diff options
Diffstat (limited to 'paramiko/transport.py')
-rw-r--r-- | paramiko/transport.py | 92 |
1 files changed, 70 insertions, 22 deletions
diff --git a/paramiko/transport.py b/paramiko/transport.py index 762781b2..0a1daf38 100644 --- a/paramiko/transport.py +++ b/paramiko/transport.py @@ -274,9 +274,10 @@ class Transport (threading.Thread): self.channels = weakref.WeakValueDictionary() # (id -> Channel) self.channel_events = { } # (id -> Event) self.channels_seen = { } # (id -> True) - self.channel_counter = 1 + self._channel_counter = 1 self.window_size = 65536 self.max_packet_size = 34816 + self._x11_handler = None self.saved_exception = None self.clear_to_send = threading.Event() @@ -592,6 +593,22 @@ class Transport (threading.Thread): """ return self.open_channel('session') + def open_x11_channel(self, src_addr=None): + """ + Request a new channel to the client, of type C{"x11"}. This + is just an alias for C{open_channel('x11', src_addr=src_addr)}. + + @param src_addr: the source address of the x11 server (port is the + x11 port, ie. 6010) + @type src_addr: (str, int) + @return: a new L{Channel} + @rtype: L{Channel} + + @raise SSHException: if the request is rejected or the session ends + prematurely + """ + return self.open_channel('x11', src_addr=src_addr) + def open_channel(self, kind, dest_addr=None, src_addr=None): """ Request a new channel to the server. L{Channel}s are socket-like @@ -621,11 +638,7 @@ class Transport (threading.Thread): return None self.lock.acquire() try: - chanid = self.channel_counter - while chanid in self.channels: - self.channel_counter = (self.channel_counter + 1) & 0xffffff - chanid = self.channel_counter - self.channel_counter = (self.channel_counter + 1) & 0xffffff + chanid = self._next_channel() m = Message() m.add_byte(chr(MSG_CHANNEL_OPEN)) m.add_string(kind) @@ -637,6 +650,9 @@ class Transport (threading.Thread): m.add_int(dest_addr[1]) m.add_string(src_addr[0]) m.add_int(src_addr[1]) + elif kind == 'x11': + m.add_string(src_addr[0]) + m.add_int(src_addr[1]) self.channels[chanid] = chan = Channel(chanid) self.channel_events[chanid] = event = threading.Event() self.channels_seen[chanid] = True @@ -1230,17 +1246,26 @@ class Transport (threading.Thread): ### internals... - def _log(self, level, msg): + def _log(self, level, msg, *args): if issubclass(type(msg), list): for m in msg: self.logger.log(level, m) else: - self.logger.log(level, msg) + self.logger.log(level, msg, *args) def _get_modulus_pack(self): "used by KexGex to find primes for group exchange" return self._modulus_pack + def _next_channel(self): + "you are holding the lock" + chanid = self._channel_counter + while chanid in self.channels: + self._channel_counter = (self._channel_counter + 1) & 0xffffff + chanid = self._channel_counter + self._channel_counter = (self._channel_counter + 1) & 0xffffff + return chanid + def _unlink_channel(self, chanid): "used by a Channel to remove itself from the active channel list" try: @@ -1314,6 +1339,25 @@ class Transport (threading.Thread): raise SSHException('Unknown client cipher ' + name) return self._cipher_info[name]['class'].new(key, self._cipher_info[name]['mode'], iv) + def _set_x11_handler(self, handler): + # only called if a channel has turned on x11 forwarding + if handler is None: + # by default, use the same mechanism as accept() + self._x11_handler = self._default_x11_handler + else: + self._x11_hanlder = handler + + def _default_x11_handler(self, channel, (src_addr, src_port)): + self._queue_incoming_channel(channel) + + def _queue_incoming_channel(self, channel): + self.lock.acquire() + try: + self.server_accepts.append(channel) + self.server_accept_cv.notify() + finally: + self.lock.release() + def run(self): # (use the exposed "run" method, because if we specify a thread target # of a private method, threading.Thread will keep a reference to it @@ -1710,7 +1754,7 @@ class Transport (threading.Thread): self._log(DEBUG, 'Received global request "%s"' % kind) want_reply = m.get_boolean() if not self.server_mode: - self._log(DEBUG, 'Rejecting "%s" channel request from server.' % kind) + self._log(DEBUG, 'Rejecting "%s" global request from server.' % kind) ok = False else: ok = self.server_object.check_global_request(kind, m) @@ -1784,18 +1828,23 @@ class Transport (threading.Thread): initial_window_size = m.get_int() max_packet_size = m.get_int() reject = False - if not self.server_mode: + if (kind == 'x11') and (self._x11_handler is not None): + origin_addr = m.get_string() + origin_port = m.get_int() + self._log(DEBUG, 'Incoming x11 connection from %s:%d' % (origin_addr, origin_port)) + self.lock.acquire() + try: + my_chanid = self._next_channel() + finally: + self.lock.release() + elif not self.server_mode: self._log(DEBUG, 'Rejecting "%s" channel request from server.' % kind) reject = True reason = OPEN_FAILED_ADMINISTRATIVELY_PROHIBITED else: self.lock.acquire() try: - my_chanid = self.channel_counter - while my_chanid in self.channels: - self.channel_counter = (self.channel_counter + 1) & 0xffffff - my_chanid = self.channel_counter - self.channel_counter = (self.channel_counter + 1) & 0xffffff + my_chanid = self._next_channel() finally: self.lock.release() reason = self.server_object.check_channel_request(kind, my_chanid) @@ -1811,6 +1860,7 @@ class Transport (threading.Thread): msg.add_string('en') self._send_message(msg) return + chan = Channel(my_chanid) try: self.lock.acquire() @@ -1828,13 +1878,11 @@ class Transport (threading.Thread): m.add_int(self.window_size) m.add_int(self.max_packet_size) self._send_message(m) - self._log(INFO, 'Secsh channel %d opened.' % my_chanid) - try: - self.lock.acquire() - self.server_accepts.append(chan) - self.server_accept_cv.notify() - finally: - self.lock.release() + self._log(INFO, 'Secsh channel %d (%s) opened.', my_chanid, kind) + if kind == 'x11': + self._x11_handler(chan, (origin_addr, origin_port)) + else: + self._queue_incoming_channel(chan) def _parse_debug(self, m): always_display = m.get_boolean() |