summaryrefslogtreecommitdiffhomepage
path: root/tests
diff options
context:
space:
mode:
Diffstat (limited to 'tests')
-rw-r--r--tests/test_transport.py76
1 files changed, 75 insertions, 1 deletions
diff --git a/tests/test_transport.py b/tests/test_transport.py
index 13fb302e..2b8ee3bc 100644
--- a/tests/test_transport.py
+++ b/tests/test_transport.py
@@ -30,6 +30,7 @@ import threading
import random
from hashlib import sha1
import unittest
+from mock import Mock
from paramiko import (
Transport,
@@ -41,19 +42,24 @@ from paramiko import (
ChannelException,
Packetizer,
Channel,
+ AuthHandler,
)
from paramiko import AUTH_FAILED, AUTH_SUCCESSFUL
from paramiko import OPEN_SUCCEEDED, OPEN_FAILED_ADMINISTRATIVELY_PROHIBITED
from paramiko.common import (
MSG_KEXINIT,
cMSG_CHANNEL_WINDOW_ADJUST,
+ cMSG_UNIMPLEMENTED,
MIN_PACKET_SIZE,
MIN_WINDOW_SIZE,
MAX_WINDOW_SIZE,
DEFAULT_WINDOW_SIZE,
DEFAULT_MAX_PACKET_SIZE,
+ MSG_NAMES,
+ MSG_UNIMPLEMENTED,
+ MSG_USERAUTH_SUCCESS,
)
-from paramiko.py3compat import bytes
+from paramiko.py3compat import bytes, byte_chr
from paramiko.message import Message
from .util import needs_builtin, _support, slow
@@ -1027,3 +1033,71 @@ class TransportTest(unittest.TestCase):
assert "forwarding request denied" in str(e)
else:
assert False, "Did not raise SSHException!"
+
+ def _send_unimplemented(self, server_is_sender):
+ self.setup_test_server()
+ sender, recipient = self.tc, self.ts
+ if server_is_sender:
+ sender, recipient = self.ts, self.tc
+ recipient._send_message = Mock()
+ msg = Message()
+ msg.add_byte(cMSG_UNIMPLEMENTED)
+ sender._send_message(msg)
+ # TODO: I hate this but I literally don't see a good way to know when
+ # the recipient has received the sender's message (there are no
+ # existing threading events in play that work for this), esp in this
+ # case where we don't WANT a response (as otherwise we could
+ # potentially try blocking on the sender's receipt of a reply...maybe).
+ time.sleep(0.1)
+ assert not recipient._send_message.called
+
+ def test_server_does_not_respond_to_MSG_UNIMPLEMENTED(self):
+ self._send_unimplemented(server_is_sender=False)
+
+ def test_client_does_not_respond_to_MSG_UNIMPLEMENTED(self):
+ self._send_unimplemented(server_is_sender=True)
+
+ def _send_client_message(self, message_type):
+ self.setup_test_server(connect_kwargs={})
+ self.ts._send_message = Mock()
+ # NOTE: this isn't 100% realistic (most of these message types would
+ # have actual other fields in 'em) but it suffices to test the level of
+ # message dispatch we're interested in here.
+ msg = Message()
+ # TODO: really not liking the whole cMSG_XXX vs MSG_XXX duality right
+ # now, esp since the former is almost always just byte_chr(the
+ # latter)...but since that's the case...
+ msg.add_byte(byte_chr(message_type))
+ self.tc._send_message(msg)
+ # No good way to actually wait for server action (see above tests re:
+ # MSG_UNIMPLEMENTED). Grump.
+ time.sleep(0.1)
+
+ def _expect_unimplemented(self):
+ # Ensure MSG_UNIMPLEMENTED was sent (implies it hit end of loop instead
+ # of truly handling the given message).
+ # NOTE: When bug present, this will actually be the first thing that
+ # fails (since in many cases actual message handling doesn't involve
+ # sending a message back right away).
+ assert self.ts._send_message.call_count == 1
+ reply = self.ts._send_message.call_args[0][0]
+ reply.rewind() # Because it's pre-send, not post-receive
+ assert reply.get_byte() == cMSG_UNIMPLEMENTED
+
+ def test_server_transports_reject_client_message_types(self):
+ # TODO: handle Transport's own tables too, not just its inner auth
+ # handler's table. See TODOs in auth_handler.py
+ for message_type in AuthHandler._client_handler_table:
+ self._send_client_message(message_type)
+ self._expect_unimplemented()
+ # Reset for rest of loop
+ self.tearDown()
+ self.setUp()
+
+ def test_server_rejects_client_MSG_USERAUTH_SUCCESS(self):
+ self._send_client_message(MSG_USERAUTH_SUCCESS)
+ # Sanity checks
+ assert not self.ts.authenticated
+ assert not self.ts.auth_handler.authenticated
+ # Real fix's behavior
+ self._expect_unimplemented()