diff options
-rw-r--r-- | paramiko/kex_gex.py | 28 | ||||
-rw-r--r-- | paramiko/kex_group1.py | 44 |
2 files changed, 47 insertions, 25 deletions
diff --git a/paramiko/kex_gex.py b/paramiko/kex_gex.py index 8f3a800f..c324211c 100644 --- a/paramiko/kex_gex.py +++ b/paramiko/kex_gex.py @@ -72,6 +72,10 @@ class KexGex (object): return self._parse_kexdh_gex_reply(m) raise SSHException('KexGex asked to handle packet type %d' % ptype) + + ### internals... + + def _generate_x(self): # generate an "x" (1 < x < (p-1)/2). q = (self.p - 1) // 2 @@ -82,7 +86,7 @@ class KexGex (object): while not (qhbyte & 0x80): qhbyte <<= 1 qmask >>= 1 - while 1: + while True: self.transport.randpool.stir() x_bytes = self.transport.randpool.get_bytes(bytes) x_bytes = chr(ord(x_bytes[0]) & qmask) + x_bytes[1:] @@ -152,8 +156,15 @@ class KexGex (object): hm = Message() hm.add(self.transport.remote_version, self.transport.local_version, self.transport.remote_kex_init, self.transport.local_kex_init, - key, self.min_bits, self.preferred_bits, self.max_bits, - self.p, self.g, self.e, self.f, K) + key) + hm.add_int(self.min_bits) + hm.add_int(self.preferred_bits) + hm.add_int(self.max_bits) + hm.add_mpint(self.p) + hm.add_mpint(self.g) + hm.add_mpint(self.e) + hm.add_mpint(self.f) + hm.add_mpint(K) H = SHA.new(str(hm)).digest() self.transport._set_K_H(K, H) # sign it @@ -178,8 +189,15 @@ class KexGex (object): hm = Message() hm.add(self.transport.local_version, self.transport.remote_version, self.transport.local_kex_init, self.transport.remote_kex_init, - host_key, self.min_bits, self.preferred_bits, self.max_bits, - self.p, self.g, self.e, self.f, K) + host_key) + hm.add_int(self.min_bits) + hm.add_int(self.preferred_bits) + hm.add_int(self.max_bits) + hm.add_mpint(self.p) + hm.add_mpint(self.g) + hm.add_mpint(self.e) + hm.add_mpint(self.f) + hm.add_mpint(K) self.transport._set_K_H(K, SHA.new(str(hm)).digest()) self.transport._verify_key(host_key, sig) self.transport._activate_outbound() diff --git a/paramiko/kex_group1.py b/paramiko/kex_group1.py index fb70648a..9b77a0f1 100644 --- a/paramiko/kex_group1.py +++ b/paramiko/kex_group1.py @@ -44,23 +44,8 @@ class KexGroup1(object): def __init__(self, transport): self.transport = transport - def generate_x(self): - # generate an "x" (1 < x < q), where q is (p-1)/2. - # p is a 128-byte (1024-bit) number, where the first 64 bits are 1. - # therefore q can be approximated as a 2^1023. we drop the subset of - # potential x where the first 63 bits are 1, because some of those will be - # larger than q (but this is a tiny tiny subset of potential x). - while 1: - self.transport.randpool.stir() - x_bytes = self.transport.randpool.get_bytes(128) - x_bytes = chr(ord(x_bytes[0]) & 0x7f) + x_bytes[1:] - if (x_bytes[:8] != '\x7F\xFF\xFF\xFF\xFF\xFF\xFF\xFF') and \ - (x_bytes[:8] != '\x00\x00\x00\x00\x00\x00\x00\x00'): - break - self.x = util.inflate_long(x_bytes) - def start_kex(self): - self.generate_x() + self._generate_x() if self.transport.server_mode: # compute f = g^x mod p, but don't send it yet self.f = pow(G, self.x, P) @@ -76,12 +61,31 @@ class KexGroup1(object): def parse_next(self, ptype, m): if self.transport.server_mode and (ptype == _MSG_KEXDH_INIT): - return self.parse_kexdh_init(m) + return self._parse_kexdh_init(m) elif not self.transport.server_mode and (ptype == _MSG_KEXDH_REPLY): - return self.parse_kexdh_reply(m) + return self._parse_kexdh_reply(m) raise SSHException('KexGroup1 asked to handle packet type %d' % ptype) + + + ### internals... + + + def _generate_x(self): + # generate an "x" (1 < x < q), where q is (p-1)/2. + # p is a 128-byte (1024-bit) number, where the first 64 bits are 1. + # therefore q can be approximated as a 2^1023. we drop the subset of + # potential x where the first 63 bits are 1, because some of those will be + # larger than q (but this is a tiny tiny subset of potential x). + while 1: + self.transport.randpool.stir() + x_bytes = self.transport.randpool.get_bytes(128) + x_bytes = chr(ord(x_bytes[0]) & 0x7f) + x_bytes[1:] + if (x_bytes[:8] != '\x7F\xFF\xFF\xFF\xFF\xFF\xFF\xFF') and \ + (x_bytes[:8] != '\x00\x00\x00\x00\x00\x00\x00\x00'): + break + self.x = util.inflate_long(x_bytes) - def parse_kexdh_reply(self, m): + def _parse_kexdh_reply(self, m): # client mode host_key = m.get_string() self.f = m.get_mpint() @@ -98,7 +102,7 @@ class KexGroup1(object): self.transport._verify_key(host_key, sig) self.transport._activate_outbound() - def parse_kexdh_init(self, m): + def _parse_kexdh_init(self, m): # server mode self.e = m.get_mpint() if (self.e < 1) or (self.e > P - 1): |