diff options
Diffstat (limited to 'tests/test_client.py')
-rw-r--r-- | tests/test_client.py | 162 |
1 files changed, 87 insertions, 75 deletions
diff --git a/tests/test_client.py b/tests/test_client.py index 597c278e..ddfb33fc 100644 --- a/tests/test_client.py +++ b/tests/test_client.py @@ -41,25 +41,25 @@ from .util import _support, slow FINGERPRINTS = { - 'ssh-dss': b'\x44\x78\xf0\xb9\xa2\x3c\xc5\x18\x20\x09\xff\x75\x5b\xc1\xd2\x6c', - 'ssh-rsa': b'\x60\x73\x38\x44\xcb\x51\x86\x65\x7f\xde\xda\xa2\x2b\x5a\x57\xd5', - 'ecdsa-sha2-nistp256': b'\x25\x19\xeb\x55\xe6\xa1\x47\xff\x4f\x38\xd2\x75\x6f\xa5\xd5\x60', + "ssh-dss": b"\x44\x78\xf0\xb9\xa2\x3c\xc5\x18\x20\x09\xff\x75\x5b\xc1\xd2\x6c", + "ssh-rsa": b"\x60\x73\x38\x44\xcb\x51\x86\x65\x7f\xde\xda\xa2\x2b\x5a\x57\xd5", + "ecdsa-sha2-nistp256": b"\x25\x19\xeb\x55\xe6\xa1\x47\xff\x4f\x38\xd2\x75\x6f\xa5\xd5\x60", } -class NullServer (paramiko.ServerInterface): +class NullServer(paramiko.ServerInterface): def __init__(self, *args, **kwargs): # Allow tests to enable/disable specific key types - self.__allowed_keys = kwargs.pop('allowed_keys', []) + self.__allowed_keys = kwargs.pop("allowed_keys", []) super(NullServer, self).__init__(*args, **kwargs) def get_allowed_auths(self, username): - if username == 'slowdive': - return 'publickey,password' - return 'publickey' + if username == "slowdive": + return "publickey,password" + return "publickey" def check_auth_password(self, username, password): - if (username == 'slowdive') and (password == 'pygmalion'): + if (username == "slowdive") and (password == "pygmalion"): return paramiko.AUTH_SUCCESSFUL return paramiko.AUTH_FAILED @@ -69,8 +69,8 @@ class NullServer (paramiko.ServerInterface): except KeyError: return paramiko.AUTH_FAILED if ( - key.get_name() in self.__allowed_keys and - key.get_fingerprint() == expected + key.get_name() in self.__allowed_keys + and key.get_fingerprint() == expected ): return paramiko.AUTH_SUCCESSFUL return paramiko.AUTH_FAILED @@ -79,22 +79,21 @@ class NullServer (paramiko.ServerInterface): return paramiko.OPEN_SUCCEEDED def check_channel_exec_request(self, channel, command): - if command != b'yes': + if command != b"yes": return False return True -class SSHClientTest (unittest.TestCase): - +class SSHClientTest(unittest.TestCase): def setUp(self): self.sockl = socket.socket() - self.sockl.bind(('localhost', 0)) + self.sockl.bind(("localhost", 0)) self.sockl.listen(1) self.addr, self.port = self.sockl.getsockname() self.connect_kwargs = dict( hostname=self.addr, port=self.port, - username='slowdive', + username="slowdive", look_for_keys=False, ) self.event = threading.Event() @@ -109,7 +108,9 @@ class SSHClientTest (unittest.TestCase): allowed_keys = FINGERPRINTS.keys() self.socks, addr = self.sockl.accept() self.ts = paramiko.Transport(self.socks) - host_key = paramiko.RSAKey.from_private_key_file(_support('test_rsa.key')) + host_key = paramiko.RSAKey.from_private_key_file( + _support("test_rsa.key") + ) self.ts.add_server_key(host_key) server = NullServer(allowed_keys=allowed_keys) if delay: @@ -123,15 +124,19 @@ class SSHClientTest (unittest.TestCase): The exception is ``allowed_keys`` which is stripped and handed to the ``NullServer`` used for testing. """ - run_kwargs = {'allowed_keys': kwargs.pop('allowed_keys', None)} + run_kwargs = {"allowed_keys": kwargs.pop("allowed_keys", None)} # Server setup threading.Thread(target=self._run, kwargs=run_kwargs).start() - host_key = paramiko.RSAKey.from_private_key_file(_support('test_rsa.key')) + host_key = paramiko.RSAKey.from_private_key_file( + _support("test_rsa.key") + ) public_host_key = paramiko.RSAKey(data=host_key.asbytes()) # Client setup self.tc = paramiko.SSHClient() - self.tc.get_host_keys().add('[%s]:%d' % (self.addr, self.port), 'ssh-rsa', public_host_key) + self.tc.get_host_keys().add( + "[%s]:%d" % (self.addr, self.port), "ssh-rsa", public_host_key + ) # Actual connection self.tc.connect(**dict(self.connect_kwargs, **kwargs)) @@ -140,22 +145,22 @@ class SSHClientTest (unittest.TestCase): self.event.wait(1.0) self.assertTrue(self.event.is_set()) self.assertTrue(self.ts.is_active()) - self.assertEqual('slowdive', self.ts.get_username()) + self.assertEqual("slowdive", self.ts.get_username()) self.assertEqual(True, self.ts.is_authenticated()) self.assertEqual(False, self.tc.get_transport().gss_kex_used) # Command execution functions? - stdin, stdout, stderr = self.tc.exec_command('yes') + stdin, stdout, stderr = self.tc.exec_command("yes") schan = self.ts.accept(1.0) - schan.send('Hello there.\n') - schan.send_stderr('This is on stderr.\n') + schan.send("Hello there.\n") + schan.send_stderr("This is on stderr.\n") schan.close() - self.assertEqual('Hello there.\n', stdout.readline()) - self.assertEqual('', stdout.readline()) - self.assertEqual('This is on stderr.\n', stderr.readline()) - self.assertEqual('', stderr.readline()) + self.assertEqual("Hello there.\n", stdout.readline()) + self.assertEqual("", stdout.readline()) + self.assertEqual("This is on stderr.\n", stderr.readline()) + self.assertEqual("", stderr.readline()) # Cleanup stdin.close() @@ -166,25 +171,25 @@ class SSHClientTest (unittest.TestCase): """ verify that the SSHClient stuff works too. """ - self._test_connection(password='pygmalion') + self._test_connection(password="pygmalion") def test_2_client_dsa(self): """ verify that SSHClient works with a DSA key. """ - self._test_connection(key_filename=_support('test_dss.key')) + self._test_connection(key_filename=_support("test_dss.key")) def test_client_rsa(self): """ verify that SSHClient works with an RSA key. """ - self._test_connection(key_filename=_support('test_rsa.key')) + self._test_connection(key_filename=_support("test_rsa.key")) def test_2_5_client_ecdsa(self): """ verify that SSHClient works with an ECDSA key. """ - self._test_connection(key_filename=_support('test_ecdsa_256.key')) + self._test_connection(key_filename=_support("test_ecdsa_256.key")) def test_3_multiple_key_files(self): """ @@ -192,22 +197,22 @@ class SSHClientTest (unittest.TestCase): """ # This is dumb :( types_ = { - 'rsa': 'ssh-rsa', - 'dss': 'ssh-dss', - 'ecdsa': 'ecdsa-sha2-nistp256', + "rsa": "ssh-rsa", + "dss": "ssh-dss", + "ecdsa": "ecdsa-sha2-nistp256", } # Various combos of attempted & valid keys # TODO: try every possible combo using itertools functions for attempt, accept in ( - (['rsa', 'dss'], ['dss']), # Original test #3 - (['dss', 'rsa'], ['dss']), # Ordering matters sometimes, sadly - (['dss', 'rsa', 'ecdsa_256'], ['dss']), # Try ECDSA but fail - (['rsa', 'ecdsa_256'], ['ecdsa']), # ECDSA success + (["rsa", "dss"], ["dss"]), # Original test #3 + (["dss", "rsa"], ["dss"]), # Ordering matters sometimes, sadly + (["dss", "rsa", "ecdsa_256"], ["dss"]), # Try ECDSA but fail + (["rsa", "ecdsa_256"], ["ecdsa"]), # ECDSA success ): try: self._test_connection( key_filename=[ - _support('test_{}.key'.format(x)) for x in attempt + _support("test_{}.key".format(x)) for x in attempt ], allowed_keys=[types_[x] for x in accept], ) @@ -223,10 +228,11 @@ class SSHClientTest (unittest.TestCase): """ # Until #387 is fixed we have to catch a high-up exception since # various platforms trigger different errors here >_< - self.assertRaises(SSHException, + self.assertRaises( + SSHException, self._test_connection, - key_filename=[_support('test_rsa.key')], - allowed_keys=['ecdsa-sha2-nistp256'], + key_filename=[_support("test_rsa.key")], + allowed_keys=["ecdsa-sha2-nistp256"], ) def test_4_auto_add_policy(self): @@ -234,29 +240,38 @@ class SSHClientTest (unittest.TestCase): verify that SSHClient's AutoAddPolicy works. """ threading.Thread(target=self._run).start() - host_key = paramiko.RSAKey.from_private_key_file(_support('test_rsa.key')) + host_key = paramiko.RSAKey.from_private_key_file( + _support("test_rsa.key") + ) public_host_key = paramiko.RSAKey(data=host_key.asbytes()) self.tc = paramiko.SSHClient() self.tc.set_missing_host_key_policy(paramiko.AutoAddPolicy()) self.assertEqual(0, len(self.tc.get_host_keys())) - self.tc.connect(password='pygmalion', **self.connect_kwargs) + self.tc.connect(password="pygmalion", **self.connect_kwargs) self.event.wait(1.0) self.assertTrue(self.event.is_set()) self.assertTrue(self.ts.is_active()) - self.assertEqual('slowdive', self.ts.get_username()) + self.assertEqual("slowdive", self.ts.get_username()) self.assertEqual(True, self.ts.is_authenticated()) self.assertEqual(1, len(self.tc.get_host_keys())) - self.assertEqual(public_host_key, self.tc.get_host_keys()['[%s]:%d' % (self.addr, self.port)]['ssh-rsa']) + self.assertEqual( + public_host_key, + self.tc.get_host_keys()["[%s]:%d" % (self.addr, self.port)][ + "ssh-rsa" + ], + ) def test_5_save_host_keys(self): """ verify that SSHClient correctly saves a known_hosts file. """ - warnings.filterwarnings('ignore', 'tempnam.*') + warnings.filterwarnings("ignore", "tempnam.*") - host_key = paramiko.RSAKey.from_private_key_file(_support('test_rsa.key')) + host_key = paramiko.RSAKey.from_private_key_file( + _support("test_rsa.key") + ) public_host_key = paramiko.RSAKey(data=host_key.asbytes()) fd, localname = mkstemp() os.close(fd) @@ -264,11 +279,13 @@ class SSHClientTest (unittest.TestCase): client = paramiko.SSHClient() self.assertEquals(0, len(client.get_host_keys())) - host_id = '[%s]:%d' % (self.addr, self.port) + host_id = "[%s]:%d" % (self.addr, self.port) - client.get_host_keys().add(host_id, 'ssh-rsa', public_host_key) + client.get_host_keys().add(host_id, "ssh-rsa", public_host_key) self.assertEquals(1, len(client.get_host_keys())) - self.assertEquals(public_host_key, client.get_host_keys()[host_id]['ssh-rsa']) + self.assertEquals( + public_host_key, client.get_host_keys()[host_id]["ssh-rsa"] + ) client.save_host_keys(localname) @@ -291,7 +308,7 @@ class SSHClientTest (unittest.TestCase): self.tc = paramiko.SSHClient() self.tc.set_missing_host_key_policy(paramiko.AutoAddPolicy()) self.assertEqual(0, len(self.tc.get_host_keys())) - self.tc.connect(**dict(self.connect_kwargs, password='pygmalion')) + self.tc.connect(**dict(self.connect_kwargs, password="pygmalion")) self.event.wait(1.0) self.assertTrue(self.event.is_set()) @@ -320,7 +337,7 @@ class SSHClientTest (unittest.TestCase): self.tc = tc self.tc.set_missing_host_key_policy(paramiko.AutoAddPolicy()) self.assertEquals(0, len(self.tc.get_host_keys())) - self.tc.connect(**dict(self.connect_kwargs, password='pygmalion')) + self.tc.connect(**dict(self.connect_kwargs, password="pygmalion")) self.event.wait(1.0) self.assertTrue(self.event.is_set()) @@ -335,19 +352,19 @@ class SSHClientTest (unittest.TestCase): verify that the SSHClient has a configurable banner timeout. """ # Start the thread with a 1 second wait. - threading.Thread(target=self._run, kwargs={'delay': 1}).start() - host_key = paramiko.RSAKey.from_private_key_file(_support('test_rsa.key')) + threading.Thread(target=self._run, kwargs={"delay": 1}).start() + host_key = paramiko.RSAKey.from_private_key_file( + _support("test_rsa.key") + ) public_host_key = paramiko.RSAKey(data=host_key.asbytes()) self.tc = paramiko.SSHClient() - self.tc.get_host_keys().add('[%s]:%d' % (self.addr, self.port), 'ssh-rsa', public_host_key) + self.tc.get_host_keys().add( + "[%s]:%d" % (self.addr, self.port), "ssh-rsa", public_host_key + ) # Connect with a half second banner timeout. kwargs = dict(self.connect_kwargs, banner_timeout=0.5) - self.assertRaises( - paramiko.SSHException, - self.tc.connect, - **kwargs - ) + self.assertRaises(paramiko.SSHException, self.tc.connect, **kwargs) def test_8_auth_trickledown(self): """ @@ -363,9 +380,9 @@ class SSHClientTest (unittest.TestCase): # 'television' as per tests/test_pkey.py). NOTE: must use # key_filename, loading the actual key here with PKey will except # immediately; we're testing the try/except crap within Client. - key_filename=[_support('test_rsa_password.key')], + key_filename=[_support("test_rsa_password.key")], # Actual password for default 'slowdive' user - password='pygmalion', + password="pygmalion", ) self._test_connection(**kwargs) @@ -375,10 +392,7 @@ class SSHClientTest (unittest.TestCase): """ if not paramiko.GSS_AUTH_AVAILABLE: return # for python 2.6 lacks skipTest - kwargs = dict( - gss_kex=True, - key_filename=[_support('test_rsa.key')], - ) + kwargs = dict(gss_kex=True, key_filename=[_support("test_rsa.key")]) self._test_connection(**kwargs) def test_10_auth_trickledown_gssauth(self): @@ -387,10 +401,7 @@ class SSHClientTest (unittest.TestCase): """ if not paramiko.GSS_AUTH_AVAILABLE: return # for python 2.6 lacks skipTest - kwargs = dict( - gss_auth=True, - key_filename=[_support('test_rsa.key')], - ) + kwargs = dict(gss_auth=True, key_filename=[_support("test_rsa.key")]) self._test_connection(**kwargs) def test_11_reject_policy(self): @@ -405,7 +416,8 @@ class SSHClientTest (unittest.TestCase): self.assertRaises( paramiko.SSHException, self.tc.connect, - password='pygmalion', **self.connect_kwargs + password="pygmalion", + **self.connect_kwargs ) def test_12_reject_policy_gsskex(self): @@ -424,7 +436,7 @@ class SSHClientTest (unittest.TestCase): self.assertRaises( paramiko.SSHException, self.tc.connect, - password='pygmalion', + password="pygmalion", gss_kex=True, - **self.connect_kwargs + **self.connect_kwargs ) |