diff options
Diffstat (limited to 'paramiko/transport.py')
-rw-r--r-- | paramiko/transport.py | 136 |
1 files changed, 78 insertions, 58 deletions
diff --git a/paramiko/transport.py b/paramiko/transport.py index 06f75801..7931a11f 100644 --- a/paramiko/transport.py +++ b/paramiko/transport.py @@ -138,6 +138,9 @@ class BaseTransport (threading.Thread): self.sock = sock # Python < 2.3 doesn't have the settimeout method - RogerB try: + # we set the timeout so we can check self.active periodically to + # see if we should bail. socket.timeout exception is never + # propagated. self.sock.settimeout(0.1) except AttributeError: pass @@ -155,7 +158,7 @@ class BaseTransport (threading.Thread): self.expected_packet = 0 self.active = False self.initial_kex_done = False - self.write_lock = threading.Lock() # lock around outbound writes (packet computation) + self.write_lock = threading.RLock() # lock around outbound writes (packet computation) self.lock = threading.Lock() # synchronization (always higher level than write_lock) self.channels = { } # (id -> Channel) self.channel_events = { } # (id -> Event) @@ -165,10 +168,13 @@ class BaseTransport (threading.Thread): self.max_packet_size = 32768 self.ultra_debug = False self.saved_exception = None + self.clear_to_send = threading.Event() # used for noticing when to re-key: self.received_bytes = 0 self.received_packets = 0 self.received_packets_overflow = 0 + self.sent_bytes = 0 + self.sent_packets = 0 # user-defined event callbacks: self.completion_event = None # keepalives: @@ -176,6 +182,7 @@ class BaseTransport (threading.Thread): self.keepalive_last = time.time() # server mode: self.server_mode = 0 + self.server_object = None self.server_key_dict = { } self.server_accepts = [ ] self.server_accept_cv = threading.Condition(self.lock) @@ -223,7 +230,7 @@ class BaseTransport (threading.Thread): self.completion_event = event self.start() - def start_server(self, event=None): + def start_server(self, event=None, server=None): """ Negotiate a new SSH2 session as a server. This is the first step after creating a new L{Transport} and setting up your server host key(s). A @@ -235,15 +242,16 @@ class BaseTransport (threading.Thread): After a successful negotiation, the client will need to authenticate. Override the methods - L{get_allowed_auths <Transport.get_allowed_auths>}, - L{check_auth_none <Transport.check_auth_none>}, - L{check_auth_password <Transport.check_auth_password>}, and - L{check_auth_publickey <Transport.check_auth_publickey>} to control the - authentication process. + L{get_allowed_auths <ServerInterface.get_allowed_auths>}, + L{check_auth_none <ServerInterface.check_auth_none>}, + L{check_auth_password <ServerInterface.check_auth_password>}, and + L{check_auth_publickey <ServerInterface.check_auth_publickey>} in the + given C{server} object to control the authentication process. After a successful authentication, the client should request to open - a channel. Override L{check_channel_request} to allow channels to - be opened. + a channel. Override + L{check_channel_request <ServerInterface.check_channel_request>} in the + given C{server} object to allow channels to be opened. @note: After calling this method (or L{start_client} or L{connect}), you should no longer directly read from or write to the original socket @@ -251,8 +259,14 @@ class BaseTransport (threading.Thread): @param event: an event to trigger when negotiation is complete. @type event: threading.Event + @param server: an object used to perform authentication and create + L{Channel}s. + @type server: L{server.ServerInterface} """ + if server is None: + server = ServerInterface() self.server_mode = 1 + self.server_object = server self.completion_event = event self.start() @@ -422,7 +436,7 @@ class BaseTransport (threading.Thread): self.channel_events[chanid] = event = threading.Event() chan._set_transport(self) chan._set_window(self.window_size, self.max_packet_size) - self._send_message(m) + self._send_user_message(m) finally: self.lock.release() while 1: @@ -457,7 +471,7 @@ class BaseTransport (threading.Thread): if bytes is None: bytes = (ord(randpool.get_bytes(1)) % 32) + 10 m.add_bytes(randpool.get_bytes(bytes)) - self._send_message(m) + self._send_user_message(m) def renegotiate_keys(self): """ @@ -528,7 +542,7 @@ class BaseTransport (threading.Thread): for item in data: m.add(item) self._log(DEBUG, 'Sending global request "%s"' % kind) - self._send_message(m) + self._send_user_message(m) if not wait: return True while True: @@ -539,36 +553,6 @@ class BaseTransport (threading.Thread): break return self.global_response - def check_channel_request(self, kind, chanid): - """ - I{(subclass override)} - Determine if a channel request of a given type will be granted, and - return a suitable L{Channel} object. This method is called in server - mode when the client requests a channel, after authentication is - complete. - - In server mode, you will generally want to subclass L{Channel} to - override some of the methods for handling client requests (such as - connecting to a subsystem or opening a shell) to determine what you - want to allow or disallow. For this reason, L{check_channel_request} - must return a new object of that type. The C{chanid} parameter is - passed so that you can use it in L{Channel}'s constructor. - - The default implementation always returns C{None}, rejecting any - channel requests. A useful server must override this method. - - @param kind: the kind of channel the client would like to open - (usually C{"session"}). - @type kind: string - @param chanid: ID of the channel, required to create a new L{Channel} - object. - @type chanid: int - @return: a new L{Channel} object (or subclass thereof), or C{None} to - refuse the request. - @rtype: L{Channel} - """ - return None - def check_global_request(self, kind, msg): """ I{(subclass override)} @@ -771,7 +755,11 @@ class BaseTransport (threading.Thread): def _write_all(self, out): self.keepalive_last = time.time() while len(out) > 0: - n = self.sock.send(out) + try: + n = self.sock.send(out) + except: + # could be: (32, 'Broken pipe') + n = -1 if n < 0: raise EOFError() if n == len(out): @@ -810,9 +798,33 @@ class BaseTransport (threading.Thread): self.sequence_number_out += 1L self.sequence_number_out %= 0x100000000L self._write_all(out) + + self.sent_bytes += len(out) + self.sent_packets += 1 + if ((self.sent_packets >= self.REKEY_PACKETS) or (self.sent_bytes >= self.REKEY_BYTES)) \ + and (self.local_kex_init is None): + # only ask once for rekeying + self._log(DEBUG, 'Rekeying (hit %d packets, %d bytes sent)' % + (self.sent_packets, self.sent_bytes)) + self.received_packets_overflow = 0 + self._send_kex_init() finally: self.write_lock.release() + def _send_user_message(self, data): + """ + send a message, but block if we're in key negotiation. this is used + for user-initiated requests. + """ + while 1: + self.clear_to_send.wait(0.1) + if not self.active: + self._log(DEBUG, 'Dropping user packet because connection is dead.') + return + if self.clear_to_send.isSet(): + break + self._send_message(data) + def _read_message(self): "only one thread will ever be in this function" header = self._read_all(self.block_size_in) @@ -850,19 +862,19 @@ class BaseTransport (threading.Thread): # check for rekey self.received_bytes += packet_size + self.remote_mac_len + 4 self.received_packets += 1 - if (self.received_packets >= self.REKEY_PACKETS) or (self.received_bytes >= self.REKEY_BYTES): + if self.local_kex_init is not None: + # we've asked to rekey -- give them 20 packets to comply before + # dropping the connection + self.received_packets_overflow += 1 + if self.received_packets_overflow >= 20: + raise SSHException('Remote transport is ignoring rekey requests') + elif (self.received_packets >= self.REKEY_PACKETS) or \ + (self.received_bytes >= self.REKEY_BYTES): # only ask once for rekeying - if self.local_kex_init is None: - self._log(DEBUG, 'Rekeying (hit %d packets, %d bytes)' % (self.received_packets, - self.received_bytes)) - self.received_packets_overflow = 0 - self._send_kex_init() - else: - # we've asked to rekey already -- give them 20 packets to - # comply, then just drop the connection - self.received_packets_overflow += 1 - if self.received_packets_overflow >= 20: - raise SSHException('Remote transport is ignoring rekey requests') + self._log(DEBUG, 'Rekeying (hit %d packets, %d bytes received)' % + (self.received_packets, self.received_bytes)) + self.received_packets_overflow = 0 + self._send_kex_init() cmd = ord(payload[0]) self._log(DEBUG, 'Read packet $%x, length %d' % (cmd, len(payload))) @@ -964,8 +976,8 @@ class BaseTransport (threading.Thread): self._log(ERROR, util.tb_strings()) self.saved_exception = e except EOFError, e: - self._log(DEBUG, 'EOF') - self._log(DEBUG, util.tb_strings()) + self._log(DEBUG, 'EOF in transport thread') + #self._log(DEBUG, util.tb_strings()) self.saved_exception = e except Exception, e: self._log(ERROR, 'Unknown exception: ' + str(e)) @@ -990,6 +1002,7 @@ class BaseTransport (threading.Thread): def _negotiate_keys(self, m): # throws SSHException on anything unusual + self.clear_to_send.clear() if self.local_kex_init == None: # remote side wants to renegotiate self._send_kex_init() @@ -1033,6 +1046,7 @@ class BaseTransport (threading.Thread): announce to the other side that we'd like to negotiate keys, and what kind of key negotiation we support. """ + self.clear_to_send.clear() if self.server_mode: if (self._modulus_pack is None) and ('diffie-hellman-group-exchange-sha1' in self.preferred_kex): # can't do group-exchange if we don't have a pack of potential primes @@ -1066,6 +1080,8 @@ class BaseTransport (threading.Thread): self.received_bytes = 0 self.received_packets = 0 self.received_packets_overflow = 0 + self.sent_bytes = 0 + self.sent_packets = 0 cookie = m.get_bytes(16) kex_algo_list = m.get_list() @@ -1211,6 +1227,8 @@ class BaseTransport (threading.Thread): # send an event? if self.completion_event != None: self.completion_event.set() + # it's now okay to send data again (if this was a re-key) + self.clear_to_send.set() return def _parse_disconnect(self, m): @@ -1307,7 +1325,7 @@ class BaseTransport (threading.Thread): self.channel_counter += 1 finally: self.lock.release() - chan = self.check_channel_request(kind, my_chanid) + chan = self.server_object.check_channel_request(kind, my_chanid) if (chan is None) or (type(chan) is int): self._log(DEBUG, 'Rejecting "%s" channel request from client.' % kind) reject = True @@ -1373,3 +1391,5 @@ class BaseTransport (threading.Thread): MSG_CHANNEL_EOF: Channel._handle_eof, MSG_CHANNEL_CLOSE: Channel._handle_close, } + +from server import ServerInterface |