summaryrefslogtreecommitdiffhomepage
diff options
context:
space:
mode:
-rw-r--r--paramiko/sftp_client.py16
-rwxr-xr-xtests/test_sftp.py10
2 files changed, 22 insertions, 4 deletions
diff --git a/paramiko/sftp_client.py b/paramiko/sftp_client.py
index 338d19b8..16728d15 100644
--- a/paramiko/sftp_client.py
+++ b/paramiko/sftp_client.py
@@ -522,7 +522,7 @@ class SFTPClient (BaseSFTP):
"""
return self._cwd
- def put(self, localpath, remotepath):
+ def put(self, localpath, remotepath, callback=None):
"""
Copy a local file (C{localpath}) to the SFTP server as C{remotepath}.
Any exception raised by operations will be passed through. This
@@ -534,12 +534,16 @@ class SFTPClient (BaseSFTP):
@type localpath: str
@param remotepath: the destination path on the SFTP server
@type remotepath: str
+ @param callback: optional callback function that accepts the bytes
+ transferred so far and the total bytes to be transferred
+ @type callback: function(int, int)
@return: an object containing attributes about the given file
(since 1.7.4)
@rtype: SFTPAttributes
@since: 1.4
"""
+ file_size = os.stat(localpath).st_size
fl = file(localpath, 'rb')
fr = self.file(remotepath, 'wb')
fr.set_pipelined(True)
@@ -550,6 +554,8 @@ class SFTPClient (BaseSFTP):
break
fr.write(data)
size += len(data)
+ if callback is not None:
+ callback(size, file_size)
fl.close()
fr.close()
s = self.stat(remotepath)
@@ -557,7 +563,7 @@ class SFTPClient (BaseSFTP):
raise IOError('size mismatch in put! %d != %d' % (s.st_size, size))
return s
- def get(self, remotepath, localpath):
+ def get(self, remotepath, localpath, callback=None):
"""
Copy a remote file (C{remotepath}) from the SFTP server to the local
host as C{localpath}. Any exception raised by operations will be
@@ -567,10 +573,14 @@ class SFTPClient (BaseSFTP):
@type remotepath: str
@param localpath: the destination path on the local host
@type localpath: str
+ @param callback: optional callback function that accepts the bytes
+ transferred so far and the total bytes to be transferred
+ @type callback: function(int, int)
@since: 1.4
"""
fr = self.file(remotepath, 'rb')
+ file_size = self.stat(remotepath).st_size
fr.prefetch()
fl = file(localpath, 'wb')
size = 0
@@ -580,6 +590,8 @@ class SFTPClient (BaseSFTP):
break
fl.write(data)
size += len(data)
+ if callback is not None:
+ callback(size, file_size)
fl.close()
fr.close()
s = os.stat(localpath)
diff --git a/tests/test_sftp.py b/tests/test_sftp.py
index ab5b8180..edc05990 100755
--- a/tests/test_sftp.py
+++ b/tests/test_sftp.py
@@ -560,19 +560,25 @@ class SFTPTest (unittest.TestCase):
f = open(localname, 'wb')
f.write(text)
f.close()
- sftp.put(localname, FOLDER + '/bunny.txt')
+ saved_progress = []
+ def progress_callback(x, y):
+ saved_progress.append((x, y))
+ sftp.put(localname, FOLDER + '/bunny.txt', progress_callback)
f = sftp.open(FOLDER + '/bunny.txt', 'r')
self.assertEquals(text, f.read(128))
f.close()
+ self.assertEquals((41, 41), saved_progress[-1])
os.unlink(localname)
localname = os.tempnam()
- sftp.get(FOLDER + '/bunny.txt', localname)
+ saved_progress = []
+ sftp.get(FOLDER + '/bunny.txt', localname, progress_callback)
f = open(localname, 'rb')
self.assertEquals(text, f.read(128))
f.close()
+ self.assertEquals((41, 41), saved_progress[-1])
os.unlink(localname)
sftp.unlink(FOLDER + '/bunny.txt')