diff options
Diffstat (limited to 'transport.py')
-rw-r--r-- | transport.py | 107 |
1 files changed, 92 insertions, 15 deletions
diff --git a/transport.py b/transport.py index 2020b279..c5ff252b 100644 --- a/transport.py +++ b/transport.py @@ -11,12 +11,11 @@ MSG_CHANNEL_OPEN, MSG_CHANNEL_OPEN_SUCCESS, MSG_CHANNEL_OPEN_FAILURE, \ MSG_CHANNEL_EOF, MSG_CHANNEL_CLOSE, MSG_CHANNEL_REQUEST, \ MSG_CHANNEL_SUCCESS, MSG_CHANNEL_FAILURE = range(90, 101) - import sys, os, string, threading, socket, logging, struct from message import Message from channel import Channel from secsh import SSHException -from util import format_binary, safe_string, inflate_long, deflate_long +from util import format_binary, safe_string, inflate_long, deflate_long, tb_strings from rsakey import RSAKey from dsskey import DSSKey from kex_group1 import KexGroup1 @@ -105,6 +104,9 @@ class BaseTransport(threading.Thread): REKEY_PACKETS = pow(2, 30) REKEY_BYTES = pow(2, 30) + OPEN_FAILED_ADMINISTRATIVELY_PROHIBITED, OPEN_FAILED_CONNECT_FAILED, OPEN_FAILED_UNKNOWN_CHANNEL_TYPE, \ + OPEN_FAILED_RESOURCE_SHORTAGE = range(1, 5) + def __init__(self, sock): threading.Thread.__init__(self) self.randpool = randpool @@ -143,6 +145,8 @@ class BaseTransport(threading.Thread): # server mode: self.server_mode = 0 self.server_key_dict = { } + self.server_accepts = [ ] + self.server_accept_cv = threading.Condition(self.lock) def start_client(self, event=None): self.completion_event = event @@ -196,7 +200,7 @@ class BaseTransport(threading.Thread): for chan in self.channels.values(): chan.unlink() - def get_host_key(self): + def get_remote_server_key(self): 'returns (type, key) where type is like "ssh-rsa" and key is an opaque string' if (not self.active) or (not self.initial_kex_done): raise SSHException('No existing session') @@ -225,8 +229,9 @@ class BaseTransport(threading.Thread): m.add_int(chanid) m.add_int(self.window_size) m.add_int(self.max_packet_size) - self.channels[chanid] = chan = Channel(chanid, self) + self.channels[chanid] = chan = Channel(chanid) self.channel_events[chanid] = event = threading.Event() + chan.set_transport(self) chan.set_window(self.window_size, self.max_packet_size) self.send_message(m) finally: @@ -445,10 +450,12 @@ class BaseTransport(threading.Thread): self.send_message(msg) except SSHException, e: self.log(DEBUG, 'Exception: ' + str(e)) + self.log(DEBUG, tb_strings()) except EOFError, e: self.log(DEBUG, 'EOF') except Exception, e: self.log(DEBUG, 'Unknown exception: ' + str(e)) + self.log(DEBUG, tb_strings()) if self.active: self.active = 0 if self.completion_event != None: @@ -503,7 +510,11 @@ class BaseTransport(threading.Thread): comment = buffer[i+1:] buffer = buffer[:i] # parse out version string and make sure it matches - _unused, version, client = string.split(buffer, '-') + segs = buffer.split('-', 2) + if len(segs) < 3: + raise SSHException('Invalid SSH banner') + version = segs[1] + client = segs[2] if version != '1.99' and version != '2.0': raise SSHException('Incompatible version (%s instead of 2.0)' % (version,)) self.log(INFO, 'Connected (version %s, client %s)' % (version, client)) @@ -681,6 +692,7 @@ class BaseTransport(threading.Thread): code = m.get_int() desc = m.get_string() self.log(INFO, 'Disconnect (code %d): %s' % (code, desc)) + def parse_channel_open_success(self, m): chanid = m.get_int() server_chanid = m.get_int() @@ -692,7 +704,7 @@ class BaseTransport(threading.Thread): try: self.lock.acquire() chan = self.channels[chanid] - chan.set_server_channel(server_chanid, server_window_size, server_max_packet_size) + chan.set_remote_channel(server_chanid, server_window_size, server_max_packet_size) self.log(INFO, 'Secsh channel %d opened.' % chanid) if self.channel_events.has_key(chanid): self.channel_events[chanid].set() @@ -719,20 +731,85 @@ class BaseTransport(threading.Thread): self.channel_events[chanid].set() del self.channel_events[chanid] finally: - self.lock_release() + self.lock.release() return + def check_channel_request(self, kind, chanid): + "override me! return object descended from Channel to allow, or None to reject" + return None + def parse_channel_open(self, m): kind = m.get_string() - self.log(DEBUG, 'Rejecting "%s" channel request from server.' % kind) chanid = m.get_int() - msg = Message() - msg.add_byte(chr(MSG_CHANNEL_OPEN_FAILURE)) - msg.add_int(chanid) - msg.add_int(1) - msg.add_string('Client connections are not allowed.') - msg.add_string('en') - self.send_message(msg) + initial_window_size = m.get_int() + max_packet_size = m.get_int() + reject = False + if not self.server_mode: + self.log(DEBUG, 'Rejecting "%s" channel request from server.' % kind) + reject = True + reason = self.OPEN_FAILED_ADMINISTRATIVELY_PROHIBITED + else: + try: + self.lock.acquire() + my_chanid = self.channel_counter + self.channel_counter += 1 + finally: + self.lock.release() + chan = self.check_channel_request(kind, my_chanid) + if (chan is None) or (type(chan) is int): + self.log(DEBUG, 'Rejecting "%s" channel request from client.' % kind) + reject = True + if type(chan) is int: + reason = chan + else: + reason = self.OPEN_FAILED_ADMINISTRATIVELY_PROHIBITED + if reject: + msg = Message() + msg.add_byte(chr(MSG_CHANNEL_OPEN_FAILURE)) + msg.add_int(chanid) + msg.add_int(reason) + msg.add_string('') + msg.add_string('en') + self.send_message(msg) + return + try: + self.lock.acquire() + self.channels[my_chanid] = chan + chan.set_transport(self) + chan.set_window(self.window_size, self.max_packet_size) + chan.set_remote_channel(chanid, initial_window_size, max_packet_size) + finally: + self.lock.release() + m = Message() + m.add_byte(chr(MSG_CHANNEL_OPEN_SUCCESS)) + m.add_int(chanid) + m.add_int(my_chanid) + m.add_int(self.window_size) + m.add_int(self.max_packet_size) + self.send_message(m) + self.log(INFO, 'Secsh channel %d opened.' % my_chanid) + try: + self.lock.acquire() + self.server_accepts.append(chan) + self.server_accept_cv.notify() + finally: + self.lock.release() + + def accept(self, timeout=None): + try: + self.lock.acquire() + if len(self.server_accepts) > 0: + chan = self.server_accepts.pop(0) + else: + self.server_accept_cv.wait(timeout) + if len(self.server_accepts) > 0: + chan = self.server_accepts.pop(0) + else: + # timeout + chan = None + finally: + self.lock.release() + return chan def parse_debug(self, m): always_display = m.get_boolean() |