summaryrefslogtreecommitdiffhomepage
diff options
context:
space:
mode:
authorMartin Packman <gzlist@googlemail.com>2017-05-25 20:37:14 +0100
committerJeff Forcier <jeff@bitprophet.org>2017-06-09 13:41:17 -0700
commite66899a87ed2e510b66116284fc95763f5a2f4ec (patch)
tree97dfe2033c70d97d4f7a30b4f36ec892b50ea269
parent85dd4d777ccfe9e9fd23c650d43386a0b50444b7 (diff)
Allow any buffer type to written to BufferedFile
Fixes #967 Also adds test coverage for writing various types to BufferedFile which required some small changes to the test LoopbackFile subclass. Change against the 1.17 branch.
-rw-r--r--paramiko/file.py7
-rwxr-xr-xtests/test_file.py57
2 files changed, 52 insertions, 12 deletions
diff --git a/paramiko/file.py b/paramiko/file.py
index 5212091a..35637b8e 100644
--- a/paramiko/file.py
+++ b/paramiko/file.py
@@ -18,7 +18,7 @@
from paramiko.common import (
linefeed_byte_value, crlf, cr_byte, linefeed_byte, cr_byte_value,
)
-from paramiko.py3compat import BytesIO, PY2, u, b, bytes_types
+from paramiko.py3compat import BytesIO, PY2, u, bytes_types, text_type
from paramiko.util import ClosingContextManager
@@ -391,7 +391,10 @@ class BufferedFile (ClosingContextManager):
:param data: ``str``/``bytes`` data to write
"""
- data = b(data)
+ if isinstance(data, text_type):
+ # GZ 2017-05-25: Accepting text on a binary stream unconditionally
+ # cooercing to utf-8 seems questionable, but compatibility reasons?
+ data = data.encode('utf-8')
if self._closed:
raise IOError('File is closed')
if not (self._flags & self.FLAG_WRITE):
diff --git a/tests/test_file.py b/tests/test_file.py
index 7fab6985..b33ecd51 100755
--- a/tests/test_file.py
+++ b/tests/test_file.py
@@ -21,10 +21,14 @@ Some unit tests for the BufferedFile abstraction.
"""
import unittest
-from paramiko.file import BufferedFile
-from paramiko.common import linefeed_byte, crlf, cr_byte
import sys
+from paramiko.common import linefeed_byte, crlf, cr_byte
+from paramiko.file import BufferedFile
+from paramiko.py3compat import BytesIO
+
+from tests import skipUnlessBuiltin
+
class LoopbackFile (BufferedFile):
"""
@@ -33,19 +37,16 @@ class LoopbackFile (BufferedFile):
def __init__(self, mode='r', bufsize=-1):
BufferedFile.__init__(self)
self._set_mode(mode, bufsize)
- self.buffer = bytes()
+ self.buffer = BytesIO()
+ self.offset = 0
def _read(self, size):
- if len(self.buffer) == 0:
- return None
- if size > len(self.buffer):
- size = len(self.buffer)
- data = self.buffer[:size]
- self.buffer = self.buffer[size:]
+ data = self.buffer.getvalue()[self.offset:self.offset+size]
+ self.offset += len(data)
return data
def _write(self, data):
- self.buffer += data
+ self.buffer.write(data)
return len(data)
@@ -187,6 +188,42 @@ class BufferedFileTest (unittest.TestCase):
self.assertEqual(data, b'hello')
f.close()
+ def test_write_bad_type(self):
+ with LoopbackFile('wb') as f:
+ self.assertRaises(TypeError, f.write, object())
+
+ def test_write_unicode_as_binary(self):
+ text = u"\xa7 why is writing text to a binary file allowed?\n"
+ with LoopbackFile('rb+') as f:
+ f.write(text)
+ self.assertEqual(f.read(), text.encode("utf-8"))
+
+ @skipUnlessBuiltin('memoryview')
+ def test_write_bytearray(self):
+ with LoopbackFile('rb+') as f:
+ f.write(bytearray(12))
+ self.assertEqual(f.read(), 12 * b"\0")
+
+ @skipUnlessBuiltin('buffer')
+ def test_write_buffer(self):
+ data = 3 * b"pretend giant block of data\n"
+ offsets = range(0, len(data), 8)
+ with LoopbackFile('rb+') as f:
+ for offset in offsets:
+ f.write(buffer(data, offset, 8))
+ self.assertEqual(f.read(), data)
+
+ @skipUnlessBuiltin('memoryview')
+ def test_write_memoryview(self):
+ data = 3 * b"pretend giant block of data\n"
+ offsets = range(0, len(data), 8)
+ with LoopbackFile('rb+') as f:
+ view = memoryview(data)
+ for offset in offsets:
+ f.write(view[offset:offset+8])
+ self.assertEqual(f.read(), data)
+
+
if __name__ == '__main__':
from unittest import main
main()