diff options
-rw-r--r-- | NOTES | 4 | ||||
-rwxr-xr-x | demo.py | 2 | ||||
-rwxr-xr-x | demo_server.py | 8 | ||||
-rw-r--r-- | paramiko/__init__.py | 2 | ||||
-rw-r--r-- | paramiko/auth_transport.py | 80 | ||||
-rw-r--r-- | paramiko/channel.py | 48 | ||||
-rw-r--r-- | paramiko/kex_gex.py | 80 | ||||
-rw-r--r-- | paramiko/kex_group1.py | 20 | ||||
-rw-r--r-- | paramiko/primes.py | 128 | ||||
-rw-r--r-- | paramiko/transport.py | 394 | ||||
-rw-r--r-- | paramiko/util.py | 16 |
11 files changed, 507 insertions, 275 deletions
@@ -22,14 +22,14 @@ from BaseTransport: get_server_key close get_remote_server_key - is_active - is_authenticated +* is_active open_session open_channel renegotiate_keys check_channel_request from Transport: +* is_authenticated auth_key auth_password get_allowed_auths @@ -69,7 +69,7 @@ try: t.ultra_debug = 0 t.start_client(event) # print repr(t) - event.wait(10) + event.wait(15) if not t.is_active(): print '*** SSH negotiation failed.' sys.exit(1) diff --git a/demo_server.py b/demo_server.py index 90ab9bac..e04c8027 100755 --- a/demo_server.py +++ b/demo_server.py @@ -70,6 +70,11 @@ print 'Got a connection!' try: event = threading.Event() t = ServerTransport(client) + try: + t.load_server_moduli() + except: + print '(Failed to load moduli -- gex will be unsupported.)' + raise t.add_server_key(host_key) t.ultra_debug = 0 t.start_server(event) @@ -81,10 +86,11 @@ try: # print repr(t) # wait for auth - chan = t.accept(10) + chan = t.accept(20) if chan is None: print '*** No channel.' sys.exit(1) + print 'Authenticated!' chan.event.wait(10) if not chan.event.isSet(): print '*** Client never asked for a shell.' diff --git a/paramiko/__init__.py b/paramiko/__init__.py index e6d5e9fa..0e96a92a 100644 --- a/paramiko/__init__.py +++ b/paramiko/__init__.py @@ -17,4 +17,4 @@ from rsakey import RSAKey from dsskey import DSSKey from util import hexify -__all__ = [ 'Transport', 'Channel', 'RSAKey', 'DSSKey', 'hexify' ] +#__all__ = [ 'Transport', 'Channel', 'RSAKey', 'DSSKey', 'hexify' ] diff --git a/paramiko/auth_transport.py b/paramiko/auth_transport.py index 23f2d02c..c39c4630 100644 --- a/paramiko/auth_transport.py +++ b/paramiko/auth_transport.py @@ -19,17 +19,45 @@ class Transport(BaseTransport): def __init__(self, sock): BaseTransport.__init__(self, sock) + self.authenticated = False self.auth_event = None # for server mode: self.auth_username = None self.auth_fail_count = 0 self.auth_complete = 0 - def request_auth(self): + def __repr__(self): + if not self.active: + return '<paramiko.Transport (unconnected)>' + out = '<paramiko.Transport' + if self.local_cipher != '': + out += ' (cipher %s)' % self.local_cipher + if self.authenticated: + if len(self.channels) == 1: + out += ' (active; 1 open channel)' + else: + out += ' (active; %d open channels)' % len(self.channels) + elif self.initial_kex_done: + out += ' (connected; awaiting auth)' + else: + out += ' (connecting)' + out += '>' + return out + + def is_authenticated(self): + """ + Return true if this session is active and authenticated. + + @return: True if the session is still open and has been authenticated successfully; + False if authentication failed and/or the session is closed. + """ + return self.authenticated and self.active + + def _request_auth(self): m = Message() m.add_byte(chr(MSG_SERVICE_REQUEST)) m.add_string('ssh-userauth') - self.send_message(m) + self._send_message(m) def auth_key(self, username, key, event): if (not self.active) or (not self.initial_kex_done): @@ -41,7 +69,7 @@ class Transport(BaseTransport): self.auth_method = 'publickey' self.username = username self.private_key = key - self.request_auth() + self._request_auth() finally: self.lock.release() @@ -56,7 +84,7 @@ class Transport(BaseTransport): self.auth_method = 'password' self.username = username self.password = password - self.request_auth() + self._request_auth() finally: self.lock.release() @@ -66,7 +94,7 @@ class Transport(BaseTransport): m.add_int(DISCONNECT_SERVICE_NOT_AVAILABLE) m.add_string('Service not available') m.add_string('en') - self.send_message(m) + self._send_message(m) self.close() def disconnect_no_more_auth(self): @@ -75,7 +103,7 @@ class Transport(BaseTransport): m.add_int(DISCONNECT_NO_MORE_AUTH_METHODS_AVAILABLE) m.add_string('No more auth methods available') m.add_string('en') - self.send_message(m) + self._send_message(m) self.close() def parse_service_request(self, m): @@ -85,7 +113,7 @@ class Transport(BaseTransport): m = Message() m.add_byte(chr(MSG_SERVICE_ACCEPT)) m.add_string(service) - self.send_message(m) + self._send_message(m) return # dunno this one self.disconnect_service_not_available() @@ -93,7 +121,7 @@ class Transport(BaseTransport): def parse_service_accept(self, m): service = m.get_string() if service == 'ssh-userauth': - self.log(DEBUG, 'userauth is OK') + self._log(DEBUG, 'userauth is OK') m = Message() m.add_byte(chr(MSG_USERAUTH_REQUEST)) m.add_string(self.username) @@ -109,9 +137,9 @@ class Transport(BaseTransport): m.add_string(self.private_key.sign_ssh_session(self.randpool, self.H, self.username)) else: raise SSHException('Unknown auth method "%s"' % self.auth_method) - self.send_message(m) + self._send_message(m) else: - self.log(DEBUG, 'Service request "%s" accepted (?)' % service) + self._log(DEBUG, 'Service request "%s" accepted (?)' % service) def get_allowed_auths(self, username): "override me!" @@ -136,7 +164,7 @@ class Transport(BaseTransport): m.add_byte(chr(MSG_USERAUTH_FAILURE)) m.add_string('none') m.add_boolean(0) - self.send_message(m) + self._send_message(m) return if self.auth_complete: # ignore @@ -144,12 +172,12 @@ class Transport(BaseTransport): username = m.get_string() service = m.get_string() method = m.get_string() - self.log(DEBUG, 'Auth request (type=%s) service=%s, username=%s' % (method, service, username)) + self._log(DEBUG, 'Auth request (type=%s) service=%s, username=%s' % (method, service, username)) if service != 'ssh-connection': self.disconnect_service_not_available() return if (self.auth_username is not None) and (self.auth_username != username): - self.log(DEBUG, 'Auth rejected because the client attempted to change username in mid-flight') + self._log(DEBUG, 'Auth rejected because the client attempted to change username in mid-flight') self.disconnect_no_more_auth() return if method == 'none': @@ -160,7 +188,7 @@ class Transport(BaseTransport): if changereq: # always treated as failure, since we don't support changing passwords, but collect # the list of valid auth types from the callback anyway - self.log(DEBUG, 'Auth request to change passwords (rejected)') + self._log(DEBUG, 'Auth request to change passwords (rejected)') newpassword = m.get_string().decode('UTF-8') result = self.AUTH_FAILED else: @@ -173,11 +201,11 @@ class Transport(BaseTransport): # okay, send result m = Message() if result == self.AUTH_SUCCESSFUL: - self.log(DEBUG, 'Auth granted.') + self._log(DEBUG, 'Auth granted.') m.add_byte(chr(MSG_USERAUTH_SUCCESS)) self.auth_complete = 1 else: - self.log(DEBUG, 'Auth rejected.') + self._log(DEBUG, 'Auth rejected.') m.add_byte(chr(MSG_USERAUTH_FAILURE)) m.add_string(self.get_allowed_auths(username)) if result == self.AUTH_PARTIALLY_SUCCESSFUL: @@ -185,13 +213,13 @@ class Transport(BaseTransport): else: m.add_boolean(0) self.auth_fail_count += 1 - self.send_message(m) + self._send_message(m) if self.auth_fail_count >= 10: self.disconnect_no_more_auth() def parse_userauth_success(self, m): - self.log(INFO, 'Authentication successful!') - self.authenticated = 1 + self._log(INFO, 'Authentication successful!') + self.authenticated = True if self.auth_event != None: self.auth_event.set() @@ -199,12 +227,12 @@ class Transport(BaseTransport): authlist = m.get_list() partial = m.get_boolean() if partial: - self.log(INFO, 'Authentication continues...') - self.log(DEBUG, 'Methods: ' + str(partial)) + self._log(INFO, 'Authentication continues...') + self._log(DEBUG, 'Methods: ' + str(partial)) # FIXME - do something pass - self.log(INFO, 'Authentication failed.') - self.authenticated = 0 + self._log(INFO, 'Authentication failed.') + self.authenticated = False self.close() if self.auth_event != None: self.auth_event.set() @@ -212,11 +240,11 @@ class Transport(BaseTransport): def parse_userauth_banner(self, m): banner = m.get_string() lang = m.get_string() - self.log(INFO, 'Auth banner: ' + banner) + self._log(INFO, 'Auth banner: ' + banner) # who cares. - handler_table = BaseTransport.handler_table.copy() - handler_table.update({ + _handler_table = BaseTransport._handler_table.copy() + _handler_table.update({ MSG_SERVICE_REQUEST: parse_service_request, MSG_SERVICE_ACCEPT: parse_service_accept, MSG_USERAUTH_REQUEST: parse_userauth_request, diff --git a/paramiko/channel.py b/paramiko/channel.py index 43ddcd16..2ac0866b 100644 --- a/paramiko/channel.py +++ b/paramiko/channel.py @@ -50,10 +50,10 @@ class Channel(object): out += '>' return out - def set_transport(self, transport): + def _set_transport(self, transport): self.transport = transport - def log(self, level, msg): + def _log(self, level, msg): self.logger.log(level, msg) def set_window(self, window_size, max_packet_size): @@ -70,7 +70,7 @@ class Channel(object): self.active = 1 def request_success(self, m): - self.log(DEBUG, 'Sesch channel %d request ok' % self.chanid) + self._log(DEBUG, 'Sesch channel %d request ok' % self.chanid) return def request_failed(self, m): @@ -80,13 +80,13 @@ class Channel(object): s = m.get_string() try: self.lock.acquire() - self.log(DEBUG, 'fed %d bytes' % len(s)) + self._log(DEBUG, 'fed %d bytes' % len(s)) if self.pipe_wfd != None: self.feed_pipe(s) else: self.in_buffer += s self.in_buffer_cv.notifyAll() - self.log(DEBUG, '(out from feed)') + self._log(DEBUG, '(out from feed)') finally: self.lock.release() @@ -94,7 +94,7 @@ class Channel(object): nbytes = m.get_int() try: self.lock.acquire() - self.log(DEBUG, 'window up %d' % nbytes) + self._log(DEBUG, 'window up %d' % nbytes) self.out_window_size += nbytes self.out_buffer_cv.notifyAll() finally: @@ -146,7 +146,7 @@ class Channel(object): pixelheight = m.get_int() ok = self.check_window_change_request(width, height, pixelwidth, pixelheight) else: - self.log(DEBUG, 'Unhandled channel request "%s"' % key) + self._log(DEBUG, 'Unhandled channel request "%s"' % key) ok = False if want_reply: m = Message() @@ -155,7 +155,7 @@ class Channel(object): else: m.add_byte(chr(MSG_CHANNEL_FAILURE)) m.add_int(self.remote_chanid) - self.transport.send_message(m) + self.transport._send_message(m) def handle_eof(self, m): try: @@ -168,7 +168,7 @@ class Channel(object): self.pipe_wfd = None finally: self.lock.release() - self.log(DEBUG, 'EOF received') + self._log(DEBUG, 'EOF received') def handle_close(self, m): self.close() @@ -199,7 +199,7 @@ class Channel(object): # pixel height, width (usually useless) m.add_int(0).add_int(0) m.add_string('') - self.transport.send_message(m) + self.transport._send_message(m) def invoke_shell(self): if self.closed or self.eof_received or self.eof_sent or not self.active: @@ -209,7 +209,7 @@ class Channel(object): m.add_int(self.remote_chanid) m.add_string('shell') m.add_boolean(1) - self.transport.send_message(m) + self.transport._send_message(m) def exec_command(self, command): if self.closed or self.eof_received or self.eof_sent or not self.active: @@ -220,7 +220,7 @@ class Channel(object): m.add_string('exec') m.add_boolean(1) m.add_string(command) - self.transport.send_message(m) + self.transport._send_message(m) def invoke_subsystem(self, subsystem): if self.closed or self.eof_received or self.eof_sent or not self.active: @@ -231,7 +231,7 @@ class Channel(object): m.add_string('subsystem') m.add_boolean(1) m.add_string(subsystem) - self.transport.send_message(m) + self.transport._send_message(m) def resize_pty(self, width=80, height=24): if self.closed or self.eof_received or self.eof_sent or not self.active: @@ -244,7 +244,7 @@ class Channel(object): m.add_int(width) m.add_int(height) m.add_int(0).add_int(0) - self.transport.send_message(m) + self.transport._send_message(m) def get_transport(self): return self.transport @@ -262,9 +262,9 @@ class Channel(object): m = Message() m.add_byte(chr(MSG_CHANNEL_EOF)) m.add_int(self.remote_chanid) - self.transport.send_message(m) + self.transport._send_message(m) self.eof_sent = 1 - self.log(DEBUG, 'EOF sent') + self._log(DEBUG, 'EOF sent') return @@ -290,9 +290,9 @@ class Channel(object): m = Message() m.add_byte(chr(MSG_CHANNEL_CLOSE)) m.add_int(self.remote_chanid) - self.transport.send_message(m) + self.transport._send_message(m) self.closed = 1 - self.transport.unlink_channel(self.chanid) + self.transport._unlink_channel(self.chanid) finally: self.lock.release() @@ -371,7 +371,7 @@ class Channel(object): m.add_byte(chr(MSG_CHANNEL_DATA)) m.add_int(self.remote_chanid) m.add_string(s[:size]) - self.transport.send_message(m) + self.transport._send_message(m) self.out_window_size -= size finally: self.lock.release() @@ -506,25 +506,25 @@ class Channel(object): self.in_buffer = self.in_buffer[nbytes:] os.write(self.pipd_wfd, x) - def unlink(self): + def _unlink(self): if self.closed or not self.active: return self.closed = 1 - self.transport.unlink_channel(self.chanid) + self.transport._unlink_channel(self.chanid) def check_add_window(self, n): # already holding the lock! if self.closed or self.eof_received or not self.active: return - self.log(DEBUG, 'addwindow %d' % n) + self._log(DEBUG, 'addwindow %d' % n) self.in_window_sofar += n if self.in_window_sofar > self.in_window_threshold: - self.log(DEBUG, 'addwindow send %d' % self.in_window_sofar) + self._log(DEBUG, 'addwindow send %d' % self.in_window_sofar) m = Message() m.add_byte(chr(MSG_CHANNEL_WINDOW_ADJUST)) m.add_int(self.remote_chanid) m.add_int(self.in_window_sofar) - self.transport.send_message(m) + self.transport._send_message(m) self.in_window_sofar = 0 diff --git a/paramiko/kex_gex.py b/paramiko/kex_gex.py index f1b8058d..5e02bab7 100644 --- a/paramiko/kex_gex.py +++ b/paramiko/kex_gex.py @@ -5,7 +5,7 @@ # LOT more on the server side). from message import Message -from util import inflate_long, deflate_long, generate_prime, bit_length +from util import inflate_long, deflate_long, bit_length from ssh_exception import SSHException from transport import MSG_NEWKEYS from Crypto.Hash import SHA @@ -27,7 +27,7 @@ class KexGex(object): def start_kex(self): if self.transport.server_mode: - self.transport.expected_packet = MSG_KEXDH_GEX_REQUEST + self.transport._expect_packet(MSG_KEXDH_GEX_REQUEST) return # request a bit range: we accept (min_bits) to (max_bits), but prefer # (preferred_bits). according to the spec, we shouldn't pull the @@ -37,21 +37,21 @@ class KexGex(object): m.add_int(self.min_bits) m.add_int(self.preferred_bits) m.add_int(self.max_bits) - self.transport.send_message(m) - self.transport.expected_packet = MSG_KEXDH_GEX_GROUP + self.transport._send_message(m) + self.transport._expect_packet(MSG_KEXDH_GEX_GROUP) def parse_next(self, ptype, m): if ptype == MSG_KEXDH_GEX_REQUEST: - return self.parse_kexdh_gex_request(m) + return self._parse_kexdh_gex_request(m) elif ptype == MSG_KEXDH_GEX_GROUP: - return self.parse_kexdh_gex_group(m) + return self._parse_kexdh_gex_group(m) elif ptype == MSG_KEXDH_GEX_INIT: - return self.parse_kexdh_gex_init(m) + return self._parse_kexdh_gex_init(m) elif ptype == MSG_KEXDH_GEX_REPLY: - return self.parse_kexdh_gex_reply(m) + return self._parse_kexdh_gex_reply(m) raise SSHException('KexGex asked to handle packet type %d' % ptype) - def generate_x(self): + def _generate_x(self): # generate an "x" (1 < x < (p-1)/2). q = (self.p - 1) // 2 qnorm = deflate_long(q, 0) @@ -70,7 +70,7 @@ class KexGex(object): break self.x = x - def parse_kexdh_gex_request(self, m): + def _parse_kexdh_gex_request(self, m): min = m.get_int() preferred = m.get_int() max = m.get_int() @@ -79,52 +79,53 @@ class KexGex(object): preferred = self.max_bits if preferred < self.min_bits: preferred = self.min_bits + # fix min/max if they're inconsistent. technically, we could just pout + # and hang up, but there's no harm in giving them the benefit of the + # doubt and just picking a bitsize for them. + if min > preferred: + min = preferred + if max < preferred: + max = preferred # now save a copy self.min_bits = min self.preferred_bits = preferred self.max_bits = max # generate prime - while 1: - # does not work FIXME - # the problem is that it's too fscking SLOW - self.transport.log(DEBUG, 'stir...') - self.transport.randpool.stir() - self.transport.log(DEBUG, 'get-prime %d...' % preferred) - self.p = generate_prime(preferred, self.transport.randpool) - self.transport.log(DEBUG, 'got ' + repr(self.p)) - if number.isPrime((self.p - 1) // 2): - break - self.g = 2 + pack = self.transport._get_modulus_pack() + if pack is None: + raise SSHException('Can\'t do server-side gex with no modulus pack') + self.g, self.p = pack.get_modulus(min, preferred, max) m = Message() m.add_byte(chr(MSG_KEXDH_GEX_GROUP)) m.add_mpint(self.p) m.add_mpint(self.g) - self.transport.send_message(m) - self.transport.expected_packet = MSG_KEXDH_GEX_INIT + self.transport._send_message(m) + self.transport._expect_packet(MSG_KEXDH_GEX_INIT) - def parse_kexdh_gex_group(self, m): + def _parse_kexdh_gex_group(self, m): self.p = m.get_mpint() self.g = m.get_mpint() # reject if p's bit length < 1024 or > 8192 bitlen = bit_length(self.p) if (bitlen < 1024) or (bitlen > 8192): raise SSHException('Server-generated gex p (don\'t ask) is out of range (%d bits)' % bitlen) - self.transport.log(DEBUG, 'Got server p (%d bits)' % bitlen) - self.generate_x() + self.transport._log(DEBUG, 'Got server p (%d bits)' % bitlen) + self._generate_x() # now compute e = g^x mod p self.e = pow(self.g, self.x, self.p) m = Message() m.add_byte(chr(MSG_KEXDH_GEX_INIT)) m.add_mpint(self.e) - self.transport.send_message(m) - self.transport.expected_packet = MSG_KEXDH_GEX_REPLY + self.transport._send_message(m) + self.transport._expect_packet(MSG_KEXDH_GEX_REPLY) - def parse_kexdh_gex_init(self, m): + def _parse_kexdh_gex_init(self, m): self.e = m.get_mpint() if (self.e < 1) or (self.e > self.p - 1): raise SSHException('Client kex "e" is out of range') - self.generate_x() - K = pow(self.e, self.x, P) + self._generate_x() + self.f = pow(self.g, self.x, self.p) + K = pow(self.e, self.x, self.p) key = str(self.transport.get_server_key()) # okay, build up the hash H of (V_C || V_S || I_C || I_S || K_S || min || n || max || p || g || e || f || K) hm = Message().add(self.transport.remote_version).add(self.transport.local_version) @@ -136,7 +137,7 @@ class KexGex(object): hm.add_mpint(self.g) hm.add(self.e).add(self.f).add(K) H = SHA.new(str(hm)).digest() - self.transport.set_K_H(K, H) + self.transport._set_K_H(K, H) # sign it sig = self.transport.get_server_key().sign_ssh_data(self.transport.randpool, H) # send reply @@ -145,11 +146,10 @@ class KexGex(object): m.add_string(key) m.add_mpint(self.f) m.add_string(sig) - self.transport.send_message(m) - self.transport.activate_outbound() - self.transport.expected_packet = MSG_NEWKEYS + self.transport._send_message(m) + self.transport._activate_outbound() - def parse_kexdh_gex_reply(self, m): + def _parse_kexdh_gex_reply(self, m): host_key = m.get_string() self.f = m.get_mpint() sig = m.get_string() @@ -165,9 +165,9 @@ class KexGex(object): hm.add_mpint(self.p) hm.add_mpint(self.g) hm.add(self.e).add(self.f).add(K) - self.transport.set_K_H(K, SHA.new(str(hm)).digest()) - self.transport.verify_key(host_key, sig) - self.transport.activate_outbound() - self.transport.expected_packet = MSG_NEWKEYS + self.transport._set_K_H(K, SHA.new(str(hm)).digest()) + self.transport._verify_key(host_key, sig) + self.transport._activate_outbound() + diff --git a/paramiko/kex_group1.py b/paramiko/kex_group1.py index a29c79d0..de1a2546 100644 --- a/paramiko/kex_group1.py +++ b/paramiko/kex_group1.py @@ -44,15 +44,15 @@ class KexGroup1(object): if self.transport.server_mode: # compute f = g^x mod p, but don't send it yet self.f = pow(G, self.x, P) - self.transport.expected_packet = MSG_KEXDH_INIT + self.transport._expect_packet(MSG_KEXDH_INIT) return # compute e = g^x mod p (where g=2), and send it self.e = pow(G, self.x, P) m = Message() m.add_byte(chr(MSG_KEXDH_INIT)) m.add_mpint(self.e) - self.transport.send_message(m) - self.transport.expected_packet = MSG_KEXDH_REPLY + self.transport._send_message(m) + self.transport._expect_packet(MSG_KEXDH_REPLY) def parse_next(self, ptype, m): if self.transport.server_mode and (ptype == MSG_KEXDH_INIT): @@ -73,10 +73,9 @@ class KexGroup1(object): hm = Message().add(self.transport.local_version).add(self.transport.remote_version) hm.add(self.transport.local_kex_init).add(self.transport.remote_kex_init).add(host_key) hm.add(self.e).add(self.f).add(K) - self.transport.set_K_H(K, SHA.new(str(hm)).digest()) - self.transport.verify_key(host_key, sig) - self.transport.activate_outbound() - self.transport.expected_packet = MSG_NEWKEYS + self.transport._set_K_H(K, SHA.new(str(hm)).digest()) + self.transport._verify_key(host_key, sig) + self.transport._activate_outbound() def parse_kexdh_init(self, m): # server mode @@ -90,7 +89,7 @@ class KexGroup1(object): hm.add(self.transport.remote_kex_init).add(self.transport.local_kex_init).add(key) hm.add(self.e).add(self.f).add(K) H = SHA.new(str(hm)).digest() - self.transport.set_K_H(K, H) + self.transport._set_K_H(K, H) # sign it sig = self.transport.get_server_key().sign_ssh_data(self.transport.randpool, H) # send reply @@ -99,6 +98,5 @@ class KexGroup1(object): m.add_string(key) m.add_mpint(self.f) m.add_string(sig) - self.transport.send_message(m) - self.transport.activate_outbound() - self.transport.expected_packet = MSG_NEWKEYS + self.transport._send_message(m) + self.transport._activate_outbound() diff --git a/paramiko/primes.py b/paramiko/primes.py new file mode 100644 index 00000000..68e7fc19 --- /dev/null +++ b/paramiko/primes.py @@ -0,0 +1,128 @@ + +# utility functions for dealing with primes + +from Crypto.Util import number +from util import bit_length, inflate_long + + +def generate_prime(bits, randpool): + hbyte_mask = pow(2, bits % 8) - 1 + while 1: + # loop catches the case where we increment n into a higher bit-range + x = randpool.get_bytes((bits+7) // 8) + if hbyte_mask > 0: + x = chr(ord(x[0]) & hbyte_mask) + x[1:] + n = inflate_long(x, 1) + n |= 1 + n |= (1 << (bits - 1)) + while not number.isPrime(n): + n += 2 + if bit_length(n) == bits: + return n + +def roll_random(randpool, n): + "returns a random # from 0 to N-1" + bits = bit_length(n-1) + bytes = (bits + 7) // 8 + hbyte_mask = pow(2, bits % 8) - 1 + + # so here's the plan: + # we fetch as many random bits as we'd need to fit N-1, and if the + # generated number is >= N, we try again. in the worst case (N-1 is a + # power of 2), we have slightly better than 50% odds of getting one that + # fits, so i can't guarantee that this loop will ever finish, but the odds + # of it looping forever should be infinitesimal. + while 1: + x = randpool.get_bytes(bytes) + if hbyte_mask > 0: + x = chr(ord(x[0]) & hbyte_mask) + x[1:] + num = inflate_long(x, 1) + if num < n: + return num + + +class ModulusPack (object): + """ + convenience object for holding the contents of the /etc/ssh/moduli file, + on systems that have such a file. + """ + + def __init__(self, randpool): + # pack is a hash of: bits -> [ (generator, modulus) ... ] + self.pack = {} + self.discarded = [] + self.randpool = randpool + + def _parse_modulus(self, line): + timestamp, type, tests, tries, size, generator, modulus = line.split() + type = int(type) + tests = int(tests) + tries = int(tries) + size = int(size) + generator = int(generator) + modulus = long(modulus, 16) + + # weed out primes that aren't at least: + # type 2 (meets basic structural requirements) + # test 4 (more than just a small-prime sieve) + # tries < 100 if test & 4 (at least 100 tries of miller-rabin) + if (type < 2) or (tests < 4) or ((tests & 4) and (tests < 8) and (tries < 100)): + self.discarded.append((modulus, 'does not meet basic requirements')) + return + if generator == 0: + generator = 2 + + # there's a bug in the ssh "moduli" file (yeah, i know: shock! dismay! + # call cnn!) where it understates the bit lengths of these primes by 1. + # this is okay. + bl = bit_length(modulus) + if (bl != size) and (bl != size + 1): + self.discarded.append((modulus, 'incorrectly reported bit length %d' % size)) + return + if not self.pack.has_key(bl): + self.pack[bl] = [] + self.pack[bl].append((generator, modulus)) + + def read_file(self, filename): + """ + @raise IOError: passed from any file operations that fail. + """ + self.pack = {} + f = open(filename, 'r') + for line in f: + line = line.strip() + if (len(line) == 0) or (line[0] == '#'): + continue + try: + self._parse_modulus(line) + except: + continue + f.close() + + def get_modulus(self, min, prefer, max): + bitsizes = self.pack.keys() + bitsizes.sort() + if len(bitsizes) == 0: + raise SSHException('no moduli available') + good = -1 + # find nearest bitsize >= preferred + for b in bitsizes: + if (b >= prefer) and (b < max) and ((b < good) or (good == -1)): + good = b + # if that failed, find greatest bitsize >= min + if good == -1: + for b in bitsizes: + if (b >= min) and (b < max) and (b > good): + good = b + if good == -1: + # their entire (min, max) range has no intersection with our range. + # if their range is below ours, pick the smallest. otherwise pick + # the largest. it'll be out of their range requirement either way, + # but we'll be sending them the closest one we have. + good = bitsizes[0] + if min > good: + good = bitsizes[-1] + # now pick a random modulus of this bitsize + n = roll_random(self.randpool, len(self.pack[good])) + return self.pack[good][n] + diff --git a/paramiko/transport.py b/paramiko/transport.py index 6bbfa757..f982d78f 100644 --- a/paramiko/transport.py +++ b/paramiko/transport.py @@ -12,6 +12,7 @@ MSG_CHANNEL_OPEN, MSG_CHANNEL_OPEN_SUCCESS, MSG_CHANNEL_OPEN_FAILURE, \ MSG_CHANNEL_SUCCESS, MSG_CHANNEL_FAILURE = range(90, 101) import sys, os, string, threading, socket, logging, struct +from ssh_exception import SSHException from message import Message from channel import Channel from util import format_binary, safe_string, inflate_long, deflate_long, tb_strings @@ -19,6 +20,7 @@ from rsakey import RSAKey from dsskey import DSSKey from kex_group1 import KexGroup1 from kex_gex import KexGex +from primes import ModulusPack # these come from PyCrypt # http://www.amk.ca/python/writing/pycrypt/ @@ -81,21 +83,21 @@ class BaseTransport(threading.Thread): preferred_keys = [ 'ssh-rsa', 'ssh-dss' ] preferred_kex = [ 'diffie-hellman-group1-sha1', 'diffie-hellman-group-exchange-sha1' ] - cipher_info = { + _cipher_info = { 'blowfish-cbc': { 'class': Blowfish, 'mode': Blowfish.MODE_CBC, 'block-size': 8, 'key-size': 16 }, 'aes128-cbc': { 'class': AES, 'mode': AES.MODE_CBC, 'block-size': 16, 'key-size': 16 }, 'aes256-cbc': { 'class': AES, 'mode': AES.MODE_CBC, 'block-size': 16, 'key-size': 32 }, '3des-cbc': { 'class': DES3, 'mode': DES3.MODE_CBC, 'block-size': 8, 'key-size': 24 }, } - mac_info = { + _mac_info = { 'hmac-sha1': { 'class': SHA, 'size': 20 }, 'hmac-sha1-96': { 'class': SHA, 'size': 12 }, 'hmac-md5': { 'class': MD5, 'size': 16 }, 'hmac-md5-96': { 'class': MD5, 'size': 12 }, } - kex_info = { + _kex_info = { 'diffie-hellman-group1-sha1': KexGroup1, 'diffie-hellman-group-exchange-sha1': KexGex, } @@ -107,7 +109,7 @@ class BaseTransport(threading.Thread): OPEN_FAILED_RESOURCE_SHORTAGE = range(1, 5) def __init__(self, sock): - threading.Thread.__init__(self) + threading.Thread.__init__(self, target=self._run) self.randpool = randpool self.sock = sock self.sock.settimeout(0.1) @@ -123,11 +125,10 @@ class BaseTransport(threading.Thread): self.session_id = None # /negotiated crypto parameters self.expected_packet = 0 - self.active = 0 + self.active = False self.initial_kex_done = 0 self.write_lock = threading.Lock() # lock around outbound writes (packet computation) self.lock = threading.Lock() # synchronization (always higher level than write_lock) - self.authenticated = 0 self.channels = { } # (id -> Channel) self.channel_events = { } # (id -> Event) self.channel_counter = 1 @@ -135,6 +136,7 @@ class BaseTransport(threading.Thread): self.window_size = 65536 self.max_packet_size = 2048 self.ultra_debug = 0 + self.modulus_pack = None # used for noticing when to re-key: self.received_bytes = 0 self.received_packets = 0 @@ -165,27 +167,69 @@ class BaseTransport(threading.Thread): except KeyError: return None + def load_server_moduli(self, filename=None): + """ + I{(optional)} + Load a file of prime moduli for use in doing group-exchange key + negotiation in server mode. It's a rather obscure option and can be + safely ignored. + + In server mode, the remote client may request "group-exchange" key + negotiation, which asks the server to send a random prime number that + fits certain criteria. These primes are pretty difficult to compute, + so they can't be generated on demand. But many systems contain a file + of suitable primes (usually named something like C{/etc/ssh/moduli}). + If you call C{load_server_moduli} and it returns C{True}, then this + file of primes has been loaded and we will support "group-exchange" in + server mode. Otherwise server mode will just claim that it doesn't + support that method of key negotiation. + + @param filename: optional path to the moduli file, if you happen to + know that it's not in a standard location. + @type filename: string + @return: True if a moduli file was successfully loaded; False + otherwise. + @rtype: boolean + + @since: doduo + + @note: This has no effect when used in client mode. + """ + self.modulus_pack = ModulusPack(self.randpool) + # places to look for the openssh "moduli" file + file_list = [ '/etc/ssh/moduli', '/usr/local/etc/moduli' ] + if filename is not None: + file_list.insert(0, filename) + for fn in file_list: + try: + self.modulus_pack.read_file(fn) + return True + except IOError: + pass + # none succeeded + self.modulus_pack = None + return False + + def _get_modulus_pack(self): + "used by KexGex to find primes for group exchange" + return self.modulus_pack + def __repr__(self): if not self.active: - return '<paramiko.Transport (unconnected)>' - out = '<sesch.Transport' + return '<paramiko.BaseTransport (unconnected)>' + out = '<paramiko.BaseTransport' #if self.remote_version != '': # out += ' (server version "%s")' % self.remote_version if self.local_cipher != '': out += ' (cipher %s)' % self.local_cipher - if self.authenticated: - if len(self.channels) == 1: - out += ' (active; 1 open channel)' - else: - out += ' (active; %d open channels)' % len(self.channels) - elif self.initial_kex_done: - out += ' (connected; awaiting auth)' + if len(self.channels) == 1: + out += ' (active; 1 open channel)' else: - out += ' (connecting)' + out += ' (active; %d open channels)' % len(self.channels) out += '>' return out - def log(self, level, msg): + def _log(self, level, msg): if type(msg) == type([]): for m in msg: self.logger.log(level, m) @@ -193,14 +237,29 @@ class BaseTransport(threading.Thread): self.logger.log(level, msg) def close(self): - self.active = 0 + """ + Close this session, and any open channels that are tied to it. + """ + self.active = False self.engine_in = self.engine_out = None self.sequence_number_in = self.sequence_number_out = 0L for chan in self.channels.values(): - chan.unlink() + chan._unlink() def get_remote_server_key(self): - 'returns (type, key) where type is like "ssh-rsa" and key is an opaque string' + """ + Return the host key of the server (in client mode). + The type string is usually either C{"ssh-rsa"} or C{"ssh-dss"} and the + key is an opaque string, which may be saved or used for comparison with + previously-seen keys. (In other words, you don't need to worry about + the content of the key, only that it compares equal to the key you + expected to see.) + + @raise SSHException: if no session is currently active. + + @return: tuple of (key type, key) + @rtype: (string, string) + """ if (not self.active) or (not self.initial_kex_done): raise SSHException('No existing session') key_msg = Message(self.host_key) @@ -208,10 +267,13 @@ class BaseTransport(threading.Thread): return key_type, self.host_key def is_active(self): - return self.active + """ + Return true if this session is active (open). - def is_authenticated(self): - return self.authenticated and self.active + @return: True if the session is still active (open); False if the session is closed. + @rtype: boolean + """ + return self.active def open_session(self): return self.open_channel('session') @@ -230,9 +292,9 @@ class BaseTransport(threading.Thread): m.add_int(self.max_packet_size) self.channels[chanid] = chan = Channel(chanid) self.channel_events[chanid] = event = threading.Event() - chan.set_transport(self) + chan._set_transport(self) chan.set_window(self.window_size, self.max_packet_size) - self.send_message(m) + self._send_message(m) finally: self.lock.release() while 1: @@ -249,7 +311,8 @@ class BaseTransport(threading.Thread): self.lock.release() return chan - def unlink_channel(self, chanid): + def _unlink_channel(self, chanid): + "used by a Channel to remove itself from the active channel list" try: self.lock.acquire() if self.channels.has_key(chanid): @@ -257,7 +320,7 @@ class BaseTransport(threading.Thread): finally: self.lock.release() - def read_all(self, n): + def _read_all(self, n): out = '' while n > 0: try: @@ -271,7 +334,7 @@ class BaseTransport(threading.Thread): raise EOFError() return out - def write_all(self, out): + def _write_all(self, out): while len(out) > 0: n = self.sock.send(out) if n <= 0: @@ -281,7 +344,7 @@ class BaseTransport(threading.Thread): out = out[n:] return - def build_packet(self, payload): + def _build_packet(self, payload): # pad up at least 4 bytes, to nearest block-size (usually 8) bsize = self.block_size_out padding = 3 + bsize - ((len(payload) + 8) % bsize) @@ -291,11 +354,12 @@ class BaseTransport(threading.Thread): packet += randpool.get_bytes(padding) return packet - def send_message(self, data): + def _send_message(self, data): + # FIXME: should we check for rekeying here too? # encrypt this sucka - packet = self.build_packet(str(data)) + packet = self._build_packet(str(data)) if self.ultra_debug: - self.log(DEBUG, format_binary(packet, 'OUT: ')) + self._log(DEBUG, format_binary(packet, 'OUT: ')) if self.engine_out != None: out = self.engine_out.encrypt(packet) else: @@ -308,29 +372,29 @@ class BaseTransport(threading.Thread): out += HMAC.HMAC(self.mac_key_out, payload, self.local_mac_engine).digest()[:self.local_mac_len] self.sequence_number_out += 1L self.sequence_number_out %= 0x100000000L - self.write_all(out) + self._write_all(out) finally: self.write_lock.release() - def read_message(self): + def _read_message(self): "only one thread will ever be in this function" - header = self.read_all(self.block_size_in) + header = self._read_all(self.block_size_in) if self.engine_in != None: header = self.engine_in.decrypt(header) if self.ultra_debug: - self.log(DEBUG, format_binary(header, 'IN: ')); + self._log(DEBUG, format_binary(header, 'IN: ')); packet_size = struct.unpack('>I', header[:4])[0] # leftover contains decrypted bytes from the first block (after the length field) leftover = header[4:] if (packet_size - len(leftover)) % self.block_size_in != 0: raise SSHException('Invalid packet blocking') - buffer = self.read_all(packet_size + self.remote_mac_len - len(leftover)) + buffer = self._read_all(packet_size + self.remote_mac_len - len(leftover)) packet = buffer[:packet_size - len(leftover)] post_packet = buffer[packet_size - len(leftover):] if self.engine_in != None: packet = self.engine_in.decrypt(packet) if self.ultra_debug: - self.log(DEBUG, format_binary(packet, 'IN: ')); + self._log(DEBUG, format_binary(packet, 'IN: ')); packet = leftover + packet if self.remote_mac_len > 0: mac = post_packet[:self.remote_mac_len] @@ -341,7 +405,7 @@ class BaseTransport(threading.Thread): padding = ord(packet[0]) payload = packet[1:packet_size - padding + 1] randpool.add_event(packet[packet_size - padding + 1]) - #self.log(DEBUG, 'Got payload (%d bytes, %d padding)' % (packet_size, padding)) + #self._log(DEBUG, 'Got payload (%d bytes, %d padding)' % (packet_size, padding)) msg = Message(payload[1:]) msg.seqno = self.sequence_number_in self.sequence_number_in = (self.sequence_number_in + 1) & 0xffffffffL @@ -351,10 +415,10 @@ class BaseTransport(threading.Thread): if (self.received_packets >= self.REKEY_PACKETS) or (self.received_bytes >= self.REKEY_BYTES): # only ask once for rekeying if self.local_kex_init is None: - self.log(DEBUG, 'Rekeying (hit %d packets, %d bytes)' % (self.received_packets, - self.received_bytes)) + self._log(DEBUG, 'Rekeying (hit %d packets, %d bytes)' % (self.received_packets, + self.received_bytes)) self.received_packets_overflow = 0 - self.send_kex_init() + self._send_kex_init() else: # we've asked to rekey already -- give them 20 packets to # comply, then just drop the connection @@ -364,14 +428,18 @@ class BaseTransport(threading.Thread): return ord(payload[0]), msg - def set_K_H(self, k, h): + def _set_K_H(self, k, h): "used by a kex object to set the K (root key) and H (exchange hash)" self.K = k self.H = h if self.session_id == None: self.session_id = h - def verify_key(self, host_key, sig): + def _expect_packet(self, type): + "used by a kex object to register the next packet type it expects to see" + self.expected_packet = type + + def _verify_key(self, host_key, sig): if self.host_key_type == 'ssh-rsa': key = RSAKey(Message(host_key)) elif self.host_key_type == 'ssh-dss': @@ -384,7 +452,7 @@ class BaseTransport(threading.Thread): raise SSHException('Signature verification (%s) failed. Boo. Robey should debug this.' % self.host_key_type) self.host_key = host_key - def compute_key(self, id, nbytes): + def _compute_key(self, id, nbytes): "id is 'A' - 'F' for the various keys used by ssh" m = Message() m.add_mpint(self.K) @@ -402,30 +470,30 @@ class BaseTransport(threading.Thread): sofar += hash return out[:nbytes] - def get_cipher(self, name, key, iv): - if not self.cipher_info.has_key(name): + def _get_cipher(self, name, key, iv): + if not self._cipher_info.has_key(name): raise SSHException('Unknown client cipher ' + name) - return self.cipher_info[name]['class'].new(key, self.cipher_info[name]['mode'], iv) + return self._cipher_info[name]['class'].new(key, self._cipher_info[name]['mode'], iv) - def run(self): - self.active = 1 + def _run(self): + self.active = True try: # SSH-1.99-OpenSSH_2.9p2 - self.write_all(self.local_version + '\r\n') - self.check_banner() - self.send_kex_init() + self._write_all(self.local_version + '\r\n') + self._check_banner() + self._send_kex_init() self.expected_packet = MSG_KEXINIT while self.active: - ptype, m = self.read_message() + ptype, m = self._read_message() if ptype == MSG_IGNORE: continue elif ptype == MSG_DISCONNECT: - self.parse_disconnect(m) - self.active = 0 + self._parse_disconnect(m) + self.active = False break elif ptype == MSG_DEBUG: - self.parse_debug(m) + self._parse_debug(m) continue if self.expected_packet != 0: if ptype != self.expected_packet: @@ -435,28 +503,29 @@ class BaseTransport(threading.Thread): self.kex_engine.parse_next(ptype, m) continue - if self.handler_table.has_key(ptype): - self.handler_table[ptype](self, m) - elif self.channel_handler_table.has_key(ptype): + if self._handler_table.has_key(ptype): + self._handler_table[ptype](self, m) + elif self._channel_handler_table.has_key(ptype): chanid = m.get_int() if self.channels.has_key(chanid): - self.channel_handler_table[ptype](self.channels[chanid], m) + self._channel_handler_table[ptype](self.channels[chanid], m) else: - self.log(WARNING, 'Oops, unhandled type %d' % ptype) + self._log(WARNING, 'Oops, unhandled type %d' % ptype) msg = Message() msg.add_byte(chr(MSG_UNIMPLEMENTED)) msg.add_int(m.seqno) - self.send_message(msg) + self._send_message(msg) except SSHException, e: - self.log(DEBUG, 'Exception: ' + str(e)) - self.log(DEBUG, tb_strings()) + self._log(DEBUG, 'Exception: ' + str(e)) + self._log(DEBUG, tb_strings()) except EOFError, e: - self.log(DEBUG, 'EOF') + self._log(DEBUG, 'EOF') + self._log(DEBUG, tb_strings()) except Exception, e: - self.log(DEBUG, 'Unknown exception: ' + str(e)) - self.log(DEBUG, tb_strings()) + self._log(DEBUG, 'Unknown exception: ' + str(e)) + self._log(DEBUG, tb_strings()) if self.active: - self.active = 0 + self.active = False if self.completion_event != None: self.completion_event.set() if self.auth_event != None: @@ -468,36 +537,49 @@ class BaseTransport(threading.Thread): ### protocol stages def renegotiate_keys(self): + """ + Force this session to switch to new keys. Normally this is done + automatically after the session hits a certain number of packets or + bytes sent or received, but this method gives you the option of forcing + new keys whenever you want. Negotiating new keys causes a pause in + traffic both ways as the two sides swap keys and do computations. This + method returns when the session has switched to new keys, or the + session has died mid-negotiation. + + @return: True if the renegotiation was successful, and the link is + using new keys; False if the session dropped during renegotiation. + @rtype: boolean + """ self.completion_event = threading.Event() - self.send_kex_init() + self._send_kex_init() while 1: self.completion_event.wait(0.1); if not self.active: - return 0 + return False if self.completion_event.isSet(): break - return 1 + return True - def negotiate_keys(self, m): + def _negotiate_keys(self, m): # throws SSHException on anything unusual if self.local_kex_init == None: # remote side wants to renegotiate - self.send_kex_init() - self.parse_kex_init(m) + self._send_kex_init() + self._parse_kex_init(m) self.kex_engine.start_kex() - def check_banner(self): + def _check_banner(self): # this is slow, but we only have to do it once for i in range(5): buffer = '' while not '\n' in buffer: - buffer += self.read_all(1) + buffer += self._read_all(1) buffer = buffer[:-1] if (len(buffer) > 0) and (buffer[-1] == '\r'): buffer = buffer[:-1] if buffer[:4] == 'SSH-': break - self.log(DEBUG, 'Banner: ' + buffer) + self._log(DEBUG, 'Banner: ' + buffer) if buffer[:4] != 'SSH-': raise SSHException('Indecipherable protocol version "' + buffer + '"') # save this server version string for later @@ -516,17 +598,21 @@ class BaseTransport(threading.Thread): client = segs[2] if version != '1.99' and version != '2.0': raise SSHException('Incompatible version (%s instead of 2.0)' % (version,)) - self.log(INFO, 'Connected (version %s, client %s)' % (version, client)) + self._log(INFO, 'Connected (version %s, client %s)' % (version, client)) - def send_kex_init(self): - # send a really wimpy kex-init packet that says we're a bare-bones ssh client + def _send_kex_init(self): + """ + announce to the other side that we'd like to negotiate keys, and what + kind of key negotiation we support. + """ if self.server_mode: - # FIXME: can't do group-exchange (gex) yet -- too slow - if 'diffie-hellman-group-exchange-sha1' in self.preferred_kex: + if (self.modulus_pack is None) and ('diffie-hellman-group-exchange-sha1' in self.preferred_kex): + # can't do group-exchange if we don't have a pack of potential primes self.preferred_kex.remove('diffie-hellman-group-exchange-sha1') - - available_server_keys = filter(self.server_key_dict.keys().__contains__, - self.preferred_keys) + available_server_keys = filter(self.server_key_dict.keys().__contains__, + self.preferred_keys) + else: + available_server_keys = self.preferred_keys m = Message() m.add_byte(chr(MSG_KEXINIT)) @@ -545,9 +631,9 @@ class BaseTransport(threading.Thread): m.add_int(0) # save a copy for later (needed to compute a hash) self.local_kex_init = str(m) - self.send_message(m) + self._send_message(m) - def parse_kex_init(self, m): + def _parse_kex_init(self, m): # reset counters of when to re-key, since we are now re-keying self.received_bytes = 0 self.received_packets = 0 @@ -580,7 +666,7 @@ class BaseTransport(threading.Thread): agreed_kex = filter(kex_algo_list.__contains__, self.preferred_kex) if len(agreed_kex) == 0: raise SSHException('Incompatible ssh peer (no acceptable kex algorithm)') - self.kex_engine = self.kex_info[agreed_kex[0]](self) + self.kex_engine = self._kex_info[agreed_kex[0]](self) if self.server_mode: available_server_keys = filter(self.server_key_dict.keys().__contains__, @@ -608,7 +694,7 @@ class BaseTransport(threading.Thread): raise SSHException('Incompatible ssh server (no acceptable ciphers)') self.local_cipher = agreed_local_ciphers[0] self.remote_cipher = agreed_remote_ciphers[0] - self.log(DEBUG, 'Ciphers agreed: local=%s, remote=%s' % (self.local_cipher, self.remote_cipher)) + self._log(DEBUG, 'Ciphers agreed: local=%s, remote=%s' % (self.local_cipher, self.remote_cipher)) if self.server_mode: agreed_remote_macs = filter(self.preferred_macs.__contains__, client_mac_algo_list) @@ -621,19 +707,19 @@ class BaseTransport(threading.Thread): self.local_mac = agreed_local_macs[0] self.remote_mac = agreed_remote_macs[0] - self.log(DEBUG, 'kex algos:' + str(kex_algo_list) + ' server key:' + str(server_key_algo_list) + \ - ' client encrypt:' + str(client_encrypt_algo_list) + \ - ' server encrypt:' + str(server_encrypt_algo_list) + \ - ' client mac:' + str(client_mac_algo_list) + \ - ' server mac:' + str(server_mac_algo_list) + \ - ' client compress:' + str(client_compress_algo_list) + \ - ' server compress:' + str(server_compress_algo_list) + \ - ' client lang:' + str(client_lang_list) + \ - ' server lang:' + str(server_lang_list) + \ - ' kex follows?' + str(kex_follows)) - self.log(DEBUG, 'using kex %s; server key type %s; cipher: local %s, remote %s; mac: local %s, remote %s' % - (agreed_kex[0], self.host_key_type, self.local_cipher, self.remote_cipher, self.local_mac, - self.remote_mac)) + self._log(DEBUG, 'kex algos:' + str(kex_algo_list) + ' server key:' + str(server_key_algo_list) + \ + ' client encrypt:' + str(client_encrypt_algo_list) + \ + ' server encrypt:' + str(server_encrypt_algo_list) + \ + ' client mac:' + str(client_mac_algo_list) + \ + ' server mac:' + str(server_mac_algo_list) + \ + ' client compress:' + str(client_compress_algo_list) + \ + ' server compress:' + str(server_compress_algo_list) + \ + ' client lang:' + str(client_lang_list) + \ + ' server lang:' + str(server_lang_list) + \ + ' kex follows?' + str(kex_follows)) + self._log(DEBUG, 'using kex %s; server key type %s; cipher: local %s, remote %s; mac: local %s, remote %s' % + (agreed_kex[0], self.host_key_type, self.local_cipher, self.remote_cipher, self.local_mac, + self.remote_mac)) # save for computing hash later... # now wait! openssh has a bug (and others might too) where there are @@ -642,50 +728,52 @@ class BaseTransport(threading.Thread): # away those bytes because they aren't part of the hash. self.remote_kex_init = chr(MSG_KEXINIT) + m.get_so_far() - def activate_inbound(self): + def _activate_inbound(self): "switch on newly negotiated encryption parameters for inbound traffic" - self.block_size_in = self.cipher_info[self.remote_cipher]['block-size'] + self.block_size_in = self._cipher_info[self.remote_cipher]['block-size'] if self.server_mode: - IV_in = self.compute_key('A', self.block_size_in) - key_in = self.compute_key('C', self.cipher_info[self.remote_cipher]['key-size']) + IV_in = self._compute_key('A', self.block_size_in) + key_in = self._compute_key('C', self._cipher_info[self.remote_cipher]['key-size']) else: - IV_in = self.compute_key('B', self.block_size_in) - key_in = self.compute_key('D', self.cipher_info[self.remote_cipher]['key-size']) - self.engine_in = self.get_cipher(self.remote_cipher, key_in, IV_in) - self.remote_mac_len = self.mac_info[self.remote_mac]['size'] - self.remote_mac_engine = self.mac_info[self.remote_mac]['class'] + IV_in = self._compute_key('B', self.block_size_in) + key_in = self._compute_key('D', self._cipher_info[self.remote_cipher]['key-size']) + self.engine_in = self._get_cipher(self.remote_cipher, key_in, IV_in) + self.remote_mac_len = self._mac_info[self.remote_mac]['size'] + self.remote_mac_engine = self._mac_info[self.remote_mac]['class'] # initial mac keys are done in the hash's natural size (not the potentially truncated # transmission size) if self.server_mode: - self.mac_key_in = self.compute_key('E', self.remote_mac_engine.digest_size) + self.mac_key_in = self._compute_key('E', self.remote_mac_engine.digest_size) else: - self.mac_key_in = self.compute_key('F', self.remote_mac_engine.digest_size) + self.mac_key_in = self._compute_key('F', self.remote_mac_engine.digest_size) - def activate_outbound(self): + def _activate_outbound(self): "switch on newly negotiated encryption parameters for outbound traffic" m = Message() m.add_byte(chr(MSG_NEWKEYS)) - self.send_message(m) - self.block_size_out = self.cipher_info[self.local_cipher]['block-size'] + self._send_message(m) + self.block_size_out = self._cipher_info[self.local_cipher]['block-size'] if self.server_mode: - IV_out = self.compute_key('B', self.block_size_out) - key_out = self.compute_key('D', self.cipher_info[self.local_cipher]['key-size']) + IV_out = self._compute_key('B', self.block_size_out) + key_out = self._compute_key('D', self._cipher_info[self.local_cipher]['key-size']) else: - IV_out = self.compute_key('A', self.block_size_out) - key_out = self.compute_key('C', self.cipher_info[self.local_cipher]['key-size']) - self.engine_out = self.get_cipher(self.local_cipher, key_out, IV_out) - self.local_mac_len = self.mac_info[self.local_mac]['size'] - self.local_mac_engine = self.mac_info[self.local_mac]['class'] + IV_out = self._compute_key('A', self.block_size_out) + key_out = self._compute_key('C', self._cipher_info[self.local_cipher]['key-size']) + self.engine_out = self._get_cipher(self.local_cipher, key_out, IV_out) + self.local_mac_len = self._mac_info[self.local_mac]['size'] + self.local_mac_engine = self._mac_info[self.local_mac]['class'] # initial mac keys are done in the hash's natural size (not the potentially truncated # transmission size) if self.server_mode: - self.mac_key_out = self.compute_key('F', self.local_mac_engine.digest_size) + self.mac_key_out = self._compute_key('F', self.local_mac_engine.digest_size) else: - self.mac_key_out = self.compute_key('E', self.local_mac_engine.digest_size) + self.mac_key_out = self._compute_key('E', self.local_mac_engine.digest_size) + # we always expect to receive NEWKEYS now + self.expected_packet = MSG_NEWKEYS - def parse_newkeys(self, m): - self.log(DEBUG, 'Switch to new keys ...') - self.activate_inbound() + def _parse_newkeys(self, m): + self._log(DEBUG, 'Switch to new keys ...') + self._activate_inbound() # can also free a bunch of stuff here self.local_kex_init = self.remote_kex_init = None self.e = self.f = self.K = self.x = None @@ -697,24 +785,24 @@ class BaseTransport(threading.Thread): self.completion_event.set() return - def parse_disconnect(self, m): + def _parse_disconnect(self, m): code = m.get_int() desc = m.get_string() - self.log(INFO, 'Disconnect (code %d): %s' % (code, desc)) + self._log(INFO, 'Disconnect (code %d): %s' % (code, desc)) - def parse_channel_open_success(self, m): + def _parse_channel_open_success(self, m): chanid = m.get_int() server_chanid = m.get_int() server_window_size = m.get_int() server_max_packet_size = m.get_int() if not self.channels.has_key(chanid): - self.log(WARNING, 'Success for unrequested channel! [??]') + self._log(WARNING, 'Success for unrequested channel! [??]') return try: self.lock.acquire() chan = self.channels[chanid] chan.set_remote_channel(server_chanid, server_window_size, server_max_packet_size) - self.log(INFO, 'Secsh channel %d opened.' % chanid) + self._log(INFO, 'Secsh channel %d opened.' % chanid) if self.channel_events.has_key(chanid): self.channel_events[chanid].set() del self.channel_events[chanid] @@ -722,7 +810,7 @@ class BaseTransport(threading.Thread): self.lock.release() return - def parse_channel_open_failure(self, m): + def _parse_channel_open_failure(self, m): chanid = m.get_int() reason = m.get_int() reason_str = m.get_string() @@ -731,7 +819,7 @@ class BaseTransport(threading.Thread): reason_text = CONNECTION_FAILED_CODE[reason] else: reason_text = '(unknown code)' - self.log(INFO, 'Secsh channel %d open FAILED: %s: %s' % (chanid, reason_str, reason_text)) + self._log(INFO, 'Secsh channel %d open FAILED: %s: %s' % (chanid, reason_str, reason_text)) try: self.lock.aquire() if self.channels.has_key(chanid): @@ -747,14 +835,14 @@ class BaseTransport(threading.Thread): "override me! return object descended from Channel to allow, or None to reject" return None - def parse_channel_open(self, m): + def _parse_channel_open(self, m): kind = m.get_string() chanid = m.get_int() initial_window_size = m.get_int() max_packet_size = m.get_int() reject = False if not self.server_mode: - self.log(DEBUG, 'Rejecting "%s" channel request from server.' % kind) + self._log(DEBUG, 'Rejecting "%s" channel request from server.' % kind) reject = True reason = self.OPEN_FAILED_ADMINISTRATIVELY_PROHIBITED else: @@ -766,7 +854,7 @@ class BaseTransport(threading.Thread): self.lock.release() chan = self.check_channel_request(kind, my_chanid) if (chan is None) or (type(chan) is int): - self.log(DEBUG, 'Rejecting "%s" channel request from client.' % kind) + self._log(DEBUG, 'Rejecting "%s" channel request from client.' % kind) reject = True if type(chan) is int: reason = chan @@ -779,12 +867,12 @@ class BaseTransport(threading.Thread): msg.add_int(reason) msg.add_string('') msg.add_string('en') - self.send_message(msg) + self._send_message(msg) return try: self.lock.acquire() self.channels[my_chanid] = chan - chan.set_transport(self) + chan._set_transport(self) chan.set_window(self.window_size, self.max_packet_size) chan.set_remote_channel(chanid, initial_window_size, max_packet_size) finally: @@ -795,8 +883,8 @@ class BaseTransport(threading.Thread): m.add_int(my_chanid) m.add_int(self.window_size) m.add_int(self.max_packet_size) - self.send_message(m) - self.log(INFO, 'Secsh channel %d opened.' % my_chanid) + self._send_message(m) + self._log(INFO, 'Secsh channel %d opened.' % my_chanid) try: self.lock.acquire() self.server_accepts.append(chan) @@ -820,21 +908,21 @@ class BaseTransport(threading.Thread): self.lock.release() return chan - def parse_debug(self, m): + def _parse_debug(self, m): always_display = m.get_boolean() msg = m.get_string() lang = m.get_string() - self.log(DEBUG, 'Debug msg: ' + safe_string(msg)) - - handler_table = { - MSG_NEWKEYS: parse_newkeys, - MSG_CHANNEL_OPEN_SUCCESS: parse_channel_open_success, - MSG_CHANNEL_OPEN_FAILURE: parse_channel_open_failure, - MSG_CHANNEL_OPEN: parse_channel_open, - MSG_KEXINIT: negotiate_keys, + self._log(DEBUG, 'Debug msg: ' + safe_string(msg)) + + _handler_table = { + MSG_NEWKEYS: _parse_newkeys, + MSG_CHANNEL_OPEN_SUCCESS: _parse_channel_open_success, + MSG_CHANNEL_OPEN_FAILURE: _parse_channel_open_failure, + MSG_CHANNEL_OPEN: _parse_channel_open, + MSG_KEXINIT: _negotiate_keys, } - channel_handler_table = { + _channel_handler_table = { MSG_CHANNEL_SUCCESS: Channel.request_success, MSG_CHANNEL_FAILURE: Channel.request_failed, MSG_CHANNEL_DATA: Channel.feed, diff --git a/paramiko/util.py b/paramiko/util.py index fd78af38..33b671c6 100644 --- a/paramiko/util.py +++ b/paramiko/util.py @@ -1,7 +1,6 @@ #!/usr/bin/python import sys, struct, traceback -from Crypto.Util import number def inflate_long(s, always_positive=0): "turns a normalized byte string into a long-int (adapted from Crypto.Util.number)" @@ -98,20 +97,5 @@ def bit_length(n): bitlen -= 1 return bitlen -def generate_prime(bits, randpool): - hbyte_mask = pow(2, bits % 8) - 1 - x = randpool.get_bytes((bits+7) // 8) - if hbyte_mask > 0: - x = chr(ord(x[0]) & hbyte_mask) + x[1:] - n = inflate_long(x, 1) - n |= 1 - n |= (1 << (bits - 1)) - while 1: - # loop catches the case where we increment n into a higher bit-range - while not number.isPrime(n): - n += 2 - if bit_length(n) == bits: - return n - def tb_strings(): return ''.join(traceback.format_exception(*sys.exc_info())).split('\n') |