summaryrefslogtreecommitdiffhomepage
path: root/transport.py
diff options
context:
space:
mode:
Diffstat (limited to 'transport.py')
-rw-r--r--transport.py107
1 files changed, 92 insertions, 15 deletions
diff --git a/transport.py b/transport.py
index 2020b279..c5ff252b 100644
--- a/transport.py
+++ b/transport.py
@@ -11,12 +11,11 @@ MSG_CHANNEL_OPEN, MSG_CHANNEL_OPEN_SUCCESS, MSG_CHANNEL_OPEN_FAILURE, \
MSG_CHANNEL_EOF, MSG_CHANNEL_CLOSE, MSG_CHANNEL_REQUEST, \
MSG_CHANNEL_SUCCESS, MSG_CHANNEL_FAILURE = range(90, 101)
-
import sys, os, string, threading, socket, logging, struct
from message import Message
from channel import Channel
from secsh import SSHException
-from util import format_binary, safe_string, inflate_long, deflate_long
+from util import format_binary, safe_string, inflate_long, deflate_long, tb_strings
from rsakey import RSAKey
from dsskey import DSSKey
from kex_group1 import KexGroup1
@@ -105,6 +104,9 @@ class BaseTransport(threading.Thread):
REKEY_PACKETS = pow(2, 30)
REKEY_BYTES = pow(2, 30)
+ OPEN_FAILED_ADMINISTRATIVELY_PROHIBITED, OPEN_FAILED_CONNECT_FAILED, OPEN_FAILED_UNKNOWN_CHANNEL_TYPE, \
+ OPEN_FAILED_RESOURCE_SHORTAGE = range(1, 5)
+
def __init__(self, sock):
threading.Thread.__init__(self)
self.randpool = randpool
@@ -143,6 +145,8 @@ class BaseTransport(threading.Thread):
# server mode:
self.server_mode = 0
self.server_key_dict = { }
+ self.server_accepts = [ ]
+ self.server_accept_cv = threading.Condition(self.lock)
def start_client(self, event=None):
self.completion_event = event
@@ -196,7 +200,7 @@ class BaseTransport(threading.Thread):
for chan in self.channels.values():
chan.unlink()
- def get_host_key(self):
+ def get_remote_server_key(self):
'returns (type, key) where type is like "ssh-rsa" and key is an opaque string'
if (not self.active) or (not self.initial_kex_done):
raise SSHException('No existing session')
@@ -225,8 +229,9 @@ class BaseTransport(threading.Thread):
m.add_int(chanid)
m.add_int(self.window_size)
m.add_int(self.max_packet_size)
- self.channels[chanid] = chan = Channel(chanid, self)
+ self.channels[chanid] = chan = Channel(chanid)
self.channel_events[chanid] = event = threading.Event()
+ chan.set_transport(self)
chan.set_window(self.window_size, self.max_packet_size)
self.send_message(m)
finally:
@@ -445,10 +450,12 @@ class BaseTransport(threading.Thread):
self.send_message(msg)
except SSHException, e:
self.log(DEBUG, 'Exception: ' + str(e))
+ self.log(DEBUG, tb_strings())
except EOFError, e:
self.log(DEBUG, 'EOF')
except Exception, e:
self.log(DEBUG, 'Unknown exception: ' + str(e))
+ self.log(DEBUG, tb_strings())
if self.active:
self.active = 0
if self.completion_event != None:
@@ -503,7 +510,11 @@ class BaseTransport(threading.Thread):
comment = buffer[i+1:]
buffer = buffer[:i]
# parse out version string and make sure it matches
- _unused, version, client = string.split(buffer, '-')
+ segs = buffer.split('-', 2)
+ if len(segs) < 3:
+ raise SSHException('Invalid SSH banner')
+ version = segs[1]
+ client = segs[2]
if version != '1.99' and version != '2.0':
raise SSHException('Incompatible version (%s instead of 2.0)' % (version,))
self.log(INFO, 'Connected (version %s, client %s)' % (version, client))
@@ -681,6 +692,7 @@ class BaseTransport(threading.Thread):
code = m.get_int()
desc = m.get_string()
self.log(INFO, 'Disconnect (code %d): %s' % (code, desc))
+
def parse_channel_open_success(self, m):
chanid = m.get_int()
server_chanid = m.get_int()
@@ -692,7 +704,7 @@ class BaseTransport(threading.Thread):
try:
self.lock.acquire()
chan = self.channels[chanid]
- chan.set_server_channel(server_chanid, server_window_size, server_max_packet_size)
+ chan.set_remote_channel(server_chanid, server_window_size, server_max_packet_size)
self.log(INFO, 'Secsh channel %d opened.' % chanid)
if self.channel_events.has_key(chanid):
self.channel_events[chanid].set()
@@ -719,20 +731,85 @@ class BaseTransport(threading.Thread):
self.channel_events[chanid].set()
del self.channel_events[chanid]
finally:
- self.lock_release()
+ self.lock.release()
return
+ def check_channel_request(self, kind, chanid):
+ "override me! return object descended from Channel to allow, or None to reject"
+ return None
+
def parse_channel_open(self, m):
kind = m.get_string()
- self.log(DEBUG, 'Rejecting "%s" channel request from server.' % kind)
chanid = m.get_int()
- msg = Message()
- msg.add_byte(chr(MSG_CHANNEL_OPEN_FAILURE))
- msg.add_int(chanid)
- msg.add_int(1)
- msg.add_string('Client connections are not allowed.')
- msg.add_string('en')
- self.send_message(msg)
+ initial_window_size = m.get_int()
+ max_packet_size = m.get_int()
+ reject = False
+ if not self.server_mode:
+ self.log(DEBUG, 'Rejecting "%s" channel request from server.' % kind)
+ reject = True
+ reason = self.OPEN_FAILED_ADMINISTRATIVELY_PROHIBITED
+ else:
+ try:
+ self.lock.acquire()
+ my_chanid = self.channel_counter
+ self.channel_counter += 1
+ finally:
+ self.lock.release()
+ chan = self.check_channel_request(kind, my_chanid)
+ if (chan is None) or (type(chan) is int):
+ self.log(DEBUG, 'Rejecting "%s" channel request from client.' % kind)
+ reject = True
+ if type(chan) is int:
+ reason = chan
+ else:
+ reason = self.OPEN_FAILED_ADMINISTRATIVELY_PROHIBITED
+ if reject:
+ msg = Message()
+ msg.add_byte(chr(MSG_CHANNEL_OPEN_FAILURE))
+ msg.add_int(chanid)
+ msg.add_int(reason)
+ msg.add_string('')
+ msg.add_string('en')
+ self.send_message(msg)
+ return
+ try:
+ self.lock.acquire()
+ self.channels[my_chanid] = chan
+ chan.set_transport(self)
+ chan.set_window(self.window_size, self.max_packet_size)
+ chan.set_remote_channel(chanid, initial_window_size, max_packet_size)
+ finally:
+ self.lock.release()
+ m = Message()
+ m.add_byte(chr(MSG_CHANNEL_OPEN_SUCCESS))
+ m.add_int(chanid)
+ m.add_int(my_chanid)
+ m.add_int(self.window_size)
+ m.add_int(self.max_packet_size)
+ self.send_message(m)
+ self.log(INFO, 'Secsh channel %d opened.' % my_chanid)
+ try:
+ self.lock.acquire()
+ self.server_accepts.append(chan)
+ self.server_accept_cv.notify()
+ finally:
+ self.lock.release()
+
+ def accept(self, timeout=None):
+ try:
+ self.lock.acquire()
+ if len(self.server_accepts) > 0:
+ chan = self.server_accepts.pop(0)
+ else:
+ self.server_accept_cv.wait(timeout)
+ if len(self.server_accepts) > 0:
+ chan = self.server_accepts.pop(0)
+ else:
+ # timeout
+ chan = None
+ finally:
+ self.lock.release()
+ return chan
def parse_debug(self, m):
always_display = m.get_boolean()