summaryrefslogtreecommitdiffhomepage
diff options
context:
space:
mode:
-rw-r--r--paramiko/agent.py3
-rw-r--r--paramiko/client.py3
-rw-r--r--paramiko/packet.py16
-rw-r--r--paramiko/transport.py3
-rw-r--r--paramiko/util.py9
-rw-r--r--tests/test_util.py26
6 files changed, 52 insertions, 8 deletions
diff --git a/paramiko/agent.py b/paramiko/agent.py
index 7115f17b..5d04dce8 100644
--- a/paramiko/agent.py
+++ b/paramiko/agent.py
@@ -35,6 +35,7 @@ from paramiko.message import Message
from paramiko.pkey import PKey
from paramiko.channel import Channel
from paramiko.common import io_sleep
+from paramiko.util import retry_on_signal
SSH2_AGENTC_REQUEST_IDENTITIES, SSH2_AGENT_IDENTITIES_ANSWER, \
SSH2_AGENTC_SIGN_REQUEST, SSH2_AGENT_SIGN_RESPONSE = range(11, 15)
@@ -202,7 +203,7 @@ class AgentClientProxy(object):
if ('SSH_AUTH_SOCK' in os.environ) and (sys.platform != 'win32'):
conn = socket.socket(socket.AF_UNIX, socket.SOCK_STREAM)
try:
- conn.connect(os.environ['SSH_AUTH_SOCK'])
+ retry_on_signal(lambda: conn.connect(os.environ['SSH_AUTH_SOCK']))
except:
# probably a dangling env var: the ssh agent is gone
return
diff --git a/paramiko/client.py b/paramiko/client.py
index 557cbb73..3ccb52bf 100644
--- a/paramiko/client.py
+++ b/paramiko/client.py
@@ -34,6 +34,7 @@ from paramiko.resource import ResourceManager
from paramiko.rsakey import RSAKey
from paramiko.ssh_exception import SSHException, BadHostKeyException
from paramiko.transport import Transport
+from paramiko.util import retry_on_signal
SSH_PORT = 22
@@ -293,7 +294,7 @@ class SSHClient (object):
sock.settimeout(timeout)
except:
pass
- sock.connect(addr)
+ retry_on_signal(lambda: sock.connect(addr))
t = self._transport = Transport(sock)
t.use_compression(compress=compress)
if self._log_channel is not None:
diff --git a/paramiko/packet.py b/paramiko/packet.py
index 2f6d692c..97820619 100644
--- a/paramiko/packet.py
+++ b/paramiko/packet.py
@@ -241,23 +241,23 @@ class Packetizer (object):
def write_all(self, out):
self.__keepalive_last = time.time()
while len(out) > 0:
- got_timeout = False
+ retry_write = False
try:
n = self.__socket.send(out)
except socket.timeout:
- got_timeout = True
+ retry_write = True
except socket.error, e:
if (type(e.args) is tuple) and (len(e.args) > 0) and (e.args[0] == errno.EAGAIN):
- got_timeout = True
+ retry_write = True
elif (type(e.args) is tuple) and (len(e.args) > 0) and (e.args[0] == errno.EINTR):
# syscall interrupted; try again
- pass
+ retry_write = True
else:
n = -1
except Exception:
# could be: (32, 'Broken pipe')
n = -1
- if got_timeout:
+ if retry_write:
n = 0
if self.__closed:
n = -1
@@ -469,6 +469,12 @@ class Packetizer (object):
break
except socket.timeout:
pass
+ except EnvironmentError, e:
+ if ((type(e.args) is tuple) and (len(e.args) > 0) and
+ (e.args[0] == errno.EINTR)):
+ pass
+ else:
+ raise
if self.__closed:
raise EOFError()
now = time.time()
diff --git a/paramiko/transport.py b/paramiko/transport.py
index 8174a4cf..dd389a88 100644
--- a/paramiko/transport.py
+++ b/paramiko/transport.py
@@ -45,6 +45,7 @@ from paramiko.rsakey import RSAKey
from paramiko.server import ServerInterface
from paramiko.sftp_client import SFTPClient
from paramiko.ssh_exception import SSHException, BadAuthenticationType, ChannelException
+from paramiko.util import retry_on_signal
from Crypto import Random
from Crypto.Cipher import Blowfish, AES, DES3, ARC4
@@ -289,7 +290,7 @@ class Transport (threading.Thread):
addr = sockaddr
sock = socket.socket(af, socket.SOCK_STREAM)
try:
- sock.connect((hostname, port))
+ retry_on_signal(lambda: sock.connect((hostname, port)))
except socket.error, e:
reason = str(e)
else:
diff --git a/paramiko/util.py b/paramiko/util.py
index 0d6a5348..f4bfbecd 100644
--- a/paramiko/util.py
+++ b/paramiko/util.py
@@ -24,6 +24,7 @@ from __future__ import generators
import array
from binascii import hexlify, unhexlify
+import errno
import sys
import struct
import traceback
@@ -270,6 +271,14 @@ def get_logger(name):
l.addFilter(_pfilter)
return l
+def retry_on_signal(function):
+ """Retries function until it doesn't raise an EINTR error"""
+ while True:
+ try:
+ return function()
+ except EnvironmentError, e:
+ if e.errno != errno.EINTR:
+ raise
class Counter (object):
"""Stateful counter for CTR mode crypto"""
diff --git a/tests/test_util.py b/tests/test_util.py
index ed0607fa..59a3d99e 100644
--- a/tests/test_util.py
+++ b/tests/test_util.py
@@ -22,6 +22,7 @@ Some unit tests for utility functions.
from binascii import hexlify
import cStringIO
+import errno
import os
import unittest
from Crypto.Hash import SHA
@@ -177,3 +178,28 @@ Host *
ssh.util.lookup_ssh_host_config(host, config),
{'hostname': host, 'port': '22'}
)
+
+ def test_8_eintr_retry(self):
+ self.assertEquals('foo', ssh.util.retry_on_signal(lambda: 'foo'))
+
+ # Variables that are set by raises_intr
+ intr_errors_remaining = [3]
+ call_count = [0]
+ def raises_intr():
+ call_count[0] += 1
+ if intr_errors_remaining[0] > 0:
+ intr_errors_remaining[0] -= 1
+ raise IOError(errno.EINTR, 'file', 'interrupted system call')
+ self.assertTrue(ssh.util.retry_on_signal(raises_intr) is None)
+ self.assertEquals(0, intr_errors_remaining[0])
+ self.assertEquals(4, call_count[0])
+
+ def raises_ioerror_not_eintr():
+ raise IOError(errno.ENOENT, 'file', 'file not found')
+ self.assertRaises(IOError,
+ lambda: ssh.util.retry_on_signal(raises_ioerror_not_eintr))
+
+ def raises_other_exception():
+ raise AssertionError('foo')
+ self.assertRaises(AssertionError,
+ lambda: ssh.util.retry_on_signal(raises_other_exception))