summaryrefslogtreecommitdiffhomepage
path: root/tests/test_transport.py
diff options
context:
space:
mode:
Diffstat (limited to 'tests/test_transport.py')
-rw-r--r--tests/test_transport.py290
1 files changed, 286 insertions, 4 deletions
diff --git a/tests/test_transport.py b/tests/test_transport.py
index e2174896..737ef705 100644
--- a/tests/test_transport.py
+++ b/tests/test_transport.py
@@ -23,6 +23,7 @@ Some unit tests for the ssh2 protocol in Transport.
from __future__ import with_statement
from binascii import hexlify
+from contextlib import contextmanager
import select
import socket
import time
@@ -38,6 +39,8 @@ from paramiko import (
Packetizer,
RSAKey,
SSHException,
+ AuthenticationException,
+ IncompatiblePeer,
SecurityOptions,
ServerInterface,
Transport,
@@ -80,6 +83,9 @@ class NullServer(ServerInterface):
paranoid_did_public_key = False
paranoid_key = DSSKey.from_private_key_file(_support("test_dss.key"))
+ def __init__(self, allowed_keys=None):
+ self.allowed_keys = allowed_keys if allowed_keys is not None else []
+
def get_allowed_auths(self, username):
if username == "slowdive":
return "publickey,password"
@@ -90,6 +96,11 @@ class NullServer(ServerInterface):
return AUTH_SUCCESSFUL
return AUTH_FAILED
+ def check_auth_publickey(self, username, key):
+ if key in self.allowed_keys:
+ return AUTH_SUCCESSFUL
+ return AUTH_FAILED
+
def check_channel_request(self, kind, chanid):
if kind == "bogus":
return OPEN_FAILED_ADMINISTRATIVELY_PROHIBITED
@@ -154,6 +165,7 @@ class TransportTest(unittest.TestCase):
self.socks.close()
self.sockc.close()
+ # TODO: unify with newer contextmanager
def setup_test_server(
self, client_options=None, server_options=None, connect_kwargs=None
):
@@ -245,7 +257,7 @@ class TransportTest(unittest.TestCase):
self.assertEqual(True, self.tc.is_authenticated())
self.assertEqual(True, self.ts.is_authenticated())
- def testa_long_banner(self):
+ def test_long_banner(self):
"""
verify that a long banner doesn't mess up the handshake.
"""
@@ -339,7 +351,7 @@ class TransportTest(unittest.TestCase):
self.assertEqual("This is on stderr.\n", f.readline())
self.assertEqual("", f.readline())
- def testa_channel_can_be_used_as_context_manager(self):
+ def test_channel_can_be_used_as_context_manager(self):
"""
verify that exec_command() does something reasonable.
"""
@@ -744,7 +756,7 @@ class TransportTest(unittest.TestCase):
threading.Thread.__init__(
self, None, None, self.__class__.__name__
)
- self.setDaemon(True)
+ self.daemon = True
self.chan = chan
self.iterations = iterations
self.done_event = done_event
@@ -768,7 +780,7 @@ class TransportTest(unittest.TestCase):
threading.Thread.__init__(
self, None, None, self.__class__.__name__
)
- self.setDaemon(True)
+ self.daemon = True
self.chan = chan
self.done_event = done_event
self.watchdog_event = threading.Event()
@@ -1169,3 +1181,273 @@ class AlgorithmDisablingTests(unittest.TestCase):
assert "ssh-dss" not in server_keys
assert "diffie-hellman-group14-sha256" not in kexen
assert "zlib" not in compressions
+
+
+@contextmanager
+def server(
+ hostkey=None,
+ init=None,
+ server_init=None,
+ client_init=None,
+ connect=None,
+ pubkeys=None,
+ catch_error=False,
+):
+ """
+ SSH server contextmanager for testing.
+
+ :param hostkey:
+ Host key to use for the server; if None, loads
+ ``test_rsa.key``.
+ :param init:
+ Default `Transport` constructor kwargs to use for both sides.
+ :param server_init:
+ Extends and/or overrides ``init`` for server transport only.
+ :param client_init:
+ Extends and/or overrides ``init`` for client transport only.
+ :param connect:
+ Kwargs to use for ``connect()`` on the client.
+ :param pubkeys:
+ List of public keys for auth.
+ :param catch_error:
+ Whether to capture connection errors & yield from contextmanager.
+ Necessary for connection_time exception testing.
+ """
+ if init is None:
+ init = {}
+ if server_init is None:
+ server_init = {}
+ if client_init is None:
+ client_init = {}
+ if connect is None:
+ connect = dict(username="slowdive", password="pygmalion")
+ socks = LoopSocket()
+ sockc = LoopSocket()
+ sockc.link(socks)
+ tc = Transport(sockc, **dict(init, **client_init))
+ ts = Transport(socks, **dict(init, **server_init))
+
+ if hostkey is None:
+ hostkey = RSAKey.from_private_key_file(_support("test_rsa.key"))
+ ts.add_server_key(hostkey)
+ event = threading.Event()
+ server = NullServer(allowed_keys=pubkeys)
+ assert not event.is_set()
+ assert not ts.is_active()
+ assert tc.get_username() is None
+ assert ts.get_username() is None
+ assert not tc.is_authenticated()
+ assert not ts.is_authenticated()
+
+ err = None
+ # Trap errors and yield instead of raising right away; otherwise callers
+ # cannot usefully deal with problems at connect time which stem from errors
+ # in the server side.
+ try:
+ ts.start_server(event, server)
+ tc.connect(**connect)
+
+ event.wait(1.0)
+ assert event.is_set()
+ assert ts.is_active()
+ assert tc.is_active()
+
+ except Exception as e:
+ if not catch_error:
+ raise
+ err = e
+
+ yield (tc, ts, err) if catch_error else (tc, ts)
+
+ tc.close()
+ ts.close()
+ socks.close()
+ sockc.close()
+
+
+_disable_sha2 = dict(
+ disabled_algorithms=dict(keys=["rsa-sha2-256", "rsa-sha2-512"])
+)
+_disable_sha1 = dict(disabled_algorithms=dict(keys=["ssh-rsa"]))
+_disable_sha2_pubkey = dict(
+ disabled_algorithms=dict(pubkeys=["rsa-sha2-256", "rsa-sha2-512"])
+)
+_disable_sha1_pubkey = dict(disabled_algorithms=dict(pubkeys=["ssh-rsa"]))
+
+
+class TestSHA2SignatureKeyExchange(unittest.TestCase):
+ # NOTE: these all rely on the default server() hostkey being RSA
+ # NOTE: these rely on both sides being properly implemented re: agreed-upon
+ # hostkey during kex being what's actually used. Truly proving that eg
+ # SHA512 was used, is quite difficult w/o super gross hacks. However, there
+ # are new tests in test_pkey.py which use known signature blobs to prove
+ # the SHA2 family was in fact used!
+
+ def test_base_case_ssh_rsa_still_used_as_fallback(self):
+ # Prove that ssh-rsa is used if either, or both, participants have SHA2
+ # algorithms disabled
+ for which in ("init", "client_init", "server_init"):
+ with server(**{which: _disable_sha2}) as (tc, _):
+ assert tc.host_key_type == "ssh-rsa"
+
+ def test_kex_with_sha2_512(self):
+ # It's the default!
+ with server() as (tc, _):
+ assert tc.host_key_type == "rsa-sha2-512"
+
+ def test_kex_with_sha2_256(self):
+ # No 512 -> you get 256
+ with server(
+ init=dict(disabled_algorithms=dict(keys=["rsa-sha2-512"]))
+ ) as (tc, _):
+ assert tc.host_key_type == "rsa-sha2-256"
+
+ def _incompatible_peers(self, client_init, server_init):
+ with server(
+ client_init=client_init, server_init=server_init, catch_error=True
+ ) as (tc, ts, err):
+ # If neither side blew up then that's bad!
+ assert err is not None
+ # If client side blew up first, it'll be straightforward
+ if isinstance(err, IncompatiblePeer):
+ pass
+ # If server side blew up first, client sees EOF & we need to check
+ # the server transport for its saved error (otherwise it can only
+ # appear in log output)
+ elif isinstance(err, EOFError):
+ assert ts.saved_exception is not None
+ assert isinstance(ts.saved_exception, IncompatiblePeer)
+ # If it was something else, welp
+ else:
+ raise err
+
+ def test_client_sha2_disabled_server_sha1_disabled_no_match(self):
+ self._incompatible_peers(
+ client_init=_disable_sha2, server_init=_disable_sha1
+ )
+
+ def test_client_sha1_disabled_server_sha2_disabled_no_match(self):
+ self._incompatible_peers(
+ client_init=_disable_sha1, server_init=_disable_sha2
+ )
+
+ def test_explicit_client_hostkey_not_limited(self):
+ # Be very explicit about the hostkey on BOTH ends,
+ # and ensure it still ends up choosing sha2-512.
+ # (This is a regression test vs previous implementation which overwrote
+ # the entire preferred-hostkeys structure when given an explicit key as
+ # a client.)
+ hostkey = RSAKey.from_private_key_file(_support("test_rsa.key"))
+ with server(hostkey=hostkey, connect=dict(hostkey=hostkey)) as (tc, _):
+ assert tc.host_key_type == "rsa-sha2-512"
+
+
+class TestExtInfo(unittest.TestCase):
+ def test_ext_info_handshake(self):
+ with server() as (tc, _):
+ kex = tc._get_latest_kex_init()
+ assert kex["kex_algo_list"][-1] == "ext-info-c"
+ assert tc.server_extensions == {
+ "server-sig-algs": b"ssh-ed25519,ecdsa-sha2-nistp256,ecdsa-sha2-nistp384,ecdsa-sha2-nistp521,rsa-sha2-512,rsa-sha2-256,ssh-rsa,ssh-dss" # noqa
+ }
+
+ def test_client_uses_server_sig_algs_for_pubkey_auth(self):
+ privkey = RSAKey.from_private_key_file(_support("test_rsa.key"))
+ with server(
+ pubkeys=[privkey],
+ connect=dict(pkey=privkey),
+ server_init=dict(
+ disabled_algorithms=dict(pubkeys=["rsa-sha2-512"])
+ ),
+ ) as (tc, _):
+ assert tc.is_authenticated()
+ # Client settled on 256 despite itself not having 512 disabled
+ assert tc._agreed_pubkey_algorithm == "rsa-sha2-256"
+
+
+# TODO: these could move into test_auth.py but that badly needs refactoring
+# with this module anyways...
+class TestSHA2SignaturePubkeys(unittest.TestCase):
+ def test_pubkey_auth_honors_disabled_algorithms(self):
+ privkey = RSAKey.from_private_key_file(_support("test_rsa.key"))
+ with server(
+ pubkeys=[privkey],
+ connect=dict(pkey=privkey),
+ init=dict(
+ disabled_algorithms=dict(
+ pubkeys=["ssh-rsa", "rsa-sha2-256", "rsa-sha2-512"]
+ )
+ ),
+ catch_error=True,
+ ) as (_, _, err):
+ assert isinstance(err, SSHException)
+ assert "no RSA pubkey algorithms" in str(err)
+
+ def test_client_sha2_disabled_server_sha1_disabled_no_match(self):
+ privkey = RSAKey.from_private_key_file(_support("test_rsa.key"))
+ with server(
+ pubkeys=[privkey],
+ connect=dict(pkey=privkey),
+ client_init=_disable_sha2_pubkey,
+ server_init=_disable_sha1_pubkey,
+ catch_error=True,
+ ) as (tc, ts, err):
+ assert isinstance(err, AuthenticationException)
+
+ def test_client_sha1_disabled_server_sha2_disabled_no_match(self):
+ privkey = RSAKey.from_private_key_file(_support("test_rsa.key"))
+ with server(
+ pubkeys=[privkey],
+ connect=dict(pkey=privkey),
+ client_init=_disable_sha1_pubkey,
+ server_init=_disable_sha2_pubkey,
+ catch_error=True,
+ ) as (tc, ts, err):
+ assert isinstance(err, AuthenticationException)
+
+ def test_ssh_rsa_still_used_when_sha2_disabled(self):
+ privkey = RSAKey.from_private_key_file(_support("test_rsa.key"))
+ # NOTE: this works because key obj comparison uses public bytes
+ # TODO: would be nice for PKey to grow a legit "give me another obj of
+ # same class but just the public bits" using asbytes()
+ with server(
+ pubkeys=[privkey], connect=dict(pkey=privkey), init=_disable_sha2
+ ) as (tc, _):
+ assert tc.is_authenticated()
+
+ def test_sha2_512(self):
+ privkey = RSAKey.from_private_key_file(_support("test_rsa.key"))
+ with server(
+ pubkeys=[privkey],
+ connect=dict(pkey=privkey),
+ init=dict(
+ disabled_algorithms=dict(pubkeys=["ssh-rsa", "rsa-sha2-256"])
+ ),
+ ) as (tc, ts):
+ assert tc.is_authenticated()
+ assert tc._agreed_pubkey_algorithm == "rsa-sha2-512"
+
+ def test_sha2_256(self):
+ privkey = RSAKey.from_private_key_file(_support("test_rsa.key"))
+ with server(
+ pubkeys=[privkey],
+ connect=dict(pkey=privkey),
+ init=dict(
+ disabled_algorithms=dict(pubkeys=["ssh-rsa", "rsa-sha2-512"])
+ ),
+ ) as (tc, ts):
+ assert tc.is_authenticated()
+ assert tc._agreed_pubkey_algorithm == "rsa-sha2-256"
+
+ def test_sha2_256_when_client_only_enables_256(self):
+ privkey = RSAKey.from_private_key_file(_support("test_rsa.key"))
+ with server(
+ pubkeys=[privkey],
+ connect=dict(pkey=privkey),
+ # Client-side only; server still accepts all 3.
+ client_init=dict(
+ disabled_algorithms=dict(pubkeys=["ssh-rsa", "rsa-sha2-512"])
+ ),
+ ) as (tc, ts):
+ assert tc.is_authenticated()
+ assert tc._agreed_pubkey_algorithm == "rsa-sha2-256"