1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
|
#!/usr/bin/python
# variant on group1 (see kex_group1.py) where the prime "p" and generator "g"
# are provided by the server. a bit more work is required on our side (and a
# LOT more on the server side).
from message import Message
from util import inflate_long, deflate_long, generate_prime, bit_length
from paramiko import SSHException
from transport import MSG_NEWKEYS
from Crypto.Hash import SHA
from Crypto.Util import number
from logging import DEBUG
MSG_KEXDH_GEX_GROUP, MSG_KEXDH_GEX_INIT, MSG_KEXDH_GEX_REPLY, MSG_KEXDH_GEX_REQUEST = range(31, 35)
class KexGex(object):
name = 'diffie-hellman-group-exchange-sha1'
min_bits = 1024
max_bits = 8192
preferred_bits = 2048
def __init__(self, transport):
self.transport = transport
def start_kex(self):
if self.transport.server_mode:
self.transport.expected_packet = MSG_KEXDH_GEX_REQUEST
return
# request a bit range: we accept (min_bits) to (max_bits), but prefer
# (preferred_bits). according to the spec, we shouldn't pull the
# minimum up above 1024.
m = Message()
m.add_byte(chr(MSG_KEXDH_GEX_REQUEST))
m.add_int(self.min_bits)
m.add_int(self.preferred_bits)
m.add_int(self.max_bits)
self.transport.send_message(m)
self.transport.expected_packet = MSG_KEXDH_GEX_GROUP
def parse_next(self, ptype, m):
if ptype == MSG_KEXDH_GEX_REQUEST:
return self.parse_kexdh_gex_request(m)
elif ptype == MSG_KEXDH_GEX_GROUP:
return self.parse_kexdh_gex_group(m)
elif ptype == MSG_KEXDH_GEX_INIT:
return self.parse_kexdh_gex_init(m)
elif ptype == MSG_KEXDH_GEX_REPLY:
return self.parse_kexdh_gex_reply(m)
raise SSHException('KexGex asked to handle packet type %d' % ptype)
def generate_x(self):
# generate an "x" (1 < x < (p-1)/2).
q = (self.p - 1) // 2
qnorm = deflate_long(q, 0)
qhbyte = ord(qnorm[0])
bytes = len(qnorm)
qmask = 0xff
while not (qhbyte & 0x80):
qhbyte <<= 1
qmask >>= 1
while 1:
self.transport.randpool.stir()
x_bytes = self.transport.randpool.get_bytes(bytes)
x_bytes = chr(ord(x_bytes[0]) & qmask) + x_bytes[1:]
x = inflate_long(x_bytes, 1)
if (x > 1) and (x < q):
break
self.x = x
def parse_kexdh_gex_request(self, m):
min = m.get_int()
preferred = m.get_int()
max = m.get_int()
# smoosh the user's preferred size into our own limits
if preferred > self.max_bits:
preferred = self.max_bits
if preferred < self.min_bits:
preferred = self.min_bits
# now save a copy
self.min_bits = min
self.preferred_bits = preferred
self.max_bits = max
# generate prime
while 1:
# does not work FIXME
# the problem is that it's too fscking SLOW
self.transport.log(DEBUG, 'stir...')
self.transport.randpool.stir()
self.transport.log(DEBUG, 'get-prime %d...' % preferred)
self.p = generate_prime(preferred, self.transport.randpool)
self.transport.log(DEBUG, 'got ' + repr(self.p))
if number.isPrime((self.p - 1) // 2):
break
self.g = 2
m = Message()
m.add_byte(chr(MSG_KEXDH_GEX_GROUP))
m.add_mpint(self.p)
m.add_mpint(self.g)
self.transport.send_message(m)
self.transport.expected_packet = MSG_KEXDH_GEX_INIT
def parse_kexdh_gex_group(self, m):
self.p = m.get_mpint()
self.g = m.get_mpint()
# reject if p's bit length < 1024 or > 8192
bitlen = bit_length(self.p)
if (bitlen < 1024) or (bitlen > 8192):
raise SSHException('Server-generated gex p (don\'t ask) is out of range (%d bits)' % bitlen)
self.transport.log(DEBUG, 'Got server p (%d bits)' % bitlen)
self.generate_x()
# now compute e = g^x mod p
self.e = pow(self.g, self.x, self.p)
m = Message()
m.add_byte(chr(MSG_KEXDH_GEX_INIT))
m.add_mpint(self.e)
self.transport.send_message(m)
self.transport.expected_packet = MSG_KEXDH_GEX_REPLY
def parse_kexdh_gex_init(self, m):
self.e = m.get_mpint()
if (self.e < 1) or (self.e > self.p - 1):
raise SSHException('Client kex "e" is out of range')
self.generate_x()
K = pow(self.e, self.x, P)
key = str(self.transport.get_server_key())
# okay, build up the hash H of (V_C || V_S || I_C || I_S || K_S || min || n || max || p || g || e || f || K)
hm = Message().add(self.transport.remote_version).add(self.transport.local_version)
hm.add(self.transport.remote_kex_init).add(self.transport.local_kex_init).add(key)
hm.add_int(self.min_bits)
hm.add_int(self.preferred_bits)
hm.add_int(self.max_bits)
hm.add_mpint(self.p)
hm.add_mpint(self.g)
hm.add(self.e).add(self.f).add(K)
H = SHA.new(str(hm)).digest()
self.transport.set_K_H(K, H)
# sign it
sig = self.transport.get_server_key().sign_ssh_data(self.transport.randpool, H)
# send reply
m = Message()
m.add_byte(chr(MSG_KEXDH_GEX_REPLY))
m.add_string(key)
m.add_mpint(self.f)
m.add_string(sig)
self.transport.send_message(m)
self.transport.activate_outbound()
self.transport.expected_packet = MSG_NEWKEYS
def parse_kexdh_gex_reply(self, m):
host_key = m.get_string()
self.f = m.get_mpint()
sig = m.get_string()
if (self.f < 1) or (self.f > self.p - 1):
raise SSHException('Server kex "f" is out of range')
K = pow(self.f, self.x, self.p)
# okay, build up the hash H of (V_C || V_S || I_C || I_S || K_S || min || n || max || p || g || e || f || K)
hm = Message().add(self.transport.local_version).add(self.transport.remote_version)
hm.add(self.transport.local_kex_init).add(self.transport.remote_kex_init).add(host_key)
hm.add_int(self.min_bits)
hm.add_int(self.preferred_bits)
hm.add_int(self.max_bits)
hm.add_mpint(self.p)
hm.add_mpint(self.g)
hm.add(self.e).add(self.f).add(K)
self.transport.set_K_H(K, SHA.new(str(hm)).digest())
self.transport.verify_key(host_key, sig)
self.transport.activate_outbound()
self.transport.expected_packet = MSG_NEWKEYS
|