summaryrefslogtreecommitdiffhomepage
diff options
context:
space:
mode:
-rw-r--r--paramiko/client.py6
-rw-r--r--tests/test_client.py27
2 files changed, 31 insertions, 2 deletions
diff --git a/paramiko/client.py b/paramiko/client.py
index 4326abbd..4f75cefe 100644
--- a/paramiko/client.py
+++ b/paramiko/client.py
@@ -172,7 +172,7 @@ class SSHClient (object):
def connect(self, hostname, port=SSH_PORT, username=None, password=None, pkey=None,
key_filename=None, timeout=None, allow_agent=True, look_for_keys=True,
- compress=False, sock=None):
+ compress=False, sock=None, banner_timeout=None):
"""
Connect to an SSH server and authenticate to it. The server's host key
is checked against the system host keys (see `load_system_host_keys`)
@@ -212,6 +212,8 @@ class SSHClient (object):
:param socket sock:
an open socket or socket-like object (such as a `.Channel`) to use
for communication to the target host
+ :param float banner_timeout: an optional timeout (in seconds) to wait
+ for the SSH banner to be presented.
:raises BadHostKeyException: if the server's host key could not be
verified
@@ -241,6 +243,8 @@ class SSHClient (object):
t.use_compression(compress=compress)
if self._log_channel is not None:
t.set_log_channel(self._log_channel)
+ if banner_timeout is not None:
+ t.banner_timeout = banner_timeout
t.start_client()
ResourceManager.register(self, t)
diff --git a/tests/test_client.py b/tests/test_client.py
index 33dd9f23..6fda7f5e 100644
--- a/tests/test_client.py
+++ b/tests/test_client.py
@@ -27,6 +27,7 @@ import unittest
import weakref
import warnings
import os
+import time
from tests.util import test_path
import paramiko
from paramiko.common import PY2, b
@@ -91,7 +92,7 @@ class SSHClientTest (unittest.TestCase):
if hasattr(self, attr):
getattr(self, attr).close()
- def _run(self, allowed_keys=None):
+ def _run(self, allowed_keys=None, delay=0):
if allowed_keys is None:
allowed_keys = FINGERPRINTS.keys()
self.socks, addr = self.sockl.accept()
@@ -99,6 +100,8 @@ class SSHClientTest (unittest.TestCase):
host_key = paramiko.RSAKey.from_private_key_file(test_path('test_rsa.key'))
self.ts.add_server_key(host_key)
server = NullServer(allowed_keys=allowed_keys)
+ if delay:
+ time.sleep(delay)
self.ts.start_server(self.event, server)
def _test_connection(self, **kwargs):
@@ -295,3 +298,25 @@ class SSHClientTest (unittest.TestCase):
gc.collect()
self.assertTrue(p() is None)
+
+ def test_7_banner_timeout(self):
+ """
+ verify that the SSHClient has a configurable banner timeout.
+ """
+ # Start the thread with a 1 second wait.
+ threading.Thread(target=self._run, kwargs={'delay': 1}).start()
+ host_key = paramiko.RSAKey.from_private_key_file(test_path('test_rsa.key'))
+ public_host_key = paramiko.RSAKey(data=host_key.asbytes())
+
+ self.tc = paramiko.SSHClient()
+ self.tc.get_host_keys().add('[%s]:%d' % (self.addr, self.port), 'ssh-rsa', public_host_key)
+ # Connect with a half second banner timeout.
+ self.assertRaises(
+ paramiko.SSHException,
+ self.tc.connect,
+ self.addr,
+ self.port,
+ username='slowdive',
+ password='pygmalion',
+ banner_timeout=0.5
+ )