summaryrefslogtreecommitdiffhomepage
diff options
context:
space:
mode:
-rw-r--r--paramiko/file.py60
-rw-r--r--paramiko/sftp_file.py11
-rwxr-xr-xtests/test_sftp.py21
3 files changed, 61 insertions, 31 deletions
diff --git a/paramiko/file.py b/paramiko/file.py
index 1971ce08..98a28f6f 100644
--- a/paramiko/file.py
+++ b/paramiko/file.py
@@ -23,13 +23,6 @@ BufferedFile.
from cStringIO import StringIO
-_FLAG_READ = 0x1
-_FLAG_WRITE = 0x2
-_FLAG_APPEND = 0x4
-_FLAG_BINARY = 0x10
-_FLAG_BUFFERED = 0x20
-_FLAG_LINE_BUFFERED = 0x40
-_FLAG_UNIVERSAL_NEWLINE = 0x80
class BufferedFile (object):
@@ -44,6 +37,14 @@ class BufferedFile (object):
SEEK_CUR = 1
SEEK_END = 2
+ FLAG_READ = 0x1
+ FLAG_WRITE = 0x2
+ FLAG_APPEND = 0x4
+ FLAG_BINARY = 0x10
+ FLAG_BUFFERED = 0x20
+ FLAG_LINE_BUFFERED = 0x40
+ FLAG_UNIVERSAL_NEWLINE = 0x80
+
def __init__(self):
self.newlines = None
self._flags = 0
@@ -123,7 +124,7 @@ class BufferedFile (object):
"""
if self._closed:
raise IOError('File is closed')
- if not (self._flags & _FLAG_READ):
+ if not (self._flags & self.FLAG_READ):
raise IOError('File not open for reading')
if (size is None) or (size < 0):
# go for broke
@@ -148,7 +149,7 @@ class BufferedFile (object):
return result
while len(self._rbuffer) < size:
read_size = size - len(self._rbuffer)
- if self._flags & _FLAG_BUFFERED:
+ if self._flags & self.FLAG_BUFFERED:
read_size = max(self._bufsize, read_size)
try:
new_data = self._read(read_size)
@@ -184,11 +185,11 @@ class BufferedFile (object):
# it's almost silly how complex this function is.
if self._closed:
raise IOError('File is closed')
- if not (self._flags & _FLAG_READ):
+ if not (self._flags & self.FLAG_READ):
raise IOError('File not open for reading')
line = self._rbuffer
while True:
- if self._at_trailing_cr and (self._flags & _FLAG_UNIVERSAL_NEWLINE) and (len(line) > 0):
+ if self._at_trailing_cr and (self._flags & self.FLAG_UNIVERSAL_NEWLINE) and (len(line) > 0):
# edge case: the newline may be '\r\n' and we may have read
# only the first '\r' last time.
if line[0] == '\n':
@@ -209,7 +210,7 @@ class BufferedFile (object):
n = size - len(line)
else:
n = self._bufsize
- if ('\n' in line) or ((self._flags & _FLAG_UNIVERSAL_NEWLINE) and ('\r' in line)):
+ if ('\n' in line) or ((self._flags & self.FLAG_UNIVERSAL_NEWLINE) and ('\r' in line)):
break
try:
new_data = self._read(n)
@@ -223,7 +224,7 @@ class BufferedFile (object):
self._realpos += len(new_data)
# find the newline
pos = line.find('\n')
- if self._flags & _FLAG_UNIVERSAL_NEWLINE:
+ if self._flags & self.FLAG_UNIVERSAL_NEWLINE:
rpos = line.find('\r')
if (rpos >= 0) and ((rpos < pos) or (pos < 0)):
pos = rpos
@@ -295,6 +296,8 @@ class BufferedFile (object):
@return: file position (in bytes).
@rtype: int
"""
+ if self._flags & self.FLAG_APPEND:
+ return self._get_size()
return self._pos
def write(self, data):
@@ -309,13 +312,13 @@ class BufferedFile (object):
"""
if self._closed:
raise IOError('File is closed')
- if not (self._flags & _FLAG_WRITE):
+ if not (self._flags & self.FLAG_WRITE):
raise IOError('File not open for writing')
- if not (self._flags & _FLAG_BUFFERED):
+ if not (self._flags & self.FLAG_BUFFERED):
self._write_all(data)
return
self._wbuffer.write(data)
- if self._flags & _FLAG_LINE_BUFFERED:
+ if self._flags & self.FLAG_LINE_BUFFERED:
# only scan the new data for linefeed, to avoid wasting time.
last_newline_pos = data.rfind('\n')
if last_newline_pos >= 0:
@@ -397,22 +400,21 @@ class BufferedFile (object):
# apparently, line buffering only affects writes. reads are only
# buffered if you call readline (directly or indirectly: iterating
# over a file will indirectly call readline).
- self._flags |= _FLAG_BUFFERED | _FLAG_LINE_BUFFERED
+ self._flags |= self.FLAG_BUFFERED | self.FLAG_LINE_BUFFERED
elif bufsize > 1:
self._bufsize = bufsize
- self._flags |= _FLAG_BUFFERED
+ self._flags |= self.FLAG_BUFFERED
if ('r' in mode) or ('+' in mode):
- self._flags |= _FLAG_READ
+ self._flags |= self.FLAG_READ
if ('w' in mode) or ('+' in mode):
- self._flags |= _FLAG_WRITE
+ self._flags |= self.FLAG_WRITE
if ('a' in mode):
- self._flags |= _FLAG_WRITE | _FLAG_APPEND
- self._size = self._get_size()
- self._pos = self._realpos = self._size
+ self._flags |= self.FLAG_WRITE | self.FLAG_APPEND
+ self._pos = self._realpos = -1
if ('b' in mode):
- self._flags |= _FLAG_BINARY
+ self._flags |= self.FLAG_BINARY
if ('U' in mode):
- self._flags |= _FLAG_UNIVERSAL_NEWLINE
+ self._flags |= self.FLAG_UNIVERSAL_NEWLINE
# built-in file objects have this attribute to store which kinds of
# line terminations they've seen:
# <http://www.python.org/doc/current/lib/built-in-funcs.html>
@@ -424,9 +426,9 @@ class BufferedFile (object):
while len(data) > 0:
count = self._write(data)
data = data[count:]
- if self._flags & _FLAG_APPEND:
- self._size += count
- self._pos = self._realpos = self._size
+ if self._flags & self.FLAG_APPEND:
+ # even if we used to know our seek position, we don't now.
+ self._pos = self._realpos = -1
else:
self._pos += count
self._realpos += count
@@ -436,7 +438,7 @@ class BufferedFile (object):
# silliness about tracking what kinds of newlines we've seen.
# i don't understand why it can be None, a string, or a tuple, instead
# of just always being a tuple, but we'll emulate that behavior anyway.
- if not (self._flags & _FLAG_UNIVERSAL_NEWLINE):
+ if not (self._flags & self.FLAG_UNIVERSAL_NEWLINE):
return
if self.newlines is None:
self.newlines = newline
diff --git a/paramiko/sftp_file.py b/paramiko/sftp_file.py
index 1e5478b2..e29d5574 100644
--- a/paramiko/sftp_file.py
+++ b/paramiko/sftp_file.py
@@ -159,8 +159,11 @@ class SFTPFile (BufferedFile):
def _write(self, data):
# may write less than requested if it would exceed max packet size
chunk = min(len(data), self.MAX_REQUEST_SIZE)
- req = self.sftp._async_request(type(None), CMD_WRITE, self.handle, long(self._realpos),
- str(data[:chunk]))
+ if self._flags & self.FLAG_APPEND:
+ pos = 0
+ else:
+ pos = self._realpos
+ req = self.sftp._async_request(type(None), CMD_WRITE, self.handle, long(pos), str(data[:chunk]))
if not self.pipelined or self.sftp.sock.recv_ready():
t, msg = self.sftp._read_response(req)
if t != CMD_STATUS:
@@ -204,6 +207,10 @@ class SFTPFile (BufferedFile):
def seek(self, offset, whence=0):
self.flush()
+ if (self._flags & self.FLAG_APPEND) and (self._realpos == -1) and (whence != self.SEEK_END):
+ # this is still legal for O_RDWR ('a+'), but we need to figure out
+ # where we are -- we lost track of it during writes.
+ self._realpos = self._pos = self._get_size()
if whence == self.SEEK_SET:
self._realpos = self._pos = offset
elif whence == self.SEEK_CUR:
diff --git a/tests/test_sftp.py b/tests/test_sftp.py
index 1e2785d1..db210131 100755
--- a/tests/test_sftp.py
+++ b/tests/test_sftp.py
@@ -648,3 +648,24 @@ class SFTPTest (unittest.TestCase):
f.close()
finally:
sftp.unlink(FOLDER + '/zero')
+
+ def test_M_seek_append(self):
+ """
+ verify that seek does't affect writes during append.
+ """
+ f = sftp.open(FOLDER + '/append.txt', 'a')
+ try:
+ f.write('first line\nsecond line\n')
+ f.seek(11, f.SEEK_SET)
+ f.write('third line\n')
+ f.close()
+
+ f = sftp.open(FOLDER + '/append.txt', 'r')
+ self.assertEqual(f.stat().st_size, 34)
+ self.assertEqual(f.readline(), 'first line\n')
+ self.assertEqual(f.readline(), 'second line\n')
+ self.assertEqual(f.readline(), 'third line\n')
+ f.close()
+ finally:
+ sftp.remove(FOLDER + '/append.txt')
+