diff options
80 files changed, 4757 insertions, 3460 deletions
diff --git a/.travis.yml b/.travis.yml index 4316bd9f..16d33b76 100644 --- a/.travis.yml +++ b/.travis.yml @@ -23,7 +23,10 @@ install: - pip install codecov # For codecov specifically - pip install -r dev-requirements.txt script: - # flake8 is now possible! + # Fast syntax check failures for more rapid feedback to submitters + # (Travis-oriented metatask that version checks Python, installs, runs.) + - inv travis.blacken + # I have this in my git pre-push hook, but contributors probably don't - flake8 # All (including slow) tests, w/ coverage! - inv coverage diff --git a/demos/demo.py b/demos/demo.py index fff61784..c9b0a5f5 100755 --- a/demos/demo.py +++ b/demos/demo.py @@ -31,6 +31,7 @@ import traceback from paramiko.py3compat import input import paramiko + try: import interactive except ImportError: @@ -42,71 +43,73 @@ def agent_auth(transport, username): Attempt to authenticate to the given transport using any of the private keys available from an SSH agent. """ - + agent = paramiko.Agent() agent_keys = agent.get_keys() if len(agent_keys) == 0: return - + for key in agent_keys: - print('Trying ssh-agent key %s' % hexlify(key.get_fingerprint())) + print("Trying ssh-agent key %s" % hexlify(key.get_fingerprint())) try: transport.auth_publickey(username, key) - print('... success!') + print("... success!") return except paramiko.SSHException: - print('... nope.') + print("... nope.") def manual_auth(username, hostname): - default_auth = 'p' - auth = input('Auth by (p)assword, (r)sa key, or (d)ss key? [%s] ' % default_auth) + default_auth = "p" + auth = input( + "Auth by (p)assword, (r)sa key, or (d)ss key? [%s] " % default_auth + ) if len(auth) == 0: auth = default_auth - if auth == 'r': - default_path = os.path.join(os.environ['HOME'], '.ssh', 'id_rsa') - path = input('RSA key [%s]: ' % default_path) + if auth == "r": + default_path = os.path.join(os.environ["HOME"], ".ssh", "id_rsa") + path = input("RSA key [%s]: " % default_path) if len(path) == 0: path = default_path try: key = paramiko.RSAKey.from_private_key_file(path) except paramiko.PasswordRequiredException: - password = getpass.getpass('RSA key password: ') + password = getpass.getpass("RSA key password: ") key = paramiko.RSAKey.from_private_key_file(path, password) t.auth_publickey(username, key) - elif auth == 'd': - default_path = os.path.join(os.environ['HOME'], '.ssh', 'id_dsa') - path = input('DSS key [%s]: ' % default_path) + elif auth == "d": + default_path = os.path.join(os.environ["HOME"], ".ssh", "id_dsa") + path = input("DSS key [%s]: " % default_path) if len(path) == 0: path = default_path try: key = paramiko.DSSKey.from_private_key_file(path) except paramiko.PasswordRequiredException: - password = getpass.getpass('DSS key password: ') + password = getpass.getpass("DSS key password: ") key = paramiko.DSSKey.from_private_key_file(path, password) t.auth_publickey(username, key) else: - pw = getpass.getpass('Password for %s@%s: ' % (username, hostname)) + pw = getpass.getpass("Password for %s@%s: " % (username, hostname)) t.auth_password(username, pw) # setup logging -paramiko.util.log_to_file('demo.log') +paramiko.util.log_to_file("demo.log") -username = '' +username = "" if len(sys.argv) > 1: hostname = sys.argv[1] - if hostname.find('@') >= 0: - username, hostname = hostname.split('@') + if hostname.find("@") >= 0: + username, hostname = hostname.split("@") else: - hostname = input('Hostname: ') + hostname = input("Hostname: ") if len(hostname) == 0: - print('*** Hostname required.') + print("*** Hostname required.") sys.exit(1) port = 22 -if hostname.find(':') >= 0: - hostname, portstr = hostname.split(':') +if hostname.find(":") >= 0: + hostname, portstr = hostname.split(":") port = int(portstr) # now connect @@ -114,7 +117,7 @@ try: sock = socket.socket(socket.AF_INET, socket.SOCK_STREAM) sock.connect((hostname, port)) except Exception as e: - print('*** Connect failed: ' + str(e)) + print("*** Connect failed: " + str(e)) traceback.print_exc() sys.exit(1) @@ -123,34 +126,38 @@ try: try: t.start_client() except paramiko.SSHException: - print('*** SSH negotiation failed.') + print("*** SSH negotiation failed.") sys.exit(1) try: - keys = paramiko.util.load_host_keys(os.path.expanduser('~/.ssh/known_hosts')) + keys = paramiko.util.load_host_keys( + os.path.expanduser("~/.ssh/known_hosts") + ) except IOError: try: - keys = paramiko.util.load_host_keys(os.path.expanduser('~/ssh/known_hosts')) + keys = paramiko.util.load_host_keys( + os.path.expanduser("~/ssh/known_hosts") + ) except IOError: - print('*** Unable to open host keys file') + print("*** Unable to open host keys file") keys = {} # check server's host key -- this is important. key = t.get_remote_server_key() if hostname not in keys: - print('*** WARNING: Unknown host key!') + print("*** WARNING: Unknown host key!") elif key.get_name() not in keys[hostname]: - print('*** WARNING: Unknown host key!') + print("*** WARNING: Unknown host key!") elif keys[hostname][key.get_name()] != key: - print('*** WARNING: Host key has changed!!!') + print("*** WARNING: Host key has changed!!!") sys.exit(1) else: - print('*** Host key OK.') + print("*** Host key OK.") # get username - if username == '': + if username == "": default_username = getpass.getuser() - username = input('Username [%s]: ' % default_username) + username = input("Username [%s]: " % default_username) if len(username) == 0: username = default_username @@ -158,25 +165,23 @@ try: if not t.is_authenticated(): manual_auth(username, hostname) if not t.is_authenticated(): - print('*** Authentication failed. :(') + print("*** Authentication failed. :(") t.close() sys.exit(1) chan = t.open_session() chan.get_pty() chan.invoke_shell() - print('*** Here we go!\n') + print("*** Here we go!\n") interactive.interactive_shell(chan) chan.close() t.close() except Exception as e: - print('*** Caught exception: ' + str(e.__class__) + ': ' + str(e)) + print("*** Caught exception: " + str(e.__class__) + ": " + str(e)) traceback.print_exc() try: t.close() except: pass sys.exit(1) - - diff --git a/demos/demo_keygen.py b/demos/demo_keygen.py index 860ee4e9..6a80272d 100755 --- a/demos/demo_keygen.py +++ b/demos/demo_keygen.py @@ -28,62 +28,97 @@ from paramiko import RSAKey from paramiko.ssh_exception import SSHException from paramiko.py3compat import u -usage=""" +usage = """ %prog [-v] [-b bits] -t type [-N new_passphrase] [-f output_keyfile]""" default_values = { "ktype": "dsa", "bits": 1024, "filename": "output", - "comment": "" + "comment": "", } -key_dispatch_table = { - 'dsa': DSSKey, - 'rsa': RSAKey, -} +key_dispatch_table = {"dsa": DSSKey, "rsa": RSAKey} + def progress(arg=None): if not arg: - sys.stdout.write('0%\x08\x08\x08 ') + sys.stdout.write("0%\x08\x08\x08 ") sys.stdout.flush() - elif arg[0] == 'p': - sys.stdout.write('25%\x08\x08\x08\x08 ') + elif arg[0] == "p": + sys.stdout.write("25%\x08\x08\x08\x08 ") sys.stdout.flush() - elif arg[0] == 'h': - sys.stdout.write('50%\x08\x08\x08\x08 ') + elif arg[0] == "h": + sys.stdout.write("50%\x08\x08\x08\x08 ") sys.stdout.flush() - elif arg[0] == 'x': - sys.stdout.write('75%\x08\x08\x08\x08 ') + elif arg[0] == "x": + sys.stdout.write("75%\x08\x08\x08\x08 ") sys.stdout.flush() -if __name__ == '__main__': - phrase=None - pfunc=None +if __name__ == "__main__": + + phrase = None + pfunc = None parser = OptionParser(usage=usage) - parser.add_option("-t", "--type", type="string", dest="ktype", + parser.add_option( + "-t", + "--type", + type="string", + dest="ktype", help="Specify type of key to create (dsa or rsa)", - metavar="ktype", default=default_values["ktype"]) - parser.add_option("-b", "--bits", type="int", dest="bits", - help="Number of bits in the key to create", metavar="bits", - default=default_values["bits"]) - parser.add_option("-N", "--new-passphrase", dest="newphrase", - help="Provide new passphrase", metavar="phrase") - parser.add_option("-P", "--old-passphrase", dest="oldphrase", - help="Provide old passphrase", metavar="phrase") - parser.add_option("-f", "--filename", type="string", dest="filename", - help="Filename of the key file", metavar="filename", - default=default_values["filename"]) - parser.add_option("-q", "--quiet", default=False, action="store_false", - help="Quiet") - parser.add_option("-v", "--verbose", default=False, action="store_true", - help="Verbose") - parser.add_option("-C", "--comment", type="string", dest="comment", - help="Provide a new comment", metavar="comment", - default=default_values["comment"]) + metavar="ktype", + default=default_values["ktype"], + ) + parser.add_option( + "-b", + "--bits", + type="int", + dest="bits", + help="Number of bits in the key to create", + metavar="bits", + default=default_values["bits"], + ) + parser.add_option( + "-N", + "--new-passphrase", + dest="newphrase", + help="Provide new passphrase", + metavar="phrase", + ) + parser.add_option( + "-P", + "--old-passphrase", + dest="oldphrase", + help="Provide old passphrase", + metavar="phrase", + ) + parser.add_option( + "-f", + "--filename", + type="string", + dest="filename", + help="Filename of the key file", + metavar="filename", + default=default_values["filename"], + ) + parser.add_option( + "-q", "--quiet", default=False, action="store_false", help="Quiet" + ) + parser.add_option( + "-v", "--verbose", default=False, action="store_true", help="Verbose" + ) + parser.add_option( + "-C", + "--comment", + type="string", + dest="comment", + help="Provide a new comment", + metavar="comment", + default=default_values["comment"], + ) (options, args) = parser.parse_args() @@ -95,18 +130,23 @@ if __name__ == '__main__': globals()[o] = getattr(options, o, default_values[o.lower()]) if options.newphrase: - phrase = getattr(options, 'newphrase') + phrase = getattr(options, "newphrase") if options.verbose: pfunc = progress - sys.stdout.write("Generating priv/pub %s %d bits key pair (%s/%s.pub)..." % (ktype, bits, filename, filename)) + sys.stdout.write( + "Generating priv/pub %s %d bits key pair (%s/%s.pub)..." + % (ktype, bits, filename, filename) + ) sys.stdout.flush() - if ktype == 'dsa' and bits > 1024: + if ktype == "dsa" and bits > 1024: raise SSHException("DSA Keys must be 1024 bits") if ktype not in key_dispatch_table: - raise SSHException("Unknown %s algorithm to generate keys pair" % ktype) + raise SSHException( + "Unknown %s algorithm to generate keys pair" % ktype + ) # generating private key prv = key_dispatch_table[ktype].generate(bits=bits, progress_func=pfunc) @@ -114,7 +154,7 @@ if __name__ == '__main__': # generating public key pub = key_dispatch_table[ktype](filename=filename, password=phrase) - with open("%s.pub" % filename, 'w') as f: + with open("%s.pub" % filename, "w") as f: f.write("%s %s" % (pub.get_name(), pub.get_base64())) if options.comment: f.write(" %s" % comment) @@ -123,4 +163,12 @@ if __name__ == '__main__': print("done.") hash = u(hexlify(pub.get_fingerprint())) - print("Fingerprint: %d %s %s.pub (%s)" % (bits, ":".join([ hash[i:2+i] for i in range(0, len(hash), 2)]), filename, ktype.upper())) + print( + "Fingerprint: %d %s %s.pub (%s)" + % ( + bits, + ":".join([hash[i : 2 + i] for i in range(0, len(hash), 2)]), + filename, + ktype.upper(), + ) + ) diff --git a/demos/demo_server.py b/demos/demo_server.py index 3a7ec854..313e5fb2 100644 --- a/demos/demo_server.py +++ b/demos/demo_server.py @@ -31,45 +31,47 @@ from paramiko.py3compat import b, u, decodebytes # setup logging -paramiko.util.log_to_file('demo_server.log') +paramiko.util.log_to_file("demo_server.log") -host_key = paramiko.RSAKey(filename='test_rsa.key') -#host_key = paramiko.DSSKey(filename='test_dss.key') +host_key = paramiko.RSAKey(filename="test_rsa.key") +# host_key = paramiko.DSSKey(filename='test_dss.key') -print('Read key: ' + u(hexlify(host_key.get_fingerprint()))) +print("Read key: " + u(hexlify(host_key.get_fingerprint()))) -class Server (paramiko.ServerInterface): +class Server(paramiko.ServerInterface): # 'data' is the output of base64.b64encode(key) # (using the "user_rsa_key" files) - data = (b'AAAAB3NzaC1yc2EAAAABIwAAAIEAyO4it3fHlmGZWJaGrfeHOVY7RWO3P9M7hp' - b'fAu7jJ2d7eothvfeuoRFtJwhUmZDluRdFyhFY/hFAh76PJKGAusIqIQKlkJxMC' - b'KDqIexkgHAfID/6mqvmnSJf0b5W8v5h2pI/stOSwTQ+pxVhwJ9ctYDhRSlF0iT' - b'UWT10hcuO4Ks8=') + data = ( + b"AAAAB3NzaC1yc2EAAAABIwAAAIEAyO4it3fHlmGZWJaGrfeHOVY7RWO3P9M7hp" + b"fAu7jJ2d7eothvfeuoRFtJwhUmZDluRdFyhFY/hFAh76PJKGAusIqIQKlkJxMC" + b"KDqIexkgHAfID/6mqvmnSJf0b5W8v5h2pI/stOSwTQ+pxVhwJ9ctYDhRSlF0iT" + b"UWT10hcuO4Ks8=" + ) good_pub_key = paramiko.RSAKey(data=decodebytes(data)) def __init__(self): self.event = threading.Event() def check_channel_request(self, kind, chanid): - if kind == 'session': + if kind == "session": return paramiko.OPEN_SUCCEEDED return paramiko.OPEN_FAILED_ADMINISTRATIVELY_PROHIBITED def check_auth_password(self, username, password): - if (username == 'robey') and (password == 'foo'): + if (username == "robey") and (password == "foo"): return paramiko.AUTH_SUCCESSFUL return paramiko.AUTH_FAILED def check_auth_publickey(self, username, key): - print('Auth attempt with key: ' + u(hexlify(key.get_fingerprint()))) - if (username == 'robey') and (key == self.good_pub_key): + print("Auth attempt with key: " + u(hexlify(key.get_fingerprint()))) + if (username == "robey") and (key == self.good_pub_key): return paramiko.AUTH_SUCCESSFUL return paramiko.AUTH_FAILED - - def check_auth_gssapi_with_mic(self, username, - gss_authenticated=paramiko.AUTH_FAILED, - cc_file=None): + + def check_auth_gssapi_with_mic( + self, username, gss_authenticated=paramiko.AUTH_FAILED, cc_file=None + ): """ .. note:: We are just checking in `AuthHandler` that the given user is a @@ -88,9 +90,9 @@ class Server (paramiko.ServerInterface): return paramiko.AUTH_SUCCESSFUL return paramiko.AUTH_FAILED - def check_auth_gssapi_keyex(self, username, - gss_authenticated=paramiko.AUTH_FAILED, - cc_file=None): + def check_auth_gssapi_keyex( + self, username, gss_authenticated=paramiko.AUTH_FAILED, cc_file=None + ): if gss_authenticated == paramiko.AUTH_SUCCESSFUL: return paramiko.AUTH_SUCCESSFUL return paramiko.AUTH_FAILED @@ -99,14 +101,15 @@ class Server (paramiko.ServerInterface): return True def get_allowed_auths(self, username): - return 'gssapi-keyex,gssapi-with-mic,password,publickey' + return "gssapi-keyex,gssapi-with-mic,password,publickey" def check_channel_shell_request(self, channel): self.event.set() return True - def check_channel_pty_request(self, channel, term, width, height, pixelwidth, - pixelheight, modes): + def check_channel_pty_request( + self, channel, term, width, height, pixelwidth, pixelheight, modes + ): return True @@ -116,22 +119,22 @@ DoGSSAPIKeyExchange = True try: sock = socket.socket(socket.AF_INET, socket.SOCK_STREAM) sock.setsockopt(socket.SOL_SOCKET, socket.SO_REUSEADDR, 1) - sock.bind(('', 2200)) + sock.bind(("", 2200)) except Exception as e: - print('*** Bind failed: ' + str(e)) + print("*** Bind failed: " + str(e)) traceback.print_exc() sys.exit(1) try: sock.listen(100) - print('Listening for connection ...') + print("Listening for connection ...") client, addr = sock.accept() except Exception as e: - print('*** Listen/accept failed: ' + str(e)) + print("*** Listen/accept failed: " + str(e)) traceback.print_exc() sys.exit(1) -print('Got a connection!') +print("Got a connection!") try: t = paramiko.Transport(client, gss_kex=DoGSSAPIKeyExchange) @@ -139,43 +142,44 @@ try: try: t.load_server_moduli() except: - print('(Failed to load moduli -- gex will be unsupported.)') + print("(Failed to load moduli -- gex will be unsupported.)") raise t.add_server_key(host_key) server = Server() try: t.start_server(server=server) except paramiko.SSHException: - print('*** SSH negotiation failed.') + print("*** SSH negotiation failed.") sys.exit(1) # wait for auth chan = t.accept(20) if chan is None: - print('*** No channel.') + print("*** No channel.") sys.exit(1) - print('Authenticated!') + print("Authenticated!") server.event.wait(10) if not server.event.is_set(): - print('*** Client never asked for a shell.') + print("*** Client never asked for a shell.") sys.exit(1) - chan.send('\r\n\r\nWelcome to my dorky little BBS!\r\n\r\n') - chan.send('We are on fire all the time! Hooray! Candy corn for everyone!\r\n') - chan.send('Happy birthday to Robot Dave!\r\n\r\n') - chan.send('Username: ') - f = chan.makefile('rU') - username = f.readline().strip('\r\n') - chan.send('\r\nI don\'t like you, ' + username + '.\r\n') + chan.send("\r\n\r\nWelcome to my dorky little BBS!\r\n\r\n") + chan.send( + "We are on fire all the time! Hooray! Candy corn for everyone!\r\n" + ) + chan.send("Happy birthday to Robot Dave!\r\n\r\n") + chan.send("Username: ") + f = chan.makefile("rU") + username = f.readline().strip("\r\n") + chan.send("\r\nI don't like you, " + username + ".\r\n") chan.close() except Exception as e: - print('*** Caught exception: ' + str(e.__class__) + ': ' + str(e)) + print("*** Caught exception: " + str(e.__class__) + ": " + str(e)) traceback.print_exc() try: t.close() except: pass sys.exit(1) - diff --git a/demos/demo_sftp.py b/demos/demo_sftp.py index 2cb44701..7f6a002e 100644 --- a/demos/demo_sftp.py +++ b/demos/demo_sftp.py @@ -32,38 +32,38 @@ from paramiko.py3compat import input # setup logging -paramiko.util.log_to_file('demo_sftp.log') +paramiko.util.log_to_file("demo_sftp.log") # Paramiko client configuration -UseGSSAPI = True # enable GSS-API / SSPI authentication +UseGSSAPI = True # enable GSS-API / SSPI authentication DoGSSAPIKeyExchange = True Port = 22 # get hostname -username = '' +username = "" if len(sys.argv) > 1: hostname = sys.argv[1] - if hostname.find('@') >= 0: - username, hostname = hostname.split('@') + if hostname.find("@") >= 0: + username, hostname = hostname.split("@") else: - hostname = input('Hostname: ') + hostname = input("Hostname: ") if len(hostname) == 0: - print('*** Hostname required.') + print("*** Hostname required.") sys.exit(1) -if hostname.find(':') >= 0: - hostname, portstr = hostname.split(':') +if hostname.find(":") >= 0: + hostname, portstr = hostname.split(":") Port = int(portstr) # get username -if username == '': +if username == "": default_username = getpass.getuser() - username = input('Username [%s]: ' % default_username) + username = input("Username [%s]: " % default_username) if len(username) == 0: username = default_username if not UseGSSAPI: - password = getpass.getpass('Password for %s@%s: ' % (username, hostname)) + password = getpass.getpass("Password for %s@%s: " % (username, hostname)) else: password = None @@ -72,59 +72,69 @@ else: hostkeytype = None hostkey = None try: - host_keys = paramiko.util.load_host_keys(os.path.expanduser('~/.ssh/known_hosts')) + host_keys = paramiko.util.load_host_keys( + os.path.expanduser("~/.ssh/known_hosts") + ) except IOError: try: # try ~/ssh/ too, because windows can't have a folder named ~/.ssh/ - host_keys = paramiko.util.load_host_keys(os.path.expanduser('~/ssh/known_hosts')) + host_keys = paramiko.util.load_host_keys( + os.path.expanduser("~/ssh/known_hosts") + ) except IOError: - print('*** Unable to open host keys file') + print("*** Unable to open host keys file") host_keys = {} if hostname in host_keys: hostkeytype = host_keys[hostname].keys()[0] hostkey = host_keys[hostname][hostkeytype] - print('Using host key of type %s' % hostkeytype) + print("Using host key of type %s" % hostkeytype) # now, connect and use paramiko Transport to negotiate SSH2 across the connection try: t = paramiko.Transport((hostname, Port)) - t.connect(hostkey, username, password, gss_host=socket.getfqdn(hostname), - gss_auth=UseGSSAPI, gss_kex=DoGSSAPIKeyExchange) + t.connect( + hostkey, + username, + password, + gss_host=socket.getfqdn(hostname), + gss_auth=UseGSSAPI, + gss_kex=DoGSSAPIKeyExchange, + ) sftp = paramiko.SFTPClient.from_transport(t) # dirlist on remote host - dirlist = sftp.listdir('.') + dirlist = sftp.listdir(".") print("Dirlist: %s" % dirlist) # copy this demo onto the server try: sftp.mkdir("demo_sftp_folder") except IOError: - print('(assuming demo_sftp_folder/ already exists)') - with sftp.open('demo_sftp_folder/README', 'w') as f: - f.write('This was created by demo_sftp.py.\n') - with open('demo_sftp.py', 'r') as f: + print("(assuming demo_sftp_folder/ already exists)") + with sftp.open("demo_sftp_folder/README", "w") as f: + f.write("This was created by demo_sftp.py.\n") + with open("demo_sftp.py", "r") as f: data = f.read() - sftp.open('demo_sftp_folder/demo_sftp.py', 'w').write(data) - print('created demo_sftp_folder/ on the server') - + sftp.open("demo_sftp_folder/demo_sftp.py", "w").write(data) + print("created demo_sftp_folder/ on the server") + # copy the README back here - with sftp.open('demo_sftp_folder/README', 'r') as f: + with sftp.open("demo_sftp_folder/README", "r") as f: data = f.read() - with open('README_demo_sftp', 'w') as f: + with open("README_demo_sftp", "w") as f: f.write(data) - print('copied README back here') - + print("copied README back here") + # BETTER: use the get() and put() methods - sftp.put('demo_sftp.py', 'demo_sftp_folder/demo_sftp.py') - sftp.get('demo_sftp_folder/README', 'README_demo_sftp') + sftp.put("demo_sftp.py", "demo_sftp_folder/demo_sftp.py") + sftp.get("demo_sftp_folder/README", "README_demo_sftp") t.close() except Exception as e: - print('*** Caught exception: %s: %s' % (e.__class__, e)) + print("*** Caught exception: %s: %s" % (e.__class__, e)) traceback.print_exc() try: t.close() diff --git a/demos/demo_simple.py b/demos/demo_simple.py index 9def57f8..5dd4f6c1 100644 --- a/demos/demo_simple.py +++ b/demos/demo_simple.py @@ -28,6 +28,7 @@ import traceback from paramiko.py3compat import input import paramiko + try: import interactive except ImportError: @@ -35,39 +36,43 @@ except ImportError: # setup logging -paramiko.util.log_to_file('demo_simple.log') +paramiko.util.log_to_file("demo_simple.log") # Paramiko client configuration -UseGSSAPI = paramiko.GSS_AUTH_AVAILABLE # enable "gssapi-with-mic" authentication, if supported by your python installation -DoGSSAPIKeyExchange = paramiko.GSS_AUTH_AVAILABLE # enable "gssapi-kex" key exchange, if supported by your python installation +UseGSSAPI = ( + paramiko.GSS_AUTH_AVAILABLE +) # enable "gssapi-with-mic" authentication, if supported by your python installation +DoGSSAPIKeyExchange = ( + paramiko.GSS_AUTH_AVAILABLE +) # enable "gssapi-kex" key exchange, if supported by your python installation # UseGSSAPI = False # DoGSSAPIKeyExchange = False port = 22 # get hostname -username = '' +username = "" if len(sys.argv) > 1: hostname = sys.argv[1] - if hostname.find('@') >= 0: - username, hostname = hostname.split('@') + if hostname.find("@") >= 0: + username, hostname = hostname.split("@") else: - hostname = input('Hostname: ') + hostname = input("Hostname: ") if len(hostname) == 0: - print('*** Hostname required.') + print("*** Hostname required.") sys.exit(1) -if hostname.find(':') >= 0: - hostname, portstr = hostname.split(':') +if hostname.find(":") >= 0: + hostname, portstr = hostname.split(":") port = int(portstr) # get username -if username == '': +if username == "": default_username = getpass.getuser() - username = input('Username [%s]: ' % default_username) + username = input("Username [%s]: " % default_username) if len(username) == 0: username = default_username if not UseGSSAPI and not DoGSSAPIKeyExchange: - password = getpass.getpass('Password for %s@%s: ' % (username, hostname)) + password = getpass.getpass("Password for %s@%s: " % (username, hostname)) # now, connect and use paramiko Client to negotiate SSH2 across the connection @@ -75,27 +80,34 @@ try: client = paramiko.SSHClient() client.load_system_host_keys() client.set_missing_host_key_policy(paramiko.WarningPolicy()) - print('*** Connecting...') + print("*** Connecting...") if not UseGSSAPI and not DoGSSAPIKeyExchange: client.connect(hostname, port, username, password) else: try: - client.connect(hostname, port, username, gss_auth=UseGSSAPI, - gss_kex=DoGSSAPIKeyExchange) + client.connect( + hostname, + port, + username, + gss_auth=UseGSSAPI, + gss_kex=DoGSSAPIKeyExchange, + ) except Exception: # traceback.print_exc() - password = getpass.getpass('Password for %s@%s: ' % (username, hostname)) + password = getpass.getpass( + "Password for %s@%s: " % (username, hostname) + ) client.connect(hostname, port, username, password) chan = client.invoke_shell() print(repr(client.get_transport())) - print('*** Here we go!\n') + print("*** Here we go!\n") interactive.interactive_shell(chan) chan.close() client.close() except Exception as e: - print('*** Caught exception: %s: %s' % (e.__class__, e)) + print("*** Caught exception: %s: %s" % (e.__class__, e)) traceback.print_exc() try: client.close() diff --git a/demos/forward.py b/demos/forward.py index 96e1700d..98757911 100644 --- a/demos/forward.py +++ b/demos/forward.py @@ -30,6 +30,7 @@ import getpass import os import socket import select + try: import SocketServer except ImportError: @@ -46,30 +47,41 @@ DEFAULT_PORT = 4000 g_verbose = True -class ForwardServer (SocketServer.ThreadingTCPServer): +class ForwardServer(SocketServer.ThreadingTCPServer): daemon_threads = True allow_reuse_address = True - -class Handler (SocketServer.BaseRequestHandler): + +class Handler(SocketServer.BaseRequestHandler): def handle(self): try: - chan = self.ssh_transport.open_channel('direct-tcpip', - (self.chain_host, self.chain_port), - self.request.getpeername()) + chan = self.ssh_transport.open_channel( + "direct-tcpip", + (self.chain_host, self.chain_port), + self.request.getpeername(), + ) except Exception as e: - verbose('Incoming request to %s:%d failed: %s' % (self.chain_host, - self.chain_port, - repr(e))) + verbose( + "Incoming request to %s:%d failed: %s" + % (self.chain_host, self.chain_port, repr(e)) + ) return if chan is None: - verbose('Incoming request to %s:%d was rejected by the SSH server.' % - (self.chain_host, self.chain_port)) + verbose( + "Incoming request to %s:%d was rejected by the SSH server." + % (self.chain_host, self.chain_port) + ) return - verbose('Connected! Tunnel open %r -> %r -> %r' % (self.request.getpeername(), - chan.getpeername(), (self.chain_host, self.chain_port))) + verbose( + "Connected! Tunnel open %r -> %r -> %r" + % ( + self.request.getpeername(), + chan.getpeername(), + (self.chain_host, self.chain_port), + ) + ) while True: r, w, x = select.select([self.request, chan], [], []) if self.request in r: @@ -82,22 +94,23 @@ class Handler (SocketServer.BaseRequestHandler): if len(data) == 0: break self.request.send(data) - + peername = self.request.getpeername() chan.close() self.request.close() - verbose('Tunnel closed from %r' % (peername,)) + verbose("Tunnel closed from %r" % (peername,)) def forward_tunnel(local_port, remote_host, remote_port, transport): # this is a little convoluted, but lets me configure things for the Handler # object. (SocketServer doesn't give Handlers any way to access the outer # server normally.) - class SubHander (Handler): + class SubHander(Handler): chain_host = remote_host chain_port = remote_port ssh_transport = transport - ForwardServer(('', local_port), SubHander).serve_forever() + + ForwardServer(("", local_port), SubHander).serve_forever() def verbose(s): @@ -114,40 +127,88 @@ the SSH server. This is similar to the openssh -L option. def get_host_port(spec, default_port): "parse 'hostname:22' into a host and port, with the port optional" - args = (spec.split(':', 1) + [default_port])[:2] + args = (spec.split(":", 1) + [default_port])[:2] args[1] = int(args[1]) return args[0], args[1] def parse_options(): global g_verbose - - parser = OptionParser(usage='usage: %prog [options] <ssh-server>[:<server-port>]', - version='%prog 1.0', description=HELP) - parser.add_option('-q', '--quiet', action='store_false', dest='verbose', default=True, - help='squelch all informational output') - parser.add_option('-p', '--local-port', action='store', type='int', dest='port', - default=DEFAULT_PORT, - help='local port to forward (default: %d)' % DEFAULT_PORT) - parser.add_option('-u', '--user', action='store', type='string', dest='user', - default=getpass.getuser(), - help='username for SSH authentication (default: %s)' % getpass.getuser()) - parser.add_option('-K', '--key', action='store', type='string', dest='keyfile', - default=None, - help='private key file to use for SSH authentication') - parser.add_option('', '--no-key', action='store_false', dest='look_for_keys', default=True, - help='don\'t look for or use a private key file') - parser.add_option('-P', '--password', action='store_true', dest='readpass', default=False, - help='read password (for key or password auth) from stdin') - parser.add_option('-r', '--remote', action='store', type='string', dest='remote', default=None, metavar='host:port', - help='remote host and port to forward to') + + parser = OptionParser( + usage="usage: %prog [options] <ssh-server>[:<server-port>]", + version="%prog 1.0", + description=HELP, + ) + parser.add_option( + "-q", + "--quiet", + action="store_false", + dest="verbose", + default=True, + help="squelch all informational output", + ) + parser.add_option( + "-p", + "--local-port", + action="store", + type="int", + dest="port", + default=DEFAULT_PORT, + help="local port to forward (default: %d)" % DEFAULT_PORT, + ) + parser.add_option( + "-u", + "--user", + action="store", + type="string", + dest="user", + default=getpass.getuser(), + help="username for SSH authentication (default: %s)" + % getpass.getuser(), + ) + parser.add_option( + "-K", + "--key", + action="store", + type="string", + dest="keyfile", + default=None, + help="private key file to use for SSH authentication", + ) + parser.add_option( + "", + "--no-key", + action="store_false", + dest="look_for_keys", + default=True, + help="don't look for or use a private key file", + ) + parser.add_option( + "-P", + "--password", + action="store_true", + dest="readpass", + default=False, + help="read password (for key or password auth) from stdin", + ) + parser.add_option( + "-r", + "--remote", + action="store", + type="string", + dest="remote", + default=None, + metavar="host:port", + help="remote host and port to forward to", + ) options, args = parser.parse_args() if len(args) != 1: - parser.error('Incorrect number of arguments.') + parser.error("Incorrect number of arguments.") if options.remote is None: - parser.error('Remote address required (-r).') - + parser.error("Remote address required (-r).") + g_verbose = options.verbose server_host, server_port = get_host_port(args[0], SSH_PORT) remote_host, remote_port = get_host_port(options.remote, SSH_PORT) @@ -156,31 +217,42 @@ def parse_options(): def main(): options, server, remote = parse_options() - + password = None if options.readpass: - password = getpass.getpass('Enter SSH password: ') - + password = getpass.getpass("Enter SSH password: ") + client = paramiko.SSHClient() client.load_system_host_keys() client.set_missing_host_key_policy(paramiko.WarningPolicy()) - verbose('Connecting to ssh host %s:%d ...' % (server[0], server[1])) + verbose("Connecting to ssh host %s:%d ..." % (server[0], server[1])) try: - client.connect(server[0], server[1], username=options.user, key_filename=options.keyfile, - look_for_keys=options.look_for_keys, password=password) + client.connect( + server[0], + server[1], + username=options.user, + key_filename=options.keyfile, + look_for_keys=options.look_for_keys, + password=password, + ) except Exception as e: - print('*** Failed to connect to %s:%d: %r' % (server[0], server[1], e)) + print("*** Failed to connect to %s:%d: %r" % (server[0], server[1], e)) sys.exit(1) - verbose('Now forwarding port %d to %s:%d ...' % (options.port, remote[0], remote[1])) + verbose( + "Now forwarding port %d to %s:%d ..." + % (options.port, remote[0], remote[1]) + ) try: - forward_tunnel(options.port, remote[0], remote[1], client.get_transport()) + forward_tunnel( + options.port, remote[0], remote[1], client.get_transport() + ) except KeyboardInterrupt: - print('C-c: Port forwarding stopped.') + print("C-c: Port forwarding stopped.") sys.exit(0) -if __name__ == '__main__': +if __name__ == "__main__": main() diff --git a/demos/interactive.py b/demos/interactive.py index 7138cd6c..037787c4 100644 --- a/demos/interactive.py +++ b/demos/interactive.py @@ -25,6 +25,7 @@ from paramiko.py3compat import u try: import termios import tty + has_termios = True except ImportError: has_termios = False @@ -39,7 +40,7 @@ def interactive_shell(chan): def posix_shell(chan): import select - + oldtty = termios.tcgetattr(sys.stdin) try: tty.setraw(sys.stdin.fileno()) @@ -52,7 +53,7 @@ def posix_shell(chan): try: x = u(chan.recv(1024)) if len(x) == 0: - sys.stdout.write('\r\n*** EOF\r\n') + sys.stdout.write("\r\n*** EOF\r\n") break sys.stdout.write(x) sys.stdout.flush() @@ -67,26 +68,28 @@ def posix_shell(chan): finally: termios.tcsetattr(sys.stdin, termios.TCSADRAIN, oldtty) - + # thanks to Mike Looijmans for this code def windows_shell(chan): import threading - sys.stdout.write("Line-buffered terminal emulation. Press F6 or ^Z to send EOF.\r\n\r\n") - + sys.stdout.write( + "Line-buffered terminal emulation. Press F6 or ^Z to send EOF.\r\n\r\n" + ) + def writeall(sock): while True: data = sock.recv(256) if not data: - sys.stdout.write('\r\n*** EOF ***\r\n\r\n') + sys.stdout.write("\r\n*** EOF ***\r\n\r\n") sys.stdout.flush() break sys.stdout.write(data) sys.stdout.flush() - + writer = threading.Thread(target=writeall, args=(chan,)) writer.start() - + try: while True: d = sys.stdin.read(1) diff --git a/demos/rforward.py b/demos/rforward.py index ae70670c..a2e8a776 100755 --- a/demos/rforward.py +++ b/demos/rforward.py @@ -47,11 +47,13 @@ def handler(chan, host, port): try: sock.connect((host, port)) except Exception as e: - verbose('Forwarding request to %s:%d failed: %r' % (host, port, e)) + verbose("Forwarding request to %s:%d failed: %r" % (host, port, e)) return - - verbose('Connected! Tunnel open %r -> %r -> %r' % (chan.origin_addr, - chan.getpeername(), (host, port))) + + verbose( + "Connected! Tunnel open %r -> %r -> %r" + % (chan.origin_addr, chan.getpeername(), (host, port)) + ) while True: r, w, x = select.select([sock, chan], [], []) if sock in r: @@ -66,16 +68,18 @@ def handler(chan, host, port): sock.send(data) chan.close() sock.close() - verbose('Tunnel closed from %r' % (chan.origin_addr,)) + verbose("Tunnel closed from %r" % (chan.origin_addr,)) def reverse_forward_tunnel(server_port, remote_host, remote_port, transport): - transport.request_port_forward('', server_port) + transport.request_port_forward("", server_port) while True: chan = transport.accept(1000) if chan is None: continue - thr = threading.Thread(target=handler, args=(chan, remote_host, remote_port)) + thr = threading.Thread( + target=handler, args=(chan, remote_host, remote_port) + ) thr.setDaemon(True) thr.start() @@ -95,40 +99,88 @@ network. This is similar to the openssh -R option. def get_host_port(spec, default_port): "parse 'hostname:22' into a host and port, with the port optional" - args = (spec.split(':', 1) + [default_port])[:2] + args = (spec.split(":", 1) + [default_port])[:2] args[1] = int(args[1]) return args[0], args[1] def parse_options(): global g_verbose - - parser = OptionParser(usage='usage: %prog [options] <ssh-server>[:<server-port>]', - version='%prog 1.0', description=HELP) - parser.add_option('-q', '--quiet', action='store_false', dest='verbose', default=True, - help='squelch all informational output') - parser.add_option('-p', '--remote-port', action='store', type='int', dest='port', - default=DEFAULT_PORT, - help='port on server to forward (default: %d)' % DEFAULT_PORT) - parser.add_option('-u', '--user', action='store', type='string', dest='user', - default=getpass.getuser(), - help='username for SSH authentication (default: %s)' % getpass.getuser()) - parser.add_option('-K', '--key', action='store', type='string', dest='keyfile', - default=None, - help='private key file to use for SSH authentication') - parser.add_option('', '--no-key', action='store_false', dest='look_for_keys', default=True, - help='don\'t look for or use a private key file') - parser.add_option('-P', '--password', action='store_true', dest='readpass', default=False, - help='read password (for key or password auth) from stdin') - parser.add_option('-r', '--remote', action='store', type='string', dest='remote', default=None, metavar='host:port', - help='remote host and port to forward to') + + parser = OptionParser( + usage="usage: %prog [options] <ssh-server>[:<server-port>]", + version="%prog 1.0", + description=HELP, + ) + parser.add_option( + "-q", + "--quiet", + action="store_false", + dest="verbose", + default=True, + help="squelch all informational output", + ) + parser.add_option( + "-p", + "--remote-port", + action="store", + type="int", + dest="port", + default=DEFAULT_PORT, + help="port on server to forward (default: %d)" % DEFAULT_PORT, + ) + parser.add_option( + "-u", + "--user", + action="store", + type="string", + dest="user", + default=getpass.getuser(), + help="username for SSH authentication (default: %s)" + % getpass.getuser(), + ) + parser.add_option( + "-K", + "--key", + action="store", + type="string", + dest="keyfile", + default=None, + help="private key file to use for SSH authentication", + ) + parser.add_option( + "", + "--no-key", + action="store_false", + dest="look_for_keys", + default=True, + help="don't look for or use a private key file", + ) + parser.add_option( + "-P", + "--password", + action="store_true", + dest="readpass", + default=False, + help="read password (for key or password auth) from stdin", + ) + parser.add_option( + "-r", + "--remote", + action="store", + type="string", + dest="remote", + default=None, + metavar="host:port", + help="remote host and port to forward to", + ) options, args = parser.parse_args() if len(args) != 1: - parser.error('Incorrect number of arguments.') + parser.error("Incorrect number of arguments.") if options.remote is None: - parser.error('Remote address required (-r).') - + parser.error("Remote address required (-r).") + g_verbose = options.verbose server_host, server_port = get_host_port(args[0], SSH_PORT) remote_host, remote_port = get_host_port(options.remote, SSH_PORT) @@ -137,31 +189,42 @@ def parse_options(): def main(): options, server, remote = parse_options() - + password = None if options.readpass: - password = getpass.getpass('Enter SSH password: ') - + password = getpass.getpass("Enter SSH password: ") + client = paramiko.SSHClient() client.load_system_host_keys() client.set_missing_host_key_policy(paramiko.WarningPolicy()) - verbose('Connecting to ssh host %s:%d ...' % (server[0], server[1])) + verbose("Connecting to ssh host %s:%d ..." % (server[0], server[1])) try: - client.connect(server[0], server[1], username=options.user, key_filename=options.keyfile, - look_for_keys=options.look_for_keys, password=password) + client.connect( + server[0], + server[1], + username=options.user, + key_filename=options.keyfile, + look_for_keys=options.look_for_keys, + password=password, + ) except Exception as e: - print('*** Failed to connect to %s:%d: %r' % (server[0], server[1], e)) + print("*** Failed to connect to %s:%d: %r" % (server[0], server[1], e)) sys.exit(1) - verbose('Now forwarding remote port %d to %s:%d ...' % (options.port, remote[0], remote[1])) + verbose( + "Now forwarding remote port %d to %s:%d ..." + % (options.port, remote[0], remote[1]) + ) try: - reverse_forward_tunnel(options.port, remote[0], remote[1], client.get_transport()) + reverse_forward_tunnel( + options.port, remote[0], remote[1], client.get_transport() + ) except KeyboardInterrupt: - print('C-c: Port forwarding stopped.') + print("C-c: Port forwarding stopped.") sys.exit(0) -if __name__ == '__main__': +if __name__ == "__main__": main() diff --git a/dev-requirements.txt b/dev-requirements.txt index d41972b9..e4629187 100644 --- a/dev-requirements.txt +++ b/dev-requirements.txt @@ -1,6 +1,6 @@ # Invocations for common project tasks -invoke>=0.13,<2.0 -invocations>=1.0,<2.0 +invoke>=1.0,<2.0 +invocations>=1.2.0,<2.0 # NOTE: pytest-relaxed currently only works with pytest >=3, <3.3 pytest>=3.2,<3.3 pytest-relaxed==1.1.2 diff --git a/paramiko/__init__.py b/paramiko/__init__.py index 4d80bf26..ebfa72a8 100644 --- a/paramiko/__init__.py +++ b/paramiko/__init__.py @@ -21,15 +21,22 @@ import sys from paramiko._version import __version__, __version_info__ from paramiko.transport import SecurityOptions, Transport from paramiko.client import ( - SSHClient, MissingHostKeyPolicy, AutoAddPolicy, RejectPolicy, + SSHClient, + MissingHostKeyPolicy, + AutoAddPolicy, + RejectPolicy, WarningPolicy, ) from paramiko.auth_handler import AuthHandler from paramiko.ssh_gss import GSSAuth, GSS_AUTH_AVAILABLE, GSS_EXCEPTIONS from paramiko.channel import Channel, ChannelFile from paramiko.ssh_exception import ( - SSHException, PasswordRequiredException, BadAuthenticationType, - ChannelException, BadHostKeyException, AuthenticationException, + SSHException, + PasswordRequiredException, + BadAuthenticationType, + ChannelException, + BadHostKeyException, + AuthenticationException, ProxyCommandFailure, ) from paramiko.server import ServerInterface, SubsystemHandler, InteractiveQuery @@ -54,14 +61,25 @@ from paramiko.config import SSHConfig from paramiko.proxy import ProxyCommand from paramiko.common import ( - AUTH_SUCCESSFUL, AUTH_PARTIALLY_SUCCESSFUL, AUTH_FAILED, OPEN_SUCCEEDED, - OPEN_FAILED_ADMINISTRATIVELY_PROHIBITED, OPEN_FAILED_CONNECT_FAILED, - OPEN_FAILED_UNKNOWN_CHANNEL_TYPE, OPEN_FAILED_RESOURCE_SHORTAGE, + AUTH_SUCCESSFUL, + AUTH_PARTIALLY_SUCCESSFUL, + AUTH_FAILED, + OPEN_SUCCEEDED, + OPEN_FAILED_ADMINISTRATIVELY_PROHIBITED, + OPEN_FAILED_CONNECT_FAILED, + OPEN_FAILED_UNKNOWN_CHANNEL_TYPE, + OPEN_FAILED_RESOURCE_SHORTAGE, ) from paramiko.sftp import ( - SFTP_OK, SFTP_EOF, SFTP_NO_SUCH_FILE, SFTP_PERMISSION_DENIED, SFTP_FAILURE, - SFTP_BAD_MESSAGE, SFTP_NO_CONNECTION, SFTP_CONNECTION_LOST, + SFTP_OK, + SFTP_EOF, + SFTP_NO_SUCH_FILE, + SFTP_PERMISSION_DENIED, + SFTP_FAILURE, + SFTP_BAD_MESSAGE, + SFTP_NO_CONNECTION, + SFTP_CONNECTION_LOST, SFTP_OP_UNSUPPORTED, ) @@ -72,43 +90,43 @@ __author__ = "Jeff Forcier <jeff@bitprophet.org>" __license__ = "GNU Lesser General Public License (LGPL)" __all__ = [ - 'Transport', - 'SSHClient', - 'MissingHostKeyPolicy', - 'AutoAddPolicy', - 'RejectPolicy', - 'WarningPolicy', - 'SecurityOptions', - 'SubsystemHandler', - 'Channel', - 'PKey', - 'RSAKey', - 'DSSKey', - 'ECDSAKey', - 'Ed25519Key', - 'Message', - 'SSHException', - 'AuthenticationException', - 'PasswordRequiredException', - 'BadAuthenticationType', - 'ChannelException', - 'BadHostKeyException', - 'ProxyCommand', - 'ProxyCommandFailure', - 'SFTP', - 'SFTPFile', - 'SFTPHandle', - 'SFTPClient', - 'SFTPServer', - 'SFTPError', - 'SFTPAttributes', - 'SFTPServerInterface', - 'ServerInterface', - 'BufferedFile', - 'Agent', - 'AgentKey', - 'HostKeys', - 'SSHConfig', - 'util', - 'io_sleep', + "Agent", + "AgentKey", + "AuthenticationException", + "AutoAddPolicy", + "BadAuthenticationType", + "BadHostKeyException", + "BufferedFile", + "Channel", + "ChannelException", + "DSSKey", + "ECDSAKey", + "Ed25519Key", + "HostKeys", + "Message", + "MissingHostKeyPolicy", + "PKey", + "PasswordRequiredException", + "ProxyCommand", + "ProxyCommandFailure", + "RSAKey", + "RejectPolicy", + "SFTP", + "SFTPAttributes", + "SFTPClient", + "SFTPError", + "SFTPFile", + "SFTPHandle", + "SFTPServer", + "SFTPServerInterface", + "SSHClient", + "SSHConfig", + "SSHException", + "SecurityOptions", + "ServerInterface", + "SubsystemHandler", + "Transport", + "WarningPolicy", + "io_sleep", + "util", ] diff --git a/paramiko/_version.py b/paramiko/_version.py index a5db89b8..95e86b6a 100644 --- a/paramiko/_version.py +++ b/paramiko/_version.py @@ -1,2 +1,2 @@ __version_info__ = (2, 4, 1) -__version__ = '.'.join(map(str, __version_info__)) +__version__ = ".".join(map(str, __version_info__)) diff --git a/paramiko/_winapi.py b/paramiko/_winapi.py index a13d7e87..ebcc678a 100644 --- a/paramiko/_winapi.py +++ b/paramiko/_winapi.py @@ -15,6 +15,7 @@ from paramiko.py3compat import u, builtins ###################### # jaraco.windows.error + def format_system_message(errno): """ Call FormatMessage with a system error number to retrieve @@ -77,7 +78,7 @@ class WindowsError(builtins.WindowsError): return self.message def __repr__(self): - return '{self.__class__.__name__}({self.winerror})'.format(**vars()) + return "{self.__class__.__name__}({self.winerror})".format(**vars()) def handle_nonzero_success(result): @@ -95,15 +96,15 @@ GlobalAlloc.argtypes = ctypes.wintypes.UINT, ctypes.c_size_t GlobalAlloc.restype = ctypes.wintypes.HANDLE GlobalLock = ctypes.windll.kernel32.GlobalLock -GlobalLock.argtypes = ctypes.wintypes.HGLOBAL, +GlobalLock.argtypes = (ctypes.wintypes.HGLOBAL,) GlobalLock.restype = ctypes.wintypes.LPVOID GlobalUnlock = ctypes.windll.kernel32.GlobalUnlock -GlobalUnlock.argtypes = ctypes.wintypes.HGLOBAL, +GlobalUnlock.argtypes = (ctypes.wintypes.HGLOBAL,) GlobalUnlock.restype = ctypes.wintypes.BOOL GlobalSize = ctypes.windll.kernel32.GlobalSize -GlobalSize.argtypes = ctypes.wintypes.HGLOBAL, +GlobalSize.argtypes = (ctypes.wintypes.HGLOBAL,) GlobalSize.restype = ctypes.c_size_t CreateFileMapping = ctypes.windll.kernel32.CreateFileMappingW @@ -121,16 +122,12 @@ MapViewOfFile = ctypes.windll.kernel32.MapViewOfFile MapViewOfFile.restype = ctypes.wintypes.HANDLE UnmapViewOfFile = ctypes.windll.kernel32.UnmapViewOfFile -UnmapViewOfFile.argtypes = ctypes.wintypes.HANDLE, +UnmapViewOfFile.argtypes = (ctypes.wintypes.HANDLE,) RtlMoveMemory = ctypes.windll.kernel32.RtlMoveMemory -RtlMoveMemory.argtypes = ( - ctypes.c_void_p, - ctypes.c_void_p, - ctypes.c_size_t, -) +RtlMoveMemory.argtypes = (ctypes.c_void_p, ctypes.c_void_p, ctypes.c_size_t) -ctypes.windll.kernel32.LocalFree.argtypes = ctypes.wintypes.HLOCAL, +ctypes.windll.kernel32.LocalFree.argtypes = (ctypes.wintypes.HLOCAL,) ##################### # jaraco.windows.mmap @@ -140,6 +137,7 @@ class MemoryMap(object): """ A memory map object which can have security attributes overridden. """ + def __init__(self, name, length, security_attributes=None): self.name = name self.length = length @@ -149,14 +147,20 @@ class MemoryMap(object): def __enter__(self): p_SA = ( ctypes.byref(self.security_attributes) - if self.security_attributes else None + if self.security_attributes + else None ) INVALID_HANDLE_VALUE = -1 PAGE_READWRITE = 0x4 FILE_MAP_WRITE = 0x2 filemap = ctypes.windll.kernel32.CreateFileMappingW( - INVALID_HANDLE_VALUE, p_SA, PAGE_READWRITE, 0, self.length, - u(self.name)) + INVALID_HANDLE_VALUE, + p_SA, + PAGE_READWRITE, + 0, + self.length, + u(self.name), + ) handle_nonzero_success(filemap) if filemap == INVALID_HANDLE_VALUE: raise Exception("Failed to create file mapping") @@ -220,41 +224,45 @@ POLICY_LOOKUP_NAMES = 0x00000800 POLICY_NOTIFICATION = 0x00001000 POLICY_ALL_ACCESS = ( - STANDARD_RIGHTS_REQUIRED | - POLICY_VIEW_LOCAL_INFORMATION | - POLICY_VIEW_AUDIT_INFORMATION | - POLICY_GET_PRIVATE_INFORMATION | - POLICY_TRUST_ADMIN | - POLICY_CREATE_ACCOUNT | - POLICY_CREATE_SECRET | - POLICY_CREATE_PRIVILEGE | - POLICY_SET_DEFAULT_QUOTA_LIMITS | - POLICY_SET_AUDIT_REQUIREMENTS | - POLICY_AUDIT_LOG_ADMIN | - POLICY_SERVER_ADMIN | - POLICY_LOOKUP_NAMES) + STANDARD_RIGHTS_REQUIRED + | POLICY_VIEW_LOCAL_INFORMATION + | POLICY_VIEW_AUDIT_INFORMATION + | POLICY_GET_PRIVATE_INFORMATION + | POLICY_TRUST_ADMIN + | POLICY_CREATE_ACCOUNT + | POLICY_CREATE_SECRET + | POLICY_CREATE_PRIVILEGE + | POLICY_SET_DEFAULT_QUOTA_LIMITS + | POLICY_SET_AUDIT_REQUIREMENTS + | POLICY_AUDIT_LOG_ADMIN + | POLICY_SERVER_ADMIN + | POLICY_LOOKUP_NAMES +) POLICY_READ = ( - STANDARD_RIGHTS_READ | - POLICY_VIEW_AUDIT_INFORMATION | - POLICY_GET_PRIVATE_INFORMATION) + STANDARD_RIGHTS_READ + | POLICY_VIEW_AUDIT_INFORMATION + | POLICY_GET_PRIVATE_INFORMATION +) POLICY_WRITE = ( - STANDARD_RIGHTS_WRITE | - POLICY_TRUST_ADMIN | - POLICY_CREATE_ACCOUNT | - POLICY_CREATE_SECRET | - POLICY_CREATE_PRIVILEGE | - POLICY_SET_DEFAULT_QUOTA_LIMITS | - POLICY_SET_AUDIT_REQUIREMENTS | - POLICY_AUDIT_LOG_ADMIN | - POLICY_SERVER_ADMIN) + STANDARD_RIGHTS_WRITE + | POLICY_TRUST_ADMIN + | POLICY_CREATE_ACCOUNT + | POLICY_CREATE_SECRET + | POLICY_CREATE_PRIVILEGE + | POLICY_SET_DEFAULT_QUOTA_LIMITS + | POLICY_SET_AUDIT_REQUIREMENTS + | POLICY_AUDIT_LOG_ADMIN + | POLICY_SERVER_ADMIN +) POLICY_EXECUTE = ( - STANDARD_RIGHTS_EXECUTE | - POLICY_VIEW_LOCAL_INFORMATION | - POLICY_LOOKUP_NAMES) + STANDARD_RIGHTS_EXECUTE + | POLICY_VIEW_LOCAL_INFORMATION + | POLICY_LOOKUP_NAMES +) class TokenAccess: @@ -268,8 +276,8 @@ class TokenInformationClass: class TOKEN_USER(ctypes.Structure): num = 1 _fields_ = [ - ('SID', ctypes.c_void_p), - ('ATTRIBUTES', ctypes.wintypes.DWORD), + ("SID", ctypes.c_void_p), + ("ATTRIBUTES", ctypes.wintypes.DWORD), ] @@ -290,13 +298,13 @@ class SECURITY_DESCRIPTOR(ctypes.Structure): REVISION = 1 _fields_ = [ - ('Revision', ctypes.c_ubyte), - ('Sbz1', ctypes.c_ubyte), - ('Control', SECURITY_DESCRIPTOR_CONTROL), - ('Owner', ctypes.c_void_p), - ('Group', ctypes.c_void_p), - ('Sacl', ctypes.c_void_p), - ('Dacl', ctypes.c_void_p), + ("Revision", ctypes.c_ubyte), + ("Sbz1", ctypes.c_ubyte), + ("Control", SECURITY_DESCRIPTOR_CONTROL), + ("Owner", ctypes.c_void_p), + ("Group", ctypes.c_void_p), + ("Sacl", ctypes.c_void_p), + ("Dacl", ctypes.c_void_p), ] @@ -309,9 +317,9 @@ class SECURITY_ATTRIBUTES(ctypes.Structure): } SECURITY_ATTRIBUTES; """ _fields_ = [ - ('nLength', ctypes.wintypes.DWORD), - ('lpSecurityDescriptor', ctypes.c_void_p), - ('bInheritHandle', ctypes.wintypes.BOOL), + ("nLength", ctypes.wintypes.DWORD), + ("lpSecurityDescriptor", ctypes.c_void_p), + ("bInheritHandle", ctypes.wintypes.BOOL), ] def __init__(self, *args, **kwargs): @@ -343,21 +351,30 @@ def GetTokenInformation(token, information_class): Given a token, get the token information for it. """ data_size = ctypes.wintypes.DWORD() - ctypes.windll.advapi32.GetTokenInformation(token, information_class.num, - 0, 0, ctypes.byref(data_size)) + ctypes.windll.advapi32.GetTokenInformation( + token, information_class.num, 0, 0, ctypes.byref(data_size) + ) data = ctypes.create_string_buffer(data_size.value) - handle_nonzero_success(ctypes.windll.advapi32.GetTokenInformation(token, - information_class.num, - ctypes.byref(data), ctypes.sizeof(data), - ctypes.byref(data_size))) + handle_nonzero_success( + ctypes.windll.advapi32.GetTokenInformation( + token, + information_class.num, + ctypes.byref(data), + ctypes.sizeof(data), + ctypes.byref(data_size), + ) + ) return ctypes.cast(data, ctypes.POINTER(TOKEN_USER)).contents def OpenProcessToken(proc_handle, access): result = ctypes.wintypes.HANDLE() proc_handle = ctypes.wintypes.HANDLE(proc_handle) - handle_nonzero_success(ctypes.windll.advapi32.OpenProcessToken( - proc_handle, access, ctypes.byref(result))) + handle_nonzero_success( + ctypes.windll.advapi32.OpenProcessToken( + proc_handle, access, ctypes.byref(result) + ) + ) return result @@ -366,8 +383,7 @@ def get_current_user(): Return a TOKEN_USER for the owner of this process. """ process = OpenProcessToken( - ctypes.windll.kernel32.GetCurrentProcess(), - TokenAccess.TOKEN_QUERY, + ctypes.windll.kernel32.GetCurrentProcess(), TokenAccess.TOKEN_QUERY ) return GetTokenInformation(process, TOKEN_USER) @@ -389,8 +405,10 @@ def get_security_attributes_for_user(user=None): SA.descriptor = SD SA.bInheritHandle = 1 - ctypes.windll.advapi32.InitializeSecurityDescriptor(ctypes.byref(SD), - SECURITY_DESCRIPTOR.REVISION) - ctypes.windll.advapi32.SetSecurityDescriptorOwner(ctypes.byref(SD), - user.SID, 0) + ctypes.windll.advapi32.InitializeSecurityDescriptor( + ctypes.byref(SD), SECURITY_DESCRIPTOR.REVISION + ) + ctypes.windll.advapi32.SetSecurityDescriptorOwner( + ctypes.byref(SD), user.SID, 0 + ) return SA diff --git a/paramiko/agent.py b/paramiko/agent.py index 7a4dde21..62a271d5 100644 --- a/paramiko/agent.py +++ b/paramiko/agent.py @@ -43,8 +43,8 @@ cSSH2_AGENTC_SIGN_REQUEST = byte_chr(13) SSH2_AGENT_SIGN_RESPONSE = 14 - class AgentSSH(object): + def __init__(self): self._conn = None self._keys = () @@ -65,7 +65,7 @@ class AgentSSH(object): self._conn = conn ptype, result = self._send_message(cSSH2_AGENTC_REQUEST_IDENTITIES) if ptype != SSH2_AGENT_IDENTITIES_ANSWER: - raise SSHException('could not get keys from ssh-agent') + raise SSHException("could not get keys from ssh-agent") keys = [] for i in range(result.get_int()): keys.append(AgentKey(self, result.get_binary())) @@ -80,19 +80,19 @@ class AgentSSH(object): def _send_message(self, msg): msg = asbytes(msg) - self._conn.send(struct.pack('>I', len(msg)) + msg) + self._conn.send(struct.pack(">I", len(msg)) + msg) l = self._read_all(4) - msg = Message(self._read_all(struct.unpack('>I', l)[0])) + msg = Message(self._read_all(struct.unpack(">I", l)[0])) return ord(msg.get_byte()), msg def _read_all(self, wanted): result = self._conn.recv(wanted) while len(result) < wanted: if len(result) == 0: - raise SSHException('lost ssh-agent') + raise SSHException("lost ssh-agent") extra = self._conn.recv(wanted - len(result)) if len(extra) == 0: - raise SSHException('lost ssh-agent') + raise SSHException("lost ssh-agent") result += extra return result @@ -101,6 +101,7 @@ class AgentProxyThread(threading.Thread): """ Class in charge of communication between two channels. """ + def __init__(self, agent): threading.Thread.__init__(self, target=self.run) self._agent = agent @@ -115,12 +116,9 @@ class AgentProxyThread(threading.Thread): # The address should be an IP address as a string? or None self.__addr = addr self._agent.connect() - if ( - not isinstance(self._agent, int) and - ( - self._agent._conn is None or - not hasattr(self._agent._conn, 'fileno') - ) + if not isinstance(self._agent, int) and ( + self._agent._conn is None + or not hasattr(self._agent._conn, "fileno") ): raise AuthenticationException("Unable to connect to SSH agent") self._communicate() @@ -130,6 +128,7 @@ class AgentProxyThread(threading.Thread): def _communicate(self): import fcntl + oldflags = fcntl.fcntl(self.__inr, fcntl.F_GETFL) fcntl.fcntl(self.__inr, fcntl.F_SETFL, oldflags | os.O_NONBLOCK) while not self._exit: @@ -162,6 +161,7 @@ class AgentLocalProxy(AgentProxyThread): Class to be used when wanting to ask a local SSH Agent being asked from a remote fake agent (so use a unix socket for ex.) """ + def __init__(self, agent): AgentProxyThread.__init__(self, agent) @@ -185,6 +185,7 @@ class AgentRemoteProxy(AgentProxyThread): """ Class to be used when wanting to ask a remote SSH Agent """ + def __init__(self, agent, chan): AgentProxyThread.__init__(self, agent) self.__chan = chan @@ -205,6 +206,7 @@ class AgentClientProxy(object): the remote fake agent and the local agent #. Communication occurs ... """ + def __init__(self, chanRemote): self._conn = None self.__chanR = chanRemote @@ -218,16 +220,18 @@ class AgentClientProxy(object): """ Method automatically called by ``AgentProxyThread.run``. """ - if ('SSH_AUTH_SOCK' in os.environ) and (sys.platform != 'win32'): + if ("SSH_AUTH_SOCK" in os.environ) and (sys.platform != "win32"): conn = socket.socket(socket.AF_UNIX, socket.SOCK_STREAM) try: retry_on_signal( - lambda: conn.connect(os.environ['SSH_AUTH_SOCK'])) + lambda: conn.connect(os.environ["SSH_AUTH_SOCK"]) + ) except: # probably a dangling env var: the ssh agent is gone return - elif sys.platform == 'win32': + elif sys.platform == "win32": import paramiko.win_pageant as win_pageant + if win_pageant.can_talk_to_agent(): conn = win_pageant.PageantConnection() else: @@ -255,12 +259,13 @@ class AgentServerProxy(AgentSSH): :raises: `.SSHException` -- mostly if we lost the agent """ + def __init__(self, t): AgentSSH.__init__(self) self.__t = t - self._dir = tempfile.mkdtemp('sshproxy') + self._dir = tempfile.mkdtemp("sshproxy") os.chmod(self._dir, stat.S_IRWXU) - self._file = self._dir + '/sshproxy.ssh' + self._file = self._dir + "/sshproxy.ssh" self.thread = AgentLocalProxy(self) self.thread.start() @@ -270,8 +275,8 @@ class AgentServerProxy(AgentSSH): def connect(self): conn_sock = self.__t.open_forward_agent_channel() if conn_sock is None: - raise SSHException('lost ssh-agent') - conn_sock.set_name('auth-agent') + raise SSHException("lost ssh-agent") + conn_sock.set_name("auth-agent") self._connect(conn_sock) def close(self): @@ -292,7 +297,7 @@ class AgentServerProxy(AgentSSH): :return: a dict containing the ``SSH_AUTH_SOCK`` environnement variables """ - return {'SSH_AUTH_SOCK': self._get_filename()} + return {"SSH_AUTH_SOCK": self._get_filename()} def _get_filename(self): return self._file @@ -319,6 +324,7 @@ class AgentRequestHandler(object): # the remote end. session.exec_command("git clone https://my.git.repository/") """ + def __init__(self, chanClient): self._conn = None self.__chanC = chanClient @@ -350,18 +356,20 @@ class Agent(AgentSSH): :raises: `.SSHException` -- if an SSH agent is found, but speaks an incompatible protocol """ + def __init__(self): AgentSSH.__init__(self) - if ('SSH_AUTH_SOCK' in os.environ) and (sys.platform != 'win32'): + if ("SSH_AUTH_SOCK" in os.environ) and (sys.platform != "win32"): conn = socket.socket(socket.AF_UNIX, socket.SOCK_STREAM) try: - conn.connect(os.environ['SSH_AUTH_SOCK']) + conn.connect(os.environ["SSH_AUTH_SOCK"]) except: # probably a dangling env var: the ssh agent is gone return - elif sys.platform == 'win32': + elif sys.platform == "win32": from . import win_pageant + if win_pageant.can_talk_to_agent(): conn = win_pageant.PageantConnection() else: @@ -384,6 +392,7 @@ class AgentKey(PKey): authenticating to a remote server (signing). Most other key operations work as expected. """ + def __init__(self, agent, blob): self.agent = agent self.blob = blob @@ -407,5 +416,5 @@ class AgentKey(PKey): msg.add_int(0) ptype, result = self.agent._send_message(msg) if ptype != SSH2_AGENT_SIGN_RESPONSE: - raise SSHException('key cannot be used for signing') + raise SSHException("key cannot be used for signing") return result.get_binary() diff --git a/paramiko/auth_handler.py b/paramiko/auth_handler.py index 3b894de7..41724832 100644 --- a/paramiko/auth_handler.py +++ b/paramiko/auth_handler.py @@ -24,31 +24,55 @@ import weakref import time from paramiko.common import ( - cMSG_SERVICE_REQUEST, cMSG_DISCONNECT, DISCONNECT_SERVICE_NOT_AVAILABLE, - DISCONNECT_NO_MORE_AUTH_METHODS_AVAILABLE, cMSG_USERAUTH_REQUEST, - cMSG_SERVICE_ACCEPT, DEBUG, AUTH_SUCCESSFUL, INFO, cMSG_USERAUTH_SUCCESS, - cMSG_USERAUTH_FAILURE, AUTH_PARTIALLY_SUCCESSFUL, - cMSG_USERAUTH_INFO_REQUEST, WARNING, AUTH_FAILED, cMSG_USERAUTH_PK_OK, - cMSG_USERAUTH_INFO_RESPONSE, MSG_SERVICE_REQUEST, MSG_SERVICE_ACCEPT, - MSG_USERAUTH_REQUEST, MSG_USERAUTH_SUCCESS, MSG_USERAUTH_FAILURE, - MSG_USERAUTH_BANNER, MSG_USERAUTH_INFO_REQUEST, MSG_USERAUTH_INFO_RESPONSE, - cMSG_USERAUTH_GSSAPI_RESPONSE, cMSG_USERAUTH_GSSAPI_TOKEN, - cMSG_USERAUTH_GSSAPI_MIC, MSG_USERAUTH_GSSAPI_RESPONSE, - MSG_USERAUTH_GSSAPI_TOKEN, MSG_USERAUTH_GSSAPI_ERROR, - MSG_USERAUTH_GSSAPI_ERRTOK, MSG_USERAUTH_GSSAPI_MIC, MSG_NAMES, - cMSG_USERAUTH_BANNER + cMSG_SERVICE_REQUEST, + cMSG_DISCONNECT, + DISCONNECT_SERVICE_NOT_AVAILABLE, + DISCONNECT_NO_MORE_AUTH_METHODS_AVAILABLE, + cMSG_USERAUTH_REQUEST, + cMSG_SERVICE_ACCEPT, + DEBUG, + AUTH_SUCCESSFUL, + INFO, + cMSG_USERAUTH_SUCCESS, + cMSG_USERAUTH_FAILURE, + AUTH_PARTIALLY_SUCCESSFUL, + cMSG_USERAUTH_INFO_REQUEST, + WARNING, + AUTH_FAILED, + cMSG_USERAUTH_PK_OK, + cMSG_USERAUTH_INFO_RESPONSE, + MSG_SERVICE_REQUEST, + MSG_SERVICE_ACCEPT, + MSG_USERAUTH_REQUEST, + MSG_USERAUTH_SUCCESS, + MSG_USERAUTH_FAILURE, + MSG_USERAUTH_BANNER, + MSG_USERAUTH_INFO_REQUEST, + MSG_USERAUTH_INFO_RESPONSE, + cMSG_USERAUTH_GSSAPI_RESPONSE, + cMSG_USERAUTH_GSSAPI_TOKEN, + cMSG_USERAUTH_GSSAPI_MIC, + MSG_USERAUTH_GSSAPI_RESPONSE, + MSG_USERAUTH_GSSAPI_TOKEN, + MSG_USERAUTH_GSSAPI_ERROR, + MSG_USERAUTH_GSSAPI_ERRTOK, + MSG_USERAUTH_GSSAPI_MIC, + MSG_NAMES, + cMSG_USERAUTH_BANNER, ) from paramiko.message import Message from paramiko.py3compat import b from paramiko.ssh_exception import ( - SSHException, AuthenticationException, BadAuthenticationType, + SSHException, + AuthenticationException, + BadAuthenticationType, PartialAuthentication, ) from paramiko.server import InteractiveQuery from paramiko.ssh_gss import GSSAuth, GSS_EXCEPTIONS -class AuthHandler (object): +class AuthHandler(object): """ Internal class to handle the mechanics of authentication. """ @@ -58,7 +82,7 @@ class AuthHandler (object): self.username = None self.authenticated = False self.auth_event = None - self.auth_method = '' + self.auth_method = "" self.banner = None self.password = None self.private_key = None @@ -87,7 +111,7 @@ class AuthHandler (object): self.transport.lock.acquire() try: self.auth_event = event - self.auth_method = 'none' + self.auth_method = "none" self.username = username self._request_auth() finally: @@ -97,7 +121,7 @@ class AuthHandler (object): self.transport.lock.acquire() try: self.auth_event = event - self.auth_method = 'publickey' + self.auth_method = "publickey" self.username = username self.private_key = key self._request_auth() @@ -108,21 +132,21 @@ class AuthHandler (object): self.transport.lock.acquire() try: self.auth_event = event - self.auth_method = 'password' + self.auth_method = "password" self.username = username self.password = password self._request_auth() finally: self.transport.lock.release() - def auth_interactive(self, username, handler, event, submethods=''): + def auth_interactive(self, username, handler, event, submethods=""): """ response_list = handler(title, instructions, prompt_list) """ self.transport.lock.acquire() try: self.auth_event = event - self.auth_method = 'keyboard-interactive' + self.auth_method = "keyboard-interactive" self.username = username self.interactive_handler = handler self.submethods = submethods @@ -134,7 +158,7 @@ class AuthHandler (object): self.transport.lock.acquire() try: self.auth_event = event - self.auth_method = 'gssapi-with-mic' + self.auth_method = "gssapi-with-mic" self.username = username self.gss_host = gss_host self.gss_deleg_creds = gss_deleg_creds @@ -146,7 +170,7 @@ class AuthHandler (object): self.transport.lock.acquire() try: self.auth_event = event - self.auth_method = 'gssapi-keyex' + self.auth_method = "gssapi-keyex" self.username = username self._request_auth() finally: @@ -161,15 +185,15 @@ class AuthHandler (object): def _request_auth(self): m = Message() m.add_byte(cMSG_SERVICE_REQUEST) - m.add_string('ssh-userauth') + m.add_string("ssh-userauth") self.transport._send_message(m) def _disconnect_service_not_available(self): m = Message() m.add_byte(cMSG_DISCONNECT) m.add_int(DISCONNECT_SERVICE_NOT_AVAILABLE) - m.add_string('Service not available') - m.add_string('en') + m.add_string("Service not available") + m.add_string("en") self.transport._send_message(m) self.transport.close() @@ -177,8 +201,8 @@ class AuthHandler (object): m = Message() m.add_byte(cMSG_DISCONNECT) m.add_int(DISCONNECT_NO_MORE_AUTH_METHODS_AVAILABLE) - m.add_string('No more auth methods available') - m.add_string('en') + m.add_string("No more auth methods available") + m.add_string("en") self.transport._send_message(m) self.transport.close() @@ -188,7 +212,7 @@ class AuthHandler (object): m.add_byte(cMSG_USERAUTH_REQUEST) m.add_string(username) m.add_string(service) - m.add_string('publickey') + m.add_string("publickey") m.add_boolean(True) # Use certificate contents, if available, plain pubkey otherwise if key.public_blob: @@ -208,17 +232,17 @@ class AuthHandler (object): if not self.transport.is_active(): e = self.transport.get_exception() if (e is None) or issubclass(e.__class__, EOFError): - e = AuthenticationException('Authentication failed.') + e = AuthenticationException("Authentication failed.") raise e if event.is_set(): break if max_ts is not None and max_ts <= time.time(): - raise AuthenticationException('Authentication timeout.') + raise AuthenticationException("Authentication timeout.") if not self.is_authenticated(): e = self.transport.get_exception() if e is None: - e = AuthenticationException('Authentication failed.') + e = AuthenticationException("Authentication failed.") # this is horrible. Python Exception isn't yet descended from # object, so type(e) won't work. :( if issubclass(e.__class__, PartialAuthentication): @@ -228,7 +252,7 @@ class AuthHandler (object): def _parse_service_request(self, m): service = m.get_text() - if self.transport.server_mode and (service == 'ssh-userauth'): + if self.transport.server_mode and (service == "ssh-userauth"): # accepted m = Message() m.add_byte(cMSG_SERVICE_ACCEPT) @@ -247,18 +271,18 @@ class AuthHandler (object): def _parse_service_accept(self, m): service = m.get_text() - if service == 'ssh-userauth': - self._log(DEBUG, 'userauth is OK') + if service == "ssh-userauth": + self._log(DEBUG, "userauth is OK") m = Message() m.add_byte(cMSG_USERAUTH_REQUEST) m.add_string(self.username) - m.add_string('ssh-connection') + m.add_string("ssh-connection") m.add_string(self.auth_method) - if self.auth_method == 'password': + if self.auth_method == "password": m.add_boolean(False) password = b(self.password) m.add_string(password) - elif self.auth_method == 'publickey': + elif self.auth_method == "publickey": m.add_boolean(True) # Use certificate contents, if available, plain pubkey # otherwise @@ -269,11 +293,12 @@ class AuthHandler (object): m.add_string(self.private_key.get_name()) m.add_string(self.private_key) blob = self._get_session_blob( - self.private_key, 'ssh-connection', self.username) + self.private_key, "ssh-connection", self.username + ) sig = self.private_key.sign_ssh_data(blob) m.add_string(sig) - elif self.auth_method == 'keyboard-interactive': - m.add_string('') + elif self.auth_method == "keyboard-interactive": + m.add_string("") m.add_string(self.submethods) elif self.auth_method == "gssapi-with-mic": sshgss = GSSAuth(self.auth_method, self.gss_deleg_creds) @@ -292,10 +317,11 @@ class AuthHandler (object): m = Message() m.add_byte(cMSG_USERAUTH_GSSAPI_TOKEN) try: - m.add_string(sshgss.ssh_init_sec_context( - self.gss_host, - mech, - self.username,)) + m.add_string( + sshgss.ssh_init_sec_context( + self.gss_host, mech, self.username + ) + ) except GSS_EXCEPTIONS as e: return self._handle_local_gss_failure(e) self.transport._send_message(m) @@ -308,7 +334,8 @@ class AuthHandler (object): self.gss_host, mech, self.username, - srv_token) + srv_token, + ) except GSS_EXCEPTIONS as e: return self._handle_local_gss_failure(e) # After this step the GSSAPI should not return any @@ -340,48 +367,55 @@ class AuthHandler (object): min_status = m.get_int() err_msg = m.get_string() m.get_string() # Lang tag - discarded - raise SSHException("""GSS-API Error: + raise SSHException( + """GSS-API Error: Major Status: {} Minor Status: {} Error Message: {} -""".format(maj_status, min_status, err_msg)) +""".format( + maj_status, min_status, err_msg + ) + ) elif ptype == MSG_USERAUTH_FAILURE: self._parse_userauth_failure(m) return else: raise SSHException( - "Received Package: {}".format(MSG_NAMES[ptype])) + "Received Package: {}".format(MSG_NAMES[ptype]) + ) elif ( - self.auth_method == 'gssapi-keyex' and - self.transport.gss_kex_used + self.auth_method == "gssapi-keyex" + and self.transport.gss_kex_used ): kexgss = self.transport.kexgss_ctxt kexgss.set_username(self.username) mic_token = kexgss.ssh_get_mic(self.transport.session_id) m.add_string(mic_token) - elif self.auth_method == 'none': + elif self.auth_method == "none": pass else: raise SSHException( - 'Unknown auth method "{}"'.format(self.auth_method)) + 'Unknown auth method "{}"'.format(self.auth_method) + ) self.transport._send_message(m) else: self._log( - DEBUG, - 'Service request "{}" accepted (?)'.format(service)) + DEBUG, 'Service request "{}" accepted (?)'.format(service) + ) def _send_auth_result(self, username, method, result): # okay, send result m = Message() if result == AUTH_SUCCESSFUL: - self._log(INFO, 'Auth granted ({}).'.format(method)) + self._log(INFO, "Auth granted ({}).".format(method)) m.add_byte(cMSG_USERAUTH_SUCCESS) self.authenticated = True else: - self._log(INFO, 'Auth rejected ({}).'.format(method)) + self._log(INFO, "Auth rejected ({}).".format(method)) m.add_byte(cMSG_USERAUTH_FAILURE) m.add_string( - self.transport.server_object.get_allowed_auths(username)) + self.transport.server_object.get_allowed_auths(username) + ) if result == AUTH_PARTIALLY_SUCCESSFUL: m.add_boolean(True) else: @@ -411,7 +445,7 @@ Error Message: {} # er, uh... what? m = Message() m.add_byte(cMSG_USERAUTH_FAILURE) - m.add_string('none') + m.add_string("none") m.add_boolean(False) self.transport._send_message(m) return @@ -423,18 +457,19 @@ Error Message: {} method = m.get_text() self._log( DEBUG, - 'Auth request (type={}) service={}, username={}'.format( + "Auth request (type={}) service={}, username={}".format( method, service, username - ) + ), ) - if service != 'ssh-connection': + if service != "ssh-connection": self._disconnect_service_not_available() return - if ((self.auth_username is not None) and - (self.auth_username != username)): + if (self.auth_username is not None) and ( + self.auth_username != username + ): self._log( WARNING, - 'Auth rejected because the client attempted to change username in mid-flight' # noqa + "Auth rejected because the client attempted to change username in mid-flight", # noqa ) self._disconnect_no_more_auth() return @@ -442,13 +477,13 @@ Error Message: {} # check if GSS-API authentication is enabled gss_auth = self.transport.server_object.enable_auth_gssapi() - if method == 'none': + if method == "none": result = self.transport.server_object.check_auth_none(username) - elif method == 'password': + elif method == "password": changereq = m.get_boolean() password = m.get_binary() try: - password = password.decode('UTF-8') + password = password.decode("UTF-8") except UnicodeError: # some clients/servers expect non-utf-8 passwords! # in this case, just return the raw byte string. @@ -457,31 +492,30 @@ Error Message: {} # always treated as failure, since we don't support changing # passwords, but collect the list of valid auth types from # the callback anyway - self._log( - DEBUG, - 'Auth request to change passwords (rejected)') + self._log(DEBUG, "Auth request to change passwords (rejected)") newpassword = m.get_binary() try: - newpassword = newpassword.decode('UTF-8', 'replace') + newpassword = newpassword.decode("UTF-8", "replace") except UnicodeError: pass result = AUTH_FAILED else: result = self.transport.server_object.check_auth_password( - username, password) - elif method == 'publickey': + username, password + ) + elif method == "publickey": sig_attached = m.get_boolean() keytype = m.get_text() keyblob = m.get_binary() try: key = self.transport._key_info[keytype](Message(keyblob)) except SSHException as e: - self._log( - INFO, - 'Auth rejected: public key: {}'.format(str(e))) + self._log(INFO, "Auth rejected: public key: {}".format(str(e))) key = None except Exception as e: - msg = 'Auth rejected: unsupported or mangled public key ({}: {})' # noqa + msg = ( + "Auth rejected: unsupported or mangled public key ({}: {})" + ) # noqa self._log(INFO, msg.format(e.__class__.__name__, e)) key = None if key is None: @@ -489,7 +523,8 @@ Error Message: {} return # first check if this key is okay... if not, we can skip the verify result = self.transport.server_object.check_auth_publickey( - username, key) + username, key + ) if result != AUTH_FAILED: # key is okay, verify it if not sig_attached: @@ -504,14 +539,13 @@ Error Message: {} sig = Message(m.get_binary()) blob = self._get_session_blob(key, service, username) if not key.verify_ssh_sig(blob, sig): - self._log( - INFO, - 'Auth rejected: invalid signature') + self._log(INFO, "Auth rejected: invalid signature") result = AUTH_FAILED - elif method == 'keyboard-interactive': + elif method == "keyboard-interactive": submethods = m.get_string() result = self.transport.server_object.check_auth_interactive( - username, submethods) + username, submethods + ) if isinstance(result, InteractiveQuery): # make interactive query instead of response self._interactive_query(result) @@ -527,7 +561,8 @@ Error Message: {} if mechs > 1: self._log( INFO, - 'Disconnect: Received more than one GSS-API OID mechanism') + "Disconnect: Received more than one GSS-API OID mechanism", + ) self._disconnect_no_more_auth() desired_mech = m.get_string() mech_ok = sshgss.ssh_check_mech(desired_mech) @@ -535,7 +570,8 @@ Error Message: {} if not mech_ok: self._log( INFO, - 'Disconnect: Received an invalid GSS-API OID mechanism') + "Disconnect: Received an invalid GSS-API OID mechanism", + ) self._disconnect_no_more_auth() # send the Kerberos V5 GSSAPI OID to the client supported_mech = sshgss.ssh_gss_oids("server") @@ -544,11 +580,14 @@ Error Message: {} m = Message() m.add_byte(cMSG_USERAUTH_GSSAPI_RESPONSE) m.add_bytes(supported_mech) - self.transport.auth_handler = GssapiWithMicAuthHandler(self, - sshgss) - self.transport._expected_packet = (MSG_USERAUTH_GSSAPI_TOKEN, - MSG_USERAUTH_REQUEST, - MSG_SERVICE_REQUEST) + self.transport.auth_handler = GssapiWithMicAuthHandler( + self, sshgss + ) + self.transport._expected_packet = ( + MSG_USERAUTH_GSSAPI_TOKEN, + MSG_USERAUTH_REQUEST, + MSG_SERVICE_REQUEST, + ) self.transport._send_message(m) return elif method == "gssapi-keyex" and gss_auth: @@ -559,16 +598,17 @@ Error Message: {} result = AUTH_FAILED self._send_auth_result(username, method, result) try: - sshgss.ssh_check_mic(mic_token, - self.transport.session_id, - self.auth_username) + sshgss.ssh_check_mic( + mic_token, self.transport.session_id, self.auth_username + ) except Exception: result = AUTH_FAILED self._send_auth_result(username, method, result) raise result = AUTH_SUCCESSFUL self.transport.server_object.check_auth_gssapi_keyex( - username, result) + username, result + ) else: result = self.transport.server_object.check_auth_none(username) # okay, send result @@ -576,8 +616,8 @@ Error Message: {} def _parse_userauth_success(self, m): self._log( - INFO, - 'Authentication ({}) successful!'.format(self.auth_method)) + INFO, "Authentication ({}) successful!".format(self.auth_method) + ) self.authenticated = True self.transport._auth_trigger() if self.auth_event is not None: @@ -587,24 +627,23 @@ Error Message: {} authlist = m.get_list() partial = m.get_boolean() if partial: - self._log(INFO, 'Authentication continues...') - self._log(DEBUG, 'Methods: ' + str(authlist)) + self._log(INFO, "Authentication continues...") + self._log(DEBUG, "Methods: " + str(authlist)) self.transport.saved_exception = PartialAuthentication(authlist) elif self.auth_method not in authlist: for msg in ( - 'Authentication type ({}) not permitted.'.format( + "Authentication type ({}) not permitted.".format( self.auth_method ), - 'Allowed methods: {}'.format(authlist), + "Allowed methods: {}".format(authlist), ): self._log(DEBUG, msg) self.transport.saved_exception = BadAuthenticationType( - 'Bad authentication type', authlist + "Bad authentication type", authlist ) else: self._log( - INFO, - 'Authentication ({}) failed.'.format(self.auth_method) + INFO, "Authentication ({}) failed.".format(self.auth_method) ) self.authenticated = False self.username = None @@ -614,12 +653,12 @@ Error Message: {} def _parse_userauth_banner(self, m): banner = m.get_string() self.banner = banner - self._log(INFO, 'Auth banner: {}'.format(banner)) + self._log(INFO, "Auth banner: {}".format(banner)) # who cares. def _parse_userauth_info_request(self, m): - if self.auth_method != 'keyboard-interactive': - raise SSHException('Illegal info request from server') + if self.auth_method != "keyboard-interactive": + raise SSHException("Illegal info request from server") title = m.get_text() instructions = m.get_text() m.get_binary() # lang @@ -628,7 +667,8 @@ Error Message: {} for i in range(prompts): prompt_list.append((m.get_text(), m.get_boolean())) response_list = self.interactive_handler( - title, instructions, prompt_list) + title, instructions, prompt_list + ) m = Message() m.add_byte(cMSG_USERAUTH_INFO_RESPONSE) @@ -639,25 +679,26 @@ Error Message: {} def _parse_userauth_info_response(self, m): if not self.transport.server_mode: - raise SSHException('Illegal info response from server') + raise SSHException("Illegal info response from server") n = m.get_int() responses = [] for i in range(n): responses.append(m.get_text()) result = self.transport.server_object.check_auth_interactive_response( - responses) + responses + ) if isinstance(result, InteractiveQuery): # make interactive query instead of response self._interactive_query(result) return self._send_auth_result( - self.auth_username, 'keyboard-interactive', result) + self.auth_username, "keyboard-interactive", result + ) def _handle_local_gss_failure(self, e): self.transport.saved_exception = e self._log(DEBUG, "GSSAPI failure: {}".format(e)) - self._log(INFO, 'Authentication ({}) failed.'.format( - self.auth_method)) + self._log(INFO, "Authentication ({}) failed.".format(self.auth_method)) self.authenticated = False self.username = None if self.auth_event is not None: @@ -718,9 +759,9 @@ class GssapiWithMicAuthHandler(object): # context. sshgss = self.sshgss try: - token = sshgss.ssh_accept_sec_context(self.gss_host, - client_token, - self.auth_username) + token = sshgss.ssh_accept_sec_context( + self.gss_host, client_token, self.auth_username + ) except Exception as e: self.transport.saved_exception = e result = AUTH_FAILED @@ -731,9 +772,11 @@ class GssapiWithMicAuthHandler(object): m = Message() m.add_byte(cMSG_USERAUTH_GSSAPI_TOKEN) m.add_string(token) - self.transport._expected_packet = (MSG_USERAUTH_GSSAPI_TOKEN, - MSG_USERAUTH_GSSAPI_MIC, - MSG_USERAUTH_REQUEST) + self.transport._expected_packet = ( + MSG_USERAUTH_GSSAPI_TOKEN, + MSG_USERAUTH_GSSAPI_MIC, + MSG_USERAUTH_REQUEST, + ) self.transport._send_message(m) def _parse_userauth_gssapi_mic(self, m): @@ -742,9 +785,9 @@ class GssapiWithMicAuthHandler(object): username = self.auth_username self._restore_delegate_auth_handler() try: - sshgss.ssh_check_mic(mic_token, - self.transport.session_id, - username) + sshgss.ssh_check_mic( + mic_token, self.transport.session_id, username + ) except Exception as e: self.transport.saved_exception = e result = AUTH_FAILED @@ -754,8 +797,9 @@ class GssapiWithMicAuthHandler(object): # The OpenSSH server is able to create a TGT with the delegated # client credentials, but this is not supported by GSS-API. result = AUTH_SUCCESSFUL - self.transport.server_object.check_auth_gssapi_with_mic(username, - result) + self.transport.server_object.check_auth_gssapi_with_mic( + username, result + ) # okay, send result self._send_auth_result(username, self.method, result) diff --git a/paramiko/ber.py b/paramiko/ber.py index 876347e0..92d7121e 100644 --- a/paramiko/ber.py +++ b/paramiko/ber.py @@ -21,7 +21,7 @@ from paramiko.py3compat import b, byte_ord, byte_chr, long import paramiko.util as util -class BERException (Exception): +class BERException(Exception): pass @@ -41,7 +41,7 @@ class BER(object): return self.asbytes() def __repr__(self): - return 'BER(\'' + repr(self.content) + '\')' + return "BER('" + repr(self.content) + "')" def decode(self): return self.decode_next() @@ -72,12 +72,13 @@ class BER(object): if self.idx + t > len(self.content): return None size = util.inflate_long( - self.content[self.idx: self.idx + t], True) + self.content[self.idx : self.idx + t], True + ) self.idx += t if self.idx + size > len(self.content): # can't fit return None - data = self.content[self.idx: self.idx + size] + data = self.content[self.idx : self.idx + size] self.idx += size # now switch on id if ident == 0x30: @@ -88,7 +89,7 @@ class BER(object): return util.inflate_long(data) else: # 1: boolean (00 false, otherwise true) - msg = 'Unknown ber encoding type {:d} (robey is lazy)' + msg = "Unknown ber encoding type {:d} (robey is lazy)" raise BERException(msg.format(ident)) @staticmethod @@ -126,7 +127,7 @@ class BER(object): self.encode_tlv(0x30, self.encode_sequence(x)) else: raise BERException( - 'Unknown type for encoding: {!r}'.format(type(x)) + "Unknown type for encoding: {!r}".format(type(x)) ) @staticmethod diff --git a/paramiko/buffered_pipe.py b/paramiko/buffered_pipe.py index d9f5149d..48f5aae5 100644 --- a/paramiko/buffered_pipe.py +++ b/paramiko/buffered_pipe.py @@ -28,14 +28,14 @@ import time from paramiko.py3compat import PY2, b -class PipeTimeout (IOError): +class PipeTimeout(IOError): """ Indicates that a timeout was reached on a read from a `.BufferedPipe`. """ pass -class BufferedPipe (object): +class BufferedPipe(object): """ A buffer that obeys normal read (with timeout) & close semantics for a file or socket, but is fed data from another thread. This is used by @@ -46,16 +46,19 @@ class BufferedPipe (object): self._lock = threading.Lock() self._cv = threading.Condition(self._lock) self._event = None - self._buffer = array.array('B') + self._buffer = array.array("B") self._closed = False if PY2: + def _buffer_frombytes(self, data): self._buffer.fromstring(data) def _buffer_tobytes(self, limit=None): return self._buffer[:limit].tostring() + else: + def _buffer_frombytes(self, data): self._buffer.frombytes(data) diff --git a/paramiko/channel.py b/paramiko/channel.py index 91a8f0df..00f86d6e 100644 --- a/paramiko/channel.py +++ b/paramiko/channel.py @@ -25,14 +25,22 @@ import os import socket import time import threading + # TODO: switch as much of py3compat.py to 'six' as possible, then use six.wraps from functools import wraps from paramiko import util from paramiko.common import ( - cMSG_CHANNEL_REQUEST, cMSG_CHANNEL_WINDOW_ADJUST, cMSG_CHANNEL_DATA, - cMSG_CHANNEL_EXTENDED_DATA, DEBUG, ERROR, cMSG_CHANNEL_SUCCESS, - cMSG_CHANNEL_FAILURE, cMSG_CHANNEL_EOF, cMSG_CHANNEL_CLOSE, + cMSG_CHANNEL_REQUEST, + cMSG_CHANNEL_WINDOW_ADJUST, + cMSG_CHANNEL_DATA, + cMSG_CHANNEL_EXTENDED_DATA, + DEBUG, + ERROR, + cMSG_CHANNEL_SUCCESS, + cMSG_CHANNEL_FAILURE, + cMSG_CHANNEL_EOF, + cMSG_CHANNEL_CLOSE, ) from paramiko.message import Message from paramiko.py3compat import bytes_types @@ -51,20 +59,22 @@ def open_only(func): `.SSHException` -- If the wrapped method is called on an unopened `.Channel`. """ + @wraps(func) def _check(self, *args, **kwds): if ( - self.closed or - self.eof_received or - self.eof_sent or - not self.active + self.closed + or self.eof_received + or self.eof_sent + or not self.active ): - raise SSHException('Channel is not open') + raise SSHException("Channel is not open") return func(self, *args, **kwds) + return _check -class Channel (ClosingContextManager): +class Channel(ClosingContextManager): """ A secure tunnel across an SSH `.Transport`. A Channel is meant to behave like a socket, and has an API that should be indistinguishable from the @@ -117,7 +127,7 @@ class Channel (ClosingContextManager): self.in_window_sofar = 0 self.status_event = threading.Event() self._name = str(chanid) - self.logger = util.get_logger('paramiko.transport') + self.logger = util.get_logger("paramiko.transport") self._pipe = None self.event = threading.Event() self.event_ready = False @@ -135,24 +145,30 @@ class Channel (ClosingContextManager): """ Return a string representation of this object, for debugging. """ - out = '<paramiko.Channel {}'.format(self.chanid) + out = "<paramiko.Channel {}".format(self.chanid) if self.closed: - out += ' (closed)' + out += " (closed)" elif self.active: if self.eof_received: - out += ' (EOF received)' + out += " (EOF received)" if self.eof_sent: - out += ' (EOF sent)' - out += ' (open) window={}'.format(self.out_window_size) + out += " (EOF sent)" + out += " (open) window={}".format(self.out_window_size) if len(self.in_buffer) > 0: - out += ' in-buffer={}'.format(len(self.in_buffer)) - out += ' -> ' + repr(self.transport) - out += '>' + out += " in-buffer={}".format(len(self.in_buffer)) + out += " -> " + repr(self.transport) + out += ">" return out @open_only - def get_pty(self, term='vt100', width=80, height=24, width_pixels=0, - height_pixels=0): + def get_pty( + self, + term="vt100", + width=80, + height=24, + width_pixels=0, + height_pixels=0, + ): """ Request a pseudo-terminal from the server. This is usually used right after creating a client channel, to ask the server to provide some @@ -174,7 +190,7 @@ class Channel (ClosingContextManager): m = Message() m.add_byte(cMSG_CHANNEL_REQUEST) m.add_int(self.remote_chanid) - m.add_string('pty-req') + m.add_string("pty-req") m.add_boolean(True) m.add_string(term) m.add_int(width) @@ -207,7 +223,7 @@ class Channel (ClosingContextManager): m = Message() m.add_byte(cMSG_CHANNEL_REQUEST) m.add_int(self.remote_chanid) - m.add_string('shell') + m.add_string("shell") m.add_boolean(True) self._event_pending() self.transport._send_user_message(m) @@ -233,7 +249,7 @@ class Channel (ClosingContextManager): m = Message() m.add_byte(cMSG_CHANNEL_REQUEST) m.add_int(self.remote_chanid) - m.add_string('exec') + m.add_string("exec") m.add_boolean(True) m.add_string(command) self._event_pending() @@ -259,7 +275,7 @@ class Channel (ClosingContextManager): m = Message() m.add_byte(cMSG_CHANNEL_REQUEST) m.add_int(self.remote_chanid) - m.add_string('subsystem') + m.add_string("subsystem") m.add_boolean(True) m.add_string(subsystem) self._event_pending() @@ -284,7 +300,7 @@ class Channel (ClosingContextManager): m = Message() m.add_byte(cMSG_CHANNEL_REQUEST) m.add_int(self.remote_chanid) - m.add_string('window-change') + m.add_string("window-change") m.add_boolean(False) m.add_int(width) m.add_int(height) @@ -315,7 +331,7 @@ class Channel (ClosingContextManager): try: self.set_environment_variable(name, value) except SSHException as e: - err = "Failed to set environment variable \"{}\"." + err = 'Failed to set environment variable "{}".' raise SSHException(err.format(name), e) @open_only @@ -339,7 +355,7 @@ class Channel (ClosingContextManager): m = Message() m.add_byte(cMSG_CHANNEL_REQUEST) m.add_int(self.remote_chanid) - m.add_string('env') + m.add_string("env") m.add_boolean(False) m.add_string(name) m.add_string(value) @@ -403,19 +419,19 @@ class Channel (ClosingContextManager): m = Message() m.add_byte(cMSG_CHANNEL_REQUEST) m.add_int(self.remote_chanid) - m.add_string('exit-status') + m.add_string("exit-status") m.add_boolean(False) m.add_int(status) self.transport._send_user_message(m) @open_only def request_x11( - self, - screen_number=0, - auth_protocol=None, - auth_cookie=None, - single_connection=False, - handler=None + self, + screen_number=0, + auth_protocol=None, + auth_cookie=None, + single_connection=False, + handler=None, ): """ Request an x11 session on this channel. If the server allows it, @@ -456,14 +472,14 @@ class Channel (ClosingContextManager): :return: the auth_cookie used """ if auth_protocol is None: - auth_protocol = 'MIT-MAGIC-COOKIE-1' + auth_protocol = "MIT-MAGIC-COOKIE-1" if auth_cookie is None: auth_cookie = binascii.hexlify(os.urandom(16)) m = Message() m.add_byte(cMSG_CHANNEL_REQUEST) m.add_int(self.remote_chanid) - m.add_string('x11-req') + m.add_string("x11-req") m.add_boolean(True) m.add_boolean(single_connection) m.add_string(auth_protocol) @@ -493,7 +509,7 @@ class Channel (ClosingContextManager): m = Message() m.add_byte(cMSG_CHANNEL_REQUEST) m.add_int(self.remote_chanid) - m.add_string('auth-agent-req@openssh.com') + m.add_string("auth-agent-req@openssh.com") m.add_boolean(False) self.transport._send_user_message(m) self.transport._set_forward_agent_handler(handler) @@ -808,7 +824,6 @@ class Channel (ClosingContextManager): m.add_int(1) return self._send(s, m) - def sendall(self, s): """ Send data to the channel, without allowing partial results. Unlike @@ -976,7 +991,7 @@ class Channel (ClosingContextManager): # a window update self.in_window_threshold = window_size // 10 self.in_window_sofar = 0 - self._log(DEBUG, 'Max packet in: {} bytes'.format(max_packet_size)) + self._log(DEBUG, "Max packet in: {} bytes".format(max_packet_size)) def _set_remote_channel(self, chanid, window_size, max_packet_size): self.remote_chanid = chanid @@ -985,11 +1000,12 @@ class Channel (ClosingContextManager): max_packet_size ) self.active = 1 - self._log(DEBUG, - 'Max packet out: {} bytes'.format(self.out_max_packet_size)) + self._log( + DEBUG, "Max packet out: {} bytes".format(self.out_max_packet_size) + ) def _request_success(self, m): - self._log(DEBUG, 'Sesch channel {} request ok'.format(self.chanid)) + self._log(DEBUG, "Sesch channel {} request ok".format(self.chanid)) self.event_ready = True self.event.set() return @@ -1017,8 +1033,7 @@ class Channel (ClosingContextManager): s = m.get_binary() if code != 1: self._log( - ERROR, - 'unknown extended_data type {}; discarding'.format(code) + ERROR, "unknown extended_data type {}; discarding".format(code) ) return if self.combine_stderr: @@ -1031,7 +1046,7 @@ class Channel (ClosingContextManager): self.lock.acquire() try: if self.ultra_debug: - self._log(DEBUG, 'window up {}'.format(nbytes)) + self._log(DEBUG, "window up {}".format(nbytes)) self.out_window_size += nbytes self.out_buffer_cv.notifyAll() finally: @@ -1042,14 +1057,14 @@ class Channel (ClosingContextManager): want_reply = m.get_boolean() server = self.transport.server_object ok = False - if key == 'exit-status': + if key == "exit-status": self.exit_status = m.get_int() self.status_event.set() ok = True - elif key == 'xon-xoff': + elif key == "xon-xoff": # ignore ok = True - elif key == 'pty-req': + elif key == "pty-req": term = m.get_string() width = m.get_int() height = m.get_int() @@ -1060,39 +1075,33 @@ class Channel (ClosingContextManager): ok = False else: ok = server.check_channel_pty_request( - self, - term, - width, - height, - pixelwidth, - pixelheight, - modes + self, term, width, height, pixelwidth, pixelheight, modes ) - elif key == 'shell': + elif key == "shell": if server is None: ok = False else: ok = server.check_channel_shell_request(self) - elif key == 'env': + elif key == "env": name = m.get_string() value = m.get_string() if server is None: ok = False else: ok = server.check_channel_env_request(self, name, value) - elif key == 'exec': + elif key == "exec": cmd = m.get_string() if server is None: ok = False else: ok = server.check_channel_exec_request(self, cmd) - elif key == 'subsystem': + elif key == "subsystem": name = m.get_text() if server is None: ok = False else: ok = server.check_channel_subsystem_request(self, name) - elif key == 'window-change': + elif key == "window-change": width = m.get_int() height = m.get_int() pixelwidth = m.get_int() @@ -1101,8 +1110,9 @@ class Channel (ClosingContextManager): ok = False else: ok = server.check_channel_window_change_request( - self, width, height, pixelwidth, pixelheight) - elif key == 'x11-req': + self, width, height, pixelwidth, pixelheight + ) + elif key == "x11-req": single_connection = m.get_boolean() auth_proto = m.get_text() auth_cookie = m.get_binary() @@ -1115,9 +1125,9 @@ class Channel (ClosingContextManager): single_connection, auth_proto, auth_cookie, - screen_number + screen_number, ) - elif key == 'auth-agent-req@openssh.com': + elif key == "auth-agent-req@openssh.com": if server is None: ok = False else: @@ -1145,7 +1155,7 @@ class Channel (ClosingContextManager): self._pipe.set_forever() finally: self.lock.release() - self._log(DEBUG, 'EOF received ({})'.format(self._name)) + self._log(DEBUG, "EOF received ({})".format(self._name)) def _handle_close(self, m): self.lock.acquire() @@ -1167,7 +1177,7 @@ class Channel (ClosingContextManager): if self.closed: # this doesn't seem useful, but it is the documented behavior # of Socket - raise socket.error('Socket is closed') + raise socket.error("Socket is closed") size = self._wait_for_send_window(size) if size == 0: # eof or similar @@ -1194,7 +1204,7 @@ class Channel (ClosingContextManager): return e = self.transport.get_exception() if e is None: - e = SSHException('Channel closed.') + e = SSHException("Channel closed.") raise e def _set_closed(self): @@ -1217,7 +1227,7 @@ class Channel (ClosingContextManager): m.add_byte(cMSG_CHANNEL_EOF) m.add_int(self.remote_chanid) self.eof_sent = True - self._log(DEBUG, 'EOF sent ({})'.format(self._name)) + self._log(DEBUG, "EOF sent ({})".format(self._name)) return m def _close_internal(self): @@ -1251,13 +1261,14 @@ class Channel (ClosingContextManager): if self.closed or self.eof_received or not self.active: return 0 if self.ultra_debug: - self._log(DEBUG, 'addwindow {}'.format(n)) + self._log(DEBUG, "addwindow {}".format(n)) self.in_window_sofar += n if self.in_window_sofar <= self.in_window_threshold: return 0 if self.ultra_debug: - self._log(DEBUG, - 'addwindow send {}'.format(self.in_window_sofar)) + self._log( + DEBUG, "addwindow send {}".format(self.in_window_sofar) + ) out = self.in_window_sofar self.in_window_sofar = 0 return out @@ -1300,11 +1311,11 @@ class Channel (ClosingContextManager): size = self.out_max_packet_size - 64 self.out_window_size -= size if self.ultra_debug: - self._log(DEBUG, 'window down to {}'.format(self.out_window_size)) + self._log(DEBUG, "window down to {}".format(self.out_window_size)) return size -class ChannelFile (BufferedFile): +class ChannelFile(BufferedFile): """ A file-like wrapper around `.Channel`. A ChannelFile is created by calling `Channel.makefile`. @@ -1317,7 +1328,7 @@ class ChannelFile (BufferedFile): flush the buffer. """ - def __init__(self, channel, mode='r', bufsize=-1): + def __init__(self, channel, mode="r", bufsize=-1): self.channel = channel BufferedFile.__init__(self) self._set_mode(mode, bufsize) @@ -1326,7 +1337,7 @@ class ChannelFile (BufferedFile): """ Returns a string representation of this object, for debugging. """ - return '<paramiko.ChannelFile from ' + repr(self.channel) + '>' + return "<paramiko.ChannelFile from " + repr(self.channel) + ">" def _read(self, size): return self.channel.recv(size) @@ -1336,8 +1347,9 @@ class ChannelFile (BufferedFile): return len(data) -class ChannelStderrFile (ChannelFile): - def __init__(self, channel, mode='r', bufsize=-1): +class ChannelStderrFile(ChannelFile): + + def __init__(self, channel, mode="r", bufsize=-1): ChannelFile.__init__(self, channel, mode, bufsize) def _read(self, size): diff --git a/paramiko/client.py b/paramiko/client.py index de0a495e..6bf479d4 100644 --- a/paramiko/client.py +++ b/paramiko/client.py @@ -38,13 +38,15 @@ from paramiko.hostkeys import HostKeys from paramiko.py3compat import string_types from paramiko.rsakey import RSAKey from paramiko.ssh_exception import ( - SSHException, BadHostKeyException, NoValidConnectionsError + SSHException, + BadHostKeyException, + NoValidConnectionsError, ) from paramiko.transport import Transport from paramiko.util import retry_on_signal, ClosingContextManager -class SSHClient (ClosingContextManager): +class SSHClient(ClosingContextManager): """ A high-level representation of a session with an SSH server. This class wraps `.Transport`, `.Channel`, and `.SFTPClient` to take care of most @@ -97,7 +99,7 @@ class SSHClient (ClosingContextManager): """ if filename is None: # try the user's .ssh key file, and mask exceptions - filename = os.path.expanduser('~/.ssh/known_hosts') + filename = os.path.expanduser("~/.ssh/known_hosts") try: self._system_host_keys.load(filename) except IOError: @@ -140,12 +142,14 @@ class SSHClient (ClosingContextManager): if self._host_keys_filename is not None: self.load_host_keys(self._host_keys_filename) - with open(filename, 'w') as f: + with open(filename, "w") as f: for hostname, keys in self._host_keys.items(): for keytype, key in keys.items(): - f.write('{} {} {}\n'.format( - hostname, keytype, key.get_base64() - )) + f.write( + "{} {} {}\n".format( + hostname, keytype, key.get_base64() + ) + ) def get_host_keys(self): """ @@ -197,7 +201,8 @@ class SSHClient (ClosingContextManager): """ guess = True addrinfos = socket.getaddrinfo( - hostname, port, socket.AF_UNSPEC, socket.SOCK_STREAM) + hostname, port, socket.AF_UNSPEC, socket.SOCK_STREAM + ) for (family, socktype, proto, canonname, sockaddr) in addrinfos: if socktype == socket.SOCK_STREAM: yield family, sockaddr @@ -419,8 +424,16 @@ class SSHClient (ClosingContextManager): key_filenames = key_filename self._auth( - username, password, pkey, key_filenames, allow_agent, - look_for_keys, gss_auth, gss_kex, gss_deleg_creds, t.gss_host, + username, + password, + pkey, + key_filenames, + allow_agent, + look_for_keys, + gss_auth, + gss_kex, + gss_deleg_creds, + t.gss_host, passphrase, ) @@ -490,13 +503,20 @@ class SSHClient (ClosingContextManager): if environment: chan.update_environment(environment) chan.exec_command(command) - stdin = chan.makefile('wb', bufsize) - stdout = chan.makefile('r', bufsize) - stderr = chan.makefile_stderr('r', bufsize) + stdin = chan.makefile("wb", bufsize) + stdout = chan.makefile("r", bufsize) + stderr = chan.makefile_stderr("r", bufsize) return stdin, stdout, stderr - def invoke_shell(self, term='vt100', width=80, height=24, width_pixels=0, - height_pixels=0, environment=None): + def invoke_shell( + self, + term="vt100", + width=80, + height=24, + width_pixels=0, + height_pixels=0, + environment=None, + ): """ Start an interactive shell session on the SSH server. A new `.Channel` is opened and connected to a pseudo-terminal using the requested @@ -545,10 +565,10 @@ class SSHClient (ClosingContextManager): - Otherwise, the filename is assumed to be a private key, and the matching public cert will be loaded if it exists. """ - cert_suffix = '-cert.pub' + cert_suffix = "-cert.pub" # Assume privkey, not cert, by default if filename.endswith(cert_suffix): - key_path = filename[:-len(cert_suffix)] + key_path = filename[: -len(cert_suffix)] cert_path = filename else: key_path = filename @@ -559,7 +579,7 @@ class SSHClient (ClosingContextManager): # when #387 is released, since this is a critical log message users are # likely testing/filtering for (bah.) msg = "Trying discovered key {} in {}".format( - hexlify(key.get_fingerprint()), key_path, + hexlify(key.get_fingerprint()), key_path ) self._log(DEBUG, msg) # Attempt to load cert if it exists. @@ -569,8 +589,17 @@ class SSHClient (ClosingContextManager): return key def _auth( - self, username, password, pkey, key_filenames, allow_agent, - look_for_keys, gss_auth, gss_kex, gss_deleg_creds, gss_host, + self, + username, + password, + pkey, + key_filenames, + allow_agent, + look_for_keys, + gss_auth, + gss_kex, + gss_deleg_creds, + gss_host, passphrase, ): """ @@ -589,7 +618,7 @@ class SSHClient (ClosingContextManager): saved_exception = None two_factor = False allowed_types = set() - two_factor_types = {'keyboard-interactive', 'password'} + two_factor_types = {"keyboard-interactive", "password"} if passphrase is None and password is not None: passphrase = password @@ -609,7 +638,7 @@ class SSHClient (ClosingContextManager): if gss_auth: try: return self._transport.auth_gssapi_with_mic( - username, gss_host, gss_deleg_creds, + username, gss_host, gss_deleg_creds ) except Exception as e: saved_exception = e @@ -618,11 +647,14 @@ class SSHClient (ClosingContextManager): try: self._log( DEBUG, - 'Trying SSH key {}'.format(hexlify(pkey.get_fingerprint())) + "Trying SSH key {}".format( + hexlify(pkey.get_fingerprint()) + ), ) allowed_types = set( - self._transport.auth_publickey(username, pkey)) - two_factor = (allowed_types & two_factor_types) + self._transport.auth_publickey(username, pkey) + ) + two_factor = allowed_types & two_factor_types if not two_factor: return except SSHException as e: @@ -633,11 +665,12 @@ class SSHClient (ClosingContextManager): for pkey_class in (RSAKey, DSSKey, ECDSAKey, Ed25519Key): try: key = self._key_from_filepath( - key_filename, pkey_class, passphrase, + key_filename, pkey_class, passphrase ) allowed_types = set( - self._transport.auth_publickey(username, key)) - two_factor = (allowed_types & two_factor_types) + self._transport.auth_publickey(username, key) + ) + two_factor = allowed_types & two_factor_types if not two_factor: return break @@ -651,12 +684,13 @@ class SSHClient (ClosingContextManager): for key in self._agent.get_keys(): try: id_ = hexlify(key.get_fingerprint()) - self._log(DEBUG, 'Trying SSH agent key {}'.format(id_)) + self._log(DEBUG, "Trying SSH agent key {}".format(id_)) # for 2-factor auth a successfully auth'd key password # will return an allowed 2fac auth method allowed_types = set( - self._transport.auth_publickey(username, key)) - two_factor = (allowed_types & two_factor_types) + self._transport.auth_publickey(username, key) + ) + two_factor = allowed_types & two_factor_types if not two_factor: return break @@ -680,8 +714,8 @@ class SSHClient (ClosingContextManager): if os.path.isfile(full_path): # TODO: only do this append if below did not run keyfiles.append((keytype, full_path)) - if os.path.isfile(full_path + '-cert.pub'): - keyfiles.append((keytype, full_path + '-cert.pub')) + if os.path.isfile(full_path + "-cert.pub"): + keyfiles.append((keytype, full_path + "-cert.pub")) if not look_for_keys: keyfiles = [] @@ -689,13 +723,14 @@ class SSHClient (ClosingContextManager): for pkey_class, filename in keyfiles: try: key = self._key_from_filepath( - filename, pkey_class, passphrase, + filename, pkey_class, passphrase ) # for 2-factor auth a successfully auth'd key will result # in ['password'] allowed_types = set( - self._transport.auth_publickey(username, key)) - two_factor = (allowed_types & two_factor_types) + self._transport.auth_publickey(username, key) + ) + two_factor = allowed_types & two_factor_types if not two_factor: return break @@ -718,13 +753,13 @@ class SSHClient (ClosingContextManager): # if we got an auth-failed exception earlier, re-raise it if saved_exception is not None: raise saved_exception - raise SSHException('No authentication methods available') + raise SSHException("No authentication methods available") def _log(self, level, msg): self._transport._log(level, msg) -class MissingHostKeyPolicy (object): +class MissingHostKeyPolicy(object): """ Interface for defining the policy that `.SSHClient` should use when the SSH server's hostname is not in either the system host keys or the @@ -745,7 +780,7 @@ class MissingHostKeyPolicy (object): pass -class AutoAddPolicy (MissingHostKeyPolicy): +class AutoAddPolicy(MissingHostKeyPolicy): """ Policy for automatically adding the hostname and new host key to the local `.HostKeys` object, and saving it. This is used by `.SSHClient`. @@ -755,32 +790,41 @@ class AutoAddPolicy (MissingHostKeyPolicy): client._host_keys.add(hostname, key.get_name(), key) if client._host_keys_filename is not None: client.save_host_keys(client._host_keys_filename) - client._log(DEBUG, 'Adding {} host key for {}: {}'.format( - key.get_name(), hostname, hexlify(key.get_fingerprint()), - )) + client._log( + DEBUG, + "Adding {} host key for {}: {}".format( + key.get_name(), hostname, hexlify(key.get_fingerprint()) + ), + ) -class RejectPolicy (MissingHostKeyPolicy): +class RejectPolicy(MissingHostKeyPolicy): """ Policy for automatically rejecting the unknown hostname & key. This is used by `.SSHClient`. """ def missing_host_key(self, client, hostname, key): - client._log(DEBUG, 'Rejecting {} host key for {}: {}'.format( - key.get_name(), hostname, hexlify(key.get_fingerprint()), - )) + client._log( + DEBUG, + "Rejecting {} host key for {}: {}".format( + key.get_name(), hostname, hexlify(key.get_fingerprint()) + ), + ) raise SSHException( - 'Server {!r} not found in known_hosts'.format(hostname) + "Server {!r} not found in known_hosts".format(hostname) ) -class WarningPolicy (MissingHostKeyPolicy): +class WarningPolicy(MissingHostKeyPolicy): """ Policy for logging a Python-style warning for an unknown host key, but accepting it. This is used by `.SSHClient`. """ + def missing_host_key(self, client, hostname, key): - warnings.warn('Unknown {} host key for {}: {}'.format( - key.get_name(), hostname, hexlify(key.get_fingerprint()), - )) + warnings.warn( + "Unknown {} host key for {}: {}".format( + key.get_name(), hostname, hexlify(key.get_fingerprint()) + ) + ) diff --git a/paramiko/common.py b/paramiko/common.py index eab6647e..7bd0cb10 100644 --- a/paramiko/common.py +++ b/paramiko/common.py @@ -22,22 +22,45 @@ Common constants and global variables. import logging from paramiko.py3compat import byte_chr, PY2, long, b -MSG_DISCONNECT, MSG_IGNORE, MSG_UNIMPLEMENTED, MSG_DEBUG, \ - MSG_SERVICE_REQUEST, MSG_SERVICE_ACCEPT = range(1, 7) -MSG_KEXINIT, MSG_NEWKEYS = range(20, 22) -MSG_USERAUTH_REQUEST, MSG_USERAUTH_FAILURE, MSG_USERAUTH_SUCCESS, \ - MSG_USERAUTH_BANNER = range(50, 54) +( + MSG_DISCONNECT, + MSG_IGNORE, + MSG_UNIMPLEMENTED, + MSG_DEBUG, + MSG_SERVICE_REQUEST, + MSG_SERVICE_ACCEPT, +) = range(1, 7) +(MSG_KEXINIT, MSG_NEWKEYS) = range(20, 22) +( + MSG_USERAUTH_REQUEST, + MSG_USERAUTH_FAILURE, + MSG_USERAUTH_SUCCESS, + MSG_USERAUTH_BANNER, +) = range(50, 54) MSG_USERAUTH_PK_OK = 60 -MSG_USERAUTH_INFO_REQUEST, MSG_USERAUTH_INFO_RESPONSE = range(60, 62) -MSG_USERAUTH_GSSAPI_RESPONSE, MSG_USERAUTH_GSSAPI_TOKEN = range(60, 62) -MSG_USERAUTH_GSSAPI_EXCHANGE_COMPLETE, MSG_USERAUTH_GSSAPI_ERROR,\ - MSG_USERAUTH_GSSAPI_ERRTOK, MSG_USERAUTH_GSSAPI_MIC = range(63, 67) +(MSG_USERAUTH_INFO_REQUEST, MSG_USERAUTH_INFO_RESPONSE) = range(60, 62) +(MSG_USERAUTH_GSSAPI_RESPONSE, MSG_USERAUTH_GSSAPI_TOKEN) = range(60, 62) +( + MSG_USERAUTH_GSSAPI_EXCHANGE_COMPLETE, + MSG_USERAUTH_GSSAPI_ERROR, + MSG_USERAUTH_GSSAPI_ERRTOK, + MSG_USERAUTH_GSSAPI_MIC, +) = range(63, 67) HIGHEST_USERAUTH_MESSAGE_ID = 79 -MSG_GLOBAL_REQUEST, MSG_REQUEST_SUCCESS, MSG_REQUEST_FAILURE = range(80, 83) -MSG_CHANNEL_OPEN, MSG_CHANNEL_OPEN_SUCCESS, MSG_CHANNEL_OPEN_FAILURE, \ - MSG_CHANNEL_WINDOW_ADJUST, MSG_CHANNEL_DATA, MSG_CHANNEL_EXTENDED_DATA, \ - MSG_CHANNEL_EOF, MSG_CHANNEL_CLOSE, MSG_CHANNEL_REQUEST, \ - MSG_CHANNEL_SUCCESS, MSG_CHANNEL_FAILURE = range(90, 101) +(MSG_GLOBAL_REQUEST, MSG_REQUEST_SUCCESS, MSG_REQUEST_FAILURE) = range(80, 83) +( + MSG_CHANNEL_OPEN, + MSG_CHANNEL_OPEN_SUCCESS, + MSG_CHANNEL_OPEN_FAILURE, + MSG_CHANNEL_WINDOW_ADJUST, + MSG_CHANNEL_DATA, + MSG_CHANNEL_EXTENDED_DATA, + MSG_CHANNEL_EOF, + MSG_CHANNEL_CLOSE, + MSG_CHANNEL_REQUEST, + MSG_CHANNEL_SUCCESS, + MSG_CHANNEL_FAILURE, +) = range(90, 101) cMSG_DISCONNECT = byte_chr(MSG_DISCONNECT) cMSG_IGNORE = byte_chr(MSG_IGNORE) @@ -56,8 +79,9 @@ cMSG_USERAUTH_INFO_REQUEST = byte_chr(MSG_USERAUTH_INFO_REQUEST) cMSG_USERAUTH_INFO_RESPONSE = byte_chr(MSG_USERAUTH_INFO_RESPONSE) cMSG_USERAUTH_GSSAPI_RESPONSE = byte_chr(MSG_USERAUTH_GSSAPI_RESPONSE) cMSG_USERAUTH_GSSAPI_TOKEN = byte_chr(MSG_USERAUTH_GSSAPI_TOKEN) -cMSG_USERAUTH_GSSAPI_EXCHANGE_COMPLETE = \ - byte_chr(MSG_USERAUTH_GSSAPI_EXCHANGE_COMPLETE) +cMSG_USERAUTH_GSSAPI_EXCHANGE_COMPLETE = byte_chr( + MSG_USERAUTH_GSSAPI_EXCHANGE_COMPLETE +) cMSG_USERAUTH_GSSAPI_ERROR = byte_chr(MSG_USERAUTH_GSSAPI_ERROR) cMSG_USERAUTH_GSSAPI_ERRTOK = byte_chr(MSG_USERAUTH_GSSAPI_ERRTOK) cMSG_USERAUTH_GSSAPI_MIC = byte_chr(MSG_USERAUTH_GSSAPI_MIC) @@ -78,47 +102,47 @@ cMSG_CHANNEL_FAILURE = byte_chr(MSG_CHANNEL_FAILURE) # for debugging: MSG_NAMES = { - MSG_DISCONNECT: 'disconnect', - MSG_IGNORE: 'ignore', - MSG_UNIMPLEMENTED: 'unimplemented', - MSG_DEBUG: 'debug', - MSG_SERVICE_REQUEST: 'service-request', - MSG_SERVICE_ACCEPT: 'service-accept', - MSG_KEXINIT: 'kexinit', - MSG_NEWKEYS: 'newkeys', - 30: 'kex30', - 31: 'kex31', - 32: 'kex32', - 33: 'kex33', - 34: 'kex34', - 40: 'kex40', - 41: 'kex41', - MSG_USERAUTH_REQUEST: 'userauth-request', - MSG_USERAUTH_FAILURE: 'userauth-failure', - MSG_USERAUTH_SUCCESS: 'userauth-success', - MSG_USERAUTH_BANNER: 'userauth--banner', - MSG_USERAUTH_PK_OK: 'userauth-60(pk-ok/info-request)', - MSG_USERAUTH_INFO_RESPONSE: 'userauth-info-response', - MSG_GLOBAL_REQUEST: 'global-request', - MSG_REQUEST_SUCCESS: 'request-success', - MSG_REQUEST_FAILURE: 'request-failure', - MSG_CHANNEL_OPEN: 'channel-open', - MSG_CHANNEL_OPEN_SUCCESS: 'channel-open-success', - MSG_CHANNEL_OPEN_FAILURE: 'channel-open-failure', - MSG_CHANNEL_WINDOW_ADJUST: 'channel-window-adjust', - MSG_CHANNEL_DATA: 'channel-data', - MSG_CHANNEL_EXTENDED_DATA: 'channel-extended-data', - MSG_CHANNEL_EOF: 'channel-eof', - MSG_CHANNEL_CLOSE: 'channel-close', - MSG_CHANNEL_REQUEST: 'channel-request', - MSG_CHANNEL_SUCCESS: 'channel-success', - MSG_CHANNEL_FAILURE: 'channel-failure', - MSG_USERAUTH_GSSAPI_RESPONSE: 'userauth-gssapi-response', - MSG_USERAUTH_GSSAPI_TOKEN: 'userauth-gssapi-token', - MSG_USERAUTH_GSSAPI_EXCHANGE_COMPLETE: 'userauth-gssapi-exchange-complete', - MSG_USERAUTH_GSSAPI_ERROR: 'userauth-gssapi-error', - MSG_USERAUTH_GSSAPI_ERRTOK: 'userauth-gssapi-error-token', - MSG_USERAUTH_GSSAPI_MIC: 'userauth-gssapi-mic' + MSG_DISCONNECT: "disconnect", + MSG_IGNORE: "ignore", + MSG_UNIMPLEMENTED: "unimplemented", + MSG_DEBUG: "debug", + MSG_SERVICE_REQUEST: "service-request", + MSG_SERVICE_ACCEPT: "service-accept", + MSG_KEXINIT: "kexinit", + MSG_NEWKEYS: "newkeys", + 30: "kex30", + 31: "kex31", + 32: "kex32", + 33: "kex33", + 34: "kex34", + 40: "kex40", + 41: "kex41", + MSG_USERAUTH_REQUEST: "userauth-request", + MSG_USERAUTH_FAILURE: "userauth-failure", + MSG_USERAUTH_SUCCESS: "userauth-success", + MSG_USERAUTH_BANNER: "userauth--banner", + MSG_USERAUTH_PK_OK: "userauth-60(pk-ok/info-request)", + MSG_USERAUTH_INFO_RESPONSE: "userauth-info-response", + MSG_GLOBAL_REQUEST: "global-request", + MSG_REQUEST_SUCCESS: "request-success", + MSG_REQUEST_FAILURE: "request-failure", + MSG_CHANNEL_OPEN: "channel-open", + MSG_CHANNEL_OPEN_SUCCESS: "channel-open-success", + MSG_CHANNEL_OPEN_FAILURE: "channel-open-failure", + MSG_CHANNEL_WINDOW_ADJUST: "channel-window-adjust", + MSG_CHANNEL_DATA: "channel-data", + MSG_CHANNEL_EXTENDED_DATA: "channel-extended-data", + MSG_CHANNEL_EOF: "channel-eof", + MSG_CHANNEL_CLOSE: "channel-close", + MSG_CHANNEL_REQUEST: "channel-request", + MSG_CHANNEL_SUCCESS: "channel-success", + MSG_CHANNEL_FAILURE: "channel-failure", + MSG_USERAUTH_GSSAPI_RESPONSE: "userauth-gssapi-response", + MSG_USERAUTH_GSSAPI_TOKEN: "userauth-gssapi-token", + MSG_USERAUTH_GSSAPI_EXCHANGE_COMPLETE: "userauth-gssapi-exchange-complete", + MSG_USERAUTH_GSSAPI_ERROR: "userauth-gssapi-error", + MSG_USERAUTH_GSSAPI_ERRTOK: "userauth-gssapi-error-token", + MSG_USERAUTH_GSSAPI_MIC: "userauth-gssapi-mic", } @@ -127,23 +151,28 @@ AUTH_SUCCESSFUL, AUTH_PARTIALLY_SUCCESSFUL, AUTH_FAILED = range(3) # channel request failed reasons: -(OPEN_SUCCEEDED, - OPEN_FAILED_ADMINISTRATIVELY_PROHIBITED, - OPEN_FAILED_CONNECT_FAILED, - OPEN_FAILED_UNKNOWN_CHANNEL_TYPE, - OPEN_FAILED_RESOURCE_SHORTAGE) = range(0, 5) +( + OPEN_SUCCEEDED, + OPEN_FAILED_ADMINISTRATIVELY_PROHIBITED, + OPEN_FAILED_CONNECT_FAILED, + OPEN_FAILED_UNKNOWN_CHANNEL_TYPE, + OPEN_FAILED_RESOURCE_SHORTAGE, +) = range(0, 5) CONNECTION_FAILED_CODE = { - 1: 'Administratively prohibited', - 2: 'Connect failed', - 3: 'Unknown channel type', - 4: 'Resource shortage' + 1: "Administratively prohibited", + 2: "Connect failed", + 3: "Unknown channel type", + 4: "Resource shortage", } -DISCONNECT_SERVICE_NOT_AVAILABLE, DISCONNECT_AUTH_CANCELLED_BY_USER, \ - DISCONNECT_NO_MORE_AUTH_METHODS_AVAILABLE = 7, 13, 14 +( + DISCONNECT_SERVICE_NOT_AVAILABLE, + DISCONNECT_AUTH_CANCELLED_BY_USER, + DISCONNECT_NO_MORE_AUTH_METHODS_AVAILABLE, +) = (7, 13, 14) zero_byte = byte_chr(0) one_byte = byte_chr(1) diff --git a/paramiko/compress.py b/paramiko/compress.py index 5073109c..cea2b308 100644 --- a/paramiko/compress.py +++ b/paramiko/compress.py @@ -23,7 +23,8 @@ Compression implementations for a Transport. import zlib -class ZlibCompressor (object): +class ZlibCompressor(object): + def __init__(self): # Use the default level of zlib compression self.z = zlib.compressobj() @@ -32,7 +33,8 @@ class ZlibCompressor (object): return self.z.compress(data) + self.z.flush(zlib.Z_FULL_FLUSH) -class ZlibDecompressor (object): +class ZlibDecompressor(object): + def __init__(self): self.z = zlib.decompressobj() diff --git a/paramiko/config.py b/paramiko/config.py index 038d84ea..21c9dab8 100644 --- a/paramiko/config.py +++ b/paramiko/config.py @@ -30,7 +30,7 @@ import socket SSH_PORT = 22 -class SSHConfig (object): +class SSHConfig(object): """ Representation of config information as stored in the format used by OpenSSH. Queries can be made via `lookup`. The format is described in @@ -41,7 +41,7 @@ class SSHConfig (object): .. versionadded:: 1.6 """ - SETTINGS_REGEX = re.compile(r'(\w+)(?:\s*=\s*|\s+)(.+)') + SETTINGS_REGEX = re.compile(r"(\w+)(?:\s*=\s*|\s+)(.+)") def __init__(self): """ @@ -55,12 +55,12 @@ class SSHConfig (object): :param file_obj: a file-like object to read the config file from """ - host = {"host": ['*'], "config": {}} + host = {"host": ["*"], "config": {}} for line in file_obj: # Strip any leading or trailing whitespace from the line. # Refer to https://github.com/paramiko/paramiko/issues/499 line = line.strip() - if not line or line.startswith('#'): + if not line or line.startswith("#"): continue match = re.match(self.SETTINGS_REGEX, line) @@ -69,17 +69,14 @@ class SSHConfig (object): key = match.group(1).lower() value = match.group(2) - if key == 'host': + if key == "host": self._config.append(host) - host = { - 'host': self._get_hosts(value), - 'config': {} - } - elif key == 'proxycommand' and value.lower() == 'none': + host = {"host": self._get_hosts(value), "config": {}} + elif key == "proxycommand" and value.lower() == "none": # Store 'none' as None; prior to 3.x, it will get stripped out # at the end (for compatibility with issue #415). After 3.x, it # will simply not get stripped, leaving a nice explicit marker. - host['config'][key] = None + host["config"][key] = None else: if value.startswith('"') and value.endswith('"'): value = value[1:-1] @@ -87,13 +84,13 @@ class SSHConfig (object): # identityfile, localforward, remoteforward keys are special # cases, since they are allowed to be specified multiple times # and they should be tried in order of specification. - if key in ['identityfile', 'localforward', 'remoteforward']: - if key in host['config']: - host['config'][key].append(value) + if key in ["identityfile", "localforward", "remoteforward"]: + if key in host["config"]: + host["config"][key].append(value) else: - host['config'][key] = [value] - elif key not in host['config']: - host['config'][key] = value + host["config"][key] = [value] + elif key not in host["config"]: + host["config"][key] = value self._config.append(host) def lookup(self, hostname): @@ -117,25 +114,26 @@ class SSHConfig (object): :param str hostname: the hostname to lookup """ matches = [ - config for config in self._config - if self._allowed(config['host'], hostname) + config + for config in self._config + if self._allowed(config["host"], hostname) ] ret = {} for match in matches: - for key, value in match['config'].items(): + for key, value in match["config"].items(): if key not in ret: # Create a copy of the original value, # else it will reference the original list # in self._config and update that value too # when the extend() is being called. ret[key] = value[:] if value is not None else value - elif key == 'identityfile': + elif key == "identityfile": ret[key].extend(value) ret = self._expand_variables(ret, hostname) # TODO: remove in 3.x re #670 - if 'proxycommand' in ret and ret['proxycommand'] is None: - del ret['proxycommand'] + if "proxycommand" in ret and ret["proxycommand"] is None: + del ret["proxycommand"] return ret def get_hostnames(self): @@ -145,13 +143,13 @@ class SSHConfig (object): """ hosts = set() for entry in self._config: - hosts.update(entry['host']) + hosts.update(entry["host"]) return hosts def _allowed(self, hosts, hostname): match = False for host in hosts: - if host.startswith('!') and fnmatch.fnmatch(hostname, host[1:]): + if host.startswith("!") and fnmatch.fnmatch(hostname, host[1:]): return False elif fnmatch.fnmatch(hostname, host): match = True @@ -169,52 +167,50 @@ class SSHConfig (object): :param str hostname: the hostname that the config belongs to """ - if 'hostname' in config: - config['hostname'] = config['hostname'].replace('%h', hostname) + if "hostname" in config: + config["hostname"] = config["hostname"].replace("%h", hostname) else: - config['hostname'] = hostname + config["hostname"] = hostname - if 'port' in config: - port = config['port'] + if "port" in config: + port = config["port"] else: port = SSH_PORT - user = os.getenv('USER') - if 'user' in config: - remoteuser = config['user'] + user = os.getenv("USER") + if "user" in config: + remoteuser = config["user"] else: remoteuser = user - host = socket.gethostname().split('.')[0] + host = socket.gethostname().split(".")[0] fqdn = LazyFqdn(config, host) - homedir = os.path.expanduser('~') - replacements = {'controlpath': - [ - ('%h', config['hostname']), - ('%l', fqdn), - ('%L', host), - ('%n', hostname), - ('%p', port), - ('%r', remoteuser), - ('%u', user) - ], - 'identityfile': - [ - ('~', homedir), - ('%d', homedir), - ('%h', config['hostname']), - ('%l', fqdn), - ('%u', user), - ('%r', remoteuser) - ], - 'proxycommand': - [ - ('~', homedir), - ('%h', config['hostname']), - ('%p', port), - ('%r', remoteuser) - ] - } + homedir = os.path.expanduser("~") + replacements = { + "controlpath": [ + ("%h", config["hostname"]), + ("%l", fqdn), + ("%L", host), + ("%n", hostname), + ("%p", port), + ("%r", remoteuser), + ("%u", user), + ], + "identityfile": [ + ("~", homedir), + ("%d", homedir), + ("%h", config["hostname"]), + ("%l", fqdn), + ("%u", user), + ("%r", remoteuser), + ], + "proxycommand": [ + ("~", homedir), + ("%h", config["hostname"]), + ("%p", port), + ("%r", remoteuser), + ], + } for k in config: if config[k] is None: @@ -265,11 +261,11 @@ class LazyFqdn(object): # Handle specific option fqdn = None - address_family = self.config.get('addressfamily', 'any').lower() - if address_family != 'any': + address_family = self.config.get("addressfamily", "any").lower() + if address_family != "any": try: family = socket.AF_INET6 - if address_family == 'inet': + if address_family == "inet": socket.AF_INET results = socket.getaddrinfo( self.host, @@ -277,11 +273,11 @@ class LazyFqdn(object): family, socket.SOCK_DGRAM, socket.IPPROTO_IP, - socket.AI_CANONNAME + socket.AI_CANONNAME, ) for res in results: af, socktype, proto, canonname, sa = res - if canonname and '.' in canonname: + if canonname and "." in canonname: fqdn = canonname break # giaerror -> socket.getaddrinfo() can't resolve self.host diff --git a/paramiko/dsskey.py b/paramiko/dsskey.py index ac1d4c2e..ec358ee2 100644 --- a/paramiko/dsskey.py +++ b/paramiko/dsskey.py @@ -25,7 +25,8 @@ from cryptography.hazmat.backends import default_backend from cryptography.hazmat.primitives import hashes, serialization from cryptography.hazmat.primitives.asymmetric import dsa from cryptography.hazmat.primitives.asymmetric.utils import ( - decode_dss_signature, encode_dss_signature + decode_dss_signature, + encode_dss_signature, ) from paramiko import util @@ -42,8 +43,15 @@ class DSSKey(PKey): data. """ - def __init__(self, msg=None, data=None, filename=None, password=None, - vals=None, file_obj=None): + def __init__( + self, + msg=None, + data=None, + filename=None, + password=None, + vals=None, + file_obj=None, + ): self.p = None self.q = None self.g = None @@ -63,8 +71,8 @@ class DSSKey(PKey): else: self._check_type_and_load_cert( msg=msg, - key_type='ssh-dss', - cert_type='ssh-dss-cert-v01@openssh.com', + key_type="ssh-dss", + cert_type="ssh-dss-cert-v01@openssh.com", ) self.p = msg.get_mpint() self.q = msg.get_mpint() @@ -74,7 +82,7 @@ class DSSKey(PKey): def asbytes(self): m = Message() - m.add_string('ssh-dss') + m.add_string("ssh-dss") m.add_mpint(self.p) m.add_mpint(self.q) m.add_mpint(self.g) @@ -88,7 +96,7 @@ class DSSKey(PKey): return hash((self.get_name(), self.p, self.q, self.g, self.y)) def get_name(self): - return 'ssh-dss' + return "ssh-dss" def get_bits(self): return self.size @@ -102,17 +110,15 @@ class DSSKey(PKey): public_numbers=dsa.DSAPublicNumbers( y=self.y, parameter_numbers=dsa.DSAParameterNumbers( - p=self.p, - q=self.q, - g=self.g - ) - ) + p=self.p, q=self.q, g=self.g + ), + ), ).private_key(backend=default_backend()) sig = key.sign(data, hashes.SHA1()) r, s = decode_dss_signature(sig) m = Message() - m.add_string('ssh-dss') + m.add_string("ssh-dss") # apparently, in rare cases, r or s may be shorter than 20 bytes! rstr = util.deflate_long(r, 0) sstr = util.deflate_long(s, 0) @@ -129,7 +135,7 @@ class DSSKey(PKey): sig = msg.asbytes() else: kind = msg.get_text() - if kind != 'ssh-dss': + if kind != "ssh-dss": return 0 sig = msg.get_binary() @@ -142,10 +148,8 @@ class DSSKey(PKey): key = dsa.DSAPublicNumbers( y=self.y, parameter_numbers=dsa.DSAParameterNumbers( - p=self.p, - q=self.q, - g=self.g - ) + p=self.p, q=self.q, g=self.g + ), ).public_key(backend=default_backend()) try: key.verify(signature, data, hashes.SHA1()) @@ -160,18 +164,16 @@ class DSSKey(PKey): public_numbers=dsa.DSAPublicNumbers( y=self.y, parameter_numbers=dsa.DSAParameterNumbers( - p=self.p, - q=self.q, - g=self.g - ) - ) + p=self.p, q=self.q, g=self.g + ), + ), ).private_key(backend=default_backend()) self._write_private_key_file( filename, key, serialization.PrivateFormat.TraditionalOpenSSL, - password=password + password=password, ) def write_private_key(self, file_obj, password=None): @@ -180,18 +182,16 @@ class DSSKey(PKey): public_numbers=dsa.DSAPublicNumbers( y=self.y, parameter_numbers=dsa.DSAParameterNumbers( - p=self.p, - q=self.q, - g=self.g - ) - ) + p=self.p, q=self.q, g=self.g + ), + ), ).private_key(backend=default_backend()) self._write_private_key( file_obj, key, serialization.PrivateFormat.TraditionalOpenSSL, - password=password + password=password, ) @staticmethod @@ -207,23 +207,25 @@ class DSSKey(PKey): numbers = dsa.generate_private_key( bits, backend=default_backend() ).private_numbers() - key = DSSKey(vals=( - numbers.public_numbers.parameter_numbers.p, - numbers.public_numbers.parameter_numbers.q, - numbers.public_numbers.parameter_numbers.g, - numbers.public_numbers.y - )) + key = DSSKey( + vals=( + numbers.public_numbers.parameter_numbers.p, + numbers.public_numbers.parameter_numbers.q, + numbers.public_numbers.parameter_numbers.g, + numbers.public_numbers.y, + ) + ) key.x = numbers.x return key # ...internals... def _from_private_key_file(self, filename, password): - data = self._read_private_key_file('DSA', filename, password) + data = self._read_private_key_file("DSA", filename, password) self._decode_key(data) def _from_private_key(self, file_obj, password): - data = self._read_private_key('DSA', file_obj, password) + data = self._read_private_key("DSA", file_obj, password) self._decode_key(data) def _decode_key(self, data): @@ -232,14 +234,11 @@ class DSSKey(PKey): try: keylist = BER(data).decode() except BERException as e: - raise SSHException('Unable to parse key file: ' + str(e)) - if ( - type(keylist) is not list or - len(keylist) < 6 or - keylist[0] != 0 - ): + raise SSHException("Unable to parse key file: " + str(e)) + if type(keylist) is not list or len(keylist) < 6 or keylist[0] != 0: raise SSHException( - 'not a valid DSA private key file (bad ber encoding)') + "not a valid DSA private key file (bad ber encoding)" + ) self.p = keylist[1] self.q = keylist[2] self.g = keylist[3] diff --git a/paramiko/ecdsakey.py b/paramiko/ecdsakey.py index 92e01a75..b73a969e 100644 --- a/paramiko/ecdsakey.py +++ b/paramiko/ecdsakey.py @@ -25,7 +25,8 @@ from cryptography.hazmat.backends import default_backend from cryptography.hazmat.primitives import hashes, serialization from cryptography.hazmat.primitives.asymmetric import ec from cryptography.hazmat.primitives.asymmetric.utils import ( - decode_dss_signature, encode_dss_signature + decode_dss_signature, + encode_dss_signature, ) from paramiko.common import four_byte @@ -43,6 +44,7 @@ class _ECDSACurve(object): the proper hash function. Also grabs the proper curve from the 'ecdsa' package. """ + def __init__(self, curve_class, nist_name): self.nist_name = nist_name self.key_length = curve_class.key_size @@ -67,6 +69,7 @@ class _ECDSACurveSet(object): format identifier. The two ways in which ECDSAKey needs to be able to look up curves. """ + def __init__(self, ecdsa_curves): self.ecdsa_curves = ecdsa_curves @@ -95,14 +98,24 @@ class ECDSAKey(PKey): data. """ - _ECDSA_CURVES = _ECDSACurveSet([ - _ECDSACurve(ec.SECP256R1, 'nistp256'), - _ECDSACurve(ec.SECP384R1, 'nistp384'), - _ECDSACurve(ec.SECP521R1, 'nistp521'), - ]) - - def __init__(self, msg=None, data=None, filename=None, password=None, - vals=None, file_obj=None, validate_point=True): + _ECDSA_CURVES = _ECDSACurveSet( + [ + _ECDSACurve(ec.SECP256R1, "nistp256"), + _ECDSACurve(ec.SECP384R1, "nistp384"), + _ECDSACurve(ec.SECP521R1, "nistp521"), + ] + ) + + def __init__( + self, + msg=None, + data=None, + filename=None, + password=None, + vals=None, + file_obj=None, + validate_point=True, + ): self.verifying_key = None self.signing_key = None self.public_blob = None @@ -126,21 +139,18 @@ class ECDSAKey(PKey): # identifier, so strip out any cert business. (NOTE: could push # that into _ECDSACurveSet.get_by_key_format_identifier(), but it # feels more correct to do it here?) - suffix = '-cert-v01@openssh.com' + suffix = "-cert-v01@openssh.com" if key_type.endswith(suffix): - key_type = key_type[:-len(suffix)] + key_type = key_type[: -len(suffix)] self.ecdsa_curve = self._ECDSA_CURVES.get_by_key_format_identifier( key_type ) key_types = self._ECDSA_CURVES.get_key_format_identifier_list() cert_types = [ - '{}-cert-v01@openssh.com'.format(x) - for x in key_types + "{}-cert-v01@openssh.com".format(x) for x in key_types ] self._check_type_and_load_cert( - msg=msg, - key_type=key_types, - cert_type=cert_types, + msg=msg, key_type=key_types, cert_type=cert_types ) curvename = msg.get_text() if curvename != self.ecdsa_curve.nist_name: @@ -172,10 +182,10 @@ class ECDSAKey(PKey): key_size_bytes = (key.curve.key_size + 7) // 8 x_bytes = deflate_long(numbers.x, add_sign_padding=False) - x_bytes = b'\x00' * (key_size_bytes - len(x_bytes)) + x_bytes + x_bytes = b"\x00" * (key_size_bytes - len(x_bytes)) + x_bytes y_bytes = deflate_long(numbers.y, add_sign_padding=False) - y_bytes = b'\x00' * (key_size_bytes - len(y_bytes)) + y_bytes + y_bytes = b"\x00" * (key_size_bytes - len(y_bytes)) + y_bytes point_str = four_byte + x_bytes + y_bytes m.add_string(point_str) @@ -185,8 +195,13 @@ class ECDSAKey(PKey): return self.asbytes() def __hash__(self): - return hash((self.get_name(), self.verifying_key.public_numbers().x, - self.verifying_key.public_numbers().y)) + return hash( + ( + self.get_name(), + self.verifying_key.public_numbers().x, + self.verifying_key.public_numbers().y, + ) + ) def get_name(self): return self.ecdsa_curve.key_format_identifier @@ -228,7 +243,7 @@ class ECDSAKey(PKey): filename, self.signing_key, serialization.PrivateFormat.TraditionalOpenSSL, - password=password + password=password, ) def write_private_key(self, file_obj, password=None): @@ -236,7 +251,7 @@ class ECDSAKey(PKey): file_obj, self.signing_key, serialization.PrivateFormat.TraditionalOpenSSL, - password=password + password=password, ) @classmethod @@ -260,11 +275,11 @@ class ECDSAKey(PKey): # ...internals... def _from_private_key_file(self, filename, password): - data = self._read_private_key_file('EC', filename, password) + data = self._read_private_key_file("EC", filename, password) self._decode_key(data) def _from_private_key(self, file_obj, password): - data = self._read_private_key('EC', file_obj, password) + data = self._read_private_key("EC", file_obj, password) self._decode_key(data) def _decode_key(self, data): diff --git a/paramiko/ed25519key.py b/paramiko/ed25519key.py index 8ad71d08..68ada224 100644 --- a/paramiko/ed25519key.py +++ b/paramiko/ed25519key.py @@ -56,8 +56,10 @@ class Ed25519Key(PKey): .. versionchanged:: 2.3 Added a ``file_obj`` parameter to match other key classes. """ - def __init__(self, msg=None, data=None, filename=None, password=None, - file_obj=None): + + def __init__( + self, msg=None, data=None, filename=None, password=None, file_obj=None + ): self.public_blob = None verifying_key = signing_key = None if msg is None and data is not None: @@ -86,6 +88,7 @@ class Ed25519Key(PKey): def _parse_signing_key_data(self, data, password): from paramiko.transport import Transport + # We may eventually want this to be usable for other key types, as # OpenSSH moves to it, but for now this is just for Ed25519 keys. # This format is described here: @@ -142,9 +145,9 @@ class Ed25519Key(PKey): ignore_few_rounds=True, ) decryptor = Cipher( - cipher["class"](key[:cipher["key-size"]]), - cipher["mode"](key[cipher["key-size"]:]), - backend=default_backend() + cipher["class"](key[: cipher["key-size"]]), + cipher["mode"](key[cipher["key-size"] :]), + backend=default_backend(), ).decryptor() private_data = ( decryptor.update(private_ciphertext) + decryptor.finalize() @@ -166,8 +169,10 @@ class Ed25519Key(PKey): signing_key = nacl.signing.SigningKey(key_data[:32]) # Verify that all the public keys are the same... assert ( - signing_key.verify_key.encode() == public == public_keys[i] == - key_data[32:] + signing_key.verify_key.encode() + == public + == public_keys[i] + == key_data[32:] ) signing_keys.append(signing_key) # Comment, ignore. diff --git a/paramiko/file.py b/paramiko/file.py index df9cdac7..9e9f6eb8 100644 --- a/paramiko/file.py +++ b/paramiko/file.py @@ -16,14 +16,18 @@ # along with Paramiko; if not, write to the Free Software Foundation, Inc., # 59 Temple Place, Suite 330, Boston, MA 02111-1307 USA. from paramiko.common import ( - linefeed_byte_value, crlf, cr_byte, linefeed_byte, cr_byte_value, + linefeed_byte_value, + crlf, + cr_byte, + linefeed_byte, + cr_byte_value, ) from paramiko.py3compat import BytesIO, PY2, u, bytes_types, text_type from paramiko.util import ClosingContextManager -class BufferedFile (ClosingContextManager): +class BufferedFile(ClosingContextManager): """ Reusable base class to implement Python-style file buffering around a simpler stream. @@ -70,7 +74,7 @@ class BufferedFile (ClosingContextManager): :raises: ``ValueError`` -- if the file is closed. """ if self._closed: - raise ValueError('I/O operation on closed file') + raise ValueError("I/O operation on closed file") return self def close(self): @@ -90,6 +94,7 @@ class BufferedFile (ClosingContextManager): return if PY2: + def next(self): """ Returns the next line from the input, or raises @@ -104,7 +109,9 @@ class BufferedFile (ClosingContextManager): if not line: raise StopIteration return line + else: + def __next__(self): """ Returns the next line from the input, or raises ``StopIteration`` @@ -159,7 +166,7 @@ class BufferedFile (ClosingContextManager): The number of bytes read. """ data = self.read(len(buff)) - buff[:len(data)] = data + buff[: len(data)] = data return len(data) def read(self, size=None): @@ -180,9 +187,9 @@ class BufferedFile (ClosingContextManager): encountered immediately """ if self._closed: - raise IOError('File is closed') + raise IOError("File is closed") if not (self._flags & self.FLAG_READ): - raise IOError('File is not open for reading') + raise IOError("File is not open for reading") if (size is None) or (size < 0): # go for broke result = self._rbuffer @@ -245,16 +252,16 @@ class BufferedFile (ClosingContextManager): """ # it's almost silly how complex this function is. if self._closed: - raise IOError('File is closed') + raise IOError("File is closed") if not (self._flags & self.FLAG_READ): - raise IOError('File not open for reading') + raise IOError("File not open for reading") line = self._rbuffer truncated = False while True: if ( - self._at_trailing_cr and - self._flags & self.FLAG_UNIVERSAL_NEWLINE and - len(line) > 0 + self._at_trailing_cr + and self._flags & self.FLAG_UNIVERSAL_NEWLINE + and len(line) > 0 ): # edge case: the newline may be '\r\n' and we may have read # only the first '\r' last time. @@ -276,12 +283,8 @@ class BufferedFile (ClosingContextManager): n = size - len(line) else: n = self._bufsize - if ( - linefeed_byte in line or - ( - self._flags & self.FLAG_UNIVERSAL_NEWLINE and - cr_byte in line - ) + if linefeed_byte in line or ( + self._flags & self.FLAG_UNIVERSAL_NEWLINE and cr_byte in line ): break try: @@ -306,9 +309,9 @@ class BufferedFile (ClosingContextManager): return line if self._flags & self.FLAG_BINARY else u(line) xpos = pos + 1 if ( - line[pos] == cr_byte_value and - xpos < len(line) and - line[xpos] == linefeed_byte_value + line[pos] == cr_byte_value + and xpos < len(line) + and line[xpos] == linefeed_byte_value ): xpos += 1 # if the string was truncated, _rbuffer needs to have the string after @@ -370,7 +373,7 @@ class BufferedFile (ClosingContextManager): :raises: ``IOError`` -- if the file doesn't support random access. """ - raise IOError('File does not support seeking.') + raise IOError("File does not support seeking.") def tell(self): """ @@ -393,11 +396,11 @@ class BufferedFile (ClosingContextManager): """ if isinstance(data, text_type): # Accept text and encode as utf-8 for compatibility only. - data = data.encode('utf-8') + data = data.encode("utf-8") if self._closed: - raise IOError('File is closed') + raise IOError("File is closed") if not (self._flags & self.FLAG_WRITE): - raise IOError('File not open for writing') + raise IOError("File not open for writing") if not (self._flags & self.FLAG_BUFFERED): self._write_all(data) return @@ -408,9 +411,9 @@ class BufferedFile (ClosingContextManager): if last_newline_pos >= 0: wbuf = self._wbuffer.getvalue() last_newline_pos += len(wbuf) - len(data) - self._write_all(wbuf[:last_newline_pos + 1]) + self._write_all(wbuf[: last_newline_pos + 1]) self._wbuffer = BytesIO() - self._wbuffer.write(wbuf[last_newline_pos + 1:]) + self._wbuffer.write(wbuf[last_newline_pos + 1 :]) return # even if we're line buffering, if the buffer has grown past the # buffer size, force a flush. @@ -457,7 +460,7 @@ class BufferedFile (ClosingContextManager): (subclass override) Write data into the stream. """ - raise IOError('write not implemented') + raise IOError("write not implemented") def _get_size(self): """ @@ -472,7 +475,7 @@ class BufferedFile (ClosingContextManager): # ...internals... - def _set_mode(self, mode='r', bufsize=-1): + def _set_mode(self, mode="r", bufsize=-1): """ Subclasses call this method to initialize the BufferedFile. """ @@ -495,17 +498,17 @@ class BufferedFile (ClosingContextManager): # unbuffered self._flags &= ~(self.FLAG_BUFFERED | self.FLAG_LINE_BUFFERED) - if ('r' in mode) or ('+' in mode): + if ("r" in mode) or ("+" in mode): self._flags |= self.FLAG_READ - if ('w' in mode) or ('+' in mode): + if ("w" in mode) or ("+" in mode): self._flags |= self.FLAG_WRITE - if 'a' in mode: + if "a" in mode: self._flags |= self.FLAG_WRITE | self.FLAG_APPEND self._size = self._get_size() self._pos = self._realpos = self._size - if 'b' in mode: + if "b" in mode: self._flags |= self.FLAG_BINARY - if 'U' in mode: + if "U" in mode: self._flags |= self.FLAG_UNIVERSAL_NEWLINE # built-in file objects have this attribute to store which kinds of # line terminations they've seen: @@ -534,9 +537,8 @@ class BufferedFile (ClosingContextManager): return if self.newlines is None: self.newlines = newline - elif ( - self.newlines != newline and - isinstance(self.newlines, bytes_types) + elif self.newlines != newline and isinstance( + self.newlines, bytes_types ): self.newlines = (self.newlines, newline) elif newline not in self.newlines: diff --git a/paramiko/hostkeys.py b/paramiko/hostkeys.py index ca185273..c4d611db 100644 --- a/paramiko/hostkeys.py +++ b/paramiko/hostkeys.py @@ -34,7 +34,7 @@ from paramiko.ed25519key import Ed25519Key from paramiko.ssh_exception import SSHException -class HostKeys (MutableMapping): +class HostKeys(MutableMapping): """ Representation of an OpenSSH-style "known hosts" file. Host keys can be read from one or more files, and then individual hosts can be looked up to @@ -88,10 +88,10 @@ class HostKeys (MutableMapping): :raises: ``IOError`` -- if there was an error reading the file """ - with open(filename, 'r') as f: + with open(filename, "r") as f: for lineno, line in enumerate(f, 1): line = line.strip() - if (len(line) == 0) or (line[0] == '#'): + if (len(line) == 0) or (line[0] == "#"): continue try: e = HostKeyEntry.from_line(line, lineno) @@ -118,7 +118,7 @@ class HostKeys (MutableMapping): .. versionadded:: 1.6.1 """ - with open(filename, 'w') as f: + with open(filename, "w") as f: for e in self._entries: line = e.to_line() if line: @@ -134,7 +134,9 @@ class HostKeys (MutableMapping): :return: dict of `str` -> `.PKey` keys associated with this host (or ``None``) """ - class SubDict (MutableMapping): + + class SubDict(MutableMapping): + def __init__(self, hostname, entries, hostkeys): self._hostname = hostname self._entries = entries @@ -176,7 +178,8 @@ class HostKeys (MutableMapping): def keys(self): return [ - e.key.get_name() for e in self._entries + e.key.get_name() + for e in self._entries if e.key is not None ] @@ -196,10 +199,10 @@ class HostKeys (MutableMapping): """ for h in entry.hostnames: if ( - h == hostname or - h.startswith('|1|') and - not hostname.startswith('|1|') and - constant_time_bytes_eq(self.hash_host(hostname, h), h) + h == hostname + or h.startswith("|1|") + and not hostname.startswith("|1|") + and constant_time_bytes_eq(self.hash_host(hostname, h), h) ): return True return False @@ -295,16 +298,17 @@ class HostKeys (MutableMapping): if salt is None: salt = os.urandom(sha1().digest_size) else: - if salt.startswith('|1|'): - salt = salt.split('|')[2] + if salt.startswith("|1|"): + salt = salt.split("|")[2] salt = decodebytes(b(salt)) assert len(salt) == sha1().digest_size hmac = HMAC(salt, b(hostname), sha1).digest() - hostkey = '|1|{}|{}'.format(u(encodebytes(salt)), u(encodebytes(hmac))) - return hostkey.replace('\n', '') + hostkey = "|1|{}|{}".format(u(encodebytes(salt)), u(encodebytes(hmac))) + return hostkey.replace("\n", "") class InvalidHostKey(Exception): + def __init__(self, line, exc): self.line = line self.exc = exc @@ -334,8 +338,8 @@ class HostKeyEntry: :param str line: a line from an OpenSSH known_hosts file """ - log = get_logger('paramiko.hostkeys') - fields = line.split(' ') + log = get_logger("paramiko.hostkeys") + fields = line.split(" ") if len(fields) < 3: # Bad number of fields msg = "Not enough fields found in known_hosts in line {} ({!r})" @@ -344,19 +348,19 @@ class HostKeyEntry: fields = fields[:3] names, keytype, key = fields - names = names.split(',') + names = names.split(",") # Decide what kind of key we're looking at and create an object # to hold it accordingly. try: key = b(key) - if keytype == 'ssh-rsa': + if keytype == "ssh-rsa": key = RSAKey(data=decodebytes(key)) - elif keytype == 'ssh-dss': + elif keytype == "ssh-dss": key = DSSKey(data=decodebytes(key)) elif keytype in ECDSAKey.supported_key_format_identifiers(): key = ECDSAKey(data=decodebytes(key), validate_point=False) - elif keytype == 'ssh-ed25519': + elif keytype == "ssh-ed25519": key = Ed25519Key(data=decodebytes(key)) else: log.info("Unable to handle key of type {}".format(keytype)) @@ -374,12 +378,12 @@ class HostKeyEntry: included. """ if self.valid: - return '{} {} {}\n'.format( - ','.join(self.hostnames), + return "{} {} {}\n".format( + ",".join(self.hostnames), self.key.get_name(), self.key.get_base64(), ) return None def __repr__(self): - return '<HostKeyEntry {!r}: {!r}>'.format(self.hostnames, self.key) + return "<HostKeyEntry {!r}: {!r}>".format(self.hostnames, self.key) diff --git a/paramiko/kex_ecdh_nist.py b/paramiko/kex_ecdh_nist.py index 4e8ff35d..1d87442a 100644 --- a/paramiko/kex_ecdh_nist.py +++ b/paramiko/kex_ecdh_nist.py @@ -15,7 +15,7 @@ _MSG_KEXECDH_INIT, _MSG_KEXECDH_REPLY = range(30, 32) c_MSG_KEXECDH_INIT, c_MSG_KEXECDH_REPLY = [byte_chr(c) for c in range(30, 32)] -class KexNistp256(): +class KexNistp256: name = "ecdh-sha2-nistp256" hash_algo = sha256 @@ -46,7 +46,7 @@ class KexNistp256(): elif not self.transport.server_mode and (ptype == _MSG_KEXECDH_REPLY): return self._parse_kexecdh_reply(m) raise SSHException( - 'KexECDH asked to handle packet type {:d}'.format(ptype) + "KexECDH asked to handle packet type {:d}".format(ptype) ) def _generate_key_pair(self): @@ -66,8 +66,12 @@ class KexNistp256(): K = long(hexlify(K), 16) # compute exchange hash hm = Message() - hm.add(self.transport.remote_version, self.transport.local_version, - self.transport.remote_kex_init, self.transport.local_kex_init) + hm.add( + self.transport.remote_version, + self.transport.local_version, + self.transport.remote_kex_init, + self.transport.local_kex_init, + ) hm.add_string(K_S) hm.add_string(Q_C_bytes) # SEC1: V2.0 2.3.3 Elliptic-Curve-Point-to-Octet-String Conversion @@ -96,8 +100,12 @@ class KexNistp256(): K = long(hexlify(K), 16) # compute exchange hash and verify signature hm = Message() - hm.add(self.transport.local_version, self.transport.remote_version, - self.transport.local_kex_init, self.transport.remote_kex_init) + hm.add( + self.transport.local_version, + self.transport.remote_version, + self.transport.local_kex_init, + self.transport.remote_kex_init, + ) hm.add_string(K_S) # SEC1: V2.0 2.3.3 Elliptic-Curve-Point-to-Octet-String Conversion hm.add_string(self.Q_C.public_numbers().encode_point()) diff --git a/paramiko/kex_gex.py b/paramiko/kex_gex.py index 44030569..fb8f01fd 100644 --- a/paramiko/kex_gex.py +++ b/paramiko/kex_gex.py @@ -32,17 +32,26 @@ from paramiko.py3compat import byte_chr, byte_ord, byte_mask from paramiko.ssh_exception import SSHException -_MSG_KEXDH_GEX_REQUEST_OLD, _MSG_KEXDH_GEX_GROUP, _MSG_KEXDH_GEX_INIT, \ - _MSG_KEXDH_GEX_REPLY, _MSG_KEXDH_GEX_REQUEST = range(30, 35) +( + _MSG_KEXDH_GEX_REQUEST_OLD, + _MSG_KEXDH_GEX_GROUP, + _MSG_KEXDH_GEX_INIT, + _MSG_KEXDH_GEX_REPLY, + _MSG_KEXDH_GEX_REQUEST, +) = range(30, 35) -c_MSG_KEXDH_GEX_REQUEST_OLD, c_MSG_KEXDH_GEX_GROUP, c_MSG_KEXDH_GEX_INIT, \ - c_MSG_KEXDH_GEX_REPLY, c_MSG_KEXDH_GEX_REQUEST = \ - [byte_chr(c) for c in range(30, 35)] +( + c_MSG_KEXDH_GEX_REQUEST_OLD, + c_MSG_KEXDH_GEX_GROUP, + c_MSG_KEXDH_GEX_INIT, + c_MSG_KEXDH_GEX_REPLY, + c_MSG_KEXDH_GEX_REQUEST, +) = [byte_chr(c) for c in range(30, 35)] -class KexGex (object): +class KexGex(object): - name = 'diffie-hellman-group-exchange-sha1' + name = "diffie-hellman-group-exchange-sha1" min_bits = 1024 max_bits = 8192 preferred_bits = 2048 @@ -61,7 +70,8 @@ class KexGex (object): def start_kex(self, _test_old_style=False): if self.transport.server_mode: self.transport._expect_packet( - _MSG_KEXDH_GEX_REQUEST, _MSG_KEXDH_GEX_REQUEST_OLD) + _MSG_KEXDH_GEX_REQUEST, _MSG_KEXDH_GEX_REQUEST_OLD + ) return # request a bit range: we accept (min_bits) to (max_bits), but prefer # (preferred_bits). according to the spec, we shouldn't pull the @@ -137,13 +147,12 @@ class KexGex (object): # generate prime pack = self.transport._get_modulus_pack() if pack is None: - raise SSHException( - 'Can\'t do server-side gex with no modulus pack') + raise SSHException("Can't do server-side gex with no modulus pack") self.transport._log( DEBUG, - 'Picking p ({} <= {} <= {} bits)'.format( - minbits, preferredbits, maxbits, - ) + "Picking p ({} <= {} <= {} bits)".format( + minbits, preferredbits, maxbits + ), ) self.g, self.p = pack.get_modulus(minbits, preferredbits, maxbits) m = Message() @@ -165,13 +174,13 @@ class KexGex (object): # generate prime pack = self.transport._get_modulus_pack() if pack is None: - raise SSHException( - 'Can\'t do server-side gex with no modulus pack') + raise SSHException("Can't do server-side gex with no modulus pack") self.transport._log( - DEBUG, 'Picking p (~ {} bits)'.format(self.preferred_bits) + DEBUG, "Picking p (~ {} bits)".format(self.preferred_bits) ) self.g, self.p = pack.get_modulus( - self.min_bits, self.preferred_bits, self.max_bits) + self.min_bits, self.preferred_bits, self.max_bits + ) m = Message() m.add_byte(c_MSG_KEXDH_GEX_GROUP) m.add_mpint(self.p) @@ -187,9 +196,10 @@ class KexGex (object): bitlen = util.bit_length(self.p) if (bitlen < 1024) or (bitlen > 8192): raise SSHException( - 'Server-generated gex p (don\'t ask) is out of range ' - '({} bits)'.format(bitlen)) - self.transport._log(DEBUG, 'Got server p ({} bits)'.format(bitlen)) + "Server-generated gex p (don't ask) is out of range " + "({} bits)".format(bitlen) + ) + self.transport._log(DEBUG, "Got server p ({} bits)".format(bitlen)) self._generate_x() # now compute e = g^x mod p self.e = pow(self.g, self.x, self.p) @@ -210,9 +220,13 @@ class KexGex (object): # okay, build up the hash H of # (V_C || V_S || I_C || I_S || K_S || min || n || max || p || g || e || f || K) # noqa hm = Message() - hm.add(self.transport.remote_version, self.transport.local_version, - self.transport.remote_kex_init, self.transport.local_kex_init, - key) + hm.add( + self.transport.remote_version, + self.transport.local_version, + self.transport.remote_kex_init, + self.transport.local_kex_init, + key, + ) if not self.old_style: hm.add_int(self.min_bits) hm.add_int(self.preferred_bits) @@ -246,9 +260,13 @@ class KexGex (object): # okay, build up the hash H of # (V_C || V_S || I_C || I_S || K_S || min || n || max || p || g || e || f || K) # noqa hm = Message() - hm.add(self.transport.local_version, self.transport.remote_version, - self.transport.local_kex_init, self.transport.remote_kex_init, - host_key) + hm.add( + self.transport.local_version, + self.transport.remote_version, + self.transport.local_kex_init, + self.transport.remote_kex_init, + host_key, + ) if not self.old_style: hm.add_int(self.min_bits) hm.add_int(self.preferred_bits) @@ -265,5 +283,5 @@ class KexGex (object): class KexGexSHA256(KexGex): - name = 'diffie-hellman-group-exchange-sha256' + name = "diffie-hellman-group-exchange-sha256" hash_algo = sha256 diff --git a/paramiko/kex_group1.py b/paramiko/kex_group1.py index 35e30e2e..e0d0bfe4 100644 --- a/paramiko/kex_group1.py +++ b/paramiko/kex_group1.py @@ -41,10 +41,12 @@ b0000000000000000 = zero_byte * 8 class KexGroup1(object): # draft-ietf-secsh-transport-09.txt, page 17 - P = 0xFFFFFFFFFFFFFFFFC90FDAA22168C234C4C6628B80DC1CD129024E088A67CC74020BBEA63B139B22514A08798E3404DDEF9519B3CD3A431B302B0A6DF25F14374FE1356D6D51C245E485B576625E7EC6F44C42E9A637ED6B0BFF5CB6F406B7EDEE386BFB5A899FA5AE9F24117C4B1FE649286651ECE65381FFFFFFFFFFFFFFFF # noqa + P = ( + 0xFFFFFFFFFFFFFFFFC90FDAA22168C234C4C6628B80DC1CD129024E088A67CC74020BBEA63B139B22514A08798E3404DDEF9519B3CD3A431B302B0A6DF25F14374FE1356D6D51C245E485B576625E7EC6F44C42E9A637ED6B0BFF5CB6F406B7EDEE386BFB5A899FA5AE9F24117C4B1FE649286651ECE65381FFFFFFFFFFFFFFFF # noqa + ) G = 2 - name = 'diffie-hellman-group1-sha1' + name = "diffie-hellman-group1-sha1" hash_algo = sha1 def __init__(self, transport): @@ -88,8 +90,10 @@ class KexGroup1(object): while 1: x_bytes = os.urandom(128) x_bytes = byte_mask(x_bytes[0], 0x7f) + x_bytes[1:] - if (x_bytes[:8] != b7fffffffffffffff and - x_bytes[:8] != b0000000000000000): + if ( + x_bytes[:8] != b7fffffffffffffff + and x_bytes[:8] != b0000000000000000 + ): break self.x = util.inflate_long(x_bytes) @@ -104,8 +108,12 @@ class KexGroup1(object): # okay, build up the hash H of # (V_C || V_S || I_C || I_S || K_S || e || f || K) hm = Message() - hm.add(self.transport.local_version, self.transport.remote_version, - self.transport.local_kex_init, self.transport.remote_kex_init) + hm.add( + self.transport.local_version, + self.transport.remote_version, + self.transport.local_kex_init, + self.transport.remote_kex_init, + ) hm.add_string(host_key) hm.add_mpint(self.e) hm.add_mpint(self.f) @@ -124,8 +132,12 @@ class KexGroup1(object): # okay, build up the hash H of # (V_C || V_S || I_C || I_S || K_S || e || f || K) hm = Message() - hm.add(self.transport.remote_version, self.transport.local_version, - self.transport.remote_kex_init, self.transport.local_kex_init) + hm.add( + self.transport.remote_version, + self.transport.local_version, + self.transport.remote_kex_init, + self.transport.local_kex_init, + ) hm.add_string(key) hm.add_mpint(self.e) hm.add_mpint(self.f) diff --git a/paramiko/kex_group14.py b/paramiko/kex_group14.py index a13e192a..d1f0fc2e 100644 --- a/paramiko/kex_group14.py +++ b/paramiko/kex_group14.py @@ -28,10 +28,12 @@ from hashlib import sha1, sha256 class KexGroup14(KexGroup1): # http://tools.ietf.org/html/rfc3526#section-3 - P = 0xFFFFFFFFFFFFFFFFC90FDAA22168C234C4C6628B80DC1CD129024E088A67CC74020BBEA63B139B22514A08798E3404DDEF9519B3CD3A431B302B0A6DF25F14374FE1356D6D51C245E485B576625E7EC6F44C42E9A637ED6B0BFF5CB6F406B7EDEE386BFB5A899FA5AE9F24117C4B1FE649286651ECE45B3DC2007CB8A163BF0598DA48361C55D39A69163FA8FD24CF5F83655D23DCA3AD961C62F356208552BB9ED529077096966D670C354E4ABC9804F1746C08CA18217C32905E462E36CE3BE39E772C180E86039B2783A2EC07A28FB5C55DF06F4C52C9DE2BCBF6955817183995497CEA956AE515D2261898FA051015728E5A8AACAA68FFFFFFFFFFFFFFFF # noqa + P = ( + 0xFFFFFFFFFFFFFFFFC90FDAA22168C234C4C6628B80DC1CD129024E088A67CC74020BBEA63B139B22514A08798E3404DDEF9519B3CD3A431B302B0A6DF25F14374FE1356D6D51C245E485B576625E7EC6F44C42E9A637ED6B0BFF5CB6F406B7EDEE386BFB5A899FA5AE9F24117C4B1FE649286651ECE45B3DC2007CB8A163BF0598DA48361C55D39A69163FA8FD24CF5F83655D23DCA3AD961C62F356208552BB9ED529077096966D670C354E4ABC9804F1746C08CA18217C32905E462E36CE3BE39E772C180E86039B2783A2EC07A28FB5C55DF06F4C52C9DE2BCBF6955817183995497CEA956AE515D2261898FA051015728E5A8AACAA68FFFFFFFFFFFFFFFF # noqa + ) G = 2 - name = 'diffie-hellman-group14-sha1' + name = "diffie-hellman-group14-sha1" hash_algo = sha1 diff --git a/paramiko/kex_gss.py b/paramiko/kex_gss.py index e21620fe..1510ff9c 100644 --- a/paramiko/kex_gss.py +++ b/paramiko/kex_gss.py @@ -47,14 +47,22 @@ from paramiko.py3compat import byte_chr, byte_mask, byte_ord from paramiko.ssh_exception import SSHException -MSG_KEXGSS_INIT, MSG_KEXGSS_CONTINUE, MSG_KEXGSS_COMPLETE, MSG_KEXGSS_HOSTKEY,\ - MSG_KEXGSS_ERROR = range(30, 35) -MSG_KEXGSS_GROUPREQ, MSG_KEXGSS_GROUP = range(40, 42) -c_MSG_KEXGSS_INIT, c_MSG_KEXGSS_CONTINUE, c_MSG_KEXGSS_COMPLETE,\ - c_MSG_KEXGSS_HOSTKEY, c_MSG_KEXGSS_ERROR = [ - byte_chr(c) for c in range(30, 35) - ] -c_MSG_KEXGSS_GROUPREQ, c_MSG_KEXGSS_GROUP = [ +( + MSG_KEXGSS_INIT, + MSG_KEXGSS_CONTINUE, + MSG_KEXGSS_COMPLETE, + MSG_KEXGSS_HOSTKEY, + MSG_KEXGSS_ERROR, +) = range(30, 35) +(MSG_KEXGSS_GROUPREQ, MSG_KEXGSS_GROUP) = range(40, 42) +( + c_MSG_KEXGSS_INIT, + c_MSG_KEXGSS_CONTINUE, + c_MSG_KEXGSS_COMPLETE, + c_MSG_KEXGSS_HOSTKEY, + c_MSG_KEXGSS_ERROR, +) = [byte_chr(c) for c in range(30, 35)] +(c_MSG_KEXGSS_GROUPREQ, c_MSG_KEXGSS_GROUP) = [ byte_chr(c) for c in range(40, 42) ] @@ -65,7 +73,9 @@ class KexGSSGroup1(object): 4462 Section 2 <https://tools.ietf.org/html/rfc4462.html#section-2>`_ """ # draft-ietf-secsh-transport-09.txt, page 17 - P = 0xFFFFFFFFFFFFFFFFC90FDAA22168C234C4C6628B80DC1CD129024E088A67CC74020BBEA63B139B22514A08798E3404DDEF9519B3CD3A431B302B0A6DF25F14374FE1356D6D51C245E485B576625E7EC6F44C42E9A637ED6B0BFF5CB6F406B7EDEE386BFB5A899FA5AE9F24117C4B1FE649286651ECE65381FFFFFFFFFFFFFFFF # noqa + P = ( + 0xFFFFFFFFFFFFFFFFC90FDAA22168C234C4C6628B80DC1CD129024E088A67CC74020BBEA63B139B22514A08798E3404DDEF9519B3CD3A431B302B0A6DF25F14374FE1356D6D51C245E485B576625E7EC6F44C42E9A637ED6B0BFF5CB6F406B7EDEE386BFB5A899FA5AE9F24117C4B1FE649286651ECE65381FFFFFFFFFFFFFFFF # noqa + ) G = 2 b7fffffffffffffff = byte_chr(0x7f) + max_byte * 7 # noqa b0000000000000000 = zero_byte * 8 # noqa @@ -98,10 +108,12 @@ class KexGSSGroup1(object): m.add_string(self.kexgss.ssh_init_sec_context(target=self.gss_host)) m.add_mpint(self.e) self.transport._send_message(m) - self.transport._expect_packet(MSG_KEXGSS_HOSTKEY, - MSG_KEXGSS_CONTINUE, - MSG_KEXGSS_COMPLETE, - MSG_KEXGSS_ERROR) + self.transport._expect_packet( + MSG_KEXGSS_HOSTKEY, + MSG_KEXGSS_CONTINUE, + MSG_KEXGSS_COMPLETE, + MSG_KEXGSS_ERROR, + ) def parse_next(self, ptype, m): """ @@ -120,7 +132,7 @@ class KexGSSGroup1(object): return self._parse_kexgss_complete(m) elif ptype == MSG_KEXGSS_ERROR: return self._parse_kexgss_error(m) - msg = 'GSS KexGroup1 asked to handle packet type {:d}' + msg = "GSS KexGroup1 asked to handle packet type {:d}" raise SSHException(msg.format(ptype)) # ## internals... @@ -152,8 +164,7 @@ class KexGSSGroup1(object): self.transport.host_key = host_key sig = m.get_string() self.transport._verify_key(host_key, sig) - self.transport._expect_packet(MSG_KEXGSS_CONTINUE, - MSG_KEXGSS_COMPLETE) + self.transport._expect_packet(MSG_KEXGSS_CONTINUE, MSG_KEXGSS_COMPLETE) def _parse_kexgss_continue(self, m): """ @@ -166,13 +177,14 @@ class KexGSSGroup1(object): srv_token = m.get_string() m = Message() m.add_byte(c_MSG_KEXGSS_CONTINUE) - m.add_string(self.kexgss.ssh_init_sec_context( - target=self.gss_host, recv_token=srv_token)) + m.add_string( + self.kexgss.ssh_init_sec_context( + target=self.gss_host, recv_token=srv_token + ) + ) self.transport.send_message(m) self.transport._expect_packet( - MSG_KEXGSS_CONTINUE, - MSG_KEXGSS_COMPLETE, - MSG_KEXGSS_ERROR + MSG_KEXGSS_CONTINUE, MSG_KEXGSS_COMPLETE, MSG_KEXGSS_ERROR ) else: pass @@ -200,8 +212,12 @@ class KexGSSGroup1(object): # okay, build up the hash H of # (V_C || V_S || I_C || I_S || K_S || e || f || K) hm = Message() - hm.add(self.transport.local_version, self.transport.remote_version, - self.transport.local_kex_init, self.transport.remote_kex_init) + hm.add( + self.transport.local_version, + self.transport.remote_version, + self.transport.local_kex_init, + self.transport.remote_kex_init, + ) hm.add_string(self.transport.host_key.__str__()) hm.add_mpint(self.e) hm.add_mpint(self.f) @@ -209,8 +225,9 @@ class KexGSSGroup1(object): H = sha1(str(hm)).digest() self.transport._set_K_H(K, H) if srv_token is not None: - self.kexgss.ssh_init_sec_context(target=self.gss_host, - recv_token=srv_token) + self.kexgss.ssh_init_sec_context( + target=self.gss_host, recv_token=srv_token + ) self.kexgss.ssh_check_mic(mic_token, H) else: self.kexgss.ssh_check_mic(mic_token, H) @@ -234,20 +251,26 @@ class KexGSSGroup1(object): # okay, build up the hash H of # (V_C || V_S || I_C || I_S || K_S || e || f || K) hm = Message() - hm.add(self.transport.remote_version, self.transport.local_version, - self.transport.remote_kex_init, self.transport.local_kex_init) + hm.add( + self.transport.remote_version, + self.transport.local_version, + self.transport.remote_kex_init, + self.transport.local_kex_init, + ) hm.add_string(key) hm.add_mpint(self.e) hm.add_mpint(self.f) hm.add_mpint(K) H = sha1(hm.asbytes()).digest() self.transport._set_K_H(K, H) - srv_token = self.kexgss.ssh_accept_sec_context(self.gss_host, - client_token) + srv_token = self.kexgss.ssh_accept_sec_context( + self.gss_host, client_token + ) m = Message() if self.kexgss._gss_srv_ctxt_status: - mic_token = self.kexgss.ssh_get_mic(self.transport.session_id, - gss_kex=True) + mic_token = self.kexgss.ssh_get_mic( + self.transport.session_id, gss_kex=True + ) m.add_byte(c_MSG_KEXGSS_COMPLETE) m.add_mpint(self.f) m.add_string(mic_token) @@ -263,9 +286,9 @@ class KexGSSGroup1(object): m.add_byte(c_MSG_KEXGSS_CONTINUE) m.add_string(srv_token) self.transport._send_message(m) - self.transport._expect_packet(MSG_KEXGSS_CONTINUE, - MSG_KEXGSS_COMPLETE, - MSG_KEXGSS_ERROR) + self.transport._expect_packet( + MSG_KEXGSS_CONTINUE, MSG_KEXGSS_COMPLETE, MSG_KEXGSS_ERROR + ) def _parse_kexgss_error(self, m): """ @@ -281,12 +304,16 @@ class KexGSSGroup1(object): maj_status = m.get_int() min_status = m.get_int() err_msg = m.get_string() - m.get_string() # we don't care about the language! - raise SSHException("""GSS-API Error: + m.get_string() # we don't care about the language! + raise SSHException( + """GSS-API Error: Major Status: {} Minor Status: {} Error Message: {} -""".format(maj_status, min_status, err_msg)) +""".format( + maj_status, min_status, err_msg + ) + ) class KexGSSGroup14(KexGSSGroup1): @@ -295,7 +322,9 @@ class KexGSSGroup14(KexGSSGroup1): in `RFC 4462 Section 2 <https://tools.ietf.org/html/rfc4462.html#section-2>`_ """ - P = 0xFFFFFFFFFFFFFFFFC90FDAA22168C234C4C6628B80DC1CD129024E088A67CC74020BBEA63B139B22514A08798E3404DDEF9519B3CD3A431B302B0A6DF25F14374FE1356D6D51C245E485B576625E7EC6F44C42E9A637ED6B0BFF5CB6F406B7EDEE386BFB5A899FA5AE9F24117C4B1FE649286651ECE45B3DC2007CB8A163BF0598DA48361C55D39A69163FA8FD24CF5F83655D23DCA3AD961C62F356208552BB9ED529077096966D670C354E4ABC9804F1746C08CA18217C32905E462E36CE3BE39E772C180E86039B2783A2EC07A28FB5C55DF06F4C52C9DE2BCBF6955817183995497CEA956AE515D2261898FA051015728E5A8AACAA68FFFFFFFFFFFFFFFF # noqa + P = ( + 0xFFFFFFFFFFFFFFFFC90FDAA22168C234C4C6628B80DC1CD129024E088A67CC74020BBEA63B139B22514A08798E3404DDEF9519B3CD3A431B302B0A6DF25F14374FE1356D6D51C245E485B576625E7EC6F44C42E9A637ED6B0BFF5CB6F406B7EDEE386BFB5A899FA5AE9F24117C4B1FE649286651ECE45B3DC2007CB8A163BF0598DA48361C55D39A69163FA8FD24CF5F83655D23DCA3AD961C62F356208552BB9ED529077096966D670C354E4ABC9804F1746C08CA18217C32905E462E36CE3BE39E772C180E86039B2783A2EC07A28FB5C55DF06F4C52C9DE2BCBF6955817183995497CEA956AE515D2261898FA051015728E5A8AACAA68FFFFFFFFFFFFFFFF # noqa + ) G = 2 NAME = "gss-group14-sha1-toWM5Slw5Ew8Mqkay+al2g==" @@ -362,7 +391,7 @@ class KexGSSGex(object): return self._parse_kexgss_complete(m) elif ptype == MSG_KEXGSS_ERROR: return self._parse_kexgss_error(m) - msg = 'KexGex asked to handle packet type {:d}' + msg = "KexGex asked to handle packet type {:d}" raise SSHException(msg.format(ptype)) # ## internals... @@ -414,13 +443,12 @@ class KexGSSGex(object): # generate prime pack = self.transport._get_modulus_pack() if pack is None: - raise SSHException( - 'Can\'t do server-side gex with no modulus pack') + raise SSHException("Can't do server-side gex with no modulus pack") self.transport._log( DEBUG, # noqa - 'Picking p ({} <= {} <= {} bits)'.format( - minbits, preferredbits, maxbits, - ) + "Picking p ({} <= {} <= {} bits)".format( + minbits, preferredbits, maxbits + ), ) self.g, self.p = pack.get_modulus(minbits, preferredbits, maxbits) m = Message() @@ -442,9 +470,12 @@ class KexGSSGex(object): bitlen = util.bit_length(self.p) if (bitlen < 1024) or (bitlen > 8192): raise SSHException( - 'Server-generated gex p (don\'t ask) is out of range ' - '({} bits)'.format(bitlen)) - self.transport._log(DEBUG, 'Got server p ({} bits)'.format(bitlen)) # noqa + "Server-generated gex p (don't ask) is out of range " + "({} bits)".format(bitlen) + ) + self.transport._log( + DEBUG, "Got server p ({} bits)".format(bitlen) + ) # noqa self._generate_x() # now compute e = g^x mod p self.e = pow(self.g, self.x, self.p) @@ -453,10 +484,12 @@ class KexGSSGex(object): m.add_string(self.kexgss.ssh_init_sec_context(target=self.gss_host)) m.add_mpint(self.e) self.transport._send_message(m) - self.transport._expect_packet(MSG_KEXGSS_HOSTKEY, - MSG_KEXGSS_CONTINUE, - MSG_KEXGSS_COMPLETE, - MSG_KEXGSS_ERROR) + self.transport._expect_packet( + MSG_KEXGSS_HOSTKEY, + MSG_KEXGSS_CONTINUE, + MSG_KEXGSS_COMPLETE, + MSG_KEXGSS_ERROR, + ) def _parse_kexgss_gex_init(self, m): """ @@ -476,9 +509,13 @@ class KexGSSGex(object): # okay, build up the hash H of # (V_C || V_S || I_C || I_S || K_S || min || n || max || p || g || e || f || K) # noqa hm = Message() - hm.add(self.transport.remote_version, self.transport.local_version, - self.transport.remote_kex_init, self.transport.local_kex_init, - key) + hm.add( + self.transport.remote_version, + self.transport.local_version, + self.transport.remote_kex_init, + self.transport.local_kex_init, + key, + ) hm.add_int(self.min_bits) hm.add_int(self.preferred_bits) hm.add_int(self.max_bits) @@ -489,12 +526,14 @@ class KexGSSGex(object): hm.add_mpint(K) H = sha1(hm.asbytes()).digest() self.transport._set_K_H(K, H) - srv_token = self.kexgss.ssh_accept_sec_context(self.gss_host, - client_token) + srv_token = self.kexgss.ssh_accept_sec_context( + self.gss_host, client_token + ) m = Message() if self.kexgss._gss_srv_ctxt_status: - mic_token = self.kexgss.ssh_get_mic(self.transport.session_id, - gss_kex=True) + mic_token = self.kexgss.ssh_get_mic( + self.transport.session_id, gss_kex=True + ) m.add_byte(c_MSG_KEXGSS_COMPLETE) m.add_mpint(self.f) m.add_string(mic_token) @@ -510,9 +549,9 @@ class KexGSSGex(object): m.add_byte(c_MSG_KEXGSS_CONTINUE) m.add_string(srv_token) self.transport._send_message(m) - self.transport._expect_packet(MSG_KEXGSS_CONTINUE, - MSG_KEXGSS_COMPLETE, - MSG_KEXGSS_ERROR) + self.transport._expect_packet( + MSG_KEXGSS_CONTINUE, MSG_KEXGSS_COMPLETE, MSG_KEXGSS_ERROR + ) def _parse_kexgss_hostkey(self, m): """ @@ -525,8 +564,7 @@ class KexGSSGex(object): self.transport.host_key = host_key sig = m.get_string() self.transport._verify_key(host_key, sig) - self.transport._expect_packet(MSG_KEXGSS_CONTINUE, - MSG_KEXGSS_COMPLETE) + self.transport._expect_packet(MSG_KEXGSS_CONTINUE, MSG_KEXGSS_COMPLETE) def _parse_kexgss_continue(self, m): """ @@ -538,12 +576,15 @@ class KexGSSGex(object): srv_token = m.get_string() m = Message() m.add_byte(c_MSG_KEXGSS_CONTINUE) - m.add_string(self.kexgss.ssh_init_sec_context(target=self.gss_host, - recv_token=srv_token)) + m.add_string( + self.kexgss.ssh_init_sec_context( + target=self.gss_host, recv_token=srv_token + ) + ) self.transport.send_message(m) - self.transport._expect_packet(MSG_KEXGSS_CONTINUE, - MSG_KEXGSS_COMPLETE, - MSG_KEXGSS_ERROR) + self.transport._expect_packet( + MSG_KEXGSS_CONTINUE, MSG_KEXGSS_COMPLETE, MSG_KEXGSS_ERROR + ) else: pass @@ -568,9 +609,13 @@ class KexGSSGex(object): # okay, build up the hash H of # (V_C || V_S || I_C || I_S || K_S || min || n || max || p || g || e || f || K) # noqa hm = Message() - hm.add(self.transport.local_version, self.transport.remote_version, - self.transport.local_kex_init, self.transport.remote_kex_init, - self.transport.host_key.__str__()) + hm.add( + self.transport.local_version, + self.transport.remote_version, + self.transport.local_kex_init, + self.transport.remote_kex_init, + self.transport.host_key.__str__(), + ) if not self.old_style: hm.add_int(self.min_bits) hm.add_int(self.preferred_bits) @@ -584,8 +629,9 @@ class KexGSSGex(object): H = sha1(hm.asbytes()).digest() self.transport._set_K_H(K, H) if srv_token is not None: - self.kexgss.ssh_init_sec_context(target=self.gss_host, - recv_token=srv_token) + self.kexgss.ssh_init_sec_context( + target=self.gss_host, recv_token=srv_token + ) self.kexgss.ssh_check_mic(mic_token, H) else: self.kexgss.ssh_check_mic(mic_token, H) @@ -606,12 +652,16 @@ class KexGSSGex(object): maj_status = m.get_int() min_status = m.get_int() err_msg = m.get_string() - m.get_string() # we don't care about the language (lang_tag)! - raise SSHException("""GSS-API Error: + m.get_string() # we don't care about the language (lang_tag)! + raise SSHException( + """GSS-API Error: Major Status: {} Minor Status: {} Error Message: {} -""".format(maj_status, min_status, err_msg)) +""".format( + maj_status, min_status, err_msg + ) + ) class NullHostKey(object): @@ -620,6 +670,7 @@ class NullHostKey(object): in `RFC 4462 Section 5 <https://tools.ietf.org/html/rfc4462.html#section-5>`_ """ + def __init__(self): self.key = "" diff --git a/paramiko/message.py b/paramiko/message.py index 9af841da..dead3508 100644 --- a/paramiko/message.py +++ b/paramiko/message.py @@ -27,7 +27,7 @@ from paramiko.common import zero_byte, max_byte, one_byte, asbytes from paramiko.py3compat import long, BytesIO, u, integer_types -class Message (object): +class Message(object): """ An SSH2 message is a stream of bytes that encodes some combination of strings, integers, bools, and infinite-precision integers (known in Python @@ -63,7 +63,7 @@ class Message (object): """ Returns a string representation of this object, for debugging. """ - return 'paramiko.Message(' + repr(self.packet.getvalue()) + ')' + return "paramiko.Message(" + repr(self.packet.getvalue()) + ")" def asbytes(self): """ @@ -139,13 +139,13 @@ class Message (object): if byte == max_byte: return util.inflate_long(self.get_binary()) byte += self.get_bytes(3) - return struct.unpack('>I', byte)[0] + return struct.unpack(">I", byte)[0] def get_int(self): """ Fetch an int from the stream. """ - return struct.unpack('>I', self.get_bytes(4))[0] + return struct.unpack(">I", self.get_bytes(4))[0] def get_int64(self): """ @@ -153,7 +153,7 @@ class Message (object): :return: a 64-bit unsigned integer (`long`). """ - return struct.unpack('>Q', self.get_bytes(8))[0] + return struct.unpack(">Q", self.get_bytes(8))[0] def get_mpint(self): """ @@ -191,7 +191,7 @@ class Message (object): These are trivially encoded as comma-separated values in a string. """ - return self.get_text().split(',') + return self.get_text().split(",") def add_bytes(self, b): """ @@ -229,7 +229,7 @@ class Message (object): :param int n: integer to add """ - self.packet.write(struct.pack('>I', n)) + self.packet.write(struct.pack(">I", n)) return self def add_adaptive_int(self, n): @@ -242,7 +242,7 @@ class Message (object): self.packet.write(max_byte) self.add_string(util.deflate_long(n)) else: - self.packet.write(struct.pack('>I', n)) + self.packet.write(struct.pack(">I", n)) return self def add_int64(self, n): @@ -251,7 +251,7 @@ class Message (object): :param long n: long int to add """ - self.packet.write(struct.pack('>Q', n)) + self.packet.write(struct.pack(">Q", n)) return self def add_mpint(self, z): @@ -283,7 +283,7 @@ class Message (object): :param l: list of strings to add """ - self.add_string(','.join(l)) + self.add_string(",".join(l)) return self def _add(self, i): diff --git a/paramiko/packet.py b/paramiko/packet.py index 2a1e91e2..d324fc35 100644 --- a/paramiko/packet.py +++ b/paramiko/packet.py @@ -30,7 +30,12 @@ from hmac import HMAC from paramiko import util from paramiko.common import ( - linefeed_byte, cr_byte_value, asbytes, MSG_NAMES, DEBUG, xffffffff, + linefeed_byte, + cr_byte_value, + asbytes, + MSG_NAMES, + DEBUG, + xffffffff, zero_byte, ) from paramiko.py3compat import u, byte_ord @@ -42,7 +47,7 @@ def compute_hmac(key, message, digest_class): return HMAC(key, message, digest_class).digest() -class NeedRekeyException (Exception): +class NeedRekeyException(Exception): """ Exception indicating a rekey is needed. """ @@ -56,7 +61,7 @@ def first_arg(e): return arg -class Packetizer (object): +class Packetizer(object): """ Implementation of the base SSH packet protocol. """ @@ -128,8 +133,15 @@ class Packetizer (object): """ self.__logger = log - def set_outbound_cipher(self, block_engine, block_size, mac_engine, - mac_size, mac_key, sdctr=False): + def set_outbound_cipher( + self, + block_engine, + block_size, + mac_engine, + mac_size, + mac_key, + sdctr=False, + ): """ Switch outbound data cipher. """ @@ -149,7 +161,8 @@ class Packetizer (object): self.__need_rekey = False def set_inbound_cipher( - self, block_engine, block_size, mac_engine, mac_size, mac_key): + self, block_engine, block_size, mac_engine, mac_size, mac_key + ): """ Switch inbound data cipher. """ @@ -352,7 +365,7 @@ class Packetizer (object): while linefeed_byte not in buf: buf += self._read_timeout(timeout) n = buf.index(linefeed_byte) - self.__remainder = buf[n + 1:] + self.__remainder = buf[n + 1 :] buf = buf[:n] if (len(buf) > 0) and (buf[-1] == cr_byte_value): buf = buf[:-1] @@ -368,7 +381,7 @@ class Packetizer (object): if cmd in MSG_NAMES: cmd_name = MSG_NAMES[cmd] else: - cmd_name = '${:x}'.format(cmd) + cmd_name = "${:x}".format(cmd) orig_len = len(data) self.__write_lock.acquire() try: @@ -378,37 +391,38 @@ class Packetizer (object): if self.__dump_packets: self._log( DEBUG, - 'Write packet <{}>, length {}'.format(cmd_name, orig_len) + "Write packet <{}>, length {}".format(cmd_name, orig_len), ) - self._log(DEBUG, util.format_binary(packet, 'OUT: ')) + self._log(DEBUG, util.format_binary(packet, "OUT: ")) if self.__block_engine_out is not None: out = self.__block_engine_out.update(packet) else: out = packet # + mac if self.__block_engine_out is not None: - payload = struct.pack( - '>I', self.__sequence_number_out) + packet + payload = ( + struct.pack(">I", self.__sequence_number_out) + packet + ) out += compute_hmac( - self.__mac_key_out, - payload, - self.__mac_engine_out)[:self.__mac_size_out] - self.__sequence_number_out = \ - (self.__sequence_number_out + 1) & xffffffff + self.__mac_key_out, payload, self.__mac_engine_out + )[: self.__mac_size_out] + self.__sequence_number_out = ( + self.__sequence_number_out + 1 + ) & xffffffff self.write_all(out) self.__sent_bytes += len(out) self.__sent_packets += 1 sent_too_much = ( - self.__sent_packets >= self.REKEY_PACKETS or - self.__sent_bytes >= self.REKEY_BYTES + self.__sent_packets >= self.REKEY_PACKETS + or self.__sent_bytes >= self.REKEY_BYTES ) if sent_too_much and not self.__need_rekey: # only ask once for rekeying msg = "Rekeying (hit {} packets, {} bytes sent)" - self._log(DEBUG, msg.format( - self.__sent_packets, self.__sent_bytes, - )) + self._log( + DEBUG, msg.format(self.__sent_packets, self.__sent_bytes) + ) self.__received_bytes_overflow = 0 self.__received_packets_overflow = 0 self._trigger_rekey() @@ -427,41 +441,42 @@ class Packetizer (object): if self.__block_engine_in is not None: header = self.__block_engine_in.update(header) if self.__dump_packets: - self._log(DEBUG, util.format_binary(header, 'IN: ')) - packet_size = struct.unpack('>I', header[:4])[0] + self._log(DEBUG, util.format_binary(header, "IN: ")) + packet_size = struct.unpack(">I", header[:4])[0] # leftover contains decrypted bytes from the first block (after the # length field) leftover = header[4:] if (packet_size - len(leftover)) % self.__block_size_in != 0: - raise SSHException('Invalid packet blocking') + raise SSHException("Invalid packet blocking") buf = self.read_all(packet_size + self.__mac_size_in - len(leftover)) - packet = buf[:packet_size - len(leftover)] - post_packet = buf[packet_size - len(leftover):] + packet = buf[: packet_size - len(leftover)] + post_packet = buf[packet_size - len(leftover) :] if self.__block_engine_in is not None: packet = self.__block_engine_in.update(packet) if self.__dump_packets: - self._log(DEBUG, util.format_binary(packet, 'IN: ')) + self._log(DEBUG, util.format_binary(packet, "IN: ")) packet = leftover + packet if self.__mac_size_in > 0: - mac = post_packet[:self.__mac_size_in] - mac_payload = struct.pack( - '>II', self.__sequence_number_in, packet_size) + packet + mac = post_packet[: self.__mac_size_in] + mac_payload = ( + struct.pack(">II", self.__sequence_number_in, packet_size) + + packet + ) my_mac = compute_hmac( - self.__mac_key_in, - mac_payload, - self.__mac_engine_in)[:self.__mac_size_in] + self.__mac_key_in, mac_payload, self.__mac_engine_in + )[: self.__mac_size_in] if not util.constant_time_bytes_eq(my_mac, mac): - raise SSHException('Mismatched MAC') + raise SSHException("Mismatched MAC") padding = byte_ord(packet[0]) - payload = packet[1:packet_size - padding] + payload = packet[1 : packet_size - padding] if self.__dump_packets: self._log( DEBUG, - 'Got payload ({} bytes, {} padding)'.format( + "Got payload ({} bytes, {} padding)".format( packet_size, padding - ) + ), ) if self.__compress_engine_in is not None: @@ -480,19 +495,24 @@ class Packetizer (object): # dropping the connection self.__received_bytes_overflow += raw_packet_size self.__received_packets_overflow += 1 - if (self.__received_packets_overflow >= - self.REKEY_PACKETS_OVERFLOW_MAX) or \ - (self.__received_bytes_overflow >= - self.REKEY_BYTES_OVERFLOW_MAX): + if ( + self.__received_packets_overflow + >= self.REKEY_PACKETS_OVERFLOW_MAX + ) or ( + self.__received_bytes_overflow >= self.REKEY_BYTES_OVERFLOW_MAX + ): raise SSHException( - 'Remote transport is ignoring rekey requests') - elif (self.__received_packets >= self.REKEY_PACKETS) or \ - (self.__received_bytes >= self.REKEY_BYTES): + "Remote transport is ignoring rekey requests" + ) + elif (self.__received_packets >= self.REKEY_PACKETS) or ( + self.__received_bytes >= self.REKEY_BYTES + ): # only ask once for rekeying err = "Rekeying (hit {} packets, {} bytes received)" - self._log(DEBUG, err.format( - self.__received_packets, self.__received_bytes, - )) + self._log( + DEBUG, + err.format(self.__received_packets, self.__received_bytes), + ) self.__received_bytes_overflow = 0 self.__received_packets_overflow = 0 self._trigger_rekey() @@ -501,11 +521,11 @@ class Packetizer (object): if cmd in MSG_NAMES: cmd_name = MSG_NAMES[cmd] else: - cmd_name = '${:x}'.format(cmd) + cmd_name = "${:x}".format(cmd) if self.__dump_packets: self._log( DEBUG, - 'Read packet <{}>, length {}'.format(cmd_name, len(payload)) + "Read packet <{}>, length {}".format(cmd_name, len(payload)), ) return cmd, msg @@ -522,9 +542,9 @@ class Packetizer (object): def _check_keepalive(self): if ( - not self.__keepalive_interval or - not self.__block_engine_out or - self.__need_rekey + not self.__keepalive_interval + or not self.__block_engine_out + or self.__need_rekey ): # wait till we're encrypting, and not in the middle of rekeying return @@ -559,13 +579,13 @@ class Packetizer (object): # pad up at least 4 bytes, to nearest block-size (usually 8) bsize = self.__block_size_out padding = 3 + bsize - ((len(payload) + 8) % bsize) - packet = struct.pack('>IB', len(payload) + padding + 1, padding) + packet = struct.pack(">IB", len(payload) + padding + 1, padding) packet += payload if self.__sdctr_out or self.__block_engine_out is None: # cute trick i caught openssh doing: if we're not encrypting or # SDCTR mode (RFC4344), # don't waste random bytes for the padding - packet += (zero_byte * padding) + packet += zero_byte * padding else: packet += os.urandom(padding) return packet diff --git a/paramiko/pipe.py b/paramiko/pipe.py index 6ca37703..e88a5e44 100644 --- a/paramiko/pipe.py +++ b/paramiko/pipe.py @@ -31,14 +31,15 @@ import socket def make_pipe(): - if sys.platform[:3] != 'win': + if sys.platform[:3] != "win": p = PosixPipe() else: p = WindowsPipe() return p -class PosixPipe (object): +class PosixPipe(object): + def __init__(self): self._rfd, self._wfd = os.pipe() self._set = False @@ -64,26 +65,27 @@ class PosixPipe (object): if self._set or self._closed: return self._set = True - os.write(self._wfd, b'*') + os.write(self._wfd, b"*") def set_forever(self): self._forever = True self.set() -class WindowsPipe (object): +class WindowsPipe(object): """ On Windows, only an OS-level "WinSock" may be used in select(), but reads and writes must be to the actual socket object. """ + def __init__(self): serv = socket.socket(socket.AF_INET, socket.SOCK_STREAM) - serv.bind(('127.0.0.1', 0)) + serv.bind(("127.0.0.1", 0)) serv.listen(1) # need to save sockets in _rsock/_wsock so they don't get closed self._rsock = socket.socket(socket.AF_INET, socket.SOCK_STREAM) - self._rsock.connect(('127.0.0.1', serv.getsockname()[1])) + self._rsock.connect(("127.0.0.1", serv.getsockname()[1])) self._wsock, addr = serv.accept() serv.close() @@ -110,14 +112,15 @@ class WindowsPipe (object): if self._set or self._closed: return self._set = True - self._wsock.send(b'*') + self._wsock.send(b"*") def set_forever(self): self._forever = True self.set() -class OrPipe (object): +class OrPipe(object): + def __init__(self, pipe): self._set = False self._partner = None diff --git a/paramiko/pkey.py b/paramiko/pkey.py index 808215f8..fa014800 100644 --- a/paramiko/pkey.py +++ b/paramiko/pkey.py @@ -43,23 +43,23 @@ class PKey(object): # known encryption types for private key files: _CIPHER_TABLE = { - 'AES-128-CBC': { - 'cipher': algorithms.AES, - 'keysize': 16, - 'blocksize': 16, - 'mode': modes.CBC + "AES-128-CBC": { + "cipher": algorithms.AES, + "keysize": 16, + "blocksize": 16, + "mode": modes.CBC, }, - 'AES-256-CBC': { - 'cipher': algorithms.AES, - 'keysize': 32, - 'blocksize': 16, - 'mode': modes.CBC + "AES-256-CBC": { + "cipher": algorithms.AES, + "keysize": 32, + "blocksize": 16, + "mode": modes.CBC, }, - 'DES-EDE3-CBC': { - 'cipher': algorithms.TripleDES, - 'keysize': 24, - 'blocksize': 8, - 'mode': modes.CBC + "DES-EDE3-CBC": { + "cipher": algorithms.TripleDES, + "keysize": 24, + "blocksize": 8, + "mode": modes.CBC, }, } @@ -107,7 +107,7 @@ class PKey(object): hs = hash(self) ho = hash(other) if hs != ho: - return cmp(hs, ho) # noqa + return cmp(hs, ho) # noqa return cmp(self.asbytes(), other.asbytes()) # noqa def __eq__(self, other): @@ -121,7 +121,7 @@ class PKey(object): name of this private key type, in SSH terminology, as a `str` (for example, ``"ssh-rsa"``). """ - return '' + return "" def get_bits(self): """ @@ -158,7 +158,7 @@ class PKey(object): :return: a base64 `string <str>` containing the public part of the key. """ - return u(encodebytes(self.asbytes())).replace('\n', '') + return u(encodebytes(self.asbytes())).replace("\n", "") def sign_ssh_data(self, data): """ @@ -239,7 +239,7 @@ class PKey(object): :raises: ``IOError`` -- if there was an error writing the file :raises: `.SSHException` -- if the key is invalid """ - raise Exception('Not implemented in PKey') + raise Exception("Not implemented in PKey") def write_private_key(self, file_obj, password=None): """ @@ -252,7 +252,7 @@ class PKey(object): :raises: ``IOError`` -- if there was an error writing to the file :raises: `.SSHException` -- if the key is invalid """ - raise Exception('Not implemented in PKey') + raise Exception("Not implemented in PKey") def _read_private_key_file(self, tag, filename, password=None): """ @@ -275,60 +275,61 @@ class PKey(object): encrypted, and ``password`` is ``None``. :raises: `.SSHException` -- if the key file is invalid. """ - with open(filename, 'r') as f: + with open(filename, "r") as f: data = self._read_private_key(tag, f, password) return data def _read_private_key(self, tag, f, password=None): lines = f.readlines() start = 0 - beginning_of_key = '-----BEGIN ' + tag + ' PRIVATE KEY-----' + beginning_of_key = "-----BEGIN " + tag + " PRIVATE KEY-----" while start < len(lines) and lines[start].strip() != beginning_of_key: start += 1 if start >= len(lines): - raise SSHException('not a valid ' + tag + ' private key file') + raise SSHException("not a valid " + tag + " private key file") # parse any headers first headers = {} start += 1 while start < len(lines): - l = lines[start].split(': ') + l = lines[start].split(": ") if len(l) == 1: break headers[l[0].lower()] = l[1].strip() start += 1 # find end end = start - ending_of_key = '-----END ' + tag + ' PRIVATE KEY-----' + ending_of_key = "-----END " + tag + " PRIVATE KEY-----" while end < len(lines) and lines[end].strip() != ending_of_key: end += 1 # if we trudged to the end of the file, just try to cope. try: - data = decodebytes(b(''.join(lines[start:end]))) + data = decodebytes(b("".join(lines[start:end]))) except base64.binascii.Error as e: - raise SSHException('base64 decoding error: ' + str(e)) - if 'proc-type' not in headers: + raise SSHException("base64 decoding error: " + str(e)) + if "proc-type" not in headers: # unencryped: done return data # encrypted keyfile: will need a password - proc_type = headers['proc-type'] - if proc_type != '4,ENCRYPTED': + proc_type = headers["proc-type"] + if proc_type != "4,ENCRYPTED": raise SSHException( 'Unknown private key structure "{}"'.format(proc_type) ) try: - encryption_type, saltstr = headers['dek-info'].split(',') + encryption_type, saltstr = headers["dek-info"].split(",") except: raise SSHException("Can't parse DEK-info in private key file") if encryption_type not in self._CIPHER_TABLE: raise SSHException( - 'Unknown private key cipher "{}"'.format(encryption_type)) + 'Unknown private key cipher "{}"'.format(encryption_type) + ) # if no password was passed in, # raise an exception pointing out that we need one if password is None: - raise PasswordRequiredException('Private key file is encrypted') - cipher = self._CIPHER_TABLE[encryption_type]['cipher'] - keysize = self._CIPHER_TABLE[encryption_type]['keysize'] - mode = self._CIPHER_TABLE[encryption_type]['mode'] + raise PasswordRequiredException("Private key file is encrypted") + cipher = self._CIPHER_TABLE[encryption_type]["cipher"] + keysize = self._CIPHER_TABLE[encryption_type]["keysize"] + mode = self._CIPHER_TABLE[encryption_type]["mode"] salt = unhexlify(b(saltstr)) key = util.generate_key_bytes(md5, salt, password, keysize) decryptor = Cipher( @@ -351,7 +352,7 @@ class PKey(object): :raises: ``IOError`` -- if there was an error writing the file. """ - with open(filename, 'w') as f: + with open(filename, "w") as f: os.chmod(filename, o600) self._write_private_key(f, key, format, password=password) @@ -361,11 +362,11 @@ class PKey(object): else: encryption = serialization.BestAvailableEncryption(b(password)) - f.write(key.private_bytes( - serialization.Encoding.PEM, - format, - encryption - ).decode()) + f.write( + key.private_bytes( + serialization.Encoding.PEM, format, encryption + ).decode() + ) def _check_type_and_load_cert(self, msg, key_type, cert_type): """ @@ -388,7 +389,7 @@ class PKey(object): cert_types = [cert_types] # Can't do much with no message, that should've been handled elsewhere if msg is None: - raise SSHException('Key object may not be empty') + raise SSHException("Key object may not be empty") # First field is always key type, in either kind of object. (make sure # we rewind before grabbing it - sometimes caller had to do their own # introspection first!) @@ -411,7 +412,7 @@ class PKey(object): # (requires going back into per-type subclasses.) msg.get_string() else: - err = 'Invalid key (class: {}, data type: {}' + err = "Invalid key (class: {}, data type: {}" raise SSHException(err.format(self.__class__.__name__, type_)) def load_certificate(self, value): @@ -434,11 +435,11 @@ class PKey(object): successfully. """ if isinstance(value, Message): - constructor = 'from_message' + constructor = "from_message" elif os.path.isfile(value): - constructor = 'from_file' + constructor = "from_file" else: - constructor = 'from_string' + constructor = "from_string" blob = getattr(PublicBlob, constructor)(value) if not blob.key_type.startswith(self.get_name()): err = "PublicBlob type {} incompatible with key type {}" @@ -464,6 +465,7 @@ class PublicBlob(object): `from_message` for useful instantiation, the main constructor is basically "I should be using ``attrs`` for this." """ + def __init__(self, type_, blob, comment=None): """ Create a new public blob of given type and contents. @@ -505,8 +507,10 @@ class PublicBlob(object): m = Message(key_blob) blob_type = m.get_text() if blob_type != key_type: - msg = "Invalid PublicBlob contents: key type={!r}, but blob type={!r}" # noqa - raise ValueError(msg.format(key_type, blob_type)) + deets = "key type={!r}, but blob type={!r}".format( + key_type, blob_type + ) + raise ValueError("Invalid PublicBlob contents: {}".format(deets)) # All good? All good. return cls(type_=key_type, blob=key_blob, comment=comment) @@ -522,7 +526,7 @@ class PublicBlob(object): return cls(type_=type_, blob=message.asbytes()) def __str__(self): - ret = '{} public key/certificate'.format(self.key_type) + ret = "{} public key/certificate".format(self.key_type) if self.comment: ret += "- {}".format(self.comment) return ret diff --git a/paramiko/primes.py b/paramiko/primes.py index ca8f9bec..8dff7683 100644 --- a/paramiko/primes.py +++ b/paramiko/primes.py @@ -49,7 +49,7 @@ def _roll_random(n): return num -class ModulusPack (object): +class ModulusPack(object): """ convenience object for holding the contents of the /etc/ssh/moduli file, on systems that have such a file. @@ -61,8 +61,15 @@ class ModulusPack (object): self.discarded = [] def _parse_modulus(self, line): - timestamp, mod_type, tests, tries, size, generator, modulus = \ - line.split() + ( + timestamp, + mod_type, + tests, + tries, + size, + generator, + modulus, + ) = line.split() mod_type = int(mod_type) tests = int(tests) tries = int(tries) @@ -75,12 +82,13 @@ class ModulusPack (object): # test 4 (more than just a small-prime sieve) # tries < 100 if test & 4 (at least 100 tries of miller-rabin) if ( - mod_type < 2 or - tests < 4 or - (tests & 4 and tests < 8 and tries < 100) + mod_type < 2 + or tests < 4 + or (tests & 4 and tests < 8 and tries < 100) ): self.discarded.append( - (modulus, 'does not meet basic requirements')) + (modulus, "does not meet basic requirements") + ) return if generator == 0: generator = 2 @@ -91,7 +99,8 @@ class ModulusPack (object): bl = util.bit_length(modulus) if (bl != size) and (bl != size + 1): self.discarded.append( - (modulus, 'incorrectly reported bit length {}'.format(size))) + (modulus, "incorrectly reported bit length {}".format(size)) + ) return if bl not in self.pack: self.pack[bl] = [] @@ -102,10 +111,10 @@ class ModulusPack (object): :raises IOError: passed from any file operations that fail. """ self.pack = {} - with open(filename, 'r') as f: + with open(filename, "r") as f: for line in f: line = line.strip() - if (len(line) == 0) or (line[0] == '#'): + if (len(line) == 0) or (line[0] == "#"): continue try: self._parse_modulus(line) @@ -115,7 +124,7 @@ class ModulusPack (object): def get_modulus(self, min, prefer, max): bitsizes = sorted(self.pack.keys()) if len(bitsizes) == 0: - raise SSHException('no moduli available') + raise SSHException("no moduli available") good = -1 # find nearest bitsize >= preferred for b in bitsizes: diff --git a/paramiko/proxy.py b/paramiko/proxy.py index c4ec627c..444c47b6 100644 --- a/paramiko/proxy.py +++ b/paramiko/proxy.py @@ -39,6 +39,7 @@ class ProxyCommand(ClosingContextManager): Instances of this class may be used as context managers. """ + def __init__(self, command_line): """ Create a new CommandProxy instance. The instance created by this @@ -50,9 +51,11 @@ class ProxyCommand(ClosingContextManager): # NOTE: subprocess import done lazily so platforms without it (e.g. # GAE) can still import us during overall Paramiko load. from subprocess import Popen, PIPE + self.cmd = shlsplit(command_line) - self.process = Popen(self.cmd, stdin=PIPE, stdout=PIPE, stderr=PIPE, - bufsize=0) + self.process = Popen( + self.cmd, stdin=PIPE, stdout=PIPE, stderr=PIPE, bufsize=0 + ) self.timeout = None def send(self, content): @@ -69,7 +72,7 @@ class ProxyCommand(ClosingContextManager): # died and we can't proceed. The best option here is to # raise an exception informing the user that the informed # ProxyCommand is not working. - raise ProxyCommandFailure(' '.join(self.cmd), e.strerror) + raise ProxyCommandFailure(" ".join(self.cmd), e.strerror) return len(content) def recv(self, size): @@ -81,21 +84,21 @@ class ProxyCommand(ClosingContextManager): :return: the string of bytes read, which may be shorter than requested """ try: - buffer = b'' + buffer = b"" start = time.time() while len(buffer) < size: select_timeout = None if self.timeout is not None: - elapsed = (time.time() - start) + elapsed = time.time() - start if elapsed >= self.timeout: raise socket.timeout() select_timeout = self.timeout - elapsed - r, w, x = select( - [self.process.stdout], [], [], select_timeout) + r, w, x = select([self.process.stdout], [], [], select_timeout) if r and r[0] == self.process.stdout: buffer += os.read( - self.process.stdout.fileno(), size - len(buffer)) + self.process.stdout.fileno(), size - len(buffer) + ) return buffer except socket.timeout: if buffer: @@ -103,7 +106,7 @@ class ProxyCommand(ClosingContextManager): return buffer raise # socket.timeout is a subclass of IOError except IOError as e: - raise ProxyCommandFailure(' '.join(self.cmd), e.strerror) + raise ProxyCommandFailure(" ".join(self.cmd), e.strerror) def close(self): os.kill(self.process.pid, signal.SIGTERM) diff --git a/paramiko/py3compat.py b/paramiko/py3compat.py index 67c0f200..e1f33fe9 100644 --- a/paramiko/py3compat.py +++ b/paramiko/py3compat.py @@ -2,10 +2,28 @@ import sys import base64 __all__ = [ - 'BytesIO', 'MAXSIZE', 'PY2', 'StringIO', 'b', 'b2s', 'builtins', - 'byte_chr', 'byte_mask', 'byte_ord', 'bytes', 'bytes_types', 'decodebytes', - 'encodebytes', 'input', 'integer_types', 'is_callable', 'long', 'next', - 'string_types', 'text_type', 'u', + "BytesIO", + "MAXSIZE", + "PY2", + "StringIO", + "b", + "b2s", + "builtins", + "byte_chr", + "byte_mask", + "byte_ord", + "bytes", + "bytes_types", + "decodebytes", + "encodebytes", + "input", + "integer_types", + "is_callable", + "long", + "next", + "string_types", + "text_type", + "u", ] PY2 = sys.version_info[0] < 3 @@ -23,16 +41,13 @@ if PY2: import __builtin__ as builtins - byte_ord = ord # NOQA byte_chr = chr # NOQA - def byte_mask(c, mask): return chr(ord(c) & mask) - - def b(s, encoding='utf8'): # NOQA + def b(s, encoding="utf8"): # NOQA """cast unicode or bytes to bytes""" if isinstance(s, str): return s @@ -43,8 +58,7 @@ if PY2: else: raise TypeError("Expected unicode or bytes, got {!r}".format(s)) - - def u(s, encoding='utf8'): # NOQA + def u(s, encoding="utf8"): # NOQA """cast bytes or unicode to unicode""" if isinstance(s, str): return s.decode(encoding) @@ -55,53 +69,52 @@ if PY2: else: raise TypeError("Expected unicode or bytes, got {!r}".format(s)) - def b2s(s): return s - import cStringIO + StringIO = cStringIO.StringIO BytesIO = StringIO - def is_callable(c): # NOQA return callable(c) - def get_next(c): # NOQA return c.next - def next(c): return c.next() # It's possible to have sizeof(long) != sizeof(Py_ssize_t). class X(object): + def __len__(self): return 1 << 31 - try: len(X()) except OverflowError: # 32-bit - MAXSIZE = int((1 << 31) - 1) # NOQA + MAXSIZE = int((1 << 31) - 1) # NOQA else: # 64-bit - MAXSIZE = int((1 << 63) - 1) # NOQA + MAXSIZE = int((1 << 63) - 1) # NOQA del X else: import collections import struct import builtins + string_types = str text_type = str bytes = bytes bytes_types = bytes integer_types = int + class long(int): pass + input = input decodebytes = base64.decodebytes encodebytes = base64.encodebytes @@ -114,13 +127,13 @@ else: def byte_chr(c): assert isinstance(c, int) - return struct.pack('B', c) + return struct.pack("B", c) def byte_mask(c, mask): assert isinstance(c, int) - return struct.pack('B', c & mask) + return struct.pack("B", c & mask) - def b(s, encoding='utf8'): + def b(s, encoding="utf8"): """cast unicode or bytes to bytes""" if isinstance(s, bytes): return s @@ -129,7 +142,7 @@ else: else: raise TypeError("Expected unicode or bytes, got {!r}".format(s)) - def u(s, encoding='utf8'): + def u(s, encoding="utf8"): """cast bytes or unicode to unicode""" if isinstance(s, bytes): return s.decode(encoding) @@ -142,8 +155,9 @@ else: return s.decode() if isinstance(s, bytes) else s import io - StringIO = io.StringIO # NOQA - BytesIO = io.BytesIO # NOQA + + StringIO = io.StringIO # NOQA + BytesIO = io.BytesIO # NOQA def is_callable(c): return isinstance(c, collections.Callable) @@ -153,4 +167,4 @@ else: next = next - MAXSIZE = sys.maxsize # NOQA + MAXSIZE = sys.maxsize # NOQA diff --git a/paramiko/rsakey.py b/paramiko/rsakey.py index 8dfcfb01..442bfe1f 100644 --- a/paramiko/rsakey.py +++ b/paramiko/rsakey.py @@ -37,8 +37,15 @@ class RSAKey(PKey): data. """ - def __init__(self, msg=None, data=None, filename=None, password=None, - key=None, file_obj=None): + def __init__( + self, + msg=None, + data=None, + filename=None, + password=None, + key=None, + file_obj=None, + ): self.key = None self.public_blob = None if file_obj is not None: @@ -54,8 +61,8 @@ class RSAKey(PKey): else: self._check_type_and_load_cert( msg=msg, - key_type='ssh-rsa', - cert_type='ssh-rsa-cert-v01@openssh.com', + key_type="ssh-rsa", + cert_type="ssh-rsa-cert-v01@openssh.com", ) self.key = rsa.RSAPublicNumbers( e=msg.get_mpint(), n=msg.get_mpint() @@ -74,7 +81,7 @@ class RSAKey(PKey): def asbytes(self): m = Message() - m.add_string('ssh-rsa') + m.add_string("ssh-rsa") m.add_mpint(self.public_numbers.e) m.add_mpint(self.public_numbers.n) return m.asbytes() @@ -89,14 +96,15 @@ class RSAKey(PKey): # tries stuffing it into ASCII for whatever godforsaken reason return self.asbytes() else: - return self.asbytes().decode('utf8', errors='ignore') + return self.asbytes().decode("utf8", errors="ignore") def __hash__(self): - return hash((self.get_name(), self.public_numbers.e, - self.public_numbers.n)) + return hash( + (self.get_name(), self.public_numbers.e, self.public_numbers.n) + ) def get_name(self): - return 'ssh-rsa' + return "ssh-rsa" def get_bits(self): return self.size @@ -106,18 +114,16 @@ class RSAKey(PKey): def sign_ssh_data(self, data): sig = self.key.sign( - data, - padding=padding.PKCS1v15(), - algorithm=hashes.SHA1(), + data, padding=padding.PKCS1v15(), algorithm=hashes.SHA1() ) m = Message() - m.add_string('ssh-rsa') + m.add_string("ssh-rsa") m.add_string(sig) return m def verify_ssh_sig(self, data, msg): - if msg.get_text() != 'ssh-rsa': + if msg.get_text() != "ssh-rsa": return False key = self.key if isinstance(key, rsa.RSAPrivateKey): @@ -137,7 +143,7 @@ class RSAKey(PKey): filename, self.key, serialization.PrivateFormat.TraditionalOpenSSL, - password=password + password=password, ) def write_private_key(self, file_obj, password=None): @@ -145,7 +151,7 @@ class RSAKey(PKey): file_obj, self.key, serialization.PrivateFormat.TraditionalOpenSSL, - password=password + password=password, ) @staticmethod @@ -166,11 +172,11 @@ class RSAKey(PKey): # ...internals... def _from_private_key_file(self, filename, password): - data = self._read_private_key_file('RSA', filename, password) + data = self._read_private_key_file("RSA", filename, password) self._decode_key(data) def _from_private_key(self, file_obj, password): - data = self._read_private_key('RSA', file_obj, password) + data = self._read_private_key("RSA", file_obj, password) self._decode_key(data) def _decode_key(self, data): diff --git a/paramiko/server.py b/paramiko/server.py index a7117815..2fe9cc19 100644 --- a/paramiko/server.py +++ b/paramiko/server.py @@ -23,13 +23,16 @@ import threading from paramiko import util from paramiko.common import ( - DEBUG, ERROR, OPEN_FAILED_ADMINISTRATIVELY_PROHIBITED, AUTH_FAILED, + DEBUG, + ERROR, + OPEN_FAILED_ADMINISTRATIVELY_PROHIBITED, + AUTH_FAILED, AUTH_SUCCESSFUL, ) from paramiko.py3compat import string_types -class ServerInterface (object): +class ServerInterface(object): """ This class defines an interface for controlling the behavior of Paramiko in server mode. @@ -99,7 +102,7 @@ class ServerInterface (object): :param str username: the username requesting authentication. :return: a comma-separated `str` of authentication types """ - return 'password' + return "password" def check_auth_none(self, username): """ @@ -233,9 +236,9 @@ class ServerInterface (object): """ return AUTH_FAILED - def check_auth_gssapi_with_mic(self, username, - gss_authenticated=AUTH_FAILED, - cc_file=None): + def check_auth_gssapi_with_mic( + self, username, gss_authenticated=AUTH_FAILED, cc_file=None + ): """ Authenticate the given user to the server if he is a valid krb5 principal. @@ -263,9 +266,9 @@ class ServerInterface (object): return AUTH_SUCCESSFUL return AUTH_FAILED - def check_auth_gssapi_keyex(self, username, - gss_authenticated=AUTH_FAILED, - cc_file=None): + def check_auth_gssapi_keyex( + self, username, gss_authenticated=AUTH_FAILED, cc_file=None + ): """ Authenticate the given user to the server if he is a valid krb5 principal and GSS-API Key Exchange was performed. @@ -372,8 +375,8 @@ class ServerInterface (object): # ...Channel requests... def check_channel_pty_request( - self, channel, term, width, height, pixelwidth, pixelheight, - modes): + self, channel, term, width, height, pixelwidth, pixelheight, modes + ): """ Determine if a pseudo-terminal of the given dimensions (usually requested for shell access) can be provided on the given channel. @@ -460,7 +463,8 @@ class ServerInterface (object): return True def check_channel_window_change_request( - self, channel, width, height, pixelwidth, pixelheight): + self, channel, width, height, pixelwidth, pixelheight + ): """ Determine if the pseudo-terminal on the given channel can be resized. This only makes sense if a pty was previously allocated on it. @@ -479,8 +483,13 @@ class ServerInterface (object): return False def check_channel_x11_request( - self, channel, single_connection, auth_protocol, auth_cookie, - screen_number): + self, + channel, + single_connection, + auth_protocol, + auth_cookie, + screen_number, + ): """ Determine if the client will be provided with an X11 session. If this method returns ``True``, X11 applications should be routed through new @@ -584,12 +593,13 @@ class ServerInterface (object): """ return (None, None) -class InteractiveQuery (object): + +class InteractiveQuery(object): """ A query (set of prompts) for a user during interactive authentication. """ - def __init__(self, name='', instructions='', *prompts): + def __init__(self, name="", instructions="", *prompts): """ Create a new interactive query to send to the client. The name and instructions are optional, but are generally displayed to the end @@ -623,7 +633,7 @@ class InteractiveQuery (object): self.prompts.append((prompt, echo)) -class SubsystemHandler (threading.Thread): +class SubsystemHandler(threading.Thread): """ Handler for a subsytem in server mode. If you create a subclass of this class and pass it to `.Transport.set_subsystem_handler`, an object of this @@ -637,6 +647,7 @@ class SubsystemHandler (threading.Thread): ``MP3Handler`` will be created, and `start_subsystem` will be called on it from a new thread. """ + def __init__(self, channel, name, server): """ Create a new handler for a channel. This is used by `.ServerInterface` @@ -667,7 +678,7 @@ class SubsystemHandler (threading.Thread): def _run(self): try: self.__transport._log( - DEBUG, 'Starting handler for subsystem {}'.format(self.__name) + DEBUG, "Starting handler for subsystem {}".format(self.__name) ) self.start_subsystem(self.__name, self.__transport, self.__channel) except Exception as e: @@ -675,7 +686,7 @@ class SubsystemHandler (threading.Thread): ERROR, 'Exception in subsystem handler for "{}": {}'.format( self.__name, e - ) + ), ) self.__transport._log(ERROR, util.tb_strings()) try: diff --git a/paramiko/sftp.py b/paramiko/sftp.py index e6786d10..6aa4ce44 100644 --- a/paramiko/sftp.py +++ b/paramiko/sftp.py @@ -26,27 +26,54 @@ from paramiko.message import Message from paramiko.py3compat import byte_chr, byte_ord -CMD_INIT, CMD_VERSION, CMD_OPEN, CMD_CLOSE, CMD_READ, CMD_WRITE, CMD_LSTAT, \ - CMD_FSTAT, CMD_SETSTAT, CMD_FSETSTAT, CMD_OPENDIR, CMD_READDIR, \ - CMD_REMOVE, CMD_MKDIR, CMD_RMDIR, CMD_REALPATH, CMD_STAT, CMD_RENAME, \ - CMD_READLINK, CMD_SYMLINK = range(1, 21) -CMD_STATUS, CMD_HANDLE, CMD_DATA, CMD_NAME, CMD_ATTRS = range(101, 106) -CMD_EXTENDED, CMD_EXTENDED_REPLY = range(200, 202) +( + CMD_INIT, + CMD_VERSION, + CMD_OPEN, + CMD_CLOSE, + CMD_READ, + CMD_WRITE, + CMD_LSTAT, + CMD_FSTAT, + CMD_SETSTAT, + CMD_FSETSTAT, + CMD_OPENDIR, + CMD_READDIR, + CMD_REMOVE, + CMD_MKDIR, + CMD_RMDIR, + CMD_REALPATH, + CMD_STAT, + CMD_RENAME, + CMD_READLINK, + CMD_SYMLINK, +) = range(1, 21) +(CMD_STATUS, CMD_HANDLE, CMD_DATA, CMD_NAME, CMD_ATTRS) = range(101, 106) +(CMD_EXTENDED, CMD_EXTENDED_REPLY) = range(200, 202) SFTP_OK = 0 -SFTP_EOF, SFTP_NO_SUCH_FILE, SFTP_PERMISSION_DENIED, SFTP_FAILURE, \ - SFTP_BAD_MESSAGE, SFTP_NO_CONNECTION, SFTP_CONNECTION_LOST, \ - SFTP_OP_UNSUPPORTED = range(1, 9) - -SFTP_DESC = ['Success', - 'End of file', - 'No such file', - 'Permission denied', - 'Failure', - 'Bad message', - 'No connection', - 'Connection lost', - 'Operation unsupported'] +( + SFTP_EOF, + SFTP_NO_SUCH_FILE, + SFTP_PERMISSION_DENIED, + SFTP_FAILURE, + SFTP_BAD_MESSAGE, + SFTP_NO_CONNECTION, + SFTP_CONNECTION_LOST, + SFTP_OP_UNSUPPORTED, +) = range(1, 9) + +SFTP_DESC = [ + "Success", + "End of file", + "No such file", + "Permission denied", + "Failure", + "Bad message", + "No connection", + "Connection lost", + "Operation unsupported", +] SFTP_FLAG_READ = 0x1 SFTP_FLAG_WRITE = 0x2 @@ -60,54 +87,55 @@ _VERSION = 3 # for debugging CMD_NAMES = { - CMD_INIT: 'init', - CMD_VERSION: 'version', - CMD_OPEN: 'open', - CMD_CLOSE: 'close', - CMD_READ: 'read', - CMD_WRITE: 'write', - CMD_LSTAT: 'lstat', - CMD_FSTAT: 'fstat', - CMD_SETSTAT: 'setstat', - CMD_FSETSTAT: 'fsetstat', - CMD_OPENDIR: 'opendir', - CMD_READDIR: 'readdir', - CMD_REMOVE: 'remove', - CMD_MKDIR: 'mkdir', - CMD_RMDIR: 'rmdir', - CMD_REALPATH: 'realpath', - CMD_STAT: 'stat', - CMD_RENAME: 'rename', - CMD_READLINK: 'readlink', - CMD_SYMLINK: 'symlink', - CMD_STATUS: 'status', - CMD_HANDLE: 'handle', - CMD_DATA: 'data', - CMD_NAME: 'name', - CMD_ATTRS: 'attrs', - CMD_EXTENDED: 'extended', - CMD_EXTENDED_REPLY: 'extended_reply' + CMD_INIT: "init", + CMD_VERSION: "version", + CMD_OPEN: "open", + CMD_CLOSE: "close", + CMD_READ: "read", + CMD_WRITE: "write", + CMD_LSTAT: "lstat", + CMD_FSTAT: "fstat", + CMD_SETSTAT: "setstat", + CMD_FSETSTAT: "fsetstat", + CMD_OPENDIR: "opendir", + CMD_READDIR: "readdir", + CMD_REMOVE: "remove", + CMD_MKDIR: "mkdir", + CMD_RMDIR: "rmdir", + CMD_REALPATH: "realpath", + CMD_STAT: "stat", + CMD_RENAME: "rename", + CMD_READLINK: "readlink", + CMD_SYMLINK: "symlink", + CMD_STATUS: "status", + CMD_HANDLE: "handle", + CMD_DATA: "data", + CMD_NAME: "name", + CMD_ATTRS: "attrs", + CMD_EXTENDED: "extended", + CMD_EXTENDED_REPLY: "extended_reply", } -class SFTPError (Exception): +class SFTPError(Exception): pass -class BaseSFTP (object): +class BaseSFTP(object): + def __init__(self): - self.logger = util.get_logger('paramiko.sftp') + self.logger = util.get_logger("paramiko.sftp") self.sock = None self.ultra_debug = False # ...internals... def _send_version(self): - self._send_packet(CMD_INIT, struct.pack('>I', _VERSION)) + self._send_packet(CMD_INIT, struct.pack(">I", _VERSION)) t, data = self._read_packet() if t != CMD_VERSION: - raise SFTPError('Incompatible sftp protocol') - version = struct.unpack('>I', data[:4])[0] + raise SFTPError("Incompatible sftp protocol") + version = struct.unpack(">I", data[:4])[0] # if version != _VERSION: # raise SFTPError('Incompatible sftp protocol') return version @@ -117,10 +145,10 @@ class BaseSFTP (object): # client finishes sending INIT. t, data = self._read_packet() if t != CMD_INIT: - raise SFTPError('Incompatible sftp protocol') - version = struct.unpack('>I', data[:4])[0] + raise SFTPError("Incompatible sftp protocol") + version = struct.unpack(">I", data[:4])[0] # advertise that we support "check-file" - extension_pairs = ['check-file', 'md5,sha1'] + extension_pairs = ["check-file", "md5,sha1"] msg = Message() msg.add_int(_VERSION) msg.add(*extension_pairs) @@ -165,9 +193,9 @@ class BaseSFTP (object): def _send_packet(self, t, packet): packet = asbytes(packet) - out = struct.pack('>I', len(packet) + 1) + byte_chr(t) + packet + out = struct.pack(">I", len(packet) + 1) + byte_chr(t) + packet if self.ultra_debug: - self._log(DEBUG, util.format_binary(out, 'OUT: ')) + self._log(DEBUG, util.format_binary(out, "OUT: ")) self._write_all(out) def _read_packet(self): @@ -175,11 +203,11 @@ class BaseSFTP (object): # most sftp servers won't accept packets larger than about 32k, so # anything with the high byte set (> 16MB) is just garbage. if byte_ord(x[0]): - raise SFTPError('Garbage packet received') - size = struct.unpack('>I', x)[0] + raise SFTPError("Garbage packet received") + size = struct.unpack(">I", x)[0] data = self._read_all(size) if self.ultra_debug: - self._log(DEBUG, util.format_binary(data, 'IN: ')) + self._log(DEBUG, util.format_binary(data, "IN: ")) if size > 0: t = byte_ord(data[0]) return t, data[1:] diff --git a/paramiko/sftp_attr.py b/paramiko/sftp_attr.py index ea12b2f6..f16ac746 100644 --- a/paramiko/sftp_attr.py +++ b/paramiko/sftp_attr.py @@ -22,7 +22,7 @@ from paramiko.common import x80000000, o700, o70, xffffffff from paramiko.py3compat import long, b -class SFTPAttributes (object): +class SFTPAttributes(object): """ Representation of the attributes of a file (or proxied file) for SFTP in client or server mode. It attemps to mirror the object returned by @@ -82,7 +82,7 @@ class SFTPAttributes (object): return attr def __repr__(self): - return '<SFTPAttributes: {}>'.format(self._debug_str()) + return "<SFTPAttributes: {}>".format(self._debug_str()) # ...internals... @classmethod @@ -144,29 +144,29 @@ class SFTPAttributes (object): return def _debug_str(self): - out = '[ ' + out = "[ " if self.st_size is not None: - out += 'size={} '.format(self.st_size) + out += "size={} ".format(self.st_size) if (self.st_uid is not None) and (self.st_gid is not None): - out += 'uid={} gid={} '.format(self.st_uid, self.st_gid) + out += "uid={} gid={} ".format(self.st_uid, self.st_gid) if self.st_mode is not None: - out += 'mode=' + oct(self.st_mode) + ' ' + out += "mode=" + oct(self.st_mode) + " " if (self.st_atime is not None) and (self.st_mtime is not None): - out += 'atime={} mtime={} '.format(self.st_atime, self.st_mtime) + out += "atime={} mtime={} ".format(self.st_atime, self.st_mtime) for k, v in self.attr.items(): out += '"{}"={!r} '.format(str(k), v) - out += ']' + out += "]" return out @staticmethod def _rwx(n, suid, sticky=False): if suid: suid = 2 - out = '-r'[n >> 2] + '-w'[(n >> 1) & 1] + out = "-r"[n >> 2] + "-w"[(n >> 1) & 1] if sticky: - out += '-xTt'[suid + (n & 1)] + out += "-xTt"[suid + (n & 1)] else: - out += '-xSs'[suid + (n & 1)] + out += "-xSs"[suid + (n & 1)] return out def __str__(self): @@ -174,42 +174,47 @@ class SFTPAttributes (object): if self.st_mode is not None: kind = stat.S_IFMT(self.st_mode) if kind == stat.S_IFIFO: - ks = 'p' + ks = "p" elif kind == stat.S_IFCHR: - ks = 'c' + ks = "c" elif kind == stat.S_IFDIR: - ks = 'd' + ks = "d" elif kind == stat.S_IFBLK: - ks = 'b' + ks = "b" elif kind == stat.S_IFREG: - ks = '-' + ks = "-" elif kind == stat.S_IFLNK: - ks = 'l' + ks = "l" elif kind == stat.S_IFSOCK: - ks = 's' + ks = "s" else: - ks = '?' + ks = "?" ks += self._rwx( - (self.st_mode & o700) >> 6, self.st_mode & stat.S_ISUID) + (self.st_mode & o700) >> 6, self.st_mode & stat.S_ISUID + ) ks += self._rwx( - (self.st_mode & o70) >> 3, self.st_mode & stat.S_ISGID) + (self.st_mode & o70) >> 3, self.st_mode & stat.S_ISGID + ) ks += self._rwx( - self.st_mode & 7, self.st_mode & stat.S_ISVTX, True) + self.st_mode & 7, self.st_mode & stat.S_ISVTX, True + ) else: - ks = '?---------' + ks = "?---------" # compute display date if (self.st_mtime is None) or (self.st_mtime == xffffffff): # shouldn't really happen - datestr = '(unknown date)' + datestr = "(unknown date)" else: if abs(time.time() - self.st_mtime) > 15552000: # (15552000 = 6 months) datestr = time.strftime( - '%d %b %Y', time.localtime(self.st_mtime)) + "%d %b %Y", time.localtime(self.st_mtime) + ) else: datestr = time.strftime( - '%d %b %H:%M', time.localtime(self.st_mtime)) - filename = getattr(self, 'filename', '?') + "%d %b %H:%M", time.localtime(self.st_mtime) + ) + filename = getattr(self, "filename", "?") # not all servers support uid/gid uid = self.st_uid @@ -225,8 +230,13 @@ class SFTPAttributes (object): # TODO: not sure this actually worked as expected beforehand, leaving # it untouched for the time being, re: .format() upgrade, until someone # has time to doublecheck - return '%s 1 %-8d %-8d %8d %-12s %s' % ( - ks, uid, gid, size, datestr, filename, + return "%s 1 %-8d %-8d %8d %-12s %s" % ( + ks, + uid, + gid, + size, + datestr, + filename, ) def asbytes(self): diff --git a/paramiko/sftp_client.py b/paramiko/sftp_client.py index 31dc234c..de3f9f58 100644 --- a/paramiko/sftp_client.py +++ b/paramiko/sftp_client.py @@ -30,12 +30,37 @@ from paramiko.message import Message from paramiko.common import INFO, DEBUG, o777 from paramiko.py3compat import b, u, long from paramiko.sftp import ( - BaseSFTP, CMD_OPENDIR, CMD_HANDLE, SFTPError, CMD_READDIR, CMD_NAME, - CMD_CLOSE, SFTP_FLAG_READ, SFTP_FLAG_WRITE, SFTP_FLAG_CREATE, - SFTP_FLAG_TRUNC, SFTP_FLAG_APPEND, SFTP_FLAG_EXCL, CMD_OPEN, CMD_REMOVE, - CMD_RENAME, CMD_MKDIR, CMD_RMDIR, CMD_STAT, CMD_ATTRS, CMD_LSTAT, - CMD_SYMLINK, CMD_SETSTAT, CMD_READLINK, CMD_REALPATH, CMD_STATUS, - CMD_EXTENDED, SFTP_OK, SFTP_EOF, SFTP_NO_SUCH_FILE, SFTP_PERMISSION_DENIED, + BaseSFTP, + CMD_OPENDIR, + CMD_HANDLE, + SFTPError, + CMD_READDIR, + CMD_NAME, + CMD_CLOSE, + SFTP_FLAG_READ, + SFTP_FLAG_WRITE, + SFTP_FLAG_CREATE, + SFTP_FLAG_TRUNC, + SFTP_FLAG_APPEND, + SFTP_FLAG_EXCL, + CMD_OPEN, + CMD_REMOVE, + CMD_RENAME, + CMD_MKDIR, + CMD_RMDIR, + CMD_STAT, + CMD_ATTRS, + CMD_LSTAT, + CMD_SYMLINK, + CMD_SETSTAT, + CMD_READLINK, + CMD_REALPATH, + CMD_STATUS, + CMD_EXTENDED, + SFTP_OK, + SFTP_EOF, + SFTP_NO_SUCH_FILE, + SFTP_PERMISSION_DENIED, ) from paramiko.sftp_attr import SFTPAttributes @@ -51,15 +76,15 @@ def _to_unicode(s): probably doesn't know the filename's encoding. """ try: - return s.encode('ascii') + return s.encode("ascii") except (UnicodeError, AttributeError): try: - return s.decode('utf-8') + return s.decode("utf-8") except UnicodeError: return s -b_slash = b'/' +b_slash = b"/" class SFTPClient(BaseSFTP, ClosingContextManager): @@ -71,6 +96,7 @@ class SFTPClient(BaseSFTP, ClosingContextManager): Instances of this class may be used as context managers. """ + def __init__(self, sock): """ Create an SFTP client from an existing `.Channel`. The channel @@ -97,15 +123,18 @@ class SFTPClient(BaseSFTP, ClosingContextManager): # override default logger transport = self.sock.get_transport() self.logger = util.get_logger( - transport.get_log_channel() + '.sftp') + transport.get_log_channel() + ".sftp" + ) self.ultra_debug = transport.get_hexdump() try: server_version = self._send_version() except EOFError: - raise SSHException('EOF during negotiation') + raise SSHException("EOF during negotiation") self._log( INFO, - 'Opened sftp connection (server version {})'.format(server_version) + "Opened sftp connection (server version {})".format( + server_version + ), ) @classmethod @@ -132,11 +161,12 @@ class SFTPClient(BaseSFTP, ClosingContextManager): .. versionchanged:: 1.15 Added the ``window_size`` and ``max_packet_size`` arguments. """ - chan = t.open_session(window_size=window_size, - max_packet_size=max_packet_size) + chan = t.open_session( + window_size=window_size, max_packet_size=max_packet_size + ) if chan is None: return None - chan.invoke_subsystem('sftp') + chan.invoke_subsystem("sftp") return cls(chan) def _log(self, level, msg, *args): @@ -148,10 +178,12 @@ class SFTPClient(BaseSFTP, ClosingContextManager): # logging.Logger.log() explicitly requires it. Grump. # escape '%' in msg (they could come from file or directory names) # before logging - msg = msg.replace('%', '%%') + msg = msg.replace("%", "%%") super(SFTPClient, self)._log( level, - "[chan %s] " + msg, *([self.sock.get_name()] + list(args))) + "[chan %s] " + msg, + *([self.sock.get_name()] + list(args)) + ) def close(self): """ @@ -159,7 +191,7 @@ class SFTPClient(BaseSFTP, ClosingContextManager): .. versionadded:: 1.4 """ - self._log(INFO, 'sftp session closed.') + self._log(INFO, "sftp session closed.") self.sock.close() def get_channel(self): @@ -171,7 +203,7 @@ class SFTPClient(BaseSFTP, ClosingContextManager): """ return self.sock - def listdir(self, path='.'): + def listdir(self, path="."): """ Return a list containing the names of the entries in the given ``path``. @@ -185,7 +217,7 @@ class SFTPClient(BaseSFTP, ClosingContextManager): """ return [f.filename for f in self.listdir_attr(path)] - def listdir_attr(self, path='.'): + def listdir_attr(self, path="."): """ Return a list containing `.SFTPAttributes` objects corresponding to files in the given ``path``. The list is in arbitrary order. It does @@ -203,10 +235,10 @@ class SFTPClient(BaseSFTP, ClosingContextManager): .. versionadded:: 1.2 """ path = self._adjust_cwd(path) - self._log(DEBUG, 'listdir({!r})'.format(path)) + self._log(DEBUG, "listdir({!r})".format(path)) t, msg = self._request(CMD_OPENDIR, path) if t != CMD_HANDLE: - raise SFTPError('Expected handle') + raise SFTPError("Expected handle") handle = msg.get_binary() filelist = [] while True: @@ -216,18 +248,18 @@ class SFTPClient(BaseSFTP, ClosingContextManager): # done with handle break if t != CMD_NAME: - raise SFTPError('Expected name response') + raise SFTPError("Expected name response") count = msg.get_int() for i in range(count): filename = msg.get_text() longname = msg.get_text() attr = SFTPAttributes._from_msg(msg, filename, longname) - if (filename != '.') and (filename != '..'): + if (filename != ".") and (filename != ".."): filelist.append(attr) self._request(CMD_CLOSE, handle) return filelist - def listdir_iter(self, path='.', read_aheads=50): + def listdir_iter(self, path=".", read_aheads=50): """ Generator version of `.listdir_attr`. @@ -242,11 +274,11 @@ class SFTPClient(BaseSFTP, ClosingContextManager): .. versionadded:: 1.15 """ path = self._adjust_cwd(path) - self._log(DEBUG, 'listdir({!r})'.format(path)) + self._log(DEBUG, "listdir({!r})".format(path)) t, msg = self._request(CMD_OPENDIR, path) if t != CMD_HANDLE: - raise SFTPError('Expected handle') + raise SFTPError("Expected handle") handle = msg.get_string() @@ -261,7 +293,6 @@ class SFTPClient(BaseSFTP, ClosingContextManager): num = self._async_request(type(None), CMD_READDIR, handle) nums.append(num) - # For each of our sent requests # Read and parse the corresponding packets # If we're at the end of our queued requests, then fire off @@ -280,8 +311,9 @@ class SFTPClient(BaseSFTP, ClosingContextManager): filename = msg.get_text() longname = msg.get_text() attr = SFTPAttributes._from_msg( - msg, filename, longname) - if (filename != '.') and (filename != '..'): + msg, filename, longname + ) + if (filename != ".") and (filename != ".."): yield attr # If we've hit the end of our queued requests, reset nums. @@ -291,8 +323,7 @@ class SFTPClient(BaseSFTP, ClosingContextManager): self._request(CMD_CLOSE, handle) return - - def open(self, filename, mode='r', bufsize=-1): + def open(self, filename, mode="r", bufsize=-1): """ Open a file on the remote server. The arguments are the same as for Python's built-in `python:file` (aka `python:open`). A file-like @@ -325,26 +356,28 @@ class SFTPClient(BaseSFTP, ClosingContextManager): :raises: ``IOError`` -- if the file could not be opened. """ filename = self._adjust_cwd(filename) - self._log(DEBUG, 'open({!r}, {!r})'.format(filename, mode)) + self._log(DEBUG, "open({!r}, {!r})".format(filename, mode)) imode = 0 - if ('r' in mode) or ('+' in mode): + if ("r" in mode) or ("+" in mode): imode |= SFTP_FLAG_READ - if ('w' in mode) or ('+' in mode) or ('a' in mode): + if ("w" in mode) or ("+" in mode) or ("a" in mode): imode |= SFTP_FLAG_WRITE - if 'w' in mode: + if "w" in mode: imode |= SFTP_FLAG_CREATE | SFTP_FLAG_TRUNC - if 'a' in mode: + if "a" in mode: imode |= SFTP_FLAG_CREATE | SFTP_FLAG_APPEND - if 'x' in mode: + if "x" in mode: imode |= SFTP_FLAG_CREATE | SFTP_FLAG_EXCL attrblock = SFTPAttributes() t, msg = self._request(CMD_OPEN, filename, imode, attrblock) if t != CMD_HANDLE: - raise SFTPError('Expected handle') + raise SFTPError("Expected handle") handle = msg.get_binary() self._log( DEBUG, - 'open({!r}, {!r}) -> {}'.format(filename, mode, u(hexlify(handle))) + "open({!r}, {!r}) -> {}".format( + filename, mode, u(hexlify(handle)) + ), ) return SFTPFile(self, handle, mode, bufsize) @@ -361,7 +394,7 @@ class SFTPClient(BaseSFTP, ClosingContextManager): :raises: ``IOError`` -- if the path refers to a folder (directory) """ path = self._adjust_cwd(path) - self._log(DEBUG, 'remove({!r})'.format(path)) + self._log(DEBUG, "remove({!r})".format(path)) self._request(CMD_REMOVE, path) unlink = remove @@ -386,7 +419,7 @@ class SFTPClient(BaseSFTP, ClosingContextManager): """ oldpath = self._adjust_cwd(oldpath) newpath = self._adjust_cwd(newpath) - self._log(DEBUG, 'rename({!r}, {!r})'.format(oldpath, newpath)) + self._log(DEBUG, "rename({!r}, {!r})".format(oldpath, newpath)) self._request(CMD_RENAME, oldpath, newpath) def posix_rename(self, oldpath, newpath): @@ -406,7 +439,7 @@ class SFTPClient(BaseSFTP, ClosingContextManager): """ oldpath = self._adjust_cwd(oldpath) newpath = self._adjust_cwd(newpath) - self._log(DEBUG, 'posix_rename({!r}, {!r})'.format(oldpath, newpath)) + self._log(DEBUG, "posix_rename({!r}, {!r})".format(oldpath, newpath)) self._request( CMD_EXTENDED, "posix-rename@openssh.com", oldpath, newpath ) @@ -421,7 +454,7 @@ class SFTPClient(BaseSFTP, ClosingContextManager): :param int mode: permissions (posix-style) for the newly-created folder """ path = self._adjust_cwd(path) - self._log(DEBUG, 'mkdir({!r}, {!r})'.format(path, mode)) + self._log(DEBUG, "mkdir({!r}, {!r})".format(path, mode)) attr = SFTPAttributes() attr.st_mode = mode self._request(CMD_MKDIR, path, attr) @@ -433,7 +466,7 @@ class SFTPClient(BaseSFTP, ClosingContextManager): :param str path: name of the folder to remove """ path = self._adjust_cwd(path) - self._log(DEBUG, 'rmdir({!r})'.format(path)) + self._log(DEBUG, "rmdir({!r})".format(path)) self._request(CMD_RMDIR, path) def stat(self, path): @@ -456,10 +489,10 @@ class SFTPClient(BaseSFTP, ClosingContextManager): file """ path = self._adjust_cwd(path) - self._log(DEBUG, 'stat({!r})'.format(path)) + self._log(DEBUG, "stat({!r})".format(path)) t, msg = self._request(CMD_STAT, path) if t != CMD_ATTRS: - raise SFTPError('Expected attributes') + raise SFTPError("Expected attributes") return SFTPAttributes._from_msg(msg) def lstat(self, path): @@ -474,10 +507,10 @@ class SFTPClient(BaseSFTP, ClosingContextManager): file """ path = self._adjust_cwd(path) - self._log(DEBUG, 'lstat({!r})'.format(path)) + self._log(DEBUG, "lstat({!r})".format(path)) t, msg = self._request(CMD_LSTAT, path) if t != CMD_ATTRS: - raise SFTPError('Expected attributes') + raise SFTPError("Expected attributes") return SFTPAttributes._from_msg(msg) def symlink(self, source, dest): @@ -488,7 +521,7 @@ class SFTPClient(BaseSFTP, ClosingContextManager): :param str dest: path of the newly created symlink """ dest = self._adjust_cwd(dest) - self._log(DEBUG, 'symlink({!r}, {!r})'.format(source, dest)) + self._log(DEBUG, "symlink({!r}, {!r})".format(source, dest)) source = b(source) self._request(CMD_SYMLINK, source, dest) @@ -502,7 +535,7 @@ class SFTPClient(BaseSFTP, ClosingContextManager): :param int mode: new permissions """ path = self._adjust_cwd(path) - self._log(DEBUG, 'chmod({!r}, {!r})'.format(path, mode)) + self._log(DEBUG, "chmod({!r}, {!r})".format(path, mode)) attr = SFTPAttributes() attr.st_mode = mode self._request(CMD_SETSTAT, path, attr) @@ -519,7 +552,7 @@ class SFTPClient(BaseSFTP, ClosingContextManager): :param int gid: new group id """ path = self._adjust_cwd(path) - self._log(DEBUG, 'chown({!r}, {!r}, {!r})'.format(path, uid, gid)) + self._log(DEBUG, "chown({!r}, {!r}, {!r})".format(path, uid, gid)) attr = SFTPAttributes() attr.st_uid, attr.st_gid = uid, gid self._request(CMD_SETSTAT, path, attr) @@ -541,7 +574,7 @@ class SFTPClient(BaseSFTP, ClosingContextManager): path = self._adjust_cwd(path) if times is None: times = (time.time(), time.time()) - self._log(DEBUG, 'utime({!r}, {!r})'.format(path, times)) + self._log(DEBUG, "utime({!r}, {!r})".format(path, times)) attr = SFTPAttributes() attr.st_atime, attr.st_mtime = times self._request(CMD_SETSTAT, path, attr) @@ -556,7 +589,7 @@ class SFTPClient(BaseSFTP, ClosingContextManager): :param int size: the new size of the file """ path = self._adjust_cwd(path) - self._log(DEBUG, 'truncate({!r}, {!r})'.format(path, size)) + self._log(DEBUG, "truncate({!r}, {!r})".format(path, size)) attr = SFTPAttributes() attr.st_size = size self._request(CMD_SETSTAT, path, attr) @@ -571,15 +604,15 @@ class SFTPClient(BaseSFTP, ClosingContextManager): :return: target path, as a `str` """ path = self._adjust_cwd(path) - self._log(DEBUG, 'readlink({!r})'.format(path)) + self._log(DEBUG, "readlink({!r})".format(path)) t, msg = self._request(CMD_READLINK, path) if t != CMD_NAME: - raise SFTPError('Expected name response') + raise SFTPError("Expected name response") count = msg.get_int() if count == 0: return None if count != 1: - raise SFTPError('Readlink returned {} results'.format(count)) + raise SFTPError("Readlink returned {} results".format(count)) return _to_unicode(msg.get_string()) def normalize(self, path): @@ -595,13 +628,13 @@ class SFTPClient(BaseSFTP, ClosingContextManager): :raises: ``IOError`` -- if the path can't be resolved on the server """ path = self._adjust_cwd(path) - self._log(DEBUG, 'normalize({!r})'.format(path)) + self._log(DEBUG, "normalize({!r})".format(path)) t, msg = self._request(CMD_REALPATH, path) if t != CMD_NAME: - raise SFTPError('Expected name response') + raise SFTPError("Expected name response") count = msg.get_int() if count != 1: - raise SFTPError('Realpath returned {} results'.format(count)) + raise SFTPError("Realpath returned {} results".format(count)) return msg.get_text() def chdir(self, path=None): @@ -625,9 +658,7 @@ class SFTPClient(BaseSFTP, ClosingContextManager): return if not stat.S_ISDIR(self.stat(path).st_mode): code = errno.ENOTDIR - raise SFTPError( - code, "{}: {}".format(os.strerror(code), path) - ) + raise SFTPError(code, "{}: {}".format(os.strerror(code), path)) self._cwd = b(self.normalize(path)) def getcwd(self): @@ -680,7 +711,7 @@ class SFTPClient(BaseSFTP, ClosingContextManager): .. versionadded:: 1.10 """ - with self.file(remotepath, 'wb') as fr: + with self.file(remotepath, "wb") as fr: fr.set_pipelined(True) size = self._transfer_with_callback( reader=fl, writer=fr, file_size=file_size, callback=callback @@ -689,7 +720,8 @@ class SFTPClient(BaseSFTP, ClosingContextManager): s = self.stat(remotepath) if s.st_size != size: raise IOError( - 'size mismatch in put! {} != {}'.format(s.st_size, size)) + "size mismatch in put! {} != {}".format(s.st_size, size) + ) else: s = SFTPAttributes() return s @@ -723,7 +755,7 @@ class SFTPClient(BaseSFTP, ClosingContextManager): ``confirm`` param added. """ file_size = os.stat(localpath).st_size - with open(localpath, 'rb') as fl: + with open(localpath, "rb") as fl: return self.putfo(fl, remotepath, file_size, callback, confirm) def getfo(self, remotepath, fl, callback=None): @@ -744,7 +776,7 @@ class SFTPClient(BaseSFTP, ClosingContextManager): .. versionadded:: 1.10 """ file_size = self.stat(remotepath).st_size - with self.open(remotepath, 'rb') as fr: + with self.open(remotepath, "rb") as fr: fr.prefetch(file_size) return self._transfer_with_callback( reader=fr, writer=fl, file_size=file_size, callback=callback @@ -766,12 +798,13 @@ class SFTPClient(BaseSFTP, ClosingContextManager): .. versionchanged:: 1.7.4 Added the ``callback`` param """ - with open(localpath, 'wb') as fl: + with open(localpath, "wb") as fl: size = self.getfo(remotepath, fl, callback) s = os.stat(localpath) if s.st_size != size: raise IOError( - 'size mismatch in get! {} != {}'.format(s.st_size, size)) + "size mismatch in get! {} != {}".format(s.st_size, size) + ) # ...internals... @@ -809,7 +842,7 @@ class SFTPClient(BaseSFTP, ClosingContextManager): try: t, data = self._read_packet() except EOFError as e: - raise SSHException('Server connection dropped: {}'.format(e)) + raise SSHException("Server connection dropped: {}".format(e)) msg = Message(data) num = msg.get_int() self._lock.acquire() @@ -817,7 +850,7 @@ class SFTPClient(BaseSFTP, ClosingContextManager): if num not in self._expecting: # might be response for a file that was closed before # responses came back - self._log(DEBUG, 'Unexpected response #{}'.format(num)) + self._log(DEBUG, "Unexpected response #{}".format(num)) if waitfor is None: # just doing a single check break diff --git a/paramiko/sftp_file.py b/paramiko/sftp_file.py index 52f2bde8..0104d857 100644 --- a/paramiko/sftp_file.py +++ b/paramiko/sftp_file.py @@ -32,13 +32,21 @@ from paramiko.common import DEBUG from paramiko.file import BufferedFile from paramiko.py3compat import u, long from paramiko.sftp import ( - CMD_CLOSE, CMD_READ, CMD_DATA, SFTPError, CMD_WRITE, CMD_STATUS, CMD_FSTAT, - CMD_ATTRS, CMD_FSETSTAT, CMD_EXTENDED, + CMD_CLOSE, + CMD_READ, + CMD_DATA, + SFTPError, + CMD_WRITE, + CMD_STATUS, + CMD_FSTAT, + CMD_ATTRS, + CMD_FSETSTAT, + CMD_EXTENDED, ) from paramiko.sftp_attr import SFTPAttributes -class SFTPFile (BufferedFile): +class SFTPFile(BufferedFile): """ Proxy object for a file on the remote server, in client mode SFTP. @@ -50,7 +58,7 @@ class SFTPFile (BufferedFile): # this size. MAX_REQUEST_SIZE = 32768 - def __init__(self, sftp, handle, mode='r', bufsize=-1): + def __init__(self, sftp, handle, mode="r", bufsize=-1): BufferedFile.__init__(self) self.sftp = sftp self.handle = handle @@ -83,7 +91,7 @@ class SFTPFile (BufferedFile): # __del__.) if self._closed: return - self.sftp._log(DEBUG, 'close({})'.format(u(hexlify(self.handle)))) + self.sftp._log(DEBUG, "close({})".format(u(hexlify(self.handle)))) if self.pipelined: self.sftp._finish_responses(self) BufferedFile.close(self) @@ -102,8 +110,9 @@ class SFTPFile (BufferedFile): pass def _data_in_prefetch_requests(self, offset, size): - k = [x for x in list(self._prefetch_extents.values()) - if x[0] <= offset] + k = [ + x for x in list(self._prefetch_extents.values()) if x[0] <= offset + ] if len(k) == 0: return False k.sort(key=lambda x: x[0]) @@ -117,8 +126,8 @@ class SFTPFile (BufferedFile): # well, we have part of the request. see if another chunk has # the rest. return self._data_in_prefetch_requests( - buf_offset + buf_size, - offset + size - buf_offset - buf_size) + buf_offset + buf_size, offset + size - buf_offset - buf_size + ) def _data_in_prefetch_buffers(self, offset): """ @@ -174,13 +183,10 @@ class SFTPFile (BufferedFile): if data is not None: return data t, msg = self.sftp._request( - CMD_READ, - self.handle, - long(self._realpos), - int(size) + CMD_READ, self.handle, long(self._realpos), int(size) ) if t != CMD_DATA: - raise SFTPError('Expected data') + raise SFTPError("Expected data") return msg.get_string() def _write(self, data): @@ -191,18 +197,17 @@ class SFTPFile (BufferedFile): CMD_WRITE, self.handle, long(self._realpos), - data[:chunk] + data[:chunk], ) self._reqs.append(sftp_async_request) - if ( - not self.pipelined or - (len(self._reqs) > 100 and self.sftp.sock.recv_ready()) + if not self.pipelined or ( + len(self._reqs) > 100 and self.sftp.sock.recv_ready() ): while len(self._reqs): req = self._reqs.popleft() t, msg = self.sftp._read_response(req) if t != CMD_STATUS: - raise SFTPError('Expected status') + raise SFTPError("Expected status") # convert_status already called return chunk @@ -277,7 +282,7 @@ class SFTPFile (BufferedFile): """ t, msg = self.sftp._request(CMD_FSTAT, self.handle) if t != CMD_ATTRS: - raise SFTPError('Expected attributes') + raise SFTPError("Expected attributes") return SFTPAttributes._from_msg(msg) def chmod(self, mode): @@ -288,8 +293,9 @@ class SFTPFile (BufferedFile): :param int mode: new permissions """ - self.sftp._log(DEBUG, 'chmod({}, {!r})'.format( - hexlify(self.handle), mode)) + self.sftp._log( + DEBUG, "chmod({}, {!r})".format(hexlify(self.handle), mode) + ) attr = SFTPAttributes() attr.st_mode = mode self.sftp._request(CMD_FSETSTAT, self.handle, attr) @@ -306,7 +312,8 @@ class SFTPFile (BufferedFile): """ self.sftp._log( DEBUG, - 'chown({}, {!r}, {!r})'.format(hexlify(self.handle), uid, gid)) + "chown({}, {!r}, {!r})".format(hexlify(self.handle), uid, gid), + ) attr = SFTPAttributes() attr.st_uid, attr.st_gid = uid, gid self.sftp._request(CMD_FSETSTAT, self.handle, attr) @@ -326,8 +333,9 @@ class SFTPFile (BufferedFile): """ if times is None: times = (time.time(), time.time()) - self.sftp._log(DEBUG, 'utime({}, {!r})'.format( - hexlify(self.handle), times)) + self.sftp._log( + DEBUG, "utime({}, {!r})".format(hexlify(self.handle), times) + ) attr = SFTPAttributes() attr.st_atime, attr.st_mtime = times self.sftp._request(CMD_FSETSTAT, self.handle, attr) @@ -341,8 +349,8 @@ class SFTPFile (BufferedFile): :param size: the new size of the file """ self.sftp._log( - DEBUG, - 'truncate({}, {!r})'.format(hexlify(self.handle), size)) + DEBUG, "truncate({}, {!r})".format(hexlify(self.handle), size) + ) attr = SFTPAttributes() attr.st_size = size self.sftp._request(CMD_FSETSTAT, self.handle, attr) @@ -394,8 +402,14 @@ class SFTPFile (BufferedFile): .. versionadded:: 1.4 """ t, msg = self.sftp._request( - CMD_EXTENDED, 'check-file', self.handle, - hash_algorithm, long(offset), long(length), block_size) + CMD_EXTENDED, + "check-file", + self.handle, + hash_algorithm, + long(offset), + long(length), + block_size, + ) msg.get_text() # ext msg.get_text() # alg data = msg.get_remainder() @@ -475,16 +489,16 @@ class SFTPFile (BufferedFile): .. versionadded:: 1.5.4 """ - self.sftp._log(DEBUG, 'readv({}, {!r})'.format( - hexlify(self.handle), chunks)) + self.sftp._log( + DEBUG, "readv({}, {!r})".format(hexlify(self.handle), chunks) + ) read_chunks = [] for offset, size in chunks: # don't fetch data that's already in the prefetch buffer - if ( - self._data_in_prefetch_buffers(offset) or - self._data_in_prefetch_requests(offset, size) - ): + if self._data_in_prefetch_buffers( + offset + ) or self._data_in_prefetch_requests(offset, size): continue # break up anything larger than the max read size @@ -521,11 +535,8 @@ class SFTPFile (BufferedFile): # a lot of them, so it may block. for offset, length in chunks: num = self.sftp._async_request( - self, - CMD_READ, - self.handle, - long(offset), - int(length)) + self, CMD_READ, self.handle, long(offset), int(length) + ) with self._prefetch_lock: self._prefetch_extents[num] = (offset, length) @@ -538,7 +549,7 @@ class SFTPFile (BufferedFile): self._saved_exception = e return if t != CMD_DATA: - raise SFTPError('Expected data') + raise SFTPError("Expected data") data = msg.get_string() while True: with self._prefetch_lock: diff --git a/paramiko/sftp_handle.py b/paramiko/sftp_handle.py index ca473900..a7e22f01 100644 --- a/paramiko/sftp_handle.py +++ b/paramiko/sftp_handle.py @@ -25,7 +25,7 @@ from paramiko.sftp import SFTP_OP_UNSUPPORTED, SFTP_OK from paramiko.util import ClosingContextManager -class SFTPHandle (ClosingContextManager): +class SFTPHandle(ClosingContextManager): """ Abstract object representing a handle to an open file (or folder) in an SFTP server implementation. Each handle has a string representation used @@ -36,6 +36,7 @@ class SFTPHandle (ClosingContextManager): Instances of this class may be used as context managers. """ + def __init__(self, flags=0): """ Create a new file handle representing a local file being served over @@ -63,10 +64,10 @@ class SFTPHandle (ClosingContextManager): using the default implementations of `read` and `write`, this method's default implementation should be fine also. """ - readfile = getattr(self, 'readfile', None) + readfile = getattr(self, "readfile", None) if readfile is not None: readfile.close() - writefile = getattr(self, 'writefile', None) + writefile = getattr(self, "writefile", None) if writefile is not None: writefile.close() @@ -88,7 +89,7 @@ class SFTPHandle (ClosingContextManager): :param int length: number of bytes to attempt to read. :return: data read from the file, or an SFTP error code, as a `str`. """ - readfile = getattr(self, 'readfile', None) + readfile = getattr(self, "readfile", None) if readfile is None: return SFTP_OP_UNSUPPORTED try: @@ -122,7 +123,7 @@ class SFTPHandle (ClosingContextManager): :param str data: data to write into the file. :return: an SFTP error code like ``SFTP_OK``. """ - writefile = getattr(self, 'writefile', None) + writefile = getattr(self, "writefile", None) if writefile is None: return SFTP_OP_UNSUPPORTED try: diff --git a/paramiko/sftp_server.py b/paramiko/sftp_server.py index f8c4f727..8265df96 100644 --- a/paramiko/sftp_server.py +++ b/paramiko/sftp_server.py @@ -27,7 +27,11 @@ from hashlib import md5, sha1 from paramiko import util from paramiko.sftp import ( - BaseSFTP, Message, SFTP_FAILURE, SFTP_PERMISSION_DENIED, SFTP_NO_SUCH_FILE, + BaseSFTP, + Message, + SFTP_FAILURE, + SFTP_PERMISSION_DENIED, + SFTP_NO_SUCH_FILE, ) from paramiko.sftp_si import SFTPServerInterface from paramiko.sftp_attr import SFTPAttributes @@ -38,30 +42,64 @@ from paramiko.server import SubsystemHandler # known hash algorithms for the "check-file" extension from paramiko.sftp import ( - CMD_HANDLE, SFTP_DESC, CMD_STATUS, SFTP_EOF, CMD_NAME, SFTP_BAD_MESSAGE, - CMD_EXTENDED_REPLY, SFTP_FLAG_READ, SFTP_FLAG_WRITE, SFTP_FLAG_APPEND, - SFTP_FLAG_CREATE, SFTP_FLAG_TRUNC, SFTP_FLAG_EXCL, CMD_NAMES, CMD_OPEN, - CMD_CLOSE, SFTP_OK, CMD_READ, CMD_DATA, CMD_WRITE, CMD_REMOVE, CMD_RENAME, - CMD_MKDIR, CMD_RMDIR, CMD_OPENDIR, CMD_READDIR, CMD_STAT, CMD_ATTRS, - CMD_LSTAT, CMD_FSTAT, CMD_SETSTAT, CMD_FSETSTAT, CMD_READLINK, CMD_SYMLINK, - CMD_REALPATH, CMD_EXTENDED, SFTP_OP_UNSUPPORTED, + CMD_HANDLE, + SFTP_DESC, + CMD_STATUS, + SFTP_EOF, + CMD_NAME, + SFTP_BAD_MESSAGE, + CMD_EXTENDED_REPLY, + SFTP_FLAG_READ, + SFTP_FLAG_WRITE, + SFTP_FLAG_APPEND, + SFTP_FLAG_CREATE, + SFTP_FLAG_TRUNC, + SFTP_FLAG_EXCL, + CMD_NAMES, + CMD_OPEN, + CMD_CLOSE, + SFTP_OK, + CMD_READ, + CMD_DATA, + CMD_WRITE, + CMD_REMOVE, + CMD_RENAME, + CMD_MKDIR, + CMD_RMDIR, + CMD_OPENDIR, + CMD_READDIR, + CMD_STAT, + CMD_ATTRS, + CMD_LSTAT, + CMD_FSTAT, + CMD_SETSTAT, + CMD_FSETSTAT, + CMD_READLINK, + CMD_SYMLINK, + CMD_REALPATH, + CMD_EXTENDED, + SFTP_OP_UNSUPPORTED, ) -_hash_class = { - 'sha1': sha1, - 'md5': md5, -} +_hash_class = {"sha1": sha1, "md5": md5} -class SFTPServer (BaseSFTP, SubsystemHandler): +class SFTPServer(BaseSFTP, SubsystemHandler): """ Server-side SFTP subsystem support. Since this is a `.SubsystemHandler`, it can be (and is meant to be) set as the handler for ``"sftp"`` requests. Use `.Transport.set_subsystem_handler` to activate this class. """ - def __init__(self, channel, name, server, sftp_si=SFTPServerInterface, - *largs, **kwargs): + def __init__( + self, + channel, + name, + server, + sftp_si=SFTPServerInterface, + *largs, + **kwargs + ): """ The constructor for SFTPServer is meant to be called from within the `.Transport` as a subsystem handler. ``server`` and any additional @@ -79,7 +117,7 @@ class SFTPServer (BaseSFTP, SubsystemHandler): BaseSFTP.__init__(self) SubsystemHandler.__init__(self, channel, name, server) transport = channel.get_transport() - self.logger = util.get_logger(transport.get_log_channel() + '.sftp') + self.logger = util.get_logger(transport.get_log_channel() + ".sftp") self.ultra_debug = transport.get_hexdump() self.next_handle = 1 # map of handle-string to SFTPHandle for files & folders: @@ -91,26 +129,26 @@ class SFTPServer (BaseSFTP, SubsystemHandler): if issubclass(type(msg), list): for m in msg: super(SFTPServer, self)._log( - level, - "[chan " + self.sock.get_name() + "] " + m) + level, "[chan " + self.sock.get_name() + "] " + m + ) else: super(SFTPServer, self)._log( - level, - "[chan " + self.sock.get_name() + "] " + msg) + level, "[chan " + self.sock.get_name() + "] " + msg + ) def start_subsystem(self, name, transport, channel): self.sock = channel - self._log(DEBUG, 'Started sftp server on channel {!r}'.format(channel)) + self._log(DEBUG, "Started sftp server on channel {!r}".format(channel)) self._send_server_version() self.server.session_started() while True: try: t, data = self._read_packet() except EOFError: - self._log(DEBUG, 'EOF -- end of session') + self._log(DEBUG, "EOF -- end of session") return except Exception as e: - self._log(DEBUG, 'Exception on channel: ' + str(e)) + self._log(DEBUG, "Exception on channel: " + str(e)) self._log(DEBUG, util.tb_strings()) return msg = Message(data) @@ -118,7 +156,7 @@ class SFTPServer (BaseSFTP, SubsystemHandler): try: self._process(t, request_number, msg) except Exception as e: - self._log(DEBUG, 'Exception in server processing: ' + str(e)) + self._log(DEBUG, "Exception in server processing: " + str(e)) self._log(DEBUG, util.tb_strings()) # send some kind of failure message, at least try: @@ -172,7 +210,7 @@ class SFTPServer (BaseSFTP, SubsystemHandler): name of the file to alter (should usually be an absolute path). :param .SFTPAttributes attr: attributes to change. """ - if sys.platform != 'win32': + if sys.platform != "win32": # mode operations are meaningless on win32 if attr._flags & attr.FLAG_PERMISSIONS: os.chmod(filename, attr.st_mode) @@ -181,7 +219,7 @@ class SFTPServer (BaseSFTP, SubsystemHandler): if attr._flags & attr.FLAG_AMTIME: os.utime(filename, (attr.st_atime, attr.st_mtime)) if attr._flags & attr.FLAG_SIZE: - with open(filename, 'w+') as f: + with open(filename, "w+") as f: f.truncate(attr.st_size) # ...internals... @@ -200,8 +238,8 @@ class SFTPServer (BaseSFTP, SubsystemHandler): item._pack(msg) else: raise Exception( - 'unknown type for {!r} type {!r}'.format( - item, type(item))) + "unknown type for {!r} type {!r}".format(item, type(item)) + ) self._send_packet(t, msg) def _send_handle_response(self, request_number, handle, folder=False): @@ -209,7 +247,7 @@ class SFTPServer (BaseSFTP, SubsystemHandler): # must be error code self._send_status(request_number, handle) return - handle._set_name(b('hx{:d}'.format(self.next_handle))) + handle._set_name(b("hx{:d}".format(self.next_handle))) self.next_handle += 1 if folder: self.folder_table[handle._get_name()] = handle @@ -222,10 +260,10 @@ class SFTPServer (BaseSFTP, SubsystemHandler): try: desc = SFTP_DESC[code] except IndexError: - desc = 'Unknown' + desc = "Unknown" # some clients expect a "langauge" tag at the end # (but don't mind it being blank) - self._response(request_number, CMD_STATUS, code, desc, '') + self._response(request_number, CMD_STATUS, code, desc, "") def _open_folder(self, request_number, path): resp = self.server.list_folder(path) @@ -264,7 +302,8 @@ class SFTPServer (BaseSFTP, SubsystemHandler): block_size = msg.get_int() if handle not in self.file_table: self._send_status( - request_number, SFTP_BAD_MESSAGE, 'Invalid handle') + request_number, SFTP_BAD_MESSAGE, "Invalid handle" + ) return f = self.file_table[handle] for x in alg_list: @@ -274,19 +313,21 @@ class SFTPServer (BaseSFTP, SubsystemHandler): break else: self._send_status( - request_number, SFTP_FAILURE, 'No supported hash types found') + request_number, SFTP_FAILURE, "No supported hash types found" + ) return if length == 0: st = f.stat() if not issubclass(type(st), SFTPAttributes): - self._send_status(request_number, st, 'Unable to stat file') + self._send_status(request_number, st, "Unable to stat file") return length = st.st_size - start if block_size == 0: block_size = length if block_size < 256: self._send_status( - request_number, SFTP_FAILURE, 'Block size too small') + request_number, SFTP_FAILURE, "Block size too small" + ) return sum_out = bytes() @@ -301,7 +342,8 @@ class SFTPServer (BaseSFTP, SubsystemHandler): data = f.read(offset, chunklen) if not isinstance(data, bytes_types): self._send_status( - request_number, data, 'Unable to hash file') + request_number, data, "Unable to hash file" + ) return hash_obj.update(data) count += len(data) @@ -310,7 +352,7 @@ class SFTPServer (BaseSFTP, SubsystemHandler): msg = Message() msg.add_int(request_number) - msg.add_string('check-file') + msg.add_string("check-file") msg.add_string(algname) msg.add_bytes(sum_out) self._send_packet(CMD_EXTENDED_REPLY, msg) @@ -334,13 +376,14 @@ class SFTPServer (BaseSFTP, SubsystemHandler): return flags def _process(self, t, request_number, msg): - self._log(DEBUG, 'Request: {}'.format(CMD_NAMES[t])) + self._log(DEBUG, "Request: {}".format(CMD_NAMES[t])) if t == CMD_OPEN: path = msg.get_text() flags = self._convert_pflags(msg.get_int()) attr = SFTPAttributes._from_msg(msg) self._send_handle_response( - request_number, self.server.open(path, flags, attr)) + request_number, self.server.open(path, flags, attr) + ) elif t == CMD_CLOSE: handle = msg.get_binary() if handle in self.folder_table: @@ -353,14 +396,16 @@ class SFTPServer (BaseSFTP, SubsystemHandler): self._send_status(request_number, SFTP_OK) return self._send_status( - request_number, SFTP_BAD_MESSAGE, 'Invalid handle') + request_number, SFTP_BAD_MESSAGE, "Invalid handle" + ) elif t == CMD_READ: handle = msg.get_binary() offset = msg.get_int64() length = msg.get_int() if handle not in self.file_table: self._send_status( - request_number, SFTP_BAD_MESSAGE, 'Invalid handle') + request_number, SFTP_BAD_MESSAGE, "Invalid handle" + ) return data = self.file_table[handle].read(offset, length) if isinstance(data, (bytes_types, string_types)): @@ -376,10 +421,12 @@ class SFTPServer (BaseSFTP, SubsystemHandler): data = msg.get_binary() if handle not in self.file_table: self._send_status( - request_number, SFTP_BAD_MESSAGE, 'Invalid handle') + request_number, SFTP_BAD_MESSAGE, "Invalid handle" + ) return self._send_status( - request_number, self.file_table[handle].write(offset, data)) + request_number, self.file_table[handle].write(offset, data) + ) elif t == CMD_REMOVE: path = msg.get_text() self._send_status(request_number, self.server.remove(path)) @@ -387,7 +434,8 @@ class SFTPServer (BaseSFTP, SubsystemHandler): oldpath = msg.get_text() newpath = msg.get_text() self._send_status( - request_number, self.server.rename(oldpath, newpath)) + request_number, self.server.rename(oldpath, newpath) + ) elif t == CMD_MKDIR: path = msg.get_text() attr = SFTPAttributes._from_msg(msg) @@ -403,7 +451,8 @@ class SFTPServer (BaseSFTP, SubsystemHandler): handle = msg.get_binary() if handle not in self.folder_table: self._send_status( - request_number, SFTP_BAD_MESSAGE, 'Invalid handle') + request_number, SFTP_BAD_MESSAGE, "Invalid handle" + ) return folder = self.folder_table[handle] self._read_folder(request_number, folder) @@ -425,7 +474,8 @@ class SFTPServer (BaseSFTP, SubsystemHandler): handle = msg.get_binary() if handle not in self.file_table: self._send_status( - request_number, SFTP_BAD_MESSAGE, 'Invalid handle') + request_number, SFTP_BAD_MESSAGE, "Invalid handle" + ) return resp = self.file_table[handle].stat() if issubclass(type(resp), SFTPAttributes): @@ -441,16 +491,19 @@ class SFTPServer (BaseSFTP, SubsystemHandler): attr = SFTPAttributes._from_msg(msg) if handle not in self.file_table: self._response( - request_number, SFTP_BAD_MESSAGE, 'Invalid handle') + request_number, SFTP_BAD_MESSAGE, "Invalid handle" + ) return self._send_status( - request_number, self.file_table[handle].chattr(attr)) + request_number, self.file_table[handle].chattr(attr) + ) elif t == CMD_READLINK: path = msg.get_text() resp = self.server.readlink(path) if isinstance(resp, (bytes_types, string_types)): self._response( - request_number, CMD_NAME, 1, resp, '', SFTPAttributes()) + request_number, CMD_NAME, 1, resp, "", SFTPAttributes() + ) else: self._send_status(request_number, resp) elif t == CMD_SYMLINK: @@ -459,17 +512,19 @@ class SFTPServer (BaseSFTP, SubsystemHandler): target_path = msg.get_text() path = msg.get_text() self._send_status( - request_number, self.server.symlink(target_path, path)) + request_number, self.server.symlink(target_path, path) + ) elif t == CMD_REALPATH: path = msg.get_text() rpath = self.server.canonicalize(path) self._response( - request_number, CMD_NAME, 1, rpath, '', SFTPAttributes()) + request_number, CMD_NAME, 1, rpath, "", SFTPAttributes() + ) elif t == CMD_EXTENDED: tag = msg.get_text() - if tag == 'check-file': + if tag == "check-file": self._check_file(request_number, msg) - elif tag == 'posix-rename@openssh.com': + elif tag == "posix-rename@openssh.com": oldpath = msg.get_text() newpath = msg.get_text() self._send_status( diff --git a/paramiko/sftp_si.py b/paramiko/sftp_si.py index b43b582c..40dc561c 100644 --- a/paramiko/sftp_si.py +++ b/paramiko/sftp_si.py @@ -25,7 +25,7 @@ import sys from paramiko.sftp import SFTP_OP_UNSUPPORTED -class SFTPServerInterface (object): +class SFTPServerInterface(object): """ This class defines an interface for controlling the behavior of paramiko when using the `.SFTPServer` subsystem to provide an SFTP server. @@ -39,6 +39,7 @@ class SFTPServerInterface (object): All paths are in string form instead of unicode because not all SFTP clients & servers obey the requirement that paths be encoded in UTF-8. """ + def __init__(self, server, *largs, **kwargs): """ Create a new SFTPServerInterface object. This method does nothing by @@ -281,10 +282,10 @@ class SFTPServerInterface (object): if os.path.isabs(path): out = os.path.normpath(path) else: - out = os.path.normpath('/' + path) - if sys.platform == 'win32': + out = os.path.normpath("/" + path) + if sys.platform == "win32": # on windows, normalize backslashes to sftp/posix format - out = out.replace('\\', '/') + out = out.replace("\\", "/") return out def readlink(self, path): diff --git a/paramiko/ssh_exception.py b/paramiko/ssh_exception.py index 2df84b65..12407d66 100644 --- a/paramiko/ssh_exception.py +++ b/paramiko/ssh_exception.py @@ -19,14 +19,14 @@ import socket -class SSHException (Exception): +class SSHException(Exception): """ Exception raised by failures in SSH2 protocol negotiation or logic errors. """ pass -class AuthenticationException (SSHException): +class AuthenticationException(SSHException): """ Exception raised when authentication failed for some reason. It may be possible to retry with different credentials. (Other classes specify more @@ -37,14 +37,14 @@ class AuthenticationException (SSHException): pass -class PasswordRequiredException (AuthenticationException): +class PasswordRequiredException(AuthenticationException): """ Exception raised when a password is needed to unlock a private key file. """ pass -class BadAuthenticationType (AuthenticationException): +class BadAuthenticationType(AuthenticationException): """ Exception raised when an authentication type (like password) is used, but the server isn't allowing that type. (It may only allow public-key, for @@ -60,28 +60,28 @@ class BadAuthenticationType (AuthenticationException): AuthenticationException.__init__(self, explanation) self.allowed_types = types # for unpickling - self.args = (explanation, types, ) + self.args = (explanation, types) def __str__(self): - return '{} (allowed_types={!r})'.format( + return "{} (allowed_types={!r})".format( SSHException.__str__(self), self.allowed_types ) -class PartialAuthentication (AuthenticationException): +class PartialAuthentication(AuthenticationException): """ An internal exception thrown in the case of partial authentication. """ allowed_types = [] def __init__(self, types): - AuthenticationException.__init__(self, 'partial authentication') + AuthenticationException.__init__(self, "partial authentication") self.allowed_types = types # for unpickling - self.args = (types, ) + self.args = (types,) -class ChannelException (SSHException): +class ChannelException(SSHException): """ Exception raised when an attempt to open a new `.Channel` fails. @@ -89,14 +89,15 @@ class ChannelException (SSHException): .. versionadded:: 1.6 """ + def __init__(self, code, text): SSHException.__init__(self, text) self.code = code # for unpickling - self.args = (code, text, ) + self.args = (code, text) -class BadHostKeyException (SSHException): +class BadHostKeyException(SSHException): """ The host key given by the SSH server did not match what we were expecting. @@ -106,35 +107,40 @@ class BadHostKeyException (SSHException): .. versionadded:: 1.6 """ + def __init__(self, hostname, got_key, expected_key): - message = 'Host key for server {} does not match: got {}, expected {}' # noqa + message = ( + "Host key for server {} does not match: got {}, expected {}" + ) # noqa message = message.format( - hostname, got_key.get_base64(), - expected_key.get_base64()) + hostname, got_key.get_base64(), expected_key.get_base64() + ) SSHException.__init__(self, message) self.hostname = hostname self.key = got_key self.expected_key = expected_key # for unpickling - self.args = (hostname, got_key, expected_key, ) + self.args = (hostname, got_key, expected_key) -class ProxyCommandFailure (SSHException): +class ProxyCommandFailure(SSHException): """ The "ProxyCommand" found in the .ssh/config file returned an error. :param str command: The command line that is generating this exception. :param str error: The error captured from the proxy command output. """ + def __init__(self, command, error): - SSHException.__init__(self, + SSHException.__init__( + self, '"ProxyCommand ({})" returned non-zero exit status: {}'.format( command, error - ) + ), ) self.error = error # for unpickling - self.args = (command, error, ) + self.args = (command, error) class NoValidConnectionsError(socket.error): @@ -159,23 +165,23 @@ class NoValidConnectionsError(socket.error): .. versionadded:: 1.16 """ + def __init__(self, errors): """ :param dict errors: The errors dict to store, as described by class docstring. """ addrs = sorted(errors.keys()) - body = ', '.join([x[0] for x in addrs[:-1]]) + body = ", ".join([x[0] for x in addrs[:-1]]) tail = addrs[-1][0] if body: msg = "Unable to connect to port {0} on {1} or {2}" else: msg = "Unable to connect to port {0} on {2}" super(NoValidConnectionsError, self).__init__( - None, # stand-in for errno - msg.format(addrs[0][1], body, tail) + None, msg.format(addrs[0][1], body, tail) # stand-in for errno ) self.errors = errors def __reduce__(self): - return (self.__class__, (self.errors, )) + return (self.__class__, (self.errors,)) diff --git a/paramiko/ssh_gss.py b/paramiko/ssh_gss.py index aa7cc74d..3f299aee 100644 --- a/paramiko/ssh_gss.py +++ b/paramiko/ssh_gss.py @@ -42,19 +42,19 @@ GSS_AUTH_AVAILABLE = True GSS_EXCEPTIONS = () - - #: :var str _API: Constraint for the used API _API = "MIT" try: import gssapi + GSS_EXCEPTIONS = (gssapi.GSSException,) except (ImportError, OSError): try: import pywintypes import sspicon import sspi + _API = "SSPI" GSS_EXCEPTIONS = (pywintypes.error,) except ImportError: @@ -99,6 +99,7 @@ class _SSH_GSSAuth(object): Contains the shared variables and methods of `._SSH_GSSAPI` and `._SSH_SSPI`. """ + def __init__(self, auth_method, gss_deleg_creds): """ :param str auth_method: The name of the SSH authentication mechanism @@ -160,6 +161,7 @@ class _SSH_GSSAuth(object): """ from pyasn1.type.univ import ObjectIdentifier from pyasn1.codec.der import encoder + OIDs = self._make_uint32(1) krb5_OID = encoder.encode(ObjectIdentifier(self._krb5_mech)) OID_len = self._make_uint32(len(krb5_OID)) @@ -175,6 +177,7 @@ class _SSH_GSSAuth(object): :return: ``True`` if the given OID is supported, otherwise C{False} """ from pyasn1.codec.der import decoder + mech, __ = decoder.decode(desired_mech) if mech.__str__() != self._krb5_mech: return False @@ -210,7 +213,7 @@ class _SSH_GSSAuth(object): """ mic = self._make_uint32(len(session_id)) mic += session_id - mic += struct.pack('B', MSG_USERAUTH_REQUEST) + mic += struct.pack("B", MSG_USERAUTH_REQUEST) mic += self._make_uint32(len(username)) mic += username.encode() mic += self._make_uint32(len(service)) @@ -226,6 +229,7 @@ class _SSH_GSSAPI(_SSH_GSSAuth): :see: `.GSSAuth` """ + def __init__(self, auth_method, gss_deleg_creds): """ :param str auth_method: The name of the SSH authentication mechanism @@ -235,17 +239,22 @@ class _SSH_GSSAPI(_SSH_GSSAuth): _SSH_GSSAuth.__init__(self, auth_method, gss_deleg_creds) if self._gss_deleg_creds: - self._gss_flags = (gssapi.C_PROT_READY_FLAG, - gssapi.C_INTEG_FLAG, - gssapi.C_MUTUAL_FLAG, - gssapi.C_DELEG_FLAG) + self._gss_flags = ( + gssapi.C_PROT_READY_FLAG, + gssapi.C_INTEG_FLAG, + gssapi.C_MUTUAL_FLAG, + gssapi.C_DELEG_FLAG, + ) else: - self._gss_flags = (gssapi.C_PROT_READY_FLAG, - gssapi.C_INTEG_FLAG, - gssapi.C_MUTUAL_FLAG) + self._gss_flags = ( + gssapi.C_PROT_READY_FLAG, + gssapi.C_INTEG_FLAG, + gssapi.C_MUTUAL_FLAG, + ) - def ssh_init_sec_context(self, target, desired_mech=None, - username=None, recv_token=None): + def ssh_init_sec_context( + self, target, desired_mech=None, username=None, recv_token=None + ): """ Initialize a GSS-API context. @@ -262,10 +271,12 @@ class _SSH_GSSAPI(_SSH_GSSAuth): ``None`` if no token was returned """ from pyasn1.codec.der import decoder + self._username = username self._gss_host = target - targ_name = gssapi.Name("host@" + self._gss_host, - gssapi.C_NT_HOSTBASED_SERVICE) + targ_name = gssapi.Name( + "host@" + self._gss_host, gssapi.C_NT_HOSTBASED_SERVICE + ) ctx = gssapi.Context() ctx.flags = self._gss_flags if desired_mech is None: @@ -279,15 +290,16 @@ class _SSH_GSSAPI(_SSH_GSSAuth): token = None try: if recv_token is None: - self._gss_ctxt = gssapi.InitContext(peer_name=targ_name, - mech_type=krb5_mech, - req_flags=ctx.flags) + self._gss_ctxt = gssapi.InitContext( + peer_name=targ_name, + mech_type=krb5_mech, + req_flags=ctx.flags, + ) token = self._gss_ctxt.step(token) else: token = self._gss_ctxt.step(recv_token) except gssapi.GSSException: - message = "{} Target: {}".format( - sys.exc_info()[1], self._gss_host) + message = "{} Target: {}".format(sys.exc_info()[1], self._gss_host) raise gssapi.GSSException(message) self._gss_ctxt_status = self._gss_ctxt.established return token @@ -307,10 +319,12 @@ class _SSH_GSSAPI(_SSH_GSSAuth): """ self._session_id = session_id if not gss_kex: - mic_field = self._ssh_build_mic(self._session_id, - self._username, - self._service, - self._auth_method) + mic_field = self._ssh_build_mic( + self._session_id, + self._username, + self._service, + self._auth_method, + ) mic_token = self._gss_ctxt.get_mic(mic_field) else: # for key exchange with gssapi-keyex @@ -351,16 +365,17 @@ class _SSH_GSSAPI(_SSH_GSSAuth): self._username = username if self._username is not None: # server mode - mic_field = self._ssh_build_mic(self._session_id, - self._username, - self._service, - self._auth_method) + mic_field = self._ssh_build_mic( + self._session_id, + self._username, + self._service, + self._auth_method, + ) self._gss_srv_ctxt.verify_mic(mic_field, mic_token) else: # for key exchange with gssapi-keyex # client mode - self._gss_ctxt.verify_mic(self._session_id, - mic_token) + self._gss_ctxt.verify_mic(self._session_id, mic_token) @property def credentials_delegated(self): @@ -393,6 +408,7 @@ class _SSH_SSPI(_SSH_GSSAuth): :see: `.GSSAuth` """ + def __init__(self, auth_method, gss_deleg_creds): """ :param str auth_method: The name of the SSH authentication mechanism @@ -403,18 +419,18 @@ class _SSH_SSPI(_SSH_GSSAuth): if self._gss_deleg_creds: self._gss_flags = ( - sspicon.ISC_REQ_INTEGRITY | - sspicon.ISC_REQ_MUTUAL_AUTH | - sspicon.ISC_REQ_DELEGATE + sspicon.ISC_REQ_INTEGRITY + | sspicon.ISC_REQ_MUTUAL_AUTH + | sspicon.ISC_REQ_DELEGATE ) else: self._gss_flags = ( - sspicon.ISC_REQ_INTEGRITY | - sspicon.ISC_REQ_MUTUAL_AUTH + sspicon.ISC_REQ_INTEGRITY | sspicon.ISC_REQ_MUTUAL_AUTH ) - def ssh_init_sec_context(self, target, desired_mech=None, - username=None, recv_token=None): + def ssh_init_sec_context( + self, target, desired_mech=None, username=None, recv_token=None + ): """ Initialize a SSPI context. @@ -431,6 +447,7 @@ class _SSH_SSPI(_SSH_GSSAuth): no token was returned """ from pyasn1.codec.der import decoder + self._username = username self._gss_host = target error = 0 @@ -441,9 +458,9 @@ class _SSH_SSPI(_SSH_GSSAuth): raise SSHException("Unsupported mechanism OID.") try: if recv_token is None: - self._gss_ctxt = sspi.ClientAuth("Kerberos", - scflags=self._gss_flags, - targetspn=targ_name) + self._gss_ctxt = sspi.ClientAuth( + "Kerberos", scflags=self._gss_flags, targetspn=targ_name + ) error, token = self._gss_ctxt.authorize(recv_token) token = token[0].Buffer except pywintypes.error as e: @@ -478,10 +495,12 @@ class _SSH_SSPI(_SSH_GSSAuth): """ self._session_id = session_id if not gss_kex: - mic_field = self._ssh_build_mic(self._session_id, - self._username, - self._service, - self._auth_method) + mic_field = self._ssh_build_mic( + self._session_id, + self._username, + self._service, + self._auth_method, + ) mic_token = self._gss_ctxt.sign(mic_field) else: # for key exchange with gssapi-keyex @@ -524,10 +543,12 @@ class _SSH_SSPI(_SSH_GSSAuth): self._username = username if username is not None: # server mode - mic_field = self._ssh_build_mic(self._session_id, - self._username, - self._service, - self._auth_method) + mic_field = self._ssh_build_mic( + self._session_id, + self._username, + self._service, + self._auth_method, + ) # Verifies data and its signature. If verification fails, an # sspi.error will be raised. self._gss_srv_ctxt.verify(mic_field, mic_token) @@ -545,9 +566,8 @@ class _SSH_SSPI(_SSH_GSSAuth): :return: ``True`` if credentials are delegated, otherwise ``False`` """ - return ( - self._gss_flags & sspicon.ISC_REQ_DELEGATE and - (self._gss_srv_ctxt_status or self._gss_flags) + return self._gss_flags & sspicon.ISC_REQ_DELEGATE and ( + self._gss_srv_ctxt_status or self._gss_flags ) def save_client_creds(self, client_token): diff --git a/paramiko/transport.py b/paramiko/transport.py index 4af29c95..9a5d33dd 100644 --- a/paramiko/transport.py +++ b/paramiko/transport.py @@ -39,18 +39,49 @@ from paramiko.auth_handler import AuthHandler from paramiko.ssh_gss import GSSAuth from paramiko.channel import Channel from paramiko.common import ( - xffffffff, cMSG_CHANNEL_OPEN, cMSG_IGNORE, cMSG_GLOBAL_REQUEST, DEBUG, - MSG_KEXINIT, MSG_IGNORE, MSG_DISCONNECT, MSG_DEBUG, ERROR, WARNING, - cMSG_UNIMPLEMENTED, INFO, cMSG_KEXINIT, cMSG_NEWKEYS, MSG_NEWKEYS, - cMSG_REQUEST_SUCCESS, cMSG_REQUEST_FAILURE, CONNECTION_FAILED_CODE, - OPEN_FAILED_ADMINISTRATIVELY_PROHIBITED, OPEN_SUCCEEDED, - cMSG_CHANNEL_OPEN_FAILURE, cMSG_CHANNEL_OPEN_SUCCESS, MSG_GLOBAL_REQUEST, - MSG_REQUEST_SUCCESS, MSG_REQUEST_FAILURE, MSG_CHANNEL_OPEN_SUCCESS, - MSG_CHANNEL_OPEN_FAILURE, MSG_CHANNEL_OPEN, MSG_CHANNEL_SUCCESS, - MSG_CHANNEL_FAILURE, MSG_CHANNEL_DATA, MSG_CHANNEL_EXTENDED_DATA, - MSG_CHANNEL_WINDOW_ADJUST, MSG_CHANNEL_REQUEST, MSG_CHANNEL_EOF, - MSG_CHANNEL_CLOSE, MIN_WINDOW_SIZE, MIN_PACKET_SIZE, MAX_WINDOW_SIZE, - DEFAULT_WINDOW_SIZE, DEFAULT_MAX_PACKET_SIZE, HIGHEST_USERAUTH_MESSAGE_ID, + xffffffff, + cMSG_CHANNEL_OPEN, + cMSG_IGNORE, + cMSG_GLOBAL_REQUEST, + DEBUG, + MSG_KEXINIT, + MSG_IGNORE, + MSG_DISCONNECT, + MSG_DEBUG, + ERROR, + WARNING, + cMSG_UNIMPLEMENTED, + INFO, + cMSG_KEXINIT, + cMSG_NEWKEYS, + MSG_NEWKEYS, + cMSG_REQUEST_SUCCESS, + cMSG_REQUEST_FAILURE, + CONNECTION_FAILED_CODE, + OPEN_FAILED_ADMINISTRATIVELY_PROHIBITED, + OPEN_SUCCEEDED, + cMSG_CHANNEL_OPEN_FAILURE, + cMSG_CHANNEL_OPEN_SUCCESS, + MSG_GLOBAL_REQUEST, + MSG_REQUEST_SUCCESS, + MSG_REQUEST_FAILURE, + MSG_CHANNEL_OPEN_SUCCESS, + MSG_CHANNEL_OPEN_FAILURE, + MSG_CHANNEL_OPEN, + MSG_CHANNEL_SUCCESS, + MSG_CHANNEL_FAILURE, + MSG_CHANNEL_DATA, + MSG_CHANNEL_EXTENDED_DATA, + MSG_CHANNEL_WINDOW_ADJUST, + MSG_CHANNEL_REQUEST, + MSG_CHANNEL_EOF, + MSG_CHANNEL_CLOSE, + MIN_WINDOW_SIZE, + MIN_PACKET_SIZE, + MAX_WINDOW_SIZE, + DEFAULT_WINDOW_SIZE, + DEFAULT_MAX_PACKET_SIZE, + HIGHEST_USERAUTH_MESSAGE_ID, ) from paramiko.compress import ZlibCompressor, ZlibDecompressor from paramiko.dsskey import DSSKey @@ -69,7 +100,10 @@ from paramiko.ecdsakey import ECDSAKey from paramiko.server import ServerInterface from paramiko.sftp_client import SFTPClient from paramiko.ssh_exception import ( - SSHException, BadAuthenticationType, ChannelException, ProxyCommandFailure, + SSHException, + BadAuthenticationType, + ChannelException, + ProxyCommandFailure, ) from paramiko.util import retry_on_signal, ClosingContextManager, clamp_value @@ -77,12 +111,14 @@ from paramiko.util import retry_on_signal, ClosingContextManager, clamp_value # for thread cleanup _active_threads = [] + def _join_lingering_threads(): for thr in _active_threads: thr.stop_thread() import atexit + atexit.register(_join_lingering_threads) @@ -99,162 +135,163 @@ class Transport(threading.Thread, ClosingContextManager): _ENCRYPT = object() _DECRYPT = object() - _PROTO_ID = '2.0' - _CLIENT_ID = 'paramiko_{}'.format(paramiko.__version__) + _PROTO_ID = "2.0" + _CLIENT_ID = "paramiko_{}".format(paramiko.__version__) # These tuples of algorithm identifiers are in preference order; do not # reorder without reason! _preferred_ciphers = ( - 'aes128-ctr', - 'aes192-ctr', - 'aes256-ctr', - 'aes128-cbc', - 'aes192-cbc', - 'aes256-cbc', - 'blowfish-cbc', - '3des-cbc', + "aes128-ctr", + "aes192-ctr", + "aes256-ctr", + "aes128-cbc", + "aes192-cbc", + "aes256-cbc", + "blowfish-cbc", + "3des-cbc", ) _preferred_macs = ( - 'hmac-sha2-256', - 'hmac-sha2-512', - 'hmac-sha1', - 'hmac-md5', - 'hmac-sha1-96', - 'hmac-md5-96', + "hmac-sha2-256", + "hmac-sha2-512", + "hmac-sha1", + "hmac-md5", + "hmac-sha1-96", + "hmac-md5-96", ) _preferred_keys = ( - 'ssh-ed25519', - 'ecdsa-sha2-nistp256', - 'ecdsa-sha2-nistp384', - 'ecdsa-sha2-nistp521', - 'ssh-rsa', - 'ssh-dss', + "ssh-ed25519", + "ecdsa-sha2-nistp256", + "ecdsa-sha2-nistp384", + "ecdsa-sha2-nistp521", + "ssh-rsa", + "ssh-dss", ) _preferred_kex = ( - 'ecdh-sha2-nistp256', - 'ecdh-sha2-nistp384', - 'ecdh-sha2-nistp521', - 'diffie-hellman-group-exchange-sha256', - 'diffie-hellman-group14-sha256', - 'diffie-hellman-group-exchange-sha1', - 'diffie-hellman-group14-sha1', - 'diffie-hellman-group1-sha1', + "ecdh-sha2-nistp256", + "ecdh-sha2-nistp384", + "ecdh-sha2-nistp521", + "diffie-hellman-group-exchange-sha256", + "diffie-hellman-group14-sha256", + "diffie-hellman-group-exchange-sha1", + "diffie-hellman-group14-sha1", + "diffie-hellman-group1-sha1", ) _preferred_gsskex = ( - 'gss-gex-sha1-toWM5Slw5Ew8Mqkay+al2g==', - 'gss-group14-sha1-toWM5Slw5Ew8Mqkay+al2g==', - 'gss-group1-sha1-toWM5Slw5Ew8Mqkay+al2g==', + "gss-gex-sha1-toWM5Slw5Ew8Mqkay+al2g==", + "gss-group14-sha1-toWM5Slw5Ew8Mqkay+al2g==", + "gss-group1-sha1-toWM5Slw5Ew8Mqkay+al2g==", ) - _preferred_compression = ('none',) + _preferred_compression = ("none",) _cipher_info = { - 'aes128-ctr': { - 'class': algorithms.AES, - 'mode': modes.CTR, - 'block-size': 16, - 'key-size': 16 + "aes128-ctr": { + "class": algorithms.AES, + "mode": modes.CTR, + "block-size": 16, + "key-size": 16, }, - 'aes192-ctr': { - 'class': algorithms.AES, - 'mode': modes.CTR, - 'block-size': 16, - 'key-size': 24 + "aes192-ctr": { + "class": algorithms.AES, + "mode": modes.CTR, + "block-size": 16, + "key-size": 24, }, - 'aes256-ctr': { - 'class': algorithms.AES, - 'mode': modes.CTR, - 'block-size': 16, - 'key-size': 32 + "aes256-ctr": { + "class": algorithms.AES, + "mode": modes.CTR, + "block-size": 16, + "key-size": 32, }, - 'blowfish-cbc': { - 'class': algorithms.Blowfish, - 'mode': modes.CBC, - 'block-size': 8, - 'key-size': 16 + "blowfish-cbc": { + "class": algorithms.Blowfish, + "mode": modes.CBC, + "block-size": 8, + "key-size": 16, }, - 'aes128-cbc': { - 'class': algorithms.AES, - 'mode': modes.CBC, - 'block-size': 16, - 'key-size': 16 + "aes128-cbc": { + "class": algorithms.AES, + "mode": modes.CBC, + "block-size": 16, + "key-size": 16, }, - 'aes192-cbc': { - 'class': algorithms.AES, - 'mode': modes.CBC, - 'block-size': 16, - 'key-size': 24 + "aes192-cbc": { + "class": algorithms.AES, + "mode": modes.CBC, + "block-size": 16, + "key-size": 24, }, - 'aes256-cbc': { - 'class': algorithms.AES, - 'mode': modes.CBC, - 'block-size': 16, - 'key-size': 32 + "aes256-cbc": { + "class": algorithms.AES, + "mode": modes.CBC, + "block-size": 16, + "key-size": 32, }, - '3des-cbc': { - 'class': algorithms.TripleDES, - 'mode': modes.CBC, - 'block-size': 8, - 'key-size': 24 + "3des-cbc": { + "class": algorithms.TripleDES, + "mode": modes.CBC, + "block-size": 8, + "key-size": 24, }, } - _mac_info = { - 'hmac-sha1': {'class': sha1, 'size': 20}, - 'hmac-sha1-96': {'class': sha1, 'size': 12}, - 'hmac-sha2-256': {'class': sha256, 'size': 32}, - 'hmac-sha2-512': {'class': sha512, 'size': 64}, - 'hmac-md5': {'class': md5, 'size': 16}, - 'hmac-md5-96': {'class': md5, 'size': 12}, + "hmac-sha1": {"class": sha1, "size": 20}, + "hmac-sha1-96": {"class": sha1, "size": 12}, + "hmac-sha2-256": {"class": sha256, "size": 32}, + "hmac-sha2-512": {"class": sha512, "size": 64}, + "hmac-md5": {"class": md5, "size": 16}, + "hmac-md5-96": {"class": md5, "size": 12}, } _key_info = { - 'ssh-rsa': RSAKey, - 'ssh-rsa-cert-v01@openssh.com': RSAKey, - 'ssh-dss': DSSKey, - 'ssh-dss-cert-v01@openssh.com': DSSKey, - 'ecdsa-sha2-nistp256': ECDSAKey, - 'ecdsa-sha2-nistp256-cert-v01@openssh.com': ECDSAKey, - 'ecdsa-sha2-nistp384': ECDSAKey, - 'ecdsa-sha2-nistp384-cert-v01@openssh.com': ECDSAKey, - 'ecdsa-sha2-nistp521': ECDSAKey, - 'ecdsa-sha2-nistp521-cert-v01@openssh.com': ECDSAKey, - 'ssh-ed25519': Ed25519Key, - 'ssh-ed25519-cert-v01@openssh.com': Ed25519Key, + "ssh-rsa": RSAKey, + "ssh-rsa-cert-v01@openssh.com": RSAKey, + "ssh-dss": DSSKey, + "ssh-dss-cert-v01@openssh.com": DSSKey, + "ecdsa-sha2-nistp256": ECDSAKey, + "ecdsa-sha2-nistp256-cert-v01@openssh.com": ECDSAKey, + "ecdsa-sha2-nistp384": ECDSAKey, + "ecdsa-sha2-nistp384-cert-v01@openssh.com": ECDSAKey, + "ecdsa-sha2-nistp521": ECDSAKey, + "ecdsa-sha2-nistp521-cert-v01@openssh.com": ECDSAKey, + "ssh-ed25519": Ed25519Key, + "ssh-ed25519-cert-v01@openssh.com": Ed25519Key, } _kex_info = { - 'diffie-hellman-group1-sha1': KexGroup1, - 'diffie-hellman-group14-sha1': KexGroup14, - 'diffie-hellman-group-exchange-sha1': KexGex, - 'diffie-hellman-group-exchange-sha256': KexGexSHA256, - 'diffie-hellman-group14-sha256': KexGroup14SHA256, - 'gss-group1-sha1-toWM5Slw5Ew8Mqkay+al2g==': KexGSSGroup1, - 'gss-group14-sha1-toWM5Slw5Ew8Mqkay+al2g==': KexGSSGroup14, - 'gss-gex-sha1-toWM5Slw5Ew8Mqkay+al2g==': KexGSSGex, - 'ecdh-sha2-nistp256': KexNistp256, - 'ecdh-sha2-nistp384': KexNistp384, - 'ecdh-sha2-nistp521': KexNistp521, + "diffie-hellman-group1-sha1": KexGroup1, + "diffie-hellman-group14-sha1": KexGroup14, + "diffie-hellman-group-exchange-sha1": KexGex, + "diffie-hellman-group-exchange-sha256": KexGexSHA256, + "diffie-hellman-group14-sha256": KexGroup14SHA256, + "gss-group1-sha1-toWM5Slw5Ew8Mqkay+al2g==": KexGSSGroup1, + "gss-group14-sha1-toWM5Slw5Ew8Mqkay+al2g==": KexGSSGroup14, + "gss-gex-sha1-toWM5Slw5Ew8Mqkay+al2g==": KexGSSGex, + "ecdh-sha2-nistp256": KexNistp256, + "ecdh-sha2-nistp384": KexNistp384, + "ecdh-sha2-nistp521": KexNistp521, } _compression_info = { # zlib@openssh.com is just zlib, but only turned on after a successful # authentication. openssh servers may only offer this type because # they've had troubles with security holes in zlib in the past. - 'zlib@openssh.com': (ZlibCompressor, ZlibDecompressor), - 'zlib': (ZlibCompressor, ZlibDecompressor), - 'none': (None, None), + "zlib@openssh.com": (ZlibCompressor, ZlibDecompressor), + "zlib": (ZlibCompressor, ZlibDecompressor), + "none": (None, None), } _modulus_pack = None _active_check_timeout = 0.1 - def __init__(self, - sock, - default_window_size=DEFAULT_WINDOW_SIZE, - default_max_packet_size=DEFAULT_MAX_PACKET_SIZE, - gss_kex=False, - gss_deleg_creds=True): + def __init__( + self, + sock, + default_window_size=DEFAULT_WINDOW_SIZE, + default_max_packet_size=DEFAULT_MAX_PACKET_SIZE, + gss_kex=False, + gss_deleg_creds=True, + ): """ Create a new SSH session over an existing socket, or socket-like object. This only creates the `.Transport` object; it doesn't begin @@ -304,7 +341,7 @@ class Transport(threading.Thread, ClosingContextManager): if isinstance(sock, string_types): # convert "host:port" into (host, port) - hl = sock.split(':', 1) + hl = sock.split(":", 1) self.hostname = hl[0] if len(hl) == 1: sock = (hl[0], 22) @@ -314,7 +351,7 @@ class Transport(threading.Thread, ClosingContextManager): # connect to the given (host, port) hostname, port = sock self.hostname = hostname - reason = 'No suitable address family' + reason = "No suitable address family" addrinfos = socket.getaddrinfo( hostname, port, socket.AF_UNSPEC, socket.SOCK_STREAM ) @@ -331,7 +368,8 @@ class Transport(threading.Thread, ClosingContextManager): break else: raise SSHException( - 'Unable to connect to {}: {}'.format(hostname, reason)) + "Unable to connect to {}: {}".format(hostname, reason) + ) # okay, normal socket-ish flow here... threading.Thread.__init__(self) self.setDaemon(True) @@ -342,9 +380,9 @@ class Transport(threading.Thread, ClosingContextManager): # negotiated crypto parameters self.packetizer = Packetizer(sock) - self.local_version = 'SSH-' + self._PROTO_ID + '-' + self._CLIENT_ID - self.remote_version = '' - self.local_cipher = self.remote_cipher = '' + self.local_version = "SSH-" + self._PROTO_ID + "-" + self._CLIENT_ID + self.remote_version = "" + self.local_cipher = self.remote_cipher = "" self.local_kex_init = self.remote_kex_init = None self.local_mac = self.remote_mac = None self.local_compression = self.remote_compression = None @@ -376,8 +414,8 @@ class Transport(threading.Thread, ClosingContextManager): # tracking open channels self._channels = ChannelMap() - self.channel_events = {} # (id -> Event) - self.channels_seen = {} # (id -> True) + self.channel_events = {} # (id -> Event) + self.channels_seen = {} # (id -> True) self._channel_counter = 0 self.default_max_packet_size = default_max_packet_size self.default_window_size = default_window_size @@ -389,7 +427,7 @@ class Transport(threading.Thread, ClosingContextManager): self.clear_to_send = threading.Event() self.clear_to_send_lock = threading.Lock() self.clear_to_send_timeout = 30.0 - self.log_name = 'paramiko.transport' + self.log_name = "paramiko.transport" self.logger = util.get_logger(self.log_name) self.packetizer.set_log(self.logger) self.auth_handler = None @@ -418,23 +456,24 @@ class Transport(threading.Thread, ClosingContextManager): Returns a string representation of this object, for debugging. """ id_ = hex(long(id(self)) & xffffffff) - out = '<paramiko.Transport at {}'.format(id_) + out = "<paramiko.Transport at {}".format(id_) if not self.active: - out += ' (unconnected)' + out += " (unconnected)" else: - if self.local_cipher != '': - out += ' (cipher {}, {:d} bits)'.format( + if self.local_cipher != "": + out += " (cipher {}, {:d} bits)".format( self.local_cipher, - self._cipher_info[self.local_cipher]['key-size'] * 8 + self._cipher_info[self.local_cipher]["key-size"] * 8, ) if self.is_authenticated(): - out += ' (active; {} open channel(s))'.format( - len(self._channels)) + out += " (active; {} open channel(s))".format( + len(self._channels) + ) elif self.initial_kex_done: - out += ' (connected; awaiting auth)' + out += " (connected; awaiting auth)" else: - out += ' (connecting)' - out += '>' + out += " (connecting)" + out += ">" return out def atfork(self): @@ -545,10 +584,9 @@ class Transport(threading.Thread, ClosingContextManager): e = self.get_exception() if e is not None: raise e - raise SSHException('Negotiation failed.') - if ( - event.is_set() or - (timeout is not None and time.time() >= max_time) + raise SSHException("Negotiation failed.") + if event.is_set() or ( + timeout is not None and time.time() >= max_time ): break @@ -614,7 +652,7 @@ class Transport(threading.Thread, ClosingContextManager): e = self.get_exception() if e is not None: raise e - raise SSHException('Negotiation failed.') + raise SSHException("Negotiation failed.") if event.is_set(): break @@ -681,7 +719,7 @@ class Transport(threading.Thread, ClosingContextManager): """ Transport._modulus_pack = ModulusPack() # places to look for the openssh "moduli" file - file_list = ['/etc/ssh/moduli', '/usr/local/etc/moduli'] + file_list = ["/etc/ssh/moduli", "/usr/local/etc/moduli"] if filename is not None: file_list.insert(0, filename) for fn in file_list: @@ -719,7 +757,7 @@ class Transport(threading.Thread, ClosingContextManager): :return: public key (`.PKey`) of the remote server """ if (not self.active) or (not self.initial_kex_done): - raise SSHException('No existing session') + raise SSHException("No existing session") return self.host_key def is_active(self): @@ -733,10 +771,7 @@ class Transport(threading.Thread, ClosingContextManager): return self.active def open_session( - self, - window_size=None, - max_packet_size=None, - timeout=None, + self, window_size=None, max_packet_size=None, timeout=None ): """ Request a new channel to the server, of type ``"session"``. This is @@ -763,10 +798,12 @@ class Transport(threading.Thread, ClosingContextManager): .. versionchanged:: 1.15 Added the ``window_size`` and ``max_packet_size`` arguments. """ - return self.open_channel('session', - window_size=window_size, - max_packet_size=max_packet_size, - timeout=timeout) + return self.open_channel( + "session", + window_size=window_size, + max_packet_size=max_packet_size, + timeout=timeout, + ) def open_x11_channel(self, src_addr=None): """ @@ -782,7 +819,7 @@ class Transport(threading.Thread, ClosingContextManager): `.SSHException` -- if the request is rejected or the session ends prematurely """ - return self.open_channel('x11', src_addr=src_addr) + return self.open_channel("x11", src_addr=src_addr) def open_forward_agent_channel(self): """ @@ -796,7 +833,7 @@ class Transport(threading.Thread, ClosingContextManager): :raises: `.SSHException` -- if the request is rejected or the session ends prematurely """ - return self.open_channel('auth-agent@openssh.com') + return self.open_channel("auth-agent@openssh.com") def open_forwarded_tcpip_channel(self, src_addr, dest_addr): """ @@ -808,15 +845,17 @@ class Transport(threading.Thread, ClosingContextManager): :param src_addr: originator's address :param dest_addr: local (server) connected address """ - return self.open_channel('forwarded-tcpip', dest_addr, src_addr) + return self.open_channel("forwarded-tcpip", dest_addr, src_addr) - def open_channel(self, - kind, - dest_addr=None, - src_addr=None, - window_size=None, - max_packet_size=None, - timeout=None): + def open_channel( + self, + kind, + dest_addr=None, + src_addr=None, + window_size=None, + max_packet_size=None, + timeout=None, + ): """ Request a new channel to the server. `Channels <.Channel>` are socket-like objects used for the actual transfer of data across the @@ -853,7 +892,7 @@ class Transport(threading.Thread, ClosingContextManager): Added the ``window_size`` and ``max_packet_size`` arguments. """ if not self.active: - raise SSHException('SSH session not active') + raise SSHException("SSH session not active") timeout = 3600 if timeout is None else timeout self.lock.acquire() try: @@ -866,12 +905,12 @@ class Transport(threading.Thread, ClosingContextManager): m.add_int(chanid) m.add_int(window_size) m.add_int(max_packet_size) - if (kind == 'forwarded-tcpip') or (kind == 'direct-tcpip'): + if (kind == "forwarded-tcpip") or (kind == "direct-tcpip"): m.add_string(dest_addr[0]) m.add_int(dest_addr[1]) m.add_string(src_addr[0]) m.add_int(src_addr[1]) - elif kind == 'x11': + elif kind == "x11": m.add_string(src_addr[0]) m.add_int(src_addr[1]) chan = Channel(chanid) @@ -889,18 +928,18 @@ class Transport(threading.Thread, ClosingContextManager): if not self.active: e = self.get_exception() if e is None: - e = SSHException('Unable to open channel.') + e = SSHException("Unable to open channel.") raise e if event.is_set(): break elif start_ts + timeout < time.time(): - raise SSHException('Timeout opening channel.') + raise SSHException("Timeout opening channel.") chan = self._channels.get(chanid) if chan is not None: return chan e = self.get_exception() if e is None: - e = SSHException('Unable to open channel.') + e = SSHException("Unable to open channel.") raise e def request_port_forward(self, address, port, handler=None): @@ -937,20 +976,22 @@ class Transport(threading.Thread, ClosingContextManager): `.SSHException` -- if the server refused the TCP forward request """ if not self.active: - raise SSHException('SSH session not active') + raise SSHException("SSH session not active") port = int(port) response = self.global_request( - 'tcpip-forward', (address, port), wait=True + "tcpip-forward", (address, port), wait=True ) if response is None: - raise SSHException('TCP forwarding request denied') + raise SSHException("TCP forwarding request denied") if port == 0: port = response.get_int() if handler is None: + def default_handler(channel, src_addr, dest_addr_port): # src_addr, src_port = src_addr_port # dest_addr, dest_port = dest_addr_port self._queue_incoming_channel(channel) + handler = default_handler self._tcp_handler = handler return port @@ -967,7 +1008,7 @@ class Transport(threading.Thread, ClosingContextManager): if not self.active: return self._tcp_handler = None - self.global_request('cancel-tcpip-forward', (address, port), wait=True) + self.global_request("cancel-tcpip-forward", (address, port), wait=True) def open_sftp_client(self): """ @@ -1020,7 +1061,7 @@ class Transport(threading.Thread, ClosingContextManager): e = self.get_exception() if e is not None: raise e - raise SSHException('Negotiation failed.') + raise SSHException("Negotiation failed.") if self.completion_event.is_set(): break return @@ -1036,8 +1077,10 @@ class Transport(threading.Thread, ClosingContextManager): seconds to wait before sending a keepalive packet (or 0 to disable keepalives). """ + def _request(x=weakref.proxy(self)): - return x.global_request('keepalive@lag.net', wait=False) + return x.global_request("keepalive@lag.net", wait=False) + self.packetizer.set_keepalive(interval, _request) def global_request(self, kind, data=None, wait=True): @@ -1105,7 +1148,7 @@ class Transport(threading.Thread, ClosingContextManager): def connect( self, hostkey=None, - username='', + username="", password=None, pkey=None, gss_host=None, @@ -1179,34 +1222,43 @@ class Transport(threading.Thread, ClosingContextManager): if (hostkey is not None) and not gss_kex: key = self.get_remote_server_key() if ( - key.get_name() != hostkey.get_name() or - key.asbytes() != hostkey.asbytes() + key.get_name() != hostkey.get_name() + or key.asbytes() != hostkey.asbytes() ): - self._log(DEBUG, 'Bad host key from server') - self._log(DEBUG, 'Expected: {}: {}'.format( - hostkey.get_name(), repr(hostkey.asbytes()), - )) - self._log(DEBUG, 'Got : {}: {}'.format( - key.get_name(), repr(key.asbytes()), - )) - raise SSHException('Bad host key from server') - self._log(DEBUG, 'Host key verified ({})'.format( - hostkey.get_name())) + self._log(DEBUG, "Bad host key from server") + self._log( + DEBUG, + "Expected: {}: {}".format( + hostkey.get_name(), repr(hostkey.asbytes()) + ), + ) + self._log( + DEBUG, + "Got : {}: {}".format( + key.get_name(), repr(key.asbytes()) + ), + ) + raise SSHException("Bad host key from server") + self._log( + DEBUG, "Host key verified ({})".format(hostkey.get_name()) + ) if (pkey is not None) or (password is not None) or gss_auth or gss_kex: if gss_auth: - self._log(DEBUG, 'Attempting GSS-API auth... (gssapi-with-mic)') # noqa + self._log( + DEBUG, "Attempting GSS-API auth... (gssapi-with-mic)" + ) # noqa self.auth_gssapi_with_mic( - username, self.gss_host, gss_deleg_creds, + username, self.gss_host, gss_deleg_creds ) elif gss_kex: - self._log(DEBUG, 'Attempting GSS-API auth... (gssapi-keyex)') + self._log(DEBUG, "Attempting GSS-API auth... (gssapi-keyex)") self.auth_gssapi_keyex(username) elif pkey is not None: - self._log(DEBUG, 'Attempting public-key auth...') + self._log(DEBUG, "Attempting public-key auth...") self.auth_publickey(username, pkey) else: - self._log(DEBUG, 'Attempting password auth...') + self._log(DEBUG, "Attempting password auth...") self.auth_password(username, password) return @@ -1261,9 +1313,9 @@ class Transport(threading.Thread, ClosingContextManager): closed. """ return ( - self.active and - self.auth_handler is not None and - self.auth_handler.is_authenticated() + self.active + and self.auth_handler is not None + and self.auth_handler.is_authenticated() ) def get_username(self): @@ -1313,7 +1365,7 @@ class Transport(threading.Thread, ClosingContextManager): .. versionadded:: 1.5 """ if (not self.active) or (not self.initial_kex_done): - raise SSHException('No existing session') + raise SSHException("No existing session") my_event = threading.Event() self.auth_handler = AuthHandler(self) self.auth_handler.auth_none(username, my_event) @@ -1369,7 +1421,7 @@ class Transport(threading.Thread, ClosingContextManager): if (not self.active) or (not self.initial_kex_done): # we should never try to send the password unless we're on a secure # link - raise SSHException('No existing session') + raise SSHException("No existing session") if event is None: my_event = threading.Event() else: @@ -1384,12 +1436,13 @@ class Transport(threading.Thread, ClosingContextManager): except BadAuthenticationType as e: # if password auth isn't allowed, but keyboard-interactive *is*, # try to fudge it - if not fallback or ('keyboard-interactive' not in e.allowed_types): + if not fallback or ("keyboard-interactive" not in e.allowed_types): raise try: + def handler(title, instructions, fields): if len(fields) > 1: - raise SSHException('Fallback authentication failed.') + raise SSHException("Fallback authentication failed.") if len(fields) == 0: # for some reason, at least on os x, a 2nd request will # be made with zero fields requested. maybe it's just @@ -1397,6 +1450,7 @@ class Transport(threading.Thread, ClosingContextManager): # type we're doing here. *shrug* :) return [] return [password] + return self.auth_interactive(username, handler) except SSHException: # attempt failed; just raise the original exception @@ -1439,7 +1493,7 @@ class Transport(threading.Thread, ClosingContextManager): """ if (not self.active) or (not self.initial_kex_done): # we should never try to authenticate unless we're on a secure link - raise SSHException('No existing session') + raise SSHException("No existing session") if event is None: my_event = threading.Event() else: @@ -1451,7 +1505,7 @@ class Transport(threading.Thread, ClosingContextManager): return [] return self.auth_handler.wait_for_response(my_event) - def auth_interactive(self, username, handler, submethods=''): + def auth_interactive(self, username, handler, submethods=""): """ Authenticate to the server interactively. A handler is used to answer arbitrary questions from the server. On many servers, this is just a @@ -1496,7 +1550,7 @@ class Transport(threading.Thread, ClosingContextManager): """ if (not self.active) or (not self.initial_kex_done): # we should never try to authenticate unless we're on a secure link - raise SSHException('No existing session') + raise SSHException("No existing session") my_event = threading.Event() self.auth_handler = AuthHandler(self) self.auth_handler.auth_interactive( @@ -1504,7 +1558,7 @@ class Transport(threading.Thread, ClosingContextManager): ) return self.auth_handler.wait_for_response(my_event) - def auth_interactive_dumb(self, username, handler=None, submethods=''): + def auth_interactive_dumb(self, username, handler=None, submethods=""): """ Autenticate to the server interactively but dumber. Just print the prompt and / or instructions to stdout and send back @@ -1513,6 +1567,7 @@ class Transport(threading.Thread, ClosingContextManager): """ if not handler: + def handler(title, instructions, prompt_list): answers = [] if title: @@ -1520,9 +1575,10 @@ class Transport(threading.Thread, ClosingContextManager): if instructions: print(instructions.strip()) for prompt, show_input in prompt_list: - print(prompt.strip(), end=' ') + print(prompt.strip(), end=" ") answers.append(input()) return answers + return self.auth_interactive(username, handler, submethods) def auth_gssapi_with_mic(self, username, gss_host, gss_deleg_creds): @@ -1543,7 +1599,7 @@ class Transport(threading.Thread, ClosingContextManager): """ if (not self.active) or (not self.initial_kex_done): # we should never try to authenticate unless we're on a secure link - raise SSHException('No existing session') + raise SSHException("No existing session") my_event = threading.Event() self.auth_handler = AuthHandler(self) self.auth_handler.auth_gssapi_with_mic( @@ -1568,7 +1624,7 @@ class Transport(threading.Thread, ClosingContextManager): """ if (not self.active) or (not self.initial_kex_done): # we should never try to authenticate unless we're on a secure link - raise SSHException('No existing session') + raise SSHException("No existing session") my_event = threading.Event() self.auth_handler = AuthHandler(self) self.auth_handler.auth_gssapi_keyex(username, my_event) @@ -1635,9 +1691,9 @@ class Transport(threading.Thread, ClosingContextManager): .. versionadded:: 1.5.2 """ if compress: - self._preferred_compression = ('zlib@openssh.com', 'zlib', 'none') + self._preferred_compression = ("zlib@openssh.com", "zlib", "none") else: - self._preferred_compression = ('none',) + self._preferred_compression = ("none",) def getpeername(self): """ @@ -1651,9 +1707,9 @@ class Transport(threading.Thread, ClosingContextManager): the address of the remote host, if known, as a ``(str, int)`` tuple. """ - gp = getattr(self.sock, 'getpeername', None) + gp = getattr(self.sock, "getpeername", None) if gp is None: - return 'unknown', 0 + return "unknown", 0 return gp() def stop_thread(self): @@ -1672,10 +1728,10 @@ class Transport(threading.Thread, ClosingContextManager): # our socket and packetizer are both closed (but where we'd # otherwise be sitting forever on that recv()). while ( - self.is_alive() and - self is not threading.current_thread() and - not self.sock._closed and - not self.packetizer.closed + self.is_alive() + and self is not threading.current_thread() + and not self.sock._closed + and not self.packetizer.closed ): self.join(0.1) @@ -1717,14 +1773,18 @@ class Transport(threading.Thread, ClosingContextManager): while True: self.clear_to_send.wait(0.1) if not self.active: - self._log(DEBUG, 'Dropping user packet because connection is dead.') # noqa + self._log( + DEBUG, "Dropping user packet because connection is dead." + ) # noqa return self.clear_to_send_lock.acquire() if self.clear_to_send.is_set(): break self.clear_to_send_lock.release() if time.time() > start + self.clear_to_send_timeout: - raise SSHException('Key-exchange timed out waiting for key negotiation') # noqa + raise SSHException( + "Key-exchange timed out waiting for key negotiation" + ) # noqa try: self._send_message(data) finally: @@ -1748,9 +1808,13 @@ class Transport(threading.Thread, ClosingContextManager): def _verify_key(self, host_key, sig): key = self._key_info[self.host_key_type](Message(host_key)) if key is None: - raise SSHException('Unknown host key type') + raise SSHException("Unknown host key type") if not key.verify_ssh_sig(self.H, Message(sig)): - raise SSHException('Signature verification ({}) failed.'.format(self.host_key_type)) # noqa + raise SSHException( + "Signature verification ({}) failed.".format( + self.host_key_type + ) + ) # noqa self.host_key = key def _compute_key(self, id, nbytes): @@ -1762,16 +1826,16 @@ class Transport(threading.Thread, ClosingContextManager): m.add_bytes(self.session_id) # Fallback to SHA1 for kex engines that fail to specify a hex # algorithm, or for e.g. transport tests that don't run kexinit. - hash_algo = getattr(self.kex_engine, 'hash_algo', None) + hash_algo = getattr(self.kex_engine, "hash_algo", None) hash_select_msg = "kex engine {} specified hash_algo {!r}".format( - self.kex_engine.__class__.__name__, hash_algo, + self.kex_engine.__class__.__name__, hash_algo ) if hash_algo is None: hash_algo = sha1 hash_select_msg += ", falling back to sha1" - if not hasattr(self, '_logged_hash_selection'): + if not hasattr(self, "_logged_hash_selection"): self._log(DEBUG, hash_select_msg) - setattr(self, '_logged_hash_selection', True) + setattr(self, "_logged_hash_selection", True) out = sofar = hash_algo(m.asbytes()).digest() while len(out) < nbytes: m = Message() @@ -1785,11 +1849,11 @@ class Transport(threading.Thread, ClosingContextManager): def _get_cipher(self, name, key, iv, operation): if name not in self._cipher_info: - raise SSHException('Unknown client cipher ' + name) + raise SSHException("Unknown client cipher " + name) else: cipher = Cipher( - self._cipher_info[name]['class'](key), - self._cipher_info[name]['mode'](iv), + self._cipher_info[name]["class"](key), + self._cipher_info[name]["mode"](iv), backend=default_backend(), ) if operation is self._ENCRYPT: @@ -1799,8 +1863,10 @@ class Transport(threading.Thread, ClosingContextManager): def _set_forward_agent_handler(self, handler): if handler is None: + def default_handler(channel): self._queue_incoming_channel(channel) + self._forward_agent_handler = default_handler else: self._forward_agent_handler = handler @@ -1811,6 +1877,7 @@ class Transport(threading.Thread, ClosingContextManager): # by default, use the same mechanism as accept() def default_handler(channel, src_addr_port): self._queue_incoming_channel(channel) + self._x11_handler = default_handler else: self._x11_handler = handler @@ -1844,9 +1911,9 @@ class Transport(threading.Thread, ClosingContextManager): Otherwise (client mode, authed, or pre-auth message) returns None. """ if ( - not self.server_mode or - ptype <= HIGHEST_USERAUTH_MESSAGE_ID or - self.is_authenticated() + not self.server_mode + or ptype <= HIGHEST_USERAUTH_MESSAGE_ID + or self.is_authenticated() ): return None # WELP. We must be dealing with someone trying to do non-auth things @@ -1857,13 +1924,13 @@ class Transport(threading.Thread, ClosingContextManager): reply.add_byte(cMSG_REQUEST_FAILURE) # Channel opens let us reject w/ a specific type + message. elif ptype == MSG_CHANNEL_OPEN: - kind = message.get_text() # noqa + kind = message.get_text() # noqa chanid = message.get_int() reply.add_byte(cMSG_CHANNEL_OPEN_FAILURE) reply.add_int(chanid) reply.add_int(OPEN_FAILED_ADMINISTRATIVELY_PROHIBITED) - reply.add_string('') - reply.add_string('en') + reply.add_string("") + reply.add_string("en") # NOTE: Post-open channel messages do not need checking; the above will # reject attemps to open channels, meaning that even if a malicious # user tries to send a MSG_CHANNEL_REQUEST, it will simply fall under @@ -1885,13 +1952,16 @@ class Transport(threading.Thread, ClosingContextManager): _active_threads.append(self) tid = hex(long(id(self)) & xffffffff) if self.server_mode: - self._log(DEBUG, 'starting thread (server mode): {}'.format(tid)) + self._log(DEBUG, "starting thread (server mode): {}".format(tid)) else: - self._log(DEBUG, 'starting thread (client mode): {}'.format(tid)) + self._log(DEBUG, "starting thread (client mode): {}".format(tid)) try: try: - self.packetizer.write_all(b(self.local_version + '\r\n')) - self._log(DEBUG, 'Local version/idstring: {}'.format(self.local_version)) # noqa + self.packetizer.write_all(b(self.local_version + "\r\n")) + self._log( + DEBUG, + "Local version/idstring: {}".format(self.local_version), + ) # noqa self._check_banner() # The above is actually very much part of the handshake, but # sometimes the banner can be read but the machine is not @@ -1921,7 +1991,11 @@ class Transport(threading.Thread, ClosingContextManager): continue if len(self._expected_packet) > 0: if ptype not in self._expected_packet: - raise SSHException('Expecting packet from {!r}, got {:d}'.format(self._expected_packet, ptype)) # noqa + raise SSHException( + "Expecting packet from {!r}, got {:d}".format( + self._expected_packet, ptype + ) + ) # noqa self._expected_packet = tuple() if (ptype >= 30) and (ptype <= 41): self.kex_engine.parse_next(ptype, m) @@ -1939,20 +2013,30 @@ class Transport(threading.Thread, ClosingContextManager): if chan is not None: self._channel_handler_table[ptype](chan, m) elif chanid in self.channels_seen: - self._log(DEBUG, 'Ignoring message for dead channel {:d}'.format(chanid)) # noqa + self._log( + DEBUG, + "Ignoring message for dead channel {:d}".format( # noqa + chanid + ), + ) else: - self._log(ERROR, 'Channel request for unknown channel {:d}'.format(chanid)) # noqa + self._log( + ERROR, + "Channel request for unknown channel {:d}".format( # noqa + chanid + ), + ) break elif ( - self.auth_handler is not None and - ptype in self.auth_handler._handler_table + self.auth_handler is not None + and ptype in self.auth_handler._handler_table ): handler = self.auth_handler._handler_table[ptype] handler(self.auth_handler, m) if len(self._expected_packet) > 0: continue else: - err = 'Oops, unhandled type {:d}'.format(ptype) + err = "Oops, unhandled type {:d}".format(ptype) self._log(WARNING, err) msg = Message() msg.add_byte(cMSG_UNIMPLEMENTED) @@ -1960,24 +2044,24 @@ class Transport(threading.Thread, ClosingContextManager): self._send_message(msg) self.packetizer.complete_handshake() except SSHException as e: - self._log(ERROR, 'Exception: ' + str(e)) + self._log(ERROR, "Exception: " + str(e)) self._log(ERROR, util.tb_strings()) self.saved_exception = e except EOFError as e: - self._log(DEBUG, 'EOF in transport thread') + self._log(DEBUG, "EOF in transport thread") self.saved_exception = e except socket.error as e: if type(e.args) is tuple: if e.args: - emsg = '{} ({:d})'.format(e.args[1], e.args[0]) + emsg = "{} ({:d})".format(e.args[1], e.args[0]) else: # empty tuple, e.g. socket.timeout emsg = str(e) or repr(e) else: emsg = e.args - self._log(ERROR, 'Socket exception: ' + emsg) + self._log(ERROR, "Socket exception: " + emsg) self.saved_exception = e except Exception as e: - self._log(ERROR, 'Unknown exception: ' + str(e)) + self._log(ERROR, "Unknown exception: " + str(e)) self._log(ERROR, util.tb_strings()) self.saved_exception = e _active_threads.remove(self) @@ -2006,7 +2090,6 @@ class Transport(threading.Thread, ClosingContextManager): if self.sys.modules is not None: raise - def _log_agreement(self, which, local, remote): # Log useful, non-duplicative line re: an agreed-upon algorithm. # Old code implied algorithms could be asymmetrical (different for @@ -2048,32 +2131,32 @@ class Transport(threading.Thread, ClosingContextManager): raise except Exception as e: raise SSHException( - 'Error reading SSH protocol banner' + str(e) + "Error reading SSH protocol banner" + str(e) ) - if buf[:4] == 'SSH-': + if buf[:4] == "SSH-": break - self._log(DEBUG, 'Banner: ' + buf) - if buf[:4] != 'SSH-': + self._log(DEBUG, "Banner: " + buf) + if buf[:4] != "SSH-": raise SSHException('Indecipherable protocol version "' + buf + '"') # save this server version string for later self.remote_version = buf - self._log(DEBUG, 'Remote version/idstring: {}'.format(buf)) + self._log(DEBUG, "Remote version/idstring: {}".format(buf)) # pull off any attached comment # NOTE: comment used to be stored in a variable and then...never used. # since 2003. ca 877cd974b8182d26fa76d566072917ea67b64e67 - i = buf.find(' ') + i = buf.find(" ") if i >= 0: buf = buf[:i] # parse out version string and make sure it matches - segs = buf.split('-', 2) + segs = buf.split("-", 2) if len(segs) < 3: - raise SSHException('Invalid SSH banner') + raise SSHException("Invalid SSH banner") version = segs[1] client = segs[2] - if version != '1.99' and version != '2.0': - msg = 'Incompatible version ({} instead of 2.0)' + if version != "1.99" and version != "2.0": + msg = "Incompatible version ({} instead of 2.0)" raise SSHException(msg.format(version)) - msg = 'Connected (version {}, client {})'.format(version, client) + msg = "Connected (version {}, client {})".format(version, client) self._log(INFO, msg) def _send_kex_init(self): @@ -2089,25 +2172,27 @@ class Transport(threading.Thread, ClosingContextManager): self.gss_kex_used = False self.in_kex = True if self.server_mode: - mp_required_prefix = 'diffie-hellman-group-exchange-sha' + mp_required_prefix = "diffie-hellman-group-exchange-sha" kex_mp = [ - k for k - in self._preferred_kex + k + for k in self._preferred_kex if k.startswith(mp_required_prefix) ] if (self._modulus_pack is None) and (len(kex_mp) > 0): # can't do group-exchange if we don't have a pack of potential # primes pkex = [ - k for k - in self.get_security_options().kex + k + for k in self.get_security_options().kex if not k.startswith(mp_required_prefix) ] self.get_security_options().kex = pkex - available_server_keys = list(filter( - list(self.server_key_dict.keys()).__contains__, - self._preferred_keys - )) + available_server_keys = list( + filter( + list(self.server_key_dict.keys()).__contains__, + self._preferred_keys, + ) + ) else: available_server_keys = self._preferred_keys @@ -2131,7 +2216,7 @@ class Transport(threading.Thread, ClosingContextManager): self._send_message(m) def _parse_kex_init(self, m): - m.get_bytes(16) # cookie, discarded + m.get_bytes(16) # cookie, discarded kex_algo_list = m.get_list() server_key_algo_list = m.get_list() client_encrypt_algo_list = m.get_list() @@ -2143,20 +2228,32 @@ class Transport(threading.Thread, ClosingContextManager): client_lang_list = m.get_list() server_lang_list = m.get_list() kex_follows = m.get_boolean() - m.get_int() # unused - - self._log(DEBUG, - 'kex algos:' + str(kex_algo_list) + - ' server key:' + str(server_key_algo_list) + - ' client encrypt:' + str(client_encrypt_algo_list) + - ' server encrypt:' + str(server_encrypt_algo_list) + - ' client mac:' + str(client_mac_algo_list) + - ' server mac:' + str(server_mac_algo_list) + - ' client compress:' + str(client_compress_algo_list) + - ' server compress:' + str(server_compress_algo_list) + - ' client lang:' + str(client_lang_list) + - ' server lang:' + str(server_lang_list) + - ' kex follows?' + str(kex_follows) + m.get_int() # unused + + self._log( + DEBUG, + "kex algos:" + + str(kex_algo_list) + + " server key:" + + str(server_key_algo_list) + + " client encrypt:" + + str(client_encrypt_algo_list) + + " server encrypt:" + + str(server_encrypt_algo_list) + + " client mac:" + + str(client_mac_algo_list) + + " server mac:" + + str(server_mac_algo_list) + + " client compress:" + + str(client_compress_algo_list) + + " server compress:" + + str(server_compress_algo_list) + + " client lang:" + + str(client_lang_list) + + " server lang:" + + str(server_lang_list) + + " kex follows?" + + str(kex_follows), ) # as a server, we pick the first item in the client's list that we @@ -2164,122 +2261,150 @@ class Transport(threading.Thread, ClosingContextManager): # as a client, we pick the first item in our list that the server # supports. if self.server_mode: - agreed_kex = list(filter( - self._preferred_kex.__contains__, - kex_algo_list - )) + agreed_kex = list( + filter(self._preferred_kex.__contains__, kex_algo_list) + ) else: - agreed_kex = list(filter( - kex_algo_list.__contains__, - self._preferred_kex - )) + agreed_kex = list( + filter(kex_algo_list.__contains__, self._preferred_kex) + ) if len(agreed_kex) == 0: - raise SSHException('Incompatible ssh peer (no acceptable kex algorithm)') # noqa + raise SSHException( + "Incompatible ssh peer (no acceptable kex algorithm)" + ) # noqa self.kex_engine = self._kex_info[agreed_kex[0]](self) self._log(DEBUG, "Kex agreed: {}".format(agreed_kex[0])) if self.server_mode: - available_server_keys = list(filter( - list(self.server_key_dict.keys()).__contains__, - self._preferred_keys - )) - agreed_keys = list(filter( - available_server_keys.__contains__, server_key_algo_list - )) + available_server_keys = list( + filter( + list(self.server_key_dict.keys()).__contains__, + self._preferred_keys, + ) + ) + agreed_keys = list( + filter( + available_server_keys.__contains__, server_key_algo_list + ) + ) else: - agreed_keys = list(filter( - server_key_algo_list.__contains__, self._preferred_keys - )) + agreed_keys = list( + filter(server_key_algo_list.__contains__, self._preferred_keys) + ) if len(agreed_keys) == 0: - raise SSHException('Incompatible ssh peer (no acceptable host key)') # noqa + raise SSHException( + "Incompatible ssh peer (no acceptable host key)" + ) # noqa self.host_key_type = agreed_keys[0] if self.server_mode and (self.get_server_key() is None): - raise SSHException('Incompatible ssh peer (can\'t match requested host key type)') # noqa - self._log_agreement( - 'HostKey', agreed_keys[0], agreed_keys[0] - ) + raise SSHException( + "Incompatible ssh peer (can't match requested host key type)" + ) # noqa + self._log_agreement("HostKey", agreed_keys[0], agreed_keys[0]) if self.server_mode: - agreed_local_ciphers = list(filter( - self._preferred_ciphers.__contains__, - server_encrypt_algo_list - )) - agreed_remote_ciphers = list(filter( - self._preferred_ciphers.__contains__, - client_encrypt_algo_list - )) + agreed_local_ciphers = list( + filter( + self._preferred_ciphers.__contains__, + server_encrypt_algo_list, + ) + ) + agreed_remote_ciphers = list( + filter( + self._preferred_ciphers.__contains__, + client_encrypt_algo_list, + ) + ) else: - agreed_local_ciphers = list(filter( - client_encrypt_algo_list.__contains__, - self._preferred_ciphers - )) - agreed_remote_ciphers = list(filter( - server_encrypt_algo_list.__contains__, - self._preferred_ciphers - )) + agreed_local_ciphers = list( + filter( + client_encrypt_algo_list.__contains__, + self._preferred_ciphers, + ) + ) + agreed_remote_ciphers = list( + filter( + server_encrypt_algo_list.__contains__, + self._preferred_ciphers, + ) + ) if len(agreed_local_ciphers) == 0 or len(agreed_remote_ciphers) == 0: - raise SSHException('Incompatible ssh server (no acceptable ciphers)') # noqa + raise SSHException( + "Incompatible ssh server (no acceptable ciphers)" + ) # noqa self.local_cipher = agreed_local_ciphers[0] self.remote_cipher = agreed_remote_ciphers[0] self._log_agreement( - 'Cipher', local=self.local_cipher, remote=self.remote_cipher + "Cipher", local=self.local_cipher, remote=self.remote_cipher ) if self.server_mode: - agreed_remote_macs = list(filter( - self._preferred_macs.__contains__, client_mac_algo_list - )) - agreed_local_macs = list(filter( - self._preferred_macs.__contains__, server_mac_algo_list - )) + agreed_remote_macs = list( + filter(self._preferred_macs.__contains__, client_mac_algo_list) + ) + agreed_local_macs = list( + filter(self._preferred_macs.__contains__, server_mac_algo_list) + ) else: - agreed_local_macs = list(filter( - client_mac_algo_list.__contains__, self._preferred_macs - )) - agreed_remote_macs = list(filter( - server_mac_algo_list.__contains__, self._preferred_macs - )) + agreed_local_macs = list( + filter(client_mac_algo_list.__contains__, self._preferred_macs) + ) + agreed_remote_macs = list( + filter(server_mac_algo_list.__contains__, self._preferred_macs) + ) if (len(agreed_local_macs) == 0) or (len(agreed_remote_macs) == 0): - raise SSHException('Incompatible ssh server (no acceptable macs)') + raise SSHException("Incompatible ssh server (no acceptable macs)") self.local_mac = agreed_local_macs[0] self.remote_mac = agreed_remote_macs[0] self._log_agreement( - 'MAC', local=self.local_mac, remote=self.remote_mac + "MAC", local=self.local_mac, remote=self.remote_mac ) if self.server_mode: - agreed_remote_compression = list(filter( - self._preferred_compression.__contains__, - client_compress_algo_list - )) - agreed_local_compression = list(filter( - self._preferred_compression.__contains__, - server_compress_algo_list - )) + agreed_remote_compression = list( + filter( + self._preferred_compression.__contains__, + client_compress_algo_list, + ) + ) + agreed_local_compression = list( + filter( + self._preferred_compression.__contains__, + server_compress_algo_list, + ) + ) else: - agreed_local_compression = list(filter( - client_compress_algo_list.__contains__, - self._preferred_compression - )) - agreed_remote_compression = list(filter( - server_compress_algo_list.__contains__, - self._preferred_compression - )) + agreed_local_compression = list( + filter( + client_compress_algo_list.__contains__, + self._preferred_compression, + ) + ) + agreed_remote_compression = list( + filter( + server_compress_algo_list.__contains__, + self._preferred_compression, + ) + ) if ( - len(agreed_local_compression) == 0 or - len(agreed_remote_compression) == 0 + len(agreed_local_compression) == 0 + or len(agreed_remote_compression) == 0 ): - msg = 'Incompatible ssh server (no acceptable compression) {!r} {!r} {!r}' # noqa - raise SSHException(msg.format( - agreed_local_compression, agreed_remote_compression, - self._preferred_compression, - )) + msg = "Incompatible ssh server (no acceptable compression)" + msg += " {!r} {!r} {!r}" + raise SSHException( + msg.format( + agreed_local_compression, + agreed_remote_compression, + self._preferred_compression, + ) + ) self.local_compression = agreed_local_compression[0] self.remote_compression = agreed_remote_compression[0] self._log_agreement( - 'Compression', + "Compression", local=self.local_compression, - remote=self.remote_compression + remote=self.remote_compression, ) # save for computing hash later... @@ -2292,40 +2417,36 @@ class Transport(threading.Thread, ClosingContextManager): def _activate_inbound(self): """switch on newly negotiated encryption parameters for inbound traffic""" - block_size = self._cipher_info[self.remote_cipher]['block-size'] + block_size = self._cipher_info[self.remote_cipher]["block-size"] if self.server_mode: - IV_in = self._compute_key('A', block_size) + IV_in = self._compute_key("A", block_size) key_in = self._compute_key( - 'C', self._cipher_info[self.remote_cipher]['key-size'] + "C", self._cipher_info[self.remote_cipher]["key-size"] ) else: - IV_in = self._compute_key('B', block_size) + IV_in = self._compute_key("B", block_size) key_in = self._compute_key( - 'D', self._cipher_info[self.remote_cipher]['key-size'] + "D", self._cipher_info[self.remote_cipher]["key-size"] ) engine = self._get_cipher( self.remote_cipher, key_in, IV_in, self._DECRYPT ) - mac_size = self._mac_info[self.remote_mac]['size'] - mac_engine = self._mac_info[self.remote_mac]['class'] + mac_size = self._mac_info[self.remote_mac]["size"] + mac_engine = self._mac_info[self.remote_mac]["class"] # initial mac keys are done in the hash's natural size (not the # potentially truncated transmission size) if self.server_mode: - mac_key = self._compute_key('E', mac_engine().digest_size) + mac_key = self._compute_key("E", mac_engine().digest_size) else: - mac_key = self._compute_key('F', mac_engine().digest_size) + mac_key = self._compute_key("F", mac_engine().digest_size) self.packetizer.set_inbound_cipher( engine, block_size, mac_engine, mac_size, mac_key ) compress_in = self._compression_info[self.remote_compression][1] - if ( - compress_in is not None and - ( - self.remote_compression != 'zlib@openssh.com' or - self.authenticated - ) + if compress_in is not None and ( + self.remote_compression != "zlib@openssh.com" or self.authenticated ): - self._log(DEBUG, 'Switching on inbound compression ...') + self._log(DEBUG, "Switching on inbound compression ...") self.packetizer.set_inbound_compressor(compress_in()) def _activate_outbound(self): @@ -2334,37 +2455,37 @@ class Transport(threading.Thread, ClosingContextManager): m = Message() m.add_byte(cMSG_NEWKEYS) self._send_message(m) - block_size = self._cipher_info[self.local_cipher]['block-size'] + block_size = self._cipher_info[self.local_cipher]["block-size"] if self.server_mode: - IV_out = self._compute_key('B', block_size) + IV_out = self._compute_key("B", block_size) key_out = self._compute_key( - 'D', self._cipher_info[self.local_cipher]['key-size']) + "D", self._cipher_info[self.local_cipher]["key-size"] + ) else: - IV_out = self._compute_key('A', block_size) + IV_out = self._compute_key("A", block_size) key_out = self._compute_key( - 'C', self._cipher_info[self.local_cipher]['key-size']) + "C", self._cipher_info[self.local_cipher]["key-size"] + ) engine = self._get_cipher( - self.local_cipher, key_out, IV_out, self._ENCRYPT) - mac_size = self._mac_info[self.local_mac]['size'] - mac_engine = self._mac_info[self.local_mac]['class'] + self.local_cipher, key_out, IV_out, self._ENCRYPT + ) + mac_size = self._mac_info[self.local_mac]["size"] + mac_engine = self._mac_info[self.local_mac]["class"] # initial mac keys are done in the hash's natural size (not the # potentially truncated transmission size) if self.server_mode: - mac_key = self._compute_key('F', mac_engine().digest_size) + mac_key = self._compute_key("F", mac_engine().digest_size) else: - mac_key = self._compute_key('E', mac_engine().digest_size) - sdctr = self.local_cipher.endswith('-ctr') + mac_key = self._compute_key("E", mac_engine().digest_size) + sdctr = self.local_cipher.endswith("-ctr") self.packetizer.set_outbound_cipher( - engine, block_size, mac_engine, mac_size, mac_key, sdctr) + engine, block_size, mac_engine, mac_size, mac_key, sdctr + ) compress_out = self._compression_info[self.local_compression][0] - if ( - compress_out is not None and - ( - self.local_compression != 'zlib@openssh.com' or - self.authenticated - ) + if compress_out is not None and ( + self.local_compression != "zlib@openssh.com" or self.authenticated ): - self._log(DEBUG, 'Switching on outbound compression ...') + self._log(DEBUG, "Switching on outbound compression ...") self.packetizer.set_outbound_compressor(compress_out()) if not self.packetizer.need_rekey(): self.in_kex = False @@ -2374,17 +2495,17 @@ class Transport(threading.Thread, ClosingContextManager): def _auth_trigger(self): self.authenticated = True # delayed initiation of compression - if self.local_compression == 'zlib@openssh.com': + if self.local_compression == "zlib@openssh.com": compress_out = self._compression_info[self.local_compression][0] - self._log(DEBUG, 'Switching on outbound compression ...') + self._log(DEBUG, "Switching on outbound compression ...") self.packetizer.set_outbound_compressor(compress_out()) - if self.remote_compression == 'zlib@openssh.com': + if self.remote_compression == "zlib@openssh.com": compress_in = self._compression_info[self.remote_compression][1] - self._log(DEBUG, 'Switching on inbound compression ...') + self._log(DEBUG, "Switching on inbound compression ...") self.packetizer.set_inbound_compressor(compress_in()) def _parse_newkeys(self, m): - self._log(DEBUG, 'Switch to new keys ...') + self._log(DEBUG, "Switch to new keys ...") self._activate_inbound() # can also free a bunch of stuff here self.local_kex_init = self.remote_kex_init = None @@ -2412,7 +2533,7 @@ class Transport(threading.Thread, ClosingContextManager): def _parse_disconnect(self, m): code = m.get_int() desc = m.get_text() - self._log(INFO, 'Disconnect (code {:d}): {}'.format(code, desc)) + self._log(INFO, "Disconnect (code {:d}): {}".format(code, desc)) def _parse_global_request(self, m): kind = m.get_text() @@ -2421,16 +2542,16 @@ class Transport(threading.Thread, ClosingContextManager): if not self.server_mode: self._log( DEBUG, - 'Rejecting "{}" global request from server.'.format(kind) + 'Rejecting "{}" global request from server.'.format(kind), ) ok = False - elif kind == 'tcpip-forward': + elif kind == "tcpip-forward": address = m.get_text() port = m.get_int() ok = self.server_object.check_port_forward_request(address, port) if ok: ok = (ok,) - elif kind == 'cancel-tcpip-forward': + elif kind == "cancel-tcpip-forward": address = m.get_text() port = m.get_int() self.server_object.cancel_port_forward_request(address, port) @@ -2451,13 +2572,13 @@ class Transport(threading.Thread, ClosingContextManager): self._send_message(msg) def _parse_request_success(self, m): - self._log(DEBUG, 'Global request successful.') + self._log(DEBUG, "Global request successful.") self.global_response = m if self.completion_event is not None: self.completion_event.set() def _parse_request_failure(self, m): - self._log(DEBUG, 'Global request denied.') + self._log(DEBUG, "Global request denied.") self.global_response = None if self.completion_event is not None: self.completion_event.set() @@ -2469,13 +2590,14 @@ class Transport(threading.Thread, ClosingContextManager): server_max_packet_size = m.get_int() chan = self._channels.get(chanid) if chan is None: - self._log(WARNING, 'Success for unrequested channel! [??]') + self._log(WARNING, "Success for unrequested channel! [??]") return self.lock.acquire() try: chan._set_remote_channel( - server_chanid, server_window_size, server_max_packet_size) - self._log(DEBUG, 'Secsh channel {:d} opened.'.format(chanid)) + server_chanid, server_window_size, server_max_packet_size + ) + self._log(DEBUG, "Secsh channel {:d} opened.".format(chanid)) if chanid in self.channel_events: self.channel_events[chanid].set() del self.channel_events[chanid] @@ -2488,12 +2610,12 @@ class Transport(threading.Thread, ClosingContextManager): reason = m.get_int() reason_str = m.get_text() m.get_text() # ignored language - reason_text = CONNECTION_FAILED_CODE.get(reason, '(unknown code)') + reason_text = CONNECTION_FAILED_CODE.get(reason, "(unknown code)") self._log( ERROR, - 'Secsh channel {:d} open FAILED: {}: {}'.format( - chanid, reason_str, reason_text, - ) + "Secsh channel {:d} open FAILED: {}: {}".format( + chanid, reason_str, reason_text + ), ) self.lock.acquire() try: @@ -2514,39 +2636,39 @@ class Transport(threading.Thread, ClosingContextManager): max_packet_size = m.get_int() reject = False if ( - kind == 'auth-agent@openssh.com' and - self._forward_agent_handler is not None + kind == "auth-agent@openssh.com" + and self._forward_agent_handler is not None ): - self._log(DEBUG, 'Incoming forward agent connection') + self._log(DEBUG, "Incoming forward agent connection") self.lock.acquire() try: my_chanid = self._next_channel() finally: self.lock.release() - elif (kind == 'x11') and (self._x11_handler is not None): + elif (kind == "x11") and (self._x11_handler is not None): origin_addr = m.get_text() origin_port = m.get_int() self._log( DEBUG, - 'Incoming x11 connection from {}:{:d}'.format( - origin_addr, origin_port, - ) + "Incoming x11 connection from {}:{:d}".format( + origin_addr, origin_port + ), ) self.lock.acquire() try: my_chanid = self._next_channel() finally: self.lock.release() - elif (kind == 'forwarded-tcpip') and (self._tcp_handler is not None): + elif (kind == "forwarded-tcpip") and (self._tcp_handler is not None): server_addr = m.get_text() server_port = m.get_int() origin_addr = m.get_text() origin_port = m.get_int() self._log( DEBUG, - 'Incoming tcp forwarded connection from {}:{:d}'.format( - origin_addr, origin_port, - ) + "Incoming tcp forwarded connection from {}:{:d}".format( + origin_addr, origin_port + ), ) self.lock.acquire() try: @@ -2556,7 +2678,8 @@ class Transport(threading.Thread, ClosingContextManager): elif not self.server_mode: self._log( DEBUG, - 'Rejecting "{}" channel request from server.'.format(kind)) + 'Rejecting "{}" channel request from server.'.format(kind), + ) reject = True reason = OPEN_FAILED_ADMINISTRATIVELY_PROHIBITED else: @@ -2565,7 +2688,7 @@ class Transport(threading.Thread, ClosingContextManager): my_chanid = self._next_channel() finally: self.lock.release() - if kind == 'direct-tcpip': + if kind == "direct-tcpip": # handle direct-tcpip requests coming from the client dest_addr = m.get_text() dest_port = m.get_int() @@ -2574,23 +2697,25 @@ class Transport(threading.Thread, ClosingContextManager): reason = self.server_object.check_channel_direct_tcpip_request( my_chanid, (origin_addr, origin_port), - (dest_addr, dest_port) + (dest_addr, dest_port), ) else: reason = self.server_object.check_channel_request( - kind, my_chanid) + kind, my_chanid + ) if reason != OPEN_SUCCEEDED: self._log( DEBUG, - 'Rejecting "{}" channel request from client.'.format(kind)) + 'Rejecting "{}" channel request from client.'.format(kind), + ) reject = True if reject: msg = Message() msg.add_byte(cMSG_CHANNEL_OPEN_FAILURE) msg.add_int(chanid) msg.add_int(reason) - msg.add_string('') - msg.add_string('en') + msg.add_string("") + msg.add_string("en") self._send_message(msg) return @@ -2601,9 +2726,11 @@ class Transport(threading.Thread, ClosingContextManager): self.channels_seen[my_chanid] = True chan._set_transport(self) chan._set_window( - self.default_window_size, self.default_max_packet_size) + self.default_window_size, self.default_max_packet_size + ) chan._set_remote_channel( - chanid, initial_window_size, max_packet_size) + chanid, initial_window_size, max_packet_size + ) finally: self.lock.release() m = Message() @@ -2613,19 +2740,17 @@ class Transport(threading.Thread, ClosingContextManager): m.add_int(self.default_window_size) m.add_int(self.default_max_packet_size) self._send_message(m) - self._log(DEBUG, - 'Secsh channel {:d} ({}) opened.'.format(my_chanid, kind) + self._log( + DEBUG, "Secsh channel {:d} ({}) opened.".format(my_chanid, kind) ) - if kind == 'auth-agent@openssh.com': + if kind == "auth-agent@openssh.com": self._forward_agent_handler(chan) - elif kind == 'x11': + elif kind == "x11": self._x11_handler(chan, (origin_addr, origin_port)) - elif kind == 'forwarded-tcpip': + elif kind == "forwarded-tcpip": chan.origin_addr = (origin_addr, origin_port) self._tcp_handler( - chan, - (origin_addr, origin_port), - (server_addr, server_port) + chan, (origin_addr, origin_port), (server_addr, server_port) ) else: self._queue_incoming_channel(chan) @@ -2634,7 +2759,7 @@ class Transport(threading.Thread, ClosingContextManager): m.get_boolean() # always_display msg = m.get_string() m.get_string() # language - self._log(DEBUG, 'Debug msg: {}'.format(util.safe_string(msg))) + self._log(DEBUG, "Debug msg: {}".format(util.safe_string(msg))) def _get_subsystem_handler(self, name): try: @@ -2668,7 +2793,7 @@ class Transport(threading.Thread, ClosingContextManager): } -class SecurityOptions (object): +class SecurityOptions(object): """ Simple object containing the security preferences of an ssh transport. These are tuples of acceptable ciphers, digests, key types, and key @@ -2680,7 +2805,7 @@ class SecurityOptions (object): ``ValueError`` will be raised. If you try to assign something besides a tuple to one of the fields, ``TypeError`` will be raised. """ - __slots__ = '_transport' + __slots__ = "_transport" def __init__(self, transport): self._transport = transport @@ -2689,17 +2814,17 @@ class SecurityOptions (object): """ Returns a string representation of this object, for debugging. """ - return '<paramiko.SecurityOptions for {!r}>'.format(self._transport) + return "<paramiko.SecurityOptions for {!r}>".format(self._transport) def _set(self, name, orig, x): if type(x) is list: x = tuple(x) if type(x) is not tuple: - raise TypeError('expected tuple or list') + raise TypeError("expected tuple or list") possible = list(getattr(self._transport, orig).keys()) forbidden = [n for n in x if n not in possible] if len(forbidden) > 0: - raise ValueError('unknown cipher') + raise ValueError("unknown cipher") setattr(self._transport, name, x) @property @@ -2709,7 +2834,7 @@ class SecurityOptions (object): @ciphers.setter def ciphers(self, x): - self._set('_preferred_ciphers', '_cipher_info', x) + self._set("_preferred_ciphers", "_cipher_info", x) @property def digests(self): @@ -2718,7 +2843,7 @@ class SecurityOptions (object): @digests.setter def digests(self, x): - self._set('_preferred_macs', '_mac_info', x) + self._set("_preferred_macs", "_mac_info", x) @property def key_types(self): @@ -2727,8 +2852,7 @@ class SecurityOptions (object): @key_types.setter def key_types(self, x): - self._set('_preferred_keys', '_key_info', x) - + self._set("_preferred_keys", "_key_info", x) @property def kex(self): @@ -2737,7 +2861,7 @@ class SecurityOptions (object): @kex.setter def kex(self, x): - self._set('_preferred_kex', '_kex_info', x) + self._set("_preferred_kex", "_kex_info", x) @property def compression(self): @@ -2746,10 +2870,11 @@ class SecurityOptions (object): @compression.setter def compression(self, x): - self._set('_preferred_compression', '_compression_info', x) + self._set("_preferred_compression", "_compression_info", x) + +class ChannelMap(object): -class ChannelMap (object): def __init__(self): # (id -> Channel) self._map = weakref.WeakValueDictionary() diff --git a/paramiko/util.py b/paramiko/util.py index 2854ef98..399141ad 100644 --- a/paramiko/util.py +++ b/paramiko/util.py @@ -49,9 +49,9 @@ def inflate_long(s, always_positive=False): # noinspection PyAugmentAssignment s = filler * (4 - len(s) % 4) + s for i in range(0, len(s), 4): - out = (out << 32) + struct.unpack('>I', s[i:i + 4])[0] + out = (out << 32) + struct.unpack(">I", s[i : i + 4])[0] if negative: - out -= (long(1) << (8 * len(s))) + out -= long(1) << (8 * len(s)) return out @@ -66,7 +66,7 @@ def deflate_long(n, add_sign_padding=True): s = bytes() n = long(n) while (n != 0) and (n != -1): - s = struct.pack('>I', n & xffffffff) + s + s = struct.pack(">I", n & xffffffff) + s n >>= 32 # strip off leading zeros, FFs for i in enumerate(s): @@ -81,7 +81,7 @@ def deflate_long(n, add_sign_padding=True): s = zero_byte else: s = max_byte - s = s[i[0]:] + s = s[i[0] :] if add_sign_padding: if (n == 0) and (byte_ord(s[0]) >= 0x80): s = zero_byte + s @@ -90,11 +90,11 @@ def deflate_long(n, add_sign_padding=True): return s -def format_binary(data, prefix=''): +def format_binary(data, prefix=""): x = 0 out = [] while len(data) > x + 16: - out.append(format_binary_line(data[x:x + 16])) + out.append(format_binary_line(data[x : x + 16])) x += 16 if x < len(data): out.append(format_binary_line(data[x:])) @@ -102,22 +102,21 @@ def format_binary(data, prefix=''): def format_binary_line(data): - left = ' '.join(['{:02X}'.format(byte_ord(c)) for c in data]) - right = ''.join([ - '.{:c}..'.format(byte_ord(c))[(byte_ord(c) + 63) // 95] - for c in data - ]) - return '{:50s} {}'.format(left, right) + left = " ".join(["{:02X}".format(byte_ord(c)) for c in data]) + right = "".join( + [".{:c}..".format(byte_ord(c))[(byte_ord(c) + 63) // 95] for c in data] + ) + return "{:50s} {}".format(left, right) def safe_string(s): - out = b'' + out = b"" for c in s: i = byte_ord(c) if 32 <= i <= 127: out += byte_chr(i) else: - out += b('%{:02X}'.format(i)) + out += b("%{:02X}".format(i)) return out @@ -137,7 +136,7 @@ def bit_length(n): def tb_strings(): - return ''.join(traceback.format_exception(*sys.exc_info())).split('\n') + return "".join(traceback.format_exception(*sys.exc_info())).split("\n") def generate_key_bytes(hash_alg, salt, key, nbytes): @@ -188,6 +187,7 @@ def load_host_keys(filename): nested dict of `.PKey` objects, indexed by hostname and then keytype """ from paramiko.hostkeys import HostKeys + return HostKeys(filename) @@ -249,15 +249,17 @@ def log_to_file(filename, level=DEBUG): if len(l.handlers) > 0: return l.setLevel(level) - f = open(filename, 'a') + f = open(filename, "a") lh = logging.StreamHandler(f) - frm = '%(levelname)-.3s [%(asctime)s.%(msecs)03d] thr=%(_threadid)-3d %(name)s: %(message)s' # noqa - lh.setFormatter(logging.Formatter(frm, '%Y%m%d-%H:%M:%S')) + frm = "%(levelname)-.3s [%(asctime)s.%(msecs)03d] thr=%(_threadid)-3d" + frm += " %(name)s: %(message)s" + lh.setFormatter(logging.Formatter(frm, "%Y%m%d-%H:%M:%S")) l.addHandler(lh) # make only one filter object, so it doesn't get applied more than once -class PFilter (object): +class PFilter(object): + def filter(self, record): record._threadid = get_thread_id() return True @@ -293,6 +295,7 @@ def constant_time_bytes_eq(a, b): class ClosingContextManager(object): + def __enter__(self): return self diff --git a/paramiko/win_pageant.py b/paramiko/win_pageant.py index 661ba575..2bba789d 100644 --- a/paramiko/win_pageant.py +++ b/paramiko/win_pageant.py @@ -44,7 +44,7 @@ win32con_WM_COPYDATA = 74 def _get_pageant_window_object(): - return ctypes.windll.user32.FindWindowA(b'Pageant', b'Pageant') + return ctypes.windll.user32.FindWindowA(b"Pageant", b"Pageant") def can_talk_to_agent(): @@ -57,7 +57,7 @@ def can_talk_to_agent(): return bool(_get_pageant_window_object()) -if platform.architecture()[0] == '64bit': +if platform.architecture()[0] == "64bit": ULONG_PTR = ctypes.c_uint64 else: ULONG_PTR = ctypes.c_uint32 @@ -69,9 +69,9 @@ class COPYDATASTRUCT(ctypes.Structure): http://msdn.microsoft.com/en-us/library/windows/desktop/ms649010%28v=vs.85%29.aspx """ _fields_ = [ - ('num_data', ULONG_PTR), - ('data_size', ctypes.wintypes.DWORD), - ('data_loc', ctypes.c_void_p), + ("num_data", ULONG_PTR), + ("data_size", ctypes.wintypes.DWORD), + ("data_loc", ctypes.c_void_p), ] @@ -86,27 +86,29 @@ def _query_pageant(msg): return None # create a name for the mmap - map_name = 'PageantRequest%08x' % thread.get_ident() + map_name = "PageantRequest%08x" % thread.get_ident() - pymap = _winapi.MemoryMap(map_name, _AGENT_MAX_MSGLEN, - _winapi.get_security_attributes_for_user(), - ) + pymap = _winapi.MemoryMap( + map_name, _AGENT_MAX_MSGLEN, _winapi.get_security_attributes_for_user() + ) with pymap: pymap.write(msg) # Create an array buffer containing the mapped filename char_buffer = array.array("b", b(map_name) + zero_byte) # noqa char_buffer_address, char_buffer_size = char_buffer.buffer_info() # Create a string to use for the SendMessage function call - cds = COPYDATASTRUCT(_AGENT_COPYDATA_ID, char_buffer_size, - char_buffer_address) + cds = COPYDATASTRUCT( + _AGENT_COPYDATA_ID, char_buffer_size, char_buffer_address + ) - response = ctypes.windll.user32.SendMessageA(hwnd, - win32con_WM_COPYDATA, ctypes.sizeof(cds), ctypes.byref(cds)) + response = ctypes.windll.user32.SendMessageA( + hwnd, win32con_WM_COPYDATA, ctypes.sizeof(cds), ctypes.byref(cds) + ) if response > 0: pymap.seek(0) datalen = pymap.read(4) - retlen = struct.unpack('>I', datalen)[0] + retlen = struct.unpack(">I", datalen)[0] return datalen + pymap.read(retlen) return None @@ -127,10 +129,10 @@ class PageantConnection(object): def recv(self, n): if self._response is None: - return '' + return "" ret = self._response[:n] self._response = self._response[n:] - if self._response == '': + if self._response == "": self._response = None return ret @@ -9,7 +9,9 @@ omit = paramiko/_winapi.py [flake8] exclude = sites,.git,build,dist,demos,tests -ignore = E124,E125,E128,E261,E301,E302,E303,E402,E721 +# NOTE: W503, E203 are concessions to black 18.0b5 and could be reinstated +# later if fixed on that end. +ignore = E124,E125,E128,E261,E301,E302,E303,E402,E721,W503,E203 max-line-length = 79 [tool:pytest] @@ -19,12 +19,12 @@ import sys from setuptools import setup -if sys.platform == 'darwin': +if sys.platform == "darwin": import setup_helper setup_helper.install_custom_make_tarball() -longdesc = ''' +longdesc = """ This is a library for making SSH2 connections (client or server). Emphasis is on using SSH2 as an alternative to SSL for making secure connections between python scripts. All major ciphers and hash methods @@ -35,14 +35,14 @@ Required packages: To install the development version, ``pip install -e git+https://github.com/paramiko/paramiko/#egg=paramiko``. -''' +""" # Version info -- read without importing _locals = {} -with open('paramiko/_version.py') as fp: +with open("paramiko/_version.py") as fp: exec(fp.read(), None, _locals) -version = _locals['__version__'] +version = _locals["__version__"] setup( name="paramiko", @@ -52,28 +52,24 @@ setup( author="Jeff Forcier", author_email="jeff@bitprophet.org", url="https://github.com/paramiko/paramiko/", - packages=['paramiko'], - license='LGPL', - platforms='Posix; MacOS X; Windows', + packages=["paramiko"], + license="LGPL", + platforms="Posix; MacOS X; Windows", classifiers=[ - 'Development Status :: 5 - Production/Stable', - 'Intended Audience :: Developers', - 'License :: OSI Approved :: ' - 'GNU Library or Lesser General Public License (LGPL)', - 'Operating System :: OS Independent', - 'Topic :: Internet', - 'Topic :: Security :: Cryptography', - 'Programming Language :: Python', - 'Programming Language :: Python :: 2', - 'Programming Language :: Python :: 2.7', - 'Programming Language :: Python :: 3', - 'Programming Language :: Python :: 3.4', - 'Programming Language :: Python :: 3.5', - 'Programming Language :: Python :: 3.6', - ], - install_requires=[ - 'bcrypt>=3.1.3', - 'cryptography>=1.5', - 'pynacl>=1.0.1', + "Development Status :: 5 - Production/Stable", + "Intended Audience :: Developers", + "License :: OSI Approved :: " + "GNU Library or Lesser General Public License (LGPL)", + "Operating System :: OS Independent", + "Topic :: Internet", + "Topic :: Security :: Cryptography", + "Programming Language :: Python", + "Programming Language :: Python :: 2", + "Programming Language :: Python :: 2.7", + "Programming Language :: Python :: 3", + "Programming Language :: Python :: 3.4", + "Programming Language :: Python :: 3.5", + "Programming Language :: Python :: 3.6", ], + install_requires=["bcrypt>=3.1.3", "cryptography>=1.5", "pynacl>=1.0.1"], ) diff --git a/setup_helper.py b/setup_helper.py index c359a16c..d0a8700e 100644 --- a/setup_helper.py +++ b/setup_helper.py @@ -40,6 +40,7 @@ try: except ImportError: getgrnam = None + def _get_gid(name): """Returns a gid, given a group name.""" if getgrnam is None or name is None: @@ -52,6 +53,7 @@ def _get_gid(name): return result[2] return None + def _get_uid(name): """Returns an uid, given a user name.""" if getpwnam is None or name is None: @@ -64,8 +66,16 @@ def _get_uid(name): return result[2] return None -def make_tarball(base_name, base_dir, compress='gzip', verbose=0, dry_run=0, - owner=None, group=None): + +def make_tarball( + base_name, + base_dir, + compress="gzip", + verbose=0, + dry_run=0, + owner=None, + group=None, +): """Create a tar file from all the files under 'base_dir'. This file may be compressed. @@ -87,28 +97,26 @@ def make_tarball(base_name, base_dir, compress='gzip', verbose=0, dry_run=0, # "create a tree of hardlinks" step! (Would also be nice to # detect GNU tar to use its 'z' option and save a step.) - compress_ext = { - 'gzip': ".gz", - 'bzip2': '.bz2', - 'compress': ".Z", - } + compress_ext = {"gzip": ".gz", "bzip2": ".bz2", "compress": ".Z"} # flags for compression program, each element of list will be an argument - tarfile_compress_flag = {'gzip': 'gz', 'bzip2': 'bz2'} - compress_flags = {'compress': ["-f"]} + tarfile_compress_flag = {"gzip": "gz", "bzip2": "bz2"} + compress_flags = {"compress": ["-f"]} if compress is not None and compress not in compress_ext.keys(): - raise ValueError("bad value for 'compress': must be None, 'gzip'," - "'bzip2' or 'compress'") + raise ValueError( + "bad value for 'compress': must be None, 'gzip'," + "'bzip2' or 'compress'" + ) archive_name = base_name + ".tar" if compress and compress in tarfile_compress_flag: archive_name += compress_ext[compress] - mode = 'w:' + tarfile_compress_flag.get(compress, '') + mode = "w:" + tarfile_compress_flag.get(compress, "") mkpath(os.path.dirname(archive_name), dry_run=dry_run) - log.info('Creating tar file %s with mode %s' % (archive_name, mode)) + log.info("Creating tar file %s with mode %s" % (archive_name, mode)) uid = _get_uid(owner) gid = _get_gid(group) @@ -136,18 +144,20 @@ def make_tarball(base_name, base_dir, compress='gzip', verbose=0, dry_run=0, tar.close() if compress and compress not in tarfile_compress_flag: - spawn([compress] + compress_flags[compress] + [archive_name], - dry_run=dry_run) + spawn( + [compress] + compress_flags[compress] + [archive_name], + dry_run=dry_run, + ) return archive_name + compress_ext[compress] else: return archive_name _custom_formats = { - 'gztar': (make_tarball, [('compress', 'gzip')], "gzip'ed tar-file"), - 'bztar': (make_tarball, [('compress', 'bzip2')], "bzip2'ed tar-file"), - 'ztar': (make_tarball, [('compress', 'compress')], "compressed tar file"), - 'tar': (make_tarball, [('compress', None)], "uncompressed tar file"), + "gztar": (make_tarball, [("compress", "gzip")], "gzip'ed tar-file"), + "bztar": (make_tarball, [("compress", "bzip2")], "bzip2'ed tar-file"), + "ztar": (make_tarball, [("compress", "compress")], "compressed tar file"), + "tar": (make_tarball, [("compress", None)], "uncompressed tar file"), } # Hack in and insert ourselves into the distutils code base diff --git a/sites/docs/conf.py b/sites/docs/conf.py index 5674fed1..eb895804 100644 --- a/sites/docs/conf.py +++ b/sites/docs/conf.py @@ -1,16 +1,17 @@ # Obtain shared config values import os, sys -sys.path.append(os.path.abspath('..')) -sys.path.append(os.path.abspath('../..')) + +sys.path.append(os.path.abspath("..")) +sys.path.append(os.path.abspath("../..")) from shared_conf import * # Enable autodoc, intersphinx -extensions.extend(['sphinx.ext.autodoc']) +extensions.extend(["sphinx.ext.autodoc"]) # Autodoc settings -autodoc_default_flags = ['members', 'special-members'] +autodoc_default_flags = ["members", "special-members"] # Sister-site links to WWW -html_theme_options['extra_nav_links'] = { - "Main website": 'http://www.paramiko.org', +html_theme_options["extra_nav_links"] = { + "Main website": "http://www.paramiko.org" } diff --git a/sites/shared_conf.py b/sites/shared_conf.py index cf0d77ff..f4806cf1 100644 --- a/sites/shared_conf.py +++ b/sites/shared_conf.py @@ -5,36 +5,29 @@ import alabaster # Alabaster theme + mini-extension html_theme_path = [alabaster.get_path()] -extensions = ['alabaster', 'sphinx.ext.intersphinx'] +extensions = ["alabaster", "sphinx.ext.intersphinx"] # Paths relative to invoking conf.py - not this shared file -html_theme = 'alabaster' +html_theme = "alabaster" html_theme_options = { - 'description': "A Python implementation of SSHv2.", - 'github_user': 'paramiko', - 'github_repo': 'paramiko', - 'analytics_id': 'UA-18486793-2', - 'travis_button': True, + "description": "A Python implementation of SSHv2.", + "github_user": "paramiko", + "github_repo": "paramiko", + "analytics_id": "UA-18486793-2", + "travis_button": True, } html_sidebars = { - '**': [ - 'about.html', - 'navigation.html', - 'searchbox.html', - 'donate.html', - ] + "**": ["about.html", "navigation.html", "searchbox.html", "donate.html"] } # Everything intersphinx's to Python -intersphinx_mapping = { - 'python': ('https://docs.python.org/2.7/', None), -} +intersphinx_mapping = {"python": ("https://docs.python.org/2.7/", None)} # Regular settings -project = 'Paramiko' +project = "Paramiko" year = datetime.now().year -copyright = '{} Jeff Forcier'.format(year) -master_doc = 'index' -templates_path = ['_templates'] -exclude_trees = ['_build'] -source_suffix = '.rst' -default_role = 'obj' +copyright = "{} Jeff Forcier".format(year) +master_doc = "index" +templates_path = ["_templates"] +exclude_trees = ["_build"] +source_suffix = ".rst" +default_role = "obj" diff --git a/sites/www/conf.py b/sites/www/conf.py index c7ba0a86..00944871 100644 --- a/sites/www/conf.py +++ b/sites/www/conf.py @@ -3,22 +3,22 @@ import sys import os from os.path import abspath, join, dirname -sys.path.append(abspath(join(dirname(__file__), '..'))) +sys.path.append(abspath(join(dirname(__file__), ".."))) from shared_conf import * # Releases changelog extension -extensions.append('releases') +extensions.append("releases") releases_release_uri = "https://github.com/paramiko/paramiko/tree/%s" releases_issue_uri = "https://github.com/paramiko/paramiko/issues/%s" # Default is 'local' building, but reference the public docs site when building # under RTD. -target = join(dirname(__file__), '..', 'docs', '_build') -if os.environ.get('READTHEDOCS') == 'True': - target = 'http://docs.paramiko.org/en/latest/' -intersphinx_mapping['docs'] = (target, None) +target = join(dirname(__file__), "..", "docs", "_build") +if os.environ.get("READTHEDOCS") == "True": + target = "http://docs.paramiko.org/en/latest/" +intersphinx_mapping["docs"] = (target, None) # Sister-site links to API docs -html_theme_options['extra_nav_links'] = { - "API Docs": 'http://docs.paramiko.org', +html_theme_options["extra_nav_links"] = { + "API Docs": "http://docs.paramiko.org" } @@ -3,6 +3,8 @@ from os.path import join from shutil import rmtree, copytree from invoke import Collection, task +from invocations import travis +from invocations.checks import blacken from invocations.docs import docs, www, sites from invocations.packaging.release import ns as release_coll, publish from invocations.testing import count_errors @@ -106,7 +108,7 @@ def release(ctx, sdist=True, wheel=True, sign=True, dry_run=False, index=None): copytree("sites/docs/_build", target) # Publish publish( - ctx, sdist=sdist, wheel=wheel, sign=sign, dry_run=dry_run, index=index, + ctx, sdist=sdist, wheel=wheel, sign=sign, dry_run=dry_run, index=index ) # Remind print( @@ -121,7 +123,16 @@ def release(ctx, sdist=True, wheel=True, sign=True, dry_run=False, index=None): release_coll.tasks["publish"] = release ns = Collection( - test, coverage, guard, release_coll, docs, www, sites, count_errors + test, + coverage, + guard, + release_coll, + docs, + www, + sites, + count_errors, + travis, + blacken, ) ns.configure( { diff --git a/tests/conftest.py b/tests/conftest.py index d1967a73..2b509c5c 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -19,7 +19,7 @@ from .util import _support # presenting it on error/failure. (But also allow turning it off when doing # very pinpoint debugging - e.g. using breakpoints, so you don't want output # hiding enabled, but also don't want all the logging to gum up the terminal.) -if not os.environ.get('DISABLE_LOGGING', False): +if not os.environ.get("DISABLE_LOGGING", False): logging.basicConfig( level=logging.DEBUG, # Also make sure to set up timestamping for more sanity when debugging. @@ -43,7 +43,7 @@ def make_sftp_folder(): # TODO: if we want to lock ourselves even harder into localhost-only # testing (probably not?) could use tempdir modules for this for improved # safety. Then again...why would someone have such a folder??? - path = os.environ.get('TEST_FOLDER', 'paramiko-test-target') + path = os.environ.get("TEST_FOLDER", "paramiko-test-target") # Forcibly nuke this directory locally, since at the moment, the below # fixtures only ever run with a locally scoped stub test server. shutil.rmtree(path, ignore_errors=True) @@ -52,7 +52,7 @@ def make_sftp_folder(): return path -@pytest.fixture#(scope='session') +@pytest.fixture # (scope='session') def sftp_server(): """ Set up an in-memory SFTP server thread. Yields the client Transport/socket. @@ -69,17 +69,17 @@ def sftp_server(): tc = Transport(sockc) ts = Transport(socks) # Auth - host_key = RSAKey.from_private_key_file(_support('test_rsa.key')) + host_key = RSAKey.from_private_key_file(_support("test_rsa.key")) ts.add_server_key(host_key) # Server setup event = threading.Event() server = StubServer() - ts.set_subsystem_handler('sftp', SFTPServer, StubSFTPServer) + ts.set_subsystem_handler("sftp", SFTPServer, StubSFTPServer) ts.start_server(event, server) # Wait (so client has time to connect? Not sure. Old.) event.wait(1.0) # Make & yield connection. - tc.connect(username='slowdive', password='pygmalion') + tc.connect(username="slowdive", password="pygmalion") yield tc # TODO: any need for shutdown? Why didn't old suite do so? Or was that the # point of the "join all threads from threading module" crap in test.py? diff --git a/tests/loop.py b/tests/loop.py index 6c432867..dd1f5a0c 100644 --- a/tests/loop.py +++ b/tests/loop.py @@ -22,13 +22,13 @@ import threading from paramiko.common import asbytes -class LoopSocket (object): +class LoopSocket(object): """ A LoopSocket looks like a normal socket, but all data written to it is delivered on the read-end of another LoopSocket, and vice versa. It's like a software "socketpair". """ - + def __init__(self): self.__in_buffer = bytes() self.__lock = threading.Lock() @@ -84,7 +84,7 @@ class LoopSocket (object): self.__cv.notifyAll() finally: self.__lock.release() - + def __unlink(self): m = None self.__lock.acquire() @@ -96,5 +96,3 @@ class LoopSocket (object): self.__lock.release() if m is not None: m.__unlink() - - diff --git a/tests/stub_sftp.py b/tests/stub_sftp.py index 19545865..ffae635d 100644 --- a/tests/stub_sftp.py +++ b/tests/stub_sftp.py @@ -24,13 +24,21 @@ import os import sys from paramiko import ( - ServerInterface, SFTPServerInterface, SFTPServer, SFTPAttributes, - SFTPHandle, SFTP_OK, SFTP_FAILURE, AUTH_SUCCESSFUL, OPEN_SUCCEEDED, + ServerInterface, + SFTPServerInterface, + SFTPServer, + SFTPAttributes, + SFTPHandle, + SFTP_OK, + SFTP_FAILURE, + AUTH_SUCCESSFUL, + OPEN_SUCCEEDED, ) from paramiko.common import o666 -class StubServer (ServerInterface): +class StubServer(ServerInterface): + def check_auth_password(self, username, password): # all are allowed return AUTH_SUCCESSFUL @@ -39,7 +47,8 @@ class StubServer (ServerInterface): return OPEN_SUCCEEDED -class StubSFTPHandle (SFTPHandle): +class StubSFTPHandle(SFTPHandle): + def stat(self): try: return SFTPAttributes.from_stat(os.fstat(self.readfile.fileno())) @@ -56,11 +65,11 @@ class StubSFTPHandle (SFTPHandle): return SFTPServer.convert_errno(e.errno) -class StubSFTPServer (SFTPServerInterface): +class StubSFTPServer(SFTPServerInterface): # assume current folder is a fine root # (the tests always create and eventually delete a subfolder, so there shouldn't be any mess) ROOT = os.getcwd() - + def _realpath(self, path): return self.ROOT + self.canonicalize(path) @@ -70,7 +79,9 @@ class StubSFTPServer (SFTPServerInterface): out = [] flist = os.listdir(path) for fname in flist: - attr = SFTPAttributes.from_stat(os.stat(os.path.join(path, fname))) + attr = SFTPAttributes.from_stat( + os.stat(os.path.join(path, fname)) + ) attr.filename = fname out.append(attr) return out @@ -94,9 +105,9 @@ class StubSFTPServer (SFTPServerInterface): def open(self, path, flags, attr): path = self._realpath(path) try: - binary_flag = getattr(os, 'O_BINARY', 0) + binary_flag = getattr(os, "O_BINARY", 0) flags |= binary_flag - mode = getattr(attr, 'st_mode', None) + mode = getattr(attr, "st_mode", None) if mode is not None: fd = os.open(path, flags, mode) else: @@ -110,17 +121,17 @@ class StubSFTPServer (SFTPServerInterface): SFTPServer.set_file_attr(path, attr) if flags & os.O_WRONLY: if flags & os.O_APPEND: - fstr = 'ab' + fstr = "ab" else: - fstr = 'wb' + fstr = "wb" elif flags & os.O_RDWR: if flags & os.O_APPEND: - fstr = 'a+b' + fstr = "a+b" else: - fstr = 'r+b' + fstr = "r+b" else: # O_RDONLY (== 0) - fstr = 'rb' + fstr = "rb" try: f = os.fdopen(fd, fstr) except OSError as e: @@ -159,7 +170,6 @@ class StubSFTPServer (SFTPServerInterface): return SFTPServer.convert_errno(e.errno) return SFTP_OK - def mkdir(self, path, attr): path = self._realpath(path) try: @@ -188,18 +198,18 @@ class StubSFTPServer (SFTPServerInterface): def symlink(self, target_path, path): path = self._realpath(path) - if (len(target_path) > 0) and (target_path[0] == '/'): + if (len(target_path) > 0) and (target_path[0] == "/"): # absolute symlink target_path = os.path.join(self.ROOT, target_path[1:]) - if target_path[:2] == '//': + if target_path[:2] == "//": # bug in os.path.join target_path = target_path[1:] else: # compute relative to path abspath = os.path.join(os.path.dirname(path), target_path) - if abspath[:len(self.ROOT)] != self.ROOT: + if abspath[: len(self.ROOT)] != self.ROOT: # this symlink isn't going to work anyway -- just break it immediately - target_path = '<error>' + target_path = "<error>" try: os.symlink(target_path, path) except OSError as e: @@ -214,10 +224,10 @@ class StubSFTPServer (SFTPServerInterface): return SFTPServer.convert_errno(e.errno) # if it's absolute, remove the root if os.path.isabs(symlink): - if symlink[:len(self.ROOT)] == self.ROOT: - symlink = symlink[len(self.ROOT):] - if (len(symlink) == 0) or (symlink[0] != '/'): - symlink = '/' + symlink + if symlink[: len(self.ROOT)] == self.ROOT: + symlink = symlink[len(self.ROOT) :] + if (len(symlink) == 0) or (symlink[0] != "/"): + symlink = "/" + symlink else: - symlink = '<error>' + symlink = "<error>" return symlink diff --git a/tests/test_auth.py b/tests/test_auth.py index dacdd654..acabb1bd 100644 --- a/tests/test_auth.py +++ b/tests/test_auth.py @@ -26,8 +26,13 @@ import unittest from time import sleep from paramiko import ( - Transport, ServerInterface, RSAKey, DSSKey, BadAuthenticationType, - InteractiveQuery, AuthenticationException, + Transport, + ServerInterface, + RSAKey, + DSSKey, + BadAuthenticationType, + InteractiveQuery, + AuthenticationException, ) from paramiko import AUTH_FAILED, AUTH_PARTIALLY_SUCCESSFUL, AUTH_SUCCESSFUL from paramiko.py3compat import u @@ -36,54 +41,57 @@ from .loop import LoopSocket from .util import _support, slow -_pwd = u('\u2022') +_pwd = u("\u2022") -class NullServer (ServerInterface): +class NullServer(ServerInterface): paranoid_did_password = False paranoid_did_public_key = False - paranoid_key = DSSKey.from_private_key_file(_support('test_dss.key')) + paranoid_key = DSSKey.from_private_key_file(_support("test_dss.key")) def get_allowed_auths(self, username): - if username == 'slowdive': - return 'publickey,password' - if username == 'paranoid': - if not self.paranoid_did_password and not self.paranoid_did_public_key: - return 'publickey,password' + if username == "slowdive": + return "publickey,password" + if username == "paranoid": + if ( + not self.paranoid_did_password + and not self.paranoid_did_public_key + ): + return "publickey,password" elif self.paranoid_did_password: - return 'publickey' + return "publickey" else: - return 'password' - if username == 'commie': - return 'keyboard-interactive' - if username == 'utf8': - return 'password' - if username == 'non-utf8': - return 'password' - return 'publickey' + return "password" + if username == "commie": + return "keyboard-interactive" + if username == "utf8": + return "password" + if username == "non-utf8": + return "password" + return "publickey" def check_auth_password(self, username, password): - if (username == 'slowdive') and (password == 'pygmalion'): + if (username == "slowdive") and (password == "pygmalion"): return AUTH_SUCCESSFUL - if (username == 'paranoid') and (password == 'paranoid'): + if (username == "paranoid") and (password == "paranoid"): # 2-part auth (even openssh doesn't support this) self.paranoid_did_password = True if self.paranoid_did_public_key: return AUTH_SUCCESSFUL return AUTH_PARTIALLY_SUCCESSFUL - if (username == 'utf8') and (password == _pwd): + if (username == "utf8") and (password == _pwd): return AUTH_SUCCESSFUL - if (username == 'non-utf8') and (password == '\xff'): + if (username == "non-utf8") and (password == "\xff"): return AUTH_SUCCESSFUL - if username == 'bad-server': + if username == "bad-server": raise Exception("Ack!") - if username == 'unresponsive-server': + if username == "unresponsive-server": sleep(5) return AUTH_SUCCESSFUL return AUTH_FAILED def check_auth_publickey(self, username, key): - if (username == 'paranoid') and (key == self.paranoid_key): + if (username == "paranoid") and (key == self.paranoid_key): # 2-part auth self.paranoid_did_public_key = True if self.paranoid_did_password: @@ -92,19 +100,21 @@ class NullServer (ServerInterface): return AUTH_FAILED def check_auth_interactive(self, username, submethods): - if username == 'commie': + if username == "commie": self.username = username - return InteractiveQuery('password', 'Please enter a password.', ('Password', False)) + return InteractiveQuery( + "password", "Please enter a password.", ("Password", False) + ) return AUTH_FAILED def check_auth_interactive_response(self, responses): - if self.username == 'commie': - if (len(responses) == 1) and (responses[0] == 'cat'): + if self.username == "commie": + if (len(responses) == 1) and (responses[0] == "cat"): return AUTH_SUCCESSFUL return AUTH_FAILED -class AuthTest (unittest.TestCase): +class AuthTest(unittest.TestCase): def setUp(self): self.socks = LoopSocket() @@ -120,7 +130,7 @@ class AuthTest (unittest.TestCase): self.sockc.close() def start_server(self): - host_key = RSAKey.from_private_key_file(_support('test_rsa.key')) + host_key = RSAKey.from_private_key_file(_support("test_rsa.key")) self.public_host_key = RSAKey(data=host_key.asbytes()) self.ts.add_server_key(host_key) self.event = threading.Event() @@ -140,13 +150,16 @@ class AuthTest (unittest.TestCase): """ self.start_server() try: - self.tc.connect(hostkey=self.public_host_key, - username='unknown', password='error') + self.tc.connect( + hostkey=self.public_host_key, + username="unknown", + password="error", + ) self.assertTrue(False) except: etype, evalue, etb = sys.exc_info() self.assertEqual(BadAuthenticationType, etype) - self.assertEqual(['publickey'], evalue.allowed_types) + self.assertEqual(["publickey"], evalue.allowed_types) def test_bad_password(self): """ @@ -156,12 +169,12 @@ class AuthTest (unittest.TestCase): self.start_server() self.tc.connect(hostkey=self.public_host_key) try: - self.tc.auth_password(username='slowdive', password='error') + self.tc.auth_password(username="slowdive", password="error") self.assertTrue(False) except: etype, evalue, etb = sys.exc_info() self.assertTrue(issubclass(etype, AuthenticationException)) - self.tc.auth_password(username='slowdive', password='pygmalion') + self.tc.auth_password(username="slowdive", password="pygmalion") self.verify_finished() def test_multipart_auth(self): @@ -170,10 +183,12 @@ class AuthTest (unittest.TestCase): """ self.start_server() self.tc.connect(hostkey=self.public_host_key) - remain = self.tc.auth_password(username='paranoid', password='paranoid') - self.assertEqual(['publickey'], remain) - key = DSSKey.from_private_key_file(_support('test_dss.key')) - remain = self.tc.auth_publickey(username='paranoid', key=key) + remain = self.tc.auth_password( + username="paranoid", password="paranoid" + ) + self.assertEqual(["publickey"], remain) + key = DSSKey.from_private_key_file(_support("test_dss.key")) + remain = self.tc.auth_publickey(username="paranoid", key=key) self.assertEqual([], remain) self.verify_finished() @@ -188,10 +203,11 @@ class AuthTest (unittest.TestCase): self.got_title = title self.got_instructions = instructions self.got_prompts = prompts - return ['cat'] - remain = self.tc.auth_interactive('commie', handler) - self.assertEqual(self.got_title, 'password') - self.assertEqual(self.got_prompts, [('Password', False)]) + return ["cat"] + + remain = self.tc.auth_interactive("commie", handler) + self.assertEqual(self.got_title, "password") + self.assertEqual(self.got_prompts, [("Password", False)]) self.assertEqual([], remain) self.verify_finished() @@ -202,7 +218,7 @@ class AuthTest (unittest.TestCase): """ self.start_server() self.tc.connect(hostkey=self.public_host_key) - remain = self.tc.auth_password('commie', 'cat') + remain = self.tc.auth_password("commie", "cat") self.assertEqual([], remain) self.verify_finished() @@ -212,7 +228,7 @@ class AuthTest (unittest.TestCase): """ self.start_server() self.tc.connect(hostkey=self.public_host_key) - remain = self.tc.auth_password('utf8', _pwd) + remain = self.tc.auth_password("utf8", _pwd) self.assertEqual([], remain) self.verify_finished() @@ -223,7 +239,7 @@ class AuthTest (unittest.TestCase): """ self.start_server() self.tc.connect(hostkey=self.public_host_key) - remain = self.tc.auth_password('non-utf8', '\xff') + remain = self.tc.auth_password("non-utf8", "\xff") self.assertEqual([], remain) self.verify_finished() @@ -235,7 +251,7 @@ class AuthTest (unittest.TestCase): self.start_server() self.tc.connect(hostkey=self.public_host_key) try: - remain = self.tc.auth_password('bad-server', 'hello') + remain = self.tc.auth_password("bad-server", "hello") except: etype, evalue, etb = sys.exc_info() self.assertTrue(issubclass(etype, AuthenticationException)) @@ -250,8 +266,8 @@ class AuthTest (unittest.TestCase): self.start_server() self.tc.connect() try: - remain = self.tc.auth_password('unresponsive-server', 'hello') + remain = self.tc.auth_password("unresponsive-server", "hello") except: etype, evalue, etb = sys.exc_info() self.assertTrue(issubclass(etype, AuthenticationException)) - self.assertTrue('Authentication timeout' in str(evalue)) + self.assertTrue("Authentication timeout" in str(evalue)) diff --git a/tests/test_buffered_pipe.py b/tests/test_buffered_pipe.py index 03616c55..9f986a5e 100644 --- a/tests/test_buffered_pipe.py +++ b/tests/test_buffered_pipe.py @@ -30,9 +30,9 @@ from paramiko.py3compat import b def delay_thread(p): - p.feed('a') + p.feed("a") time.sleep(0.5) - p.feed('b') + p.feed("b") p.close() @@ -42,41 +42,42 @@ def close_thread(p): class BufferedPipeTest(unittest.TestCase): + def test_1_buffered_pipe(self): p = BufferedPipe() self.assertTrue(not p.read_ready()) - p.feed('hello.') + p.feed("hello.") self.assertTrue(p.read_ready()) data = p.read(6) - self.assertEqual(b'hello.', data) + self.assertEqual(b"hello.", data) - p.feed('plus/minus') - self.assertEqual(b'plu', p.read(3)) - self.assertEqual(b's/m', p.read(3)) - self.assertEqual(b'inus', p.read(4)) + p.feed("plus/minus") + self.assertEqual(b"plu", p.read(3)) + self.assertEqual(b"s/m", p.read(3)) + self.assertEqual(b"inus", p.read(4)) p.close() self.assertTrue(not p.read_ready()) - self.assertEqual(b'', p.read(1)) + self.assertEqual(b"", p.read(1)) def test_2_delay(self): p = BufferedPipe() self.assertTrue(not p.read_ready()) threading.Thread(target=delay_thread, args=(p,)).start() - self.assertEqual(b'a', p.read(1, 0.1)) + self.assertEqual(b"a", p.read(1, 0.1)) try: p.read(1, 0.1) self.assertTrue(False) except PipeTimeout: pass - self.assertEqual(b'b', p.read(1, 1.0)) - self.assertEqual(b'', p.read(1)) + self.assertEqual(b"b", p.read(1, 1.0)) + self.assertEqual(b"", p.read(1)) def test_3_close_while_reading(self): p = BufferedPipe() threading.Thread(target=close_thread, args=(p,)).start() data = p.read(1, 1.0) - self.assertEqual(b'', data) + self.assertEqual(b"", data) def test_4_or_pipe(self): p = pipe.make_pipe() diff --git a/tests/test_client.py b/tests/test_client.py index 7163fdcf..4943df29 100644 --- a/tests/test_client.py +++ b/tests/test_client.py @@ -48,30 +48,31 @@ requires_gss_auth = unittest.skipUnless( ) FINGERPRINTS = { - 'ssh-dss': b'\x44\x78\xf0\xb9\xa2\x3c\xc5\x18\x20\x09\xff\x75\x5b\xc1\xd2\x6c', - 'ssh-rsa': b'\x60\x73\x38\x44\xcb\x51\x86\x65\x7f\xde\xda\xa2\x2b\x5a\x57\xd5', - 'ecdsa-sha2-nistp256': b'\x25\x19\xeb\x55\xe6\xa1\x47\xff\x4f\x38\xd2\x75\x6f\xa5\xd5\x60', - 'ssh-ed25519': b'\xb3\xd5"\xaa\xf9u^\xe8\xcd\x0e\xea\x02\xb9)\xa2\x80', + "ssh-dss": b"\x44\x78\xf0\xb9\xa2\x3c\xc5\x18\x20\x09\xff\x75\x5b\xc1\xd2\x6c", + "ssh-rsa": b"\x60\x73\x38\x44\xcb\x51\x86\x65\x7f\xde\xda\xa2\x2b\x5a\x57\xd5", + "ecdsa-sha2-nistp256": b"\x25\x19\xeb\x55\xe6\xa1\x47\xff\x4f\x38\xd2\x75\x6f\xa5\xd5\x60", + "ssh-ed25519": b'\xb3\xd5"\xaa\xf9u^\xe8\xcd\x0e\xea\x02\xb9)\xa2\x80', } class NullServer(paramiko.ServerInterface): + def __init__(self, *args, **kwargs): # Allow tests to enable/disable specific key types - self.__allowed_keys = kwargs.pop('allowed_keys', []) + self.__allowed_keys = kwargs.pop("allowed_keys", []) # And allow them to set a (single...meh) expected public blob (cert) - self.__expected_public_blob = kwargs.pop('public_blob', None) + self.__expected_public_blob = kwargs.pop("public_blob", None) super(NullServer, self).__init__(*args, **kwargs) def get_allowed_auths(self, username): - if username == 'slowdive': - return 'publickey,password' - return 'publickey' + if username == "slowdive": + return "publickey,password" + return "publickey" def check_auth_password(self, username, password): - if (username == 'slowdive') and (password == 'pygmalion'): + if (username == "slowdive") and (password == "pygmalion"): return paramiko.AUTH_SUCCESSFUL - if (username == 'slowdive') and (password == 'unresponsive-server'): + if (username == "slowdive") and (password == "unresponsive-server"): time.sleep(5) return paramiko.AUTH_SUCCESSFUL return paramiko.AUTH_FAILED @@ -83,13 +84,13 @@ class NullServer(paramiko.ServerInterface): return paramiko.AUTH_FAILED # Base check: allowed auth type & fingerprint matches happy = ( - key.get_name() in self.__allowed_keys and - key.get_fingerprint() == expected + key.get_name() in self.__allowed_keys + and key.get_fingerprint() == expected ) # Secondary check: if test wants assertions about cert data if ( - self.__expected_public_blob is not None and - key.public_blob != self.__expected_public_blob + self.__expected_public_blob is not None + and key.public_blob != self.__expected_public_blob ): happy = False return paramiko.AUTH_SUCCESSFUL if happy else paramiko.AUTH_FAILED @@ -98,31 +99,32 @@ class NullServer(paramiko.ServerInterface): return paramiko.OPEN_SUCCEEDED def check_channel_exec_request(self, channel, command): - if command != b'yes': + if command != b"yes": return False return True def check_channel_env_request(self, channel, name, value): - if name == 'INVALID_ENV': + if name == "INVALID_ENV": return False - if not hasattr(channel, 'env'): - setattr(channel, 'env', {}) + if not hasattr(channel, "env"): + setattr(channel, "env", {}) channel.env[name] = value return True class ClientTest(unittest.TestCase): + def setUp(self): self.sockl = socket.socket() - self.sockl.bind(('localhost', 0)) + self.sockl.bind(("localhost", 0)) self.sockl.listen(1) self.addr, self.port = self.sockl.getsockname() self.connect_kwargs = dict( hostname=self.addr, port=self.port, - username='slowdive', + username="slowdive", look_for_keys=False, ) self.event = threading.Event() @@ -130,10 +132,10 @@ class ClientTest(unittest.TestCase): def tearDown(self): # Shut down client Transport - if hasattr(self, 'tc'): + if hasattr(self, "tc"): self.tc.close() # Shut down shared socket - if hasattr(self, 'sockl'): + if hasattr(self, "sockl"): # Signal to server thread that it should shut down early; it checks # this immediately after accept(). (In scenarios where connection # actually succeeded during the test, this becomes a no-op.) @@ -151,7 +153,7 @@ class ClientTest(unittest.TestCase): self.sockl.close() def _run( - self, allowed_keys=None, delay=0, public_blob=None, kill_event=None, + self, allowed_keys=None, delay=0, public_blob=None, kill_event=None ): if allowed_keys is None: allowed_keys = FINGERPRINTS.keys() @@ -163,10 +165,10 @@ class ClientTest(unittest.TestCase): self.socks.close() return self.ts = paramiko.Transport(self.socks) - keypath = _support('test_rsa.key') + keypath = _support("test_rsa.key") host_key = paramiko.RSAKey.from_private_key_file(keypath) self.ts.add_server_key(host_key) - keypath = _support('test_ecdsa_256.key') + keypath = _support("test_ecdsa_256.key") host_key = paramiko.ECDSAKey.from_private_key_file(keypath) self.ts.add_server_key(host_key) server = NullServer(allowed_keys=allowed_keys, public_blob=public_blob) @@ -181,17 +183,21 @@ class ClientTest(unittest.TestCase): The exception is ``allowed_keys`` which is stripped and handed to the ``NullServer`` used for testing. """ - run_kwargs = {'kill_event': self.kill_event} - for key in ('allowed_keys', 'public_blob'): + run_kwargs = {"kill_event": self.kill_event} + for key in ("allowed_keys", "public_blob"): run_kwargs[key] = kwargs.pop(key, None) # Server setup threading.Thread(target=self._run, kwargs=run_kwargs).start() - host_key = paramiko.RSAKey.from_private_key_file(_support('test_rsa.key')) + host_key = paramiko.RSAKey.from_private_key_file( + _support("test_rsa.key") + ) public_host_key = paramiko.RSAKey(data=host_key.asbytes()) # Client setup self.tc = paramiko.SSHClient() - self.tc.get_host_keys().add('[%s]:%d' % (self.addr, self.port), 'ssh-rsa', public_host_key) + self.tc.get_host_keys().add( + "[%s]:%d" % (self.addr, self.port), "ssh-rsa", public_host_key + ) # Actual connection self.tc.connect(**dict(self.connect_kwargs, **kwargs)) @@ -200,22 +206,22 @@ class ClientTest(unittest.TestCase): self.event.wait(1.0) self.assertTrue(self.event.is_set()) self.assertTrue(self.ts.is_active()) - self.assertEqual('slowdive', self.ts.get_username()) + self.assertEqual("slowdive", self.ts.get_username()) self.assertEqual(True, self.ts.is_authenticated()) self.assertEqual(False, self.tc.get_transport().gss_kex_used) # Command execution functions? - stdin, stdout, stderr = self.tc.exec_command('yes') + stdin, stdout, stderr = self.tc.exec_command("yes") schan = self.ts.accept(1.0) - schan.send('Hello there.\n') - schan.send_stderr('This is on stderr.\n') + schan.send("Hello there.\n") + schan.send_stderr("This is on stderr.\n") schan.close() - self.assertEqual('Hello there.\n', stdout.readline()) - self.assertEqual('', stdout.readline()) - self.assertEqual('This is on stderr.\n', stderr.readline()) - self.assertEqual('', stderr.readline()) + self.assertEqual("Hello there.\n", stdout.readline()) + self.assertEqual("", stdout.readline()) + self.assertEqual("This is on stderr.\n", stderr.readline()) + self.assertEqual("", stderr.readline()) # Cleanup stdin.close() @@ -224,32 +230,33 @@ class ClientTest(unittest.TestCase): class SSHClientTest(ClientTest): + def test_1_client(self): """ verify that the SSHClient stuff works too. """ - self._test_connection(password='pygmalion') + self._test_connection(password="pygmalion") def test_2_client_dsa(self): """ verify that SSHClient works with a DSA key. """ - self._test_connection(key_filename=_support('test_dss.key')) + self._test_connection(key_filename=_support("test_dss.key")) def test_client_rsa(self): """ verify that SSHClient works with an RSA key. """ - self._test_connection(key_filename=_support('test_rsa.key')) + self._test_connection(key_filename=_support("test_rsa.key")) def test_2_5_client_ecdsa(self): """ verify that SSHClient works with an ECDSA key. """ - self._test_connection(key_filename=_support('test_ecdsa_256.key')) + self._test_connection(key_filename=_support("test_ecdsa_256.key")) def test_client_ed25519(self): - self._test_connection(key_filename=_support('test_ed25519.key')) + self._test_connection(key_filename=_support("test_ed25519.key")) def test_3_multiple_key_files(self): """ @@ -257,22 +264,22 @@ class SSHClientTest(ClientTest): """ # This is dumb :( types_ = { - 'rsa': 'ssh-rsa', - 'dss': 'ssh-dss', - 'ecdsa': 'ecdsa-sha2-nistp256', + "rsa": "ssh-rsa", + "dss": "ssh-dss", + "ecdsa": "ecdsa-sha2-nistp256", } # Various combos of attempted & valid keys # TODO: try every possible combo using itertools functions for attempt, accept in ( - (['rsa', 'dss'], ['dss']), # Original test #3 - (['dss', 'rsa'], ['dss']), # Ordering matters sometimes, sadly - (['dss', 'rsa', 'ecdsa_256'], ['dss']), # Try ECDSA but fail - (['rsa', 'ecdsa_256'], ['ecdsa']), # ECDSA success + (["rsa", "dss"], ["dss"]), # Original test #3 + (["dss", "rsa"], ["dss"]), # Ordering matters sometimes, sadly + (["dss", "rsa", "ecdsa_256"], ["dss"]), # Try ECDSA but fail + (["rsa", "ecdsa_256"], ["ecdsa"]), # ECDSA success ): try: self._test_connection( key_filename=[ - _support('test_{}.key'.format(x)) for x in attempt + _support("test_{}.key".format(x)) for x in attempt ], allowed_keys=[types_[x] for x in accept], ) @@ -288,10 +295,11 @@ class SSHClientTest(ClientTest): """ # Until #387 is fixed we have to catch a high-up exception since # various platforms trigger different errors here >_< - self.assertRaises(SSHException, + self.assertRaises( + SSHException, self._test_connection, - key_filename=[_support('test_rsa.key')], - allowed_keys=['ecdsa-sha2-nistp256'], + key_filename=[_support("test_rsa.key")], + allowed_keys=["ecdsa-sha2-nistp256"], ) def test_certs_allowed_as_key_filename_values(self): @@ -299,9 +307,9 @@ class SSHClientTest(ClientTest): # They're similar except for which path is given; the expected auth and # server-side behavior is 100% identical.) # NOTE: only bothered whipping up one cert per overall class/family. - for type_ in ('rsa', 'dss', 'ecdsa_256', 'ed25519'): - cert_name = 'test_{}.key-cert.pub'.format(type_) - cert_path = _support(os.path.join('cert_support', cert_name)) + for type_ in ("rsa", "dss", "ecdsa_256", "ed25519"): + cert_name = "test_{}.key-cert.pub".format(type_) + cert_path = _support(os.path.join("cert_support", cert_name)) self._test_connection( key_filename=cert_path, public_blob=PublicBlob.from_file(cert_path), @@ -314,13 +322,13 @@ class SSHClientTest(ClientTest): # about the server-side key object's public blob. Thus, we can prove # that a specific cert was found, along with regular authorization # succeeding proving that the overall flow works. - for type_ in ('rsa', 'dss', 'ecdsa_256', 'ed25519'): - key_name = 'test_{}.key'.format(type_) - key_path = _support(os.path.join('cert_support', key_name)) + for type_ in ("rsa", "dss", "ecdsa_256", "ed25519"): + key_name = "test_{}.key".format(type_) + key_path = _support(os.path.join("cert_support", key_name)) self._test_connection( key_filename=key_path, public_blob=PublicBlob.from_file( - '{}-cert.pub'.format(key_path) + "{}-cert.pub".format(key_path) ), ) @@ -335,19 +343,19 @@ class SSHClientTest(ClientTest): verify that SSHClient's AutoAddPolicy works. """ threading.Thread(target=self._run).start() - hostname = '[%s]:%d' % (self.addr, self.port) - key_file = _support('test_ecdsa_256.key') + hostname = "[%s]:%d" % (self.addr, self.port) + key_file = _support("test_ecdsa_256.key") public_host_key = paramiko.ECDSAKey.from_private_key_file(key_file) self.tc = paramiko.SSHClient() self.tc.set_missing_host_key_policy(paramiko.AutoAddPolicy()) self.assertEqual(0, len(self.tc.get_host_keys())) - self.tc.connect(password='pygmalion', **self.connect_kwargs) + self.tc.connect(password="pygmalion", **self.connect_kwargs) self.event.wait(1.0) self.assertTrue(self.event.is_set()) self.assertTrue(self.ts.is_active()) - self.assertEqual('slowdive', self.ts.get_username()) + self.assertEqual("slowdive", self.ts.get_username()) self.assertEqual(True, self.ts.is_authenticated()) self.assertEqual(1, len(self.tc.get_host_keys())) new_host_key = list(self.tc.get_host_keys()[hostname].values())[0] @@ -357,9 +365,11 @@ class SSHClientTest(ClientTest): """ verify that SSHClient correctly saves a known_hosts file. """ - warnings.filterwarnings('ignore', 'tempnam.*') + warnings.filterwarnings("ignore", "tempnam.*") - host_key = paramiko.RSAKey.from_private_key_file(_support('test_rsa.key')) + host_key = paramiko.RSAKey.from_private_key_file( + _support("test_rsa.key") + ) public_host_key = paramiko.RSAKey(data=host_key.asbytes()) fd, localname = mkstemp() os.close(fd) @@ -367,11 +377,13 @@ class SSHClientTest(ClientTest): client = paramiko.SSHClient() self.assertEquals(0, len(client.get_host_keys())) - host_id = '[%s]:%d' % (self.addr, self.port) + host_id = "[%s]:%d" % (self.addr, self.port) - client.get_host_keys().add(host_id, 'ssh-rsa', public_host_key) + client.get_host_keys().add(host_id, "ssh-rsa", public_host_key) self.assertEquals(1, len(client.get_host_keys())) - self.assertEquals(public_host_key, client.get_host_keys()[host_id]['ssh-rsa']) + self.assertEquals( + public_host_key, client.get_host_keys()[host_id]["ssh-rsa"] + ) client.save_host_keys(localname) @@ -394,7 +406,7 @@ class SSHClientTest(ClientTest): self.tc = paramiko.SSHClient() self.tc.set_missing_host_key_policy(paramiko.AutoAddPolicy()) self.assertEqual(0, len(self.tc.get_host_keys())) - self.tc.connect(**dict(self.connect_kwargs, password='pygmalion')) + self.tc.connect(**dict(self.connect_kwargs, password="pygmalion")) self.event.wait(1.0) self.assertTrue(self.event.is_set()) @@ -423,7 +435,7 @@ class SSHClientTest(ClientTest): self.tc = tc self.tc.set_missing_host_key_policy(paramiko.AutoAddPolicy()) self.assertEquals(0, len(self.tc.get_host_keys())) - self.tc.connect(**dict(self.connect_kwargs, password='pygmalion')) + self.tc.connect(**dict(self.connect_kwargs, password="pygmalion")) self.event.wait(1.0) self.assertTrue(self.event.is_set()) @@ -438,19 +450,19 @@ class SSHClientTest(ClientTest): verify that the SSHClient has a configurable banner timeout. """ # Start the thread with a 1 second wait. - threading.Thread(target=self._run, kwargs={'delay': 1}).start() - host_key = paramiko.RSAKey.from_private_key_file(_support('test_rsa.key')) + threading.Thread(target=self._run, kwargs={"delay": 1}).start() + host_key = paramiko.RSAKey.from_private_key_file( + _support("test_rsa.key") + ) public_host_key = paramiko.RSAKey(data=host_key.asbytes()) self.tc = paramiko.SSHClient() - self.tc.get_host_keys().add('[%s]:%d' % (self.addr, self.port), 'ssh-rsa', public_host_key) + self.tc.get_host_keys().add( + "[%s]:%d" % (self.addr, self.port), "ssh-rsa", public_host_key + ) # Connect with a half second banner timeout. kwargs = dict(self.connect_kwargs, banner_timeout=0.5) - self.assertRaises( - paramiko.SSHException, - self.tc.connect, - **kwargs - ) + self.assertRaises(paramiko.SSHException, self.tc.connect, **kwargs) def test_8_auth_trickledown(self): """ @@ -466,9 +478,9 @@ class SSHClientTest(ClientTest): # 'television' as per tests/test_pkey.py). NOTE: must use # key_filename, loading the actual key here with PKey will except # immediately; we're testing the try/except crap within Client. - key_filename=[_support('test_rsa_password.key')], + key_filename=[_support("test_rsa_password.key")], # Actual password for default 'slowdive' user - password='pygmalion', + password="pygmalion", ) self._test_connection(**kwargs) @@ -481,7 +493,7 @@ class SSHClientTest(ClientTest): self.assertRaises( AuthenticationException, self._test_connection, - password='unresponsive-server', + password="unresponsive-server", auth_timeout=0.5, ) @@ -490,10 +502,7 @@ class SSHClientTest(ClientTest): """ Failed gssapi-keyex auth doesn't prevent subsequent key auth from succeeding """ - kwargs = dict( - gss_kex=True, - key_filename=[_support('test_rsa.key')], - ) + kwargs = dict(gss_kex=True, key_filename=[_support("test_rsa.key")]) self._test_connection(**kwargs) @requires_gss_auth @@ -501,10 +510,7 @@ class SSHClientTest(ClientTest): """ Failed gssapi-with-mic auth doesn't prevent subsequent key auth from succeeding """ - kwargs = dict( - gss_auth=True, - key_filename=[_support('test_rsa.key')], - ) + kwargs = dict(gss_auth=True, key_filename=[_support("test_rsa.key")]) self._test_connection(**kwargs) def test_12_reject_policy(self): @@ -519,7 +525,8 @@ class SSHClientTest(ClientTest): self.assertRaises( paramiko.SSHException, self.tc.connect, - password='pygmalion', **self.connect_kwargs + password="pygmalion", + **self.connect_kwargs ) @requires_gss_auth @@ -537,14 +544,14 @@ class SSHClientTest(ClientTest): self.assertRaises( paramiko.SSHException, self.tc.connect, - password='pygmalion', + password="pygmalion", gss_kex=True, - **self.connect_kwargs + **self.connect_kwargs ) def _client_host_key_bad(self, host_key): threading.Thread(target=self._run).start() - hostname = '[%s]:%d' % (self.addr, self.port) + hostname = "[%s]:%d" % (self.addr, self.port) self.tc = paramiko.SSHClient() self.tc.set_missing_host_key_policy(paramiko.WarningPolicy()) @@ -554,13 +561,13 @@ class SSHClientTest(ClientTest): self.assertRaises( paramiko.BadHostKeyException, self.tc.connect, - password='pygmalion', + password="pygmalion", **self.connect_kwargs ) def _client_host_key_good(self, ktype, kfile): threading.Thread(target=self._run).start() - hostname = '[%s]:%d' % (self.addr, self.port) + hostname = "[%s]:%d" % (self.addr, self.port) self.tc = paramiko.SSHClient() self.tc.set_missing_host_key_policy(paramiko.RejectPolicy()) @@ -568,7 +575,7 @@ class SSHClientTest(ClientTest): known_hosts = self.tc.get_host_keys() known_hosts.add(hostname, host_key.get_name(), host_key) - self.tc.connect(password='pygmalion', **self.connect_kwargs) + self.tc.connect(password="pygmalion", **self.connect_kwargs) self.event.wait(1.0) self.assertTrue(self.event.is_set()) self.assertTrue(self.ts.is_active()) @@ -583,10 +590,10 @@ class SSHClientTest(ClientTest): self._client_host_key_bad(host_key) def test_host_key_negotiation_3(self): - self._client_host_key_good(paramiko.ECDSAKey, 'test_ecdsa_256.key') + self._client_host_key_good(paramiko.ECDSAKey, "test_ecdsa_256.key") def test_host_key_negotiation_4(self): - self._client_host_key_good(paramiko.RSAKey, 'test_rsa.key') + self._client_host_key_good(paramiko.RSAKey, "test_rsa.key") def _setup_for_env(self): threading.Thread(target=self._run).start() @@ -594,7 +601,9 @@ class SSHClientTest(ClientTest): self.tc = paramiko.SSHClient() self.tc.set_missing_host_key_policy(paramiko.AutoAddPolicy()) self.assertEqual(0, len(self.tc.get_host_keys())) - self.tc.connect(self.addr, self.port, username='slowdive', password='pygmalion') + self.tc.connect( + self.addr, self.port, username="slowdive", password="pygmalion" + ) self.event.wait(1.0) self.assertTrue(self.event.isSet()) @@ -605,11 +614,11 @@ class SSHClientTest(ClientTest): Verify that environment variables can be set by the client. """ self._setup_for_env() - target_env = {b'A': b'B', b'C': b'd'} + target_env = {b"A": b"B", b"C": b"d"} - self.tc.exec_command('yes', environment=target_env) + self.tc.exec_command("yes", environment=target_env) schan = self.ts.accept(1.0) - self.assertEqual(target_env, getattr(schan, 'env', {})) + self.assertEqual(target_env, getattr(schan, "env", {})) schan.close() @unittest.skip("Clients normally fail silently, thus so do we, for now") @@ -617,14 +626,14 @@ class SSHClientTest(ClientTest): self._setup_for_env() with self.assertRaises(SSHException) as manager: # Verify that a rejection by the server can be detected - self.tc.exec_command('yes', environment={b'INVALID_ENV': b''}) + self.tc.exec_command("yes", environment={b"INVALID_ENV": b""}) self.assertTrue( - 'INVALID_ENV' in str(manager.exception), - 'Expected variable name in error message' + "INVALID_ENV" in str(manager.exception), + "Expected variable name in error message", ) self.assertTrue( isinstance(manager.exception.args[1], SSHException), - 'Expected original SSHException in exception' + "Expected original SSHException in exception", ) def test_missing_key_policy_accepts_classes_or_instances(self): @@ -652,35 +661,39 @@ class PasswordPassphraseTests(ClientTest): def test_password_kwarg_works_for_password_auth(self): # Straightforward / duplicate of earlier basic password test. - self._test_connection(password='pygmalion') + self._test_connection(password="pygmalion") # TODO: more granular exception pending #387; should be signaling "no auth # methods available" because no key and no password @raises(SSHException) def test_passphrase_kwarg_not_used_for_password_auth(self): # Using the "right" password in the "wrong" field shouldn't work. - self._test_connection(passphrase='pygmalion') + self._test_connection(passphrase="pygmalion") def test_passphrase_kwarg_used_for_key_passphrase(self): # Straightforward again, with new passphrase kwarg. self._test_connection( - key_filename=_support('test_rsa_password.key'), - passphrase='television', + key_filename=_support("test_rsa_password.key"), + passphrase="television", ) - def test_password_kwarg_used_for_passphrase_when_no_passphrase_kwarg_given(self): # noqa + def test_password_kwarg_used_for_passphrase_when_no_passphrase_kwarg_given( + self + ): # noqa # Backwards compatibility: passphrase in the password field. self._test_connection( - key_filename=_support('test_rsa_password.key'), - password='television', + key_filename=_support("test_rsa_password.key"), + password="television", ) - @raises(AuthenticationException) # TODO: more granular - def test_password_kwarg_not_used_for_passphrase_when_passphrase_kwarg_given(self): # noqa + @raises(AuthenticationException) # TODO: more granular + def test_password_kwarg_not_used_for_passphrase_when_passphrase_kwarg_given( + self + ): # noqa # Sanity: if we're given both fields, the password field is NOT used as # a passphrase. self._test_connection( - key_filename=_support('test_rsa_password.key'), - password='television', - passphrase='wat? lol no', + key_filename=_support("test_rsa_password.key"), + password="television", + passphrase="wat? lol no", ) diff --git a/tests/test_file.py b/tests/test_file.py index 3d2c94e6..deacd60a 100644 --- a/tests/test_file.py +++ b/tests/test_file.py @@ -30,18 +30,19 @@ from paramiko.py3compat import BytesIO from .util import needs_builtin -class LoopbackFile (BufferedFile): +class LoopbackFile(BufferedFile): """ BufferedFile object that you can write data into, and then read it back. """ - def __init__(self, mode='r', bufsize=-1): + + def __init__(self, mode="r", bufsize=-1): BufferedFile.__init__(self) self._set_mode(mode, bufsize) self.buffer = BytesIO() self.offset = 0 def _read(self, size): - data = self.buffer.getvalue()[self.offset:self.offset+size] + data = self.buffer.getvalue()[self.offset : self.offset + size] self.offset += len(data) return data @@ -50,44 +51,46 @@ class LoopbackFile (BufferedFile): return len(data) -class BufferedFileTest (unittest.TestCase): +class BufferedFileTest(unittest.TestCase): def test_1_simple(self): - f = LoopbackFile('r') + f = LoopbackFile("r") try: - f.write(b'hi') - self.assertTrue(False, 'no exception on write to read-only file') + f.write(b"hi") + self.assertTrue(False, "no exception on write to read-only file") except: pass f.close() - f = LoopbackFile('w') + f = LoopbackFile("w") try: f.read(1) - self.assertTrue(False, 'no exception to read from write-only file') + self.assertTrue(False, "no exception to read from write-only file") except: pass f.close() def test_2_readline(self): - f = LoopbackFile('r+U') - f.write(b'First line.\nSecond line.\r\nThird line.\n' + - b'Fourth line.\nFinal line non-terminated.') + f = LoopbackFile("r+U") + f.write( + b"First line.\nSecond line.\r\nThird line.\n" + + b"Fourth line.\nFinal line non-terminated." + ) - self.assertEqual(f.readline(), 'First line.\n') + self.assertEqual(f.readline(), "First line.\n") # universal newline mode should convert this linefeed: - self.assertEqual(f.readline(), 'Second line.\n') + self.assertEqual(f.readline(), "Second line.\n") # truncated line: - self.assertEqual(f.readline(7), 'Third l') - self.assertEqual(f.readline(), 'ine.\n') + self.assertEqual(f.readline(7), "Third l") + self.assertEqual(f.readline(), "ine.\n") # newline should be detected and only the fourth line returned - self.assertEqual(f.readline(39), 'Fourth line.\n') - self.assertEqual(f.readline(), 'Final line non-terminated.') - self.assertEqual(f.readline(), '') + self.assertEqual(f.readline(39), "Fourth line.\n") + self.assertEqual(f.readline(), "Final line non-terminated.") + self.assertEqual(f.readline(), "") f.close() try: f.readline() - self.assertTrue(False, 'no exception on readline of closed file') + self.assertTrue(False, "no exception on readline of closed file") except IOError: pass self.assertTrue(linefeed_byte in f.newlines) @@ -98,11 +101,11 @@ class BufferedFileTest (unittest.TestCase): """ try to trick the linefeed detector. """ - f = LoopbackFile('r+U') - f.write(b'First line.\r') - self.assertEqual(f.readline(), 'First line.\n') - f.write(b'\nSecond.\r\n') - self.assertEqual(f.readline(), 'Second.\n') + f = LoopbackFile("r+U") + f.write(b"First line.\r") + self.assertEqual(f.readline(), "First line.\n") + f.write(b"\nSecond.\r\n") + self.assertEqual(f.readline(), "Second.\n") f.close() self.assertEqual(f.newlines, crlf) @@ -110,51 +113,54 @@ class BufferedFileTest (unittest.TestCase): """ verify that write buffering is on. """ - f = LoopbackFile('r+', 1) - f.write(b'Complete line.\nIncomplete line.') - self.assertEqual(f.readline(), 'Complete line.\n') - self.assertEqual(f.readline(), '') - f.write('..\n') - self.assertEqual(f.readline(), 'Incomplete line...\n') + f = LoopbackFile("r+", 1) + f.write(b"Complete line.\nIncomplete line.") + self.assertEqual(f.readline(), "Complete line.\n") + self.assertEqual(f.readline(), "") + f.write("..\n") + self.assertEqual(f.readline(), "Incomplete line...\n") f.close() def test_5_flush(self): """ verify that flush will force a write. """ - f = LoopbackFile('r+', 512) - f.write('Not\nquite\n512 bytes.\n') - self.assertEqual(f.read(1), b'') + f = LoopbackFile("r+", 512) + f.write("Not\nquite\n512 bytes.\n") + self.assertEqual(f.read(1), b"") f.flush() - self.assertEqual(f.read(5), b'Not\nq') - self.assertEqual(f.read(10), b'uite\n512 b') - self.assertEqual(f.read(9), b'ytes.\n') - self.assertEqual(f.read(3), b'') + self.assertEqual(f.read(5), b"Not\nq") + self.assertEqual(f.read(10), b"uite\n512 b") + self.assertEqual(f.read(9), b"ytes.\n") + self.assertEqual(f.read(3), b"") f.close() def test_6_buffering(self): """ verify that flushing happens automatically on buffer crossing. """ - f = LoopbackFile('r+', 16) - f.write(b'Too small.') - self.assertEqual(f.read(4), b'') - f.write(b' ') - self.assertEqual(f.read(4), b'') - f.write(b'Enough.') - self.assertEqual(f.read(20), b'Too small. Enough.') + f = LoopbackFile("r+", 16) + f.write(b"Too small.") + self.assertEqual(f.read(4), b"") + f.write(b" ") + self.assertEqual(f.read(4), b"") + f.write(b"Enough.") + self.assertEqual(f.read(20), b"Too small. Enough.") f.close() def test_7_read_all(self): """ verify that read(-1) returns everything left in the file. """ - f = LoopbackFile('r+', 16) - f.write(b'The first thing you need to do is open your eyes. ') - f.write(b'Then, you need to close them again.\n') + f = LoopbackFile("r+", 16) + f.write(b"The first thing you need to do is open your eyes. ") + f.write(b"Then, you need to close them again.\n") s = f.read(-1) - self.assertEqual(s, b'The first thing you need to do is open your eyes. Then, you ' + - b'need to close them again.\n') + self.assertEqual( + s, + b"The first thing you need to do is open your eyes. Then, you " + + b"need to close them again.\n", + ) f.close() def test_8_buffering(self): @@ -162,19 +168,19 @@ class BufferedFileTest (unittest.TestCase): verify that buffered objects can be written """ if sys.version_info[0] == 2: - f = LoopbackFile('r+', 16) - f.write(buffer(b'Too small.')) + f = LoopbackFile("r+", 16) + f.write(buffer(b"Too small.")) f.close() def test_9_readable(self): - f = LoopbackFile('r') + f = LoopbackFile("r") self.assertTrue(f.readable()) self.assertFalse(f.writable()) self.assertFalse(f.seekable()) f.close() def test_A_writable(self): - f = LoopbackFile('w') + f = LoopbackFile("w") self.assertTrue(f.writable()) self.assertFalse(f.readable()) self.assertFalse(f.seekable()) @@ -182,48 +188,49 @@ class BufferedFileTest (unittest.TestCase): def test_B_readinto(self): data = bytearray(5) - f = LoopbackFile('r+') + f = LoopbackFile("r+") f._write(b"hello") f.readinto(data) - self.assertEqual(data, b'hello') + self.assertEqual(data, b"hello") f.close() def test_write_bad_type(self): - with LoopbackFile('wb') as f: + with LoopbackFile("wb") as f: self.assertRaises(TypeError, f.write, object()) def test_write_unicode_as_binary(self): text = u"\xa7 why is writing text to a binary file allowed?\n" - with LoopbackFile('rb+') as f: + with LoopbackFile("rb+") as f: f.write(text) self.assertEqual(f.read(), text.encode("utf-8")) - @needs_builtin('memoryview') + @needs_builtin("memoryview") def test_write_bytearray(self): - with LoopbackFile('rb+') as f: + with LoopbackFile("rb+") as f: f.write(bytearray(12)) self.assertEqual(f.read(), 12 * b"\0") - @needs_builtin('buffer') + @needs_builtin("buffer") def test_write_buffer(self): data = 3 * b"pretend giant block of data\n" offsets = range(0, len(data), 8) - with LoopbackFile('rb+') as f: + with LoopbackFile("rb+") as f: for offset in offsets: f.write(buffer(data, offset, 8)) self.assertEqual(f.read(), data) - @needs_builtin('memoryview') + @needs_builtin("memoryview") def test_write_memoryview(self): data = 3 * b"pretend giant block of data\n" offsets = range(0, len(data), 8) - with LoopbackFile('rb+') as f: + with LoopbackFile("rb+") as f: view = memoryview(data) for offset in offsets: - f.write(view[offset:offset+8]) + f.write(view[offset : offset + 8]) self.assertEqual(f.read(), data) -if __name__ == '__main__': +if __name__ == "__main__": from unittest import main + main() diff --git a/tests/test_gssapi.py b/tests/test_gssapi.py index d4b632be..d7fbdd53 100644 --- a/tests/test_gssapi.py +++ b/tests/test_gssapi.py @@ -30,6 +30,7 @@ from .util import needs_gssapi @needs_gssapi class GSSAPITest(unittest.TestCase): + def setup(): # TODO: these vars should all come from os.environ or whatever the # approved pytest method is for runtime-configuring test data. @@ -43,6 +44,7 @@ class GSSAPITest(unittest.TestCase): """ from pyasn1.type.univ import ObjectIdentifier from pyasn1.codec.der import encoder, decoder + oid = encoder.encode(ObjectIdentifier(self.krb5_mech)) mech, __ = decoder.decode(oid) self.assertEquals(self.krb5_mech, mech.__str__()) @@ -57,6 +59,7 @@ class GSSAPITest(unittest.TestCase): except ImportError: import sspicon import sspi + _API = "SSPI" c_token = None @@ -65,23 +68,28 @@ class GSSAPITest(unittest.TestCase): if _API == "MIT": if self.server_mode: - gss_flags = (gssapi.C_PROT_READY_FLAG, - gssapi.C_INTEG_FLAG, - gssapi.C_MUTUAL_FLAG, - gssapi.C_DELEG_FLAG) + gss_flags = ( + gssapi.C_PROT_READY_FLAG, + gssapi.C_INTEG_FLAG, + gssapi.C_MUTUAL_FLAG, + gssapi.C_DELEG_FLAG, + ) else: - gss_flags = (gssapi.C_PROT_READY_FLAG, - gssapi.C_INTEG_FLAG, - gssapi.C_DELEG_FLAG) + gss_flags = ( + gssapi.C_PROT_READY_FLAG, + gssapi.C_INTEG_FLAG, + gssapi.C_DELEG_FLAG, + ) # Initialize a GSS-API context. ctx = gssapi.Context() ctx.flags = gss_flags krb5_oid = gssapi.OID.mech_from_string(self.krb5_mech) - target_name = gssapi.Name("host@" + self.targ_name, - gssapi.C_NT_HOSTBASED_SERVICE) - gss_ctxt = gssapi.InitContext(peer_name=target_name, - mech_type=krb5_oid, - req_flags=ctx.flags) + target_name = gssapi.Name( + "host@" + self.targ_name, gssapi.C_NT_HOSTBASED_SERVICE + ) + gss_ctxt = gssapi.InitContext( + peer_name=target_name, mech_type=krb5_oid, req_flags=ctx.flags + ) if self.server_mode: c_token = gss_ctxt.step(c_token) gss_ctxt_status = gss_ctxt.established @@ -108,15 +116,15 @@ class GSSAPITest(unittest.TestCase): self.assertEquals(0, status) else: gss_flags = ( - sspicon.ISC_REQ_INTEGRITY | - sspicon.ISC_REQ_MUTUAL_AUTH | - sspicon.ISC_REQ_DELEGATE + sspicon.ISC_REQ_INTEGRITY + | sspicon.ISC_REQ_MUTUAL_AUTH + | sspicon.ISC_REQ_DELEGATE ) # Initialize a GSS-API context. target_name = "host/" + socket.getfqdn(self.targ_name) - gss_ctxt = sspi.ClientAuth("Kerberos", - scflags=gss_flags, - targetspn=target_name) + gss_ctxt = sspi.ClientAuth( + "Kerberos", scflags=gss_flags, targetspn=target_name + ) if self.server_mode: error, token = gss_ctxt.authorize(c_token) c_token = token[0].Buffer diff --git a/tests/test_hostkeys.py b/tests/test_hostkeys.py index cd75f8ab..a1b7a9e0 100644 --- a/tests/test_hostkeys.py +++ b/tests/test_hostkeys.py @@ -54,77 +54,80 @@ Ngw3qIch/WgRmMHy4kBq1SsXMjQCte1So6HBMvBPIW5SiMTmjCfZZiw4AYHK+B/JaOwaG9yRg2Ejg\ 0d54U0X/NeX5QxuYR6OMJlrkQB7oiW/P/1mwjQgE=""" -class HostKeysTest (unittest.TestCase): +class HostKeysTest(unittest.TestCase): def setUp(self): - with open('hostfile.temp', 'w') as f: + with open("hostfile.temp", "w") as f: f.write(test_hosts_file) def tearDown(self): - os.unlink('hostfile.temp') + os.unlink("hostfile.temp") def test_1_load(self): - hostdict = paramiko.HostKeys('hostfile.temp') + hostdict = paramiko.HostKeys("hostfile.temp") self.assertEqual(2, len(hostdict)) self.assertEqual(1, len(list(hostdict.values())[0])) self.assertEqual(1, len(list(hostdict.values())[1])) - fp = hexlify(hostdict['secure.example.com']['ssh-rsa'].get_fingerprint()).upper() - self.assertEqual(b'E6684DB30E109B67B70FF1DC5C7F1363', fp) + fp = hexlify( + hostdict["secure.example.com"]["ssh-rsa"].get_fingerprint() + ).upper() + self.assertEqual(b"E6684DB30E109B67B70FF1DC5C7F1363", fp) def test_2_add(self): - hostdict = paramiko.HostKeys('hostfile.temp') - hh = '|1|BMsIC6cUIP2zBuXR3t2LRcJYjzM=|hpkJMysjTk/+zzUUzxQEa2ieq6c=' + hostdict = paramiko.HostKeys("hostfile.temp") + hh = "|1|BMsIC6cUIP2zBuXR3t2LRcJYjzM=|hpkJMysjTk/+zzUUzxQEa2ieq6c=" key = paramiko.RSAKey(data=decodebytes(keyblob)) - hostdict.add(hh, 'ssh-rsa', key) + hostdict.add(hh, "ssh-rsa", key) self.assertEqual(3, len(list(hostdict))) - x = hostdict['foo.example.com'] - fp = hexlify(x['ssh-rsa'].get_fingerprint()).upper() - self.assertEqual(b'7EC91BB336CB6D810B124B1353C32396', fp) - self.assertTrue(hostdict.check('foo.example.com', key)) + x = hostdict["foo.example.com"] + fp = hexlify(x["ssh-rsa"].get_fingerprint()).upper() + self.assertEqual(b"7EC91BB336CB6D810B124B1353C32396", fp) + self.assertTrue(hostdict.check("foo.example.com", key)) def test_3_dict(self): - hostdict = paramiko.HostKeys('hostfile.temp') - self.assertTrue('secure.example.com' in hostdict) - self.assertTrue('not.example.com' not in hostdict) - self.assertTrue('secure.example.com' in hostdict) - self.assertTrue('not.example.com' not in hostdict) - x = hostdict.get('secure.example.com', None) + hostdict = paramiko.HostKeys("hostfile.temp") + self.assertTrue("secure.example.com" in hostdict) + self.assertTrue("not.example.com" not in hostdict) + self.assertTrue("secure.example.com" in hostdict) + self.assertTrue("not.example.com" not in hostdict) + x = hostdict.get("secure.example.com", None) self.assertTrue(x is not None) - fp = hexlify(x['ssh-rsa'].get_fingerprint()).upper() - self.assertEqual(b'E6684DB30E109B67B70FF1DC5C7F1363', fp) + fp = hexlify(x["ssh-rsa"].get_fingerprint()).upper() + self.assertEqual(b"E6684DB30E109B67B70FF1DC5C7F1363", fp) i = 0 for key in hostdict: i += 1 self.assertEqual(2, i) - + def test_4_dict_set(self): - hostdict = paramiko.HostKeys('hostfile.temp') + hostdict = paramiko.HostKeys("hostfile.temp") key = paramiko.RSAKey(data=decodebytes(keyblob)) key_dss = paramiko.DSSKey(data=decodebytes(keyblob_dss)) - hostdict['secure.example.com'] = { - 'ssh-rsa': key, - 'ssh-dss': key_dss - } - hostdict['fake.example.com'] = {} - hostdict['fake.example.com']['ssh-rsa'] = key - + hostdict["secure.example.com"] = {"ssh-rsa": key, "ssh-dss": key_dss} + hostdict["fake.example.com"] = {} + hostdict["fake.example.com"]["ssh-rsa"] = key + self.assertEqual(3, len(hostdict)) self.assertEqual(2, len(list(hostdict.values())[0])) self.assertEqual(1, len(list(hostdict.values())[1])) self.assertEqual(1, len(list(hostdict.values())[2])) - fp = hexlify(hostdict['secure.example.com']['ssh-rsa'].get_fingerprint()).upper() - self.assertEqual(b'7EC91BB336CB6D810B124B1353C32396', fp) - fp = hexlify(hostdict['secure.example.com']['ssh-dss'].get_fingerprint()).upper() - self.assertEqual(b'4478F0B9A23CC5182009FF755BC1D26C', fp) + fp = hexlify( + hostdict["secure.example.com"]["ssh-rsa"].get_fingerprint() + ).upper() + self.assertEqual(b"7EC91BB336CB6D810B124B1353C32396", fp) + fp = hexlify( + hostdict["secure.example.com"]["ssh-dss"].get_fingerprint() + ).upper() + self.assertEqual(b"4478F0B9A23CC5182009FF755BC1D26C", fp) def test_delitem(self): - hostdict = paramiko.HostKeys('hostfile.temp') - target = 'happy.example.com' - entry = hostdict[target] # will KeyError if not present + hostdict = paramiko.HostKeys("hostfile.temp") + target = "happy.example.com" + entry = hostdict[target] # will KeyError if not present del hostdict[target] try: entry = hostdict[target] except KeyError: - pass # Good + pass # Good else: assert False, "Entry was not deleted from HostKeys on delitem!" diff --git a/tests/test_kex.py b/tests/test_kex.py index 241a7665..732410a2 100644 --- a/tests/test_kex.py +++ b/tests/test_kex.py @@ -39,30 +39,46 @@ from paramiko.kex_ecdh_nist import KexNistp256 def dummy_urandom(n): return byte_chr(0xcc) * n + def dummy_generate_key_pair(obj): - private_key_value = 94761803665136558137557783047955027733968423115106677159790289642479432803037 - public_key_numbers = "042bdab212fa8ba1b7c843301682a4db424d307246c7e1e6083c41d9ca7b098bf30b3d63e2ec6278488c135360456cc054b3444ecc45998c08894cbc1370f5f989" - public_key_numbers_obj = ec.EllipticCurvePublicNumbers.from_encoded_point(ec.SECP256R1(), unhexlify(public_key_numbers)) - obj.P = ec.EllipticCurvePrivateNumbers(private_value=private_key_value, public_numbers=public_key_numbers_obj).private_key(default_backend()) + private_key_value = ( + 94761803665136558137557783047955027733968423115106677159790289642479432803037 + ) + public_key_numbers = ( + "042bdab212fa8ba1b7c843301682a4db424d307246c7e1e6083c41d9ca7b098bf30b3d63e2ec6278488c135360456cc054b3444ecc45998c08894cbc1370f5f989" + ) + public_key_numbers_obj = ec.EllipticCurvePublicNumbers.from_encoded_point( + ec.SECP256R1(), unhexlify(public_key_numbers) + ) + obj.P = ec.EllipticCurvePrivateNumbers( + private_value=private_key_value, public_numbers=public_key_numbers_obj + ).private_key(default_backend()) if obj.transport.server_mode: - obj.Q_S = ec.EllipticCurvePublicNumbers.from_encoded_point(ec.SECP256R1(), unhexlify(public_key_numbers)).public_key(default_backend()) + obj.Q_S = ec.EllipticCurvePublicNumbers.from_encoded_point( + ec.SECP256R1(), unhexlify(public_key_numbers) + ).public_key(default_backend()) return - obj.Q_C = ec.EllipticCurvePublicNumbers.from_encoded_point(ec.SECP256R1(), unhexlify(public_key_numbers)).public_key(default_backend()) + obj.Q_C = ec.EllipticCurvePublicNumbers.from_encoded_point( + ec.SECP256R1(), unhexlify(public_key_numbers) + ).public_key(default_backend()) + +class FakeKey(object): -class FakeKey (object): def __str__(self): - return 'fake-key' + return "fake-key" def asbytes(self): - return b'fake-key' + return b"fake-key" def sign_ssh_data(self, H): - return b'fake-sig' + return b"fake-sig" -class FakeModulusPack (object): - P = 0xFFFFFFFFFFFFFFFFC90FDAA22168C234C4C6628B80DC1CD129024E088A67CC74020BBEA63B139B22514A08798E3404DDEF9519B3CD3A431B302B0A6DF25F14374FE1356D6D51C245E485B576625E7EC6F44C42E9A637ED6B0BFF5CB6F406B7EDEE386BFB5A899FA5AE9F24117C4B1FE649286651ECE65381FFFFFFFFFFFFFFFF +class FakeModulusPack(object): + P = ( + 0xFFFFFFFFFFFFFFFFC90FDAA22168C234C4C6628B80DC1CD129024E088A67CC74020BBEA63B139B22514A08798E3404DDEF9519B3CD3A431B302B0A6DF25F14374FE1356D6D51C245E485B576625E7EC6F44C42E9A637ED6B0BFF5CB6F406B7EDEE386BFB5A899FA5AE9F24117C4B1FE649286651ECE65381FFFFFFFFFFFFFFFF + ) G = 2 def get_modulus(self, min, ask, max): @@ -70,10 +86,10 @@ class FakeModulusPack (object): class FakeTransport(object): - local_version = 'SSH-2.0-paramiko_1.0' - remote_version = 'SSH-2.0-lame' - local_kex_init = 'local-kex-init' - remote_kex_init = 'remote-kex-init' + local_version = "SSH-2.0-paramiko_1.0" + remote_version = "SSH-2.0-lame" + local_kex_init = "local-kex-init" + remote_kex_init = "remote-kex-init" def _send_message(self, m): self._message = m @@ -101,9 +117,11 @@ class FakeTransport(object): return FakeModulusPack() -class KexTest (unittest.TestCase): +class KexTest(unittest.TestCase): - K = 14730343317708716439807310032871972459448364195094179797249681733965528989482751523943515690110179031004049109375612685505881911274101441415545039654102474376472240501616988799699744135291070488314748284283496055223852115360852283821334858541043710301057312858051901453919067023103730011648890038847384890504 + K = ( + 14730343317708716439807310032871972459448364195094179797249681733965528989482751523943515690110179031004049109375612685505881911274101441415545039654102474376472240501616988799699744135291070488314748284283496055223852115360852283821334858541043710301057312858051901453919067023103730011648890038847384890504 + ) def setUp(self): self._original_urandom = os.urandom @@ -120,21 +138,25 @@ class KexTest (unittest.TestCase): transport.server_mode = False kex = KexGroup1(transport) kex.start_kex() - x = b'1E000000807E2DDB1743F3487D6545F04F1C8476092FB912B013626AB5BCEB764257D88BBA64243B9F348DF7B41B8C814A995E00299913503456983FFB9178D3CD79EB6D55522418A8ABF65375872E55938AB99A84A0B5FC8A1ECC66A7C3766E7E0F80B7CE2C9225FC2DD683F4764244B72963BBB383F529DCF0C5D17740B8A2ADBE9208D4' + x = ( + b"1E000000807E2DDB1743F3487D6545F04F1C8476092FB912B013626AB5BCEB764257D88BBA64243B9F348DF7B41B8C814A995E00299913503456983FFB9178D3CD79EB6D55522418A8ABF65375872E55938AB99A84A0B5FC8A1ECC66A7C3766E7E0F80B7CE2C9225FC2DD683F4764244B72963BBB383F529DCF0C5D17740B8A2ADBE9208D4" + ) self.assertEqual(x, hexlify(transport._message.asbytes()).upper()) - self.assertEqual((paramiko.kex_group1._MSG_KEXDH_REPLY,), transport._expect) + self.assertEqual( + (paramiko.kex_group1._MSG_KEXDH_REPLY,), transport._expect + ) # fake "reply" msg = Message() - msg.add_string('fake-host-key') + msg.add_string("fake-host-key") msg.add_mpint(69) - msg.add_string('fake-sig') + msg.add_string("fake-sig") msg.rewind() kex.parse_next(paramiko.kex_group1._MSG_KEXDH_REPLY, msg) - H = b'03079780F3D3AD0B3C6DB30C8D21685F367A86D2' + H = b"03079780F3D3AD0B3C6DB30C8D21685F367A86D2" self.assertEqual(self.K, transport._K) self.assertEqual(H, hexlify(transport._H).upper()) - self.assertEqual((b'fake-host-key', b'fake-sig'), transport._verify) + self.assertEqual((b"fake-host-key", b"fake-sig"), transport._verify) self.assertTrue(transport._activated) def test_2_group1_server(self): @@ -142,14 +164,18 @@ class KexTest (unittest.TestCase): transport.server_mode = True kex = KexGroup1(transport) kex.start_kex() - self.assertEqual((paramiko.kex_group1._MSG_KEXDH_INIT,), transport._expect) + self.assertEqual( + (paramiko.kex_group1._MSG_KEXDH_INIT,), transport._expect + ) msg = Message() msg.add_mpint(69) msg.rewind() kex.parse_next(paramiko.kex_group1._MSG_KEXDH_INIT, msg) - H = b'B16BF34DD10945EDE84E9C1EF24A14BFDC843389' - x = b'1F0000000866616B652D6B6579000000807E2DDB1743F3487D6545F04F1C8476092FB912B013626AB5BCEB764257D88BBA64243B9F348DF7B41B8C814A995E00299913503456983FFB9178D3CD79EB6D55522418A8ABF65375872E55938AB99A84A0B5FC8A1ECC66A7C3766E7E0F80B7CE2C9225FC2DD683F4764244B72963BBB383F529DCF0C5D17740B8A2ADBE9208D40000000866616B652D736967' + H = b"B16BF34DD10945EDE84E9C1EF24A14BFDC843389" + x = ( + b"1F0000000866616B652D6B6579000000807E2DDB1743F3487D6545F04F1C8476092FB912B013626AB5BCEB764257D88BBA64243B9F348DF7B41B8C814A995E00299913503456983FFB9178D3CD79EB6D55522418A8ABF65375872E55938AB99A84A0B5FC8A1ECC66A7C3766E7E0F80B7CE2C9225FC2DD683F4764244B72963BBB383F529DCF0C5D17740B8A2ADBE9208D40000000866616B652D736967" + ) self.assertEqual(self.K, transport._K) self.assertEqual(H, hexlify(transport._H).upper()) self.assertEqual(x, hexlify(transport._message.asbytes()).upper()) @@ -160,29 +186,35 @@ class KexTest (unittest.TestCase): transport.server_mode = False kex = KexGex(transport) kex.start_kex() - x = b'22000004000000080000002000' + x = b"22000004000000080000002000" self.assertEqual(x, hexlify(transport._message.asbytes()).upper()) - self.assertEqual((paramiko.kex_gex._MSG_KEXDH_GEX_GROUP,), transport._expect) + self.assertEqual( + (paramiko.kex_gex._MSG_KEXDH_GEX_GROUP,), transport._expect + ) msg = Message() msg.add_mpint(FakeModulusPack.P) msg.add_mpint(FakeModulusPack.G) msg.rewind() kex.parse_next(paramiko.kex_gex._MSG_KEXDH_GEX_GROUP, msg) - x = b'20000000807E2DDB1743F3487D6545F04F1C8476092FB912B013626AB5BCEB764257D88BBA64243B9F348DF7B41B8C814A995E00299913503456983FFB9178D3CD79EB6D55522418A8ABF65375872E55938AB99A84A0B5FC8A1ECC66A7C3766E7E0F80B7CE2C9225FC2DD683F4764244B72963BBB383F529DCF0C5D17740B8A2ADBE9208D4' + x = ( + b"20000000807E2DDB1743F3487D6545F04F1C8476092FB912B013626AB5BCEB764257D88BBA64243B9F348DF7B41B8C814A995E00299913503456983FFB9178D3CD79EB6D55522418A8ABF65375872E55938AB99A84A0B5FC8A1ECC66A7C3766E7E0F80B7CE2C9225FC2DD683F4764244B72963BBB383F529DCF0C5D17740B8A2ADBE9208D4" + ) self.assertEqual(x, hexlify(transport._message.asbytes()).upper()) - self.assertEqual((paramiko.kex_gex._MSG_KEXDH_GEX_REPLY,), transport._expect) + self.assertEqual( + (paramiko.kex_gex._MSG_KEXDH_GEX_REPLY,), transport._expect + ) msg = Message() - msg.add_string('fake-host-key') + msg.add_string("fake-host-key") msg.add_mpint(69) - msg.add_string('fake-sig') + msg.add_string("fake-sig") msg.rewind() kex.parse_next(paramiko.kex_gex._MSG_KEXDH_GEX_REPLY, msg) - H = b'A265563F2FA87F1A89BF007EE90D58BE2E4A4BD0' + H = b"A265563F2FA87F1A89BF007EE90D58BE2E4A4BD0" self.assertEqual(self.K, transport._K) self.assertEqual(H, hexlify(transport._H).upper()) - self.assertEqual((b'fake-host-key', b'fake-sig'), transport._verify) + self.assertEqual((b"fake-host-key", b"fake-sig"), transport._verify) self.assertTrue(transport._activated) def test_4_gex_old_client(self): @@ -190,37 +222,49 @@ class KexTest (unittest.TestCase): transport.server_mode = False kex = KexGex(transport) kex.start_kex(_test_old_style=True) - x = b'1E00000800' + x = b"1E00000800" self.assertEqual(x, hexlify(transport._message.asbytes()).upper()) - self.assertEqual((paramiko.kex_gex._MSG_KEXDH_GEX_GROUP,), transport._expect) + self.assertEqual( + (paramiko.kex_gex._MSG_KEXDH_GEX_GROUP,), transport._expect + ) msg = Message() msg.add_mpint(FakeModulusPack.P) msg.add_mpint(FakeModulusPack.G) msg.rewind() kex.parse_next(paramiko.kex_gex._MSG_KEXDH_GEX_GROUP, msg) - x = b'20000000807E2DDB1743F3487D6545F04F1C8476092FB912B013626AB5BCEB764257D88BBA64243B9F348DF7B41B8C814A995E00299913503456983FFB9178D3CD79EB6D55522418A8ABF65375872E55938AB99A84A0B5FC8A1ECC66A7C3766E7E0F80B7CE2C9225FC2DD683F4764244B72963BBB383F529DCF0C5D17740B8A2ADBE9208D4' + x = ( + b"20000000807E2DDB1743F3487D6545F04F1C8476092FB912B013626AB5BCEB764257D88BBA64243B9F348DF7B41B8C814A995E00299913503456983FFB9178D3CD79EB6D55522418A8ABF65375872E55938AB99A84A0B5FC8A1ECC66A7C3766E7E0F80B7CE2C9225FC2DD683F4764244B72963BBB383F529DCF0C5D17740B8A2ADBE9208D4" + ) self.assertEqual(x, hexlify(transport._message.asbytes()).upper()) - self.assertEqual((paramiko.kex_gex._MSG_KEXDH_GEX_REPLY,), transport._expect) + self.assertEqual( + (paramiko.kex_gex._MSG_KEXDH_GEX_REPLY,), transport._expect + ) msg = Message() - msg.add_string('fake-host-key') + msg.add_string("fake-host-key") msg.add_mpint(69) - msg.add_string('fake-sig') + msg.add_string("fake-sig") msg.rewind() kex.parse_next(paramiko.kex_gex._MSG_KEXDH_GEX_REPLY, msg) - H = b'807F87B269EF7AC5EC7E75676808776A27D5864C' + H = b"807F87B269EF7AC5EC7E75676808776A27D5864C" self.assertEqual(self.K, transport._K) self.assertEqual(H, hexlify(transport._H).upper()) - self.assertEqual((b'fake-host-key', b'fake-sig'), transport._verify) + self.assertEqual((b"fake-host-key", b"fake-sig"), transport._verify) self.assertTrue(transport._activated) - + def test_5_gex_server(self): transport = FakeTransport() transport.server_mode = True kex = KexGex(transport) kex.start_kex() - self.assertEqual((paramiko.kex_gex._MSG_KEXDH_GEX_REQUEST, paramiko.kex_gex._MSG_KEXDH_GEX_REQUEST_OLD), transport._expect) + self.assertEqual( + ( + paramiko.kex_gex._MSG_KEXDH_GEX_REQUEST, + paramiko.kex_gex._MSG_KEXDH_GEX_REQUEST_OLD, + ), + transport._expect, + ) msg = Message() msg.add_int(1024) @@ -228,17 +272,25 @@ class KexTest (unittest.TestCase): msg.add_int(4096) msg.rewind() kex.parse_next(paramiko.kex_gex._MSG_KEXDH_GEX_REQUEST, msg) - x = b'1F0000008100FFFFFFFFFFFFFFFFC90FDAA22168C234C4C6628B80DC1CD129024E088A67CC74020BBEA63B139B22514A08798E3404DDEF9519B3CD3A431B302B0A6DF25F14374FE1356D6D51C245E485B576625E7EC6F44C42E9A637ED6B0BFF5CB6F406B7EDEE386BFB5A899FA5AE9F24117C4B1FE649286651ECE65381FFFFFFFFFFFFFFFF0000000102' + x = ( + b"1F0000008100FFFFFFFFFFFFFFFFC90FDAA22168C234C4C6628B80DC1CD129024E088A67CC74020BBEA63B139B22514A08798E3404DDEF9519B3CD3A431B302B0A6DF25F14374FE1356D6D51C245E485B576625E7EC6F44C42E9A637ED6B0BFF5CB6F406B7EDEE386BFB5A899FA5AE9F24117C4B1FE649286651ECE65381FFFFFFFFFFFFFFFF0000000102" + ) self.assertEqual(x, hexlify(transport._message.asbytes()).upper()) - self.assertEqual((paramiko.kex_gex._MSG_KEXDH_GEX_INIT,), transport._expect) + self.assertEqual( + (paramiko.kex_gex._MSG_KEXDH_GEX_INIT,), transport._expect + ) msg = Message() msg.add_mpint(12345) msg.rewind() kex.parse_next(paramiko.kex_gex._MSG_KEXDH_GEX_INIT, msg) - K = 67592995013596137876033460028393339951879041140378510871612128162185209509220726296697886624612526735888348020498716482757677848959420073720160491114319163078862905400020959196386947926388406687288901564192071077389283980347784184487280885335302632305026248574716290537036069329724382811853044654824945750581 - H = b'CE754197C21BF3452863B4F44D0B3951F12516EF' - x = b'210000000866616B652D6B6579000000807E2DDB1743F3487D6545F04F1C8476092FB912B013626AB5BCEB764257D88BBA64243B9F348DF7B41B8C814A995E00299913503456983FFB9178D3CD79EB6D55522418A8ABF65375872E55938AB99A84A0B5FC8A1ECC66A7C3766E7E0F80B7CE2C9225FC2DD683F4764244B72963BBB383F529DCF0C5D17740B8A2ADBE9208D40000000866616B652D736967' + K = ( + 67592995013596137876033460028393339951879041140378510871612128162185209509220726296697886624612526735888348020498716482757677848959420073720160491114319163078862905400020959196386947926388406687288901564192071077389283980347784184487280885335302632305026248574716290537036069329724382811853044654824945750581 + ) + H = b"CE754197C21BF3452863B4F44D0B3951F12516EF" + x = ( + b"210000000866616B652D6B6579000000807E2DDB1743F3487D6545F04F1C8476092FB912B013626AB5BCEB764257D88BBA64243B9F348DF7B41B8C814A995E00299913503456983FFB9178D3CD79EB6D55522418A8ABF65375872E55938AB99A84A0B5FC8A1ECC66A7C3766E7E0F80B7CE2C9225FC2DD683F4764244B72963BBB383F529DCF0C5D17740B8A2ADBE9208D40000000866616B652D736967" + ) self.assertEqual(K, transport._K) self.assertEqual(H, hexlify(transport._H).upper()) self.assertEqual(x, hexlify(transport._message.asbytes()).upper()) @@ -249,23 +301,37 @@ class KexTest (unittest.TestCase): transport.server_mode = True kex = KexGex(transport) kex.start_kex() - self.assertEqual((paramiko.kex_gex._MSG_KEXDH_GEX_REQUEST, paramiko.kex_gex._MSG_KEXDH_GEX_REQUEST_OLD), transport._expect) + self.assertEqual( + ( + paramiko.kex_gex._MSG_KEXDH_GEX_REQUEST, + paramiko.kex_gex._MSG_KEXDH_GEX_REQUEST_OLD, + ), + transport._expect, + ) msg = Message() msg.add_int(2048) msg.rewind() kex.parse_next(paramiko.kex_gex._MSG_KEXDH_GEX_REQUEST_OLD, msg) - x = b'1F0000008100FFFFFFFFFFFFFFFFC90FDAA22168C234C4C6628B80DC1CD129024E088A67CC74020BBEA63B139B22514A08798E3404DDEF9519B3CD3A431B302B0A6DF25F14374FE1356D6D51C245E485B576625E7EC6F44C42E9A637ED6B0BFF5CB6F406B7EDEE386BFB5A899FA5AE9F24117C4B1FE649286651ECE65381FFFFFFFFFFFFFFFF0000000102' + x = ( + b"1F0000008100FFFFFFFFFFFFFFFFC90FDAA22168C234C4C6628B80DC1CD129024E088A67CC74020BBEA63B139B22514A08798E3404DDEF9519B3CD3A431B302B0A6DF25F14374FE1356D6D51C245E485B576625E7EC6F44C42E9A637ED6B0BFF5CB6F406B7EDEE386BFB5A899FA5AE9F24117C4B1FE649286651ECE65381FFFFFFFFFFFFFFFF0000000102" + ) self.assertEqual(x, hexlify(transport._message.asbytes()).upper()) - self.assertEqual((paramiko.kex_gex._MSG_KEXDH_GEX_INIT,), transport._expect) + self.assertEqual( + (paramiko.kex_gex._MSG_KEXDH_GEX_INIT,), transport._expect + ) msg = Message() msg.add_mpint(12345) msg.rewind() kex.parse_next(paramiko.kex_gex._MSG_KEXDH_GEX_INIT, msg) - K = 67592995013596137876033460028393339951879041140378510871612128162185209509220726296697886624612526735888348020498716482757677848959420073720160491114319163078862905400020959196386947926388406687288901564192071077389283980347784184487280885335302632305026248574716290537036069329724382811853044654824945750581 - H = b'B41A06B2E59043CEFC1AE16EC31F1E2D12EC455B' - x = b'210000000866616B652D6B6579000000807E2DDB1743F3487D6545F04F1C8476092FB912B013626AB5BCEB764257D88BBA64243B9F348DF7B41B8C814A995E00299913503456983FFB9178D3CD79EB6D55522418A8ABF65375872E55938AB99A84A0B5FC8A1ECC66A7C3766E7E0F80B7CE2C9225FC2DD683F4764244B72963BBB383F529DCF0C5D17740B8A2ADBE9208D40000000866616B652D736967' + K = ( + 67592995013596137876033460028393339951879041140378510871612128162185209509220726296697886624612526735888348020498716482757677848959420073720160491114319163078862905400020959196386947926388406687288901564192071077389283980347784184487280885335302632305026248574716290537036069329724382811853044654824945750581 + ) + H = b"B41A06B2E59043CEFC1AE16EC31F1E2D12EC455B" + x = ( + b"210000000866616B652D6B6579000000807E2DDB1743F3487D6545F04F1C8476092FB912B013626AB5BCEB764257D88BBA64243B9F348DF7B41B8C814A995E00299913503456983FFB9178D3CD79EB6D55522418A8ABF65375872E55938AB99A84A0B5FC8A1ECC66A7C3766E7E0F80B7CE2C9225FC2DD683F4764244B72963BBB383F529DCF0C5D17740B8A2ADBE9208D40000000866616B652D736967" + ) self.assertEqual(K, transport._K) self.assertEqual(H, hexlify(transport._H).upper()) self.assertEqual(x, hexlify(transport._message.asbytes()).upper()) @@ -276,29 +342,35 @@ class KexTest (unittest.TestCase): transport.server_mode = False kex = KexGexSHA256(transport) kex.start_kex() - x = b'22000004000000080000002000' + x = b"22000004000000080000002000" self.assertEqual(x, hexlify(transport._message.asbytes()).upper()) - self.assertEqual((paramiko.kex_gex._MSG_KEXDH_GEX_GROUP,), transport._expect) + self.assertEqual( + (paramiko.kex_gex._MSG_KEXDH_GEX_GROUP,), transport._expect + ) msg = Message() msg.add_mpint(FakeModulusPack.P) msg.add_mpint(FakeModulusPack.G) msg.rewind() kex.parse_next(paramiko.kex_gex._MSG_KEXDH_GEX_GROUP, msg) - x = b'20000000807E2DDB1743F3487D6545F04F1C8476092FB912B013626AB5BCEB764257D88BBA64243B9F348DF7B41B8C814A995E00299913503456983FFB9178D3CD79EB6D55522418A8ABF65375872E55938AB99A84A0B5FC8A1ECC66A7C3766E7E0F80B7CE2C9225FC2DD683F4764244B72963BBB383F529DCF0C5D17740B8A2ADBE9208D4' + x = ( + b"20000000807E2DDB1743F3487D6545F04F1C8476092FB912B013626AB5BCEB764257D88BBA64243B9F348DF7B41B8C814A995E00299913503456983FFB9178D3CD79EB6D55522418A8ABF65375872E55938AB99A84A0B5FC8A1ECC66A7C3766E7E0F80B7CE2C9225FC2DD683F4764244B72963BBB383F529DCF0C5D17740B8A2ADBE9208D4" + ) self.assertEqual(x, hexlify(transport._message.asbytes()).upper()) - self.assertEqual((paramiko.kex_gex._MSG_KEXDH_GEX_REPLY,), transport._expect) + self.assertEqual( + (paramiko.kex_gex._MSG_KEXDH_GEX_REPLY,), transport._expect + ) msg = Message() - msg.add_string('fake-host-key') + msg.add_string("fake-host-key") msg.add_mpint(69) - msg.add_string('fake-sig') + msg.add_string("fake-sig") msg.rewind() kex.parse_next(paramiko.kex_gex._MSG_KEXDH_GEX_REPLY, msg) - H = b'AD1A9365A67B4496F05594AD1BF656E3CDA0851289A4C1AFF549FEAE50896DF4' + H = b"AD1A9365A67B4496F05594AD1BF656E3CDA0851289A4C1AFF549FEAE50896DF4" self.assertEqual(self.K, transport._K) self.assertEqual(H, hexlify(transport._H).upper()) - self.assertEqual((b'fake-host-key', b'fake-sig'), transport._verify) + self.assertEqual((b"fake-host-key", b"fake-sig"), transport._verify) self.assertTrue(transport._activated) def test_8_gex_sha256_old_client(self): @@ -306,29 +378,35 @@ class KexTest (unittest.TestCase): transport.server_mode = False kex = KexGexSHA256(transport) kex.start_kex(_test_old_style=True) - x = b'1E00000800' + x = b"1E00000800" self.assertEqual(x, hexlify(transport._message.asbytes()).upper()) - self.assertEqual((paramiko.kex_gex._MSG_KEXDH_GEX_GROUP,), transport._expect) + self.assertEqual( + (paramiko.kex_gex._MSG_KEXDH_GEX_GROUP,), transport._expect + ) msg = Message() msg.add_mpint(FakeModulusPack.P) msg.add_mpint(FakeModulusPack.G) msg.rewind() kex.parse_next(paramiko.kex_gex._MSG_KEXDH_GEX_GROUP, msg) - x = b'20000000807E2DDB1743F3487D6545F04F1C8476092FB912B013626AB5BCEB764257D88BBA64243B9F348DF7B41B8C814A995E00299913503456983FFB9178D3CD79EB6D55522418A8ABF65375872E55938AB99A84A0B5FC8A1ECC66A7C3766E7E0F80B7CE2C9225FC2DD683F4764244B72963BBB383F529DCF0C5D17740B8A2ADBE9208D4' + x = ( + b"20000000807E2DDB1743F3487D6545F04F1C8476092FB912B013626AB5BCEB764257D88BBA64243B9F348DF7B41B8C814A995E00299913503456983FFB9178D3CD79EB6D55522418A8ABF65375872E55938AB99A84A0B5FC8A1ECC66A7C3766E7E0F80B7CE2C9225FC2DD683F4764244B72963BBB383F529DCF0C5D17740B8A2ADBE9208D4" + ) self.assertEqual(x, hexlify(transport._message.asbytes()).upper()) - self.assertEqual((paramiko.kex_gex._MSG_KEXDH_GEX_REPLY,), transport._expect) + self.assertEqual( + (paramiko.kex_gex._MSG_KEXDH_GEX_REPLY,), transport._expect + ) msg = Message() - msg.add_string('fake-host-key') + msg.add_string("fake-host-key") msg.add_mpint(69) - msg.add_string('fake-sig') + msg.add_string("fake-sig") msg.rewind() kex.parse_next(paramiko.kex_gex._MSG_KEXDH_GEX_REPLY, msg) - H = b'518386608B15891AE5237DEE08DCADDE76A0BCEFCE7F6DB3AD66BC41D256DFE5' + H = b"518386608B15891AE5237DEE08DCADDE76A0BCEFCE7F6DB3AD66BC41D256DFE5" self.assertEqual(self.K, transport._K) self.assertEqual(H, hexlify(transport._H).upper()) - self.assertEqual((b'fake-host-key', b'fake-sig'), transport._verify) + self.assertEqual((b"fake-host-key", b"fake-sig"), transport._verify) self.assertTrue(transport._activated) def test_9_gex_sha256_server(self): @@ -336,7 +414,13 @@ class KexTest (unittest.TestCase): transport.server_mode = True kex = KexGexSHA256(transport) kex.start_kex() - self.assertEqual((paramiko.kex_gex._MSG_KEXDH_GEX_REQUEST, paramiko.kex_gex._MSG_KEXDH_GEX_REQUEST_OLD), transport._expect) + self.assertEqual( + ( + paramiko.kex_gex._MSG_KEXDH_GEX_REQUEST, + paramiko.kex_gex._MSG_KEXDH_GEX_REQUEST_OLD, + ), + transport._expect, + ) msg = Message() msg.add_int(1024) @@ -344,17 +428,25 @@ class KexTest (unittest.TestCase): msg.add_int(4096) msg.rewind() kex.parse_next(paramiko.kex_gex._MSG_KEXDH_GEX_REQUEST, msg) - x = b'1F0000008100FFFFFFFFFFFFFFFFC90FDAA22168C234C4C6628B80DC1CD129024E088A67CC74020BBEA63B139B22514A08798E3404DDEF9519B3CD3A431B302B0A6DF25F14374FE1356D6D51C245E485B576625E7EC6F44C42E9A637ED6B0BFF5CB6F406B7EDEE386BFB5A899FA5AE9F24117C4B1FE649286651ECE65381FFFFFFFFFFFFFFFF0000000102' + x = ( + b"1F0000008100FFFFFFFFFFFFFFFFC90FDAA22168C234C4C6628B80DC1CD129024E088A67CC74020BBEA63B139B22514A08798E3404DDEF9519B3CD3A431B302B0A6DF25F14374FE1356D6D51C245E485B576625E7EC6F44C42E9A637ED6B0BFF5CB6F406B7EDEE386BFB5A899FA5AE9F24117C4B1FE649286651ECE65381FFFFFFFFFFFFFFFF0000000102" + ) self.assertEqual(x, hexlify(transport._message.asbytes()).upper()) - self.assertEqual((paramiko.kex_gex._MSG_KEXDH_GEX_INIT,), transport._expect) + self.assertEqual( + (paramiko.kex_gex._MSG_KEXDH_GEX_INIT,), transport._expect + ) msg = Message() msg.add_mpint(12345) msg.rewind() kex.parse_next(paramiko.kex_gex._MSG_KEXDH_GEX_INIT, msg) - K = 67592995013596137876033460028393339951879041140378510871612128162185209509220726296697886624612526735888348020498716482757677848959420073720160491114319163078862905400020959196386947926388406687288901564192071077389283980347784184487280885335302632305026248574716290537036069329724382811853044654824945750581 - H = b'CCAC0497CF0ABA1DBF55E1A3995D17F4CC31824B0E8D95CDF8A06F169D050D80' - x = b'210000000866616B652D6B6579000000807E2DDB1743F3487D6545F04F1C8476092FB912B013626AB5BCEB764257D88BBA64243B9F348DF7B41B8C814A995E00299913503456983FFB9178D3CD79EB6D55522418A8ABF65375872E55938AB99A84A0B5FC8A1ECC66A7C3766E7E0F80B7CE2C9225FC2DD683F4764244B72963BBB383F529DCF0C5D17740B8A2ADBE9208D40000000866616B652D736967' + K = ( + 67592995013596137876033460028393339951879041140378510871612128162185209509220726296697886624612526735888348020498716482757677848959420073720160491114319163078862905400020959196386947926388406687288901564192071077389283980347784184487280885335302632305026248574716290537036069329724382811853044654824945750581 + ) + H = b"CCAC0497CF0ABA1DBF55E1A3995D17F4CC31824B0E8D95CDF8A06F169D050D80" + x = ( + b"210000000866616B652D6B6579000000807E2DDB1743F3487D6545F04F1C8476092FB912B013626AB5BCEB764257D88BBA64243B9F348DF7B41B8C814A995E00299913503456983FFB9178D3CD79EB6D55522418A8ABF65375872E55938AB99A84A0B5FC8A1ECC66A7C3766E7E0F80B7CE2C9225FC2DD683F4764244B72963BBB383F529DCF0C5D17740B8A2ADBE9208D40000000866616B652D736967" + ) self.assertEqual(K, transport._K) self.assertEqual(H, hexlify(transport._H).upper()) self.assertEqual(x, hexlify(transport._message.asbytes()).upper()) @@ -365,62 +457,88 @@ class KexTest (unittest.TestCase): transport.server_mode = True kex = KexGexSHA256(transport) kex.start_kex() - self.assertEqual((paramiko.kex_gex._MSG_KEXDH_GEX_REQUEST, paramiko.kex_gex._MSG_KEXDH_GEX_REQUEST_OLD), transport._expect) + self.assertEqual( + ( + paramiko.kex_gex._MSG_KEXDH_GEX_REQUEST, + paramiko.kex_gex._MSG_KEXDH_GEX_REQUEST_OLD, + ), + transport._expect, + ) msg = Message() msg.add_int(2048) msg.rewind() kex.parse_next(paramiko.kex_gex._MSG_KEXDH_GEX_REQUEST_OLD, msg) - x = b'1F0000008100FFFFFFFFFFFFFFFFC90FDAA22168C234C4C6628B80DC1CD129024E088A67CC74020BBEA63B139B22514A08798E3404DDEF9519B3CD3A431B302B0A6DF25F14374FE1356D6D51C245E485B576625E7EC6F44C42E9A637ED6B0BFF5CB6F406B7EDEE386BFB5A899FA5AE9F24117C4B1FE649286651ECE65381FFFFFFFFFFFFFFFF0000000102' + x = ( + b"1F0000008100FFFFFFFFFFFFFFFFC90FDAA22168C234C4C6628B80DC1CD129024E088A67CC74020BBEA63B139B22514A08798E3404DDEF9519B3CD3A431B302B0A6DF25F14374FE1356D6D51C245E485B576625E7EC6F44C42E9A637ED6B0BFF5CB6F406B7EDEE386BFB5A899FA5AE9F24117C4B1FE649286651ECE65381FFFFFFFFFFFFFFFF0000000102" + ) self.assertEqual(x, hexlify(transport._message.asbytes()).upper()) - self.assertEqual((paramiko.kex_gex._MSG_KEXDH_GEX_INIT,), transport._expect) + self.assertEqual( + (paramiko.kex_gex._MSG_KEXDH_GEX_INIT,), transport._expect + ) msg = Message() msg.add_mpint(12345) msg.rewind() kex.parse_next(paramiko.kex_gex._MSG_KEXDH_GEX_INIT, msg) - K = 67592995013596137876033460028393339951879041140378510871612128162185209509220726296697886624612526735888348020498716482757677848959420073720160491114319163078862905400020959196386947926388406687288901564192071077389283980347784184487280885335302632305026248574716290537036069329724382811853044654824945750581 - H = b'3DDD2AD840AD095E397BA4D0573972DC60F6461FD38A187CACA6615A5BC8ADBB' - x = b'210000000866616B652D6B6579000000807E2DDB1743F3487D6545F04F1C8476092FB912B013626AB5BCEB764257D88BBA64243B9F348DF7B41B8C814A995E00299913503456983FFB9178D3CD79EB6D55522418A8ABF65375872E55938AB99A84A0B5FC8A1ECC66A7C3766E7E0F80B7CE2C9225FC2DD683F4764244B72963BBB383F529DCF0C5D17740B8A2ADBE9208D40000000866616B652D736967' + K = ( + 67592995013596137876033460028393339951879041140378510871612128162185209509220726296697886624612526735888348020498716482757677848959420073720160491114319163078862905400020959196386947926388406687288901564192071077389283980347784184487280885335302632305026248574716290537036069329724382811853044654824945750581 + ) + H = b"3DDD2AD840AD095E397BA4D0573972DC60F6461FD38A187CACA6615A5BC8ADBB" + x = ( + b"210000000866616B652D6B6579000000807E2DDB1743F3487D6545F04F1C8476092FB912B013626AB5BCEB764257D88BBA64243B9F348DF7B41B8C814A995E00299913503456983FFB9178D3CD79EB6D55522418A8ABF65375872E55938AB99A84A0B5FC8A1ECC66A7C3766E7E0F80B7CE2C9225FC2DD683F4764244B72963BBB383F529DCF0C5D17740B8A2ADBE9208D40000000866616B652D736967" + ) self.assertEqual(K, transport._K) self.assertEqual(H, hexlify(transport._H).upper()) self.assertEqual(x, hexlify(transport._message.asbytes()).upper()) self.assertTrue(transport._activated) def test_11_kex_nistp256_client(self): - K = 91610929826364598472338906427792435253694642563583721654249504912114314269754 + K = ( + 91610929826364598472338906427792435253694642563583721654249504912114314269754 + ) transport = FakeTransport() transport.server_mode = False kex = KexNistp256(transport) kex.start_kex() - self.assertEqual((paramiko.kex_ecdh_nist._MSG_KEXECDH_REPLY,), transport._expect) + self.assertEqual( + (paramiko.kex_ecdh_nist._MSG_KEXECDH_REPLY,), transport._expect + ) - #fake reply + # fake reply msg = Message() - msg.add_string('fake-host-key') - Q_S = unhexlify("043ae159594ba062efa121480e9ef136203fa9ec6b6e1f8723a321c16e62b945f573f3b822258cbcd094b9fa1c125cbfe5f043280893e66863cc0cb4dccbe70210") + msg.add_string("fake-host-key") + Q_S = unhexlify( + "043ae159594ba062efa121480e9ef136203fa9ec6b6e1f8723a321c16e62b945f573f3b822258cbcd094b9fa1c125cbfe5f043280893e66863cc0cb4dccbe70210" + ) msg.add_string(Q_S) - msg.add_string('fake-sig') + msg.add_string("fake-sig") msg.rewind() kex.parse_next(paramiko.kex_ecdh_nist._MSG_KEXECDH_REPLY, msg) - H = b'BAF7CE243A836037EB5D2221420F35C02B9AB6C957FE3BDE3369307B9612570A' + H = b"BAF7CE243A836037EB5D2221420F35C02B9AB6C957FE3BDE3369307B9612570A" self.assertEqual(K, kex.transport._K) self.assertEqual(H, hexlify(transport._H).upper()) - self.assertEqual((b'fake-host-key', b'fake-sig'), transport._verify) + self.assertEqual((b"fake-host-key", b"fake-sig"), transport._verify) self.assertTrue(transport._activated) def test_12_kex_nistp256_server(self): - K = 91610929826364598472338906427792435253694642563583721654249504912114314269754 + K = ( + 91610929826364598472338906427792435253694642563583721654249504912114314269754 + ) transport = FakeTransport() transport.server_mode = True kex = KexNistp256(transport) kex.start_kex() - self.assertEqual((paramiko.kex_ecdh_nist._MSG_KEXECDH_INIT,), transport._expect) + self.assertEqual( + (paramiko.kex_ecdh_nist._MSG_KEXECDH_INIT,), transport._expect + ) - #fake init - msg=Message() - Q_C = unhexlify("043ae159594ba062efa121480e9ef136203fa9ec6b6e1f8723a321c16e62b945f573f3b822258cbcd094b9fa1c125cbfe5f043280893e66863cc0cb4dccbe70210") - H = b'2EF4957AFD530DD3F05DBEABF68D724FACC060974DA9704F2AEE4C3DE861E7CA' + # fake init + msg = Message() + Q_C = unhexlify( + "043ae159594ba062efa121480e9ef136203fa9ec6b6e1f8723a321c16e62b945f573f3b822258cbcd094b9fa1c125cbfe5f043280893e66863cc0cb4dccbe70210" + ) + H = b"2EF4957AFD530DD3F05DBEABF68D724FACC060974DA9704F2AEE4C3DE861E7CA" msg.add_string(Q_C) msg.rewind() kex.parse_next(paramiko.kex_ecdh_nist._MSG_KEXECDH_INIT, msg) diff --git a/tests/test_kex_gss.py b/tests/test_kex_gss.py index 025d1faa..afddee08 100644 --- a/tests/test_kex_gss.py +++ b/tests/test_kex_gss.py @@ -34,14 +34,14 @@ import paramiko from .util import needs_gssapi -class NullServer (paramiko.ServerInterface): +class NullServer(paramiko.ServerInterface): def get_allowed_auths(self, username): - return 'gssapi-keyex' + return "gssapi-keyex" - def check_auth_gssapi_keyex(self, username, - gss_authenticated=paramiko.AUTH_FAILED, - cc_file=None): + def check_auth_gssapi_keyex( + self, username, gss_authenticated=paramiko.AUTH_FAILED, cc_file=None + ): if gss_authenticated == paramiko.AUTH_SUCCESSFUL: return paramiko.AUTH_SUCCESSFUL return paramiko.AUTH_FAILED @@ -54,13 +54,14 @@ class NullServer (paramiko.ServerInterface): return paramiko.OPEN_SUCCEEDED def check_channel_exec_request(self, channel, command): - if command != 'yes': + if command != "yes": return False return True @needs_gssapi class GSSKexTest(unittest.TestCase): + @staticmethod def init(username, hostname): global krb5_principal, targ_name @@ -86,13 +87,13 @@ class GSSKexTest(unittest.TestCase): def _run(self): self.socks, addr = self.sockl.accept() self.ts = paramiko.Transport(self.socks, gss_kex=True) - host_key = paramiko.RSAKey.from_private_key_file('tests/test_rsa.key') + host_key = paramiko.RSAKey.from_private_key_file("tests/test_rsa.key") self.ts.add_server_key(host_key) self.ts.set_gss_host(targ_name) try: self.ts.load_server_moduli() except: - print ('(Failed to load moduli -- gex will be unsupported.)') + print("(Failed to load moduli -- gex will be unsupported.)") server = NullServer() self.ts.start_server(self.event, server) @@ -102,14 +103,21 @@ class GSSKexTest(unittest.TestCase): Diffie-Hellman Key Exchange and user authentication with the GSS-API context created during key exchange. """ - host_key = paramiko.RSAKey.from_private_key_file('tests/test_rsa.key') + host_key = paramiko.RSAKey.from_private_key_file("tests/test_rsa.key") public_host_key = paramiko.RSAKey(data=host_key.asbytes()) self.tc = paramiko.SSHClient() - self.tc.get_host_keys().add('[%s]:%d' % (self.hostname, self.port), - 'ssh-rsa', public_host_key) - self.tc.connect(self.hostname, self.port, username=self.username, - gss_auth=True, gss_kex=True, gss_host=gss_host) + self.tc.get_host_keys().add( + "[%s]:%d" % (self.hostname, self.port), "ssh-rsa", public_host_key + ) + self.tc.connect( + self.hostname, + self.port, + username=self.username, + gss_auth=True, + gss_kex=True, + gss_host=gss_host, + ) self.event.wait(1.0) self.assert_(self.event.is_set()) @@ -118,19 +126,19 @@ class GSSKexTest(unittest.TestCase): self.assertEquals(True, self.ts.is_authenticated()) self.assertEquals(True, self.tc.get_transport().gss_kex_used) - stdin, stdout, stderr = self.tc.exec_command('yes') + stdin, stdout, stderr = self.tc.exec_command("yes") schan = self.ts.accept(1.0) if rekey: self.tc.get_transport().renegotiate_keys() - schan.send('Hello there.\n') - schan.send_stderr('This is on stderr.\n') + schan.send("Hello there.\n") + schan.send_stderr("This is on stderr.\n") schan.close() - self.assertEquals('Hello there.\n', stdout.readline()) - self.assertEquals('', stdout.readline()) - self.assertEquals('This is on stderr.\n', stderr.readline()) - self.assertEquals('', stderr.readline()) + self.assertEquals("Hello there.\n", stdout.readline()) + self.assertEquals("", stdout.readline()) + self.assertEquals("This is on stderr.\n", stderr.readline()) + self.assertEquals("", stderr.readline()) stdin.close() stdout.close() diff --git a/tests/test_message.py b/tests/test_message.py index 645b0509..c292f4e6 100644 --- a/tests/test_message.py +++ b/tests/test_message.py @@ -26,20 +26,29 @@ from paramiko.message import Message from paramiko.common import byte_chr, zero_byte -class MessageTest (unittest.TestCase): +class MessageTest(unittest.TestCase): - __a = b'\x00\x00\x00\x17\x07\x60\xe0\x90\x00\x00\x00\x01\x71\x00\x00\x00\x05\x68\x65\x6c\x6c\x6f\x00\x00\x03\xe8' + b'x' * 1000 - __b = b'\x01\x00\xf3\x00\x3f\x00\x00\x00\x10\x68\x75\x65\x79\x2c\x64\x65\x77\x65\x79\x2c\x6c\x6f\x75\x69\x65' - __c = b'\x00\x00\x00\x00\x00\x00\x00\x05\x00\x00\xf5\xe4\xd3\xc2\xb1\x09\x00\x00\x00\x01\x11\x00\x00\x00\x07\x00\xf5\xe4\xd3\xc2\xb1\x09\x00\x00\x00\x06\x9a\x1b\x2c\x3d\x4e\xf7' - __d = b'\x00\x00\x00\x05\xff\x00\x00\x00\x05\x11\x22\x33\x44\x55\xff\x00\x00\x00\x0a\x00\xf0\x00\x00\x00\x00\x00\x00\x00\x00\x01\x00\x00\x00\x03\x63\x61\x74\x00\x00\x00\x03\x61\x2c\x62' + __a = ( + b"\x00\x00\x00\x17\x07\x60\xe0\x90\x00\x00\x00\x01\x71\x00\x00\x00\x05\x68\x65\x6c\x6c\x6f\x00\x00\x03\xe8" + + b"x" * 1000 + ) + __b = ( + b"\x01\x00\xf3\x00\x3f\x00\x00\x00\x10\x68\x75\x65\x79\x2c\x64\x65\x77\x65\x79\x2c\x6c\x6f\x75\x69\x65" + ) + __c = ( + b"\x00\x00\x00\x00\x00\x00\x00\x05\x00\x00\xf5\xe4\xd3\xc2\xb1\x09\x00\x00\x00\x01\x11\x00\x00\x00\x07\x00\xf5\xe4\xd3\xc2\xb1\x09\x00\x00\x00\x06\x9a\x1b\x2c\x3d\x4e\xf7" + ) + __d = ( + b"\x00\x00\x00\x05\xff\x00\x00\x00\x05\x11\x22\x33\x44\x55\xff\x00\x00\x00\x0a\x00\xf0\x00\x00\x00\x00\x00\x00\x00\x00\x01\x00\x00\x00\x03\x63\x61\x74\x00\x00\x00\x03\x61\x2c\x62" + ) def test_1_encode(self): msg = Message() msg.add_int(23) msg.add_int(123789456) - msg.add_string('q') - msg.add_string('hello') - msg.add_string('x' * 1000) + msg.add_string("q") + msg.add_string("hello") + msg.add_string("x" * 1000) self.assertEqual(msg.asbytes(), self.__a) msg = Message() @@ -48,7 +57,7 @@ class MessageTest (unittest.TestCase): msg.add_byte(byte_chr(0xf3)) msg.add_bytes(zero_byte + byte_chr(0x3f)) - msg.add_list(['huey', 'dewey', 'louie']) + msg.add_list(["huey", "dewey", "louie"]) self.assertEqual(msg.asbytes(), self.__b) msg = Message() @@ -63,16 +72,16 @@ class MessageTest (unittest.TestCase): msg = Message(self.__a) self.assertEqual(msg.get_int(), 23) self.assertEqual(msg.get_int(), 123789456) - self.assertEqual(msg.get_text(), 'q') - self.assertEqual(msg.get_text(), 'hello') - self.assertEqual(msg.get_text(), 'x' * 1000) + self.assertEqual(msg.get_text(), "q") + self.assertEqual(msg.get_text(), "hello") + self.assertEqual(msg.get_text(), "x" * 1000) msg = Message(self.__b) self.assertEqual(msg.get_boolean(), True) self.assertEqual(msg.get_boolean(), False) self.assertEqual(msg.get_byte(), byte_chr(0xf3)) self.assertEqual(msg.get_bytes(2), zero_byte + byte_chr(0x3f)) - self.assertEqual(msg.get_list(), ['huey', 'dewey', 'louie']) + self.assertEqual(msg.get_list(), ["huey", "dewey", "louie"]) msg = Message(self.__c) self.assertEqual(msg.get_int64(), 5) @@ -87,8 +96,8 @@ class MessageTest (unittest.TestCase): msg.add(0x1122334455) msg.add(0xf00000000000000000) msg.add(True) - msg.add('cat') - msg.add(['a', 'b']) + msg.add("cat") + msg.add(["a", "b"]) self.assertEqual(msg.asbytes(), self.__d) def test_4_misc(self): diff --git a/tests/test_packetizer.py b/tests/test_packetizer.py index 414b7e38..dbe5993e 100644 --- a/tests/test_packetizer.py +++ b/tests/test_packetizer.py @@ -36,19 +36,20 @@ from .loop import LoopSocket x55 = byte_chr(0x55) x1f = byte_chr(0x1f) -class PacketizerTest (unittest.TestCase): + +class PacketizerTest(unittest.TestCase): def test_1_write(self): rsock = LoopSocket() wsock = LoopSocket() rsock.link(wsock) p = Packetizer(wsock) - p.set_log(util.get_logger('paramiko.transport')) + p.set_log(util.get_logger("paramiko.transport")) p.set_hexdump(True) encryptor = Cipher( algorithms.AES(zero_byte * 16), modes.CBC(x55 * 16), - backend=default_backend() + backend=default_backend(), ).encryptor() p.set_outbound_cipher(encryptor, 16, sha1, 12, x1f * 20) @@ -63,22 +64,27 @@ class PacketizerTest (unittest.TestCase): data = rsock.recv(100) # 32 + 12 bytes of MAC = 44 self.assertEqual(44, len(data)) - self.assertEqual(b'\x43\x91\x97\xbd\x5b\x50\xac\x25\x87\xc2\xc4\x6b\xc7\xe9\x38\xc0', data[:16]) + self.assertEqual( + b"\x43\x91\x97\xbd\x5b\x50\xac\x25\x87\xc2\xc4\x6b\xc7\xe9\x38\xc0", + data[:16], + ) def test_2_read(self): rsock = LoopSocket() wsock = LoopSocket() rsock.link(wsock) p = Packetizer(rsock) - p.set_log(util.get_logger('paramiko.transport')) + p.set_log(util.get_logger("paramiko.transport")) p.set_hexdump(True) decryptor = Cipher( algorithms.AES(zero_byte * 16), modes.CBC(x55 * 16), - backend=default_backend() + backend=default_backend(), ).decryptor() p.set_inbound_cipher(decryptor, 16, sha1, 12, x1f * 20) - wsock.send(b'\x43\x91\x97\xbd\x5b\x50\xac\x25\x87\xc2\xc4\x6b\xc7\xe9\x38\xc0\x90\xd2\x16\x56\x0d\x71\x73\x61\x38\x7c\x4c\x3d\xfb\x97\x7d\xe2\x6e\x03\xb1\xa0\xc2\x1c\xd6\x41\x41\x4c\xb4\x59') + wsock.send( + b"\x43\x91\x97\xbd\x5b\x50\xac\x25\x87\xc2\xc4\x6b\xc7\xe9\x38\xc0\x90\xd2\x16\x56\x0d\x71\x73\x61\x38\x7c\x4c\x3d\xfb\x97\x7d\xe2\x6e\x03\xb1\xa0\xc2\x1c\xd6\x41\x41\x4c\xb4\x59" + ) cmd, m = p.read_message() self.assertEqual(100, cmd) self.assertEqual(100, m.get_int()) @@ -86,18 +92,18 @@ class PacketizerTest (unittest.TestCase): self.assertEqual(900, m.get_int()) def test_3_closed(self): - if sys.platform.startswith("win"): # no SIGALRM on windows + if sys.platform.startswith("win"): # no SIGALRM on windows return rsock = LoopSocket() wsock = LoopSocket() rsock.link(wsock) p = Packetizer(wsock) - p.set_log(util.get_logger('paramiko.transport')) + p.set_log(util.get_logger("paramiko.transport")) p.set_hexdump(True) encryptor = Cipher( algorithms.AES(zero_byte * 16), modes.CBC(x55 * 16), - backend=default_backend() + backend=default_backend(), ).encryptor() p.set_outbound_cipher(encryptor, 16, sha1, 12, x1f * 20) @@ -115,14 +121,17 @@ class PacketizerTest (unittest.TestCase): import signal class TimeoutError(Exception): + def __init__(self, error_message): - if hasattr(errno, 'ETIME'): + if hasattr(errno, "ETIME"): self.message = os.sterror(errno.ETIME) else: self.messaage = error_message - def timeout(seconds=1, error_message='Timer expired'): + def timeout(seconds=1, error_message="Timer expired"): + def decorator(func): + def _handle_timeout(signum, frame): raise TimeoutError(error_message) @@ -138,5 +147,6 @@ class PacketizerTest (unittest.TestCase): return wraps(func)(wrapper) return decorator + send = timeout()(p.send_message) self.assertRaises(EOFError, send, m) diff --git a/tests/test_pkey.py b/tests/test_pkey.py index 1827d2a9..4bbfaba1 100644 --- a/tests/test_pkey.py +++ b/tests/test_pkey.py @@ -34,18 +34,30 @@ from .util import _support # from openssh's ssh-keygen -PUB_RSA = 'ssh-rsa AAAAB3NzaC1yc2EAAAABIwAAAIEA049W6geFpmsljTwfvI1UmKWWJPNFI74+vNKTk4dmzkQY2yAMs6FhlvhlI8ysU4oj71ZsRYMecHbBbxdN79+JRFVYTKaLqjwGENeTd+yv4q+V2PvZv3fLnzApI3l7EJCqhWwJUHJ1jAkZzqDx0tyOL4uoZpww3nmE0kb3y21tH4c=' -PUB_DSS = 'ssh-dss AAAAB3NzaC1kc3MAAACBAOeBpgNnfRzr/twmAQRu2XwWAp3CFtrVnug6s6fgwj/oLjYbVtjAy6pl/h0EKCWx2rf1IetyNsTxWrniA9I6HeDj65X1FyDkg6g8tvCnaNB8Xp/UUhuzHuGsMIipRxBxw9LF608EqZcj1E3ytktoW5B5OcjrkEoz3xG7C+rpIjYvAAAAFQDwz4UnmsGiSNu5iqjn3uTzwUpshwAAAIEAkxfFeY8P2wZpDjX0MimZl5wkoFQDL25cPzGBuB4OnB8NoUk/yjAHIIpEShw8V+LzouMK5CTJQo5+Ngw3qIch/WgRmMHy4kBq1SsXMjQCte1So6HBMvBPIW5SiMTmjCfZZiw4AYHK+B/JaOwaG9yRg2Ejg4Ok10+XFDxlqZo8Y+wAAACARmR7CCPjodxASvRbIyzaVpZoJ/Z6x7dAumV+ysrV1BVYd0lYukmnjO1kKBWApqpH1ve9XDQYN8zgxM4b16L21kpoWQnZtXrY3GZ4/it9kUgyB7+NwacIBlXa8cMDL7Q/69o0d54U0X/NeX5QxuYR6OMJlrkQB7oiW/P/1mwjQgE=' -PUB_ECDSA_256 = 'ecdsa-sha2-nistp256 AAAAE2VjZHNhLXNoYTItbmlzdHAyNTYAAAAIbmlzdHAyNTYAAABBBJSPZm3ZWkvk/Zx8WP+fZRZ5/NBBHnGQwR6uIC6XHGPDIHuWUzIjAwA0bzqkOUffEsbLe+uQgKl5kbc/L8KA/eo=' -PUB_ECDSA_384 = 'ecdsa-sha2-nistp384 AAAAE2VjZHNhLXNoYTItbmlzdHAzODQAAAAIbmlzdHAzODQAAABhBBbGibQLW9AAZiGN2hEQxWYYoFaWKwN3PKSaDJSMqmIn1Z9sgRUuw8Y/w502OGvXL/wFk0i2z50l3pWZjD7gfMH7gX5TUiCzwrQkS+Hn1U2S9aF5WJp0NcIzYxXw2r4M2A==' -PUB_ECDSA_521 = 'ecdsa-sha2-nistp521 AAAAE2VjZHNhLXNoYTItbmlzdHA1MjEAAAAIbmlzdHA1MjEAAACFBACaOaFLZGuxa5AW16qj6VLypFbLrEWrt9AZUloCMefxO8bNLjK/O5g0rAVasar1TnyHE9qj4NwzANZASWjQNbc4MAG8vzqezFwLIn/kNyNTsXNfqEko9OgHZknlj2Z79dwTJcRAL4QLcT5aND0EHZLB2fAUDXiWIb2j4rg1mwPlBMiBXA==' - -FINGER_RSA = '1024 60:73:38:44:cb:51:86:65:7f:de:da:a2:2b:5a:57:d5' -FINGER_DSS = '1024 44:78:f0:b9:a2:3c:c5:18:20:09:ff:75:5b:c1:d2:6c' -FINGER_ECDSA_256 = '256 25:19:eb:55:e6:a1:47:ff:4f:38:d2:75:6f:a5:d5:60' -FINGER_ECDSA_384 = '384 c1:8d:a0:59:09:47:41:8e:a8:a6:07:01:29:23:b4:65' -FINGER_ECDSA_521 = '521 44:58:22:52:12:33:16:0e:ce:0e:be:2c:7c:7e:cc:1e' -SIGNED_RSA = '20:d7:8a:31:21:cb:f7:92:12:f2:a4:89:37:f5:78:af:e6:16:b6:25:b9:97:3d:a2:cd:5f:ca:20:21:73:4c:ad:34:73:8f:20:77:28:e2:94:15:08:d8:91:40:7a:85:83:bf:18:37:95:dc:54:1a:9b:88:29:6c:73:ca:38:b4:04:f1:56:b9:f2:42:9d:52:1b:29:29:b4:4f:fd:c9:2d:af:47:d2:40:76:30:f3:63:45:0c:d9:1d:43:86:0f:1c:70:e2:93:12:34:f3:ac:c5:0a:2f:14:50:66:59:f1:88:ee:c1:4a:e9:d1:9c:4e:46:f0:0e:47:6f:38:74:f1:44:a8' +PUB_RSA = ( + "ssh-rsa AAAAB3NzaC1yc2EAAAABIwAAAIEA049W6geFpmsljTwfvI1UmKWWJPNFI74+vNKTk4dmzkQY2yAMs6FhlvhlI8ysU4oj71ZsRYMecHbBbxdN79+JRFVYTKaLqjwGENeTd+yv4q+V2PvZv3fLnzApI3l7EJCqhWwJUHJ1jAkZzqDx0tyOL4uoZpww3nmE0kb3y21tH4c=" +) +PUB_DSS = ( + "ssh-dss AAAAB3NzaC1kc3MAAACBAOeBpgNnfRzr/twmAQRu2XwWAp3CFtrVnug6s6fgwj/oLjYbVtjAy6pl/h0EKCWx2rf1IetyNsTxWrniA9I6HeDj65X1FyDkg6g8tvCnaNB8Xp/UUhuzHuGsMIipRxBxw9LF608EqZcj1E3ytktoW5B5OcjrkEoz3xG7C+rpIjYvAAAAFQDwz4UnmsGiSNu5iqjn3uTzwUpshwAAAIEAkxfFeY8P2wZpDjX0MimZl5wkoFQDL25cPzGBuB4OnB8NoUk/yjAHIIpEShw8V+LzouMK5CTJQo5+Ngw3qIch/WgRmMHy4kBq1SsXMjQCte1So6HBMvBPIW5SiMTmjCfZZiw4AYHK+B/JaOwaG9yRg2Ejg4Ok10+XFDxlqZo8Y+wAAACARmR7CCPjodxASvRbIyzaVpZoJ/Z6x7dAumV+ysrV1BVYd0lYukmnjO1kKBWApqpH1ve9XDQYN8zgxM4b16L21kpoWQnZtXrY3GZ4/it9kUgyB7+NwacIBlXa8cMDL7Q/69o0d54U0X/NeX5QxuYR6OMJlrkQB7oiW/P/1mwjQgE=" +) +PUB_ECDSA_256 = ( + "ecdsa-sha2-nistp256 AAAAE2VjZHNhLXNoYTItbmlzdHAyNTYAAAAIbmlzdHAyNTYAAABBBJSPZm3ZWkvk/Zx8WP+fZRZ5/NBBHnGQwR6uIC6XHGPDIHuWUzIjAwA0bzqkOUffEsbLe+uQgKl5kbc/L8KA/eo=" +) +PUB_ECDSA_384 = ( + "ecdsa-sha2-nistp384 AAAAE2VjZHNhLXNoYTItbmlzdHAzODQAAAAIbmlzdHAzODQAAABhBBbGibQLW9AAZiGN2hEQxWYYoFaWKwN3PKSaDJSMqmIn1Z9sgRUuw8Y/w502OGvXL/wFk0i2z50l3pWZjD7gfMH7gX5TUiCzwrQkS+Hn1U2S9aF5WJp0NcIzYxXw2r4M2A==" +) +PUB_ECDSA_521 = ( + "ecdsa-sha2-nistp521 AAAAE2VjZHNhLXNoYTItbmlzdHA1MjEAAAAIbmlzdHA1MjEAAACFBACaOaFLZGuxa5AW16qj6VLypFbLrEWrt9AZUloCMefxO8bNLjK/O5g0rAVasar1TnyHE9qj4NwzANZASWjQNbc4MAG8vzqezFwLIn/kNyNTsXNfqEko9OgHZknlj2Z79dwTJcRAL4QLcT5aND0EHZLB2fAUDXiWIb2j4rg1mwPlBMiBXA==" +) + +FINGER_RSA = "1024 60:73:38:44:cb:51:86:65:7f:de:da:a2:2b:5a:57:d5" +FINGER_DSS = "1024 44:78:f0:b9:a2:3c:c5:18:20:09:ff:75:5b:c1:d2:6c" +FINGER_ECDSA_256 = "256 25:19:eb:55:e6:a1:47:ff:4f:38:d2:75:6f:a5:d5:60" +FINGER_ECDSA_384 = "384 c1:8d:a0:59:09:47:41:8e:a8:a6:07:01:29:23:b4:65" +FINGER_ECDSA_521 = "521 44:58:22:52:12:33:16:0e:ce:0e:be:2c:7c:7e:cc:1e" +SIGNED_RSA = ( + "20:d7:8a:31:21:cb:f7:92:12:f2:a4:89:37:f5:78:af:e6:16:b6:25:b9:97:3d:a2:cd:5f:ca:20:21:73:4c:ad:34:73:8f:20:77:28:e2:94:15:08:d8:91:40:7a:85:83:bf:18:37:95:dc:54:1a:9b:88:29:6c:73:ca:38:b4:04:f1:56:b9:f2:42:9d:52:1b:29:29:b4:4f:fd:c9:2d:af:47:d2:40:76:30:f3:63:45:0c:d9:1d:43:86:0f:1c:70:e2:93:12:34:f3:ac:c5:0a:2f:14:50:66:59:f1:88:ee:c1:4a:e9:d1:9c:4e:46:f0:0e:47:6f:38:74:f1:44:a8" +) RSA_PRIVATE_OUT = """\ -----BEGIN RSA PRIVATE KEY----- @@ -107,10 +119,14 @@ L4QLcT5aND0EHZLB2fAUDXiWIb2j4rg1mwPlBMiBXA== -----END EC PRIVATE KEY----- """ -x1234 = b'\x01\x02\x03\x04' +x1234 = b"\x01\x02\x03\x04" -TEST_KEY_BYTESTR_2 = '\x00\x00\x00\x07ssh-rsa\x00\x00\x00\x01#\x00\x00\x00\x81\x00\xd3\x8fV\xea\x07\x85\xa6k%\x8d<\x1f\xbc\x8dT\x98\xa5\x96$\xf3E#\xbe>\xbc\xd2\x93\x93\x87f\xceD\x18\xdb \x0c\xb3\xa1a\x96\xf8e#\xcc\xacS\x8a#\xefVlE\x83\x1epv\xc1o\x17M\xef\xdf\x89DUXL\xa6\x8b\xaa<\x06\x10\xd7\x93w\xec\xaf\xe2\xaf\x95\xd8\xfb\xd9\xbfw\xcb\x9f0)#y{\x10\x90\xaa\x85l\tPru\x8c\t\x19\xce\xa0\xf1\xd2\xdc\x8e/\x8b\xa8f\x9c0\xdey\x84\xd2F\xf7\xcbmm\x1f\x87' -TEST_KEY_BYTESTR_3 = '\x00\x00\x00\x07ssh-rsa\x00\x00\x00\x01#\x00\x00\x00\x00ӏV\x07k%<\x1fT$E#>ғfD\x18 \x0cae#̬S#VlE\x1epvo\x17M߉DUXL<\x06\x10דw\u2bd5ٿw˟0)#y{\x10l\tPru\t\x19Π\u070e/f0yFmm\x1f' +TEST_KEY_BYTESTR_2 = ( + "\x00\x00\x00\x07ssh-rsa\x00\x00\x00\x01#\x00\x00\x00\x81\x00\xd3\x8fV\xea\x07\x85\xa6k%\x8d<\x1f\xbc\x8dT\x98\xa5\x96$\xf3E#\xbe>\xbc\xd2\x93\x93\x87f\xceD\x18\xdb \x0c\xb3\xa1a\x96\xf8e#\xcc\xacS\x8a#\xefVlE\x83\x1epv\xc1o\x17M\xef\xdf\x89DUXL\xa6\x8b\xaa<\x06\x10\xd7\x93w\xec\xaf\xe2\xaf\x95\xd8\xfb\xd9\xbfw\xcb\x9f0)#y{\x10\x90\xaa\x85l\tPru\x8c\t\x19\xce\xa0\xf1\xd2\xdc\x8e/\x8b\xa8f\x9c0\xdey\x84\xd2F\xf7\xcbmm\x1f\x87" +) +TEST_KEY_BYTESTR_3 = ( + "\x00\x00\x00\x07ssh-rsa\x00\x00\x00\x01#\x00\x00\x00\x00ӏV\x07k%<\x1fT$E#>ғfD\x18 \x0cae#̬S#VlE\x1epvo\x17M߉DUXL<\x06\x10דw\u2bd5ٿw˟0)#y{\x10l\tPru\t\x19Π\u070e/f0yFmm\x1f" +) class KeyTest(unittest.TestCase): @@ -127,21 +143,22 @@ class KeyTest(unittest.TestCase): """ with open(keyfile, "r") as fh: self.assertEqual( - fh.readline()[:-1], - "-----BEGIN RSA PRIVATE KEY-----" + fh.readline()[:-1], "-----BEGIN RSA PRIVATE KEY-----" ) self.assertEqual(fh.readline()[:-1], "Proc-Type: 4,ENCRYPTED") self.assertEqual(fh.readline()[0:10], "DEK-Info: ") def test_1_generate_key_bytes(self): - key = util.generate_key_bytes(md5, x1234, 'happy birthday', 30) - exp = b'\x61\xE1\xF2\x72\xF4\xC1\xC4\x56\x15\x86\xBD\x32\x24\x98\xC0\xE9\x24\x67\x27\x80\xF4\x7B\xB3\x7D\xDA\x7D\x54\x01\x9E\x64' + key = util.generate_key_bytes(md5, x1234, "happy birthday", 30) + exp = ( + b"\x61\xE1\xF2\x72\xF4\xC1\xC4\x56\x15\x86\xBD\x32\x24\x98\xC0\xE9\x24\x67\x27\x80\xF4\x7B\xB3\x7D\xDA\x7D\x54\x01\x9E\x64" + ) self.assertEqual(exp, key) def test_2_load_rsa(self): - key = RSAKey.from_private_key_file(_support('test_rsa.key')) - self.assertEqual('ssh-rsa', key.get_name()) - exp_rsa = b(FINGER_RSA.split()[1].replace(':', '')) + key = RSAKey.from_private_key_file(_support("test_rsa.key")) + self.assertEqual("ssh-rsa", key.get_name()) + exp_rsa = b(FINGER_RSA.split()[1].replace(":", "")) my_rsa = hexlify(key.get_fingerprint()) self.assertEqual(exp_rsa, my_rsa) self.assertEqual(PUB_RSA.split()[1], key.get_base64()) @@ -155,18 +172,20 @@ class KeyTest(unittest.TestCase): self.assertEqual(key, key2) def test_3_load_rsa_password(self): - key = RSAKey.from_private_key_file(_support('test_rsa_password.key'), 'television') - self.assertEqual('ssh-rsa', key.get_name()) - exp_rsa = b(FINGER_RSA.split()[1].replace(':', '')) + key = RSAKey.from_private_key_file( + _support("test_rsa_password.key"), "television" + ) + self.assertEqual("ssh-rsa", key.get_name()) + exp_rsa = b(FINGER_RSA.split()[1].replace(":", "")) my_rsa = hexlify(key.get_fingerprint()) self.assertEqual(exp_rsa, my_rsa) self.assertEqual(PUB_RSA.split()[1], key.get_base64()) self.assertEqual(1024, key.get_bits()) def test_4_load_dss(self): - key = DSSKey.from_private_key_file(_support('test_dss.key')) - self.assertEqual('ssh-dss', key.get_name()) - exp_dss = b(FINGER_DSS.split()[1].replace(':', '')) + key = DSSKey.from_private_key_file(_support("test_dss.key")) + self.assertEqual("ssh-dss", key.get_name()) + exp_dss = b(FINGER_DSS.split()[1].replace(":", "")) my_dss = hexlify(key.get_fingerprint()) self.assertEqual(exp_dss, my_dss) self.assertEqual(PUB_DSS.split()[1], key.get_base64()) @@ -180,9 +199,11 @@ class KeyTest(unittest.TestCase): self.assertEqual(key, key2) def test_5_load_dss_password(self): - key = DSSKey.from_private_key_file(_support('test_dss_password.key'), 'television') - self.assertEqual('ssh-dss', key.get_name()) - exp_dss = b(FINGER_DSS.split()[1].replace(':', '')) + key = DSSKey.from_private_key_file( + _support("test_dss_password.key"), "television" + ) + self.assertEqual("ssh-dss", key.get_name()) + exp_dss = b(FINGER_DSS.split()[1].replace(":", "")) my_dss = hexlify(key.get_fingerprint()) self.assertEqual(exp_dss, my_dss) self.assertEqual(PUB_DSS.split()[1], key.get_base64()) @@ -190,7 +211,7 @@ class KeyTest(unittest.TestCase): def test_6_compare_rsa(self): # verify that the private & public keys compare equal - key = RSAKey.from_private_key_file(_support('test_rsa.key')) + key = RSAKey.from_private_key_file(_support("test_rsa.key")) self.assertEqual(key, key) pub = RSAKey(data=key.asbytes()) self.assertTrue(key.can_sign()) @@ -199,7 +220,7 @@ class KeyTest(unittest.TestCase): def test_7_compare_dss(self): # verify that the private & public keys compare equal - key = DSSKey.from_private_key_file(_support('test_dss.key')) + key = DSSKey.from_private_key_file(_support("test_dss.key")) self.assertEqual(key, key) pub = DSSKey(data=key.asbytes()) self.assertTrue(key.can_sign()) @@ -208,77 +229,79 @@ class KeyTest(unittest.TestCase): def test_8_sign_rsa(self): # verify that the rsa private key can sign and verify - key = RSAKey.from_private_key_file(_support('test_rsa.key')) - msg = key.sign_ssh_data(b'ice weasels') + key = RSAKey.from_private_key_file(_support("test_rsa.key")) + msg = key.sign_ssh_data(b"ice weasels") self.assertTrue(type(msg) is Message) msg.rewind() - self.assertEqual('ssh-rsa', msg.get_text()) - sig = bytes().join([byte_chr(int(x, 16)) for x in SIGNED_RSA.split(':')]) + self.assertEqual("ssh-rsa", msg.get_text()) + sig = bytes().join( + [byte_chr(int(x, 16)) for x in SIGNED_RSA.split(":")] + ) self.assertEqual(sig, msg.get_binary()) msg.rewind() pub = RSAKey(data=key.asbytes()) - self.assertTrue(pub.verify_ssh_sig(b'ice weasels', msg)) + self.assertTrue(pub.verify_ssh_sig(b"ice weasels", msg)) def test_9_sign_dss(self): # verify that the dss private key can sign and verify - key = DSSKey.from_private_key_file(_support('test_dss.key')) - msg = key.sign_ssh_data(b'ice weasels') + key = DSSKey.from_private_key_file(_support("test_dss.key")) + msg = key.sign_ssh_data(b"ice weasels") self.assertTrue(type(msg) is Message) msg.rewind() - self.assertEqual('ssh-dss', msg.get_text()) + self.assertEqual("ssh-dss", msg.get_text()) # can't do the same test as we do for RSA, because DSS signatures # are usually different each time. but we can test verification # anyway so it's ok. self.assertEqual(40, len(msg.get_binary())) msg.rewind() pub = DSSKey(data=key.asbytes()) - self.assertTrue(pub.verify_ssh_sig(b'ice weasels', msg)) + self.assertTrue(pub.verify_ssh_sig(b"ice weasels", msg)) def test_A_generate_rsa(self): key = RSAKey.generate(1024) - msg = key.sign_ssh_data(b'jerri blank') + msg = key.sign_ssh_data(b"jerri blank") msg.rewind() - self.assertTrue(key.verify_ssh_sig(b'jerri blank', msg)) + self.assertTrue(key.verify_ssh_sig(b"jerri blank", msg)) def test_B_generate_dss(self): key = DSSKey.generate(1024) - msg = key.sign_ssh_data(b'jerri blank') + msg = key.sign_ssh_data(b"jerri blank") msg.rewind() - self.assertTrue(key.verify_ssh_sig(b'jerri blank', msg)) + self.assertTrue(key.verify_ssh_sig(b"jerri blank", msg)) def test_C_generate_ecdsa(self): key = ECDSAKey.generate() - msg = key.sign_ssh_data(b'jerri blank') + msg = key.sign_ssh_data(b"jerri blank") msg.rewind() - self.assertTrue(key.verify_ssh_sig(b'jerri blank', msg)) + self.assertTrue(key.verify_ssh_sig(b"jerri blank", msg)) self.assertEqual(key.get_bits(), 256) - self.assertEqual(key.get_name(), 'ecdsa-sha2-nistp256') + self.assertEqual(key.get_name(), "ecdsa-sha2-nistp256") key = ECDSAKey.generate(bits=256) - msg = key.sign_ssh_data(b'jerri blank') + msg = key.sign_ssh_data(b"jerri blank") msg.rewind() - self.assertTrue(key.verify_ssh_sig(b'jerri blank', msg)) + self.assertTrue(key.verify_ssh_sig(b"jerri blank", msg)) self.assertEqual(key.get_bits(), 256) - self.assertEqual(key.get_name(), 'ecdsa-sha2-nistp256') + self.assertEqual(key.get_name(), "ecdsa-sha2-nistp256") key = ECDSAKey.generate(bits=384) - msg = key.sign_ssh_data(b'jerri blank') + msg = key.sign_ssh_data(b"jerri blank") msg.rewind() - self.assertTrue(key.verify_ssh_sig(b'jerri blank', msg)) + self.assertTrue(key.verify_ssh_sig(b"jerri blank", msg)) self.assertEqual(key.get_bits(), 384) - self.assertEqual(key.get_name(), 'ecdsa-sha2-nistp384') + self.assertEqual(key.get_name(), "ecdsa-sha2-nistp384") key = ECDSAKey.generate(bits=521) - msg = key.sign_ssh_data(b'jerri blank') + msg = key.sign_ssh_data(b"jerri blank") msg.rewind() - self.assertTrue(key.verify_ssh_sig(b'jerri blank', msg)) + self.assertTrue(key.verify_ssh_sig(b"jerri blank", msg)) self.assertEqual(key.get_bits(), 521) - self.assertEqual(key.get_name(), 'ecdsa-sha2-nistp521') + self.assertEqual(key.get_name(), "ecdsa-sha2-nistp521") def test_10_load_ecdsa_256(self): - key = ECDSAKey.from_private_key_file(_support('test_ecdsa_256.key')) - self.assertEqual('ecdsa-sha2-nistp256', key.get_name()) - exp_ecdsa = b(FINGER_ECDSA_256.split()[1].replace(':', '')) + key = ECDSAKey.from_private_key_file(_support("test_ecdsa_256.key")) + self.assertEqual("ecdsa-sha2-nistp256", key.get_name()) + exp_ecdsa = b(FINGER_ECDSA_256.split()[1].replace(":", "")) my_ecdsa = hexlify(key.get_fingerprint()) self.assertEqual(exp_ecdsa, my_ecdsa) self.assertEqual(PUB_ECDSA_256.split()[1], key.get_base64()) @@ -292,9 +315,11 @@ class KeyTest(unittest.TestCase): self.assertEqual(key, key2) def test_11_load_ecdsa_password_256(self): - key = ECDSAKey.from_private_key_file(_support('test_ecdsa_password_256.key'), b'television') - self.assertEqual('ecdsa-sha2-nistp256', key.get_name()) - exp_ecdsa = b(FINGER_ECDSA_256.split()[1].replace(':', '')) + key = ECDSAKey.from_private_key_file( + _support("test_ecdsa_password_256.key"), b"television" + ) + self.assertEqual("ecdsa-sha2-nistp256", key.get_name()) + exp_ecdsa = b(FINGER_ECDSA_256.split()[1].replace(":", "")) my_ecdsa = hexlify(key.get_fingerprint()) self.assertEqual(exp_ecdsa, my_ecdsa) self.assertEqual(PUB_ECDSA_256.split()[1], key.get_base64()) @@ -302,7 +327,7 @@ class KeyTest(unittest.TestCase): def test_12_compare_ecdsa_256(self): # verify that the private & public keys compare equal - key = ECDSAKey.from_private_key_file(_support('test_ecdsa_256.key')) + key = ECDSAKey.from_private_key_file(_support("test_ecdsa_256.key")) self.assertEqual(key, key) pub = ECDSAKey(data=key.asbytes()) self.assertTrue(key.can_sign()) @@ -311,11 +336,11 @@ class KeyTest(unittest.TestCase): def test_13_sign_ecdsa_256(self): # verify that the rsa private key can sign and verify - key = ECDSAKey.from_private_key_file(_support('test_ecdsa_256.key')) - msg = key.sign_ssh_data(b'ice weasels') + key = ECDSAKey.from_private_key_file(_support("test_ecdsa_256.key")) + msg = key.sign_ssh_data(b"ice weasels") self.assertTrue(type(msg) is Message) msg.rewind() - self.assertEqual('ecdsa-sha2-nistp256', msg.get_text()) + self.assertEqual("ecdsa-sha2-nistp256", msg.get_text()) # ECDSA signatures, like DSS signatures, tend to be different # each time, so we can't compare against a "known correct" # signature. @@ -323,12 +348,12 @@ class KeyTest(unittest.TestCase): msg.rewind() pub = ECDSAKey(data=key.asbytes()) - self.assertTrue(pub.verify_ssh_sig(b'ice weasels', msg)) + self.assertTrue(pub.verify_ssh_sig(b"ice weasels", msg)) def test_14_load_ecdsa_384(self): - key = ECDSAKey.from_private_key_file(_support('test_ecdsa_384.key')) - self.assertEqual('ecdsa-sha2-nistp384', key.get_name()) - exp_ecdsa = b(FINGER_ECDSA_384.split()[1].replace(':', '')) + key = ECDSAKey.from_private_key_file(_support("test_ecdsa_384.key")) + self.assertEqual("ecdsa-sha2-nistp384", key.get_name()) + exp_ecdsa = b(FINGER_ECDSA_384.split()[1].replace(":", "")) my_ecdsa = hexlify(key.get_fingerprint()) self.assertEqual(exp_ecdsa, my_ecdsa) self.assertEqual(PUB_ECDSA_384.split()[1], key.get_base64()) @@ -342,9 +367,11 @@ class KeyTest(unittest.TestCase): self.assertEqual(key, key2) def test_15_load_ecdsa_password_384(self): - key = ECDSAKey.from_private_key_file(_support('test_ecdsa_password_384.key'), b'television') - self.assertEqual('ecdsa-sha2-nistp384', key.get_name()) - exp_ecdsa = b(FINGER_ECDSA_384.split()[1].replace(':', '')) + key = ECDSAKey.from_private_key_file( + _support("test_ecdsa_password_384.key"), b"television" + ) + self.assertEqual("ecdsa-sha2-nistp384", key.get_name()) + exp_ecdsa = b(FINGER_ECDSA_384.split()[1].replace(":", "")) my_ecdsa = hexlify(key.get_fingerprint()) self.assertEqual(exp_ecdsa, my_ecdsa) self.assertEqual(PUB_ECDSA_384.split()[1], key.get_base64()) @@ -352,7 +379,7 @@ class KeyTest(unittest.TestCase): def test_16_compare_ecdsa_384(self): # verify that the private & public keys compare equal - key = ECDSAKey.from_private_key_file(_support('test_ecdsa_384.key')) + key = ECDSAKey.from_private_key_file(_support("test_ecdsa_384.key")) self.assertEqual(key, key) pub = ECDSAKey(data=key.asbytes()) self.assertTrue(key.can_sign()) @@ -361,11 +388,11 @@ class KeyTest(unittest.TestCase): def test_17_sign_ecdsa_384(self): # verify that the rsa private key can sign and verify - key = ECDSAKey.from_private_key_file(_support('test_ecdsa_384.key')) - msg = key.sign_ssh_data(b'ice weasels') + key = ECDSAKey.from_private_key_file(_support("test_ecdsa_384.key")) + msg = key.sign_ssh_data(b"ice weasels") self.assertTrue(type(msg) is Message) msg.rewind() - self.assertEqual('ecdsa-sha2-nistp384', msg.get_text()) + self.assertEqual("ecdsa-sha2-nistp384", msg.get_text()) # ECDSA signatures, like DSS signatures, tend to be different # each time, so we can't compare against a "known correct" # signature. @@ -373,12 +400,12 @@ class KeyTest(unittest.TestCase): msg.rewind() pub = ECDSAKey(data=key.asbytes()) - self.assertTrue(pub.verify_ssh_sig(b'ice weasels', msg)) + self.assertTrue(pub.verify_ssh_sig(b"ice weasels", msg)) def test_18_load_ecdsa_521(self): - key = ECDSAKey.from_private_key_file(_support('test_ecdsa_521.key')) - self.assertEqual('ecdsa-sha2-nistp521', key.get_name()) - exp_ecdsa = b(FINGER_ECDSA_521.split()[1].replace(':', '')) + key = ECDSAKey.from_private_key_file(_support("test_ecdsa_521.key")) + self.assertEqual("ecdsa-sha2-nistp521", key.get_name()) + exp_ecdsa = b(FINGER_ECDSA_521.split()[1].replace(":", "")) my_ecdsa = hexlify(key.get_fingerprint()) self.assertEqual(exp_ecdsa, my_ecdsa) self.assertEqual(PUB_ECDSA_521.split()[1], key.get_base64()) @@ -395,9 +422,11 @@ class KeyTest(unittest.TestCase): self.assertEqual(key, key2) def test_19_load_ecdsa_password_521(self): - key = ECDSAKey.from_private_key_file(_support('test_ecdsa_password_521.key'), b'television') - self.assertEqual('ecdsa-sha2-nistp521', key.get_name()) - exp_ecdsa = b(FINGER_ECDSA_521.split()[1].replace(':', '')) + key = ECDSAKey.from_private_key_file( + _support("test_ecdsa_password_521.key"), b"television" + ) + self.assertEqual("ecdsa-sha2-nistp521", key.get_name()) + exp_ecdsa = b(FINGER_ECDSA_521.split()[1].replace(":", "")) my_ecdsa = hexlify(key.get_fingerprint()) self.assertEqual(exp_ecdsa, my_ecdsa) self.assertEqual(PUB_ECDSA_521.split()[1], key.get_base64()) @@ -405,7 +434,7 @@ class KeyTest(unittest.TestCase): def test_20_compare_ecdsa_521(self): # verify that the private & public keys compare equal - key = ECDSAKey.from_private_key_file(_support('test_ecdsa_521.key')) + key = ECDSAKey.from_private_key_file(_support("test_ecdsa_521.key")) self.assertEqual(key, key) pub = ECDSAKey(data=key.asbytes()) self.assertTrue(key.can_sign()) @@ -414,11 +443,11 @@ class KeyTest(unittest.TestCase): def test_21_sign_ecdsa_521(self): # verify that the rsa private key can sign and verify - key = ECDSAKey.from_private_key_file(_support('test_ecdsa_521.key')) - msg = key.sign_ssh_data(b'ice weasels') + key = ECDSAKey.from_private_key_file(_support("test_ecdsa_521.key")) + msg = key.sign_ssh_data(b"ice weasels") self.assertTrue(type(msg) is Message) msg.rewind() - self.assertEqual('ecdsa-sha2-nistp521', msg.get_text()) + self.assertEqual("ecdsa-sha2-nistp521", msg.get_text()) # ECDSA signatures, like DSS signatures, tend to be different # each time, so we can't compare against a "known correct" # signature. @@ -426,14 +455,14 @@ class KeyTest(unittest.TestCase): msg.rewind() pub = ECDSAKey(data=key.asbytes()) - self.assertTrue(pub.verify_ssh_sig(b'ice weasels', msg)) + self.assertTrue(pub.verify_ssh_sig(b"ice weasels", msg)) def test_salt_size(self): # Read an existing encrypted private key - file_ = _support('test_rsa_password.key') - password = 'television' - newfile = file_ + '.new' - newpassword = 'radio' + file_ = _support("test_rsa_password.key") + password = "television" + newfile = file_ + ".new" + newpassword = "radio" key = RSAKey(filename=file_, password=password) # Write out a newly re-encrypted copy with a new password. # When the bug under test exists, this will ValueError. @@ -447,20 +476,20 @@ class KeyTest(unittest.TestCase): os.remove(newfile) def test_stringification(self): - key = RSAKey.from_private_key_file(_support('test_rsa.key')) + key = RSAKey.from_private_key_file(_support("test_rsa.key")) comparable = TEST_KEY_BYTESTR_2 if PY2 else TEST_KEY_BYTESTR_3 self.assertEqual(str(key), comparable) def test_ed25519(self): - key1 = Ed25519Key.from_private_key_file(_support('test_ed25519.key')) + key1 = Ed25519Key.from_private_key_file(_support("test_ed25519.key")) key2 = Ed25519Key.from_private_key_file( - _support('test_ed25519_password.key'), b'abc123' + _support("test_ed25519_password.key"), b"abc123" ) self.assertNotEqual(key1.asbytes(), key2.asbytes()) def test_ed25519_compare(self): # verify that the private & public keys compare equal - key = Ed25519Key.from_private_key_file(_support('test_ed25519.key')) + key = Ed25519Key.from_private_key_file(_support("test_ed25519.key")) self.assertEqual(key, key) pub = Ed25519Key(data=key.asbytes()) self.assertTrue(key.can_sign()) @@ -470,25 +499,25 @@ class KeyTest(unittest.TestCase): def test_ed25519_nonbytes_password(self): # https://github.com/paramiko/paramiko/issues/1039 key = Ed25519Key.from_private_key_file( - _support('test_ed25519_password.key'), + _support("test_ed25519_password.key"), # NOTE: not a bytes. Amusingly, the test above for same key DOES # explicitly cast to bytes...code smell! - 'abc123', + "abc123", ) # No exception -> it's good. Meh. def test_ed25519_load_from_file_obj(self): - with open(_support('test_ed25519.key')) as pkey_fileobj: + with open(_support("test_ed25519.key")) as pkey_fileobj: key = Ed25519Key.from_private_key(pkey_fileobj) self.assertEqual(key, key) self.assertTrue(key.can_sign()) def test_keyfile_is_actually_encrypted(self): # Read an existing encrypted private key - file_ = _support('test_rsa_password.key') - password = 'television' - newfile = file_ + '.new' - newpassword = 'radio' + file_ = _support("test_rsa_password.key") + password = "television" + newfile = file_ + ".new" + newpassword = "radio" key = RSAKey(filename=file_, password=password) # Write out a newly re-encrypted copy with a new password. # When the bug under test exists, this will ValueError. @@ -503,19 +532,21 @@ class KeyTest(unittest.TestCase): # test_client.py; this and nearby cert tests are more about the gritty # details. # PKey.load_certificate - key_path = _support(os.path.join('cert_support', 'test_rsa.key')) + key_path = _support(os.path.join("cert_support", "test_rsa.key")) key = RSAKey.from_private_key_file(key_path) self.assertTrue(key.public_blob is None) cert_path = _support( - os.path.join('cert_support', 'test_rsa.key-cert.pub') + os.path.join("cert_support", "test_rsa.key-cert.pub") ) key.load_certificate(cert_path) self.assertTrue(key.public_blob is not None) - self.assertEqual(key.public_blob.key_type, 'ssh-rsa-cert-v01@openssh.com') - self.assertEqual(key.public_blob.comment, 'test_rsa.key.pub') + self.assertEqual( + key.public_blob.key_type, "ssh-rsa-cert-v01@openssh.com" + ) + self.assertEqual(key.public_blob.comment, "test_rsa.key.pub") # Delve into blob contents, for test purposes msg = Message(key.public_blob.key_blob) - self.assertEqual(msg.get_text(), 'ssh-rsa-cert-v01@openssh.com') + self.assertEqual(msg.get_text(), "ssh-rsa-cert-v01@openssh.com") nonce = msg.get_string() e = msg.get_mpint() n = msg.get_mpint() @@ -525,10 +556,10 @@ class KeyTest(unittest.TestCase): self.assertEqual(msg.get_int64(), 1234) # Prevented from loading certificate that doesn't match - key_path = _support(os.path.join('cert_support', 'test_ed25519.key')) + key_path = _support(os.path.join("cert_support", "test_ed25519.key")) key1 = Ed25519Key.from_private_key_file(key_path) self.assertRaises( ValueError, key1.load_certificate, - _support('test_rsa.key-cert.pub'), + _support("test_rsa.key-cert.pub"), ) diff --git a/tests/test_sftp.py b/tests/test_sftp.py index 09a50453..576b69b7 100644 --- a/tests/test_sftp.py +++ b/tests/test_sftp.py @@ -45,7 +45,7 @@ from .stub_sftp import StubServer, StubSFTPServer from .util import _support, slow -ARTICLE = ''' +ARTICLE = """ Insulin sensitivity and liver insulin receptor structure in ducks from two genera @@ -70,7 +70,7 @@ receptors. Therefore the ducks from the two genera exhibit an alpha-beta- structure for liver insulin receptors and a clear difference in the number of liver insulin receptors. Their sensitivity to insulin is, however, similarly decreased compared with chicken. -''' +""" # Here is how unicode characters are encoded over 1 to 6 bytes in utf-8 @@ -82,32 +82,33 @@ decreased compared with chicken. # U-04000000 - U-7FFFFFFF: 1111110x 10xxxxxx 10xxxxxx 10xxxxxx 10xxxxxx 10xxxxxx # Note that: hex(int('11000011',2)) == '0xc3' # Thus, the following 2-bytes sequence is not valid utf8: "invalid continuation byte" -NON_UTF8_DATA = b'\xC3\xC3' +NON_UTF8_DATA = b"\xC3\xC3" -unicode_folder = u'\u00fcnic\u00f8de' if PY2 else '\u00fcnic\u00f8de' -utf8_folder = b'/\xc3\xbcnic\xc3\xb8\x64\x65' +unicode_folder = u"\u00fcnic\u00f8de" if PY2 else "\u00fcnic\u00f8de" +utf8_folder = b"/\xc3\xbcnic\xc3\xb8\x64\x65" @slow class TestSFTP(object): + def test_1_file(self, sftp): """ verify that we can create a file. """ - f = sftp.open(sftp.FOLDER + '/test', 'w') + f = sftp.open(sftp.FOLDER + "/test", "w") try: assert f.stat().st_size == 0 finally: f.close() - sftp.remove(sftp.FOLDER + '/test') + sftp.remove(sftp.FOLDER + "/test") def test_2_close(self, sftp): """ Verify that SFTP session close() causes a socket error on next action. """ sftp.close() - with pytest.raises(socket.error, match='Socket is closed'): - sftp.open(sftp.FOLDER + '/test2', 'w') + with pytest.raises(socket.error, match="Socket is closed"): + sftp.open(sftp.FOLDER + "/test2", "w") def test_2_sftp_can_be_used_as_context_manager(self, sftp): """ @@ -115,117 +116,117 @@ class TestSFTP(object): """ with sftp: pass - with pytest.raises(socket.error, match='Socket is closed'): - sftp.open(sftp.FOLDER + '/test2', 'w') + with pytest.raises(socket.error, match="Socket is closed"): + sftp.open(sftp.FOLDER + "/test2", "w") def test_3_write(self, sftp): """ verify that a file can be created and written, and the size is correct. """ try: - with sftp.open(sftp.FOLDER + '/duck.txt', 'w') as f: + with sftp.open(sftp.FOLDER + "/duck.txt", "w") as f: f.write(ARTICLE) - assert sftp.stat(sftp.FOLDER + '/duck.txt').st_size == 1483 + assert sftp.stat(sftp.FOLDER + "/duck.txt").st_size == 1483 finally: - sftp.remove(sftp.FOLDER + '/duck.txt') + sftp.remove(sftp.FOLDER + "/duck.txt") def test_3_sftp_file_can_be_used_as_context_manager(self, sftp): """ verify that an opened file can be used as a context manager """ try: - with sftp.open(sftp.FOLDER + '/duck.txt', 'w') as f: + with sftp.open(sftp.FOLDER + "/duck.txt", "w") as f: f.write(ARTICLE) - assert sftp.stat(sftp.FOLDER + '/duck.txt').st_size == 1483 + assert sftp.stat(sftp.FOLDER + "/duck.txt").st_size == 1483 finally: - sftp.remove(sftp.FOLDER + '/duck.txt') + sftp.remove(sftp.FOLDER + "/duck.txt") def test_4_append(self, sftp): """ verify that a file can be opened for append, and tell() still works. """ try: - with sftp.open(sftp.FOLDER + '/append.txt', 'w') as f: - f.write('first line\nsecond line\n') + with sftp.open(sftp.FOLDER + "/append.txt", "w") as f: + f.write("first line\nsecond line\n") assert f.tell() == 23 - with sftp.open(sftp.FOLDER + '/append.txt', 'a+') as f: - f.write('third line!!!\n') + with sftp.open(sftp.FOLDER + "/append.txt", "a+") as f: + f.write("third line!!!\n") assert f.tell() == 37 assert f.stat().st_size == 37 f.seek(-26, f.SEEK_CUR) - assert f.readline() == 'second line\n' + assert f.readline() == "second line\n" finally: - sftp.remove(sftp.FOLDER + '/append.txt') + sftp.remove(sftp.FOLDER + "/append.txt") def test_5_rename(self, sftp): """ verify that renaming a file works. """ try: - with sftp.open(sftp.FOLDER + '/first.txt', 'w') as f: - f.write('content!\n') - sftp.rename(sftp.FOLDER + '/first.txt', sftp.FOLDER + '/second.txt') - with pytest.raises(IOError, match='No such file'): - sftp.open(sftp.FOLDER + '/first.txt', 'r') - with sftp.open(sftp.FOLDER + '/second.txt', 'r') as f: + with sftp.open(sftp.FOLDER + "/first.txt", "w") as f: + f.write("content!\n") + sftp.rename( + sftp.FOLDER + "/first.txt", sftp.FOLDER + "/second.txt" + ) + with pytest.raises(IOError, match="No such file"): + sftp.open(sftp.FOLDER + "/first.txt", "r") + with sftp.open(sftp.FOLDER + "/second.txt", "r") as f: f.seek(-6, f.SEEK_END) - assert u(f.read(4)) == 'tent' + assert u(f.read(4)) == "tent" finally: # TODO: this is gross, make some sort of 'remove if possible' / 'rm # -f' a-like, jeez try: - sftp.remove(sftp.FOLDER + '/first.txt') + sftp.remove(sftp.FOLDER + "/first.txt") except: pass try: - sftp.remove(sftp.FOLDER + '/second.txt') + sftp.remove(sftp.FOLDER + "/second.txt") except: pass - def test_5a_posix_rename(self, sftp): """Test posix-rename@openssh.com protocol extension.""" try: # first check that the normal rename works as specified - with sftp.open(sftp.FOLDER + '/a', 'w') as f: - f.write('one') - sftp.rename(sftp.FOLDER + '/a', sftp.FOLDER + '/b') - with sftp.open(sftp.FOLDER + '/a', 'w') as f: - f.write('two') - with pytest.raises(IOError): # actual message seems generic - sftp.rename(sftp.FOLDER + '/a', sftp.FOLDER + '/b') + with sftp.open(sftp.FOLDER + "/a", "w") as f: + f.write("one") + sftp.rename(sftp.FOLDER + "/a", sftp.FOLDER + "/b") + with sftp.open(sftp.FOLDER + "/a", "w") as f: + f.write("two") + with pytest.raises(IOError): # actual message seems generic + sftp.rename(sftp.FOLDER + "/a", sftp.FOLDER + "/b") # now check with the posix_rename - sftp.posix_rename(sftp.FOLDER + '/a', sftp.FOLDER + '/b') - with sftp.open(sftp.FOLDER + '/b', 'r') as f: + sftp.posix_rename(sftp.FOLDER + "/a", sftp.FOLDER + "/b") + with sftp.open(sftp.FOLDER + "/b", "r") as f: data = u(f.read()) err = "Contents of renamed file not the same as original file" - assert 'two' == data, err + assert "two" == data, err finally: try: - sftp.remove(sftp.FOLDER + '/a') + sftp.remove(sftp.FOLDER + "/a") except: pass try: - sftp.remove(sftp.FOLDER + '/b') + sftp.remove(sftp.FOLDER + "/b") except: pass - def test_6_folder(self, sftp): """ create a temporary folder, verify that we can create a file in it, then remove the folder and verify that we can't create a file in it anymore. """ - sftp.mkdir(sftp.FOLDER + '/subfolder') - sftp.open(sftp.FOLDER + '/subfolder/test', 'w').close() - sftp.remove(sftp.FOLDER + '/subfolder/test') - sftp.rmdir(sftp.FOLDER + '/subfolder') + sftp.mkdir(sftp.FOLDER + "/subfolder") + sftp.open(sftp.FOLDER + "/subfolder/test", "w").close() + sftp.remove(sftp.FOLDER + "/subfolder/test") + sftp.rmdir(sftp.FOLDER + "/subfolder") # shouldn't be able to create that file if dir removed with pytest.raises(IOError, match="No such file"): - sftp.open(sftp.FOLDER + '/subfolder/test') + sftp.open(sftp.FOLDER + "/subfolder/test") def test_7_listdir(self, sftp): """ @@ -233,57 +234,57 @@ class TestSFTP(object): it, and those files show up in sftp.listdir. """ try: - sftp.open(sftp.FOLDER + '/duck.txt', 'w').close() - sftp.open(sftp.FOLDER + '/fish.txt', 'w').close() - sftp.open(sftp.FOLDER + '/tertiary.py', 'w').close() + sftp.open(sftp.FOLDER + "/duck.txt", "w").close() + sftp.open(sftp.FOLDER + "/fish.txt", "w").close() + sftp.open(sftp.FOLDER + "/tertiary.py", "w").close() x = sftp.listdir(sftp.FOLDER) assert len(x) == 3 - assert 'duck.txt' in x - assert 'fish.txt' in x - assert 'tertiary.py' in x - assert 'random' not in x + assert "duck.txt" in x + assert "fish.txt" in x + assert "tertiary.py" in x + assert "random" not in x finally: - sftp.remove(sftp.FOLDER + '/duck.txt') - sftp.remove(sftp.FOLDER + '/fish.txt') - sftp.remove(sftp.FOLDER + '/tertiary.py') + sftp.remove(sftp.FOLDER + "/duck.txt") + sftp.remove(sftp.FOLDER + "/fish.txt") + sftp.remove(sftp.FOLDER + "/tertiary.py") def test_7_5_listdir_iter(self, sftp): """ listdir_iter version of above test """ try: - sftp.open(sftp.FOLDER + '/duck.txt', 'w').close() - sftp.open(sftp.FOLDER + '/fish.txt', 'w').close() - sftp.open(sftp.FOLDER + '/tertiary.py', 'w').close() + sftp.open(sftp.FOLDER + "/duck.txt", "w").close() + sftp.open(sftp.FOLDER + "/fish.txt", "w").close() + sftp.open(sftp.FOLDER + "/tertiary.py", "w").close() x = [x.filename for x in sftp.listdir_iter(sftp.FOLDER)] assert len(x) == 3 - assert 'duck.txt' in x - assert 'fish.txt' in x - assert 'tertiary.py' in x - assert 'random' not in x + assert "duck.txt" in x + assert "fish.txt" in x + assert "tertiary.py" in x + assert "random" not in x finally: - sftp.remove(sftp.FOLDER + '/duck.txt') - sftp.remove(sftp.FOLDER + '/fish.txt') - sftp.remove(sftp.FOLDER + '/tertiary.py') + sftp.remove(sftp.FOLDER + "/duck.txt") + sftp.remove(sftp.FOLDER + "/fish.txt") + sftp.remove(sftp.FOLDER + "/tertiary.py") def test_8_setstat(self, sftp): """ verify that the setstat functions (chown, chmod, utime, truncate) work. """ try: - with sftp.open(sftp.FOLDER + '/special', 'w') as f: - f.write('x' * 1024) + with sftp.open(sftp.FOLDER + "/special", "w") as f: + f.write("x" * 1024) - stat = sftp.stat(sftp.FOLDER + '/special') - sftp.chmod(sftp.FOLDER + '/special', (stat.st_mode & ~o777) | o600) - stat = sftp.stat(sftp.FOLDER + '/special') + stat = sftp.stat(sftp.FOLDER + "/special") + sftp.chmod(sftp.FOLDER + "/special", (stat.st_mode & ~o777) | o600) + stat = sftp.stat(sftp.FOLDER + "/special") expected_mode = o600 - if sys.platform == 'win32': + if sys.platform == "win32": # chmod not really functional on windows expected_mode = o666 - if sys.platform == 'cygwin': + if sys.platform == "cygwin": # even worse. expected_mode = o644 assert stat.st_mode & o777 == expected_mode @@ -291,19 +292,19 @@ class TestSFTP(object): mtime = stat.st_mtime - 3600 atime = stat.st_atime - 1800 - sftp.utime(sftp.FOLDER + '/special', (atime, mtime)) - stat = sftp.stat(sftp.FOLDER + '/special') + sftp.utime(sftp.FOLDER + "/special", (atime, mtime)) + stat = sftp.stat(sftp.FOLDER + "/special") assert stat.st_mtime == mtime - if sys.platform not in ('win32', 'cygwin'): + if sys.platform not in ("win32", "cygwin"): assert stat.st_atime == atime # can't really test chown, since we'd have to know a valid uid. - sftp.truncate(sftp.FOLDER + '/special', 512) - stat = sftp.stat(sftp.FOLDER + '/special') + sftp.truncate(sftp.FOLDER + "/special", 512) + stat = sftp.stat(sftp.FOLDER + "/special") assert stat.st_size == 512 finally: - sftp.remove(sftp.FOLDER + '/special') + sftp.remove(sftp.FOLDER + "/special") def test_9_fsetstat(self, sftp): """ @@ -311,19 +312,19 @@ class TestSFTP(object): work on open files. """ try: - with sftp.open(sftp.FOLDER + '/special', 'w') as f: - f.write('x' * 1024) + with sftp.open(sftp.FOLDER + "/special", "w") as f: + f.write("x" * 1024) - with sftp.open(sftp.FOLDER + '/special', 'r+') as f: + with sftp.open(sftp.FOLDER + "/special", "r+") as f: stat = f.stat() f.chmod((stat.st_mode & ~o777) | o600) stat = f.stat() expected_mode = o600 - if sys.platform == 'win32': + if sys.platform == "win32": # chmod not really functional on windows expected_mode = o666 - if sys.platform == 'cygwin': + if sys.platform == "cygwin": # even worse. expected_mode = o644 assert stat.st_mode & o777 == expected_mode @@ -334,7 +335,7 @@ class TestSFTP(object): f.utime((atime, mtime)) stat = f.stat() assert stat.st_mtime == mtime - if sys.platform not in ('win32', 'cygwin'): + if sys.platform not in ("win32", "cygwin"): assert stat.st_atime == atime # can't really test chown, since we'd have to know a valid uid. @@ -343,7 +344,7 @@ class TestSFTP(object): stat = f.stat() assert stat.st_size == 512 finally: - sftp.remove(sftp.FOLDER + '/special') + sftp.remove(sftp.FOLDER + "/special") def test_A_readline_seek(self, sftp): """ @@ -353,10 +354,10 @@ class TestSFTP(object): buffering is reset on 'seek'. """ try: - with sftp.open(sftp.FOLDER + '/duck.txt', 'w') as f: + with sftp.open(sftp.FOLDER + "/duck.txt", "w") as f: f.write(ARTICLE) - with sftp.open(sftp.FOLDER + '/duck.txt', 'r+') as f: + with sftp.open(sftp.FOLDER + "/duck.txt", "r+") as f: line_number = 0 loc = 0 pos_list = [] @@ -366,13 +367,16 @@ class TestSFTP(object): loc = f.tell() assert f.seekable() f.seek(pos_list[6], f.SEEK_SET) - assert f.readline(), 'Nouzilly == France.\n' + assert f.readline(), "Nouzilly == France.\n" f.seek(pos_list[17], f.SEEK_SET) - assert f.readline()[:4] == 'duck' + assert f.readline()[:4] == "duck" f.seek(pos_list[10], f.SEEK_SET) - assert f.readline() == 'duck types were equally resistant to exogenous insulin compared with chicken.\n' + assert ( + f.readline() + == "duck types were equally resistant to exogenous insulin compared with chicken.\n" + ) finally: - sftp.remove(sftp.FOLDER + '/duck.txt') + sftp.remove(sftp.FOLDER + "/duck.txt") def test_B_write_seek(self, sftp): """ @@ -380,17 +384,17 @@ class TestSFTP(object): changes worked. """ try: - with sftp.open(sftp.FOLDER + '/testing.txt', 'w') as f: - f.write('hello kitty.\n') + with sftp.open(sftp.FOLDER + "/testing.txt", "w") as f: + f.write("hello kitty.\n") f.seek(-5, f.SEEK_CUR) - f.write('dd') + f.write("dd") - assert sftp.stat(sftp.FOLDER + '/testing.txt').st_size == 13 - with sftp.open(sftp.FOLDER + '/testing.txt', 'r') as f: + assert sftp.stat(sftp.FOLDER + "/testing.txt").st_size == 13 + with sftp.open(sftp.FOLDER + "/testing.txt", "r") as f: data = f.read(20) - assert data == b'hello kiddy.\n' + assert data == b"hello kiddy.\n" finally: - sftp.remove(sftp.FOLDER + '/testing.txt') + sftp.remove(sftp.FOLDER + "/testing.txt") def test_C_symlink(self, sftp): """ @@ -401,39 +405,41 @@ class TestSFTP(object): return try: - with sftp.open(sftp.FOLDER + '/original.txt', 'w') as f: - f.write('original\n') - sftp.symlink('original.txt', sftp.FOLDER + '/link.txt') - assert sftp.readlink(sftp.FOLDER + '/link.txt') == 'original.txt' + with sftp.open(sftp.FOLDER + "/original.txt", "w") as f: + f.write("original\n") + sftp.symlink("original.txt", sftp.FOLDER + "/link.txt") + assert sftp.readlink(sftp.FOLDER + "/link.txt") == "original.txt" - with sftp.open(sftp.FOLDER + '/link.txt', 'r') as f: - assert f.readlines() == ['original\n'] + with sftp.open(sftp.FOLDER + "/link.txt", "r") as f: + assert f.readlines() == ["original\n"] - cwd = sftp.normalize('.') - if cwd[-1] == '/': + cwd = sftp.normalize(".") + if cwd[-1] == "/": cwd = cwd[:-1] - abs_path = cwd + '/' + sftp.FOLDER + '/original.txt' - sftp.symlink(abs_path, sftp.FOLDER + '/link2.txt') - assert abs_path == sftp.readlink(sftp.FOLDER + '/link2.txt') + abs_path = cwd + "/" + sftp.FOLDER + "/original.txt" + sftp.symlink(abs_path, sftp.FOLDER + "/link2.txt") + assert abs_path == sftp.readlink(sftp.FOLDER + "/link2.txt") - assert sftp.lstat(sftp.FOLDER + '/link.txt').st_size == 12 - assert sftp.stat(sftp.FOLDER + '/link.txt').st_size == 9 + assert sftp.lstat(sftp.FOLDER + "/link.txt").st_size == 12 + assert sftp.stat(sftp.FOLDER + "/link.txt").st_size == 9 # the sftp server may be hiding extra path members from us, so the # length may be longer than we expect: - assert sftp.lstat(sftp.FOLDER + '/link2.txt').st_size >= len(abs_path) - assert sftp.stat(sftp.FOLDER + '/link2.txt').st_size == 9 - assert sftp.stat(sftp.FOLDER + '/original.txt').st_size == 9 + assert sftp.lstat(sftp.FOLDER + "/link2.txt").st_size >= len( + abs_path + ) + assert sftp.stat(sftp.FOLDER + "/link2.txt").st_size == 9 + assert sftp.stat(sftp.FOLDER + "/original.txt").st_size == 9 finally: try: - sftp.remove(sftp.FOLDER + '/link.txt') + sftp.remove(sftp.FOLDER + "/link.txt") except: pass try: - sftp.remove(sftp.FOLDER + '/link2.txt') + sftp.remove(sftp.FOLDER + "/link2.txt") except: pass try: - sftp.remove(sftp.FOLDER + '/original.txt') + sftp.remove(sftp.FOLDER + "/original.txt") except: pass @@ -442,18 +448,18 @@ class TestSFTP(object): verify that buffered writes are automatically flushed on seek. """ try: - with sftp.open(sftp.FOLDER + '/happy.txt', 'w', 1) as f: - f.write('full line.\n') - f.write('partial') + with sftp.open(sftp.FOLDER + "/happy.txt", "w", 1) as f: + f.write("full line.\n") + f.write("partial") f.seek(9, f.SEEK_SET) - f.write('?\n') + f.write("?\n") - with sftp.open(sftp.FOLDER + '/happy.txt', 'r') as f: - assert f.readline() == u('full line?\n') - assert f.read(7) == b'partial' + with sftp.open(sftp.FOLDER + "/happy.txt", "r") as f: + assert f.readline() == u("full line?\n") + assert f.read(7) == b"partial" finally: try: - sftp.remove(sftp.FOLDER + '/happy.txt') + sftp.remove(sftp.FOLDER + "/happy.txt") except: pass @@ -462,9 +468,9 @@ class TestSFTP(object): test that realpath is returning something non-empty and not an error. """ - pwd = sftp.normalize('.') + pwd = sftp.normalize(".") assert len(pwd) > 0 - f = sftp.normalize('./' + sftp.FOLDER) + f = sftp.normalize("./" + sftp.FOLDER) assert len(f) > 0 assert os.path.join(pwd, sftp.FOLDER) == f @@ -472,46 +478,46 @@ class TestSFTP(object): """ verify that mkdir/rmdir work. """ - sftp.mkdir(sftp.FOLDER + '/subfolder') - with pytest.raises(IOError): # generic msg only - sftp.mkdir(sftp.FOLDER + '/subfolder') - sftp.rmdir(sftp.FOLDER + '/subfolder') + sftp.mkdir(sftp.FOLDER + "/subfolder") + with pytest.raises(IOError): # generic msg only + sftp.mkdir(sftp.FOLDER + "/subfolder") + sftp.rmdir(sftp.FOLDER + "/subfolder") with pytest.raises(IOError, match="No such file"): - sftp.rmdir(sftp.FOLDER + '/subfolder') + sftp.rmdir(sftp.FOLDER + "/subfolder") def test_G_chdir(self, sftp): """ verify that chdir/getcwd work. """ - root = sftp.normalize('.') - if root[-1] != '/': - root += '/' + root = sftp.normalize(".") + if root[-1] != "/": + root += "/" try: - sftp.mkdir(sftp.FOLDER + '/alpha') - sftp.chdir(sftp.FOLDER + '/alpha') - sftp.mkdir('beta') - assert root + sftp.FOLDER + '/alpha' == sftp.getcwd() - assert ['beta'] == sftp.listdir('.') - - sftp.chdir('beta') - with sftp.open('fish', 'w') as f: - f.write('hello\n') - sftp.chdir('..') - assert ['fish'] == sftp.listdir('beta') - sftp.chdir('..') - assert ['fish'] == sftp.listdir('alpha/beta') + sftp.mkdir(sftp.FOLDER + "/alpha") + sftp.chdir(sftp.FOLDER + "/alpha") + sftp.mkdir("beta") + assert root + sftp.FOLDER + "/alpha" == sftp.getcwd() + assert ["beta"] == sftp.listdir(".") + + sftp.chdir("beta") + with sftp.open("fish", "w") as f: + f.write("hello\n") + sftp.chdir("..") + assert ["fish"] == sftp.listdir("beta") + sftp.chdir("..") + assert ["fish"] == sftp.listdir("alpha/beta") finally: sftp.chdir(root) try: - sftp.unlink(sftp.FOLDER + '/alpha/beta/fish') + sftp.unlink(sftp.FOLDER + "/alpha/beta/fish") except: pass try: - sftp.rmdir(sftp.FOLDER + '/alpha/beta') + sftp.rmdir(sftp.FOLDER + "/alpha/beta") except: pass try: - sftp.rmdir(sftp.FOLDER + '/alpha') + sftp.rmdir(sftp.FOLDER + "/alpha") except: pass @@ -519,20 +525,21 @@ class TestSFTP(object): """ verify that get/put work. """ - warnings.filterwarnings('ignore', 'tempnam.*') + warnings.filterwarnings("ignore", "tempnam.*") fd, localname = mkstemp() os.close(fd) - text = b'All I wanted was a plastic bunny rabbit.\n' - with open(localname, 'wb') as f: + text = b"All I wanted was a plastic bunny rabbit.\n" + with open(localname, "wb") as f: f.write(text) saved_progress = [] def progress_callback(x, y): saved_progress.append((x, y)) - sftp.put(localname, sftp.FOLDER + '/bunny.txt', progress_callback) - with sftp.open(sftp.FOLDER + '/bunny.txt', 'rb') as f: + sftp.put(localname, sftp.FOLDER + "/bunny.txt", progress_callback) + + with sftp.open(sftp.FOLDER + "/bunny.txt", "rb") as f: assert text == f.read(128) assert [(41, 41)] == saved_progress @@ -540,14 +547,14 @@ class TestSFTP(object): fd, localname = mkstemp() os.close(fd) saved_progress = [] - sftp.get(sftp.FOLDER + '/bunny.txt', localname, progress_callback) + sftp.get(sftp.FOLDER + "/bunny.txt", localname, progress_callback) - with open(localname, 'rb') as f: + with open(localname, "rb") as f: assert text == f.read(128) assert [(41, 41)] == saved_progress os.unlink(localname) - sftp.unlink(sftp.FOLDER + '/bunny.txt') + sftp.unlink(sftp.FOLDER + "/bunny.txt") def test_I_check(self, sftp): """ @@ -555,118 +562,132 @@ class TestSFTP(object): (it's an sftp extension that we support, and may be the only ones who support it.) """ - with sftp.open(sftp.FOLDER + '/kitty.txt', 'w') as f: - f.write('here kitty kitty' * 64) + with sftp.open(sftp.FOLDER + "/kitty.txt", "w") as f: + f.write("here kitty kitty" * 64) try: - with sftp.open(sftp.FOLDER + '/kitty.txt', 'r') as f: - sum = f.check('sha1') - assert '91059CFC6615941378D413CB5ADAF4C5EB293402' == u(hexlify(sum)).upper() - sum = f.check('md5', 0, 512) - assert '93DE4788FCA28D471516963A1FE3856A' == u(hexlify(sum)).upper() - sum = f.check('md5', 0, 0, 510) - assert u(hexlify(sum)).upper() == 'EB3B45B8CD55A0707D99B177544A319F373183D241432BB2157AB9E46358C4AC90370B5CADE5D90336FC1716F90B36D6' # noqa + with sftp.open(sftp.FOLDER + "/kitty.txt", "r") as f: + sum = f.check("sha1") + assert ( + "91059CFC6615941378D413CB5ADAF4C5EB293402" + == u(hexlify(sum)).upper() + ) + sum = f.check("md5", 0, 512) + assert ( + "93DE4788FCA28D471516963A1FE3856A" + == u(hexlify(sum)).upper() + ) + sum = f.check("md5", 0, 0, 510) + assert ( + u(hexlify(sum)).upper() + == "EB3B45B8CD55A0707D99B177544A319F373183D241432BB2157AB9E46358C4AC90370B5CADE5D90336FC1716F90B36D6" + ) # noqa finally: - sftp.unlink(sftp.FOLDER + '/kitty.txt') + sftp.unlink(sftp.FOLDER + "/kitty.txt") def test_J_x_flag(self, sftp): """ verify that the 'x' flag works when opening a file. """ - sftp.open(sftp.FOLDER + '/unusual.txt', 'wx').close() + sftp.open(sftp.FOLDER + "/unusual.txt", "wx").close() try: try: - sftp.open(sftp.FOLDER + '/unusual.txt', 'wx') - self.fail('expected exception') + sftp.open(sftp.FOLDER + "/unusual.txt", "wx") + self.fail("expected exception") except IOError: pass finally: - sftp.unlink(sftp.FOLDER + '/unusual.txt') + sftp.unlink(sftp.FOLDER + "/unusual.txt") def test_K_utf8(self, sftp): """ verify that unicode strings are encoded into utf8 correctly. """ - with sftp.open(sftp.FOLDER + '/something', 'w') as f: - f.write('okay') + with sftp.open(sftp.FOLDER + "/something", "w") as f: + f.write("okay") try: - sftp.rename(sftp.FOLDER + '/something', sftp.FOLDER + '/' + unicode_folder) - sftp.open(b(sftp.FOLDER) + utf8_folder, 'r') + sftp.rename( + sftp.FOLDER + "/something", sftp.FOLDER + "/" + unicode_folder + ) + sftp.open(b(sftp.FOLDER) + utf8_folder, "r") except Exception as e: - self.fail('exception ' + str(e)) + self.fail("exception " + str(e)) sftp.unlink(b(sftp.FOLDER) + utf8_folder) def test_L_utf8_chdir(self, sftp): - sftp.mkdir(sftp.FOLDER + '/' + unicode_folder) + sftp.mkdir(sftp.FOLDER + "/" + unicode_folder) try: - sftp.chdir(sftp.FOLDER + '/' + unicode_folder) - with sftp.open('something', 'w') as f: - f.write('okay') - sftp.unlink('something') + sftp.chdir(sftp.FOLDER + "/" + unicode_folder) + with sftp.open("something", "w") as f: + f.write("okay") + sftp.unlink("something") finally: sftp.chdir() - sftp.rmdir(sftp.FOLDER + '/' + unicode_folder) + sftp.rmdir(sftp.FOLDER + "/" + unicode_folder) def test_M_bad_readv(self, sftp): """ verify that readv at the end of the file doesn't essplode. """ - sftp.open(sftp.FOLDER + '/zero', 'w').close() + sftp.open(sftp.FOLDER + "/zero", "w").close() try: - with sftp.open(sftp.FOLDER + '/zero', 'r') as f: + with sftp.open(sftp.FOLDER + "/zero", "r") as f: f.readv([(0, 12)]) - with sftp.open(sftp.FOLDER + '/zero', 'r') as f: + with sftp.open(sftp.FOLDER + "/zero", "r") as f: file_size = f.stat().st_size f.prefetch(file_size) f.read(100) finally: - sftp.unlink(sftp.FOLDER + '/zero') + sftp.unlink(sftp.FOLDER + "/zero") def test_N_put_without_confirm(self, sftp): """ verify that get/put work without confirmation. """ - warnings.filterwarnings('ignore', 'tempnam.*') + warnings.filterwarnings("ignore", "tempnam.*") fd, localname = mkstemp() os.close(fd) - text = b'All I wanted was a plastic bunny rabbit.\n' - with open(localname, 'wb') as f: + text = b"All I wanted was a plastic bunny rabbit.\n" + with open(localname, "wb") as f: f.write(text) saved_progress = [] def progress_callback(x, y): saved_progress.append((x, y)) - res = sftp.put(localname, sftp.FOLDER + '/bunny.txt', progress_callback, False) + + res = sftp.put( + localname, sftp.FOLDER + "/bunny.txt", progress_callback, False + ) assert SFTPAttributes().attr == res.attr - with sftp.open(sftp.FOLDER + '/bunny.txt', 'r') as f: + with sftp.open(sftp.FOLDER + "/bunny.txt", "r") as f: assert text == f.read(128) assert (41, 41) == saved_progress[-1] os.unlink(localname) - sftp.unlink(sftp.FOLDER + '/bunny.txt') + sftp.unlink(sftp.FOLDER + "/bunny.txt") def test_O_getcwd(self, sftp): """ verify that chdir/getcwd work. """ assert sftp.getcwd() == None - root = sftp.normalize('.') - if root[-1] != '/': - root += '/' + root = sftp.normalize(".") + if root[-1] != "/": + root += "/" try: - sftp.mkdir(sftp.FOLDER + '/alpha') - sftp.chdir(sftp.FOLDER + '/alpha') - assert sftp.getcwd() == '/' + sftp.FOLDER + '/alpha' + sftp.mkdir(sftp.FOLDER + "/alpha") + sftp.chdir(sftp.FOLDER + "/alpha") + assert sftp.getcwd() == "/" + sftp.FOLDER + "/alpha" finally: sftp.chdir(root) try: - sftp.rmdir(sftp.FOLDER + '/alpha') + sftp.rmdir(sftp.FOLDER + "/alpha") except: pass @@ -677,24 +698,24 @@ class TestSFTP(object): does not work except through paramiko. :( openssh fails. """ try: - with sftp.open(sftp.FOLDER + '/append.txt', 'a') as f: - f.write('first line\nsecond line\n') + with sftp.open(sftp.FOLDER + "/append.txt", "a") as f: + f.write("first line\nsecond line\n") f.seek(11, f.SEEK_SET) - f.write('third line\n') + f.write("third line\n") - with sftp.open(sftp.FOLDER + '/append.txt', 'r') as f: + with sftp.open(sftp.FOLDER + "/append.txt", "r") as f: assert f.stat().st_size == 34 - assert f.readline() == 'first line\n' - assert f.readline() == 'second line\n' - assert f.readline() == 'third line\n' + assert f.readline() == "first line\n" + assert f.readline() == "second line\n" + assert f.readline() == "third line\n" finally: - sftp.remove(sftp.FOLDER + '/append.txt') + sftp.remove(sftp.FOLDER + "/append.txt") def test_putfo_empty_file(self, sftp): """ Send an empty file and confirm it is sent. """ - target = sftp.FOLDER + '/empty file.txt' + target = sftp.FOLDER + "/empty file.txt" stream = StringIO() try: attrs = sftp.putfo(stream, target) @@ -713,59 +734,61 @@ class TestSFTP(object): verify that we can create a file with a '%' in the filename. ( it needs to be properly escaped by _log() ) """ - f = sftp.open(sftp.FOLDER + '/test%file', 'w') + f = sftp.open(sftp.FOLDER + "/test%file", "w") try: assert f.stat().st_size == 0 finally: f.close() - sftp.remove(sftp.FOLDER + '/test%file') + sftp.remove(sftp.FOLDER + "/test%file") def test_O_non_utf8_data(self, sftp): """Test write() and read() of non utf8 data""" try: - with sftp.open('%s/nonutf8data' % sftp.FOLDER, 'w') as f: + with sftp.open("%s/nonutf8data" % sftp.FOLDER, "w") as f: f.write(NON_UTF8_DATA) - with sftp.open('%s/nonutf8data' % sftp.FOLDER, 'r') as f: + with sftp.open("%s/nonutf8data" % sftp.FOLDER, "r") as f: data = f.read() assert data == NON_UTF8_DATA - with sftp.open('%s/nonutf8data' % sftp.FOLDER, 'wb') as f: + with sftp.open("%s/nonutf8data" % sftp.FOLDER, "wb") as f: f.write(NON_UTF8_DATA) - with sftp.open('%s/nonutf8data' % sftp.FOLDER, 'rb') as f: + with sftp.open("%s/nonutf8data" % sftp.FOLDER, "rb") as f: data = f.read() assert data == NON_UTF8_DATA finally: - sftp.remove('%s/nonutf8data' % sftp.FOLDER) - + sftp.remove("%s/nonutf8data" % sftp.FOLDER) def test_sftp_attributes_empty_str(self, sftp): sftp_attributes = SFTPAttributes() - assert str(sftp_attributes) == "?--------- 1 0 0 0 (unknown date) ?" + assert ( + str(sftp_attributes) + == "?--------- 1 0 0 0 (unknown date) ?" + ) - @needs_builtin('buffer') + @needs_builtin("buffer") def test_write_buffer(self, sftp): """Test write() using a buffer instance.""" - data = 3 * b'A potentially large block of data to chunk up.\n' + data = 3 * b"A potentially large block of data to chunk up.\n" try: - with sftp.open('%s/write_buffer' % sftp.FOLDER, 'wb') as f: + with sftp.open("%s/write_buffer" % sftp.FOLDER, "wb") as f: for offset in range(0, len(data), 8): f.write(buffer(data, offset, 8)) - with sftp.open('%s/write_buffer' % sftp.FOLDER, 'rb') as f: + with sftp.open("%s/write_buffer" % sftp.FOLDER, "rb") as f: assert f.read() == data finally: - sftp.remove('%s/write_buffer' % sftp.FOLDER) + sftp.remove("%s/write_buffer" % sftp.FOLDER) - @needs_builtin('memoryview') + @needs_builtin("memoryview") def test_write_memoryview(self, sftp): """Test write() using a memoryview instance.""" - data = 3 * b'A potentially large block of data to chunk up.\n' + data = 3 * b"A potentially large block of data to chunk up.\n" try: - with sftp.open('%s/write_memoryview' % sftp.FOLDER, 'wb') as f: + with sftp.open("%s/write_memoryview" % sftp.FOLDER, "wb") as f: view = memoryview(data) for offset in range(0, len(data), 8): - f.write(view[offset:offset+8]) + f.write(view[offset : offset + 8]) - with sftp.open('%s/write_memoryview' % sftp.FOLDER, 'rb') as f: + with sftp.open("%s/write_memoryview" % sftp.FOLDER, "rb") as f: assert f.read() == data finally: - sftp.remove('%s/write_memoryview' % sftp.FOLDER) + sftp.remove("%s/write_memoryview" % sftp.FOLDER) diff --git a/tests/test_sftp_big.py b/tests/test_sftp_big.py index a659098d..97c0eb90 100644 --- a/tests/test_sftp_big.py +++ b/tests/test_sftp_big.py @@ -37,6 +37,7 @@ from .util import slow @slow class TestBigSFTP(object): + def test_1_lots_of_files(self, sftp): """ create a bunch of files over the same session. @@ -44,22 +45,24 @@ class TestBigSFTP(object): numfiles = 100 try: for i in range(numfiles): - with sftp.open('%s/file%d.txt' % (sftp.FOLDER, i), 'w', 1) as f: - f.write('this is file #%d.\n' % i) - sftp.chmod('%s/file%d.txt' % (sftp.FOLDER, i), o660) + with sftp.open( + "%s/file%d.txt" % (sftp.FOLDER, i), "w", 1 + ) as f: + f.write("this is file #%d.\n" % i) + sftp.chmod("%s/file%d.txt" % (sftp.FOLDER, i), o660) # now make sure every file is there, by creating a list of filenmes # and reading them in random order. numlist = list(range(numfiles)) while len(numlist) > 0: r = numlist[random.randint(0, len(numlist) - 1)] - with sftp.open('%s/file%d.txt' % (sftp.FOLDER, r)) as f: - assert f.readline() == 'this is file #%d.\n' % r + with sftp.open("%s/file%d.txt" % (sftp.FOLDER, r)) as f: + assert f.readline() == "this is file #%d.\n" % r numlist.remove(r) finally: for i in range(numfiles): try: - sftp.remove('%s/file%d.txt' % (sftp.FOLDER, i)) + sftp.remove("%s/file%d.txt" % (sftp.FOLDER, i)) except: pass @@ -67,52 +70,56 @@ class TestBigSFTP(object): """ write a 1MB file with no buffering. """ - kblob = (1024 * b'x') + kblob = 1024 * b"x" start = time.time() try: - with sftp.open('%s/hongry.txt' % sftp.FOLDER, 'w') as f: + with sftp.open("%s/hongry.txt" % sftp.FOLDER, "w") as f: for n in range(1024): f.write(kblob) if n % 128 == 0: - sys.stderr.write('.') - sys.stderr.write(' ') + sys.stderr.write(".") + sys.stderr.write(" ") - assert sftp.stat('%s/hongry.txt' % sftp.FOLDER).st_size == 1024 * 1024 + assert ( + sftp.stat("%s/hongry.txt" % sftp.FOLDER).st_size == 1024 * 1024 + ) end = time.time() - sys.stderr.write('%ds ' % round(end - start)) - + sys.stderr.write("%ds " % round(end - start)) + start = time.time() - with sftp.open('%s/hongry.txt' % sftp.FOLDER, 'r') as f: + with sftp.open("%s/hongry.txt" % sftp.FOLDER, "r") as f: for n in range(1024): data = f.read(1024) assert data == kblob end = time.time() - sys.stderr.write('%ds ' % round(end - start)) + sys.stderr.write("%ds " % round(end - start)) finally: - sftp.remove('%s/hongry.txt' % sftp.FOLDER) + sftp.remove("%s/hongry.txt" % sftp.FOLDER) def test_3_big_file_pipelined(self, sftp): """ write a 1MB file, with no linefeeds, using pipelining. """ - kblob = bytes().join([struct.pack('>H', n) for n in range(512)]) + kblob = bytes().join([struct.pack(">H", n) for n in range(512)]) start = time.time() try: - with sftp.open('%s/hongry.txt' % sftp.FOLDER, 'wb') as f: + with sftp.open("%s/hongry.txt" % sftp.FOLDER, "wb") as f: f.set_pipelined(True) for n in range(1024): f.write(kblob) if n % 128 == 0: - sys.stderr.write('.') - sys.stderr.write(' ') + sys.stderr.write(".") + sys.stderr.write(" ") - assert sftp.stat('%s/hongry.txt' % sftp.FOLDER).st_size == 1024 * 1024 + assert ( + sftp.stat("%s/hongry.txt" % sftp.FOLDER).st_size == 1024 * 1024 + ) end = time.time() - sys.stderr.write('%ds ' % round(end - start)) - + sys.stderr.write("%ds " % round(end - start)) + start = time.time() - with sftp.open('%s/hongry.txt' % sftp.FOLDER, 'rb') as f: + with sftp.open("%s/hongry.txt" % sftp.FOLDER, "rb") as f: file_size = f.stat().st_size f.prefetch(file_size) @@ -126,35 +133,39 @@ class TestBigSFTP(object): chunk = size - n data = f.read(chunk) offset = n % 1024 - assert data == k2blob[offset:offset + chunk] + assert data == k2blob[offset : offset + chunk] n += chunk end = time.time() - sys.stderr.write('%ds ' % round(end - start)) + sys.stderr.write("%ds " % round(end - start)) finally: - sftp.remove('%s/hongry.txt' % sftp.FOLDER) + sftp.remove("%s/hongry.txt" % sftp.FOLDER) def test_4_prefetch_seek(self, sftp): - kblob = bytes().join([struct.pack('>H', n) for n in range(512)]) + kblob = bytes().join([struct.pack(">H", n) for n in range(512)]) try: - with sftp.open('%s/hongry.txt' % sftp.FOLDER, 'wb') as f: + with sftp.open("%s/hongry.txt" % sftp.FOLDER, "wb") as f: f.set_pipelined(True) for n in range(1024): f.write(kblob) if n % 128 == 0: - sys.stderr.write('.') - sys.stderr.write(' ') - - assert sftp.stat('%s/hongry.txt' % sftp.FOLDER).st_size == 1024 * 1024 - + sys.stderr.write(".") + sys.stderr.write(" ") + + assert ( + sftp.stat("%s/hongry.txt" % sftp.FOLDER).st_size == 1024 * 1024 + ) + start = time.time() k2blob = kblob + kblob chunk = 793 for i in range(10): - with sftp.open('%s/hongry.txt' % sftp.FOLDER, 'rb') as f: + with sftp.open("%s/hongry.txt" % sftp.FOLDER, "rb") as f: file_size = f.stat().st_size f.prefetch(file_size) - base_offset = (512 * 1024) + 17 * random.randint(1000, 2000) + base_offset = (512 * 1024) + 17 * random.randint( + 1000, 2000 + ) offsets = [base_offset + j * chunk for j in range(100)] # randomly seek around and read them out for j in range(100): @@ -163,32 +174,36 @@ class TestBigSFTP(object): f.seek(offset) data = f.read(chunk) n_offset = offset % 1024 - assert data == k2blob[n_offset:n_offset + chunk] + assert data == k2blob[n_offset : n_offset + chunk] offset += chunk end = time.time() - sys.stderr.write('%ds ' % round(end - start)) + sys.stderr.write("%ds " % round(end - start)) finally: - sftp.remove('%s/hongry.txt' % sftp.FOLDER) + sftp.remove("%s/hongry.txt" % sftp.FOLDER) def test_5_readv_seek(self, sftp): - kblob = bytes().join([struct.pack('>H', n) for n in range(512)]) + kblob = bytes().join([struct.pack(">H", n) for n in range(512)]) try: - with sftp.open('%s/hongry.txt' % sftp.FOLDER, 'wb') as f: + with sftp.open("%s/hongry.txt" % sftp.FOLDER, "wb") as f: f.set_pipelined(True) for n in range(1024): f.write(kblob) if n % 128 == 0: - sys.stderr.write('.') - sys.stderr.write(' ') + sys.stderr.write(".") + sys.stderr.write(" ") - assert sftp.stat('%s/hongry.txt' % sftp.FOLDER).st_size == 1024 * 1024 + assert ( + sftp.stat("%s/hongry.txt" % sftp.FOLDER).st_size == 1024 * 1024 + ) start = time.time() k2blob = kblob + kblob chunk = 793 for i in range(10): - with sftp.open('%s/hongry.txt' % sftp.FOLDER, 'rb') as f: - base_offset = (512 * 1024) + 17 * random.randint(1000, 2000) + with sftp.open("%s/hongry.txt" % sftp.FOLDER, "rb") as f: + base_offset = (512 * 1024) + 17 * random.randint( + 1000, 2000 + ) # make a bunch of offsets and put them in random order offsets = [base_offset + j * chunk for j in range(100)] readv_list = [] @@ -200,62 +215,66 @@ class TestBigSFTP(object): for i in range(len(readv_list)): offset = readv_list[i][0] n_offset = offset % 1024 - assert next(ret) == k2blob[n_offset:n_offset + chunk] + assert next(ret) == k2blob[n_offset : n_offset + chunk] end = time.time() - sys.stderr.write('%ds ' % round(end - start)) + sys.stderr.write("%ds " % round(end - start)) finally: - sftp.remove('%s/hongry.txt' % sftp.FOLDER) + sftp.remove("%s/hongry.txt" % sftp.FOLDER) def test_6_lots_of_prefetching(self, sftp): """ prefetch a 1MB file a bunch of times, discarding the file object without using it, to verify that paramiko doesn't get confused. """ - kblob = (1024 * b'x') + kblob = 1024 * b"x" try: - with sftp.open('%s/hongry.txt' % sftp.FOLDER, 'w') as f: + with sftp.open("%s/hongry.txt" % sftp.FOLDER, "w") as f: f.set_pipelined(True) for n in range(1024): f.write(kblob) if n % 128 == 0: - sys.stderr.write('.') - sys.stderr.write(' ') + sys.stderr.write(".") + sys.stderr.write(" ") - assert sftp.stat('%s/hongry.txt' % sftp.FOLDER).st_size == 1024 * 1024 + assert ( + sftp.stat("%s/hongry.txt" % sftp.FOLDER).st_size == 1024 * 1024 + ) for i in range(10): - with sftp.open('%s/hongry.txt' % sftp.FOLDER, 'r') as f: + with sftp.open("%s/hongry.txt" % sftp.FOLDER, "r") as f: file_size = f.stat().st_size f.prefetch(file_size) - with sftp.open('%s/hongry.txt' % sftp.FOLDER, 'r') as f: + with sftp.open("%s/hongry.txt" % sftp.FOLDER, "r") as f: file_size = f.stat().st_size f.prefetch(file_size) for n in range(1024): data = f.read(1024) assert data == kblob if n % 128 == 0: - sys.stderr.write('.') - sys.stderr.write(' ') + sys.stderr.write(".") + sys.stderr.write(" ") finally: - sftp.remove('%s/hongry.txt' % sftp.FOLDER) - + sftp.remove("%s/hongry.txt" % sftp.FOLDER) + def test_7_prefetch_readv(self, sftp): """ verify that prefetch and readv don't conflict with each other. """ - kblob = bytes().join([struct.pack('>H', n) for n in range(512)]) + kblob = bytes().join([struct.pack(">H", n) for n in range(512)]) try: - with sftp.open('%s/hongry.txt' % sftp.FOLDER, 'wb') as f: + with sftp.open("%s/hongry.txt" % sftp.FOLDER, "wb") as f: f.set_pipelined(True) for n in range(1024): f.write(kblob) if n % 128 == 0: - sys.stderr.write('.') - sys.stderr.write(' ') - - assert sftp.stat('%s/hongry.txt' % sftp.FOLDER).st_size == 1024 * 1024 + sys.stderr.write(".") + sys.stderr.write(" ") + + assert ( + sftp.stat("%s/hongry.txt" % sftp.FOLDER).st_size == 1024 * 1024 + ) - with sftp.open('%s/hongry.txt' % sftp.FOLDER, 'rb') as f: + with sftp.open("%s/hongry.txt" % sftp.FOLDER, "rb") as f: file_size = f.stat().st_size f.prefetch(file_size) data = f.read(1024) @@ -264,79 +283,94 @@ class TestBigSFTP(object): chunk_size = 793 base_offset = 512 * 1024 k2blob = kblob + kblob - chunks = [(base_offset + (chunk_size * i), chunk_size) for i in range(20)] + chunks = [ + (base_offset + (chunk_size * i), chunk_size) + for i in range(20) + ] for data in f.readv(chunks): offset = base_offset % 1024 assert chunk_size == len(data) - assert k2blob[offset:offset + chunk_size] == data + assert k2blob[offset : offset + chunk_size] == data base_offset += chunk_size - sys.stderr.write(' ') + sys.stderr.write(" ") finally: - sftp.remove('%s/hongry.txt' % sftp.FOLDER) - + sftp.remove("%s/hongry.txt" % sftp.FOLDER) + def test_8_large_readv(self, sftp): """ verify that a very large readv is broken up correctly and still returned as a single blob. """ - kblob = bytes().join([struct.pack('>H', n) for n in range(512)]) + kblob = bytes().join([struct.pack(">H", n) for n in range(512)]) try: - with sftp.open('%s/hongry.txt' % sftp.FOLDER, 'wb') as f: + with sftp.open("%s/hongry.txt" % sftp.FOLDER, "wb") as f: f.set_pipelined(True) for n in range(1024): f.write(kblob) if n % 128 == 0: - sys.stderr.write('.') - sys.stderr.write(' ') + sys.stderr.write(".") + sys.stderr.write(" ") + + assert ( + sftp.stat("%s/hongry.txt" % sftp.FOLDER).st_size == 1024 * 1024 + ) - assert sftp.stat('%s/hongry.txt' % sftp.FOLDER).st_size == 1024 * 1024 - - with sftp.open('%s/hongry.txt' % sftp.FOLDER, 'rb') as f: + with sftp.open("%s/hongry.txt" % sftp.FOLDER, "rb") as f: data = list(f.readv([(23 * 1024, 128 * 1024)])) assert len(data) == 1 data = data[0] assert len(data) == 128 * 1024 - - sys.stderr.write(' ') + + sys.stderr.write(" ") finally: - sftp.remove('%s/hongry.txt' % sftp.FOLDER) - + sftp.remove("%s/hongry.txt" % sftp.FOLDER) + def test_9_big_file_big_buffer(self, sftp): """ write a 1MB file, with no linefeeds, and a big buffer. """ - mblob = (1024 * 1024 * 'x') + mblob = 1024 * 1024 * "x" try: - with sftp.open('%s/hongry.txt' % sftp.FOLDER, 'w', 128 * 1024) as f: + with sftp.open( + "%s/hongry.txt" % sftp.FOLDER, "w", 128 * 1024 + ) as f: f.write(mblob) - assert sftp.stat('%s/hongry.txt' % sftp.FOLDER).st_size == 1024 * 1024 + assert ( + sftp.stat("%s/hongry.txt" % sftp.FOLDER).st_size == 1024 * 1024 + ) finally: - sftp.remove('%s/hongry.txt' % sftp.FOLDER) - + sftp.remove("%s/hongry.txt" % sftp.FOLDER) + def test_A_big_file_renegotiate(self, sftp): """ write a 1MB file, forcing key renegotiation in the middle. """ t = sftp.sock.get_transport() t.packetizer.REKEY_BYTES = 512 * 1024 - k32blob = (32 * 1024 * 'x') + k32blob = 32 * 1024 * "x" try: - with sftp.open('%s/hongry.txt' % sftp.FOLDER, 'w', 128 * 1024) as f: + with sftp.open( + "%s/hongry.txt" % sftp.FOLDER, "w", 128 * 1024 + ) as f: for i in range(32): f.write(k32blob) - assert sftp.stat('%s/hongry.txt' % sftp.FOLDER).st_size == 1024 * 1024 + assert ( + sftp.stat("%s/hongry.txt" % sftp.FOLDER).st_size == 1024 * 1024 + ) assert t.H != t.session_id - + # try to read it too. - with sftp.open('%s/hongry.txt' % sftp.FOLDER, 'r', 128 * 1024) as f: + with sftp.open( + "%s/hongry.txt" % sftp.FOLDER, "r", 128 * 1024 + ) as f: file_size = f.stat().st_size f.prefetch(file_size) total = 0 while total < 1024 * 1024: total += len(f.read(32 * 1024)) finally: - sftp.remove('%s/hongry.txt' % sftp.FOLDER) + sftp.remove("%s/hongry.txt" % sftp.FOLDER) t.packetizer.REKEY_BYTES = pow(2, 30) diff --git a/tests/test_ssh_exception.py b/tests/test_ssh_exception.py index 18f2a97d..6cc5d06a 100644 --- a/tests/test_ssh_exception.py +++ b/tests/test_ssh_exception.py @@ -4,28 +4,33 @@ import unittest from paramiko.ssh_exception import NoValidConnectionsError -class NoValidConnectionsErrorTest (unittest.TestCase): +class NoValidConnectionsErrorTest(unittest.TestCase): def test_pickling(self): # Regression test for https://github.com/paramiko/paramiko/issues/617 - exc = NoValidConnectionsError({('127.0.0.1', '22'): Exception()}) + exc = NoValidConnectionsError({("127.0.0.1", "22"): Exception()}) new_exc = pickle.loads(pickle.dumps(exc)) self.assertEqual(type(exc), type(new_exc)) self.assertEqual(str(exc), str(new_exc)) self.assertEqual(exc.args, new_exc.args) def test_error_message_for_single_host(self): - exc = NoValidConnectionsError({('127.0.0.1', '22'): Exception()}) + exc = NoValidConnectionsError({("127.0.0.1", "22"): Exception()}) assert "Unable to connect to port 22 on 127.0.0.1" in str(exc) def test_error_message_for_two_hosts(self): - exc = NoValidConnectionsError({('127.0.0.1', '22'): Exception(), - ('::1', '22'): Exception()}) + exc = NoValidConnectionsError( + {("127.0.0.1", "22"): Exception(), ("::1", "22"): Exception()} + ) assert "Unable to connect to port 22 on 127.0.0.1 or ::1" in str(exc) def test_error_message_for_multiple_hosts(self): - exc = NoValidConnectionsError({('127.0.0.1', '22'): Exception(), - ('::1', '22'): Exception(), - ('10.0.0.42', '22'): Exception()}) + exc = NoValidConnectionsError( + { + ("127.0.0.1", "22"): Exception(), + ("::1", "22"): Exception(), + ("10.0.0.42", "22"): Exception(), + } + ) exp = "Unable to connect to port 22 on 10.0.0.42, 127.0.0.1 or ::1" assert exp in str(exc) diff --git a/tests/test_ssh_gss.py b/tests/test_ssh_gss.py index f0645e0e..cee6ce89 100644 --- a/tests/test_ssh_gss.py +++ b/tests/test_ssh_gss.py @@ -33,15 +33,13 @@ from .util import _support, needs_gssapi from .test_client import FINGERPRINTS -class NullServer (paramiko.ServerInterface): +class NullServer(paramiko.ServerInterface): + def get_allowed_auths(self, username): - return 'gssapi-with-mic,publickey' + return "gssapi-with-mic,publickey" def check_auth_gssapi_with_mic( - self, - username, - gss_authenticated=paramiko.AUTH_FAILED, - cc_file=None, + self, username, gss_authenticated=paramiko.AUTH_FAILED, cc_file=None ): if gss_authenticated == paramiko.AUTH_SUCCESSFUL: return paramiko.AUTH_SUCCESSFUL @@ -64,13 +62,14 @@ class NullServer (paramiko.ServerInterface): return paramiko.OPEN_SUCCEEDED def check_channel_exec_request(self, channel, command): - if command != 'yes': + if command != "yes": return False return True @needs_gssapi class GSSAuthTest(unittest.TestCase): + def setUp(self): # TODO: username and targ_name should come from os.environ or whatever # the approved pytest method is for runtime-configuring test data. @@ -92,7 +91,7 @@ class GSSAuthTest(unittest.TestCase): def _run(self): self.socks, addr = self.sockl.accept() self.ts = paramiko.Transport(self.socks) - host_key = paramiko.RSAKey.from_private_key_file('tests/test_rsa.key') + host_key = paramiko.RSAKey.from_private_key_file("tests/test_rsa.key") self.ts.add_server_key(host_key) server = NullServer() self.ts.start_server(self.event, server) @@ -103,15 +102,22 @@ class GSSAuthTest(unittest.TestCase): The exception is ... no exception yet """ - host_key = paramiko.RSAKey.from_private_key_file('tests/test_rsa.key') + host_key = paramiko.RSAKey.from_private_key_file("tests/test_rsa.key") public_host_key = paramiko.RSAKey(data=host_key.asbytes()) self.tc = paramiko.SSHClient() self.tc.set_missing_host_key_policy(paramiko.WarningPolicy()) - self.tc.get_host_keys().add('[%s]:%d' % (self.addr, self.port), - 'ssh-rsa', public_host_key) - self.tc.connect(hostname=self.addr, port=self.port, username=self.username, gss_host=self.hostname, - gss_auth=True, **kwargs) + self.tc.get_host_keys().add( + "[%s]:%d" % (self.addr, self.port), "ssh-rsa", public_host_key + ) + self.tc.connect( + hostname=self.addr, + port=self.port, + username=self.username, + gss_host=self.hostname, + gss_auth=True, + **kwargs + ) self.event.wait(1.0) self.assert_(self.event.is_set()) @@ -119,17 +125,17 @@ class GSSAuthTest(unittest.TestCase): self.assertEquals(self.username, self.ts.get_username()) self.assertEquals(True, self.ts.is_authenticated()) - stdin, stdout, stderr = self.tc.exec_command('yes') + stdin, stdout, stderr = self.tc.exec_command("yes") schan = self.ts.accept(1.0) - schan.send('Hello there.\n') - schan.send_stderr('This is on stderr.\n') + schan.send("Hello there.\n") + schan.send_stderr("This is on stderr.\n") schan.close() - self.assertEquals('Hello there.\n', stdout.readline()) - self.assertEquals('', stdout.readline()) - self.assertEquals('This is on stderr.\n', stderr.readline()) - self.assertEquals('', stderr.readline()) + self.assertEquals("Hello there.\n", stdout.readline()) + self.assertEquals("", stdout.readline()) + self.assertEquals("This is on stderr.\n", stderr.readline()) + self.assertEquals("", stderr.readline()) stdin.close() stdout.close() @@ -140,14 +146,17 @@ class GSSAuthTest(unittest.TestCase): Verify that Paramiko can handle SSHv2 GSS-API / SSPI authentication (gssapi-with-mic) in client and server mode. """ - self._test_connection(allow_agent=False, - look_for_keys=False) + self._test_connection(allow_agent=False, look_for_keys=False) def test_2_auth_trickledown(self): """ Failed gssapi-with-mic auth doesn't prevent subsequent key auth from succeeding """ - self.hostname = "this_host_does_not_exists_and_causes_a_GSSAPI-exception" - self._test_connection(key_filename=[_support('test_rsa.key')], - allow_agent=False, - look_for_keys=False) + self.hostname = ( + "this_host_does_not_exists_and_causes_a_GSSAPI-exception" + ) + self._test_connection( + key_filename=[_support("test_rsa.key")], + allow_agent=False, + look_for_keys=False, + ) diff --git a/tests/test_transport.py b/tests/test_transport.py index 9474acfc..c05d6781 100644 --- a/tests/test_transport.py +++ b/tests/test_transport.py @@ -32,14 +32,26 @@ from hashlib import sha1 import unittest from paramiko import ( - Transport, SecurityOptions, ServerInterface, RSAKey, DSSKey, SSHException, - ChannelException, Packetizer, Channel, + Transport, + SecurityOptions, + ServerInterface, + RSAKey, + DSSKey, + SSHException, + ChannelException, + Packetizer, + Channel, ) from paramiko import AUTH_FAILED, AUTH_SUCCESSFUL from paramiko import OPEN_SUCCEEDED, OPEN_FAILED_ADMINISTRATIVELY_PROHIBITED from paramiko.common import ( - MSG_KEXINIT, cMSG_CHANNEL_WINDOW_ADJUST, MIN_PACKET_SIZE, MIN_WINDOW_SIZE, - MAX_WINDOW_SIZE, DEFAULT_WINDOW_SIZE, DEFAULT_MAX_PACKET_SIZE, + MSG_KEXINIT, + cMSG_CHANNEL_WINDOW_ADJUST, + MIN_PACKET_SIZE, + MIN_WINDOW_SIZE, + MAX_WINDOW_SIZE, + DEFAULT_WINDOW_SIZE, + DEFAULT_MAX_PACKET_SIZE, ) from paramiko.py3compat import bytes from paramiko.message import Message @@ -61,28 +73,28 @@ Maybe. """ -class NullServer (ServerInterface): +class NullServer(ServerInterface): paranoid_did_password = False paranoid_did_public_key = False - paranoid_key = DSSKey.from_private_key_file(_support('test_dss.key')) + paranoid_key = DSSKey.from_private_key_file(_support("test_dss.key")) def get_allowed_auths(self, username): - if username == 'slowdive': - return 'publickey,password' - return 'publickey' + if username == "slowdive": + return "publickey,password" + return "publickey" def check_auth_password(self, username, password): - if (username == 'slowdive') and (password == 'pygmalion'): + if (username == "slowdive") and (password == "pygmalion"): return AUTH_SUCCESSFUL return AUTH_FAILED def check_channel_request(self, kind, chanid): - if kind == 'bogus': + if kind == "bogus": return OPEN_FAILED_ADMINISTRATIVELY_PROHIBITED return OPEN_SUCCEEDED def check_channel_exec_request(self, channel, command): - if command != b'yes': + if command != b"yes": return False return True @@ -95,9 +107,16 @@ class NullServer (ServerInterface): # tho that's only supposed to occur if the request cannot be served. # For now, leaving that the default unless test supplies specific # 'acceptable' request kind - return kind == 'acceptable' - - def check_channel_x11_request(self, channel, single_connection, auth_protocol, auth_cookie, screen_number): + return kind == "acceptable" + + def check_channel_x11_request( + self, + channel, + single_connection, + auth_protocol, + auth_cookie, + screen_number, + ): self._x11_single_connection = single_connection self._x11_auth_protocol = auth_protocol self._x11_auth_cookie = auth_cookie @@ -106,7 +125,7 @@ class NullServer (ServerInterface): def check_port_forward_request(self, addr, port): self._listen = socket.socket() - self._listen.bind(('127.0.0.1', 0)) + self._listen.bind(("127.0.0.1", 0)) self._listen.listen(1) return self._listen.getsockname()[1] @@ -120,6 +139,7 @@ class NullServer (ServerInterface): class TransportTest(unittest.TestCase): + def setUp(self): self.socks = LoopSocket() self.sockc = LoopSocket() @@ -134,9 +154,9 @@ class TransportTest(unittest.TestCase): self.sockc.close() def setup_test_server( - self, client_options=None, server_options=None, connect_kwargs=None, + self, client_options=None, server_options=None, connect_kwargs=None ): - host_key = RSAKey.from_private_key_file(_support('test_rsa.key')) + host_key = RSAKey.from_private_key_file(_support("test_rsa.key")) public_host_key = RSAKey(data=host_key.asbytes()) self.ts.add_server_key(host_key) @@ -152,8 +172,8 @@ class TransportTest(unittest.TestCase): if connect_kwargs is None: connect_kwargs = dict( hostkey=public_host_key, - username='slowdive', - password='pygmalion', + username="slowdive", + password="pygmalion", ) self.tc.connect(**connect_kwargs) event.wait(1.0) @@ -163,11 +183,11 @@ class TransportTest(unittest.TestCase): def test_1_security_options(self): o = self.tc.get_security_options() self.assertEqual(type(o), SecurityOptions) - self.assertTrue(('aes256-cbc', 'blowfish-cbc') != o.ciphers) - o.ciphers = ('aes256-cbc', 'blowfish-cbc') - self.assertEqual(('aes256-cbc', 'blowfish-cbc'), o.ciphers) + self.assertTrue(("aes256-cbc", "blowfish-cbc") != o.ciphers) + o.ciphers = ("aes256-cbc", "blowfish-cbc") + self.assertEqual(("aes256-cbc", "blowfish-cbc"), o.ciphers) try: - o.ciphers = ('aes256-cbc', 'made-up-cipher') + o.ciphers = ("aes256-cbc", "made-up-cipher") self.assertTrue(False) except ValueError: pass @@ -187,12 +207,18 @@ class TransportTest(unittest.TestCase): o.compression = o.compression def test_2_compute_key(self): - self.tc.K = 123281095979686581523377256114209720774539068973101330872763622971399429481072519713536292772709507296759612401802191955568143056534122385270077606457721553469730659233569339356140085284052436697480759510519672848743794433460113118986816826624865291116513647975790797391795651716378444844877749505443714557929 - self.tc.H = b'\x0C\x83\x07\xCD\xE6\x85\x6F\xF3\x0B\xA9\x36\x84\xEB\x0F\x04\xC2\x52\x0E\x9E\xD3' + self.tc.K = ( + 123281095979686581523377256114209720774539068973101330872763622971399429481072519713536292772709507296759612401802191955568143056534122385270077606457721553469730659233569339356140085284052436697480759510519672848743794433460113118986816826624865291116513647975790797391795651716378444844877749505443714557929 + ) + self.tc.H = ( + b"\x0C\x83\x07\xCD\xE6\x85\x6F\xF3\x0B\xA9\x36\x84\xEB\x0F\x04\xC2\x52\x0E\x9E\xD3" + ) self.tc.session_id = self.tc.H - key = self.tc._compute_key('C', 32) - self.assertEqual(b'207E66594CA87C44ECCBA3B3CD39FDDB378E6FDB0F97C54B2AA0CFBF900CD995', - hexlify(key).upper()) + key = self.tc._compute_key("C", 32) + self.assertEqual( + b"207E66594CA87C44ECCBA3B3CD39FDDB378E6FDB0F97C54B2AA0CFBF900CD995", + hexlify(key).upper(), + ) def test_3_simple(self): """ @@ -200,7 +226,7 @@ class TransportTest(unittest.TestCase): loopback sockets. this is hardly "simple" but it's simpler than the later tests. :) """ - host_key = RSAKey.from_private_key_file(_support('test_rsa.key')) + host_key = RSAKey.from_private_key_file(_support("test_rsa.key")) public_host_key = RSAKey(data=host_key.asbytes()) self.ts.add_server_key(host_key) event = threading.Event() @@ -211,13 +237,14 @@ class TransportTest(unittest.TestCase): self.assertEqual(False, self.tc.is_authenticated()) self.assertEqual(False, self.ts.is_authenticated()) self.ts.start_server(event, server) - self.tc.connect(hostkey=public_host_key, - username='slowdive', password='pygmalion') + self.tc.connect( + hostkey=public_host_key, username="slowdive", password="pygmalion" + ) event.wait(1.0) self.assertTrue(event.is_set()) self.assertTrue(self.ts.is_active()) - self.assertEqual('slowdive', self.tc.get_username()) - self.assertEqual('slowdive', self.ts.get_username()) + self.assertEqual("slowdive", self.tc.get_username()) + self.assertEqual("slowdive", self.ts.get_username()) self.assertEqual(True, self.tc.is_authenticated()) self.assertEqual(True, self.ts.is_authenticated()) @@ -225,7 +252,7 @@ class TransportTest(unittest.TestCase): """ verify that a long banner doesn't mess up the handshake. """ - host_key = RSAKey.from_private_key_file(_support('test_rsa.key')) + host_key = RSAKey.from_private_key_file(_support("test_rsa.key")) public_host_key = RSAKey(data=host_key.asbytes()) self.ts.add_server_key(host_key) event = threading.Event() @@ -233,8 +260,9 @@ class TransportTest(unittest.TestCase): self.assertTrue(not event.is_set()) self.socks.send(LONG_BANNER) self.ts.start_server(event, server) - self.tc.connect(hostkey=public_host_key, - username='slowdive', password='pygmalion') + self.tc.connect( + hostkey=public_host_key, username="slowdive", password="pygmalion" + ) event.wait(1.0) self.assertTrue(event.is_set()) self.assertTrue(self.ts.is_active()) @@ -244,12 +272,14 @@ class TransportTest(unittest.TestCase): verify that the client can demand odd handshake settings, and can renegotiate keys in mid-stream. """ + def force_algorithms(options): - options.ciphers = ('aes256-cbc',) - options.digests = ('hmac-md5-96',) + options.ciphers = ("aes256-cbc",) + options.digests = ("hmac-md5-96",) + self.setup_test_server(client_options=force_algorithms) - self.assertEqual('aes256-cbc', self.tc.local_cipher) - self.assertEqual('aes256-cbc', self.tc.remote_cipher) + self.assertEqual("aes256-cbc", self.tc.local_cipher) + self.assertEqual("aes256-cbc", self.tc.remote_cipher) self.assertEqual(12, self.tc.packetizer.get_mac_size_out()) self.assertEqual(12, self.tc.packetizer.get_mac_size_in()) @@ -263,10 +293,10 @@ class TransportTest(unittest.TestCase): verify that the keepalive will be sent. """ self.setup_test_server() - self.assertEqual(None, getattr(self.server, '_global_request', None)) + self.assertEqual(None, getattr(self.server, "_global_request", None)) self.tc.set_keepalive(1) time.sleep(2) - self.assertEqual('keepalive@lag.net', self.server._global_request) + self.assertEqual("keepalive@lag.net", self.server._global_request) def test_6_exec_command(self): """ @@ -277,39 +307,41 @@ class TransportTest(unittest.TestCase): chan = self.tc.open_session() schan = self.ts.accept(1.0) try: - chan.exec_command(b'command contains \xfc and is not a valid UTF-8 string') + chan.exec_command( + b"command contains \xfc and is not a valid UTF-8 string" + ) self.assertTrue(False) except SSHException: pass chan = self.tc.open_session() - chan.exec_command('yes') + chan.exec_command("yes") schan = self.ts.accept(1.0) - schan.send('Hello there.\n') - schan.send_stderr('This is on stderr.\n') + schan.send("Hello there.\n") + schan.send_stderr("This is on stderr.\n") schan.close() f = chan.makefile() - self.assertEqual('Hello there.\n', f.readline()) - self.assertEqual('', f.readline()) + self.assertEqual("Hello there.\n", f.readline()) + self.assertEqual("", f.readline()) f = chan.makefile_stderr() - self.assertEqual('This is on stderr.\n', f.readline()) - self.assertEqual('', f.readline()) + self.assertEqual("This is on stderr.\n", f.readline()) + self.assertEqual("", f.readline()) # now try it with combined stdout/stderr chan = self.tc.open_session() - chan.exec_command('yes') + chan.exec_command("yes") schan = self.ts.accept(1.0) - schan.send('Hello there.\n') - schan.send_stderr('This is on stderr.\n') + schan.send("Hello there.\n") + schan.send_stderr("This is on stderr.\n") schan.close() chan.set_combine_stderr(True) f = chan.makefile() - self.assertEqual('Hello there.\n', f.readline()) - self.assertEqual('This is on stderr.\n', f.readline()) - self.assertEqual('', f.readline()) - + self.assertEqual("Hello there.\n", f.readline()) + self.assertEqual("This is on stderr.\n", f.readline()) + self.assertEqual("", f.readline()) + def test_6a_channel_can_be_used_as_context_manager(self): """ verify that exec_command() does something reasonable. @@ -318,13 +350,13 @@ class TransportTest(unittest.TestCase): with self.tc.open_session() as chan: with self.ts.accept(1.0) as schan: - chan.exec_command('yes') - schan.send('Hello there.\n') + chan.exec_command("yes") + schan.send("Hello there.\n") schan.close() f = chan.makefile() - self.assertEqual('Hello there.\n', f.readline()) - self.assertEqual('', f.readline()) + self.assertEqual("Hello there.\n", f.readline()) + self.assertEqual("", f.readline()) def test_7_invoke_shell(self): """ @@ -334,11 +366,11 @@ class TransportTest(unittest.TestCase): chan = self.tc.open_session() chan.invoke_shell() schan = self.ts.accept(1.0) - chan.send('communist j. cat\n') + chan.send("communist j. cat\n") f = schan.makefile() - self.assertEqual('communist j. cat\n', f.readline()) + self.assertEqual("communist j. cat\n", f.readline()) chan.close() - self.assertEqual('', f.readline()) + self.assertEqual("", f.readline()) def test_8_channel_exception(self): """ @@ -346,8 +378,8 @@ class TransportTest(unittest.TestCase): """ self.setup_test_server() try: - chan = self.tc.open_channel('bogus') - self.fail('expected exception') + chan = self.tc.open_channel("bogus") + self.fail("expected exception") except ChannelException as e: self.assertTrue(e.code == OPEN_FAILED_ADMINISTRATIVELY_PROHIBITED) @@ -359,8 +391,8 @@ class TransportTest(unittest.TestCase): chan = self.tc.open_session() schan = self.ts.accept(1.0) - chan.exec_command('yes') - schan.send('Hello there.\n') + chan.exec_command("yes") + schan.send("Hello there.\n") self.assertTrue(not chan.exit_status_ready()) # trigger an EOF schan.shutdown_read() @@ -369,8 +401,8 @@ class TransportTest(unittest.TestCase): schan.close() f = chan.makefile() - self.assertEqual('Hello there.\n', f.readline()) - self.assertEqual('', f.readline()) + self.assertEqual("Hello there.\n", f.readline()) + self.assertEqual("", f.readline()) count = 0 while not chan.exit_status_ready(): time.sleep(0.1) @@ -395,7 +427,7 @@ class TransportTest(unittest.TestCase): self.assertEqual([], w) self.assertEqual([], e) - schan.send('hello\n') + schan.send("hello\n") # something should be ready now (give it 1 second to appear) for i in range(10): @@ -407,7 +439,7 @@ class TransportTest(unittest.TestCase): self.assertEqual([], w) self.assertEqual([], e) - self.assertEqual(b'hello\n', chan.recv(6)) + self.assertEqual(b"hello\n", chan.recv(6)) # and, should be dead again now r, w, e = select.select([chan], [], [], 0.1) @@ -442,12 +474,12 @@ class TransportTest(unittest.TestCase): self.setup_test_server() self.tc.packetizer.REKEY_BYTES = 16384 chan = self.tc.open_session() - chan.exec_command('yes') + chan.exec_command("yes") schan = self.ts.accept(1.0) self.assertEqual(self.tc.H, self.tc.session_id) for i in range(20): - chan.send('x' * 1024) + chan.send("x" * 1024) chan.close() # allow a few seconds for the rekeying to complete @@ -463,18 +495,20 @@ class TransportTest(unittest.TestCase): """ verify that zlib compression is basically working. """ + def force_compression(o): - o.compression = ('zlib',) + o.compression = ("zlib",) + self.setup_test_server(force_compression, force_compression) chan = self.tc.open_session() - chan.exec_command('yes') + chan.exec_command("yes") schan = self.ts.accept(1.0) bytes = self.tc.packetizer._Packetizer__sent_bytes - chan.send('x' * 1024) + chan.send("x" * 1024) bytes2 = self.tc.packetizer._Packetizer__sent_bytes - block_size = self.tc._cipher_info[self.tc.local_cipher]['block-size'] - mac_size = self.tc._mac_info[self.tc.local_mac]['size'] + block_size = self.tc._cipher_info[self.tc.local_cipher]["block-size"] + mac_size = self.tc._mac_info[self.tc.local_mac]["size"] # tests show this is actually compressed to *52 bytes*! including packet overhead! nice!! :) self.assertTrue(bytes2 - bytes < 1024) self.assertEqual(16 + block_size + mac_size, bytes2 - bytes) @@ -488,29 +522,32 @@ class TransportTest(unittest.TestCase): """ self.setup_test_server() chan = self.tc.open_session() - chan.exec_command('yes') + chan.exec_command("yes") schan = self.ts.accept(1.0) requested = [] + def handler(c, addr_port): addr, port = addr_port requested.append((addr, port)) self.tc._queue_incoming_channel(c) - self.assertEqual(None, getattr(self.server, '_x11_screen_number', None)) + self.assertEqual( + None, getattr(self.server, "_x11_screen_number", None) + ) cookie = chan.request_x11(0, single_connection=True, handler=handler) self.assertEqual(0, self.server._x11_screen_number) - self.assertEqual('MIT-MAGIC-COOKIE-1', self.server._x11_auth_protocol) + self.assertEqual("MIT-MAGIC-COOKIE-1", self.server._x11_auth_protocol) self.assertEqual(cookie, self.server._x11_auth_cookie) self.assertEqual(True, self.server._x11_single_connection) - x11_server = self.ts.open_x11_channel(('localhost', 6093)) + x11_server = self.ts.open_x11_channel(("localhost", 6093)) x11_client = self.tc.accept() - self.assertEqual('localhost', requested[0][0]) + self.assertEqual("localhost", requested[0][0]) self.assertEqual(6093, requested[0][1]) - x11_server.send('hello') - self.assertEqual(b'hello', x11_client.recv(5)) + x11_server.send("hello") + self.assertEqual(b"hello", x11_client.recv(5)) x11_server.close() x11_client.close() @@ -524,33 +561,36 @@ class TransportTest(unittest.TestCase): """ self.setup_test_server() chan = self.tc.open_session() - chan.exec_command('yes') + chan.exec_command("yes") schan = self.ts.accept(1.0) requested = [] + def handler(c, origin_addr_port, server_addr_port): requested.append(origin_addr_port) requested.append(server_addr_port) self.tc._queue_incoming_channel(c) - port = self.tc.request_port_forward('127.0.0.1', 0, handler) + port = self.tc.request_port_forward("127.0.0.1", 0, handler) self.assertEqual(port, self.server._listen.getsockname()[1]) cs = socket.socket() - cs.connect(('127.0.0.1', port)) + cs.connect(("127.0.0.1", port)) ss, _ = self.server._listen.accept() - sch = self.ts.open_forwarded_tcpip_channel(ss.getsockname(), ss.getpeername()) + sch = self.ts.open_forwarded_tcpip_channel( + ss.getsockname(), ss.getpeername() + ) cch = self.tc.accept() - sch.send('hello') - self.assertEqual(b'hello', cch.recv(5)) + sch.send("hello") + self.assertEqual(b"hello", cch.recv(5)) sch.close() cch.close() ss.close() cs.close() # now cancel it. - self.tc.cancel_port_forward('127.0.0.1', port) + self.tc.cancel_port_forward("127.0.0.1", port) self.assertTrue(self.server._listen is None) def test_F_port_forwarding(self): @@ -560,27 +600,29 @@ class TransportTest(unittest.TestCase): """ self.setup_test_server() chan = self.tc.open_session() - chan.exec_command('yes') + chan.exec_command("yes") schan = self.ts.accept(1.0) # open a port on the "server" that the client will ask to forward to. greeting_server = socket.socket() - greeting_server.bind(('127.0.0.1', 0)) + greeting_server.bind(("127.0.0.1", 0)) greeting_server.listen(1) greeting_port = greeting_server.getsockname()[1] - cs = self.tc.open_channel('direct-tcpip', ('127.0.0.1', greeting_port), ('', 9000)) + cs = self.tc.open_channel( + "direct-tcpip", ("127.0.0.1", greeting_port), ("", 9000) + ) sch = self.ts.accept(1.0) cch = socket.socket() cch.connect(self.server._tcpip_dest) ss, _ = greeting_server.accept() - ss.send(b'Hello!\n') + ss.send(b"Hello!\n") ss.close() sch.send(cch.recv(8192)) sch.close() - self.assertEqual(b'Hello!\n', cs.recv(7)) + self.assertEqual(b"Hello!\n", cs.recv(7)) cs.close() def test_G_stderr_select(self): @@ -599,7 +641,7 @@ class TransportTest(unittest.TestCase): self.assertEqual([], w) self.assertEqual([], e) - schan.send_stderr('hello\n') + schan.send_stderr("hello\n") # something should be ready now (give it 1 second to appear) for i in range(10): @@ -611,7 +653,7 @@ class TransportTest(unittest.TestCase): self.assertEqual([], w) self.assertEqual([], e) - self.assertEqual(b'hello\n', chan.recv_stderr(6)) + self.assertEqual(b"hello\n", chan.recv_stderr(6)) # and, should be dead again now r, w, e = select.select([chan], [], [], 0.1) @@ -633,8 +675,8 @@ class TransportTest(unittest.TestCase): self.assertEqual(chan.send_ready(), True) total = 0 - K = '*' * 1024 - limit = 1+(64 * 2 ** 15) + K = "*" * 1024 + limit = 1 + (64 * 2 ** 15) while total < limit: chan.send(K) total += len(K) @@ -696,8 +738,11 @@ class TransportTest(unittest.TestCase): # expires, a deadlock is assumed. class SendThread(threading.Thread): + def __init__(self, chan, iterations, done_event): - threading.Thread.__init__(self, None, None, self.__class__.__name__) + threading.Thread.__init__( + self, None, None, self.__class__.__name__ + ) self.setDaemon(True) self.chan = chan self.iterations = iterations @@ -707,19 +752,22 @@ class TransportTest(unittest.TestCase): def run(self): try: - for i in range(1, 1+self.iterations): + for i in range(1, 1 + self.iterations): if self.done_event.is_set(): break self.watchdog_event.set() - #print i, "SEND" + # print i, "SEND" self.chan.send("x" * 2048) finally: self.done_event.set() self.watchdog_event.set() class ReceiveThread(threading.Thread): + def __init__(self, chan, done_event): - threading.Thread.__init__(self, None, None, self.__class__.__name__) + threading.Thread.__init__( + self, None, None, self.__class__.__name__ + ) self.setDaemon(True) self.chan = chan self.done_event = done_event @@ -742,30 +790,34 @@ class TransportTest(unittest.TestCase): self.ts.packetizer.REKEY_BYTES = 2048 chan = self.tc.open_session() - chan.exec_command('yes') + chan.exec_command("yes") schan = self.ts.accept(1.0) # Monkey patch the client's Transport._handler_table so that the client # sends MSG_CHANNEL_WINDOW_ADJUST whenever it receives an initial # MSG_KEXINIT. This is used to simulate the effect of network latency # on a real MSG_CHANNEL_WINDOW_ADJUST message. - self.tc._handler_table = self.tc._handler_table.copy() # copy per-class dictionary + self.tc._handler_table = ( + self.tc._handler_table.copy() + ) # copy per-class dictionary _negotiate_keys = self.tc._handler_table[MSG_KEXINIT] + def _negotiate_keys_wrapper(self, m): - if self.local_kex_init is None: # Remote side sent KEXINIT + if self.local_kex_init is None: # Remote side sent KEXINIT # Simulate in-transit MSG_CHANNEL_WINDOW_ADJUST by sending it # before responding to the incoming MSG_KEXINIT. m2 = Message() m2.add_byte(cMSG_CHANNEL_WINDOW_ADJUST) m2.add_int(chan.remote_chanid) - m2.add_int(1) # bytes to add + m2.add_int(1) # bytes to add self._send_message(m2) return _negotiate_keys(self, m) + self.tc._handler_table[MSG_KEXINIT] = _negotiate_keys_wrapper # Parameters for the test - iterations = 500 # The deadlock does not happen every time, but it - # should after many iterations. + iterations = 500 # The deadlock does not happen every time, but it + # should after many iterations. timeout = 5 # This event is set when the test is completed @@ -807,18 +859,22 @@ class TransportTest(unittest.TestCase): """ verify that we conform to the rfc of packet and window sizes. """ - for val, correct in [(4095, MIN_PACKET_SIZE), - (None, DEFAULT_MAX_PACKET_SIZE), - (2**32, MAX_WINDOW_SIZE)]: + for val, correct in [ + (4095, MIN_PACKET_SIZE), + (None, DEFAULT_MAX_PACKET_SIZE), + (2 ** 32, MAX_WINDOW_SIZE), + ]: self.assertEqual(self.tc._sanitize_packet_size(val), correct) def test_K_sanitze_window_size(self): """ verify that we conform to the rfc of packet and window sizes. """ - for val, correct in [(32767, MIN_WINDOW_SIZE), - (None, DEFAULT_WINDOW_SIZE), - (2**32, MAX_WINDOW_SIZE)]: + for val, correct in [ + (32767, MIN_WINDOW_SIZE), + (None, DEFAULT_WINDOW_SIZE), + (2 ** 32, MAX_WINDOW_SIZE), + ]: self.assertEqual(self.tc._sanitize_window_size(val), correct) @slow @@ -834,15 +890,17 @@ class TransportTest(unittest.TestCase): # (Doing this on the server's transport *sounds* more 'correct' but # actually doesn't work nearly as well for whatever reason.) class SlowPacketizer(Packetizer): + def read_message(self): time.sleep(1) return super(SlowPacketizer, self).read_message() + # NOTE: prettttty sure since the replaced .packetizer Packetizer is now # no longer doing anything with its copy of the socket...everything'll # be fine. Even tho it's a bit squicky. self.tc.packetizer = SlowPacketizer(self.tc.sock) # Continue with regular test red tape. - host_key = RSAKey.from_private_key_file(_support('test_rsa.key')) + host_key = RSAKey.from_private_key_file(_support("test_rsa.key")) public_host_key = RSAKey(data=host_key.asbytes()) self.ts.add_server_key(host_key) event = threading.Event() @@ -850,10 +908,13 @@ class TransportTest(unittest.TestCase): self.assertTrue(not event.is_set()) self.tc.handshake_timeout = 0.000000000001 self.ts.start_server(event, server) - self.assertRaises(EOFError, self.tc.connect, - hostkey=public_host_key, - username='slowdive', - password='pygmalion') + self.assertRaises( + EOFError, + self.tc.connect, + hostkey=public_host_key, + username="slowdive", + password="pygmalion", + ) def test_M_select_after_close(self): """ @@ -894,13 +955,13 @@ class TransportTest(unittest.TestCase): expected = text.encode("utf-8") self.assertEqual(sfile.read(len(expected)), expected) - @needs_builtin('buffer') + @needs_builtin("buffer") def test_channel_send_buffer(self): """ verify sending buffer instances to a channel """ self.setup_test_server() - data = 3 * b'some test data\n whole' + data = 3 * b"some test data\n whole" with self.tc.open_session() as chan: schan = self.ts.accept(1.0) if schan is None: @@ -917,13 +978,13 @@ class TransportTest(unittest.TestCase): chan.sendall(buffer(data)) self.assertEqual(sfile.read(len(data)), data) - @needs_builtin('memoryview') + @needs_builtin("memoryview") def test_channel_send_memoryview(self): """ verify sending memoryview instances to a channel """ self.setup_test_server() - data = 3 * b'some test data\n whole' + data = 3 * b"some test data\n whole" with self.tc.open_session() as chan: schan = self.ts.accept(1.0) if schan is None: @@ -934,7 +995,7 @@ class TransportTest(unittest.TestCase): sent = 0 view = memoryview(data) while sent < len(view): - sent += chan.send(view[sent:sent+8]) + sent += chan.send(view[sent : sent + 8]) self.assertEqual(sfile.read(len(data)), data) # sendall() accepts a memoryview instance @@ -954,7 +1015,7 @@ class TransportTest(unittest.TestCase): self.setup_test_server(connect_kwargs={}) # NOTE: this dummy global request kind would normally pass muster # from the test server. - self.tc.global_request('acceptable') + self.tc.global_request("acceptable") # Global requests never raise exceptions, even on failure (not sure why # this was the original design...ugh.) Best we can do to tell failure # happened is that the client transport's global_response was set back @@ -969,7 +1030,7 @@ class TransportTest(unittest.TestCase): # an exception on the client side, unlike the general case...) self.setup_test_server(connect_kwargs={}) try: - self.tc.request_port_forward('localhost', 1234) + self.tc.request_port_forward("localhost", 1234) except SSHException as e: assert "forwarding request denied" in str(e) else: diff --git a/tests/test_util.py b/tests/test_util.py index 90473f43..23b2e86a 100644 --- a/tests/test_util.py +++ b/tests/test_util.py @@ -67,53 +67,71 @@ from paramiko import * class UtilTest(unittest.TestCase): + def test_import(self): """ verify that all the classes can be imported from paramiko. """ symbols = list(globals().keys()) - self.assertTrue('Transport' in symbols) - self.assertTrue('SSHClient' in symbols) - self.assertTrue('MissingHostKeyPolicy' in symbols) - self.assertTrue('AutoAddPolicy' in symbols) - self.assertTrue('RejectPolicy' in symbols) - self.assertTrue('WarningPolicy' in symbols) - self.assertTrue('SecurityOptions' in symbols) - self.assertTrue('SubsystemHandler' in symbols) - self.assertTrue('Channel' in symbols) - self.assertTrue('RSAKey' in symbols) - self.assertTrue('DSSKey' in symbols) - self.assertTrue('Message' in symbols) - self.assertTrue('SSHException' in symbols) - self.assertTrue('AuthenticationException' in symbols) - self.assertTrue('PasswordRequiredException' in symbols) - self.assertTrue('BadAuthenticationType' in symbols) - self.assertTrue('ChannelException' in symbols) - self.assertTrue('SFTP' in symbols) - self.assertTrue('SFTPFile' in symbols) - self.assertTrue('SFTPHandle' in symbols) - self.assertTrue('SFTPClient' in symbols) - self.assertTrue('SFTPServer' in symbols) - self.assertTrue('SFTPError' in symbols) - self.assertTrue('SFTPAttributes' in symbols) - self.assertTrue('SFTPServerInterface' in symbols) - self.assertTrue('ServerInterface' in symbols) - self.assertTrue('BufferedFile' in symbols) - self.assertTrue('Agent' in symbols) - self.assertTrue('AgentKey' in symbols) - self.assertTrue('HostKeys' in symbols) - self.assertTrue('SSHConfig' in symbols) - self.assertTrue('util' in symbols) + self.assertTrue("Transport" in symbols) + self.assertTrue("SSHClient" in symbols) + self.assertTrue("MissingHostKeyPolicy" in symbols) + self.assertTrue("AutoAddPolicy" in symbols) + self.assertTrue("RejectPolicy" in symbols) + self.assertTrue("WarningPolicy" in symbols) + self.assertTrue("SecurityOptions" in symbols) + self.assertTrue("SubsystemHandler" in symbols) + self.assertTrue("Channel" in symbols) + self.assertTrue("RSAKey" in symbols) + self.assertTrue("DSSKey" in symbols) + self.assertTrue("Message" in symbols) + self.assertTrue("SSHException" in symbols) + self.assertTrue("AuthenticationException" in symbols) + self.assertTrue("PasswordRequiredException" in symbols) + self.assertTrue("BadAuthenticationType" in symbols) + self.assertTrue("ChannelException" in symbols) + self.assertTrue("SFTP" in symbols) + self.assertTrue("SFTPFile" in symbols) + self.assertTrue("SFTPHandle" in symbols) + self.assertTrue("SFTPClient" in symbols) + self.assertTrue("SFTPServer" in symbols) + self.assertTrue("SFTPError" in symbols) + self.assertTrue("SFTPAttributes" in symbols) + self.assertTrue("SFTPServerInterface" in symbols) + self.assertTrue("ServerInterface" in symbols) + self.assertTrue("BufferedFile" in symbols) + self.assertTrue("Agent" in symbols) + self.assertTrue("AgentKey" in symbols) + self.assertTrue("HostKeys" in symbols) + self.assertTrue("SSHConfig" in symbols) + self.assertTrue("util" in symbols) def test_parse_config(self): global test_config_file f = StringIO(test_config_file) config = paramiko.util.parse_ssh_config(f) - self.assertEqual(config._config, - [{'host': ['*'], 'config': {}}, {'host': ['*'], 'config': {'identityfile': ['~/.ssh/id_rsa'], 'user': 'robey'}}, - {'host': ['*.example.com'], 'config': {'user': 'bjork', 'port': '3333'}}, - {'host': ['*'], 'config': {'crazy': 'something dumb'}}, - {'host': ['spoo.example.com'], 'config': {'crazy': 'something else'}}]) + self.assertEqual( + config._config, + [ + {"host": ["*"], "config": {}}, + { + "host": ["*"], + "config": { + "identityfile": ["~/.ssh/id_rsa"], + "user": "robey", + }, + }, + { + "host": ["*.example.com"], + "config": {"user": "bjork", "port": "3333"}, + }, + {"host": ["*"], "config": {"crazy": "something dumb"}}, + { + "host": ["spoo.example.com"], + "config": {"crazy": "something else"}, + }, + ], + ) def test_host_config(self): global test_config_file @@ -121,44 +139,57 @@ class UtilTest(unittest.TestCase): config = paramiko.util.parse_ssh_config(f) for host, values in { - 'irc.danger.com': {'crazy': 'something dumb', - 'hostname': 'irc.danger.com', - 'user': 'robey'}, - 'irc.example.com': {'crazy': 'something dumb', - 'hostname': 'irc.example.com', - 'user': 'robey', - 'port': '3333'}, - 'spoo.example.com': {'crazy': 'something dumb', - 'hostname': 'spoo.example.com', - 'user': 'robey', - 'port': '3333'} + "irc.danger.com": { + "crazy": "something dumb", + "hostname": "irc.danger.com", + "user": "robey", + }, + "irc.example.com": { + "crazy": "something dumb", + "hostname": "irc.example.com", + "user": "robey", + "port": "3333", + }, + "spoo.example.com": { + "crazy": "something dumb", + "hostname": "spoo.example.com", + "user": "robey", + "port": "3333", + }, }.items(): - values = dict(values, + values = dict( + values, hostname=host, - identityfile=[os.path.expanduser("~/.ssh/id_rsa")] + identityfile=[os.path.expanduser("~/.ssh/id_rsa")], ) self.assertEqual( - paramiko.util.lookup_ssh_host_config(host, config), - values + paramiko.util.lookup_ssh_host_config(host, config), values ) def test_generate_key_bytes(self): - x = paramiko.util.generate_key_bytes(sha1, b'ABCDEFGH', 'This is my secret passphrase.', 64) - hex = ''.join(['%02x' % byte_ord(c) for c in x]) - self.assertEqual(hex, '9110e2f6793b69363e58173e9436b13a5a4b339005741d5c680e505f57d871347b4239f14fb5c46e857d5e100424873ba849ac699cea98d729e57b3e84378e8b') + x = paramiko.util.generate_key_bytes( + sha1, b"ABCDEFGH", "This is my secret passphrase.", 64 + ) + hex = "".join(["%02x" % byte_ord(c) for c in x]) + self.assertEqual( + hex, + "9110e2f6793b69363e58173e9436b13a5a4b339005741d5c680e505f57d871347b4239f14fb5c46e857d5e100424873ba849ac699cea98d729e57b3e84378e8b", + ) def test_host_keys(self): - with open('hostfile.temp', 'w') as f: + with open("hostfile.temp", "w") as f: f.write(test_hosts_file) try: - hostdict = paramiko.util.load_host_keys('hostfile.temp') + hostdict = paramiko.util.load_host_keys("hostfile.temp") self.assertEqual(2, len(hostdict)) self.assertEqual(1, len(list(hostdict.values())[0])) self.assertEqual(1, len(list(hostdict.values())[1])) - fp = hexlify(hostdict['secure.example.com']['ssh-rsa'].get_fingerprint()).upper() - self.assertEqual(b'E6684DB30E109B67B70FF1DC5C7F1363', fp) + fp = hexlify( + hostdict["secure.example.com"]["ssh-rsa"].get_fingerprint() + ).upper() + self.assertEqual(b"E6684DB30E109B67B70FF1DC5C7F1363", fp) finally: - os.unlink('hostfile.temp') + os.unlink("hostfile.temp") def test_host_config_expose_issue_33(self): test_config_file = """ @@ -173,36 +204,44 @@ Host * """ f = StringIO(test_config_file) config = paramiko.util.parse_ssh_config(f) - host = 'www13.example.com' + host = "www13.example.com" self.assertEqual( paramiko.util.lookup_ssh_host_config(host, config), - {'hostname': host, 'port': '22'} + {"hostname": host, "port": "22"}, ) def test_eintr_retry(self): - self.assertEqual('foo', paramiko.util.retry_on_signal(lambda: 'foo')) + self.assertEqual("foo", paramiko.util.retry_on_signal(lambda: "foo")) # Variables that are set by raises_intr intr_errors_remaining = [3] call_count = [0] + def raises_intr(): call_count[0] += 1 if intr_errors_remaining[0] > 0: intr_errors_remaining[0] -= 1 - raise IOError(errno.EINTR, 'file', 'interrupted system call') + raise IOError(errno.EINTR, "file", "interrupted system call") + self.assertTrue(paramiko.util.retry_on_signal(raises_intr) is None) self.assertEqual(0, intr_errors_remaining[0]) self.assertEqual(4, call_count[0]) def raises_ioerror_not_eintr(): - raise IOError(errno.ENOENT, 'file', 'file not found') - self.assertRaises(IOError, - lambda: paramiko.util.retry_on_signal(raises_ioerror_not_eintr)) + raise IOError(errno.ENOENT, "file", "file not found") + + self.assertRaises( + IOError, + lambda: paramiko.util.retry_on_signal(raises_ioerror_not_eintr), + ) def raises_other_exception(): - raise AssertionError('foo') - self.assertRaises(AssertionError, - lambda: paramiko.util.retry_on_signal(raises_other_exception)) + raise AssertionError("foo") + + self.assertRaises( + AssertionError, + lambda: paramiko.util.retry_on_signal(raises_other_exception), + ) def test_proxycommand_config_equals_parsing(self): """ @@ -217,17 +256,18 @@ Host equals-delimited """ f = StringIO(conf) config = paramiko.util.parse_ssh_config(f) - for host in ('space-delimited', 'equals-delimited'): + for host in ("space-delimited", "equals-delimited"): self.assertEqual( - host_config(host, config)['proxycommand'], - 'foo bar=biz baz' + host_config(host, config)["proxycommand"], "foo bar=biz baz" ) def test_proxycommand_interpolation(self): """ ProxyCommand should perform interpolation on the value """ - config = paramiko.util.parse_ssh_config(StringIO(""" + config = paramiko.util.parse_ssh_config( + StringIO( + """ Host specific Port 37 ProxyCommand host %h port %p lol @@ -238,28 +278,32 @@ Host portonly Host * Port 25 ProxyCommand host %h port %p -""")) +""" + ) + ) for host, val in ( - ('foo.com', "host foo.com port 25"), - ('specific', "host specific port 37 lol"), - ('portonly', "host portonly port 155"), + ("foo.com", "host foo.com port 25"), + ("specific", "host specific port 37 lol"), + ("portonly", "host portonly port 155"), ): - self.assertEqual( - host_config(host, config)['proxycommand'], - val - ) + self.assertEqual(host_config(host, config)["proxycommand"], val) def test_proxycommand_tilde_expansion(self): """ Tilde (~) should be expanded inside ProxyCommand """ - config = paramiko.util.parse_ssh_config(StringIO(""" + config = paramiko.util.parse_ssh_config( + StringIO( + """ Host test ProxyCommand ssh -F ~/.ssh/test_config bastion nc %h %p -""")) +""" + ) + ) self.assertEqual( - 'ssh -F %s/.ssh/test_config bastion nc test 22' % os.path.expanduser('~'), - host_config('test', config)['proxycommand'] + "ssh -F %s/.ssh/test_config bastion nc test 22" + % os.path.expanduser("~"), + host_config("test", config)["proxycommand"], ) def test_host_config_test_negation(self): @@ -278,10 +322,10 @@ Host * """ f = StringIO(test_config_file) config = paramiko.util.parse_ssh_config(f) - host = 'www13.example.com' + host = "www13.example.com" self.assertEqual( paramiko.util.lookup_ssh_host_config(host, config), - {'hostname': host, 'port': '8080'} + {"hostname": host, "port": "8080"}, ) def test_host_config_test_proxycommand(self): @@ -296,20 +340,24 @@ Host proxy-without-equal-divisor ProxyCommand foo=bar:%h-%p """ for host, values in { - 'proxy-with-equal-divisor-and-space' :{'hostname': 'proxy-with-equal-divisor-and-space', - 'proxycommand': 'foo=bar'}, - 'proxy-with-equal-divisor-and-no-space':{'hostname': 'proxy-with-equal-divisor-and-no-space', - 'proxycommand': 'foo=bar'}, - 'proxy-without-equal-divisor' :{'hostname': 'proxy-without-equal-divisor', - 'proxycommand': - 'foo=bar:proxy-without-equal-divisor-22'} + "proxy-with-equal-divisor-and-space": { + "hostname": "proxy-with-equal-divisor-and-space", + "proxycommand": "foo=bar", + }, + "proxy-with-equal-divisor-and-no-space": { + "hostname": "proxy-with-equal-divisor-and-no-space", + "proxycommand": "foo=bar", + }, + "proxy-without-equal-divisor": { + "hostname": "proxy-without-equal-divisor", + "proxycommand": "foo=bar:proxy-without-equal-divisor-22", + }, }.items(): f = StringIO(test_config_file) config = paramiko.util.parse_ssh_config(f) self.assertEqual( - paramiko.util.lookup_ssh_host_config(host, config), - values + paramiko.util.lookup_ssh_host_config(host, config), values ) def test_host_config_test_identityfile(self): @@ -327,19 +375,21 @@ Host dsa2* IdentityFile id_dsa22 """ for host, values in { - 'foo' :{'hostname': 'foo', - 'identityfile': ['id_dsa0', 'id_dsa1']}, - 'dsa2' :{'hostname': 'dsa2', - 'identityfile': ['id_dsa0', 'id_dsa1', 'id_dsa2', 'id_dsa22']}, - 'dsa22' :{'hostname': 'dsa22', - 'identityfile': ['id_dsa0', 'id_dsa1', 'id_dsa22']} + "foo": {"hostname": "foo", "identityfile": ["id_dsa0", "id_dsa1"]}, + "dsa2": { + "hostname": "dsa2", + "identityfile": ["id_dsa0", "id_dsa1", "id_dsa2", "id_dsa22"], + }, + "dsa22": { + "hostname": "dsa22", + "identityfile": ["id_dsa0", "id_dsa1", "id_dsa22"], + }, }.items(): f = StringIO(test_config_file) config = paramiko.util.parse_ssh_config(f) self.assertEqual( - paramiko.util.lookup_ssh_host_config(host, config), - values + paramiko.util.lookup_ssh_host_config(host, config), values ) def test_config_addressfamily_and_lazy_fqdn(self): @@ -351,7 +401,9 @@ AddressFamily inet IdentityFile something_%l_using_fqdn """ config = paramiko.util.parse_ssh_config(StringIO(test_config)) - assert config.lookup('meh') # will die during lookup() if bug regresses + assert config.lookup( + "meh" + ) # will die during lookup() if bug regresses def test_clamp_value(self): self.assertEqual(32768, paramiko.util.clamp_value(32767, 32768, 32769)) @@ -367,7 +419,9 @@ IdentityFile something_%l_using_fqdn def test_get_hostnames(self): f = StringIO(test_config_file) config = paramiko.util.parse_ssh_config(f) - self.assertEqual(config.get_hostnames(), {'*', '*.example.com', 'spoo.example.com'}) + self.assertEqual( + config.get_hostnames(), {"*", "*.example.com", "spoo.example.com"} + ) def test_quoted_host_names(self): test_config_file = """\ @@ -384,27 +438,23 @@ Host param4 "p a r" "p" "par" para Port 4444 """ res = { - 'param pam': {'hostname': 'param pam', 'port': '1111'}, - 'param': {'hostname': 'param', 'port': '1111'}, - 'pam': {'hostname': 'pam', 'port': '1111'}, - - 'param2': {'hostname': 'param2', 'port': '2222'}, - - 'param3': {'hostname': 'param3', 'port': '3333'}, - 'parara': {'hostname': 'parara', 'port': '3333'}, - - 'param4': {'hostname': 'param4', 'port': '4444'}, - 'p a r': {'hostname': 'p a r', 'port': '4444'}, - 'p': {'hostname': 'p', 'port': '4444'}, - 'par': {'hostname': 'par', 'port': '4444'}, - 'para': {'hostname': 'para', 'port': '4444'}, + "param pam": {"hostname": "param pam", "port": "1111"}, + "param": {"hostname": "param", "port": "1111"}, + "pam": {"hostname": "pam", "port": "1111"}, + "param2": {"hostname": "param2", "port": "2222"}, + "param3": {"hostname": "param3", "port": "3333"}, + "parara": {"hostname": "parara", "port": "3333"}, + "param4": {"hostname": "param4", "port": "4444"}, + "p a r": {"hostname": "p a r", "port": "4444"}, + "p": {"hostname": "p", "port": "4444"}, + "par": {"hostname": "par", "port": "4444"}, + "para": {"hostname": "para", "port": "4444"}, } f = StringIO(test_config_file) config = paramiko.util.parse_ssh_config(f) for host, values in res.items(): self.assertEquals( - paramiko.util.lookup_ssh_host_config(host, config), - values + paramiko.util.lookup_ssh_host_config(host, config), values ) def test_quoted_params_in_config(self): @@ -420,52 +470,44 @@ Host param3 parara IdentityFile "test rsa key" """ res = { - 'param pam': {'hostname': 'param pam', 'identityfile': ['id_rsa']}, - 'param': {'hostname': 'param', 'identityfile': ['id_rsa']}, - 'pam': {'hostname': 'pam', 'identityfile': ['id_rsa']}, - - 'param2': {'hostname': 'param2', 'identityfile': ['test rsa key']}, - - 'param3': {'hostname': 'param3', 'identityfile': ['id_rsa', 'test rsa key']}, - 'parara': {'hostname': 'parara', 'identityfile': ['id_rsa', 'test rsa key']}, + "param pam": {"hostname": "param pam", "identityfile": ["id_rsa"]}, + "param": {"hostname": "param", "identityfile": ["id_rsa"]}, + "pam": {"hostname": "pam", "identityfile": ["id_rsa"]}, + "param2": {"hostname": "param2", "identityfile": ["test rsa key"]}, + "param3": { + "hostname": "param3", + "identityfile": ["id_rsa", "test rsa key"], + }, + "parara": { + "hostname": "parara", + "identityfile": ["id_rsa", "test rsa key"], + }, } f = StringIO(test_config_file) config = paramiko.util.parse_ssh_config(f) for host, values in res.items(): self.assertEquals( - paramiko.util.lookup_ssh_host_config(host, config), - values + paramiko.util.lookup_ssh_host_config(host, config), values ) def test_quoted_host_in_config(self): conf = SSHConfig() correct_data = { - 'param': ['param'], - '"param"': ['param'], - - 'param pam': ['param', 'pam'], - '"param" "pam"': ['param', 'pam'], - '"param" pam': ['param', 'pam'], - 'param "pam"': ['param', 'pam'], - - 'param "pam" p': ['param', 'pam', 'p'], - '"param" pam "p"': ['param', 'pam', 'p'], - - '"pa ram"': ['pa ram'], - '"pa ram" pam': ['pa ram', 'pam'], - 'param "p a m"': ['param', 'p a m'], + "param": ["param"], + '"param"': ["param"], + "param pam": ["param", "pam"], + '"param" "pam"': ["param", "pam"], + '"param" pam': ["param", "pam"], + 'param "pam"': ["param", "pam"], + 'param "pam" p': ["param", "pam", "p"], + '"param" pam "p"': ["param", "pam", "p"], + '"pa ram"': ["pa ram"], + '"pa ram" pam': ["pa ram", "pam"], + 'param "p a m"': ["param", "p a m"], } - incorrect_data = [ - 'param"', - '"param', - 'param "pam', - 'param "pam" "p a', - ] + incorrect_data = ['param"', '"param', 'param "pam', 'param "pam" "p a'] for host, values in correct_data.items(): - self.assertEquals( - conf._get_hosts(host), - values - ) + self.assertEquals(conf._get_hosts(host), values) for host in incorrect_data: self.assertRaises(Exception, conf._get_hosts, host) @@ -490,15 +532,18 @@ Host proxycommand-with-equals-none ProxyCommand=None """ for host, values in { - 'proxycommand-standard-none': {'hostname': 'proxycommand-standard-none'}, - 'proxycommand-with-equals-none': {'hostname': 'proxycommand-with-equals-none'} + "proxycommand-standard-none": { + "hostname": "proxycommand-standard-none" + }, + "proxycommand-with-equals-none": { + "hostname": "proxycommand-with-equals-none" + }, }.items(): f = StringIO(test_config_file) config = paramiko.util.parse_ssh_config(f) self.assertEqual( - paramiko.util.lookup_ssh_host_config(host, config), - values + paramiko.util.lookup_ssh_host_config(host, config), values ) def test_proxycommand_none_masking(self): @@ -521,12 +566,10 @@ Host * # backwards compatibility reasons in 1.x/2.x) appear completely blank, # as if the host had no ProxyCommand whatsoever. # Threw another unrelated host in there just for sanity reasons. - self.assertFalse('proxycommand' in config.lookup('specific-host')) + self.assertFalse("proxycommand" in config.lookup("specific-host")) self.assertEqual( - config.lookup('other-host')['proxycommand'], - 'other-proxy' + config.lookup("other-host")["proxycommand"], "other-proxy" ) self.assertEqual( - config.lookup('some-random-host')['proxycommand'], - 'default-proxy' + config.lookup("some-random-host")["proxycommand"], "default-proxy" ) |