diff options
Diffstat (limited to 'tests/test_client.py')
-rw-r--r-- | tests/test_client.py | 164 |
1 files changed, 84 insertions, 80 deletions
diff --git a/tests/test_client.py b/tests/test_client.py index 4a7117ac..dd2676f5 100644 --- a/tests/test_client.py +++ b/tests/test_client.py @@ -23,42 +23,43 @@ Some unit tests for SSHClient. from __future__ import with_statement import gc +import os import platform import socket -from tempfile import mkstemp import threading +import time import unittest -import weakref import warnings -import os -import time -from tests.util import test_path +import weakref +from tempfile import mkstemp import paramiko from paramiko.py3compat import PY2, b from paramiko.ssh_exception import SSHException +from .util import _support, slow + FINGERPRINTS = { - 'ssh-dss': b'\x44\x78\xf0\xb9\xa2\x3c\xc5\x18\x20\x09\xff\x75\x5b\xc1\xd2\x6c', - 'ssh-rsa': b'\x60\x73\x38\x44\xcb\x51\x86\x65\x7f\xde\xda\xa2\x2b\x5a\x57\xd5', - 'ecdsa-sha2-nistp256': b'\x25\x19\xeb\x55\xe6\xa1\x47\xff\x4f\x38\xd2\x75\x6f\xa5\xd5\x60', + "ssh-dss": b"\x44\x78\xf0\xb9\xa2\x3c\xc5\x18\x20\x09\xff\x75\x5b\xc1\xd2\x6c", + "ssh-rsa": b"\x60\x73\x38\x44\xcb\x51\x86\x65\x7f\xde\xda\xa2\x2b\x5a\x57\xd5", + "ecdsa-sha2-nistp256": b"\x25\x19\xeb\x55\xe6\xa1\x47\xff\x4f\x38\xd2\x75\x6f\xa5\xd5\x60", } -class NullServer (paramiko.ServerInterface): +class NullServer(paramiko.ServerInterface): def __init__(self, *args, **kwargs): # Allow tests to enable/disable specific key types - self.__allowed_keys = kwargs.pop('allowed_keys', []) + self.__allowed_keys = kwargs.pop("allowed_keys", []) super(NullServer, self).__init__(*args, **kwargs) def get_allowed_auths(self, username): - if username == 'slowdive': - return 'publickey,password' - return 'publickey' + if username == "slowdive": + return "publickey,password" + return "publickey" def check_auth_password(self, username, password): - if (username == 'slowdive') and (password == 'pygmalion'): + if (username == "slowdive") and (password == "pygmalion"): return paramiko.AUTH_SUCCESSFUL return paramiko.AUTH_FAILED @@ -68,8 +69,8 @@ class NullServer (paramiko.ServerInterface): except KeyError: return paramiko.AUTH_FAILED if ( - key.get_name() in self.__allowed_keys and - key.get_fingerprint() == expected + key.get_name() in self.__allowed_keys + and key.get_fingerprint() == expected ): return paramiko.AUTH_SUCCESSFUL return paramiko.AUTH_FAILED @@ -78,7 +79,7 @@ class NullServer (paramiko.ServerInterface): return paramiko.OPEN_SUCCEEDED def check_channel_exec_request(self, channel, command): - if command != b'yes': + if command != b"yes": return False return True @@ -93,17 +94,16 @@ class NullServer (paramiko.ServerInterface): return True -class SSHClientTest (unittest.TestCase): - +class SSHClientTest(unittest.TestCase): def setUp(self): self.sockl = socket.socket() - self.sockl.bind(('localhost', 0)) + self.sockl.bind(("localhost", 0)) self.sockl.listen(1) self.addr, self.port = self.sockl.getsockname() self.connect_kwargs = dict( hostname=self.addr, port=self.port, - username='slowdive', + username="slowdive", look_for_keys=False, ) self.event = threading.Event() @@ -118,10 +118,10 @@ class SSHClientTest (unittest.TestCase): allowed_keys = FINGERPRINTS.keys() self.socks, addr = self.sockl.accept() self.ts = paramiko.Transport(self.socks) - keypath = test_path('test_rsa.key') + keypath = _support('test_rsa.key') host_key = paramiko.RSAKey.from_private_key_file(keypath) self.ts.add_server_key(host_key) - keypath = test_path('test_ecdsa_256.key') + keypath = _support('test_ecdsa_256.key') host_key = paramiko.ECDSAKey.from_private_key_file(keypath) self.ts.add_server_key(host_key) server = NullServer(allowed_keys=allowed_keys) @@ -136,15 +136,19 @@ class SSHClientTest (unittest.TestCase): The exception is ``allowed_keys`` which is stripped and handed to the ``NullServer`` used for testing. """ - run_kwargs = {'allowed_keys': kwargs.pop('allowed_keys', None)} + run_kwargs = {"allowed_keys": kwargs.pop("allowed_keys", None)} # Server setup threading.Thread(target=self._run, kwargs=run_kwargs).start() - host_key = paramiko.RSAKey.from_private_key_file(test_path('test_rsa.key')) + host_key = paramiko.RSAKey.from_private_key_file( + _support("test_rsa.key") + ) public_host_key = paramiko.RSAKey(data=host_key.asbytes()) # Client setup self.tc = paramiko.SSHClient() - self.tc.get_host_keys().add('[%s]:%d' % (self.addr, self.port), 'ssh-rsa', public_host_key) + self.tc.get_host_keys().add( + "[%s]:%d" % (self.addr, self.port), "ssh-rsa", public_host_key + ) # Actual connection self.tc.connect(**dict(self.connect_kwargs, **kwargs)) @@ -153,22 +157,22 @@ class SSHClientTest (unittest.TestCase): self.event.wait(1.0) self.assertTrue(self.event.is_set()) self.assertTrue(self.ts.is_active()) - self.assertEqual('slowdive', self.ts.get_username()) + self.assertEqual("slowdive", self.ts.get_username()) self.assertEqual(True, self.ts.is_authenticated()) self.assertEqual(False, self.tc.get_transport().gss_kex_used) # Command execution functions? - stdin, stdout, stderr = self.tc.exec_command('yes') + stdin, stdout, stderr = self.tc.exec_command("yes") schan = self.ts.accept(1.0) - schan.send('Hello there.\n') - schan.send_stderr('This is on stderr.\n') + schan.send("Hello there.\n") + schan.send_stderr("This is on stderr.\n") schan.close() - self.assertEqual('Hello there.\n', stdout.readline()) - self.assertEqual('', stdout.readline()) - self.assertEqual('This is on stderr.\n', stderr.readline()) - self.assertEqual('', stderr.readline()) + self.assertEqual("Hello there.\n", stdout.readline()) + self.assertEqual("", stdout.readline()) + self.assertEqual("This is on stderr.\n", stderr.readline()) + self.assertEqual("", stderr.readline()) # Cleanup stdin.close() @@ -179,25 +183,25 @@ class SSHClientTest (unittest.TestCase): """ verify that the SSHClient stuff works too. """ - self._test_connection(password='pygmalion') + self._test_connection(password="pygmalion") def test_2_client_dsa(self): """ verify that SSHClient works with a DSA key. """ - self._test_connection(key_filename=test_path('test_dss.key')) + self._test_connection(key_filename=_support("test_dss.key")) def test_client_rsa(self): """ verify that SSHClient works with an RSA key. """ - self._test_connection(key_filename=test_path('test_rsa.key')) + self._test_connection(key_filename=_support("test_rsa.key")) def test_2_5_client_ecdsa(self): """ verify that SSHClient works with an ECDSA key. """ - self._test_connection(key_filename=test_path('test_ecdsa_256.key')) + self._test_connection(key_filename=_support("test_ecdsa_256.key")) def test_3_multiple_key_files(self): """ @@ -205,22 +209,22 @@ class SSHClientTest (unittest.TestCase): """ # This is dumb :( types_ = { - 'rsa': 'ssh-rsa', - 'dss': 'ssh-dss', - 'ecdsa': 'ecdsa-sha2-nistp256', + "rsa": "ssh-rsa", + "dss": "ssh-dss", + "ecdsa": "ecdsa-sha2-nistp256", } # Various combos of attempted & valid keys # TODO: try every possible combo using itertools functions for attempt, accept in ( - (['rsa', 'dss'], ['dss']), # Original test #3 - (['dss', 'rsa'], ['dss']), # Ordering matters sometimes, sadly - (['dss', 'rsa', 'ecdsa_256'], ['dss']), # Try ECDSA but fail - (['rsa', 'ecdsa_256'], ['ecdsa']), # ECDSA success + (["rsa", "dss"], ["dss"]), # Original test #3 + (["dss", "rsa"], ["dss"]), # Ordering matters sometimes, sadly + (["dss", "rsa", "ecdsa_256"], ["dss"]), # Try ECDSA but fail + (["rsa", "ecdsa_256"], ["ecdsa"]), # ECDSA success ): try: self._test_connection( key_filename=[ - test_path('test_{0}.key'.format(x)) for x in attempt + _support("test_{0}.key".format(x)) for x in attempt ], allowed_keys=[types_[x] for x in accept], ) @@ -236,10 +240,11 @@ class SSHClientTest (unittest.TestCase): """ # Until #387 is fixed we have to catch a high-up exception since # various platforms trigger different errors here >_< - self.assertRaises(SSHException, + self.assertRaises( + SSHException, self._test_connection, - key_filename=[test_path('test_rsa.key')], - allowed_keys=['ecdsa-sha2-nistp256'], + key_filename=[_support("test_rsa.key")], + allowed_keys=["ecdsa-sha2-nistp256"], ) def test_4_auto_add_policy(self): @@ -248,18 +253,18 @@ class SSHClientTest (unittest.TestCase): """ threading.Thread(target=self._run).start() hostname = '[%s]:%d' % (self.addr, self.port) - key_file = test_path('test_ecdsa_256.key') + key_file = _support('test_ecdsa_256.key') public_host_key = paramiko.ECDSAKey.from_private_key_file(key_file) self.tc = paramiko.SSHClient() self.tc.set_missing_host_key_policy(paramiko.AutoAddPolicy()) self.assertEqual(0, len(self.tc.get_host_keys())) - self.tc.connect(password='pygmalion', **self.connect_kwargs) + self.tc.connect(password="pygmalion", **self.connect_kwargs) self.event.wait(1.0) self.assertTrue(self.event.is_set()) self.assertTrue(self.ts.is_active()) - self.assertEqual('slowdive', self.ts.get_username()) + self.assertEqual("slowdive", self.ts.get_username()) self.assertEqual(True, self.ts.is_authenticated()) self.assertEqual(1, len(self.tc.get_host_keys())) new_host_key = list(self.tc.get_host_keys()[hostname].values())[0] @@ -269,9 +274,11 @@ class SSHClientTest (unittest.TestCase): """ verify that SSHClient correctly saves a known_hosts file. """ - warnings.filterwarnings('ignore', 'tempnam.*') + warnings.filterwarnings("ignore", "tempnam.*") - host_key = paramiko.RSAKey.from_private_key_file(test_path('test_rsa.key')) + host_key = paramiko.RSAKey.from_private_key_file( + _support("test_rsa.key") + ) public_host_key = paramiko.RSAKey(data=host_key.asbytes()) fd, localname = mkstemp() os.close(fd) @@ -279,11 +286,13 @@ class SSHClientTest (unittest.TestCase): client = paramiko.SSHClient() self.assertEquals(0, len(client.get_host_keys())) - host_id = '[%s]:%d' % (self.addr, self.port) + host_id = "[%s]:%d" % (self.addr, self.port) - client.get_host_keys().add(host_id, 'ssh-rsa', public_host_key) + client.get_host_keys().add(host_id, "ssh-rsa", public_host_key) self.assertEquals(1, len(client.get_host_keys())) - self.assertEquals(public_host_key, client.get_host_keys()[host_id]['ssh-rsa']) + self.assertEquals( + public_host_key, client.get_host_keys()[host_id]["ssh-rsa"] + ) client.save_host_keys(localname) @@ -306,7 +315,7 @@ class SSHClientTest (unittest.TestCase): self.tc = paramiko.SSHClient() self.tc.set_missing_host_key_policy(paramiko.AutoAddPolicy()) self.assertEqual(0, len(self.tc.get_host_keys())) - self.tc.connect(**dict(self.connect_kwargs, password='pygmalion')) + self.tc.connect(**dict(self.connect_kwargs, password="pygmalion")) self.event.wait(1.0) self.assertTrue(self.event.is_set()) @@ -335,7 +344,7 @@ class SSHClientTest (unittest.TestCase): self.tc = tc self.tc.set_missing_host_key_policy(paramiko.AutoAddPolicy()) self.assertEquals(0, len(self.tc.get_host_keys())) - self.tc.connect(**dict(self.connect_kwargs, password='pygmalion')) + self.tc.connect(**dict(self.connect_kwargs, password="pygmalion")) self.event.wait(1.0) self.assertTrue(self.event.is_set()) @@ -350,19 +359,19 @@ class SSHClientTest (unittest.TestCase): verify that the SSHClient has a configurable banner timeout. """ # Start the thread with a 1 second wait. - threading.Thread(target=self._run, kwargs={'delay': 1}).start() - host_key = paramiko.RSAKey.from_private_key_file(test_path('test_rsa.key')) + threading.Thread(target=self._run, kwargs={"delay": 1}).start() + host_key = paramiko.RSAKey.from_private_key_file( + _support("test_rsa.key") + ) public_host_key = paramiko.RSAKey(data=host_key.asbytes()) self.tc = paramiko.SSHClient() - self.tc.get_host_keys().add('[%s]:%d' % (self.addr, self.port), 'ssh-rsa', public_host_key) + self.tc.get_host_keys().add( + "[%s]:%d" % (self.addr, self.port), "ssh-rsa", public_host_key + ) # Connect with a half second banner timeout. kwargs = dict(self.connect_kwargs, banner_timeout=0.5) - self.assertRaises( - paramiko.SSHException, - self.tc.connect, - **kwargs - ) + self.assertRaises(paramiko.SSHException, self.tc.connect, **kwargs) def test_8_auth_trickledown(self): """ @@ -378,9 +387,9 @@ class SSHClientTest (unittest.TestCase): # 'television' as per tests/test_pkey.py). NOTE: must use # key_filename, loading the actual key here with PKey will except # immediately; we're testing the try/except crap within Client. - key_filename=[test_path('test_rsa_password.key')], + key_filename=[_support("test_rsa_password.key")], # Actual password for default 'slowdive' user - password='pygmalion', + password="pygmalion", ) self._test_connection(**kwargs) @@ -390,10 +399,7 @@ class SSHClientTest (unittest.TestCase): """ if not paramiko.GSS_AUTH_AVAILABLE: return # for python 2.6 lacks skipTest - kwargs = dict( - gss_kex=True, - key_filename=[test_path('test_rsa.key')], - ) + kwargs = dict(gss_kex=True, key_filename=[_support("test_rsa.key")]) self._test_connection(**kwargs) def test_10_auth_trickledown_gssauth(self): @@ -402,10 +408,7 @@ class SSHClientTest (unittest.TestCase): """ if not paramiko.GSS_AUTH_AVAILABLE: return # for python 2.6 lacks skipTest - kwargs = dict( - gss_auth=True, - key_filename=[test_path('test_rsa.key')], - ) + kwargs = dict(gss_auth=True, key_filename=[_support("test_rsa.key")]) self._test_connection(**kwargs) def test_11_reject_policy(self): @@ -420,7 +423,8 @@ class SSHClientTest (unittest.TestCase): self.assertRaises( paramiko.SSHException, self.tc.connect, - password='pygmalion', **self.connect_kwargs + password="pygmalion", + **self.connect_kwargs ) def test_12_reject_policy_gsskex(self): @@ -439,9 +443,9 @@ class SSHClientTest (unittest.TestCase): self.assertRaises( paramiko.SSHException, self.tc.connect, - password='pygmalion', + password="pygmalion", gss_kex=True, - **self.connect_kwargs + **self.connect_kwargs ) def _client_host_key_bad(self, host_key): |