diff options
-rw-r--r-- | paramiko/auth_handler.py | 5 | ||||
-rw-r--r-- | paramiko/client.py | 7 | ||||
-rw-r--r-- | paramiko/transport.py | 2 | ||||
-rw-r--r-- | sites/www/changelog.rst | 3 | ||||
-rw-r--r-- | tests/test_auth.py | 25 | ||||
-rw-r--r-- | tests/test_client.py | 21 |
6 files changed, 62 insertions, 1 deletions
diff --git a/paramiko/auth_handler.py b/paramiko/auth_handler.py index 33f01da6..6131db3f 100644 --- a/paramiko/auth_handler.py +++ b/paramiko/auth_handler.py @@ -21,6 +21,7 @@ """ import weakref +import time from paramiko.common import ( cMSG_SERVICE_REQUEST, cMSG_DISCONNECT, DISCONNECT_SERVICE_NOT_AVAILABLE, DISCONNECT_NO_MORE_AUTH_METHODS_AVAILABLE, cMSG_USERAUTH_REQUEST, @@ -190,6 +191,7 @@ class AuthHandler (object): return m.asbytes() def wait_for_response(self, event): + max_ts = time.time() + self.transport.auth_timeout if self.transport.auth_timeout is not None else None while True: event.wait(0.1) if not self.transport.is_active(): @@ -199,6 +201,9 @@ class AuthHandler (object): raise e if event.is_set(): break + if max_ts is not None and max_ts <= time.time(): + raise AuthenticationException('Authentication timeout.') + if not self.is_authenticated(): e = self.transport.get_exception() if e is None: diff --git a/paramiko/client.py b/paramiko/client.py index 08fe69d4..8b54d553 100644 --- a/paramiko/client.py +++ b/paramiko/client.py @@ -227,7 +227,8 @@ class SSHClient (ClosingContextManager): gss_kex=False, gss_deleg_creds=True, gss_host=None, - banner_timeout=None + banner_timeout=None, + auth_timeout=None ): """ Connect to an SSH server and authenticate to it. The server's host key @@ -279,6 +280,8 @@ class SSHClient (ClosingContextManager): The targets name in the kerberos database. default: hostname :param float banner_timeout: an optional timeout (in seconds) to wait for the SSH banner to be presented. + :param float auth_timeout: an optional timeout (in seconds) to wait for + an authentication response. :raises: `.BadHostKeyException` -- if the server's host key could not be @@ -339,6 +342,8 @@ class SSHClient (ClosingContextManager): t.set_log_channel(self._log_channel) if banner_timeout is not None: t.banner_timeout = banner_timeout + if auth_timeout is not None: + t.auth_timeout = auth_timeout t.start_client(timeout=timeout) t.set_sshclient(self) ResourceManager.register(self, t) diff --git a/paramiko/transport.py b/paramiko/transport.py index 688e09e7..998212a2 100644 --- a/paramiko/transport.py +++ b/paramiko/transport.py @@ -397,6 +397,8 @@ class Transport(threading.Thread, ClosingContextManager): # how long (seconds) to wait for the handshake to finish after SSH # banner sent. self.handshake_timeout = 15 + # how long (seconds) to wait for the auth response. + self.auth_timeout = 30 # server mode: self.server_mode = False diff --git a/sites/www/changelog.rst b/sites/www/changelog.rst index 5c0b3552..234d9df6 100644 --- a/sites/www/changelog.rst +++ b/sites/www/changelog.rst @@ -2,6 +2,9 @@ Changelog ========= +* :feature:`add-auth-timeout` Adds a timeout for the authentication process. + This is a fix to prevent the client getting stuck if an SSH server becomes + un-responsive during the authentication. Credit to ``@timsavage``. * :support:`921` Tighten up the ``__hash__`` implementation for various key classes; less code is good code. Thanks to Francisco Couzo for the patch. * :bug:`983` Move ``sha1`` above the now-arguably-broken ``md5`` in the list of diff --git a/tests/test_auth.py b/tests/test_auth.py index 96f7611c..58b2f44f 100644 --- a/tests/test_auth.py +++ b/tests/test_auth.py @@ -23,6 +23,7 @@ Some unit tests for authenticating over a Transport. import sys import threading import unittest +from time import sleep from paramiko import ( Transport, ServerInterface, RSAKey, DSSKey, BadAuthenticationType, @@ -74,6 +75,9 @@ class NullServer (ServerInterface): return AUTH_SUCCESSFUL if username == 'bad-server': raise Exception("Ack!") + if username == 'unresponsive-server': + sleep(5) + return AUTH_SUCCESSFUL return AUTH_FAILED def check_auth_publickey(self, username, key): @@ -233,3 +237,24 @@ class AuthTest (unittest.TestCase): except: etype, evalue, etb = sys.exc_info() self.assertTrue(issubclass(etype, AuthenticationException)) + + def test_9_auth_non_responsive(self): + """ + verify that authentication times out if server takes to long to + respond (or never responds). + """ + auth_timeout = self.tc.auth_timeout + self.tc.auth_timeout = 2 # Reduce to 2 seconds to speed up test + + try: + self.start_server() + self.tc.connect() + try: + remain = self.tc.auth_password('unresponsive-server', 'hello') + except: + etype, evalue, etb = sys.exc_info() + self.assertTrue(issubclass(etype, AuthenticationException)) + self.assertTrue('Authentication timeout' in str(evalue)) + finally: + # Restore value + self.tc.auth_timeout = auth_timeout diff --git a/tests/test_client.py b/tests/test_client.py index 3a9001e2..aa3ff59b 100644 --- a/tests/test_client.py +++ b/tests/test_client.py @@ -61,6 +61,9 @@ class NullServer (paramiko.ServerInterface): def check_auth_password(self, username, password): if (username == 'slowdive') and (password == 'pygmalion'): return paramiko.AUTH_SUCCESSFUL + if (username == 'slowdive') and (password == 'unresponsive-server'): + time.sleep(5) + return paramiko.AUTH_SUCCESSFUL return paramiko.AUTH_FAILED def check_auth_publickey(self, username, key): @@ -384,6 +387,24 @@ class SSHClientTest (unittest.TestCase): ) self._test_connection(**kwargs) + def test_9_auth_timeout(self): + """ + verify that the SSHClient has a configurable auth timeout + """ + threading.Thread(target=self._run).start() + host_key = paramiko.RSAKey.from_private_key_file(test_path('test_rsa.key')) + public_host_key = paramiko.RSAKey(data=host_key.asbytes()) + + self.tc = paramiko.SSHClient() + self.tc.get_host_keys().add('[%s]:%d' % (self.addr, self.port), 'ssh-rsa', public_host_key) + # Connect with a half second auth timeout + kwargs = dict(self.connect_kwargs, password='unresponsive-server', auth_timeout=0.5) + self.assertRaises( + paramiko.AuthenticationException, + self.tc.connect, + **kwargs + ) + def test_update_environment(self): """ Verify that environment variables can be set by the client. |