diff options
Diffstat (limited to 'tests/test_client.py')
-rw-r--r-- | tests/test_client.py | 120 |
1 files changed, 107 insertions, 13 deletions
diff --git a/tests/test_client.py b/tests/test_client.py index 7d431384..cb578394 100644 --- a/tests/test_client.py +++ b/tests/test_client.py @@ -83,6 +83,16 @@ class NullServer(paramiko.ServerInterface): return False return True + def check_channel_env_request(self, channel, name, value): + if name == "INVALID_ENV": + return False + + if not hasattr(channel, "env"): + setattr(channel, "env", {}) + + channel.env[name] = value + return True + class SSHClientTest(unittest.TestCase): def setUp(self): @@ -108,9 +118,11 @@ class SSHClientTest(unittest.TestCase): allowed_keys = FINGERPRINTS.keys() self.socks, addr = self.sockl.accept() self.ts = paramiko.Transport(self.socks) - host_key = paramiko.RSAKey.from_private_key_file( - _support("test_rsa.key") - ) + keypath = _support("test_rsa.key") + host_key = paramiko.RSAKey.from_private_key_file(keypath) + self.ts.add_server_key(host_key) + keypath = _support("test_ecdsa_256.key") + host_key = paramiko.ECDSAKey.from_private_key_file(keypath) self.ts.add_server_key(host_key) server = NullServer(allowed_keys=allowed_keys) if delay: @@ -240,10 +252,9 @@ class SSHClientTest(unittest.TestCase): verify that SSHClient's AutoAddPolicy works. """ threading.Thread(target=self._run).start() - host_key = paramiko.RSAKey.from_private_key_file( - _support("test_rsa.key") - ) - public_host_key = paramiko.RSAKey(data=host_key.asbytes()) + hostname = "[%s]:%d" % (self.addr, self.port) + key_file = _support("test_ecdsa_256.key") + public_host_key = paramiko.ECDSAKey.from_private_key_file(key_file) self.tc = paramiko.SSHClient() self.tc.set_missing_host_key_policy(paramiko.AutoAddPolicy()) @@ -256,12 +267,8 @@ class SSHClientTest(unittest.TestCase): self.assertEqual("slowdive", self.ts.get_username()) self.assertEqual(True, self.ts.is_authenticated()) self.assertEqual(1, len(self.tc.get_host_keys())) - self.assertEqual( - public_host_key, - self.tc.get_host_keys()["[%s]:%d" % (self.addr, self.port)][ - "ssh-rsa" - ], - ) + new_host_key = list(self.tc.get_host_keys()[hostname].values())[0] + self.assertEqual(public_host_key, new_host_key) def test_5_save_host_keys(self): """ @@ -438,3 +445,90 @@ class SSHClientTest(unittest.TestCase): gss_kex=True, **self.connect_kwargs ) + + def _client_host_key_bad(self, host_key): + threading.Thread(target=self._run).start() + hostname = "[%s]:%d" % (self.addr, self.port) + + self.tc = paramiko.SSHClient() + self.tc.set_missing_host_key_policy(paramiko.WarningPolicy()) + known_hosts = self.tc.get_host_keys() + known_hosts.add(hostname, host_key.get_name(), host_key) + + self.assertRaises( + paramiko.BadHostKeyException, + self.tc.connect, + password="pygmalion", + **self.connect_kwargs + ) + + def _client_host_key_good(self, ktype, kfile): + threading.Thread(target=self._run).start() + hostname = "[%s]:%d" % (self.addr, self.port) + + self.tc = paramiko.SSHClient() + self.tc.set_missing_host_key_policy(paramiko.RejectPolicy()) + host_key = ktype.from_private_key_file(_support(kfile)) + known_hosts = self.tc.get_host_keys() + known_hosts.add(hostname, host_key.get_name(), host_key) + + self.tc.connect(password="pygmalion", **self.connect_kwargs) + self.event.wait(1.0) + self.assertTrue(self.event.is_set()) + self.assertTrue(self.ts.is_active()) + self.assertEqual(True, self.ts.is_authenticated()) + + def test_host_key_negotiation_1(self): + host_key = paramiko.ECDSAKey.generate() + self._client_host_key_bad(host_key) + + def test_host_key_negotiation_2(self): + host_key = paramiko.RSAKey.generate(2048) + self._client_host_key_bad(host_key) + + def test_host_key_negotiation_3(self): + self._client_host_key_good(paramiko.ECDSAKey, "test_ecdsa_256.key") + + def test_host_key_negotiation_4(self): + self._client_host_key_good(paramiko.RSAKey, "test_rsa.key") + + def test_update_environment(self): + """ + Verify that environment variables can be set by the client. + """ + threading.Thread(target=self._run).start() + + self.tc = paramiko.SSHClient() + self.tc.set_missing_host_key_policy(paramiko.AutoAddPolicy()) + self.assertEqual(0, len(self.tc.get_host_keys())) + self.tc.connect( + self.addr, self.port, username="slowdive", password="pygmalion" + ) + + self.event.wait(1.0) + self.assertTrue(self.event.isSet()) + self.assertTrue(self.ts.is_active()) + + target_env = {b"A": b"B", b"C": b"d"} + + self.tc.exec_command("yes", environment=target_env) + schan = self.ts.accept(1.0) + self.assertEqual(target_env, getattr(schan, "env", {})) + schan.close() + + # Cannot use assertRaises in context manager mode as it is not supported + # in Python 2.6. + try: + # Verify that a rejection by the server can be detected + self.tc.exec_command("yes", environment={b"INVALID_ENV": b""}) + except SSHException as e: + self.assertTrue( + "INVALID_ENV" in str(e), + "Expected variable name in error message", + ) + self.assertTrue( + isinstance(e.args[1], SSHException), + "Expected original SSHException in exception", + ) + else: + self.assertFalse(False, "SSHException was not thrown.") |