summaryrefslogtreecommitdiffhomepage
path: root/tests/test_transport.py
diff options
context:
space:
mode:
Diffstat (limited to 'tests/test_transport.py')
-rw-r--r--tests/test_transport.py71
1 files changed, 66 insertions, 5 deletions
diff --git a/tests/test_transport.py b/tests/test_transport.py
index ecea6e19..81254b48 100644
--- a/tests/test_transport.py
+++ b/tests/test_transport.py
@@ -23,7 +23,7 @@ Some unit tests for the ssh2 protocol in Transport.
import sys, time, threading, unittest
import select
from paramiko import Transport, SecurityOptions, ServerInterface, RSAKey, DSSKey, \
- SSHException, BadAuthenticationType, util
+ SSHException, BadAuthenticationType, InteractiveQuery, util
from paramiko import AUTH_FAILED, AUTH_PARTIALLY_SUCCESSFUL, AUTH_SUCCESSFUL
from paramiko import OPEN_SUCCEEDED
from loop import LoopSocket
@@ -44,6 +44,8 @@ class NullServer (ServerInterface):
return 'publickey'
else:
return 'password'
+ if username == 'commie':
+ return 'keyboard-interactive'
return 'publickey'
def check_auth_password(self, username, password):
@@ -66,6 +68,18 @@ class NullServer (ServerInterface):
return AUTH_PARTIALLY_SUCCESSFUL
return AUTH_FAILED
+ def check_auth_interactive(self, username, submethods):
+ if username == 'commie':
+ self.username = username
+ return InteractiveQuery('password', 'Please enter a password.', ('Password', False))
+ return AUTH_FAILED
+
+ def check_auth_interactive_response(self, responses):
+ if self.username == 'commie':
+ if (len(responses) == 1) and (responses[0] == 'cat'):
+ return AUTH_SUCCESSFUL
+ return AUTH_FAILED
+
def check_channel_request(self, kind, chanid):
return OPEN_SUCCEEDED
@@ -270,7 +284,54 @@ class TransportTest (unittest.TestCase):
self.assert_(event.isSet())
self.assert_(self.ts.is_active())
- def test_9_exec_command(self):
+ def test_9_interactive_auth(self):
+ """
+ verify keyboard-interactive auth works.
+ """
+ host_key = RSAKey.from_private_key_file('tests/test_rsa.key')
+ public_host_key = RSAKey(data=str(host_key))
+ self.ts.add_server_key(host_key)
+ event = threading.Event()
+ server = NullServer()
+ self.assert_(not event.isSet())
+ self.ts.start_server(event, server)
+ self.tc.ultra_debug = True
+ self.tc.connect(hostkey=public_host_key)
+
+ def handler(title, instructions, prompts):
+ self.got_title = title
+ self.got_instructions = instructions
+ self.got_prompts = prompts
+ return ['cat']
+ remain = self.tc.auth_interactive('commie', handler)
+ self.assertEquals(self.got_title, 'password')
+ self.assertEquals(self.got_prompts, [('Password', False)])
+ self.assertEquals([], remain)
+ event.wait(1.0)
+ self.assert_(event.isSet())
+ self.assert_(self.ts.is_active())
+
+ def test_A_interactive_auth_fallback(self):
+ """
+ verify that a password auth attempt will fallback to "interactive"
+ if password auth isn't supported but interactive is.
+ """
+ host_key = RSAKey.from_private_key_file('tests/test_rsa.key')
+ public_host_key = RSAKey(data=str(host_key))
+ self.ts.add_server_key(host_key)
+ event = threading.Event()
+ server = NullServer()
+ self.assert_(not event.isSet())
+ self.ts.start_server(event, server)
+ self.tc.ultra_debug = True
+ self.tc.connect(hostkey=public_host_key)
+ remain = self.tc.auth_password('commie', 'cat')
+ self.assertEquals([], remain)
+ event.wait(1.0)
+ self.assert_(event.isSet())
+ self.assert_(self.ts.is_active())
+
+ def test_B_exec_command(self):
"""
verify that exec_command() does something reasonable.
"""
@@ -320,7 +381,7 @@ class TransportTest (unittest.TestCase):
self.assertEquals('This is on stderr.\n', f.readline())
self.assertEquals('', f.readline())
- def test_A_invoke_shell(self):
+ def test_C_invoke_shell(self):
"""
verify that invoke_shell() does something reasonable.
"""
@@ -347,7 +408,7 @@ class TransportTest (unittest.TestCase):
chan.close()
self.assertEquals('', f.readline())
- def test_B_exit_status(self):
+ def test_D_exit_status(self):
"""
verify that get_exit_status() works.
"""
@@ -381,7 +442,7 @@ class TransportTest (unittest.TestCase):
self.assertEquals(23, chan.recv_exit_status())
chan.close()
- def test_C_select(self):
+ def test_E_select(self):
"""
verify that select() on a channel works.
"""