summaryrefslogtreecommitdiffhomepage
diff options
context:
space:
mode:
-rw-r--r--paramiko/__init__.py2
-rw-r--r--paramiko/auth_transport.py94
-rw-r--r--paramiko/ssh_exception.py23
-rw-r--r--paramiko/transport.py44
-rw-r--r--tests/test_transport.py56
5 files changed, 177 insertions, 42 deletions
diff --git a/paramiko/__init__.py b/paramiko/__init__.py
index f7ddbc1c..2c430d78 100644
--- a/paramiko/__init__.py
+++ b/paramiko/__init__.py
@@ -77,6 +77,7 @@ DSSKey = dsskey.DSSKey
SSHException = ssh_exception.SSHException
Message = message.Message
PasswordRequiredException = ssh_exception.PasswordRequiredException
+BadAuthenticationType = ssh_exception.BadAuthenticationType
SFTP = sftp_client.SFTP
SFTPClient = sftp_client.SFTPClient
SFTPServer = sftp_server.SFTPServer
@@ -105,6 +106,7 @@ __all__ = [ 'Transport',
'Message',
'SSHException',
'PasswordRequiredException',
+ 'BadAuthenticationType',
'SFTP',
'SFTPHandle',
'SFTPClient',
diff --git a/paramiko/auth_transport.py b/paramiko/auth_transport.py
index 2a6a1a29..45ef8ce9 100644
--- a/paramiko/auth_transport.py
+++ b/paramiko/auth_transport.py
@@ -23,6 +23,8 @@ L{Transport} is a subclass of L{BaseTransport} that handles authentication.
This separation keeps either class file from being too unwieldy.
"""
+import threading
+
# this helps freezing utils
import encodings.utf_8
@@ -30,7 +32,7 @@ from common import *
import util
from transport import BaseTransport
from message import Message
-from ssh_exception import SSHException
+from ssh_exception import SSHException, BadAuthenticationType
class Transport (BaseTransport):
@@ -102,12 +104,19 @@ class Transport (BaseTransport):
else:
return self.username
- def auth_publickey(self, username, key, event):
+ def auth_publickey(self, username, key, event=None):
"""
Authenticate to the server using a private key. The key is used to
- sign data from the server, so it must include the private part. The
- given L{event} is triggered on success or failure. On success,
- L{is_authenticated} will return C{True}.
+ sign data from the server, so it must include the private part.
+
+ If an C{event} is passed in, this method will return immediately, and
+ the event will be triggered once authentication succeeds or fails. On
+ success, L{is_authenticated} will return C{True}. On failure, you may
+ use L{get_exception} to get more detailed error information.
+
+ Since 1.1, if no event is passed, this method will block until the
+ authentication succeeds or fails. On failure, an exception is raised.
+ Otherwise, the method simply returns.
@param username: the username to authenticate as.
@type username: string
@@ -116,26 +125,46 @@ class Transport (BaseTransport):
@param event: an event to trigger when the authentication attempt is
complete (whether it was successful or not)
@type event: threading.Event
+
+ @raise BadAuthenticationType: if public-key authentication isn't
+ allowed by the server for this user (and no event was passed in).
+ @raise SSHException: if the authentication failed (and no event was
+ passed in).
"""
if (not self.active) or (not self.initial_kex_done):
# we should never try to authenticate unless we're on a secure link
raise SSHException('No existing session')
+ if event is None:
+ my_event = threading.Event()
+ else:
+ my_event = event
+ self.lock.acquire()
try:
- self.lock.acquire()
- self.auth_event = event
+ self.auth_event = my_event
self.auth_method = 'publickey'
self.username = username
self.private_key = key
self._request_auth()
finally:
self.lock.release()
+ if event is not None:
+ # caller wants to wait for event themselves
+ return
+ self._wait_for_response(my_event)
- def auth_password(self, username, password, event):
+ def auth_password(self, username, password, event=None):
"""
Authenticate to the server using a password. The username and password
- are sent over an encrypted link, and the given L{event} is triggered on
- success or failure. On success, L{is_authenticated} will return
- C{True}.
+ are sent over an encrypted link.
+
+ If an C{event} is passed in, this method will return immediately, and
+ the event will be triggered once authentication succeeds or fails. On
+ success, L{is_authenticated} will return C{True}. On failure, you may
+ use L{get_exception} to get more detailed error information.
+
+ Since 1.1, if no event is passed, this method will block until the
+ authentication succeeds or fails. On failure, an exception is raised.
+ Otherwise, the method simply returns.
@param username: the username to authenticate as.
@type username: string
@@ -144,19 +173,32 @@ class Transport (BaseTransport):
@param event: an event to trigger when the authentication attempt is
complete (whether it was successful or not)
@type event: threading.Event
+
+ @raise BadAuthenticationType: if password authentication isn't
+ allowed by the server for this user (and no event was passed in).
+ @raise SSHException: if the authentication failed (and no event was
+ passed in).
"""
if (not self.active) or (not self.initial_kex_done):
# we should never try to send the password unless we're on a secure link
raise SSHException('No existing session')
+ if event is None:
+ my_event = threading.Event()
+ else:
+ my_event = event
+ self.lock.acquire()
try:
- self.lock.acquire()
- self.auth_event = event
+ self.auth_event = my_event
self.auth_method = 'password'
self.username = username
self.password = password
self._request_auth()
finally:
self.lock.release()
+ if event is not None:
+ # caller wants to wait for event themselves
+ return
+ self._wait_for_response(my_event)
### internals...
@@ -198,6 +240,22 @@ class Transport (BaseTransport):
m.add_string(str(key))
return str(m)
+ def _wait_for_response(self, event):
+ while True:
+ event.wait(0.1)
+ if not self.active:
+ e = self.get_exception()
+ if e is None:
+ e = SSHException('Authentication failed.')
+ raise e
+ if event.isSet():
+ break
+ if not self.is_authenticated():
+ e = self.get_exception()
+ if e is None:
+ e = SSHException('Authentication failed.')
+ raise e
+
def _parse_service_request(self, m):
service = m.get_string()
if self.server_mode and (service == 'ssh-userauth'):
@@ -264,12 +322,12 @@ class Transport (BaseTransport):
result = self.server_object.check_auth_none(username)
elif method == 'password':
changereq = m.get_boolean()
- password = m.get_string().decode('UTF-8')
+ password = m.get_string().decode('UTF-8', 'replace')
if changereq:
# always treated as failure, since we don't support changing passwords, but collect
# the list of valid auth types from the callback anyway
self._log(DEBUG, 'Auth request to change passwords (rejected)')
- newpassword = m.get_string().decode('UTF-8')
+ newpassword = m.get_string().decode('UTF-8', 'replace')
result = AUTH_FAILED
else:
result = self.server_object.check_auth_password(username, password)
@@ -339,13 +397,13 @@ class Transport (BaseTransport):
if partial:
self._log(INFO, 'Authentication continues...')
self._log(DEBUG, 'Methods: ' + str(partial))
- # FIXME - do something
+ # FIXME: multi-part auth not supported
pass
+ if self.auth_method not in authlist:
+ self.saved_exception = BadAuthenticationType('Bad authentication type', authlist)
self._log(INFO, 'Authentication failed.')
self.authenticated = False
- # FIXME: i don't think we need to close() necessarily here
self.username = None
- self.close()
if self.auth_event != None:
self.auth_event.set()
diff --git a/paramiko/ssh_exception.py b/paramiko/ssh_exception.py
index 6321821c..1f9173e1 100644
--- a/paramiko/ssh_exception.py
+++ b/paramiko/ssh_exception.py
@@ -25,12 +25,31 @@ Exceptions defined by paramiko.
class SSHException (Exception):
"""
- Exception thrown by failures in SSH2 protocol negotiation or logic errors.
+ Exception raised by failures in SSH2 protocol negotiation or logic errors.
"""
pass
class PasswordRequiredException (SSHException):
"""
- Exception thrown when a password is needed to unlock a private key file.
+ Exception raised when a password is needed to unlock a private key file.
"""
pass
+
+class BadAuthenticationType (SSHException):
+ """
+ Exception raised when an authentication type (like password) is used, but
+ the server isn't allowing that type. (It may only allow public-key, for
+ example.)
+
+ @ivar allowed_types: list of allowed authentication types provided by the
+ server (possible values are: C{"none"}, C{"password"}, and
+ C{"publickey"}).
+ @type allowed_types: list
+
+ @since: 1.1
+ """
+ allowed_types = []
+
+ def __init__(self, explanation, types):
+ SSHException.__init__(self, explanation)
+ self.allowed_types = types
diff --git a/paramiko/transport.py b/paramiko/transport.py
index 3e50ffcf..7e457410 100644
--- a/paramiko/transport.py
+++ b/paramiko/transport.py
@@ -681,8 +681,8 @@ class BaseTransport (threading.Thread):
return False
def accept(self, timeout=None):
+ self.lock.acquire()
try:
- self.lock.acquire()
if len(self.server_accepts) > 0:
chan = self.server_accepts.pop(0)
else:
@@ -740,8 +740,7 @@ class BaseTransport (threading.Thread):
while 1:
event.wait(0.1)
if not self.active:
- e = self.saved_exception
- self.saved_exception = None
+ e = self.get_exception()
if e is not None:
raise e
raise SSHException('Negotiation failed.')
@@ -759,27 +758,34 @@ class BaseTransport (threading.Thread):
self._log(DEBUG, 'Host key verified (%s)' % hostkey.get_name())
if (pkey is not None) or (password is not None):
- event.clear()
if password is not None:
self._log(DEBUG, 'Attempting password auth...')
- self.auth_password(username, password, event)
+ self.auth_password(username, password)
else:
- self._log(DEBUG, 'Attempting pkey auth...')
- self.auth_publickey(username, pkey, event)
- while 1:
- event.wait(0.1)
- if not self.active:
- e = self.saved_exception
- self.saved_exception = None
- if e is not None:
- raise e
- raise SSHException('Authentication failed.')
- if event.isSet():
- break
- if not self.is_authenticated():
- raise SSHException('Authentication failed.')
+ self._log(DEBUG, 'Attempting public-key auth...')
+ self.auth_publickey(username, pkey)
return
+
+ def get_exception(self):
+ """
+ Return any exception that happened during the last server request.
+ This can be used to fetch more specific error information after using
+ calls like L{start_client}. The exception (if any) is cleared after
+ this call.
+
+ @return: an exception, or C{None} if there is no stored exception.
+ @rtype: Exception
+
+ @since: 1.1
+ """
+ self.lock.acquire()
+ try:
+ e = self.saved_exception
+ self.saved_exception = None
+ return e
+ finally:
+ self.lock.release()
def set_subsystem_handler(self, name, handler, *larg, **kwarg):
"""
diff --git a/tests/test_transport.py b/tests/test_transport.py
index 93dc8b7f..b55160a7 100644
--- a/tests/test_transport.py
+++ b/tests/test_transport.py
@@ -22,14 +22,17 @@
Some unit tests for the ssh2 protocol in Transport.
"""
-import unittest, threading
-from paramiko import Transport, SecurityOptions, ServerInterface, RSAKey, DSSKey
+import sys, unittest, threading
+from paramiko import Transport, SecurityOptions, ServerInterface, RSAKey, DSSKey, \
+ SSHException, BadAuthenticationType
from paramiko import AUTH_FAILED, AUTH_SUCCESSFUL
from loop import LoopSocket
class NullServer (ServerInterface):
def get_allowed_auths(self, username):
+ if username == 'slowdive':
+ return 'publickey,password'
return 'publickey'
def check_auth_password(self, username, password):
@@ -90,4 +93,51 @@ class TransportTest (unittest.TestCase):
self.assert_(event.isSet())
self.assert_(self.ts.is_active())
-
+ def test_3_bad_auth_type(self):
+ """
+ verify that we get the right exception when an unsupported auth
+ type is requested.
+ """
+ host_key = RSAKey.from_private_key_file('tests/test_rsa.key')
+ public_host_key = RSAKey(data=str(host_key))
+ self.ts.add_server_key(host_key)
+ event = threading.Event()
+ server = NullServer()
+ self.assert_(not event.isSet())
+ self.ts.start_server(event, server)
+ self.tc.ultra_debug = True
+ try:
+ self.tc.connect(hostkey=public_host_key,
+ username='unknown', password='error')
+ self.assert_(False)
+ except:
+ etype, evalue, etb = sys.exc_info()
+ self.assertEquals(BadAuthenticationType, etype)
+ self.assertEquals(['publickey'], evalue.allowed_types)
+
+ def test_4_bad_password(self):
+ """
+ verify that a bad password gets the right exception, and that a retry
+ with the right password works.
+ """
+ host_key = RSAKey.from_private_key_file('tests/test_rsa.key')
+ public_host_key = RSAKey(data=str(host_key))
+ self.ts.add_server_key(host_key)
+ event = threading.Event()
+ server = NullServer()
+ self.assert_(not event.isSet())
+ self.ts.start_server(event, server)
+ self.tc.ultra_debug = True
+ self.tc.connect(hostkey=public_host_key)
+ try:
+ self.tc.auth_password(username='slowdive', password='error')
+ self.assert_(False)
+ except:
+ etype, evalue, etb = sys.exc_info()
+ self.assertEquals(SSHException, etype)
+ self.tc.auth_password(username='slowdive', password='pygmalion')
+ event.wait(1.0)
+ self.assert_(event.isSet())
+ self.assert_(self.ts.is_active())
+
+