summaryrefslogtreecommitdiffhomepage
path: root/tests/test_transport.py
diff options
context:
space:
mode:
Diffstat (limited to 'tests/test_transport.py')
-rw-r--r--tests/test_transport.py212
1 files changed, 210 insertions, 2 deletions
diff --git a/tests/test_transport.py b/tests/test_transport.py
index 421c078b..54e1f206 100644
--- a/tests/test_transport.py
+++ b/tests/test_transport.py
@@ -22,6 +22,7 @@ Some unit tests for the ssh2 protocol in Transport.
from binascii import hexlify
+import itertools
import select
import socket
import time
@@ -33,10 +34,11 @@ from unittest.mock import Mock
from paramiko import (
AuthHandler,
ChannelException,
+ IncompatiblePeer,
+ MessageOrderError,
Packetizer,
RSAKey,
SSHException,
- IncompatiblePeer,
SecurityOptions,
ServiceRequestingTransport,
Transport,
@@ -49,11 +51,15 @@ from paramiko.common import (
MAX_WINDOW_SIZE,
MIN_PACKET_SIZE,
MIN_WINDOW_SIZE,
+ MSG_CHANNEL_OPEN,
+ MSG_DEBUG,
+ MSG_IGNORE,
MSG_KEXINIT,
+ MSG_UNIMPLEMENTED,
MSG_USERAUTH_SUCCESS,
+ byte_chr,
cMSG_CHANNEL_WINDOW_ADJUST,
cMSG_UNIMPLEMENTED,
- byte_chr,
)
from paramiko.message import Message
@@ -68,6 +74,7 @@ from ._util import (
TestServer as NullServer,
)
from ._loop import LoopSocket
+from pytest import mark, raises
LONG_BANNER = """\
@@ -82,6 +89,10 @@ Note: An SSH banner may eventually appear.
Maybe.
"""
+# Faux 'packet type' we do not implement and are unlikely ever to (but which is
+# technically "within spec" re RFC 4251
+MSG_FUGGEDABOUTIT = 253
+
class TransportTest(unittest.TestCase):
# TODO: this can get nuked once ServiceRequestingTransport becomes the
@@ -1052,6 +1063,16 @@ class TransportTest(unittest.TestCase):
# Real fix's behavior
self._expect_unimplemented()
+ def test_can_override_packetizer_used(self):
+ class MyPacketizer(Packetizer):
+ pass
+
+ # control case
+ assert Transport(sock=LoopSocket()).packetizer.__class__ is Packetizer
+ # overridden case
+ tweaked = Transport(sock=LoopSocket(), packetizer_class=MyPacketizer)
+ assert tweaked.packetizer.__class__ is MyPacketizer
+
# TODO: for now this is purely a regression test. It needs actual tests of the
# intentional new behavior too!
@@ -1238,3 +1259,190 @@ class TestExtInfo(unittest.TestCase):
# Client settled on 256 despite itself not having 512 disabled (and
# otherwise, 512 would have been earlier in the preferred list)
assert tc._agreed_pubkey_algorithm == "rsa-sha2-256"
+
+
+class BadSeqPacketizer(Packetizer):
+ def read_message(self):
+ cmd, msg = super().read_message()
+ # Only mess w/ seqno if kexinit.
+ if cmd is MSG_KEXINIT:
+ # NOTE: this is /only/ the copy of the seqno which gets
+ # transmitted up from Packetizer; it's not modifying
+ # Packetizer's own internal seqno. For these tests,
+ # modifying the latter isn't required, and is also harder
+ # to do w/o triggering MAC mismatches.
+ msg.seqno = 17 # arbitrary nonzero int
+ return cmd, msg
+
+
+class TestStrictKex:
+ def test_kex_algos_includes_kex_strict_c(self):
+ with server() as (tc, _):
+ kex = tc._get_latest_kex_init()
+ assert "kex-strict-c-v00@openssh.com" in kex["kex_algo_list"]
+
+ @mark.parametrize(
+ "server_active,client_active",
+ itertools.product([True, False], repeat=2),
+ )
+ def test_mode_agreement(self, server_active, client_active):
+ with server(
+ server_init=dict(strict_kex=server_active),
+ client_init=dict(strict_kex=client_active),
+ ) as (tc, ts):
+ if server_active and client_active:
+ assert tc.agreed_on_strict_kex is True
+ assert ts.agreed_on_strict_kex is True
+ else:
+ assert tc.agreed_on_strict_kex is False
+ assert ts.agreed_on_strict_kex is False
+
+ def test_mode_advertised_by_default(self):
+ # NOTE: no explicit strict_kex overrides...
+ with server() as (tc, ts):
+ assert all(
+ (
+ tc.advertise_strict_kex,
+ tc.agreed_on_strict_kex,
+ ts.advertise_strict_kex,
+ ts.agreed_on_strict_kex,
+ )
+ )
+
+ @mark.parametrize(
+ "ptype",
+ (
+ # "normal" but definitely out-of-order message
+ MSG_CHANNEL_OPEN,
+ # Normally ignored, but not in this case
+ MSG_IGNORE,
+ # Normally triggers debug parsing, but not in this case
+ MSG_DEBUG,
+ # Normally ignored, but...you get the idea
+ MSG_UNIMPLEMENTED,
+ # Not real, so would normally trigger us /sending/
+ # MSG_UNIMPLEMENTED, but...
+ MSG_FUGGEDABOUTIT,
+ ),
+ )
+ def test_MessageOrderError_non_kex_messages_in_initial_kex(self, ptype):
+ class AttackTransport(Transport):
+ # Easiest apparent spot on server side which is:
+ # - late enough for both ends to have handshook on strict mode
+ # - early enough to be in the window of opportunity for Terrapin
+ # attack; essentially during actual kex, when the engine is
+ # waiting for things like MSG_KEXECDH_REPLY (for eg curve25519).
+ def _negotiate_keys(self, m):
+ self.clear_to_send_lock.acquire()
+ try:
+ self.clear_to_send.clear()
+ finally:
+ self.clear_to_send_lock.release()
+ if self.local_kex_init is None:
+ # remote side wants to renegotiate
+ self._send_kex_init()
+ self._parse_kex_init(m)
+ # Here, we would normally kick over to kex_engine, but instead
+ # we want the server to send the OOO message.
+ m = Message()
+ m.add_byte(byte_chr(ptype))
+ # rest of packet unnecessary...
+ self._send_message(m)
+
+ with raises(MessageOrderError):
+ with server(server_transport_factory=AttackTransport) as (tc, _):
+ pass # above should run and except during connect()
+
+ def test_SSHException_raised_on_out_of_order_messages_when_not_strict(
+ self,
+ ):
+ # This is kind of dumb (either situation is still fatal!) but whatever,
+ # may as well be strict with our new strict flag...
+ with raises(SSHException) as info: # would be true either way, but
+ with server(
+ client_init=dict(strict_kex=False),
+ ) as (tc, _):
+ tc._expect_packet(MSG_KEXINIT)
+ tc.open_session()
+ assert info.type is SSHException # NOT MessageOrderError!
+
+ def test_error_not_raised_when_kexinit_not_seq_0_but_unstrict(self):
+ with server(
+ client_init=dict(
+ # Disable strict kex
+ strict_kex=False,
+ # Give our clientside a packetizer that sets all kexinit
+ # Message objects to have .seqno==17, which would trigger the
+ # new logic if we'd forgotten to wrap it in strict-kex check
+ packetizer_class=BadSeqPacketizer,
+ ),
+ ):
+ pass # kexinit happens at connect...
+
+ def test_MessageOrderError_raised_when_kexinit_not_seq_0_and_strict(self):
+ with raises(MessageOrderError):
+ with server(
+ # Give our clientside a packetizer that sets all kexinit
+ # Message objects to have .seqno==17, which should trigger the
+ # new logic (given we are NOT disabling strict-mode)
+ client_init=dict(packetizer_class=BadSeqPacketizer),
+ ):
+ pass # kexinit happens at connect...
+
+ def test_sequence_numbers_reset_on_newkeys_when_strict(self):
+ with server(defer=True) as (tc, ts):
+ # When in strict mode, these should all be zero or close to it
+ # (post-kexinit, pre-auth).
+ # Server->client will be 1 (EXT_INFO got sent after NEWKEYS)
+ assert tc.packetizer._Packetizer__sequence_number_in == 1
+ assert ts.packetizer._Packetizer__sequence_number_out == 1
+ # Client->server will be 0
+ assert tc.packetizer._Packetizer__sequence_number_out == 0
+ assert ts.packetizer._Packetizer__sequence_number_in == 0
+
+ def test_sequence_numbers_not_reset_on_newkeys_when_not_strict(self):
+ with server(defer=True, client_init=dict(strict_kex=False)) as (
+ tc,
+ ts,
+ ):
+ # When not in strict mode, these will all be ~3-4 or so
+ # (post-kexinit, pre-auth). Not encoding exact values as it will
+ # change anytime we mess with the test harness...
+ assert tc.packetizer._Packetizer__sequence_number_in != 0
+ assert tc.packetizer._Packetizer__sequence_number_out != 0
+ assert ts.packetizer._Packetizer__sequence_number_in != 0
+ assert ts.packetizer._Packetizer__sequence_number_out != 0
+
+ def test_sequence_number_rollover_detected(self):
+ class RolloverTransport(Transport):
+ def __init__(self, *args, **kwargs):
+ super().__init__(*args, **kwargs)
+ # Induce an about-to-rollover seqno, such that it rolls over
+ # during initial kex. (Sequence numbers are uint32, so we need
+ # the largest possible 32bit integer such that incrementing it
+ # will roll over to 0.)
+ last_seq = 2**32 - 1
+ setattr(
+ self.packetizer,
+ "_Packetizer__sequence_number_in",
+ last_seq,
+ )
+ setattr(
+ self.packetizer,
+ "_Packetizer__sequence_number_out",
+ last_seq,
+ )
+
+ with raises(
+ SSHException,
+ match=r"Sequence number rolled over during initial kex!",
+ ):
+ with server(
+ client_init=dict(
+ # Disable strict kex - this should happen always
+ strict_kex=False,
+ ),
+ # Transport which tickles its packetizer seqno's
+ transport_factory=RolloverTransport,
+ ):
+ pass # kexinit happens at connect...