# Copyright (C) 2003-2009 Robey Pointer # # This file is part of paramiko. # # Paramiko is free software; you can redistribute it and/or modify it under the # terms of the GNU Lesser General Public License as published by the Free # Software Foundation; either version 2.1 of the License, or (at your option) # any later version. # # Paramiko is distributed in the hope that it will be useful, but WITHOUT ANY # WARRANTY; without even the implied warranty of MERCHANTABILITY or FITNESS FOR # A PARTICULAR PURPOSE. See the GNU Lesser General Public License for more # details. # # You should have received a copy of the GNU Lesser General Public License # along with Paramiko; if not, write to the Free Software Foundation, Inc., # 59 Temple Place, Suite 330, Boston, MA 02111-1307 USA. """ Some unit tests for SSHClient. """ from __future__ import with_statement import gc import os import platform import socket import threading import time import unittest import warnings import weakref from tempfile import mkstemp import paramiko from paramiko.pkey import PublicBlob from paramiko.common import PY2 from paramiko.ssh_exception import SSHException, AuthenticationException from .util import _support, slow requires_gss_auth = unittest.skipUnless( paramiko.GSS_AUTH_AVAILABLE, "GSS auth not available" ) FINGERPRINTS = { 'ssh-dss': b'\x44\x78\xf0\xb9\xa2\x3c\xc5\x18\x20\x09\xff\x75\x5b\xc1\xd2\x6c', 'ssh-rsa': b'\x60\x73\x38\x44\xcb\x51\x86\x65\x7f\xde\xda\xa2\x2b\x5a\x57\xd5', 'ecdsa-sha2-nistp256': b'\x25\x19\xeb\x55\xe6\xa1\x47\xff\x4f\x38\xd2\x75\x6f\xa5\xd5\x60', 'ssh-ed25519': b'\xb3\xd5"\xaa\xf9u^\xe8\xcd\x0e\xea\x02\xb9)\xa2\x80', } class NullServer(paramiko.ServerInterface): def __init__(self, *args, **kwargs): # Allow tests to enable/disable specific key types self.__allowed_keys = kwargs.pop('allowed_keys', []) # And allow them to set a (single...meh) expected public blob (cert) self.__expected_public_blob = kwargs.pop('public_blob', None) super(NullServer, self).__init__(*args, **kwargs) def get_allowed_auths(self, username): if username == 'slowdive': return 'publickey,password' return 'publickey' def check_auth_password(self, username, password): if (username == 'slowdive') and (password == 'pygmalion'): return paramiko.AUTH_SUCCESSFUL if (username == 'slowdive') and (password == 'unresponsive-server'): time.sleep(5) return paramiko.AUTH_SUCCESSFUL return paramiko.AUTH_FAILED def check_auth_publickey(self, username, key): try: expected = FINGERPRINTS[key.get_name()] except KeyError: return paramiko.AUTH_FAILED # Base check: allowed auth type & fingerprint matches happy = ( key.get_name() in self.__allowed_keys and key.get_fingerprint() == expected ) # Secondary check: if test wants assertions about cert data if ( self.__expected_public_blob is not None and key.public_blob != self.__expected_public_blob ): happy = False return paramiko.AUTH_SUCCESSFUL if happy else paramiko.AUTH_FAILED def check_channel_request(self, kind, chanid): return paramiko.OPEN_SUCCEEDED def check_channel_exec_request(self, channel, command): if command != b'yes': return False return True def check_channel_env_request(self, channel, name, value): if name == 'INVALID_ENV': return False if not hasattr(channel, 'env'): setattr(channel, 'env', {}) channel.env[name] = value return True class ClientTest(unittest.TestCase): def setUp(self): self.sockl = socket.socket() self.sockl.bind(('localhost', 0)) self.sockl.listen(1) self.addr, self.port = self.sockl.getsockname() self.connect_kwargs = dict( hostname=self.addr, port=self.port, username='slowdive', look_for_keys=False, ) self.event = threading.Event() self.kill_event = threading.Event() def tearDown(self): # Shut down client Transport if hasattr(self, 'tc'): self.tc.close() # Shut down shared socket if hasattr(self, 'sockl'): # Signal to server thread that it should shut down early; it checks # this immediately after accept(). (In scenarios where connection # actually succeeded during the test, this becomes a no-op.) self.kill_event.set() # Forcibly connect to server sock in case the server thread is # hanging out in its accept() (e.g. if the client side of the test # fails before it even gets to connecting); there's no other good # way to force an accept() to exit. put_a_sock_in_it = socket.socket() put_a_sock_in_it.connect((self.addr, self.port)) put_a_sock_in_it.close() # Then close "our" end of the socket (which _should_ cause the # accept() to bail out, but does not, for some reason. I blame # threading.) self.sockl.close() def _run( self, allowed_keys=None, delay=0, public_blob=None, kill_event=None, ): if allowed_keys is None: allowed_keys = FINGERPRINTS.keys() self.socks, addr = self.sockl.accept() # If the kill event was set at this point, it indicates an early # shutdown, so bail out now and don't even try setting up a Transport # (which will just verbosely die.) if kill_event.is_set(): self.socks.close() return self.ts = paramiko.Transport(self.socks) keypath = _support('test_rsa.key') host_key = paramiko.RSAKey.from_private_key_file(keypath) self.ts.add_server_key(host_key) keypath = _support('test_ecdsa_256.key') host_key = paramiko.ECDSAKey.from_private_key_file(keypath) self.ts.add_server_key(host_key) server = NullServer(allowed_keys=allowed_keys, public_blob=public_blob) if delay: time.sleep(delay) self.ts.start_server(self.event, server) def _test_connection(self, **kwargs): """ (Most) kwargs get passed directly into SSHClient.connect(). The exception is ``allowed_keys`` which is stripped and handed to the ``NullServer`` used for testing. """ run_kwargs = {'kill_event': self.kill_event} for key in ('allowed_keys', 'public_blob'): run_kwargs[key] = kwargs.pop(key, None) # Server setup threading.Thread(target=self._run, kwargs=run_kwargs).start() host_key = paramiko.RSAKey.from_private_key_file(_support('test_rsa.key')) public_host_key = paramiko.RSAKey(data=host_key.asbytes()) # Client setup self.tc = paramiko.SSHClient() self.tc.get_host_keys().add('[%s]:%d' % (self.addr, self.port), 'ssh-rsa', public_host_key) # Actual connection self.tc.connect(**dict(self.connect_kwargs, **kwargs)) # Authentication successful? self.event.wait(1.0) self.assertTrue(self.event.is_set()) self.assertTrue(self.ts.is_active()) self.assertEqual('slowdive', self.ts.get_username()) self.assertEqual(True, self.ts.is_authenticated()) self.assertEqual(False, self.tc.get_transport().gss_kex_used) # Command execution functions? stdin, stdout, stderr = self.tc.exec_command('yes') schan = self.ts.accept(1.0) schan.send('Hello there.\n') schan.send_stderr('This is on stderr.\n') schan.close() self.assertEqual('Hello there.\n', stdout.readline()) self.assertEqual('', stdout.readline()) self.assertEqual('This is on stderr.\n', stderr.readline()) self.assertEqual('', stderr.readline()) # Cleanup stdin.close() stdout.close() stderr.close() class SSHClientTest(ClientTest): def test_1_client(self): """ verify that the SSHClient stuff works too. """ self._test_connection(password='pygmalion') def test_2_client_dsa(self): """ verify that SSHClient works with a DSA key. """ self._test_connection(key_filename=_support('test_dss.key')) def test_client_rsa(self): """ verify that SSHClient works with an RSA key. """ self._test_connection(key_filename=_support('test_rsa.key')) def test_2_5_client_ecdsa(self): """ verify that SSHClient works with an ECDSA key. """ self._test_connection(key_filename=_support('test_ecdsa_256.key')) def test_client_ed25519(self): self._test_connection(key_filename=_support('test_ed25519.key')) def test_3_multiple_key_files(self): """ verify that SSHClient accepts and tries multiple key files. """ # This is dumb :( types_ = { 'rsa': 'ssh-rsa', 'dss': 'ssh-dss', 'ecdsa': 'ecdsa-sha2-nistp256', } # Various combos of attempted & valid keys # TODO: try every possible combo using itertools functions for attempt, accept in ( (['rsa', 'dss'], ['dss']), # Original test #3 (['dss', 'rsa'], ['dss']), # Ordering matters sometimes, sadly (['dss', 'rsa', 'ecdsa_256'], ['dss']), # Try ECDSA but fail (['rsa', 'ecdsa_256'], ['ecdsa']), # ECDSA success ): try: self._test_connection( key_filename=[ _support('test_{}.key'.format(x)) for x in attempt ], allowed_keys=[types_[x] for x in accept], ) finally: # Clean up to avoid occasional gc-related deadlocks. # TODO: use nose test generators after nose port self.tearDown() self.setUp() def test_multiple_key_files_failure(self): """ Expect failure when multiple keys in play and none are accepted """ # Until #387 is fixed we have to catch a high-up exception since # various platforms trigger different errors here >_< self.assertRaises(SSHException, self._test_connection, key_filename=[_support('test_rsa.key')], allowed_keys=['ecdsa-sha2-nistp256'], ) def test_certs_allowed_as_key_filename_values(self): # NOTE: giving cert path here, not key path. (Key path test is below. # They're similar except for which path is given; the expected auth and # server-side behavior is 100% identical.) # NOTE: only bothered whipping up one cert per overall class/family. for type_ in ('rsa', 'dss', 'ecdsa_256', 'ed25519'): cert_name = 'test_{}.key-cert.pub'.format(type_) cert_path = _support(os.path.join('cert_support', cert_name)) self._test_connection( key_filename=cert_path, public_blob=PublicBlob.from_file(cert_path), ) def test_certs_implicitly_loaded_alongside_key_filename_keys(self): # NOTE: a regular test_connection() w/ test_rsa.key would incidentally # test this (because test_xxx.key-cert.pub exists) but incidental tests # stink, so NullServer and friends were updated to allow assertions # about the server-side key object's public blob. Thus, we can prove # that a specific cert was found, along with regular authorization # succeeding proving that the overall flow works. for type_ in ('rsa', 'dss', 'ecdsa_256', 'ed25519'): key_name = 'test_{}.key'.format(type_) key_path = _support(os.path.join('cert_support', key_name)) self._test_connection( key_filename=key_path, public_blob=PublicBlob.from_file( '{}-cert.pub'.format(key_path) ), ) def test_default_key_locations_trigger_cert_loads_if_found(self): # TODO: what it says on the tin: ~/.ssh/id_rsa tries to load # ~/.ssh/id_rsa-cert.pub. Right now no other tests actually test that # code path (!) so we're punting too, sob. pass def test_4_auto_add_policy(self): """ verify that SSHClient's AutoAddPolicy works. """ threading.Thread(target=self._run).start() hostname = '[%s]:%d' % (self.addr, self.port) key_file = _support('test_ecdsa_256.key') public_host_key = paramiko.ECDSAKey.from_private_key_file(key_file) self.tc = paramiko.SSHClient() self.tc.set_missing_host_key_policy(paramiko.AutoAddPolicy()) self.assertEqual(0, len(self.tc.get_host_keys())) self.tc.connect(password='pygmalion', **self.connect_kwargs) self.event.wait(1.0) self.assertTrue(self.event.is_set()) self.assertTrue(self.ts.is_active()) self.assertEqual('slowdive', self.ts.get_username()) self.assertEqual(True, self.ts.is_authenticated()) self.assertEqual(1, len(self.tc.get_host_keys())) new_host_key = list(self.tc.get_host_keys()[hostname].values())[0] self.assertEqual(public_host_key, new_host_key) def test_5_save_host_keys(self): """ verify that SSHClient correctly saves a known_hosts file. """ warnings.filterwarnings('ignore', 'tempnam.*') host_key = paramiko.RSAKey.from_private_key_file(_support('test_rsa.key')) public_host_key = paramiko.RSAKey(data=host_key.asbytes()) fd, localname = mkstemp() os.close(fd) client = paramiko.SSHClient() self.assertEquals(0, len(client.get_host_keys())) host_id = '[%s]:%d' % (self.addr, self.port) client.get_host_keys().add(host_id, 'ssh-rsa', public_host_key) self.assertEquals(1, len(client.get_host_keys())) self.assertEquals(public_host_key, client.get_host_keys()[host_id]['ssh-rsa']) client.save_host_keys(localname) with open(localname) as fd: assert host_id in fd.read() os.unlink(localname) def test_6_cleanup(self): """ verify that when an SSHClient is collected, its transport (and the transport's packetizer) is closed. """ # Skipped on PyPy because it fails on travis for unknown reasons if platform.python_implementation() == "PyPy": return threading.Thread(target=self._run).start() self.tc = paramiko.SSHClient() self.tc.set_missing_host_key_policy(paramiko.AutoAddPolicy()) self.assertEqual(0, len(self.tc.get_host_keys())) self.tc.connect(**dict(self.connect_kwargs, password='pygmalion')) self.event.wait(1.0) self.assertTrue(self.event.is_set()) self.assertTrue(self.ts.is_active()) p = weakref.ref(self.tc._transport.packetizer) self.assertTrue(p() is not None) self.tc.close() del self.tc # force a collection to see whether the SSHClient object is deallocated # 2 GCs are needed on PyPy, time is needed for Python 3 time.sleep(0.3) gc.collect() gc.collect() self.assertTrue(p() is None) def test_client_can_be_used_as_context_manager(self): """ verify that an SSHClient can be used a context manager """ threading.Thread(target=self._run).start() with paramiko.SSHClient() as tc: self.tc = tc self.tc.set_missing_host_key_policy(paramiko.AutoAddPolicy()) self.assertEquals(0, len(self.tc.get_host_keys())) self.tc.connect(**dict(self.connect_kwargs, password='pygmalion')) self.event.wait(1.0) self.assertTrue(self.event.is_set()) self.assertTrue(self.ts.is_active()) self.assertTrue(self.tc._transport is not None) self.assertTrue(self.tc._transport is None) def test_7_banner_timeout(self): """ verify that the SSHClient has a configurable banner timeout. """ # Start the thread with a 1 second wait. threading.Thread(target=self._run, kwargs={'delay': 1}).start() host_key = paramiko.RSAKey.from_private_key_file(_support('test_rsa.key')) public_host_key = paramiko.RSAKey(data=host_key.asbytes()) self.tc = paramiko.SSHClient() self.tc.get_host_keys().add('[%s]:%d' % (self.addr, self.port), 'ssh-rsa', public_host_key) # Connect with a half second banner timeout. kwargs = dict(self.connect_kwargs, banner_timeout=0.5) self.assertRaises( paramiko.SSHException, self.tc.connect, **kwargs ) def test_8_auth_trickledown(self): """ Failed key auth doesn't prevent subsequent pw auth from succeeding """ # NOTE: re #387, re #394 # If pkey module used within Client._auth isn't correctly handling auth # errors (e.g. if it allows things like ValueError to bubble up as per # midway through #394) client.connect() will fail (at key load step) # instead of succeeding (at password step) kwargs = dict( # Password-protected key whose passphrase is not 'pygmalion' (it's # 'television' as per tests/test_pkey.py). NOTE: must use # key_filename, loading the actual key here with PKey will except # immediately; we're testing the try/except crap within Client. key_filename=[_support('test_rsa_password.key')], # Actual password for default 'slowdive' user password='pygmalion', ) self._test_connection(**kwargs) @slow def test_9_auth_timeout(self): """ verify that the SSHClient has a configurable auth timeout """ # Connect with a half second auth timeout self.assertRaises( AuthenticationException, self._test_connection, password='unresponsive-server', auth_timeout=0.5, ) @requires_gss_auth def test_10_auth_trickledown_gsskex(self): """ Failed gssapi-keyex auth doesn't prevent subsequent key auth from succeeding """ kwargs = dict( gss_kex=True, key_filename=[_support('test_rsa.key')], ) self._test_connection(**kwargs) @requires_gss_auth def test_11_auth_trickledown_gssauth(self): """ Failed gssapi-with-mic auth doesn't prevent subsequent key auth from succeeding """ kwargs = dict( gss_auth=True, key_filename=[_support('test_rsa.key')], ) self._test_connection(**kwargs) def test_12_reject_policy(self): """ verify that SSHClient's RejectPolicy works. """ threading.Thread(target=self._run).start() self.tc = paramiko.SSHClient() self.tc.set_missing_host_key_policy(paramiko.RejectPolicy()) self.assertEqual(0, len(self.tc.get_host_keys())) self.assertRaises( paramiko.SSHException, self.tc.connect, password='pygmalion', **self.connect_kwargs ) @requires_gss_auth def test_13_reject_policy_gsskex(self): """ verify that SSHClient's RejectPolicy works, even if gssapi-keyex was enabled but not used. """ # Test for a bug present in paramiko versions released before 2017-08-01 threading.Thread(target=self._run).start() self.tc = paramiko.SSHClient() self.tc.set_missing_host_key_policy(paramiko.RejectPolicy()) self.assertEqual(0, len(self.tc.get_host_keys())) self.assertRaises( paramiko.SSHException, self.tc.connect, password='pygmalion', gss_kex=True, **self.connect_kwargs ) def _client_host_key_bad(self, host_key): threading.Thread(target=self._run).start() hostname = '[%s]:%d' % (self.addr, self.port) self.tc = paramiko.SSHClient() self.tc.set_missing_host_key_policy(paramiko.WarningPolicy()) known_hosts = self.tc.get_host_keys() known_hosts.add(hostname, host_key.get_name(), host_key) self.assertRaises( paramiko.BadHostKeyException, self.tc.connect, password='pygmalion', **self.connect_kwargs ) def _client_host_key_good(self, ktype, kfile): threading.Thread(target=self._run).start() hostname = '[%s]:%d' % (self.addr, self.port) self.tc = paramiko.SSHClient() self.tc.set_missing_host_key_policy(paramiko.RejectPolicy()) host_key = ktype.from_private_key_file(_support(kfile)) known_hosts = self.tc.get_host_keys() known_hosts.add(hostname, host_key.get_name(), host_key) self.tc.connect(password='pygmalion', **self.connect_kwargs) self.event.wait(1.0) self.assertTrue(self.event.is_set()) self.assertTrue(self.ts.is_active()) self.assertEqual(True, self.ts.is_authenticated()) def test_host_key_negotiation_1(self): host_key = paramiko.ECDSAKey.generate() self._client_host_key_bad(host_key) def test_host_key_negotiation_2(self): host_key = paramiko.RSAKey.generate(2048) self._client_host_key_bad(host_key) def test_host_key_negotiation_3(self): self._client_host_key_good(paramiko.ECDSAKey, 'test_ecdsa_256.key') def test_host_key_negotiation_4(self): self._client_host_key_good(paramiko.RSAKey, 'test_rsa.key') def _setup_for_env(self): threading.Thread(target=self._run).start() self.tc = paramiko.SSHClient() self.tc.set_missing_host_key_policy(paramiko.AutoAddPolicy()) self.assertEqual(0, len(self.tc.get_host_keys())) self.tc.connect(self.addr, self.port, username='slowdive', password='pygmalion') self.event.wait(1.0) self.assertTrue(self.event.isSet()) self.assertTrue(self.ts.is_active()) def test_update_environment(self): """ Verify that environment variables can be set by the client. """ self._setup_for_env() target_env = {b'A': b'B', b'C': b'd'} self.tc.exec_command('yes', environment=target_env) schan = self.ts.accept(1.0) self.assertEqual(target_env, getattr(schan, 'env', {})) schan.close() @unittest.skip("Clients normally fail silently, thus so do we, for now") def test_env_update_failures(self): self._setup_for_env() with self.assertRaises(SSHException) as manager: # Verify that a rejection by the server can be detected self.tc.exec_command('yes', environment={b'INVALID_ENV': b''}) self.assertTrue( 'INVALID_ENV' in str(manager.exception), 'Expected variable name in error message' ) self.assertTrue( isinstance(manager.exception.args[1], SSHException), 'Expected original SSHException in exception' ) def test_missing_key_policy_accepts_classes_or_instances(self): """ Client.missing_host_key_policy() can take classes or instances. """ # AN ACTUAL UNIT TEST?! GOOD LORD # (But then we have to test a private API...meh.) client = paramiko.SSHClient() # Default assert isinstance(client._policy, paramiko.RejectPolicy) # Hand in an instance (classic behavior) client.set_missing_host_key_policy(paramiko.AutoAddPolicy()) assert isinstance(client._policy, paramiko.AutoAddPolicy) # Hand in just the class (new behavior) client.set_missing_host_key_policy(paramiko.AutoAddPolicy) assert isinstance(client._policy, paramiko.AutoAddPolicy)