diff options
Diffstat (limited to 'pkg/sentry/socket/hostinet/socket.go')
-rw-r--r-- | pkg/sentry/socket/hostinet/socket.go | 35 |
1 files changed, 20 insertions, 15 deletions
diff --git a/pkg/sentry/socket/hostinet/socket.go b/pkg/sentry/socket/hostinet/socket.go index 618ea42c7..92beb1bcf 100644 --- a/pkg/sentry/socket/hostinet/socket.go +++ b/pkg/sentry/socket/hostinet/socket.go @@ -189,15 +189,16 @@ func (s *socketOperations) Connect(t *kernel.Task, sockaddr []byte, blocking boo } // Accept implements socket.Socket.Accept. -func (s *socketOperations) Accept(t *kernel.Task, peerRequested bool, flags int, blocking bool) (int32, interface{}, uint32, *syserr.Error) { - var peerAddr []byte +func (s *socketOperations) Accept(t *kernel.Task, peerRequested bool, flags int, blocking bool) (int32, linux.SockAddr, uint32, *syserr.Error) { + var peerAddr linux.SockAddr + var peerAddrBuf []byte var peerAddrlen uint32 var peerAddrPtr *byte var peerAddrlenPtr *uint32 if peerRequested { - peerAddr = make([]byte, sizeofSockaddr) - peerAddrlen = uint32(len(peerAddr)) - peerAddrPtr = &peerAddr[0] + peerAddrBuf = make([]byte, sizeofSockaddr) + peerAddrlen = uint32(len(peerAddrBuf)) + peerAddrPtr = &peerAddrBuf[0] peerAddrlenPtr = &peerAddrlen } @@ -222,7 +223,7 @@ func (s *socketOperations) Accept(t *kernel.Task, peerRequested bool, flags int, } if peerRequested { - peerAddr = peerAddr[:peerAddrlen] + peerAddr = socket.UnmarshalSockAddr(s.family, peerAddrBuf[:peerAddrlen]) } if syscallErr != nil { return 0, peerAddr, peerAddrlen, syserr.FromError(syscallErr) @@ -353,7 +354,7 @@ func (s *socketOperations) SetSockOpt(t *kernel.Task, level int, name int, opt [ } // RecvMsg implements socket.Socket.RecvMsg. -func (s *socketOperations) RecvMsg(t *kernel.Task, dst usermem.IOSequence, flags int, haveDeadline bool, deadline ktime.Time, senderRequested bool, controlDataLen uint64) (int, int, interface{}, uint32, socket.ControlMessages, *syserr.Error) { +func (s *socketOperations) RecvMsg(t *kernel.Task, dst usermem.IOSequence, flags int, haveDeadline bool, deadline ktime.Time, senderRequested bool, controlDataLen uint64) (int, int, linux.SockAddr, uint32, socket.ControlMessages, *syserr.Error) { // Whitelist flags. // // FIXME(jamieliu): We can't support MSG_ERRQUEUE because it uses ancillary @@ -363,9 +364,10 @@ func (s *socketOperations) RecvMsg(t *kernel.Task, dst usermem.IOSequence, flags return 0, 0, nil, 0, socket.ControlMessages{}, syserr.ErrInvalidArgument } - var senderAddr []byte + var senderAddr linux.SockAddr + var senderAddrBuf []byte if senderRequested { - senderAddr = make([]byte, sizeofSockaddr) + senderAddrBuf = make([]byte, sizeofSockaddr) } var msgFlags int @@ -384,7 +386,7 @@ func (s *socketOperations) RecvMsg(t *kernel.Task, dst usermem.IOSequence, flags if dsts.NumBlocks() == 1 { // Skip allocating []syscall.Iovec. - return recvfrom(s.fd, dsts.Head().ToSlice(), sysflags, &senderAddr) + return recvfrom(s.fd, dsts.Head().ToSlice(), sysflags, &senderAddrBuf) } iovs := iovecsFromBlockSeq(dsts) @@ -392,15 +394,15 @@ func (s *socketOperations) RecvMsg(t *kernel.Task, dst usermem.IOSequence, flags Iov: &iovs[0], Iovlen: uint64(len(iovs)), } - if len(senderAddr) != 0 { - msg.Name = &senderAddr[0] - msg.Namelen = uint32(len(senderAddr)) + if len(senderAddrBuf) != 0 { + msg.Name = &senderAddrBuf[0] + msg.Namelen = uint32(len(senderAddrBuf)) } n, err := recvmsg(s.fd, &msg, sysflags) if err != nil { return 0, err } - senderAddr = senderAddr[:msg.Namelen] + senderAddrBuf = senderAddrBuf[:msg.Namelen] msgFlags = int(msg.Flags) return n, nil }) @@ -431,7 +433,10 @@ func (s *socketOperations) RecvMsg(t *kernel.Task, dst usermem.IOSequence, flags // We don't allow control messages. msgFlags &^= linux.MSG_CTRUNC - return int(n), msgFlags, senderAddr, uint32(len(senderAddr)), socket.ControlMessages{}, syserr.FromError(err) + if senderRequested { + senderAddr = socket.UnmarshalSockAddr(s.family, senderAddrBuf) + } + return int(n), msgFlags, senderAddr, uint32(len(senderAddrBuf)), socket.ControlMessages{}, syserr.FromError(err) } // SendMsg implements socket.Socket.SendMsg. |