diff options
Diffstat (limited to 'tests/test_transport.py')
-rw-r--r-- | tests/test_transport.py | 218 |
1 files changed, 147 insertions, 71 deletions
diff --git a/tests/test_transport.py b/tests/test_transport.py index 485a18e8..5069e5b0 100644 --- a/tests/test_transport.py +++ b/tests/test_transport.py @@ -20,22 +20,28 @@ Some unit tests for the ssh2 protocol in Transport. """ +from __future__ import with_statement + from binascii import hexlify import select import socket import time import threading import random +from hashlib import sha1 +import unittest from paramiko import Transport, SecurityOptions, ServerInterface, RSAKey, DSSKey, \ - SSHException, ChannelException + SSHException, ChannelException, Packetizer from paramiko import AUTH_FAILED, AUTH_SUCCESSFUL from paramiko import OPEN_SUCCEEDED, OPEN_FAILED_ADMINISTRATIVELY_PROHIBITED -from paramiko.common import MSG_KEXINIT, cMSG_CHANNEL_WINDOW_ADJUST +from paramiko.common import MSG_KEXINIT, cMSG_CHANNEL_WINDOW_ADJUST, \ + MIN_PACKET_SIZE, MIN_WINDOW_SIZE, MAX_WINDOW_SIZE, \ + DEFAULT_WINDOW_SIZE, DEFAULT_MAX_PACKET_SIZE from paramiko.py3compat import bytes from paramiko.message import Message from tests.loop import LoopSocket -from tests.util import ParamikoTest, test_path +from tests.util import test_path LONG_BANNER = """\ @@ -55,7 +61,7 @@ class NullServer (ServerInterface): paranoid_did_password = False paranoid_did_public_key = False paranoid_key = DSSKey.from_private_key_file(test_path('test_dss.key')) - + def get_allowed_auths(self, username): if username == 'slowdive': return 'publickey,password' @@ -72,30 +78,30 @@ class NullServer (ServerInterface): return OPEN_SUCCEEDED def check_channel_exec_request(self, channel, command): - if command != 'yes': + if command != b'yes': return False return True def check_channel_shell_request(self, channel): return True - + def check_global_request(self, kind, msg): self._global_request = kind return False - + def check_channel_x11_request(self, channel, single_connection, auth_protocol, auth_cookie, screen_number): self._x11_single_connection = single_connection self._x11_auth_protocol = auth_protocol self._x11_auth_cookie = auth_cookie self._x11_screen_number = screen_number return True - + def check_port_forward_request(self, addr, port): self._listen = socket.socket() self._listen.bind(('127.0.0.1', 0)) self._listen.listen(1) return self._listen.getsockname()[1] - + def cancel_port_forward_request(self, addr, port): self._listen.close() self._listen = None @@ -105,7 +111,7 @@ class NullServer (ServerInterface): return OPEN_SUCCEEDED -class TransportTest(ParamikoTest): +class TransportTest(unittest.TestCase): def setUp(self): self.socks = LoopSocket() self.sockc = LoopSocket() @@ -123,20 +129,20 @@ class TransportTest(ParamikoTest): host_key = RSAKey.from_private_key_file(test_path('test_rsa.key')) public_host_key = RSAKey(data=host_key.asbytes()) self.ts.add_server_key(host_key) - + if client_options is not None: client_options(self.tc.get_security_options()) if server_options is not None: server_options(self.ts.get_security_options()) - + event = threading.Event() self.server = NullServer() - self.assertTrue(not event.isSet()) + self.assertTrue(not event.is_set()) self.ts.start_server(event, self.server) self.tc.connect(hostkey=public_host_key, username='slowdive', password='pygmalion') event.wait(1.0) - self.assertTrue(event.isSet()) + self.assertTrue(event.is_set()) self.assertTrue(self.ts.is_active()) def test_1_security_options(self): @@ -155,7 +161,7 @@ class TransportTest(ParamikoTest): self.assertTrue(False) except TypeError: pass - + def test_2_compute_key(self): self.tc.K = 123281095979686581523377256114209720774539068973101330872763622971399429481072519713536292772709507296759612401802191955568143056534122385270077606457721553469730659233569339356140085284052436697480759510519672848743794433460113118986816826624865291116513647975790797391795651716378444844877749505443714557929 self.tc.H = b'\x0C\x83\x07\xCD\xE6\x85\x6F\xF3\x0B\xA9\x36\x84\xEB\x0F\x04\xC2\x52\x0E\x9E\xD3' @@ -175,7 +181,7 @@ class TransportTest(ParamikoTest): self.ts.add_server_key(host_key) event = threading.Event() server = NullServer() - self.assertTrue(not event.isSet()) + self.assertTrue(not event.is_set()) self.assertEqual(None, self.tc.get_username()) self.assertEqual(None, self.ts.get_username()) self.assertEqual(False, self.tc.is_authenticated()) @@ -184,7 +190,7 @@ class TransportTest(ParamikoTest): self.tc.connect(hostkey=public_host_key, username='slowdive', password='pygmalion') event.wait(1.0) - self.assertTrue(event.isSet()) + self.assertTrue(event.is_set()) self.assertTrue(self.ts.is_active()) self.assertEqual('slowdive', self.tc.get_username()) self.assertEqual('slowdive', self.ts.get_username()) @@ -200,15 +206,15 @@ class TransportTest(ParamikoTest): self.ts.add_server_key(host_key) event = threading.Event() server = NullServer() - self.assertTrue(not event.isSet()) + self.assertTrue(not event.is_set()) self.socks.send(LONG_BANNER) self.ts.start_server(event, server) self.tc.connect(hostkey=public_host_key, username='slowdive', password='pygmalion') event.wait(1.0) - self.assertTrue(event.isSet()) + self.assertTrue(event.is_set()) self.assertTrue(self.ts.is_active()) - + def test_4_special(self): """ verify that the client can demand odd handshake settings, and can @@ -222,7 +228,7 @@ class TransportTest(ParamikoTest): self.assertEqual('aes256-cbc', self.tc.remote_cipher) self.assertEqual(12, self.tc.packetizer.get_mac_size_out()) self.assertEqual(12, self.tc.packetizer.get_mac_size_in()) - + self.tc.send_ignore(1024) self.tc.renegotiate_keys() self.ts.send_ignore(1024) @@ -236,7 +242,7 @@ class TransportTest(ParamikoTest): self.tc.set_keepalive(1) time.sleep(2) self.assertEqual('keepalive@lag.net', self.server._global_request) - + def test_6_exec_command(self): """ verify that exec_command() does something reasonable. @@ -246,11 +252,11 @@ class TransportTest(ParamikoTest): chan = self.tc.open_session() schan = self.ts.accept(1.0) try: - chan.exec_command('no') + chan.exec_command(b'command contains \xfc and is not a valid UTF-8 string') self.assertTrue(False) except SSHException: pass - + chan = self.tc.open_session() chan.exec_command('yes') schan = self.ts.accept(1.0) @@ -264,7 +270,7 @@ class TransportTest(ParamikoTest): f = chan.makefile_stderr() self.assertEqual('This is on stderr.\n', f.readline()) self.assertEqual('', f.readline()) - + # now try it with combined stdout/stderr chan = self.tc.open_session() chan.exec_command('yes') @@ -273,11 +279,27 @@ class TransportTest(ParamikoTest): schan.send_stderr('This is on stderr.\n') schan.close() - chan.set_combine_stderr(True) + chan.set_combine_stderr(True) f = chan.makefile() self.assertEqual('Hello there.\n', f.readline()) self.assertEqual('This is on stderr.\n', f.readline()) self.assertEqual('', f.readline()) + + def test_6a_channel_can_be_used_as_context_manager(self): + """ + verify that exec_command() does something reasonable. + """ + self.setup_test_server() + + with self.tc.open_session() as chan: + with self.ts.accept(1.0) as schan: + chan.exec_command('yes') + schan.send('Hello there.\n') + schan.close() + + f = chan.makefile() + self.assertEqual('Hello there.\n', f.readline()) + self.assertEqual('', f.readline()) def test_7_invoke_shell(self): """ @@ -320,7 +342,7 @@ class TransportTest(ParamikoTest): schan.shutdown_write() schan.send_exit_status(23) schan.close() - + f = chan.makefile() self.assertEqual('Hello there.\n', f.readline()) self.assertEqual('', f.readline()) @@ -342,14 +364,14 @@ class TransportTest(ParamikoTest): chan.invoke_shell() schan = self.ts.accept(1.0) - # nothing should be ready + # nothing should be ready r, w, e = select.select([chan], [], [], 0.1) self.assertEqual([], r) self.assertEqual([], w) self.assertEqual([], e) - + schan.send('hello\n') - + # something should be ready now (give it 1 second to appear) for i in range(10): r, w, e = select.select([chan], [], [], 0.1) @@ -361,7 +383,7 @@ class TransportTest(ParamikoTest): self.assertEqual([], e) self.assertEqual(b'hello\n', chan.recv(6)) - + # and, should be dead again now r, w, e = select.select([chan], [], [], 0.1) self.assertEqual([], r) @@ -369,7 +391,7 @@ class TransportTest(ParamikoTest): self.assertEqual([], e) schan.close() - + # detect eof? for i in range(10): r, w, e = select.select([chan], [], [], 0.1) @@ -380,14 +402,14 @@ class TransportTest(ParamikoTest): self.assertEqual([], w) self.assertEqual([], e) self.assertEqual(bytes(), chan.recv(16)) - + # make sure the pipe is still open for now... p = chan._pipe self.assertEqual(False, p._closed) chan.close() # ...and now is closed. self.assertEqual(True, p._closed) - + def test_B_renegotiate(self): """ verify that a transport can correctly renegotiate mid-stream. @@ -402,7 +424,7 @@ class TransportTest(ParamikoTest): for i in range(20): chan.send('x' * 1024) chan.close() - + # allow a few seconds for the rekeying to complete for i in range(50): if self.tc.H != self.tc.session_id: @@ -426,9 +448,11 @@ class TransportTest(ParamikoTest): bytes = self.tc.packetizer._Packetizer__sent_bytes chan.send('x' * 1024) bytes2 = self.tc.packetizer._Packetizer__sent_bytes + block_size = self.tc._cipher_info[self.tc.local_cipher]['block-size'] + mac_size = self.tc._mac_info[self.tc.local_mac]['size'] # tests show this is actually compressed to *52 bytes*! including packet overhead! nice!! :) self.assertTrue(bytes2 - bytes < 1024) - self.assertEqual(52, bytes2 - bytes) + self.assertEqual(16 + block_size + mac_size, bytes2 - bytes) chan.close() schan.close() @@ -441,28 +465,28 @@ class TransportTest(ParamikoTest): chan = self.tc.open_session() chan.exec_command('yes') schan = self.ts.accept(1.0) - + requested = [] def handler(c, addr_port): addr, port = addr_port requested.append((addr, port)) self.tc._queue_incoming_channel(c) - + self.assertEqual(None, getattr(self.server, '_x11_screen_number', None)) cookie = chan.request_x11(0, single_connection=True, handler=handler) self.assertEqual(0, self.server._x11_screen_number) self.assertEqual('MIT-MAGIC-COOKIE-1', self.server._x11_auth_protocol) self.assertEqual(cookie, self.server._x11_auth_cookie) self.assertEqual(True, self.server._x11_single_connection) - + x11_server = self.ts.open_x11_channel(('localhost', 6093)) x11_client = self.tc.accept() self.assertEqual('localhost', requested[0][0]) self.assertEqual(6093, requested[0][1]) - + x11_server.send('hello') self.assertEqual(b'hello', x11_client.recv(5)) - + x11_server.close() x11_client.close() chan.close() @@ -477,13 +501,13 @@ class TransportTest(ParamikoTest): chan = self.tc.open_session() chan.exec_command('yes') schan = self.ts.accept(1.0) - + requested = [] def handler(c, origin_addr_port, server_addr_port): requested.append(origin_addr_port) requested.append(server_addr_port) self.tc._queue_incoming_channel(c) - + port = self.tc.request_port_forward('127.0.0.1', 0, handler) self.assertEqual(port, self.server._listen.getsockname()[1]) @@ -492,14 +516,14 @@ class TransportTest(ParamikoTest): ss, _ = self.server._listen.accept() sch = self.ts.open_forwarded_tcpip_channel(ss.getsockname(), ss.getpeername()) cch = self.tc.accept() - + sch.send('hello') self.assertEqual(b'hello', cch.recv(5)) sch.close() cch.close() ss.close() cs.close() - + # now cancel it. self.tc.cancel_port_forward('127.0.0.1', port) self.assertTrue(self.server._listen is None) @@ -513,7 +537,7 @@ class TransportTest(ParamikoTest): chan = self.tc.open_session() chan.exec_command('yes') schan = self.ts.accept(1.0) - + # open a port on the "server" that the client will ask to forward to. greeting_server = socket.socket() greeting_server.bind(('127.0.0.1', 0)) @@ -524,13 +548,13 @@ class TransportTest(ParamikoTest): sch = self.ts.accept(1.0) cch = socket.socket() cch.connect(self.server._tcpip_dest) - + ss, _ = greeting_server.accept() ss.send(b'Hello!\n') ss.close() sch.send(cch.recv(8192)) sch.close() - + self.assertEqual(b'Hello!\n', cs.recv(7)) cs.close() @@ -544,14 +568,14 @@ class TransportTest(ParamikoTest): chan.invoke_shell() schan = self.ts.accept(1.0) - # nothing should be ready + # nothing should be ready r, w, e = select.select([chan], [], [], 0.1) self.assertEqual([], r) self.assertEqual([], w) self.assertEqual([], e) - + schan.send_stderr('hello\n') - + # something should be ready now (give it 1 second to appear) for i in range(10): r, w, e = select.select([chan], [], [], 0.1) @@ -563,7 +587,7 @@ class TransportTest(ParamikoTest): self.assertEqual([], e) self.assertEqual(b'hello\n', chan.recv_stderr(6)) - + # and, should be dead again now r, w, e = select.select([chan], [], [], 0.1) self.assertEqual([], r) @@ -585,12 +609,13 @@ class TransportTest(ParamikoTest): self.assertEqual(chan.send_ready(), True) total = 0 K = '*' * 1024 - while total < 1024 * 1024: + limit = 1+(64 * 2 ** 15) + while total < limit: chan.send(K) total += len(K) if not chan.send_ready(): break - self.assertTrue(total < 1024 * 1024) + self.assertTrue(total < limit) schan.close() chan.close() @@ -599,10 +624,10 @@ class TransportTest(ParamikoTest): def test_I_rekey_deadlock(self): """ Regression test for deadlock when in-transit messages are received after MSG_KEXINIT is sent - + Note: When this test fails, it may leak threads. """ - + # Test for an obscure deadlocking bug that can occur if we receive # certain messages while initiating a key exchange. # @@ -619,7 +644,7 @@ class TransportTest(ParamikoTest): # NeedRekeyException. # 4. In response to NeedRekeyException, the transport thread sends # MSG_KEXINIT to the remote host. - # + # # On the remote host (using any SSH implementation): # 5. The MSG_CHANNEL_DATA is received, and MSG_CHANNEL_WINDOW_ADJUST is sent. # 6. The MSG_KEXINIT is received, and a corresponding MSG_KEXINIT is sent. @@ -654,11 +679,11 @@ class TransportTest(ParamikoTest): self.done_event = done_event self.watchdog_event = threading.Event() self.last = None - + def run(self): try: for i in range(1, 1+self.iterations): - if self.done_event.isSet(): + if self.done_event.is_set(): break self.watchdog_event.set() #print i, "SEND" @@ -666,7 +691,7 @@ class TransportTest(ParamikoTest): finally: self.done_event.set() self.watchdog_event.set() - + class ReceiveThread(threading.Thread): def __init__(self, chan, done_event): threading.Thread.__init__(self, None, None, self.__class__.__name__) @@ -674,10 +699,10 @@ class TransportTest(ParamikoTest): self.chan = chan self.done_event = done_event self.watchdog_event = threading.Event() - + def run(self): try: - while not self.done_event.isSet(): + while not self.done_event.is_set(): if self.chan.recv_ready(): chan.recv(65536) self.watchdog_event.set() @@ -687,10 +712,10 @@ class TransportTest(ParamikoTest): finally: self.done_event.set() self.watchdog_event.set() - + self.setup_test_server() self.ts.packetizer.REKEY_BYTES = 2048 - + chan = self.tc.open_session() chan.exec_command('yes') schan = self.ts.accept(1.0) @@ -712,7 +737,7 @@ class TransportTest(ParamikoTest): self._send_message(m2) return _negotiate_keys(self, m) self.tc._handler_table[MSG_KEXINIT] = _negotiate_keys_wrapper - + # Parameters for the test iterations = 500 # The deadlock does not happen every time, but it # should after many iterations. @@ -724,23 +749,23 @@ class TransportTest(ParamikoTest): # Start the sending thread st = SendThread(schan, iterations, done_event) st.start() - + # Start the receiving thread rt = ReceiveThread(chan, done_event) rt.start() - # Act as a watchdog timer, checking + # Act as a watchdog timer, checking deadlocked = False - while not deadlocked and not done_event.isSet(): + while not deadlocked and not done_event.is_set(): for event in (st.watchdog_event, rt.watchdog_event): event.wait(timeout) - if done_event.isSet(): + if done_event.is_set(): break - if not event.isSet(): + if not event.is_set(): deadlocked = True break event.clear() - + # Tell the threads to stop (if they haven't already stopped). Note # that if one or more threads are deadlocked, they might hang around # forever (until the process exits). @@ -752,3 +777,54 @@ class TransportTest(ParamikoTest): # Close the channels schan.close() chan.close() + + def test_J_sanitze_packet_size(self): + """ + verify that we conform to the rfc of packet and window sizes. + """ + for val, correct in [(4095, MIN_PACKET_SIZE), + (None, DEFAULT_MAX_PACKET_SIZE), + (2**32, MAX_WINDOW_SIZE)]: + self.assertEqual(self.tc._sanitize_packet_size(val), correct) + + def test_K_sanitze_window_size(self): + """ + verify that we conform to the rfc of packet and window sizes. + """ + for val, correct in [(32767, MIN_WINDOW_SIZE), + (None, DEFAULT_WINDOW_SIZE), + (2**32, MAX_WINDOW_SIZE)]: + self.assertEqual(self.tc._sanitize_window_size(val), correct) + + def test_L_handshake_timeout(self): + """ + verify that we can get a hanshake timeout. + """ + # Tweak client Transport instance's Packetizer instance so + # its read_message() sleeps a bit. This helps prevent race conditions + # where the client Transport's timeout timer thread doesn't even have + # time to get scheduled before the main client thread finishes + # handshaking with the server. + # (Doing this on the server's transport *sounds* more 'correct' but + # actually doesn't work nearly as well for whatever reason.) + class SlowPacketizer(Packetizer): + def read_message(self): + time.sleep(1) + return super(SlowPacketizer, self).read_message() + # NOTE: prettttty sure since the replaced .packetizer Packetizer is now + # no longer doing anything with its copy of the socket...everything'll + # be fine. Even tho it's a bit squicky. + self.tc.packetizer = SlowPacketizer(self.tc.sock) + # Continue with regular test red tape. + host_key = RSAKey.from_private_key_file(test_path('test_rsa.key')) + public_host_key = RSAKey(data=host_key.asbytes()) + self.ts.add_server_key(host_key) + event = threading.Event() + server = NullServer() + self.assertTrue(not event.is_set()) + self.tc.handshake_timeout = 0.000000000001 + self.ts.start_server(event, server) + self.assertRaises(EOFError, self.tc.connect, + hostkey=public_host_key, + username='slowdive', + password='pygmalion') |