diff options
Diffstat (limited to 'tests/test_transport.py')
-rw-r--r-- | tests/test_transport.py | 326 |
1 files changed, 53 insertions, 273 deletions
diff --git a/tests/test_transport.py b/tests/test_transport.py index 4062d767..421c078b 100644 --- a/tests/test_transport.py +++ b/tests/test_transport.py @@ -22,7 +22,6 @@ Some unit tests for the ssh2 protocol in Transport. from binascii import hexlify -from contextlib import contextmanager import select import socket import time @@ -34,18 +33,16 @@ from unittest.mock import Mock from paramiko import ( AuthHandler, ChannelException, - DSSKey, Packetizer, RSAKey, SSHException, - AuthenticationException, IncompatiblePeer, SecurityOptions, - ServerInterface, + ServiceRequestingTransport, Transport, ) -from paramiko import AUTH_FAILED, AUTH_SUCCESSFUL -from paramiko import OPEN_SUCCEEDED, OPEN_FAILED_ADMINISTRATIVELY_PROHIBITED +from paramiko.auth_handler import AuthOnlyHandler +from paramiko import OPEN_FAILED_ADMINISTRATIVELY_PROHIBITED from paramiko.common import ( DEFAULT_MAX_PACKET_SIZE, DEFAULT_WINDOW_SIZE, @@ -60,8 +57,17 @@ from paramiko.common import ( ) from paramiko.message import Message -from .util import needs_builtin, _support, requires_sha1_signing, slow -from .loop import LoopSocket +from ._util import ( + needs_builtin, + _support, + requires_sha1_signing, + slow, + server, + _disable_sha2, + _disable_sha1, + TestServer as NullServer, +) +from ._loop import LoopSocket LONG_BANNER = """\ @@ -77,80 +83,11 @@ Maybe. """ -class NullServer(ServerInterface): - paranoid_did_password = False - paranoid_did_public_key = False - paranoid_key = DSSKey.from_private_key_file(_support("test_dss.key")) - - def __init__(self, allowed_keys=None): - self.allowed_keys = allowed_keys if allowed_keys is not None else [] - - def get_allowed_auths(self, username): - if username == "slowdive": - return "publickey,password" - return "publickey" - - def check_auth_password(self, username, password): - if (username == "slowdive") and (password == "pygmalion"): - return AUTH_SUCCESSFUL - return AUTH_FAILED - - def check_auth_publickey(self, username, key): - if key in self.allowed_keys: - return AUTH_SUCCESSFUL - return AUTH_FAILED - - def check_channel_request(self, kind, chanid): - if kind == "bogus": - return OPEN_FAILED_ADMINISTRATIVELY_PROHIBITED - return OPEN_SUCCEEDED - - def check_channel_exec_request(self, channel, command): - if command != b"yes": - return False - return True - - def check_channel_shell_request(self, channel): - return True - - def check_global_request(self, kind, msg): - self._global_request = kind - # NOTE: for w/e reason, older impl of this returned False always, even - # tho that's only supposed to occur if the request cannot be served. - # For now, leaving that the default unless test supplies specific - # 'acceptable' request kind - return kind == "acceptable" - - def check_channel_x11_request( - self, - channel, - single_connection, - auth_protocol, - auth_cookie, - screen_number, - ): - self._x11_single_connection = single_connection - self._x11_auth_protocol = auth_protocol - self._x11_auth_cookie = auth_cookie - self._x11_screen_number = screen_number - return True - - def check_port_forward_request(self, addr, port): - self._listen = socket.socket() - self._listen.bind(("127.0.0.1", 0)) - self._listen.listen(1) - return self._listen.getsockname()[1] - - def cancel_port_forward_request(self, addr, port): - self._listen.close() - self._listen = None - - def check_channel_direct_tcpip_request(self, chanid, origin, destination): - self._tcpip_dest = destination - return OPEN_SUCCEEDED - - class TransportTest(unittest.TestCase): + # TODO: this can get nuked once ServiceRequestingTransport becomes the + # only Transport, as it has this baked in. + _auth_handler_class = AuthHandler + def setUp(self): self.socks = LoopSocket() self.sockc = LoopSocket() @@ -168,7 +105,7 @@ class TransportTest(unittest.TestCase): def setup_test_server( self, client_options=None, server_options=None, connect_kwargs=None ): - host_key = RSAKey.from_private_key_file(_support("test_rsa.key")) + host_key = RSAKey.from_private_key_file(_support("rsa.key")) public_host_key = RSAKey(data=host_key.asbytes()) self.ts.add_server_key(host_key) @@ -234,7 +171,7 @@ class TransportTest(unittest.TestCase): loopback sockets. this is hardly "simple" but it's simpler than the later tests. :) """ - host_key = RSAKey.from_private_key_file(_support("test_rsa.key")) + host_key = RSAKey.from_private_key_file(_support("rsa.key")) public_host_key = RSAKey(data=host_key.asbytes()) self.ts.add_server_key(host_key) event = threading.Event() @@ -260,7 +197,7 @@ class TransportTest(unittest.TestCase): """ verify that a long banner doesn't mess up the handshake. """ - host_key = RSAKey.from_private_key_file(_support("test_rsa.key")) + host_key = RSAKey.from_private_key_file(_support("rsa.key")) public_host_key = RSAKey(data=host_key.asbytes()) self.ts.add_server_key(host_key) event = threading.Event() @@ -910,7 +847,7 @@ class TransportTest(unittest.TestCase): # be fine. Even tho it's a bit squicky. self.tc.packetizer = SlowPacketizer(self.tc.sock) # Continue with regular test red tape. - host_key = RSAKey.from_private_key_file(_support("test_rsa.key")) + host_key = RSAKey.from_private_key_file(_support("rsa.key")) public_host_key = RSAKey(data=host_key.asbytes()) self.ts.add_server_key(host_key) event = threading.Event() @@ -1099,7 +1036,8 @@ class TransportTest(unittest.TestCase): def test_server_transports_reject_client_message_types(self): # TODO: handle Transport's own tables too, not just its inner auth # handler's table. See TODOs in auth_handler.py - for message_type in AuthHandler._client_handler_table: + some_handler = self._auth_handler_class(self.tc) + for message_type in some_handler._client_handler_table: self._send_client_message(message_type) self._expect_unimplemented() # Reset for rest of loop @@ -1115,6 +1053,21 @@ class TransportTest(unittest.TestCase): self._expect_unimplemented() +# TODO: for now this is purely a regression test. It needs actual tests of the +# intentional new behavior too! +class ServiceRequestingTransportTest(TransportTest): + _auth_handler_class = AuthOnlyHandler + + def setUp(self): + # Copypasta (Transport init is load-bearing) + self.socks = LoopSocket() + self.sockc = LoopSocket() + self.sockc.link(self.socks) + # New class who dis + self.tc = ServiceRequestingTransport(self.sockc) + self.ts = ServiceRequestingTransport(self.socks) + + class AlgorithmDisablingTests(unittest.TestCase): def test_preferred_lists_default_to_private_attribute_contents(self): t = Transport(sock=Mock()) @@ -1188,98 +1141,6 @@ class AlgorithmDisablingTests(unittest.TestCase): assert "zlib" not in compressions -@contextmanager -def server( - hostkey=None, - init=None, - server_init=None, - client_init=None, - connect=None, - pubkeys=None, - catch_error=False, -): - """ - SSH server contextmanager for testing. - - :param hostkey: - Host key to use for the server; if None, loads - ``test_rsa.key``. - :param init: - Default `Transport` constructor kwargs to use for both sides. - :param server_init: - Extends and/or overrides ``init`` for server transport only. - :param client_init: - Extends and/or overrides ``init`` for client transport only. - :param connect: - Kwargs to use for ``connect()`` on the client. - :param pubkeys: - List of public keys for auth. - :param catch_error: - Whether to capture connection errors & yield from contextmanager. - Necessary for connection_time exception testing. - """ - if init is None: - init = {} - if server_init is None: - server_init = {} - if client_init is None: - client_init = {} - if connect is None: - connect = dict(username="slowdive", password="pygmalion") - socks = LoopSocket() - sockc = LoopSocket() - sockc.link(socks) - tc = Transport(sockc, **dict(init, **client_init)) - ts = Transport(socks, **dict(init, **server_init)) - - if hostkey is None: - hostkey = RSAKey.from_private_key_file(_support("test_rsa.key")) - ts.add_server_key(hostkey) - event = threading.Event() - server = NullServer(allowed_keys=pubkeys) - assert not event.is_set() - assert not ts.is_active() - assert tc.get_username() is None - assert ts.get_username() is None - assert not tc.is_authenticated() - assert not ts.is_authenticated() - - err = None - # Trap errors and yield instead of raising right away; otherwise callers - # cannot usefully deal with problems at connect time which stem from errors - # in the server side. - try: - ts.start_server(event, server) - tc.connect(**connect) - - event.wait(1.0) - assert event.is_set() - assert ts.is_active() - assert tc.is_active() - - except Exception as e: - if not catch_error: - raise - err = e - - yield (tc, ts, err) if catch_error else (tc, ts) - - tc.close() - ts.close() - socks.close() - sockc.close() - - -_disable_sha2 = dict( - disabled_algorithms=dict(keys=["rsa-sha2-256", "rsa-sha2-512"]) -) -_disable_sha1 = dict(disabled_algorithms=dict(keys=["ssh-rsa"])) -_disable_sha2_pubkey = dict( - disabled_algorithms=dict(pubkeys=["rsa-sha2-256", "rsa-sha2-512"]) -) -_disable_sha1_pubkey = dict(disabled_algorithms=dict(pubkeys=["ssh-rsa"])) - - class TestSHA2SignatureKeyExchange(unittest.TestCase): # NOTE: these all rely on the default server() hostkey being RSA # NOTE: these rely on both sides being properly implemented re: agreed-upon @@ -1343,22 +1204,29 @@ class TestSHA2SignatureKeyExchange(unittest.TestCase): # (This is a regression test vs previous implementation which overwrote # the entire preferred-hostkeys structure when given an explicit key as # a client.) - hostkey = RSAKey.from_private_key_file(_support("test_rsa.key")) - with server(hostkey=hostkey, connect=dict(hostkey=hostkey)) as (tc, _): + hostkey = RSAKey.from_private_key_file(_support("rsa.key")) + connect = dict( + hostkey=hostkey, username="slowdive", password="pygmalion" + ) + with server(hostkey=hostkey, connect=connect) as (tc, _): assert tc.host_key_type == "rsa-sha2-512" class TestExtInfo(unittest.TestCase): - def test_ext_info_handshake(self): + def test_ext_info_handshake_exposed_in_client_kexinit(self): with server() as (tc, _): + # NOTE: this is latest KEXINIT /sent by us/ (Transport retains it) kex = tc._get_latest_kex_init() - assert kex["kex_algo_list"][-1] == "ext-info-c" + # flag in KexAlgorithms list + assert "ext-info-c" in kex["kex_algo_list"] + # data stored on Transport after hearing back from a compatible + # server (such as ourselves in server mode) assert tc.server_extensions == { "server-sig-algs": b"ssh-ed25519,ecdsa-sha2-nistp256,ecdsa-sha2-nistp384,ecdsa-sha2-nistp521,rsa-sha2-512,rsa-sha2-256,ssh-rsa,ssh-dss" # noqa } def test_client_uses_server_sig_algs_for_pubkey_auth(self): - privkey = RSAKey.from_private_key_file(_support("test_rsa.key")) + privkey = RSAKey.from_private_key_file(_support("rsa.key")) with server( pubkeys=[privkey], connect=dict(pkey=privkey), @@ -1367,94 +1235,6 @@ class TestExtInfo(unittest.TestCase): ), ) as (tc, _): assert tc.is_authenticated() - # Client settled on 256 despite itself not having 512 disabled - assert tc._agreed_pubkey_algorithm == "rsa-sha2-256" - - -# TODO: these could move into test_auth.py but that badly needs refactoring -# with this module anyways... -class TestSHA2SignaturePubkeys(unittest.TestCase): - def test_pubkey_auth_honors_disabled_algorithms(self): - privkey = RSAKey.from_private_key_file(_support("test_rsa.key")) - with server( - pubkeys=[privkey], - connect=dict(pkey=privkey), - init=dict( - disabled_algorithms=dict( - pubkeys=["ssh-rsa", "rsa-sha2-256", "rsa-sha2-512"] - ) - ), - catch_error=True, - ) as (_, _, err): - assert isinstance(err, SSHException) - assert "no RSA pubkey algorithms" in str(err) - - def test_client_sha2_disabled_server_sha1_disabled_no_match(self): - privkey = RSAKey.from_private_key_file(_support("test_rsa.key")) - with server( - pubkeys=[privkey], - connect=dict(pkey=privkey), - client_init=_disable_sha2_pubkey, - server_init=_disable_sha1_pubkey, - catch_error=True, - ) as (tc, ts, err): - assert isinstance(err, AuthenticationException) - - def test_client_sha1_disabled_server_sha2_disabled_no_match(self): - privkey = RSAKey.from_private_key_file(_support("test_rsa.key")) - with server( - pubkeys=[privkey], - connect=dict(pkey=privkey), - client_init=_disable_sha1_pubkey, - server_init=_disable_sha2_pubkey, - catch_error=True, - ) as (tc, ts, err): - assert isinstance(err, AuthenticationException) - - @requires_sha1_signing - def test_ssh_rsa_still_used_when_sha2_disabled(self): - privkey = RSAKey.from_private_key_file(_support("test_rsa.key")) - # NOTE: this works because key obj comparison uses public bytes - # TODO: would be nice for PKey to grow a legit "give me another obj of - # same class but just the public bits" using asbytes() - with server( - pubkeys=[privkey], connect=dict(pkey=privkey), init=_disable_sha2 - ) as (tc, _): - assert tc.is_authenticated() - - def test_sha2_512(self): - privkey = RSAKey.from_private_key_file(_support("test_rsa.key")) - with server( - pubkeys=[privkey], - connect=dict(pkey=privkey), - init=dict( - disabled_algorithms=dict(pubkeys=["ssh-rsa", "rsa-sha2-256"]) - ), - ) as (tc, ts): - assert tc.is_authenticated() - assert tc._agreed_pubkey_algorithm == "rsa-sha2-512" - - def test_sha2_256(self): - privkey = RSAKey.from_private_key_file(_support("test_rsa.key")) - with server( - pubkeys=[privkey], - connect=dict(pkey=privkey), - init=dict( - disabled_algorithms=dict(pubkeys=["ssh-rsa", "rsa-sha2-512"]) - ), - ) as (tc, ts): - assert tc.is_authenticated() - assert tc._agreed_pubkey_algorithm == "rsa-sha2-256" - - def test_sha2_256_when_client_only_enables_256(self): - privkey = RSAKey.from_private_key_file(_support("test_rsa.key")) - with server( - pubkeys=[privkey], - connect=dict(pkey=privkey), - # Client-side only; server still accepts all 3. - client_init=dict( - disabled_algorithms=dict(pubkeys=["ssh-rsa", "rsa-sha2-512"]) - ), - ) as (tc, ts): - assert tc.is_authenticated() + # Client settled on 256 despite itself not having 512 disabled (and + # otherwise, 512 would have been earlier in the preferred list) assert tc._agreed_pubkey_algorithm == "rsa-sha2-256" |