diff options
-rw-r--r-- | paramiko/message.py | 75 | ||||
-rw-r--r-- | paramiko/sftp_attr.py | 1 | ||||
-rw-r--r-- | paramiko/transport.py | 26 | ||||
-rw-r--r-- | tests/test_kex.py | 6 | ||||
-rw-r--r-- | tests/test_pkey.py | 4 |
5 files changed, 52 insertions, 60 deletions
diff --git a/paramiko/message.py b/paramiko/message.py index 9a4a9dfa..22ea53d2 100644 --- a/paramiko/message.py +++ b/paramiko/message.py @@ -20,7 +20,7 @@ Implementation of an SSH2 "message". """ -import string, types, struct +import struct, cStringIO import util @@ -31,16 +31,18 @@ class Message (object): as I{long}s). This class builds or breaks down such a byte stream. """ - def __init__(self, content=''): + def __init__(self, content=None): """ Create a new SSH2 Message. - @param content: the byte stream to use as the Message content (usually - passed in only when decomposing a Message). + @param content: the byte stream to use as the Message content (passed + in only when decomposing a Message). @type content: string """ - self.packet = content - self.idx = 0 + if content != None: + self.packet = cStringIO.StringIO(content) + else: + self.packet = cStringIO.StringIO() def __str__(self): """ @@ -49,7 +51,7 @@ class Message (object): @return: the contents of this Message. @rtype: string """ - return self.packet + return self.packet.getvalue() def __repr__(self): """ @@ -57,14 +59,14 @@ class Message (object): @rtype: string """ - return 'paramiko.Message(' + repr(self.packet) + ')' + return 'paramiko.Message(' + repr(self.packet.getvalue()) + ')' def rewind(self): """ Rewind the message to the beginning as if no items had been parsed out of it yet. """ - self.idx = 0 + self.packet.seek(0) def get_remainder(self): """ @@ -74,7 +76,10 @@ class Message (object): @return: a string of the bytes not parsed yet. @rtype: string """ - return self.packet[self.idx:] + position = self.packet.tell() + remainder = self.packet.read() + self.packet.seek(position) + return remainder def get_so_far(self): """ @@ -85,7 +90,9 @@ class Message (object): @return: a string of the bytes parsed so far. @rtype: string """ - return self.packet[:self.idx] + position = self.packet.tell() + self.rewind() + return self.packet.read(position) def get_bytes(self, n): """ @@ -96,10 +103,9 @@ class Message (object): of C{n} zero bytes, if there aren't C{n} bytes remaining. @rtype: string """ - if self.idx + n > len(self.packet): + b = self.packet.read(n) + if len(b) < n: return '\x00'*n - b = self.packet[self.idx:self.idx+n] - self.idx = self.idx + n return b def get_byte(self): @@ -130,13 +136,7 @@ class Message (object): @return: a 32-bit unsigned integer. @rtype: int """ - x = self.packet - i = self.idx - if i + 4 > len(x): - return 0 - n = struct.unpack('>I', x[i:i+4])[0] - self.idx = i+4 - return n + return struct.unpack('>I', self.get_bytes(4))[0] def get_int64(self): """ @@ -145,13 +145,7 @@ class Message (object): @return: a 64-bit unsigned integer. @rtype: long """ - x = self.packet - i = self.idx - if i + 8 > len(x): - return 0L - n = struct.unpack('>Q', x[i:i+8])[0] - self.idx += 8 - return n + return struct.unpack('>Q', self.get_bytes(8))[0] def get_mpint(self): """ @@ -171,12 +165,7 @@ class Message (object): @return: a string. @rtype: string """ - l = self.get_int() - if self.idx + l > len(self.packet): - return '' - str = self.packet[self.idx:self.idx+l] - self.idx = self.idx + l - return str + return self.get_bytes(self.get_int()) def get_list(self): """ @@ -186,16 +175,14 @@ class Message (object): @return: a list of strings. @rtype: list of strings """ - str = self.get_string() - l = string.split(str, ',') - return l + return self.get_string().split(',') def add_bytes(self, b): - self.packet = self.packet + b + self.packet.write(b) return self def add_byte(self, b): - self.packet = self.packet + b + self.packet.write(b) return self def add_boolean(self, b): @@ -206,7 +193,7 @@ class Message (object): return self def add_int(self, n): - self.packet = self.packet + struct.pack('>I', n) + self.packet.write(struct.pack('>I', n)) return self def add_int64(self, n): @@ -216,7 +203,7 @@ class Message (object): @param n: long int to add. @type n: long """ - self.packet = self.packet + struct.pack('>Q', n) + self.packet.write(struct.pack('>Q', n)) return self def add_mpint(self, z): @@ -226,13 +213,11 @@ class Message (object): def add_string(self, s): self.add_int(len(s)) - self.packet = self.packet + s + self.packet.write(s) return self def add_list(self, l): - out = string.join(l, ',') - self.add_int(len(out)) - self.packet = self.packet + out + self.add_string(','.join(l)) return self def _add(self, i): diff --git a/paramiko/sftp_attr.py b/paramiko/sftp_attr.py index 5123c914..b2812c6d 100644 --- a/paramiko/sftp_attr.py +++ b/paramiko/sftp_attr.py @@ -112,7 +112,6 @@ class SFTPAttributes (object): count = msg.get_int() for i in range(count): self.attr[msg.get_string()] = msg.get_string() - return msg.get_remainder() def _pack(self, msg): self._flags = 0 diff --git a/paramiko/transport.py b/paramiko/transport.py index d72fb20f..b73a06a5 100644 --- a/paramiko/transport.py +++ b/paramiko/transport.py @@ -639,8 +639,7 @@ class BaseTransport (threading.Thread): m.add_string(kind) m.add_boolean(wait) if data is not None: - for item in data: - m.add(item) + m.add(*data) self._log(DEBUG, 'Sending global request "%s"' % kind) self._send_user_message(m) if not wait: @@ -1085,16 +1084,16 @@ class BaseTransport (threading.Thread): m = Message() m.add_byte(chr(MSG_KEXINIT)) m.add_bytes(randpool.get_bytes(16)) - m.add(','.join(self._preferred_kex)) - m.add(','.join(available_server_keys)) - m.add(','.join(self._preferred_ciphers)) - m.add(','.join(self._preferred_ciphers)) - m.add(','.join(self._preferred_macs)) - m.add(','.join(self._preferred_macs)) - m.add('none') - m.add('none') - m.add('') - m.add('') + m.add_list(self._preferred_kex) + m.add_list(available_server_keys) + m.add_list(self._preferred_ciphers) + m.add_list(self._preferred_ciphers) + m.add_list(self._preferred_macs) + m.add_list(self._preferred_macs) + m.add_string('none') + m.add_string('none') + m.add_string('') + m.add_string('') m.add_boolean(False) m.add_int(0) # save a copy for later (needed to compute a hash) @@ -1274,8 +1273,7 @@ class BaseTransport (threading.Thread): msg = Message() if ok: msg.add_byte(chr(MSG_REQUEST_SUCCESS)) - for item in extra: - msg.add(item) + msg.add(*extra) else: msg.add_byte(chr(MSG_REQUEST_FAILURE)) self._send_message(msg) diff --git a/tests/test_kex.py b/tests/test_kex.py index 536c3867..2680853e 100644 --- a/tests/test_kex.py +++ b/tests/test_kex.py @@ -97,6 +97,7 @@ class KexTest (unittest.TestCase): msg.add_string('fake-host-key') msg.add_mpint(69) msg.add_string('fake-sig') + msg.rewind() kex.parse_next(paramiko.kex_group1._MSG_KEXDH_REPLY, msg) H = '03079780F3D3AD0B3C6DB30C8D21685F367A86D2' self.assertEquals(self.K, transport._K) @@ -113,6 +114,7 @@ class KexTest (unittest.TestCase): msg = Message() msg.add_mpint(69) + msg.rewind() kex.parse_next(paramiko.kex_group1._MSG_KEXDH_INIT, msg) H = 'B16BF34DD10945EDE84E9C1EF24A14BFDC843389' x = '1F0000000866616B652D6B6579000000807E2DDB1743F3487D6545F04F1C8476092FB912B013626AB5BCEB764257D88BBA64243B9F348DF7B41B8C814A995E00299913503456983FFB9178D3CD79EB6D55522418A8ABF65375872E55938AB99A84A0B5FC8A1ECC66A7C3766E7E0F80B7CE2C9225FC2DD683F4764244B72963BBB383F529DCF0C5D17740B8A2ADBE9208D40000000866616B652D736967' @@ -133,6 +135,7 @@ class KexTest (unittest.TestCase): msg = Message() msg.add_mpint(FakeModulusPack.P) msg.add_mpint(FakeModulusPack.G) + msg.rewind() kex.parse_next(paramiko.kex_gex._MSG_KEXDH_GEX_GROUP, msg) x = '20000000807E2DDB1743F3487D6545F04F1C8476092FB912B013626AB5BCEB764257D88BBA64243B9F348DF7B41B8C814A995E00299913503456983FFB9178D3CD79EB6D55522418A8ABF65375872E55938AB99A84A0B5FC8A1ECC66A7C3766E7E0F80B7CE2C9225FC2DD683F4764244B72963BBB383F529DCF0C5D17740B8A2ADBE9208D4' self.assertEquals(x, paramiko.util.hexify(str(transport._message))) @@ -142,6 +145,7 @@ class KexTest (unittest.TestCase): msg.add_string('fake-host-key') msg.add_mpint(69) msg.add_string('fake-sig') + msg.rewind() kex.parse_next(paramiko.kex_gex._MSG_KEXDH_GEX_REPLY, msg) H = 'A265563F2FA87F1A89BF007EE90D58BE2E4A4BD0' self.assertEquals(self.K, transport._K) @@ -160,6 +164,7 @@ class KexTest (unittest.TestCase): msg.add_int(1024) msg.add_int(2048) msg.add_int(4096) + msg.rewind() kex.parse_next(paramiko.kex_gex._MSG_KEXDH_GEX_REQUEST, msg) x = '1F0000008100FFFFFFFFFFFFFFFFC90FDAA22168C234C4C6628B80DC1CD129024E088A67CC74020BBEA63B139B22514A08798E3404DDEF9519B3CD3A431B302B0A6DF25F14374FE1356D6D51C245E485B576625E7EC6F44C42E9A637ED6B0BFF5CB6F406B7EDEE386BFB5A899FA5AE9F24117C4B1FE649286651ECE65381FFFFFFFFFFFFFFFF0000000102' self.assertEquals(x, paramiko.util.hexify(str(transport._message))) @@ -167,6 +172,7 @@ class KexTest (unittest.TestCase): msg = Message() msg.add_mpint(12345) + msg.rewind() kex.parse_next(paramiko.kex_gex._MSG_KEXDH_GEX_INIT, msg) K = 67592995013596137876033460028393339951879041140378510871612128162185209509220726296697886624612526735888348020498716482757677848959420073720160491114319163078862905400020959196386947926388406687288901564192071077389283980347784184487280885335302632305026248574716290537036069329724382811853044654824945750581L H = 'CE754197C21BF3452863B4F44D0B3951F12516EF' diff --git a/tests/test_pkey.py b/tests/test_pkey.py index 3cd051fe..e56edb15 100644 --- a/tests/test_pkey.py +++ b/tests/test_pkey.py @@ -104,6 +104,7 @@ class KeyTest (unittest.TestCase): key = RSAKey.from_private_key_file('tests/test_rsa.key') msg = key.sign_ssh_data(randpool, 'ice weasels') self.assert_(type(msg) is Message) + msg.rewind() self.assertEquals('ssh-rsa', msg.get_string()) sig = ''.join([chr(int(x, 16)) for x in SIGNED_RSA.split(':')]) self.assertEquals(sig, msg.get_string()) @@ -116,6 +117,7 @@ class KeyTest (unittest.TestCase): key = DSSKey.from_private_key_file('tests/test_dss.key') msg = key.sign_ssh_data(randpool, 'ice weasels') self.assert_(type(msg) is Message) + msg.rewind() self.assertEquals('ssh-dss', msg.get_string()) # can't do the same test as we do for RSA, because DSS signatures # are usually different each time. but we can test verification @@ -128,9 +130,11 @@ class KeyTest (unittest.TestCase): def test_A_generate_rsa(self): key = RSAKey.generate(1024) msg = key.sign_ssh_data(randpool, 'jerri blank') + msg.rewind() self.assert_(key.verify_ssh_sig('jerri blank', msg)) def test_B_generate_dss(self): key = DSSKey.generate(1024) msg = key.sign_ssh_data(randpool, 'jerri blank') + msg.rewind() self.assert_(key.verify_ssh_sig('jerri blank', msg)) |