diff options
author | Paul Kehrer <paul.l.kehrer@gmail.com> | 2017-06-25 20:47:34 -1000 |
---|---|---|
committer | GitHub <noreply@github.com> | 2017-06-25 20:47:34 -1000 |
commit | cf14b9ff3004b9c9316417cd657a77797675b628 (patch) | |
tree | 523656d9d9c820cd3023bf89132e33c732cfa555 /tests | |
parent | fdc09c9f93fd189a6398d5b350a3c91011d9b4cb (diff) | |
parent | 842caba00262a81975cbfd186b846c83f72354e3 (diff) |
Merge branch 'master' into one-shot-methods
Diffstat (limited to 'tests')
-rw-r--r-- | tests/__init__.py | 36 | ||||
-rw-r--r-- | tests/stub_sftp.py | 14 | ||||
-rw-r--r-- | tests/test_auth.py | 19 | ||||
-rw-r--r-- | tests/test_client.py | 107 | ||||
-rwxr-xr-x | tests/test_file.py | 57 | ||||
-rw-r--r-- | tests/test_kex.py | 57 | ||||
-rw-r--r-- | tests/test_pkey.py | 45 | ||||
-rwxr-xr-x | tests/test_sftp.py | 63 | ||||
-rw-r--r-- | tests/test_transport.py | 78 |
9 files changed, 450 insertions, 26 deletions
diff --git a/tests/__init__.py b/tests/__init__.py index e69de29b..8878f14d 100644 --- a/tests/__init__.py +++ b/tests/__init__.py @@ -0,0 +1,36 @@ +# Copyright (C) 2017 Martin Packman <gzlist@googlemail.com> +# +# This file is part of paramiko. +# +# Paramiko is free software; you can redistribute it and/or modify it under the +# terms of the GNU Lesser General Public License as published by the Free +# Software Foundation; either version 2.1 of the License, or (at your option) +# any later version. +# +# Paramiko is distributed in the hope that it will be useful, but WITHOUT ANY +# WARRANTY; without even the implied warranty of MERCHANTABILITY or FITNESS FOR +# A PARTICULAR PURPOSE. See the GNU Lesser General Public License for more +# details. +# +# You should have received a copy of the GNU Lesser General Public License +# along with Paramiko; if not, write to the Free Software Foundation, Inc., +# 59 Temple Place, Suite 330, Boston, MA 02111-1307 USA. + +"""Base classes and helpers for testing paramiko.""" + +import unittest + +from paramiko.py3compat import ( + builtins, + ) + + +def skipUnlessBuiltin(name): + """Skip decorated test if builtin name does not exist.""" + if getattr(builtins, name, None) is None: + skip = getattr(unittest, "skip", None) + if skip is None: + # Python 2.6 pseudo-skip + return lambda func: None + return skip("No builtin " + repr(name)) + return lambda func: func diff --git a/tests/stub_sftp.py b/tests/stub_sftp.py index 334af561..0d673091 100644 --- a/tests/stub_sftp.py +++ b/tests/stub_sftp.py @@ -24,7 +24,7 @@ import os import sys from paramiko import ( ServerInterface, SFTPServerInterface, SFTPServer, SFTPAttributes, - SFTPHandle, SFTP_OK, AUTH_SUCCESSFUL, OPEN_SUCCEEDED, + SFTPHandle, SFTP_OK, SFTP_FAILURE, AUTH_SUCCESSFUL, OPEN_SUCCEEDED, ) from paramiko.common import o666 @@ -141,12 +141,24 @@ class StubSFTPServer (SFTPServerInterface): def rename(self, oldpath, newpath): oldpath = self._realpath(oldpath) newpath = self._realpath(newpath) + if os.path.exists(newpath): + return SFTP_FAILURE try: os.rename(oldpath, newpath) except OSError as e: return SFTPServer.convert_errno(e.errno) return SFTP_OK + def posix_rename(self, oldpath, newpath): + oldpath = self._realpath(oldpath) + newpath = self._realpath(newpath) + try: + os.rename(oldpath, newpath) + except OSError as e: + return SFTPServer.convert_errno(e.errno) + return SFTP_OK + + def mkdir(self, path, attr): path = self._realpath(path) try: diff --git a/tests/test_auth.py b/tests/test_auth.py index 96f7611c..e78397c6 100644 --- a/tests/test_auth.py +++ b/tests/test_auth.py @@ -23,6 +23,7 @@ Some unit tests for authenticating over a Transport. import sys import threading import unittest +from time import sleep from paramiko import ( Transport, ServerInterface, RSAKey, DSSKey, BadAuthenticationType, @@ -74,6 +75,9 @@ class NullServer (ServerInterface): return AUTH_SUCCESSFUL if username == 'bad-server': raise Exception("Ack!") + if username == 'unresponsive-server': + sleep(5) + return AUTH_SUCCESSFUL return AUTH_FAILED def check_auth_publickey(self, username, key): @@ -233,3 +237,18 @@ class AuthTest (unittest.TestCase): except: etype, evalue, etb = sys.exc_info() self.assertTrue(issubclass(etype, AuthenticationException)) + + def test_9_auth_non_responsive(self): + """ + verify that authentication times out if server takes to long to + respond (or never responds). + """ + self.tc.auth_timeout = 1 # 1 second, to speed up test + self.start_server() + self.tc.connect() + try: + remain = self.tc.auth_password('unresponsive-server', 'hello') + except: + etype, evalue, etb = sys.exc_info() + self.assertTrue(issubclass(etype, AuthenticationException)) + self.assertTrue('Authentication timeout' in str(evalue)) diff --git a/tests/test_client.py b/tests/test_client.py index a340be00..e912d5b2 100644 --- a/tests/test_client.py +++ b/tests/test_client.py @@ -36,7 +36,7 @@ from tests.util import test_path import paramiko from paramiko.common import PY2 -from paramiko.ssh_exception import SSHException +from paramiko.ssh_exception import SSHException, AuthenticationException FINGERPRINTS = { @@ -61,6 +61,9 @@ class NullServer (paramiko.ServerInterface): def check_auth_password(self, username, password): if (username == 'slowdive') and (password == 'pygmalion'): return paramiko.AUTH_SUCCESSFUL + if (username == 'slowdive') and (password == 'unresponsive-server'): + time.sleep(5) + return paramiko.AUTH_SUCCESSFUL return paramiko.AUTH_FAILED def check_auth_publickey(self, username, key): @@ -119,7 +122,11 @@ class SSHClientTest (unittest.TestCase): allowed_keys = FINGERPRINTS.keys() self.socks, addr = self.sockl.accept() self.ts = paramiko.Transport(self.socks) - host_key = paramiko.RSAKey.from_private_key_file(test_path('test_rsa.key')) + keypath = test_path('test_rsa.key') + host_key = paramiko.RSAKey.from_private_key_file(keypath) + self.ts.add_server_key(host_key) + keypath = test_path('test_ecdsa_256.key') + host_key = paramiko.ECDSAKey.from_private_key_file(keypath) self.ts.add_server_key(host_key) server = NullServer(allowed_keys=allowed_keys) if delay: @@ -246,8 +253,9 @@ class SSHClientTest (unittest.TestCase): verify that SSHClient's AutoAddPolicy works. """ threading.Thread(target=self._run).start() - host_key = paramiko.RSAKey.from_private_key_file(test_path('test_rsa.key')) - public_host_key = paramiko.RSAKey(data=host_key.asbytes()) + hostname = '[%s]:%d' % (self.addr, self.port) + key_file = test_path('test_ecdsa_256.key') + public_host_key = paramiko.ECDSAKey.from_private_key_file(key_file) self.tc = paramiko.SSHClient() self.tc.set_missing_host_key_policy(paramiko.AutoAddPolicy()) @@ -260,7 +268,8 @@ class SSHClientTest (unittest.TestCase): self.assertEqual('slowdive', self.ts.get_username()) self.assertEqual(True, self.ts.is_authenticated()) self.assertEqual(1, len(self.tc.get_host_keys())) - self.assertEqual(public_host_key, self.tc.get_host_keys()['[%s]:%d' % (self.addr, self.port)]['ssh-rsa']) + new_host_key = list(self.tc.get_host_keys()[hostname].values())[0] + self.assertEqual(public_host_key, new_host_key) def test_5_save_host_keys(self): """ @@ -294,13 +303,10 @@ class SSHClientTest (unittest.TestCase): verify that when an SSHClient is collected, its transport (and the transport's packetizer) is closed. """ - # Unclear why this is borked on Py3, but it is, and does not seem worth - # pursuing at the moment. Skipped on PyPy because it fails on travis - # for unknown reasons, works fine locally. - # XXX: It's the release of the references to e.g packetizer that fails - # in py3... - if not PY2 or platform.python_implementation() == "PyPy": + # Skipped on PyPy because it fails on travis for unknown reasons + if platform.python_implementation() == "PyPy": return + threading.Thread(target=self._run).start() self.tc = paramiko.SSHClient() @@ -318,8 +324,8 @@ class SSHClientTest (unittest.TestCase): del self.tc # force a collection to see whether the SSHClient object is deallocated - # correctly. 2 GCs are needed to make sure it's really collected on - # PyPy + # 2 GCs are needed on PyPy, time is needed for Python 3 + time.sleep(0.3) gc.collect() gc.collect() @@ -384,6 +390,64 @@ class SSHClientTest (unittest.TestCase): ) self._test_connection(**kwargs) + def test_9_auth_timeout(self): + """ + verify that the SSHClient has a configurable auth timeout + """ + # Connect with a half second auth timeout + self.assertRaises( + AuthenticationException, + self._test_connection, + password='unresponsive-server', + auth_timeout=0.5, + ) + + def _client_host_key_bad(self, host_key): + threading.Thread(target=self._run).start() + hostname = '[%s]:%d' % (self.addr, self.port) + + self.tc = paramiko.SSHClient() + self.tc.set_missing_host_key_policy(paramiko.WarningPolicy()) + known_hosts = self.tc.get_host_keys() + known_hosts.add(hostname, host_key.get_name(), host_key) + + self.assertRaises( + paramiko.BadHostKeyException, + self.tc.connect, + password='pygmalion', + **self.connect_kwargs + ) + + def _client_host_key_good(self, ktype, kfile): + threading.Thread(target=self._run).start() + hostname = '[%s]:%d' % (self.addr, self.port) + + self.tc = paramiko.SSHClient() + self.tc.set_missing_host_key_policy(paramiko.RejectPolicy()) + host_key = ktype.from_private_key_file(test_path(kfile)) + known_hosts = self.tc.get_host_keys() + known_hosts.add(hostname, host_key.get_name(), host_key) + + self.tc.connect(password='pygmalion', **self.connect_kwargs) + self.event.wait(1.0) + self.assertTrue(self.event.is_set()) + self.assertTrue(self.ts.is_active()) + self.assertEqual(True, self.ts.is_authenticated()) + + def test_host_key_negotiation_1(self): + host_key = paramiko.ECDSAKey.generate() + self._client_host_key_bad(host_key) + + def test_host_key_negotiation_2(self): + host_key = paramiko.RSAKey.generate(2048) + self._client_host_key_bad(host_key) + + def test_host_key_negotiation_3(self): + self._client_host_key_good(paramiko.ECDSAKey, 'test_ecdsa_256.key') + + def test_host_key_negotiation_4(self): + self._client_host_key_good(paramiko.RSAKey, 'test_rsa.key') + def test_update_environment(self): """ Verify that environment variables can be set by the client. @@ -418,3 +482,20 @@ class SSHClientTest (unittest.TestCase): 'Expected original SSHException in exception') else: self.assertFalse(False, 'SSHException was not thrown.') + + + def test_missing_key_policy_accepts_classes_or_instances(self): + """ + Client.missing_host_key_policy() can take classes or instances. + """ + # AN ACTUAL UNIT TEST?! GOOD LORD + # (But then we have to test a private API...meh.) + client = paramiko.SSHClient() + # Default + assert isinstance(client._policy, paramiko.RejectPolicy) + # Hand in an instance (classic behavior) + client.set_missing_host_key_policy(paramiko.AutoAddPolicy()) + assert isinstance(client._policy, paramiko.AutoAddPolicy) + # Hand in just the class (new behavior) + client.set_missing_host_key_policy(paramiko.AutoAddPolicy) + assert isinstance(client._policy, paramiko.AutoAddPolicy) diff --git a/tests/test_file.py b/tests/test_file.py index 7fab6985..b33ecd51 100755 --- a/tests/test_file.py +++ b/tests/test_file.py @@ -21,10 +21,14 @@ Some unit tests for the BufferedFile abstraction. """ import unittest -from paramiko.file import BufferedFile -from paramiko.common import linefeed_byte, crlf, cr_byte import sys +from paramiko.common import linefeed_byte, crlf, cr_byte +from paramiko.file import BufferedFile +from paramiko.py3compat import BytesIO + +from tests import skipUnlessBuiltin + class LoopbackFile (BufferedFile): """ @@ -33,19 +37,16 @@ class LoopbackFile (BufferedFile): def __init__(self, mode='r', bufsize=-1): BufferedFile.__init__(self) self._set_mode(mode, bufsize) - self.buffer = bytes() + self.buffer = BytesIO() + self.offset = 0 def _read(self, size): - if len(self.buffer) == 0: - return None - if size > len(self.buffer): - size = len(self.buffer) - data = self.buffer[:size] - self.buffer = self.buffer[size:] + data = self.buffer.getvalue()[self.offset:self.offset+size] + self.offset += len(data) return data def _write(self, data): - self.buffer += data + self.buffer.write(data) return len(data) @@ -187,6 +188,42 @@ class BufferedFileTest (unittest.TestCase): self.assertEqual(data, b'hello') f.close() + def test_write_bad_type(self): + with LoopbackFile('wb') as f: + self.assertRaises(TypeError, f.write, object()) + + def test_write_unicode_as_binary(self): + text = u"\xa7 why is writing text to a binary file allowed?\n" + with LoopbackFile('rb+') as f: + f.write(text) + self.assertEqual(f.read(), text.encode("utf-8")) + + @skipUnlessBuiltin('memoryview') + def test_write_bytearray(self): + with LoopbackFile('rb+') as f: + f.write(bytearray(12)) + self.assertEqual(f.read(), 12 * b"\0") + + @skipUnlessBuiltin('buffer') + def test_write_buffer(self): + data = 3 * b"pretend giant block of data\n" + offsets = range(0, len(data), 8) + with LoopbackFile('rb+') as f: + for offset in offsets: + f.write(buffer(data, offset, 8)) + self.assertEqual(f.read(), data) + + @skipUnlessBuiltin('memoryview') + def test_write_memoryview(self): + data = 3 * b"pretend giant block of data\n" + offsets = range(0, len(data), 8) + with LoopbackFile('rb+') as f: + view = memoryview(data) + for offset in offsets: + f.write(view[offset:offset+8]) + self.assertEqual(f.read(), data) + + if __name__ == '__main__': from unittest import main main() diff --git a/tests/test_kex.py b/tests/test_kex.py index 19804fbf..b7f588f7 100644 --- a/tests/test_kex.py +++ b/tests/test_kex.py @@ -20,7 +20,7 @@ Some unit tests for the key exchange protocols. """ -from binascii import hexlify +from binascii import hexlify, unhexlify import os import unittest @@ -29,11 +29,24 @@ from paramiko.kex_group1 import KexGroup1 from paramiko.kex_gex import KexGex, KexGexSHA256 from paramiko import Message from paramiko.common import byte_chr +from paramiko.kex_ecdh_nist import KexNistp256 +from cryptography.hazmat.backends import default_backend +from cryptography.hazmat.primitives.asymmetric import ec def dummy_urandom(n): return byte_chr(0xcc) * n +def dummy_generate_key_pair(obj): + private_key_value = 94761803665136558137557783047955027733968423115106677159790289642479432803037 + public_key_numbers = "042bdab212fa8ba1b7c843301682a4db424d307246c7e1e6083c41d9ca7b098bf30b3d63e2ec6278488c135360456cc054b3444ecc45998c08894cbc1370f5f989" + public_key_numbers_obj = ec.EllipticCurvePublicNumbers.from_encoded_point(ec.SECP256R1(), unhexlify(public_key_numbers)) + obj.P = ec.EllipticCurvePrivateNumbers(private_value=private_key_value, public_numbers=public_key_numbers_obj).private_key(default_backend()) + if obj.transport.server_mode: + obj.Q_S = ec.EllipticCurvePublicNumbers.from_encoded_point(ec.SECP256R1(), unhexlify(public_key_numbers)).public_key(default_backend()) + return + obj.Q_C = ec.EllipticCurvePublicNumbers.from_encoded_point(ec.SECP256R1(), unhexlify(public_key_numbers)).public_key(default_backend()) + class FakeKey (object): def __str__(self): @@ -93,9 +106,12 @@ class KexTest (unittest.TestCase): def setUp(self): self._original_urandom = os.urandom os.urandom = dummy_urandom + self._original_generate_key_pair = KexNistp256._generate_key_pair + KexNistp256._generate_key_pair = dummy_generate_key_pair def tearDown(self): os.urandom = self._original_urandom + KexNistp256._generate_key_pair = self._original_generate_key_pair def test_1_group1_client(self): transport = FakeTransport() @@ -369,4 +385,43 @@ class KexTest (unittest.TestCase): self.assertEqual(x, hexlify(transport._message.asbytes()).upper()) self.assertTrue(transport._activated) + def test_11_kex_nistp256_client(self): + K = 91610929826364598472338906427792435253694642563583721654249504912114314269754 + transport = FakeTransport() + transport.server_mode = False + kex = KexNistp256(transport) + kex.start_kex() + self.assertEqual((paramiko.kex_ecdh_nist._MSG_KEXECDH_REPLY,), transport._expect) + + #fake reply + msg = Message() + msg.add_string('fake-host-key') + Q_S = unhexlify("043ae159594ba062efa121480e9ef136203fa9ec6b6e1f8723a321c16e62b945f573f3b822258cbcd094b9fa1c125cbfe5f043280893e66863cc0cb4dccbe70210") + msg.add_string(Q_S) + msg.add_string('fake-sig') + msg.rewind() + kex.parse_next(paramiko.kex_ecdh_nist._MSG_KEXECDH_REPLY, msg) + H = b'BAF7CE243A836037EB5D2221420F35C02B9AB6C957FE3BDE3369307B9612570A' + self.assertEqual(K, kex.transport._K) + self.assertEqual(H, hexlify(transport._H).upper()) + self.assertEqual((b'fake-host-key', b'fake-sig'), transport._verify) + self.assertTrue(transport._activated) + + def test_12_kex_nistp256_server(self): + K = 91610929826364598472338906427792435253694642563583721654249504912114314269754 + transport = FakeTransport() + transport.server_mode = True + kex = KexNistp256(transport) + kex.start_kex() + self.assertEqual((paramiko.kex_ecdh_nist._MSG_KEXECDH_INIT,), transport._expect) + #fake init + msg=Message() + Q_C = unhexlify("043ae159594ba062efa121480e9ef136203fa9ec6b6e1f8723a321c16e62b945f573f3b822258cbcd094b9fa1c125cbfe5f043280893e66863cc0cb4dccbe70210") + H = b'2EF4957AFD530DD3F05DBEABF68D724FACC060974DA9704F2AEE4C3DE861E7CA' + msg.add_string(Q_C) + msg.rewind() + kex.parse_next(paramiko.kex_ecdh_nist._MSG_KEXECDH_INIT, msg) + self.assertEqual(K, transport._K) + self.assertTrue(transport._activated) + self.assertEqual(H, hexlify(transport._H).upper()) diff --git a/tests/test_pkey.py b/tests/test_pkey.py index a26ff170..9bb3c44c 100644 --- a/tests/test_pkey.py +++ b/tests/test_pkey.py @@ -113,6 +113,25 @@ TEST_KEY_BYTESTR_3 = '\x00\x00\x00\x07ssh-rsa\x00\x00\x00\x01#\x00\x00\x00\x00ÓŹ class KeyTest(unittest.TestCase): + + def setUp(self): + pass + + def tearDown(self): + pass + + def assert_keyfile_is_encrypted(self, keyfile): + """ + A quick check that filename looks like an encrypted key. + """ + with open(keyfile, "r") as fh: + self.assertEqual( + fh.readline()[:-1], + "-----BEGIN RSA PRIVATE KEY-----" + ) + self.assertEqual(fh.readline()[:-1], "Proc-Type: 4,ENCRYPTED") + self.assertEqual(fh.readline()[0:10], "DEK-Info: ") + def test_1_generate_key_bytes(self): key = util.generate_key_bytes(md5, x1234, 'happy birthday', 30) exp = b'\x61\xE1\xF2\x72\xF4\xC1\xC4\x56\x15\x86\xBD\x32\x24\x98\xC0\xE9\x24\x67\x27\x80\xF4\x7B\xB3\x7D\xDA\x7D\x54\x01\x9E\x64' @@ -419,6 +438,7 @@ class KeyTest(unittest.TestCase): # When the bug under test exists, this will ValueError. try: key.write_private_key_file(newfile, password=newpassword) + self.assert_keyfile_is_encrypted(newfile) # Verify the inner key data still matches (when no ValueError) key2 = RSAKey(filename=newfile, password=newpassword) self.assertEqual(key, key2) @@ -435,5 +455,28 @@ class KeyTest(unittest.TestCase): key2 = Ed25519Key.from_private_key_file( test_path('test_ed25519_password.key'), b'abc123' ) - self.assertNotEqual(key1.asbytes(), key2.asbytes()) + + def test_ed25519_compare(self): + # verify that the private & public keys compare equal + key = Ed25519Key.from_private_key_file(test_path('test_ed25519.key')) + self.assertEqual(key, key) + pub = Ed25519Key(data=key.asbytes()) + self.assertTrue(key.can_sign()) + self.assertTrue(not pub.can_sign()) + self.assertEqual(key, pub) + + def test_keyfile_is_actually_encrypted(self): + # Read an existing encrypted private key + file_ = test_path('test_rsa_password.key') + password = 'television' + newfile = file_ + '.new' + newpassword = 'radio' + key = RSAKey(filename=file_, password=password) + # Write out a newly re-encrypted copy with a new password. + # When the bug under test exists, this will ValueError. + try: + key.write_private_key_file(newfile, password=newpassword) + self.assert_keyfile_is_encrypted(newfile) + finally: + os.remove(newfile) diff --git a/tests/test_sftp.py b/tests/test_sftp.py index d3064fff..b3c7bf98 100755 --- a/tests/test_sftp.py +++ b/tests/test_sftp.py @@ -35,6 +35,7 @@ from tempfile import mkstemp import paramiko from paramiko.py3compat import PY2, b, u, StringIO from paramiko.common import o777, o600, o666, o644 +from tests import skipUnlessBuiltin from tests.stub_sftp import StubServer, StubSFTPServer from tests.loop import LoopSocket from tests.util import test_path @@ -276,6 +277,39 @@ class SFTPTest (unittest.TestCase): except: pass + + def test_5a_posix_rename(self): + """Test posix-rename@openssh.com protocol extension.""" + try: + # first check that the normal rename works as specified + with sftp.open(FOLDER + '/a', 'w') as f: + f.write('one') + sftp.rename(FOLDER + '/a', FOLDER + '/b') + with sftp.open(FOLDER + '/a', 'w') as f: + f.write('two') + try: + sftp.rename(FOLDER + '/a', FOLDER + '/b') + self.assertTrue(False, 'no exception when rename-ing onto existing file') + except (OSError, IOError): + pass + + # now check with the posix_rename + sftp.posix_rename(FOLDER + '/a', FOLDER + '/b') + with sftp.open(FOLDER + '/b', 'r') as f: + data = u(f.read()) + self.assertEqual('two', data, "Contents of renamed file not the same as original file") + + finally: + try: + sftp.remove(FOLDER + '/a') + except: + pass + try: + sftp.remove(FOLDER + '/b') + except: + pass + + def test_6_folder(self): """ create a temporary folder, verify that we can create a file in it, then @@ -817,6 +851,35 @@ class SFTPTest (unittest.TestCase): sftp_attributes = SFTPAttributes() self.assertEqual(str(sftp_attributes), "?--------- 1 0 0 0 (unknown date) ?") + @skipUnlessBuiltin('buffer') + def test_write_buffer(self): + """Test write() using a buffer instance.""" + data = 3 * b'A potentially large block of data to chunk up.\n' + try: + with sftp.open('%s/write_buffer' % FOLDER, 'wb') as f: + for offset in range(0, len(data), 8): + f.write(buffer(data, offset, 8)) + + with sftp.open('%s/write_buffer' % FOLDER, 'rb') as f: + self.assertEqual(f.read(), data) + finally: + sftp.remove('%s/write_buffer' % FOLDER) + + @skipUnlessBuiltin('memoryview') + def test_write_memoryview(self): + """Test write() using a memoryview instance.""" + data = 3 * b'A potentially large block of data to chunk up.\n' + try: + with sftp.open('%s/write_memoryview' % FOLDER, 'wb') as f: + view = memoryview(data) + for offset in range(0, len(data), 8): + f.write(view[offset:offset+8]) + + with sftp.open('%s/write_memoryview' % FOLDER, 'rb') as f: + self.assertEqual(f.read(), data) + finally: + sftp.remove('%s/write_memoryview' % FOLDER) + if __name__ == '__main__': SFTPTest.init_loopback() diff --git a/tests/test_transport.py b/tests/test_transport.py index 2ebdf854..3e352919 100644 --- a/tests/test_transport.py +++ b/tests/test_transport.py @@ -43,6 +43,7 @@ from paramiko.common import ( ) from paramiko.py3compat import bytes from paramiko.message import Message +from tests import skipUnlessBuiltin from tests.loop import LoopSocket from tests.util import test_path @@ -165,6 +166,15 @@ class TransportTest(unittest.TestCase): except TypeError: pass + def test_1b_security_options_reset(self): + o = self.tc.get_security_options() + # should not throw any exceptions + o.ciphers = o.ciphers + o.digests = o.digests + o.key_types = o.key_types + o.kex = o.kex + o.compression = o.compression + def test_2_compute_key(self): self.tc.K = 123281095979686581523377256114209720774539068973101330872763622971399429481072519713536292772709507296759612401802191955568143056534122385270077606457721553469730659233569339356140085284052436697480759510519672848743794433460113118986816826624865291116513647975790797391795651716378444844877749505443714557929 self.tc.H = b'\x0C\x83\x07\xCD\xE6\x85\x6F\xF3\x0B\xA9\x36\x84\xEB\x0F\x04\xC2\x52\x0E\x9E\xD3' @@ -849,3 +859,71 @@ class TransportTest(unittest.TestCase): self.assertEqual([chan], r) self.assertEqual([], w) self.assertEqual([], e) + + def test_channel_send_misc(self): + """ + verify behaviours sending various instances to a channel + """ + self.setup_test_server() + text = u"\xa7 slice me nicely" + with self.tc.open_session() as chan: + schan = self.ts.accept(1.0) + if schan is None: + self.fail("Test server transport failed to accept") + sfile = schan.makefile() + + # TypeError raised on non string or buffer type + self.assertRaises(TypeError, chan.send, object()) + self.assertRaises(TypeError, chan.sendall, object()) + + # sendall() accepts a unicode instance + chan.sendall(text) + expected = text.encode("utf-8") + self.assertEqual(sfile.read(len(expected)), expected) + + @skipUnlessBuiltin('buffer') + def test_channel_send_buffer(self): + """ + verify sending buffer instances to a channel + """ + self.setup_test_server() + data = 3 * b'some test data\n whole' + with self.tc.open_session() as chan: + schan = self.ts.accept(1.0) + if schan is None: + self.fail("Test server transport failed to accept") + sfile = schan.makefile() + + # send() accepts buffer instances + sent = 0 + while sent < len(data): + sent += chan.send(buffer(data, sent, 8)) + self.assertEqual(sfile.read(len(data)), data) + + # sendall() accepts a buffer instance + chan.sendall(buffer(data)) + self.assertEqual(sfile.read(len(data)), data) + + @skipUnlessBuiltin('memoryview') + def test_channel_send_memoryview(self): + """ + verify sending memoryview instances to a channel + """ + self.setup_test_server() + data = 3 * b'some test data\n whole' + with self.tc.open_session() as chan: + schan = self.ts.accept(1.0) + if schan is None: + self.fail("Test server transport failed to accept") + sfile = schan.makefile() + + # send() accepts memoryview slices + sent = 0 + view = memoryview(data) + while sent < len(view): + sent += chan.send(view[sent:sent+8]) + self.assertEqual(sfile.read(len(data)), data) + + # sendall() accepts a memoryview instance + chan.sendall(memoryview(data)) + self.assertEqual(sfile.read(len(data)), data) |