diff options
Diffstat (limited to 'pkg')
-rw-r--r-- | pkg/sentry/socket/epsocket/epsocket.go | 45 | ||||
-rw-r--r-- | pkg/sentry/socket/unix/unix.go | 2 | ||||
-rw-r--r-- | pkg/sentry/strace/socket.go | 2 | ||||
-rw-r--r-- | pkg/tcpip/stack/transport_test.go | 5 | ||||
-rw-r--r-- | pkg/tcpip/tcpip.go | 3 | ||||
-rw-r--r-- | pkg/tcpip/transport/icmp/endpoint.go | 10 | ||||
-rw-r--r-- | pkg/tcpip/transport/raw/endpoint.go | 10 | ||||
-rw-r--r-- | pkg/tcpip/transport/tcp/endpoint.go | 10 | ||||
-rw-r--r-- | pkg/tcpip/transport/udp/endpoint.go | 6 |
9 files changed, 55 insertions, 38 deletions
diff --git a/pkg/sentry/socket/epsocket/epsocket.go b/pkg/sentry/socket/epsocket/epsocket.go index 4d8a5ac22..635042263 100644 --- a/pkg/sentry/socket/epsocket/epsocket.go +++ b/pkg/sentry/socket/epsocket/epsocket.go @@ -291,18 +291,22 @@ func bytesToIPAddress(addr []byte) tcpip.Address { return tcpip.Address(addr) } -// GetAddress reads an sockaddr struct from the given address and converts it -// to the FullAddress format. It supports AF_UNIX, AF_INET and AF_INET6 -// addresses. -func GetAddress(sfamily int, addr []byte, strict bool) (tcpip.FullAddress, *syserr.Error) { +// AddressAndFamily reads an sockaddr struct from the given address and +// converts it to the FullAddress format. It supports AF_UNIX, AF_INET and +// AF_INET6 addresses. +// +// strict indicates whether addresses with the AF_UNSPEC family are accepted of not. +// +// AddressAndFamily returns an address, its family. +func AddressAndFamily(sfamily int, addr []byte, strict bool) (tcpip.FullAddress, uint16, *syserr.Error) { // Make sure we have at least 2 bytes for the address family. if len(addr) < 2 { - return tcpip.FullAddress{}, syserr.ErrInvalidArgument + return tcpip.FullAddress{}, 0, syserr.ErrInvalidArgument } family := usermem.ByteOrder.Uint16(addr) if family != uint16(sfamily) && (!strict && family != linux.AF_UNSPEC) { - return tcpip.FullAddress{}, syserr.ErrAddressFamilyNotSupported + return tcpip.FullAddress{}, family, syserr.ErrAddressFamilyNotSupported } // Get the rest of the fields based on the address family. @@ -310,7 +314,7 @@ func GetAddress(sfamily int, addr []byte, strict bool) (tcpip.FullAddress, *syse case linux.AF_UNIX: path := addr[2:] if len(path) > linux.UnixPathMax { - return tcpip.FullAddress{}, syserr.ErrInvalidArgument + return tcpip.FullAddress{}, family, syserr.ErrInvalidArgument } // Drop the terminating NUL (if one exists) and everything after // it for filesystem (non-abstract) addresses. @@ -321,12 +325,12 @@ func GetAddress(sfamily int, addr []byte, strict bool) (tcpip.FullAddress, *syse } return tcpip.FullAddress{ Addr: tcpip.Address(path), - }, nil + }, family, nil case linux.AF_INET: var a linux.SockAddrInet if len(addr) < sockAddrInetSize { - return tcpip.FullAddress{}, syserr.ErrInvalidArgument + return tcpip.FullAddress{}, family, syserr.ErrInvalidArgument } binary.Unmarshal(addr[:sockAddrInetSize], usermem.ByteOrder, &a) @@ -334,12 +338,12 @@ func GetAddress(sfamily int, addr []byte, strict bool) (tcpip.FullAddress, *syse Addr: bytesToIPAddress(a.Addr[:]), Port: ntohs(a.Port), } - return out, nil + return out, family, nil case linux.AF_INET6: var a linux.SockAddrInet6 if len(addr) < sockAddrInet6Size { - return tcpip.FullAddress{}, syserr.ErrInvalidArgument + return tcpip.FullAddress{}, family, syserr.ErrInvalidArgument } binary.Unmarshal(addr[:sockAddrInet6Size], usermem.ByteOrder, &a) @@ -350,13 +354,13 @@ func GetAddress(sfamily int, addr []byte, strict bool) (tcpip.FullAddress, *syse if isLinkLocal(out.Addr) { out.NIC = tcpip.NICID(a.Scope_id) } - return out, nil + return out, family, nil case linux.AF_UNSPEC: - return tcpip.FullAddress{}, nil + return tcpip.FullAddress{}, family, nil default: - return tcpip.FullAddress{}, syserr.ErrAddressFamilyNotSupported + return tcpip.FullAddress{}, 0, syserr.ErrAddressFamilyNotSupported } } @@ -482,11 +486,18 @@ func (s *SocketOperations) Readiness(mask waiter.EventMask) waiter.EventMask { // Connect implements the linux syscall connect(2) for sockets backed by // tpcip.Endpoint. func (s *SocketOperations) Connect(t *kernel.Task, sockaddr []byte, blocking bool) *syserr.Error { - addr, err := GetAddress(s.family, sockaddr, false /* strict */) + addr, family, err := AddressAndFamily(s.family, sockaddr, false /* strict */) if err != nil { return err } + if family == linux.AF_UNSPEC { + err := s.Endpoint.Disconnect() + if err == tcpip.ErrNotSupported { + return syserr.ErrAddressFamilyNotSupported + } + return syserr.TranslateNetstackError(err) + } // Always return right away in the non-blocking case. if !blocking { return syserr.TranslateNetstackError(s.Endpoint.Connect(addr)) @@ -515,7 +526,7 @@ func (s *SocketOperations) Connect(t *kernel.Task, sockaddr []byte, blocking boo // Bind implements the linux syscall bind(2) for sockets backed by // tcpip.Endpoint. func (s *SocketOperations) Bind(t *kernel.Task, sockaddr []byte) *syserr.Error { - addr, err := GetAddress(s.family, sockaddr, true /* strict */) + addr, _, err := AddressAndFamily(s.family, sockaddr, true /* strict */) if err != nil { return err } @@ -2023,7 +2034,7 @@ func (s *SocketOperations) SendMsg(t *kernel.Task, src usermem.IOSequence, to [] var addr *tcpip.FullAddress if len(to) > 0 { - addrBuf, err := GetAddress(s.family, to, true /* strict */) + addrBuf, _, err := AddressAndFamily(s.family, to, true /* strict */) if err != nil { return 0, err } diff --git a/pkg/sentry/socket/unix/unix.go b/pkg/sentry/socket/unix/unix.go index 8a3f65236..0d0cb68df 100644 --- a/pkg/sentry/socket/unix/unix.go +++ b/pkg/sentry/socket/unix/unix.go @@ -116,7 +116,7 @@ func (s *SocketOperations) Endpoint() transport.Endpoint { // extractPath extracts and validates the address. func extractPath(sockaddr []byte) (string, *syserr.Error) { - addr, err := epsocket.GetAddress(linux.AF_UNIX, sockaddr, true /* strict */) + addr, _, err := epsocket.AddressAndFamily(linux.AF_UNIX, sockaddr, true /* strict */) if err != nil { return "", err } diff --git a/pkg/sentry/strace/socket.go b/pkg/sentry/strace/socket.go index 386b40af7..f779186ad 100644 --- a/pkg/sentry/strace/socket.go +++ b/pkg/sentry/strace/socket.go @@ -332,7 +332,7 @@ func sockAddr(t *kernel.Task, addr usermem.Addr, length uint32) string { switch family { case linux.AF_INET, linux.AF_INET6, linux.AF_UNIX: - fa, err := epsocket.GetAddress(int(family), b, true /* strict */) + fa, _, err := epsocket.AddressAndFamily(int(family), b, true /* strict */) if err != nil { return fmt.Sprintf("%#x {Family: %s, error extracting address: %v}", addr, familyStr, err) } diff --git a/pkg/tcpip/stack/transport_test.go b/pkg/tcpip/stack/transport_test.go index 8146e8444..1c811ab68 100644 --- a/pkg/tcpip/stack/transport_test.go +++ b/pkg/tcpip/stack/transport_test.go @@ -105,6 +105,11 @@ func (*fakeTransportEndpoint) GetSockOpt(opt interface{}) *tcpip.Error { return tcpip.ErrInvalidEndpointState } +// Disconnect implements tcpip.Endpoint.Disconnect. +func (*fakeTransportEndpoint) Disconnect() *tcpip.Error { + return tcpip.ErrNotSupported +} + func (f *fakeTransportEndpoint) Connect(addr tcpip.FullAddress) *tcpip.Error { f.peerAddr = addr.Addr diff --git a/pkg/tcpip/tcpip.go b/pkg/tcpip/tcpip.go index ba2dd85b8..29a6025d9 100644 --- a/pkg/tcpip/tcpip.go +++ b/pkg/tcpip/tcpip.go @@ -358,6 +358,9 @@ type Endpoint interface { // ErrAddressFamilyNotSupported must be returned. Connect(address FullAddress) *Error + // Disconnect disconnects the endpoint from its peer. + Disconnect() *Error + // Shutdown closes the read and/or write end of the endpoint connection // to its peer. Shutdown(flags ShutdownFlags) *Error diff --git a/pkg/tcpip/transport/icmp/endpoint.go b/pkg/tcpip/transport/icmp/endpoint.go index 2e8d5d4bf..451d3880e 100644 --- a/pkg/tcpip/transport/icmp/endpoint.go +++ b/pkg/tcpip/transport/icmp/endpoint.go @@ -428,16 +428,16 @@ func (e *endpoint) checkV4Mapped(addr *tcpip.FullAddress, allowMismatch bool) (t return netProto, nil } +// Disconnect implements tcpip.Endpoint.Disconnect. +func (*endpoint) Disconnect() *tcpip.Error { + return tcpip.ErrNotSupported +} + // Connect connects the endpoint to its peer. Specifying a NIC is optional. func (e *endpoint) Connect(addr tcpip.FullAddress) *tcpip.Error { e.mu.Lock() defer e.mu.Unlock() - if addr.Addr == "" { - // AF_UNSPEC isn't supported. - return tcpip.ErrAddressFamilyNotSupported - } - nicid := addr.NIC localPort := uint16(0) switch e.state { diff --git a/pkg/tcpip/transport/raw/endpoint.go b/pkg/tcpip/transport/raw/endpoint.go index 53c9515a4..13e17e2a6 100644 --- a/pkg/tcpip/transport/raw/endpoint.go +++ b/pkg/tcpip/transport/raw/endpoint.go @@ -349,16 +349,16 @@ func (ep *endpoint) Peek([][]byte) (int64, tcpip.ControlMessages, *tcpip.Error) return 0, tcpip.ControlMessages{}, nil } +// Disconnect implements tcpip.Endpoint.Disconnect. +func (*endpoint) Disconnect() *tcpip.Error { + return tcpip.ErrNotSupported +} + // Connect implements tcpip.Endpoint.Connect. func (ep *endpoint) Connect(addr tcpip.FullAddress) *tcpip.Error { ep.mu.Lock() defer ep.mu.Unlock() - if addr.Addr == "" { - // AF_UNSPEC isn't supported. - return tcpip.ErrAddressFamilyNotSupported - } - if ep.closed { return tcpip.ErrInvalidEndpointState } diff --git a/pkg/tcpip/transport/tcp/endpoint.go b/pkg/tcpip/transport/tcp/endpoint.go index e5f835c20..24b32e4af 100644 --- a/pkg/tcpip/transport/tcp/endpoint.go +++ b/pkg/tcpip/transport/tcp/endpoint.go @@ -1362,13 +1362,13 @@ func (e *endpoint) checkV4Mapped(addr *tcpip.FullAddress) (tcpip.NetworkProtocol return netProto, nil } +// Disconnect implements tcpip.Endpoint.Disconnect. +func (*endpoint) Disconnect() *tcpip.Error { + return tcpip.ErrNotSupported +} + // Connect connects the endpoint to its peer. func (e *endpoint) Connect(addr tcpip.FullAddress) *tcpip.Error { - if addr.Addr == "" && addr.Port == 0 { - // AF_UNSPEC isn't supported. - return tcpip.ErrAddressFamilyNotSupported - } - return e.connect(addr, true, true) } diff --git a/pkg/tcpip/transport/udp/endpoint.go b/pkg/tcpip/transport/udp/endpoint.go index 640bb8667..8b3356406 100644 --- a/pkg/tcpip/transport/udp/endpoint.go +++ b/pkg/tcpip/transport/udp/endpoint.go @@ -711,7 +711,8 @@ func (e *endpoint) checkV4Mapped(addr *tcpip.FullAddress, allowMismatch bool) (t return netProto, nil } -func (e *endpoint) disconnect() *tcpip.Error { +// Disconnect implements tcpip.Endpoint.Disconnect. +func (e *endpoint) Disconnect() *tcpip.Error { e.mu.Lock() defer e.mu.Unlock() @@ -750,9 +751,6 @@ func (e *endpoint) Connect(addr tcpip.FullAddress) *tcpip.Error { if err != nil { return err } - if addr.Addr == "" { - return e.disconnect() - } if addr.Port == 0 { // We don't support connecting to port zero. return tcpip.ErrInvalidEndpointState |