summaryrefslogtreecommitdiffhomepage
diff options
context:
space:
mode:
authorJeff Forcier <jeff@bitprophet.org>2023-05-04 14:46:23 -0400
committerJeff Forcier <jeff@bitprophet.org>2023-05-18 13:57:13 -0400
commita39ddd1a9cb4b4c5afa80e0fbc05ed678708f2ce (patch)
tree3968bd8beadb58edf624a141b51db7b44f05cce3
parente22c5ea330814801d8487dc3da347f987bafe5ec (diff)
Modernize auth tests to use shared server manager
Also move auth tests to be new style filename, obj naming Also allow test task module selector to see new-style test modules
-rw-r--r--tasks.py7
-rw-r--r--tests/_util.py20
-rw-r--r--tests/auth.py136
-rw-r--r--tests/test_auth.py272
4 files changed, 160 insertions, 275 deletions
diff --git a/tasks.py b/tasks.py
index 1f5e999c..f58f699e 100644
--- a/tasks.py
+++ b/tasks.py
@@ -1,4 +1,5 @@
import os
+from pathlib import Path
from os.path import join
from shutil import rmtree, copytree
@@ -50,8 +51,10 @@ def test(
opts += " -f"
modstr = ""
if module is not None:
- # NOTE: implicit test_ prefix as we're not on pytest-relaxed yet
- modstr = " tests/test_{}.py".format(module)
+ base = f"{module}.py"
+ tests = Path("tests")
+ legacy = tests / f"test_{base}"
+ modstr = str(legacy if legacy.exists() else tests / base)
# Switch runner depending on coverage or no coverage.
# TODO: get pytest's coverage plugin working, IIRC it has issues?
runner = "pytest"
diff --git a/tests/_util.py b/tests/_util.py
index 2bfe314d..eaf6aac4 100644
--- a/tests/_util.py
+++ b/tests/_util.py
@@ -346,6 +346,8 @@ def server(
pubkeys=None,
catch_error=False,
transport_factory=None,
+ defer=False,
+ skip_verify=False,
):
"""
SSH server contextmanager for testing.
@@ -368,6 +370,13 @@ def server(
Necessary for connection_time exception testing.
:param transport_factory:
Like the same-named param in SSHClient: which Transport class to use.
+ :param bool defer:
+ Whether to defer authentication during connecting.
+
+ This is really just shorthand for ``connect={}`` which would do roughly
+ the same thing. Also: this implies skip_verify=True automatically!
+ :param bool skip_verify:
+ Whether NOT to do the default "make sure auth passed" check.
"""
if init is None:
init = {}
@@ -376,7 +385,12 @@ def server(
if client_init is None:
client_init = {}
if connect is None:
- connect = dict(username="slowdive", password="pygmalion")
+ # No auth at all please
+ if defer:
+ connect = dict()
+ # Default username based auth
+ else:
+ connect = dict(username="slowdive", password="pygmalion")
socks = LoopSocket()
sockc = LoopSocket()
sockc.link(socks)
@@ -417,6 +431,10 @@ def server(
yield (tc, ts, err) if catch_error else (tc, ts)
+ if not (catch_error or skip_verify):
+ assert ts.is_authenticated()
+ assert tc.is_authenticated()
+
tc.close()
ts.close()
socks.close()
diff --git a/tests/auth.py b/tests/auth.py
new file mode 100644
index 00000000..08de6148
--- /dev/null
+++ b/tests/auth.py
@@ -0,0 +1,136 @@
+# Copyright (C) 2008 Robey Pointer <robeypointer@gmail.com>
+#
+# This file is part of paramiko.
+#
+# Paramiko is free software; you can redistribute it and/or modify it under the
+# terms of the GNU Lesser General Public License as published by the Free
+# Software Foundation; either version 2.1 of the License, or (at your option)
+# any later version.
+#
+# Paramiko is distributed in the hope that it will be useful, but WITHOUT ANY
+# WARRANTY; without even the implied warranty of MERCHANTABILITY or FITNESS FOR
+# A PARTICULAR PURPOSE. See the GNU Lesser General Public License for more
+# details.
+#
+# You should have received a copy of the GNU Lesser General Public License
+# along with Paramiko; if not, write to the Free Software Foundation, Inc.,
+# 51 Franklin Street, Fifth Floor, Boston, MA 02110-1301 USA.
+
+"""
+Some unit tests for authenticating over a Transport.
+"""
+
+import unittest
+from pytest import raises
+
+from paramiko import (
+ DSSKey,
+ BadAuthenticationType,
+ AuthenticationException,
+)
+
+from ._util import _support, server, unicodey
+
+
+class AuthHandler_:
+ def bad_auth_type(self):
+ """
+ verify that we get the right exception when an unsupported auth
+ type is requested.
+ """
+ # Server won't allow password auth for this user, so should fail
+ # and return just publickey allowed types
+ with server(
+ connect=dict(username="unknown", password="error"),
+ catch_error=True,
+ ) as (_, _, err):
+ assert isinstance(err, BadAuthenticationType)
+ assert err.allowed_types == ["publickey"]
+
+ def bad_password(self):
+ """
+ verify that a bad password gets the right exception, and that a retry
+ with the right password works.
+ """
+ # NOTE: Transport.connect doesn't do any auth upfront if no userauth
+ # related kwargs given.
+ with server(defer=True) as (tc, ts):
+ # Auth once, badly
+ with raises(AuthenticationException):
+ tc.auth_password(username="slowdive", password="error")
+ # And again, correctly
+ tc.auth_password(username="slowdive", password="pygmalion")
+
+ def multipart_auth(self):
+ """
+ verify that multipart auth works.
+ """
+ with server(defer=True) as (tc, ts):
+ assert tc.auth_password(
+ username="paranoid", password="paranoid"
+ ) == ["publickey"]
+ key = DSSKey.from_private_key_file(_support("dss.key"))
+ assert tc.auth_publickey(username="paranoid", key=key) == []
+
+ def interactive_auth(self):
+ """
+ verify keyboard-interactive auth works.
+ """
+
+ def handler(title, instructions, prompts):
+ self.got_title = title
+ self.got_instructions = instructions
+ self.got_prompts = prompts
+ return ["cat"]
+
+ with server(defer=True) as (tc, ts):
+ assert tc.auth_interactive("commie", handler) == []
+ assert self.got_title == "password"
+ assert self.got_prompts == [("Password", False)]
+
+ def interactive_fallback(self):
+ """
+ verify that a password auth attempt will fallback to "interactive"
+ if password auth isn't supported but interactive is.
+ """
+ with server(defer=True) as (tc, ts):
+ # This username results in an allowed_auth of just kbd-int,
+ # and has a configured interactive->response on the server.
+ assert tc.auth_password("commie", "cat") == []
+
+ def utf8(self):
+ """
+ verify that utf-8 encoding happens in authentication.
+ """
+ with server(defer=True) as (tc, ts):
+ assert tc.auth_password("utf8", unicodey) == []
+
+ def non_utf8(self):
+ """
+ verify that non-utf-8 encoded passwords can be used for broken
+ servers.
+ """
+ with server(defer=True) as (tc, ts):
+ assert tc.auth_password("non-utf8", "\xff") == []
+
+ def auth_exception_when_disconnected(self):
+ """
+ verify that we catch a server disconnecting during auth, and report
+ it as an auth failure.
+ """
+ with server(defer=True, skip_verify=True) as (tc, ts), raises(
+ AuthenticationException
+ ):
+ tc.auth_password("bad-server", "hello")
+
+ def non_responsive_triggers_auth_exception(self):
+ """
+ verify that authentication times out if server takes to long to
+ respond (or never responds).
+ """
+ with server(defer=True, skip_verify=True) as (tc, ts), raises(
+ AuthenticationException
+ ) as info:
+ tc.auth_timeout = 1 # 1 second, to speed up test
+ tc.auth_password("unresponsive-server", "hello")
+ assert "Authentication timeout" in str(info.value)
diff --git a/tests/test_auth.py b/tests/test_auth.py
deleted file mode 100644
index 02df8c12..00000000
--- a/tests/test_auth.py
+++ /dev/null
@@ -1,272 +0,0 @@
-# Copyright (C) 2008 Robey Pointer <robeypointer@gmail.com>
-#
-# This file is part of paramiko.
-#
-# Paramiko is free software; you can redistribute it and/or modify it under the
-# terms of the GNU Lesser General Public License as published by the Free
-# Software Foundation; either version 2.1 of the License, or (at your option)
-# any later version.
-#
-# Paramiko is distributed in the hope that it will be useful, but WITHOUT ANY
-# WARRANTY; without even the implied warranty of MERCHANTABILITY or FITNESS FOR
-# A PARTICULAR PURPOSE. See the GNU Lesser General Public License for more
-# details.
-#
-# You should have received a copy of the GNU Lesser General Public License
-# along with Paramiko; if not, write to the Free Software Foundation, Inc.,
-# 51 Franklin Street, Fifth Floor, Boston, MA 02110-1301 USA.
-
-"""
-Some unit tests for authenticating over a Transport.
-"""
-
-import sys
-import threading
-import unittest
-from time import sleep
-
-from paramiko import (
- Transport,
- ServerInterface,
- RSAKey,
- DSSKey,
- BadAuthenticationType,
- InteractiveQuery,
- AuthenticationException,
-)
-from paramiko import AUTH_FAILED, AUTH_PARTIALLY_SUCCESSFUL, AUTH_SUCCESSFUL
-from paramiko.util import u
-
-from ._loop import LoopSocket
-from ._util import _support, slow
-
-
-_pwd = u("\u2022")
-
-
-class NullServer(ServerInterface):
- paranoid_did_password = False
- paranoid_did_public_key = False
- paranoid_key = DSSKey.from_private_key_file(_support("dss.key"))
-
- def get_allowed_auths(self, username):
- if username == "slowdive":
- return "publickey,password"
- if username == "paranoid":
- if (
- not self.paranoid_did_password
- and not self.paranoid_did_public_key
- ):
- return "publickey,password"
- elif self.paranoid_did_password:
- return "publickey"
- else:
- return "password"
- if username == "commie":
- return "keyboard-interactive"
- if username == "utf8":
- return "password"
- if username == "non-utf8":
- return "password"
- return "publickey"
-
- def check_auth_password(self, username, password):
- if (username == "slowdive") and (password == "pygmalion"):
- return AUTH_SUCCESSFUL
- if (username == "paranoid") and (password == "paranoid"):
- # 2-part auth (even openssh doesn't support this)
- self.paranoid_did_password = True
- if self.paranoid_did_public_key:
- return AUTH_SUCCESSFUL
- return AUTH_PARTIALLY_SUCCESSFUL
- if (username == "utf8") and (password == _pwd):
- return AUTH_SUCCESSFUL
- if (username == "non-utf8") and (password == "\xff"):
- return AUTH_SUCCESSFUL
- if username == "bad-server":
- raise Exception("Ack!")
- if username == "unresponsive-server":
- sleep(5)
- return AUTH_SUCCESSFUL
- return AUTH_FAILED
-
- def check_auth_publickey(self, username, key):
- if (username == "paranoid") and (key == self.paranoid_key):
- # 2-part auth
- self.paranoid_did_public_key = True
- if self.paranoid_did_password:
- return AUTH_SUCCESSFUL
- return AUTH_PARTIALLY_SUCCESSFUL
- return AUTH_FAILED
-
- def check_auth_interactive(self, username, submethods):
- if username == "commie":
- self.username = username
- return InteractiveQuery(
- "password", "Please enter a password.", ("Password", False)
- )
- return AUTH_FAILED
-
- def check_auth_interactive_response(self, responses):
- if self.username == "commie":
- if (len(responses) == 1) and (responses[0] == "cat"):
- return AUTH_SUCCESSFUL
- return AUTH_FAILED
-
-
-class AuthTest(unittest.TestCase):
- def setUp(self):
- self.socks = LoopSocket()
- self.sockc = LoopSocket()
- self.sockc.link(self.socks)
- self.tc = Transport(self.sockc)
- self.ts = Transport(self.socks)
-
- def tearDown(self):
- self.tc.close()
- self.ts.close()
- self.socks.close()
- self.sockc.close()
-
- def start_server(self):
- host_key = RSAKey.from_private_key_file(_support("rsa.key"))
- self.public_host_key = RSAKey(data=host_key.asbytes())
- self.ts.add_server_key(host_key)
- self.event = threading.Event()
- self.server = NullServer()
- self.assertTrue(not self.event.is_set())
- self.ts.start_server(self.event, self.server)
-
- def verify_finished(self):
- self.event.wait(1.0)
- self.assertTrue(self.event.is_set())
- self.assertTrue(self.ts.is_active())
-
- def test_bad_auth_type(self):
- """
- verify that we get the right exception when an unsupported auth
- type is requested.
- """
- self.start_server()
- try:
- self.tc.connect(
- hostkey=self.public_host_key,
- username="unknown",
- password="error",
- )
- self.assertTrue(False)
- except:
- etype, evalue, etb = sys.exc_info()
- self.assertEqual(BadAuthenticationType, etype)
- self.assertEqual(["publickey"], evalue.allowed_types)
-
- def test_bad_password(self):
- """
- verify that a bad password gets the right exception, and that a retry
- with the right password works.
- """
- self.start_server()
- self.tc.connect(hostkey=self.public_host_key)
- try:
- self.tc.auth_password(username="slowdive", password="error")
- self.assertTrue(False)
- except:
- etype, evalue, etb = sys.exc_info()
- self.assertTrue(issubclass(etype, AuthenticationException))
- self.tc.auth_password(username="slowdive", password="pygmalion")
- self.verify_finished()
-
- def test_multipart_auth(self):
- """
- verify that multipart auth works.
- """
- self.start_server()
- self.tc.connect(hostkey=self.public_host_key)
- remain = self.tc.auth_password(
- username="paranoid", password="paranoid"
- )
- self.assertEqual(["publickey"], remain)
- key = DSSKey.from_private_key_file(_support("dss.key"))
- remain = self.tc.auth_publickey(username="paranoid", key=key)
- self.assertEqual([], remain)
- self.verify_finished()
-
- def test_interactive_auth(self):
- """
- verify keyboard-interactive auth works.
- """
- self.start_server()
- self.tc.connect(hostkey=self.public_host_key)
-
- def handler(title, instructions, prompts):
- self.got_title = title
- self.got_instructions = instructions
- self.got_prompts = prompts
- return ["cat"]
-
- remain = self.tc.auth_interactive("commie", handler)
- self.assertEqual(self.got_title, "password")
- self.assertEqual(self.got_prompts, [("Password", False)])
- self.assertEqual([], remain)
- self.verify_finished()
-
- def test_interactive_auth_fallback(self):
- """
- verify that a password auth attempt will fallback to "interactive"
- if password auth isn't supported but interactive is.
- """
- self.start_server()
- self.tc.connect(hostkey=self.public_host_key)
- remain = self.tc.auth_password("commie", "cat")
- self.assertEqual([], remain)
- self.verify_finished()
-
- def test_auth_utf8(self):
- """
- verify that utf-8 encoding happens in authentication.
- """
- self.start_server()
- self.tc.connect(hostkey=self.public_host_key)
- remain = self.tc.auth_password("utf8", _pwd)
- self.assertEqual([], remain)
- self.verify_finished()
-
- def test_auth_non_utf8(self):
- """
- verify that non-utf-8 encoded passwords can be used for broken
- servers.
- """
- self.start_server()
- self.tc.connect(hostkey=self.public_host_key)
- remain = self.tc.auth_password("non-utf8", "\xff")
- self.assertEqual([], remain)
- self.verify_finished()
-
- def test_auth_gets_disconnected(self):
- """
- verify that we catch a server disconnecting during auth, and report
- it as an auth failure.
- """
- self.start_server()
- self.tc.connect(hostkey=self.public_host_key)
- try:
- self.tc.auth_password("bad-server", "hello")
- except:
- etype, evalue, etb = sys.exc_info()
- self.assertTrue(issubclass(etype, AuthenticationException))
-
- @slow
- def test_auth_non_responsive(self):
- """
- verify that authentication times out if server takes to long to
- respond (or never responds).
- """
- self.tc.auth_timeout = 1 # 1 second, to speed up test
- self.start_server()
- self.tc.connect()
- try:
- self.tc.auth_password("unresponsive-server", "hello")
- except:
- etype, evalue, etb = sys.exc_info()
- self.assertTrue(issubclass(etype, AuthenticationException))
- self.assertTrue("Authentication timeout" in str(evalue))