diff options
Diffstat (limited to 'tests/test_client.py')
-rw-r--r-- | tests/test_client.py | 59 |
1 files changed, 53 insertions, 6 deletions
diff --git a/tests/test_client.py b/tests/test_client.py index 87f7bcb2..fec1485e 100644 --- a/tests/test_client.py +++ b/tests/test_client.py @@ -34,6 +34,7 @@ import weakref from tempfile import mkstemp import paramiko +from paramiko.pkey import PublicBlob from paramiko.common import PY2 from paramiko.ssh_exception import SSHException, AuthenticationException @@ -52,6 +53,8 @@ class NullServer(paramiko.ServerInterface): def __init__(self, *args, **kwargs): # Allow tests to enable/disable specific key types self.__allowed_keys = kwargs.pop("allowed_keys", []) + # And allow them to set a (single...meh) expected public blob (cert) + self.__expected_public_blob = kwargs.pop("public_blob", None) super(NullServer, self).__init__(*args, **kwargs) def get_allowed_auths(self, username): @@ -72,12 +75,18 @@ class NullServer(paramiko.ServerInterface): expected = FINGERPRINTS[key.get_name()] except KeyError: return paramiko.AUTH_FAILED - if ( + # Base check: allowed auth type & fingerprint matches + happy = ( key.get_name() in self.__allowed_keys and key.get_fingerprint() == expected + ) + # Secondary check: if test wants assertions about cert data + if ( + self.__expected_public_blob is not None + and key.public_blob != self.__expected_public_blob ): - return paramiko.AUTH_SUCCESSFUL - return paramiko.AUTH_FAILED + happy = False + return paramiko.AUTH_SUCCESSFUL if happy else paramiko.AUTH_FAILED def check_channel_request(self, kind, chanid): return paramiko.OPEN_SUCCEEDED @@ -117,7 +126,7 @@ class SSHClientTest(unittest.TestCase): if hasattr(self, attr): getattr(self, attr).close() - def _run(self, allowed_keys=None, delay=0): + def _run(self, allowed_keys=None, delay=0, public_blob=None): if allowed_keys is None: allowed_keys = FINGERPRINTS.keys() self.socks, addr = self.sockl.accept() @@ -128,7 +137,7 @@ class SSHClientTest(unittest.TestCase): keypath = _support("test_ecdsa_256.key") host_key = paramiko.ECDSAKey.from_private_key_file(keypath) self.ts.add_server_key(host_key) - server = NullServer(allowed_keys=allowed_keys) + server = NullServer(allowed_keys=allowed_keys, public_blob=public_blob) if delay: time.sleep(delay) self.ts.start_server(self.event, server) @@ -140,7 +149,9 @@ class SSHClientTest(unittest.TestCase): The exception is ``allowed_keys`` which is stripped and handed to the ``NullServer`` used for testing. """ - run_kwargs = {"allowed_keys": kwargs.pop("allowed_keys", None)} + run_kwargs = {} + for key in ("allowed_keys", "public_blob"): + run_kwargs[key] = kwargs.pop(key, None) # Server setup threading.Thread(target=self._run, kwargs=run_kwargs).start() host_key = paramiko.RSAKey.from_private_key_file( @@ -254,6 +265,42 @@ class SSHClientTest(unittest.TestCase): allowed_keys=["ecdsa-sha2-nistp256"], ) + def test_certs_allowed_as_key_filename_values(self): + # NOTE: giving cert path here, not key path. (Key path test is below. + # They're similar except for which path is given; the expected auth and + # server-side behavior is 100% identical.) + # NOTE: only bothered whipping up one cert per overall class/family. + for type_ in ("rsa", "dss", "ecdsa_256", "ed25519"): + cert_name = "test_{0}.key-cert.pub".format(type_) + cert_path = _support(os.path.join("cert_support", cert_name)) + self._test_connection( + key_filename=cert_path, + public_blob=PublicBlob.from_file(cert_path), + ) + + def test_certs_implicitly_loaded_alongside_key_filename_keys(self): + # NOTE: a regular test_connection() w/ test_rsa.key would incidentally + # test this (because test_xxx.key-cert.pub exists) but incidental tests + # stink, so NullServer and friends were updated to allow assertions + # about the server-side key object's public blob. Thus, we can prove + # that a specific cert was found, along with regular authorization + # succeeding proving that the overall flow works. + for type_ in ("rsa", "dss", "ecdsa_256", "ed25519"): + key_name = "test_{0}.key".format(type_) + key_path = _support(os.path.join("cert_support", key_name)) + self._test_connection( + key_filename=key_path, + public_blob=PublicBlob.from_file( + "{0}-cert.pub".format(key_path) + ), + ) + + def test_default_key_locations_trigger_cert_loads_if_found(self): + # TODO: what it says on the tin: ~/.ssh/id_rsa tries to load + # ~/.ssh/id_rsa-cert.pub. Right now no other tests actually test that + # code path (!) so we're punting too, sob. + pass + def test_4_auto_add_policy(self): """ verify that SSHClient's AutoAddPolicy works. |