summaryrefslogtreecommitdiffhomepage
path: root/pkg/sentry/socket/epsocket/epsocket.go
diff options
context:
space:
mode:
Diffstat (limited to 'pkg/sentry/socket/epsocket/epsocket.go')
-rw-r--r--pkg/sentry/socket/epsocket/epsocket.go25
1 files changed, 13 insertions, 12 deletions
diff --git a/pkg/sentry/socket/epsocket/epsocket.go b/pkg/sentry/socket/epsocket/epsocket.go
index 2a38e370a..9d1bcfd41 100644
--- a/pkg/sentry/socket/epsocket/epsocket.go
+++ b/pkg/sentry/socket/epsocket/epsocket.go
@@ -40,7 +40,6 @@ import (
"gvisor.dev/gvisor/pkg/sentry/fs/fsutil"
"gvisor.dev/gvisor/pkg/sentry/inet"
"gvisor.dev/gvisor/pkg/sentry/kernel"
- "gvisor.dev/gvisor/pkg/sentry/kernel/kdefs"
ktime "gvisor.dev/gvisor/pkg/sentry/kernel/time"
"gvisor.dev/gvisor/pkg/sentry/safemem"
"gvisor.dev/gvisor/pkg/sentry/socket"
@@ -286,14 +285,14 @@ func bytesToIPAddress(addr []byte) tcpip.Address {
// GetAddress reads an sockaddr struct from the given address and converts it
// to the FullAddress format. It supports AF_UNIX, AF_INET and AF_INET6
// addresses.
-func GetAddress(sfamily int, addr []byte) (tcpip.FullAddress, *syserr.Error) {
+func GetAddress(sfamily int, addr []byte, strict bool) (tcpip.FullAddress, *syserr.Error) {
// Make sure we have at least 2 bytes for the address family.
if len(addr) < 2 {
return tcpip.FullAddress{}, syserr.ErrInvalidArgument
}
family := usermem.ByteOrder.Uint16(addr)
- if family != uint16(sfamily) {
+ if family != uint16(sfamily) && (!strict && family != linux.AF_UNSPEC) {
return tcpip.FullAddress{}, syserr.ErrAddressFamilyNotSupported
}
@@ -318,7 +317,7 @@ func GetAddress(sfamily int, addr []byte) (tcpip.FullAddress, *syserr.Error) {
case linux.AF_INET:
var a linux.SockAddrInet
if len(addr) < sockAddrInetSize {
- return tcpip.FullAddress{}, syserr.ErrBadAddress
+ return tcpip.FullAddress{}, syserr.ErrInvalidArgument
}
binary.Unmarshal(addr[:sockAddrInetSize], usermem.ByteOrder, &a)
@@ -331,7 +330,7 @@ func GetAddress(sfamily int, addr []byte) (tcpip.FullAddress, *syserr.Error) {
case linux.AF_INET6:
var a linux.SockAddrInet6
if len(addr) < sockAddrInet6Size {
- return tcpip.FullAddress{}, syserr.ErrBadAddress
+ return tcpip.FullAddress{}, syserr.ErrInvalidArgument
}
binary.Unmarshal(addr[:sockAddrInet6Size], usermem.ByteOrder, &a)
@@ -344,6 +343,9 @@ func GetAddress(sfamily int, addr []byte) (tcpip.FullAddress, *syserr.Error) {
}
return out, nil
+ case linux.AF_UNSPEC:
+ return tcpip.FullAddress{}, nil
+
default:
return tcpip.FullAddress{}, syserr.ErrAddressFamilyNotSupported
}
@@ -466,7 +468,7 @@ func (s *SocketOperations) Readiness(mask waiter.EventMask) waiter.EventMask {
// Connect implements the linux syscall connect(2) for sockets backed by
// tpcip.Endpoint.
func (s *SocketOperations) Connect(t *kernel.Task, sockaddr []byte, blocking bool) *syserr.Error {
- addr, err := GetAddress(s.family, sockaddr)
+ addr, err := GetAddress(s.family, sockaddr, false /* strict */)
if err != nil {
return err
}
@@ -499,7 +501,7 @@ func (s *SocketOperations) Connect(t *kernel.Task, sockaddr []byte, blocking boo
// Bind implements the linux syscall bind(2) for sockets backed by
// tcpip.Endpoint.
func (s *SocketOperations) Bind(t *kernel.Task, sockaddr []byte) *syserr.Error {
- addr, err := GetAddress(s.family, sockaddr)
+ addr, err := GetAddress(s.family, sockaddr, true /* strict */)
if err != nil {
return err
}
@@ -537,7 +539,7 @@ func (s *SocketOperations) blockingAccept(t *kernel.Task) (tcpip.Endpoint, *wait
// Accept implements the linux syscall accept(2) for sockets backed by
// tcpip.Endpoint.
-func (s *SocketOperations) Accept(t *kernel.Task, peerRequested bool, flags int, blocking bool) (kdefs.FD, interface{}, uint32, *syserr.Error) {
+func (s *SocketOperations) Accept(t *kernel.Task, peerRequested bool, flags int, blocking bool) (int32, interface{}, uint32, *syserr.Error) {
// Issue the accept request to get the new endpoint.
ep, wq, terr := s.Endpoint.Accept()
if terr != nil {
@@ -575,10 +577,9 @@ func (s *SocketOperations) Accept(t *kernel.Task, peerRequested bool, flags int,
}
}
- fdFlags := kernel.FDFlags{
+ fd, e := t.NewFDFrom(0, ns, kernel.FDFlags{
CloseOnExec: flags&linux.SOCK_CLOEXEC != 0,
- }
- fd, e := t.FDMap().NewFDFrom(0, ns, fdFlags, t.ThreadGroup().Limits())
+ })
t.Kernel().RecordSocket(ns)
@@ -1924,7 +1925,7 @@ func (s *SocketOperations) SendMsg(t *kernel.Task, src usermem.IOSequence, to []
var addr *tcpip.FullAddress
if len(to) > 0 {
- addrBuf, err := GetAddress(s.family, to)
+ addrBuf, err := GetAddress(s.family, to, true /* strict */)
if err != nil {
return 0, err
}