From c2c402cb6e4a4f86ff3053389bb6300c4b4505f1 Mon Sep 17 00:00:00 2001 From: Martin Packman Date: Thu, 25 May 2017 22:14:19 +0100 Subject: Allow any buffer type to be sent to Channel Fixes #968 Changes the behaviour of the underlying asbytes helper to pass along unknown types. Most callers already handle this by passing the bytes along to a file or socket-like object which will raise TypeError anyway. Adds test coverage through the Transport implementation. Change against the 1.17 branch. --- paramiko/common.py | 19 +++++++------- tests/test_transport.py | 69 +++++++++++++++++++++++++++++++++++++++++++++++++ 2 files changed, 79 insertions(+), 9 deletions(-) diff --git a/paramiko/common.py b/paramiko/common.py index 556f046a..3b18858b 100644 --- a/paramiko/common.py +++ b/paramiko/common.py @@ -21,7 +21,7 @@ Common constants and global variables. """ import logging from paramiko.py3compat import ( - byte_chr, PY2, bytes_types, string_types, b, long, + byte_chr, PY2, bytes_types, text_type, string_types, b, long, ) MSG_DISCONNECT, MSG_IGNORE, MSG_UNIMPLEMENTED, MSG_DEBUG, \ @@ -163,14 +163,15 @@ else: def asbytes(s): - if not isinstance(s, bytes_types): - if isinstance(s, string_types): - s = b(s) - else: - try: - s = s.asbytes() - except Exception: - raise Exception('Unknown type') + if isinstance(s, bytes_types): + return s + if isinstance(s, text_type): + # GZ 2017-05-25: Accept text and encode as utf-8 for compatibilty only. + return s.encode("utf-8") + asbytes = getattr(s, "asbytes", None) + if asbytes is not None: + return asbytes() + # May be an object that implements the buffer api, let callers decide return s diff --git a/tests/test_transport.py b/tests/test_transport.py index c426cef1..3e352919 100644 --- a/tests/test_transport.py +++ b/tests/test_transport.py @@ -43,6 +43,7 @@ from paramiko.common import ( ) from paramiko.py3compat import bytes from paramiko.message import Message +from tests import skipUnlessBuiltin from tests.loop import LoopSocket from tests.util import test_path @@ -858,3 +859,71 @@ class TransportTest(unittest.TestCase): self.assertEqual([chan], r) self.assertEqual([], w) self.assertEqual([], e) + + def test_channel_send_misc(self): + """ + verify behaviours sending various instances to a channel + """ + self.setup_test_server() + text = u"\xa7 slice me nicely" + with self.tc.open_session() as chan: + schan = self.ts.accept(1.0) + if schan is None: + self.fail("Test server transport failed to accept") + sfile = schan.makefile() + + # TypeError raised on non string or buffer type + self.assertRaises(TypeError, chan.send, object()) + self.assertRaises(TypeError, chan.sendall, object()) + + # sendall() accepts a unicode instance + chan.sendall(text) + expected = text.encode("utf-8") + self.assertEqual(sfile.read(len(expected)), expected) + + @skipUnlessBuiltin('buffer') + def test_channel_send_buffer(self): + """ + verify sending buffer instances to a channel + """ + self.setup_test_server() + data = 3 * b'some test data\n whole' + with self.tc.open_session() as chan: + schan = self.ts.accept(1.0) + if schan is None: + self.fail("Test server transport failed to accept") + sfile = schan.makefile() + + # send() accepts buffer instances + sent = 0 + while sent < len(data): + sent += chan.send(buffer(data, sent, 8)) + self.assertEqual(sfile.read(len(data)), data) + + # sendall() accepts a buffer instance + chan.sendall(buffer(data)) + self.assertEqual(sfile.read(len(data)), data) + + @skipUnlessBuiltin('memoryview') + def test_channel_send_memoryview(self): + """ + verify sending memoryview instances to a channel + """ + self.setup_test_server() + data = 3 * b'some test data\n whole' + with self.tc.open_session() as chan: + schan = self.ts.accept(1.0) + if schan is None: + self.fail("Test server transport failed to accept") + sfile = schan.makefile() + + # send() accepts memoryview slices + sent = 0 + view = memoryview(data) + while sent < len(view): + sent += chan.send(view[sent:sent+8]) + self.assertEqual(sfile.read(len(data)), data) + + # sendall() accepts a memoryview instance + chan.sendall(memoryview(data)) + self.assertEqual(sfile.read(len(data)), data) -- cgit v1.2.3