summaryrefslogtreecommitdiffhomepage
diff options
context:
space:
mode:
-rw-r--r--paramiko/client.py2
-rw-r--r--paramiko/resource.py72
-rw-r--r--tests/test_client.py28
3 files changed, 101 insertions, 1 deletions
diff --git a/paramiko/client.py b/paramiko/client.py
index c13105e9..ae257152 100644
--- a/paramiko/client.py
+++ b/paramiko/client.py
@@ -28,6 +28,7 @@ from paramiko.agent import Agent
from paramiko.common import *
from paramiko.dsskey import DSSKey
from paramiko.hostkeys import HostKeys
+from paramiko.resource import ResourceManager
from paramiko.rsakey import RSAKey
from paramiko.ssh_exception import SSHException, BadHostKeyException
from paramiko.transport import Transport
@@ -256,6 +257,7 @@ class SSHClient (object):
if self._log_channel is not None:
t.set_log_channel(self._log_channel)
t.start_client()
+ ResourceManager.register(self, t)
server_key = t.get_remote_server_key()
keytype = server_key.get_name()
diff --git a/paramiko/resource.py b/paramiko/resource.py
new file mode 100644
index 00000000..135af157
--- /dev/null
+++ b/paramiko/resource.py
@@ -0,0 +1,72 @@
+# Copyright (C) 2003-2006 Robey Pointer <robey@lag.net>
+#
+# This file is part of paramiko.
+#
+# Paramiko is free software; you can redistribute it and/or modify it under the
+# terms of the GNU Lesser General Public License as published by the Free
+# Software Foundation; either version 2.1 of the License, or (at your option)
+# any later version.
+#
+# Paramiko is distrubuted in the hope that it will be useful, but WITHOUT ANY
+# WARRANTY; without even the implied warranty of MERCHANTABILITY or FITNESS FOR
+# A PARTICULAR PURPOSE. See the GNU Lesser General Public License for more
+# details.
+#
+# You should have received a copy of the GNU Lesser General Public License
+# along with Paramiko; if not, write to the Free Software Foundation, Inc.,
+# 59 Temple Place, Suite 330, Boston, MA 02111-1307 USA.
+
+"""
+Resource manager.
+"""
+
+import weakref
+
+
+class ResourceManager (object):
+ """
+ A registry of objects and resources that should be closed when those
+ objects are deleted.
+
+ This is meant to be a safer alternative to python's C{__del__} method,
+ which can cause reference cycles to never be collected. Objects registered
+ with the ResourceManager can be collected but still free resources when
+ they die.
+
+ Resources are registered using L{register}, and when an object is garbage
+ collected, each registered resource is closed by having its C{close()}
+ method called. Multiple resources may be registered per object, but a
+ resource will only be closed once, even if multiple objects register it.
+ (The last object to register it wins.)
+ """
+
+ def __init__(self):
+ self._table = {}
+
+ def register(self, obj, resource):
+ """
+ Register a resource to be closed with an object is collected.
+
+ When the given C{obj} is garbage-collected by the python interpreter,
+ the C{resource} will be closed by having its C{close()} method called.
+ Any exceptions are ignored.
+
+ @param obj: the object to track
+ @type obj: object
+ @param resource: the resource to close when the object is collected
+ @type resource: object
+ """
+ def callback(ref):
+ try:
+ resource.close()
+ except:
+ pass
+ del self._table[id(resource)]
+
+ # keep the weakref in a table so it sticks around long enough to get
+ # its callback called. :)
+ self._table[id(resource)] = weakref.ref(obj, callback)
+
+
+# singleton
+ResourceManager = ResourceManager()
diff --git a/tests/test_client.py b/tests/test_client.py
index 32f1a304..a53ff0e9 100644
--- a/tests/test_client.py
+++ b/tests/test_client.py
@@ -23,6 +23,8 @@ Some unit tests for SSHClient.
import socket
import threading
import unittest
+import weakref
+
import paramiko
@@ -59,7 +61,8 @@ class SSHClientTest (unittest.TestCase):
thread.start()
def tearDown(self):
- self.tc.close()
+ if hasattr(self, 'tc'):
+ self.tc.close()
self.ts.close()
self.socks.close()
self.sockl.close()
@@ -125,3 +128,26 @@ class SSHClientTest (unittest.TestCase):
self.assertEquals(True, self.ts.is_authenticated())
self.assertEquals(1, len(self.tc.get_host_keys()))
self.assertEquals(public_host_key, self.tc.get_host_keys()[self.addr]['ssh-rsa'])
+
+ def test_3_cleanup(self):
+ """
+ verify that when an SSHClient is collected, its transport (and the
+ transport's packetizer) is closed.
+ """
+ host_key = paramiko.RSAKey.from_private_key_file('tests/test_rsa.key')
+ public_host_key = paramiko.RSAKey(data=str(host_key))
+
+ self.tc = paramiko.SSHClient()
+ self.tc.set_missing_host_key_policy(paramiko.AutoAddPolicy())
+ self.assertEquals(0, len(self.tc.get_host_keys()))
+ self.tc.connect(self.addr, self.port, username='slowdive', password='pygmalion')
+
+ self.event.wait(1.0)
+ self.assert_(self.event.isSet())
+ self.assert_(self.ts.is_active())
+
+ p = weakref.ref(self.tc._transport.packetizer)
+ self.assert_(p() is not None)
+ del self.tc
+ self.assert_(p() is None)
+