diff options
Diffstat (limited to 'pkg/sentry/socket/epsocket')
-rw-r--r-- | pkg/sentry/socket/epsocket/epsocket.go | 17 |
1 files changed, 10 insertions, 7 deletions
diff --git a/pkg/sentry/socket/epsocket/epsocket.go b/pkg/sentry/socket/epsocket/epsocket.go index b2b2d98a1..9d1bcfd41 100644 --- a/pkg/sentry/socket/epsocket/epsocket.go +++ b/pkg/sentry/socket/epsocket/epsocket.go @@ -285,14 +285,14 @@ func bytesToIPAddress(addr []byte) tcpip.Address { // GetAddress reads an sockaddr struct from the given address and converts it // to the FullAddress format. It supports AF_UNIX, AF_INET and AF_INET6 // addresses. -func GetAddress(sfamily int, addr []byte) (tcpip.FullAddress, *syserr.Error) { +func GetAddress(sfamily int, addr []byte, strict bool) (tcpip.FullAddress, *syserr.Error) { // Make sure we have at least 2 bytes for the address family. if len(addr) < 2 { return tcpip.FullAddress{}, syserr.ErrInvalidArgument } family := usermem.ByteOrder.Uint16(addr) - if family != uint16(sfamily) { + if family != uint16(sfamily) && (!strict && family != linux.AF_UNSPEC) { return tcpip.FullAddress{}, syserr.ErrAddressFamilyNotSupported } @@ -317,7 +317,7 @@ func GetAddress(sfamily int, addr []byte) (tcpip.FullAddress, *syserr.Error) { case linux.AF_INET: var a linux.SockAddrInet if len(addr) < sockAddrInetSize { - return tcpip.FullAddress{}, syserr.ErrBadAddress + return tcpip.FullAddress{}, syserr.ErrInvalidArgument } binary.Unmarshal(addr[:sockAddrInetSize], usermem.ByteOrder, &a) @@ -330,7 +330,7 @@ func GetAddress(sfamily int, addr []byte) (tcpip.FullAddress, *syserr.Error) { case linux.AF_INET6: var a linux.SockAddrInet6 if len(addr) < sockAddrInet6Size { - return tcpip.FullAddress{}, syserr.ErrBadAddress + return tcpip.FullAddress{}, syserr.ErrInvalidArgument } binary.Unmarshal(addr[:sockAddrInet6Size], usermem.ByteOrder, &a) @@ -343,6 +343,9 @@ func GetAddress(sfamily int, addr []byte) (tcpip.FullAddress, *syserr.Error) { } return out, nil + case linux.AF_UNSPEC: + return tcpip.FullAddress{}, nil + default: return tcpip.FullAddress{}, syserr.ErrAddressFamilyNotSupported } @@ -465,7 +468,7 @@ func (s *SocketOperations) Readiness(mask waiter.EventMask) waiter.EventMask { // Connect implements the linux syscall connect(2) for sockets backed by // tpcip.Endpoint. func (s *SocketOperations) Connect(t *kernel.Task, sockaddr []byte, blocking bool) *syserr.Error { - addr, err := GetAddress(s.family, sockaddr) + addr, err := GetAddress(s.family, sockaddr, false /* strict */) if err != nil { return err } @@ -498,7 +501,7 @@ func (s *SocketOperations) Connect(t *kernel.Task, sockaddr []byte, blocking boo // Bind implements the linux syscall bind(2) for sockets backed by // tcpip.Endpoint. func (s *SocketOperations) Bind(t *kernel.Task, sockaddr []byte) *syserr.Error { - addr, err := GetAddress(s.family, sockaddr) + addr, err := GetAddress(s.family, sockaddr, true /* strict */) if err != nil { return err } @@ -1922,7 +1925,7 @@ func (s *SocketOperations) SendMsg(t *kernel.Task, src usermem.IOSequence, to [] var addr *tcpip.FullAddress if len(to) > 0 { - addrBuf, err := GetAddress(s.family, to) + addrBuf, err := GetAddress(s.family, to, true /* strict */) if err != nil { return 0, err } |