summaryrefslogtreecommitdiffhomepage
diff options
context:
space:
mode:
-rw-r--r--.travis.yml5
-rwxr-xr-xdemos/demo.py85
-rwxr-xr-xdemos/demo_keygen.py130
-rw-r--r--demos/demo_server.py88
-rw-r--r--demos/demo_sftp.py76
-rw-r--r--demos/demo_simple.py50
-rw-r--r--demos/forward.py176
-rw-r--r--demos/interactive.py19
-rwxr-xr-xdemos/rforward.py147
-rw-r--r--dev-requirements.txt4
-rw-r--r--paramiko/__init__.py34
-rw-r--r--paramiko/_version.py2
-rw-r--r--paramiko/_winapi.py154
-rw-r--r--paramiko/agent.py57
-rw-r--r--paramiko/auth_handler.py292
-rw-r--r--paramiko/ber.py13
-rw-r--r--paramiko/buffered_pipe.py9
-rw-r--r--paramiko/channel.py166
-rw-r--r--paramiko/client.py146
-rw-r--r--paramiko/common.py165
-rw-r--r--paramiko/compress.py6
-rw-r--r--paramiko/config.py130
-rw-r--r--paramiko/dsskey.py89
-rw-r--r--paramiko/ecdsakey.py63
-rw-r--r--paramiko/ed25519key.py19
-rw-r--r--paramiko/file.py74
-rw-r--r--paramiko/hostkeys.py50
-rw-r--r--paramiko/kex_ecdh_nist.py20
-rw-r--r--paramiko/kex_gex.py72
-rw-r--r--paramiko/kex_group1.py28
-rw-r--r--paramiko/kex_group14.py6
-rw-r--r--paramiko/kex_gss.py203
-rw-r--r--paramiko/message.py20
-rw-r--r--paramiko/packet.py130
-rw-r--r--paramiko/pipe.py19
-rw-r--r--paramiko/pkey.py104
-rw-r--r--paramiko/primes.py31
-rw-r--r--paramiko/proxy.py21
-rw-r--r--paramiko/py3compat.py64
-rw-r--r--paramiko/rsakey.py42
-rw-r--r--paramiko/server.py49
-rw-r--r--paramiko/sftp.py148
-rw-r--r--paramiko/sftp_attr.py68
-rw-r--r--paramiko/sftp_client.py175
-rw-r--r--paramiko/sftp_file.py93
-rw-r--r--paramiko/sftp_handle.py11
-rw-r--r--paramiko/sftp_server.py161
-rw-r--r--paramiko/sftp_si.py9
-rw-r--r--paramiko/ssh_exception.py54
-rw-r--r--paramiko/ssh_gss.py114
-rw-r--r--paramiko/transport.py1051
-rw-r--r--paramiko/util.py41
-rw-r--r--paramiko/win_pageant.py34
-rw-r--r--setup.cfg4
-rw-r--r--setup.py50
-rw-r--r--setup_helper.py48
-rw-r--r--sites/docs/conf.py13
-rw-r--r--sites/shared_conf.py39
-rw-r--r--sites/www/conf.py16
-rw-r--r--tasks.py15
-rw-r--r--tests/conftest.py12
-rw-r--r--tests/loop.py8
-rw-r--r--tests/stub_sftp.py58
-rw-r--r--tests/test_auth.py118
-rw-r--r--tests/test_buffered_pipe.py27
-rw-r--r--tests/test_client.py253
-rw-r--r--tests/test_file.py139
-rw-r--r--tests/test_gssapi.py44
-rw-r--r--tests/test_hostkeys.py79
-rw-r--r--tests/test_kex.py314
-rw-r--r--tests/test_kex_gss.py48
-rw-r--r--tests/test_message.py39
-rw-r--r--tests/test_packetizer.py34
-rw-r--r--tests/test_pkey.py257
-rw-r--r--tests/test_sftp.py495
-rw-r--r--tests/test_sftp_big.py222
-rw-r--r--tests/test_ssh_exception.py21
-rw-r--r--tests/test_ssh_gss.py61
-rw-r--r--tests/test_transport.py323
-rw-r--r--tests/test_util.py375
80 files changed, 4714 insertions, 3415 deletions
diff --git a/.travis.yml b/.travis.yml
index 4316bd9f..16d33b76 100644
--- a/.travis.yml
+++ b/.travis.yml
@@ -23,7 +23,10 @@ install:
- pip install codecov # For codecov specifically
- pip install -r dev-requirements.txt
script:
- # flake8 is now possible!
+ # Fast syntax check failures for more rapid feedback to submitters
+ # (Travis-oriented metatask that version checks Python, installs, runs.)
+ - inv travis.blacken
+ # I have this in my git pre-push hook, but contributors probably don't
- flake8
# All (including slow) tests, w/ coverage!
- inv coverage
diff --git a/demos/demo.py b/demos/demo.py
index fff61784..c9b0a5f5 100755
--- a/demos/demo.py
+++ b/demos/demo.py
@@ -31,6 +31,7 @@ import traceback
from paramiko.py3compat import input
import paramiko
+
try:
import interactive
except ImportError:
@@ -42,71 +43,73 @@ def agent_auth(transport, username):
Attempt to authenticate to the given transport using any of the private
keys available from an SSH agent.
"""
-
+
agent = paramiko.Agent()
agent_keys = agent.get_keys()
if len(agent_keys) == 0:
return
-
+
for key in agent_keys:
- print('Trying ssh-agent key %s' % hexlify(key.get_fingerprint()))
+ print("Trying ssh-agent key %s" % hexlify(key.get_fingerprint()))
try:
transport.auth_publickey(username, key)
- print('... success!')
+ print("... success!")
return
except paramiko.SSHException:
- print('... nope.')
+ print("... nope.")
def manual_auth(username, hostname):
- default_auth = 'p'
- auth = input('Auth by (p)assword, (r)sa key, or (d)ss key? [%s] ' % default_auth)
+ default_auth = "p"
+ auth = input(
+ "Auth by (p)assword, (r)sa key, or (d)ss key? [%s] " % default_auth
+ )
if len(auth) == 0:
auth = default_auth
- if auth == 'r':
- default_path = os.path.join(os.environ['HOME'], '.ssh', 'id_rsa')
- path = input('RSA key [%s]: ' % default_path)
+ if auth == "r":
+ default_path = os.path.join(os.environ["HOME"], ".ssh", "id_rsa")
+ path = input("RSA key [%s]: " % default_path)
if len(path) == 0:
path = default_path
try:
key = paramiko.RSAKey.from_private_key_file(path)
except paramiko.PasswordRequiredException:
- password = getpass.getpass('RSA key password: ')
+ password = getpass.getpass("RSA key password: ")
key = paramiko.RSAKey.from_private_key_file(path, password)
t.auth_publickey(username, key)
- elif auth == 'd':
- default_path = os.path.join(os.environ['HOME'], '.ssh', 'id_dsa')
- path = input('DSS key [%s]: ' % default_path)
+ elif auth == "d":
+ default_path = os.path.join(os.environ["HOME"], ".ssh", "id_dsa")
+ path = input("DSS key [%s]: " % default_path)
if len(path) == 0:
path = default_path
try:
key = paramiko.DSSKey.from_private_key_file(path)
except paramiko.PasswordRequiredException:
- password = getpass.getpass('DSS key password: ')
+ password = getpass.getpass("DSS key password: ")
key = paramiko.DSSKey.from_private_key_file(path, password)
t.auth_publickey(username, key)
else:
- pw = getpass.getpass('Password for %s@%s: ' % (username, hostname))
+ pw = getpass.getpass("Password for %s@%s: " % (username, hostname))
t.auth_password(username, pw)
# setup logging
-paramiko.util.log_to_file('demo.log')
+paramiko.util.log_to_file("demo.log")
-username = ''
+username = ""
if len(sys.argv) > 1:
hostname = sys.argv[1]
- if hostname.find('@') >= 0:
- username, hostname = hostname.split('@')
+ if hostname.find("@") >= 0:
+ username, hostname = hostname.split("@")
else:
- hostname = input('Hostname: ')
+ hostname = input("Hostname: ")
if len(hostname) == 0:
- print('*** Hostname required.')
+ print("*** Hostname required.")
sys.exit(1)
port = 22
-if hostname.find(':') >= 0:
- hostname, portstr = hostname.split(':')
+if hostname.find(":") >= 0:
+ hostname, portstr = hostname.split(":")
port = int(portstr)
# now connect
@@ -114,7 +117,7 @@ try:
sock = socket.socket(socket.AF_INET, socket.SOCK_STREAM)
sock.connect((hostname, port))
except Exception as e:
- print('*** Connect failed: ' + str(e))
+ print("*** Connect failed: " + str(e))
traceback.print_exc()
sys.exit(1)
@@ -123,34 +126,38 @@ try:
try:
t.start_client()
except paramiko.SSHException:
- print('*** SSH negotiation failed.')
+ print("*** SSH negotiation failed.")
sys.exit(1)
try:
- keys = paramiko.util.load_host_keys(os.path.expanduser('~/.ssh/known_hosts'))
+ keys = paramiko.util.load_host_keys(
+ os.path.expanduser("~/.ssh/known_hosts")
+ )
except IOError:
try:
- keys = paramiko.util.load_host_keys(os.path.expanduser('~/ssh/known_hosts'))
+ keys = paramiko.util.load_host_keys(
+ os.path.expanduser("~/ssh/known_hosts")
+ )
except IOError:
- print('*** Unable to open host keys file')
+ print("*** Unable to open host keys file")
keys = {}
# check server's host key -- this is important.
key = t.get_remote_server_key()
if hostname not in keys:
- print('*** WARNING: Unknown host key!')
+ print("*** WARNING: Unknown host key!")
elif key.get_name() not in keys[hostname]:
- print('*** WARNING: Unknown host key!')
+ print("*** WARNING: Unknown host key!")
elif keys[hostname][key.get_name()] != key:
- print('*** WARNING: Host key has changed!!!')
+ print("*** WARNING: Host key has changed!!!")
sys.exit(1)
else:
- print('*** Host key OK.')
+ print("*** Host key OK.")
# get username
- if username == '':
+ if username == "":
default_username = getpass.getuser()
- username = input('Username [%s]: ' % default_username)
+ username = input("Username [%s]: " % default_username)
if len(username) == 0:
username = default_username
@@ -158,25 +165,23 @@ try:
if not t.is_authenticated():
manual_auth(username, hostname)
if not t.is_authenticated():
- print('*** Authentication failed. :(')
+ print("*** Authentication failed. :(")
t.close()
sys.exit(1)
chan = t.open_session()
chan.get_pty()
chan.invoke_shell()
- print('*** Here we go!\n')
+ print("*** Here we go!\n")
interactive.interactive_shell(chan)
chan.close()
t.close()
except Exception as e:
- print('*** Caught exception: ' + str(e.__class__) + ': ' + str(e))
+ print("*** Caught exception: " + str(e.__class__) + ": " + str(e))
traceback.print_exc()
try:
t.close()
except:
pass
sys.exit(1)
-
-
diff --git a/demos/demo_keygen.py b/demos/demo_keygen.py
index 860ee4e9..6a80272d 100755
--- a/demos/demo_keygen.py
+++ b/demos/demo_keygen.py
@@ -28,62 +28,97 @@ from paramiko import RSAKey
from paramiko.ssh_exception import SSHException
from paramiko.py3compat import u
-usage="""
+usage = """
%prog [-v] [-b bits] -t type [-N new_passphrase] [-f output_keyfile]"""
default_values = {
"ktype": "dsa",
"bits": 1024,
"filename": "output",
- "comment": ""
+ "comment": "",
}
-key_dispatch_table = {
- 'dsa': DSSKey,
- 'rsa': RSAKey,
-}
+key_dispatch_table = {"dsa": DSSKey, "rsa": RSAKey}
+
def progress(arg=None):
if not arg:
- sys.stdout.write('0%\x08\x08\x08 ')
+ sys.stdout.write("0%\x08\x08\x08 ")
sys.stdout.flush()
- elif arg[0] == 'p':
- sys.stdout.write('25%\x08\x08\x08\x08 ')
+ elif arg[0] == "p":
+ sys.stdout.write("25%\x08\x08\x08\x08 ")
sys.stdout.flush()
- elif arg[0] == 'h':
- sys.stdout.write('50%\x08\x08\x08\x08 ')
+ elif arg[0] == "h":
+ sys.stdout.write("50%\x08\x08\x08\x08 ")
sys.stdout.flush()
- elif arg[0] == 'x':
- sys.stdout.write('75%\x08\x08\x08\x08 ')
+ elif arg[0] == "x":
+ sys.stdout.write("75%\x08\x08\x08\x08 ")
sys.stdout.flush()
-if __name__ == '__main__':
- phrase=None
- pfunc=None
+if __name__ == "__main__":
+
+ phrase = None
+ pfunc = None
parser = OptionParser(usage=usage)
- parser.add_option("-t", "--type", type="string", dest="ktype",
+ parser.add_option(
+ "-t",
+ "--type",
+ type="string",
+ dest="ktype",
help="Specify type of key to create (dsa or rsa)",
- metavar="ktype", default=default_values["ktype"])
- parser.add_option("-b", "--bits", type="int", dest="bits",
- help="Number of bits in the key to create", metavar="bits",
- default=default_values["bits"])
- parser.add_option("-N", "--new-passphrase", dest="newphrase",
- help="Provide new passphrase", metavar="phrase")
- parser.add_option("-P", "--old-passphrase", dest="oldphrase",
- help="Provide old passphrase", metavar="phrase")
- parser.add_option("-f", "--filename", type="string", dest="filename",
- help="Filename of the key file", metavar="filename",
- default=default_values["filename"])
- parser.add_option("-q", "--quiet", default=False, action="store_false",
- help="Quiet")
- parser.add_option("-v", "--verbose", default=False, action="store_true",
- help="Verbose")
- parser.add_option("-C", "--comment", type="string", dest="comment",
- help="Provide a new comment", metavar="comment",
- default=default_values["comment"])
+ metavar="ktype",
+ default=default_values["ktype"],
+ )
+ parser.add_option(
+ "-b",
+ "--bits",
+ type="int",
+ dest="bits",
+ help="Number of bits in the key to create",
+ metavar="bits",
+ default=default_values["bits"],
+ )
+ parser.add_option(
+ "-N",
+ "--new-passphrase",
+ dest="newphrase",
+ help="Provide new passphrase",
+ metavar="phrase",
+ )
+ parser.add_option(
+ "-P",
+ "--old-passphrase",
+ dest="oldphrase",
+ help="Provide old passphrase",
+ metavar="phrase",
+ )
+ parser.add_option(
+ "-f",
+ "--filename",
+ type="string",
+ dest="filename",
+ help="Filename of the key file",
+ metavar="filename",
+ default=default_values["filename"],
+ )
+ parser.add_option(
+ "-q", "--quiet", default=False, action="store_false", help="Quiet"
+ )
+ parser.add_option(
+ "-v", "--verbose", default=False, action="store_true", help="Verbose"
+ )
+ parser.add_option(
+ "-C",
+ "--comment",
+ type="string",
+ dest="comment",
+ help="Provide a new comment",
+ metavar="comment",
+ default=default_values["comment"],
+ )
(options, args) = parser.parse_args()
@@ -95,18 +130,23 @@ if __name__ == '__main__':
globals()[o] = getattr(options, o, default_values[o.lower()])
if options.newphrase:
- phrase = getattr(options, 'newphrase')
+ phrase = getattr(options, "newphrase")
if options.verbose:
pfunc = progress
- sys.stdout.write("Generating priv/pub %s %d bits key pair (%s/%s.pub)..." % (ktype, bits, filename, filename))
+ sys.stdout.write(
+ "Generating priv/pub %s %d bits key pair (%s/%s.pub)..."
+ % (ktype, bits, filename, filename)
+ )
sys.stdout.flush()
- if ktype == 'dsa' and bits > 1024:
+ if ktype == "dsa" and bits > 1024:
raise SSHException("DSA Keys must be 1024 bits")
if ktype not in key_dispatch_table:
- raise SSHException("Unknown %s algorithm to generate keys pair" % ktype)
+ raise SSHException(
+ "Unknown %s algorithm to generate keys pair" % ktype
+ )
# generating private key
prv = key_dispatch_table[ktype].generate(bits=bits, progress_func=pfunc)
@@ -114,7 +154,7 @@ if __name__ == '__main__':
# generating public key
pub = key_dispatch_table[ktype](filename=filename, password=phrase)
- with open("%s.pub" % filename, 'w') as f:
+ with open("%s.pub" % filename, "w") as f:
f.write("%s %s" % (pub.get_name(), pub.get_base64()))
if options.comment:
f.write(" %s" % comment)
@@ -123,4 +163,12 @@ if __name__ == '__main__':
print("done.")
hash = u(hexlify(pub.get_fingerprint()))
- print("Fingerprint: %d %s %s.pub (%s)" % (bits, ":".join([ hash[i:2+i] for i in range(0, len(hash), 2)]), filename, ktype.upper()))
+ print(
+ "Fingerprint: %d %s %s.pub (%s)"
+ % (
+ bits,
+ ":".join([hash[i : 2 + i] for i in range(0, len(hash), 2)]),
+ filename,
+ ktype.upper(),
+ )
+ )
diff --git a/demos/demo_server.py b/demos/demo_server.py
index 3a7ec854..313e5fb2 100644
--- a/demos/demo_server.py
+++ b/demos/demo_server.py
@@ -31,45 +31,47 @@ from paramiko.py3compat import b, u, decodebytes
# setup logging
-paramiko.util.log_to_file('demo_server.log')
+paramiko.util.log_to_file("demo_server.log")
-host_key = paramiko.RSAKey(filename='test_rsa.key')
-#host_key = paramiko.DSSKey(filename='test_dss.key')
+host_key = paramiko.RSAKey(filename="test_rsa.key")
+# host_key = paramiko.DSSKey(filename='test_dss.key')
-print('Read key: ' + u(hexlify(host_key.get_fingerprint())))
+print("Read key: " + u(hexlify(host_key.get_fingerprint())))
-class Server (paramiko.ServerInterface):
+class Server(paramiko.ServerInterface):
# 'data' is the output of base64.b64encode(key)
# (using the "user_rsa_key" files)
- data = (b'AAAAB3NzaC1yc2EAAAABIwAAAIEAyO4it3fHlmGZWJaGrfeHOVY7RWO3P9M7hp'
- b'fAu7jJ2d7eothvfeuoRFtJwhUmZDluRdFyhFY/hFAh76PJKGAusIqIQKlkJxMC'
- b'KDqIexkgHAfID/6mqvmnSJf0b5W8v5h2pI/stOSwTQ+pxVhwJ9ctYDhRSlF0iT'
- b'UWT10hcuO4Ks8=')
+ data = (
+ b"AAAAB3NzaC1yc2EAAAABIwAAAIEAyO4it3fHlmGZWJaGrfeHOVY7RWO3P9M7hp"
+ b"fAu7jJ2d7eothvfeuoRFtJwhUmZDluRdFyhFY/hFAh76PJKGAusIqIQKlkJxMC"
+ b"KDqIexkgHAfID/6mqvmnSJf0b5W8v5h2pI/stOSwTQ+pxVhwJ9ctYDhRSlF0iT"
+ b"UWT10hcuO4Ks8="
+ )
good_pub_key = paramiko.RSAKey(data=decodebytes(data))
def __init__(self):
self.event = threading.Event()
def check_channel_request(self, kind, chanid):
- if kind == 'session':
+ if kind == "session":
return paramiko.OPEN_SUCCEEDED
return paramiko.OPEN_FAILED_ADMINISTRATIVELY_PROHIBITED
def check_auth_password(self, username, password):
- if (username == 'robey') and (password == 'foo'):
+ if (username == "robey") and (password == "foo"):
return paramiko.AUTH_SUCCESSFUL
return paramiko.AUTH_FAILED
def check_auth_publickey(self, username, key):
- print('Auth attempt with key: ' + u(hexlify(key.get_fingerprint())))
- if (username == 'robey') and (key == self.good_pub_key):
+ print("Auth attempt with key: " + u(hexlify(key.get_fingerprint())))
+ if (username == "robey") and (key == self.good_pub_key):
return paramiko.AUTH_SUCCESSFUL
return paramiko.AUTH_FAILED
-
- def check_auth_gssapi_with_mic(self, username,
- gss_authenticated=paramiko.AUTH_FAILED,
- cc_file=None):
+
+ def check_auth_gssapi_with_mic(
+ self, username, gss_authenticated=paramiko.AUTH_FAILED, cc_file=None
+ ):
"""
.. note::
We are just checking in `AuthHandler` that the given user is a
@@ -88,9 +90,9 @@ class Server (paramiko.ServerInterface):
return paramiko.AUTH_SUCCESSFUL
return paramiko.AUTH_FAILED
- def check_auth_gssapi_keyex(self, username,
- gss_authenticated=paramiko.AUTH_FAILED,
- cc_file=None):
+ def check_auth_gssapi_keyex(
+ self, username, gss_authenticated=paramiko.AUTH_FAILED, cc_file=None
+ ):
if gss_authenticated == paramiko.AUTH_SUCCESSFUL:
return paramiko.AUTH_SUCCESSFUL
return paramiko.AUTH_FAILED
@@ -99,14 +101,15 @@ class Server (paramiko.ServerInterface):
return True
def get_allowed_auths(self, username):
- return 'gssapi-keyex,gssapi-with-mic,password,publickey'
+ return "gssapi-keyex,gssapi-with-mic,password,publickey"
def check_channel_shell_request(self, channel):
self.event.set()
return True
- def check_channel_pty_request(self, channel, term, width, height, pixelwidth,
- pixelheight, modes):
+ def check_channel_pty_request(
+ self, channel, term, width, height, pixelwidth, pixelheight, modes
+ ):
return True
@@ -116,22 +119,22 @@ DoGSSAPIKeyExchange = True
try:
sock = socket.socket(socket.AF_INET, socket.SOCK_STREAM)
sock.setsockopt(socket.SOL_SOCKET, socket.SO_REUSEADDR, 1)
- sock.bind(('', 2200))
+ sock.bind(("", 2200))
except Exception as e:
- print('*** Bind failed: ' + str(e))
+ print("*** Bind failed: " + str(e))
traceback.print_exc()
sys.exit(1)
try:
sock.listen(100)
- print('Listening for connection ...')
+ print("Listening for connection ...")
client, addr = sock.accept()
except Exception as e:
- print('*** Listen/accept failed: ' + str(e))
+ print("*** Listen/accept failed: " + str(e))
traceback.print_exc()
sys.exit(1)
-print('Got a connection!')
+print("Got a connection!")
try:
t = paramiko.Transport(client, gss_kex=DoGSSAPIKeyExchange)
@@ -139,43 +142,44 @@ try:
try:
t.load_server_moduli()
except:
- print('(Failed to load moduli -- gex will be unsupported.)')
+ print("(Failed to load moduli -- gex will be unsupported.)")
raise
t.add_server_key(host_key)
server = Server()
try:
t.start_server(server=server)
except paramiko.SSHException:
- print('*** SSH negotiation failed.')
+ print("*** SSH negotiation failed.")
sys.exit(1)
# wait for auth
chan = t.accept(20)
if chan is None:
- print('*** No channel.')
+ print("*** No channel.")
sys.exit(1)
- print('Authenticated!')
+ print("Authenticated!")
server.event.wait(10)
if not server.event.is_set():
- print('*** Client never asked for a shell.')
+ print("*** Client never asked for a shell.")
sys.exit(1)
- chan.send('\r\n\r\nWelcome to my dorky little BBS!\r\n\r\n')
- chan.send('We are on fire all the time! Hooray! Candy corn for everyone!\r\n')
- chan.send('Happy birthday to Robot Dave!\r\n\r\n')
- chan.send('Username: ')
- f = chan.makefile('rU')
- username = f.readline().strip('\r\n')
- chan.send('\r\nI don\'t like you, ' + username + '.\r\n')
+ chan.send("\r\n\r\nWelcome to my dorky little BBS!\r\n\r\n")
+ chan.send(
+ "We are on fire all the time! Hooray! Candy corn for everyone!\r\n"
+ )
+ chan.send("Happy birthday to Robot Dave!\r\n\r\n")
+ chan.send("Username: ")
+ f = chan.makefile("rU")
+ username = f.readline().strip("\r\n")
+ chan.send("\r\nI don't like you, " + username + ".\r\n")
chan.close()
except Exception as e:
- print('*** Caught exception: ' + str(e.__class__) + ': ' + str(e))
+ print("*** Caught exception: " + str(e.__class__) + ": " + str(e))
traceback.print_exc()
try:
t.close()
except:
pass
sys.exit(1)
-
diff --git a/demos/demo_sftp.py b/demos/demo_sftp.py
index 2cb44701..7f6a002e 100644
--- a/demos/demo_sftp.py
+++ b/demos/demo_sftp.py
@@ -32,38 +32,38 @@ from paramiko.py3compat import input
# setup logging
-paramiko.util.log_to_file('demo_sftp.log')
+paramiko.util.log_to_file("demo_sftp.log")
# Paramiko client configuration
-UseGSSAPI = True # enable GSS-API / SSPI authentication
+UseGSSAPI = True # enable GSS-API / SSPI authentication
DoGSSAPIKeyExchange = True
Port = 22
# get hostname
-username = ''
+username = ""
if len(sys.argv) > 1:
hostname = sys.argv[1]
- if hostname.find('@') >= 0:
- username, hostname = hostname.split('@')
+ if hostname.find("@") >= 0:
+ username, hostname = hostname.split("@")
else:
- hostname = input('Hostname: ')
+ hostname = input("Hostname: ")
if len(hostname) == 0:
- print('*** Hostname required.')
+ print("*** Hostname required.")
sys.exit(1)
-if hostname.find(':') >= 0:
- hostname, portstr = hostname.split(':')
+if hostname.find(":") >= 0:
+ hostname, portstr = hostname.split(":")
Port = int(portstr)
# get username
-if username == '':
+if username == "":
default_username = getpass.getuser()
- username = input('Username [%s]: ' % default_username)
+ username = input("Username [%s]: " % default_username)
if len(username) == 0:
username = default_username
if not UseGSSAPI:
- password = getpass.getpass('Password for %s@%s: ' % (username, hostname))
+ password = getpass.getpass("Password for %s@%s: " % (username, hostname))
else:
password = None
@@ -72,59 +72,69 @@ else:
hostkeytype = None
hostkey = None
try:
- host_keys = paramiko.util.load_host_keys(os.path.expanduser('~/.ssh/known_hosts'))
+ host_keys = paramiko.util.load_host_keys(
+ os.path.expanduser("~/.ssh/known_hosts")
+ )
except IOError:
try:
# try ~/ssh/ too, because windows can't have a folder named ~/.ssh/
- host_keys = paramiko.util.load_host_keys(os.path.expanduser('~/ssh/known_hosts'))
+ host_keys = paramiko.util.load_host_keys(
+ os.path.expanduser("~/ssh/known_hosts")
+ )
except IOError:
- print('*** Unable to open host keys file')
+ print("*** Unable to open host keys file")
host_keys = {}
if hostname in host_keys:
hostkeytype = host_keys[hostname].keys()[0]
hostkey = host_keys[hostname][hostkeytype]
- print('Using host key of type %s' % hostkeytype)
+ print("Using host key of type %s" % hostkeytype)
# now, connect and use paramiko Transport to negotiate SSH2 across the connection
try:
t = paramiko.Transport((hostname, Port))
- t.connect(hostkey, username, password, gss_host=socket.getfqdn(hostname),
- gss_auth=UseGSSAPI, gss_kex=DoGSSAPIKeyExchange)
+ t.connect(
+ hostkey,
+ username,
+ password,
+ gss_host=socket.getfqdn(hostname),
+ gss_auth=UseGSSAPI,
+ gss_kex=DoGSSAPIKeyExchange,
+ )
sftp = paramiko.SFTPClient.from_transport(t)
# dirlist on remote host
- dirlist = sftp.listdir('.')
+ dirlist = sftp.listdir(".")
print("Dirlist: %s" % dirlist)
# copy this demo onto the server
try:
sftp.mkdir("demo_sftp_folder")
except IOError:
- print('(assuming demo_sftp_folder/ already exists)')
- with sftp.open('demo_sftp_folder/README', 'w') as f:
- f.write('This was created by demo_sftp.py.\n')
- with open('demo_sftp.py', 'r') as f:
+ print("(assuming demo_sftp_folder/ already exists)")
+ with sftp.open("demo_sftp_folder/README", "w") as f:
+ f.write("This was created by demo_sftp.py.\n")
+ with open("demo_sftp.py", "r") as f:
data = f.read()
- sftp.open('demo_sftp_folder/demo_sftp.py', 'w').write(data)
- print('created demo_sftp_folder/ on the server')
-
+ sftp.open("demo_sftp_folder/demo_sftp.py", "w").write(data)
+ print("created demo_sftp_folder/ on the server")
+
# copy the README back here
- with sftp.open('demo_sftp_folder/README', 'r') as f:
+ with sftp.open("demo_sftp_folder/README", "r") as f:
data = f.read()
- with open('README_demo_sftp', 'w') as f:
+ with open("README_demo_sftp", "w") as f:
f.write(data)
- print('copied README back here')
-
+ print("copied README back here")
+
# BETTER: use the get() and put() methods
- sftp.put('demo_sftp.py', 'demo_sftp_folder/demo_sftp.py')
- sftp.get('demo_sftp_folder/README', 'README_demo_sftp')
+ sftp.put("demo_sftp.py", "demo_sftp_folder/demo_sftp.py")
+ sftp.get("demo_sftp_folder/README", "README_demo_sftp")
t.close()
except Exception as e:
- print('*** Caught exception: %s: %s' % (e.__class__, e))
+ print("*** Caught exception: %s: %s" % (e.__class__, e))
traceback.print_exc()
try:
t.close()
diff --git a/demos/demo_simple.py b/demos/demo_simple.py
index 9def57f8..5dd4f6c1 100644
--- a/demos/demo_simple.py
+++ b/demos/demo_simple.py
@@ -28,6 +28,7 @@ import traceback
from paramiko.py3compat import input
import paramiko
+
try:
import interactive
except ImportError:
@@ -35,39 +36,43 @@ except ImportError:
# setup logging
-paramiko.util.log_to_file('demo_simple.log')
+paramiko.util.log_to_file("demo_simple.log")
# Paramiko client configuration
-UseGSSAPI = paramiko.GSS_AUTH_AVAILABLE # enable "gssapi-with-mic" authentication, if supported by your python installation
-DoGSSAPIKeyExchange = paramiko.GSS_AUTH_AVAILABLE # enable "gssapi-kex" key exchange, if supported by your python installation
+UseGSSAPI = (
+ paramiko.GSS_AUTH_AVAILABLE
+) # enable "gssapi-with-mic" authentication, if supported by your python installation
+DoGSSAPIKeyExchange = (
+ paramiko.GSS_AUTH_AVAILABLE
+) # enable "gssapi-kex" key exchange, if supported by your python installation
# UseGSSAPI = False
# DoGSSAPIKeyExchange = False
port = 22
# get hostname
-username = ''
+username = ""
if len(sys.argv) > 1:
hostname = sys.argv[1]
- if hostname.find('@') >= 0:
- username, hostname = hostname.split('@')
+ if hostname.find("@") >= 0:
+ username, hostname = hostname.split("@")
else:
- hostname = input('Hostname: ')
+ hostname = input("Hostname: ")
if len(hostname) == 0:
- print('*** Hostname required.')
+ print("*** Hostname required.")
sys.exit(1)
-if hostname.find(':') >= 0:
- hostname, portstr = hostname.split(':')
+if hostname.find(":") >= 0:
+ hostname, portstr = hostname.split(":")
port = int(portstr)
# get username
-if username == '':
+if username == "":
default_username = getpass.getuser()
- username = input('Username [%s]: ' % default_username)
+ username = input("Username [%s]: " % default_username)
if len(username) == 0:
username = default_username
if not UseGSSAPI and not DoGSSAPIKeyExchange:
- password = getpass.getpass('Password for %s@%s: ' % (username, hostname))
+ password = getpass.getpass("Password for %s@%s: " % (username, hostname))
# now, connect and use paramiko Client to negotiate SSH2 across the connection
@@ -75,27 +80,34 @@ try:
client = paramiko.SSHClient()
client.load_system_host_keys()
client.set_missing_host_key_policy(paramiko.WarningPolicy())
- print('*** Connecting...')
+ print("*** Connecting...")
if not UseGSSAPI and not DoGSSAPIKeyExchange:
client.connect(hostname, port, username, password)
else:
try:
- client.connect(hostname, port, username, gss_auth=UseGSSAPI,
- gss_kex=DoGSSAPIKeyExchange)
+ client.connect(
+ hostname,
+ port,
+ username,
+ gss_auth=UseGSSAPI,
+ gss_kex=DoGSSAPIKeyExchange,
+ )
except Exception:
# traceback.print_exc()
- password = getpass.getpass('Password for %s@%s: ' % (username, hostname))
+ password = getpass.getpass(
+ "Password for %s@%s: " % (username, hostname)
+ )
client.connect(hostname, port, username, password)
chan = client.invoke_shell()
print(repr(client.get_transport()))
- print('*** Here we go!\n')
+ print("*** Here we go!\n")
interactive.interactive_shell(chan)
chan.close()
client.close()
except Exception as e:
- print('*** Caught exception: %s: %s' % (e.__class__, e))
+ print("*** Caught exception: %s: %s" % (e.__class__, e))
traceback.print_exc()
try:
client.close()
diff --git a/demos/forward.py b/demos/forward.py
index 96e1700d..98757911 100644
--- a/demos/forward.py
+++ b/demos/forward.py
@@ -30,6 +30,7 @@ import getpass
import os
import socket
import select
+
try:
import SocketServer
except ImportError:
@@ -46,30 +47,41 @@ DEFAULT_PORT = 4000
g_verbose = True
-class ForwardServer (SocketServer.ThreadingTCPServer):
+class ForwardServer(SocketServer.ThreadingTCPServer):
daemon_threads = True
allow_reuse_address = True
-
-class Handler (SocketServer.BaseRequestHandler):
+
+class Handler(SocketServer.BaseRequestHandler):
def handle(self):
try:
- chan = self.ssh_transport.open_channel('direct-tcpip',
- (self.chain_host, self.chain_port),
- self.request.getpeername())
+ chan = self.ssh_transport.open_channel(
+ "direct-tcpip",
+ (self.chain_host, self.chain_port),
+ self.request.getpeername(),
+ )
except Exception as e:
- verbose('Incoming request to %s:%d failed: %s' % (self.chain_host,
- self.chain_port,
- repr(e)))
+ verbose(
+ "Incoming request to %s:%d failed: %s"
+ % (self.chain_host, self.chain_port, repr(e))
+ )
return
if chan is None:
- verbose('Incoming request to %s:%d was rejected by the SSH server.' %
- (self.chain_host, self.chain_port))
+ verbose(
+ "Incoming request to %s:%d was rejected by the SSH server."
+ % (self.chain_host, self.chain_port)
+ )
return
- verbose('Connected! Tunnel open %r -> %r -> %r' % (self.request.getpeername(),
- chan.getpeername(), (self.chain_host, self.chain_port)))
+ verbose(
+ "Connected! Tunnel open %r -> %r -> %r"
+ % (
+ self.request.getpeername(),
+ chan.getpeername(),
+ (self.chain_host, self.chain_port),
+ )
+ )
while True:
r, w, x = select.select([self.request, chan], [], [])
if self.request in r:
@@ -82,22 +94,23 @@ class Handler (SocketServer.BaseRequestHandler):
if len(data) == 0:
break
self.request.send(data)
-
+
peername = self.request.getpeername()
chan.close()
self.request.close()
- verbose('Tunnel closed from %r' % (peername,))
+ verbose("Tunnel closed from %r" % (peername,))
def forward_tunnel(local_port, remote_host, remote_port, transport):
# this is a little convoluted, but lets me configure things for the Handler
# object. (SocketServer doesn't give Handlers any way to access the outer
# server normally.)
- class SubHander (Handler):
+ class SubHander(Handler):
chain_host = remote_host
chain_port = remote_port
ssh_transport = transport
- ForwardServer(('', local_port), SubHander).serve_forever()
+
+ ForwardServer(("", local_port), SubHander).serve_forever()
def verbose(s):
@@ -114,40 +127,88 @@ the SSH server. This is similar to the openssh -L option.
def get_host_port(spec, default_port):
"parse 'hostname:22' into a host and port, with the port optional"
- args = (spec.split(':', 1) + [default_port])[:2]
+ args = (spec.split(":", 1) + [default_port])[:2]
args[1] = int(args[1])
return args[0], args[1]
def parse_options():
global g_verbose
-
- parser = OptionParser(usage='usage: %prog [options] <ssh-server>[:<server-port>]',
- version='%prog 1.0', description=HELP)
- parser.add_option('-q', '--quiet', action='store_false', dest='verbose', default=True,
- help='squelch all informational output')
- parser.add_option('-p', '--local-port', action='store', type='int', dest='port',
- default=DEFAULT_PORT,
- help='local port to forward (default: %d)' % DEFAULT_PORT)
- parser.add_option('-u', '--user', action='store', type='string', dest='user',
- default=getpass.getuser(),
- help='username for SSH authentication (default: %s)' % getpass.getuser())
- parser.add_option('-K', '--key', action='store', type='string', dest='keyfile',
- default=None,
- help='private key file to use for SSH authentication')
- parser.add_option('', '--no-key', action='store_false', dest='look_for_keys', default=True,
- help='don\'t look for or use a private key file')
- parser.add_option('-P', '--password', action='store_true', dest='readpass', default=False,
- help='read password (for key or password auth) from stdin')
- parser.add_option('-r', '--remote', action='store', type='string', dest='remote', default=None, metavar='host:port',
- help='remote host and port to forward to')
+
+ parser = OptionParser(
+ usage="usage: %prog [options] <ssh-server>[:<server-port>]",
+ version="%prog 1.0",
+ description=HELP,
+ )
+ parser.add_option(
+ "-q",
+ "--quiet",
+ action="store_false",
+ dest="verbose",
+ default=True,
+ help="squelch all informational output",
+ )
+ parser.add_option(
+ "-p",
+ "--local-port",
+ action="store",
+ type="int",
+ dest="port",
+ default=DEFAULT_PORT,
+ help="local port to forward (default: %d)" % DEFAULT_PORT,
+ )
+ parser.add_option(
+ "-u",
+ "--user",
+ action="store",
+ type="string",
+ dest="user",
+ default=getpass.getuser(),
+ help="username for SSH authentication (default: %s)"
+ % getpass.getuser(),
+ )
+ parser.add_option(
+ "-K",
+ "--key",
+ action="store",
+ type="string",
+ dest="keyfile",
+ default=None,
+ help="private key file to use for SSH authentication",
+ )
+ parser.add_option(
+ "",
+ "--no-key",
+ action="store_false",
+ dest="look_for_keys",
+ default=True,
+ help="don't look for or use a private key file",
+ )
+ parser.add_option(
+ "-P",
+ "--password",
+ action="store_true",
+ dest="readpass",
+ default=False,
+ help="read password (for key or password auth) from stdin",
+ )
+ parser.add_option(
+ "-r",
+ "--remote",
+ action="store",
+ type="string",
+ dest="remote",
+ default=None,
+ metavar="host:port",
+ help="remote host and port to forward to",
+ )
options, args = parser.parse_args()
if len(args) != 1:
- parser.error('Incorrect number of arguments.')
+ parser.error("Incorrect number of arguments.")
if options.remote is None:
- parser.error('Remote address required (-r).')
-
+ parser.error("Remote address required (-r).")
+
g_verbose = options.verbose
server_host, server_port = get_host_port(args[0], SSH_PORT)
remote_host, remote_port = get_host_port(options.remote, SSH_PORT)
@@ -156,31 +217,42 @@ def parse_options():
def main():
options, server, remote = parse_options()
-
+
password = None
if options.readpass:
- password = getpass.getpass('Enter SSH password: ')
-
+ password = getpass.getpass("Enter SSH password: ")
+
client = paramiko.SSHClient()
client.load_system_host_keys()
client.set_missing_host_key_policy(paramiko.WarningPolicy())
- verbose('Connecting to ssh host %s:%d ...' % (server[0], server[1]))
+ verbose("Connecting to ssh host %s:%d ..." % (server[0], server[1]))
try:
- client.connect(server[0], server[1], username=options.user, key_filename=options.keyfile,
- look_for_keys=options.look_for_keys, password=password)
+ client.connect(
+ server[0],
+ server[1],
+ username=options.user,
+ key_filename=options.keyfile,
+ look_for_keys=options.look_for_keys,
+ password=password,
+ )
except Exception as e:
- print('*** Failed to connect to %s:%d: %r' % (server[0], server[1], e))
+ print("*** Failed to connect to %s:%d: %r" % (server[0], server[1], e))
sys.exit(1)
- verbose('Now forwarding port %d to %s:%d ...' % (options.port, remote[0], remote[1]))
+ verbose(
+ "Now forwarding port %d to %s:%d ..."
+ % (options.port, remote[0], remote[1])
+ )
try:
- forward_tunnel(options.port, remote[0], remote[1], client.get_transport())
+ forward_tunnel(
+ options.port, remote[0], remote[1], client.get_transport()
+ )
except KeyboardInterrupt:
- print('C-c: Port forwarding stopped.')
+ print("C-c: Port forwarding stopped.")
sys.exit(0)
-if __name__ == '__main__':
+if __name__ == "__main__":
main()
diff --git a/demos/interactive.py b/demos/interactive.py
index 7138cd6c..037787c4 100644
--- a/demos/interactive.py
+++ b/demos/interactive.py
@@ -25,6 +25,7 @@ from paramiko.py3compat import u
try:
import termios
import tty
+
has_termios = True
except ImportError:
has_termios = False
@@ -39,7 +40,7 @@ def interactive_shell(chan):
def posix_shell(chan):
import select
-
+
oldtty = termios.tcgetattr(sys.stdin)
try:
tty.setraw(sys.stdin.fileno())
@@ -52,7 +53,7 @@ def posix_shell(chan):
try:
x = u(chan.recv(1024))
if len(x) == 0:
- sys.stdout.write('\r\n*** EOF\r\n')
+ sys.stdout.write("\r\n*** EOF\r\n")
break
sys.stdout.write(x)
sys.stdout.flush()
@@ -67,26 +68,28 @@ def posix_shell(chan):
finally:
termios.tcsetattr(sys.stdin, termios.TCSADRAIN, oldtty)
-
+
# thanks to Mike Looijmans for this code
def windows_shell(chan):
import threading
- sys.stdout.write("Line-buffered terminal emulation. Press F6 or ^Z to send EOF.\r\n\r\n")
-
+ sys.stdout.write(
+ "Line-buffered terminal emulation. Press F6 or ^Z to send EOF.\r\n\r\n"
+ )
+
def writeall(sock):
while True:
data = sock.recv(256)
if not data:
- sys.stdout.write('\r\n*** EOF ***\r\n\r\n')
+ sys.stdout.write("\r\n*** EOF ***\r\n\r\n")
sys.stdout.flush()
break
sys.stdout.write(data)
sys.stdout.flush()
-
+
writer = threading.Thread(target=writeall, args=(chan,))
writer.start()
-
+
try:
while True:
d = sys.stdin.read(1)
diff --git a/demos/rforward.py b/demos/rforward.py
index ae70670c..a2e8a776 100755
--- a/demos/rforward.py
+++ b/demos/rforward.py
@@ -47,11 +47,13 @@ def handler(chan, host, port):
try:
sock.connect((host, port))
except Exception as e:
- verbose('Forwarding request to %s:%d failed: %r' % (host, port, e))
+ verbose("Forwarding request to %s:%d failed: %r" % (host, port, e))
return
-
- verbose('Connected! Tunnel open %r -> %r -> %r' % (chan.origin_addr,
- chan.getpeername(), (host, port)))
+
+ verbose(
+ "Connected! Tunnel open %r -> %r -> %r"
+ % (chan.origin_addr, chan.getpeername(), (host, port))
+ )
while True:
r, w, x = select.select([sock, chan], [], [])
if sock in r:
@@ -66,16 +68,18 @@ def handler(chan, host, port):
sock.send(data)
chan.close()
sock.close()
- verbose('Tunnel closed from %r' % (chan.origin_addr,))
+ verbose("Tunnel closed from %r" % (chan.origin_addr,))
def reverse_forward_tunnel(server_port, remote_host, remote_port, transport):
- transport.request_port_forward('', server_port)
+ transport.request_port_forward("", server_port)
while True:
chan = transport.accept(1000)
if chan is None:
continue
- thr = threading.Thread(target=handler, args=(chan, remote_host, remote_port))
+ thr = threading.Thread(
+ target=handler, args=(chan, remote_host, remote_port)
+ )
thr.setDaemon(True)
thr.start()
@@ -95,40 +99,88 @@ network. This is similar to the openssh -R option.
def get_host_port(spec, default_port):
"parse 'hostname:22' into a host and port, with the port optional"
- args = (spec.split(':', 1) + [default_port])[:2]
+ args = (spec.split(":", 1) + [default_port])[:2]
args[1] = int(args[1])
return args[0], args[1]
def parse_options():
global g_verbose
-
- parser = OptionParser(usage='usage: %prog [options] <ssh-server>[:<server-port>]',
- version='%prog 1.0', description=HELP)
- parser.add_option('-q', '--quiet', action='store_false', dest='verbose', default=True,
- help='squelch all informational output')
- parser.add_option('-p', '--remote-port', action='store', type='int', dest='port',
- default=DEFAULT_PORT,
- help='port on server to forward (default: %d)' % DEFAULT_PORT)
- parser.add_option('-u', '--user', action='store', type='string', dest='user',
- default=getpass.getuser(),
- help='username for SSH authentication (default: %s)' % getpass.getuser())
- parser.add_option('-K', '--key', action='store', type='string', dest='keyfile',
- default=None,
- help='private key file to use for SSH authentication')
- parser.add_option('', '--no-key', action='store_false', dest='look_for_keys', default=True,
- help='don\'t look for or use a private key file')
- parser.add_option('-P', '--password', action='store_true', dest='readpass', default=False,
- help='read password (for key or password auth) from stdin')
- parser.add_option('-r', '--remote', action='store', type='string', dest='remote', default=None, metavar='host:port',
- help='remote host and port to forward to')
+
+ parser = OptionParser(
+ usage="usage: %prog [options] <ssh-server>[:<server-port>]",
+ version="%prog 1.0",
+ description=HELP,
+ )
+ parser.add_option(
+ "-q",
+ "--quiet",
+ action="store_false",
+ dest="verbose",
+ default=True,
+ help="squelch all informational output",
+ )
+ parser.add_option(
+ "-p",
+ "--remote-port",
+ action="store",
+ type="int",
+ dest="port",
+ default=DEFAULT_PORT,
+ help="port on server to forward (default: %d)" % DEFAULT_PORT,
+ )
+ parser.add_option(
+ "-u",
+ "--user",
+ action="store",
+ type="string",
+ dest="user",
+ default=getpass.getuser(),
+ help="username for SSH authentication (default: %s)"
+ % getpass.getuser(),
+ )
+ parser.add_option(
+ "-K",
+ "--key",
+ action="store",
+ type="string",
+ dest="keyfile",
+ default=None,
+ help="private key file to use for SSH authentication",
+ )
+ parser.add_option(
+ "",
+ "--no-key",
+ action="store_false",
+ dest="look_for_keys",
+ default=True,
+ help="don't look for or use a private key file",
+ )
+ parser.add_option(
+ "-P",
+ "--password",
+ action="store_true",
+ dest="readpass",
+ default=False,
+ help="read password (for key or password auth) from stdin",
+ )
+ parser.add_option(
+ "-r",
+ "--remote",
+ action="store",
+ type="string",
+ dest="remote",
+ default=None,
+ metavar="host:port",
+ help="remote host and port to forward to",
+ )
options, args = parser.parse_args()
if len(args) != 1:
- parser.error('Incorrect number of arguments.')
+ parser.error("Incorrect number of arguments.")
if options.remote is None:
- parser.error('Remote address required (-r).')
-
+ parser.error("Remote address required (-r).")
+
g_verbose = options.verbose
server_host, server_port = get_host_port(args[0], SSH_PORT)
remote_host, remote_port = get_host_port(options.remote, SSH_PORT)
@@ -137,31 +189,42 @@ def parse_options():
def main():
options, server, remote = parse_options()
-
+
password = None
if options.readpass:
- password = getpass.getpass('Enter SSH password: ')
-
+ password = getpass.getpass("Enter SSH password: ")
+
client = paramiko.SSHClient()
client.load_system_host_keys()
client.set_missing_host_key_policy(paramiko.WarningPolicy())
- verbose('Connecting to ssh host %s:%d ...' % (server[0], server[1]))
+ verbose("Connecting to ssh host %s:%d ..." % (server[0], server[1]))
try:
- client.connect(server[0], server[1], username=options.user, key_filename=options.keyfile,
- look_for_keys=options.look_for_keys, password=password)
+ client.connect(
+ server[0],
+ server[1],
+ username=options.user,
+ key_filename=options.keyfile,
+ look_for_keys=options.look_for_keys,
+ password=password,
+ )
except Exception as e:
- print('*** Failed to connect to %s:%d: %r' % (server[0], server[1], e))
+ print("*** Failed to connect to %s:%d: %r" % (server[0], server[1], e))
sys.exit(1)
- verbose('Now forwarding remote port %d to %s:%d ...' % (options.port, remote[0], remote[1]))
+ verbose(
+ "Now forwarding remote port %d to %s:%d ..."
+ % (options.port, remote[0], remote[1])
+ )
try:
- reverse_forward_tunnel(options.port, remote[0], remote[1], client.get_transport())
+ reverse_forward_tunnel(
+ options.port, remote[0], remote[1], client.get_transport()
+ )
except KeyboardInterrupt:
- print('C-c: Port forwarding stopped.')
+ print("C-c: Port forwarding stopped.")
sys.exit(0)
-if __name__ == '__main__':
+if __name__ == "__main__":
main()
diff --git a/dev-requirements.txt b/dev-requirements.txt
index d41972b9..e4629187 100644
--- a/dev-requirements.txt
+++ b/dev-requirements.txt
@@ -1,6 +1,6 @@
# Invocations for common project tasks
-invoke>=0.13,<2.0
-invocations>=1.0,<2.0
+invoke>=1.0,<2.0
+invocations>=1.2.0,<2.0
# NOTE: pytest-relaxed currently only works with pytest >=3, <3.3
pytest>=3.2,<3.3
pytest-relaxed==1.1.2
diff --git a/paramiko/__init__.py b/paramiko/__init__.py
index e2f66b8b..ebfa72a8 100644
--- a/paramiko/__init__.py
+++ b/paramiko/__init__.py
@@ -21,15 +21,22 @@ import sys
from paramiko._version import __version__, __version_info__
from paramiko.transport import SecurityOptions, Transport
from paramiko.client import (
- SSHClient, MissingHostKeyPolicy, AutoAddPolicy, RejectPolicy,
+ SSHClient,
+ MissingHostKeyPolicy,
+ AutoAddPolicy,
+ RejectPolicy,
WarningPolicy,
)
from paramiko.auth_handler import AuthHandler
from paramiko.ssh_gss import GSSAuth, GSS_AUTH_AVAILABLE, GSS_EXCEPTIONS
from paramiko.channel import Channel, ChannelFile
from paramiko.ssh_exception import (
- SSHException, PasswordRequiredException, BadAuthenticationType,
- ChannelException, BadHostKeyException, AuthenticationException,
+ SSHException,
+ PasswordRequiredException,
+ BadAuthenticationType,
+ ChannelException,
+ BadHostKeyException,
+ AuthenticationException,
ProxyCommandFailure,
)
from paramiko.server import ServerInterface, SubsystemHandler, InteractiveQuery
@@ -54,14 +61,25 @@ from paramiko.config import SSHConfig
from paramiko.proxy import ProxyCommand
from paramiko.common import (
- AUTH_SUCCESSFUL, AUTH_PARTIALLY_SUCCESSFUL, AUTH_FAILED, OPEN_SUCCEEDED,
- OPEN_FAILED_ADMINISTRATIVELY_PROHIBITED, OPEN_FAILED_CONNECT_FAILED,
- OPEN_FAILED_UNKNOWN_CHANNEL_TYPE, OPEN_FAILED_RESOURCE_SHORTAGE,
+ AUTH_SUCCESSFUL,
+ AUTH_PARTIALLY_SUCCESSFUL,
+ AUTH_FAILED,
+ OPEN_SUCCEEDED,
+ OPEN_FAILED_ADMINISTRATIVELY_PROHIBITED,
+ OPEN_FAILED_CONNECT_FAILED,
+ OPEN_FAILED_UNKNOWN_CHANNEL_TYPE,
+ OPEN_FAILED_RESOURCE_SHORTAGE,
)
from paramiko.sftp import (
- SFTP_OK, SFTP_EOF, SFTP_NO_SUCH_FILE, SFTP_PERMISSION_DENIED, SFTP_FAILURE,
- SFTP_BAD_MESSAGE, SFTP_NO_CONNECTION, SFTP_CONNECTION_LOST,
+ SFTP_OK,
+ SFTP_EOF,
+ SFTP_NO_SUCH_FILE,
+ SFTP_PERMISSION_DENIED,
+ SFTP_FAILURE,
+ SFTP_BAD_MESSAGE,
+ SFTP_NO_CONNECTION,
+ SFTP_CONNECTION_LOST,
SFTP_OP_UNSUPPORTED,
)
diff --git a/paramiko/_version.py b/paramiko/_version.py
index a5db89b8..95e86b6a 100644
--- a/paramiko/_version.py
+++ b/paramiko/_version.py
@@ -1,2 +1,2 @@
__version_info__ = (2, 4, 1)
-__version__ = '.'.join(map(str, __version_info__))
+__version__ = ".".join(map(str, __version_info__))
diff --git a/paramiko/_winapi.py b/paramiko/_winapi.py
index a13d7e87..ebcc678a 100644
--- a/paramiko/_winapi.py
+++ b/paramiko/_winapi.py
@@ -15,6 +15,7 @@ from paramiko.py3compat import u, builtins
######################
# jaraco.windows.error
+
def format_system_message(errno):
"""
Call FormatMessage with a system error number to retrieve
@@ -77,7 +78,7 @@ class WindowsError(builtins.WindowsError):
return self.message
def __repr__(self):
- return '{self.__class__.__name__}({self.winerror})'.format(**vars())
+ return "{self.__class__.__name__}({self.winerror})".format(**vars())
def handle_nonzero_success(result):
@@ -95,15 +96,15 @@ GlobalAlloc.argtypes = ctypes.wintypes.UINT, ctypes.c_size_t
GlobalAlloc.restype = ctypes.wintypes.HANDLE
GlobalLock = ctypes.windll.kernel32.GlobalLock
-GlobalLock.argtypes = ctypes.wintypes.HGLOBAL,
+GlobalLock.argtypes = (ctypes.wintypes.HGLOBAL,)
GlobalLock.restype = ctypes.wintypes.LPVOID
GlobalUnlock = ctypes.windll.kernel32.GlobalUnlock
-GlobalUnlock.argtypes = ctypes.wintypes.HGLOBAL,
+GlobalUnlock.argtypes = (ctypes.wintypes.HGLOBAL,)
GlobalUnlock.restype = ctypes.wintypes.BOOL
GlobalSize = ctypes.windll.kernel32.GlobalSize
-GlobalSize.argtypes = ctypes.wintypes.HGLOBAL,
+GlobalSize.argtypes = (ctypes.wintypes.HGLOBAL,)
GlobalSize.restype = ctypes.c_size_t
CreateFileMapping = ctypes.windll.kernel32.CreateFileMappingW
@@ -121,16 +122,12 @@ MapViewOfFile = ctypes.windll.kernel32.MapViewOfFile
MapViewOfFile.restype = ctypes.wintypes.HANDLE
UnmapViewOfFile = ctypes.windll.kernel32.UnmapViewOfFile
-UnmapViewOfFile.argtypes = ctypes.wintypes.HANDLE,
+UnmapViewOfFile.argtypes = (ctypes.wintypes.HANDLE,)
RtlMoveMemory = ctypes.windll.kernel32.RtlMoveMemory
-RtlMoveMemory.argtypes = (
- ctypes.c_void_p,
- ctypes.c_void_p,
- ctypes.c_size_t,
-)
+RtlMoveMemory.argtypes = (ctypes.c_void_p, ctypes.c_void_p, ctypes.c_size_t)
-ctypes.windll.kernel32.LocalFree.argtypes = ctypes.wintypes.HLOCAL,
+ctypes.windll.kernel32.LocalFree.argtypes = (ctypes.wintypes.HLOCAL,)
#####################
# jaraco.windows.mmap
@@ -140,6 +137,7 @@ class MemoryMap(object):
"""
A memory map object which can have security attributes overridden.
"""
+
def __init__(self, name, length, security_attributes=None):
self.name = name
self.length = length
@@ -149,14 +147,20 @@ class MemoryMap(object):
def __enter__(self):
p_SA = (
ctypes.byref(self.security_attributes)
- if self.security_attributes else None
+ if self.security_attributes
+ else None
)
INVALID_HANDLE_VALUE = -1
PAGE_READWRITE = 0x4
FILE_MAP_WRITE = 0x2
filemap = ctypes.windll.kernel32.CreateFileMappingW(
- INVALID_HANDLE_VALUE, p_SA, PAGE_READWRITE, 0, self.length,
- u(self.name))
+ INVALID_HANDLE_VALUE,
+ p_SA,
+ PAGE_READWRITE,
+ 0,
+ self.length,
+ u(self.name),
+ )
handle_nonzero_success(filemap)
if filemap == INVALID_HANDLE_VALUE:
raise Exception("Failed to create file mapping")
@@ -220,41 +224,45 @@ POLICY_LOOKUP_NAMES = 0x00000800
POLICY_NOTIFICATION = 0x00001000
POLICY_ALL_ACCESS = (
- STANDARD_RIGHTS_REQUIRED |
- POLICY_VIEW_LOCAL_INFORMATION |
- POLICY_VIEW_AUDIT_INFORMATION |
- POLICY_GET_PRIVATE_INFORMATION |
- POLICY_TRUST_ADMIN |
- POLICY_CREATE_ACCOUNT |
- POLICY_CREATE_SECRET |
- POLICY_CREATE_PRIVILEGE |
- POLICY_SET_DEFAULT_QUOTA_LIMITS |
- POLICY_SET_AUDIT_REQUIREMENTS |
- POLICY_AUDIT_LOG_ADMIN |
- POLICY_SERVER_ADMIN |
- POLICY_LOOKUP_NAMES)
+ STANDARD_RIGHTS_REQUIRED
+ | POLICY_VIEW_LOCAL_INFORMATION
+ | POLICY_VIEW_AUDIT_INFORMATION
+ | POLICY_GET_PRIVATE_INFORMATION
+ | POLICY_TRUST_ADMIN
+ | POLICY_CREATE_ACCOUNT
+ | POLICY_CREATE_SECRET
+ | POLICY_CREATE_PRIVILEGE
+ | POLICY_SET_DEFAULT_QUOTA_LIMITS
+ | POLICY_SET_AUDIT_REQUIREMENTS
+ | POLICY_AUDIT_LOG_ADMIN
+ | POLICY_SERVER_ADMIN
+ | POLICY_LOOKUP_NAMES
+)
POLICY_READ = (
- STANDARD_RIGHTS_READ |
- POLICY_VIEW_AUDIT_INFORMATION |
- POLICY_GET_PRIVATE_INFORMATION)
+ STANDARD_RIGHTS_READ
+ | POLICY_VIEW_AUDIT_INFORMATION
+ | POLICY_GET_PRIVATE_INFORMATION
+)
POLICY_WRITE = (
- STANDARD_RIGHTS_WRITE |
- POLICY_TRUST_ADMIN |
- POLICY_CREATE_ACCOUNT |
- POLICY_CREATE_SECRET |
- POLICY_CREATE_PRIVILEGE |
- POLICY_SET_DEFAULT_QUOTA_LIMITS |
- POLICY_SET_AUDIT_REQUIREMENTS |
- POLICY_AUDIT_LOG_ADMIN |
- POLICY_SERVER_ADMIN)
+ STANDARD_RIGHTS_WRITE
+ | POLICY_TRUST_ADMIN
+ | POLICY_CREATE_ACCOUNT
+ | POLICY_CREATE_SECRET
+ | POLICY_CREATE_PRIVILEGE
+ | POLICY_SET_DEFAULT_QUOTA_LIMITS
+ | POLICY_SET_AUDIT_REQUIREMENTS
+ | POLICY_AUDIT_LOG_ADMIN
+ | POLICY_SERVER_ADMIN
+)
POLICY_EXECUTE = (
- STANDARD_RIGHTS_EXECUTE |
- POLICY_VIEW_LOCAL_INFORMATION |
- POLICY_LOOKUP_NAMES)
+ STANDARD_RIGHTS_EXECUTE
+ | POLICY_VIEW_LOCAL_INFORMATION
+ | POLICY_LOOKUP_NAMES
+)
class TokenAccess:
@@ -268,8 +276,8 @@ class TokenInformationClass:
class TOKEN_USER(ctypes.Structure):
num = 1
_fields_ = [
- ('SID', ctypes.c_void_p),
- ('ATTRIBUTES', ctypes.wintypes.DWORD),
+ ("SID", ctypes.c_void_p),
+ ("ATTRIBUTES", ctypes.wintypes.DWORD),
]
@@ -290,13 +298,13 @@ class SECURITY_DESCRIPTOR(ctypes.Structure):
REVISION = 1
_fields_ = [
- ('Revision', ctypes.c_ubyte),
- ('Sbz1', ctypes.c_ubyte),
- ('Control', SECURITY_DESCRIPTOR_CONTROL),
- ('Owner', ctypes.c_void_p),
- ('Group', ctypes.c_void_p),
- ('Sacl', ctypes.c_void_p),
- ('Dacl', ctypes.c_void_p),
+ ("Revision", ctypes.c_ubyte),
+ ("Sbz1", ctypes.c_ubyte),
+ ("Control", SECURITY_DESCRIPTOR_CONTROL),
+ ("Owner", ctypes.c_void_p),
+ ("Group", ctypes.c_void_p),
+ ("Sacl", ctypes.c_void_p),
+ ("Dacl", ctypes.c_void_p),
]
@@ -309,9 +317,9 @@ class SECURITY_ATTRIBUTES(ctypes.Structure):
} SECURITY_ATTRIBUTES;
"""
_fields_ = [
- ('nLength', ctypes.wintypes.DWORD),
- ('lpSecurityDescriptor', ctypes.c_void_p),
- ('bInheritHandle', ctypes.wintypes.BOOL),
+ ("nLength", ctypes.wintypes.DWORD),
+ ("lpSecurityDescriptor", ctypes.c_void_p),
+ ("bInheritHandle", ctypes.wintypes.BOOL),
]
def __init__(self, *args, **kwargs):
@@ -343,21 +351,30 @@ def GetTokenInformation(token, information_class):
Given a token, get the token information for it.
"""
data_size = ctypes.wintypes.DWORD()
- ctypes.windll.advapi32.GetTokenInformation(token, information_class.num,
- 0, 0, ctypes.byref(data_size))
+ ctypes.windll.advapi32.GetTokenInformation(
+ token, information_class.num, 0, 0, ctypes.byref(data_size)
+ )
data = ctypes.create_string_buffer(data_size.value)
- handle_nonzero_success(ctypes.windll.advapi32.GetTokenInformation(token,
- information_class.num,
- ctypes.byref(data), ctypes.sizeof(data),
- ctypes.byref(data_size)))
+ handle_nonzero_success(
+ ctypes.windll.advapi32.GetTokenInformation(
+ token,
+ information_class.num,
+ ctypes.byref(data),
+ ctypes.sizeof(data),
+ ctypes.byref(data_size),
+ )
+ )
return ctypes.cast(data, ctypes.POINTER(TOKEN_USER)).contents
def OpenProcessToken(proc_handle, access):
result = ctypes.wintypes.HANDLE()
proc_handle = ctypes.wintypes.HANDLE(proc_handle)
- handle_nonzero_success(ctypes.windll.advapi32.OpenProcessToken(
- proc_handle, access, ctypes.byref(result)))
+ handle_nonzero_success(
+ ctypes.windll.advapi32.OpenProcessToken(
+ proc_handle, access, ctypes.byref(result)
+ )
+ )
return result
@@ -366,8 +383,7 @@ def get_current_user():
Return a TOKEN_USER for the owner of this process.
"""
process = OpenProcessToken(
- ctypes.windll.kernel32.GetCurrentProcess(),
- TokenAccess.TOKEN_QUERY,
+ ctypes.windll.kernel32.GetCurrentProcess(), TokenAccess.TOKEN_QUERY
)
return GetTokenInformation(process, TOKEN_USER)
@@ -389,8 +405,10 @@ def get_security_attributes_for_user(user=None):
SA.descriptor = SD
SA.bInheritHandle = 1
- ctypes.windll.advapi32.InitializeSecurityDescriptor(ctypes.byref(SD),
- SECURITY_DESCRIPTOR.REVISION)
- ctypes.windll.advapi32.SetSecurityDescriptorOwner(ctypes.byref(SD),
- user.SID, 0)
+ ctypes.windll.advapi32.InitializeSecurityDescriptor(
+ ctypes.byref(SD), SECURITY_DESCRIPTOR.REVISION
+ )
+ ctypes.windll.advapi32.SetSecurityDescriptorOwner(
+ ctypes.byref(SD), user.SID, 0
+ )
return SA
diff --git a/paramiko/agent.py b/paramiko/agent.py
index 7a4dde21..62a271d5 100644
--- a/paramiko/agent.py
+++ b/paramiko/agent.py
@@ -43,8 +43,8 @@ cSSH2_AGENTC_SIGN_REQUEST = byte_chr(13)
SSH2_AGENT_SIGN_RESPONSE = 14
-
class AgentSSH(object):
+
def __init__(self):
self._conn = None
self._keys = ()
@@ -65,7 +65,7 @@ class AgentSSH(object):
self._conn = conn
ptype, result = self._send_message(cSSH2_AGENTC_REQUEST_IDENTITIES)
if ptype != SSH2_AGENT_IDENTITIES_ANSWER:
- raise SSHException('could not get keys from ssh-agent')
+ raise SSHException("could not get keys from ssh-agent")
keys = []
for i in range(result.get_int()):
keys.append(AgentKey(self, result.get_binary()))
@@ -80,19 +80,19 @@ class AgentSSH(object):
def _send_message(self, msg):
msg = asbytes(msg)
- self._conn.send(struct.pack('>I', len(msg)) + msg)
+ self._conn.send(struct.pack(">I", len(msg)) + msg)
l = self._read_all(4)
- msg = Message(self._read_all(struct.unpack('>I', l)[0]))
+ msg = Message(self._read_all(struct.unpack(">I", l)[0]))
return ord(msg.get_byte()), msg
def _read_all(self, wanted):
result = self._conn.recv(wanted)
while len(result) < wanted:
if len(result) == 0:
- raise SSHException('lost ssh-agent')
+ raise SSHException("lost ssh-agent")
extra = self._conn.recv(wanted - len(result))
if len(extra) == 0:
- raise SSHException('lost ssh-agent')
+ raise SSHException("lost ssh-agent")
result += extra
return result
@@ -101,6 +101,7 @@ class AgentProxyThread(threading.Thread):
"""
Class in charge of communication between two channels.
"""
+
def __init__(self, agent):
threading.Thread.__init__(self, target=self.run)
self._agent = agent
@@ -115,12 +116,9 @@ class AgentProxyThread(threading.Thread):
# The address should be an IP address as a string? or None
self.__addr = addr
self._agent.connect()
- if (
- not isinstance(self._agent, int) and
- (
- self._agent._conn is None or
- not hasattr(self._agent._conn, 'fileno')
- )
+ if not isinstance(self._agent, int) and (
+ self._agent._conn is None
+ or not hasattr(self._agent._conn, "fileno")
):
raise AuthenticationException("Unable to connect to SSH agent")
self._communicate()
@@ -130,6 +128,7 @@ class AgentProxyThread(threading.Thread):
def _communicate(self):
import fcntl
+
oldflags = fcntl.fcntl(self.__inr, fcntl.F_GETFL)
fcntl.fcntl(self.__inr, fcntl.F_SETFL, oldflags | os.O_NONBLOCK)
while not self._exit:
@@ -162,6 +161,7 @@ class AgentLocalProxy(AgentProxyThread):
Class to be used when wanting to ask a local SSH Agent being
asked from a remote fake agent (so use a unix socket for ex.)
"""
+
def __init__(self, agent):
AgentProxyThread.__init__(self, agent)
@@ -185,6 +185,7 @@ class AgentRemoteProxy(AgentProxyThread):
"""
Class to be used when wanting to ask a remote SSH Agent
"""
+
def __init__(self, agent, chan):
AgentProxyThread.__init__(self, agent)
self.__chan = chan
@@ -205,6 +206,7 @@ class AgentClientProxy(object):
the remote fake agent and the local agent
#. Communication occurs ...
"""
+
def __init__(self, chanRemote):
self._conn = None
self.__chanR = chanRemote
@@ -218,16 +220,18 @@ class AgentClientProxy(object):
"""
Method automatically called by ``AgentProxyThread.run``.
"""
- if ('SSH_AUTH_SOCK' in os.environ) and (sys.platform != 'win32'):
+ if ("SSH_AUTH_SOCK" in os.environ) and (sys.platform != "win32"):
conn = socket.socket(socket.AF_UNIX, socket.SOCK_STREAM)
try:
retry_on_signal(
- lambda: conn.connect(os.environ['SSH_AUTH_SOCK']))
+ lambda: conn.connect(os.environ["SSH_AUTH_SOCK"])
+ )
except:
# probably a dangling env var: the ssh agent is gone
return
- elif sys.platform == 'win32':
+ elif sys.platform == "win32":
import paramiko.win_pageant as win_pageant
+
if win_pageant.can_talk_to_agent():
conn = win_pageant.PageantConnection()
else:
@@ -255,12 +259,13 @@ class AgentServerProxy(AgentSSH):
:raises: `.SSHException` -- mostly if we lost the agent
"""
+
def __init__(self, t):
AgentSSH.__init__(self)
self.__t = t
- self._dir = tempfile.mkdtemp('sshproxy')
+ self._dir = tempfile.mkdtemp("sshproxy")
os.chmod(self._dir, stat.S_IRWXU)
- self._file = self._dir + '/sshproxy.ssh'
+ self._file = self._dir + "/sshproxy.ssh"
self.thread = AgentLocalProxy(self)
self.thread.start()
@@ -270,8 +275,8 @@ class AgentServerProxy(AgentSSH):
def connect(self):
conn_sock = self.__t.open_forward_agent_channel()
if conn_sock is None:
- raise SSHException('lost ssh-agent')
- conn_sock.set_name('auth-agent')
+ raise SSHException("lost ssh-agent")
+ conn_sock.set_name("auth-agent")
self._connect(conn_sock)
def close(self):
@@ -292,7 +297,7 @@ class AgentServerProxy(AgentSSH):
:return:
a dict containing the ``SSH_AUTH_SOCK`` environnement variables
"""
- return {'SSH_AUTH_SOCK': self._get_filename()}
+ return {"SSH_AUTH_SOCK": self._get_filename()}
def _get_filename(self):
return self._file
@@ -319,6 +324,7 @@ class AgentRequestHandler(object):
# the remote end.
session.exec_command("git clone https://my.git.repository/")
"""
+
def __init__(self, chanClient):
self._conn = None
self.__chanC = chanClient
@@ -350,18 +356,20 @@ class Agent(AgentSSH):
:raises: `.SSHException` --
if an SSH agent is found, but speaks an incompatible protocol
"""
+
def __init__(self):
AgentSSH.__init__(self)
- if ('SSH_AUTH_SOCK' in os.environ) and (sys.platform != 'win32'):
+ if ("SSH_AUTH_SOCK" in os.environ) and (sys.platform != "win32"):
conn = socket.socket(socket.AF_UNIX, socket.SOCK_STREAM)
try:
- conn.connect(os.environ['SSH_AUTH_SOCK'])
+ conn.connect(os.environ["SSH_AUTH_SOCK"])
except:
# probably a dangling env var: the ssh agent is gone
return
- elif sys.platform == 'win32':
+ elif sys.platform == "win32":
from . import win_pageant
+
if win_pageant.can_talk_to_agent():
conn = win_pageant.PageantConnection()
else:
@@ -384,6 +392,7 @@ class AgentKey(PKey):
authenticating to a remote server (signing). Most other key operations
work as expected.
"""
+
def __init__(self, agent, blob):
self.agent = agent
self.blob = blob
@@ -407,5 +416,5 @@ class AgentKey(PKey):
msg.add_int(0)
ptype, result = self.agent._send_message(msg)
if ptype != SSH2_AGENT_SIGN_RESPONSE:
- raise SSHException('key cannot be used for signing')
+ raise SSHException("key cannot be used for signing")
return result.get_binary()
diff --git a/paramiko/auth_handler.py b/paramiko/auth_handler.py
index 3b894de7..41724832 100644
--- a/paramiko/auth_handler.py
+++ b/paramiko/auth_handler.py
@@ -24,31 +24,55 @@ import weakref
import time
from paramiko.common import (
- cMSG_SERVICE_REQUEST, cMSG_DISCONNECT, DISCONNECT_SERVICE_NOT_AVAILABLE,
- DISCONNECT_NO_MORE_AUTH_METHODS_AVAILABLE, cMSG_USERAUTH_REQUEST,
- cMSG_SERVICE_ACCEPT, DEBUG, AUTH_SUCCESSFUL, INFO, cMSG_USERAUTH_SUCCESS,
- cMSG_USERAUTH_FAILURE, AUTH_PARTIALLY_SUCCESSFUL,
- cMSG_USERAUTH_INFO_REQUEST, WARNING, AUTH_FAILED, cMSG_USERAUTH_PK_OK,
- cMSG_USERAUTH_INFO_RESPONSE, MSG_SERVICE_REQUEST, MSG_SERVICE_ACCEPT,
- MSG_USERAUTH_REQUEST, MSG_USERAUTH_SUCCESS, MSG_USERAUTH_FAILURE,
- MSG_USERAUTH_BANNER, MSG_USERAUTH_INFO_REQUEST, MSG_USERAUTH_INFO_RESPONSE,
- cMSG_USERAUTH_GSSAPI_RESPONSE, cMSG_USERAUTH_GSSAPI_TOKEN,
- cMSG_USERAUTH_GSSAPI_MIC, MSG_USERAUTH_GSSAPI_RESPONSE,
- MSG_USERAUTH_GSSAPI_TOKEN, MSG_USERAUTH_GSSAPI_ERROR,
- MSG_USERAUTH_GSSAPI_ERRTOK, MSG_USERAUTH_GSSAPI_MIC, MSG_NAMES,
- cMSG_USERAUTH_BANNER
+ cMSG_SERVICE_REQUEST,
+ cMSG_DISCONNECT,
+ DISCONNECT_SERVICE_NOT_AVAILABLE,
+ DISCONNECT_NO_MORE_AUTH_METHODS_AVAILABLE,
+ cMSG_USERAUTH_REQUEST,
+ cMSG_SERVICE_ACCEPT,
+ DEBUG,
+ AUTH_SUCCESSFUL,
+ INFO,
+ cMSG_USERAUTH_SUCCESS,
+ cMSG_USERAUTH_FAILURE,
+ AUTH_PARTIALLY_SUCCESSFUL,
+ cMSG_USERAUTH_INFO_REQUEST,
+ WARNING,
+ AUTH_FAILED,
+ cMSG_USERAUTH_PK_OK,
+ cMSG_USERAUTH_INFO_RESPONSE,
+ MSG_SERVICE_REQUEST,
+ MSG_SERVICE_ACCEPT,
+ MSG_USERAUTH_REQUEST,
+ MSG_USERAUTH_SUCCESS,
+ MSG_USERAUTH_FAILURE,
+ MSG_USERAUTH_BANNER,
+ MSG_USERAUTH_INFO_REQUEST,
+ MSG_USERAUTH_INFO_RESPONSE,
+ cMSG_USERAUTH_GSSAPI_RESPONSE,
+ cMSG_USERAUTH_GSSAPI_TOKEN,
+ cMSG_USERAUTH_GSSAPI_MIC,
+ MSG_USERAUTH_GSSAPI_RESPONSE,
+ MSG_USERAUTH_GSSAPI_TOKEN,
+ MSG_USERAUTH_GSSAPI_ERROR,
+ MSG_USERAUTH_GSSAPI_ERRTOK,
+ MSG_USERAUTH_GSSAPI_MIC,
+ MSG_NAMES,
+ cMSG_USERAUTH_BANNER,
)
from paramiko.message import Message
from paramiko.py3compat import b
from paramiko.ssh_exception import (
- SSHException, AuthenticationException, BadAuthenticationType,
+ SSHException,
+ AuthenticationException,
+ BadAuthenticationType,
PartialAuthentication,
)
from paramiko.server import InteractiveQuery
from paramiko.ssh_gss import GSSAuth, GSS_EXCEPTIONS
-class AuthHandler (object):
+class AuthHandler(object):
"""
Internal class to handle the mechanics of authentication.
"""
@@ -58,7 +82,7 @@ class AuthHandler (object):
self.username = None
self.authenticated = False
self.auth_event = None
- self.auth_method = ''
+ self.auth_method = ""
self.banner = None
self.password = None
self.private_key = None
@@ -87,7 +111,7 @@ class AuthHandler (object):
self.transport.lock.acquire()
try:
self.auth_event = event
- self.auth_method = 'none'
+ self.auth_method = "none"
self.username = username
self._request_auth()
finally:
@@ -97,7 +121,7 @@ class AuthHandler (object):
self.transport.lock.acquire()
try:
self.auth_event = event
- self.auth_method = 'publickey'
+ self.auth_method = "publickey"
self.username = username
self.private_key = key
self._request_auth()
@@ -108,21 +132,21 @@ class AuthHandler (object):
self.transport.lock.acquire()
try:
self.auth_event = event
- self.auth_method = 'password'
+ self.auth_method = "password"
self.username = username
self.password = password
self._request_auth()
finally:
self.transport.lock.release()
- def auth_interactive(self, username, handler, event, submethods=''):
+ def auth_interactive(self, username, handler, event, submethods=""):
"""
response_list = handler(title, instructions, prompt_list)
"""
self.transport.lock.acquire()
try:
self.auth_event = event
- self.auth_method = 'keyboard-interactive'
+ self.auth_method = "keyboard-interactive"
self.username = username
self.interactive_handler = handler
self.submethods = submethods
@@ -134,7 +158,7 @@ class AuthHandler (object):
self.transport.lock.acquire()
try:
self.auth_event = event
- self.auth_method = 'gssapi-with-mic'
+ self.auth_method = "gssapi-with-mic"
self.username = username
self.gss_host = gss_host
self.gss_deleg_creds = gss_deleg_creds
@@ -146,7 +170,7 @@ class AuthHandler (object):
self.transport.lock.acquire()
try:
self.auth_event = event
- self.auth_method = 'gssapi-keyex'
+ self.auth_method = "gssapi-keyex"
self.username = username
self._request_auth()
finally:
@@ -161,15 +185,15 @@ class AuthHandler (object):
def _request_auth(self):
m = Message()
m.add_byte(cMSG_SERVICE_REQUEST)
- m.add_string('ssh-userauth')
+ m.add_string("ssh-userauth")
self.transport._send_message(m)
def _disconnect_service_not_available(self):
m = Message()
m.add_byte(cMSG_DISCONNECT)
m.add_int(DISCONNECT_SERVICE_NOT_AVAILABLE)
- m.add_string('Service not available')
- m.add_string('en')
+ m.add_string("Service not available")
+ m.add_string("en")
self.transport._send_message(m)
self.transport.close()
@@ -177,8 +201,8 @@ class AuthHandler (object):
m = Message()
m.add_byte(cMSG_DISCONNECT)
m.add_int(DISCONNECT_NO_MORE_AUTH_METHODS_AVAILABLE)
- m.add_string('No more auth methods available')
- m.add_string('en')
+ m.add_string("No more auth methods available")
+ m.add_string("en")
self.transport._send_message(m)
self.transport.close()
@@ -188,7 +212,7 @@ class AuthHandler (object):
m.add_byte(cMSG_USERAUTH_REQUEST)
m.add_string(username)
m.add_string(service)
- m.add_string('publickey')
+ m.add_string("publickey")
m.add_boolean(True)
# Use certificate contents, if available, plain pubkey otherwise
if key.public_blob:
@@ -208,17 +232,17 @@ class AuthHandler (object):
if not self.transport.is_active():
e = self.transport.get_exception()
if (e is None) or issubclass(e.__class__, EOFError):
- e = AuthenticationException('Authentication failed.')
+ e = AuthenticationException("Authentication failed.")
raise e
if event.is_set():
break
if max_ts is not None and max_ts <= time.time():
- raise AuthenticationException('Authentication timeout.')
+ raise AuthenticationException("Authentication timeout.")
if not self.is_authenticated():
e = self.transport.get_exception()
if e is None:
- e = AuthenticationException('Authentication failed.')
+ e = AuthenticationException("Authentication failed.")
# this is horrible. Python Exception isn't yet descended from
# object, so type(e) won't work. :(
if issubclass(e.__class__, PartialAuthentication):
@@ -228,7 +252,7 @@ class AuthHandler (object):
def _parse_service_request(self, m):
service = m.get_text()
- if self.transport.server_mode and (service == 'ssh-userauth'):
+ if self.transport.server_mode and (service == "ssh-userauth"):
# accepted
m = Message()
m.add_byte(cMSG_SERVICE_ACCEPT)
@@ -247,18 +271,18 @@ class AuthHandler (object):
def _parse_service_accept(self, m):
service = m.get_text()
- if service == 'ssh-userauth':
- self._log(DEBUG, 'userauth is OK')
+ if service == "ssh-userauth":
+ self._log(DEBUG, "userauth is OK")
m = Message()
m.add_byte(cMSG_USERAUTH_REQUEST)
m.add_string(self.username)
- m.add_string('ssh-connection')
+ m.add_string("ssh-connection")
m.add_string(self.auth_method)
- if self.auth_method == 'password':
+ if self.auth_method == "password":
m.add_boolean(False)
password = b(self.password)
m.add_string(password)
- elif self.auth_method == 'publickey':
+ elif self.auth_method == "publickey":
m.add_boolean(True)
# Use certificate contents, if available, plain pubkey
# otherwise
@@ -269,11 +293,12 @@ class AuthHandler (object):
m.add_string(self.private_key.get_name())
m.add_string(self.private_key)
blob = self._get_session_blob(
- self.private_key, 'ssh-connection', self.username)
+ self.private_key, "ssh-connection", self.username
+ )
sig = self.private_key.sign_ssh_data(blob)
m.add_string(sig)
- elif self.auth_method == 'keyboard-interactive':
- m.add_string('')
+ elif self.auth_method == "keyboard-interactive":
+ m.add_string("")
m.add_string(self.submethods)
elif self.auth_method == "gssapi-with-mic":
sshgss = GSSAuth(self.auth_method, self.gss_deleg_creds)
@@ -292,10 +317,11 @@ class AuthHandler (object):
m = Message()
m.add_byte(cMSG_USERAUTH_GSSAPI_TOKEN)
try:
- m.add_string(sshgss.ssh_init_sec_context(
- self.gss_host,
- mech,
- self.username,))
+ m.add_string(
+ sshgss.ssh_init_sec_context(
+ self.gss_host, mech, self.username
+ )
+ )
except GSS_EXCEPTIONS as e:
return self._handle_local_gss_failure(e)
self.transport._send_message(m)
@@ -308,7 +334,8 @@ class AuthHandler (object):
self.gss_host,
mech,
self.username,
- srv_token)
+ srv_token,
+ )
except GSS_EXCEPTIONS as e:
return self._handle_local_gss_failure(e)
# After this step the GSSAPI should not return any
@@ -340,48 +367,55 @@ class AuthHandler (object):
min_status = m.get_int()
err_msg = m.get_string()
m.get_string() # Lang tag - discarded
- raise SSHException("""GSS-API Error:
+ raise SSHException(
+ """GSS-API Error:
Major Status: {}
Minor Status: {}
Error Message: {}
-""".format(maj_status, min_status, err_msg))
+""".format(
+ maj_status, min_status, err_msg
+ )
+ )
elif ptype == MSG_USERAUTH_FAILURE:
self._parse_userauth_failure(m)
return
else:
raise SSHException(
- "Received Package: {}".format(MSG_NAMES[ptype]))
+ "Received Package: {}".format(MSG_NAMES[ptype])
+ )
elif (
- self.auth_method == 'gssapi-keyex' and
- self.transport.gss_kex_used
+ self.auth_method == "gssapi-keyex"
+ and self.transport.gss_kex_used
):
kexgss = self.transport.kexgss_ctxt
kexgss.set_username(self.username)
mic_token = kexgss.ssh_get_mic(self.transport.session_id)
m.add_string(mic_token)
- elif self.auth_method == 'none':
+ elif self.auth_method == "none":
pass
else:
raise SSHException(
- 'Unknown auth method "{}"'.format(self.auth_method))
+ 'Unknown auth method "{}"'.format(self.auth_method)
+ )
self.transport._send_message(m)
else:
self._log(
- DEBUG,
- 'Service request "{}" accepted (?)'.format(service))
+ DEBUG, 'Service request "{}" accepted (?)'.format(service)
+ )
def _send_auth_result(self, username, method, result):
# okay, send result
m = Message()
if result == AUTH_SUCCESSFUL:
- self._log(INFO, 'Auth granted ({}).'.format(method))
+ self._log(INFO, "Auth granted ({}).".format(method))
m.add_byte(cMSG_USERAUTH_SUCCESS)
self.authenticated = True
else:
- self._log(INFO, 'Auth rejected ({}).'.format(method))
+ self._log(INFO, "Auth rejected ({}).".format(method))
m.add_byte(cMSG_USERAUTH_FAILURE)
m.add_string(
- self.transport.server_object.get_allowed_auths(username))
+ self.transport.server_object.get_allowed_auths(username)
+ )
if result == AUTH_PARTIALLY_SUCCESSFUL:
m.add_boolean(True)
else:
@@ -411,7 +445,7 @@ Error Message: {}
# er, uh... what?
m = Message()
m.add_byte(cMSG_USERAUTH_FAILURE)
- m.add_string('none')
+ m.add_string("none")
m.add_boolean(False)
self.transport._send_message(m)
return
@@ -423,18 +457,19 @@ Error Message: {}
method = m.get_text()
self._log(
DEBUG,
- 'Auth request (type={}) service={}, username={}'.format(
+ "Auth request (type={}) service={}, username={}".format(
method, service, username
- )
+ ),
)
- if service != 'ssh-connection':
+ if service != "ssh-connection":
self._disconnect_service_not_available()
return
- if ((self.auth_username is not None) and
- (self.auth_username != username)):
+ if (self.auth_username is not None) and (
+ self.auth_username != username
+ ):
self._log(
WARNING,
- 'Auth rejected because the client attempted to change username in mid-flight' # noqa
+ "Auth rejected because the client attempted to change username in mid-flight", # noqa
)
self._disconnect_no_more_auth()
return
@@ -442,13 +477,13 @@ Error Message: {}
# check if GSS-API authentication is enabled
gss_auth = self.transport.server_object.enable_auth_gssapi()
- if method == 'none':
+ if method == "none":
result = self.transport.server_object.check_auth_none(username)
- elif method == 'password':
+ elif method == "password":
changereq = m.get_boolean()
password = m.get_binary()
try:
- password = password.decode('UTF-8')
+ password = password.decode("UTF-8")
except UnicodeError:
# some clients/servers expect non-utf-8 passwords!
# in this case, just return the raw byte string.
@@ -457,31 +492,30 @@ Error Message: {}
# always treated as failure, since we don't support changing
# passwords, but collect the list of valid auth types from
# the callback anyway
- self._log(
- DEBUG,
- 'Auth request to change passwords (rejected)')
+ self._log(DEBUG, "Auth request to change passwords (rejected)")
newpassword = m.get_binary()
try:
- newpassword = newpassword.decode('UTF-8', 'replace')
+ newpassword = newpassword.decode("UTF-8", "replace")
except UnicodeError:
pass
result = AUTH_FAILED
else:
result = self.transport.server_object.check_auth_password(
- username, password)
- elif method == 'publickey':
+ username, password
+ )
+ elif method == "publickey":
sig_attached = m.get_boolean()
keytype = m.get_text()
keyblob = m.get_binary()
try:
key = self.transport._key_info[keytype](Message(keyblob))
except SSHException as e:
- self._log(
- INFO,
- 'Auth rejected: public key: {}'.format(str(e)))
+ self._log(INFO, "Auth rejected: public key: {}".format(str(e)))
key = None
except Exception as e:
- msg = 'Auth rejected: unsupported or mangled public key ({}: {})' # noqa
+ msg = (
+ "Auth rejected: unsupported or mangled public key ({}: {})"
+ ) # noqa
self._log(INFO, msg.format(e.__class__.__name__, e))
key = None
if key is None:
@@ -489,7 +523,8 @@ Error Message: {}
return
# first check if this key is okay... if not, we can skip the verify
result = self.transport.server_object.check_auth_publickey(
- username, key)
+ username, key
+ )
if result != AUTH_FAILED:
# key is okay, verify it
if not sig_attached:
@@ -504,14 +539,13 @@ Error Message: {}
sig = Message(m.get_binary())
blob = self._get_session_blob(key, service, username)
if not key.verify_ssh_sig(blob, sig):
- self._log(
- INFO,
- 'Auth rejected: invalid signature')
+ self._log(INFO, "Auth rejected: invalid signature")
result = AUTH_FAILED
- elif method == 'keyboard-interactive':
+ elif method == "keyboard-interactive":
submethods = m.get_string()
result = self.transport.server_object.check_auth_interactive(
- username, submethods)
+ username, submethods
+ )
if isinstance(result, InteractiveQuery):
# make interactive query instead of response
self._interactive_query(result)
@@ -527,7 +561,8 @@ Error Message: {}
if mechs > 1:
self._log(
INFO,
- 'Disconnect: Received more than one GSS-API OID mechanism')
+ "Disconnect: Received more than one GSS-API OID mechanism",
+ )
self._disconnect_no_more_auth()
desired_mech = m.get_string()
mech_ok = sshgss.ssh_check_mech(desired_mech)
@@ -535,7 +570,8 @@ Error Message: {}
if not mech_ok:
self._log(
INFO,
- 'Disconnect: Received an invalid GSS-API OID mechanism')
+ "Disconnect: Received an invalid GSS-API OID mechanism",
+ )
self._disconnect_no_more_auth()
# send the Kerberos V5 GSSAPI OID to the client
supported_mech = sshgss.ssh_gss_oids("server")
@@ -544,11 +580,14 @@ Error Message: {}
m = Message()
m.add_byte(cMSG_USERAUTH_GSSAPI_RESPONSE)
m.add_bytes(supported_mech)
- self.transport.auth_handler = GssapiWithMicAuthHandler(self,
- sshgss)
- self.transport._expected_packet = (MSG_USERAUTH_GSSAPI_TOKEN,
- MSG_USERAUTH_REQUEST,
- MSG_SERVICE_REQUEST)
+ self.transport.auth_handler = GssapiWithMicAuthHandler(
+ self, sshgss
+ )
+ self.transport._expected_packet = (
+ MSG_USERAUTH_GSSAPI_TOKEN,
+ MSG_USERAUTH_REQUEST,
+ MSG_SERVICE_REQUEST,
+ )
self.transport._send_message(m)
return
elif method == "gssapi-keyex" and gss_auth:
@@ -559,16 +598,17 @@ Error Message: {}
result = AUTH_FAILED
self._send_auth_result(username, method, result)
try:
- sshgss.ssh_check_mic(mic_token,
- self.transport.session_id,
- self.auth_username)
+ sshgss.ssh_check_mic(
+ mic_token, self.transport.session_id, self.auth_username
+ )
except Exception:
result = AUTH_FAILED
self._send_auth_result(username, method, result)
raise
result = AUTH_SUCCESSFUL
self.transport.server_object.check_auth_gssapi_keyex(
- username, result)
+ username, result
+ )
else:
result = self.transport.server_object.check_auth_none(username)
# okay, send result
@@ -576,8 +616,8 @@ Error Message: {}
def _parse_userauth_success(self, m):
self._log(
- INFO,
- 'Authentication ({}) successful!'.format(self.auth_method))
+ INFO, "Authentication ({}) successful!".format(self.auth_method)
+ )
self.authenticated = True
self.transport._auth_trigger()
if self.auth_event is not None:
@@ -587,24 +627,23 @@ Error Message: {}
authlist = m.get_list()
partial = m.get_boolean()
if partial:
- self._log(INFO, 'Authentication continues...')
- self._log(DEBUG, 'Methods: ' + str(authlist))
+ self._log(INFO, "Authentication continues...")
+ self._log(DEBUG, "Methods: " + str(authlist))
self.transport.saved_exception = PartialAuthentication(authlist)
elif self.auth_method not in authlist:
for msg in (
- 'Authentication type ({}) not permitted.'.format(
+ "Authentication type ({}) not permitted.".format(
self.auth_method
),
- 'Allowed methods: {}'.format(authlist),
+ "Allowed methods: {}".format(authlist),
):
self._log(DEBUG, msg)
self.transport.saved_exception = BadAuthenticationType(
- 'Bad authentication type', authlist
+ "Bad authentication type", authlist
)
else:
self._log(
- INFO,
- 'Authentication ({}) failed.'.format(self.auth_method)
+ INFO, "Authentication ({}) failed.".format(self.auth_method)
)
self.authenticated = False
self.username = None
@@ -614,12 +653,12 @@ Error Message: {}
def _parse_userauth_banner(self, m):
banner = m.get_string()
self.banner = banner
- self._log(INFO, 'Auth banner: {}'.format(banner))
+ self._log(INFO, "Auth banner: {}".format(banner))
# who cares.
def _parse_userauth_info_request(self, m):
- if self.auth_method != 'keyboard-interactive':
- raise SSHException('Illegal info request from server')
+ if self.auth_method != "keyboard-interactive":
+ raise SSHException("Illegal info request from server")
title = m.get_text()
instructions = m.get_text()
m.get_binary() # lang
@@ -628,7 +667,8 @@ Error Message: {}
for i in range(prompts):
prompt_list.append((m.get_text(), m.get_boolean()))
response_list = self.interactive_handler(
- title, instructions, prompt_list)
+ title, instructions, prompt_list
+ )
m = Message()
m.add_byte(cMSG_USERAUTH_INFO_RESPONSE)
@@ -639,25 +679,26 @@ Error Message: {}
def _parse_userauth_info_response(self, m):
if not self.transport.server_mode:
- raise SSHException('Illegal info response from server')
+ raise SSHException("Illegal info response from server")
n = m.get_int()
responses = []
for i in range(n):
responses.append(m.get_text())
result = self.transport.server_object.check_auth_interactive_response(
- responses)
+ responses
+ )
if isinstance(result, InteractiveQuery):
# make interactive query instead of response
self._interactive_query(result)
return
self._send_auth_result(
- self.auth_username, 'keyboard-interactive', result)
+ self.auth_username, "keyboard-interactive", result
+ )
def _handle_local_gss_failure(self, e):
self.transport.saved_exception = e
self._log(DEBUG, "GSSAPI failure: {}".format(e))
- self._log(INFO, 'Authentication ({}) failed.'.format(
- self.auth_method))
+ self._log(INFO, "Authentication ({}) failed.".format(self.auth_method))
self.authenticated = False
self.username = None
if self.auth_event is not None:
@@ -718,9 +759,9 @@ class GssapiWithMicAuthHandler(object):
# context.
sshgss = self.sshgss
try:
- token = sshgss.ssh_accept_sec_context(self.gss_host,
- client_token,
- self.auth_username)
+ token = sshgss.ssh_accept_sec_context(
+ self.gss_host, client_token, self.auth_username
+ )
except Exception as e:
self.transport.saved_exception = e
result = AUTH_FAILED
@@ -731,9 +772,11 @@ class GssapiWithMicAuthHandler(object):
m = Message()
m.add_byte(cMSG_USERAUTH_GSSAPI_TOKEN)
m.add_string(token)
- self.transport._expected_packet = (MSG_USERAUTH_GSSAPI_TOKEN,
- MSG_USERAUTH_GSSAPI_MIC,
- MSG_USERAUTH_REQUEST)
+ self.transport._expected_packet = (
+ MSG_USERAUTH_GSSAPI_TOKEN,
+ MSG_USERAUTH_GSSAPI_MIC,
+ MSG_USERAUTH_REQUEST,
+ )
self.transport._send_message(m)
def _parse_userauth_gssapi_mic(self, m):
@@ -742,9 +785,9 @@ class GssapiWithMicAuthHandler(object):
username = self.auth_username
self._restore_delegate_auth_handler()
try:
- sshgss.ssh_check_mic(mic_token,
- self.transport.session_id,
- username)
+ sshgss.ssh_check_mic(
+ mic_token, self.transport.session_id, username
+ )
except Exception as e:
self.transport.saved_exception = e
result = AUTH_FAILED
@@ -754,8 +797,9 @@ class GssapiWithMicAuthHandler(object):
# The OpenSSH server is able to create a TGT with the delegated
# client credentials, but this is not supported by GSS-API.
result = AUTH_SUCCESSFUL
- self.transport.server_object.check_auth_gssapi_with_mic(username,
- result)
+ self.transport.server_object.check_auth_gssapi_with_mic(
+ username, result
+ )
# okay, send result
self._send_auth_result(username, self.method, result)
diff --git a/paramiko/ber.py b/paramiko/ber.py
index 876347e0..92d7121e 100644
--- a/paramiko/ber.py
+++ b/paramiko/ber.py
@@ -21,7 +21,7 @@ from paramiko.py3compat import b, byte_ord, byte_chr, long
import paramiko.util as util
-class BERException (Exception):
+class BERException(Exception):
pass
@@ -41,7 +41,7 @@ class BER(object):
return self.asbytes()
def __repr__(self):
- return 'BER(\'' + repr(self.content) + '\')'
+ return "BER('" + repr(self.content) + "')"
def decode(self):
return self.decode_next()
@@ -72,12 +72,13 @@ class BER(object):
if self.idx + t > len(self.content):
return None
size = util.inflate_long(
- self.content[self.idx: self.idx + t], True)
+ self.content[self.idx : self.idx + t], True
+ )
self.idx += t
if self.idx + size > len(self.content):
# can't fit
return None
- data = self.content[self.idx: self.idx + size]
+ data = self.content[self.idx : self.idx + size]
self.idx += size
# now switch on id
if ident == 0x30:
@@ -88,7 +89,7 @@ class BER(object):
return util.inflate_long(data)
else:
# 1: boolean (00 false, otherwise true)
- msg = 'Unknown ber encoding type {:d} (robey is lazy)'
+ msg = "Unknown ber encoding type {:d} (robey is lazy)"
raise BERException(msg.format(ident))
@staticmethod
@@ -126,7 +127,7 @@ class BER(object):
self.encode_tlv(0x30, self.encode_sequence(x))
else:
raise BERException(
- 'Unknown type for encoding: {!r}'.format(type(x))
+ "Unknown type for encoding: {!r}".format(type(x))
)
@staticmethod
diff --git a/paramiko/buffered_pipe.py b/paramiko/buffered_pipe.py
index d9f5149d..48f5aae5 100644
--- a/paramiko/buffered_pipe.py
+++ b/paramiko/buffered_pipe.py
@@ -28,14 +28,14 @@ import time
from paramiko.py3compat import PY2, b
-class PipeTimeout (IOError):
+class PipeTimeout(IOError):
"""
Indicates that a timeout was reached on a read from a `.BufferedPipe`.
"""
pass
-class BufferedPipe (object):
+class BufferedPipe(object):
"""
A buffer that obeys normal read (with timeout) & close semantics for a
file or socket, but is fed data from another thread. This is used by
@@ -46,16 +46,19 @@ class BufferedPipe (object):
self._lock = threading.Lock()
self._cv = threading.Condition(self._lock)
self._event = None
- self._buffer = array.array('B')
+ self._buffer = array.array("B")
self._closed = False
if PY2:
+
def _buffer_frombytes(self, data):
self._buffer.fromstring(data)
def _buffer_tobytes(self, limit=None):
return self._buffer[:limit].tostring()
+
else:
+
def _buffer_frombytes(self, data):
self._buffer.frombytes(data)
diff --git a/paramiko/channel.py b/paramiko/channel.py
index 91a8f0df..00f86d6e 100644
--- a/paramiko/channel.py
+++ b/paramiko/channel.py
@@ -25,14 +25,22 @@ import os
import socket
import time
import threading
+
# TODO: switch as much of py3compat.py to 'six' as possible, then use six.wraps
from functools import wraps
from paramiko import util
from paramiko.common import (
- cMSG_CHANNEL_REQUEST, cMSG_CHANNEL_WINDOW_ADJUST, cMSG_CHANNEL_DATA,
- cMSG_CHANNEL_EXTENDED_DATA, DEBUG, ERROR, cMSG_CHANNEL_SUCCESS,
- cMSG_CHANNEL_FAILURE, cMSG_CHANNEL_EOF, cMSG_CHANNEL_CLOSE,
+ cMSG_CHANNEL_REQUEST,
+ cMSG_CHANNEL_WINDOW_ADJUST,
+ cMSG_CHANNEL_DATA,
+ cMSG_CHANNEL_EXTENDED_DATA,
+ DEBUG,
+ ERROR,
+ cMSG_CHANNEL_SUCCESS,
+ cMSG_CHANNEL_FAILURE,
+ cMSG_CHANNEL_EOF,
+ cMSG_CHANNEL_CLOSE,
)
from paramiko.message import Message
from paramiko.py3compat import bytes_types
@@ -51,20 +59,22 @@ def open_only(func):
`.SSHException` -- If the wrapped method is called on an unopened
`.Channel`.
"""
+
@wraps(func)
def _check(self, *args, **kwds):
if (
- self.closed or
- self.eof_received or
- self.eof_sent or
- not self.active
+ self.closed
+ or self.eof_received
+ or self.eof_sent
+ or not self.active
):
- raise SSHException('Channel is not open')
+ raise SSHException("Channel is not open")
return func(self, *args, **kwds)
+
return _check
-class Channel (ClosingContextManager):
+class Channel(ClosingContextManager):
"""
A secure tunnel across an SSH `.Transport`. A Channel is meant to behave
like a socket, and has an API that should be indistinguishable from the
@@ -117,7 +127,7 @@ class Channel (ClosingContextManager):
self.in_window_sofar = 0
self.status_event = threading.Event()
self._name = str(chanid)
- self.logger = util.get_logger('paramiko.transport')
+ self.logger = util.get_logger("paramiko.transport")
self._pipe = None
self.event = threading.Event()
self.event_ready = False
@@ -135,24 +145,30 @@ class Channel (ClosingContextManager):
"""
Return a string representation of this object, for debugging.
"""
- out = '<paramiko.Channel {}'.format(self.chanid)
+ out = "<paramiko.Channel {}".format(self.chanid)
if self.closed:
- out += ' (closed)'
+ out += " (closed)"
elif self.active:
if self.eof_received:
- out += ' (EOF received)'
+ out += " (EOF received)"
if self.eof_sent:
- out += ' (EOF sent)'
- out += ' (open) window={}'.format(self.out_window_size)
+ out += " (EOF sent)"
+ out += " (open) window={}".format(self.out_window_size)
if len(self.in_buffer) > 0:
- out += ' in-buffer={}'.format(len(self.in_buffer))
- out += ' -> ' + repr(self.transport)
- out += '>'
+ out += " in-buffer={}".format(len(self.in_buffer))
+ out += " -> " + repr(self.transport)
+ out += ">"
return out
@open_only
- def get_pty(self, term='vt100', width=80, height=24, width_pixels=0,
- height_pixels=0):
+ def get_pty(
+ self,
+ term="vt100",
+ width=80,
+ height=24,
+ width_pixels=0,
+ height_pixels=0,
+ ):
"""
Request a pseudo-terminal from the server. This is usually used right
after creating a client channel, to ask the server to provide some
@@ -174,7 +190,7 @@ class Channel (ClosingContextManager):
m = Message()
m.add_byte(cMSG_CHANNEL_REQUEST)
m.add_int(self.remote_chanid)
- m.add_string('pty-req')
+ m.add_string("pty-req")
m.add_boolean(True)
m.add_string(term)
m.add_int(width)
@@ -207,7 +223,7 @@ class Channel (ClosingContextManager):
m = Message()
m.add_byte(cMSG_CHANNEL_REQUEST)
m.add_int(self.remote_chanid)
- m.add_string('shell')
+ m.add_string("shell")
m.add_boolean(True)
self._event_pending()
self.transport._send_user_message(m)
@@ -233,7 +249,7 @@ class Channel (ClosingContextManager):
m = Message()
m.add_byte(cMSG_CHANNEL_REQUEST)
m.add_int(self.remote_chanid)
- m.add_string('exec')
+ m.add_string("exec")
m.add_boolean(True)
m.add_string(command)
self._event_pending()
@@ -259,7 +275,7 @@ class Channel (ClosingContextManager):
m = Message()
m.add_byte(cMSG_CHANNEL_REQUEST)
m.add_int(self.remote_chanid)
- m.add_string('subsystem')
+ m.add_string("subsystem")
m.add_boolean(True)
m.add_string(subsystem)
self._event_pending()
@@ -284,7 +300,7 @@ class Channel (ClosingContextManager):
m = Message()
m.add_byte(cMSG_CHANNEL_REQUEST)
m.add_int(self.remote_chanid)
- m.add_string('window-change')
+ m.add_string("window-change")
m.add_boolean(False)
m.add_int(width)
m.add_int(height)
@@ -315,7 +331,7 @@ class Channel (ClosingContextManager):
try:
self.set_environment_variable(name, value)
except SSHException as e:
- err = "Failed to set environment variable \"{}\"."
+ err = 'Failed to set environment variable "{}".'
raise SSHException(err.format(name), e)
@open_only
@@ -339,7 +355,7 @@ class Channel (ClosingContextManager):
m = Message()
m.add_byte(cMSG_CHANNEL_REQUEST)
m.add_int(self.remote_chanid)
- m.add_string('env')
+ m.add_string("env")
m.add_boolean(False)
m.add_string(name)
m.add_string(value)
@@ -403,19 +419,19 @@ class Channel (ClosingContextManager):
m = Message()
m.add_byte(cMSG_CHANNEL_REQUEST)
m.add_int(self.remote_chanid)
- m.add_string('exit-status')
+ m.add_string("exit-status")
m.add_boolean(False)
m.add_int(status)
self.transport._send_user_message(m)
@open_only
def request_x11(
- self,
- screen_number=0,
- auth_protocol=None,
- auth_cookie=None,
- single_connection=False,
- handler=None
+ self,
+ screen_number=0,
+ auth_protocol=None,
+ auth_cookie=None,
+ single_connection=False,
+ handler=None,
):
"""
Request an x11 session on this channel. If the server allows it,
@@ -456,14 +472,14 @@ class Channel (ClosingContextManager):
:return: the auth_cookie used
"""
if auth_protocol is None:
- auth_protocol = 'MIT-MAGIC-COOKIE-1'
+ auth_protocol = "MIT-MAGIC-COOKIE-1"
if auth_cookie is None:
auth_cookie = binascii.hexlify(os.urandom(16))
m = Message()
m.add_byte(cMSG_CHANNEL_REQUEST)
m.add_int(self.remote_chanid)
- m.add_string('x11-req')
+ m.add_string("x11-req")
m.add_boolean(True)
m.add_boolean(single_connection)
m.add_string(auth_protocol)
@@ -493,7 +509,7 @@ class Channel (ClosingContextManager):
m = Message()
m.add_byte(cMSG_CHANNEL_REQUEST)
m.add_int(self.remote_chanid)
- m.add_string('auth-agent-req@openssh.com')
+ m.add_string("auth-agent-req@openssh.com")
m.add_boolean(False)
self.transport._send_user_message(m)
self.transport._set_forward_agent_handler(handler)
@@ -808,7 +824,6 @@ class Channel (ClosingContextManager):
m.add_int(1)
return self._send(s, m)
-
def sendall(self, s):
"""
Send data to the channel, without allowing partial results. Unlike
@@ -976,7 +991,7 @@ class Channel (ClosingContextManager):
# a window update
self.in_window_threshold = window_size // 10
self.in_window_sofar = 0
- self._log(DEBUG, 'Max packet in: {} bytes'.format(max_packet_size))
+ self._log(DEBUG, "Max packet in: {} bytes".format(max_packet_size))
def _set_remote_channel(self, chanid, window_size, max_packet_size):
self.remote_chanid = chanid
@@ -985,11 +1000,12 @@ class Channel (ClosingContextManager):
max_packet_size
)
self.active = 1
- self._log(DEBUG,
- 'Max packet out: {} bytes'.format(self.out_max_packet_size))
+ self._log(
+ DEBUG, "Max packet out: {} bytes".format(self.out_max_packet_size)
+ )
def _request_success(self, m):
- self._log(DEBUG, 'Sesch channel {} request ok'.format(self.chanid))
+ self._log(DEBUG, "Sesch channel {} request ok".format(self.chanid))
self.event_ready = True
self.event.set()
return
@@ -1017,8 +1033,7 @@ class Channel (ClosingContextManager):
s = m.get_binary()
if code != 1:
self._log(
- ERROR,
- 'unknown extended_data type {}; discarding'.format(code)
+ ERROR, "unknown extended_data type {}; discarding".format(code)
)
return
if self.combine_stderr:
@@ -1031,7 +1046,7 @@ class Channel (ClosingContextManager):
self.lock.acquire()
try:
if self.ultra_debug:
- self._log(DEBUG, 'window up {}'.format(nbytes))
+ self._log(DEBUG, "window up {}".format(nbytes))
self.out_window_size += nbytes
self.out_buffer_cv.notifyAll()
finally:
@@ -1042,14 +1057,14 @@ class Channel (ClosingContextManager):
want_reply = m.get_boolean()
server = self.transport.server_object
ok = False
- if key == 'exit-status':
+ if key == "exit-status":
self.exit_status = m.get_int()
self.status_event.set()
ok = True
- elif key == 'xon-xoff':
+ elif key == "xon-xoff":
# ignore
ok = True
- elif key == 'pty-req':
+ elif key == "pty-req":
term = m.get_string()
width = m.get_int()
height = m.get_int()
@@ -1060,39 +1075,33 @@ class Channel (ClosingContextManager):
ok = False
else:
ok = server.check_channel_pty_request(
- self,
- term,
- width,
- height,
- pixelwidth,
- pixelheight,
- modes
+ self, term, width, height, pixelwidth, pixelheight, modes
)
- elif key == 'shell':
+ elif key == "shell":
if server is None:
ok = False
else:
ok = server.check_channel_shell_request(self)
- elif key == 'env':
+ elif key == "env":
name = m.get_string()
value = m.get_string()
if server is None:
ok = False
else:
ok = server.check_channel_env_request(self, name, value)
- elif key == 'exec':
+ elif key == "exec":
cmd = m.get_string()
if server is None:
ok = False
else:
ok = server.check_channel_exec_request(self, cmd)
- elif key == 'subsystem':
+ elif key == "subsystem":
name = m.get_text()
if server is None:
ok = False
else:
ok = server.check_channel_subsystem_request(self, name)
- elif key == 'window-change':
+ elif key == "window-change":
width = m.get_int()
height = m.get_int()
pixelwidth = m.get_int()
@@ -1101,8 +1110,9 @@ class Channel (ClosingContextManager):
ok = False
else:
ok = server.check_channel_window_change_request(
- self, width, height, pixelwidth, pixelheight)
- elif key == 'x11-req':
+ self, width, height, pixelwidth, pixelheight
+ )
+ elif key == "x11-req":
single_connection = m.get_boolean()
auth_proto = m.get_text()
auth_cookie = m.get_binary()
@@ -1115,9 +1125,9 @@ class Channel (ClosingContextManager):
single_connection,
auth_proto,
auth_cookie,
- screen_number
+ screen_number,
)
- elif key == 'auth-agent-req@openssh.com':
+ elif key == "auth-agent-req@openssh.com":
if server is None:
ok = False
else:
@@ -1145,7 +1155,7 @@ class Channel (ClosingContextManager):
self._pipe.set_forever()
finally:
self.lock.release()
- self._log(DEBUG, 'EOF received ({})'.format(self._name))
+ self._log(DEBUG, "EOF received ({})".format(self._name))
def _handle_close(self, m):
self.lock.acquire()
@@ -1167,7 +1177,7 @@ class Channel (ClosingContextManager):
if self.closed:
# this doesn't seem useful, but it is the documented behavior
# of Socket
- raise socket.error('Socket is closed')
+ raise socket.error("Socket is closed")
size = self._wait_for_send_window(size)
if size == 0:
# eof or similar
@@ -1194,7 +1204,7 @@ class Channel (ClosingContextManager):
return
e = self.transport.get_exception()
if e is None:
- e = SSHException('Channel closed.')
+ e = SSHException("Channel closed.")
raise e
def _set_closed(self):
@@ -1217,7 +1227,7 @@ class Channel (ClosingContextManager):
m.add_byte(cMSG_CHANNEL_EOF)
m.add_int(self.remote_chanid)
self.eof_sent = True
- self._log(DEBUG, 'EOF sent ({})'.format(self._name))
+ self._log(DEBUG, "EOF sent ({})".format(self._name))
return m
def _close_internal(self):
@@ -1251,13 +1261,14 @@ class Channel (ClosingContextManager):
if self.closed or self.eof_received or not self.active:
return 0
if self.ultra_debug:
- self._log(DEBUG, 'addwindow {}'.format(n))
+ self._log(DEBUG, "addwindow {}".format(n))
self.in_window_sofar += n
if self.in_window_sofar <= self.in_window_threshold:
return 0
if self.ultra_debug:
- self._log(DEBUG,
- 'addwindow send {}'.format(self.in_window_sofar))
+ self._log(
+ DEBUG, "addwindow send {}".format(self.in_window_sofar)
+ )
out = self.in_window_sofar
self.in_window_sofar = 0
return out
@@ -1300,11 +1311,11 @@ class Channel (ClosingContextManager):
size = self.out_max_packet_size - 64
self.out_window_size -= size
if self.ultra_debug:
- self._log(DEBUG, 'window down to {}'.format(self.out_window_size))
+ self._log(DEBUG, "window down to {}".format(self.out_window_size))
return size
-class ChannelFile (BufferedFile):
+class ChannelFile(BufferedFile):
"""
A file-like wrapper around `.Channel`. A ChannelFile is created by calling
`Channel.makefile`.
@@ -1317,7 +1328,7 @@ class ChannelFile (BufferedFile):
flush the buffer.
"""
- def __init__(self, channel, mode='r', bufsize=-1):
+ def __init__(self, channel, mode="r", bufsize=-1):
self.channel = channel
BufferedFile.__init__(self)
self._set_mode(mode, bufsize)
@@ -1326,7 +1337,7 @@ class ChannelFile (BufferedFile):
"""
Returns a string representation of this object, for debugging.
"""
- return '<paramiko.ChannelFile from ' + repr(self.channel) + '>'
+ return "<paramiko.ChannelFile from " + repr(self.channel) + ">"
def _read(self, size):
return self.channel.recv(size)
@@ -1336,8 +1347,9 @@ class ChannelFile (BufferedFile):
return len(data)
-class ChannelStderrFile (ChannelFile):
- def __init__(self, channel, mode='r', bufsize=-1):
+class ChannelStderrFile(ChannelFile):
+
+ def __init__(self, channel, mode="r", bufsize=-1):
ChannelFile.__init__(self, channel, mode, bufsize)
def _read(self, size):
diff --git a/paramiko/client.py b/paramiko/client.py
index de0a495e..6bf479d4 100644
--- a/paramiko/client.py
+++ b/paramiko/client.py
@@ -38,13 +38,15 @@ from paramiko.hostkeys import HostKeys
from paramiko.py3compat import string_types
from paramiko.rsakey import RSAKey
from paramiko.ssh_exception import (
- SSHException, BadHostKeyException, NoValidConnectionsError
+ SSHException,
+ BadHostKeyException,
+ NoValidConnectionsError,
)
from paramiko.transport import Transport
from paramiko.util import retry_on_signal, ClosingContextManager
-class SSHClient (ClosingContextManager):
+class SSHClient(ClosingContextManager):
"""
A high-level representation of a session with an SSH server. This class
wraps `.Transport`, `.Channel`, and `.SFTPClient` to take care of most
@@ -97,7 +99,7 @@ class SSHClient (ClosingContextManager):
"""
if filename is None:
# try the user's .ssh key file, and mask exceptions
- filename = os.path.expanduser('~/.ssh/known_hosts')
+ filename = os.path.expanduser("~/.ssh/known_hosts")
try:
self._system_host_keys.load(filename)
except IOError:
@@ -140,12 +142,14 @@ class SSHClient (ClosingContextManager):
if self._host_keys_filename is not None:
self.load_host_keys(self._host_keys_filename)
- with open(filename, 'w') as f:
+ with open(filename, "w") as f:
for hostname, keys in self._host_keys.items():
for keytype, key in keys.items():
- f.write('{} {} {}\n'.format(
- hostname, keytype, key.get_base64()
- ))
+ f.write(
+ "{} {} {}\n".format(
+ hostname, keytype, key.get_base64()
+ )
+ )
def get_host_keys(self):
"""
@@ -197,7 +201,8 @@ class SSHClient (ClosingContextManager):
"""
guess = True
addrinfos = socket.getaddrinfo(
- hostname, port, socket.AF_UNSPEC, socket.SOCK_STREAM)
+ hostname, port, socket.AF_UNSPEC, socket.SOCK_STREAM
+ )
for (family, socktype, proto, canonname, sockaddr) in addrinfos:
if socktype == socket.SOCK_STREAM:
yield family, sockaddr
@@ -419,8 +424,16 @@ class SSHClient (ClosingContextManager):
key_filenames = key_filename
self._auth(
- username, password, pkey, key_filenames, allow_agent,
- look_for_keys, gss_auth, gss_kex, gss_deleg_creds, t.gss_host,
+ username,
+ password,
+ pkey,
+ key_filenames,
+ allow_agent,
+ look_for_keys,
+ gss_auth,
+ gss_kex,
+ gss_deleg_creds,
+ t.gss_host,
passphrase,
)
@@ -490,13 +503,20 @@ class SSHClient (ClosingContextManager):
if environment:
chan.update_environment(environment)
chan.exec_command(command)
- stdin = chan.makefile('wb', bufsize)
- stdout = chan.makefile('r', bufsize)
- stderr = chan.makefile_stderr('r', bufsize)
+ stdin = chan.makefile("wb", bufsize)
+ stdout = chan.makefile("r", bufsize)
+ stderr = chan.makefile_stderr("r", bufsize)
return stdin, stdout, stderr
- def invoke_shell(self, term='vt100', width=80, height=24, width_pixels=0,
- height_pixels=0, environment=None):
+ def invoke_shell(
+ self,
+ term="vt100",
+ width=80,
+ height=24,
+ width_pixels=0,
+ height_pixels=0,
+ environment=None,
+ ):
"""
Start an interactive shell session on the SSH server. A new `.Channel`
is opened and connected to a pseudo-terminal using the requested
@@ -545,10 +565,10 @@ class SSHClient (ClosingContextManager):
- Otherwise, the filename is assumed to be a private key, and the
matching public cert will be loaded if it exists.
"""
- cert_suffix = '-cert.pub'
+ cert_suffix = "-cert.pub"
# Assume privkey, not cert, by default
if filename.endswith(cert_suffix):
- key_path = filename[:-len(cert_suffix)]
+ key_path = filename[: -len(cert_suffix)]
cert_path = filename
else:
key_path = filename
@@ -559,7 +579,7 @@ class SSHClient (ClosingContextManager):
# when #387 is released, since this is a critical log message users are
# likely testing/filtering for (bah.)
msg = "Trying discovered key {} in {}".format(
- hexlify(key.get_fingerprint()), key_path,
+ hexlify(key.get_fingerprint()), key_path
)
self._log(DEBUG, msg)
# Attempt to load cert if it exists.
@@ -569,8 +589,17 @@ class SSHClient (ClosingContextManager):
return key
def _auth(
- self, username, password, pkey, key_filenames, allow_agent,
- look_for_keys, gss_auth, gss_kex, gss_deleg_creds, gss_host,
+ self,
+ username,
+ password,
+ pkey,
+ key_filenames,
+ allow_agent,
+ look_for_keys,
+ gss_auth,
+ gss_kex,
+ gss_deleg_creds,
+ gss_host,
passphrase,
):
"""
@@ -589,7 +618,7 @@ class SSHClient (ClosingContextManager):
saved_exception = None
two_factor = False
allowed_types = set()
- two_factor_types = {'keyboard-interactive', 'password'}
+ two_factor_types = {"keyboard-interactive", "password"}
if passphrase is None and password is not None:
passphrase = password
@@ -609,7 +638,7 @@ class SSHClient (ClosingContextManager):
if gss_auth:
try:
return self._transport.auth_gssapi_with_mic(
- username, gss_host, gss_deleg_creds,
+ username, gss_host, gss_deleg_creds
)
except Exception as e:
saved_exception = e
@@ -618,11 +647,14 @@ class SSHClient (ClosingContextManager):
try:
self._log(
DEBUG,
- 'Trying SSH key {}'.format(hexlify(pkey.get_fingerprint()))
+ "Trying SSH key {}".format(
+ hexlify(pkey.get_fingerprint())
+ ),
)
allowed_types = set(
- self._transport.auth_publickey(username, pkey))
- two_factor = (allowed_types & two_factor_types)
+ self._transport.auth_publickey(username, pkey)
+ )
+ two_factor = allowed_types & two_factor_types
if not two_factor:
return
except SSHException as e:
@@ -633,11 +665,12 @@ class SSHClient (ClosingContextManager):
for pkey_class in (RSAKey, DSSKey, ECDSAKey, Ed25519Key):
try:
key = self._key_from_filepath(
- key_filename, pkey_class, passphrase,
+ key_filename, pkey_class, passphrase
)
allowed_types = set(
- self._transport.auth_publickey(username, key))
- two_factor = (allowed_types & two_factor_types)
+ self._transport.auth_publickey(username, key)
+ )
+ two_factor = allowed_types & two_factor_types
if not two_factor:
return
break
@@ -651,12 +684,13 @@ class SSHClient (ClosingContextManager):
for key in self._agent.get_keys():
try:
id_ = hexlify(key.get_fingerprint())
- self._log(DEBUG, 'Trying SSH agent key {}'.format(id_))
+ self._log(DEBUG, "Trying SSH agent key {}".format(id_))
# for 2-factor auth a successfully auth'd key password
# will return an allowed 2fac auth method
allowed_types = set(
- self._transport.auth_publickey(username, key))
- two_factor = (allowed_types & two_factor_types)
+ self._transport.auth_publickey(username, key)
+ )
+ two_factor = allowed_types & two_factor_types
if not two_factor:
return
break
@@ -680,8 +714,8 @@ class SSHClient (ClosingContextManager):
if os.path.isfile(full_path):
# TODO: only do this append if below did not run
keyfiles.append((keytype, full_path))
- if os.path.isfile(full_path + '-cert.pub'):
- keyfiles.append((keytype, full_path + '-cert.pub'))
+ if os.path.isfile(full_path + "-cert.pub"):
+ keyfiles.append((keytype, full_path + "-cert.pub"))
if not look_for_keys:
keyfiles = []
@@ -689,13 +723,14 @@ class SSHClient (ClosingContextManager):
for pkey_class, filename in keyfiles:
try:
key = self._key_from_filepath(
- filename, pkey_class, passphrase,
+ filename, pkey_class, passphrase
)
# for 2-factor auth a successfully auth'd key will result
# in ['password']
allowed_types = set(
- self._transport.auth_publickey(username, key))
- two_factor = (allowed_types & two_factor_types)
+ self._transport.auth_publickey(username, key)
+ )
+ two_factor = allowed_types & two_factor_types
if not two_factor:
return
break
@@ -718,13 +753,13 @@ class SSHClient (ClosingContextManager):
# if we got an auth-failed exception earlier, re-raise it
if saved_exception is not None:
raise saved_exception
- raise SSHException('No authentication methods available')
+ raise SSHException("No authentication methods available")
def _log(self, level, msg):
self._transport._log(level, msg)
-class MissingHostKeyPolicy (object):
+class MissingHostKeyPolicy(object):
"""
Interface for defining the policy that `.SSHClient` should use when the
SSH server's hostname is not in either the system host keys or the
@@ -745,7 +780,7 @@ class MissingHostKeyPolicy (object):
pass
-class AutoAddPolicy (MissingHostKeyPolicy):
+class AutoAddPolicy(MissingHostKeyPolicy):
"""
Policy for automatically adding the hostname and new host key to the
local `.HostKeys` object, and saving it. This is used by `.SSHClient`.
@@ -755,32 +790,41 @@ class AutoAddPolicy (MissingHostKeyPolicy):
client._host_keys.add(hostname, key.get_name(), key)
if client._host_keys_filename is not None:
client.save_host_keys(client._host_keys_filename)
- client._log(DEBUG, 'Adding {} host key for {}: {}'.format(
- key.get_name(), hostname, hexlify(key.get_fingerprint()),
- ))
+ client._log(
+ DEBUG,
+ "Adding {} host key for {}: {}".format(
+ key.get_name(), hostname, hexlify(key.get_fingerprint())
+ ),
+ )
-class RejectPolicy (MissingHostKeyPolicy):
+class RejectPolicy(MissingHostKeyPolicy):
"""
Policy for automatically rejecting the unknown hostname & key. This is
used by `.SSHClient`.
"""
def missing_host_key(self, client, hostname, key):
- client._log(DEBUG, 'Rejecting {} host key for {}: {}'.format(
- key.get_name(), hostname, hexlify(key.get_fingerprint()),
- ))
+ client._log(
+ DEBUG,
+ "Rejecting {} host key for {}: {}".format(
+ key.get_name(), hostname, hexlify(key.get_fingerprint())
+ ),
+ )
raise SSHException(
- 'Server {!r} not found in known_hosts'.format(hostname)
+ "Server {!r} not found in known_hosts".format(hostname)
)
-class WarningPolicy (MissingHostKeyPolicy):
+class WarningPolicy(MissingHostKeyPolicy):
"""
Policy for logging a Python-style warning for an unknown host key, but
accepting it. This is used by `.SSHClient`.
"""
+
def missing_host_key(self, client, hostname, key):
- warnings.warn('Unknown {} host key for {}: {}'.format(
- key.get_name(), hostname, hexlify(key.get_fingerprint()),
- ))
+ warnings.warn(
+ "Unknown {} host key for {}: {}".format(
+ key.get_name(), hostname, hexlify(key.get_fingerprint())
+ )
+ )
diff --git a/paramiko/common.py b/paramiko/common.py
index eab6647e..7bd0cb10 100644
--- a/paramiko/common.py
+++ b/paramiko/common.py
@@ -22,22 +22,45 @@ Common constants and global variables.
import logging
from paramiko.py3compat import byte_chr, PY2, long, b
-MSG_DISCONNECT, MSG_IGNORE, MSG_UNIMPLEMENTED, MSG_DEBUG, \
- MSG_SERVICE_REQUEST, MSG_SERVICE_ACCEPT = range(1, 7)
-MSG_KEXINIT, MSG_NEWKEYS = range(20, 22)
-MSG_USERAUTH_REQUEST, MSG_USERAUTH_FAILURE, MSG_USERAUTH_SUCCESS, \
- MSG_USERAUTH_BANNER = range(50, 54)
+(
+ MSG_DISCONNECT,
+ MSG_IGNORE,
+ MSG_UNIMPLEMENTED,
+ MSG_DEBUG,
+ MSG_SERVICE_REQUEST,
+ MSG_SERVICE_ACCEPT,
+) = range(1, 7)
+(MSG_KEXINIT, MSG_NEWKEYS) = range(20, 22)
+(
+ MSG_USERAUTH_REQUEST,
+ MSG_USERAUTH_FAILURE,
+ MSG_USERAUTH_SUCCESS,
+ MSG_USERAUTH_BANNER,
+) = range(50, 54)
MSG_USERAUTH_PK_OK = 60
-MSG_USERAUTH_INFO_REQUEST, MSG_USERAUTH_INFO_RESPONSE = range(60, 62)
-MSG_USERAUTH_GSSAPI_RESPONSE, MSG_USERAUTH_GSSAPI_TOKEN = range(60, 62)
-MSG_USERAUTH_GSSAPI_EXCHANGE_COMPLETE, MSG_USERAUTH_GSSAPI_ERROR,\
- MSG_USERAUTH_GSSAPI_ERRTOK, MSG_USERAUTH_GSSAPI_MIC = range(63, 67)
+(MSG_USERAUTH_INFO_REQUEST, MSG_USERAUTH_INFO_RESPONSE) = range(60, 62)
+(MSG_USERAUTH_GSSAPI_RESPONSE, MSG_USERAUTH_GSSAPI_TOKEN) = range(60, 62)
+(
+ MSG_USERAUTH_GSSAPI_EXCHANGE_COMPLETE,
+ MSG_USERAUTH_GSSAPI_ERROR,
+ MSG_USERAUTH_GSSAPI_ERRTOK,
+ MSG_USERAUTH_GSSAPI_MIC,
+) = range(63, 67)
HIGHEST_USERAUTH_MESSAGE_ID = 79
-MSG_GLOBAL_REQUEST, MSG_REQUEST_SUCCESS, MSG_REQUEST_FAILURE = range(80, 83)
-MSG_CHANNEL_OPEN, MSG_CHANNEL_OPEN_SUCCESS, MSG_CHANNEL_OPEN_FAILURE, \
- MSG_CHANNEL_WINDOW_ADJUST, MSG_CHANNEL_DATA, MSG_CHANNEL_EXTENDED_DATA, \
- MSG_CHANNEL_EOF, MSG_CHANNEL_CLOSE, MSG_CHANNEL_REQUEST, \
- MSG_CHANNEL_SUCCESS, MSG_CHANNEL_FAILURE = range(90, 101)
+(MSG_GLOBAL_REQUEST, MSG_REQUEST_SUCCESS, MSG_REQUEST_FAILURE) = range(80, 83)
+(
+ MSG_CHANNEL_OPEN,
+ MSG_CHANNEL_OPEN_SUCCESS,
+ MSG_CHANNEL_OPEN_FAILURE,
+ MSG_CHANNEL_WINDOW_ADJUST,
+ MSG_CHANNEL_DATA,
+ MSG_CHANNEL_EXTENDED_DATA,
+ MSG_CHANNEL_EOF,
+ MSG_CHANNEL_CLOSE,
+ MSG_CHANNEL_REQUEST,
+ MSG_CHANNEL_SUCCESS,
+ MSG_CHANNEL_FAILURE,
+) = range(90, 101)
cMSG_DISCONNECT = byte_chr(MSG_DISCONNECT)
cMSG_IGNORE = byte_chr(MSG_IGNORE)
@@ -56,8 +79,9 @@ cMSG_USERAUTH_INFO_REQUEST = byte_chr(MSG_USERAUTH_INFO_REQUEST)
cMSG_USERAUTH_INFO_RESPONSE = byte_chr(MSG_USERAUTH_INFO_RESPONSE)
cMSG_USERAUTH_GSSAPI_RESPONSE = byte_chr(MSG_USERAUTH_GSSAPI_RESPONSE)
cMSG_USERAUTH_GSSAPI_TOKEN = byte_chr(MSG_USERAUTH_GSSAPI_TOKEN)
-cMSG_USERAUTH_GSSAPI_EXCHANGE_COMPLETE = \
- byte_chr(MSG_USERAUTH_GSSAPI_EXCHANGE_COMPLETE)
+cMSG_USERAUTH_GSSAPI_EXCHANGE_COMPLETE = byte_chr(
+ MSG_USERAUTH_GSSAPI_EXCHANGE_COMPLETE
+)
cMSG_USERAUTH_GSSAPI_ERROR = byte_chr(MSG_USERAUTH_GSSAPI_ERROR)
cMSG_USERAUTH_GSSAPI_ERRTOK = byte_chr(MSG_USERAUTH_GSSAPI_ERRTOK)
cMSG_USERAUTH_GSSAPI_MIC = byte_chr(MSG_USERAUTH_GSSAPI_MIC)
@@ -78,47 +102,47 @@ cMSG_CHANNEL_FAILURE = byte_chr(MSG_CHANNEL_FAILURE)
# for debugging:
MSG_NAMES = {
- MSG_DISCONNECT: 'disconnect',
- MSG_IGNORE: 'ignore',
- MSG_UNIMPLEMENTED: 'unimplemented',
- MSG_DEBUG: 'debug',
- MSG_SERVICE_REQUEST: 'service-request',
- MSG_SERVICE_ACCEPT: 'service-accept',
- MSG_KEXINIT: 'kexinit',
- MSG_NEWKEYS: 'newkeys',
- 30: 'kex30',
- 31: 'kex31',
- 32: 'kex32',
- 33: 'kex33',
- 34: 'kex34',
- 40: 'kex40',
- 41: 'kex41',
- MSG_USERAUTH_REQUEST: 'userauth-request',
- MSG_USERAUTH_FAILURE: 'userauth-failure',
- MSG_USERAUTH_SUCCESS: 'userauth-success',
- MSG_USERAUTH_BANNER: 'userauth--banner',
- MSG_USERAUTH_PK_OK: 'userauth-60(pk-ok/info-request)',
- MSG_USERAUTH_INFO_RESPONSE: 'userauth-info-response',
- MSG_GLOBAL_REQUEST: 'global-request',
- MSG_REQUEST_SUCCESS: 'request-success',
- MSG_REQUEST_FAILURE: 'request-failure',
- MSG_CHANNEL_OPEN: 'channel-open',
- MSG_CHANNEL_OPEN_SUCCESS: 'channel-open-success',
- MSG_CHANNEL_OPEN_FAILURE: 'channel-open-failure',
- MSG_CHANNEL_WINDOW_ADJUST: 'channel-window-adjust',
- MSG_CHANNEL_DATA: 'channel-data',
- MSG_CHANNEL_EXTENDED_DATA: 'channel-extended-data',
- MSG_CHANNEL_EOF: 'channel-eof',
- MSG_CHANNEL_CLOSE: 'channel-close',
- MSG_CHANNEL_REQUEST: 'channel-request',
- MSG_CHANNEL_SUCCESS: 'channel-success',
- MSG_CHANNEL_FAILURE: 'channel-failure',
- MSG_USERAUTH_GSSAPI_RESPONSE: 'userauth-gssapi-response',
- MSG_USERAUTH_GSSAPI_TOKEN: 'userauth-gssapi-token',
- MSG_USERAUTH_GSSAPI_EXCHANGE_COMPLETE: 'userauth-gssapi-exchange-complete',
- MSG_USERAUTH_GSSAPI_ERROR: 'userauth-gssapi-error',
- MSG_USERAUTH_GSSAPI_ERRTOK: 'userauth-gssapi-error-token',
- MSG_USERAUTH_GSSAPI_MIC: 'userauth-gssapi-mic'
+ MSG_DISCONNECT: "disconnect",
+ MSG_IGNORE: "ignore",
+ MSG_UNIMPLEMENTED: "unimplemented",
+ MSG_DEBUG: "debug",
+ MSG_SERVICE_REQUEST: "service-request",
+ MSG_SERVICE_ACCEPT: "service-accept",
+ MSG_KEXINIT: "kexinit",
+ MSG_NEWKEYS: "newkeys",
+ 30: "kex30",
+ 31: "kex31",
+ 32: "kex32",
+ 33: "kex33",
+ 34: "kex34",
+ 40: "kex40",
+ 41: "kex41",
+ MSG_USERAUTH_REQUEST: "userauth-request",
+ MSG_USERAUTH_FAILURE: "userauth-failure",
+ MSG_USERAUTH_SUCCESS: "userauth-success",
+ MSG_USERAUTH_BANNER: "userauth--banner",
+ MSG_USERAUTH_PK_OK: "userauth-60(pk-ok/info-request)",
+ MSG_USERAUTH_INFO_RESPONSE: "userauth-info-response",
+ MSG_GLOBAL_REQUEST: "global-request",
+ MSG_REQUEST_SUCCESS: "request-success",
+ MSG_REQUEST_FAILURE: "request-failure",
+ MSG_CHANNEL_OPEN: "channel-open",
+ MSG_CHANNEL_OPEN_SUCCESS: "channel-open-success",
+ MSG_CHANNEL_OPEN_FAILURE: "channel-open-failure",
+ MSG_CHANNEL_WINDOW_ADJUST: "channel-window-adjust",
+ MSG_CHANNEL_DATA: "channel-data",
+ MSG_CHANNEL_EXTENDED_DATA: "channel-extended-data",
+ MSG_CHANNEL_EOF: "channel-eof",
+ MSG_CHANNEL_CLOSE: "channel-close",
+ MSG_CHANNEL_REQUEST: "channel-request",
+ MSG_CHANNEL_SUCCESS: "channel-success",
+ MSG_CHANNEL_FAILURE: "channel-failure",
+ MSG_USERAUTH_GSSAPI_RESPONSE: "userauth-gssapi-response",
+ MSG_USERAUTH_GSSAPI_TOKEN: "userauth-gssapi-token",
+ MSG_USERAUTH_GSSAPI_EXCHANGE_COMPLETE: "userauth-gssapi-exchange-complete",
+ MSG_USERAUTH_GSSAPI_ERROR: "userauth-gssapi-error",
+ MSG_USERAUTH_GSSAPI_ERRTOK: "userauth-gssapi-error-token",
+ MSG_USERAUTH_GSSAPI_MIC: "userauth-gssapi-mic",
}
@@ -127,23 +151,28 @@ AUTH_SUCCESSFUL, AUTH_PARTIALLY_SUCCESSFUL, AUTH_FAILED = range(3)
# channel request failed reasons:
-(OPEN_SUCCEEDED,
- OPEN_FAILED_ADMINISTRATIVELY_PROHIBITED,
- OPEN_FAILED_CONNECT_FAILED,
- OPEN_FAILED_UNKNOWN_CHANNEL_TYPE,
- OPEN_FAILED_RESOURCE_SHORTAGE) = range(0, 5)
+(
+ OPEN_SUCCEEDED,
+ OPEN_FAILED_ADMINISTRATIVELY_PROHIBITED,
+ OPEN_FAILED_CONNECT_FAILED,
+ OPEN_FAILED_UNKNOWN_CHANNEL_TYPE,
+ OPEN_FAILED_RESOURCE_SHORTAGE,
+) = range(0, 5)
CONNECTION_FAILED_CODE = {
- 1: 'Administratively prohibited',
- 2: 'Connect failed',
- 3: 'Unknown channel type',
- 4: 'Resource shortage'
+ 1: "Administratively prohibited",
+ 2: "Connect failed",
+ 3: "Unknown channel type",
+ 4: "Resource shortage",
}
-DISCONNECT_SERVICE_NOT_AVAILABLE, DISCONNECT_AUTH_CANCELLED_BY_USER, \
- DISCONNECT_NO_MORE_AUTH_METHODS_AVAILABLE = 7, 13, 14
+(
+ DISCONNECT_SERVICE_NOT_AVAILABLE,
+ DISCONNECT_AUTH_CANCELLED_BY_USER,
+ DISCONNECT_NO_MORE_AUTH_METHODS_AVAILABLE,
+) = (7, 13, 14)
zero_byte = byte_chr(0)
one_byte = byte_chr(1)
diff --git a/paramiko/compress.py b/paramiko/compress.py
index 5073109c..cea2b308 100644
--- a/paramiko/compress.py
+++ b/paramiko/compress.py
@@ -23,7 +23,8 @@ Compression implementations for a Transport.
import zlib
-class ZlibCompressor (object):
+class ZlibCompressor(object):
+
def __init__(self):
# Use the default level of zlib compression
self.z = zlib.compressobj()
@@ -32,7 +33,8 @@ class ZlibCompressor (object):
return self.z.compress(data) + self.z.flush(zlib.Z_FULL_FLUSH)
-class ZlibDecompressor (object):
+class ZlibDecompressor(object):
+
def __init__(self):
self.z = zlib.decompressobj()
diff --git a/paramiko/config.py b/paramiko/config.py
index 038d84ea..21c9dab8 100644
--- a/paramiko/config.py
+++ b/paramiko/config.py
@@ -30,7 +30,7 @@ import socket
SSH_PORT = 22
-class SSHConfig (object):
+class SSHConfig(object):
"""
Representation of config information as stored in the format used by
OpenSSH. Queries can be made via `lookup`. The format is described in
@@ -41,7 +41,7 @@ class SSHConfig (object):
.. versionadded:: 1.6
"""
- SETTINGS_REGEX = re.compile(r'(\w+)(?:\s*=\s*|\s+)(.+)')
+ SETTINGS_REGEX = re.compile(r"(\w+)(?:\s*=\s*|\s+)(.+)")
def __init__(self):
"""
@@ -55,12 +55,12 @@ class SSHConfig (object):
:param file_obj: a file-like object to read the config file from
"""
- host = {"host": ['*'], "config": {}}
+ host = {"host": ["*"], "config": {}}
for line in file_obj:
# Strip any leading or trailing whitespace from the line.
# Refer to https://github.com/paramiko/paramiko/issues/499
line = line.strip()
- if not line or line.startswith('#'):
+ if not line or line.startswith("#"):
continue
match = re.match(self.SETTINGS_REGEX, line)
@@ -69,17 +69,14 @@ class SSHConfig (object):
key = match.group(1).lower()
value = match.group(2)
- if key == 'host':
+ if key == "host":
self._config.append(host)
- host = {
- 'host': self._get_hosts(value),
- 'config': {}
- }
- elif key == 'proxycommand' and value.lower() == 'none':
+ host = {"host": self._get_hosts(value), "config": {}}
+ elif key == "proxycommand" and value.lower() == "none":
# Store 'none' as None; prior to 3.x, it will get stripped out
# at the end (for compatibility with issue #415). After 3.x, it
# will simply not get stripped, leaving a nice explicit marker.
- host['config'][key] = None
+ host["config"][key] = None
else:
if value.startswith('"') and value.endswith('"'):
value = value[1:-1]
@@ -87,13 +84,13 @@ class SSHConfig (object):
# identityfile, localforward, remoteforward keys are special
# cases, since they are allowed to be specified multiple times
# and they should be tried in order of specification.
- if key in ['identityfile', 'localforward', 'remoteforward']:
- if key in host['config']:
- host['config'][key].append(value)
+ if key in ["identityfile", "localforward", "remoteforward"]:
+ if key in host["config"]:
+ host["config"][key].append(value)
else:
- host['config'][key] = [value]
- elif key not in host['config']:
- host['config'][key] = value
+ host["config"][key] = [value]
+ elif key not in host["config"]:
+ host["config"][key] = value
self._config.append(host)
def lookup(self, hostname):
@@ -117,25 +114,26 @@ class SSHConfig (object):
:param str hostname: the hostname to lookup
"""
matches = [
- config for config in self._config
- if self._allowed(config['host'], hostname)
+ config
+ for config in self._config
+ if self._allowed(config["host"], hostname)
]
ret = {}
for match in matches:
- for key, value in match['config'].items():
+ for key, value in match["config"].items():
if key not in ret:
# Create a copy of the original value,
# else it will reference the original list
# in self._config and update that value too
# when the extend() is being called.
ret[key] = value[:] if value is not None else value
- elif key == 'identityfile':
+ elif key == "identityfile":
ret[key].extend(value)
ret = self._expand_variables(ret, hostname)
# TODO: remove in 3.x re #670
- if 'proxycommand' in ret and ret['proxycommand'] is None:
- del ret['proxycommand']
+ if "proxycommand" in ret and ret["proxycommand"] is None:
+ del ret["proxycommand"]
return ret
def get_hostnames(self):
@@ -145,13 +143,13 @@ class SSHConfig (object):
"""
hosts = set()
for entry in self._config:
- hosts.update(entry['host'])
+ hosts.update(entry["host"])
return hosts
def _allowed(self, hosts, hostname):
match = False
for host in hosts:
- if host.startswith('!') and fnmatch.fnmatch(hostname, host[1:]):
+ if host.startswith("!") and fnmatch.fnmatch(hostname, host[1:]):
return False
elif fnmatch.fnmatch(hostname, host):
match = True
@@ -169,52 +167,50 @@ class SSHConfig (object):
:param str hostname: the hostname that the config belongs to
"""
- if 'hostname' in config:
- config['hostname'] = config['hostname'].replace('%h', hostname)
+ if "hostname" in config:
+ config["hostname"] = config["hostname"].replace("%h", hostname)
else:
- config['hostname'] = hostname
+ config["hostname"] = hostname
- if 'port' in config:
- port = config['port']
+ if "port" in config:
+ port = config["port"]
else:
port = SSH_PORT
- user = os.getenv('USER')
- if 'user' in config:
- remoteuser = config['user']
+ user = os.getenv("USER")
+ if "user" in config:
+ remoteuser = config["user"]
else:
remoteuser = user
- host = socket.gethostname().split('.')[0]
+ host = socket.gethostname().split(".")[0]
fqdn = LazyFqdn(config, host)
- homedir = os.path.expanduser('~')
- replacements = {'controlpath':
- [
- ('%h', config['hostname']),
- ('%l', fqdn),
- ('%L', host),
- ('%n', hostname),
- ('%p', port),
- ('%r', remoteuser),
- ('%u', user)
- ],
- 'identityfile':
- [
- ('~', homedir),
- ('%d', homedir),
- ('%h', config['hostname']),
- ('%l', fqdn),
- ('%u', user),
- ('%r', remoteuser)
- ],
- 'proxycommand':
- [
- ('~', homedir),
- ('%h', config['hostname']),
- ('%p', port),
- ('%r', remoteuser)
- ]
- }
+ homedir = os.path.expanduser("~")
+ replacements = {
+ "controlpath": [
+ ("%h", config["hostname"]),
+ ("%l", fqdn),
+ ("%L", host),
+ ("%n", hostname),
+ ("%p", port),
+ ("%r", remoteuser),
+ ("%u", user),
+ ],
+ "identityfile": [
+ ("~", homedir),
+ ("%d", homedir),
+ ("%h", config["hostname"]),
+ ("%l", fqdn),
+ ("%u", user),
+ ("%r", remoteuser),
+ ],
+ "proxycommand": [
+ ("~", homedir),
+ ("%h", config["hostname"]),
+ ("%p", port),
+ ("%r", remoteuser),
+ ],
+ }
for k in config:
if config[k] is None:
@@ -265,11 +261,11 @@ class LazyFqdn(object):
# Handle specific option
fqdn = None
- address_family = self.config.get('addressfamily', 'any').lower()
- if address_family != 'any':
+ address_family = self.config.get("addressfamily", "any").lower()
+ if address_family != "any":
try:
family = socket.AF_INET6
- if address_family == 'inet':
+ if address_family == "inet":
socket.AF_INET
results = socket.getaddrinfo(
self.host,
@@ -277,11 +273,11 @@ class LazyFqdn(object):
family,
socket.SOCK_DGRAM,
socket.IPPROTO_IP,
- socket.AI_CANONNAME
+ socket.AI_CANONNAME,
)
for res in results:
af, socktype, proto, canonname, sa = res
- if canonname and '.' in canonname:
+ if canonname and "." in canonname:
fqdn = canonname
break
# giaerror -> socket.getaddrinfo() can't resolve self.host
diff --git a/paramiko/dsskey.py b/paramiko/dsskey.py
index ac1d4c2e..ec358ee2 100644
--- a/paramiko/dsskey.py
+++ b/paramiko/dsskey.py
@@ -25,7 +25,8 @@ from cryptography.hazmat.backends import default_backend
from cryptography.hazmat.primitives import hashes, serialization
from cryptography.hazmat.primitives.asymmetric import dsa
from cryptography.hazmat.primitives.asymmetric.utils import (
- decode_dss_signature, encode_dss_signature
+ decode_dss_signature,
+ encode_dss_signature,
)
from paramiko import util
@@ -42,8 +43,15 @@ class DSSKey(PKey):
data.
"""
- def __init__(self, msg=None, data=None, filename=None, password=None,
- vals=None, file_obj=None):
+ def __init__(
+ self,
+ msg=None,
+ data=None,
+ filename=None,
+ password=None,
+ vals=None,
+ file_obj=None,
+ ):
self.p = None
self.q = None
self.g = None
@@ -63,8 +71,8 @@ class DSSKey(PKey):
else:
self._check_type_and_load_cert(
msg=msg,
- key_type='ssh-dss',
- cert_type='ssh-dss-cert-v01@openssh.com',
+ key_type="ssh-dss",
+ cert_type="ssh-dss-cert-v01@openssh.com",
)
self.p = msg.get_mpint()
self.q = msg.get_mpint()
@@ -74,7 +82,7 @@ class DSSKey(PKey):
def asbytes(self):
m = Message()
- m.add_string('ssh-dss')
+ m.add_string("ssh-dss")
m.add_mpint(self.p)
m.add_mpint(self.q)
m.add_mpint(self.g)
@@ -88,7 +96,7 @@ class DSSKey(PKey):
return hash((self.get_name(), self.p, self.q, self.g, self.y))
def get_name(self):
- return 'ssh-dss'
+ return "ssh-dss"
def get_bits(self):
return self.size
@@ -102,17 +110,15 @@ class DSSKey(PKey):
public_numbers=dsa.DSAPublicNumbers(
y=self.y,
parameter_numbers=dsa.DSAParameterNumbers(
- p=self.p,
- q=self.q,
- g=self.g
- )
- )
+ p=self.p, q=self.q, g=self.g
+ ),
+ ),
).private_key(backend=default_backend())
sig = key.sign(data, hashes.SHA1())
r, s = decode_dss_signature(sig)
m = Message()
- m.add_string('ssh-dss')
+ m.add_string("ssh-dss")
# apparently, in rare cases, r or s may be shorter than 20 bytes!
rstr = util.deflate_long(r, 0)
sstr = util.deflate_long(s, 0)
@@ -129,7 +135,7 @@ class DSSKey(PKey):
sig = msg.asbytes()
else:
kind = msg.get_text()
- if kind != 'ssh-dss':
+ if kind != "ssh-dss":
return 0
sig = msg.get_binary()
@@ -142,10 +148,8 @@ class DSSKey(PKey):
key = dsa.DSAPublicNumbers(
y=self.y,
parameter_numbers=dsa.DSAParameterNumbers(
- p=self.p,
- q=self.q,
- g=self.g
- )
+ p=self.p, q=self.q, g=self.g
+ ),
).public_key(backend=default_backend())
try:
key.verify(signature, data, hashes.SHA1())
@@ -160,18 +164,16 @@ class DSSKey(PKey):
public_numbers=dsa.DSAPublicNumbers(
y=self.y,
parameter_numbers=dsa.DSAParameterNumbers(
- p=self.p,
- q=self.q,
- g=self.g
- )
- )
+ p=self.p, q=self.q, g=self.g
+ ),
+ ),
).private_key(backend=default_backend())
self._write_private_key_file(
filename,
key,
serialization.PrivateFormat.TraditionalOpenSSL,
- password=password
+ password=password,
)
def write_private_key(self, file_obj, password=None):
@@ -180,18 +182,16 @@ class DSSKey(PKey):
public_numbers=dsa.DSAPublicNumbers(
y=self.y,
parameter_numbers=dsa.DSAParameterNumbers(
- p=self.p,
- q=self.q,
- g=self.g
- )
- )
+ p=self.p, q=self.q, g=self.g
+ ),
+ ),
).private_key(backend=default_backend())
self._write_private_key(
file_obj,
key,
serialization.PrivateFormat.TraditionalOpenSSL,
- password=password
+ password=password,
)
@staticmethod
@@ -207,23 +207,25 @@ class DSSKey(PKey):
numbers = dsa.generate_private_key(
bits, backend=default_backend()
).private_numbers()
- key = DSSKey(vals=(
- numbers.public_numbers.parameter_numbers.p,
- numbers.public_numbers.parameter_numbers.q,
- numbers.public_numbers.parameter_numbers.g,
- numbers.public_numbers.y
- ))
+ key = DSSKey(
+ vals=(
+ numbers.public_numbers.parameter_numbers.p,
+ numbers.public_numbers.parameter_numbers.q,
+ numbers.public_numbers.parameter_numbers.g,
+ numbers.public_numbers.y,
+ )
+ )
key.x = numbers.x
return key
# ...internals...
def _from_private_key_file(self, filename, password):
- data = self._read_private_key_file('DSA', filename, password)
+ data = self._read_private_key_file("DSA", filename, password)
self._decode_key(data)
def _from_private_key(self, file_obj, password):
- data = self._read_private_key('DSA', file_obj, password)
+ data = self._read_private_key("DSA", file_obj, password)
self._decode_key(data)
def _decode_key(self, data):
@@ -232,14 +234,11 @@ class DSSKey(PKey):
try:
keylist = BER(data).decode()
except BERException as e:
- raise SSHException('Unable to parse key file: ' + str(e))
- if (
- type(keylist) is not list or
- len(keylist) < 6 or
- keylist[0] != 0
- ):
+ raise SSHException("Unable to parse key file: " + str(e))
+ if type(keylist) is not list or len(keylist) < 6 or keylist[0] != 0:
raise SSHException(
- 'not a valid DSA private key file (bad ber encoding)')
+ "not a valid DSA private key file (bad ber encoding)"
+ )
self.p = keylist[1]
self.q = keylist[2]
self.g = keylist[3]
diff --git a/paramiko/ecdsakey.py b/paramiko/ecdsakey.py
index 92e01a75..b73a969e 100644
--- a/paramiko/ecdsakey.py
+++ b/paramiko/ecdsakey.py
@@ -25,7 +25,8 @@ from cryptography.hazmat.backends import default_backend
from cryptography.hazmat.primitives import hashes, serialization
from cryptography.hazmat.primitives.asymmetric import ec
from cryptography.hazmat.primitives.asymmetric.utils import (
- decode_dss_signature, encode_dss_signature
+ decode_dss_signature,
+ encode_dss_signature,
)
from paramiko.common import four_byte
@@ -43,6 +44,7 @@ class _ECDSACurve(object):
the proper hash function. Also grabs the proper curve from the 'ecdsa'
package.
"""
+
def __init__(self, curve_class, nist_name):
self.nist_name = nist_name
self.key_length = curve_class.key_size
@@ -67,6 +69,7 @@ class _ECDSACurveSet(object):
format identifier. The two ways in which ECDSAKey needs to be able to look
up curves.
"""
+
def __init__(self, ecdsa_curves):
self.ecdsa_curves = ecdsa_curves
@@ -95,14 +98,24 @@ class ECDSAKey(PKey):
data.
"""
- _ECDSA_CURVES = _ECDSACurveSet([
- _ECDSACurve(ec.SECP256R1, 'nistp256'),
- _ECDSACurve(ec.SECP384R1, 'nistp384'),
- _ECDSACurve(ec.SECP521R1, 'nistp521'),
- ])
-
- def __init__(self, msg=None, data=None, filename=None, password=None,
- vals=None, file_obj=None, validate_point=True):
+ _ECDSA_CURVES = _ECDSACurveSet(
+ [
+ _ECDSACurve(ec.SECP256R1, "nistp256"),
+ _ECDSACurve(ec.SECP384R1, "nistp384"),
+ _ECDSACurve(ec.SECP521R1, "nistp521"),
+ ]
+ )
+
+ def __init__(
+ self,
+ msg=None,
+ data=None,
+ filename=None,
+ password=None,
+ vals=None,
+ file_obj=None,
+ validate_point=True,
+ ):
self.verifying_key = None
self.signing_key = None
self.public_blob = None
@@ -126,21 +139,18 @@ class ECDSAKey(PKey):
# identifier, so strip out any cert business. (NOTE: could push
# that into _ECDSACurveSet.get_by_key_format_identifier(), but it
# feels more correct to do it here?)
- suffix = '-cert-v01@openssh.com'
+ suffix = "-cert-v01@openssh.com"
if key_type.endswith(suffix):
- key_type = key_type[:-len(suffix)]
+ key_type = key_type[: -len(suffix)]
self.ecdsa_curve = self._ECDSA_CURVES.get_by_key_format_identifier(
key_type
)
key_types = self._ECDSA_CURVES.get_key_format_identifier_list()
cert_types = [
- '{}-cert-v01@openssh.com'.format(x)
- for x in key_types
+ "{}-cert-v01@openssh.com".format(x) for x in key_types
]
self._check_type_and_load_cert(
- msg=msg,
- key_type=key_types,
- cert_type=cert_types,
+ msg=msg, key_type=key_types, cert_type=cert_types
)
curvename = msg.get_text()
if curvename != self.ecdsa_curve.nist_name:
@@ -172,10 +182,10 @@ class ECDSAKey(PKey):
key_size_bytes = (key.curve.key_size + 7) // 8
x_bytes = deflate_long(numbers.x, add_sign_padding=False)
- x_bytes = b'\x00' * (key_size_bytes - len(x_bytes)) + x_bytes
+ x_bytes = b"\x00" * (key_size_bytes - len(x_bytes)) + x_bytes
y_bytes = deflate_long(numbers.y, add_sign_padding=False)
- y_bytes = b'\x00' * (key_size_bytes - len(y_bytes)) + y_bytes
+ y_bytes = b"\x00" * (key_size_bytes - len(y_bytes)) + y_bytes
point_str = four_byte + x_bytes + y_bytes
m.add_string(point_str)
@@ -185,8 +195,13 @@ class ECDSAKey(PKey):
return self.asbytes()
def __hash__(self):
- return hash((self.get_name(), self.verifying_key.public_numbers().x,
- self.verifying_key.public_numbers().y))
+ return hash(
+ (
+ self.get_name(),
+ self.verifying_key.public_numbers().x,
+ self.verifying_key.public_numbers().y,
+ )
+ )
def get_name(self):
return self.ecdsa_curve.key_format_identifier
@@ -228,7 +243,7 @@ class ECDSAKey(PKey):
filename,
self.signing_key,
serialization.PrivateFormat.TraditionalOpenSSL,
- password=password
+ password=password,
)
def write_private_key(self, file_obj, password=None):
@@ -236,7 +251,7 @@ class ECDSAKey(PKey):
file_obj,
self.signing_key,
serialization.PrivateFormat.TraditionalOpenSSL,
- password=password
+ password=password,
)
@classmethod
@@ -260,11 +275,11 @@ class ECDSAKey(PKey):
# ...internals...
def _from_private_key_file(self, filename, password):
- data = self._read_private_key_file('EC', filename, password)
+ data = self._read_private_key_file("EC", filename, password)
self._decode_key(data)
def _from_private_key(self, file_obj, password):
- data = self._read_private_key('EC', file_obj, password)
+ data = self._read_private_key("EC", file_obj, password)
self._decode_key(data)
def _decode_key(self, data):
diff --git a/paramiko/ed25519key.py b/paramiko/ed25519key.py
index 8ad71d08..68ada224 100644
--- a/paramiko/ed25519key.py
+++ b/paramiko/ed25519key.py
@@ -56,8 +56,10 @@ class Ed25519Key(PKey):
.. versionchanged:: 2.3
Added a ``file_obj`` parameter to match other key classes.
"""
- def __init__(self, msg=None, data=None, filename=None, password=None,
- file_obj=None):
+
+ def __init__(
+ self, msg=None, data=None, filename=None, password=None, file_obj=None
+ ):
self.public_blob = None
verifying_key = signing_key = None
if msg is None and data is not None:
@@ -86,6 +88,7 @@ class Ed25519Key(PKey):
def _parse_signing_key_data(self, data, password):
from paramiko.transport import Transport
+
# We may eventually want this to be usable for other key types, as
# OpenSSH moves to it, but for now this is just for Ed25519 keys.
# This format is described here:
@@ -142,9 +145,9 @@ class Ed25519Key(PKey):
ignore_few_rounds=True,
)
decryptor = Cipher(
- cipher["class"](key[:cipher["key-size"]]),
- cipher["mode"](key[cipher["key-size"]:]),
- backend=default_backend()
+ cipher["class"](key[: cipher["key-size"]]),
+ cipher["mode"](key[cipher["key-size"] :]),
+ backend=default_backend(),
).decryptor()
private_data = (
decryptor.update(private_ciphertext) + decryptor.finalize()
@@ -166,8 +169,10 @@ class Ed25519Key(PKey):
signing_key = nacl.signing.SigningKey(key_data[:32])
# Verify that all the public keys are the same...
assert (
- signing_key.verify_key.encode() == public == public_keys[i] ==
- key_data[32:]
+ signing_key.verify_key.encode()
+ == public
+ == public_keys[i]
+ == key_data[32:]
)
signing_keys.append(signing_key)
# Comment, ignore.
diff --git a/paramiko/file.py b/paramiko/file.py
index df9cdac7..9e9f6eb8 100644
--- a/paramiko/file.py
+++ b/paramiko/file.py
@@ -16,14 +16,18 @@
# along with Paramiko; if not, write to the Free Software Foundation, Inc.,
# 59 Temple Place, Suite 330, Boston, MA 02111-1307 USA.
from paramiko.common import (
- linefeed_byte_value, crlf, cr_byte, linefeed_byte, cr_byte_value,
+ linefeed_byte_value,
+ crlf,
+ cr_byte,
+ linefeed_byte,
+ cr_byte_value,
)
from paramiko.py3compat import BytesIO, PY2, u, bytes_types, text_type
from paramiko.util import ClosingContextManager
-class BufferedFile (ClosingContextManager):
+class BufferedFile(ClosingContextManager):
"""
Reusable base class to implement Python-style file buffering around a
simpler stream.
@@ -70,7 +74,7 @@ class BufferedFile (ClosingContextManager):
:raises: ``ValueError`` -- if the file is closed.
"""
if self._closed:
- raise ValueError('I/O operation on closed file')
+ raise ValueError("I/O operation on closed file")
return self
def close(self):
@@ -90,6 +94,7 @@ class BufferedFile (ClosingContextManager):
return
if PY2:
+
def next(self):
"""
Returns the next line from the input, or raises
@@ -104,7 +109,9 @@ class BufferedFile (ClosingContextManager):
if not line:
raise StopIteration
return line
+
else:
+
def __next__(self):
"""
Returns the next line from the input, or raises ``StopIteration``
@@ -159,7 +166,7 @@ class BufferedFile (ClosingContextManager):
The number of bytes read.
"""
data = self.read(len(buff))
- buff[:len(data)] = data
+ buff[: len(data)] = data
return len(data)
def read(self, size=None):
@@ -180,9 +187,9 @@ class BufferedFile (ClosingContextManager):
encountered immediately
"""
if self._closed:
- raise IOError('File is closed')
+ raise IOError("File is closed")
if not (self._flags & self.FLAG_READ):
- raise IOError('File is not open for reading')
+ raise IOError("File is not open for reading")
if (size is None) or (size < 0):
# go for broke
result = self._rbuffer
@@ -245,16 +252,16 @@ class BufferedFile (ClosingContextManager):
"""
# it's almost silly how complex this function is.
if self._closed:
- raise IOError('File is closed')
+ raise IOError("File is closed")
if not (self._flags & self.FLAG_READ):
- raise IOError('File not open for reading')
+ raise IOError("File not open for reading")
line = self._rbuffer
truncated = False
while True:
if (
- self._at_trailing_cr and
- self._flags & self.FLAG_UNIVERSAL_NEWLINE and
- len(line) > 0
+ self._at_trailing_cr
+ and self._flags & self.FLAG_UNIVERSAL_NEWLINE
+ and len(line) > 0
):
# edge case: the newline may be '\r\n' and we may have read
# only the first '\r' last time.
@@ -276,12 +283,8 @@ class BufferedFile (ClosingContextManager):
n = size - len(line)
else:
n = self._bufsize
- if (
- linefeed_byte in line or
- (
- self._flags & self.FLAG_UNIVERSAL_NEWLINE and
- cr_byte in line
- )
+ if linefeed_byte in line or (
+ self._flags & self.FLAG_UNIVERSAL_NEWLINE and cr_byte in line
):
break
try:
@@ -306,9 +309,9 @@ class BufferedFile (ClosingContextManager):
return line if self._flags & self.FLAG_BINARY else u(line)
xpos = pos + 1
if (
- line[pos] == cr_byte_value and
- xpos < len(line) and
- line[xpos] == linefeed_byte_value
+ line[pos] == cr_byte_value
+ and xpos < len(line)
+ and line[xpos] == linefeed_byte_value
):
xpos += 1
# if the string was truncated, _rbuffer needs to have the string after
@@ -370,7 +373,7 @@ class BufferedFile (ClosingContextManager):
:raises: ``IOError`` -- if the file doesn't support random access.
"""
- raise IOError('File does not support seeking.')
+ raise IOError("File does not support seeking.")
def tell(self):
"""
@@ -393,11 +396,11 @@ class BufferedFile (ClosingContextManager):
"""
if isinstance(data, text_type):
# Accept text and encode as utf-8 for compatibility only.
- data = data.encode('utf-8')
+ data = data.encode("utf-8")
if self._closed:
- raise IOError('File is closed')
+ raise IOError("File is closed")
if not (self._flags & self.FLAG_WRITE):
- raise IOError('File not open for writing')
+ raise IOError("File not open for writing")
if not (self._flags & self.FLAG_BUFFERED):
self._write_all(data)
return
@@ -408,9 +411,9 @@ class BufferedFile (ClosingContextManager):
if last_newline_pos >= 0:
wbuf = self._wbuffer.getvalue()
last_newline_pos += len(wbuf) - len(data)
- self._write_all(wbuf[:last_newline_pos + 1])
+ self._write_all(wbuf[: last_newline_pos + 1])
self._wbuffer = BytesIO()
- self._wbuffer.write(wbuf[last_newline_pos + 1:])
+ self._wbuffer.write(wbuf[last_newline_pos + 1 :])
return
# even if we're line buffering, if the buffer has grown past the
# buffer size, force a flush.
@@ -457,7 +460,7 @@ class BufferedFile (ClosingContextManager):
(subclass override)
Write data into the stream.
"""
- raise IOError('write not implemented')
+ raise IOError("write not implemented")
def _get_size(self):
"""
@@ -472,7 +475,7 @@ class BufferedFile (ClosingContextManager):
# ...internals...
- def _set_mode(self, mode='r', bufsize=-1):
+ def _set_mode(self, mode="r", bufsize=-1):
"""
Subclasses call this method to initialize the BufferedFile.
"""
@@ -495,17 +498,17 @@ class BufferedFile (ClosingContextManager):
# unbuffered
self._flags &= ~(self.FLAG_BUFFERED | self.FLAG_LINE_BUFFERED)
- if ('r' in mode) or ('+' in mode):
+ if ("r" in mode) or ("+" in mode):
self._flags |= self.FLAG_READ
- if ('w' in mode) or ('+' in mode):
+ if ("w" in mode) or ("+" in mode):
self._flags |= self.FLAG_WRITE
- if 'a' in mode:
+ if "a" in mode:
self._flags |= self.FLAG_WRITE | self.FLAG_APPEND
self._size = self._get_size()
self._pos = self._realpos = self._size
- if 'b' in mode:
+ if "b" in mode:
self._flags |= self.FLAG_BINARY
- if 'U' in mode:
+ if "U" in mode:
self._flags |= self.FLAG_UNIVERSAL_NEWLINE
# built-in file objects have this attribute to store which kinds of
# line terminations they've seen:
@@ -534,9 +537,8 @@ class BufferedFile (ClosingContextManager):
return
if self.newlines is None:
self.newlines = newline
- elif (
- self.newlines != newline and
- isinstance(self.newlines, bytes_types)
+ elif self.newlines != newline and isinstance(
+ self.newlines, bytes_types
):
self.newlines = (self.newlines, newline)
elif newline not in self.newlines:
diff --git a/paramiko/hostkeys.py b/paramiko/hostkeys.py
index ca185273..c4d611db 100644
--- a/paramiko/hostkeys.py
+++ b/paramiko/hostkeys.py
@@ -34,7 +34,7 @@ from paramiko.ed25519key import Ed25519Key
from paramiko.ssh_exception import SSHException
-class HostKeys (MutableMapping):
+class HostKeys(MutableMapping):
"""
Representation of an OpenSSH-style "known hosts" file. Host keys can be
read from one or more files, and then individual hosts can be looked up to
@@ -88,10 +88,10 @@ class HostKeys (MutableMapping):
:raises: ``IOError`` -- if there was an error reading the file
"""
- with open(filename, 'r') as f:
+ with open(filename, "r") as f:
for lineno, line in enumerate(f, 1):
line = line.strip()
- if (len(line) == 0) or (line[0] == '#'):
+ if (len(line) == 0) or (line[0] == "#"):
continue
try:
e = HostKeyEntry.from_line(line, lineno)
@@ -118,7 +118,7 @@ class HostKeys (MutableMapping):
.. versionadded:: 1.6.1
"""
- with open(filename, 'w') as f:
+ with open(filename, "w") as f:
for e in self._entries:
line = e.to_line()
if line:
@@ -134,7 +134,9 @@ class HostKeys (MutableMapping):
:return: dict of `str` -> `.PKey` keys associated with this host
(or ``None``)
"""
- class SubDict (MutableMapping):
+
+ class SubDict(MutableMapping):
+
def __init__(self, hostname, entries, hostkeys):
self._hostname = hostname
self._entries = entries
@@ -176,7 +178,8 @@ class HostKeys (MutableMapping):
def keys(self):
return [
- e.key.get_name() for e in self._entries
+ e.key.get_name()
+ for e in self._entries
if e.key is not None
]
@@ -196,10 +199,10 @@ class HostKeys (MutableMapping):
"""
for h in entry.hostnames:
if (
- h == hostname or
- h.startswith('|1|') and
- not hostname.startswith('|1|') and
- constant_time_bytes_eq(self.hash_host(hostname, h), h)
+ h == hostname
+ or h.startswith("|1|")
+ and not hostname.startswith("|1|")
+ and constant_time_bytes_eq(self.hash_host(hostname, h), h)
):
return True
return False
@@ -295,16 +298,17 @@ class HostKeys (MutableMapping):
if salt is None:
salt = os.urandom(sha1().digest_size)
else:
- if salt.startswith('|1|'):
- salt = salt.split('|')[2]
+ if salt.startswith("|1|"):
+ salt = salt.split("|")[2]
salt = decodebytes(b(salt))
assert len(salt) == sha1().digest_size
hmac = HMAC(salt, b(hostname), sha1).digest()
- hostkey = '|1|{}|{}'.format(u(encodebytes(salt)), u(encodebytes(hmac)))
- return hostkey.replace('\n', '')
+ hostkey = "|1|{}|{}".format(u(encodebytes(salt)), u(encodebytes(hmac)))
+ return hostkey.replace("\n", "")
class InvalidHostKey(Exception):
+
def __init__(self, line, exc):
self.line = line
self.exc = exc
@@ -334,8 +338,8 @@ class HostKeyEntry:
:param str line: a line from an OpenSSH known_hosts file
"""
- log = get_logger('paramiko.hostkeys')
- fields = line.split(' ')
+ log = get_logger("paramiko.hostkeys")
+ fields = line.split(" ")
if len(fields) < 3:
# Bad number of fields
msg = "Not enough fields found in known_hosts in line {} ({!r})"
@@ -344,19 +348,19 @@ class HostKeyEntry:
fields = fields[:3]
names, keytype, key = fields
- names = names.split(',')
+ names = names.split(",")
# Decide what kind of key we're looking at and create an object
# to hold it accordingly.
try:
key = b(key)
- if keytype == 'ssh-rsa':
+ if keytype == "ssh-rsa":
key = RSAKey(data=decodebytes(key))
- elif keytype == 'ssh-dss':
+ elif keytype == "ssh-dss":
key = DSSKey(data=decodebytes(key))
elif keytype in ECDSAKey.supported_key_format_identifiers():
key = ECDSAKey(data=decodebytes(key), validate_point=False)
- elif keytype == 'ssh-ed25519':
+ elif keytype == "ssh-ed25519":
key = Ed25519Key(data=decodebytes(key))
else:
log.info("Unable to handle key of type {}".format(keytype))
@@ -374,12 +378,12 @@ class HostKeyEntry:
included.
"""
if self.valid:
- return '{} {} {}\n'.format(
- ','.join(self.hostnames),
+ return "{} {} {}\n".format(
+ ",".join(self.hostnames),
self.key.get_name(),
self.key.get_base64(),
)
return None
def __repr__(self):
- return '<HostKeyEntry {!r}: {!r}>'.format(self.hostnames, self.key)
+ return "<HostKeyEntry {!r}: {!r}>".format(self.hostnames, self.key)
diff --git a/paramiko/kex_ecdh_nist.py b/paramiko/kex_ecdh_nist.py
index 4e8ff35d..1d87442a 100644
--- a/paramiko/kex_ecdh_nist.py
+++ b/paramiko/kex_ecdh_nist.py
@@ -15,7 +15,7 @@ _MSG_KEXECDH_INIT, _MSG_KEXECDH_REPLY = range(30, 32)
c_MSG_KEXECDH_INIT, c_MSG_KEXECDH_REPLY = [byte_chr(c) for c in range(30, 32)]
-class KexNistp256():
+class KexNistp256:
name = "ecdh-sha2-nistp256"
hash_algo = sha256
@@ -46,7 +46,7 @@ class KexNistp256():
elif not self.transport.server_mode and (ptype == _MSG_KEXECDH_REPLY):
return self._parse_kexecdh_reply(m)
raise SSHException(
- 'KexECDH asked to handle packet type {:d}'.format(ptype)
+ "KexECDH asked to handle packet type {:d}".format(ptype)
)
def _generate_key_pair(self):
@@ -66,8 +66,12 @@ class KexNistp256():
K = long(hexlify(K), 16)
# compute exchange hash
hm = Message()
- hm.add(self.transport.remote_version, self.transport.local_version,
- self.transport.remote_kex_init, self.transport.local_kex_init)
+ hm.add(
+ self.transport.remote_version,
+ self.transport.local_version,
+ self.transport.remote_kex_init,
+ self.transport.local_kex_init,
+ )
hm.add_string(K_S)
hm.add_string(Q_C_bytes)
# SEC1: V2.0 2.3.3 Elliptic-Curve-Point-to-Octet-String Conversion
@@ -96,8 +100,12 @@ class KexNistp256():
K = long(hexlify(K), 16)
# compute exchange hash and verify signature
hm = Message()
- hm.add(self.transport.local_version, self.transport.remote_version,
- self.transport.local_kex_init, self.transport.remote_kex_init)
+ hm.add(
+ self.transport.local_version,
+ self.transport.remote_version,
+ self.transport.local_kex_init,
+ self.transport.remote_kex_init,
+ )
hm.add_string(K_S)
# SEC1: V2.0 2.3.3 Elliptic-Curve-Point-to-Octet-String Conversion
hm.add_string(self.Q_C.public_numbers().encode_point())
diff --git a/paramiko/kex_gex.py b/paramiko/kex_gex.py
index 44030569..fb8f01fd 100644
--- a/paramiko/kex_gex.py
+++ b/paramiko/kex_gex.py
@@ -32,17 +32,26 @@ from paramiko.py3compat import byte_chr, byte_ord, byte_mask
from paramiko.ssh_exception import SSHException
-_MSG_KEXDH_GEX_REQUEST_OLD, _MSG_KEXDH_GEX_GROUP, _MSG_KEXDH_GEX_INIT, \
- _MSG_KEXDH_GEX_REPLY, _MSG_KEXDH_GEX_REQUEST = range(30, 35)
+(
+ _MSG_KEXDH_GEX_REQUEST_OLD,
+ _MSG_KEXDH_GEX_GROUP,
+ _MSG_KEXDH_GEX_INIT,
+ _MSG_KEXDH_GEX_REPLY,
+ _MSG_KEXDH_GEX_REQUEST,
+) = range(30, 35)
-c_MSG_KEXDH_GEX_REQUEST_OLD, c_MSG_KEXDH_GEX_GROUP, c_MSG_KEXDH_GEX_INIT, \
- c_MSG_KEXDH_GEX_REPLY, c_MSG_KEXDH_GEX_REQUEST = \
- [byte_chr(c) for c in range(30, 35)]
+(
+ c_MSG_KEXDH_GEX_REQUEST_OLD,
+ c_MSG_KEXDH_GEX_GROUP,
+ c_MSG_KEXDH_GEX_INIT,
+ c_MSG_KEXDH_GEX_REPLY,
+ c_MSG_KEXDH_GEX_REQUEST,
+) = [byte_chr(c) for c in range(30, 35)]
-class KexGex (object):
+class KexGex(object):
- name = 'diffie-hellman-group-exchange-sha1'
+ name = "diffie-hellman-group-exchange-sha1"
min_bits = 1024
max_bits = 8192
preferred_bits = 2048
@@ -61,7 +70,8 @@ class KexGex (object):
def start_kex(self, _test_old_style=False):
if self.transport.server_mode:
self.transport._expect_packet(
- _MSG_KEXDH_GEX_REQUEST, _MSG_KEXDH_GEX_REQUEST_OLD)
+ _MSG_KEXDH_GEX_REQUEST, _MSG_KEXDH_GEX_REQUEST_OLD
+ )
return
# request a bit range: we accept (min_bits) to (max_bits), but prefer
# (preferred_bits). according to the spec, we shouldn't pull the
@@ -137,13 +147,12 @@ class KexGex (object):
# generate prime
pack = self.transport._get_modulus_pack()
if pack is None:
- raise SSHException(
- 'Can\'t do server-side gex with no modulus pack')
+ raise SSHException("Can't do server-side gex with no modulus pack")
self.transport._log(
DEBUG,
- 'Picking p ({} <= {} <= {} bits)'.format(
- minbits, preferredbits, maxbits,
- )
+ "Picking p ({} <= {} <= {} bits)".format(
+ minbits, preferredbits, maxbits
+ ),
)
self.g, self.p = pack.get_modulus(minbits, preferredbits, maxbits)
m = Message()
@@ -165,13 +174,13 @@ class KexGex (object):
# generate prime
pack = self.transport._get_modulus_pack()
if pack is None:
- raise SSHException(
- 'Can\'t do server-side gex with no modulus pack')
+ raise SSHException("Can't do server-side gex with no modulus pack")
self.transport._log(
- DEBUG, 'Picking p (~ {} bits)'.format(self.preferred_bits)
+ DEBUG, "Picking p (~ {} bits)".format(self.preferred_bits)
)
self.g, self.p = pack.get_modulus(
- self.min_bits, self.preferred_bits, self.max_bits)
+ self.min_bits, self.preferred_bits, self.max_bits
+ )
m = Message()
m.add_byte(c_MSG_KEXDH_GEX_GROUP)
m.add_mpint(self.p)
@@ -187,9 +196,10 @@ class KexGex (object):
bitlen = util.bit_length(self.p)
if (bitlen < 1024) or (bitlen > 8192):
raise SSHException(
- 'Server-generated gex p (don\'t ask) is out of range '
- '({} bits)'.format(bitlen))
- self.transport._log(DEBUG, 'Got server p ({} bits)'.format(bitlen))
+ "Server-generated gex p (don't ask) is out of range "
+ "({} bits)".format(bitlen)
+ )
+ self.transport._log(DEBUG, "Got server p ({} bits)".format(bitlen))
self._generate_x()
# now compute e = g^x mod p
self.e = pow(self.g, self.x, self.p)
@@ -210,9 +220,13 @@ class KexGex (object):
# okay, build up the hash H of
# (V_C || V_S || I_C || I_S || K_S || min || n || max || p || g || e || f || K) # noqa
hm = Message()
- hm.add(self.transport.remote_version, self.transport.local_version,
- self.transport.remote_kex_init, self.transport.local_kex_init,
- key)
+ hm.add(
+ self.transport.remote_version,
+ self.transport.local_version,
+ self.transport.remote_kex_init,
+ self.transport.local_kex_init,
+ key,
+ )
if not self.old_style:
hm.add_int(self.min_bits)
hm.add_int(self.preferred_bits)
@@ -246,9 +260,13 @@ class KexGex (object):
# okay, build up the hash H of
# (V_C || V_S || I_C || I_S || K_S || min || n || max || p || g || e || f || K) # noqa
hm = Message()
- hm.add(self.transport.local_version, self.transport.remote_version,
- self.transport.local_kex_init, self.transport.remote_kex_init,
- host_key)
+ hm.add(
+ self.transport.local_version,
+ self.transport.remote_version,
+ self.transport.local_kex_init,
+ self.transport.remote_kex_init,
+ host_key,
+ )
if not self.old_style:
hm.add_int(self.min_bits)
hm.add_int(self.preferred_bits)
@@ -265,5 +283,5 @@ class KexGex (object):
class KexGexSHA256(KexGex):
- name = 'diffie-hellman-group-exchange-sha256'
+ name = "diffie-hellman-group-exchange-sha256"
hash_algo = sha256
diff --git a/paramiko/kex_group1.py b/paramiko/kex_group1.py
index 1bebd375..66b7bb20 100644
--- a/paramiko/kex_group1.py
+++ b/paramiko/kex_group1.py
@@ -41,10 +41,12 @@ b0000000000000000 = zero_byte * 8
class KexGroup1(object):
# draft-ietf-secsh-transport-09.txt, page 17
- P = 0xFFFFFFFFFFFFFFFFC90FDAA22168C234C4C6628B80DC1CD129024E088A67CC74020BBEA63B139B22514A08798E3404DDEF9519B3CD3A431B302B0A6DF25F14374FE1356D6D51C245E485B576625E7EC6F44C42E9A637ED6B0BFF5CB6F406B7EDEE386BFB5A899FA5AE9F24117C4B1FE649286651ECE65381FFFFFFFFFFFFFFFF # noqa
+ P = (
+ 0xFFFFFFFFFFFFFFFFC90FDAA22168C234C4C6628B80DC1CD129024E088A67CC74020BBEA63B139B22514A08798E3404DDEF9519B3CD3A431B302B0A6DF25F14374FE1356D6D51C245E485B576625E7EC6F44C42E9A637ED6B0BFF5CB6F406B7EDEE386BFB5A899FA5AE9F24117C4B1FE649286651ECE65381FFFFFFFFFFFFFFFF # noqa
+ )
G = 2
- name = 'diffie-hellman-group1-sha1'
+ name = "diffie-hellman-group1-sha1"
hash_algo = sha1
def __init__(self, transport):
@@ -88,8 +90,10 @@ class KexGroup1(object):
while 1:
x_bytes = os.urandom(128)
x_bytes = byte_mask(x_bytes[0], 0x7f) + x_bytes[1:]
- if (x_bytes[:8] != b7fffffffffffffff and
- x_bytes[:8] != b0000000000000000):
+ if (
+ x_bytes[:8] != b7fffffffffffffff
+ and x_bytes[:8] != b0000000000000000
+ ):
break
self.x = util.inflate_long(x_bytes)
@@ -104,8 +108,12 @@ class KexGroup1(object):
# okay, build up the hash H of
# (V_C || V_S || I_C || I_S || K_S || e || f || K)
hm = Message()
- hm.add(self.transport.local_version, self.transport.remote_version,
- self.transport.local_kex_init, self.transport.remote_kex_init)
+ hm.add(
+ self.transport.local_version,
+ self.transport.remote_version,
+ self.transport.local_kex_init,
+ self.transport.remote_kex_init,
+ )
hm.add_string(host_key)
hm.add_mpint(self.e)
hm.add_mpint(self.f)
@@ -124,8 +132,12 @@ class KexGroup1(object):
# okay, build up the hash H of
# (V_C || V_S || I_C || I_S || K_S || e || f || K)
hm = Message()
- hm.add(self.transport.remote_version, self.transport.local_version,
- self.transport.remote_kex_init, self.transport.local_kex_init)
+ hm.add(
+ self.transport.remote_version,
+ self.transport.local_version,
+ self.transport.remote_kex_init,
+ self.transport.local_kex_init,
+ )
hm.add_string(key)
hm.add_mpint(self.e)
hm.add_mpint(self.f)
diff --git a/paramiko/kex_group14.py b/paramiko/kex_group14.py
index 22955e34..29af2408 100644
--- a/paramiko/kex_group14.py
+++ b/paramiko/kex_group14.py
@@ -28,8 +28,10 @@ from hashlib import sha1
class KexGroup14(KexGroup1):
# http://tools.ietf.org/html/rfc3526#section-3
- P = 0xFFFFFFFFFFFFFFFFC90FDAA22168C234C4C6628B80DC1CD129024E088A67CC74020BBEA63B139B22514A08798E3404DDEF9519B3CD3A431B302B0A6DF25F14374FE1356D6D51C245E485B576625E7EC6F44C42E9A637ED6B0BFF5CB6F406B7EDEE386BFB5A899FA5AE9F24117C4B1FE649286651ECE45B3DC2007CB8A163BF0598DA48361C55D39A69163FA8FD24CF5F83655D23DCA3AD961C62F356208552BB9ED529077096966D670C354E4ABC9804F1746C08CA18217C32905E462E36CE3BE39E772C180E86039B2783A2EC07A28FB5C55DF06F4C52C9DE2BCBF6955817183995497CEA956AE515D2261898FA051015728E5A8AACAA68FFFFFFFFFFFFFFFF # noqa
+ P = (
+ 0xFFFFFFFFFFFFFFFFC90FDAA22168C234C4C6628B80DC1CD129024E088A67CC74020BBEA63B139B22514A08798E3404DDEF9519B3CD3A431B302B0A6DF25F14374FE1356D6D51C245E485B576625E7EC6F44C42E9A637ED6B0BFF5CB6F406B7EDEE386BFB5A899FA5AE9F24117C4B1FE649286651ECE45B3DC2007CB8A163BF0598DA48361C55D39A69163FA8FD24CF5F83655D23DCA3AD961C62F356208552BB9ED529077096966D670C354E4ABC9804F1746C08CA18217C32905E462E36CE3BE39E772C180E86039B2783A2EC07A28FB5C55DF06F4C52C9DE2BCBF6955817183995497CEA956AE515D2261898FA051015728E5A8AACAA68FFFFFFFFFFFFFFFF # noqa
+ )
G = 2
- name = 'diffie-hellman-group14-sha1'
+ name = "diffie-hellman-group14-sha1"
hash_algo = sha1
diff --git a/paramiko/kex_gss.py b/paramiko/kex_gss.py
index e21620fe..1510ff9c 100644
--- a/paramiko/kex_gss.py
+++ b/paramiko/kex_gss.py
@@ -47,14 +47,22 @@ from paramiko.py3compat import byte_chr, byte_mask, byte_ord
from paramiko.ssh_exception import SSHException
-MSG_KEXGSS_INIT, MSG_KEXGSS_CONTINUE, MSG_KEXGSS_COMPLETE, MSG_KEXGSS_HOSTKEY,\
- MSG_KEXGSS_ERROR = range(30, 35)
-MSG_KEXGSS_GROUPREQ, MSG_KEXGSS_GROUP = range(40, 42)
-c_MSG_KEXGSS_INIT, c_MSG_KEXGSS_CONTINUE, c_MSG_KEXGSS_COMPLETE,\
- c_MSG_KEXGSS_HOSTKEY, c_MSG_KEXGSS_ERROR = [
- byte_chr(c) for c in range(30, 35)
- ]
-c_MSG_KEXGSS_GROUPREQ, c_MSG_KEXGSS_GROUP = [
+(
+ MSG_KEXGSS_INIT,
+ MSG_KEXGSS_CONTINUE,
+ MSG_KEXGSS_COMPLETE,
+ MSG_KEXGSS_HOSTKEY,
+ MSG_KEXGSS_ERROR,
+) = range(30, 35)
+(MSG_KEXGSS_GROUPREQ, MSG_KEXGSS_GROUP) = range(40, 42)
+(
+ c_MSG_KEXGSS_INIT,
+ c_MSG_KEXGSS_CONTINUE,
+ c_MSG_KEXGSS_COMPLETE,
+ c_MSG_KEXGSS_HOSTKEY,
+ c_MSG_KEXGSS_ERROR,
+) = [byte_chr(c) for c in range(30, 35)]
+(c_MSG_KEXGSS_GROUPREQ, c_MSG_KEXGSS_GROUP) = [
byte_chr(c) for c in range(40, 42)
]
@@ -65,7 +73,9 @@ class KexGSSGroup1(object):
4462 Section 2 <https://tools.ietf.org/html/rfc4462.html#section-2>`_
"""
# draft-ietf-secsh-transport-09.txt, page 17
- P = 0xFFFFFFFFFFFFFFFFC90FDAA22168C234C4C6628B80DC1CD129024E088A67CC74020BBEA63B139B22514A08798E3404DDEF9519B3CD3A431B302B0A6DF25F14374FE1356D6D51C245E485B576625E7EC6F44C42E9A637ED6B0BFF5CB6F406B7EDEE386BFB5A899FA5AE9F24117C4B1FE649286651ECE65381FFFFFFFFFFFFFFFF # noqa
+ P = (
+ 0xFFFFFFFFFFFFFFFFC90FDAA22168C234C4C6628B80DC1CD129024E088A67CC74020BBEA63B139B22514A08798E3404DDEF9519B3CD3A431B302B0A6DF25F14374FE1356D6D51C245E485B576625E7EC6F44C42E9A637ED6B0BFF5CB6F406B7EDEE386BFB5A899FA5AE9F24117C4B1FE649286651ECE65381FFFFFFFFFFFFFFFF # noqa
+ )
G = 2
b7fffffffffffffff = byte_chr(0x7f) + max_byte * 7 # noqa
b0000000000000000 = zero_byte * 8 # noqa
@@ -98,10 +108,12 @@ class KexGSSGroup1(object):
m.add_string(self.kexgss.ssh_init_sec_context(target=self.gss_host))
m.add_mpint(self.e)
self.transport._send_message(m)
- self.transport._expect_packet(MSG_KEXGSS_HOSTKEY,
- MSG_KEXGSS_CONTINUE,
- MSG_KEXGSS_COMPLETE,
- MSG_KEXGSS_ERROR)
+ self.transport._expect_packet(
+ MSG_KEXGSS_HOSTKEY,
+ MSG_KEXGSS_CONTINUE,
+ MSG_KEXGSS_COMPLETE,
+ MSG_KEXGSS_ERROR,
+ )
def parse_next(self, ptype, m):
"""
@@ -120,7 +132,7 @@ class KexGSSGroup1(object):
return self._parse_kexgss_complete(m)
elif ptype == MSG_KEXGSS_ERROR:
return self._parse_kexgss_error(m)
- msg = 'GSS KexGroup1 asked to handle packet type {:d}'
+ msg = "GSS KexGroup1 asked to handle packet type {:d}"
raise SSHException(msg.format(ptype))
# ## internals...
@@ -152,8 +164,7 @@ class KexGSSGroup1(object):
self.transport.host_key = host_key
sig = m.get_string()
self.transport._verify_key(host_key, sig)
- self.transport._expect_packet(MSG_KEXGSS_CONTINUE,
- MSG_KEXGSS_COMPLETE)
+ self.transport._expect_packet(MSG_KEXGSS_CONTINUE, MSG_KEXGSS_COMPLETE)
def _parse_kexgss_continue(self, m):
"""
@@ -166,13 +177,14 @@ class KexGSSGroup1(object):
srv_token = m.get_string()
m = Message()
m.add_byte(c_MSG_KEXGSS_CONTINUE)
- m.add_string(self.kexgss.ssh_init_sec_context(
- target=self.gss_host, recv_token=srv_token))
+ m.add_string(
+ self.kexgss.ssh_init_sec_context(
+ target=self.gss_host, recv_token=srv_token
+ )
+ )
self.transport.send_message(m)
self.transport._expect_packet(
- MSG_KEXGSS_CONTINUE,
- MSG_KEXGSS_COMPLETE,
- MSG_KEXGSS_ERROR
+ MSG_KEXGSS_CONTINUE, MSG_KEXGSS_COMPLETE, MSG_KEXGSS_ERROR
)
else:
pass
@@ -200,8 +212,12 @@ class KexGSSGroup1(object):
# okay, build up the hash H of
# (V_C || V_S || I_C || I_S || K_S || e || f || K)
hm = Message()
- hm.add(self.transport.local_version, self.transport.remote_version,
- self.transport.local_kex_init, self.transport.remote_kex_init)
+ hm.add(
+ self.transport.local_version,
+ self.transport.remote_version,
+ self.transport.local_kex_init,
+ self.transport.remote_kex_init,
+ )
hm.add_string(self.transport.host_key.__str__())
hm.add_mpint(self.e)
hm.add_mpint(self.f)
@@ -209,8 +225,9 @@ class KexGSSGroup1(object):
H = sha1(str(hm)).digest()
self.transport._set_K_H(K, H)
if srv_token is not None:
- self.kexgss.ssh_init_sec_context(target=self.gss_host,
- recv_token=srv_token)
+ self.kexgss.ssh_init_sec_context(
+ target=self.gss_host, recv_token=srv_token
+ )
self.kexgss.ssh_check_mic(mic_token, H)
else:
self.kexgss.ssh_check_mic(mic_token, H)
@@ -234,20 +251,26 @@ class KexGSSGroup1(object):
# okay, build up the hash H of
# (V_C || V_S || I_C || I_S || K_S || e || f || K)
hm = Message()
- hm.add(self.transport.remote_version, self.transport.local_version,
- self.transport.remote_kex_init, self.transport.local_kex_init)
+ hm.add(
+ self.transport.remote_version,
+ self.transport.local_version,
+ self.transport.remote_kex_init,
+ self.transport.local_kex_init,
+ )
hm.add_string(key)
hm.add_mpint(self.e)
hm.add_mpint(self.f)
hm.add_mpint(K)
H = sha1(hm.asbytes()).digest()
self.transport._set_K_H(K, H)
- srv_token = self.kexgss.ssh_accept_sec_context(self.gss_host,
- client_token)
+ srv_token = self.kexgss.ssh_accept_sec_context(
+ self.gss_host, client_token
+ )
m = Message()
if self.kexgss._gss_srv_ctxt_status:
- mic_token = self.kexgss.ssh_get_mic(self.transport.session_id,
- gss_kex=True)
+ mic_token = self.kexgss.ssh_get_mic(
+ self.transport.session_id, gss_kex=True
+ )
m.add_byte(c_MSG_KEXGSS_COMPLETE)
m.add_mpint(self.f)
m.add_string(mic_token)
@@ -263,9 +286,9 @@ class KexGSSGroup1(object):
m.add_byte(c_MSG_KEXGSS_CONTINUE)
m.add_string(srv_token)
self.transport._send_message(m)
- self.transport._expect_packet(MSG_KEXGSS_CONTINUE,
- MSG_KEXGSS_COMPLETE,
- MSG_KEXGSS_ERROR)
+ self.transport._expect_packet(
+ MSG_KEXGSS_CONTINUE, MSG_KEXGSS_COMPLETE, MSG_KEXGSS_ERROR
+ )
def _parse_kexgss_error(self, m):
"""
@@ -281,12 +304,16 @@ class KexGSSGroup1(object):
maj_status = m.get_int()
min_status = m.get_int()
err_msg = m.get_string()
- m.get_string() # we don't care about the language!
- raise SSHException("""GSS-API Error:
+ m.get_string() # we don't care about the language!
+ raise SSHException(
+ """GSS-API Error:
Major Status: {}
Minor Status: {}
Error Message: {}
-""".format(maj_status, min_status, err_msg))
+""".format(
+ maj_status, min_status, err_msg
+ )
+ )
class KexGSSGroup14(KexGSSGroup1):
@@ -295,7 +322,9 @@ class KexGSSGroup14(KexGSSGroup1):
in `RFC 4462 Section 2
<https://tools.ietf.org/html/rfc4462.html#section-2>`_
"""
- P = 0xFFFFFFFFFFFFFFFFC90FDAA22168C234C4C6628B80DC1CD129024E088A67CC74020BBEA63B139B22514A08798E3404DDEF9519B3CD3A431B302B0A6DF25F14374FE1356D6D51C245E485B576625E7EC6F44C42E9A637ED6B0BFF5CB6F406B7EDEE386BFB5A899FA5AE9F24117C4B1FE649286651ECE45B3DC2007CB8A163BF0598DA48361C55D39A69163FA8FD24CF5F83655D23DCA3AD961C62F356208552BB9ED529077096966D670C354E4ABC9804F1746C08CA18217C32905E462E36CE3BE39E772C180E86039B2783A2EC07A28FB5C55DF06F4C52C9DE2BCBF6955817183995497CEA956AE515D2261898FA051015728E5A8AACAA68FFFFFFFFFFFFFFFF # noqa
+ P = (
+ 0xFFFFFFFFFFFFFFFFC90FDAA22168C234C4C6628B80DC1CD129024E088A67CC74020BBEA63B139B22514A08798E3404DDEF9519B3CD3A431B302B0A6DF25F14374FE1356D6D51C245E485B576625E7EC6F44C42E9A637ED6B0BFF5CB6F406B7EDEE386BFB5A899FA5AE9F24117C4B1FE649286651ECE45B3DC2007CB8A163BF0598DA48361C55D39A69163FA8FD24CF5F83655D23DCA3AD961C62F356208552BB9ED529077096966D670C354E4ABC9804F1746C08CA18217C32905E462E36CE3BE39E772C180E86039B2783A2EC07A28FB5C55DF06F4C52C9DE2BCBF6955817183995497CEA956AE515D2261898FA051015728E5A8AACAA68FFFFFFFFFFFFFFFF # noqa
+ )
G = 2
NAME = "gss-group14-sha1-toWM5Slw5Ew8Mqkay+al2g=="
@@ -362,7 +391,7 @@ class KexGSSGex(object):
return self._parse_kexgss_complete(m)
elif ptype == MSG_KEXGSS_ERROR:
return self._parse_kexgss_error(m)
- msg = 'KexGex asked to handle packet type {:d}'
+ msg = "KexGex asked to handle packet type {:d}"
raise SSHException(msg.format(ptype))
# ## internals...
@@ -414,13 +443,12 @@ class KexGSSGex(object):
# generate prime
pack = self.transport._get_modulus_pack()
if pack is None:
- raise SSHException(
- 'Can\'t do server-side gex with no modulus pack')
+ raise SSHException("Can't do server-side gex with no modulus pack")
self.transport._log(
DEBUG, # noqa
- 'Picking p ({} <= {} <= {} bits)'.format(
- minbits, preferredbits, maxbits,
- )
+ "Picking p ({} <= {} <= {} bits)".format(
+ minbits, preferredbits, maxbits
+ ),
)
self.g, self.p = pack.get_modulus(minbits, preferredbits, maxbits)
m = Message()
@@ -442,9 +470,12 @@ class KexGSSGex(object):
bitlen = util.bit_length(self.p)
if (bitlen < 1024) or (bitlen > 8192):
raise SSHException(
- 'Server-generated gex p (don\'t ask) is out of range '
- '({} bits)'.format(bitlen))
- self.transport._log(DEBUG, 'Got server p ({} bits)'.format(bitlen)) # noqa
+ "Server-generated gex p (don't ask) is out of range "
+ "({} bits)".format(bitlen)
+ )
+ self.transport._log(
+ DEBUG, "Got server p ({} bits)".format(bitlen)
+ ) # noqa
self._generate_x()
# now compute e = g^x mod p
self.e = pow(self.g, self.x, self.p)
@@ -453,10 +484,12 @@ class KexGSSGex(object):
m.add_string(self.kexgss.ssh_init_sec_context(target=self.gss_host))
m.add_mpint(self.e)
self.transport._send_message(m)
- self.transport._expect_packet(MSG_KEXGSS_HOSTKEY,
- MSG_KEXGSS_CONTINUE,
- MSG_KEXGSS_COMPLETE,
- MSG_KEXGSS_ERROR)
+ self.transport._expect_packet(
+ MSG_KEXGSS_HOSTKEY,
+ MSG_KEXGSS_CONTINUE,
+ MSG_KEXGSS_COMPLETE,
+ MSG_KEXGSS_ERROR,
+ )
def _parse_kexgss_gex_init(self, m):
"""
@@ -476,9 +509,13 @@ class KexGSSGex(object):
# okay, build up the hash H of
# (V_C || V_S || I_C || I_S || K_S || min || n || max || p || g || e || f || K) # noqa
hm = Message()
- hm.add(self.transport.remote_version, self.transport.local_version,
- self.transport.remote_kex_init, self.transport.local_kex_init,
- key)
+ hm.add(
+ self.transport.remote_version,
+ self.transport.local_version,
+ self.transport.remote_kex_init,
+ self.transport.local_kex_init,
+ key,
+ )
hm.add_int(self.min_bits)
hm.add_int(self.preferred_bits)
hm.add_int(self.max_bits)
@@ -489,12 +526,14 @@ class KexGSSGex(object):
hm.add_mpint(K)
H = sha1(hm.asbytes()).digest()
self.transport._set_K_H(K, H)
- srv_token = self.kexgss.ssh_accept_sec_context(self.gss_host,
- client_token)
+ srv_token = self.kexgss.ssh_accept_sec_context(
+ self.gss_host, client_token
+ )
m = Message()
if self.kexgss._gss_srv_ctxt_status:
- mic_token = self.kexgss.ssh_get_mic(self.transport.session_id,
- gss_kex=True)
+ mic_token = self.kexgss.ssh_get_mic(
+ self.transport.session_id, gss_kex=True
+ )
m.add_byte(c_MSG_KEXGSS_COMPLETE)
m.add_mpint(self.f)
m.add_string(mic_token)
@@ -510,9 +549,9 @@ class KexGSSGex(object):
m.add_byte(c_MSG_KEXGSS_CONTINUE)
m.add_string(srv_token)
self.transport._send_message(m)
- self.transport._expect_packet(MSG_KEXGSS_CONTINUE,
- MSG_KEXGSS_COMPLETE,
- MSG_KEXGSS_ERROR)
+ self.transport._expect_packet(
+ MSG_KEXGSS_CONTINUE, MSG_KEXGSS_COMPLETE, MSG_KEXGSS_ERROR
+ )
def _parse_kexgss_hostkey(self, m):
"""
@@ -525,8 +564,7 @@ class KexGSSGex(object):
self.transport.host_key = host_key
sig = m.get_string()
self.transport._verify_key(host_key, sig)
- self.transport._expect_packet(MSG_KEXGSS_CONTINUE,
- MSG_KEXGSS_COMPLETE)
+ self.transport._expect_packet(MSG_KEXGSS_CONTINUE, MSG_KEXGSS_COMPLETE)
def _parse_kexgss_continue(self, m):
"""
@@ -538,12 +576,15 @@ class KexGSSGex(object):
srv_token = m.get_string()
m = Message()
m.add_byte(c_MSG_KEXGSS_CONTINUE)
- m.add_string(self.kexgss.ssh_init_sec_context(target=self.gss_host,
- recv_token=srv_token))
+ m.add_string(
+ self.kexgss.ssh_init_sec_context(
+ target=self.gss_host, recv_token=srv_token
+ )
+ )
self.transport.send_message(m)
- self.transport._expect_packet(MSG_KEXGSS_CONTINUE,
- MSG_KEXGSS_COMPLETE,
- MSG_KEXGSS_ERROR)
+ self.transport._expect_packet(
+ MSG_KEXGSS_CONTINUE, MSG_KEXGSS_COMPLETE, MSG_KEXGSS_ERROR
+ )
else:
pass
@@ -568,9 +609,13 @@ class KexGSSGex(object):
# okay, build up the hash H of
# (V_C || V_S || I_C || I_S || K_S || min || n || max || p || g || e || f || K) # noqa
hm = Message()
- hm.add(self.transport.local_version, self.transport.remote_version,
- self.transport.local_kex_init, self.transport.remote_kex_init,
- self.transport.host_key.__str__())
+ hm.add(
+ self.transport.local_version,
+ self.transport.remote_version,
+ self.transport.local_kex_init,
+ self.transport.remote_kex_init,
+ self.transport.host_key.__str__(),
+ )
if not self.old_style:
hm.add_int(self.min_bits)
hm.add_int(self.preferred_bits)
@@ -584,8 +629,9 @@ class KexGSSGex(object):
H = sha1(hm.asbytes()).digest()
self.transport._set_K_H(K, H)
if srv_token is not None:
- self.kexgss.ssh_init_sec_context(target=self.gss_host,
- recv_token=srv_token)
+ self.kexgss.ssh_init_sec_context(
+ target=self.gss_host, recv_token=srv_token
+ )
self.kexgss.ssh_check_mic(mic_token, H)
else:
self.kexgss.ssh_check_mic(mic_token, H)
@@ -606,12 +652,16 @@ class KexGSSGex(object):
maj_status = m.get_int()
min_status = m.get_int()
err_msg = m.get_string()
- m.get_string() # we don't care about the language (lang_tag)!
- raise SSHException("""GSS-API Error:
+ m.get_string() # we don't care about the language (lang_tag)!
+ raise SSHException(
+ """GSS-API Error:
Major Status: {}
Minor Status: {}
Error Message: {}
-""".format(maj_status, min_status, err_msg))
+""".format(
+ maj_status, min_status, err_msg
+ )
+ )
class NullHostKey(object):
@@ -620,6 +670,7 @@ class NullHostKey(object):
in `RFC 4462 Section 5
<https://tools.ietf.org/html/rfc4462.html#section-5>`_
"""
+
def __init__(self):
self.key = ""
diff --git a/paramiko/message.py b/paramiko/message.py
index 9af841da..dead3508 100644
--- a/paramiko/message.py
+++ b/paramiko/message.py
@@ -27,7 +27,7 @@ from paramiko.common import zero_byte, max_byte, one_byte, asbytes
from paramiko.py3compat import long, BytesIO, u, integer_types
-class Message (object):
+class Message(object):
"""
An SSH2 message is a stream of bytes that encodes some combination of
strings, integers, bools, and infinite-precision integers (known in Python
@@ -63,7 +63,7 @@ class Message (object):
"""
Returns a string representation of this object, for debugging.
"""
- return 'paramiko.Message(' + repr(self.packet.getvalue()) + ')'
+ return "paramiko.Message(" + repr(self.packet.getvalue()) + ")"
def asbytes(self):
"""
@@ -139,13 +139,13 @@ class Message (object):
if byte == max_byte:
return util.inflate_long(self.get_binary())
byte += self.get_bytes(3)
- return struct.unpack('>I', byte)[0]
+ return struct.unpack(">I", byte)[0]
def get_int(self):
"""
Fetch an int from the stream.
"""
- return struct.unpack('>I', self.get_bytes(4))[0]
+ return struct.unpack(">I", self.get_bytes(4))[0]
def get_int64(self):
"""
@@ -153,7 +153,7 @@ class Message (object):
:return: a 64-bit unsigned integer (`long`).
"""
- return struct.unpack('>Q', self.get_bytes(8))[0]
+ return struct.unpack(">Q", self.get_bytes(8))[0]
def get_mpint(self):
"""
@@ -191,7 +191,7 @@ class Message (object):
These are trivially encoded as comma-separated values in a string.
"""
- return self.get_text().split(',')
+ return self.get_text().split(",")
def add_bytes(self, b):
"""
@@ -229,7 +229,7 @@ class Message (object):
:param int n: integer to add
"""
- self.packet.write(struct.pack('>I', n))
+ self.packet.write(struct.pack(">I", n))
return self
def add_adaptive_int(self, n):
@@ -242,7 +242,7 @@ class Message (object):
self.packet.write(max_byte)
self.add_string(util.deflate_long(n))
else:
- self.packet.write(struct.pack('>I', n))
+ self.packet.write(struct.pack(">I", n))
return self
def add_int64(self, n):
@@ -251,7 +251,7 @@ class Message (object):
:param long n: long int to add
"""
- self.packet.write(struct.pack('>Q', n))
+ self.packet.write(struct.pack(">Q", n))
return self
def add_mpint(self, z):
@@ -283,7 +283,7 @@ class Message (object):
:param l: list of strings to add
"""
- self.add_string(','.join(l))
+ self.add_string(",".join(l))
return self
def _add(self, i):
diff --git a/paramiko/packet.py b/paramiko/packet.py
index 2a1e91e2..d324fc35 100644
--- a/paramiko/packet.py
+++ b/paramiko/packet.py
@@ -30,7 +30,12 @@ from hmac import HMAC
from paramiko import util
from paramiko.common import (
- linefeed_byte, cr_byte_value, asbytes, MSG_NAMES, DEBUG, xffffffff,
+ linefeed_byte,
+ cr_byte_value,
+ asbytes,
+ MSG_NAMES,
+ DEBUG,
+ xffffffff,
zero_byte,
)
from paramiko.py3compat import u, byte_ord
@@ -42,7 +47,7 @@ def compute_hmac(key, message, digest_class):
return HMAC(key, message, digest_class).digest()
-class NeedRekeyException (Exception):
+class NeedRekeyException(Exception):
"""
Exception indicating a rekey is needed.
"""
@@ -56,7 +61,7 @@ def first_arg(e):
return arg
-class Packetizer (object):
+class Packetizer(object):
"""
Implementation of the base SSH packet protocol.
"""
@@ -128,8 +133,15 @@ class Packetizer (object):
"""
self.__logger = log
- def set_outbound_cipher(self, block_engine, block_size, mac_engine,
- mac_size, mac_key, sdctr=False):
+ def set_outbound_cipher(
+ self,
+ block_engine,
+ block_size,
+ mac_engine,
+ mac_size,
+ mac_key,
+ sdctr=False,
+ ):
"""
Switch outbound data cipher.
"""
@@ -149,7 +161,8 @@ class Packetizer (object):
self.__need_rekey = False
def set_inbound_cipher(
- self, block_engine, block_size, mac_engine, mac_size, mac_key):
+ self, block_engine, block_size, mac_engine, mac_size, mac_key
+ ):
"""
Switch inbound data cipher.
"""
@@ -352,7 +365,7 @@ class Packetizer (object):
while linefeed_byte not in buf:
buf += self._read_timeout(timeout)
n = buf.index(linefeed_byte)
- self.__remainder = buf[n + 1:]
+ self.__remainder = buf[n + 1 :]
buf = buf[:n]
if (len(buf) > 0) and (buf[-1] == cr_byte_value):
buf = buf[:-1]
@@ -368,7 +381,7 @@ class Packetizer (object):
if cmd in MSG_NAMES:
cmd_name = MSG_NAMES[cmd]
else:
- cmd_name = '${:x}'.format(cmd)
+ cmd_name = "${:x}".format(cmd)
orig_len = len(data)
self.__write_lock.acquire()
try:
@@ -378,37 +391,38 @@ class Packetizer (object):
if self.__dump_packets:
self._log(
DEBUG,
- 'Write packet <{}>, length {}'.format(cmd_name, orig_len)
+ "Write packet <{}>, length {}".format(cmd_name, orig_len),
)
- self._log(DEBUG, util.format_binary(packet, 'OUT: '))
+ self._log(DEBUG, util.format_binary(packet, "OUT: "))
if self.__block_engine_out is not None:
out = self.__block_engine_out.update(packet)
else:
out = packet
# + mac
if self.__block_engine_out is not None:
- payload = struct.pack(
- '>I', self.__sequence_number_out) + packet
+ payload = (
+ struct.pack(">I", self.__sequence_number_out) + packet
+ )
out += compute_hmac(
- self.__mac_key_out,
- payload,
- self.__mac_engine_out)[:self.__mac_size_out]
- self.__sequence_number_out = \
- (self.__sequence_number_out + 1) & xffffffff
+ self.__mac_key_out, payload, self.__mac_engine_out
+ )[: self.__mac_size_out]
+ self.__sequence_number_out = (
+ self.__sequence_number_out + 1
+ ) & xffffffff
self.write_all(out)
self.__sent_bytes += len(out)
self.__sent_packets += 1
sent_too_much = (
- self.__sent_packets >= self.REKEY_PACKETS or
- self.__sent_bytes >= self.REKEY_BYTES
+ self.__sent_packets >= self.REKEY_PACKETS
+ or self.__sent_bytes >= self.REKEY_BYTES
)
if sent_too_much and not self.__need_rekey:
# only ask once for rekeying
msg = "Rekeying (hit {} packets, {} bytes sent)"
- self._log(DEBUG, msg.format(
- self.__sent_packets, self.__sent_bytes,
- ))
+ self._log(
+ DEBUG, msg.format(self.__sent_packets, self.__sent_bytes)
+ )
self.__received_bytes_overflow = 0
self.__received_packets_overflow = 0
self._trigger_rekey()
@@ -427,41 +441,42 @@ class Packetizer (object):
if self.__block_engine_in is not None:
header = self.__block_engine_in.update(header)
if self.__dump_packets:
- self._log(DEBUG, util.format_binary(header, 'IN: '))
- packet_size = struct.unpack('>I', header[:4])[0]
+ self._log(DEBUG, util.format_binary(header, "IN: "))
+ packet_size = struct.unpack(">I", header[:4])[0]
# leftover contains decrypted bytes from the first block (after the
# length field)
leftover = header[4:]
if (packet_size - len(leftover)) % self.__block_size_in != 0:
- raise SSHException('Invalid packet blocking')
+ raise SSHException("Invalid packet blocking")
buf = self.read_all(packet_size + self.__mac_size_in - len(leftover))
- packet = buf[:packet_size - len(leftover)]
- post_packet = buf[packet_size - len(leftover):]
+ packet = buf[: packet_size - len(leftover)]
+ post_packet = buf[packet_size - len(leftover) :]
if self.__block_engine_in is not None:
packet = self.__block_engine_in.update(packet)
if self.__dump_packets:
- self._log(DEBUG, util.format_binary(packet, 'IN: '))
+ self._log(DEBUG, util.format_binary(packet, "IN: "))
packet = leftover + packet
if self.__mac_size_in > 0:
- mac = post_packet[:self.__mac_size_in]
- mac_payload = struct.pack(
- '>II', self.__sequence_number_in, packet_size) + packet
+ mac = post_packet[: self.__mac_size_in]
+ mac_payload = (
+ struct.pack(">II", self.__sequence_number_in, packet_size)
+ + packet
+ )
my_mac = compute_hmac(
- self.__mac_key_in,
- mac_payload,
- self.__mac_engine_in)[:self.__mac_size_in]
+ self.__mac_key_in, mac_payload, self.__mac_engine_in
+ )[: self.__mac_size_in]
if not util.constant_time_bytes_eq(my_mac, mac):
- raise SSHException('Mismatched MAC')
+ raise SSHException("Mismatched MAC")
padding = byte_ord(packet[0])
- payload = packet[1:packet_size - padding]
+ payload = packet[1 : packet_size - padding]
if self.__dump_packets:
self._log(
DEBUG,
- 'Got payload ({} bytes, {} padding)'.format(
+ "Got payload ({} bytes, {} padding)".format(
packet_size, padding
- )
+ ),
)
if self.__compress_engine_in is not None:
@@ -480,19 +495,24 @@ class Packetizer (object):
# dropping the connection
self.__received_bytes_overflow += raw_packet_size
self.__received_packets_overflow += 1
- if (self.__received_packets_overflow >=
- self.REKEY_PACKETS_OVERFLOW_MAX) or \
- (self.__received_bytes_overflow >=
- self.REKEY_BYTES_OVERFLOW_MAX):
+ if (
+ self.__received_packets_overflow
+ >= self.REKEY_PACKETS_OVERFLOW_MAX
+ ) or (
+ self.__received_bytes_overflow >= self.REKEY_BYTES_OVERFLOW_MAX
+ ):
raise SSHException(
- 'Remote transport is ignoring rekey requests')
- elif (self.__received_packets >= self.REKEY_PACKETS) or \
- (self.__received_bytes >= self.REKEY_BYTES):
+ "Remote transport is ignoring rekey requests"
+ )
+ elif (self.__received_packets >= self.REKEY_PACKETS) or (
+ self.__received_bytes >= self.REKEY_BYTES
+ ):
# only ask once for rekeying
err = "Rekeying (hit {} packets, {} bytes received)"
- self._log(DEBUG, err.format(
- self.__received_packets, self.__received_bytes,
- ))
+ self._log(
+ DEBUG,
+ err.format(self.__received_packets, self.__received_bytes),
+ )
self.__received_bytes_overflow = 0
self.__received_packets_overflow = 0
self._trigger_rekey()
@@ -501,11 +521,11 @@ class Packetizer (object):
if cmd in MSG_NAMES:
cmd_name = MSG_NAMES[cmd]
else:
- cmd_name = '${:x}'.format(cmd)
+ cmd_name = "${:x}".format(cmd)
if self.__dump_packets:
self._log(
DEBUG,
- 'Read packet <{}>, length {}'.format(cmd_name, len(payload))
+ "Read packet <{}>, length {}".format(cmd_name, len(payload)),
)
return cmd, msg
@@ -522,9 +542,9 @@ class Packetizer (object):
def _check_keepalive(self):
if (
- not self.__keepalive_interval or
- not self.__block_engine_out or
- self.__need_rekey
+ not self.__keepalive_interval
+ or not self.__block_engine_out
+ or self.__need_rekey
):
# wait till we're encrypting, and not in the middle of rekeying
return
@@ -559,13 +579,13 @@ class Packetizer (object):
# pad up at least 4 bytes, to nearest block-size (usually 8)
bsize = self.__block_size_out
padding = 3 + bsize - ((len(payload) + 8) % bsize)
- packet = struct.pack('>IB', len(payload) + padding + 1, padding)
+ packet = struct.pack(">IB", len(payload) + padding + 1, padding)
packet += payload
if self.__sdctr_out or self.__block_engine_out is None:
# cute trick i caught openssh doing: if we're not encrypting or
# SDCTR mode (RFC4344),
# don't waste random bytes for the padding
- packet += (zero_byte * padding)
+ packet += zero_byte * padding
else:
packet += os.urandom(padding)
return packet
diff --git a/paramiko/pipe.py b/paramiko/pipe.py
index 6ca37703..e88a5e44 100644
--- a/paramiko/pipe.py
+++ b/paramiko/pipe.py
@@ -31,14 +31,15 @@ import socket
def make_pipe():
- if sys.platform[:3] != 'win':
+ if sys.platform[:3] != "win":
p = PosixPipe()
else:
p = WindowsPipe()
return p
-class PosixPipe (object):
+class PosixPipe(object):
+
def __init__(self):
self._rfd, self._wfd = os.pipe()
self._set = False
@@ -64,26 +65,27 @@ class PosixPipe (object):
if self._set or self._closed:
return
self._set = True
- os.write(self._wfd, b'*')
+ os.write(self._wfd, b"*")
def set_forever(self):
self._forever = True
self.set()
-class WindowsPipe (object):
+class WindowsPipe(object):
"""
On Windows, only an OS-level "WinSock" may be used in select(), but reads
and writes must be to the actual socket object.
"""
+
def __init__(self):
serv = socket.socket(socket.AF_INET, socket.SOCK_STREAM)
- serv.bind(('127.0.0.1', 0))
+ serv.bind(("127.0.0.1", 0))
serv.listen(1)
# need to save sockets in _rsock/_wsock so they don't get closed
self._rsock = socket.socket(socket.AF_INET, socket.SOCK_STREAM)
- self._rsock.connect(('127.0.0.1', serv.getsockname()[1]))
+ self._rsock.connect(("127.0.0.1", serv.getsockname()[1]))
self._wsock, addr = serv.accept()
serv.close()
@@ -110,14 +112,15 @@ class WindowsPipe (object):
if self._set or self._closed:
return
self._set = True
- self._wsock.send(b'*')
+ self._wsock.send(b"*")
def set_forever(self):
self._forever = True
self.set()
-class OrPipe (object):
+class OrPipe(object):
+
def __init__(self, pipe):
self._set = False
self._partner = None
diff --git a/paramiko/pkey.py b/paramiko/pkey.py
index 808215f8..fa014800 100644
--- a/paramiko/pkey.py
+++ b/paramiko/pkey.py
@@ -43,23 +43,23 @@ class PKey(object):
# known encryption types for private key files:
_CIPHER_TABLE = {
- 'AES-128-CBC': {
- 'cipher': algorithms.AES,
- 'keysize': 16,
- 'blocksize': 16,
- 'mode': modes.CBC
+ "AES-128-CBC": {
+ "cipher": algorithms.AES,
+ "keysize": 16,
+ "blocksize": 16,
+ "mode": modes.CBC,
},
- 'AES-256-CBC': {
- 'cipher': algorithms.AES,
- 'keysize': 32,
- 'blocksize': 16,
- 'mode': modes.CBC
+ "AES-256-CBC": {
+ "cipher": algorithms.AES,
+ "keysize": 32,
+ "blocksize": 16,
+ "mode": modes.CBC,
},
- 'DES-EDE3-CBC': {
- 'cipher': algorithms.TripleDES,
- 'keysize': 24,
- 'blocksize': 8,
- 'mode': modes.CBC
+ "DES-EDE3-CBC": {
+ "cipher": algorithms.TripleDES,
+ "keysize": 24,
+ "blocksize": 8,
+ "mode": modes.CBC,
},
}
@@ -107,7 +107,7 @@ class PKey(object):
hs = hash(self)
ho = hash(other)
if hs != ho:
- return cmp(hs, ho) # noqa
+ return cmp(hs, ho) # noqa
return cmp(self.asbytes(), other.asbytes()) # noqa
def __eq__(self, other):
@@ -121,7 +121,7 @@ class PKey(object):
name of this private key type, in SSH terminology, as a `str` (for
example, ``"ssh-rsa"``).
"""
- return ''
+ return ""
def get_bits(self):
"""
@@ -158,7 +158,7 @@ class PKey(object):
:return: a base64 `string <str>` containing the public part of the key.
"""
- return u(encodebytes(self.asbytes())).replace('\n', '')
+ return u(encodebytes(self.asbytes())).replace("\n", "")
def sign_ssh_data(self, data):
"""
@@ -239,7 +239,7 @@ class PKey(object):
:raises: ``IOError`` -- if there was an error writing the file
:raises: `.SSHException` -- if the key is invalid
"""
- raise Exception('Not implemented in PKey')
+ raise Exception("Not implemented in PKey")
def write_private_key(self, file_obj, password=None):
"""
@@ -252,7 +252,7 @@ class PKey(object):
:raises: ``IOError`` -- if there was an error writing to the file
:raises: `.SSHException` -- if the key is invalid
"""
- raise Exception('Not implemented in PKey')
+ raise Exception("Not implemented in PKey")
def _read_private_key_file(self, tag, filename, password=None):
"""
@@ -275,60 +275,61 @@ class PKey(object):
encrypted, and ``password`` is ``None``.
:raises: `.SSHException` -- if the key file is invalid.
"""
- with open(filename, 'r') as f:
+ with open(filename, "r") as f:
data = self._read_private_key(tag, f, password)
return data
def _read_private_key(self, tag, f, password=None):
lines = f.readlines()
start = 0
- beginning_of_key = '-----BEGIN ' + tag + ' PRIVATE KEY-----'
+ beginning_of_key = "-----BEGIN " + tag + " PRIVATE KEY-----"
while start < len(lines) and lines[start].strip() != beginning_of_key:
start += 1
if start >= len(lines):
- raise SSHException('not a valid ' + tag + ' private key file')
+ raise SSHException("not a valid " + tag + " private key file")
# parse any headers first
headers = {}
start += 1
while start < len(lines):
- l = lines[start].split(': ')
+ l = lines[start].split(": ")
if len(l) == 1:
break
headers[l[0].lower()] = l[1].strip()
start += 1
# find end
end = start
- ending_of_key = '-----END ' + tag + ' PRIVATE KEY-----'
+ ending_of_key = "-----END " + tag + " PRIVATE KEY-----"
while end < len(lines) and lines[end].strip() != ending_of_key:
end += 1
# if we trudged to the end of the file, just try to cope.
try:
- data = decodebytes(b(''.join(lines[start:end])))
+ data = decodebytes(b("".join(lines[start:end])))
except base64.binascii.Error as e:
- raise SSHException('base64 decoding error: ' + str(e))
- if 'proc-type' not in headers:
+ raise SSHException("base64 decoding error: " + str(e))
+ if "proc-type" not in headers:
# unencryped: done
return data
# encrypted keyfile: will need a password
- proc_type = headers['proc-type']
- if proc_type != '4,ENCRYPTED':
+ proc_type = headers["proc-type"]
+ if proc_type != "4,ENCRYPTED":
raise SSHException(
'Unknown private key structure "{}"'.format(proc_type)
)
try:
- encryption_type, saltstr = headers['dek-info'].split(',')
+ encryption_type, saltstr = headers["dek-info"].split(",")
except:
raise SSHException("Can't parse DEK-info in private key file")
if encryption_type not in self._CIPHER_TABLE:
raise SSHException(
- 'Unknown private key cipher "{}"'.format(encryption_type))
+ 'Unknown private key cipher "{}"'.format(encryption_type)
+ )
# if no password was passed in,
# raise an exception pointing out that we need one
if password is None:
- raise PasswordRequiredException('Private key file is encrypted')
- cipher = self._CIPHER_TABLE[encryption_type]['cipher']
- keysize = self._CIPHER_TABLE[encryption_type]['keysize']
- mode = self._CIPHER_TABLE[encryption_type]['mode']
+ raise PasswordRequiredException("Private key file is encrypted")
+ cipher = self._CIPHER_TABLE[encryption_type]["cipher"]
+ keysize = self._CIPHER_TABLE[encryption_type]["keysize"]
+ mode = self._CIPHER_TABLE[encryption_type]["mode"]
salt = unhexlify(b(saltstr))
key = util.generate_key_bytes(md5, salt, password, keysize)
decryptor = Cipher(
@@ -351,7 +352,7 @@ class PKey(object):
:raises: ``IOError`` -- if there was an error writing the file.
"""
- with open(filename, 'w') as f:
+ with open(filename, "w") as f:
os.chmod(filename, o600)
self._write_private_key(f, key, format, password=password)
@@ -361,11 +362,11 @@ class PKey(object):
else:
encryption = serialization.BestAvailableEncryption(b(password))
- f.write(key.private_bytes(
- serialization.Encoding.PEM,
- format,
- encryption
- ).decode())
+ f.write(
+ key.private_bytes(
+ serialization.Encoding.PEM, format, encryption
+ ).decode()
+ )
def _check_type_and_load_cert(self, msg, key_type, cert_type):
"""
@@ -388,7 +389,7 @@ class PKey(object):
cert_types = [cert_types]
# Can't do much with no message, that should've been handled elsewhere
if msg is None:
- raise SSHException('Key object may not be empty')
+ raise SSHException("Key object may not be empty")
# First field is always key type, in either kind of object. (make sure
# we rewind before grabbing it - sometimes caller had to do their own
# introspection first!)
@@ -411,7 +412,7 @@ class PKey(object):
# (requires going back into per-type subclasses.)
msg.get_string()
else:
- err = 'Invalid key (class: {}, data type: {}'
+ err = "Invalid key (class: {}, data type: {}"
raise SSHException(err.format(self.__class__.__name__, type_))
def load_certificate(self, value):
@@ -434,11 +435,11 @@ class PKey(object):
successfully.
"""
if isinstance(value, Message):
- constructor = 'from_message'
+ constructor = "from_message"
elif os.path.isfile(value):
- constructor = 'from_file'
+ constructor = "from_file"
else:
- constructor = 'from_string'
+ constructor = "from_string"
blob = getattr(PublicBlob, constructor)(value)
if not blob.key_type.startswith(self.get_name()):
err = "PublicBlob type {} incompatible with key type {}"
@@ -464,6 +465,7 @@ class PublicBlob(object):
`from_message` for useful instantiation, the main constructor is
basically "I should be using ``attrs`` for this."
"""
+
def __init__(self, type_, blob, comment=None):
"""
Create a new public blob of given type and contents.
@@ -505,8 +507,10 @@ class PublicBlob(object):
m = Message(key_blob)
blob_type = m.get_text()
if blob_type != key_type:
- msg = "Invalid PublicBlob contents: key type={!r}, but blob type={!r}" # noqa
- raise ValueError(msg.format(key_type, blob_type))
+ deets = "key type={!r}, but blob type={!r}".format(
+ key_type, blob_type
+ )
+ raise ValueError("Invalid PublicBlob contents: {}".format(deets))
# All good? All good.
return cls(type_=key_type, blob=key_blob, comment=comment)
@@ -522,7 +526,7 @@ class PublicBlob(object):
return cls(type_=type_, blob=message.asbytes())
def __str__(self):
- ret = '{} public key/certificate'.format(self.key_type)
+ ret = "{} public key/certificate".format(self.key_type)
if self.comment:
ret += "- {}".format(self.comment)
return ret
diff --git a/paramiko/primes.py b/paramiko/primes.py
index ca8f9bec..8dff7683 100644
--- a/paramiko/primes.py
+++ b/paramiko/primes.py
@@ -49,7 +49,7 @@ def _roll_random(n):
return num
-class ModulusPack (object):
+class ModulusPack(object):
"""
convenience object for holding the contents of the /etc/ssh/moduli file,
on systems that have such a file.
@@ -61,8 +61,15 @@ class ModulusPack (object):
self.discarded = []
def _parse_modulus(self, line):
- timestamp, mod_type, tests, tries, size, generator, modulus = \
- line.split()
+ (
+ timestamp,
+ mod_type,
+ tests,
+ tries,
+ size,
+ generator,
+ modulus,
+ ) = line.split()
mod_type = int(mod_type)
tests = int(tests)
tries = int(tries)
@@ -75,12 +82,13 @@ class ModulusPack (object):
# test 4 (more than just a small-prime sieve)
# tries < 100 if test & 4 (at least 100 tries of miller-rabin)
if (
- mod_type < 2 or
- tests < 4 or
- (tests & 4 and tests < 8 and tries < 100)
+ mod_type < 2
+ or tests < 4
+ or (tests & 4 and tests < 8 and tries < 100)
):
self.discarded.append(
- (modulus, 'does not meet basic requirements'))
+ (modulus, "does not meet basic requirements")
+ )
return
if generator == 0:
generator = 2
@@ -91,7 +99,8 @@ class ModulusPack (object):
bl = util.bit_length(modulus)
if (bl != size) and (bl != size + 1):
self.discarded.append(
- (modulus, 'incorrectly reported bit length {}'.format(size)))
+ (modulus, "incorrectly reported bit length {}".format(size))
+ )
return
if bl not in self.pack:
self.pack[bl] = []
@@ -102,10 +111,10 @@ class ModulusPack (object):
:raises IOError: passed from any file operations that fail.
"""
self.pack = {}
- with open(filename, 'r') as f:
+ with open(filename, "r") as f:
for line in f:
line = line.strip()
- if (len(line) == 0) or (line[0] == '#'):
+ if (len(line) == 0) or (line[0] == "#"):
continue
try:
self._parse_modulus(line)
@@ -115,7 +124,7 @@ class ModulusPack (object):
def get_modulus(self, min, prefer, max):
bitsizes = sorted(self.pack.keys())
if len(bitsizes) == 0:
- raise SSHException('no moduli available')
+ raise SSHException("no moduli available")
good = -1
# find nearest bitsize >= preferred
for b in bitsizes:
diff --git a/paramiko/proxy.py b/paramiko/proxy.py
index c4ec627c..444c47b6 100644
--- a/paramiko/proxy.py
+++ b/paramiko/proxy.py
@@ -39,6 +39,7 @@ class ProxyCommand(ClosingContextManager):
Instances of this class may be used as context managers.
"""
+
def __init__(self, command_line):
"""
Create a new CommandProxy instance. The instance created by this
@@ -50,9 +51,11 @@ class ProxyCommand(ClosingContextManager):
# NOTE: subprocess import done lazily so platforms without it (e.g.
# GAE) can still import us during overall Paramiko load.
from subprocess import Popen, PIPE
+
self.cmd = shlsplit(command_line)
- self.process = Popen(self.cmd, stdin=PIPE, stdout=PIPE, stderr=PIPE,
- bufsize=0)
+ self.process = Popen(
+ self.cmd, stdin=PIPE, stdout=PIPE, stderr=PIPE, bufsize=0
+ )
self.timeout = None
def send(self, content):
@@ -69,7 +72,7 @@ class ProxyCommand(ClosingContextManager):
# died and we can't proceed. The best option here is to
# raise an exception informing the user that the informed
# ProxyCommand is not working.
- raise ProxyCommandFailure(' '.join(self.cmd), e.strerror)
+ raise ProxyCommandFailure(" ".join(self.cmd), e.strerror)
return len(content)
def recv(self, size):
@@ -81,21 +84,21 @@ class ProxyCommand(ClosingContextManager):
:return: the string of bytes read, which may be shorter than requested
"""
try:
- buffer = b''
+ buffer = b""
start = time.time()
while len(buffer) < size:
select_timeout = None
if self.timeout is not None:
- elapsed = (time.time() - start)
+ elapsed = time.time() - start
if elapsed >= self.timeout:
raise socket.timeout()
select_timeout = self.timeout - elapsed
- r, w, x = select(
- [self.process.stdout], [], [], select_timeout)
+ r, w, x = select([self.process.stdout], [], [], select_timeout)
if r and r[0] == self.process.stdout:
buffer += os.read(
- self.process.stdout.fileno(), size - len(buffer))
+ self.process.stdout.fileno(), size - len(buffer)
+ )
return buffer
except socket.timeout:
if buffer:
@@ -103,7 +106,7 @@ class ProxyCommand(ClosingContextManager):
return buffer
raise # socket.timeout is a subclass of IOError
except IOError as e:
- raise ProxyCommandFailure(' '.join(self.cmd), e.strerror)
+ raise ProxyCommandFailure(" ".join(self.cmd), e.strerror)
def close(self):
os.kill(self.process.pid, signal.SIGTERM)
diff --git a/paramiko/py3compat.py b/paramiko/py3compat.py
index 67c0f200..e1f33fe9 100644
--- a/paramiko/py3compat.py
+++ b/paramiko/py3compat.py
@@ -2,10 +2,28 @@ import sys
import base64
__all__ = [
- 'BytesIO', 'MAXSIZE', 'PY2', 'StringIO', 'b', 'b2s', 'builtins',
- 'byte_chr', 'byte_mask', 'byte_ord', 'bytes', 'bytes_types', 'decodebytes',
- 'encodebytes', 'input', 'integer_types', 'is_callable', 'long', 'next',
- 'string_types', 'text_type', 'u',
+ "BytesIO",
+ "MAXSIZE",
+ "PY2",
+ "StringIO",
+ "b",
+ "b2s",
+ "builtins",
+ "byte_chr",
+ "byte_mask",
+ "byte_ord",
+ "bytes",
+ "bytes_types",
+ "decodebytes",
+ "encodebytes",
+ "input",
+ "integer_types",
+ "is_callable",
+ "long",
+ "next",
+ "string_types",
+ "text_type",
+ "u",
]
PY2 = sys.version_info[0] < 3
@@ -23,16 +41,13 @@ if PY2:
import __builtin__ as builtins
-
byte_ord = ord # NOQA
byte_chr = chr # NOQA
-
def byte_mask(c, mask):
return chr(ord(c) & mask)
-
- def b(s, encoding='utf8'): # NOQA
+ def b(s, encoding="utf8"): # NOQA
"""cast unicode or bytes to bytes"""
if isinstance(s, str):
return s
@@ -43,8 +58,7 @@ if PY2:
else:
raise TypeError("Expected unicode or bytes, got {!r}".format(s))
-
- def u(s, encoding='utf8'): # NOQA
+ def u(s, encoding="utf8"): # NOQA
"""cast bytes or unicode to unicode"""
if isinstance(s, str):
return s.decode(encoding)
@@ -55,53 +69,52 @@ if PY2:
else:
raise TypeError("Expected unicode or bytes, got {!r}".format(s))
-
def b2s(s):
return s
-
import cStringIO
+
StringIO = cStringIO.StringIO
BytesIO = StringIO
-
def is_callable(c): # NOQA
return callable(c)
-
def get_next(c): # NOQA
return c.next
-
def next(c):
return c.next()
# It's possible to have sizeof(long) != sizeof(Py_ssize_t).
class X(object):
+
def __len__(self):
return 1 << 31
-
try:
len(X())
except OverflowError:
# 32-bit
- MAXSIZE = int((1 << 31) - 1) # NOQA
+ MAXSIZE = int((1 << 31) - 1) # NOQA
else:
# 64-bit
- MAXSIZE = int((1 << 63) - 1) # NOQA
+ MAXSIZE = int((1 << 63) - 1) # NOQA
del X
else:
import collections
import struct
import builtins
+
string_types = str
text_type = str
bytes = bytes
bytes_types = bytes
integer_types = int
+
class long(int):
pass
+
input = input
decodebytes = base64.decodebytes
encodebytes = base64.encodebytes
@@ -114,13 +127,13 @@ else:
def byte_chr(c):
assert isinstance(c, int)
- return struct.pack('B', c)
+ return struct.pack("B", c)
def byte_mask(c, mask):
assert isinstance(c, int)
- return struct.pack('B', c & mask)
+ return struct.pack("B", c & mask)
- def b(s, encoding='utf8'):
+ def b(s, encoding="utf8"):
"""cast unicode or bytes to bytes"""
if isinstance(s, bytes):
return s
@@ -129,7 +142,7 @@ else:
else:
raise TypeError("Expected unicode or bytes, got {!r}".format(s))
- def u(s, encoding='utf8'):
+ def u(s, encoding="utf8"):
"""cast bytes or unicode to unicode"""
if isinstance(s, bytes):
return s.decode(encoding)
@@ -142,8 +155,9 @@ else:
return s.decode() if isinstance(s, bytes) else s
import io
- StringIO = io.StringIO # NOQA
- BytesIO = io.BytesIO # NOQA
+
+ StringIO = io.StringIO # NOQA
+ BytesIO = io.BytesIO # NOQA
def is_callable(c):
return isinstance(c, collections.Callable)
@@ -153,4 +167,4 @@ else:
next = next
- MAXSIZE = sys.maxsize # NOQA
+ MAXSIZE = sys.maxsize # NOQA
diff --git a/paramiko/rsakey.py b/paramiko/rsakey.py
index 8dfcfb01..442bfe1f 100644
--- a/paramiko/rsakey.py
+++ b/paramiko/rsakey.py
@@ -37,8 +37,15 @@ class RSAKey(PKey):
data.
"""
- def __init__(self, msg=None, data=None, filename=None, password=None,
- key=None, file_obj=None):
+ def __init__(
+ self,
+ msg=None,
+ data=None,
+ filename=None,
+ password=None,
+ key=None,
+ file_obj=None,
+ ):
self.key = None
self.public_blob = None
if file_obj is not None:
@@ -54,8 +61,8 @@ class RSAKey(PKey):
else:
self._check_type_and_load_cert(
msg=msg,
- key_type='ssh-rsa',
- cert_type='ssh-rsa-cert-v01@openssh.com',
+ key_type="ssh-rsa",
+ cert_type="ssh-rsa-cert-v01@openssh.com",
)
self.key = rsa.RSAPublicNumbers(
e=msg.get_mpint(), n=msg.get_mpint()
@@ -74,7 +81,7 @@ class RSAKey(PKey):
def asbytes(self):
m = Message()
- m.add_string('ssh-rsa')
+ m.add_string("ssh-rsa")
m.add_mpint(self.public_numbers.e)
m.add_mpint(self.public_numbers.n)
return m.asbytes()
@@ -89,14 +96,15 @@ class RSAKey(PKey):
# tries stuffing it into ASCII for whatever godforsaken reason
return self.asbytes()
else:
- return self.asbytes().decode('utf8', errors='ignore')
+ return self.asbytes().decode("utf8", errors="ignore")
def __hash__(self):
- return hash((self.get_name(), self.public_numbers.e,
- self.public_numbers.n))
+ return hash(
+ (self.get_name(), self.public_numbers.e, self.public_numbers.n)
+ )
def get_name(self):
- return 'ssh-rsa'
+ return "ssh-rsa"
def get_bits(self):
return self.size
@@ -106,18 +114,16 @@ class RSAKey(PKey):
def sign_ssh_data(self, data):
sig = self.key.sign(
- data,
- padding=padding.PKCS1v15(),
- algorithm=hashes.SHA1(),
+ data, padding=padding.PKCS1v15(), algorithm=hashes.SHA1()
)
m = Message()
- m.add_string('ssh-rsa')
+ m.add_string("ssh-rsa")
m.add_string(sig)
return m
def verify_ssh_sig(self, data, msg):
- if msg.get_text() != 'ssh-rsa':
+ if msg.get_text() != "ssh-rsa":
return False
key = self.key
if isinstance(key, rsa.RSAPrivateKey):
@@ -137,7 +143,7 @@ class RSAKey(PKey):
filename,
self.key,
serialization.PrivateFormat.TraditionalOpenSSL,
- password=password
+ password=password,
)
def write_private_key(self, file_obj, password=None):
@@ -145,7 +151,7 @@ class RSAKey(PKey):
file_obj,
self.key,
serialization.PrivateFormat.TraditionalOpenSSL,
- password=password
+ password=password,
)
@staticmethod
@@ -166,11 +172,11 @@ class RSAKey(PKey):
# ...internals...
def _from_private_key_file(self, filename, password):
- data = self._read_private_key_file('RSA', filename, password)
+ data = self._read_private_key_file("RSA", filename, password)
self._decode_key(data)
def _from_private_key(self, file_obj, password):
- data = self._read_private_key('RSA', file_obj, password)
+ data = self._read_private_key("RSA", file_obj, password)
self._decode_key(data)
def _decode_key(self, data):
diff --git a/paramiko/server.py b/paramiko/server.py
index a7117815..2fe9cc19 100644
--- a/paramiko/server.py
+++ b/paramiko/server.py
@@ -23,13 +23,16 @@
import threading
from paramiko import util
from paramiko.common import (
- DEBUG, ERROR, OPEN_FAILED_ADMINISTRATIVELY_PROHIBITED, AUTH_FAILED,
+ DEBUG,
+ ERROR,
+ OPEN_FAILED_ADMINISTRATIVELY_PROHIBITED,
+ AUTH_FAILED,
AUTH_SUCCESSFUL,
)
from paramiko.py3compat import string_types
-class ServerInterface (object):
+class ServerInterface(object):
"""
This class defines an interface for controlling the behavior of Paramiko
in server mode.
@@ -99,7 +102,7 @@ class ServerInterface (object):
:param str username: the username requesting authentication.
:return: a comma-separated `str` of authentication types
"""
- return 'password'
+ return "password"
def check_auth_none(self, username):
"""
@@ -233,9 +236,9 @@ class ServerInterface (object):
"""
return AUTH_FAILED
- def check_auth_gssapi_with_mic(self, username,
- gss_authenticated=AUTH_FAILED,
- cc_file=None):
+ def check_auth_gssapi_with_mic(
+ self, username, gss_authenticated=AUTH_FAILED, cc_file=None
+ ):
"""
Authenticate the given user to the server if he is a valid krb5
principal.
@@ -263,9 +266,9 @@ class ServerInterface (object):
return AUTH_SUCCESSFUL
return AUTH_FAILED
- def check_auth_gssapi_keyex(self, username,
- gss_authenticated=AUTH_FAILED,
- cc_file=None):
+ def check_auth_gssapi_keyex(
+ self, username, gss_authenticated=AUTH_FAILED, cc_file=None
+ ):
"""
Authenticate the given user to the server if he is a valid krb5
principal and GSS-API Key Exchange was performed.
@@ -372,8 +375,8 @@ class ServerInterface (object):
# ...Channel requests...
def check_channel_pty_request(
- self, channel, term, width, height, pixelwidth, pixelheight,
- modes):
+ self, channel, term, width, height, pixelwidth, pixelheight, modes
+ ):
"""
Determine if a pseudo-terminal of the given dimensions (usually
requested for shell access) can be provided on the given channel.
@@ -460,7 +463,8 @@ class ServerInterface (object):
return True
def check_channel_window_change_request(
- self, channel, width, height, pixelwidth, pixelheight):
+ self, channel, width, height, pixelwidth, pixelheight
+ ):
"""
Determine if the pseudo-terminal on the given channel can be resized.
This only makes sense if a pty was previously allocated on it.
@@ -479,8 +483,13 @@ class ServerInterface (object):
return False
def check_channel_x11_request(
- self, channel, single_connection, auth_protocol, auth_cookie,
- screen_number):
+ self,
+ channel,
+ single_connection,
+ auth_protocol,
+ auth_cookie,
+ screen_number,
+ ):
"""
Determine if the client will be provided with an X11 session. If this
method returns ``True``, X11 applications should be routed through new
@@ -584,12 +593,13 @@ class ServerInterface (object):
"""
return (None, None)
-class InteractiveQuery (object):
+
+class InteractiveQuery(object):
"""
A query (set of prompts) for a user during interactive authentication.
"""
- def __init__(self, name='', instructions='', *prompts):
+ def __init__(self, name="", instructions="", *prompts):
"""
Create a new interactive query to send to the client. The name and
instructions are optional, but are generally displayed to the end
@@ -623,7 +633,7 @@ class InteractiveQuery (object):
self.prompts.append((prompt, echo))
-class SubsystemHandler (threading.Thread):
+class SubsystemHandler(threading.Thread):
"""
Handler for a subsytem in server mode. If you create a subclass of this
class and pass it to `.Transport.set_subsystem_handler`, an object of this
@@ -637,6 +647,7 @@ class SubsystemHandler (threading.Thread):
``MP3Handler`` will be created, and `start_subsystem` will be called on
it from a new thread.
"""
+
def __init__(self, channel, name, server):
"""
Create a new handler for a channel. This is used by `.ServerInterface`
@@ -667,7 +678,7 @@ class SubsystemHandler (threading.Thread):
def _run(self):
try:
self.__transport._log(
- DEBUG, 'Starting handler for subsystem {}'.format(self.__name)
+ DEBUG, "Starting handler for subsystem {}".format(self.__name)
)
self.start_subsystem(self.__name, self.__transport, self.__channel)
except Exception as e:
@@ -675,7 +686,7 @@ class SubsystemHandler (threading.Thread):
ERROR,
'Exception in subsystem handler for "{}": {}'.format(
self.__name, e
- )
+ ),
)
self.__transport._log(ERROR, util.tb_strings())
try:
diff --git a/paramiko/sftp.py b/paramiko/sftp.py
index e6786d10..6aa4ce44 100644
--- a/paramiko/sftp.py
+++ b/paramiko/sftp.py
@@ -26,27 +26,54 @@ from paramiko.message import Message
from paramiko.py3compat import byte_chr, byte_ord
-CMD_INIT, CMD_VERSION, CMD_OPEN, CMD_CLOSE, CMD_READ, CMD_WRITE, CMD_LSTAT, \
- CMD_FSTAT, CMD_SETSTAT, CMD_FSETSTAT, CMD_OPENDIR, CMD_READDIR, \
- CMD_REMOVE, CMD_MKDIR, CMD_RMDIR, CMD_REALPATH, CMD_STAT, CMD_RENAME, \
- CMD_READLINK, CMD_SYMLINK = range(1, 21)
-CMD_STATUS, CMD_HANDLE, CMD_DATA, CMD_NAME, CMD_ATTRS = range(101, 106)
-CMD_EXTENDED, CMD_EXTENDED_REPLY = range(200, 202)
+(
+ CMD_INIT,
+ CMD_VERSION,
+ CMD_OPEN,
+ CMD_CLOSE,
+ CMD_READ,
+ CMD_WRITE,
+ CMD_LSTAT,
+ CMD_FSTAT,
+ CMD_SETSTAT,
+ CMD_FSETSTAT,
+ CMD_OPENDIR,
+ CMD_READDIR,
+ CMD_REMOVE,
+ CMD_MKDIR,
+ CMD_RMDIR,
+ CMD_REALPATH,
+ CMD_STAT,
+ CMD_RENAME,
+ CMD_READLINK,
+ CMD_SYMLINK,
+) = range(1, 21)
+(CMD_STATUS, CMD_HANDLE, CMD_DATA, CMD_NAME, CMD_ATTRS) = range(101, 106)
+(CMD_EXTENDED, CMD_EXTENDED_REPLY) = range(200, 202)
SFTP_OK = 0
-SFTP_EOF, SFTP_NO_SUCH_FILE, SFTP_PERMISSION_DENIED, SFTP_FAILURE, \
- SFTP_BAD_MESSAGE, SFTP_NO_CONNECTION, SFTP_CONNECTION_LOST, \
- SFTP_OP_UNSUPPORTED = range(1, 9)
-
-SFTP_DESC = ['Success',
- 'End of file',
- 'No such file',
- 'Permission denied',
- 'Failure',
- 'Bad message',
- 'No connection',
- 'Connection lost',
- 'Operation unsupported']
+(
+ SFTP_EOF,
+ SFTP_NO_SUCH_FILE,
+ SFTP_PERMISSION_DENIED,
+ SFTP_FAILURE,
+ SFTP_BAD_MESSAGE,
+ SFTP_NO_CONNECTION,
+ SFTP_CONNECTION_LOST,
+ SFTP_OP_UNSUPPORTED,
+) = range(1, 9)
+
+SFTP_DESC = [
+ "Success",
+ "End of file",
+ "No such file",
+ "Permission denied",
+ "Failure",
+ "Bad message",
+ "No connection",
+ "Connection lost",
+ "Operation unsupported",
+]
SFTP_FLAG_READ = 0x1
SFTP_FLAG_WRITE = 0x2
@@ -60,54 +87,55 @@ _VERSION = 3
# for debugging
CMD_NAMES = {
- CMD_INIT: 'init',
- CMD_VERSION: 'version',
- CMD_OPEN: 'open',
- CMD_CLOSE: 'close',
- CMD_READ: 'read',
- CMD_WRITE: 'write',
- CMD_LSTAT: 'lstat',
- CMD_FSTAT: 'fstat',
- CMD_SETSTAT: 'setstat',
- CMD_FSETSTAT: 'fsetstat',
- CMD_OPENDIR: 'opendir',
- CMD_READDIR: 'readdir',
- CMD_REMOVE: 'remove',
- CMD_MKDIR: 'mkdir',
- CMD_RMDIR: 'rmdir',
- CMD_REALPATH: 'realpath',
- CMD_STAT: 'stat',
- CMD_RENAME: 'rename',
- CMD_READLINK: 'readlink',
- CMD_SYMLINK: 'symlink',
- CMD_STATUS: 'status',
- CMD_HANDLE: 'handle',
- CMD_DATA: 'data',
- CMD_NAME: 'name',
- CMD_ATTRS: 'attrs',
- CMD_EXTENDED: 'extended',
- CMD_EXTENDED_REPLY: 'extended_reply'
+ CMD_INIT: "init",
+ CMD_VERSION: "version",
+ CMD_OPEN: "open",
+ CMD_CLOSE: "close",
+ CMD_READ: "read",
+ CMD_WRITE: "write",
+ CMD_LSTAT: "lstat",
+ CMD_FSTAT: "fstat",
+ CMD_SETSTAT: "setstat",
+ CMD_FSETSTAT: "fsetstat",
+ CMD_OPENDIR: "opendir",
+ CMD_READDIR: "readdir",
+ CMD_REMOVE: "remove",
+ CMD_MKDIR: "mkdir",
+ CMD_RMDIR: "rmdir",
+ CMD_REALPATH: "realpath",
+ CMD_STAT: "stat",
+ CMD_RENAME: "rename",
+ CMD_READLINK: "readlink",
+ CMD_SYMLINK: "symlink",
+ CMD_STATUS: "status",
+ CMD_HANDLE: "handle",
+ CMD_DATA: "data",
+ CMD_NAME: "name",
+ CMD_ATTRS: "attrs",
+ CMD_EXTENDED: "extended",
+ CMD_EXTENDED_REPLY: "extended_reply",
}
-class SFTPError (Exception):
+class SFTPError(Exception):
pass
-class BaseSFTP (object):
+class BaseSFTP(object):
+
def __init__(self):
- self.logger = util.get_logger('paramiko.sftp')
+ self.logger = util.get_logger("paramiko.sftp")
self.sock = None
self.ultra_debug = False
# ...internals...
def _send_version(self):
- self._send_packet(CMD_INIT, struct.pack('>I', _VERSION))
+ self._send_packet(CMD_INIT, struct.pack(">I", _VERSION))
t, data = self._read_packet()
if t != CMD_VERSION:
- raise SFTPError('Incompatible sftp protocol')
- version = struct.unpack('>I', data[:4])[0]
+ raise SFTPError("Incompatible sftp protocol")
+ version = struct.unpack(">I", data[:4])[0]
# if version != _VERSION:
# raise SFTPError('Incompatible sftp protocol')
return version
@@ -117,10 +145,10 @@ class BaseSFTP (object):
# client finishes sending INIT.
t, data = self._read_packet()
if t != CMD_INIT:
- raise SFTPError('Incompatible sftp protocol')
- version = struct.unpack('>I', data[:4])[0]
+ raise SFTPError("Incompatible sftp protocol")
+ version = struct.unpack(">I", data[:4])[0]
# advertise that we support "check-file"
- extension_pairs = ['check-file', 'md5,sha1']
+ extension_pairs = ["check-file", "md5,sha1"]
msg = Message()
msg.add_int(_VERSION)
msg.add(*extension_pairs)
@@ -165,9 +193,9 @@ class BaseSFTP (object):
def _send_packet(self, t, packet):
packet = asbytes(packet)
- out = struct.pack('>I', len(packet) + 1) + byte_chr(t) + packet
+ out = struct.pack(">I", len(packet) + 1) + byte_chr(t) + packet
if self.ultra_debug:
- self._log(DEBUG, util.format_binary(out, 'OUT: '))
+ self._log(DEBUG, util.format_binary(out, "OUT: "))
self._write_all(out)
def _read_packet(self):
@@ -175,11 +203,11 @@ class BaseSFTP (object):
# most sftp servers won't accept packets larger than about 32k, so
# anything with the high byte set (> 16MB) is just garbage.
if byte_ord(x[0]):
- raise SFTPError('Garbage packet received')
- size = struct.unpack('>I', x)[0]
+ raise SFTPError("Garbage packet received")
+ size = struct.unpack(">I", x)[0]
data = self._read_all(size)
if self.ultra_debug:
- self._log(DEBUG, util.format_binary(data, 'IN: '))
+ self._log(DEBUG, util.format_binary(data, "IN: "))
if size > 0:
t = byte_ord(data[0])
return t, data[1:]
diff --git a/paramiko/sftp_attr.py b/paramiko/sftp_attr.py
index ea12b2f6..f16ac746 100644
--- a/paramiko/sftp_attr.py
+++ b/paramiko/sftp_attr.py
@@ -22,7 +22,7 @@ from paramiko.common import x80000000, o700, o70, xffffffff
from paramiko.py3compat import long, b
-class SFTPAttributes (object):
+class SFTPAttributes(object):
"""
Representation of the attributes of a file (or proxied file) for SFTP in
client or server mode. It attemps to mirror the object returned by
@@ -82,7 +82,7 @@ class SFTPAttributes (object):
return attr
def __repr__(self):
- return '<SFTPAttributes: {}>'.format(self._debug_str())
+ return "<SFTPAttributes: {}>".format(self._debug_str())
# ...internals...
@classmethod
@@ -144,29 +144,29 @@ class SFTPAttributes (object):
return
def _debug_str(self):
- out = '[ '
+ out = "[ "
if self.st_size is not None:
- out += 'size={} '.format(self.st_size)
+ out += "size={} ".format(self.st_size)
if (self.st_uid is not None) and (self.st_gid is not None):
- out += 'uid={} gid={} '.format(self.st_uid, self.st_gid)
+ out += "uid={} gid={} ".format(self.st_uid, self.st_gid)
if self.st_mode is not None:
- out += 'mode=' + oct(self.st_mode) + ' '
+ out += "mode=" + oct(self.st_mode) + " "
if (self.st_atime is not None) and (self.st_mtime is not None):
- out += 'atime={} mtime={} '.format(self.st_atime, self.st_mtime)
+ out += "atime={} mtime={} ".format(self.st_atime, self.st_mtime)
for k, v in self.attr.items():
out += '"{}"={!r} '.format(str(k), v)
- out += ']'
+ out += "]"
return out
@staticmethod
def _rwx(n, suid, sticky=False):
if suid:
suid = 2
- out = '-r'[n >> 2] + '-w'[(n >> 1) & 1]
+ out = "-r"[n >> 2] + "-w"[(n >> 1) & 1]
if sticky:
- out += '-xTt'[suid + (n & 1)]
+ out += "-xTt"[suid + (n & 1)]
else:
- out += '-xSs'[suid + (n & 1)]
+ out += "-xSs"[suid + (n & 1)]
return out
def __str__(self):
@@ -174,42 +174,47 @@ class SFTPAttributes (object):
if self.st_mode is not None:
kind = stat.S_IFMT(self.st_mode)
if kind == stat.S_IFIFO:
- ks = 'p'
+ ks = "p"
elif kind == stat.S_IFCHR:
- ks = 'c'
+ ks = "c"
elif kind == stat.S_IFDIR:
- ks = 'd'
+ ks = "d"
elif kind == stat.S_IFBLK:
- ks = 'b'
+ ks = "b"
elif kind == stat.S_IFREG:
- ks = '-'
+ ks = "-"
elif kind == stat.S_IFLNK:
- ks = 'l'
+ ks = "l"
elif kind == stat.S_IFSOCK:
- ks = 's'
+ ks = "s"
else:
- ks = '?'
+ ks = "?"
ks += self._rwx(
- (self.st_mode & o700) >> 6, self.st_mode & stat.S_ISUID)
+ (self.st_mode & o700) >> 6, self.st_mode & stat.S_ISUID
+ )
ks += self._rwx(
- (self.st_mode & o70) >> 3, self.st_mode & stat.S_ISGID)
+ (self.st_mode & o70) >> 3, self.st_mode & stat.S_ISGID
+ )
ks += self._rwx(
- self.st_mode & 7, self.st_mode & stat.S_ISVTX, True)
+ self.st_mode & 7, self.st_mode & stat.S_ISVTX, True
+ )
else:
- ks = '?---------'
+ ks = "?---------"
# compute display date
if (self.st_mtime is None) or (self.st_mtime == xffffffff):
# shouldn't really happen
- datestr = '(unknown date)'
+ datestr = "(unknown date)"
else:
if abs(time.time() - self.st_mtime) > 15552000:
# (15552000 = 6 months)
datestr = time.strftime(
- '%d %b %Y', time.localtime(self.st_mtime))
+ "%d %b %Y", time.localtime(self.st_mtime)
+ )
else:
datestr = time.strftime(
- '%d %b %H:%M', time.localtime(self.st_mtime))
- filename = getattr(self, 'filename', '?')
+ "%d %b %H:%M", time.localtime(self.st_mtime)
+ )
+ filename = getattr(self, "filename", "?")
# not all servers support uid/gid
uid = self.st_uid
@@ -225,8 +230,13 @@ class SFTPAttributes (object):
# TODO: not sure this actually worked as expected beforehand, leaving
# it untouched for the time being, re: .format() upgrade, until someone
# has time to doublecheck
- return '%s 1 %-8d %-8d %8d %-12s %s' % (
- ks, uid, gid, size, datestr, filename,
+ return "%s 1 %-8d %-8d %8d %-12s %s" % (
+ ks,
+ uid,
+ gid,
+ size,
+ datestr,
+ filename,
)
def asbytes(self):
diff --git a/paramiko/sftp_client.py b/paramiko/sftp_client.py
index 31dc234c..de3f9f58 100644
--- a/paramiko/sftp_client.py
+++ b/paramiko/sftp_client.py
@@ -30,12 +30,37 @@ from paramiko.message import Message
from paramiko.common import INFO, DEBUG, o777
from paramiko.py3compat import b, u, long
from paramiko.sftp import (
- BaseSFTP, CMD_OPENDIR, CMD_HANDLE, SFTPError, CMD_READDIR, CMD_NAME,
- CMD_CLOSE, SFTP_FLAG_READ, SFTP_FLAG_WRITE, SFTP_FLAG_CREATE,
- SFTP_FLAG_TRUNC, SFTP_FLAG_APPEND, SFTP_FLAG_EXCL, CMD_OPEN, CMD_REMOVE,
- CMD_RENAME, CMD_MKDIR, CMD_RMDIR, CMD_STAT, CMD_ATTRS, CMD_LSTAT,
- CMD_SYMLINK, CMD_SETSTAT, CMD_READLINK, CMD_REALPATH, CMD_STATUS,
- CMD_EXTENDED, SFTP_OK, SFTP_EOF, SFTP_NO_SUCH_FILE, SFTP_PERMISSION_DENIED,
+ BaseSFTP,
+ CMD_OPENDIR,
+ CMD_HANDLE,
+ SFTPError,
+ CMD_READDIR,
+ CMD_NAME,
+ CMD_CLOSE,
+ SFTP_FLAG_READ,
+ SFTP_FLAG_WRITE,
+ SFTP_FLAG_CREATE,
+ SFTP_FLAG_TRUNC,
+ SFTP_FLAG_APPEND,
+ SFTP_FLAG_EXCL,
+ CMD_OPEN,
+ CMD_REMOVE,
+ CMD_RENAME,
+ CMD_MKDIR,
+ CMD_RMDIR,
+ CMD_STAT,
+ CMD_ATTRS,
+ CMD_LSTAT,
+ CMD_SYMLINK,
+ CMD_SETSTAT,
+ CMD_READLINK,
+ CMD_REALPATH,
+ CMD_STATUS,
+ CMD_EXTENDED,
+ SFTP_OK,
+ SFTP_EOF,
+ SFTP_NO_SUCH_FILE,
+ SFTP_PERMISSION_DENIED,
)
from paramiko.sftp_attr import SFTPAttributes
@@ -51,15 +76,15 @@ def _to_unicode(s):
probably doesn't know the filename's encoding.
"""
try:
- return s.encode('ascii')
+ return s.encode("ascii")
except (UnicodeError, AttributeError):
try:
- return s.decode('utf-8')
+ return s.decode("utf-8")
except UnicodeError:
return s
-b_slash = b'/'
+b_slash = b"/"
class SFTPClient(BaseSFTP, ClosingContextManager):
@@ -71,6 +96,7 @@ class SFTPClient(BaseSFTP, ClosingContextManager):
Instances of this class may be used as context managers.
"""
+
def __init__(self, sock):
"""
Create an SFTP client from an existing `.Channel`. The channel
@@ -97,15 +123,18 @@ class SFTPClient(BaseSFTP, ClosingContextManager):
# override default logger
transport = self.sock.get_transport()
self.logger = util.get_logger(
- transport.get_log_channel() + '.sftp')
+ transport.get_log_channel() + ".sftp"
+ )
self.ultra_debug = transport.get_hexdump()
try:
server_version = self._send_version()
except EOFError:
- raise SSHException('EOF during negotiation')
+ raise SSHException("EOF during negotiation")
self._log(
INFO,
- 'Opened sftp connection (server version {})'.format(server_version)
+ "Opened sftp connection (server version {})".format(
+ server_version
+ ),
)
@classmethod
@@ -132,11 +161,12 @@ class SFTPClient(BaseSFTP, ClosingContextManager):
.. versionchanged:: 1.15
Added the ``window_size`` and ``max_packet_size`` arguments.
"""
- chan = t.open_session(window_size=window_size,
- max_packet_size=max_packet_size)
+ chan = t.open_session(
+ window_size=window_size, max_packet_size=max_packet_size
+ )
if chan is None:
return None
- chan.invoke_subsystem('sftp')
+ chan.invoke_subsystem("sftp")
return cls(chan)
def _log(self, level, msg, *args):
@@ -148,10 +178,12 @@ class SFTPClient(BaseSFTP, ClosingContextManager):
# logging.Logger.log() explicitly requires it. Grump.
# escape '%' in msg (they could come from file or directory names)
# before logging
- msg = msg.replace('%', '%%')
+ msg = msg.replace("%", "%%")
super(SFTPClient, self)._log(
level,
- "[chan %s] " + msg, *([self.sock.get_name()] + list(args)))
+ "[chan %s] " + msg,
+ *([self.sock.get_name()] + list(args))
+ )
def close(self):
"""
@@ -159,7 +191,7 @@ class SFTPClient(BaseSFTP, ClosingContextManager):
.. versionadded:: 1.4
"""
- self._log(INFO, 'sftp session closed.')
+ self._log(INFO, "sftp session closed.")
self.sock.close()
def get_channel(self):
@@ -171,7 +203,7 @@ class SFTPClient(BaseSFTP, ClosingContextManager):
"""
return self.sock
- def listdir(self, path='.'):
+ def listdir(self, path="."):
"""
Return a list containing the names of the entries in the given
``path``.
@@ -185,7 +217,7 @@ class SFTPClient(BaseSFTP, ClosingContextManager):
"""
return [f.filename for f in self.listdir_attr(path)]
- def listdir_attr(self, path='.'):
+ def listdir_attr(self, path="."):
"""
Return a list containing `.SFTPAttributes` objects corresponding to
files in the given ``path``. The list is in arbitrary order. It does
@@ -203,10 +235,10 @@ class SFTPClient(BaseSFTP, ClosingContextManager):
.. versionadded:: 1.2
"""
path = self._adjust_cwd(path)
- self._log(DEBUG, 'listdir({!r})'.format(path))
+ self._log(DEBUG, "listdir({!r})".format(path))
t, msg = self._request(CMD_OPENDIR, path)
if t != CMD_HANDLE:
- raise SFTPError('Expected handle')
+ raise SFTPError("Expected handle")
handle = msg.get_binary()
filelist = []
while True:
@@ -216,18 +248,18 @@ class SFTPClient(BaseSFTP, ClosingContextManager):
# done with handle
break
if t != CMD_NAME:
- raise SFTPError('Expected name response')
+ raise SFTPError("Expected name response")
count = msg.get_int()
for i in range(count):
filename = msg.get_text()
longname = msg.get_text()
attr = SFTPAttributes._from_msg(msg, filename, longname)
- if (filename != '.') and (filename != '..'):
+ if (filename != ".") and (filename != ".."):
filelist.append(attr)
self._request(CMD_CLOSE, handle)
return filelist
- def listdir_iter(self, path='.', read_aheads=50):
+ def listdir_iter(self, path=".", read_aheads=50):
"""
Generator version of `.listdir_attr`.
@@ -242,11 +274,11 @@ class SFTPClient(BaseSFTP, ClosingContextManager):
.. versionadded:: 1.15
"""
path = self._adjust_cwd(path)
- self._log(DEBUG, 'listdir({!r})'.format(path))
+ self._log(DEBUG, "listdir({!r})".format(path))
t, msg = self._request(CMD_OPENDIR, path)
if t != CMD_HANDLE:
- raise SFTPError('Expected handle')
+ raise SFTPError("Expected handle")
handle = msg.get_string()
@@ -261,7 +293,6 @@ class SFTPClient(BaseSFTP, ClosingContextManager):
num = self._async_request(type(None), CMD_READDIR, handle)
nums.append(num)
-
# For each of our sent requests
# Read and parse the corresponding packets
# If we're at the end of our queued requests, then fire off
@@ -280,8 +311,9 @@ class SFTPClient(BaseSFTP, ClosingContextManager):
filename = msg.get_text()
longname = msg.get_text()
attr = SFTPAttributes._from_msg(
- msg, filename, longname)
- if (filename != '.') and (filename != '..'):
+ msg, filename, longname
+ )
+ if (filename != ".") and (filename != ".."):
yield attr
# If we've hit the end of our queued requests, reset nums.
@@ -291,8 +323,7 @@ class SFTPClient(BaseSFTP, ClosingContextManager):
self._request(CMD_CLOSE, handle)
return
-
- def open(self, filename, mode='r', bufsize=-1):
+ def open(self, filename, mode="r", bufsize=-1):
"""
Open a file on the remote server. The arguments are the same as for
Python's built-in `python:file` (aka `python:open`). A file-like
@@ -325,26 +356,28 @@ class SFTPClient(BaseSFTP, ClosingContextManager):
:raises: ``IOError`` -- if the file could not be opened.
"""
filename = self._adjust_cwd(filename)
- self._log(DEBUG, 'open({!r}, {!r})'.format(filename, mode))
+ self._log(DEBUG, "open({!r}, {!r})".format(filename, mode))
imode = 0
- if ('r' in mode) or ('+' in mode):
+ if ("r" in mode) or ("+" in mode):
imode |= SFTP_FLAG_READ
- if ('w' in mode) or ('+' in mode) or ('a' in mode):
+ if ("w" in mode) or ("+" in mode) or ("a" in mode):
imode |= SFTP_FLAG_WRITE
- if 'w' in mode:
+ if "w" in mode:
imode |= SFTP_FLAG_CREATE | SFTP_FLAG_TRUNC
- if 'a' in mode:
+ if "a" in mode:
imode |= SFTP_FLAG_CREATE | SFTP_FLAG_APPEND
- if 'x' in mode:
+ if "x" in mode:
imode |= SFTP_FLAG_CREATE | SFTP_FLAG_EXCL
attrblock = SFTPAttributes()
t, msg = self._request(CMD_OPEN, filename, imode, attrblock)
if t != CMD_HANDLE:
- raise SFTPError('Expected handle')
+ raise SFTPError("Expected handle")
handle = msg.get_binary()
self._log(
DEBUG,
- 'open({!r}, {!r}) -> {}'.format(filename, mode, u(hexlify(handle)))
+ "open({!r}, {!r}) -> {}".format(
+ filename, mode, u(hexlify(handle))
+ ),
)
return SFTPFile(self, handle, mode, bufsize)
@@ -361,7 +394,7 @@ class SFTPClient(BaseSFTP, ClosingContextManager):
:raises: ``IOError`` -- if the path refers to a folder (directory)
"""
path = self._adjust_cwd(path)
- self._log(DEBUG, 'remove({!r})'.format(path))
+ self._log(DEBUG, "remove({!r})".format(path))
self._request(CMD_REMOVE, path)
unlink = remove
@@ -386,7 +419,7 @@ class SFTPClient(BaseSFTP, ClosingContextManager):
"""
oldpath = self._adjust_cwd(oldpath)
newpath = self._adjust_cwd(newpath)
- self._log(DEBUG, 'rename({!r}, {!r})'.format(oldpath, newpath))
+ self._log(DEBUG, "rename({!r}, {!r})".format(oldpath, newpath))
self._request(CMD_RENAME, oldpath, newpath)
def posix_rename(self, oldpath, newpath):
@@ -406,7 +439,7 @@ class SFTPClient(BaseSFTP, ClosingContextManager):
"""
oldpath = self._adjust_cwd(oldpath)
newpath = self._adjust_cwd(newpath)
- self._log(DEBUG, 'posix_rename({!r}, {!r})'.format(oldpath, newpath))
+ self._log(DEBUG, "posix_rename({!r}, {!r})".format(oldpath, newpath))
self._request(
CMD_EXTENDED, "posix-rename@openssh.com", oldpath, newpath
)
@@ -421,7 +454,7 @@ class SFTPClient(BaseSFTP, ClosingContextManager):
:param int mode: permissions (posix-style) for the newly-created folder
"""
path = self._adjust_cwd(path)
- self._log(DEBUG, 'mkdir({!r}, {!r})'.format(path, mode))
+ self._log(DEBUG, "mkdir({!r}, {!r})".format(path, mode))
attr = SFTPAttributes()
attr.st_mode = mode
self._request(CMD_MKDIR, path, attr)
@@ -433,7 +466,7 @@ class SFTPClient(BaseSFTP, ClosingContextManager):
:param str path: name of the folder to remove
"""
path = self._adjust_cwd(path)
- self._log(DEBUG, 'rmdir({!r})'.format(path))
+ self._log(DEBUG, "rmdir({!r})".format(path))
self._request(CMD_RMDIR, path)
def stat(self, path):
@@ -456,10 +489,10 @@ class SFTPClient(BaseSFTP, ClosingContextManager):
file
"""
path = self._adjust_cwd(path)
- self._log(DEBUG, 'stat({!r})'.format(path))
+ self._log(DEBUG, "stat({!r})".format(path))
t, msg = self._request(CMD_STAT, path)
if t != CMD_ATTRS:
- raise SFTPError('Expected attributes')
+ raise SFTPError("Expected attributes")
return SFTPAttributes._from_msg(msg)
def lstat(self, path):
@@ -474,10 +507,10 @@ class SFTPClient(BaseSFTP, ClosingContextManager):
file
"""
path = self._adjust_cwd(path)
- self._log(DEBUG, 'lstat({!r})'.format(path))
+ self._log(DEBUG, "lstat({!r})".format(path))
t, msg = self._request(CMD_LSTAT, path)
if t != CMD_ATTRS:
- raise SFTPError('Expected attributes')
+ raise SFTPError("Expected attributes")
return SFTPAttributes._from_msg(msg)
def symlink(self, source, dest):
@@ -488,7 +521,7 @@ class SFTPClient(BaseSFTP, ClosingContextManager):
:param str dest: path of the newly created symlink
"""
dest = self._adjust_cwd(dest)
- self._log(DEBUG, 'symlink({!r}, {!r})'.format(source, dest))
+ self._log(DEBUG, "symlink({!r}, {!r})".format(source, dest))
source = b(source)
self._request(CMD_SYMLINK, source, dest)
@@ -502,7 +535,7 @@ class SFTPClient(BaseSFTP, ClosingContextManager):
:param int mode: new permissions
"""
path = self._adjust_cwd(path)
- self._log(DEBUG, 'chmod({!r}, {!r})'.format(path, mode))
+ self._log(DEBUG, "chmod({!r}, {!r})".format(path, mode))
attr = SFTPAttributes()
attr.st_mode = mode
self._request(CMD_SETSTAT, path, attr)
@@ -519,7 +552,7 @@ class SFTPClient(BaseSFTP, ClosingContextManager):
:param int gid: new group id
"""
path = self._adjust_cwd(path)
- self._log(DEBUG, 'chown({!r}, {!r}, {!r})'.format(path, uid, gid))
+ self._log(DEBUG, "chown({!r}, {!r}, {!r})".format(path, uid, gid))
attr = SFTPAttributes()
attr.st_uid, attr.st_gid = uid, gid
self._request(CMD_SETSTAT, path, attr)
@@ -541,7 +574,7 @@ class SFTPClient(BaseSFTP, ClosingContextManager):
path = self._adjust_cwd(path)
if times is None:
times = (time.time(), time.time())
- self._log(DEBUG, 'utime({!r}, {!r})'.format(path, times))
+ self._log(DEBUG, "utime({!r}, {!r})".format(path, times))
attr = SFTPAttributes()
attr.st_atime, attr.st_mtime = times
self._request(CMD_SETSTAT, path, attr)
@@ -556,7 +589,7 @@ class SFTPClient(BaseSFTP, ClosingContextManager):
:param int size: the new size of the file
"""
path = self._adjust_cwd(path)
- self._log(DEBUG, 'truncate({!r}, {!r})'.format(path, size))
+ self._log(DEBUG, "truncate({!r}, {!r})".format(path, size))
attr = SFTPAttributes()
attr.st_size = size
self._request(CMD_SETSTAT, path, attr)
@@ -571,15 +604,15 @@ class SFTPClient(BaseSFTP, ClosingContextManager):
:return: target path, as a `str`
"""
path = self._adjust_cwd(path)
- self._log(DEBUG, 'readlink({!r})'.format(path))
+ self._log(DEBUG, "readlink({!r})".format(path))
t, msg = self._request(CMD_READLINK, path)
if t != CMD_NAME:
- raise SFTPError('Expected name response')
+ raise SFTPError("Expected name response")
count = msg.get_int()
if count == 0:
return None
if count != 1:
- raise SFTPError('Readlink returned {} results'.format(count))
+ raise SFTPError("Readlink returned {} results".format(count))
return _to_unicode(msg.get_string())
def normalize(self, path):
@@ -595,13 +628,13 @@ class SFTPClient(BaseSFTP, ClosingContextManager):
:raises: ``IOError`` -- if the path can't be resolved on the server
"""
path = self._adjust_cwd(path)
- self._log(DEBUG, 'normalize({!r})'.format(path))
+ self._log(DEBUG, "normalize({!r})".format(path))
t, msg = self._request(CMD_REALPATH, path)
if t != CMD_NAME:
- raise SFTPError('Expected name response')
+ raise SFTPError("Expected name response")
count = msg.get_int()
if count != 1:
- raise SFTPError('Realpath returned {} results'.format(count))
+ raise SFTPError("Realpath returned {} results".format(count))
return msg.get_text()
def chdir(self, path=None):
@@ -625,9 +658,7 @@ class SFTPClient(BaseSFTP, ClosingContextManager):
return
if not stat.S_ISDIR(self.stat(path).st_mode):
code = errno.ENOTDIR
- raise SFTPError(
- code, "{}: {}".format(os.strerror(code), path)
- )
+ raise SFTPError(code, "{}: {}".format(os.strerror(code), path))
self._cwd = b(self.normalize(path))
def getcwd(self):
@@ -680,7 +711,7 @@ class SFTPClient(BaseSFTP, ClosingContextManager):
.. versionadded:: 1.10
"""
- with self.file(remotepath, 'wb') as fr:
+ with self.file(remotepath, "wb") as fr:
fr.set_pipelined(True)
size = self._transfer_with_callback(
reader=fl, writer=fr, file_size=file_size, callback=callback
@@ -689,7 +720,8 @@ class SFTPClient(BaseSFTP, ClosingContextManager):
s = self.stat(remotepath)
if s.st_size != size:
raise IOError(
- 'size mismatch in put! {} != {}'.format(s.st_size, size))
+ "size mismatch in put! {} != {}".format(s.st_size, size)
+ )
else:
s = SFTPAttributes()
return s
@@ -723,7 +755,7 @@ class SFTPClient(BaseSFTP, ClosingContextManager):
``confirm`` param added.
"""
file_size = os.stat(localpath).st_size
- with open(localpath, 'rb') as fl:
+ with open(localpath, "rb") as fl:
return self.putfo(fl, remotepath, file_size, callback, confirm)
def getfo(self, remotepath, fl, callback=None):
@@ -744,7 +776,7 @@ class SFTPClient(BaseSFTP, ClosingContextManager):
.. versionadded:: 1.10
"""
file_size = self.stat(remotepath).st_size
- with self.open(remotepath, 'rb') as fr:
+ with self.open(remotepath, "rb") as fr:
fr.prefetch(file_size)
return self._transfer_with_callback(
reader=fr, writer=fl, file_size=file_size, callback=callback
@@ -766,12 +798,13 @@ class SFTPClient(BaseSFTP, ClosingContextManager):
.. versionchanged:: 1.7.4
Added the ``callback`` param
"""
- with open(localpath, 'wb') as fl:
+ with open(localpath, "wb") as fl:
size = self.getfo(remotepath, fl, callback)
s = os.stat(localpath)
if s.st_size != size:
raise IOError(
- 'size mismatch in get! {} != {}'.format(s.st_size, size))
+ "size mismatch in get! {} != {}".format(s.st_size, size)
+ )
# ...internals...
@@ -809,7 +842,7 @@ class SFTPClient(BaseSFTP, ClosingContextManager):
try:
t, data = self._read_packet()
except EOFError as e:
- raise SSHException('Server connection dropped: {}'.format(e))
+ raise SSHException("Server connection dropped: {}".format(e))
msg = Message(data)
num = msg.get_int()
self._lock.acquire()
@@ -817,7 +850,7 @@ class SFTPClient(BaseSFTP, ClosingContextManager):
if num not in self._expecting:
# might be response for a file that was closed before
# responses came back
- self._log(DEBUG, 'Unexpected response #{}'.format(num))
+ self._log(DEBUG, "Unexpected response #{}".format(num))
if waitfor is None:
# just doing a single check
break
diff --git a/paramiko/sftp_file.py b/paramiko/sftp_file.py
index 52f2bde8..0104d857 100644
--- a/paramiko/sftp_file.py
+++ b/paramiko/sftp_file.py
@@ -32,13 +32,21 @@ from paramiko.common import DEBUG
from paramiko.file import BufferedFile
from paramiko.py3compat import u, long
from paramiko.sftp import (
- CMD_CLOSE, CMD_READ, CMD_DATA, SFTPError, CMD_WRITE, CMD_STATUS, CMD_FSTAT,
- CMD_ATTRS, CMD_FSETSTAT, CMD_EXTENDED,
+ CMD_CLOSE,
+ CMD_READ,
+ CMD_DATA,
+ SFTPError,
+ CMD_WRITE,
+ CMD_STATUS,
+ CMD_FSTAT,
+ CMD_ATTRS,
+ CMD_FSETSTAT,
+ CMD_EXTENDED,
)
from paramiko.sftp_attr import SFTPAttributes
-class SFTPFile (BufferedFile):
+class SFTPFile(BufferedFile):
"""
Proxy object for a file on the remote server, in client mode SFTP.
@@ -50,7 +58,7 @@ class SFTPFile (BufferedFile):
# this size.
MAX_REQUEST_SIZE = 32768
- def __init__(self, sftp, handle, mode='r', bufsize=-1):
+ def __init__(self, sftp, handle, mode="r", bufsize=-1):
BufferedFile.__init__(self)
self.sftp = sftp
self.handle = handle
@@ -83,7 +91,7 @@ class SFTPFile (BufferedFile):
# __del__.)
if self._closed:
return
- self.sftp._log(DEBUG, 'close({})'.format(u(hexlify(self.handle))))
+ self.sftp._log(DEBUG, "close({})".format(u(hexlify(self.handle))))
if self.pipelined:
self.sftp._finish_responses(self)
BufferedFile.close(self)
@@ -102,8 +110,9 @@ class SFTPFile (BufferedFile):
pass
def _data_in_prefetch_requests(self, offset, size):
- k = [x for x in list(self._prefetch_extents.values())
- if x[0] <= offset]
+ k = [
+ x for x in list(self._prefetch_extents.values()) if x[0] <= offset
+ ]
if len(k) == 0:
return False
k.sort(key=lambda x: x[0])
@@ -117,8 +126,8 @@ class SFTPFile (BufferedFile):
# well, we have part of the request. see if another chunk has
# the rest.
return self._data_in_prefetch_requests(
- buf_offset + buf_size,
- offset + size - buf_offset - buf_size)
+ buf_offset + buf_size, offset + size - buf_offset - buf_size
+ )
def _data_in_prefetch_buffers(self, offset):
"""
@@ -174,13 +183,10 @@ class SFTPFile (BufferedFile):
if data is not None:
return data
t, msg = self.sftp._request(
- CMD_READ,
- self.handle,
- long(self._realpos),
- int(size)
+ CMD_READ, self.handle, long(self._realpos), int(size)
)
if t != CMD_DATA:
- raise SFTPError('Expected data')
+ raise SFTPError("Expected data")
return msg.get_string()
def _write(self, data):
@@ -191,18 +197,17 @@ class SFTPFile (BufferedFile):
CMD_WRITE,
self.handle,
long(self._realpos),
- data[:chunk]
+ data[:chunk],
)
self._reqs.append(sftp_async_request)
- if (
- not self.pipelined or
- (len(self._reqs) > 100 and self.sftp.sock.recv_ready())
+ if not self.pipelined or (
+ len(self._reqs) > 100 and self.sftp.sock.recv_ready()
):
while len(self._reqs):
req = self._reqs.popleft()
t, msg = self.sftp._read_response(req)
if t != CMD_STATUS:
- raise SFTPError('Expected status')
+ raise SFTPError("Expected status")
# convert_status already called
return chunk
@@ -277,7 +282,7 @@ class SFTPFile (BufferedFile):
"""
t, msg = self.sftp._request(CMD_FSTAT, self.handle)
if t != CMD_ATTRS:
- raise SFTPError('Expected attributes')
+ raise SFTPError("Expected attributes")
return SFTPAttributes._from_msg(msg)
def chmod(self, mode):
@@ -288,8 +293,9 @@ class SFTPFile (BufferedFile):
:param int mode: new permissions
"""
- self.sftp._log(DEBUG, 'chmod({}, {!r})'.format(
- hexlify(self.handle), mode))
+ self.sftp._log(
+ DEBUG, "chmod({}, {!r})".format(hexlify(self.handle), mode)
+ )
attr = SFTPAttributes()
attr.st_mode = mode
self.sftp._request(CMD_FSETSTAT, self.handle, attr)
@@ -306,7 +312,8 @@ class SFTPFile (BufferedFile):
"""
self.sftp._log(
DEBUG,
- 'chown({}, {!r}, {!r})'.format(hexlify(self.handle), uid, gid))
+ "chown({}, {!r}, {!r})".format(hexlify(self.handle), uid, gid),
+ )
attr = SFTPAttributes()
attr.st_uid, attr.st_gid = uid, gid
self.sftp._request(CMD_FSETSTAT, self.handle, attr)
@@ -326,8 +333,9 @@ class SFTPFile (BufferedFile):
"""
if times is None:
times = (time.time(), time.time())
- self.sftp._log(DEBUG, 'utime({}, {!r})'.format(
- hexlify(self.handle), times))
+ self.sftp._log(
+ DEBUG, "utime({}, {!r})".format(hexlify(self.handle), times)
+ )
attr = SFTPAttributes()
attr.st_atime, attr.st_mtime = times
self.sftp._request(CMD_FSETSTAT, self.handle, attr)
@@ -341,8 +349,8 @@ class SFTPFile (BufferedFile):
:param size: the new size of the file
"""
self.sftp._log(
- DEBUG,
- 'truncate({}, {!r})'.format(hexlify(self.handle), size))
+ DEBUG, "truncate({}, {!r})".format(hexlify(self.handle), size)
+ )
attr = SFTPAttributes()
attr.st_size = size
self.sftp._request(CMD_FSETSTAT, self.handle, attr)
@@ -394,8 +402,14 @@ class SFTPFile (BufferedFile):
.. versionadded:: 1.4
"""
t, msg = self.sftp._request(
- CMD_EXTENDED, 'check-file', self.handle,
- hash_algorithm, long(offset), long(length), block_size)
+ CMD_EXTENDED,
+ "check-file",
+ self.handle,
+ hash_algorithm,
+ long(offset),
+ long(length),
+ block_size,
+ )
msg.get_text() # ext
msg.get_text() # alg
data = msg.get_remainder()
@@ -475,16 +489,16 @@ class SFTPFile (BufferedFile):
.. versionadded:: 1.5.4
"""
- self.sftp._log(DEBUG, 'readv({}, {!r})'.format(
- hexlify(self.handle), chunks))
+ self.sftp._log(
+ DEBUG, "readv({}, {!r})".format(hexlify(self.handle), chunks)
+ )
read_chunks = []
for offset, size in chunks:
# don't fetch data that's already in the prefetch buffer
- if (
- self._data_in_prefetch_buffers(offset) or
- self._data_in_prefetch_requests(offset, size)
- ):
+ if self._data_in_prefetch_buffers(
+ offset
+ ) or self._data_in_prefetch_requests(offset, size):
continue
# break up anything larger than the max read size
@@ -521,11 +535,8 @@ class SFTPFile (BufferedFile):
# a lot of them, so it may block.
for offset, length in chunks:
num = self.sftp._async_request(
- self,
- CMD_READ,
- self.handle,
- long(offset),
- int(length))
+ self, CMD_READ, self.handle, long(offset), int(length)
+ )
with self._prefetch_lock:
self._prefetch_extents[num] = (offset, length)
@@ -538,7 +549,7 @@ class SFTPFile (BufferedFile):
self._saved_exception = e
return
if t != CMD_DATA:
- raise SFTPError('Expected data')
+ raise SFTPError("Expected data")
data = msg.get_string()
while True:
with self._prefetch_lock:
diff --git a/paramiko/sftp_handle.py b/paramiko/sftp_handle.py
index ca473900..a7e22f01 100644
--- a/paramiko/sftp_handle.py
+++ b/paramiko/sftp_handle.py
@@ -25,7 +25,7 @@ from paramiko.sftp import SFTP_OP_UNSUPPORTED, SFTP_OK
from paramiko.util import ClosingContextManager
-class SFTPHandle (ClosingContextManager):
+class SFTPHandle(ClosingContextManager):
"""
Abstract object representing a handle to an open file (or folder) in an
SFTP server implementation. Each handle has a string representation used
@@ -36,6 +36,7 @@ class SFTPHandle (ClosingContextManager):
Instances of this class may be used as context managers.
"""
+
def __init__(self, flags=0):
"""
Create a new file handle representing a local file being served over
@@ -63,10 +64,10 @@ class SFTPHandle (ClosingContextManager):
using the default implementations of `read` and `write`, this
method's default implementation should be fine also.
"""
- readfile = getattr(self, 'readfile', None)
+ readfile = getattr(self, "readfile", None)
if readfile is not None:
readfile.close()
- writefile = getattr(self, 'writefile', None)
+ writefile = getattr(self, "writefile", None)
if writefile is not None:
writefile.close()
@@ -88,7 +89,7 @@ class SFTPHandle (ClosingContextManager):
:param int length: number of bytes to attempt to read.
:return: data read from the file, or an SFTP error code, as a `str`.
"""
- readfile = getattr(self, 'readfile', None)
+ readfile = getattr(self, "readfile", None)
if readfile is None:
return SFTP_OP_UNSUPPORTED
try:
@@ -122,7 +123,7 @@ class SFTPHandle (ClosingContextManager):
:param str data: data to write into the file.
:return: an SFTP error code like ``SFTP_OK``.
"""
- writefile = getattr(self, 'writefile', None)
+ writefile = getattr(self, "writefile", None)
if writefile is None:
return SFTP_OP_UNSUPPORTED
try:
diff --git a/paramiko/sftp_server.py b/paramiko/sftp_server.py
index f8c4f727..8265df96 100644
--- a/paramiko/sftp_server.py
+++ b/paramiko/sftp_server.py
@@ -27,7 +27,11 @@ from hashlib import md5, sha1
from paramiko import util
from paramiko.sftp import (
- BaseSFTP, Message, SFTP_FAILURE, SFTP_PERMISSION_DENIED, SFTP_NO_SUCH_FILE,
+ BaseSFTP,
+ Message,
+ SFTP_FAILURE,
+ SFTP_PERMISSION_DENIED,
+ SFTP_NO_SUCH_FILE,
)
from paramiko.sftp_si import SFTPServerInterface
from paramiko.sftp_attr import SFTPAttributes
@@ -38,30 +42,64 @@ from paramiko.server import SubsystemHandler
# known hash algorithms for the "check-file" extension
from paramiko.sftp import (
- CMD_HANDLE, SFTP_DESC, CMD_STATUS, SFTP_EOF, CMD_NAME, SFTP_BAD_MESSAGE,
- CMD_EXTENDED_REPLY, SFTP_FLAG_READ, SFTP_FLAG_WRITE, SFTP_FLAG_APPEND,
- SFTP_FLAG_CREATE, SFTP_FLAG_TRUNC, SFTP_FLAG_EXCL, CMD_NAMES, CMD_OPEN,
- CMD_CLOSE, SFTP_OK, CMD_READ, CMD_DATA, CMD_WRITE, CMD_REMOVE, CMD_RENAME,
- CMD_MKDIR, CMD_RMDIR, CMD_OPENDIR, CMD_READDIR, CMD_STAT, CMD_ATTRS,
- CMD_LSTAT, CMD_FSTAT, CMD_SETSTAT, CMD_FSETSTAT, CMD_READLINK, CMD_SYMLINK,
- CMD_REALPATH, CMD_EXTENDED, SFTP_OP_UNSUPPORTED,
+ CMD_HANDLE,
+ SFTP_DESC,
+ CMD_STATUS,
+ SFTP_EOF,
+ CMD_NAME,
+ SFTP_BAD_MESSAGE,
+ CMD_EXTENDED_REPLY,
+ SFTP_FLAG_READ,
+ SFTP_FLAG_WRITE,
+ SFTP_FLAG_APPEND,
+ SFTP_FLAG_CREATE,
+ SFTP_FLAG_TRUNC,
+ SFTP_FLAG_EXCL,
+ CMD_NAMES,
+ CMD_OPEN,
+ CMD_CLOSE,
+ SFTP_OK,
+ CMD_READ,
+ CMD_DATA,
+ CMD_WRITE,
+ CMD_REMOVE,
+ CMD_RENAME,
+ CMD_MKDIR,
+ CMD_RMDIR,
+ CMD_OPENDIR,
+ CMD_READDIR,
+ CMD_STAT,
+ CMD_ATTRS,
+ CMD_LSTAT,
+ CMD_FSTAT,
+ CMD_SETSTAT,
+ CMD_FSETSTAT,
+ CMD_READLINK,
+ CMD_SYMLINK,
+ CMD_REALPATH,
+ CMD_EXTENDED,
+ SFTP_OP_UNSUPPORTED,
)
-_hash_class = {
- 'sha1': sha1,
- 'md5': md5,
-}
+_hash_class = {"sha1": sha1, "md5": md5}
-class SFTPServer (BaseSFTP, SubsystemHandler):
+class SFTPServer(BaseSFTP, SubsystemHandler):
"""
Server-side SFTP subsystem support. Since this is a `.SubsystemHandler`,
it can be (and is meant to be) set as the handler for ``"sftp"`` requests.
Use `.Transport.set_subsystem_handler` to activate this class.
"""
- def __init__(self, channel, name, server, sftp_si=SFTPServerInterface,
- *largs, **kwargs):
+ def __init__(
+ self,
+ channel,
+ name,
+ server,
+ sftp_si=SFTPServerInterface,
+ *largs,
+ **kwargs
+ ):
"""
The constructor for SFTPServer is meant to be called from within the
`.Transport` as a subsystem handler. ``server`` and any additional
@@ -79,7 +117,7 @@ class SFTPServer (BaseSFTP, SubsystemHandler):
BaseSFTP.__init__(self)
SubsystemHandler.__init__(self, channel, name, server)
transport = channel.get_transport()
- self.logger = util.get_logger(transport.get_log_channel() + '.sftp')
+ self.logger = util.get_logger(transport.get_log_channel() + ".sftp")
self.ultra_debug = transport.get_hexdump()
self.next_handle = 1
# map of handle-string to SFTPHandle for files & folders:
@@ -91,26 +129,26 @@ class SFTPServer (BaseSFTP, SubsystemHandler):
if issubclass(type(msg), list):
for m in msg:
super(SFTPServer, self)._log(
- level,
- "[chan " + self.sock.get_name() + "] " + m)
+ level, "[chan " + self.sock.get_name() + "] " + m
+ )
else:
super(SFTPServer, self)._log(
- level,
- "[chan " + self.sock.get_name() + "] " + msg)
+ level, "[chan " + self.sock.get_name() + "] " + msg
+ )
def start_subsystem(self, name, transport, channel):
self.sock = channel
- self._log(DEBUG, 'Started sftp server on channel {!r}'.format(channel))
+ self._log(DEBUG, "Started sftp server on channel {!r}".format(channel))
self._send_server_version()
self.server.session_started()
while True:
try:
t, data = self._read_packet()
except EOFError:
- self._log(DEBUG, 'EOF -- end of session')
+ self._log(DEBUG, "EOF -- end of session")
return
except Exception as e:
- self._log(DEBUG, 'Exception on channel: ' + str(e))
+ self._log(DEBUG, "Exception on channel: " + str(e))
self._log(DEBUG, util.tb_strings())
return
msg = Message(data)
@@ -118,7 +156,7 @@ class SFTPServer (BaseSFTP, SubsystemHandler):
try:
self._process(t, request_number, msg)
except Exception as e:
- self._log(DEBUG, 'Exception in server processing: ' + str(e))
+ self._log(DEBUG, "Exception in server processing: " + str(e))
self._log(DEBUG, util.tb_strings())
# send some kind of failure message, at least
try:
@@ -172,7 +210,7 @@ class SFTPServer (BaseSFTP, SubsystemHandler):
name of the file to alter (should usually be an absolute path).
:param .SFTPAttributes attr: attributes to change.
"""
- if sys.platform != 'win32':
+ if sys.platform != "win32":
# mode operations are meaningless on win32
if attr._flags & attr.FLAG_PERMISSIONS:
os.chmod(filename, attr.st_mode)
@@ -181,7 +219,7 @@ class SFTPServer (BaseSFTP, SubsystemHandler):
if attr._flags & attr.FLAG_AMTIME:
os.utime(filename, (attr.st_atime, attr.st_mtime))
if attr._flags & attr.FLAG_SIZE:
- with open(filename, 'w+') as f:
+ with open(filename, "w+") as f:
f.truncate(attr.st_size)
# ...internals...
@@ -200,8 +238,8 @@ class SFTPServer (BaseSFTP, SubsystemHandler):
item._pack(msg)
else:
raise Exception(
- 'unknown type for {!r} type {!r}'.format(
- item, type(item)))
+ "unknown type for {!r} type {!r}".format(item, type(item))
+ )
self._send_packet(t, msg)
def _send_handle_response(self, request_number, handle, folder=False):
@@ -209,7 +247,7 @@ class SFTPServer (BaseSFTP, SubsystemHandler):
# must be error code
self._send_status(request_number, handle)
return
- handle._set_name(b('hx{:d}'.format(self.next_handle)))
+ handle._set_name(b("hx{:d}".format(self.next_handle)))
self.next_handle += 1
if folder:
self.folder_table[handle._get_name()] = handle
@@ -222,10 +260,10 @@ class SFTPServer (BaseSFTP, SubsystemHandler):
try:
desc = SFTP_DESC[code]
except IndexError:
- desc = 'Unknown'
+ desc = "Unknown"
# some clients expect a "langauge" tag at the end
# (but don't mind it being blank)
- self._response(request_number, CMD_STATUS, code, desc, '')
+ self._response(request_number, CMD_STATUS, code, desc, "")
def _open_folder(self, request_number, path):
resp = self.server.list_folder(path)
@@ -264,7 +302,8 @@ class SFTPServer (BaseSFTP, SubsystemHandler):
block_size = msg.get_int()
if handle not in self.file_table:
self._send_status(
- request_number, SFTP_BAD_MESSAGE, 'Invalid handle')
+ request_number, SFTP_BAD_MESSAGE, "Invalid handle"
+ )
return
f = self.file_table[handle]
for x in alg_list:
@@ -274,19 +313,21 @@ class SFTPServer (BaseSFTP, SubsystemHandler):
break
else:
self._send_status(
- request_number, SFTP_FAILURE, 'No supported hash types found')
+ request_number, SFTP_FAILURE, "No supported hash types found"
+ )
return
if length == 0:
st = f.stat()
if not issubclass(type(st), SFTPAttributes):
- self._send_status(request_number, st, 'Unable to stat file')
+ self._send_status(request_number, st, "Unable to stat file")
return
length = st.st_size - start
if block_size == 0:
block_size = length
if block_size < 256:
self._send_status(
- request_number, SFTP_FAILURE, 'Block size too small')
+ request_number, SFTP_FAILURE, "Block size too small"
+ )
return
sum_out = bytes()
@@ -301,7 +342,8 @@ class SFTPServer (BaseSFTP, SubsystemHandler):
data = f.read(offset, chunklen)
if not isinstance(data, bytes_types):
self._send_status(
- request_number, data, 'Unable to hash file')
+ request_number, data, "Unable to hash file"
+ )
return
hash_obj.update(data)
count += len(data)
@@ -310,7 +352,7 @@ class SFTPServer (BaseSFTP, SubsystemHandler):
msg = Message()
msg.add_int(request_number)
- msg.add_string('check-file')
+ msg.add_string("check-file")
msg.add_string(algname)
msg.add_bytes(sum_out)
self._send_packet(CMD_EXTENDED_REPLY, msg)
@@ -334,13 +376,14 @@ class SFTPServer (BaseSFTP, SubsystemHandler):
return flags
def _process(self, t, request_number, msg):
- self._log(DEBUG, 'Request: {}'.format(CMD_NAMES[t]))
+ self._log(DEBUG, "Request: {}".format(CMD_NAMES[t]))
if t == CMD_OPEN:
path = msg.get_text()
flags = self._convert_pflags(msg.get_int())
attr = SFTPAttributes._from_msg(msg)
self._send_handle_response(
- request_number, self.server.open(path, flags, attr))
+ request_number, self.server.open(path, flags, attr)
+ )
elif t == CMD_CLOSE:
handle = msg.get_binary()
if handle in self.folder_table:
@@ -353,14 +396,16 @@ class SFTPServer (BaseSFTP, SubsystemHandler):
self._send_status(request_number, SFTP_OK)
return
self._send_status(
- request_number, SFTP_BAD_MESSAGE, 'Invalid handle')
+ request_number, SFTP_BAD_MESSAGE, "Invalid handle"
+ )
elif t == CMD_READ:
handle = msg.get_binary()
offset = msg.get_int64()
length = msg.get_int()
if handle not in self.file_table:
self._send_status(
- request_number, SFTP_BAD_MESSAGE, 'Invalid handle')
+ request_number, SFTP_BAD_MESSAGE, "Invalid handle"
+ )
return
data = self.file_table[handle].read(offset, length)
if isinstance(data, (bytes_types, string_types)):
@@ -376,10 +421,12 @@ class SFTPServer (BaseSFTP, SubsystemHandler):
data = msg.get_binary()
if handle not in self.file_table:
self._send_status(
- request_number, SFTP_BAD_MESSAGE, 'Invalid handle')
+ request_number, SFTP_BAD_MESSAGE, "Invalid handle"
+ )
return
self._send_status(
- request_number, self.file_table[handle].write(offset, data))
+ request_number, self.file_table[handle].write(offset, data)
+ )
elif t == CMD_REMOVE:
path = msg.get_text()
self._send_status(request_number, self.server.remove(path))
@@ -387,7 +434,8 @@ class SFTPServer (BaseSFTP, SubsystemHandler):
oldpath = msg.get_text()
newpath = msg.get_text()
self._send_status(
- request_number, self.server.rename(oldpath, newpath))
+ request_number, self.server.rename(oldpath, newpath)
+ )
elif t == CMD_MKDIR:
path = msg.get_text()
attr = SFTPAttributes._from_msg(msg)
@@ -403,7 +451,8 @@ class SFTPServer (BaseSFTP, SubsystemHandler):
handle = msg.get_binary()
if handle not in self.folder_table:
self._send_status(
- request_number, SFTP_BAD_MESSAGE, 'Invalid handle')
+ request_number, SFTP_BAD_MESSAGE, "Invalid handle"
+ )
return
folder = self.folder_table[handle]
self._read_folder(request_number, folder)
@@ -425,7 +474,8 @@ class SFTPServer (BaseSFTP, SubsystemHandler):
handle = msg.get_binary()
if handle not in self.file_table:
self._send_status(
- request_number, SFTP_BAD_MESSAGE, 'Invalid handle')
+ request_number, SFTP_BAD_MESSAGE, "Invalid handle"
+ )
return
resp = self.file_table[handle].stat()
if issubclass(type(resp), SFTPAttributes):
@@ -441,16 +491,19 @@ class SFTPServer (BaseSFTP, SubsystemHandler):
attr = SFTPAttributes._from_msg(msg)
if handle not in self.file_table:
self._response(
- request_number, SFTP_BAD_MESSAGE, 'Invalid handle')
+ request_number, SFTP_BAD_MESSAGE, "Invalid handle"
+ )
return
self._send_status(
- request_number, self.file_table[handle].chattr(attr))
+ request_number, self.file_table[handle].chattr(attr)
+ )
elif t == CMD_READLINK:
path = msg.get_text()
resp = self.server.readlink(path)
if isinstance(resp, (bytes_types, string_types)):
self._response(
- request_number, CMD_NAME, 1, resp, '', SFTPAttributes())
+ request_number, CMD_NAME, 1, resp, "", SFTPAttributes()
+ )
else:
self._send_status(request_number, resp)
elif t == CMD_SYMLINK:
@@ -459,17 +512,19 @@ class SFTPServer (BaseSFTP, SubsystemHandler):
target_path = msg.get_text()
path = msg.get_text()
self._send_status(
- request_number, self.server.symlink(target_path, path))
+ request_number, self.server.symlink(target_path, path)
+ )
elif t == CMD_REALPATH:
path = msg.get_text()
rpath = self.server.canonicalize(path)
self._response(
- request_number, CMD_NAME, 1, rpath, '', SFTPAttributes())
+ request_number, CMD_NAME, 1, rpath, "", SFTPAttributes()
+ )
elif t == CMD_EXTENDED:
tag = msg.get_text()
- if tag == 'check-file':
+ if tag == "check-file":
self._check_file(request_number, msg)
- elif tag == 'posix-rename@openssh.com':
+ elif tag == "posix-rename@openssh.com":
oldpath = msg.get_text()
newpath = msg.get_text()
self._send_status(
diff --git a/paramiko/sftp_si.py b/paramiko/sftp_si.py
index b43b582c..40dc561c 100644
--- a/paramiko/sftp_si.py
+++ b/paramiko/sftp_si.py
@@ -25,7 +25,7 @@ import sys
from paramiko.sftp import SFTP_OP_UNSUPPORTED
-class SFTPServerInterface (object):
+class SFTPServerInterface(object):
"""
This class defines an interface for controlling the behavior of paramiko
when using the `.SFTPServer` subsystem to provide an SFTP server.
@@ -39,6 +39,7 @@ class SFTPServerInterface (object):
All paths are in string form instead of unicode because not all SFTP
clients & servers obey the requirement that paths be encoded in UTF-8.
"""
+
def __init__(self, server, *largs, **kwargs):
"""
Create a new SFTPServerInterface object. This method does nothing by
@@ -281,10 +282,10 @@ class SFTPServerInterface (object):
if os.path.isabs(path):
out = os.path.normpath(path)
else:
- out = os.path.normpath('/' + path)
- if sys.platform == 'win32':
+ out = os.path.normpath("/" + path)
+ if sys.platform == "win32":
# on windows, normalize backslashes to sftp/posix format
- out = out.replace('\\', '/')
+ out = out.replace("\\", "/")
return out
def readlink(self, path):
diff --git a/paramiko/ssh_exception.py b/paramiko/ssh_exception.py
index 2df84b65..12407d66 100644
--- a/paramiko/ssh_exception.py
+++ b/paramiko/ssh_exception.py
@@ -19,14 +19,14 @@
import socket
-class SSHException (Exception):
+class SSHException(Exception):
"""
Exception raised by failures in SSH2 protocol negotiation or logic errors.
"""
pass
-class AuthenticationException (SSHException):
+class AuthenticationException(SSHException):
"""
Exception raised when authentication failed for some reason. It may be
possible to retry with different credentials. (Other classes specify more
@@ -37,14 +37,14 @@ class AuthenticationException (SSHException):
pass
-class PasswordRequiredException (AuthenticationException):
+class PasswordRequiredException(AuthenticationException):
"""
Exception raised when a password is needed to unlock a private key file.
"""
pass
-class BadAuthenticationType (AuthenticationException):
+class BadAuthenticationType(AuthenticationException):
"""
Exception raised when an authentication type (like password) is used, but
the server isn't allowing that type. (It may only allow public-key, for
@@ -60,28 +60,28 @@ class BadAuthenticationType (AuthenticationException):
AuthenticationException.__init__(self, explanation)
self.allowed_types = types
# for unpickling
- self.args = (explanation, types, )
+ self.args = (explanation, types)
def __str__(self):
- return '{} (allowed_types={!r})'.format(
+ return "{} (allowed_types={!r})".format(
SSHException.__str__(self), self.allowed_types
)
-class PartialAuthentication (AuthenticationException):
+class PartialAuthentication(AuthenticationException):
"""
An internal exception thrown in the case of partial authentication.
"""
allowed_types = []
def __init__(self, types):
- AuthenticationException.__init__(self, 'partial authentication')
+ AuthenticationException.__init__(self, "partial authentication")
self.allowed_types = types
# for unpickling
- self.args = (types, )
+ self.args = (types,)
-class ChannelException (SSHException):
+class ChannelException(SSHException):
"""
Exception raised when an attempt to open a new `.Channel` fails.
@@ -89,14 +89,15 @@ class ChannelException (SSHException):
.. versionadded:: 1.6
"""
+
def __init__(self, code, text):
SSHException.__init__(self, text)
self.code = code
# for unpickling
- self.args = (code, text, )
+ self.args = (code, text)
-class BadHostKeyException (SSHException):
+class BadHostKeyException(SSHException):
"""
The host key given by the SSH server did not match what we were expecting.
@@ -106,35 +107,40 @@ class BadHostKeyException (SSHException):
.. versionadded:: 1.6
"""
+
def __init__(self, hostname, got_key, expected_key):
- message = 'Host key for server {} does not match: got {}, expected {}' # noqa
+ message = (
+ "Host key for server {} does not match: got {}, expected {}"
+ ) # noqa
message = message.format(
- hostname, got_key.get_base64(),
- expected_key.get_base64())
+ hostname, got_key.get_base64(), expected_key.get_base64()
+ )
SSHException.__init__(self, message)
self.hostname = hostname
self.key = got_key
self.expected_key = expected_key
# for unpickling
- self.args = (hostname, got_key, expected_key, )
+ self.args = (hostname, got_key, expected_key)
-class ProxyCommandFailure (SSHException):
+class ProxyCommandFailure(SSHException):
"""
The "ProxyCommand" found in the .ssh/config file returned an error.
:param str command: The command line that is generating this exception.
:param str error: The error captured from the proxy command output.
"""
+
def __init__(self, command, error):
- SSHException.__init__(self,
+ SSHException.__init__(
+ self,
'"ProxyCommand ({})" returned non-zero exit status: {}'.format(
command, error
- )
+ ),
)
self.error = error
# for unpickling
- self.args = (command, error, )
+ self.args = (command, error)
class NoValidConnectionsError(socket.error):
@@ -159,23 +165,23 @@ class NoValidConnectionsError(socket.error):
.. versionadded:: 1.16
"""
+
def __init__(self, errors):
"""
:param dict errors:
The errors dict to store, as described by class docstring.
"""
addrs = sorted(errors.keys())
- body = ', '.join([x[0] for x in addrs[:-1]])
+ body = ", ".join([x[0] for x in addrs[:-1]])
tail = addrs[-1][0]
if body:
msg = "Unable to connect to port {0} on {1} or {2}"
else:
msg = "Unable to connect to port {0} on {2}"
super(NoValidConnectionsError, self).__init__(
- None, # stand-in for errno
- msg.format(addrs[0][1], body, tail)
+ None, msg.format(addrs[0][1], body, tail) # stand-in for errno
)
self.errors = errors
def __reduce__(self):
- return (self.__class__, (self.errors, ))
+ return (self.__class__, (self.errors,))
diff --git a/paramiko/ssh_gss.py b/paramiko/ssh_gss.py
index aa7cc74d..31601381 100644
--- a/paramiko/ssh_gss.py
+++ b/paramiko/ssh_gss.py
@@ -49,12 +49,14 @@ _API = "MIT"
try:
import gssapi
+
GSS_EXCEPTIONS = (gssapi.GSSException,)
except (ImportError, OSError):
try:
import pywintypes
import sspicon
import sspi
+
_API = "SSPI"
GSS_EXCEPTIONS = (pywintypes.error,)
except ImportError:
@@ -99,6 +101,7 @@ class _SSH_GSSAuth(object):
Contains the shared variables and methods of `._SSH_GSSAPI` and
`._SSH_SSPI`.
"""
+
def __init__(self, auth_method, gss_deleg_creds):
"""
:param str auth_method: The name of the SSH authentication mechanism
@@ -210,7 +213,7 @@ class _SSH_GSSAuth(object):
"""
mic = self._make_uint32(len(session_id))
mic += session_id
- mic += struct.pack('B', MSG_USERAUTH_REQUEST)
+ mic += struct.pack("B", MSG_USERAUTH_REQUEST)
mic += self._make_uint32(len(username))
mic += username.encode()
mic += self._make_uint32(len(service))
@@ -226,6 +229,7 @@ class _SSH_GSSAPI(_SSH_GSSAuth):
:see: `.GSSAuth`
"""
+
def __init__(self, auth_method, gss_deleg_creds):
"""
:param str auth_method: The name of the SSH authentication mechanism
@@ -235,17 +239,22 @@ class _SSH_GSSAPI(_SSH_GSSAuth):
_SSH_GSSAuth.__init__(self, auth_method, gss_deleg_creds)
if self._gss_deleg_creds:
- self._gss_flags = (gssapi.C_PROT_READY_FLAG,
- gssapi.C_INTEG_FLAG,
- gssapi.C_MUTUAL_FLAG,
- gssapi.C_DELEG_FLAG)
+ self._gss_flags = (
+ gssapi.C_PROT_READY_FLAG,
+ gssapi.C_INTEG_FLAG,
+ gssapi.C_MUTUAL_FLAG,
+ gssapi.C_DELEG_FLAG,
+ )
else:
- self._gss_flags = (gssapi.C_PROT_READY_FLAG,
- gssapi.C_INTEG_FLAG,
- gssapi.C_MUTUAL_FLAG)
+ self._gss_flags = (
+ gssapi.C_PROT_READY_FLAG,
+ gssapi.C_INTEG_FLAG,
+ gssapi.C_MUTUAL_FLAG,
+ )
- def ssh_init_sec_context(self, target, desired_mech=None,
- username=None, recv_token=None):
+ def ssh_init_sec_context(
+ self, target, desired_mech=None, username=None, recv_token=None
+ ):
"""
Initialize a GSS-API context.
@@ -264,8 +273,9 @@ class _SSH_GSSAPI(_SSH_GSSAuth):
from pyasn1.codec.der import decoder
self._username = username
self._gss_host = target
- targ_name = gssapi.Name("host@" + self._gss_host,
- gssapi.C_NT_HOSTBASED_SERVICE)
+ targ_name = gssapi.Name(
+ "host@" + self._gss_host, gssapi.C_NT_HOSTBASED_SERVICE
+ )
ctx = gssapi.Context()
ctx.flags = self._gss_flags
if desired_mech is None:
@@ -279,15 +289,16 @@ class _SSH_GSSAPI(_SSH_GSSAuth):
token = None
try:
if recv_token is None:
- self._gss_ctxt = gssapi.InitContext(peer_name=targ_name,
- mech_type=krb5_mech,
- req_flags=ctx.flags)
+ self._gss_ctxt = gssapi.InitContext(
+ peer_name=targ_name,
+ mech_type=krb5_mech,
+ req_flags=ctx.flags,
+ )
token = self._gss_ctxt.step(token)
else:
token = self._gss_ctxt.step(recv_token)
except gssapi.GSSException:
- message = "{} Target: {}".format(
- sys.exc_info()[1], self._gss_host)
+ message = "{} Target: {}".format(sys.exc_info()[1], self._gss_host)
raise gssapi.GSSException(message)
self._gss_ctxt_status = self._gss_ctxt.established
return token
@@ -307,10 +318,12 @@ class _SSH_GSSAPI(_SSH_GSSAuth):
"""
self._session_id = session_id
if not gss_kex:
- mic_field = self._ssh_build_mic(self._session_id,
- self._username,
- self._service,
- self._auth_method)
+ mic_field = self._ssh_build_mic(
+ self._session_id,
+ self._username,
+ self._service,
+ self._auth_method,
+ )
mic_token = self._gss_ctxt.get_mic(mic_field)
else:
# for key exchange with gssapi-keyex
@@ -351,16 +364,17 @@ class _SSH_GSSAPI(_SSH_GSSAuth):
self._username = username
if self._username is not None:
# server mode
- mic_field = self._ssh_build_mic(self._session_id,
- self._username,
- self._service,
- self._auth_method)
+ mic_field = self._ssh_build_mic(
+ self._session_id,
+ self._username,
+ self._service,
+ self._auth_method,
+ )
self._gss_srv_ctxt.verify_mic(mic_field, mic_token)
else:
# for key exchange with gssapi-keyex
# client mode
- self._gss_ctxt.verify_mic(self._session_id,
- mic_token)
+ self._gss_ctxt.verify_mic(self._session_id, mic_token)
@property
def credentials_delegated(self):
@@ -393,6 +407,7 @@ class _SSH_SSPI(_SSH_GSSAuth):
:see: `.GSSAuth`
"""
+
def __init__(self, auth_method, gss_deleg_creds):
"""
:param str auth_method: The name of the SSH authentication mechanism
@@ -403,18 +418,18 @@ class _SSH_SSPI(_SSH_GSSAuth):
if self._gss_deleg_creds:
self._gss_flags = (
- sspicon.ISC_REQ_INTEGRITY |
- sspicon.ISC_REQ_MUTUAL_AUTH |
- sspicon.ISC_REQ_DELEGATE
+ sspicon.ISC_REQ_INTEGRITY
+ | sspicon.ISC_REQ_MUTUAL_AUTH
+ | sspicon.ISC_REQ_DELEGATE
)
else:
self._gss_flags = (
- sspicon.ISC_REQ_INTEGRITY |
- sspicon.ISC_REQ_MUTUAL_AUTH
+ sspicon.ISC_REQ_INTEGRITY | sspicon.ISC_REQ_MUTUAL_AUTH
)
- def ssh_init_sec_context(self, target, desired_mech=None,
- username=None, recv_token=None):
+ def ssh_init_sec_context(
+ self, target, desired_mech=None, username=None, recv_token=None
+ ):
"""
Initialize a SSPI context.
@@ -441,9 +456,9 @@ class _SSH_SSPI(_SSH_GSSAuth):
raise SSHException("Unsupported mechanism OID.")
try:
if recv_token is None:
- self._gss_ctxt = sspi.ClientAuth("Kerberos",
- scflags=self._gss_flags,
- targetspn=targ_name)
+ self._gss_ctxt = sspi.ClientAuth(
+ "Kerberos", scflags=self._gss_flags, targetspn=targ_name
+ )
error, token = self._gss_ctxt.authorize(recv_token)
token = token[0].Buffer
except pywintypes.error as e:
@@ -478,10 +493,12 @@ class _SSH_SSPI(_SSH_GSSAuth):
"""
self._session_id = session_id
if not gss_kex:
- mic_field = self._ssh_build_mic(self._session_id,
- self._username,
- self._service,
- self._auth_method)
+ mic_field = self._ssh_build_mic(
+ self._session_id,
+ self._username,
+ self._service,
+ self._auth_method,
+ )
mic_token = self._gss_ctxt.sign(mic_field)
else:
# for key exchange with gssapi-keyex
@@ -524,10 +541,12 @@ class _SSH_SSPI(_SSH_GSSAuth):
self._username = username
if username is not None:
# server mode
- mic_field = self._ssh_build_mic(self._session_id,
- self._username,
- self._service,
- self._auth_method)
+ mic_field = self._ssh_build_mic(
+ self._session_id,
+ self._username,
+ self._service,
+ self._auth_method,
+ )
# Verifies data and its signature. If verification fails, an
# sspi.error will be raised.
self._gss_srv_ctxt.verify(mic_field, mic_token)
@@ -545,9 +564,8 @@ class _SSH_SSPI(_SSH_GSSAuth):
:return: ``True`` if credentials are delegated, otherwise ``False``
"""
- return (
- self._gss_flags & sspicon.ISC_REQ_DELEGATE and
- (self._gss_srv_ctxt_status or self._gss_flags)
+ return self._gss_flags & sspicon.ISC_REQ_DELEGATE and (
+ self._gss_srv_ctxt_status or self._gss_flags
)
def save_client_creds(self, client_token):
diff --git a/paramiko/transport.py b/paramiko/transport.py
index ddcb2912..4e6cb2c1 100644
--- a/paramiko/transport.py
+++ b/paramiko/transport.py
@@ -39,18 +39,49 @@ from paramiko.auth_handler import AuthHandler
from paramiko.ssh_gss import GSSAuth
from paramiko.channel import Channel
from paramiko.common import (
- xffffffff, cMSG_CHANNEL_OPEN, cMSG_IGNORE, cMSG_GLOBAL_REQUEST, DEBUG,
- MSG_KEXINIT, MSG_IGNORE, MSG_DISCONNECT, MSG_DEBUG, ERROR, WARNING,
- cMSG_UNIMPLEMENTED, INFO, cMSG_KEXINIT, cMSG_NEWKEYS, MSG_NEWKEYS,
- cMSG_REQUEST_SUCCESS, cMSG_REQUEST_FAILURE, CONNECTION_FAILED_CODE,
- OPEN_FAILED_ADMINISTRATIVELY_PROHIBITED, OPEN_SUCCEEDED,
- cMSG_CHANNEL_OPEN_FAILURE, cMSG_CHANNEL_OPEN_SUCCESS, MSG_GLOBAL_REQUEST,
- MSG_REQUEST_SUCCESS, MSG_REQUEST_FAILURE, MSG_CHANNEL_OPEN_SUCCESS,
- MSG_CHANNEL_OPEN_FAILURE, MSG_CHANNEL_OPEN, MSG_CHANNEL_SUCCESS,
- MSG_CHANNEL_FAILURE, MSG_CHANNEL_DATA, MSG_CHANNEL_EXTENDED_DATA,
- MSG_CHANNEL_WINDOW_ADJUST, MSG_CHANNEL_REQUEST, MSG_CHANNEL_EOF,
- MSG_CHANNEL_CLOSE, MIN_WINDOW_SIZE, MIN_PACKET_SIZE, MAX_WINDOW_SIZE,
- DEFAULT_WINDOW_SIZE, DEFAULT_MAX_PACKET_SIZE, HIGHEST_USERAUTH_MESSAGE_ID,
+ xffffffff,
+ cMSG_CHANNEL_OPEN,
+ cMSG_IGNORE,
+ cMSG_GLOBAL_REQUEST,
+ DEBUG,
+ MSG_KEXINIT,
+ MSG_IGNORE,
+ MSG_DISCONNECT,
+ MSG_DEBUG,
+ ERROR,
+ WARNING,
+ cMSG_UNIMPLEMENTED,
+ INFO,
+ cMSG_KEXINIT,
+ cMSG_NEWKEYS,
+ MSG_NEWKEYS,
+ cMSG_REQUEST_SUCCESS,
+ cMSG_REQUEST_FAILURE,
+ CONNECTION_FAILED_CODE,
+ OPEN_FAILED_ADMINISTRATIVELY_PROHIBITED,
+ OPEN_SUCCEEDED,
+ cMSG_CHANNEL_OPEN_FAILURE,
+ cMSG_CHANNEL_OPEN_SUCCESS,
+ MSG_GLOBAL_REQUEST,
+ MSG_REQUEST_SUCCESS,
+ MSG_REQUEST_FAILURE,
+ MSG_CHANNEL_OPEN_SUCCESS,
+ MSG_CHANNEL_OPEN_FAILURE,
+ MSG_CHANNEL_OPEN,
+ MSG_CHANNEL_SUCCESS,
+ MSG_CHANNEL_FAILURE,
+ MSG_CHANNEL_DATA,
+ MSG_CHANNEL_EXTENDED_DATA,
+ MSG_CHANNEL_WINDOW_ADJUST,
+ MSG_CHANNEL_REQUEST,
+ MSG_CHANNEL_EOF,
+ MSG_CHANNEL_CLOSE,
+ MIN_WINDOW_SIZE,
+ MIN_PACKET_SIZE,
+ MAX_WINDOW_SIZE,
+ DEFAULT_WINDOW_SIZE,
+ DEFAULT_MAX_PACKET_SIZE,
+ HIGHEST_USERAUTH_MESSAGE_ID,
)
from paramiko.compress import ZlibCompressor, ZlibDecompressor
from paramiko.dsskey import DSSKey
@@ -69,7 +100,10 @@ from paramiko.ecdsakey import ECDSAKey
from paramiko.server import ServerInterface
from paramiko.sftp_client import SFTPClient
from paramiko.ssh_exception import (
- SSHException, BadAuthenticationType, ChannelException, ProxyCommandFailure,
+ SSHException,
+ BadAuthenticationType,
+ ChannelException,
+ ProxyCommandFailure,
)
from paramiko.util import retry_on_signal, ClosingContextManager, clamp_value
@@ -77,12 +111,14 @@ from paramiko.util import retry_on_signal, ClosingContextManager, clamp_value
# for thread cleanup
_active_threads = []
+
def _join_lingering_threads():
for thr in _active_threads:
thr.stop_thread()
import atexit
+
atexit.register(_join_lingering_threads)
@@ -99,160 +135,161 @@ class Transport(threading.Thread, ClosingContextManager):
_ENCRYPT = object()
_DECRYPT = object()
- _PROTO_ID = '2.0'
- _CLIENT_ID = 'paramiko_{}'.format(paramiko.__version__)
+ _PROTO_ID = "2.0"
+ _CLIENT_ID = "paramiko_{}".format(paramiko.__version__)
# These tuples of algorithm identifiers are in preference order; do not
# reorder without reason!
_preferred_ciphers = (
- 'aes128-ctr',
- 'aes192-ctr',
- 'aes256-ctr',
- 'aes128-cbc',
- 'aes192-cbc',
- 'aes256-cbc',
- 'blowfish-cbc',
- '3des-cbc',
+ "aes128-ctr",
+ "aes192-ctr",
+ "aes256-ctr",
+ "aes128-cbc",
+ "aes192-cbc",
+ "aes256-cbc",
+ "blowfish-cbc",
+ "3des-cbc",
)
_preferred_macs = (
- 'hmac-sha2-256',
- 'hmac-sha2-512',
- 'hmac-sha1',
- 'hmac-md5',
- 'hmac-sha1-96',
- 'hmac-md5-96',
+ "hmac-sha2-256",
+ "hmac-sha2-512",
+ "hmac-sha1",
+ "hmac-md5",
+ "hmac-sha1-96",
+ "hmac-md5-96",
)
_preferred_keys = (
- 'ssh-ed25519',
- 'ecdsa-sha2-nistp256',
- 'ecdsa-sha2-nistp384',
- 'ecdsa-sha2-nistp521',
- 'ssh-rsa',
- 'ssh-dss',
+ "ssh-ed25519",
+ "ecdsa-sha2-nistp256",
+ "ecdsa-sha2-nistp384",
+ "ecdsa-sha2-nistp521",
+ "ssh-rsa",
+ "ssh-dss",
)
_preferred_kex = (
- 'ecdh-sha2-nistp256',
- 'ecdh-sha2-nistp384',
- 'ecdh-sha2-nistp521',
- 'diffie-hellman-group-exchange-sha256',
- 'diffie-hellman-group-exchange-sha1',
- 'diffie-hellman-group14-sha1',
- 'diffie-hellman-group1-sha1',
+ "ecdh-sha2-nistp256",
+ "ecdh-sha2-nistp384",
+ "ecdh-sha2-nistp521",
+ "diffie-hellman-group-exchange-sha256",
+ "diffie-hellman-group-exchange-sha1",
+ "diffie-hellman-group14-sha1",
+ "diffie-hellman-group1-sha1",
)
_preferred_gsskex = (
- 'gss-gex-sha1-toWM5Slw5Ew8Mqkay+al2g==',
- 'gss-group14-sha1-toWM5Slw5Ew8Mqkay+al2g==',
- 'gss-group1-sha1-toWM5Slw5Ew8Mqkay+al2g==',
+ "gss-gex-sha1-toWM5Slw5Ew8Mqkay+al2g==",
+ "gss-group14-sha1-toWM5Slw5Ew8Mqkay+al2g==",
+ "gss-group1-sha1-toWM5Slw5Ew8Mqkay+al2g==",
)
- _preferred_compression = ('none',)
+ _preferred_compression = ("none",)
_cipher_info = {
- 'aes128-ctr': {
- 'class': algorithms.AES,
- 'mode': modes.CTR,
- 'block-size': 16,
- 'key-size': 16
+ "aes128-ctr": {
+ "class": algorithms.AES,
+ "mode": modes.CTR,
+ "block-size": 16,
+ "key-size": 16,
},
- 'aes192-ctr': {
- 'class': algorithms.AES,
- 'mode': modes.CTR,
- 'block-size': 16,
- 'key-size': 24
+ "aes192-ctr": {
+ "class": algorithms.AES,
+ "mode": modes.CTR,
+ "block-size": 16,
+ "key-size": 24,
},
- 'aes256-ctr': {
- 'class': algorithms.AES,
- 'mode': modes.CTR,
- 'block-size': 16,
- 'key-size': 32
+ "aes256-ctr": {
+ "class": algorithms.AES,
+ "mode": modes.CTR,
+ "block-size": 16,
+ "key-size": 32,
},
- 'blowfish-cbc': {
- 'class': algorithms.Blowfish,
- 'mode': modes.CBC,
- 'block-size': 8,
- 'key-size': 16
+ "blowfish-cbc": {
+ "class": algorithms.Blowfish,
+ "mode": modes.CBC,
+ "block-size": 8,
+ "key-size": 16,
},
- 'aes128-cbc': {
- 'class': algorithms.AES,
- 'mode': modes.CBC,
- 'block-size': 16,
- 'key-size': 16
+ "aes128-cbc": {
+ "class": algorithms.AES,
+ "mode": modes.CBC,
+ "block-size": 16,
+ "key-size": 16,
},
- 'aes192-cbc': {
- 'class': algorithms.AES,
- 'mode': modes.CBC,
- 'block-size': 16,
- 'key-size': 24
+ "aes192-cbc": {
+ "class": algorithms.AES,
+ "mode": modes.CBC,
+ "block-size": 16,
+ "key-size": 24,
},
- 'aes256-cbc': {
- 'class': algorithms.AES,
- 'mode': modes.CBC,
- 'block-size': 16,
- 'key-size': 32
+ "aes256-cbc": {
+ "class": algorithms.AES,
+ "mode": modes.CBC,
+ "block-size": 16,
+ "key-size": 32,
},
- '3des-cbc': {
- 'class': algorithms.TripleDES,
- 'mode': modes.CBC,
- 'block-size': 8,
- 'key-size': 24
+ "3des-cbc": {
+ "class": algorithms.TripleDES,
+ "mode": modes.CBC,
+ "block-size": 8,
+ "key-size": 24,
},
}
-
_mac_info = {
- 'hmac-sha1': {'class': sha1, 'size': 20},
- 'hmac-sha1-96': {'class': sha1, 'size': 12},
- 'hmac-sha2-256': {'class': sha256, 'size': 32},
- 'hmac-sha2-512': {'class': sha512, 'size': 64},
- 'hmac-md5': {'class': md5, 'size': 16},
- 'hmac-md5-96': {'class': md5, 'size': 12},
+ "hmac-sha1": {"class": sha1, "size": 20},
+ "hmac-sha1-96": {"class": sha1, "size": 12},
+ "hmac-sha2-256": {"class": sha256, "size": 32},
+ "hmac-sha2-512": {"class": sha512, "size": 64},
+ "hmac-md5": {"class": md5, "size": 16},
+ "hmac-md5-96": {"class": md5, "size": 12},
}
_key_info = {
- 'ssh-rsa': RSAKey,
- 'ssh-rsa-cert-v01@openssh.com': RSAKey,
- 'ssh-dss': DSSKey,
- 'ssh-dss-cert-v01@openssh.com': DSSKey,
- 'ecdsa-sha2-nistp256': ECDSAKey,
- 'ecdsa-sha2-nistp256-cert-v01@openssh.com': ECDSAKey,
- 'ecdsa-sha2-nistp384': ECDSAKey,
- 'ecdsa-sha2-nistp384-cert-v01@openssh.com': ECDSAKey,
- 'ecdsa-sha2-nistp521': ECDSAKey,
- 'ecdsa-sha2-nistp521-cert-v01@openssh.com': ECDSAKey,
- 'ssh-ed25519': Ed25519Key,
- 'ssh-ed25519-cert-v01@openssh.com': Ed25519Key,
+ "ssh-rsa": RSAKey,
+ "ssh-rsa-cert-v01@openssh.com": RSAKey,
+ "ssh-dss": DSSKey,
+ "ssh-dss-cert-v01@openssh.com": DSSKey,
+ "ecdsa-sha2-nistp256": ECDSAKey,
+ "ecdsa-sha2-nistp256-cert-v01@openssh.com": ECDSAKey,
+ "ecdsa-sha2-nistp384": ECDSAKey,
+ "ecdsa-sha2-nistp384-cert-v01@openssh.com": ECDSAKey,
+ "ecdsa-sha2-nistp521": ECDSAKey,
+ "ecdsa-sha2-nistp521-cert-v01@openssh.com": ECDSAKey,
+ "ssh-ed25519": Ed25519Key,
+ "ssh-ed25519-cert-v01@openssh.com": Ed25519Key,
}
_kex_info = {
- 'diffie-hellman-group1-sha1': KexGroup1,
- 'diffie-hellman-group14-sha1': KexGroup14,
- 'diffie-hellman-group-exchange-sha1': KexGex,
- 'diffie-hellman-group-exchange-sha256': KexGexSHA256,
- 'gss-group1-sha1-toWM5Slw5Ew8Mqkay+al2g==': KexGSSGroup1,
- 'gss-group14-sha1-toWM5Slw5Ew8Mqkay+al2g==': KexGSSGroup14,
- 'gss-gex-sha1-toWM5Slw5Ew8Mqkay+al2g==': KexGSSGex,
- 'ecdh-sha2-nistp256': KexNistp256,
- 'ecdh-sha2-nistp384': KexNistp384,
- 'ecdh-sha2-nistp521': KexNistp521,
+ "diffie-hellman-group1-sha1": KexGroup1,
+ "diffie-hellman-group14-sha1": KexGroup14,
+ "diffie-hellman-group-exchange-sha1": KexGex,
+ "diffie-hellman-group-exchange-sha256": KexGexSHA256,
+ "gss-group1-sha1-toWM5Slw5Ew8Mqkay+al2g==": KexGSSGroup1,
+ "gss-group14-sha1-toWM5Slw5Ew8Mqkay+al2g==": KexGSSGroup14,
+ "gss-gex-sha1-toWM5Slw5Ew8Mqkay+al2g==": KexGSSGex,
+ "ecdh-sha2-nistp256": KexNistp256,
+ "ecdh-sha2-nistp384": KexNistp384,
+ "ecdh-sha2-nistp521": KexNistp521,
}
_compression_info = {
# zlib@openssh.com is just zlib, but only turned on after a successful
# authentication. openssh servers may only offer this type because
# they've had troubles with security holes in zlib in the past.
- 'zlib@openssh.com': (ZlibCompressor, ZlibDecompressor),
- 'zlib': (ZlibCompressor, ZlibDecompressor),
- 'none': (None, None),
+ "zlib@openssh.com": (ZlibCompressor, ZlibDecompressor),
+ "zlib": (ZlibCompressor, ZlibDecompressor),
+ "none": (None, None),
}
_modulus_pack = None
_active_check_timeout = 0.1
- def __init__(self,
- sock,
- default_window_size=DEFAULT_WINDOW_SIZE,
- default_max_packet_size=DEFAULT_MAX_PACKET_SIZE,
- gss_kex=False,
- gss_deleg_creds=True):
+ def __init__(
+ self,
+ sock,
+ default_window_size=DEFAULT_WINDOW_SIZE,
+ default_max_packet_size=DEFAULT_MAX_PACKET_SIZE,
+ gss_kex=False,
+ gss_deleg_creds=True,
+ ):
"""
Create a new SSH session over an existing socket, or socket-like
object. This only creates the `.Transport` object; it doesn't begin
@@ -302,7 +339,7 @@ class Transport(threading.Thread, ClosingContextManager):
if isinstance(sock, string_types):
# convert "host:port" into (host, port)
- hl = sock.split(':', 1)
+ hl = sock.split(":", 1)
self.hostname = hl[0]
if len(hl) == 1:
sock = (hl[0], 22)
@@ -312,7 +349,7 @@ class Transport(threading.Thread, ClosingContextManager):
# connect to the given (host, port)
hostname, port = sock
self.hostname = hostname
- reason = 'No suitable address family'
+ reason = "No suitable address family"
addrinfos = socket.getaddrinfo(
hostname, port, socket.AF_UNSPEC, socket.SOCK_STREAM
)
@@ -329,7 +366,8 @@ class Transport(threading.Thread, ClosingContextManager):
break
else:
raise SSHException(
- 'Unable to connect to {}: {}'.format(hostname, reason))
+ "Unable to connect to {}: {}".format(hostname, reason)
+ )
# okay, normal socket-ish flow here...
threading.Thread.__init__(self)
self.setDaemon(True)
@@ -340,9 +378,9 @@ class Transport(threading.Thread, ClosingContextManager):
# negotiated crypto parameters
self.packetizer = Packetizer(sock)
- self.local_version = 'SSH-' + self._PROTO_ID + '-' + self._CLIENT_ID
- self.remote_version = ''
- self.local_cipher = self.remote_cipher = ''
+ self.local_version = "SSH-" + self._PROTO_ID + "-" + self._CLIENT_ID
+ self.remote_version = ""
+ self.local_cipher = self.remote_cipher = ""
self.local_kex_init = self.remote_kex_init = None
self.local_mac = self.remote_mac = None
self.local_compression = self.remote_compression = None
@@ -374,8 +412,8 @@ class Transport(threading.Thread, ClosingContextManager):
# tracking open channels
self._channels = ChannelMap()
- self.channel_events = {} # (id -> Event)
- self.channels_seen = {} # (id -> True)
+ self.channel_events = {} # (id -> Event)
+ self.channels_seen = {} # (id -> True)
self._channel_counter = 0
self.default_max_packet_size = default_max_packet_size
self.default_window_size = default_window_size
@@ -387,7 +425,7 @@ class Transport(threading.Thread, ClosingContextManager):
self.clear_to_send = threading.Event()
self.clear_to_send_lock = threading.Lock()
self.clear_to_send_timeout = 30.0
- self.log_name = 'paramiko.transport'
+ self.log_name = "paramiko.transport"
self.logger = util.get_logger(self.log_name)
self.packetizer.set_log(self.logger)
self.auth_handler = None
@@ -416,23 +454,24 @@ class Transport(threading.Thread, ClosingContextManager):
Returns a string representation of this object, for debugging.
"""
id_ = hex(long(id(self)) & xffffffff)
- out = '<paramiko.Transport at {}'.format(id_)
+ out = "<paramiko.Transport at {}".format(id_)
if not self.active:
- out += ' (unconnected)'
+ out += " (unconnected)"
else:
- if self.local_cipher != '':
- out += ' (cipher {}, {:d} bits)'.format(
+ if self.local_cipher != "":
+ out += " (cipher {}, {:d} bits)".format(
self.local_cipher,
- self._cipher_info[self.local_cipher]['key-size'] * 8
+ self._cipher_info[self.local_cipher]["key-size"] * 8,
)
if self.is_authenticated():
- out += ' (active; {} open channel(s))'.format(
- len(self._channels))
+ out += " (active; {} open channel(s))".format(
+ len(self._channels)
+ )
elif self.initial_kex_done:
- out += ' (connected; awaiting auth)'
+ out += " (connected; awaiting auth)"
else:
- out += ' (connecting)'
- out += '>'
+ out += " (connecting)"
+ out += ">"
return out
def atfork(self):
@@ -543,10 +582,9 @@ class Transport(threading.Thread, ClosingContextManager):
e = self.get_exception()
if e is not None:
raise e
- raise SSHException('Negotiation failed.')
- if (
- event.is_set() or
- (timeout is not None and time.time() >= max_time)
+ raise SSHException("Negotiation failed.")
+ if event.is_set() or (
+ timeout is not None and time.time() >= max_time
):
break
@@ -612,7 +650,7 @@ class Transport(threading.Thread, ClosingContextManager):
e = self.get_exception()
if e is not None:
raise e
- raise SSHException('Negotiation failed.')
+ raise SSHException("Negotiation failed.")
if event.is_set():
break
@@ -679,7 +717,7 @@ class Transport(threading.Thread, ClosingContextManager):
"""
Transport._modulus_pack = ModulusPack()
# places to look for the openssh "moduli" file
- file_list = ['/etc/ssh/moduli', '/usr/local/etc/moduli']
+ file_list = ["/etc/ssh/moduli", "/usr/local/etc/moduli"]
if filename is not None:
file_list.insert(0, filename)
for fn in file_list:
@@ -717,7 +755,7 @@ class Transport(threading.Thread, ClosingContextManager):
:return: public key (`.PKey`) of the remote server
"""
if (not self.active) or (not self.initial_kex_done):
- raise SSHException('No existing session')
+ raise SSHException("No existing session")
return self.host_key
def is_active(self):
@@ -731,10 +769,7 @@ class Transport(threading.Thread, ClosingContextManager):
return self.active
def open_session(
- self,
- window_size=None,
- max_packet_size=None,
- timeout=None,
+ self, window_size=None, max_packet_size=None, timeout=None
):
"""
Request a new channel to the server, of type ``"session"``. This is
@@ -761,10 +796,12 @@ class Transport(threading.Thread, ClosingContextManager):
.. versionchanged:: 1.15
Added the ``window_size`` and ``max_packet_size`` arguments.
"""
- return self.open_channel('session',
- window_size=window_size,
- max_packet_size=max_packet_size,
- timeout=timeout)
+ return self.open_channel(
+ "session",
+ window_size=window_size,
+ max_packet_size=max_packet_size,
+ timeout=timeout,
+ )
def open_x11_channel(self, src_addr=None):
"""
@@ -780,7 +817,7 @@ class Transport(threading.Thread, ClosingContextManager):
`.SSHException` -- if the request is rejected or the session ends
prematurely
"""
- return self.open_channel('x11', src_addr=src_addr)
+ return self.open_channel("x11", src_addr=src_addr)
def open_forward_agent_channel(self):
"""
@@ -794,7 +831,7 @@ class Transport(threading.Thread, ClosingContextManager):
:raises: `.SSHException` --
if the request is rejected or the session ends prematurely
"""
- return self.open_channel('auth-agent@openssh.com')
+ return self.open_channel("auth-agent@openssh.com")
def open_forwarded_tcpip_channel(self, src_addr, dest_addr):
"""
@@ -806,15 +843,17 @@ class Transport(threading.Thread, ClosingContextManager):
:param src_addr: originator's address
:param dest_addr: local (server) connected address
"""
- return self.open_channel('forwarded-tcpip', dest_addr, src_addr)
+ return self.open_channel("forwarded-tcpip", dest_addr, src_addr)
- def open_channel(self,
- kind,
- dest_addr=None,
- src_addr=None,
- window_size=None,
- max_packet_size=None,
- timeout=None):
+ def open_channel(
+ self,
+ kind,
+ dest_addr=None,
+ src_addr=None,
+ window_size=None,
+ max_packet_size=None,
+ timeout=None,
+ ):
"""
Request a new channel to the server. `Channels <.Channel>` are
socket-like objects used for the actual transfer of data across the
@@ -851,7 +890,7 @@ class Transport(threading.Thread, ClosingContextManager):
Added the ``window_size`` and ``max_packet_size`` arguments.
"""
if not self.active:
- raise SSHException('SSH session not active')
+ raise SSHException("SSH session not active")
timeout = 3600 if timeout is None else timeout
self.lock.acquire()
try:
@@ -864,12 +903,12 @@ class Transport(threading.Thread, ClosingContextManager):
m.add_int(chanid)
m.add_int(window_size)
m.add_int(max_packet_size)
- if (kind == 'forwarded-tcpip') or (kind == 'direct-tcpip'):
+ if (kind == "forwarded-tcpip") or (kind == "direct-tcpip"):
m.add_string(dest_addr[0])
m.add_int(dest_addr[1])
m.add_string(src_addr[0])
m.add_int(src_addr[1])
- elif kind == 'x11':
+ elif kind == "x11":
m.add_string(src_addr[0])
m.add_int(src_addr[1])
chan = Channel(chanid)
@@ -887,18 +926,18 @@ class Transport(threading.Thread, ClosingContextManager):
if not self.active:
e = self.get_exception()
if e is None:
- e = SSHException('Unable to open channel.')
+ e = SSHException("Unable to open channel.")
raise e
if event.is_set():
break
elif start_ts + timeout < time.time():
- raise SSHException('Timeout opening channel.')
+ raise SSHException("Timeout opening channel.")
chan = self._channels.get(chanid)
if chan is not None:
return chan
e = self.get_exception()
if e is None:
- e = SSHException('Unable to open channel.')
+ e = SSHException("Unable to open channel.")
raise e
def request_port_forward(self, address, port, handler=None):
@@ -935,20 +974,22 @@ class Transport(threading.Thread, ClosingContextManager):
`.SSHException` -- if the server refused the TCP forward request
"""
if not self.active:
- raise SSHException('SSH session not active')
+ raise SSHException("SSH session not active")
port = int(port)
response = self.global_request(
- 'tcpip-forward', (address, port), wait=True
+ "tcpip-forward", (address, port), wait=True
)
if response is None:
- raise SSHException('TCP forwarding request denied')
+ raise SSHException("TCP forwarding request denied")
if port == 0:
port = response.get_int()
if handler is None:
+
def default_handler(channel, src_addr, dest_addr_port):
# src_addr, src_port = src_addr_port
# dest_addr, dest_port = dest_addr_port
self._queue_incoming_channel(channel)
+
handler = default_handler
self._tcp_handler = handler
return port
@@ -965,7 +1006,7 @@ class Transport(threading.Thread, ClosingContextManager):
if not self.active:
return
self._tcp_handler = None
- self.global_request('cancel-tcpip-forward', (address, port), wait=True)
+ self.global_request("cancel-tcpip-forward", (address, port), wait=True)
def open_sftp_client(self):
"""
@@ -1018,7 +1059,7 @@ class Transport(threading.Thread, ClosingContextManager):
e = self.get_exception()
if e is not None:
raise e
- raise SSHException('Negotiation failed.')
+ raise SSHException("Negotiation failed.")
if self.completion_event.is_set():
break
return
@@ -1034,8 +1075,10 @@ class Transport(threading.Thread, ClosingContextManager):
seconds to wait before sending a keepalive packet (or
0 to disable keepalives).
"""
+
def _request(x=weakref.proxy(self)):
- return x.global_request('keepalive@lag.net', wait=False)
+ return x.global_request("keepalive@lag.net", wait=False)
+
self.packetizer.set_keepalive(interval, _request)
def global_request(self, kind, data=None, wait=True):
@@ -1103,7 +1146,7 @@ class Transport(threading.Thread, ClosingContextManager):
def connect(
self,
hostkey=None,
- username='',
+ username="",
password=None,
pkey=None,
gss_host=None,
@@ -1177,34 +1220,43 @@ class Transport(threading.Thread, ClosingContextManager):
if (hostkey is not None) and not gss_kex:
key = self.get_remote_server_key()
if (
- key.get_name() != hostkey.get_name() or
- key.asbytes() != hostkey.asbytes()
+ key.get_name() != hostkey.get_name()
+ or key.asbytes() != hostkey.asbytes()
):
- self._log(DEBUG, 'Bad host key from server')
- self._log(DEBUG, 'Expected: {}: {}'.format(
- hostkey.get_name(), repr(hostkey.asbytes()),
- ))
- self._log(DEBUG, 'Got : {}: {}'.format(
- key.get_name(), repr(key.asbytes()),
- ))
- raise SSHException('Bad host key from server')
- self._log(DEBUG, 'Host key verified ({})'.format(
- hostkey.get_name()))
+ self._log(DEBUG, "Bad host key from server")
+ self._log(
+ DEBUG,
+ "Expected: {}: {}".format(
+ hostkey.get_name(), repr(hostkey.asbytes())
+ ),
+ )
+ self._log(
+ DEBUG,
+ "Got : {}: {}".format(
+ key.get_name(), repr(key.asbytes())
+ ),
+ )
+ raise SSHException("Bad host key from server")
+ self._log(
+ DEBUG, "Host key verified ({})".format(hostkey.get_name())
+ )
if (pkey is not None) or (password is not None) or gss_auth or gss_kex:
if gss_auth:
- self._log(DEBUG, 'Attempting GSS-API auth... (gssapi-with-mic)') # noqa
+ self._log(
+ DEBUG, "Attempting GSS-API auth... (gssapi-with-mic)"
+ ) # noqa
self.auth_gssapi_with_mic(
- username, self.gss_host, gss_deleg_creds,
+ username, self.gss_host, gss_deleg_creds
)
elif gss_kex:
- self._log(DEBUG, 'Attempting GSS-API auth... (gssapi-keyex)')
+ self._log(DEBUG, "Attempting GSS-API auth... (gssapi-keyex)")
self.auth_gssapi_keyex(username)
elif pkey is not None:
- self._log(DEBUG, 'Attempting public-key auth...')
+ self._log(DEBUG, "Attempting public-key auth...")
self.auth_publickey(username, pkey)
else:
- self._log(DEBUG, 'Attempting password auth...')
+ self._log(DEBUG, "Attempting password auth...")
self.auth_password(username, password)
return
@@ -1259,9 +1311,9 @@ class Transport(threading.Thread, ClosingContextManager):
closed.
"""
return (
- self.active and
- self.auth_handler is not None and
- self.auth_handler.is_authenticated()
+ self.active
+ and self.auth_handler is not None
+ and self.auth_handler.is_authenticated()
)
def get_username(self):
@@ -1311,7 +1363,7 @@ class Transport(threading.Thread, ClosingContextManager):
.. versionadded:: 1.5
"""
if (not self.active) or (not self.initial_kex_done):
- raise SSHException('No existing session')
+ raise SSHException("No existing session")
my_event = threading.Event()
self.auth_handler = AuthHandler(self)
self.auth_handler.auth_none(username, my_event)
@@ -1367,7 +1419,7 @@ class Transport(threading.Thread, ClosingContextManager):
if (not self.active) or (not self.initial_kex_done):
# we should never try to send the password unless we're on a secure
# link
- raise SSHException('No existing session')
+ raise SSHException("No existing session")
if event is None:
my_event = threading.Event()
else:
@@ -1382,12 +1434,13 @@ class Transport(threading.Thread, ClosingContextManager):
except BadAuthenticationType as e:
# if password auth isn't allowed, but keyboard-interactive *is*,
# try to fudge it
- if not fallback or ('keyboard-interactive' not in e.allowed_types):
+ if not fallback or ("keyboard-interactive" not in e.allowed_types):
raise
try:
+
def handler(title, instructions, fields):
if len(fields) > 1:
- raise SSHException('Fallback authentication failed.')
+ raise SSHException("Fallback authentication failed.")
if len(fields) == 0:
# for some reason, at least on os x, a 2nd request will
# be made with zero fields requested. maybe it's just
@@ -1395,6 +1448,7 @@ class Transport(threading.Thread, ClosingContextManager):
# type we're doing here. *shrug* :)
return []
return [password]
+
return self.auth_interactive(username, handler)
except SSHException:
# attempt failed; just raise the original exception
@@ -1437,7 +1491,7 @@ class Transport(threading.Thread, ClosingContextManager):
"""
if (not self.active) or (not self.initial_kex_done):
# we should never try to authenticate unless we're on a secure link
- raise SSHException('No existing session')
+ raise SSHException("No existing session")
if event is None:
my_event = threading.Event()
else:
@@ -1449,7 +1503,7 @@ class Transport(threading.Thread, ClosingContextManager):
return []
return self.auth_handler.wait_for_response(my_event)
- def auth_interactive(self, username, handler, submethods=''):
+ def auth_interactive(self, username, handler, submethods=""):
"""
Authenticate to the server interactively. A handler is used to answer
arbitrary questions from the server. On many servers, this is just a
@@ -1494,7 +1548,7 @@ class Transport(threading.Thread, ClosingContextManager):
"""
if (not self.active) or (not self.initial_kex_done):
# we should never try to authenticate unless we're on a secure link
- raise SSHException('No existing session')
+ raise SSHException("No existing session")
my_event = threading.Event()
self.auth_handler = AuthHandler(self)
self.auth_handler.auth_interactive(
@@ -1502,7 +1556,7 @@ class Transport(threading.Thread, ClosingContextManager):
)
return self.auth_handler.wait_for_response(my_event)
- def auth_interactive_dumb(self, username, handler=None, submethods=''):
+ def auth_interactive_dumb(self, username, handler=None, submethods=""):
"""
Autenticate to the server interactively but dumber.
Just print the prompt and / or instructions to stdout and send back
@@ -1511,6 +1565,7 @@ class Transport(threading.Thread, ClosingContextManager):
"""
if not handler:
+
def handler(title, instructions, prompt_list):
answers = []
if title:
@@ -1518,9 +1573,10 @@ class Transport(threading.Thread, ClosingContextManager):
if instructions:
print(instructions.strip())
for prompt, show_input in prompt_list:
- print(prompt.strip(), end=' ')
+ print(prompt.strip(), end=" ")
answers.append(input())
return answers
+
return self.auth_interactive(username, handler, submethods)
def auth_gssapi_with_mic(self, username, gss_host, gss_deleg_creds):
@@ -1541,7 +1597,7 @@ class Transport(threading.Thread, ClosingContextManager):
"""
if (not self.active) or (not self.initial_kex_done):
# we should never try to authenticate unless we're on a secure link
- raise SSHException('No existing session')
+ raise SSHException("No existing session")
my_event = threading.Event()
self.auth_handler = AuthHandler(self)
self.auth_handler.auth_gssapi_with_mic(
@@ -1566,7 +1622,7 @@ class Transport(threading.Thread, ClosingContextManager):
"""
if (not self.active) or (not self.initial_kex_done):
# we should never try to authenticate unless we're on a secure link
- raise SSHException('No existing session')
+ raise SSHException("No existing session")
my_event = threading.Event()
self.auth_handler = AuthHandler(self)
self.auth_handler.auth_gssapi_keyex(username, my_event)
@@ -1633,9 +1689,9 @@ class Transport(threading.Thread, ClosingContextManager):
.. versionadded:: 1.5.2
"""
if compress:
- self._preferred_compression = ('zlib@openssh.com', 'zlib', 'none')
+ self._preferred_compression = ("zlib@openssh.com", "zlib", "none")
else:
- self._preferred_compression = ('none',)
+ self._preferred_compression = ("none",)
def getpeername(self):
"""
@@ -1649,9 +1705,9 @@ class Transport(threading.Thread, ClosingContextManager):
the address of the remote host, if known, as a ``(str, int)``
tuple.
"""
- gp = getattr(self.sock, 'getpeername', None)
+ gp = getattr(self.sock, "getpeername", None)
if gp is None:
- return 'unknown', 0
+ return "unknown", 0
return gp()
def stop_thread(self):
@@ -1670,10 +1726,10 @@ class Transport(threading.Thread, ClosingContextManager):
# our socket and packetizer are both closed (but where we'd
# otherwise be sitting forever on that recv()).
while (
- self.is_alive() and
- self is not threading.current_thread() and
- not self.sock._closed and
- not self.packetizer.closed
+ self.is_alive()
+ and self is not threading.current_thread()
+ and not self.sock._closed
+ and not self.packetizer.closed
):
self.join(0.1)
@@ -1715,14 +1771,18 @@ class Transport(threading.Thread, ClosingContextManager):
while True:
self.clear_to_send.wait(0.1)
if not self.active:
- self._log(DEBUG, 'Dropping user packet because connection is dead.') # noqa
+ self._log(
+ DEBUG, "Dropping user packet because connection is dead."
+ ) # noqa
return
self.clear_to_send_lock.acquire()
if self.clear_to_send.is_set():
break
self.clear_to_send_lock.release()
if time.time() > start + self.clear_to_send_timeout:
- raise SSHException('Key-exchange timed out waiting for key negotiation') # noqa
+ raise SSHException(
+ "Key-exchange timed out waiting for key negotiation"
+ ) # noqa
try:
self._send_message(data)
finally:
@@ -1746,9 +1806,13 @@ class Transport(threading.Thread, ClosingContextManager):
def _verify_key(self, host_key, sig):
key = self._key_info[self.host_key_type](Message(host_key))
if key is None:
- raise SSHException('Unknown host key type')
+ raise SSHException("Unknown host key type")
if not key.verify_ssh_sig(self.H, Message(sig)):
- raise SSHException('Signature verification ({}) failed.'.format(self.host_key_type)) # noqa
+ raise SSHException(
+ "Signature verification ({}) failed.".format(
+ self.host_key_type
+ )
+ ) # noqa
self.host_key = key
def _compute_key(self, id, nbytes):
@@ -1760,16 +1824,16 @@ class Transport(threading.Thread, ClosingContextManager):
m.add_bytes(self.session_id)
# Fallback to SHA1 for kex engines that fail to specify a hex
# algorithm, or for e.g. transport tests that don't run kexinit.
- hash_algo = getattr(self.kex_engine, 'hash_algo', None)
+ hash_algo = getattr(self.kex_engine, "hash_algo", None)
hash_select_msg = "kex engine {} specified hash_algo {!r}".format(
- self.kex_engine.__class__.__name__, hash_algo,
+ self.kex_engine.__class__.__name__, hash_algo
)
if hash_algo is None:
hash_algo = sha1
hash_select_msg += ", falling back to sha1"
- if not hasattr(self, '_logged_hash_selection'):
+ if not hasattr(self, "_logged_hash_selection"):
self._log(DEBUG, hash_select_msg)
- setattr(self, '_logged_hash_selection', True)
+ setattr(self, "_logged_hash_selection", True)
out = sofar = hash_algo(m.asbytes()).digest()
while len(out) < nbytes:
m = Message()
@@ -1783,11 +1847,11 @@ class Transport(threading.Thread, ClosingContextManager):
def _get_cipher(self, name, key, iv, operation):
if name not in self._cipher_info:
- raise SSHException('Unknown client cipher ' + name)
+ raise SSHException("Unknown client cipher " + name)
else:
cipher = Cipher(
- self._cipher_info[name]['class'](key),
- self._cipher_info[name]['mode'](iv),
+ self._cipher_info[name]["class"](key),
+ self._cipher_info[name]["mode"](iv),
backend=default_backend(),
)
if operation is self._ENCRYPT:
@@ -1797,8 +1861,10 @@ class Transport(threading.Thread, ClosingContextManager):
def _set_forward_agent_handler(self, handler):
if handler is None:
+
def default_handler(channel):
self._queue_incoming_channel(channel)
+
self._forward_agent_handler = default_handler
else:
self._forward_agent_handler = handler
@@ -1809,6 +1875,7 @@ class Transport(threading.Thread, ClosingContextManager):
# by default, use the same mechanism as accept()
def default_handler(channel, src_addr_port):
self._queue_incoming_channel(channel)
+
self._x11_handler = default_handler
else:
self._x11_handler = handler
@@ -1842,9 +1909,9 @@ class Transport(threading.Thread, ClosingContextManager):
Otherwise (client mode, authed, or pre-auth message) returns None.
"""
if (
- not self.server_mode or
- ptype <= HIGHEST_USERAUTH_MESSAGE_ID or
- self.is_authenticated()
+ not self.server_mode
+ or ptype <= HIGHEST_USERAUTH_MESSAGE_ID
+ or self.is_authenticated()
):
return None
# WELP. We must be dealing with someone trying to do non-auth things
@@ -1855,13 +1922,13 @@ class Transport(threading.Thread, ClosingContextManager):
reply.add_byte(cMSG_REQUEST_FAILURE)
# Channel opens let us reject w/ a specific type + message.
elif ptype == MSG_CHANNEL_OPEN:
- kind = message.get_text() # noqa
+ kind = message.get_text() # noqa
chanid = message.get_int()
reply.add_byte(cMSG_CHANNEL_OPEN_FAILURE)
reply.add_int(chanid)
reply.add_int(OPEN_FAILED_ADMINISTRATIVELY_PROHIBITED)
- reply.add_string('')
- reply.add_string('en')
+ reply.add_string("")
+ reply.add_string("en")
# NOTE: Post-open channel messages do not need checking; the above will
# reject attemps to open channels, meaning that even if a malicious
# user tries to send a MSG_CHANNEL_REQUEST, it will simply fall under
@@ -1883,13 +1950,16 @@ class Transport(threading.Thread, ClosingContextManager):
_active_threads.append(self)
tid = hex(long(id(self)) & xffffffff)
if self.server_mode:
- self._log(DEBUG, 'starting thread (server mode): {}'.format(tid))
+ self._log(DEBUG, "starting thread (server mode): {}".format(tid))
else:
- self._log(DEBUG, 'starting thread (client mode): {}'.format(tid))
+ self._log(DEBUG, "starting thread (client mode): {}".format(tid))
try:
try:
- self.packetizer.write_all(b(self.local_version + '\r\n'))
- self._log(DEBUG, 'Local version/idstring: {}'.format(self.local_version)) # noqa
+ self.packetizer.write_all(b(self.local_version + "\r\n"))
+ self._log(
+ DEBUG,
+ "Local version/idstring: {}".format(self.local_version),
+ ) # noqa
self._check_banner()
# The above is actually very much part of the handshake, but
# sometimes the banner can be read but the machine is not
@@ -1919,7 +1989,11 @@ class Transport(threading.Thread, ClosingContextManager):
continue
if len(self._expected_packet) > 0:
if ptype not in self._expected_packet:
- raise SSHException('Expecting packet from {!r}, got {:d}'.format(self._expected_packet, ptype)) # noqa
+ raise SSHException(
+ "Expecting packet from {!r}, got {:d}".format(
+ self._expected_packet, ptype
+ )
+ ) # noqa
self._expected_packet = tuple()
if (ptype >= 30) and (ptype <= 41):
self.kex_engine.parse_next(ptype, m)
@@ -1937,20 +2011,30 @@ class Transport(threading.Thread, ClosingContextManager):
if chan is not None:
self._channel_handler_table[ptype](chan, m)
elif chanid in self.channels_seen:
- self._log(DEBUG, 'Ignoring message for dead channel {:d}'.format(chanid)) # noqa
+ self._log(
+ DEBUG,
+ "Ignoring message for dead channel {:d}".format( # noqa
+ chanid
+ ),
+ )
else:
- self._log(ERROR, 'Channel request for unknown channel {:d}'.format(chanid)) # noqa
+ self._log(
+ ERROR,
+ "Channel request for unknown channel {:d}".format( # noqa
+ chanid
+ ),
+ )
break
elif (
- self.auth_handler is not None and
- ptype in self.auth_handler._handler_table
+ self.auth_handler is not None
+ and ptype in self.auth_handler._handler_table
):
handler = self.auth_handler._handler_table[ptype]
handler(self.auth_handler, m)
if len(self._expected_packet) > 0:
continue
else:
- err = 'Oops, unhandled type {:d}'.format(ptype)
+ err = "Oops, unhandled type {:d}".format(ptype)
self._log(WARNING, err)
msg = Message()
msg.add_byte(cMSG_UNIMPLEMENTED)
@@ -1958,24 +2042,24 @@ class Transport(threading.Thread, ClosingContextManager):
self._send_message(msg)
self.packetizer.complete_handshake()
except SSHException as e:
- self._log(ERROR, 'Exception: ' + str(e))
+ self._log(ERROR, "Exception: " + str(e))
self._log(ERROR, util.tb_strings())
self.saved_exception = e
except EOFError as e:
- self._log(DEBUG, 'EOF in transport thread')
+ self._log(DEBUG, "EOF in transport thread")
self.saved_exception = e
except socket.error as e:
if type(e.args) is tuple:
if e.args:
- emsg = '{} ({:d})'.format(e.args[1], e.args[0])
+ emsg = "{} ({:d})".format(e.args[1], e.args[0])
else: # empty tuple, e.g. socket.timeout
emsg = str(e) or repr(e)
else:
emsg = e.args
- self._log(ERROR, 'Socket exception: ' + emsg)
+ self._log(ERROR, "Socket exception: " + emsg)
self.saved_exception = e
except Exception as e:
- self._log(ERROR, 'Unknown exception: ' + str(e))
+ self._log(ERROR, "Unknown exception: " + str(e))
self._log(ERROR, util.tb_strings())
self.saved_exception = e
_active_threads.remove(self)
@@ -2004,7 +2088,6 @@ class Transport(threading.Thread, ClosingContextManager):
if self.sys.modules is not None:
raise
-
def _log_agreement(self, which, local, remote):
# Log useful, non-duplicative line re: an agreed-upon algorithm.
# Old code implied algorithms could be asymmetrical (different for
@@ -2046,32 +2129,32 @@ class Transport(threading.Thread, ClosingContextManager):
raise
except Exception as e:
raise SSHException(
- 'Error reading SSH protocol banner' + str(e)
+ "Error reading SSH protocol banner" + str(e)
)
- if buf[:4] == 'SSH-':
+ if buf[:4] == "SSH-":
break
- self._log(DEBUG, 'Banner: ' + buf)
- if buf[:4] != 'SSH-':
+ self._log(DEBUG, "Banner: " + buf)
+ if buf[:4] != "SSH-":
raise SSHException('Indecipherable protocol version "' + buf + '"')
# save this server version string for later
self.remote_version = buf
- self._log(DEBUG, 'Remote version/idstring: {}'.format(buf))
+ self._log(DEBUG, "Remote version/idstring: {}".format(buf))
# pull off any attached comment
# NOTE: comment used to be stored in a variable and then...never used.
# since 2003. ca 877cd974b8182d26fa76d566072917ea67b64e67
- i = buf.find(' ')
+ i = buf.find(" ")
if i >= 0:
buf = buf[:i]
# parse out version string and make sure it matches
- segs = buf.split('-', 2)
+ segs = buf.split("-", 2)
if len(segs) < 3:
- raise SSHException('Invalid SSH banner')
+ raise SSHException("Invalid SSH banner")
version = segs[1]
client = segs[2]
- if version != '1.99' and version != '2.0':
- msg = 'Incompatible version ({} instead of 2.0)'
+ if version != "1.99" and version != "2.0":
+ msg = "Incompatible version ({} instead of 2.0)"
raise SSHException(msg.format(version))
- msg = 'Connected (version {}, client {})'.format(version, client)
+ msg = "Connected (version {}, client {})".format(version, client)
self._log(INFO, msg)
def _send_kex_init(self):
@@ -2087,25 +2170,27 @@ class Transport(threading.Thread, ClosingContextManager):
self.gss_kex_used = False
self.in_kex = True
if self.server_mode:
- mp_required_prefix = 'diffie-hellman-group-exchange-sha'
+ mp_required_prefix = "diffie-hellman-group-exchange-sha"
kex_mp = [
- k for k
- in self._preferred_kex
+ k
+ for k in self._preferred_kex
if k.startswith(mp_required_prefix)
]
if (self._modulus_pack is None) and (len(kex_mp) > 0):
# can't do group-exchange if we don't have a pack of potential
# primes
pkex = [
- k for k
- in self.get_security_options().kex
+ k
+ for k in self.get_security_options().kex
if not k.startswith(mp_required_prefix)
]
self.get_security_options().kex = pkex
- available_server_keys = list(filter(
- list(self.server_key_dict.keys()).__contains__,
- self._preferred_keys
- ))
+ available_server_keys = list(
+ filter(
+ list(self.server_key_dict.keys()).__contains__,
+ self._preferred_keys,
+ )
+ )
else:
available_server_keys = self._preferred_keys
@@ -2129,7 +2214,7 @@ class Transport(threading.Thread, ClosingContextManager):
self._send_message(m)
def _parse_kex_init(self, m):
- m.get_bytes(16) # cookie, discarded
+ m.get_bytes(16) # cookie, discarded
kex_algo_list = m.get_list()
server_key_algo_list = m.get_list()
client_encrypt_algo_list = m.get_list()
@@ -2141,20 +2226,32 @@ class Transport(threading.Thread, ClosingContextManager):
client_lang_list = m.get_list()
server_lang_list = m.get_list()
kex_follows = m.get_boolean()
- m.get_int() # unused
-
- self._log(DEBUG,
- 'kex algos:' + str(kex_algo_list) +
- ' server key:' + str(server_key_algo_list) +
- ' client encrypt:' + str(client_encrypt_algo_list) +
- ' server encrypt:' + str(server_encrypt_algo_list) +
- ' client mac:' + str(client_mac_algo_list) +
- ' server mac:' + str(server_mac_algo_list) +
- ' client compress:' + str(client_compress_algo_list) +
- ' server compress:' + str(server_compress_algo_list) +
- ' client lang:' + str(client_lang_list) +
- ' server lang:' + str(server_lang_list) +
- ' kex follows?' + str(kex_follows)
+ m.get_int() # unused
+
+ self._log(
+ DEBUG,
+ "kex algos:"
+ + str(kex_algo_list)
+ + " server key:"
+ + str(server_key_algo_list)
+ + " client encrypt:"
+ + str(client_encrypt_algo_list)
+ + " server encrypt:"
+ + str(server_encrypt_algo_list)
+ + " client mac:"
+ + str(client_mac_algo_list)
+ + " server mac:"
+ + str(server_mac_algo_list)
+ + " client compress:"
+ + str(client_compress_algo_list)
+ + " server compress:"
+ + str(server_compress_algo_list)
+ + " client lang:"
+ + str(client_lang_list)
+ + " server lang:"
+ + str(server_lang_list)
+ + " kex follows?"
+ + str(kex_follows),
)
# as a server, we pick the first item in the client's list that we
@@ -2162,122 +2259,150 @@ class Transport(threading.Thread, ClosingContextManager):
# as a client, we pick the first item in our list that the server
# supports.
if self.server_mode:
- agreed_kex = list(filter(
- self._preferred_kex.__contains__,
- kex_algo_list
- ))
+ agreed_kex = list(
+ filter(self._preferred_kex.__contains__, kex_algo_list)
+ )
else:
- agreed_kex = list(filter(
- kex_algo_list.__contains__,
- self._preferred_kex
- ))
+ agreed_kex = list(
+ filter(kex_algo_list.__contains__, self._preferred_kex)
+ )
if len(agreed_kex) == 0:
- raise SSHException('Incompatible ssh peer (no acceptable kex algorithm)') # noqa
+ raise SSHException(
+ "Incompatible ssh peer (no acceptable kex algorithm)"
+ ) # noqa
self.kex_engine = self._kex_info[agreed_kex[0]](self)
self._log(DEBUG, "Kex agreed: {}".format(agreed_kex[0]))
if self.server_mode:
- available_server_keys = list(filter(
- list(self.server_key_dict.keys()).__contains__,
- self._preferred_keys
- ))
- agreed_keys = list(filter(
- available_server_keys.__contains__, server_key_algo_list
- ))
+ available_server_keys = list(
+ filter(
+ list(self.server_key_dict.keys()).__contains__,
+ self._preferred_keys,
+ )
+ )
+ agreed_keys = list(
+ filter(
+ available_server_keys.__contains__, server_key_algo_list
+ )
+ )
else:
- agreed_keys = list(filter(
- server_key_algo_list.__contains__, self._preferred_keys
- ))
+ agreed_keys = list(
+ filter(server_key_algo_list.__contains__, self._preferred_keys)
+ )
if len(agreed_keys) == 0:
- raise SSHException('Incompatible ssh peer (no acceptable host key)') # noqa
+ raise SSHException(
+ "Incompatible ssh peer (no acceptable host key)"
+ ) # noqa
self.host_key_type = agreed_keys[0]
if self.server_mode and (self.get_server_key() is None):
- raise SSHException('Incompatible ssh peer (can\'t match requested host key type)') # noqa
- self._log_agreement(
- 'HostKey', agreed_keys[0], agreed_keys[0]
- )
+ raise SSHException(
+ "Incompatible ssh peer (can't match requested host key type)"
+ ) # noqa
+ self._log_agreement("HostKey", agreed_keys[0], agreed_keys[0])
if self.server_mode:
- agreed_local_ciphers = list(filter(
- self._preferred_ciphers.__contains__,
- server_encrypt_algo_list
- ))
- agreed_remote_ciphers = list(filter(
- self._preferred_ciphers.__contains__,
- client_encrypt_algo_list
- ))
+ agreed_local_ciphers = list(
+ filter(
+ self._preferred_ciphers.__contains__,
+ server_encrypt_algo_list,
+ )
+ )
+ agreed_remote_ciphers = list(
+ filter(
+ self._preferred_ciphers.__contains__,
+ client_encrypt_algo_list,
+ )
+ )
else:
- agreed_local_ciphers = list(filter(
- client_encrypt_algo_list.__contains__,
- self._preferred_ciphers
- ))
- agreed_remote_ciphers = list(filter(
- server_encrypt_algo_list.__contains__,
- self._preferred_ciphers
- ))
+ agreed_local_ciphers = list(
+ filter(
+ client_encrypt_algo_list.__contains__,
+ self._preferred_ciphers,
+ )
+ )
+ agreed_remote_ciphers = list(
+ filter(
+ server_encrypt_algo_list.__contains__,
+ self._preferred_ciphers,
+ )
+ )
if len(agreed_local_ciphers) == 0 or len(agreed_remote_ciphers) == 0:
- raise SSHException('Incompatible ssh server (no acceptable ciphers)') # noqa
+ raise SSHException(
+ "Incompatible ssh server (no acceptable ciphers)"
+ ) # noqa
self.local_cipher = agreed_local_ciphers[0]
self.remote_cipher = agreed_remote_ciphers[0]
self._log_agreement(
- 'Cipher', local=self.local_cipher, remote=self.remote_cipher
+ "Cipher", local=self.local_cipher, remote=self.remote_cipher
)
if self.server_mode:
- agreed_remote_macs = list(filter(
- self._preferred_macs.__contains__, client_mac_algo_list
- ))
- agreed_local_macs = list(filter(
- self._preferred_macs.__contains__, server_mac_algo_list
- ))
+ agreed_remote_macs = list(
+ filter(self._preferred_macs.__contains__, client_mac_algo_list)
+ )
+ agreed_local_macs = list(
+ filter(self._preferred_macs.__contains__, server_mac_algo_list)
+ )
else:
- agreed_local_macs = list(filter(
- client_mac_algo_list.__contains__, self._preferred_macs
- ))
- agreed_remote_macs = list(filter(
- server_mac_algo_list.__contains__, self._preferred_macs
- ))
+ agreed_local_macs = list(
+ filter(client_mac_algo_list.__contains__, self._preferred_macs)
+ )
+ agreed_remote_macs = list(
+ filter(server_mac_algo_list.__contains__, self._preferred_macs)
+ )
if (len(agreed_local_macs) == 0) or (len(agreed_remote_macs) == 0):
- raise SSHException('Incompatible ssh server (no acceptable macs)')
+ raise SSHException("Incompatible ssh server (no acceptable macs)")
self.local_mac = agreed_local_macs[0]
self.remote_mac = agreed_remote_macs[0]
self._log_agreement(
- 'MAC', local=self.local_mac, remote=self.remote_mac
+ "MAC", local=self.local_mac, remote=self.remote_mac
)
if self.server_mode:
- agreed_remote_compression = list(filter(
- self._preferred_compression.__contains__,
- client_compress_algo_list
- ))
- agreed_local_compression = list(filter(
- self._preferred_compression.__contains__,
- server_compress_algo_list
- ))
+ agreed_remote_compression = list(
+ filter(
+ self._preferred_compression.__contains__,
+ client_compress_algo_list,
+ )
+ )
+ agreed_local_compression = list(
+ filter(
+ self._preferred_compression.__contains__,
+ server_compress_algo_list,
+ )
+ )
else:
- agreed_local_compression = list(filter(
- client_compress_algo_list.__contains__,
- self._preferred_compression
- ))
- agreed_remote_compression = list(filter(
- server_compress_algo_list.__contains__,
- self._preferred_compression
- ))
+ agreed_local_compression = list(
+ filter(
+ client_compress_algo_list.__contains__,
+ self._preferred_compression,
+ )
+ )
+ agreed_remote_compression = list(
+ filter(
+ server_compress_algo_list.__contains__,
+ self._preferred_compression,
+ )
+ )
if (
- len(agreed_local_compression) == 0 or
- len(agreed_remote_compression) == 0
+ len(agreed_local_compression) == 0
+ or len(agreed_remote_compression) == 0
):
- msg = 'Incompatible ssh server (no acceptable compression) {!r} {!r} {!r}' # noqa
- raise SSHException(msg.format(
- agreed_local_compression, agreed_remote_compression,
- self._preferred_compression,
- ))
+ msg = "Incompatible ssh server (no acceptable compression)"
+ msg += " {!r} {!r} {!r}"
+ raise SSHException(
+ msg.format(
+ agreed_local_compression,
+ agreed_remote_compression,
+ self._preferred_compression,
+ )
+ )
self.local_compression = agreed_local_compression[0]
self.remote_compression = agreed_remote_compression[0]
self._log_agreement(
- 'Compression',
+ "Compression",
local=self.local_compression,
- remote=self.remote_compression
+ remote=self.remote_compression,
)
# save for computing hash later...
@@ -2290,40 +2415,36 @@ class Transport(threading.Thread, ClosingContextManager):
def _activate_inbound(self):
"""switch on newly negotiated encryption parameters for
inbound traffic"""
- block_size = self._cipher_info[self.remote_cipher]['block-size']
+ block_size = self._cipher_info[self.remote_cipher]["block-size"]
if self.server_mode:
- IV_in = self._compute_key('A', block_size)
+ IV_in = self._compute_key("A", block_size)
key_in = self._compute_key(
- 'C', self._cipher_info[self.remote_cipher]['key-size']
+ "C", self._cipher_info[self.remote_cipher]["key-size"]
)
else:
- IV_in = self._compute_key('B', block_size)
+ IV_in = self._compute_key("B", block_size)
key_in = self._compute_key(
- 'D', self._cipher_info[self.remote_cipher]['key-size']
+ "D", self._cipher_info[self.remote_cipher]["key-size"]
)
engine = self._get_cipher(
self.remote_cipher, key_in, IV_in, self._DECRYPT
)
- mac_size = self._mac_info[self.remote_mac]['size']
- mac_engine = self._mac_info[self.remote_mac]['class']
+ mac_size = self._mac_info[self.remote_mac]["size"]
+ mac_engine = self._mac_info[self.remote_mac]["class"]
# initial mac keys are done in the hash's natural size (not the
# potentially truncated transmission size)
if self.server_mode:
- mac_key = self._compute_key('E', mac_engine().digest_size)
+ mac_key = self._compute_key("E", mac_engine().digest_size)
else:
- mac_key = self._compute_key('F', mac_engine().digest_size)
+ mac_key = self._compute_key("F", mac_engine().digest_size)
self.packetizer.set_inbound_cipher(
engine, block_size, mac_engine, mac_size, mac_key
)
compress_in = self._compression_info[self.remote_compression][1]
- if (
- compress_in is not None and
- (
- self.remote_compression != 'zlib@openssh.com' or
- self.authenticated
- )
+ if compress_in is not None and (
+ self.remote_compression != "zlib@openssh.com" or self.authenticated
):
- self._log(DEBUG, 'Switching on inbound compression ...')
+ self._log(DEBUG, "Switching on inbound compression ...")
self.packetizer.set_inbound_compressor(compress_in())
def _activate_outbound(self):
@@ -2332,37 +2453,37 @@ class Transport(threading.Thread, ClosingContextManager):
m = Message()
m.add_byte(cMSG_NEWKEYS)
self._send_message(m)
- block_size = self._cipher_info[self.local_cipher]['block-size']
+ block_size = self._cipher_info[self.local_cipher]["block-size"]
if self.server_mode:
- IV_out = self._compute_key('B', block_size)
+ IV_out = self._compute_key("B", block_size)
key_out = self._compute_key(
- 'D', self._cipher_info[self.local_cipher]['key-size'])
+ "D", self._cipher_info[self.local_cipher]["key-size"]
+ )
else:
- IV_out = self._compute_key('A', block_size)
+ IV_out = self._compute_key("A", block_size)
key_out = self._compute_key(
- 'C', self._cipher_info[self.local_cipher]['key-size'])
+ "C", self._cipher_info[self.local_cipher]["key-size"]
+ )
engine = self._get_cipher(
- self.local_cipher, key_out, IV_out, self._ENCRYPT)
- mac_size = self._mac_info[self.local_mac]['size']
- mac_engine = self._mac_info[self.local_mac]['class']
+ self.local_cipher, key_out, IV_out, self._ENCRYPT
+ )
+ mac_size = self._mac_info[self.local_mac]["size"]
+ mac_engine = self._mac_info[self.local_mac]["class"]
# initial mac keys are done in the hash's natural size (not the
# potentially truncated transmission size)
if self.server_mode:
- mac_key = self._compute_key('F', mac_engine().digest_size)
+ mac_key = self._compute_key("F", mac_engine().digest_size)
else:
- mac_key = self._compute_key('E', mac_engine().digest_size)
- sdctr = self.local_cipher.endswith('-ctr')
+ mac_key = self._compute_key("E", mac_engine().digest_size)
+ sdctr = self.local_cipher.endswith("-ctr")
self.packetizer.set_outbound_cipher(
- engine, block_size, mac_engine, mac_size, mac_key, sdctr)
+ engine, block_size, mac_engine, mac_size, mac_key, sdctr
+ )
compress_out = self._compression_info[self.local_compression][0]
- if (
- compress_out is not None and
- (
- self.local_compression != 'zlib@openssh.com' or
- self.authenticated
- )
+ if compress_out is not None and (
+ self.local_compression != "zlib@openssh.com" or self.authenticated
):
- self._log(DEBUG, 'Switching on outbound compression ...')
+ self._log(DEBUG, "Switching on outbound compression ...")
self.packetizer.set_outbound_compressor(compress_out())
if not self.packetizer.need_rekey():
self.in_kex = False
@@ -2372,17 +2493,17 @@ class Transport(threading.Thread, ClosingContextManager):
def _auth_trigger(self):
self.authenticated = True
# delayed initiation of compression
- if self.local_compression == 'zlib@openssh.com':
+ if self.local_compression == "zlib@openssh.com":
compress_out = self._compression_info[self.local_compression][0]
- self._log(DEBUG, 'Switching on outbound compression ...')
+ self._log(DEBUG, "Switching on outbound compression ...")
self.packetizer.set_outbound_compressor(compress_out())
- if self.remote_compression == 'zlib@openssh.com':
+ if self.remote_compression == "zlib@openssh.com":
compress_in = self._compression_info[self.remote_compression][1]
- self._log(DEBUG, 'Switching on inbound compression ...')
+ self._log(DEBUG, "Switching on inbound compression ...")
self.packetizer.set_inbound_compressor(compress_in())
def _parse_newkeys(self, m):
- self._log(DEBUG, 'Switch to new keys ...')
+ self._log(DEBUG, "Switch to new keys ...")
self._activate_inbound()
# can also free a bunch of stuff here
self.local_kex_init = self.remote_kex_init = None
@@ -2410,7 +2531,7 @@ class Transport(threading.Thread, ClosingContextManager):
def _parse_disconnect(self, m):
code = m.get_int()
desc = m.get_text()
- self._log(INFO, 'Disconnect (code {:d}): {}'.format(code, desc))
+ self._log(INFO, "Disconnect (code {:d}): {}".format(code, desc))
def _parse_global_request(self, m):
kind = m.get_text()
@@ -2419,16 +2540,16 @@ class Transport(threading.Thread, ClosingContextManager):
if not self.server_mode:
self._log(
DEBUG,
- 'Rejecting "{}" global request from server.'.format(kind)
+ 'Rejecting "{}" global request from server.'.format(kind),
)
ok = False
- elif kind == 'tcpip-forward':
+ elif kind == "tcpip-forward":
address = m.get_text()
port = m.get_int()
ok = self.server_object.check_port_forward_request(address, port)
if ok:
ok = (ok,)
- elif kind == 'cancel-tcpip-forward':
+ elif kind == "cancel-tcpip-forward":
address = m.get_text()
port = m.get_int()
self.server_object.cancel_port_forward_request(address, port)
@@ -2449,13 +2570,13 @@ class Transport(threading.Thread, ClosingContextManager):
self._send_message(msg)
def _parse_request_success(self, m):
- self._log(DEBUG, 'Global request successful.')
+ self._log(DEBUG, "Global request successful.")
self.global_response = m
if self.completion_event is not None:
self.completion_event.set()
def _parse_request_failure(self, m):
- self._log(DEBUG, 'Global request denied.')
+ self._log(DEBUG, "Global request denied.")
self.global_response = None
if self.completion_event is not None:
self.completion_event.set()
@@ -2467,13 +2588,14 @@ class Transport(threading.Thread, ClosingContextManager):
server_max_packet_size = m.get_int()
chan = self._channels.get(chanid)
if chan is None:
- self._log(WARNING, 'Success for unrequested channel! [??]')
+ self._log(WARNING, "Success for unrequested channel! [??]")
return
self.lock.acquire()
try:
chan._set_remote_channel(
- server_chanid, server_window_size, server_max_packet_size)
- self._log(DEBUG, 'Secsh channel {:d} opened.'.format(chanid))
+ server_chanid, server_window_size, server_max_packet_size
+ )
+ self._log(DEBUG, "Secsh channel {:d} opened.".format(chanid))
if chanid in self.channel_events:
self.channel_events[chanid].set()
del self.channel_events[chanid]
@@ -2486,12 +2608,12 @@ class Transport(threading.Thread, ClosingContextManager):
reason = m.get_int()
reason_str = m.get_text()
m.get_text() # ignored language
- reason_text = CONNECTION_FAILED_CODE.get(reason, '(unknown code)')
+ reason_text = CONNECTION_FAILED_CODE.get(reason, "(unknown code)")
self._log(
ERROR,
- 'Secsh channel {:d} open FAILED: {}: {}'.format(
- chanid, reason_str, reason_text,
- )
+ "Secsh channel {:d} open FAILED: {}: {}".format(
+ chanid, reason_str, reason_text
+ ),
)
self.lock.acquire()
try:
@@ -2512,39 +2634,39 @@ class Transport(threading.Thread, ClosingContextManager):
max_packet_size = m.get_int()
reject = False
if (
- kind == 'auth-agent@openssh.com' and
- self._forward_agent_handler is not None
+ kind == "auth-agent@openssh.com"
+ and self._forward_agent_handler is not None
):
- self._log(DEBUG, 'Incoming forward agent connection')
+ self._log(DEBUG, "Incoming forward agent connection")
self.lock.acquire()
try:
my_chanid = self._next_channel()
finally:
self.lock.release()
- elif (kind == 'x11') and (self._x11_handler is not None):
+ elif (kind == "x11") and (self._x11_handler is not None):
origin_addr = m.get_text()
origin_port = m.get_int()
self._log(
DEBUG,
- 'Incoming x11 connection from {}:{:d}'.format(
- origin_addr, origin_port,
- )
+ "Incoming x11 connection from {}:{:d}".format(
+ origin_addr, origin_port
+ ),
)
self.lock.acquire()
try:
my_chanid = self._next_channel()
finally:
self.lock.release()
- elif (kind == 'forwarded-tcpip') and (self._tcp_handler is not None):
+ elif (kind == "forwarded-tcpip") and (self._tcp_handler is not None):
server_addr = m.get_text()
server_port = m.get_int()
origin_addr = m.get_text()
origin_port = m.get_int()
self._log(
DEBUG,
- 'Incoming tcp forwarded connection from {}:{:d}'.format(
- origin_addr, origin_port,
- )
+ "Incoming tcp forwarded connection from {}:{:d}".format(
+ origin_addr, origin_port
+ ),
)
self.lock.acquire()
try:
@@ -2554,7 +2676,8 @@ class Transport(threading.Thread, ClosingContextManager):
elif not self.server_mode:
self._log(
DEBUG,
- 'Rejecting "{}" channel request from server.'.format(kind))
+ 'Rejecting "{}" channel request from server.'.format(kind),
+ )
reject = True
reason = OPEN_FAILED_ADMINISTRATIVELY_PROHIBITED
else:
@@ -2563,7 +2686,7 @@ class Transport(threading.Thread, ClosingContextManager):
my_chanid = self._next_channel()
finally:
self.lock.release()
- if kind == 'direct-tcpip':
+ if kind == "direct-tcpip":
# handle direct-tcpip requests coming from the client
dest_addr = m.get_text()
dest_port = m.get_int()
@@ -2572,23 +2695,25 @@ class Transport(threading.Thread, ClosingContextManager):
reason = self.server_object.check_channel_direct_tcpip_request(
my_chanid,
(origin_addr, origin_port),
- (dest_addr, dest_port)
+ (dest_addr, dest_port),
)
else:
reason = self.server_object.check_channel_request(
- kind, my_chanid)
+ kind, my_chanid
+ )
if reason != OPEN_SUCCEEDED:
self._log(
DEBUG,
- 'Rejecting "{}" channel request from client.'.format(kind))
+ 'Rejecting "{}" channel request from client.'.format(kind),
+ )
reject = True
if reject:
msg = Message()
msg.add_byte(cMSG_CHANNEL_OPEN_FAILURE)
msg.add_int(chanid)
msg.add_int(reason)
- msg.add_string('')
- msg.add_string('en')
+ msg.add_string("")
+ msg.add_string("en")
self._send_message(msg)
return
@@ -2599,9 +2724,11 @@ class Transport(threading.Thread, ClosingContextManager):
self.channels_seen[my_chanid] = True
chan._set_transport(self)
chan._set_window(
- self.default_window_size, self.default_max_packet_size)
+ self.default_window_size, self.default_max_packet_size
+ )
chan._set_remote_channel(
- chanid, initial_window_size, max_packet_size)
+ chanid, initial_window_size, max_packet_size
+ )
finally:
self.lock.release()
m = Message()
@@ -2611,19 +2738,17 @@ class Transport(threading.Thread, ClosingContextManager):
m.add_int(self.default_window_size)
m.add_int(self.default_max_packet_size)
self._send_message(m)
- self._log(DEBUG,
- 'Secsh channel {:d} ({}) opened.'.format(my_chanid, kind)
+ self._log(
+ DEBUG, "Secsh channel {:d} ({}) opened.".format(my_chanid, kind)
)
- if kind == 'auth-agent@openssh.com':
+ if kind == "auth-agent@openssh.com":
self._forward_agent_handler(chan)
- elif kind == 'x11':
+ elif kind == "x11":
self._x11_handler(chan, (origin_addr, origin_port))
- elif kind == 'forwarded-tcpip':
+ elif kind == "forwarded-tcpip":
chan.origin_addr = (origin_addr, origin_port)
self._tcp_handler(
- chan,
- (origin_addr, origin_port),
- (server_addr, server_port)
+ chan, (origin_addr, origin_port), (server_addr, server_port)
)
else:
self._queue_incoming_channel(chan)
@@ -2632,7 +2757,7 @@ class Transport(threading.Thread, ClosingContextManager):
m.get_boolean() # always_display
msg = m.get_string()
m.get_string() # language
- self._log(DEBUG, 'Debug msg: {}'.format(util.safe_string(msg)))
+ self._log(DEBUG, "Debug msg: {}".format(util.safe_string(msg)))
def _get_subsystem_handler(self, name):
try:
@@ -2666,7 +2791,7 @@ class Transport(threading.Thread, ClosingContextManager):
}
-class SecurityOptions (object):
+class SecurityOptions(object):
"""
Simple object containing the security preferences of an ssh transport.
These are tuples of acceptable ciphers, digests, key types, and key
@@ -2678,7 +2803,7 @@ class SecurityOptions (object):
``ValueError`` will be raised. If you try to assign something besides a
tuple to one of the fields, ``TypeError`` will be raised.
"""
- __slots__ = '_transport'
+ __slots__ = "_transport"
def __init__(self, transport):
self._transport = transport
@@ -2687,17 +2812,17 @@ class SecurityOptions (object):
"""
Returns a string representation of this object, for debugging.
"""
- return '<paramiko.SecurityOptions for {!r}>'.format(self._transport)
+ return "<paramiko.SecurityOptions for {!r}>".format(self._transport)
def _set(self, name, orig, x):
if type(x) is list:
x = tuple(x)
if type(x) is not tuple:
- raise TypeError('expected tuple or list')
+ raise TypeError("expected tuple or list")
possible = list(getattr(self._transport, orig).keys())
forbidden = [n for n in x if n not in possible]
if len(forbidden) > 0:
- raise ValueError('unknown cipher')
+ raise ValueError("unknown cipher")
setattr(self._transport, name, x)
@property
@@ -2707,7 +2832,7 @@ class SecurityOptions (object):
@ciphers.setter
def ciphers(self, x):
- self._set('_preferred_ciphers', '_cipher_info', x)
+ self._set("_preferred_ciphers", "_cipher_info", x)
@property
def digests(self):
@@ -2716,7 +2841,7 @@ class SecurityOptions (object):
@digests.setter
def digests(self, x):
- self._set('_preferred_macs', '_mac_info', x)
+ self._set("_preferred_macs", "_mac_info", x)
@property
def key_types(self):
@@ -2725,8 +2850,7 @@ class SecurityOptions (object):
@key_types.setter
def key_types(self, x):
- self._set('_preferred_keys', '_key_info', x)
-
+ self._set("_preferred_keys", "_key_info", x)
@property
def kex(self):
@@ -2735,7 +2859,7 @@ class SecurityOptions (object):
@kex.setter
def kex(self, x):
- self._set('_preferred_kex', '_kex_info', x)
+ self._set("_preferred_kex", "_kex_info", x)
@property
def compression(self):
@@ -2744,10 +2868,11 @@ class SecurityOptions (object):
@compression.setter
def compression(self, x):
- self._set('_preferred_compression', '_compression_info', x)
+ self._set("_preferred_compression", "_compression_info", x)
+
+class ChannelMap(object):
-class ChannelMap (object):
def __init__(self):
# (id -> Channel)
self._map = weakref.WeakValueDictionary()
diff --git a/paramiko/util.py b/paramiko/util.py
index 2854ef98..399141ad 100644
--- a/paramiko/util.py
+++ b/paramiko/util.py
@@ -49,9 +49,9 @@ def inflate_long(s, always_positive=False):
# noinspection PyAugmentAssignment
s = filler * (4 - len(s) % 4) + s
for i in range(0, len(s), 4):
- out = (out << 32) + struct.unpack('>I', s[i:i + 4])[0]
+ out = (out << 32) + struct.unpack(">I", s[i : i + 4])[0]
if negative:
- out -= (long(1) << (8 * len(s)))
+ out -= long(1) << (8 * len(s))
return out
@@ -66,7 +66,7 @@ def deflate_long(n, add_sign_padding=True):
s = bytes()
n = long(n)
while (n != 0) and (n != -1):
- s = struct.pack('>I', n & xffffffff) + s
+ s = struct.pack(">I", n & xffffffff) + s
n >>= 32
# strip off leading zeros, FFs
for i in enumerate(s):
@@ -81,7 +81,7 @@ def deflate_long(n, add_sign_padding=True):
s = zero_byte
else:
s = max_byte
- s = s[i[0]:]
+ s = s[i[0] :]
if add_sign_padding:
if (n == 0) and (byte_ord(s[0]) >= 0x80):
s = zero_byte + s
@@ -90,11 +90,11 @@ def deflate_long(n, add_sign_padding=True):
return s
-def format_binary(data, prefix=''):
+def format_binary(data, prefix=""):
x = 0
out = []
while len(data) > x + 16:
- out.append(format_binary_line(data[x:x + 16]))
+ out.append(format_binary_line(data[x : x + 16]))
x += 16
if x < len(data):
out.append(format_binary_line(data[x:]))
@@ -102,22 +102,21 @@ def format_binary(data, prefix=''):
def format_binary_line(data):
- left = ' '.join(['{:02X}'.format(byte_ord(c)) for c in data])
- right = ''.join([
- '.{:c}..'.format(byte_ord(c))[(byte_ord(c) + 63) // 95]
- for c in data
- ])
- return '{:50s} {}'.format(left, right)
+ left = " ".join(["{:02X}".format(byte_ord(c)) for c in data])
+ right = "".join(
+ [".{:c}..".format(byte_ord(c))[(byte_ord(c) + 63) // 95] for c in data]
+ )
+ return "{:50s} {}".format(left, right)
def safe_string(s):
- out = b''
+ out = b""
for c in s:
i = byte_ord(c)
if 32 <= i <= 127:
out += byte_chr(i)
else:
- out += b('%{:02X}'.format(i))
+ out += b("%{:02X}".format(i))
return out
@@ -137,7 +136,7 @@ def bit_length(n):
def tb_strings():
- return ''.join(traceback.format_exception(*sys.exc_info())).split('\n')
+ return "".join(traceback.format_exception(*sys.exc_info())).split("\n")
def generate_key_bytes(hash_alg, salt, key, nbytes):
@@ -188,6 +187,7 @@ def load_host_keys(filename):
nested dict of `.PKey` objects, indexed by hostname and then keytype
"""
from paramiko.hostkeys import HostKeys
+
return HostKeys(filename)
@@ -249,15 +249,17 @@ def log_to_file(filename, level=DEBUG):
if len(l.handlers) > 0:
return
l.setLevel(level)
- f = open(filename, 'a')
+ f = open(filename, "a")
lh = logging.StreamHandler(f)
- frm = '%(levelname)-.3s [%(asctime)s.%(msecs)03d] thr=%(_threadid)-3d %(name)s: %(message)s' # noqa
- lh.setFormatter(logging.Formatter(frm, '%Y%m%d-%H:%M:%S'))
+ frm = "%(levelname)-.3s [%(asctime)s.%(msecs)03d] thr=%(_threadid)-3d"
+ frm += " %(name)s: %(message)s"
+ lh.setFormatter(logging.Formatter(frm, "%Y%m%d-%H:%M:%S"))
l.addHandler(lh)
# make only one filter object, so it doesn't get applied more than once
-class PFilter (object):
+class PFilter(object):
+
def filter(self, record):
record._threadid = get_thread_id()
return True
@@ -293,6 +295,7 @@ def constant_time_bytes_eq(a, b):
class ClosingContextManager(object):
+
def __enter__(self):
return self
diff --git a/paramiko/win_pageant.py b/paramiko/win_pageant.py
index 661ba575..2bba789d 100644
--- a/paramiko/win_pageant.py
+++ b/paramiko/win_pageant.py
@@ -44,7 +44,7 @@ win32con_WM_COPYDATA = 74
def _get_pageant_window_object():
- return ctypes.windll.user32.FindWindowA(b'Pageant', b'Pageant')
+ return ctypes.windll.user32.FindWindowA(b"Pageant", b"Pageant")
def can_talk_to_agent():
@@ -57,7 +57,7 @@ def can_talk_to_agent():
return bool(_get_pageant_window_object())
-if platform.architecture()[0] == '64bit':
+if platform.architecture()[0] == "64bit":
ULONG_PTR = ctypes.c_uint64
else:
ULONG_PTR = ctypes.c_uint32
@@ -69,9 +69,9 @@ class COPYDATASTRUCT(ctypes.Structure):
http://msdn.microsoft.com/en-us/library/windows/desktop/ms649010%28v=vs.85%29.aspx
"""
_fields_ = [
- ('num_data', ULONG_PTR),
- ('data_size', ctypes.wintypes.DWORD),
- ('data_loc', ctypes.c_void_p),
+ ("num_data", ULONG_PTR),
+ ("data_size", ctypes.wintypes.DWORD),
+ ("data_loc", ctypes.c_void_p),
]
@@ -86,27 +86,29 @@ def _query_pageant(msg):
return None
# create a name for the mmap
- map_name = 'PageantRequest%08x' % thread.get_ident()
+ map_name = "PageantRequest%08x" % thread.get_ident()
- pymap = _winapi.MemoryMap(map_name, _AGENT_MAX_MSGLEN,
- _winapi.get_security_attributes_for_user(),
- )
+ pymap = _winapi.MemoryMap(
+ map_name, _AGENT_MAX_MSGLEN, _winapi.get_security_attributes_for_user()
+ )
with pymap:
pymap.write(msg)
# Create an array buffer containing the mapped filename
char_buffer = array.array("b", b(map_name) + zero_byte) # noqa
char_buffer_address, char_buffer_size = char_buffer.buffer_info()
# Create a string to use for the SendMessage function call
- cds = COPYDATASTRUCT(_AGENT_COPYDATA_ID, char_buffer_size,
- char_buffer_address)
+ cds = COPYDATASTRUCT(
+ _AGENT_COPYDATA_ID, char_buffer_size, char_buffer_address
+ )
- response = ctypes.windll.user32.SendMessageA(hwnd,
- win32con_WM_COPYDATA, ctypes.sizeof(cds), ctypes.byref(cds))
+ response = ctypes.windll.user32.SendMessageA(
+ hwnd, win32con_WM_COPYDATA, ctypes.sizeof(cds), ctypes.byref(cds)
+ )
if response > 0:
pymap.seek(0)
datalen = pymap.read(4)
- retlen = struct.unpack('>I', datalen)[0]
+ retlen = struct.unpack(">I", datalen)[0]
return datalen + pymap.read(retlen)
return None
@@ -127,10 +129,10 @@ class PageantConnection(object):
def recv(self, n):
if self._response is None:
- return ''
+ return ""
ret = self._response[:n]
self._response = self._response[n:]
- if self._response == '':
+ if self._response == "":
self._response = None
return ret
diff --git a/setup.cfg b/setup.cfg
index a24844d0..bf86db42 100644
--- a/setup.cfg
+++ b/setup.cfg
@@ -9,7 +9,9 @@ omit = paramiko/_winapi.py
[flake8]
exclude = sites,.git,build,dist,demos,tests
-ignore = E124,E125,E128,E261,E301,E302,E303,E402,E721
+# NOTE: W503, E203 are concessions to black 18.0b5 and could be reinstated
+# later if fixed on that end.
+ignore = E124,E125,E128,E261,E301,E302,E303,E402,E721,W503,E203
max-line-length = 79
[tool:pytest]
diff --git a/setup.py b/setup.py
index 01150831..d26d1c3e 100644
--- a/setup.py
+++ b/setup.py
@@ -19,12 +19,12 @@
import sys
from setuptools import setup
-if sys.platform == 'darwin':
+if sys.platform == "darwin":
import setup_helper
setup_helper.install_custom_make_tarball()
-longdesc = '''
+longdesc = """
This is a library for making SSH2 connections (client or server).
Emphasis is on using SSH2 as an alternative to SSL for making secure
connections between python scripts. All major ciphers and hash methods
@@ -35,14 +35,14 @@ Required packages:
To install the development version, ``pip install -e
git+https://github.com/paramiko/paramiko/#egg=paramiko``.
-'''
+"""
# Version info -- read without importing
_locals = {}
-with open('paramiko/_version.py') as fp:
+with open("paramiko/_version.py") as fp:
exec(fp.read(), None, _locals)
-version = _locals['__version__']
+version = _locals["__version__"]
setup(
name="paramiko",
@@ -52,28 +52,28 @@ setup(
author="Jeff Forcier",
author_email="jeff@bitprophet.org",
url="https://github.com/paramiko/paramiko/",
- packages=['paramiko'],
- license='LGPL',
- platforms='Posix; MacOS X; Windows',
+ packages=["paramiko"],
+ license="LGPL",
+ platforms="Posix; MacOS X; Windows",
classifiers=[
- 'Development Status :: 5 - Production/Stable',
- 'Intended Audience :: Developers',
- 'License :: OSI Approved :: '
- 'GNU Library or Lesser General Public License (LGPL)',
- 'Operating System :: OS Independent',
- 'Topic :: Internet',
- 'Topic :: Security :: Cryptography',
- 'Programming Language :: Python',
- 'Programming Language :: Python :: 2',
- 'Programming Language :: Python :: 2.7',
- 'Programming Language :: Python :: 3',
- 'Programming Language :: Python :: 3.4',
- 'Programming Language :: Python :: 3.5',
- 'Programming Language :: Python :: 3.6',
+ "Development Status :: 5 - Production/Stable",
+ "Intended Audience :: Developers",
+ "License :: OSI Approved :: "
+ "GNU Library or Lesser General Public License (LGPL)",
+ "Operating System :: OS Independent",
+ "Topic :: Internet",
+ "Topic :: Security :: Cryptography",
+ "Programming Language :: Python",
+ "Programming Language :: Python :: 2",
+ "Programming Language :: Python :: 2.7",
+ "Programming Language :: Python :: 3",
+ "Programming Language :: Python :: 3.4",
+ "Programming Language :: Python :: 3.5",
+ "Programming Language :: Python :: 3.6",
],
install_requires=[
- 'bcrypt>=3.1.3',
- 'cryptography>=1.5',
- 'pynacl>=1.0.1',
+ "bcrypt>=3.1.3",
+ "cryptography>=1.5",
+ "pynacl>=1.0.1",
],
)
diff --git a/setup_helper.py b/setup_helper.py
index c359a16c..d0a8700e 100644
--- a/setup_helper.py
+++ b/setup_helper.py
@@ -40,6 +40,7 @@ try:
except ImportError:
getgrnam = None
+
def _get_gid(name):
"""Returns a gid, given a group name."""
if getgrnam is None or name is None:
@@ -52,6 +53,7 @@ def _get_gid(name):
return result[2]
return None
+
def _get_uid(name):
"""Returns an uid, given a user name."""
if getpwnam is None or name is None:
@@ -64,8 +66,16 @@ def _get_uid(name):
return result[2]
return None
-def make_tarball(base_name, base_dir, compress='gzip', verbose=0, dry_run=0,
- owner=None, group=None):
+
+def make_tarball(
+ base_name,
+ base_dir,
+ compress="gzip",
+ verbose=0,
+ dry_run=0,
+ owner=None,
+ group=None,
+):
"""Create a tar file from all the files under 'base_dir'.
This file may be compressed.
@@ -87,28 +97,26 @@ def make_tarball(base_name, base_dir, compress='gzip', verbose=0, dry_run=0,
# "create a tree of hardlinks" step! (Would also be nice to
# detect GNU tar to use its 'z' option and save a step.)
- compress_ext = {
- 'gzip': ".gz",
- 'bzip2': '.bz2',
- 'compress': ".Z",
- }
+ compress_ext = {"gzip": ".gz", "bzip2": ".bz2", "compress": ".Z"}
# flags for compression program, each element of list will be an argument
- tarfile_compress_flag = {'gzip': 'gz', 'bzip2': 'bz2'}
- compress_flags = {'compress': ["-f"]}
+ tarfile_compress_flag = {"gzip": "gz", "bzip2": "bz2"}
+ compress_flags = {"compress": ["-f"]}
if compress is not None and compress not in compress_ext.keys():
- raise ValueError("bad value for 'compress': must be None, 'gzip',"
- "'bzip2' or 'compress'")
+ raise ValueError(
+ "bad value for 'compress': must be None, 'gzip',"
+ "'bzip2' or 'compress'"
+ )
archive_name = base_name + ".tar"
if compress and compress in tarfile_compress_flag:
archive_name += compress_ext[compress]
- mode = 'w:' + tarfile_compress_flag.get(compress, '')
+ mode = "w:" + tarfile_compress_flag.get(compress, "")
mkpath(os.path.dirname(archive_name), dry_run=dry_run)
- log.info('Creating tar file %s with mode %s' % (archive_name, mode))
+ log.info("Creating tar file %s with mode %s" % (archive_name, mode))
uid = _get_uid(owner)
gid = _get_gid(group)
@@ -136,18 +144,20 @@ def make_tarball(base_name, base_dir, compress='gzip', verbose=0, dry_run=0,
tar.close()
if compress and compress not in tarfile_compress_flag:
- spawn([compress] + compress_flags[compress] + [archive_name],
- dry_run=dry_run)
+ spawn(
+ [compress] + compress_flags[compress] + [archive_name],
+ dry_run=dry_run,
+ )
return archive_name + compress_ext[compress]
else:
return archive_name
_custom_formats = {
- 'gztar': (make_tarball, [('compress', 'gzip')], "gzip'ed tar-file"),
- 'bztar': (make_tarball, [('compress', 'bzip2')], "bzip2'ed tar-file"),
- 'ztar': (make_tarball, [('compress', 'compress')], "compressed tar file"),
- 'tar': (make_tarball, [('compress', None)], "uncompressed tar file"),
+ "gztar": (make_tarball, [("compress", "gzip")], "gzip'ed tar-file"),
+ "bztar": (make_tarball, [("compress", "bzip2")], "bzip2'ed tar-file"),
+ "ztar": (make_tarball, [("compress", "compress")], "compressed tar file"),
+ "tar": (make_tarball, [("compress", None)], "uncompressed tar file"),
}
# Hack in and insert ourselves into the distutils code base
diff --git a/sites/docs/conf.py b/sites/docs/conf.py
index 5674fed1..eb895804 100644
--- a/sites/docs/conf.py
+++ b/sites/docs/conf.py
@@ -1,16 +1,17 @@
# Obtain shared config values
import os, sys
-sys.path.append(os.path.abspath('..'))
-sys.path.append(os.path.abspath('../..'))
+
+sys.path.append(os.path.abspath(".."))
+sys.path.append(os.path.abspath("../.."))
from shared_conf import *
# Enable autodoc, intersphinx
-extensions.extend(['sphinx.ext.autodoc'])
+extensions.extend(["sphinx.ext.autodoc"])
# Autodoc settings
-autodoc_default_flags = ['members', 'special-members']
+autodoc_default_flags = ["members", "special-members"]
# Sister-site links to WWW
-html_theme_options['extra_nav_links'] = {
- "Main website": 'http://www.paramiko.org',
+html_theme_options["extra_nav_links"] = {
+ "Main website": "http://www.paramiko.org"
}
diff --git a/sites/shared_conf.py b/sites/shared_conf.py
index cf0d77ff..f4806cf1 100644
--- a/sites/shared_conf.py
+++ b/sites/shared_conf.py
@@ -5,36 +5,29 @@ import alabaster
# Alabaster theme + mini-extension
html_theme_path = [alabaster.get_path()]
-extensions = ['alabaster', 'sphinx.ext.intersphinx']
+extensions = ["alabaster", "sphinx.ext.intersphinx"]
# Paths relative to invoking conf.py - not this shared file
-html_theme = 'alabaster'
+html_theme = "alabaster"
html_theme_options = {
- 'description': "A Python implementation of SSHv2.",
- 'github_user': 'paramiko',
- 'github_repo': 'paramiko',
- 'analytics_id': 'UA-18486793-2',
- 'travis_button': True,
+ "description": "A Python implementation of SSHv2.",
+ "github_user": "paramiko",
+ "github_repo": "paramiko",
+ "analytics_id": "UA-18486793-2",
+ "travis_button": True,
}
html_sidebars = {
- '**': [
- 'about.html',
- 'navigation.html',
- 'searchbox.html',
- 'donate.html',
- ]
+ "**": ["about.html", "navigation.html", "searchbox.html", "donate.html"]
}
# Everything intersphinx's to Python
-intersphinx_mapping = {
- 'python': ('https://docs.python.org/2.7/', None),
-}
+intersphinx_mapping = {"python": ("https://docs.python.org/2.7/", None)}
# Regular settings
-project = 'Paramiko'
+project = "Paramiko"
year = datetime.now().year
-copyright = '{} Jeff Forcier'.format(year)
-master_doc = 'index'
-templates_path = ['_templates']
-exclude_trees = ['_build']
-source_suffix = '.rst'
-default_role = 'obj'
+copyright = "{} Jeff Forcier".format(year)
+master_doc = "index"
+templates_path = ["_templates"]
+exclude_trees = ["_build"]
+source_suffix = ".rst"
+default_role = "obj"
diff --git a/sites/www/conf.py b/sites/www/conf.py
index c7ba0a86..00944871 100644
--- a/sites/www/conf.py
+++ b/sites/www/conf.py
@@ -3,22 +3,22 @@ import sys
import os
from os.path import abspath, join, dirname
-sys.path.append(abspath(join(dirname(__file__), '..')))
+sys.path.append(abspath(join(dirname(__file__), "..")))
from shared_conf import *
# Releases changelog extension
-extensions.append('releases')
+extensions.append("releases")
releases_release_uri = "https://github.com/paramiko/paramiko/tree/%s"
releases_issue_uri = "https://github.com/paramiko/paramiko/issues/%s"
# Default is 'local' building, but reference the public docs site when building
# under RTD.
-target = join(dirname(__file__), '..', 'docs', '_build')
-if os.environ.get('READTHEDOCS') == 'True':
- target = 'http://docs.paramiko.org/en/latest/'
-intersphinx_mapping['docs'] = (target, None)
+target = join(dirname(__file__), "..", "docs", "_build")
+if os.environ.get("READTHEDOCS") == "True":
+ target = "http://docs.paramiko.org/en/latest/"
+intersphinx_mapping["docs"] = (target, None)
# Sister-site links to API docs
-html_theme_options['extra_nav_links'] = {
- "API Docs": 'http://docs.paramiko.org',
+html_theme_options["extra_nav_links"] = {
+ "API Docs": "http://docs.paramiko.org"
}
diff --git a/tasks.py b/tasks.py
index 6cf20377..fb15728b 100644
--- a/tasks.py
+++ b/tasks.py
@@ -3,6 +3,8 @@ from os.path import join
from shutil import rmtree, copytree
from invoke import Collection, task
+from invocations import travis
+from invocations.checks import blacken
from invocations.docs import docs, www, sites
from invocations.packaging.release import ns as release_coll, publish
from invocations.testing import count_errors
@@ -106,7 +108,7 @@ def release(ctx, sdist=True, wheel=True, sign=True, dry_run=False, index=None):
copytree("sites/docs/_build", target)
# Publish
publish(
- ctx, sdist=sdist, wheel=wheel, sign=sign, dry_run=dry_run, index=index,
+ ctx, sdist=sdist, wheel=wheel, sign=sign, dry_run=dry_run, index=index
)
# Remind
print(
@@ -121,7 +123,16 @@ def release(ctx, sdist=True, wheel=True, sign=True, dry_run=False, index=None):
release_coll.tasks["publish"] = release
ns = Collection(
- test, coverage, guard, release_coll, docs, www, sites, count_errors
+ test,
+ coverage,
+ guard,
+ release_coll,
+ docs,
+ www,
+ sites,
+ count_errors,
+ travis,
+ blacken,
)
ns.configure(
{
diff --git a/tests/conftest.py b/tests/conftest.py
index d1967a73..2b509c5c 100644
--- a/tests/conftest.py
+++ b/tests/conftest.py
@@ -19,7 +19,7 @@ from .util import _support
# presenting it on error/failure. (But also allow turning it off when doing
# very pinpoint debugging - e.g. using breakpoints, so you don't want output
# hiding enabled, but also don't want all the logging to gum up the terminal.)
-if not os.environ.get('DISABLE_LOGGING', False):
+if not os.environ.get("DISABLE_LOGGING", False):
logging.basicConfig(
level=logging.DEBUG,
# Also make sure to set up timestamping for more sanity when debugging.
@@ -43,7 +43,7 @@ def make_sftp_folder():
# TODO: if we want to lock ourselves even harder into localhost-only
# testing (probably not?) could use tempdir modules for this for improved
# safety. Then again...why would someone have such a folder???
- path = os.environ.get('TEST_FOLDER', 'paramiko-test-target')
+ path = os.environ.get("TEST_FOLDER", "paramiko-test-target")
# Forcibly nuke this directory locally, since at the moment, the below
# fixtures only ever run with a locally scoped stub test server.
shutil.rmtree(path, ignore_errors=True)
@@ -52,7 +52,7 @@ def make_sftp_folder():
return path
-@pytest.fixture#(scope='session')
+@pytest.fixture # (scope='session')
def sftp_server():
"""
Set up an in-memory SFTP server thread. Yields the client Transport/socket.
@@ -69,17 +69,17 @@ def sftp_server():
tc = Transport(sockc)
ts = Transport(socks)
# Auth
- host_key = RSAKey.from_private_key_file(_support('test_rsa.key'))
+ host_key = RSAKey.from_private_key_file(_support("test_rsa.key"))
ts.add_server_key(host_key)
# Server setup
event = threading.Event()
server = StubServer()
- ts.set_subsystem_handler('sftp', SFTPServer, StubSFTPServer)
+ ts.set_subsystem_handler("sftp", SFTPServer, StubSFTPServer)
ts.start_server(event, server)
# Wait (so client has time to connect? Not sure. Old.)
event.wait(1.0)
# Make & yield connection.
- tc.connect(username='slowdive', password='pygmalion')
+ tc.connect(username="slowdive", password="pygmalion")
yield tc
# TODO: any need for shutdown? Why didn't old suite do so? Or was that the
# point of the "join all threads from threading module" crap in test.py?
diff --git a/tests/loop.py b/tests/loop.py
index 6c432867..dd1f5a0c 100644
--- a/tests/loop.py
+++ b/tests/loop.py
@@ -22,13 +22,13 @@ import threading
from paramiko.common import asbytes
-class LoopSocket (object):
+class LoopSocket(object):
"""
A LoopSocket looks like a normal socket, but all data written to it is
delivered on the read-end of another LoopSocket, and vice versa. It's
like a software "socketpair".
"""
-
+
def __init__(self):
self.__in_buffer = bytes()
self.__lock = threading.Lock()
@@ -84,7 +84,7 @@ class LoopSocket (object):
self.__cv.notifyAll()
finally:
self.__lock.release()
-
+
def __unlink(self):
m = None
self.__lock.acquire()
@@ -96,5 +96,3 @@ class LoopSocket (object):
self.__lock.release()
if m is not None:
m.__unlink()
-
-
diff --git a/tests/stub_sftp.py b/tests/stub_sftp.py
index 19545865..ffae635d 100644
--- a/tests/stub_sftp.py
+++ b/tests/stub_sftp.py
@@ -24,13 +24,21 @@ import os
import sys
from paramiko import (
- ServerInterface, SFTPServerInterface, SFTPServer, SFTPAttributes,
- SFTPHandle, SFTP_OK, SFTP_FAILURE, AUTH_SUCCESSFUL, OPEN_SUCCEEDED,
+ ServerInterface,
+ SFTPServerInterface,
+ SFTPServer,
+ SFTPAttributes,
+ SFTPHandle,
+ SFTP_OK,
+ SFTP_FAILURE,
+ AUTH_SUCCESSFUL,
+ OPEN_SUCCEEDED,
)
from paramiko.common import o666
-class StubServer (ServerInterface):
+class StubServer(ServerInterface):
+
def check_auth_password(self, username, password):
# all are allowed
return AUTH_SUCCESSFUL
@@ -39,7 +47,8 @@ class StubServer (ServerInterface):
return OPEN_SUCCEEDED
-class StubSFTPHandle (SFTPHandle):
+class StubSFTPHandle(SFTPHandle):
+
def stat(self):
try:
return SFTPAttributes.from_stat(os.fstat(self.readfile.fileno()))
@@ -56,11 +65,11 @@ class StubSFTPHandle (SFTPHandle):
return SFTPServer.convert_errno(e.errno)
-class StubSFTPServer (SFTPServerInterface):
+class StubSFTPServer(SFTPServerInterface):
# assume current folder is a fine root
# (the tests always create and eventually delete a subfolder, so there shouldn't be any mess)
ROOT = os.getcwd()
-
+
def _realpath(self, path):
return self.ROOT + self.canonicalize(path)
@@ -70,7 +79,9 @@ class StubSFTPServer (SFTPServerInterface):
out = []
flist = os.listdir(path)
for fname in flist:
- attr = SFTPAttributes.from_stat(os.stat(os.path.join(path, fname)))
+ attr = SFTPAttributes.from_stat(
+ os.stat(os.path.join(path, fname))
+ )
attr.filename = fname
out.append(attr)
return out
@@ -94,9 +105,9 @@ class StubSFTPServer (SFTPServerInterface):
def open(self, path, flags, attr):
path = self._realpath(path)
try:
- binary_flag = getattr(os, 'O_BINARY', 0)
+ binary_flag = getattr(os, "O_BINARY", 0)
flags |= binary_flag
- mode = getattr(attr, 'st_mode', None)
+ mode = getattr(attr, "st_mode", None)
if mode is not None:
fd = os.open(path, flags, mode)
else:
@@ -110,17 +121,17 @@ class StubSFTPServer (SFTPServerInterface):
SFTPServer.set_file_attr(path, attr)
if flags & os.O_WRONLY:
if flags & os.O_APPEND:
- fstr = 'ab'
+ fstr = "ab"
else:
- fstr = 'wb'
+ fstr = "wb"
elif flags & os.O_RDWR:
if flags & os.O_APPEND:
- fstr = 'a+b'
+ fstr = "a+b"
else:
- fstr = 'r+b'
+ fstr = "r+b"
else:
# O_RDONLY (== 0)
- fstr = 'rb'
+ fstr = "rb"
try:
f = os.fdopen(fd, fstr)
except OSError as e:
@@ -159,7 +170,6 @@ class StubSFTPServer (SFTPServerInterface):
return SFTPServer.convert_errno(e.errno)
return SFTP_OK
-
def mkdir(self, path, attr):
path = self._realpath(path)
try:
@@ -188,18 +198,18 @@ class StubSFTPServer (SFTPServerInterface):
def symlink(self, target_path, path):
path = self._realpath(path)
- if (len(target_path) > 0) and (target_path[0] == '/'):
+ if (len(target_path) > 0) and (target_path[0] == "/"):
# absolute symlink
target_path = os.path.join(self.ROOT, target_path[1:])
- if target_path[:2] == '//':
+ if target_path[:2] == "//":
# bug in os.path.join
target_path = target_path[1:]
else:
# compute relative to path
abspath = os.path.join(os.path.dirname(path), target_path)
- if abspath[:len(self.ROOT)] != self.ROOT:
+ if abspath[: len(self.ROOT)] != self.ROOT:
# this symlink isn't going to work anyway -- just break it immediately
- target_path = '<error>'
+ target_path = "<error>"
try:
os.symlink(target_path, path)
except OSError as e:
@@ -214,10 +224,10 @@ class StubSFTPServer (SFTPServerInterface):
return SFTPServer.convert_errno(e.errno)
# if it's absolute, remove the root
if os.path.isabs(symlink):
- if symlink[:len(self.ROOT)] == self.ROOT:
- symlink = symlink[len(self.ROOT):]
- if (len(symlink) == 0) or (symlink[0] != '/'):
- symlink = '/' + symlink
+ if symlink[: len(self.ROOT)] == self.ROOT:
+ symlink = symlink[len(self.ROOT) :]
+ if (len(symlink) == 0) or (symlink[0] != "/"):
+ symlink = "/" + symlink
else:
- symlink = '<error>'
+ symlink = "<error>"
return symlink
diff --git a/tests/test_auth.py b/tests/test_auth.py
index dacdd654..acabb1bd 100644
--- a/tests/test_auth.py
+++ b/tests/test_auth.py
@@ -26,8 +26,13 @@ import unittest
from time import sleep
from paramiko import (
- Transport, ServerInterface, RSAKey, DSSKey, BadAuthenticationType,
- InteractiveQuery, AuthenticationException,
+ Transport,
+ ServerInterface,
+ RSAKey,
+ DSSKey,
+ BadAuthenticationType,
+ InteractiveQuery,
+ AuthenticationException,
)
from paramiko import AUTH_FAILED, AUTH_PARTIALLY_SUCCESSFUL, AUTH_SUCCESSFUL
from paramiko.py3compat import u
@@ -36,54 +41,57 @@ from .loop import LoopSocket
from .util import _support, slow
-_pwd = u('\u2022')
+_pwd = u("\u2022")
-class NullServer (ServerInterface):
+class NullServer(ServerInterface):
paranoid_did_password = False
paranoid_did_public_key = False
- paranoid_key = DSSKey.from_private_key_file(_support('test_dss.key'))
+ paranoid_key = DSSKey.from_private_key_file(_support("test_dss.key"))
def get_allowed_auths(self, username):
- if username == 'slowdive':
- return 'publickey,password'
- if username == 'paranoid':
- if not self.paranoid_did_password and not self.paranoid_did_public_key:
- return 'publickey,password'
+ if username == "slowdive":
+ return "publickey,password"
+ if username == "paranoid":
+ if (
+ not self.paranoid_did_password
+ and not self.paranoid_did_public_key
+ ):
+ return "publickey,password"
elif self.paranoid_did_password:
- return 'publickey'
+ return "publickey"
else:
- return 'password'
- if username == 'commie':
- return 'keyboard-interactive'
- if username == 'utf8':
- return 'password'
- if username == 'non-utf8':
- return 'password'
- return 'publickey'
+ return "password"
+ if username == "commie":
+ return "keyboard-interactive"
+ if username == "utf8":
+ return "password"
+ if username == "non-utf8":
+ return "password"
+ return "publickey"
def check_auth_password(self, username, password):
- if (username == 'slowdive') and (password == 'pygmalion'):
+ if (username == "slowdive") and (password == "pygmalion"):
return AUTH_SUCCESSFUL
- if (username == 'paranoid') and (password == 'paranoid'):
+ if (username == "paranoid") and (password == "paranoid"):
# 2-part auth (even openssh doesn't support this)
self.paranoid_did_password = True
if self.paranoid_did_public_key:
return AUTH_SUCCESSFUL
return AUTH_PARTIALLY_SUCCESSFUL
- if (username == 'utf8') and (password == _pwd):
+ if (username == "utf8") and (password == _pwd):
return AUTH_SUCCESSFUL
- if (username == 'non-utf8') and (password == '\xff'):
+ if (username == "non-utf8") and (password == "\xff"):
return AUTH_SUCCESSFUL
- if username == 'bad-server':
+ if username == "bad-server":
raise Exception("Ack!")
- if username == 'unresponsive-server':
+ if username == "unresponsive-server":
sleep(5)
return AUTH_SUCCESSFUL
return AUTH_FAILED
def check_auth_publickey(self, username, key):
- if (username == 'paranoid') and (key == self.paranoid_key):
+ if (username == "paranoid") and (key == self.paranoid_key):
# 2-part auth
self.paranoid_did_public_key = True
if self.paranoid_did_password:
@@ -92,19 +100,21 @@ class NullServer (ServerInterface):
return AUTH_FAILED
def check_auth_interactive(self, username, submethods):
- if username == 'commie':
+ if username == "commie":
self.username = username
- return InteractiveQuery('password', 'Please enter a password.', ('Password', False))
+ return InteractiveQuery(
+ "password", "Please enter a password.", ("Password", False)
+ )
return AUTH_FAILED
def check_auth_interactive_response(self, responses):
- if self.username == 'commie':
- if (len(responses) == 1) and (responses[0] == 'cat'):
+ if self.username == "commie":
+ if (len(responses) == 1) and (responses[0] == "cat"):
return AUTH_SUCCESSFUL
return AUTH_FAILED
-class AuthTest (unittest.TestCase):
+class AuthTest(unittest.TestCase):
def setUp(self):
self.socks = LoopSocket()
@@ -120,7 +130,7 @@ class AuthTest (unittest.TestCase):
self.sockc.close()
def start_server(self):
- host_key = RSAKey.from_private_key_file(_support('test_rsa.key'))
+ host_key = RSAKey.from_private_key_file(_support("test_rsa.key"))
self.public_host_key = RSAKey(data=host_key.asbytes())
self.ts.add_server_key(host_key)
self.event = threading.Event()
@@ -140,13 +150,16 @@ class AuthTest (unittest.TestCase):
"""
self.start_server()
try:
- self.tc.connect(hostkey=self.public_host_key,
- username='unknown', password='error')
+ self.tc.connect(
+ hostkey=self.public_host_key,
+ username="unknown",
+ password="error",
+ )
self.assertTrue(False)
except:
etype, evalue, etb = sys.exc_info()
self.assertEqual(BadAuthenticationType, etype)
- self.assertEqual(['publickey'], evalue.allowed_types)
+ self.assertEqual(["publickey"], evalue.allowed_types)
def test_bad_password(self):
"""
@@ -156,12 +169,12 @@ class AuthTest (unittest.TestCase):
self.start_server()
self.tc.connect(hostkey=self.public_host_key)
try:
- self.tc.auth_password(username='slowdive', password='error')
+ self.tc.auth_password(username="slowdive", password="error")
self.assertTrue(False)
except:
etype, evalue, etb = sys.exc_info()
self.assertTrue(issubclass(etype, AuthenticationException))
- self.tc.auth_password(username='slowdive', password='pygmalion')
+ self.tc.auth_password(username="slowdive", password="pygmalion")
self.verify_finished()
def test_multipart_auth(self):
@@ -170,10 +183,12 @@ class AuthTest (unittest.TestCase):
"""
self.start_server()
self.tc.connect(hostkey=self.public_host_key)
- remain = self.tc.auth_password(username='paranoid', password='paranoid')
- self.assertEqual(['publickey'], remain)
- key = DSSKey.from_private_key_file(_support('test_dss.key'))
- remain = self.tc.auth_publickey(username='paranoid', key=key)
+ remain = self.tc.auth_password(
+ username="paranoid", password="paranoid"
+ )
+ self.assertEqual(["publickey"], remain)
+ key = DSSKey.from_private_key_file(_support("test_dss.key"))
+ remain = self.tc.auth_publickey(username="paranoid", key=key)
self.assertEqual([], remain)
self.verify_finished()
@@ -188,10 +203,11 @@ class AuthTest (unittest.TestCase):
self.got_title = title
self.got_instructions = instructions
self.got_prompts = prompts
- return ['cat']
- remain = self.tc.auth_interactive('commie', handler)
- self.assertEqual(self.got_title, 'password')
- self.assertEqual(self.got_prompts, [('Password', False)])
+ return ["cat"]
+
+ remain = self.tc.auth_interactive("commie", handler)
+ self.assertEqual(self.got_title, "password")
+ self.assertEqual(self.got_prompts, [("Password", False)])
self.assertEqual([], remain)
self.verify_finished()
@@ -202,7 +218,7 @@ class AuthTest (unittest.TestCase):
"""
self.start_server()
self.tc.connect(hostkey=self.public_host_key)
- remain = self.tc.auth_password('commie', 'cat')
+ remain = self.tc.auth_password("commie", "cat")
self.assertEqual([], remain)
self.verify_finished()
@@ -212,7 +228,7 @@ class AuthTest (unittest.TestCase):
"""
self.start_server()
self.tc.connect(hostkey=self.public_host_key)
- remain = self.tc.auth_password('utf8', _pwd)
+ remain = self.tc.auth_password("utf8", _pwd)
self.assertEqual([], remain)
self.verify_finished()
@@ -223,7 +239,7 @@ class AuthTest (unittest.TestCase):
"""
self.start_server()
self.tc.connect(hostkey=self.public_host_key)
- remain = self.tc.auth_password('non-utf8', '\xff')
+ remain = self.tc.auth_password("non-utf8", "\xff")
self.assertEqual([], remain)
self.verify_finished()
@@ -235,7 +251,7 @@ class AuthTest (unittest.TestCase):
self.start_server()
self.tc.connect(hostkey=self.public_host_key)
try:
- remain = self.tc.auth_password('bad-server', 'hello')
+ remain = self.tc.auth_password("bad-server", "hello")
except:
etype, evalue, etb = sys.exc_info()
self.assertTrue(issubclass(etype, AuthenticationException))
@@ -250,8 +266,8 @@ class AuthTest (unittest.TestCase):
self.start_server()
self.tc.connect()
try:
- remain = self.tc.auth_password('unresponsive-server', 'hello')
+ remain = self.tc.auth_password("unresponsive-server", "hello")
except:
etype, evalue, etb = sys.exc_info()
self.assertTrue(issubclass(etype, AuthenticationException))
- self.assertTrue('Authentication timeout' in str(evalue))
+ self.assertTrue("Authentication timeout" in str(evalue))
diff --git a/tests/test_buffered_pipe.py b/tests/test_buffered_pipe.py
index 03616c55..9f986a5e 100644
--- a/tests/test_buffered_pipe.py
+++ b/tests/test_buffered_pipe.py
@@ -30,9 +30,9 @@ from paramiko.py3compat import b
def delay_thread(p):
- p.feed('a')
+ p.feed("a")
time.sleep(0.5)
- p.feed('b')
+ p.feed("b")
p.close()
@@ -42,41 +42,42 @@ def close_thread(p):
class BufferedPipeTest(unittest.TestCase):
+
def test_1_buffered_pipe(self):
p = BufferedPipe()
self.assertTrue(not p.read_ready())
- p.feed('hello.')
+ p.feed("hello.")
self.assertTrue(p.read_ready())
data = p.read(6)
- self.assertEqual(b'hello.', data)
+ self.assertEqual(b"hello.", data)
- p.feed('plus/minus')
- self.assertEqual(b'plu', p.read(3))
- self.assertEqual(b's/m', p.read(3))
- self.assertEqual(b'inus', p.read(4))
+ p.feed("plus/minus")
+ self.assertEqual(b"plu", p.read(3))
+ self.assertEqual(b"s/m", p.read(3))
+ self.assertEqual(b"inus", p.read(4))
p.close()
self.assertTrue(not p.read_ready())
- self.assertEqual(b'', p.read(1))
+ self.assertEqual(b"", p.read(1))
def test_2_delay(self):
p = BufferedPipe()
self.assertTrue(not p.read_ready())
threading.Thread(target=delay_thread, args=(p,)).start()
- self.assertEqual(b'a', p.read(1, 0.1))
+ self.assertEqual(b"a", p.read(1, 0.1))
try:
p.read(1, 0.1)
self.assertTrue(False)
except PipeTimeout:
pass
- self.assertEqual(b'b', p.read(1, 1.0))
- self.assertEqual(b'', p.read(1))
+ self.assertEqual(b"b", p.read(1, 1.0))
+ self.assertEqual(b"", p.read(1))
def test_3_close_while_reading(self):
p = BufferedPipe()
threading.Thread(target=close_thread, args=(p,)).start()
data = p.read(1, 1.0)
- self.assertEqual(b'', data)
+ self.assertEqual(b"", data)
def test_4_or_pipe(self):
p = pipe.make_pipe()
diff --git a/tests/test_client.py b/tests/test_client.py
index 7163fdcf..4943df29 100644
--- a/tests/test_client.py
+++ b/tests/test_client.py
@@ -48,30 +48,31 @@ requires_gss_auth = unittest.skipUnless(
)
FINGERPRINTS = {
- 'ssh-dss': b'\x44\x78\xf0\xb9\xa2\x3c\xc5\x18\x20\x09\xff\x75\x5b\xc1\xd2\x6c',
- 'ssh-rsa': b'\x60\x73\x38\x44\xcb\x51\x86\x65\x7f\xde\xda\xa2\x2b\x5a\x57\xd5',
- 'ecdsa-sha2-nistp256': b'\x25\x19\xeb\x55\xe6\xa1\x47\xff\x4f\x38\xd2\x75\x6f\xa5\xd5\x60',
- 'ssh-ed25519': b'\xb3\xd5"\xaa\xf9u^\xe8\xcd\x0e\xea\x02\xb9)\xa2\x80',
+ "ssh-dss": b"\x44\x78\xf0\xb9\xa2\x3c\xc5\x18\x20\x09\xff\x75\x5b\xc1\xd2\x6c",
+ "ssh-rsa": b"\x60\x73\x38\x44\xcb\x51\x86\x65\x7f\xde\xda\xa2\x2b\x5a\x57\xd5",
+ "ecdsa-sha2-nistp256": b"\x25\x19\xeb\x55\xe6\xa1\x47\xff\x4f\x38\xd2\x75\x6f\xa5\xd5\x60",
+ "ssh-ed25519": b'\xb3\xd5"\xaa\xf9u^\xe8\xcd\x0e\xea\x02\xb9)\xa2\x80',
}
class NullServer(paramiko.ServerInterface):
+
def __init__(self, *args, **kwargs):
# Allow tests to enable/disable specific key types
- self.__allowed_keys = kwargs.pop('allowed_keys', [])
+ self.__allowed_keys = kwargs.pop("allowed_keys", [])
# And allow them to set a (single...meh) expected public blob (cert)
- self.__expected_public_blob = kwargs.pop('public_blob', None)
+ self.__expected_public_blob = kwargs.pop("public_blob", None)
super(NullServer, self).__init__(*args, **kwargs)
def get_allowed_auths(self, username):
- if username == 'slowdive':
- return 'publickey,password'
- return 'publickey'
+ if username == "slowdive":
+ return "publickey,password"
+ return "publickey"
def check_auth_password(self, username, password):
- if (username == 'slowdive') and (password == 'pygmalion'):
+ if (username == "slowdive") and (password == "pygmalion"):
return paramiko.AUTH_SUCCESSFUL
- if (username == 'slowdive') and (password == 'unresponsive-server'):
+ if (username == "slowdive") and (password == "unresponsive-server"):
time.sleep(5)
return paramiko.AUTH_SUCCESSFUL
return paramiko.AUTH_FAILED
@@ -83,13 +84,13 @@ class NullServer(paramiko.ServerInterface):
return paramiko.AUTH_FAILED
# Base check: allowed auth type & fingerprint matches
happy = (
- key.get_name() in self.__allowed_keys and
- key.get_fingerprint() == expected
+ key.get_name() in self.__allowed_keys
+ and key.get_fingerprint() == expected
)
# Secondary check: if test wants assertions about cert data
if (
- self.__expected_public_blob is not None and
- key.public_blob != self.__expected_public_blob
+ self.__expected_public_blob is not None
+ and key.public_blob != self.__expected_public_blob
):
happy = False
return paramiko.AUTH_SUCCESSFUL if happy else paramiko.AUTH_FAILED
@@ -98,31 +99,32 @@ class NullServer(paramiko.ServerInterface):
return paramiko.OPEN_SUCCEEDED
def check_channel_exec_request(self, channel, command):
- if command != b'yes':
+ if command != b"yes":
return False
return True
def check_channel_env_request(self, channel, name, value):
- if name == 'INVALID_ENV':
+ if name == "INVALID_ENV":
return False
- if not hasattr(channel, 'env'):
- setattr(channel, 'env', {})
+ if not hasattr(channel, "env"):
+ setattr(channel, "env", {})
channel.env[name] = value
return True
class ClientTest(unittest.TestCase):
+
def setUp(self):
self.sockl = socket.socket()
- self.sockl.bind(('localhost', 0))
+ self.sockl.bind(("localhost", 0))
self.sockl.listen(1)
self.addr, self.port = self.sockl.getsockname()
self.connect_kwargs = dict(
hostname=self.addr,
port=self.port,
- username='slowdive',
+ username="slowdive",
look_for_keys=False,
)
self.event = threading.Event()
@@ -130,10 +132,10 @@ class ClientTest(unittest.TestCase):
def tearDown(self):
# Shut down client Transport
- if hasattr(self, 'tc'):
+ if hasattr(self, "tc"):
self.tc.close()
# Shut down shared socket
- if hasattr(self, 'sockl'):
+ if hasattr(self, "sockl"):
# Signal to server thread that it should shut down early; it checks
# this immediately after accept(). (In scenarios where connection
# actually succeeded during the test, this becomes a no-op.)
@@ -151,7 +153,7 @@ class ClientTest(unittest.TestCase):
self.sockl.close()
def _run(
- self, allowed_keys=None, delay=0, public_blob=None, kill_event=None,
+ self, allowed_keys=None, delay=0, public_blob=None, kill_event=None
):
if allowed_keys is None:
allowed_keys = FINGERPRINTS.keys()
@@ -163,10 +165,10 @@ class ClientTest(unittest.TestCase):
self.socks.close()
return
self.ts = paramiko.Transport(self.socks)
- keypath = _support('test_rsa.key')
+ keypath = _support("test_rsa.key")
host_key = paramiko.RSAKey.from_private_key_file(keypath)
self.ts.add_server_key(host_key)
- keypath = _support('test_ecdsa_256.key')
+ keypath = _support("test_ecdsa_256.key")
host_key = paramiko.ECDSAKey.from_private_key_file(keypath)
self.ts.add_server_key(host_key)
server = NullServer(allowed_keys=allowed_keys, public_blob=public_blob)
@@ -181,17 +183,21 @@ class ClientTest(unittest.TestCase):
The exception is ``allowed_keys`` which is stripped and handed to the
``NullServer`` used for testing.
"""
- run_kwargs = {'kill_event': self.kill_event}
- for key in ('allowed_keys', 'public_blob'):
+ run_kwargs = {"kill_event": self.kill_event}
+ for key in ("allowed_keys", "public_blob"):
run_kwargs[key] = kwargs.pop(key, None)
# Server setup
threading.Thread(target=self._run, kwargs=run_kwargs).start()
- host_key = paramiko.RSAKey.from_private_key_file(_support('test_rsa.key'))
+ host_key = paramiko.RSAKey.from_private_key_file(
+ _support("test_rsa.key")
+ )
public_host_key = paramiko.RSAKey(data=host_key.asbytes())
# Client setup
self.tc = paramiko.SSHClient()
- self.tc.get_host_keys().add('[%s]:%d' % (self.addr, self.port), 'ssh-rsa', public_host_key)
+ self.tc.get_host_keys().add(
+ "[%s]:%d" % (self.addr, self.port), "ssh-rsa", public_host_key
+ )
# Actual connection
self.tc.connect(**dict(self.connect_kwargs, **kwargs))
@@ -200,22 +206,22 @@ class ClientTest(unittest.TestCase):
self.event.wait(1.0)
self.assertTrue(self.event.is_set())
self.assertTrue(self.ts.is_active())
- self.assertEqual('slowdive', self.ts.get_username())
+ self.assertEqual("slowdive", self.ts.get_username())
self.assertEqual(True, self.ts.is_authenticated())
self.assertEqual(False, self.tc.get_transport().gss_kex_used)
# Command execution functions?
- stdin, stdout, stderr = self.tc.exec_command('yes')
+ stdin, stdout, stderr = self.tc.exec_command("yes")
schan = self.ts.accept(1.0)
- schan.send('Hello there.\n')
- schan.send_stderr('This is on stderr.\n')
+ schan.send("Hello there.\n")
+ schan.send_stderr("This is on stderr.\n")
schan.close()
- self.assertEqual('Hello there.\n', stdout.readline())
- self.assertEqual('', stdout.readline())
- self.assertEqual('This is on stderr.\n', stderr.readline())
- self.assertEqual('', stderr.readline())
+ self.assertEqual("Hello there.\n", stdout.readline())
+ self.assertEqual("", stdout.readline())
+ self.assertEqual("This is on stderr.\n", stderr.readline())
+ self.assertEqual("", stderr.readline())
# Cleanup
stdin.close()
@@ -224,32 +230,33 @@ class ClientTest(unittest.TestCase):
class SSHClientTest(ClientTest):
+
def test_1_client(self):
"""
verify that the SSHClient stuff works too.
"""
- self._test_connection(password='pygmalion')
+ self._test_connection(password="pygmalion")
def test_2_client_dsa(self):
"""
verify that SSHClient works with a DSA key.
"""
- self._test_connection(key_filename=_support('test_dss.key'))
+ self._test_connection(key_filename=_support("test_dss.key"))
def test_client_rsa(self):
"""
verify that SSHClient works with an RSA key.
"""
- self._test_connection(key_filename=_support('test_rsa.key'))
+ self._test_connection(key_filename=_support("test_rsa.key"))
def test_2_5_client_ecdsa(self):
"""
verify that SSHClient works with an ECDSA key.
"""
- self._test_connection(key_filename=_support('test_ecdsa_256.key'))
+ self._test_connection(key_filename=_support("test_ecdsa_256.key"))
def test_client_ed25519(self):
- self._test_connection(key_filename=_support('test_ed25519.key'))
+ self._test_connection(key_filename=_support("test_ed25519.key"))
def test_3_multiple_key_files(self):
"""
@@ -257,22 +264,22 @@ class SSHClientTest(ClientTest):
"""
# This is dumb :(
types_ = {
- 'rsa': 'ssh-rsa',
- 'dss': 'ssh-dss',
- 'ecdsa': 'ecdsa-sha2-nistp256',
+ "rsa": "ssh-rsa",
+ "dss": "ssh-dss",
+ "ecdsa": "ecdsa-sha2-nistp256",
}
# Various combos of attempted & valid keys
# TODO: try every possible combo using itertools functions
for attempt, accept in (
- (['rsa', 'dss'], ['dss']), # Original test #3
- (['dss', 'rsa'], ['dss']), # Ordering matters sometimes, sadly
- (['dss', 'rsa', 'ecdsa_256'], ['dss']), # Try ECDSA but fail
- (['rsa', 'ecdsa_256'], ['ecdsa']), # ECDSA success
+ (["rsa", "dss"], ["dss"]), # Original test #3
+ (["dss", "rsa"], ["dss"]), # Ordering matters sometimes, sadly
+ (["dss", "rsa", "ecdsa_256"], ["dss"]), # Try ECDSA but fail
+ (["rsa", "ecdsa_256"], ["ecdsa"]), # ECDSA success
):
try:
self._test_connection(
key_filename=[
- _support('test_{}.key'.format(x)) for x in attempt
+ _support("test_{}.key".format(x)) for x in attempt
],
allowed_keys=[types_[x] for x in accept],
)
@@ -288,10 +295,11 @@ class SSHClientTest(ClientTest):
"""
# Until #387 is fixed we have to catch a high-up exception since
# various platforms trigger different errors here >_<
- self.assertRaises(SSHException,
+ self.assertRaises(
+ SSHException,
self._test_connection,
- key_filename=[_support('test_rsa.key')],
- allowed_keys=['ecdsa-sha2-nistp256'],
+ key_filename=[_support("test_rsa.key")],
+ allowed_keys=["ecdsa-sha2-nistp256"],
)
def test_certs_allowed_as_key_filename_values(self):
@@ -299,9 +307,9 @@ class SSHClientTest(ClientTest):
# They're similar except for which path is given; the expected auth and
# server-side behavior is 100% identical.)
# NOTE: only bothered whipping up one cert per overall class/family.
- for type_ in ('rsa', 'dss', 'ecdsa_256', 'ed25519'):
- cert_name = 'test_{}.key-cert.pub'.format(type_)
- cert_path = _support(os.path.join('cert_support', cert_name))
+ for type_ in ("rsa", "dss", "ecdsa_256", "ed25519"):
+ cert_name = "test_{}.key-cert.pub".format(type_)
+ cert_path = _support(os.path.join("cert_support", cert_name))
self._test_connection(
key_filename=cert_path,
public_blob=PublicBlob.from_file(cert_path),
@@ -314,13 +322,13 @@ class SSHClientTest(ClientTest):
# about the server-side key object's public blob. Thus, we can prove
# that a specific cert was found, along with regular authorization
# succeeding proving that the overall flow works.
- for type_ in ('rsa', 'dss', 'ecdsa_256', 'ed25519'):
- key_name = 'test_{}.key'.format(type_)
- key_path = _support(os.path.join('cert_support', key_name))
+ for type_ in ("rsa", "dss", "ecdsa_256", "ed25519"):
+ key_name = "test_{}.key".format(type_)
+ key_path = _support(os.path.join("cert_support", key_name))
self._test_connection(
key_filename=key_path,
public_blob=PublicBlob.from_file(
- '{}-cert.pub'.format(key_path)
+ "{}-cert.pub".format(key_path)
),
)
@@ -335,19 +343,19 @@ class SSHClientTest(ClientTest):
verify that SSHClient's AutoAddPolicy works.
"""
threading.Thread(target=self._run).start()
- hostname = '[%s]:%d' % (self.addr, self.port)
- key_file = _support('test_ecdsa_256.key')
+ hostname = "[%s]:%d" % (self.addr, self.port)
+ key_file = _support("test_ecdsa_256.key")
public_host_key = paramiko.ECDSAKey.from_private_key_file(key_file)
self.tc = paramiko.SSHClient()
self.tc.set_missing_host_key_policy(paramiko.AutoAddPolicy())
self.assertEqual(0, len(self.tc.get_host_keys()))
- self.tc.connect(password='pygmalion', **self.connect_kwargs)
+ self.tc.connect(password="pygmalion", **self.connect_kwargs)
self.event.wait(1.0)
self.assertTrue(self.event.is_set())
self.assertTrue(self.ts.is_active())
- self.assertEqual('slowdive', self.ts.get_username())
+ self.assertEqual("slowdive", self.ts.get_username())
self.assertEqual(True, self.ts.is_authenticated())
self.assertEqual(1, len(self.tc.get_host_keys()))
new_host_key = list(self.tc.get_host_keys()[hostname].values())[0]
@@ -357,9 +365,11 @@ class SSHClientTest(ClientTest):
"""
verify that SSHClient correctly saves a known_hosts file.
"""
- warnings.filterwarnings('ignore', 'tempnam.*')
+ warnings.filterwarnings("ignore", "tempnam.*")
- host_key = paramiko.RSAKey.from_private_key_file(_support('test_rsa.key'))
+ host_key = paramiko.RSAKey.from_private_key_file(
+ _support("test_rsa.key")
+ )
public_host_key = paramiko.RSAKey(data=host_key.asbytes())
fd, localname = mkstemp()
os.close(fd)
@@ -367,11 +377,13 @@ class SSHClientTest(ClientTest):
client = paramiko.SSHClient()
self.assertEquals(0, len(client.get_host_keys()))
- host_id = '[%s]:%d' % (self.addr, self.port)
+ host_id = "[%s]:%d" % (self.addr, self.port)
- client.get_host_keys().add(host_id, 'ssh-rsa', public_host_key)
+ client.get_host_keys().add(host_id, "ssh-rsa", public_host_key)
self.assertEquals(1, len(client.get_host_keys()))
- self.assertEquals(public_host_key, client.get_host_keys()[host_id]['ssh-rsa'])
+ self.assertEquals(
+ public_host_key, client.get_host_keys()[host_id]["ssh-rsa"]
+ )
client.save_host_keys(localname)
@@ -394,7 +406,7 @@ class SSHClientTest(ClientTest):
self.tc = paramiko.SSHClient()
self.tc.set_missing_host_key_policy(paramiko.AutoAddPolicy())
self.assertEqual(0, len(self.tc.get_host_keys()))
- self.tc.connect(**dict(self.connect_kwargs, password='pygmalion'))
+ self.tc.connect(**dict(self.connect_kwargs, password="pygmalion"))
self.event.wait(1.0)
self.assertTrue(self.event.is_set())
@@ -423,7 +435,7 @@ class SSHClientTest(ClientTest):
self.tc = tc
self.tc.set_missing_host_key_policy(paramiko.AutoAddPolicy())
self.assertEquals(0, len(self.tc.get_host_keys()))
- self.tc.connect(**dict(self.connect_kwargs, password='pygmalion'))
+ self.tc.connect(**dict(self.connect_kwargs, password="pygmalion"))
self.event.wait(1.0)
self.assertTrue(self.event.is_set())
@@ -438,19 +450,19 @@ class SSHClientTest(ClientTest):
verify that the SSHClient has a configurable banner timeout.
"""
# Start the thread with a 1 second wait.
- threading.Thread(target=self._run, kwargs={'delay': 1}).start()
- host_key = paramiko.RSAKey.from_private_key_file(_support('test_rsa.key'))
+ threading.Thread(target=self._run, kwargs={"delay": 1}).start()
+ host_key = paramiko.RSAKey.from_private_key_file(
+ _support("test_rsa.key")
+ )
public_host_key = paramiko.RSAKey(data=host_key.asbytes())
self.tc = paramiko.SSHClient()
- self.tc.get_host_keys().add('[%s]:%d' % (self.addr, self.port), 'ssh-rsa', public_host_key)
+ self.tc.get_host_keys().add(
+ "[%s]:%d" % (self.addr, self.port), "ssh-rsa", public_host_key
+ )
# Connect with a half second banner timeout.
kwargs = dict(self.connect_kwargs, banner_timeout=0.5)
- self.assertRaises(
- paramiko.SSHException,
- self.tc.connect,
- **kwargs
- )
+ self.assertRaises(paramiko.SSHException, self.tc.connect, **kwargs)
def test_8_auth_trickledown(self):
"""
@@ -466,9 +478,9 @@ class SSHClientTest(ClientTest):
# 'television' as per tests/test_pkey.py). NOTE: must use
# key_filename, loading the actual key here with PKey will except
# immediately; we're testing the try/except crap within Client.
- key_filename=[_support('test_rsa_password.key')],
+ key_filename=[_support("test_rsa_password.key")],
# Actual password for default 'slowdive' user
- password='pygmalion',
+ password="pygmalion",
)
self._test_connection(**kwargs)
@@ -481,7 +493,7 @@ class SSHClientTest(ClientTest):
self.assertRaises(
AuthenticationException,
self._test_connection,
- password='unresponsive-server',
+ password="unresponsive-server",
auth_timeout=0.5,
)
@@ -490,10 +502,7 @@ class SSHClientTest(ClientTest):
"""
Failed gssapi-keyex auth doesn't prevent subsequent key auth from succeeding
"""
- kwargs = dict(
- gss_kex=True,
- key_filename=[_support('test_rsa.key')],
- )
+ kwargs = dict(gss_kex=True, key_filename=[_support("test_rsa.key")])
self._test_connection(**kwargs)
@requires_gss_auth
@@ -501,10 +510,7 @@ class SSHClientTest(ClientTest):
"""
Failed gssapi-with-mic auth doesn't prevent subsequent key auth from succeeding
"""
- kwargs = dict(
- gss_auth=True,
- key_filename=[_support('test_rsa.key')],
- )
+ kwargs = dict(gss_auth=True, key_filename=[_support("test_rsa.key")])
self._test_connection(**kwargs)
def test_12_reject_policy(self):
@@ -519,7 +525,8 @@ class SSHClientTest(ClientTest):
self.assertRaises(
paramiko.SSHException,
self.tc.connect,
- password='pygmalion', **self.connect_kwargs
+ password="pygmalion",
+ **self.connect_kwargs
)
@requires_gss_auth
@@ -537,14 +544,14 @@ class SSHClientTest(ClientTest):
self.assertRaises(
paramiko.SSHException,
self.tc.connect,
- password='pygmalion',
+ password="pygmalion",
gss_kex=True,
- **self.connect_kwargs
+ **self.connect_kwargs
)
def _client_host_key_bad(self, host_key):
threading.Thread(target=self._run).start()
- hostname = '[%s]:%d' % (self.addr, self.port)
+ hostname = "[%s]:%d" % (self.addr, self.port)
self.tc = paramiko.SSHClient()
self.tc.set_missing_host_key_policy(paramiko.WarningPolicy())
@@ -554,13 +561,13 @@ class SSHClientTest(ClientTest):
self.assertRaises(
paramiko.BadHostKeyException,
self.tc.connect,
- password='pygmalion',
+ password="pygmalion",
**self.connect_kwargs
)
def _client_host_key_good(self, ktype, kfile):
threading.Thread(target=self._run).start()
- hostname = '[%s]:%d' % (self.addr, self.port)
+ hostname = "[%s]:%d" % (self.addr, self.port)
self.tc = paramiko.SSHClient()
self.tc.set_missing_host_key_policy(paramiko.RejectPolicy())
@@ -568,7 +575,7 @@ class SSHClientTest(ClientTest):
known_hosts = self.tc.get_host_keys()
known_hosts.add(hostname, host_key.get_name(), host_key)
- self.tc.connect(password='pygmalion', **self.connect_kwargs)
+ self.tc.connect(password="pygmalion", **self.connect_kwargs)
self.event.wait(1.0)
self.assertTrue(self.event.is_set())
self.assertTrue(self.ts.is_active())
@@ -583,10 +590,10 @@ class SSHClientTest(ClientTest):
self._client_host_key_bad(host_key)
def test_host_key_negotiation_3(self):
- self._client_host_key_good(paramiko.ECDSAKey, 'test_ecdsa_256.key')
+ self._client_host_key_good(paramiko.ECDSAKey, "test_ecdsa_256.key")
def test_host_key_negotiation_4(self):
- self._client_host_key_good(paramiko.RSAKey, 'test_rsa.key')
+ self._client_host_key_good(paramiko.RSAKey, "test_rsa.key")
def _setup_for_env(self):
threading.Thread(target=self._run).start()
@@ -594,7 +601,9 @@ class SSHClientTest(ClientTest):
self.tc = paramiko.SSHClient()
self.tc.set_missing_host_key_policy(paramiko.AutoAddPolicy())
self.assertEqual(0, len(self.tc.get_host_keys()))
- self.tc.connect(self.addr, self.port, username='slowdive', password='pygmalion')
+ self.tc.connect(
+ self.addr, self.port, username="slowdive", password="pygmalion"
+ )
self.event.wait(1.0)
self.assertTrue(self.event.isSet())
@@ -605,11 +614,11 @@ class SSHClientTest(ClientTest):
Verify that environment variables can be set by the client.
"""
self._setup_for_env()
- target_env = {b'A': b'B', b'C': b'd'}
+ target_env = {b"A": b"B", b"C": b"d"}
- self.tc.exec_command('yes', environment=target_env)
+ self.tc.exec_command("yes", environment=target_env)
schan = self.ts.accept(1.0)
- self.assertEqual(target_env, getattr(schan, 'env', {}))
+ self.assertEqual(target_env, getattr(schan, "env", {}))
schan.close()
@unittest.skip("Clients normally fail silently, thus so do we, for now")
@@ -617,14 +626,14 @@ class SSHClientTest(ClientTest):
self._setup_for_env()
with self.assertRaises(SSHException) as manager:
# Verify that a rejection by the server can be detected
- self.tc.exec_command('yes', environment={b'INVALID_ENV': b''})
+ self.tc.exec_command("yes", environment={b"INVALID_ENV": b""})
self.assertTrue(
- 'INVALID_ENV' in str(manager.exception),
- 'Expected variable name in error message'
+ "INVALID_ENV" in str(manager.exception),
+ "Expected variable name in error message",
)
self.assertTrue(
isinstance(manager.exception.args[1], SSHException),
- 'Expected original SSHException in exception'
+ "Expected original SSHException in exception",
)
def test_missing_key_policy_accepts_classes_or_instances(self):
@@ -652,35 +661,39 @@ class PasswordPassphraseTests(ClientTest):
def test_password_kwarg_works_for_password_auth(self):
# Straightforward / duplicate of earlier basic password test.
- self._test_connection(password='pygmalion')
+ self._test_connection(password="pygmalion")
# TODO: more granular exception pending #387; should be signaling "no auth
# methods available" because no key and no password
@raises(SSHException)
def test_passphrase_kwarg_not_used_for_password_auth(self):
# Using the "right" password in the "wrong" field shouldn't work.
- self._test_connection(passphrase='pygmalion')
+ self._test_connection(passphrase="pygmalion")
def test_passphrase_kwarg_used_for_key_passphrase(self):
# Straightforward again, with new passphrase kwarg.
self._test_connection(
- key_filename=_support('test_rsa_password.key'),
- passphrase='television',
+ key_filename=_support("test_rsa_password.key"),
+ passphrase="television",
)
- def test_password_kwarg_used_for_passphrase_when_no_passphrase_kwarg_given(self): # noqa
+ def test_password_kwarg_used_for_passphrase_when_no_passphrase_kwarg_given(
+ self
+ ): # noqa
# Backwards compatibility: passphrase in the password field.
self._test_connection(
- key_filename=_support('test_rsa_password.key'),
- password='television',
+ key_filename=_support("test_rsa_password.key"),
+ password="television",
)
- @raises(AuthenticationException) # TODO: more granular
- def test_password_kwarg_not_used_for_passphrase_when_passphrase_kwarg_given(self): # noqa
+ @raises(AuthenticationException) # TODO: more granular
+ def test_password_kwarg_not_used_for_passphrase_when_passphrase_kwarg_given(
+ self
+ ): # noqa
# Sanity: if we're given both fields, the password field is NOT used as
# a passphrase.
self._test_connection(
- key_filename=_support('test_rsa_password.key'),
- password='television',
- passphrase='wat? lol no',
+ key_filename=_support("test_rsa_password.key"),
+ password="television",
+ passphrase="wat? lol no",
)
diff --git a/tests/test_file.py b/tests/test_file.py
index 3d2c94e6..deacd60a 100644
--- a/tests/test_file.py
+++ b/tests/test_file.py
@@ -30,18 +30,19 @@ from paramiko.py3compat import BytesIO
from .util import needs_builtin
-class LoopbackFile (BufferedFile):
+class LoopbackFile(BufferedFile):
"""
BufferedFile object that you can write data into, and then read it back.
"""
- def __init__(self, mode='r', bufsize=-1):
+
+ def __init__(self, mode="r", bufsize=-1):
BufferedFile.__init__(self)
self._set_mode(mode, bufsize)
self.buffer = BytesIO()
self.offset = 0
def _read(self, size):
- data = self.buffer.getvalue()[self.offset:self.offset+size]
+ data = self.buffer.getvalue()[self.offset : self.offset + size]
self.offset += len(data)
return data
@@ -50,44 +51,46 @@ class LoopbackFile (BufferedFile):
return len(data)
-class BufferedFileTest (unittest.TestCase):
+class BufferedFileTest(unittest.TestCase):
def test_1_simple(self):
- f = LoopbackFile('r')
+ f = LoopbackFile("r")
try:
- f.write(b'hi')
- self.assertTrue(False, 'no exception on write to read-only file')
+ f.write(b"hi")
+ self.assertTrue(False, "no exception on write to read-only file")
except:
pass
f.close()
- f = LoopbackFile('w')
+ f = LoopbackFile("w")
try:
f.read(1)
- self.assertTrue(False, 'no exception to read from write-only file')
+ self.assertTrue(False, "no exception to read from write-only file")
except:
pass
f.close()
def test_2_readline(self):
- f = LoopbackFile('r+U')
- f.write(b'First line.\nSecond line.\r\nThird line.\n' +
- b'Fourth line.\nFinal line non-terminated.')
+ f = LoopbackFile("r+U")
+ f.write(
+ b"First line.\nSecond line.\r\nThird line.\n"
+ + b"Fourth line.\nFinal line non-terminated."
+ )
- self.assertEqual(f.readline(), 'First line.\n')
+ self.assertEqual(f.readline(), "First line.\n")
# universal newline mode should convert this linefeed:
- self.assertEqual(f.readline(), 'Second line.\n')
+ self.assertEqual(f.readline(), "Second line.\n")
# truncated line:
- self.assertEqual(f.readline(7), 'Third l')
- self.assertEqual(f.readline(), 'ine.\n')
+ self.assertEqual(f.readline(7), "Third l")
+ self.assertEqual(f.readline(), "ine.\n")
# newline should be detected and only the fourth line returned
- self.assertEqual(f.readline(39), 'Fourth line.\n')
- self.assertEqual(f.readline(), 'Final line non-terminated.')
- self.assertEqual(f.readline(), '')
+ self.assertEqual(f.readline(39), "Fourth line.\n")
+ self.assertEqual(f.readline(), "Final line non-terminated.")
+ self.assertEqual(f.readline(), "")
f.close()
try:
f.readline()
- self.assertTrue(False, 'no exception on readline of closed file')
+ self.assertTrue(False, "no exception on readline of closed file")
except IOError:
pass
self.assertTrue(linefeed_byte in f.newlines)
@@ -98,11 +101,11 @@ class BufferedFileTest (unittest.TestCase):
"""
try to trick the linefeed detector.
"""
- f = LoopbackFile('r+U')
- f.write(b'First line.\r')
- self.assertEqual(f.readline(), 'First line.\n')
- f.write(b'\nSecond.\r\n')
- self.assertEqual(f.readline(), 'Second.\n')
+ f = LoopbackFile("r+U")
+ f.write(b"First line.\r")
+ self.assertEqual(f.readline(), "First line.\n")
+ f.write(b"\nSecond.\r\n")
+ self.assertEqual(f.readline(), "Second.\n")
f.close()
self.assertEqual(f.newlines, crlf)
@@ -110,51 +113,54 @@ class BufferedFileTest (unittest.TestCase):
"""
verify that write buffering is on.
"""
- f = LoopbackFile('r+', 1)
- f.write(b'Complete line.\nIncomplete line.')
- self.assertEqual(f.readline(), 'Complete line.\n')
- self.assertEqual(f.readline(), '')
- f.write('..\n')
- self.assertEqual(f.readline(), 'Incomplete line...\n')
+ f = LoopbackFile("r+", 1)
+ f.write(b"Complete line.\nIncomplete line.")
+ self.assertEqual(f.readline(), "Complete line.\n")
+ self.assertEqual(f.readline(), "")
+ f.write("..\n")
+ self.assertEqual(f.readline(), "Incomplete line...\n")
f.close()
def test_5_flush(self):
"""
verify that flush will force a write.
"""
- f = LoopbackFile('r+', 512)
- f.write('Not\nquite\n512 bytes.\n')
- self.assertEqual(f.read(1), b'')
+ f = LoopbackFile("r+", 512)
+ f.write("Not\nquite\n512 bytes.\n")
+ self.assertEqual(f.read(1), b"")
f.flush()
- self.assertEqual(f.read(5), b'Not\nq')
- self.assertEqual(f.read(10), b'uite\n512 b')
- self.assertEqual(f.read(9), b'ytes.\n')
- self.assertEqual(f.read(3), b'')
+ self.assertEqual(f.read(5), b"Not\nq")
+ self.assertEqual(f.read(10), b"uite\n512 b")
+ self.assertEqual(f.read(9), b"ytes.\n")
+ self.assertEqual(f.read(3), b"")
f.close()
def test_6_buffering(self):
"""
verify that flushing happens automatically on buffer crossing.
"""
- f = LoopbackFile('r+', 16)
- f.write(b'Too small.')
- self.assertEqual(f.read(4), b'')
- f.write(b' ')
- self.assertEqual(f.read(4), b'')
- f.write(b'Enough.')
- self.assertEqual(f.read(20), b'Too small. Enough.')
+ f = LoopbackFile("r+", 16)
+ f.write(b"Too small.")
+ self.assertEqual(f.read(4), b"")
+ f.write(b" ")
+ self.assertEqual(f.read(4), b"")
+ f.write(b"Enough.")
+ self.assertEqual(f.read(20), b"Too small. Enough.")
f.close()
def test_7_read_all(self):
"""
verify that read(-1) returns everything left in the file.
"""
- f = LoopbackFile('r+', 16)
- f.write(b'The first thing you need to do is open your eyes. ')
- f.write(b'Then, you need to close them again.\n')
+ f = LoopbackFile("r+", 16)
+ f.write(b"The first thing you need to do is open your eyes. ")
+ f.write(b"Then, you need to close them again.\n")
s = f.read(-1)
- self.assertEqual(s, b'The first thing you need to do is open your eyes. Then, you ' +
- b'need to close them again.\n')
+ self.assertEqual(
+ s,
+ b"The first thing you need to do is open your eyes. Then, you "
+ + b"need to close them again.\n",
+ )
f.close()
def test_8_buffering(self):
@@ -162,19 +168,19 @@ class BufferedFileTest (unittest.TestCase):
verify that buffered objects can be written
"""
if sys.version_info[0] == 2:
- f = LoopbackFile('r+', 16)
- f.write(buffer(b'Too small.'))
+ f = LoopbackFile("r+", 16)
+ f.write(buffer(b"Too small."))
f.close()
def test_9_readable(self):
- f = LoopbackFile('r')
+ f = LoopbackFile("r")
self.assertTrue(f.readable())
self.assertFalse(f.writable())
self.assertFalse(f.seekable())
f.close()
def test_A_writable(self):
- f = LoopbackFile('w')
+ f = LoopbackFile("w")
self.assertTrue(f.writable())
self.assertFalse(f.readable())
self.assertFalse(f.seekable())
@@ -182,48 +188,49 @@ class BufferedFileTest (unittest.TestCase):
def test_B_readinto(self):
data = bytearray(5)
- f = LoopbackFile('r+')
+ f = LoopbackFile("r+")
f._write(b"hello")
f.readinto(data)
- self.assertEqual(data, b'hello')
+ self.assertEqual(data, b"hello")
f.close()
def test_write_bad_type(self):
- with LoopbackFile('wb') as f:
+ with LoopbackFile("wb") as f:
self.assertRaises(TypeError, f.write, object())
def test_write_unicode_as_binary(self):
text = u"\xa7 why is writing text to a binary file allowed?\n"
- with LoopbackFile('rb+') as f:
+ with LoopbackFile("rb+") as f:
f.write(text)
self.assertEqual(f.read(), text.encode("utf-8"))
- @needs_builtin('memoryview')
+ @needs_builtin("memoryview")
def test_write_bytearray(self):
- with LoopbackFile('rb+') as f:
+ with LoopbackFile("rb+") as f:
f.write(bytearray(12))
self.assertEqual(f.read(), 12 * b"\0")
- @needs_builtin('buffer')
+ @needs_builtin("buffer")
def test_write_buffer(self):
data = 3 * b"pretend giant block of data\n"
offsets = range(0, len(data), 8)
- with LoopbackFile('rb+') as f:
+ with LoopbackFile("rb+") as f:
for offset in offsets:
f.write(buffer(data, offset, 8))
self.assertEqual(f.read(), data)
- @needs_builtin('memoryview')
+ @needs_builtin("memoryview")
def test_write_memoryview(self):
data = 3 * b"pretend giant block of data\n"
offsets = range(0, len(data), 8)
- with LoopbackFile('rb+') as f:
+ with LoopbackFile("rb+") as f:
view = memoryview(data)
for offset in offsets:
- f.write(view[offset:offset+8])
+ f.write(view[offset : offset + 8])
self.assertEqual(f.read(), data)
-if __name__ == '__main__':
+if __name__ == "__main__":
from unittest import main
+
main()
diff --git a/tests/test_gssapi.py b/tests/test_gssapi.py
index d4b632be..d7fbdd53 100644
--- a/tests/test_gssapi.py
+++ b/tests/test_gssapi.py
@@ -30,6 +30,7 @@ from .util import needs_gssapi
@needs_gssapi
class GSSAPITest(unittest.TestCase):
+
def setup():
# TODO: these vars should all come from os.environ or whatever the
# approved pytest method is for runtime-configuring test data.
@@ -43,6 +44,7 @@ class GSSAPITest(unittest.TestCase):
"""
from pyasn1.type.univ import ObjectIdentifier
from pyasn1.codec.der import encoder, decoder
+
oid = encoder.encode(ObjectIdentifier(self.krb5_mech))
mech, __ = decoder.decode(oid)
self.assertEquals(self.krb5_mech, mech.__str__())
@@ -57,6 +59,7 @@ class GSSAPITest(unittest.TestCase):
except ImportError:
import sspicon
import sspi
+
_API = "SSPI"
c_token = None
@@ -65,23 +68,28 @@ class GSSAPITest(unittest.TestCase):
if _API == "MIT":
if self.server_mode:
- gss_flags = (gssapi.C_PROT_READY_FLAG,
- gssapi.C_INTEG_FLAG,
- gssapi.C_MUTUAL_FLAG,
- gssapi.C_DELEG_FLAG)
+ gss_flags = (
+ gssapi.C_PROT_READY_FLAG,
+ gssapi.C_INTEG_FLAG,
+ gssapi.C_MUTUAL_FLAG,
+ gssapi.C_DELEG_FLAG,
+ )
else:
- gss_flags = (gssapi.C_PROT_READY_FLAG,
- gssapi.C_INTEG_FLAG,
- gssapi.C_DELEG_FLAG)
+ gss_flags = (
+ gssapi.C_PROT_READY_FLAG,
+ gssapi.C_INTEG_FLAG,
+ gssapi.C_DELEG_FLAG,
+ )
# Initialize a GSS-API context.
ctx = gssapi.Context()
ctx.flags = gss_flags
krb5_oid = gssapi.OID.mech_from_string(self.krb5_mech)
- target_name = gssapi.Name("host@" + self.targ_name,
- gssapi.C_NT_HOSTBASED_SERVICE)
- gss_ctxt = gssapi.InitContext(peer_name=target_name,
- mech_type=krb5_oid,
- req_flags=ctx.flags)
+ target_name = gssapi.Name(
+ "host@" + self.targ_name, gssapi.C_NT_HOSTBASED_SERVICE
+ )
+ gss_ctxt = gssapi.InitContext(
+ peer_name=target_name, mech_type=krb5_oid, req_flags=ctx.flags
+ )
if self.server_mode:
c_token = gss_ctxt.step(c_token)
gss_ctxt_status = gss_ctxt.established
@@ -108,15 +116,15 @@ class GSSAPITest(unittest.TestCase):
self.assertEquals(0, status)
else:
gss_flags = (
- sspicon.ISC_REQ_INTEGRITY |
- sspicon.ISC_REQ_MUTUAL_AUTH |
- sspicon.ISC_REQ_DELEGATE
+ sspicon.ISC_REQ_INTEGRITY
+ | sspicon.ISC_REQ_MUTUAL_AUTH
+ | sspicon.ISC_REQ_DELEGATE
)
# Initialize a GSS-API context.
target_name = "host/" + socket.getfqdn(self.targ_name)
- gss_ctxt = sspi.ClientAuth("Kerberos",
- scflags=gss_flags,
- targetspn=target_name)
+ gss_ctxt = sspi.ClientAuth(
+ "Kerberos", scflags=gss_flags, targetspn=target_name
+ )
if self.server_mode:
error, token = gss_ctxt.authorize(c_token)
c_token = token[0].Buffer
diff --git a/tests/test_hostkeys.py b/tests/test_hostkeys.py
index cd75f8ab..a1b7a9e0 100644
--- a/tests/test_hostkeys.py
+++ b/tests/test_hostkeys.py
@@ -54,77 +54,80 @@ Ngw3qIch/WgRmMHy4kBq1SsXMjQCte1So6HBMvBPIW5SiMTmjCfZZiw4AYHK+B/JaOwaG9yRg2Ejg\
0d54U0X/NeX5QxuYR6OMJlrkQB7oiW/P/1mwjQgE="""
-class HostKeysTest (unittest.TestCase):
+class HostKeysTest(unittest.TestCase):
def setUp(self):
- with open('hostfile.temp', 'w') as f:
+ with open("hostfile.temp", "w") as f:
f.write(test_hosts_file)
def tearDown(self):
- os.unlink('hostfile.temp')
+ os.unlink("hostfile.temp")
def test_1_load(self):
- hostdict = paramiko.HostKeys('hostfile.temp')
+ hostdict = paramiko.HostKeys("hostfile.temp")
self.assertEqual(2, len(hostdict))
self.assertEqual(1, len(list(hostdict.values())[0]))
self.assertEqual(1, len(list(hostdict.values())[1]))
- fp = hexlify(hostdict['secure.example.com']['ssh-rsa'].get_fingerprint()).upper()
- self.assertEqual(b'E6684DB30E109B67B70FF1DC5C7F1363', fp)
+ fp = hexlify(
+ hostdict["secure.example.com"]["ssh-rsa"].get_fingerprint()
+ ).upper()
+ self.assertEqual(b"E6684DB30E109B67B70FF1DC5C7F1363", fp)
def test_2_add(self):
- hostdict = paramiko.HostKeys('hostfile.temp')
- hh = '|1|BMsIC6cUIP2zBuXR3t2LRcJYjzM=|hpkJMysjTk/+zzUUzxQEa2ieq6c='
+ hostdict = paramiko.HostKeys("hostfile.temp")
+ hh = "|1|BMsIC6cUIP2zBuXR3t2LRcJYjzM=|hpkJMysjTk/+zzUUzxQEa2ieq6c="
key = paramiko.RSAKey(data=decodebytes(keyblob))
- hostdict.add(hh, 'ssh-rsa', key)
+ hostdict.add(hh, "ssh-rsa", key)
self.assertEqual(3, len(list(hostdict)))
- x = hostdict['foo.example.com']
- fp = hexlify(x['ssh-rsa'].get_fingerprint()).upper()
- self.assertEqual(b'7EC91BB336CB6D810B124B1353C32396', fp)
- self.assertTrue(hostdict.check('foo.example.com', key))
+ x = hostdict["foo.example.com"]
+ fp = hexlify(x["ssh-rsa"].get_fingerprint()).upper()
+ self.assertEqual(b"7EC91BB336CB6D810B124B1353C32396", fp)
+ self.assertTrue(hostdict.check("foo.example.com", key))
def test_3_dict(self):
- hostdict = paramiko.HostKeys('hostfile.temp')
- self.assertTrue('secure.example.com' in hostdict)
- self.assertTrue('not.example.com' not in hostdict)
- self.assertTrue('secure.example.com' in hostdict)
- self.assertTrue('not.example.com' not in hostdict)
- x = hostdict.get('secure.example.com', None)
+ hostdict = paramiko.HostKeys("hostfile.temp")
+ self.assertTrue("secure.example.com" in hostdict)
+ self.assertTrue("not.example.com" not in hostdict)
+ self.assertTrue("secure.example.com" in hostdict)
+ self.assertTrue("not.example.com" not in hostdict)
+ x = hostdict.get("secure.example.com", None)
self.assertTrue(x is not None)
- fp = hexlify(x['ssh-rsa'].get_fingerprint()).upper()
- self.assertEqual(b'E6684DB30E109B67B70FF1DC5C7F1363', fp)
+ fp = hexlify(x["ssh-rsa"].get_fingerprint()).upper()
+ self.assertEqual(b"E6684DB30E109B67B70FF1DC5C7F1363", fp)
i = 0
for key in hostdict:
i += 1
self.assertEqual(2, i)
-
+
def test_4_dict_set(self):
- hostdict = paramiko.HostKeys('hostfile.temp')
+ hostdict = paramiko.HostKeys("hostfile.temp")
key = paramiko.RSAKey(data=decodebytes(keyblob))
key_dss = paramiko.DSSKey(data=decodebytes(keyblob_dss))
- hostdict['secure.example.com'] = {
- 'ssh-rsa': key,
- 'ssh-dss': key_dss
- }
- hostdict['fake.example.com'] = {}
- hostdict['fake.example.com']['ssh-rsa'] = key
-
+ hostdict["secure.example.com"] = {"ssh-rsa": key, "ssh-dss": key_dss}
+ hostdict["fake.example.com"] = {}
+ hostdict["fake.example.com"]["ssh-rsa"] = key
+
self.assertEqual(3, len(hostdict))
self.assertEqual(2, len(list(hostdict.values())[0]))
self.assertEqual(1, len(list(hostdict.values())[1]))
self.assertEqual(1, len(list(hostdict.values())[2]))
- fp = hexlify(hostdict['secure.example.com']['ssh-rsa'].get_fingerprint()).upper()
- self.assertEqual(b'7EC91BB336CB6D810B124B1353C32396', fp)
- fp = hexlify(hostdict['secure.example.com']['ssh-dss'].get_fingerprint()).upper()
- self.assertEqual(b'4478F0B9A23CC5182009FF755BC1D26C', fp)
+ fp = hexlify(
+ hostdict["secure.example.com"]["ssh-rsa"].get_fingerprint()
+ ).upper()
+ self.assertEqual(b"7EC91BB336CB6D810B124B1353C32396", fp)
+ fp = hexlify(
+ hostdict["secure.example.com"]["ssh-dss"].get_fingerprint()
+ ).upper()
+ self.assertEqual(b"4478F0B9A23CC5182009FF755BC1D26C", fp)
def test_delitem(self):
- hostdict = paramiko.HostKeys('hostfile.temp')
- target = 'happy.example.com'
- entry = hostdict[target] # will KeyError if not present
+ hostdict = paramiko.HostKeys("hostfile.temp")
+ target = "happy.example.com"
+ entry = hostdict[target] # will KeyError if not present
del hostdict[target]
try:
entry = hostdict[target]
except KeyError:
- pass # Good
+ pass # Good
else:
assert False, "Entry was not deleted from HostKeys on delitem!"
diff --git a/tests/test_kex.py b/tests/test_kex.py
index b5808e7e..13d19d86 100644
--- a/tests/test_kex.py
+++ b/tests/test_kex.py
@@ -38,30 +38,46 @@ from paramiko.kex_ecdh_nist import KexNistp256
def dummy_urandom(n):
return byte_chr(0xcc) * n
+
def dummy_generate_key_pair(obj):
- private_key_value = 94761803665136558137557783047955027733968423115106677159790289642479432803037
- public_key_numbers = "042bdab212fa8ba1b7c843301682a4db424d307246c7e1e6083c41d9ca7b098bf30b3d63e2ec6278488c135360456cc054b3444ecc45998c08894cbc1370f5f989"
- public_key_numbers_obj = ec.EllipticCurvePublicNumbers.from_encoded_point(ec.SECP256R1(), unhexlify(public_key_numbers))
- obj.P = ec.EllipticCurvePrivateNumbers(private_value=private_key_value, public_numbers=public_key_numbers_obj).private_key(default_backend())
+ private_key_value = (
+ 94761803665136558137557783047955027733968423115106677159790289642479432803037
+ )
+ public_key_numbers = (
+ "042bdab212fa8ba1b7c843301682a4db424d307246c7e1e6083c41d9ca7b098bf30b3d63e2ec6278488c135360456cc054b3444ecc45998c08894cbc1370f5f989"
+ )
+ public_key_numbers_obj = ec.EllipticCurvePublicNumbers.from_encoded_point(
+ ec.SECP256R1(), unhexlify(public_key_numbers)
+ )
+ obj.P = ec.EllipticCurvePrivateNumbers(
+ private_value=private_key_value, public_numbers=public_key_numbers_obj
+ ).private_key(default_backend())
if obj.transport.server_mode:
- obj.Q_S = ec.EllipticCurvePublicNumbers.from_encoded_point(ec.SECP256R1(), unhexlify(public_key_numbers)).public_key(default_backend())
+ obj.Q_S = ec.EllipticCurvePublicNumbers.from_encoded_point(
+ ec.SECP256R1(), unhexlify(public_key_numbers)
+ ).public_key(default_backend())
return
- obj.Q_C = ec.EllipticCurvePublicNumbers.from_encoded_point(ec.SECP256R1(), unhexlify(public_key_numbers)).public_key(default_backend())
+ obj.Q_C = ec.EllipticCurvePublicNumbers.from_encoded_point(
+ ec.SECP256R1(), unhexlify(public_key_numbers)
+ ).public_key(default_backend())
+
+class FakeKey(object):
-class FakeKey (object):
def __str__(self):
- return 'fake-key'
+ return "fake-key"
def asbytes(self):
- return b'fake-key'
+ return b"fake-key"
def sign_ssh_data(self, H):
- return b'fake-sig'
+ return b"fake-sig"
-class FakeModulusPack (object):
- P = 0xFFFFFFFFFFFFFFFFC90FDAA22168C234C4C6628B80DC1CD129024E088A67CC74020BBEA63B139B22514A08798E3404DDEF9519B3CD3A431B302B0A6DF25F14374FE1356D6D51C245E485B576625E7EC6F44C42E9A637ED6B0BFF5CB6F406B7EDEE386BFB5A899FA5AE9F24117C4B1FE649286651ECE65381FFFFFFFFFFFFFFFF
+class FakeModulusPack(object):
+ P = (
+ 0xFFFFFFFFFFFFFFFFC90FDAA22168C234C4C6628B80DC1CD129024E088A67CC74020BBEA63B139B22514A08798E3404DDEF9519B3CD3A431B302B0A6DF25F14374FE1356D6D51C245E485B576625E7EC6F44C42E9A637ED6B0BFF5CB6F406B7EDEE386BFB5A899FA5AE9F24117C4B1FE649286651ECE65381FFFFFFFFFFFFFFFF
+ )
G = 2
def get_modulus(self, min, ask, max):
@@ -69,10 +85,10 @@ class FakeModulusPack (object):
class FakeTransport(object):
- local_version = 'SSH-2.0-paramiko_1.0'
- remote_version = 'SSH-2.0-lame'
- local_kex_init = 'local-kex-init'
- remote_kex_init = 'remote-kex-init'
+ local_version = "SSH-2.0-paramiko_1.0"
+ remote_version = "SSH-2.0-lame"
+ local_kex_init = "local-kex-init"
+ remote_kex_init = "remote-kex-init"
def _send_message(self, m):
self._message = m
@@ -100,9 +116,11 @@ class FakeTransport(object):
return FakeModulusPack()
-class KexTest (unittest.TestCase):
+class KexTest(unittest.TestCase):
- K = 14730343317708716439807310032871972459448364195094179797249681733965528989482751523943515690110179031004049109375612685505881911274101441415545039654102474376472240501616988799699744135291070488314748284283496055223852115360852283821334858541043710301057312858051901453919067023103730011648890038847384890504
+ K = (
+ 14730343317708716439807310032871972459448364195094179797249681733965528989482751523943515690110179031004049109375612685505881911274101441415545039654102474376472240501616988799699744135291070488314748284283496055223852115360852283821334858541043710301057312858051901453919067023103730011648890038847384890504
+ )
def setUp(self):
self._original_urandom = os.urandom
@@ -119,21 +137,25 @@ class KexTest (unittest.TestCase):
transport.server_mode = False
kex = KexGroup1(transport)
kex.start_kex()
- x = b'1E000000807E2DDB1743F3487D6545F04F1C8476092FB912B013626AB5BCEB764257D88BBA64243B9F348DF7B41B8C814A995E00299913503456983FFB9178D3CD79EB6D55522418A8ABF65375872E55938AB99A84A0B5FC8A1ECC66A7C3766E7E0F80B7CE2C9225FC2DD683F4764244B72963BBB383F529DCF0C5D17740B8A2ADBE9208D4'
+ x = (
+ b"1E000000807E2DDB1743F3487D6545F04F1C8476092FB912B013626AB5BCEB764257D88BBA64243B9F348DF7B41B8C814A995E00299913503456983FFB9178D3CD79EB6D55522418A8ABF65375872E55938AB99A84A0B5FC8A1ECC66A7C3766E7E0F80B7CE2C9225FC2DD683F4764244B72963BBB383F529DCF0C5D17740B8A2ADBE9208D4"
+ )
self.assertEqual(x, hexlify(transport._message.asbytes()).upper())
- self.assertEqual((paramiko.kex_group1._MSG_KEXDH_REPLY,), transport._expect)
+ self.assertEqual(
+ (paramiko.kex_group1._MSG_KEXDH_REPLY,), transport._expect
+ )
# fake "reply"
msg = Message()
- msg.add_string('fake-host-key')
+ msg.add_string("fake-host-key")
msg.add_mpint(69)
- msg.add_string('fake-sig')
+ msg.add_string("fake-sig")
msg.rewind()
kex.parse_next(paramiko.kex_group1._MSG_KEXDH_REPLY, msg)
- H = b'03079780F3D3AD0B3C6DB30C8D21685F367A86D2'
+ H = b"03079780F3D3AD0B3C6DB30C8D21685F367A86D2"
self.assertEqual(self.K, transport._K)
self.assertEqual(H, hexlify(transport._H).upper())
- self.assertEqual((b'fake-host-key', b'fake-sig'), transport._verify)
+ self.assertEqual((b"fake-host-key", b"fake-sig"), transport._verify)
self.assertTrue(transport._activated)
def test_2_group1_server(self):
@@ -141,14 +163,18 @@ class KexTest (unittest.TestCase):
transport.server_mode = True
kex = KexGroup1(transport)
kex.start_kex()
- self.assertEqual((paramiko.kex_group1._MSG_KEXDH_INIT,), transport._expect)
+ self.assertEqual(
+ (paramiko.kex_group1._MSG_KEXDH_INIT,), transport._expect
+ )
msg = Message()
msg.add_mpint(69)
msg.rewind()
kex.parse_next(paramiko.kex_group1._MSG_KEXDH_INIT, msg)
- H = b'B16BF34DD10945EDE84E9C1EF24A14BFDC843389'
- x = b'1F0000000866616B652D6B6579000000807E2DDB1743F3487D6545F04F1C8476092FB912B013626AB5BCEB764257D88BBA64243B9F348DF7B41B8C814A995E00299913503456983FFB9178D3CD79EB6D55522418A8ABF65375872E55938AB99A84A0B5FC8A1ECC66A7C3766E7E0F80B7CE2C9225FC2DD683F4764244B72963BBB383F529DCF0C5D17740B8A2ADBE9208D40000000866616B652D736967'
+ H = b"B16BF34DD10945EDE84E9C1EF24A14BFDC843389"
+ x = (
+ b"1F0000000866616B652D6B6579000000807E2DDB1743F3487D6545F04F1C8476092FB912B013626AB5BCEB764257D88BBA64243B9F348DF7B41B8C814A995E00299913503456983FFB9178D3CD79EB6D55522418A8ABF65375872E55938AB99A84A0B5FC8A1ECC66A7C3766E7E0F80B7CE2C9225FC2DD683F4764244B72963BBB383F529DCF0C5D17740B8A2ADBE9208D40000000866616B652D736967"
+ )
self.assertEqual(self.K, transport._K)
self.assertEqual(H, hexlify(transport._H).upper())
self.assertEqual(x, hexlify(transport._message.asbytes()).upper())
@@ -159,29 +185,35 @@ class KexTest (unittest.TestCase):
transport.server_mode = False
kex = KexGex(transport)
kex.start_kex()
- x = b'22000004000000080000002000'
+ x = b"22000004000000080000002000"
self.assertEqual(x, hexlify(transport._message.asbytes()).upper())
- self.assertEqual((paramiko.kex_gex._MSG_KEXDH_GEX_GROUP,), transport._expect)
+ self.assertEqual(
+ (paramiko.kex_gex._MSG_KEXDH_GEX_GROUP,), transport._expect
+ )
msg = Message()
msg.add_mpint(FakeModulusPack.P)
msg.add_mpint(FakeModulusPack.G)
msg.rewind()
kex.parse_next(paramiko.kex_gex._MSG_KEXDH_GEX_GROUP, msg)
- x = b'20000000807E2DDB1743F3487D6545F04F1C8476092FB912B013626AB5BCEB764257D88BBA64243B9F348DF7B41B8C814A995E00299913503456983FFB9178D3CD79EB6D55522418A8ABF65375872E55938AB99A84A0B5FC8A1ECC66A7C3766E7E0F80B7CE2C9225FC2DD683F4764244B72963BBB383F529DCF0C5D17740B8A2ADBE9208D4'
+ x = (
+ b"20000000807E2DDB1743F3487D6545F04F1C8476092FB912B013626AB5BCEB764257D88BBA64243B9F348DF7B41B8C814A995E00299913503456983FFB9178D3CD79EB6D55522418A8ABF65375872E55938AB99A84A0B5FC8A1ECC66A7C3766E7E0F80B7CE2C9225FC2DD683F4764244B72963BBB383F529DCF0C5D17740B8A2ADBE9208D4"
+ )
self.assertEqual(x, hexlify(transport._message.asbytes()).upper())
- self.assertEqual((paramiko.kex_gex._MSG_KEXDH_GEX_REPLY,), transport._expect)
+ self.assertEqual(
+ (paramiko.kex_gex._MSG_KEXDH_GEX_REPLY,), transport._expect
+ )
msg = Message()
- msg.add_string('fake-host-key')
+ msg.add_string("fake-host-key")
msg.add_mpint(69)
- msg.add_string('fake-sig')
+ msg.add_string("fake-sig")
msg.rewind()
kex.parse_next(paramiko.kex_gex._MSG_KEXDH_GEX_REPLY, msg)
- H = b'A265563F2FA87F1A89BF007EE90D58BE2E4A4BD0'
+ H = b"A265563F2FA87F1A89BF007EE90D58BE2E4A4BD0"
self.assertEqual(self.K, transport._K)
self.assertEqual(H, hexlify(transport._H).upper())
- self.assertEqual((b'fake-host-key', b'fake-sig'), transport._verify)
+ self.assertEqual((b"fake-host-key", b"fake-sig"), transport._verify)
self.assertTrue(transport._activated)
def test_4_gex_old_client(self):
@@ -189,37 +221,49 @@ class KexTest (unittest.TestCase):
transport.server_mode = False
kex = KexGex(transport)
kex.start_kex(_test_old_style=True)
- x = b'1E00000800'
+ x = b"1E00000800"
self.assertEqual(x, hexlify(transport._message.asbytes()).upper())
- self.assertEqual((paramiko.kex_gex._MSG_KEXDH_GEX_GROUP,), transport._expect)
+ self.assertEqual(
+ (paramiko.kex_gex._MSG_KEXDH_GEX_GROUP,), transport._expect
+ )
msg = Message()
msg.add_mpint(FakeModulusPack.P)
msg.add_mpint(FakeModulusPack.G)
msg.rewind()
kex.parse_next(paramiko.kex_gex._MSG_KEXDH_GEX_GROUP, msg)
- x = b'20000000807E2DDB1743F3487D6545F04F1C8476092FB912B013626AB5BCEB764257D88BBA64243B9F348DF7B41B8C814A995E00299913503456983FFB9178D3CD79EB6D55522418A8ABF65375872E55938AB99A84A0B5FC8A1ECC66A7C3766E7E0F80B7CE2C9225FC2DD683F4764244B72963BBB383F529DCF0C5D17740B8A2ADBE9208D4'
+ x = (
+ b"20000000807E2DDB1743F3487D6545F04F1C8476092FB912B013626AB5BCEB764257D88BBA64243B9F348DF7B41B8C814A995E00299913503456983FFB9178D3CD79EB6D55522418A8ABF65375872E55938AB99A84A0B5FC8A1ECC66A7C3766E7E0F80B7CE2C9225FC2DD683F4764244B72963BBB383F529DCF0C5D17740B8A2ADBE9208D4"
+ )
self.assertEqual(x, hexlify(transport._message.asbytes()).upper())
- self.assertEqual((paramiko.kex_gex._MSG_KEXDH_GEX_REPLY,), transport._expect)
+ self.assertEqual(
+ (paramiko.kex_gex._MSG_KEXDH_GEX_REPLY,), transport._expect
+ )
msg = Message()
- msg.add_string('fake-host-key')
+ msg.add_string("fake-host-key")
msg.add_mpint(69)
- msg.add_string('fake-sig')
+ msg.add_string("fake-sig")
msg.rewind()
kex.parse_next(paramiko.kex_gex._MSG_KEXDH_GEX_REPLY, msg)
- H = b'807F87B269EF7AC5EC7E75676808776A27D5864C'
+ H = b"807F87B269EF7AC5EC7E75676808776A27D5864C"
self.assertEqual(self.K, transport._K)
self.assertEqual(H, hexlify(transport._H).upper())
- self.assertEqual((b'fake-host-key', b'fake-sig'), transport._verify)
+ self.assertEqual((b"fake-host-key", b"fake-sig"), transport._verify)
self.assertTrue(transport._activated)
-
+
def test_5_gex_server(self):
transport = FakeTransport()
transport.server_mode = True
kex = KexGex(transport)
kex.start_kex()
- self.assertEqual((paramiko.kex_gex._MSG_KEXDH_GEX_REQUEST, paramiko.kex_gex._MSG_KEXDH_GEX_REQUEST_OLD), transport._expect)
+ self.assertEqual(
+ (
+ paramiko.kex_gex._MSG_KEXDH_GEX_REQUEST,
+ paramiko.kex_gex._MSG_KEXDH_GEX_REQUEST_OLD,
+ ),
+ transport._expect,
+ )
msg = Message()
msg.add_int(1024)
@@ -227,17 +271,25 @@ class KexTest (unittest.TestCase):
msg.add_int(4096)
msg.rewind()
kex.parse_next(paramiko.kex_gex._MSG_KEXDH_GEX_REQUEST, msg)
- x = b'1F0000008100FFFFFFFFFFFFFFFFC90FDAA22168C234C4C6628B80DC1CD129024E088A67CC74020BBEA63B139B22514A08798E3404DDEF9519B3CD3A431B302B0A6DF25F14374FE1356D6D51C245E485B576625E7EC6F44C42E9A637ED6B0BFF5CB6F406B7EDEE386BFB5A899FA5AE9F24117C4B1FE649286651ECE65381FFFFFFFFFFFFFFFF0000000102'
+ x = (
+ b"1F0000008100FFFFFFFFFFFFFFFFC90FDAA22168C234C4C6628B80DC1CD129024E088A67CC74020BBEA63B139B22514A08798E3404DDEF9519B3CD3A431B302B0A6DF25F14374FE1356D6D51C245E485B576625E7EC6F44C42E9A637ED6B0BFF5CB6F406B7EDEE386BFB5A899FA5AE9F24117C4B1FE649286651ECE65381FFFFFFFFFFFFFFFF0000000102"
+ )
self.assertEqual(x, hexlify(transport._message.asbytes()).upper())
- self.assertEqual((paramiko.kex_gex._MSG_KEXDH_GEX_INIT,), transport._expect)
+ self.assertEqual(
+ (paramiko.kex_gex._MSG_KEXDH_GEX_INIT,), transport._expect
+ )
msg = Message()
msg.add_mpint(12345)
msg.rewind()
kex.parse_next(paramiko.kex_gex._MSG_KEXDH_GEX_INIT, msg)
- K = 67592995013596137876033460028393339951879041140378510871612128162185209509220726296697886624612526735888348020498716482757677848959420073720160491114319163078862905400020959196386947926388406687288901564192071077389283980347784184487280885335302632305026248574716290537036069329724382811853044654824945750581
- H = b'CE754197C21BF3452863B4F44D0B3951F12516EF'
- x = b'210000000866616B652D6B6579000000807E2DDB1743F3487D6545F04F1C8476092FB912B013626AB5BCEB764257D88BBA64243B9F348DF7B41B8C814A995E00299913503456983FFB9178D3CD79EB6D55522418A8ABF65375872E55938AB99A84A0B5FC8A1ECC66A7C3766E7E0F80B7CE2C9225FC2DD683F4764244B72963BBB383F529DCF0C5D17740B8A2ADBE9208D40000000866616B652D736967'
+ K = (
+ 67592995013596137876033460028393339951879041140378510871612128162185209509220726296697886624612526735888348020498716482757677848959420073720160491114319163078862905400020959196386947926388406687288901564192071077389283980347784184487280885335302632305026248574716290537036069329724382811853044654824945750581
+ )
+ H = b"CE754197C21BF3452863B4F44D0B3951F12516EF"
+ x = (
+ b"210000000866616B652D6B6579000000807E2DDB1743F3487D6545F04F1C8476092FB912B013626AB5BCEB764257D88BBA64243B9F348DF7B41B8C814A995E00299913503456983FFB9178D3CD79EB6D55522418A8ABF65375872E55938AB99A84A0B5FC8A1ECC66A7C3766E7E0F80B7CE2C9225FC2DD683F4764244B72963BBB383F529DCF0C5D17740B8A2ADBE9208D40000000866616B652D736967"
+ )
self.assertEqual(K, transport._K)
self.assertEqual(H, hexlify(transport._H).upper())
self.assertEqual(x, hexlify(transport._message.asbytes()).upper())
@@ -248,23 +300,37 @@ class KexTest (unittest.TestCase):
transport.server_mode = True
kex = KexGex(transport)
kex.start_kex()
- self.assertEqual((paramiko.kex_gex._MSG_KEXDH_GEX_REQUEST, paramiko.kex_gex._MSG_KEXDH_GEX_REQUEST_OLD), transport._expect)
+ self.assertEqual(
+ (
+ paramiko.kex_gex._MSG_KEXDH_GEX_REQUEST,
+ paramiko.kex_gex._MSG_KEXDH_GEX_REQUEST_OLD,
+ ),
+ transport._expect,
+ )
msg = Message()
msg.add_int(2048)
msg.rewind()
kex.parse_next(paramiko.kex_gex._MSG_KEXDH_GEX_REQUEST_OLD, msg)
- x = b'1F0000008100FFFFFFFFFFFFFFFFC90FDAA22168C234C4C6628B80DC1CD129024E088A67CC74020BBEA63B139B22514A08798E3404DDEF9519B3CD3A431B302B0A6DF25F14374FE1356D6D51C245E485B576625E7EC6F44C42E9A637ED6B0BFF5CB6F406B7EDEE386BFB5A899FA5AE9F24117C4B1FE649286651ECE65381FFFFFFFFFFFFFFFF0000000102'
+ x = (
+ b"1F0000008100FFFFFFFFFFFFFFFFC90FDAA22168C234C4C6628B80DC1CD129024E088A67CC74020BBEA63B139B22514A08798E3404DDEF9519B3CD3A431B302B0A6DF25F14374FE1356D6D51C245E485B576625E7EC6F44C42E9A637ED6B0BFF5CB6F406B7EDEE386BFB5A899FA5AE9F24117C4B1FE649286651ECE65381FFFFFFFFFFFFFFFF0000000102"
+ )
self.assertEqual(x, hexlify(transport._message.asbytes()).upper())
- self.assertEqual((paramiko.kex_gex._MSG_KEXDH_GEX_INIT,), transport._expect)
+ self.assertEqual(
+ (paramiko.kex_gex._MSG_KEXDH_GEX_INIT,), transport._expect
+ )
msg = Message()
msg.add_mpint(12345)
msg.rewind()
kex.parse_next(paramiko.kex_gex._MSG_KEXDH_GEX_INIT, msg)
- K = 67592995013596137876033460028393339951879041140378510871612128162185209509220726296697886624612526735888348020498716482757677848959420073720160491114319163078862905400020959196386947926388406687288901564192071077389283980347784184487280885335302632305026248574716290537036069329724382811853044654824945750581
- H = b'B41A06B2E59043CEFC1AE16EC31F1E2D12EC455B'
- x = b'210000000866616B652D6B6579000000807E2DDB1743F3487D6545F04F1C8476092FB912B013626AB5BCEB764257D88BBA64243B9F348DF7B41B8C814A995E00299913503456983FFB9178D3CD79EB6D55522418A8ABF65375872E55938AB99A84A0B5FC8A1ECC66A7C3766E7E0F80B7CE2C9225FC2DD683F4764244B72963BBB383F529DCF0C5D17740B8A2ADBE9208D40000000866616B652D736967'
+ K = (
+ 67592995013596137876033460028393339951879041140378510871612128162185209509220726296697886624612526735888348020498716482757677848959420073720160491114319163078862905400020959196386947926388406687288901564192071077389283980347784184487280885335302632305026248574716290537036069329724382811853044654824945750581
+ )
+ H = b"B41A06B2E59043CEFC1AE16EC31F1E2D12EC455B"
+ x = (
+ b"210000000866616B652D6B6579000000807E2DDB1743F3487D6545F04F1C8476092FB912B013626AB5BCEB764257D88BBA64243B9F348DF7B41B8C814A995E00299913503456983FFB9178D3CD79EB6D55522418A8ABF65375872E55938AB99A84A0B5FC8A1ECC66A7C3766E7E0F80B7CE2C9225FC2DD683F4764244B72963BBB383F529DCF0C5D17740B8A2ADBE9208D40000000866616B652D736967"
+ )
self.assertEqual(K, transport._K)
self.assertEqual(H, hexlify(transport._H).upper())
self.assertEqual(x, hexlify(transport._message.asbytes()).upper())
@@ -275,29 +341,35 @@ class KexTest (unittest.TestCase):
transport.server_mode = False
kex = KexGexSHA256(transport)
kex.start_kex()
- x = b'22000004000000080000002000'
+ x = b"22000004000000080000002000"
self.assertEqual(x, hexlify(transport._message.asbytes()).upper())
- self.assertEqual((paramiko.kex_gex._MSG_KEXDH_GEX_GROUP,), transport._expect)
+ self.assertEqual(
+ (paramiko.kex_gex._MSG_KEXDH_GEX_GROUP,), transport._expect
+ )
msg = Message()
msg.add_mpint(FakeModulusPack.P)
msg.add_mpint(FakeModulusPack.G)
msg.rewind()
kex.parse_next(paramiko.kex_gex._MSG_KEXDH_GEX_GROUP, msg)
- x = b'20000000807E2DDB1743F3487D6545F04F1C8476092FB912B013626AB5BCEB764257D88BBA64243B9F348DF7B41B8C814A995E00299913503456983FFB9178D3CD79EB6D55522418A8ABF65375872E55938AB99A84A0B5FC8A1ECC66A7C3766E7E0F80B7CE2C9225FC2DD683F4764244B72963BBB383F529DCF0C5D17740B8A2ADBE9208D4'
+ x = (
+ b"20000000807E2DDB1743F3487D6545F04F1C8476092FB912B013626AB5BCEB764257D88BBA64243B9F348DF7B41B8C814A995E00299913503456983FFB9178D3CD79EB6D55522418A8ABF65375872E55938AB99A84A0B5FC8A1ECC66A7C3766E7E0F80B7CE2C9225FC2DD683F4764244B72963BBB383F529DCF0C5D17740B8A2ADBE9208D4"
+ )
self.assertEqual(x, hexlify(transport._message.asbytes()).upper())
- self.assertEqual((paramiko.kex_gex._MSG_KEXDH_GEX_REPLY,), transport._expect)
+ self.assertEqual(
+ (paramiko.kex_gex._MSG_KEXDH_GEX_REPLY,), transport._expect
+ )
msg = Message()
- msg.add_string('fake-host-key')
+ msg.add_string("fake-host-key")
msg.add_mpint(69)
- msg.add_string('fake-sig')
+ msg.add_string("fake-sig")
msg.rewind()
kex.parse_next(paramiko.kex_gex._MSG_KEXDH_GEX_REPLY, msg)
- H = b'AD1A9365A67B4496F05594AD1BF656E3CDA0851289A4C1AFF549FEAE50896DF4'
+ H = b"AD1A9365A67B4496F05594AD1BF656E3CDA0851289A4C1AFF549FEAE50896DF4"
self.assertEqual(self.K, transport._K)
self.assertEqual(H, hexlify(transport._H).upper())
- self.assertEqual((b'fake-host-key', b'fake-sig'), transport._verify)
+ self.assertEqual((b"fake-host-key", b"fake-sig"), transport._verify)
self.assertTrue(transport._activated)
def test_8_gex_sha256_old_client(self):
@@ -305,29 +377,35 @@ class KexTest (unittest.TestCase):
transport.server_mode = False
kex = KexGexSHA256(transport)
kex.start_kex(_test_old_style=True)
- x = b'1E00000800'
+ x = b"1E00000800"
self.assertEqual(x, hexlify(transport._message.asbytes()).upper())
- self.assertEqual((paramiko.kex_gex._MSG_KEXDH_GEX_GROUP,), transport._expect)
+ self.assertEqual(
+ (paramiko.kex_gex._MSG_KEXDH_GEX_GROUP,), transport._expect
+ )
msg = Message()
msg.add_mpint(FakeModulusPack.P)
msg.add_mpint(FakeModulusPack.G)
msg.rewind()
kex.parse_next(paramiko.kex_gex._MSG_KEXDH_GEX_GROUP, msg)
- x = b'20000000807E2DDB1743F3487D6545F04F1C8476092FB912B013626AB5BCEB764257D88BBA64243B9F348DF7B41B8C814A995E00299913503456983FFB9178D3CD79EB6D55522418A8ABF65375872E55938AB99A84A0B5FC8A1ECC66A7C3766E7E0F80B7CE2C9225FC2DD683F4764244B72963BBB383F529DCF0C5D17740B8A2ADBE9208D4'
+ x = (
+ b"20000000807E2DDB1743F3487D6545F04F1C8476092FB912B013626AB5BCEB764257D88BBA64243B9F348DF7B41B8C814A995E00299913503456983FFB9178D3CD79EB6D55522418A8ABF65375872E55938AB99A84A0B5FC8A1ECC66A7C3766E7E0F80B7CE2C9225FC2DD683F4764244B72963BBB383F529DCF0C5D17740B8A2ADBE9208D4"
+ )
self.assertEqual(x, hexlify(transport._message.asbytes()).upper())
- self.assertEqual((paramiko.kex_gex._MSG_KEXDH_GEX_REPLY,), transport._expect)
+ self.assertEqual(
+ (paramiko.kex_gex._MSG_KEXDH_GEX_REPLY,), transport._expect
+ )
msg = Message()
- msg.add_string('fake-host-key')
+ msg.add_string("fake-host-key")
msg.add_mpint(69)
- msg.add_string('fake-sig')
+ msg.add_string("fake-sig")
msg.rewind()
kex.parse_next(paramiko.kex_gex._MSG_KEXDH_GEX_REPLY, msg)
- H = b'518386608B15891AE5237DEE08DCADDE76A0BCEFCE7F6DB3AD66BC41D256DFE5'
+ H = b"518386608B15891AE5237DEE08DCADDE76A0BCEFCE7F6DB3AD66BC41D256DFE5"
self.assertEqual(self.K, transport._K)
self.assertEqual(H, hexlify(transport._H).upper())
- self.assertEqual((b'fake-host-key', b'fake-sig'), transport._verify)
+ self.assertEqual((b"fake-host-key", b"fake-sig"), transport._verify)
self.assertTrue(transport._activated)
def test_9_gex_sha256_server(self):
@@ -335,7 +413,13 @@ class KexTest (unittest.TestCase):
transport.server_mode = True
kex = KexGexSHA256(transport)
kex.start_kex()
- self.assertEqual((paramiko.kex_gex._MSG_KEXDH_GEX_REQUEST, paramiko.kex_gex._MSG_KEXDH_GEX_REQUEST_OLD), transport._expect)
+ self.assertEqual(
+ (
+ paramiko.kex_gex._MSG_KEXDH_GEX_REQUEST,
+ paramiko.kex_gex._MSG_KEXDH_GEX_REQUEST_OLD,
+ ),
+ transport._expect,
+ )
msg = Message()
msg.add_int(1024)
@@ -343,17 +427,25 @@ class KexTest (unittest.TestCase):
msg.add_int(4096)
msg.rewind()
kex.parse_next(paramiko.kex_gex._MSG_KEXDH_GEX_REQUEST, msg)
- x = b'1F0000008100FFFFFFFFFFFFFFFFC90FDAA22168C234C4C6628B80DC1CD129024E088A67CC74020BBEA63B139B22514A08798E3404DDEF9519B3CD3A431B302B0A6DF25F14374FE1356D6D51C245E485B576625E7EC6F44C42E9A637ED6B0BFF5CB6F406B7EDEE386BFB5A899FA5AE9F24117C4B1FE649286651ECE65381FFFFFFFFFFFFFFFF0000000102'
+ x = (
+ b"1F0000008100FFFFFFFFFFFFFFFFC90FDAA22168C234C4C6628B80DC1CD129024E088A67CC74020BBEA63B139B22514A08798E3404DDEF9519B3CD3A431B302B0A6DF25F14374FE1356D6D51C245E485B576625E7EC6F44C42E9A637ED6B0BFF5CB6F406B7EDEE386BFB5A899FA5AE9F24117C4B1FE649286651ECE65381FFFFFFFFFFFFFFFF0000000102"
+ )
self.assertEqual(x, hexlify(transport._message.asbytes()).upper())
- self.assertEqual((paramiko.kex_gex._MSG_KEXDH_GEX_INIT,), transport._expect)
+ self.assertEqual(
+ (paramiko.kex_gex._MSG_KEXDH_GEX_INIT,), transport._expect
+ )
msg = Message()
msg.add_mpint(12345)
msg.rewind()
kex.parse_next(paramiko.kex_gex._MSG_KEXDH_GEX_INIT, msg)
- K = 67592995013596137876033460028393339951879041140378510871612128162185209509220726296697886624612526735888348020498716482757677848959420073720160491114319163078862905400020959196386947926388406687288901564192071077389283980347784184487280885335302632305026248574716290537036069329724382811853044654824945750581
- H = b'CCAC0497CF0ABA1DBF55E1A3995D17F4CC31824B0E8D95CDF8A06F169D050D80'
- x = b'210000000866616B652D6B6579000000807E2DDB1743F3487D6545F04F1C8476092FB912B013626AB5BCEB764257D88BBA64243B9F348DF7B41B8C814A995E00299913503456983FFB9178D3CD79EB6D55522418A8ABF65375872E55938AB99A84A0B5FC8A1ECC66A7C3766E7E0F80B7CE2C9225FC2DD683F4764244B72963BBB383F529DCF0C5D17740B8A2ADBE9208D40000000866616B652D736967'
+ K = (
+ 67592995013596137876033460028393339951879041140378510871612128162185209509220726296697886624612526735888348020498716482757677848959420073720160491114319163078862905400020959196386947926388406687288901564192071077389283980347784184487280885335302632305026248574716290537036069329724382811853044654824945750581
+ )
+ H = b"CCAC0497CF0ABA1DBF55E1A3995D17F4CC31824B0E8D95CDF8A06F169D050D80"
+ x = (
+ b"210000000866616B652D6B6579000000807E2DDB1743F3487D6545F04F1C8476092FB912B013626AB5BCEB764257D88BBA64243B9F348DF7B41B8C814A995E00299913503456983FFB9178D3CD79EB6D55522418A8ABF65375872E55938AB99A84A0B5FC8A1ECC66A7C3766E7E0F80B7CE2C9225FC2DD683F4764244B72963BBB383F529DCF0C5D17740B8A2ADBE9208D40000000866616B652D736967"
+ )
self.assertEqual(K, transport._K)
self.assertEqual(H, hexlify(transport._H).upper())
self.assertEqual(x, hexlify(transport._message.asbytes()).upper())
@@ -364,62 +456,88 @@ class KexTest (unittest.TestCase):
transport.server_mode = True
kex = KexGexSHA256(transport)
kex.start_kex()
- self.assertEqual((paramiko.kex_gex._MSG_KEXDH_GEX_REQUEST, paramiko.kex_gex._MSG_KEXDH_GEX_REQUEST_OLD), transport._expect)
+ self.assertEqual(
+ (
+ paramiko.kex_gex._MSG_KEXDH_GEX_REQUEST,
+ paramiko.kex_gex._MSG_KEXDH_GEX_REQUEST_OLD,
+ ),
+ transport._expect,
+ )
msg = Message()
msg.add_int(2048)
msg.rewind()
kex.parse_next(paramiko.kex_gex._MSG_KEXDH_GEX_REQUEST_OLD, msg)
- x = b'1F0000008100FFFFFFFFFFFFFFFFC90FDAA22168C234C4C6628B80DC1CD129024E088A67CC74020BBEA63B139B22514A08798E3404DDEF9519B3CD3A431B302B0A6DF25F14374FE1356D6D51C245E485B576625E7EC6F44C42E9A637ED6B0BFF5CB6F406B7EDEE386BFB5A899FA5AE9F24117C4B1FE649286651ECE65381FFFFFFFFFFFFFFFF0000000102'
+ x = (
+ b"1F0000008100FFFFFFFFFFFFFFFFC90FDAA22168C234C4C6628B80DC1CD129024E088A67CC74020BBEA63B139B22514A08798E3404DDEF9519B3CD3A431B302B0A6DF25F14374FE1356D6D51C245E485B576625E7EC6F44C42E9A637ED6B0BFF5CB6F406B7EDEE386BFB5A899FA5AE9F24117C4B1FE649286651ECE65381FFFFFFFFFFFFFFFF0000000102"
+ )
self.assertEqual(x, hexlify(transport._message.asbytes()).upper())
- self.assertEqual((paramiko.kex_gex._MSG_KEXDH_GEX_INIT,), transport._expect)
+ self.assertEqual(
+ (paramiko.kex_gex._MSG_KEXDH_GEX_INIT,), transport._expect
+ )
msg = Message()
msg.add_mpint(12345)
msg.rewind()
kex.parse_next(paramiko.kex_gex._MSG_KEXDH_GEX_INIT, msg)
- K = 67592995013596137876033460028393339951879041140378510871612128162185209509220726296697886624612526735888348020498716482757677848959420073720160491114319163078862905400020959196386947926388406687288901564192071077389283980347784184487280885335302632305026248574716290537036069329724382811853044654824945750581
- H = b'3DDD2AD840AD095E397BA4D0573972DC60F6461FD38A187CACA6615A5BC8ADBB'
- x = b'210000000866616B652D6B6579000000807E2DDB1743F3487D6545F04F1C8476092FB912B013626AB5BCEB764257D88BBA64243B9F348DF7B41B8C814A995E00299913503456983FFB9178D3CD79EB6D55522418A8ABF65375872E55938AB99A84A0B5FC8A1ECC66A7C3766E7E0F80B7CE2C9225FC2DD683F4764244B72963BBB383F529DCF0C5D17740B8A2ADBE9208D40000000866616B652D736967'
+ K = (
+ 67592995013596137876033460028393339951879041140378510871612128162185209509220726296697886624612526735888348020498716482757677848959420073720160491114319163078862905400020959196386947926388406687288901564192071077389283980347784184487280885335302632305026248574716290537036069329724382811853044654824945750581
+ )
+ H = b"3DDD2AD840AD095E397BA4D0573972DC60F6461FD38A187CACA6615A5BC8ADBB"
+ x = (
+ b"210000000866616B652D6B6579000000807E2DDB1743F3487D6545F04F1C8476092FB912B013626AB5BCEB764257D88BBA64243B9F348DF7B41B8C814A995E00299913503456983FFB9178D3CD79EB6D55522418A8ABF65375872E55938AB99A84A0B5FC8A1ECC66A7C3766E7E0F80B7CE2C9225FC2DD683F4764244B72963BBB383F529DCF0C5D17740B8A2ADBE9208D40000000866616B652D736967"
+ )
self.assertEqual(K, transport._K)
self.assertEqual(H, hexlify(transport._H).upper())
self.assertEqual(x, hexlify(transport._message.asbytes()).upper())
self.assertTrue(transport._activated)
def test_11_kex_nistp256_client(self):
- K = 91610929826364598472338906427792435253694642563583721654249504912114314269754
+ K = (
+ 91610929826364598472338906427792435253694642563583721654249504912114314269754
+ )
transport = FakeTransport()
transport.server_mode = False
kex = KexNistp256(transport)
kex.start_kex()
- self.assertEqual((paramiko.kex_ecdh_nist._MSG_KEXECDH_REPLY,), transport._expect)
+ self.assertEqual(
+ (paramiko.kex_ecdh_nist._MSG_KEXECDH_REPLY,), transport._expect
+ )
- #fake reply
+ # fake reply
msg = Message()
- msg.add_string('fake-host-key')
- Q_S = unhexlify("043ae159594ba062efa121480e9ef136203fa9ec6b6e1f8723a321c16e62b945f573f3b822258cbcd094b9fa1c125cbfe5f043280893e66863cc0cb4dccbe70210")
+ msg.add_string("fake-host-key")
+ Q_S = unhexlify(
+ "043ae159594ba062efa121480e9ef136203fa9ec6b6e1f8723a321c16e62b945f573f3b822258cbcd094b9fa1c125cbfe5f043280893e66863cc0cb4dccbe70210"
+ )
msg.add_string(Q_S)
- msg.add_string('fake-sig')
+ msg.add_string("fake-sig")
msg.rewind()
kex.parse_next(paramiko.kex_ecdh_nist._MSG_KEXECDH_REPLY, msg)
- H = b'BAF7CE243A836037EB5D2221420F35C02B9AB6C957FE3BDE3369307B9612570A'
+ H = b"BAF7CE243A836037EB5D2221420F35C02B9AB6C957FE3BDE3369307B9612570A"
self.assertEqual(K, kex.transport._K)
self.assertEqual(H, hexlify(transport._H).upper())
- self.assertEqual((b'fake-host-key', b'fake-sig'), transport._verify)
+ self.assertEqual((b"fake-host-key", b"fake-sig"), transport._verify)
self.assertTrue(transport._activated)
def test_12_kex_nistp256_server(self):
- K = 91610929826364598472338906427792435253694642563583721654249504912114314269754
+ K = (
+ 91610929826364598472338906427792435253694642563583721654249504912114314269754
+ )
transport = FakeTransport()
transport.server_mode = True
kex = KexNistp256(transport)
kex.start_kex()
- self.assertEqual((paramiko.kex_ecdh_nist._MSG_KEXECDH_INIT,), transport._expect)
+ self.assertEqual(
+ (paramiko.kex_ecdh_nist._MSG_KEXECDH_INIT,), transport._expect
+ )
- #fake init
- msg=Message()
- Q_C = unhexlify("043ae159594ba062efa121480e9ef136203fa9ec6b6e1f8723a321c16e62b945f573f3b822258cbcd094b9fa1c125cbfe5f043280893e66863cc0cb4dccbe70210")
- H = b'2EF4957AFD530DD3F05DBEABF68D724FACC060974DA9704F2AEE4C3DE861E7CA'
+ # fake init
+ msg = Message()
+ Q_C = unhexlify(
+ "043ae159594ba062efa121480e9ef136203fa9ec6b6e1f8723a321c16e62b945f573f3b822258cbcd094b9fa1c125cbfe5f043280893e66863cc0cb4dccbe70210"
+ )
+ H = b"2EF4957AFD530DD3F05DBEABF68D724FACC060974DA9704F2AEE4C3DE861E7CA"
msg.add_string(Q_C)
msg.rewind()
kex.parse_next(paramiko.kex_ecdh_nist._MSG_KEXECDH_INIT, msg)
diff --git a/tests/test_kex_gss.py b/tests/test_kex_gss.py
index 025d1faa..afddee08 100644
--- a/tests/test_kex_gss.py
+++ b/tests/test_kex_gss.py
@@ -34,14 +34,14 @@ import paramiko
from .util import needs_gssapi
-class NullServer (paramiko.ServerInterface):
+class NullServer(paramiko.ServerInterface):
def get_allowed_auths(self, username):
- return 'gssapi-keyex'
+ return "gssapi-keyex"
- def check_auth_gssapi_keyex(self, username,
- gss_authenticated=paramiko.AUTH_FAILED,
- cc_file=None):
+ def check_auth_gssapi_keyex(
+ self, username, gss_authenticated=paramiko.AUTH_FAILED, cc_file=None
+ ):
if gss_authenticated == paramiko.AUTH_SUCCESSFUL:
return paramiko.AUTH_SUCCESSFUL
return paramiko.AUTH_FAILED
@@ -54,13 +54,14 @@ class NullServer (paramiko.ServerInterface):
return paramiko.OPEN_SUCCEEDED
def check_channel_exec_request(self, channel, command):
- if command != 'yes':
+ if command != "yes":
return False
return True
@needs_gssapi
class GSSKexTest(unittest.TestCase):
+
@staticmethod
def init(username, hostname):
global krb5_principal, targ_name
@@ -86,13 +87,13 @@ class GSSKexTest(unittest.TestCase):
def _run(self):
self.socks, addr = self.sockl.accept()
self.ts = paramiko.Transport(self.socks, gss_kex=True)
- host_key = paramiko.RSAKey.from_private_key_file('tests/test_rsa.key')
+ host_key = paramiko.RSAKey.from_private_key_file("tests/test_rsa.key")
self.ts.add_server_key(host_key)
self.ts.set_gss_host(targ_name)
try:
self.ts.load_server_moduli()
except:
- print ('(Failed to load moduli -- gex will be unsupported.)')
+ print("(Failed to load moduli -- gex will be unsupported.)")
server = NullServer()
self.ts.start_server(self.event, server)
@@ -102,14 +103,21 @@ class GSSKexTest(unittest.TestCase):
Diffie-Hellman Key Exchange and user authentication with the GSS-API
context created during key exchange.
"""
- host_key = paramiko.RSAKey.from_private_key_file('tests/test_rsa.key')
+ host_key = paramiko.RSAKey.from_private_key_file("tests/test_rsa.key")
public_host_key = paramiko.RSAKey(data=host_key.asbytes())
self.tc = paramiko.SSHClient()
- self.tc.get_host_keys().add('[%s]:%d' % (self.hostname, self.port),
- 'ssh-rsa', public_host_key)
- self.tc.connect(self.hostname, self.port, username=self.username,
- gss_auth=True, gss_kex=True, gss_host=gss_host)
+ self.tc.get_host_keys().add(
+ "[%s]:%d" % (self.hostname, self.port), "ssh-rsa", public_host_key
+ )
+ self.tc.connect(
+ self.hostname,
+ self.port,
+ username=self.username,
+ gss_auth=True,
+ gss_kex=True,
+ gss_host=gss_host,
+ )
self.event.wait(1.0)
self.assert_(self.event.is_set())
@@ -118,19 +126,19 @@ class GSSKexTest(unittest.TestCase):
self.assertEquals(True, self.ts.is_authenticated())
self.assertEquals(True, self.tc.get_transport().gss_kex_used)
- stdin, stdout, stderr = self.tc.exec_command('yes')
+ stdin, stdout, stderr = self.tc.exec_command("yes")
schan = self.ts.accept(1.0)
if rekey:
self.tc.get_transport().renegotiate_keys()
- schan.send('Hello there.\n')
- schan.send_stderr('This is on stderr.\n')
+ schan.send("Hello there.\n")
+ schan.send_stderr("This is on stderr.\n")
schan.close()
- self.assertEquals('Hello there.\n', stdout.readline())
- self.assertEquals('', stdout.readline())
- self.assertEquals('This is on stderr.\n', stderr.readline())
- self.assertEquals('', stderr.readline())
+ self.assertEquals("Hello there.\n", stdout.readline())
+ self.assertEquals("", stdout.readline())
+ self.assertEquals("This is on stderr.\n", stderr.readline())
+ self.assertEquals("", stderr.readline())
stdin.close()
stdout.close()
diff --git a/tests/test_message.py b/tests/test_message.py
index 645b0509..c292f4e6 100644
--- a/tests/test_message.py
+++ b/tests/test_message.py
@@ -26,20 +26,29 @@ from paramiko.message import Message
from paramiko.common import byte_chr, zero_byte
-class MessageTest (unittest.TestCase):
+class MessageTest(unittest.TestCase):
- __a = b'\x00\x00\x00\x17\x07\x60\xe0\x90\x00\x00\x00\x01\x71\x00\x00\x00\x05\x68\x65\x6c\x6c\x6f\x00\x00\x03\xe8' + b'x' * 1000
- __b = b'\x01\x00\xf3\x00\x3f\x00\x00\x00\x10\x68\x75\x65\x79\x2c\x64\x65\x77\x65\x79\x2c\x6c\x6f\x75\x69\x65'
- __c = b'\x00\x00\x00\x00\x00\x00\x00\x05\x00\x00\xf5\xe4\xd3\xc2\xb1\x09\x00\x00\x00\x01\x11\x00\x00\x00\x07\x00\xf5\xe4\xd3\xc2\xb1\x09\x00\x00\x00\x06\x9a\x1b\x2c\x3d\x4e\xf7'
- __d = b'\x00\x00\x00\x05\xff\x00\x00\x00\x05\x11\x22\x33\x44\x55\xff\x00\x00\x00\x0a\x00\xf0\x00\x00\x00\x00\x00\x00\x00\x00\x01\x00\x00\x00\x03\x63\x61\x74\x00\x00\x00\x03\x61\x2c\x62'
+ __a = (
+ b"\x00\x00\x00\x17\x07\x60\xe0\x90\x00\x00\x00\x01\x71\x00\x00\x00\x05\x68\x65\x6c\x6c\x6f\x00\x00\x03\xe8"
+ + b"x" * 1000
+ )
+ __b = (
+ b"\x01\x00\xf3\x00\x3f\x00\x00\x00\x10\x68\x75\x65\x79\x2c\x64\x65\x77\x65\x79\x2c\x6c\x6f\x75\x69\x65"
+ )
+ __c = (
+ b"\x00\x00\x00\x00\x00\x00\x00\x05\x00\x00\xf5\xe4\xd3\xc2\xb1\x09\x00\x00\x00\x01\x11\x00\x00\x00\x07\x00\xf5\xe4\xd3\xc2\xb1\x09\x00\x00\x00\x06\x9a\x1b\x2c\x3d\x4e\xf7"
+ )
+ __d = (
+ b"\x00\x00\x00\x05\xff\x00\x00\x00\x05\x11\x22\x33\x44\x55\xff\x00\x00\x00\x0a\x00\xf0\x00\x00\x00\x00\x00\x00\x00\x00\x01\x00\x00\x00\x03\x63\x61\x74\x00\x00\x00\x03\x61\x2c\x62"
+ )
def test_1_encode(self):
msg = Message()
msg.add_int(23)
msg.add_int(123789456)
- msg.add_string('q')
- msg.add_string('hello')
- msg.add_string('x' * 1000)
+ msg.add_string("q")
+ msg.add_string("hello")
+ msg.add_string("x" * 1000)
self.assertEqual(msg.asbytes(), self.__a)
msg = Message()
@@ -48,7 +57,7 @@ class MessageTest (unittest.TestCase):
msg.add_byte(byte_chr(0xf3))
msg.add_bytes(zero_byte + byte_chr(0x3f))
- msg.add_list(['huey', 'dewey', 'louie'])
+ msg.add_list(["huey", "dewey", "louie"])
self.assertEqual(msg.asbytes(), self.__b)
msg = Message()
@@ -63,16 +72,16 @@ class MessageTest (unittest.TestCase):
msg = Message(self.__a)
self.assertEqual(msg.get_int(), 23)
self.assertEqual(msg.get_int(), 123789456)
- self.assertEqual(msg.get_text(), 'q')
- self.assertEqual(msg.get_text(), 'hello')
- self.assertEqual(msg.get_text(), 'x' * 1000)
+ self.assertEqual(msg.get_text(), "q")
+ self.assertEqual(msg.get_text(), "hello")
+ self.assertEqual(msg.get_text(), "x" * 1000)
msg = Message(self.__b)
self.assertEqual(msg.get_boolean(), True)
self.assertEqual(msg.get_boolean(), False)
self.assertEqual(msg.get_byte(), byte_chr(0xf3))
self.assertEqual(msg.get_bytes(2), zero_byte + byte_chr(0x3f))
- self.assertEqual(msg.get_list(), ['huey', 'dewey', 'louie'])
+ self.assertEqual(msg.get_list(), ["huey", "dewey", "louie"])
msg = Message(self.__c)
self.assertEqual(msg.get_int64(), 5)
@@ -87,8 +96,8 @@ class MessageTest (unittest.TestCase):
msg.add(0x1122334455)
msg.add(0xf00000000000000000)
msg.add(True)
- msg.add('cat')
- msg.add(['a', 'b'])
+ msg.add("cat")
+ msg.add(["a", "b"])
self.assertEqual(msg.asbytes(), self.__d)
def test_4_misc(self):
diff --git a/tests/test_packetizer.py b/tests/test_packetizer.py
index 414b7e38..dbe5993e 100644
--- a/tests/test_packetizer.py
+++ b/tests/test_packetizer.py
@@ -36,19 +36,20 @@ from .loop import LoopSocket
x55 = byte_chr(0x55)
x1f = byte_chr(0x1f)
-class PacketizerTest (unittest.TestCase):
+
+class PacketizerTest(unittest.TestCase):
def test_1_write(self):
rsock = LoopSocket()
wsock = LoopSocket()
rsock.link(wsock)
p = Packetizer(wsock)
- p.set_log(util.get_logger('paramiko.transport'))
+ p.set_log(util.get_logger("paramiko.transport"))
p.set_hexdump(True)
encryptor = Cipher(
algorithms.AES(zero_byte * 16),
modes.CBC(x55 * 16),
- backend=default_backend()
+ backend=default_backend(),
).encryptor()
p.set_outbound_cipher(encryptor, 16, sha1, 12, x1f * 20)
@@ -63,22 +64,27 @@ class PacketizerTest (unittest.TestCase):
data = rsock.recv(100)
# 32 + 12 bytes of MAC = 44
self.assertEqual(44, len(data))
- self.assertEqual(b'\x43\x91\x97\xbd\x5b\x50\xac\x25\x87\xc2\xc4\x6b\xc7\xe9\x38\xc0', data[:16])
+ self.assertEqual(
+ b"\x43\x91\x97\xbd\x5b\x50\xac\x25\x87\xc2\xc4\x6b\xc7\xe9\x38\xc0",
+ data[:16],
+ )
def test_2_read(self):
rsock = LoopSocket()
wsock = LoopSocket()
rsock.link(wsock)
p = Packetizer(rsock)
- p.set_log(util.get_logger('paramiko.transport'))
+ p.set_log(util.get_logger("paramiko.transport"))
p.set_hexdump(True)
decryptor = Cipher(
algorithms.AES(zero_byte * 16),
modes.CBC(x55 * 16),
- backend=default_backend()
+ backend=default_backend(),
).decryptor()
p.set_inbound_cipher(decryptor, 16, sha1, 12, x1f * 20)
- wsock.send(b'\x43\x91\x97\xbd\x5b\x50\xac\x25\x87\xc2\xc4\x6b\xc7\xe9\x38\xc0\x90\xd2\x16\x56\x0d\x71\x73\x61\x38\x7c\x4c\x3d\xfb\x97\x7d\xe2\x6e\x03\xb1\xa0\xc2\x1c\xd6\x41\x41\x4c\xb4\x59')
+ wsock.send(
+ b"\x43\x91\x97\xbd\x5b\x50\xac\x25\x87\xc2\xc4\x6b\xc7\xe9\x38\xc0\x90\xd2\x16\x56\x0d\x71\x73\x61\x38\x7c\x4c\x3d\xfb\x97\x7d\xe2\x6e\x03\xb1\xa0\xc2\x1c\xd6\x41\x41\x4c\xb4\x59"
+ )
cmd, m = p.read_message()
self.assertEqual(100, cmd)
self.assertEqual(100, m.get_int())
@@ -86,18 +92,18 @@ class PacketizerTest (unittest.TestCase):
self.assertEqual(900, m.get_int())
def test_3_closed(self):
- if sys.platform.startswith("win"): # no SIGALRM on windows
+ if sys.platform.startswith("win"): # no SIGALRM on windows
return
rsock = LoopSocket()
wsock = LoopSocket()
rsock.link(wsock)
p = Packetizer(wsock)
- p.set_log(util.get_logger('paramiko.transport'))
+ p.set_log(util.get_logger("paramiko.transport"))
p.set_hexdump(True)
encryptor = Cipher(
algorithms.AES(zero_byte * 16),
modes.CBC(x55 * 16),
- backend=default_backend()
+ backend=default_backend(),
).encryptor()
p.set_outbound_cipher(encryptor, 16, sha1, 12, x1f * 20)
@@ -115,14 +121,17 @@ class PacketizerTest (unittest.TestCase):
import signal
class TimeoutError(Exception):
+
def __init__(self, error_message):
- if hasattr(errno, 'ETIME'):
+ if hasattr(errno, "ETIME"):
self.message = os.sterror(errno.ETIME)
else:
self.messaage = error_message
- def timeout(seconds=1, error_message='Timer expired'):
+ def timeout(seconds=1, error_message="Timer expired"):
+
def decorator(func):
+
def _handle_timeout(signum, frame):
raise TimeoutError(error_message)
@@ -138,5 +147,6 @@ class PacketizerTest (unittest.TestCase):
return wraps(func)(wrapper)
return decorator
+
send = timeout()(p.send_message)
self.assertRaises(EOFError, send, m)
diff --git a/tests/test_pkey.py b/tests/test_pkey.py
index 1827d2a9..4bbfaba1 100644
--- a/tests/test_pkey.py
+++ b/tests/test_pkey.py
@@ -34,18 +34,30 @@ from .util import _support
# from openssh's ssh-keygen
-PUB_RSA = 'ssh-rsa AAAAB3NzaC1yc2EAAAABIwAAAIEA049W6geFpmsljTwfvI1UmKWWJPNFI74+vNKTk4dmzkQY2yAMs6FhlvhlI8ysU4oj71ZsRYMecHbBbxdN79+JRFVYTKaLqjwGENeTd+yv4q+V2PvZv3fLnzApI3l7EJCqhWwJUHJ1jAkZzqDx0tyOL4uoZpww3nmE0kb3y21tH4c='
-PUB_DSS = 'ssh-dss AAAAB3NzaC1kc3MAAACBAOeBpgNnfRzr/twmAQRu2XwWAp3CFtrVnug6s6fgwj/oLjYbVtjAy6pl/h0EKCWx2rf1IetyNsTxWrniA9I6HeDj65X1FyDkg6g8tvCnaNB8Xp/UUhuzHuGsMIipRxBxw9LF608EqZcj1E3ytktoW5B5OcjrkEoz3xG7C+rpIjYvAAAAFQDwz4UnmsGiSNu5iqjn3uTzwUpshwAAAIEAkxfFeY8P2wZpDjX0MimZl5wkoFQDL25cPzGBuB4OnB8NoUk/yjAHIIpEShw8V+LzouMK5CTJQo5+Ngw3qIch/WgRmMHy4kBq1SsXMjQCte1So6HBMvBPIW5SiMTmjCfZZiw4AYHK+B/JaOwaG9yRg2Ejg4Ok10+XFDxlqZo8Y+wAAACARmR7CCPjodxASvRbIyzaVpZoJ/Z6x7dAumV+ysrV1BVYd0lYukmnjO1kKBWApqpH1ve9XDQYN8zgxM4b16L21kpoWQnZtXrY3GZ4/it9kUgyB7+NwacIBlXa8cMDL7Q/69o0d54U0X/NeX5QxuYR6OMJlrkQB7oiW/P/1mwjQgE='
-PUB_ECDSA_256 = 'ecdsa-sha2-nistp256 AAAAE2VjZHNhLXNoYTItbmlzdHAyNTYAAAAIbmlzdHAyNTYAAABBBJSPZm3ZWkvk/Zx8WP+fZRZ5/NBBHnGQwR6uIC6XHGPDIHuWUzIjAwA0bzqkOUffEsbLe+uQgKl5kbc/L8KA/eo='
-PUB_ECDSA_384 = 'ecdsa-sha2-nistp384 AAAAE2VjZHNhLXNoYTItbmlzdHAzODQAAAAIbmlzdHAzODQAAABhBBbGibQLW9AAZiGN2hEQxWYYoFaWKwN3PKSaDJSMqmIn1Z9sgRUuw8Y/w502OGvXL/wFk0i2z50l3pWZjD7gfMH7gX5TUiCzwrQkS+Hn1U2S9aF5WJp0NcIzYxXw2r4M2A=='
-PUB_ECDSA_521 = 'ecdsa-sha2-nistp521 AAAAE2VjZHNhLXNoYTItbmlzdHA1MjEAAAAIbmlzdHA1MjEAAACFBACaOaFLZGuxa5AW16qj6VLypFbLrEWrt9AZUloCMefxO8bNLjK/O5g0rAVasar1TnyHE9qj4NwzANZASWjQNbc4MAG8vzqezFwLIn/kNyNTsXNfqEko9OgHZknlj2Z79dwTJcRAL4QLcT5aND0EHZLB2fAUDXiWIb2j4rg1mwPlBMiBXA=='
-
-FINGER_RSA = '1024 60:73:38:44:cb:51:86:65:7f:de:da:a2:2b:5a:57:d5'
-FINGER_DSS = '1024 44:78:f0:b9:a2:3c:c5:18:20:09:ff:75:5b:c1:d2:6c'
-FINGER_ECDSA_256 = '256 25:19:eb:55:e6:a1:47:ff:4f:38:d2:75:6f:a5:d5:60'
-FINGER_ECDSA_384 = '384 c1:8d:a0:59:09:47:41:8e:a8:a6:07:01:29:23:b4:65'
-FINGER_ECDSA_521 = '521 44:58:22:52:12:33:16:0e:ce:0e:be:2c:7c:7e:cc:1e'
-SIGNED_RSA = '20:d7:8a:31:21:cb:f7:92:12:f2:a4:89:37:f5:78:af:e6:16:b6:25:b9:97:3d:a2:cd:5f:ca:20:21:73:4c:ad:34:73:8f:20:77:28:e2:94:15:08:d8:91:40:7a:85:83:bf:18:37:95:dc:54:1a:9b:88:29:6c:73:ca:38:b4:04:f1:56:b9:f2:42:9d:52:1b:29:29:b4:4f:fd:c9:2d:af:47:d2:40:76:30:f3:63:45:0c:d9:1d:43:86:0f:1c:70:e2:93:12:34:f3:ac:c5:0a:2f:14:50:66:59:f1:88:ee:c1:4a:e9:d1:9c:4e:46:f0:0e:47:6f:38:74:f1:44:a8'
+PUB_RSA = (
+ "ssh-rsa AAAAB3NzaC1yc2EAAAABIwAAAIEA049W6geFpmsljTwfvI1UmKWWJPNFI74+vNKTk4dmzkQY2yAMs6FhlvhlI8ysU4oj71ZsRYMecHbBbxdN79+JRFVYTKaLqjwGENeTd+yv4q+V2PvZv3fLnzApI3l7EJCqhWwJUHJ1jAkZzqDx0tyOL4uoZpww3nmE0kb3y21tH4c="
+)
+PUB_DSS = (
+ "ssh-dss AAAAB3NzaC1kc3MAAACBAOeBpgNnfRzr/twmAQRu2XwWAp3CFtrVnug6s6fgwj/oLjYbVtjAy6pl/h0EKCWx2rf1IetyNsTxWrniA9I6HeDj65X1FyDkg6g8tvCnaNB8Xp/UUhuzHuGsMIipRxBxw9LF608EqZcj1E3ytktoW5B5OcjrkEoz3xG7C+rpIjYvAAAAFQDwz4UnmsGiSNu5iqjn3uTzwUpshwAAAIEAkxfFeY8P2wZpDjX0MimZl5wkoFQDL25cPzGBuB4OnB8NoUk/yjAHIIpEShw8V+LzouMK5CTJQo5+Ngw3qIch/WgRmMHy4kBq1SsXMjQCte1So6HBMvBPIW5SiMTmjCfZZiw4AYHK+B/JaOwaG9yRg2Ejg4Ok10+XFDxlqZo8Y+wAAACARmR7CCPjodxASvRbIyzaVpZoJ/Z6x7dAumV+ysrV1BVYd0lYukmnjO1kKBWApqpH1ve9XDQYN8zgxM4b16L21kpoWQnZtXrY3GZ4/it9kUgyB7+NwacIBlXa8cMDL7Q/69o0d54U0X/NeX5QxuYR6OMJlrkQB7oiW/P/1mwjQgE="
+)
+PUB_ECDSA_256 = (
+ "ecdsa-sha2-nistp256 AAAAE2VjZHNhLXNoYTItbmlzdHAyNTYAAAAIbmlzdHAyNTYAAABBBJSPZm3ZWkvk/Zx8WP+fZRZ5/NBBHnGQwR6uIC6XHGPDIHuWUzIjAwA0bzqkOUffEsbLe+uQgKl5kbc/L8KA/eo="
+)
+PUB_ECDSA_384 = (
+ "ecdsa-sha2-nistp384 AAAAE2VjZHNhLXNoYTItbmlzdHAzODQAAAAIbmlzdHAzODQAAABhBBbGibQLW9AAZiGN2hEQxWYYoFaWKwN3PKSaDJSMqmIn1Z9sgRUuw8Y/w502OGvXL/wFk0i2z50l3pWZjD7gfMH7gX5TUiCzwrQkS+Hn1U2S9aF5WJp0NcIzYxXw2r4M2A=="
+)
+PUB_ECDSA_521 = (
+ "ecdsa-sha2-nistp521 AAAAE2VjZHNhLXNoYTItbmlzdHA1MjEAAAAIbmlzdHA1MjEAAACFBACaOaFLZGuxa5AW16qj6VLypFbLrEWrt9AZUloCMefxO8bNLjK/O5g0rAVasar1TnyHE9qj4NwzANZASWjQNbc4MAG8vzqezFwLIn/kNyNTsXNfqEko9OgHZknlj2Z79dwTJcRAL4QLcT5aND0EHZLB2fAUDXiWIb2j4rg1mwPlBMiBXA=="
+)
+
+FINGER_RSA = "1024 60:73:38:44:cb:51:86:65:7f:de:da:a2:2b:5a:57:d5"
+FINGER_DSS = "1024 44:78:f0:b9:a2:3c:c5:18:20:09:ff:75:5b:c1:d2:6c"
+FINGER_ECDSA_256 = "256 25:19:eb:55:e6:a1:47:ff:4f:38:d2:75:6f:a5:d5:60"
+FINGER_ECDSA_384 = "384 c1:8d:a0:59:09:47:41:8e:a8:a6:07:01:29:23:b4:65"
+FINGER_ECDSA_521 = "521 44:58:22:52:12:33:16:0e:ce:0e:be:2c:7c:7e:cc:1e"
+SIGNED_RSA = (
+ "20:d7:8a:31:21:cb:f7:92:12:f2:a4:89:37:f5:78:af:e6:16:b6:25:b9:97:3d:a2:cd:5f:ca:20:21:73:4c:ad:34:73:8f:20:77:28:e2:94:15:08:d8:91:40:7a:85:83:bf:18:37:95:dc:54:1a:9b:88:29:6c:73:ca:38:b4:04:f1:56:b9:f2:42:9d:52:1b:29:29:b4:4f:fd:c9:2d:af:47:d2:40:76:30:f3:63:45:0c:d9:1d:43:86:0f:1c:70:e2:93:12:34:f3:ac:c5:0a:2f:14:50:66:59:f1:88:ee:c1:4a:e9:d1:9c:4e:46:f0:0e:47:6f:38:74:f1:44:a8"
+)
RSA_PRIVATE_OUT = """\
-----BEGIN RSA PRIVATE KEY-----
@@ -107,10 +119,14 @@ L4QLcT5aND0EHZLB2fAUDXiWIb2j4rg1mwPlBMiBXA==
-----END EC PRIVATE KEY-----
"""
-x1234 = b'\x01\x02\x03\x04'
+x1234 = b"\x01\x02\x03\x04"
-TEST_KEY_BYTESTR_2 = '\x00\x00\x00\x07ssh-rsa\x00\x00\x00\x01#\x00\x00\x00\x81\x00\xd3\x8fV\xea\x07\x85\xa6k%\x8d<\x1f\xbc\x8dT\x98\xa5\x96$\xf3E#\xbe>\xbc\xd2\x93\x93\x87f\xceD\x18\xdb \x0c\xb3\xa1a\x96\xf8e#\xcc\xacS\x8a#\xefVlE\x83\x1epv\xc1o\x17M\xef\xdf\x89DUXL\xa6\x8b\xaa<\x06\x10\xd7\x93w\xec\xaf\xe2\xaf\x95\xd8\xfb\xd9\xbfw\xcb\x9f0)#y{\x10\x90\xaa\x85l\tPru\x8c\t\x19\xce\xa0\xf1\xd2\xdc\x8e/\x8b\xa8f\x9c0\xdey\x84\xd2F\xf7\xcbmm\x1f\x87'
-TEST_KEY_BYTESTR_3 = '\x00\x00\x00\x07ssh-rsa\x00\x00\x00\x01#\x00\x00\x00\x00ӏV\x07k%<\x1fT$E#>ғfD\x18 \x0cae#̬S#VlE\x1epvo\x17M߉DUXL<\x06\x10דw\u2bd5ٿw˟0)#y{\x10l\tPru\t\x19Π\u070e/f0yFmm\x1f'
+TEST_KEY_BYTESTR_2 = (
+ "\x00\x00\x00\x07ssh-rsa\x00\x00\x00\x01#\x00\x00\x00\x81\x00\xd3\x8fV\xea\x07\x85\xa6k%\x8d<\x1f\xbc\x8dT\x98\xa5\x96$\xf3E#\xbe>\xbc\xd2\x93\x93\x87f\xceD\x18\xdb \x0c\xb3\xa1a\x96\xf8e#\xcc\xacS\x8a#\xefVlE\x83\x1epv\xc1o\x17M\xef\xdf\x89DUXL\xa6\x8b\xaa<\x06\x10\xd7\x93w\xec\xaf\xe2\xaf\x95\xd8\xfb\xd9\xbfw\xcb\x9f0)#y{\x10\x90\xaa\x85l\tPru\x8c\t\x19\xce\xa0\xf1\xd2\xdc\x8e/\x8b\xa8f\x9c0\xdey\x84\xd2F\xf7\xcbmm\x1f\x87"
+)
+TEST_KEY_BYTESTR_3 = (
+ "\x00\x00\x00\x07ssh-rsa\x00\x00\x00\x01#\x00\x00\x00\x00ӏV\x07k%<\x1fT$E#>ғfD\x18 \x0cae#̬S#VlE\x1epvo\x17M߉DUXL<\x06\x10דw\u2bd5ٿw˟0)#y{\x10l\tPru\t\x19Π\u070e/f0yFmm\x1f"
+)
class KeyTest(unittest.TestCase):
@@ -127,21 +143,22 @@ class KeyTest(unittest.TestCase):
"""
with open(keyfile, "r") as fh:
self.assertEqual(
- fh.readline()[:-1],
- "-----BEGIN RSA PRIVATE KEY-----"
+ fh.readline()[:-1], "-----BEGIN RSA PRIVATE KEY-----"
)
self.assertEqual(fh.readline()[:-1], "Proc-Type: 4,ENCRYPTED")
self.assertEqual(fh.readline()[0:10], "DEK-Info: ")
def test_1_generate_key_bytes(self):
- key = util.generate_key_bytes(md5, x1234, 'happy birthday', 30)
- exp = b'\x61\xE1\xF2\x72\xF4\xC1\xC4\x56\x15\x86\xBD\x32\x24\x98\xC0\xE9\x24\x67\x27\x80\xF4\x7B\xB3\x7D\xDA\x7D\x54\x01\x9E\x64'
+ key = util.generate_key_bytes(md5, x1234, "happy birthday", 30)
+ exp = (
+ b"\x61\xE1\xF2\x72\xF4\xC1\xC4\x56\x15\x86\xBD\x32\x24\x98\xC0\xE9\x24\x67\x27\x80\xF4\x7B\xB3\x7D\xDA\x7D\x54\x01\x9E\x64"
+ )
self.assertEqual(exp, key)
def test_2_load_rsa(self):
- key = RSAKey.from_private_key_file(_support('test_rsa.key'))
- self.assertEqual('ssh-rsa', key.get_name())
- exp_rsa = b(FINGER_RSA.split()[1].replace(':', ''))
+ key = RSAKey.from_private_key_file(_support("test_rsa.key"))
+ self.assertEqual("ssh-rsa", key.get_name())
+ exp_rsa = b(FINGER_RSA.split()[1].replace(":", ""))
my_rsa = hexlify(key.get_fingerprint())
self.assertEqual(exp_rsa, my_rsa)
self.assertEqual(PUB_RSA.split()[1], key.get_base64())
@@ -155,18 +172,20 @@ class KeyTest(unittest.TestCase):
self.assertEqual(key, key2)
def test_3_load_rsa_password(self):
- key = RSAKey.from_private_key_file(_support('test_rsa_password.key'), 'television')
- self.assertEqual('ssh-rsa', key.get_name())
- exp_rsa = b(FINGER_RSA.split()[1].replace(':', ''))
+ key = RSAKey.from_private_key_file(
+ _support("test_rsa_password.key"), "television"
+ )
+ self.assertEqual("ssh-rsa", key.get_name())
+ exp_rsa = b(FINGER_RSA.split()[1].replace(":", ""))
my_rsa = hexlify(key.get_fingerprint())
self.assertEqual(exp_rsa, my_rsa)
self.assertEqual(PUB_RSA.split()[1], key.get_base64())
self.assertEqual(1024, key.get_bits())
def test_4_load_dss(self):
- key = DSSKey.from_private_key_file(_support('test_dss.key'))
- self.assertEqual('ssh-dss', key.get_name())
- exp_dss = b(FINGER_DSS.split()[1].replace(':', ''))
+ key = DSSKey.from_private_key_file(_support("test_dss.key"))
+ self.assertEqual("ssh-dss", key.get_name())
+ exp_dss = b(FINGER_DSS.split()[1].replace(":", ""))
my_dss = hexlify(key.get_fingerprint())
self.assertEqual(exp_dss, my_dss)
self.assertEqual(PUB_DSS.split()[1], key.get_base64())
@@ -180,9 +199,11 @@ class KeyTest(unittest.TestCase):
self.assertEqual(key, key2)
def test_5_load_dss_password(self):
- key = DSSKey.from_private_key_file(_support('test_dss_password.key'), 'television')
- self.assertEqual('ssh-dss', key.get_name())
- exp_dss = b(FINGER_DSS.split()[1].replace(':', ''))
+ key = DSSKey.from_private_key_file(
+ _support("test_dss_password.key"), "television"
+ )
+ self.assertEqual("ssh-dss", key.get_name())
+ exp_dss = b(FINGER_DSS.split()[1].replace(":", ""))
my_dss = hexlify(key.get_fingerprint())
self.assertEqual(exp_dss, my_dss)
self.assertEqual(PUB_DSS.split()[1], key.get_base64())
@@ -190,7 +211,7 @@ class KeyTest(unittest.TestCase):
def test_6_compare_rsa(self):
# verify that the private & public keys compare equal
- key = RSAKey.from_private_key_file(_support('test_rsa.key'))
+ key = RSAKey.from_private_key_file(_support("test_rsa.key"))
self.assertEqual(key, key)
pub = RSAKey(data=key.asbytes())
self.assertTrue(key.can_sign())
@@ -199,7 +220,7 @@ class KeyTest(unittest.TestCase):
def test_7_compare_dss(self):
# verify that the private & public keys compare equal
- key = DSSKey.from_private_key_file(_support('test_dss.key'))
+ key = DSSKey.from_private_key_file(_support("test_dss.key"))
self.assertEqual(key, key)
pub = DSSKey(data=key.asbytes())
self.assertTrue(key.can_sign())
@@ -208,77 +229,79 @@ class KeyTest(unittest.TestCase):
def test_8_sign_rsa(self):
# verify that the rsa private key can sign and verify
- key = RSAKey.from_private_key_file(_support('test_rsa.key'))
- msg = key.sign_ssh_data(b'ice weasels')
+ key = RSAKey.from_private_key_file(_support("test_rsa.key"))
+ msg = key.sign_ssh_data(b"ice weasels")
self.assertTrue(type(msg) is Message)
msg.rewind()
- self.assertEqual('ssh-rsa', msg.get_text())
- sig = bytes().join([byte_chr(int(x, 16)) for x in SIGNED_RSA.split(':')])
+ self.assertEqual("ssh-rsa", msg.get_text())
+ sig = bytes().join(
+ [byte_chr(int(x, 16)) for x in SIGNED_RSA.split(":")]
+ )
self.assertEqual(sig, msg.get_binary())
msg.rewind()
pub = RSAKey(data=key.asbytes())
- self.assertTrue(pub.verify_ssh_sig(b'ice weasels', msg))
+ self.assertTrue(pub.verify_ssh_sig(b"ice weasels", msg))
def test_9_sign_dss(self):
# verify that the dss private key can sign and verify
- key = DSSKey.from_private_key_file(_support('test_dss.key'))
- msg = key.sign_ssh_data(b'ice weasels')
+ key = DSSKey.from_private_key_file(_support("test_dss.key"))
+ msg = key.sign_ssh_data(b"ice weasels")
self.assertTrue(type(msg) is Message)
msg.rewind()
- self.assertEqual('ssh-dss', msg.get_text())
+ self.assertEqual("ssh-dss", msg.get_text())
# can't do the same test as we do for RSA, because DSS signatures
# are usually different each time. but we can test verification
# anyway so it's ok.
self.assertEqual(40, len(msg.get_binary()))
msg.rewind()
pub = DSSKey(data=key.asbytes())
- self.assertTrue(pub.verify_ssh_sig(b'ice weasels', msg))
+ self.assertTrue(pub.verify_ssh_sig(b"ice weasels", msg))
def test_A_generate_rsa(self):
key = RSAKey.generate(1024)
- msg = key.sign_ssh_data(b'jerri blank')
+ msg = key.sign_ssh_data(b"jerri blank")
msg.rewind()
- self.assertTrue(key.verify_ssh_sig(b'jerri blank', msg))
+ self.assertTrue(key.verify_ssh_sig(b"jerri blank", msg))
def test_B_generate_dss(self):
key = DSSKey.generate(1024)
- msg = key.sign_ssh_data(b'jerri blank')
+ msg = key.sign_ssh_data(b"jerri blank")
msg.rewind()
- self.assertTrue(key.verify_ssh_sig(b'jerri blank', msg))
+ self.assertTrue(key.verify_ssh_sig(b"jerri blank", msg))
def test_C_generate_ecdsa(self):
key = ECDSAKey.generate()
- msg = key.sign_ssh_data(b'jerri blank')
+ msg = key.sign_ssh_data(b"jerri blank")
msg.rewind()
- self.assertTrue(key.verify_ssh_sig(b'jerri blank', msg))
+ self.assertTrue(key.verify_ssh_sig(b"jerri blank", msg))
self.assertEqual(key.get_bits(), 256)
- self.assertEqual(key.get_name(), 'ecdsa-sha2-nistp256')
+ self.assertEqual(key.get_name(), "ecdsa-sha2-nistp256")
key = ECDSAKey.generate(bits=256)
- msg = key.sign_ssh_data(b'jerri blank')
+ msg = key.sign_ssh_data(b"jerri blank")
msg.rewind()
- self.assertTrue(key.verify_ssh_sig(b'jerri blank', msg))
+ self.assertTrue(key.verify_ssh_sig(b"jerri blank", msg))
self.assertEqual(key.get_bits(), 256)
- self.assertEqual(key.get_name(), 'ecdsa-sha2-nistp256')
+ self.assertEqual(key.get_name(), "ecdsa-sha2-nistp256")
key = ECDSAKey.generate(bits=384)
- msg = key.sign_ssh_data(b'jerri blank')
+ msg = key.sign_ssh_data(b"jerri blank")
msg.rewind()
- self.assertTrue(key.verify_ssh_sig(b'jerri blank', msg))
+ self.assertTrue(key.verify_ssh_sig(b"jerri blank", msg))
self.assertEqual(key.get_bits(), 384)
- self.assertEqual(key.get_name(), 'ecdsa-sha2-nistp384')
+ self.assertEqual(key.get_name(), "ecdsa-sha2-nistp384")
key = ECDSAKey.generate(bits=521)
- msg = key.sign_ssh_data(b'jerri blank')
+ msg = key.sign_ssh_data(b"jerri blank")
msg.rewind()
- self.assertTrue(key.verify_ssh_sig(b'jerri blank', msg))
+ self.assertTrue(key.verify_ssh_sig(b"jerri blank", msg))
self.assertEqual(key.get_bits(), 521)
- self.assertEqual(key.get_name(), 'ecdsa-sha2-nistp521')
+ self.assertEqual(key.get_name(), "ecdsa-sha2-nistp521")
def test_10_load_ecdsa_256(self):
- key = ECDSAKey.from_private_key_file(_support('test_ecdsa_256.key'))
- self.assertEqual('ecdsa-sha2-nistp256', key.get_name())
- exp_ecdsa = b(FINGER_ECDSA_256.split()[1].replace(':', ''))
+ key = ECDSAKey.from_private_key_file(_support("test_ecdsa_256.key"))
+ self.assertEqual("ecdsa-sha2-nistp256", key.get_name())
+ exp_ecdsa = b(FINGER_ECDSA_256.split()[1].replace(":", ""))
my_ecdsa = hexlify(key.get_fingerprint())
self.assertEqual(exp_ecdsa, my_ecdsa)
self.assertEqual(PUB_ECDSA_256.split()[1], key.get_base64())
@@ -292,9 +315,11 @@ class KeyTest(unittest.TestCase):
self.assertEqual(key, key2)
def test_11_load_ecdsa_password_256(self):
- key = ECDSAKey.from_private_key_file(_support('test_ecdsa_password_256.key'), b'television')
- self.assertEqual('ecdsa-sha2-nistp256', key.get_name())
- exp_ecdsa = b(FINGER_ECDSA_256.split()[1].replace(':', ''))
+ key = ECDSAKey.from_private_key_file(
+ _support("test_ecdsa_password_256.key"), b"television"
+ )
+ self.assertEqual("ecdsa-sha2-nistp256", key.get_name())
+ exp_ecdsa = b(FINGER_ECDSA_256.split()[1].replace(":", ""))
my_ecdsa = hexlify(key.get_fingerprint())
self.assertEqual(exp_ecdsa, my_ecdsa)
self.assertEqual(PUB_ECDSA_256.split()[1], key.get_base64())
@@ -302,7 +327,7 @@ class KeyTest(unittest.TestCase):
def test_12_compare_ecdsa_256(self):
# verify that the private & public keys compare equal
- key = ECDSAKey.from_private_key_file(_support('test_ecdsa_256.key'))
+ key = ECDSAKey.from_private_key_file(_support("test_ecdsa_256.key"))
self.assertEqual(key, key)
pub = ECDSAKey(data=key.asbytes())
self.assertTrue(key.can_sign())
@@ -311,11 +336,11 @@ class KeyTest(unittest.TestCase):
def test_13_sign_ecdsa_256(self):
# verify that the rsa private key can sign and verify
- key = ECDSAKey.from_private_key_file(_support('test_ecdsa_256.key'))
- msg = key.sign_ssh_data(b'ice weasels')
+ key = ECDSAKey.from_private_key_file(_support("test_ecdsa_256.key"))
+ msg = key.sign_ssh_data(b"ice weasels")
self.assertTrue(type(msg) is Message)
msg.rewind()
- self.assertEqual('ecdsa-sha2-nistp256', msg.get_text())
+ self.assertEqual("ecdsa-sha2-nistp256", msg.get_text())
# ECDSA signatures, like DSS signatures, tend to be different
# each time, so we can't compare against a "known correct"
# signature.
@@ -323,12 +348,12 @@ class KeyTest(unittest.TestCase):
msg.rewind()
pub = ECDSAKey(data=key.asbytes())
- self.assertTrue(pub.verify_ssh_sig(b'ice weasels', msg))
+ self.assertTrue(pub.verify_ssh_sig(b"ice weasels", msg))
def test_14_load_ecdsa_384(self):
- key = ECDSAKey.from_private_key_file(_support('test_ecdsa_384.key'))
- self.assertEqual('ecdsa-sha2-nistp384', key.get_name())
- exp_ecdsa = b(FINGER_ECDSA_384.split()[1].replace(':', ''))
+ key = ECDSAKey.from_private_key_file(_support("test_ecdsa_384.key"))
+ self.assertEqual("ecdsa-sha2-nistp384", key.get_name())
+ exp_ecdsa = b(FINGER_ECDSA_384.split()[1].replace(":", ""))
my_ecdsa = hexlify(key.get_fingerprint())
self.assertEqual(exp_ecdsa, my_ecdsa)
self.assertEqual(PUB_ECDSA_384.split()[1], key.get_base64())
@@ -342,9 +367,11 @@ class KeyTest(unittest.TestCase):
self.assertEqual(key, key2)
def test_15_load_ecdsa_password_384(self):
- key = ECDSAKey.from_private_key_file(_support('test_ecdsa_password_384.key'), b'television')
- self.assertEqual('ecdsa-sha2-nistp384', key.get_name())
- exp_ecdsa = b(FINGER_ECDSA_384.split()[1].replace(':', ''))
+ key = ECDSAKey.from_private_key_file(
+ _support("test_ecdsa_password_384.key"), b"television"
+ )
+ self.assertEqual("ecdsa-sha2-nistp384", key.get_name())
+ exp_ecdsa = b(FINGER_ECDSA_384.split()[1].replace(":", ""))
my_ecdsa = hexlify(key.get_fingerprint())
self.assertEqual(exp_ecdsa, my_ecdsa)
self.assertEqual(PUB_ECDSA_384.split()[1], key.get_base64())
@@ -352,7 +379,7 @@ class KeyTest(unittest.TestCase):
def test_16_compare_ecdsa_384(self):
# verify that the private & public keys compare equal
- key = ECDSAKey.from_private_key_file(_support('test_ecdsa_384.key'))
+ key = ECDSAKey.from_private_key_file(_support("test_ecdsa_384.key"))
self.assertEqual(key, key)
pub = ECDSAKey(data=key.asbytes())
self.assertTrue(key.can_sign())
@@ -361,11 +388,11 @@ class KeyTest(unittest.TestCase):
def test_17_sign_ecdsa_384(self):
# verify that the rsa private key can sign and verify
- key = ECDSAKey.from_private_key_file(_support('test_ecdsa_384.key'))
- msg = key.sign_ssh_data(b'ice weasels')
+ key = ECDSAKey.from_private_key_file(_support("test_ecdsa_384.key"))
+ msg = key.sign_ssh_data(b"ice weasels")
self.assertTrue(type(msg) is Message)
msg.rewind()
- self.assertEqual('ecdsa-sha2-nistp384', msg.get_text())
+ self.assertEqual("ecdsa-sha2-nistp384", msg.get_text())
# ECDSA signatures, like DSS signatures, tend to be different
# each time, so we can't compare against a "known correct"
# signature.
@@ -373,12 +400,12 @@ class KeyTest(unittest.TestCase):
msg.rewind()
pub = ECDSAKey(data=key.asbytes())
- self.assertTrue(pub.verify_ssh_sig(b'ice weasels', msg))
+ self.assertTrue(pub.verify_ssh_sig(b"ice weasels", msg))
def test_18_load_ecdsa_521(self):
- key = ECDSAKey.from_private_key_file(_support('test_ecdsa_521.key'))
- self.assertEqual('ecdsa-sha2-nistp521', key.get_name())
- exp_ecdsa = b(FINGER_ECDSA_521.split()[1].replace(':', ''))
+ key = ECDSAKey.from_private_key_file(_support("test_ecdsa_521.key"))
+ self.assertEqual("ecdsa-sha2-nistp521", key.get_name())
+ exp_ecdsa = b(FINGER_ECDSA_521.split()[1].replace(":", ""))
my_ecdsa = hexlify(key.get_fingerprint())
self.assertEqual(exp_ecdsa, my_ecdsa)
self.assertEqual(PUB_ECDSA_521.split()[1], key.get_base64())
@@ -395,9 +422,11 @@ class KeyTest(unittest.TestCase):
self.assertEqual(key, key2)
def test_19_load_ecdsa_password_521(self):
- key = ECDSAKey.from_private_key_file(_support('test_ecdsa_password_521.key'), b'television')
- self.assertEqual('ecdsa-sha2-nistp521', key.get_name())
- exp_ecdsa = b(FINGER_ECDSA_521.split()[1].replace(':', ''))
+ key = ECDSAKey.from_private_key_file(
+ _support("test_ecdsa_password_521.key"), b"television"
+ )
+ self.assertEqual("ecdsa-sha2-nistp521", key.get_name())
+ exp_ecdsa = b(FINGER_ECDSA_521.split()[1].replace(":", ""))
my_ecdsa = hexlify(key.get_fingerprint())
self.assertEqual(exp_ecdsa, my_ecdsa)
self.assertEqual(PUB_ECDSA_521.split()[1], key.get_base64())
@@ -405,7 +434,7 @@ class KeyTest(unittest.TestCase):
def test_20_compare_ecdsa_521(self):
# verify that the private & public keys compare equal
- key = ECDSAKey.from_private_key_file(_support('test_ecdsa_521.key'))
+ key = ECDSAKey.from_private_key_file(_support("test_ecdsa_521.key"))
self.assertEqual(key, key)
pub = ECDSAKey(data=key.asbytes())
self.assertTrue(key.can_sign())
@@ -414,11 +443,11 @@ class KeyTest(unittest.TestCase):
def test_21_sign_ecdsa_521(self):
# verify that the rsa private key can sign and verify
- key = ECDSAKey.from_private_key_file(_support('test_ecdsa_521.key'))
- msg = key.sign_ssh_data(b'ice weasels')
+ key = ECDSAKey.from_private_key_file(_support("test_ecdsa_521.key"))
+ msg = key.sign_ssh_data(b"ice weasels")
self.assertTrue(type(msg) is Message)
msg.rewind()
- self.assertEqual('ecdsa-sha2-nistp521', msg.get_text())
+ self.assertEqual("ecdsa-sha2-nistp521", msg.get_text())
# ECDSA signatures, like DSS signatures, tend to be different
# each time, so we can't compare against a "known correct"
# signature.
@@ -426,14 +455,14 @@ class KeyTest(unittest.TestCase):
msg.rewind()
pub = ECDSAKey(data=key.asbytes())
- self.assertTrue(pub.verify_ssh_sig(b'ice weasels', msg))
+ self.assertTrue(pub.verify_ssh_sig(b"ice weasels", msg))
def test_salt_size(self):
# Read an existing encrypted private key
- file_ = _support('test_rsa_password.key')
- password = 'television'
- newfile = file_ + '.new'
- newpassword = 'radio'
+ file_ = _support("test_rsa_password.key")
+ password = "television"
+ newfile = file_ + ".new"
+ newpassword = "radio"
key = RSAKey(filename=file_, password=password)
# Write out a newly re-encrypted copy with a new password.
# When the bug under test exists, this will ValueError.
@@ -447,20 +476,20 @@ class KeyTest(unittest.TestCase):
os.remove(newfile)
def test_stringification(self):
- key = RSAKey.from_private_key_file(_support('test_rsa.key'))
+ key = RSAKey.from_private_key_file(_support("test_rsa.key"))
comparable = TEST_KEY_BYTESTR_2 if PY2 else TEST_KEY_BYTESTR_3
self.assertEqual(str(key), comparable)
def test_ed25519(self):
- key1 = Ed25519Key.from_private_key_file(_support('test_ed25519.key'))
+ key1 = Ed25519Key.from_private_key_file(_support("test_ed25519.key"))
key2 = Ed25519Key.from_private_key_file(
- _support('test_ed25519_password.key'), b'abc123'
+ _support("test_ed25519_password.key"), b"abc123"
)
self.assertNotEqual(key1.asbytes(), key2.asbytes())
def test_ed25519_compare(self):
# verify that the private & public keys compare equal
- key = Ed25519Key.from_private_key_file(_support('test_ed25519.key'))
+ key = Ed25519Key.from_private_key_file(_support("test_ed25519.key"))
self.assertEqual(key, key)
pub = Ed25519Key(data=key.asbytes())
self.assertTrue(key.can_sign())
@@ -470,25 +499,25 @@ class KeyTest(unittest.TestCase):
def test_ed25519_nonbytes_password(self):
# https://github.com/paramiko/paramiko/issues/1039
key = Ed25519Key.from_private_key_file(
- _support('test_ed25519_password.key'),
+ _support("test_ed25519_password.key"),
# NOTE: not a bytes. Amusingly, the test above for same key DOES
# explicitly cast to bytes...code smell!
- 'abc123',
+ "abc123",
)
# No exception -> it's good. Meh.
def test_ed25519_load_from_file_obj(self):
- with open(_support('test_ed25519.key')) as pkey_fileobj:
+ with open(_support("test_ed25519.key")) as pkey_fileobj:
key = Ed25519Key.from_private_key(pkey_fileobj)
self.assertEqual(key, key)
self.assertTrue(key.can_sign())
def test_keyfile_is_actually_encrypted(self):
# Read an existing encrypted private key
- file_ = _support('test_rsa_password.key')
- password = 'television'
- newfile = file_ + '.new'
- newpassword = 'radio'
+ file_ = _support("test_rsa_password.key")
+ password = "television"
+ newfile = file_ + ".new"
+ newpassword = "radio"
key = RSAKey(filename=file_, password=password)
# Write out a newly re-encrypted copy with a new password.
# When the bug under test exists, this will ValueError.
@@ -503,19 +532,21 @@ class KeyTest(unittest.TestCase):
# test_client.py; this and nearby cert tests are more about the gritty
# details.
# PKey.load_certificate
- key_path = _support(os.path.join('cert_support', 'test_rsa.key'))
+ key_path = _support(os.path.join("cert_support", "test_rsa.key"))
key = RSAKey.from_private_key_file(key_path)
self.assertTrue(key.public_blob is None)
cert_path = _support(
- os.path.join('cert_support', 'test_rsa.key-cert.pub')
+ os.path.join("cert_support", "test_rsa.key-cert.pub")
)
key.load_certificate(cert_path)
self.assertTrue(key.public_blob is not None)
- self.assertEqual(key.public_blob.key_type, 'ssh-rsa-cert-v01@openssh.com')
- self.assertEqual(key.public_blob.comment, 'test_rsa.key.pub')
+ self.assertEqual(
+ key.public_blob.key_type, "ssh-rsa-cert-v01@openssh.com"
+ )
+ self.assertEqual(key.public_blob.comment, "test_rsa.key.pub")
# Delve into blob contents, for test purposes
msg = Message(key.public_blob.key_blob)
- self.assertEqual(msg.get_text(), 'ssh-rsa-cert-v01@openssh.com')
+ self.assertEqual(msg.get_text(), "ssh-rsa-cert-v01@openssh.com")
nonce = msg.get_string()
e = msg.get_mpint()
n = msg.get_mpint()
@@ -525,10 +556,10 @@ class KeyTest(unittest.TestCase):
self.assertEqual(msg.get_int64(), 1234)
# Prevented from loading certificate that doesn't match
- key_path = _support(os.path.join('cert_support', 'test_ed25519.key'))
+ key_path = _support(os.path.join("cert_support", "test_ed25519.key"))
key1 = Ed25519Key.from_private_key_file(key_path)
self.assertRaises(
ValueError,
key1.load_certificate,
- _support('test_rsa.key-cert.pub'),
+ _support("test_rsa.key-cert.pub"),
)
diff --git a/tests/test_sftp.py b/tests/test_sftp.py
index 09a50453..576b69b7 100644
--- a/tests/test_sftp.py
+++ b/tests/test_sftp.py
@@ -45,7 +45,7 @@ from .stub_sftp import StubServer, StubSFTPServer
from .util import _support, slow
-ARTICLE = '''
+ARTICLE = """
Insulin sensitivity and liver insulin receptor structure in ducks from two
genera
@@ -70,7 +70,7 @@ receptors. Therefore the ducks from the two genera exhibit an alpha-beta-
structure for liver insulin receptors and a clear difference in the number of
liver insulin receptors. Their sensitivity to insulin is, however, similarly
decreased compared with chicken.
-'''
+"""
# Here is how unicode characters are encoded over 1 to 6 bytes in utf-8
@@ -82,32 +82,33 @@ decreased compared with chicken.
# U-04000000 - U-7FFFFFFF: 1111110x 10xxxxxx 10xxxxxx 10xxxxxx 10xxxxxx 10xxxxxx
# Note that: hex(int('11000011',2)) == '0xc3'
# Thus, the following 2-bytes sequence is not valid utf8: "invalid continuation byte"
-NON_UTF8_DATA = b'\xC3\xC3'
+NON_UTF8_DATA = b"\xC3\xC3"
-unicode_folder = u'\u00fcnic\u00f8de' if PY2 else '\u00fcnic\u00f8de'
-utf8_folder = b'/\xc3\xbcnic\xc3\xb8\x64\x65'
+unicode_folder = u"\u00fcnic\u00f8de" if PY2 else "\u00fcnic\u00f8de"
+utf8_folder = b"/\xc3\xbcnic\xc3\xb8\x64\x65"
@slow
class TestSFTP(object):
+
def test_1_file(self, sftp):
"""
verify that we can create a file.
"""
- f = sftp.open(sftp.FOLDER + '/test', 'w')
+ f = sftp.open(sftp.FOLDER + "/test", "w")
try:
assert f.stat().st_size == 0
finally:
f.close()
- sftp.remove(sftp.FOLDER + '/test')
+ sftp.remove(sftp.FOLDER + "/test")
def test_2_close(self, sftp):
"""
Verify that SFTP session close() causes a socket error on next action.
"""
sftp.close()
- with pytest.raises(socket.error, match='Socket is closed'):
- sftp.open(sftp.FOLDER + '/test2', 'w')
+ with pytest.raises(socket.error, match="Socket is closed"):
+ sftp.open(sftp.FOLDER + "/test2", "w")
def test_2_sftp_can_be_used_as_context_manager(self, sftp):
"""
@@ -115,117 +116,117 @@ class TestSFTP(object):
"""
with sftp:
pass
- with pytest.raises(socket.error, match='Socket is closed'):
- sftp.open(sftp.FOLDER + '/test2', 'w')
+ with pytest.raises(socket.error, match="Socket is closed"):
+ sftp.open(sftp.FOLDER + "/test2", "w")
def test_3_write(self, sftp):
"""
verify that a file can be created and written, and the size is correct.
"""
try:
- with sftp.open(sftp.FOLDER + '/duck.txt', 'w') as f:
+ with sftp.open(sftp.FOLDER + "/duck.txt", "w") as f:
f.write(ARTICLE)
- assert sftp.stat(sftp.FOLDER + '/duck.txt').st_size == 1483
+ assert sftp.stat(sftp.FOLDER + "/duck.txt").st_size == 1483
finally:
- sftp.remove(sftp.FOLDER + '/duck.txt')
+ sftp.remove(sftp.FOLDER + "/duck.txt")
def test_3_sftp_file_can_be_used_as_context_manager(self, sftp):
"""
verify that an opened file can be used as a context manager
"""
try:
- with sftp.open(sftp.FOLDER + '/duck.txt', 'w') as f:
+ with sftp.open(sftp.FOLDER + "/duck.txt", "w") as f:
f.write(ARTICLE)
- assert sftp.stat(sftp.FOLDER + '/duck.txt').st_size == 1483
+ assert sftp.stat(sftp.FOLDER + "/duck.txt").st_size == 1483
finally:
- sftp.remove(sftp.FOLDER + '/duck.txt')
+ sftp.remove(sftp.FOLDER + "/duck.txt")
def test_4_append(self, sftp):
"""
verify that a file can be opened for append, and tell() still works.
"""
try:
- with sftp.open(sftp.FOLDER + '/append.txt', 'w') as f:
- f.write('first line\nsecond line\n')
+ with sftp.open(sftp.FOLDER + "/append.txt", "w") as f:
+ f.write("first line\nsecond line\n")
assert f.tell() == 23
- with sftp.open(sftp.FOLDER + '/append.txt', 'a+') as f:
- f.write('third line!!!\n')
+ with sftp.open(sftp.FOLDER + "/append.txt", "a+") as f:
+ f.write("third line!!!\n")
assert f.tell() == 37
assert f.stat().st_size == 37
f.seek(-26, f.SEEK_CUR)
- assert f.readline() == 'second line\n'
+ assert f.readline() == "second line\n"
finally:
- sftp.remove(sftp.FOLDER + '/append.txt')
+ sftp.remove(sftp.FOLDER + "/append.txt")
def test_5_rename(self, sftp):
"""
verify that renaming a file works.
"""
try:
- with sftp.open(sftp.FOLDER + '/first.txt', 'w') as f:
- f.write('content!\n')
- sftp.rename(sftp.FOLDER + '/first.txt', sftp.FOLDER + '/second.txt')
- with pytest.raises(IOError, match='No such file'):
- sftp.open(sftp.FOLDER + '/first.txt', 'r')
- with sftp.open(sftp.FOLDER + '/second.txt', 'r') as f:
+ with sftp.open(sftp.FOLDER + "/first.txt", "w") as f:
+ f.write("content!\n")
+ sftp.rename(
+ sftp.FOLDER + "/first.txt", sftp.FOLDER + "/second.txt"
+ )
+ with pytest.raises(IOError, match="No such file"):
+ sftp.open(sftp.FOLDER + "/first.txt", "r")
+ with sftp.open(sftp.FOLDER + "/second.txt", "r") as f:
f.seek(-6, f.SEEK_END)
- assert u(f.read(4)) == 'tent'
+ assert u(f.read(4)) == "tent"
finally:
# TODO: this is gross, make some sort of 'remove if possible' / 'rm
# -f' a-like, jeez
try:
- sftp.remove(sftp.FOLDER + '/first.txt')
+ sftp.remove(sftp.FOLDER + "/first.txt")
except:
pass
try:
- sftp.remove(sftp.FOLDER + '/second.txt')
+ sftp.remove(sftp.FOLDER + "/second.txt")
except:
pass
-
def test_5a_posix_rename(self, sftp):
"""Test posix-rename@openssh.com protocol extension."""
try:
# first check that the normal rename works as specified
- with sftp.open(sftp.FOLDER + '/a', 'w') as f:
- f.write('one')
- sftp.rename(sftp.FOLDER + '/a', sftp.FOLDER + '/b')
- with sftp.open(sftp.FOLDER + '/a', 'w') as f:
- f.write('two')
- with pytest.raises(IOError): # actual message seems generic
- sftp.rename(sftp.FOLDER + '/a', sftp.FOLDER + '/b')
+ with sftp.open(sftp.FOLDER + "/a", "w") as f:
+ f.write("one")
+ sftp.rename(sftp.FOLDER + "/a", sftp.FOLDER + "/b")
+ with sftp.open(sftp.FOLDER + "/a", "w") as f:
+ f.write("two")
+ with pytest.raises(IOError): # actual message seems generic
+ sftp.rename(sftp.FOLDER + "/a", sftp.FOLDER + "/b")
# now check with the posix_rename
- sftp.posix_rename(sftp.FOLDER + '/a', sftp.FOLDER + '/b')
- with sftp.open(sftp.FOLDER + '/b', 'r') as f:
+ sftp.posix_rename(sftp.FOLDER + "/a", sftp.FOLDER + "/b")
+ with sftp.open(sftp.FOLDER + "/b", "r") as f:
data = u(f.read())
err = "Contents of renamed file not the same as original file"
- assert 'two' == data, err
+ assert "two" == data, err
finally:
try:
- sftp.remove(sftp.FOLDER + '/a')
+ sftp.remove(sftp.FOLDER + "/a")
except:
pass
try:
- sftp.remove(sftp.FOLDER + '/b')
+ sftp.remove(sftp.FOLDER + "/b")
except:
pass
-
def test_6_folder(self, sftp):
"""
create a temporary folder, verify that we can create a file in it, then
remove the folder and verify that we can't create a file in it anymore.
"""
- sftp.mkdir(sftp.FOLDER + '/subfolder')
- sftp.open(sftp.FOLDER + '/subfolder/test', 'w').close()
- sftp.remove(sftp.FOLDER + '/subfolder/test')
- sftp.rmdir(sftp.FOLDER + '/subfolder')
+ sftp.mkdir(sftp.FOLDER + "/subfolder")
+ sftp.open(sftp.FOLDER + "/subfolder/test", "w").close()
+ sftp.remove(sftp.FOLDER + "/subfolder/test")
+ sftp.rmdir(sftp.FOLDER + "/subfolder")
# shouldn't be able to create that file if dir removed
with pytest.raises(IOError, match="No such file"):
- sftp.open(sftp.FOLDER + '/subfolder/test')
+ sftp.open(sftp.FOLDER + "/subfolder/test")
def test_7_listdir(self, sftp):
"""
@@ -233,57 +234,57 @@ class TestSFTP(object):
it, and those files show up in sftp.listdir.
"""
try:
- sftp.open(sftp.FOLDER + '/duck.txt', 'w').close()
- sftp.open(sftp.FOLDER + '/fish.txt', 'w').close()
- sftp.open(sftp.FOLDER + '/tertiary.py', 'w').close()
+ sftp.open(sftp.FOLDER + "/duck.txt", "w").close()
+ sftp.open(sftp.FOLDER + "/fish.txt", "w").close()
+ sftp.open(sftp.FOLDER + "/tertiary.py", "w").close()
x = sftp.listdir(sftp.FOLDER)
assert len(x) == 3
- assert 'duck.txt' in x
- assert 'fish.txt' in x
- assert 'tertiary.py' in x
- assert 'random' not in x
+ assert "duck.txt" in x
+ assert "fish.txt" in x
+ assert "tertiary.py" in x
+ assert "random" not in x
finally:
- sftp.remove(sftp.FOLDER + '/duck.txt')
- sftp.remove(sftp.FOLDER + '/fish.txt')
- sftp.remove(sftp.FOLDER + '/tertiary.py')
+ sftp.remove(sftp.FOLDER + "/duck.txt")
+ sftp.remove(sftp.FOLDER + "/fish.txt")
+ sftp.remove(sftp.FOLDER + "/tertiary.py")
def test_7_5_listdir_iter(self, sftp):
"""
listdir_iter version of above test
"""
try:
- sftp.open(sftp.FOLDER + '/duck.txt', 'w').close()
- sftp.open(sftp.FOLDER + '/fish.txt', 'w').close()
- sftp.open(sftp.FOLDER + '/tertiary.py', 'w').close()
+ sftp.open(sftp.FOLDER + "/duck.txt", "w").close()
+ sftp.open(sftp.FOLDER + "/fish.txt", "w").close()
+ sftp.open(sftp.FOLDER + "/tertiary.py", "w").close()
x = [x.filename for x in sftp.listdir_iter(sftp.FOLDER)]
assert len(x) == 3
- assert 'duck.txt' in x
- assert 'fish.txt' in x
- assert 'tertiary.py' in x
- assert 'random' not in x
+ assert "duck.txt" in x
+ assert "fish.txt" in x
+ assert "tertiary.py" in x
+ assert "random" not in x
finally:
- sftp.remove(sftp.FOLDER + '/duck.txt')
- sftp.remove(sftp.FOLDER + '/fish.txt')
- sftp.remove(sftp.FOLDER + '/tertiary.py')
+ sftp.remove(sftp.FOLDER + "/duck.txt")
+ sftp.remove(sftp.FOLDER + "/fish.txt")
+ sftp.remove(sftp.FOLDER + "/tertiary.py")
def test_8_setstat(self, sftp):
"""
verify that the setstat functions (chown, chmod, utime, truncate) work.
"""
try:
- with sftp.open(sftp.FOLDER + '/special', 'w') as f:
- f.write('x' * 1024)
+ with sftp.open(sftp.FOLDER + "/special", "w") as f:
+ f.write("x" * 1024)
- stat = sftp.stat(sftp.FOLDER + '/special')
- sftp.chmod(sftp.FOLDER + '/special', (stat.st_mode & ~o777) | o600)
- stat = sftp.stat(sftp.FOLDER + '/special')
+ stat = sftp.stat(sftp.FOLDER + "/special")
+ sftp.chmod(sftp.FOLDER + "/special", (stat.st_mode & ~o777) | o600)
+ stat = sftp.stat(sftp.FOLDER + "/special")
expected_mode = o600
- if sys.platform == 'win32':
+ if sys.platform == "win32":
# chmod not really functional on windows
expected_mode = o666
- if sys.platform == 'cygwin':
+ if sys.platform == "cygwin":
# even worse.
expected_mode = o644
assert stat.st_mode & o777 == expected_mode
@@ -291,19 +292,19 @@ class TestSFTP(object):
mtime = stat.st_mtime - 3600
atime = stat.st_atime - 1800
- sftp.utime(sftp.FOLDER + '/special', (atime, mtime))
- stat = sftp.stat(sftp.FOLDER + '/special')
+ sftp.utime(sftp.FOLDER + "/special", (atime, mtime))
+ stat = sftp.stat(sftp.FOLDER + "/special")
assert stat.st_mtime == mtime
- if sys.platform not in ('win32', 'cygwin'):
+ if sys.platform not in ("win32", "cygwin"):
assert stat.st_atime == atime
# can't really test chown, since we'd have to know a valid uid.
- sftp.truncate(sftp.FOLDER + '/special', 512)
- stat = sftp.stat(sftp.FOLDER + '/special')
+ sftp.truncate(sftp.FOLDER + "/special", 512)
+ stat = sftp.stat(sftp.FOLDER + "/special")
assert stat.st_size == 512
finally:
- sftp.remove(sftp.FOLDER + '/special')
+ sftp.remove(sftp.FOLDER + "/special")
def test_9_fsetstat(self, sftp):
"""
@@ -311,19 +312,19 @@ class TestSFTP(object):
work on open files.
"""
try:
- with sftp.open(sftp.FOLDER + '/special', 'w') as f:
- f.write('x' * 1024)
+ with sftp.open(sftp.FOLDER + "/special", "w") as f:
+ f.write("x" * 1024)
- with sftp.open(sftp.FOLDER + '/special', 'r+') as f:
+ with sftp.open(sftp.FOLDER + "/special", "r+") as f:
stat = f.stat()
f.chmod((stat.st_mode & ~o777) | o600)
stat = f.stat()
expected_mode = o600
- if sys.platform == 'win32':
+ if sys.platform == "win32":
# chmod not really functional on windows
expected_mode = o666
- if sys.platform == 'cygwin':
+ if sys.platform == "cygwin":
# even worse.
expected_mode = o644
assert stat.st_mode & o777 == expected_mode
@@ -334,7 +335,7 @@ class TestSFTP(object):
f.utime((atime, mtime))
stat = f.stat()
assert stat.st_mtime == mtime
- if sys.platform not in ('win32', 'cygwin'):
+ if sys.platform not in ("win32", "cygwin"):
assert stat.st_atime == atime
# can't really test chown, since we'd have to know a valid uid.
@@ -343,7 +344,7 @@ class TestSFTP(object):
stat = f.stat()
assert stat.st_size == 512
finally:
- sftp.remove(sftp.FOLDER + '/special')
+ sftp.remove(sftp.FOLDER + "/special")
def test_A_readline_seek(self, sftp):
"""
@@ -353,10 +354,10 @@ class TestSFTP(object):
buffering is reset on 'seek'.
"""
try:
- with sftp.open(sftp.FOLDER + '/duck.txt', 'w') as f:
+ with sftp.open(sftp.FOLDER + "/duck.txt", "w") as f:
f.write(ARTICLE)
- with sftp.open(sftp.FOLDER + '/duck.txt', 'r+') as f:
+ with sftp.open(sftp.FOLDER + "/duck.txt", "r+") as f:
line_number = 0
loc = 0
pos_list = []
@@ -366,13 +367,16 @@ class TestSFTP(object):
loc = f.tell()
assert f.seekable()
f.seek(pos_list[6], f.SEEK_SET)
- assert f.readline(), 'Nouzilly == France.\n'
+ assert f.readline(), "Nouzilly == France.\n"
f.seek(pos_list[17], f.SEEK_SET)
- assert f.readline()[:4] == 'duck'
+ assert f.readline()[:4] == "duck"
f.seek(pos_list[10], f.SEEK_SET)
- assert f.readline() == 'duck types were equally resistant to exogenous insulin compared with chicken.\n'
+ assert (
+ f.readline()
+ == "duck types were equally resistant to exogenous insulin compared with chicken.\n"
+ )
finally:
- sftp.remove(sftp.FOLDER + '/duck.txt')
+ sftp.remove(sftp.FOLDER + "/duck.txt")
def test_B_write_seek(self, sftp):
"""
@@ -380,17 +384,17 @@ class TestSFTP(object):
changes worked.
"""
try:
- with sftp.open(sftp.FOLDER + '/testing.txt', 'w') as f:
- f.write('hello kitty.\n')
+ with sftp.open(sftp.FOLDER + "/testing.txt", "w") as f:
+ f.write("hello kitty.\n")
f.seek(-5, f.SEEK_CUR)
- f.write('dd')
+ f.write("dd")
- assert sftp.stat(sftp.FOLDER + '/testing.txt').st_size == 13
- with sftp.open(sftp.FOLDER + '/testing.txt', 'r') as f:
+ assert sftp.stat(sftp.FOLDER + "/testing.txt").st_size == 13
+ with sftp.open(sftp.FOLDER + "/testing.txt", "r") as f:
data = f.read(20)
- assert data == b'hello kiddy.\n'
+ assert data == b"hello kiddy.\n"
finally:
- sftp.remove(sftp.FOLDER + '/testing.txt')
+ sftp.remove(sftp.FOLDER + "/testing.txt")
def test_C_symlink(self, sftp):
"""
@@ -401,39 +405,41 @@ class TestSFTP(object):
return
try:
- with sftp.open(sftp.FOLDER + '/original.txt', 'w') as f:
- f.write('original\n')
- sftp.symlink('original.txt', sftp.FOLDER + '/link.txt')
- assert sftp.readlink(sftp.FOLDER + '/link.txt') == 'original.txt'
+ with sftp.open(sftp.FOLDER + "/original.txt", "w") as f:
+ f.write("original\n")
+ sftp.symlink("original.txt", sftp.FOLDER + "/link.txt")
+ assert sftp.readlink(sftp.FOLDER + "/link.txt") == "original.txt"
- with sftp.open(sftp.FOLDER + '/link.txt', 'r') as f:
- assert f.readlines() == ['original\n']
+ with sftp.open(sftp.FOLDER + "/link.txt", "r") as f:
+ assert f.readlines() == ["original\n"]
- cwd = sftp.normalize('.')
- if cwd[-1] == '/':
+ cwd = sftp.normalize(".")
+ if cwd[-1] == "/":
cwd = cwd[:-1]
- abs_path = cwd + '/' + sftp.FOLDER + '/original.txt'
- sftp.symlink(abs_path, sftp.FOLDER + '/link2.txt')
- assert abs_path == sftp.readlink(sftp.FOLDER + '/link2.txt')
+ abs_path = cwd + "/" + sftp.FOLDER + "/original.txt"
+ sftp.symlink(abs_path, sftp.FOLDER + "/link2.txt")
+ assert abs_path == sftp.readlink(sftp.FOLDER + "/link2.txt")
- assert sftp.lstat(sftp.FOLDER + '/link.txt').st_size == 12
- assert sftp.stat(sftp.FOLDER + '/link.txt').st_size == 9
+ assert sftp.lstat(sftp.FOLDER + "/link.txt").st_size == 12
+ assert sftp.stat(sftp.FOLDER + "/link.txt").st_size == 9
# the sftp server may be hiding extra path members from us, so the
# length may be longer than we expect:
- assert sftp.lstat(sftp.FOLDER + '/link2.txt').st_size >= len(abs_path)
- assert sftp.stat(sftp.FOLDER + '/link2.txt').st_size == 9
- assert sftp.stat(sftp.FOLDER + '/original.txt').st_size == 9
+ assert sftp.lstat(sftp.FOLDER + "/link2.txt").st_size >= len(
+ abs_path
+ )
+ assert sftp.stat(sftp.FOLDER + "/link2.txt").st_size == 9
+ assert sftp.stat(sftp.FOLDER + "/original.txt").st_size == 9
finally:
try:
- sftp.remove(sftp.FOLDER + '/link.txt')
+ sftp.remove(sftp.FOLDER + "/link.txt")
except:
pass
try:
- sftp.remove(sftp.FOLDER + '/link2.txt')
+ sftp.remove(sftp.FOLDER + "/link2.txt")
except:
pass
try:
- sftp.remove(sftp.FOLDER + '/original.txt')
+ sftp.remove(sftp.FOLDER + "/original.txt")
except:
pass
@@ -442,18 +448,18 @@ class TestSFTP(object):
verify that buffered writes are automatically flushed on seek.
"""
try:
- with sftp.open(sftp.FOLDER + '/happy.txt', 'w', 1) as f:
- f.write('full line.\n')
- f.write('partial')
+ with sftp.open(sftp.FOLDER + "/happy.txt", "w", 1) as f:
+ f.write("full line.\n")
+ f.write("partial")
f.seek(9, f.SEEK_SET)
- f.write('?\n')
+ f.write("?\n")
- with sftp.open(sftp.FOLDER + '/happy.txt', 'r') as f:
- assert f.readline() == u('full line?\n')
- assert f.read(7) == b'partial'
+ with sftp.open(sftp.FOLDER + "/happy.txt", "r") as f:
+ assert f.readline() == u("full line?\n")
+ assert f.read(7) == b"partial"
finally:
try:
- sftp.remove(sftp.FOLDER + '/happy.txt')
+ sftp.remove(sftp.FOLDER + "/happy.txt")
except:
pass
@@ -462,9 +468,9 @@ class TestSFTP(object):
test that realpath is returning something non-empty and not an
error.
"""
- pwd = sftp.normalize('.')
+ pwd = sftp.normalize(".")
assert len(pwd) > 0
- f = sftp.normalize('./' + sftp.FOLDER)
+ f = sftp.normalize("./" + sftp.FOLDER)
assert len(f) > 0
assert os.path.join(pwd, sftp.FOLDER) == f
@@ -472,46 +478,46 @@ class TestSFTP(object):
"""
verify that mkdir/rmdir work.
"""
- sftp.mkdir(sftp.FOLDER + '/subfolder')
- with pytest.raises(IOError): # generic msg only
- sftp.mkdir(sftp.FOLDER + '/subfolder')
- sftp.rmdir(sftp.FOLDER + '/subfolder')
+ sftp.mkdir(sftp.FOLDER + "/subfolder")
+ with pytest.raises(IOError): # generic msg only
+ sftp.mkdir(sftp.FOLDER + "/subfolder")
+ sftp.rmdir(sftp.FOLDER + "/subfolder")
with pytest.raises(IOError, match="No such file"):
- sftp.rmdir(sftp.FOLDER + '/subfolder')
+ sftp.rmdir(sftp.FOLDER + "/subfolder")
def test_G_chdir(self, sftp):
"""
verify that chdir/getcwd work.
"""
- root = sftp.normalize('.')
- if root[-1] != '/':
- root += '/'
+ root = sftp.normalize(".")
+ if root[-1] != "/":
+ root += "/"
try:
- sftp.mkdir(sftp.FOLDER + '/alpha')
- sftp.chdir(sftp.FOLDER + '/alpha')
- sftp.mkdir('beta')
- assert root + sftp.FOLDER + '/alpha' == sftp.getcwd()
- assert ['beta'] == sftp.listdir('.')
-
- sftp.chdir('beta')
- with sftp.open('fish', 'w') as f:
- f.write('hello\n')
- sftp.chdir('..')
- assert ['fish'] == sftp.listdir('beta')
- sftp.chdir('..')
- assert ['fish'] == sftp.listdir('alpha/beta')
+ sftp.mkdir(sftp.FOLDER + "/alpha")
+ sftp.chdir(sftp.FOLDER + "/alpha")
+ sftp.mkdir("beta")
+ assert root + sftp.FOLDER + "/alpha" == sftp.getcwd()
+ assert ["beta"] == sftp.listdir(".")
+
+ sftp.chdir("beta")
+ with sftp.open("fish", "w") as f:
+ f.write("hello\n")
+ sftp.chdir("..")
+ assert ["fish"] == sftp.listdir("beta")
+ sftp.chdir("..")
+ assert ["fish"] == sftp.listdir("alpha/beta")
finally:
sftp.chdir(root)
try:
- sftp.unlink(sftp.FOLDER + '/alpha/beta/fish')
+ sftp.unlink(sftp.FOLDER + "/alpha/beta/fish")
except:
pass
try:
- sftp.rmdir(sftp.FOLDER + '/alpha/beta')
+ sftp.rmdir(sftp.FOLDER + "/alpha/beta")
except:
pass
try:
- sftp.rmdir(sftp.FOLDER + '/alpha')
+ sftp.rmdir(sftp.FOLDER + "/alpha")
except:
pass
@@ -519,20 +525,21 @@ class TestSFTP(object):
"""
verify that get/put work.
"""
- warnings.filterwarnings('ignore', 'tempnam.*')
+ warnings.filterwarnings("ignore", "tempnam.*")
fd, localname = mkstemp()
os.close(fd)
- text = b'All I wanted was a plastic bunny rabbit.\n'
- with open(localname, 'wb') as f:
+ text = b"All I wanted was a plastic bunny rabbit.\n"
+ with open(localname, "wb") as f:
f.write(text)
saved_progress = []
def progress_callback(x, y):
saved_progress.append((x, y))
- sftp.put(localname, sftp.FOLDER + '/bunny.txt', progress_callback)
- with sftp.open(sftp.FOLDER + '/bunny.txt', 'rb') as f:
+ sftp.put(localname, sftp.FOLDER + "/bunny.txt", progress_callback)
+
+ with sftp.open(sftp.FOLDER + "/bunny.txt", "rb") as f:
assert text == f.read(128)
assert [(41, 41)] == saved_progress
@@ -540,14 +547,14 @@ class TestSFTP(object):
fd, localname = mkstemp()
os.close(fd)
saved_progress = []
- sftp.get(sftp.FOLDER + '/bunny.txt', localname, progress_callback)
+ sftp.get(sftp.FOLDER + "/bunny.txt", localname, progress_callback)
- with open(localname, 'rb') as f:
+ with open(localname, "rb") as f:
assert text == f.read(128)
assert [(41, 41)] == saved_progress
os.unlink(localname)
- sftp.unlink(sftp.FOLDER + '/bunny.txt')
+ sftp.unlink(sftp.FOLDER + "/bunny.txt")
def test_I_check(self, sftp):
"""
@@ -555,118 +562,132 @@ class TestSFTP(object):
(it's an sftp extension that we support, and may be the only ones who
support it.)
"""
- with sftp.open(sftp.FOLDER + '/kitty.txt', 'w') as f:
- f.write('here kitty kitty' * 64)
+ with sftp.open(sftp.FOLDER + "/kitty.txt", "w") as f:
+ f.write("here kitty kitty" * 64)
try:
- with sftp.open(sftp.FOLDER + '/kitty.txt', 'r') as f:
- sum = f.check('sha1')
- assert '91059CFC6615941378D413CB5ADAF4C5EB293402' == u(hexlify(sum)).upper()
- sum = f.check('md5', 0, 512)
- assert '93DE4788FCA28D471516963A1FE3856A' == u(hexlify(sum)).upper()
- sum = f.check('md5', 0, 0, 510)
- assert u(hexlify(sum)).upper() == 'EB3B45B8CD55A0707D99B177544A319F373183D241432BB2157AB9E46358C4AC90370B5CADE5D90336FC1716F90B36D6' # noqa
+ with sftp.open(sftp.FOLDER + "/kitty.txt", "r") as f:
+ sum = f.check("sha1")
+ assert (
+ "91059CFC6615941378D413CB5ADAF4C5EB293402"
+ == u(hexlify(sum)).upper()
+ )
+ sum = f.check("md5", 0, 512)
+ assert (
+ "93DE4788FCA28D471516963A1FE3856A"
+ == u(hexlify(sum)).upper()
+ )
+ sum = f.check("md5", 0, 0, 510)
+ assert (
+ u(hexlify(sum)).upper()
+ == "EB3B45B8CD55A0707D99B177544A319F373183D241432BB2157AB9E46358C4AC90370B5CADE5D90336FC1716F90B36D6"
+ ) # noqa
finally:
- sftp.unlink(sftp.FOLDER + '/kitty.txt')
+ sftp.unlink(sftp.FOLDER + "/kitty.txt")
def test_J_x_flag(self, sftp):
"""
verify that the 'x' flag works when opening a file.
"""
- sftp.open(sftp.FOLDER + '/unusual.txt', 'wx').close()
+ sftp.open(sftp.FOLDER + "/unusual.txt", "wx").close()
try:
try:
- sftp.open(sftp.FOLDER + '/unusual.txt', 'wx')
- self.fail('expected exception')
+ sftp.open(sftp.FOLDER + "/unusual.txt", "wx")
+ self.fail("expected exception")
except IOError:
pass
finally:
- sftp.unlink(sftp.FOLDER + '/unusual.txt')
+ sftp.unlink(sftp.FOLDER + "/unusual.txt")
def test_K_utf8(self, sftp):
"""
verify that unicode strings are encoded into utf8 correctly.
"""
- with sftp.open(sftp.FOLDER + '/something', 'w') as f:
- f.write('okay')
+ with sftp.open(sftp.FOLDER + "/something", "w") as f:
+ f.write("okay")
try:
- sftp.rename(sftp.FOLDER + '/something', sftp.FOLDER + '/' + unicode_folder)
- sftp.open(b(sftp.FOLDER) + utf8_folder, 'r')
+ sftp.rename(
+ sftp.FOLDER + "/something", sftp.FOLDER + "/" + unicode_folder
+ )
+ sftp.open(b(sftp.FOLDER) + utf8_folder, "r")
except Exception as e:
- self.fail('exception ' + str(e))
+ self.fail("exception " + str(e))
sftp.unlink(b(sftp.FOLDER) + utf8_folder)
def test_L_utf8_chdir(self, sftp):
- sftp.mkdir(sftp.FOLDER + '/' + unicode_folder)
+ sftp.mkdir(sftp.FOLDER + "/" + unicode_folder)
try:
- sftp.chdir(sftp.FOLDER + '/' + unicode_folder)
- with sftp.open('something', 'w') as f:
- f.write('okay')
- sftp.unlink('something')
+ sftp.chdir(sftp.FOLDER + "/" + unicode_folder)
+ with sftp.open("something", "w") as f:
+ f.write("okay")
+ sftp.unlink("something")
finally:
sftp.chdir()
- sftp.rmdir(sftp.FOLDER + '/' + unicode_folder)
+ sftp.rmdir(sftp.FOLDER + "/" + unicode_folder)
def test_M_bad_readv(self, sftp):
"""
verify that readv at the end of the file doesn't essplode.
"""
- sftp.open(sftp.FOLDER + '/zero', 'w').close()
+ sftp.open(sftp.FOLDER + "/zero", "w").close()
try:
- with sftp.open(sftp.FOLDER + '/zero', 'r') as f:
+ with sftp.open(sftp.FOLDER + "/zero", "r") as f:
f.readv([(0, 12)])
- with sftp.open(sftp.FOLDER + '/zero', 'r') as f:
+ with sftp.open(sftp.FOLDER + "/zero", "r") as f:
file_size = f.stat().st_size
f.prefetch(file_size)
f.read(100)
finally:
- sftp.unlink(sftp.FOLDER + '/zero')
+ sftp.unlink(sftp.FOLDER + "/zero")
def test_N_put_without_confirm(self, sftp):
"""
verify that get/put work without confirmation.
"""
- warnings.filterwarnings('ignore', 'tempnam.*')
+ warnings.filterwarnings("ignore", "tempnam.*")
fd, localname = mkstemp()
os.close(fd)
- text = b'All I wanted was a plastic bunny rabbit.\n'
- with open(localname, 'wb') as f:
+ text = b"All I wanted was a plastic bunny rabbit.\n"
+ with open(localname, "wb") as f:
f.write(text)
saved_progress = []
def progress_callback(x, y):
saved_progress.append((x, y))
- res = sftp.put(localname, sftp.FOLDER + '/bunny.txt', progress_callback, False)
+
+ res = sftp.put(
+ localname, sftp.FOLDER + "/bunny.txt", progress_callback, False
+ )
assert SFTPAttributes().attr == res.attr
- with sftp.open(sftp.FOLDER + '/bunny.txt', 'r') as f:
+ with sftp.open(sftp.FOLDER + "/bunny.txt", "r") as f:
assert text == f.read(128)
assert (41, 41) == saved_progress[-1]
os.unlink(localname)
- sftp.unlink(sftp.FOLDER + '/bunny.txt')
+ sftp.unlink(sftp.FOLDER + "/bunny.txt")
def test_O_getcwd(self, sftp):
"""
verify that chdir/getcwd work.
"""
assert sftp.getcwd() == None
- root = sftp.normalize('.')
- if root[-1] != '/':
- root += '/'
+ root = sftp.normalize(".")
+ if root[-1] != "/":
+ root += "/"
try:
- sftp.mkdir(sftp.FOLDER + '/alpha')
- sftp.chdir(sftp.FOLDER + '/alpha')
- assert sftp.getcwd() == '/' + sftp.FOLDER + '/alpha'
+ sftp.mkdir(sftp.FOLDER + "/alpha")
+ sftp.chdir(sftp.FOLDER + "/alpha")
+ assert sftp.getcwd() == "/" + sftp.FOLDER + "/alpha"
finally:
sftp.chdir(root)
try:
- sftp.rmdir(sftp.FOLDER + '/alpha')
+ sftp.rmdir(sftp.FOLDER + "/alpha")
except:
pass
@@ -677,24 +698,24 @@ class TestSFTP(object):
does not work except through paramiko. :( openssh fails.
"""
try:
- with sftp.open(sftp.FOLDER + '/append.txt', 'a') as f:
- f.write('first line\nsecond line\n')
+ with sftp.open(sftp.FOLDER + "/append.txt", "a") as f:
+ f.write("first line\nsecond line\n")
f.seek(11, f.SEEK_SET)
- f.write('third line\n')
+ f.write("third line\n")
- with sftp.open(sftp.FOLDER + '/append.txt', 'r') as f:
+ with sftp.open(sftp.FOLDER + "/append.txt", "r") as f:
assert f.stat().st_size == 34
- assert f.readline() == 'first line\n'
- assert f.readline() == 'second line\n'
- assert f.readline() == 'third line\n'
+ assert f.readline() == "first line\n"
+ assert f.readline() == "second line\n"
+ assert f.readline() == "third line\n"
finally:
- sftp.remove(sftp.FOLDER + '/append.txt')
+ sftp.remove(sftp.FOLDER + "/append.txt")
def test_putfo_empty_file(self, sftp):
"""
Send an empty file and confirm it is sent.
"""
- target = sftp.FOLDER + '/empty file.txt'
+ target = sftp.FOLDER + "/empty file.txt"
stream = StringIO()
try:
attrs = sftp.putfo(stream, target)
@@ -713,59 +734,61 @@ class TestSFTP(object):
verify that we can create a file with a '%' in the filename.
( it needs to be properly escaped by _log() )
"""
- f = sftp.open(sftp.FOLDER + '/test%file', 'w')
+ f = sftp.open(sftp.FOLDER + "/test%file", "w")
try:
assert f.stat().st_size == 0
finally:
f.close()
- sftp.remove(sftp.FOLDER + '/test%file')
+ sftp.remove(sftp.FOLDER + "/test%file")
def test_O_non_utf8_data(self, sftp):
"""Test write() and read() of non utf8 data"""
try:
- with sftp.open('%s/nonutf8data' % sftp.FOLDER, 'w') as f:
+ with sftp.open("%s/nonutf8data" % sftp.FOLDER, "w") as f:
f.write(NON_UTF8_DATA)
- with sftp.open('%s/nonutf8data' % sftp.FOLDER, 'r') as f:
+ with sftp.open("%s/nonutf8data" % sftp.FOLDER, "r") as f:
data = f.read()
assert data == NON_UTF8_DATA
- with sftp.open('%s/nonutf8data' % sftp.FOLDER, 'wb') as f:
+ with sftp.open("%s/nonutf8data" % sftp.FOLDER, "wb") as f:
f.write(NON_UTF8_DATA)
- with sftp.open('%s/nonutf8data' % sftp.FOLDER, 'rb') as f:
+ with sftp.open("%s/nonutf8data" % sftp.FOLDER, "rb") as f:
data = f.read()
assert data == NON_UTF8_DATA
finally:
- sftp.remove('%s/nonutf8data' % sftp.FOLDER)
-
+ sftp.remove("%s/nonutf8data" % sftp.FOLDER)
def test_sftp_attributes_empty_str(self, sftp):
sftp_attributes = SFTPAttributes()
- assert str(sftp_attributes) == "?--------- 1 0 0 0 (unknown date) ?"
+ assert (
+ str(sftp_attributes)
+ == "?--------- 1 0 0 0 (unknown date) ?"
+ )
- @needs_builtin('buffer')
+ @needs_builtin("buffer")
def test_write_buffer(self, sftp):
"""Test write() using a buffer instance."""
- data = 3 * b'A potentially large block of data to chunk up.\n'
+ data = 3 * b"A potentially large block of data to chunk up.\n"
try:
- with sftp.open('%s/write_buffer' % sftp.FOLDER, 'wb') as f:
+ with sftp.open("%s/write_buffer" % sftp.FOLDER, "wb") as f:
for offset in range(0, len(data), 8):
f.write(buffer(data, offset, 8))
- with sftp.open('%s/write_buffer' % sftp.FOLDER, 'rb') as f:
+ with sftp.open("%s/write_buffer" % sftp.FOLDER, "rb") as f:
assert f.read() == data
finally:
- sftp.remove('%s/write_buffer' % sftp.FOLDER)
+ sftp.remove("%s/write_buffer" % sftp.FOLDER)
- @needs_builtin('memoryview')
+ @needs_builtin("memoryview")
def test_write_memoryview(self, sftp):
"""Test write() using a memoryview instance."""
- data = 3 * b'A potentially large block of data to chunk up.\n'
+ data = 3 * b"A potentially large block of data to chunk up.\n"
try:
- with sftp.open('%s/write_memoryview' % sftp.FOLDER, 'wb') as f:
+ with sftp.open("%s/write_memoryview" % sftp.FOLDER, "wb") as f:
view = memoryview(data)
for offset in range(0, len(data), 8):
- f.write(view[offset:offset+8])
+ f.write(view[offset : offset + 8])
- with sftp.open('%s/write_memoryview' % sftp.FOLDER, 'rb') as f:
+ with sftp.open("%s/write_memoryview" % sftp.FOLDER, "rb") as f:
assert f.read() == data
finally:
- sftp.remove('%s/write_memoryview' % sftp.FOLDER)
+ sftp.remove("%s/write_memoryview" % sftp.FOLDER)
diff --git a/tests/test_sftp_big.py b/tests/test_sftp_big.py
index a659098d..97c0eb90 100644
--- a/tests/test_sftp_big.py
+++ b/tests/test_sftp_big.py
@@ -37,6 +37,7 @@ from .util import slow
@slow
class TestBigSFTP(object):
+
def test_1_lots_of_files(self, sftp):
"""
create a bunch of files over the same session.
@@ -44,22 +45,24 @@ class TestBigSFTP(object):
numfiles = 100
try:
for i in range(numfiles):
- with sftp.open('%s/file%d.txt' % (sftp.FOLDER, i), 'w', 1) as f:
- f.write('this is file #%d.\n' % i)
- sftp.chmod('%s/file%d.txt' % (sftp.FOLDER, i), o660)
+ with sftp.open(
+ "%s/file%d.txt" % (sftp.FOLDER, i), "w", 1
+ ) as f:
+ f.write("this is file #%d.\n" % i)
+ sftp.chmod("%s/file%d.txt" % (sftp.FOLDER, i), o660)
# now make sure every file is there, by creating a list of filenmes
# and reading them in random order.
numlist = list(range(numfiles))
while len(numlist) > 0:
r = numlist[random.randint(0, len(numlist) - 1)]
- with sftp.open('%s/file%d.txt' % (sftp.FOLDER, r)) as f:
- assert f.readline() == 'this is file #%d.\n' % r
+ with sftp.open("%s/file%d.txt" % (sftp.FOLDER, r)) as f:
+ assert f.readline() == "this is file #%d.\n" % r
numlist.remove(r)
finally:
for i in range(numfiles):
try:
- sftp.remove('%s/file%d.txt' % (sftp.FOLDER, i))
+ sftp.remove("%s/file%d.txt" % (sftp.FOLDER, i))
except:
pass
@@ -67,52 +70,56 @@ class TestBigSFTP(object):
"""
write a 1MB file with no buffering.
"""
- kblob = (1024 * b'x')
+ kblob = 1024 * b"x"
start = time.time()
try:
- with sftp.open('%s/hongry.txt' % sftp.FOLDER, 'w') as f:
+ with sftp.open("%s/hongry.txt" % sftp.FOLDER, "w") as f:
for n in range(1024):
f.write(kblob)
if n % 128 == 0:
- sys.stderr.write('.')
- sys.stderr.write(' ')
+ sys.stderr.write(".")
+ sys.stderr.write(" ")
- assert sftp.stat('%s/hongry.txt' % sftp.FOLDER).st_size == 1024 * 1024
+ assert (
+ sftp.stat("%s/hongry.txt" % sftp.FOLDER).st_size == 1024 * 1024
+ )
end = time.time()
- sys.stderr.write('%ds ' % round(end - start))
-
+ sys.stderr.write("%ds " % round(end - start))
+
start = time.time()
- with sftp.open('%s/hongry.txt' % sftp.FOLDER, 'r') as f:
+ with sftp.open("%s/hongry.txt" % sftp.FOLDER, "r") as f:
for n in range(1024):
data = f.read(1024)
assert data == kblob
end = time.time()
- sys.stderr.write('%ds ' % round(end - start))
+ sys.stderr.write("%ds " % round(end - start))
finally:
- sftp.remove('%s/hongry.txt' % sftp.FOLDER)
+ sftp.remove("%s/hongry.txt" % sftp.FOLDER)
def test_3_big_file_pipelined(self, sftp):
"""
write a 1MB file, with no linefeeds, using pipelining.
"""
- kblob = bytes().join([struct.pack('>H', n) for n in range(512)])
+ kblob = bytes().join([struct.pack(">H", n) for n in range(512)])
start = time.time()
try:
- with sftp.open('%s/hongry.txt' % sftp.FOLDER, 'wb') as f:
+ with sftp.open("%s/hongry.txt" % sftp.FOLDER, "wb") as f:
f.set_pipelined(True)
for n in range(1024):
f.write(kblob)
if n % 128 == 0:
- sys.stderr.write('.')
- sys.stderr.write(' ')
+ sys.stderr.write(".")
+ sys.stderr.write(" ")
- assert sftp.stat('%s/hongry.txt' % sftp.FOLDER).st_size == 1024 * 1024
+ assert (
+ sftp.stat("%s/hongry.txt" % sftp.FOLDER).st_size == 1024 * 1024
+ )
end = time.time()
- sys.stderr.write('%ds ' % round(end - start))
-
+ sys.stderr.write("%ds " % round(end - start))
+
start = time.time()
- with sftp.open('%s/hongry.txt' % sftp.FOLDER, 'rb') as f:
+ with sftp.open("%s/hongry.txt" % sftp.FOLDER, "rb") as f:
file_size = f.stat().st_size
f.prefetch(file_size)
@@ -126,35 +133,39 @@ class TestBigSFTP(object):
chunk = size - n
data = f.read(chunk)
offset = n % 1024
- assert data == k2blob[offset:offset + chunk]
+ assert data == k2blob[offset : offset + chunk]
n += chunk
end = time.time()
- sys.stderr.write('%ds ' % round(end - start))
+ sys.stderr.write("%ds " % round(end - start))
finally:
- sftp.remove('%s/hongry.txt' % sftp.FOLDER)
+ sftp.remove("%s/hongry.txt" % sftp.FOLDER)
def test_4_prefetch_seek(self, sftp):
- kblob = bytes().join([struct.pack('>H', n) for n in range(512)])
+ kblob = bytes().join([struct.pack(">H", n) for n in range(512)])
try:
- with sftp.open('%s/hongry.txt' % sftp.FOLDER, 'wb') as f:
+ with sftp.open("%s/hongry.txt" % sftp.FOLDER, "wb") as f:
f.set_pipelined(True)
for n in range(1024):
f.write(kblob)
if n % 128 == 0:
- sys.stderr.write('.')
- sys.stderr.write(' ')
-
- assert sftp.stat('%s/hongry.txt' % sftp.FOLDER).st_size == 1024 * 1024
-
+ sys.stderr.write(".")
+ sys.stderr.write(" ")
+
+ assert (
+ sftp.stat("%s/hongry.txt" % sftp.FOLDER).st_size == 1024 * 1024
+ )
+
start = time.time()
k2blob = kblob + kblob
chunk = 793
for i in range(10):
- with sftp.open('%s/hongry.txt' % sftp.FOLDER, 'rb') as f:
+ with sftp.open("%s/hongry.txt" % sftp.FOLDER, "rb") as f:
file_size = f.stat().st_size
f.prefetch(file_size)
- base_offset = (512 * 1024) + 17 * random.randint(1000, 2000)
+ base_offset = (512 * 1024) + 17 * random.randint(
+ 1000, 2000
+ )
offsets = [base_offset + j * chunk for j in range(100)]
# randomly seek around and read them out
for j in range(100):
@@ -163,32 +174,36 @@ class TestBigSFTP(object):
f.seek(offset)
data = f.read(chunk)
n_offset = offset % 1024
- assert data == k2blob[n_offset:n_offset + chunk]
+ assert data == k2blob[n_offset : n_offset + chunk]
offset += chunk
end = time.time()
- sys.stderr.write('%ds ' % round(end - start))
+ sys.stderr.write("%ds " % round(end - start))
finally:
- sftp.remove('%s/hongry.txt' % sftp.FOLDER)
+ sftp.remove("%s/hongry.txt" % sftp.FOLDER)
def test_5_readv_seek(self, sftp):
- kblob = bytes().join([struct.pack('>H', n) for n in range(512)])
+ kblob = bytes().join([struct.pack(">H", n) for n in range(512)])
try:
- with sftp.open('%s/hongry.txt' % sftp.FOLDER, 'wb') as f:
+ with sftp.open("%s/hongry.txt" % sftp.FOLDER, "wb") as f:
f.set_pipelined(True)
for n in range(1024):
f.write(kblob)
if n % 128 == 0:
- sys.stderr.write('.')
- sys.stderr.write(' ')
+ sys.stderr.write(".")
+ sys.stderr.write(" ")
- assert sftp.stat('%s/hongry.txt' % sftp.FOLDER).st_size == 1024 * 1024
+ assert (
+ sftp.stat("%s/hongry.txt" % sftp.FOLDER).st_size == 1024 * 1024
+ )
start = time.time()
k2blob = kblob + kblob
chunk = 793
for i in range(10):
- with sftp.open('%s/hongry.txt' % sftp.FOLDER, 'rb') as f:
- base_offset = (512 * 1024) + 17 * random.randint(1000, 2000)
+ with sftp.open("%s/hongry.txt" % sftp.FOLDER, "rb") as f:
+ base_offset = (512 * 1024) + 17 * random.randint(
+ 1000, 2000
+ )
# make a bunch of offsets and put them in random order
offsets = [base_offset + j * chunk for j in range(100)]
readv_list = []
@@ -200,62 +215,66 @@ class TestBigSFTP(object):
for i in range(len(readv_list)):
offset = readv_list[i][0]
n_offset = offset % 1024
- assert next(ret) == k2blob[n_offset:n_offset + chunk]
+ assert next(ret) == k2blob[n_offset : n_offset + chunk]
end = time.time()
- sys.stderr.write('%ds ' % round(end - start))
+ sys.stderr.write("%ds " % round(end - start))
finally:
- sftp.remove('%s/hongry.txt' % sftp.FOLDER)
+ sftp.remove("%s/hongry.txt" % sftp.FOLDER)
def test_6_lots_of_prefetching(self, sftp):
"""
prefetch a 1MB file a bunch of times, discarding the file object
without using it, to verify that paramiko doesn't get confused.
"""
- kblob = (1024 * b'x')
+ kblob = 1024 * b"x"
try:
- with sftp.open('%s/hongry.txt' % sftp.FOLDER, 'w') as f:
+ with sftp.open("%s/hongry.txt" % sftp.FOLDER, "w") as f:
f.set_pipelined(True)
for n in range(1024):
f.write(kblob)
if n % 128 == 0:
- sys.stderr.write('.')
- sys.stderr.write(' ')
+ sys.stderr.write(".")
+ sys.stderr.write(" ")
- assert sftp.stat('%s/hongry.txt' % sftp.FOLDER).st_size == 1024 * 1024
+ assert (
+ sftp.stat("%s/hongry.txt" % sftp.FOLDER).st_size == 1024 * 1024
+ )
for i in range(10):
- with sftp.open('%s/hongry.txt' % sftp.FOLDER, 'r') as f:
+ with sftp.open("%s/hongry.txt" % sftp.FOLDER, "r") as f:
file_size = f.stat().st_size
f.prefetch(file_size)
- with sftp.open('%s/hongry.txt' % sftp.FOLDER, 'r') as f:
+ with sftp.open("%s/hongry.txt" % sftp.FOLDER, "r") as f:
file_size = f.stat().st_size
f.prefetch(file_size)
for n in range(1024):
data = f.read(1024)
assert data == kblob
if n % 128 == 0:
- sys.stderr.write('.')
- sys.stderr.write(' ')
+ sys.stderr.write(".")
+ sys.stderr.write(" ")
finally:
- sftp.remove('%s/hongry.txt' % sftp.FOLDER)
-
+ sftp.remove("%s/hongry.txt" % sftp.FOLDER)
+
def test_7_prefetch_readv(self, sftp):
"""
verify that prefetch and readv don't conflict with each other.
"""
- kblob = bytes().join([struct.pack('>H', n) for n in range(512)])
+ kblob = bytes().join([struct.pack(">H", n) for n in range(512)])
try:
- with sftp.open('%s/hongry.txt' % sftp.FOLDER, 'wb') as f:
+ with sftp.open("%s/hongry.txt" % sftp.FOLDER, "wb") as f:
f.set_pipelined(True)
for n in range(1024):
f.write(kblob)
if n % 128 == 0:
- sys.stderr.write('.')
- sys.stderr.write(' ')
-
- assert sftp.stat('%s/hongry.txt' % sftp.FOLDER).st_size == 1024 * 1024
+ sys.stderr.write(".")
+ sys.stderr.write(" ")
+
+ assert (
+ sftp.stat("%s/hongry.txt" % sftp.FOLDER).st_size == 1024 * 1024
+ )
- with sftp.open('%s/hongry.txt' % sftp.FOLDER, 'rb') as f:
+ with sftp.open("%s/hongry.txt" % sftp.FOLDER, "rb") as f:
file_size = f.stat().st_size
f.prefetch(file_size)
data = f.read(1024)
@@ -264,79 +283,94 @@ class TestBigSFTP(object):
chunk_size = 793
base_offset = 512 * 1024
k2blob = kblob + kblob
- chunks = [(base_offset + (chunk_size * i), chunk_size) for i in range(20)]
+ chunks = [
+ (base_offset + (chunk_size * i), chunk_size)
+ for i in range(20)
+ ]
for data in f.readv(chunks):
offset = base_offset % 1024
assert chunk_size == len(data)
- assert k2blob[offset:offset + chunk_size] == data
+ assert k2blob[offset : offset + chunk_size] == data
base_offset += chunk_size
- sys.stderr.write(' ')
+ sys.stderr.write(" ")
finally:
- sftp.remove('%s/hongry.txt' % sftp.FOLDER)
-
+ sftp.remove("%s/hongry.txt" % sftp.FOLDER)
+
def test_8_large_readv(self, sftp):
"""
verify that a very large readv is broken up correctly and still
returned as a single blob.
"""
- kblob = bytes().join([struct.pack('>H', n) for n in range(512)])
+ kblob = bytes().join([struct.pack(">H", n) for n in range(512)])
try:
- with sftp.open('%s/hongry.txt' % sftp.FOLDER, 'wb') as f:
+ with sftp.open("%s/hongry.txt" % sftp.FOLDER, "wb") as f:
f.set_pipelined(True)
for n in range(1024):
f.write(kblob)
if n % 128 == 0:
- sys.stderr.write('.')
- sys.stderr.write(' ')
+ sys.stderr.write(".")
+ sys.stderr.write(" ")
+
+ assert (
+ sftp.stat("%s/hongry.txt" % sftp.FOLDER).st_size == 1024 * 1024
+ )
- assert sftp.stat('%s/hongry.txt' % sftp.FOLDER).st_size == 1024 * 1024
-
- with sftp.open('%s/hongry.txt' % sftp.FOLDER, 'rb') as f:
+ with sftp.open("%s/hongry.txt" % sftp.FOLDER, "rb") as f:
data = list(f.readv([(23 * 1024, 128 * 1024)]))
assert len(data) == 1
data = data[0]
assert len(data) == 128 * 1024
-
- sys.stderr.write(' ')
+
+ sys.stderr.write(" ")
finally:
- sftp.remove('%s/hongry.txt' % sftp.FOLDER)
-
+ sftp.remove("%s/hongry.txt" % sftp.FOLDER)
+
def test_9_big_file_big_buffer(self, sftp):
"""
write a 1MB file, with no linefeeds, and a big buffer.
"""
- mblob = (1024 * 1024 * 'x')
+ mblob = 1024 * 1024 * "x"
try:
- with sftp.open('%s/hongry.txt' % sftp.FOLDER, 'w', 128 * 1024) as f:
+ with sftp.open(
+ "%s/hongry.txt" % sftp.FOLDER, "w", 128 * 1024
+ ) as f:
f.write(mblob)
- assert sftp.stat('%s/hongry.txt' % sftp.FOLDER).st_size == 1024 * 1024
+ assert (
+ sftp.stat("%s/hongry.txt" % sftp.FOLDER).st_size == 1024 * 1024
+ )
finally:
- sftp.remove('%s/hongry.txt' % sftp.FOLDER)
-
+ sftp.remove("%s/hongry.txt" % sftp.FOLDER)
+
def test_A_big_file_renegotiate(self, sftp):
"""
write a 1MB file, forcing key renegotiation in the middle.
"""
t = sftp.sock.get_transport()
t.packetizer.REKEY_BYTES = 512 * 1024
- k32blob = (32 * 1024 * 'x')
+ k32blob = 32 * 1024 * "x"
try:
- with sftp.open('%s/hongry.txt' % sftp.FOLDER, 'w', 128 * 1024) as f:
+ with sftp.open(
+ "%s/hongry.txt" % sftp.FOLDER, "w", 128 * 1024
+ ) as f:
for i in range(32):
f.write(k32blob)
- assert sftp.stat('%s/hongry.txt' % sftp.FOLDER).st_size == 1024 * 1024
+ assert (
+ sftp.stat("%s/hongry.txt" % sftp.FOLDER).st_size == 1024 * 1024
+ )
assert t.H != t.session_id
-
+
# try to read it too.
- with sftp.open('%s/hongry.txt' % sftp.FOLDER, 'r', 128 * 1024) as f:
+ with sftp.open(
+ "%s/hongry.txt" % sftp.FOLDER, "r", 128 * 1024
+ ) as f:
file_size = f.stat().st_size
f.prefetch(file_size)
total = 0
while total < 1024 * 1024:
total += len(f.read(32 * 1024))
finally:
- sftp.remove('%s/hongry.txt' % sftp.FOLDER)
+ sftp.remove("%s/hongry.txt" % sftp.FOLDER)
t.packetizer.REKEY_BYTES = pow(2, 30)
diff --git a/tests/test_ssh_exception.py b/tests/test_ssh_exception.py
index 18f2a97d..6cc5d06a 100644
--- a/tests/test_ssh_exception.py
+++ b/tests/test_ssh_exception.py
@@ -4,28 +4,33 @@ import unittest
from paramiko.ssh_exception import NoValidConnectionsError
-class NoValidConnectionsErrorTest (unittest.TestCase):
+class NoValidConnectionsErrorTest(unittest.TestCase):
def test_pickling(self):
# Regression test for https://github.com/paramiko/paramiko/issues/617
- exc = NoValidConnectionsError({('127.0.0.1', '22'): Exception()})
+ exc = NoValidConnectionsError({("127.0.0.1", "22"): Exception()})
new_exc = pickle.loads(pickle.dumps(exc))
self.assertEqual(type(exc), type(new_exc))
self.assertEqual(str(exc), str(new_exc))
self.assertEqual(exc.args, new_exc.args)
def test_error_message_for_single_host(self):
- exc = NoValidConnectionsError({('127.0.0.1', '22'): Exception()})
+ exc = NoValidConnectionsError({("127.0.0.1", "22"): Exception()})
assert "Unable to connect to port 22 on 127.0.0.1" in str(exc)
def test_error_message_for_two_hosts(self):
- exc = NoValidConnectionsError({('127.0.0.1', '22'): Exception(),
- ('::1', '22'): Exception()})
+ exc = NoValidConnectionsError(
+ {("127.0.0.1", "22"): Exception(), ("::1", "22"): Exception()}
+ )
assert "Unable to connect to port 22 on 127.0.0.1 or ::1" in str(exc)
def test_error_message_for_multiple_hosts(self):
- exc = NoValidConnectionsError({('127.0.0.1', '22'): Exception(),
- ('::1', '22'): Exception(),
- ('10.0.0.42', '22'): Exception()})
+ exc = NoValidConnectionsError(
+ {
+ ("127.0.0.1", "22"): Exception(),
+ ("::1", "22"): Exception(),
+ ("10.0.0.42", "22"): Exception(),
+ }
+ )
exp = "Unable to connect to port 22 on 10.0.0.42, 127.0.0.1 or ::1"
assert exp in str(exc)
diff --git a/tests/test_ssh_gss.py b/tests/test_ssh_gss.py
index f0645e0e..cee6ce89 100644
--- a/tests/test_ssh_gss.py
+++ b/tests/test_ssh_gss.py
@@ -33,15 +33,13 @@ from .util import _support, needs_gssapi
from .test_client import FINGERPRINTS
-class NullServer (paramiko.ServerInterface):
+class NullServer(paramiko.ServerInterface):
+
def get_allowed_auths(self, username):
- return 'gssapi-with-mic,publickey'
+ return "gssapi-with-mic,publickey"
def check_auth_gssapi_with_mic(
- self,
- username,
- gss_authenticated=paramiko.AUTH_FAILED,
- cc_file=None,
+ self, username, gss_authenticated=paramiko.AUTH_FAILED, cc_file=None
):
if gss_authenticated == paramiko.AUTH_SUCCESSFUL:
return paramiko.AUTH_SUCCESSFUL
@@ -64,13 +62,14 @@ class NullServer (paramiko.ServerInterface):
return paramiko.OPEN_SUCCEEDED
def check_channel_exec_request(self, channel, command):
- if command != 'yes':
+ if command != "yes":
return False
return True
@needs_gssapi
class GSSAuthTest(unittest.TestCase):
+
def setUp(self):
# TODO: username and targ_name should come from os.environ or whatever
# the approved pytest method is for runtime-configuring test data.
@@ -92,7 +91,7 @@ class GSSAuthTest(unittest.TestCase):
def _run(self):
self.socks, addr = self.sockl.accept()
self.ts = paramiko.Transport(self.socks)
- host_key = paramiko.RSAKey.from_private_key_file('tests/test_rsa.key')
+ host_key = paramiko.RSAKey.from_private_key_file("tests/test_rsa.key")
self.ts.add_server_key(host_key)
server = NullServer()
self.ts.start_server(self.event, server)
@@ -103,15 +102,22 @@ class GSSAuthTest(unittest.TestCase):
The exception is ... no exception yet
"""
- host_key = paramiko.RSAKey.from_private_key_file('tests/test_rsa.key')
+ host_key = paramiko.RSAKey.from_private_key_file("tests/test_rsa.key")
public_host_key = paramiko.RSAKey(data=host_key.asbytes())
self.tc = paramiko.SSHClient()
self.tc.set_missing_host_key_policy(paramiko.WarningPolicy())
- self.tc.get_host_keys().add('[%s]:%d' % (self.addr, self.port),
- 'ssh-rsa', public_host_key)
- self.tc.connect(hostname=self.addr, port=self.port, username=self.username, gss_host=self.hostname,
- gss_auth=True, **kwargs)
+ self.tc.get_host_keys().add(
+ "[%s]:%d" % (self.addr, self.port), "ssh-rsa", public_host_key
+ )
+ self.tc.connect(
+ hostname=self.addr,
+ port=self.port,
+ username=self.username,
+ gss_host=self.hostname,
+ gss_auth=True,
+ **kwargs
+ )
self.event.wait(1.0)
self.assert_(self.event.is_set())
@@ -119,17 +125,17 @@ class GSSAuthTest(unittest.TestCase):
self.assertEquals(self.username, self.ts.get_username())
self.assertEquals(True, self.ts.is_authenticated())
- stdin, stdout, stderr = self.tc.exec_command('yes')
+ stdin, stdout, stderr = self.tc.exec_command("yes")
schan = self.ts.accept(1.0)
- schan.send('Hello there.\n')
- schan.send_stderr('This is on stderr.\n')
+ schan.send("Hello there.\n")
+ schan.send_stderr("This is on stderr.\n")
schan.close()
- self.assertEquals('Hello there.\n', stdout.readline())
- self.assertEquals('', stdout.readline())
- self.assertEquals('This is on stderr.\n', stderr.readline())
- self.assertEquals('', stderr.readline())
+ self.assertEquals("Hello there.\n", stdout.readline())
+ self.assertEquals("", stdout.readline())
+ self.assertEquals("This is on stderr.\n", stderr.readline())
+ self.assertEquals("", stderr.readline())
stdin.close()
stdout.close()
@@ -140,14 +146,17 @@ class GSSAuthTest(unittest.TestCase):
Verify that Paramiko can handle SSHv2 GSS-API / SSPI authentication
(gssapi-with-mic) in client and server mode.
"""
- self._test_connection(allow_agent=False,
- look_for_keys=False)
+ self._test_connection(allow_agent=False, look_for_keys=False)
def test_2_auth_trickledown(self):
"""
Failed gssapi-with-mic auth doesn't prevent subsequent key auth from succeeding
"""
- self.hostname = "this_host_does_not_exists_and_causes_a_GSSAPI-exception"
- self._test_connection(key_filename=[_support('test_rsa.key')],
- allow_agent=False,
- look_for_keys=False)
+ self.hostname = (
+ "this_host_does_not_exists_and_causes_a_GSSAPI-exception"
+ )
+ self._test_connection(
+ key_filename=[_support("test_rsa.key")],
+ allow_agent=False,
+ look_for_keys=False,
+ )
diff --git a/tests/test_transport.py b/tests/test_transport.py
index 9474acfc..c05d6781 100644
--- a/tests/test_transport.py
+++ b/tests/test_transport.py
@@ -32,14 +32,26 @@ from hashlib import sha1
import unittest
from paramiko import (
- Transport, SecurityOptions, ServerInterface, RSAKey, DSSKey, SSHException,
- ChannelException, Packetizer, Channel,
+ Transport,
+ SecurityOptions,
+ ServerInterface,
+ RSAKey,
+ DSSKey,
+ SSHException,
+ ChannelException,
+ Packetizer,
+ Channel,
)
from paramiko import AUTH_FAILED, AUTH_SUCCESSFUL
from paramiko import OPEN_SUCCEEDED, OPEN_FAILED_ADMINISTRATIVELY_PROHIBITED
from paramiko.common import (
- MSG_KEXINIT, cMSG_CHANNEL_WINDOW_ADJUST, MIN_PACKET_SIZE, MIN_WINDOW_SIZE,
- MAX_WINDOW_SIZE, DEFAULT_WINDOW_SIZE, DEFAULT_MAX_PACKET_SIZE,
+ MSG_KEXINIT,
+ cMSG_CHANNEL_WINDOW_ADJUST,
+ MIN_PACKET_SIZE,
+ MIN_WINDOW_SIZE,
+ MAX_WINDOW_SIZE,
+ DEFAULT_WINDOW_SIZE,
+ DEFAULT_MAX_PACKET_SIZE,
)
from paramiko.py3compat import bytes
from paramiko.message import Message
@@ -61,28 +73,28 @@ Maybe.
"""
-class NullServer (ServerInterface):
+class NullServer(ServerInterface):
paranoid_did_password = False
paranoid_did_public_key = False
- paranoid_key = DSSKey.from_private_key_file(_support('test_dss.key'))
+ paranoid_key = DSSKey.from_private_key_file(_support("test_dss.key"))
def get_allowed_auths(self, username):
- if username == 'slowdive':
- return 'publickey,password'
- return 'publickey'
+ if username == "slowdive":
+ return "publickey,password"
+ return "publickey"
def check_auth_password(self, username, password):
- if (username == 'slowdive') and (password == 'pygmalion'):
+ if (username == "slowdive") and (password == "pygmalion"):
return AUTH_SUCCESSFUL
return AUTH_FAILED
def check_channel_request(self, kind, chanid):
- if kind == 'bogus':
+ if kind == "bogus":
return OPEN_FAILED_ADMINISTRATIVELY_PROHIBITED
return OPEN_SUCCEEDED
def check_channel_exec_request(self, channel, command):
- if command != b'yes':
+ if command != b"yes":
return False
return True
@@ -95,9 +107,16 @@ class NullServer (ServerInterface):
# tho that's only supposed to occur if the request cannot be served.
# For now, leaving that the default unless test supplies specific
# 'acceptable' request kind
- return kind == 'acceptable'
-
- def check_channel_x11_request(self, channel, single_connection, auth_protocol, auth_cookie, screen_number):
+ return kind == "acceptable"
+
+ def check_channel_x11_request(
+ self,
+ channel,
+ single_connection,
+ auth_protocol,
+ auth_cookie,
+ screen_number,
+ ):
self._x11_single_connection = single_connection
self._x11_auth_protocol = auth_protocol
self._x11_auth_cookie = auth_cookie
@@ -106,7 +125,7 @@ class NullServer (ServerInterface):
def check_port_forward_request(self, addr, port):
self._listen = socket.socket()
- self._listen.bind(('127.0.0.1', 0))
+ self._listen.bind(("127.0.0.1", 0))
self._listen.listen(1)
return self._listen.getsockname()[1]
@@ -120,6 +139,7 @@ class NullServer (ServerInterface):
class TransportTest(unittest.TestCase):
+
def setUp(self):
self.socks = LoopSocket()
self.sockc = LoopSocket()
@@ -134,9 +154,9 @@ class TransportTest(unittest.TestCase):
self.sockc.close()
def setup_test_server(
- self, client_options=None, server_options=None, connect_kwargs=None,
+ self, client_options=None, server_options=None, connect_kwargs=None
):
- host_key = RSAKey.from_private_key_file(_support('test_rsa.key'))
+ host_key = RSAKey.from_private_key_file(_support("test_rsa.key"))
public_host_key = RSAKey(data=host_key.asbytes())
self.ts.add_server_key(host_key)
@@ -152,8 +172,8 @@ class TransportTest(unittest.TestCase):
if connect_kwargs is None:
connect_kwargs = dict(
hostkey=public_host_key,
- username='slowdive',
- password='pygmalion',
+ username="slowdive",
+ password="pygmalion",
)
self.tc.connect(**connect_kwargs)
event.wait(1.0)
@@ -163,11 +183,11 @@ class TransportTest(unittest.TestCase):
def test_1_security_options(self):
o = self.tc.get_security_options()
self.assertEqual(type(o), SecurityOptions)
- self.assertTrue(('aes256-cbc', 'blowfish-cbc') != o.ciphers)
- o.ciphers = ('aes256-cbc', 'blowfish-cbc')
- self.assertEqual(('aes256-cbc', 'blowfish-cbc'), o.ciphers)
+ self.assertTrue(("aes256-cbc", "blowfish-cbc") != o.ciphers)
+ o.ciphers = ("aes256-cbc", "blowfish-cbc")
+ self.assertEqual(("aes256-cbc", "blowfish-cbc"), o.ciphers)
try:
- o.ciphers = ('aes256-cbc', 'made-up-cipher')
+ o.ciphers = ("aes256-cbc", "made-up-cipher")
self.assertTrue(False)
except ValueError:
pass
@@ -187,12 +207,18 @@ class TransportTest(unittest.TestCase):
o.compression = o.compression
def test_2_compute_key(self):
- self.tc.K = 123281095979686581523377256114209720774539068973101330872763622971399429481072519713536292772709507296759612401802191955568143056534122385270077606457721553469730659233569339356140085284052436697480759510519672848743794433460113118986816826624865291116513647975790797391795651716378444844877749505443714557929
- self.tc.H = b'\x0C\x83\x07\xCD\xE6\x85\x6F\xF3\x0B\xA9\x36\x84\xEB\x0F\x04\xC2\x52\x0E\x9E\xD3'
+ self.tc.K = (
+ 123281095979686581523377256114209720774539068973101330872763622971399429481072519713536292772709507296759612401802191955568143056534122385270077606457721553469730659233569339356140085284052436697480759510519672848743794433460113118986816826624865291116513647975790797391795651716378444844877749505443714557929
+ )
+ self.tc.H = (
+ b"\x0C\x83\x07\xCD\xE6\x85\x6F\xF3\x0B\xA9\x36\x84\xEB\x0F\x04\xC2\x52\x0E\x9E\xD3"
+ )
self.tc.session_id = self.tc.H
- key = self.tc._compute_key('C', 32)
- self.assertEqual(b'207E66594CA87C44ECCBA3B3CD39FDDB378E6FDB0F97C54B2AA0CFBF900CD995',
- hexlify(key).upper())
+ key = self.tc._compute_key("C", 32)
+ self.assertEqual(
+ b"207E66594CA87C44ECCBA3B3CD39FDDB378E6FDB0F97C54B2AA0CFBF900CD995",
+ hexlify(key).upper(),
+ )
def test_3_simple(self):
"""
@@ -200,7 +226,7 @@ class TransportTest(unittest.TestCase):
loopback sockets. this is hardly "simple" but it's simpler than the
later tests. :)
"""
- host_key = RSAKey.from_private_key_file(_support('test_rsa.key'))
+ host_key = RSAKey.from_private_key_file(_support("test_rsa.key"))
public_host_key = RSAKey(data=host_key.asbytes())
self.ts.add_server_key(host_key)
event = threading.Event()
@@ -211,13 +237,14 @@ class TransportTest(unittest.TestCase):
self.assertEqual(False, self.tc.is_authenticated())
self.assertEqual(False, self.ts.is_authenticated())
self.ts.start_server(event, server)
- self.tc.connect(hostkey=public_host_key,
- username='slowdive', password='pygmalion')
+ self.tc.connect(
+ hostkey=public_host_key, username="slowdive", password="pygmalion"
+ )
event.wait(1.0)
self.assertTrue(event.is_set())
self.assertTrue(self.ts.is_active())
- self.assertEqual('slowdive', self.tc.get_username())
- self.assertEqual('slowdive', self.ts.get_username())
+ self.assertEqual("slowdive", self.tc.get_username())
+ self.assertEqual("slowdive", self.ts.get_username())
self.assertEqual(True, self.tc.is_authenticated())
self.assertEqual(True, self.ts.is_authenticated())
@@ -225,7 +252,7 @@ class TransportTest(unittest.TestCase):
"""
verify that a long banner doesn't mess up the handshake.
"""
- host_key = RSAKey.from_private_key_file(_support('test_rsa.key'))
+ host_key = RSAKey.from_private_key_file(_support("test_rsa.key"))
public_host_key = RSAKey(data=host_key.asbytes())
self.ts.add_server_key(host_key)
event = threading.Event()
@@ -233,8 +260,9 @@ class TransportTest(unittest.TestCase):
self.assertTrue(not event.is_set())
self.socks.send(LONG_BANNER)
self.ts.start_server(event, server)
- self.tc.connect(hostkey=public_host_key,
- username='slowdive', password='pygmalion')
+ self.tc.connect(
+ hostkey=public_host_key, username="slowdive", password="pygmalion"
+ )
event.wait(1.0)
self.assertTrue(event.is_set())
self.assertTrue(self.ts.is_active())
@@ -244,12 +272,14 @@ class TransportTest(unittest.TestCase):
verify that the client can demand odd handshake settings, and can
renegotiate keys in mid-stream.
"""
+
def force_algorithms(options):
- options.ciphers = ('aes256-cbc',)
- options.digests = ('hmac-md5-96',)
+ options.ciphers = ("aes256-cbc",)
+ options.digests = ("hmac-md5-96",)
+
self.setup_test_server(client_options=force_algorithms)
- self.assertEqual('aes256-cbc', self.tc.local_cipher)
- self.assertEqual('aes256-cbc', self.tc.remote_cipher)
+ self.assertEqual("aes256-cbc", self.tc.local_cipher)
+ self.assertEqual("aes256-cbc", self.tc.remote_cipher)
self.assertEqual(12, self.tc.packetizer.get_mac_size_out())
self.assertEqual(12, self.tc.packetizer.get_mac_size_in())
@@ -263,10 +293,10 @@ class TransportTest(unittest.TestCase):
verify that the keepalive will be sent.
"""
self.setup_test_server()
- self.assertEqual(None, getattr(self.server, '_global_request', None))
+ self.assertEqual(None, getattr(self.server, "_global_request", None))
self.tc.set_keepalive(1)
time.sleep(2)
- self.assertEqual('keepalive@lag.net', self.server._global_request)
+ self.assertEqual("keepalive@lag.net", self.server._global_request)
def test_6_exec_command(self):
"""
@@ -277,39 +307,41 @@ class TransportTest(unittest.TestCase):
chan = self.tc.open_session()
schan = self.ts.accept(1.0)
try:
- chan.exec_command(b'command contains \xfc and is not a valid UTF-8 string')
+ chan.exec_command(
+ b"command contains \xfc and is not a valid UTF-8 string"
+ )
self.assertTrue(False)
except SSHException:
pass
chan = self.tc.open_session()
- chan.exec_command('yes')
+ chan.exec_command("yes")
schan = self.ts.accept(1.0)
- schan.send('Hello there.\n')
- schan.send_stderr('This is on stderr.\n')
+ schan.send("Hello there.\n")
+ schan.send_stderr("This is on stderr.\n")
schan.close()
f = chan.makefile()
- self.assertEqual('Hello there.\n', f.readline())
- self.assertEqual('', f.readline())
+ self.assertEqual("Hello there.\n", f.readline())
+ self.assertEqual("", f.readline())
f = chan.makefile_stderr()
- self.assertEqual('This is on stderr.\n', f.readline())
- self.assertEqual('', f.readline())
+ self.assertEqual("This is on stderr.\n", f.readline())
+ self.assertEqual("", f.readline())
# now try it with combined stdout/stderr
chan = self.tc.open_session()
- chan.exec_command('yes')
+ chan.exec_command("yes")
schan = self.ts.accept(1.0)
- schan.send('Hello there.\n')
- schan.send_stderr('This is on stderr.\n')
+ schan.send("Hello there.\n")
+ schan.send_stderr("This is on stderr.\n")
schan.close()
chan.set_combine_stderr(True)
f = chan.makefile()
- self.assertEqual('Hello there.\n', f.readline())
- self.assertEqual('This is on stderr.\n', f.readline())
- self.assertEqual('', f.readline())
-
+ self.assertEqual("Hello there.\n", f.readline())
+ self.assertEqual("This is on stderr.\n", f.readline())
+ self.assertEqual("", f.readline())
+
def test_6a_channel_can_be_used_as_context_manager(self):
"""
verify that exec_command() does something reasonable.
@@ -318,13 +350,13 @@ class TransportTest(unittest.TestCase):
with self.tc.open_session() as chan:
with self.ts.accept(1.0) as schan:
- chan.exec_command('yes')
- schan.send('Hello there.\n')
+ chan.exec_command("yes")
+ schan.send("Hello there.\n")
schan.close()
f = chan.makefile()
- self.assertEqual('Hello there.\n', f.readline())
- self.assertEqual('', f.readline())
+ self.assertEqual("Hello there.\n", f.readline())
+ self.assertEqual("", f.readline())
def test_7_invoke_shell(self):
"""
@@ -334,11 +366,11 @@ class TransportTest(unittest.TestCase):
chan = self.tc.open_session()
chan.invoke_shell()
schan = self.ts.accept(1.0)
- chan.send('communist j. cat\n')
+ chan.send("communist j. cat\n")
f = schan.makefile()
- self.assertEqual('communist j. cat\n', f.readline())
+ self.assertEqual("communist j. cat\n", f.readline())
chan.close()
- self.assertEqual('', f.readline())
+ self.assertEqual("", f.readline())
def test_8_channel_exception(self):
"""
@@ -346,8 +378,8 @@ class TransportTest(unittest.TestCase):
"""
self.setup_test_server()
try:
- chan = self.tc.open_channel('bogus')
- self.fail('expected exception')
+ chan = self.tc.open_channel("bogus")
+ self.fail("expected exception")
except ChannelException as e:
self.assertTrue(e.code == OPEN_FAILED_ADMINISTRATIVELY_PROHIBITED)
@@ -359,8 +391,8 @@ class TransportTest(unittest.TestCase):
chan = self.tc.open_session()
schan = self.ts.accept(1.0)
- chan.exec_command('yes')
- schan.send('Hello there.\n')
+ chan.exec_command("yes")
+ schan.send("Hello there.\n")
self.assertTrue(not chan.exit_status_ready())
# trigger an EOF
schan.shutdown_read()
@@ -369,8 +401,8 @@ class TransportTest(unittest.TestCase):
schan.close()
f = chan.makefile()
- self.assertEqual('Hello there.\n', f.readline())
- self.assertEqual('', f.readline())
+ self.assertEqual("Hello there.\n", f.readline())
+ self.assertEqual("", f.readline())
count = 0
while not chan.exit_status_ready():
time.sleep(0.1)
@@ -395,7 +427,7 @@ class TransportTest(unittest.TestCase):
self.assertEqual([], w)
self.assertEqual([], e)
- schan.send('hello\n')
+ schan.send("hello\n")
# something should be ready now (give it 1 second to appear)
for i in range(10):
@@ -407,7 +439,7 @@ class TransportTest(unittest.TestCase):
self.assertEqual([], w)
self.assertEqual([], e)
- self.assertEqual(b'hello\n', chan.recv(6))
+ self.assertEqual(b"hello\n", chan.recv(6))
# and, should be dead again now
r, w, e = select.select([chan], [], [], 0.1)
@@ -442,12 +474,12 @@ class TransportTest(unittest.TestCase):
self.setup_test_server()
self.tc.packetizer.REKEY_BYTES = 16384
chan = self.tc.open_session()
- chan.exec_command('yes')
+ chan.exec_command("yes")
schan = self.ts.accept(1.0)
self.assertEqual(self.tc.H, self.tc.session_id)
for i in range(20):
- chan.send('x' * 1024)
+ chan.send("x" * 1024)
chan.close()
# allow a few seconds for the rekeying to complete
@@ -463,18 +495,20 @@ class TransportTest(unittest.TestCase):
"""
verify that zlib compression is basically working.
"""
+
def force_compression(o):
- o.compression = ('zlib',)
+ o.compression = ("zlib",)
+
self.setup_test_server(force_compression, force_compression)
chan = self.tc.open_session()
- chan.exec_command('yes')
+ chan.exec_command("yes")
schan = self.ts.accept(1.0)
bytes = self.tc.packetizer._Packetizer__sent_bytes
- chan.send('x' * 1024)
+ chan.send("x" * 1024)
bytes2 = self.tc.packetizer._Packetizer__sent_bytes
- block_size = self.tc._cipher_info[self.tc.local_cipher]['block-size']
- mac_size = self.tc._mac_info[self.tc.local_mac]['size']
+ block_size = self.tc._cipher_info[self.tc.local_cipher]["block-size"]
+ mac_size = self.tc._mac_info[self.tc.local_mac]["size"]
# tests show this is actually compressed to *52 bytes*! including packet overhead! nice!! :)
self.assertTrue(bytes2 - bytes < 1024)
self.assertEqual(16 + block_size + mac_size, bytes2 - bytes)
@@ -488,29 +522,32 @@ class TransportTest(unittest.TestCase):
"""
self.setup_test_server()
chan = self.tc.open_session()
- chan.exec_command('yes')
+ chan.exec_command("yes")
schan = self.ts.accept(1.0)
requested = []
+
def handler(c, addr_port):
addr, port = addr_port
requested.append((addr, port))
self.tc._queue_incoming_channel(c)
- self.assertEqual(None, getattr(self.server, '_x11_screen_number', None))
+ self.assertEqual(
+ None, getattr(self.server, "_x11_screen_number", None)
+ )
cookie = chan.request_x11(0, single_connection=True, handler=handler)
self.assertEqual(0, self.server._x11_screen_number)
- self.assertEqual('MIT-MAGIC-COOKIE-1', self.server._x11_auth_protocol)
+ self.assertEqual("MIT-MAGIC-COOKIE-1", self.server._x11_auth_protocol)
self.assertEqual(cookie, self.server._x11_auth_cookie)
self.assertEqual(True, self.server._x11_single_connection)
- x11_server = self.ts.open_x11_channel(('localhost', 6093))
+ x11_server = self.ts.open_x11_channel(("localhost", 6093))
x11_client = self.tc.accept()
- self.assertEqual('localhost', requested[0][0])
+ self.assertEqual("localhost", requested[0][0])
self.assertEqual(6093, requested[0][1])
- x11_server.send('hello')
- self.assertEqual(b'hello', x11_client.recv(5))
+ x11_server.send("hello")
+ self.assertEqual(b"hello", x11_client.recv(5))
x11_server.close()
x11_client.close()
@@ -524,33 +561,36 @@ class TransportTest(unittest.TestCase):
"""
self.setup_test_server()
chan = self.tc.open_session()
- chan.exec_command('yes')
+ chan.exec_command("yes")
schan = self.ts.accept(1.0)
requested = []
+
def handler(c, origin_addr_port, server_addr_port):
requested.append(origin_addr_port)
requested.append(server_addr_port)
self.tc._queue_incoming_channel(c)
- port = self.tc.request_port_forward('127.0.0.1', 0, handler)
+ port = self.tc.request_port_forward("127.0.0.1", 0, handler)
self.assertEqual(port, self.server._listen.getsockname()[1])
cs = socket.socket()
- cs.connect(('127.0.0.1', port))
+ cs.connect(("127.0.0.1", port))
ss, _ = self.server._listen.accept()
- sch = self.ts.open_forwarded_tcpip_channel(ss.getsockname(), ss.getpeername())
+ sch = self.ts.open_forwarded_tcpip_channel(
+ ss.getsockname(), ss.getpeername()
+ )
cch = self.tc.accept()
- sch.send('hello')
- self.assertEqual(b'hello', cch.recv(5))
+ sch.send("hello")
+ self.assertEqual(b"hello", cch.recv(5))
sch.close()
cch.close()
ss.close()
cs.close()
# now cancel it.
- self.tc.cancel_port_forward('127.0.0.1', port)
+ self.tc.cancel_port_forward("127.0.0.1", port)
self.assertTrue(self.server._listen is None)
def test_F_port_forwarding(self):
@@ -560,27 +600,29 @@ class TransportTest(unittest.TestCase):
"""
self.setup_test_server()
chan = self.tc.open_session()
- chan.exec_command('yes')
+ chan.exec_command("yes")
schan = self.ts.accept(1.0)
# open a port on the "server" that the client will ask to forward to.
greeting_server = socket.socket()
- greeting_server.bind(('127.0.0.1', 0))
+ greeting_server.bind(("127.0.0.1", 0))
greeting_server.listen(1)
greeting_port = greeting_server.getsockname()[1]
- cs = self.tc.open_channel('direct-tcpip', ('127.0.0.1', greeting_port), ('', 9000))
+ cs = self.tc.open_channel(
+ "direct-tcpip", ("127.0.0.1", greeting_port), ("", 9000)
+ )
sch = self.ts.accept(1.0)
cch = socket.socket()
cch.connect(self.server._tcpip_dest)
ss, _ = greeting_server.accept()
- ss.send(b'Hello!\n')
+ ss.send(b"Hello!\n")
ss.close()
sch.send(cch.recv(8192))
sch.close()
- self.assertEqual(b'Hello!\n', cs.recv(7))
+ self.assertEqual(b"Hello!\n", cs.recv(7))
cs.close()
def test_G_stderr_select(self):
@@ -599,7 +641,7 @@ class TransportTest(unittest.TestCase):
self.assertEqual([], w)
self.assertEqual([], e)
- schan.send_stderr('hello\n')
+ schan.send_stderr("hello\n")
# something should be ready now (give it 1 second to appear)
for i in range(10):
@@ -611,7 +653,7 @@ class TransportTest(unittest.TestCase):
self.assertEqual([], w)
self.assertEqual([], e)
- self.assertEqual(b'hello\n', chan.recv_stderr(6))
+ self.assertEqual(b"hello\n", chan.recv_stderr(6))
# and, should be dead again now
r, w, e = select.select([chan], [], [], 0.1)
@@ -633,8 +675,8 @@ class TransportTest(unittest.TestCase):
self.assertEqual(chan.send_ready(), True)
total = 0
- K = '*' * 1024
- limit = 1+(64 * 2 ** 15)
+ K = "*" * 1024
+ limit = 1 + (64 * 2 ** 15)
while total < limit:
chan.send(K)
total += len(K)
@@ -696,8 +738,11 @@ class TransportTest(unittest.TestCase):
# expires, a deadlock is assumed.
class SendThread(threading.Thread):
+
def __init__(self, chan, iterations, done_event):
- threading.Thread.__init__(self, None, None, self.__class__.__name__)
+ threading.Thread.__init__(
+ self, None, None, self.__class__.__name__
+ )
self.setDaemon(True)
self.chan = chan
self.iterations = iterations
@@ -707,19 +752,22 @@ class TransportTest(unittest.TestCase):
def run(self):
try:
- for i in range(1, 1+self.iterations):
+ for i in range(1, 1 + self.iterations):
if self.done_event.is_set():
break
self.watchdog_event.set()
- #print i, "SEND"
+ # print i, "SEND"
self.chan.send("x" * 2048)
finally:
self.done_event.set()
self.watchdog_event.set()
class ReceiveThread(threading.Thread):
+
def __init__(self, chan, done_event):
- threading.Thread.__init__(self, None, None, self.__class__.__name__)
+ threading.Thread.__init__(
+ self, None, None, self.__class__.__name__
+ )
self.setDaemon(True)
self.chan = chan
self.done_event = done_event
@@ -742,30 +790,34 @@ class TransportTest(unittest.TestCase):
self.ts.packetizer.REKEY_BYTES = 2048
chan = self.tc.open_session()
- chan.exec_command('yes')
+ chan.exec_command("yes")
schan = self.ts.accept(1.0)
# Monkey patch the client's Transport._handler_table so that the client
# sends MSG_CHANNEL_WINDOW_ADJUST whenever it receives an initial
# MSG_KEXINIT. This is used to simulate the effect of network latency
# on a real MSG_CHANNEL_WINDOW_ADJUST message.
- self.tc._handler_table = self.tc._handler_table.copy() # copy per-class dictionary
+ self.tc._handler_table = (
+ self.tc._handler_table.copy()
+ ) # copy per-class dictionary
_negotiate_keys = self.tc._handler_table[MSG_KEXINIT]
+
def _negotiate_keys_wrapper(self, m):
- if self.local_kex_init is None: # Remote side sent KEXINIT
+ if self.local_kex_init is None: # Remote side sent KEXINIT
# Simulate in-transit MSG_CHANNEL_WINDOW_ADJUST by sending it
# before responding to the incoming MSG_KEXINIT.
m2 = Message()
m2.add_byte(cMSG_CHANNEL_WINDOW_ADJUST)
m2.add_int(chan.remote_chanid)
- m2.add_int(1) # bytes to add
+ m2.add_int(1) # bytes to add
self._send_message(m2)
return _negotiate_keys(self, m)
+
self.tc._handler_table[MSG_KEXINIT] = _negotiate_keys_wrapper
# Parameters for the test
- iterations = 500 # The deadlock does not happen every time, but it
- # should after many iterations.
+ iterations = 500 # The deadlock does not happen every time, but it
+ # should after many iterations.
timeout = 5
# This event is set when the test is completed
@@ -807,18 +859,22 @@ class TransportTest(unittest.TestCase):
"""
verify that we conform to the rfc of packet and window sizes.
"""
- for val, correct in [(4095, MIN_PACKET_SIZE),
- (None, DEFAULT_MAX_PACKET_SIZE),
- (2**32, MAX_WINDOW_SIZE)]:
+ for val, correct in [
+ (4095, MIN_PACKET_SIZE),
+ (None, DEFAULT_MAX_PACKET_SIZE),
+ (2 ** 32, MAX_WINDOW_SIZE),
+ ]:
self.assertEqual(self.tc._sanitize_packet_size(val), correct)
def test_K_sanitze_window_size(self):
"""
verify that we conform to the rfc of packet and window sizes.
"""
- for val, correct in [(32767, MIN_WINDOW_SIZE),
- (None, DEFAULT_WINDOW_SIZE),
- (2**32, MAX_WINDOW_SIZE)]:
+ for val, correct in [
+ (32767, MIN_WINDOW_SIZE),
+ (None, DEFAULT_WINDOW_SIZE),
+ (2 ** 32, MAX_WINDOW_SIZE),
+ ]:
self.assertEqual(self.tc._sanitize_window_size(val), correct)
@slow
@@ -834,15 +890,17 @@ class TransportTest(unittest.TestCase):
# (Doing this on the server's transport *sounds* more 'correct' but
# actually doesn't work nearly as well for whatever reason.)
class SlowPacketizer(Packetizer):
+
def read_message(self):
time.sleep(1)
return super(SlowPacketizer, self).read_message()
+
# NOTE: prettttty sure since the replaced .packetizer Packetizer is now
# no longer doing anything with its copy of the socket...everything'll
# be fine. Even tho it's a bit squicky.
self.tc.packetizer = SlowPacketizer(self.tc.sock)
# Continue with regular test red tape.
- host_key = RSAKey.from_private_key_file(_support('test_rsa.key'))
+ host_key = RSAKey.from_private_key_file(_support("test_rsa.key"))
public_host_key = RSAKey(data=host_key.asbytes())
self.ts.add_server_key(host_key)
event = threading.Event()
@@ -850,10 +908,13 @@ class TransportTest(unittest.TestCase):
self.assertTrue(not event.is_set())
self.tc.handshake_timeout = 0.000000000001
self.ts.start_server(event, server)
- self.assertRaises(EOFError, self.tc.connect,
- hostkey=public_host_key,
- username='slowdive',
- password='pygmalion')
+ self.assertRaises(
+ EOFError,
+ self.tc.connect,
+ hostkey=public_host_key,
+ username="slowdive",
+ password="pygmalion",
+ )
def test_M_select_after_close(self):
"""
@@ -894,13 +955,13 @@ class TransportTest(unittest.TestCase):
expected = text.encode("utf-8")
self.assertEqual(sfile.read(len(expected)), expected)
- @needs_builtin('buffer')
+ @needs_builtin("buffer")
def test_channel_send_buffer(self):
"""
verify sending buffer instances to a channel
"""
self.setup_test_server()
- data = 3 * b'some test data\n whole'
+ data = 3 * b"some test data\n whole"
with self.tc.open_session() as chan:
schan = self.ts.accept(1.0)
if schan is None:
@@ -917,13 +978,13 @@ class TransportTest(unittest.TestCase):
chan.sendall(buffer(data))
self.assertEqual(sfile.read(len(data)), data)
- @needs_builtin('memoryview')
+ @needs_builtin("memoryview")
def test_channel_send_memoryview(self):
"""
verify sending memoryview instances to a channel
"""
self.setup_test_server()
- data = 3 * b'some test data\n whole'
+ data = 3 * b"some test data\n whole"
with self.tc.open_session() as chan:
schan = self.ts.accept(1.0)
if schan is None:
@@ -934,7 +995,7 @@ class TransportTest(unittest.TestCase):
sent = 0
view = memoryview(data)
while sent < len(view):
- sent += chan.send(view[sent:sent+8])
+ sent += chan.send(view[sent : sent + 8])
self.assertEqual(sfile.read(len(data)), data)
# sendall() accepts a memoryview instance
@@ -954,7 +1015,7 @@ class TransportTest(unittest.TestCase):
self.setup_test_server(connect_kwargs={})
# NOTE: this dummy global request kind would normally pass muster
# from the test server.
- self.tc.global_request('acceptable')
+ self.tc.global_request("acceptable")
# Global requests never raise exceptions, even on failure (not sure why
# this was the original design...ugh.) Best we can do to tell failure
# happened is that the client transport's global_response was set back
@@ -969,7 +1030,7 @@ class TransportTest(unittest.TestCase):
# an exception on the client side, unlike the general case...)
self.setup_test_server(connect_kwargs={})
try:
- self.tc.request_port_forward('localhost', 1234)
+ self.tc.request_port_forward("localhost", 1234)
except SSHException as e:
assert "forwarding request denied" in str(e)
else:
diff --git a/tests/test_util.py b/tests/test_util.py
index 90473f43..23b2e86a 100644
--- a/tests/test_util.py
+++ b/tests/test_util.py
@@ -67,53 +67,71 @@ from paramiko import *
class UtilTest(unittest.TestCase):
+
def test_import(self):
"""
verify that all the classes can be imported from paramiko.
"""
symbols = list(globals().keys())
- self.assertTrue('Transport' in symbols)
- self.assertTrue('SSHClient' in symbols)
- self.assertTrue('MissingHostKeyPolicy' in symbols)
- self.assertTrue('AutoAddPolicy' in symbols)
- self.assertTrue('RejectPolicy' in symbols)
- self.assertTrue('WarningPolicy' in symbols)
- self.assertTrue('SecurityOptions' in symbols)
- self.assertTrue('SubsystemHandler' in symbols)
- self.assertTrue('Channel' in symbols)
- self.assertTrue('RSAKey' in symbols)
- self.assertTrue('DSSKey' in symbols)
- self.assertTrue('Message' in symbols)
- self.assertTrue('SSHException' in symbols)
- self.assertTrue('AuthenticationException' in symbols)
- self.assertTrue('PasswordRequiredException' in symbols)
- self.assertTrue('BadAuthenticationType' in symbols)
- self.assertTrue('ChannelException' in symbols)
- self.assertTrue('SFTP' in symbols)
- self.assertTrue('SFTPFile' in symbols)
- self.assertTrue('SFTPHandle' in symbols)
- self.assertTrue('SFTPClient' in symbols)
- self.assertTrue('SFTPServer' in symbols)
- self.assertTrue('SFTPError' in symbols)
- self.assertTrue('SFTPAttributes' in symbols)
- self.assertTrue('SFTPServerInterface' in symbols)
- self.assertTrue('ServerInterface' in symbols)
- self.assertTrue('BufferedFile' in symbols)
- self.assertTrue('Agent' in symbols)
- self.assertTrue('AgentKey' in symbols)
- self.assertTrue('HostKeys' in symbols)
- self.assertTrue('SSHConfig' in symbols)
- self.assertTrue('util' in symbols)
+ self.assertTrue("Transport" in symbols)
+ self.assertTrue("SSHClient" in symbols)
+ self.assertTrue("MissingHostKeyPolicy" in symbols)
+ self.assertTrue("AutoAddPolicy" in symbols)
+ self.assertTrue("RejectPolicy" in symbols)
+ self.assertTrue("WarningPolicy" in symbols)
+ self.assertTrue("SecurityOptions" in symbols)
+ self.assertTrue("SubsystemHandler" in symbols)
+ self.assertTrue("Channel" in symbols)
+ self.assertTrue("RSAKey" in symbols)
+ self.assertTrue("DSSKey" in symbols)
+ self.assertTrue("Message" in symbols)
+ self.assertTrue("SSHException" in symbols)
+ self.assertTrue("AuthenticationException" in symbols)
+ self.assertTrue("PasswordRequiredException" in symbols)
+ self.assertTrue("BadAuthenticationType" in symbols)
+ self.assertTrue("ChannelException" in symbols)
+ self.assertTrue("SFTP" in symbols)
+ self.assertTrue("SFTPFile" in symbols)
+ self.assertTrue("SFTPHandle" in symbols)
+ self.assertTrue("SFTPClient" in symbols)
+ self.assertTrue("SFTPServer" in symbols)
+ self.assertTrue("SFTPError" in symbols)
+ self.assertTrue("SFTPAttributes" in symbols)
+ self.assertTrue("SFTPServerInterface" in symbols)
+ self.assertTrue("ServerInterface" in symbols)
+ self.assertTrue("BufferedFile" in symbols)
+ self.assertTrue("Agent" in symbols)
+ self.assertTrue("AgentKey" in symbols)
+ self.assertTrue("HostKeys" in symbols)
+ self.assertTrue("SSHConfig" in symbols)
+ self.assertTrue("util" in symbols)
def test_parse_config(self):
global test_config_file
f = StringIO(test_config_file)
config = paramiko.util.parse_ssh_config(f)
- self.assertEqual(config._config,
- [{'host': ['*'], 'config': {}}, {'host': ['*'], 'config': {'identityfile': ['~/.ssh/id_rsa'], 'user': 'robey'}},
- {'host': ['*.example.com'], 'config': {'user': 'bjork', 'port': '3333'}},
- {'host': ['*'], 'config': {'crazy': 'something dumb'}},
- {'host': ['spoo.example.com'], 'config': {'crazy': 'something else'}}])
+ self.assertEqual(
+ config._config,
+ [
+ {"host": ["*"], "config": {}},
+ {
+ "host": ["*"],
+ "config": {
+ "identityfile": ["~/.ssh/id_rsa"],
+ "user": "robey",
+ },
+ },
+ {
+ "host": ["*.example.com"],
+ "config": {"user": "bjork", "port": "3333"},
+ },
+ {"host": ["*"], "config": {"crazy": "something dumb"}},
+ {
+ "host": ["spoo.example.com"],
+ "config": {"crazy": "something else"},
+ },
+ ],
+ )
def test_host_config(self):
global test_config_file
@@ -121,44 +139,57 @@ class UtilTest(unittest.TestCase):
config = paramiko.util.parse_ssh_config(f)
for host, values in {
- 'irc.danger.com': {'crazy': 'something dumb',
- 'hostname': 'irc.danger.com',
- 'user': 'robey'},
- 'irc.example.com': {'crazy': 'something dumb',
- 'hostname': 'irc.example.com',
- 'user': 'robey',
- 'port': '3333'},
- 'spoo.example.com': {'crazy': 'something dumb',
- 'hostname': 'spoo.example.com',
- 'user': 'robey',
- 'port': '3333'}
+ "irc.danger.com": {
+ "crazy": "something dumb",
+ "hostname": "irc.danger.com",
+ "user": "robey",
+ },
+ "irc.example.com": {
+ "crazy": "something dumb",
+ "hostname": "irc.example.com",
+ "user": "robey",
+ "port": "3333",
+ },
+ "spoo.example.com": {
+ "crazy": "something dumb",
+ "hostname": "spoo.example.com",
+ "user": "robey",
+ "port": "3333",
+ },
}.items():
- values = dict(values,
+ values = dict(
+ values,
hostname=host,
- identityfile=[os.path.expanduser("~/.ssh/id_rsa")]
+ identityfile=[os.path.expanduser("~/.ssh/id_rsa")],
)
self.assertEqual(
- paramiko.util.lookup_ssh_host_config(host, config),
- values
+ paramiko.util.lookup_ssh_host_config(host, config), values
)
def test_generate_key_bytes(self):
- x = paramiko.util.generate_key_bytes(sha1, b'ABCDEFGH', 'This is my secret passphrase.', 64)
- hex = ''.join(['%02x' % byte_ord(c) for c in x])
- self.assertEqual(hex, '9110e2f6793b69363e58173e9436b13a5a4b339005741d5c680e505f57d871347b4239f14fb5c46e857d5e100424873ba849ac699cea98d729e57b3e84378e8b')
+ x = paramiko.util.generate_key_bytes(
+ sha1, b"ABCDEFGH", "This is my secret passphrase.", 64
+ )
+ hex = "".join(["%02x" % byte_ord(c) for c in x])
+ self.assertEqual(
+ hex,
+ "9110e2f6793b69363e58173e9436b13a5a4b339005741d5c680e505f57d871347b4239f14fb5c46e857d5e100424873ba849ac699cea98d729e57b3e84378e8b",
+ )
def test_host_keys(self):
- with open('hostfile.temp', 'w') as f:
+ with open("hostfile.temp", "w") as f:
f.write(test_hosts_file)
try:
- hostdict = paramiko.util.load_host_keys('hostfile.temp')
+ hostdict = paramiko.util.load_host_keys("hostfile.temp")
self.assertEqual(2, len(hostdict))
self.assertEqual(1, len(list(hostdict.values())[0]))
self.assertEqual(1, len(list(hostdict.values())[1]))
- fp = hexlify(hostdict['secure.example.com']['ssh-rsa'].get_fingerprint()).upper()
- self.assertEqual(b'E6684DB30E109B67B70FF1DC5C7F1363', fp)
+ fp = hexlify(
+ hostdict["secure.example.com"]["ssh-rsa"].get_fingerprint()
+ ).upper()
+ self.assertEqual(b"E6684DB30E109B67B70FF1DC5C7F1363", fp)
finally:
- os.unlink('hostfile.temp')
+ os.unlink("hostfile.temp")
def test_host_config_expose_issue_33(self):
test_config_file = """
@@ -173,36 +204,44 @@ Host *
"""
f = StringIO(test_config_file)
config = paramiko.util.parse_ssh_config(f)
- host = 'www13.example.com'
+ host = "www13.example.com"
self.assertEqual(
paramiko.util.lookup_ssh_host_config(host, config),
- {'hostname': host, 'port': '22'}
+ {"hostname": host, "port": "22"},
)
def test_eintr_retry(self):
- self.assertEqual('foo', paramiko.util.retry_on_signal(lambda: 'foo'))
+ self.assertEqual("foo", paramiko.util.retry_on_signal(lambda: "foo"))
# Variables that are set by raises_intr
intr_errors_remaining = [3]
call_count = [0]
+
def raises_intr():
call_count[0] += 1
if intr_errors_remaining[0] > 0:
intr_errors_remaining[0] -= 1
- raise IOError(errno.EINTR, 'file', 'interrupted system call')
+ raise IOError(errno.EINTR, "file", "interrupted system call")
+
self.assertTrue(paramiko.util.retry_on_signal(raises_intr) is None)
self.assertEqual(0, intr_errors_remaining[0])
self.assertEqual(4, call_count[0])
def raises_ioerror_not_eintr():
- raise IOError(errno.ENOENT, 'file', 'file not found')
- self.assertRaises(IOError,
- lambda: paramiko.util.retry_on_signal(raises_ioerror_not_eintr))
+ raise IOError(errno.ENOENT, "file", "file not found")
+
+ self.assertRaises(
+ IOError,
+ lambda: paramiko.util.retry_on_signal(raises_ioerror_not_eintr),
+ )
def raises_other_exception():
- raise AssertionError('foo')
- self.assertRaises(AssertionError,
- lambda: paramiko.util.retry_on_signal(raises_other_exception))
+ raise AssertionError("foo")
+
+ self.assertRaises(
+ AssertionError,
+ lambda: paramiko.util.retry_on_signal(raises_other_exception),
+ )
def test_proxycommand_config_equals_parsing(self):
"""
@@ -217,17 +256,18 @@ Host equals-delimited
"""
f = StringIO(conf)
config = paramiko.util.parse_ssh_config(f)
- for host in ('space-delimited', 'equals-delimited'):
+ for host in ("space-delimited", "equals-delimited"):
self.assertEqual(
- host_config(host, config)['proxycommand'],
- 'foo bar=biz baz'
+ host_config(host, config)["proxycommand"], "foo bar=biz baz"
)
def test_proxycommand_interpolation(self):
"""
ProxyCommand should perform interpolation on the value
"""
- config = paramiko.util.parse_ssh_config(StringIO("""
+ config = paramiko.util.parse_ssh_config(
+ StringIO(
+ """
Host specific
Port 37
ProxyCommand host %h port %p lol
@@ -238,28 +278,32 @@ Host portonly
Host *
Port 25
ProxyCommand host %h port %p
-"""))
+"""
+ )
+ )
for host, val in (
- ('foo.com', "host foo.com port 25"),
- ('specific', "host specific port 37 lol"),
- ('portonly', "host portonly port 155"),
+ ("foo.com", "host foo.com port 25"),
+ ("specific", "host specific port 37 lol"),
+ ("portonly", "host portonly port 155"),
):
- self.assertEqual(
- host_config(host, config)['proxycommand'],
- val
- )
+ self.assertEqual(host_config(host, config)["proxycommand"], val)
def test_proxycommand_tilde_expansion(self):
"""
Tilde (~) should be expanded inside ProxyCommand
"""
- config = paramiko.util.parse_ssh_config(StringIO("""
+ config = paramiko.util.parse_ssh_config(
+ StringIO(
+ """
Host test
ProxyCommand ssh -F ~/.ssh/test_config bastion nc %h %p
-"""))
+"""
+ )
+ )
self.assertEqual(
- 'ssh -F %s/.ssh/test_config bastion nc test 22' % os.path.expanduser('~'),
- host_config('test', config)['proxycommand']
+ "ssh -F %s/.ssh/test_config bastion nc test 22"
+ % os.path.expanduser("~"),
+ host_config("test", config)["proxycommand"],
)
def test_host_config_test_negation(self):
@@ -278,10 +322,10 @@ Host *
"""
f = StringIO(test_config_file)
config = paramiko.util.parse_ssh_config(f)
- host = 'www13.example.com'
+ host = "www13.example.com"
self.assertEqual(
paramiko.util.lookup_ssh_host_config(host, config),
- {'hostname': host, 'port': '8080'}
+ {"hostname": host, "port": "8080"},
)
def test_host_config_test_proxycommand(self):
@@ -296,20 +340,24 @@ Host proxy-without-equal-divisor
ProxyCommand foo=bar:%h-%p
"""
for host, values in {
- 'proxy-with-equal-divisor-and-space' :{'hostname': 'proxy-with-equal-divisor-and-space',
- 'proxycommand': 'foo=bar'},
- 'proxy-with-equal-divisor-and-no-space':{'hostname': 'proxy-with-equal-divisor-and-no-space',
- 'proxycommand': 'foo=bar'},
- 'proxy-without-equal-divisor' :{'hostname': 'proxy-without-equal-divisor',
- 'proxycommand':
- 'foo=bar:proxy-without-equal-divisor-22'}
+ "proxy-with-equal-divisor-and-space": {
+ "hostname": "proxy-with-equal-divisor-and-space",
+ "proxycommand": "foo=bar",
+ },
+ "proxy-with-equal-divisor-and-no-space": {
+ "hostname": "proxy-with-equal-divisor-and-no-space",
+ "proxycommand": "foo=bar",
+ },
+ "proxy-without-equal-divisor": {
+ "hostname": "proxy-without-equal-divisor",
+ "proxycommand": "foo=bar:proxy-without-equal-divisor-22",
+ },
}.items():
f = StringIO(test_config_file)
config = paramiko.util.parse_ssh_config(f)
self.assertEqual(
- paramiko.util.lookup_ssh_host_config(host, config),
- values
+ paramiko.util.lookup_ssh_host_config(host, config), values
)
def test_host_config_test_identityfile(self):
@@ -327,19 +375,21 @@ Host dsa2*
IdentityFile id_dsa22
"""
for host, values in {
- 'foo' :{'hostname': 'foo',
- 'identityfile': ['id_dsa0', 'id_dsa1']},
- 'dsa2' :{'hostname': 'dsa2',
- 'identityfile': ['id_dsa0', 'id_dsa1', 'id_dsa2', 'id_dsa22']},
- 'dsa22' :{'hostname': 'dsa22',
- 'identityfile': ['id_dsa0', 'id_dsa1', 'id_dsa22']}
+ "foo": {"hostname": "foo", "identityfile": ["id_dsa0", "id_dsa1"]},
+ "dsa2": {
+ "hostname": "dsa2",
+ "identityfile": ["id_dsa0", "id_dsa1", "id_dsa2", "id_dsa22"],
+ },
+ "dsa22": {
+ "hostname": "dsa22",
+ "identityfile": ["id_dsa0", "id_dsa1", "id_dsa22"],
+ },
}.items():
f = StringIO(test_config_file)
config = paramiko.util.parse_ssh_config(f)
self.assertEqual(
- paramiko.util.lookup_ssh_host_config(host, config),
- values
+ paramiko.util.lookup_ssh_host_config(host, config), values
)
def test_config_addressfamily_and_lazy_fqdn(self):
@@ -351,7 +401,9 @@ AddressFamily inet
IdentityFile something_%l_using_fqdn
"""
config = paramiko.util.parse_ssh_config(StringIO(test_config))
- assert config.lookup('meh') # will die during lookup() if bug regresses
+ assert config.lookup(
+ "meh"
+ ) # will die during lookup() if bug regresses
def test_clamp_value(self):
self.assertEqual(32768, paramiko.util.clamp_value(32767, 32768, 32769))
@@ -367,7 +419,9 @@ IdentityFile something_%l_using_fqdn
def test_get_hostnames(self):
f = StringIO(test_config_file)
config = paramiko.util.parse_ssh_config(f)
- self.assertEqual(config.get_hostnames(), {'*', '*.example.com', 'spoo.example.com'})
+ self.assertEqual(
+ config.get_hostnames(), {"*", "*.example.com", "spoo.example.com"}
+ )
def test_quoted_host_names(self):
test_config_file = """\
@@ -384,27 +438,23 @@ Host param4 "p a r" "p" "par" para
Port 4444
"""
res = {
- 'param pam': {'hostname': 'param pam', 'port': '1111'},
- 'param': {'hostname': 'param', 'port': '1111'},
- 'pam': {'hostname': 'pam', 'port': '1111'},
-
- 'param2': {'hostname': 'param2', 'port': '2222'},
-
- 'param3': {'hostname': 'param3', 'port': '3333'},
- 'parara': {'hostname': 'parara', 'port': '3333'},
-
- 'param4': {'hostname': 'param4', 'port': '4444'},
- 'p a r': {'hostname': 'p a r', 'port': '4444'},
- 'p': {'hostname': 'p', 'port': '4444'},
- 'par': {'hostname': 'par', 'port': '4444'},
- 'para': {'hostname': 'para', 'port': '4444'},
+ "param pam": {"hostname": "param pam", "port": "1111"},
+ "param": {"hostname": "param", "port": "1111"},
+ "pam": {"hostname": "pam", "port": "1111"},
+ "param2": {"hostname": "param2", "port": "2222"},
+ "param3": {"hostname": "param3", "port": "3333"},
+ "parara": {"hostname": "parara", "port": "3333"},
+ "param4": {"hostname": "param4", "port": "4444"},
+ "p a r": {"hostname": "p a r", "port": "4444"},
+ "p": {"hostname": "p", "port": "4444"},
+ "par": {"hostname": "par", "port": "4444"},
+ "para": {"hostname": "para", "port": "4444"},
}
f = StringIO(test_config_file)
config = paramiko.util.parse_ssh_config(f)
for host, values in res.items():
self.assertEquals(
- paramiko.util.lookup_ssh_host_config(host, config),
- values
+ paramiko.util.lookup_ssh_host_config(host, config), values
)
def test_quoted_params_in_config(self):
@@ -420,52 +470,44 @@ Host param3 parara
IdentityFile "test rsa key"
"""
res = {
- 'param pam': {'hostname': 'param pam', 'identityfile': ['id_rsa']},
- 'param': {'hostname': 'param', 'identityfile': ['id_rsa']},
- 'pam': {'hostname': 'pam', 'identityfile': ['id_rsa']},
-
- 'param2': {'hostname': 'param2', 'identityfile': ['test rsa key']},
-
- 'param3': {'hostname': 'param3', 'identityfile': ['id_rsa', 'test rsa key']},
- 'parara': {'hostname': 'parara', 'identityfile': ['id_rsa', 'test rsa key']},
+ "param pam": {"hostname": "param pam", "identityfile": ["id_rsa"]},
+ "param": {"hostname": "param", "identityfile": ["id_rsa"]},
+ "pam": {"hostname": "pam", "identityfile": ["id_rsa"]},
+ "param2": {"hostname": "param2", "identityfile": ["test rsa key"]},
+ "param3": {
+ "hostname": "param3",
+ "identityfile": ["id_rsa", "test rsa key"],
+ },
+ "parara": {
+ "hostname": "parara",
+ "identityfile": ["id_rsa", "test rsa key"],
+ },
}
f = StringIO(test_config_file)
config = paramiko.util.parse_ssh_config(f)
for host, values in res.items():
self.assertEquals(
- paramiko.util.lookup_ssh_host_config(host, config),
- values
+ paramiko.util.lookup_ssh_host_config(host, config), values
)
def test_quoted_host_in_config(self):
conf = SSHConfig()
correct_data = {
- 'param': ['param'],
- '"param"': ['param'],
-
- 'param pam': ['param', 'pam'],
- '"param" "pam"': ['param', 'pam'],
- '"param" pam': ['param', 'pam'],
- 'param "pam"': ['param', 'pam'],
-
- 'param "pam" p': ['param', 'pam', 'p'],
- '"param" pam "p"': ['param', 'pam', 'p'],
-
- '"pa ram"': ['pa ram'],
- '"pa ram" pam': ['pa ram', 'pam'],
- 'param "p a m"': ['param', 'p a m'],
+ "param": ["param"],
+ '"param"': ["param"],
+ "param pam": ["param", "pam"],
+ '"param" "pam"': ["param", "pam"],
+ '"param" pam': ["param", "pam"],
+ 'param "pam"': ["param", "pam"],
+ 'param "pam" p': ["param", "pam", "p"],
+ '"param" pam "p"': ["param", "pam", "p"],
+ '"pa ram"': ["pa ram"],
+ '"pa ram" pam': ["pa ram", "pam"],
+ 'param "p a m"': ["param", "p a m"],
}
- incorrect_data = [
- 'param"',
- '"param',
- 'param "pam',
- 'param "pam" "p a',
- ]
+ incorrect_data = ['param"', '"param', 'param "pam', 'param "pam" "p a']
for host, values in correct_data.items():
- self.assertEquals(
- conf._get_hosts(host),
- values
- )
+ self.assertEquals(conf._get_hosts(host), values)
for host in incorrect_data:
self.assertRaises(Exception, conf._get_hosts, host)
@@ -490,15 +532,18 @@ Host proxycommand-with-equals-none
ProxyCommand=None
"""
for host, values in {
- 'proxycommand-standard-none': {'hostname': 'proxycommand-standard-none'},
- 'proxycommand-with-equals-none': {'hostname': 'proxycommand-with-equals-none'}
+ "proxycommand-standard-none": {
+ "hostname": "proxycommand-standard-none"
+ },
+ "proxycommand-with-equals-none": {
+ "hostname": "proxycommand-with-equals-none"
+ },
}.items():
f = StringIO(test_config_file)
config = paramiko.util.parse_ssh_config(f)
self.assertEqual(
- paramiko.util.lookup_ssh_host_config(host, config),
- values
+ paramiko.util.lookup_ssh_host_config(host, config), values
)
def test_proxycommand_none_masking(self):
@@ -521,12 +566,10 @@ Host *
# backwards compatibility reasons in 1.x/2.x) appear completely blank,
# as if the host had no ProxyCommand whatsoever.
# Threw another unrelated host in there just for sanity reasons.
- self.assertFalse('proxycommand' in config.lookup('specific-host'))
+ self.assertFalse("proxycommand" in config.lookup("specific-host"))
self.assertEqual(
- config.lookup('other-host')['proxycommand'],
- 'other-proxy'
+ config.lookup("other-host")["proxycommand"], "other-proxy"
)
self.assertEqual(
- config.lookup('some-random-host')['proxycommand'],
- 'default-proxy'
+ config.lookup("some-random-host")["proxycommand"], "default-proxy"
)