summaryrefslogtreecommitdiffhomepage
diff options
context:
space:
mode:
-rw-r--r--paramiko/client.py6
-rw-r--r--tests/test_client.py24
2 files changed, 28 insertions, 2 deletions
diff --git a/paramiko/client.py b/paramiko/client.py
index 4326abbd..d718135e 100644
--- a/paramiko/client.py
+++ b/paramiko/client.py
@@ -37,10 +37,10 @@ from paramiko.resource import ResourceManager
from paramiko.rsakey import RSAKey
from paramiko.ssh_exception import SSHException, BadHostKeyException
from paramiko.transport import Transport
-from paramiko.util import retry_on_signal
+from paramiko.util import retry_on_signal, ClosingContextManager
-class SSHClient (object):
+class SSHClient (ClosingContextManager):
"""
A high-level representation of a session with an SSH server. This class
wraps `.Transport`, `.Channel`, and `.SFTPClient` to take care of most
@@ -55,6 +55,8 @@ class SSHClient (object):
checking. The default mechanism is to try to use local key files or an
SSH agent (if one is running).
+ Instances of this class may be used as context managers.
+
.. versionadded:: 1.6
"""
diff --git a/tests/test_client.py b/tests/test_client.py
index 7c094628..b3635272 100644
--- a/tests/test_client.py
+++ b/tests/test_client.py
@@ -20,6 +20,8 @@
Some unit tests for SSHClient.
"""
+from __future__ import with_statement
+
import socket
from tempfile import mkstemp
import threading
@@ -293,3 +295,25 @@ class SSHClientTest (unittest.TestCase):
gc.collect()
self.assertTrue(p() is None)
+
+ def test_6_client_can_be_used_as_context_manager(self):
+ """
+ verify that an SSHClient can be used a context manager
+ """
+ threading.Thread(target=self._run).start()
+ host_key = paramiko.RSAKey.from_private_key_file(test_path('test_rsa.key'))
+ public_host_key = paramiko.RSAKey(data=host_key.asbytes())
+
+ with paramiko.SSHClient() as tc:
+ self.tc = tc
+ self.tc.set_missing_host_key_policy(paramiko.AutoAddPolicy())
+ self.assertEquals(0, len(self.tc.get_host_keys()))
+ self.tc.connect(self.addr, self.port, username='slowdive', password='pygmalion')
+
+ self.event.wait(1.0)
+ self.assertTrue(self.event.isSet())
+ self.assertTrue(self.ts.is_active())
+
+ self.assertTrue(self.tc._transport is not None)
+
+ self.assertTrue(self.tc._transport is None)