summaryrefslogtreecommitdiffhomepage
diff options
context:
space:
mode:
authorMartin Packman <gzlist@googlemail.com>2017-05-25 22:14:19 +0100
committerJeff Forcier <jeff@bitprophet.org>2017-06-09 13:42:49 -0700
commitc2c402cb6e4a4f86ff3053389bb6300c4b4505f1 (patch)
treefe7cc0aa3252f97c2c836fecfacb1eaacfbe67ea
parent02b0aef5932438c2dd0b7d649a6e61450fb71be3 (diff)
Allow any buffer type to be sent to Channel
Fixes #968 Changes the behaviour of the underlying asbytes helper to pass along unknown types. Most callers already handle this by passing the bytes along to a file or socket-like object which will raise TypeError anyway. Adds test coverage through the Transport implementation. Change against the 1.17 branch.
-rw-r--r--paramiko/common.py19
-rw-r--r--tests/test_transport.py69
2 files changed, 79 insertions, 9 deletions
diff --git a/paramiko/common.py b/paramiko/common.py
index 556f046a..3b18858b 100644
--- a/paramiko/common.py
+++ b/paramiko/common.py
@@ -21,7 +21,7 @@ Common constants and global variables.
"""
import logging
from paramiko.py3compat import (
- byte_chr, PY2, bytes_types, string_types, b, long,
+ byte_chr, PY2, bytes_types, text_type, string_types, b, long,
)
MSG_DISCONNECT, MSG_IGNORE, MSG_UNIMPLEMENTED, MSG_DEBUG, \
@@ -163,14 +163,15 @@ else:
def asbytes(s):
- if not isinstance(s, bytes_types):
- if isinstance(s, string_types):
- s = b(s)
- else:
- try:
- s = s.asbytes()
- except Exception:
- raise Exception('Unknown type')
+ if isinstance(s, bytes_types):
+ return s
+ if isinstance(s, text_type):
+ # GZ 2017-05-25: Accept text and encode as utf-8 for compatibilty only.
+ return s.encode("utf-8")
+ asbytes = getattr(s, "asbytes", None)
+ if asbytes is not None:
+ return asbytes()
+ # May be an object that implements the buffer api, let callers decide
return s
diff --git a/tests/test_transport.py b/tests/test_transport.py
index c426cef1..3e352919 100644
--- a/tests/test_transport.py
+++ b/tests/test_transport.py
@@ -43,6 +43,7 @@ from paramiko.common import (
)
from paramiko.py3compat import bytes
from paramiko.message import Message
+from tests import skipUnlessBuiltin
from tests.loop import LoopSocket
from tests.util import test_path
@@ -858,3 +859,71 @@ class TransportTest(unittest.TestCase):
self.assertEqual([chan], r)
self.assertEqual([], w)
self.assertEqual([], e)
+
+ def test_channel_send_misc(self):
+ """
+ verify behaviours sending various instances to a channel
+ """
+ self.setup_test_server()
+ text = u"\xa7 slice me nicely"
+ with self.tc.open_session() as chan:
+ schan = self.ts.accept(1.0)
+ if schan is None:
+ self.fail("Test server transport failed to accept")
+ sfile = schan.makefile()
+
+ # TypeError raised on non string or buffer type
+ self.assertRaises(TypeError, chan.send, object())
+ self.assertRaises(TypeError, chan.sendall, object())
+
+ # sendall() accepts a unicode instance
+ chan.sendall(text)
+ expected = text.encode("utf-8")
+ self.assertEqual(sfile.read(len(expected)), expected)
+
+ @skipUnlessBuiltin('buffer')
+ def test_channel_send_buffer(self):
+ """
+ verify sending buffer instances to a channel
+ """
+ self.setup_test_server()
+ data = 3 * b'some test data\n whole'
+ with self.tc.open_session() as chan:
+ schan = self.ts.accept(1.0)
+ if schan is None:
+ self.fail("Test server transport failed to accept")
+ sfile = schan.makefile()
+
+ # send() accepts buffer instances
+ sent = 0
+ while sent < len(data):
+ sent += chan.send(buffer(data, sent, 8))
+ self.assertEqual(sfile.read(len(data)), data)
+
+ # sendall() accepts a buffer instance
+ chan.sendall(buffer(data))
+ self.assertEqual(sfile.read(len(data)), data)
+
+ @skipUnlessBuiltin('memoryview')
+ def test_channel_send_memoryview(self):
+ """
+ verify sending memoryview instances to a channel
+ """
+ self.setup_test_server()
+ data = 3 * b'some test data\n whole'
+ with self.tc.open_session() as chan:
+ schan = self.ts.accept(1.0)
+ if schan is None:
+ self.fail("Test server transport failed to accept")
+ sfile = schan.makefile()
+
+ # send() accepts memoryview slices
+ sent = 0
+ view = memoryview(data)
+ while sent < len(view):
+ sent += chan.send(view[sent:sent+8])
+ self.assertEqual(sfile.read(len(data)), data)
+
+ # sendall() accepts a memoryview instance
+ chan.sendall(memoryview(data))
+ self.assertEqual(sfile.read(len(data)), data)