summaryrefslogtreecommitdiffhomepage
diff options
context:
space:
mode:
-rw-r--r--paramiko/channel.py5
-rw-r--r--paramiko/client.py6
-rw-r--r--paramiko/file.py4
-rw-r--r--paramiko/proxy.py5
-rw-r--r--paramiko/sftp_client.py5
-rw-r--r--paramiko/sftp_file.py6
-rw-r--r--paramiko/sftp_handle.py5
-rw-r--r--paramiko/transport.py6
-rw-r--r--paramiko/util.py9
-rw-r--r--tests/test_client.py24
-rwxr-xr-xtests/test_sftp.py20
-rw-r--r--tests/test_transport.py18
12 files changed, 98 insertions, 15 deletions
diff --git a/paramiko/channel.py b/paramiko/channel.py
index 78a14795..9de278cb 100644
--- a/paramiko/channel.py
+++ b/paramiko/channel.py
@@ -38,6 +38,7 @@ from paramiko.ssh_exception import SSHException
from paramiko.file import BufferedFile
from paramiko.buffered_pipe import BufferedPipe, PipeTimeout
from paramiko import pipe
+from paramiko.util import ClosingContextManager
def open_only(func):
@@ -60,7 +61,7 @@ def open_only(func):
return _check
-class Channel (object):
+class Channel (ClosingContextManager):
"""
A secure tunnel across an SSH `.Transport`. A Channel is meant to behave
like a socket, and has an API that should be indistinguishable from the
@@ -73,6 +74,8 @@ class Channel (object):
flow-controlled independently.) Similarly, if the server isn't reading
data you send, calls to `send` may block, unless you set a timeout. This
is exactly like a normal network socket, so it shouldn't be too surprising.
+
+ Instances of this class may be used as context managers.
"""
def __init__(self, chanid):
diff --git a/paramiko/client.py b/paramiko/client.py
index 265389de..05686d97 100644
--- a/paramiko/client.py
+++ b/paramiko/client.py
@@ -37,10 +37,10 @@ from paramiko.resource import ResourceManager
from paramiko.rsakey import RSAKey
from paramiko.ssh_exception import SSHException, BadHostKeyException
from paramiko.transport import Transport
-from paramiko.util import retry_on_signal
+from paramiko.util import retry_on_signal, ClosingContextManager
-class SSHClient (object):
+class SSHClient (ClosingContextManager):
"""
A high-level representation of a session with an SSH server. This class
wraps `.Transport`, `.Channel`, and `.SFTPClient` to take care of most
@@ -55,6 +55,8 @@ class SSHClient (object):
checking. The default mechanism is to try to use local key files or an
SSH agent (if one is running).
+ Instances of this class may be used as context managers.
+
.. versionadded:: 1.6
"""
diff --git a/paramiko/file.py b/paramiko/file.py
index 2238f0bf..09998829 100644
--- a/paramiko/file.py
+++ b/paramiko/file.py
@@ -19,8 +19,10 @@ from paramiko.common import linefeed_byte_value, crlf, cr_byte, linefeed_byte, \
cr_byte_value
from paramiko.py3compat import BytesIO, PY2, u, b, bytes_types
+from paramiko.util import ClosingContextManager
-class BufferedFile (object):
+
+class BufferedFile (ClosingContextManager):
"""
Reusable base class to implement Python-style file buffering around a
simpler stream.
diff --git a/paramiko/proxy.py b/paramiko/proxy.py
index 8959b244..0664ac6e 100644
--- a/paramiko/proxy.py
+++ b/paramiko/proxy.py
@@ -26,9 +26,10 @@ from select import select
import socket
from paramiko.ssh_exception import ProxyCommandFailure
+from paramiko.util import ClosingContextManager
-class ProxyCommand(object):
+class ProxyCommand(ClosingContextManager):
"""
Wraps a subprocess running ProxyCommand-driven programs.
@@ -36,6 +37,8 @@ class ProxyCommand(object):
`.Transport` and `.Packetizer` classes. Using this class instead of a
regular socket makes it possible to talk with a Popen'd command that will
proxy traffic between the client and a server hosted in another machine.
+
+ Instances of this class may be used as context managers.
"""
def __init__(self, command_line):
"""
diff --git a/paramiko/sftp_client.py b/paramiko/sftp_client.py
index 9c30426d..62127cc2 100644
--- a/paramiko/sftp_client.py
+++ b/paramiko/sftp_client.py
@@ -39,6 +39,7 @@ from paramiko.sftp import BaseSFTP, CMD_OPENDIR, CMD_HANDLE, SFTPError, CMD_READ
from paramiko.sftp_attr import SFTPAttributes
from paramiko.ssh_exception import SSHException
from paramiko.sftp_file import SFTPFile
+from paramiko.util import ClosingContextManager
def _to_unicode(s):
@@ -58,12 +59,14 @@ def _to_unicode(s):
b_slash = b'/'
-class SFTPClient(BaseSFTP):
+class SFTPClient(BaseSFTP, ClosingContextManager):
"""
SFTP client object.
Used to open an SFTP session across an open SSH `.Transport` and perform
remote file operations.
+
+ Instances of this class may be used as context managers.
"""
def __init__(self, sock):
"""
diff --git a/paramiko/sftp_file.py b/paramiko/sftp_file.py
index 03d67b33..d0a37da3 100644
--- a/paramiko/sftp_file.py
+++ b/paramiko/sftp_file.py
@@ -488,9 +488,3 @@ class SFTPFile (BufferedFile):
x = self._saved_exception
self._saved_exception = None
raise x
-
- def __enter__(self):
- return self
-
- def __exit__(self, type, value, traceback):
- self.close()
diff --git a/paramiko/sftp_handle.py b/paramiko/sftp_handle.py
index 92dd9cfe..edceb5ad 100644
--- a/paramiko/sftp_handle.py
+++ b/paramiko/sftp_handle.py
@@ -22,9 +22,10 @@ Abstraction of an SFTP file handle (for server mode).
import os
from paramiko.sftp import SFTP_OP_UNSUPPORTED, SFTP_OK
+from paramiko.util import ClosingContextManager
-class SFTPHandle (object):
+class SFTPHandle (ClosingContextManager):
"""
Abstract object representing a handle to an open file (or folder) in an
SFTP server implementation. Each handle has a string representation used
@@ -32,6 +33,8 @@ class SFTPHandle (object):
Server implementations can (and should) subclass SFTPHandle to implement
features of a file handle, like `stat` or `chattr`.
+
+ Instances of this class may be used as context managers.
"""
def __init__(self, flags=0):
"""
diff --git a/paramiko/transport.py b/paramiko/transport.py
index 0f747343..b465aa00 100644
--- a/paramiko/transport.py
+++ b/paramiko/transport.py
@@ -61,7 +61,7 @@ from paramiko.server import ServerInterface
from paramiko.sftp_client import SFTPClient
from paramiko.ssh_exception import (SSHException, BadAuthenticationType,
ChannelException, ProxyCommandFailure)
-from paramiko.util import retry_on_signal, clamp_value
+from paramiko.util import retry_on_signal, ClosingContextManager, clamp_value
from Crypto.Cipher import Blowfish, AES, DES3, ARC4
try:
@@ -81,13 +81,15 @@ import atexit
atexit.register(_join_lingering_threads)
-class Transport (threading.Thread):
+class Transport (threading.Thread, ClosingContextManager):
"""
An SSH Transport attaches to a stream (usually a socket), negotiates an
encrypted session, authenticates, and then creates stream tunnels, called
`channels <.Channel>`, across the session. Multiple channels can be
multiplexed across a single session (and often are, in the case of port
forwardings).
+
+ Instances of this class may be used as context managers.
"""
_PROTO_ID = '2.0'
_CLIENT_ID = 'paramiko_%s' % paramiko.__version__
diff --git a/paramiko/util.py b/paramiko/util.py
index d029f52e..88ca2bc4 100644
--- a/paramiko/util.py
+++ b/paramiko/util.py
@@ -321,5 +321,14 @@ def constant_time_bytes_eq(a, b):
res |= byte_ord(a[i]) ^ byte_ord(b[i])
return res == 0
+
+class ClosingContextManager(object):
+ def __enter__(self):
+ return self
+
+ def __exit__(self, type, value, traceback):
+ self.close()
+
+
def clamp_value(minimum, val, maximum):
return max(minimum, min(val, maximum))
diff --git a/tests/test_client.py b/tests/test_client.py
index 6fda7f5e..28d1cb46 100644
--- a/tests/test_client.py
+++ b/tests/test_client.py
@@ -20,6 +20,8 @@
Some unit tests for SSHClient.
"""
+from __future__ import with_statement
+
import socket
from tempfile import mkstemp
import threading
@@ -299,6 +301,28 @@ class SSHClientTest (unittest.TestCase):
self.assertTrue(p() is None)
+ def test_client_can_be_used_as_context_manager(self):
+ """
+ verify that an SSHClient can be used a context manager
+ """
+ threading.Thread(target=self._run).start()
+ host_key = paramiko.RSAKey.from_private_key_file(test_path('test_rsa.key'))
+ public_host_key = paramiko.RSAKey(data=host_key.asbytes())
+
+ with paramiko.SSHClient() as tc:
+ self.tc = tc
+ self.tc.set_missing_host_key_policy(paramiko.AutoAddPolicy())
+ self.assertEquals(0, len(self.tc.get_host_keys()))
+ self.tc.connect(self.addr, self.port, username='slowdive', password='pygmalion')
+
+ self.event.wait(1.0)
+ self.assertTrue(self.event.isSet())
+ self.assertTrue(self.ts.is_active())
+
+ self.assertTrue(self.tc._transport is not None)
+
+ self.assertTrue(self.tc._transport is None)
+
def test_7_banner_timeout(self):
"""
verify that the SSHClient has a configurable banner timeout.
diff --git a/tests/test_sftp.py b/tests/test_sftp.py
index 1ae9781d..58013cfd 100755
--- a/tests/test_sftp.py
+++ b/tests/test_sftp.py
@@ -195,6 +195,18 @@ class SFTPTest (unittest.TestCase):
pass
sftp = paramiko.SFTP.from_transport(tc)
+ def test_2_sftp_can_be_used_as_context_manager(self):
+ """
+ verify that the sftp session is closed when exiting the context manager
+ """
+ global sftp
+ with sftp:
+ pass
+ try:
+ self._assert_opening_file_raises_error(sftp)
+ finally:
+ sftp = paramiko.SFTP.from_transport(tc)
+
def test_3_write(self):
"""
verify that a file can be created and written, and the size is correct.
@@ -796,6 +808,14 @@ class SFTPTest (unittest.TestCase):
sftp.remove('%s/nonutf8data' % FOLDER)
+ def _assert_opening_file_raises_error(self, sftp):
+ try:
+ sftp.open(FOLDER + '/test2', 'w')
+ self.fail('expected exception')
+ except EOFError:
+ pass
+
+
if __name__ == '__main__':
SFTPTest.init_loopback()
# logging is required by test_N_file_with_percent
diff --git a/tests/test_transport.py b/tests/test_transport.py
index 344d64b8..50b1d86b 100644
--- a/tests/test_transport.py
+++ b/tests/test_transport.py
@@ -20,6 +20,8 @@
Some unit tests for the ssh2 protocol in Transport.
"""
+from __future__ import with_statement
+
from binascii import hexlify
import select
import socket
@@ -281,6 +283,22 @@ class TransportTest(unittest.TestCase):
self.assertEqual('Hello there.\n', f.readline())
self.assertEqual('This is on stderr.\n', f.readline())
self.assertEqual('', f.readline())
+
+ def test_6a_channel_can_be_used_as_context_manager(self):
+ """
+ verify that exec_command() does something reasonable.
+ """
+ self.setup_test_server()
+
+ with self.tc.open_session() as chan:
+ with self.ts.accept(1.0) as schan:
+ chan.exec_command('yes')
+ schan.send('Hello there.\n')
+ schan.close()
+
+ f = chan.makefile()
+ self.assertEqual('Hello there.\n', f.readline())
+ self.assertEqual('', f.readline())
def test_7_invoke_shell(self):
"""