summaryrefslogtreecommitdiffhomepage
path: root/tests/test_client.py
diff options
context:
space:
mode:
Diffstat (limited to 'tests/test_client.py')
-rw-r--r--tests/test_client.py120
1 files changed, 107 insertions, 13 deletions
diff --git a/tests/test_client.py b/tests/test_client.py
index 7d431384..cb578394 100644
--- a/tests/test_client.py
+++ b/tests/test_client.py
@@ -83,6 +83,16 @@ class NullServer(paramiko.ServerInterface):
return False
return True
+ def check_channel_env_request(self, channel, name, value):
+ if name == "INVALID_ENV":
+ return False
+
+ if not hasattr(channel, "env"):
+ setattr(channel, "env", {})
+
+ channel.env[name] = value
+ return True
+
class SSHClientTest(unittest.TestCase):
def setUp(self):
@@ -108,9 +118,11 @@ class SSHClientTest(unittest.TestCase):
allowed_keys = FINGERPRINTS.keys()
self.socks, addr = self.sockl.accept()
self.ts = paramiko.Transport(self.socks)
- host_key = paramiko.RSAKey.from_private_key_file(
- _support("test_rsa.key")
- )
+ keypath = _support("test_rsa.key")
+ host_key = paramiko.RSAKey.from_private_key_file(keypath)
+ self.ts.add_server_key(host_key)
+ keypath = _support("test_ecdsa_256.key")
+ host_key = paramiko.ECDSAKey.from_private_key_file(keypath)
self.ts.add_server_key(host_key)
server = NullServer(allowed_keys=allowed_keys)
if delay:
@@ -240,10 +252,9 @@ class SSHClientTest(unittest.TestCase):
verify that SSHClient's AutoAddPolicy works.
"""
threading.Thread(target=self._run).start()
- host_key = paramiko.RSAKey.from_private_key_file(
- _support("test_rsa.key")
- )
- public_host_key = paramiko.RSAKey(data=host_key.asbytes())
+ hostname = "[%s]:%d" % (self.addr, self.port)
+ key_file = _support("test_ecdsa_256.key")
+ public_host_key = paramiko.ECDSAKey.from_private_key_file(key_file)
self.tc = paramiko.SSHClient()
self.tc.set_missing_host_key_policy(paramiko.AutoAddPolicy())
@@ -256,12 +267,8 @@ class SSHClientTest(unittest.TestCase):
self.assertEqual("slowdive", self.ts.get_username())
self.assertEqual(True, self.ts.is_authenticated())
self.assertEqual(1, len(self.tc.get_host_keys()))
- self.assertEqual(
- public_host_key,
- self.tc.get_host_keys()["[%s]:%d" % (self.addr, self.port)][
- "ssh-rsa"
- ],
- )
+ new_host_key = list(self.tc.get_host_keys()[hostname].values())[0]
+ self.assertEqual(public_host_key, new_host_key)
def test_5_save_host_keys(self):
"""
@@ -438,3 +445,90 @@ class SSHClientTest(unittest.TestCase):
gss_kex=True,
**self.connect_kwargs
)
+
+ def _client_host_key_bad(self, host_key):
+ threading.Thread(target=self._run).start()
+ hostname = "[%s]:%d" % (self.addr, self.port)
+
+ self.tc = paramiko.SSHClient()
+ self.tc.set_missing_host_key_policy(paramiko.WarningPolicy())
+ known_hosts = self.tc.get_host_keys()
+ known_hosts.add(hostname, host_key.get_name(), host_key)
+
+ self.assertRaises(
+ paramiko.BadHostKeyException,
+ self.tc.connect,
+ password="pygmalion",
+ **self.connect_kwargs
+ )
+
+ def _client_host_key_good(self, ktype, kfile):
+ threading.Thread(target=self._run).start()
+ hostname = "[%s]:%d" % (self.addr, self.port)
+
+ self.tc = paramiko.SSHClient()
+ self.tc.set_missing_host_key_policy(paramiko.RejectPolicy())
+ host_key = ktype.from_private_key_file(_support(kfile))
+ known_hosts = self.tc.get_host_keys()
+ known_hosts.add(hostname, host_key.get_name(), host_key)
+
+ self.tc.connect(password="pygmalion", **self.connect_kwargs)
+ self.event.wait(1.0)
+ self.assertTrue(self.event.is_set())
+ self.assertTrue(self.ts.is_active())
+ self.assertEqual(True, self.ts.is_authenticated())
+
+ def test_host_key_negotiation_1(self):
+ host_key = paramiko.ECDSAKey.generate()
+ self._client_host_key_bad(host_key)
+
+ def test_host_key_negotiation_2(self):
+ host_key = paramiko.RSAKey.generate(2048)
+ self._client_host_key_bad(host_key)
+
+ def test_host_key_negotiation_3(self):
+ self._client_host_key_good(paramiko.ECDSAKey, "test_ecdsa_256.key")
+
+ def test_host_key_negotiation_4(self):
+ self._client_host_key_good(paramiko.RSAKey, "test_rsa.key")
+
+ def test_update_environment(self):
+ """
+ Verify that environment variables can be set by the client.
+ """
+ threading.Thread(target=self._run).start()
+
+ self.tc = paramiko.SSHClient()
+ self.tc.set_missing_host_key_policy(paramiko.AutoAddPolicy())
+ self.assertEqual(0, len(self.tc.get_host_keys()))
+ self.tc.connect(
+ self.addr, self.port, username="slowdive", password="pygmalion"
+ )
+
+ self.event.wait(1.0)
+ self.assertTrue(self.event.isSet())
+ self.assertTrue(self.ts.is_active())
+
+ target_env = {b"A": b"B", b"C": b"d"}
+
+ self.tc.exec_command("yes", environment=target_env)
+ schan = self.ts.accept(1.0)
+ self.assertEqual(target_env, getattr(schan, "env", {}))
+ schan.close()
+
+ # Cannot use assertRaises in context manager mode as it is not supported
+ # in Python 2.6.
+ try:
+ # Verify that a rejection by the server can be detected
+ self.tc.exec_command("yes", environment={b"INVALID_ENV": b""})
+ except SSHException as e:
+ self.assertTrue(
+ "INVALID_ENV" in str(e),
+ "Expected variable name in error message",
+ )
+ self.assertTrue(
+ isinstance(e.args[1], SSHException),
+ "Expected original SSHException in exception",
+ )
+ else:
+ self.assertFalse(False, "SSHException was not thrown.")