summaryrefslogtreecommitdiffhomepage
path: root/pkg/sentry/socket/hostinet
diff options
context:
space:
mode:
Diffstat (limited to 'pkg/sentry/socket/hostinet')
-rw-r--r--pkg/sentry/socket/hostinet/socket.go35
-rw-r--r--pkg/sentry/socket/hostinet/socket_unsafe.go10
2 files changed, 26 insertions, 19 deletions
diff --git a/pkg/sentry/socket/hostinet/socket.go b/pkg/sentry/socket/hostinet/socket.go
index 618ea42c7..92beb1bcf 100644
--- a/pkg/sentry/socket/hostinet/socket.go
+++ b/pkg/sentry/socket/hostinet/socket.go
@@ -189,15 +189,16 @@ func (s *socketOperations) Connect(t *kernel.Task, sockaddr []byte, blocking boo
}
// Accept implements socket.Socket.Accept.
-func (s *socketOperations) Accept(t *kernel.Task, peerRequested bool, flags int, blocking bool) (int32, interface{}, uint32, *syserr.Error) {
- var peerAddr []byte
+func (s *socketOperations) Accept(t *kernel.Task, peerRequested bool, flags int, blocking bool) (int32, linux.SockAddr, uint32, *syserr.Error) {
+ var peerAddr linux.SockAddr
+ var peerAddrBuf []byte
var peerAddrlen uint32
var peerAddrPtr *byte
var peerAddrlenPtr *uint32
if peerRequested {
- peerAddr = make([]byte, sizeofSockaddr)
- peerAddrlen = uint32(len(peerAddr))
- peerAddrPtr = &peerAddr[0]
+ peerAddrBuf = make([]byte, sizeofSockaddr)
+ peerAddrlen = uint32(len(peerAddrBuf))
+ peerAddrPtr = &peerAddrBuf[0]
peerAddrlenPtr = &peerAddrlen
}
@@ -222,7 +223,7 @@ func (s *socketOperations) Accept(t *kernel.Task, peerRequested bool, flags int,
}
if peerRequested {
- peerAddr = peerAddr[:peerAddrlen]
+ peerAddr = socket.UnmarshalSockAddr(s.family, peerAddrBuf[:peerAddrlen])
}
if syscallErr != nil {
return 0, peerAddr, peerAddrlen, syserr.FromError(syscallErr)
@@ -353,7 +354,7 @@ func (s *socketOperations) SetSockOpt(t *kernel.Task, level int, name int, opt [
}
// RecvMsg implements socket.Socket.RecvMsg.
-func (s *socketOperations) RecvMsg(t *kernel.Task, dst usermem.IOSequence, flags int, haveDeadline bool, deadline ktime.Time, senderRequested bool, controlDataLen uint64) (int, int, interface{}, uint32, socket.ControlMessages, *syserr.Error) {
+func (s *socketOperations) RecvMsg(t *kernel.Task, dst usermem.IOSequence, flags int, haveDeadline bool, deadline ktime.Time, senderRequested bool, controlDataLen uint64) (int, int, linux.SockAddr, uint32, socket.ControlMessages, *syserr.Error) {
// Whitelist flags.
//
// FIXME(jamieliu): We can't support MSG_ERRQUEUE because it uses ancillary
@@ -363,9 +364,10 @@ func (s *socketOperations) RecvMsg(t *kernel.Task, dst usermem.IOSequence, flags
return 0, 0, nil, 0, socket.ControlMessages{}, syserr.ErrInvalidArgument
}
- var senderAddr []byte
+ var senderAddr linux.SockAddr
+ var senderAddrBuf []byte
if senderRequested {
- senderAddr = make([]byte, sizeofSockaddr)
+ senderAddrBuf = make([]byte, sizeofSockaddr)
}
var msgFlags int
@@ -384,7 +386,7 @@ func (s *socketOperations) RecvMsg(t *kernel.Task, dst usermem.IOSequence, flags
if dsts.NumBlocks() == 1 {
// Skip allocating []syscall.Iovec.
- return recvfrom(s.fd, dsts.Head().ToSlice(), sysflags, &senderAddr)
+ return recvfrom(s.fd, dsts.Head().ToSlice(), sysflags, &senderAddrBuf)
}
iovs := iovecsFromBlockSeq(dsts)
@@ -392,15 +394,15 @@ func (s *socketOperations) RecvMsg(t *kernel.Task, dst usermem.IOSequence, flags
Iov: &iovs[0],
Iovlen: uint64(len(iovs)),
}
- if len(senderAddr) != 0 {
- msg.Name = &senderAddr[0]
- msg.Namelen = uint32(len(senderAddr))
+ if len(senderAddrBuf) != 0 {
+ msg.Name = &senderAddrBuf[0]
+ msg.Namelen = uint32(len(senderAddrBuf))
}
n, err := recvmsg(s.fd, &msg, sysflags)
if err != nil {
return 0, err
}
- senderAddr = senderAddr[:msg.Namelen]
+ senderAddrBuf = senderAddrBuf[:msg.Namelen]
msgFlags = int(msg.Flags)
return n, nil
})
@@ -431,7 +433,10 @@ func (s *socketOperations) RecvMsg(t *kernel.Task, dst usermem.IOSequence, flags
// We don't allow control messages.
msgFlags &^= linux.MSG_CTRUNC
- return int(n), msgFlags, senderAddr, uint32(len(senderAddr)), socket.ControlMessages{}, syserr.FromError(err)
+ if senderRequested {
+ senderAddr = socket.UnmarshalSockAddr(s.family, senderAddrBuf)
+ }
+ return int(n), msgFlags, senderAddr, uint32(len(senderAddrBuf)), socket.ControlMessages{}, syserr.FromError(err)
}
// SendMsg implements socket.Socket.SendMsg.
diff --git a/pkg/sentry/socket/hostinet/socket_unsafe.go b/pkg/sentry/socket/hostinet/socket_unsafe.go
index 6c69ba9c7..e69ec38c2 100644
--- a/pkg/sentry/socket/hostinet/socket_unsafe.go
+++ b/pkg/sentry/socket/hostinet/socket_unsafe.go
@@ -18,10 +18,12 @@ import (
"syscall"
"unsafe"
+ "gvisor.dev/gvisor/pkg/abi/linux"
"gvisor.dev/gvisor/pkg/sentry/arch"
"gvisor.dev/gvisor/pkg/sentry/context"
"gvisor.dev/gvisor/pkg/sentry/fs"
"gvisor.dev/gvisor/pkg/sentry/kernel"
+ "gvisor.dev/gvisor/pkg/sentry/socket"
"gvisor.dev/gvisor/pkg/sentry/usermem"
"gvisor.dev/gvisor/pkg/syserr"
"gvisor.dev/gvisor/pkg/syserror"
@@ -91,25 +93,25 @@ func getsockopt(fd int, level, name int, optlen int) ([]byte, error) {
}
// GetSockName implements socket.Socket.GetSockName.
-func (s *socketOperations) GetSockName(t *kernel.Task) (interface{}, uint32, *syserr.Error) {
+func (s *socketOperations) GetSockName(t *kernel.Task) (linux.SockAddr, uint32, *syserr.Error) {
addr := make([]byte, sizeofSockaddr)
addrlen := uint32(len(addr))
_, _, errno := syscall.Syscall(syscall.SYS_GETSOCKNAME, uintptr(s.fd), uintptr(unsafe.Pointer(&addr[0])), uintptr(unsafe.Pointer(&addrlen)))
if errno != 0 {
return nil, 0, syserr.FromError(errno)
}
- return addr[:addrlen], addrlen, nil
+ return socket.UnmarshalSockAddr(s.family, addr), addrlen, nil
}
// GetPeerName implements socket.Socket.GetPeerName.
-func (s *socketOperations) GetPeerName(t *kernel.Task) (interface{}, uint32, *syserr.Error) {
+func (s *socketOperations) GetPeerName(t *kernel.Task) (linux.SockAddr, uint32, *syserr.Error) {
addr := make([]byte, sizeofSockaddr)
addrlen := uint32(len(addr))
_, _, errno := syscall.Syscall(syscall.SYS_GETPEERNAME, uintptr(s.fd), uintptr(unsafe.Pointer(&addr[0])), uintptr(unsafe.Pointer(&addrlen)))
if errno != 0 {
return nil, 0, syserr.FromError(errno)
}
- return addr[:addrlen], addrlen, nil
+ return socket.UnmarshalSockAddr(s.family, addr), addrlen, nil
}
func recvfrom(fd int, dst []byte, flags int, from *[]byte) (uint64, error) {