diff options
Diffstat (limited to 'tests')
-rw-r--r-- | tests/conftest.py | 12 | ||||
-rw-r--r-- | tests/loop.py | 8 | ||||
-rw-r--r-- | tests/stub_sftp.py | 58 | ||||
-rw-r--r-- | tests/test_auth.py | 118 | ||||
-rw-r--r-- | tests/test_buffered_pipe.py | 27 | ||||
-rw-r--r-- | tests/test_client.py | 253 | ||||
-rw-r--r-- | tests/test_file.py | 139 | ||||
-rw-r--r-- | tests/test_gssapi.py | 44 | ||||
-rw-r--r-- | tests/test_hostkeys.py | 79 | ||||
-rw-r--r-- | tests/test_kex.py | 314 | ||||
-rw-r--r-- | tests/test_kex_gss.py | 48 | ||||
-rw-r--r-- | tests/test_message.py | 39 | ||||
-rw-r--r-- | tests/test_packetizer.py | 34 | ||||
-rw-r--r-- | tests/test_pkey.py | 257 | ||||
-rw-r--r-- | tests/test_sftp.py | 495 | ||||
-rw-r--r-- | tests/test_sftp_big.py | 222 | ||||
-rw-r--r-- | tests/test_ssh_exception.py | 21 | ||||
-rw-r--r-- | tests/test_ssh_gss.py | 61 | ||||
-rw-r--r-- | tests/test_transport.py | 323 | ||||
-rw-r--r-- | tests/test_util.py | 375 |
20 files changed, 1667 insertions, 1260 deletions
diff --git a/tests/conftest.py b/tests/conftest.py index d1967a73..2b509c5c 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -19,7 +19,7 @@ from .util import _support # presenting it on error/failure. (But also allow turning it off when doing # very pinpoint debugging - e.g. using breakpoints, so you don't want output # hiding enabled, but also don't want all the logging to gum up the terminal.) -if not os.environ.get('DISABLE_LOGGING', False): +if not os.environ.get("DISABLE_LOGGING", False): logging.basicConfig( level=logging.DEBUG, # Also make sure to set up timestamping for more sanity when debugging. @@ -43,7 +43,7 @@ def make_sftp_folder(): # TODO: if we want to lock ourselves even harder into localhost-only # testing (probably not?) could use tempdir modules for this for improved # safety. Then again...why would someone have such a folder??? - path = os.environ.get('TEST_FOLDER', 'paramiko-test-target') + path = os.environ.get("TEST_FOLDER", "paramiko-test-target") # Forcibly nuke this directory locally, since at the moment, the below # fixtures only ever run with a locally scoped stub test server. shutil.rmtree(path, ignore_errors=True) @@ -52,7 +52,7 @@ def make_sftp_folder(): return path -@pytest.fixture#(scope='session') +@pytest.fixture # (scope='session') def sftp_server(): """ Set up an in-memory SFTP server thread. Yields the client Transport/socket. @@ -69,17 +69,17 @@ def sftp_server(): tc = Transport(sockc) ts = Transport(socks) # Auth - host_key = RSAKey.from_private_key_file(_support('test_rsa.key')) + host_key = RSAKey.from_private_key_file(_support("test_rsa.key")) ts.add_server_key(host_key) # Server setup event = threading.Event() server = StubServer() - ts.set_subsystem_handler('sftp', SFTPServer, StubSFTPServer) + ts.set_subsystem_handler("sftp", SFTPServer, StubSFTPServer) ts.start_server(event, server) # Wait (so client has time to connect? Not sure. Old.) event.wait(1.0) # Make & yield connection. - tc.connect(username='slowdive', password='pygmalion') + tc.connect(username="slowdive", password="pygmalion") yield tc # TODO: any need for shutdown? Why didn't old suite do so? Or was that the # point of the "join all threads from threading module" crap in test.py? diff --git a/tests/loop.py b/tests/loop.py index 6c432867..dd1f5a0c 100644 --- a/tests/loop.py +++ b/tests/loop.py @@ -22,13 +22,13 @@ import threading from paramiko.common import asbytes -class LoopSocket (object): +class LoopSocket(object): """ A LoopSocket looks like a normal socket, but all data written to it is delivered on the read-end of another LoopSocket, and vice versa. It's like a software "socketpair". """ - + def __init__(self): self.__in_buffer = bytes() self.__lock = threading.Lock() @@ -84,7 +84,7 @@ class LoopSocket (object): self.__cv.notifyAll() finally: self.__lock.release() - + def __unlink(self): m = None self.__lock.acquire() @@ -96,5 +96,3 @@ class LoopSocket (object): self.__lock.release() if m is not None: m.__unlink() - - diff --git a/tests/stub_sftp.py b/tests/stub_sftp.py index 19545865..ffae635d 100644 --- a/tests/stub_sftp.py +++ b/tests/stub_sftp.py @@ -24,13 +24,21 @@ import os import sys from paramiko import ( - ServerInterface, SFTPServerInterface, SFTPServer, SFTPAttributes, - SFTPHandle, SFTP_OK, SFTP_FAILURE, AUTH_SUCCESSFUL, OPEN_SUCCEEDED, + ServerInterface, + SFTPServerInterface, + SFTPServer, + SFTPAttributes, + SFTPHandle, + SFTP_OK, + SFTP_FAILURE, + AUTH_SUCCESSFUL, + OPEN_SUCCEEDED, ) from paramiko.common import o666 -class StubServer (ServerInterface): +class StubServer(ServerInterface): + def check_auth_password(self, username, password): # all are allowed return AUTH_SUCCESSFUL @@ -39,7 +47,8 @@ class StubServer (ServerInterface): return OPEN_SUCCEEDED -class StubSFTPHandle (SFTPHandle): +class StubSFTPHandle(SFTPHandle): + def stat(self): try: return SFTPAttributes.from_stat(os.fstat(self.readfile.fileno())) @@ -56,11 +65,11 @@ class StubSFTPHandle (SFTPHandle): return SFTPServer.convert_errno(e.errno) -class StubSFTPServer (SFTPServerInterface): +class StubSFTPServer(SFTPServerInterface): # assume current folder is a fine root # (the tests always create and eventually delete a subfolder, so there shouldn't be any mess) ROOT = os.getcwd() - + def _realpath(self, path): return self.ROOT + self.canonicalize(path) @@ -70,7 +79,9 @@ class StubSFTPServer (SFTPServerInterface): out = [] flist = os.listdir(path) for fname in flist: - attr = SFTPAttributes.from_stat(os.stat(os.path.join(path, fname))) + attr = SFTPAttributes.from_stat( + os.stat(os.path.join(path, fname)) + ) attr.filename = fname out.append(attr) return out @@ -94,9 +105,9 @@ class StubSFTPServer (SFTPServerInterface): def open(self, path, flags, attr): path = self._realpath(path) try: - binary_flag = getattr(os, 'O_BINARY', 0) + binary_flag = getattr(os, "O_BINARY", 0) flags |= binary_flag - mode = getattr(attr, 'st_mode', None) + mode = getattr(attr, "st_mode", None) if mode is not None: fd = os.open(path, flags, mode) else: @@ -110,17 +121,17 @@ class StubSFTPServer (SFTPServerInterface): SFTPServer.set_file_attr(path, attr) if flags & os.O_WRONLY: if flags & os.O_APPEND: - fstr = 'ab' + fstr = "ab" else: - fstr = 'wb' + fstr = "wb" elif flags & os.O_RDWR: if flags & os.O_APPEND: - fstr = 'a+b' + fstr = "a+b" else: - fstr = 'r+b' + fstr = "r+b" else: # O_RDONLY (== 0) - fstr = 'rb' + fstr = "rb" try: f = os.fdopen(fd, fstr) except OSError as e: @@ -159,7 +170,6 @@ class StubSFTPServer (SFTPServerInterface): return SFTPServer.convert_errno(e.errno) return SFTP_OK - def mkdir(self, path, attr): path = self._realpath(path) try: @@ -188,18 +198,18 @@ class StubSFTPServer (SFTPServerInterface): def symlink(self, target_path, path): path = self._realpath(path) - if (len(target_path) > 0) and (target_path[0] == '/'): + if (len(target_path) > 0) and (target_path[0] == "/"): # absolute symlink target_path = os.path.join(self.ROOT, target_path[1:]) - if target_path[:2] == '//': + if target_path[:2] == "//": # bug in os.path.join target_path = target_path[1:] else: # compute relative to path abspath = os.path.join(os.path.dirname(path), target_path) - if abspath[:len(self.ROOT)] != self.ROOT: + if abspath[: len(self.ROOT)] != self.ROOT: # this symlink isn't going to work anyway -- just break it immediately - target_path = '<error>' + target_path = "<error>" try: os.symlink(target_path, path) except OSError as e: @@ -214,10 +224,10 @@ class StubSFTPServer (SFTPServerInterface): return SFTPServer.convert_errno(e.errno) # if it's absolute, remove the root if os.path.isabs(symlink): - if symlink[:len(self.ROOT)] == self.ROOT: - symlink = symlink[len(self.ROOT):] - if (len(symlink) == 0) or (symlink[0] != '/'): - symlink = '/' + symlink + if symlink[: len(self.ROOT)] == self.ROOT: + symlink = symlink[len(self.ROOT) :] + if (len(symlink) == 0) or (symlink[0] != "/"): + symlink = "/" + symlink else: - symlink = '<error>' + symlink = "<error>" return symlink diff --git a/tests/test_auth.py b/tests/test_auth.py index dacdd654..acabb1bd 100644 --- a/tests/test_auth.py +++ b/tests/test_auth.py @@ -26,8 +26,13 @@ import unittest from time import sleep from paramiko import ( - Transport, ServerInterface, RSAKey, DSSKey, BadAuthenticationType, - InteractiveQuery, AuthenticationException, + Transport, + ServerInterface, + RSAKey, + DSSKey, + BadAuthenticationType, + InteractiveQuery, + AuthenticationException, ) from paramiko import AUTH_FAILED, AUTH_PARTIALLY_SUCCESSFUL, AUTH_SUCCESSFUL from paramiko.py3compat import u @@ -36,54 +41,57 @@ from .loop import LoopSocket from .util import _support, slow -_pwd = u('\u2022') +_pwd = u("\u2022") -class NullServer (ServerInterface): +class NullServer(ServerInterface): paranoid_did_password = False paranoid_did_public_key = False - paranoid_key = DSSKey.from_private_key_file(_support('test_dss.key')) + paranoid_key = DSSKey.from_private_key_file(_support("test_dss.key")) def get_allowed_auths(self, username): - if username == 'slowdive': - return 'publickey,password' - if username == 'paranoid': - if not self.paranoid_did_password and not self.paranoid_did_public_key: - return 'publickey,password' + if username == "slowdive": + return "publickey,password" + if username == "paranoid": + if ( + not self.paranoid_did_password + and not self.paranoid_did_public_key + ): + return "publickey,password" elif self.paranoid_did_password: - return 'publickey' + return "publickey" else: - return 'password' - if username == 'commie': - return 'keyboard-interactive' - if username == 'utf8': - return 'password' - if username == 'non-utf8': - return 'password' - return 'publickey' + return "password" + if username == "commie": + return "keyboard-interactive" + if username == "utf8": + return "password" + if username == "non-utf8": + return "password" + return "publickey" def check_auth_password(self, username, password): - if (username == 'slowdive') and (password == 'pygmalion'): + if (username == "slowdive") and (password == "pygmalion"): return AUTH_SUCCESSFUL - if (username == 'paranoid') and (password == 'paranoid'): + if (username == "paranoid") and (password == "paranoid"): # 2-part auth (even openssh doesn't support this) self.paranoid_did_password = True if self.paranoid_did_public_key: return AUTH_SUCCESSFUL return AUTH_PARTIALLY_SUCCESSFUL - if (username == 'utf8') and (password == _pwd): + if (username == "utf8") and (password == _pwd): return AUTH_SUCCESSFUL - if (username == 'non-utf8') and (password == '\xff'): + if (username == "non-utf8") and (password == "\xff"): return AUTH_SUCCESSFUL - if username == 'bad-server': + if username == "bad-server": raise Exception("Ack!") - if username == 'unresponsive-server': + if username == "unresponsive-server": sleep(5) return AUTH_SUCCESSFUL return AUTH_FAILED def check_auth_publickey(self, username, key): - if (username == 'paranoid') and (key == self.paranoid_key): + if (username == "paranoid") and (key == self.paranoid_key): # 2-part auth self.paranoid_did_public_key = True if self.paranoid_did_password: @@ -92,19 +100,21 @@ class NullServer (ServerInterface): return AUTH_FAILED def check_auth_interactive(self, username, submethods): - if username == 'commie': + if username == "commie": self.username = username - return InteractiveQuery('password', 'Please enter a password.', ('Password', False)) + return InteractiveQuery( + "password", "Please enter a password.", ("Password", False) + ) return AUTH_FAILED def check_auth_interactive_response(self, responses): - if self.username == 'commie': - if (len(responses) == 1) and (responses[0] == 'cat'): + if self.username == "commie": + if (len(responses) == 1) and (responses[0] == "cat"): return AUTH_SUCCESSFUL return AUTH_FAILED -class AuthTest (unittest.TestCase): +class AuthTest(unittest.TestCase): def setUp(self): self.socks = LoopSocket() @@ -120,7 +130,7 @@ class AuthTest (unittest.TestCase): self.sockc.close() def start_server(self): - host_key = RSAKey.from_private_key_file(_support('test_rsa.key')) + host_key = RSAKey.from_private_key_file(_support("test_rsa.key")) self.public_host_key = RSAKey(data=host_key.asbytes()) self.ts.add_server_key(host_key) self.event = threading.Event() @@ -140,13 +150,16 @@ class AuthTest (unittest.TestCase): """ self.start_server() try: - self.tc.connect(hostkey=self.public_host_key, - username='unknown', password='error') + self.tc.connect( + hostkey=self.public_host_key, + username="unknown", + password="error", + ) self.assertTrue(False) except: etype, evalue, etb = sys.exc_info() self.assertEqual(BadAuthenticationType, etype) - self.assertEqual(['publickey'], evalue.allowed_types) + self.assertEqual(["publickey"], evalue.allowed_types) def test_bad_password(self): """ @@ -156,12 +169,12 @@ class AuthTest (unittest.TestCase): self.start_server() self.tc.connect(hostkey=self.public_host_key) try: - self.tc.auth_password(username='slowdive', password='error') + self.tc.auth_password(username="slowdive", password="error") self.assertTrue(False) except: etype, evalue, etb = sys.exc_info() self.assertTrue(issubclass(etype, AuthenticationException)) - self.tc.auth_password(username='slowdive', password='pygmalion') + self.tc.auth_password(username="slowdive", password="pygmalion") self.verify_finished() def test_multipart_auth(self): @@ -170,10 +183,12 @@ class AuthTest (unittest.TestCase): """ self.start_server() self.tc.connect(hostkey=self.public_host_key) - remain = self.tc.auth_password(username='paranoid', password='paranoid') - self.assertEqual(['publickey'], remain) - key = DSSKey.from_private_key_file(_support('test_dss.key')) - remain = self.tc.auth_publickey(username='paranoid', key=key) + remain = self.tc.auth_password( + username="paranoid", password="paranoid" + ) + self.assertEqual(["publickey"], remain) + key = DSSKey.from_private_key_file(_support("test_dss.key")) + remain = self.tc.auth_publickey(username="paranoid", key=key) self.assertEqual([], remain) self.verify_finished() @@ -188,10 +203,11 @@ class AuthTest (unittest.TestCase): self.got_title = title self.got_instructions = instructions self.got_prompts = prompts - return ['cat'] - remain = self.tc.auth_interactive('commie', handler) - self.assertEqual(self.got_title, 'password') - self.assertEqual(self.got_prompts, [('Password', False)]) + return ["cat"] + + remain = self.tc.auth_interactive("commie", handler) + self.assertEqual(self.got_title, "password") + self.assertEqual(self.got_prompts, [("Password", False)]) self.assertEqual([], remain) self.verify_finished() @@ -202,7 +218,7 @@ class AuthTest (unittest.TestCase): """ self.start_server() self.tc.connect(hostkey=self.public_host_key) - remain = self.tc.auth_password('commie', 'cat') + remain = self.tc.auth_password("commie", "cat") self.assertEqual([], remain) self.verify_finished() @@ -212,7 +228,7 @@ class AuthTest (unittest.TestCase): """ self.start_server() self.tc.connect(hostkey=self.public_host_key) - remain = self.tc.auth_password('utf8', _pwd) + remain = self.tc.auth_password("utf8", _pwd) self.assertEqual([], remain) self.verify_finished() @@ -223,7 +239,7 @@ class AuthTest (unittest.TestCase): """ self.start_server() self.tc.connect(hostkey=self.public_host_key) - remain = self.tc.auth_password('non-utf8', '\xff') + remain = self.tc.auth_password("non-utf8", "\xff") self.assertEqual([], remain) self.verify_finished() @@ -235,7 +251,7 @@ class AuthTest (unittest.TestCase): self.start_server() self.tc.connect(hostkey=self.public_host_key) try: - remain = self.tc.auth_password('bad-server', 'hello') + remain = self.tc.auth_password("bad-server", "hello") except: etype, evalue, etb = sys.exc_info() self.assertTrue(issubclass(etype, AuthenticationException)) @@ -250,8 +266,8 @@ class AuthTest (unittest.TestCase): self.start_server() self.tc.connect() try: - remain = self.tc.auth_password('unresponsive-server', 'hello') + remain = self.tc.auth_password("unresponsive-server", "hello") except: etype, evalue, etb = sys.exc_info() self.assertTrue(issubclass(etype, AuthenticationException)) - self.assertTrue('Authentication timeout' in str(evalue)) + self.assertTrue("Authentication timeout" in str(evalue)) diff --git a/tests/test_buffered_pipe.py b/tests/test_buffered_pipe.py index 03616c55..9f986a5e 100644 --- a/tests/test_buffered_pipe.py +++ b/tests/test_buffered_pipe.py @@ -30,9 +30,9 @@ from paramiko.py3compat import b def delay_thread(p): - p.feed('a') + p.feed("a") time.sleep(0.5) - p.feed('b') + p.feed("b") p.close() @@ -42,41 +42,42 @@ def close_thread(p): class BufferedPipeTest(unittest.TestCase): + def test_1_buffered_pipe(self): p = BufferedPipe() self.assertTrue(not p.read_ready()) - p.feed('hello.') + p.feed("hello.") self.assertTrue(p.read_ready()) data = p.read(6) - self.assertEqual(b'hello.', data) + self.assertEqual(b"hello.", data) - p.feed('plus/minus') - self.assertEqual(b'plu', p.read(3)) - self.assertEqual(b's/m', p.read(3)) - self.assertEqual(b'inus', p.read(4)) + p.feed("plus/minus") + self.assertEqual(b"plu", p.read(3)) + self.assertEqual(b"s/m", p.read(3)) + self.assertEqual(b"inus", p.read(4)) p.close() self.assertTrue(not p.read_ready()) - self.assertEqual(b'', p.read(1)) + self.assertEqual(b"", p.read(1)) def test_2_delay(self): p = BufferedPipe() self.assertTrue(not p.read_ready()) threading.Thread(target=delay_thread, args=(p,)).start() - self.assertEqual(b'a', p.read(1, 0.1)) + self.assertEqual(b"a", p.read(1, 0.1)) try: p.read(1, 0.1) self.assertTrue(False) except PipeTimeout: pass - self.assertEqual(b'b', p.read(1, 1.0)) - self.assertEqual(b'', p.read(1)) + self.assertEqual(b"b", p.read(1, 1.0)) + self.assertEqual(b"", p.read(1)) def test_3_close_while_reading(self): p = BufferedPipe() threading.Thread(target=close_thread, args=(p,)).start() data = p.read(1, 1.0) - self.assertEqual(b'', data) + self.assertEqual(b"", data) def test_4_or_pipe(self): p = pipe.make_pipe() diff --git a/tests/test_client.py b/tests/test_client.py index 7163fdcf..4943df29 100644 --- a/tests/test_client.py +++ b/tests/test_client.py @@ -48,30 +48,31 @@ requires_gss_auth = unittest.skipUnless( ) FINGERPRINTS = { - 'ssh-dss': b'\x44\x78\xf0\xb9\xa2\x3c\xc5\x18\x20\x09\xff\x75\x5b\xc1\xd2\x6c', - 'ssh-rsa': b'\x60\x73\x38\x44\xcb\x51\x86\x65\x7f\xde\xda\xa2\x2b\x5a\x57\xd5', - 'ecdsa-sha2-nistp256': b'\x25\x19\xeb\x55\xe6\xa1\x47\xff\x4f\x38\xd2\x75\x6f\xa5\xd5\x60', - 'ssh-ed25519': b'\xb3\xd5"\xaa\xf9u^\xe8\xcd\x0e\xea\x02\xb9)\xa2\x80', + "ssh-dss": b"\x44\x78\xf0\xb9\xa2\x3c\xc5\x18\x20\x09\xff\x75\x5b\xc1\xd2\x6c", + "ssh-rsa": b"\x60\x73\x38\x44\xcb\x51\x86\x65\x7f\xde\xda\xa2\x2b\x5a\x57\xd5", + "ecdsa-sha2-nistp256": b"\x25\x19\xeb\x55\xe6\xa1\x47\xff\x4f\x38\xd2\x75\x6f\xa5\xd5\x60", + "ssh-ed25519": b'\xb3\xd5"\xaa\xf9u^\xe8\xcd\x0e\xea\x02\xb9)\xa2\x80', } class NullServer(paramiko.ServerInterface): + def __init__(self, *args, **kwargs): # Allow tests to enable/disable specific key types - self.__allowed_keys = kwargs.pop('allowed_keys', []) + self.__allowed_keys = kwargs.pop("allowed_keys", []) # And allow them to set a (single...meh) expected public blob (cert) - self.__expected_public_blob = kwargs.pop('public_blob', None) + self.__expected_public_blob = kwargs.pop("public_blob", None) super(NullServer, self).__init__(*args, **kwargs) def get_allowed_auths(self, username): - if username == 'slowdive': - return 'publickey,password' - return 'publickey' + if username == "slowdive": + return "publickey,password" + return "publickey" def check_auth_password(self, username, password): - if (username == 'slowdive') and (password == 'pygmalion'): + if (username == "slowdive") and (password == "pygmalion"): return paramiko.AUTH_SUCCESSFUL - if (username == 'slowdive') and (password == 'unresponsive-server'): + if (username == "slowdive") and (password == "unresponsive-server"): time.sleep(5) return paramiko.AUTH_SUCCESSFUL return paramiko.AUTH_FAILED @@ -83,13 +84,13 @@ class NullServer(paramiko.ServerInterface): return paramiko.AUTH_FAILED # Base check: allowed auth type & fingerprint matches happy = ( - key.get_name() in self.__allowed_keys and - key.get_fingerprint() == expected + key.get_name() in self.__allowed_keys + and key.get_fingerprint() == expected ) # Secondary check: if test wants assertions about cert data if ( - self.__expected_public_blob is not None and - key.public_blob != self.__expected_public_blob + self.__expected_public_blob is not None + and key.public_blob != self.__expected_public_blob ): happy = False return paramiko.AUTH_SUCCESSFUL if happy else paramiko.AUTH_FAILED @@ -98,31 +99,32 @@ class NullServer(paramiko.ServerInterface): return paramiko.OPEN_SUCCEEDED def check_channel_exec_request(self, channel, command): - if command != b'yes': + if command != b"yes": return False return True def check_channel_env_request(self, channel, name, value): - if name == 'INVALID_ENV': + if name == "INVALID_ENV": return False - if not hasattr(channel, 'env'): - setattr(channel, 'env', {}) + if not hasattr(channel, "env"): + setattr(channel, "env", {}) channel.env[name] = value return True class ClientTest(unittest.TestCase): + def setUp(self): self.sockl = socket.socket() - self.sockl.bind(('localhost', 0)) + self.sockl.bind(("localhost", 0)) self.sockl.listen(1) self.addr, self.port = self.sockl.getsockname() self.connect_kwargs = dict( hostname=self.addr, port=self.port, - username='slowdive', + username="slowdive", look_for_keys=False, ) self.event = threading.Event() @@ -130,10 +132,10 @@ class ClientTest(unittest.TestCase): def tearDown(self): # Shut down client Transport - if hasattr(self, 'tc'): + if hasattr(self, "tc"): self.tc.close() # Shut down shared socket - if hasattr(self, 'sockl'): + if hasattr(self, "sockl"): # Signal to server thread that it should shut down early; it checks # this immediately after accept(). (In scenarios where connection # actually succeeded during the test, this becomes a no-op.) @@ -151,7 +153,7 @@ class ClientTest(unittest.TestCase): self.sockl.close() def _run( - self, allowed_keys=None, delay=0, public_blob=None, kill_event=None, + self, allowed_keys=None, delay=0, public_blob=None, kill_event=None ): if allowed_keys is None: allowed_keys = FINGERPRINTS.keys() @@ -163,10 +165,10 @@ class ClientTest(unittest.TestCase): self.socks.close() return self.ts = paramiko.Transport(self.socks) - keypath = _support('test_rsa.key') + keypath = _support("test_rsa.key") host_key = paramiko.RSAKey.from_private_key_file(keypath) self.ts.add_server_key(host_key) - keypath = _support('test_ecdsa_256.key') + keypath = _support("test_ecdsa_256.key") host_key = paramiko.ECDSAKey.from_private_key_file(keypath) self.ts.add_server_key(host_key) server = NullServer(allowed_keys=allowed_keys, public_blob=public_blob) @@ -181,17 +183,21 @@ class ClientTest(unittest.TestCase): The exception is ``allowed_keys`` which is stripped and handed to the ``NullServer`` used for testing. """ - run_kwargs = {'kill_event': self.kill_event} - for key in ('allowed_keys', 'public_blob'): + run_kwargs = {"kill_event": self.kill_event} + for key in ("allowed_keys", "public_blob"): run_kwargs[key] = kwargs.pop(key, None) # Server setup threading.Thread(target=self._run, kwargs=run_kwargs).start() - host_key = paramiko.RSAKey.from_private_key_file(_support('test_rsa.key')) + host_key = paramiko.RSAKey.from_private_key_file( + _support("test_rsa.key") + ) public_host_key = paramiko.RSAKey(data=host_key.asbytes()) # Client setup self.tc = paramiko.SSHClient() - self.tc.get_host_keys().add('[%s]:%d' % (self.addr, self.port), 'ssh-rsa', public_host_key) + self.tc.get_host_keys().add( + "[%s]:%d" % (self.addr, self.port), "ssh-rsa", public_host_key + ) # Actual connection self.tc.connect(**dict(self.connect_kwargs, **kwargs)) @@ -200,22 +206,22 @@ class ClientTest(unittest.TestCase): self.event.wait(1.0) self.assertTrue(self.event.is_set()) self.assertTrue(self.ts.is_active()) - self.assertEqual('slowdive', self.ts.get_username()) + self.assertEqual("slowdive", self.ts.get_username()) self.assertEqual(True, self.ts.is_authenticated()) self.assertEqual(False, self.tc.get_transport().gss_kex_used) # Command execution functions? - stdin, stdout, stderr = self.tc.exec_command('yes') + stdin, stdout, stderr = self.tc.exec_command("yes") schan = self.ts.accept(1.0) - schan.send('Hello there.\n') - schan.send_stderr('This is on stderr.\n') + schan.send("Hello there.\n") + schan.send_stderr("This is on stderr.\n") schan.close() - self.assertEqual('Hello there.\n', stdout.readline()) - self.assertEqual('', stdout.readline()) - self.assertEqual('This is on stderr.\n', stderr.readline()) - self.assertEqual('', stderr.readline()) + self.assertEqual("Hello there.\n", stdout.readline()) + self.assertEqual("", stdout.readline()) + self.assertEqual("This is on stderr.\n", stderr.readline()) + self.assertEqual("", stderr.readline()) # Cleanup stdin.close() @@ -224,32 +230,33 @@ class ClientTest(unittest.TestCase): class SSHClientTest(ClientTest): + def test_1_client(self): """ verify that the SSHClient stuff works too. """ - self._test_connection(password='pygmalion') + self._test_connection(password="pygmalion") def test_2_client_dsa(self): """ verify that SSHClient works with a DSA key. """ - self._test_connection(key_filename=_support('test_dss.key')) + self._test_connection(key_filename=_support("test_dss.key")) def test_client_rsa(self): """ verify that SSHClient works with an RSA key. """ - self._test_connection(key_filename=_support('test_rsa.key')) + self._test_connection(key_filename=_support("test_rsa.key")) def test_2_5_client_ecdsa(self): """ verify that SSHClient works with an ECDSA key. """ - self._test_connection(key_filename=_support('test_ecdsa_256.key')) + self._test_connection(key_filename=_support("test_ecdsa_256.key")) def test_client_ed25519(self): - self._test_connection(key_filename=_support('test_ed25519.key')) + self._test_connection(key_filename=_support("test_ed25519.key")) def test_3_multiple_key_files(self): """ @@ -257,22 +264,22 @@ class SSHClientTest(ClientTest): """ # This is dumb :( types_ = { - 'rsa': 'ssh-rsa', - 'dss': 'ssh-dss', - 'ecdsa': 'ecdsa-sha2-nistp256', + "rsa": "ssh-rsa", + "dss": "ssh-dss", + "ecdsa": "ecdsa-sha2-nistp256", } # Various combos of attempted & valid keys # TODO: try every possible combo using itertools functions for attempt, accept in ( - (['rsa', 'dss'], ['dss']), # Original test #3 - (['dss', 'rsa'], ['dss']), # Ordering matters sometimes, sadly - (['dss', 'rsa', 'ecdsa_256'], ['dss']), # Try ECDSA but fail - (['rsa', 'ecdsa_256'], ['ecdsa']), # ECDSA success + (["rsa", "dss"], ["dss"]), # Original test #3 + (["dss", "rsa"], ["dss"]), # Ordering matters sometimes, sadly + (["dss", "rsa", "ecdsa_256"], ["dss"]), # Try ECDSA but fail + (["rsa", "ecdsa_256"], ["ecdsa"]), # ECDSA success ): try: self._test_connection( key_filename=[ - _support('test_{}.key'.format(x)) for x in attempt + _support("test_{}.key".format(x)) for x in attempt ], allowed_keys=[types_[x] for x in accept], ) @@ -288,10 +295,11 @@ class SSHClientTest(ClientTest): """ # Until #387 is fixed we have to catch a high-up exception since # various platforms trigger different errors here >_< - self.assertRaises(SSHException, + self.assertRaises( + SSHException, self._test_connection, - key_filename=[_support('test_rsa.key')], - allowed_keys=['ecdsa-sha2-nistp256'], + key_filename=[_support("test_rsa.key")], + allowed_keys=["ecdsa-sha2-nistp256"], ) def test_certs_allowed_as_key_filename_values(self): @@ -299,9 +307,9 @@ class SSHClientTest(ClientTest): # They're similar except for which path is given; the expected auth and # server-side behavior is 100% identical.) # NOTE: only bothered whipping up one cert per overall class/family. - for type_ in ('rsa', 'dss', 'ecdsa_256', 'ed25519'): - cert_name = 'test_{}.key-cert.pub'.format(type_) - cert_path = _support(os.path.join('cert_support', cert_name)) + for type_ in ("rsa", "dss", "ecdsa_256", "ed25519"): + cert_name = "test_{}.key-cert.pub".format(type_) + cert_path = _support(os.path.join("cert_support", cert_name)) self._test_connection( key_filename=cert_path, public_blob=PublicBlob.from_file(cert_path), @@ -314,13 +322,13 @@ class SSHClientTest(ClientTest): # about the server-side key object's public blob. Thus, we can prove # that a specific cert was found, along with regular authorization # succeeding proving that the overall flow works. - for type_ in ('rsa', 'dss', 'ecdsa_256', 'ed25519'): - key_name = 'test_{}.key'.format(type_) - key_path = _support(os.path.join('cert_support', key_name)) + for type_ in ("rsa", "dss", "ecdsa_256", "ed25519"): + key_name = "test_{}.key".format(type_) + key_path = _support(os.path.join("cert_support", key_name)) self._test_connection( key_filename=key_path, public_blob=PublicBlob.from_file( - '{}-cert.pub'.format(key_path) + "{}-cert.pub".format(key_path) ), ) @@ -335,19 +343,19 @@ class SSHClientTest(ClientTest): verify that SSHClient's AutoAddPolicy works. """ threading.Thread(target=self._run).start() - hostname = '[%s]:%d' % (self.addr, self.port) - key_file = _support('test_ecdsa_256.key') + hostname = "[%s]:%d" % (self.addr, self.port) + key_file = _support("test_ecdsa_256.key") public_host_key = paramiko.ECDSAKey.from_private_key_file(key_file) self.tc = paramiko.SSHClient() self.tc.set_missing_host_key_policy(paramiko.AutoAddPolicy()) self.assertEqual(0, len(self.tc.get_host_keys())) - self.tc.connect(password='pygmalion', **self.connect_kwargs) + self.tc.connect(password="pygmalion", **self.connect_kwargs) self.event.wait(1.0) self.assertTrue(self.event.is_set()) self.assertTrue(self.ts.is_active()) - self.assertEqual('slowdive', self.ts.get_username()) + self.assertEqual("slowdive", self.ts.get_username()) self.assertEqual(True, self.ts.is_authenticated()) self.assertEqual(1, len(self.tc.get_host_keys())) new_host_key = list(self.tc.get_host_keys()[hostname].values())[0] @@ -357,9 +365,11 @@ class SSHClientTest(ClientTest): """ verify that SSHClient correctly saves a known_hosts file. """ - warnings.filterwarnings('ignore', 'tempnam.*') + warnings.filterwarnings("ignore", "tempnam.*") - host_key = paramiko.RSAKey.from_private_key_file(_support('test_rsa.key')) + host_key = paramiko.RSAKey.from_private_key_file( + _support("test_rsa.key") + ) public_host_key = paramiko.RSAKey(data=host_key.asbytes()) fd, localname = mkstemp() os.close(fd) @@ -367,11 +377,13 @@ class SSHClientTest(ClientTest): client = paramiko.SSHClient() self.assertEquals(0, len(client.get_host_keys())) - host_id = '[%s]:%d' % (self.addr, self.port) + host_id = "[%s]:%d" % (self.addr, self.port) - client.get_host_keys().add(host_id, 'ssh-rsa', public_host_key) + client.get_host_keys().add(host_id, "ssh-rsa", public_host_key) self.assertEquals(1, len(client.get_host_keys())) - self.assertEquals(public_host_key, client.get_host_keys()[host_id]['ssh-rsa']) + self.assertEquals( + public_host_key, client.get_host_keys()[host_id]["ssh-rsa"] + ) client.save_host_keys(localname) @@ -394,7 +406,7 @@ class SSHClientTest(ClientTest): self.tc = paramiko.SSHClient() self.tc.set_missing_host_key_policy(paramiko.AutoAddPolicy()) self.assertEqual(0, len(self.tc.get_host_keys())) - self.tc.connect(**dict(self.connect_kwargs, password='pygmalion')) + self.tc.connect(**dict(self.connect_kwargs, password="pygmalion")) self.event.wait(1.0) self.assertTrue(self.event.is_set()) @@ -423,7 +435,7 @@ class SSHClientTest(ClientTest): self.tc = tc self.tc.set_missing_host_key_policy(paramiko.AutoAddPolicy()) self.assertEquals(0, len(self.tc.get_host_keys())) - self.tc.connect(**dict(self.connect_kwargs, password='pygmalion')) + self.tc.connect(**dict(self.connect_kwargs, password="pygmalion")) self.event.wait(1.0) self.assertTrue(self.event.is_set()) @@ -438,19 +450,19 @@ class SSHClientTest(ClientTest): verify that the SSHClient has a configurable banner timeout. """ # Start the thread with a 1 second wait. - threading.Thread(target=self._run, kwargs={'delay': 1}).start() - host_key = paramiko.RSAKey.from_private_key_file(_support('test_rsa.key')) + threading.Thread(target=self._run, kwargs={"delay": 1}).start() + host_key = paramiko.RSAKey.from_private_key_file( + _support("test_rsa.key") + ) public_host_key = paramiko.RSAKey(data=host_key.asbytes()) self.tc = paramiko.SSHClient() - self.tc.get_host_keys().add('[%s]:%d' % (self.addr, self.port), 'ssh-rsa', public_host_key) + self.tc.get_host_keys().add( + "[%s]:%d" % (self.addr, self.port), "ssh-rsa", public_host_key + ) # Connect with a half second banner timeout. kwargs = dict(self.connect_kwargs, banner_timeout=0.5) - self.assertRaises( - paramiko.SSHException, - self.tc.connect, - **kwargs - ) + self.assertRaises(paramiko.SSHException, self.tc.connect, **kwargs) def test_8_auth_trickledown(self): """ @@ -466,9 +478,9 @@ class SSHClientTest(ClientTest): # 'television' as per tests/test_pkey.py). NOTE: must use # key_filename, loading the actual key here with PKey will except # immediately; we're testing the try/except crap within Client. - key_filename=[_support('test_rsa_password.key')], + key_filename=[_support("test_rsa_password.key")], # Actual password for default 'slowdive' user - password='pygmalion', + password="pygmalion", ) self._test_connection(**kwargs) @@ -481,7 +493,7 @@ class SSHClientTest(ClientTest): self.assertRaises( AuthenticationException, self._test_connection, - password='unresponsive-server', + password="unresponsive-server", auth_timeout=0.5, ) @@ -490,10 +502,7 @@ class SSHClientTest(ClientTest): """ Failed gssapi-keyex auth doesn't prevent subsequent key auth from succeeding """ - kwargs = dict( - gss_kex=True, - key_filename=[_support('test_rsa.key')], - ) + kwargs = dict(gss_kex=True, key_filename=[_support("test_rsa.key")]) self._test_connection(**kwargs) @requires_gss_auth @@ -501,10 +510,7 @@ class SSHClientTest(ClientTest): """ Failed gssapi-with-mic auth doesn't prevent subsequent key auth from succeeding """ - kwargs = dict( - gss_auth=True, - key_filename=[_support('test_rsa.key')], - ) + kwargs = dict(gss_auth=True, key_filename=[_support("test_rsa.key")]) self._test_connection(**kwargs) def test_12_reject_policy(self): @@ -519,7 +525,8 @@ class SSHClientTest(ClientTest): self.assertRaises( paramiko.SSHException, self.tc.connect, - password='pygmalion', **self.connect_kwargs + password="pygmalion", + **self.connect_kwargs ) @requires_gss_auth @@ -537,14 +544,14 @@ class SSHClientTest(ClientTest): self.assertRaises( paramiko.SSHException, self.tc.connect, - password='pygmalion', + password="pygmalion", gss_kex=True, - **self.connect_kwargs + **self.connect_kwargs ) def _client_host_key_bad(self, host_key): threading.Thread(target=self._run).start() - hostname = '[%s]:%d' % (self.addr, self.port) + hostname = "[%s]:%d" % (self.addr, self.port) self.tc = paramiko.SSHClient() self.tc.set_missing_host_key_policy(paramiko.WarningPolicy()) @@ -554,13 +561,13 @@ class SSHClientTest(ClientTest): self.assertRaises( paramiko.BadHostKeyException, self.tc.connect, - password='pygmalion', + password="pygmalion", **self.connect_kwargs ) def _client_host_key_good(self, ktype, kfile): threading.Thread(target=self._run).start() - hostname = '[%s]:%d' % (self.addr, self.port) + hostname = "[%s]:%d" % (self.addr, self.port) self.tc = paramiko.SSHClient() self.tc.set_missing_host_key_policy(paramiko.RejectPolicy()) @@ -568,7 +575,7 @@ class SSHClientTest(ClientTest): known_hosts = self.tc.get_host_keys() known_hosts.add(hostname, host_key.get_name(), host_key) - self.tc.connect(password='pygmalion', **self.connect_kwargs) + self.tc.connect(password="pygmalion", **self.connect_kwargs) self.event.wait(1.0) self.assertTrue(self.event.is_set()) self.assertTrue(self.ts.is_active()) @@ -583,10 +590,10 @@ class SSHClientTest(ClientTest): self._client_host_key_bad(host_key) def test_host_key_negotiation_3(self): - self._client_host_key_good(paramiko.ECDSAKey, 'test_ecdsa_256.key') + self._client_host_key_good(paramiko.ECDSAKey, "test_ecdsa_256.key") def test_host_key_negotiation_4(self): - self._client_host_key_good(paramiko.RSAKey, 'test_rsa.key') + self._client_host_key_good(paramiko.RSAKey, "test_rsa.key") def _setup_for_env(self): threading.Thread(target=self._run).start() @@ -594,7 +601,9 @@ class SSHClientTest(ClientTest): self.tc = paramiko.SSHClient() self.tc.set_missing_host_key_policy(paramiko.AutoAddPolicy()) self.assertEqual(0, len(self.tc.get_host_keys())) - self.tc.connect(self.addr, self.port, username='slowdive', password='pygmalion') + self.tc.connect( + self.addr, self.port, username="slowdive", password="pygmalion" + ) self.event.wait(1.0) self.assertTrue(self.event.isSet()) @@ -605,11 +614,11 @@ class SSHClientTest(ClientTest): Verify that environment variables can be set by the client. """ self._setup_for_env() - target_env = {b'A': b'B', b'C': b'd'} + target_env = {b"A": b"B", b"C": b"d"} - self.tc.exec_command('yes', environment=target_env) + self.tc.exec_command("yes", environment=target_env) schan = self.ts.accept(1.0) - self.assertEqual(target_env, getattr(schan, 'env', {})) + self.assertEqual(target_env, getattr(schan, "env", {})) schan.close() @unittest.skip("Clients normally fail silently, thus so do we, for now") @@ -617,14 +626,14 @@ class SSHClientTest(ClientTest): self._setup_for_env() with self.assertRaises(SSHException) as manager: # Verify that a rejection by the server can be detected - self.tc.exec_command('yes', environment={b'INVALID_ENV': b''}) + self.tc.exec_command("yes", environment={b"INVALID_ENV": b""}) self.assertTrue( - 'INVALID_ENV' in str(manager.exception), - 'Expected variable name in error message' + "INVALID_ENV" in str(manager.exception), + "Expected variable name in error message", ) self.assertTrue( isinstance(manager.exception.args[1], SSHException), - 'Expected original SSHException in exception' + "Expected original SSHException in exception", ) def test_missing_key_policy_accepts_classes_or_instances(self): @@ -652,35 +661,39 @@ class PasswordPassphraseTests(ClientTest): def test_password_kwarg_works_for_password_auth(self): # Straightforward / duplicate of earlier basic password test. - self._test_connection(password='pygmalion') + self._test_connection(password="pygmalion") # TODO: more granular exception pending #387; should be signaling "no auth # methods available" because no key and no password @raises(SSHException) def test_passphrase_kwarg_not_used_for_password_auth(self): # Using the "right" password in the "wrong" field shouldn't work. - self._test_connection(passphrase='pygmalion') + self._test_connection(passphrase="pygmalion") def test_passphrase_kwarg_used_for_key_passphrase(self): # Straightforward again, with new passphrase kwarg. self._test_connection( - key_filename=_support('test_rsa_password.key'), - passphrase='television', + key_filename=_support("test_rsa_password.key"), + passphrase="television", ) - def test_password_kwarg_used_for_passphrase_when_no_passphrase_kwarg_given(self): # noqa + def test_password_kwarg_used_for_passphrase_when_no_passphrase_kwarg_given( + self + ): # noqa # Backwards compatibility: passphrase in the password field. self._test_connection( - key_filename=_support('test_rsa_password.key'), - password='television', + key_filename=_support("test_rsa_password.key"), + password="television", ) - @raises(AuthenticationException) # TODO: more granular - def test_password_kwarg_not_used_for_passphrase_when_passphrase_kwarg_given(self): # noqa + @raises(AuthenticationException) # TODO: more granular + def test_password_kwarg_not_used_for_passphrase_when_passphrase_kwarg_given( + self + ): # noqa # Sanity: if we're given both fields, the password field is NOT used as # a passphrase. self._test_connection( - key_filename=_support('test_rsa_password.key'), - password='television', - passphrase='wat? lol no', + key_filename=_support("test_rsa_password.key"), + password="television", + passphrase="wat? lol no", ) diff --git a/tests/test_file.py b/tests/test_file.py index 3d2c94e6..deacd60a 100644 --- a/tests/test_file.py +++ b/tests/test_file.py @@ -30,18 +30,19 @@ from paramiko.py3compat import BytesIO from .util import needs_builtin -class LoopbackFile (BufferedFile): +class LoopbackFile(BufferedFile): """ BufferedFile object that you can write data into, and then read it back. """ - def __init__(self, mode='r', bufsize=-1): + + def __init__(self, mode="r", bufsize=-1): BufferedFile.__init__(self) self._set_mode(mode, bufsize) self.buffer = BytesIO() self.offset = 0 def _read(self, size): - data = self.buffer.getvalue()[self.offset:self.offset+size] + data = self.buffer.getvalue()[self.offset : self.offset + size] self.offset += len(data) return data @@ -50,44 +51,46 @@ class LoopbackFile (BufferedFile): return len(data) -class BufferedFileTest (unittest.TestCase): +class BufferedFileTest(unittest.TestCase): def test_1_simple(self): - f = LoopbackFile('r') + f = LoopbackFile("r") try: - f.write(b'hi') - self.assertTrue(False, 'no exception on write to read-only file') + f.write(b"hi") + self.assertTrue(False, "no exception on write to read-only file") except: pass f.close() - f = LoopbackFile('w') + f = LoopbackFile("w") try: f.read(1) - self.assertTrue(False, 'no exception to read from write-only file') + self.assertTrue(False, "no exception to read from write-only file") except: pass f.close() def test_2_readline(self): - f = LoopbackFile('r+U') - f.write(b'First line.\nSecond line.\r\nThird line.\n' + - b'Fourth line.\nFinal line non-terminated.') + f = LoopbackFile("r+U") + f.write( + b"First line.\nSecond line.\r\nThird line.\n" + + b"Fourth line.\nFinal line non-terminated." + ) - self.assertEqual(f.readline(), 'First line.\n') + self.assertEqual(f.readline(), "First line.\n") # universal newline mode should convert this linefeed: - self.assertEqual(f.readline(), 'Second line.\n') + self.assertEqual(f.readline(), "Second line.\n") # truncated line: - self.assertEqual(f.readline(7), 'Third l') - self.assertEqual(f.readline(), 'ine.\n') + self.assertEqual(f.readline(7), "Third l") + self.assertEqual(f.readline(), "ine.\n") # newline should be detected and only the fourth line returned - self.assertEqual(f.readline(39), 'Fourth line.\n') - self.assertEqual(f.readline(), 'Final line non-terminated.') - self.assertEqual(f.readline(), '') + self.assertEqual(f.readline(39), "Fourth line.\n") + self.assertEqual(f.readline(), "Final line non-terminated.") + self.assertEqual(f.readline(), "") f.close() try: f.readline() - self.assertTrue(False, 'no exception on readline of closed file') + self.assertTrue(False, "no exception on readline of closed file") except IOError: pass self.assertTrue(linefeed_byte in f.newlines) @@ -98,11 +101,11 @@ class BufferedFileTest (unittest.TestCase): """ try to trick the linefeed detector. """ - f = LoopbackFile('r+U') - f.write(b'First line.\r') - self.assertEqual(f.readline(), 'First line.\n') - f.write(b'\nSecond.\r\n') - self.assertEqual(f.readline(), 'Second.\n') + f = LoopbackFile("r+U") + f.write(b"First line.\r") + self.assertEqual(f.readline(), "First line.\n") + f.write(b"\nSecond.\r\n") + self.assertEqual(f.readline(), "Second.\n") f.close() self.assertEqual(f.newlines, crlf) @@ -110,51 +113,54 @@ class BufferedFileTest (unittest.TestCase): """ verify that write buffering is on. """ - f = LoopbackFile('r+', 1) - f.write(b'Complete line.\nIncomplete line.') - self.assertEqual(f.readline(), 'Complete line.\n') - self.assertEqual(f.readline(), '') - f.write('..\n') - self.assertEqual(f.readline(), 'Incomplete line...\n') + f = LoopbackFile("r+", 1) + f.write(b"Complete line.\nIncomplete line.") + self.assertEqual(f.readline(), "Complete line.\n") + self.assertEqual(f.readline(), "") + f.write("..\n") + self.assertEqual(f.readline(), "Incomplete line...\n") f.close() def test_5_flush(self): """ verify that flush will force a write. """ - f = LoopbackFile('r+', 512) - f.write('Not\nquite\n512 bytes.\n') - self.assertEqual(f.read(1), b'') + f = LoopbackFile("r+", 512) + f.write("Not\nquite\n512 bytes.\n") + self.assertEqual(f.read(1), b"") f.flush() - self.assertEqual(f.read(5), b'Not\nq') - self.assertEqual(f.read(10), b'uite\n512 b') - self.assertEqual(f.read(9), b'ytes.\n') - self.assertEqual(f.read(3), b'') + self.assertEqual(f.read(5), b"Not\nq") + self.assertEqual(f.read(10), b"uite\n512 b") + self.assertEqual(f.read(9), b"ytes.\n") + self.assertEqual(f.read(3), b"") f.close() def test_6_buffering(self): """ verify that flushing happens automatically on buffer crossing. """ - f = LoopbackFile('r+', 16) - f.write(b'Too small.') - self.assertEqual(f.read(4), b'') - f.write(b' ') - self.assertEqual(f.read(4), b'') - f.write(b'Enough.') - self.assertEqual(f.read(20), b'Too small. Enough.') + f = LoopbackFile("r+", 16) + f.write(b"Too small.") + self.assertEqual(f.read(4), b"") + f.write(b" ") + self.assertEqual(f.read(4), b"") + f.write(b"Enough.") + self.assertEqual(f.read(20), b"Too small. Enough.") f.close() def test_7_read_all(self): """ verify that read(-1) returns everything left in the file. """ - f = LoopbackFile('r+', 16) - f.write(b'The first thing you need to do is open your eyes. ') - f.write(b'Then, you need to close them again.\n') + f = LoopbackFile("r+", 16) + f.write(b"The first thing you need to do is open your eyes. ") + f.write(b"Then, you need to close them again.\n") s = f.read(-1) - self.assertEqual(s, b'The first thing you need to do is open your eyes. Then, you ' + - b'need to close them again.\n') + self.assertEqual( + s, + b"The first thing you need to do is open your eyes. Then, you " + + b"need to close them again.\n", + ) f.close() def test_8_buffering(self): @@ -162,19 +168,19 @@ class BufferedFileTest (unittest.TestCase): verify that buffered objects can be written """ if sys.version_info[0] == 2: - f = LoopbackFile('r+', 16) - f.write(buffer(b'Too small.')) + f = LoopbackFile("r+", 16) + f.write(buffer(b"Too small.")) f.close() def test_9_readable(self): - f = LoopbackFile('r') + f = LoopbackFile("r") self.assertTrue(f.readable()) self.assertFalse(f.writable()) self.assertFalse(f.seekable()) f.close() def test_A_writable(self): - f = LoopbackFile('w') + f = LoopbackFile("w") self.assertTrue(f.writable()) self.assertFalse(f.readable()) self.assertFalse(f.seekable()) @@ -182,48 +188,49 @@ class BufferedFileTest (unittest.TestCase): def test_B_readinto(self): data = bytearray(5) - f = LoopbackFile('r+') + f = LoopbackFile("r+") f._write(b"hello") f.readinto(data) - self.assertEqual(data, b'hello') + self.assertEqual(data, b"hello") f.close() def test_write_bad_type(self): - with LoopbackFile('wb') as f: + with LoopbackFile("wb") as f: self.assertRaises(TypeError, f.write, object()) def test_write_unicode_as_binary(self): text = u"\xa7 why is writing text to a binary file allowed?\n" - with LoopbackFile('rb+') as f: + with LoopbackFile("rb+") as f: f.write(text) self.assertEqual(f.read(), text.encode("utf-8")) - @needs_builtin('memoryview') + @needs_builtin("memoryview") def test_write_bytearray(self): - with LoopbackFile('rb+') as f: + with LoopbackFile("rb+") as f: f.write(bytearray(12)) self.assertEqual(f.read(), 12 * b"\0") - @needs_builtin('buffer') + @needs_builtin("buffer") def test_write_buffer(self): data = 3 * b"pretend giant block of data\n" offsets = range(0, len(data), 8) - with LoopbackFile('rb+') as f: + with LoopbackFile("rb+") as f: for offset in offsets: f.write(buffer(data, offset, 8)) self.assertEqual(f.read(), data) - @needs_builtin('memoryview') + @needs_builtin("memoryview") def test_write_memoryview(self): data = 3 * b"pretend giant block of data\n" offsets = range(0, len(data), 8) - with LoopbackFile('rb+') as f: + with LoopbackFile("rb+") as f: view = memoryview(data) for offset in offsets: - f.write(view[offset:offset+8]) + f.write(view[offset : offset + 8]) self.assertEqual(f.read(), data) -if __name__ == '__main__': +if __name__ == "__main__": from unittest import main + main() diff --git a/tests/test_gssapi.py b/tests/test_gssapi.py index d4b632be..d7fbdd53 100644 --- a/tests/test_gssapi.py +++ b/tests/test_gssapi.py @@ -30,6 +30,7 @@ from .util import needs_gssapi @needs_gssapi class GSSAPITest(unittest.TestCase): + def setup(): # TODO: these vars should all come from os.environ or whatever the # approved pytest method is for runtime-configuring test data. @@ -43,6 +44,7 @@ class GSSAPITest(unittest.TestCase): """ from pyasn1.type.univ import ObjectIdentifier from pyasn1.codec.der import encoder, decoder + oid = encoder.encode(ObjectIdentifier(self.krb5_mech)) mech, __ = decoder.decode(oid) self.assertEquals(self.krb5_mech, mech.__str__()) @@ -57,6 +59,7 @@ class GSSAPITest(unittest.TestCase): except ImportError: import sspicon import sspi + _API = "SSPI" c_token = None @@ -65,23 +68,28 @@ class GSSAPITest(unittest.TestCase): if _API == "MIT": if self.server_mode: - gss_flags = (gssapi.C_PROT_READY_FLAG, - gssapi.C_INTEG_FLAG, - gssapi.C_MUTUAL_FLAG, - gssapi.C_DELEG_FLAG) + gss_flags = ( + gssapi.C_PROT_READY_FLAG, + gssapi.C_INTEG_FLAG, + gssapi.C_MUTUAL_FLAG, + gssapi.C_DELEG_FLAG, + ) else: - gss_flags = (gssapi.C_PROT_READY_FLAG, - gssapi.C_INTEG_FLAG, - gssapi.C_DELEG_FLAG) + gss_flags = ( + gssapi.C_PROT_READY_FLAG, + gssapi.C_INTEG_FLAG, + gssapi.C_DELEG_FLAG, + ) # Initialize a GSS-API context. ctx = gssapi.Context() ctx.flags = gss_flags krb5_oid = gssapi.OID.mech_from_string(self.krb5_mech) - target_name = gssapi.Name("host@" + self.targ_name, - gssapi.C_NT_HOSTBASED_SERVICE) - gss_ctxt = gssapi.InitContext(peer_name=target_name, - mech_type=krb5_oid, - req_flags=ctx.flags) + target_name = gssapi.Name( + "host@" + self.targ_name, gssapi.C_NT_HOSTBASED_SERVICE + ) + gss_ctxt = gssapi.InitContext( + peer_name=target_name, mech_type=krb5_oid, req_flags=ctx.flags + ) if self.server_mode: c_token = gss_ctxt.step(c_token) gss_ctxt_status = gss_ctxt.established @@ -108,15 +116,15 @@ class GSSAPITest(unittest.TestCase): self.assertEquals(0, status) else: gss_flags = ( - sspicon.ISC_REQ_INTEGRITY | - sspicon.ISC_REQ_MUTUAL_AUTH | - sspicon.ISC_REQ_DELEGATE + sspicon.ISC_REQ_INTEGRITY + | sspicon.ISC_REQ_MUTUAL_AUTH + | sspicon.ISC_REQ_DELEGATE ) # Initialize a GSS-API context. target_name = "host/" + socket.getfqdn(self.targ_name) - gss_ctxt = sspi.ClientAuth("Kerberos", - scflags=gss_flags, - targetspn=target_name) + gss_ctxt = sspi.ClientAuth( + "Kerberos", scflags=gss_flags, targetspn=target_name + ) if self.server_mode: error, token = gss_ctxt.authorize(c_token) c_token = token[0].Buffer diff --git a/tests/test_hostkeys.py b/tests/test_hostkeys.py index cd75f8ab..a1b7a9e0 100644 --- a/tests/test_hostkeys.py +++ b/tests/test_hostkeys.py @@ -54,77 +54,80 @@ Ngw3qIch/WgRmMHy4kBq1SsXMjQCte1So6HBMvBPIW5SiMTmjCfZZiw4AYHK+B/JaOwaG9yRg2Ejg\ 0d54U0X/NeX5QxuYR6OMJlrkQB7oiW/P/1mwjQgE=""" -class HostKeysTest (unittest.TestCase): +class HostKeysTest(unittest.TestCase): def setUp(self): - with open('hostfile.temp', 'w') as f: + with open("hostfile.temp", "w") as f: f.write(test_hosts_file) def tearDown(self): - os.unlink('hostfile.temp') + os.unlink("hostfile.temp") def test_1_load(self): - hostdict = paramiko.HostKeys('hostfile.temp') + hostdict = paramiko.HostKeys("hostfile.temp") self.assertEqual(2, len(hostdict)) self.assertEqual(1, len(list(hostdict.values())[0])) self.assertEqual(1, len(list(hostdict.values())[1])) - fp = hexlify(hostdict['secure.example.com']['ssh-rsa'].get_fingerprint()).upper() - self.assertEqual(b'E6684DB30E109B67B70FF1DC5C7F1363', fp) + fp = hexlify( + hostdict["secure.example.com"]["ssh-rsa"].get_fingerprint() + ).upper() + self.assertEqual(b"E6684DB30E109B67B70FF1DC5C7F1363", fp) def test_2_add(self): - hostdict = paramiko.HostKeys('hostfile.temp') - hh = '|1|BMsIC6cUIP2zBuXR3t2LRcJYjzM=|hpkJMysjTk/+zzUUzxQEa2ieq6c=' + hostdict = paramiko.HostKeys("hostfile.temp") + hh = "|1|BMsIC6cUIP2zBuXR3t2LRcJYjzM=|hpkJMysjTk/+zzUUzxQEa2ieq6c=" key = paramiko.RSAKey(data=decodebytes(keyblob)) - hostdict.add(hh, 'ssh-rsa', key) + hostdict.add(hh, "ssh-rsa", key) self.assertEqual(3, len(list(hostdict))) - x = hostdict['foo.example.com'] - fp = hexlify(x['ssh-rsa'].get_fingerprint()).upper() - self.assertEqual(b'7EC91BB336CB6D810B124B1353C32396', fp) - self.assertTrue(hostdict.check('foo.example.com', key)) + x = hostdict["foo.example.com"] + fp = hexlify(x["ssh-rsa"].get_fingerprint()).upper() + self.assertEqual(b"7EC91BB336CB6D810B124B1353C32396", fp) + self.assertTrue(hostdict.check("foo.example.com", key)) def test_3_dict(self): - hostdict = paramiko.HostKeys('hostfile.temp') - self.assertTrue('secure.example.com' in hostdict) - self.assertTrue('not.example.com' not in hostdict) - self.assertTrue('secure.example.com' in hostdict) - self.assertTrue('not.example.com' not in hostdict) - x = hostdict.get('secure.example.com', None) + hostdict = paramiko.HostKeys("hostfile.temp") + self.assertTrue("secure.example.com" in hostdict) + self.assertTrue("not.example.com" not in hostdict) + self.assertTrue("secure.example.com" in hostdict) + self.assertTrue("not.example.com" not in hostdict) + x = hostdict.get("secure.example.com", None) self.assertTrue(x is not None) - fp = hexlify(x['ssh-rsa'].get_fingerprint()).upper() - self.assertEqual(b'E6684DB30E109B67B70FF1DC5C7F1363', fp) + fp = hexlify(x["ssh-rsa"].get_fingerprint()).upper() + self.assertEqual(b"E6684DB30E109B67B70FF1DC5C7F1363", fp) i = 0 for key in hostdict: i += 1 self.assertEqual(2, i) - + def test_4_dict_set(self): - hostdict = paramiko.HostKeys('hostfile.temp') + hostdict = paramiko.HostKeys("hostfile.temp") key = paramiko.RSAKey(data=decodebytes(keyblob)) key_dss = paramiko.DSSKey(data=decodebytes(keyblob_dss)) - hostdict['secure.example.com'] = { - 'ssh-rsa': key, - 'ssh-dss': key_dss - } - hostdict['fake.example.com'] = {} - hostdict['fake.example.com']['ssh-rsa'] = key - + hostdict["secure.example.com"] = {"ssh-rsa": key, "ssh-dss": key_dss} + hostdict["fake.example.com"] = {} + hostdict["fake.example.com"]["ssh-rsa"] = key + self.assertEqual(3, len(hostdict)) self.assertEqual(2, len(list(hostdict.values())[0])) self.assertEqual(1, len(list(hostdict.values())[1])) self.assertEqual(1, len(list(hostdict.values())[2])) - fp = hexlify(hostdict['secure.example.com']['ssh-rsa'].get_fingerprint()).upper() - self.assertEqual(b'7EC91BB336CB6D810B124B1353C32396', fp) - fp = hexlify(hostdict['secure.example.com']['ssh-dss'].get_fingerprint()).upper() - self.assertEqual(b'4478F0B9A23CC5182009FF755BC1D26C', fp) + fp = hexlify( + hostdict["secure.example.com"]["ssh-rsa"].get_fingerprint() + ).upper() + self.assertEqual(b"7EC91BB336CB6D810B124B1353C32396", fp) + fp = hexlify( + hostdict["secure.example.com"]["ssh-dss"].get_fingerprint() + ).upper() + self.assertEqual(b"4478F0B9A23CC5182009FF755BC1D26C", fp) def test_delitem(self): - hostdict = paramiko.HostKeys('hostfile.temp') - target = 'happy.example.com' - entry = hostdict[target] # will KeyError if not present + hostdict = paramiko.HostKeys("hostfile.temp") + target = "happy.example.com" + entry = hostdict[target] # will KeyError if not present del hostdict[target] try: entry = hostdict[target] except KeyError: - pass # Good + pass # Good else: assert False, "Entry was not deleted from HostKeys on delitem!" diff --git a/tests/test_kex.py b/tests/test_kex.py index b5808e7e..13d19d86 100644 --- a/tests/test_kex.py +++ b/tests/test_kex.py @@ -38,30 +38,46 @@ from paramiko.kex_ecdh_nist import KexNistp256 def dummy_urandom(n): return byte_chr(0xcc) * n + def dummy_generate_key_pair(obj): - private_key_value = 94761803665136558137557783047955027733968423115106677159790289642479432803037 - public_key_numbers = "042bdab212fa8ba1b7c843301682a4db424d307246c7e1e6083c41d9ca7b098bf30b3d63e2ec6278488c135360456cc054b3444ecc45998c08894cbc1370f5f989" - public_key_numbers_obj = ec.EllipticCurvePublicNumbers.from_encoded_point(ec.SECP256R1(), unhexlify(public_key_numbers)) - obj.P = ec.EllipticCurvePrivateNumbers(private_value=private_key_value, public_numbers=public_key_numbers_obj).private_key(default_backend()) + private_key_value = ( + 94761803665136558137557783047955027733968423115106677159790289642479432803037 + ) + public_key_numbers = ( + "042bdab212fa8ba1b7c843301682a4db424d307246c7e1e6083c41d9ca7b098bf30b3d63e2ec6278488c135360456cc054b3444ecc45998c08894cbc1370f5f989" + ) + public_key_numbers_obj = ec.EllipticCurvePublicNumbers.from_encoded_point( + ec.SECP256R1(), unhexlify(public_key_numbers) + ) + obj.P = ec.EllipticCurvePrivateNumbers( + private_value=private_key_value, public_numbers=public_key_numbers_obj + ).private_key(default_backend()) if obj.transport.server_mode: - obj.Q_S = ec.EllipticCurvePublicNumbers.from_encoded_point(ec.SECP256R1(), unhexlify(public_key_numbers)).public_key(default_backend()) + obj.Q_S = ec.EllipticCurvePublicNumbers.from_encoded_point( + ec.SECP256R1(), unhexlify(public_key_numbers) + ).public_key(default_backend()) return - obj.Q_C = ec.EllipticCurvePublicNumbers.from_encoded_point(ec.SECP256R1(), unhexlify(public_key_numbers)).public_key(default_backend()) + obj.Q_C = ec.EllipticCurvePublicNumbers.from_encoded_point( + ec.SECP256R1(), unhexlify(public_key_numbers) + ).public_key(default_backend()) + +class FakeKey(object): -class FakeKey (object): def __str__(self): - return 'fake-key' + return "fake-key" def asbytes(self): - return b'fake-key' + return b"fake-key" def sign_ssh_data(self, H): - return b'fake-sig' + return b"fake-sig" -class FakeModulusPack (object): - P = 0xFFFFFFFFFFFFFFFFC90FDAA22168C234C4C6628B80DC1CD129024E088A67CC74020BBEA63B139B22514A08798E3404DDEF9519B3CD3A431B302B0A6DF25F14374FE1356D6D51C245E485B576625E7EC6F44C42E9A637ED6B0BFF5CB6F406B7EDEE386BFB5A899FA5AE9F24117C4B1FE649286651ECE65381FFFFFFFFFFFFFFFF +class FakeModulusPack(object): + P = ( + 0xFFFFFFFFFFFFFFFFC90FDAA22168C234C4C6628B80DC1CD129024E088A67CC74020BBEA63B139B22514A08798E3404DDEF9519B3CD3A431B302B0A6DF25F14374FE1356D6D51C245E485B576625E7EC6F44C42E9A637ED6B0BFF5CB6F406B7EDEE386BFB5A899FA5AE9F24117C4B1FE649286651ECE65381FFFFFFFFFFFFFFFF + ) G = 2 def get_modulus(self, min, ask, max): @@ -69,10 +85,10 @@ class FakeModulusPack (object): class FakeTransport(object): - local_version = 'SSH-2.0-paramiko_1.0' - remote_version = 'SSH-2.0-lame' - local_kex_init = 'local-kex-init' - remote_kex_init = 'remote-kex-init' + local_version = "SSH-2.0-paramiko_1.0" + remote_version = "SSH-2.0-lame" + local_kex_init = "local-kex-init" + remote_kex_init = "remote-kex-init" def _send_message(self, m): self._message = m @@ -100,9 +116,11 @@ class FakeTransport(object): return FakeModulusPack() -class KexTest (unittest.TestCase): +class KexTest(unittest.TestCase): - K = 14730343317708716439807310032871972459448364195094179797249681733965528989482751523943515690110179031004049109375612685505881911274101441415545039654102474376472240501616988799699744135291070488314748284283496055223852115360852283821334858541043710301057312858051901453919067023103730011648890038847384890504 + K = ( + 14730343317708716439807310032871972459448364195094179797249681733965528989482751523943515690110179031004049109375612685505881911274101441415545039654102474376472240501616988799699744135291070488314748284283496055223852115360852283821334858541043710301057312858051901453919067023103730011648890038847384890504 + ) def setUp(self): self._original_urandom = os.urandom @@ -119,21 +137,25 @@ class KexTest (unittest.TestCase): transport.server_mode = False kex = KexGroup1(transport) kex.start_kex() - x = b'1E000000807E2DDB1743F3487D6545F04F1C8476092FB912B013626AB5BCEB764257D88BBA64243B9F348DF7B41B8C814A995E00299913503456983FFB9178D3CD79EB6D55522418A8ABF65375872E55938AB99A84A0B5FC8A1ECC66A7C3766E7E0F80B7CE2C9225FC2DD683F4764244B72963BBB383F529DCF0C5D17740B8A2ADBE9208D4' + x = ( + b"1E000000807E2DDB1743F3487D6545F04F1C8476092FB912B013626AB5BCEB764257D88BBA64243B9F348DF7B41B8C814A995E00299913503456983FFB9178D3CD79EB6D55522418A8ABF65375872E55938AB99A84A0B5FC8A1ECC66A7C3766E7E0F80B7CE2C9225FC2DD683F4764244B72963BBB383F529DCF0C5D17740B8A2ADBE9208D4" + ) self.assertEqual(x, hexlify(transport._message.asbytes()).upper()) - self.assertEqual((paramiko.kex_group1._MSG_KEXDH_REPLY,), transport._expect) + self.assertEqual( + (paramiko.kex_group1._MSG_KEXDH_REPLY,), transport._expect + ) # fake "reply" msg = Message() - msg.add_string('fake-host-key') + msg.add_string("fake-host-key") msg.add_mpint(69) - msg.add_string('fake-sig') + msg.add_string("fake-sig") msg.rewind() kex.parse_next(paramiko.kex_group1._MSG_KEXDH_REPLY, msg) - H = b'03079780F3D3AD0B3C6DB30C8D21685F367A86D2' + H = b"03079780F3D3AD0B3C6DB30C8D21685F367A86D2" self.assertEqual(self.K, transport._K) self.assertEqual(H, hexlify(transport._H).upper()) - self.assertEqual((b'fake-host-key', b'fake-sig'), transport._verify) + self.assertEqual((b"fake-host-key", b"fake-sig"), transport._verify) self.assertTrue(transport._activated) def test_2_group1_server(self): @@ -141,14 +163,18 @@ class KexTest (unittest.TestCase): transport.server_mode = True kex = KexGroup1(transport) kex.start_kex() - self.assertEqual((paramiko.kex_group1._MSG_KEXDH_INIT,), transport._expect) + self.assertEqual( + (paramiko.kex_group1._MSG_KEXDH_INIT,), transport._expect + ) msg = Message() msg.add_mpint(69) msg.rewind() kex.parse_next(paramiko.kex_group1._MSG_KEXDH_INIT, msg) - H = b'B16BF34DD10945EDE84E9C1EF24A14BFDC843389' - x = b'1F0000000866616B652D6B6579000000807E2DDB1743F3487D6545F04F1C8476092FB912B013626AB5BCEB764257D88BBA64243B9F348DF7B41B8C814A995E00299913503456983FFB9178D3CD79EB6D55522418A8ABF65375872E55938AB99A84A0B5FC8A1ECC66A7C3766E7E0F80B7CE2C9225FC2DD683F4764244B72963BBB383F529DCF0C5D17740B8A2ADBE9208D40000000866616B652D736967' + H = b"B16BF34DD10945EDE84E9C1EF24A14BFDC843389" + x = ( + b"1F0000000866616B652D6B6579000000807E2DDB1743F3487D6545F04F1C8476092FB912B013626AB5BCEB764257D88BBA64243B9F348DF7B41B8C814A995E00299913503456983FFB9178D3CD79EB6D55522418A8ABF65375872E55938AB99A84A0B5FC8A1ECC66A7C3766E7E0F80B7CE2C9225FC2DD683F4764244B72963BBB383F529DCF0C5D17740B8A2ADBE9208D40000000866616B652D736967" + ) self.assertEqual(self.K, transport._K) self.assertEqual(H, hexlify(transport._H).upper()) self.assertEqual(x, hexlify(transport._message.asbytes()).upper()) @@ -159,29 +185,35 @@ class KexTest (unittest.TestCase): transport.server_mode = False kex = KexGex(transport) kex.start_kex() - x = b'22000004000000080000002000' + x = b"22000004000000080000002000" self.assertEqual(x, hexlify(transport._message.asbytes()).upper()) - self.assertEqual((paramiko.kex_gex._MSG_KEXDH_GEX_GROUP,), transport._expect) + self.assertEqual( + (paramiko.kex_gex._MSG_KEXDH_GEX_GROUP,), transport._expect + ) msg = Message() msg.add_mpint(FakeModulusPack.P) msg.add_mpint(FakeModulusPack.G) msg.rewind() kex.parse_next(paramiko.kex_gex._MSG_KEXDH_GEX_GROUP, msg) - x = b'20000000807E2DDB1743F3487D6545F04F1C8476092FB912B013626AB5BCEB764257D88BBA64243B9F348DF7B41B8C814A995E00299913503456983FFB9178D3CD79EB6D55522418A8ABF65375872E55938AB99A84A0B5FC8A1ECC66A7C3766E7E0F80B7CE2C9225FC2DD683F4764244B72963BBB383F529DCF0C5D17740B8A2ADBE9208D4' + x = ( + b"20000000807E2DDB1743F3487D6545F04F1C8476092FB912B013626AB5BCEB764257D88BBA64243B9F348DF7B41B8C814A995E00299913503456983FFB9178D3CD79EB6D55522418A8ABF65375872E55938AB99A84A0B5FC8A1ECC66A7C3766E7E0F80B7CE2C9225FC2DD683F4764244B72963BBB383F529DCF0C5D17740B8A2ADBE9208D4" + ) self.assertEqual(x, hexlify(transport._message.asbytes()).upper()) - self.assertEqual((paramiko.kex_gex._MSG_KEXDH_GEX_REPLY,), transport._expect) + self.assertEqual( + (paramiko.kex_gex._MSG_KEXDH_GEX_REPLY,), transport._expect + ) msg = Message() - msg.add_string('fake-host-key') + msg.add_string("fake-host-key") msg.add_mpint(69) - msg.add_string('fake-sig') + msg.add_string("fake-sig") msg.rewind() kex.parse_next(paramiko.kex_gex._MSG_KEXDH_GEX_REPLY, msg) - H = b'A265563F2FA87F1A89BF007EE90D58BE2E4A4BD0' + H = b"A265563F2FA87F1A89BF007EE90D58BE2E4A4BD0" self.assertEqual(self.K, transport._K) self.assertEqual(H, hexlify(transport._H).upper()) - self.assertEqual((b'fake-host-key', b'fake-sig'), transport._verify) + self.assertEqual((b"fake-host-key", b"fake-sig"), transport._verify) self.assertTrue(transport._activated) def test_4_gex_old_client(self): @@ -189,37 +221,49 @@ class KexTest (unittest.TestCase): transport.server_mode = False kex = KexGex(transport) kex.start_kex(_test_old_style=True) - x = b'1E00000800' + x = b"1E00000800" self.assertEqual(x, hexlify(transport._message.asbytes()).upper()) - self.assertEqual((paramiko.kex_gex._MSG_KEXDH_GEX_GROUP,), transport._expect) + self.assertEqual( + (paramiko.kex_gex._MSG_KEXDH_GEX_GROUP,), transport._expect + ) msg = Message() msg.add_mpint(FakeModulusPack.P) msg.add_mpint(FakeModulusPack.G) msg.rewind() kex.parse_next(paramiko.kex_gex._MSG_KEXDH_GEX_GROUP, msg) - x = b'20000000807E2DDB1743F3487D6545F04F1C8476092FB912B013626AB5BCEB764257D88BBA64243B9F348DF7B41B8C814A995E00299913503456983FFB9178D3CD79EB6D55522418A8ABF65375872E55938AB99A84A0B5FC8A1ECC66A7C3766E7E0F80B7CE2C9225FC2DD683F4764244B72963BBB383F529DCF0C5D17740B8A2ADBE9208D4' + x = ( + b"20000000807E2DDB1743F3487D6545F04F1C8476092FB912B013626AB5BCEB764257D88BBA64243B9F348DF7B41B8C814A995E00299913503456983FFB9178D3CD79EB6D55522418A8ABF65375872E55938AB99A84A0B5FC8A1ECC66A7C3766E7E0F80B7CE2C9225FC2DD683F4764244B72963BBB383F529DCF0C5D17740B8A2ADBE9208D4" + ) self.assertEqual(x, hexlify(transport._message.asbytes()).upper()) - self.assertEqual((paramiko.kex_gex._MSG_KEXDH_GEX_REPLY,), transport._expect) + self.assertEqual( + (paramiko.kex_gex._MSG_KEXDH_GEX_REPLY,), transport._expect + ) msg = Message() - msg.add_string('fake-host-key') + msg.add_string("fake-host-key") msg.add_mpint(69) - msg.add_string('fake-sig') + msg.add_string("fake-sig") msg.rewind() kex.parse_next(paramiko.kex_gex._MSG_KEXDH_GEX_REPLY, msg) - H = b'807F87B269EF7AC5EC7E75676808776A27D5864C' + H = b"807F87B269EF7AC5EC7E75676808776A27D5864C" self.assertEqual(self.K, transport._K) self.assertEqual(H, hexlify(transport._H).upper()) - self.assertEqual((b'fake-host-key', b'fake-sig'), transport._verify) + self.assertEqual((b"fake-host-key", b"fake-sig"), transport._verify) self.assertTrue(transport._activated) - + def test_5_gex_server(self): transport = FakeTransport() transport.server_mode = True kex = KexGex(transport) kex.start_kex() - self.assertEqual((paramiko.kex_gex._MSG_KEXDH_GEX_REQUEST, paramiko.kex_gex._MSG_KEXDH_GEX_REQUEST_OLD), transport._expect) + self.assertEqual( + ( + paramiko.kex_gex._MSG_KEXDH_GEX_REQUEST, + paramiko.kex_gex._MSG_KEXDH_GEX_REQUEST_OLD, + ), + transport._expect, + ) msg = Message() msg.add_int(1024) @@ -227,17 +271,25 @@ class KexTest (unittest.TestCase): msg.add_int(4096) msg.rewind() kex.parse_next(paramiko.kex_gex._MSG_KEXDH_GEX_REQUEST, msg) - x = b'1F0000008100FFFFFFFFFFFFFFFFC90FDAA22168C234C4C6628B80DC1CD129024E088A67CC74020BBEA63B139B22514A08798E3404DDEF9519B3CD3A431B302B0A6DF25F14374FE1356D6D51C245E485B576625E7EC6F44C42E9A637ED6B0BFF5CB6F406B7EDEE386BFB5A899FA5AE9F24117C4B1FE649286651ECE65381FFFFFFFFFFFFFFFF0000000102' + x = ( + b"1F0000008100FFFFFFFFFFFFFFFFC90FDAA22168C234C4C6628B80DC1CD129024E088A67CC74020BBEA63B139B22514A08798E3404DDEF9519B3CD3A431B302B0A6DF25F14374FE1356D6D51C245E485B576625E7EC6F44C42E9A637ED6B0BFF5CB6F406B7EDEE386BFB5A899FA5AE9F24117C4B1FE649286651ECE65381FFFFFFFFFFFFFFFF0000000102" + ) self.assertEqual(x, hexlify(transport._message.asbytes()).upper()) - self.assertEqual((paramiko.kex_gex._MSG_KEXDH_GEX_INIT,), transport._expect) + self.assertEqual( + (paramiko.kex_gex._MSG_KEXDH_GEX_INIT,), transport._expect + ) msg = Message() msg.add_mpint(12345) msg.rewind() kex.parse_next(paramiko.kex_gex._MSG_KEXDH_GEX_INIT, msg) - K = 67592995013596137876033460028393339951879041140378510871612128162185209509220726296697886624612526735888348020498716482757677848959420073720160491114319163078862905400020959196386947926388406687288901564192071077389283980347784184487280885335302632305026248574716290537036069329724382811853044654824945750581 - H = b'CE754197C21BF3452863B4F44D0B3951F12516EF' - x = b'210000000866616B652D6B6579000000807E2DDB1743F3487D6545F04F1C8476092FB912B013626AB5BCEB764257D88BBA64243B9F348DF7B41B8C814A995E00299913503456983FFB9178D3CD79EB6D55522418A8ABF65375872E55938AB99A84A0B5FC8A1ECC66A7C3766E7E0F80B7CE2C9225FC2DD683F4764244B72963BBB383F529DCF0C5D17740B8A2ADBE9208D40000000866616B652D736967' + K = ( + 67592995013596137876033460028393339951879041140378510871612128162185209509220726296697886624612526735888348020498716482757677848959420073720160491114319163078862905400020959196386947926388406687288901564192071077389283980347784184487280885335302632305026248574716290537036069329724382811853044654824945750581 + ) + H = b"CE754197C21BF3452863B4F44D0B3951F12516EF" + x = ( + b"210000000866616B652D6B6579000000807E2DDB1743F3487D6545F04F1C8476092FB912B013626AB5BCEB764257D88BBA64243B9F348DF7B41B8C814A995E00299913503456983FFB9178D3CD79EB6D55522418A8ABF65375872E55938AB99A84A0B5FC8A1ECC66A7C3766E7E0F80B7CE2C9225FC2DD683F4764244B72963BBB383F529DCF0C5D17740B8A2ADBE9208D40000000866616B652D736967" + ) self.assertEqual(K, transport._K) self.assertEqual(H, hexlify(transport._H).upper()) self.assertEqual(x, hexlify(transport._message.asbytes()).upper()) @@ -248,23 +300,37 @@ class KexTest (unittest.TestCase): transport.server_mode = True kex = KexGex(transport) kex.start_kex() - self.assertEqual((paramiko.kex_gex._MSG_KEXDH_GEX_REQUEST, paramiko.kex_gex._MSG_KEXDH_GEX_REQUEST_OLD), transport._expect) + self.assertEqual( + ( + paramiko.kex_gex._MSG_KEXDH_GEX_REQUEST, + paramiko.kex_gex._MSG_KEXDH_GEX_REQUEST_OLD, + ), + transport._expect, + ) msg = Message() msg.add_int(2048) msg.rewind() kex.parse_next(paramiko.kex_gex._MSG_KEXDH_GEX_REQUEST_OLD, msg) - x = b'1F0000008100FFFFFFFFFFFFFFFFC90FDAA22168C234C4C6628B80DC1CD129024E088A67CC74020BBEA63B139B22514A08798E3404DDEF9519B3CD3A431B302B0A6DF25F14374FE1356D6D51C245E485B576625E7EC6F44C42E9A637ED6B0BFF5CB6F406B7EDEE386BFB5A899FA5AE9F24117C4B1FE649286651ECE65381FFFFFFFFFFFFFFFF0000000102' + x = ( + b"1F0000008100FFFFFFFFFFFFFFFFC90FDAA22168C234C4C6628B80DC1CD129024E088A67CC74020BBEA63B139B22514A08798E3404DDEF9519B3CD3A431B302B0A6DF25F14374FE1356D6D51C245E485B576625E7EC6F44C42E9A637ED6B0BFF5CB6F406B7EDEE386BFB5A899FA5AE9F24117C4B1FE649286651ECE65381FFFFFFFFFFFFFFFF0000000102" + ) self.assertEqual(x, hexlify(transport._message.asbytes()).upper()) - self.assertEqual((paramiko.kex_gex._MSG_KEXDH_GEX_INIT,), transport._expect) + self.assertEqual( + (paramiko.kex_gex._MSG_KEXDH_GEX_INIT,), transport._expect + ) msg = Message() msg.add_mpint(12345) msg.rewind() kex.parse_next(paramiko.kex_gex._MSG_KEXDH_GEX_INIT, msg) - K = 67592995013596137876033460028393339951879041140378510871612128162185209509220726296697886624612526735888348020498716482757677848959420073720160491114319163078862905400020959196386947926388406687288901564192071077389283980347784184487280885335302632305026248574716290537036069329724382811853044654824945750581 - H = b'B41A06B2E59043CEFC1AE16EC31F1E2D12EC455B' - x = b'210000000866616B652D6B6579000000807E2DDB1743F3487D6545F04F1C8476092FB912B013626AB5BCEB764257D88BBA64243B9F348DF7B41B8C814A995E00299913503456983FFB9178D3CD79EB6D55522418A8ABF65375872E55938AB99A84A0B5FC8A1ECC66A7C3766E7E0F80B7CE2C9225FC2DD683F4764244B72963BBB383F529DCF0C5D17740B8A2ADBE9208D40000000866616B652D736967' + K = ( + 67592995013596137876033460028393339951879041140378510871612128162185209509220726296697886624612526735888348020498716482757677848959420073720160491114319163078862905400020959196386947926388406687288901564192071077389283980347784184487280885335302632305026248574716290537036069329724382811853044654824945750581 + ) + H = b"B41A06B2E59043CEFC1AE16EC31F1E2D12EC455B" + x = ( + b"210000000866616B652D6B6579000000807E2DDB1743F3487D6545F04F1C8476092FB912B013626AB5BCEB764257D88BBA64243B9F348DF7B41B8C814A995E00299913503456983FFB9178D3CD79EB6D55522418A8ABF65375872E55938AB99A84A0B5FC8A1ECC66A7C3766E7E0F80B7CE2C9225FC2DD683F4764244B72963BBB383F529DCF0C5D17740B8A2ADBE9208D40000000866616B652D736967" + ) self.assertEqual(K, transport._K) self.assertEqual(H, hexlify(transport._H).upper()) self.assertEqual(x, hexlify(transport._message.asbytes()).upper()) @@ -275,29 +341,35 @@ class KexTest (unittest.TestCase): transport.server_mode = False kex = KexGexSHA256(transport) kex.start_kex() - x = b'22000004000000080000002000' + x = b"22000004000000080000002000" self.assertEqual(x, hexlify(transport._message.asbytes()).upper()) - self.assertEqual((paramiko.kex_gex._MSG_KEXDH_GEX_GROUP,), transport._expect) + self.assertEqual( + (paramiko.kex_gex._MSG_KEXDH_GEX_GROUP,), transport._expect + ) msg = Message() msg.add_mpint(FakeModulusPack.P) msg.add_mpint(FakeModulusPack.G) msg.rewind() kex.parse_next(paramiko.kex_gex._MSG_KEXDH_GEX_GROUP, msg) - x = b'20000000807E2DDB1743F3487D6545F04F1C8476092FB912B013626AB5BCEB764257D88BBA64243B9F348DF7B41B8C814A995E00299913503456983FFB9178D3CD79EB6D55522418A8ABF65375872E55938AB99A84A0B5FC8A1ECC66A7C3766E7E0F80B7CE2C9225FC2DD683F4764244B72963BBB383F529DCF0C5D17740B8A2ADBE9208D4' + x = ( + b"20000000807E2DDB1743F3487D6545F04F1C8476092FB912B013626AB5BCEB764257D88BBA64243B9F348DF7B41B8C814A995E00299913503456983FFB9178D3CD79EB6D55522418A8ABF65375872E55938AB99A84A0B5FC8A1ECC66A7C3766E7E0F80B7CE2C9225FC2DD683F4764244B72963BBB383F529DCF0C5D17740B8A2ADBE9208D4" + ) self.assertEqual(x, hexlify(transport._message.asbytes()).upper()) - self.assertEqual((paramiko.kex_gex._MSG_KEXDH_GEX_REPLY,), transport._expect) + self.assertEqual( + (paramiko.kex_gex._MSG_KEXDH_GEX_REPLY,), transport._expect + ) msg = Message() - msg.add_string('fake-host-key') + msg.add_string("fake-host-key") msg.add_mpint(69) - msg.add_string('fake-sig') + msg.add_string("fake-sig") msg.rewind() kex.parse_next(paramiko.kex_gex._MSG_KEXDH_GEX_REPLY, msg) - H = b'AD1A9365A67B4496F05594AD1BF656E3CDA0851289A4C1AFF549FEAE50896DF4' + H = b"AD1A9365A67B4496F05594AD1BF656E3CDA0851289A4C1AFF549FEAE50896DF4" self.assertEqual(self.K, transport._K) self.assertEqual(H, hexlify(transport._H).upper()) - self.assertEqual((b'fake-host-key', b'fake-sig'), transport._verify) + self.assertEqual((b"fake-host-key", b"fake-sig"), transport._verify) self.assertTrue(transport._activated) def test_8_gex_sha256_old_client(self): @@ -305,29 +377,35 @@ class KexTest (unittest.TestCase): transport.server_mode = False kex = KexGexSHA256(transport) kex.start_kex(_test_old_style=True) - x = b'1E00000800' + x = b"1E00000800" self.assertEqual(x, hexlify(transport._message.asbytes()).upper()) - self.assertEqual((paramiko.kex_gex._MSG_KEXDH_GEX_GROUP,), transport._expect) + self.assertEqual( + (paramiko.kex_gex._MSG_KEXDH_GEX_GROUP,), transport._expect + ) msg = Message() msg.add_mpint(FakeModulusPack.P) msg.add_mpint(FakeModulusPack.G) msg.rewind() kex.parse_next(paramiko.kex_gex._MSG_KEXDH_GEX_GROUP, msg) - x = b'20000000807E2DDB1743F3487D6545F04F1C8476092FB912B013626AB5BCEB764257D88BBA64243B9F348DF7B41B8C814A995E00299913503456983FFB9178D3CD79EB6D55522418A8ABF65375872E55938AB99A84A0B5FC8A1ECC66A7C3766E7E0F80B7CE2C9225FC2DD683F4764244B72963BBB383F529DCF0C5D17740B8A2ADBE9208D4' + x = ( + b"20000000807E2DDB1743F3487D6545F04F1C8476092FB912B013626AB5BCEB764257D88BBA64243B9F348DF7B41B8C814A995E00299913503456983FFB9178D3CD79EB6D55522418A8ABF65375872E55938AB99A84A0B5FC8A1ECC66A7C3766E7E0F80B7CE2C9225FC2DD683F4764244B72963BBB383F529DCF0C5D17740B8A2ADBE9208D4" + ) self.assertEqual(x, hexlify(transport._message.asbytes()).upper()) - self.assertEqual((paramiko.kex_gex._MSG_KEXDH_GEX_REPLY,), transport._expect) + self.assertEqual( + (paramiko.kex_gex._MSG_KEXDH_GEX_REPLY,), transport._expect + ) msg = Message() - msg.add_string('fake-host-key') + msg.add_string("fake-host-key") msg.add_mpint(69) - msg.add_string('fake-sig') + msg.add_string("fake-sig") msg.rewind() kex.parse_next(paramiko.kex_gex._MSG_KEXDH_GEX_REPLY, msg) - H = b'518386608B15891AE5237DEE08DCADDE76A0BCEFCE7F6DB3AD66BC41D256DFE5' + H = b"518386608B15891AE5237DEE08DCADDE76A0BCEFCE7F6DB3AD66BC41D256DFE5" self.assertEqual(self.K, transport._K) self.assertEqual(H, hexlify(transport._H).upper()) - self.assertEqual((b'fake-host-key', b'fake-sig'), transport._verify) + self.assertEqual((b"fake-host-key", b"fake-sig"), transport._verify) self.assertTrue(transport._activated) def test_9_gex_sha256_server(self): @@ -335,7 +413,13 @@ class KexTest (unittest.TestCase): transport.server_mode = True kex = KexGexSHA256(transport) kex.start_kex() - self.assertEqual((paramiko.kex_gex._MSG_KEXDH_GEX_REQUEST, paramiko.kex_gex._MSG_KEXDH_GEX_REQUEST_OLD), transport._expect) + self.assertEqual( + ( + paramiko.kex_gex._MSG_KEXDH_GEX_REQUEST, + paramiko.kex_gex._MSG_KEXDH_GEX_REQUEST_OLD, + ), + transport._expect, + ) msg = Message() msg.add_int(1024) @@ -343,17 +427,25 @@ class KexTest (unittest.TestCase): msg.add_int(4096) msg.rewind() kex.parse_next(paramiko.kex_gex._MSG_KEXDH_GEX_REQUEST, msg) - x = b'1F0000008100FFFFFFFFFFFFFFFFC90FDAA22168C234C4C6628B80DC1CD129024E088A67CC74020BBEA63B139B22514A08798E3404DDEF9519B3CD3A431B302B0A6DF25F14374FE1356D6D51C245E485B576625E7EC6F44C42E9A637ED6B0BFF5CB6F406B7EDEE386BFB5A899FA5AE9F24117C4B1FE649286651ECE65381FFFFFFFFFFFFFFFF0000000102' + x = ( + b"1F0000008100FFFFFFFFFFFFFFFFC90FDAA22168C234C4C6628B80DC1CD129024E088A67CC74020BBEA63B139B22514A08798E3404DDEF9519B3CD3A431B302B0A6DF25F14374FE1356D6D51C245E485B576625E7EC6F44C42E9A637ED6B0BFF5CB6F406B7EDEE386BFB5A899FA5AE9F24117C4B1FE649286651ECE65381FFFFFFFFFFFFFFFF0000000102" + ) self.assertEqual(x, hexlify(transport._message.asbytes()).upper()) - self.assertEqual((paramiko.kex_gex._MSG_KEXDH_GEX_INIT,), transport._expect) + self.assertEqual( + (paramiko.kex_gex._MSG_KEXDH_GEX_INIT,), transport._expect + ) msg = Message() msg.add_mpint(12345) msg.rewind() kex.parse_next(paramiko.kex_gex._MSG_KEXDH_GEX_INIT, msg) - K = 67592995013596137876033460028393339951879041140378510871612128162185209509220726296697886624612526735888348020498716482757677848959420073720160491114319163078862905400020959196386947926388406687288901564192071077389283980347784184487280885335302632305026248574716290537036069329724382811853044654824945750581 - H = b'CCAC0497CF0ABA1DBF55E1A3995D17F4CC31824B0E8D95CDF8A06F169D050D80' - x = b'210000000866616B652D6B6579000000807E2DDB1743F3487D6545F04F1C8476092FB912B013626AB5BCEB764257D88BBA64243B9F348DF7B41B8C814A995E00299913503456983FFB9178D3CD79EB6D55522418A8ABF65375872E55938AB99A84A0B5FC8A1ECC66A7C3766E7E0F80B7CE2C9225FC2DD683F4764244B72963BBB383F529DCF0C5D17740B8A2ADBE9208D40000000866616B652D736967' + K = ( + 67592995013596137876033460028393339951879041140378510871612128162185209509220726296697886624612526735888348020498716482757677848959420073720160491114319163078862905400020959196386947926388406687288901564192071077389283980347784184487280885335302632305026248574716290537036069329724382811853044654824945750581 + ) + H = b"CCAC0497CF0ABA1DBF55E1A3995D17F4CC31824B0E8D95CDF8A06F169D050D80" + x = ( + b"210000000866616B652D6B6579000000807E2DDB1743F3487D6545F04F1C8476092FB912B013626AB5BCEB764257D88BBA64243B9F348DF7B41B8C814A995E00299913503456983FFB9178D3CD79EB6D55522418A8ABF65375872E55938AB99A84A0B5FC8A1ECC66A7C3766E7E0F80B7CE2C9225FC2DD683F4764244B72963BBB383F529DCF0C5D17740B8A2ADBE9208D40000000866616B652D736967" + ) self.assertEqual(K, transport._K) self.assertEqual(H, hexlify(transport._H).upper()) self.assertEqual(x, hexlify(transport._message.asbytes()).upper()) @@ -364,62 +456,88 @@ class KexTest (unittest.TestCase): transport.server_mode = True kex = KexGexSHA256(transport) kex.start_kex() - self.assertEqual((paramiko.kex_gex._MSG_KEXDH_GEX_REQUEST, paramiko.kex_gex._MSG_KEXDH_GEX_REQUEST_OLD), transport._expect) + self.assertEqual( + ( + paramiko.kex_gex._MSG_KEXDH_GEX_REQUEST, + paramiko.kex_gex._MSG_KEXDH_GEX_REQUEST_OLD, + ), + transport._expect, + ) msg = Message() msg.add_int(2048) msg.rewind() kex.parse_next(paramiko.kex_gex._MSG_KEXDH_GEX_REQUEST_OLD, msg) - x = b'1F0000008100FFFFFFFFFFFFFFFFC90FDAA22168C234C4C6628B80DC1CD129024E088A67CC74020BBEA63B139B22514A08798E3404DDEF9519B3CD3A431B302B0A6DF25F14374FE1356D6D51C245E485B576625E7EC6F44C42E9A637ED6B0BFF5CB6F406B7EDEE386BFB5A899FA5AE9F24117C4B1FE649286651ECE65381FFFFFFFFFFFFFFFF0000000102' + x = ( + b"1F0000008100FFFFFFFFFFFFFFFFC90FDAA22168C234C4C6628B80DC1CD129024E088A67CC74020BBEA63B139B22514A08798E3404DDEF9519B3CD3A431B302B0A6DF25F14374FE1356D6D51C245E485B576625E7EC6F44C42E9A637ED6B0BFF5CB6F406B7EDEE386BFB5A899FA5AE9F24117C4B1FE649286651ECE65381FFFFFFFFFFFFFFFF0000000102" + ) self.assertEqual(x, hexlify(transport._message.asbytes()).upper()) - self.assertEqual((paramiko.kex_gex._MSG_KEXDH_GEX_INIT,), transport._expect) + self.assertEqual( + (paramiko.kex_gex._MSG_KEXDH_GEX_INIT,), transport._expect + ) msg = Message() msg.add_mpint(12345) msg.rewind() kex.parse_next(paramiko.kex_gex._MSG_KEXDH_GEX_INIT, msg) - K = 67592995013596137876033460028393339951879041140378510871612128162185209509220726296697886624612526735888348020498716482757677848959420073720160491114319163078862905400020959196386947926388406687288901564192071077389283980347784184487280885335302632305026248574716290537036069329724382811853044654824945750581 - H = b'3DDD2AD840AD095E397BA4D0573972DC60F6461FD38A187CACA6615A5BC8ADBB' - x = b'210000000866616B652D6B6579000000807E2DDB1743F3487D6545F04F1C8476092FB912B013626AB5BCEB764257D88BBA64243B9F348DF7B41B8C814A995E00299913503456983FFB9178D3CD79EB6D55522418A8ABF65375872E55938AB99A84A0B5FC8A1ECC66A7C3766E7E0F80B7CE2C9225FC2DD683F4764244B72963BBB383F529DCF0C5D17740B8A2ADBE9208D40000000866616B652D736967' + K = ( + 67592995013596137876033460028393339951879041140378510871612128162185209509220726296697886624612526735888348020498716482757677848959420073720160491114319163078862905400020959196386947926388406687288901564192071077389283980347784184487280885335302632305026248574716290537036069329724382811853044654824945750581 + ) + H = b"3DDD2AD840AD095E397BA4D0573972DC60F6461FD38A187CACA6615A5BC8ADBB" + x = ( + b"210000000866616B652D6B6579000000807E2DDB1743F3487D6545F04F1C8476092FB912B013626AB5BCEB764257D88BBA64243B9F348DF7B41B8C814A995E00299913503456983FFB9178D3CD79EB6D55522418A8ABF65375872E55938AB99A84A0B5FC8A1ECC66A7C3766E7E0F80B7CE2C9225FC2DD683F4764244B72963BBB383F529DCF0C5D17740B8A2ADBE9208D40000000866616B652D736967" + ) self.assertEqual(K, transport._K) self.assertEqual(H, hexlify(transport._H).upper()) self.assertEqual(x, hexlify(transport._message.asbytes()).upper()) self.assertTrue(transport._activated) def test_11_kex_nistp256_client(self): - K = 91610929826364598472338906427792435253694642563583721654249504912114314269754 + K = ( + 91610929826364598472338906427792435253694642563583721654249504912114314269754 + ) transport = FakeTransport() transport.server_mode = False kex = KexNistp256(transport) kex.start_kex() - self.assertEqual((paramiko.kex_ecdh_nist._MSG_KEXECDH_REPLY,), transport._expect) + self.assertEqual( + (paramiko.kex_ecdh_nist._MSG_KEXECDH_REPLY,), transport._expect + ) - #fake reply + # fake reply msg = Message() - msg.add_string('fake-host-key') - Q_S = unhexlify("043ae159594ba062efa121480e9ef136203fa9ec6b6e1f8723a321c16e62b945f573f3b822258cbcd094b9fa1c125cbfe5f043280893e66863cc0cb4dccbe70210") + msg.add_string("fake-host-key") + Q_S = unhexlify( + "043ae159594ba062efa121480e9ef136203fa9ec6b6e1f8723a321c16e62b945f573f3b822258cbcd094b9fa1c125cbfe5f043280893e66863cc0cb4dccbe70210" + ) msg.add_string(Q_S) - msg.add_string('fake-sig') + msg.add_string("fake-sig") msg.rewind() kex.parse_next(paramiko.kex_ecdh_nist._MSG_KEXECDH_REPLY, msg) - H = b'BAF7CE243A836037EB5D2221420F35C02B9AB6C957FE3BDE3369307B9612570A' + H = b"BAF7CE243A836037EB5D2221420F35C02B9AB6C957FE3BDE3369307B9612570A" self.assertEqual(K, kex.transport._K) self.assertEqual(H, hexlify(transport._H).upper()) - self.assertEqual((b'fake-host-key', b'fake-sig'), transport._verify) + self.assertEqual((b"fake-host-key", b"fake-sig"), transport._verify) self.assertTrue(transport._activated) def test_12_kex_nistp256_server(self): - K = 91610929826364598472338906427792435253694642563583721654249504912114314269754 + K = ( + 91610929826364598472338906427792435253694642563583721654249504912114314269754 + ) transport = FakeTransport() transport.server_mode = True kex = KexNistp256(transport) kex.start_kex() - self.assertEqual((paramiko.kex_ecdh_nist._MSG_KEXECDH_INIT,), transport._expect) + self.assertEqual( + (paramiko.kex_ecdh_nist._MSG_KEXECDH_INIT,), transport._expect + ) - #fake init - msg=Message() - Q_C = unhexlify("043ae159594ba062efa121480e9ef136203fa9ec6b6e1f8723a321c16e62b945f573f3b822258cbcd094b9fa1c125cbfe5f043280893e66863cc0cb4dccbe70210") - H = b'2EF4957AFD530DD3F05DBEABF68D724FACC060974DA9704F2AEE4C3DE861E7CA' + # fake init + msg = Message() + Q_C = unhexlify( + "043ae159594ba062efa121480e9ef136203fa9ec6b6e1f8723a321c16e62b945f573f3b822258cbcd094b9fa1c125cbfe5f043280893e66863cc0cb4dccbe70210" + ) + H = b"2EF4957AFD530DD3F05DBEABF68D724FACC060974DA9704F2AEE4C3DE861E7CA" msg.add_string(Q_C) msg.rewind() kex.parse_next(paramiko.kex_ecdh_nist._MSG_KEXECDH_INIT, msg) diff --git a/tests/test_kex_gss.py b/tests/test_kex_gss.py index 025d1faa..afddee08 100644 --- a/tests/test_kex_gss.py +++ b/tests/test_kex_gss.py @@ -34,14 +34,14 @@ import paramiko from .util import needs_gssapi -class NullServer (paramiko.ServerInterface): +class NullServer(paramiko.ServerInterface): def get_allowed_auths(self, username): - return 'gssapi-keyex' + return "gssapi-keyex" - def check_auth_gssapi_keyex(self, username, - gss_authenticated=paramiko.AUTH_FAILED, - cc_file=None): + def check_auth_gssapi_keyex( + self, username, gss_authenticated=paramiko.AUTH_FAILED, cc_file=None + ): if gss_authenticated == paramiko.AUTH_SUCCESSFUL: return paramiko.AUTH_SUCCESSFUL return paramiko.AUTH_FAILED @@ -54,13 +54,14 @@ class NullServer (paramiko.ServerInterface): return paramiko.OPEN_SUCCEEDED def check_channel_exec_request(self, channel, command): - if command != 'yes': + if command != "yes": return False return True @needs_gssapi class GSSKexTest(unittest.TestCase): + @staticmethod def init(username, hostname): global krb5_principal, targ_name @@ -86,13 +87,13 @@ class GSSKexTest(unittest.TestCase): def _run(self): self.socks, addr = self.sockl.accept() self.ts = paramiko.Transport(self.socks, gss_kex=True) - host_key = paramiko.RSAKey.from_private_key_file('tests/test_rsa.key') + host_key = paramiko.RSAKey.from_private_key_file("tests/test_rsa.key") self.ts.add_server_key(host_key) self.ts.set_gss_host(targ_name) try: self.ts.load_server_moduli() except: - print ('(Failed to load moduli -- gex will be unsupported.)') + print("(Failed to load moduli -- gex will be unsupported.)") server = NullServer() self.ts.start_server(self.event, server) @@ -102,14 +103,21 @@ class GSSKexTest(unittest.TestCase): Diffie-Hellman Key Exchange and user authentication with the GSS-API context created during key exchange. """ - host_key = paramiko.RSAKey.from_private_key_file('tests/test_rsa.key') + host_key = paramiko.RSAKey.from_private_key_file("tests/test_rsa.key") public_host_key = paramiko.RSAKey(data=host_key.asbytes()) self.tc = paramiko.SSHClient() - self.tc.get_host_keys().add('[%s]:%d' % (self.hostname, self.port), - 'ssh-rsa', public_host_key) - self.tc.connect(self.hostname, self.port, username=self.username, - gss_auth=True, gss_kex=True, gss_host=gss_host) + self.tc.get_host_keys().add( + "[%s]:%d" % (self.hostname, self.port), "ssh-rsa", public_host_key + ) + self.tc.connect( + self.hostname, + self.port, + username=self.username, + gss_auth=True, + gss_kex=True, + gss_host=gss_host, + ) self.event.wait(1.0) self.assert_(self.event.is_set()) @@ -118,19 +126,19 @@ class GSSKexTest(unittest.TestCase): self.assertEquals(True, self.ts.is_authenticated()) self.assertEquals(True, self.tc.get_transport().gss_kex_used) - stdin, stdout, stderr = self.tc.exec_command('yes') + stdin, stdout, stderr = self.tc.exec_command("yes") schan = self.ts.accept(1.0) if rekey: self.tc.get_transport().renegotiate_keys() - schan.send('Hello there.\n') - schan.send_stderr('This is on stderr.\n') + schan.send("Hello there.\n") + schan.send_stderr("This is on stderr.\n") schan.close() - self.assertEquals('Hello there.\n', stdout.readline()) - self.assertEquals('', stdout.readline()) - self.assertEquals('This is on stderr.\n', stderr.readline()) - self.assertEquals('', stderr.readline()) + self.assertEquals("Hello there.\n", stdout.readline()) + self.assertEquals("", stdout.readline()) + self.assertEquals("This is on stderr.\n", stderr.readline()) + self.assertEquals("", stderr.readline()) stdin.close() stdout.close() diff --git a/tests/test_message.py b/tests/test_message.py index 645b0509..c292f4e6 100644 --- a/tests/test_message.py +++ b/tests/test_message.py @@ -26,20 +26,29 @@ from paramiko.message import Message from paramiko.common import byte_chr, zero_byte -class MessageTest (unittest.TestCase): +class MessageTest(unittest.TestCase): - __a = b'\x00\x00\x00\x17\x07\x60\xe0\x90\x00\x00\x00\x01\x71\x00\x00\x00\x05\x68\x65\x6c\x6c\x6f\x00\x00\x03\xe8' + b'x' * 1000 - __b = b'\x01\x00\xf3\x00\x3f\x00\x00\x00\x10\x68\x75\x65\x79\x2c\x64\x65\x77\x65\x79\x2c\x6c\x6f\x75\x69\x65' - __c = b'\x00\x00\x00\x00\x00\x00\x00\x05\x00\x00\xf5\xe4\xd3\xc2\xb1\x09\x00\x00\x00\x01\x11\x00\x00\x00\x07\x00\xf5\xe4\xd3\xc2\xb1\x09\x00\x00\x00\x06\x9a\x1b\x2c\x3d\x4e\xf7' - __d = b'\x00\x00\x00\x05\xff\x00\x00\x00\x05\x11\x22\x33\x44\x55\xff\x00\x00\x00\x0a\x00\xf0\x00\x00\x00\x00\x00\x00\x00\x00\x01\x00\x00\x00\x03\x63\x61\x74\x00\x00\x00\x03\x61\x2c\x62' + __a = ( + b"\x00\x00\x00\x17\x07\x60\xe0\x90\x00\x00\x00\x01\x71\x00\x00\x00\x05\x68\x65\x6c\x6c\x6f\x00\x00\x03\xe8" + + b"x" * 1000 + ) + __b = ( + b"\x01\x00\xf3\x00\x3f\x00\x00\x00\x10\x68\x75\x65\x79\x2c\x64\x65\x77\x65\x79\x2c\x6c\x6f\x75\x69\x65" + ) + __c = ( + b"\x00\x00\x00\x00\x00\x00\x00\x05\x00\x00\xf5\xe4\xd3\xc2\xb1\x09\x00\x00\x00\x01\x11\x00\x00\x00\x07\x00\xf5\xe4\xd3\xc2\xb1\x09\x00\x00\x00\x06\x9a\x1b\x2c\x3d\x4e\xf7" + ) + __d = ( + b"\x00\x00\x00\x05\xff\x00\x00\x00\x05\x11\x22\x33\x44\x55\xff\x00\x00\x00\x0a\x00\xf0\x00\x00\x00\x00\x00\x00\x00\x00\x01\x00\x00\x00\x03\x63\x61\x74\x00\x00\x00\x03\x61\x2c\x62" + ) def test_1_encode(self): msg = Message() msg.add_int(23) msg.add_int(123789456) - msg.add_string('q') - msg.add_string('hello') - msg.add_string('x' * 1000) + msg.add_string("q") + msg.add_string("hello") + msg.add_string("x" * 1000) self.assertEqual(msg.asbytes(), self.__a) msg = Message() @@ -48,7 +57,7 @@ class MessageTest (unittest.TestCase): msg.add_byte(byte_chr(0xf3)) msg.add_bytes(zero_byte + byte_chr(0x3f)) - msg.add_list(['huey', 'dewey', 'louie']) + msg.add_list(["huey", "dewey", "louie"]) self.assertEqual(msg.asbytes(), self.__b) msg = Message() @@ -63,16 +72,16 @@ class MessageTest (unittest.TestCase): msg = Message(self.__a) self.assertEqual(msg.get_int(), 23) self.assertEqual(msg.get_int(), 123789456) - self.assertEqual(msg.get_text(), 'q') - self.assertEqual(msg.get_text(), 'hello') - self.assertEqual(msg.get_text(), 'x' * 1000) + self.assertEqual(msg.get_text(), "q") + self.assertEqual(msg.get_text(), "hello") + self.assertEqual(msg.get_text(), "x" * 1000) msg = Message(self.__b) self.assertEqual(msg.get_boolean(), True) self.assertEqual(msg.get_boolean(), False) self.assertEqual(msg.get_byte(), byte_chr(0xf3)) self.assertEqual(msg.get_bytes(2), zero_byte + byte_chr(0x3f)) - self.assertEqual(msg.get_list(), ['huey', 'dewey', 'louie']) + self.assertEqual(msg.get_list(), ["huey", "dewey", "louie"]) msg = Message(self.__c) self.assertEqual(msg.get_int64(), 5) @@ -87,8 +96,8 @@ class MessageTest (unittest.TestCase): msg.add(0x1122334455) msg.add(0xf00000000000000000) msg.add(True) - msg.add('cat') - msg.add(['a', 'b']) + msg.add("cat") + msg.add(["a", "b"]) self.assertEqual(msg.asbytes(), self.__d) def test_4_misc(self): diff --git a/tests/test_packetizer.py b/tests/test_packetizer.py index 414b7e38..dbe5993e 100644 --- a/tests/test_packetizer.py +++ b/tests/test_packetizer.py @@ -36,19 +36,20 @@ from .loop import LoopSocket x55 = byte_chr(0x55) x1f = byte_chr(0x1f) -class PacketizerTest (unittest.TestCase): + +class PacketizerTest(unittest.TestCase): def test_1_write(self): rsock = LoopSocket() wsock = LoopSocket() rsock.link(wsock) p = Packetizer(wsock) - p.set_log(util.get_logger('paramiko.transport')) + p.set_log(util.get_logger("paramiko.transport")) p.set_hexdump(True) encryptor = Cipher( algorithms.AES(zero_byte * 16), modes.CBC(x55 * 16), - backend=default_backend() + backend=default_backend(), ).encryptor() p.set_outbound_cipher(encryptor, 16, sha1, 12, x1f * 20) @@ -63,22 +64,27 @@ class PacketizerTest (unittest.TestCase): data = rsock.recv(100) # 32 + 12 bytes of MAC = 44 self.assertEqual(44, len(data)) - self.assertEqual(b'\x43\x91\x97\xbd\x5b\x50\xac\x25\x87\xc2\xc4\x6b\xc7\xe9\x38\xc0', data[:16]) + self.assertEqual( + b"\x43\x91\x97\xbd\x5b\x50\xac\x25\x87\xc2\xc4\x6b\xc7\xe9\x38\xc0", + data[:16], + ) def test_2_read(self): rsock = LoopSocket() wsock = LoopSocket() rsock.link(wsock) p = Packetizer(rsock) - p.set_log(util.get_logger('paramiko.transport')) + p.set_log(util.get_logger("paramiko.transport")) p.set_hexdump(True) decryptor = Cipher( algorithms.AES(zero_byte * 16), modes.CBC(x55 * 16), - backend=default_backend() + backend=default_backend(), ).decryptor() p.set_inbound_cipher(decryptor, 16, sha1, 12, x1f * 20) - wsock.send(b'\x43\x91\x97\xbd\x5b\x50\xac\x25\x87\xc2\xc4\x6b\xc7\xe9\x38\xc0\x90\xd2\x16\x56\x0d\x71\x73\x61\x38\x7c\x4c\x3d\xfb\x97\x7d\xe2\x6e\x03\xb1\xa0\xc2\x1c\xd6\x41\x41\x4c\xb4\x59') + wsock.send( + b"\x43\x91\x97\xbd\x5b\x50\xac\x25\x87\xc2\xc4\x6b\xc7\xe9\x38\xc0\x90\xd2\x16\x56\x0d\x71\x73\x61\x38\x7c\x4c\x3d\xfb\x97\x7d\xe2\x6e\x03\xb1\xa0\xc2\x1c\xd6\x41\x41\x4c\xb4\x59" + ) cmd, m = p.read_message() self.assertEqual(100, cmd) self.assertEqual(100, m.get_int()) @@ -86,18 +92,18 @@ class PacketizerTest (unittest.TestCase): self.assertEqual(900, m.get_int()) def test_3_closed(self): - if sys.platform.startswith("win"): # no SIGALRM on windows + if sys.platform.startswith("win"): # no SIGALRM on windows return rsock = LoopSocket() wsock = LoopSocket() rsock.link(wsock) p = Packetizer(wsock) - p.set_log(util.get_logger('paramiko.transport')) + p.set_log(util.get_logger("paramiko.transport")) p.set_hexdump(True) encryptor = Cipher( algorithms.AES(zero_byte * 16), modes.CBC(x55 * 16), - backend=default_backend() + backend=default_backend(), ).encryptor() p.set_outbound_cipher(encryptor, 16, sha1, 12, x1f * 20) @@ -115,14 +121,17 @@ class PacketizerTest (unittest.TestCase): import signal class TimeoutError(Exception): + def __init__(self, error_message): - if hasattr(errno, 'ETIME'): + if hasattr(errno, "ETIME"): self.message = os.sterror(errno.ETIME) else: self.messaage = error_message - def timeout(seconds=1, error_message='Timer expired'): + def timeout(seconds=1, error_message="Timer expired"): + def decorator(func): + def _handle_timeout(signum, frame): raise TimeoutError(error_message) @@ -138,5 +147,6 @@ class PacketizerTest (unittest.TestCase): return wraps(func)(wrapper) return decorator + send = timeout()(p.send_message) self.assertRaises(EOFError, send, m) diff --git a/tests/test_pkey.py b/tests/test_pkey.py index 1827d2a9..4bbfaba1 100644 --- a/tests/test_pkey.py +++ b/tests/test_pkey.py @@ -34,18 +34,30 @@ from .util import _support # from openssh's ssh-keygen -PUB_RSA = 'ssh-rsa AAAAB3NzaC1yc2EAAAABIwAAAIEA049W6geFpmsljTwfvI1UmKWWJPNFI74+vNKTk4dmzkQY2yAMs6FhlvhlI8ysU4oj71ZsRYMecHbBbxdN79+JRFVYTKaLqjwGENeTd+yv4q+V2PvZv3fLnzApI3l7EJCqhWwJUHJ1jAkZzqDx0tyOL4uoZpww3nmE0kb3y21tH4c=' -PUB_DSS = 'ssh-dss AAAAB3NzaC1kc3MAAACBAOeBpgNnfRzr/twmAQRu2XwWAp3CFtrVnug6s6fgwj/oLjYbVtjAy6pl/h0EKCWx2rf1IetyNsTxWrniA9I6HeDj65X1FyDkg6g8tvCnaNB8Xp/UUhuzHuGsMIipRxBxw9LF608EqZcj1E3ytktoW5B5OcjrkEoz3xG7C+rpIjYvAAAAFQDwz4UnmsGiSNu5iqjn3uTzwUpshwAAAIEAkxfFeY8P2wZpDjX0MimZl5wkoFQDL25cPzGBuB4OnB8NoUk/yjAHIIpEShw8V+LzouMK5CTJQo5+Ngw3qIch/WgRmMHy4kBq1SsXMjQCte1So6HBMvBPIW5SiMTmjCfZZiw4AYHK+B/JaOwaG9yRg2Ejg4Ok10+XFDxlqZo8Y+wAAACARmR7CCPjodxASvRbIyzaVpZoJ/Z6x7dAumV+ysrV1BVYd0lYukmnjO1kKBWApqpH1ve9XDQYN8zgxM4b16L21kpoWQnZtXrY3GZ4/it9kUgyB7+NwacIBlXa8cMDL7Q/69o0d54U0X/NeX5QxuYR6OMJlrkQB7oiW/P/1mwjQgE=' -PUB_ECDSA_256 = 'ecdsa-sha2-nistp256 AAAAE2VjZHNhLXNoYTItbmlzdHAyNTYAAAAIbmlzdHAyNTYAAABBBJSPZm3ZWkvk/Zx8WP+fZRZ5/NBBHnGQwR6uIC6XHGPDIHuWUzIjAwA0bzqkOUffEsbLe+uQgKl5kbc/L8KA/eo=' -PUB_ECDSA_384 = 'ecdsa-sha2-nistp384 AAAAE2VjZHNhLXNoYTItbmlzdHAzODQAAAAIbmlzdHAzODQAAABhBBbGibQLW9AAZiGN2hEQxWYYoFaWKwN3PKSaDJSMqmIn1Z9sgRUuw8Y/w502OGvXL/wFk0i2z50l3pWZjD7gfMH7gX5TUiCzwrQkS+Hn1U2S9aF5WJp0NcIzYxXw2r4M2A==' -PUB_ECDSA_521 = 'ecdsa-sha2-nistp521 AAAAE2VjZHNhLXNoYTItbmlzdHA1MjEAAAAIbmlzdHA1MjEAAACFBACaOaFLZGuxa5AW16qj6VLypFbLrEWrt9AZUloCMefxO8bNLjK/O5g0rAVasar1TnyHE9qj4NwzANZASWjQNbc4MAG8vzqezFwLIn/kNyNTsXNfqEko9OgHZknlj2Z79dwTJcRAL4QLcT5aND0EHZLB2fAUDXiWIb2j4rg1mwPlBMiBXA==' - -FINGER_RSA = '1024 60:73:38:44:cb:51:86:65:7f:de:da:a2:2b:5a:57:d5' -FINGER_DSS = '1024 44:78:f0:b9:a2:3c:c5:18:20:09:ff:75:5b:c1:d2:6c' -FINGER_ECDSA_256 = '256 25:19:eb:55:e6:a1:47:ff:4f:38:d2:75:6f:a5:d5:60' -FINGER_ECDSA_384 = '384 c1:8d:a0:59:09:47:41:8e:a8:a6:07:01:29:23:b4:65' -FINGER_ECDSA_521 = '521 44:58:22:52:12:33:16:0e:ce:0e:be:2c:7c:7e:cc:1e' -SIGNED_RSA = '20:d7:8a:31:21:cb:f7:92:12:f2:a4:89:37:f5:78:af:e6:16:b6:25:b9:97:3d:a2:cd:5f:ca:20:21:73:4c:ad:34:73:8f:20:77:28:e2:94:15:08:d8:91:40:7a:85:83:bf:18:37:95:dc:54:1a:9b:88:29:6c:73:ca:38:b4:04:f1:56:b9:f2:42:9d:52:1b:29:29:b4:4f:fd:c9:2d:af:47:d2:40:76:30:f3:63:45:0c:d9:1d:43:86:0f:1c:70:e2:93:12:34:f3:ac:c5:0a:2f:14:50:66:59:f1:88:ee:c1:4a:e9:d1:9c:4e:46:f0:0e:47:6f:38:74:f1:44:a8' +PUB_RSA = ( + "ssh-rsa AAAAB3NzaC1yc2EAAAABIwAAAIEA049W6geFpmsljTwfvI1UmKWWJPNFI74+vNKTk4dmzkQY2yAMs6FhlvhlI8ysU4oj71ZsRYMecHbBbxdN79+JRFVYTKaLqjwGENeTd+yv4q+V2PvZv3fLnzApI3l7EJCqhWwJUHJ1jAkZzqDx0tyOL4uoZpww3nmE0kb3y21tH4c=" +) +PUB_DSS = ( + "ssh-dss AAAAB3NzaC1kc3MAAACBAOeBpgNnfRzr/twmAQRu2XwWAp3CFtrVnug6s6fgwj/oLjYbVtjAy6pl/h0EKCWx2rf1IetyNsTxWrniA9I6HeDj65X1FyDkg6g8tvCnaNB8Xp/UUhuzHuGsMIipRxBxw9LF608EqZcj1E3ytktoW5B5OcjrkEoz3xG7C+rpIjYvAAAAFQDwz4UnmsGiSNu5iqjn3uTzwUpshwAAAIEAkxfFeY8P2wZpDjX0MimZl5wkoFQDL25cPzGBuB4OnB8NoUk/yjAHIIpEShw8V+LzouMK5CTJQo5+Ngw3qIch/WgRmMHy4kBq1SsXMjQCte1So6HBMvBPIW5SiMTmjCfZZiw4AYHK+B/JaOwaG9yRg2Ejg4Ok10+XFDxlqZo8Y+wAAACARmR7CCPjodxASvRbIyzaVpZoJ/Z6x7dAumV+ysrV1BVYd0lYukmnjO1kKBWApqpH1ve9XDQYN8zgxM4b16L21kpoWQnZtXrY3GZ4/it9kUgyB7+NwacIBlXa8cMDL7Q/69o0d54U0X/NeX5QxuYR6OMJlrkQB7oiW/P/1mwjQgE=" +) +PUB_ECDSA_256 = ( + "ecdsa-sha2-nistp256 AAAAE2VjZHNhLXNoYTItbmlzdHAyNTYAAAAIbmlzdHAyNTYAAABBBJSPZm3ZWkvk/Zx8WP+fZRZ5/NBBHnGQwR6uIC6XHGPDIHuWUzIjAwA0bzqkOUffEsbLe+uQgKl5kbc/L8KA/eo=" +) +PUB_ECDSA_384 = ( + "ecdsa-sha2-nistp384 AAAAE2VjZHNhLXNoYTItbmlzdHAzODQAAAAIbmlzdHAzODQAAABhBBbGibQLW9AAZiGN2hEQxWYYoFaWKwN3PKSaDJSMqmIn1Z9sgRUuw8Y/w502OGvXL/wFk0i2z50l3pWZjD7gfMH7gX5TUiCzwrQkS+Hn1U2S9aF5WJp0NcIzYxXw2r4M2A==" +) +PUB_ECDSA_521 = ( + "ecdsa-sha2-nistp521 AAAAE2VjZHNhLXNoYTItbmlzdHA1MjEAAAAIbmlzdHA1MjEAAACFBACaOaFLZGuxa5AW16qj6VLypFbLrEWrt9AZUloCMefxO8bNLjK/O5g0rAVasar1TnyHE9qj4NwzANZASWjQNbc4MAG8vzqezFwLIn/kNyNTsXNfqEko9OgHZknlj2Z79dwTJcRAL4QLcT5aND0EHZLB2fAUDXiWIb2j4rg1mwPlBMiBXA==" +) + +FINGER_RSA = "1024 60:73:38:44:cb:51:86:65:7f:de:da:a2:2b:5a:57:d5" +FINGER_DSS = "1024 44:78:f0:b9:a2:3c:c5:18:20:09:ff:75:5b:c1:d2:6c" +FINGER_ECDSA_256 = "256 25:19:eb:55:e6:a1:47:ff:4f:38:d2:75:6f:a5:d5:60" +FINGER_ECDSA_384 = "384 c1:8d:a0:59:09:47:41:8e:a8:a6:07:01:29:23:b4:65" +FINGER_ECDSA_521 = "521 44:58:22:52:12:33:16:0e:ce:0e:be:2c:7c:7e:cc:1e" +SIGNED_RSA = ( + "20:d7:8a:31:21:cb:f7:92:12:f2:a4:89:37:f5:78:af:e6:16:b6:25:b9:97:3d:a2:cd:5f:ca:20:21:73:4c:ad:34:73:8f:20:77:28:e2:94:15:08:d8:91:40:7a:85:83:bf:18:37:95:dc:54:1a:9b:88:29:6c:73:ca:38:b4:04:f1:56:b9:f2:42:9d:52:1b:29:29:b4:4f:fd:c9:2d:af:47:d2:40:76:30:f3:63:45:0c:d9:1d:43:86:0f:1c:70:e2:93:12:34:f3:ac:c5:0a:2f:14:50:66:59:f1:88:ee:c1:4a:e9:d1:9c:4e:46:f0:0e:47:6f:38:74:f1:44:a8" +) RSA_PRIVATE_OUT = """\ -----BEGIN RSA PRIVATE KEY----- @@ -107,10 +119,14 @@ L4QLcT5aND0EHZLB2fAUDXiWIb2j4rg1mwPlBMiBXA== -----END EC PRIVATE KEY----- """ -x1234 = b'\x01\x02\x03\x04' +x1234 = b"\x01\x02\x03\x04" -TEST_KEY_BYTESTR_2 = '\x00\x00\x00\x07ssh-rsa\x00\x00\x00\x01#\x00\x00\x00\x81\x00\xd3\x8fV\xea\x07\x85\xa6k%\x8d<\x1f\xbc\x8dT\x98\xa5\x96$\xf3E#\xbe>\xbc\xd2\x93\x93\x87f\xceD\x18\xdb \x0c\xb3\xa1a\x96\xf8e#\xcc\xacS\x8a#\xefVlE\x83\x1epv\xc1o\x17M\xef\xdf\x89DUXL\xa6\x8b\xaa<\x06\x10\xd7\x93w\xec\xaf\xe2\xaf\x95\xd8\xfb\xd9\xbfw\xcb\x9f0)#y{\x10\x90\xaa\x85l\tPru\x8c\t\x19\xce\xa0\xf1\xd2\xdc\x8e/\x8b\xa8f\x9c0\xdey\x84\xd2F\xf7\xcbmm\x1f\x87' -TEST_KEY_BYTESTR_3 = '\x00\x00\x00\x07ssh-rsa\x00\x00\x00\x01#\x00\x00\x00\x00ӏV\x07k%<\x1fT$E#>ғfD\x18 \x0cae#̬S#VlE\x1epvo\x17M߉DUXL<\x06\x10דw\u2bd5ٿw˟0)#y{\x10l\tPru\t\x19Π\u070e/f0yFmm\x1f' +TEST_KEY_BYTESTR_2 = ( + "\x00\x00\x00\x07ssh-rsa\x00\x00\x00\x01#\x00\x00\x00\x81\x00\xd3\x8fV\xea\x07\x85\xa6k%\x8d<\x1f\xbc\x8dT\x98\xa5\x96$\xf3E#\xbe>\xbc\xd2\x93\x93\x87f\xceD\x18\xdb \x0c\xb3\xa1a\x96\xf8e#\xcc\xacS\x8a#\xefVlE\x83\x1epv\xc1o\x17M\xef\xdf\x89DUXL\xa6\x8b\xaa<\x06\x10\xd7\x93w\xec\xaf\xe2\xaf\x95\xd8\xfb\xd9\xbfw\xcb\x9f0)#y{\x10\x90\xaa\x85l\tPru\x8c\t\x19\xce\xa0\xf1\xd2\xdc\x8e/\x8b\xa8f\x9c0\xdey\x84\xd2F\xf7\xcbmm\x1f\x87" +) +TEST_KEY_BYTESTR_3 = ( + "\x00\x00\x00\x07ssh-rsa\x00\x00\x00\x01#\x00\x00\x00\x00ӏV\x07k%<\x1fT$E#>ғfD\x18 \x0cae#̬S#VlE\x1epvo\x17M߉DUXL<\x06\x10דw\u2bd5ٿw˟0)#y{\x10l\tPru\t\x19Π\u070e/f0yFmm\x1f" +) class KeyTest(unittest.TestCase): @@ -127,21 +143,22 @@ class KeyTest(unittest.TestCase): """ with open(keyfile, "r") as fh: self.assertEqual( - fh.readline()[:-1], - "-----BEGIN RSA PRIVATE KEY-----" + fh.readline()[:-1], "-----BEGIN RSA PRIVATE KEY-----" ) self.assertEqual(fh.readline()[:-1], "Proc-Type: 4,ENCRYPTED") self.assertEqual(fh.readline()[0:10], "DEK-Info: ") def test_1_generate_key_bytes(self): - key = util.generate_key_bytes(md5, x1234, 'happy birthday', 30) - exp = b'\x61\xE1\xF2\x72\xF4\xC1\xC4\x56\x15\x86\xBD\x32\x24\x98\xC0\xE9\x24\x67\x27\x80\xF4\x7B\xB3\x7D\xDA\x7D\x54\x01\x9E\x64' + key = util.generate_key_bytes(md5, x1234, "happy birthday", 30) + exp = ( + b"\x61\xE1\xF2\x72\xF4\xC1\xC4\x56\x15\x86\xBD\x32\x24\x98\xC0\xE9\x24\x67\x27\x80\xF4\x7B\xB3\x7D\xDA\x7D\x54\x01\x9E\x64" + ) self.assertEqual(exp, key) def test_2_load_rsa(self): - key = RSAKey.from_private_key_file(_support('test_rsa.key')) - self.assertEqual('ssh-rsa', key.get_name()) - exp_rsa = b(FINGER_RSA.split()[1].replace(':', '')) + key = RSAKey.from_private_key_file(_support("test_rsa.key")) + self.assertEqual("ssh-rsa", key.get_name()) + exp_rsa = b(FINGER_RSA.split()[1].replace(":", "")) my_rsa = hexlify(key.get_fingerprint()) self.assertEqual(exp_rsa, my_rsa) self.assertEqual(PUB_RSA.split()[1], key.get_base64()) @@ -155,18 +172,20 @@ class KeyTest(unittest.TestCase): self.assertEqual(key, key2) def test_3_load_rsa_password(self): - key = RSAKey.from_private_key_file(_support('test_rsa_password.key'), 'television') - self.assertEqual('ssh-rsa', key.get_name()) - exp_rsa = b(FINGER_RSA.split()[1].replace(':', '')) + key = RSAKey.from_private_key_file( + _support("test_rsa_password.key"), "television" + ) + self.assertEqual("ssh-rsa", key.get_name()) + exp_rsa = b(FINGER_RSA.split()[1].replace(":", "")) my_rsa = hexlify(key.get_fingerprint()) self.assertEqual(exp_rsa, my_rsa) self.assertEqual(PUB_RSA.split()[1], key.get_base64()) self.assertEqual(1024, key.get_bits()) def test_4_load_dss(self): - key = DSSKey.from_private_key_file(_support('test_dss.key')) - self.assertEqual('ssh-dss', key.get_name()) - exp_dss = b(FINGER_DSS.split()[1].replace(':', '')) + key = DSSKey.from_private_key_file(_support("test_dss.key")) + self.assertEqual("ssh-dss", key.get_name()) + exp_dss = b(FINGER_DSS.split()[1].replace(":", "")) my_dss = hexlify(key.get_fingerprint()) self.assertEqual(exp_dss, my_dss) self.assertEqual(PUB_DSS.split()[1], key.get_base64()) @@ -180,9 +199,11 @@ class KeyTest(unittest.TestCase): self.assertEqual(key, key2) def test_5_load_dss_password(self): - key = DSSKey.from_private_key_file(_support('test_dss_password.key'), 'television') - self.assertEqual('ssh-dss', key.get_name()) - exp_dss = b(FINGER_DSS.split()[1].replace(':', '')) + key = DSSKey.from_private_key_file( + _support("test_dss_password.key"), "television" + ) + self.assertEqual("ssh-dss", key.get_name()) + exp_dss = b(FINGER_DSS.split()[1].replace(":", "")) my_dss = hexlify(key.get_fingerprint()) self.assertEqual(exp_dss, my_dss) self.assertEqual(PUB_DSS.split()[1], key.get_base64()) @@ -190,7 +211,7 @@ class KeyTest(unittest.TestCase): def test_6_compare_rsa(self): # verify that the private & public keys compare equal - key = RSAKey.from_private_key_file(_support('test_rsa.key')) + key = RSAKey.from_private_key_file(_support("test_rsa.key")) self.assertEqual(key, key) pub = RSAKey(data=key.asbytes()) self.assertTrue(key.can_sign()) @@ -199,7 +220,7 @@ class KeyTest(unittest.TestCase): def test_7_compare_dss(self): # verify that the private & public keys compare equal - key = DSSKey.from_private_key_file(_support('test_dss.key')) + key = DSSKey.from_private_key_file(_support("test_dss.key")) self.assertEqual(key, key) pub = DSSKey(data=key.asbytes()) self.assertTrue(key.can_sign()) @@ -208,77 +229,79 @@ class KeyTest(unittest.TestCase): def test_8_sign_rsa(self): # verify that the rsa private key can sign and verify - key = RSAKey.from_private_key_file(_support('test_rsa.key')) - msg = key.sign_ssh_data(b'ice weasels') + key = RSAKey.from_private_key_file(_support("test_rsa.key")) + msg = key.sign_ssh_data(b"ice weasels") self.assertTrue(type(msg) is Message) msg.rewind() - self.assertEqual('ssh-rsa', msg.get_text()) - sig = bytes().join([byte_chr(int(x, 16)) for x in SIGNED_RSA.split(':')]) + self.assertEqual("ssh-rsa", msg.get_text()) + sig = bytes().join( + [byte_chr(int(x, 16)) for x in SIGNED_RSA.split(":")] + ) self.assertEqual(sig, msg.get_binary()) msg.rewind() pub = RSAKey(data=key.asbytes()) - self.assertTrue(pub.verify_ssh_sig(b'ice weasels', msg)) + self.assertTrue(pub.verify_ssh_sig(b"ice weasels", msg)) def test_9_sign_dss(self): # verify that the dss private key can sign and verify - key = DSSKey.from_private_key_file(_support('test_dss.key')) - msg = key.sign_ssh_data(b'ice weasels') + key = DSSKey.from_private_key_file(_support("test_dss.key")) + msg = key.sign_ssh_data(b"ice weasels") self.assertTrue(type(msg) is Message) msg.rewind() - self.assertEqual('ssh-dss', msg.get_text()) + self.assertEqual("ssh-dss", msg.get_text()) # can't do the same test as we do for RSA, because DSS signatures # are usually different each time. but we can test verification # anyway so it's ok. self.assertEqual(40, len(msg.get_binary())) msg.rewind() pub = DSSKey(data=key.asbytes()) - self.assertTrue(pub.verify_ssh_sig(b'ice weasels', msg)) + self.assertTrue(pub.verify_ssh_sig(b"ice weasels", msg)) def test_A_generate_rsa(self): key = RSAKey.generate(1024) - msg = key.sign_ssh_data(b'jerri blank') + msg = key.sign_ssh_data(b"jerri blank") msg.rewind() - self.assertTrue(key.verify_ssh_sig(b'jerri blank', msg)) + self.assertTrue(key.verify_ssh_sig(b"jerri blank", msg)) def test_B_generate_dss(self): key = DSSKey.generate(1024) - msg = key.sign_ssh_data(b'jerri blank') + msg = key.sign_ssh_data(b"jerri blank") msg.rewind() - self.assertTrue(key.verify_ssh_sig(b'jerri blank', msg)) + self.assertTrue(key.verify_ssh_sig(b"jerri blank", msg)) def test_C_generate_ecdsa(self): key = ECDSAKey.generate() - msg = key.sign_ssh_data(b'jerri blank') + msg = key.sign_ssh_data(b"jerri blank") msg.rewind() - self.assertTrue(key.verify_ssh_sig(b'jerri blank', msg)) + self.assertTrue(key.verify_ssh_sig(b"jerri blank", msg)) self.assertEqual(key.get_bits(), 256) - self.assertEqual(key.get_name(), 'ecdsa-sha2-nistp256') + self.assertEqual(key.get_name(), "ecdsa-sha2-nistp256") key = ECDSAKey.generate(bits=256) - msg = key.sign_ssh_data(b'jerri blank') + msg = key.sign_ssh_data(b"jerri blank") msg.rewind() - self.assertTrue(key.verify_ssh_sig(b'jerri blank', msg)) + self.assertTrue(key.verify_ssh_sig(b"jerri blank", msg)) self.assertEqual(key.get_bits(), 256) - self.assertEqual(key.get_name(), 'ecdsa-sha2-nistp256') + self.assertEqual(key.get_name(), "ecdsa-sha2-nistp256") key = ECDSAKey.generate(bits=384) - msg = key.sign_ssh_data(b'jerri blank') + msg = key.sign_ssh_data(b"jerri blank") msg.rewind() - self.assertTrue(key.verify_ssh_sig(b'jerri blank', msg)) + self.assertTrue(key.verify_ssh_sig(b"jerri blank", msg)) self.assertEqual(key.get_bits(), 384) - self.assertEqual(key.get_name(), 'ecdsa-sha2-nistp384') + self.assertEqual(key.get_name(), "ecdsa-sha2-nistp384") key = ECDSAKey.generate(bits=521) - msg = key.sign_ssh_data(b'jerri blank') + msg = key.sign_ssh_data(b"jerri blank") msg.rewind() - self.assertTrue(key.verify_ssh_sig(b'jerri blank', msg)) + self.assertTrue(key.verify_ssh_sig(b"jerri blank", msg)) self.assertEqual(key.get_bits(), 521) - self.assertEqual(key.get_name(), 'ecdsa-sha2-nistp521') + self.assertEqual(key.get_name(), "ecdsa-sha2-nistp521") def test_10_load_ecdsa_256(self): - key = ECDSAKey.from_private_key_file(_support('test_ecdsa_256.key')) - self.assertEqual('ecdsa-sha2-nistp256', key.get_name()) - exp_ecdsa = b(FINGER_ECDSA_256.split()[1].replace(':', '')) + key = ECDSAKey.from_private_key_file(_support("test_ecdsa_256.key")) + self.assertEqual("ecdsa-sha2-nistp256", key.get_name()) + exp_ecdsa = b(FINGER_ECDSA_256.split()[1].replace(":", "")) my_ecdsa = hexlify(key.get_fingerprint()) self.assertEqual(exp_ecdsa, my_ecdsa) self.assertEqual(PUB_ECDSA_256.split()[1], key.get_base64()) @@ -292,9 +315,11 @@ class KeyTest(unittest.TestCase): self.assertEqual(key, key2) def test_11_load_ecdsa_password_256(self): - key = ECDSAKey.from_private_key_file(_support('test_ecdsa_password_256.key'), b'television') - self.assertEqual('ecdsa-sha2-nistp256', key.get_name()) - exp_ecdsa = b(FINGER_ECDSA_256.split()[1].replace(':', '')) + key = ECDSAKey.from_private_key_file( + _support("test_ecdsa_password_256.key"), b"television" + ) + self.assertEqual("ecdsa-sha2-nistp256", key.get_name()) + exp_ecdsa = b(FINGER_ECDSA_256.split()[1].replace(":", "")) my_ecdsa = hexlify(key.get_fingerprint()) self.assertEqual(exp_ecdsa, my_ecdsa) self.assertEqual(PUB_ECDSA_256.split()[1], key.get_base64()) @@ -302,7 +327,7 @@ class KeyTest(unittest.TestCase): def test_12_compare_ecdsa_256(self): # verify that the private & public keys compare equal - key = ECDSAKey.from_private_key_file(_support('test_ecdsa_256.key')) + key = ECDSAKey.from_private_key_file(_support("test_ecdsa_256.key")) self.assertEqual(key, key) pub = ECDSAKey(data=key.asbytes()) self.assertTrue(key.can_sign()) @@ -311,11 +336,11 @@ class KeyTest(unittest.TestCase): def test_13_sign_ecdsa_256(self): # verify that the rsa private key can sign and verify - key = ECDSAKey.from_private_key_file(_support('test_ecdsa_256.key')) - msg = key.sign_ssh_data(b'ice weasels') + key = ECDSAKey.from_private_key_file(_support("test_ecdsa_256.key")) + msg = key.sign_ssh_data(b"ice weasels") self.assertTrue(type(msg) is Message) msg.rewind() - self.assertEqual('ecdsa-sha2-nistp256', msg.get_text()) + self.assertEqual("ecdsa-sha2-nistp256", msg.get_text()) # ECDSA signatures, like DSS signatures, tend to be different # each time, so we can't compare against a "known correct" # signature. @@ -323,12 +348,12 @@ class KeyTest(unittest.TestCase): msg.rewind() pub = ECDSAKey(data=key.asbytes()) - self.assertTrue(pub.verify_ssh_sig(b'ice weasels', msg)) + self.assertTrue(pub.verify_ssh_sig(b"ice weasels", msg)) def test_14_load_ecdsa_384(self): - key = ECDSAKey.from_private_key_file(_support('test_ecdsa_384.key')) - self.assertEqual('ecdsa-sha2-nistp384', key.get_name()) - exp_ecdsa = b(FINGER_ECDSA_384.split()[1].replace(':', '')) + key = ECDSAKey.from_private_key_file(_support("test_ecdsa_384.key")) + self.assertEqual("ecdsa-sha2-nistp384", key.get_name()) + exp_ecdsa = b(FINGER_ECDSA_384.split()[1].replace(":", "")) my_ecdsa = hexlify(key.get_fingerprint()) self.assertEqual(exp_ecdsa, my_ecdsa) self.assertEqual(PUB_ECDSA_384.split()[1], key.get_base64()) @@ -342,9 +367,11 @@ class KeyTest(unittest.TestCase): self.assertEqual(key, key2) def test_15_load_ecdsa_password_384(self): - key = ECDSAKey.from_private_key_file(_support('test_ecdsa_password_384.key'), b'television') - self.assertEqual('ecdsa-sha2-nistp384', key.get_name()) - exp_ecdsa = b(FINGER_ECDSA_384.split()[1].replace(':', '')) + key = ECDSAKey.from_private_key_file( + _support("test_ecdsa_password_384.key"), b"television" + ) + self.assertEqual("ecdsa-sha2-nistp384", key.get_name()) + exp_ecdsa = b(FINGER_ECDSA_384.split()[1].replace(":", "")) my_ecdsa = hexlify(key.get_fingerprint()) self.assertEqual(exp_ecdsa, my_ecdsa) self.assertEqual(PUB_ECDSA_384.split()[1], key.get_base64()) @@ -352,7 +379,7 @@ class KeyTest(unittest.TestCase): def test_16_compare_ecdsa_384(self): # verify that the private & public keys compare equal - key = ECDSAKey.from_private_key_file(_support('test_ecdsa_384.key')) + key = ECDSAKey.from_private_key_file(_support("test_ecdsa_384.key")) self.assertEqual(key, key) pub = ECDSAKey(data=key.asbytes()) self.assertTrue(key.can_sign()) @@ -361,11 +388,11 @@ class KeyTest(unittest.TestCase): def test_17_sign_ecdsa_384(self): # verify that the rsa private key can sign and verify - key = ECDSAKey.from_private_key_file(_support('test_ecdsa_384.key')) - msg = key.sign_ssh_data(b'ice weasels') + key = ECDSAKey.from_private_key_file(_support("test_ecdsa_384.key")) + msg = key.sign_ssh_data(b"ice weasels") self.assertTrue(type(msg) is Message) msg.rewind() - self.assertEqual('ecdsa-sha2-nistp384', msg.get_text()) + self.assertEqual("ecdsa-sha2-nistp384", msg.get_text()) # ECDSA signatures, like DSS signatures, tend to be different # each time, so we can't compare against a "known correct" # signature. @@ -373,12 +400,12 @@ class KeyTest(unittest.TestCase): msg.rewind() pub = ECDSAKey(data=key.asbytes()) - self.assertTrue(pub.verify_ssh_sig(b'ice weasels', msg)) + self.assertTrue(pub.verify_ssh_sig(b"ice weasels", msg)) def test_18_load_ecdsa_521(self): - key = ECDSAKey.from_private_key_file(_support('test_ecdsa_521.key')) - self.assertEqual('ecdsa-sha2-nistp521', key.get_name()) - exp_ecdsa = b(FINGER_ECDSA_521.split()[1].replace(':', '')) + key = ECDSAKey.from_private_key_file(_support("test_ecdsa_521.key")) + self.assertEqual("ecdsa-sha2-nistp521", key.get_name()) + exp_ecdsa = b(FINGER_ECDSA_521.split()[1].replace(":", "")) my_ecdsa = hexlify(key.get_fingerprint()) self.assertEqual(exp_ecdsa, my_ecdsa) self.assertEqual(PUB_ECDSA_521.split()[1], key.get_base64()) @@ -395,9 +422,11 @@ class KeyTest(unittest.TestCase): self.assertEqual(key, key2) def test_19_load_ecdsa_password_521(self): - key = ECDSAKey.from_private_key_file(_support('test_ecdsa_password_521.key'), b'television') - self.assertEqual('ecdsa-sha2-nistp521', key.get_name()) - exp_ecdsa = b(FINGER_ECDSA_521.split()[1].replace(':', '')) + key = ECDSAKey.from_private_key_file( + _support("test_ecdsa_password_521.key"), b"television" + ) + self.assertEqual("ecdsa-sha2-nistp521", key.get_name()) + exp_ecdsa = b(FINGER_ECDSA_521.split()[1].replace(":", "")) my_ecdsa = hexlify(key.get_fingerprint()) self.assertEqual(exp_ecdsa, my_ecdsa) self.assertEqual(PUB_ECDSA_521.split()[1], key.get_base64()) @@ -405,7 +434,7 @@ class KeyTest(unittest.TestCase): def test_20_compare_ecdsa_521(self): # verify that the private & public keys compare equal - key = ECDSAKey.from_private_key_file(_support('test_ecdsa_521.key')) + key = ECDSAKey.from_private_key_file(_support("test_ecdsa_521.key")) self.assertEqual(key, key) pub = ECDSAKey(data=key.asbytes()) self.assertTrue(key.can_sign()) @@ -414,11 +443,11 @@ class KeyTest(unittest.TestCase): def test_21_sign_ecdsa_521(self): # verify that the rsa private key can sign and verify - key = ECDSAKey.from_private_key_file(_support('test_ecdsa_521.key')) - msg = key.sign_ssh_data(b'ice weasels') + key = ECDSAKey.from_private_key_file(_support("test_ecdsa_521.key")) + msg = key.sign_ssh_data(b"ice weasels") self.assertTrue(type(msg) is Message) msg.rewind() - self.assertEqual('ecdsa-sha2-nistp521', msg.get_text()) + self.assertEqual("ecdsa-sha2-nistp521", msg.get_text()) # ECDSA signatures, like DSS signatures, tend to be different # each time, so we can't compare against a "known correct" # signature. @@ -426,14 +455,14 @@ class KeyTest(unittest.TestCase): msg.rewind() pub = ECDSAKey(data=key.asbytes()) - self.assertTrue(pub.verify_ssh_sig(b'ice weasels', msg)) + self.assertTrue(pub.verify_ssh_sig(b"ice weasels", msg)) def test_salt_size(self): # Read an existing encrypted private key - file_ = _support('test_rsa_password.key') - password = 'television' - newfile = file_ + '.new' - newpassword = 'radio' + file_ = _support("test_rsa_password.key") + password = "television" + newfile = file_ + ".new" + newpassword = "radio" key = RSAKey(filename=file_, password=password) # Write out a newly re-encrypted copy with a new password. # When the bug under test exists, this will ValueError. @@ -447,20 +476,20 @@ class KeyTest(unittest.TestCase): os.remove(newfile) def test_stringification(self): - key = RSAKey.from_private_key_file(_support('test_rsa.key')) + key = RSAKey.from_private_key_file(_support("test_rsa.key")) comparable = TEST_KEY_BYTESTR_2 if PY2 else TEST_KEY_BYTESTR_3 self.assertEqual(str(key), comparable) def test_ed25519(self): - key1 = Ed25519Key.from_private_key_file(_support('test_ed25519.key')) + key1 = Ed25519Key.from_private_key_file(_support("test_ed25519.key")) key2 = Ed25519Key.from_private_key_file( - _support('test_ed25519_password.key'), b'abc123' + _support("test_ed25519_password.key"), b"abc123" ) self.assertNotEqual(key1.asbytes(), key2.asbytes()) def test_ed25519_compare(self): # verify that the private & public keys compare equal - key = Ed25519Key.from_private_key_file(_support('test_ed25519.key')) + key = Ed25519Key.from_private_key_file(_support("test_ed25519.key")) self.assertEqual(key, key) pub = Ed25519Key(data=key.asbytes()) self.assertTrue(key.can_sign()) @@ -470,25 +499,25 @@ class KeyTest(unittest.TestCase): def test_ed25519_nonbytes_password(self): # https://github.com/paramiko/paramiko/issues/1039 key = Ed25519Key.from_private_key_file( - _support('test_ed25519_password.key'), + _support("test_ed25519_password.key"), # NOTE: not a bytes. Amusingly, the test above for same key DOES # explicitly cast to bytes...code smell! - 'abc123', + "abc123", ) # No exception -> it's good. Meh. def test_ed25519_load_from_file_obj(self): - with open(_support('test_ed25519.key')) as pkey_fileobj: + with open(_support("test_ed25519.key")) as pkey_fileobj: key = Ed25519Key.from_private_key(pkey_fileobj) self.assertEqual(key, key) self.assertTrue(key.can_sign()) def test_keyfile_is_actually_encrypted(self): # Read an existing encrypted private key - file_ = _support('test_rsa_password.key') - password = 'television' - newfile = file_ + '.new' - newpassword = 'radio' + file_ = _support("test_rsa_password.key") + password = "television" + newfile = file_ + ".new" + newpassword = "radio" key = RSAKey(filename=file_, password=password) # Write out a newly re-encrypted copy with a new password. # When the bug under test exists, this will ValueError. @@ -503,19 +532,21 @@ class KeyTest(unittest.TestCase): # test_client.py; this and nearby cert tests are more about the gritty # details. # PKey.load_certificate - key_path = _support(os.path.join('cert_support', 'test_rsa.key')) + key_path = _support(os.path.join("cert_support", "test_rsa.key")) key = RSAKey.from_private_key_file(key_path) self.assertTrue(key.public_blob is None) cert_path = _support( - os.path.join('cert_support', 'test_rsa.key-cert.pub') + os.path.join("cert_support", "test_rsa.key-cert.pub") ) key.load_certificate(cert_path) self.assertTrue(key.public_blob is not None) - self.assertEqual(key.public_blob.key_type, 'ssh-rsa-cert-v01@openssh.com') - self.assertEqual(key.public_blob.comment, 'test_rsa.key.pub') + self.assertEqual( + key.public_blob.key_type, "ssh-rsa-cert-v01@openssh.com" + ) + self.assertEqual(key.public_blob.comment, "test_rsa.key.pub") # Delve into blob contents, for test purposes msg = Message(key.public_blob.key_blob) - self.assertEqual(msg.get_text(), 'ssh-rsa-cert-v01@openssh.com') + self.assertEqual(msg.get_text(), "ssh-rsa-cert-v01@openssh.com") nonce = msg.get_string() e = msg.get_mpint() n = msg.get_mpint() @@ -525,10 +556,10 @@ class KeyTest(unittest.TestCase): self.assertEqual(msg.get_int64(), 1234) # Prevented from loading certificate that doesn't match - key_path = _support(os.path.join('cert_support', 'test_ed25519.key')) + key_path = _support(os.path.join("cert_support", "test_ed25519.key")) key1 = Ed25519Key.from_private_key_file(key_path) self.assertRaises( ValueError, key1.load_certificate, - _support('test_rsa.key-cert.pub'), + _support("test_rsa.key-cert.pub"), ) diff --git a/tests/test_sftp.py b/tests/test_sftp.py index 09a50453..576b69b7 100644 --- a/tests/test_sftp.py +++ b/tests/test_sftp.py @@ -45,7 +45,7 @@ from .stub_sftp import StubServer, StubSFTPServer from .util import _support, slow -ARTICLE = ''' +ARTICLE = """ Insulin sensitivity and liver insulin receptor structure in ducks from two genera @@ -70,7 +70,7 @@ receptors. Therefore the ducks from the two genera exhibit an alpha-beta- structure for liver insulin receptors and a clear difference in the number of liver insulin receptors. Their sensitivity to insulin is, however, similarly decreased compared with chicken. -''' +""" # Here is how unicode characters are encoded over 1 to 6 bytes in utf-8 @@ -82,32 +82,33 @@ decreased compared with chicken. # U-04000000 - U-7FFFFFFF: 1111110x 10xxxxxx 10xxxxxx 10xxxxxx 10xxxxxx 10xxxxxx # Note that: hex(int('11000011',2)) == '0xc3' # Thus, the following 2-bytes sequence is not valid utf8: "invalid continuation byte" -NON_UTF8_DATA = b'\xC3\xC3' +NON_UTF8_DATA = b"\xC3\xC3" -unicode_folder = u'\u00fcnic\u00f8de' if PY2 else '\u00fcnic\u00f8de' -utf8_folder = b'/\xc3\xbcnic\xc3\xb8\x64\x65' +unicode_folder = u"\u00fcnic\u00f8de" if PY2 else "\u00fcnic\u00f8de" +utf8_folder = b"/\xc3\xbcnic\xc3\xb8\x64\x65" @slow class TestSFTP(object): + def test_1_file(self, sftp): """ verify that we can create a file. """ - f = sftp.open(sftp.FOLDER + '/test', 'w') + f = sftp.open(sftp.FOLDER + "/test", "w") try: assert f.stat().st_size == 0 finally: f.close() - sftp.remove(sftp.FOLDER + '/test') + sftp.remove(sftp.FOLDER + "/test") def test_2_close(self, sftp): """ Verify that SFTP session close() causes a socket error on next action. """ sftp.close() - with pytest.raises(socket.error, match='Socket is closed'): - sftp.open(sftp.FOLDER + '/test2', 'w') + with pytest.raises(socket.error, match="Socket is closed"): + sftp.open(sftp.FOLDER + "/test2", "w") def test_2_sftp_can_be_used_as_context_manager(self, sftp): """ @@ -115,117 +116,117 @@ class TestSFTP(object): """ with sftp: pass - with pytest.raises(socket.error, match='Socket is closed'): - sftp.open(sftp.FOLDER + '/test2', 'w') + with pytest.raises(socket.error, match="Socket is closed"): + sftp.open(sftp.FOLDER + "/test2", "w") def test_3_write(self, sftp): """ verify that a file can be created and written, and the size is correct. """ try: - with sftp.open(sftp.FOLDER + '/duck.txt', 'w') as f: + with sftp.open(sftp.FOLDER + "/duck.txt", "w") as f: f.write(ARTICLE) - assert sftp.stat(sftp.FOLDER + '/duck.txt').st_size == 1483 + assert sftp.stat(sftp.FOLDER + "/duck.txt").st_size == 1483 finally: - sftp.remove(sftp.FOLDER + '/duck.txt') + sftp.remove(sftp.FOLDER + "/duck.txt") def test_3_sftp_file_can_be_used_as_context_manager(self, sftp): """ verify that an opened file can be used as a context manager """ try: - with sftp.open(sftp.FOLDER + '/duck.txt', 'w') as f: + with sftp.open(sftp.FOLDER + "/duck.txt", "w") as f: f.write(ARTICLE) - assert sftp.stat(sftp.FOLDER + '/duck.txt').st_size == 1483 + assert sftp.stat(sftp.FOLDER + "/duck.txt").st_size == 1483 finally: - sftp.remove(sftp.FOLDER + '/duck.txt') + sftp.remove(sftp.FOLDER + "/duck.txt") def test_4_append(self, sftp): """ verify that a file can be opened for append, and tell() still works. """ try: - with sftp.open(sftp.FOLDER + '/append.txt', 'w') as f: - f.write('first line\nsecond line\n') + with sftp.open(sftp.FOLDER + "/append.txt", "w") as f: + f.write("first line\nsecond line\n") assert f.tell() == 23 - with sftp.open(sftp.FOLDER + '/append.txt', 'a+') as f: - f.write('third line!!!\n') + with sftp.open(sftp.FOLDER + "/append.txt", "a+") as f: + f.write("third line!!!\n") assert f.tell() == 37 assert f.stat().st_size == 37 f.seek(-26, f.SEEK_CUR) - assert f.readline() == 'second line\n' + assert f.readline() == "second line\n" finally: - sftp.remove(sftp.FOLDER + '/append.txt') + sftp.remove(sftp.FOLDER + "/append.txt") def test_5_rename(self, sftp): """ verify that renaming a file works. """ try: - with sftp.open(sftp.FOLDER + '/first.txt', 'w') as f: - f.write('content!\n') - sftp.rename(sftp.FOLDER + '/first.txt', sftp.FOLDER + '/second.txt') - with pytest.raises(IOError, match='No such file'): - sftp.open(sftp.FOLDER + '/first.txt', 'r') - with sftp.open(sftp.FOLDER + '/second.txt', 'r') as f: + with sftp.open(sftp.FOLDER + "/first.txt", "w") as f: + f.write("content!\n") + sftp.rename( + sftp.FOLDER + "/first.txt", sftp.FOLDER + "/second.txt" + ) + with pytest.raises(IOError, match="No such file"): + sftp.open(sftp.FOLDER + "/first.txt", "r") + with sftp.open(sftp.FOLDER + "/second.txt", "r") as f: f.seek(-6, f.SEEK_END) - assert u(f.read(4)) == 'tent' + assert u(f.read(4)) == "tent" finally: # TODO: this is gross, make some sort of 'remove if possible' / 'rm # -f' a-like, jeez try: - sftp.remove(sftp.FOLDER + '/first.txt') + sftp.remove(sftp.FOLDER + "/first.txt") except: pass try: - sftp.remove(sftp.FOLDER + '/second.txt') + sftp.remove(sftp.FOLDER + "/second.txt") except: pass - def test_5a_posix_rename(self, sftp): """Test posix-rename@openssh.com protocol extension.""" try: # first check that the normal rename works as specified - with sftp.open(sftp.FOLDER + '/a', 'w') as f: - f.write('one') - sftp.rename(sftp.FOLDER + '/a', sftp.FOLDER + '/b') - with sftp.open(sftp.FOLDER + '/a', 'w') as f: - f.write('two') - with pytest.raises(IOError): # actual message seems generic - sftp.rename(sftp.FOLDER + '/a', sftp.FOLDER + '/b') + with sftp.open(sftp.FOLDER + "/a", "w") as f: + f.write("one") + sftp.rename(sftp.FOLDER + "/a", sftp.FOLDER + "/b") + with sftp.open(sftp.FOLDER + "/a", "w") as f: + f.write("two") + with pytest.raises(IOError): # actual message seems generic + sftp.rename(sftp.FOLDER + "/a", sftp.FOLDER + "/b") # now check with the posix_rename - sftp.posix_rename(sftp.FOLDER + '/a', sftp.FOLDER + '/b') - with sftp.open(sftp.FOLDER + '/b', 'r') as f: + sftp.posix_rename(sftp.FOLDER + "/a", sftp.FOLDER + "/b") + with sftp.open(sftp.FOLDER + "/b", "r") as f: data = u(f.read()) err = "Contents of renamed file not the same as original file" - assert 'two' == data, err + assert "two" == data, err finally: try: - sftp.remove(sftp.FOLDER + '/a') + sftp.remove(sftp.FOLDER + "/a") except: pass try: - sftp.remove(sftp.FOLDER + '/b') + sftp.remove(sftp.FOLDER + "/b") except: pass - def test_6_folder(self, sftp): """ create a temporary folder, verify that we can create a file in it, then remove the folder and verify that we can't create a file in it anymore. """ - sftp.mkdir(sftp.FOLDER + '/subfolder') - sftp.open(sftp.FOLDER + '/subfolder/test', 'w').close() - sftp.remove(sftp.FOLDER + '/subfolder/test') - sftp.rmdir(sftp.FOLDER + '/subfolder') + sftp.mkdir(sftp.FOLDER + "/subfolder") + sftp.open(sftp.FOLDER + "/subfolder/test", "w").close() + sftp.remove(sftp.FOLDER + "/subfolder/test") + sftp.rmdir(sftp.FOLDER + "/subfolder") # shouldn't be able to create that file if dir removed with pytest.raises(IOError, match="No such file"): - sftp.open(sftp.FOLDER + '/subfolder/test') + sftp.open(sftp.FOLDER + "/subfolder/test") def test_7_listdir(self, sftp): """ @@ -233,57 +234,57 @@ class TestSFTP(object): it, and those files show up in sftp.listdir. """ try: - sftp.open(sftp.FOLDER + '/duck.txt', 'w').close() - sftp.open(sftp.FOLDER + '/fish.txt', 'w').close() - sftp.open(sftp.FOLDER + '/tertiary.py', 'w').close() + sftp.open(sftp.FOLDER + "/duck.txt", "w").close() + sftp.open(sftp.FOLDER + "/fish.txt", "w").close() + sftp.open(sftp.FOLDER + "/tertiary.py", "w").close() x = sftp.listdir(sftp.FOLDER) assert len(x) == 3 - assert 'duck.txt' in x - assert 'fish.txt' in x - assert 'tertiary.py' in x - assert 'random' not in x + assert "duck.txt" in x + assert "fish.txt" in x + assert "tertiary.py" in x + assert "random" not in x finally: - sftp.remove(sftp.FOLDER + '/duck.txt') - sftp.remove(sftp.FOLDER + '/fish.txt') - sftp.remove(sftp.FOLDER + '/tertiary.py') + sftp.remove(sftp.FOLDER + "/duck.txt") + sftp.remove(sftp.FOLDER + "/fish.txt") + sftp.remove(sftp.FOLDER + "/tertiary.py") def test_7_5_listdir_iter(self, sftp): """ listdir_iter version of above test """ try: - sftp.open(sftp.FOLDER + '/duck.txt', 'w').close() - sftp.open(sftp.FOLDER + '/fish.txt', 'w').close() - sftp.open(sftp.FOLDER + '/tertiary.py', 'w').close() + sftp.open(sftp.FOLDER + "/duck.txt", "w").close() + sftp.open(sftp.FOLDER + "/fish.txt", "w").close() + sftp.open(sftp.FOLDER + "/tertiary.py", "w").close() x = [x.filename for x in sftp.listdir_iter(sftp.FOLDER)] assert len(x) == 3 - assert 'duck.txt' in x - assert 'fish.txt' in x - assert 'tertiary.py' in x - assert 'random' not in x + assert "duck.txt" in x + assert "fish.txt" in x + assert "tertiary.py" in x + assert "random" not in x finally: - sftp.remove(sftp.FOLDER + '/duck.txt') - sftp.remove(sftp.FOLDER + '/fish.txt') - sftp.remove(sftp.FOLDER + '/tertiary.py') + sftp.remove(sftp.FOLDER + "/duck.txt") + sftp.remove(sftp.FOLDER + "/fish.txt") + sftp.remove(sftp.FOLDER + "/tertiary.py") def test_8_setstat(self, sftp): """ verify that the setstat functions (chown, chmod, utime, truncate) work. """ try: - with sftp.open(sftp.FOLDER + '/special', 'w') as f: - f.write('x' * 1024) + with sftp.open(sftp.FOLDER + "/special", "w") as f: + f.write("x" * 1024) - stat = sftp.stat(sftp.FOLDER + '/special') - sftp.chmod(sftp.FOLDER + '/special', (stat.st_mode & ~o777) | o600) - stat = sftp.stat(sftp.FOLDER + '/special') + stat = sftp.stat(sftp.FOLDER + "/special") + sftp.chmod(sftp.FOLDER + "/special", (stat.st_mode & ~o777) | o600) + stat = sftp.stat(sftp.FOLDER + "/special") expected_mode = o600 - if sys.platform == 'win32': + if sys.platform == "win32": # chmod not really functional on windows expected_mode = o666 - if sys.platform == 'cygwin': + if sys.platform == "cygwin": # even worse. expected_mode = o644 assert stat.st_mode & o777 == expected_mode @@ -291,19 +292,19 @@ class TestSFTP(object): mtime = stat.st_mtime - 3600 atime = stat.st_atime - 1800 - sftp.utime(sftp.FOLDER + '/special', (atime, mtime)) - stat = sftp.stat(sftp.FOLDER + '/special') + sftp.utime(sftp.FOLDER + "/special", (atime, mtime)) + stat = sftp.stat(sftp.FOLDER + "/special") assert stat.st_mtime == mtime - if sys.platform not in ('win32', 'cygwin'): + if sys.platform not in ("win32", "cygwin"): assert stat.st_atime == atime # can't really test chown, since we'd have to know a valid uid. - sftp.truncate(sftp.FOLDER + '/special', 512) - stat = sftp.stat(sftp.FOLDER + '/special') + sftp.truncate(sftp.FOLDER + "/special", 512) + stat = sftp.stat(sftp.FOLDER + "/special") assert stat.st_size == 512 finally: - sftp.remove(sftp.FOLDER + '/special') + sftp.remove(sftp.FOLDER + "/special") def test_9_fsetstat(self, sftp): """ @@ -311,19 +312,19 @@ class TestSFTP(object): work on open files. """ try: - with sftp.open(sftp.FOLDER + '/special', 'w') as f: - f.write('x' * 1024) + with sftp.open(sftp.FOLDER + "/special", "w") as f: + f.write("x" * 1024) - with sftp.open(sftp.FOLDER + '/special', 'r+') as f: + with sftp.open(sftp.FOLDER + "/special", "r+") as f: stat = f.stat() f.chmod((stat.st_mode & ~o777) | o600) stat = f.stat() expected_mode = o600 - if sys.platform == 'win32': + if sys.platform == "win32": # chmod not really functional on windows expected_mode = o666 - if sys.platform == 'cygwin': + if sys.platform == "cygwin": # even worse. expected_mode = o644 assert stat.st_mode & o777 == expected_mode @@ -334,7 +335,7 @@ class TestSFTP(object): f.utime((atime, mtime)) stat = f.stat() assert stat.st_mtime == mtime - if sys.platform not in ('win32', 'cygwin'): + if sys.platform not in ("win32", "cygwin"): assert stat.st_atime == atime # can't really test chown, since we'd have to know a valid uid. @@ -343,7 +344,7 @@ class TestSFTP(object): stat = f.stat() assert stat.st_size == 512 finally: - sftp.remove(sftp.FOLDER + '/special') + sftp.remove(sftp.FOLDER + "/special") def test_A_readline_seek(self, sftp): """ @@ -353,10 +354,10 @@ class TestSFTP(object): buffering is reset on 'seek'. """ try: - with sftp.open(sftp.FOLDER + '/duck.txt', 'w') as f: + with sftp.open(sftp.FOLDER + "/duck.txt", "w") as f: f.write(ARTICLE) - with sftp.open(sftp.FOLDER + '/duck.txt', 'r+') as f: + with sftp.open(sftp.FOLDER + "/duck.txt", "r+") as f: line_number = 0 loc = 0 pos_list = [] @@ -366,13 +367,16 @@ class TestSFTP(object): loc = f.tell() assert f.seekable() f.seek(pos_list[6], f.SEEK_SET) - assert f.readline(), 'Nouzilly == France.\n' + assert f.readline(), "Nouzilly == France.\n" f.seek(pos_list[17], f.SEEK_SET) - assert f.readline()[:4] == 'duck' + assert f.readline()[:4] == "duck" f.seek(pos_list[10], f.SEEK_SET) - assert f.readline() == 'duck types were equally resistant to exogenous insulin compared with chicken.\n' + assert ( + f.readline() + == "duck types were equally resistant to exogenous insulin compared with chicken.\n" + ) finally: - sftp.remove(sftp.FOLDER + '/duck.txt') + sftp.remove(sftp.FOLDER + "/duck.txt") def test_B_write_seek(self, sftp): """ @@ -380,17 +384,17 @@ class TestSFTP(object): changes worked. """ try: - with sftp.open(sftp.FOLDER + '/testing.txt', 'w') as f: - f.write('hello kitty.\n') + with sftp.open(sftp.FOLDER + "/testing.txt", "w") as f: + f.write("hello kitty.\n") f.seek(-5, f.SEEK_CUR) - f.write('dd') + f.write("dd") - assert sftp.stat(sftp.FOLDER + '/testing.txt').st_size == 13 - with sftp.open(sftp.FOLDER + '/testing.txt', 'r') as f: + assert sftp.stat(sftp.FOLDER + "/testing.txt").st_size == 13 + with sftp.open(sftp.FOLDER + "/testing.txt", "r") as f: data = f.read(20) - assert data == b'hello kiddy.\n' + assert data == b"hello kiddy.\n" finally: - sftp.remove(sftp.FOLDER + '/testing.txt') + sftp.remove(sftp.FOLDER + "/testing.txt") def test_C_symlink(self, sftp): """ @@ -401,39 +405,41 @@ class TestSFTP(object): return try: - with sftp.open(sftp.FOLDER + '/original.txt', 'w') as f: - f.write('original\n') - sftp.symlink('original.txt', sftp.FOLDER + '/link.txt') - assert sftp.readlink(sftp.FOLDER + '/link.txt') == 'original.txt' + with sftp.open(sftp.FOLDER + "/original.txt", "w") as f: + f.write("original\n") + sftp.symlink("original.txt", sftp.FOLDER + "/link.txt") + assert sftp.readlink(sftp.FOLDER + "/link.txt") == "original.txt" - with sftp.open(sftp.FOLDER + '/link.txt', 'r') as f: - assert f.readlines() == ['original\n'] + with sftp.open(sftp.FOLDER + "/link.txt", "r") as f: + assert f.readlines() == ["original\n"] - cwd = sftp.normalize('.') - if cwd[-1] == '/': + cwd = sftp.normalize(".") + if cwd[-1] == "/": cwd = cwd[:-1] - abs_path = cwd + '/' + sftp.FOLDER + '/original.txt' - sftp.symlink(abs_path, sftp.FOLDER + '/link2.txt') - assert abs_path == sftp.readlink(sftp.FOLDER + '/link2.txt') + abs_path = cwd + "/" + sftp.FOLDER + "/original.txt" + sftp.symlink(abs_path, sftp.FOLDER + "/link2.txt") + assert abs_path == sftp.readlink(sftp.FOLDER + "/link2.txt") - assert sftp.lstat(sftp.FOLDER + '/link.txt').st_size == 12 - assert sftp.stat(sftp.FOLDER + '/link.txt').st_size == 9 + assert sftp.lstat(sftp.FOLDER + "/link.txt").st_size == 12 + assert sftp.stat(sftp.FOLDER + "/link.txt").st_size == 9 # the sftp server may be hiding extra path members from us, so the # length may be longer than we expect: - assert sftp.lstat(sftp.FOLDER + '/link2.txt').st_size >= len(abs_path) - assert sftp.stat(sftp.FOLDER + '/link2.txt').st_size == 9 - assert sftp.stat(sftp.FOLDER + '/original.txt').st_size == 9 + assert sftp.lstat(sftp.FOLDER + "/link2.txt").st_size >= len( + abs_path + ) + assert sftp.stat(sftp.FOLDER + "/link2.txt").st_size == 9 + assert sftp.stat(sftp.FOLDER + "/original.txt").st_size == 9 finally: try: - sftp.remove(sftp.FOLDER + '/link.txt') + sftp.remove(sftp.FOLDER + "/link.txt") except: pass try: - sftp.remove(sftp.FOLDER + '/link2.txt') + sftp.remove(sftp.FOLDER + "/link2.txt") except: pass try: - sftp.remove(sftp.FOLDER + '/original.txt') + sftp.remove(sftp.FOLDER + "/original.txt") except: pass @@ -442,18 +448,18 @@ class TestSFTP(object): verify that buffered writes are automatically flushed on seek. """ try: - with sftp.open(sftp.FOLDER + '/happy.txt', 'w', 1) as f: - f.write('full line.\n') - f.write('partial') + with sftp.open(sftp.FOLDER + "/happy.txt", "w", 1) as f: + f.write("full line.\n") + f.write("partial") f.seek(9, f.SEEK_SET) - f.write('?\n') + f.write("?\n") - with sftp.open(sftp.FOLDER + '/happy.txt', 'r') as f: - assert f.readline() == u('full line?\n') - assert f.read(7) == b'partial' + with sftp.open(sftp.FOLDER + "/happy.txt", "r") as f: + assert f.readline() == u("full line?\n") + assert f.read(7) == b"partial" finally: try: - sftp.remove(sftp.FOLDER + '/happy.txt') + sftp.remove(sftp.FOLDER + "/happy.txt") except: pass @@ -462,9 +468,9 @@ class TestSFTP(object): test that realpath is returning something non-empty and not an error. """ - pwd = sftp.normalize('.') + pwd = sftp.normalize(".") assert len(pwd) > 0 - f = sftp.normalize('./' + sftp.FOLDER) + f = sftp.normalize("./" + sftp.FOLDER) assert len(f) > 0 assert os.path.join(pwd, sftp.FOLDER) == f @@ -472,46 +478,46 @@ class TestSFTP(object): """ verify that mkdir/rmdir work. """ - sftp.mkdir(sftp.FOLDER + '/subfolder') - with pytest.raises(IOError): # generic msg only - sftp.mkdir(sftp.FOLDER + '/subfolder') - sftp.rmdir(sftp.FOLDER + '/subfolder') + sftp.mkdir(sftp.FOLDER + "/subfolder") + with pytest.raises(IOError): # generic msg only + sftp.mkdir(sftp.FOLDER + "/subfolder") + sftp.rmdir(sftp.FOLDER + "/subfolder") with pytest.raises(IOError, match="No such file"): - sftp.rmdir(sftp.FOLDER + '/subfolder') + sftp.rmdir(sftp.FOLDER + "/subfolder") def test_G_chdir(self, sftp): """ verify that chdir/getcwd work. """ - root = sftp.normalize('.') - if root[-1] != '/': - root += '/' + root = sftp.normalize(".") + if root[-1] != "/": + root += "/" try: - sftp.mkdir(sftp.FOLDER + '/alpha') - sftp.chdir(sftp.FOLDER + '/alpha') - sftp.mkdir('beta') - assert root + sftp.FOLDER + '/alpha' == sftp.getcwd() - assert ['beta'] == sftp.listdir('.') - - sftp.chdir('beta') - with sftp.open('fish', 'w') as f: - f.write('hello\n') - sftp.chdir('..') - assert ['fish'] == sftp.listdir('beta') - sftp.chdir('..') - assert ['fish'] == sftp.listdir('alpha/beta') + sftp.mkdir(sftp.FOLDER + "/alpha") + sftp.chdir(sftp.FOLDER + "/alpha") + sftp.mkdir("beta") + assert root + sftp.FOLDER + "/alpha" == sftp.getcwd() + assert ["beta"] == sftp.listdir(".") + + sftp.chdir("beta") + with sftp.open("fish", "w") as f: + f.write("hello\n") + sftp.chdir("..") + assert ["fish"] == sftp.listdir("beta") + sftp.chdir("..") + assert ["fish"] == sftp.listdir("alpha/beta") finally: sftp.chdir(root) try: - sftp.unlink(sftp.FOLDER + '/alpha/beta/fish') + sftp.unlink(sftp.FOLDER + "/alpha/beta/fish") except: pass try: - sftp.rmdir(sftp.FOLDER + '/alpha/beta') + sftp.rmdir(sftp.FOLDER + "/alpha/beta") except: pass try: - sftp.rmdir(sftp.FOLDER + '/alpha') + sftp.rmdir(sftp.FOLDER + "/alpha") except: pass @@ -519,20 +525,21 @@ class TestSFTP(object): """ verify that get/put work. """ - warnings.filterwarnings('ignore', 'tempnam.*') + warnings.filterwarnings("ignore", "tempnam.*") fd, localname = mkstemp() os.close(fd) - text = b'All I wanted was a plastic bunny rabbit.\n' - with open(localname, 'wb') as f: + text = b"All I wanted was a plastic bunny rabbit.\n" + with open(localname, "wb") as f: f.write(text) saved_progress = [] def progress_callback(x, y): saved_progress.append((x, y)) - sftp.put(localname, sftp.FOLDER + '/bunny.txt', progress_callback) - with sftp.open(sftp.FOLDER + '/bunny.txt', 'rb') as f: + sftp.put(localname, sftp.FOLDER + "/bunny.txt", progress_callback) + + with sftp.open(sftp.FOLDER + "/bunny.txt", "rb") as f: assert text == f.read(128) assert [(41, 41)] == saved_progress @@ -540,14 +547,14 @@ class TestSFTP(object): fd, localname = mkstemp() os.close(fd) saved_progress = [] - sftp.get(sftp.FOLDER + '/bunny.txt', localname, progress_callback) + sftp.get(sftp.FOLDER + "/bunny.txt", localname, progress_callback) - with open(localname, 'rb') as f: + with open(localname, "rb") as f: assert text == f.read(128) assert [(41, 41)] == saved_progress os.unlink(localname) - sftp.unlink(sftp.FOLDER + '/bunny.txt') + sftp.unlink(sftp.FOLDER + "/bunny.txt") def test_I_check(self, sftp): """ @@ -555,118 +562,132 @@ class TestSFTP(object): (it's an sftp extension that we support, and may be the only ones who support it.) """ - with sftp.open(sftp.FOLDER + '/kitty.txt', 'w') as f: - f.write('here kitty kitty' * 64) + with sftp.open(sftp.FOLDER + "/kitty.txt", "w") as f: + f.write("here kitty kitty" * 64) try: - with sftp.open(sftp.FOLDER + '/kitty.txt', 'r') as f: - sum = f.check('sha1') - assert '91059CFC6615941378D413CB5ADAF4C5EB293402' == u(hexlify(sum)).upper() - sum = f.check('md5', 0, 512) - assert '93DE4788FCA28D471516963A1FE3856A' == u(hexlify(sum)).upper() - sum = f.check('md5', 0, 0, 510) - assert u(hexlify(sum)).upper() == 'EB3B45B8CD55A0707D99B177544A319F373183D241432BB2157AB9E46358C4AC90370B5CADE5D90336FC1716F90B36D6' # noqa + with sftp.open(sftp.FOLDER + "/kitty.txt", "r") as f: + sum = f.check("sha1") + assert ( + "91059CFC6615941378D413CB5ADAF4C5EB293402" + == u(hexlify(sum)).upper() + ) + sum = f.check("md5", 0, 512) + assert ( + "93DE4788FCA28D471516963A1FE3856A" + == u(hexlify(sum)).upper() + ) + sum = f.check("md5", 0, 0, 510) + assert ( + u(hexlify(sum)).upper() + == "EB3B45B8CD55A0707D99B177544A319F373183D241432BB2157AB9E46358C4AC90370B5CADE5D90336FC1716F90B36D6" + ) # noqa finally: - sftp.unlink(sftp.FOLDER + '/kitty.txt') + sftp.unlink(sftp.FOLDER + "/kitty.txt") def test_J_x_flag(self, sftp): """ verify that the 'x' flag works when opening a file. """ - sftp.open(sftp.FOLDER + '/unusual.txt', 'wx').close() + sftp.open(sftp.FOLDER + "/unusual.txt", "wx").close() try: try: - sftp.open(sftp.FOLDER + '/unusual.txt', 'wx') - self.fail('expected exception') + sftp.open(sftp.FOLDER + "/unusual.txt", "wx") + self.fail("expected exception") except IOError: pass finally: - sftp.unlink(sftp.FOLDER + '/unusual.txt') + sftp.unlink(sftp.FOLDER + "/unusual.txt") def test_K_utf8(self, sftp): """ verify that unicode strings are encoded into utf8 correctly. """ - with sftp.open(sftp.FOLDER + '/something', 'w') as f: - f.write('okay') + with sftp.open(sftp.FOLDER + "/something", "w") as f: + f.write("okay") try: - sftp.rename(sftp.FOLDER + '/something', sftp.FOLDER + '/' + unicode_folder) - sftp.open(b(sftp.FOLDER) + utf8_folder, 'r') + sftp.rename( + sftp.FOLDER + "/something", sftp.FOLDER + "/" + unicode_folder + ) + sftp.open(b(sftp.FOLDER) + utf8_folder, "r") except Exception as e: - self.fail('exception ' + str(e)) + self.fail("exception " + str(e)) sftp.unlink(b(sftp.FOLDER) + utf8_folder) def test_L_utf8_chdir(self, sftp): - sftp.mkdir(sftp.FOLDER + '/' + unicode_folder) + sftp.mkdir(sftp.FOLDER + "/" + unicode_folder) try: - sftp.chdir(sftp.FOLDER + '/' + unicode_folder) - with sftp.open('something', 'w') as f: - f.write('okay') - sftp.unlink('something') + sftp.chdir(sftp.FOLDER + "/" + unicode_folder) + with sftp.open("something", "w") as f: + f.write("okay") + sftp.unlink("something") finally: sftp.chdir() - sftp.rmdir(sftp.FOLDER + '/' + unicode_folder) + sftp.rmdir(sftp.FOLDER + "/" + unicode_folder) def test_M_bad_readv(self, sftp): """ verify that readv at the end of the file doesn't essplode. """ - sftp.open(sftp.FOLDER + '/zero', 'w').close() + sftp.open(sftp.FOLDER + "/zero", "w").close() try: - with sftp.open(sftp.FOLDER + '/zero', 'r') as f: + with sftp.open(sftp.FOLDER + "/zero", "r") as f: f.readv([(0, 12)]) - with sftp.open(sftp.FOLDER + '/zero', 'r') as f: + with sftp.open(sftp.FOLDER + "/zero", "r") as f: file_size = f.stat().st_size f.prefetch(file_size) f.read(100) finally: - sftp.unlink(sftp.FOLDER + '/zero') + sftp.unlink(sftp.FOLDER + "/zero") def test_N_put_without_confirm(self, sftp): """ verify that get/put work without confirmation. """ - warnings.filterwarnings('ignore', 'tempnam.*') + warnings.filterwarnings("ignore", "tempnam.*") fd, localname = mkstemp() os.close(fd) - text = b'All I wanted was a plastic bunny rabbit.\n' - with open(localname, 'wb') as f: + text = b"All I wanted was a plastic bunny rabbit.\n" + with open(localname, "wb") as f: f.write(text) saved_progress = [] def progress_callback(x, y): saved_progress.append((x, y)) - res = sftp.put(localname, sftp.FOLDER + '/bunny.txt', progress_callback, False) + + res = sftp.put( + localname, sftp.FOLDER + "/bunny.txt", progress_callback, False + ) assert SFTPAttributes().attr == res.attr - with sftp.open(sftp.FOLDER + '/bunny.txt', 'r') as f: + with sftp.open(sftp.FOLDER + "/bunny.txt", "r") as f: assert text == f.read(128) assert (41, 41) == saved_progress[-1] os.unlink(localname) - sftp.unlink(sftp.FOLDER + '/bunny.txt') + sftp.unlink(sftp.FOLDER + "/bunny.txt") def test_O_getcwd(self, sftp): """ verify that chdir/getcwd work. """ assert sftp.getcwd() == None - root = sftp.normalize('.') - if root[-1] != '/': - root += '/' + root = sftp.normalize(".") + if root[-1] != "/": + root += "/" try: - sftp.mkdir(sftp.FOLDER + '/alpha') - sftp.chdir(sftp.FOLDER + '/alpha') - assert sftp.getcwd() == '/' + sftp.FOLDER + '/alpha' + sftp.mkdir(sftp.FOLDER + "/alpha") + sftp.chdir(sftp.FOLDER + "/alpha") + assert sftp.getcwd() == "/" + sftp.FOLDER + "/alpha" finally: sftp.chdir(root) try: - sftp.rmdir(sftp.FOLDER + '/alpha') + sftp.rmdir(sftp.FOLDER + "/alpha") except: pass @@ -677,24 +698,24 @@ class TestSFTP(object): does not work except through paramiko. :( openssh fails. """ try: - with sftp.open(sftp.FOLDER + '/append.txt', 'a') as f: - f.write('first line\nsecond line\n') + with sftp.open(sftp.FOLDER + "/append.txt", "a") as f: + f.write("first line\nsecond line\n") f.seek(11, f.SEEK_SET) - f.write('third line\n') + f.write("third line\n") - with sftp.open(sftp.FOLDER + '/append.txt', 'r') as f: + with sftp.open(sftp.FOLDER + "/append.txt", "r") as f: assert f.stat().st_size == 34 - assert f.readline() == 'first line\n' - assert f.readline() == 'second line\n' - assert f.readline() == 'third line\n' + assert f.readline() == "first line\n" + assert f.readline() == "second line\n" + assert f.readline() == "third line\n" finally: - sftp.remove(sftp.FOLDER + '/append.txt') + sftp.remove(sftp.FOLDER + "/append.txt") def test_putfo_empty_file(self, sftp): """ Send an empty file and confirm it is sent. """ - target = sftp.FOLDER + '/empty file.txt' + target = sftp.FOLDER + "/empty file.txt" stream = StringIO() try: attrs = sftp.putfo(stream, target) @@ -713,59 +734,61 @@ class TestSFTP(object): verify that we can create a file with a '%' in the filename. ( it needs to be properly escaped by _log() ) """ - f = sftp.open(sftp.FOLDER + '/test%file', 'w') + f = sftp.open(sftp.FOLDER + "/test%file", "w") try: assert f.stat().st_size == 0 finally: f.close() - sftp.remove(sftp.FOLDER + '/test%file') + sftp.remove(sftp.FOLDER + "/test%file") def test_O_non_utf8_data(self, sftp): """Test write() and read() of non utf8 data""" try: - with sftp.open('%s/nonutf8data' % sftp.FOLDER, 'w') as f: + with sftp.open("%s/nonutf8data" % sftp.FOLDER, "w") as f: f.write(NON_UTF8_DATA) - with sftp.open('%s/nonutf8data' % sftp.FOLDER, 'r') as f: + with sftp.open("%s/nonutf8data" % sftp.FOLDER, "r") as f: data = f.read() assert data == NON_UTF8_DATA - with sftp.open('%s/nonutf8data' % sftp.FOLDER, 'wb') as f: + with sftp.open("%s/nonutf8data" % sftp.FOLDER, "wb") as f: f.write(NON_UTF8_DATA) - with sftp.open('%s/nonutf8data' % sftp.FOLDER, 'rb') as f: + with sftp.open("%s/nonutf8data" % sftp.FOLDER, "rb") as f: data = f.read() assert data == NON_UTF8_DATA finally: - sftp.remove('%s/nonutf8data' % sftp.FOLDER) - + sftp.remove("%s/nonutf8data" % sftp.FOLDER) def test_sftp_attributes_empty_str(self, sftp): sftp_attributes = SFTPAttributes() - assert str(sftp_attributes) == "?--------- 1 0 0 0 (unknown date) ?" + assert ( + str(sftp_attributes) + == "?--------- 1 0 0 0 (unknown date) ?" + ) - @needs_builtin('buffer') + @needs_builtin("buffer") def test_write_buffer(self, sftp): """Test write() using a buffer instance.""" - data = 3 * b'A potentially large block of data to chunk up.\n' + data = 3 * b"A potentially large block of data to chunk up.\n" try: - with sftp.open('%s/write_buffer' % sftp.FOLDER, 'wb') as f: + with sftp.open("%s/write_buffer" % sftp.FOLDER, "wb") as f: for offset in range(0, len(data), 8): f.write(buffer(data, offset, 8)) - with sftp.open('%s/write_buffer' % sftp.FOLDER, 'rb') as f: + with sftp.open("%s/write_buffer" % sftp.FOLDER, "rb") as f: assert f.read() == data finally: - sftp.remove('%s/write_buffer' % sftp.FOLDER) + sftp.remove("%s/write_buffer" % sftp.FOLDER) - @needs_builtin('memoryview') + @needs_builtin("memoryview") def test_write_memoryview(self, sftp): """Test write() using a memoryview instance.""" - data = 3 * b'A potentially large block of data to chunk up.\n' + data = 3 * b"A potentially large block of data to chunk up.\n" try: - with sftp.open('%s/write_memoryview' % sftp.FOLDER, 'wb') as f: + with sftp.open("%s/write_memoryview" % sftp.FOLDER, "wb") as f: view = memoryview(data) for offset in range(0, len(data), 8): - f.write(view[offset:offset+8]) + f.write(view[offset : offset + 8]) - with sftp.open('%s/write_memoryview' % sftp.FOLDER, 'rb') as f: + with sftp.open("%s/write_memoryview" % sftp.FOLDER, "rb") as f: assert f.read() == data finally: - sftp.remove('%s/write_memoryview' % sftp.FOLDER) + sftp.remove("%s/write_memoryview" % sftp.FOLDER) diff --git a/tests/test_sftp_big.py b/tests/test_sftp_big.py index a659098d..97c0eb90 100644 --- a/tests/test_sftp_big.py +++ b/tests/test_sftp_big.py @@ -37,6 +37,7 @@ from .util import slow @slow class TestBigSFTP(object): + def test_1_lots_of_files(self, sftp): """ create a bunch of files over the same session. @@ -44,22 +45,24 @@ class TestBigSFTP(object): numfiles = 100 try: for i in range(numfiles): - with sftp.open('%s/file%d.txt' % (sftp.FOLDER, i), 'w', 1) as f: - f.write('this is file #%d.\n' % i) - sftp.chmod('%s/file%d.txt' % (sftp.FOLDER, i), o660) + with sftp.open( + "%s/file%d.txt" % (sftp.FOLDER, i), "w", 1 + ) as f: + f.write("this is file #%d.\n" % i) + sftp.chmod("%s/file%d.txt" % (sftp.FOLDER, i), o660) # now make sure every file is there, by creating a list of filenmes # and reading them in random order. numlist = list(range(numfiles)) while len(numlist) > 0: r = numlist[random.randint(0, len(numlist) - 1)] - with sftp.open('%s/file%d.txt' % (sftp.FOLDER, r)) as f: - assert f.readline() == 'this is file #%d.\n' % r + with sftp.open("%s/file%d.txt" % (sftp.FOLDER, r)) as f: + assert f.readline() == "this is file #%d.\n" % r numlist.remove(r) finally: for i in range(numfiles): try: - sftp.remove('%s/file%d.txt' % (sftp.FOLDER, i)) + sftp.remove("%s/file%d.txt" % (sftp.FOLDER, i)) except: pass @@ -67,52 +70,56 @@ class TestBigSFTP(object): """ write a 1MB file with no buffering. """ - kblob = (1024 * b'x') + kblob = 1024 * b"x" start = time.time() try: - with sftp.open('%s/hongry.txt' % sftp.FOLDER, 'w') as f: + with sftp.open("%s/hongry.txt" % sftp.FOLDER, "w") as f: for n in range(1024): f.write(kblob) if n % 128 == 0: - sys.stderr.write('.') - sys.stderr.write(' ') + sys.stderr.write(".") + sys.stderr.write(" ") - assert sftp.stat('%s/hongry.txt' % sftp.FOLDER).st_size == 1024 * 1024 + assert ( + sftp.stat("%s/hongry.txt" % sftp.FOLDER).st_size == 1024 * 1024 + ) end = time.time() - sys.stderr.write('%ds ' % round(end - start)) - + sys.stderr.write("%ds " % round(end - start)) + start = time.time() - with sftp.open('%s/hongry.txt' % sftp.FOLDER, 'r') as f: + with sftp.open("%s/hongry.txt" % sftp.FOLDER, "r") as f: for n in range(1024): data = f.read(1024) assert data == kblob end = time.time() - sys.stderr.write('%ds ' % round(end - start)) + sys.stderr.write("%ds " % round(end - start)) finally: - sftp.remove('%s/hongry.txt' % sftp.FOLDER) + sftp.remove("%s/hongry.txt" % sftp.FOLDER) def test_3_big_file_pipelined(self, sftp): """ write a 1MB file, with no linefeeds, using pipelining. """ - kblob = bytes().join([struct.pack('>H', n) for n in range(512)]) + kblob = bytes().join([struct.pack(">H", n) for n in range(512)]) start = time.time() try: - with sftp.open('%s/hongry.txt' % sftp.FOLDER, 'wb') as f: + with sftp.open("%s/hongry.txt" % sftp.FOLDER, "wb") as f: f.set_pipelined(True) for n in range(1024): f.write(kblob) if n % 128 == 0: - sys.stderr.write('.') - sys.stderr.write(' ') + sys.stderr.write(".") + sys.stderr.write(" ") - assert sftp.stat('%s/hongry.txt' % sftp.FOLDER).st_size == 1024 * 1024 + assert ( + sftp.stat("%s/hongry.txt" % sftp.FOLDER).st_size == 1024 * 1024 + ) end = time.time() - sys.stderr.write('%ds ' % round(end - start)) - + sys.stderr.write("%ds " % round(end - start)) + start = time.time() - with sftp.open('%s/hongry.txt' % sftp.FOLDER, 'rb') as f: + with sftp.open("%s/hongry.txt" % sftp.FOLDER, "rb") as f: file_size = f.stat().st_size f.prefetch(file_size) @@ -126,35 +133,39 @@ class TestBigSFTP(object): chunk = size - n data = f.read(chunk) offset = n % 1024 - assert data == k2blob[offset:offset + chunk] + assert data == k2blob[offset : offset + chunk] n += chunk end = time.time() - sys.stderr.write('%ds ' % round(end - start)) + sys.stderr.write("%ds " % round(end - start)) finally: - sftp.remove('%s/hongry.txt' % sftp.FOLDER) + sftp.remove("%s/hongry.txt" % sftp.FOLDER) def test_4_prefetch_seek(self, sftp): - kblob = bytes().join([struct.pack('>H', n) for n in range(512)]) + kblob = bytes().join([struct.pack(">H", n) for n in range(512)]) try: - with sftp.open('%s/hongry.txt' % sftp.FOLDER, 'wb') as f: + with sftp.open("%s/hongry.txt" % sftp.FOLDER, "wb") as f: f.set_pipelined(True) for n in range(1024): f.write(kblob) if n % 128 == 0: - sys.stderr.write('.') - sys.stderr.write(' ') - - assert sftp.stat('%s/hongry.txt' % sftp.FOLDER).st_size == 1024 * 1024 - + sys.stderr.write(".") + sys.stderr.write(" ") + + assert ( + sftp.stat("%s/hongry.txt" % sftp.FOLDER).st_size == 1024 * 1024 + ) + start = time.time() k2blob = kblob + kblob chunk = 793 for i in range(10): - with sftp.open('%s/hongry.txt' % sftp.FOLDER, 'rb') as f: + with sftp.open("%s/hongry.txt" % sftp.FOLDER, "rb") as f: file_size = f.stat().st_size f.prefetch(file_size) - base_offset = (512 * 1024) + 17 * random.randint(1000, 2000) + base_offset = (512 * 1024) + 17 * random.randint( + 1000, 2000 + ) offsets = [base_offset + j * chunk for j in range(100)] # randomly seek around and read them out for j in range(100): @@ -163,32 +174,36 @@ class TestBigSFTP(object): f.seek(offset) data = f.read(chunk) n_offset = offset % 1024 - assert data == k2blob[n_offset:n_offset + chunk] + assert data == k2blob[n_offset : n_offset + chunk] offset += chunk end = time.time() - sys.stderr.write('%ds ' % round(end - start)) + sys.stderr.write("%ds " % round(end - start)) finally: - sftp.remove('%s/hongry.txt' % sftp.FOLDER) + sftp.remove("%s/hongry.txt" % sftp.FOLDER) def test_5_readv_seek(self, sftp): - kblob = bytes().join([struct.pack('>H', n) for n in range(512)]) + kblob = bytes().join([struct.pack(">H", n) for n in range(512)]) try: - with sftp.open('%s/hongry.txt' % sftp.FOLDER, 'wb') as f: + with sftp.open("%s/hongry.txt" % sftp.FOLDER, "wb") as f: f.set_pipelined(True) for n in range(1024): f.write(kblob) if n % 128 == 0: - sys.stderr.write('.') - sys.stderr.write(' ') + sys.stderr.write(".") + sys.stderr.write(" ") - assert sftp.stat('%s/hongry.txt' % sftp.FOLDER).st_size == 1024 * 1024 + assert ( + sftp.stat("%s/hongry.txt" % sftp.FOLDER).st_size == 1024 * 1024 + ) start = time.time() k2blob = kblob + kblob chunk = 793 for i in range(10): - with sftp.open('%s/hongry.txt' % sftp.FOLDER, 'rb') as f: - base_offset = (512 * 1024) + 17 * random.randint(1000, 2000) + with sftp.open("%s/hongry.txt" % sftp.FOLDER, "rb") as f: + base_offset = (512 * 1024) + 17 * random.randint( + 1000, 2000 + ) # make a bunch of offsets and put them in random order offsets = [base_offset + j * chunk for j in range(100)] readv_list = [] @@ -200,62 +215,66 @@ class TestBigSFTP(object): for i in range(len(readv_list)): offset = readv_list[i][0] n_offset = offset % 1024 - assert next(ret) == k2blob[n_offset:n_offset + chunk] + assert next(ret) == k2blob[n_offset : n_offset + chunk] end = time.time() - sys.stderr.write('%ds ' % round(end - start)) + sys.stderr.write("%ds " % round(end - start)) finally: - sftp.remove('%s/hongry.txt' % sftp.FOLDER) + sftp.remove("%s/hongry.txt" % sftp.FOLDER) def test_6_lots_of_prefetching(self, sftp): """ prefetch a 1MB file a bunch of times, discarding the file object without using it, to verify that paramiko doesn't get confused. """ - kblob = (1024 * b'x') + kblob = 1024 * b"x" try: - with sftp.open('%s/hongry.txt' % sftp.FOLDER, 'w') as f: + with sftp.open("%s/hongry.txt" % sftp.FOLDER, "w") as f: f.set_pipelined(True) for n in range(1024): f.write(kblob) if n % 128 == 0: - sys.stderr.write('.') - sys.stderr.write(' ') + sys.stderr.write(".") + sys.stderr.write(" ") - assert sftp.stat('%s/hongry.txt' % sftp.FOLDER).st_size == 1024 * 1024 + assert ( + sftp.stat("%s/hongry.txt" % sftp.FOLDER).st_size == 1024 * 1024 + ) for i in range(10): - with sftp.open('%s/hongry.txt' % sftp.FOLDER, 'r') as f: + with sftp.open("%s/hongry.txt" % sftp.FOLDER, "r") as f: file_size = f.stat().st_size f.prefetch(file_size) - with sftp.open('%s/hongry.txt' % sftp.FOLDER, 'r') as f: + with sftp.open("%s/hongry.txt" % sftp.FOLDER, "r") as f: file_size = f.stat().st_size f.prefetch(file_size) for n in range(1024): data = f.read(1024) assert data == kblob if n % 128 == 0: - sys.stderr.write('.') - sys.stderr.write(' ') + sys.stderr.write(".") + sys.stderr.write(" ") finally: - sftp.remove('%s/hongry.txt' % sftp.FOLDER) - + sftp.remove("%s/hongry.txt" % sftp.FOLDER) + def test_7_prefetch_readv(self, sftp): """ verify that prefetch and readv don't conflict with each other. """ - kblob = bytes().join([struct.pack('>H', n) for n in range(512)]) + kblob = bytes().join([struct.pack(">H", n) for n in range(512)]) try: - with sftp.open('%s/hongry.txt' % sftp.FOLDER, 'wb') as f: + with sftp.open("%s/hongry.txt" % sftp.FOLDER, "wb") as f: f.set_pipelined(True) for n in range(1024): f.write(kblob) if n % 128 == 0: - sys.stderr.write('.') - sys.stderr.write(' ') - - assert sftp.stat('%s/hongry.txt' % sftp.FOLDER).st_size == 1024 * 1024 + sys.stderr.write(".") + sys.stderr.write(" ") + + assert ( + sftp.stat("%s/hongry.txt" % sftp.FOLDER).st_size == 1024 * 1024 + ) - with sftp.open('%s/hongry.txt' % sftp.FOLDER, 'rb') as f: + with sftp.open("%s/hongry.txt" % sftp.FOLDER, "rb") as f: file_size = f.stat().st_size f.prefetch(file_size) data = f.read(1024) @@ -264,79 +283,94 @@ class TestBigSFTP(object): chunk_size = 793 base_offset = 512 * 1024 k2blob = kblob + kblob - chunks = [(base_offset + (chunk_size * i), chunk_size) for i in range(20)] + chunks = [ + (base_offset + (chunk_size * i), chunk_size) + for i in range(20) + ] for data in f.readv(chunks): offset = base_offset % 1024 assert chunk_size == len(data) - assert k2blob[offset:offset + chunk_size] == data + assert k2blob[offset : offset + chunk_size] == data base_offset += chunk_size - sys.stderr.write(' ') + sys.stderr.write(" ") finally: - sftp.remove('%s/hongry.txt' % sftp.FOLDER) - + sftp.remove("%s/hongry.txt" % sftp.FOLDER) + def test_8_large_readv(self, sftp): """ verify that a very large readv is broken up correctly and still returned as a single blob. """ - kblob = bytes().join([struct.pack('>H', n) for n in range(512)]) + kblob = bytes().join([struct.pack(">H", n) for n in range(512)]) try: - with sftp.open('%s/hongry.txt' % sftp.FOLDER, 'wb') as f: + with sftp.open("%s/hongry.txt" % sftp.FOLDER, "wb") as f: f.set_pipelined(True) for n in range(1024): f.write(kblob) if n % 128 == 0: - sys.stderr.write('.') - sys.stderr.write(' ') + sys.stderr.write(".") + sys.stderr.write(" ") + + assert ( + sftp.stat("%s/hongry.txt" % sftp.FOLDER).st_size == 1024 * 1024 + ) - assert sftp.stat('%s/hongry.txt' % sftp.FOLDER).st_size == 1024 * 1024 - - with sftp.open('%s/hongry.txt' % sftp.FOLDER, 'rb') as f: + with sftp.open("%s/hongry.txt" % sftp.FOLDER, "rb") as f: data = list(f.readv([(23 * 1024, 128 * 1024)])) assert len(data) == 1 data = data[0] assert len(data) == 128 * 1024 - - sys.stderr.write(' ') + + sys.stderr.write(" ") finally: - sftp.remove('%s/hongry.txt' % sftp.FOLDER) - + sftp.remove("%s/hongry.txt" % sftp.FOLDER) + def test_9_big_file_big_buffer(self, sftp): """ write a 1MB file, with no linefeeds, and a big buffer. """ - mblob = (1024 * 1024 * 'x') + mblob = 1024 * 1024 * "x" try: - with sftp.open('%s/hongry.txt' % sftp.FOLDER, 'w', 128 * 1024) as f: + with sftp.open( + "%s/hongry.txt" % sftp.FOLDER, "w", 128 * 1024 + ) as f: f.write(mblob) - assert sftp.stat('%s/hongry.txt' % sftp.FOLDER).st_size == 1024 * 1024 + assert ( + sftp.stat("%s/hongry.txt" % sftp.FOLDER).st_size == 1024 * 1024 + ) finally: - sftp.remove('%s/hongry.txt' % sftp.FOLDER) - + sftp.remove("%s/hongry.txt" % sftp.FOLDER) + def test_A_big_file_renegotiate(self, sftp): """ write a 1MB file, forcing key renegotiation in the middle. """ t = sftp.sock.get_transport() t.packetizer.REKEY_BYTES = 512 * 1024 - k32blob = (32 * 1024 * 'x') + k32blob = 32 * 1024 * "x" try: - with sftp.open('%s/hongry.txt' % sftp.FOLDER, 'w', 128 * 1024) as f: + with sftp.open( + "%s/hongry.txt" % sftp.FOLDER, "w", 128 * 1024 + ) as f: for i in range(32): f.write(k32blob) - assert sftp.stat('%s/hongry.txt' % sftp.FOLDER).st_size == 1024 * 1024 + assert ( + sftp.stat("%s/hongry.txt" % sftp.FOLDER).st_size == 1024 * 1024 + ) assert t.H != t.session_id - + # try to read it too. - with sftp.open('%s/hongry.txt' % sftp.FOLDER, 'r', 128 * 1024) as f: + with sftp.open( + "%s/hongry.txt" % sftp.FOLDER, "r", 128 * 1024 + ) as f: file_size = f.stat().st_size f.prefetch(file_size) total = 0 while total < 1024 * 1024: total += len(f.read(32 * 1024)) finally: - sftp.remove('%s/hongry.txt' % sftp.FOLDER) + sftp.remove("%s/hongry.txt" % sftp.FOLDER) t.packetizer.REKEY_BYTES = pow(2, 30) diff --git a/tests/test_ssh_exception.py b/tests/test_ssh_exception.py index 18f2a97d..6cc5d06a 100644 --- a/tests/test_ssh_exception.py +++ b/tests/test_ssh_exception.py @@ -4,28 +4,33 @@ import unittest from paramiko.ssh_exception import NoValidConnectionsError -class NoValidConnectionsErrorTest (unittest.TestCase): +class NoValidConnectionsErrorTest(unittest.TestCase): def test_pickling(self): # Regression test for https://github.com/paramiko/paramiko/issues/617 - exc = NoValidConnectionsError({('127.0.0.1', '22'): Exception()}) + exc = NoValidConnectionsError({("127.0.0.1", "22"): Exception()}) new_exc = pickle.loads(pickle.dumps(exc)) self.assertEqual(type(exc), type(new_exc)) self.assertEqual(str(exc), str(new_exc)) self.assertEqual(exc.args, new_exc.args) def test_error_message_for_single_host(self): - exc = NoValidConnectionsError({('127.0.0.1', '22'): Exception()}) + exc = NoValidConnectionsError({("127.0.0.1", "22"): Exception()}) assert "Unable to connect to port 22 on 127.0.0.1" in str(exc) def test_error_message_for_two_hosts(self): - exc = NoValidConnectionsError({('127.0.0.1', '22'): Exception(), - ('::1', '22'): Exception()}) + exc = NoValidConnectionsError( + {("127.0.0.1", "22"): Exception(), ("::1", "22"): Exception()} + ) assert "Unable to connect to port 22 on 127.0.0.1 or ::1" in str(exc) def test_error_message_for_multiple_hosts(self): - exc = NoValidConnectionsError({('127.0.0.1', '22'): Exception(), - ('::1', '22'): Exception(), - ('10.0.0.42', '22'): Exception()}) + exc = NoValidConnectionsError( + { + ("127.0.0.1", "22"): Exception(), + ("::1", "22"): Exception(), + ("10.0.0.42", "22"): Exception(), + } + ) exp = "Unable to connect to port 22 on 10.0.0.42, 127.0.0.1 or ::1" assert exp in str(exc) diff --git a/tests/test_ssh_gss.py b/tests/test_ssh_gss.py index f0645e0e..cee6ce89 100644 --- a/tests/test_ssh_gss.py +++ b/tests/test_ssh_gss.py @@ -33,15 +33,13 @@ from .util import _support, needs_gssapi from .test_client import FINGERPRINTS -class NullServer (paramiko.ServerInterface): +class NullServer(paramiko.ServerInterface): + def get_allowed_auths(self, username): - return 'gssapi-with-mic,publickey' + return "gssapi-with-mic,publickey" def check_auth_gssapi_with_mic( - self, - username, - gss_authenticated=paramiko.AUTH_FAILED, - cc_file=None, + self, username, gss_authenticated=paramiko.AUTH_FAILED, cc_file=None ): if gss_authenticated == paramiko.AUTH_SUCCESSFUL: return paramiko.AUTH_SUCCESSFUL @@ -64,13 +62,14 @@ class NullServer (paramiko.ServerInterface): return paramiko.OPEN_SUCCEEDED def check_channel_exec_request(self, channel, command): - if command != 'yes': + if command != "yes": return False return True @needs_gssapi class GSSAuthTest(unittest.TestCase): + def setUp(self): # TODO: username and targ_name should come from os.environ or whatever # the approved pytest method is for runtime-configuring test data. @@ -92,7 +91,7 @@ class GSSAuthTest(unittest.TestCase): def _run(self): self.socks, addr = self.sockl.accept() self.ts = paramiko.Transport(self.socks) - host_key = paramiko.RSAKey.from_private_key_file('tests/test_rsa.key') + host_key = paramiko.RSAKey.from_private_key_file("tests/test_rsa.key") self.ts.add_server_key(host_key) server = NullServer() self.ts.start_server(self.event, server) @@ -103,15 +102,22 @@ class GSSAuthTest(unittest.TestCase): The exception is ... no exception yet """ - host_key = paramiko.RSAKey.from_private_key_file('tests/test_rsa.key') + host_key = paramiko.RSAKey.from_private_key_file("tests/test_rsa.key") public_host_key = paramiko.RSAKey(data=host_key.asbytes()) self.tc = paramiko.SSHClient() self.tc.set_missing_host_key_policy(paramiko.WarningPolicy()) - self.tc.get_host_keys().add('[%s]:%d' % (self.addr, self.port), - 'ssh-rsa', public_host_key) - self.tc.connect(hostname=self.addr, port=self.port, username=self.username, gss_host=self.hostname, - gss_auth=True, **kwargs) + self.tc.get_host_keys().add( + "[%s]:%d" % (self.addr, self.port), "ssh-rsa", public_host_key + ) + self.tc.connect( + hostname=self.addr, + port=self.port, + username=self.username, + gss_host=self.hostname, + gss_auth=True, + **kwargs + ) self.event.wait(1.0) self.assert_(self.event.is_set()) @@ -119,17 +125,17 @@ class GSSAuthTest(unittest.TestCase): self.assertEquals(self.username, self.ts.get_username()) self.assertEquals(True, self.ts.is_authenticated()) - stdin, stdout, stderr = self.tc.exec_command('yes') + stdin, stdout, stderr = self.tc.exec_command("yes") schan = self.ts.accept(1.0) - schan.send('Hello there.\n') - schan.send_stderr('This is on stderr.\n') + schan.send("Hello there.\n") + schan.send_stderr("This is on stderr.\n") schan.close() - self.assertEquals('Hello there.\n', stdout.readline()) - self.assertEquals('', stdout.readline()) - self.assertEquals('This is on stderr.\n', stderr.readline()) - self.assertEquals('', stderr.readline()) + self.assertEquals("Hello there.\n", stdout.readline()) + self.assertEquals("", stdout.readline()) + self.assertEquals("This is on stderr.\n", stderr.readline()) + self.assertEquals("", stderr.readline()) stdin.close() stdout.close() @@ -140,14 +146,17 @@ class GSSAuthTest(unittest.TestCase): Verify that Paramiko can handle SSHv2 GSS-API / SSPI authentication (gssapi-with-mic) in client and server mode. """ - self._test_connection(allow_agent=False, - look_for_keys=False) + self._test_connection(allow_agent=False, look_for_keys=False) def test_2_auth_trickledown(self): """ Failed gssapi-with-mic auth doesn't prevent subsequent key auth from succeeding """ - self.hostname = "this_host_does_not_exists_and_causes_a_GSSAPI-exception" - self._test_connection(key_filename=[_support('test_rsa.key')], - allow_agent=False, - look_for_keys=False) + self.hostname = ( + "this_host_does_not_exists_and_causes_a_GSSAPI-exception" + ) + self._test_connection( + key_filename=[_support("test_rsa.key")], + allow_agent=False, + look_for_keys=False, + ) diff --git a/tests/test_transport.py b/tests/test_transport.py index 9474acfc..c05d6781 100644 --- a/tests/test_transport.py +++ b/tests/test_transport.py @@ -32,14 +32,26 @@ from hashlib import sha1 import unittest from paramiko import ( - Transport, SecurityOptions, ServerInterface, RSAKey, DSSKey, SSHException, - ChannelException, Packetizer, Channel, + Transport, + SecurityOptions, + ServerInterface, + RSAKey, + DSSKey, + SSHException, + ChannelException, + Packetizer, + Channel, ) from paramiko import AUTH_FAILED, AUTH_SUCCESSFUL from paramiko import OPEN_SUCCEEDED, OPEN_FAILED_ADMINISTRATIVELY_PROHIBITED from paramiko.common import ( - MSG_KEXINIT, cMSG_CHANNEL_WINDOW_ADJUST, MIN_PACKET_SIZE, MIN_WINDOW_SIZE, - MAX_WINDOW_SIZE, DEFAULT_WINDOW_SIZE, DEFAULT_MAX_PACKET_SIZE, + MSG_KEXINIT, + cMSG_CHANNEL_WINDOW_ADJUST, + MIN_PACKET_SIZE, + MIN_WINDOW_SIZE, + MAX_WINDOW_SIZE, + DEFAULT_WINDOW_SIZE, + DEFAULT_MAX_PACKET_SIZE, ) from paramiko.py3compat import bytes from paramiko.message import Message @@ -61,28 +73,28 @@ Maybe. """ -class NullServer (ServerInterface): +class NullServer(ServerInterface): paranoid_did_password = False paranoid_did_public_key = False - paranoid_key = DSSKey.from_private_key_file(_support('test_dss.key')) + paranoid_key = DSSKey.from_private_key_file(_support("test_dss.key")) def get_allowed_auths(self, username): - if username == 'slowdive': - return 'publickey,password' - return 'publickey' + if username == "slowdive": + return "publickey,password" + return "publickey" def check_auth_password(self, username, password): - if (username == 'slowdive') and (password == 'pygmalion'): + if (username == "slowdive") and (password == "pygmalion"): return AUTH_SUCCESSFUL return AUTH_FAILED def check_channel_request(self, kind, chanid): - if kind == 'bogus': + if kind == "bogus": return OPEN_FAILED_ADMINISTRATIVELY_PROHIBITED return OPEN_SUCCEEDED def check_channel_exec_request(self, channel, command): - if command != b'yes': + if command != b"yes": return False return True @@ -95,9 +107,16 @@ class NullServer (ServerInterface): # tho that's only supposed to occur if the request cannot be served. # For now, leaving that the default unless test supplies specific # 'acceptable' request kind - return kind == 'acceptable' - - def check_channel_x11_request(self, channel, single_connection, auth_protocol, auth_cookie, screen_number): + return kind == "acceptable" + + def check_channel_x11_request( + self, + channel, + single_connection, + auth_protocol, + auth_cookie, + screen_number, + ): self._x11_single_connection = single_connection self._x11_auth_protocol = auth_protocol self._x11_auth_cookie = auth_cookie @@ -106,7 +125,7 @@ class NullServer (ServerInterface): def check_port_forward_request(self, addr, port): self._listen = socket.socket() - self._listen.bind(('127.0.0.1', 0)) + self._listen.bind(("127.0.0.1", 0)) self._listen.listen(1) return self._listen.getsockname()[1] @@ -120,6 +139,7 @@ class NullServer (ServerInterface): class TransportTest(unittest.TestCase): + def setUp(self): self.socks = LoopSocket() self.sockc = LoopSocket() @@ -134,9 +154,9 @@ class TransportTest(unittest.TestCase): self.sockc.close() def setup_test_server( - self, client_options=None, server_options=None, connect_kwargs=None, + self, client_options=None, server_options=None, connect_kwargs=None ): - host_key = RSAKey.from_private_key_file(_support('test_rsa.key')) + host_key = RSAKey.from_private_key_file(_support("test_rsa.key")) public_host_key = RSAKey(data=host_key.asbytes()) self.ts.add_server_key(host_key) @@ -152,8 +172,8 @@ class TransportTest(unittest.TestCase): if connect_kwargs is None: connect_kwargs = dict( hostkey=public_host_key, - username='slowdive', - password='pygmalion', + username="slowdive", + password="pygmalion", ) self.tc.connect(**connect_kwargs) event.wait(1.0) @@ -163,11 +183,11 @@ class TransportTest(unittest.TestCase): def test_1_security_options(self): o = self.tc.get_security_options() self.assertEqual(type(o), SecurityOptions) - self.assertTrue(('aes256-cbc', 'blowfish-cbc') != o.ciphers) - o.ciphers = ('aes256-cbc', 'blowfish-cbc') - self.assertEqual(('aes256-cbc', 'blowfish-cbc'), o.ciphers) + self.assertTrue(("aes256-cbc", "blowfish-cbc") != o.ciphers) + o.ciphers = ("aes256-cbc", "blowfish-cbc") + self.assertEqual(("aes256-cbc", "blowfish-cbc"), o.ciphers) try: - o.ciphers = ('aes256-cbc', 'made-up-cipher') + o.ciphers = ("aes256-cbc", "made-up-cipher") self.assertTrue(False) except ValueError: pass @@ -187,12 +207,18 @@ class TransportTest(unittest.TestCase): o.compression = o.compression def test_2_compute_key(self): - self.tc.K = 123281095979686581523377256114209720774539068973101330872763622971399429481072519713536292772709507296759612401802191955568143056534122385270077606457721553469730659233569339356140085284052436697480759510519672848743794433460113118986816826624865291116513647975790797391795651716378444844877749505443714557929 - self.tc.H = b'\x0C\x83\x07\xCD\xE6\x85\x6F\xF3\x0B\xA9\x36\x84\xEB\x0F\x04\xC2\x52\x0E\x9E\xD3' + self.tc.K = ( + 123281095979686581523377256114209720774539068973101330872763622971399429481072519713536292772709507296759612401802191955568143056534122385270077606457721553469730659233569339356140085284052436697480759510519672848743794433460113118986816826624865291116513647975790797391795651716378444844877749505443714557929 + ) + self.tc.H = ( + b"\x0C\x83\x07\xCD\xE6\x85\x6F\xF3\x0B\xA9\x36\x84\xEB\x0F\x04\xC2\x52\x0E\x9E\xD3" + ) self.tc.session_id = self.tc.H - key = self.tc._compute_key('C', 32) - self.assertEqual(b'207E66594CA87C44ECCBA3B3CD39FDDB378E6FDB0F97C54B2AA0CFBF900CD995', - hexlify(key).upper()) + key = self.tc._compute_key("C", 32) + self.assertEqual( + b"207E66594CA87C44ECCBA3B3CD39FDDB378E6FDB0F97C54B2AA0CFBF900CD995", + hexlify(key).upper(), + ) def test_3_simple(self): """ @@ -200,7 +226,7 @@ class TransportTest(unittest.TestCase): loopback sockets. this is hardly "simple" but it's simpler than the later tests. :) """ - host_key = RSAKey.from_private_key_file(_support('test_rsa.key')) + host_key = RSAKey.from_private_key_file(_support("test_rsa.key")) public_host_key = RSAKey(data=host_key.asbytes()) self.ts.add_server_key(host_key) event = threading.Event() @@ -211,13 +237,14 @@ class TransportTest(unittest.TestCase): self.assertEqual(False, self.tc.is_authenticated()) self.assertEqual(False, self.ts.is_authenticated()) self.ts.start_server(event, server) - self.tc.connect(hostkey=public_host_key, - username='slowdive', password='pygmalion') + self.tc.connect( + hostkey=public_host_key, username="slowdive", password="pygmalion" + ) event.wait(1.0) self.assertTrue(event.is_set()) self.assertTrue(self.ts.is_active()) - self.assertEqual('slowdive', self.tc.get_username()) - self.assertEqual('slowdive', self.ts.get_username()) + self.assertEqual("slowdive", self.tc.get_username()) + self.assertEqual("slowdive", self.ts.get_username()) self.assertEqual(True, self.tc.is_authenticated()) self.assertEqual(True, self.ts.is_authenticated()) @@ -225,7 +252,7 @@ class TransportTest(unittest.TestCase): """ verify that a long banner doesn't mess up the handshake. """ - host_key = RSAKey.from_private_key_file(_support('test_rsa.key')) + host_key = RSAKey.from_private_key_file(_support("test_rsa.key")) public_host_key = RSAKey(data=host_key.asbytes()) self.ts.add_server_key(host_key) event = threading.Event() @@ -233,8 +260,9 @@ class TransportTest(unittest.TestCase): self.assertTrue(not event.is_set()) self.socks.send(LONG_BANNER) self.ts.start_server(event, server) - self.tc.connect(hostkey=public_host_key, - username='slowdive', password='pygmalion') + self.tc.connect( + hostkey=public_host_key, username="slowdive", password="pygmalion" + ) event.wait(1.0) self.assertTrue(event.is_set()) self.assertTrue(self.ts.is_active()) @@ -244,12 +272,14 @@ class TransportTest(unittest.TestCase): verify that the client can demand odd handshake settings, and can renegotiate keys in mid-stream. """ + def force_algorithms(options): - options.ciphers = ('aes256-cbc',) - options.digests = ('hmac-md5-96',) + options.ciphers = ("aes256-cbc",) + options.digests = ("hmac-md5-96",) + self.setup_test_server(client_options=force_algorithms) - self.assertEqual('aes256-cbc', self.tc.local_cipher) - self.assertEqual('aes256-cbc', self.tc.remote_cipher) + self.assertEqual("aes256-cbc", self.tc.local_cipher) + self.assertEqual("aes256-cbc", self.tc.remote_cipher) self.assertEqual(12, self.tc.packetizer.get_mac_size_out()) self.assertEqual(12, self.tc.packetizer.get_mac_size_in()) @@ -263,10 +293,10 @@ class TransportTest(unittest.TestCase): verify that the keepalive will be sent. """ self.setup_test_server() - self.assertEqual(None, getattr(self.server, '_global_request', None)) + self.assertEqual(None, getattr(self.server, "_global_request", None)) self.tc.set_keepalive(1) time.sleep(2) - self.assertEqual('keepalive@lag.net', self.server._global_request) + self.assertEqual("keepalive@lag.net", self.server._global_request) def test_6_exec_command(self): """ @@ -277,39 +307,41 @@ class TransportTest(unittest.TestCase): chan = self.tc.open_session() schan = self.ts.accept(1.0) try: - chan.exec_command(b'command contains \xfc and is not a valid UTF-8 string') + chan.exec_command( + b"command contains \xfc and is not a valid UTF-8 string" + ) self.assertTrue(False) except SSHException: pass chan = self.tc.open_session() - chan.exec_command('yes') + chan.exec_command("yes") schan = self.ts.accept(1.0) - schan.send('Hello there.\n') - schan.send_stderr('This is on stderr.\n') + schan.send("Hello there.\n") + schan.send_stderr("This is on stderr.\n") schan.close() f = chan.makefile() - self.assertEqual('Hello there.\n', f.readline()) - self.assertEqual('', f.readline()) + self.assertEqual("Hello there.\n", f.readline()) + self.assertEqual("", f.readline()) f = chan.makefile_stderr() - self.assertEqual('This is on stderr.\n', f.readline()) - self.assertEqual('', f.readline()) + self.assertEqual("This is on stderr.\n", f.readline()) + self.assertEqual("", f.readline()) # now try it with combined stdout/stderr chan = self.tc.open_session() - chan.exec_command('yes') + chan.exec_command("yes") schan = self.ts.accept(1.0) - schan.send('Hello there.\n') - schan.send_stderr('This is on stderr.\n') + schan.send("Hello there.\n") + schan.send_stderr("This is on stderr.\n") schan.close() chan.set_combine_stderr(True) f = chan.makefile() - self.assertEqual('Hello there.\n', f.readline()) - self.assertEqual('This is on stderr.\n', f.readline()) - self.assertEqual('', f.readline()) - + self.assertEqual("Hello there.\n", f.readline()) + self.assertEqual("This is on stderr.\n", f.readline()) + self.assertEqual("", f.readline()) + def test_6a_channel_can_be_used_as_context_manager(self): """ verify that exec_command() does something reasonable. @@ -318,13 +350,13 @@ class TransportTest(unittest.TestCase): with self.tc.open_session() as chan: with self.ts.accept(1.0) as schan: - chan.exec_command('yes') - schan.send('Hello there.\n') + chan.exec_command("yes") + schan.send("Hello there.\n") schan.close() f = chan.makefile() - self.assertEqual('Hello there.\n', f.readline()) - self.assertEqual('', f.readline()) + self.assertEqual("Hello there.\n", f.readline()) + self.assertEqual("", f.readline()) def test_7_invoke_shell(self): """ @@ -334,11 +366,11 @@ class TransportTest(unittest.TestCase): chan = self.tc.open_session() chan.invoke_shell() schan = self.ts.accept(1.0) - chan.send('communist j. cat\n') + chan.send("communist j. cat\n") f = schan.makefile() - self.assertEqual('communist j. cat\n', f.readline()) + self.assertEqual("communist j. cat\n", f.readline()) chan.close() - self.assertEqual('', f.readline()) + self.assertEqual("", f.readline()) def test_8_channel_exception(self): """ @@ -346,8 +378,8 @@ class TransportTest(unittest.TestCase): """ self.setup_test_server() try: - chan = self.tc.open_channel('bogus') - self.fail('expected exception') + chan = self.tc.open_channel("bogus") + self.fail("expected exception") except ChannelException as e: self.assertTrue(e.code == OPEN_FAILED_ADMINISTRATIVELY_PROHIBITED) @@ -359,8 +391,8 @@ class TransportTest(unittest.TestCase): chan = self.tc.open_session() schan = self.ts.accept(1.0) - chan.exec_command('yes') - schan.send('Hello there.\n') + chan.exec_command("yes") + schan.send("Hello there.\n") self.assertTrue(not chan.exit_status_ready()) # trigger an EOF schan.shutdown_read() @@ -369,8 +401,8 @@ class TransportTest(unittest.TestCase): schan.close() f = chan.makefile() - self.assertEqual('Hello there.\n', f.readline()) - self.assertEqual('', f.readline()) + self.assertEqual("Hello there.\n", f.readline()) + self.assertEqual("", f.readline()) count = 0 while not chan.exit_status_ready(): time.sleep(0.1) @@ -395,7 +427,7 @@ class TransportTest(unittest.TestCase): self.assertEqual([], w) self.assertEqual([], e) - schan.send('hello\n') + schan.send("hello\n") # something should be ready now (give it 1 second to appear) for i in range(10): @@ -407,7 +439,7 @@ class TransportTest(unittest.TestCase): self.assertEqual([], w) self.assertEqual([], e) - self.assertEqual(b'hello\n', chan.recv(6)) + self.assertEqual(b"hello\n", chan.recv(6)) # and, should be dead again now r, w, e = select.select([chan], [], [], 0.1) @@ -442,12 +474,12 @@ class TransportTest(unittest.TestCase): self.setup_test_server() self.tc.packetizer.REKEY_BYTES = 16384 chan = self.tc.open_session() - chan.exec_command('yes') + chan.exec_command("yes") schan = self.ts.accept(1.0) self.assertEqual(self.tc.H, self.tc.session_id) for i in range(20): - chan.send('x' * 1024) + chan.send("x" * 1024) chan.close() # allow a few seconds for the rekeying to complete @@ -463,18 +495,20 @@ class TransportTest(unittest.TestCase): """ verify that zlib compression is basically working. """ + def force_compression(o): - o.compression = ('zlib',) + o.compression = ("zlib",) + self.setup_test_server(force_compression, force_compression) chan = self.tc.open_session() - chan.exec_command('yes') + chan.exec_command("yes") schan = self.ts.accept(1.0) bytes = self.tc.packetizer._Packetizer__sent_bytes - chan.send('x' * 1024) + chan.send("x" * 1024) bytes2 = self.tc.packetizer._Packetizer__sent_bytes - block_size = self.tc._cipher_info[self.tc.local_cipher]['block-size'] - mac_size = self.tc._mac_info[self.tc.local_mac]['size'] + block_size = self.tc._cipher_info[self.tc.local_cipher]["block-size"] + mac_size = self.tc._mac_info[self.tc.local_mac]["size"] # tests show this is actually compressed to *52 bytes*! including packet overhead! nice!! :) self.assertTrue(bytes2 - bytes < 1024) self.assertEqual(16 + block_size + mac_size, bytes2 - bytes) @@ -488,29 +522,32 @@ class TransportTest(unittest.TestCase): """ self.setup_test_server() chan = self.tc.open_session() - chan.exec_command('yes') + chan.exec_command("yes") schan = self.ts.accept(1.0) requested = [] + def handler(c, addr_port): addr, port = addr_port requested.append((addr, port)) self.tc._queue_incoming_channel(c) - self.assertEqual(None, getattr(self.server, '_x11_screen_number', None)) + self.assertEqual( + None, getattr(self.server, "_x11_screen_number", None) + ) cookie = chan.request_x11(0, single_connection=True, handler=handler) self.assertEqual(0, self.server._x11_screen_number) - self.assertEqual('MIT-MAGIC-COOKIE-1', self.server._x11_auth_protocol) + self.assertEqual("MIT-MAGIC-COOKIE-1", self.server._x11_auth_protocol) self.assertEqual(cookie, self.server._x11_auth_cookie) self.assertEqual(True, self.server._x11_single_connection) - x11_server = self.ts.open_x11_channel(('localhost', 6093)) + x11_server = self.ts.open_x11_channel(("localhost", 6093)) x11_client = self.tc.accept() - self.assertEqual('localhost', requested[0][0]) + self.assertEqual("localhost", requested[0][0]) self.assertEqual(6093, requested[0][1]) - x11_server.send('hello') - self.assertEqual(b'hello', x11_client.recv(5)) + x11_server.send("hello") + self.assertEqual(b"hello", x11_client.recv(5)) x11_server.close() x11_client.close() @@ -524,33 +561,36 @@ class TransportTest(unittest.TestCase): """ self.setup_test_server() chan = self.tc.open_session() - chan.exec_command('yes') + chan.exec_command("yes") schan = self.ts.accept(1.0) requested = [] + def handler(c, origin_addr_port, server_addr_port): requested.append(origin_addr_port) requested.append(server_addr_port) self.tc._queue_incoming_channel(c) - port = self.tc.request_port_forward('127.0.0.1', 0, handler) + port = self.tc.request_port_forward("127.0.0.1", 0, handler) self.assertEqual(port, self.server._listen.getsockname()[1]) cs = socket.socket() - cs.connect(('127.0.0.1', port)) + cs.connect(("127.0.0.1", port)) ss, _ = self.server._listen.accept() - sch = self.ts.open_forwarded_tcpip_channel(ss.getsockname(), ss.getpeername()) + sch = self.ts.open_forwarded_tcpip_channel( + ss.getsockname(), ss.getpeername() + ) cch = self.tc.accept() - sch.send('hello') - self.assertEqual(b'hello', cch.recv(5)) + sch.send("hello") + self.assertEqual(b"hello", cch.recv(5)) sch.close() cch.close() ss.close() cs.close() # now cancel it. - self.tc.cancel_port_forward('127.0.0.1', port) + self.tc.cancel_port_forward("127.0.0.1", port) self.assertTrue(self.server._listen is None) def test_F_port_forwarding(self): @@ -560,27 +600,29 @@ class TransportTest(unittest.TestCase): """ self.setup_test_server() chan = self.tc.open_session() - chan.exec_command('yes') + chan.exec_command("yes") schan = self.ts.accept(1.0) # open a port on the "server" that the client will ask to forward to. greeting_server = socket.socket() - greeting_server.bind(('127.0.0.1', 0)) + greeting_server.bind(("127.0.0.1", 0)) greeting_server.listen(1) greeting_port = greeting_server.getsockname()[1] - cs = self.tc.open_channel('direct-tcpip', ('127.0.0.1', greeting_port), ('', 9000)) + cs = self.tc.open_channel( + "direct-tcpip", ("127.0.0.1", greeting_port), ("", 9000) + ) sch = self.ts.accept(1.0) cch = socket.socket() cch.connect(self.server._tcpip_dest) ss, _ = greeting_server.accept() - ss.send(b'Hello!\n') + ss.send(b"Hello!\n") ss.close() sch.send(cch.recv(8192)) sch.close() - self.assertEqual(b'Hello!\n', cs.recv(7)) + self.assertEqual(b"Hello!\n", cs.recv(7)) cs.close() def test_G_stderr_select(self): @@ -599,7 +641,7 @@ class TransportTest(unittest.TestCase): self.assertEqual([], w) self.assertEqual([], e) - schan.send_stderr('hello\n') + schan.send_stderr("hello\n") # something should be ready now (give it 1 second to appear) for i in range(10): @@ -611,7 +653,7 @@ class TransportTest(unittest.TestCase): self.assertEqual([], w) self.assertEqual([], e) - self.assertEqual(b'hello\n', chan.recv_stderr(6)) + self.assertEqual(b"hello\n", chan.recv_stderr(6)) # and, should be dead again now r, w, e = select.select([chan], [], [], 0.1) @@ -633,8 +675,8 @@ class TransportTest(unittest.TestCase): self.assertEqual(chan.send_ready(), True) total = 0 - K = '*' * 1024 - limit = 1+(64 * 2 ** 15) + K = "*" * 1024 + limit = 1 + (64 * 2 ** 15) while total < limit: chan.send(K) total += len(K) @@ -696,8 +738,11 @@ class TransportTest(unittest.TestCase): # expires, a deadlock is assumed. class SendThread(threading.Thread): + def __init__(self, chan, iterations, done_event): - threading.Thread.__init__(self, None, None, self.__class__.__name__) + threading.Thread.__init__( + self, None, None, self.__class__.__name__ + ) self.setDaemon(True) self.chan = chan self.iterations = iterations @@ -707,19 +752,22 @@ class TransportTest(unittest.TestCase): def run(self): try: - for i in range(1, 1+self.iterations): + for i in range(1, 1 + self.iterations): if self.done_event.is_set(): break self.watchdog_event.set() - #print i, "SEND" + # print i, "SEND" self.chan.send("x" * 2048) finally: self.done_event.set() self.watchdog_event.set() class ReceiveThread(threading.Thread): + def __init__(self, chan, done_event): - threading.Thread.__init__(self, None, None, self.__class__.__name__) + threading.Thread.__init__( + self, None, None, self.__class__.__name__ + ) self.setDaemon(True) self.chan = chan self.done_event = done_event @@ -742,30 +790,34 @@ class TransportTest(unittest.TestCase): self.ts.packetizer.REKEY_BYTES = 2048 chan = self.tc.open_session() - chan.exec_command('yes') + chan.exec_command("yes") schan = self.ts.accept(1.0) # Monkey patch the client's Transport._handler_table so that the client # sends MSG_CHANNEL_WINDOW_ADJUST whenever it receives an initial # MSG_KEXINIT. This is used to simulate the effect of network latency # on a real MSG_CHANNEL_WINDOW_ADJUST message. - self.tc._handler_table = self.tc._handler_table.copy() # copy per-class dictionary + self.tc._handler_table = ( + self.tc._handler_table.copy() + ) # copy per-class dictionary _negotiate_keys = self.tc._handler_table[MSG_KEXINIT] + def _negotiate_keys_wrapper(self, m): - if self.local_kex_init is None: # Remote side sent KEXINIT + if self.local_kex_init is None: # Remote side sent KEXINIT # Simulate in-transit MSG_CHANNEL_WINDOW_ADJUST by sending it # before responding to the incoming MSG_KEXINIT. m2 = Message() m2.add_byte(cMSG_CHANNEL_WINDOW_ADJUST) m2.add_int(chan.remote_chanid) - m2.add_int(1) # bytes to add + m2.add_int(1) # bytes to add self._send_message(m2) return _negotiate_keys(self, m) + self.tc._handler_table[MSG_KEXINIT] = _negotiate_keys_wrapper # Parameters for the test - iterations = 500 # The deadlock does not happen every time, but it - # should after many iterations. + iterations = 500 # The deadlock does not happen every time, but it + # should after many iterations. timeout = 5 # This event is set when the test is completed @@ -807,18 +859,22 @@ class TransportTest(unittest.TestCase): """ verify that we conform to the rfc of packet and window sizes. """ - for val, correct in [(4095, MIN_PACKET_SIZE), - (None, DEFAULT_MAX_PACKET_SIZE), - (2**32, MAX_WINDOW_SIZE)]: + for val, correct in [ + (4095, MIN_PACKET_SIZE), + (None, DEFAULT_MAX_PACKET_SIZE), + (2 ** 32, MAX_WINDOW_SIZE), + ]: self.assertEqual(self.tc._sanitize_packet_size(val), correct) def test_K_sanitze_window_size(self): """ verify that we conform to the rfc of packet and window sizes. """ - for val, correct in [(32767, MIN_WINDOW_SIZE), - (None, DEFAULT_WINDOW_SIZE), - (2**32, MAX_WINDOW_SIZE)]: + for val, correct in [ + (32767, MIN_WINDOW_SIZE), + (None, DEFAULT_WINDOW_SIZE), + (2 ** 32, MAX_WINDOW_SIZE), + ]: self.assertEqual(self.tc._sanitize_window_size(val), correct) @slow @@ -834,15 +890,17 @@ class TransportTest(unittest.TestCase): # (Doing this on the server's transport *sounds* more 'correct' but # actually doesn't work nearly as well for whatever reason.) class SlowPacketizer(Packetizer): + def read_message(self): time.sleep(1) return super(SlowPacketizer, self).read_message() + # NOTE: prettttty sure since the replaced .packetizer Packetizer is now # no longer doing anything with its copy of the socket...everything'll # be fine. Even tho it's a bit squicky. self.tc.packetizer = SlowPacketizer(self.tc.sock) # Continue with regular test red tape. - host_key = RSAKey.from_private_key_file(_support('test_rsa.key')) + host_key = RSAKey.from_private_key_file(_support("test_rsa.key")) public_host_key = RSAKey(data=host_key.asbytes()) self.ts.add_server_key(host_key) event = threading.Event() @@ -850,10 +908,13 @@ class TransportTest(unittest.TestCase): self.assertTrue(not event.is_set()) self.tc.handshake_timeout = 0.000000000001 self.ts.start_server(event, server) - self.assertRaises(EOFError, self.tc.connect, - hostkey=public_host_key, - username='slowdive', - password='pygmalion') + self.assertRaises( + EOFError, + self.tc.connect, + hostkey=public_host_key, + username="slowdive", + password="pygmalion", + ) def test_M_select_after_close(self): """ @@ -894,13 +955,13 @@ class TransportTest(unittest.TestCase): expected = text.encode("utf-8") self.assertEqual(sfile.read(len(expected)), expected) - @needs_builtin('buffer') + @needs_builtin("buffer") def test_channel_send_buffer(self): """ verify sending buffer instances to a channel """ self.setup_test_server() - data = 3 * b'some test data\n whole' + data = 3 * b"some test data\n whole" with self.tc.open_session() as chan: schan = self.ts.accept(1.0) if schan is None: @@ -917,13 +978,13 @@ class TransportTest(unittest.TestCase): chan.sendall(buffer(data)) self.assertEqual(sfile.read(len(data)), data) - @needs_builtin('memoryview') + @needs_builtin("memoryview") def test_channel_send_memoryview(self): """ verify sending memoryview instances to a channel """ self.setup_test_server() - data = 3 * b'some test data\n whole' + data = 3 * b"some test data\n whole" with self.tc.open_session() as chan: schan = self.ts.accept(1.0) if schan is None: @@ -934,7 +995,7 @@ class TransportTest(unittest.TestCase): sent = 0 view = memoryview(data) while sent < len(view): - sent += chan.send(view[sent:sent+8]) + sent += chan.send(view[sent : sent + 8]) self.assertEqual(sfile.read(len(data)), data) # sendall() accepts a memoryview instance @@ -954,7 +1015,7 @@ class TransportTest(unittest.TestCase): self.setup_test_server(connect_kwargs={}) # NOTE: this dummy global request kind would normally pass muster # from the test server. - self.tc.global_request('acceptable') + self.tc.global_request("acceptable") # Global requests never raise exceptions, even on failure (not sure why # this was the original design...ugh.) Best we can do to tell failure # happened is that the client transport's global_response was set back @@ -969,7 +1030,7 @@ class TransportTest(unittest.TestCase): # an exception on the client side, unlike the general case...) self.setup_test_server(connect_kwargs={}) try: - self.tc.request_port_forward('localhost', 1234) + self.tc.request_port_forward("localhost", 1234) except SSHException as e: assert "forwarding request denied" in str(e) else: diff --git a/tests/test_util.py b/tests/test_util.py index 90473f43..23b2e86a 100644 --- a/tests/test_util.py +++ b/tests/test_util.py @@ -67,53 +67,71 @@ from paramiko import * class UtilTest(unittest.TestCase): + def test_import(self): """ verify that all the classes can be imported from paramiko. """ symbols = list(globals().keys()) - self.assertTrue('Transport' in symbols) - self.assertTrue('SSHClient' in symbols) - self.assertTrue('MissingHostKeyPolicy' in symbols) - self.assertTrue('AutoAddPolicy' in symbols) - self.assertTrue('RejectPolicy' in symbols) - self.assertTrue('WarningPolicy' in symbols) - self.assertTrue('SecurityOptions' in symbols) - self.assertTrue('SubsystemHandler' in symbols) - self.assertTrue('Channel' in symbols) - self.assertTrue('RSAKey' in symbols) - self.assertTrue('DSSKey' in symbols) - self.assertTrue('Message' in symbols) - self.assertTrue('SSHException' in symbols) - self.assertTrue('AuthenticationException' in symbols) - self.assertTrue('PasswordRequiredException' in symbols) - self.assertTrue('BadAuthenticationType' in symbols) - self.assertTrue('ChannelException' in symbols) - self.assertTrue('SFTP' in symbols) - self.assertTrue('SFTPFile' in symbols) - self.assertTrue('SFTPHandle' in symbols) - self.assertTrue('SFTPClient' in symbols) - self.assertTrue('SFTPServer' in symbols) - self.assertTrue('SFTPError' in symbols) - self.assertTrue('SFTPAttributes' in symbols) - self.assertTrue('SFTPServerInterface' in symbols) - self.assertTrue('ServerInterface' in symbols) - self.assertTrue('BufferedFile' in symbols) - self.assertTrue('Agent' in symbols) - self.assertTrue('AgentKey' in symbols) - self.assertTrue('HostKeys' in symbols) - self.assertTrue('SSHConfig' in symbols) - self.assertTrue('util' in symbols) + self.assertTrue("Transport" in symbols) + self.assertTrue("SSHClient" in symbols) + self.assertTrue("MissingHostKeyPolicy" in symbols) + self.assertTrue("AutoAddPolicy" in symbols) + self.assertTrue("RejectPolicy" in symbols) + self.assertTrue("WarningPolicy" in symbols) + self.assertTrue("SecurityOptions" in symbols) + self.assertTrue("SubsystemHandler" in symbols) + self.assertTrue("Channel" in symbols) + self.assertTrue("RSAKey" in symbols) + self.assertTrue("DSSKey" in symbols) + self.assertTrue("Message" in symbols) + self.assertTrue("SSHException" in symbols) + self.assertTrue("AuthenticationException" in symbols) + self.assertTrue("PasswordRequiredException" in symbols) + self.assertTrue("BadAuthenticationType" in symbols) + self.assertTrue("ChannelException" in symbols) + self.assertTrue("SFTP" in symbols) + self.assertTrue("SFTPFile" in symbols) + self.assertTrue("SFTPHandle" in symbols) + self.assertTrue("SFTPClient" in symbols) + self.assertTrue("SFTPServer" in symbols) + self.assertTrue("SFTPError" in symbols) + self.assertTrue("SFTPAttributes" in symbols) + self.assertTrue("SFTPServerInterface" in symbols) + self.assertTrue("ServerInterface" in symbols) + self.assertTrue("BufferedFile" in symbols) + self.assertTrue("Agent" in symbols) + self.assertTrue("AgentKey" in symbols) + self.assertTrue("HostKeys" in symbols) + self.assertTrue("SSHConfig" in symbols) + self.assertTrue("util" in symbols) def test_parse_config(self): global test_config_file f = StringIO(test_config_file) config = paramiko.util.parse_ssh_config(f) - self.assertEqual(config._config, - [{'host': ['*'], 'config': {}}, {'host': ['*'], 'config': {'identityfile': ['~/.ssh/id_rsa'], 'user': 'robey'}}, - {'host': ['*.example.com'], 'config': {'user': 'bjork', 'port': '3333'}}, - {'host': ['*'], 'config': {'crazy': 'something dumb'}}, - {'host': ['spoo.example.com'], 'config': {'crazy': 'something else'}}]) + self.assertEqual( + config._config, + [ + {"host": ["*"], "config": {}}, + { + "host": ["*"], + "config": { + "identityfile": ["~/.ssh/id_rsa"], + "user": "robey", + }, + }, + { + "host": ["*.example.com"], + "config": {"user": "bjork", "port": "3333"}, + }, + {"host": ["*"], "config": {"crazy": "something dumb"}}, + { + "host": ["spoo.example.com"], + "config": {"crazy": "something else"}, + }, + ], + ) def test_host_config(self): global test_config_file @@ -121,44 +139,57 @@ class UtilTest(unittest.TestCase): config = paramiko.util.parse_ssh_config(f) for host, values in { - 'irc.danger.com': {'crazy': 'something dumb', - 'hostname': 'irc.danger.com', - 'user': 'robey'}, - 'irc.example.com': {'crazy': 'something dumb', - 'hostname': 'irc.example.com', - 'user': 'robey', - 'port': '3333'}, - 'spoo.example.com': {'crazy': 'something dumb', - 'hostname': 'spoo.example.com', - 'user': 'robey', - 'port': '3333'} + "irc.danger.com": { + "crazy": "something dumb", + "hostname": "irc.danger.com", + "user": "robey", + }, + "irc.example.com": { + "crazy": "something dumb", + "hostname": "irc.example.com", + "user": "robey", + "port": "3333", + }, + "spoo.example.com": { + "crazy": "something dumb", + "hostname": "spoo.example.com", + "user": "robey", + "port": "3333", + }, }.items(): - values = dict(values, + values = dict( + values, hostname=host, - identityfile=[os.path.expanduser("~/.ssh/id_rsa")] + identityfile=[os.path.expanduser("~/.ssh/id_rsa")], ) self.assertEqual( - paramiko.util.lookup_ssh_host_config(host, config), - values + paramiko.util.lookup_ssh_host_config(host, config), values ) def test_generate_key_bytes(self): - x = paramiko.util.generate_key_bytes(sha1, b'ABCDEFGH', 'This is my secret passphrase.', 64) - hex = ''.join(['%02x' % byte_ord(c) for c in x]) - self.assertEqual(hex, '9110e2f6793b69363e58173e9436b13a5a4b339005741d5c680e505f57d871347b4239f14fb5c46e857d5e100424873ba849ac699cea98d729e57b3e84378e8b') + x = paramiko.util.generate_key_bytes( + sha1, b"ABCDEFGH", "This is my secret passphrase.", 64 + ) + hex = "".join(["%02x" % byte_ord(c) for c in x]) + self.assertEqual( + hex, + "9110e2f6793b69363e58173e9436b13a5a4b339005741d5c680e505f57d871347b4239f14fb5c46e857d5e100424873ba849ac699cea98d729e57b3e84378e8b", + ) def test_host_keys(self): - with open('hostfile.temp', 'w') as f: + with open("hostfile.temp", "w") as f: f.write(test_hosts_file) try: - hostdict = paramiko.util.load_host_keys('hostfile.temp') + hostdict = paramiko.util.load_host_keys("hostfile.temp") self.assertEqual(2, len(hostdict)) self.assertEqual(1, len(list(hostdict.values())[0])) self.assertEqual(1, len(list(hostdict.values())[1])) - fp = hexlify(hostdict['secure.example.com']['ssh-rsa'].get_fingerprint()).upper() - self.assertEqual(b'E6684DB30E109B67B70FF1DC5C7F1363', fp) + fp = hexlify( + hostdict["secure.example.com"]["ssh-rsa"].get_fingerprint() + ).upper() + self.assertEqual(b"E6684DB30E109B67B70FF1DC5C7F1363", fp) finally: - os.unlink('hostfile.temp') + os.unlink("hostfile.temp") def test_host_config_expose_issue_33(self): test_config_file = """ @@ -173,36 +204,44 @@ Host * """ f = StringIO(test_config_file) config = paramiko.util.parse_ssh_config(f) - host = 'www13.example.com' + host = "www13.example.com" self.assertEqual( paramiko.util.lookup_ssh_host_config(host, config), - {'hostname': host, 'port': '22'} + {"hostname": host, "port": "22"}, ) def test_eintr_retry(self): - self.assertEqual('foo', paramiko.util.retry_on_signal(lambda: 'foo')) + self.assertEqual("foo", paramiko.util.retry_on_signal(lambda: "foo")) # Variables that are set by raises_intr intr_errors_remaining = [3] call_count = [0] + def raises_intr(): call_count[0] += 1 if intr_errors_remaining[0] > 0: intr_errors_remaining[0] -= 1 - raise IOError(errno.EINTR, 'file', 'interrupted system call') + raise IOError(errno.EINTR, "file", "interrupted system call") + self.assertTrue(paramiko.util.retry_on_signal(raises_intr) is None) self.assertEqual(0, intr_errors_remaining[0]) self.assertEqual(4, call_count[0]) def raises_ioerror_not_eintr(): - raise IOError(errno.ENOENT, 'file', 'file not found') - self.assertRaises(IOError, - lambda: paramiko.util.retry_on_signal(raises_ioerror_not_eintr)) + raise IOError(errno.ENOENT, "file", "file not found") + + self.assertRaises( + IOError, + lambda: paramiko.util.retry_on_signal(raises_ioerror_not_eintr), + ) def raises_other_exception(): - raise AssertionError('foo') - self.assertRaises(AssertionError, - lambda: paramiko.util.retry_on_signal(raises_other_exception)) + raise AssertionError("foo") + + self.assertRaises( + AssertionError, + lambda: paramiko.util.retry_on_signal(raises_other_exception), + ) def test_proxycommand_config_equals_parsing(self): """ @@ -217,17 +256,18 @@ Host equals-delimited """ f = StringIO(conf) config = paramiko.util.parse_ssh_config(f) - for host in ('space-delimited', 'equals-delimited'): + for host in ("space-delimited", "equals-delimited"): self.assertEqual( - host_config(host, config)['proxycommand'], - 'foo bar=biz baz' + host_config(host, config)["proxycommand"], "foo bar=biz baz" ) def test_proxycommand_interpolation(self): """ ProxyCommand should perform interpolation on the value """ - config = paramiko.util.parse_ssh_config(StringIO(""" + config = paramiko.util.parse_ssh_config( + StringIO( + """ Host specific Port 37 ProxyCommand host %h port %p lol @@ -238,28 +278,32 @@ Host portonly Host * Port 25 ProxyCommand host %h port %p -""")) +""" + ) + ) for host, val in ( - ('foo.com', "host foo.com port 25"), - ('specific', "host specific port 37 lol"), - ('portonly', "host portonly port 155"), + ("foo.com", "host foo.com port 25"), + ("specific", "host specific port 37 lol"), + ("portonly", "host portonly port 155"), ): - self.assertEqual( - host_config(host, config)['proxycommand'], - val - ) + self.assertEqual(host_config(host, config)["proxycommand"], val) def test_proxycommand_tilde_expansion(self): """ Tilde (~) should be expanded inside ProxyCommand """ - config = paramiko.util.parse_ssh_config(StringIO(""" + config = paramiko.util.parse_ssh_config( + StringIO( + """ Host test ProxyCommand ssh -F ~/.ssh/test_config bastion nc %h %p -""")) +""" + ) + ) self.assertEqual( - 'ssh -F %s/.ssh/test_config bastion nc test 22' % os.path.expanduser('~'), - host_config('test', config)['proxycommand'] + "ssh -F %s/.ssh/test_config bastion nc test 22" + % os.path.expanduser("~"), + host_config("test", config)["proxycommand"], ) def test_host_config_test_negation(self): @@ -278,10 +322,10 @@ Host * """ f = StringIO(test_config_file) config = paramiko.util.parse_ssh_config(f) - host = 'www13.example.com' + host = "www13.example.com" self.assertEqual( paramiko.util.lookup_ssh_host_config(host, config), - {'hostname': host, 'port': '8080'} + {"hostname": host, "port": "8080"}, ) def test_host_config_test_proxycommand(self): @@ -296,20 +340,24 @@ Host proxy-without-equal-divisor ProxyCommand foo=bar:%h-%p """ for host, values in { - 'proxy-with-equal-divisor-and-space' :{'hostname': 'proxy-with-equal-divisor-and-space', - 'proxycommand': 'foo=bar'}, - 'proxy-with-equal-divisor-and-no-space':{'hostname': 'proxy-with-equal-divisor-and-no-space', - 'proxycommand': 'foo=bar'}, - 'proxy-without-equal-divisor' :{'hostname': 'proxy-without-equal-divisor', - 'proxycommand': - 'foo=bar:proxy-without-equal-divisor-22'} + "proxy-with-equal-divisor-and-space": { + "hostname": "proxy-with-equal-divisor-and-space", + "proxycommand": "foo=bar", + }, + "proxy-with-equal-divisor-and-no-space": { + "hostname": "proxy-with-equal-divisor-and-no-space", + "proxycommand": "foo=bar", + }, + "proxy-without-equal-divisor": { + "hostname": "proxy-without-equal-divisor", + "proxycommand": "foo=bar:proxy-without-equal-divisor-22", + }, }.items(): f = StringIO(test_config_file) config = paramiko.util.parse_ssh_config(f) self.assertEqual( - paramiko.util.lookup_ssh_host_config(host, config), - values + paramiko.util.lookup_ssh_host_config(host, config), values ) def test_host_config_test_identityfile(self): @@ -327,19 +375,21 @@ Host dsa2* IdentityFile id_dsa22 """ for host, values in { - 'foo' :{'hostname': 'foo', - 'identityfile': ['id_dsa0', 'id_dsa1']}, - 'dsa2' :{'hostname': 'dsa2', - 'identityfile': ['id_dsa0', 'id_dsa1', 'id_dsa2', 'id_dsa22']}, - 'dsa22' :{'hostname': 'dsa22', - 'identityfile': ['id_dsa0', 'id_dsa1', 'id_dsa22']} + "foo": {"hostname": "foo", "identityfile": ["id_dsa0", "id_dsa1"]}, + "dsa2": { + "hostname": "dsa2", + "identityfile": ["id_dsa0", "id_dsa1", "id_dsa2", "id_dsa22"], + }, + "dsa22": { + "hostname": "dsa22", + "identityfile": ["id_dsa0", "id_dsa1", "id_dsa22"], + }, }.items(): f = StringIO(test_config_file) config = paramiko.util.parse_ssh_config(f) self.assertEqual( - paramiko.util.lookup_ssh_host_config(host, config), - values + paramiko.util.lookup_ssh_host_config(host, config), values ) def test_config_addressfamily_and_lazy_fqdn(self): @@ -351,7 +401,9 @@ AddressFamily inet IdentityFile something_%l_using_fqdn """ config = paramiko.util.parse_ssh_config(StringIO(test_config)) - assert config.lookup('meh') # will die during lookup() if bug regresses + assert config.lookup( + "meh" + ) # will die during lookup() if bug regresses def test_clamp_value(self): self.assertEqual(32768, paramiko.util.clamp_value(32767, 32768, 32769)) @@ -367,7 +419,9 @@ IdentityFile something_%l_using_fqdn def test_get_hostnames(self): f = StringIO(test_config_file) config = paramiko.util.parse_ssh_config(f) - self.assertEqual(config.get_hostnames(), {'*', '*.example.com', 'spoo.example.com'}) + self.assertEqual( + config.get_hostnames(), {"*", "*.example.com", "spoo.example.com"} + ) def test_quoted_host_names(self): test_config_file = """\ @@ -384,27 +438,23 @@ Host param4 "p a r" "p" "par" para Port 4444 """ res = { - 'param pam': {'hostname': 'param pam', 'port': '1111'}, - 'param': {'hostname': 'param', 'port': '1111'}, - 'pam': {'hostname': 'pam', 'port': '1111'}, - - 'param2': {'hostname': 'param2', 'port': '2222'}, - - 'param3': {'hostname': 'param3', 'port': '3333'}, - 'parara': {'hostname': 'parara', 'port': '3333'}, - - 'param4': {'hostname': 'param4', 'port': '4444'}, - 'p a r': {'hostname': 'p a r', 'port': '4444'}, - 'p': {'hostname': 'p', 'port': '4444'}, - 'par': {'hostname': 'par', 'port': '4444'}, - 'para': {'hostname': 'para', 'port': '4444'}, + "param pam": {"hostname": "param pam", "port": "1111"}, + "param": {"hostname": "param", "port": "1111"}, + "pam": {"hostname": "pam", "port": "1111"}, + "param2": {"hostname": "param2", "port": "2222"}, + "param3": {"hostname": "param3", "port": "3333"}, + "parara": {"hostname": "parara", "port": "3333"}, + "param4": {"hostname": "param4", "port": "4444"}, + "p a r": {"hostname": "p a r", "port": "4444"}, + "p": {"hostname": "p", "port": "4444"}, + "par": {"hostname": "par", "port": "4444"}, + "para": {"hostname": "para", "port": "4444"}, } f = StringIO(test_config_file) config = paramiko.util.parse_ssh_config(f) for host, values in res.items(): self.assertEquals( - paramiko.util.lookup_ssh_host_config(host, config), - values + paramiko.util.lookup_ssh_host_config(host, config), values ) def test_quoted_params_in_config(self): @@ -420,52 +470,44 @@ Host param3 parara IdentityFile "test rsa key" """ res = { - 'param pam': {'hostname': 'param pam', 'identityfile': ['id_rsa']}, - 'param': {'hostname': 'param', 'identityfile': ['id_rsa']}, - 'pam': {'hostname': 'pam', 'identityfile': ['id_rsa']}, - - 'param2': {'hostname': 'param2', 'identityfile': ['test rsa key']}, - - 'param3': {'hostname': 'param3', 'identityfile': ['id_rsa', 'test rsa key']}, - 'parara': {'hostname': 'parara', 'identityfile': ['id_rsa', 'test rsa key']}, + "param pam": {"hostname": "param pam", "identityfile": ["id_rsa"]}, + "param": {"hostname": "param", "identityfile": ["id_rsa"]}, + "pam": {"hostname": "pam", "identityfile": ["id_rsa"]}, + "param2": {"hostname": "param2", "identityfile": ["test rsa key"]}, + "param3": { + "hostname": "param3", + "identityfile": ["id_rsa", "test rsa key"], + }, + "parara": { + "hostname": "parara", + "identityfile": ["id_rsa", "test rsa key"], + }, } f = StringIO(test_config_file) config = paramiko.util.parse_ssh_config(f) for host, values in res.items(): self.assertEquals( - paramiko.util.lookup_ssh_host_config(host, config), - values + paramiko.util.lookup_ssh_host_config(host, config), values ) def test_quoted_host_in_config(self): conf = SSHConfig() correct_data = { - 'param': ['param'], - '"param"': ['param'], - - 'param pam': ['param', 'pam'], - '"param" "pam"': ['param', 'pam'], - '"param" pam': ['param', 'pam'], - 'param "pam"': ['param', 'pam'], - - 'param "pam" p': ['param', 'pam', 'p'], - '"param" pam "p"': ['param', 'pam', 'p'], - - '"pa ram"': ['pa ram'], - '"pa ram" pam': ['pa ram', 'pam'], - 'param "p a m"': ['param', 'p a m'], + "param": ["param"], + '"param"': ["param"], + "param pam": ["param", "pam"], + '"param" "pam"': ["param", "pam"], + '"param" pam': ["param", "pam"], + 'param "pam"': ["param", "pam"], + 'param "pam" p': ["param", "pam", "p"], + '"param" pam "p"': ["param", "pam", "p"], + '"pa ram"': ["pa ram"], + '"pa ram" pam': ["pa ram", "pam"], + 'param "p a m"': ["param", "p a m"], } - incorrect_data = [ - 'param"', - '"param', - 'param "pam', - 'param "pam" "p a', - ] + incorrect_data = ['param"', '"param', 'param "pam', 'param "pam" "p a'] for host, values in correct_data.items(): - self.assertEquals( - conf._get_hosts(host), - values - ) + self.assertEquals(conf._get_hosts(host), values) for host in incorrect_data: self.assertRaises(Exception, conf._get_hosts, host) @@ -490,15 +532,18 @@ Host proxycommand-with-equals-none ProxyCommand=None """ for host, values in { - 'proxycommand-standard-none': {'hostname': 'proxycommand-standard-none'}, - 'proxycommand-with-equals-none': {'hostname': 'proxycommand-with-equals-none'} + "proxycommand-standard-none": { + "hostname": "proxycommand-standard-none" + }, + "proxycommand-with-equals-none": { + "hostname": "proxycommand-with-equals-none" + }, }.items(): f = StringIO(test_config_file) config = paramiko.util.parse_ssh_config(f) self.assertEqual( - paramiko.util.lookup_ssh_host_config(host, config), - values + paramiko.util.lookup_ssh_host_config(host, config), values ) def test_proxycommand_none_masking(self): @@ -521,12 +566,10 @@ Host * # backwards compatibility reasons in 1.x/2.x) appear completely blank, # as if the host had no ProxyCommand whatsoever. # Threw another unrelated host in there just for sanity reasons. - self.assertFalse('proxycommand' in config.lookup('specific-host')) + self.assertFalse("proxycommand" in config.lookup("specific-host")) self.assertEqual( - config.lookup('other-host')['proxycommand'], - 'other-proxy' + config.lookup("other-host")["proxycommand"], "other-proxy" ) self.assertEqual( - config.lookup('some-random-host')['proxycommand'], - 'default-proxy' + config.lookup("some-random-host")["proxycommand"], "default-proxy" ) |