summaryrefslogtreecommitdiffhomepage
diff options
context:
space:
mode:
-rw-r--r--paramiko/packet.py6
-rw-r--r--paramiko/transport.py22
-rw-r--r--tests/test_transport.py25
3 files changed, 49 insertions, 4 deletions
diff --git a/paramiko/packet.py b/paramiko/packet.py
index e40355e3..fc3d2de1 100644
--- a/paramiko/packet.py
+++ b/paramiko/packet.py
@@ -130,6 +130,12 @@ class Packetizer:
def closed(self):
return self.__closed
+ def reset_seqno_out(self):
+ self.__sequence_number_out = 0
+
+ def reset_seqno_in(self):
+ self.__sequence_number_in = 0
+
def set_log(self, log):
"""
Set the Python log object to use for logging.
diff --git a/paramiko/transport.py b/paramiko/transport.py
index c819d9a6..f3925861 100644
--- a/paramiko/transport.py
+++ b/paramiko/transport.py
@@ -2499,9 +2499,13 @@ class Transport(threading.Thread, ClosingContextManager):
# CVE mitigation: expect zeroed-out seqno anytime we are performing kex
# init phase, if strict mode was negotiated.
- if self.agreed_on_strict_kex and m.seqno != 0:
+ if (
+ self.agreed_on_strict_kex
+ and not self.initial_kex_done
+ and m.seqno != 0
+ ):
raise MessageOrderError(
- f"Got nonzero seqno ({m.seqno}) during strict KEXINIT!"
+ "In strict-kex mode, but KEXINIT was not the first packet!"
)
# as a server, we pick the first item in the client's list that we
@@ -2703,6 +2707,13 @@ class Transport(threading.Thread, ClosingContextManager):
):
self._log(DEBUG, "Switching on inbound compression ...")
self.packetizer.set_inbound_compressor(compress_in())
+ # Reset inbound sequence number if strict mode.
+ if self.agreed_on_strict_kex:
+ self._log(
+ DEBUG,
+ f"Resetting inbound seqno after NEWKEYS due to strict mode",
+ )
+ self.packetizer.reset_seqno_in()
def _activate_outbound(self):
"""switch on newly negotiated encryption parameters for
@@ -2710,6 +2721,13 @@ class Transport(threading.Thread, ClosingContextManager):
m = Message()
m.add_byte(cMSG_NEWKEYS)
self._send_message(m)
+ # Reset outbound sequence number if strict mode.
+ if self.agreed_on_strict_kex:
+ self._log(
+ DEBUG,
+ f"Resetting outbound sequence number after NEWKEYS due to strict mode",
+ )
+ self.packetizer.reset_seqno_out()
block_size = self._cipher_info[self.local_cipher]["block-size"]
if self.server_mode:
IV_out = self._compute_key("B", block_size)
diff --git a/tests/test_transport.py b/tests/test_transport.py
index 6cd9398a..f9bb89db 100644
--- a/tests/test_transport.py
+++ b/tests/test_transport.py
@@ -1345,5 +1345,26 @@ class TestStrictKex:
):
pass # kexinit happens at connect...
- def test_sequence_numbers_reset_on_newkeys(self):
- skip()
+ def test_sequence_numbers_reset_on_newkeys_when_strict(self):
+ with server(defer=True) as (tc, ts):
+ # When in strict mode, these should all be zero or close to it
+ # (post-kexinit, pre-auth).
+ # Server->client will be 1 (EXT_INFO got sent after NEWKEYS)
+ assert tc.packetizer._Packetizer__sequence_number_in == 1
+ assert ts.packetizer._Packetizer__sequence_number_out == 1
+ # Client->server will be 0
+ assert tc.packetizer._Packetizer__sequence_number_out == 0
+ assert ts.packetizer._Packetizer__sequence_number_in == 0
+
+ def test_sequence_numbers_not_reset_on_newkeys_when_not_strict(self):
+ with server(defer=True, client_init=dict(strict_kex=False)) as (
+ tc,
+ ts,
+ ):
+ # When not in strict mode, these will all be ~3-4 or so
+ # (post-kexinit, pre-auth). Not encoding exact values as it will
+ # change anytime we mess with the test harness...
+ assert tc.packetizer._Packetizer__sequence_number_in != 0
+ assert tc.packetizer._Packetizer__sequence_number_out != 0
+ assert ts.packetizer._Packetizer__sequence_number_in != 0
+ assert ts.packetizer._Packetizer__sequence_number_out != 0