summaryrefslogtreecommitdiffhomepage
path: root/dsskey.py
blob: a5e5c9a10d84c524271ea530ce840acc24c9aeb1 (plain)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
#!/usr/bin/python

import base64
from paramiko import SSHException
from message import Message
from transport import MSG_USERAUTH_REQUEST
from util import inflate_long, deflate_long
from Crypto.PublicKey import DSA
from Crypto.Hash import SHA, MD5
from ber import BER

from util import format_binary


class DSSKey(object):

    def __init__(self, msg=None):
        self.valid = 0
        if (msg == None) or (msg.get_string() != 'ssh-dss'):
            return
        self.p = msg.get_mpint()
        self.q = msg.get_mpint()
        self.g = msg.get_mpint()
        self.y = msg.get_mpint()
        self.size = len(deflate_long(self.p, 0))
        self.valid = 1

    def __str__(self):
        if not self.valid:
            return ''
        m = Message()
        m.add_string('ssh-dss')
        m.add_mpint(self.p)
        m.add_mpint(self.q)
        m.add_mpint(self.g)
        m.add_mpint(self.y)
        return str(m)

    def get_name(self):
        return 'ssh-dss'

    def get_fingerprint(self):
        return MD5.new(str(self)).digest()
        
    def verify_ssh_sig(self, data, msg):
        if not self.valid:
            return 0
        if len(str(msg)) == 40:
            # spies.com bug: signature has no header
            sig = str(msg)
        else:
            kind = msg.get_string()
            if kind != 'ssh-dss':
                return 0
            sig = msg.get_string()

        # pull out (r, s) which are NOT encoded as mpints
        sigR = inflate_long(sig[:20], 1)
        sigS = inflate_long(sig[20:], 1)
        sigM = inflate_long(SHA.new(data).digest(), 1)

        dss = DSA.construct((long(self.y), long(self.g), long(self.p), long(self.q)))
        return dss.verify(sigM, (sigR, sigS))

    def sign_ssh_data(self, randpool, data):
        hash = SHA.new(data).digest()
        dss = DSA.construct((long(self.y), long(self.g), long(self.p), long(self.q), long(self.x)))
        # generate a suitable k
        qsize = len(deflate_long(self.q, 0))
        while 1:
            k = inflate_long(randpool.get_bytes(qsize), 1)
            if (k > 2) and (k < self.q):
                break
        r, s = dss.sign(inflate_long(hash, 1), k)
        m = Message()
        m.add_string('ssh-dss')
        m.add_string(deflate_long(r, 0) + deflate_long(s, 0))
        return str(m)

    def read_private_key_file(self, filename):
        "throws a file exception, or SSHException (on invalid key, or base64 decoding exception"
        # private key file contains:
        # DSAPrivateKey = { version = 0, p, q, g, y, x }
        self.valid = 0
        f = open(filename, 'r')
        lines = f.readlines()
        f.close()
        if lines[0].strip() != '-----BEGIN DSA PRIVATE KEY-----':
            raise SSHException('not a valid DSA private key file')
        data = base64.decodestring(''.join(lines[1:-1]))
        keylist = BER(data).decode()
        if (type(keylist) != type([])) or (len(keylist) < 6) or (keylist[0] != 0):
            raise SSHException('not a valid DSA private key file (bad ber encoding)')
        self.p = keylist[1]
        self.q = keylist[2]
        self.g = keylist[3]
        self.y = keylist[4]
        self.x = keylist[5]
        self.size = len(deflate_long(self.p, 0))
        self.valid = 1

    def sign_ssh_session(self, randpool, sid, username):
        m = Message()
        m.add_string(sid)
        m.add_byte(chr(MSG_USERAUTH_REQUEST))
        m.add_string(username)
        m.add_string('ssh-connection')
        m.add_string('publickey')
        m.add_boolean(1)
        m.add_string('ssh-dss')
        m.add_string(str(self))
        return self.sign_ssh_data(randpool, str(m))