summaryrefslogtreecommitdiffhomepage
diff options
context:
space:
mode:
authorRobey Pointer <robey@lag.net>2004-04-07 15:52:07 +0000
committerRobey Pointer <robey@lag.net>2004-04-07 15:52:07 +0000
commit17acfb5d28be4c5fec3253ef0f55ebc8007c1863 (patch)
treec557d1a6c86632d775dcd185a1ff6cfba7723f3d
parent1af636000743dcd43f17ad79d5b59aa83e1bc384 (diff)
[project @ Arch-1:robey@lag.net--2003-public%secsh--dev--1.0--patch-45]
add set_keepalive() add set_keepalive() to set an automatic keepalive mechanism. (while waiting for a packet on a connection, we periodically check if it's time to send a keepalive packet.)
-rw-r--r--paramiko/transport.py40
1 files changed, 37 insertions, 3 deletions
diff --git a/paramiko/transport.py b/paramiko/transport.py
index 29f59e8d..4e2ddbcb 100644
--- a/paramiko/transport.py
+++ b/paramiko/transport.py
@@ -22,7 +22,7 @@
L{BaseTransport} handles the core SSH2 protocol.
"""
-import sys, os, string, threading, socket, struct
+import sys, os, string, threading, socket, struct, time
from common import *
from ssh_exception import SSHException
@@ -154,7 +154,7 @@ class BaseTransport (threading.Thread):
# /negotiated crypto parameters
self.expected_packet = 0
self.active = False
- self.initial_kex_done = 0
+ self.initial_kex_done = False
self.write_lock = threading.Lock() # lock around outbound writes (packet computation)
self.lock = threading.Lock() # synchronization (always higher level than write_lock)
self.channels = { } # (id -> Channel)
@@ -171,6 +171,9 @@ class BaseTransport (threading.Thread):
self.received_packets_overflow = 0
# user-defined event callbacks:
self.completion_event = None
+ # keepalives:
+ self.keepalive_interval = 0
+ self.keepalive_last = time.time()
# server mode:
self.server_mode = 0
self.server_key_dict = { }
@@ -432,6 +435,8 @@ class BaseTransport (threading.Thread):
@param bytes: the number of random bytes to send in the payload of the
ignored packet -- defaults to a random number from 10 to 41.
@type bytes: int
+
+ @since: fearow
"""
m = Message()
m.add_byte(chr(MSG_IGNORE))
@@ -464,6 +469,19 @@ class BaseTransport (threading.Thread):
break
return True
+ def set_keepalive(self, interval):
+ """
+ Turn on/off keepalive packets (default is off). If this is set, after
+ C{interval} seconds without sending any data over the connection, a
+ "keepalive" packet will be sent (and ignored by the remote host). This
+ can be useful to keep connections alive over a NAT, for example.
+
+ @param interval: seconds to wait before sending a keepalive packet (or
+ 0 to disable keepalives).
+ @type interval: int
+ """
+ self.keepalive_interval = interval
+
def global_request(self, kind, data=None, wait=True):
"""
Make a global request to the remote host. These are normally
@@ -481,6 +499,8 @@ class BaseTransport (threading.Thread):
request was successful (or an empty L{Message} if C{wait} was
C{False}); C{None} if the request was denied.
@rtype: L{Message}
+
+ @since: fearow
"""
if wait:
self.completion_event = threading.Event()
@@ -491,6 +511,7 @@ class BaseTransport (threading.Thread):
if data is not None:
for item in data:
m.add(item)
+ self._log(DEBUG, 'Sending global request "%s"' % kind)
self._send_message(m)
if not wait:
return True
@@ -691,6 +712,13 @@ class BaseTransport (threading.Thread):
finally:
self.lock.release()
+ def _check_keepalive(self):
+ if (not self.keepalive_interval) or (not self.initial_kex_done):
+ return
+ now = time.time()
+ if now > self.keepalive_last + self.keepalive_interval:
+ self.global_request('keepalive@lag.net', wait=False)
+
def _py22_read_all(self, n):
out = ''
while n > 0:
@@ -698,6 +726,7 @@ class BaseTransport (threading.Thread):
if self.sock not in r:
if not self.active:
raise EOFError()
+ self._check_keepalive()
else:
x = self.sock.recv(n)
if len(x) == 0:
@@ -720,9 +749,11 @@ class BaseTransport (threading.Thread):
except socket.timeout:
if not self.active:
raise EOFError()
+ self._check_keepalive()
return out
def _write_all(self, out):
+ self.keepalive_last = time.time()
while len(out) > 0:
n = self.sock.send(out)
if n <= 0:
@@ -1156,7 +1187,7 @@ class BaseTransport (threading.Thread):
self.e = self.f = self.K = self.x = None
if not self.initial_kex_done:
# this was the first key exchange
- self.initial_kex_done = 1
+ self.initial_kex_done = True
# send an event?
if self.completion_event != None:
self.completion_event.set()
@@ -1169,6 +1200,7 @@ class BaseTransport (threading.Thread):
def _parse_global_request(self, m):
kind = m.get_string()
+ self._log(DEBUG, 'Received global request "%s"' % kind)
want_reply = m.get_boolean()
ok = self.check_global_request(kind, m)
extra = ()
@@ -1186,11 +1218,13 @@ class BaseTransport (threading.Thread):
self._send_message(msg)
def _parse_request_success(self, m):
+ self._log(DEBUG, 'Global request successful.')
self.global_response = m
if self.completion_event is not None:
self.completion_event.set()
def _parse_request_failure(self, m):
+ self._log(DEBUG, 'Global request denied.')
self.global_response = None
if self.completion_event is not None:
self.completion_event.set()