diff options
author | Michael Williamson <mike@zwobble.org> | 2013-09-29 18:10:12 +0100 |
---|---|---|
committer | Michael Williamson <mike@zwobble.org> | 2014-09-07 18:51:37 +0100 |
commit | 0063e64046c732e8c50fc5f54234942feaa313d9 (patch) | |
tree | 33c25c448fff4fb2df0c8de1ff14523b6277d991 | |
parent | 4c3b973c78b209170b2bbd7b50a0de1638a12e4a (diff) |
Convert SSHClient into a context manager
-rw-r--r-- | paramiko/client.py | 6 | ||||
-rw-r--r-- | tests/test_client.py | 24 |
2 files changed, 28 insertions, 2 deletions
diff --git a/paramiko/client.py b/paramiko/client.py index 4326abbd..d718135e 100644 --- a/paramiko/client.py +++ b/paramiko/client.py @@ -37,10 +37,10 @@ from paramiko.resource import ResourceManager from paramiko.rsakey import RSAKey from paramiko.ssh_exception import SSHException, BadHostKeyException from paramiko.transport import Transport -from paramiko.util import retry_on_signal +from paramiko.util import retry_on_signal, ClosingContextManager -class SSHClient (object): +class SSHClient (ClosingContextManager): """ A high-level representation of a session with an SSH server. This class wraps `.Transport`, `.Channel`, and `.SFTPClient` to take care of most @@ -55,6 +55,8 @@ class SSHClient (object): checking. The default mechanism is to try to use local key files or an SSH agent (if one is running). + Instances of this class may be used as context managers. + .. versionadded:: 1.6 """ diff --git a/tests/test_client.py b/tests/test_client.py index 7c094628..b3635272 100644 --- a/tests/test_client.py +++ b/tests/test_client.py @@ -20,6 +20,8 @@ Some unit tests for SSHClient. """ +from __future__ import with_statement + import socket from tempfile import mkstemp import threading @@ -293,3 +295,25 @@ class SSHClientTest (unittest.TestCase): gc.collect() self.assertTrue(p() is None) + + def test_6_client_can_be_used_as_context_manager(self): + """ + verify that an SSHClient can be used a context manager + """ + threading.Thread(target=self._run).start() + host_key = paramiko.RSAKey.from_private_key_file(test_path('test_rsa.key')) + public_host_key = paramiko.RSAKey(data=host_key.asbytes()) + + with paramiko.SSHClient() as tc: + self.tc = tc + self.tc.set_missing_host_key_policy(paramiko.AutoAddPolicy()) + self.assertEquals(0, len(self.tc.get_host_keys())) + self.tc.connect(self.addr, self.port, username='slowdive', password='pygmalion') + + self.event.wait(1.0) + self.assertTrue(self.event.isSet()) + self.assertTrue(self.ts.is_active()) + + self.assertTrue(self.tc._transport is not None) + + self.assertTrue(self.tc._transport is None) |