summaryrefslogtreecommitdiffhomepage
diff options
context:
space:
mode:
authorJeff Forcier <jeff@bitprophet.org>2014-03-05 17:03:37 -0800
committerJeff Forcier <jeff@bitprophet.org>2014-03-05 17:03:37 -0800
commitb2be63ec623b5944f9b84cac8b8f41aeb2b42fb7 (patch)
tree389e17b0c08cd34872a2e3afbc34860ab44fb3ed
parentbd61c7c0a9a4a2020d0acfb6a01e9ec85bb43b8e (diff)
parentae078f51d622931954e47e78029a889c4e721a05 (diff)
Merge remote-tracking branch 'scottkmaxwell/py3-support-without-py25' into python3
Conflicts: dev-requirements.txt paramiko/__init__.py paramiko/file.py paramiko/hostkeys.py paramiko/message.py paramiko/proxy.py paramiko/server.py paramiko/transport.py paramiko/util.py paramiko/win_pageant.py setup.py
-rw-r--r--.travis.yml2
-rw-r--r--README4
-rwxr-xr-xdemos/demo.py53
-rwxr-xr-xdemos/demo_keygen.py23
-rw-r--r--demos/demo_server.py43
-rwxr-xr-xdemos/demo_sftp.py37
-rwxr-xr-xdemos/demo_simple.py23
-rw-r--r--demos/forward.py16
-rw-r--r--demos/interactive.py5
-rwxr-xr-xdemos/rforward.py10
-rw-r--r--paramiko/__init__.py60
-rw-r--r--paramiko/_winapi.py14
-rw-r--r--paramiko/agent.py30
-rw-r--r--paramiko/auth_handler.py72
-rw-r--r--paramiko/ber.py30
-rw-r--r--paramiko/buffered_pipe.py25
-rw-r--r--paramiko/channel.py56
-rw-r--r--paramiko/client.py27
-rw-r--r--paramiko/common.py45
-rw-r--r--paramiko/config.py2
-rw-r--r--paramiko/dsskey.py27
-rw-r--r--paramiko/ecdsakey.py34
-rw-r--r--paramiko/file.py72
-rw-r--r--paramiko/hostkeys.py90
-rw-r--r--paramiko/kex_gex.py30
-rw-r--r--paramiko/kex_group1.py33
-rw-r--r--paramiko/message.py103
-rw-r--r--paramiko/packet.py42
-rw-r--r--paramiko/pipe.py5
-rw-r--r--paramiko/pkey.py44
-rw-r--r--paramiko/primes.py15
-rw-r--r--paramiko/proxy.py4
-rw-r--r--paramiko/py3compat.py160
-rw-r--r--paramiko/rsakey.py24
-rw-r--r--paramiko/server.py4
-rw-r--r--paramiko/sftp.py15
-rw-r--r--paramiko/sftp_attr.py14
-rw-r--r--paramiko/sftp_client.py73
-rw-r--r--paramiko/sftp_file.py14
-rw-r--r--paramiko/sftp_handle.py4
-rw-r--r--paramiko/sftp_server.py83
-rw-r--r--paramiko/transport.py183
-rw-r--r--paramiko/util.py81
-rw-r--r--paramiko/win_pageant.py3
-rw-r--r--setup.py2
-rwxr-xr-xtest.py35
-rw-r--r--tests/__init__.py0
-rw-r--r--tests/loop.py8
-rw-r--r--tests/stub_sftp.py32
-rw-r--r--tests/test_auth.py54
-rw-r--r--tests/test_buffered_pipe.py31
-rw-r--r--tests/test_client.py126
-rwxr-xr-xtests/test_file.py17
-rw-r--r--tests/test_hostkeys.py56
-rw-r--r--tests/test_kex.py129
-rw-r--r--tests/test_message.py77
-rw-r--r--tests/test_packetizer.py35
-rw-r--r--tests/test_pkey.py191
-rwxr-xr-xtests/test_sftp.py440
-rw-r--r--tests/test_sftp_big.py313
-rw-r--r--tests/test_transport.py223
-rw-r--r--tests/test_util.py67
-rw-r--r--tests/util.py7
-rw-r--r--tox.ini2
64 files changed, 1978 insertions, 1601 deletions
diff --git a/.travis.yml b/.travis.yml
index 97165c47..7042570f 100644
--- a/.travis.yml
+++ b/.travis.yml
@@ -2,6 +2,8 @@ language: python
python:
- "2.6"
- "2.7"
+ - "3.2"
+ - "3.3"
install:
# Self-install for setup.py-driven deps
- pip install -e .
diff --git a/README b/README
index 537956e8..77866b8f 100644
--- a/README
+++ b/README
@@ -15,7 +15,7 @@ What
----
"paramiko" is a combination of the esperanto words for "paranoid" and
-"friend". it's a module for python 2.5+ that implements the SSH2 protocol
+"friend". it's a module for python 2.6+ that implements the SSH2 protocol
for secure (encrypted and authenticated) connections to remote machines.
unlike SSL (aka TLS), SSH2 protocol does not require hierarchical
certificates signed by a powerful central authority. you may know SSH2 as
@@ -34,7 +34,7 @@ that should have come with this archive.
Requirements
------------
- - python 2.5 or better <http://www.python.org/>
+ - python 2.6 or better <http://www.python.org/>
- pycrypto 2.1 or better <https://www.dlitz.net/software/pycrypto/>
- ecdsa 0.9 or better <https://pypi.python.org/pypi/ecdsa>
diff --git a/demos/demo.py b/demos/demo.py
index aa4bdaa5..fff61784 100755
--- a/demos/demo.py
+++ b/demos/demo.py
@@ -28,9 +28,13 @@ import socket
import sys
import time
import traceback
+from paramiko.py3compat import input
import paramiko
-import interactive
+try:
+ import interactive
+except ImportError:
+ from . import interactive
def agent_auth(transport, username):
@@ -45,24 +49,24 @@ def agent_auth(transport, username):
return
for key in agent_keys:
- print 'Trying ssh-agent key %s' % hexlify(key.get_fingerprint()),
+ print('Trying ssh-agent key %s' % hexlify(key.get_fingerprint()))
try:
transport.auth_publickey(username, key)
- print '... success!'
+ print('... success!')
return
except paramiko.SSHException:
- print '... nope.'
+ print('... nope.')
def manual_auth(username, hostname):
default_auth = 'p'
- auth = raw_input('Auth by (p)assword, (r)sa key, or (d)ss key? [%s] ' % default_auth)
+ auth = input('Auth by (p)assword, (r)sa key, or (d)ss key? [%s] ' % default_auth)
if len(auth) == 0:
auth = default_auth
if auth == 'r':
default_path = os.path.join(os.environ['HOME'], '.ssh', 'id_rsa')
- path = raw_input('RSA key [%s]: ' % default_path)
+ path = input('RSA key [%s]: ' % default_path)
if len(path) == 0:
path = default_path
try:
@@ -73,7 +77,7 @@ def manual_auth(username, hostname):
t.auth_publickey(username, key)
elif auth == 'd':
default_path = os.path.join(os.environ['HOME'], '.ssh', 'id_dsa')
- path = raw_input('DSS key [%s]: ' % default_path)
+ path = input('DSS key [%s]: ' % default_path)
if len(path) == 0:
path = default_path
try:
@@ -96,9 +100,9 @@ if len(sys.argv) > 1:
if hostname.find('@') >= 0:
username, hostname = hostname.split('@')
else:
- hostname = raw_input('Hostname: ')
+ hostname = input('Hostname: ')
if len(hostname) == 0:
- print '*** Hostname required.'
+ print('*** Hostname required.')
sys.exit(1)
port = 22
if hostname.find(':') >= 0:
@@ -109,8 +113,8 @@ if hostname.find(':') >= 0:
try:
sock = socket.socket(socket.AF_INET, socket.SOCK_STREAM)
sock.connect((hostname, port))
-except Exception, e:
- print '*** Connect failed: ' + str(e)
+except Exception as e:
+ print('*** Connect failed: ' + str(e))
traceback.print_exc()
sys.exit(1)
@@ -119,7 +123,7 @@ try:
try:
t.start_client()
except paramiko.SSHException:
- print '*** SSH negotiation failed.'
+ print('*** SSH negotiation failed.')
sys.exit(1)
try:
@@ -128,25 +132,25 @@ try:
try:
keys = paramiko.util.load_host_keys(os.path.expanduser('~/ssh/known_hosts'))
except IOError:
- print '*** Unable to open host keys file'
+ print('*** Unable to open host keys file')
keys = {}
# check server's host key -- this is important.
key = t.get_remote_server_key()
- if not keys.has_key(hostname):
- print '*** WARNING: Unknown host key!'
- elif not keys[hostname].has_key(key.get_name()):
- print '*** WARNING: Unknown host key!'
+ if hostname not in keys:
+ print('*** WARNING: Unknown host key!')
+ elif key.get_name() not in keys[hostname]:
+ print('*** WARNING: Unknown host key!')
elif keys[hostname][key.get_name()] != key:
- print '*** WARNING: Host key has changed!!!'
+ print('*** WARNING: Host key has changed!!!')
sys.exit(1)
else:
- print '*** Host key OK.'
+ print('*** Host key OK.')
# get username
if username == '':
default_username = getpass.getuser()
- username = raw_input('Username [%s]: ' % default_username)
+ username = input('Username [%s]: ' % default_username)
if len(username) == 0:
username = default_username
@@ -154,21 +158,20 @@ try:
if not t.is_authenticated():
manual_auth(username, hostname)
if not t.is_authenticated():
- print '*** Authentication failed. :('
+ print('*** Authentication failed. :(')
t.close()
sys.exit(1)
chan = t.open_session()
chan.get_pty()
chan.invoke_shell()
- print '*** Here we go!'
- print
+ print('*** Here we go!\n')
interactive.interactive_shell(chan)
chan.close()
t.close()
-except Exception, e:
- print '*** Caught exception: ' + str(e.__class__) + ': ' + str(e)
+except Exception as e:
+ print('*** Caught exception: ' + str(e.__class__) + ': ' + str(e))
traceback.print_exc()
try:
t.close()
diff --git a/demos/demo_keygen.py b/demos/demo_keygen.py
index bdd7388d..860ee4e9 100755
--- a/demos/demo_keygen.py
+++ b/demos/demo_keygen.py
@@ -17,9 +17,7 @@
# You should have received a copy of the GNU Lesser General Public License
# along with Paramiko; if not, write to the Free Software Foundation, Inc.,
# 59 Temple Place, Suite 330, Boston, MA 02111-1307 USA.
-from __future__ import with_statement
-import string
import sys
from binascii import hexlify
@@ -28,6 +26,7 @@ from optparse import OptionParser
from paramiko import DSSKey
from paramiko import RSAKey
from paramiko.ssh_exception import SSHException
+from paramiko.py3compat import u
usage="""
%prog [-v] [-b bits] -t type [-N new_passphrase] [-f output_keyfile]"""
@@ -47,16 +46,16 @@ key_dispatch_table = {
def progress(arg=None):
if not arg:
- print '0%\x08\x08\x08',
+ sys.stdout.write('0%\x08\x08\x08 ')
sys.stdout.flush()
elif arg[0] == 'p':
- print '25%\x08\x08\x08\x08',
+ sys.stdout.write('25%\x08\x08\x08\x08 ')
sys.stdout.flush()
elif arg[0] == 'h':
- print '50%\x08\x08\x08\x08',
+ sys.stdout.write('50%\x08\x08\x08\x08 ')
sys.stdout.flush()
elif arg[0] == 'x':
- print '75%\x08\x08\x08\x08',
+ sys.stdout.write('75%\x08\x08\x08\x08 ')
sys.stdout.flush()
if __name__ == '__main__':
@@ -92,8 +91,8 @@ if __name__ == '__main__':
parser.print_help()
sys.exit(0)
- for o in default_values.keys():
- globals()[o] = getattr(options, o, default_values[string.lower(o)])
+ for o in list(default_values.keys()):
+ globals()[o] = getattr(options, o, default_values[o.lower()])
if options.newphrase:
phrase = getattr(options, 'newphrase')
@@ -106,7 +105,7 @@ if __name__ == '__main__':
if ktype == 'dsa' and bits > 1024:
raise SSHException("DSA Keys must be 1024 bits")
- if not key_dispatch_table.has_key(ktype):
+ if ktype not in key_dispatch_table:
raise SSHException("Unknown %s algorithm to generate keys pair" % ktype)
# generating private key
@@ -121,7 +120,7 @@ if __name__ == '__main__':
f.write(" %s" % comment)
if options.verbose:
- print "done."
+ print("done.")
- hash = hexlify(pub.get_fingerprint())
- print "Fingerprint: %d %s %s.pub (%s)" % (bits, ":".join([ hash[i:2+i] for i in range(0, len(hash), 2)]), filename, string.upper(ktype))
+ hash = u(hexlify(pub.get_fingerprint()))
+ print("Fingerprint: %d %s %s.pub (%s)" % (bits, ":".join([ hash[i:2+i] for i in range(0, len(hash), 2)]), filename, ktype.upper()))
diff --git a/demos/demo_server.py b/demos/demo_server.py
index 915b0c67..bb35258b 100644
--- a/demos/demo_server.py
+++ b/demos/demo_server.py
@@ -27,6 +27,7 @@ import threading
import traceback
import paramiko
+from paramiko.py3compat import b, u, decodebytes
# setup logging
@@ -35,17 +36,17 @@ paramiko.util.log_to_file('demo_server.log')
host_key = paramiko.RSAKey(filename='test_rsa.key')
#host_key = paramiko.DSSKey(filename='test_dss.key')
-print 'Read key: ' + hexlify(host_key.get_fingerprint())
+print('Read key: ' + u(hexlify(host_key.get_fingerprint())))
class Server (paramiko.ServerInterface):
# 'data' is the output of base64.encodestring(str(key))
# (using the "user_rsa_key" files)
- data = 'AAAAB3NzaC1yc2EAAAABIwAAAIEAyO4it3fHlmGZWJaGrfeHOVY7RWO3P9M7hp' + \
- 'fAu7jJ2d7eothvfeuoRFtJwhUmZDluRdFyhFY/hFAh76PJKGAusIqIQKlkJxMC' + \
- 'KDqIexkgHAfID/6mqvmnSJf0b5W8v5h2pI/stOSwTQ+pxVhwJ9ctYDhRSlF0iT' + \
- 'UWT10hcuO4Ks8='
- good_pub_key = paramiko.RSAKey(data=base64.decodestring(data))
+ data = (b'AAAAB3NzaC1yc2EAAAABIwAAAIEAyO4it3fHlmGZWJaGrfeHOVY7RWO3P9M7hp'
+ b'fAu7jJ2d7eothvfeuoRFtJwhUmZDluRdFyhFY/hFAh76PJKGAusIqIQKlkJxMC'
+ b'KDqIexkgHAfID/6mqvmnSJf0b5W8v5h2pI/stOSwTQ+pxVhwJ9ctYDhRSlF0iT'
+ b'UWT10hcuO4Ks8=')
+ good_pub_key = paramiko.RSAKey(data=decodebytes(data))
def __init__(self):
self.event = threading.Event()
@@ -61,7 +62,7 @@ class Server (paramiko.ServerInterface):
return paramiko.AUTH_FAILED
def check_auth_publickey(self, username, key):
- print 'Auth attempt with key: ' + hexlify(key.get_fingerprint())
+ print('Auth attempt with key: ' + u(hexlify(key.get_fingerprint())))
if (username == 'robey') and (key == self.good_pub_key):
return paramiko.AUTH_SUCCESSFUL
return paramiko.AUTH_FAILED
@@ -83,47 +84,47 @@ try:
sock = socket.socket(socket.AF_INET, socket.SOCK_STREAM)
sock.setsockopt(socket.SOL_SOCKET, socket.SO_REUSEADDR, 1)
sock.bind(('', 2200))
-except Exception, e:
- print '*** Bind failed: ' + str(e)
+except Exception as e:
+ print('*** Bind failed: ' + str(e))
traceback.print_exc()
sys.exit(1)
try:
sock.listen(100)
- print 'Listening for connection ...'
+ print('Listening for connection ...')
client, addr = sock.accept()
-except Exception, e:
- print '*** Listen/accept failed: ' + str(e)
+except Exception as e:
+ print('*** Listen/accept failed: ' + str(e))
traceback.print_exc()
sys.exit(1)
-print 'Got a connection!'
+print('Got a connection!')
try:
t = paramiko.Transport(client)
try:
t.load_server_moduli()
except:
- print '(Failed to load moduli -- gex will be unsupported.)'
+ print('(Failed to load moduli -- gex will be unsupported.)')
raise
t.add_server_key(host_key)
server = Server()
try:
t.start_server(server=server)
- except paramiko.SSHException, x:
- print '*** SSH negotiation failed.'
+ except paramiko.SSHException:
+ print('*** SSH negotiation failed.')
sys.exit(1)
# wait for auth
chan = t.accept(20)
if chan is None:
- print '*** No channel.'
+ print('*** No channel.')
sys.exit(1)
- print 'Authenticated!'
+ print('Authenticated!')
server.event.wait(10)
if not server.event.isSet():
- print '*** Client never asked for a shell.'
+ print('*** Client never asked for a shell.')
sys.exit(1)
chan.send('\r\n\r\nWelcome to my dorky little BBS!\r\n\r\n')
@@ -135,8 +136,8 @@ try:
chan.send('\r\nI don\'t like you, ' + username + '.\r\n')
chan.close()
-except Exception, e:
- print '*** Caught exception: ' + str(e.__class__) + ': ' + str(e)
+except Exception as e:
+ print('*** Caught exception: ' + str(e.__class__) + ': ' + str(e))
traceback.print_exc()
try:
t.close()
diff --git a/demos/demo_sftp.py b/demos/demo_sftp.py
index 7c4aaba0..a34f2b19 100755
--- a/demos/demo_sftp.py
+++ b/demos/demo_sftp.py
@@ -28,6 +28,7 @@ import sys
import traceback
import paramiko
+from paramiko.py3compat import input
# setup logging
@@ -40,9 +41,9 @@ if len(sys.argv) > 1:
if hostname.find('@') >= 0:
username, hostname = hostname.split('@')
else:
- hostname = raw_input('Hostname: ')
+ hostname = input('Hostname: ')
if len(hostname) == 0:
- print '*** Hostname required.'
+ print('*** Hostname required.')
sys.exit(1)
port = 22
if hostname.find(':') >= 0:
@@ -53,7 +54,7 @@ if hostname.find(':') >= 0:
# get username
if username == '':
default_username = getpass.getuser()
- username = raw_input('Username [%s]: ' % default_username)
+ username = input('Username [%s]: ' % default_username)
if len(username) == 0:
username = default_username
password = getpass.getpass('Password for %s@%s: ' % (username, hostname))
@@ -69,13 +70,13 @@ except IOError:
# try ~/ssh/ too, because windows can't have a folder named ~/.ssh/
host_keys = paramiko.util.load_host_keys(os.path.expanduser('~/ssh/known_hosts'))
except IOError:
- print '*** Unable to open host keys file'
+ print('*** Unable to open host keys file')
host_keys = {}
-if host_keys.has_key(hostname):
+if hostname in host_keys:
hostkeytype = host_keys[hostname].keys()[0]
hostkey = host_keys[hostname][hostkeytype]
- print 'Using host key of type %s' % hostkeytype
+ print('Using host key of type %s' % hostkeytype)
# now, connect and use paramiko Transport to negotiate SSH2 across the connection
@@ -86,22 +87,26 @@ try:
# dirlist on remote host
dirlist = sftp.listdir('.')
- print "Dirlist:", dirlist
+ print("Dirlist: %s" % dirlist)
# copy this demo onto the server
try:
sftp.mkdir("demo_sftp_folder")
except IOError:
- print '(assuming demo_sftp_folder/ already exists)'
- sftp.open('demo_sftp_folder/README', 'w').write('This was created by demo_sftp.py.\n')
- data = open('demo_sftp.py', 'r').read()
+ print('(assuming demo_sftp_folder/ already exists)')
+ with sftp.open('demo_sftp_folder/README', 'w') as f:
+ f.write('This was created by demo_sftp.py.\n')
+ with open('demo_sftp.py', 'r') as f:
+ data = f.read()
sftp.open('demo_sftp_folder/demo_sftp.py', 'w').write(data)
- print 'created demo_sftp_folder/ on the server'
+ print('created demo_sftp_folder/ on the server')
# copy the README back here
- data = sftp.open('demo_sftp_folder/README', 'r').read()
- open('README_demo_sftp', 'w').write(data)
- print 'copied README back here'
+ with sftp.open('demo_sftp_folder/README', 'r') as f:
+ data = f.read()
+ with open('README_demo_sftp', 'w') as f:
+ f.write(data)
+ print('copied README back here')
# BETTER: use the get() and put() methods
sftp.put('demo_sftp.py', 'demo_sftp_folder/demo_sftp.py')
@@ -109,8 +114,8 @@ try:
t.close()
-except Exception, e:
- print '*** Caught exception: %s: %s' % (e.__class__, e)
+except Exception as e:
+ print('*** Caught exception: %s: %s' % (e.__class__, e))
traceback.print_exc()
try:
t.close()
diff --git a/demos/demo_simple.py b/demos/demo_simple.py
index 50f344a7..ae631e43 100755
--- a/demos/demo_simple.py
+++ b/demos/demo_simple.py
@@ -25,9 +25,13 @@ import os
import socket
import sys
import traceback
+from paramiko.py3compat import input
import paramiko
-import interactive
+try:
+ import interactive
+except ImportError:
+ from . import interactive
# setup logging
@@ -40,9 +44,9 @@ if len(sys.argv) > 1:
if hostname.find('@') >= 0:
username, hostname = hostname.split('@')
else:
- hostname = raw_input('Hostname: ')
+ hostname = input('Hostname: ')
if len(hostname) == 0:
- print '*** Hostname required.'
+ print('*** Hostname required.')
sys.exit(1)
port = 22
if hostname.find(':') >= 0:
@@ -53,7 +57,7 @@ if hostname.find(':') >= 0:
# get username
if username == '':
default_username = getpass.getuser()
- username = raw_input('Username [%s]: ' % default_username)
+ username = input('Username [%s]: ' % default_username)
if len(username) == 0:
username = default_username
password = getpass.getpass('Password for %s@%s: ' % (username, hostname))
@@ -64,18 +68,17 @@ try:
client = paramiko.SSHClient()
client.load_system_host_keys()
client.set_missing_host_key_policy(paramiko.WarningPolicy())
- print '*** Connecting...'
+ print('*** Connecting...')
client.connect(hostname, port, username, password)
chan = client.invoke_shell()
- print repr(client.get_transport())
- print '*** Here we go!'
- print
+ print(repr(client.get_transport()))
+ print('*** Here we go!\n')
interactive.interactive_shell(chan)
chan.close()
client.close()
-except Exception, e:
- print '*** Caught exception: %s: %s' % (e.__class__, e)
+except Exception as e:
+ print('*** Caught exception: %s: %s' % (e.__class__, e))
traceback.print_exc()
try:
client.close()
diff --git a/demos/forward.py b/demos/forward.py
index 5048c775..96e1700d 100644
--- a/demos/forward.py
+++ b/demos/forward.py
@@ -30,7 +30,11 @@ import getpass
import os
import socket
import select
-import SocketServer
+try:
+ import SocketServer
+except ImportError:
+ import socketserver as SocketServer
+
import sys
from optparse import OptionParser
@@ -54,7 +58,7 @@ class Handler (SocketServer.BaseRequestHandler):
chan = self.ssh_transport.open_channel('direct-tcpip',
(self.chain_host, self.chain_port),
self.request.getpeername())
- except Exception, e:
+ except Exception as e:
verbose('Incoming request to %s:%d failed: %s' % (self.chain_host,
self.chain_port,
repr(e)))
@@ -98,7 +102,7 @@ def forward_tunnel(local_port, remote_host, remote_port, transport):
def verbose(s):
if g_verbose:
- print s
+ print(s)
HELP = """\
@@ -165,8 +169,8 @@ def main():
try:
client.connect(server[0], server[1], username=options.user, key_filename=options.keyfile,
look_for_keys=options.look_for_keys, password=password)
- except Exception, e:
- print '*** Failed to connect to %s:%d: %r' % (server[0], server[1], e)
+ except Exception as e:
+ print('*** Failed to connect to %s:%d: %r' % (server[0], server[1], e))
sys.exit(1)
verbose('Now forwarding port %d to %s:%d ...' % (options.port, remote[0], remote[1]))
@@ -174,7 +178,7 @@ def main():
try:
forward_tunnel(options.port, remote[0], remote[1], client.get_transport())
except KeyboardInterrupt:
- print 'C-c: Port forwarding stopped.'
+ print('C-c: Port forwarding stopped.')
sys.exit(0)
diff --git a/demos/interactive.py b/demos/interactive.py
index f3be74d2..7138cd6c 100644
--- a/demos/interactive.py
+++ b/demos/interactive.py
@@ -19,6 +19,7 @@
import socket
import sys
+from paramiko.py3compat import u
# windows does not have termios...
try:
@@ -49,9 +50,9 @@ def posix_shell(chan):
r, w, e = select.select([chan, sys.stdin], [], [])
if chan in r:
try:
- x = chan.recv(1024)
+ x = u(chan.recv(1024))
if len(x) == 0:
- print '\r\n*** EOF\r\n',
+ sys.stdout.write('\r\n*** EOF\r\n')
break
sys.stdout.write(x)
sys.stdout.flush()
diff --git a/demos/rforward.py b/demos/rforward.py
index 4a5d2e43..ae70670c 100755
--- a/demos/rforward.py
+++ b/demos/rforward.py
@@ -46,7 +46,7 @@ def handler(chan, host, port):
sock = socket.socket()
try:
sock.connect((host, port))
- except Exception, e:
+ except Exception as e:
verbose('Forwarding request to %s:%d failed: %r' % (host, port, e))
return
@@ -82,7 +82,7 @@ def reverse_forward_tunnel(server_port, remote_host, remote_port, transport):
def verbose(s):
if g_verbose:
- print s
+ print(s)
HELP = """\
@@ -150,8 +150,8 @@ def main():
try:
client.connect(server[0], server[1], username=options.user, key_filename=options.keyfile,
look_for_keys=options.look_for_keys, password=password)
- except Exception, e:
- print '*** Failed to connect to %s:%d: %r' % (server[0], server[1], e)
+ except Exception as e:
+ print('*** Failed to connect to %s:%d: %r' % (server[0], server[1], e))
sys.exit(1)
verbose('Now forwarding remote port %d to %s:%d ...' % (options.port, remote[0], remote[1]))
@@ -159,7 +159,7 @@ def main():
try:
reverse_forward_tunnel(options.port, remote[0], remote[1], client.get_transport())
except KeyboardInterrupt:
- print 'C-c: Port forwarding stopped.'
+ print('C-c: Port forwarding stopped.')
sys.exit(0)
diff --git a/paramiko/__init__.py b/paramiko/__init__.py
index 0e8f9de7..64b2e0b1 100644
--- a/paramiko/__init__.py
+++ b/paramiko/__init__.py
@@ -18,51 +18,51 @@
import sys
-if sys.version_info < (2, 5):
- raise RuntimeError('You need Python 2.5+ for this module.')
+if sys.version_info < (2, 6):
+ raise RuntimeError('You need Python 2.6+ for this module.')
__author__ = "Jeff Forcier <jeff@bitprophet.org>"
-__version__ = "1.12.2"
+__version__ = "1.13.0"
__version_info__ = tuple([ int(d) for d in __version__.split(".") ])
__license__ = "GNU Lesser General Public License (LGPL)"
-from transport import SecurityOptions, Transport
-from client import SSHClient, MissingHostKeyPolicy, AutoAddPolicy, RejectPolicy, WarningPolicy
-from auth_handler import AuthHandler
-from channel import Channel, ChannelFile
-from ssh_exception import SSHException, PasswordRequiredException, \
+from paramiko.transport import SecurityOptions, Transport
+from paramiko.client import SSHClient, MissingHostKeyPolicy, AutoAddPolicy, RejectPolicy, WarningPolicy
+from paramiko.auth_handler import AuthHandler
+from paramiko.channel import Channel, ChannelFile
+from paramiko.ssh_exception import SSHException, PasswordRequiredException, \
BadAuthenticationType, ChannelException, BadHostKeyException, \
AuthenticationException, ProxyCommandFailure
-from server import ServerInterface, SubsystemHandler, InteractiveQuery
-from rsakey import RSAKey
-from dsskey import DSSKey
-from ecdsakey import ECDSAKey
-from sftp import SFTPError, BaseSFTP
-from sftp_client import SFTP, SFTPClient
-from sftp_server import SFTPServer
-from sftp_attr import SFTPAttributes
-from sftp_handle import SFTPHandle
-from sftp_si import SFTPServerInterface
-from sftp_file import SFTPFile
-from message import Message
-from packet import Packetizer
-from file import BufferedFile
-from agent import Agent, AgentKey
-from pkey import PKey
-from hostkeys import HostKeys
-from config import SSHConfig
-from proxy import ProxyCommand
+from paramiko.server import ServerInterface, SubsystemHandler, InteractiveQuery
+from paramiko.rsakey import RSAKey
+from paramiko.dsskey import DSSKey
+from paramiko.ecdsakey import ECDSAKey
+from paramiko.sftp import SFTPError, BaseSFTP
+from paramiko.sftp_client import SFTP, SFTPClient
+from paramiko.sftp_server import SFTPServer
+from paramiko.sftp_attr import SFTPAttributes
+from paramiko.sftp_handle import SFTPHandle
+from paramiko.sftp_si import SFTPServerInterface
+from paramiko.sftp_file import SFTPFile
+from paramiko.message import Message
+from paramiko.packet import Packetizer
+from paramiko.file import BufferedFile
+from paramiko.agent import Agent, AgentKey
+from paramiko.pkey import PKey
+from paramiko.hostkeys import HostKeys
+from paramiko.config import SSHConfig
+from paramiko.proxy import ProxyCommand
-from common import AUTH_SUCCESSFUL, AUTH_PARTIALLY_SUCCESSFUL, AUTH_FAILED, \
+from paramiko.common import AUTH_SUCCESSFUL, AUTH_PARTIALLY_SUCCESSFUL, AUTH_FAILED, \
OPEN_SUCCEEDED, OPEN_FAILED_ADMINISTRATIVELY_PROHIBITED, OPEN_FAILED_CONNECT_FAILED, \
OPEN_FAILED_UNKNOWN_CHANNEL_TYPE, OPEN_FAILED_RESOURCE_SHORTAGE
-from sftp import SFTP_OK, SFTP_EOF, SFTP_NO_SUCH_FILE, SFTP_PERMISSION_DENIED, SFTP_FAILURE, \
+from paramiko.sftp import SFTP_OK, SFTP_EOF, SFTP_NO_SUCH_FILE, SFTP_PERMISSION_DENIED, SFTP_FAILURE, \
SFTP_BAD_MESSAGE, SFTP_NO_CONNECTION, SFTP_CONNECTION_LOST, SFTP_OP_UNSUPPORTED
-from common import io_sleep
+from paramiko.common import io_sleep
__all__ = [ 'Transport',
'SSHClient',
diff --git a/paramiko/_winapi.py b/paramiko/_winapi.py
index b8759245..1be117c3 100644
--- a/paramiko/_winapi.py
+++ b/paramiko/_winapi.py
@@ -8,7 +8,11 @@ in jaraco.windows and asking the author to port the fixes back here.
import ctypes
import ctypes.wintypes
-import __builtin__
+from paramiko.py3compat import u
+try:
+ import builtins
+except ImportError:
+ import __builtin__ as builtins
try:
USHORT = ctypes.wintypes.USHORT
@@ -40,7 +44,7 @@ def format_system_message(errno):
result_buffer = ctypes.wintypes.LPWSTR()
buffer_size = 0
arguments = None
- bytes = ctypes.windll.kernel32.FormatMessageW(
+ format_bytes = ctypes.windll.kernel32.FormatMessageW(
flags,
source,
message_id,
@@ -52,13 +56,13 @@ def format_system_message(errno):
# note the following will cause an infinite loop if GetLastError
# repeatedly returns an error that cannot be formatted, although
# this should not happen.
- handle_nonzero_success(bytes)
+ handle_nonzero_success(format_bytes)
message = result_buffer.value
ctypes.windll.kernel32.LocalFree(result_buffer)
return message
-class WindowsError(__builtin__.WindowsError):
+class WindowsError(builtins.WindowsError):
"more info about errors at http://msdn.microsoft.com/en-us/library/ms681381(VS.85).aspx"
def __init__(self, value=None):
@@ -120,7 +124,7 @@ class MemoryMap(object):
FILE_MAP_WRITE = 0x2
filemap = ctypes.windll.kernel32.CreateFileMappingW(
INVALID_HANDLE_VALUE, p_SA, PAGE_READWRITE, 0, self.length,
- unicode(self.name))
+ u(self.name))
handle_nonzero_success(filemap)
if filemap == INVALID_HANDLE_VALUE:
raise Exception("Failed to create file mapping")
diff --git a/paramiko/agent.py b/paramiko/agent.py
index d9f4b1bc..3aa58bea 100644
--- a/paramiko/agent.py
+++ b/paramiko/agent.py
@@ -34,11 +34,14 @@ from paramiko.ssh_exception import SSHException
from paramiko.message import Message
from paramiko.pkey import PKey
from paramiko.channel import Channel
-from paramiko.common import io_sleep
+from paramiko.common import *
from paramiko.util import retry_on_signal
-SSH2_AGENTC_REQUEST_IDENTITIES, SSH2_AGENT_IDENTITIES_ANSWER, \
- SSH2_AGENTC_SIGN_REQUEST, SSH2_AGENT_SIGN_RESPONSE = range(11, 15)
+cSSH2_AGENTC_REQUEST_IDENTITIES = byte_chr(11)
+SSH2_AGENT_IDENTITIES_ANSWER = 12
+cSSH2_AGENTC_SIGN_REQUEST = byte_chr(13)
+SSH2_AGENT_SIGN_RESPONSE = 14
+
class AgentSSH(object):
@@ -60,12 +63,12 @@ class AgentSSH(object):
def _connect(self, conn):
self._conn = conn
- ptype, result = self._send_message(chr(SSH2_AGENTC_REQUEST_IDENTITIES))
+ ptype, result = self._send_message(cSSH2_AGENTC_REQUEST_IDENTITIES)
if ptype != SSH2_AGENT_IDENTITIES_ANSWER:
raise SSHException('could not get keys from ssh-agent')
keys = []
for i in range(result.get_int()):
- keys.append(AgentKey(self, result.get_string()))
+ keys.append(AgentKey(self, result.get_binary()))
result.get_string()
self._keys = tuple(keys)
@@ -75,7 +78,7 @@ class AgentSSH(object):
self._keys = ()
def _send_message(self, msg):
- msg = str(msg)
+ msg = asbytes(msg)
self._conn.send(struct.pack('>I', len(msg)) + msg)
l = self._read_all(4)
msg = Message(self._read_all(struct.unpack('>I', l)[0]))
@@ -212,7 +215,7 @@ class AgentClientProxy(object):
# probably a dangling env var: the ssh agent is gone
return
elif sys.platform == 'win32':
- import win_pageant
+ import paramiko.win_pageant as win_pageant
if win_pageant.can_talk_to_agent():
conn = win_pageant.PageantConnection()
else:
@@ -328,7 +331,7 @@ class Agent(AgentSSH):
# probably a dangling env var: the ssh agent is gone
return
elif sys.platform == 'win32':
- import win_pageant
+ from . import win_pageant
if win_pageant.can_talk_to_agent():
conn = win_pageant.PageantConnection()
else:
@@ -354,21 +357,24 @@ class AgentKey(PKey):
def __init__(self, agent, blob):
self.agent = agent
self.blob = blob
- self.name = Message(blob).get_string()
+ self.name = Message(blob).get_text()
- def __str__(self):
+ def asbytes(self):
return self.blob
+ def __str__(self):
+ return self.asbytes()
+
def get_name(self):
return self.name
def sign_ssh_data(self, rng, data):
msg = Message()
- msg.add_byte(chr(SSH2_AGENTC_SIGN_REQUEST))
+ msg.add_byte(cSSH2_AGENTC_SIGN_REQUEST)
msg.add_string(self.blob)
msg.add_string(data)
msg.add_int(0)
ptype, result = self.agent._send_message(msg)
if ptype != SSH2_AGENT_SIGN_RESPONSE:
raise SSHException('key cannot be used for signing')
- return result.get_string()
+ return result.get_binary()
diff --git a/paramiko/auth_handler.py b/paramiko/auth_handler.py
index a6f52550..2cc09353 100644
--- a/paramiko/auth_handler.py
+++ b/paramiko/auth_handler.py
@@ -120,13 +120,13 @@ class AuthHandler (object):
def _request_auth(self):
m = Message()
- m.add_byte(chr(MSG_SERVICE_REQUEST))
+ m.add_byte(cMSG_SERVICE_REQUEST)
m.add_string('ssh-userauth')
self.transport._send_message(m)
def _disconnect_service_not_available(self):
m = Message()
- m.add_byte(chr(MSG_DISCONNECT))
+ m.add_byte(cMSG_DISCONNECT)
m.add_int(DISCONNECT_SERVICE_NOT_AVAILABLE)
m.add_string('Service not available')
m.add_string('en')
@@ -135,7 +135,7 @@ class AuthHandler (object):
def _disconnect_no_more_auth(self):
m = Message()
- m.add_byte(chr(MSG_DISCONNECT))
+ m.add_byte(cMSG_DISCONNECT)
m.add_int(DISCONNECT_NO_MORE_AUTH_METHODS_AVAILABLE)
m.add_string('No more auth methods available')
m.add_string('en')
@@ -145,14 +145,14 @@ class AuthHandler (object):
def _get_session_blob(self, key, service, username):
m = Message()
m.add_string(self.transport.session_id)
- m.add_byte(chr(MSG_USERAUTH_REQUEST))
+ m.add_byte(cMSG_USERAUTH_REQUEST)
m.add_string(username)
m.add_string(service)
m.add_string('publickey')
m.add_boolean(1)
m.add_string(key.get_name())
- m.add_string(str(key))
- return str(m)
+ m.add_string(key)
+ return m.asbytes()
def wait_for_response(self, event):
while True:
@@ -176,11 +176,11 @@ class AuthHandler (object):
return []
def _parse_service_request(self, m):
- service = m.get_string()
+ service = m.get_text()
if self.transport.server_mode and (service == 'ssh-userauth'):
# accepted
m = Message()
- m.add_byte(chr(MSG_SERVICE_ACCEPT))
+ m.add_byte(cMSG_SERVICE_ACCEPT)
m.add_string(service)
self.transport._send_message(m)
return
@@ -188,27 +188,25 @@ class AuthHandler (object):
self._disconnect_service_not_available()
def _parse_service_accept(self, m):
- service = m.get_string()
+ service = m.get_text()
if service == 'ssh-userauth':
self.transport._log(DEBUG, 'userauth is OK')
m = Message()
- m.add_byte(chr(MSG_USERAUTH_REQUEST))
+ m.add_byte(cMSG_USERAUTH_REQUEST)
m.add_string(self.username)
m.add_string('ssh-connection')
m.add_string(self.auth_method)
if self.auth_method == 'password':
m.add_boolean(False)
- password = self.password
- if isinstance(password, unicode):
- password = password.encode('UTF-8')
+ password = bytestring(self.password)
m.add_string(password)
elif self.auth_method == 'publickey':
m.add_boolean(True)
m.add_string(self.private_key.get_name())
- m.add_string(str(self.private_key))
+ m.add_string(self.private_key)
blob = self._get_session_blob(self.private_key, 'ssh-connection', self.username)
sig = self.private_key.sign_ssh_data(self.transport.rng, blob)
- m.add_string(str(sig))
+ m.add_string(sig)
elif self.auth_method == 'keyboard-interactive':
m.add_string('')
m.add_string(self.submethods)
@@ -225,11 +223,11 @@ class AuthHandler (object):
m = Message()
if result == AUTH_SUCCESSFUL:
self.transport._log(INFO, 'Auth granted (%s).' % method)
- m.add_byte(chr(MSG_USERAUTH_SUCCESS))
+ m.add_byte(cMSG_USERAUTH_SUCCESS)
self.authenticated = True
else:
self.transport._log(INFO, 'Auth rejected (%s).' % method)
- m.add_byte(chr(MSG_USERAUTH_FAILURE))
+ m.add_byte(cMSG_USERAUTH_FAILURE)
m.add_string(self.transport.server_object.get_allowed_auths(username))
if result == AUTH_PARTIALLY_SUCCESSFUL:
m.add_boolean(1)
@@ -245,10 +243,10 @@ class AuthHandler (object):
def _interactive_query(self, q):
# make interactive query instead of response
m = Message()
- m.add_byte(chr(MSG_USERAUTH_INFO_REQUEST))
+ m.add_byte(cMSG_USERAUTH_INFO_REQUEST)
m.add_string(q.name)
m.add_string(q.instructions)
- m.add_string('')
+ m.add_string(bytes())
m.add_int(len(q.prompts))
for p in q.prompts:
m.add_string(p[0])
@@ -259,7 +257,7 @@ class AuthHandler (object):
if not self.transport.server_mode:
# er, uh... what?
m = Message()
- m.add_byte(chr(MSG_USERAUTH_FAILURE))
+ m.add_byte(cMSG_USERAUTH_FAILURE)
m.add_string('none')
m.add_boolean(0)
self.transport._send_message(m)
@@ -267,9 +265,9 @@ class AuthHandler (object):
if self.authenticated:
# ignore
return
- username = m.get_string()
- service = m.get_string()
- method = m.get_string()
+ username = m.get_text()
+ service = m.get_text()
+ method = m.get_text()
self.transport._log(DEBUG, 'Auth request (type=%s) service=%s, username=%s' % (method, service, username))
if service != 'ssh-connection':
self._disconnect_service_not_available()
@@ -284,7 +282,7 @@ class AuthHandler (object):
result = self.transport.server_object.check_auth_none(username)
elif method == 'password':
changereq = m.get_boolean()
- password = m.get_string()
+ password = m.get_binary()
try:
password = password.decode('UTF-8')
except UnicodeError:
@@ -295,7 +293,7 @@ class AuthHandler (object):
# always treated as failure, since we don't support changing passwords, but collect
# the list of valid auth types from the callback anyway
self.transport._log(DEBUG, 'Auth request to change passwords (rejected)')
- newpassword = m.get_string()
+ newpassword = m.get_binary()
try:
newpassword = newpassword.decode('UTF-8', 'replace')
except UnicodeError:
@@ -305,11 +303,11 @@ class AuthHandler (object):
result = self.transport.server_object.check_auth_password(username, password)
elif method == 'publickey':
sig_attached = m.get_boolean()
- keytype = m.get_string()
- keyblob = m.get_string()
+ keytype = m.get_text()
+ keyblob = m.get_binary()
try:
key = self.transport._key_info[keytype](Message(keyblob))
- except SSHException, e:
+ except SSHException as e:
self.transport._log(INFO, 'Auth rejected: public key: %s' % str(e))
key = None
except:
@@ -326,12 +324,12 @@ class AuthHandler (object):
# client wants to know if this key is acceptable, before it
# signs anything... send special "ok" message
m = Message()
- m.add_byte(chr(MSG_USERAUTH_PK_OK))
+ m.add_byte(cMSG_USERAUTH_PK_OK)
m.add_string(keytype)
m.add_string(keyblob)
self.transport._send_message(m)
return
- sig = Message(m.get_string())
+ sig = Message(m.get_binary())
blob = self._get_session_blob(key, service, username)
if not key.verify_ssh_sig(blob, sig):
self.transport._log(INFO, 'Auth rejected: invalid signature')
@@ -378,23 +376,23 @@ class AuthHandler (object):
banner = m.get_string()
self.banner = banner
lang = m.get_string()
- self.transport._log(INFO, 'Auth banner: ' + banner)
+ self.transport._log(INFO, 'Auth banner: %s' % banner)
# who cares.
def _parse_userauth_info_request(self, m):
if self.auth_method != 'keyboard-interactive':
raise SSHException('Illegal info request from server')
- title = m.get_string()
- instructions = m.get_string()
- m.get_string() # lang
+ title = m.get_text()
+ instructions = m.get_text()
+ m.get_binary() # lang
prompts = m.get_int()
prompt_list = []
for i in range(prompts):
- prompt_list.append((m.get_string(), m.get_boolean()))
+ prompt_list.append((m.get_text(), m.get_boolean()))
response_list = self.interactive_handler(title, instructions, prompt_list)
m = Message()
- m.add_byte(chr(MSG_USERAUTH_INFO_RESPONSE))
+ m.add_byte(cMSG_USERAUTH_INFO_RESPONSE)
m.add_int(len(response_list))
for r in response_list:
m.add_string(r)
@@ -406,7 +404,7 @@ class AuthHandler (object):
n = m.get_int()
responses = []
for i in range(n):
- responses.append(m.get_string())
+ responses.append(m.get_text())
result = self.transport.server_object.check_auth_interactive_response(responses)
if isinstance(type(result), InteractiveQuery):
# make interactive query instead of response
diff --git a/paramiko/ber.py b/paramiko/ber.py
index 3941581c..c4f35210 100644
--- a/paramiko/ber.py
+++ b/paramiko/ber.py
@@ -17,7 +17,8 @@
# 59 Temple Place, Suite 330, Boston, MA 02111-1307 USA.
-import util
+import paramiko.util as util
+from paramiko.common import *
class BERException (Exception):
@@ -29,13 +30,16 @@ class BER(object):
Robey's tiny little attempt at a BER decoder.
"""
- def __init__(self, content=''):
- self.content = content
+ def __init__(self, content=bytes()):
+ self.content = b(content)
self.idx = 0
- def __str__(self):
+ def asbytes(self):
return self.content
+ def __str__(self):
+ return self.asbytes()
+
def __repr__(self):
return 'BER(\'' + repr(self.content) + '\')'
@@ -45,13 +49,13 @@ class BER(object):
def decode_next(self):
if self.idx >= len(self.content):
return None
- ident = ord(self.content[self.idx])
+ ident = byte_ord(self.content[self.idx])
self.idx += 1
if (ident & 31) == 31:
# identifier > 30
ident = 0
while self.idx < len(self.content):
- t = ord(self.content[self.idx])
+ t = byte_ord(self.content[self.idx])
self.idx += 1
ident = (ident << 7) | (t & 0x7f)
if not (t & 0x80):
@@ -59,7 +63,7 @@ class BER(object):
if self.idx >= len(self.content):
return None
# now fetch length
- size = ord(self.content[self.idx])
+ size = byte_ord(self.content[self.idx])
self.idx += 1
if size & 0x80:
# more complimicated...
@@ -98,20 +102,20 @@ class BER(object):
def encode_tlv(self, ident, val):
# no need to support ident > 31 here
- self.content += chr(ident)
+ self.content += byte_chr(ident)
if len(val) > 0x7f:
lenstr = util.deflate_long(len(val))
- self.content += chr(0x80 + len(lenstr)) + lenstr
+ self.content += byte_chr(0x80 + len(lenstr)) + lenstr
else:
- self.content += chr(len(val))
+ self.content += byte_chr(len(val))
self.content += val
def encode(self, x):
if type(x) is bool:
if x:
- self.encode_tlv(1, '\xff')
+ self.encode_tlv(1, max_byte)
else:
- self.encode_tlv(1, '\x00')
+ self.encode_tlv(1, zero_byte)
elif (type(x) is int) or (type(x) is long):
self.encode_tlv(2, util.deflate_long(x))
elif type(x) is str:
@@ -125,5 +129,5 @@ class BER(object):
b = BER()
for item in data:
b.encode(item)
- return str(b)
+ return b.asbytes()
encode_sequence = staticmethod(encode_sequence)
diff --git a/paramiko/buffered_pipe.py b/paramiko/buffered_pipe.py
index a4be5d8d..94514b67 100644
--- a/paramiko/buffered_pipe.py
+++ b/paramiko/buffered_pipe.py
@@ -25,6 +25,7 @@ read operations are blocking and can have a timeout set.
import array
import threading
import time
+from paramiko.common import *
class PipeTimeout (IOError):
@@ -48,6 +49,20 @@ class BufferedPipe (object):
self._buffer = array.array('B')
self._closed = False
+ if PY2:
+ def _buffer_frombytes(self, data):
+ self._buffer.fromstring(data)
+
+ def _buffer_tobytes(self, limit=None):
+ return self._buffer[:limit].tostring()
+ else:
+ def _buffer_frombytes(self, data):
+ self._buffer.frombytes(data)
+
+ def _buffer_tobytes(self, limit=None):
+ return self._buffer[:limit].tobytes()
+
+
def set_event(self, event):
"""
Set an event on this buffer. When data is ready to be read (or the
@@ -73,7 +88,7 @@ class BufferedPipe (object):
try:
if self._event is not None:
self._event.set()
- self._buffer.fromstring(data)
+ self._buffer_frombytes(b(data))
self._cv.notifyAll()
finally:
self._lock.release()
@@ -117,7 +132,7 @@ class BufferedPipe (object):
if a timeout was specified and no data was ready before that
timeout
"""
- out = ''
+ out = bytes()
self._lock.acquire()
try:
if len(self._buffer) == 0:
@@ -138,12 +153,12 @@ class BufferedPipe (object):
# something's in the buffer and we have the lock!
if len(self._buffer) <= nbytes:
- out = self._buffer.tostring()
+ out = self._buffer_tobytes()
del self._buffer[:]
if (self._event is not None) and not self._closed:
self._event.clear()
else:
- out = self._buffer[:nbytes].tostring()
+ out = self._buffer_tobytes(nbytes)
del self._buffer[:nbytes]
finally:
self._lock.release()
@@ -160,7 +175,7 @@ class BufferedPipe (object):
"""
self._lock.acquire()
try:
- out = self._buffer.tostring()
+ out = self._buffer_tobytes()
del self._buffer[:]
if (self._event is not None) and not self._closed:
self._event.clear()
diff --git a/paramiko/channel.py b/paramiko/channel.py
index 20f487a4..107786c4 100644
--- a/paramiko/channel.py
+++ b/paramiko/channel.py
@@ -140,7 +140,7 @@ class Channel (object):
if self.closed or self.eof_received or self.eof_sent or not self.active:
raise SSHException('Channel is not open')
m = Message()
- m.add_byte(chr(MSG_CHANNEL_REQUEST))
+ m.add_byte(cMSG_CHANNEL_REQUEST)
m.add_int(self.remote_chanid)
m.add_string('pty-req')
m.add_boolean(True)
@@ -149,7 +149,7 @@ class Channel (object):
m.add_int(height)
m.add_int(width_pixels)
m.add_int(height_pixels)
- m.add_string('')
+ m.add_string(bytes())
self._event_pending()
self.transport._send_user_message(m)
self._wait_for_event()
@@ -173,7 +173,7 @@ class Channel (object):
if self.closed or self.eof_received or self.eof_sent or not self.active:
raise SSHException('Channel is not open')
m = Message()
- m.add_byte(chr(MSG_CHANNEL_REQUEST))
+ m.add_byte(cMSG_CHANNEL_REQUEST)
m.add_int(self.remote_chanid)
m.add_string('shell')
m.add_boolean(1)
@@ -199,7 +199,7 @@ class Channel (object):
if self.closed or self.eof_received or self.eof_sent or not self.active:
raise SSHException('Channel is not open')
m = Message()
- m.add_byte(chr(MSG_CHANNEL_REQUEST))
+ m.add_byte(cMSG_CHANNEL_REQUEST)
m.add_int(self.remote_chanid)
m.add_string('exec')
m.add_boolean(True)
@@ -225,7 +225,7 @@ class Channel (object):
if self.closed or self.eof_received or self.eof_sent or not self.active:
raise SSHException('Channel is not open')
m = Message()
- m.add_byte(chr(MSG_CHANNEL_REQUEST))
+ m.add_byte(cMSG_CHANNEL_REQUEST)
m.add_int(self.remote_chanid)
m.add_string('subsystem')
m.add_boolean(True)
@@ -250,7 +250,7 @@ class Channel (object):
if self.closed or self.eof_received or self.eof_sent or not self.active:
raise SSHException('Channel is not open')
m = Message()
- m.add_byte(chr(MSG_CHANNEL_REQUEST))
+ m.add_byte(cMSG_CHANNEL_REQUEST)
m.add_int(self.remote_chanid)
m.add_string('window-change')
m.add_boolean(False)
@@ -304,7 +304,7 @@ class Channel (object):
# in many cases, the channel will not still be open here.
# that's fine.
m = Message()
- m.add_byte(chr(MSG_CHANNEL_REQUEST))
+ m.add_byte(cMSG_CHANNEL_REQUEST)
m.add_int(self.remote_chanid)
m.add_string('exit-status')
m.add_boolean(False)
@@ -359,7 +359,7 @@ class Channel (object):
auth_cookie = binascii.hexlify(self.transport.rng.read(16))
m = Message()
- m.add_byte(chr(MSG_CHANNEL_REQUEST))
+ m.add_byte(cMSG_CHANNEL_REQUEST)
m.add_int(self.remote_chanid)
m.add_string('x11-req')
m.add_boolean(True)
@@ -389,7 +389,7 @@ class Channel (object):
raise SSHException('Channel is not open')
m = Message()
- m.add_byte(chr(MSG_CHANNEL_REQUEST))
+ m.add_byte(cMSG_CHANNEL_REQUEST)
m.add_int(self.remote_chanid)
m.add_string('auth-agent-req@openssh.com')
m.add_boolean(False)
@@ -451,7 +451,7 @@ class Channel (object):
.. versionadded:: 1.1
"""
- data = ''
+ data = bytes()
self.lock.acquire()
try:
old = self.combine_stderr
@@ -581,14 +581,14 @@ class Channel (object):
"""
try:
out = self.in_buffer.read(nbytes, self.timeout)
- except PipeTimeout, e:
+ except PipeTimeout:
raise socket.timeout()
ack = self._check_add_window(len(out))
# no need to hold the channel lock when sending this
if ack > 0:
m = Message()
- m.add_byte(chr(MSG_CHANNEL_WINDOW_ADJUST))
+ m.add_byte(cMSG_CHANNEL_WINDOW_ADJUST)
m.add_int(self.remote_chanid)
m.add_int(ack)
self.transport._send_user_message(m)
@@ -629,14 +629,14 @@ class Channel (object):
"""
try:
out = self.in_stderr_buffer.read(nbytes, self.timeout)
- except PipeTimeout, e:
+ except PipeTimeout:
raise socket.timeout()
ack = self._check_add_window(len(out))
# no need to hold the channel lock when sending this
if ack > 0:
m = Message()
- m.add_byte(chr(MSG_CHANNEL_WINDOW_ADJUST))
+ m.add_byte(cMSG_CHANNEL_WINDOW_ADJUST)
m.add_int(self.remote_chanid)
m.add_int(ack)
self.transport._send_user_message(m)
@@ -686,7 +686,7 @@ class Channel (object):
# eof or similar
return 0
m = Message()
- m.add_byte(chr(MSG_CHANNEL_DATA))
+ m.add_byte(cMSG_CHANNEL_DATA)
m.add_int(self.remote_chanid)
m.add_string(s[:size])
finally:
@@ -721,7 +721,7 @@ class Channel (object):
# eof or similar
return 0
m = Message()
- m.add_byte(chr(MSG_CHANNEL_EXTENDED_DATA))
+ m.add_byte(cMSG_CHANNEL_EXTENDED_DATA)
m.add_int(self.remote_chanid)
m.add_int(1)
m.add_string(s[:size])
@@ -925,16 +925,16 @@ class Channel (object):
self.transport._send_user_message(m)
def _feed(self, m):
- if type(m) is str:
+ if isinstance(m, bytes_types):
# passed from _feed_extended
s = m
else:
- s = m.get_string()
+ s = m.get_binary()
self.in_buffer.feed(s)
def _feed_extended(self, m):
code = m.get_int()
- s = m.get_string()
+ s = m.get_binary()
if code != 1:
self._log(ERROR, 'unknown extended_data type %d; discarding' % code)
return
@@ -955,7 +955,7 @@ class Channel (object):
self.lock.release()
def _handle_request(self, m):
- key = m.get_string()
+ key = m.get_text()
want_reply = m.get_boolean()
server = self.transport.server_object
ok = False
@@ -991,13 +991,13 @@ class Channel (object):
else:
ok = server.check_channel_env_request(self, name, value)
elif key == 'exec':
- cmd = m.get_string()
+ cmd = m.get_text()
if server is None:
ok = False
else:
ok = server.check_channel_exec_request(self, cmd)
elif key == 'subsystem':
- name = m.get_string()
+ name = m.get_text()
if server is None:
ok = False
else:
@@ -1014,8 +1014,8 @@ class Channel (object):
pixelheight)
elif key == 'x11-req':
single_connection = m.get_boolean()
- auth_proto = m.get_string()
- auth_cookie = m.get_string()
+ auth_proto = m.get_text()
+ auth_cookie = m.get_binary()
screen_number = m.get_int()
if server is None:
ok = False
@@ -1033,9 +1033,9 @@ class Channel (object):
if want_reply:
m = Message()
if ok:
- m.add_byte(chr(MSG_CHANNEL_SUCCESS))
+ m.add_byte(cMSG_CHANNEL_SUCCESS)
else:
- m.add_byte(chr(MSG_CHANNEL_FAILURE))
+ m.add_byte(cMSG_CHANNEL_FAILURE)
m.add_int(self.remote_chanid)
self.transport._send_user_message(m)
@@ -1101,7 +1101,7 @@ class Channel (object):
if self.eof_sent:
return None
m = Message()
- m.add_byte(chr(MSG_CHANNEL_EOF))
+ m.add_byte(cMSG_CHANNEL_EOF)
m.add_int(self.remote_chanid)
self.eof_sent = True
self._log(DEBUG, 'EOF sent (%s)', self._name)
@@ -1113,7 +1113,7 @@ class Channel (object):
return None, None
m1 = self._send_eof()
m2 = Message()
- m2.add_byte(chr(MSG_CHANNEL_CLOSE))
+ m2.add_byte(cMSG_CHANNEL_CLOSE)
m2.add_int(self.remote_chanid)
self._set_closed()
# can't unlink from the Transport yet -- the remote side may still
diff --git a/paramiko/client.py b/paramiko/client.py
index b5929e6e..2bb7c4bc 100644
--- a/paramiko/client.py
+++ b/paramiko/client.py
@@ -132,11 +132,10 @@ class SSHClient (object):
if self._host_keys_filename is not None:
self.load_host_keys(self._host_keys_filename)
- f = open(filename, 'w')
- for hostname, keys in self._host_keys.iteritems():
- for keytype, key in keys.iteritems():
+ with open(filename, 'w') as f:
+ for hostname, keys in self._host_keys.items():
+ for keytype, key in keys.items():
f.write('%s %s %s\n' % (hostname, keytype, key.get_base64()))
- f.close()
def get_host_keys(self):
"""
@@ -266,7 +265,7 @@ class SSHClient (object):
if key_filename is None:
key_filenames = []
- elif isinstance(key_filename, (str, unicode)):
+ elif isinstance(key_filename, string_types):
key_filenames = [ key_filename ]
else:
key_filenames = key_filename
@@ -310,8 +309,8 @@ class SSHClient (object):
chan.settimeout(timeout)
chan.exec_command(command)
stdin = chan.makefile('wb', bufsize)
- stdout = chan.makefile('rb', bufsize)
- stderr = chan.makefile_stderr('rb', bufsize)
+ stdout = chan.makefile('r', bufsize)
+ stderr = chan.makefile_stderr('r', bufsize)
return stdin, stdout, stderr
def invoke_shell(self, term='vt100', width=80, height=24, width_pixels=0,
@@ -377,7 +376,7 @@ class SSHClient (object):
two_factor = (allowed_types == ['password'])
if not two_factor:
return
- except SSHException, e:
+ except SSHException as e:
saved_exception = e
if not two_factor:
@@ -391,7 +390,7 @@ class SSHClient (object):
if not two_factor:
return
break
- except SSHException, e:
+ except SSHException as e:
saved_exception = e
if not two_factor and allow_agent:
@@ -407,7 +406,7 @@ class SSHClient (object):
if not two_factor:
return
break
- except SSHException, e:
+ except SSHException as e:
saved_exception = e
if not two_factor:
@@ -439,17 +438,15 @@ class SSHClient (object):
if not two_factor:
return
break
- except SSHException, e:
- saved_exception = e
- except IOError, e:
+ except (SSHException, IOError) as e:
saved_exception = e
if password is not None:
try:
self._transport.auth_password(username, password)
return
- except SSHException, e:
- saved_exception = e
+ except SSHException:
+ saved_exception = sys.exc_info()[1]
elif two_factor:
raise SSHException('Two-factor authentication requires a password')
diff --git a/paramiko/common.py b/paramiko/common.py
index 3d7ca588..e30df73a 100644
--- a/paramiko/common.py
+++ b/paramiko/common.py
@@ -19,12 +19,13 @@
"""
Common constants and global variables.
"""
+from paramiko.py3compat import *
MSG_DISCONNECT, MSG_IGNORE, MSG_UNIMPLEMENTED, MSG_DEBUG, MSG_SERVICE_REQUEST, \
MSG_SERVICE_ACCEPT = range(1, 7)
MSG_KEXINIT, MSG_NEWKEYS = range(20, 22)
MSG_USERAUTH_REQUEST, MSG_USERAUTH_FAILURE, MSG_USERAUTH_SUCCESS, \
- MSG_USERAUTH_BANNER = range(50, 54)
+ MSG_USERAUTH_BANNER = range(50, 54)
MSG_USERAUTH_PK_OK = 60
MSG_USERAUTH_INFO_REQUEST, MSG_USERAUTH_INFO_RESPONSE = range(60, 62)
MSG_GLOBAL_REQUEST, MSG_REQUEST_SUCCESS, MSG_REQUEST_FAILURE = range(80, 83)
@@ -33,6 +34,10 @@ MSG_CHANNEL_OPEN, MSG_CHANNEL_OPEN_SUCCESS, MSG_CHANNEL_OPEN_FAILURE, \
MSG_CHANNEL_EOF, MSG_CHANNEL_CLOSE, MSG_CHANNEL_REQUEST, \
MSG_CHANNEL_SUCCESS, MSG_CHANNEL_FAILURE = range(90, 101)
+for key in list(locals().keys()):
+ if key.startswith('MSG_'):
+ locals()['c' + key] = byte_chr(locals()[key])
+del key
# for debugging:
MSG_NAMES = {
@@ -69,7 +74,7 @@ MSG_NAMES = {
MSG_CHANNEL_REQUEST: 'channel-request',
MSG_CHANNEL_SUCCESS: 'channel-success',
MSG_CHANNEL_FAILURE: 'channel-failure'
- }
+}
# authentication request return codes:
@@ -118,6 +123,42 @@ else:
import logging
PY22 = False
+zero_byte = byte_chr(0)
+one_byte = byte_chr(1)
+four_byte = byte_chr(4)
+max_byte = byte_chr(0xff)
+cr_byte = byte_chr(13)
+linefeed_byte = byte_chr(10)
+crlf = cr_byte + linefeed_byte
+
+if PY2:
+ cr_byte_value = cr_byte
+ linefeed_byte_value = linefeed_byte
+else:
+ cr_byte_value = 13
+ linefeed_byte_value = 10
+
+
+def asbytes(s):
+ if not isinstance(s, bytes_types):
+ if isinstance(s, string_types):
+ s = b(s)
+ else:
+ try:
+ s = s.asbytes()
+ except Exception:
+ raise Exception('Unknown type')
+ return s
+
+xffffffff = long(0xffffffff)
+x80000000 = long(0x80000000)
+o666 = 438
+o660 = 432
+o644 = 420
+o600 = 384
+o777 = 511
+o700 = 448
+o70 = 56
DEBUG = logging.DEBUG
INFO = logging.INFO
diff --git a/paramiko/config.py b/paramiko/config.py
index bc2816da..77fa13d7 100644
--- a/paramiko/config.py
+++ b/paramiko/config.py
@@ -116,7 +116,7 @@ class SSHConfig (object):
ret = {}
for match in matches:
- for key, value in match['config'].iteritems():
+ for key, value in match['config'].items():
if key not in ret:
# Create a copy of the original value,
# else it will reference the original list
diff --git a/paramiko/dsskey.py b/paramiko/dsskey.py
index bac3dfed..6ab298ac 100644
--- a/paramiko/dsskey.py
+++ b/paramiko/dsskey.py
@@ -56,7 +56,7 @@ class DSSKey (PKey):
else:
if msg is None:
raise SSHException('Key object may not be empty')
- if msg.get_string() != 'ssh-dss':
+ if msg.get_text() != 'ssh-dss':
raise SSHException('Invalid key')
self.p = msg.get_mpint()
self.q = msg.get_mpint()
@@ -64,14 +64,17 @@ class DSSKey (PKey):
self.y = msg.get_mpint()
self.size = util.bit_length(self.p)
- def __str__(self):
+ def asbytes(self):
m = Message()
m.add_string('ssh-dss')
m.add_mpint(self.p)
m.add_mpint(self.q)
m.add_mpint(self.g)
m.add_mpint(self.y)
- return str(m)
+ return m.asbytes()
+
+ def __str__(self):
+ return self.asbytes()
def __hash__(self):
h = hash(self.get_name())
@@ -107,21 +110,21 @@ class DSSKey (PKey):
rstr = util.deflate_long(r, 0)
sstr = util.deflate_long(s, 0)
if len(rstr) < 20:
- rstr = '\x00' * (20 - len(rstr)) + rstr
+ rstr = zero_byte * (20 - len(rstr)) + rstr
if len(sstr) < 20:
- sstr = '\x00' * (20 - len(sstr)) + sstr
+ sstr = zero_byte * (20 - len(sstr)) + sstr
m.add_string(rstr + sstr)
return m
def verify_ssh_sig(self, data, msg):
- if len(str(msg)) == 40:
+ if len(msg.asbytes()) == 40:
# spies.com bug: signature has no header
- sig = str(msg)
+ sig = msg.asbytes()
else:
- kind = msg.get_string()
+ kind = msg.get_text()
if kind != 'ssh-dss':
return 0
- sig = msg.get_string()
+ sig = msg.get_binary()
# pull out (r, s) which are NOT encoded as mpints
sigR = util.inflate_long(sig[:20], 1)
@@ -140,7 +143,7 @@ class DSSKey (PKey):
b.encode(keylist)
except BERException:
raise SSHException('Unable to create ber encoding of key')
- return str(b)
+ return b.asbytes()
def write_private_key_file(self, filename, password=None):
self._write_private_key_file('DSA', filename, self._encode_key(), password)
@@ -182,8 +185,8 @@ class DSSKey (PKey):
# DSAPrivateKey = { version = 0, p, q, g, y, x }
try:
keylist = BER(data).decode()
- except BERException, x:
- raise SSHException('Unable to parse key file: ' + str(x))
+ except BERException as e:
+ raise SSHException('Unable to parse key file: ' + str(e))
if (type(keylist) is not list) or (len(keylist) < 6) or (keylist[0] != 0):
raise SSHException('not a valid DSA private key file (bad ber encoding)')
self.p = keylist[1]
diff --git a/paramiko/ecdsakey.py b/paramiko/ecdsakey.py
index ac840ab7..3ecf0a58 100644
--- a/paramiko/ecdsakey.py
+++ b/paramiko/ecdsakey.py
@@ -56,30 +56,33 @@ class ECDSAKey (PKey):
else:
if msg is None:
raise SSHException('Key object may not be empty')
- if msg.get_string() != 'ecdsa-sha2-nistp256':
+ if msg.get_text() != 'ecdsa-sha2-nistp256':
raise SSHException('Invalid key')
- curvename = msg.get_string()
+ curvename = msg.get_text()
if curvename != 'nistp256':
raise SSHException("Can't handle curve of type %s" % curvename)
- pointinfo = msg.get_string()
- if pointinfo[0] != "\x04":
- raise SSHException('Point compression is being used: %s'%
+ pointinfo = msg.get_binary()
+ if pointinfo[0:1] != four_byte:
+ raise SSHException('Point compression is being used: %s' %
binascii.hexlify(pointinfo))
self.verifying_key = VerifyingKey.from_string(pointinfo[1:],
- curve=curves.NIST256p)
+ curve=curves.NIST256p)
self.size = 256
- def __str__(self):
+ def asbytes(self):
key = self.verifying_key
m = Message()
m.add_string('ecdsa-sha2-nistp256')
m.add_string('nistp256')
- point_str = "\x04" + key.to_string()
+ point_str = four_byte + key.to_string()
m.add_string(point_str)
- return str(m)
+ return m.asbytes()
+
+ def __str__(self):
+ return self.asbytes()
def __hash__(self):
h = hash(self.get_name())
@@ -106,9 +109,9 @@ class ECDSAKey (PKey):
return m
def verify_ssh_sig(self, data, msg):
- if msg.get_string() != 'ecdsa-sha2-nistp256':
+ if msg.get_text() != 'ecdsa-sha2-nistp256':
return False
- sig = msg.get_string()
+ sig = msg.get_binary()
# verify the signature by SHA'ing the data and encrypting it
# using the public key.
@@ -154,14 +157,13 @@ class ECDSAKey (PKey):
data = self._read_private_key('EC', file_obj, password)
self._decode_key(data)
- ALLOWED_PADDINGS = ['\x01', '\x02\x02', '\x03\x03\x03', '\x04\x04\x04\x04',
- '\x05\x05\x05\x05\x05', '\x06\x06\x06\x06\x06\x06',
- '\x07\x07\x07\x07\x07\x07\x07']
+ ALLOWED_PADDINGS = [one_byte, byte_chr(2) * 2, byte_chr(3) * 3, byte_chr(4) * 4,
+ byte_chr(5) * 5, byte_chr(6) * 6, byte_chr(7) * 7]
def _decode_key(self, data):
s, padding = der.remove_sequence(data)
if padding:
if padding not in self.ALLOWED_PADDINGS:
- raise ValueError, "weird padding: %s" % (binascii.hexlify(empty))
+ raise ValueError("weird padding: %s" % u(binascii.hexlify(data)))
data = data[:-len(padding)]
key = SigningKey.from_der(data)
self.signing_key = key
@@ -172,7 +174,7 @@ class ECDSAKey (PKey):
msg = Message()
msg.add_mpint(r)
msg.add_mpint(s)
- return str(msg)
+ return msg.asbytes()
def _sigdecode(self, sig, order):
msg = Message(sig)
diff --git a/paramiko/file.py b/paramiko/file.py
index 253ffcd0..9f002423 100644
--- a/paramiko/file.py
+++ b/paramiko/file.py
@@ -16,7 +16,7 @@
# along with Paramiko; if not, write to the Free Software Foundation, Inc.,
# 59 Temple Place, Suite 330, Boston, MA 02111-1307 USA.
-from cStringIO import StringIO
+from paramiko.common import *
class BufferedFile (object):
@@ -43,8 +43,8 @@ class BufferedFile (object):
self.newlines = None
self._flags = 0
self._bufsize = self._DEFAULT_BUFSIZE
- self._wbuffer = StringIO()
- self._rbuffer = ''
+ self._wbuffer = BytesIO()
+ self._rbuffer = bytes()
self._at_trailing_cr = False
self._closed = False
# pos - position within the file, according to the user
@@ -82,9 +82,10 @@ class BufferedFile (object):
buffering is not turned on.
"""
self._write_all(self._wbuffer.getvalue())
- self._wbuffer = StringIO()
+ self._wbuffer = BytesIO()
return
+ if PY2:
def next(self):
"""
Returns the next line from the input, or raises
@@ -99,6 +100,22 @@ class BufferedFile (object):
if not line:
raise StopIteration
return line
+ else:
+ def __next__(self):
+ """
+ Returns the next line from the input, or raises L{StopIteration} when
+ EOF is hit. Unlike python file objects, it's okay to mix calls to
+ C{next} and L{readline}.
+
+ @raise StopIteration: when the end of the file is reached.
+
+ @return: a line read from the file.
+ @rtype: str
+ """
+ line = self.readline()
+ if not line:
+ raise StopIteration
+ return line
def read(self, size=None):
"""
@@ -118,7 +135,7 @@ class BufferedFile (object):
if (size is None) or (size < 0):
# go for broke
result = self._rbuffer
- self._rbuffer = ''
+ self._rbuffer = bytes()
self._pos += len(result)
while True:
try:
@@ -130,12 +147,12 @@ class BufferedFile (object):
result += new_data
self._realpos += len(new_data)
self._pos += len(new_data)
- return result
+ return result if self._flags & self.FLAG_BINARY else u(result)
if size <= len(self._rbuffer):
result = self._rbuffer[:size]
self._rbuffer = self._rbuffer[size:]
self._pos += len(result)
- return result
+ return result if self._flags & self.FLAG_BINARY else u(result)
while len(self._rbuffer) < size:
read_size = size - len(self._rbuffer)
if self._flags & self.FLAG_BUFFERED:
@@ -151,7 +168,7 @@ class BufferedFile (object):
result = self._rbuffer[:size]
self._rbuffer = self._rbuffer[size:]
self._pos += len(result)
- return result
+ return result if self._flags & self.FLAG_BINARY else u(result)
def readline(self, size=None):
"""
@@ -181,11 +198,11 @@ class BufferedFile (object):
if self._at_trailing_cr and (self._flags & self.FLAG_UNIVERSAL_NEWLINE) and (len(line) > 0):
# edge case: the newline may be '\r\n' and we may have read
# only the first '\r' last time.
- if line[0] == '\n':
+ if line[0] == linefeed_byte_value:
line = line[1:]
- self._record_newline('\r\n')
+ self._record_newline(crlf)
else:
- self._record_newline('\r')
+ self._record_newline(cr_byte)
self._at_trailing_cr = False
# check size before looking for a linefeed, in case we already have
# enough.
@@ -195,42 +212,42 @@ class BufferedFile (object):
self._rbuffer = line[size:]
line = line[:size]
self._pos += len(line)
- return line
+ return line if self._flags & self.FLAG_BINARY else u(line)
n = size - len(line)
else:
n = self._bufsize
- if ('\n' in line) or ((self._flags & self.FLAG_UNIVERSAL_NEWLINE) and ('\r' in line)):
+ if (linefeed_byte in line) or ((self._flags & self.FLAG_UNIVERSAL_NEWLINE) and (cr_byte in line)):
break
try:
new_data = self._read(n)
except EOFError:
new_data = None
if (new_data is None) or (len(new_data) == 0):
- self._rbuffer = ''
+ self._rbuffer = bytes()
self._pos += len(line)
- return line
+ return line if self._flags & self.FLAG_BINARY else u(line)
line += new_data
self._realpos += len(new_data)
# find the newline
- pos = line.find('\n')
+ pos = line.find(linefeed_byte)
if self._flags & self.FLAG_UNIVERSAL_NEWLINE:
- rpos = line.find('\r')
+ rpos = line.find(cr_byte)
if (rpos >= 0) and ((rpos < pos) or (pos < 0)):
pos = rpos
xpos = pos + 1
- if (line[pos] == '\r') and (xpos < len(line)) and (line[xpos] == '\n'):
+ if (line[pos] == cr_byte_value) and (xpos < len(line)) and (line[xpos] == linefeed_byte_value):
xpos += 1
self._rbuffer = line[xpos:]
lf = line[pos:xpos]
- line = line[:pos] + '\n'
- if (len(self._rbuffer) == 0) and (lf == '\r'):
+ line = line[:pos] + linefeed_byte
+ if (len(self._rbuffer) == 0) and (lf == cr_byte):
# we could read the line up to a '\r' and there could still be a
# '\n' following that we read next time. note that and eat it.
self._at_trailing_cr = True
else:
self._record_newline(lf)
self._pos += len(line)
- return line
+ return line if self._flags & self.FLAG_BINARY else u(line)
def readlines(self, sizehint=None):
"""
@@ -243,14 +260,14 @@ class BufferedFile (object):
:return: `list` of lines read from the file.
"""
lines = []
- bytes = 0
+ byte_count = 0
while True:
line = self.readline()
if len(line) == 0:
break
lines.append(line)
- bytes += len(line)
- if (sizehint is not None) and (bytes >= sizehint):
+ byte_count += len(line)
+ if (sizehint is not None) and (byte_count >= sizehint):
break
return lines
@@ -292,6 +309,7 @@ class BufferedFile (object):
:param str data: data to write
"""
+ data = b(data)
if self._closed:
raise IOError('File is closed')
if not (self._flags & self.FLAG_WRITE):
@@ -302,12 +320,12 @@ class BufferedFile (object):
self._wbuffer.write(data)
if self._flags & self.FLAG_LINE_BUFFERED:
# only scan the new data for linefeed, to avoid wasting time.
- last_newline_pos = data.rfind('\n')
+ last_newline_pos = data.rfind(linefeed_byte)
if last_newline_pos >= 0:
wbuf = self._wbuffer.getvalue()
last_newline_pos += len(wbuf) - len(data)
self._write_all(wbuf[:last_newline_pos + 1])
- self._wbuffer = StringIO()
+ self._wbuffer = BytesIO()
self._wbuffer.write(wbuf[last_newline_pos + 1:])
return
# even if we're line buffering, if the buffer has grown past the
@@ -436,7 +454,7 @@ class BufferedFile (object):
return
if self.newlines is None:
self.newlines = newline
- elif (type(self.newlines) is str) and (self.newlines != newline):
+ elif self.newlines != newline and isinstance(self.newlines, bytes_types):
self.newlines = (self.newlines, newline)
elif newline not in self.newlines:
self.newlines += (newline,)
diff --git a/paramiko/hostkeys.py b/paramiko/hostkeys.py
index d436703c..072a393a 100644
--- a/paramiko/hostkeys.py
+++ b/paramiko/hostkeys.py
@@ -20,7 +20,10 @@
import base64
import binascii
from Crypto.Hash import SHA, HMAC
-import UserDict
+try:
+ from collections import MutableMapping
+except ImportError:
+ from UserDict import DictMixin as MutableMapping
from paramiko.common import *
from paramiko.dsskey import DSSKey
@@ -29,7 +32,7 @@ from paramiko.util import get_logger, constant_time_bytes_eq
from paramiko.ecdsakey import ECDSAKey
-class HostKeys (UserDict.DictMixin):
+class HostKeys (MutableMapping):
"""
Representation of an OpenSSH-style "known hosts" file. Host keys can be
read from one or more files, and then individual hosts can be looked up to
@@ -83,20 +86,19 @@ class HostKeys (UserDict.DictMixin):
:raises IOError: if there was an error reading the file
"""
- f = open(filename, 'r')
- for lineno, line in enumerate(f):
- line = line.strip()
- if (len(line) == 0) or (line[0] == '#'):
- continue
- e = HostKeyEntry.from_line(line, lineno)
- if e is not None:
- _hostnames = e.hostnames
- for h in _hostnames:
- if self.check(h, e.key):
- e.hostnames.remove(h)
- if len(e.hostnames):
- self._entries.append(e)
- f.close()
+ with open(filename, 'r') as f:
+ for lineno, line in enumerate(f):
+ line = line.strip()
+ if (len(line) == 0) or (line[0] == '#'):
+ continue
+ e = HostKeyEntry.from_line(line, lineno)
+ if e is not None:
+ _hostnames = e.hostnames
+ for h in _hostnames:
+ if self.check(h, e.key):
+ e.hostnames.remove(h)
+ if len(e.hostnames):
+ self._entries.append(e)
def save(self, filename):
"""
@@ -111,12 +113,11 @@ class HostKeys (UserDict.DictMixin):
.. versionadded:: 1.6.1
"""
- f = open(filename, 'w')
- for e in self._entries:
- line = e.to_line()
- if line:
- f.write(line)
- f.close()
+ with open(filename, 'w') as f:
+ for e in self._entries:
+ line = e.to_line()
+ if line:
+ f.write(line)
def lookup(self, hostname):
"""
@@ -127,12 +128,26 @@ class HostKeys (UserDict.DictMixin):
:param str hostname: the hostname (or IP) to lookup
:return: dict of `str` -> `.PKey` keys associated with this host (or ``None``)
"""
- class SubDict (UserDict.DictMixin):
+ class SubDict (MutableMapping):
def __init__(self, hostname, entries, hostkeys):
self._hostname = hostname
self._entries = entries
self._hostkeys = hostkeys
+ def __iter__(self):
+ for k in self.keys():
+ yield k
+
+ def __len__(self):
+ return len(self.keys())
+
+ def __delitem__(self, key):
+ for e in list(self._entries):
+ if e.key.get_name() == key:
+ self._entries.remove(e)
+ else:
+ raise KeyError(key)
+
def __getitem__(self, key):
for e in self._entries:
if e.key.get_name() == key:
@@ -181,7 +196,7 @@ class HostKeys (UserDict.DictMixin):
host_key = k.get(key.get_name(), None)
if host_key is None:
return False
- return str(host_key) == str(key)
+ return host_key.asbytes() == key.asbytes()
def clear(self):
"""
@@ -189,6 +204,17 @@ class HostKeys (UserDict.DictMixin):
"""
self._entries = []
+ def __iter__(self):
+ for k in self.keys():
+ yield k
+
+ def __len__(self):
+ return len(self.keys())
+
+ def __delitem__(self, key):
+ k = self[key]
+ pass
+
def __getitem__(self, key):
ret = self.lookup(key)
if ret is None:
@@ -239,10 +265,10 @@ class HostKeys (UserDict.DictMixin):
else:
if salt.startswith('|1|'):
salt = salt.split('|')[2]
- salt = base64.decodestring(salt)
+ salt = decodebytes(b(salt))
assert len(salt) == SHA.digest_size
- hmac = HMAC.HMAC(salt, hostname, SHA).digest()
- hostkey = '|1|%s|%s' % (base64.encodestring(salt), base64.encodestring(hmac))
+ hmac = HMAC.HMAC(salt, b(hostname), SHA).digest()
+ hostkey = '|1|%s|%s' % (u(encodebytes(salt)), u(encodebytes(hmac)))
return hostkey.replace('\n', '')
hash_host = staticmethod(hash_host)
@@ -292,17 +318,17 @@ class HostKeyEntry:
# to hold it accordingly.
try:
if keytype == 'ssh-rsa':
- key = RSAKey(data=base64.decodestring(key))
+ key = RSAKey(data=decodebytes(key))
elif keytype == 'ssh-dss':
- key = DSSKey(data=base64.decodestring(key))
+ key = DSSKey(data=decodebytes(key))
elif keytype == 'ecdsa-sha2-nistp256':
- key = ECDSAKey(data=base64.decodestring(key))
+ key = ECDSAKey(data=decodebytes(key))
else:
log.info("Unable to handle key of type %s" % (keytype,))
return None
- except binascii.Error, e:
- raise InvalidHostKey(line, e)
+ except binascii.Error as e:
+ raise InvalidHostKey(line, sys.exc_info()[1])
return cls(names, key)
from_line = classmethod(from_line)
diff --git a/paramiko/kex_gex.py b/paramiko/kex_gex.py
index 27287300..8ac23212 100644
--- a/paramiko/kex_gex.py
+++ b/paramiko/kex_gex.py
@@ -33,6 +33,8 @@ from paramiko.ssh_exception import SSHException
_MSG_KEXDH_GEX_REQUEST_OLD, _MSG_KEXDH_GEX_GROUP, _MSG_KEXDH_GEX_INIT, \
_MSG_KEXDH_GEX_REPLY, _MSG_KEXDH_GEX_REQUEST = range(30, 35)
+c_MSG_KEXDH_GEX_REQUEST_OLD, c_MSG_KEXDH_GEX_GROUP, c_MSG_KEXDH_GEX_INIT, \
+ c_MSG_KEXDH_GEX_REPLY, c_MSG_KEXDH_GEX_REQUEST = [byte_chr(c) for c in range(30, 35)]
class KexGex (object):
@@ -62,11 +64,11 @@ class KexGex (object):
m = Message()
if _test_old_style:
# only used for unit tests: we shouldn't ever send this
- m.add_byte(chr(_MSG_KEXDH_GEX_REQUEST_OLD))
+ m.add_byte(c_MSG_KEXDH_GEX_REQUEST_OLD)
m.add_int(self.preferred_bits)
self.old_style = True
else:
- m.add_byte(chr(_MSG_KEXDH_GEX_REQUEST))
+ m.add_byte(c_MSG_KEXDH_GEX_REQUEST)
m.add_int(self.min_bits)
m.add_int(self.preferred_bits)
m.add_int(self.max_bits)
@@ -94,15 +96,15 @@ class KexGex (object):
# generate an "x" (1 < x < (p-1)/2).
q = (self.p - 1) // 2
qnorm = util.deflate_long(q, 0)
- qhbyte = ord(qnorm[0])
- bytes = len(qnorm)
+ qhbyte = byte_ord(qnorm[0])
+ byte_count = len(qnorm)
qmask = 0xff
while not (qhbyte & 0x80):
qhbyte <<= 1
qmask >>= 1
while True:
- x_bytes = self.transport.rng.read(bytes)
- x_bytes = chr(ord(x_bytes[0]) & qmask) + x_bytes[1:]
+ x_bytes = self.transport.rng.read(byte_count)
+ x_bytes = byte_mask(x_bytes[0], qmask) + x_bytes[1:]
x = util.inflate_long(x_bytes, 1)
if (x > 1) and (x < q):
break
@@ -135,7 +137,7 @@ class KexGex (object):
self.transport._log(DEBUG, 'Picking p (%d <= %d <= %d bits)' % (minbits, preferredbits, maxbits))
self.g, self.p = pack.get_modulus(minbits, preferredbits, maxbits)
m = Message()
- m.add_byte(chr(_MSG_KEXDH_GEX_GROUP))
+ m.add_byte(c_MSG_KEXDH_GEX_GROUP)
m.add_mpint(self.p)
m.add_mpint(self.g)
self.transport._send_message(m)
@@ -156,7 +158,7 @@ class KexGex (object):
self.transport._log(DEBUG, 'Picking p (~ %d bits)' % (self.preferred_bits,))
self.g, self.p = pack.get_modulus(self.min_bits, self.preferred_bits, self.max_bits)
m = Message()
- m.add_byte(chr(_MSG_KEXDH_GEX_GROUP))
+ m.add_byte(c_MSG_KEXDH_GEX_GROUP)
m.add_mpint(self.p)
m.add_mpint(self.g)
self.transport._send_message(m)
@@ -175,7 +177,7 @@ class KexGex (object):
# now compute e = g^x mod p
self.e = pow(self.g, self.x, self.p)
m = Message()
- m.add_byte(chr(_MSG_KEXDH_GEX_INIT))
+ m.add_byte(c_MSG_KEXDH_GEX_INIT)
m.add_mpint(self.e)
self.transport._send_message(m)
self.transport._expect_packet(_MSG_KEXDH_GEX_REPLY)
@@ -187,7 +189,7 @@ class KexGex (object):
self._generate_x()
self.f = pow(self.g, self.x, self.p)
K = pow(self.e, self.x, self.p)
- key = str(self.transport.get_server_key())
+ key = self.transport.get_server_key().asbytes()
# okay, build up the hash H of (V_C || V_S || I_C || I_S || K_S || min || n || max || p || g || e || f || K)
hm = Message()
hm.add(self.transport.remote_version, self.transport.local_version,
@@ -203,16 +205,16 @@ class KexGex (object):
hm.add_mpint(self.e)
hm.add_mpint(self.f)
hm.add_mpint(K)
- H = SHA.new(str(hm)).digest()
+ H = SHA.new(hm.asbytes()).digest()
self.transport._set_K_H(K, H)
# sign it
sig = self.transport.get_server_key().sign_ssh_data(self.transport.rng, H)
# send reply
m = Message()
- m.add_byte(chr(_MSG_KEXDH_GEX_REPLY))
+ m.add_byte(c_MSG_KEXDH_GEX_REPLY)
m.add_string(key)
m.add_mpint(self.f)
- m.add_string(str(sig))
+ m.add_string(sig)
self.transport._send_message(m)
self.transport._activate_outbound()
@@ -238,6 +240,6 @@ class KexGex (object):
hm.add_mpint(self.e)
hm.add_mpint(self.f)
hm.add_mpint(K)
- self.transport._set_K_H(K, SHA.new(str(hm)).digest())
+ self.transport._set_K_H(K, SHA.new(hm.asbytes()).digest())
self.transport._verify_key(host_key, sig)
self.transport._activate_outbound()
diff --git a/paramiko/kex_group1.py b/paramiko/kex_group1.py
index 6e89b6dc..05693a1f 100644
--- a/paramiko/kex_group1.py
+++ b/paramiko/kex_group1.py
@@ -30,11 +30,14 @@ from paramiko.ssh_exception import SSHException
_MSG_KEXDH_INIT, _MSG_KEXDH_REPLY = range(30, 32)
+c_MSG_KEXDH_INIT, c_MSG_KEXDH_REPLY = [byte_chr(c) for c in range(30, 32)]
# draft-ietf-secsh-transport-09.txt, page 17
-P = 0xFFFFFFFFFFFFFFFFC90FDAA22168C234C4C6628B80DC1CD129024E088A67CC74020BBEA63B139B22514A08798E3404DDEF9519B3CD3A431B302B0A6DF25F14374FE1356D6D51C245E485B576625E7EC6F44C42E9A637ED6B0BFF5CB6F406B7EDEE386BFB5A899FA5AE9F24117C4B1FE649286651ECE65381FFFFFFFFFFFFFFFFL
+P = 0xFFFFFFFFFFFFFFFFC90FDAA22168C234C4C6628B80DC1CD129024E088A67CC74020BBEA63B139B22514A08798E3404DDEF9519B3CD3A431B302B0A6DF25F14374FE1356D6D51C245E485B576625E7EC6F44C42E9A637ED6B0BFF5CB6F406B7EDEE386BFB5A899FA5AE9F24117C4B1FE649286651ECE65381FFFFFFFFFFFFFFFF
G = 2
+b7fffffffffffffff = byte_chr(0x7f) + max_byte * 7
+b0000000000000000 = zero_byte * 8
class KexGroup1(object):
@@ -42,9 +45,9 @@ class KexGroup1(object):
def __init__(self, transport):
self.transport = transport
- self.x = 0L
- self.e = 0L
- self.f = 0L
+ self.x = long(0)
+ self.e = long(0)
+ self.f = long(0)
def start_kex(self):
self._generate_x()
@@ -56,7 +59,7 @@ class KexGroup1(object):
# compute e = g^x mod p (where g=2), and send it
self.e = pow(G, self.x, P)
m = Message()
- m.add_byte(chr(_MSG_KEXDH_INIT))
+ m.add_byte(c_MSG_KEXDH_INIT)
m.add_mpint(self.e)
self.transport._send_message(m)
self.transport._expect_packet(_MSG_KEXDH_REPLY)
@@ -67,7 +70,7 @@ class KexGroup1(object):
elif not self.transport.server_mode and (ptype == _MSG_KEXDH_REPLY):
return self._parse_kexdh_reply(m)
raise SSHException('KexGroup1 asked to handle packet type %d' % ptype)
-
+
### internals...
@@ -80,9 +83,9 @@ class KexGroup1(object):
# larger than q (but this is a tiny tiny subset of potential x).
while 1:
x_bytes = self.transport.rng.read(128)
- x_bytes = chr(ord(x_bytes[0]) & 0x7f) + x_bytes[1:]
- if (x_bytes[:8] != '\x7F\xFF\xFF\xFF\xFF\xFF\xFF\xFF') and \
- (x_bytes[:8] != '\x00\x00\x00\x00\x00\x00\x00\x00'):
+ x_bytes = byte_mask(x_bytes[0], 0x7f) + x_bytes[1:]
+ if (x_bytes[:8] != b7fffffffffffffff) and \
+ (x_bytes[:8] != b0000000000000000):
break
self.x = util.inflate_long(x_bytes)
@@ -92,7 +95,7 @@ class KexGroup1(object):
self.f = m.get_mpint()
if (self.f < 1) or (self.f > P - 1):
raise SSHException('Server kex "f" is out of range')
- sig = m.get_string()
+ sig = m.get_binary()
K = pow(self.f, self.x, P)
# okay, build up the hash H of (V_C || V_S || I_C || I_S || K_S || e || f || K)
hm = Message()
@@ -102,7 +105,7 @@ class KexGroup1(object):
hm.add_mpint(self.e)
hm.add_mpint(self.f)
hm.add_mpint(K)
- self.transport._set_K_H(K, SHA.new(str(hm)).digest())
+ self.transport._set_K_H(K, SHA.new(hm.asbytes()).digest())
self.transport._verify_key(host_key, sig)
self.transport._activate_outbound()
@@ -112,7 +115,7 @@ class KexGroup1(object):
if (self.e < 1) or (self.e > P - 1):
raise SSHException('Client kex "e" is out of range')
K = pow(self.e, self.x, P)
- key = str(self.transport.get_server_key())
+ key = self.transport.get_server_key().asbytes()
# okay, build up the hash H of (V_C || V_S || I_C || I_S || K_S || e || f || K)
hm = Message()
hm.add(self.transport.remote_version, self.transport.local_version,
@@ -121,15 +124,15 @@ class KexGroup1(object):
hm.add_mpint(self.e)
hm.add_mpint(self.f)
hm.add_mpint(K)
- H = SHA.new(str(hm)).digest()
+ H = SHA.new(hm.asbytes()).digest()
self.transport._set_K_H(K, H)
# sign it
sig = self.transport.get_server_key().sign_ssh_data(self.transport.rng, H)
# send reply
m = Message()
- m.add_byte(chr(_MSG_KEXDH_REPLY))
+ m.add_byte(c_MSG_KEXDH_REPLY)
m.add_string(key)
m.add_mpint(self.f)
- m.add_string(str(sig))
+ m.add_string(sig)
self.transport._send_message(m)
self.transport._activate_outbound()
diff --git a/paramiko/message.py b/paramiko/message.py
index 213b2e79..a487f2e8 100644
--- a/paramiko/message.py
+++ b/paramiko/message.py
@@ -21,9 +21,9 @@ Implementation of an SSH2 "message".
"""
import struct
-import cStringIO
from paramiko import util
+from paramiko.common import *
class Message (object):
@@ -37,6 +37,8 @@ class Message (object):
paramiko doesn't support yet.
"""
+ big_int = long(0xff000000)
+
def __init__(self, content=None):
"""
Create a new SSH2 message.
@@ -46,15 +48,15 @@ class Message (object):
decomposing a message).
"""
if content != None:
- self.packet = cStringIO.StringIO(content)
+ self.packet = BytesIO(content)
else:
- self.packet = cStringIO.StringIO()
+ self.packet = BytesIO()
def __str__(self):
"""
- Return the byte stream content of this message, as a string.
+ Return the byte stream content of this message, as a string/bytes obj.
"""
- return self.packet.getvalue()
+ return self.asbytes()
def __repr__(self):
"""
@@ -62,6 +64,15 @@ class Message (object):
"""
return 'paramiko.Message(' + repr(self.packet.getvalue()) + ')'
+ def asbytes(self):
+ """
+ Return the byte stream content of this Message, as bytes.
+
+ @return: the contents of this Message.
+ @rtype: bytes
+ """
+ return self.packet.getvalue()
+
def rewind(self):
"""
Rewind the message to the beginning as if no items had been parsed
@@ -99,7 +110,7 @@ class Message (object):
b = self.packet.read(n)
max_pad_size = 1<<20 # Limit padding to 1 MB
if len(b) < n and n < max_pad_size:
- return b + '\x00' * (n - len(b))
+ return b + zero_byte * (n - len(b))
return b
def get_byte(self):
@@ -118,7 +129,7 @@ class Message (object):
Fetch a boolean from the stream.
"""
b = self.get_bytes(1)
- return b != '\x00'
+ return b != zero_byte
def get_int(self):
"""
@@ -126,6 +137,19 @@ class Message (object):
:return: a 32-bit unsigned `int`.
"""
+ byte = self.get_bytes(1)
+ if byte == max_byte:
+ return util.inflate_long(self.get_binary())
+ byte += self.get_bytes(3)
+ return struct.unpack('>I', byte)[0]
+
+ def get_size(self):
+ """
+ Fetch an int from the stream.
+
+ @return: a 32-bit unsigned integer.
+ @rtype: int
+ """
return struct.unpack('>I', self.get_bytes(4))[0]
def get_int64(self):
@@ -142,7 +166,7 @@ class Message (object):
:return: an arbitrary-length integer (`long`).
"""
- return util.inflate_long(self.get_string())
+ return util.inflate_long(self.get_binary())
def get_string(self):
"""
@@ -150,7 +174,30 @@ class Message (object):
contain unprintable characters. (It's not unheard of for a string to
contain another byte-stream message.)
"""
- return self.get_bytes(self.get_int())
+ return self.get_bytes(self.get_size())
+
+ def get_text(self):
+ """
+ Fetch a string from the stream. This could be a byte string and may
+ contain unprintable characters. (It's not unheard of for a string to
+ contain another byte-stream Message.)
+
+ @return: a string.
+ @rtype: string
+ """
+ return u(self.get_bytes(self.get_size()))
+ #return self.get_bytes(self.get_size())
+
+ def get_binary(self):
+ """
+ Fetch a string from the stream. This could be a byte string and may
+ contain unprintable characters. (It's not unheard of for a string to
+ contain another byte-stream Message.)
+
+ @return: a string.
+ @rtype: string
+ """
+ return self.get_bytes(self.get_size())
def get_list(self):
"""
@@ -158,7 +205,7 @@ class Message (object):
These are trivially encoded as comma-separated values in a string.
"""
- return self.get_string().split(',')
+ return self.get_text().split(',')
def add_bytes(self, b):
"""
@@ -185,9 +232,19 @@ class Message (object):
:param bool b: boolean value to add
"""
if b:
- self.add_byte('\x01')
+ self.packet.write(one_byte)
else:
- self.add_byte('\x00')
+ self.packet.write(zero_byte)
+ return self
+
+ def add_size(self, n):
+ """
+ Add an integer to the stream.
+
+ @param n: integer to add
+ @type n: int
+ """
+ self.packet.write(struct.pack('>I', n))
return self
def add_int(self, n):
@@ -196,6 +253,10 @@ class Message (object):
:param int n: integer to add
"""
+ if n >= Message.big_int:
+ self.packet.write(max_byte)
+ self.add_string(util.deflate_long(n))
+ else:
self.packet.write(struct.pack('>I', n))
return self
@@ -224,7 +285,8 @@ class Message (object):
:param str s: string to add
"""
- self.add_int(len(s))
+ s = asbytes(s)
+ self.add_size(len(s))
self.packet.write(s)
return self
@@ -240,21 +302,14 @@ class Message (object):
return self
def _add(self, i):
- if type(i) is str:
- return self.add_string(i)
- elif type(i) is int:
- return self.add_int(i)
- elif type(i) is long:
- if i > 0xffffffffL:
- return self.add_mpint(i)
- else:
- return self.add_int(i)
- elif type(i) is bool:
+ if type(i) is bool:
return self.add_boolean(i)
+ elif isinstance(i, integer_types):
+ return self.add_int(i)
elif type(i) is list:
return self.add_list(i)
else:
- raise Exception('Unknown type')
+ return self.add_string(i)
def add(self, *seq):
"""
diff --git a/paramiko/packet.py b/paramiko/packet.py
index 62cda219..fd1f0197 100644
--- a/paramiko/packet.py
+++ b/paramiko/packet.py
@@ -38,6 +38,7 @@ try:
except ImportError:
from Crypto.Hash.HMAC import HMAC
+
def compute_hmac(key, message, digest_class):
return HMAC(key, message, digest_class).digest()
@@ -66,7 +67,7 @@ class Packetizer (object):
self.__dump_packets = False
self.__need_rekey = False
self.__init_count = 0
- self.__remainder = ''
+ self.__remainder = bytes()
# used for noticing when to re-key:
self.__sent_bytes = 0
@@ -86,12 +87,12 @@ class Packetizer (object):
self.__sdctr_out = False
self.__mac_engine_out = None
self.__mac_engine_in = None
- self.__mac_key_out = ''
- self.__mac_key_in = ''
+ self.__mac_key_out = bytes()
+ self.__mac_key_in = bytes()
self.__compress_engine_out = None
self.__compress_engine_in = None
- self.__sequence_number_out = 0L
- self.__sequence_number_in = 0L
+ self.__sequence_number_out = 0
+ self.__sequence_number_in = 0
# lock around outbound writes (packet computation)
self.__write_lock = threading.RLock()
@@ -152,6 +153,7 @@ class Packetizer (object):
def close(self):
self.__closed = True
+ self.__socket.close()
def set_hexdump(self, hexdump):
self.__dump_packets = hexdump
@@ -193,7 +195,7 @@ class Packetizer (object):
:raises EOFError:
if the socket was closed before all the bytes could be read
"""
- out = ''
+ out = bytes()
# handle over-reading from reading the banner line
if len(self.__remainder) > 0:
out = self.__remainder[:n]
@@ -211,7 +213,7 @@ class Packetizer (object):
n -= len(x)
except socket.timeout:
got_timeout = True
- except socket.error, e:
+ except socket.error as e:
# on Linux, sometimes instead of socket.timeout, we get
# EAGAIN. this is a bug in recent (> 2.6.9) kernels but
# we need to work around it.
@@ -240,7 +242,7 @@ class Packetizer (object):
n = self.__socket.send(out)
except socket.timeout:
retry_write = True
- except socket.error, e:
+ except socket.error as e:
if (type(e.args) is tuple) and (len(e.args) > 0) and (e.args[0] == errno.EAGAIN):
retry_write = True
elif (type(e.args) is tuple) and (len(e.args) > 0) and (e.args[0] == errno.EINTR):
@@ -270,22 +272,22 @@ class Packetizer (object):
line, so it's okay to attempt large reads.
"""
buf = self.__remainder
- while not '\n' in buf:
+ while not linefeed_byte in buf:
buf += self._read_timeout(timeout)
- n = buf.index('\n')
+ n = buf.index(linefeed_byte)
self.__remainder = buf[n+1:]
buf = buf[:n]
- if (len(buf) > 0) and (buf[-1] == '\r'):
+ if (len(buf) > 0) and (buf[-1] == cr_byte_value):
buf = buf[:-1]
- return buf
+ return u(buf)
def send_message(self, data):
"""
Write a block of data using the current cipher, as an SSH block.
"""
# encrypt this sucka
- data = str(data)
- cmd = ord(data[0])
+ data = asbytes(data)
+ cmd = byte_ord(data[0])
if cmd in MSG_NAMES:
cmd_name = MSG_NAMES[cmd]
else:
@@ -307,7 +309,7 @@ class Packetizer (object):
if self.__block_engine_out != None:
payload = struct.pack('>I', self.__sequence_number_out) + packet
out += compute_hmac(self.__mac_key_out, payload, self.__mac_engine_out)[:self.__mac_size_out]
- self.__sequence_number_out = (self.__sequence_number_out + 1) & 0xffffffffL
+ self.__sequence_number_out = (self.__sequence_number_out + 1) & xffffffff
self.write_all(out)
self.__sent_bytes += len(out)
@@ -356,7 +358,7 @@ class Packetizer (object):
my_mac = compute_hmac(self.__mac_key_in, mac_payload, self.__mac_engine_in)[:self.__mac_size_in]
if not util.constant_time_bytes_eq(my_mac, mac):
raise SSHException('Mismatched MAC')
- padding = ord(packet[0])
+ padding = byte_ord(packet[0])
payload = packet[1:packet_size - padding]
if self.__dump_packets:
@@ -367,7 +369,7 @@ class Packetizer (object):
msg = Message(payload[1:])
msg.seqno = self.__sequence_number_in
- self.__sequence_number_in = (self.__sequence_number_in + 1) & 0xffffffffL
+ self.__sequence_number_in = (self.__sequence_number_in + 1) & xffffffff
# check for rekey
raw_packet_size = packet_size + self.__mac_size_in + 4
@@ -390,7 +392,7 @@ class Packetizer (object):
self.__received_packets_overflow = 0
self._trigger_rekey()
- cmd = ord(payload[0])
+ cmd = byte_ord(payload[0])
if cmd in MSG_NAMES:
cmd_name = MSG_NAMES[cmd]
else:
@@ -465,7 +467,7 @@ class Packetizer (object):
break
except socket.timeout:
pass
- except EnvironmentError, e:
+ except EnvironmentError as e:
if ((type(e.args) is tuple) and (len(e.args) > 0) and
(e.args[0] == errno.EINTR)):
pass
@@ -487,7 +489,7 @@ class Packetizer (object):
if self.__sdctr_out or self.__block_engine_out is None:
# cute trick i caught openssh doing: if we're not encrypting or SDCTR mode (RFC4344),
# don't waste random bytes for the padding
- packet += (chr(0) * padding)
+ packet += (zero_byte * padding)
else:
packet += rng.read(padding)
return packet
diff --git a/paramiko/pipe.py b/paramiko/pipe.py
index 705f8d49..568aca6b 100644
--- a/paramiko/pipe.py
+++ b/paramiko/pipe.py
@@ -28,6 +28,7 @@ will trigger as readable in `select <select.select>`.
import sys
import os
import socket
+from paramiko.py3compat import b
def make_pipe ():
@@ -64,7 +65,7 @@ class PosixPipe (object):
if self._set or self._closed:
return
self._set = True
- os.write(self._wfd, '*')
+ os.write(self._wfd, b'*')
def set_forever (self):
self._forever = True
@@ -110,7 +111,7 @@ class WindowsPipe (object):
if self._set or self._closed:
return
self._set = True
- self._wsock.send('*')
+ self._wsock.send(b'*')
def set_forever (self):
self._forever = True
diff --git a/paramiko/pkey.py b/paramiko/pkey.py
index ea4d2140..31b4e7c1 100644
--- a/paramiko/pkey.py
+++ b/paramiko/pkey.py
@@ -62,13 +62,16 @@ class PKey (object):
"""
pass
- def __str__(self):
+ def asbytes(self):
"""
Return a string of an SSH `.Message` made up of the public part(s) of
this key. This string is suitable for passing to `__init__` to
re-create the key object later.
"""
- return ''
+ return bytes()
+
+ def __str__(self):
+ return self.asbytes()
def __cmp__(self, other):
"""
@@ -83,7 +86,10 @@ class PKey (object):
ho = hash(other)
if hs != ho:
return cmp(hs, ho)
- return cmp(str(self), str(other))
+ return cmp(self.asbytes(), other.asbytes())
+
+ def __eq__(self, other):
+ return hash(self) == hash(other)
def get_name(self):
"""
@@ -120,7 +126,7 @@ class PKey (object):
a 16-byte `string <str>` (binary) of the MD5 fingerprint, in SSH
format.
"""
- return MD5.new(str(self)).digest()
+ return MD5.new(self.asbytes()).digest()
def get_base64(self):
"""
@@ -130,7 +136,7 @@ class PKey (object):
:return: a base64 `string <str>` containing the public part of the key.
"""
- return base64.encodestring(str(self)).replace('\n', '')
+ return u(encodebytes(self.asbytes())).replace('\n', '')
def sign_ssh_data(self, rng, data):
"""
@@ -141,7 +147,7 @@ class PKey (object):
:param str data: the data to sign.
:return: an SSH signature `message <.Message>`.
"""
- return ''
+ return bytes()
def verify_ssh_sig(self, data, msg):
"""
@@ -246,9 +252,8 @@ class PKey (object):
encrypted, and ``password`` is ``None``.
:raises SSHException: if the key file is invalid.
"""
- f = open(filename, 'r')
+ with open(filename, 'r') as f:
data = self._read_private_key(tag, f, password)
- f.close()
return data
def _read_private_key(self, tag, f, password=None):
@@ -273,8 +278,8 @@ class PKey (object):
end += 1
# if we trudged to the end of the file, just try to cope.
try:
- data = base64.decodestring(''.join(lines[start:end]))
- except base64.binascii.Error, e:
+ data = decodebytes(b(''.join(lines[start:end])))
+ except base64.binascii.Error as e:
raise SSHException('base64 decoding error: ' + str(e))
if 'proc-type' not in headers:
# unencryped: done
@@ -285,7 +290,7 @@ class PKey (object):
try:
encryption_type, saltstr = headers['dek-info'].split(',')
except:
- raise SSHException('Can\'t parse DEK-info in private key file')
+ raise SSHException("Can't parse DEK-info in private key file")
if encryption_type not in self._CIPHER_TABLE:
raise SSHException('Unknown private key cipher "%s"' % encryption_type)
# if no password was passed in, raise an exception pointing out that we need one
@@ -294,7 +299,7 @@ class PKey (object):
cipher = self._CIPHER_TABLE[encryption_type]['cipher']
keysize = self._CIPHER_TABLE[encryption_type]['keysize']
mode = self._CIPHER_TABLE[encryption_type]['mode']
- salt = unhexlify(saltstr)
+ salt = unhexlify(b(saltstr))
key = util.generate_key_bytes(MD5, salt, password, keysize)
return cipher.new(key, mode, salt).decrypt(data)
@@ -312,33 +317,32 @@ class PKey (object):
:raises IOError: if there was an error writing the file.
"""
- f = open(filename, 'w', 0600)
+ with open(filename, 'w', o600) as f:
# grrr... the mode doesn't always take hold
- os.chmod(filename, 0600)
+ os.chmod(filename, o600)
self._write_private_key(tag, f, data, password)
- f.close()
def _write_private_key(self, tag, f, data, password=None):
f.write('-----BEGIN %s PRIVATE KEY-----\n' % tag)
if password is not None:
# since we only support one cipher here, use it
- cipher_name = self._CIPHER_TABLE.keys()[0]
+ cipher_name = list(self._CIPHER_TABLE.keys())[0]
cipher = self._CIPHER_TABLE[cipher_name]['cipher']
keysize = self._CIPHER_TABLE[cipher_name]['keysize']
blocksize = self._CIPHER_TABLE[cipher_name]['blocksize']
mode = self._CIPHER_TABLE[cipher_name]['mode']
- salt = rng.read(8)
+ salt = rng.read(16)
key = util.generate_key_bytes(MD5, salt, password, keysize)
if len(data) % blocksize != 0:
n = blocksize - len(data) % blocksize
#data += rng.read(n)
# that would make more sense ^, but it confuses openssh.
- data += '\0' * n
+ data += zero_byte * n
data = cipher.new(key, mode, salt).encrypt(data)
f.write('Proc-Type: 4,ENCRYPTED\n')
- f.write('DEK-Info: %s,%s\n' % (cipher_name, hexlify(salt).upper()))
+ f.write('DEK-Info: %s,%s\n' % (cipher_name, u(hexlify(salt)).upper()))
f.write('\n')
- s = base64.encodestring(data)
+ s = u(encodebytes(data))
# re-wrap to 64-char lines
s = ''.join(s.split('\n'))
s = '\n'.join([s[i : i+64] for i in range(0, len(s), 64)])
diff --git a/paramiko/primes.py b/paramiko/primes.py
index 86b9953a..34b9877e 100644
--- a/paramiko/primes.py
+++ b/paramiko/primes.py
@@ -24,6 +24,7 @@ from Crypto.Util import number
from paramiko import util
from paramiko.ssh_exception import SSHException
+from paramiko.common import *
def _generate_prime(bits, rng):
@@ -33,7 +34,7 @@ def _generate_prime(bits, rng):
# loop catches the case where we increment n into a higher bit-range
x = rng.read((bits+7) // 8)
if hbyte_mask > 0:
- x = chr(ord(x[0]) & hbyte_mask) + x[1:]
+ x = byte_mask(x[0], hbyte_mask) + x[1:]
n = util.inflate_long(x, 1)
n |= 1
n |= (1 << (bits - 1))
@@ -46,7 +47,7 @@ def _generate_prime(bits, rng):
def _roll_random(rng, n):
"returns a random # from 0 to N-1"
bits = util.bit_length(n-1)
- bytes = (bits + 7) // 8
+ byte_count = (bits + 7) // 8
hbyte_mask = pow(2, bits % 8) - 1
# so here's the plan:
@@ -56,9 +57,9 @@ def _roll_random(rng, n):
# fits, so i can't guarantee that this loop will ever finish, but the odds
# of it looping forever should be infinitesimal.
while True:
- x = rng.read(bytes)
+ x = rng.read(byte_count)
if hbyte_mask > 0:
- x = chr(ord(x[0]) & hbyte_mask) + x[1:]
+ x = byte_mask(x[0], hbyte_mask) + x[1:]
num = util.inflate_long(x, 1)
if num < n:
break
@@ -112,7 +113,7 @@ class ModulusPack (object):
:raises IOError: passed from any file operations that fail.
"""
self.pack = {}
- f = open(filename, 'r')
+ with open(filename, 'r') as f:
for line in f:
line = line.strip()
if (len(line) == 0) or (line[0] == '#'):
@@ -121,11 +122,9 @@ class ModulusPack (object):
self._parse_modulus(line)
except:
continue
- f.close()
def get_modulus(self, min, prefer, max):
- bitsizes = self.pack.keys()
- bitsizes.sort()
+ bitsizes = sorted(self.pack.keys())
if len(bitsizes) == 0:
raise SSHException('no moduli available')
good = -1
diff --git a/paramiko/proxy.py b/paramiko/proxy.py
index 10f0728f..c7e93efa 100644
--- a/paramiko/proxy.py
+++ b/paramiko/proxy.py
@@ -59,7 +59,7 @@ class ProxyCommand(object):
"""
try:
self.process.stdin.write(content)
- except IOError, e:
+ except IOError as e:
# There was a problem with the child process. It probably
# died and we can't proceed. The best option here is to
# raise an exception informing the user that the informed
@@ -95,7 +95,7 @@ class ProxyCommand(object):
return result
except socket.timeout:
raise # socket.timeout is a subclass of IOError
- except IOError, e:
+ except IOError as e:
raise ProxyCommandFailure(' '.join(self.cmd), e.strerror)
def close(self):
diff --git a/paramiko/py3compat.py b/paramiko/py3compat.py
new file mode 100644
index 00000000..22285992
--- /dev/null
+++ b/paramiko/py3compat.py
@@ -0,0 +1,160 @@
+import sys
+import base64
+
+__all__ = ['PY2', 'string_types', 'integer_types', 'text_type', 'bytes_types', 'bytes', 'long', 'input',
+ 'decodebytes', 'encodebytes', 'bytestring', 'byte_ord', 'byte_chr', 'byte_mask',
+ 'b', 'u', 'b2s', 'StringIO', 'BytesIO', 'is_callable', 'MAXSIZE', 'next']
+
+PY2 = sys.version_info[0] < 3
+
+if PY2:
+ string_types = basestring
+ text_type = unicode
+ bytes_types = str
+ bytes = str
+ integer_types = (int, long)
+ long = long
+ input = raw_input
+ decodebytes = base64.decodestring
+ encodebytes = base64.encodestring
+
+
+ def bytestring(s): # NOQA
+ if isinstance(s, unicode):
+ return s.encode('utf-8')
+ return s
+
+
+ byte_ord = ord # NOQA
+ byte_chr = chr # NOQA
+
+
+ def byte_mask(c, mask):
+ return chr(ord(c) & mask)
+
+
+ def b(s, encoding='utf8'): # NOQA
+ """cast unicode or bytes to bytes"""
+ if isinstance(s, str):
+ return s
+ elif isinstance(s, unicode):
+ return s.encode(encoding)
+ else:
+ raise TypeError("Expected unicode or bytes, got %r" % s)
+
+
+ def u(s, encoding='utf8'): # NOQA
+ """cast bytes or unicode to unicode"""
+ if isinstance(s, str):
+ return s.decode(encoding)
+ elif isinstance(s, unicode):
+ return s
+ else:
+ raise TypeError("Expected unicode or bytes, got %r" % s)
+
+
+ def b2s(s):
+ return s
+
+
+ try:
+ import cStringIO
+
+ StringIO = cStringIO.StringIO # NOQA
+ except ImportError:
+ import StringIO
+
+ StringIO = StringIO.StringIO # NOQA
+
+ BytesIO = StringIO
+
+
+ def is_callable(c): # NOQA
+ return callable(c)
+
+
+ def get_next(c): # NOQA
+ return c.next
+
+
+ def next(c):
+ return c.next()
+
+ # It's possible to have sizeof(long) != sizeof(Py_ssize_t).
+ class X(object):
+ def __len__(self):
+ return 1 << 31
+
+
+ try:
+ len(X())
+ except OverflowError:
+ # 32-bit
+ MAXSIZE = int((1 << 31) - 1) # NOQA
+ else:
+ # 64-bit
+ MAXSIZE = int((1 << 63) - 1) # NOQA
+ del X
+else:
+ import collections
+ import struct
+ string_types = str
+ text_type = str
+ bytes = bytes
+ bytes_types = bytes
+ integer_types = int
+ class long(int):
+ pass
+ input = input
+ decodebytes = base64.decodebytes
+ encodebytes = base64.encodebytes
+
+ def bytestring(s):
+ return s
+
+ def byte_ord(c):
+ assert isinstance(c, int)
+ return c
+
+ def byte_chr(c):
+ assert isinstance(c, int)
+ return struct.pack('B', c)
+
+ def byte_mask(c, mask):
+ assert isinstance(c, int)
+ return struct.pack('B', c & mask)
+
+ def b(s, encoding='utf8'):
+ """cast unicode or bytes to bytes"""
+ if isinstance(s, bytes):
+ return s
+ elif isinstance(s, str):
+ return s.encode(encoding)
+ else:
+ raise TypeError("Expected unicode or bytes, got %r" % s)
+
+ def u(s, encoding='utf8'):
+ """cast bytes or unicode to unicode"""
+ if isinstance(s, bytes):
+ return s.decode(encoding)
+ elif isinstance(s, str):
+ return s
+ else:
+ raise TypeError("Expected unicode or bytes, got %r" % s)
+
+ def b2s(s):
+ return s.decode() if isinstance(s, bytes) else s
+
+ import io
+ StringIO = io.StringIO # NOQA
+ BytesIO = io.BytesIO # NOQA
+
+ def is_callable(c):
+ return isinstance(c, collections.Callable)
+
+ def get_next(c):
+ return c.__next__
+
+ next = next
+
+ MAXSIZE = sys.maxsize # NOQA
diff --git a/paramiko/rsakey.py b/paramiko/rsakey.py
index 8c2aae91..06f0085d 100644
--- a/paramiko/rsakey.py
+++ b/paramiko/rsakey.py
@@ -31,6 +31,8 @@ from paramiko.ber import BER, BERException
from paramiko.pkey import PKey
from paramiko.ssh_exception import SSHException
+SHA1_DIGESTINFO = b'\x30\x21\x30\x09\x06\x05\x2b\x0e\x03\x02\x1a\x05\x00\x04\x14'
+
class RSAKey (PKey):
"""
@@ -57,18 +59,21 @@ class RSAKey (PKey):
else:
if msg is None:
raise SSHException('Key object may not be empty')
- if msg.get_string() != 'ssh-rsa':
+ if msg.get_text() != 'ssh-rsa':
raise SSHException('Invalid key')
self.e = msg.get_mpint()
self.n = msg.get_mpint()
self.size = util.bit_length(self.n)
- def __str__(self):
+ def asbytes(self):
m = Message()
m.add_string('ssh-rsa')
m.add_mpint(self.e)
m.add_mpint(self.n)
- return str(m)
+ return m.asbytes()
+
+ def __str__(self):
+ return self.asbytes()
def __hash__(self):
h = hash(self.get_name())
@@ -88,16 +93,16 @@ class RSAKey (PKey):
def sign_ssh_data(self, rpool, data):
digest = SHA.new(data).digest()
rsa = RSA.construct((long(self.n), long(self.e), long(self.d)))
- sig = util.deflate_long(rsa.sign(self._pkcs1imify(digest), '')[0], 0)
+ sig = util.deflate_long(rsa.sign(self._pkcs1imify(digest), bytes())[0], 0)
m = Message()
m.add_string('ssh-rsa')
m.add_string(sig)
return m
def verify_ssh_sig(self, data, msg):
- if msg.get_string() != 'ssh-rsa':
+ if msg.get_text() != 'ssh-rsa':
return False
- sig = util.inflate_long(msg.get_string(), True)
+ sig = util.inflate_long(msg.get_binary(), True)
# verify the signature by SHA'ing the data and encrypting it using the
# public key. some wackiness ensues where we "pkcs1imify" the 20-byte
# hash into a string as long as the RSA key.
@@ -116,7 +121,7 @@ class RSAKey (PKey):
b.encode(keylist)
except BERException:
raise SSHException('Unable to create ber encoding of key')
- return str(b)
+ return b.asbytes()
def write_private_key_file(self, filename, password=None):
self._write_private_key_file('RSA', filename, self._encode_key(), password)
@@ -152,10 +157,9 @@ class RSAKey (PKey):
turn a 20-byte SHA1 hash into a blob of data as large as the key's N,
using PKCS1's \"emsa-pkcs1-v1_5\" encoding. totally bizarre.
"""
- SHA1_DIGESTINFO = '\x30\x21\x30\x09\x06\x05\x2b\x0e\x03\x02\x1a\x05\x00\x04\x14'
size = len(util.deflate_long(self.n, 0))
- filler = '\xff' * (size - len(SHA1_DIGESTINFO) - len(data) - 3)
- return '\x00\x01' + filler + '\x00' + SHA1_DIGESTINFO + data
+ filler = max_byte * (size - len(SHA1_DIGESTINFO) - len(data) - 3)
+ return zero_byte + one_byte + filler + zero_byte + SHA1_DIGESTINFO + data
def _from_private_key_file(self, filename, password):
data = self._read_private_key_file('RSA', filename, password)
diff --git a/paramiko/server.py b/paramiko/server.py
index bf11cda3..ad0acb94 100644
--- a/paramiko/server.py
+++ b/paramiko/server.py
@@ -514,7 +514,7 @@ class InteractiveQuery (object):
self.instructions = instructions
self.prompts = []
for x in prompts:
- if (type(x) is str) or (type(x) is unicode):
+ if isinstance(x, string_types):
self.add_prompt(x)
else:
self.add_prompt(x[0], x[1])
@@ -576,7 +576,7 @@ class SubsystemHandler (threading.Thread):
try:
self.__transport._log(DEBUG, 'Starting handler for subsystem %s' % self.__name)
self.start_subsystem(self.__name, self.__transport, self.__channel)
- except Exception, e:
+ except Exception as e:
self.__transport._log(ERROR, 'Exception in subsystem handler for "%s": %s' %
(self.__name, str(e)))
self.__transport._log(ERROR, util.tb_strings())
diff --git a/paramiko/sftp.py b/paramiko/sftp.py
index a97c300f..3e05de9f 100644
--- a/paramiko/sftp.py
+++ b/paramiko/sftp.py
@@ -86,7 +86,7 @@ CMD_NAMES = {
CMD_ATTRS: 'attrs',
CMD_EXTENDED: 'extended',
CMD_EXTENDED_REPLY: 'extended_reply'
- }
+}
class SFTPError (Exception):
@@ -125,7 +125,7 @@ class BaseSFTP (object):
msg = Message()
msg.add_int(_VERSION)
msg.add(*extension_pairs)
- self._send_packet(CMD_VERSION, str(msg))
+ self._send_packet(CMD_VERSION, msg)
return version
def _log(self, level, msg, *args):
@@ -142,7 +142,7 @@ class BaseSFTP (object):
return
def _read_all(self, n):
- out = ''
+ out = bytes()
while n > 0:
if isinstance(self.sock, socket.socket):
# sometimes sftp is used directly over a socket instead of
@@ -166,7 +166,8 @@ class BaseSFTP (object):
def _send_packet(self, t, packet):
#self._log(DEBUG2, 'write: %s (len=%d)' % (CMD_NAMES.get(t, '0x%02x' % t), len(packet)))
- out = struct.pack('>I', len(packet) + 1) + chr(t) + packet
+ packet = asbytes(packet)
+ out = struct.pack('>I', len(packet) + 1) + byte_chr(t) + packet
if self.ultra_debug:
self._log(DEBUG, util.format_binary(out, 'OUT: '))
self._write_all(out)
@@ -175,14 +176,14 @@ class BaseSFTP (object):
x = self._read_all(4)
# most sftp servers won't accept packets larger than about 32k, so
# anything with the high byte set (> 16MB) is just garbage.
- if x[0] != '\x00':
+ if byte_ord(x[0]):
raise SFTPError('Garbage packet received')
size = struct.unpack('>I', x)[0]
data = self._read_all(size)
if self.ultra_debug:
self._log(DEBUG, util.format_binary(data, 'IN: '));
if size > 0:
- t = ord(data[0])
+ t = byte_ord(data[0])
#self._log(DEBUG2, 'read: %s (len=%d)' % (CMD_NAMES.get(t), '0x%02x' % t, len(data)-1))
return t, data[1:]
- return 0, ''
+ return 0, bytes()
diff --git a/paramiko/sftp_attr.py b/paramiko/sftp_attr.py
index 3ef9703b..ffdaa864 100644
--- a/paramiko/sftp_attr.py
+++ b/paramiko/sftp_attr.py
@@ -45,7 +45,7 @@ class SFTPAttributes (object):
FLAG_UIDGID = 2
FLAG_PERMISSIONS = 4
FLAG_AMTIME = 8
- FLAG_EXTENDED = 0x80000000L
+ FLAG_EXTENDED = x80000000
def __init__(self):
"""
@@ -141,7 +141,7 @@ class SFTPAttributes (object):
msg.add_int(long(self.st_mtime))
if self._flags & self.FLAG_EXTENDED:
msg.add_int(len(self.attr))
- for key, val in self.attr.iteritems():
+ for key, val in self.attr.items():
msg.add_string(key)
msg.add_string(val)
return
@@ -156,7 +156,7 @@ class SFTPAttributes (object):
out += 'mode=' + oct(self.st_mode) + ' '
if (self.st_atime is not None) and (self.st_mtime is not None):
out += 'atime=%d mtime=%d ' % (self.st_atime, self.st_mtime)
- for k, v in self.attr.iteritems():
+ for k, v in self.attr.items():
out += '"%s"=%r ' % (str(k), v)
out += ']'
return out
@@ -192,13 +192,13 @@ class SFTPAttributes (object):
ks = 's'
else:
ks = '?'
- ks += self._rwx((self.st_mode & 0700) >> 6, self.st_mode & stat.S_ISUID)
- ks += self._rwx((self.st_mode & 070) >> 3, self.st_mode & stat.S_ISGID)
+ ks += self._rwx((self.st_mode & o700) >> 6, self.st_mode & stat.S_ISUID)
+ ks += self._rwx((self.st_mode & o70) >> 3, self.st_mode & stat.S_ISGID)
ks += self._rwx(self.st_mode & 7, self.st_mode & stat.S_ISVTX, True)
else:
ks = '?---------'
# compute display date
- if (self.st_mtime is None) or (self.st_mtime == 0xffffffffL):
+ if (self.st_mtime is None) or (self.st_mtime == xffffffff):
# shouldn't really happen
datestr = '(unknown date)'
else:
@@ -219,3 +219,5 @@ class SFTPAttributes (object):
return '%s 1 %-8d %-8d %8d %-12s %s' % (ks, uid, gid, self.st_size, datestr, filename)
+ def asbytes(self):
+ return b(str(self))
diff --git a/paramiko/sftp_client.py b/paramiko/sftp_client.py
index 0580bc43..90571234 100644
--- a/paramiko/sftp_client.py
+++ b/paramiko/sftp_client.py
@@ -39,12 +39,13 @@ def _to_unicode(s):
"""
try:
return s.encode('ascii')
- except UnicodeError:
+ except (UnicodeError, AttributeError):
try:
return s.decode('utf-8')
except UnicodeError:
return s
+b_slash = b'/'
class SFTPClient(BaseSFTP):
"""
@@ -82,7 +83,7 @@ class SFTPClient(BaseSFTP):
self.ultra_debug = transport.get_hexdump()
try:
server_version = self._send_version()
- except EOFError, x:
+ except EOFError:
raise SSHException('EOF during negotiation')
self._log(INFO, 'Opened sftp connection (server version %d)' % server_version)
@@ -162,20 +163,20 @@ class SFTPClient(BaseSFTP):
t, msg = self._request(CMD_OPENDIR, path)
if t != CMD_HANDLE:
raise SFTPError('Expected handle')
- handle = msg.get_string()
+ handle = msg.get_binary()
filelist = []
while True:
try:
t, msg = self._request(CMD_READDIR, handle)
- except EOFError, e:
+ except EOFError:
# done with handle
break
if t != CMD_NAME:
raise SFTPError('Expected name response')
count = msg.get_int()
for i in range(count):
- filename = _to_unicode(msg.get_string())
- longname = _to_unicode(msg.get_string())
+ filename = msg.get_text()
+ longname = msg.get_text()
attr = SFTPAttributes._from_msg(msg, filename, longname)
if (filename != '.') and (filename != '..'):
filelist.append(attr)
@@ -231,7 +232,7 @@ class SFTPClient(BaseSFTP):
t, msg = self._request(CMD_OPEN, filename, imode, attrblock)
if t != CMD_HANDLE:
raise SFTPError('Expected handle')
- handle = msg.get_string()
+ handle = msg.get_binary()
self._log(DEBUG, 'open(%r, %r) -> %s' % (filename, mode, hexlify(handle)))
return SFTPFile(self, handle, mode, bufsize)
@@ -268,7 +269,7 @@ class SFTPClient(BaseSFTP):
self._log(DEBUG, 'rename(%r, %r)' % (oldpath, newpath))
self._request(CMD_RENAME, oldpath, newpath)
- def mkdir(self, path, mode=0777):
+ def mkdir(self, path, mode=o777):
"""
Create a folder (directory) named ``path`` with numeric mode ``mode``.
The default mode is 0777 (octal). On some systems, mode is ignored.
@@ -347,8 +348,7 @@ class SFTPClient(BaseSFTP):
"""
dest = self._adjust_cwd(dest)
self._log(DEBUG, 'symlink(%r, %r)' % (source, dest))
- if type(source) is unicode:
- source = source.encode('utf-8')
+ source = bytestring(source)
self._request(CMD_SYMLINK, source, dest)
def chmod(self, path, mode):
@@ -462,9 +462,9 @@ class SFTPClient(BaseSFTP):
count = msg.get_int()
if count != 1:
raise SFTPError('Realpath returned %d results' % count)
- return _to_unicode(msg.get_string())
+ return msg.get_text()
- def chdir(self, path):
+ def chdir(self, path=None):
"""
Change the "current directory" of this SFTP session. Since SFTP
doesn't really have the concept of a current working directory, this is
@@ -484,7 +484,7 @@ class SFTPClient(BaseSFTP):
return
if not stat.S_ISDIR(self.stat(path).st_mode):
raise SFTPError(errno.ENOTDIR, "%s: %s" % (os.strerror(errno.ENOTDIR), path))
- self._cwd = self.normalize(path).encode('utf-8')
+ self._cwd = b(self.normalize(path))
def getcwd(self):
"""
@@ -494,7 +494,7 @@ class SFTPClient(BaseSFTP):
.. versionadded:: 1.4
"""
- return self._cwd
+ return self._cwd and u(self._cwd)
def putfo(self, fl, remotepath, file_size=0, callback=None, confirm=True):
"""
@@ -525,10 +525,9 @@ class SFTPClient(BaseSFTP):
.. versionchanged:: 1.7.4
Began returning rich attribute objects.
"""
- fr = self.file(remotepath, 'wb')
+ with self.file(remotepath, 'wb') as fr:
fr.set_pipelined(True)
size = 0
- try:
while True:
data = fl.read(32768)
fr.write(data)
@@ -537,8 +536,6 @@ class SFTPClient(BaseSFTP):
callback(size, file_size)
if len(data) == 0:
break
- finally:
- fr.close()
if confirm:
s = self.stat(remotepath)
if s.st_size != size:
@@ -573,11 +570,8 @@ class SFTPClient(BaseSFTP):
``confirm`` param added.
"""
file_size = os.stat(localpath).st_size
- fl = file(localpath, 'rb')
- try:
+ with open(localpath, 'rb') as fl:
return self.putfo(fl, remotepath, os.stat(localpath).st_size, callback, confirm)
- finally:
- fl.close()
def getfo(self, remotepath, fl, callback=None):
"""
@@ -598,10 +592,9 @@ class SFTPClient(BaseSFTP):
.. versionchanged:: 1.7.4
Added the ``callable`` param.
"""
- fr = self.file(remotepath, 'rb')
+ with self.open(remotepath, 'rb') as fr:
file_size = self.stat(remotepath).st_size
fr.prefetch()
- try:
size = 0
while True:
data = fr.read(32768)
@@ -611,8 +604,6 @@ class SFTPClient(BaseSFTP):
callback(size, file_size)
if len(data) == 0:
break
- finally:
- fr.close()
return size
def get(self, remotepath, localpath, callback=None):
@@ -632,11 +623,8 @@ class SFTPClient(BaseSFTP):
Added the ``callback`` param
"""
file_size = self.stat(remotepath).st_size
- fl = file(localpath, 'wb')
- try:
+ with open(localpath, 'wb') as fl:
size = self.getfo(remotepath, fl, callback)
- finally:
- fl.close()
s = os.stat(localpath)
if s.st_size != size:
raise IOError('size mismatch in get! %d != %d' % (s.st_size, size))
@@ -656,11 +644,11 @@ class SFTPClient(BaseSFTP):
msg = Message()
msg.add_int(self.request_number)
for item in arg:
- if isinstance(item, int):
- msg.add_int(item)
- elif isinstance(item, long):
+ if isinstance(item, long):
msg.add_int64(item)
- elif isinstance(item, str):
+ elif isinstance(item, int):
+ msg.add_int(item)
+ elif isinstance(item, (string_types, bytes_types)):
msg.add_string(item)
elif isinstance(item, SFTPAttributes):
item._pack(msg)
@@ -668,7 +656,7 @@ class SFTPClient(BaseSFTP):
raise Exception('unknown type for %r type %r' % (item, type(item)))
num = self.request_number
self._expecting[num] = fileobj
- self._send_packet(t, str(msg))
+ self._send_packet(t, msg)
self.request_number += 1
finally:
self._lock.release()
@@ -678,8 +666,8 @@ class SFTPClient(BaseSFTP):
while True:
try:
t, data = self._read_packet()
- except EOFError, e:
- raise SSHException('Server connection dropped: %s' % (str(e),))
+ except EOFError as e:
+ raise SSHException('Server connection dropped: %s' % str(e))
msg = Message(data)
num = msg.get_int()
if num not in self._expecting:
@@ -713,7 +701,7 @@ class SFTPClient(BaseSFTP):
Raises EOFError or IOError on error status; otherwise does nothing.
"""
code = msg.get_int()
- text = msg.get_string()
+ text = msg.get_text()
if code == SFTP_OK:
return
elif code == SFTP_EOF:
@@ -731,16 +719,15 @@ class SFTPClient(BaseSFTP):
Return an adjusted path if we're emulating a "current working
directory" for the server.
"""
- if type(path) is unicode:
- path = path.encode('utf-8')
+ path = b(path)
if self._cwd is None:
return path
- if (len(path) > 0) and (path[0] == '/'):
+ if len(path) and path[0:1] == b_slash:
# absolute path
return path
- if self._cwd == '/':
+ if self._cwd == b_slash:
return self._cwd + path
- return self._cwd + '/' + path
+ return self._cwd + b_slash + path
class SFTP(SFTPClient):
diff --git a/paramiko/sftp_file.py b/paramiko/sftp_file.py
index 9add3c91..5d75bf59 100644
--- a/paramiko/sftp_file.py
+++ b/paramiko/sftp_file.py
@@ -100,7 +100,7 @@ class SFTPFile (BufferedFile):
k = [x for x in self._prefetch_extents.values() if x[0] <= offset]
if len(k) == 0:
return False
- k.sort(lambda x, y: cmp(x[0], y[0]))
+ k.sort(key=lambda x: x[0])
buf_offset, buf_size = k[-1]
if buf_offset + buf_size <= offset:
# prefetch request ends before this one begins
@@ -171,7 +171,7 @@ class SFTPFile (BufferedFile):
def _write(self, data):
# may write less than requested if it would exceed max packet size
chunk = min(len(data), self.MAX_REQUEST_SIZE)
- self._reqs.append(self.sftp._async_request(type(None), CMD_WRITE, self.handle, long(self._realpos), str(data[:chunk])))
+ self._reqs.append(self.sftp._async_request(type(None), CMD_WRITE, self.handle, long(self._realpos), data[:chunk]))
if not self.pipelined or (len(self._reqs) > 100 and self.sftp.sock.recv_ready()):
while len(self._reqs):
req = self._reqs.popleft()
@@ -224,7 +224,7 @@ class SFTPFile (BufferedFile):
self._realpos = self._pos
else:
self._realpos = self._pos = self._get_size() + offset
- self._rbuffer = ''
+ self._rbuffer = bytes()
def stat(self):
"""
@@ -352,8 +352,8 @@ class SFTPFile (BufferedFile):
"""
t, msg = self.sftp._request(CMD_EXTENDED, 'check-file', self.handle,
hash_algorithm, long(offset), long(length), block_size)
- ext = msg.get_string()
- alg = msg.get_string()
+ ext = msg.get_text()
+ alg = msg.get_text()
data = msg.get_remainder()
return data
@@ -469,8 +469,8 @@ class SFTPFile (BufferedFile):
# save exception and re-raise it on next file operation
try:
self.sftp._convert_status(msg)
- except Exception, x:
- self._saved_exception = x
+ except Exception as e:
+ self._saved_exception = e
return
if t != CMD_DATA:
raise SFTPError('Expected data')
diff --git a/paramiko/sftp_handle.py b/paramiko/sftp_handle.py
index a799d57c..79c0045c 100644
--- a/paramiko/sftp_handle.py
+++ b/paramiko/sftp_handle.py
@@ -97,7 +97,7 @@ class SFTPHandle (object):
readfile.seek(offset)
self.__tell = offset
data = readfile.read(length)
- except IOError, e:
+ except IOError as e:
self.__tell = None
return SFTPServer.convert_errno(e.errno)
self.__tell += len(data)
@@ -135,7 +135,7 @@ class SFTPHandle (object):
self.__tell = offset
writefile.write(data)
writefile.flush()
- except IOError, e:
+ except IOError as e:
self.__tell = None
return SFTPServer.convert_errno(e.errno)
if self.__tell is not None:
diff --git a/paramiko/sftp_server.py b/paramiko/sftp_server.py
index 0456e0a6..1c197dfd 100644
--- a/paramiko/sftp_server.py
+++ b/paramiko/sftp_server.py
@@ -89,7 +89,7 @@ class SFTPServer (BaseSFTP, SubsystemHandler):
except EOFError:
self._log(DEBUG, 'EOF -- end of session')
return
- except Exception, e:
+ except Exception as e:
self._log(DEBUG, 'Exception on channel: ' + str(e))
self._log(DEBUG, util.tb_strings())
return
@@ -97,7 +97,7 @@ class SFTPServer (BaseSFTP, SubsystemHandler):
request_number = msg.get_int()
try:
self._process(t, request_number, msg)
- except Exception, e:
+ except Exception as e:
self._log(DEBUG, 'Exception in server processing: ' + str(e))
self._log(DEBUG, util.tb_strings())
# send some kind of failure message, at least
@@ -110,9 +110,9 @@ class SFTPServer (BaseSFTP, SubsystemHandler):
self.server.session_ended()
super(SFTPServer, self).finish_subsystem()
# close any file handles that were left open (so we can return them to the OS quickly)
- for f in self.file_table.itervalues():
+ for f in self.file_table.values():
f.close()
- for f in self.folder_table.itervalues():
+ for f in self.folder_table.values():
f.close()
self.file_table = {}
self.folder_table = {}
@@ -159,7 +159,8 @@ class SFTPServer (BaseSFTP, SubsystemHandler):
if attr._flags & attr.FLAG_AMTIME:
os.utime(filename, (attr.st_atime, attr.st_mtime))
if attr._flags & attr.FLAG_SIZE:
- open(filename, 'w+').truncate(attr.st_size)
+ with open(filename, 'w+') as f:
+ f.truncate(attr.st_size)
set_file_attr = staticmethod(set_file_attr)
@@ -170,24 +171,24 @@ class SFTPServer (BaseSFTP, SubsystemHandler):
msg = Message()
msg.add_int(request_number)
for item in arg:
- if type(item) is int:
- msg.add_int(item)
- elif type(item) is long:
+ if isinstance(item, long):
msg.add_int64(item)
- elif type(item) is str:
+ elif isinstance(item, int):
+ msg.add_int(item)
+ elif isinstance(item, (string_types, bytes_types)):
msg.add_string(item)
elif type(item) is SFTPAttributes:
item._pack(msg)
else:
raise Exception('unknown type for ' + repr(item) + ' type ' + repr(type(item)))
- self._send_packet(t, str(msg))
+ self._send_packet(t, msg)
def _send_handle_response(self, request_number, handle, folder=False):
if not issubclass(type(handle), SFTPHandle):
# must be error code
self._send_status(request_number, handle)
return
- handle._set_name('hx%d' % self.next_handle)
+ handle._set_name(b('hx%d' % self.next_handle))
self.next_handle += 1
if folder:
self.folder_table[handle._get_name()] = handle
@@ -225,16 +226,16 @@ class SFTPServer (BaseSFTP, SubsystemHandler):
msg.add_int(len(flist))
for attr in flist:
msg.add_string(attr.filename)
- msg.add_string(str(attr))
+ msg.add_string(attr)
attr._pack(msg)
- self._send_packet(CMD_NAME, str(msg))
+ self._send_packet(CMD_NAME, msg)
def _check_file(self, request_number, msg):
# this extension actually comes from v6 protocol, but since it's an
# extension, i feel like we can reasonably support it backported.
# it's very useful for verifying uploaded files or checking for
# rsync-like differences between local and remote files.
- handle = msg.get_string()
+ handle = msg.get_binary()
alg_list = msg.get_list()
start = msg.get_int64()
length = msg.get_int64()
@@ -263,7 +264,7 @@ class SFTPServer (BaseSFTP, SubsystemHandler):
self._send_status(request_number, SFTP_FAILURE, 'Block size too small')
return
- sum_out = ''
+ sum_out = bytes()
offset = start
while offset < start + length:
blocklen = min(block_size, start + length - offset)
@@ -273,7 +274,7 @@ class SFTPServer (BaseSFTP, SubsystemHandler):
hash_obj = alg.new()
while count < blocklen:
data = f.read(offset, chunklen)
- if not type(data) is str:
+ if not isinstance(data, bytes_types):
self._send_status(request_number, data, 'Unable to hash file')
return
hash_obj.update(data)
@@ -286,7 +287,7 @@ class SFTPServer (BaseSFTP, SubsystemHandler):
msg.add_string('check-file')
msg.add_string(algname)
msg.add_bytes(sum_out)
- self._send_packet(CMD_EXTENDED_REPLY, str(msg))
+ self._send_packet(CMD_EXTENDED_REPLY, msg)
def _convert_pflags(self, pflags):
"convert SFTP-style open() flags to Python's os.open() flags"
@@ -309,12 +310,12 @@ class SFTPServer (BaseSFTP, SubsystemHandler):
def _process(self, t, request_number, msg):
self._log(DEBUG, 'Request: %s' % CMD_NAMES[t])
if t == CMD_OPEN:
- path = msg.get_string()
+ path = msg.get_text()
flags = self._convert_pflags(msg.get_int())
attr = SFTPAttributes._from_msg(msg)
self._send_handle_response(request_number, self.server.open(path, flags, attr))
elif t == CMD_CLOSE:
- handle = msg.get_string()
+ handle = msg.get_binary()
if handle in self.folder_table:
del self.folder_table[handle]
self._send_status(request_number, SFTP_OK)
@@ -326,14 +327,14 @@ class SFTPServer (BaseSFTP, SubsystemHandler):
return
self._send_status(request_number, SFTP_BAD_MESSAGE, 'Invalid handle')
elif t == CMD_READ:
- handle = msg.get_string()
+ handle = msg.get_binary()
offset = msg.get_int64()
length = msg.get_int()
if handle not in self.file_table:
self._send_status(request_number, SFTP_BAD_MESSAGE, 'Invalid handle')
return
data = self.file_table[handle].read(offset, length)
- if type(data) is str:
+ if isinstance(data, (bytes_types, string_types)):
if len(data) == 0:
self._send_status(request_number, SFTP_EOF)
else:
@@ -341,54 +342,54 @@ class SFTPServer (BaseSFTP, SubsystemHandler):
else:
self._send_status(request_number, data)
elif t == CMD_WRITE:
- handle = msg.get_string()
+ handle = msg.get_binary()
offset = msg.get_int64()
- data = msg.get_string()
+ data = msg.get_binary()
if handle not in self.file_table:
self._send_status(request_number, SFTP_BAD_MESSAGE, 'Invalid handle')
return
self._send_status(request_number, self.file_table[handle].write(offset, data))
elif t == CMD_REMOVE:
- path = msg.get_string()
+ path = msg.get_text()
self._send_status(request_number, self.server.remove(path))
elif t == CMD_RENAME:
- oldpath = msg.get_string()
- newpath = msg.get_string()
+ oldpath = msg.get_text()
+ newpath = msg.get_text()
self._send_status(request_number, self.server.rename(oldpath, newpath))
elif t == CMD_MKDIR:
- path = msg.get_string()
+ path = msg.get_text()
attr = SFTPAttributes._from_msg(msg)
self._send_status(request_number, self.server.mkdir(path, attr))
elif t == CMD_RMDIR:
- path = msg.get_string()
+ path = msg.get_text()
self._send_status(request_number, self.server.rmdir(path))
elif t == CMD_OPENDIR:
- path = msg.get_string()
+ path = msg.get_text()
self._open_folder(request_number, path)
return
elif t == CMD_READDIR:
- handle = msg.get_string()
+ handle = msg.get_binary()
if handle not in self.folder_table:
self._send_status(request_number, SFTP_BAD_MESSAGE, 'Invalid handle')
return
folder = self.folder_table[handle]
self._read_folder(request_number, folder)
elif t == CMD_STAT:
- path = msg.get_string()
+ path = msg.get_text()
resp = self.server.stat(path)
if issubclass(type(resp), SFTPAttributes):
self._response(request_number, CMD_ATTRS, resp)
else:
self._send_status(request_number, resp)
elif t == CMD_LSTAT:
- path = msg.get_string()
+ path = msg.get_text()
resp = self.server.lstat(path)
if issubclass(type(resp), SFTPAttributes):
self._response(request_number, CMD_ATTRS, resp)
else:
self._send_status(request_number, resp)
elif t == CMD_FSTAT:
- handle = msg.get_string()
+ handle = msg.get_binary()
if handle not in self.file_table:
self._send_status(request_number, SFTP_BAD_MESSAGE, 'Invalid handle')
return
@@ -398,34 +399,34 @@ class SFTPServer (BaseSFTP, SubsystemHandler):
else:
self._send_status(request_number, resp)
elif t == CMD_SETSTAT:
- path = msg.get_string()
+ path = msg.get_text()
attr = SFTPAttributes._from_msg(msg)
self._send_status(request_number, self.server.chattr(path, attr))
elif t == CMD_FSETSTAT:
- handle = msg.get_string()
+ handle = msg.get_binary()
attr = SFTPAttributes._from_msg(msg)
if handle not in self.file_table:
self._response(request_number, SFTP_BAD_MESSAGE, 'Invalid handle')
return
self._send_status(request_number, self.file_table[handle].chattr(attr))
elif t == CMD_READLINK:
- path = msg.get_string()
+ path = msg.get_text()
resp = self.server.readlink(path)
- if type(resp) is str:
+ if isinstance(resp, (bytes_types, string_types)):
self._response(request_number, CMD_NAME, 1, resp, '', SFTPAttributes())
else:
self._send_status(request_number, resp)
elif t == CMD_SYMLINK:
# the sftp 2 draft is incorrect here! path always follows target_path
- target_path = msg.get_string()
- path = msg.get_string()
+ target_path = msg.get_text()
+ path = msg.get_text()
self._send_status(request_number, self.server.symlink(target_path, path))
elif t == CMD_REALPATH:
- path = msg.get_string()
+ path = msg.get_text()
rpath = self.server.canonicalize(path)
self._response(request_number, CMD_NAME, 1, rpath, '', SFTPAttributes())
elif t == CMD_EXTENDED:
- tag = msg.get_string()
+ tag = msg.get_text()
if tag == 'check-file':
self._check_file(request_number, msg)
else:
diff --git a/paramiko/transport.py b/paramiko/transport.py
index 9f5c7098..692a0c68 100644
--- a/paramiko/transport.py
+++ b/paramiko/transport.py
@@ -155,7 +155,7 @@ class Transport (threading.Thread):
:param socket sock:
a socket or socket-like object to create the session over.
"""
- if isinstance(sock, (str, unicode)):
+ if isinstance(sock, string_types):
# convert "host:port" into (host, port)
hl = sock.split(':', 1)
if len(hl) == 1:
@@ -173,7 +173,7 @@ class Transport (threading.Thread):
sock = socket.socket(af, socket.SOCK_STREAM)
try:
retry_on_signal(lambda: sock.connect((hostname, port)))
- except socket.error, e:
+ except socket.error as e:
reason = str(e)
else:
break
@@ -253,7 +253,7 @@ class Transport (threading.Thread):
"""
Returns a string representation of this object, for debugging.
"""
- out = '<paramiko.Transport at %s' % hex(long(id(self)) & 0xffffffffL)
+ out = '<paramiko.Transport at %s' % hex(long(id(self)) & xffffffff)
if not self.active:
out += ' (unconnected)'
else:
@@ -279,6 +279,7 @@ class Transport (threading.Thread):
.. versionadded:: 1.5.3
"""
+ self.sock.close()
self.close()
def get_security_options(self):
@@ -489,7 +490,7 @@ class Transport (threading.Thread):
if not self.active:
return
self.stop_thread()
- for chan in self._channels.values():
+ for chan in list(self._channels.values()):
chan._unlink()
self.sock.close()
@@ -562,18 +563,16 @@ class Transport (threading.Thread):
"""
return self.open_channel('auth-agent@openssh.com')
- def open_forwarded_tcpip_channel(self, (src_addr, src_port), (dest_addr, dest_port)):
+ def open_forwarded_tcpip_channel(self, src_addr, dest_addr):
"""
Request a new channel back to the client, of type ``"forwarded-tcpip"``.
This is used after a client has requested port forwarding, for sending
incoming connections back to the client.
:param src_addr: originator's address
- :param src_port: originator's port
:param dest_addr: local (server) connected address
- :param dest_port: local (server) connected port
"""
- return self.open_channel('forwarded-tcpip', (dest_addr, dest_port), (src_addr, src_port))
+ return self.open_channel('forwarded-tcpip', dest_addr, src_addr)
def open_channel(self, kind, dest_addr=None, src_addr=None):
"""
@@ -602,7 +601,7 @@ class Transport (threading.Thread):
try:
chanid = self._next_channel()
m = Message()
- m.add_byte(chr(MSG_CHANNEL_OPEN))
+ m.add_byte(cMSG_CHANNEL_OPEN)
m.add_string(kind)
m.add_int(chanid)
m.add_int(self.window_size)
@@ -670,7 +669,6 @@ class Transport (threading.Thread):
"""
if not self.active:
raise SSHException('SSH session not active')
- address = str(address)
port = int(port)
response = self.global_request('tcpip-forward', (address, port), wait=True)
if response is None:
@@ -678,7 +676,9 @@ class Transport (threading.Thread):
if port == 0:
port = response.get_int()
if handler is None:
- def default_handler(channel, (src_addr, src_port), (dest_addr, dest_port)):
+ def default_handler(channel, src_addr, dest_addr):
+ #src_addr, src_port = src_addr_port
+ #dest_addr, dest_port = dest_addr_port
self._queue_incoming_channel(channel)
handler = default_handler
self._tcp_handler = handler
@@ -710,22 +710,22 @@ class Transport (threading.Thread):
"""
return SFTPClient.from_transport(self)
- def send_ignore(self, bytes=None):
+ def send_ignore(self, byte_count=None):
"""
Send a junk packet across the encrypted link. This is sometimes used
to add "noise" to a connection to confuse would-be attackers. It can
also be used as a keep-alive for long lived connections traversing
firewalls.
- :param int bytes:
+ :param int byte_count:
the number of random bytes to send in the payload of the ignored
packet -- defaults to a random number from 10 to 41.
"""
m = Message()
- m.add_byte(chr(MSG_IGNORE))
- if bytes is None:
- bytes = (ord(rng.read(1)) % 32) + 10
- m.add_bytes(rng.read(bytes))
+ m.add_byte(cMSG_IGNORE)
+ if byte_count is None:
+ byte_count = (ord(rng.read(1)) % 32) + 10
+ m.add_bytes(rng.read(byte_count))
self._send_user_message(m)
def renegotiate_keys(self):
@@ -787,7 +787,7 @@ class Transport (threading.Thread):
if wait:
self.completion_event = threading.Event()
m = Message()
- m.add_byte(chr(MSG_GLOBAL_REQUEST))
+ m.add_byte(cMSG_GLOBAL_REQUEST)
m.add_string(kind)
m.add_boolean(wait)
if data is not None:
@@ -871,10 +871,10 @@ class Transport (threading.Thread):
# check host key if we were given one
if (hostkey is not None):
key = self.get_remote_server_key()
- if (key.get_name() != hostkey.get_name()) or (str(key) != str(hostkey)):
+ if (key.get_name() != hostkey.get_name()) or (key.asbytes() != hostkey.asbytes()):
self._log(DEBUG, 'Bad host key from server')
- self._log(DEBUG, 'Expected: %s: %s' % (hostkey.get_name(), repr(str(hostkey))))
- self._log(DEBUG, 'Got : %s: %s' % (key.get_name(), repr(str(key))))
+ self._log(DEBUG, 'Expected: %s: %s' % (hostkey.get_name(), repr(hostkey.asbytes())))
+ self._log(DEBUG, 'Got : %s: %s' % (key.get_name(), repr(key.asbytes())))
raise SSHException('Bad host key from server')
self._log(DEBUG, 'Host key verified (%s)' % hostkey.get_name())
@@ -1048,9 +1048,9 @@ class Transport (threading.Thread):
return []
try:
return self.auth_handler.wait_for_response(my_event)
- except BadAuthenticationType, x:
+ except BadAuthenticationType as e:
# if password auth isn't allowed, but keyboard-interactive *is*, try to fudge it
- if not fallback or ('keyboard-interactive' not in x.allowed_types):
+ if not fallback or ('keyboard-interactive' not in e.allowed_types):
raise
try:
def handler(title, instructions, fields):
@@ -1064,9 +1064,9 @@ class Transport (threading.Thread):
return []
return [ password ]
return self.auth_interactive(username, handler)
- except SSHException, ignored:
+ except SSHException:
# attempt failed; just raise the original exception
- raise x
+ raise e
return None
def auth_publickey(self, username, key, event=None):
@@ -1331,15 +1331,15 @@ class Transport (threading.Thread):
m = Message()
m.add_mpint(self.K)
m.add_bytes(self.H)
- m.add_byte(id)
+ m.add_byte(b(id))
m.add_bytes(self.session_id)
- out = sofar = SHA.new(str(m)).digest()
+ out = sofar = SHA.new(m.asbytes()).digest()
while len(out) < nbytes:
m = Message()
m.add_mpint(self.K)
m.add_bytes(self.H)
m.add_bytes(sofar)
- digest = SHA.new(str(m)).digest()
+ digest = SHA.new(m.asbytes()).digest()
out += digest
sofar += digest
return out[:nbytes]
@@ -1373,7 +1373,7 @@ class Transport (threading.Thread):
# only called if a channel has turned on x11 forwarding
if handler is None:
# by default, use the same mechanism as accept()
- def default_handler(channel, (src_addr, src_port)):
+ def default_handler(channel, src_addr_port):
self._queue_incoming_channel(channel)
self._x11_handler = default_handler
else:
@@ -1404,12 +1404,12 @@ class Transport (threading.Thread):
# active=True occurs before the thread is launched, to avoid a race
_active_threads.append(self)
if self.server_mode:
- self._log(DEBUG, 'starting thread (server mode): %s' % hex(long(id(self)) & 0xffffffffL))
+ self._log(DEBUG, 'starting thread (server mode): %s' % hex(long(id(self)) & xffffffff))
else:
- self._log(DEBUG, 'starting thread (client mode): %s' % hex(long(id(self)) & 0xffffffffL))
+ self._log(DEBUG, 'starting thread (client mode): %s' % hex(long(id(self)) & xffffffff))
try:
try:
- self.packetizer.write_all(self.local_version + '\r\n')
+ self.packetizer.write_all(b(self.local_version + '\r\n'))
self._check_banner()
self._send_kex_init()
self._expect_packet(MSG_KEXINIT)
@@ -1457,18 +1457,18 @@ class Transport (threading.Thread):
else:
self._log(WARNING, 'Oops, unhandled type %d' % ptype)
msg = Message()
- msg.add_byte(chr(MSG_UNIMPLEMENTED))
+ msg.add_byte(cMSG_UNIMPLEMENTED)
msg.add_int(m.seqno)
self._send_message(msg)
- except SSHException, e:
+ except SSHException as e:
self._log(ERROR, 'Exception: ' + str(e))
self._log(ERROR, util.tb_strings())
self.saved_exception = e
- except EOFError, e:
+ except EOFError as e:
self._log(DEBUG, 'EOF in transport thread')
#self._log(DEBUG, util.tb_strings())
self.saved_exception = e
- except socket.error, e:
+ except socket.error as e:
if type(e.args) is tuple:
if e.args:
emsg = '%s (%d)' % (e.args[1], e.args[0])
@@ -1478,12 +1478,12 @@ class Transport (threading.Thread):
emsg = e.args
self._log(ERROR, 'Socket exception: ' + emsg)
self.saved_exception = e
- except Exception, e:
+ except Exception as e:
self._log(ERROR, 'Unknown exception: ' + str(e))
self._log(ERROR, util.tb_strings())
self.saved_exception = e
_active_threads.remove(self)
- for chan in self._channels.values():
+ for chan in list(self._channels.values()):
chan._unlink()
if self.active:
self.active = False
@@ -1538,8 +1538,8 @@ class Transport (threading.Thread):
buf = self.packetizer.readline(timeout)
except ProxyCommandFailure:
raise
- except Exception, x:
- raise SSHException('Error reading SSH protocol banner' + str(x))
+ except Exception as e:
+ raise SSHException('Error reading SSH protocol banner' + str(e))
if buf[:4] == 'SSH-':
break
self._log(DEBUG, 'Banner: ' + buf)
@@ -1549,7 +1549,7 @@ class Transport (threading.Thread):
self.remote_version = buf
# pull off any attached comment
comment = ''
- i = string.find(buf, ' ')
+ i = buf.find(' ')
if i >= 0:
comment = buf[i+1:]
buf = buf[:i]
@@ -1580,13 +1580,13 @@ class Transport (threading.Thread):
pkex = list(self.get_security_options().kex)
pkex.remove('diffie-hellman-group-exchange-sha1')
self.get_security_options().kex = pkex
- available_server_keys = filter(self.server_key_dict.keys().__contains__,
- self._preferred_keys)
+ available_server_keys = list(filter(list(self.server_key_dict.keys()).__contains__,
+ self._preferred_keys))
else:
available_server_keys = self._preferred_keys
m = Message()
- m.add_byte(chr(MSG_KEXINIT))
+ m.add_byte(cMSG_KEXINIT)
m.add_bytes(rng.read(16))
m.add_list(self._preferred_kex)
m.add_list(available_server_keys)
@@ -1596,12 +1596,12 @@ class Transport (threading.Thread):
m.add_list(self._preferred_macs)
m.add_list(self._preferred_compression)
m.add_list(self._preferred_compression)
- m.add_string('')
- m.add_string('')
+ m.add_string(bytes())
+ m.add_string(bytes())
m.add_boolean(False)
m.add_int(0)
# save a copy for later (needed to compute a hash)
- self.local_kex_init = str(m)
+ self.local_kex_init = m.asbytes()
self._send_message(m)
def _parse_kex_init(self, m):
@@ -1633,19 +1633,19 @@ class Transport (threading.Thread):
# as a server, we pick the first item in the client's list that we support.
# as a client, we pick the first item in our list that the server supports.
if self.server_mode:
- agreed_kex = filter(self._preferred_kex.__contains__, kex_algo_list)
+ agreed_kex = list(filter(self._preferred_kex.__contains__, kex_algo_list))
else:
- agreed_kex = filter(kex_algo_list.__contains__, self._preferred_kex)
+ agreed_kex = list(filter(kex_algo_list.__contains__, self._preferred_kex))
if len(agreed_kex) == 0:
raise SSHException('Incompatible ssh peer (no acceptable kex algorithm)')
self.kex_engine = self._kex_info[agreed_kex[0]](self)
if self.server_mode:
- available_server_keys = filter(self.server_key_dict.keys().__contains__,
- self._preferred_keys)
- agreed_keys = filter(available_server_keys.__contains__, server_key_algo_list)
+ available_server_keys = list(filter(list(self.server_key_dict.keys()).__contains__,
+ self._preferred_keys))
+ agreed_keys = list(filter(available_server_keys.__contains__, server_key_algo_list))
else:
- agreed_keys = filter(server_key_algo_list.__contains__, self._preferred_keys)
+ agreed_keys = list(filter(server_key_algo_list.__contains__, self._preferred_keys))
if len(agreed_keys) == 0:
raise SSHException('Incompatible ssh peer (no acceptable host key)')
self.host_key_type = agreed_keys[0]
@@ -1653,15 +1653,15 @@ class Transport (threading.Thread):
raise SSHException('Incompatible ssh peer (can\'t match requested host key type)')
if self.server_mode:
- agreed_local_ciphers = filter(self._preferred_ciphers.__contains__,
- server_encrypt_algo_list)
- agreed_remote_ciphers = filter(self._preferred_ciphers.__contains__,
- client_encrypt_algo_list)
+ agreed_local_ciphers = list(filter(self._preferred_ciphers.__contains__,
+ server_encrypt_algo_list))
+ agreed_remote_ciphers = list(filter(self._preferred_ciphers.__contains__,
+ client_encrypt_algo_list))
else:
- agreed_local_ciphers = filter(client_encrypt_algo_list.__contains__,
- self._preferred_ciphers)
- agreed_remote_ciphers = filter(server_encrypt_algo_list.__contains__,
- self._preferred_ciphers)
+ agreed_local_ciphers = list(filter(client_encrypt_algo_list.__contains__,
+ self._preferred_ciphers))
+ agreed_remote_ciphers = list(filter(server_encrypt_algo_list.__contains__,
+ self._preferred_ciphers))
if (len(agreed_local_ciphers) == 0) or (len(agreed_remote_ciphers) == 0):
raise SSHException('Incompatible ssh server (no acceptable ciphers)')
self.local_cipher = agreed_local_ciphers[0]
@@ -1669,22 +1669,22 @@ class Transport (threading.Thread):
self._log(DEBUG, 'Ciphers agreed: local=%s, remote=%s' % (self.local_cipher, self.remote_cipher))
if self.server_mode:
- agreed_remote_macs = filter(self._preferred_macs.__contains__, client_mac_algo_list)
- agreed_local_macs = filter(self._preferred_macs.__contains__, server_mac_algo_list)
+ agreed_remote_macs = list(filter(self._preferred_macs.__contains__, client_mac_algo_list))
+ agreed_local_macs = list(filter(self._preferred_macs.__contains__, server_mac_algo_list))
else:
- agreed_local_macs = filter(client_mac_algo_list.__contains__, self._preferred_macs)
- agreed_remote_macs = filter(server_mac_algo_list.__contains__, self._preferred_macs)
+ agreed_local_macs = list(filter(client_mac_algo_list.__contains__, self._preferred_macs))
+ agreed_remote_macs = list(filter(server_mac_algo_list.__contains__, self._preferred_macs))
if (len(agreed_local_macs) == 0) or (len(agreed_remote_macs) == 0):
raise SSHException('Incompatible ssh server (no acceptable macs)')
self.local_mac = agreed_local_macs[0]
self.remote_mac = agreed_remote_macs[0]
if self.server_mode:
- agreed_remote_compression = filter(self._preferred_compression.__contains__, client_compress_algo_list)
- agreed_local_compression = filter(self._preferred_compression.__contains__, server_compress_algo_list)
+ agreed_remote_compression = list(filter(self._preferred_compression.__contains__, client_compress_algo_list))
+ agreed_local_compression = list(filter(self._preferred_compression.__contains__, server_compress_algo_list))
else:
- agreed_local_compression = filter(client_compress_algo_list.__contains__, self._preferred_compression)
- agreed_remote_compression = filter(server_compress_algo_list.__contains__, self._preferred_compression)
+ agreed_local_compression = list(filter(client_compress_algo_list.__contains__, self._preferred_compression))
+ agreed_remote_compression = list(filter(server_compress_algo_list.__contains__, self._preferred_compression))
if (len(agreed_local_compression) == 0) or (len(agreed_remote_compression) == 0):
raise SSHException('Incompatible ssh server (no acceptable compression) %r %r %r' % (agreed_local_compression, agreed_remote_compression, self._preferred_compression))
self.local_compression = agreed_local_compression[0]
@@ -1699,7 +1699,7 @@ class Transport (threading.Thread):
# actually some extra bytes (one NUL byte in openssh's case) added to
# the end of the packet but not parsed. turns out we need to throw
# away those bytes because they aren't part of the hash.
- self.remote_kex_init = chr(MSG_KEXINIT) + m.get_so_far()
+ self.remote_kex_init = cMSG_KEXINIT + m.get_so_far()
def _activate_inbound(self):
"switch on newly negotiated encryption parameters for inbound traffic"
@@ -1728,7 +1728,7 @@ class Transport (threading.Thread):
def _activate_outbound(self):
"switch on newly negotiated encryption parameters for outbound traffic"
m = Message()
- m.add_byte(chr(MSG_NEWKEYS))
+ m.add_byte(MSG_NEWKEYS)
self._send_message(m)
block_size = self._cipher_info[self.local_cipher]['block-size']
if self.server_mode:
@@ -1797,24 +1797,24 @@ class Transport (threading.Thread):
def _parse_disconnect(self, m):
code = m.get_int()
- desc = m.get_string()
+ desc = m.get_text()
self._log(INFO, 'Disconnect (code %d): %s' % (code, desc))
def _parse_global_request(self, m):
- kind = m.get_string()
+ kind = m.get_text()
self._log(DEBUG, 'Received global request "%s"' % kind)
want_reply = m.get_boolean()
if not self.server_mode:
self._log(DEBUG, 'Rejecting "%s" global request from server.' % kind)
ok = False
elif kind == 'tcpip-forward':
- address = m.get_string()
+ address = m.get_text()
port = m.get_int()
ok = self.server_object.check_port_forward_request(address, port)
if ok != False:
ok = (ok,)
elif kind == 'cancel-tcpip-forward':
- address = m.get_string()
+ address = m.get_test()
port = m.get_int()
self.server_object.cancel_port_forward_request(address, port)
ok = True
@@ -1827,10 +1827,10 @@ class Transport (threading.Thread):
if want_reply:
msg = Message()
if ok:
- msg.add_byte(chr(MSG_REQUEST_SUCCESS))
+ msg.add_byte(cMSG_REQUEST_SUCCESS)
msg.add(*extra)
else:
- msg.add_byte(chr(MSG_REQUEST_FAILURE))
+ msg.add_byte(cMSG_REQUEST_FAILURE)
self._send_message(msg)
def _parse_request_success(self, m):
@@ -1868,8 +1868,8 @@ class Transport (threading.Thread):
def _parse_channel_open_failure(self, m):
chanid = m.get_int()
reason = m.get_int()
- reason_str = m.get_string()
- lang = m.get_string()
+ reason_str = m.get_text()
+ lang = m.get_text()
reason_text = CONNECTION_FAILED_CODE.get(reason, '(unknown code)')
self._log(INFO, 'Secsh channel %d open FAILED: %s: %s' % (chanid, reason_str, reason_text))
self.lock.acquire()
@@ -1885,7 +1885,7 @@ class Transport (threading.Thread):
return
def _parse_channel_open(self, m):
- kind = m.get_string()
+ kind = m.get_text()
chanid = m.get_int()
initial_window_size = m.get_int()
max_packet_size = m.get_int()
@@ -1898,7 +1898,7 @@ class Transport (threading.Thread):
finally:
self.lock.release()
elif (kind == 'x11') and (self._x11_handler is not None):
- origin_addr = m.get_string()
+ origin_addr = m.get_text()
origin_port = m.get_int()
self._log(DEBUG, 'Incoming x11 connection from %s:%d' % (origin_addr, origin_port))
self.lock.acquire()
@@ -1907,9 +1907,9 @@ class Transport (threading.Thread):
finally:
self.lock.release()
elif (kind == 'forwarded-tcpip') and (self._tcp_handler is not None):
- server_addr = m.get_string()
+ server_addr = m.get_text()
server_port = m.get_int()
- origin_addr = m.get_string()
+ origin_addr = m.get_text()
origin_port = m.get_int()
self._log(DEBUG, 'Incoming tcp forwarded connection from %s:%d' % (origin_addr, origin_port))
self.lock.acquire()
@@ -1929,9 +1929,9 @@ class Transport (threading.Thread):
self.lock.release()
if kind == 'direct-tcpip':
# handle direct-tcpip requests comming from the client
- dest_addr = m.get_string()
+ dest_addr = m.get_text()
dest_port = m.get_int()
- origin_addr = m.get_string()
+ origin_addr = m.get_text()
origin_port = m.get_int()
reason = self.server_object.check_channel_direct_tcpip_request(
my_chanid, (origin_addr, origin_port),
@@ -1943,7 +1943,7 @@ class Transport (threading.Thread):
reject = True
if reject:
msg = Message()
- msg.add_byte(chr(MSG_CHANNEL_OPEN_FAILURE))
+ msg.add_byte(cMSG_CHANNEL_OPEN_FAILURE)
msg.add_int(chanid)
msg.add_int(reason)
msg.add_string('')
@@ -1962,7 +1962,7 @@ class Transport (threading.Thread):
finally:
self.lock.release()
m = Message()
- m.add_byte(chr(MSG_CHANNEL_OPEN_SUCCESS))
+ m.add_byte(cMSG_CHANNEL_OPEN_SUCCESS)
m.add_int(chanid)
m.add_int(my_chanid)
m.add_int(self.window_size)
@@ -2029,7 +2029,8 @@ class SecurityOptions (object):
``ValueError`` will be raised. If you try to assign something besides a
tuple to one of the fields, ``TypeError`` will be raised.
"""
- __slots__ = [ 'ciphers', 'digests', 'key_types', 'kex', 'compression', '_transport' ]
+ #__slots__ = [ 'ciphers', 'digests', 'key_types', 'kex', 'compression', '_transport' ]
+ __slots__ = '_transport'
def __init__(self, transport):
self._transport = transport
@@ -2060,8 +2061,8 @@ class SecurityOptions (object):
x = tuple(x)
if type(x) is not tuple:
raise TypeError('expected tuple or list')
- possible = getattr(self._transport, orig).keys()
- forbidden = filter(lambda n: n not in possible, x)
+ possible = list(getattr(self._transport, orig).keys())
+ forbidden = [n for n in x if n not in possible]
if len(forbidden) > 0:
raise ValueError('unknown cipher')
setattr(self._transport, name, x)
@@ -2125,7 +2126,7 @@ class ChannelMap (object):
def values(self):
self._lock.acquire()
try:
- return self._map.values()
+ return list(self._map.values())
finally:
self._lock.release()
diff --git a/paramiko/util.py b/paramiko/util.py
index e0ef3b7c..c4b87e3b 100644
--- a/paramiko/util.py
+++ b/paramiko/util.py
@@ -48,60 +48,53 @@ if sys.version_info < (2,3):
def inflate_long(s, always_positive=False):
"turns a normalized byte string into a long-int (adapted from Crypto.Util.number)"
- out = 0L
+ out = long(0)
negative = 0
- if not always_positive and (len(s) > 0) and (ord(s[0]) >= 0x80):
+ if not always_positive and (len(s) > 0) and (byte_ord(s[0]) >= 0x80):
negative = 1
if len(s) % 4:
- filler = '\x00'
+ filler = zero_byte
if negative:
- filler = '\xff'
+ filler = max_byte
s = filler * (4 - len(s) % 4) + s
for i in range(0, len(s), 4):
out = (out << 32) + struct.unpack('>I', s[i:i+4])[0]
if negative:
- out -= (1L << (8 * len(s)))
+ out -= (long(1) << (8 * len(s)))
return out
+deflate_zero = zero_byte if PY2 else 0
+deflate_ff = max_byte if PY2 else 0xff
+
def deflate_long(n, add_sign_padding=True):
"turns a long-int into a normalized byte string (adapted from Crypto.Util.number)"
# after much testing, this algorithm was deemed to be the fastest
- s = ''
+ s = bytes()
n = long(n)
while (n != 0) and (n != -1):
- s = struct.pack('>I', n & 0xffffffffL) + s
+ s = struct.pack('>I', n & xffffffff) + s
n = n >> 32
# strip off leading zeros, FFs
for i in enumerate(s):
- if (n == 0) and (i[1] != '\000'):
+ if (n == 0) and (i[1] != deflate_zero):
break
- if (n == -1) and (i[1] != '\xff'):
+ if (n == -1) and (i[1] != deflate_ff):
break
else:
# degenerate case, n was either 0 or -1
i = (0,)
if n == 0:
- s = '\000'
+ s = zero_byte
else:
- s = '\xff'
+ s = max_byte
s = s[i[0]:]
if add_sign_padding:
- if (n == 0) and (ord(s[0]) >= 0x80):
- s = '\x00' + s
- if (n == -1) and (ord(s[0]) < 0x80):
- s = '\xff' + s
+ if (n == 0) and (byte_ord(s[0]) >= 0x80):
+ s = zero_byte + s
+ if (n == -1) and (byte_ord(s[0]) < 0x80):
+ s = max_byte + s
return s
-def format_binary_weird(data):
- out = ''
- for i in enumerate(data):
- out += '%02X' % ord(i[1])
- if i[0] % 2:
- out += ' '
- if i[0] % 16 == 15:
- out += '\n'
- return out
-
def format_binary(data, prefix=''):
x = 0
out = []
@@ -113,8 +106,8 @@ def format_binary(data, prefix=''):
return [prefix + x for x in out]
def format_binary_line(data):
- left = ' '.join(['%02X' % ord(c) for c in data])
- right = ''.join([('.%c..' % c)[(ord(c)+63)//95] for c in data])
+ left = ' '.join(['%02X' % byte_ord(c) for c in data])
+ right = ''.join([('.%c..' % c)[(byte_ord(c)+63)//95] for c in data])
return '%-50s %s' % (left, right)
def hexify(s):
@@ -126,17 +119,20 @@ def unhexify(s):
def safe_string(s):
out = ''
for c in s:
- if (ord(c) >= 32) and (ord(c) <= 127):
+ if (byte_ord(c) >= 32) and (byte_ord(c) <= 127):
out += c
else:
- out += '%%%02X' % ord(c)
+ out += '%%%02X' % byte_ord(c)
return out
# ''.join([['%%%02X' % ord(c), c][(ord(c) >= 32) and (ord(c) <= 127)] for c in s])
def bit_length(n):
+ try:
+ return n.bitlength()
+ except AttributeError:
norm = deflate_long(n, 0)
- hbyte = ord(norm[0])
+ hbyte = byte_ord(norm[0])
if hbyte == 0:
return 1
bitlen = len(norm) * 8
@@ -157,20 +153,21 @@ def generate_key_bytes(hashclass, salt, key, nbytes):
:param class hashclass:
class from `Crypto.Hash` that can be used as a secure hashing function
(like ``MD5`` or ``SHA``).
- :param str salt: data to salt the hash with.
+ :param salt: data to salt the hash with.
+ :type salt: byte string
:param str key: human-entered password or passphrase.
:param int nbytes: number of bytes to generate.
:return: Key data `str`
"""
- keydata = ''
- digest = ''
+ keydata = bytes()
+ digest = bytes()
if len(salt) > 8:
salt = salt[:8]
while nbytes > 0:
hash_obj = hashclass.new()
if len(digest) > 0:
hash_obj.update(digest)
- hash_obj.update(key)
+ hash_obj.update(b(key))
hash_obj.update(salt)
digest = hash_obj.digest()
size = min(nbytes, len(digest))
@@ -271,37 +268,37 @@ def retry_on_signal(function):
while True:
try:
return function()
- except EnvironmentError, e:
+ except EnvironmentError as e:
if e.errno != errno.EINTR:
raise
class Counter (object):
"""Stateful counter for CTR mode crypto"""
- def __init__(self, nbits, initial_value=1L, overflow=0L):
+ def __init__(self, nbits, initial_value=long(1), overflow=long(0)):
self.blocksize = nbits / 8
self.overflow = overflow
# start with value - 1 so we don't have to store intermediate values when counting
# could the iv be 0?
if initial_value == 0:
- self.value = array.array('c', '\xFF' * self.blocksize)
+ self.value = array.array('c', max_byte * self.blocksize)
else:
x = deflate_long(initial_value - 1, add_sign_padding=False)
- self.value = array.array('c', '\x00' * (self.blocksize - len(x)) + x)
+ self.value = array.array('c', zero_byte * (self.blocksize - len(x)) + x)
def __call__(self):
"""Increament the counter and return the new value"""
i = self.blocksize - 1
while i > -1:
- c = self.value[i] = chr((ord(self.value[i]) + 1) % 256)
- if c != '\x00':
+ c = self.value[i] = byte_chr((byte_ord(self.value[i]) + 1) % 256)
+ if c != zero_byte:
return self.value.tostring()
i -= 1
# counter reset
x = deflate_long(self.overflow, add_sign_padding=False)
- self.value = array.array('c', '\x00' * (self.blocksize - len(x)) + x)
+ self.value = array.array('c', zero_byte * (self.blocksize - len(x)) + x)
return self.value.tostring()
- def new(cls, nbits, initial_value=1L, overflow=0L):
+ def new(cls, nbits, initial_value=long(1), overflow=long(0)):
return cls(nbits, initial_value=initial_value, overflow=overflow)
new = classmethod(new)
diff --git a/paramiko/win_pageant.py b/paramiko/win_pageant.py
index d588e81d..d815a322 100644
--- a/paramiko/win_pageant.py
+++ b/paramiko/win_pageant.py
@@ -27,6 +27,7 @@ import array
import ctypes.wintypes
import platform
import struct
+from paramiko.util import *
try:
import _thread as thread # Python 3.x
@@ -91,7 +92,7 @@ def _query_pageant(msg):
with pymap:
pymap.write(msg)
# Create an array buffer containing the mapped filename
- char_buffer = array.array("c", map_name + '\0')
+ char_buffer = array.array("c", b(map_name) + zero_byte)
char_buffer_address, char_buffer_size = char_buffer.buffer_info()
# Create a string to use for the SendMessage function call
cds = COPYDATASTRUCT(_AGENT_COPYDATA_ID, char_buffer_size,
diff --git a/setup.py b/setup.py
index 7d6706ed..be404b1b 100644
--- a/setup.py
+++ b/setup.py
@@ -54,7 +54,7 @@ if sys.platform == 'darwin':
setup(name = "paramiko",
- version = "1.12.2",
+ version = "1.13.0",
description = "SSH2 protocol library",
author = "Jeff Forcier",
author_email = "jeff@bitprophet.org",
diff --git a/test.py b/test.py
index 6702e53a..bd966d1e 100755
--- a/test.py
+++ b/test.py
@@ -29,22 +29,21 @@ import unittest
from optparse import OptionParser
import paramiko
import threading
+from paramiko.py3compat import PY2
sys.path.append('tests')
-from test_message import MessageTest
-from test_file import BufferedFileTest
-from test_buffered_pipe import BufferedPipeTest
-from test_util import UtilTest
-from test_hostkeys import HostKeysTest
-from test_pkey import KeyTest
-from test_kex import KexTest
-from test_packetizer import PacketizerTest
-from test_auth import AuthTest
-from test_transport import TransportTest
-from test_sftp import SFTPTest
-from test_sftp_big import BigSFTPTest
-from test_client import SSHClientTest
+from tests.test_message import MessageTest
+from tests.test_file import BufferedFileTest
+from tests.test_buffered_pipe import BufferedPipeTest
+from tests.test_util import UtilTest
+from tests.test_hostkeys import HostKeysTest
+from tests.test_pkey import KeyTest
+from tests.test_kex import KexTest
+from tests.test_packetizer import PacketizerTest
+from tests.test_auth import AuthTest
+from tests.test_transport import TransportTest
+from tests.test_client import SSHClientTest
default_host = 'localhost'
default_user = os.environ.get('USER', 'nobody')
@@ -109,13 +108,16 @@ def main():
paramiko.util.log_to_file('test.log')
if options.use_sftp:
+ from tests.test_sftp import SFTPTest
if options.use_loopback_sftp:
SFTPTest.init_loopback()
else:
SFTPTest.init(options.hostname, options.username, options.keyfile, options.password)
if not options.use_big_file:
SFTPTest.set_big_file_test(False)
-
+ if options.use_big_file:
+ from tests.test_sftp_big import BigSFTPTest
+
suite = unittest.TestSuite()
suite.addTest(unittest.makeSuite(MessageTest))
suite.addTest(unittest.makeSuite(BufferedFileTest))
@@ -147,7 +149,10 @@ def main():
# TODO: make that not a problem, jeez
for thread in threading.enumerate():
if thread is not threading.currentThread():
- thread._Thread__stop()
+ if PY2:
+ thread._Thread__stop()
+ else:
+ thread._stop()
# Exit correctly
if not result.wasSuccessful():
sys.exit(1)
diff --git a/tests/__init__.py b/tests/__init__.py
new file mode 100644
index 00000000..e69de29b
--- /dev/null
+++ b/tests/__init__.py
diff --git a/tests/loop.py b/tests/loop.py
index 91c216d2..6e933f83 100644
--- a/tests/loop.py
+++ b/tests/loop.py
@@ -21,6 +21,7 @@
"""
import threading, socket
+from paramiko.common import *
class LoopSocket (object):
@@ -31,7 +32,7 @@ class LoopSocket (object):
"""
def __init__(self):
- self.__in_buffer = ''
+ self.__in_buffer = bytes()
self.__lock = threading.Lock()
self.__cv = threading.Condition(self.__lock)
self.__timeout = None
@@ -41,11 +42,12 @@ class LoopSocket (object):
self.__unlink()
try:
self.__lock.acquire()
- self.__in_buffer = ''
+ self.__in_buffer = bytes()
finally:
self.__lock.release()
def send(self, data):
+ data = asbytes(data)
if self.__mate is None:
# EOF
raise EOFError()
@@ -57,7 +59,7 @@ class LoopSocket (object):
try:
if self.__mate is None:
# EOF
- return ''
+ return bytes()
if len(self.__in_buffer) == 0:
self.__cv.wait(self.__timeout)
if len(self.__in_buffer) == 0:
diff --git a/tests/stub_sftp.py b/tests/stub_sftp.py
index 3021d816..58e4be26 100644
--- a/tests/stub_sftp.py
+++ b/tests/stub_sftp.py
@@ -21,8 +21,10 @@ A stub SFTP server for loopback SFTP testing.
"""
import os
+import sys
from paramiko import ServerInterface, SFTPServerInterface, SFTPServer, SFTPAttributes, \
SFTPHandle, SFTP_OK, AUTH_SUCCESSFUL, OPEN_SUCCEEDED
+from paramiko.common import *
class StubServer (ServerInterface):
@@ -38,7 +40,7 @@ class StubSFTPHandle (SFTPHandle):
def stat(self):
try:
return SFTPAttributes.from_stat(os.fstat(self.readfile.fileno()))
- except OSError, e:
+ except OSError as e:
return SFTPServer.convert_errno(e.errno)
def chattr(self, attr):
@@ -47,7 +49,7 @@ class StubSFTPHandle (SFTPHandle):
try:
SFTPServer.set_file_attr(self.filename, attr)
return SFTP_OK
- except OSError, e:
+ except OSError as e:
return SFTPServer.convert_errno(e.errno)
@@ -69,21 +71,21 @@ class StubSFTPServer (SFTPServerInterface):
attr.filename = fname
out.append(attr)
return out
- except OSError, e:
+ except OSError as e:
return SFTPServer.convert_errno(e.errno)
def stat(self, path):
path = self._realpath(path)
try:
return SFTPAttributes.from_stat(os.stat(path))
- except OSError, e:
+ except OSError as e:
return SFTPServer.convert_errno(e.errno)
def lstat(self, path):
path = self._realpath(path)
try:
return SFTPAttributes.from_stat(os.lstat(path))
- except OSError, e:
+ except OSError as e:
return SFTPServer.convert_errno(e.errno)
def open(self, path, flags, attr):
@@ -97,8 +99,8 @@ class StubSFTPServer (SFTPServerInterface):
else:
# os.open() defaults to 0777 which is
# an odd default mode for files
- fd = os.open(path, flags, 0666)
- except OSError, e:
+ fd = os.open(path, flags, o666)
+ except OSError as e:
return SFTPServer.convert_errno(e.errno)
if (flags & os.O_CREAT) and (attr is not None):
attr._flags &= ~attr.FLAG_PERMISSIONS
@@ -118,7 +120,7 @@ class StubSFTPServer (SFTPServerInterface):
fstr = 'rb'
try:
f = os.fdopen(fd, fstr)
- except OSError, e:
+ except OSError as e:
return SFTPServer.convert_errno(e.errno)
fobj = StubSFTPHandle(flags)
fobj.filename = path
@@ -130,7 +132,7 @@ class StubSFTPServer (SFTPServerInterface):
path = self._realpath(path)
try:
os.remove(path)
- except OSError, e:
+ except OSError as e:
return SFTPServer.convert_errno(e.errno)
return SFTP_OK
@@ -139,7 +141,7 @@ class StubSFTPServer (SFTPServerInterface):
newpath = self._realpath(newpath)
try:
os.rename(oldpath, newpath)
- except OSError, e:
+ except OSError as e:
return SFTPServer.convert_errno(e.errno)
return SFTP_OK
@@ -149,7 +151,7 @@ class StubSFTPServer (SFTPServerInterface):
os.mkdir(path)
if attr is not None:
SFTPServer.set_file_attr(path, attr)
- except OSError, e:
+ except OSError as e:
return SFTPServer.convert_errno(e.errno)
return SFTP_OK
@@ -157,7 +159,7 @@ class StubSFTPServer (SFTPServerInterface):
path = self._realpath(path)
try:
os.rmdir(path)
- except OSError, e:
+ except OSError as e:
return SFTPServer.convert_errno(e.errno)
return SFTP_OK
@@ -165,7 +167,7 @@ class StubSFTPServer (SFTPServerInterface):
path = self._realpath(path)
try:
SFTPServer.set_file_attr(path, attr)
- except OSError, e:
+ except OSError as e:
return SFTPServer.convert_errno(e.errno)
return SFTP_OK
@@ -185,7 +187,7 @@ class StubSFTPServer (SFTPServerInterface):
target_path = '<error>'
try:
os.symlink(target_path, path)
- except OSError, e:
+ except OSError as e:
return SFTPServer.convert_errno(e.errno)
return SFTP_OK
@@ -193,7 +195,7 @@ class StubSFTPServer (SFTPServerInterface):
path = self._realpath(path)
try:
symlink = os.readlink(path)
- except OSError, e:
+ except OSError as e:
return SFTPServer.convert_errno(e.errno)
# if it's absolute, remove the root
if os.path.isabs(symlink):
diff --git a/tests/test_auth.py b/tests/test_auth.py
index 61fe63f4..d26b1807 100644
--- a/tests/test_auth.py
+++ b/tests/test_auth.py
@@ -29,14 +29,18 @@ from paramiko import Transport, ServerInterface, RSAKey, DSSKey, \
AuthenticationException
from paramiko import AUTH_FAILED, AUTH_PARTIALLY_SUCCESSFUL, AUTH_SUCCESSFUL
from paramiko import OPEN_SUCCEEDED, OPEN_FAILED_ADMINISTRATIVELY_PROHIBITED
-from loop import LoopSocket
+from paramiko.py3compat import u
+from tests.loop import LoopSocket
+from tests.util import test_path
+
+_pwd = u('\u2022')
class NullServer (ServerInterface):
paranoid_did_password = False
paranoid_did_public_key = False
- paranoid_key = DSSKey.from_private_key_file('tests/test_dss.key')
-
+ paranoid_key = DSSKey.from_private_key_file(test_path('test_dss.key'))
+
def get_allowed_auths(self, username):
if username == 'slowdive':
return 'publickey,password'
@@ -64,7 +68,7 @@ class NullServer (ServerInterface):
if self.paranoid_did_public_key:
return AUTH_SUCCESSFUL
return AUTH_PARTIALLY_SUCCESSFUL
- if (username == 'utf8') and (password == u'\u2022'):
+ if (username == 'utf8') and (password == _pwd):
return AUTH_SUCCESSFUL
if (username == 'non-utf8') and (password == '\xff'):
return AUTH_SUCCESSFUL
@@ -110,18 +114,18 @@ class AuthTest (unittest.TestCase):
self.sockc.close()
def start_server(self):
- host_key = RSAKey.from_private_key_file('tests/test_rsa.key')
- self.public_host_key = RSAKey(data=str(host_key))
+ host_key = RSAKey.from_private_key_file(test_path('test_rsa.key'))
+ self.public_host_key = RSAKey(data=host_key.asbytes())
self.ts.add_server_key(host_key)
self.event = threading.Event()
self.server = NullServer()
- self.assert_(not self.event.isSet())
+ self.assertTrue(not self.event.isSet())
self.ts.start_server(self.event, self.server)
def verify_finished(self):
self.event.wait(1.0)
- self.assert_(self.event.isSet())
- self.assert_(self.ts.is_active())
+ self.assertTrue(self.event.isSet())
+ self.assertTrue(self.ts.is_active())
def test_1_bad_auth_type(self):
"""
@@ -132,11 +136,11 @@ class AuthTest (unittest.TestCase):
try:
self.tc.connect(hostkey=self.public_host_key,
username='unknown', password='error')
- self.assert_(False)
+ self.assertTrue(False)
except:
etype, evalue, etb = sys.exc_info()
- self.assertEquals(BadAuthenticationType, etype)
- self.assertEquals(['publickey'], evalue.allowed_types)
+ self.assertEqual(BadAuthenticationType, etype)
+ self.assertEqual(['publickey'], evalue.allowed_types)
def test_2_bad_password(self):
"""
@@ -147,10 +151,10 @@ class AuthTest (unittest.TestCase):
self.tc.connect(hostkey=self.public_host_key)
try:
self.tc.auth_password(username='slowdive', password='error')
- self.assert_(False)
+ self.assertTrue(False)
except:
etype, evalue, etb = sys.exc_info()
- self.assert_(issubclass(etype, AuthenticationException))
+ self.assertTrue(issubclass(etype, AuthenticationException))
self.tc.auth_password(username='slowdive', password='pygmalion')
self.verify_finished()
@@ -161,10 +165,10 @@ class AuthTest (unittest.TestCase):
self.start_server()
self.tc.connect(hostkey=self.public_host_key)
remain = self.tc.auth_password(username='paranoid', password='paranoid')
- self.assertEquals(['publickey'], remain)
- key = DSSKey.from_private_key_file('tests/test_dss.key')
+ self.assertEqual(['publickey'], remain)
+ key = DSSKey.from_private_key_file(test_path('test_dss.key'))
remain = self.tc.auth_publickey(username='paranoid', key=key)
- self.assertEquals([], remain)
+ self.assertEqual([], remain)
self.verify_finished()
def test_4_interactive_auth(self):
@@ -180,9 +184,9 @@ class AuthTest (unittest.TestCase):
self.got_prompts = prompts
return ['cat']
remain = self.tc.auth_interactive('commie', handler)
- self.assertEquals(self.got_title, 'password')
- self.assertEquals(self.got_prompts, [('Password', False)])
- self.assertEquals([], remain)
+ self.assertEqual(self.got_title, 'password')
+ self.assertEqual(self.got_prompts, [('Password', False)])
+ self.assertEqual([], remain)
self.verify_finished()
def test_5_interactive_auth_fallback(self):
@@ -193,7 +197,7 @@ class AuthTest (unittest.TestCase):
self.start_server()
self.tc.connect(hostkey=self.public_host_key)
remain = self.tc.auth_password('commie', 'cat')
- self.assertEquals([], remain)
+ self.assertEqual([], remain)
self.verify_finished()
def test_6_auth_utf8(self):
@@ -202,8 +206,8 @@ class AuthTest (unittest.TestCase):
"""
self.start_server()
self.tc.connect(hostkey=self.public_host_key)
- remain = self.tc.auth_password('utf8', u'\u2022')
- self.assertEquals([], remain)
+ remain = self.tc.auth_password('utf8', _pwd)
+ self.assertEqual([], remain)
self.verify_finished()
def test_7_auth_non_utf8(self):
@@ -214,7 +218,7 @@ class AuthTest (unittest.TestCase):
self.start_server()
self.tc.connect(hostkey=self.public_host_key)
remain = self.tc.auth_password('non-utf8', '\xff')
- self.assertEquals([], remain)
+ self.assertEqual([], remain)
self.verify_finished()
def test_8_auth_gets_disconnected(self):
@@ -228,4 +232,4 @@ class AuthTest (unittest.TestCase):
remain = self.tc.auth_password('bad-server', 'hello')
except:
etype, evalue, etb = sys.exc_info()
- self.assert_(issubclass(etype, AuthenticationException))
+ self.assertTrue(issubclass(etype, AuthenticationException))
diff --git a/tests/test_buffered_pipe.py b/tests/test_buffered_pipe.py
index 47ece936..b9d2bef4 100644
--- a/tests/test_buffered_pipe.py
+++ b/tests/test_buffered_pipe.py
@@ -25,8 +25,9 @@ import time
import unittest
from paramiko.buffered_pipe import BufferedPipe, PipeTimeout
from paramiko import pipe
+from paramiko.py3compat import b
-from util import ParamikoTest
+from tests.util import ParamikoTest
def delay_thread(pipe):
@@ -44,39 +45,39 @@ def close_thread(pipe):
class BufferedPipeTest(ParamikoTest):
def test_1_buffered_pipe(self):
p = BufferedPipe()
- self.assert_(not p.read_ready())
+ self.assertTrue(not p.read_ready())
p.feed('hello.')
- self.assert_(p.read_ready())
+ self.assertTrue(p.read_ready())
data = p.read(6)
- self.assertEquals('hello.', data)
+ self.assertEqual(b'hello.', data)
p.feed('plus/minus')
- self.assertEquals('plu', p.read(3))
- self.assertEquals('s/m', p.read(3))
- self.assertEquals('inus', p.read(4))
+ self.assertEqual(b'plu', p.read(3))
+ self.assertEqual(b's/m', p.read(3))
+ self.assertEqual(b'inus', p.read(4))
p.close()
- self.assert_(not p.read_ready())
- self.assertEquals('', p.read(1))
+ self.assertTrue(not p.read_ready())
+ self.assertEqual(b'', p.read(1))
def test_2_delay(self):
p = BufferedPipe()
- self.assert_(not p.read_ready())
+ self.assertTrue(not p.read_ready())
threading.Thread(target=delay_thread, args=(p,)).start()
- self.assertEquals('a', p.read(1, 0.1))
+ self.assertEqual(b'a', p.read(1, 0.1))
try:
p.read(1, 0.1)
- self.assert_(False)
+ self.assertTrue(False)
except PipeTimeout:
pass
- self.assertEquals('b', p.read(1, 1.0))
- self.assertEquals('', p.read(1))
+ self.assertEqual(b'b', p.read(1, 1.0))
+ self.assertEqual(b'', p.read(1))
def test_3_close_while_reading(self):
p = BufferedPipe()
threading.Thread(target=close_thread, args=(p,)).start()
data = p.read(1, 1.0)
- self.assertEquals('', data)
+ self.assertEqual(b'', data)
def test_4_or_pipe(self):
p = pipe.make_pipe()
diff --git a/tests/test_client.py b/tests/test_client.py
index fae1d329..97150979 100644
--- a/tests/test_client.py
+++ b/tests/test_client.py
@@ -20,16 +20,14 @@
Some unit tests for SSHClient.
"""
-from __future__ import with_statement # Python 2.5 support
import socket
+from tempfile import mkstemp
import threading
-import time
import unittest
import weakref
import warnings
import os
-from binascii import hexlify
-
+from tests.util import test_path
import paramiko
@@ -46,7 +44,7 @@ class NullServer (paramiko.ServerInterface):
return paramiko.AUTH_FAILED
def check_auth_publickey(self, username, key):
- if (key.get_name() == 'ssh-dss') and (hexlify(key.get_fingerprint()) == '4478f0b9a23cc5182009ff755bc1d26c'):
+ if (key.get_name() == 'ssh-dss') and key.get_fingerprint() == b'\x44\x78\xf0\xb9\xa2\x3c\xc5\x18\x20\x09\xff\x75\x5b\xc1\xd2\x6c':
return paramiko.AUTH_SUCCESSFUL
return paramiko.AUTH_FAILED
@@ -67,8 +65,6 @@ class SSHClientTest (unittest.TestCase):
self.sockl.listen(1)
self.addr, self.port = self.sockl.getsockname()
self.event = threading.Event()
- thread = threading.Thread(target=self._run)
- thread.start()
def tearDown(self):
for attr in "tc ts socks sockl".split():
@@ -78,28 +74,28 @@ class SSHClientTest (unittest.TestCase):
def _run(self):
self.socks, addr = self.sockl.accept()
self.ts = paramiko.Transport(self.socks)
- host_key = paramiko.RSAKey.from_private_key_file('tests/test_rsa.key')
+ host_key = paramiko.RSAKey.from_private_key_file(test_path('test_rsa.key'))
self.ts.add_server_key(host_key)
server = NullServer()
self.ts.start_server(self.event, server)
-
def test_1_client(self):
"""
verify that the SSHClient stuff works too.
"""
- host_key = paramiko.RSAKey.from_private_key_file('tests/test_rsa.key')
- public_host_key = paramiko.RSAKey(data=str(host_key))
+ threading.Thread(target=self._run).start()
+ host_key = paramiko.RSAKey.from_private_key_file(test_path('test_rsa.key'))
+ public_host_key = paramiko.RSAKey(data=host_key.asbytes())
self.tc = paramiko.SSHClient()
self.tc.get_host_keys().add('[%s]:%d' % (self.addr, self.port), 'ssh-rsa', public_host_key)
self.tc.connect(self.addr, self.port, username='slowdive', password='pygmalion')
self.event.wait(1.0)
- self.assert_(self.event.isSet())
- self.assert_(self.ts.is_active())
- self.assertEquals('slowdive', self.ts.get_username())
- self.assertEquals(True, self.ts.is_authenticated())
+ self.assertTrue(self.event.isSet())
+ self.assertTrue(self.ts.is_active())
+ self.assertEqual('slowdive', self.ts.get_username())
+ self.assertEqual(True, self.ts.is_authenticated())
stdin, stdout, stderr = self.tc.exec_command('yes')
schan = self.ts.accept(1.0)
@@ -108,10 +104,10 @@ class SSHClientTest (unittest.TestCase):
schan.send_stderr('This is on stderr.\n')
schan.close()
- self.assertEquals('Hello there.\n', stdout.readline())
- self.assertEquals('', stdout.readline())
- self.assertEquals('This is on stderr.\n', stderr.readline())
- self.assertEquals('', stderr.readline())
+ self.assertEqual('Hello there.\n', stdout.readline())
+ self.assertEqual('', stdout.readline())
+ self.assertEqual('This is on stderr.\n', stderr.readline())
+ self.assertEqual('', stderr.readline())
stdin.close()
stdout.close()
@@ -121,18 +117,19 @@ class SSHClientTest (unittest.TestCase):
"""
verify that SSHClient works with a DSA key.
"""
- host_key = paramiko.RSAKey.from_private_key_file('tests/test_rsa.key')
- public_host_key = paramiko.RSAKey(data=str(host_key))
+ threading.Thread(target=self._run).start()
+ host_key = paramiko.RSAKey.from_private_key_file(test_path('test_rsa.key'))
+ public_host_key = paramiko.RSAKey(data=host_key.asbytes())
self.tc = paramiko.SSHClient()
self.tc.get_host_keys().add('[%s]:%d' % (self.addr, self.port), 'ssh-rsa', public_host_key)
- self.tc.connect(self.addr, self.port, username='slowdive', key_filename='tests/test_dss.key')
+ self.tc.connect(self.addr, self.port, username='slowdive', key_filename=test_path('test_dss.key'))
self.event.wait(1.0)
- self.assert_(self.event.isSet())
- self.assert_(self.ts.is_active())
- self.assertEquals('slowdive', self.ts.get_username())
- self.assertEquals(True, self.ts.is_authenticated())
+ self.assertTrue(self.event.isSet())
+ self.assertTrue(self.ts.is_active())
+ self.assertEqual('slowdive', self.ts.get_username())
+ self.assertEqual(True, self.ts.is_authenticated())
stdin, stdout, stderr = self.tc.exec_command('yes')
schan = self.ts.accept(1.0)
@@ -141,10 +138,10 @@ class SSHClientTest (unittest.TestCase):
schan.send_stderr('This is on stderr.\n')
schan.close()
- self.assertEquals('Hello there.\n', stdout.readline())
- self.assertEquals('', stdout.readline())
- self.assertEquals('This is on stderr.\n', stderr.readline())
- self.assertEquals('', stderr.readline())
+ self.assertEqual('Hello there.\n', stdout.readline())
+ self.assertEqual('', stdout.readline())
+ self.assertEqual('This is on stderr.\n', stderr.readline())
+ self.assertEqual('', stderr.readline())
stdin.close()
stdout.close()
@@ -154,38 +151,40 @@ class SSHClientTest (unittest.TestCase):
"""
verify that SSHClient accepts and tries multiple key files.
"""
- host_key = paramiko.RSAKey.from_private_key_file('tests/test_rsa.key')
- public_host_key = paramiko.RSAKey(data=str(host_key))
+ threading.Thread(target=self._run).start()
+ host_key = paramiko.RSAKey.from_private_key_file(test_path('test_rsa.key'))
+ public_host_key = paramiko.RSAKey(data=host_key.asbytes())
self.tc = paramiko.SSHClient()
self.tc.get_host_keys().add('[%s]:%d' % (self.addr, self.port), 'ssh-rsa', public_host_key)
- self.tc.connect(self.addr, self.port, username='slowdive', key_filename=[ 'tests/test_rsa.key', 'tests/test_dss.key' ])
+ self.tc.connect(self.addr, self.port, username='slowdive', key_filename=[ test_path('test_rsa.key'), test_path('test_dss.key') ])
self.event.wait(1.0)
- self.assert_(self.event.isSet())
- self.assert_(self.ts.is_active())
- self.assertEquals('slowdive', self.ts.get_username())
- self.assertEquals(True, self.ts.is_authenticated())
+ self.assertTrue(self.event.isSet())
+ self.assertTrue(self.ts.is_active())
+ self.assertEqual('slowdive', self.ts.get_username())
+ self.assertEqual(True, self.ts.is_authenticated())
def test_4_auto_add_policy(self):
"""
verify that SSHClient's AutoAddPolicy works.
"""
- host_key = paramiko.RSAKey.from_private_key_file('tests/test_rsa.key')
- public_host_key = paramiko.RSAKey(data=str(host_key))
+ threading.Thread(target=self._run).start()
+ host_key = paramiko.RSAKey.from_private_key_file(test_path('test_rsa.key'))
+ public_host_key = paramiko.RSAKey(data=host_key.asbytes())
self.tc = paramiko.SSHClient()
self.tc.set_missing_host_key_policy(paramiko.AutoAddPolicy())
- self.assertEquals(0, len(self.tc.get_host_keys()))
+ self.assertEqual(0, len(self.tc.get_host_keys()))
self.tc.connect(self.addr, self.port, username='slowdive', password='pygmalion')
self.event.wait(1.0)
- self.assert_(self.event.isSet())
- self.assert_(self.ts.is_active())
- self.assertEquals('slowdive', self.ts.get_username())
- self.assertEquals(True, self.ts.is_authenticated())
- self.assertEquals(1, len(self.tc.get_host_keys()))
- self.assertEquals(public_host_key, self.tc.get_host_keys()['[%s]:%d' % (self.addr, self.port)]['ssh-rsa'])
+ self.assertTrue(self.event.isSet())
+ self.assertTrue(self.ts.is_active())
+ self.assertEqual('slowdive', self.ts.get_username())
+ self.assertEqual(True, self.ts.is_authenticated())
+ self.assertEqual(1, len(self.tc.get_host_keys()))
+ self.assertEqual(public_host_key, self.tc.get_host_keys()['[%s]:%d' % (self.addr, self.port)]['ssh-rsa'])
def test_5_save_host_keys(self):
"""
@@ -193,9 +192,10 @@ class SSHClientTest (unittest.TestCase):
"""
warnings.filterwarnings('ignore', 'tempnam.*')
- host_key = paramiko.RSAKey.from_private_key_file('tests/test_rsa.key')
- public_host_key = paramiko.RSAKey(data=str(host_key))
- localname = os.tempnam()
+ host_key = paramiko.RSAKey.from_private_key_file(test_path('test_rsa.key'))
+ public_host_key = paramiko.RSAKey(data=host_key.asbytes())
+ fd, localname = mkstemp()
+ os.close(fd)
client = paramiko.SSHClient()
self.assertEquals(0, len(client.get_host_keys()))
@@ -218,24 +218,32 @@ class SSHClientTest (unittest.TestCase):
verify that when an SSHClient is collected, its transport (and the
transport's packetizer) is closed.
"""
- host_key = paramiko.RSAKey.from_private_key_file('tests/test_rsa.key')
- public_host_key = paramiko.RSAKey(data=str(host_key))
+ threading.Thread(target=self._run).start()
+ host_key = paramiko.RSAKey.from_private_key_file(test_path('test_rsa.key'))
+ public_host_key = paramiko.RSAKey(data=host_key.asbytes())
self.tc = paramiko.SSHClient()
self.tc.set_missing_host_key_policy(paramiko.AutoAddPolicy())
- self.assertEquals(0, len(self.tc.get_host_keys()))
+ self.assertEqual(0, len(self.tc.get_host_keys()))
self.tc.connect(self.addr, self.port, username='slowdive', password='pygmalion')
self.event.wait(1.0)
- self.assert_(self.event.isSet())
- self.assert_(self.ts.is_active())
+ self.assertTrue(self.event.isSet())
+ self.assertTrue(self.ts.is_active())
p = weakref.ref(self.tc._transport.packetizer)
- self.assert_(p() is not None)
+ self.assertTrue(p() is not None)
+ self.tc.close()
del self.tc
+
# hrm, sometimes p isn't cleared right away. why is that?
- st = time.time()
- while (time.time() - st < 5.0) and (p() is not None):
- time.sleep(0.1)
- self.assert_(p() is None)
+ #st = time.time()
+ #while (time.time() - st < 5.0) and (p() is not None):
+ # time.sleep(0.1)
+
+ # instead of dumbly waiting for the GC to collect, force a collection
+ # to see whether the SSHClient object is deallocated correctly
+ import gc
+ gc.collect()
+ self.assertTrue(p() is None)
diff --git a/tests/test_file.py b/tests/test_file.py
index 6cb35070..33a49130 100755
--- a/tests/test_file.py
+++ b/tests/test_file.py
@@ -22,6 +22,7 @@ Some unit tests for the BufferedFile abstraction.
import unittest
from paramiko.file import BufferedFile
+from paramiko.common import *
class LoopbackFile (BufferedFile):
@@ -31,7 +32,7 @@ class LoopbackFile (BufferedFile):
def __init__(self, mode='r', bufsize=-1):
BufferedFile.__init__(self)
self._set_mode(mode, bufsize)
- self.buffer = ''
+ self.buffer = bytes()
def _read(self, size):
if len(self.buffer) == 0:
@@ -53,7 +54,7 @@ class BufferedFileTest (unittest.TestCase):
f = LoopbackFile('r')
try:
f.write('hi')
- self.assert_(False, 'no exception on write to read-only file')
+ self.assertTrue(False, 'no exception on write to read-only file')
except:
pass
f.close()
@@ -61,7 +62,7 @@ class BufferedFileTest (unittest.TestCase):
f = LoopbackFile('w')
try:
f.read(1)
- self.assert_(False, 'no exception to read from write-only file')
+ self.assertTrue(False, 'no exception to read from write-only file')
except:
pass
f.close()
@@ -80,12 +81,12 @@ class BufferedFileTest (unittest.TestCase):
f.close()
try:
f.readline()
- self.assert_(False, 'no exception on readline of closed file')
+ self.assertTrue(False, 'no exception on readline of closed file')
except IOError:
pass
- self.assert_('\n' in f.newlines)
- self.assert_('\r\n' in f.newlines)
- self.assert_('\r' not in f.newlines)
+ self.assertTrue(linefeed_byte in f.newlines)
+ self.assertTrue(crlf in f.newlines)
+ self.assertTrue(cr_byte not in f.newlines)
def test_3_lf(self):
"""
@@ -97,7 +98,7 @@ class BufferedFileTest (unittest.TestCase):
f.write('\nSecond.\r\n')
self.assertEqual(f.readline(), 'Second.\n')
f.close()
- self.assertEqual(f.newlines, '\r\n')
+ self.assertEqual(f.newlines, crlf)
def test_4_write(self):
"""
diff --git a/tests/test_hostkeys.py b/tests/test_hostkeys.py
index 44070cbe..9a7e3689 100644
--- a/tests/test_hostkeys.py
+++ b/tests/test_hostkeys.py
@@ -25,6 +25,7 @@ from binascii import hexlify
import os
import unittest
import paramiko
+from paramiko.py3compat import b, decodebytes
test_hosts_file = """\
@@ -36,12 +37,12 @@ BGQ3GQ/Fc7SX6gkpXkwcZryoi4kNFhHu5LvHcZPdxXV1D+uTMfGS1eyd2Yz/DoNWXNAl8TI0cAsW\
5ymME3bQ4J/k1IKxCtz/bAlAqFgKoc+EolMziDYqWIATtW0rYTJvzGAzTmMj80/QpsFH+Pc2M=
"""
-keyblob = """\
+keyblob = b"""\
AAAAB3NzaC1yc2EAAAABIwAAAIEA8bP1ZA7DCZDB9J0s50l31MBGQ3GQ/Fc7SX6gkpXkwcZryoi4k\
NFhHu5LvHcZPdxXV1D+uTMfGS1eyd2Yz/DoNWXNAl8TI0cAsW5ymME3bQ4J/k1IKxCtz/bAlAqFgK\
oc+EolMziDYqWIATtW0rYTJvzGAzTmMj80/QpsFH+Pc2M="""
-keyblob_dss = """\
+keyblob_dss = b"""\
AAAAB3NzaC1kc3MAAACBAOeBpgNnfRzr/twmAQRu2XwWAp3CFtrVnug6s6fgwj/oLjYbVtjAy6pl/\
h0EKCWx2rf1IetyNsTxWrniA9I6HeDj65X1FyDkg6g8tvCnaNB8Xp/UUhuzHuGsMIipRxBxw9LF60\
8EqZcj1E3ytktoW5B5OcjrkEoz3xG7C+rpIjYvAAAAFQDwz4UnmsGiSNu5iqjn3uTzwUpshwAAAIE\
@@ -55,51 +56,50 @@ Ngw3qIch/WgRmMHy4kBq1SsXMjQCte1So6HBMvBPIW5SiMTmjCfZZiw4AYHK+B/JaOwaG9yRg2Ejg\
class HostKeysTest (unittest.TestCase):
def setUp(self):
- f = open('hostfile.temp', 'w')
- f.write(test_hosts_file)
- f.close()
+ with open('hostfile.temp', 'w') as f:
+ f.write(test_hosts_file)
def tearDown(self):
os.unlink('hostfile.temp')
def test_1_load(self):
hostdict = paramiko.HostKeys('hostfile.temp')
- self.assertEquals(2, len(hostdict))
- self.assertEquals(1, len(hostdict.values()[0]))
- self.assertEquals(1, len(hostdict.values()[1]))
+ self.assertEqual(2, len(hostdict))
+ self.assertEqual(1, len(list(hostdict.values())[0]))
+ self.assertEqual(1, len(list(hostdict.values())[1]))
fp = hexlify(hostdict['secure.example.com']['ssh-rsa'].get_fingerprint()).upper()
- self.assertEquals('E6684DB30E109B67B70FF1DC5C7F1363', fp)
+ self.assertEqual(b'E6684DB30E109B67B70FF1DC5C7F1363', fp)
def test_2_add(self):
hostdict = paramiko.HostKeys('hostfile.temp')
hh = '|1|BMsIC6cUIP2zBuXR3t2LRcJYjzM=|hpkJMysjTk/+zzUUzxQEa2ieq6c='
- key = paramiko.RSAKey(data=base64.decodestring(keyblob))
+ key = paramiko.RSAKey(data=decodebytes(keyblob))
hostdict.add(hh, 'ssh-rsa', key)
- self.assertEquals(3, len(hostdict))
+ self.assertEqual(3, len(list(hostdict)))
x = hostdict['foo.example.com']
fp = hexlify(x['ssh-rsa'].get_fingerprint()).upper()
- self.assertEquals('7EC91BB336CB6D810B124B1353C32396', fp)
- self.assert_(hostdict.check('foo.example.com', key))
+ self.assertEqual(b'7EC91BB336CB6D810B124B1353C32396', fp)
+ self.assertTrue(hostdict.check('foo.example.com', key))
def test_3_dict(self):
hostdict = paramiko.HostKeys('hostfile.temp')
- self.assert_('secure.example.com' in hostdict)
- self.assert_('not.example.com' not in hostdict)
- self.assert_(hostdict.has_key('secure.example.com'))
- self.assert_(not hostdict.has_key('not.example.com'))
+ self.assertTrue('secure.example.com' in hostdict)
+ self.assertTrue('not.example.com' not in hostdict)
+ self.assertTrue('secure.example.com' in hostdict)
+ self.assertTrue('not.example.com' not in hostdict)
x = hostdict.get('secure.example.com', None)
- self.assert_(x is not None)
+ self.assertTrue(x is not None)
fp = hexlify(x['ssh-rsa'].get_fingerprint()).upper()
- self.assertEquals('E6684DB30E109B67B70FF1DC5C7F1363', fp)
+ self.assertEqual(b'E6684DB30E109B67B70FF1DC5C7F1363', fp)
i = 0
for key in hostdict:
i += 1
- self.assertEquals(2, i)
+ self.assertEqual(2, i)
def test_4_dict_set(self):
hostdict = paramiko.HostKeys('hostfile.temp')
- key = paramiko.RSAKey(data=base64.decodestring(keyblob))
- key_dss = paramiko.DSSKey(data=base64.decodestring(keyblob_dss))
+ key = paramiko.RSAKey(data=decodebytes(keyblob))
+ key_dss = paramiko.DSSKey(data=decodebytes(keyblob_dss))
hostdict['secure.example.com'] = {
'ssh-rsa': key,
'ssh-dss': key_dss
@@ -107,11 +107,11 @@ class HostKeysTest (unittest.TestCase):
hostdict['fake.example.com'] = {}
hostdict['fake.example.com']['ssh-rsa'] = key
- self.assertEquals(3, len(hostdict))
- self.assertEquals(2, len(hostdict.values()[0]))
- self.assertEquals(1, len(hostdict.values()[1]))
- self.assertEquals(1, len(hostdict.values()[2]))
+ self.assertEqual(3, len(hostdict))
+ self.assertEqual(2, len(list(hostdict.values())[0]))
+ self.assertEqual(1, len(list(hostdict.values())[1]))
+ self.assertEqual(1, len(list(hostdict.values())[2]))
fp = hexlify(hostdict['secure.example.com']['ssh-rsa'].get_fingerprint()).upper()
- self.assertEquals('7EC91BB336CB6D810B124B1353C32396', fp)
+ self.assertEqual(b'7EC91BB336CB6D810B124B1353C32396', fp)
fp = hexlify(hostdict['secure.example.com']['ssh-dss'].get_fingerprint()).upper()
- self.assertEquals('4478F0B9A23CC5182009FF755BC1D26C', fp)
+ self.assertEqual(b'4478F0B9A23CC5182009FF755BC1D26C', fp)
diff --git a/tests/test_kex.py b/tests/test_kex.py
index 39d2e17e..e69c051b 100644
--- a/tests/test_kex.py
+++ b/tests/test_kex.py
@@ -26,22 +26,25 @@ import paramiko.util
from paramiko.kex_group1 import KexGroup1
from paramiko.kex_gex import KexGex
from paramiko import Message
+from paramiko.common import *
class FakeRng (object):
def read(self, n):
- return chr(0xcc) * n
+ return byte_chr(0xcc) * n
class FakeKey (object):
def __str__(self):
return 'fake-key'
+ def asbytes(self):
+ return b'fake-key'
def sign_ssh_data(self, rng, H):
- return 'fake-sig'
+ return b'fake-sig'
class FakeModulusPack (object):
- P = 0xFFFFFFFFFFFFFFFFC90FDAA22168C234C4C6628B80DC1CD129024E088A67CC74020BBEA63B139B22514A08798E3404DDEF9519B3CD3A431B302B0A6DF25F14374FE1356D6D51C245E485B576625E7EC6F44C42E9A637ED6B0BFF5CB6F406B7EDEE386BFB5A899FA5AE9F24117C4B1FE649286651ECE65381FFFFFFFFFFFFFFFFL
+ P = 0xFFFFFFFFFFFFFFFFC90FDAA22168C234C4C6628B80DC1CD129024E088A67CC74020BBEA63B139B22514A08798E3404DDEF9519B3CD3A431B302B0A6DF25F14374FE1356D6D51C245E485B576625E7EC6F44C42E9A637ED6B0BFF5CB6F406B7EDEE386BFB5A899FA5AE9F24117C4B1FE649286651ECE65381FFFFFFFFFFFFFFFF
G = 2
def get_modulus(self, min, ask, max):
return self.G, self.P
@@ -75,7 +78,7 @@ class FakeTransport (object):
class KexTest (unittest.TestCase):
- K = 14730343317708716439807310032871972459448364195094179797249681733965528989482751523943515690110179031004049109375612685505881911274101441415545039654102474376472240501616988799699744135291070488314748284283496055223852115360852283821334858541043710301057312858051901453919067023103730011648890038847384890504L
+ K = 14730343317708716439807310032871972459448364195094179797249681733965528989482751523943515690110179031004049109375612685505881911274101441415545039654102474376472240501616988799699744135291070488314748284283496055223852115360852283821334858541043710301057312858051901453919067023103730011648890038847384890504
def setUp(self):
pass
@@ -88,9 +91,9 @@ class KexTest (unittest.TestCase):
transport.server_mode = False
kex = KexGroup1(transport)
kex.start_kex()
- x = '1E000000807E2DDB1743F3487D6545F04F1C8476092FB912B013626AB5BCEB764257D88BBA64243B9F348DF7B41B8C814A995E00299913503456983FFB9178D3CD79EB6D55522418A8ABF65375872E55938AB99A84A0B5FC8A1ECC66A7C3766E7E0F80B7CE2C9225FC2DD683F4764244B72963BBB383F529DCF0C5D17740B8A2ADBE9208D4'
- self.assertEquals(x, hexlify(str(transport._message)).upper())
- self.assertEquals((paramiko.kex_group1._MSG_KEXDH_REPLY,), transport._expect)
+ x = b'1E000000807E2DDB1743F3487D6545F04F1C8476092FB912B013626AB5BCEB764257D88BBA64243B9F348DF7B41B8C814A995E00299913503456983FFB9178D3CD79EB6D55522418A8ABF65375872E55938AB99A84A0B5FC8A1ECC66A7C3766E7E0F80B7CE2C9225FC2DD683F4764244B72963BBB383F529DCF0C5D17740B8A2ADBE9208D4'
+ self.assertEqual(x, hexlify(transport._message.asbytes()).upper())
+ self.assertEqual((paramiko.kex_group1._MSG_KEXDH_REPLY,), transport._expect)
# fake "reply"
msg = Message()
@@ -99,47 +102,47 @@ class KexTest (unittest.TestCase):
msg.add_string('fake-sig')
msg.rewind()
kex.parse_next(paramiko.kex_group1._MSG_KEXDH_REPLY, msg)
- H = '03079780F3D3AD0B3C6DB30C8D21685F367A86D2'
- self.assertEquals(self.K, transport._K)
- self.assertEquals(H, hexlify(transport._H).upper())
- self.assertEquals(('fake-host-key', 'fake-sig'), transport._verify)
- self.assert_(transport._activated)
+ H = b'03079780F3D3AD0B3C6DB30C8D21685F367A86D2'
+ self.assertEqual(self.K, transport._K)
+ self.assertEqual(H, hexlify(transport._H).upper())
+ self.assertEqual((b'fake-host-key', b'fake-sig'), transport._verify)
+ self.assertTrue(transport._activated)
def test_2_group1_server(self):
transport = FakeTransport()
transport.server_mode = True
kex = KexGroup1(transport)
kex.start_kex()
- self.assertEquals((paramiko.kex_group1._MSG_KEXDH_INIT,), transport._expect)
+ self.assertEqual((paramiko.kex_group1._MSG_KEXDH_INIT,), transport._expect)
msg = Message()
msg.add_mpint(69)
msg.rewind()
kex.parse_next(paramiko.kex_group1._MSG_KEXDH_INIT, msg)
- H = 'B16BF34DD10945EDE84E9C1EF24A14BFDC843389'
- x = '1F0000000866616B652D6B6579000000807E2DDB1743F3487D6545F04F1C8476092FB912B013626AB5BCEB764257D88BBA64243B9F348DF7B41B8C814A995E00299913503456983FFB9178D3CD79EB6D55522418A8ABF65375872E55938AB99A84A0B5FC8A1ECC66A7C3766E7E0F80B7CE2C9225FC2DD683F4764244B72963BBB383F529DCF0C5D17740B8A2ADBE9208D40000000866616B652D736967'
- self.assertEquals(self.K, transport._K)
- self.assertEquals(H, hexlify(transport._H).upper())
- self.assertEquals(x, hexlify(str(transport._message)).upper())
- self.assert_(transport._activated)
+ H = b'B16BF34DD10945EDE84E9C1EF24A14BFDC843389'
+ x = b'1F0000000866616B652D6B6579000000807E2DDB1743F3487D6545F04F1C8476092FB912B013626AB5BCEB764257D88BBA64243B9F348DF7B41B8C814A995E00299913503456983FFB9178D3CD79EB6D55522418A8ABF65375872E55938AB99A84A0B5FC8A1ECC66A7C3766E7E0F80B7CE2C9225FC2DD683F4764244B72963BBB383F529DCF0C5D17740B8A2ADBE9208D40000000866616B652D736967'
+ self.assertEqual(self.K, transport._K)
+ self.assertEqual(H, hexlify(transport._H).upper())
+ self.assertEqual(x, hexlify(transport._message.asbytes()).upper())
+ self.assertTrue(transport._activated)
def test_3_gex_client(self):
transport = FakeTransport()
transport.server_mode = False
kex = KexGex(transport)
kex.start_kex()
- x = '22000004000000080000002000'
- self.assertEquals(x, hexlify(str(transport._message)).upper())
- self.assertEquals((paramiko.kex_gex._MSG_KEXDH_GEX_GROUP,), transport._expect)
+ x = b'22000004000000080000002000'
+ self.assertEqual(x, hexlify(transport._message.asbytes()).upper())
+ self.assertEqual((paramiko.kex_gex._MSG_KEXDH_GEX_GROUP,), transport._expect)
msg = Message()
msg.add_mpint(FakeModulusPack.P)
msg.add_mpint(FakeModulusPack.G)
msg.rewind()
kex.parse_next(paramiko.kex_gex._MSG_KEXDH_GEX_GROUP, msg)
- x = '20000000807E2DDB1743F3487D6545F04F1C8476092FB912B013626AB5BCEB764257D88BBA64243B9F348DF7B41B8C814A995E00299913503456983FFB9178D3CD79EB6D55522418A8ABF65375872E55938AB99A84A0B5FC8A1ECC66A7C3766E7E0F80B7CE2C9225FC2DD683F4764244B72963BBB383F529DCF0C5D17740B8A2ADBE9208D4'
- self.assertEquals(x, hexlify(str(transport._message)).upper())
- self.assertEquals((paramiko.kex_gex._MSG_KEXDH_GEX_REPLY,), transport._expect)
+ x = b'20000000807E2DDB1743F3487D6545F04F1C8476092FB912B013626AB5BCEB764257D88BBA64243B9F348DF7B41B8C814A995E00299913503456983FFB9178D3CD79EB6D55522418A8ABF65375872E55938AB99A84A0B5FC8A1ECC66A7C3766E7E0F80B7CE2C9225FC2DD683F4764244B72963BBB383F529DCF0C5D17740B8A2ADBE9208D4'
+ self.assertEqual(x, hexlify(transport._message.asbytes()).upper())
+ self.assertEqual((paramiko.kex_gex._MSG_KEXDH_GEX_REPLY,), transport._expect)
msg = Message()
msg.add_string('fake-host-key')
@@ -147,29 +150,29 @@ class KexTest (unittest.TestCase):
msg.add_string('fake-sig')
msg.rewind()
kex.parse_next(paramiko.kex_gex._MSG_KEXDH_GEX_REPLY, msg)
- H = 'A265563F2FA87F1A89BF007EE90D58BE2E4A4BD0'
- self.assertEquals(self.K, transport._K)
- self.assertEquals(H, hexlify(transport._H).upper())
- self.assertEquals(('fake-host-key', 'fake-sig'), transport._verify)
- self.assert_(transport._activated)
+ H = b'A265563F2FA87F1A89BF007EE90D58BE2E4A4BD0'
+ self.assertEqual(self.K, transport._K)
+ self.assertEqual(H, hexlify(transport._H).upper())
+ self.assertEqual((b'fake-host-key', b'fake-sig'), transport._verify)
+ self.assertTrue(transport._activated)
def test_4_gex_old_client(self):
transport = FakeTransport()
transport.server_mode = False
kex = KexGex(transport)
kex.start_kex(_test_old_style=True)
- x = '1E00000800'
- self.assertEquals(x, hexlify(str(transport._message)).upper())
- self.assertEquals((paramiko.kex_gex._MSG_KEXDH_GEX_GROUP,), transport._expect)
+ x = b'1E00000800'
+ self.assertEqual(x, hexlify(transport._message.asbytes()).upper())
+ self.assertEqual((paramiko.kex_gex._MSG_KEXDH_GEX_GROUP,), transport._expect)
msg = Message()
msg.add_mpint(FakeModulusPack.P)
msg.add_mpint(FakeModulusPack.G)
msg.rewind()
kex.parse_next(paramiko.kex_gex._MSG_KEXDH_GEX_GROUP, msg)
- x = '20000000807E2DDB1743F3487D6545F04F1C8476092FB912B013626AB5BCEB764257D88BBA64243B9F348DF7B41B8C814A995E00299913503456983FFB9178D3CD79EB6D55522418A8ABF65375872E55938AB99A84A0B5FC8A1ECC66A7C3766E7E0F80B7CE2C9225FC2DD683F4764244B72963BBB383F529DCF0C5D17740B8A2ADBE9208D4'
- self.assertEquals(x, hexlify(str(transport._message)).upper())
- self.assertEquals((paramiko.kex_gex._MSG_KEXDH_GEX_REPLY,), transport._expect)
+ x = b'20000000807E2DDB1743F3487D6545F04F1C8476092FB912B013626AB5BCEB764257D88BBA64243B9F348DF7B41B8C814A995E00299913503456983FFB9178D3CD79EB6D55522418A8ABF65375872E55938AB99A84A0B5FC8A1ECC66A7C3766E7E0F80B7CE2C9225FC2DD683F4764244B72963BBB383F529DCF0C5D17740B8A2ADBE9208D4'
+ self.assertEqual(x, hexlify(transport._message.asbytes()).upper())
+ self.assertEqual((paramiko.kex_gex._MSG_KEXDH_GEX_REPLY,), transport._expect)
msg = Message()
msg.add_string('fake-host-key')
@@ -177,18 +180,18 @@ class KexTest (unittest.TestCase):
msg.add_string('fake-sig')
msg.rewind()
kex.parse_next(paramiko.kex_gex._MSG_KEXDH_GEX_REPLY, msg)
- H = '807F87B269EF7AC5EC7E75676808776A27D5864C'
- self.assertEquals(self.K, transport._K)
- self.assertEquals(H, hexlify(transport._H).upper())
- self.assertEquals(('fake-host-key', 'fake-sig'), transport._verify)
- self.assert_(transport._activated)
+ H = b'807F87B269EF7AC5EC7E75676808776A27D5864C'
+ self.assertEqual(self.K, transport._K)
+ self.assertEqual(H, hexlify(transport._H).upper())
+ self.assertEqual((b'fake-host-key', b'fake-sig'), transport._verify)
+ self.assertTrue(transport._activated)
def test_5_gex_server(self):
transport = FakeTransport()
transport.server_mode = True
kex = KexGex(transport)
kex.start_kex()
- self.assertEquals((paramiko.kex_gex._MSG_KEXDH_GEX_REQUEST, paramiko.kex_gex._MSG_KEXDH_GEX_REQUEST_OLD), transport._expect)
+ self.assertEqual((paramiko.kex_gex._MSG_KEXDH_GEX_REQUEST, paramiko.kex_gex._MSG_KEXDH_GEX_REQUEST_OLD), transport._expect)
msg = Message()
msg.add_int(1024)
@@ -196,45 +199,45 @@ class KexTest (unittest.TestCase):
msg.add_int(4096)
msg.rewind()
kex.parse_next(paramiko.kex_gex._MSG_KEXDH_GEX_REQUEST, msg)
- x = '1F0000008100FFFFFFFFFFFFFFFFC90FDAA22168C234C4C6628B80DC1CD129024E088A67CC74020BBEA63B139B22514A08798E3404DDEF9519B3CD3A431B302B0A6DF25F14374FE1356D6D51C245E485B576625E7EC6F44C42E9A637ED6B0BFF5CB6F406B7EDEE386BFB5A899FA5AE9F24117C4B1FE649286651ECE65381FFFFFFFFFFFFFFFF0000000102'
- self.assertEquals(x, hexlify(str(transport._message)).upper())
- self.assertEquals((paramiko.kex_gex._MSG_KEXDH_GEX_INIT,), transport._expect)
+ x = b'1F0000008100FFFFFFFFFFFFFFFFC90FDAA22168C234C4C6628B80DC1CD129024E088A67CC74020BBEA63B139B22514A08798E3404DDEF9519B3CD3A431B302B0A6DF25F14374FE1356D6D51C245E485B576625E7EC6F44C42E9A637ED6B0BFF5CB6F406B7EDEE386BFB5A899FA5AE9F24117C4B1FE649286651ECE65381FFFFFFFFFFFFFFFF0000000102'
+ self.assertEqual(x, hexlify(transport._message.asbytes()).upper())
+ self.assertEqual((paramiko.kex_gex._MSG_KEXDH_GEX_INIT,), transport._expect)
msg = Message()
msg.add_mpint(12345)
msg.rewind()
kex.parse_next(paramiko.kex_gex._MSG_KEXDH_GEX_INIT, msg)
- K = 67592995013596137876033460028393339951879041140378510871612128162185209509220726296697886624612526735888348020498716482757677848959420073720160491114319163078862905400020959196386947926388406687288901564192071077389283980347784184487280885335302632305026248574716290537036069329724382811853044654824945750581L
- H = 'CE754197C21BF3452863B4F44D0B3951F12516EF'
- x = '210000000866616B652D6B6579000000807E2DDB1743F3487D6545F04F1C8476092FB912B013626AB5BCEB764257D88BBA64243B9F348DF7B41B8C814A995E00299913503456983FFB9178D3CD79EB6D55522418A8ABF65375872E55938AB99A84A0B5FC8A1ECC66A7C3766E7E0F80B7CE2C9225FC2DD683F4764244B72963BBB383F529DCF0C5D17740B8A2ADBE9208D40000000866616B652D736967'
- self.assertEquals(K, transport._K)
- self.assertEquals(H, hexlify(transport._H).upper())
- self.assertEquals(x, hexlify(str(transport._message)).upper())
- self.assert_(transport._activated)
+ K = 67592995013596137876033460028393339951879041140378510871612128162185209509220726296697886624612526735888348020498716482757677848959420073720160491114319163078862905400020959196386947926388406687288901564192071077389283980347784184487280885335302632305026248574716290537036069329724382811853044654824945750581
+ H = b'CE754197C21BF3452863B4F44D0B3951F12516EF'
+ x = b'210000000866616B652D6B6579000000807E2DDB1743F3487D6545F04F1C8476092FB912B013626AB5BCEB764257D88BBA64243B9F348DF7B41B8C814A995E00299913503456983FFB9178D3CD79EB6D55522418A8ABF65375872E55938AB99A84A0B5FC8A1ECC66A7C3766E7E0F80B7CE2C9225FC2DD683F4764244B72963BBB383F529DCF0C5D17740B8A2ADBE9208D40000000866616B652D736967'
+ self.assertEqual(K, transport._K)
+ self.assertEqual(H, hexlify(transport._H).upper())
+ self.assertEqual(x, hexlify(transport._message.asbytes()).upper())
+ self.assertTrue(transport._activated)
def test_6_gex_server_with_old_client(self):
transport = FakeTransport()
transport.server_mode = True
kex = KexGex(transport)
kex.start_kex()
- self.assertEquals((paramiko.kex_gex._MSG_KEXDH_GEX_REQUEST, paramiko.kex_gex._MSG_KEXDH_GEX_REQUEST_OLD), transport._expect)
+ self.assertEqual((paramiko.kex_gex._MSG_KEXDH_GEX_REQUEST, paramiko.kex_gex._MSG_KEXDH_GEX_REQUEST_OLD), transport._expect)
msg = Message()
msg.add_int(2048)
msg.rewind()
kex.parse_next(paramiko.kex_gex._MSG_KEXDH_GEX_REQUEST_OLD, msg)
- x = '1F0000008100FFFFFFFFFFFFFFFFC90FDAA22168C234C4C6628B80DC1CD129024E088A67CC74020BBEA63B139B22514A08798E3404DDEF9519B3CD3A431B302B0A6DF25F14374FE1356D6D51C245E485B576625E7EC6F44C42E9A637ED6B0BFF5CB6F406B7EDEE386BFB5A899FA5AE9F24117C4B1FE649286651ECE65381FFFFFFFFFFFFFFFF0000000102'
- self.assertEquals(x, hexlify(str(transport._message)).upper())
- self.assertEquals((paramiko.kex_gex._MSG_KEXDH_GEX_INIT,), transport._expect)
+ x = b'1F0000008100FFFFFFFFFFFFFFFFC90FDAA22168C234C4C6628B80DC1CD129024E088A67CC74020BBEA63B139B22514A08798E3404DDEF9519B3CD3A431B302B0A6DF25F14374FE1356D6D51C245E485B576625E7EC6F44C42E9A637ED6B0BFF5CB6F406B7EDEE386BFB5A899FA5AE9F24117C4B1FE649286651ECE65381FFFFFFFFFFFFFFFF0000000102'
+ self.assertEqual(x, hexlify(transport._message.asbytes()).upper())
+ self.assertEqual((paramiko.kex_gex._MSG_KEXDH_GEX_INIT,), transport._expect)
msg = Message()
msg.add_mpint(12345)
msg.rewind()
kex.parse_next(paramiko.kex_gex._MSG_KEXDH_GEX_INIT, msg)
- K = 67592995013596137876033460028393339951879041140378510871612128162185209509220726296697886624612526735888348020498716482757677848959420073720160491114319163078862905400020959196386947926388406687288901564192071077389283980347784184487280885335302632305026248574716290537036069329724382811853044654824945750581L
- H = 'B41A06B2E59043CEFC1AE16EC31F1E2D12EC455B'
- x = '210000000866616B652D6B6579000000807E2DDB1743F3487D6545F04F1C8476092FB912B013626AB5BCEB764257D88BBA64243B9F348DF7B41B8C814A995E00299913503456983FFB9178D3CD79EB6D55522418A8ABF65375872E55938AB99A84A0B5FC8A1ECC66A7C3766E7E0F80B7CE2C9225FC2DD683F4764244B72963BBB383F529DCF0C5D17740B8A2ADBE9208D40000000866616B652D736967'
- self.assertEquals(K, transport._K)
- self.assertEquals(H, hexlify(transport._H).upper())
- self.assertEquals(x, hexlify(str(transport._message)).upper())
- self.assert_(transport._activated)
+ K = 67592995013596137876033460028393339951879041140378510871612128162185209509220726296697886624612526735888348020498716482757677848959420073720160491114319163078862905400020959196386947926388406687288901564192071077389283980347784184487280885335302632305026248574716290537036069329724382811853044654824945750581
+ H = b'B41A06B2E59043CEFC1AE16EC31F1E2D12EC455B'
+ x = b'210000000866616B652D6B6579000000807E2DDB1743F3487D6545F04F1C8476092FB912B013626AB5BCEB764257D88BBA64243B9F348DF7B41B8C814A995E00299913503456983FFB9178D3CD79EB6D55522418A8ABF65375872E55938AB99A84A0B5FC8A1ECC66A7C3766E7E0F80B7CE2C9225FC2DD683F4764244B72963BBB383F529DCF0C5D17740B8A2ADBE9208D40000000866616B652D736967'
+ self.assertEqual(K, transport._K)
+ self.assertEqual(H, hexlify(transport._H).upper())
+ self.assertEqual(x, hexlify(transport._message.asbytes()).upper())
+ self.assertTrue(transport._activated)
diff --git a/tests/test_message.py b/tests/test_message.py
index ad622a27..4da52cfb 100644
--- a/tests/test_message.py
+++ b/tests/test_message.py
@@ -22,14 +22,15 @@ Some unit tests for ssh protocol message blocks.
import unittest
from paramiko.message import Message
+from paramiko.common import *
class MessageTest (unittest.TestCase):
- __a = '\x00\x00\x00\x17\x07\x60\xe0\x90\x00\x00\x00\x01q\x00\x00\x00\x05hello\x00\x00\x03\xe8' + ('x' * 1000)
- __b = '\x01\x00\xf3\x00\x3f\x00\x00\x00\x10huey,dewey,louie'
- __c = '\x00\x00\x00\x00\x00\x00\x00\x05\x00\x00\xf5\xe4\xd3\xc2\xb1\x09\x00\x00\x00\x01\x11\x00\x00\x00\x07\x00\xf5\xe4\xd3\xc2\xb1\x09\x00\x00\x00\x06\x9a\x1b\x2c\x3d\x4e\xf7'
- __d = '\x00\x00\x00\x05\x00\x00\x00\x05\x11\x22\x33\x44\x55\x01\x00\x00\x00\x03cat\x00\x00\x00\x03a,b'
+ __a = b'\x00\x00\x00\x17\x07\x60\xe0\x90\x00\x00\x00\x01\x71\x00\x00\x00\x05\x68\x65\x6c\x6c\x6f\x00\x00\x03\xe8' + b'x' * 1000
+ __b = b'\x01\x00\xf3\x00\x3f\x00\x00\x00\x10\x68\x75\x65\x79\x2c\x64\x65\x77\x65\x79\x2c\x6c\x6f\x75\x69\x65'
+ __c = b'\x00\x00\x00\x00\x00\x00\x00\x05\x00\x00\xf5\xe4\xd3\xc2\xb1\x09\x00\x00\x00\x01\x11\x00\x00\x00\x07\x00\xf5\xe4\xd3\xc2\xb1\x09\x00\x00\x00\x06\x9a\x1b\x2c\x3d\x4e\xf7'
+ __d = b'\x00\x00\x00\x05\xff\x00\x00\x00\x05\x11\x22\x33\x44\x55\xff\x00\x00\x00\x0a\x00\xf0\x00\x00\x00\x00\x00\x00\x00\x00\x01\x00\x00\x00\x03\x63\x61\x74\x00\x00\x00\x03\x61\x2c\x62'
def test_1_encode(self):
msg = Message()
@@ -38,63 +39,65 @@ class MessageTest (unittest.TestCase):
msg.add_string('q')
msg.add_string('hello')
msg.add_string('x' * 1000)
- self.assertEquals(str(msg), self.__a)
+ self.assertEqual(msg.asbytes(), self.__a)
msg = Message()
msg.add_boolean(True)
msg.add_boolean(False)
- msg.add_byte('\xf3')
- msg.add_bytes('\x00\x3f')
+ msg.add_byte(byte_chr(0xf3))
+
+ msg.add_bytes(zero_byte + byte_chr(0x3f))
msg.add_list(['huey', 'dewey', 'louie'])
- self.assertEquals(str(msg), self.__b)
+ self.assertEqual(msg.asbytes(), self.__b)
msg = Message()
msg.add_int64(5)
- msg.add_int64(0xf5e4d3c2b109L)
+ msg.add_int64(0xf5e4d3c2b109)
msg.add_mpint(17)
- msg.add_mpint(0xf5e4d3c2b109L)
- msg.add_mpint(-0x65e4d3c2b109L)
- self.assertEquals(str(msg), self.__c)
+ msg.add_mpint(0xf5e4d3c2b109)
+ msg.add_mpint(-0x65e4d3c2b109)
+ self.assertEqual(msg.asbytes(), self.__c)
def test_2_decode(self):
msg = Message(self.__a)
- self.assertEquals(msg.get_int(), 23)
- self.assertEquals(msg.get_int(), 123789456)
- self.assertEquals(msg.get_string(), 'q')
- self.assertEquals(msg.get_string(), 'hello')
- self.assertEquals(msg.get_string(), 'x' * 1000)
+ self.assertEqual(msg.get_int(), 23)
+ self.assertEqual(msg.get_int(), 123789456)
+ self.assertEqual(msg.get_text(), 'q')
+ self.assertEqual(msg.get_text(), 'hello')
+ self.assertEqual(msg.get_text(), 'x' * 1000)
msg = Message(self.__b)
- self.assertEquals(msg.get_boolean(), True)
- self.assertEquals(msg.get_boolean(), False)
- self.assertEquals(msg.get_byte(), '\xf3')
- self.assertEquals(msg.get_bytes(2), '\x00\x3f')
- self.assertEquals(msg.get_list(), ['huey', 'dewey', 'louie'])
+ self.assertEqual(msg.get_boolean(), True)
+ self.assertEqual(msg.get_boolean(), False)
+ self.assertEqual(msg.get_byte(), byte_chr(0xf3))
+ self.assertEqual(msg.get_bytes(2), zero_byte + byte_chr(0x3f))
+ self.assertEqual(msg.get_list(), ['huey', 'dewey', 'louie'])
msg = Message(self.__c)
- self.assertEquals(msg.get_int64(), 5)
- self.assertEquals(msg.get_int64(), 0xf5e4d3c2b109L)
- self.assertEquals(msg.get_mpint(), 17)
- self.assertEquals(msg.get_mpint(), 0xf5e4d3c2b109L)
- self.assertEquals(msg.get_mpint(), -0x65e4d3c2b109L)
+ self.assertEqual(msg.get_int64(), 5)
+ self.assertEqual(msg.get_int64(), 0xf5e4d3c2b109)
+ self.assertEqual(msg.get_mpint(), 17)
+ self.assertEqual(msg.get_mpint(), 0xf5e4d3c2b109)
+ self.assertEqual(msg.get_mpint(), -0x65e4d3c2b109)
def test_3_add(self):
msg = Message()
msg.add(5)
- msg.add(0x1122334455L)
+ msg.add(0x1122334455)
+ msg.add(0xf00000000000000000)
msg.add(True)
msg.add('cat')
msg.add(['a', 'b'])
- self.assertEquals(str(msg), self.__d)
+ self.assertEqual(msg.asbytes(), self.__d)
def test_4_misc(self):
msg = Message(self.__d)
- self.assertEquals(msg.get_int(), 5)
- self.assertEquals(msg.get_mpint(), 0x1122334455L)
- self.assertEquals(msg.get_so_far(), self.__d[:13])
- self.assertEquals(msg.get_remainder(), self.__d[13:])
+ self.assertEqual(msg.get_int(), 5)
+ self.assertEqual(msg.get_int(), 0x1122334455)
+ self.assertEqual(msg.get_int(), 0xf00000000000000000)
+ self.assertEqual(msg.get_so_far(), self.__d[:29])
+ self.assertEqual(msg.get_remainder(), self.__d[29:])
msg.rewind()
- self.assertEquals(msg.get_int(), 5)
- self.assertEquals(msg.get_so_far(), self.__d[:4])
- self.assertEquals(msg.get_remainder(), self.__d[4:])
-
+ self.assertEqual(msg.get_int(), 5)
+ self.assertEqual(msg.get_so_far(), self.__d[:4])
+ self.assertEqual(msg.get_remainder(), self.__d[4:])
diff --git a/tests/test_packetizer.py b/tests/test_packetizer.py
index 1f5bec05..5c36fed6 100644
--- a/tests/test_packetizer.py
+++ b/tests/test_packetizer.py
@@ -21,10 +21,15 @@ Some unit tests for the ssh2 protocol in Transport.
"""
import unittest
-from loop import LoopSocket
+from tests.loop import LoopSocket
from Crypto.Cipher import AES
from Crypto.Hash import SHA, HMAC
from paramiko import Message, Packetizer, util
+from paramiko.common import *
+
+x55 = byte_chr(0x55)
+x1f = byte_chr(0x1f)
+
class PacketizerTest (unittest.TestCase):
@@ -35,22 +40,22 @@ class PacketizerTest (unittest.TestCase):
p = Packetizer(wsock)
p.set_log(util.get_logger('paramiko.transport'))
p.set_hexdump(True)
- cipher = AES.new('\x00' * 16, AES.MODE_CBC, '\x55' * 16)
- p.set_outbound_cipher(cipher, 16, SHA, 12, '\x1f' * 20)
+ cipher = AES.new(zero_byte * 16, AES.MODE_CBC, x55 * 16)
+ p.set_outbound_cipher(cipher, 16, SHA, 12, x1f * 20)
# message has to be at least 16 bytes long, so we'll have at least one
# block of data encrypted that contains zero random padding bytes
m = Message()
- m.add_byte(chr(100))
+ m.add_byte(byte_chr(100))
m.add_int(100)
m.add_int(1)
m.add_int(900)
p.send_message(m)
data = rsock.recv(100)
# 32 + 12 bytes of MAC = 44
- self.assertEquals(44, len(data))
- self.assertEquals('\x43\x91\x97\xbd\x5b\x50\xac\x25\x87\xc2\xc4\x6b\xc7\xe9\x38\xc0', data[:16])
-
+ self.assertEqual(44, len(data))
+ self.assertEqual(b'\x43\x91\x97\xbd\x5b\x50\xac\x25\x87\xc2\xc4\x6b\xc7\xe9\x38\xc0', data[:16])
+
def test_2_read (self):
rsock = LoopSocket()
wsock = LoopSocket()
@@ -58,13 +63,11 @@ class PacketizerTest (unittest.TestCase):
p = Packetizer(rsock)
p.set_log(util.get_logger('paramiko.transport'))
p.set_hexdump(True)
- cipher = AES.new('\x00' * 16, AES.MODE_CBC, '\x55' * 16)
- p.set_inbound_cipher(cipher, 16, SHA, 12, '\x1f' * 20)
-
- wsock.send('C\x91\x97\xbd[P\xac%\x87\xc2\xc4k\xc7\xe98\xc0' + \
- '\x90\xd2\x16V\rqsa8|L=\xfb\x97}\xe2n\x03\xb1\xa0\xc2\x1c\xd6AAL\xb4Y')
+ cipher = AES.new(zero_byte * 16, AES.MODE_CBC, x55 * 16)
+ p.set_inbound_cipher(cipher, 16, SHA, 12, x1f * 20)
+ wsock.send(b'\x43\x91\x97\xbd\x5b\x50\xac\x25\x87\xc2\xc4\x6b\xc7\xe9\x38\xc0\x90\xd2\x16\x56\x0d\x71\x73\x61\x38\x7c\x4c\x3d\xfb\x97\x7d\xe2\x6e\x03\xb1\xa0\xc2\x1c\xd6\x41\x41\x4c\xb4\x59')
cmd, m = p.read_message()
- self.assertEquals(100, cmd)
- self.assertEquals(100, m.get_int())
- self.assertEquals(1, m.get_int())
- self.assertEquals(900, m.get_int())
+ self.assertEqual(100, cmd)
+ self.assertEqual(100, m.get_int())
+ self.assertEqual(1, m.get_int())
+ self.assertEqual(900, m.get_int())
diff --git a/tests/test_pkey.py b/tests/test_pkey.py
index 8e8c4aa7..2e565a5f 100644
--- a/tests/test_pkey.py
+++ b/tests/test_pkey.py
@@ -20,11 +20,11 @@
Some unit tests for public/private key objects.
"""
-from binascii import hexlify, unhexlify
-import StringIO
+from binascii import hexlify
import unittest
from paramiko import RSAKey, DSSKey, ECDSAKey, Message, util
-from paramiko.common import rng
+from paramiko.common import rng, StringIO, byte_chr, b, bytes
+from tests.util import test_path
# from openssh's ssh-keygen
PUB_RSA = 'ssh-rsa AAAAB3NzaC1yc2EAAAABIwAAAIEA049W6geFpmsljTwfvI1UmKWWJPNFI74+vNKTk4dmzkQY2yAMs6FhlvhlI8ysU4oj71ZsRYMecHbBbxdN79+JRFVYTKaLqjwGENeTd+yv4q+V2PvZv3fLnzApI3l7EJCqhWwJUHJ1jAkZzqDx0tyOL4uoZpww3nmE0kb3y21tH4c='
@@ -77,6 +77,9 @@ ADRvOqQ5R98Sxst765CAqXmRtz8vwoD96g==
-----END EC PRIVATE KEY-----
"""
+x1234 = b'\x01\x02\x03\x04'
+
+
class KeyTest (unittest.TestCase):
def setUp(self):
@@ -87,164 +90,164 @@ class KeyTest (unittest.TestCase):
def test_1_generate_key_bytes(self):
from Crypto.Hash import MD5
- key = util.generate_key_bytes(MD5, '\x01\x02\x03\x04', 'happy birthday', 30)
- exp = unhexlify('61E1F272F4C1C4561586BD322498C0E924672780F47BB37DDA7D54019E64')
- self.assertEquals(exp, key)
+ key = util.generate_key_bytes(MD5, x1234, 'happy birthday', 30)
+ exp = b'\x61\xE1\xF2\x72\xF4\xC1\xC4\x56\x15\x86\xBD\x32\x24\x98\xC0\xE9\x24\x67\x27\x80\xF4\x7B\xB3\x7D\xDA\x7D\x54\x01\x9E\x64'
+ self.assertEqual(exp, key)
def test_2_load_rsa(self):
- key = RSAKey.from_private_key_file('tests/test_rsa.key')
- self.assertEquals('ssh-rsa', key.get_name())
- exp_rsa = FINGER_RSA.split()[1].replace(':', '')
+ key = RSAKey.from_private_key_file(test_path('test_rsa.key'))
+ self.assertEqual('ssh-rsa', key.get_name())
+ exp_rsa = b(FINGER_RSA.split()[1].replace(':', ''))
my_rsa = hexlify(key.get_fingerprint())
- self.assertEquals(exp_rsa, my_rsa)
- self.assertEquals(PUB_RSA.split()[1], key.get_base64())
- self.assertEquals(1024, key.get_bits())
+ self.assertEqual(exp_rsa, my_rsa)
+ self.assertEqual(PUB_RSA.split()[1], key.get_base64())
+ self.assertEqual(1024, key.get_bits())
- s = StringIO.StringIO()
+ s = StringIO()
key.write_private_key(s)
- self.assertEquals(RSA_PRIVATE_OUT, s.getvalue())
+ self.assertEqual(RSA_PRIVATE_OUT, s.getvalue())
s.seek(0)
key2 = RSAKey.from_private_key(s)
- self.assertEquals(key, key2)
+ self.assertEqual(key, key2)
def test_3_load_rsa_password(self):
- key = RSAKey.from_private_key_file('tests/test_rsa_password.key', 'television')
- self.assertEquals('ssh-rsa', key.get_name())
- exp_rsa = FINGER_RSA.split()[1].replace(':', '')
+ key = RSAKey.from_private_key_file(test_path('test_rsa_password.key'), 'television')
+ self.assertEqual('ssh-rsa', key.get_name())
+ exp_rsa = b(FINGER_RSA.split()[1].replace(':', ''))
my_rsa = hexlify(key.get_fingerprint())
- self.assertEquals(exp_rsa, my_rsa)
- self.assertEquals(PUB_RSA.split()[1], key.get_base64())
- self.assertEquals(1024, key.get_bits())
+ self.assertEqual(exp_rsa, my_rsa)
+ self.assertEqual(PUB_RSA.split()[1], key.get_base64())
+ self.assertEqual(1024, key.get_bits())
def test_4_load_dss(self):
- key = DSSKey.from_private_key_file('tests/test_dss.key')
- self.assertEquals('ssh-dss', key.get_name())
- exp_dss = FINGER_DSS.split()[1].replace(':', '')
+ key = DSSKey.from_private_key_file(test_path('test_dss.key'))
+ self.assertEqual('ssh-dss', key.get_name())
+ exp_dss = b(FINGER_DSS.split()[1].replace(':', ''))
my_dss = hexlify(key.get_fingerprint())
- self.assertEquals(exp_dss, my_dss)
- self.assertEquals(PUB_DSS.split()[1], key.get_base64())
- self.assertEquals(1024, key.get_bits())
+ self.assertEqual(exp_dss, my_dss)
+ self.assertEqual(PUB_DSS.split()[1], key.get_base64())
+ self.assertEqual(1024, key.get_bits())
- s = StringIO.StringIO()
+ s = StringIO()
key.write_private_key(s)
- self.assertEquals(DSS_PRIVATE_OUT, s.getvalue())
+ self.assertEqual(DSS_PRIVATE_OUT, s.getvalue())
s.seek(0)
key2 = DSSKey.from_private_key(s)
- self.assertEquals(key, key2)
+ self.assertEqual(key, key2)
def test_5_load_dss_password(self):
- key = DSSKey.from_private_key_file('tests/test_dss_password.key', 'television')
- self.assertEquals('ssh-dss', key.get_name())
- exp_dss = FINGER_DSS.split()[1].replace(':', '')
+ key = DSSKey.from_private_key_file(test_path('test_dss_password.key'), 'television')
+ self.assertEqual('ssh-dss', key.get_name())
+ exp_dss = b(FINGER_DSS.split()[1].replace(':', ''))
my_dss = hexlify(key.get_fingerprint())
- self.assertEquals(exp_dss, my_dss)
- self.assertEquals(PUB_DSS.split()[1], key.get_base64())
- self.assertEquals(1024, key.get_bits())
+ self.assertEqual(exp_dss, my_dss)
+ self.assertEqual(PUB_DSS.split()[1], key.get_base64())
+ self.assertEqual(1024, key.get_bits())
def test_6_compare_rsa(self):
# verify that the private & public keys compare equal
- key = RSAKey.from_private_key_file('tests/test_rsa.key')
- self.assertEquals(key, key)
- pub = RSAKey(data=str(key))
- self.assert_(key.can_sign())
- self.assert_(not pub.can_sign())
- self.assertEquals(key, pub)
+ key = RSAKey.from_private_key_file(test_path('test_rsa.key'))
+ self.assertEqual(key, key)
+ pub = RSAKey(data=key.asbytes())
+ self.assertTrue(key.can_sign())
+ self.assertTrue(not pub.can_sign())
+ self.assertEqual(key, pub)
def test_7_compare_dss(self):
# verify that the private & public keys compare equal
- key = DSSKey.from_private_key_file('tests/test_dss.key')
- self.assertEquals(key, key)
- pub = DSSKey(data=str(key))
- self.assert_(key.can_sign())
- self.assert_(not pub.can_sign())
- self.assertEquals(key, pub)
+ key = DSSKey.from_private_key_file(test_path('test_dss.key'))
+ self.assertEqual(key, key)
+ pub = DSSKey(data=key.asbytes())
+ self.assertTrue(key.can_sign())
+ self.assertTrue(not pub.can_sign())
+ self.assertEqual(key, pub)
def test_8_sign_rsa(self):
# verify that the rsa private key can sign and verify
- key = RSAKey.from_private_key_file('tests/test_rsa.key')
- msg = key.sign_ssh_data(rng, 'ice weasels')
- self.assert_(type(msg) is Message)
+ key = RSAKey.from_private_key_file(test_path('test_rsa.key'))
+ msg = key.sign_ssh_data(rng, b'ice weasels')
+ self.assertTrue(type(msg) is Message)
msg.rewind()
- self.assertEquals('ssh-rsa', msg.get_string())
- sig = ''.join([chr(int(x, 16)) for x in SIGNED_RSA.split(':')])
- self.assertEquals(sig, msg.get_string())
+ self.assertEqual('ssh-rsa', msg.get_text())
+ sig = bytes().join([byte_chr(int(x, 16)) for x in SIGNED_RSA.split(':')])
+ self.assertEqual(sig, msg.get_binary())
msg.rewind()
- pub = RSAKey(data=str(key))
- self.assert_(pub.verify_ssh_sig('ice weasels', msg))
+ pub = RSAKey(data=key.asbytes())
+ self.assertTrue(pub.verify_ssh_sig(b'ice weasels', msg))
def test_9_sign_dss(self):
# verify that the dss private key can sign and verify
- key = DSSKey.from_private_key_file('tests/test_dss.key')
- msg = key.sign_ssh_data(rng, 'ice weasels')
- self.assert_(type(msg) is Message)
+ key = DSSKey.from_private_key_file(test_path('test_dss.key'))
+ msg = key.sign_ssh_data(rng, b'ice weasels')
+ self.assertTrue(type(msg) is Message)
msg.rewind()
- self.assertEquals('ssh-dss', msg.get_string())
+ self.assertEqual('ssh-dss', msg.get_text())
# can't do the same test as we do for RSA, because DSS signatures
# are usually different each time. but we can test verification
# anyway so it's ok.
- self.assertEquals(40, len(msg.get_string()))
+ self.assertEqual(40, len(msg.get_binary()))
msg.rewind()
- pub = DSSKey(data=str(key))
- self.assert_(pub.verify_ssh_sig('ice weasels', msg))
+ pub = DSSKey(data=key.asbytes())
+ self.assertTrue(pub.verify_ssh_sig(b'ice weasels', msg))
def test_A_generate_rsa(self):
key = RSAKey.generate(1024)
- msg = key.sign_ssh_data(rng, 'jerri blank')
+ msg = key.sign_ssh_data(rng, b'jerri blank')
msg.rewind()
- self.assert_(key.verify_ssh_sig('jerri blank', msg))
+ self.assertTrue(key.verify_ssh_sig(b'jerri blank', msg))
def test_B_generate_dss(self):
key = DSSKey.generate(1024)
- msg = key.sign_ssh_data(rng, 'jerri blank')
+ msg = key.sign_ssh_data(rng, b'jerri blank')
msg.rewind()
- self.assert_(key.verify_ssh_sig('jerri blank', msg))
+ self.assertTrue(key.verify_ssh_sig(b'jerri blank', msg))
def test_10_load_ecdsa(self):
- key = ECDSAKey.from_private_key_file('tests/test_ecdsa.key')
- self.assertEquals('ecdsa-sha2-nistp256', key.get_name())
- exp_ecdsa = FINGER_ECDSA.split()[1].replace(':', '')
+ key = ECDSAKey.from_private_key_file(test_path('test_ecdsa.key'))
+ self.assertEqual('ecdsa-sha2-nistp256', key.get_name())
+ exp_ecdsa = b(FINGER_ECDSA.split()[1].replace(':', ''))
my_ecdsa = hexlify(key.get_fingerprint())
- self.assertEquals(exp_ecdsa, my_ecdsa)
- self.assertEquals(PUB_ECDSA.split()[1], key.get_base64())
- self.assertEquals(256, key.get_bits())
+ self.assertEqual(exp_ecdsa, my_ecdsa)
+ self.assertEqual(PUB_ECDSA.split()[1], key.get_base64())
+ self.assertEqual(256, key.get_bits())
- s = StringIO.StringIO()
+ s = StringIO()
key.write_private_key(s)
- self.assertEquals(ECDSA_PRIVATE_OUT, s.getvalue())
+ self.assertEqual(ECDSA_PRIVATE_OUT, s.getvalue())
s.seek(0)
key2 = ECDSAKey.from_private_key(s)
- self.assertEquals(key, key2)
+ self.assertEqual(key, key2)
def test_11_load_ecdsa_password(self):
- key = ECDSAKey.from_private_key_file('tests/test_ecdsa_password.key', 'television')
- self.assertEquals('ecdsa-sha2-nistp256', key.get_name())
- exp_ecdsa = FINGER_ECDSA.split()[1].replace(':', '')
+ key = ECDSAKey.from_private_key_file(test_path('test_ecdsa_password.key'), b'television')
+ self.assertEqual('ecdsa-sha2-nistp256', key.get_name())
+ exp_ecdsa = b(FINGER_ECDSA.split()[1].replace(':', ''))
my_ecdsa = hexlify(key.get_fingerprint())
- self.assertEquals(exp_ecdsa, my_ecdsa)
- self.assertEquals(PUB_ECDSA.split()[1], key.get_base64())
- self.assertEquals(256, key.get_bits())
+ self.assertEqual(exp_ecdsa, my_ecdsa)
+ self.assertEqual(PUB_ECDSA.split()[1], key.get_base64())
+ self.assertEqual(256, key.get_bits())
def test_12_compare_ecdsa(self):
# verify that the private & public keys compare equal
- key = ECDSAKey.from_private_key_file('tests/test_ecdsa.key')
- self.assertEquals(key, key)
- pub = ECDSAKey(data=str(key))
- self.assert_(key.can_sign())
- self.assert_(not pub.can_sign())
- self.assertEquals(key, pub)
+ key = ECDSAKey.from_private_key_file(test_path('test_ecdsa.key'))
+ self.assertEqual(key, key)
+ pub = ECDSAKey(data=key.asbytes())
+ self.assertTrue(key.can_sign())
+ self.assertTrue(not pub.can_sign())
+ self.assertEqual(key, pub)
def test_13_sign_ecdsa(self):
# verify that the rsa private key can sign and verify
- key = ECDSAKey.from_private_key_file('tests/test_ecdsa.key')
- msg = key.sign_ssh_data(rng, 'ice weasels')
- self.assert_(type(msg) is Message)
+ key = ECDSAKey.from_private_key_file(test_path('test_ecdsa.key'))
+ msg = key.sign_ssh_data(rng, b'ice weasels')
+ self.assertTrue(type(msg) is Message)
msg.rewind()
- self.assertEquals('ecdsa-sha2-nistp256', msg.get_string())
+ self.assertEqual('ecdsa-sha2-nistp256', msg.get_text())
# ECDSA signatures, like DSS signatures, tend to be different
# each time, so we can't compare against a "known correct"
# signature.
# Even the length of the signature can change.
msg.rewind()
- pub = ECDSAKey(data=str(key))
- self.assert_(pub.verify_ssh_sig('ice weasels', msg))
+ pub = ECDSAKey(data=key.asbytes())
+ self.assertTrue(pub.verify_ssh_sig(b'ice weasels', msg))
diff --git a/tests/test_sftp.py b/tests/test_sftp.py
index cc512c18..f8fab1ce 100755
--- a/tests/test_sftp.py
+++ b/tests/test_sftp.py
@@ -23,19 +23,18 @@ a real actual sftp server is contacted, and a new folder is created there to
do test file operations in (so no existing files will be harmed).
"""
-from __future__ import with_statement
-
from binascii import hexlify
import os
import warnings
-import sys
import threading
import unittest
-import StringIO
+from tempfile import mkstemp
import paramiko
-from stub_sftp import StubServer, StubSFTPServer
-from loop import LoopSocket
+from paramiko.common import *
+from tests.stub_sftp import StubServer, StubSFTPServer
+from tests.loop import LoopSocket
+from tests.util import test_path
from paramiko.sftp_attr import SFTPAttributes
ARTICLE = '''
@@ -70,6 +69,10 @@ FOLDER = os.environ.get('TEST_FOLDER', 'temp-testing000')
sftp = None
tc = None
g_big_file_test = True
+# we need to use eval(compile()) here because Py3.2 doesn't support the 'u' marker for unicode
+# this test is the only line in the entire program that has to be treated specially to support Py3.2
+unicode_folder = eval(compile(r"u'\u00fcnic\u00f8de'" if PY2 else r"'\u00fcnic\u00f8de'", 'test_sftp.py', 'eval'))
+utf8_folder = b'/\xc3\xbcnic\xc3\xb8\x64\x65'
def get_sftp():
@@ -121,7 +124,7 @@ class SFTPTest (unittest.TestCase):
tc = paramiko.Transport(sockc)
ts = paramiko.Transport(socks)
- host_key = paramiko.RSAKey.from_private_key_file('tests/test_rsa.key')
+ host_key = paramiko.RSAKey.from_private_key_file(test_path('test_rsa.key'))
ts.add_server_key(host_key)
event = threading.Event()
server = StubServer()
@@ -140,7 +143,7 @@ class SFTPTest (unittest.TestCase):
def setUp(self):
global FOLDER
- for i in xrange(1000):
+ for i in range(1000):
FOLDER = FOLDER[:-3] + '%03d' % i
try:
sftp.mkdir(FOLDER)
@@ -149,6 +152,7 @@ class SFTPTest (unittest.TestCase):
pass
def tearDown(self):
+ #sftp.chdir()
sftp.rmdir(FOLDER)
def test_1_file(self):
@@ -158,8 +162,8 @@ class SFTPTest (unittest.TestCase):
f = sftp.open(FOLDER + '/test', 'w')
try:
self.assertEqual(f.stat().st_size, 0)
- f.close()
finally:
+ f.close()
sftp.remove(FOLDER + '/test')
def test_2_close(self):
@@ -180,10 +184,9 @@ class SFTPTest (unittest.TestCase):
"""
verify that a file can be created and written, and the size is correct.
"""
- f = sftp.open(FOLDER + '/duck.txt', 'w')
try:
- f.write(ARTICLE)
- f.close()
+ with sftp.open(FOLDER + '/duck.txt', 'w') as f:
+ f.write(ARTICLE)
self.assertEqual(sftp.stat(FOLDER + '/duck.txt').st_size, 1483)
finally:
sftp.remove(FOLDER + '/duck.txt')
@@ -203,19 +206,17 @@ class SFTPTest (unittest.TestCase):
"""
verify that a file can be opened for append, and tell() still works.
"""
- f = sftp.open(FOLDER + '/append.txt', 'w')
try:
- f.write('first line\nsecond line\n')
- self.assertEqual(f.tell(), 23)
- f.close()
-
- f = sftp.open(FOLDER + '/append.txt', 'a+')
- f.write('third line!!!\n')
- self.assertEqual(f.tell(), 37)
- self.assertEqual(f.stat().st_size, 37)
- f.seek(-26, f.SEEK_CUR)
- self.assertEqual(f.readline(), 'second line\n')
- f.close()
+ with sftp.open(FOLDER + '/append.txt', 'w') as f:
+ f.write('first line\nsecond line\n')
+ self.assertEqual(f.tell(), 23)
+
+ with sftp.open(FOLDER + '/append.txt', 'a+') as f:
+ f.write('third line!!!\n')
+ self.assertEqual(f.tell(), 37)
+ self.assertEqual(f.stat().st_size, 37)
+ f.seek(-26, f.SEEK_CUR)
+ self.assertEqual(f.readline(), 'second line\n')
finally:
sftp.remove(FOLDER + '/append.txt')
@@ -223,20 +224,18 @@ class SFTPTest (unittest.TestCase):
"""
verify that renaming a file works.
"""
- f = sftp.open(FOLDER + '/first.txt', 'w')
try:
- f.write('content!\n')
- f.close()
+ with sftp.open(FOLDER + '/first.txt', 'w') as f:
+ f.write('content!\n')
sftp.rename(FOLDER + '/first.txt', FOLDER + '/second.txt')
try:
- f = sftp.open(FOLDER + '/first.txt', 'r')
- self.assert_(False, 'no exception on reading nonexistent file')
+ sftp.open(FOLDER + '/first.txt', 'r')
+ self.assertTrue(False, 'no exception on reading nonexistent file')
except IOError:
pass
- f = sftp.open(FOLDER + '/second.txt', 'r')
- f.seek(-6, f.SEEK_END)
- self.assertEqual(f.read(4), 'tent')
- f.close()
+ with sftp.open(FOLDER + '/second.txt', 'r') as f:
+ f.seek(-6, f.SEEK_END)
+ self.assertEqual(u(f.read(4)), 'tent')
finally:
try:
sftp.remove(FOLDER + '/first.txt')
@@ -253,14 +252,13 @@ class SFTPTest (unittest.TestCase):
remove the folder and verify that we can't create a file in it anymore.
"""
sftp.mkdir(FOLDER + '/subfolder')
- f = sftp.open(FOLDER + '/subfolder/test', 'w')
- f.close()
+ sftp.open(FOLDER + '/subfolder/test', 'w').close()
sftp.remove(FOLDER + '/subfolder/test')
sftp.rmdir(FOLDER + '/subfolder')
try:
- f = sftp.open(FOLDER + '/subfolder/test')
+ sftp.open(FOLDER + '/subfolder/test')
# shouldn't be able to create that file
- self.assert_(False, 'no exception at dummy file creation')
+ self.assertTrue(False, 'no exception at dummy file creation')
except IOError:
pass
@@ -270,21 +268,16 @@ class SFTPTest (unittest.TestCase):
and those files show up in sftp.listdir.
"""
try:
- f = sftp.open(FOLDER + '/duck.txt', 'w')
- f.close()
-
- f = sftp.open(FOLDER + '/fish.txt', 'w')
- f.close()
-
- f = sftp.open(FOLDER + '/tertiary.py', 'w')
- f.close()
+ sftp.open(FOLDER + '/duck.txt', 'w').close()
+ sftp.open(FOLDER + '/fish.txt', 'w').close()
+ sftp.open(FOLDER + '/tertiary.py', 'w').close()
x = sftp.listdir(FOLDER)
self.assertEqual(len(x), 3)
- self.assert_('duck.txt' in x)
- self.assert_('fish.txt' in x)
- self.assert_('tertiary.py' in x)
- self.assert_('random' not in x)
+ self.assertTrue('duck.txt' in x)
+ self.assertTrue('fish.txt' in x)
+ self.assertTrue('tertiary.py' in x)
+ self.assertTrue('random' not in x)
finally:
sftp.remove(FOLDER + '/duck.txt')
sftp.remove(FOLDER + '/fish.txt')
@@ -294,22 +287,21 @@ class SFTPTest (unittest.TestCase):
"""
verify that the setstat functions (chown, chmod, utime, truncate) work.
"""
- f = sftp.open(FOLDER + '/special', 'w')
try:
- f.write('x' * 1024)
- f.close()
+ with sftp.open(FOLDER + '/special', 'w') as f:
+ f.write('x' * 1024)
stat = sftp.stat(FOLDER + '/special')
- sftp.chmod(FOLDER + '/special', (stat.st_mode & ~0777) | 0600)
+ sftp.chmod(FOLDER + '/special', (stat.st_mode & ~o777) | o600)
stat = sftp.stat(FOLDER + '/special')
- expected_mode = 0600
+ expected_mode = o600
if sys.platform == 'win32':
# chmod not really functional on windows
- expected_mode = 0666
+ expected_mode = o666
if sys.platform == 'cygwin':
# even worse.
- expected_mode = 0644
- self.assertEqual(stat.st_mode & 0777, expected_mode)
+ expected_mode = o644
+ self.assertEqual(stat.st_mode & o777, expected_mode)
self.assertEqual(stat.st_size, 1024)
mtime = stat.st_mtime - 3600
@@ -333,40 +325,38 @@ class SFTPTest (unittest.TestCase):
verify that the fsetstat functions (chown, chmod, utime, truncate)
work on open files.
"""
- f = sftp.open(FOLDER + '/special', 'w')
try:
- f.write('x' * 1024)
- f.close()
-
- f = sftp.open(FOLDER + '/special', 'r+')
- stat = f.stat()
- f.chmod((stat.st_mode & ~0777) | 0600)
- stat = f.stat()
-
- expected_mode = 0600
- if sys.platform == 'win32':
- # chmod not really functional on windows
- expected_mode = 0666
- if sys.platform == 'cygwin':
- # even worse.
- expected_mode = 0644
- self.assertEqual(stat.st_mode & 0777, expected_mode)
- self.assertEqual(stat.st_size, 1024)
-
- mtime = stat.st_mtime - 3600
- atime = stat.st_atime - 1800
- f.utime((atime, mtime))
- stat = f.stat()
- self.assertEqual(stat.st_mtime, mtime)
- if sys.platform not in ('win32', 'cygwin'):
- self.assertEqual(stat.st_atime, atime)
-
- # can't really test chown, since we'd have to know a valid uid.
-
- f.truncate(512)
- stat = f.stat()
- self.assertEqual(stat.st_size, 512)
- f.close()
+ with sftp.open(FOLDER + '/special', 'w') as f:
+ f.write('x' * 1024)
+
+ with sftp.open(FOLDER + '/special', 'r+') as f:
+ stat = f.stat()
+ f.chmod((stat.st_mode & ~o777) | o600)
+ stat = f.stat()
+
+ expected_mode = o600
+ if sys.platform == 'win32':
+ # chmod not really functional on windows
+ expected_mode = o666
+ if sys.platform == 'cygwin':
+ # even worse.
+ expected_mode = o644
+ self.assertEqual(stat.st_mode & o777, expected_mode)
+ self.assertEqual(stat.st_size, 1024)
+
+ mtime = stat.st_mtime - 3600
+ atime = stat.st_atime - 1800
+ f.utime((atime, mtime))
+ stat = f.stat()
+ self.assertEqual(stat.st_mtime, mtime)
+ if sys.platform not in ('win32', 'cygwin'):
+ self.assertEqual(stat.st_atime, atime)
+
+ # can't really test chown, since we'd have to know a valid uid.
+
+ f.truncate(512)
+ stat = f.stat()
+ self.assertEqual(stat.st_size, 512)
finally:
sftp.remove(FOLDER + '/special')
@@ -378,25 +368,23 @@ class SFTPTest (unittest.TestCase):
buffering is reset on 'seek'.
"""
try:
- f = sftp.open(FOLDER + '/duck.txt', 'w')
- f.write(ARTICLE)
- f.close()
+ with sftp.open(FOLDER + '/duck.txt', 'w') as f:
+ f.write(ARTICLE)
- f = sftp.open(FOLDER + '/duck.txt', 'r+')
- line_number = 0
- loc = 0
- pos_list = []
- for line in f:
- line_number += 1
- pos_list.append(loc)
- loc = f.tell()
- f.seek(pos_list[6], f.SEEK_SET)
- self.assertEqual(f.readline(), 'Nouzilly, France.\n')
- f.seek(pos_list[17], f.SEEK_SET)
- self.assertEqual(f.readline()[:4], 'duck')
- f.seek(pos_list[10], f.SEEK_SET)
- self.assertEqual(f.readline(), 'duck types were equally resistant to exogenous insulin compared with chicken.\n')
- f.close()
+ with sftp.open(FOLDER + '/duck.txt', 'r+') as f:
+ line_number = 0
+ loc = 0
+ pos_list = []
+ for line in f:
+ line_number += 1
+ pos_list.append(loc)
+ loc = f.tell()
+ f.seek(pos_list[6], f.SEEK_SET)
+ self.assertEqual(f.readline(), 'Nouzilly, France.\n')
+ f.seek(pos_list[17], f.SEEK_SET)
+ self.assertEqual(f.readline()[:4], 'duck')
+ f.seek(pos_list[10], f.SEEK_SET)
+ self.assertEqual(f.readline(), 'duck types were equally resistant to exogenous insulin compared with chicken.\n')
finally:
sftp.remove(FOLDER + '/duck.txt')
@@ -405,17 +393,15 @@ class SFTPTest (unittest.TestCase):
create a text file, seek back and change part of it, and verify that the
changes worked.
"""
- f = sftp.open(FOLDER + '/testing.txt', 'w')
try:
- f.write('hello kitty.\n')
- f.seek(-5, f.SEEK_CUR)
- f.write('dd')
- f.close()
+ with sftp.open(FOLDER + '/testing.txt', 'w') as f:
+ f.write('hello kitty.\n')
+ f.seek(-5, f.SEEK_CUR)
+ f.write('dd')
self.assertEqual(sftp.stat(FOLDER + '/testing.txt').st_size, 13)
- f = sftp.open(FOLDER + '/testing.txt', 'r')
- data = f.read(20)
- f.close()
+ with sftp.open(FOLDER + '/testing.txt', 'r') as f:
+ data = f.read(20)
self.assertEqual(data, 'hello kiddy.\n')
finally:
sftp.remove(FOLDER + '/testing.txt')
@@ -428,16 +414,14 @@ class SFTPTest (unittest.TestCase):
# skip symlink tests on windows
return
- f = sftp.open(FOLDER + '/original.txt', 'w')
try:
- f.write('original\n')
- f.close()
+ with sftp.open(FOLDER + '/original.txt', 'w') as f:
+ f.write('original\n')
sftp.symlink('original.txt', FOLDER + '/link.txt')
self.assertEqual(sftp.readlink(FOLDER + '/link.txt'), 'original.txt')
- f = sftp.open(FOLDER + '/link.txt', 'r')
- self.assertEqual(f.readlines(), ['original\n'])
- f.close()
+ with sftp.open(FOLDER + '/link.txt', 'r') as f:
+ self.assertEqual(f.readlines(), ['original\n'])
cwd = sftp.normalize('.')
if cwd[-1] == '/':
@@ -450,7 +434,7 @@ class SFTPTest (unittest.TestCase):
self.assertEqual(sftp.stat(FOLDER + '/link.txt').st_size, 9)
# the sftp server may be hiding extra path members from us, so the
# length may be longer than we expect:
- self.assert_(sftp.lstat(FOLDER + '/link2.txt').st_size >= len(abs_path))
+ self.assertTrue(sftp.lstat(FOLDER + '/link2.txt').st_size >= len(abs_path))
self.assertEqual(sftp.stat(FOLDER + '/link2.txt').st_size, 9)
self.assertEqual(sftp.stat(FOLDER + '/original.txt').st_size, 9)
finally:
@@ -471,18 +455,16 @@ class SFTPTest (unittest.TestCase):
"""
verify that buffered writes are automatically flushed on seek.
"""
- f = sftp.open(FOLDER + '/happy.txt', 'w', 1)
try:
- f.write('full line.\n')
- f.write('partial')
- f.seek(9, f.SEEK_SET)
- f.write('?\n')
- f.close()
-
- f = sftp.open(FOLDER + '/happy.txt', 'r')
- self.assertEqual(f.readline(), 'full line?\n')
- self.assertEqual(f.read(7), 'partial')
- f.close()
+ with sftp.open(FOLDER + '/happy.txt', 'w', 1) as f:
+ f.write('full line.\n')
+ f.write('partial')
+ f.seek(9, f.SEEK_SET)
+ f.write('?\n')
+
+ with sftp.open(FOLDER + '/happy.txt', 'r') as f:
+ self.assertEqual(f.readline(), 'full line?\n')
+ self.assertEqual(f.read(7), 'partial')
finally:
try:
sftp.remove(FOLDER + '/happy.txt')
@@ -495,10 +477,10 @@ class SFTPTest (unittest.TestCase):
error.
"""
pwd = sftp.normalize('.')
- self.assert_(len(pwd) > 0)
+ self.assertTrue(len(pwd) > 0)
f = sftp.normalize('./' + FOLDER)
- self.assert_(len(f) > 0)
- self.assertEquals(os.path.join(pwd, FOLDER), f)
+ self.assertTrue(len(f) > 0)
+ self.assertEqual(os.path.join(pwd, FOLDER), f)
def test_F_mkdir(self):
"""
@@ -507,19 +489,19 @@ class SFTPTest (unittest.TestCase):
try:
sftp.mkdir(FOLDER + '/subfolder')
except:
- self.assert_(False, 'exception creating subfolder')
+ self.assertTrue(False, 'exception creating subfolder')
try:
sftp.mkdir(FOLDER + '/subfolder')
- self.assert_(False, 'no exception overwriting subfolder')
+ self.assertTrue(False, 'no exception overwriting subfolder')
except IOError:
pass
try:
sftp.rmdir(FOLDER + '/subfolder')
except:
- self.assert_(False, 'exception removing subfolder')
+ self.assertTrue(False, 'exception removing subfolder')
try:
sftp.rmdir(FOLDER + '/subfolder')
- self.assert_(False, 'no exception removing nonexistent subfolder')
+ self.assertTrue(False, 'no exception removing nonexistent subfolder')
except IOError:
pass
@@ -534,17 +516,16 @@ class SFTPTest (unittest.TestCase):
sftp.mkdir(FOLDER + '/alpha')
sftp.chdir(FOLDER + '/alpha')
sftp.mkdir('beta')
- self.assertEquals(root + FOLDER + '/alpha', sftp.getcwd())
- self.assertEquals(['beta'], sftp.listdir('.'))
+ self.assertEqual(root + FOLDER + '/alpha', sftp.getcwd())
+ self.assertEqual(['beta'], sftp.listdir('.'))
sftp.chdir('beta')
- f = sftp.open('fish', 'w')
- f.write('hello\n')
- f.close()
+ with sftp.open('fish', 'w') as f:
+ f.write('hello\n')
sftp.chdir('..')
- self.assertEquals(['fish'], sftp.listdir('beta'))
+ self.assertEqual(['fish'], sftp.listdir('beta'))
sftp.chdir('..')
- self.assertEquals(['fish'], sftp.listdir('alpha/beta'))
+ self.assertEqual(['fish'], sftp.listdir('alpha/beta'))
finally:
sftp.chdir(root)
try:
@@ -566,30 +547,29 @@ class SFTPTest (unittest.TestCase):
"""
warnings.filterwarnings('ignore', 'tempnam.*')
- localname = os.tempnam()
- text = 'All I wanted was a plastic bunny rabbit.\n'
- f = open(localname, 'wb')
- f.write(text)
- f.close()
+ fd, localname = mkstemp()
+ os.close(fd)
+ text = b'All I wanted was a plastic bunny rabbit.\n'
+ with open(localname, 'wb') as f:
+ f.write(text)
saved_progress = []
def progress_callback(x, y):
saved_progress.append((x, y))
sftp.put(localname, FOLDER + '/bunny.txt', progress_callback)
- f = sftp.open(FOLDER + '/bunny.txt', 'r')
- self.assertEquals(text, f.read(128))
- f.close()
- self.assertEquals((41, 41), saved_progress[-1])
+ with sftp.open(FOLDER + '/bunny.txt', 'rb') as f:
+ self.assertEqual(text, f.read(128))
+ self.assertEqual((41, 41), saved_progress[-1])
os.unlink(localname)
- localname = os.tempnam()
+ fd, localname = mkstemp()
+ os.close(fd)
saved_progress = []
sftp.get(FOLDER + '/bunny.txt', localname, progress_callback)
- f = open(localname, 'rb')
- self.assertEquals(text, f.read(128))
- f.close()
- self.assertEquals((41, 41), saved_progress[-1])
+ with open(localname, 'rb') as f:
+ self.assertEqual(text, f.read(128))
+ self.assertEqual((41, 41), saved_progress[-1])
os.unlink(localname)
sftp.unlink(FOLDER + '/bunny.txt')
@@ -600,20 +580,18 @@ class SFTPTest (unittest.TestCase):
(it's an sftp extension that we support, and may be the only ones who
support it.)
"""
- f = sftp.open(FOLDER + '/kitty.txt', 'w')
- f.write('here kitty kitty' * 64)
- f.close()
+ with sftp.open(FOLDER + '/kitty.txt', 'w') as f:
+ f.write('here kitty kitty' * 64)
try:
- f = sftp.open(FOLDER + '/kitty.txt', 'r')
- sum = f.check('sha1')
- self.assertEquals('91059CFC6615941378D413CB5ADAF4C5EB293402', hexlify(sum).upper())
- sum = f.check('md5', 0, 512)
- self.assertEquals('93DE4788FCA28D471516963A1FE3856A', hexlify(sum).upper())
- sum = f.check('md5', 0, 0, 510)
- self.assertEquals('EB3B45B8CD55A0707D99B177544A319F373183D241432BB2157AB9E46358C4AC90370B5CADE5D90336FC1716F90B36D6',
- hexlify(sum).upper())
- f.close()
+ with sftp.open(FOLDER + '/kitty.txt', 'r') as f:
+ sum = f.check('sha1')
+ self.assertEqual('91059CFC6615941378D413CB5ADAF4C5EB293402', u(hexlify(sum)).upper())
+ sum = f.check('md5', 0, 512)
+ self.assertEqual('93DE4788FCA28D471516963A1FE3856A', u(hexlify(sum)).upper())
+ sum = f.check('md5', 0, 0, 510)
+ self.assertEqual('EB3B45B8CD55A0707D99B177544A319F373183D241432BB2157AB9E46358C4AC90370B5CADE5D90336FC1716F90B36D6',
+ u(hexlify(sum)).upper())
finally:
sftp.unlink(FOLDER + '/kitty.txt')
@@ -621,12 +599,11 @@ class SFTPTest (unittest.TestCase):
"""
verify that the 'x' flag works when opening a file.
"""
- f = sftp.open(FOLDER + '/unusual.txt', 'wx')
- f.close()
+ sftp.open(FOLDER + '/unusual.txt', 'wx').close()
try:
try:
- f = sftp.open(FOLDER + '/unusual.txt', 'wx')
+ sftp.open(FOLDER + '/unusual.txt', 'wx')
self.fail('expected exception')
except IOError:
pass
@@ -637,44 +614,39 @@ class SFTPTest (unittest.TestCase):
"""
verify that unicode strings are encoded into utf8 correctly.
"""
- f = sftp.open(FOLDER + '/something', 'w')
- f.write('okay')
- f.close()
+ with sftp.open(FOLDER + '/something', 'w') as f:
+ f.write('okay')
try:
- sftp.rename(FOLDER + '/something', FOLDER + u'/\u00fcnic\u00f8de')
- sftp.open(FOLDER + '/\xc3\xbcnic\xc3\xb8\x64\x65', 'r')
- except Exception, e:
- self.fail('exception ' + e)
- sftp.unlink(FOLDER + '/\xc3\xbcnic\xc3\xb8\x64\x65')
+ sftp.rename(FOLDER + '/something', FOLDER + '/' + unicode_folder)
+ sftp.open(b(FOLDER) + utf8_folder, 'r')
+ except Exception as e:
+ self.fail('exception ' + str(e))
+ sftp.unlink(b(FOLDER) + utf8_folder)
def test_L_utf8_chdir(self):
- sftp.mkdir(FOLDER + u'\u00fcnic\u00f8de')
+ sftp.mkdir(FOLDER + '/' + unicode_folder)
try:
- sftp.chdir(FOLDER + u'\u00fcnic\u00f8de')
- f = sftp.open('something', 'w')
- f.write('okay')
- f.close()
+ sftp.chdir(FOLDER + '/' + unicode_folder)
+ with sftp.open('something', 'w') as f:
+ f.write('okay')
sftp.unlink('something')
finally:
- sftp.chdir(None)
- sftp.rmdir(FOLDER + u'\u00fcnic\u00f8de')
+ sftp.chdir()
+ sftp.rmdir(FOLDER + '/' + unicode_folder)
def test_M_bad_readv(self):
"""
verify that readv at the end of the file doesn't essplode.
"""
- f = sftp.open(FOLDER + '/zero', 'w')
- f.close()
+ sftp.open(FOLDER + '/zero', 'w').close()
try:
- f = sftp.open(FOLDER + '/zero', 'r')
- f.readv([(0, 12)])
- f.close()
+ with sftp.open(FOLDER + '/zero', 'r') as f:
+ f.readv([(0, 12)])
- f = sftp.open(FOLDER + '/zero', 'r')
- f.prefetch()
- f.read(100)
- f.close()
+ with sftp.open(FOLDER + '/zero', 'r') as f:
+ f.prefetch()
+ f.read(100)
finally:
sftp.unlink(FOLDER + '/zero')
@@ -684,45 +656,61 @@ class SFTPTest (unittest.TestCase):
"""
warnings.filterwarnings('ignore', 'tempnam.*')
- localname = os.tempnam()
+ fd, localname = mkstemp()
+ os.close(fd)
text = 'All I wanted was a plastic bunny rabbit.\n'
- f = open(localname, 'wb')
- f.write(text)
- f.close()
+ with open(localname, 'w') as f:
+ f.write(text)
saved_progress = []
def progress_callback(x, y):
saved_progress.append((x, y))
res = sftp.put(localname, FOLDER + '/bunny.txt', progress_callback, False)
- self.assertEquals(SFTPAttributes().attr, res.attr)
+ self.assertEqual(SFTPAttributes().attr, res.attr)
- f = sftp.open(FOLDER + '/bunny.txt', 'r')
- self.assertEquals(text, f.read(128))
- f.close()
- self.assertEquals((41, 41), saved_progress[-1])
+ with sftp.open(FOLDER + '/bunny.txt', 'r') as f:
+ self.assertEqual(text, f.read(128))
+ self.assertEqual((41, 41), saved_progress[-1])
os.unlink(localname)
sftp.unlink(FOLDER + '/bunny.txt')
+ def test_O_getcwd(self):
+ """
+ verify that chdir/getcwd work.
+ """
+ self.assertEqual(None, sftp.getcwd())
+ root = sftp.normalize('.')
+ if root[-1] != '/':
+ root += '/'
+ try:
+ sftp.mkdir(FOLDER + '/alpha')
+ sftp.chdir(FOLDER + '/alpha')
+ self.assertEqual('/' + FOLDER + '/alpha', sftp.getcwd())
+ finally:
+ sftp.chdir(root)
+ try:
+ sftp.rmdir(FOLDER + '/alpha')
+ except:
+ pass
+
def XXX_test_M_seek_append(self):
"""
verify that seek does't affect writes during append.
does not work except through paramiko. :( openssh fails.
"""
- f = sftp.open(FOLDER + '/append.txt', 'a')
try:
- f.write('first line\nsecond line\n')
- f.seek(11, f.SEEK_SET)
- f.write('third line\n')
- f.close()
-
- f = sftp.open(FOLDER + '/append.txt', 'r')
- self.assertEqual(f.stat().st_size, 34)
- self.assertEqual(f.readline(), 'first line\n')
- self.assertEqual(f.readline(), 'second line\n')
- self.assertEqual(f.readline(), 'third line\n')
- f.close()
+ with sftp.open(FOLDER + '/append.txt', 'a') as f:
+ f.write('first line\nsecond line\n')
+ f.seek(11, f.SEEK_SET)
+ f.write('third line\n')
+
+ with sftp.open(FOLDER + '/append.txt', 'r') as f:
+ self.assertEqual(f.stat().st_size, 34)
+ self.assertEqual(f.readline(), 'first line\n')
+ self.assertEqual(f.readline(), 'second line\n')
+ self.assertEqual(f.readline(), 'third line\n')
finally:
sftp.remove(FOLDER + '/append.txt')
@@ -731,10 +719,16 @@ class SFTPTest (unittest.TestCase):
Send an empty file and confirm it is sent.
"""
target = FOLDER + '/empty file.txt'
- stream = StringIO.StringIO()
+ stream = StringIO()
try:
attrs = sftp.putfo(stream, target)
# the returned attributes should not be null
self.assertNotEqual(attrs, None)
finally:
sftp.remove(target)
+
+
+if __name__ == '__main__':
+ SFTPTest.init_loopback()
+ from unittest import main
+ main()
diff --git a/tests/test_sftp_big.py b/tests/test_sftp_big.py
index 04b15b0d..6870c6b4 100644
--- a/tests/test_sftp_big.py
+++ b/tests/test_sftp_big.py
@@ -33,9 +33,10 @@ import time
import unittest
import paramiko
-from stub_sftp import StubServer, StubSFTPServer
-from loop import LoopSocket
-from test_sftp import get_sftp
+from paramiko.common import *
+from tests.stub_sftp import StubServer, StubSFTPServer
+from tests.loop import LoopSocket
+from tests.test_sftp import get_sftp
FOLDER = os.environ.get('TEST_FOLDER', 'temp-testing000')
@@ -45,7 +46,7 @@ class BigSFTPTest (unittest.TestCase):
def setUp(self):
global FOLDER
sftp = get_sftp()
- for i in xrange(1000):
+ for i in range(1000):
FOLDER = FOLDER[:-3] + '%03d' % i
try:
sftp.mkdir(FOLDER)
@@ -65,19 +66,17 @@ class BigSFTPTest (unittest.TestCase):
numfiles = 100
try:
for i in range(numfiles):
- f = sftp.open('%s/file%d.txt' % (FOLDER, i), 'w', 1)
- f.write('this is file #%d.\n' % i)
- f.close()
- sftp.chmod('%s/file%d.txt' % (FOLDER, i), 0660)
+ with sftp.open('%s/file%d.txt' % (FOLDER, i), 'w', 1) as f:
+ f.write('this is file #%d.\n' % i)
+ sftp.chmod('%s/file%d.txt' % (FOLDER, i), o660)
# now make sure every file is there, by creating a list of filenmes
# and reading them in random order.
- numlist = range(numfiles)
+ numlist = list(range(numfiles))
while len(numlist) > 0:
r = numlist[random.randint(0, len(numlist) - 1)]
- f = sftp.open('%s/file%d.txt' % (FOLDER, r))
- self.assertEqual(f.readline(), 'this is file #%d.\n' % r)
- f.close()
+ with sftp.open('%s/file%d.txt' % (FOLDER, r)) as f:
+ self.assertEqual(f.readline(), 'this is file #%d.\n' % r)
numlist.remove(r)
finally:
for i in range(numfiles):
@@ -94,12 +93,11 @@ class BigSFTPTest (unittest.TestCase):
kblob = (1024 * 'x')
start = time.time()
try:
- f = sftp.open('%s/hongry.txt' % FOLDER, 'w')
- for n in range(1024):
- f.write(kblob)
- if n % 128 == 0:
- sys.stderr.write('.')
- f.close()
+ with sftp.open('%s/hongry.txt' % FOLDER, 'w') as f:
+ for n in range(1024):
+ f.write(kblob)
+ if n % 128 == 0:
+ sys.stderr.write('.')
sys.stderr.write(' ')
self.assertEqual(sftp.stat('%s/hongry.txt' % FOLDER).st_size, 1024 * 1024)
@@ -107,11 +105,10 @@ class BigSFTPTest (unittest.TestCase):
sys.stderr.write('%ds ' % round(end - start))
start = time.time()
- f = sftp.open('%s/hongry.txt' % FOLDER, 'r')
- for n in range(1024):
- data = f.read(1024)
- self.assertEqual(data, kblob)
- f.close()
+ with sftp.open('%s/hongry.txt' % FOLDER, 'r') as f:
+ for n in range(1024):
+ data = f.read(1024)
+ self.assertEqual(data, kblob)
end = time.time()
sys.stderr.write('%ds ' % round(end - start))
@@ -123,16 +120,15 @@ class BigSFTPTest (unittest.TestCase):
write a 1MB file, with no linefeeds, using pipelining.
"""
sftp = get_sftp()
- kblob = ''.join([struct.pack('>H', n) for n in xrange(512)])
+ kblob = bytes().join([struct.pack('>H', n) for n in range(512)])
start = time.time()
try:
- f = sftp.open('%s/hongry.txt' % FOLDER, 'w')
- f.set_pipelined(True)
- for n in range(1024):
- f.write(kblob)
- if n % 128 == 0:
- sys.stderr.write('.')
- f.close()
+ with sftp.open('%s/hongry.txt' % FOLDER, 'wb') as f:
+ f.set_pipelined(True)
+ for n in range(1024):
+ f.write(kblob)
+ if n % 128 == 0:
+ sys.stderr.write('.')
sys.stderr.write(' ')
self.assertEqual(sftp.stat('%s/hongry.txt' % FOLDER).st_size, 1024 * 1024)
@@ -140,22 +136,21 @@ class BigSFTPTest (unittest.TestCase):
sys.stderr.write('%ds ' % round(end - start))
start = time.time()
- f = sftp.open('%s/hongry.txt' % FOLDER, 'r')
- f.prefetch()
+ with sftp.open('%s/hongry.txt' % FOLDER, 'rb') as f:
+ f.prefetch()
- # read on odd boundaries to make sure the bytes aren't getting scrambled
- n = 0
- k2blob = kblob + kblob
- chunk = 629
- size = 1024 * 1024
- while n < size:
- if n + chunk > size:
- chunk = size - n
- data = f.read(chunk)
- offset = n % 1024
- self.assertEqual(data, k2blob[offset:offset + chunk])
- n += chunk
- f.close()
+ # read on odd boundaries to make sure the bytes aren't getting scrambled
+ n = 0
+ k2blob = kblob + kblob
+ chunk = 629
+ size = 1024 * 1024
+ while n < size:
+ if n + chunk > size:
+ chunk = size - n
+ data = f.read(chunk)
+ offset = n % 1024
+ self.assertEqual(data, k2blob[offset:offset + chunk])
+ n += chunk
end = time.time()
sys.stderr.write('%ds ' % round(end - start))
@@ -164,15 +159,14 @@ class BigSFTPTest (unittest.TestCase):
def test_4_prefetch_seek(self):
sftp = get_sftp()
- kblob = ''.join([struct.pack('>H', n) for n in xrange(512)])
+ kblob = bytes().join([struct.pack('>H', n) for n in range(512)])
try:
- f = sftp.open('%s/hongry.txt' % FOLDER, 'w')
- f.set_pipelined(True)
- for n in range(1024):
- f.write(kblob)
- if n % 128 == 0:
- sys.stderr.write('.')
- f.close()
+ with sftp.open('%s/hongry.txt' % FOLDER, 'wb') as f:
+ f.set_pipelined(True)
+ for n in range(1024):
+ f.write(kblob)
+ if n % 128 == 0:
+ sys.stderr.write('.')
sys.stderr.write(' ')
self.assertEqual(sftp.stat('%s/hongry.txt' % FOLDER).st_size, 1024 * 1024)
@@ -180,21 +174,20 @@ class BigSFTPTest (unittest.TestCase):
start = time.time()
k2blob = kblob + kblob
chunk = 793
- for i in xrange(10):
- f = sftp.open('%s/hongry.txt' % FOLDER, 'r')
- f.prefetch()
- base_offset = (512 * 1024) + 17 * random.randint(1000, 2000)
- offsets = [base_offset + j * chunk for j in xrange(100)]
- # randomly seek around and read them out
- for j in xrange(100):
- offset = offsets[random.randint(0, len(offsets) - 1)]
- offsets.remove(offset)
- f.seek(offset)
- data = f.read(chunk)
- n_offset = offset % 1024
- self.assertEqual(data, k2blob[n_offset:n_offset + chunk])
- offset += chunk
- f.close()
+ for i in range(10):
+ with sftp.open('%s/hongry.txt' % FOLDER, 'rb') as f:
+ f.prefetch()
+ base_offset = (512 * 1024) + 17 * random.randint(1000, 2000)
+ offsets = [base_offset + j * chunk for j in range(100)]
+ # randomly seek around and read them out
+ for j in range(100):
+ offset = offsets[random.randint(0, len(offsets) - 1)]
+ offsets.remove(offset)
+ f.seek(offset)
+ data = f.read(chunk)
+ n_offset = offset % 1024
+ self.assertEqual(data, k2blob[n_offset:n_offset + chunk])
+ offset += chunk
end = time.time()
sys.stderr.write('%ds ' % round(end - start))
finally:
@@ -202,15 +195,14 @@ class BigSFTPTest (unittest.TestCase):
def test_5_readv_seek(self):
sftp = get_sftp()
- kblob = ''.join([struct.pack('>H', n) for n in xrange(512)])
+ kblob = bytes().join([struct.pack('>H', n) for n in range(512)])
try:
- f = sftp.open('%s/hongry.txt' % FOLDER, 'w')
- f.set_pipelined(True)
- for n in range(1024):
- f.write(kblob)
- if n % 128 == 0:
- sys.stderr.write('.')
- f.close()
+ with sftp.open('%s/hongry.txt' % FOLDER, 'wb') as f:
+ f.set_pipelined(True)
+ for n in range(1024):
+ f.write(kblob)
+ if n % 128 == 0:
+ sys.stderr.write('.')
sys.stderr.write(' ')
self.assertEqual(sftp.stat('%s/hongry.txt' % FOLDER).st_size, 1024 * 1024)
@@ -218,22 +210,21 @@ class BigSFTPTest (unittest.TestCase):
start = time.time()
k2blob = kblob + kblob
chunk = 793
- for i in xrange(10):
- f = sftp.open('%s/hongry.txt' % FOLDER, 'r')
- base_offset = (512 * 1024) + 17 * random.randint(1000, 2000)
- # make a bunch of offsets and put them in random order
- offsets = [base_offset + j * chunk for j in xrange(100)]
- readv_list = []
- for j in xrange(100):
- o = offsets[random.randint(0, len(offsets) - 1)]
- offsets.remove(o)
- readv_list.append((o, chunk))
- ret = f.readv(readv_list)
- for i in xrange(len(readv_list)):
- offset = readv_list[i][0]
- n_offset = offset % 1024
- self.assertEqual(ret.next(), k2blob[n_offset:n_offset + chunk])
- f.close()
+ for i in range(10):
+ with sftp.open('%s/hongry.txt' % FOLDER, 'rb') as f:
+ base_offset = (512 * 1024) + 17 * random.randint(1000, 2000)
+ # make a bunch of offsets and put them in random order
+ offsets = [base_offset + j * chunk for j in range(100)]
+ readv_list = []
+ for j in range(100):
+ o = offsets[random.randint(0, len(offsets) - 1)]
+ offsets.remove(o)
+ readv_list.append((o, chunk))
+ ret = f.readv(readv_list)
+ for i in range(len(readv_list)):
+ offset = readv_list[i][0]
+ n_offset = offset % 1024
+ self.assertEqual(next(ret), k2blob[n_offset:n_offset + chunk])
end = time.time()
sys.stderr.write('%ds ' % round(end - start))
finally:
@@ -247,28 +238,26 @@ class BigSFTPTest (unittest.TestCase):
sftp = get_sftp()
kblob = (1024 * 'x')
try:
- f = sftp.open('%s/hongry.txt' % FOLDER, 'w')
- f.set_pipelined(True)
- for n in range(1024):
- f.write(kblob)
- if n % 128 == 0:
- sys.stderr.write('.')
- f.close()
+ with sftp.open('%s/hongry.txt' % FOLDER, 'w') as f:
+ f.set_pipelined(True)
+ for n in range(1024):
+ f.write(kblob)
+ if n % 128 == 0:
+ sys.stderr.write('.')
sys.stderr.write(' ')
self.assertEqual(sftp.stat('%s/hongry.txt' % FOLDER).st_size, 1024 * 1024)
for i in range(10):
- f = sftp.open('%s/hongry.txt' % FOLDER, 'r')
+ with sftp.open('%s/hongry.txt' % FOLDER, 'r') as f:
+ f.prefetch()
+ with sftp.open('%s/hongry.txt' % FOLDER, 'r') as f:
f.prefetch()
- f = sftp.open('%s/hongry.txt' % FOLDER, 'r')
- f.prefetch()
- for n in range(1024):
- data = f.read(1024)
- self.assertEqual(data, kblob)
- if n % 128 == 0:
- sys.stderr.write('.')
- f.close()
+ for n in range(1024):
+ data = f.read(1024)
+ self.assertEqual(data, kblob)
+ if n % 128 == 0:
+ sys.stderr.write('.')
sys.stderr.write(' ')
finally:
sftp.remove('%s/hongry.txt' % FOLDER)
@@ -278,35 +267,33 @@ class BigSFTPTest (unittest.TestCase):
verify that prefetch and readv don't conflict with each other.
"""
sftp = get_sftp()
- kblob = ''.join([struct.pack('>H', n) for n in xrange(512)])
+ kblob = bytes().join([struct.pack('>H', n) for n in range(512)])
try:
- f = sftp.open('%s/hongry.txt' % FOLDER, 'w')
- f.set_pipelined(True)
- for n in range(1024):
- f.write(kblob)
- if n % 128 == 0:
- sys.stderr.write('.')
- f.close()
+ with sftp.open('%s/hongry.txt' % FOLDER, 'wb') as f:
+ f.set_pipelined(True)
+ for n in range(1024):
+ f.write(kblob)
+ if n % 128 == 0:
+ sys.stderr.write('.')
sys.stderr.write(' ')
self.assertEqual(sftp.stat('%s/hongry.txt' % FOLDER).st_size, 1024 * 1024)
- f = sftp.open('%s/hongry.txt' % FOLDER, 'r')
- f.prefetch()
- data = f.read(1024)
- self.assertEqual(data, kblob)
-
- chunk_size = 793
- base_offset = 512 * 1024
- k2blob = kblob + kblob
- chunks = [(base_offset + (chunk_size * i), chunk_size) for i in range(20)]
- for data in f.readv(chunks):
- offset = base_offset % 1024
- self.assertEqual(chunk_size, len(data))
- self.assertEqual(k2blob[offset:offset + chunk_size], data)
- base_offset += chunk_size
-
- f.close()
+ with sftp.open('%s/hongry.txt' % FOLDER, 'rb') as f:
+ f.prefetch()
+ data = f.read(1024)
+ self.assertEqual(data, kblob)
+
+ chunk_size = 793
+ base_offset = 512 * 1024
+ k2blob = kblob + kblob
+ chunks = [(base_offset + (chunk_size * i), chunk_size) for i in range(20)]
+ for data in f.readv(chunks):
+ offset = base_offset % 1024
+ self.assertEqual(chunk_size, len(data))
+ self.assertEqual(k2blob[offset:offset + chunk_size], data)
+ base_offset += chunk_size
+
sys.stderr.write(' ')
finally:
sftp.remove('%s/hongry.txt' % FOLDER)
@@ -317,26 +304,24 @@ class BigSFTPTest (unittest.TestCase):
returned as a single blob.
"""
sftp = get_sftp()
- kblob = ''.join([struct.pack('>H', n) for n in xrange(512)])
+ kblob = bytes().join([struct.pack('>H', n) for n in range(512)])
try:
- f = sftp.open('%s/hongry.txt' % FOLDER, 'w')
- f.set_pipelined(True)
- for n in range(1024):
- f.write(kblob)
- if n % 128 == 0:
- sys.stderr.write('.')
- f.close()
+ with sftp.open('%s/hongry.txt' % FOLDER, 'wb') as f:
+ f.set_pipelined(True)
+ for n in range(1024):
+ f.write(kblob)
+ if n % 128 == 0:
+ sys.stderr.write('.')
sys.stderr.write(' ')
self.assertEqual(sftp.stat('%s/hongry.txt' % FOLDER).st_size, 1024 * 1024)
- f = sftp.open('%s/hongry.txt' % FOLDER, 'r')
- data = list(f.readv([(23 * 1024, 128 * 1024)]))
- self.assertEqual(1, len(data))
- data = data[0]
- self.assertEqual(128 * 1024, len(data))
+ with sftp.open('%s/hongry.txt' % FOLDER, 'rb') as f:
+ data = list(f.readv([(23 * 1024, 128 * 1024)]))
+ self.assertEqual(1, len(data))
+ data = data[0]
+ self.assertEqual(128 * 1024, len(data))
- f.close()
sys.stderr.write(' ')
finally:
sftp.remove('%s/hongry.txt' % FOLDER)
@@ -348,9 +333,8 @@ class BigSFTPTest (unittest.TestCase):
sftp = get_sftp()
mblob = (1024 * 1024 * 'x')
try:
- f = sftp.open('%s/hongry.txt' % FOLDER, 'w', 128 * 1024)
- f.write(mblob)
- f.close()
+ with sftp.open('%s/hongry.txt' % FOLDER, 'w', 128 * 1024) as f:
+ f.write(mblob)
self.assertEqual(sftp.stat('%s/hongry.txt' % FOLDER).st_size, 1024 * 1024)
finally:
@@ -365,21 +349,26 @@ class BigSFTPTest (unittest.TestCase):
t.packetizer.REKEY_BYTES = 512 * 1024
k32blob = (32 * 1024 * 'x')
try:
- f = sftp.open('%s/hongry.txt' % FOLDER, 'w', 128 * 1024)
- for i in xrange(32):
- f.write(k32blob)
- f.close()
-
+ with sftp.open('%s/hongry.txt' % FOLDER, 'w', 128 * 1024) as f:
+ for i in range(32):
+ f.write(k32blob)
+
self.assertEqual(sftp.stat('%s/hongry.txt' % FOLDER).st_size, 1024 * 1024)
- self.assertNotEquals(t.H, t.session_id)
+ self.assertNotEqual(t.H, t.session_id)
# try to read it too.
- f = sftp.open('%s/hongry.txt' % FOLDER, 'r', 128 * 1024)
- f.prefetch()
- total = 0
- while total < 1024 * 1024:
- total += len(f.read(32 * 1024))
- f.close()
+ with sftp.open('%s/hongry.txt' % FOLDER, 'r', 128 * 1024) as f:
+ f.prefetch()
+ total = 0
+ while total < 1024 * 1024:
+ total += len(f.read(32 * 1024))
finally:
sftp.remove('%s/hongry.txt' % FOLDER)
t.packetizer.REKEY_BYTES = pow(2, 30)
+
+
+if __name__ == '__main__':
+ from tests.test_sftp import SFTPTest
+ SFTPTest.init_loopback()
+ from unittest import main
+ main()
diff --git a/tests/test_transport.py b/tests/test_transport.py
index e8f7f366..876759c8 100644
--- a/tests/test_transport.py
+++ b/tests/test_transport.py
@@ -20,7 +20,7 @@
Some unit tests for the ssh2 protocol in Transport.
"""
-from binascii import hexlify, unhexlify
+from binascii import hexlify
import select
import socket
import sys
@@ -33,10 +33,10 @@ from paramiko import Transport, SecurityOptions, ServerInterface, RSAKey, DSSKey
SSHException, BadAuthenticationType, InteractiveQuery, ChannelException
from paramiko import AUTH_FAILED, AUTH_PARTIALLY_SUCCESSFUL, AUTH_SUCCESSFUL
from paramiko import OPEN_SUCCEEDED, OPEN_FAILED_ADMINISTRATIVELY_PROHIBITED
-from paramiko.common import MSG_KEXINIT, MSG_CHANNEL_WINDOW_ADJUST
+from paramiko.common import MSG_KEXINIT, MSG_CHANNEL_WINDOW_ADJUST, b, bytes
from paramiko.message import Message
-from loop import LoopSocket
-from util import ParamikoTest
+from tests.loop import LoopSocket
+from tests.util import ParamikoTest, test_path
LONG_BANNER = """\
@@ -55,7 +55,7 @@ Maybe.
class NullServer (ServerInterface):
paranoid_did_password = False
paranoid_did_public_key = False
- paranoid_key = DSSKey.from_private_key_file('tests/test_dss.key')
+ paranoid_key = DSSKey.from_private_key_file(test_path('test_dss.key'))
def get_allowed_auths(self, username):
if username == 'slowdive':
@@ -121,8 +121,8 @@ class TransportTest(ParamikoTest):
self.sockc.close()
def setup_test_server(self, client_options=None, server_options=None):
- host_key = RSAKey.from_private_key_file('tests/test_rsa.key')
- public_host_key = RSAKey(data=str(host_key))
+ host_key = RSAKey.from_private_key_file(test_path('test_rsa.key'))
+ public_host_key = RSAKey(data=host_key.asbytes())
self.ts.add_server_key(host_key)
if client_options is not None:
@@ -132,37 +132,37 @@ class TransportTest(ParamikoTest):
event = threading.Event()
self.server = NullServer()
- self.assert_(not event.isSet())
+ self.assertTrue(not event.isSet())
self.ts.start_server(event, self.server)
self.tc.connect(hostkey=public_host_key,
username='slowdive', password='pygmalion')
event.wait(1.0)
- self.assert_(event.isSet())
- self.assert_(self.ts.is_active())
+ self.assertTrue(event.isSet())
+ self.assertTrue(self.ts.is_active())
def test_1_security_options(self):
o = self.tc.get_security_options()
- self.assertEquals(type(o), SecurityOptions)
- self.assert_(('aes256-cbc', 'blowfish-cbc') != o.ciphers)
+ self.assertEqual(type(o), SecurityOptions)
+ self.assertTrue(('aes256-cbc', 'blowfish-cbc') != o.ciphers)
o.ciphers = ('aes256-cbc', 'blowfish-cbc')
- self.assertEquals(('aes256-cbc', 'blowfish-cbc'), o.ciphers)
+ self.assertEqual(('aes256-cbc', 'blowfish-cbc'), o.ciphers)
try:
o.ciphers = ('aes256-cbc', 'made-up-cipher')
- self.assert_(False)
+ self.assertTrue(False)
except ValueError:
pass
try:
o.ciphers = 23
- self.assert_(False)
+ self.assertTrue(False)
except TypeError:
pass
def test_2_compute_key(self):
- self.tc.K = 123281095979686581523377256114209720774539068973101330872763622971399429481072519713536292772709507296759612401802191955568143056534122385270077606457721553469730659233569339356140085284052436697480759510519672848743794433460113118986816826624865291116513647975790797391795651716378444844877749505443714557929L
- self.tc.H = unhexlify('0C8307CDE6856FF30BA93684EB0F04C2520E9ED3')
+ self.tc.K = 123281095979686581523377256114209720774539068973101330872763622971399429481072519713536292772709507296759612401802191955568143056534122385270077606457721553469730659233569339356140085284052436697480759510519672848743794433460113118986816826624865291116513647975790797391795651716378444844877749505443714557929
+ self.tc.H = b'\x0C\x83\x07\xCD\xE6\x85\x6F\xF3\x0B\xA9\x36\x84\xEB\x0F\x04\xC2\x52\x0E\x9E\xD3'
self.tc.session_id = self.tc.H
key = self.tc._compute_key('C', 32)
- self.assertEquals('207E66594CA87C44ECCBA3B3CD39FDDB378E6FDB0F97C54B2AA0CFBF900CD995',
+ self.assertEqual(b'207E66594CA87C44ECCBA3B3CD39FDDB378E6FDB0F97C54B2AA0CFBF900CD995',
hexlify(key).upper())
def test_3_simple(self):
@@ -171,44 +171,44 @@ class TransportTest(ParamikoTest):
loopback sockets. this is hardly "simple" but it's simpler than the
later tests. :)
"""
- host_key = RSAKey.from_private_key_file('tests/test_rsa.key')
- public_host_key = RSAKey(data=str(host_key))
+ host_key = RSAKey.from_private_key_file(test_path('test_rsa.key'))
+ public_host_key = RSAKey(data=host_key.asbytes())
self.ts.add_server_key(host_key)
event = threading.Event()
server = NullServer()
- self.assert_(not event.isSet())
- self.assertEquals(None, self.tc.get_username())
- self.assertEquals(None, self.ts.get_username())
- self.assertEquals(False, self.tc.is_authenticated())
- self.assertEquals(False, self.ts.is_authenticated())
+ self.assertTrue(not event.isSet())
+ self.assertEqual(None, self.tc.get_username())
+ self.assertEqual(None, self.ts.get_username())
+ self.assertEqual(False, self.tc.is_authenticated())
+ self.assertEqual(False, self.ts.is_authenticated())
self.ts.start_server(event, server)
self.tc.connect(hostkey=public_host_key,
username='slowdive', password='pygmalion')
event.wait(1.0)
- self.assert_(event.isSet())
- self.assert_(self.ts.is_active())
- self.assertEquals('slowdive', self.tc.get_username())
- self.assertEquals('slowdive', self.ts.get_username())
- self.assertEquals(True, self.tc.is_authenticated())
- self.assertEquals(True, self.ts.is_authenticated())
+ self.assertTrue(event.isSet())
+ self.assertTrue(self.ts.is_active())
+ self.assertEqual('slowdive', self.tc.get_username())
+ self.assertEqual('slowdive', self.ts.get_username())
+ self.assertEqual(True, self.tc.is_authenticated())
+ self.assertEqual(True, self.ts.is_authenticated())
def test_3a_long_banner(self):
"""
verify that a long banner doesn't mess up the handshake.
"""
- host_key = RSAKey.from_private_key_file('tests/test_rsa.key')
- public_host_key = RSAKey(data=str(host_key))
+ host_key = RSAKey.from_private_key_file(test_path('test_rsa.key'))
+ public_host_key = RSAKey(data=host_key.asbytes())
self.ts.add_server_key(host_key)
event = threading.Event()
server = NullServer()
- self.assert_(not event.isSet())
+ self.assertTrue(not event.isSet())
self.socks.send(LONG_BANNER)
self.ts.start_server(event, server)
self.tc.connect(hostkey=public_host_key,
username='slowdive', password='pygmalion')
event.wait(1.0)
- self.assert_(event.isSet())
- self.assert_(self.ts.is_active())
+ self.assertTrue(event.isSet())
+ self.assertTrue(self.ts.is_active())
def test_4_special(self):
"""
@@ -219,10 +219,10 @@ class TransportTest(ParamikoTest):
options.ciphers = ('aes256-cbc',)
options.digests = ('hmac-md5-96',)
self.setup_test_server(client_options=force_algorithms)
- self.assertEquals('aes256-cbc', self.tc.local_cipher)
- self.assertEquals('aes256-cbc', self.tc.remote_cipher)
- self.assertEquals(12, self.tc.packetizer.get_mac_size_out())
- self.assertEquals(12, self.tc.packetizer.get_mac_size_in())
+ self.assertEqual('aes256-cbc', self.tc.local_cipher)
+ self.assertEqual('aes256-cbc', self.tc.remote_cipher)
+ self.assertEqual(12, self.tc.packetizer.get_mac_size_out())
+ self.assertEqual(12, self.tc.packetizer.get_mac_size_in())
self.tc.send_ignore(1024)
self.tc.renegotiate_keys()
@@ -233,10 +233,10 @@ class TransportTest(ParamikoTest):
verify that the keepalive will be sent.
"""
self.setup_test_server()
- self.assertEquals(None, getattr(self.server, '_global_request', None))
+ self.assertEqual(None, getattr(self.server, '_global_request', None))
self.tc.set_keepalive(1)
time.sleep(2)
- self.assertEquals('keepalive@lag.net', self.server._global_request)
+ self.assertEqual('keepalive@lag.net', self.server._global_request)
def test_6_exec_command(self):
"""
@@ -248,8 +248,8 @@ class TransportTest(ParamikoTest):
schan = self.ts.accept(1.0)
try:
chan.exec_command('no')
- self.assert_(False)
- except SSHException, x:
+ self.assertTrue(False)
+ except SSHException:
pass
chan = self.tc.open_session()
@@ -260,11 +260,11 @@ class TransportTest(ParamikoTest):
schan.close()
f = chan.makefile()
- self.assertEquals('Hello there.\n', f.readline())
- self.assertEquals('', f.readline())
+ self.assertEqual('Hello there.\n', f.readline())
+ self.assertEqual('', f.readline())
f = chan.makefile_stderr()
- self.assertEquals('This is on stderr.\n', f.readline())
- self.assertEquals('', f.readline())
+ self.assertEqual('This is on stderr.\n', f.readline())
+ self.assertEqual('', f.readline())
# now try it with combined stdout/stderr
chan = self.tc.open_session()
@@ -276,9 +276,9 @@ class TransportTest(ParamikoTest):
chan.set_combine_stderr(True)
f = chan.makefile()
- self.assertEquals('Hello there.\n', f.readline())
- self.assertEquals('This is on stderr.\n', f.readline())
- self.assertEquals('', f.readline())
+ self.assertEqual('Hello there.\n', f.readline())
+ self.assertEqual('This is on stderr.\n', f.readline())
+ self.assertEqual('', f.readline())
def test_7_invoke_shell(self):
"""
@@ -290,9 +290,9 @@ class TransportTest(ParamikoTest):
schan = self.ts.accept(1.0)
chan.send('communist j. cat\n')
f = schan.makefile()
- self.assertEquals('communist j. cat\n', f.readline())
+ self.assertEqual('communist j. cat\n', f.readline())
chan.close()
- self.assertEquals('', f.readline())
+ self.assertEqual('', f.readline())
def test_8_channel_exception(self):
"""
@@ -302,8 +302,8 @@ class TransportTest(ParamikoTest):
try:
chan = self.tc.open_channel('bogus')
self.fail('expected exception')
- except ChannelException, x:
- self.assert_(x.code == OPEN_FAILED_ADMINISTRATIVELY_PROHIBITED)
+ except ChannelException as e:
+ self.assertTrue(e.code == OPEN_FAILED_ADMINISTRATIVELY_PROHIBITED)
def test_9_exit_status(self):
"""
@@ -315,7 +315,7 @@ class TransportTest(ParamikoTest):
schan = self.ts.accept(1.0)
chan.exec_command('yes')
schan.send('Hello there.\n')
- self.assert_(not chan.exit_status_ready())
+ self.assertTrue(not chan.exit_status_ready())
# trigger an EOF
schan.shutdown_read()
schan.shutdown_write()
@@ -323,15 +323,15 @@ class TransportTest(ParamikoTest):
schan.close()
f = chan.makefile()
- self.assertEquals('Hello there.\n', f.readline())
- self.assertEquals('', f.readline())
+ self.assertEqual('Hello there.\n', f.readline())
+ self.assertEqual('', f.readline())
count = 0
while not chan.exit_status_ready():
time.sleep(0.1)
count += 1
if count > 50:
raise Exception("timeout")
- self.assertEquals(23, chan.recv_exit_status())
+ self.assertEqual(23, chan.recv_exit_status())
chan.close()
def test_A_select(self):
@@ -345,9 +345,9 @@ class TransportTest(ParamikoTest):
# nothing should be ready
r, w, e = select.select([chan], [], [], 0.1)
- self.assertEquals([], r)
- self.assertEquals([], w)
- self.assertEquals([], e)
+ self.assertEqual([], r)
+ self.assertEqual([], w)
+ self.assertEqual([], e)
schan.send('hello\n')
@@ -357,17 +357,17 @@ class TransportTest(ParamikoTest):
if chan in r:
break
time.sleep(0.1)
- self.assertEquals([chan], r)
- self.assertEquals([], w)
- self.assertEquals([], e)
+ self.assertEqual([chan], r)
+ self.assertEqual([], w)
+ self.assertEqual([], e)
- self.assertEquals('hello\n', chan.recv(6))
+ self.assertEqual(b'hello\n', chan.recv(6))
# and, should be dead again now
r, w, e = select.select([chan], [], [], 0.1)
- self.assertEquals([], r)
- self.assertEquals([], w)
- self.assertEquals([], e)
+ self.assertEqual([], r)
+ self.assertEqual([], w)
+ self.assertEqual([], e)
schan.close()
@@ -377,17 +377,17 @@ class TransportTest(ParamikoTest):
if chan in r:
break
time.sleep(0.1)
- self.assertEquals([chan], r)
- self.assertEquals([], w)
- self.assertEquals([], e)
- self.assertEquals('', chan.recv(16))
+ self.assertEqual([chan], r)
+ self.assertEqual([], w)
+ self.assertEqual([], e)
+ self.assertEqual(bytes(), chan.recv(16))
# make sure the pipe is still open for now...
p = chan._pipe
- self.assertEquals(False, p._closed)
+ self.assertEqual(False, p._closed)
chan.close()
# ...and now is closed.
- self.assertEquals(True, p._closed)
+ self.assertEqual(True, p._closed)
def test_B_renegotiate(self):
"""
@@ -399,17 +399,17 @@ class TransportTest(ParamikoTest):
chan.exec_command('yes')
schan = self.ts.accept(1.0)
- self.assertEquals(self.tc.H, self.tc.session_id)
+ self.assertEqual(self.tc.H, self.tc.session_id)
for i in range(20):
chan.send('x' * 1024)
chan.close()
# allow a few seconds for the rekeying to complete
- for i in xrange(50):
+ for i in range(50):
if self.tc.H != self.tc.session_id:
break
time.sleep(0.1)
- self.assertNotEquals(self.tc.H, self.tc.session_id)
+ self.assertNotEqual(self.tc.H, self.tc.session_id)
schan.close()
@@ -428,8 +428,8 @@ class TransportTest(ParamikoTest):
chan.send('x' * 1024)
bytes2 = self.tc.packetizer._Packetizer__sent_bytes
# tests show this is actually compressed to *52 bytes*! including packet overhead! nice!! :)
- self.assert_(bytes2 - bytes < 1024)
- self.assertEquals(52, bytes2 - bytes)
+ self.assertTrue(bytes2 - bytes < 1024)
+ self.assertEqual(52, bytes2 - bytes)
chan.close()
schan.close()
@@ -444,24 +444,25 @@ class TransportTest(ParamikoTest):
schan = self.ts.accept(1.0)
requested = []
- def handler(c, (addr, port)):
+ def handler(c, addr_port):
+ addr, port = addr_port
requested.append((addr, port))
self.tc._queue_incoming_channel(c)
- self.assertEquals(None, getattr(self.server, '_x11_screen_number', None))
+ self.assertEqual(None, getattr(self.server, '_x11_screen_number', None))
cookie = chan.request_x11(0, single_connection=True, handler=handler)
- self.assertEquals(0, self.server._x11_screen_number)
- self.assertEquals('MIT-MAGIC-COOKIE-1', self.server._x11_auth_protocol)
- self.assertEquals(cookie, self.server._x11_auth_cookie)
- self.assertEquals(True, self.server._x11_single_connection)
+ self.assertEqual(0, self.server._x11_screen_number)
+ self.assertEqual('MIT-MAGIC-COOKIE-1', self.server._x11_auth_protocol)
+ self.assertEqual(cookie, self.server._x11_auth_cookie)
+ self.assertEqual(True, self.server._x11_single_connection)
x11_server = self.ts.open_x11_channel(('localhost', 6093))
x11_client = self.tc.accept()
- self.assertEquals('localhost', requested[0][0])
- self.assertEquals(6093, requested[0][1])
+ self.assertEqual('localhost', requested[0][0])
+ self.assertEqual(6093, requested[0][1])
x11_server.send('hello')
- self.assertEquals('hello', x11_client.recv(5))
+ self.assertEqual(b'hello', x11_client.recv(5))
x11_server.close()
x11_client.close()
@@ -479,13 +480,13 @@ class TransportTest(ParamikoTest):
schan = self.ts.accept(1.0)
requested = []
- def handler(c, (origin_addr, origin_port), (server_addr, server_port)):
- requested.append((origin_addr, origin_port))
- requested.append((server_addr, server_port))
+ def handler(c, origin_addr_port, server_addr_port):
+ requested.append(origin_addr_port)
+ requested.append(server_addr_port)
self.tc._queue_incoming_channel(c)
port = self.tc.request_port_forward('127.0.0.1', 0, handler)
- self.assertEquals(port, self.server._listen.getsockname()[1])
+ self.assertEqual(port, self.server._listen.getsockname()[1])
cs = socket.socket()
cs.connect(('127.0.0.1', port))
@@ -494,7 +495,7 @@ class TransportTest(ParamikoTest):
cch = self.tc.accept()
sch.send('hello')
- self.assertEquals('hello', cch.recv(5))
+ self.assertEqual(b'hello', cch.recv(5))
sch.close()
cch.close()
ss.close()
@@ -526,12 +527,12 @@ class TransportTest(ParamikoTest):
cch.connect(self.server._tcpip_dest)
ss, _ = greeting_server.accept()
- ss.send('Hello!\n')
+ ss.send(b'Hello!\n')
ss.close()
sch.send(cch.recv(8192))
sch.close()
- self.assertEquals('Hello!\n', cs.recv(7))
+ self.assertEqual(b'Hello!\n', cs.recv(7))
cs.close()
def test_G_stderr_select(self):
@@ -546,9 +547,9 @@ class TransportTest(ParamikoTest):
# nothing should be ready
r, w, e = select.select([chan], [], [], 0.1)
- self.assertEquals([], r)
- self.assertEquals([], w)
- self.assertEquals([], e)
+ self.assertEqual([], r)
+ self.assertEqual([], w)
+ self.assertEqual([], e)
schan.send_stderr('hello\n')
@@ -558,17 +559,17 @@ class TransportTest(ParamikoTest):
if chan in r:
break
time.sleep(0.1)
- self.assertEquals([chan], r)
- self.assertEquals([], w)
- self.assertEquals([], e)
+ self.assertEqual([chan], r)
+ self.assertEqual([], w)
+ self.assertEqual([], e)
- self.assertEquals('hello\n', chan.recv_stderr(6))
+ self.assertEqual(b'hello\n', chan.recv_stderr(6))
# and, should be dead again now
r, w, e = select.select([chan], [], [], 0.1)
- self.assertEquals([], r)
- self.assertEquals([], w)
- self.assertEquals([], e)
+ self.assertEqual([], r)
+ self.assertEqual([], w)
+ self.assertEqual([], e)
schan.close()
chan.close()
@@ -582,7 +583,7 @@ class TransportTest(ParamikoTest):
chan.invoke_shell()
schan = self.ts.accept(1.0)
- self.assertEquals(chan.send_ready(), True)
+ self.assertEqual(chan.send_ready(), True)
total = 0
K = '*' * 1024
while total < 1024 * 1024:
@@ -590,11 +591,11 @@ class TransportTest(ParamikoTest):
total += len(K)
if not chan.send_ready():
break
- self.assert_(total < 1024 * 1024)
+ self.assertTrue(total < 1024 * 1024)
schan.close()
chan.close()
- self.assertEquals(chan.send_ready(), True)
+ self.assertEqual(chan.send_ready(), True)
def test_I_rekey_deadlock(self):
"""
@@ -657,7 +658,7 @@ class TransportTest(ParamikoTest):
def run(self):
try:
- for i in xrange(1, 1+self.iterations):
+ for i in range(1, 1+self.iterations):
if self.done_event.isSet():
break
self.watchdog_event.set()
@@ -706,7 +707,7 @@ class TransportTest(ParamikoTest):
# Simulate in-transit MSG_CHANNEL_WINDOW_ADJUST by sending it
# before responding to the incoming MSG_KEXINIT.
m2 = Message()
- m2.add_byte(chr(MSG_CHANNEL_WINDOW_ADJUST))
+ m2.add_byte(cMSG_CHANNEL_WINDOW_ADJUST)
m2.add_int(chan.remote_chanid)
m2.add_int(1) # bytes to add
self._send_message(m2)
diff --git a/tests/test_util.py b/tests/test_util.py
index 12677a9b..4f85c391 100644
--- a/tests/test_util.py
+++ b/tests/test_util.py
@@ -21,15 +21,15 @@ Some unit tests for utility functions.
"""
from binascii import hexlify
-import cStringIO
import errno
import os
import unittest
from Crypto.Hash import SHA
import paramiko.util
from paramiko.util import lookup_ssh_host_config as host_config
+from paramiko.py3compat import StringIO, byte_ord, b
-from util import ParamikoTest
+from tests.util import ParamikoTest
test_config_file = """\
Host *
@@ -65,7 +65,7 @@ class UtilTest(ParamikoTest):
"""
verify that all the classes can be imported from paramiko.
"""
- symbols = globals().keys()
+ symbols = list(globals().keys())
self.assertTrue('Transport' in symbols)
self.assertTrue('SSHClient' in symbols)
self.assertTrue('MissingHostKeyPolicy' in symbols)
@@ -101,9 +101,9 @@ class UtilTest(ParamikoTest):
def test_2_parse_config(self):
global test_config_file
- f = cStringIO.StringIO(test_config_file)
+ f = StringIO(test_config_file)
config = paramiko.util.parse_ssh_config(f)
- self.assertEquals(config._config,
+ self.assertEqual(config._config,
[{'host': ['*'], 'config': {}}, {'host': ['*'], 'config': {'identityfile': ['~/.ssh/id_rsa'], 'user': 'robey'}},
{'host': ['*.example.com'], 'config': {'user': 'bjork', 'port': '3333'}},
{'host': ['*'], 'config': {'crazy': 'something dumb '}},
@@ -111,7 +111,7 @@ class UtilTest(ParamikoTest):
def test_3_host_config(self):
global test_config_file
- f = cStringIO.StringIO(test_config_file)
+ f = StringIO(test_config_file)
config = paramiko.util.parse_ssh_config(f)
for host, values in {
@@ -131,27 +131,26 @@ class UtilTest(ParamikoTest):
hostname=host,
identityfile=[os.path.expanduser("~/.ssh/id_rsa")]
)
- self.assertEquals(
+ self.assertEqual(
paramiko.util.lookup_ssh_host_config(host, config),
values
)
def test_4_generate_key_bytes(self):
- x = paramiko.util.generate_key_bytes(SHA, 'ABCDEFGH', 'This is my secret passphrase.', 64)
- hex = ''.join(['%02x' % ord(c) for c in x])
- self.assertEquals(hex, '9110e2f6793b69363e58173e9436b13a5a4b339005741d5c680e505f57d871347b4239f14fb5c46e857d5e100424873ba849ac699cea98d729e57b3e84378e8b')
+ x = paramiko.util.generate_key_bytes(SHA, b'ABCDEFGH', 'This is my secret passphrase.', 64)
+ hex = ''.join(['%02x' % byte_ord(c) for c in x])
+ self.assertEqual(hex, '9110e2f6793b69363e58173e9436b13a5a4b339005741d5c680e505f57d871347b4239f14fb5c46e857d5e100424873ba849ac699cea98d729e57b3e84378e8b')
def test_5_host_keys(self):
- f = open('hostfile.temp', 'w')
- f.write(test_hosts_file)
- f.close()
+ with open('hostfile.temp', 'w') as f:
+ f.write(test_hosts_file)
try:
hostdict = paramiko.util.load_host_keys('hostfile.temp')
- self.assertEquals(2, len(hostdict))
- self.assertEquals(1, len(hostdict.values()[0]))
- self.assertEquals(1, len(hostdict.values()[1]))
+ self.assertEqual(2, len(hostdict))
+ self.assertEqual(1, len(list(hostdict.values())[0]))
+ self.assertEqual(1, len(list(hostdict.values())[1]))
fp = hexlify(hostdict['secure.example.com']['ssh-rsa'].get_fingerprint()).upper()
- self.assertEquals('E6684DB30E109B67B70FF1DC5C7F1363', fp)
+ self.assertEqual(b'E6684DB30E109B67B70FF1DC5C7F1363', fp)
finally:
os.unlink('hostfile.temp')
@@ -159,7 +158,7 @@ class UtilTest(ParamikoTest):
from paramiko.common import rng
# just verify that we can pull out 32 bytes and not get an exception.
x = rng.read(32)
- self.assertEquals(len(x), 32)
+ self.assertEqual(len(x), 32)
def test_7_host_config_expose_issue_33(self):
test_config_file = """
@@ -172,16 +171,16 @@ Host *.example.com
Host *
Port 3333
"""
- f = cStringIO.StringIO(test_config_file)
+ f = StringIO(test_config_file)
config = paramiko.util.parse_ssh_config(f)
host = 'www13.example.com'
- self.assertEquals(
+ self.assertEqual(
paramiko.util.lookup_ssh_host_config(host, config),
{'hostname': host, 'port': '22'}
)
def test_8_eintr_retry(self):
- self.assertEquals('foo', paramiko.util.retry_on_signal(lambda: 'foo'))
+ self.assertEqual('foo', paramiko.util.retry_on_signal(lambda: 'foo'))
# Variables that are set by raises_intr
intr_errors_remaining = [3]
@@ -192,8 +191,8 @@ Host *
intr_errors_remaining[0] -= 1
raise IOError(errno.EINTR, 'file', 'interrupted system call')
self.assertTrue(paramiko.util.retry_on_signal(raises_intr) is None)
- self.assertEquals(0, intr_errors_remaining[0])
- self.assertEquals(4, call_count[0])
+ self.assertEqual(0, intr_errors_remaining[0])
+ self.assertEqual(4, call_count[0])
def raises_ioerror_not_eintr():
raise IOError(errno.ENOENT, 'file', 'file not found')
@@ -216,10 +215,10 @@ Host space-delimited
Host equals-delimited
ProxyCommand=foo bar=biz baz
"""
- f = cStringIO.StringIO(conf)
+ f = StringIO(conf)
config = paramiko.util.parse_ssh_config(f)
for host in ('space-delimited', 'equals-delimited'):
- self.assertEquals(
+ self.assertEqual(
host_config(host, config)['proxycommand'],
'foo bar=biz baz'
)
@@ -228,7 +227,7 @@ Host equals-delimited
"""
ProxyCommand should perform interpolation on the value
"""
- config = paramiko.util.parse_ssh_config(cStringIO.StringIO("""
+ config = paramiko.util.parse_ssh_config(StringIO("""
Host specific
Port 37
ProxyCommand host %h port %p lol
@@ -245,7 +244,7 @@ Host *
('specific', "host specific port 37 lol"),
('portonly', "host portonly port 155"),
):
- self.assertEquals(
+ self.assertEqual(
host_config(host, config)['proxycommand'],
val
)
@@ -264,10 +263,10 @@ Host www13.*
Host *
Port 3333
"""
- f = cStringIO.StringIO(test_config_file)
+ f = StringIO(test_config_file)
config = paramiko.util.parse_ssh_config(f)
host = 'www13.example.com'
- self.assertEquals(
+ self.assertEqual(
paramiko.util.lookup_ssh_host_config(host, config),
{'hostname': host, 'port': '8080'}
)
@@ -293,9 +292,9 @@ ProxyCommand foo=bar:%h-%p
'foo=bar:proxy-without-equal-divisor-22'}
}.items():
- f = cStringIO.StringIO(test_config_file)
+ f = StringIO(test_config_file)
config = paramiko.util.parse_ssh_config(f)
- self.assertEquals(
+ self.assertEqual(
paramiko.util.lookup_ssh_host_config(host, config),
values
)
@@ -323,9 +322,9 @@ IdentityFile id_dsa22
'identityfile': ['id_dsa0', 'id_dsa1', 'id_dsa22']}
}.items():
- f = cStringIO.StringIO(test_config_file)
+ f = StringIO(test_config_file)
config = paramiko.util.parse_ssh_config(f)
- self.assertEquals(
+ self.assertEqual(
paramiko.util.lookup_ssh_host_config(host, config),
values
)
@@ -338,5 +337,5 @@ IdentityFile id_dsa22
AddressFamily inet
IdentityFile something_%l_using_fqdn
"""
- config = paramiko.util.parse_ssh_config(cStringIO.StringIO(test_config))
+ config = paramiko.util.parse_ssh_config(StringIO(test_config))
assert config.lookup('meh') # will die during lookup() if bug regresses
diff --git a/tests/util.py b/tests/util.py
index 2e0be087..66d2696c 100644
--- a/tests/util.py
+++ b/tests/util.py
@@ -1,5 +1,8 @@
+import os
import unittest
+root_path = os.path.dirname(os.path.realpath(__file__))
+
class ParamikoTest(unittest.TestCase):
# for Python 2.3 and below
@@ -8,3 +11,7 @@ class ParamikoTest(unittest.TestCase):
if not hasattr(unittest.TestCase, 'assertFalse'):
assertFalse = unittest.TestCase.failIf
+
+def test_path(filename):
+ return os.path.join(root_path, filename)
+
diff --git a/tox.ini b/tox.ini
index af4fbf20..55e3fe64 100644
--- a/tox.ini
+++ b/tox.ini
@@ -1,5 +1,5 @@
[tox]
-envlist = py25,py26,py27
+envlist = py25,py26,py27,py32,py33
[testenv]
commands = pip install --use-mirrors -q -r tox-requirements.txt