summaryrefslogtreecommitdiffhomepage
diff options
context:
space:
mode:
authorRobey Pointer <robey@lag.net>2005-05-01 08:04:59 +0000
committerRobey Pointer <robey@lag.net>2005-05-01 08:04:59 +0000
commit36055c5ac2bd786a21aa05d248935a77a8fbccec (patch)
treed14484cff9237385fed9af7e2091524691365187
parent2f2d7bdee88c9f9b14dc2495fb77d7abd1587d64 (diff)
[project @ Arch-1:robey@lag.net--2005-master-shake%paramiko--dev--1--patch-5]
split out Packetizer, fix banner detection bug, new unit test split out a chunk of BaseTransport into a Packetizer class, which handles the in/out packet data, ciphers, etc. it didn't make the code any smaller (transport.py is still close to 1500 lines, which is awful) but it did split out a coherent chunk of functionality into a discrete unit. in the process, fixed a bug that alain spineux pointed out: the banner check was too forgiving and would block forever waiting for an SSH banner. now it waits 5 seconds for the first line, and 2 seconds for each subsequent line, before giving up. added a unit test to test keepalive, since i wasn't sure that was still working after pulling out Packetizer.
-rw-r--r--README2
-rw-r--r--paramiko/channel.py5
-rw-r--r--paramiko/packet.py401
-rw-r--r--paramiko/sftp_client.py2
-rw-r--r--paramiko/sftp_server.py2
-rw-r--r--paramiko/transport.py340
-rwxr-xr-xtest.py6
-rwxr-xr-xtests/test_sftp.py34
-rw-r--r--tests/test_transport.py46
9 files changed, 553 insertions, 285 deletions
diff --git a/README b/README
index ffecafec..01b89396 100644
--- a/README
+++ b/README
@@ -231,3 +231,5 @@ v0.9 FEAROW
* would be nice to have an ftp-like interface to sftp (put, get, chdir...)
+* speed up file transfers!
+* what is psyco?
diff --git a/paramiko/channel.py b/paramiko/channel.py
index c6915e1a..cd866c09 100644
--- a/paramiko/channel.py
+++ b/paramiko/channel.py
@@ -298,7 +298,6 @@ class Channel (object):
m.add_boolean(0)
m.add_int(status)
self.transport._send_user_message(m)
- self._log(DEBUG, 'EXIT-STATUS')
def get_transport(self):
"""
@@ -468,7 +467,7 @@ class Channel (object):
it means you may need to wait before more data arrives.
@return: C{True} if a L{recv} call on this channel would immediately
- return at least one byte; C{False} otherwise.
+ return at least one byte; C{False} otherwise.
@rtype: boolean
"""
self.lock.acquire()
@@ -492,7 +491,7 @@ class Channel (object):
@rtype: str
@raise socket.timeout: if no data is ready before the timeout set by
- L{settimeout}.
+ L{settimeout}.
"""
out = ''
self.lock.acquire()
diff --git a/paramiko/packet.py b/paramiko/packet.py
new file mode 100644
index 00000000..d93227cb
--- /dev/null
+++ b/paramiko/packet.py
@@ -0,0 +1,401 @@
+#!/usr/bin/python
+
+# Copyright (C) 2003-2005 Robey Pointer <robey@lag.net>
+#
+# This file is part of paramiko.
+#
+# Paramiko is free software; you can redistribute it and/or modify it under the
+# terms of the GNU Lesser General Public License as published by the Free
+# Software Foundation; either version 2.1 of the License, or (at your option)
+# any later version.
+#
+# Paramiko is distrubuted in the hope that it will be useful, but WITHOUT ANY
+# WARRANTY; without even the implied warranty of MERCHANTABILITY or FITNESS FOR
+# A PARTICULAR PURPOSE. See the GNU Lesser General Public License for more
+# details.
+#
+# You should have received a copy of the GNU Lesser General Public License
+# along with Paramiko; if not, write to the Free Software Foundation, Inc.,
+# 59 Temple Place, Suite 330, Boston, MA 02111-1307 USA.
+
+"""
+Packetizer.
+"""
+
+import select, socket, struct, threading, time
+from Crypto.Hash import HMAC
+from common import *
+from message import Message
+import util
+
+
+class Packetizer (object):
+ """
+ Implementation of the base SSH packet protocol.
+ """
+
+ # READ the secsh RFC's before raising these values. if anything,
+ # they should probably be lower.
+ REKEY_PACKETS = pow(2, 30)
+ REKEY_BYTES = pow(2, 30)
+
+ def __init__(self, socket):
+ self.__socket = socket
+ self.__logger = None
+ self.__closed = False
+ self.__dump_packets = False
+ self.__need_rekey = False
+
+ # used for noticing when to re-key:
+ self.__sent_bytes = 0
+ self.__sent_packets = 0
+ self.__received_bytes = 0
+ self.__received_packets = 0
+ self.__received_packets_overflow = 0
+
+ # current inbound/outbound ciphering:
+ self.__block_size_out = 8
+ self.__block_size_in = 8
+ self.__mac_size_out = 0
+ self.__mac_size_in = 0
+ self.__block_engine_out = None
+ self.__block_engine_in = None
+ self.__mac_engine_out = None
+ self.__mac_engine_in = None
+ self.__mac_key_out = ''
+ self.__mac_key_in = ''
+ self.__sequence_number_out = 0L
+ self.__sequence_number_in = 0L
+
+ # lock around outbound writes (packet computation)
+ self.__write_lock = threading.RLock()
+
+ # keepalives:
+ self.__keepalive_interval = 0
+ self.__keepalive_last = time.time()
+ self.__keepalive_callback = None
+
+
+ def set_log(self, log):
+ """
+ Set the python log object to use for logging.
+ """
+ self.__logger = log
+
+ def set_outbound_cipher(self, block_engine, block_size, mac_engine, mac_size, mac_key):
+ """
+ Switch outbound data cipher.
+ """
+ self.__block_engine_out = block_engine
+ self.__block_size_out = block_size
+ self.__mac_engine_out = mac_engine
+ self.__mac_size_out = mac_size
+ self.__mac_key_out = mac_key
+ self.__sent_bytes = 0
+ self.__sent_packets = 0
+ self.__need_rekey = False
+
+ def set_inbound_cipher(self, block_engine, block_size, mac_engine, mac_size, mac_key):
+ """
+ Switch inbound data cipher.
+ """
+ self.__block_engine_in = block_engine
+ self.__block_size_in = block_size
+ self.__mac_engine_in = mac_engine
+ self.__mac_size_in = mac_size
+ self.__mac_key_in = mac_key
+ self.__received_bytes = 0
+ self.__received_packets = 0
+ self.__received_packets_overflow = 0
+ self.__need_rekey = False
+
+ def close(self):
+ self.__closed = True
+ self.__block_engine_in = None
+ self.__block_engine_out = None
+ self.__socket.close()
+
+ def set_hexdump(self, hexdump):
+ self.__dump_packets = hexdump
+
+ def get_hexdump(self):
+ return self.__dump_packets
+
+ def get_mac_size_in(self):
+ return self.__mac_size_in
+
+ def get_mac_size_out(self):
+ return self.__mac_size_out
+
+ def need_rekey(self):
+ """
+ Returns C{True} if a new set of keys needs to be negotiated. This
+ will be triggered during a packet read or write, so it should be
+ checked after every read or write, or at least after every few.
+
+ @return: C{True} if a new set of keys needs to be negotiated
+ """
+ return self.__need_rekey
+
+ def set_keepalive(self, interval, callback):
+ """
+ Turn on/off the callback keepalive. If C{interval} seconds pass with
+ no data read from or written to the socket, the callback will be
+ executed and the timer will be reset.
+ """
+ self.__keepalive_interval = interval
+ self.__keepalive_callback = callback
+ self.__keepalive_last = time.time()
+ self._log(DEBUG, 'SET KEEPALIVE %r' % interval)
+
+ def read_all(self, n):
+ """
+ Read as close to N bytes as possible, blocking as long as necessary.
+
+ @param n: number of bytes to read
+ @type n: int
+ @return: the data read
+ @rtype: str
+ @throw EOFError: if the socket was closed before all the bytes could
+ be read
+ """
+ if PY22:
+ return self._py22_read_all(n)
+ out = ''
+ while n > 0:
+ try:
+ x = self.__socket.recv(n)
+ if len(x) == 0:
+ raise EOFError()
+ out += x
+ n -= len(x)
+ except socket.timeout:
+ if self.__closed:
+ raise EOFError()
+ self._check_keepalive()
+ return out
+
+ def write_all(self, out):
+ self.__keepalive_last = time.time()
+ while len(out) > 0:
+ try:
+ n = self.__socket.send(out)
+ except socket.timeout:
+ n = 0
+ if self.__closed:
+ n = -1
+ except Exception, x:
+ # could be: (32, 'Broken pipe')
+ n = -1
+ if n < 0:
+ raise EOFError()
+ if n == len(out):
+ return
+ out = out[n:]
+ return
+
+ def readline(self, timeout):
+ """
+ Read a line from the socket. This is done in a fairly inefficient
+ way, but is only used for initial banner negotiation so it's not worth
+ optimising.
+ """
+ buffer = ''
+ while not '\n' in buffer:
+ buffer += self._read_timeout(timeout)
+ buffer = buffer[:-1]
+ if (len(buffer) > 0) and (buffer[-1] == '\r'):
+ buffer = buffer[:-1]
+ return buffer
+
+ def send_message(self, data):
+ """
+ Write a block of data using the current cipher, as an SSH block.
+ """
+ # encrypt this sucka
+ data = str(data)
+ cmd = ord(data[0])
+ if cmd in MSG_NAMES:
+ cmd_name = MSG_NAMES[cmd]
+ else:
+ cmd_name = '$%x' % cmd
+ self._log(DEBUG, 'Write packet <%s>, length %d' % (cmd_name, len(data)))
+ packet = self._build_packet(data)
+ if self.__dump_packets:
+ self._log(DEBUG, util.format_binary(packet, 'OUT: '))
+ if self.__block_engine_out != None:
+ out = self.__block_engine_out.encrypt(packet)
+ else:
+ out = packet
+ # + mac
+ self.__write_lock.acquire()
+ try:
+ if self.__block_engine_out != None:
+ payload = struct.pack('>I', self.__sequence_number_out) + packet
+ out += HMAC.HMAC(self.__mac_key_out, payload, self.__mac_engine_out).digest()[:self.__mac_size_out]
+ self.__sequence_number_out = (self.__sequence_number_out + 1) & 0xffffffffL
+ self.write_all(out)
+
+ self.__sent_bytes += len(out)
+ self.__sent_packets += 1
+ if ((self.__sent_packets >= self.REKEY_PACKETS) or (self.__sent_bytes >= self.REKEY_BYTES)) \
+ and not self.__need_rekey:
+ # only ask once for rekeying
+ self._log(DEBUG, 'Rekeying (hit %d packets, %d bytes sent)' %
+ (self.__sent_packets, self.__sent_bytes))
+ self.__received_packets_overflow = 0
+ self._trigger_rekey()
+ finally:
+ self.__write_lock.release()
+
+ def read_message(self):
+ """
+ Only one thread should ever be in this function (no other locking is
+ done).
+
+ @throw SSHException: if the packet is mangled
+ """
+ header = self.read_all(self.__block_size_in)
+ if self.__block_engine_in != None:
+ header = self.__block_engine_in.decrypt(header)
+ if self.__dump_packets:
+ self._log(DEBUG, util.format_binary(header, 'IN: '));
+ packet_size = struct.unpack('>I', header[:4])[0]
+ # leftover contains decrypted bytes from the first block (after the length field)
+ leftover = header[4:]
+ if (packet_size - len(leftover)) % self.__block_size_in != 0:
+ raise SSHException('Invalid packet blocking')
+ buffer = self.read_all(packet_size + self.__mac_size_in - len(leftover))
+ packet = buffer[:packet_size - len(leftover)]
+ post_packet = buffer[packet_size - len(leftover):]
+ if self.__block_engine_in != None:
+ packet = self.__block_engine_in.decrypt(packet)
+ if self.__dump_packets:
+ self._log(DEBUG, util.format_binary(packet, 'IN: '));
+ packet = leftover + packet
+
+ if self.__mac_size_in > 0:
+ mac = post_packet[:self.__mac_size_in]
+ mac_payload = struct.pack('>II', self.__sequence_number_in, packet_size) + packet
+ my_mac = HMAC.HMAC(self.__mac_key_in, mac_payload, self.__mac_engine_in).digest()[:self.__mac_size_in]
+ if my_mac != mac:
+ raise SSHException('Mismatched MAC')
+ padding = ord(packet[0])
+ payload = packet[1:packet_size - padding + 1]
+ randpool.add_event(packet[packet_size - padding + 1])
+ if self.__dump_packets:
+ self._log(DEBUG, 'Got payload (%d bytes, %d padding)' % (packet_size, padding))
+
+ msg = Message(payload[1:])
+ msg.seqno = self.__sequence_number_in
+ self.__sequence_number_in = (self.__sequence_number_in + 1) & 0xffffffffL
+
+ # check for rekey
+ self.__received_bytes += packet_size + self.__mac_size_in + 4
+ self.__received_packets += 1
+ if self.__need_rekey:
+ # we've asked to rekey -- give them 20 packets to comply before
+ # dropping the connection
+ self.__received_packets_overflow += 1
+ if self.__received_packets_overflow >= 20:
+ raise SSHException('Remote transport is ignoring rekey requests')
+ elif (self.__received_packets >= self.REKEY_PACKETS) or \
+ (self.__received_bytes >= self.REKEY_BYTES):
+ # only ask once for rekeying
+ self._log(DEBUG, 'Rekeying (hit %d packets, %d bytes received)' %
+ (self.__received_packets, self.__received_bytes))
+ self.__received_packets_overflow = 0
+ self._trigger_rekey()
+
+ cmd = ord(payload[0])
+ if cmd in MSG_NAMES:
+ cmd_name = MSG_NAMES[cmd]
+ else:
+ cmd_name = '$%x' % cmd
+ self._log(DEBUG, 'Read packet <%s>, length %d' % (cmd_name, len(payload)))
+ return cmd, msg
+
+
+ ########## protected
+
+
+ def _log(self, level, msg):
+ if self.__logger is None:
+ return
+ if issubclass(type(msg), list):
+ for m in msg:
+ self.__logger.log(level, m)
+ else:
+ self.__logger.log(level, msg)
+
+ def _check_keepalive(self):
+ if (not self.__keepalive_interval) or (not self.__block_engine_out) or \
+ self.__need_rekey:
+ # wait till we're encrypting, and not in the middle of rekeying
+ return
+ now = time.time()
+ if now > self.__keepalive_last + self.__keepalive_interval:
+ self.__keepalive_callback()
+ self.__keepalive_last = now
+
+ def _py22_read_all(self, n):
+ out = ''
+ while n > 0:
+ r, w, e = select.select([self.__socket], [], [], 0.1)
+ if self.__socket not in r:
+ if self.__closed:
+ raise EOFError()
+ self._check_keepalive()
+ else:
+ x = self.__socket.recv(n)
+ if len(x) == 0:
+ raise EOFError()
+ out += x
+ n -= len(x)
+ return out
+
+ def _py22_read_timeout(self, timeout):
+ start = time.time()
+ while True:
+ r, w, e = select.select([self.__socket], [], [], 0.1)
+ if self.__socket in r:
+ x = self.__socket.recv(1)
+ if len(x) == 0:
+ raise EOFError()
+ return x
+ if self.__closed:
+ raise EOFError()
+ now = time.time()
+ if now - start >= timeout:
+ raise socket.timeout()
+
+ def _read_timeout(self, timeout):
+ if PY22:
+ return self._py22_read_timeout(n)
+ start = time.time()
+ while True:
+ try:
+ x = self.__socket.recv(1)
+ if len(x) == 0:
+ raise EOFError()
+ return x
+ except socket.timeout:
+ pass
+ if self.__closed:
+ raise EOFError()
+ now = time.time()
+ if now - start >= timeout:
+ raise socket.timeout()
+
+ def _build_packet(self, payload):
+ # pad up at least 4 bytes, to nearest block-size (usually 8)
+ bsize = self.__block_size_out
+ padding = 3 + bsize - ((len(payload) + 8) % bsize)
+ packet = struct.pack('>IB', len(payload) + padding + 1, padding)
+ packet += payload
+ packet += randpool.get_bytes(padding)
+ return packet
+
+ def _trigger_rekey(self):
+ # outside code should check for this flag
+ self.__need_rekey = True
diff --git a/paramiko/sftp_client.py b/paramiko/sftp_client.py
index b29f71e7..fcf77063 100644
--- a/paramiko/sftp_client.py
+++ b/paramiko/sftp_client.py
@@ -53,7 +53,7 @@ class SFTPClient (BaseSFTP):
transport = self.sock.get_transport()
self.logger = util.get_logger(transport.get_log_channel() + '.' +
self.sock.get_name() + '.sftp')
- self.ultra_debug = transport.ultra_debug
+ self.ultra_debug = transport.get_hexdump()
self._send_version()
def from_transport(selfclass, t):
diff --git a/paramiko/sftp_server.py b/paramiko/sftp_server.py
index 67971848..94a9e6c4 100644
--- a/paramiko/sftp_server.py
+++ b/paramiko/sftp_server.py
@@ -60,7 +60,7 @@ class SFTPServer (BaseSFTP, SubsystemHandler):
transport = channel.get_transport()
self.logger = util.get_logger(transport.get_log_channel() + '.' +
channel.get_name() + '.sftp')
- self.ultra_debug = transport.ultra_debug
+ self.ultra_debug = transport.get_hexdump()
self.next_handle = 1
# map of handle-string to SFTPHandle for files & folders:
self.file_table = { }
diff --git a/paramiko/transport.py b/paramiko/transport.py
index da252ce7..b436847c 100644
--- a/paramiko/transport.py
+++ b/paramiko/transport.py
@@ -1,5 +1,3 @@
-#!/usr/bin/python
-
# Copyright (C) 2003-2005 Robey Pointer <robey@lag.net>
#
# This file is part of paramiko.
@@ -30,6 +28,7 @@ from message import Message
from channel import Channel
from sftp_client import SFTPClient
import util
+from packet import Packetizer
from rsakey import RSAKey
from dsskey import DSSKey
from kex_group1 import KexGroup1
@@ -49,7 +48,7 @@ from Crypto.Hash import SHA, MD5, HMAC
_active_threads = []
def _join_lingering_threads():
for thr in _active_threads:
- thr.active = False
+ thr.stop_thread()
import atexit
atexit.register(_join_lingering_threads)
@@ -162,10 +161,6 @@ class BaseTransport (threading.Thread):
'diffie-hellman-group-exchange-sha1': KexGex,
}
- # READ the secsh RFC's before raising these values. if anything,
- # they should probably be lower.
- REKEY_PACKETS = pow(2, 30)
- REKEY_BYTES = pow(2, 30)
_modulus_pack = None
@@ -222,42 +217,29 @@ class BaseTransport (threading.Thread):
except AttributeError:
pass
# negotiated crypto parameters
+ self.packetizer = Packetizer(sock)
self.local_version = 'SSH-' + self._PROTO_ID + '-' + self._CLIENT_ID
self.remote_version = ''
- self.block_size_out = self.block_size_in = 8
- self.local_mac_len = self.remote_mac_len = 0
- self.engine_in = self.engine_out = None
self.local_cipher = self.remote_cipher = ''
- self.sequence_number_in = self.sequence_number_out = 0L
self.local_kex_init = self.remote_kex_init = None
self.session_id = None
# /negotiated crypto parameters
self.expected_packet = 0
self.active = False
self.initial_kex_done = False
- self.write_lock = threading.RLock() # lock around outbound writes (packet computation)
- self.lock = threading.Lock() # synchronization (always higher level than write_lock)
- self.channels = { } # (id -> Channel)
- self.channel_events = { } # (id -> Event)
+ self.lock = threading.Lock() # synchronization (always higher level than write_lock)
+ self.channels = { } # (id -> Channel)
+ self.channel_events = { } # (id -> Event)
self.channel_counter = 1
self.window_size = 65536
self.max_packet_size = 32768
- self.ultra_debug = False
self.saved_exception = None
self.clear_to_send = threading.Event()
self.log_name = 'paramiko.transport'
self.logger = util.get_logger(self.log_name)
- # used for noticing when to re-key:
- self.received_bytes = 0
- self.received_packets = 0
- self.received_packets_overflow = 0
- self.sent_bytes = 0
- self.sent_packets = 0
+ self.packetizer.set_log(self.logger)
# user-defined event callbacks:
self.completion_event = None
- # keepalives:
- self.keepalive_interval = 0
- self.keepalive_last = time.time()
# server mode:
self.server_mode = False
self.server_object = None
@@ -293,7 +275,7 @@ class BaseTransport (threading.Thread):
preference for them.
@return: an object that can be used to change the preferred algorithms
- for encryption, digest (hash), public key, and key exchange.
+ for encryption, digest (hash), public key, and key exchange.
@rtype: L{SecurityOptions}
@since: ivysaur
@@ -316,8 +298,8 @@ class BaseTransport (threading.Thread):
@note: L{connect} is a simpler method for connecting as a client.
@note: After calling this method (or L{start_server} or L{connect}),
- you should no longer directly read from or write to the original socket
- object.
+ you should no longer directly read from or write to the original
+ socket object.
@param event: an event to trigger when negotiation is complete.
@type event: threading.Event
@@ -350,13 +332,13 @@ class BaseTransport (threading.Thread):
given C{server} object to allow channels to be opened.
@note: After calling this method (or L{start_client} or L{connect}),
- you should no longer directly read from or write to the original socket
- object.
+ you should no longer directly read from or write to the original
+ socket object.
@param event: an event to trigger when negotiation is complete.
@type event: threading.Event
@param server: an object used to perform authentication and create
- L{Channel}s.
+ L{Channel}s.
@type server: L{server.ServerInterface}
"""
if server is None:
@@ -376,7 +358,7 @@ class BaseTransport (threading.Thread):
key info, not just the public half.
@param key: the host key to add, usually an L{RSAKey <rsakey.RSAKey>} or
- L{DSSKey <dsskey.DSSKey>}.
+ L{DSSKey <dsskey.DSSKey>}.
@type key: L{PKey <pkey.PKey>}
"""
self.server_key_dict[key.get_name()] = key
@@ -418,10 +400,10 @@ class BaseTransport (threading.Thread):
support that method of key negotiation.
@param filename: optional path to the moduli file, if you happen to
- know that it's not in a standard location.
+ know that it's not in a standard location.
@type filename: str
@return: True if a moduli file was successfully loaded; False
- otherwise.
+ otherwise.
@rtype: bool
@since: doduo
@@ -449,8 +431,7 @@ class BaseTransport (threading.Thread):
Close this session, and any open channels that are tied to it.
"""
self.active = False
- self.engine_in = self.engine_out = None
- self.sequence_number_in = self.sequence_number_out = 0L
+ self.packetizer.close()
for chan in self.channels.values():
chan._unlink()
@@ -459,9 +440,9 @@ class BaseTransport (threading.Thread):
Return the host key of the server (in client mode).
@note: Previously this call returned a tuple of (key type, key string).
- You can get the same effect by calling
- L{PKey.get_name <pkey.PKey.get_name>} for the key type, and C{str(key)}
- for the key string.
+ You can get the same effect by calling
+ L{PKey.get_name <pkey.PKey.get_name>} for the key type, and
+ C{str(key)} for the key string.
@raise SSHException: if no session is currently active.
@@ -476,7 +457,8 @@ class BaseTransport (threading.Thread):
"""
Return true if this session is active (open).
- @return: True if the session is still active (open); False if the session is closed.
+ @return: True if the session is still active (open); False if the
+ session is closed.
@rtype: bool
"""
return self.active
@@ -487,7 +469,7 @@ class BaseTransport (threading.Thread):
is just an alias for C{open_channel('session')}.
@return: a new L{Channel} on success, or C{None} if the request is
- rejected or the session ends prematurely.
+ rejected or the session ends prematurely.
@rtype: L{Channel}
"""
return self.open_channel('session')
@@ -500,17 +482,17 @@ class BaseTransport (threading.Thread):
L{connect} or L{start_client}) and authenticating.
@param kind: the kind of channel requested (usually C{"session"},
- C{"forwarded-tcpip"} or C{"direct-tcpip"}).
+ C{"forwarded-tcpip"} or C{"direct-tcpip"}).
@type kind: str
@param dest_addr: the destination address of this port forwarding,
- if C{kind} is C{"forwarded-tcpip"} or C{"direct-tcpip"} (ignored
- for other channel types).
+ if C{kind} is C{"forwarded-tcpip"} or C{"direct-tcpip"} (ignored
+ for other channel types).
@type dest_addr: (str, int)
@param src_addr: the source address of this port forwarding, if
- C{kind} is C{"forwarded-tcpip"} or C{"direct-tcpip"}.
+ C{kind} is C{"forwarded-tcpip"} or C{"direct-tcpip"}.
@type src_addr: (str, int)
@return: a new L{Channel} on success, or C{None} if the request is
- rejected or the session ends prematurely.
+ rejected or the session ends prematurely.
@rtype: L{Channel}
"""
chan = None
@@ -599,7 +581,7 @@ class BaseTransport (threading.Thread):
session has died mid-negotiation.
@return: True if the renegotiation was successful, and the link is
- using new keys; False if the session dropped during renegotiation.
+ using new keys; False if the session dropped during renegotiation.
@rtype: bool
"""
self.completion_event = threading.Event()
@@ -620,12 +602,13 @@ class BaseTransport (threading.Thread):
can be useful to keep connections alive over a NAT, for example.
@param interval: seconds to wait before sending a keepalive packet (or
- 0 to disable keepalives).
+ 0 to disable keepalives).
@type interval: int
@since: fearow
"""
- self.keepalive_interval = interval
+ self.packetizer.set_keepalive(interval,
+ lambda x=self: x.global_request('keepalive@lag.net', wait=False))
def global_request(self, kind, data=None, wait=True):
"""
@@ -635,14 +618,14 @@ class BaseTransport (threading.Thread):
@param kind: name of the request.
@type kind: str
@param data: an optional tuple containing additional data to attach
- to the request.
+ to the request.
@type data: tuple
@param wait: C{True} if this method should not return until a response
- is received; C{False} otherwise.
+ is received; C{False} otherwise.
@type wait: bool
@return: a L{Message} containing possible additional data if the
- request was successful (or an empty L{Message} if C{wait} was
- C{False}); C{None} if the request was denied.
+ request was successful (or an empty L{Message} if C{wait} was
+ C{False}); C{None} if the request was denied.
@rtype: L{Message}
@since: fearow
@@ -833,7 +816,23 @@ class BaseTransport (threading.Thread):
C{False} otherwise.
@type hexdump: bool
"""
- self.ultra_debug = hexdump
+ self.packetizer.set_hexdump(hexdump)
+
+ def get_hexdump(self):
+ """
+ Return C{True} if the transport is currently logging hex dumps of
+ protocol traffic.
+
+ @return: C{True} if hex dumps are being logged
+ @rtype: bool
+
+ @since: 1.4
+ """
+ return self.packetizer.get_hexdump()
+
+ def stop_thread(self):
+ self.active = False
+ self.packetizer.close()
### internals...
@@ -859,113 +858,10 @@ class BaseTransport (threading.Thread):
finally:
self.lock.release()
- def _check_keepalive(self):
- if (not self.keepalive_interval) or (not self.initial_kex_done):
- return
- now = time.time()
- if now > self.keepalive_last + self.keepalive_interval:
- self.global_request('keepalive@lag.net', wait=False)
-
- def _py22_read_all(self, n):
- out = ''
- while n > 0:
- r, w, e = select.select([self.sock], [], [], 0.1)
- if self.sock not in r:
- if not self.active:
- raise EOFError()
- self._check_keepalive()
- else:
- x = self.sock.recv(n)
- if len(x) == 0:
- raise EOFError()
- out += x
- n -= len(x)
- return out
-
- def _read_all(self, n):
- if PY22:
- return self._py22_read_all(n)
- out = ''
- while n > 0:
- try:
- x = self.sock.recv(n)
- if len(x) == 0:
- raise EOFError()
- out += x
- n -= len(x)
- except socket.timeout:
- if not self.active:
- raise EOFError()
- self._check_keepalive()
- return out
-
- def _write_all(self, out):
- self.keepalive_last = time.time()
- while len(out) > 0:
- try:
- n = self.sock.send(out)
- except socket.timeout:
- n = 0
- if not self.active:
- n = -1
- except Exception, x:
- # could be: (32, 'Broken pipe')
- n = -1
- if n < 0:
- raise EOFError()
- if n == len(out):
- return
- out = out[n:]
- return
-
- def _build_packet(self, payload):
- # pad up at least 4 bytes, to nearest block-size (usually 8)
- bsize = self.block_size_out
- padding = 3 + bsize - ((len(payload) + 8) % bsize)
- packet = struct.pack('>I', len(payload) + padding + 1)
- packet += chr(padding)
- packet += payload
- packet += randpool.get_bytes(padding)
- return packet
-
def _send_message(self, data):
- # encrypt this sucka
- data = str(data)
- cmd = ord(data[0])
- if cmd in MSG_NAMES:
- cmd_name = MSG_NAMES[cmd]
- else:
- cmd_name = '$%x' % cmd
- self._log(DEBUG, 'Write packet <%s>, length %d' % (cmd_name, len(data)))
- packet = self._build_packet(data)
- if self.ultra_debug:
- self._log(DEBUG, util.format_binary(packet, 'OUT: '))
- if self.engine_out != None:
- out = self.engine_out.encrypt(packet)
- else:
- out = packet
- # + mac
- try:
- self.write_lock.acquire()
- if self.engine_out != None:
- payload = struct.pack('>I', self.sequence_number_out) + packet
- out += HMAC.HMAC(self.mac_key_out, payload, self.local_mac_engine).digest()[:self.local_mac_len]
- self.sequence_number_out += 1L
- self.sequence_number_out %= 0x100000000L
- self._write_all(out)
-
- self.sent_bytes += len(out)
- self.sent_packets += 1
- if ((self.sent_packets >= self.REKEY_PACKETS) or (self.sent_bytes >= self.REKEY_BYTES)) \
- and (self.local_kex_init is None):
- # only ask once for rekeying
- self._log(DEBUG, 'Rekeying (hit %d packets, %d bytes sent)' %
- (self.sent_packets, self.sent_bytes))
- self.received_packets_overflow = 0
- # this may do a recursive lock, but that's okay:
- self._send_kex_init()
- finally:
- self.write_lock.release()
+ self.packetizer.send_message(data)
+ if self.packetizer.need_rekey():
+ self._send_kex_init()
def _send_user_message(self, data):
"""
@@ -981,65 +877,6 @@ class BaseTransport (threading.Thread):
break
self._send_message(data)
- def _read_message(self):
- "only one thread will ever be in this function"
- header = self._read_all(self.block_size_in)
- if self.engine_in != None:
- header = self.engine_in.decrypt(header)
- if self.ultra_debug:
- self._log(DEBUG, util.format_binary(header, 'IN: '));
- packet_size = struct.unpack('>I', header[:4])[0]
- # leftover contains decrypted bytes from the first block (after the length field)
- leftover = header[4:]
- if (packet_size - len(leftover)) % self.block_size_in != 0:
- raise SSHException('Invalid packet blocking')
- buffer = self._read_all(packet_size + self.remote_mac_len - len(leftover))
- packet = buffer[:packet_size - len(leftover)]
- post_packet = buffer[packet_size - len(leftover):]
- if self.engine_in != None:
- packet = self.engine_in.decrypt(packet)
- if self.ultra_debug:
- self._log(DEBUG, util.format_binary(packet, 'IN: '));
- packet = leftover + packet
- if self.remote_mac_len > 0:
- mac = post_packet[:self.remote_mac_len]
- mac_payload = struct.pack('>II', self.sequence_number_in, packet_size) + packet
- my_mac = HMAC.HMAC(self.mac_key_in, mac_payload, self.remote_mac_engine).digest()[:self.remote_mac_len]
- if my_mac != mac:
- raise SSHException('Mismatched MAC')
- padding = ord(packet[0])
- payload = packet[1:packet_size - padding + 1]
- randpool.add_event(packet[packet_size - padding + 1])
- if self.ultra_debug:
- self._log(DEBUG, 'Got payload (%d bytes, %d padding)' % (packet_size, padding))
- msg = Message(payload[1:])
- msg.seqno = self.sequence_number_in
- self.sequence_number_in = (self.sequence_number_in + 1) & 0xffffffffL
- # check for rekey
- self.received_bytes += packet_size + self.remote_mac_len + 4
- self.received_packets += 1
- if self.local_kex_init is not None:
- # we've asked to rekey -- give them 20 packets to comply before
- # dropping the connection
- self.received_packets_overflow += 1
- if self.received_packets_overflow >= 20:
- raise SSHException('Remote transport is ignoring rekey requests')
- elif (self.received_packets >= self.REKEY_PACKETS) or \
- (self.received_bytes >= self.REKEY_BYTES):
- # only ask once for rekeying
- self._log(DEBUG, 'Rekeying (hit %d packets, %d bytes received)' %
- (self.received_packets, self.received_bytes))
- self.received_packets_overflow = 0
- self._send_kex_init()
-
- cmd = ord(payload[0])
- if cmd in MSG_NAMES:
- cmd_name = MSG_NAMES[cmd]
- else:
- cmd_name = '$%x' % cmd
- self._log(DEBUG, 'Read packet <%s>, length %d' % (cmd_name, len(payload)))
- return cmd, msg
-
def _set_K_H(self, k, h):
"used by a kex object to set the K (root key) and H (exchange hash)"
self.K = k
@@ -1090,18 +927,21 @@ class BaseTransport (threading.Thread):
else:
self._log(DEBUG, 'starting thread (client mode): %s' % hex(long(id(self)) & 0xffffffffL))
try:
- self._write_all(self.local_version + '\r\n')
+ self.packetizer.write_all(self.local_version + '\r\n')
self._check_banner()
self._send_kex_init()
self.expected_packet = MSG_KEXINIT
while self.active:
- ptype, m = self._read_message()
+ if self.packetizer.need_rekey():
+ self._send_kex_init()
+ ptype, m = self.packetizer.read_message()
if ptype == MSG_IGNORE:
continue
elif ptype == MSG_DISCONNECT:
self._parse_disconnect(m)
self.active = False
+ self.packetizer.close()
break
elif ptype == MSG_DEBUG:
self._parse_debug(m)
@@ -1123,6 +963,7 @@ class BaseTransport (threading.Thread):
else:
self._log(ERROR, 'Channel request for unknown channel %d' % chanid)
self.active = False
+ self.packetizer.close()
else:
self._log(WARNING, 'Oops, unhandled type %d' % ptype)
msg = Message()
@@ -1146,6 +987,7 @@ class BaseTransport (threading.Thread):
chan._unlink()
if self.active:
self.active = False
+ self.packetizer.close()
if self.completion_event != None:
self.completion_event.set()
if self.auth_event != None:
@@ -1170,12 +1012,15 @@ class BaseTransport (threading.Thread):
def _check_banner(self):
# this is slow, but we only have to do it once
for i in range(5):
- buffer = ''
- while not '\n' in buffer:
- buffer += self._read_all(1)
- buffer = buffer[:-1]
- if (len(buffer) > 0) and (buffer[-1] == '\r'):
- buffer = buffer[:-1]
+ # give them 5 seconds for the first line, then just 2 seconds each additional line
+ if i == 0:
+ timeout = 5
+ else:
+ timeout = 2
+ try:
+ buffer = self.packetizer.readline(timeout)
+ except Exception, x:
+ raise SSHException('Error reading SSH protocol banner' + str(x))
if buffer[:4] == 'SSH-':
break
self._log(DEBUG, 'Banner: ' + buffer)
@@ -1236,13 +1081,6 @@ class BaseTransport (threading.Thread):
self._send_message(m)
def _parse_kex_init(self, m):
- # reset counters of when to re-key, since we are now re-keying
- self.received_bytes = 0
- self.received_packets = 0
- self.received_packets_overflow = 0
- self.sent_bytes = 0
- self.sent_packets = 0
-
cookie = m.get_bytes(16)
kex_algo_list = m.get_list()
server_key_algo_list = m.get_list()
@@ -1334,44 +1172,46 @@ class BaseTransport (threading.Thread):
def _activate_inbound(self):
"switch on newly negotiated encryption parameters for inbound traffic"
- self.block_size_in = self._cipher_info[self.remote_cipher]['block-size']
+ block_size = self._cipher_info[self.remote_cipher]['block-size']
if self.server_mode:
- IV_in = self._compute_key('A', self.block_size_in)
+ IV_in = self._compute_key('A', block_size)
key_in = self._compute_key('C', self._cipher_info[self.remote_cipher]['key-size'])
else:
- IV_in = self._compute_key('B', self.block_size_in)
+ IV_in = self._compute_key('B', block_size)
key_in = self._compute_key('D', self._cipher_info[self.remote_cipher]['key-size'])
- self.engine_in = self._get_cipher(self.remote_cipher, key_in, IV_in)
- self.remote_mac_len = self._mac_info[self.remote_mac]['size']
- self.remote_mac_engine = self._mac_info[self.remote_mac]['class']
+ engine = self._get_cipher(self.remote_cipher, key_in, IV_in)
+ mac_size = self._mac_info[self.remote_mac]['size']
+ mac_engine = self._mac_info[self.remote_mac]['class']
# initial mac keys are done in the hash's natural size (not the potentially truncated
# transmission size)
if self.server_mode:
- self.mac_key_in = self._compute_key('E', self.remote_mac_engine.digest_size)
+ mac_key = self._compute_key('E', mac_engine.digest_size)
else:
- self.mac_key_in = self._compute_key('F', self.remote_mac_engine.digest_size)
+ mac_key = self._compute_key('F', mac_engine.digest_size)
+ self.packetizer.set_inbound_cipher(engine, block_size, mac_engine, mac_size, mac_key)
def _activate_outbound(self):
"switch on newly negotiated encryption parameters for outbound traffic"
m = Message()
m.add_byte(chr(MSG_NEWKEYS))
self._send_message(m)
- self.block_size_out = self._cipher_info[self.local_cipher]['block-size']
+ block_size = self._cipher_info[self.local_cipher]['block-size']
if self.server_mode:
- IV_out = self._compute_key('B', self.block_size_out)
+ IV_out = self._compute_key('B', block_size)
key_out = self._compute_key('D', self._cipher_info[self.local_cipher]['key-size'])
else:
- IV_out = self._compute_key('A', self.block_size_out)
+ IV_out = self._compute_key('A', block_size)
key_out = self._compute_key('C', self._cipher_info[self.local_cipher]['key-size'])
- self.engine_out = self._get_cipher(self.local_cipher, key_out, IV_out)
- self.local_mac_len = self._mac_info[self.local_mac]['size']
- self.local_mac_engine = self._mac_info[self.local_mac]['class']
+ engine = self._get_cipher(self.local_cipher, key_out, IV_out)
+ mac_size = self._mac_info[self.local_mac]['size']
+ mac_engine = self._mac_info[self.local_mac]['class']
# initial mac keys are done in the hash's natural size (not the potentially truncated
# transmission size)
if self.server_mode:
- self.mac_key_out = self._compute_key('F', self.local_mac_engine.digest_size)
+ mac_key = self._compute_key('F', mac_engine.digest_size)
else:
- self.mac_key_out = self._compute_key('E', self.local_mac_engine.digest_size)
+ mac_key = self._compute_key('E', mac_engine.digest_size)
+ self.packetizer.set_outbound_cipher(engine, block_size, mac_engine, mac_size, mac_key)
# we always expect to receive NEWKEYS now
self.expected_packet = MSG_NEWKEYS
diff --git a/test.py b/test.py
index 2c4a28ad..a97354ee 100755
--- a/test.py
+++ b/test.py
@@ -70,6 +70,9 @@ options, args = parser.parse_args()
if len(args) > 0:
parser.error('unknown argument(s)')
+# setup logging
+paramiko.util.log_to_file('test.log')
+
if options.use_sftp:
if options.use_loopback_sftp:
SFTPTest.init_loopback()
@@ -78,9 +81,6 @@ if options.use_sftp:
if not options.use_big_file:
SFTPTest.set_big_file_test(False)
-# setup logging
-paramiko.util.log_to_file('test.log')
-
suite = unittest.TestSuite()
suite.addTest(unittest.makeSuite(MessageTest))
suite.addTest(unittest.makeSuite(BufferedFileTest))
diff --git a/tests/test_sftp.py b/tests/test_sftp.py
index 5d4d921c..5031f02f 100755
--- a/tests/test_sftp.py
+++ b/tests/test_sftp.py
@@ -1,5 +1,3 @@
-#!/usr/bin/python
-
# Copyright (C) 2003-2005 Robey Pointer <robey@lag.net>
#
# This file is part of paramiko.
@@ -145,7 +143,7 @@ class SFTPTest (unittest.TestCase):
finally:
sftp.remove(FOLDER + '/test')
- def test_1a_close(self):
+ def test_2_close(self):
"""
verify that closing the sftp session doesn't do anything bad, and that
a new one can be opened.
@@ -159,7 +157,7 @@ class SFTPTest (unittest.TestCase):
pass
sftp = paramiko.SFTP.from_transport(tc)
- def test_2_write(self):
+ def test_3_write(self):
"""
verify that a file can be created and written, and the size is correct.
"""
@@ -171,7 +169,7 @@ class SFTPTest (unittest.TestCase):
finally:
sftp.remove(FOLDER + '/duck.txt')
- def test_3_append(self):
+ def test_4_append(self):
"""
verify that a file can be opened for append, and tell() still works.
"""
@@ -191,7 +189,7 @@ class SFTPTest (unittest.TestCase):
finally:
sftp.remove(FOLDER + '/append.txt')
- def test_4_rename(self):
+ def test_5_rename(self):
"""
verify that renaming a file works.
"""
@@ -219,7 +217,7 @@ class SFTPTest (unittest.TestCase):
except:
pass
- def test_5_folder(self):
+ def test_6_folder(self):
"""
create a temporary folder, verify that we can create a file in it, then
remove the folder and verify that we can't create a file in it anymore.
@@ -236,7 +234,7 @@ class SFTPTest (unittest.TestCase):
except IOError:
pass
- def test_6_listdir(self):
+ def test_7_listdir(self):
"""
verify that a folder can be created, a bunch of files can be placed in it,
and those files show up in sftp.listdir.
@@ -262,7 +260,7 @@ class SFTPTest (unittest.TestCase):
sftp.remove(FOLDER + '/fish.txt')
sftp.remove(FOLDER + '/tertiary.py')
- def test_7_setstat(self):
+ def test_8_setstat(self):
"""
verify that the setstat functions (chown, chmod, utime) work.
"""
@@ -285,7 +283,7 @@ class SFTPTest (unittest.TestCase):
finally:
sftp.remove(FOLDER + '/special')
- def test_8_readline_seek(self):
+ def test_9_readline_seek(self):
"""
create a text file and write a bunch of text into it. then count the lines
in the file, and seek around to retreive particular lines. this should
@@ -315,7 +313,7 @@ class SFTPTest (unittest.TestCase):
finally:
sftp.remove(FOLDER + '/duck.txt')
- def test_9_write_seek(self):
+ def test_A_write_seek(self):
"""
create a text file, seek back and change part of it, and verify that the
changes worked.
@@ -335,7 +333,7 @@ class SFTPTest (unittest.TestCase):
finally:
sftp.remove(FOLDER + '/testing.txt')
- def test_A_symlink(self):
+ def test_B_symlink(self):
"""
create a symlink and then check that lstat doesn't follow it.
"""
@@ -378,7 +376,7 @@ class SFTPTest (unittest.TestCase):
except:
pass
- def test_B_flush_seek(self):
+ def test_C_flush_seek(self):
"""
verify that buffered writes are automatically flushed on seek.
"""
@@ -400,7 +398,7 @@ class SFTPTest (unittest.TestCase):
except:
pass
- def test_C_lots_of_files(self):
+ def test_D_lots_of_files(self):
"""
create a bunch of files over the same session.
"""
@@ -431,7 +429,7 @@ class SFTPTest (unittest.TestCase):
except:
pass
- def test_D_big_file(self):
+ def test_E_big_file(self):
"""
write a 1MB file, with no linefeeds, using line buffering.
FIXME: this is slow! what causes the slowness?
@@ -453,7 +451,7 @@ class SFTPTest (unittest.TestCase):
finally:
sftp.remove('%s/hongry.txt' % FOLDER)
- def test_E_big_file_big_buffer(self):
+ def test_F_big_file_big_buffer(self):
"""
write a 1MB file, with no linefeeds, and a big buffer.
"""
@@ -470,7 +468,7 @@ class SFTPTest (unittest.TestCase):
finally:
sftp.remove('%s/hongry.txt' % FOLDER)
- def test_F_realpath(self):
+ def test_G_realpath(self):
"""
test that realpath is returning something non-empty and not an
error.
@@ -481,7 +479,7 @@ class SFTPTest (unittest.TestCase):
self.assert_(len(f) > 0)
self.assertEquals(os.path.join(pwd, FOLDER), f)
- def test_G_mkdir(self):
+ def test_H_mkdir(self):
"""
verify that mkdir/rmdir work.
"""
diff --git a/tests/test_transport.py b/tests/test_transport.py
index bd11487f..5afc2e12 100644
--- a/tests/test_transport.py
+++ b/tests/test_transport.py
@@ -22,7 +22,7 @@
Some unit tests for the ssh2 protocol in Transport.
"""
-import sys, unittest, threading
+import sys, time, threading, unittest
from paramiko import Transport, SecurityOptions, ServerInterface, RSAKey, DSSKey, \
SSHException, BadAuthenticationType
from paramiko import AUTH_FAILED, AUTH_PARTIALLY_SUCCESSFUL, AUTH_SUCCESSFUL
@@ -77,6 +77,10 @@ class NullServer (ServerInterface):
def check_channel_shell_request(self, channel):
return True
+
+ def check_global_request(self, kind, msg):
+ self._global_request = kind
+ return False
class TransportTest (unittest.TestCase):
@@ -160,14 +164,38 @@ class TransportTest (unittest.TestCase):
self.assert_(self.ts.is_active())
self.assertEquals('aes256-cbc', self.tc.local_cipher)
self.assertEquals('aes256-cbc', self.tc.remote_cipher)
- self.assertEquals(12, self.tc.local_mac_len)
- self.assertEquals(12, self.tc.remote_mac_len)
+ self.assertEquals(12, self.tc.packetizer.get_mac_size_out())
+ self.assertEquals(12, self.tc.packetizer.get_mac_size_in())
self.tc.send_ignore(1024)
self.assert_(self.tc.renegotiate_keys())
self.ts.send_ignore(1024)
- def test_4_bad_auth_type(self):
+ def test_4_keepalive(self):
+ """
+ verify that the keepalive will be sent.
+ """
+ self.tc.set_hexdump(True)
+
+ host_key = RSAKey.from_private_key_file('tests/test_rsa.key')
+ public_host_key = RSAKey(data=str(host_key))
+ self.ts.add_server_key(host_key)
+ event = threading.Event()
+ server = NullServer()
+ self.assert_(not event.isSet())
+ self.ts.start_server(event, server)
+ self.tc.connect(hostkey=public_host_key,
+ username='slowdive', password='pygmalion')
+ event.wait(1.0)
+ self.assert_(event.isSet())
+ self.assert_(self.ts.is_active())
+
+ self.assertEquals(None, getattr(server, '_global_request', None))
+ self.tc.set_keepalive(1)
+ time.sleep(2)
+ self.assertEquals('keepalive@lag.net', server._global_request)
+
+ def test_5_bad_auth_type(self):
"""
verify that we get the right exception when an unsupported auth
type is requested.
@@ -188,7 +216,7 @@ class TransportTest (unittest.TestCase):
self.assertEquals(BadAuthenticationType, etype)
self.assertEquals(['publickey'], evalue.allowed_types)
- def test_5_bad_password(self):
+ def test_6_bad_password(self):
"""
verify that a bad password gets the right exception, and that a retry
with the right password works.
@@ -213,7 +241,7 @@ class TransportTest (unittest.TestCase):
self.assert_(event.isSet())
self.assert_(self.ts.is_active())
- def test_6_multipart_auth(self):
+ def test_7_multipart_auth(self):
"""
verify that multipart auth works.
"""
@@ -235,7 +263,7 @@ class TransportTest (unittest.TestCase):
self.assert_(event.isSet())
self.assert_(self.ts.is_active())
- def test_7_exec_command(self):
+ def test_8_exec_command(self):
"""
verify that exec_command() does something reasonable.
"""
@@ -285,7 +313,7 @@ class TransportTest (unittest.TestCase):
self.assertEquals('This is on stderr.\n', f.readline())
self.assertEquals('', f.readline())
- def test_8_invoke_shell(self):
+ def test_9_invoke_shell(self):
"""
verify that invoke_shell() does something reasonable.
"""
@@ -312,7 +340,7 @@ class TransportTest (unittest.TestCase):
chan.close()
self.assertEquals('', f.readline())
- def test_9_exit_status(self):
+ def test_A_exit_status(self):
"""
verify that get_exit_status() works.
"""