diff options
Diffstat (limited to 'tests/test_auth.py')
-rw-r--r-- | tests/test_auth.py | 118 |
1 files changed, 67 insertions, 51 deletions
diff --git a/tests/test_auth.py b/tests/test_auth.py index dacdd654..acabb1bd 100644 --- a/tests/test_auth.py +++ b/tests/test_auth.py @@ -26,8 +26,13 @@ import unittest from time import sleep from paramiko import ( - Transport, ServerInterface, RSAKey, DSSKey, BadAuthenticationType, - InteractiveQuery, AuthenticationException, + Transport, + ServerInterface, + RSAKey, + DSSKey, + BadAuthenticationType, + InteractiveQuery, + AuthenticationException, ) from paramiko import AUTH_FAILED, AUTH_PARTIALLY_SUCCESSFUL, AUTH_SUCCESSFUL from paramiko.py3compat import u @@ -36,54 +41,57 @@ from .loop import LoopSocket from .util import _support, slow -_pwd = u('\u2022') +_pwd = u("\u2022") -class NullServer (ServerInterface): +class NullServer(ServerInterface): paranoid_did_password = False paranoid_did_public_key = False - paranoid_key = DSSKey.from_private_key_file(_support('test_dss.key')) + paranoid_key = DSSKey.from_private_key_file(_support("test_dss.key")) def get_allowed_auths(self, username): - if username == 'slowdive': - return 'publickey,password' - if username == 'paranoid': - if not self.paranoid_did_password and not self.paranoid_did_public_key: - return 'publickey,password' + if username == "slowdive": + return "publickey,password" + if username == "paranoid": + if ( + not self.paranoid_did_password + and not self.paranoid_did_public_key + ): + return "publickey,password" elif self.paranoid_did_password: - return 'publickey' + return "publickey" else: - return 'password' - if username == 'commie': - return 'keyboard-interactive' - if username == 'utf8': - return 'password' - if username == 'non-utf8': - return 'password' - return 'publickey' + return "password" + if username == "commie": + return "keyboard-interactive" + if username == "utf8": + return "password" + if username == "non-utf8": + return "password" + return "publickey" def check_auth_password(self, username, password): - if (username == 'slowdive') and (password == 'pygmalion'): + if (username == "slowdive") and (password == "pygmalion"): return AUTH_SUCCESSFUL - if (username == 'paranoid') and (password == 'paranoid'): + if (username == "paranoid") and (password == "paranoid"): # 2-part auth (even openssh doesn't support this) self.paranoid_did_password = True if self.paranoid_did_public_key: return AUTH_SUCCESSFUL return AUTH_PARTIALLY_SUCCESSFUL - if (username == 'utf8') and (password == _pwd): + if (username == "utf8") and (password == _pwd): return AUTH_SUCCESSFUL - if (username == 'non-utf8') and (password == '\xff'): + if (username == "non-utf8") and (password == "\xff"): return AUTH_SUCCESSFUL - if username == 'bad-server': + if username == "bad-server": raise Exception("Ack!") - if username == 'unresponsive-server': + if username == "unresponsive-server": sleep(5) return AUTH_SUCCESSFUL return AUTH_FAILED def check_auth_publickey(self, username, key): - if (username == 'paranoid') and (key == self.paranoid_key): + if (username == "paranoid") and (key == self.paranoid_key): # 2-part auth self.paranoid_did_public_key = True if self.paranoid_did_password: @@ -92,19 +100,21 @@ class NullServer (ServerInterface): return AUTH_FAILED def check_auth_interactive(self, username, submethods): - if username == 'commie': + if username == "commie": self.username = username - return InteractiveQuery('password', 'Please enter a password.', ('Password', False)) + return InteractiveQuery( + "password", "Please enter a password.", ("Password", False) + ) return AUTH_FAILED def check_auth_interactive_response(self, responses): - if self.username == 'commie': - if (len(responses) == 1) and (responses[0] == 'cat'): + if self.username == "commie": + if (len(responses) == 1) and (responses[0] == "cat"): return AUTH_SUCCESSFUL return AUTH_FAILED -class AuthTest (unittest.TestCase): +class AuthTest(unittest.TestCase): def setUp(self): self.socks = LoopSocket() @@ -120,7 +130,7 @@ class AuthTest (unittest.TestCase): self.sockc.close() def start_server(self): - host_key = RSAKey.from_private_key_file(_support('test_rsa.key')) + host_key = RSAKey.from_private_key_file(_support("test_rsa.key")) self.public_host_key = RSAKey(data=host_key.asbytes()) self.ts.add_server_key(host_key) self.event = threading.Event() @@ -140,13 +150,16 @@ class AuthTest (unittest.TestCase): """ self.start_server() try: - self.tc.connect(hostkey=self.public_host_key, - username='unknown', password='error') + self.tc.connect( + hostkey=self.public_host_key, + username="unknown", + password="error", + ) self.assertTrue(False) except: etype, evalue, etb = sys.exc_info() self.assertEqual(BadAuthenticationType, etype) - self.assertEqual(['publickey'], evalue.allowed_types) + self.assertEqual(["publickey"], evalue.allowed_types) def test_bad_password(self): """ @@ -156,12 +169,12 @@ class AuthTest (unittest.TestCase): self.start_server() self.tc.connect(hostkey=self.public_host_key) try: - self.tc.auth_password(username='slowdive', password='error') + self.tc.auth_password(username="slowdive", password="error") self.assertTrue(False) except: etype, evalue, etb = sys.exc_info() self.assertTrue(issubclass(etype, AuthenticationException)) - self.tc.auth_password(username='slowdive', password='pygmalion') + self.tc.auth_password(username="slowdive", password="pygmalion") self.verify_finished() def test_multipart_auth(self): @@ -170,10 +183,12 @@ class AuthTest (unittest.TestCase): """ self.start_server() self.tc.connect(hostkey=self.public_host_key) - remain = self.tc.auth_password(username='paranoid', password='paranoid') - self.assertEqual(['publickey'], remain) - key = DSSKey.from_private_key_file(_support('test_dss.key')) - remain = self.tc.auth_publickey(username='paranoid', key=key) + remain = self.tc.auth_password( + username="paranoid", password="paranoid" + ) + self.assertEqual(["publickey"], remain) + key = DSSKey.from_private_key_file(_support("test_dss.key")) + remain = self.tc.auth_publickey(username="paranoid", key=key) self.assertEqual([], remain) self.verify_finished() @@ -188,10 +203,11 @@ class AuthTest (unittest.TestCase): self.got_title = title self.got_instructions = instructions self.got_prompts = prompts - return ['cat'] - remain = self.tc.auth_interactive('commie', handler) - self.assertEqual(self.got_title, 'password') - self.assertEqual(self.got_prompts, [('Password', False)]) + return ["cat"] + + remain = self.tc.auth_interactive("commie", handler) + self.assertEqual(self.got_title, "password") + self.assertEqual(self.got_prompts, [("Password", False)]) self.assertEqual([], remain) self.verify_finished() @@ -202,7 +218,7 @@ class AuthTest (unittest.TestCase): """ self.start_server() self.tc.connect(hostkey=self.public_host_key) - remain = self.tc.auth_password('commie', 'cat') + remain = self.tc.auth_password("commie", "cat") self.assertEqual([], remain) self.verify_finished() @@ -212,7 +228,7 @@ class AuthTest (unittest.TestCase): """ self.start_server() self.tc.connect(hostkey=self.public_host_key) - remain = self.tc.auth_password('utf8', _pwd) + remain = self.tc.auth_password("utf8", _pwd) self.assertEqual([], remain) self.verify_finished() @@ -223,7 +239,7 @@ class AuthTest (unittest.TestCase): """ self.start_server() self.tc.connect(hostkey=self.public_host_key) - remain = self.tc.auth_password('non-utf8', '\xff') + remain = self.tc.auth_password("non-utf8", "\xff") self.assertEqual([], remain) self.verify_finished() @@ -235,7 +251,7 @@ class AuthTest (unittest.TestCase): self.start_server() self.tc.connect(hostkey=self.public_host_key) try: - remain = self.tc.auth_password('bad-server', 'hello') + remain = self.tc.auth_password("bad-server", "hello") except: etype, evalue, etb = sys.exc_info() self.assertTrue(issubclass(etype, AuthenticationException)) @@ -250,8 +266,8 @@ class AuthTest (unittest.TestCase): self.start_server() self.tc.connect() try: - remain = self.tc.auth_password('unresponsive-server', 'hello') + remain = self.tc.auth_password("unresponsive-server", "hello") except: etype, evalue, etb = sys.exc_info() self.assertTrue(issubclass(etype, AuthenticationException)) - self.assertTrue('Authentication timeout' in str(evalue)) + self.assertTrue("Authentication timeout" in str(evalue)) |