summaryrefslogtreecommitdiffhomepage
diff options
context:
space:
mode:
-rw-r--r--paramiko/transport.py18
1 files changed, 9 insertions, 9 deletions
diff --git a/paramiko/transport.py b/paramiko/transport.py
index 05d30873..df871514 100644
--- a/paramiko/transport.py
+++ b/paramiko/transport.py
@@ -267,7 +267,7 @@ class Transport (threading.Thread):
self.initial_kex_done = False
self.in_kex = False
self.authenticated = False
- self.expected_packet = 0
+ self._expected_packet = tuple()
self.lock = threading.Lock() # synchronization (always higher level than write_lock)
# tracking open channels
@@ -1275,9 +1275,9 @@ class Transport (threading.Thread):
if self.session_id == None:
self.session_id = h
- def _expect_packet(self, type):
+ def _expect_packet(self, *ptypes):
"used by a kex object to register the next packet type it expects to see"
- self.expected_packet = type
+ self._expected_packet = tuple(ptypes)
def _verify_key(self, host_key, sig):
key = self._key_info[self.host_key_type](Message(host_key))
@@ -1326,7 +1326,7 @@ class Transport (threading.Thread):
self.packetizer.write_all(self.local_version + '\r\n')
self._check_banner()
self._send_kex_init()
- self.expected_packet = MSG_KEXINIT
+ self._expect_packet(MSG_KEXINIT)
while self.active:
if self.packetizer.need_rekey() and not self.in_kex:
@@ -1345,10 +1345,10 @@ class Transport (threading.Thread):
elif ptype == MSG_DEBUG:
self._parse_debug(m)
continue
- if self.expected_packet != 0:
- if ptype != self.expected_packet:
- raise SSHException('Expecting packet %d, got %d' % (self.expected_packet, ptype))
- self.expected_packet = 0
+ if len(self._expected_packet) > 0:
+ if ptype not in self._expected_packet:
+ raise SSHException('Expecting packet from %r, got %d' % (self._expected_packet, ptype))
+ self._expected_packet = tuple()
if (ptype >= 30) and (ptype <= 39):
self.kex_engine.parse_next(ptype, m)
continue
@@ -1651,7 +1651,7 @@ class Transport (threading.Thread):
if not self.packetizer.need_rekey():
self.in_kex = False
# we always expect to receive NEWKEYS now
- self.expected_packet = MSG_NEWKEYS
+ self._expect_packet(MSG_NEWKEYS)
def _auth_trigger(self):
self.authenticated = True