diff options
Diffstat (limited to 'tests/test_transport.py')
-rw-r--r-- | tests/test_transport.py | 212 |
1 files changed, 210 insertions, 2 deletions
diff --git a/tests/test_transport.py b/tests/test_transport.py index 421c078b..54e1f206 100644 --- a/tests/test_transport.py +++ b/tests/test_transport.py @@ -22,6 +22,7 @@ Some unit tests for the ssh2 protocol in Transport. from binascii import hexlify +import itertools import select import socket import time @@ -33,10 +34,11 @@ from unittest.mock import Mock from paramiko import ( AuthHandler, ChannelException, + IncompatiblePeer, + MessageOrderError, Packetizer, RSAKey, SSHException, - IncompatiblePeer, SecurityOptions, ServiceRequestingTransport, Transport, @@ -49,11 +51,15 @@ from paramiko.common import ( MAX_WINDOW_SIZE, MIN_PACKET_SIZE, MIN_WINDOW_SIZE, + MSG_CHANNEL_OPEN, + MSG_DEBUG, + MSG_IGNORE, MSG_KEXINIT, + MSG_UNIMPLEMENTED, MSG_USERAUTH_SUCCESS, + byte_chr, cMSG_CHANNEL_WINDOW_ADJUST, cMSG_UNIMPLEMENTED, - byte_chr, ) from paramiko.message import Message @@ -68,6 +74,7 @@ from ._util import ( TestServer as NullServer, ) from ._loop import LoopSocket +from pytest import mark, raises LONG_BANNER = """\ @@ -82,6 +89,10 @@ Note: An SSH banner may eventually appear. Maybe. """ +# Faux 'packet type' we do not implement and are unlikely ever to (but which is +# technically "within spec" re RFC 4251 +MSG_FUGGEDABOUTIT = 253 + class TransportTest(unittest.TestCase): # TODO: this can get nuked once ServiceRequestingTransport becomes the @@ -1052,6 +1063,16 @@ class TransportTest(unittest.TestCase): # Real fix's behavior self._expect_unimplemented() + def test_can_override_packetizer_used(self): + class MyPacketizer(Packetizer): + pass + + # control case + assert Transport(sock=LoopSocket()).packetizer.__class__ is Packetizer + # overridden case + tweaked = Transport(sock=LoopSocket(), packetizer_class=MyPacketizer) + assert tweaked.packetizer.__class__ is MyPacketizer + # TODO: for now this is purely a regression test. It needs actual tests of the # intentional new behavior too! @@ -1238,3 +1259,190 @@ class TestExtInfo(unittest.TestCase): # Client settled on 256 despite itself not having 512 disabled (and # otherwise, 512 would have been earlier in the preferred list) assert tc._agreed_pubkey_algorithm == "rsa-sha2-256" + + +class BadSeqPacketizer(Packetizer): + def read_message(self): + cmd, msg = super().read_message() + # Only mess w/ seqno if kexinit. + if cmd is MSG_KEXINIT: + # NOTE: this is /only/ the copy of the seqno which gets + # transmitted up from Packetizer; it's not modifying + # Packetizer's own internal seqno. For these tests, + # modifying the latter isn't required, and is also harder + # to do w/o triggering MAC mismatches. + msg.seqno = 17 # arbitrary nonzero int + return cmd, msg + + +class TestStrictKex: + def test_kex_algos_includes_kex_strict_c(self): + with server() as (tc, _): + kex = tc._get_latest_kex_init() + assert "kex-strict-c-v00@openssh.com" in kex["kex_algo_list"] + + @mark.parametrize( + "server_active,client_active", + itertools.product([True, False], repeat=2), + ) + def test_mode_agreement(self, server_active, client_active): + with server( + server_init=dict(strict_kex=server_active), + client_init=dict(strict_kex=client_active), + ) as (tc, ts): + if server_active and client_active: + assert tc.agreed_on_strict_kex is True + assert ts.agreed_on_strict_kex is True + else: + assert tc.agreed_on_strict_kex is False + assert ts.agreed_on_strict_kex is False + + def test_mode_advertised_by_default(self): + # NOTE: no explicit strict_kex overrides... + with server() as (tc, ts): + assert all( + ( + tc.advertise_strict_kex, + tc.agreed_on_strict_kex, + ts.advertise_strict_kex, + ts.agreed_on_strict_kex, + ) + ) + + @mark.parametrize( + "ptype", + ( + # "normal" but definitely out-of-order message + MSG_CHANNEL_OPEN, + # Normally ignored, but not in this case + MSG_IGNORE, + # Normally triggers debug parsing, but not in this case + MSG_DEBUG, + # Normally ignored, but...you get the idea + MSG_UNIMPLEMENTED, + # Not real, so would normally trigger us /sending/ + # MSG_UNIMPLEMENTED, but... + MSG_FUGGEDABOUTIT, + ), + ) + def test_MessageOrderError_non_kex_messages_in_initial_kex(self, ptype): + class AttackTransport(Transport): + # Easiest apparent spot on server side which is: + # - late enough for both ends to have handshook on strict mode + # - early enough to be in the window of opportunity for Terrapin + # attack; essentially during actual kex, when the engine is + # waiting for things like MSG_KEXECDH_REPLY (for eg curve25519). + def _negotiate_keys(self, m): + self.clear_to_send_lock.acquire() + try: + self.clear_to_send.clear() + finally: + self.clear_to_send_lock.release() + if self.local_kex_init is None: + # remote side wants to renegotiate + self._send_kex_init() + self._parse_kex_init(m) + # Here, we would normally kick over to kex_engine, but instead + # we want the server to send the OOO message. + m = Message() + m.add_byte(byte_chr(ptype)) + # rest of packet unnecessary... + self._send_message(m) + + with raises(MessageOrderError): + with server(server_transport_factory=AttackTransport) as (tc, _): + pass # above should run and except during connect() + + def test_SSHException_raised_on_out_of_order_messages_when_not_strict( + self, + ): + # This is kind of dumb (either situation is still fatal!) but whatever, + # may as well be strict with our new strict flag... + with raises(SSHException) as info: # would be true either way, but + with server( + client_init=dict(strict_kex=False), + ) as (tc, _): + tc._expect_packet(MSG_KEXINIT) + tc.open_session() + assert info.type is SSHException # NOT MessageOrderError! + + def test_error_not_raised_when_kexinit_not_seq_0_but_unstrict(self): + with server( + client_init=dict( + # Disable strict kex + strict_kex=False, + # Give our clientside a packetizer that sets all kexinit + # Message objects to have .seqno==17, which would trigger the + # new logic if we'd forgotten to wrap it in strict-kex check + packetizer_class=BadSeqPacketizer, + ), + ): + pass # kexinit happens at connect... + + def test_MessageOrderError_raised_when_kexinit_not_seq_0_and_strict(self): + with raises(MessageOrderError): + with server( + # Give our clientside a packetizer that sets all kexinit + # Message objects to have .seqno==17, which should trigger the + # new logic (given we are NOT disabling strict-mode) + client_init=dict(packetizer_class=BadSeqPacketizer), + ): + pass # kexinit happens at connect... + + def test_sequence_numbers_reset_on_newkeys_when_strict(self): + with server(defer=True) as (tc, ts): + # When in strict mode, these should all be zero or close to it + # (post-kexinit, pre-auth). + # Server->client will be 1 (EXT_INFO got sent after NEWKEYS) + assert tc.packetizer._Packetizer__sequence_number_in == 1 + assert ts.packetizer._Packetizer__sequence_number_out == 1 + # Client->server will be 0 + assert tc.packetizer._Packetizer__sequence_number_out == 0 + assert ts.packetizer._Packetizer__sequence_number_in == 0 + + def test_sequence_numbers_not_reset_on_newkeys_when_not_strict(self): + with server(defer=True, client_init=dict(strict_kex=False)) as ( + tc, + ts, + ): + # When not in strict mode, these will all be ~3-4 or so + # (post-kexinit, pre-auth). Not encoding exact values as it will + # change anytime we mess with the test harness... + assert tc.packetizer._Packetizer__sequence_number_in != 0 + assert tc.packetizer._Packetizer__sequence_number_out != 0 + assert ts.packetizer._Packetizer__sequence_number_in != 0 + assert ts.packetizer._Packetizer__sequence_number_out != 0 + + def test_sequence_number_rollover_detected(self): + class RolloverTransport(Transport): + def __init__(self, *args, **kwargs): + super().__init__(*args, **kwargs) + # Induce an about-to-rollover seqno, such that it rolls over + # during initial kex. (Sequence numbers are uint32, so we need + # the largest possible 32bit integer such that incrementing it + # will roll over to 0.) + last_seq = 2**32 - 1 + setattr( + self.packetizer, + "_Packetizer__sequence_number_in", + last_seq, + ) + setattr( + self.packetizer, + "_Packetizer__sequence_number_out", + last_seq, + ) + + with raises( + SSHException, + match=r"Sequence number rolled over during initial kex!", + ): + with server( + client_init=dict( + # Disable strict kex - this should happen always + strict_kex=False, + ), + # Transport which tickles its packetizer seqno's + transport_factory=RolloverTransport, + ): + pass # kexinit happens at connect... |