diff options
-rw-r--r-- | paramiko/client.py | 6 | ||||
-rw-r--r-- | tests/test_client.py | 27 |
2 files changed, 31 insertions, 2 deletions
diff --git a/paramiko/client.py b/paramiko/client.py index 4326abbd..4f75cefe 100644 --- a/paramiko/client.py +++ b/paramiko/client.py @@ -172,7 +172,7 @@ class SSHClient (object): def connect(self, hostname, port=SSH_PORT, username=None, password=None, pkey=None, key_filename=None, timeout=None, allow_agent=True, look_for_keys=True, - compress=False, sock=None): + compress=False, sock=None, banner_timeout=None): """ Connect to an SSH server and authenticate to it. The server's host key is checked against the system host keys (see `load_system_host_keys`) @@ -212,6 +212,8 @@ class SSHClient (object): :param socket sock: an open socket or socket-like object (such as a `.Channel`) to use for communication to the target host + :param float banner_timeout: an optional timeout (in seconds) to wait + for the SSH banner to be presented. :raises BadHostKeyException: if the server's host key could not be verified @@ -241,6 +243,8 @@ class SSHClient (object): t.use_compression(compress=compress) if self._log_channel is not None: t.set_log_channel(self._log_channel) + if banner_timeout is not None: + t.banner_timeout = banner_timeout t.start_client() ResourceManager.register(self, t) diff --git a/tests/test_client.py b/tests/test_client.py index 33dd9f23..6fda7f5e 100644 --- a/tests/test_client.py +++ b/tests/test_client.py @@ -27,6 +27,7 @@ import unittest import weakref import warnings import os +import time from tests.util import test_path import paramiko from paramiko.common import PY2, b @@ -91,7 +92,7 @@ class SSHClientTest (unittest.TestCase): if hasattr(self, attr): getattr(self, attr).close() - def _run(self, allowed_keys=None): + def _run(self, allowed_keys=None, delay=0): if allowed_keys is None: allowed_keys = FINGERPRINTS.keys() self.socks, addr = self.sockl.accept() @@ -99,6 +100,8 @@ class SSHClientTest (unittest.TestCase): host_key = paramiko.RSAKey.from_private_key_file(test_path('test_rsa.key')) self.ts.add_server_key(host_key) server = NullServer(allowed_keys=allowed_keys) + if delay: + time.sleep(delay) self.ts.start_server(self.event, server) def _test_connection(self, **kwargs): @@ -295,3 +298,25 @@ class SSHClientTest (unittest.TestCase): gc.collect() self.assertTrue(p() is None) + + def test_7_banner_timeout(self): + """ + verify that the SSHClient has a configurable banner timeout. + """ + # Start the thread with a 1 second wait. + threading.Thread(target=self._run, kwargs={'delay': 1}).start() + host_key = paramiko.RSAKey.from_private_key_file(test_path('test_rsa.key')) + public_host_key = paramiko.RSAKey(data=host_key.asbytes()) + + self.tc = paramiko.SSHClient() + self.tc.get_host_keys().add('[%s]:%d' % (self.addr, self.port), 'ssh-rsa', public_host_key) + # Connect with a half second banner timeout. + self.assertRaises( + paramiko.SSHException, + self.tc.connect, + self.addr, + self.port, + username='slowdive', + password='pygmalion', + banner_timeout=0.5 + ) |