diff options
-rw-r--r-- | paramiko/common.py | 22 | ||||
-rw-r--r-- | paramiko/file.py | 6 | ||||
-rw-r--r-- | paramiko/sftp_client.py | 11 | ||||
-rw-r--r-- | sites/www/changelog.rst | 5 | ||||
-rw-r--r-- | tests/__init__.py | 36 | ||||
-rw-r--r-- | tests/test_client.py | 2 | ||||
-rwxr-xr-x | tests/test_file.py | 57 | ||||
-rwxr-xr-x | tests/test_sftp.py | 30 | ||||
-rw-r--r-- | tests/test_transport.py | 69 |
9 files changed, 207 insertions, 31 deletions
diff --git a/paramiko/common.py b/paramiko/common.py index 556f046a..0012372a 100644 --- a/paramiko/common.py +++ b/paramiko/common.py @@ -20,9 +20,7 @@ Common constants and global variables. """ import logging -from paramiko.py3compat import ( - byte_chr, PY2, bytes_types, string_types, b, long, -) +from paramiko.py3compat import byte_chr, PY2, bytes_types, text_type, long MSG_DISCONNECT, MSG_IGNORE, MSG_UNIMPLEMENTED, MSG_DEBUG, \ MSG_SERVICE_REQUEST, MSG_SERVICE_ACCEPT = range(1, 7) @@ -163,14 +161,16 @@ else: def asbytes(s): - if not isinstance(s, bytes_types): - if isinstance(s, string_types): - s = b(s) - else: - try: - s = s.asbytes() - except Exception: - raise Exception('Unknown type') + """Coerce to bytes if possible or return unchanged.""" + if isinstance(s, bytes_types): + return s + if isinstance(s, text_type): + # Accept text and encode as utf-8 for compatibility only. + return s.encode("utf-8") + asbytes = getattr(s, "asbytes", None) + if asbytes is not None: + return asbytes() + # May be an object that implements the buffer api, let callers handle. return s diff --git a/paramiko/file.py b/paramiko/file.py index 5212091a..a1bdafbe 100644 --- a/paramiko/file.py +++ b/paramiko/file.py @@ -18,7 +18,7 @@ from paramiko.common import ( linefeed_byte_value, crlf, cr_byte, linefeed_byte, cr_byte_value, ) -from paramiko.py3compat import BytesIO, PY2, u, b, bytes_types +from paramiko.py3compat import BytesIO, PY2, u, bytes_types, text_type from paramiko.util import ClosingContextManager @@ -391,7 +391,9 @@ class BufferedFile (ClosingContextManager): :param data: ``str``/``bytes`` data to write """ - data = b(data) + if isinstance(data, text_type): + # Accept text and encode as utf-8 for compatibility only. + data = data.encode('utf-8') if self._closed: raise IOError('File is closed') if not (self._flags & self.FLAG_WRITE): diff --git a/paramiko/sftp_client.py b/paramiko/sftp_client.py index 12fccb2f..ee5ab073 100644 --- a/paramiko/sftp_client.py +++ b/paramiko/sftp_client.py @@ -28,9 +28,7 @@ from paramiko import util from paramiko.channel import Channel from paramiko.message import Message from paramiko.common import INFO, DEBUG, o777 -from paramiko.py3compat import ( - bytestring, b, u, long, string_types, bytes_types, -) +from paramiko.py3compat import bytestring, b, u, long from paramiko.sftp import ( BaseSFTP, CMD_OPENDIR, CMD_HANDLE, SFTPError, CMD_READDIR, CMD_NAME, CMD_CLOSE, SFTP_FLAG_READ, SFTP_FLAG_WRITE, SFTP_FLAG_CREATE, @@ -758,13 +756,12 @@ class SFTPClient(BaseSFTP, ClosingContextManager): msg.add_int64(item) elif isinstance(item, int): msg.add_int(item) - elif isinstance(item, (string_types, bytes_types)): - msg.add_string(item) elif isinstance(item, SFTPAttributes): item._pack(msg) else: - raise Exception( - 'unknown type for %r type %r' % (item, type(item))) + # For all other types, rely on as_string() to either coerce + # to bytes before writing or raise a suitable exception. + msg.add_string(item) num = self.request_number self._expecting[num] = fileobj self.request_number += 1 diff --git a/sites/www/changelog.rst b/sites/www/changelog.rst index e9adee21..ec547b58 100644 --- a/sites/www/changelog.rst +++ b/sites/www/changelog.rst @@ -2,6 +2,11 @@ Changelog ========= +* :bug:`971 (1.17+)` Allow any type implementing the buffer API to be used with + `BufferedFile <paramiko.file.BufferedFile>`, `Channel + <paramiko.channel.Channel>`, and `SFTPFile <paramiko.sftp_file.SFTPFile>`. + This resolves a regression introduced in 1.13 with the Python 3 porting + changes, when using types such as ``memoryview``. Credit: Martin Packman. * :bug:`741` (also :issue:`809`, :issue:`772`; all via :issue:`912`) Writing encrypted/password-protected private key files was silently broken since 2.0 due to an incorrect API call; this has been fixed. diff --git a/tests/__init__.py b/tests/__init__.py index e69de29b..8878f14d 100644 --- a/tests/__init__.py +++ b/tests/__init__.py @@ -0,0 +1,36 @@ +# Copyright (C) 2017 Martin Packman <gzlist@googlemail.com> +# +# This file is part of paramiko. +# +# Paramiko is free software; you can redistribute it and/or modify it under the +# terms of the GNU Lesser General Public License as published by the Free +# Software Foundation; either version 2.1 of the License, or (at your option) +# any later version. +# +# Paramiko is distributed in the hope that it will be useful, but WITHOUT ANY +# WARRANTY; without even the implied warranty of MERCHANTABILITY or FITNESS FOR +# A PARTICULAR PURPOSE. See the GNU Lesser General Public License for more +# details. +# +# You should have received a copy of the GNU Lesser General Public License +# along with Paramiko; if not, write to the Free Software Foundation, Inc., +# 59 Temple Place, Suite 330, Boston, MA 02111-1307 USA. + +"""Base classes and helpers for testing paramiko.""" + +import unittest + +from paramiko.py3compat import ( + builtins, + ) + + +def skipUnlessBuiltin(name): + """Skip decorated test if builtin name does not exist.""" + if getattr(builtins, name, None) is None: + skip = getattr(unittest, "skip", None) + if skip is None: + # Python 2.6 pseudo-skip + return lambda func: None + return skip("No builtin " + repr(name)) + return lambda func: func diff --git a/tests/test_client.py b/tests/test_client.py index 229df991..cc2d067f 100644 --- a/tests/test_client.py +++ b/tests/test_client.py @@ -35,7 +35,7 @@ import time from tests.util import test_path import paramiko -from paramiko.common import PY2 +from paramiko.py3compat import PY2, b from paramiko.ssh_exception import SSHException diff --git a/tests/test_file.py b/tests/test_file.py index 7fab6985..b33ecd51 100755 --- a/tests/test_file.py +++ b/tests/test_file.py @@ -21,10 +21,14 @@ Some unit tests for the BufferedFile abstraction. """ import unittest -from paramiko.file import BufferedFile -from paramiko.common import linefeed_byte, crlf, cr_byte import sys +from paramiko.common import linefeed_byte, crlf, cr_byte +from paramiko.file import BufferedFile +from paramiko.py3compat import BytesIO + +from tests import skipUnlessBuiltin + class LoopbackFile (BufferedFile): """ @@ -33,19 +37,16 @@ class LoopbackFile (BufferedFile): def __init__(self, mode='r', bufsize=-1): BufferedFile.__init__(self) self._set_mode(mode, bufsize) - self.buffer = bytes() + self.buffer = BytesIO() + self.offset = 0 def _read(self, size): - if len(self.buffer) == 0: - return None - if size > len(self.buffer): - size = len(self.buffer) - data = self.buffer[:size] - self.buffer = self.buffer[size:] + data = self.buffer.getvalue()[self.offset:self.offset+size] + self.offset += len(data) return data def _write(self, data): - self.buffer += data + self.buffer.write(data) return len(data) @@ -187,6 +188,42 @@ class BufferedFileTest (unittest.TestCase): self.assertEqual(data, b'hello') f.close() + def test_write_bad_type(self): + with LoopbackFile('wb') as f: + self.assertRaises(TypeError, f.write, object()) + + def test_write_unicode_as_binary(self): + text = u"\xa7 why is writing text to a binary file allowed?\n" + with LoopbackFile('rb+') as f: + f.write(text) + self.assertEqual(f.read(), text.encode("utf-8")) + + @skipUnlessBuiltin('memoryview') + def test_write_bytearray(self): + with LoopbackFile('rb+') as f: + f.write(bytearray(12)) + self.assertEqual(f.read(), 12 * b"\0") + + @skipUnlessBuiltin('buffer') + def test_write_buffer(self): + data = 3 * b"pretend giant block of data\n" + offsets = range(0, len(data), 8) + with LoopbackFile('rb+') as f: + for offset in offsets: + f.write(buffer(data, offset, 8)) + self.assertEqual(f.read(), data) + + @skipUnlessBuiltin('memoryview') + def test_write_memoryview(self): + data = 3 * b"pretend giant block of data\n" + offsets = range(0, len(data), 8) + with LoopbackFile('rb+') as f: + view = memoryview(data) + for offset in offsets: + f.write(view[offset:offset+8]) + self.assertEqual(f.read(), data) + + if __name__ == '__main__': from unittest import main main() diff --git a/tests/test_sftp.py b/tests/test_sftp.py index d3064fff..98a9cebb 100755 --- a/tests/test_sftp.py +++ b/tests/test_sftp.py @@ -35,6 +35,7 @@ from tempfile import mkstemp import paramiko from paramiko.py3compat import PY2, b, u, StringIO from paramiko.common import o777, o600, o666, o644 +from tests import skipUnlessBuiltin from tests.stub_sftp import StubServer, StubSFTPServer from tests.loop import LoopSocket from tests.util import test_path @@ -817,6 +818,35 @@ class SFTPTest (unittest.TestCase): sftp_attributes = SFTPAttributes() self.assertEqual(str(sftp_attributes), "?--------- 1 0 0 0 (unknown date) ?") + @skipUnlessBuiltin('buffer') + def test_write_buffer(self): + """Test write() using a buffer instance.""" + data = 3 * b'A potentially large block of data to chunk up.\n' + try: + with sftp.open('%s/write_buffer' % FOLDER, 'wb') as f: + for offset in range(0, len(data), 8): + f.write(buffer(data, offset, 8)) + + with sftp.open('%s/write_buffer' % FOLDER, 'rb') as f: + self.assertEqual(f.read(), data) + finally: + sftp.remove('%s/write_buffer' % FOLDER) + + @skipUnlessBuiltin('memoryview') + def test_write_memoryview(self): + """Test write() using a memoryview instance.""" + data = 3 * b'A potentially large block of data to chunk up.\n' + try: + with sftp.open('%s/write_memoryview' % FOLDER, 'wb') as f: + view = memoryview(data) + for offset in range(0, len(data), 8): + f.write(view[offset:offset+8]) + + with sftp.open('%s/write_memoryview' % FOLDER, 'rb') as f: + self.assertEqual(f.read(), data) + finally: + sftp.remove('%s/write_memoryview' % FOLDER) + if __name__ == '__main__': SFTPTest.init_loopback() diff --git a/tests/test_transport.py b/tests/test_transport.py index c426cef1..3e352919 100644 --- a/tests/test_transport.py +++ b/tests/test_transport.py @@ -43,6 +43,7 @@ from paramiko.common import ( ) from paramiko.py3compat import bytes from paramiko.message import Message +from tests import skipUnlessBuiltin from tests.loop import LoopSocket from tests.util import test_path @@ -858,3 +859,71 @@ class TransportTest(unittest.TestCase): self.assertEqual([chan], r) self.assertEqual([], w) self.assertEqual([], e) + + def test_channel_send_misc(self): + """ + verify behaviours sending various instances to a channel + """ + self.setup_test_server() + text = u"\xa7 slice me nicely" + with self.tc.open_session() as chan: + schan = self.ts.accept(1.0) + if schan is None: + self.fail("Test server transport failed to accept") + sfile = schan.makefile() + + # TypeError raised on non string or buffer type + self.assertRaises(TypeError, chan.send, object()) + self.assertRaises(TypeError, chan.sendall, object()) + + # sendall() accepts a unicode instance + chan.sendall(text) + expected = text.encode("utf-8") + self.assertEqual(sfile.read(len(expected)), expected) + + @skipUnlessBuiltin('buffer') + def test_channel_send_buffer(self): + """ + verify sending buffer instances to a channel + """ + self.setup_test_server() + data = 3 * b'some test data\n whole' + with self.tc.open_session() as chan: + schan = self.ts.accept(1.0) + if schan is None: + self.fail("Test server transport failed to accept") + sfile = schan.makefile() + + # send() accepts buffer instances + sent = 0 + while sent < len(data): + sent += chan.send(buffer(data, sent, 8)) + self.assertEqual(sfile.read(len(data)), data) + + # sendall() accepts a buffer instance + chan.sendall(buffer(data)) + self.assertEqual(sfile.read(len(data)), data) + + @skipUnlessBuiltin('memoryview') + def test_channel_send_memoryview(self): + """ + verify sending memoryview instances to a channel + """ + self.setup_test_server() + data = 3 * b'some test data\n whole' + with self.tc.open_session() as chan: + schan = self.ts.accept(1.0) + if schan is None: + self.fail("Test server transport failed to accept") + sfile = schan.makefile() + + # send() accepts memoryview slices + sent = 0 + view = memoryview(data) + while sent < len(view): + sent += chan.send(view[sent:sent+8]) + self.assertEqual(sfile.read(len(data)), data) + + # sendall() accepts a memoryview instance + chan.sendall(memoryview(data)) + self.assertEqual(sfile.read(len(data)), data) |