summaryrefslogtreecommitdiffhomepage
path: root/paramiko/kex_curve25519.py
blob: b305165d413b51431e40f095bbb7636a0823e808 (plain)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
import binascii
import hashlib

from cryptography.hazmat.primitives import constant_time, serialization
from cryptography.hazmat.primitives.asymmetric.x25519 import (
    X25519PrivateKey, X25519PublicKey
)

from paramiko.message import Message
from paramiko.py3compat import byte_chr, long


_MSG_KEXECDH_INIT, _MSG_KEXECDH_REPLY = range(30, 32)
c_MSG_KEXECDH_INIT, c_MSG_KEXECDH_REPLY = [byte_chr(c) for c in range(30, 32)]

class KexCurve25519(object):
    def __init__(self, transport):
        self.transport = transport
        self.key = None

    def _perform_exchange(self, peer_key):
        secret = self.key.exchange(peer_key)
        if constant_time.bytes_eq(secret, b"\x00" * 32):
            raise SSHException("peer's curve25519 public value has wrong order")
        return secret

    def start_kex(self):
        self.key = X25519PrivateKey.generate()
        if self.transport.server_mode:
            self.transport._expect_packet(_MSG_KEXECDH_INIT)
            return

        m = Message()
        m.add_byte(c_MSG_KEXECDH_INIT)
        m.add_string(self.key.public_key().public_bytes(
            serialization.Encoding.Raw, serialization.PublicFormat.Raw
        ))
        self.transport._send_message(m)
        self.transport._expect_packet(_MSG_KEXECDH_REPLY)

    def parse_next(self, ptype, m):
        if self.transport.server_mode and (ptype == _MSG_KEXECDH_INIT):
            return self._parse_kexecdh_init(m)
        elif not self.transport.server_mode and (ptype == _MSG_KEXECDH_REPLY):
            return self._parse_kexecdh_reply(m)
        raise SSHException(
            "KexCurve25519 asked to handle packet type {:d}".format(ptype)
        )

    def _parse_kexecdh_init(self, m):
        peer_key_bytes = m.get_string()
        peer_key = X25519PublicKey.from_public_bytes(peer_key_bytes)
        K = self._perform_exchange(peer_key)
        K = long(binascii.hexlify(K), 16)
        # compute exchange hash
        hm = Message()
        hm.add(
            self.transport.remote_version,
            self.transport.local_version,
            self.transport.remote_kex_init,
            self.transport.local_kex_init,
        )
        server_key_bytes = self.transport.get_server_key().asbytes()
        exchange_key_bytes = self.key.public_key().public_bytes(
            serialization.Encoding.Raw, serialization.PublicFormat.Raw,
        )
        hm.add_string(server_key_bytes)
        hm.add_string(peer_key_bytes)
        hm.add_string(exchange_key_bytes)
        hm.add_mpint(K)
        H = hashlib.sha256(hm.asbytes()).digest()
        self.transport._set_K_H(K, H)
        sig = self.transport.get_server_key().sign_ssh_data(H)
        # construct reply
        m = Message()
        m.add_byte(c_MSG_KEXECDH_REPLY)
        m.add_string(server_key_bytes)
        m.add_string(exchange_key_bytes)
        m.add_string(sig)
        self.transport._send_message(m)
        self.transport._activate_outbound()

    def _parse_kexecdh_reply(self, m):
        peer_host_key_bytes = m.get_string()
        peer_key_bytes = m.get_string()
        sig = m.get_binary()

        peer_key = X25519PublicKey.from_public_bytes(peer_key_bytes)

        K = self._perform_exchange(peer_key)
        K = long(binascii.hexlify(K), 16)
        # compute exchange hash and verify signature
        hm = Message()
        hm.add(
            self.transport.local_version,
            self.transport.remote_version,
            self.transport.local_kex_init,
            self.transport.remote_kex_init,
        )
        hm.add_string(peer_host_key_bytes)
        hm.add_string(self.key.public_key().public_bytes(
            serialization.Encoding.Raw, serialization.PublicFormat.Raw
        ))
        hm.add_string(peer_key_bytes)
        hm.add_mpint(K)
        self.transport._set_K_H(K, hashlib.sha256(hm.asbytes()).digest())
        self.transport._verify_key(peer_host_key_bytes, sig)
        self.transport._activate_outbound()