diff options
author | Jeff Forcier <jeff@bitprophet.org> | 2023-05-04 14:46:23 -0400 |
---|---|---|
committer | Jeff Forcier <jeff@bitprophet.org> | 2023-05-18 13:57:13 -0400 |
commit | a39ddd1a9cb4b4c5afa80e0fbc05ed678708f2ce (patch) | |
tree | 3968bd8beadb58edf624a141b51db7b44f05cce3 | |
parent | e22c5ea330814801d8487dc3da347f987bafe5ec (diff) |
Modernize auth tests to use shared server manager
Also move auth tests to be new style filename, obj naming
Also allow test task module selector to see new-style test modules
-rw-r--r-- | tasks.py | 7 | ||||
-rw-r--r-- | tests/_util.py | 20 | ||||
-rw-r--r-- | tests/auth.py | 136 | ||||
-rw-r--r-- | tests/test_auth.py | 272 |
4 files changed, 160 insertions, 275 deletions
@@ -1,4 +1,5 @@ import os +from pathlib import Path from os.path import join from shutil import rmtree, copytree @@ -50,8 +51,10 @@ def test( opts += " -f" modstr = "" if module is not None: - # NOTE: implicit test_ prefix as we're not on pytest-relaxed yet - modstr = " tests/test_{}.py".format(module) + base = f"{module}.py" + tests = Path("tests") + legacy = tests / f"test_{base}" + modstr = str(legacy if legacy.exists() else tests / base) # Switch runner depending on coverage or no coverage. # TODO: get pytest's coverage plugin working, IIRC it has issues? runner = "pytest" diff --git a/tests/_util.py b/tests/_util.py index 2bfe314d..eaf6aac4 100644 --- a/tests/_util.py +++ b/tests/_util.py @@ -346,6 +346,8 @@ def server( pubkeys=None, catch_error=False, transport_factory=None, + defer=False, + skip_verify=False, ): """ SSH server contextmanager for testing. @@ -368,6 +370,13 @@ def server( Necessary for connection_time exception testing. :param transport_factory: Like the same-named param in SSHClient: which Transport class to use. + :param bool defer: + Whether to defer authentication during connecting. + + This is really just shorthand for ``connect={}`` which would do roughly + the same thing. Also: this implies skip_verify=True automatically! + :param bool skip_verify: + Whether NOT to do the default "make sure auth passed" check. """ if init is None: init = {} @@ -376,7 +385,12 @@ def server( if client_init is None: client_init = {} if connect is None: - connect = dict(username="slowdive", password="pygmalion") + # No auth at all please + if defer: + connect = dict() + # Default username based auth + else: + connect = dict(username="slowdive", password="pygmalion") socks = LoopSocket() sockc = LoopSocket() sockc.link(socks) @@ -417,6 +431,10 @@ def server( yield (tc, ts, err) if catch_error else (tc, ts) + if not (catch_error or skip_verify): + assert ts.is_authenticated() + assert tc.is_authenticated() + tc.close() ts.close() socks.close() diff --git a/tests/auth.py b/tests/auth.py new file mode 100644 index 00000000..08de6148 --- /dev/null +++ b/tests/auth.py @@ -0,0 +1,136 @@ +# Copyright (C) 2008 Robey Pointer <robeypointer@gmail.com> +# +# This file is part of paramiko. +# +# Paramiko is free software; you can redistribute it and/or modify it under the +# terms of the GNU Lesser General Public License as published by the Free +# Software Foundation; either version 2.1 of the License, or (at your option) +# any later version. +# +# Paramiko is distributed in the hope that it will be useful, but WITHOUT ANY +# WARRANTY; without even the implied warranty of MERCHANTABILITY or FITNESS FOR +# A PARTICULAR PURPOSE. See the GNU Lesser General Public License for more +# details. +# +# You should have received a copy of the GNU Lesser General Public License +# along with Paramiko; if not, write to the Free Software Foundation, Inc., +# 51 Franklin Street, Fifth Floor, Boston, MA 02110-1301 USA. + +""" +Some unit tests for authenticating over a Transport. +""" + +import unittest +from pytest import raises + +from paramiko import ( + DSSKey, + BadAuthenticationType, + AuthenticationException, +) + +from ._util import _support, server, unicodey + + +class AuthHandler_: + def bad_auth_type(self): + """ + verify that we get the right exception when an unsupported auth + type is requested. + """ + # Server won't allow password auth for this user, so should fail + # and return just publickey allowed types + with server( + connect=dict(username="unknown", password="error"), + catch_error=True, + ) as (_, _, err): + assert isinstance(err, BadAuthenticationType) + assert err.allowed_types == ["publickey"] + + def bad_password(self): + """ + verify that a bad password gets the right exception, and that a retry + with the right password works. + """ + # NOTE: Transport.connect doesn't do any auth upfront if no userauth + # related kwargs given. + with server(defer=True) as (tc, ts): + # Auth once, badly + with raises(AuthenticationException): + tc.auth_password(username="slowdive", password="error") + # And again, correctly + tc.auth_password(username="slowdive", password="pygmalion") + + def multipart_auth(self): + """ + verify that multipart auth works. + """ + with server(defer=True) as (tc, ts): + assert tc.auth_password( + username="paranoid", password="paranoid" + ) == ["publickey"] + key = DSSKey.from_private_key_file(_support("dss.key")) + assert tc.auth_publickey(username="paranoid", key=key) == [] + + def interactive_auth(self): + """ + verify keyboard-interactive auth works. + """ + + def handler(title, instructions, prompts): + self.got_title = title + self.got_instructions = instructions + self.got_prompts = prompts + return ["cat"] + + with server(defer=True) as (tc, ts): + assert tc.auth_interactive("commie", handler) == [] + assert self.got_title == "password" + assert self.got_prompts == [("Password", False)] + + def interactive_fallback(self): + """ + verify that a password auth attempt will fallback to "interactive" + if password auth isn't supported but interactive is. + """ + with server(defer=True) as (tc, ts): + # This username results in an allowed_auth of just kbd-int, + # and has a configured interactive->response on the server. + assert tc.auth_password("commie", "cat") == [] + + def utf8(self): + """ + verify that utf-8 encoding happens in authentication. + """ + with server(defer=True) as (tc, ts): + assert tc.auth_password("utf8", unicodey) == [] + + def non_utf8(self): + """ + verify that non-utf-8 encoded passwords can be used for broken + servers. + """ + with server(defer=True) as (tc, ts): + assert tc.auth_password("non-utf8", "\xff") == [] + + def auth_exception_when_disconnected(self): + """ + verify that we catch a server disconnecting during auth, and report + it as an auth failure. + """ + with server(defer=True, skip_verify=True) as (tc, ts), raises( + AuthenticationException + ): + tc.auth_password("bad-server", "hello") + + def non_responsive_triggers_auth_exception(self): + """ + verify that authentication times out if server takes to long to + respond (or never responds). + """ + with server(defer=True, skip_verify=True) as (tc, ts), raises( + AuthenticationException + ) as info: + tc.auth_timeout = 1 # 1 second, to speed up test + tc.auth_password("unresponsive-server", "hello") + assert "Authentication timeout" in str(info.value) diff --git a/tests/test_auth.py b/tests/test_auth.py deleted file mode 100644 index 02df8c12..00000000 --- a/tests/test_auth.py +++ /dev/null @@ -1,272 +0,0 @@ -# Copyright (C) 2008 Robey Pointer <robeypointer@gmail.com> -# -# This file is part of paramiko. -# -# Paramiko is free software; you can redistribute it and/or modify it under the -# terms of the GNU Lesser General Public License as published by the Free -# Software Foundation; either version 2.1 of the License, or (at your option) -# any later version. -# -# Paramiko is distributed in the hope that it will be useful, but WITHOUT ANY -# WARRANTY; without even the implied warranty of MERCHANTABILITY or FITNESS FOR -# A PARTICULAR PURPOSE. See the GNU Lesser General Public License for more -# details. -# -# You should have received a copy of the GNU Lesser General Public License -# along with Paramiko; if not, write to the Free Software Foundation, Inc., -# 51 Franklin Street, Fifth Floor, Boston, MA 02110-1301 USA. - -""" -Some unit tests for authenticating over a Transport. -""" - -import sys -import threading -import unittest -from time import sleep - -from paramiko import ( - Transport, - ServerInterface, - RSAKey, - DSSKey, - BadAuthenticationType, - InteractiveQuery, - AuthenticationException, -) -from paramiko import AUTH_FAILED, AUTH_PARTIALLY_SUCCESSFUL, AUTH_SUCCESSFUL -from paramiko.util import u - -from ._loop import LoopSocket -from ._util import _support, slow - - -_pwd = u("\u2022") - - -class NullServer(ServerInterface): - paranoid_did_password = False - paranoid_did_public_key = False - paranoid_key = DSSKey.from_private_key_file(_support("dss.key")) - - def get_allowed_auths(self, username): - if username == "slowdive": - return "publickey,password" - if username == "paranoid": - if ( - not self.paranoid_did_password - and not self.paranoid_did_public_key - ): - return "publickey,password" - elif self.paranoid_did_password: - return "publickey" - else: - return "password" - if username == "commie": - return "keyboard-interactive" - if username == "utf8": - return "password" - if username == "non-utf8": - return "password" - return "publickey" - - def check_auth_password(self, username, password): - if (username == "slowdive") and (password == "pygmalion"): - return AUTH_SUCCESSFUL - if (username == "paranoid") and (password == "paranoid"): - # 2-part auth (even openssh doesn't support this) - self.paranoid_did_password = True - if self.paranoid_did_public_key: - return AUTH_SUCCESSFUL - return AUTH_PARTIALLY_SUCCESSFUL - if (username == "utf8") and (password == _pwd): - return AUTH_SUCCESSFUL - if (username == "non-utf8") and (password == "\xff"): - return AUTH_SUCCESSFUL - if username == "bad-server": - raise Exception("Ack!") - if username == "unresponsive-server": - sleep(5) - return AUTH_SUCCESSFUL - return AUTH_FAILED - - def check_auth_publickey(self, username, key): - if (username == "paranoid") and (key == self.paranoid_key): - # 2-part auth - self.paranoid_did_public_key = True - if self.paranoid_did_password: - return AUTH_SUCCESSFUL - return AUTH_PARTIALLY_SUCCESSFUL - return AUTH_FAILED - - def check_auth_interactive(self, username, submethods): - if username == "commie": - self.username = username - return InteractiveQuery( - "password", "Please enter a password.", ("Password", False) - ) - return AUTH_FAILED - - def check_auth_interactive_response(self, responses): - if self.username == "commie": - if (len(responses) == 1) and (responses[0] == "cat"): - return AUTH_SUCCESSFUL - return AUTH_FAILED - - -class AuthTest(unittest.TestCase): - def setUp(self): - self.socks = LoopSocket() - self.sockc = LoopSocket() - self.sockc.link(self.socks) - self.tc = Transport(self.sockc) - self.ts = Transport(self.socks) - - def tearDown(self): - self.tc.close() - self.ts.close() - self.socks.close() - self.sockc.close() - - def start_server(self): - host_key = RSAKey.from_private_key_file(_support("rsa.key")) - self.public_host_key = RSAKey(data=host_key.asbytes()) - self.ts.add_server_key(host_key) - self.event = threading.Event() - self.server = NullServer() - self.assertTrue(not self.event.is_set()) - self.ts.start_server(self.event, self.server) - - def verify_finished(self): - self.event.wait(1.0) - self.assertTrue(self.event.is_set()) - self.assertTrue(self.ts.is_active()) - - def test_bad_auth_type(self): - """ - verify that we get the right exception when an unsupported auth - type is requested. - """ - self.start_server() - try: - self.tc.connect( - hostkey=self.public_host_key, - username="unknown", - password="error", - ) - self.assertTrue(False) - except: - etype, evalue, etb = sys.exc_info() - self.assertEqual(BadAuthenticationType, etype) - self.assertEqual(["publickey"], evalue.allowed_types) - - def test_bad_password(self): - """ - verify that a bad password gets the right exception, and that a retry - with the right password works. - """ - self.start_server() - self.tc.connect(hostkey=self.public_host_key) - try: - self.tc.auth_password(username="slowdive", password="error") - self.assertTrue(False) - except: - etype, evalue, etb = sys.exc_info() - self.assertTrue(issubclass(etype, AuthenticationException)) - self.tc.auth_password(username="slowdive", password="pygmalion") - self.verify_finished() - - def test_multipart_auth(self): - """ - verify that multipart auth works. - """ - self.start_server() - self.tc.connect(hostkey=self.public_host_key) - remain = self.tc.auth_password( - username="paranoid", password="paranoid" - ) - self.assertEqual(["publickey"], remain) - key = DSSKey.from_private_key_file(_support("dss.key")) - remain = self.tc.auth_publickey(username="paranoid", key=key) - self.assertEqual([], remain) - self.verify_finished() - - def test_interactive_auth(self): - """ - verify keyboard-interactive auth works. - """ - self.start_server() - self.tc.connect(hostkey=self.public_host_key) - - def handler(title, instructions, prompts): - self.got_title = title - self.got_instructions = instructions - self.got_prompts = prompts - return ["cat"] - - remain = self.tc.auth_interactive("commie", handler) - self.assertEqual(self.got_title, "password") - self.assertEqual(self.got_prompts, [("Password", False)]) - self.assertEqual([], remain) - self.verify_finished() - - def test_interactive_auth_fallback(self): - """ - verify that a password auth attempt will fallback to "interactive" - if password auth isn't supported but interactive is. - """ - self.start_server() - self.tc.connect(hostkey=self.public_host_key) - remain = self.tc.auth_password("commie", "cat") - self.assertEqual([], remain) - self.verify_finished() - - def test_auth_utf8(self): - """ - verify that utf-8 encoding happens in authentication. - """ - self.start_server() - self.tc.connect(hostkey=self.public_host_key) - remain = self.tc.auth_password("utf8", _pwd) - self.assertEqual([], remain) - self.verify_finished() - - def test_auth_non_utf8(self): - """ - verify that non-utf-8 encoded passwords can be used for broken - servers. - """ - self.start_server() - self.tc.connect(hostkey=self.public_host_key) - remain = self.tc.auth_password("non-utf8", "\xff") - self.assertEqual([], remain) - self.verify_finished() - - def test_auth_gets_disconnected(self): - """ - verify that we catch a server disconnecting during auth, and report - it as an auth failure. - """ - self.start_server() - self.tc.connect(hostkey=self.public_host_key) - try: - self.tc.auth_password("bad-server", "hello") - except: - etype, evalue, etb = sys.exc_info() - self.assertTrue(issubclass(etype, AuthenticationException)) - - @slow - def test_auth_non_responsive(self): - """ - verify that authentication times out if server takes to long to - respond (or never responds). - """ - self.tc.auth_timeout = 1 # 1 second, to speed up test - self.start_server() - self.tc.connect() - try: - self.tc.auth_password("unresponsive-server", "hello") - except: - etype, evalue, etb = sys.exc_info() - self.assertTrue(issubclass(etype, AuthenticationException)) - self.assertTrue("Authentication timeout" in str(evalue)) |