diff options
-rw-r--r-- | paramiko/client.py | 18 | ||||
-rw-r--r-- | tests/test_client.py | 21 |
2 files changed, 31 insertions, 8 deletions
diff --git a/paramiko/client.py b/paramiko/client.py index 33eab422..7870ea9f 100644 --- a/paramiko/client.py +++ b/paramiko/client.py @@ -255,9 +255,9 @@ class SSHClient (object): @type password: str @param pkey: an optional private key to use for authentication @type pkey: L{PKey} - @param key_filename: the filename of an optional private key to use - for authentication - @type key_filename: str + @param key_filename: the filename, or list of filenames, of optional + private key(s) to try for authentication + @type key_filename: str or list(str) @param timeout: an optional timeout (in seconds) for the TCP connect @type timeout: float @param allow_agent: set to False to disable connecting to the SSH agent @@ -306,7 +306,13 @@ class SSHClient (object): if username is None: username = getpass.getuser() - self._auth(username, password, pkey, key_filename, allow_agent, look_for_keys) + if key_filename is None: + key_filenames = [] + elif isinstance(key_filename, (str, unicode)): + key_filenames = [ key_filename ] + else: + key_filenames = key_filename + self._auth(username, password, pkey, key_filenames, allow_agent, look_for_keys) def close(self): """ @@ -382,7 +388,7 @@ class SSHClient (object): """ return self._transport - def _auth(self, username, password, pkey, key_filename, allow_agent, look_for_keys): + def _auth(self, username, password, pkey, key_filenames, allow_agent, look_for_keys): """ Try, in order: @@ -403,7 +409,7 @@ class SSHClient (object): except SSHException, e: saved_exception = e - if key_filename is not None: + for key_filename in key_filenames: for pkey_class in (RSAKey, DSSKey): try: key = pkey_class.from_private_key_file(key_filename, password) diff --git a/tests/test_client.py b/tests/test_client.py index d372eebb..59cd67cb 100644 --- a/tests/test_client.py +++ b/tests/test_client.py @@ -148,8 +148,25 @@ class SSHClientTest (unittest.TestCase): stdin.close() stdout.close() stderr.close() + + def test_3_multiple_key_files(self): + """ + verify that SSHClient accepts and tries multiple key files. + """ + host_key = paramiko.RSAKey.from_private_key_file('tests/test_rsa.key') + public_host_key = paramiko.RSAKey(data=str(host_key)) + + self.tc = paramiko.SSHClient() + self.tc.get_host_keys().add(self.addr, 'ssh-rsa', public_host_key) + self.tc.connect(self.addr, self.port, username='slowdive', key_filename=[ 'tests/test_rsa.key', 'tests/test_dss.key' ]) + + self.event.wait(1.0) + self.assert_(self.event.isSet()) + self.assert_(self.ts.is_active()) + self.assertEquals('slowdive', self.ts.get_username()) + self.assertEquals(True, self.ts.is_authenticated()) - def test_3_auto_add_policy(self): + def test_4_auto_add_policy(self): """ verify that SSHClient's AutoAddPolicy works. """ @@ -169,7 +186,7 @@ class SSHClientTest (unittest.TestCase): self.assertEquals(1, len(self.tc.get_host_keys())) self.assertEquals(public_host_key, self.tc.get_host_keys()[self.addr]['ssh-rsa']) - def test_4_cleanup(self): + def test_5_cleanup(self): """ verify that when an SSHClient is collected, its transport (and the transport's packetizer) is closed. |