diff options
author | Jeff Forcier <jeff@bitprophet.org> | 2023-12-16 16:17:58 -0500 |
---|---|---|
committer | Jeff Forcier <jeff@bitprophet.org> | 2023-12-16 16:17:58 -0500 |
commit | 75e311d3c0845a316b6e7b3fae2488d86ad5a270 (patch) | |
tree | 6702433a8f92b31c6b5ad52786ee3d05ba42c8d0 /tests/test_transport.py | |
parent | 73f079f5a4bbba7f3048dadbe05b24242206745e (diff) |
Enforce zero seqno on kexinit
Diffstat (limited to 'tests/test_transport.py')
-rw-r--r-- | tests/test_transport.py | 62 |
1 files changed, 56 insertions, 6 deletions
diff --git a/tests/test_transport.py b/tests/test_transport.py index 060a6cae..6cd9398a 100644 --- a/tests/test_transport.py +++ b/tests/test_transport.py @@ -1055,6 +1055,16 @@ class TransportTest(unittest.TestCase): # Real fix's behavior self._expect_unimplemented() + def test_can_override_packetizer_used(self): + class MyPacketizer(Packetizer): + pass + + # control case + assert Transport(sock=LoopSocket()).packetizer.__class__ is Packetizer + # overridden case + tweaked = Transport(sock=LoopSocket(), packetizer_class=MyPacketizer) + assert tweaked.packetizer.__class__ is MyPacketizer + # TODO: for now this is purely a regression test. It needs actual tests of the # intentional new behavior too! @@ -1243,6 +1253,20 @@ class TestExtInfo(unittest.TestCase): assert tc._agreed_pubkey_algorithm == "rsa-sha2-256" +class BadSeqPacketizer(Packetizer): + def read_message(self): + cmd, msg = super().read_message() + # Only mess w/ seqno if kexinit. + if cmd is MSG_KEXINIT: + # NOTE: this is /only/ the copy of the seqno which gets + # transmitted up from Packetizer; it's not modifying + # Packetizer's own internal seqno. For these tests, + # modifying the latter isn't required, and is also harder + # to do w/o triggering MAC mismatches. + msg.seqno = 17 # arbitrary nonzero int + return cmd, msg + + class TestStrictKex: def test_kex_algos_includes_kex_strict_c(self): with server() as (tc, _): @@ -1277,9 +1301,6 @@ class TestStrictKex: ) ) - def test_sequence_numbers_reset_on_newkeys(self): - skip() - def test_MessageOrderError_raised_on_out_of_order_messages(self): with raises(MessageOrderError): with server() as (tc, _): @@ -1288,12 +1309,41 @@ class TestStrictKex: tc._expect_packet(MSG_KEXINIT) tc.open_session() - def test_SSHException_raised_on_out_of_order_messages_when_not_strict(self): + def test_SSHException_raised_on_out_of_order_messages_when_not_strict( + self, + ): # This is kind of dumb (either situation is still fatal!) but whatever, # may as well be strict with our new strict flag... with raises(SSHException) as info: # would be true either way, but - with server(client_init=dict(strict_kex=False), - ) as (tc, _): + with server( + client_init=dict(strict_kex=False), + ) as (tc, _): tc._expect_packet(MSG_KEXINIT) tc.open_session() assert info.type is SSHException # NOT MessageOrderError! + + def test_error_not_raised_when_kexinit_not_seq_0_but_unstrict(self): + with server( + client_init=dict( + # Disable strict kex + strict_kex=False, + # Give our clientside a packetizer that sets all kexinit + # Message objects to have .seqno==17, which would trigger the + # new logic if we'd forgotten to wrap it in strict-kex check + packetizer_class=BadSeqPacketizer, + ), + ): + pass # kexinit happens at connect... + + def test_MessageOrderError_raised_when_kexinit_not_seq_0_and_strict(self): + with raises(MessageOrderError): + with server( + # Give our clientside a packetizer that sets all kexinit + # Message objects to have .seqno==17, which should trigger the + # new logic (given we are NOT disabling strict-mode) + client_init=dict(packetizer_class=BadSeqPacketizer), + ): + pass # kexinit happens at connect... + + def test_sequence_numbers_reset_on_newkeys(self): + skip() |