summaryrefslogtreecommitdiffhomepage
path: root/paramiko/transport.py
diff options
context:
space:
mode:
Diffstat (limited to 'paramiko/transport.py')
-rw-r--r--paramiko/transport.py136
1 files changed, 78 insertions, 58 deletions
diff --git a/paramiko/transport.py b/paramiko/transport.py
index 06f75801..7931a11f 100644
--- a/paramiko/transport.py
+++ b/paramiko/transport.py
@@ -138,6 +138,9 @@ class BaseTransport (threading.Thread):
self.sock = sock
# Python < 2.3 doesn't have the settimeout method - RogerB
try:
+ # we set the timeout so we can check self.active periodically to
+ # see if we should bail. socket.timeout exception is never
+ # propagated.
self.sock.settimeout(0.1)
except AttributeError:
pass
@@ -155,7 +158,7 @@ class BaseTransport (threading.Thread):
self.expected_packet = 0
self.active = False
self.initial_kex_done = False
- self.write_lock = threading.Lock() # lock around outbound writes (packet computation)
+ self.write_lock = threading.RLock() # lock around outbound writes (packet computation)
self.lock = threading.Lock() # synchronization (always higher level than write_lock)
self.channels = { } # (id -> Channel)
self.channel_events = { } # (id -> Event)
@@ -165,10 +168,13 @@ class BaseTransport (threading.Thread):
self.max_packet_size = 32768
self.ultra_debug = False
self.saved_exception = None
+ self.clear_to_send = threading.Event()
# used for noticing when to re-key:
self.received_bytes = 0
self.received_packets = 0
self.received_packets_overflow = 0
+ self.sent_bytes = 0
+ self.sent_packets = 0
# user-defined event callbacks:
self.completion_event = None
# keepalives:
@@ -176,6 +182,7 @@ class BaseTransport (threading.Thread):
self.keepalive_last = time.time()
# server mode:
self.server_mode = 0
+ self.server_object = None
self.server_key_dict = { }
self.server_accepts = [ ]
self.server_accept_cv = threading.Condition(self.lock)
@@ -223,7 +230,7 @@ class BaseTransport (threading.Thread):
self.completion_event = event
self.start()
- def start_server(self, event=None):
+ def start_server(self, event=None, server=None):
"""
Negotiate a new SSH2 session as a server. This is the first step after
creating a new L{Transport} and setting up your server host key(s). A
@@ -235,15 +242,16 @@ class BaseTransport (threading.Thread):
After a successful negotiation, the client will need to authenticate.
Override the methods
- L{get_allowed_auths <Transport.get_allowed_auths>},
- L{check_auth_none <Transport.check_auth_none>},
- L{check_auth_password <Transport.check_auth_password>}, and
- L{check_auth_publickey <Transport.check_auth_publickey>} to control the
- authentication process.
+ L{get_allowed_auths <ServerInterface.get_allowed_auths>},
+ L{check_auth_none <ServerInterface.check_auth_none>},
+ L{check_auth_password <ServerInterface.check_auth_password>}, and
+ L{check_auth_publickey <ServerInterface.check_auth_publickey>} in the
+ given C{server} object to control the authentication process.
After a successful authentication, the client should request to open
- a channel. Override L{check_channel_request} to allow channels to
- be opened.
+ a channel. Override
+ L{check_channel_request <ServerInterface.check_channel_request>} in the
+ given C{server} object to allow channels to be opened.
@note: After calling this method (or L{start_client} or L{connect}),
you should no longer directly read from or write to the original socket
@@ -251,8 +259,14 @@ class BaseTransport (threading.Thread):
@param event: an event to trigger when negotiation is complete.
@type event: threading.Event
+ @param server: an object used to perform authentication and create
+ L{Channel}s.
+ @type server: L{server.ServerInterface}
"""
+ if server is None:
+ server = ServerInterface()
self.server_mode = 1
+ self.server_object = server
self.completion_event = event
self.start()
@@ -422,7 +436,7 @@ class BaseTransport (threading.Thread):
self.channel_events[chanid] = event = threading.Event()
chan._set_transport(self)
chan._set_window(self.window_size, self.max_packet_size)
- self._send_message(m)
+ self._send_user_message(m)
finally:
self.lock.release()
while 1:
@@ -457,7 +471,7 @@ class BaseTransport (threading.Thread):
if bytes is None:
bytes = (ord(randpool.get_bytes(1)) % 32) + 10
m.add_bytes(randpool.get_bytes(bytes))
- self._send_message(m)
+ self._send_user_message(m)
def renegotiate_keys(self):
"""
@@ -528,7 +542,7 @@ class BaseTransport (threading.Thread):
for item in data:
m.add(item)
self._log(DEBUG, 'Sending global request "%s"' % kind)
- self._send_message(m)
+ self._send_user_message(m)
if not wait:
return True
while True:
@@ -539,36 +553,6 @@ class BaseTransport (threading.Thread):
break
return self.global_response
- def check_channel_request(self, kind, chanid):
- """
- I{(subclass override)}
- Determine if a channel request of a given type will be granted, and
- return a suitable L{Channel} object. This method is called in server
- mode when the client requests a channel, after authentication is
- complete.
-
- In server mode, you will generally want to subclass L{Channel} to
- override some of the methods for handling client requests (such as
- connecting to a subsystem or opening a shell) to determine what you
- want to allow or disallow. For this reason, L{check_channel_request}
- must return a new object of that type. The C{chanid} parameter is
- passed so that you can use it in L{Channel}'s constructor.
-
- The default implementation always returns C{None}, rejecting any
- channel requests. A useful server must override this method.
-
- @param kind: the kind of channel the client would like to open
- (usually C{"session"}).
- @type kind: string
- @param chanid: ID of the channel, required to create a new L{Channel}
- object.
- @type chanid: int
- @return: a new L{Channel} object (or subclass thereof), or C{None} to
- refuse the request.
- @rtype: L{Channel}
- """
- return None
-
def check_global_request(self, kind, msg):
"""
I{(subclass override)}
@@ -771,7 +755,11 @@ class BaseTransport (threading.Thread):
def _write_all(self, out):
self.keepalive_last = time.time()
while len(out) > 0:
- n = self.sock.send(out)
+ try:
+ n = self.sock.send(out)
+ except:
+ # could be: (32, 'Broken pipe')
+ n = -1
if n < 0:
raise EOFError()
if n == len(out):
@@ -810,9 +798,33 @@ class BaseTransport (threading.Thread):
self.sequence_number_out += 1L
self.sequence_number_out %= 0x100000000L
self._write_all(out)
+
+ self.sent_bytes += len(out)
+ self.sent_packets += 1
+ if ((self.sent_packets >= self.REKEY_PACKETS) or (self.sent_bytes >= self.REKEY_BYTES)) \
+ and (self.local_kex_init is None):
+ # only ask once for rekeying
+ self._log(DEBUG, 'Rekeying (hit %d packets, %d bytes sent)' %
+ (self.sent_packets, self.sent_bytes))
+ self.received_packets_overflow = 0
+ self._send_kex_init()
finally:
self.write_lock.release()
+ def _send_user_message(self, data):
+ """
+ send a message, but block if we're in key negotiation. this is used
+ for user-initiated requests.
+ """
+ while 1:
+ self.clear_to_send.wait(0.1)
+ if not self.active:
+ self._log(DEBUG, 'Dropping user packet because connection is dead.')
+ return
+ if self.clear_to_send.isSet():
+ break
+ self._send_message(data)
+
def _read_message(self):
"only one thread will ever be in this function"
header = self._read_all(self.block_size_in)
@@ -850,19 +862,19 @@ class BaseTransport (threading.Thread):
# check for rekey
self.received_bytes += packet_size + self.remote_mac_len + 4
self.received_packets += 1
- if (self.received_packets >= self.REKEY_PACKETS) or (self.received_bytes >= self.REKEY_BYTES):
+ if self.local_kex_init is not None:
+ # we've asked to rekey -- give them 20 packets to comply before
+ # dropping the connection
+ self.received_packets_overflow += 1
+ if self.received_packets_overflow >= 20:
+ raise SSHException('Remote transport is ignoring rekey requests')
+ elif (self.received_packets >= self.REKEY_PACKETS) or \
+ (self.received_bytes >= self.REKEY_BYTES):
# only ask once for rekeying
- if self.local_kex_init is None:
- self._log(DEBUG, 'Rekeying (hit %d packets, %d bytes)' % (self.received_packets,
- self.received_bytes))
- self.received_packets_overflow = 0
- self._send_kex_init()
- else:
- # we've asked to rekey already -- give them 20 packets to
- # comply, then just drop the connection
- self.received_packets_overflow += 1
- if self.received_packets_overflow >= 20:
- raise SSHException('Remote transport is ignoring rekey requests')
+ self._log(DEBUG, 'Rekeying (hit %d packets, %d bytes received)' %
+ (self.received_packets, self.received_bytes))
+ self.received_packets_overflow = 0
+ self._send_kex_init()
cmd = ord(payload[0])
self._log(DEBUG, 'Read packet $%x, length %d' % (cmd, len(payload)))
@@ -964,8 +976,8 @@ class BaseTransport (threading.Thread):
self._log(ERROR, util.tb_strings())
self.saved_exception = e
except EOFError, e:
- self._log(DEBUG, 'EOF')
- self._log(DEBUG, util.tb_strings())
+ self._log(DEBUG, 'EOF in transport thread')
+ #self._log(DEBUG, util.tb_strings())
self.saved_exception = e
except Exception, e:
self._log(ERROR, 'Unknown exception: ' + str(e))
@@ -990,6 +1002,7 @@ class BaseTransport (threading.Thread):
def _negotiate_keys(self, m):
# throws SSHException on anything unusual
+ self.clear_to_send.clear()
if self.local_kex_init == None:
# remote side wants to renegotiate
self._send_kex_init()
@@ -1033,6 +1046,7 @@ class BaseTransport (threading.Thread):
announce to the other side that we'd like to negotiate keys, and what
kind of key negotiation we support.
"""
+ self.clear_to_send.clear()
if self.server_mode:
if (self._modulus_pack is None) and ('diffie-hellman-group-exchange-sha1' in self.preferred_kex):
# can't do group-exchange if we don't have a pack of potential primes
@@ -1066,6 +1080,8 @@ class BaseTransport (threading.Thread):
self.received_bytes = 0
self.received_packets = 0
self.received_packets_overflow = 0
+ self.sent_bytes = 0
+ self.sent_packets = 0
cookie = m.get_bytes(16)
kex_algo_list = m.get_list()
@@ -1211,6 +1227,8 @@ class BaseTransport (threading.Thread):
# send an event?
if self.completion_event != None:
self.completion_event.set()
+ # it's now okay to send data again (if this was a re-key)
+ self.clear_to_send.set()
return
def _parse_disconnect(self, m):
@@ -1307,7 +1325,7 @@ class BaseTransport (threading.Thread):
self.channel_counter += 1
finally:
self.lock.release()
- chan = self.check_channel_request(kind, my_chanid)
+ chan = self.server_object.check_channel_request(kind, my_chanid)
if (chan is None) or (type(chan) is int):
self._log(DEBUG, 'Rejecting "%s" channel request from client.' % kind)
reject = True
@@ -1373,3 +1391,5 @@ class BaseTransport (threading.Thread):
MSG_CHANNEL_EOF: Channel._handle_eof,
MSG_CHANNEL_CLOSE: Channel._handle_close,
}
+
+from server import ServerInterface