diff options
Diffstat (limited to 'tests/test_transport.py')
-rw-r--r-- | tests/test_transport.py | 290 |
1 files changed, 286 insertions, 4 deletions
diff --git a/tests/test_transport.py b/tests/test_transport.py index e2174896..737ef705 100644 --- a/tests/test_transport.py +++ b/tests/test_transport.py @@ -23,6 +23,7 @@ Some unit tests for the ssh2 protocol in Transport. from __future__ import with_statement from binascii import hexlify +from contextlib import contextmanager import select import socket import time @@ -38,6 +39,8 @@ from paramiko import ( Packetizer, RSAKey, SSHException, + AuthenticationException, + IncompatiblePeer, SecurityOptions, ServerInterface, Transport, @@ -80,6 +83,9 @@ class NullServer(ServerInterface): paranoid_did_public_key = False paranoid_key = DSSKey.from_private_key_file(_support("test_dss.key")) + def __init__(self, allowed_keys=None): + self.allowed_keys = allowed_keys if allowed_keys is not None else [] + def get_allowed_auths(self, username): if username == "slowdive": return "publickey,password" @@ -90,6 +96,11 @@ class NullServer(ServerInterface): return AUTH_SUCCESSFUL return AUTH_FAILED + def check_auth_publickey(self, username, key): + if key in self.allowed_keys: + return AUTH_SUCCESSFUL + return AUTH_FAILED + def check_channel_request(self, kind, chanid): if kind == "bogus": return OPEN_FAILED_ADMINISTRATIVELY_PROHIBITED @@ -154,6 +165,7 @@ class TransportTest(unittest.TestCase): self.socks.close() self.sockc.close() + # TODO: unify with newer contextmanager def setup_test_server( self, client_options=None, server_options=None, connect_kwargs=None ): @@ -245,7 +257,7 @@ class TransportTest(unittest.TestCase): self.assertEqual(True, self.tc.is_authenticated()) self.assertEqual(True, self.ts.is_authenticated()) - def testa_long_banner(self): + def test_long_banner(self): """ verify that a long banner doesn't mess up the handshake. """ @@ -339,7 +351,7 @@ class TransportTest(unittest.TestCase): self.assertEqual("This is on stderr.\n", f.readline()) self.assertEqual("", f.readline()) - def testa_channel_can_be_used_as_context_manager(self): + def test_channel_can_be_used_as_context_manager(self): """ verify that exec_command() does something reasonable. """ @@ -744,7 +756,7 @@ class TransportTest(unittest.TestCase): threading.Thread.__init__( self, None, None, self.__class__.__name__ ) - self.setDaemon(True) + self.daemon = True self.chan = chan self.iterations = iterations self.done_event = done_event @@ -768,7 +780,7 @@ class TransportTest(unittest.TestCase): threading.Thread.__init__( self, None, None, self.__class__.__name__ ) - self.setDaemon(True) + self.daemon = True self.chan = chan self.done_event = done_event self.watchdog_event = threading.Event() @@ -1169,3 +1181,273 @@ class AlgorithmDisablingTests(unittest.TestCase): assert "ssh-dss" not in server_keys assert "diffie-hellman-group14-sha256" not in kexen assert "zlib" not in compressions + + +@contextmanager +def server( + hostkey=None, + init=None, + server_init=None, + client_init=None, + connect=None, + pubkeys=None, + catch_error=False, +): + """ + SSH server contextmanager for testing. + + :param hostkey: + Host key to use for the server; if None, loads + ``test_rsa.key``. + :param init: + Default `Transport` constructor kwargs to use for both sides. + :param server_init: + Extends and/or overrides ``init`` for server transport only. + :param client_init: + Extends and/or overrides ``init`` for client transport only. + :param connect: + Kwargs to use for ``connect()`` on the client. + :param pubkeys: + List of public keys for auth. + :param catch_error: + Whether to capture connection errors & yield from contextmanager. + Necessary for connection_time exception testing. + """ + if init is None: + init = {} + if server_init is None: + server_init = {} + if client_init is None: + client_init = {} + if connect is None: + connect = dict(username="slowdive", password="pygmalion") + socks = LoopSocket() + sockc = LoopSocket() + sockc.link(socks) + tc = Transport(sockc, **dict(init, **client_init)) + ts = Transport(socks, **dict(init, **server_init)) + + if hostkey is None: + hostkey = RSAKey.from_private_key_file(_support("test_rsa.key")) + ts.add_server_key(hostkey) + event = threading.Event() + server = NullServer(allowed_keys=pubkeys) + assert not event.is_set() + assert not ts.is_active() + assert tc.get_username() is None + assert ts.get_username() is None + assert not tc.is_authenticated() + assert not ts.is_authenticated() + + err = None + # Trap errors and yield instead of raising right away; otherwise callers + # cannot usefully deal with problems at connect time which stem from errors + # in the server side. + try: + ts.start_server(event, server) + tc.connect(**connect) + + event.wait(1.0) + assert event.is_set() + assert ts.is_active() + assert tc.is_active() + + except Exception as e: + if not catch_error: + raise + err = e + + yield (tc, ts, err) if catch_error else (tc, ts) + + tc.close() + ts.close() + socks.close() + sockc.close() + + +_disable_sha2 = dict( + disabled_algorithms=dict(keys=["rsa-sha2-256", "rsa-sha2-512"]) +) +_disable_sha1 = dict(disabled_algorithms=dict(keys=["ssh-rsa"])) +_disable_sha2_pubkey = dict( + disabled_algorithms=dict(pubkeys=["rsa-sha2-256", "rsa-sha2-512"]) +) +_disable_sha1_pubkey = dict(disabled_algorithms=dict(pubkeys=["ssh-rsa"])) + + +class TestSHA2SignatureKeyExchange(unittest.TestCase): + # NOTE: these all rely on the default server() hostkey being RSA + # NOTE: these rely on both sides being properly implemented re: agreed-upon + # hostkey during kex being what's actually used. Truly proving that eg + # SHA512 was used, is quite difficult w/o super gross hacks. However, there + # are new tests in test_pkey.py which use known signature blobs to prove + # the SHA2 family was in fact used! + + def test_base_case_ssh_rsa_still_used_as_fallback(self): + # Prove that ssh-rsa is used if either, or both, participants have SHA2 + # algorithms disabled + for which in ("init", "client_init", "server_init"): + with server(**{which: _disable_sha2}) as (tc, _): + assert tc.host_key_type == "ssh-rsa" + + def test_kex_with_sha2_512(self): + # It's the default! + with server() as (tc, _): + assert tc.host_key_type == "rsa-sha2-512" + + def test_kex_with_sha2_256(self): + # No 512 -> you get 256 + with server( + init=dict(disabled_algorithms=dict(keys=["rsa-sha2-512"])) + ) as (tc, _): + assert tc.host_key_type == "rsa-sha2-256" + + def _incompatible_peers(self, client_init, server_init): + with server( + client_init=client_init, server_init=server_init, catch_error=True + ) as (tc, ts, err): + # If neither side blew up then that's bad! + assert err is not None + # If client side blew up first, it'll be straightforward + if isinstance(err, IncompatiblePeer): + pass + # If server side blew up first, client sees EOF & we need to check + # the server transport for its saved error (otherwise it can only + # appear in log output) + elif isinstance(err, EOFError): + assert ts.saved_exception is not None + assert isinstance(ts.saved_exception, IncompatiblePeer) + # If it was something else, welp + else: + raise err + + def test_client_sha2_disabled_server_sha1_disabled_no_match(self): + self._incompatible_peers( + client_init=_disable_sha2, server_init=_disable_sha1 + ) + + def test_client_sha1_disabled_server_sha2_disabled_no_match(self): + self._incompatible_peers( + client_init=_disable_sha1, server_init=_disable_sha2 + ) + + def test_explicit_client_hostkey_not_limited(self): + # Be very explicit about the hostkey on BOTH ends, + # and ensure it still ends up choosing sha2-512. + # (This is a regression test vs previous implementation which overwrote + # the entire preferred-hostkeys structure when given an explicit key as + # a client.) + hostkey = RSAKey.from_private_key_file(_support("test_rsa.key")) + with server(hostkey=hostkey, connect=dict(hostkey=hostkey)) as (tc, _): + assert tc.host_key_type == "rsa-sha2-512" + + +class TestExtInfo(unittest.TestCase): + def test_ext_info_handshake(self): + with server() as (tc, _): + kex = tc._get_latest_kex_init() + assert kex["kex_algo_list"][-1] == "ext-info-c" + assert tc.server_extensions == { + "server-sig-algs": b"ssh-ed25519,ecdsa-sha2-nistp256,ecdsa-sha2-nistp384,ecdsa-sha2-nistp521,rsa-sha2-512,rsa-sha2-256,ssh-rsa,ssh-dss" # noqa + } + + def test_client_uses_server_sig_algs_for_pubkey_auth(self): + privkey = RSAKey.from_private_key_file(_support("test_rsa.key")) + with server( + pubkeys=[privkey], + connect=dict(pkey=privkey), + server_init=dict( + disabled_algorithms=dict(pubkeys=["rsa-sha2-512"]) + ), + ) as (tc, _): + assert tc.is_authenticated() + # Client settled on 256 despite itself not having 512 disabled + assert tc._agreed_pubkey_algorithm == "rsa-sha2-256" + + +# TODO: these could move into test_auth.py but that badly needs refactoring +# with this module anyways... +class TestSHA2SignaturePubkeys(unittest.TestCase): + def test_pubkey_auth_honors_disabled_algorithms(self): + privkey = RSAKey.from_private_key_file(_support("test_rsa.key")) + with server( + pubkeys=[privkey], + connect=dict(pkey=privkey), + init=dict( + disabled_algorithms=dict( + pubkeys=["ssh-rsa", "rsa-sha2-256", "rsa-sha2-512"] + ) + ), + catch_error=True, + ) as (_, _, err): + assert isinstance(err, SSHException) + assert "no RSA pubkey algorithms" in str(err) + + def test_client_sha2_disabled_server_sha1_disabled_no_match(self): + privkey = RSAKey.from_private_key_file(_support("test_rsa.key")) + with server( + pubkeys=[privkey], + connect=dict(pkey=privkey), + client_init=_disable_sha2_pubkey, + server_init=_disable_sha1_pubkey, + catch_error=True, + ) as (tc, ts, err): + assert isinstance(err, AuthenticationException) + + def test_client_sha1_disabled_server_sha2_disabled_no_match(self): + privkey = RSAKey.from_private_key_file(_support("test_rsa.key")) + with server( + pubkeys=[privkey], + connect=dict(pkey=privkey), + client_init=_disable_sha1_pubkey, + server_init=_disable_sha2_pubkey, + catch_error=True, + ) as (tc, ts, err): + assert isinstance(err, AuthenticationException) + + def test_ssh_rsa_still_used_when_sha2_disabled(self): + privkey = RSAKey.from_private_key_file(_support("test_rsa.key")) + # NOTE: this works because key obj comparison uses public bytes + # TODO: would be nice for PKey to grow a legit "give me another obj of + # same class but just the public bits" using asbytes() + with server( + pubkeys=[privkey], connect=dict(pkey=privkey), init=_disable_sha2 + ) as (tc, _): + assert tc.is_authenticated() + + def test_sha2_512(self): + privkey = RSAKey.from_private_key_file(_support("test_rsa.key")) + with server( + pubkeys=[privkey], + connect=dict(pkey=privkey), + init=dict( + disabled_algorithms=dict(pubkeys=["ssh-rsa", "rsa-sha2-256"]) + ), + ) as (tc, ts): + assert tc.is_authenticated() + assert tc._agreed_pubkey_algorithm == "rsa-sha2-512" + + def test_sha2_256(self): + privkey = RSAKey.from_private_key_file(_support("test_rsa.key")) + with server( + pubkeys=[privkey], + connect=dict(pkey=privkey), + init=dict( + disabled_algorithms=dict(pubkeys=["ssh-rsa", "rsa-sha2-512"]) + ), + ) as (tc, ts): + assert tc.is_authenticated() + assert tc._agreed_pubkey_algorithm == "rsa-sha2-256" + + def test_sha2_256_when_client_only_enables_256(self): + privkey = RSAKey.from_private_key_file(_support("test_rsa.key")) + with server( + pubkeys=[privkey], + connect=dict(pkey=privkey), + # Client-side only; server still accepts all 3. + client_init=dict( + disabled_algorithms=dict(pubkeys=["ssh-rsa", "rsa-sha2-512"]) + ), + ) as (tc, ts): + assert tc.is_authenticated() + assert tc._agreed_pubkey_algorithm == "rsa-sha2-256" |