summaryrefslogtreecommitdiffhomepage
path: root/tests/test_file.py
diff options
context:
space:
mode:
authorJeff Forcier <jeff@bitprophet.org>2017-06-09 13:53:43 -0700
committerJeff Forcier <jeff@bitprophet.org>2017-06-09 13:53:43 -0700
commit38d9f5da2e41f63b8a2c54b3a9b603489dc624e3 (patch)
tree312c5c873f03f1d0dd4ea4895c6d90955aa4de81 /tests/test_file.py
parentc11ec97a756afcf60058a4214f56c0fdd4f6a587 (diff)
parent35e75f7dba78ca93065eb138e746863892258a2a (diff)
Merge branch '2.1'
Diffstat (limited to 'tests/test_file.py')
-rwxr-xr-xtests/test_file.py57
1 files changed, 47 insertions, 10 deletions
diff --git a/tests/test_file.py b/tests/test_file.py
index 7fab6985..b33ecd51 100755
--- a/tests/test_file.py
+++ b/tests/test_file.py
@@ -21,10 +21,14 @@ Some unit tests for the BufferedFile abstraction.
"""
import unittest
-from paramiko.file import BufferedFile
-from paramiko.common import linefeed_byte, crlf, cr_byte
import sys
+from paramiko.common import linefeed_byte, crlf, cr_byte
+from paramiko.file import BufferedFile
+from paramiko.py3compat import BytesIO
+
+from tests import skipUnlessBuiltin
+
class LoopbackFile (BufferedFile):
"""
@@ -33,19 +37,16 @@ class LoopbackFile (BufferedFile):
def __init__(self, mode='r', bufsize=-1):
BufferedFile.__init__(self)
self._set_mode(mode, bufsize)
- self.buffer = bytes()
+ self.buffer = BytesIO()
+ self.offset = 0
def _read(self, size):
- if len(self.buffer) == 0:
- return None
- if size > len(self.buffer):
- size = len(self.buffer)
- data = self.buffer[:size]
- self.buffer = self.buffer[size:]
+ data = self.buffer.getvalue()[self.offset:self.offset+size]
+ self.offset += len(data)
return data
def _write(self, data):
- self.buffer += data
+ self.buffer.write(data)
return len(data)
@@ -187,6 +188,42 @@ class BufferedFileTest (unittest.TestCase):
self.assertEqual(data, b'hello')
f.close()
+ def test_write_bad_type(self):
+ with LoopbackFile('wb') as f:
+ self.assertRaises(TypeError, f.write, object())
+
+ def test_write_unicode_as_binary(self):
+ text = u"\xa7 why is writing text to a binary file allowed?\n"
+ with LoopbackFile('rb+') as f:
+ f.write(text)
+ self.assertEqual(f.read(), text.encode("utf-8"))
+
+ @skipUnlessBuiltin('memoryview')
+ def test_write_bytearray(self):
+ with LoopbackFile('rb+') as f:
+ f.write(bytearray(12))
+ self.assertEqual(f.read(), 12 * b"\0")
+
+ @skipUnlessBuiltin('buffer')
+ def test_write_buffer(self):
+ data = 3 * b"pretend giant block of data\n"
+ offsets = range(0, len(data), 8)
+ with LoopbackFile('rb+') as f:
+ for offset in offsets:
+ f.write(buffer(data, offset, 8))
+ self.assertEqual(f.read(), data)
+
+ @skipUnlessBuiltin('memoryview')
+ def test_write_memoryview(self):
+ data = 3 * b"pretend giant block of data\n"
+ offsets = range(0, len(data), 8)
+ with LoopbackFile('rb+') as f:
+ view = memoryview(data)
+ for offset in offsets:
+ f.write(view[offset:offset+8])
+ self.assertEqual(f.read(), data)
+
+
if __name__ == '__main__':
from unittest import main
main()