summaryrefslogtreecommitdiffhomepage
diff options
context:
space:
mode:
authorRobey Pointer <robey@lag.net>2008-03-23 01:21:10 -0700
committerRobey Pointer <robey@lag.net>2008-03-23 01:21:10 -0700
commit9a6ffec93fea7627a5c1d85e19c5060c1c0e0943 (patch)
tree355f40d95453ac75885a658a2b657db2dac1695c
parente5a1b4bf569599b53016374f96e53639799ed6d3 (diff)
[project @ robey@lag.net-20080323082110-o9fglwyiozn08tm9]
bug 191657: clean up usage of the channel map by making a special object to hold the weak value dict.
-rw-r--r--paramiko/transport.py92
1 files changed, 64 insertions, 28 deletions
diff --git a/paramiko/transport.py b/paramiko/transport.py
index 4ceb4a03..af4c3076 100644
--- a/paramiko/transport.py
+++ b/paramiko/transport.py
@@ -140,6 +140,51 @@ class SecurityOptions (object):
"Compression algorithms")
+class ChannelMap (object):
+ def __init__(self):
+ # (id -> Channel)
+ self._map = weakref.WeakValueDictionary()
+ self._lock = threading.Lock()
+
+ def put(self, chanid, chan):
+ self._lock.acquire()
+ try:
+ self._map[chanid] = chan
+ finally:
+ self._lock.release()
+
+ def get(self, chanid):
+ self._lock.acquire()
+ try:
+ return self._map.get(chanid, None)
+ finally:
+ self._lock.release()
+
+ def delete(self, chanid):
+ self._lock.acquire()
+ try:
+ try:
+ del self._map[chanid]
+ except KeyError:
+ pass
+ finally:
+ self._lock.release()
+
+ def values(self):
+ self._lock.acquire()
+ try:
+ return self._map.values()
+ finally:
+ self._lock.release()
+
+ def __len__(self):
+ self._lock.acquire()
+ try:
+ return len(self._map)
+ finally:
+ self._lock.release()
+
+
class Transport (threading.Thread):
"""
An SSH Transport attaches to a stream (usually a socket), negotiates an
@@ -271,7 +316,7 @@ class Transport (threading.Thread):
self.lock = threading.Lock() # synchronization (always higher level than write_lock)
# tracking open channels
- self.channels = weakref.WeakValueDictionary() # (id -> Channel)
+ self._channels = ChannelMap()
self.channel_events = { } # (id -> Event)
self.channels_seen = { } # (id -> True)
self._channel_counter = 1
@@ -313,10 +358,7 @@ class Transport (threading.Thread):
out += ' (cipher %s, %d bits)' % (self.local_cipher,
self._cipher_info[self.local_cipher]['key-size'] * 8)
if self.is_authenticated():
- if len(self.channels) == 1:
- out += ' (active; 1 open channel)'
- else:
- out += ' (active; %d open channels)' % len(self.channels)
+ out += ' (active; %d open channel(s))' % len(self._channels)
elif self.initial_kex_done:
out += ' (connected; awaiting auth)'
else:
@@ -550,7 +592,7 @@ class Transport (threading.Thread):
self.active = False
self.packetizer.close()
self.join()
- for chan in self.channels.values():
+ for chan in self._channels.values():
chan._unlink()
def get_remote_server_key(self):
@@ -667,7 +709,8 @@ class Transport (threading.Thread):
elif kind == 'x11':
m.add_string(src_addr[0])
m.add_int(src_addr[1])
- self.channels[chanid] = chan = Channel(chanid)
+ chan = Channel(chanid)
+ self._channels.put(chanid, chan)
self.channel_events[chanid] = event = threading.Event()
self.channels_seen[chanid] = True
chan._set_transport(self)
@@ -684,12 +727,9 @@ class Transport (threading.Thread):
raise e
if event.isSet():
break
- self.lock.acquire()
- try:
- if chanid in self.channels:
- return chan
- finally:
- self.lock.release()
+ chan = self._channels.get(chanid)
+ if chan is not None:
+ return chan
e = self.get_exception()
if e is None:
e = SSHException('Unable to open channel.')
@@ -1334,7 +1374,7 @@ class Transport (threading.Thread):
def _next_channel(self):
"you are holding the lock"
chanid = self._channel_counter
- while chanid in self.channels:
+ while self._channels.get(chanid) is not None:
self._channel_counter = (self._channel_counter + 1) & 0xffffff
chanid = self._channel_counter
self._channel_counter = (self._channel_counter + 1) & 0xffffff
@@ -1342,12 +1382,7 @@ class Transport (threading.Thread):
def _unlink_channel(self, chanid):
"used by a Channel to remove itself from the active channel list"
- try:
- self.lock.acquire()
- if chanid in self.channels:
- del self.channels[chanid]
- finally:
- self.lock.release()
+ self._channels.delete(chanid)
def _send_message(self, data):
self.packetizer.send_message(data)
@@ -1478,8 +1513,9 @@ class Transport (threading.Thread):
self._handler_table[ptype](self, m)
elif ptype in self._channel_handler_table:
chanid = m.get_int()
- if chanid in self.channels:
- self._channel_handler_table[ptype](self.channels[chanid], m)
+ chan = self._channels.get(chanid)
+ if chan is not None:
+ self._channel_handler_table[ptype](chan, m)
elif chanid in self.channels_seen:
self._log(DEBUG, 'Ignoring message for dead channel %d' % chanid)
else:
@@ -1514,7 +1550,7 @@ class Transport (threading.Thread):
self._log(ERROR, util.tb_strings())
self.saved_exception = e
_active_threads.remove(self)
- for chan in self.channels.values():
+ for chan in self._channels.values():
chan._unlink()
if self.active:
self.active = False
@@ -1872,12 +1908,12 @@ class Transport (threading.Thread):
server_chanid = m.get_int()
server_window_size = m.get_int()
server_max_packet_size = m.get_int()
- if chanid not in self.channels:
+ chan = self._channels.get(chanid)
+ if chan is None:
self._log(WARNING, 'Success for unrequested channel! [??]')
return
self.lock.acquire()
try:
- chan = self.channels[chanid]
chan._set_remote_channel(server_chanid, server_window_size, server_max_packet_size)
self._log(INFO, 'Secsh channel %d opened.' % chanid)
if chanid in self.channel_events:
@@ -1898,7 +1934,7 @@ class Transport (threading.Thread):
try:
self.saved_exception = ChannelException(reason, reason_text)
if chanid in self.channel_events:
- del self.channels[chanid]
+ self._channels.delete(chanid)
if chanid in self.channel_events:
self.channel_events[chanid].set()
del self.channel_events[chanid]
@@ -1967,9 +2003,9 @@ class Transport (threading.Thread):
return
chan = Channel(my_chanid)
+ self.lock.acquire()
try:
- self.lock.acquire()
- self.channels[my_chanid] = chan
+ self._channels.put(my_chanid, chan)
self.channels_seen[my_chanid] = True
chan._set_transport(self)
chan._set_window(self.window_size, self.max_packet_size)