summaryrefslogtreecommitdiffhomepage
path: root/kex_gex.py
blob: 2b6e11cd1f30361892590c830838de3c8d0075e0 (plain)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
#!/usr/bin/python

# variant on group1 (see kex_group1.py) where the prime "p" and generator "g"
# are provided by the server.  a bit more work is required on our side (and a
# LOT more on the server side).

from message import Message
from util import inflate_long, deflate_long, generate_prime
from secsh import SSHException
from transport import MSG_NEWKEYS
from Crypto.Hash import SHA
from Crypto.Util import number
from logging import DEBUG

MSG_KEXDH_GEX_GROUP, MSG_KEXDH_GEX_INIT, MSG_KEXDH_GEX_REPLY, MSG_KEXDH_GEX_REQUEST = range(31, 35)


class KexGex(object):

    name = 'diffie-hellman-group-exchange-sha1'
    min_bits = 1024
    max_bits = 8192
    preferred_bits = 2048

    def __init__(self, transport):
        self.transport = transport

    def start_kex(self):
        if self.transport.server_mode:
            self.transport.expected_packet = MSG_KEXDH_GEX_REQUEST
            return
        # request a bit range: we accept (min_bits) to (max_bits), but prefer
        # (preferred_bits).  according to the spec, we shouldn't pull the
        # minimum up above 1024.
        m = Message()
        m.add_byte(chr(MSG_KEXDH_GEX_REQUEST))
        m.add_int(self.min_bits)
        m.add_int(self.preferred_bits)
        m.add_int(self.max_bits)
        self.transport.send_message(m)
        self.transport.expected_packet = MSG_KEXDH_GEX_GROUP

    def parse_next(self, ptype, m):
        if ptype == MSG_KEXDH_GEX_REQUEST:
            return self.parse_kexdh_gex_request(m)
        elif ptype == MSG_KEXDH_GEX_GROUP:
            return self.parse_kexdh_gex_group(m)
        elif ptype == MSG_KEXDH_GEX_INIT:
            return self.parse_kexdh_gex_init(m)
        elif ptype == MSG_KEXDH_GEX_REPLY:
            return self.parse_kexdh_gex_reply(m)
        raise SSHException('KexGex asked to handle packet type %d' % ptype)

    def bit_length(n):
        norm = deflate_long(n, 0)
        hbyte = ord(norm[0])
        bitlen = len(norm) * 8
        while not (hbyte & 0x80):
            hbyte <<= 1
            bitlen -= 1
        return bitlen
    bit_length = staticmethod(bit_length)

    def generate_x(self):
        # generate an "x" (1 < x < (p-1)/2).
        q = (self.p - 1) // 2
        qnorm = deflate_long(q, 0)
        qhbyte = ord(qnorm[0])
        bytes = len(qnorm)
        qmask = 0xff
        while not (qhbyte & 0x80):
            qhbyte <<= 1
            qmask >>= 1
        while 1:
            self.transport.randpool.stir()
            x_bytes = self.transport.randpool.get_bytes(bytes)
            x_bytes = chr(ord(x_bytes[0]) & qmask) + x_bytes[1:]
            x = inflate_long(x_bytes, 1)
            if (x > 1) and (x < q):
                break
        self.x = x

    def parse_kexdh_gex_request(self, m):
        min = m.get_int()
        preferred = m.get_int()
        max = m.get_int()
        # smoosh the user's preferred size into our own limits
        if preferred > self.max_bits:
            preferred = self.max_bits
        if preferred < self.min_bits:
            preferred = self.min_bits
        # now save a copy
        self.min_bits = min
        self.preferred_bits = preferred
        self.max_bits = max
        # generate prime
        while 1:
            # does not work FIXME
            # the problem is that it's too fscking SLOW
            self.transport.log(DEBUG, 'stir...')
            self.transport.randpool.stir()
            self.transport.log(DEBUG, 'get-prime %d...' % preferred)
            self.p = generate_prime(preferred, self.transport.randpool)
            self.transport.log(DEBUG, 'got ' + repr(self.p))
            if number.isPrime((self.p - 1) // 2):
                break
        self.g = 2
        m = Message()
        m.add_byte(chr(MSG_KEXDH_GEX_GROUP))
        m.add_mpint(self.p)
        m.add_mpint(self.g)
        self.transport.send_message(m)
        self.transport.expected_packet = MSG_KEXDH_GEX_INIT

    def parse_kexdh_gex_group(self, m):
        self.p = m.get_mpint()
        self.g = m.get_mpint()
        # reject if p's bit length < 1024 or > 8192
        bitlen = self.bit_length(self.p)
        if (bitlen < 1024) or (bitlen > 8192):
            raise SSHException('Server-generated gex p (don\'t ask) is out of range (%d bits)' % bitlen)
        self.transport.log(DEBUG, 'Got server p (%d bits)' % bitlen)
        self.generate_x()
        # now compute e = g^x mod p
        self.e = pow(self.g, self.x, self.p)
        m = Message()
        m.add_byte(chr(MSG_KEXDH_GEX_INIT))
        m.add_mpint(self.e)
        self.transport.send_message(m)
        self.transport.expected_packet = MSG_KEXDH_GEX_REPLY

    def parse_kexdh_gex_init(self, m):
        self.e = m.get_mpint()
        if (self.e < 1) or (self.e > self.p - 1):
            raise SSHException('Client kex "e" is out of range')
        self.generate_x()
        K = pow(self.e, self.x, P)
        key = str(self.transport.get_server_key())
        # okay, build up the hash H of (V_C || V_S || I_C || I_S || K_S || min || n || max || p || g || e || f || K)
        hm = Message().add(self.transport.remote_version).add(self.transport.local_version)
        hm.add(self.transport.remote_kex_init).add(self.transport.local_kex_init).add(key)
        hm.add_int(self.min_bits)
        hm.add_int(self.preferred_bits)
        hm.add_int(self.max_bits)
        hm.add_mpint(self.p)
        hm.add_mpint(self.g)
        hm.add(self.e).add(self.f).add(K)
        H = SHA.new(str(hm)).digest()
        self.transport.set_K_H(K, H)
        # sign it
        sig = self.transport.get_server_key().sign_ssh_data(H)
        # send reply
        m = Message()
        m.add_byte(chr(MSG_KEXDH_GEX_REPLY))
        m.add_string(key)
        m.add_mpint(self.f)
        m.add_string(sig)
        self.transport.send_message(m)
        self.transport.activate_outbound()
        self.transport.expected_packet = MSG_NEWKEYS
        
    def parse_kexdh_gex_reply(self, m):
        host_key = m.get_string()
        self.f = m.get_mpint()
        sig = m.get_string()
        if (self.f < 1) or (self.f > self.p - 1):
            raise SSHException('Server kex "f" is out of range')
        K = pow(self.f, self.x, self.p)
        # okay, build up the hash H of (V_C || V_S || I_C || I_S || K_S || min || n || max || p || g || e || f || K)
        hm = Message().add(self.transport.local_version).add(self.transport.remote_version)
        hm.add(self.transport.local_kex_init).add(self.transport.remote_kex_init).add(host_key)
        hm.add_int(self.min_bits)
        hm.add_int(self.preferred_bits)
        hm.add_int(self.max_bits)
        hm.add_mpint(self.p)
        hm.add_mpint(self.g)
        hm.add(self.e).add(self.f).add(K)
        self.transport.set_K_H(K, SHA.new(str(hm)).digest())
        self.transport.verify_key(host_key, sig)
        self.transport.activate_outbound()
        self.transport.expected_packet = MSG_NEWKEYS