diff options
Diffstat (limited to 'paramiko/transport.py')
-rw-r--r-- | paramiko/transport.py | 340 |
1 files changed, 90 insertions, 250 deletions
diff --git a/paramiko/transport.py b/paramiko/transport.py index da252ce7..b436847c 100644 --- a/paramiko/transport.py +++ b/paramiko/transport.py @@ -1,5 +1,3 @@ -#!/usr/bin/python - # Copyright (C) 2003-2005 Robey Pointer <robey@lag.net> # # This file is part of paramiko. @@ -30,6 +28,7 @@ from message import Message from channel import Channel from sftp_client import SFTPClient import util +from packet import Packetizer from rsakey import RSAKey from dsskey import DSSKey from kex_group1 import KexGroup1 @@ -49,7 +48,7 @@ from Crypto.Hash import SHA, MD5, HMAC _active_threads = [] def _join_lingering_threads(): for thr in _active_threads: - thr.active = False + thr.stop_thread() import atexit atexit.register(_join_lingering_threads) @@ -162,10 +161,6 @@ class BaseTransport (threading.Thread): 'diffie-hellman-group-exchange-sha1': KexGex, } - # READ the secsh RFC's before raising these values. if anything, - # they should probably be lower. - REKEY_PACKETS = pow(2, 30) - REKEY_BYTES = pow(2, 30) _modulus_pack = None @@ -222,42 +217,29 @@ class BaseTransport (threading.Thread): except AttributeError: pass # negotiated crypto parameters + self.packetizer = Packetizer(sock) self.local_version = 'SSH-' + self._PROTO_ID + '-' + self._CLIENT_ID self.remote_version = '' - self.block_size_out = self.block_size_in = 8 - self.local_mac_len = self.remote_mac_len = 0 - self.engine_in = self.engine_out = None self.local_cipher = self.remote_cipher = '' - self.sequence_number_in = self.sequence_number_out = 0L self.local_kex_init = self.remote_kex_init = None self.session_id = None # /negotiated crypto parameters self.expected_packet = 0 self.active = False self.initial_kex_done = False - self.write_lock = threading.RLock() # lock around outbound writes (packet computation) - self.lock = threading.Lock() # synchronization (always higher level than write_lock) - self.channels = { } # (id -> Channel) - self.channel_events = { } # (id -> Event) + self.lock = threading.Lock() # synchronization (always higher level than write_lock) + self.channels = { } # (id -> Channel) + self.channel_events = { } # (id -> Event) self.channel_counter = 1 self.window_size = 65536 self.max_packet_size = 32768 - self.ultra_debug = False self.saved_exception = None self.clear_to_send = threading.Event() self.log_name = 'paramiko.transport' self.logger = util.get_logger(self.log_name) - # used for noticing when to re-key: - self.received_bytes = 0 - self.received_packets = 0 - self.received_packets_overflow = 0 - self.sent_bytes = 0 - self.sent_packets = 0 + self.packetizer.set_log(self.logger) # user-defined event callbacks: self.completion_event = None - # keepalives: - self.keepalive_interval = 0 - self.keepalive_last = time.time() # server mode: self.server_mode = False self.server_object = None @@ -293,7 +275,7 @@ class BaseTransport (threading.Thread): preference for them. @return: an object that can be used to change the preferred algorithms - for encryption, digest (hash), public key, and key exchange. + for encryption, digest (hash), public key, and key exchange. @rtype: L{SecurityOptions} @since: ivysaur @@ -316,8 +298,8 @@ class BaseTransport (threading.Thread): @note: L{connect} is a simpler method for connecting as a client. @note: After calling this method (or L{start_server} or L{connect}), - you should no longer directly read from or write to the original socket - object. + you should no longer directly read from or write to the original + socket object. @param event: an event to trigger when negotiation is complete. @type event: threading.Event @@ -350,13 +332,13 @@ class BaseTransport (threading.Thread): given C{server} object to allow channels to be opened. @note: After calling this method (or L{start_client} or L{connect}), - you should no longer directly read from or write to the original socket - object. + you should no longer directly read from or write to the original + socket object. @param event: an event to trigger when negotiation is complete. @type event: threading.Event @param server: an object used to perform authentication and create - L{Channel}s. + L{Channel}s. @type server: L{server.ServerInterface} """ if server is None: @@ -376,7 +358,7 @@ class BaseTransport (threading.Thread): key info, not just the public half. @param key: the host key to add, usually an L{RSAKey <rsakey.RSAKey>} or - L{DSSKey <dsskey.DSSKey>}. + L{DSSKey <dsskey.DSSKey>}. @type key: L{PKey <pkey.PKey>} """ self.server_key_dict[key.get_name()] = key @@ -418,10 +400,10 @@ class BaseTransport (threading.Thread): support that method of key negotiation. @param filename: optional path to the moduli file, if you happen to - know that it's not in a standard location. + know that it's not in a standard location. @type filename: str @return: True if a moduli file was successfully loaded; False - otherwise. + otherwise. @rtype: bool @since: doduo @@ -449,8 +431,7 @@ class BaseTransport (threading.Thread): Close this session, and any open channels that are tied to it. """ self.active = False - self.engine_in = self.engine_out = None - self.sequence_number_in = self.sequence_number_out = 0L + self.packetizer.close() for chan in self.channels.values(): chan._unlink() @@ -459,9 +440,9 @@ class BaseTransport (threading.Thread): Return the host key of the server (in client mode). @note: Previously this call returned a tuple of (key type, key string). - You can get the same effect by calling - L{PKey.get_name <pkey.PKey.get_name>} for the key type, and C{str(key)} - for the key string. + You can get the same effect by calling + L{PKey.get_name <pkey.PKey.get_name>} for the key type, and + C{str(key)} for the key string. @raise SSHException: if no session is currently active. @@ -476,7 +457,8 @@ class BaseTransport (threading.Thread): """ Return true if this session is active (open). - @return: True if the session is still active (open); False if the session is closed. + @return: True if the session is still active (open); False if the + session is closed. @rtype: bool """ return self.active @@ -487,7 +469,7 @@ class BaseTransport (threading.Thread): is just an alias for C{open_channel('session')}. @return: a new L{Channel} on success, or C{None} if the request is - rejected or the session ends prematurely. + rejected or the session ends prematurely. @rtype: L{Channel} """ return self.open_channel('session') @@ -500,17 +482,17 @@ class BaseTransport (threading.Thread): L{connect} or L{start_client}) and authenticating. @param kind: the kind of channel requested (usually C{"session"}, - C{"forwarded-tcpip"} or C{"direct-tcpip"}). + C{"forwarded-tcpip"} or C{"direct-tcpip"}). @type kind: str @param dest_addr: the destination address of this port forwarding, - if C{kind} is C{"forwarded-tcpip"} or C{"direct-tcpip"} (ignored - for other channel types). + if C{kind} is C{"forwarded-tcpip"} or C{"direct-tcpip"} (ignored + for other channel types). @type dest_addr: (str, int) @param src_addr: the source address of this port forwarding, if - C{kind} is C{"forwarded-tcpip"} or C{"direct-tcpip"}. + C{kind} is C{"forwarded-tcpip"} or C{"direct-tcpip"}. @type src_addr: (str, int) @return: a new L{Channel} on success, or C{None} if the request is - rejected or the session ends prematurely. + rejected or the session ends prematurely. @rtype: L{Channel} """ chan = None @@ -599,7 +581,7 @@ class BaseTransport (threading.Thread): session has died mid-negotiation. @return: True if the renegotiation was successful, and the link is - using new keys; False if the session dropped during renegotiation. + using new keys; False if the session dropped during renegotiation. @rtype: bool """ self.completion_event = threading.Event() @@ -620,12 +602,13 @@ class BaseTransport (threading.Thread): can be useful to keep connections alive over a NAT, for example. @param interval: seconds to wait before sending a keepalive packet (or - 0 to disable keepalives). + 0 to disable keepalives). @type interval: int @since: fearow """ - self.keepalive_interval = interval + self.packetizer.set_keepalive(interval, + lambda x=self: x.global_request('keepalive@lag.net', wait=False)) def global_request(self, kind, data=None, wait=True): """ @@ -635,14 +618,14 @@ class BaseTransport (threading.Thread): @param kind: name of the request. @type kind: str @param data: an optional tuple containing additional data to attach - to the request. + to the request. @type data: tuple @param wait: C{True} if this method should not return until a response - is received; C{False} otherwise. + is received; C{False} otherwise. @type wait: bool @return: a L{Message} containing possible additional data if the - request was successful (or an empty L{Message} if C{wait} was - C{False}); C{None} if the request was denied. + request was successful (or an empty L{Message} if C{wait} was + C{False}); C{None} if the request was denied. @rtype: L{Message} @since: fearow @@ -833,7 +816,23 @@ class BaseTransport (threading.Thread): C{False} otherwise. @type hexdump: bool """ - self.ultra_debug = hexdump + self.packetizer.set_hexdump(hexdump) + + def get_hexdump(self): + """ + Return C{True} if the transport is currently logging hex dumps of + protocol traffic. + + @return: C{True} if hex dumps are being logged + @rtype: bool + + @since: 1.4 + """ + return self.packetizer.get_hexdump() + + def stop_thread(self): + self.active = False + self.packetizer.close() ### internals... @@ -859,113 +858,10 @@ class BaseTransport (threading.Thread): finally: self.lock.release() - def _check_keepalive(self): - if (not self.keepalive_interval) or (not self.initial_kex_done): - return - now = time.time() - if now > self.keepalive_last + self.keepalive_interval: - self.global_request('keepalive@lag.net', wait=False) - - def _py22_read_all(self, n): - out = '' - while n > 0: - r, w, e = select.select([self.sock], [], [], 0.1) - if self.sock not in r: - if not self.active: - raise EOFError() - self._check_keepalive() - else: - x = self.sock.recv(n) - if len(x) == 0: - raise EOFError() - out += x - n -= len(x) - return out - - def _read_all(self, n): - if PY22: - return self._py22_read_all(n) - out = '' - while n > 0: - try: - x = self.sock.recv(n) - if len(x) == 0: - raise EOFError() - out += x - n -= len(x) - except socket.timeout: - if not self.active: - raise EOFError() - self._check_keepalive() - return out - - def _write_all(self, out): - self.keepalive_last = time.time() - while len(out) > 0: - try: - n = self.sock.send(out) - except socket.timeout: - n = 0 - if not self.active: - n = -1 - except Exception, x: - # could be: (32, 'Broken pipe') - n = -1 - if n < 0: - raise EOFError() - if n == len(out): - return - out = out[n:] - return - - def _build_packet(self, payload): - # pad up at least 4 bytes, to nearest block-size (usually 8) - bsize = self.block_size_out - padding = 3 + bsize - ((len(payload) + 8) % bsize) - packet = struct.pack('>I', len(payload) + padding + 1) - packet += chr(padding) - packet += payload - packet += randpool.get_bytes(padding) - return packet - def _send_message(self, data): - # encrypt this sucka - data = str(data) - cmd = ord(data[0]) - if cmd in MSG_NAMES: - cmd_name = MSG_NAMES[cmd] - else: - cmd_name = '$%x' % cmd - self._log(DEBUG, 'Write packet <%s>, length %d' % (cmd_name, len(data))) - packet = self._build_packet(data) - if self.ultra_debug: - self._log(DEBUG, util.format_binary(packet, 'OUT: ')) - if self.engine_out != None: - out = self.engine_out.encrypt(packet) - else: - out = packet - # + mac - try: - self.write_lock.acquire() - if self.engine_out != None: - payload = struct.pack('>I', self.sequence_number_out) + packet - out += HMAC.HMAC(self.mac_key_out, payload, self.local_mac_engine).digest()[:self.local_mac_len] - self.sequence_number_out += 1L - self.sequence_number_out %= 0x100000000L - self._write_all(out) - - self.sent_bytes += len(out) - self.sent_packets += 1 - if ((self.sent_packets >= self.REKEY_PACKETS) or (self.sent_bytes >= self.REKEY_BYTES)) \ - and (self.local_kex_init is None): - # only ask once for rekeying - self._log(DEBUG, 'Rekeying (hit %d packets, %d bytes sent)' % - (self.sent_packets, self.sent_bytes)) - self.received_packets_overflow = 0 - # this may do a recursive lock, but that's okay: - self._send_kex_init() - finally: - self.write_lock.release() + self.packetizer.send_message(data) + if self.packetizer.need_rekey(): + self._send_kex_init() def _send_user_message(self, data): """ @@ -981,65 +877,6 @@ class BaseTransport (threading.Thread): break self._send_message(data) - def _read_message(self): - "only one thread will ever be in this function" - header = self._read_all(self.block_size_in) - if self.engine_in != None: - header = self.engine_in.decrypt(header) - if self.ultra_debug: - self._log(DEBUG, util.format_binary(header, 'IN: ')); - packet_size = struct.unpack('>I', header[:4])[0] - # leftover contains decrypted bytes from the first block (after the length field) - leftover = header[4:] - if (packet_size - len(leftover)) % self.block_size_in != 0: - raise SSHException('Invalid packet blocking') - buffer = self._read_all(packet_size + self.remote_mac_len - len(leftover)) - packet = buffer[:packet_size - len(leftover)] - post_packet = buffer[packet_size - len(leftover):] - if self.engine_in != None: - packet = self.engine_in.decrypt(packet) - if self.ultra_debug: - self._log(DEBUG, util.format_binary(packet, 'IN: ')); - packet = leftover + packet - if self.remote_mac_len > 0: - mac = post_packet[:self.remote_mac_len] - mac_payload = struct.pack('>II', self.sequence_number_in, packet_size) + packet - my_mac = HMAC.HMAC(self.mac_key_in, mac_payload, self.remote_mac_engine).digest()[:self.remote_mac_len] - if my_mac != mac: - raise SSHException('Mismatched MAC') - padding = ord(packet[0]) - payload = packet[1:packet_size - padding + 1] - randpool.add_event(packet[packet_size - padding + 1]) - if self.ultra_debug: - self._log(DEBUG, 'Got payload (%d bytes, %d padding)' % (packet_size, padding)) - msg = Message(payload[1:]) - msg.seqno = self.sequence_number_in - self.sequence_number_in = (self.sequence_number_in + 1) & 0xffffffffL - # check for rekey - self.received_bytes += packet_size + self.remote_mac_len + 4 - self.received_packets += 1 - if self.local_kex_init is not None: - # we've asked to rekey -- give them 20 packets to comply before - # dropping the connection - self.received_packets_overflow += 1 - if self.received_packets_overflow >= 20: - raise SSHException('Remote transport is ignoring rekey requests') - elif (self.received_packets >= self.REKEY_PACKETS) or \ - (self.received_bytes >= self.REKEY_BYTES): - # only ask once for rekeying - self._log(DEBUG, 'Rekeying (hit %d packets, %d bytes received)' % - (self.received_packets, self.received_bytes)) - self.received_packets_overflow = 0 - self._send_kex_init() - - cmd = ord(payload[0]) - if cmd in MSG_NAMES: - cmd_name = MSG_NAMES[cmd] - else: - cmd_name = '$%x' % cmd - self._log(DEBUG, 'Read packet <%s>, length %d' % (cmd_name, len(payload))) - return cmd, msg - def _set_K_H(self, k, h): "used by a kex object to set the K (root key) and H (exchange hash)" self.K = k @@ -1090,18 +927,21 @@ class BaseTransport (threading.Thread): else: self._log(DEBUG, 'starting thread (client mode): %s' % hex(long(id(self)) & 0xffffffffL)) try: - self._write_all(self.local_version + '\r\n') + self.packetizer.write_all(self.local_version + '\r\n') self._check_banner() self._send_kex_init() self.expected_packet = MSG_KEXINIT while self.active: - ptype, m = self._read_message() + if self.packetizer.need_rekey(): + self._send_kex_init() + ptype, m = self.packetizer.read_message() if ptype == MSG_IGNORE: continue elif ptype == MSG_DISCONNECT: self._parse_disconnect(m) self.active = False + self.packetizer.close() break elif ptype == MSG_DEBUG: self._parse_debug(m) @@ -1123,6 +963,7 @@ class BaseTransport (threading.Thread): else: self._log(ERROR, 'Channel request for unknown channel %d' % chanid) self.active = False + self.packetizer.close() else: self._log(WARNING, 'Oops, unhandled type %d' % ptype) msg = Message() @@ -1146,6 +987,7 @@ class BaseTransport (threading.Thread): chan._unlink() if self.active: self.active = False + self.packetizer.close() if self.completion_event != None: self.completion_event.set() if self.auth_event != None: @@ -1170,12 +1012,15 @@ class BaseTransport (threading.Thread): def _check_banner(self): # this is slow, but we only have to do it once for i in range(5): - buffer = '' - while not '\n' in buffer: - buffer += self._read_all(1) - buffer = buffer[:-1] - if (len(buffer) > 0) and (buffer[-1] == '\r'): - buffer = buffer[:-1] + # give them 5 seconds for the first line, then just 2 seconds each additional line + if i == 0: + timeout = 5 + else: + timeout = 2 + try: + buffer = self.packetizer.readline(timeout) + except Exception, x: + raise SSHException('Error reading SSH protocol banner' + str(x)) if buffer[:4] == 'SSH-': break self._log(DEBUG, 'Banner: ' + buffer) @@ -1236,13 +1081,6 @@ class BaseTransport (threading.Thread): self._send_message(m) def _parse_kex_init(self, m): - # reset counters of when to re-key, since we are now re-keying - self.received_bytes = 0 - self.received_packets = 0 - self.received_packets_overflow = 0 - self.sent_bytes = 0 - self.sent_packets = 0 - cookie = m.get_bytes(16) kex_algo_list = m.get_list() server_key_algo_list = m.get_list() @@ -1334,44 +1172,46 @@ class BaseTransport (threading.Thread): def _activate_inbound(self): "switch on newly negotiated encryption parameters for inbound traffic" - self.block_size_in = self._cipher_info[self.remote_cipher]['block-size'] + block_size = self._cipher_info[self.remote_cipher]['block-size'] if self.server_mode: - IV_in = self._compute_key('A', self.block_size_in) + IV_in = self._compute_key('A', block_size) key_in = self._compute_key('C', self._cipher_info[self.remote_cipher]['key-size']) else: - IV_in = self._compute_key('B', self.block_size_in) + IV_in = self._compute_key('B', block_size) key_in = self._compute_key('D', self._cipher_info[self.remote_cipher]['key-size']) - self.engine_in = self._get_cipher(self.remote_cipher, key_in, IV_in) - self.remote_mac_len = self._mac_info[self.remote_mac]['size'] - self.remote_mac_engine = self._mac_info[self.remote_mac]['class'] + engine = self._get_cipher(self.remote_cipher, key_in, IV_in) + mac_size = self._mac_info[self.remote_mac]['size'] + mac_engine = self._mac_info[self.remote_mac]['class'] # initial mac keys are done in the hash's natural size (not the potentially truncated # transmission size) if self.server_mode: - self.mac_key_in = self._compute_key('E', self.remote_mac_engine.digest_size) + mac_key = self._compute_key('E', mac_engine.digest_size) else: - self.mac_key_in = self._compute_key('F', self.remote_mac_engine.digest_size) + mac_key = self._compute_key('F', mac_engine.digest_size) + self.packetizer.set_inbound_cipher(engine, block_size, mac_engine, mac_size, mac_key) def _activate_outbound(self): "switch on newly negotiated encryption parameters for outbound traffic" m = Message() m.add_byte(chr(MSG_NEWKEYS)) self._send_message(m) - self.block_size_out = self._cipher_info[self.local_cipher]['block-size'] + block_size = self._cipher_info[self.local_cipher]['block-size'] if self.server_mode: - IV_out = self._compute_key('B', self.block_size_out) + IV_out = self._compute_key('B', block_size) key_out = self._compute_key('D', self._cipher_info[self.local_cipher]['key-size']) else: - IV_out = self._compute_key('A', self.block_size_out) + IV_out = self._compute_key('A', block_size) key_out = self._compute_key('C', self._cipher_info[self.local_cipher]['key-size']) - self.engine_out = self._get_cipher(self.local_cipher, key_out, IV_out) - self.local_mac_len = self._mac_info[self.local_mac]['size'] - self.local_mac_engine = self._mac_info[self.local_mac]['class'] + engine = self._get_cipher(self.local_cipher, key_out, IV_out) + mac_size = self._mac_info[self.local_mac]['size'] + mac_engine = self._mac_info[self.local_mac]['class'] # initial mac keys are done in the hash's natural size (not the potentially truncated # transmission size) if self.server_mode: - self.mac_key_out = self._compute_key('F', self.local_mac_engine.digest_size) + mac_key = self._compute_key('F', mac_engine.digest_size) else: - self.mac_key_out = self._compute_key('E', self.local_mac_engine.digest_size) + mac_key = self._compute_key('E', mac_engine.digest_size) + self.packetizer.set_outbound_cipher(engine, block_size, mac_engine, mac_size, mac_key) # we always expect to receive NEWKEYS now self.expected_packet = MSG_NEWKEYS |