summaryrefslogtreecommitdiffhomepage
path: root/tests/test_auth.py
diff options
context:
space:
mode:
Diffstat (limited to 'tests/test_auth.py')
-rw-r--r--tests/test_auth.py119
1 files changed, 68 insertions, 51 deletions
diff --git a/tests/test_auth.py b/tests/test_auth.py
index 96f7611c..45dcb3a4 100644
--- a/tests/test_auth.py
+++ b/tests/test_auth.py
@@ -25,59 +25,69 @@ import threading
import unittest
from paramiko import (
- Transport, ServerInterface, RSAKey, DSSKey, BadAuthenticationType,
- InteractiveQuery, AuthenticationException,
+ Transport,
+ ServerInterface,
+ RSAKey,
+ DSSKey,
+ BadAuthenticationType,
+ InteractiveQuery,
+ AuthenticationException,
)
from paramiko import AUTH_FAILED, AUTH_PARTIALLY_SUCCESSFUL, AUTH_SUCCESSFUL
from paramiko.py3compat import u
-from tests.loop import LoopSocket
-from tests.util import test_path
-_pwd = u('\u2022')
+from .loop import LoopSocket
+from .util import _support, slow
-class NullServer (ServerInterface):
+_pwd = u("\u2022")
+
+
+class NullServer(ServerInterface):
paranoid_did_password = False
paranoid_did_public_key = False
- paranoid_key = DSSKey.from_private_key_file(test_path('test_dss.key'))
+ paranoid_key = DSSKey.from_private_key_file(_support("test_dss.key"))
def get_allowed_auths(self, username):
- if username == 'slowdive':
- return 'publickey,password'
- if username == 'paranoid':
- if not self.paranoid_did_password and not self.paranoid_did_public_key:
- return 'publickey,password'
+ if username == "slowdive":
+ return "publickey,password"
+ if username == "paranoid":
+ if (
+ not self.paranoid_did_password
+ and not self.paranoid_did_public_key
+ ):
+ return "publickey,password"
elif self.paranoid_did_password:
- return 'publickey'
+ return "publickey"
else:
- return 'password'
- if username == 'commie':
- return 'keyboard-interactive'
- if username == 'utf8':
- return 'password'
- if username == 'non-utf8':
- return 'password'
- return 'publickey'
+ return "password"
+ if username == "commie":
+ return "keyboard-interactive"
+ if username == "utf8":
+ return "password"
+ if username == "non-utf8":
+ return "password"
+ return "publickey"
def check_auth_password(self, username, password):
- if (username == 'slowdive') and (password == 'pygmalion'):
+ if (username == "slowdive") and (password == "pygmalion"):
return AUTH_SUCCESSFUL
- if (username == 'paranoid') and (password == 'paranoid'):
+ if (username == "paranoid") and (password == "paranoid"):
# 2-part auth (even openssh doesn't support this)
self.paranoid_did_password = True
if self.paranoid_did_public_key:
return AUTH_SUCCESSFUL
return AUTH_PARTIALLY_SUCCESSFUL
- if (username == 'utf8') and (password == _pwd):
+ if (username == "utf8") and (password == _pwd):
return AUTH_SUCCESSFUL
- if (username == 'non-utf8') and (password == '\xff'):
+ if (username == "non-utf8") and (password == "\xff"):
return AUTH_SUCCESSFUL
- if username == 'bad-server':
+ if username == "bad-server":
raise Exception("Ack!")
return AUTH_FAILED
def check_auth_publickey(self, username, key):
- if (username == 'paranoid') and (key == self.paranoid_key):
+ if (username == "paranoid") and (key == self.paranoid_key):
# 2-part auth
self.paranoid_did_public_key = True
if self.paranoid_did_password:
@@ -86,20 +96,21 @@ class NullServer (ServerInterface):
return AUTH_FAILED
def check_auth_interactive(self, username, submethods):
- if username == 'commie':
+ if username == "commie":
self.username = username
- return InteractiveQuery('password', 'Please enter a password.', ('Password', False))
+ return InteractiveQuery(
+ "password", "Please enter a password.", ("Password", False)
+ )
return AUTH_FAILED
def check_auth_interactive_response(self, responses):
- if self.username == 'commie':
- if (len(responses) == 1) and (responses[0] == 'cat'):
+ if self.username == "commie":
+ if (len(responses) == 1) and (responses[0] == "cat"):
return AUTH_SUCCESSFUL
return AUTH_FAILED
-class AuthTest (unittest.TestCase):
-
+class AuthTest(unittest.TestCase):
def setUp(self):
self.socks = LoopSocket()
self.sockc = LoopSocket()
@@ -114,7 +125,7 @@ class AuthTest (unittest.TestCase):
self.sockc.close()
def start_server(self):
- host_key = RSAKey.from_private_key_file(test_path('test_rsa.key'))
+ host_key = RSAKey.from_private_key_file(_support("test_rsa.key"))
self.public_host_key = RSAKey(data=host_key.asbytes())
self.ts.add_server_key(host_key)
self.event = threading.Event()
@@ -134,13 +145,16 @@ class AuthTest (unittest.TestCase):
"""
self.start_server()
try:
- self.tc.connect(hostkey=self.public_host_key,
- username='unknown', password='error')
+ self.tc.connect(
+ hostkey=self.public_host_key,
+ username="unknown",
+ password="error",
+ )
self.assertTrue(False)
except:
etype, evalue, etb = sys.exc_info()
self.assertEqual(BadAuthenticationType, etype)
- self.assertEqual(['publickey'], evalue.allowed_types)
+ self.assertEqual(["publickey"], evalue.allowed_types)
def test_2_bad_password(self):
"""
@@ -150,12 +164,12 @@ class AuthTest (unittest.TestCase):
self.start_server()
self.tc.connect(hostkey=self.public_host_key)
try:
- self.tc.auth_password(username='slowdive', password='error')
+ self.tc.auth_password(username="slowdive", password="error")
self.assertTrue(False)
except:
etype, evalue, etb = sys.exc_info()
self.assertTrue(issubclass(etype, AuthenticationException))
- self.tc.auth_password(username='slowdive', password='pygmalion')
+ self.tc.auth_password(username="slowdive", password="pygmalion")
self.verify_finished()
def test_3_multipart_auth(self):
@@ -164,10 +178,12 @@ class AuthTest (unittest.TestCase):
"""
self.start_server()
self.tc.connect(hostkey=self.public_host_key)
- remain = self.tc.auth_password(username='paranoid', password='paranoid')
- self.assertEqual(['publickey'], remain)
- key = DSSKey.from_private_key_file(test_path('test_dss.key'))
- remain = self.tc.auth_publickey(username='paranoid', key=key)
+ remain = self.tc.auth_password(
+ username="paranoid", password="paranoid"
+ )
+ self.assertEqual(["publickey"], remain)
+ key = DSSKey.from_private_key_file(_support("test_dss.key"))
+ remain = self.tc.auth_publickey(username="paranoid", key=key)
self.assertEqual([], remain)
self.verify_finished()
@@ -182,10 +198,11 @@ class AuthTest (unittest.TestCase):
self.got_title = title
self.got_instructions = instructions
self.got_prompts = prompts
- return ['cat']
- remain = self.tc.auth_interactive('commie', handler)
- self.assertEqual(self.got_title, 'password')
- self.assertEqual(self.got_prompts, [('Password', False)])
+ return ["cat"]
+
+ remain = self.tc.auth_interactive("commie", handler)
+ self.assertEqual(self.got_title, "password")
+ self.assertEqual(self.got_prompts, [("Password", False)])
self.assertEqual([], remain)
self.verify_finished()
@@ -196,7 +213,7 @@ class AuthTest (unittest.TestCase):
"""
self.start_server()
self.tc.connect(hostkey=self.public_host_key)
- remain = self.tc.auth_password('commie', 'cat')
+ remain = self.tc.auth_password("commie", "cat")
self.assertEqual([], remain)
self.verify_finished()
@@ -206,7 +223,7 @@ class AuthTest (unittest.TestCase):
"""
self.start_server()
self.tc.connect(hostkey=self.public_host_key)
- remain = self.tc.auth_password('utf8', _pwd)
+ remain = self.tc.auth_password("utf8", _pwd)
self.assertEqual([], remain)
self.verify_finished()
@@ -217,7 +234,7 @@ class AuthTest (unittest.TestCase):
"""
self.start_server()
self.tc.connect(hostkey=self.public_host_key)
- remain = self.tc.auth_password('non-utf8', '\xff')
+ remain = self.tc.auth_password("non-utf8", "\xff")
self.assertEqual([], remain)
self.verify_finished()
@@ -229,7 +246,7 @@ class AuthTest (unittest.TestCase):
self.start_server()
self.tc.connect(hostkey=self.public_host_key)
try:
- remain = self.tc.auth_password('bad-server', 'hello')
+ remain = self.tc.auth_password("bad-server", "hello")
except:
etype, evalue, etb = sys.exc_info()
self.assertTrue(issubclass(etype, AuthenticationException))