summaryrefslogtreecommitdiffhomepage
diff options
context:
space:
mode:
-rw-r--r--paramiko/client.py18
-rw-r--r--tests/test_client.py21
2 files changed, 31 insertions, 8 deletions
diff --git a/paramiko/client.py b/paramiko/client.py
index 33eab422..7870ea9f 100644
--- a/paramiko/client.py
+++ b/paramiko/client.py
@@ -255,9 +255,9 @@ class SSHClient (object):
@type password: str
@param pkey: an optional private key to use for authentication
@type pkey: L{PKey}
- @param key_filename: the filename of an optional private key to use
- for authentication
- @type key_filename: str
+ @param key_filename: the filename, or list of filenames, of optional
+ private key(s) to try for authentication
+ @type key_filename: str or list(str)
@param timeout: an optional timeout (in seconds) for the TCP connect
@type timeout: float
@param allow_agent: set to False to disable connecting to the SSH agent
@@ -306,7 +306,13 @@ class SSHClient (object):
if username is None:
username = getpass.getuser()
- self._auth(username, password, pkey, key_filename, allow_agent, look_for_keys)
+ if key_filename is None:
+ key_filenames = []
+ elif isinstance(key_filename, (str, unicode)):
+ key_filenames = [ key_filename ]
+ else:
+ key_filenames = key_filename
+ self._auth(username, password, pkey, key_filenames, allow_agent, look_for_keys)
def close(self):
"""
@@ -382,7 +388,7 @@ class SSHClient (object):
"""
return self._transport
- def _auth(self, username, password, pkey, key_filename, allow_agent, look_for_keys):
+ def _auth(self, username, password, pkey, key_filenames, allow_agent, look_for_keys):
"""
Try, in order:
@@ -403,7 +409,7 @@ class SSHClient (object):
except SSHException, e:
saved_exception = e
- if key_filename is not None:
+ for key_filename in key_filenames:
for pkey_class in (RSAKey, DSSKey):
try:
key = pkey_class.from_private_key_file(key_filename, password)
diff --git a/tests/test_client.py b/tests/test_client.py
index d372eebb..59cd67cb 100644
--- a/tests/test_client.py
+++ b/tests/test_client.py
@@ -148,8 +148,25 @@ class SSHClientTest (unittest.TestCase):
stdin.close()
stdout.close()
stderr.close()
+
+ def test_3_multiple_key_files(self):
+ """
+ verify that SSHClient accepts and tries multiple key files.
+ """
+ host_key = paramiko.RSAKey.from_private_key_file('tests/test_rsa.key')
+ public_host_key = paramiko.RSAKey(data=str(host_key))
+
+ self.tc = paramiko.SSHClient()
+ self.tc.get_host_keys().add(self.addr, 'ssh-rsa', public_host_key)
+ self.tc.connect(self.addr, self.port, username='slowdive', key_filename=[ 'tests/test_rsa.key', 'tests/test_dss.key' ])
+
+ self.event.wait(1.0)
+ self.assert_(self.event.isSet())
+ self.assert_(self.ts.is_active())
+ self.assertEquals('slowdive', self.ts.get_username())
+ self.assertEquals(True, self.ts.is_authenticated())
- def test_3_auto_add_policy(self):
+ def test_4_auto_add_policy(self):
"""
verify that SSHClient's AutoAddPolicy works.
"""
@@ -169,7 +186,7 @@ class SSHClientTest (unittest.TestCase):
self.assertEquals(1, len(self.tc.get_host_keys()))
self.assertEquals(public_host_key, self.tc.get_host_keys()[self.addr]['ssh-rsa'])
- def test_4_cleanup(self):
+ def test_5_cleanup(self):
"""
verify that when an SSHClient is collected, its transport (and the
transport's packetizer) is closed.