summaryrefslogtreecommitdiffhomepage
diff options
context:
space:
mode:
authorRobey Pointer <robey@lag.net>2006-05-03 19:52:37 -0700
committerRobey Pointer <robey@lag.net>2006-05-03 19:52:37 -0700
commit581103665b82f50d71aacb12881f9fd0b3fcca88 (patch)
tree829431f48164a232421bb4a6a7396540c40d21bc
parentaac434e9b08f436dd742c744a56e7eb3c62306b8 (diff)
[project @ robey@lag.net-20060504025237-a015ee747d9a2e75]
if open_channel fails, it now raises ChannelException. added a unit test for that too. renegotiate_keys will also raise an exception now instead of returning a bool.
-rw-r--r--paramiko/__init__.py5
-rw-r--r--paramiko/ssh_exception.py11
-rw-r--r--paramiko/transport.py73
-rw-r--r--tests/test_transport.py40
4 files changed, 88 insertions, 41 deletions
diff --git a/paramiko/__init__.py b/paramiko/__init__.py
index c341d2bd..e9d504a4 100644
--- a/paramiko/__init__.py
+++ b/paramiko/__init__.py
@@ -69,7 +69,7 @@ from transport import randpool, SecurityOptions, Transport
from client import SSHClient, MissingHostKeyPolicy, AutoAddPolicy, RejectPolicy
from auth_handler import AuthHandler
from channel import Channel, ChannelFile
-from ssh_exception import SSHException, PasswordRequiredException, BadAuthenticationType
+from ssh_exception import SSHException, PasswordRequiredException, BadAuthenticationType, ChannelException
from server import ServerInterface, SubsystemHandler, InteractiveQuery
from rsakey import RSAKey
from dsskey import DSSKey
@@ -94,7 +94,7 @@ for x in (Transport, SecurityOptions, Channel, SFTPServer, SSHException,
SFTP, SFTPClient, SFTPServer, Message, Packetizer, SFTPAttributes,
SFTPHandle, SFTPServerInterface, BufferedFile, Agent, AgentKey,
PKey, BaseSFTP, SFTPFile, ServerInterface, HostKeys, SSHClient,
- MissingHostKeyPolicy, AutoAddPolicy, RejectPolicy):
+ MissingHostKeyPolicy, AutoAddPolicy, RejectPolicy, ChannelException):
x.__module__ = 'paramiko'
from common import AUTH_SUCCESSFUL, AUTH_PARTIALLY_SUCCESSFUL, AUTH_FAILED, \
@@ -119,6 +119,7 @@ __all__ = [ 'Transport',
'SSHException',
'PasswordRequiredException',
'BadAuthenticationType',
+ 'ChannelException',
'SFTP',
'SFTPFile',
'SFTPHandle',
diff --git a/paramiko/ssh_exception.py b/paramiko/ssh_exception.py
index 3aa4860d..99eaa648 100644
--- a/paramiko/ssh_exception.py
+++ b/paramiko/ssh_exception.py
@@ -67,3 +67,14 @@ class PartialAuthentication (SSHException):
def __init__(self, types):
SSHException.__init__(self, 'partial authentication')
self.allowed_types = types
+
+
+class ChannelException (SSHException):
+ """
+ Exception raised when an attempt to open a new L{Channel} fails.
+
+ @since: 1.6
+ """
+ def __init__(self, code, text):
+ SSHException.__init__(self, text)
+ self.code = code
diff --git a/paramiko/transport.py b/paramiko/transport.py
index 6fe72189..31a5423a 100644
--- a/paramiko/transport.py
+++ b/paramiko/transport.py
@@ -43,7 +43,7 @@ from paramiko.primes import ModulusPack
from paramiko.rsakey import RSAKey
from paramiko.server import ServerInterface
from paramiko.sftp_client import SFTPClient
-from paramiko.ssh_exception import SSHException, BadAuthenticationType
+from paramiko.ssh_exception import SSHException, BadAuthenticationType, ChannelException
# these come from PyCrypt
# http://www.amk.ca/python/writing/pycrypt/
@@ -558,7 +558,7 @@ class Transport (threading.Thread):
@raise SSHException: if no session is currently active.
- @return: public key of the remote server.
+ @return: public key of the remote server
@rtype: L{PKey <pkey.PKey>}
"""
if (not self.active) or (not self.initial_kex_done):
@@ -570,7 +570,7 @@ class Transport (threading.Thread):
Return true if this session is active (open).
@return: True if the session is still active (open); False if the
- session is closed.
+ session is closed
@rtype: bool
"""
return self.active
@@ -580,9 +580,11 @@ class Transport (threading.Thread):
Request a new channel to the server, of type C{"session"}. This
is just an alias for C{open_channel('session')}.
- @return: a new L{Channel} on success, or C{None} if the request is
- rejected or the session ends prematurely.
+ @return: a new L{Channel}
@rtype: L{Channel}
+
+ @raise SSHException: if the request is rejected or the session ends
+ prematurely
"""
return self.open_channel('session')
@@ -594,18 +596,20 @@ class Transport (threading.Thread):
L{connect} or L{start_client}) and authenticating.
@param kind: the kind of channel requested (usually C{"session"},
- C{"forwarded-tcpip"} or C{"direct-tcpip"}).
+ C{"forwarded-tcpip"} or C{"direct-tcpip"})
@type kind: str
@param dest_addr: the destination address of this port forwarding,
if C{kind} is C{"forwarded-tcpip"} or C{"direct-tcpip"} (ignored
- for other channel types).
+ for other channel types)
@type dest_addr: (str, int)
@param src_addr: the source address of this port forwarding, if
- C{kind} is C{"forwarded-tcpip"} or C{"direct-tcpip"}.
+ C{kind} is C{"forwarded-tcpip"} or C{"direct-tcpip"}
@type src_addr: (str, int)
- @return: a new L{Channel} on success, or C{None} if the request is
- rejected or the session ends prematurely.
+ @return: a new L{Channel} on success
@rtype: L{Channel}
+
+ @raise SSHException: if the request is rejected or the session ends
+ prematurely
"""
chan = None
if not self.active:
@@ -637,19 +641,25 @@ class Transport (threading.Thread):
finally:
self.lock.release()
self._send_user_message(m)
- while 1:
+ while True:
event.wait(0.1);
if not self.active:
- return None
+ e = self.get_exception()
+ if e is None:
+ e = SSHException('Unable to open channel.')
+ raise e
if event.isSet():
break
+ self.lock.acquire()
try:
- self.lock.acquire()
- if not self.channels.has_key(chanid):
- chan = None
+ if self.channels.has_key(chanid):
+ return chan
finally:
self.lock.release()
- return chan
+ e = self.get_exception()
+ if e is None:
+ e = SSHException('Unable to open channel.')
+ raise e
def open_sftp_client(self):
"""
@@ -689,22 +699,23 @@ class Transport (threading.Thread):
bytes sent or received, but this method gives you the option of forcing
new keys whenever you want. Negotiating new keys causes a pause in
traffic both ways as the two sides swap keys and do computations. This
- method returns when the session has switched to new keys, or the
- session has died mid-negotiation.
+ method returns when the session has switched to new keys.
- @return: True if the renegotiation was successful, and the link is
- using new keys; False if the session dropped during renegotiation.
- @rtype: bool
+ @raise SSHException: if the key renegotiation failed (which causes the
+ session to end)
"""
self.completion_event = threading.Event()
self._send_kex_init()
- while 1:
- self.completion_event.wait(0.1);
+ while True:
+ self.completion_event.wait(0.1)
if not self.active:
- return False
+ e = self.get_exception()
+ if e is not None:
+ raise e
+ raise SSHException('Negotiation failed.')
if self.completion_event.isSet():
break
- return True
+ return
def set_keepalive(self, interval):
"""
@@ -1017,7 +1028,7 @@ class Transport (threading.Thread):
except SSHException, ignored:
# attempt failed; just raise the original exception
raise x
- return None
+ return None
def auth_publickey(self, username, key, event=None):
"""
@@ -1741,14 +1752,12 @@ class Transport (threading.Thread):
reason = m.get_int()
reason_str = m.get_string()
lang = m.get_string()
- if CONNECTION_FAILED_CODE.has_key(reason):
- reason_text = CONNECTION_FAILED_CODE[reason]
- else:
- reason_text = '(unknown code)'
+ reason_text = CONNECTION_FAILED_CODE.get(reason, '(unknown code)')
self._log(INFO, 'Secsh channel %d open FAILED: %s: %s' % (chanid, reason_str, reason_text))
+ self.lock.acquire()
try:
- self.lock.aquire()
- if self.channels.has_key(chanid):
+ self.saved_exception = ChannelException(reason, reason_text)
+ if self.channel_events.has_key(chanid):
del self.channels[chanid]
if self.channel_events.has_key(chanid):
self.channel_events[chanid].set()
diff --git a/tests/test_transport.py b/tests/test_transport.py
index 5fcc7865..b2e8b6f6 100644
--- a/tests/test_transport.py
+++ b/tests/test_transport.py
@@ -23,9 +23,9 @@ Some unit tests for the ssh2 protocol in Transport.
import sys, time, threading, unittest
import select
from paramiko import Transport, SecurityOptions, ServerInterface, RSAKey, DSSKey, \
- SSHException, BadAuthenticationType, InteractiveQuery, util
+ SSHException, BadAuthenticationType, InteractiveQuery, util, ChannelException
from paramiko import AUTH_FAILED, AUTH_PARTIALLY_SUCCESSFUL, AUTH_SUCCESSFUL
-from paramiko import OPEN_SUCCEEDED
+from paramiko import OPEN_SUCCEEDED, OPEN_FAILED_ADMINISTRATIVELY_PROHIBITED
from loop import LoopSocket
@@ -81,6 +81,8 @@ class NullServer (ServerInterface):
return AUTH_FAILED
def check_channel_request(self, kind, chanid):
+ if kind == 'bogus':
+ return OPEN_FAILED_ADMINISTRATIVELY_PROHIBITED
return OPEN_SUCCEEDED
def check_channel_exec_request(self, channel, command):
@@ -189,7 +191,7 @@ class TransportTest (unittest.TestCase):
self.assertEquals(12, self.tc.packetizer.get_mac_size_in())
self.tc.send_ignore(1024)
- self.assert_(self.tc.renegotiate_keys())
+ self.tc.renegotiate_keys()
self.ts.send_ignore(1024)
def test_5_keepalive(self):
@@ -408,7 +410,31 @@ class TransportTest (unittest.TestCase):
chan.close()
self.assertEquals('', f.readline())
- def test_D_exit_status(self):
+ def test_D_channel_exception(self):
+ """
+ verify that ChannelException is thrown for a bad open-channel request.
+ """
+ host_key = RSAKey.from_private_key_file('tests/test_rsa.key')
+ public_host_key = RSAKey(data=str(host_key))
+ self.ts.add_server_key(host_key)
+ event = threading.Event()
+ server = NullServer()
+ self.assert_(not event.isSet())
+ self.ts.start_server(event, server)
+ self.tc.ultra_debug = True
+ self.tc.connect(hostkey=public_host_key)
+ self.tc.auth_password(username='slowdive', password='pygmalion')
+ event.wait(1.0)
+ self.assert_(event.isSet())
+ self.assert_(self.ts.is_active())
+
+ try:
+ chan = self.tc.open_channel('bogus')
+ self.fail('expected exception')
+ except ChannelException, x:
+ self.assert_(x.code == OPEN_FAILED_ADMINISTRATIVELY_PROHIBITED)
+
+ def test_E_exit_status(self):
"""
verify that get_exit_status() works.
"""
@@ -442,7 +468,7 @@ class TransportTest (unittest.TestCase):
self.assertEquals(23, chan.recv_exit_status())
chan.close()
- def test_E_select(self):
+ def test_F_select(self):
"""
verify that select() on a channel works.
"""
@@ -505,7 +531,7 @@ class TransportTest (unittest.TestCase):
chan.close()
- def test_F_renegotiate(self):
+ def test_G_renegotiate(self):
"""
verify that a transport can correctly renegotiate mid-stream.
"""
@@ -541,7 +567,7 @@ class TransportTest (unittest.TestCase):
schan.close()
- def test_G_compression(self):
+ def test_H_compression(self):
"""
verify that zlib compression is basically working.
"""