diff options
Diffstat (limited to 'pkg/sentry/socket/epsocket/epsocket.go')
-rw-r--r-- | pkg/sentry/socket/epsocket/epsocket.go | 25 |
1 files changed, 13 insertions, 12 deletions
diff --git a/pkg/sentry/socket/epsocket/epsocket.go b/pkg/sentry/socket/epsocket/epsocket.go index 2a38e370a..9d1bcfd41 100644 --- a/pkg/sentry/socket/epsocket/epsocket.go +++ b/pkg/sentry/socket/epsocket/epsocket.go @@ -40,7 +40,6 @@ import ( "gvisor.dev/gvisor/pkg/sentry/fs/fsutil" "gvisor.dev/gvisor/pkg/sentry/inet" "gvisor.dev/gvisor/pkg/sentry/kernel" - "gvisor.dev/gvisor/pkg/sentry/kernel/kdefs" ktime "gvisor.dev/gvisor/pkg/sentry/kernel/time" "gvisor.dev/gvisor/pkg/sentry/safemem" "gvisor.dev/gvisor/pkg/sentry/socket" @@ -286,14 +285,14 @@ func bytesToIPAddress(addr []byte) tcpip.Address { // GetAddress reads an sockaddr struct from the given address and converts it // to the FullAddress format. It supports AF_UNIX, AF_INET and AF_INET6 // addresses. -func GetAddress(sfamily int, addr []byte) (tcpip.FullAddress, *syserr.Error) { +func GetAddress(sfamily int, addr []byte, strict bool) (tcpip.FullAddress, *syserr.Error) { // Make sure we have at least 2 bytes for the address family. if len(addr) < 2 { return tcpip.FullAddress{}, syserr.ErrInvalidArgument } family := usermem.ByteOrder.Uint16(addr) - if family != uint16(sfamily) { + if family != uint16(sfamily) && (!strict && family != linux.AF_UNSPEC) { return tcpip.FullAddress{}, syserr.ErrAddressFamilyNotSupported } @@ -318,7 +317,7 @@ func GetAddress(sfamily int, addr []byte) (tcpip.FullAddress, *syserr.Error) { case linux.AF_INET: var a linux.SockAddrInet if len(addr) < sockAddrInetSize { - return tcpip.FullAddress{}, syserr.ErrBadAddress + return tcpip.FullAddress{}, syserr.ErrInvalidArgument } binary.Unmarshal(addr[:sockAddrInetSize], usermem.ByteOrder, &a) @@ -331,7 +330,7 @@ func GetAddress(sfamily int, addr []byte) (tcpip.FullAddress, *syserr.Error) { case linux.AF_INET6: var a linux.SockAddrInet6 if len(addr) < sockAddrInet6Size { - return tcpip.FullAddress{}, syserr.ErrBadAddress + return tcpip.FullAddress{}, syserr.ErrInvalidArgument } binary.Unmarshal(addr[:sockAddrInet6Size], usermem.ByteOrder, &a) @@ -344,6 +343,9 @@ func GetAddress(sfamily int, addr []byte) (tcpip.FullAddress, *syserr.Error) { } return out, nil + case linux.AF_UNSPEC: + return tcpip.FullAddress{}, nil + default: return tcpip.FullAddress{}, syserr.ErrAddressFamilyNotSupported } @@ -466,7 +468,7 @@ func (s *SocketOperations) Readiness(mask waiter.EventMask) waiter.EventMask { // Connect implements the linux syscall connect(2) for sockets backed by // tpcip.Endpoint. func (s *SocketOperations) Connect(t *kernel.Task, sockaddr []byte, blocking bool) *syserr.Error { - addr, err := GetAddress(s.family, sockaddr) + addr, err := GetAddress(s.family, sockaddr, false /* strict */) if err != nil { return err } @@ -499,7 +501,7 @@ func (s *SocketOperations) Connect(t *kernel.Task, sockaddr []byte, blocking boo // Bind implements the linux syscall bind(2) for sockets backed by // tcpip.Endpoint. func (s *SocketOperations) Bind(t *kernel.Task, sockaddr []byte) *syserr.Error { - addr, err := GetAddress(s.family, sockaddr) + addr, err := GetAddress(s.family, sockaddr, true /* strict */) if err != nil { return err } @@ -537,7 +539,7 @@ func (s *SocketOperations) blockingAccept(t *kernel.Task) (tcpip.Endpoint, *wait // Accept implements the linux syscall accept(2) for sockets backed by // tcpip.Endpoint. -func (s *SocketOperations) Accept(t *kernel.Task, peerRequested bool, flags int, blocking bool) (kdefs.FD, interface{}, uint32, *syserr.Error) { +func (s *SocketOperations) Accept(t *kernel.Task, peerRequested bool, flags int, blocking bool) (int32, interface{}, uint32, *syserr.Error) { // Issue the accept request to get the new endpoint. ep, wq, terr := s.Endpoint.Accept() if terr != nil { @@ -575,10 +577,9 @@ func (s *SocketOperations) Accept(t *kernel.Task, peerRequested bool, flags int, } } - fdFlags := kernel.FDFlags{ + fd, e := t.NewFDFrom(0, ns, kernel.FDFlags{ CloseOnExec: flags&linux.SOCK_CLOEXEC != 0, - } - fd, e := t.FDMap().NewFDFrom(0, ns, fdFlags, t.ThreadGroup().Limits()) + }) t.Kernel().RecordSocket(ns) @@ -1924,7 +1925,7 @@ func (s *SocketOperations) SendMsg(t *kernel.Task, src usermem.IOSequence, to [] var addr *tcpip.FullAddress if len(to) > 0 { - addrBuf, err := GetAddress(s.family, to) + addrBuf, err := GetAddress(s.family, to, true /* strict */) if err != nil { return 0, err } |