summaryrefslogtreecommitdiffhomepage
diff options
context:
space:
mode:
-rw-r--r--auth_transport.py37
-rw-r--r--channel.py119
-rwxr-xr-xdemo-server.py29
-rwxr-xr-xdemo.py2
-rw-r--r--transport.py107
5 files changed, 226 insertions, 68 deletions
diff --git a/auth_transport.py b/auth_transport.py
index 1a06326d..78ce8d70 100644
--- a/auth_transport.py
+++ b/auth_transport.py
@@ -10,11 +10,13 @@ from logging import DEBUG, INFO, WARNING, ERROR, CRITICAL
DISCONNECT_SERVICE_NOT_AVAILABLE, DISCONNECT_AUTH_CANCELLED_BY_USER, \
DISCONNECT_NO_MORE_AUTH_METHODS_AVAILABLE = 7, 13, 14
-AUTH_SUCCESSFUL, AUTH_PARTIALLY_SUCCESSFUL, AUTH_FAILED = range(3)
class Transport(BaseTransport):
"BaseTransport with the auth framework hooked up"
+
+ AUTH_SUCCESSFUL, AUTH_PARTIALLY_SUCCESSFUL, AUTH_FAILED = range(3)
+
def __init__(self, sock):
BaseTransport.__init__(self, sock)
self.auth_event = None
@@ -111,21 +113,21 @@ class Transport(BaseTransport):
else:
self.log(DEBUG, 'Service request "%s" accepted (?)' % service)
- def get_allowed_auths(self):
+ def get_allowed_auths(self, username):
"override me!"
return 'password'
def check_auth_none(self, username):
- "override me! return tuple of (int, string) ==> (auth status, list of acceptable auth methods)"
- return (AUTH_FAILED, self.get_allowed_auths())
+ "override me! return int ==> auth status"
+ return self.AUTH_FAILED
def check_auth_password(self, username, password):
- "override me! return tuple of (int, string) ==> (auth status, list of acceptable auth methods)"
- return (AUTH_FAILED, self.get_allowed_auths())
+ "override me! return int ==> auth status"
+ return self.AUTH_FAILED
def check_auth_publickey(self, username, key):
- "override me! return tuple of (int, string) ==> (auth status, list of acceptable auth methods)"
- return (AUTH_FAILED, self.get_allowed_auths())
+ "override me! return int ==> auth status"
+ return self.AUTH_FAILED
def parse_userauth_request(self, m):
if not self.server_mode:
@@ -142,11 +144,12 @@ class Transport(BaseTransport):
username = m.get_string()
service = m.get_string()
method = m.get_string()
+ self.log(DEBUG, 'Auth request (type=%s) service=%s, username=%s' % (method, service, username))
if service != 'ssh-connection':
self.disconnect_service_not_available()
return
if (self.auth_username is not None) and (self.auth_username != username):
- # trying to change username in mid-flight!
+ self.log(DEBUG, 'Auth rejected because the client attempted to change username in mid-flight')
self.disconnect_no_more_auth()
return
if method == 'none':
@@ -157,27 +160,27 @@ class Transport(BaseTransport):
if changereq:
# always treated as failure, since we don't support changing passwords, but collect
# the list of valid auth types from the callback anyway
+ self.log(DEBUG, 'Auth request to change passwords (rejected)')
newpassword = m.get_string().decode('UTF-8')
- result = self.check_auth_password(username, password)
- result = (AUTH_FAILED, result[1])
+ result = self.AUTH_FAILED
else:
result = self.check_auth_password(username, password)
elif method == 'publickey':
# FIXME
result = self.check_auth_none(username)
- result = (AUTH_FAILED, result[1])
else:
result = self.check_auth_none(username)
- result = (AUTH_FAILED, result[1])
# okay, send result
m = Message()
- if result[0] == AUTH_SUCCESSFUL:
- m.add_byte(chr(MSG_USERAUTH_SUCCESSFUL))
+ if result == self.AUTH_SUCCESSFUL:
+ self.log(DEBUG, 'Auth granted.')
+ m.add_byte(chr(MSG_USERAUTH_SUCCESS))
self.auth_complete = 1
else:
+ self.log(DEBUG, 'Auth rejected.')
m.add_byte(chr(MSG_USERAUTH_FAILURE))
- m.add_string(result[1])
- if result[0] == AUTH_PARTIALLY_SUCCESSFUL:
+ m.add_string(self.get_allowed_auths(username))
+ if result == self.AUTH_PARTIALLY_SUCCESSFUL:
m.add_boolean(1)
else:
m.add_boolean(0)
diff --git a/channel.py b/channel.py
index 275c0a26..8f53d379 100644
--- a/channel.py
+++ b/channel.py
@@ -1,7 +1,7 @@
from message import Message
from secsh import SSHException
from transport import MSG_CHANNEL_REQUEST, MSG_CHANNEL_CLOSE, MSG_CHANNEL_WINDOW_ADJUST, MSG_CHANNEL_DATA, \
- MSG_CHANNEL_EOF
+ MSG_CHANNEL_EOF, MSG_CHANNEL_SUCCESS, MSG_CHANNEL_FAILURE
import time, threading, logging, socket, os
from logging import DEBUG
@@ -18,9 +18,9 @@ class Channel(object):
Abstraction for a secsh channel.
"""
- def __init__(self, chanid, transport):
+ def __init__(self, chanid):
self.chanid = chanid
- self.transport = transport
+ self.transport = None
self.active = 0
self.eof_received = 0
self.eof_sent = 0
@@ -50,6 +50,9 @@ class Channel(object):
out += '>'
return out
+ def set_transport(self, transport):
+ self.transport = transport
+
def log(self, level, msg):
self.logger.log(level, msg)
@@ -60,8 +63,8 @@ class Channel(object):
self.in_window_threshold = window_size // 10
self.in_window_sofar = 0
- def set_server_channel(self, chanid, window_size, max_packet_size):
- self.server_chanid = chanid
+ def set_remote_channel(self, chanid, window_size, max_packet_size):
+ self.remote_chanid = chanid
self.out_window_size = window_size
self.out_max_packet_size = max_packet_size
self.active = 1
@@ -99,14 +102,29 @@ class Channel(object):
def handle_request(self, m):
key = m.get_string()
+ want_reply = m.get_boolean()
+ ok = False
if key == 'exit-status':
self.exit_status = m.get_int()
- return
+ ok = True
elif key == 'xon-xoff':
# ignore
- return
+ ok = True
+ elif (key == 'pty-req') or (key == 'shell'):
+ if self.transport.server_mode:
+ # humor them
+ ok = True
else:
self.log(DEBUG, 'Unhandled channel request "%s"' % key)
+ ok = False
+ if want_reply:
+ m = Message()
+ if ok:
+ m.add_byte(chr(MSG_CHANNEL_SUCCESS))
+ else:
+ m.add_byte(chr(MSG_CHANNEL_FAILURE))
+ m.add_int(self.remote_chanid)
+ self.transport.send_message(m)
def handle_eof(self, m):
self.eof_received = 1
@@ -140,7 +158,7 @@ class Channel(object):
raise SSHException('Channel is not open')
m = Message()
m.add_byte(chr(MSG_CHANNEL_REQUEST))
- m.add_int(self.server_chanid)
+ m.add_int(self.remote_chanid)
m.add_string('pty-req')
m.add_boolean(0)
m.add_string(term)
@@ -156,7 +174,7 @@ class Channel(object):
raise SSHException('Channel is not open')
m = Message()
m.add_byte(chr(MSG_CHANNEL_REQUEST))
- m.add_int(self.server_chanid)
+ m.add_int(self.remote_chanid)
m.add_string('shell')
m.add_boolean(1)
self.transport.send_message(m)
@@ -166,7 +184,7 @@ class Channel(object):
raise SSHException('Channel is not open')
m = Message()
m.add_byte(chr(MSG_CHANNEL_REQUEST))
- m.add_int(self.server_chanid)
+ m.add_int(self.remote_chanid)
m.add_string('exec')
m.add_boolean(1)
m.add_string(command)
@@ -177,7 +195,7 @@ class Channel(object):
raise SSHException('Channel is not open')
m = Message()
m.add_byte(chr(MSG_CHANNEL_REQUEST))
- m.add_int(self.server_chanid)
+ m.add_int(self.remote_chanid)
m.add_string('subsystem')
m.add_boolean(1)
m.add_string(subsystem)
@@ -188,7 +206,7 @@ class Channel(object):
raise SSHException('Channel is not open')
m = Message()
m.add_byte(chr(MSG_CHANNEL_REQUEST))
- m.add_int(self.server_chanid)
+ m.add_int(self.remote_chanid)
m.add_string('window-change')
m.add_boolean(0)
m.add_int(width)
@@ -211,7 +229,7 @@ class Channel(object):
return
m = Message()
m.add_byte(chr(MSG_CHANNEL_EOF))
- m.add_int(self.server_chanid)
+ m.add_int(self.remote_chanid)
self.transport.send_message(m)
self.eof_sent = 1
self.log(DEBUG, 'EOF sent')
@@ -238,7 +256,7 @@ class Channel(object):
self.send_eof()
m = Message()
m.add_byte(chr(MSG_CHANNEL_CLOSE))
- m.add_int(self.server_chanid)
+ m.add_int(self.remote_chanid)
self.transport.send_message(m)
self.closed = 1
self.transport.unlink_channel(self.chanid)
@@ -316,7 +334,7 @@ class Channel(object):
size = self.out_max_packet_size
m = Message()
m.add_byte(chr(MSG_CHANNEL_DATA))
- m.add_int(self.server_chanid)
+ m.add_int(self.remote_chanid)
m.add_string(s[:size])
self.transport.send_message(m)
self.out_window_size -= size
@@ -469,7 +487,7 @@ class Channel(object):
self.log(DEBUG, 'addwindow send %d' % self.in_window_sofar)
m = Message()
m.add_byte(chr(MSG_CHANNEL_WINDOW_ADJUST))
- m.add_int(self.server_chanid)
+ m.add_int(self.remote_chanid)
m.add_int(self.in_window_sofar)
self.transport.send_message(m)
self.in_window_sofar = 0
@@ -490,7 +508,7 @@ class ChannelFile(object):
def __init__(self, channel, mode = "r", buf_size = -1):
self.channel = channel
self.mode = mode
- if buf_size < 0:
+ if buf_size <= 0:
self.buf_size = 1024
self.line_buffered = 0
elif buf_size == 1:
@@ -503,10 +521,12 @@ class ChannelFile(object):
self.rbuffer = ""
self.readable = ("r" in mode)
self.writable = ("w" in mode) or ("+" in mode) or ("a" in mode)
+ self.universal_newlines = ('U' in mode)
self.binary = ("b" in mode)
- if not self.binary:
- raise NotImplementedError("text mode not supported")
- self.softspace = 0
+ self.at_trailing_cr = False
+ self.name = '<file from ' + repr(self.channel) + '>'
+ self.newlines = None
+ self.softspace = False
def __iter__(self):
return self
@@ -570,23 +590,56 @@ class ChannelFile(object):
self.rbuffer[size:]
return result
- def readline(self, size = None):
- line = ""
- while "\n" not in line:
+ def readline(self, size=None):
+ line = self.rbuffer
+ while 1:
+ if self.at_trailing_cr and (len(line) > 0):
+ if line[0] == '\n':
+ line = line[1:]
+ self.at_trailing_cr = False
+ if self.universal_newlines:
+ if ('\n' in line) or ('\r' in line):
+ break
+ else:
+ if '\n' in line:
+ break
if size >= 0:
- new_data = self.read(size - len(line))
+ if len(line) >= size:
+ # truncate line and return
+ self.rbuffer = line[size:]
+ line = line[:size]
+ return line
+ n = size - len(line)
else:
- new_data = self.read(64)
+ n = 64
+ new_data = self.channel.recv(n)
if not new_data:
- break
+ self.rbuffer = ''
+ return line
line += new_data
- newline_pos = line.find("\n")
- if newline_pos >= 0:
- self.rbuffer = line[newline_pos+1:] + self.rbuffer
- return line[:newline_pos+1]
- elif len(line) > size:
- self.rbuffer = line[size:] + self.rbuffer
- return line[:size]
+ # find the newline
+ pos = line.find('\n')
+ if self.universal_newlines:
+ rpos = line.find('\r')
+ if (rpos >= 0) and ((rpos < pos) or (pos < 0)):
+ pos = rpos
+ xpos = pos + 1
+ if (line[pos] == '\r') and (xpos < len(line)) and (line[xpos] == '\n'):
+ xpos += 1
+ self.rbuffer = line[xpos:]
+ lf = line[pos:xpos]
+ line = line[:xpos]
+ if (len(self.rbuffer) == 0) and (lf == '\r'):
+ # we could read the line up to a '\r' and there could still be a
+ # '\n' following that we read next time. note that and eat it.
+ self.at_trailing_cr = True
+ # silliness about tracking what kinds of newlines we've seen
+ if self.newlines is None:
+ self.newlines = lf
+ elif (type(self.newlines) is str) and (self.newlines != lf):
+ self.newlines = (self.newlines, lf)
+ elif lf not in self.newlines:
+ self.newlines += (lf,)
return line
def readlines(self, sizehint = None):
diff --git a/demo-server.py b/demo-server.py
index 1db02230..b0f8326a 100755
--- a/demo-server.py
+++ b/demo-server.py
@@ -1,6 +1,6 @@
#!/usr/bin/python
-import sys, os, socket, threading, logging, traceback
+import sys, os, socket, threading, logging, traceback, time
import secsh
# setup logging
@@ -15,6 +15,19 @@ if len(l.handlers) == 0:
host_key = secsh.RSAKey()
host_key.read_private_key_file('demo-host-key')
+
+class ServerTransport(secsh.Transport):
+ def check_channel_request(self, kind, chanid):
+ if kind == 'session':
+ return secsh.Channel(chanid)
+ return self.OPEN_FAILED_ADMINISTRATIVELY_PROHIBITED
+
+ def check_auth_password(self, username, password):
+ if (username == 'robey') and (password == 'foo'):
+ return self.AUTH_SUCCESSFUL
+ return self.AUTH_FAILED
+
+
# now connect
try:
sock = socket.socket(socket.AF_INET, socket.SOCK_STREAM)
@@ -35,7 +48,7 @@ except Exception, e:
try:
event = threading.Event()
- t = secsh.Transport(client)
+ t = ServerTransport(client)
t.add_server_key(host_key)
t.ultra_debug = 1
t.start_server(event)
@@ -45,6 +58,18 @@ try:
print '*** SSH negotiation failed.'
sys.exit(1)
# print repr(t)
+
+ chan = t.accept()
+ time.sleep(2)
+ chan.send('\r\n\r\nWelcome to my dorky little BBS!\r\n\r\n')
+ chan.send('We are on fire all the time! Hooray! Candy corn for everyone!\r\n')
+ chan.send('Happy birthday to Robot Dave!\r\n\r\n')
+ chan.send('Username: ')
+ f = chan.makefile('rU')
+ username = f.readline().strip('\r\n')
+ chan.send('\r\nI don\'t like you, ' + username + '.\r\n')
+ chan.close()
+
except Exception, e:
print '*** Caught exception: ' + str(e.__class__) + ': ' + str(e)
traceback.print_exc()
diff --git a/demo.py b/demo.py
index fc707e4a..069077d7 100755
--- a/demo.py
+++ b/demo.py
@@ -76,7 +76,7 @@ try:
# print repr(t)
keys = load_host_keys()
- keytype, hostkey = t.get_host_key()
+ keytype, hostkey = t.get_remote_server_key()
if not keys.has_key(hostname):
print '*** WARNING: Unknown host key!'
elif not keys[hostname].has_key(keytype):
diff --git a/transport.py b/transport.py
index 2020b279..c5ff252b 100644
--- a/transport.py
+++ b/transport.py
@@ -11,12 +11,11 @@ MSG_CHANNEL_OPEN, MSG_CHANNEL_OPEN_SUCCESS, MSG_CHANNEL_OPEN_FAILURE, \
MSG_CHANNEL_EOF, MSG_CHANNEL_CLOSE, MSG_CHANNEL_REQUEST, \
MSG_CHANNEL_SUCCESS, MSG_CHANNEL_FAILURE = range(90, 101)
-
import sys, os, string, threading, socket, logging, struct
from message import Message
from channel import Channel
from secsh import SSHException
-from util import format_binary, safe_string, inflate_long, deflate_long
+from util import format_binary, safe_string, inflate_long, deflate_long, tb_strings
from rsakey import RSAKey
from dsskey import DSSKey
from kex_group1 import KexGroup1
@@ -105,6 +104,9 @@ class BaseTransport(threading.Thread):
REKEY_PACKETS = pow(2, 30)
REKEY_BYTES = pow(2, 30)
+ OPEN_FAILED_ADMINISTRATIVELY_PROHIBITED, OPEN_FAILED_CONNECT_FAILED, OPEN_FAILED_UNKNOWN_CHANNEL_TYPE, \
+ OPEN_FAILED_RESOURCE_SHORTAGE = range(1, 5)
+
def __init__(self, sock):
threading.Thread.__init__(self)
self.randpool = randpool
@@ -143,6 +145,8 @@ class BaseTransport(threading.Thread):
# server mode:
self.server_mode = 0
self.server_key_dict = { }
+ self.server_accepts = [ ]
+ self.server_accept_cv = threading.Condition(self.lock)
def start_client(self, event=None):
self.completion_event = event
@@ -196,7 +200,7 @@ class BaseTransport(threading.Thread):
for chan in self.channels.values():
chan.unlink()
- def get_host_key(self):
+ def get_remote_server_key(self):
'returns (type, key) where type is like "ssh-rsa" and key is an opaque string'
if (not self.active) or (not self.initial_kex_done):
raise SSHException('No existing session')
@@ -225,8 +229,9 @@ class BaseTransport(threading.Thread):
m.add_int(chanid)
m.add_int(self.window_size)
m.add_int(self.max_packet_size)
- self.channels[chanid] = chan = Channel(chanid, self)
+ self.channels[chanid] = chan = Channel(chanid)
self.channel_events[chanid] = event = threading.Event()
+ chan.set_transport(self)
chan.set_window(self.window_size, self.max_packet_size)
self.send_message(m)
finally:
@@ -445,10 +450,12 @@ class BaseTransport(threading.Thread):
self.send_message(msg)
except SSHException, e:
self.log(DEBUG, 'Exception: ' + str(e))
+ self.log(DEBUG, tb_strings())
except EOFError, e:
self.log(DEBUG, 'EOF')
except Exception, e:
self.log(DEBUG, 'Unknown exception: ' + str(e))
+ self.log(DEBUG, tb_strings())
if self.active:
self.active = 0
if self.completion_event != None:
@@ -503,7 +510,11 @@ class BaseTransport(threading.Thread):
comment = buffer[i+1:]
buffer = buffer[:i]
# parse out version string and make sure it matches
- _unused, version, client = string.split(buffer, '-')
+ segs = buffer.split('-', 2)
+ if len(segs) < 3:
+ raise SSHException('Invalid SSH banner')
+ version = segs[1]
+ client = segs[2]
if version != '1.99' and version != '2.0':
raise SSHException('Incompatible version (%s instead of 2.0)' % (version,))
self.log(INFO, 'Connected (version %s, client %s)' % (version, client))
@@ -681,6 +692,7 @@ class BaseTransport(threading.Thread):
code = m.get_int()
desc = m.get_string()
self.log(INFO, 'Disconnect (code %d): %s' % (code, desc))
+
def parse_channel_open_success(self, m):
chanid = m.get_int()
server_chanid = m.get_int()
@@ -692,7 +704,7 @@ class BaseTransport(threading.Thread):
try:
self.lock.acquire()
chan = self.channels[chanid]
- chan.set_server_channel(server_chanid, server_window_size, server_max_packet_size)
+ chan.set_remote_channel(server_chanid, server_window_size, server_max_packet_size)
self.log(INFO, 'Secsh channel %d opened.' % chanid)
if self.channel_events.has_key(chanid):
self.channel_events[chanid].set()
@@ -719,20 +731,85 @@ class BaseTransport(threading.Thread):
self.channel_events[chanid].set()
del self.channel_events[chanid]
finally:
- self.lock_release()
+ self.lock.release()
return
+ def check_channel_request(self, kind, chanid):
+ "override me! return object descended from Channel to allow, or None to reject"
+ return None
+
def parse_channel_open(self, m):
kind = m.get_string()
- self.log(DEBUG, 'Rejecting "%s" channel request from server.' % kind)
chanid = m.get_int()
- msg = Message()
- msg.add_byte(chr(MSG_CHANNEL_OPEN_FAILURE))
- msg.add_int(chanid)
- msg.add_int(1)
- msg.add_string('Client connections are not allowed.')
- msg.add_string('en')
- self.send_message(msg)
+ initial_window_size = m.get_int()
+ max_packet_size = m.get_int()
+ reject = False
+ if not self.server_mode:
+ self.log(DEBUG, 'Rejecting "%s" channel request from server.' % kind)
+ reject = True
+ reason = self.OPEN_FAILED_ADMINISTRATIVELY_PROHIBITED
+ else:
+ try:
+ self.lock.acquire()
+ my_chanid = self.channel_counter
+ self.channel_counter += 1
+ finally:
+ self.lock.release()
+ chan = self.check_channel_request(kind, my_chanid)
+ if (chan is None) or (type(chan) is int):
+ self.log(DEBUG, 'Rejecting "%s" channel request from client.' % kind)
+ reject = True
+ if type(chan) is int:
+ reason = chan
+ else:
+ reason = self.OPEN_FAILED_ADMINISTRATIVELY_PROHIBITED
+ if reject:
+ msg = Message()
+ msg.add_byte(chr(MSG_CHANNEL_OPEN_FAILURE))
+ msg.add_int(chanid)
+ msg.add_int(reason)
+ msg.add_string('')
+ msg.add_string('en')
+ self.send_message(msg)
+ return
+ try:
+ self.lock.acquire()
+ self.channels[my_chanid] = chan
+ chan.set_transport(self)
+ chan.set_window(self.window_size, self.max_packet_size)
+ chan.set_remote_channel(chanid, initial_window_size, max_packet_size)
+ finally:
+ self.lock.release()
+ m = Message()
+ m.add_byte(chr(MSG_CHANNEL_OPEN_SUCCESS))
+ m.add_int(chanid)
+ m.add_int(my_chanid)
+ m.add_int(self.window_size)
+ m.add_int(self.max_packet_size)
+ self.send_message(m)
+ self.log(INFO, 'Secsh channel %d opened.' % my_chanid)
+ try:
+ self.lock.acquire()
+ self.server_accepts.append(chan)
+ self.server_accept_cv.notify()
+ finally:
+ self.lock.release()
+
+ def accept(self, timeout=None):
+ try:
+ self.lock.acquire()
+ if len(self.server_accepts) > 0:
+ chan = self.server_accepts.pop(0)
+ else:
+ self.server_accept_cv.wait(timeout)
+ if len(self.server_accepts) > 0:
+ chan = self.server_accepts.pop(0)
+ else:
+ # timeout
+ chan = None
+ finally:
+ self.lock.release()
+ return chan
def parse_debug(self, m):
always_display = m.get_boolean()