#!/usr/bin/env python

# Copyright (C) 2010 Sofian Brabez <sbz@6dev.net>
#
# This file is part of paramiko.
#
# Paramiko is free software; you can redistribute it and/or modify it under the
# terms of the GNU Lesser General Public License as published by the Free
# Software Foundation; either version 2.1 of the License, or (at your option)
# any later version.
#
# Paramiko is distributed in the hope that it will be useful, but WITHOUT ANY
# WARRANTY; without even the implied warranty of MERCHANTABILITY or FITNESS FOR
# A PARTICULAR PURPOSE.  See the GNU Lesser General Public License for more
# details.
#
# You should have received a copy of the GNU Lesser General Public License
# along with Paramiko; if not, write to the Free Software Foundation, Inc.,
# 59 Temple Place, Suite 330, Boston, MA  02111-1307  USA.

import sys

from binascii import hexlify
from optparse import OptionParser

from paramiko import DSSKey
from paramiko import RSAKey
from paramiko.ssh_exception import SSHException
from paramiko.py3compat import u

usage="""
%prog [-v] [-b bits] -t type [-N new_passphrase] [-f output_keyfile]"""

default_values = {
    "ktype": "dsa",
    "bits": 1024,
    "filename": "output",
    "comment": ""
}

key_dispatch_table = {
    'dsa': DSSKey,
    'rsa': RSAKey,
}

def progress(arg=None):

    if not arg:
        sys.stdout.write('0%\x08\x08\x08 ')
        sys.stdout.flush()
    elif arg[0] == 'p':
        sys.stdout.write('25%\x08\x08\x08\x08 ')
        sys.stdout.flush()
    elif arg[0] == 'h':
        sys.stdout.write('50%\x08\x08\x08\x08 ')
        sys.stdout.flush()
    elif arg[0] == 'x':
        sys.stdout.write('75%\x08\x08\x08\x08 ')
        sys.stdout.flush()

if __name__ == '__main__':

    phrase=None
    pfunc=None

    parser = OptionParser(usage=usage)
    parser.add_option("-t", "--type", type="string", dest="ktype",
        help="Specify type of key to create (dsa or rsa)",
        metavar="ktype", default=default_values["ktype"])
    parser.add_option("-b", "--bits", type="int", dest="bits",
        help="Number of bits in the key to create", metavar="bits",
        default=default_values["bits"])
    parser.add_option("-N", "--new-passphrase", dest="newphrase",
        help="Provide new passphrase", metavar="phrase")
    parser.add_option("-P", "--old-passphrase", dest="oldphrase",
        help="Provide old passphrase", metavar="phrase")
    parser.add_option("-f", "--filename", type="string", dest="filename",
        help="Filename of the key file", metavar="filename",
        default=default_values["filename"])
    parser.add_option("-q", "--quiet", default=False, action="store_false",
        help="Quiet")
    parser.add_option("-v", "--verbose", default=False, action="store_true",
        help="Verbose")
    parser.add_option("-C", "--comment", type="string", dest="comment",
        help="Provide a new comment", metavar="comment",
        default=default_values["comment"])

    (options, args) = parser.parse_args()

    if len(sys.argv) == 1:
        parser.print_help()
        sys.exit(0)

    for o in list(default_values.keys()):
        globals()[o] = getattr(options, o, default_values[o.lower()])

    if options.newphrase:
        phrase = getattr(options, 'newphrase')

    if options.verbose:
        pfunc = progress
        sys.stdout.write("Generating priv/pub %s %d bits key pair (%s/%s.pub)..." % (ktype, bits, filename, filename))
        sys.stdout.flush()

    if ktype == 'dsa' and bits > 1024:
        raise SSHException("DSA Keys must be 1024 bits")

    if ktype not in key_dispatch_table:
        raise SSHException("Unknown %s algorithm to generate keys pair" % ktype)

    # generating private key
    prv = key_dispatch_table[ktype].generate(bits=bits, progress_func=pfunc)
    prv.write_private_key_file(filename, password=phrase)

    # generating public key
    pub = key_dispatch_table[ktype](filename=filename, password=phrase)
    with open("%s.pub" % filename, 'w') as f:
        f.write("%s %s" % (pub.get_name(), pub.get_base64()))
        if options.comment:
            f.write(" %s" % comment)

    if options.verbose:
        print("done.")

    hash = u(hexlify(pub.get_fingerprint()))
    print("Fingerprint: %d %s %s.pub (%s)" % (bits, ":".join([ hash[i:2+i] for i in range(0, len(hash), 2)]), filename, ktype.upper()))