diff options
author | Chris Rose <offline@offby1.net> | 2018-05-17 10:13:38 -0400 |
---|---|---|
committer | Chris Rose <offline@offby1.net> | 2018-05-17 10:13:38 -0400 |
commit | 7f2c35052183b400827d9949a68b41c90f90a32d (patch) | |
tree | fea4a1ec04b7ee3ced14d61e8b6cf3f479e22704 | |
parent | 52551321a2297bdb966869fa719e584c868dd857 (diff) |
Blacken Paramiko on 2.4
63 files changed, 4042 insertions, 3032 deletions
diff --git a/paramiko/__init__.py b/paramiko/__init__.py index c4c69a45..58763b47 100644 --- a/paramiko/__init__.py +++ b/paramiko/__init__.py @@ -21,15 +21,22 @@ import sys from paramiko._version import __version__, __version_info__ from paramiko.transport import SecurityOptions, Transport from paramiko.client import ( - SSHClient, MissingHostKeyPolicy, AutoAddPolicy, RejectPolicy, + SSHClient, + MissingHostKeyPolicy, + AutoAddPolicy, + RejectPolicy, WarningPolicy, ) from paramiko.auth_handler import AuthHandler from paramiko.ssh_gss import GSSAuth, GSS_AUTH_AVAILABLE, GSS_EXCEPTIONS from paramiko.channel import Channel, ChannelFile from paramiko.ssh_exception import ( - SSHException, PasswordRequiredException, BadAuthenticationType, - ChannelException, BadHostKeyException, AuthenticationException, + SSHException, + PasswordRequiredException, + BadAuthenticationType, + ChannelException, + BadHostKeyException, + AuthenticationException, ProxyCommandFailure, ) from paramiko.server import ServerInterface, SubsystemHandler, InteractiveQuery @@ -54,14 +61,25 @@ from paramiko.config import SSHConfig from paramiko.proxy import ProxyCommand from paramiko.common import ( - AUTH_SUCCESSFUL, AUTH_PARTIALLY_SUCCESSFUL, AUTH_FAILED, OPEN_SUCCEEDED, - OPEN_FAILED_ADMINISTRATIVELY_PROHIBITED, OPEN_FAILED_CONNECT_FAILED, - OPEN_FAILED_UNKNOWN_CHANNEL_TYPE, OPEN_FAILED_RESOURCE_SHORTAGE, + AUTH_SUCCESSFUL, + AUTH_PARTIALLY_SUCCESSFUL, + AUTH_FAILED, + OPEN_SUCCEEDED, + OPEN_FAILED_ADMINISTRATIVELY_PROHIBITED, + OPEN_FAILED_CONNECT_FAILED, + OPEN_FAILED_UNKNOWN_CHANNEL_TYPE, + OPEN_FAILED_RESOURCE_SHORTAGE, ) from paramiko.sftp import ( - SFTP_OK, SFTP_EOF, SFTP_NO_SUCH_FILE, SFTP_PERMISSION_DENIED, SFTP_FAILURE, - SFTP_BAD_MESSAGE, SFTP_NO_CONNECTION, SFTP_CONNECTION_LOST, + SFTP_OK, + SFTP_EOF, + SFTP_NO_SUCH_FILE, + SFTP_PERMISSION_DENIED, + SFTP_FAILURE, + SFTP_BAD_MESSAGE, + SFTP_NO_CONNECTION, + SFTP_CONNECTION_LOST, SFTP_OP_UNSUPPORTED, ) @@ -72,41 +90,41 @@ __author__ = "Jeff Forcier <jeff@bitprophet.org>" __license__ = "GNU Lesser General Public License (LGPL)" __all__ = [ - 'Transport', - 'SSHClient', - 'MissingHostKeyPolicy', - 'AutoAddPolicy', - 'RejectPolicy', - 'WarningPolicy', - 'SecurityOptions', - 'SubsystemHandler', - 'Channel', - 'PKey', - 'RSAKey', - 'DSSKey', - 'Message', - 'SSHException', - 'AuthenticationException', - 'PasswordRequiredException', - 'BadAuthenticationType', - 'ChannelException', - 'BadHostKeyException', - 'ProxyCommand', - 'ProxyCommandFailure', - 'SFTP', - 'SFTPFile', - 'SFTPHandle', - 'SFTPClient', - 'SFTPServer', - 'SFTPError', - 'SFTPAttributes', - 'SFTPServerInterface', - 'ServerInterface', - 'BufferedFile', - 'Agent', - 'AgentKey', - 'HostKeys', - 'SSHConfig', - 'util', - 'io_sleep', + "Transport", + "SSHClient", + "MissingHostKeyPolicy", + "AutoAddPolicy", + "RejectPolicy", + "WarningPolicy", + "SecurityOptions", + "SubsystemHandler", + "Channel", + "PKey", + "RSAKey", + "DSSKey", + "Message", + "SSHException", + "AuthenticationException", + "PasswordRequiredException", + "BadAuthenticationType", + "ChannelException", + "BadHostKeyException", + "ProxyCommand", + "ProxyCommandFailure", + "SFTP", + "SFTPFile", + "SFTPHandle", + "SFTPClient", + "SFTPServer", + "SFTPError", + "SFTPAttributes", + "SFTPServerInterface", + "ServerInterface", + "BufferedFile", + "Agent", + "AgentKey", + "HostKeys", + "SSHConfig", + "util", + "io_sleep", ] diff --git a/paramiko/_version.py b/paramiko/_version.py index a5db89b8..95e86b6a 100644 --- a/paramiko/_version.py +++ b/paramiko/_version.py @@ -1,2 +1,2 @@ __version_info__ = (2, 4, 1) -__version__ = '.'.join(map(str, __version_info__)) +__version__ = ".".join(map(str, __version_info__)) diff --git a/paramiko/_winapi.py b/paramiko/_winapi.py index a13d7e87..c996ec46 100644 --- a/paramiko/_winapi.py +++ b/paramiko/_winapi.py @@ -15,6 +15,7 @@ from paramiko.py3compat import u, builtins ###################### # jaraco.windows.error + def format_system_message(errno): """ Call FormatMessage with a system error number to retrieve @@ -77,7 +78,7 @@ class WindowsError(builtins.WindowsError): return self.message def __repr__(self): - return '{self.__class__.__name__}({self.winerror})'.format(**vars()) + return "{self.__class__.__name__}({self.winerror})".format(**vars()) def handle_nonzero_success(result): @@ -124,11 +125,7 @@ UnmapViewOfFile = ctypes.windll.kernel32.UnmapViewOfFile UnmapViewOfFile.argtypes = ctypes.wintypes.HANDLE, RtlMoveMemory = ctypes.windll.kernel32.RtlMoveMemory -RtlMoveMemory.argtypes = ( - ctypes.c_void_p, - ctypes.c_void_p, - ctypes.c_size_t, -) +RtlMoveMemory.argtypes = (ctypes.c_void_p, ctypes.c_void_p, ctypes.c_size_t) ctypes.windll.kernel32.LocalFree.argtypes = ctypes.wintypes.HLOCAL, @@ -140,6 +137,7 @@ class MemoryMap(object): """ A memory map object which can have security attributes overridden. """ + def __init__(self, name, length, security_attributes=None): self.name = name self.length = length @@ -149,14 +147,20 @@ class MemoryMap(object): def __enter__(self): p_SA = ( ctypes.byref(self.security_attributes) - if self.security_attributes else None + if self.security_attributes + else None ) INVALID_HANDLE_VALUE = -1 PAGE_READWRITE = 0x4 FILE_MAP_WRITE = 0x2 filemap = ctypes.windll.kernel32.CreateFileMappingW( - INVALID_HANDLE_VALUE, p_SA, PAGE_READWRITE, 0, self.length, - u(self.name)) + INVALID_HANDLE_VALUE, + p_SA, + PAGE_READWRITE, + 0, + self.length, + u(self.name), + ) handle_nonzero_success(filemap) if filemap == INVALID_HANDLE_VALUE: raise Exception("Failed to create file mapping") @@ -220,41 +224,45 @@ POLICY_LOOKUP_NAMES = 0x00000800 POLICY_NOTIFICATION = 0x00001000 POLICY_ALL_ACCESS = ( - STANDARD_RIGHTS_REQUIRED | - POLICY_VIEW_LOCAL_INFORMATION | - POLICY_VIEW_AUDIT_INFORMATION | - POLICY_GET_PRIVATE_INFORMATION | - POLICY_TRUST_ADMIN | - POLICY_CREATE_ACCOUNT | - POLICY_CREATE_SECRET | - POLICY_CREATE_PRIVILEGE | - POLICY_SET_DEFAULT_QUOTA_LIMITS | - POLICY_SET_AUDIT_REQUIREMENTS | - POLICY_AUDIT_LOG_ADMIN | - POLICY_SERVER_ADMIN | - POLICY_LOOKUP_NAMES) + STANDARD_RIGHTS_REQUIRED + | POLICY_VIEW_LOCAL_INFORMATION + | POLICY_VIEW_AUDIT_INFORMATION + | POLICY_GET_PRIVATE_INFORMATION + | POLICY_TRUST_ADMIN + | POLICY_CREATE_ACCOUNT + | POLICY_CREATE_SECRET + | POLICY_CREATE_PRIVILEGE + | POLICY_SET_DEFAULT_QUOTA_LIMITS + | POLICY_SET_AUDIT_REQUIREMENTS + | POLICY_AUDIT_LOG_ADMIN + | POLICY_SERVER_ADMIN + | POLICY_LOOKUP_NAMES +) POLICY_READ = ( - STANDARD_RIGHTS_READ | - POLICY_VIEW_AUDIT_INFORMATION | - POLICY_GET_PRIVATE_INFORMATION) + STANDARD_RIGHTS_READ + | POLICY_VIEW_AUDIT_INFORMATION + | POLICY_GET_PRIVATE_INFORMATION +) POLICY_WRITE = ( - STANDARD_RIGHTS_WRITE | - POLICY_TRUST_ADMIN | - POLICY_CREATE_ACCOUNT | - POLICY_CREATE_SECRET | - POLICY_CREATE_PRIVILEGE | - POLICY_SET_DEFAULT_QUOTA_LIMITS | - POLICY_SET_AUDIT_REQUIREMENTS | - POLICY_AUDIT_LOG_ADMIN | - POLICY_SERVER_ADMIN) + STANDARD_RIGHTS_WRITE + | POLICY_TRUST_ADMIN + | POLICY_CREATE_ACCOUNT + | POLICY_CREATE_SECRET + | POLICY_CREATE_PRIVILEGE + | POLICY_SET_DEFAULT_QUOTA_LIMITS + | POLICY_SET_AUDIT_REQUIREMENTS + | POLICY_AUDIT_LOG_ADMIN + | POLICY_SERVER_ADMIN +) POLICY_EXECUTE = ( - STANDARD_RIGHTS_EXECUTE | - POLICY_VIEW_LOCAL_INFORMATION | - POLICY_LOOKUP_NAMES) + STANDARD_RIGHTS_EXECUTE + | POLICY_VIEW_LOCAL_INFORMATION + | POLICY_LOOKUP_NAMES +) class TokenAccess: @@ -268,8 +276,7 @@ class TokenInformationClass: class TOKEN_USER(ctypes.Structure): num = 1 _fields_ = [ - ('SID', ctypes.c_void_p), - ('ATTRIBUTES', ctypes.wintypes.DWORD), + ("SID", ctypes.c_void_p), ("ATTRIBUTES", ctypes.wintypes.DWORD) ] @@ -290,13 +297,13 @@ class SECURITY_DESCRIPTOR(ctypes.Structure): REVISION = 1 _fields_ = [ - ('Revision', ctypes.c_ubyte), - ('Sbz1', ctypes.c_ubyte), - ('Control', SECURITY_DESCRIPTOR_CONTROL), - ('Owner', ctypes.c_void_p), - ('Group', ctypes.c_void_p), - ('Sacl', ctypes.c_void_p), - ('Dacl', ctypes.c_void_p), + ("Revision", ctypes.c_ubyte), + ("Sbz1", ctypes.c_ubyte), + ("Control", SECURITY_DESCRIPTOR_CONTROL), + ("Owner", ctypes.c_void_p), + ("Group", ctypes.c_void_p), + ("Sacl", ctypes.c_void_p), + ("Dacl", ctypes.c_void_p), ] @@ -309,9 +316,9 @@ class SECURITY_ATTRIBUTES(ctypes.Structure): } SECURITY_ATTRIBUTES; """ _fields_ = [ - ('nLength', ctypes.wintypes.DWORD), - ('lpSecurityDescriptor', ctypes.c_void_p), - ('bInheritHandle', ctypes.wintypes.BOOL), + ("nLength", ctypes.wintypes.DWORD), + ("lpSecurityDescriptor", ctypes.c_void_p), + ("bInheritHandle", ctypes.wintypes.BOOL), ] def __init__(self, *args, **kwargs): @@ -329,9 +336,7 @@ class SECURITY_ATTRIBUTES(ctypes.Structure): ctypes.windll.advapi32.SetSecurityDescriptorOwner.argtypes = ( - ctypes.POINTER(SECURITY_DESCRIPTOR), - ctypes.c_void_p, - ctypes.wintypes.BOOL, + ctypes.POINTER(SECURITY_DESCRIPTOR), ctypes.c_void_p, ctypes.wintypes.BOOL ) ######################### @@ -343,21 +348,30 @@ def GetTokenInformation(token, information_class): Given a token, get the token information for it. """ data_size = ctypes.wintypes.DWORD() - ctypes.windll.advapi32.GetTokenInformation(token, information_class.num, - 0, 0, ctypes.byref(data_size)) + ctypes.windll.advapi32.GetTokenInformation( + token, information_class.num, 0, 0, ctypes.byref(data_size) + ) data = ctypes.create_string_buffer(data_size.value) - handle_nonzero_success(ctypes.windll.advapi32.GetTokenInformation(token, - information_class.num, - ctypes.byref(data), ctypes.sizeof(data), - ctypes.byref(data_size))) + handle_nonzero_success( + ctypes.windll.advapi32.GetTokenInformation( + token, + information_class.num, + ctypes.byref(data), + ctypes.sizeof(data), + ctypes.byref(data_size), + ) + ) return ctypes.cast(data, ctypes.POINTER(TOKEN_USER)).contents def OpenProcessToken(proc_handle, access): result = ctypes.wintypes.HANDLE() proc_handle = ctypes.wintypes.HANDLE(proc_handle) - handle_nonzero_success(ctypes.windll.advapi32.OpenProcessToken( - proc_handle, access, ctypes.byref(result))) + handle_nonzero_success( + ctypes.windll.advapi32.OpenProcessToken( + proc_handle, access, ctypes.byref(result) + ) + ) return result @@ -366,8 +380,7 @@ def get_current_user(): Return a TOKEN_USER for the owner of this process. """ process = OpenProcessToken( - ctypes.windll.kernel32.GetCurrentProcess(), - TokenAccess.TOKEN_QUERY, + ctypes.windll.kernel32.GetCurrentProcess(), TokenAccess.TOKEN_QUERY ) return GetTokenInformation(process, TOKEN_USER) @@ -389,8 +402,10 @@ def get_security_attributes_for_user(user=None): SA.descriptor = SD SA.bInheritHandle = 1 - ctypes.windll.advapi32.InitializeSecurityDescriptor(ctypes.byref(SD), - SECURITY_DESCRIPTOR.REVISION) - ctypes.windll.advapi32.SetSecurityDescriptorOwner(ctypes.byref(SD), - user.SID, 0) + ctypes.windll.advapi32.InitializeSecurityDescriptor( + ctypes.byref(SD), SECURITY_DESCRIPTOR.REVISION + ) + ctypes.windll.advapi32.SetSecurityDescriptorOwner( + ctypes.byref(SD), user.SID, 0 + ) return SA diff --git a/paramiko/agent.py b/paramiko/agent.py index 7a4dde21..00baf85c 100644 --- a/paramiko/agent.py +++ b/paramiko/agent.py @@ -43,8 +43,8 @@ cSSH2_AGENTC_SIGN_REQUEST = byte_chr(13) SSH2_AGENT_SIGN_RESPONSE = 14 - class AgentSSH(object): + def __init__(self): self._conn = None self._keys = () @@ -65,7 +65,7 @@ class AgentSSH(object): self._conn = conn ptype, result = self._send_message(cSSH2_AGENTC_REQUEST_IDENTITIES) if ptype != SSH2_AGENT_IDENTITIES_ANSWER: - raise SSHException('could not get keys from ssh-agent') + raise SSHException("could not get keys from ssh-agent") keys = [] for i in range(result.get_int()): keys.append(AgentKey(self, result.get_binary())) @@ -80,19 +80,19 @@ class AgentSSH(object): def _send_message(self, msg): msg = asbytes(msg) - self._conn.send(struct.pack('>I', len(msg)) + msg) + self._conn.send(struct.pack(">I", len(msg)) + msg) l = self._read_all(4) - msg = Message(self._read_all(struct.unpack('>I', l)[0])) + msg = Message(self._read_all(struct.unpack(">I", l)[0])) return ord(msg.get_byte()), msg def _read_all(self, wanted): result = self._conn.recv(wanted) while len(result) < wanted: if len(result) == 0: - raise SSHException('lost ssh-agent') + raise SSHException("lost ssh-agent") extra = self._conn.recv(wanted - len(result)) if len(extra) == 0: - raise SSHException('lost ssh-agent') + raise SSHException("lost ssh-agent") result += extra return result @@ -101,6 +101,7 @@ class AgentProxyThread(threading.Thread): """ Class in charge of communication between two channels. """ + def __init__(self, agent): threading.Thread.__init__(self, target=self.run) self._agent = agent @@ -116,10 +117,10 @@ class AgentProxyThread(threading.Thread): self.__addr = addr self._agent.connect() if ( - not isinstance(self._agent, int) and - ( - self._agent._conn is None or - not hasattr(self._agent._conn, 'fileno') + not isinstance(self._agent, int) + and ( + self._agent._conn is None + or not hasattr(self._agent._conn, "fileno") ) ): raise AuthenticationException("Unable to connect to SSH agent") @@ -130,6 +131,7 @@ class AgentProxyThread(threading.Thread): def _communicate(self): import fcntl + oldflags = fcntl.fcntl(self.__inr, fcntl.F_GETFL) fcntl.fcntl(self.__inr, fcntl.F_SETFL, oldflags | os.O_NONBLOCK) while not self._exit: @@ -162,6 +164,7 @@ class AgentLocalProxy(AgentProxyThread): Class to be used when wanting to ask a local SSH Agent being asked from a remote fake agent (so use a unix socket for ex.) """ + def __init__(self, agent): AgentProxyThread.__init__(self, agent) @@ -185,6 +188,7 @@ class AgentRemoteProxy(AgentProxyThread): """ Class to be used when wanting to ask a remote SSH Agent """ + def __init__(self, agent, chan): AgentProxyThread.__init__(self, agent) self.__chan = chan @@ -205,6 +209,7 @@ class AgentClientProxy(object): the remote fake agent and the local agent #. Communication occurs ... """ + def __init__(self, chanRemote): self._conn = None self.__chanR = chanRemote @@ -218,16 +223,18 @@ class AgentClientProxy(object): """ Method automatically called by ``AgentProxyThread.run``. """ - if ('SSH_AUTH_SOCK' in os.environ) and (sys.platform != 'win32'): + if ("SSH_AUTH_SOCK" in os.environ) and (sys.platform != "win32"): conn = socket.socket(socket.AF_UNIX, socket.SOCK_STREAM) try: retry_on_signal( - lambda: conn.connect(os.environ['SSH_AUTH_SOCK'])) + lambda: conn.connect(os.environ["SSH_AUTH_SOCK"]) + ) except: # probably a dangling env var: the ssh agent is gone return - elif sys.platform == 'win32': + elif sys.platform == "win32": import paramiko.win_pageant as win_pageant + if win_pageant.can_talk_to_agent(): conn = win_pageant.PageantConnection() else: @@ -255,12 +262,13 @@ class AgentServerProxy(AgentSSH): :raises: `.SSHException` -- mostly if we lost the agent """ + def __init__(self, t): AgentSSH.__init__(self) self.__t = t - self._dir = tempfile.mkdtemp('sshproxy') + self._dir = tempfile.mkdtemp("sshproxy") os.chmod(self._dir, stat.S_IRWXU) - self._file = self._dir + '/sshproxy.ssh' + self._file = self._dir + "/sshproxy.ssh" self.thread = AgentLocalProxy(self) self.thread.start() @@ -270,8 +278,8 @@ class AgentServerProxy(AgentSSH): def connect(self): conn_sock = self.__t.open_forward_agent_channel() if conn_sock is None: - raise SSHException('lost ssh-agent') - conn_sock.set_name('auth-agent') + raise SSHException("lost ssh-agent") + conn_sock.set_name("auth-agent") self._connect(conn_sock) def close(self): @@ -292,7 +300,7 @@ class AgentServerProxy(AgentSSH): :return: a dict containing the ``SSH_AUTH_SOCK`` environnement variables """ - return {'SSH_AUTH_SOCK': self._get_filename()} + return {"SSH_AUTH_SOCK": self._get_filename()} def _get_filename(self): return self._file @@ -319,6 +327,7 @@ class AgentRequestHandler(object): # the remote end. session.exec_command("git clone https://my.git.repository/") """ + def __init__(self, chanClient): self._conn = None self.__chanC = chanClient @@ -350,18 +359,20 @@ class Agent(AgentSSH): :raises: `.SSHException` -- if an SSH agent is found, but speaks an incompatible protocol """ + def __init__(self): AgentSSH.__init__(self) - if ('SSH_AUTH_SOCK' in os.environ) and (sys.platform != 'win32'): + if ("SSH_AUTH_SOCK" in os.environ) and (sys.platform != "win32"): conn = socket.socket(socket.AF_UNIX, socket.SOCK_STREAM) try: - conn.connect(os.environ['SSH_AUTH_SOCK']) + conn.connect(os.environ["SSH_AUTH_SOCK"]) except: # probably a dangling env var: the ssh agent is gone return - elif sys.platform == 'win32': + elif sys.platform == "win32": from . import win_pageant + if win_pageant.can_talk_to_agent(): conn = win_pageant.PageantConnection() else: @@ -384,6 +395,7 @@ class AgentKey(PKey): authenticating to a remote server (signing). Most other key operations work as expected. """ + def __init__(self, agent, blob): self.agent = agent self.blob = blob @@ -407,5 +419,5 @@ class AgentKey(PKey): msg.add_int(0) ptype, result = self.agent._send_message(msg) if ptype != SSH2_AGENT_SIGN_RESPONSE: - raise SSHException('key cannot be used for signing') + raise SSHException("key cannot be used for signing") return result.get_binary() diff --git a/paramiko/auth_handler.py b/paramiko/auth_handler.py index a1ce5e3b..416657e2 100644 --- a/paramiko/auth_handler.py +++ b/paramiko/auth_handler.py @@ -24,31 +24,55 @@ import weakref import time from paramiko.common import ( - cMSG_SERVICE_REQUEST, cMSG_DISCONNECT, DISCONNECT_SERVICE_NOT_AVAILABLE, - DISCONNECT_NO_MORE_AUTH_METHODS_AVAILABLE, cMSG_USERAUTH_REQUEST, - cMSG_SERVICE_ACCEPT, DEBUG, AUTH_SUCCESSFUL, INFO, cMSG_USERAUTH_SUCCESS, - cMSG_USERAUTH_FAILURE, AUTH_PARTIALLY_SUCCESSFUL, - cMSG_USERAUTH_INFO_REQUEST, WARNING, AUTH_FAILED, cMSG_USERAUTH_PK_OK, - cMSG_USERAUTH_INFO_RESPONSE, MSG_SERVICE_REQUEST, MSG_SERVICE_ACCEPT, - MSG_USERAUTH_REQUEST, MSG_USERAUTH_SUCCESS, MSG_USERAUTH_FAILURE, - MSG_USERAUTH_BANNER, MSG_USERAUTH_INFO_REQUEST, MSG_USERAUTH_INFO_RESPONSE, - cMSG_USERAUTH_GSSAPI_RESPONSE, cMSG_USERAUTH_GSSAPI_TOKEN, - cMSG_USERAUTH_GSSAPI_MIC, MSG_USERAUTH_GSSAPI_RESPONSE, - MSG_USERAUTH_GSSAPI_TOKEN, MSG_USERAUTH_GSSAPI_ERROR, - MSG_USERAUTH_GSSAPI_ERRTOK, MSG_USERAUTH_GSSAPI_MIC, MSG_NAMES, - cMSG_USERAUTH_BANNER + cMSG_SERVICE_REQUEST, + cMSG_DISCONNECT, + DISCONNECT_SERVICE_NOT_AVAILABLE, + DISCONNECT_NO_MORE_AUTH_METHODS_AVAILABLE, + cMSG_USERAUTH_REQUEST, + cMSG_SERVICE_ACCEPT, + DEBUG, + AUTH_SUCCESSFUL, + INFO, + cMSG_USERAUTH_SUCCESS, + cMSG_USERAUTH_FAILURE, + AUTH_PARTIALLY_SUCCESSFUL, + cMSG_USERAUTH_INFO_REQUEST, + WARNING, + AUTH_FAILED, + cMSG_USERAUTH_PK_OK, + cMSG_USERAUTH_INFO_RESPONSE, + MSG_SERVICE_REQUEST, + MSG_SERVICE_ACCEPT, + MSG_USERAUTH_REQUEST, + MSG_USERAUTH_SUCCESS, + MSG_USERAUTH_FAILURE, + MSG_USERAUTH_BANNER, + MSG_USERAUTH_INFO_REQUEST, + MSG_USERAUTH_INFO_RESPONSE, + cMSG_USERAUTH_GSSAPI_RESPONSE, + cMSG_USERAUTH_GSSAPI_TOKEN, + cMSG_USERAUTH_GSSAPI_MIC, + MSG_USERAUTH_GSSAPI_RESPONSE, + MSG_USERAUTH_GSSAPI_TOKEN, + MSG_USERAUTH_GSSAPI_ERROR, + MSG_USERAUTH_GSSAPI_ERRTOK, + MSG_USERAUTH_GSSAPI_MIC, + MSG_NAMES, + cMSG_USERAUTH_BANNER, ) from paramiko.message import Message from paramiko.py3compat import bytestring from paramiko.ssh_exception import ( - SSHException, AuthenticationException, BadAuthenticationType, + SSHException, + AuthenticationException, + BadAuthenticationType, PartialAuthentication, ) from paramiko.server import InteractiveQuery from paramiko.ssh_gss import GSSAuth, GSS_EXCEPTIONS -class AuthHandler (object): +class AuthHandler(object): """ Internal class to handle the mechanics of authentication. """ @@ -58,7 +82,7 @@ class AuthHandler (object): self.username = None self.authenticated = False self.auth_event = None - self.auth_method = '' + self.auth_method = "" self.banner = None self.password = None self.private_key = None @@ -87,7 +111,7 @@ class AuthHandler (object): self.transport.lock.acquire() try: self.auth_event = event - self.auth_method = 'none' + self.auth_method = "none" self.username = username self._request_auth() finally: @@ -97,7 +121,7 @@ class AuthHandler (object): self.transport.lock.acquire() try: self.auth_event = event - self.auth_method = 'publickey' + self.auth_method = "publickey" self.username = username self.private_key = key self._request_auth() @@ -108,21 +132,21 @@ class AuthHandler (object): self.transport.lock.acquire() try: self.auth_event = event - self.auth_method = 'password' + self.auth_method = "password" self.username = username self.password = password self._request_auth() finally: self.transport.lock.release() - def auth_interactive(self, username, handler, event, submethods=''): + def auth_interactive(self, username, handler, event, submethods=""): """ response_list = handler(title, instructions, prompt_list) """ self.transport.lock.acquire() try: self.auth_event = event - self.auth_method = 'keyboard-interactive' + self.auth_method = "keyboard-interactive" self.username = username self.interactive_handler = handler self.submethods = submethods @@ -134,7 +158,7 @@ class AuthHandler (object): self.transport.lock.acquire() try: self.auth_event = event - self.auth_method = 'gssapi-with-mic' + self.auth_method = "gssapi-with-mic" self.username = username self.gss_host = gss_host self.gss_deleg_creds = gss_deleg_creds @@ -146,7 +170,7 @@ class AuthHandler (object): self.transport.lock.acquire() try: self.auth_event = event - self.auth_method = 'gssapi-keyex' + self.auth_method = "gssapi-keyex" self.username = username self._request_auth() finally: @@ -161,15 +185,15 @@ class AuthHandler (object): def _request_auth(self): m = Message() m.add_byte(cMSG_SERVICE_REQUEST) - m.add_string('ssh-userauth') + m.add_string("ssh-userauth") self.transport._send_message(m) def _disconnect_service_not_available(self): m = Message() m.add_byte(cMSG_DISCONNECT) m.add_int(DISCONNECT_SERVICE_NOT_AVAILABLE) - m.add_string('Service not available') - m.add_string('en') + m.add_string("Service not available") + m.add_string("en") self.transport._send_message(m) self.transport.close() @@ -177,8 +201,8 @@ class AuthHandler (object): m = Message() m.add_byte(cMSG_DISCONNECT) m.add_int(DISCONNECT_NO_MORE_AUTH_METHODS_AVAILABLE) - m.add_string('No more auth methods available') - m.add_string('en') + m.add_string("No more auth methods available") + m.add_string("en") self.transport._send_message(m) self.transport.close() @@ -188,7 +212,7 @@ class AuthHandler (object): m.add_byte(cMSG_USERAUTH_REQUEST) m.add_string(username) m.add_string(service) - m.add_string('publickey') + m.add_string("publickey") m.add_boolean(True) # Use certificate contents, if available, plain pubkey otherwise if key.public_blob: @@ -208,17 +232,17 @@ class AuthHandler (object): if not self.transport.is_active(): e = self.transport.get_exception() if (e is None) or issubclass(e.__class__, EOFError): - e = AuthenticationException('Authentication failed.') + e = AuthenticationException("Authentication failed.") raise e if event.is_set(): break if max_ts is not None and max_ts <= time.time(): - raise AuthenticationException('Authentication timeout.') + raise AuthenticationException("Authentication timeout.") if not self.is_authenticated(): e = self.transport.get_exception() if e is None: - e = AuthenticationException('Authentication failed.') + e = AuthenticationException("Authentication failed.") # this is horrible. Python Exception isn't yet descended from # object, so type(e) won't work. :( if issubclass(e.__class__, PartialAuthentication): @@ -228,7 +252,7 @@ class AuthHandler (object): def _parse_service_request(self, m): service = m.get_text() - if self.transport.server_mode and (service == 'ssh-userauth'): + if self.transport.server_mode and (service == "ssh-userauth"): # accepted m = Message() m.add_byte(cMSG_SERVICE_ACCEPT) @@ -247,18 +271,18 @@ class AuthHandler (object): def _parse_service_accept(self, m): service = m.get_text() - if service == 'ssh-userauth': - self._log(DEBUG, 'userauth is OK') + if service == "ssh-userauth": + self._log(DEBUG, "userauth is OK") m = Message() m.add_byte(cMSG_USERAUTH_REQUEST) m.add_string(self.username) - m.add_string('ssh-connection') + m.add_string("ssh-connection") m.add_string(self.auth_method) - if self.auth_method == 'password': + if self.auth_method == "password": m.add_boolean(False) password = bytestring(self.password) m.add_string(password) - elif self.auth_method == 'publickey': + elif self.auth_method == "publickey": m.add_boolean(True) # Use certificate contents, if available, plain pubkey # otherwise @@ -269,11 +293,12 @@ class AuthHandler (object): m.add_string(self.private_key.get_name()) m.add_string(self.private_key) blob = self._get_session_blob( - self.private_key, 'ssh-connection', self.username) + self.private_key, "ssh-connection", self.username + ) sig = self.private_key.sign_ssh_data(blob) m.add_string(sig) - elif self.auth_method == 'keyboard-interactive': - m.add_string('') + elif self.auth_method == "keyboard-interactive": + m.add_string("") m.add_string(self.submethods) elif self.auth_method == "gssapi-with-mic": sshgss = GSSAuth(self.auth_method, self.gss_deleg_creds) @@ -292,10 +317,11 @@ class AuthHandler (object): m = Message() m.add_byte(cMSG_USERAUTH_GSSAPI_TOKEN) try: - m.add_string(sshgss.ssh_init_sec_context( - self.gss_host, - mech, - self.username,)) + m.add_string( + sshgss.ssh_init_sec_context( + self.gss_host, mech, self.username + ) + ) except GSS_EXCEPTIONS as e: return self._handle_local_gss_failure(e) self.transport._send_message(m) @@ -308,7 +334,8 @@ class AuthHandler (object): self.gss_host, mech, self.username, - srv_token) + srv_token, + ) except GSS_EXCEPTIONS as e: return self._handle_local_gss_failure(e) # After this step the GSSAPI should not return any @@ -340,48 +367,55 @@ class AuthHandler (object): min_status = m.get_int() err_msg = m.get_string() m.get_string() # Lang tag - discarded - raise SSHException("""GSS-API Error: + raise SSHException( + """GSS-API Error: Major Status: {} Minor Status: {} Error Message: {} -""".format(maj_status, min_status, err_msg)) +""".format( + maj_status, min_status, err_msg + ) + ) elif ptype == MSG_USERAUTH_FAILURE: self._parse_userauth_failure(m) return else: raise SSHException( - "Received Package: {}".format(MSG_NAMES[ptype])) + "Received Package: {}".format(MSG_NAMES[ptype]) + ) elif ( - self.auth_method == 'gssapi-keyex' and - self.transport.gss_kex_used + self.auth_method == "gssapi-keyex" + and self.transport.gss_kex_used ): kexgss = self.transport.kexgss_ctxt kexgss.set_username(self.username) mic_token = kexgss.ssh_get_mic(self.transport.session_id) m.add_string(mic_token) - elif self.auth_method == 'none': + elif self.auth_method == "none": pass else: raise SSHException( - 'Unknown auth method "{}"'.format(self.auth_method)) + 'Unknown auth method "{}"'.format(self.auth_method) + ) self.transport._send_message(m) else: self._log( - DEBUG, - 'Service request "{}" accepted (?)'.format(service)) + DEBUG, 'Service request "{}" accepted (?)'.format(service) + ) def _send_auth_result(self, username, method, result): # okay, send result m = Message() if result == AUTH_SUCCESSFUL: - self._log(INFO, 'Auth granted ({}).'.format(method)) + self._log(INFO, "Auth granted ({}).".format(method)) m.add_byte(cMSG_USERAUTH_SUCCESS) self.authenticated = True else: - self._log(INFO, 'Auth rejected ({}).'.format(method)) + self._log(INFO, "Auth rejected ({}).".format(method)) m.add_byte(cMSG_USERAUTH_FAILURE) m.add_string( - self.transport.server_object.get_allowed_auths(username)) + self.transport.server_object.get_allowed_auths(username) + ) if result == AUTH_PARTIALLY_SUCCESSFUL: m.add_boolean(True) else: @@ -411,7 +445,7 @@ Error Message: {} # er, uh... what? m = Message() m.add_byte(cMSG_USERAUTH_FAILURE) - m.add_string('none') + m.add_string("none") m.add_boolean(False) self.transport._send_message(m) return @@ -423,18 +457,20 @@ Error Message: {} method = m.get_text() self._log( DEBUG, - 'Auth request (type={}) service={}, username={}'.format( + "Auth request (type={}) service={}, username={}".format( method, service, username - ) + ), ) - if service != 'ssh-connection': + if service != "ssh-connection": self._disconnect_service_not_available() return - if ((self.auth_username is not None) and - (self.auth_username != username)): + if ( + (self.auth_username is not None) + and (self.auth_username != username) + ): self._log( WARNING, - 'Auth rejected because the client attempted to change username in mid-flight' # noqa + "Auth rejected because the client attempted to change username in mid-flight", # noqa ) self._disconnect_no_more_auth() return @@ -442,13 +478,13 @@ Error Message: {} # check if GSS-API authentication is enabled gss_auth = self.transport.server_object.enable_auth_gssapi() - if method == 'none': + if method == "none": result = self.transport.server_object.check_auth_none(username) - elif method == 'password': + elif method == "password": changereq = m.get_boolean() password = m.get_binary() try: - password = password.decode('UTF-8') + password = password.decode("UTF-8") except UnicodeError: # some clients/servers expect non-utf-8 passwords! # in this case, just return the raw byte string. @@ -457,31 +493,28 @@ Error Message: {} # always treated as failure, since we don't support changing # passwords, but collect the list of valid auth types from # the callback anyway - self._log( - DEBUG, - 'Auth request to change passwords (rejected)') + self._log(DEBUG, "Auth request to change passwords (rejected)") newpassword = m.get_binary() try: - newpassword = newpassword.decode('UTF-8', 'replace') + newpassword = newpassword.decode("UTF-8", "replace") except UnicodeError: pass result = AUTH_FAILED else: result = self.transport.server_object.check_auth_password( - username, password) - elif method == 'publickey': + username, password + ) + elif method == "publickey": sig_attached = m.get_boolean() keytype = m.get_text() keyblob = m.get_binary() try: key = self.transport._key_info[keytype](Message(keyblob)) except SSHException as e: - self._log( - INFO, - 'Auth rejected: public key: {}'.format(str(e))) + self._log(INFO, "Auth rejected: public key: {}".format(str(e))) key = None except Exception as e: - msg = 'Auth rejected: unsupported or mangled public key ({}: {})' # noqa + msg = "Auth rejected: unsupported or mangled public key ({}: {})" # noqa self._log(INFO, msg.format(e.__class__.__name__, e)) key = None if key is None: @@ -489,7 +522,8 @@ Error Message: {} return # first check if this key is okay... if not, we can skip the verify result = self.transport.server_object.check_auth_publickey( - username, key) + username, key + ) if result != AUTH_FAILED: # key is okay, verify it if not sig_attached: @@ -504,14 +538,13 @@ Error Message: {} sig = Message(m.get_binary()) blob = self._get_session_blob(key, service, username) if not key.verify_ssh_sig(blob, sig): - self._log( - INFO, - 'Auth rejected: invalid signature') + self._log(INFO, "Auth rejected: invalid signature") result = AUTH_FAILED - elif method == 'keyboard-interactive': + elif method == "keyboard-interactive": submethods = m.get_string() result = self.transport.server_object.check_auth_interactive( - username, submethods) + username, submethods + ) if isinstance(result, InteractiveQuery): # make interactive query instead of response self._interactive_query(result) @@ -527,7 +560,8 @@ Error Message: {} if mechs > 1: self._log( INFO, - 'Disconnect: Received more than one GSS-API OID mechanism') + "Disconnect: Received more than one GSS-API OID mechanism", + ) self._disconnect_no_more_auth() desired_mech = m.get_string() mech_ok = sshgss.ssh_check_mech(desired_mech) @@ -535,7 +569,8 @@ Error Message: {} if not mech_ok: self._log( INFO, - 'Disconnect: Received an invalid GSS-API OID mechanism') + "Disconnect: Received an invalid GSS-API OID mechanism", + ) self._disconnect_no_more_auth() # send the Kerberos V5 GSSAPI OID to the client supported_mech = sshgss.ssh_gss_oids("server") @@ -544,11 +579,14 @@ Error Message: {} m = Message() m.add_byte(cMSG_USERAUTH_GSSAPI_RESPONSE) m.add_bytes(supported_mech) - self.transport.auth_handler = GssapiWithMicAuthHandler(self, - sshgss) - self.transport._expected_packet = (MSG_USERAUTH_GSSAPI_TOKEN, - MSG_USERAUTH_REQUEST, - MSG_SERVICE_REQUEST) + self.transport.auth_handler = GssapiWithMicAuthHandler( + self, sshgss + ) + self.transport._expected_packet = ( + MSG_USERAUTH_GSSAPI_TOKEN, + MSG_USERAUTH_REQUEST, + MSG_SERVICE_REQUEST, + ) self.transport._send_message(m) return elif method == "gssapi-keyex" and gss_auth: @@ -559,16 +597,17 @@ Error Message: {} result = AUTH_FAILED self._send_auth_result(username, method, result) try: - sshgss.ssh_check_mic(mic_token, - self.transport.session_id, - self.auth_username) + sshgss.ssh_check_mic( + mic_token, self.transport.session_id, self.auth_username + ) except Exception: result = AUTH_FAILED self._send_auth_result(username, method, result) raise result = AUTH_SUCCESSFUL self.transport.server_object.check_auth_gssapi_keyex( - username, result) + username, result + ) else: result = self.transport.server_object.check_auth_none(username) # okay, send result @@ -576,8 +615,8 @@ Error Message: {} def _parse_userauth_success(self, m): self._log( - INFO, - 'Authentication ({}) successful!'.format(self.auth_method)) + INFO, "Authentication ({}) successful!".format(self.auth_method) + ) self.authenticated = True self.transport._auth_trigger() if self.auth_event is not None: @@ -587,24 +626,23 @@ Error Message: {} authlist = m.get_list() partial = m.get_boolean() if partial: - self._log(INFO, 'Authentication continues...') - self._log(DEBUG, 'Methods: ' + str(authlist)) + self._log(INFO, "Authentication continues...") + self._log(DEBUG, "Methods: " + str(authlist)) self.transport.saved_exception = PartialAuthentication(authlist) elif self.auth_method not in authlist: for msg in ( - 'Authentication type ({}) not permitted.'.format( + "Authentication type ({}) not permitted.".format( self.auth_method ), - 'Allowed methods: {}'.format(authlist), + "Allowed methods: {}".format(authlist), ): self._log(DEBUG, msg) self.transport.saved_exception = BadAuthenticationType( - 'Bad authentication type', authlist + "Bad authentication type", authlist ) else: self._log( - INFO, - 'Authentication ({}) failed.'.format(self.auth_method) + INFO, "Authentication ({}) failed.".format(self.auth_method) ) self.authenticated = False self.username = None @@ -614,12 +652,12 @@ Error Message: {} def _parse_userauth_banner(self, m): banner = m.get_string() self.banner = banner - self._log(INFO, 'Auth banner: {}'.format(banner)) + self._log(INFO, "Auth banner: {}".format(banner)) # who cares. def _parse_userauth_info_request(self, m): - if self.auth_method != 'keyboard-interactive': - raise SSHException('Illegal info request from server') + if self.auth_method != "keyboard-interactive": + raise SSHException("Illegal info request from server") title = m.get_text() instructions = m.get_text() m.get_binary() # lang @@ -628,7 +666,8 @@ Error Message: {} for i in range(prompts): prompt_list.append((m.get_text(), m.get_boolean())) response_list = self.interactive_handler( - title, instructions, prompt_list) + title, instructions, prompt_list + ) m = Message() m.add_byte(cMSG_USERAUTH_INFO_RESPONSE) @@ -639,25 +678,26 @@ Error Message: {} def _parse_userauth_info_response(self, m): if not self.transport.server_mode: - raise SSHException('Illegal info response from server') + raise SSHException("Illegal info response from server") n = m.get_int() responses = [] for i in range(n): responses.append(m.get_text()) result = self.transport.server_object.check_auth_interactive_response( - responses) + responses + ) if isinstance(result, InteractiveQuery): # make interactive query instead of response self._interactive_query(result) return self._send_auth_result( - self.auth_username, 'keyboard-interactive', result) + self.auth_username, "keyboard-interactive", result + ) def _handle_local_gss_failure(self, e): self.transport.saved_exception = e self._log(DEBUG, "GSSAPI failure: {}".format(e)) - self._log(INFO, 'Authentication ({}) failed.'.format( - self.auth_method)) + self._log(INFO, "Authentication ({}) failed.".format(self.auth_method)) self.authenticated = False self.username = None if self.auth_event is not None: @@ -718,9 +758,9 @@ class GssapiWithMicAuthHandler(object): # context. sshgss = self.sshgss try: - token = sshgss.ssh_accept_sec_context(self.gss_host, - client_token, - self.auth_username) + token = sshgss.ssh_accept_sec_context( + self.gss_host, client_token, self.auth_username + ) except Exception as e: self.transport.saved_exception = e result = AUTH_FAILED @@ -731,9 +771,11 @@ class GssapiWithMicAuthHandler(object): m = Message() m.add_byte(cMSG_USERAUTH_GSSAPI_TOKEN) m.add_string(token) - self.transport._expected_packet = (MSG_USERAUTH_GSSAPI_TOKEN, - MSG_USERAUTH_GSSAPI_MIC, - MSG_USERAUTH_REQUEST) + self.transport._expected_packet = ( + MSG_USERAUTH_GSSAPI_TOKEN, + MSG_USERAUTH_GSSAPI_MIC, + MSG_USERAUTH_REQUEST, + ) self.transport._send_message(m) def _parse_userauth_gssapi_mic(self, m): @@ -742,9 +784,9 @@ class GssapiWithMicAuthHandler(object): username = self.auth_username self._restore_delegate_auth_handler() try: - sshgss.ssh_check_mic(mic_token, - self.transport.session_id, - username) + sshgss.ssh_check_mic( + mic_token, self.transport.session_id, username + ) except Exception as e: self.transport.saved_exception = e result = AUTH_FAILED @@ -754,8 +796,9 @@ class GssapiWithMicAuthHandler(object): # The OpenSSH server is able to create a TGT with the delegated # client credentials, but this is not supported by GSS-API. result = AUTH_SUCCESSFUL - self.transport.server_object.check_auth_gssapi_with_mic(username, - result) + self.transport.server_object.check_auth_gssapi_with_mic( + username, result + ) # okay, send result self._send_auth_result(username, self.method, result) diff --git a/paramiko/ber.py b/paramiko/ber.py index 876347e0..fb6ee71d 100644 --- a/paramiko/ber.py +++ b/paramiko/ber.py @@ -21,7 +21,7 @@ from paramiko.py3compat import b, byte_ord, byte_chr, long import paramiko.util as util -class BERException (Exception): +class BERException(Exception): pass @@ -41,7 +41,7 @@ class BER(object): return self.asbytes() def __repr__(self): - return 'BER(\'' + repr(self.content) + '\')' + return "BER('" + repr(self.content) + "')" def decode(self): return self.decode_next() @@ -71,13 +71,12 @@ class BER(object): t = size & 0x7f if self.idx + t > len(self.content): return None - size = util.inflate_long( - self.content[self.idx: self.idx + t], True) + size = util.inflate_long(self.content[self.idx:self.idx + t], True) self.idx += t if self.idx + size > len(self.content): # can't fit return None - data = self.content[self.idx: self.idx + size] + data = self.content[self.idx:self.idx + size] self.idx += size # now switch on id if ident == 0x30: @@ -88,7 +87,7 @@ class BER(object): return util.inflate_long(data) else: # 1: boolean (00 false, otherwise true) - msg = 'Unknown ber encoding type {:d} (robey is lazy)' + msg = "Unknown ber encoding type {:d} (robey is lazy)" raise BERException(msg.format(ident)) @staticmethod @@ -126,7 +125,7 @@ class BER(object): self.encode_tlv(0x30, self.encode_sequence(x)) else: raise BERException( - 'Unknown type for encoding: {!r}'.format(type(x)) + "Unknown type for encoding: {!r}".format(type(x)) ) @staticmethod diff --git a/paramiko/buffered_pipe.py b/paramiko/buffered_pipe.py index d9f5149d..48f5aae5 100644 --- a/paramiko/buffered_pipe.py +++ b/paramiko/buffered_pipe.py @@ -28,14 +28,14 @@ import time from paramiko.py3compat import PY2, b -class PipeTimeout (IOError): +class PipeTimeout(IOError): """ Indicates that a timeout was reached on a read from a `.BufferedPipe`. """ pass -class BufferedPipe (object): +class BufferedPipe(object): """ A buffer that obeys normal read (with timeout) & close semantics for a file or socket, but is fed data from another thread. This is used by @@ -46,16 +46,19 @@ class BufferedPipe (object): self._lock = threading.Lock() self._cv = threading.Condition(self._lock) self._event = None - self._buffer = array.array('B') + self._buffer = array.array("B") self._closed = False if PY2: + def _buffer_frombytes(self, data): self._buffer.fromstring(data) def _buffer_tobytes(self, limit=None): return self._buffer[:limit].tostring() + else: + def _buffer_frombytes(self, data): self._buffer.frombytes(data) diff --git a/paramiko/channel.py b/paramiko/channel.py index 91a8f0df..00f86d6e 100644 --- a/paramiko/channel.py +++ b/paramiko/channel.py @@ -25,14 +25,22 @@ import os import socket import time import threading + # TODO: switch as much of py3compat.py to 'six' as possible, then use six.wraps from functools import wraps from paramiko import util from paramiko.common import ( - cMSG_CHANNEL_REQUEST, cMSG_CHANNEL_WINDOW_ADJUST, cMSG_CHANNEL_DATA, - cMSG_CHANNEL_EXTENDED_DATA, DEBUG, ERROR, cMSG_CHANNEL_SUCCESS, - cMSG_CHANNEL_FAILURE, cMSG_CHANNEL_EOF, cMSG_CHANNEL_CLOSE, + cMSG_CHANNEL_REQUEST, + cMSG_CHANNEL_WINDOW_ADJUST, + cMSG_CHANNEL_DATA, + cMSG_CHANNEL_EXTENDED_DATA, + DEBUG, + ERROR, + cMSG_CHANNEL_SUCCESS, + cMSG_CHANNEL_FAILURE, + cMSG_CHANNEL_EOF, + cMSG_CHANNEL_CLOSE, ) from paramiko.message import Message from paramiko.py3compat import bytes_types @@ -51,20 +59,22 @@ def open_only(func): `.SSHException` -- If the wrapped method is called on an unopened `.Channel`. """ + @wraps(func) def _check(self, *args, **kwds): if ( - self.closed or - self.eof_received or - self.eof_sent or - not self.active + self.closed + or self.eof_received + or self.eof_sent + or not self.active ): - raise SSHException('Channel is not open') + raise SSHException("Channel is not open") return func(self, *args, **kwds) + return _check -class Channel (ClosingContextManager): +class Channel(ClosingContextManager): """ A secure tunnel across an SSH `.Transport`. A Channel is meant to behave like a socket, and has an API that should be indistinguishable from the @@ -117,7 +127,7 @@ class Channel (ClosingContextManager): self.in_window_sofar = 0 self.status_event = threading.Event() self._name = str(chanid) - self.logger = util.get_logger('paramiko.transport') + self.logger = util.get_logger("paramiko.transport") self._pipe = None self.event = threading.Event() self.event_ready = False @@ -135,24 +145,30 @@ class Channel (ClosingContextManager): """ Return a string representation of this object, for debugging. """ - out = '<paramiko.Channel {}'.format(self.chanid) + out = "<paramiko.Channel {}".format(self.chanid) if self.closed: - out += ' (closed)' + out += " (closed)" elif self.active: if self.eof_received: - out += ' (EOF received)' + out += " (EOF received)" if self.eof_sent: - out += ' (EOF sent)' - out += ' (open) window={}'.format(self.out_window_size) + out += " (EOF sent)" + out += " (open) window={}".format(self.out_window_size) if len(self.in_buffer) > 0: - out += ' in-buffer={}'.format(len(self.in_buffer)) - out += ' -> ' + repr(self.transport) - out += '>' + out += " in-buffer={}".format(len(self.in_buffer)) + out += " -> " + repr(self.transport) + out += ">" return out @open_only - def get_pty(self, term='vt100', width=80, height=24, width_pixels=0, - height_pixels=0): + def get_pty( + self, + term="vt100", + width=80, + height=24, + width_pixels=0, + height_pixels=0, + ): """ Request a pseudo-terminal from the server. This is usually used right after creating a client channel, to ask the server to provide some @@ -174,7 +190,7 @@ class Channel (ClosingContextManager): m = Message() m.add_byte(cMSG_CHANNEL_REQUEST) m.add_int(self.remote_chanid) - m.add_string('pty-req') + m.add_string("pty-req") m.add_boolean(True) m.add_string(term) m.add_int(width) @@ -207,7 +223,7 @@ class Channel (ClosingContextManager): m = Message() m.add_byte(cMSG_CHANNEL_REQUEST) m.add_int(self.remote_chanid) - m.add_string('shell') + m.add_string("shell") m.add_boolean(True) self._event_pending() self.transport._send_user_message(m) @@ -233,7 +249,7 @@ class Channel (ClosingContextManager): m = Message() m.add_byte(cMSG_CHANNEL_REQUEST) m.add_int(self.remote_chanid) - m.add_string('exec') + m.add_string("exec") m.add_boolean(True) m.add_string(command) self._event_pending() @@ -259,7 +275,7 @@ class Channel (ClosingContextManager): m = Message() m.add_byte(cMSG_CHANNEL_REQUEST) m.add_int(self.remote_chanid) - m.add_string('subsystem') + m.add_string("subsystem") m.add_boolean(True) m.add_string(subsystem) self._event_pending() @@ -284,7 +300,7 @@ class Channel (ClosingContextManager): m = Message() m.add_byte(cMSG_CHANNEL_REQUEST) m.add_int(self.remote_chanid) - m.add_string('window-change') + m.add_string("window-change") m.add_boolean(False) m.add_int(width) m.add_int(height) @@ -315,7 +331,7 @@ class Channel (ClosingContextManager): try: self.set_environment_variable(name, value) except SSHException as e: - err = "Failed to set environment variable \"{}\"." + err = 'Failed to set environment variable "{}".' raise SSHException(err.format(name), e) @open_only @@ -339,7 +355,7 @@ class Channel (ClosingContextManager): m = Message() m.add_byte(cMSG_CHANNEL_REQUEST) m.add_int(self.remote_chanid) - m.add_string('env') + m.add_string("env") m.add_boolean(False) m.add_string(name) m.add_string(value) @@ -403,19 +419,19 @@ class Channel (ClosingContextManager): m = Message() m.add_byte(cMSG_CHANNEL_REQUEST) m.add_int(self.remote_chanid) - m.add_string('exit-status') + m.add_string("exit-status") m.add_boolean(False) m.add_int(status) self.transport._send_user_message(m) @open_only def request_x11( - self, - screen_number=0, - auth_protocol=None, - auth_cookie=None, - single_connection=False, - handler=None + self, + screen_number=0, + auth_protocol=None, + auth_cookie=None, + single_connection=False, + handler=None, ): """ Request an x11 session on this channel. If the server allows it, @@ -456,14 +472,14 @@ class Channel (ClosingContextManager): :return: the auth_cookie used """ if auth_protocol is None: - auth_protocol = 'MIT-MAGIC-COOKIE-1' + auth_protocol = "MIT-MAGIC-COOKIE-1" if auth_cookie is None: auth_cookie = binascii.hexlify(os.urandom(16)) m = Message() m.add_byte(cMSG_CHANNEL_REQUEST) m.add_int(self.remote_chanid) - m.add_string('x11-req') + m.add_string("x11-req") m.add_boolean(True) m.add_boolean(single_connection) m.add_string(auth_protocol) @@ -493,7 +509,7 @@ class Channel (ClosingContextManager): m = Message() m.add_byte(cMSG_CHANNEL_REQUEST) m.add_int(self.remote_chanid) - m.add_string('auth-agent-req@openssh.com') + m.add_string("auth-agent-req@openssh.com") m.add_boolean(False) self.transport._send_user_message(m) self.transport._set_forward_agent_handler(handler) @@ -808,7 +824,6 @@ class Channel (ClosingContextManager): m.add_int(1) return self._send(s, m) - def sendall(self, s): """ Send data to the channel, without allowing partial results. Unlike @@ -976,7 +991,7 @@ class Channel (ClosingContextManager): # a window update self.in_window_threshold = window_size // 10 self.in_window_sofar = 0 - self._log(DEBUG, 'Max packet in: {} bytes'.format(max_packet_size)) + self._log(DEBUG, "Max packet in: {} bytes".format(max_packet_size)) def _set_remote_channel(self, chanid, window_size, max_packet_size): self.remote_chanid = chanid @@ -985,11 +1000,12 @@ class Channel (ClosingContextManager): max_packet_size ) self.active = 1 - self._log(DEBUG, - 'Max packet out: {} bytes'.format(self.out_max_packet_size)) + self._log( + DEBUG, "Max packet out: {} bytes".format(self.out_max_packet_size) + ) def _request_success(self, m): - self._log(DEBUG, 'Sesch channel {} request ok'.format(self.chanid)) + self._log(DEBUG, "Sesch channel {} request ok".format(self.chanid)) self.event_ready = True self.event.set() return @@ -1017,8 +1033,7 @@ class Channel (ClosingContextManager): s = m.get_binary() if code != 1: self._log( - ERROR, - 'unknown extended_data type {}; discarding'.format(code) + ERROR, "unknown extended_data type {}; discarding".format(code) ) return if self.combine_stderr: @@ -1031,7 +1046,7 @@ class Channel (ClosingContextManager): self.lock.acquire() try: if self.ultra_debug: - self._log(DEBUG, 'window up {}'.format(nbytes)) + self._log(DEBUG, "window up {}".format(nbytes)) self.out_window_size += nbytes self.out_buffer_cv.notifyAll() finally: @@ -1042,14 +1057,14 @@ class Channel (ClosingContextManager): want_reply = m.get_boolean() server = self.transport.server_object ok = False - if key == 'exit-status': + if key == "exit-status": self.exit_status = m.get_int() self.status_event.set() ok = True - elif key == 'xon-xoff': + elif key == "xon-xoff": # ignore ok = True - elif key == 'pty-req': + elif key == "pty-req": term = m.get_string() width = m.get_int() height = m.get_int() @@ -1060,39 +1075,33 @@ class Channel (ClosingContextManager): ok = False else: ok = server.check_channel_pty_request( - self, - term, - width, - height, - pixelwidth, - pixelheight, - modes + self, term, width, height, pixelwidth, pixelheight, modes ) - elif key == 'shell': + elif key == "shell": if server is None: ok = False else: ok = server.check_channel_shell_request(self) - elif key == 'env': + elif key == "env": name = m.get_string() value = m.get_string() if server is None: ok = False else: ok = server.check_channel_env_request(self, name, value) - elif key == 'exec': + elif key == "exec": cmd = m.get_string() if server is None: ok = False else: ok = server.check_channel_exec_request(self, cmd) - elif key == 'subsystem': + elif key == "subsystem": name = m.get_text() if server is None: ok = False else: ok = server.check_channel_subsystem_request(self, name) - elif key == 'window-change': + elif key == "window-change": width = m.get_int() height = m.get_int() pixelwidth = m.get_int() @@ -1101,8 +1110,9 @@ class Channel (ClosingContextManager): ok = False else: ok = server.check_channel_window_change_request( - self, width, height, pixelwidth, pixelheight) - elif key == 'x11-req': + self, width, height, pixelwidth, pixelheight + ) + elif key == "x11-req": single_connection = m.get_boolean() auth_proto = m.get_text() auth_cookie = m.get_binary() @@ -1115,9 +1125,9 @@ class Channel (ClosingContextManager): single_connection, auth_proto, auth_cookie, - screen_number + screen_number, ) - elif key == 'auth-agent-req@openssh.com': + elif key == "auth-agent-req@openssh.com": if server is None: ok = False else: @@ -1145,7 +1155,7 @@ class Channel (ClosingContextManager): self._pipe.set_forever() finally: self.lock.release() - self._log(DEBUG, 'EOF received ({})'.format(self._name)) + self._log(DEBUG, "EOF received ({})".format(self._name)) def _handle_close(self, m): self.lock.acquire() @@ -1167,7 +1177,7 @@ class Channel (ClosingContextManager): if self.closed: # this doesn't seem useful, but it is the documented behavior # of Socket - raise socket.error('Socket is closed') + raise socket.error("Socket is closed") size = self._wait_for_send_window(size) if size == 0: # eof or similar @@ -1194,7 +1204,7 @@ class Channel (ClosingContextManager): return e = self.transport.get_exception() if e is None: - e = SSHException('Channel closed.') + e = SSHException("Channel closed.") raise e def _set_closed(self): @@ -1217,7 +1227,7 @@ class Channel (ClosingContextManager): m.add_byte(cMSG_CHANNEL_EOF) m.add_int(self.remote_chanid) self.eof_sent = True - self._log(DEBUG, 'EOF sent ({})'.format(self._name)) + self._log(DEBUG, "EOF sent ({})".format(self._name)) return m def _close_internal(self): @@ -1251,13 +1261,14 @@ class Channel (ClosingContextManager): if self.closed or self.eof_received or not self.active: return 0 if self.ultra_debug: - self._log(DEBUG, 'addwindow {}'.format(n)) + self._log(DEBUG, "addwindow {}".format(n)) self.in_window_sofar += n if self.in_window_sofar <= self.in_window_threshold: return 0 if self.ultra_debug: - self._log(DEBUG, - 'addwindow send {}'.format(self.in_window_sofar)) + self._log( + DEBUG, "addwindow send {}".format(self.in_window_sofar) + ) out = self.in_window_sofar self.in_window_sofar = 0 return out @@ -1300,11 +1311,11 @@ class Channel (ClosingContextManager): size = self.out_max_packet_size - 64 self.out_window_size -= size if self.ultra_debug: - self._log(DEBUG, 'window down to {}'.format(self.out_window_size)) + self._log(DEBUG, "window down to {}".format(self.out_window_size)) return size -class ChannelFile (BufferedFile): +class ChannelFile(BufferedFile): """ A file-like wrapper around `.Channel`. A ChannelFile is created by calling `Channel.makefile`. @@ -1317,7 +1328,7 @@ class ChannelFile (BufferedFile): flush the buffer. """ - def __init__(self, channel, mode='r', bufsize=-1): + def __init__(self, channel, mode="r", bufsize=-1): self.channel = channel BufferedFile.__init__(self) self._set_mode(mode, bufsize) @@ -1326,7 +1337,7 @@ class ChannelFile (BufferedFile): """ Returns a string representation of this object, for debugging. """ - return '<paramiko.ChannelFile from ' + repr(self.channel) + '>' + return "<paramiko.ChannelFile from " + repr(self.channel) + ">" def _read(self, size): return self.channel.recv(size) @@ -1336,8 +1347,9 @@ class ChannelFile (BufferedFile): return len(data) -class ChannelStderrFile (ChannelFile): - def __init__(self, channel, mode='r', bufsize=-1): +class ChannelStderrFile(ChannelFile): + + def __init__(self, channel, mode="r", bufsize=-1): ChannelFile.__init__(self, channel, mode, bufsize) def _read(self, size): diff --git a/paramiko/client.py b/paramiko/client.py index 6f0cb847..8690c86d 100644 --- a/paramiko/client.py +++ b/paramiko/client.py @@ -38,13 +38,15 @@ from paramiko.hostkeys import HostKeys from paramiko.py3compat import string_types from paramiko.rsakey import RSAKey from paramiko.ssh_exception import ( - SSHException, BadHostKeyException, NoValidConnectionsError + SSHException, + BadHostKeyException, + NoValidConnectionsError, ) from paramiko.transport import Transport from paramiko.util import retry_on_signal, ClosingContextManager -class SSHClient (ClosingContextManager): +class SSHClient(ClosingContextManager): """ A high-level representation of a session with an SSH server. This class wraps `.Transport`, `.Channel`, and `.SFTPClient` to take care of most @@ -97,7 +99,7 @@ class SSHClient (ClosingContextManager): """ if filename is None: # try the user's .ssh key file, and mask exceptions - filename = os.path.expanduser('~/.ssh/known_hosts') + filename = os.path.expanduser("~/.ssh/known_hosts") try: self._system_host_keys.load(filename) except IOError: @@ -140,12 +142,14 @@ class SSHClient (ClosingContextManager): if self._host_keys_filename is not None: self.load_host_keys(self._host_keys_filename) - with open(filename, 'w') as f: + with open(filename, "w") as f: for hostname, keys in self._host_keys.items(): for keytype, key in keys.items(): - f.write('{} {} {}\n'.format( - hostname, keytype, key.get_base64() - )) + f.write( + "{} {} {}\n".format( + hostname, keytype, key.get_base64() + ) + ) def get_host_keys(self): """ @@ -197,7 +201,8 @@ class SSHClient (ClosingContextManager): """ guess = True addrinfos = socket.getaddrinfo( - hostname, port, socket.AF_UNSPEC, socket.SOCK_STREAM) + hostname, port, socket.AF_UNSPEC, socket.SOCK_STREAM + ) for (family, socktype, proto, canonname, sockaddr) in addrinfos: if socktype == socket.SOCK_STREAM: yield family, sockaddr @@ -419,8 +424,16 @@ class SSHClient (ClosingContextManager): key_filenames = key_filename self._auth( - username, password, pkey, key_filenames, allow_agent, - look_for_keys, gss_auth, gss_kex, gss_deleg_creds, t.gss_host, + username, + password, + pkey, + key_filenames, + allow_agent, + look_for_keys, + gss_auth, + gss_kex, + gss_deleg_creds, + t.gss_host, passphrase, ) @@ -484,13 +497,20 @@ class SSHClient (ClosingContextManager): if environment: chan.update_environment(environment) chan.exec_command(command) - stdin = chan.makefile('wb', bufsize) - stdout = chan.makefile('r', bufsize) - stderr = chan.makefile_stderr('r', bufsize) + stdin = chan.makefile("wb", bufsize) + stdout = chan.makefile("r", bufsize) + stderr = chan.makefile_stderr("r", bufsize) return stdin, stdout, stderr - def invoke_shell(self, term='vt100', width=80, height=24, width_pixels=0, - height_pixels=0, environment=None): + def invoke_shell( + self, + term="vt100", + width=80, + height=24, + width_pixels=0, + height_pixels=0, + environment=None, + ): """ Start an interactive shell session on the SSH server. A new `.Channel` is opened and connected to a pseudo-terminal using the requested @@ -539,7 +559,7 @@ class SSHClient (ClosingContextManager): - Otherwise, the filename is assumed to be a private key, and the matching public cert will be loaded if it exists. """ - cert_suffix = '-cert.pub' + cert_suffix = "-cert.pub" # Assume privkey, not cert, by default if filename.endswith(cert_suffix): key_path = filename[:-len(cert_suffix)] @@ -553,7 +573,7 @@ class SSHClient (ClosingContextManager): # when #387 is released, since this is a critical log message users are # likely testing/filtering for (bah.) msg = "Trying discovered key {} in {}".format( - hexlify(key.get_fingerprint()), key_path, + hexlify(key.get_fingerprint()), key_path ) self._log(DEBUG, msg) # Attempt to load cert if it exists. @@ -563,8 +583,17 @@ class SSHClient (ClosingContextManager): return key def _auth( - self, username, password, pkey, key_filenames, allow_agent, - look_for_keys, gss_auth, gss_kex, gss_deleg_creds, gss_host, + self, + username, + password, + pkey, + key_filenames, + allow_agent, + look_for_keys, + gss_auth, + gss_kex, + gss_deleg_creds, + gss_host, passphrase, ): """ @@ -583,7 +612,7 @@ class SSHClient (ClosingContextManager): saved_exception = None two_factor = False allowed_types = set() - two_factor_types = {'keyboard-interactive', 'password'} + two_factor_types = {"keyboard-interactive", "password"} if passphrase is None and password is not None: passphrase = password @@ -603,7 +632,7 @@ class SSHClient (ClosingContextManager): if gss_auth: try: return self._transport.auth_gssapi_with_mic( - username, gss_host, gss_deleg_creds, + username, gss_host, gss_deleg_creds ) except Exception as e: saved_exception = e @@ -612,10 +641,13 @@ class SSHClient (ClosingContextManager): try: self._log( DEBUG, - 'Trying SSH key {}'.format(hexlify(pkey.get_fingerprint())) + "Trying SSH key {}".format( + hexlify(pkey.get_fingerprint()) + ), ) allowed_types = set( - self._transport.auth_publickey(username, pkey)) + self._transport.auth_publickey(username, pkey) + ) two_factor = (allowed_types & two_factor_types) if not two_factor: return @@ -627,10 +659,11 @@ class SSHClient (ClosingContextManager): for pkey_class in (RSAKey, DSSKey, ECDSAKey, Ed25519Key): try: key = self._key_from_filepath( - key_filename, pkey_class, passphrase, + key_filename, pkey_class, passphrase ) allowed_types = set( - self._transport.auth_publickey(username, key)) + self._transport.auth_publickey(username, key) + ) two_factor = (allowed_types & two_factor_types) if not two_factor: return @@ -645,11 +678,12 @@ class SSHClient (ClosingContextManager): for key in self._agent.get_keys(): try: id_ = hexlify(key.get_fingerprint()) - self._log(DEBUG, 'Trying SSH agent key {}'.format(id_)) + self._log(DEBUG, "Trying SSH agent key {}".format(id_)) # for 2-factor auth a successfully auth'd key password # will return an allowed 2fac auth method allowed_types = set( - self._transport.auth_publickey(username, key)) + self._transport.auth_publickey(username, key) + ) two_factor = (allowed_types & two_factor_types) if not two_factor: return @@ -674,8 +708,8 @@ class SSHClient (ClosingContextManager): if os.path.isfile(full_path): # TODO: only do this append if below did not run keyfiles.append((keytype, full_path)) - if os.path.isfile(full_path + '-cert.pub'): - keyfiles.append((keytype, full_path + '-cert.pub')) + if os.path.isfile(full_path + "-cert.pub"): + keyfiles.append((keytype, full_path + "-cert.pub")) if not look_for_keys: keyfiles = [] @@ -683,12 +717,13 @@ class SSHClient (ClosingContextManager): for pkey_class, filename in keyfiles: try: key = self._key_from_filepath( - filename, pkey_class, passphrase, + filename, pkey_class, passphrase ) # for 2-factor auth a successfully auth'd key will result # in ['password'] allowed_types = set( - self._transport.auth_publickey(username, key)) + self._transport.auth_publickey(username, key) + ) two_factor = (allowed_types & two_factor_types) if not two_factor: return @@ -712,13 +747,13 @@ class SSHClient (ClosingContextManager): # if we got an auth-failed exception earlier, re-raise it if saved_exception is not None: raise saved_exception - raise SSHException('No authentication methods available') + raise SSHException("No authentication methods available") def _log(self, level, msg): self._transport._log(level, msg) -class MissingHostKeyPolicy (object): +class MissingHostKeyPolicy(object): """ Interface for defining the policy that `.SSHClient` should use when the SSH server's hostname is not in either the system host keys or the @@ -739,7 +774,7 @@ class MissingHostKeyPolicy (object): pass -class AutoAddPolicy (MissingHostKeyPolicy): +class AutoAddPolicy(MissingHostKeyPolicy): """ Policy for automatically adding the hostname and new host key to the local `.HostKeys` object, and saving it. This is used by `.SSHClient`. @@ -749,32 +784,41 @@ class AutoAddPolicy (MissingHostKeyPolicy): client._host_keys.add(hostname, key.get_name(), key) if client._host_keys_filename is not None: client.save_host_keys(client._host_keys_filename) - client._log(DEBUG, 'Adding {} host key for {}: {}'.format( - key.get_name(), hostname, hexlify(key.get_fingerprint()), - )) + client._log( + DEBUG, + "Adding {} host key for {}: {}".format( + key.get_name(), hostname, hexlify(key.get_fingerprint()) + ), + ) -class RejectPolicy (MissingHostKeyPolicy): +class RejectPolicy(MissingHostKeyPolicy): """ Policy for automatically rejecting the unknown hostname & key. This is used by `.SSHClient`. """ def missing_host_key(self, client, hostname, key): - client._log(DEBUG, 'Rejecting {} host key for {}: {}'.format( - key.get_name(), hostname, hexlify(key.get_fingerprint()), - )) + client._log( + DEBUG, + "Rejecting {} host key for {}: {}".format( + key.get_name(), hostname, hexlify(key.get_fingerprint()) + ), + ) raise SSHException( - 'Server {!r} not found in known_hosts'.format(hostname) + "Server {!r} not found in known_hosts".format(hostname) ) -class WarningPolicy (MissingHostKeyPolicy): +class WarningPolicy(MissingHostKeyPolicy): """ Policy for logging a Python-style warning for an unknown host key, but accepting it. This is used by `.SSHClient`. """ + def missing_host_key(self, client, hostname, key): - warnings.warn('Unknown {} host key for {}: {}'.format( - key.get_name(), hostname, hexlify(key.get_fingerprint()), - )) + warnings.warn( + "Unknown {} host key for {}: {}".format( + key.get_name(), hostname, hexlify(key.get_fingerprint()) + ) + ) diff --git a/paramiko/common.py b/paramiko/common.py index 11c4121d..7e9510b9 100644 --- a/paramiko/common.py +++ b/paramiko/common.py @@ -22,22 +22,53 @@ Common constants and global variables. import logging from paramiko.py3compat import byte_chr, PY2, bytes_types, text_type, long -MSG_DISCONNECT, MSG_IGNORE, MSG_UNIMPLEMENTED, MSG_DEBUG, \ - MSG_SERVICE_REQUEST, MSG_SERVICE_ACCEPT = range(1, 7) -MSG_KEXINIT, MSG_NEWKEYS = range(20, 22) -MSG_USERAUTH_REQUEST, MSG_USERAUTH_FAILURE, MSG_USERAUTH_SUCCESS, \ - MSG_USERAUTH_BANNER = range(50, 54) +( + MSG_DISCONNECT, + MSG_IGNORE, + MSG_UNIMPLEMENTED, + MSG_DEBUG, + MSG_SERVICE_REQUEST, + MSG_SERVICE_ACCEPT, +) = range( + 1, 7 +) +(MSG_KEXINIT, MSG_NEWKEYS) = range(20, 22) +( + MSG_USERAUTH_REQUEST, + MSG_USERAUTH_FAILURE, + MSG_USERAUTH_SUCCESS, + MSG_USERAUTH_BANNER, +) = range( + 50, 54 +) MSG_USERAUTH_PK_OK = 60 -MSG_USERAUTH_INFO_REQUEST, MSG_USERAUTH_INFO_RESPONSE = range(60, 62) -MSG_USERAUTH_GSSAPI_RESPONSE, MSG_USERAUTH_GSSAPI_TOKEN = range(60, 62) -MSG_USERAUTH_GSSAPI_EXCHANGE_COMPLETE, MSG_USERAUTH_GSSAPI_ERROR,\ - MSG_USERAUTH_GSSAPI_ERRTOK, MSG_USERAUTH_GSSAPI_MIC = range(63, 67) +(MSG_USERAUTH_INFO_REQUEST, MSG_USERAUTH_INFO_RESPONSE) = range(60, 62) +(MSG_USERAUTH_GSSAPI_RESPONSE, MSG_USERAUTH_GSSAPI_TOKEN) = range(60, 62) +( + MSG_USERAUTH_GSSAPI_EXCHANGE_COMPLETE, + MSG_USERAUTH_GSSAPI_ERROR, + MSG_USERAUTH_GSSAPI_ERRTOK, + MSG_USERAUTH_GSSAPI_MIC, +) = range( + 63, 67 +) HIGHEST_USERAUTH_MESSAGE_ID = 79 -MSG_GLOBAL_REQUEST, MSG_REQUEST_SUCCESS, MSG_REQUEST_FAILURE = range(80, 83) -MSG_CHANNEL_OPEN, MSG_CHANNEL_OPEN_SUCCESS, MSG_CHANNEL_OPEN_FAILURE, \ - MSG_CHANNEL_WINDOW_ADJUST, MSG_CHANNEL_DATA, MSG_CHANNEL_EXTENDED_DATA, \ - MSG_CHANNEL_EOF, MSG_CHANNEL_CLOSE, MSG_CHANNEL_REQUEST, \ - MSG_CHANNEL_SUCCESS, MSG_CHANNEL_FAILURE = range(90, 101) +(MSG_GLOBAL_REQUEST, MSG_REQUEST_SUCCESS, MSG_REQUEST_FAILURE) = range(80, 83) +( + MSG_CHANNEL_OPEN, + MSG_CHANNEL_OPEN_SUCCESS, + MSG_CHANNEL_OPEN_FAILURE, + MSG_CHANNEL_WINDOW_ADJUST, + MSG_CHANNEL_DATA, + MSG_CHANNEL_EXTENDED_DATA, + MSG_CHANNEL_EOF, + MSG_CHANNEL_CLOSE, + MSG_CHANNEL_REQUEST, + MSG_CHANNEL_SUCCESS, + MSG_CHANNEL_FAILURE, +) = range( + 90, 101 +) cMSG_DISCONNECT = byte_chr(MSG_DISCONNECT) cMSG_IGNORE = byte_chr(MSG_IGNORE) @@ -56,8 +87,9 @@ cMSG_USERAUTH_INFO_REQUEST = byte_chr(MSG_USERAUTH_INFO_REQUEST) cMSG_USERAUTH_INFO_RESPONSE = byte_chr(MSG_USERAUTH_INFO_RESPONSE) cMSG_USERAUTH_GSSAPI_RESPONSE = byte_chr(MSG_USERAUTH_GSSAPI_RESPONSE) cMSG_USERAUTH_GSSAPI_TOKEN = byte_chr(MSG_USERAUTH_GSSAPI_TOKEN) -cMSG_USERAUTH_GSSAPI_EXCHANGE_COMPLETE = \ - byte_chr(MSG_USERAUTH_GSSAPI_EXCHANGE_COMPLETE) +cMSG_USERAUTH_GSSAPI_EXCHANGE_COMPLETE = byte_chr( + MSG_USERAUTH_GSSAPI_EXCHANGE_COMPLETE +) cMSG_USERAUTH_GSSAPI_ERROR = byte_chr(MSG_USERAUTH_GSSAPI_ERROR) cMSG_USERAUTH_GSSAPI_ERRTOK = byte_chr(MSG_USERAUTH_GSSAPI_ERRTOK) cMSG_USERAUTH_GSSAPI_MIC = byte_chr(MSG_USERAUTH_GSSAPI_MIC) @@ -78,47 +110,47 @@ cMSG_CHANNEL_FAILURE = byte_chr(MSG_CHANNEL_FAILURE) # for debugging: MSG_NAMES = { - MSG_DISCONNECT: 'disconnect', - MSG_IGNORE: 'ignore', - MSG_UNIMPLEMENTED: 'unimplemented', - MSG_DEBUG: 'debug', - MSG_SERVICE_REQUEST: 'service-request', - MSG_SERVICE_ACCEPT: 'service-accept', - MSG_KEXINIT: 'kexinit', - MSG_NEWKEYS: 'newkeys', - 30: 'kex30', - 31: 'kex31', - 32: 'kex32', - 33: 'kex33', - 34: 'kex34', - 40: 'kex40', - 41: 'kex41', - MSG_USERAUTH_REQUEST: 'userauth-request', - MSG_USERAUTH_FAILURE: 'userauth-failure', - MSG_USERAUTH_SUCCESS: 'userauth-success', - MSG_USERAUTH_BANNER: 'userauth--banner', - MSG_USERAUTH_PK_OK: 'userauth-60(pk-ok/info-request)', - MSG_USERAUTH_INFO_RESPONSE: 'userauth-info-response', - MSG_GLOBAL_REQUEST: 'global-request', - MSG_REQUEST_SUCCESS: 'request-success', - MSG_REQUEST_FAILURE: 'request-failure', - MSG_CHANNEL_OPEN: 'channel-open', - MSG_CHANNEL_OPEN_SUCCESS: 'channel-open-success', - MSG_CHANNEL_OPEN_FAILURE: 'channel-open-failure', - MSG_CHANNEL_WINDOW_ADJUST: 'channel-window-adjust', - MSG_CHANNEL_DATA: 'channel-data', - MSG_CHANNEL_EXTENDED_DATA: 'channel-extended-data', - MSG_CHANNEL_EOF: 'channel-eof', - MSG_CHANNEL_CLOSE: 'channel-close', - MSG_CHANNEL_REQUEST: 'channel-request', - MSG_CHANNEL_SUCCESS: 'channel-success', - MSG_CHANNEL_FAILURE: 'channel-failure', - MSG_USERAUTH_GSSAPI_RESPONSE: 'userauth-gssapi-response', - MSG_USERAUTH_GSSAPI_TOKEN: 'userauth-gssapi-token', - MSG_USERAUTH_GSSAPI_EXCHANGE_COMPLETE: 'userauth-gssapi-exchange-complete', - MSG_USERAUTH_GSSAPI_ERROR: 'userauth-gssapi-error', - MSG_USERAUTH_GSSAPI_ERRTOK: 'userauth-gssapi-error-token', - MSG_USERAUTH_GSSAPI_MIC: 'userauth-gssapi-mic' + MSG_DISCONNECT: "disconnect", + MSG_IGNORE: "ignore", + MSG_UNIMPLEMENTED: "unimplemented", + MSG_DEBUG: "debug", + MSG_SERVICE_REQUEST: "service-request", + MSG_SERVICE_ACCEPT: "service-accept", + MSG_KEXINIT: "kexinit", + MSG_NEWKEYS: "newkeys", + 30: "kex30", + 31: "kex31", + 32: "kex32", + 33: "kex33", + 34: "kex34", + 40: "kex40", + 41: "kex41", + MSG_USERAUTH_REQUEST: "userauth-request", + MSG_USERAUTH_FAILURE: "userauth-failure", + MSG_USERAUTH_SUCCESS: "userauth-success", + MSG_USERAUTH_BANNER: "userauth--banner", + MSG_USERAUTH_PK_OK: "userauth-60(pk-ok/info-request)", + MSG_USERAUTH_INFO_RESPONSE: "userauth-info-response", + MSG_GLOBAL_REQUEST: "global-request", + MSG_REQUEST_SUCCESS: "request-success", + MSG_REQUEST_FAILURE: "request-failure", + MSG_CHANNEL_OPEN: "channel-open", + MSG_CHANNEL_OPEN_SUCCESS: "channel-open-success", + MSG_CHANNEL_OPEN_FAILURE: "channel-open-failure", + MSG_CHANNEL_WINDOW_ADJUST: "channel-window-adjust", + MSG_CHANNEL_DATA: "channel-data", + MSG_CHANNEL_EXTENDED_DATA: "channel-extended-data", + MSG_CHANNEL_EOF: "channel-eof", + MSG_CHANNEL_CLOSE: "channel-close", + MSG_CHANNEL_REQUEST: "channel-request", + MSG_CHANNEL_SUCCESS: "channel-success", + MSG_CHANNEL_FAILURE: "channel-failure", + MSG_USERAUTH_GSSAPI_RESPONSE: "userauth-gssapi-response", + MSG_USERAUTH_GSSAPI_TOKEN: "userauth-gssapi-token", + MSG_USERAUTH_GSSAPI_EXCHANGE_COMPLETE: "userauth-gssapi-exchange-complete", + MSG_USERAUTH_GSSAPI_ERROR: "userauth-gssapi-error", + MSG_USERAUTH_GSSAPI_ERRTOK: "userauth-gssapi-error-token", + MSG_USERAUTH_GSSAPI_MIC: "userauth-gssapi-mic", } @@ -127,23 +159,30 @@ AUTH_SUCCESSFUL, AUTH_PARTIALLY_SUCCESSFUL, AUTH_FAILED = range(3) # channel request failed reasons: -(OPEN_SUCCEEDED, - OPEN_FAILED_ADMINISTRATIVELY_PROHIBITED, - OPEN_FAILED_CONNECT_FAILED, - OPEN_FAILED_UNKNOWN_CHANNEL_TYPE, - OPEN_FAILED_RESOURCE_SHORTAGE) = range(0, 5) +( + OPEN_SUCCEEDED, + OPEN_FAILED_ADMINISTRATIVELY_PROHIBITED, + OPEN_FAILED_CONNECT_FAILED, + OPEN_FAILED_UNKNOWN_CHANNEL_TYPE, + OPEN_FAILED_RESOURCE_SHORTAGE, +) = range( + 0, 5 +) CONNECTION_FAILED_CODE = { - 1: 'Administratively prohibited', - 2: 'Connect failed', - 3: 'Unknown channel type', - 4: 'Resource shortage' + 1: "Administratively prohibited", + 2: "Connect failed", + 3: "Unknown channel type", + 4: "Resource shortage", } -DISCONNECT_SERVICE_NOT_AVAILABLE, DISCONNECT_AUTH_CANCELLED_BY_USER, \ - DISCONNECT_NO_MORE_AUTH_METHODS_AVAILABLE = 7, 13, 14 +( + DISCONNECT_SERVICE_NOT_AVAILABLE, + DISCONNECT_AUTH_CANCELLED_BY_USER, + DISCONNECT_NO_MORE_AUTH_METHODS_AVAILABLE, +) = 7, 13, 14 zero_byte = byte_chr(0) one_byte = byte_chr(1) diff --git a/paramiko/compress.py b/paramiko/compress.py index 5073109c..cea2b308 100644 --- a/paramiko/compress.py +++ b/paramiko/compress.py @@ -23,7 +23,8 @@ Compression implementations for a Transport. import zlib -class ZlibCompressor (object): +class ZlibCompressor(object): + def __init__(self): # Use the default level of zlib compression self.z = zlib.compressobj() @@ -32,7 +33,8 @@ class ZlibCompressor (object): return self.z.compress(data) + self.z.flush(zlib.Z_FULL_FLUSH) -class ZlibDecompressor (object): +class ZlibDecompressor(object): + def __init__(self): self.z = zlib.decompressobj() diff --git a/paramiko/config.py b/paramiko/config.py index 038d84ea..21c9dab8 100644 --- a/paramiko/config.py +++ b/paramiko/config.py @@ -30,7 +30,7 @@ import socket SSH_PORT = 22 -class SSHConfig (object): +class SSHConfig(object): """ Representation of config information as stored in the format used by OpenSSH. Queries can be made via `lookup`. The format is described in @@ -41,7 +41,7 @@ class SSHConfig (object): .. versionadded:: 1.6 """ - SETTINGS_REGEX = re.compile(r'(\w+)(?:\s*=\s*|\s+)(.+)') + SETTINGS_REGEX = re.compile(r"(\w+)(?:\s*=\s*|\s+)(.+)") def __init__(self): """ @@ -55,12 +55,12 @@ class SSHConfig (object): :param file_obj: a file-like object to read the config file from """ - host = {"host": ['*'], "config": {}} + host = {"host": ["*"], "config": {}} for line in file_obj: # Strip any leading or trailing whitespace from the line. # Refer to https://github.com/paramiko/paramiko/issues/499 line = line.strip() - if not line or line.startswith('#'): + if not line or line.startswith("#"): continue match = re.match(self.SETTINGS_REGEX, line) @@ -69,17 +69,14 @@ class SSHConfig (object): key = match.group(1).lower() value = match.group(2) - if key == 'host': + if key == "host": self._config.append(host) - host = { - 'host': self._get_hosts(value), - 'config': {} - } - elif key == 'proxycommand' and value.lower() == 'none': + host = {"host": self._get_hosts(value), "config": {}} + elif key == "proxycommand" and value.lower() == "none": # Store 'none' as None; prior to 3.x, it will get stripped out # at the end (for compatibility with issue #415). After 3.x, it # will simply not get stripped, leaving a nice explicit marker. - host['config'][key] = None + host["config"][key] = None else: if value.startswith('"') and value.endswith('"'): value = value[1:-1] @@ -87,13 +84,13 @@ class SSHConfig (object): # identityfile, localforward, remoteforward keys are special # cases, since they are allowed to be specified multiple times # and they should be tried in order of specification. - if key in ['identityfile', 'localforward', 'remoteforward']: - if key in host['config']: - host['config'][key].append(value) + if key in ["identityfile", "localforward", "remoteforward"]: + if key in host["config"]: + host["config"][key].append(value) else: - host['config'][key] = [value] - elif key not in host['config']: - host['config'][key] = value + host["config"][key] = [value] + elif key not in host["config"]: + host["config"][key] = value self._config.append(host) def lookup(self, hostname): @@ -117,25 +114,26 @@ class SSHConfig (object): :param str hostname: the hostname to lookup """ matches = [ - config for config in self._config - if self._allowed(config['host'], hostname) + config + for config in self._config + if self._allowed(config["host"], hostname) ] ret = {} for match in matches: - for key, value in match['config'].items(): + for key, value in match["config"].items(): if key not in ret: # Create a copy of the original value, # else it will reference the original list # in self._config and update that value too # when the extend() is being called. ret[key] = value[:] if value is not None else value - elif key == 'identityfile': + elif key == "identityfile": ret[key].extend(value) ret = self._expand_variables(ret, hostname) # TODO: remove in 3.x re #670 - if 'proxycommand' in ret and ret['proxycommand'] is None: - del ret['proxycommand'] + if "proxycommand" in ret and ret["proxycommand"] is None: + del ret["proxycommand"] return ret def get_hostnames(self): @@ -145,13 +143,13 @@ class SSHConfig (object): """ hosts = set() for entry in self._config: - hosts.update(entry['host']) + hosts.update(entry["host"]) return hosts def _allowed(self, hosts, hostname): match = False for host in hosts: - if host.startswith('!') and fnmatch.fnmatch(hostname, host[1:]): + if host.startswith("!") and fnmatch.fnmatch(hostname, host[1:]): return False elif fnmatch.fnmatch(hostname, host): match = True @@ -169,52 +167,50 @@ class SSHConfig (object): :param str hostname: the hostname that the config belongs to """ - if 'hostname' in config: - config['hostname'] = config['hostname'].replace('%h', hostname) + if "hostname" in config: + config["hostname"] = config["hostname"].replace("%h", hostname) else: - config['hostname'] = hostname + config["hostname"] = hostname - if 'port' in config: - port = config['port'] + if "port" in config: + port = config["port"] else: port = SSH_PORT - user = os.getenv('USER') - if 'user' in config: - remoteuser = config['user'] + user = os.getenv("USER") + if "user" in config: + remoteuser = config["user"] else: remoteuser = user - host = socket.gethostname().split('.')[0] + host = socket.gethostname().split(".")[0] fqdn = LazyFqdn(config, host) - homedir = os.path.expanduser('~') - replacements = {'controlpath': - [ - ('%h', config['hostname']), - ('%l', fqdn), - ('%L', host), - ('%n', hostname), - ('%p', port), - ('%r', remoteuser), - ('%u', user) - ], - 'identityfile': - [ - ('~', homedir), - ('%d', homedir), - ('%h', config['hostname']), - ('%l', fqdn), - ('%u', user), - ('%r', remoteuser) - ], - 'proxycommand': - [ - ('~', homedir), - ('%h', config['hostname']), - ('%p', port), - ('%r', remoteuser) - ] - } + homedir = os.path.expanduser("~") + replacements = { + "controlpath": [ + ("%h", config["hostname"]), + ("%l", fqdn), + ("%L", host), + ("%n", hostname), + ("%p", port), + ("%r", remoteuser), + ("%u", user), + ], + "identityfile": [ + ("~", homedir), + ("%d", homedir), + ("%h", config["hostname"]), + ("%l", fqdn), + ("%u", user), + ("%r", remoteuser), + ], + "proxycommand": [ + ("~", homedir), + ("%h", config["hostname"]), + ("%p", port), + ("%r", remoteuser), + ], + } for k in config: if config[k] is None: @@ -265,11 +261,11 @@ class LazyFqdn(object): # Handle specific option fqdn = None - address_family = self.config.get('addressfamily', 'any').lower() - if address_family != 'any': + address_family = self.config.get("addressfamily", "any").lower() + if address_family != "any": try: family = socket.AF_INET6 - if address_family == 'inet': + if address_family == "inet": socket.AF_INET results = socket.getaddrinfo( self.host, @@ -277,11 +273,11 @@ class LazyFqdn(object): family, socket.SOCK_DGRAM, socket.IPPROTO_IP, - socket.AI_CANONNAME + socket.AI_CANONNAME, ) for res in results: af, socktype, proto, canonname, sa = res - if canonname and '.' in canonname: + if canonname and "." in canonname: fqdn = canonname break # giaerror -> socket.getaddrinfo() can't resolve self.host diff --git a/paramiko/dsskey.py b/paramiko/dsskey.py index ac1d4c2e..fc9e7b54 100644 --- a/paramiko/dsskey.py +++ b/paramiko/dsskey.py @@ -25,7 +25,8 @@ from cryptography.hazmat.backends import default_backend from cryptography.hazmat.primitives import hashes, serialization from cryptography.hazmat.primitives.asymmetric import dsa from cryptography.hazmat.primitives.asymmetric.utils import ( - decode_dss_signature, encode_dss_signature + decode_dss_signature, + encode_dss_signature, ) from paramiko import util @@ -42,8 +43,15 @@ class DSSKey(PKey): data. """ - def __init__(self, msg=None, data=None, filename=None, password=None, - vals=None, file_obj=None): + def __init__( + self, + msg=None, + data=None, + filename=None, + password=None, + vals=None, + file_obj=None, + ): self.p = None self.q = None self.g = None @@ -63,8 +71,8 @@ class DSSKey(PKey): else: self._check_type_and_load_cert( msg=msg, - key_type='ssh-dss', - cert_type='ssh-dss-cert-v01@openssh.com', + key_type="ssh-dss", + cert_type="ssh-dss-cert-v01@openssh.com", ) self.p = msg.get_mpint() self.q = msg.get_mpint() @@ -74,7 +82,7 @@ class DSSKey(PKey): def asbytes(self): m = Message() - m.add_string('ssh-dss') + m.add_string("ssh-dss") m.add_mpint(self.p) m.add_mpint(self.q) m.add_mpint(self.g) @@ -88,7 +96,7 @@ class DSSKey(PKey): return hash((self.get_name(), self.p, self.q, self.g, self.y)) def get_name(self): - return 'ssh-dss' + return "ssh-dss" def get_bits(self): return self.size @@ -102,17 +110,17 @@ class DSSKey(PKey): public_numbers=dsa.DSAPublicNumbers( y=self.y, parameter_numbers=dsa.DSAParameterNumbers( - p=self.p, - q=self.q, - g=self.g - ) - ) - ).private_key(backend=default_backend()) + p=self.p, q=self.q, g=self.g + ), + ), + ).private_key( + backend=default_backend() + ) sig = key.sign(data, hashes.SHA1()) r, s = decode_dss_signature(sig) m = Message() - m.add_string('ssh-dss') + m.add_string("ssh-dss") # apparently, in rare cases, r or s may be shorter than 20 bytes! rstr = util.deflate_long(r, 0) sstr = util.deflate_long(s, 0) @@ -129,7 +137,7 @@ class DSSKey(PKey): sig = msg.asbytes() else: kind = msg.get_text() - if kind != 'ssh-dss': + if kind != "ssh-dss": return 0 sig = msg.get_binary() @@ -142,11 +150,11 @@ class DSSKey(PKey): key = dsa.DSAPublicNumbers( y=self.y, parameter_numbers=dsa.DSAParameterNumbers( - p=self.p, - q=self.q, - g=self.g - ) - ).public_key(backend=default_backend()) + p=self.p, q=self.q, g=self.g + ), + ).public_key( + backend=default_backend() + ) try: key.verify(signature, data, hashes.SHA1()) except InvalidSignature: @@ -160,18 +168,18 @@ class DSSKey(PKey): public_numbers=dsa.DSAPublicNumbers( y=self.y, parameter_numbers=dsa.DSAParameterNumbers( - p=self.p, - q=self.q, - g=self.g - ) - ) - ).private_key(backend=default_backend()) + p=self.p, q=self.q, g=self.g + ), + ), + ).private_key( + backend=default_backend() + ) self._write_private_key_file( filename, key, serialization.PrivateFormat.TraditionalOpenSSL, - password=password + password=password, ) def write_private_key(self, file_obj, password=None): @@ -180,18 +188,18 @@ class DSSKey(PKey): public_numbers=dsa.DSAPublicNumbers( y=self.y, parameter_numbers=dsa.DSAParameterNumbers( - p=self.p, - q=self.q, - g=self.g - ) - ) - ).private_key(backend=default_backend()) + p=self.p, q=self.q, g=self.g + ), + ), + ).private_key( + backend=default_backend() + ) self._write_private_key( file_obj, key, serialization.PrivateFormat.TraditionalOpenSSL, - password=password + password=password, ) @staticmethod @@ -207,23 +215,25 @@ class DSSKey(PKey): numbers = dsa.generate_private_key( bits, backend=default_backend() ).private_numbers() - key = DSSKey(vals=( - numbers.public_numbers.parameter_numbers.p, - numbers.public_numbers.parameter_numbers.q, - numbers.public_numbers.parameter_numbers.g, - numbers.public_numbers.y - )) + key = DSSKey( + vals=( + numbers.public_numbers.parameter_numbers.p, + numbers.public_numbers.parameter_numbers.q, + numbers.public_numbers.parameter_numbers.g, + numbers.public_numbers.y, + ) + ) key.x = numbers.x return key # ...internals... def _from_private_key_file(self, filename, password): - data = self._read_private_key_file('DSA', filename, password) + data = self._read_private_key_file("DSA", filename, password) self._decode_key(data) def _from_private_key(self, file_obj, password): - data = self._read_private_key('DSA', file_obj, password) + data = self._read_private_key("DSA", file_obj, password) self._decode_key(data) def _decode_key(self, data): @@ -232,14 +242,11 @@ class DSSKey(PKey): try: keylist = BER(data).decode() except BERException as e: - raise SSHException('Unable to parse key file: ' + str(e)) - if ( - type(keylist) is not list or - len(keylist) < 6 or - keylist[0] != 0 - ): + raise SSHException("Unable to parse key file: " + str(e)) + if type(keylist) is not list or len(keylist) < 6 or keylist[0] != 0: raise SSHException( - 'not a valid DSA private key file (bad ber encoding)') + "not a valid DSA private key file (bad ber encoding)" + ) self.p = keylist[1] self.q = keylist[2] self.g = keylist[3] diff --git a/paramiko/ecdsakey.py b/paramiko/ecdsakey.py index 92e01a75..4b7984ca 100644 --- a/paramiko/ecdsakey.py +++ b/paramiko/ecdsakey.py @@ -25,7 +25,8 @@ from cryptography.hazmat.backends import default_backend from cryptography.hazmat.primitives import hashes, serialization from cryptography.hazmat.primitives.asymmetric import ec from cryptography.hazmat.primitives.asymmetric.utils import ( - decode_dss_signature, encode_dss_signature + decode_dss_signature, + encode_dss_signature, ) from paramiko.common import four_byte @@ -43,6 +44,7 @@ class _ECDSACurve(object): the proper hash function. Also grabs the proper curve from the 'ecdsa' package. """ + def __init__(self, curve_class, nist_name): self.nist_name = nist_name self.key_length = curve_class.key_size @@ -67,6 +69,7 @@ class _ECDSACurveSet(object): format identifier. The two ways in which ECDSAKey needs to be able to look up curves. """ + def __init__(self, ecdsa_curves): self.ecdsa_curves = ecdsa_curves @@ -95,14 +98,24 @@ class ECDSAKey(PKey): data. """ - _ECDSA_CURVES = _ECDSACurveSet([ - _ECDSACurve(ec.SECP256R1, 'nistp256'), - _ECDSACurve(ec.SECP384R1, 'nistp384'), - _ECDSACurve(ec.SECP521R1, 'nistp521'), - ]) - - def __init__(self, msg=None, data=None, filename=None, password=None, - vals=None, file_obj=None, validate_point=True): + _ECDSA_CURVES = _ECDSACurveSet( + [ + _ECDSACurve(ec.SECP256R1, "nistp256"), + _ECDSACurve(ec.SECP384R1, "nistp384"), + _ECDSACurve(ec.SECP521R1, "nistp521"), + ] + ) + + def __init__( + self, + msg=None, + data=None, + filename=None, + password=None, + vals=None, + file_obj=None, + validate_point=True, + ): self.verifying_key = None self.signing_key = None self.public_blob = None @@ -126,7 +139,7 @@ class ECDSAKey(PKey): # identifier, so strip out any cert business. (NOTE: could push # that into _ECDSACurveSet.get_by_key_format_identifier(), but it # feels more correct to do it here?) - suffix = '-cert-v01@openssh.com' + suffix = "-cert-v01@openssh.com" if key_type.endswith(suffix): key_type = key_type[:-len(suffix)] self.ecdsa_curve = self._ECDSA_CURVES.get_by_key_format_identifier( @@ -134,13 +147,10 @@ class ECDSAKey(PKey): ) key_types = self._ECDSA_CURVES.get_key_format_identifier_list() cert_types = [ - '{}-cert-v01@openssh.com'.format(x) - for x in key_types + "{}-cert-v01@openssh.com".format(x) for x in key_types ] self._check_type_and_load_cert( - msg=msg, - key_type=key_types, - cert_type=cert_types, + msg=msg, key_type=key_types, cert_type=cert_types ) curvename = msg.get_text() if curvename != self.ecdsa_curve.nist_name: @@ -172,10 +182,10 @@ class ECDSAKey(PKey): key_size_bytes = (key.curve.key_size + 7) // 8 x_bytes = deflate_long(numbers.x, add_sign_padding=False) - x_bytes = b'\x00' * (key_size_bytes - len(x_bytes)) + x_bytes + x_bytes = b"\x00" * (key_size_bytes - len(x_bytes)) + x_bytes y_bytes = deflate_long(numbers.y, add_sign_padding=False) - y_bytes = b'\x00' * (key_size_bytes - len(y_bytes)) + y_bytes + y_bytes = b"\x00" * (key_size_bytes - len(y_bytes)) + y_bytes point_str = four_byte + x_bytes + y_bytes m.add_string(point_str) @@ -185,8 +195,13 @@ class ECDSAKey(PKey): return self.asbytes() def __hash__(self): - return hash((self.get_name(), self.verifying_key.public_numbers().x, - self.verifying_key.public_numbers().y)) + return hash( + ( + self.get_name(), + self.verifying_key.public_numbers().x, + self.verifying_key.public_numbers().y, + ) + ) def get_name(self): return self.ecdsa_curve.key_format_identifier @@ -228,7 +243,7 @@ class ECDSAKey(PKey): filename, self.signing_key, serialization.PrivateFormat.TraditionalOpenSSL, - password=password + password=password, ) def write_private_key(self, file_obj, password=None): @@ -236,7 +251,7 @@ class ECDSAKey(PKey): file_obj, self.signing_key, serialization.PrivateFormat.TraditionalOpenSSL, - password=password + password=password, ) @classmethod @@ -260,11 +275,11 @@ class ECDSAKey(PKey): # ...internals... def _from_private_key_file(self, filename, password): - data = self._read_private_key_file('EC', filename, password) + data = self._read_private_key_file("EC", filename, password) self._decode_key(data) def _from_private_key(self, file_obj, password): - data = self._read_private_key('EC', file_obj, password) + data = self._read_private_key("EC", file_obj, password) self._decode_key(data) def _decode_key(self, data): diff --git a/paramiko/ed25519key.py b/paramiko/ed25519key.py index 8ad71d08..c8f6dd34 100644 --- a/paramiko/ed25519key.py +++ b/paramiko/ed25519key.py @@ -56,8 +56,10 @@ class Ed25519Key(PKey): .. versionchanged:: 2.3 Added a ``file_obj`` parameter to match other key classes. """ - def __init__(self, msg=None, data=None, filename=None, password=None, - file_obj=None): + + def __init__( + self, msg=None, data=None, filename=None, password=None, file_obj=None + ): self.public_blob = None verifying_key = signing_key = None if msg is None and data is not None: @@ -86,6 +88,7 @@ class Ed25519Key(PKey): def _parse_signing_key_data(self, data, password): from paramiko.transport import Transport + # We may eventually want this to be usable for other key types, as # OpenSSH moves to it, but for now this is just for Ed25519 keys. # This format is described here: @@ -144,7 +147,7 @@ class Ed25519Key(PKey): decryptor = Cipher( cipher["class"](key[:cipher["key-size"]]), cipher["mode"](key[cipher["key-size"]:]), - backend=default_backend() + backend=default_backend(), ).decryptor() private_data = ( decryptor.update(private_ciphertext) + decryptor.finalize() @@ -166,8 +169,10 @@ class Ed25519Key(PKey): signing_key = nacl.signing.SigningKey(key_data[:32]) # Verify that all the public keys are the same... assert ( - signing_key.verify_key.encode() == public == public_keys[i] == - key_data[32:] + signing_key.verify_key.encode() + == public + == public_keys[i] + == key_data[32:] ) signing_keys.append(signing_key) # Comment, ignore. diff --git a/paramiko/file.py b/paramiko/file.py index df9cdac7..62686b53 100644 --- a/paramiko/file.py +++ b/paramiko/file.py @@ -16,14 +16,18 @@ # along with Paramiko; if not, write to the Free Software Foundation, Inc., # 59 Temple Place, Suite 330, Boston, MA 02111-1307 USA. from paramiko.common import ( - linefeed_byte_value, crlf, cr_byte, linefeed_byte, cr_byte_value, + linefeed_byte_value, + crlf, + cr_byte, + linefeed_byte, + cr_byte_value, ) from paramiko.py3compat import BytesIO, PY2, u, bytes_types, text_type from paramiko.util import ClosingContextManager -class BufferedFile (ClosingContextManager): +class BufferedFile(ClosingContextManager): """ Reusable base class to implement Python-style file buffering around a simpler stream. @@ -70,7 +74,7 @@ class BufferedFile (ClosingContextManager): :raises: ``ValueError`` -- if the file is closed. """ if self._closed: - raise ValueError('I/O operation on closed file') + raise ValueError("I/O operation on closed file") return self def close(self): @@ -90,6 +94,7 @@ class BufferedFile (ClosingContextManager): return if PY2: + def next(self): """ Returns the next line from the input, or raises @@ -104,7 +109,9 @@ class BufferedFile (ClosingContextManager): if not line: raise StopIteration return line + else: + def __next__(self): """ Returns the next line from the input, or raises ``StopIteration`` @@ -180,9 +187,9 @@ class BufferedFile (ClosingContextManager): encountered immediately """ if self._closed: - raise IOError('File is closed') + raise IOError("File is closed") if not (self._flags & self.FLAG_READ): - raise IOError('File is not open for reading') + raise IOError("File is not open for reading") if (size is None) or (size < 0): # go for broke result = self._rbuffer @@ -245,16 +252,16 @@ class BufferedFile (ClosingContextManager): """ # it's almost silly how complex this function is. if self._closed: - raise IOError('File is closed') + raise IOError("File is closed") if not (self._flags & self.FLAG_READ): - raise IOError('File not open for reading') + raise IOError("File not open for reading") line = self._rbuffer truncated = False while True: if ( - self._at_trailing_cr and - self._flags & self.FLAG_UNIVERSAL_NEWLINE and - len(line) > 0 + self._at_trailing_cr + and self._flags & self.FLAG_UNIVERSAL_NEWLINE + and len(line) > 0 ): # edge case: the newline may be '\r\n' and we may have read # only the first '\r' last time. @@ -277,10 +284,10 @@ class BufferedFile (ClosingContextManager): else: n = self._bufsize if ( - linefeed_byte in line or - ( - self._flags & self.FLAG_UNIVERSAL_NEWLINE and - cr_byte in line + linefeed_byte in line + or ( + self._flags & self.FLAG_UNIVERSAL_NEWLINE + and cr_byte in line ) ): break @@ -306,9 +313,9 @@ class BufferedFile (ClosingContextManager): return line if self._flags & self.FLAG_BINARY else u(line) xpos = pos + 1 if ( - line[pos] == cr_byte_value and - xpos < len(line) and - line[xpos] == linefeed_byte_value + line[pos] == cr_byte_value + and xpos < len(line) + and line[xpos] == linefeed_byte_value ): xpos += 1 # if the string was truncated, _rbuffer needs to have the string after @@ -370,7 +377,7 @@ class BufferedFile (ClosingContextManager): :raises: ``IOError`` -- if the file doesn't support random access. """ - raise IOError('File does not support seeking.') + raise IOError("File does not support seeking.") def tell(self): """ @@ -393,11 +400,11 @@ class BufferedFile (ClosingContextManager): """ if isinstance(data, text_type): # Accept text and encode as utf-8 for compatibility only. - data = data.encode('utf-8') + data = data.encode("utf-8") if self._closed: - raise IOError('File is closed') + raise IOError("File is closed") if not (self._flags & self.FLAG_WRITE): - raise IOError('File not open for writing') + raise IOError("File not open for writing") if not (self._flags & self.FLAG_BUFFERED): self._write_all(data) return @@ -457,7 +464,7 @@ class BufferedFile (ClosingContextManager): (subclass override) Write data into the stream. """ - raise IOError('write not implemented') + raise IOError("write not implemented") def _get_size(self): """ @@ -472,7 +479,7 @@ class BufferedFile (ClosingContextManager): # ...internals... - def _set_mode(self, mode='r', bufsize=-1): + def _set_mode(self, mode="r", bufsize=-1): """ Subclasses call this method to initialize the BufferedFile. """ @@ -495,17 +502,17 @@ class BufferedFile (ClosingContextManager): # unbuffered self._flags &= ~(self.FLAG_BUFFERED | self.FLAG_LINE_BUFFERED) - if ('r' in mode) or ('+' in mode): + if ("r" in mode) or ("+" in mode): self._flags |= self.FLAG_READ - if ('w' in mode) or ('+' in mode): + if ("w" in mode) or ("+" in mode): self._flags |= self.FLAG_WRITE - if 'a' in mode: + if "a" in mode: self._flags |= self.FLAG_WRITE | self.FLAG_APPEND self._size = self._get_size() self._pos = self._realpos = self._size - if 'b' in mode: + if "b" in mode: self._flags |= self.FLAG_BINARY - if 'U' in mode: + if "U" in mode: self._flags |= self.FLAG_UNIVERSAL_NEWLINE # built-in file objects have this attribute to store which kinds of # line terminations they've seen: @@ -535,8 +542,7 @@ class BufferedFile (ClosingContextManager): if self.newlines is None: self.newlines = newline elif ( - self.newlines != newline and - isinstance(self.newlines, bytes_types) + self.newlines != newline and isinstance(self.newlines, bytes_types) ): self.newlines = (self.newlines, newline) elif newline not in self.newlines: diff --git a/paramiko/hostkeys.py b/paramiko/hostkeys.py index ca185273..c4d611db 100644 --- a/paramiko/hostkeys.py +++ b/paramiko/hostkeys.py @@ -34,7 +34,7 @@ from paramiko.ed25519key import Ed25519Key from paramiko.ssh_exception import SSHException -class HostKeys (MutableMapping): +class HostKeys(MutableMapping): """ Representation of an OpenSSH-style "known hosts" file. Host keys can be read from one or more files, and then individual hosts can be looked up to @@ -88,10 +88,10 @@ class HostKeys (MutableMapping): :raises: ``IOError`` -- if there was an error reading the file """ - with open(filename, 'r') as f: + with open(filename, "r") as f: for lineno, line in enumerate(f, 1): line = line.strip() - if (len(line) == 0) or (line[0] == '#'): + if (len(line) == 0) or (line[0] == "#"): continue try: e = HostKeyEntry.from_line(line, lineno) @@ -118,7 +118,7 @@ class HostKeys (MutableMapping): .. versionadded:: 1.6.1 """ - with open(filename, 'w') as f: + with open(filename, "w") as f: for e in self._entries: line = e.to_line() if line: @@ -134,7 +134,9 @@ class HostKeys (MutableMapping): :return: dict of `str` -> `.PKey` keys associated with this host (or ``None``) """ - class SubDict (MutableMapping): + + class SubDict(MutableMapping): + def __init__(self, hostname, entries, hostkeys): self._hostname = hostname self._entries = entries @@ -176,7 +178,8 @@ class HostKeys (MutableMapping): def keys(self): return [ - e.key.get_name() for e in self._entries + e.key.get_name() + for e in self._entries if e.key is not None ] @@ -196,10 +199,10 @@ class HostKeys (MutableMapping): """ for h in entry.hostnames: if ( - h == hostname or - h.startswith('|1|') and - not hostname.startswith('|1|') and - constant_time_bytes_eq(self.hash_host(hostname, h), h) + h == hostname + or h.startswith("|1|") + and not hostname.startswith("|1|") + and constant_time_bytes_eq(self.hash_host(hostname, h), h) ): return True return False @@ -295,16 +298,17 @@ class HostKeys (MutableMapping): if salt is None: salt = os.urandom(sha1().digest_size) else: - if salt.startswith('|1|'): - salt = salt.split('|')[2] + if salt.startswith("|1|"): + salt = salt.split("|")[2] salt = decodebytes(b(salt)) assert len(salt) == sha1().digest_size hmac = HMAC(salt, b(hostname), sha1).digest() - hostkey = '|1|{}|{}'.format(u(encodebytes(salt)), u(encodebytes(hmac))) - return hostkey.replace('\n', '') + hostkey = "|1|{}|{}".format(u(encodebytes(salt)), u(encodebytes(hmac))) + return hostkey.replace("\n", "") class InvalidHostKey(Exception): + def __init__(self, line, exc): self.line = line self.exc = exc @@ -334,8 +338,8 @@ class HostKeyEntry: :param str line: a line from an OpenSSH known_hosts file """ - log = get_logger('paramiko.hostkeys') - fields = line.split(' ') + log = get_logger("paramiko.hostkeys") + fields = line.split(" ") if len(fields) < 3: # Bad number of fields msg = "Not enough fields found in known_hosts in line {} ({!r})" @@ -344,19 +348,19 @@ class HostKeyEntry: fields = fields[:3] names, keytype, key = fields - names = names.split(',') + names = names.split(",") # Decide what kind of key we're looking at and create an object # to hold it accordingly. try: key = b(key) - if keytype == 'ssh-rsa': + if keytype == "ssh-rsa": key = RSAKey(data=decodebytes(key)) - elif keytype == 'ssh-dss': + elif keytype == "ssh-dss": key = DSSKey(data=decodebytes(key)) elif keytype in ECDSAKey.supported_key_format_identifiers(): key = ECDSAKey(data=decodebytes(key), validate_point=False) - elif keytype == 'ssh-ed25519': + elif keytype == "ssh-ed25519": key = Ed25519Key(data=decodebytes(key)) else: log.info("Unable to handle key of type {}".format(keytype)) @@ -374,12 +378,12 @@ class HostKeyEntry: included. """ if self.valid: - return '{} {} {}\n'.format( - ','.join(self.hostnames), + return "{} {} {}\n".format( + ",".join(self.hostnames), self.key.get_name(), self.key.get_base64(), ) return None def __repr__(self): - return '<HostKeyEntry {!r}: {!r}>'.format(self.hostnames, self.key) + return "<HostKeyEntry {!r}: {!r}>".format(self.hostnames, self.key) diff --git a/paramiko/kex_ecdh_nist.py b/paramiko/kex_ecdh_nist.py index 4e8ff35d..ca32404f 100644 --- a/paramiko/kex_ecdh_nist.py +++ b/paramiko/kex_ecdh_nist.py @@ -46,7 +46,7 @@ class KexNistp256(): elif not self.transport.server_mode and (ptype == _MSG_KEXECDH_REPLY): return self._parse_kexecdh_reply(m) raise SSHException( - 'KexECDH asked to handle packet type {:d}'.format(ptype) + "KexECDH asked to handle packet type {:d}".format(ptype) ) def _generate_key_pair(self): @@ -66,8 +66,12 @@ class KexNistp256(): K = long(hexlify(K), 16) # compute exchange hash hm = Message() - hm.add(self.transport.remote_version, self.transport.local_version, - self.transport.remote_kex_init, self.transport.local_kex_init) + hm.add( + self.transport.remote_version, + self.transport.local_version, + self.transport.remote_kex_init, + self.transport.local_kex_init, + ) hm.add_string(K_S) hm.add_string(Q_C_bytes) # SEC1: V2.0 2.3.3 Elliptic-Curve-Point-to-Octet-String Conversion @@ -96,8 +100,12 @@ class KexNistp256(): K = long(hexlify(K), 16) # compute exchange hash and verify signature hm = Message() - hm.add(self.transport.local_version, self.transport.remote_version, - self.transport.local_kex_init, self.transport.remote_kex_init) + hm.add( + self.transport.local_version, + self.transport.remote_version, + self.transport.local_kex_init, + self.transport.remote_kex_init, + ) hm.add_string(K_S) # SEC1: V2.0 2.3.3 Elliptic-Curve-Point-to-Octet-String Conversion hm.add_string(self.Q_C.public_numbers().encode_point()) diff --git a/paramiko/kex_gex.py b/paramiko/kex_gex.py index 44030569..6b24c1ff 100644 --- a/paramiko/kex_gex.py +++ b/paramiko/kex_gex.py @@ -32,17 +32,30 @@ from paramiko.py3compat import byte_chr, byte_ord, byte_mask from paramiko.ssh_exception import SSHException -_MSG_KEXDH_GEX_REQUEST_OLD, _MSG_KEXDH_GEX_GROUP, _MSG_KEXDH_GEX_INIT, \ - _MSG_KEXDH_GEX_REPLY, _MSG_KEXDH_GEX_REQUEST = range(30, 35) +( + _MSG_KEXDH_GEX_REQUEST_OLD, + _MSG_KEXDH_GEX_GROUP, + _MSG_KEXDH_GEX_INIT, + _MSG_KEXDH_GEX_REPLY, + _MSG_KEXDH_GEX_REQUEST, +) = range( + 30, 35 +) -c_MSG_KEXDH_GEX_REQUEST_OLD, c_MSG_KEXDH_GEX_GROUP, c_MSG_KEXDH_GEX_INIT, \ - c_MSG_KEXDH_GEX_REPLY, c_MSG_KEXDH_GEX_REQUEST = \ - [byte_chr(c) for c in range(30, 35)] +( + c_MSG_KEXDH_GEX_REQUEST_OLD, + c_MSG_KEXDH_GEX_GROUP, + c_MSG_KEXDH_GEX_INIT, + c_MSG_KEXDH_GEX_REPLY, + c_MSG_KEXDH_GEX_REQUEST, +) = [ + byte_chr(c) for c in range(30, 35) +] -class KexGex (object): +class KexGex(object): - name = 'diffie-hellman-group-exchange-sha1' + name = "diffie-hellman-group-exchange-sha1" min_bits = 1024 max_bits = 8192 preferred_bits = 2048 @@ -61,7 +74,8 @@ class KexGex (object): def start_kex(self, _test_old_style=False): if self.transport.server_mode: self.transport._expect_packet( - _MSG_KEXDH_GEX_REQUEST, _MSG_KEXDH_GEX_REQUEST_OLD) + _MSG_KEXDH_GEX_REQUEST, _MSG_KEXDH_GEX_REQUEST_OLD + ) return # request a bit range: we accept (min_bits) to (max_bits), but prefer # (preferred_bits). according to the spec, we shouldn't pull the @@ -137,13 +151,12 @@ class KexGex (object): # generate prime pack = self.transport._get_modulus_pack() if pack is None: - raise SSHException( - 'Can\'t do server-side gex with no modulus pack') + raise SSHException("Can't do server-side gex with no modulus pack") self.transport._log( DEBUG, - 'Picking p ({} <= {} <= {} bits)'.format( - minbits, preferredbits, maxbits, - ) + "Picking p ({} <= {} <= {} bits)".format( + minbits, preferredbits, maxbits + ), ) self.g, self.p = pack.get_modulus(minbits, preferredbits, maxbits) m = Message() @@ -165,13 +178,13 @@ class KexGex (object): # generate prime pack = self.transport._get_modulus_pack() if pack is None: - raise SSHException( - 'Can\'t do server-side gex with no modulus pack') + raise SSHException("Can't do server-side gex with no modulus pack") self.transport._log( - DEBUG, 'Picking p (~ {} bits)'.format(self.preferred_bits) + DEBUG, "Picking p (~ {} bits)".format(self.preferred_bits) ) self.g, self.p = pack.get_modulus( - self.min_bits, self.preferred_bits, self.max_bits) + self.min_bits, self.preferred_bits, self.max_bits + ) m = Message() m.add_byte(c_MSG_KEXDH_GEX_GROUP) m.add_mpint(self.p) @@ -187,9 +200,10 @@ class KexGex (object): bitlen = util.bit_length(self.p) if (bitlen < 1024) or (bitlen > 8192): raise SSHException( - 'Server-generated gex p (don\'t ask) is out of range ' - '({} bits)'.format(bitlen)) - self.transport._log(DEBUG, 'Got server p ({} bits)'.format(bitlen)) + "Server-generated gex p (don't ask) is out of range " + "({} bits)".format(bitlen) + ) + self.transport._log(DEBUG, "Got server p ({} bits)".format(bitlen)) self._generate_x() # now compute e = g^x mod p self.e = pow(self.g, self.x, self.p) @@ -210,9 +224,13 @@ class KexGex (object): # okay, build up the hash H of # (V_C || V_S || I_C || I_S || K_S || min || n || max || p || g || e || f || K) # noqa hm = Message() - hm.add(self.transport.remote_version, self.transport.local_version, - self.transport.remote_kex_init, self.transport.local_kex_init, - key) + hm.add( + self.transport.remote_version, + self.transport.local_version, + self.transport.remote_kex_init, + self.transport.local_kex_init, + key, + ) if not self.old_style: hm.add_int(self.min_bits) hm.add_int(self.preferred_bits) @@ -246,9 +264,13 @@ class KexGex (object): # okay, build up the hash H of # (V_C || V_S || I_C || I_S || K_S || min || n || max || p || g || e || f || K) # noqa hm = Message() - hm.add(self.transport.local_version, self.transport.remote_version, - self.transport.local_kex_init, self.transport.remote_kex_init, - host_key) + hm.add( + self.transport.local_version, + self.transport.remote_version, + self.transport.local_kex_init, + self.transport.remote_kex_init, + host_key, + ) if not self.old_style: hm.add_int(self.min_bits) hm.add_int(self.preferred_bits) @@ -265,5 +287,5 @@ class KexGex (object): class KexGexSHA256(KexGex): - name = 'diffie-hellman-group-exchange-sha256' + name = "diffie-hellman-group-exchange-sha256" hash_algo = sha256 diff --git a/paramiko/kex_group1.py b/paramiko/kex_group1.py index 1bebd375..904835d7 100644 --- a/paramiko/kex_group1.py +++ b/paramiko/kex_group1.py @@ -44,7 +44,7 @@ class KexGroup1(object): P = 0xFFFFFFFFFFFFFFFFC90FDAA22168C234C4C6628B80DC1CD129024E088A67CC74020BBEA63B139B22514A08798E3404DDEF9519B3CD3A431B302B0A6DF25F14374FE1356D6D51C245E485B576625E7EC6F44C42E9A637ED6B0BFF5CB6F406B7EDEE386BFB5A899FA5AE9F24117C4B1FE649286651ECE65381FFFFFFFFFFFFFFFF # noqa G = 2 - name = 'diffie-hellman-group1-sha1' + name = "diffie-hellman-group1-sha1" hash_algo = sha1 def __init__(self, transport): @@ -88,8 +88,10 @@ class KexGroup1(object): while 1: x_bytes = os.urandom(128) x_bytes = byte_mask(x_bytes[0], 0x7f) + x_bytes[1:] - if (x_bytes[:8] != b7fffffffffffffff and - x_bytes[:8] != b0000000000000000): + if ( + x_bytes[:8] != b7fffffffffffffff + and x_bytes[:8] != b0000000000000000 + ): break self.x = util.inflate_long(x_bytes) @@ -104,8 +106,12 @@ class KexGroup1(object): # okay, build up the hash H of # (V_C || V_S || I_C || I_S || K_S || e || f || K) hm = Message() - hm.add(self.transport.local_version, self.transport.remote_version, - self.transport.local_kex_init, self.transport.remote_kex_init) + hm.add( + self.transport.local_version, + self.transport.remote_version, + self.transport.local_kex_init, + self.transport.remote_kex_init, + ) hm.add_string(host_key) hm.add_mpint(self.e) hm.add_mpint(self.f) @@ -124,8 +130,12 @@ class KexGroup1(object): # okay, build up the hash H of # (V_C || V_S || I_C || I_S || K_S || e || f || K) hm = Message() - hm.add(self.transport.remote_version, self.transport.local_version, - self.transport.remote_kex_init, self.transport.local_kex_init) + hm.add( + self.transport.remote_version, + self.transport.local_version, + self.transport.remote_kex_init, + self.transport.local_kex_init, + ) hm.add_string(key) hm.add_mpint(self.e) hm.add_mpint(self.f) diff --git a/paramiko/kex_group14.py b/paramiko/kex_group14.py index 22955e34..0df302e3 100644 --- a/paramiko/kex_group14.py +++ b/paramiko/kex_group14.py @@ -31,5 +31,5 @@ class KexGroup14(KexGroup1): P = 0xFFFFFFFFFFFFFFFFC90FDAA22168C234C4C6628B80DC1CD129024E088A67CC74020BBEA63B139B22514A08798E3404DDEF9519B3CD3A431B302B0A6DF25F14374FE1356D6D51C245E485B576625E7EC6F44C42E9A637ED6B0BFF5CB6F406B7EDEE386BFB5A899FA5AE9F24117C4B1FE649286651ECE45B3DC2007CB8A163BF0598DA48361C55D39A69163FA8FD24CF5F83655D23DCA3AD961C62F356208552BB9ED529077096966D670C354E4ABC9804F1746C08CA18217C32905E462E36CE3BE39E772C180E86039B2783A2EC07A28FB5C55DF06F4C52C9DE2BCBF6955817183995497CEA956AE515D2261898FA051015728E5A8AACAA68FFFFFFFFFFFFFFFF # noqa G = 2 - name = 'diffie-hellman-group14-sha1' + name = "diffie-hellman-group14-sha1" hash_algo = sha1 diff --git a/paramiko/kex_gss.py b/paramiko/kex_gss.py index e21620fe..5eaaa5d5 100644 --- a/paramiko/kex_gss.py +++ b/paramiko/kex_gss.py @@ -47,14 +47,26 @@ from paramiko.py3compat import byte_chr, byte_mask, byte_ord from paramiko.ssh_exception import SSHException -MSG_KEXGSS_INIT, MSG_KEXGSS_CONTINUE, MSG_KEXGSS_COMPLETE, MSG_KEXGSS_HOSTKEY,\ - MSG_KEXGSS_ERROR = range(30, 35) -MSG_KEXGSS_GROUPREQ, MSG_KEXGSS_GROUP = range(40, 42) -c_MSG_KEXGSS_INIT, c_MSG_KEXGSS_CONTINUE, c_MSG_KEXGSS_COMPLETE,\ - c_MSG_KEXGSS_HOSTKEY, c_MSG_KEXGSS_ERROR = [ - byte_chr(c) for c in range(30, 35) - ] -c_MSG_KEXGSS_GROUPREQ, c_MSG_KEXGSS_GROUP = [ +( + MSG_KEXGSS_INIT, + MSG_KEXGSS_CONTINUE, + MSG_KEXGSS_COMPLETE, + MSG_KEXGSS_HOSTKEY, + MSG_KEXGSS_ERROR, +) = range( + 30, 35 +) +(MSG_KEXGSS_GROUPREQ, MSG_KEXGSS_GROUP) = range(40, 42) +( + c_MSG_KEXGSS_INIT, + c_MSG_KEXGSS_CONTINUE, + c_MSG_KEXGSS_COMPLETE, + c_MSG_KEXGSS_HOSTKEY, + c_MSG_KEXGSS_ERROR, +) = [ + byte_chr(c) for c in range(30, 35) +] +(c_MSG_KEXGSS_GROUPREQ, c_MSG_KEXGSS_GROUP) = [ byte_chr(c) for c in range(40, 42) ] @@ -98,10 +110,12 @@ class KexGSSGroup1(object): m.add_string(self.kexgss.ssh_init_sec_context(target=self.gss_host)) m.add_mpint(self.e) self.transport._send_message(m) - self.transport._expect_packet(MSG_KEXGSS_HOSTKEY, - MSG_KEXGSS_CONTINUE, - MSG_KEXGSS_COMPLETE, - MSG_KEXGSS_ERROR) + self.transport._expect_packet( + MSG_KEXGSS_HOSTKEY, + MSG_KEXGSS_CONTINUE, + MSG_KEXGSS_COMPLETE, + MSG_KEXGSS_ERROR, + ) def parse_next(self, ptype, m): """ @@ -120,7 +134,7 @@ class KexGSSGroup1(object): return self._parse_kexgss_complete(m) elif ptype == MSG_KEXGSS_ERROR: return self._parse_kexgss_error(m) - msg = 'GSS KexGroup1 asked to handle packet type {:d}' + msg = "GSS KexGroup1 asked to handle packet type {:d}" raise SSHException(msg.format(ptype)) # ## internals... @@ -152,8 +166,7 @@ class KexGSSGroup1(object): self.transport.host_key = host_key sig = m.get_string() self.transport._verify_key(host_key, sig) - self.transport._expect_packet(MSG_KEXGSS_CONTINUE, - MSG_KEXGSS_COMPLETE) + self.transport._expect_packet(MSG_KEXGSS_CONTINUE, MSG_KEXGSS_COMPLETE) def _parse_kexgss_continue(self, m): """ @@ -166,13 +179,14 @@ class KexGSSGroup1(object): srv_token = m.get_string() m = Message() m.add_byte(c_MSG_KEXGSS_CONTINUE) - m.add_string(self.kexgss.ssh_init_sec_context( - target=self.gss_host, recv_token=srv_token)) + m.add_string( + self.kexgss.ssh_init_sec_context( + target=self.gss_host, recv_token=srv_token + ) + ) self.transport.send_message(m) self.transport._expect_packet( - MSG_KEXGSS_CONTINUE, - MSG_KEXGSS_COMPLETE, - MSG_KEXGSS_ERROR + MSG_KEXGSS_CONTINUE, MSG_KEXGSS_COMPLETE, MSG_KEXGSS_ERROR ) else: pass @@ -200,8 +214,12 @@ class KexGSSGroup1(object): # okay, build up the hash H of # (V_C || V_S || I_C || I_S || K_S || e || f || K) hm = Message() - hm.add(self.transport.local_version, self.transport.remote_version, - self.transport.local_kex_init, self.transport.remote_kex_init) + hm.add( + self.transport.local_version, + self.transport.remote_version, + self.transport.local_kex_init, + self.transport.remote_kex_init, + ) hm.add_string(self.transport.host_key.__str__()) hm.add_mpint(self.e) hm.add_mpint(self.f) @@ -209,8 +227,9 @@ class KexGSSGroup1(object): H = sha1(str(hm)).digest() self.transport._set_K_H(K, H) if srv_token is not None: - self.kexgss.ssh_init_sec_context(target=self.gss_host, - recv_token=srv_token) + self.kexgss.ssh_init_sec_context( + target=self.gss_host, recv_token=srv_token + ) self.kexgss.ssh_check_mic(mic_token, H) else: self.kexgss.ssh_check_mic(mic_token, H) @@ -234,20 +253,26 @@ class KexGSSGroup1(object): # okay, build up the hash H of # (V_C || V_S || I_C || I_S || K_S || e || f || K) hm = Message() - hm.add(self.transport.remote_version, self.transport.local_version, - self.transport.remote_kex_init, self.transport.local_kex_init) + hm.add( + self.transport.remote_version, + self.transport.local_version, + self.transport.remote_kex_init, + self.transport.local_kex_init, + ) hm.add_string(key) hm.add_mpint(self.e) hm.add_mpint(self.f) hm.add_mpint(K) H = sha1(hm.asbytes()).digest() self.transport._set_K_H(K, H) - srv_token = self.kexgss.ssh_accept_sec_context(self.gss_host, - client_token) + srv_token = self.kexgss.ssh_accept_sec_context( + self.gss_host, client_token + ) m = Message() if self.kexgss._gss_srv_ctxt_status: - mic_token = self.kexgss.ssh_get_mic(self.transport.session_id, - gss_kex=True) + mic_token = self.kexgss.ssh_get_mic( + self.transport.session_id, gss_kex=True + ) m.add_byte(c_MSG_KEXGSS_COMPLETE) m.add_mpint(self.f) m.add_string(mic_token) @@ -263,9 +288,9 @@ class KexGSSGroup1(object): m.add_byte(c_MSG_KEXGSS_CONTINUE) m.add_string(srv_token) self.transport._send_message(m) - self.transport._expect_packet(MSG_KEXGSS_CONTINUE, - MSG_KEXGSS_COMPLETE, - MSG_KEXGSS_ERROR) + self.transport._expect_packet( + MSG_KEXGSS_CONTINUE, MSG_KEXGSS_COMPLETE, MSG_KEXGSS_ERROR + ) def _parse_kexgss_error(self, m): """ @@ -281,12 +306,16 @@ class KexGSSGroup1(object): maj_status = m.get_int() min_status = m.get_int() err_msg = m.get_string() - m.get_string() # we don't care about the language! - raise SSHException("""GSS-API Error: + m.get_string() # we don't care about the language! + raise SSHException( + """GSS-API Error: Major Status: {} Minor Status: {} Error Message: {} -""".format(maj_status, min_status, err_msg)) +""".format( + maj_status, min_status, err_msg + ) + ) class KexGSSGroup14(KexGSSGroup1): @@ -362,7 +391,7 @@ class KexGSSGex(object): return self._parse_kexgss_complete(m) elif ptype == MSG_KEXGSS_ERROR: return self._parse_kexgss_error(m) - msg = 'KexGex asked to handle packet type {:d}' + msg = "KexGex asked to handle packet type {:d}" raise SSHException(msg.format(ptype)) # ## internals... @@ -414,13 +443,12 @@ class KexGSSGex(object): # generate prime pack = self.transport._get_modulus_pack() if pack is None: - raise SSHException( - 'Can\'t do server-side gex with no modulus pack') + raise SSHException("Can't do server-side gex with no modulus pack") self.transport._log( DEBUG, # noqa - 'Picking p ({} <= {} <= {} bits)'.format( - minbits, preferredbits, maxbits, - ) + "Picking p ({} <= {} <= {} bits)".format( + minbits, preferredbits, maxbits + ), ) self.g, self.p = pack.get_modulus(minbits, preferredbits, maxbits) m = Message() @@ -442,9 +470,12 @@ class KexGSSGex(object): bitlen = util.bit_length(self.p) if (bitlen < 1024) or (bitlen > 8192): raise SSHException( - 'Server-generated gex p (don\'t ask) is out of range ' - '({} bits)'.format(bitlen)) - self.transport._log(DEBUG, 'Got server p ({} bits)'.format(bitlen)) # noqa + "Server-generated gex p (don't ask) is out of range " + "({} bits)".format(bitlen) + ) + self.transport._log( + DEBUG, "Got server p ({} bits)".format(bitlen) + ) # noqa self._generate_x() # now compute e = g^x mod p self.e = pow(self.g, self.x, self.p) @@ -453,10 +484,12 @@ class KexGSSGex(object): m.add_string(self.kexgss.ssh_init_sec_context(target=self.gss_host)) m.add_mpint(self.e) self.transport._send_message(m) - self.transport._expect_packet(MSG_KEXGSS_HOSTKEY, - MSG_KEXGSS_CONTINUE, - MSG_KEXGSS_COMPLETE, - MSG_KEXGSS_ERROR) + self.transport._expect_packet( + MSG_KEXGSS_HOSTKEY, + MSG_KEXGSS_CONTINUE, + MSG_KEXGSS_COMPLETE, + MSG_KEXGSS_ERROR, + ) def _parse_kexgss_gex_init(self, m): """ @@ -476,9 +509,13 @@ class KexGSSGex(object): # okay, build up the hash H of # (V_C || V_S || I_C || I_S || K_S || min || n || max || p || g || e || f || K) # noqa hm = Message() - hm.add(self.transport.remote_version, self.transport.local_version, - self.transport.remote_kex_init, self.transport.local_kex_init, - key) + hm.add( + self.transport.remote_version, + self.transport.local_version, + self.transport.remote_kex_init, + self.transport.local_kex_init, + key, + ) hm.add_int(self.min_bits) hm.add_int(self.preferred_bits) hm.add_int(self.max_bits) @@ -489,12 +526,14 @@ class KexGSSGex(object): hm.add_mpint(K) H = sha1(hm.asbytes()).digest() self.transport._set_K_H(K, H) - srv_token = self.kexgss.ssh_accept_sec_context(self.gss_host, - client_token) + srv_token = self.kexgss.ssh_accept_sec_context( + self.gss_host, client_token + ) m = Message() if self.kexgss._gss_srv_ctxt_status: - mic_token = self.kexgss.ssh_get_mic(self.transport.session_id, - gss_kex=True) + mic_token = self.kexgss.ssh_get_mic( + self.transport.session_id, gss_kex=True + ) m.add_byte(c_MSG_KEXGSS_COMPLETE) m.add_mpint(self.f) m.add_string(mic_token) @@ -510,9 +549,9 @@ class KexGSSGex(object): m.add_byte(c_MSG_KEXGSS_CONTINUE) m.add_string(srv_token) self.transport._send_message(m) - self.transport._expect_packet(MSG_KEXGSS_CONTINUE, - MSG_KEXGSS_COMPLETE, - MSG_KEXGSS_ERROR) + self.transport._expect_packet( + MSG_KEXGSS_CONTINUE, MSG_KEXGSS_COMPLETE, MSG_KEXGSS_ERROR + ) def _parse_kexgss_hostkey(self, m): """ @@ -525,8 +564,7 @@ class KexGSSGex(object): self.transport.host_key = host_key sig = m.get_string() self.transport._verify_key(host_key, sig) - self.transport._expect_packet(MSG_KEXGSS_CONTINUE, - MSG_KEXGSS_COMPLETE) + self.transport._expect_packet(MSG_KEXGSS_CONTINUE, MSG_KEXGSS_COMPLETE) def _parse_kexgss_continue(self, m): """ @@ -538,12 +576,15 @@ class KexGSSGex(object): srv_token = m.get_string() m = Message() m.add_byte(c_MSG_KEXGSS_CONTINUE) - m.add_string(self.kexgss.ssh_init_sec_context(target=self.gss_host, - recv_token=srv_token)) + m.add_string( + self.kexgss.ssh_init_sec_context( + target=self.gss_host, recv_token=srv_token + ) + ) self.transport.send_message(m) - self.transport._expect_packet(MSG_KEXGSS_CONTINUE, - MSG_KEXGSS_COMPLETE, - MSG_KEXGSS_ERROR) + self.transport._expect_packet( + MSG_KEXGSS_CONTINUE, MSG_KEXGSS_COMPLETE, MSG_KEXGSS_ERROR + ) else: pass @@ -568,9 +609,13 @@ class KexGSSGex(object): # okay, build up the hash H of # (V_C || V_S || I_C || I_S || K_S || min || n || max || p || g || e || f || K) # noqa hm = Message() - hm.add(self.transport.local_version, self.transport.remote_version, - self.transport.local_kex_init, self.transport.remote_kex_init, - self.transport.host_key.__str__()) + hm.add( + self.transport.local_version, + self.transport.remote_version, + self.transport.local_kex_init, + self.transport.remote_kex_init, + self.transport.host_key.__str__(), + ) if not self.old_style: hm.add_int(self.min_bits) hm.add_int(self.preferred_bits) @@ -584,8 +629,9 @@ class KexGSSGex(object): H = sha1(hm.asbytes()).digest() self.transport._set_K_H(K, H) if srv_token is not None: - self.kexgss.ssh_init_sec_context(target=self.gss_host, - recv_token=srv_token) + self.kexgss.ssh_init_sec_context( + target=self.gss_host, recv_token=srv_token + ) self.kexgss.ssh_check_mic(mic_token, H) else: self.kexgss.ssh_check_mic(mic_token, H) @@ -606,12 +652,16 @@ class KexGSSGex(object): maj_status = m.get_int() min_status = m.get_int() err_msg = m.get_string() - m.get_string() # we don't care about the language (lang_tag)! - raise SSHException("""GSS-API Error: + m.get_string() # we don't care about the language (lang_tag)! + raise SSHException( + """GSS-API Error: Major Status: {} Minor Status: {} Error Message: {} -""".format(maj_status, min_status, err_msg)) +""".format( + maj_status, min_status, err_msg + ) + ) class NullHostKey(object): @@ -620,6 +670,7 @@ class NullHostKey(object): in `RFC 4462 Section 5 <https://tools.ietf.org/html/rfc4462.html#section-5>`_ """ + def __init__(self): self.key = "" diff --git a/paramiko/message.py b/paramiko/message.py index 9af841da..dead3508 100644 --- a/paramiko/message.py +++ b/paramiko/message.py @@ -27,7 +27,7 @@ from paramiko.common import zero_byte, max_byte, one_byte, asbytes from paramiko.py3compat import long, BytesIO, u, integer_types -class Message (object): +class Message(object): """ An SSH2 message is a stream of bytes that encodes some combination of strings, integers, bools, and infinite-precision integers (known in Python @@ -63,7 +63,7 @@ class Message (object): """ Returns a string representation of this object, for debugging. """ - return 'paramiko.Message(' + repr(self.packet.getvalue()) + ')' + return "paramiko.Message(" + repr(self.packet.getvalue()) + ")" def asbytes(self): """ @@ -139,13 +139,13 @@ class Message (object): if byte == max_byte: return util.inflate_long(self.get_binary()) byte += self.get_bytes(3) - return struct.unpack('>I', byte)[0] + return struct.unpack(">I", byte)[0] def get_int(self): """ Fetch an int from the stream. """ - return struct.unpack('>I', self.get_bytes(4))[0] + return struct.unpack(">I", self.get_bytes(4))[0] def get_int64(self): """ @@ -153,7 +153,7 @@ class Message (object): :return: a 64-bit unsigned integer (`long`). """ - return struct.unpack('>Q', self.get_bytes(8))[0] + return struct.unpack(">Q", self.get_bytes(8))[0] def get_mpint(self): """ @@ -191,7 +191,7 @@ class Message (object): These are trivially encoded as comma-separated values in a string. """ - return self.get_text().split(',') + return self.get_text().split(",") def add_bytes(self, b): """ @@ -229,7 +229,7 @@ class Message (object): :param int n: integer to add """ - self.packet.write(struct.pack('>I', n)) + self.packet.write(struct.pack(">I", n)) return self def add_adaptive_int(self, n): @@ -242,7 +242,7 @@ class Message (object): self.packet.write(max_byte) self.add_string(util.deflate_long(n)) else: - self.packet.write(struct.pack('>I', n)) + self.packet.write(struct.pack(">I", n)) return self def add_int64(self, n): @@ -251,7 +251,7 @@ class Message (object): :param long n: long int to add """ - self.packet.write(struct.pack('>Q', n)) + self.packet.write(struct.pack(">Q", n)) return self def add_mpint(self, z): @@ -283,7 +283,7 @@ class Message (object): :param l: list of strings to add """ - self.add_string(','.join(l)) + self.add_string(",".join(l)) return self def _add(self, i): diff --git a/paramiko/packet.py b/paramiko/packet.py index 2a1e91e2..20b37ac4 100644 --- a/paramiko/packet.py +++ b/paramiko/packet.py @@ -30,7 +30,12 @@ from hmac import HMAC from paramiko import util from paramiko.common import ( - linefeed_byte, cr_byte_value, asbytes, MSG_NAMES, DEBUG, xffffffff, + linefeed_byte, + cr_byte_value, + asbytes, + MSG_NAMES, + DEBUG, + xffffffff, zero_byte, ) from paramiko.py3compat import u, byte_ord @@ -42,7 +47,7 @@ def compute_hmac(key, message, digest_class): return HMAC(key, message, digest_class).digest() -class NeedRekeyException (Exception): +class NeedRekeyException(Exception): """ Exception indicating a rekey is needed. """ @@ -56,7 +61,7 @@ def first_arg(e): return arg -class Packetizer (object): +class Packetizer(object): """ Implementation of the base SSH packet protocol. """ @@ -128,8 +133,15 @@ class Packetizer (object): """ self.__logger = log - def set_outbound_cipher(self, block_engine, block_size, mac_engine, - mac_size, mac_key, sdctr=False): + def set_outbound_cipher( + self, + block_engine, + block_size, + mac_engine, + mac_size, + mac_key, + sdctr=False, + ): """ Switch outbound data cipher. """ @@ -149,7 +161,8 @@ class Packetizer (object): self.__need_rekey = False def set_inbound_cipher( - self, block_engine, block_size, mac_engine, mac_size, mac_key): + self, block_engine, block_size, mac_engine, mac_size, mac_key + ): """ Switch inbound data cipher. """ @@ -368,7 +381,7 @@ class Packetizer (object): if cmd in MSG_NAMES: cmd_name = MSG_NAMES[cmd] else: - cmd_name = '${:x}'.format(cmd) + cmd_name = "${:x}".format(cmd) orig_len = len(data) self.__write_lock.acquire() try: @@ -378,9 +391,9 @@ class Packetizer (object): if self.__dump_packets: self._log( DEBUG, - 'Write packet <{}>, length {}'.format(cmd_name, orig_len) + "Write packet <{}>, length {}".format(cmd_name, orig_len), ) - self._log(DEBUG, util.format_binary(packet, 'OUT: ')) + self._log(DEBUG, util.format_binary(packet, "OUT: ")) if self.__block_engine_out is not None: out = self.__block_engine_out.update(packet) else: @@ -388,27 +401,30 @@ class Packetizer (object): # + mac if self.__block_engine_out is not None: payload = struct.pack( - '>I', self.__sequence_number_out) + packet + ">I", self.__sequence_number_out + ) + packet out += compute_hmac( - self.__mac_key_out, - payload, - self.__mac_engine_out)[:self.__mac_size_out] - self.__sequence_number_out = \ - (self.__sequence_number_out + 1) & xffffffff + self.__mac_key_out, payload, self.__mac_engine_out + )[ + :self.__mac_size_out + ] + self.__sequence_number_out = ( + self.__sequence_number_out + 1 + ) & xffffffff self.write_all(out) self.__sent_bytes += len(out) self.__sent_packets += 1 sent_too_much = ( - self.__sent_packets >= self.REKEY_PACKETS or - self.__sent_bytes >= self.REKEY_BYTES + self.__sent_packets >= self.REKEY_PACKETS + or self.__sent_bytes >= self.REKEY_BYTES ) if sent_too_much and not self.__need_rekey: # only ask once for rekeying msg = "Rekeying (hit {} packets, {} bytes sent)" - self._log(DEBUG, msg.format( - self.__sent_packets, self.__sent_bytes, - )) + self._log( + DEBUG, msg.format(self.__sent_packets, self.__sent_bytes) + ) self.__received_bytes_overflow = 0 self.__received_packets_overflow = 0 self._trigger_rekey() @@ -427,41 +443,43 @@ class Packetizer (object): if self.__block_engine_in is not None: header = self.__block_engine_in.update(header) if self.__dump_packets: - self._log(DEBUG, util.format_binary(header, 'IN: ')) - packet_size = struct.unpack('>I', header[:4])[0] + self._log(DEBUG, util.format_binary(header, "IN: ")) + packet_size = struct.unpack(">I", header[:4])[0] # leftover contains decrypted bytes from the first block (after the # length field) leftover = header[4:] if (packet_size - len(leftover)) % self.__block_size_in != 0: - raise SSHException('Invalid packet blocking') + raise SSHException("Invalid packet blocking") buf = self.read_all(packet_size + self.__mac_size_in - len(leftover)) packet = buf[:packet_size - len(leftover)] post_packet = buf[packet_size - len(leftover):] if self.__block_engine_in is not None: packet = self.__block_engine_in.update(packet) if self.__dump_packets: - self._log(DEBUG, util.format_binary(packet, 'IN: ')) + self._log(DEBUG, util.format_binary(packet, "IN: ")) packet = leftover + packet if self.__mac_size_in > 0: mac = post_packet[:self.__mac_size_in] mac_payload = struct.pack( - '>II', self.__sequence_number_in, packet_size) + packet + ">II", self.__sequence_number_in, packet_size + ) + packet my_mac = compute_hmac( - self.__mac_key_in, - mac_payload, - self.__mac_engine_in)[:self.__mac_size_in] + self.__mac_key_in, mac_payload, self.__mac_engine_in + )[ + :self.__mac_size_in + ] if not util.constant_time_bytes_eq(my_mac, mac): - raise SSHException('Mismatched MAC') + raise SSHException("Mismatched MAC") padding = byte_ord(packet[0]) payload = packet[1:packet_size - padding] if self.__dump_packets: self._log( DEBUG, - 'Got payload ({} bytes, {} padding)'.format( + "Got payload ({} bytes, {} padding)".format( packet_size, padding - ) + ), ) if self.__compress_engine_in is not None: @@ -480,19 +498,28 @@ class Packetizer (object): # dropping the connection self.__received_bytes_overflow += raw_packet_size self.__received_packets_overflow += 1 - if (self.__received_packets_overflow >= - self.REKEY_PACKETS_OVERFLOW_MAX) or \ - (self.__received_bytes_overflow >= - self.REKEY_BYTES_OVERFLOW_MAX): + if ( + ( + self.__received_packets_overflow + >= self.REKEY_PACKETS_OVERFLOW_MAX + ) + or ( + self.__received_bytes_overflow + >= self.REKEY_BYTES_OVERFLOW_MAX + ) + ): raise SSHException( - 'Remote transport is ignoring rekey requests') - elif (self.__received_packets >= self.REKEY_PACKETS) or \ - (self.__received_bytes >= self.REKEY_BYTES): + "Remote transport is ignoring rekey requests" + ) + elif (self.__received_packets >= self.REKEY_PACKETS) or ( + self.__received_bytes >= self.REKEY_BYTES + ): # only ask once for rekeying err = "Rekeying (hit {} packets, {} bytes received)" - self._log(DEBUG, err.format( - self.__received_packets, self.__received_bytes, - )) + self._log( + DEBUG, + err.format(self.__received_packets, self.__received_bytes), + ) self.__received_bytes_overflow = 0 self.__received_packets_overflow = 0 self._trigger_rekey() @@ -501,11 +528,11 @@ class Packetizer (object): if cmd in MSG_NAMES: cmd_name = MSG_NAMES[cmd] else: - cmd_name = '${:x}'.format(cmd) + cmd_name = "${:x}".format(cmd) if self.__dump_packets: self._log( DEBUG, - 'Read packet <{}>, length {}'.format(cmd_name, len(payload)) + "Read packet <{}>, length {}".format(cmd_name, len(payload)), ) return cmd, msg @@ -522,9 +549,9 @@ class Packetizer (object): def _check_keepalive(self): if ( - not self.__keepalive_interval or - not self.__block_engine_out or - self.__need_rekey + not self.__keepalive_interval + or not self.__block_engine_out + or self.__need_rekey ): # wait till we're encrypting, and not in the middle of rekeying return @@ -559,7 +586,7 @@ class Packetizer (object): # pad up at least 4 bytes, to nearest block-size (usually 8) bsize = self.__block_size_out padding = 3 + bsize - ((len(payload) + 8) % bsize) - packet = struct.pack('>IB', len(payload) + padding + 1, padding) + packet = struct.pack(">IB", len(payload) + padding + 1, padding) packet += payload if self.__sdctr_out or self.__block_engine_out is None: # cute trick i caught openssh doing: if we're not encrypting or diff --git a/paramiko/pipe.py b/paramiko/pipe.py index 6ca37703..e88a5e44 100644 --- a/paramiko/pipe.py +++ b/paramiko/pipe.py @@ -31,14 +31,15 @@ import socket def make_pipe(): - if sys.platform[:3] != 'win': + if sys.platform[:3] != "win": p = PosixPipe() else: p = WindowsPipe() return p -class PosixPipe (object): +class PosixPipe(object): + def __init__(self): self._rfd, self._wfd = os.pipe() self._set = False @@ -64,26 +65,27 @@ class PosixPipe (object): if self._set or self._closed: return self._set = True - os.write(self._wfd, b'*') + os.write(self._wfd, b"*") def set_forever(self): self._forever = True self.set() -class WindowsPipe (object): +class WindowsPipe(object): """ On Windows, only an OS-level "WinSock" may be used in select(), but reads and writes must be to the actual socket object. """ + def __init__(self): serv = socket.socket(socket.AF_INET, socket.SOCK_STREAM) - serv.bind(('127.0.0.1', 0)) + serv.bind(("127.0.0.1", 0)) serv.listen(1) # need to save sockets in _rsock/_wsock so they don't get closed self._rsock = socket.socket(socket.AF_INET, socket.SOCK_STREAM) - self._rsock.connect(('127.0.0.1', serv.getsockname()[1])) + self._rsock.connect(("127.0.0.1", serv.getsockname()[1])) self._wsock, addr = serv.accept() serv.close() @@ -110,14 +112,15 @@ class WindowsPipe (object): if self._set or self._closed: return self._set = True - self._wsock.send(b'*') + self._wsock.send(b"*") def set_forever(self): self._forever = True self.set() -class OrPipe (object): +class OrPipe(object): + def __init__(self, pipe): self._set = False self._partner = None diff --git a/paramiko/pkey.py b/paramiko/pkey.py index 808215f8..a01d4fd8 100644 --- a/paramiko/pkey.py +++ b/paramiko/pkey.py @@ -43,23 +43,23 @@ class PKey(object): # known encryption types for private key files: _CIPHER_TABLE = { - 'AES-128-CBC': { - 'cipher': algorithms.AES, - 'keysize': 16, - 'blocksize': 16, - 'mode': modes.CBC + "AES-128-CBC": { + "cipher": algorithms.AES, + "keysize": 16, + "blocksize": 16, + "mode": modes.CBC, }, - 'AES-256-CBC': { - 'cipher': algorithms.AES, - 'keysize': 32, - 'blocksize': 16, - 'mode': modes.CBC + "AES-256-CBC": { + "cipher": algorithms.AES, + "keysize": 32, + "blocksize": 16, + "mode": modes.CBC, }, - 'DES-EDE3-CBC': { - 'cipher': algorithms.TripleDES, - 'keysize': 24, - 'blocksize': 8, - 'mode': modes.CBC + "DES-EDE3-CBC": { + "cipher": algorithms.TripleDES, + "keysize": 24, + "blocksize": 8, + "mode": modes.CBC, }, } @@ -107,7 +107,7 @@ class PKey(object): hs = hash(self) ho = hash(other) if hs != ho: - return cmp(hs, ho) # noqa + return cmp(hs, ho) # noqa return cmp(self.asbytes(), other.asbytes()) # noqa def __eq__(self, other): @@ -121,7 +121,7 @@ class PKey(object): name of this private key type, in SSH terminology, as a `str` (for example, ``"ssh-rsa"``). """ - return '' + return "" def get_bits(self): """ @@ -158,7 +158,7 @@ class PKey(object): :return: a base64 `string <str>` containing the public part of the key. """ - return u(encodebytes(self.asbytes())).replace('\n', '') + return u(encodebytes(self.asbytes())).replace("\n", "") def sign_ssh_data(self, data): """ @@ -239,7 +239,7 @@ class PKey(object): :raises: ``IOError`` -- if there was an error writing the file :raises: `.SSHException` -- if the key is invalid """ - raise Exception('Not implemented in PKey') + raise Exception("Not implemented in PKey") def write_private_key(self, file_obj, password=None): """ @@ -252,7 +252,7 @@ class PKey(object): :raises: ``IOError`` -- if there was an error writing to the file :raises: `.SSHException` -- if the key is invalid """ - raise Exception('Not implemented in PKey') + raise Exception("Not implemented in PKey") def _read_private_key_file(self, tag, filename, password=None): """ @@ -275,60 +275,61 @@ class PKey(object): encrypted, and ``password`` is ``None``. :raises: `.SSHException` -- if the key file is invalid. """ - with open(filename, 'r') as f: + with open(filename, "r") as f: data = self._read_private_key(tag, f, password) return data def _read_private_key(self, tag, f, password=None): lines = f.readlines() start = 0 - beginning_of_key = '-----BEGIN ' + tag + ' PRIVATE KEY-----' + beginning_of_key = "-----BEGIN " + tag + " PRIVATE KEY-----" while start < len(lines) and lines[start].strip() != beginning_of_key: start += 1 if start >= len(lines): - raise SSHException('not a valid ' + tag + ' private key file') + raise SSHException("not a valid " + tag + " private key file") # parse any headers first headers = {} start += 1 while start < len(lines): - l = lines[start].split(': ') + l = lines[start].split(": ") if len(l) == 1: break headers[l[0].lower()] = l[1].strip() start += 1 # find end end = start - ending_of_key = '-----END ' + tag + ' PRIVATE KEY-----' + ending_of_key = "-----END " + tag + " PRIVATE KEY-----" while end < len(lines) and lines[end].strip() != ending_of_key: end += 1 # if we trudged to the end of the file, just try to cope. try: - data = decodebytes(b(''.join(lines[start:end]))) + data = decodebytes(b("".join(lines[start:end]))) except base64.binascii.Error as e: - raise SSHException('base64 decoding error: ' + str(e)) - if 'proc-type' not in headers: + raise SSHException("base64 decoding error: " + str(e)) + if "proc-type" not in headers: # unencryped: done return data # encrypted keyfile: will need a password - proc_type = headers['proc-type'] - if proc_type != '4,ENCRYPTED': + proc_type = headers["proc-type"] + if proc_type != "4,ENCRYPTED": raise SSHException( 'Unknown private key structure "{}"'.format(proc_type) ) try: - encryption_type, saltstr = headers['dek-info'].split(',') + encryption_type, saltstr = headers["dek-info"].split(",") except: raise SSHException("Can't parse DEK-info in private key file") if encryption_type not in self._CIPHER_TABLE: raise SSHException( - 'Unknown private key cipher "{}"'.format(encryption_type)) + 'Unknown private key cipher "{}"'.format(encryption_type) + ) # if no password was passed in, # raise an exception pointing out that we need one if password is None: - raise PasswordRequiredException('Private key file is encrypted') - cipher = self._CIPHER_TABLE[encryption_type]['cipher'] - keysize = self._CIPHER_TABLE[encryption_type]['keysize'] - mode = self._CIPHER_TABLE[encryption_type]['mode'] + raise PasswordRequiredException("Private key file is encrypted") + cipher = self._CIPHER_TABLE[encryption_type]["cipher"] + keysize = self._CIPHER_TABLE[encryption_type]["keysize"] + mode = self._CIPHER_TABLE[encryption_type]["mode"] salt = unhexlify(b(saltstr)) key = util.generate_key_bytes(md5, salt, password, keysize) decryptor = Cipher( @@ -351,7 +352,7 @@ class PKey(object): :raises: ``IOError`` -- if there was an error writing the file. """ - with open(filename, 'w') as f: + with open(filename, "w") as f: os.chmod(filename, o600) self._write_private_key(f, key, format, password=password) @@ -361,11 +362,11 @@ class PKey(object): else: encryption = serialization.BestAvailableEncryption(b(password)) - f.write(key.private_bytes( - serialization.Encoding.PEM, - format, - encryption - ).decode()) + f.write( + key.private_bytes( + serialization.Encoding.PEM, format, encryption + ).decode() + ) def _check_type_and_load_cert(self, msg, key_type, cert_type): """ @@ -388,7 +389,7 @@ class PKey(object): cert_types = [cert_types] # Can't do much with no message, that should've been handled elsewhere if msg is None: - raise SSHException('Key object may not be empty') + raise SSHException("Key object may not be empty") # First field is always key type, in either kind of object. (make sure # we rewind before grabbing it - sometimes caller had to do their own # introspection first!) @@ -411,7 +412,7 @@ class PKey(object): # (requires going back into per-type subclasses.) msg.get_string() else: - err = 'Invalid key (class: {}, data type: {}' + err = "Invalid key (class: {}, data type: {}" raise SSHException(err.format(self.__class__.__name__, type_)) def load_certificate(self, value): @@ -434,11 +435,11 @@ class PKey(object): successfully. """ if isinstance(value, Message): - constructor = 'from_message' + constructor = "from_message" elif os.path.isfile(value): - constructor = 'from_file' + constructor = "from_file" else: - constructor = 'from_string' + constructor = "from_string" blob = getattr(PublicBlob, constructor)(value) if not blob.key_type.startswith(self.get_name()): err = "PublicBlob type {} incompatible with key type {}" @@ -464,6 +465,7 @@ class PublicBlob(object): `from_message` for useful instantiation, the main constructor is basically "I should be using ``attrs`` for this." """ + def __init__(self, type_, blob, comment=None): """ Create a new public blob of given type and contents. @@ -505,7 +507,7 @@ class PublicBlob(object): m = Message(key_blob) blob_type = m.get_text() if blob_type != key_type: - msg = "Invalid PublicBlob contents: key type={!r}, but blob type={!r}" # noqa + msg = "Invalid PublicBlob contents: key type={!r}, but blob type={!r}" # noqa raise ValueError(msg.format(key_type, blob_type)) # All good? All good. return cls(type_=key_type, blob=key_blob, comment=comment) @@ -522,7 +524,7 @@ class PublicBlob(object): return cls(type_=type_, blob=message.asbytes()) def __str__(self): - ret = '{} public key/certificate'.format(self.key_type) + ret = "{} public key/certificate".format(self.key_type) if self.comment: ret += "- {}".format(self.comment) return ret diff --git a/paramiko/primes.py b/paramiko/primes.py index ca8f9bec..7496abbd 100644 --- a/paramiko/primes.py +++ b/paramiko/primes.py @@ -49,7 +49,7 @@ def _roll_random(n): return num -class ModulusPack (object): +class ModulusPack(object): """ convenience object for holding the contents of the /etc/ssh/moduli file, on systems that have such a file. @@ -61,8 +61,9 @@ class ModulusPack (object): self.discarded = [] def _parse_modulus(self, line): - timestamp, mod_type, tests, tries, size, generator, modulus = \ - line.split() + ( + timestamp, mod_type, tests, tries, size, generator, modulus + ) = line.split() mod_type = int(mod_type) tests = int(tests) tries = int(tries) @@ -75,12 +76,13 @@ class ModulusPack (object): # test 4 (more than just a small-prime sieve) # tries < 100 if test & 4 (at least 100 tries of miller-rabin) if ( - mod_type < 2 or - tests < 4 or - (tests & 4 and tests < 8 and tries < 100) + mod_type < 2 + or tests < 4 + or (tests & 4 and tests < 8 and tries < 100) ): self.discarded.append( - (modulus, 'does not meet basic requirements')) + (modulus, "does not meet basic requirements") + ) return if generator == 0: generator = 2 @@ -91,7 +93,8 @@ class ModulusPack (object): bl = util.bit_length(modulus) if (bl != size) and (bl != size + 1): self.discarded.append( - (modulus, 'incorrectly reported bit length {}'.format(size))) + (modulus, "incorrectly reported bit length {}".format(size)) + ) return if bl not in self.pack: self.pack[bl] = [] @@ -102,10 +105,10 @@ class ModulusPack (object): :raises IOError: passed from any file operations that fail. """ self.pack = {} - with open(filename, 'r') as f: + with open(filename, "r") as f: for line in f: line = line.strip() - if (len(line) == 0) or (line[0] == '#'): + if (len(line) == 0) or (line[0] == "#"): continue try: self._parse_modulus(line) @@ -115,7 +118,7 @@ class ModulusPack (object): def get_modulus(self, min, prefer, max): bitsizes = sorted(self.pack.keys()) if len(bitsizes) == 0: - raise SSHException('no moduli available') + raise SSHException("no moduli available") good = -1 # find nearest bitsize >= preferred for b in bitsizes: diff --git a/paramiko/proxy.py b/paramiko/proxy.py index c4ec627c..d0ef8784 100644 --- a/paramiko/proxy.py +++ b/paramiko/proxy.py @@ -39,6 +39,7 @@ class ProxyCommand(ClosingContextManager): Instances of this class may be used as context managers. """ + def __init__(self, command_line): """ Create a new CommandProxy instance. The instance created by this @@ -50,9 +51,11 @@ class ProxyCommand(ClosingContextManager): # NOTE: subprocess import done lazily so platforms without it (e.g. # GAE) can still import us during overall Paramiko load. from subprocess import Popen, PIPE + self.cmd = shlsplit(command_line) - self.process = Popen(self.cmd, stdin=PIPE, stdout=PIPE, stderr=PIPE, - bufsize=0) + self.process = Popen( + self.cmd, stdin=PIPE, stdout=PIPE, stderr=PIPE, bufsize=0 + ) self.timeout = None def send(self, content): @@ -69,7 +72,7 @@ class ProxyCommand(ClosingContextManager): # died and we can't proceed. The best option here is to # raise an exception informing the user that the informed # ProxyCommand is not working. - raise ProxyCommandFailure(' '.join(self.cmd), e.strerror) + raise ProxyCommandFailure(" ".join(self.cmd), e.strerror) return len(content) def recv(self, size): @@ -81,7 +84,7 @@ class ProxyCommand(ClosingContextManager): :return: the string of bytes read, which may be shorter than requested """ try: - buffer = b'' + buffer = b"" start = time.time() while len(buffer) < size: select_timeout = None @@ -91,11 +94,11 @@ class ProxyCommand(ClosingContextManager): raise socket.timeout() select_timeout = self.timeout - elapsed - r, w, x = select( - [self.process.stdout], [], [], select_timeout) + r, w, x = select([self.process.stdout], [], [], select_timeout) if r and r[0] == self.process.stdout: buffer += os.read( - self.process.stdout.fileno(), size - len(buffer)) + self.process.stdout.fileno(), size - len(buffer) + ) return buffer except socket.timeout: if buffer: @@ -103,7 +106,7 @@ class ProxyCommand(ClosingContextManager): return buffer raise # socket.timeout is a subclass of IOError except IOError as e: - raise ProxyCommandFailure(' '.join(self.cmd), e.strerror) + raise ProxyCommandFailure(" ".join(self.cmd), e.strerror) def close(self): os.kill(self.process.pid, signal.SIGTERM) diff --git a/paramiko/py3compat.py b/paramiko/py3compat.py index cb9de412..795b9b0e 100644 --- a/paramiko/py3compat.py +++ b/paramiko/py3compat.py @@ -1,11 +1,31 @@ import sys import base64 -__all__ = ['PY2', 'string_types', 'integer_types', 'text_type', 'bytes_types', - 'bytes', 'long', 'input', 'decodebytes', 'encodebytes', - 'bytestring', 'byte_ord', 'byte_chr', 'byte_mask', 'b', 'u', 'b2s', - 'StringIO', 'BytesIO', 'is_callable', 'MAXSIZE', - 'next', 'builtins'] +__all__ = [ + "PY2", + "string_types", + "integer_types", + "text_type", + "bytes_types", + "bytes", + "long", + "input", + "decodebytes", + "encodebytes", + "bytestring", + "byte_ord", + "byte_chr", + "byte_mask", + "b", + "u", + "b2s", + "StringIO", + "BytesIO", + "is_callable", + "MAXSIZE", + "next", + "builtins", +] PY2 = sys.version_info[0] < 3 @@ -22,22 +42,18 @@ if PY2: import __builtin__ as builtins - def bytestring(s): # NOQA if isinstance(s, unicode): # NOQA - return s.encode('utf-8') + return s.encode("utf-8") return s - byte_ord = ord # NOQA byte_chr = chr # NOQA - def byte_mask(c, mask): return chr(ord(c) & mask) - - def b(s, encoding='utf8'): # NOQA + def b(s, encoding="utf8"): # NOQA """cast unicode or bytes to bytes""" if isinstance(s, str): return s @@ -48,8 +64,7 @@ if PY2: else: raise TypeError("Expected unicode or bytes, got {!r}".format(s)) - - def u(s, encoding='utf8'): # NOQA + def u(s, encoding="utf8"): # NOQA """cast bytes or unicode to unicode""" if isinstance(s, str): return s.decode(encoding) @@ -60,53 +75,52 @@ if PY2: else: raise TypeError("Expected unicode or bytes, got {!r}".format(s)) - def b2s(s): return s - import cStringIO + StringIO = cStringIO.StringIO BytesIO = StringIO - def is_callable(c): # NOQA return callable(c) - def get_next(c): # NOQA return c.next - def next(c): return c.next() # It's possible to have sizeof(long) != sizeof(Py_ssize_t). class X(object): + def __len__(self): return 1 << 31 - try: len(X()) except OverflowError: # 32-bit - MAXSIZE = int((1 << 31) - 1) # NOQA + MAXSIZE = int((1 << 31) - 1) # NOQA else: # 64-bit - MAXSIZE = int((1 << 63) - 1) # NOQA + MAXSIZE = int((1 << 63) - 1) # NOQA del X else: import collections import struct import builtins + string_types = str text_type = str bytes = bytes bytes_types = bytes integer_types = int + class long(int): pass + input = input decodebytes = base64.decodebytes encodebytes = base64.encodebytes @@ -122,13 +136,13 @@ else: def byte_chr(c): assert isinstance(c, int) - return struct.pack('B', c) + return struct.pack("B", c) def byte_mask(c, mask): assert isinstance(c, int) - return struct.pack('B', c & mask) + return struct.pack("B", c & mask) - def b(s, encoding='utf8'): + def b(s, encoding="utf8"): """cast unicode or bytes to bytes""" if isinstance(s, bytes): return s @@ -137,7 +151,7 @@ else: else: raise TypeError("Expected unicode or bytes, got {!r}".format(s)) - def u(s, encoding='utf8'): + def u(s, encoding="utf8"): """cast bytes or unicode to unicode""" if isinstance(s, bytes): return s.decode(encoding) @@ -150,8 +164,9 @@ else: return s.decode() if isinstance(s, bytes) else s import io - StringIO = io.StringIO # NOQA - BytesIO = io.BytesIO # NOQA + + StringIO = io.StringIO # NOQA + BytesIO = io.BytesIO # NOQA def is_callable(c): return isinstance(c, collections.Callable) @@ -161,4 +176,4 @@ else: next = next - MAXSIZE = sys.maxsize # NOQA + MAXSIZE = sys.maxsize # NOQA diff --git a/paramiko/rsakey.py b/paramiko/rsakey.py index 8dfcfb01..b0fce1f1 100644 --- a/paramiko/rsakey.py +++ b/paramiko/rsakey.py @@ -37,8 +37,15 @@ class RSAKey(PKey): data. """ - def __init__(self, msg=None, data=None, filename=None, password=None, - key=None, file_obj=None): + def __init__( + self, + msg=None, + data=None, + filename=None, + password=None, + key=None, + file_obj=None, + ): self.key = None self.public_blob = None if file_obj is not None: @@ -54,12 +61,14 @@ class RSAKey(PKey): else: self._check_type_and_load_cert( msg=msg, - key_type='ssh-rsa', - cert_type='ssh-rsa-cert-v01@openssh.com', + key_type="ssh-rsa", + cert_type="ssh-rsa-cert-v01@openssh.com", ) self.key = rsa.RSAPublicNumbers( e=msg.get_mpint(), n=msg.get_mpint() - ).public_key(default_backend()) + ).public_key( + default_backend() + ) @property def size(self): @@ -74,7 +83,7 @@ class RSAKey(PKey): def asbytes(self): m = Message() - m.add_string('ssh-rsa') + m.add_string("ssh-rsa") m.add_mpint(self.public_numbers.e) m.add_mpint(self.public_numbers.n) return m.asbytes() @@ -89,14 +98,15 @@ class RSAKey(PKey): # tries stuffing it into ASCII for whatever godforsaken reason return self.asbytes() else: - return self.asbytes().decode('utf8', errors='ignore') + return self.asbytes().decode("utf8", errors="ignore") def __hash__(self): - return hash((self.get_name(), self.public_numbers.e, - self.public_numbers.n)) + return hash( + (self.get_name(), self.public_numbers.e, self.public_numbers.n) + ) def get_name(self): - return 'ssh-rsa' + return "ssh-rsa" def get_bits(self): return self.size @@ -106,18 +116,16 @@ class RSAKey(PKey): def sign_ssh_data(self, data): sig = self.key.sign( - data, - padding=padding.PKCS1v15(), - algorithm=hashes.SHA1(), + data, padding=padding.PKCS1v15(), algorithm=hashes.SHA1() ) m = Message() - m.add_string('ssh-rsa') + m.add_string("ssh-rsa") m.add_string(sig) return m def verify_ssh_sig(self, data, msg): - if msg.get_text() != 'ssh-rsa': + if msg.get_text() != "ssh-rsa": return False key = self.key if isinstance(key, rsa.RSAPrivateKey): @@ -137,7 +145,7 @@ class RSAKey(PKey): filename, self.key, serialization.PrivateFormat.TraditionalOpenSSL, - password=password + password=password, ) def write_private_key(self, file_obj, password=None): @@ -145,7 +153,7 @@ class RSAKey(PKey): file_obj, self.key, serialization.PrivateFormat.TraditionalOpenSSL, - password=password + password=password, ) @staticmethod @@ -166,11 +174,11 @@ class RSAKey(PKey): # ...internals... def _from_private_key_file(self, filename, password): - data = self._read_private_key_file('RSA', filename, password) + data = self._read_private_key_file("RSA", filename, password) self._decode_key(data) def _from_private_key(self, file_obj, password): - data = self._read_private_key('RSA', file_obj, password) + data = self._read_private_key("RSA", file_obj, password) self._decode_key(data) def _decode_key(self, data): diff --git a/paramiko/server.py b/paramiko/server.py index a7117815..2fe9cc19 100644 --- a/paramiko/server.py +++ b/paramiko/server.py @@ -23,13 +23,16 @@ import threading from paramiko import util from paramiko.common import ( - DEBUG, ERROR, OPEN_FAILED_ADMINISTRATIVELY_PROHIBITED, AUTH_FAILED, + DEBUG, + ERROR, + OPEN_FAILED_ADMINISTRATIVELY_PROHIBITED, + AUTH_FAILED, AUTH_SUCCESSFUL, ) from paramiko.py3compat import string_types -class ServerInterface (object): +class ServerInterface(object): """ This class defines an interface for controlling the behavior of Paramiko in server mode. @@ -99,7 +102,7 @@ class ServerInterface (object): :param str username: the username requesting authentication. :return: a comma-separated `str` of authentication types """ - return 'password' + return "password" def check_auth_none(self, username): """ @@ -233,9 +236,9 @@ class ServerInterface (object): """ return AUTH_FAILED - def check_auth_gssapi_with_mic(self, username, - gss_authenticated=AUTH_FAILED, - cc_file=None): + def check_auth_gssapi_with_mic( + self, username, gss_authenticated=AUTH_FAILED, cc_file=None + ): """ Authenticate the given user to the server if he is a valid krb5 principal. @@ -263,9 +266,9 @@ class ServerInterface (object): return AUTH_SUCCESSFUL return AUTH_FAILED - def check_auth_gssapi_keyex(self, username, - gss_authenticated=AUTH_FAILED, - cc_file=None): + def check_auth_gssapi_keyex( + self, username, gss_authenticated=AUTH_FAILED, cc_file=None + ): """ Authenticate the given user to the server if he is a valid krb5 principal and GSS-API Key Exchange was performed. @@ -372,8 +375,8 @@ class ServerInterface (object): # ...Channel requests... def check_channel_pty_request( - self, channel, term, width, height, pixelwidth, pixelheight, - modes): + self, channel, term, width, height, pixelwidth, pixelheight, modes + ): """ Determine if a pseudo-terminal of the given dimensions (usually requested for shell access) can be provided on the given channel. @@ -460,7 +463,8 @@ class ServerInterface (object): return True def check_channel_window_change_request( - self, channel, width, height, pixelwidth, pixelheight): + self, channel, width, height, pixelwidth, pixelheight + ): """ Determine if the pseudo-terminal on the given channel can be resized. This only makes sense if a pty was previously allocated on it. @@ -479,8 +483,13 @@ class ServerInterface (object): return False def check_channel_x11_request( - self, channel, single_connection, auth_protocol, auth_cookie, - screen_number): + self, + channel, + single_connection, + auth_protocol, + auth_cookie, + screen_number, + ): """ Determine if the client will be provided with an X11 session. If this method returns ``True``, X11 applications should be routed through new @@ -584,12 +593,13 @@ class ServerInterface (object): """ return (None, None) -class InteractiveQuery (object): + +class InteractiveQuery(object): """ A query (set of prompts) for a user during interactive authentication. """ - def __init__(self, name='', instructions='', *prompts): + def __init__(self, name="", instructions="", *prompts): """ Create a new interactive query to send to the client. The name and instructions are optional, but are generally displayed to the end @@ -623,7 +633,7 @@ class InteractiveQuery (object): self.prompts.append((prompt, echo)) -class SubsystemHandler (threading.Thread): +class SubsystemHandler(threading.Thread): """ Handler for a subsytem in server mode. If you create a subclass of this class and pass it to `.Transport.set_subsystem_handler`, an object of this @@ -637,6 +647,7 @@ class SubsystemHandler (threading.Thread): ``MP3Handler`` will be created, and `start_subsystem` will be called on it from a new thread. """ + def __init__(self, channel, name, server): """ Create a new handler for a channel. This is used by `.ServerInterface` @@ -667,7 +678,7 @@ class SubsystemHandler (threading.Thread): def _run(self): try: self.__transport._log( - DEBUG, 'Starting handler for subsystem {}'.format(self.__name) + DEBUG, "Starting handler for subsystem {}".format(self.__name) ) self.start_subsystem(self.__name, self.__transport, self.__channel) except Exception as e: @@ -675,7 +686,7 @@ class SubsystemHandler (threading.Thread): ERROR, 'Exception in subsystem handler for "{}": {}'.format( self.__name, e - ) + ), ) self.__transport._log(ERROR, util.tb_strings()) try: diff --git a/paramiko/sftp.py b/paramiko/sftp.py index e6786d10..ac32f6bd 100644 --- a/paramiko/sftp.py +++ b/paramiko/sftp.py @@ -26,27 +26,58 @@ from paramiko.message import Message from paramiko.py3compat import byte_chr, byte_ord -CMD_INIT, CMD_VERSION, CMD_OPEN, CMD_CLOSE, CMD_READ, CMD_WRITE, CMD_LSTAT, \ - CMD_FSTAT, CMD_SETSTAT, CMD_FSETSTAT, CMD_OPENDIR, CMD_READDIR, \ - CMD_REMOVE, CMD_MKDIR, CMD_RMDIR, CMD_REALPATH, CMD_STAT, CMD_RENAME, \ - CMD_READLINK, CMD_SYMLINK = range(1, 21) -CMD_STATUS, CMD_HANDLE, CMD_DATA, CMD_NAME, CMD_ATTRS = range(101, 106) -CMD_EXTENDED, CMD_EXTENDED_REPLY = range(200, 202) +( + CMD_INIT, + CMD_VERSION, + CMD_OPEN, + CMD_CLOSE, + CMD_READ, + CMD_WRITE, + CMD_LSTAT, + CMD_FSTAT, + CMD_SETSTAT, + CMD_FSETSTAT, + CMD_OPENDIR, + CMD_READDIR, + CMD_REMOVE, + CMD_MKDIR, + CMD_RMDIR, + CMD_REALPATH, + CMD_STAT, + CMD_RENAME, + CMD_READLINK, + CMD_SYMLINK, +) = range( + 1, 21 +) +(CMD_STATUS, CMD_HANDLE, CMD_DATA, CMD_NAME, CMD_ATTRS) = range(101, 106) +(CMD_EXTENDED, CMD_EXTENDED_REPLY) = range(200, 202) SFTP_OK = 0 -SFTP_EOF, SFTP_NO_SUCH_FILE, SFTP_PERMISSION_DENIED, SFTP_FAILURE, \ - SFTP_BAD_MESSAGE, SFTP_NO_CONNECTION, SFTP_CONNECTION_LOST, \ - SFTP_OP_UNSUPPORTED = range(1, 9) - -SFTP_DESC = ['Success', - 'End of file', - 'No such file', - 'Permission denied', - 'Failure', - 'Bad message', - 'No connection', - 'Connection lost', - 'Operation unsupported'] +( + SFTP_EOF, + SFTP_NO_SUCH_FILE, + SFTP_PERMISSION_DENIED, + SFTP_FAILURE, + SFTP_BAD_MESSAGE, + SFTP_NO_CONNECTION, + SFTP_CONNECTION_LOST, + SFTP_OP_UNSUPPORTED, +) = range( + 1, 9 +) + +SFTP_DESC = [ + "Success", + "End of file", + "No such file", + "Permission denied", + "Failure", + "Bad message", + "No connection", + "Connection lost", + "Operation unsupported", +] SFTP_FLAG_READ = 0x1 SFTP_FLAG_WRITE = 0x2 @@ -60,54 +91,55 @@ _VERSION = 3 # for debugging CMD_NAMES = { - CMD_INIT: 'init', - CMD_VERSION: 'version', - CMD_OPEN: 'open', - CMD_CLOSE: 'close', - CMD_READ: 'read', - CMD_WRITE: 'write', - CMD_LSTAT: 'lstat', - CMD_FSTAT: 'fstat', - CMD_SETSTAT: 'setstat', - CMD_FSETSTAT: 'fsetstat', - CMD_OPENDIR: 'opendir', - CMD_READDIR: 'readdir', - CMD_REMOVE: 'remove', - CMD_MKDIR: 'mkdir', - CMD_RMDIR: 'rmdir', - CMD_REALPATH: 'realpath', - CMD_STAT: 'stat', - CMD_RENAME: 'rename', - CMD_READLINK: 'readlink', - CMD_SYMLINK: 'symlink', - CMD_STATUS: 'status', - CMD_HANDLE: 'handle', - CMD_DATA: 'data', - CMD_NAME: 'name', - CMD_ATTRS: 'attrs', - CMD_EXTENDED: 'extended', - CMD_EXTENDED_REPLY: 'extended_reply' + CMD_INIT: "init", + CMD_VERSION: "version", + CMD_OPEN: "open", + CMD_CLOSE: "close", + CMD_READ: "read", + CMD_WRITE: "write", + CMD_LSTAT: "lstat", + CMD_FSTAT: "fstat", + CMD_SETSTAT: "setstat", + CMD_FSETSTAT: "fsetstat", + CMD_OPENDIR: "opendir", + CMD_READDIR: "readdir", + CMD_REMOVE: "remove", + CMD_MKDIR: "mkdir", + CMD_RMDIR: "rmdir", + CMD_REALPATH: "realpath", + CMD_STAT: "stat", + CMD_RENAME: "rename", + CMD_READLINK: "readlink", + CMD_SYMLINK: "symlink", + CMD_STATUS: "status", + CMD_HANDLE: "handle", + CMD_DATA: "data", + CMD_NAME: "name", + CMD_ATTRS: "attrs", + CMD_EXTENDED: "extended", + CMD_EXTENDED_REPLY: "extended_reply", } -class SFTPError (Exception): +class SFTPError(Exception): pass -class BaseSFTP (object): +class BaseSFTP(object): + def __init__(self): - self.logger = util.get_logger('paramiko.sftp') + self.logger = util.get_logger("paramiko.sftp") self.sock = None self.ultra_debug = False # ...internals... def _send_version(self): - self._send_packet(CMD_INIT, struct.pack('>I', _VERSION)) + self._send_packet(CMD_INIT, struct.pack(">I", _VERSION)) t, data = self._read_packet() if t != CMD_VERSION: - raise SFTPError('Incompatible sftp protocol') - version = struct.unpack('>I', data[:4])[0] + raise SFTPError("Incompatible sftp protocol") + version = struct.unpack(">I", data[:4])[0] # if version != _VERSION: # raise SFTPError('Incompatible sftp protocol') return version @@ -117,10 +149,10 @@ class BaseSFTP (object): # client finishes sending INIT. t, data = self._read_packet() if t != CMD_INIT: - raise SFTPError('Incompatible sftp protocol') - version = struct.unpack('>I', data[:4])[0] + raise SFTPError("Incompatible sftp protocol") + version = struct.unpack(">I", data[:4])[0] # advertise that we support "check-file" - extension_pairs = ['check-file', 'md5,sha1'] + extension_pairs = ["check-file", "md5,sha1"] msg = Message() msg.add_int(_VERSION) msg.add(*extension_pairs) @@ -165,9 +197,9 @@ class BaseSFTP (object): def _send_packet(self, t, packet): packet = asbytes(packet) - out = struct.pack('>I', len(packet) + 1) + byte_chr(t) + packet + out = struct.pack(">I", len(packet) + 1) + byte_chr(t) + packet if self.ultra_debug: - self._log(DEBUG, util.format_binary(out, 'OUT: ')) + self._log(DEBUG, util.format_binary(out, "OUT: ")) self._write_all(out) def _read_packet(self): @@ -175,11 +207,11 @@ class BaseSFTP (object): # most sftp servers won't accept packets larger than about 32k, so # anything with the high byte set (> 16MB) is just garbage. if byte_ord(x[0]): - raise SFTPError('Garbage packet received') - size = struct.unpack('>I', x)[0] + raise SFTPError("Garbage packet received") + size = struct.unpack(">I", x)[0] data = self._read_all(size) if self.ultra_debug: - self._log(DEBUG, util.format_binary(data, 'IN: ')) + self._log(DEBUG, util.format_binary(data, "IN: ")) if size > 0: t = byte_ord(data[0]) return t, data[1:] diff --git a/paramiko/sftp_attr.py b/paramiko/sftp_attr.py index ea12b2f6..8e483739 100644 --- a/paramiko/sftp_attr.py +++ b/paramiko/sftp_attr.py @@ -22,7 +22,7 @@ from paramiko.common import x80000000, o700, o70, xffffffff from paramiko.py3compat import long, b -class SFTPAttributes (object): +class SFTPAttributes(object): """ Representation of the attributes of a file (or proxied file) for SFTP in client or server mode. It attemps to mirror the object returned by @@ -82,7 +82,7 @@ class SFTPAttributes (object): return attr def __repr__(self): - return '<SFTPAttributes: {}>'.format(self._debug_str()) + return "<SFTPAttributes: {}>".format(self._debug_str()) # ...internals... @classmethod @@ -144,29 +144,29 @@ class SFTPAttributes (object): return def _debug_str(self): - out = '[ ' + out = "[ " if self.st_size is not None: - out += 'size={} '.format(self.st_size) + out += "size={} ".format(self.st_size) if (self.st_uid is not None) and (self.st_gid is not None): - out += 'uid={} gid={} '.format(self.st_uid, self.st_gid) + out += "uid={} gid={} ".format(self.st_uid, self.st_gid) if self.st_mode is not None: - out += 'mode=' + oct(self.st_mode) + ' ' + out += "mode=" + oct(self.st_mode) + " " if (self.st_atime is not None) and (self.st_mtime is not None): - out += 'atime={} mtime={} '.format(self.st_atime, self.st_mtime) + out += "atime={} mtime={} ".format(self.st_atime, self.st_mtime) for k, v in self.attr.items(): out += '"{}"={!r} '.format(str(k), v) - out += ']' + out += "]" return out @staticmethod def _rwx(n, suid, sticky=False): if suid: suid = 2 - out = '-r'[n >> 2] + '-w'[(n >> 1) & 1] + out = "-r"[n >> 2] + "-w"[(n >> 1) & 1] if sticky: - out += '-xTt'[suid + (n & 1)] + out += "-xTt"[suid + (n & 1)] else: - out += '-xSs'[suid + (n & 1)] + out += "-xSs"[suid + (n & 1)] return out def __str__(self): @@ -174,42 +174,47 @@ class SFTPAttributes (object): if self.st_mode is not None: kind = stat.S_IFMT(self.st_mode) if kind == stat.S_IFIFO: - ks = 'p' + ks = "p" elif kind == stat.S_IFCHR: - ks = 'c' + ks = "c" elif kind == stat.S_IFDIR: - ks = 'd' + ks = "d" elif kind == stat.S_IFBLK: - ks = 'b' + ks = "b" elif kind == stat.S_IFREG: - ks = '-' + ks = "-" elif kind == stat.S_IFLNK: - ks = 'l' + ks = "l" elif kind == stat.S_IFSOCK: - ks = 's' + ks = "s" else: - ks = '?' + ks = "?" ks += self._rwx( - (self.st_mode & o700) >> 6, self.st_mode & stat.S_ISUID) + (self.st_mode & o700) >> 6, self.st_mode & stat.S_ISUID + ) ks += self._rwx( - (self.st_mode & o70) >> 3, self.st_mode & stat.S_ISGID) + (self.st_mode & o70) >> 3, self.st_mode & stat.S_ISGID + ) ks += self._rwx( - self.st_mode & 7, self.st_mode & stat.S_ISVTX, True) + self.st_mode & 7, self.st_mode & stat.S_ISVTX, True + ) else: - ks = '?---------' + ks = "?---------" # compute display date if (self.st_mtime is None) or (self.st_mtime == xffffffff): # shouldn't really happen - datestr = '(unknown date)' + datestr = "(unknown date)" else: if abs(time.time() - self.st_mtime) > 15552000: # (15552000 = 6 months) datestr = time.strftime( - '%d %b %Y', time.localtime(self.st_mtime)) + "%d %b %Y", time.localtime(self.st_mtime) + ) else: datestr = time.strftime( - '%d %b %H:%M', time.localtime(self.st_mtime)) - filename = getattr(self, 'filename', '?') + "%d %b %H:%M", time.localtime(self.st_mtime) + ) + filename = getattr(self, "filename", "?") # not all servers support uid/gid uid = self.st_uid @@ -225,8 +230,8 @@ class SFTPAttributes (object): # TODO: not sure this actually worked as expected beforehand, leaving # it untouched for the time being, re: .format() upgrade, until someone # has time to doublecheck - return '%s 1 %-8d %-8d %8d %-12s %s' % ( - ks, uid, gid, size, datestr, filename, + return "%s 1 %-8d %-8d %8d %-12s %s" % ( + ks, uid, gid, size, datestr, filename ) def asbytes(self): diff --git a/paramiko/sftp_client.py b/paramiko/sftp_client.py index b344dff3..f6d59d54 100644 --- a/paramiko/sftp_client.py +++ b/paramiko/sftp_client.py @@ -30,12 +30,37 @@ from paramiko.message import Message from paramiko.common import INFO, DEBUG, o777 from paramiko.py3compat import bytestring, b, u, long from paramiko.sftp import ( - BaseSFTP, CMD_OPENDIR, CMD_HANDLE, SFTPError, CMD_READDIR, CMD_NAME, - CMD_CLOSE, SFTP_FLAG_READ, SFTP_FLAG_WRITE, SFTP_FLAG_CREATE, - SFTP_FLAG_TRUNC, SFTP_FLAG_APPEND, SFTP_FLAG_EXCL, CMD_OPEN, CMD_REMOVE, - CMD_RENAME, CMD_MKDIR, CMD_RMDIR, CMD_STAT, CMD_ATTRS, CMD_LSTAT, - CMD_SYMLINK, CMD_SETSTAT, CMD_READLINK, CMD_REALPATH, CMD_STATUS, - CMD_EXTENDED, SFTP_OK, SFTP_EOF, SFTP_NO_SUCH_FILE, SFTP_PERMISSION_DENIED, + BaseSFTP, + CMD_OPENDIR, + CMD_HANDLE, + SFTPError, + CMD_READDIR, + CMD_NAME, + CMD_CLOSE, + SFTP_FLAG_READ, + SFTP_FLAG_WRITE, + SFTP_FLAG_CREATE, + SFTP_FLAG_TRUNC, + SFTP_FLAG_APPEND, + SFTP_FLAG_EXCL, + CMD_OPEN, + CMD_REMOVE, + CMD_RENAME, + CMD_MKDIR, + CMD_RMDIR, + CMD_STAT, + CMD_ATTRS, + CMD_LSTAT, + CMD_SYMLINK, + CMD_SETSTAT, + CMD_READLINK, + CMD_REALPATH, + CMD_STATUS, + CMD_EXTENDED, + SFTP_OK, + SFTP_EOF, + SFTP_NO_SUCH_FILE, + SFTP_PERMISSION_DENIED, ) from paramiko.sftp_attr import SFTPAttributes @@ -51,15 +76,15 @@ def _to_unicode(s): probably doesn't know the filename's encoding. """ try: - return s.encode('ascii') + return s.encode("ascii") except (UnicodeError, AttributeError): try: - return s.decode('utf-8') + return s.decode("utf-8") except UnicodeError: return s -b_slash = b'/' +b_slash = b"/" class SFTPClient(BaseSFTP, ClosingContextManager): @@ -71,6 +96,7 @@ class SFTPClient(BaseSFTP, ClosingContextManager): Instances of this class may be used as context managers. """ + def __init__(self, sock): """ Create an SFTP client from an existing `.Channel`. The channel @@ -97,15 +123,18 @@ class SFTPClient(BaseSFTP, ClosingContextManager): # override default logger transport = self.sock.get_transport() self.logger = util.get_logger( - transport.get_log_channel() + '.sftp') + transport.get_log_channel() + ".sftp" + ) self.ultra_debug = transport.get_hexdump() try: server_version = self._send_version() except EOFError: - raise SSHException('EOF during negotiation') + raise SSHException("EOF during negotiation") self._log( INFO, - 'Opened sftp connection (server version {})'.format(server_version) + "Opened sftp connection (server version {})".format( + server_version + ), ) @classmethod @@ -132,11 +161,12 @@ class SFTPClient(BaseSFTP, ClosingContextManager): .. versionchanged:: 1.15 Added the ``window_size`` and ``max_packet_size`` arguments. """ - chan = t.open_session(window_size=window_size, - max_packet_size=max_packet_size) + chan = t.open_session( + window_size=window_size, max_packet_size=max_packet_size + ) if chan is None: return None - chan.invoke_subsystem('sftp') + chan.invoke_subsystem("sftp") return cls(chan) def _log(self, level, msg, *args): @@ -148,10 +178,12 @@ class SFTPClient(BaseSFTP, ClosingContextManager): # logging.Logger.log() explicitly requires it. Grump. # escape '%' in msg (they could come from file or directory names) # before logging - msg = msg.replace('%', '%%') + msg = msg.replace("%", "%%") super(SFTPClient, self)._log( level, - "[chan %s] " + msg, *([self.sock.get_name()] + list(args))) + "[chan %s] " + msg, + *([self.sock.get_name()] + list(args)) + ) def close(self): """ @@ -159,7 +191,7 @@ class SFTPClient(BaseSFTP, ClosingContextManager): .. versionadded:: 1.4 """ - self._log(INFO, 'sftp session closed.') + self._log(INFO, "sftp session closed.") self.sock.close() def get_channel(self): @@ -171,7 +203,7 @@ class SFTPClient(BaseSFTP, ClosingContextManager): """ return self.sock - def listdir(self, path='.'): + def listdir(self, path="."): """ Return a list containing the names of the entries in the given ``path``. @@ -185,7 +217,7 @@ class SFTPClient(BaseSFTP, ClosingContextManager): """ return [f.filename for f in self.listdir_attr(path)] - def listdir_attr(self, path='.'): + def listdir_attr(self, path="."): """ Return a list containing `.SFTPAttributes` objects corresponding to files in the given ``path``. The list is in arbitrary order. It does @@ -203,10 +235,10 @@ class SFTPClient(BaseSFTP, ClosingContextManager): .. versionadded:: 1.2 """ path = self._adjust_cwd(path) - self._log(DEBUG, 'listdir({!r})'.format(path)) + self._log(DEBUG, "listdir({!r})".format(path)) t, msg = self._request(CMD_OPENDIR, path) if t != CMD_HANDLE: - raise SFTPError('Expected handle') + raise SFTPError("Expected handle") handle = msg.get_binary() filelist = [] while True: @@ -216,18 +248,18 @@ class SFTPClient(BaseSFTP, ClosingContextManager): # done with handle break if t != CMD_NAME: - raise SFTPError('Expected name response') + raise SFTPError("Expected name response") count = msg.get_int() for i in range(count): filename = msg.get_text() longname = msg.get_text() attr = SFTPAttributes._from_msg(msg, filename, longname) - if (filename != '.') and (filename != '..'): + if (filename != ".") and (filename != ".."): filelist.append(attr) self._request(CMD_CLOSE, handle) return filelist - def listdir_iter(self, path='.', read_aheads=50): + def listdir_iter(self, path=".", read_aheads=50): """ Generator version of `.listdir_attr`. @@ -242,11 +274,11 @@ class SFTPClient(BaseSFTP, ClosingContextManager): .. versionadded:: 1.15 """ path = self._adjust_cwd(path) - self._log(DEBUG, 'listdir({!r})'.format(path)) + self._log(DEBUG, "listdir({!r})".format(path)) t, msg = self._request(CMD_OPENDIR, path) if t != CMD_HANDLE: - raise SFTPError('Expected handle') + raise SFTPError("Expected handle") handle = msg.get_string() @@ -261,7 +293,6 @@ class SFTPClient(BaseSFTP, ClosingContextManager): num = self._async_request(type(None), CMD_READDIR, handle) nums.append(num) - # For each of our sent requests # Read and parse the corresponding packets # If we're at the end of our queued requests, then fire off @@ -280,8 +311,9 @@ class SFTPClient(BaseSFTP, ClosingContextManager): filename = msg.get_text() longname = msg.get_text() attr = SFTPAttributes._from_msg( - msg, filename, longname) - if (filename != '.') and (filename != '..'): + msg, filename, longname + ) + if (filename != ".") and (filename != ".."): yield attr # If we've hit the end of our queued requests, reset nums. @@ -291,8 +323,7 @@ class SFTPClient(BaseSFTP, ClosingContextManager): self._request(CMD_CLOSE, handle) return - - def open(self, filename, mode='r', bufsize=-1): + def open(self, filename, mode="r", bufsize=-1): """ Open a file on the remote server. The arguments are the same as for Python's built-in `python:file` (aka `python:open`). A file-like @@ -325,26 +356,28 @@ class SFTPClient(BaseSFTP, ClosingContextManager): :raises: ``IOError`` -- if the file could not be opened. """ filename = self._adjust_cwd(filename) - self._log(DEBUG, 'open({!r}, {!r})'.format(filename, mode)) + self._log(DEBUG, "open({!r}, {!r})".format(filename, mode)) imode = 0 - if ('r' in mode) or ('+' in mode): + if ("r" in mode) or ("+" in mode): imode |= SFTP_FLAG_READ - if ('w' in mode) or ('+' in mode) or ('a' in mode): + if ("w" in mode) or ("+" in mode) or ("a" in mode): imode |= SFTP_FLAG_WRITE - if 'w' in mode: + if "w" in mode: imode |= SFTP_FLAG_CREATE | SFTP_FLAG_TRUNC - if 'a' in mode: + if "a" in mode: imode |= SFTP_FLAG_CREATE | SFTP_FLAG_APPEND - if 'x' in mode: + if "x" in mode: imode |= SFTP_FLAG_CREATE | SFTP_FLAG_EXCL attrblock = SFTPAttributes() t, msg = self._request(CMD_OPEN, filename, imode, attrblock) if t != CMD_HANDLE: - raise SFTPError('Expected handle') + raise SFTPError("Expected handle") handle = msg.get_binary() self._log( DEBUG, - 'open({!r}, {!r}) -> {}'.format(filename, mode, u(hexlify(handle))) + "open({!r}, {!r}) -> {}".format( + filename, mode, u(hexlify(handle)) + ), ) return SFTPFile(self, handle, mode, bufsize) @@ -361,7 +394,7 @@ class SFTPClient(BaseSFTP, ClosingContextManager): :raises: ``IOError`` -- if the path refers to a folder (directory) """ path = self._adjust_cwd(path) - self._log(DEBUG, 'remove({!r})'.format(path)) + self._log(DEBUG, "remove({!r})".format(path)) self._request(CMD_REMOVE, path) unlink = remove @@ -386,7 +419,7 @@ class SFTPClient(BaseSFTP, ClosingContextManager): """ oldpath = self._adjust_cwd(oldpath) newpath = self._adjust_cwd(newpath) - self._log(DEBUG, 'rename({!r}, {!r})'.format(oldpath, newpath)) + self._log(DEBUG, "rename({!r}, {!r})".format(oldpath, newpath)) self._request(CMD_RENAME, oldpath, newpath) def posix_rename(self, oldpath, newpath): @@ -406,7 +439,7 @@ class SFTPClient(BaseSFTP, ClosingContextManager): """ oldpath = self._adjust_cwd(oldpath) newpath = self._adjust_cwd(newpath) - self._log(DEBUG, 'posix_rename({!r}, {!r})'.format(oldpath, newpath)) + self._log(DEBUG, "posix_rename({!r}, {!r})".format(oldpath, newpath)) self._request( CMD_EXTENDED, "posix-rename@openssh.com", oldpath, newpath ) @@ -421,7 +454,7 @@ class SFTPClient(BaseSFTP, ClosingContextManager): :param int mode: permissions (posix-style) for the newly-created folder """ path = self._adjust_cwd(path) - self._log(DEBUG, 'mkdir({!r}, {!r})'.format(path, mode)) + self._log(DEBUG, "mkdir({!r}, {!r})".format(path, mode)) attr = SFTPAttributes() attr.st_mode = mode self._request(CMD_MKDIR, path, attr) @@ -433,7 +466,7 @@ class SFTPClient(BaseSFTP, ClosingContextManager): :param str path: name of the folder to remove """ path = self._adjust_cwd(path) - self._log(DEBUG, 'rmdir({!r})'.format(path)) + self._log(DEBUG, "rmdir({!r})".format(path)) self._request(CMD_RMDIR, path) def stat(self, path): @@ -456,10 +489,10 @@ class SFTPClient(BaseSFTP, ClosingContextManager): file """ path = self._adjust_cwd(path) - self._log(DEBUG, 'stat({!r})'.format(path)) + self._log(DEBUG, "stat({!r})".format(path)) t, msg = self._request(CMD_STAT, path) if t != CMD_ATTRS: - raise SFTPError('Expected attributes') + raise SFTPError("Expected attributes") return SFTPAttributes._from_msg(msg) def lstat(self, path): @@ -474,10 +507,10 @@ class SFTPClient(BaseSFTP, ClosingContextManager): file """ path = self._adjust_cwd(path) - self._log(DEBUG, 'lstat({!r})'.format(path)) + self._log(DEBUG, "lstat({!r})".format(path)) t, msg = self._request(CMD_LSTAT, path) if t != CMD_ATTRS: - raise SFTPError('Expected attributes') + raise SFTPError("Expected attributes") return SFTPAttributes._from_msg(msg) def symlink(self, source, dest): @@ -488,7 +521,7 @@ class SFTPClient(BaseSFTP, ClosingContextManager): :param str dest: path of the newly created symlink """ dest = self._adjust_cwd(dest) - self._log(DEBUG, 'symlink({!r}, {!r})'.format(source, dest)) + self._log(DEBUG, "symlink({!r}, {!r})".format(source, dest)) source = bytestring(source) self._request(CMD_SYMLINK, source, dest) @@ -502,7 +535,7 @@ class SFTPClient(BaseSFTP, ClosingContextManager): :param int mode: new permissions """ path = self._adjust_cwd(path) - self._log(DEBUG, 'chmod({!r}, {!r})'.format(path, mode)) + self._log(DEBUG, "chmod({!r}, {!r})".format(path, mode)) attr = SFTPAttributes() attr.st_mode = mode self._request(CMD_SETSTAT, path, attr) @@ -519,7 +552,7 @@ class SFTPClient(BaseSFTP, ClosingContextManager): :param int gid: new group id """ path = self._adjust_cwd(path) - self._log(DEBUG, 'chown({!r}, {!r}, {!r})'.format(path, uid, gid)) + self._log(DEBUG, "chown({!r}, {!r}, {!r})".format(path, uid, gid)) attr = SFTPAttributes() attr.st_uid, attr.st_gid = uid, gid self._request(CMD_SETSTAT, path, attr) @@ -541,7 +574,7 @@ class SFTPClient(BaseSFTP, ClosingContextManager): path = self._adjust_cwd(path) if times is None: times = (time.time(), time.time()) - self._log(DEBUG, 'utime({!r}, {!r})'.format(path, times)) + self._log(DEBUG, "utime({!r}, {!r})".format(path, times)) attr = SFTPAttributes() attr.st_atime, attr.st_mtime = times self._request(CMD_SETSTAT, path, attr) @@ -556,7 +589,7 @@ class SFTPClient(BaseSFTP, ClosingContextManager): :param int size: the new size of the file """ path = self._adjust_cwd(path) - self._log(DEBUG, 'truncate({!r}, {!r})'.format(path, size)) + self._log(DEBUG, "truncate({!r}, {!r})".format(path, size)) attr = SFTPAttributes() attr.st_size = size self._request(CMD_SETSTAT, path, attr) @@ -571,15 +604,15 @@ class SFTPClient(BaseSFTP, ClosingContextManager): :return: target path, as a `str` """ path = self._adjust_cwd(path) - self._log(DEBUG, 'readlink({!r})'.format(path)) + self._log(DEBUG, "readlink({!r})".format(path)) t, msg = self._request(CMD_READLINK, path) if t != CMD_NAME: - raise SFTPError('Expected name response') + raise SFTPError("Expected name response") count = msg.get_int() if count == 0: return None if count != 1: - raise SFTPError('Readlink returned {} results'.format(count)) + raise SFTPError("Readlink returned {} results".format(count)) return _to_unicode(msg.get_string()) def normalize(self, path): @@ -595,13 +628,13 @@ class SFTPClient(BaseSFTP, ClosingContextManager): :raises: ``IOError`` -- if the path can't be resolved on the server """ path = self._adjust_cwd(path) - self._log(DEBUG, 'normalize({!r})'.format(path)) + self._log(DEBUG, "normalize({!r})".format(path)) t, msg = self._request(CMD_REALPATH, path) if t != CMD_NAME: - raise SFTPError('Expected name response') + raise SFTPError("Expected name response") count = msg.get_int() if count != 1: - raise SFTPError('Realpath returned {} results'.format(count)) + raise SFTPError("Realpath returned {} results".format(count)) return msg.get_text() def chdir(self, path=None): @@ -625,9 +658,7 @@ class SFTPClient(BaseSFTP, ClosingContextManager): return if not stat.S_ISDIR(self.stat(path).st_mode): code = errno.ENOTDIR - raise SFTPError( - code, "{}: {}".format(os.strerror(code), path) - ) + raise SFTPError(code, "{}: {}".format(os.strerror(code), path)) self._cwd = b(self.normalize(path)) def getcwd(self): @@ -680,7 +711,7 @@ class SFTPClient(BaseSFTP, ClosingContextManager): .. versionadded:: 1.10 """ - with self.file(remotepath, 'wb') as fr: + with self.file(remotepath, "wb") as fr: fr.set_pipelined(True) size = self._transfer_with_callback( reader=fl, writer=fr, file_size=file_size, callback=callback @@ -689,7 +720,8 @@ class SFTPClient(BaseSFTP, ClosingContextManager): s = self.stat(remotepath) if s.st_size != size: raise IOError( - 'size mismatch in put! {} != {}'.format(s.st_size, size)) + "size mismatch in put! {} != {}".format(s.st_size, size) + ) else: s = SFTPAttributes() return s @@ -723,7 +755,7 @@ class SFTPClient(BaseSFTP, ClosingContextManager): ``confirm`` param added. """ file_size = os.stat(localpath).st_size - with open(localpath, 'rb') as fl: + with open(localpath, "rb") as fl: return self.putfo(fl, remotepath, file_size, callback, confirm) def getfo(self, remotepath, fl, callback=None): @@ -744,7 +776,7 @@ class SFTPClient(BaseSFTP, ClosingContextManager): .. versionadded:: 1.10 """ file_size = self.stat(remotepath).st_size - with self.open(remotepath, 'rb') as fr: + with self.open(remotepath, "rb") as fr: fr.prefetch(file_size) return self._transfer_with_callback( reader=fr, writer=fl, file_size=file_size, callback=callback @@ -766,12 +798,13 @@ class SFTPClient(BaseSFTP, ClosingContextManager): .. versionchanged:: 1.7.4 Added the ``callback`` param """ - with open(localpath, 'wb') as fl: + with open(localpath, "wb") as fl: size = self.getfo(remotepath, fl, callback) s = os.stat(localpath) if s.st_size != size: raise IOError( - 'size mismatch in get! {} != {}'.format(s.st_size, size)) + "size mismatch in get! {} != {}".format(s.st_size, size) + ) # ...internals... @@ -809,7 +842,7 @@ class SFTPClient(BaseSFTP, ClosingContextManager): try: t, data = self._read_packet() except EOFError as e: - raise SSHException('Server connection dropped: {}'.format(e)) + raise SSHException("Server connection dropped: {}".format(e)) msg = Message(data) num = msg.get_int() self._lock.acquire() @@ -817,7 +850,7 @@ class SFTPClient(BaseSFTP, ClosingContextManager): if num not in self._expecting: # might be response for a file that was closed before # responses came back - self._log(DEBUG, 'Unexpected response #{}'.format(num)) + self._log(DEBUG, "Unexpected response #{}".format(num)) if waitfor is None: # just doing a single check break diff --git a/paramiko/sftp_file.py b/paramiko/sftp_file.py index 52f2bde8..049e804d 100644 --- a/paramiko/sftp_file.py +++ b/paramiko/sftp_file.py @@ -32,13 +32,21 @@ from paramiko.common import DEBUG from paramiko.file import BufferedFile from paramiko.py3compat import u, long from paramiko.sftp import ( - CMD_CLOSE, CMD_READ, CMD_DATA, SFTPError, CMD_WRITE, CMD_STATUS, CMD_FSTAT, - CMD_ATTRS, CMD_FSETSTAT, CMD_EXTENDED, + CMD_CLOSE, + CMD_READ, + CMD_DATA, + SFTPError, + CMD_WRITE, + CMD_STATUS, + CMD_FSTAT, + CMD_ATTRS, + CMD_FSETSTAT, + CMD_EXTENDED, ) from paramiko.sftp_attr import SFTPAttributes -class SFTPFile (BufferedFile): +class SFTPFile(BufferedFile): """ Proxy object for a file on the remote server, in client mode SFTP. @@ -50,7 +58,7 @@ class SFTPFile (BufferedFile): # this size. MAX_REQUEST_SIZE = 32768 - def __init__(self, sftp, handle, mode='r', bufsize=-1): + def __init__(self, sftp, handle, mode="r", bufsize=-1): BufferedFile.__init__(self) self.sftp = sftp self.handle = handle @@ -83,7 +91,7 @@ class SFTPFile (BufferedFile): # __del__.) if self._closed: return - self.sftp._log(DEBUG, 'close({})'.format(u(hexlify(self.handle)))) + self.sftp._log(DEBUG, "close({})".format(u(hexlify(self.handle)))) if self.pipelined: self.sftp._finish_responses(self) BufferedFile.close(self) @@ -102,8 +110,9 @@ class SFTPFile (BufferedFile): pass def _data_in_prefetch_requests(self, offset, size): - k = [x for x in list(self._prefetch_extents.values()) - if x[0] <= offset] + k = [ + x for x in list(self._prefetch_extents.values()) if x[0] <= offset + ] if len(k) == 0: return False k.sort(key=lambda x: x[0]) @@ -117,8 +126,8 @@ class SFTPFile (BufferedFile): # well, we have part of the request. see if another chunk has # the rest. return self._data_in_prefetch_requests( - buf_offset + buf_size, - offset + size - buf_offset - buf_size) + buf_offset + buf_size, offset + size - buf_offset - buf_size + ) def _data_in_prefetch_buffers(self, offset): """ @@ -174,13 +183,10 @@ class SFTPFile (BufferedFile): if data is not None: return data t, msg = self.sftp._request( - CMD_READ, - self.handle, - long(self._realpos), - int(size) + CMD_READ, self.handle, long(self._realpos), int(size) ) if t != CMD_DATA: - raise SFTPError('Expected data') + raise SFTPError("Expected data") return msg.get_string() def _write(self, data): @@ -191,18 +197,18 @@ class SFTPFile (BufferedFile): CMD_WRITE, self.handle, long(self._realpos), - data[:chunk] + data[:chunk], ) self._reqs.append(sftp_async_request) if ( - not self.pipelined or - (len(self._reqs) > 100 and self.sftp.sock.recv_ready()) + not self.pipelined + or (len(self._reqs) > 100 and self.sftp.sock.recv_ready()) ): while len(self._reqs): req = self._reqs.popleft() t, msg = self.sftp._read_response(req) if t != CMD_STATUS: - raise SFTPError('Expected status') + raise SFTPError("Expected status") # convert_status already called return chunk @@ -277,7 +283,7 @@ class SFTPFile (BufferedFile): """ t, msg = self.sftp._request(CMD_FSTAT, self.handle) if t != CMD_ATTRS: - raise SFTPError('Expected attributes') + raise SFTPError("Expected attributes") return SFTPAttributes._from_msg(msg) def chmod(self, mode): @@ -288,8 +294,9 @@ class SFTPFile (BufferedFile): :param int mode: new permissions """ - self.sftp._log(DEBUG, 'chmod({}, {!r})'.format( - hexlify(self.handle), mode)) + self.sftp._log( + DEBUG, "chmod({}, {!r})".format(hexlify(self.handle), mode) + ) attr = SFTPAttributes() attr.st_mode = mode self.sftp._request(CMD_FSETSTAT, self.handle, attr) @@ -306,7 +313,8 @@ class SFTPFile (BufferedFile): """ self.sftp._log( DEBUG, - 'chown({}, {!r}, {!r})'.format(hexlify(self.handle), uid, gid)) + "chown({}, {!r}, {!r})".format(hexlify(self.handle), uid, gid), + ) attr = SFTPAttributes() attr.st_uid, attr.st_gid = uid, gid self.sftp._request(CMD_FSETSTAT, self.handle, attr) @@ -326,8 +334,9 @@ class SFTPFile (BufferedFile): """ if times is None: times = (time.time(), time.time()) - self.sftp._log(DEBUG, 'utime({}, {!r})'.format( - hexlify(self.handle), times)) + self.sftp._log( + DEBUG, "utime({}, {!r})".format(hexlify(self.handle), times) + ) attr = SFTPAttributes() attr.st_atime, attr.st_mtime = times self.sftp._request(CMD_FSETSTAT, self.handle, attr) @@ -341,8 +350,8 @@ class SFTPFile (BufferedFile): :param size: the new size of the file """ self.sftp._log( - DEBUG, - 'truncate({}, {!r})'.format(hexlify(self.handle), size)) + DEBUG, "truncate({}, {!r})".format(hexlify(self.handle), size) + ) attr = SFTPAttributes() attr.st_size = size self.sftp._request(CMD_FSETSTAT, self.handle, attr) @@ -394,8 +403,14 @@ class SFTPFile (BufferedFile): .. versionadded:: 1.4 """ t, msg = self.sftp._request( - CMD_EXTENDED, 'check-file', self.handle, - hash_algorithm, long(offset), long(length), block_size) + CMD_EXTENDED, + "check-file", + self.handle, + hash_algorithm, + long(offset), + long(length), + block_size, + ) msg.get_text() # ext msg.get_text() # alg data = msg.get_remainder() @@ -475,15 +490,16 @@ class SFTPFile (BufferedFile): .. versionadded:: 1.5.4 """ - self.sftp._log(DEBUG, 'readv({}, {!r})'.format( - hexlify(self.handle), chunks)) + self.sftp._log( + DEBUG, "readv({}, {!r})".format(hexlify(self.handle), chunks) + ) read_chunks = [] for offset, size in chunks: # don't fetch data that's already in the prefetch buffer if ( - self._data_in_prefetch_buffers(offset) or - self._data_in_prefetch_requests(offset, size) + self._data_in_prefetch_buffers(offset) + or self._data_in_prefetch_requests(offset, size) ): continue @@ -521,11 +537,8 @@ class SFTPFile (BufferedFile): # a lot of them, so it may block. for offset, length in chunks: num = self.sftp._async_request( - self, - CMD_READ, - self.handle, - long(offset), - int(length)) + self, CMD_READ, self.handle, long(offset), int(length) + ) with self._prefetch_lock: self._prefetch_extents[num] = (offset, length) @@ -538,7 +551,7 @@ class SFTPFile (BufferedFile): self._saved_exception = e return if t != CMD_DATA: - raise SFTPError('Expected data') + raise SFTPError("Expected data") data = msg.get_string() while True: with self._prefetch_lock: diff --git a/paramiko/sftp_handle.py b/paramiko/sftp_handle.py index ca473900..a7e22f01 100644 --- a/paramiko/sftp_handle.py +++ b/paramiko/sftp_handle.py @@ -25,7 +25,7 @@ from paramiko.sftp import SFTP_OP_UNSUPPORTED, SFTP_OK from paramiko.util import ClosingContextManager -class SFTPHandle (ClosingContextManager): +class SFTPHandle(ClosingContextManager): """ Abstract object representing a handle to an open file (or folder) in an SFTP server implementation. Each handle has a string representation used @@ -36,6 +36,7 @@ class SFTPHandle (ClosingContextManager): Instances of this class may be used as context managers. """ + def __init__(self, flags=0): """ Create a new file handle representing a local file being served over @@ -63,10 +64,10 @@ class SFTPHandle (ClosingContextManager): using the default implementations of `read` and `write`, this method's default implementation should be fine also. """ - readfile = getattr(self, 'readfile', None) + readfile = getattr(self, "readfile", None) if readfile is not None: readfile.close() - writefile = getattr(self, 'writefile', None) + writefile = getattr(self, "writefile", None) if writefile is not None: writefile.close() @@ -88,7 +89,7 @@ class SFTPHandle (ClosingContextManager): :param int length: number of bytes to attempt to read. :return: data read from the file, or an SFTP error code, as a `str`. """ - readfile = getattr(self, 'readfile', None) + readfile = getattr(self, "readfile", None) if readfile is None: return SFTP_OP_UNSUPPORTED try: @@ -122,7 +123,7 @@ class SFTPHandle (ClosingContextManager): :param str data: data to write into the file. :return: an SFTP error code like ``SFTP_OK``. """ - writefile = getattr(self, 'writefile', None) + writefile = getattr(self, "writefile", None) if writefile is None: return SFTP_OP_UNSUPPORTED try: diff --git a/paramiko/sftp_server.py b/paramiko/sftp_server.py index f8c4f727..8265df96 100644 --- a/paramiko/sftp_server.py +++ b/paramiko/sftp_server.py @@ -27,7 +27,11 @@ from hashlib import md5, sha1 from paramiko import util from paramiko.sftp import ( - BaseSFTP, Message, SFTP_FAILURE, SFTP_PERMISSION_DENIED, SFTP_NO_SUCH_FILE, + BaseSFTP, + Message, + SFTP_FAILURE, + SFTP_PERMISSION_DENIED, + SFTP_NO_SUCH_FILE, ) from paramiko.sftp_si import SFTPServerInterface from paramiko.sftp_attr import SFTPAttributes @@ -38,30 +42,64 @@ from paramiko.server import SubsystemHandler # known hash algorithms for the "check-file" extension from paramiko.sftp import ( - CMD_HANDLE, SFTP_DESC, CMD_STATUS, SFTP_EOF, CMD_NAME, SFTP_BAD_MESSAGE, - CMD_EXTENDED_REPLY, SFTP_FLAG_READ, SFTP_FLAG_WRITE, SFTP_FLAG_APPEND, - SFTP_FLAG_CREATE, SFTP_FLAG_TRUNC, SFTP_FLAG_EXCL, CMD_NAMES, CMD_OPEN, - CMD_CLOSE, SFTP_OK, CMD_READ, CMD_DATA, CMD_WRITE, CMD_REMOVE, CMD_RENAME, - CMD_MKDIR, CMD_RMDIR, CMD_OPENDIR, CMD_READDIR, CMD_STAT, CMD_ATTRS, - CMD_LSTAT, CMD_FSTAT, CMD_SETSTAT, CMD_FSETSTAT, CMD_READLINK, CMD_SYMLINK, - CMD_REALPATH, CMD_EXTENDED, SFTP_OP_UNSUPPORTED, + CMD_HANDLE, + SFTP_DESC, + CMD_STATUS, + SFTP_EOF, + CMD_NAME, + SFTP_BAD_MESSAGE, + CMD_EXTENDED_REPLY, + SFTP_FLAG_READ, + SFTP_FLAG_WRITE, + SFTP_FLAG_APPEND, + SFTP_FLAG_CREATE, + SFTP_FLAG_TRUNC, + SFTP_FLAG_EXCL, + CMD_NAMES, + CMD_OPEN, + CMD_CLOSE, + SFTP_OK, + CMD_READ, + CMD_DATA, + CMD_WRITE, + CMD_REMOVE, + CMD_RENAME, + CMD_MKDIR, + CMD_RMDIR, + CMD_OPENDIR, + CMD_READDIR, + CMD_STAT, + CMD_ATTRS, + CMD_LSTAT, + CMD_FSTAT, + CMD_SETSTAT, + CMD_FSETSTAT, + CMD_READLINK, + CMD_SYMLINK, + CMD_REALPATH, + CMD_EXTENDED, + SFTP_OP_UNSUPPORTED, ) -_hash_class = { - 'sha1': sha1, - 'md5': md5, -} +_hash_class = {"sha1": sha1, "md5": md5} -class SFTPServer (BaseSFTP, SubsystemHandler): +class SFTPServer(BaseSFTP, SubsystemHandler): """ Server-side SFTP subsystem support. Since this is a `.SubsystemHandler`, it can be (and is meant to be) set as the handler for ``"sftp"`` requests. Use `.Transport.set_subsystem_handler` to activate this class. """ - def __init__(self, channel, name, server, sftp_si=SFTPServerInterface, - *largs, **kwargs): + def __init__( + self, + channel, + name, + server, + sftp_si=SFTPServerInterface, + *largs, + **kwargs + ): """ The constructor for SFTPServer is meant to be called from within the `.Transport` as a subsystem handler. ``server`` and any additional @@ -79,7 +117,7 @@ class SFTPServer (BaseSFTP, SubsystemHandler): BaseSFTP.__init__(self) SubsystemHandler.__init__(self, channel, name, server) transport = channel.get_transport() - self.logger = util.get_logger(transport.get_log_channel() + '.sftp') + self.logger = util.get_logger(transport.get_log_channel() + ".sftp") self.ultra_debug = transport.get_hexdump() self.next_handle = 1 # map of handle-string to SFTPHandle for files & folders: @@ -91,26 +129,26 @@ class SFTPServer (BaseSFTP, SubsystemHandler): if issubclass(type(msg), list): for m in msg: super(SFTPServer, self)._log( - level, - "[chan " + self.sock.get_name() + "] " + m) + level, "[chan " + self.sock.get_name() + "] " + m + ) else: super(SFTPServer, self)._log( - level, - "[chan " + self.sock.get_name() + "] " + msg) + level, "[chan " + self.sock.get_name() + "] " + msg + ) def start_subsystem(self, name, transport, channel): self.sock = channel - self._log(DEBUG, 'Started sftp server on channel {!r}'.format(channel)) + self._log(DEBUG, "Started sftp server on channel {!r}".format(channel)) self._send_server_version() self.server.session_started() while True: try: t, data = self._read_packet() except EOFError: - self._log(DEBUG, 'EOF -- end of session') + self._log(DEBUG, "EOF -- end of session") return except Exception as e: - self._log(DEBUG, 'Exception on channel: ' + str(e)) + self._log(DEBUG, "Exception on channel: " + str(e)) self._log(DEBUG, util.tb_strings()) return msg = Message(data) @@ -118,7 +156,7 @@ class SFTPServer (BaseSFTP, SubsystemHandler): try: self._process(t, request_number, msg) except Exception as e: - self._log(DEBUG, 'Exception in server processing: ' + str(e)) + self._log(DEBUG, "Exception in server processing: " + str(e)) self._log(DEBUG, util.tb_strings()) # send some kind of failure message, at least try: @@ -172,7 +210,7 @@ class SFTPServer (BaseSFTP, SubsystemHandler): name of the file to alter (should usually be an absolute path). :param .SFTPAttributes attr: attributes to change. """ - if sys.platform != 'win32': + if sys.platform != "win32": # mode operations are meaningless on win32 if attr._flags & attr.FLAG_PERMISSIONS: os.chmod(filename, attr.st_mode) @@ -181,7 +219,7 @@ class SFTPServer (BaseSFTP, SubsystemHandler): if attr._flags & attr.FLAG_AMTIME: os.utime(filename, (attr.st_atime, attr.st_mtime)) if attr._flags & attr.FLAG_SIZE: - with open(filename, 'w+') as f: + with open(filename, "w+") as f: f.truncate(attr.st_size) # ...internals... @@ -200,8 +238,8 @@ class SFTPServer (BaseSFTP, SubsystemHandler): item._pack(msg) else: raise Exception( - 'unknown type for {!r} type {!r}'.format( - item, type(item))) + "unknown type for {!r} type {!r}".format(item, type(item)) + ) self._send_packet(t, msg) def _send_handle_response(self, request_number, handle, folder=False): @@ -209,7 +247,7 @@ class SFTPServer (BaseSFTP, SubsystemHandler): # must be error code self._send_status(request_number, handle) return - handle._set_name(b('hx{:d}'.format(self.next_handle))) + handle._set_name(b("hx{:d}".format(self.next_handle))) self.next_handle += 1 if folder: self.folder_table[handle._get_name()] = handle @@ -222,10 +260,10 @@ class SFTPServer (BaseSFTP, SubsystemHandler): try: desc = SFTP_DESC[code] except IndexError: - desc = 'Unknown' + desc = "Unknown" # some clients expect a "langauge" tag at the end # (but don't mind it being blank) - self._response(request_number, CMD_STATUS, code, desc, '') + self._response(request_number, CMD_STATUS, code, desc, "") def _open_folder(self, request_number, path): resp = self.server.list_folder(path) @@ -264,7 +302,8 @@ class SFTPServer (BaseSFTP, SubsystemHandler): block_size = msg.get_int() if handle not in self.file_table: self._send_status( - request_number, SFTP_BAD_MESSAGE, 'Invalid handle') + request_number, SFTP_BAD_MESSAGE, "Invalid handle" + ) return f = self.file_table[handle] for x in alg_list: @@ -274,19 +313,21 @@ class SFTPServer (BaseSFTP, SubsystemHandler): break else: self._send_status( - request_number, SFTP_FAILURE, 'No supported hash types found') + request_number, SFTP_FAILURE, "No supported hash types found" + ) return if length == 0: st = f.stat() if not issubclass(type(st), SFTPAttributes): - self._send_status(request_number, st, 'Unable to stat file') + self._send_status(request_number, st, "Unable to stat file") return length = st.st_size - start if block_size == 0: block_size = length if block_size < 256: self._send_status( - request_number, SFTP_FAILURE, 'Block size too small') + request_number, SFTP_FAILURE, "Block size too small" + ) return sum_out = bytes() @@ -301,7 +342,8 @@ class SFTPServer (BaseSFTP, SubsystemHandler): data = f.read(offset, chunklen) if not isinstance(data, bytes_types): self._send_status( - request_number, data, 'Unable to hash file') + request_number, data, "Unable to hash file" + ) return hash_obj.update(data) count += len(data) @@ -310,7 +352,7 @@ class SFTPServer (BaseSFTP, SubsystemHandler): msg = Message() msg.add_int(request_number) - msg.add_string('check-file') + msg.add_string("check-file") msg.add_string(algname) msg.add_bytes(sum_out) self._send_packet(CMD_EXTENDED_REPLY, msg) @@ -334,13 +376,14 @@ class SFTPServer (BaseSFTP, SubsystemHandler): return flags def _process(self, t, request_number, msg): - self._log(DEBUG, 'Request: {}'.format(CMD_NAMES[t])) + self._log(DEBUG, "Request: {}".format(CMD_NAMES[t])) if t == CMD_OPEN: path = msg.get_text() flags = self._convert_pflags(msg.get_int()) attr = SFTPAttributes._from_msg(msg) self._send_handle_response( - request_number, self.server.open(path, flags, attr)) + request_number, self.server.open(path, flags, attr) + ) elif t == CMD_CLOSE: handle = msg.get_binary() if handle in self.folder_table: @@ -353,14 +396,16 @@ class SFTPServer (BaseSFTP, SubsystemHandler): self._send_status(request_number, SFTP_OK) return self._send_status( - request_number, SFTP_BAD_MESSAGE, 'Invalid handle') + request_number, SFTP_BAD_MESSAGE, "Invalid handle" + ) elif t == CMD_READ: handle = msg.get_binary() offset = msg.get_int64() length = msg.get_int() if handle not in self.file_table: self._send_status( - request_number, SFTP_BAD_MESSAGE, 'Invalid handle') + request_number, SFTP_BAD_MESSAGE, "Invalid handle" + ) return data = self.file_table[handle].read(offset, length) if isinstance(data, (bytes_types, string_types)): @@ -376,10 +421,12 @@ class SFTPServer (BaseSFTP, SubsystemHandler): data = msg.get_binary() if handle not in self.file_table: self._send_status( - request_number, SFTP_BAD_MESSAGE, 'Invalid handle') + request_number, SFTP_BAD_MESSAGE, "Invalid handle" + ) return self._send_status( - request_number, self.file_table[handle].write(offset, data)) + request_number, self.file_table[handle].write(offset, data) + ) elif t == CMD_REMOVE: path = msg.get_text() self._send_status(request_number, self.server.remove(path)) @@ -387,7 +434,8 @@ class SFTPServer (BaseSFTP, SubsystemHandler): oldpath = msg.get_text() newpath = msg.get_text() self._send_status( - request_number, self.server.rename(oldpath, newpath)) + request_number, self.server.rename(oldpath, newpath) + ) elif t == CMD_MKDIR: path = msg.get_text() attr = SFTPAttributes._from_msg(msg) @@ -403,7 +451,8 @@ class SFTPServer (BaseSFTP, SubsystemHandler): handle = msg.get_binary() if handle not in self.folder_table: self._send_status( - request_number, SFTP_BAD_MESSAGE, 'Invalid handle') + request_number, SFTP_BAD_MESSAGE, "Invalid handle" + ) return folder = self.folder_table[handle] self._read_folder(request_number, folder) @@ -425,7 +474,8 @@ class SFTPServer (BaseSFTP, SubsystemHandler): handle = msg.get_binary() if handle not in self.file_table: self._send_status( - request_number, SFTP_BAD_MESSAGE, 'Invalid handle') + request_number, SFTP_BAD_MESSAGE, "Invalid handle" + ) return resp = self.file_table[handle].stat() if issubclass(type(resp), SFTPAttributes): @@ -441,16 +491,19 @@ class SFTPServer (BaseSFTP, SubsystemHandler): attr = SFTPAttributes._from_msg(msg) if handle not in self.file_table: self._response( - request_number, SFTP_BAD_MESSAGE, 'Invalid handle') + request_number, SFTP_BAD_MESSAGE, "Invalid handle" + ) return self._send_status( - request_number, self.file_table[handle].chattr(attr)) + request_number, self.file_table[handle].chattr(attr) + ) elif t == CMD_READLINK: path = msg.get_text() resp = self.server.readlink(path) if isinstance(resp, (bytes_types, string_types)): self._response( - request_number, CMD_NAME, 1, resp, '', SFTPAttributes()) + request_number, CMD_NAME, 1, resp, "", SFTPAttributes() + ) else: self._send_status(request_number, resp) elif t == CMD_SYMLINK: @@ -459,17 +512,19 @@ class SFTPServer (BaseSFTP, SubsystemHandler): target_path = msg.get_text() path = msg.get_text() self._send_status( - request_number, self.server.symlink(target_path, path)) + request_number, self.server.symlink(target_path, path) + ) elif t == CMD_REALPATH: path = msg.get_text() rpath = self.server.canonicalize(path) self._response( - request_number, CMD_NAME, 1, rpath, '', SFTPAttributes()) + request_number, CMD_NAME, 1, rpath, "", SFTPAttributes() + ) elif t == CMD_EXTENDED: tag = msg.get_text() - if tag == 'check-file': + if tag == "check-file": self._check_file(request_number, msg) - elif tag == 'posix-rename@openssh.com': + elif tag == "posix-rename@openssh.com": oldpath = msg.get_text() newpath = msg.get_text() self._send_status( diff --git a/paramiko/sftp_si.py b/paramiko/sftp_si.py index b43b582c..40dc561c 100644 --- a/paramiko/sftp_si.py +++ b/paramiko/sftp_si.py @@ -25,7 +25,7 @@ import sys from paramiko.sftp import SFTP_OP_UNSUPPORTED -class SFTPServerInterface (object): +class SFTPServerInterface(object): """ This class defines an interface for controlling the behavior of paramiko when using the `.SFTPServer` subsystem to provide an SFTP server. @@ -39,6 +39,7 @@ class SFTPServerInterface (object): All paths are in string form instead of unicode because not all SFTP clients & servers obey the requirement that paths be encoded in UTF-8. """ + def __init__(self, server, *largs, **kwargs): """ Create a new SFTPServerInterface object. This method does nothing by @@ -281,10 +282,10 @@ class SFTPServerInterface (object): if os.path.isabs(path): out = os.path.normpath(path) else: - out = os.path.normpath('/' + path) - if sys.platform == 'win32': + out = os.path.normpath("/" + path) + if sys.platform == "win32": # on windows, normalize backslashes to sftp/posix format - out = out.replace('\\', '/') + out = out.replace("\\", "/") return out def readlink(self, path): diff --git a/paramiko/ssh_exception.py b/paramiko/ssh_exception.py index 2df84b65..c1276c69 100644 --- a/paramiko/ssh_exception.py +++ b/paramiko/ssh_exception.py @@ -19,14 +19,14 @@ import socket -class SSHException (Exception): +class SSHException(Exception): """ Exception raised by failures in SSH2 protocol negotiation or logic errors. """ pass -class AuthenticationException (SSHException): +class AuthenticationException(SSHException): """ Exception raised when authentication failed for some reason. It may be possible to retry with different credentials. (Other classes specify more @@ -37,14 +37,14 @@ class AuthenticationException (SSHException): pass -class PasswordRequiredException (AuthenticationException): +class PasswordRequiredException(AuthenticationException): """ Exception raised when a password is needed to unlock a private key file. """ pass -class BadAuthenticationType (AuthenticationException): +class BadAuthenticationType(AuthenticationException): """ Exception raised when an authentication type (like password) is used, but the server isn't allowing that type. (It may only allow public-key, for @@ -60,28 +60,28 @@ class BadAuthenticationType (AuthenticationException): AuthenticationException.__init__(self, explanation) self.allowed_types = types # for unpickling - self.args = (explanation, types, ) + self.args = (explanation, types) def __str__(self): - return '{} (allowed_types={!r})'.format( + return "{} (allowed_types={!r})".format( SSHException.__str__(self), self.allowed_types ) -class PartialAuthentication (AuthenticationException): +class PartialAuthentication(AuthenticationException): """ An internal exception thrown in the case of partial authentication. """ allowed_types = [] def __init__(self, types): - AuthenticationException.__init__(self, 'partial authentication') + AuthenticationException.__init__(self, "partial authentication") self.allowed_types = types # for unpickling - self.args = (types, ) + self.args = (types,) -class ChannelException (SSHException): +class ChannelException(SSHException): """ Exception raised when an attempt to open a new `.Channel` fails. @@ -89,14 +89,15 @@ class ChannelException (SSHException): .. versionadded:: 1.6 """ + def __init__(self, code, text): SSHException.__init__(self, text) self.code = code # for unpickling - self.args = (code, text, ) + self.args = (code, text) -class BadHostKeyException (SSHException): +class BadHostKeyException(SSHException): """ The host key given by the SSH server did not match what we were expecting. @@ -106,35 +107,38 @@ class BadHostKeyException (SSHException): .. versionadded:: 1.6 """ + def __init__(self, hostname, got_key, expected_key): - message = 'Host key for server {} does not match: got {}, expected {}' # noqa + message = "Host key for server {} does not match: got {}, expected {}" # noqa message = message.format( - hostname, got_key.get_base64(), - expected_key.get_base64()) + hostname, got_key.get_base64(), expected_key.get_base64() + ) SSHException.__init__(self, message) self.hostname = hostname self.key = got_key self.expected_key = expected_key # for unpickling - self.args = (hostname, got_key, expected_key, ) + self.args = (hostname, got_key, expected_key) -class ProxyCommandFailure (SSHException): +class ProxyCommandFailure(SSHException): """ The "ProxyCommand" found in the .ssh/config file returned an error. :param str command: The command line that is generating this exception. :param str error: The error captured from the proxy command output. """ + def __init__(self, command, error): - SSHException.__init__(self, + SSHException.__init__( + self, '"ProxyCommand ({})" returned non-zero exit status: {}'.format( command, error - ) + ), ) self.error = error # for unpickling - self.args = (command, error, ) + self.args = (command, error) class NoValidConnectionsError(socket.error): @@ -159,23 +163,23 @@ class NoValidConnectionsError(socket.error): .. versionadded:: 1.16 """ + def __init__(self, errors): """ :param dict errors: The errors dict to store, as described by class docstring. """ addrs = sorted(errors.keys()) - body = ', '.join([x[0] for x in addrs[:-1]]) + body = ", ".join([x[0] for x in addrs[:-1]]) tail = addrs[-1][0] if body: msg = "Unable to connect to port {0} on {1} or {2}" else: msg = "Unable to connect to port {0} on {2}" super(NoValidConnectionsError, self).__init__( - None, # stand-in for errno - msg.format(addrs[0][1], body, tail) + None, msg.format(addrs[0][1], body, tail) # stand-in for errno ) self.errors = errors def __reduce__(self): - return (self.__class__, (self.errors, )) + return (self.__class__, (self.errors,)) diff --git a/paramiko/ssh_gss.py b/paramiko/ssh_gss.py index 88dedf7e..14087042 100644 --- a/paramiko/ssh_gss.py +++ b/paramiko/ssh_gss.py @@ -51,12 +51,14 @@ _API = "MIT" try: import gssapi + GSS_EXCEPTIONS = (gssapi.GSSException,) except (ImportError, OSError): try: import pywintypes import sspicon import sspi + _API = "SSPI" GSS_EXCEPTIONS = (pywintypes.error,) except ImportError: @@ -101,6 +103,7 @@ class _SSH_GSSAuth(object): Contains the shared variables and methods of `._SSH_GSSAPI` and `._SSH_SSPI`. """ + def __init__(self, auth_method, gss_deleg_creds): """ :param str auth_method: The name of the SSH authentication mechanism @@ -209,7 +212,7 @@ class _SSH_GSSAuth(object): """ mic = self._make_uint32(len(session_id)) mic += session_id - mic += struct.pack('B', MSG_USERAUTH_REQUEST) + mic += struct.pack("B", MSG_USERAUTH_REQUEST) mic += self._make_uint32(len(username)) mic += username.encode() mic += self._make_uint32(len(service)) @@ -225,6 +228,7 @@ class _SSH_GSSAPI(_SSH_GSSAuth): :see: `.GSSAuth` """ + def __init__(self, auth_method, gss_deleg_creds): """ :param str auth_method: The name of the SSH authentication mechanism @@ -234,17 +238,22 @@ class _SSH_GSSAPI(_SSH_GSSAuth): _SSH_GSSAuth.__init__(self, auth_method, gss_deleg_creds) if self._gss_deleg_creds: - self._gss_flags = (gssapi.C_PROT_READY_FLAG, - gssapi.C_INTEG_FLAG, - gssapi.C_MUTUAL_FLAG, - gssapi.C_DELEG_FLAG) + self._gss_flags = ( + gssapi.C_PROT_READY_FLAG, + gssapi.C_INTEG_FLAG, + gssapi.C_MUTUAL_FLAG, + gssapi.C_DELEG_FLAG, + ) else: - self._gss_flags = (gssapi.C_PROT_READY_FLAG, - gssapi.C_INTEG_FLAG, - gssapi.C_MUTUAL_FLAG) + self._gss_flags = ( + gssapi.C_PROT_READY_FLAG, + gssapi.C_INTEG_FLAG, + gssapi.C_MUTUAL_FLAG, + ) - def ssh_init_sec_context(self, target, desired_mech=None, - username=None, recv_token=None): + def ssh_init_sec_context( + self, target, desired_mech=None, username=None, recv_token=None + ): """ Initialize a GSS-API context. @@ -262,8 +271,9 @@ class _SSH_GSSAPI(_SSH_GSSAuth): """ self._username = username self._gss_host = target - targ_name = gssapi.Name("host@" + self._gss_host, - gssapi.C_NT_HOSTBASED_SERVICE) + targ_name = gssapi.Name( + "host@" + self._gss_host, gssapi.C_NT_HOSTBASED_SERVICE + ) ctx = gssapi.Context() ctx.flags = self._gss_flags if desired_mech is None: @@ -277,15 +287,16 @@ class _SSH_GSSAPI(_SSH_GSSAuth): token = None try: if recv_token is None: - self._gss_ctxt = gssapi.InitContext(peer_name=targ_name, - mech_type=krb5_mech, - req_flags=ctx.flags) + self._gss_ctxt = gssapi.InitContext( + peer_name=targ_name, + mech_type=krb5_mech, + req_flags=ctx.flags, + ) token = self._gss_ctxt.step(token) else: token = self._gss_ctxt.step(recv_token) except gssapi.GSSException: - message = "{} Target: {}".format( - sys.exc_info()[1], self._gss_host) + message = "{} Target: {}".format(sys.exc_info()[1], self._gss_host) raise gssapi.GSSException(message) self._gss_ctxt_status = self._gss_ctxt.established return token @@ -305,10 +316,12 @@ class _SSH_GSSAPI(_SSH_GSSAuth): """ self._session_id = session_id if not gss_kex: - mic_field = self._ssh_build_mic(self._session_id, - self._username, - self._service, - self._auth_method) + mic_field = self._ssh_build_mic( + self._session_id, + self._username, + self._service, + self._auth_method, + ) mic_token = self._gss_ctxt.get_mic(mic_field) else: # for key exchange with gssapi-keyex @@ -349,16 +362,17 @@ class _SSH_GSSAPI(_SSH_GSSAuth): self._username = username if self._username is not None: # server mode - mic_field = self._ssh_build_mic(self._session_id, - self._username, - self._service, - self._auth_method) + mic_field = self._ssh_build_mic( + self._session_id, + self._username, + self._service, + self._auth_method, + ) self._gss_srv_ctxt.verify_mic(mic_field, mic_token) else: # for key exchange with gssapi-keyex # client mode - self._gss_ctxt.verify_mic(self._session_id, - mic_token) + self._gss_ctxt.verify_mic(self._session_id, mic_token) @property def credentials_delegated(self): @@ -391,6 +405,7 @@ class _SSH_SSPI(_SSH_GSSAuth): :see: `.GSSAuth` """ + def __init__(self, auth_method, gss_deleg_creds): """ :param str auth_method: The name of the SSH authentication mechanism @@ -401,18 +416,18 @@ class _SSH_SSPI(_SSH_GSSAuth): if self._gss_deleg_creds: self._gss_flags = ( - sspicon.ISC_REQ_INTEGRITY | - sspicon.ISC_REQ_MUTUAL_AUTH | - sspicon.ISC_REQ_DELEGATE + sspicon.ISC_REQ_INTEGRITY + | sspicon.ISC_REQ_MUTUAL_AUTH + | sspicon.ISC_REQ_DELEGATE ) else: self._gss_flags = ( - sspicon.ISC_REQ_INTEGRITY | - sspicon.ISC_REQ_MUTUAL_AUTH + sspicon.ISC_REQ_INTEGRITY | sspicon.ISC_REQ_MUTUAL_AUTH ) - def ssh_init_sec_context(self, target, desired_mech=None, - username=None, recv_token=None): + def ssh_init_sec_context( + self, target, desired_mech=None, username=None, recv_token=None + ): """ Initialize a SSPI context. @@ -438,9 +453,9 @@ class _SSH_SSPI(_SSH_GSSAuth): raise SSHException("Unsupported mechanism OID.") try: if recv_token is None: - self._gss_ctxt = sspi.ClientAuth("Kerberos", - scflags=self._gss_flags, - targetspn=targ_name) + self._gss_ctxt = sspi.ClientAuth( + "Kerberos", scflags=self._gss_flags, targetspn=targ_name + ) error, token = self._gss_ctxt.authorize(recv_token) token = token[0].Buffer except pywintypes.error as e: @@ -475,10 +490,12 @@ class _SSH_SSPI(_SSH_GSSAuth): """ self._session_id = session_id if not gss_kex: - mic_field = self._ssh_build_mic(self._session_id, - self._username, - self._service, - self._auth_method) + mic_field = self._ssh_build_mic( + self._session_id, + self._username, + self._service, + self._auth_method, + ) mic_token = self._gss_ctxt.sign(mic_field) else: # for key exchange with gssapi-keyex @@ -521,10 +538,12 @@ class _SSH_SSPI(_SSH_GSSAuth): self._username = username if username is not None: # server mode - mic_field = self._ssh_build_mic(self._session_id, - self._username, - self._service, - self._auth_method) + mic_field = self._ssh_build_mic( + self._session_id, + self._username, + self._service, + self._auth_method, + ) # Verifies data and its signature. If verification fails, an # sspi.error will be raised. self._gss_srv_ctxt.verify(mic_field, mic_token) @@ -543,8 +562,8 @@ class _SSH_SSPI(_SSH_GSSAuth): :return: ``True`` if credentials are delegated, otherwise ``False`` """ return ( - self._gss_flags & sspicon.ISC_REQ_DELEGATE and - (self._gss_srv_ctxt_status or self._gss_flags) + self._gss_flags & sspicon.ISC_REQ_DELEGATE + and (self._gss_srv_ctxt_status or self._gss_flags) ) def save_client_creds(self, client_token): diff --git a/paramiko/transport.py b/paramiko/transport.py index ddcb2912..ea303d37 100644 --- a/paramiko/transport.py +++ b/paramiko/transport.py @@ -39,18 +39,49 @@ from paramiko.auth_handler import AuthHandler from paramiko.ssh_gss import GSSAuth from paramiko.channel import Channel from paramiko.common import ( - xffffffff, cMSG_CHANNEL_OPEN, cMSG_IGNORE, cMSG_GLOBAL_REQUEST, DEBUG, - MSG_KEXINIT, MSG_IGNORE, MSG_DISCONNECT, MSG_DEBUG, ERROR, WARNING, - cMSG_UNIMPLEMENTED, INFO, cMSG_KEXINIT, cMSG_NEWKEYS, MSG_NEWKEYS, - cMSG_REQUEST_SUCCESS, cMSG_REQUEST_FAILURE, CONNECTION_FAILED_CODE, - OPEN_FAILED_ADMINISTRATIVELY_PROHIBITED, OPEN_SUCCEEDED, - cMSG_CHANNEL_OPEN_FAILURE, cMSG_CHANNEL_OPEN_SUCCESS, MSG_GLOBAL_REQUEST, - MSG_REQUEST_SUCCESS, MSG_REQUEST_FAILURE, MSG_CHANNEL_OPEN_SUCCESS, - MSG_CHANNEL_OPEN_FAILURE, MSG_CHANNEL_OPEN, MSG_CHANNEL_SUCCESS, - MSG_CHANNEL_FAILURE, MSG_CHANNEL_DATA, MSG_CHANNEL_EXTENDED_DATA, - MSG_CHANNEL_WINDOW_ADJUST, MSG_CHANNEL_REQUEST, MSG_CHANNEL_EOF, - MSG_CHANNEL_CLOSE, MIN_WINDOW_SIZE, MIN_PACKET_SIZE, MAX_WINDOW_SIZE, - DEFAULT_WINDOW_SIZE, DEFAULT_MAX_PACKET_SIZE, HIGHEST_USERAUTH_MESSAGE_ID, + xffffffff, + cMSG_CHANNEL_OPEN, + cMSG_IGNORE, + cMSG_GLOBAL_REQUEST, + DEBUG, + MSG_KEXINIT, + MSG_IGNORE, + MSG_DISCONNECT, + MSG_DEBUG, + ERROR, + WARNING, + cMSG_UNIMPLEMENTED, + INFO, + cMSG_KEXINIT, + cMSG_NEWKEYS, + MSG_NEWKEYS, + cMSG_REQUEST_SUCCESS, + cMSG_REQUEST_FAILURE, + CONNECTION_FAILED_CODE, + OPEN_FAILED_ADMINISTRATIVELY_PROHIBITED, + OPEN_SUCCEEDED, + cMSG_CHANNEL_OPEN_FAILURE, + cMSG_CHANNEL_OPEN_SUCCESS, + MSG_GLOBAL_REQUEST, + MSG_REQUEST_SUCCESS, + MSG_REQUEST_FAILURE, + MSG_CHANNEL_OPEN_SUCCESS, + MSG_CHANNEL_OPEN_FAILURE, + MSG_CHANNEL_OPEN, + MSG_CHANNEL_SUCCESS, + MSG_CHANNEL_FAILURE, + MSG_CHANNEL_DATA, + MSG_CHANNEL_EXTENDED_DATA, + MSG_CHANNEL_WINDOW_ADJUST, + MSG_CHANNEL_REQUEST, + MSG_CHANNEL_EOF, + MSG_CHANNEL_CLOSE, + MIN_WINDOW_SIZE, + MIN_PACKET_SIZE, + MAX_WINDOW_SIZE, + DEFAULT_WINDOW_SIZE, + DEFAULT_MAX_PACKET_SIZE, + HIGHEST_USERAUTH_MESSAGE_ID, ) from paramiko.compress import ZlibCompressor, ZlibDecompressor from paramiko.dsskey import DSSKey @@ -69,7 +100,10 @@ from paramiko.ecdsakey import ECDSAKey from paramiko.server import ServerInterface from paramiko.sftp_client import SFTPClient from paramiko.ssh_exception import ( - SSHException, BadAuthenticationType, ChannelException, ProxyCommandFailure, + SSHException, + BadAuthenticationType, + ChannelException, + ProxyCommandFailure, ) from paramiko.util import retry_on_signal, ClosingContextManager, clamp_value @@ -77,12 +111,14 @@ from paramiko.util import retry_on_signal, ClosingContextManager, clamp_value # for thread cleanup _active_threads = [] + def _join_lingering_threads(): for thr in _active_threads: thr.stop_thread() import atexit + atexit.register(_join_lingering_threads) @@ -99,160 +135,161 @@ class Transport(threading.Thread, ClosingContextManager): _ENCRYPT = object() _DECRYPT = object() - _PROTO_ID = '2.0' - _CLIENT_ID = 'paramiko_{}'.format(paramiko.__version__) + _PROTO_ID = "2.0" + _CLIENT_ID = "paramiko_{}".format(paramiko.__version__) # These tuples of algorithm identifiers are in preference order; do not # reorder without reason! _preferred_ciphers = ( - 'aes128-ctr', - 'aes192-ctr', - 'aes256-ctr', - 'aes128-cbc', - 'aes192-cbc', - 'aes256-cbc', - 'blowfish-cbc', - '3des-cbc', + "aes128-ctr", + "aes192-ctr", + "aes256-ctr", + "aes128-cbc", + "aes192-cbc", + "aes256-cbc", + "blowfish-cbc", + "3des-cbc", ) _preferred_macs = ( - 'hmac-sha2-256', - 'hmac-sha2-512', - 'hmac-sha1', - 'hmac-md5', - 'hmac-sha1-96', - 'hmac-md5-96', + "hmac-sha2-256", + "hmac-sha2-512", + "hmac-sha1", + "hmac-md5", + "hmac-sha1-96", + "hmac-md5-96", ) _preferred_keys = ( - 'ssh-ed25519', - 'ecdsa-sha2-nistp256', - 'ecdsa-sha2-nistp384', - 'ecdsa-sha2-nistp521', - 'ssh-rsa', - 'ssh-dss', + "ssh-ed25519", + "ecdsa-sha2-nistp256", + "ecdsa-sha2-nistp384", + "ecdsa-sha2-nistp521", + "ssh-rsa", + "ssh-dss", ) _preferred_kex = ( - 'ecdh-sha2-nistp256', - 'ecdh-sha2-nistp384', - 'ecdh-sha2-nistp521', - 'diffie-hellman-group-exchange-sha256', - 'diffie-hellman-group-exchange-sha1', - 'diffie-hellman-group14-sha1', - 'diffie-hellman-group1-sha1', + "ecdh-sha2-nistp256", + "ecdh-sha2-nistp384", + "ecdh-sha2-nistp521", + "diffie-hellman-group-exchange-sha256", + "diffie-hellman-group-exchange-sha1", + "diffie-hellman-group14-sha1", + "diffie-hellman-group1-sha1", ) _preferred_gsskex = ( - 'gss-gex-sha1-toWM5Slw5Ew8Mqkay+al2g==', - 'gss-group14-sha1-toWM5Slw5Ew8Mqkay+al2g==', - 'gss-group1-sha1-toWM5Slw5Ew8Mqkay+al2g==', + "gss-gex-sha1-toWM5Slw5Ew8Mqkay+al2g==", + "gss-group14-sha1-toWM5Slw5Ew8Mqkay+al2g==", + "gss-group1-sha1-toWM5Slw5Ew8Mqkay+al2g==", ) - _preferred_compression = ('none',) + _preferred_compression = ("none",) _cipher_info = { - 'aes128-ctr': { - 'class': algorithms.AES, - 'mode': modes.CTR, - 'block-size': 16, - 'key-size': 16 + "aes128-ctr": { + "class": algorithms.AES, + "mode": modes.CTR, + "block-size": 16, + "key-size": 16, }, - 'aes192-ctr': { - 'class': algorithms.AES, - 'mode': modes.CTR, - 'block-size': 16, - 'key-size': 24 + "aes192-ctr": { + "class": algorithms.AES, + "mode": modes.CTR, + "block-size": 16, + "key-size": 24, }, - 'aes256-ctr': { - 'class': algorithms.AES, - 'mode': modes.CTR, - 'block-size': 16, - 'key-size': 32 + "aes256-ctr": { + "class": algorithms.AES, + "mode": modes.CTR, + "block-size": 16, + "key-size": 32, }, - 'blowfish-cbc': { - 'class': algorithms.Blowfish, - 'mode': modes.CBC, - 'block-size': 8, - 'key-size': 16 + "blowfish-cbc": { + "class": algorithms.Blowfish, + "mode": modes.CBC, + "block-size": 8, + "key-size": 16, }, - 'aes128-cbc': { - 'class': algorithms.AES, - 'mode': modes.CBC, - 'block-size': 16, - 'key-size': 16 + "aes128-cbc": { + "class": algorithms.AES, + "mode": modes.CBC, + "block-size": 16, + "key-size": 16, }, - 'aes192-cbc': { - 'class': algorithms.AES, - 'mode': modes.CBC, - 'block-size': 16, - 'key-size': 24 + "aes192-cbc": { + "class": algorithms.AES, + "mode": modes.CBC, + "block-size": 16, + "key-size": 24, }, - 'aes256-cbc': { - 'class': algorithms.AES, - 'mode': modes.CBC, - 'block-size': 16, - 'key-size': 32 + "aes256-cbc": { + "class": algorithms.AES, + "mode": modes.CBC, + "block-size": 16, + "key-size": 32, }, - '3des-cbc': { - 'class': algorithms.TripleDES, - 'mode': modes.CBC, - 'block-size': 8, - 'key-size': 24 + "3des-cbc": { + "class": algorithms.TripleDES, + "mode": modes.CBC, + "block-size": 8, + "key-size": 24, }, } - _mac_info = { - 'hmac-sha1': {'class': sha1, 'size': 20}, - 'hmac-sha1-96': {'class': sha1, 'size': 12}, - 'hmac-sha2-256': {'class': sha256, 'size': 32}, - 'hmac-sha2-512': {'class': sha512, 'size': 64}, - 'hmac-md5': {'class': md5, 'size': 16}, - 'hmac-md5-96': {'class': md5, 'size': 12}, + "hmac-sha1": {"class": sha1, "size": 20}, + "hmac-sha1-96": {"class": sha1, "size": 12}, + "hmac-sha2-256": {"class": sha256, "size": 32}, + "hmac-sha2-512": {"class": sha512, "size": 64}, + "hmac-md5": {"class": md5, "size": 16}, + "hmac-md5-96": {"class": md5, "size": 12}, } _key_info = { - 'ssh-rsa': RSAKey, - 'ssh-rsa-cert-v01@openssh.com': RSAKey, - 'ssh-dss': DSSKey, - 'ssh-dss-cert-v01@openssh.com': DSSKey, - 'ecdsa-sha2-nistp256': ECDSAKey, - 'ecdsa-sha2-nistp256-cert-v01@openssh.com': ECDSAKey, - 'ecdsa-sha2-nistp384': ECDSAKey, - 'ecdsa-sha2-nistp384-cert-v01@openssh.com': ECDSAKey, - 'ecdsa-sha2-nistp521': ECDSAKey, - 'ecdsa-sha2-nistp521-cert-v01@openssh.com': ECDSAKey, - 'ssh-ed25519': Ed25519Key, - 'ssh-ed25519-cert-v01@openssh.com': Ed25519Key, + "ssh-rsa": RSAKey, + "ssh-rsa-cert-v01@openssh.com": RSAKey, + "ssh-dss": DSSKey, + "ssh-dss-cert-v01@openssh.com": DSSKey, + "ecdsa-sha2-nistp256": ECDSAKey, + "ecdsa-sha2-nistp256-cert-v01@openssh.com": ECDSAKey, + "ecdsa-sha2-nistp384": ECDSAKey, + "ecdsa-sha2-nistp384-cert-v01@openssh.com": ECDSAKey, + "ecdsa-sha2-nistp521": ECDSAKey, + "ecdsa-sha2-nistp521-cert-v01@openssh.com": ECDSAKey, + "ssh-ed25519": Ed25519Key, + "ssh-ed25519-cert-v01@openssh.com": Ed25519Key, } _kex_info = { - 'diffie-hellman-group1-sha1': KexGroup1, - 'diffie-hellman-group14-sha1': KexGroup14, - 'diffie-hellman-group-exchange-sha1': KexGex, - 'diffie-hellman-group-exchange-sha256': KexGexSHA256, - 'gss-group1-sha1-toWM5Slw5Ew8Mqkay+al2g==': KexGSSGroup1, - 'gss-group14-sha1-toWM5Slw5Ew8Mqkay+al2g==': KexGSSGroup14, - 'gss-gex-sha1-toWM5Slw5Ew8Mqkay+al2g==': KexGSSGex, - 'ecdh-sha2-nistp256': KexNistp256, - 'ecdh-sha2-nistp384': KexNistp384, - 'ecdh-sha2-nistp521': KexNistp521, + "diffie-hellman-group1-sha1": KexGroup1, + "diffie-hellman-group14-sha1": KexGroup14, + "diffie-hellman-group-exchange-sha1": KexGex, + "diffie-hellman-group-exchange-sha256": KexGexSHA256, + "gss-group1-sha1-toWM5Slw5Ew8Mqkay+al2g==": KexGSSGroup1, + "gss-group14-sha1-toWM5Slw5Ew8Mqkay+al2g==": KexGSSGroup14, + "gss-gex-sha1-toWM5Slw5Ew8Mqkay+al2g==": KexGSSGex, + "ecdh-sha2-nistp256": KexNistp256, + "ecdh-sha2-nistp384": KexNistp384, + "ecdh-sha2-nistp521": KexNistp521, } _compression_info = { # zlib@openssh.com is just zlib, but only turned on after a successful # authentication. openssh servers may only offer this type because # they've had troubles with security holes in zlib in the past. - 'zlib@openssh.com': (ZlibCompressor, ZlibDecompressor), - 'zlib': (ZlibCompressor, ZlibDecompressor), - 'none': (None, None), + "zlib@openssh.com": (ZlibCompressor, ZlibDecompressor), + "zlib": (ZlibCompressor, ZlibDecompressor), + "none": (None, None), } _modulus_pack = None _active_check_timeout = 0.1 - def __init__(self, - sock, - default_window_size=DEFAULT_WINDOW_SIZE, - default_max_packet_size=DEFAULT_MAX_PACKET_SIZE, - gss_kex=False, - gss_deleg_creds=True): + def __init__( + self, + sock, + default_window_size=DEFAULT_WINDOW_SIZE, + default_max_packet_size=DEFAULT_MAX_PACKET_SIZE, + gss_kex=False, + gss_deleg_creds=True, + ): """ Create a new SSH session over an existing socket, or socket-like object. This only creates the `.Transport` object; it doesn't begin @@ -302,7 +339,7 @@ class Transport(threading.Thread, ClosingContextManager): if isinstance(sock, string_types): # convert "host:port" into (host, port) - hl = sock.split(':', 1) + hl = sock.split(":", 1) self.hostname = hl[0] if len(hl) == 1: sock = (hl[0], 22) @@ -312,7 +349,7 @@ class Transport(threading.Thread, ClosingContextManager): # connect to the given (host, port) hostname, port = sock self.hostname = hostname - reason = 'No suitable address family' + reason = "No suitable address family" addrinfos = socket.getaddrinfo( hostname, port, socket.AF_UNSPEC, socket.SOCK_STREAM ) @@ -329,7 +366,8 @@ class Transport(threading.Thread, ClosingContextManager): break else: raise SSHException( - 'Unable to connect to {}: {}'.format(hostname, reason)) + "Unable to connect to {}: {}".format(hostname, reason) + ) # okay, normal socket-ish flow here... threading.Thread.__init__(self) self.setDaemon(True) @@ -340,9 +378,9 @@ class Transport(threading.Thread, ClosingContextManager): # negotiated crypto parameters self.packetizer = Packetizer(sock) - self.local_version = 'SSH-' + self._PROTO_ID + '-' + self._CLIENT_ID - self.remote_version = '' - self.local_cipher = self.remote_cipher = '' + self.local_version = "SSH-" + self._PROTO_ID + "-" + self._CLIENT_ID + self.remote_version = "" + self.local_cipher = self.remote_cipher = "" self.local_kex_init = self.remote_kex_init = None self.local_mac = self.remote_mac = None self.local_compression = self.remote_compression = None @@ -374,8 +412,8 @@ class Transport(threading.Thread, ClosingContextManager): # tracking open channels self._channels = ChannelMap() - self.channel_events = {} # (id -> Event) - self.channels_seen = {} # (id -> True) + self.channel_events = {} # (id -> Event) + self.channels_seen = {} # (id -> True) self._channel_counter = 0 self.default_max_packet_size = default_max_packet_size self.default_window_size = default_window_size @@ -387,7 +425,7 @@ class Transport(threading.Thread, ClosingContextManager): self.clear_to_send = threading.Event() self.clear_to_send_lock = threading.Lock() self.clear_to_send_timeout = 30.0 - self.log_name = 'paramiko.transport' + self.log_name = "paramiko.transport" self.logger = util.get_logger(self.log_name) self.packetizer.set_log(self.logger) self.auth_handler = None @@ -416,23 +454,24 @@ class Transport(threading.Thread, ClosingContextManager): Returns a string representation of this object, for debugging. """ id_ = hex(long(id(self)) & xffffffff) - out = '<paramiko.Transport at {}'.format(id_) + out = "<paramiko.Transport at {}".format(id_) if not self.active: - out += ' (unconnected)' + out += " (unconnected)" else: - if self.local_cipher != '': - out += ' (cipher {}, {:d} bits)'.format( + if self.local_cipher != "": + out += " (cipher {}, {:d} bits)".format( self.local_cipher, - self._cipher_info[self.local_cipher]['key-size'] * 8 + self._cipher_info[self.local_cipher]["key-size"] * 8, ) if self.is_authenticated(): - out += ' (active; {} open channel(s))'.format( - len(self._channels)) + out += " (active; {} open channel(s))".format( + len(self._channels) + ) elif self.initial_kex_done: - out += ' (connected; awaiting auth)' + out += " (connected; awaiting auth)" else: - out += ' (connecting)' - out += '>' + out += " (connecting)" + out += ">" return out def atfork(self): @@ -543,10 +582,10 @@ class Transport(threading.Thread, ClosingContextManager): e = self.get_exception() if e is not None: raise e - raise SSHException('Negotiation failed.') + raise SSHException("Negotiation failed.") if ( - event.is_set() or - (timeout is not None and time.time() >= max_time) + event.is_set() + or (timeout is not None and time.time() >= max_time) ): break @@ -612,7 +651,7 @@ class Transport(threading.Thread, ClosingContextManager): e = self.get_exception() if e is not None: raise e - raise SSHException('Negotiation failed.') + raise SSHException("Negotiation failed.") if event.is_set(): break @@ -679,7 +718,7 @@ class Transport(threading.Thread, ClosingContextManager): """ Transport._modulus_pack = ModulusPack() # places to look for the openssh "moduli" file - file_list = ['/etc/ssh/moduli', '/usr/local/etc/moduli'] + file_list = ["/etc/ssh/moduli", "/usr/local/etc/moduli"] if filename is not None: file_list.insert(0, filename) for fn in file_list: @@ -717,7 +756,7 @@ class Transport(threading.Thread, ClosingContextManager): :return: public key (`.PKey`) of the remote server """ if (not self.active) or (not self.initial_kex_done): - raise SSHException('No existing session') + raise SSHException("No existing session") return self.host_key def is_active(self): @@ -731,10 +770,7 @@ class Transport(threading.Thread, ClosingContextManager): return self.active def open_session( - self, - window_size=None, - max_packet_size=None, - timeout=None, + self, window_size=None, max_packet_size=None, timeout=None ): """ Request a new channel to the server, of type ``"session"``. This is @@ -761,10 +797,12 @@ class Transport(threading.Thread, ClosingContextManager): .. versionchanged:: 1.15 Added the ``window_size`` and ``max_packet_size`` arguments. """ - return self.open_channel('session', - window_size=window_size, - max_packet_size=max_packet_size, - timeout=timeout) + return self.open_channel( + "session", + window_size=window_size, + max_packet_size=max_packet_size, + timeout=timeout, + ) def open_x11_channel(self, src_addr=None): """ @@ -780,7 +818,7 @@ class Transport(threading.Thread, ClosingContextManager): `.SSHException` -- if the request is rejected or the session ends prematurely """ - return self.open_channel('x11', src_addr=src_addr) + return self.open_channel("x11", src_addr=src_addr) def open_forward_agent_channel(self): """ @@ -794,7 +832,7 @@ class Transport(threading.Thread, ClosingContextManager): :raises: `.SSHException` -- if the request is rejected or the session ends prematurely """ - return self.open_channel('auth-agent@openssh.com') + return self.open_channel("auth-agent@openssh.com") def open_forwarded_tcpip_channel(self, src_addr, dest_addr): """ @@ -806,15 +844,17 @@ class Transport(threading.Thread, ClosingContextManager): :param src_addr: originator's address :param dest_addr: local (server) connected address """ - return self.open_channel('forwarded-tcpip', dest_addr, src_addr) + return self.open_channel("forwarded-tcpip", dest_addr, src_addr) - def open_channel(self, - kind, - dest_addr=None, - src_addr=None, - window_size=None, - max_packet_size=None, - timeout=None): + def open_channel( + self, + kind, + dest_addr=None, + src_addr=None, + window_size=None, + max_packet_size=None, + timeout=None, + ): """ Request a new channel to the server. `Channels <.Channel>` are socket-like objects used for the actual transfer of data across the @@ -851,7 +891,7 @@ class Transport(threading.Thread, ClosingContextManager): Added the ``window_size`` and ``max_packet_size`` arguments. """ if not self.active: - raise SSHException('SSH session not active') + raise SSHException("SSH session not active") timeout = 3600 if timeout is None else timeout self.lock.acquire() try: @@ -864,12 +904,12 @@ class Transport(threading.Thread, ClosingContextManager): m.add_int(chanid) m.add_int(window_size) m.add_int(max_packet_size) - if (kind == 'forwarded-tcpip') or (kind == 'direct-tcpip'): + if (kind == "forwarded-tcpip") or (kind == "direct-tcpip"): m.add_string(dest_addr[0]) m.add_int(dest_addr[1]) m.add_string(src_addr[0]) m.add_int(src_addr[1]) - elif kind == 'x11': + elif kind == "x11": m.add_string(src_addr[0]) m.add_int(src_addr[1]) chan = Channel(chanid) @@ -887,18 +927,18 @@ class Transport(threading.Thread, ClosingContextManager): if not self.active: e = self.get_exception() if e is None: - e = SSHException('Unable to open channel.') + e = SSHException("Unable to open channel.") raise e if event.is_set(): break elif start_ts + timeout < time.time(): - raise SSHException('Timeout opening channel.') + raise SSHException("Timeout opening channel.") chan = self._channels.get(chanid) if chan is not None: return chan e = self.get_exception() if e is None: - e = SSHException('Unable to open channel.') + e = SSHException("Unable to open channel.") raise e def request_port_forward(self, address, port, handler=None): @@ -935,20 +975,22 @@ class Transport(threading.Thread, ClosingContextManager): `.SSHException` -- if the server refused the TCP forward request """ if not self.active: - raise SSHException('SSH session not active') + raise SSHException("SSH session not active") port = int(port) response = self.global_request( - 'tcpip-forward', (address, port), wait=True + "tcpip-forward", (address, port), wait=True ) if response is None: - raise SSHException('TCP forwarding request denied') + raise SSHException("TCP forwarding request denied") if port == 0: port = response.get_int() if handler is None: + def default_handler(channel, src_addr, dest_addr_port): # src_addr, src_port = src_addr_port # dest_addr, dest_port = dest_addr_port self._queue_incoming_channel(channel) + handler = default_handler self._tcp_handler = handler return port @@ -965,7 +1007,7 @@ class Transport(threading.Thread, ClosingContextManager): if not self.active: return self._tcp_handler = None - self.global_request('cancel-tcpip-forward', (address, port), wait=True) + self.global_request("cancel-tcpip-forward", (address, port), wait=True) def open_sftp_client(self): """ @@ -1018,7 +1060,7 @@ class Transport(threading.Thread, ClosingContextManager): e = self.get_exception() if e is not None: raise e - raise SSHException('Negotiation failed.') + raise SSHException("Negotiation failed.") if self.completion_event.is_set(): break return @@ -1034,8 +1076,10 @@ class Transport(threading.Thread, ClosingContextManager): seconds to wait before sending a keepalive packet (or 0 to disable keepalives). """ + def _request(x=weakref.proxy(self)): - return x.global_request('keepalive@lag.net', wait=False) + return x.global_request("keepalive@lag.net", wait=False) + self.packetizer.set_keepalive(interval, _request) def global_request(self, kind, data=None, wait=True): @@ -1103,7 +1147,7 @@ class Transport(threading.Thread, ClosingContextManager): def connect( self, hostkey=None, - username='', + username="", password=None, pkey=None, gss_host=None, @@ -1177,34 +1221,43 @@ class Transport(threading.Thread, ClosingContextManager): if (hostkey is not None) and not gss_kex: key = self.get_remote_server_key() if ( - key.get_name() != hostkey.get_name() or - key.asbytes() != hostkey.asbytes() + key.get_name() != hostkey.get_name() + or key.asbytes() != hostkey.asbytes() ): - self._log(DEBUG, 'Bad host key from server') - self._log(DEBUG, 'Expected: {}: {}'.format( - hostkey.get_name(), repr(hostkey.asbytes()), - )) - self._log(DEBUG, 'Got : {}: {}'.format( - key.get_name(), repr(key.asbytes()), - )) - raise SSHException('Bad host key from server') - self._log(DEBUG, 'Host key verified ({})'.format( - hostkey.get_name())) + self._log(DEBUG, "Bad host key from server") + self._log( + DEBUG, + "Expected: {}: {}".format( + hostkey.get_name(), repr(hostkey.asbytes()) + ), + ) + self._log( + DEBUG, + "Got : {}: {}".format( + key.get_name(), repr(key.asbytes()) + ), + ) + raise SSHException("Bad host key from server") + self._log( + DEBUG, "Host key verified ({})".format(hostkey.get_name()) + ) if (pkey is not None) or (password is not None) or gss_auth or gss_kex: if gss_auth: - self._log(DEBUG, 'Attempting GSS-API auth... (gssapi-with-mic)') # noqa + self._log( + DEBUG, "Attempting GSS-API auth... (gssapi-with-mic)" + ) # noqa self.auth_gssapi_with_mic( - username, self.gss_host, gss_deleg_creds, + username, self.gss_host, gss_deleg_creds ) elif gss_kex: - self._log(DEBUG, 'Attempting GSS-API auth... (gssapi-keyex)') + self._log(DEBUG, "Attempting GSS-API auth... (gssapi-keyex)") self.auth_gssapi_keyex(username) elif pkey is not None: - self._log(DEBUG, 'Attempting public-key auth...') + self._log(DEBUG, "Attempting public-key auth...") self.auth_publickey(username, pkey) else: - self._log(DEBUG, 'Attempting password auth...') + self._log(DEBUG, "Attempting password auth...") self.auth_password(username, password) return @@ -1259,9 +1312,9 @@ class Transport(threading.Thread, ClosingContextManager): closed. """ return ( - self.active and - self.auth_handler is not None and - self.auth_handler.is_authenticated() + self.active + and self.auth_handler is not None + and self.auth_handler.is_authenticated() ) def get_username(self): @@ -1311,7 +1364,7 @@ class Transport(threading.Thread, ClosingContextManager): .. versionadded:: 1.5 """ if (not self.active) or (not self.initial_kex_done): - raise SSHException('No existing session') + raise SSHException("No existing session") my_event = threading.Event() self.auth_handler = AuthHandler(self) self.auth_handler.auth_none(username, my_event) @@ -1367,7 +1420,7 @@ class Transport(threading.Thread, ClosingContextManager): if (not self.active) or (not self.initial_kex_done): # we should never try to send the password unless we're on a secure # link - raise SSHException('No existing session') + raise SSHException("No existing session") if event is None: my_event = threading.Event() else: @@ -1382,12 +1435,13 @@ class Transport(threading.Thread, ClosingContextManager): except BadAuthenticationType as e: # if password auth isn't allowed, but keyboard-interactive *is*, # try to fudge it - if not fallback or ('keyboard-interactive' not in e.allowed_types): + if not fallback or ("keyboard-interactive" not in e.allowed_types): raise try: + def handler(title, instructions, fields): if len(fields) > 1: - raise SSHException('Fallback authentication failed.') + raise SSHException("Fallback authentication failed.") if len(fields) == 0: # for some reason, at least on os x, a 2nd request will # be made with zero fields requested. maybe it's just @@ -1395,6 +1449,7 @@ class Transport(threading.Thread, ClosingContextManager): # type we're doing here. *shrug* :) return [] return [password] + return self.auth_interactive(username, handler) except SSHException: # attempt failed; just raise the original exception @@ -1437,7 +1492,7 @@ class Transport(threading.Thread, ClosingContextManager): """ if (not self.active) or (not self.initial_kex_done): # we should never try to authenticate unless we're on a secure link - raise SSHException('No existing session') + raise SSHException("No existing session") if event is None: my_event = threading.Event() else: @@ -1449,7 +1504,7 @@ class Transport(threading.Thread, ClosingContextManager): return [] return self.auth_handler.wait_for_response(my_event) - def auth_interactive(self, username, handler, submethods=''): + def auth_interactive(self, username, handler, submethods=""): """ Authenticate to the server interactively. A handler is used to answer arbitrary questions from the server. On many servers, this is just a @@ -1494,7 +1549,7 @@ class Transport(threading.Thread, ClosingContextManager): """ if (not self.active) or (not self.initial_kex_done): # we should never try to authenticate unless we're on a secure link - raise SSHException('No existing session') + raise SSHException("No existing session") my_event = threading.Event() self.auth_handler = AuthHandler(self) self.auth_handler.auth_interactive( @@ -1502,7 +1557,7 @@ class Transport(threading.Thread, ClosingContextManager): ) return self.auth_handler.wait_for_response(my_event) - def auth_interactive_dumb(self, username, handler=None, submethods=''): + def auth_interactive_dumb(self, username, handler=None, submethods=""): """ Autenticate to the server interactively but dumber. Just print the prompt and / or instructions to stdout and send back @@ -1511,6 +1566,7 @@ class Transport(threading.Thread, ClosingContextManager): """ if not handler: + def handler(title, instructions, prompt_list): answers = [] if title: @@ -1518,9 +1574,10 @@ class Transport(threading.Thread, ClosingContextManager): if instructions: print(instructions.strip()) for prompt, show_input in prompt_list: - print(prompt.strip(), end=' ') + print(prompt.strip(), end=" ") answers.append(input()) return answers + return self.auth_interactive(username, handler, submethods) def auth_gssapi_with_mic(self, username, gss_host, gss_deleg_creds): @@ -1541,7 +1598,7 @@ class Transport(threading.Thread, ClosingContextManager): """ if (not self.active) or (not self.initial_kex_done): # we should never try to authenticate unless we're on a secure link - raise SSHException('No existing session') + raise SSHException("No existing session") my_event = threading.Event() self.auth_handler = AuthHandler(self) self.auth_handler.auth_gssapi_with_mic( @@ -1566,7 +1623,7 @@ class Transport(threading.Thread, ClosingContextManager): """ if (not self.active) or (not self.initial_kex_done): # we should never try to authenticate unless we're on a secure link - raise SSHException('No existing session') + raise SSHException("No existing session") my_event = threading.Event() self.auth_handler = AuthHandler(self) self.auth_handler.auth_gssapi_keyex(username, my_event) @@ -1633,9 +1690,9 @@ class Transport(threading.Thread, ClosingContextManager): .. versionadded:: 1.5.2 """ if compress: - self._preferred_compression = ('zlib@openssh.com', 'zlib', 'none') + self._preferred_compression = ("zlib@openssh.com", "zlib", "none") else: - self._preferred_compression = ('none',) + self._preferred_compression = ("none",) def getpeername(self): """ @@ -1649,9 +1706,9 @@ class Transport(threading.Thread, ClosingContextManager): the address of the remote host, if known, as a ``(str, int)`` tuple. """ - gp = getattr(self.sock, 'getpeername', None) + gp = getattr(self.sock, "getpeername", None) if gp is None: - return 'unknown', 0 + return "unknown", 0 return gp() def stop_thread(self): @@ -1670,10 +1727,10 @@ class Transport(threading.Thread, ClosingContextManager): # our socket and packetizer are both closed (but where we'd # otherwise be sitting forever on that recv()). while ( - self.is_alive() and - self is not threading.current_thread() and - not self.sock._closed and - not self.packetizer.closed + self.is_alive() + and self is not threading.current_thread() + and not self.sock._closed + and not self.packetizer.closed ): self.join(0.1) @@ -1715,14 +1772,18 @@ class Transport(threading.Thread, ClosingContextManager): while True: self.clear_to_send.wait(0.1) if not self.active: - self._log(DEBUG, 'Dropping user packet because connection is dead.') # noqa + self._log( + DEBUG, "Dropping user packet because connection is dead." + ) # noqa return self.clear_to_send_lock.acquire() if self.clear_to_send.is_set(): break self.clear_to_send_lock.release() if time.time() > start + self.clear_to_send_timeout: - raise SSHException('Key-exchange timed out waiting for key negotiation') # noqa + raise SSHException( + "Key-exchange timed out waiting for key negotiation" + ) # noqa try: self._send_message(data) finally: @@ -1746,9 +1807,13 @@ class Transport(threading.Thread, ClosingContextManager): def _verify_key(self, host_key, sig): key = self._key_info[self.host_key_type](Message(host_key)) if key is None: - raise SSHException('Unknown host key type') + raise SSHException("Unknown host key type") if not key.verify_ssh_sig(self.H, Message(sig)): - raise SSHException('Signature verification ({}) failed.'.format(self.host_key_type)) # noqa + raise SSHException( + "Signature verification ({}) failed.".format( + self.host_key_type + ) + ) # noqa self.host_key = key def _compute_key(self, id, nbytes): @@ -1760,16 +1825,16 @@ class Transport(threading.Thread, ClosingContextManager): m.add_bytes(self.session_id) # Fallback to SHA1 for kex engines that fail to specify a hex # algorithm, or for e.g. transport tests that don't run kexinit. - hash_algo = getattr(self.kex_engine, 'hash_algo', None) + hash_algo = getattr(self.kex_engine, "hash_algo", None) hash_select_msg = "kex engine {} specified hash_algo {!r}".format( - self.kex_engine.__class__.__name__, hash_algo, + self.kex_engine.__class__.__name__, hash_algo ) if hash_algo is None: hash_algo = sha1 hash_select_msg += ", falling back to sha1" - if not hasattr(self, '_logged_hash_selection'): + if not hasattr(self, "_logged_hash_selection"): self._log(DEBUG, hash_select_msg) - setattr(self, '_logged_hash_selection', True) + setattr(self, "_logged_hash_selection", True) out = sofar = hash_algo(m.asbytes()).digest() while len(out) < nbytes: m = Message() @@ -1783,11 +1848,11 @@ class Transport(threading.Thread, ClosingContextManager): def _get_cipher(self, name, key, iv, operation): if name not in self._cipher_info: - raise SSHException('Unknown client cipher ' + name) + raise SSHException("Unknown client cipher " + name) else: cipher = Cipher( - self._cipher_info[name]['class'](key), - self._cipher_info[name]['mode'](iv), + self._cipher_info[name]["class"](key), + self._cipher_info[name]["mode"](iv), backend=default_backend(), ) if operation is self._ENCRYPT: @@ -1797,8 +1862,10 @@ class Transport(threading.Thread, ClosingContextManager): def _set_forward_agent_handler(self, handler): if handler is None: + def default_handler(channel): self._queue_incoming_channel(channel) + self._forward_agent_handler = default_handler else: self._forward_agent_handler = handler @@ -1809,6 +1876,7 @@ class Transport(threading.Thread, ClosingContextManager): # by default, use the same mechanism as accept() def default_handler(channel, src_addr_port): self._queue_incoming_channel(channel) + self._x11_handler = default_handler else: self._x11_handler = handler @@ -1842,9 +1910,9 @@ class Transport(threading.Thread, ClosingContextManager): Otherwise (client mode, authed, or pre-auth message) returns None. """ if ( - not self.server_mode or - ptype <= HIGHEST_USERAUTH_MESSAGE_ID or - self.is_authenticated() + not self.server_mode + or ptype <= HIGHEST_USERAUTH_MESSAGE_ID + or self.is_authenticated() ): return None # WELP. We must be dealing with someone trying to do non-auth things @@ -1855,13 +1923,13 @@ class Transport(threading.Thread, ClosingContextManager): reply.add_byte(cMSG_REQUEST_FAILURE) # Channel opens let us reject w/ a specific type + message. elif ptype == MSG_CHANNEL_OPEN: - kind = message.get_text() # noqa + kind = message.get_text() # noqa chanid = message.get_int() reply.add_byte(cMSG_CHANNEL_OPEN_FAILURE) reply.add_int(chanid) reply.add_int(OPEN_FAILED_ADMINISTRATIVELY_PROHIBITED) - reply.add_string('') - reply.add_string('en') + reply.add_string("") + reply.add_string("en") # NOTE: Post-open channel messages do not need checking; the above will # reject attemps to open channels, meaning that even if a malicious # user tries to send a MSG_CHANNEL_REQUEST, it will simply fall under @@ -1883,13 +1951,16 @@ class Transport(threading.Thread, ClosingContextManager): _active_threads.append(self) tid = hex(long(id(self)) & xffffffff) if self.server_mode: - self._log(DEBUG, 'starting thread (server mode): {}'.format(tid)) + self._log(DEBUG, "starting thread (server mode): {}".format(tid)) else: - self._log(DEBUG, 'starting thread (client mode): {}'.format(tid)) + self._log(DEBUG, "starting thread (client mode): {}".format(tid)) try: try: - self.packetizer.write_all(b(self.local_version + '\r\n')) - self._log(DEBUG, 'Local version/idstring: {}'.format(self.local_version)) # noqa + self.packetizer.write_all(b(self.local_version + "\r\n")) + self._log( + DEBUG, + "Local version/idstring: {}".format(self.local_version), + ) # noqa self._check_banner() # The above is actually very much part of the handshake, but # sometimes the banner can be read but the machine is not @@ -1919,7 +1990,11 @@ class Transport(threading.Thread, ClosingContextManager): continue if len(self._expected_packet) > 0: if ptype not in self._expected_packet: - raise SSHException('Expecting packet from {!r}, got {:d}'.format(self._expected_packet, ptype)) # noqa + raise SSHException( + "Expecting packet from {!r}, got {:d}".format( + self._expected_packet, ptype + ) + ) # noqa self._expected_packet = tuple() if (ptype >= 30) and (ptype <= 41): self.kex_engine.parse_next(ptype, m) @@ -1937,20 +2012,30 @@ class Transport(threading.Thread, ClosingContextManager): if chan is not None: self._channel_handler_table[ptype](chan, m) elif chanid in self.channels_seen: - self._log(DEBUG, 'Ignoring message for dead channel {:d}'.format(chanid)) # noqa + self._log( + DEBUG, + "Ignoring message for dead channel {:d}".format( # noqa + chanid + ), + ) else: - self._log(ERROR, 'Channel request for unknown channel {:d}'.format(chanid)) # noqa + self._log( + ERROR, + "Channel request for unknown channel {:d}".format( # noqa + chanid + ), + ) break elif ( - self.auth_handler is not None and - ptype in self.auth_handler._handler_table + self.auth_handler is not None + and ptype in self.auth_handler._handler_table ): handler = self.auth_handler._handler_table[ptype] handler(self.auth_handler, m) if len(self._expected_packet) > 0: continue else: - err = 'Oops, unhandled type {:d}'.format(ptype) + err = "Oops, unhandled type {:d}".format(ptype) self._log(WARNING, err) msg = Message() msg.add_byte(cMSG_UNIMPLEMENTED) @@ -1958,24 +2043,24 @@ class Transport(threading.Thread, ClosingContextManager): self._send_message(msg) self.packetizer.complete_handshake() except SSHException as e: - self._log(ERROR, 'Exception: ' + str(e)) + self._log(ERROR, "Exception: " + str(e)) self._log(ERROR, util.tb_strings()) self.saved_exception = e except EOFError as e: - self._log(DEBUG, 'EOF in transport thread') + self._log(DEBUG, "EOF in transport thread") self.saved_exception = e except socket.error as e: if type(e.args) is tuple: if e.args: - emsg = '{} ({:d})'.format(e.args[1], e.args[0]) + emsg = "{} ({:d})".format(e.args[1], e.args[0]) else: # empty tuple, e.g. socket.timeout emsg = str(e) or repr(e) else: emsg = e.args - self._log(ERROR, 'Socket exception: ' + emsg) + self._log(ERROR, "Socket exception: " + emsg) self.saved_exception = e except Exception as e: - self._log(ERROR, 'Unknown exception: ' + str(e)) + self._log(ERROR, "Unknown exception: " + str(e)) self._log(ERROR, util.tb_strings()) self.saved_exception = e _active_threads.remove(self) @@ -2004,7 +2089,6 @@ class Transport(threading.Thread, ClosingContextManager): if self.sys.modules is not None: raise - def _log_agreement(self, which, local, remote): # Log useful, non-duplicative line re: an agreed-upon algorithm. # Old code implied algorithms could be asymmetrical (different for @@ -2046,32 +2130,32 @@ class Transport(threading.Thread, ClosingContextManager): raise except Exception as e: raise SSHException( - 'Error reading SSH protocol banner' + str(e) + "Error reading SSH protocol banner" + str(e) ) - if buf[:4] == 'SSH-': + if buf[:4] == "SSH-": break - self._log(DEBUG, 'Banner: ' + buf) - if buf[:4] != 'SSH-': + self._log(DEBUG, "Banner: " + buf) + if buf[:4] != "SSH-": raise SSHException('Indecipherable protocol version "' + buf + '"') # save this server version string for later self.remote_version = buf - self._log(DEBUG, 'Remote version/idstring: {}'.format(buf)) + self._log(DEBUG, "Remote version/idstring: {}".format(buf)) # pull off any attached comment # NOTE: comment used to be stored in a variable and then...never used. # since 2003. ca 877cd974b8182d26fa76d566072917ea67b64e67 - i = buf.find(' ') + i = buf.find(" ") if i >= 0: buf = buf[:i] # parse out version string and make sure it matches - segs = buf.split('-', 2) + segs = buf.split("-", 2) if len(segs) < 3: - raise SSHException('Invalid SSH banner') + raise SSHException("Invalid SSH banner") version = segs[1] client = segs[2] - if version != '1.99' and version != '2.0': - msg = 'Incompatible version ({} instead of 2.0)' + if version != "1.99" and version != "2.0": + msg = "Incompatible version ({} instead of 2.0)" raise SSHException(msg.format(version)) - msg = 'Connected (version {}, client {})'.format(version, client) + msg = "Connected (version {}, client {})".format(version, client) self._log(INFO, msg) def _send_kex_init(self): @@ -2087,25 +2171,27 @@ class Transport(threading.Thread, ClosingContextManager): self.gss_kex_used = False self.in_kex = True if self.server_mode: - mp_required_prefix = 'diffie-hellman-group-exchange-sha' + mp_required_prefix = "diffie-hellman-group-exchange-sha" kex_mp = [ - k for k - in self._preferred_kex + k + for k in self._preferred_kex if k.startswith(mp_required_prefix) ] if (self._modulus_pack is None) and (len(kex_mp) > 0): # can't do group-exchange if we don't have a pack of potential # primes pkex = [ - k for k - in self.get_security_options().kex + k + for k in self.get_security_options().kex if not k.startswith(mp_required_prefix) ] self.get_security_options().kex = pkex - available_server_keys = list(filter( - list(self.server_key_dict.keys()).__contains__, - self._preferred_keys - )) + available_server_keys = list( + filter( + list(self.server_key_dict.keys()).__contains__, + self._preferred_keys, + ) + ) else: available_server_keys = self._preferred_keys @@ -2129,7 +2215,7 @@ class Transport(threading.Thread, ClosingContextManager): self._send_message(m) def _parse_kex_init(self, m): - m.get_bytes(16) # cookie, discarded + m.get_bytes(16) # cookie, discarded kex_algo_list = m.get_list() server_key_algo_list = m.get_list() client_encrypt_algo_list = m.get_list() @@ -2141,20 +2227,32 @@ class Transport(threading.Thread, ClosingContextManager): client_lang_list = m.get_list() server_lang_list = m.get_list() kex_follows = m.get_boolean() - m.get_int() # unused - - self._log(DEBUG, - 'kex algos:' + str(kex_algo_list) + - ' server key:' + str(server_key_algo_list) + - ' client encrypt:' + str(client_encrypt_algo_list) + - ' server encrypt:' + str(server_encrypt_algo_list) + - ' client mac:' + str(client_mac_algo_list) + - ' server mac:' + str(server_mac_algo_list) + - ' client compress:' + str(client_compress_algo_list) + - ' server compress:' + str(server_compress_algo_list) + - ' client lang:' + str(client_lang_list) + - ' server lang:' + str(server_lang_list) + - ' kex follows?' + str(kex_follows) + m.get_int() # unused + + self._log( + DEBUG, + "kex algos:" + + str(kex_algo_list) + + " server key:" + + str(server_key_algo_list) + + " client encrypt:" + + str(client_encrypt_algo_list) + + " server encrypt:" + + str(server_encrypt_algo_list) + + " client mac:" + + str(client_mac_algo_list) + + " server mac:" + + str(server_mac_algo_list) + + " client compress:" + + str(client_compress_algo_list) + + " server compress:" + + str(server_compress_algo_list) + + " client lang:" + + str(client_lang_list) + + " server lang:" + + str(server_lang_list) + + " kex follows?" + + str(kex_follows), ) # as a server, we pick the first item in the client's list that we @@ -2162,122 +2260,149 @@ class Transport(threading.Thread, ClosingContextManager): # as a client, we pick the first item in our list that the server # supports. if self.server_mode: - agreed_kex = list(filter( - self._preferred_kex.__contains__, - kex_algo_list - )) + agreed_kex = list( + filter(self._preferred_kex.__contains__, kex_algo_list) + ) else: - agreed_kex = list(filter( - kex_algo_list.__contains__, - self._preferred_kex - )) + agreed_kex = list( + filter(kex_algo_list.__contains__, self._preferred_kex) + ) if len(agreed_kex) == 0: - raise SSHException('Incompatible ssh peer (no acceptable kex algorithm)') # noqa + raise SSHException( + "Incompatible ssh peer (no acceptable kex algorithm)" + ) # noqa self.kex_engine = self._kex_info[agreed_kex[0]](self) self._log(DEBUG, "Kex agreed: {}".format(agreed_kex[0])) if self.server_mode: - available_server_keys = list(filter( - list(self.server_key_dict.keys()).__contains__, - self._preferred_keys - )) - agreed_keys = list(filter( - available_server_keys.__contains__, server_key_algo_list - )) + available_server_keys = list( + filter( + list(self.server_key_dict.keys()).__contains__, + self._preferred_keys, + ) + ) + agreed_keys = list( + filter( + available_server_keys.__contains__, server_key_algo_list + ) + ) else: - agreed_keys = list(filter( - server_key_algo_list.__contains__, self._preferred_keys - )) + agreed_keys = list( + filter(server_key_algo_list.__contains__, self._preferred_keys) + ) if len(agreed_keys) == 0: - raise SSHException('Incompatible ssh peer (no acceptable host key)') # noqa + raise SSHException( + "Incompatible ssh peer (no acceptable host key)" + ) # noqa self.host_key_type = agreed_keys[0] if self.server_mode and (self.get_server_key() is None): - raise SSHException('Incompatible ssh peer (can\'t match requested host key type)') # noqa - self._log_agreement( - 'HostKey', agreed_keys[0], agreed_keys[0] - ) + raise SSHException( + "Incompatible ssh peer (can't match requested host key type)" + ) # noqa + self._log_agreement("HostKey", agreed_keys[0], agreed_keys[0]) if self.server_mode: - agreed_local_ciphers = list(filter( - self._preferred_ciphers.__contains__, - server_encrypt_algo_list - )) - agreed_remote_ciphers = list(filter( - self._preferred_ciphers.__contains__, - client_encrypt_algo_list - )) + agreed_local_ciphers = list( + filter( + self._preferred_ciphers.__contains__, + server_encrypt_algo_list, + ) + ) + agreed_remote_ciphers = list( + filter( + self._preferred_ciphers.__contains__, + client_encrypt_algo_list, + ) + ) else: - agreed_local_ciphers = list(filter( - client_encrypt_algo_list.__contains__, - self._preferred_ciphers - )) - agreed_remote_ciphers = list(filter( - server_encrypt_algo_list.__contains__, - self._preferred_ciphers - )) + agreed_local_ciphers = list( + filter( + client_encrypt_algo_list.__contains__, + self._preferred_ciphers, + ) + ) + agreed_remote_ciphers = list( + filter( + server_encrypt_algo_list.__contains__, + self._preferred_ciphers, + ) + ) if len(agreed_local_ciphers) == 0 or len(agreed_remote_ciphers) == 0: - raise SSHException('Incompatible ssh server (no acceptable ciphers)') # noqa + raise SSHException( + "Incompatible ssh server (no acceptable ciphers)" + ) # noqa self.local_cipher = agreed_local_ciphers[0] self.remote_cipher = agreed_remote_ciphers[0] self._log_agreement( - 'Cipher', local=self.local_cipher, remote=self.remote_cipher + "Cipher", local=self.local_cipher, remote=self.remote_cipher ) if self.server_mode: - agreed_remote_macs = list(filter( - self._preferred_macs.__contains__, client_mac_algo_list - )) - agreed_local_macs = list(filter( - self._preferred_macs.__contains__, server_mac_algo_list - )) + agreed_remote_macs = list( + filter(self._preferred_macs.__contains__, client_mac_algo_list) + ) + agreed_local_macs = list( + filter(self._preferred_macs.__contains__, server_mac_algo_list) + ) else: - agreed_local_macs = list(filter( - client_mac_algo_list.__contains__, self._preferred_macs - )) - agreed_remote_macs = list(filter( - server_mac_algo_list.__contains__, self._preferred_macs - )) + agreed_local_macs = list( + filter(client_mac_algo_list.__contains__, self._preferred_macs) + ) + agreed_remote_macs = list( + filter(server_mac_algo_list.__contains__, self._preferred_macs) + ) if (len(agreed_local_macs) == 0) or (len(agreed_remote_macs) == 0): - raise SSHException('Incompatible ssh server (no acceptable macs)') + raise SSHException("Incompatible ssh server (no acceptable macs)") self.local_mac = agreed_local_macs[0] self.remote_mac = agreed_remote_macs[0] self._log_agreement( - 'MAC', local=self.local_mac, remote=self.remote_mac + "MAC", local=self.local_mac, remote=self.remote_mac ) if self.server_mode: - agreed_remote_compression = list(filter( - self._preferred_compression.__contains__, - client_compress_algo_list - )) - agreed_local_compression = list(filter( - self._preferred_compression.__contains__, - server_compress_algo_list - )) + agreed_remote_compression = list( + filter( + self._preferred_compression.__contains__, + client_compress_algo_list, + ) + ) + agreed_local_compression = list( + filter( + self._preferred_compression.__contains__, + server_compress_algo_list, + ) + ) else: - agreed_local_compression = list(filter( - client_compress_algo_list.__contains__, - self._preferred_compression - )) - agreed_remote_compression = list(filter( - server_compress_algo_list.__contains__, - self._preferred_compression - )) + agreed_local_compression = list( + filter( + client_compress_algo_list.__contains__, + self._preferred_compression, + ) + ) + agreed_remote_compression = list( + filter( + server_compress_algo_list.__contains__, + self._preferred_compression, + ) + ) if ( - len(agreed_local_compression) == 0 or - len(agreed_remote_compression) == 0 + len(agreed_local_compression) == 0 + or len(agreed_remote_compression) == 0 ): - msg = 'Incompatible ssh server (no acceptable compression) {!r} {!r} {!r}' # noqa - raise SSHException(msg.format( - agreed_local_compression, agreed_remote_compression, - self._preferred_compression, - )) + msg = "Incompatible ssh server (no acceptable compression) {!r} {!r} {!r}" # noqa + raise SSHException( + msg.format( + agreed_local_compression, + agreed_remote_compression, + self._preferred_compression, + ) + ) self.local_compression = agreed_local_compression[0] self.remote_compression = agreed_remote_compression[0] self._log_agreement( - 'Compression', + "Compression", local=self.local_compression, - remote=self.remote_compression + remote=self.remote_compression, ) # save for computing hash later... @@ -2290,40 +2415,40 @@ class Transport(threading.Thread, ClosingContextManager): def _activate_inbound(self): """switch on newly negotiated encryption parameters for inbound traffic""" - block_size = self._cipher_info[self.remote_cipher]['block-size'] + block_size = self._cipher_info[self.remote_cipher]["block-size"] if self.server_mode: - IV_in = self._compute_key('A', block_size) + IV_in = self._compute_key("A", block_size) key_in = self._compute_key( - 'C', self._cipher_info[self.remote_cipher]['key-size'] + "C", self._cipher_info[self.remote_cipher]["key-size"] ) else: - IV_in = self._compute_key('B', block_size) + IV_in = self._compute_key("B", block_size) key_in = self._compute_key( - 'D', self._cipher_info[self.remote_cipher]['key-size'] + "D", self._cipher_info[self.remote_cipher]["key-size"] ) engine = self._get_cipher( self.remote_cipher, key_in, IV_in, self._DECRYPT ) - mac_size = self._mac_info[self.remote_mac]['size'] - mac_engine = self._mac_info[self.remote_mac]['class'] + mac_size = self._mac_info[self.remote_mac]["size"] + mac_engine = self._mac_info[self.remote_mac]["class"] # initial mac keys are done in the hash's natural size (not the # potentially truncated transmission size) if self.server_mode: - mac_key = self._compute_key('E', mac_engine().digest_size) + mac_key = self._compute_key("E", mac_engine().digest_size) else: - mac_key = self._compute_key('F', mac_engine().digest_size) + mac_key = self._compute_key("F", mac_engine().digest_size) self.packetizer.set_inbound_cipher( engine, block_size, mac_engine, mac_size, mac_key ) compress_in = self._compression_info[self.remote_compression][1] if ( - compress_in is not None and - ( - self.remote_compression != 'zlib@openssh.com' or - self.authenticated + compress_in is not None + and ( + self.remote_compression != "zlib@openssh.com" + or self.authenticated ) ): - self._log(DEBUG, 'Switching on inbound compression ...') + self._log(DEBUG, "Switching on inbound compression ...") self.packetizer.set_inbound_compressor(compress_in()) def _activate_outbound(self): @@ -2332,37 +2457,41 @@ class Transport(threading.Thread, ClosingContextManager): m = Message() m.add_byte(cMSG_NEWKEYS) self._send_message(m) - block_size = self._cipher_info[self.local_cipher]['block-size'] + block_size = self._cipher_info[self.local_cipher]["block-size"] if self.server_mode: - IV_out = self._compute_key('B', block_size) + IV_out = self._compute_key("B", block_size) key_out = self._compute_key( - 'D', self._cipher_info[self.local_cipher]['key-size']) + "D", self._cipher_info[self.local_cipher]["key-size"] + ) else: - IV_out = self._compute_key('A', block_size) + IV_out = self._compute_key("A", block_size) key_out = self._compute_key( - 'C', self._cipher_info[self.local_cipher]['key-size']) + "C", self._cipher_info[self.local_cipher]["key-size"] + ) engine = self._get_cipher( - self.local_cipher, key_out, IV_out, self._ENCRYPT) - mac_size = self._mac_info[self.local_mac]['size'] - mac_engine = self._mac_info[self.local_mac]['class'] + self.local_cipher, key_out, IV_out, self._ENCRYPT + ) + mac_size = self._mac_info[self.local_mac]["size"] + mac_engine = self._mac_info[self.local_mac]["class"] # initial mac keys are done in the hash's natural size (not the # potentially truncated transmission size) if self.server_mode: - mac_key = self._compute_key('F', mac_engine().digest_size) + mac_key = self._compute_key("F", mac_engine().digest_size) else: - mac_key = self._compute_key('E', mac_engine().digest_size) - sdctr = self.local_cipher.endswith('-ctr') + mac_key = self._compute_key("E", mac_engine().digest_size) + sdctr = self.local_cipher.endswith("-ctr") self.packetizer.set_outbound_cipher( - engine, block_size, mac_engine, mac_size, mac_key, sdctr) + engine, block_size, mac_engine, mac_size, mac_key, sdctr + ) compress_out = self._compression_info[self.local_compression][0] if ( - compress_out is not None and - ( - self.local_compression != 'zlib@openssh.com' or - self.authenticated + compress_out is not None + and ( + self.local_compression != "zlib@openssh.com" + or self.authenticated ) ): - self._log(DEBUG, 'Switching on outbound compression ...') + self._log(DEBUG, "Switching on outbound compression ...") self.packetizer.set_outbound_compressor(compress_out()) if not self.packetizer.need_rekey(): self.in_kex = False @@ -2372,17 +2501,17 @@ class Transport(threading.Thread, ClosingContextManager): def _auth_trigger(self): self.authenticated = True # delayed initiation of compression - if self.local_compression == 'zlib@openssh.com': + if self.local_compression == "zlib@openssh.com": compress_out = self._compression_info[self.local_compression][0] - self._log(DEBUG, 'Switching on outbound compression ...') + self._log(DEBUG, "Switching on outbound compression ...") self.packetizer.set_outbound_compressor(compress_out()) - if self.remote_compression == 'zlib@openssh.com': + if self.remote_compression == "zlib@openssh.com": compress_in = self._compression_info[self.remote_compression][1] - self._log(DEBUG, 'Switching on inbound compression ...') + self._log(DEBUG, "Switching on inbound compression ...") self.packetizer.set_inbound_compressor(compress_in()) def _parse_newkeys(self, m): - self._log(DEBUG, 'Switch to new keys ...') + self._log(DEBUG, "Switch to new keys ...") self._activate_inbound() # can also free a bunch of stuff here self.local_kex_init = self.remote_kex_init = None @@ -2410,7 +2539,7 @@ class Transport(threading.Thread, ClosingContextManager): def _parse_disconnect(self, m): code = m.get_int() desc = m.get_text() - self._log(INFO, 'Disconnect (code {:d}): {}'.format(code, desc)) + self._log(INFO, "Disconnect (code {:d}): {}".format(code, desc)) def _parse_global_request(self, m): kind = m.get_text() @@ -2419,16 +2548,16 @@ class Transport(threading.Thread, ClosingContextManager): if not self.server_mode: self._log( DEBUG, - 'Rejecting "{}" global request from server.'.format(kind) + 'Rejecting "{}" global request from server.'.format(kind), ) ok = False - elif kind == 'tcpip-forward': + elif kind == "tcpip-forward": address = m.get_text() port = m.get_int() ok = self.server_object.check_port_forward_request(address, port) if ok: ok = (ok,) - elif kind == 'cancel-tcpip-forward': + elif kind == "cancel-tcpip-forward": address = m.get_text() port = m.get_int() self.server_object.cancel_port_forward_request(address, port) @@ -2449,13 +2578,13 @@ class Transport(threading.Thread, ClosingContextManager): self._send_message(msg) def _parse_request_success(self, m): - self._log(DEBUG, 'Global request successful.') + self._log(DEBUG, "Global request successful.") self.global_response = m if self.completion_event is not None: self.completion_event.set() def _parse_request_failure(self, m): - self._log(DEBUG, 'Global request denied.') + self._log(DEBUG, "Global request denied.") self.global_response = None if self.completion_event is not None: self.completion_event.set() @@ -2467,13 +2596,14 @@ class Transport(threading.Thread, ClosingContextManager): server_max_packet_size = m.get_int() chan = self._channels.get(chanid) if chan is None: - self._log(WARNING, 'Success for unrequested channel! [??]') + self._log(WARNING, "Success for unrequested channel! [??]") return self.lock.acquire() try: chan._set_remote_channel( - server_chanid, server_window_size, server_max_packet_size) - self._log(DEBUG, 'Secsh channel {:d} opened.'.format(chanid)) + server_chanid, server_window_size, server_max_packet_size + ) + self._log(DEBUG, "Secsh channel {:d} opened.".format(chanid)) if chanid in self.channel_events: self.channel_events[chanid].set() del self.channel_events[chanid] @@ -2486,12 +2616,12 @@ class Transport(threading.Thread, ClosingContextManager): reason = m.get_int() reason_str = m.get_text() m.get_text() # ignored language - reason_text = CONNECTION_FAILED_CODE.get(reason, '(unknown code)') + reason_text = CONNECTION_FAILED_CODE.get(reason, "(unknown code)") self._log( ERROR, - 'Secsh channel {:d} open FAILED: {}: {}'.format( - chanid, reason_str, reason_text, - ) + "Secsh channel {:d} open FAILED: {}: {}".format( + chanid, reason_str, reason_text + ), ) self.lock.acquire() try: @@ -2512,39 +2642,39 @@ class Transport(threading.Thread, ClosingContextManager): max_packet_size = m.get_int() reject = False if ( - kind == 'auth-agent@openssh.com' and - self._forward_agent_handler is not None + kind == "auth-agent@openssh.com" + and self._forward_agent_handler is not None ): - self._log(DEBUG, 'Incoming forward agent connection') + self._log(DEBUG, "Incoming forward agent connection") self.lock.acquire() try: my_chanid = self._next_channel() finally: self.lock.release() - elif (kind == 'x11') and (self._x11_handler is not None): + elif (kind == "x11") and (self._x11_handler is not None): origin_addr = m.get_text() origin_port = m.get_int() self._log( DEBUG, - 'Incoming x11 connection from {}:{:d}'.format( - origin_addr, origin_port, - ) + "Incoming x11 connection from {}:{:d}".format( + origin_addr, origin_port + ), ) self.lock.acquire() try: my_chanid = self._next_channel() finally: self.lock.release() - elif (kind == 'forwarded-tcpip') and (self._tcp_handler is not None): + elif (kind == "forwarded-tcpip") and (self._tcp_handler is not None): server_addr = m.get_text() server_port = m.get_int() origin_addr = m.get_text() origin_port = m.get_int() self._log( DEBUG, - 'Incoming tcp forwarded connection from {}:{:d}'.format( - origin_addr, origin_port, - ) + "Incoming tcp forwarded connection from {}:{:d}".format( + origin_addr, origin_port + ), ) self.lock.acquire() try: @@ -2554,7 +2684,8 @@ class Transport(threading.Thread, ClosingContextManager): elif not self.server_mode: self._log( DEBUG, - 'Rejecting "{}" channel request from server.'.format(kind)) + 'Rejecting "{}" channel request from server.'.format(kind), + ) reject = True reason = OPEN_FAILED_ADMINISTRATIVELY_PROHIBITED else: @@ -2563,7 +2694,7 @@ class Transport(threading.Thread, ClosingContextManager): my_chanid = self._next_channel() finally: self.lock.release() - if kind == 'direct-tcpip': + if kind == "direct-tcpip": # handle direct-tcpip requests coming from the client dest_addr = m.get_text() dest_port = m.get_int() @@ -2572,23 +2703,25 @@ class Transport(threading.Thread, ClosingContextManager): reason = self.server_object.check_channel_direct_tcpip_request( my_chanid, (origin_addr, origin_port), - (dest_addr, dest_port) + (dest_addr, dest_port), ) else: reason = self.server_object.check_channel_request( - kind, my_chanid) + kind, my_chanid + ) if reason != OPEN_SUCCEEDED: self._log( DEBUG, - 'Rejecting "{}" channel request from client.'.format(kind)) + 'Rejecting "{}" channel request from client.'.format(kind), + ) reject = True if reject: msg = Message() msg.add_byte(cMSG_CHANNEL_OPEN_FAILURE) msg.add_int(chanid) msg.add_int(reason) - msg.add_string('') - msg.add_string('en') + msg.add_string("") + msg.add_string("en") self._send_message(msg) return @@ -2599,9 +2732,11 @@ class Transport(threading.Thread, ClosingContextManager): self.channels_seen[my_chanid] = True chan._set_transport(self) chan._set_window( - self.default_window_size, self.default_max_packet_size) + self.default_window_size, self.default_max_packet_size + ) chan._set_remote_channel( - chanid, initial_window_size, max_packet_size) + chanid, initial_window_size, max_packet_size + ) finally: self.lock.release() m = Message() @@ -2611,19 +2746,17 @@ class Transport(threading.Thread, ClosingContextManager): m.add_int(self.default_window_size) m.add_int(self.default_max_packet_size) self._send_message(m) - self._log(DEBUG, - 'Secsh channel {:d} ({}) opened.'.format(my_chanid, kind) + self._log( + DEBUG, "Secsh channel {:d} ({}) opened.".format(my_chanid, kind) ) - if kind == 'auth-agent@openssh.com': + if kind == "auth-agent@openssh.com": self._forward_agent_handler(chan) - elif kind == 'x11': + elif kind == "x11": self._x11_handler(chan, (origin_addr, origin_port)) - elif kind == 'forwarded-tcpip': + elif kind == "forwarded-tcpip": chan.origin_addr = (origin_addr, origin_port) self._tcp_handler( - chan, - (origin_addr, origin_port), - (server_addr, server_port) + chan, (origin_addr, origin_port), (server_addr, server_port) ) else: self._queue_incoming_channel(chan) @@ -2632,7 +2765,7 @@ class Transport(threading.Thread, ClosingContextManager): m.get_boolean() # always_display msg = m.get_string() m.get_string() # language - self._log(DEBUG, 'Debug msg: {}'.format(util.safe_string(msg))) + self._log(DEBUG, "Debug msg: {}".format(util.safe_string(msg))) def _get_subsystem_handler(self, name): try: @@ -2666,7 +2799,7 @@ class Transport(threading.Thread, ClosingContextManager): } -class SecurityOptions (object): +class SecurityOptions(object): """ Simple object containing the security preferences of an ssh transport. These are tuples of acceptable ciphers, digests, key types, and key @@ -2678,7 +2811,7 @@ class SecurityOptions (object): ``ValueError`` will be raised. If you try to assign something besides a tuple to one of the fields, ``TypeError`` will be raised. """ - __slots__ = '_transport' + __slots__ = "_transport" def __init__(self, transport): self._transport = transport @@ -2687,17 +2820,17 @@ class SecurityOptions (object): """ Returns a string representation of this object, for debugging. """ - return '<paramiko.SecurityOptions for {!r}>'.format(self._transport) + return "<paramiko.SecurityOptions for {!r}>".format(self._transport) def _set(self, name, orig, x): if type(x) is list: x = tuple(x) if type(x) is not tuple: - raise TypeError('expected tuple or list') + raise TypeError("expected tuple or list") possible = list(getattr(self._transport, orig).keys()) forbidden = [n for n in x if n not in possible] if len(forbidden) > 0: - raise ValueError('unknown cipher') + raise ValueError("unknown cipher") setattr(self._transport, name, x) @property @@ -2707,7 +2840,7 @@ class SecurityOptions (object): @ciphers.setter def ciphers(self, x): - self._set('_preferred_ciphers', '_cipher_info', x) + self._set("_preferred_ciphers", "_cipher_info", x) @property def digests(self): @@ -2716,7 +2849,7 @@ class SecurityOptions (object): @digests.setter def digests(self, x): - self._set('_preferred_macs', '_mac_info', x) + self._set("_preferred_macs", "_mac_info", x) @property def key_types(self): @@ -2725,8 +2858,7 @@ class SecurityOptions (object): @key_types.setter def key_types(self, x): - self._set('_preferred_keys', '_key_info', x) - + self._set("_preferred_keys", "_key_info", x) @property def kex(self): @@ -2735,7 +2867,7 @@ class SecurityOptions (object): @kex.setter def kex(self, x): - self._set('_preferred_kex', '_kex_info', x) + self._set("_preferred_kex", "_kex_info", x) @property def compression(self): @@ -2744,10 +2876,11 @@ class SecurityOptions (object): @compression.setter def compression(self, x): - self._set('_preferred_compression', '_compression_info', x) + self._set("_preferred_compression", "_compression_info", x) + +class ChannelMap(object): -class ChannelMap (object): def __init__(self): # (id -> Channel) self._map = weakref.WeakValueDictionary() diff --git a/paramiko/util.py b/paramiko/util.py index 2854ef98..c60c040c 100644 --- a/paramiko/util.py +++ b/paramiko/util.py @@ -49,7 +49,7 @@ def inflate_long(s, always_positive=False): # noinspection PyAugmentAssignment s = filler * (4 - len(s) % 4) + s for i in range(0, len(s), 4): - out = (out << 32) + struct.unpack('>I', s[i:i + 4])[0] + out = (out << 32) + struct.unpack(">I", s[i:i + 4])[0] if negative: out -= (long(1) << (8 * len(s))) return out @@ -66,7 +66,7 @@ def deflate_long(n, add_sign_padding=True): s = bytes() n = long(n) while (n != 0) and (n != -1): - s = struct.pack('>I', n & xffffffff) + s + s = struct.pack(">I", n & xffffffff) + s n >>= 32 # strip off leading zeros, FFs for i in enumerate(s): @@ -90,7 +90,7 @@ def deflate_long(n, add_sign_padding=True): return s -def format_binary(data, prefix=''): +def format_binary(data, prefix=""): x = 0 out = [] while len(data) > x + 16: @@ -102,22 +102,21 @@ def format_binary(data, prefix=''): def format_binary_line(data): - left = ' '.join(['{:02X}'.format(byte_ord(c)) for c in data]) - right = ''.join([ - '.{:c}..'.format(byte_ord(c))[(byte_ord(c) + 63) // 95] - for c in data - ]) - return '{:50s} {}'.format(left, right) + left = " ".join(["{:02X}".format(byte_ord(c)) for c in data]) + right = "".join( + [".{:c}..".format(byte_ord(c))[(byte_ord(c) + 63) // 95] for c in data] + ) + return "{:50s} {}".format(left, right) def safe_string(s): - out = b'' + out = b"" for c in s: i = byte_ord(c) if 32 <= i <= 127: out += byte_chr(i) else: - out += b('%{:02X}'.format(i)) + out += b("%{:02X}".format(i)) return out @@ -137,7 +136,7 @@ def bit_length(n): def tb_strings(): - return ''.join(traceback.format_exception(*sys.exc_info())).split('\n') + return "".join(traceback.format_exception(*sys.exc_info())).split("\n") def generate_key_bytes(hash_alg, salt, key, nbytes): @@ -188,6 +187,7 @@ def load_host_keys(filename): nested dict of `.PKey` objects, indexed by hostname and then keytype """ from paramiko.hostkeys import HostKeys + return HostKeys(filename) @@ -249,15 +249,16 @@ def log_to_file(filename, level=DEBUG): if len(l.handlers) > 0: return l.setLevel(level) - f = open(filename, 'a') + f = open(filename, "a") lh = logging.StreamHandler(f) - frm = '%(levelname)-.3s [%(asctime)s.%(msecs)03d] thr=%(_threadid)-3d %(name)s: %(message)s' # noqa - lh.setFormatter(logging.Formatter(frm, '%Y%m%d-%H:%M:%S')) + frm = "%(levelname)-.3s [%(asctime)s.%(msecs)03d] thr=%(_threadid)-3d %(name)s: %(message)s" # noqa + lh.setFormatter(logging.Formatter(frm, "%Y%m%d-%H:%M:%S")) l.addHandler(lh) # make only one filter object, so it doesn't get applied more than once -class PFilter (object): +class PFilter(object): + def filter(self, record): record._threadid = get_thread_id() return True @@ -293,6 +294,7 @@ def constant_time_bytes_eq(a, b): class ClosingContextManager(object): + def __enter__(self): return self diff --git a/paramiko/win_pageant.py b/paramiko/win_pageant.py index 661ba575..2bba789d 100644 --- a/paramiko/win_pageant.py +++ b/paramiko/win_pageant.py @@ -44,7 +44,7 @@ win32con_WM_COPYDATA = 74 def _get_pageant_window_object(): - return ctypes.windll.user32.FindWindowA(b'Pageant', b'Pageant') + return ctypes.windll.user32.FindWindowA(b"Pageant", b"Pageant") def can_talk_to_agent(): @@ -57,7 +57,7 @@ def can_talk_to_agent(): return bool(_get_pageant_window_object()) -if platform.architecture()[0] == '64bit': +if platform.architecture()[0] == "64bit": ULONG_PTR = ctypes.c_uint64 else: ULONG_PTR = ctypes.c_uint32 @@ -69,9 +69,9 @@ class COPYDATASTRUCT(ctypes.Structure): http://msdn.microsoft.com/en-us/library/windows/desktop/ms649010%28v=vs.85%29.aspx """ _fields_ = [ - ('num_data', ULONG_PTR), - ('data_size', ctypes.wintypes.DWORD), - ('data_loc', ctypes.c_void_p), + ("num_data", ULONG_PTR), + ("data_size", ctypes.wintypes.DWORD), + ("data_loc", ctypes.c_void_p), ] @@ -86,27 +86,29 @@ def _query_pageant(msg): return None # create a name for the mmap - map_name = 'PageantRequest%08x' % thread.get_ident() + map_name = "PageantRequest%08x" % thread.get_ident() - pymap = _winapi.MemoryMap(map_name, _AGENT_MAX_MSGLEN, - _winapi.get_security_attributes_for_user(), - ) + pymap = _winapi.MemoryMap( + map_name, _AGENT_MAX_MSGLEN, _winapi.get_security_attributes_for_user() + ) with pymap: pymap.write(msg) # Create an array buffer containing the mapped filename char_buffer = array.array("b", b(map_name) + zero_byte) # noqa char_buffer_address, char_buffer_size = char_buffer.buffer_info() # Create a string to use for the SendMessage function call - cds = COPYDATASTRUCT(_AGENT_COPYDATA_ID, char_buffer_size, - char_buffer_address) + cds = COPYDATASTRUCT( + _AGENT_COPYDATA_ID, char_buffer_size, char_buffer_address + ) - response = ctypes.windll.user32.SendMessageA(hwnd, - win32con_WM_COPYDATA, ctypes.sizeof(cds), ctypes.byref(cds)) + response = ctypes.windll.user32.SendMessageA( + hwnd, win32con_WM_COPYDATA, ctypes.sizeof(cds), ctypes.byref(cds) + ) if response > 0: pymap.seek(0) datalen = pymap.read(4) - retlen = struct.unpack('>I', datalen)[0] + retlen = struct.unpack(">I", datalen)[0] return datalen + pymap.read(retlen) return None @@ -127,10 +129,10 @@ class PageantConnection(object): def recv(self, n): if self._response is None: - return '' + return "" ret = self._response[:n] self._response = self._response[n:] - if self._response == '': + if self._response == "": self._response = None return ret diff --git a/tests/conftest.py b/tests/conftest.py index d1967a73..2b509c5c 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -19,7 +19,7 @@ from .util import _support # presenting it on error/failure. (But also allow turning it off when doing # very pinpoint debugging - e.g. using breakpoints, so you don't want output # hiding enabled, but also don't want all the logging to gum up the terminal.) -if not os.environ.get('DISABLE_LOGGING', False): +if not os.environ.get("DISABLE_LOGGING", False): logging.basicConfig( level=logging.DEBUG, # Also make sure to set up timestamping for more sanity when debugging. @@ -43,7 +43,7 @@ def make_sftp_folder(): # TODO: if we want to lock ourselves even harder into localhost-only # testing (probably not?) could use tempdir modules for this for improved # safety. Then again...why would someone have such a folder??? - path = os.environ.get('TEST_FOLDER', 'paramiko-test-target') + path = os.environ.get("TEST_FOLDER", "paramiko-test-target") # Forcibly nuke this directory locally, since at the moment, the below # fixtures only ever run with a locally scoped stub test server. shutil.rmtree(path, ignore_errors=True) @@ -52,7 +52,7 @@ def make_sftp_folder(): return path -@pytest.fixture#(scope='session') +@pytest.fixture # (scope='session') def sftp_server(): """ Set up an in-memory SFTP server thread. Yields the client Transport/socket. @@ -69,17 +69,17 @@ def sftp_server(): tc = Transport(sockc) ts = Transport(socks) # Auth - host_key = RSAKey.from_private_key_file(_support('test_rsa.key')) + host_key = RSAKey.from_private_key_file(_support("test_rsa.key")) ts.add_server_key(host_key) # Server setup event = threading.Event() server = StubServer() - ts.set_subsystem_handler('sftp', SFTPServer, StubSFTPServer) + ts.set_subsystem_handler("sftp", SFTPServer, StubSFTPServer) ts.start_server(event, server) # Wait (so client has time to connect? Not sure. Old.) event.wait(1.0) # Make & yield connection. - tc.connect(username='slowdive', password='pygmalion') + tc.connect(username="slowdive", password="pygmalion") yield tc # TODO: any need for shutdown? Why didn't old suite do so? Or was that the # point of the "join all threads from threading module" crap in test.py? diff --git a/tests/loop.py b/tests/loop.py index 6c432867..dd1f5a0c 100644 --- a/tests/loop.py +++ b/tests/loop.py @@ -22,13 +22,13 @@ import threading from paramiko.common import asbytes -class LoopSocket (object): +class LoopSocket(object): """ A LoopSocket looks like a normal socket, but all data written to it is delivered on the read-end of another LoopSocket, and vice versa. It's like a software "socketpair". """ - + def __init__(self): self.__in_buffer = bytes() self.__lock = threading.Lock() @@ -84,7 +84,7 @@ class LoopSocket (object): self.__cv.notifyAll() finally: self.__lock.release() - + def __unlink(self): m = None self.__lock.acquire() @@ -96,5 +96,3 @@ class LoopSocket (object): self.__lock.release() if m is not None: m.__unlink() - - diff --git a/tests/stub_sftp.py b/tests/stub_sftp.py index 19545865..170304ab 100644 --- a/tests/stub_sftp.py +++ b/tests/stub_sftp.py @@ -24,13 +24,21 @@ import os import sys from paramiko import ( - ServerInterface, SFTPServerInterface, SFTPServer, SFTPAttributes, - SFTPHandle, SFTP_OK, SFTP_FAILURE, AUTH_SUCCESSFUL, OPEN_SUCCEEDED, + ServerInterface, + SFTPServerInterface, + SFTPServer, + SFTPAttributes, + SFTPHandle, + SFTP_OK, + SFTP_FAILURE, + AUTH_SUCCESSFUL, + OPEN_SUCCEEDED, ) from paramiko.common import o666 -class StubServer (ServerInterface): +class StubServer(ServerInterface): + def check_auth_password(self, username, password): # all are allowed return AUTH_SUCCESSFUL @@ -39,7 +47,8 @@ class StubServer (ServerInterface): return OPEN_SUCCEEDED -class StubSFTPHandle (SFTPHandle): +class StubSFTPHandle(SFTPHandle): + def stat(self): try: return SFTPAttributes.from_stat(os.fstat(self.readfile.fileno())) @@ -56,11 +65,11 @@ class StubSFTPHandle (SFTPHandle): return SFTPServer.convert_errno(e.errno) -class StubSFTPServer (SFTPServerInterface): +class StubSFTPServer(SFTPServerInterface): # assume current folder is a fine root # (the tests always create and eventually delete a subfolder, so there shouldn't be any mess) ROOT = os.getcwd() - + def _realpath(self, path): return self.ROOT + self.canonicalize(path) @@ -70,7 +79,9 @@ class StubSFTPServer (SFTPServerInterface): out = [] flist = os.listdir(path) for fname in flist: - attr = SFTPAttributes.from_stat(os.stat(os.path.join(path, fname))) + attr = SFTPAttributes.from_stat( + os.stat(os.path.join(path, fname)) + ) attr.filename = fname out.append(attr) return out @@ -94,9 +105,9 @@ class StubSFTPServer (SFTPServerInterface): def open(self, path, flags, attr): path = self._realpath(path) try: - binary_flag = getattr(os, 'O_BINARY', 0) + binary_flag = getattr(os, "O_BINARY", 0) flags |= binary_flag - mode = getattr(attr, 'st_mode', None) + mode = getattr(attr, "st_mode", None) if mode is not None: fd = os.open(path, flags, mode) else: @@ -110,17 +121,17 @@ class StubSFTPServer (SFTPServerInterface): SFTPServer.set_file_attr(path, attr) if flags & os.O_WRONLY: if flags & os.O_APPEND: - fstr = 'ab' + fstr = "ab" else: - fstr = 'wb' + fstr = "wb" elif flags & os.O_RDWR: if flags & os.O_APPEND: - fstr = 'a+b' + fstr = "a+b" else: - fstr = 'r+b' + fstr = "r+b" else: # O_RDONLY (== 0) - fstr = 'rb' + fstr = "rb" try: f = os.fdopen(fd, fstr) except OSError as e: @@ -159,7 +170,6 @@ class StubSFTPServer (SFTPServerInterface): return SFTPServer.convert_errno(e.errno) return SFTP_OK - def mkdir(self, path, attr): path = self._realpath(path) try: @@ -188,10 +198,10 @@ class StubSFTPServer (SFTPServerInterface): def symlink(self, target_path, path): path = self._realpath(path) - if (len(target_path) > 0) and (target_path[0] == '/'): + if (len(target_path) > 0) and (target_path[0] == "/"): # absolute symlink target_path = os.path.join(self.ROOT, target_path[1:]) - if target_path[:2] == '//': + if target_path[:2] == "//": # bug in os.path.join target_path = target_path[1:] else: @@ -199,7 +209,7 @@ class StubSFTPServer (SFTPServerInterface): abspath = os.path.join(os.path.dirname(path), target_path) if abspath[:len(self.ROOT)] != self.ROOT: # this symlink isn't going to work anyway -- just break it immediately - target_path = '<error>' + target_path = "<error>" try: os.symlink(target_path, path) except OSError as e: @@ -216,8 +226,8 @@ class StubSFTPServer (SFTPServerInterface): if os.path.isabs(symlink): if symlink[:len(self.ROOT)] == self.ROOT: symlink = symlink[len(self.ROOT):] - if (len(symlink) == 0) or (symlink[0] != '/'): - symlink = '/' + symlink + if (len(symlink) == 0) or (symlink[0] != "/"): + symlink = "/" + symlink else: - symlink = '<error>' + symlink = "<error>" return symlink diff --git a/tests/test_auth.py b/tests/test_auth.py index 4eade610..8ad0ac3b 100644 --- a/tests/test_auth.py +++ b/tests/test_auth.py @@ -26,8 +26,13 @@ import unittest from time import sleep from paramiko import ( - Transport, ServerInterface, RSAKey, DSSKey, BadAuthenticationType, - InteractiveQuery, AuthenticationException, + Transport, + ServerInterface, + RSAKey, + DSSKey, + BadAuthenticationType, + InteractiveQuery, + AuthenticationException, ) from paramiko import AUTH_FAILED, AUTH_PARTIALLY_SUCCESSFUL, AUTH_SUCCESSFUL from paramiko.py3compat import u @@ -36,54 +41,57 @@ from .loop import LoopSocket from .util import _support, slow -_pwd = u('\u2022') +_pwd = u("\u2022") -class NullServer (ServerInterface): +class NullServer(ServerInterface): paranoid_did_password = False paranoid_did_public_key = False - paranoid_key = DSSKey.from_private_key_file(_support('test_dss.key')) + paranoid_key = DSSKey.from_private_key_file(_support("test_dss.key")) def get_allowed_auths(self, username): - if username == 'slowdive': - return 'publickey,password' - if username == 'paranoid': - if not self.paranoid_did_password and not self.paranoid_did_public_key: - return 'publickey,password' + if username == "slowdive": + return "publickey,password" + if username == "paranoid": + if ( + not self.paranoid_did_password + and not self.paranoid_did_public_key + ): + return "publickey,password" elif self.paranoid_did_password: - return 'publickey' + return "publickey" else: - return 'password' - if username == 'commie': - return 'keyboard-interactive' - if username == 'utf8': - return 'password' - if username == 'non-utf8': - return 'password' - return 'publickey' + return "password" + if username == "commie": + return "keyboard-interactive" + if username == "utf8": + return "password" + if username == "non-utf8": + return "password" + return "publickey" def check_auth_password(self, username, password): - if (username == 'slowdive') and (password == 'pygmalion'): + if (username == "slowdive") and (password == "pygmalion"): return AUTH_SUCCESSFUL - if (username == 'paranoid') and (password == 'paranoid'): + if (username == "paranoid") and (password == "paranoid"): # 2-part auth (even openssh doesn't support this) self.paranoid_did_password = True if self.paranoid_did_public_key: return AUTH_SUCCESSFUL return AUTH_PARTIALLY_SUCCESSFUL - if (username == 'utf8') and (password == _pwd): + if (username == "utf8") and (password == _pwd): return AUTH_SUCCESSFUL - if (username == 'non-utf8') and (password == '\xff'): + if (username == "non-utf8") and (password == "\xff"): return AUTH_SUCCESSFUL - if username == 'bad-server': + if username == "bad-server": raise Exception("Ack!") - if username == 'unresponsive-server': + if username == "unresponsive-server": sleep(5) return AUTH_SUCCESSFUL return AUTH_FAILED def check_auth_publickey(self, username, key): - if (username == 'paranoid') and (key == self.paranoid_key): + if (username == "paranoid") and (key == self.paranoid_key): # 2-part auth self.paranoid_did_public_key = True if self.paranoid_did_password: @@ -92,19 +100,21 @@ class NullServer (ServerInterface): return AUTH_FAILED def check_auth_interactive(self, username, submethods): - if username == 'commie': + if username == "commie": self.username = username - return InteractiveQuery('password', 'Please enter a password.', ('Password', False)) + return InteractiveQuery( + "password", "Please enter a password.", ("Password", False) + ) return AUTH_FAILED def check_auth_interactive_response(self, responses): - if self.username == 'commie': - if (len(responses) == 1) and (responses[0] == 'cat'): + if self.username == "commie": + if (len(responses) == 1) and (responses[0] == "cat"): return AUTH_SUCCESSFUL return AUTH_FAILED -class AuthTest (unittest.TestCase): +class AuthTest(unittest.TestCase): def setUp(self): self.socks = LoopSocket() @@ -120,7 +130,7 @@ class AuthTest (unittest.TestCase): self.sockc.close() def start_server(self): - host_key = RSAKey.from_private_key_file(_support('test_rsa.key')) + host_key = RSAKey.from_private_key_file(_support("test_rsa.key")) self.public_host_key = RSAKey(data=host_key.asbytes()) self.ts.add_server_key(host_key) self.event = threading.Event() @@ -140,13 +150,16 @@ class AuthTest (unittest.TestCase): """ self.start_server() try: - self.tc.connect(hostkey=self.public_host_key, - username='unknown', password='error') + self.tc.connect( + hostkey=self.public_host_key, + username="unknown", + password="error", + ) self.assertTrue(False) except: etype, evalue, etb = sys.exc_info() self.assertEqual(BadAuthenticationType, etype) - self.assertEqual(['publickey'], evalue.allowed_types) + self.assertEqual(["publickey"], evalue.allowed_types) def test_2_bad_password(self): """ @@ -156,12 +169,12 @@ class AuthTest (unittest.TestCase): self.start_server() self.tc.connect(hostkey=self.public_host_key) try: - self.tc.auth_password(username='slowdive', password='error') + self.tc.auth_password(username="slowdive", password="error") self.assertTrue(False) except: etype, evalue, etb = sys.exc_info() self.assertTrue(issubclass(etype, AuthenticationException)) - self.tc.auth_password(username='slowdive', password='pygmalion') + self.tc.auth_password(username="slowdive", password="pygmalion") self.verify_finished() def test_3_multipart_auth(self): @@ -170,10 +183,12 @@ class AuthTest (unittest.TestCase): """ self.start_server() self.tc.connect(hostkey=self.public_host_key) - remain = self.tc.auth_password(username='paranoid', password='paranoid') - self.assertEqual(['publickey'], remain) - key = DSSKey.from_private_key_file(_support('test_dss.key')) - remain = self.tc.auth_publickey(username='paranoid', key=key) + remain = self.tc.auth_password( + username="paranoid", password="paranoid" + ) + self.assertEqual(["publickey"], remain) + key = DSSKey.from_private_key_file(_support("test_dss.key")) + remain = self.tc.auth_publickey(username="paranoid", key=key) self.assertEqual([], remain) self.verify_finished() @@ -188,10 +203,11 @@ class AuthTest (unittest.TestCase): self.got_title = title self.got_instructions = instructions self.got_prompts = prompts - return ['cat'] - remain = self.tc.auth_interactive('commie', handler) - self.assertEqual(self.got_title, 'password') - self.assertEqual(self.got_prompts, [('Password', False)]) + return ["cat"] + + remain = self.tc.auth_interactive("commie", handler) + self.assertEqual(self.got_title, "password") + self.assertEqual(self.got_prompts, [("Password", False)]) self.assertEqual([], remain) self.verify_finished() @@ -202,7 +218,7 @@ class AuthTest (unittest.TestCase): """ self.start_server() self.tc.connect(hostkey=self.public_host_key) - remain = self.tc.auth_password('commie', 'cat') + remain = self.tc.auth_password("commie", "cat") self.assertEqual([], remain) self.verify_finished() @@ -212,7 +228,7 @@ class AuthTest (unittest.TestCase): """ self.start_server() self.tc.connect(hostkey=self.public_host_key) - remain = self.tc.auth_password('utf8', _pwd) + remain = self.tc.auth_password("utf8", _pwd) self.assertEqual([], remain) self.verify_finished() @@ -223,7 +239,7 @@ class AuthTest (unittest.TestCase): """ self.start_server() self.tc.connect(hostkey=self.public_host_key) - remain = self.tc.auth_password('non-utf8', '\xff') + remain = self.tc.auth_password("non-utf8", "\xff") self.assertEqual([], remain) self.verify_finished() @@ -235,7 +251,7 @@ class AuthTest (unittest.TestCase): self.start_server() self.tc.connect(hostkey=self.public_host_key) try: - remain = self.tc.auth_password('bad-server', 'hello') + remain = self.tc.auth_password("bad-server", "hello") except: etype, evalue, etb = sys.exc_info() self.assertTrue(issubclass(etype, AuthenticationException)) @@ -250,8 +266,8 @@ class AuthTest (unittest.TestCase): self.start_server() self.tc.connect() try: - remain = self.tc.auth_password('unresponsive-server', 'hello') + remain = self.tc.auth_password("unresponsive-server", "hello") except: etype, evalue, etb = sys.exc_info() self.assertTrue(issubclass(etype, AuthenticationException)) - self.assertTrue('Authentication timeout' in str(evalue)) + self.assertTrue("Authentication timeout" in str(evalue)) diff --git a/tests/test_buffered_pipe.py b/tests/test_buffered_pipe.py index 03616c55..9f986a5e 100644 --- a/tests/test_buffered_pipe.py +++ b/tests/test_buffered_pipe.py @@ -30,9 +30,9 @@ from paramiko.py3compat import b def delay_thread(p): - p.feed('a') + p.feed("a") time.sleep(0.5) - p.feed('b') + p.feed("b") p.close() @@ -42,41 +42,42 @@ def close_thread(p): class BufferedPipeTest(unittest.TestCase): + def test_1_buffered_pipe(self): p = BufferedPipe() self.assertTrue(not p.read_ready()) - p.feed('hello.') + p.feed("hello.") self.assertTrue(p.read_ready()) data = p.read(6) - self.assertEqual(b'hello.', data) + self.assertEqual(b"hello.", data) - p.feed('plus/minus') - self.assertEqual(b'plu', p.read(3)) - self.assertEqual(b's/m', p.read(3)) - self.assertEqual(b'inus', p.read(4)) + p.feed("plus/minus") + self.assertEqual(b"plu", p.read(3)) + self.assertEqual(b"s/m", p.read(3)) + self.assertEqual(b"inus", p.read(4)) p.close() self.assertTrue(not p.read_ready()) - self.assertEqual(b'', p.read(1)) + self.assertEqual(b"", p.read(1)) def test_2_delay(self): p = BufferedPipe() self.assertTrue(not p.read_ready()) threading.Thread(target=delay_thread, args=(p,)).start() - self.assertEqual(b'a', p.read(1, 0.1)) + self.assertEqual(b"a", p.read(1, 0.1)) try: p.read(1, 0.1) self.assertTrue(False) except PipeTimeout: pass - self.assertEqual(b'b', p.read(1, 1.0)) - self.assertEqual(b'', p.read(1)) + self.assertEqual(b"b", p.read(1, 1.0)) + self.assertEqual(b"", p.read(1)) def test_3_close_while_reading(self): p = BufferedPipe() threading.Thread(target=close_thread, args=(p,)).start() data = p.read(1, 1.0) - self.assertEqual(b'', data) + self.assertEqual(b"", data) def test_4_or_pipe(self): p = pipe.make_pipe() diff --git a/tests/test_client.py b/tests/test_client.py index 7163fdcf..4a1e9829 100644 --- a/tests/test_client.py +++ b/tests/test_client.py @@ -48,30 +48,31 @@ requires_gss_auth = unittest.skipUnless( ) FINGERPRINTS = { - 'ssh-dss': b'\x44\x78\xf0\xb9\xa2\x3c\xc5\x18\x20\x09\xff\x75\x5b\xc1\xd2\x6c', - 'ssh-rsa': b'\x60\x73\x38\x44\xcb\x51\x86\x65\x7f\xde\xda\xa2\x2b\x5a\x57\xd5', - 'ecdsa-sha2-nistp256': b'\x25\x19\xeb\x55\xe6\xa1\x47\xff\x4f\x38\xd2\x75\x6f\xa5\xd5\x60', - 'ssh-ed25519': b'\xb3\xd5"\xaa\xf9u^\xe8\xcd\x0e\xea\x02\xb9)\xa2\x80', + "ssh-dss": b"\x44\x78\xf0\xb9\xa2\x3c\xc5\x18\x20\x09\xff\x75\x5b\xc1\xd2\x6c", + "ssh-rsa": b"\x60\x73\x38\x44\xcb\x51\x86\x65\x7f\xde\xda\xa2\x2b\x5a\x57\xd5", + "ecdsa-sha2-nistp256": b"\x25\x19\xeb\x55\xe6\xa1\x47\xff\x4f\x38\xd2\x75\x6f\xa5\xd5\x60", + "ssh-ed25519": b'\xb3\xd5"\xaa\xf9u^\xe8\xcd\x0e\xea\x02\xb9)\xa2\x80', } class NullServer(paramiko.ServerInterface): + def __init__(self, *args, **kwargs): # Allow tests to enable/disable specific key types - self.__allowed_keys = kwargs.pop('allowed_keys', []) + self.__allowed_keys = kwargs.pop("allowed_keys", []) # And allow them to set a (single...meh) expected public blob (cert) - self.__expected_public_blob = kwargs.pop('public_blob', None) + self.__expected_public_blob = kwargs.pop("public_blob", None) super(NullServer, self).__init__(*args, **kwargs) def get_allowed_auths(self, username): - if username == 'slowdive': - return 'publickey,password' - return 'publickey' + if username == "slowdive": + return "publickey,password" + return "publickey" def check_auth_password(self, username, password): - if (username == 'slowdive') and (password == 'pygmalion'): + if (username == "slowdive") and (password == "pygmalion"): return paramiko.AUTH_SUCCESSFUL - if (username == 'slowdive') and (password == 'unresponsive-server'): + if (username == "slowdive") and (password == "unresponsive-server"): time.sleep(5) return paramiko.AUTH_SUCCESSFUL return paramiko.AUTH_FAILED @@ -83,13 +84,13 @@ class NullServer(paramiko.ServerInterface): return paramiko.AUTH_FAILED # Base check: allowed auth type & fingerprint matches happy = ( - key.get_name() in self.__allowed_keys and - key.get_fingerprint() == expected + key.get_name() in self.__allowed_keys + and key.get_fingerprint() == expected ) # Secondary check: if test wants assertions about cert data if ( - self.__expected_public_blob is not None and - key.public_blob != self.__expected_public_blob + self.__expected_public_blob is not None + and key.public_blob != self.__expected_public_blob ): happy = False return paramiko.AUTH_SUCCESSFUL if happy else paramiko.AUTH_FAILED @@ -98,31 +99,32 @@ class NullServer(paramiko.ServerInterface): return paramiko.OPEN_SUCCEEDED def check_channel_exec_request(self, channel, command): - if command != b'yes': + if command != b"yes": return False return True def check_channel_env_request(self, channel, name, value): - if name == 'INVALID_ENV': + if name == "INVALID_ENV": return False - if not hasattr(channel, 'env'): - setattr(channel, 'env', {}) + if not hasattr(channel, "env"): + setattr(channel, "env", {}) channel.env[name] = value return True class ClientTest(unittest.TestCase): + def setUp(self): self.sockl = socket.socket() - self.sockl.bind(('localhost', 0)) + self.sockl.bind(("localhost", 0)) self.sockl.listen(1) self.addr, self.port = self.sockl.getsockname() self.connect_kwargs = dict( hostname=self.addr, port=self.port, - username='slowdive', + username="slowdive", look_for_keys=False, ) self.event = threading.Event() @@ -130,10 +132,10 @@ class ClientTest(unittest.TestCase): def tearDown(self): # Shut down client Transport - if hasattr(self, 'tc'): + if hasattr(self, "tc"): self.tc.close() # Shut down shared socket - if hasattr(self, 'sockl'): + if hasattr(self, "sockl"): # Signal to server thread that it should shut down early; it checks # this immediately after accept(). (In scenarios where connection # actually succeeded during the test, this becomes a no-op.) @@ -151,7 +153,7 @@ class ClientTest(unittest.TestCase): self.sockl.close() def _run( - self, allowed_keys=None, delay=0, public_blob=None, kill_event=None, + self, allowed_keys=None, delay=0, public_blob=None, kill_event=None ): if allowed_keys is None: allowed_keys = FINGERPRINTS.keys() @@ -163,10 +165,10 @@ class ClientTest(unittest.TestCase): self.socks.close() return self.ts = paramiko.Transport(self.socks) - keypath = _support('test_rsa.key') + keypath = _support("test_rsa.key") host_key = paramiko.RSAKey.from_private_key_file(keypath) self.ts.add_server_key(host_key) - keypath = _support('test_ecdsa_256.key') + keypath = _support("test_ecdsa_256.key") host_key = paramiko.ECDSAKey.from_private_key_file(keypath) self.ts.add_server_key(host_key) server = NullServer(allowed_keys=allowed_keys, public_blob=public_blob) @@ -181,17 +183,21 @@ class ClientTest(unittest.TestCase): The exception is ``allowed_keys`` which is stripped and handed to the ``NullServer`` used for testing. """ - run_kwargs = {'kill_event': self.kill_event} - for key in ('allowed_keys', 'public_blob'): + run_kwargs = {"kill_event": self.kill_event} + for key in ("allowed_keys", "public_blob"): run_kwargs[key] = kwargs.pop(key, None) # Server setup threading.Thread(target=self._run, kwargs=run_kwargs).start() - host_key = paramiko.RSAKey.from_private_key_file(_support('test_rsa.key')) + host_key = paramiko.RSAKey.from_private_key_file( + _support("test_rsa.key") + ) public_host_key = paramiko.RSAKey(data=host_key.asbytes()) # Client setup self.tc = paramiko.SSHClient() - self.tc.get_host_keys().add('[%s]:%d' % (self.addr, self.port), 'ssh-rsa', public_host_key) + self.tc.get_host_keys().add( + "[%s]:%d" % (self.addr, self.port), "ssh-rsa", public_host_key + ) # Actual connection self.tc.connect(**dict(self.connect_kwargs, **kwargs)) @@ -200,22 +206,22 @@ class ClientTest(unittest.TestCase): self.event.wait(1.0) self.assertTrue(self.event.is_set()) self.assertTrue(self.ts.is_active()) - self.assertEqual('slowdive', self.ts.get_username()) + self.assertEqual("slowdive", self.ts.get_username()) self.assertEqual(True, self.ts.is_authenticated()) self.assertEqual(False, self.tc.get_transport().gss_kex_used) # Command execution functions? - stdin, stdout, stderr = self.tc.exec_command('yes') + stdin, stdout, stderr = self.tc.exec_command("yes") schan = self.ts.accept(1.0) - schan.send('Hello there.\n') - schan.send_stderr('This is on stderr.\n') + schan.send("Hello there.\n") + schan.send_stderr("This is on stderr.\n") schan.close() - self.assertEqual('Hello there.\n', stdout.readline()) - self.assertEqual('', stdout.readline()) - self.assertEqual('This is on stderr.\n', stderr.readline()) - self.assertEqual('', stderr.readline()) + self.assertEqual("Hello there.\n", stdout.readline()) + self.assertEqual("", stdout.readline()) + self.assertEqual("This is on stderr.\n", stderr.readline()) + self.assertEqual("", stderr.readline()) # Cleanup stdin.close() @@ -224,32 +230,33 @@ class ClientTest(unittest.TestCase): class SSHClientTest(ClientTest): + def test_1_client(self): """ verify that the SSHClient stuff works too. """ - self._test_connection(password='pygmalion') + self._test_connection(password="pygmalion") def test_2_client_dsa(self): """ verify that SSHClient works with a DSA key. """ - self._test_connection(key_filename=_support('test_dss.key')) + self._test_connection(key_filename=_support("test_dss.key")) def test_client_rsa(self): """ verify that SSHClient works with an RSA key. """ - self._test_connection(key_filename=_support('test_rsa.key')) + self._test_connection(key_filename=_support("test_rsa.key")) def test_2_5_client_ecdsa(self): """ verify that SSHClient works with an ECDSA key. """ - self._test_connection(key_filename=_support('test_ecdsa_256.key')) + self._test_connection(key_filename=_support("test_ecdsa_256.key")) def test_client_ed25519(self): - self._test_connection(key_filename=_support('test_ed25519.key')) + self._test_connection(key_filename=_support("test_ed25519.key")) def test_3_multiple_key_files(self): """ @@ -257,22 +264,20 @@ class SSHClientTest(ClientTest): """ # This is dumb :( types_ = { - 'rsa': 'ssh-rsa', - 'dss': 'ssh-dss', - 'ecdsa': 'ecdsa-sha2-nistp256', + "rsa": "ssh-rsa", "dss": "ssh-dss", "ecdsa": "ecdsa-sha2-nistp256" } # Various combos of attempted & valid keys # TODO: try every possible combo using itertools functions for attempt, accept in ( - (['rsa', 'dss'], ['dss']), # Original test #3 - (['dss', 'rsa'], ['dss']), # Ordering matters sometimes, sadly - (['dss', 'rsa', 'ecdsa_256'], ['dss']), # Try ECDSA but fail - (['rsa', 'ecdsa_256'], ['ecdsa']), # ECDSA success + (["rsa", "dss"], ["dss"]), # Original test #3 + (["dss", "rsa"], ["dss"]), # Ordering matters sometimes, sadly + (["dss", "rsa", "ecdsa_256"], ["dss"]), # Try ECDSA but fail + (["rsa", "ecdsa_256"], ["ecdsa"]), # ECDSA success ): try: self._test_connection( key_filename=[ - _support('test_{}.key'.format(x)) for x in attempt + _support("test_{}.key".format(x)) for x in attempt ], allowed_keys=[types_[x] for x in accept], ) @@ -288,10 +293,11 @@ class SSHClientTest(ClientTest): """ # Until #387 is fixed we have to catch a high-up exception since # various platforms trigger different errors here >_< - self.assertRaises(SSHException, + self.assertRaises( + SSHException, self._test_connection, - key_filename=[_support('test_rsa.key')], - allowed_keys=['ecdsa-sha2-nistp256'], + key_filename=[_support("test_rsa.key")], + allowed_keys=["ecdsa-sha2-nistp256"], ) def test_certs_allowed_as_key_filename_values(self): @@ -299,9 +305,9 @@ class SSHClientTest(ClientTest): # They're similar except for which path is given; the expected auth and # server-side behavior is 100% identical.) # NOTE: only bothered whipping up one cert per overall class/family. - for type_ in ('rsa', 'dss', 'ecdsa_256', 'ed25519'): - cert_name = 'test_{}.key-cert.pub'.format(type_) - cert_path = _support(os.path.join('cert_support', cert_name)) + for type_ in ("rsa", "dss", "ecdsa_256", "ed25519"): + cert_name = "test_{}.key-cert.pub".format(type_) + cert_path = _support(os.path.join("cert_support", cert_name)) self._test_connection( key_filename=cert_path, public_blob=PublicBlob.from_file(cert_path), @@ -314,13 +320,13 @@ class SSHClientTest(ClientTest): # about the server-side key object's public blob. Thus, we can prove # that a specific cert was found, along with regular authorization # succeeding proving that the overall flow works. - for type_ in ('rsa', 'dss', 'ecdsa_256', 'ed25519'): - key_name = 'test_{}.key'.format(type_) - key_path = _support(os.path.join('cert_support', key_name)) + for type_ in ("rsa", "dss", "ecdsa_256", "ed25519"): + key_name = "test_{}.key".format(type_) + key_path = _support(os.path.join("cert_support", key_name)) self._test_connection( key_filename=key_path, public_blob=PublicBlob.from_file( - '{}-cert.pub'.format(key_path) + "{}-cert.pub".format(key_path) ), ) @@ -335,19 +341,19 @@ class SSHClientTest(ClientTest): verify that SSHClient's AutoAddPolicy works. """ threading.Thread(target=self._run).start() - hostname = '[%s]:%d' % (self.addr, self.port) - key_file = _support('test_ecdsa_256.key') + hostname = "[%s]:%d" % (self.addr, self.port) + key_file = _support("test_ecdsa_256.key") public_host_key = paramiko.ECDSAKey.from_private_key_file(key_file) self.tc = paramiko.SSHClient() self.tc.set_missing_host_key_policy(paramiko.AutoAddPolicy()) self.assertEqual(0, len(self.tc.get_host_keys())) - self.tc.connect(password='pygmalion', **self.connect_kwargs) + self.tc.connect(password="pygmalion", **self.connect_kwargs) self.event.wait(1.0) self.assertTrue(self.event.is_set()) self.assertTrue(self.ts.is_active()) - self.assertEqual('slowdive', self.ts.get_username()) + self.assertEqual("slowdive", self.ts.get_username()) self.assertEqual(True, self.ts.is_authenticated()) self.assertEqual(1, len(self.tc.get_host_keys())) new_host_key = list(self.tc.get_host_keys()[hostname].values())[0] @@ -357,9 +363,11 @@ class SSHClientTest(ClientTest): """ verify that SSHClient correctly saves a known_hosts file. """ - warnings.filterwarnings('ignore', 'tempnam.*') + warnings.filterwarnings("ignore", "tempnam.*") - host_key = paramiko.RSAKey.from_private_key_file(_support('test_rsa.key')) + host_key = paramiko.RSAKey.from_private_key_file( + _support("test_rsa.key") + ) public_host_key = paramiko.RSAKey(data=host_key.asbytes()) fd, localname = mkstemp() os.close(fd) @@ -367,11 +375,13 @@ class SSHClientTest(ClientTest): client = paramiko.SSHClient() self.assertEquals(0, len(client.get_host_keys())) - host_id = '[%s]:%d' % (self.addr, self.port) + host_id = "[%s]:%d" % (self.addr, self.port) - client.get_host_keys().add(host_id, 'ssh-rsa', public_host_key) + client.get_host_keys().add(host_id, "ssh-rsa", public_host_key) self.assertEquals(1, len(client.get_host_keys())) - self.assertEquals(public_host_key, client.get_host_keys()[host_id]['ssh-rsa']) + self.assertEquals( + public_host_key, client.get_host_keys()[host_id]["ssh-rsa"] + ) client.save_host_keys(localname) @@ -394,7 +404,7 @@ class SSHClientTest(ClientTest): self.tc = paramiko.SSHClient() self.tc.set_missing_host_key_policy(paramiko.AutoAddPolicy()) self.assertEqual(0, len(self.tc.get_host_keys())) - self.tc.connect(**dict(self.connect_kwargs, password='pygmalion')) + self.tc.connect(**dict(self.connect_kwargs, password="pygmalion")) self.event.wait(1.0) self.assertTrue(self.event.is_set()) @@ -423,7 +433,7 @@ class SSHClientTest(ClientTest): self.tc = tc self.tc.set_missing_host_key_policy(paramiko.AutoAddPolicy()) self.assertEquals(0, len(self.tc.get_host_keys())) - self.tc.connect(**dict(self.connect_kwargs, password='pygmalion')) + self.tc.connect(**dict(self.connect_kwargs, password="pygmalion")) self.event.wait(1.0) self.assertTrue(self.event.is_set()) @@ -438,19 +448,19 @@ class SSHClientTest(ClientTest): verify that the SSHClient has a configurable banner timeout. """ # Start the thread with a 1 second wait. - threading.Thread(target=self._run, kwargs={'delay': 1}).start() - host_key = paramiko.RSAKey.from_private_key_file(_support('test_rsa.key')) + threading.Thread(target=self._run, kwargs={"delay": 1}).start() + host_key = paramiko.RSAKey.from_private_key_file( + _support("test_rsa.key") + ) public_host_key = paramiko.RSAKey(data=host_key.asbytes()) self.tc = paramiko.SSHClient() - self.tc.get_host_keys().add('[%s]:%d' % (self.addr, self.port), 'ssh-rsa', public_host_key) + self.tc.get_host_keys().add( + "[%s]:%d" % (self.addr, self.port), "ssh-rsa", public_host_key + ) # Connect with a half second banner timeout. kwargs = dict(self.connect_kwargs, banner_timeout=0.5) - self.assertRaises( - paramiko.SSHException, - self.tc.connect, - **kwargs - ) + self.assertRaises(paramiko.SSHException, self.tc.connect, **kwargs) def test_8_auth_trickledown(self): """ @@ -466,9 +476,9 @@ class SSHClientTest(ClientTest): # 'television' as per tests/test_pkey.py). NOTE: must use # key_filename, loading the actual key here with PKey will except # immediately; we're testing the try/except crap within Client. - key_filename=[_support('test_rsa_password.key')], + key_filename=[_support("test_rsa_password.key")], # Actual password for default 'slowdive' user - password='pygmalion', + password="pygmalion", ) self._test_connection(**kwargs) @@ -481,7 +491,7 @@ class SSHClientTest(ClientTest): self.assertRaises( AuthenticationException, self._test_connection, - password='unresponsive-server', + password="unresponsive-server", auth_timeout=0.5, ) @@ -490,10 +500,7 @@ class SSHClientTest(ClientTest): """ Failed gssapi-keyex auth doesn't prevent subsequent key auth from succeeding """ - kwargs = dict( - gss_kex=True, - key_filename=[_support('test_rsa.key')], - ) + kwargs = dict(gss_kex=True, key_filename=[_support("test_rsa.key")]) self._test_connection(**kwargs) @requires_gss_auth @@ -501,10 +508,7 @@ class SSHClientTest(ClientTest): """ Failed gssapi-with-mic auth doesn't prevent subsequent key auth from succeeding """ - kwargs = dict( - gss_auth=True, - key_filename=[_support('test_rsa.key')], - ) + kwargs = dict(gss_auth=True, key_filename=[_support("test_rsa.key")]) self._test_connection(**kwargs) def test_12_reject_policy(self): @@ -519,7 +523,8 @@ class SSHClientTest(ClientTest): self.assertRaises( paramiko.SSHException, self.tc.connect, - password='pygmalion', **self.connect_kwargs + password="pygmalion", + **self.connect_kwargs ) @requires_gss_auth @@ -537,14 +542,14 @@ class SSHClientTest(ClientTest): self.assertRaises( paramiko.SSHException, self.tc.connect, - password='pygmalion', + password="pygmalion", gss_kex=True, - **self.connect_kwargs + **self.connect_kwargs ) def _client_host_key_bad(self, host_key): threading.Thread(target=self._run).start() - hostname = '[%s]:%d' % (self.addr, self.port) + hostname = "[%s]:%d" % (self.addr, self.port) self.tc = paramiko.SSHClient() self.tc.set_missing_host_key_policy(paramiko.WarningPolicy()) @@ -554,13 +559,13 @@ class SSHClientTest(ClientTest): self.assertRaises( paramiko.BadHostKeyException, self.tc.connect, - password='pygmalion', + password="pygmalion", **self.connect_kwargs ) def _client_host_key_good(self, ktype, kfile): threading.Thread(target=self._run).start() - hostname = '[%s]:%d' % (self.addr, self.port) + hostname = "[%s]:%d" % (self.addr, self.port) self.tc = paramiko.SSHClient() self.tc.set_missing_host_key_policy(paramiko.RejectPolicy()) @@ -568,7 +573,7 @@ class SSHClientTest(ClientTest): known_hosts = self.tc.get_host_keys() known_hosts.add(hostname, host_key.get_name(), host_key) - self.tc.connect(password='pygmalion', **self.connect_kwargs) + self.tc.connect(password="pygmalion", **self.connect_kwargs) self.event.wait(1.0) self.assertTrue(self.event.is_set()) self.assertTrue(self.ts.is_active()) @@ -583,10 +588,10 @@ class SSHClientTest(ClientTest): self._client_host_key_bad(host_key) def test_host_key_negotiation_3(self): - self._client_host_key_good(paramiko.ECDSAKey, 'test_ecdsa_256.key') + self._client_host_key_good(paramiko.ECDSAKey, "test_ecdsa_256.key") def test_host_key_negotiation_4(self): - self._client_host_key_good(paramiko.RSAKey, 'test_rsa.key') + self._client_host_key_good(paramiko.RSAKey, "test_rsa.key") def _setup_for_env(self): threading.Thread(target=self._run).start() @@ -594,7 +599,9 @@ class SSHClientTest(ClientTest): self.tc = paramiko.SSHClient() self.tc.set_missing_host_key_policy(paramiko.AutoAddPolicy()) self.assertEqual(0, len(self.tc.get_host_keys())) - self.tc.connect(self.addr, self.port, username='slowdive', password='pygmalion') + self.tc.connect( + self.addr, self.port, username="slowdive", password="pygmalion" + ) self.event.wait(1.0) self.assertTrue(self.event.isSet()) @@ -605,11 +612,11 @@ class SSHClientTest(ClientTest): Verify that environment variables can be set by the client. """ self._setup_for_env() - target_env = {b'A': b'B', b'C': b'd'} + target_env = {b"A": b"B", b"C": b"d"} - self.tc.exec_command('yes', environment=target_env) + self.tc.exec_command("yes", environment=target_env) schan = self.ts.accept(1.0) - self.assertEqual(target_env, getattr(schan, 'env', {})) + self.assertEqual(target_env, getattr(schan, "env", {})) schan.close() @unittest.skip("Clients normally fail silently, thus so do we, for now") @@ -617,14 +624,14 @@ class SSHClientTest(ClientTest): self._setup_for_env() with self.assertRaises(SSHException) as manager: # Verify that a rejection by the server can be detected - self.tc.exec_command('yes', environment={b'INVALID_ENV': b''}) + self.tc.exec_command("yes", environment={b"INVALID_ENV": b""}) self.assertTrue( - 'INVALID_ENV' in str(manager.exception), - 'Expected variable name in error message' + "INVALID_ENV" in str(manager.exception), + "Expected variable name in error message", ) self.assertTrue( isinstance(manager.exception.args[1], SSHException), - 'Expected original SSHException in exception' + "Expected original SSHException in exception", ) def test_missing_key_policy_accepts_classes_or_instances(self): @@ -652,35 +659,39 @@ class PasswordPassphraseTests(ClientTest): def test_password_kwarg_works_for_password_auth(self): # Straightforward / duplicate of earlier basic password test. - self._test_connection(password='pygmalion') + self._test_connection(password="pygmalion") # TODO: more granular exception pending #387; should be signaling "no auth # methods available" because no key and no password @raises(SSHException) def test_passphrase_kwarg_not_used_for_password_auth(self): # Using the "right" password in the "wrong" field shouldn't work. - self._test_connection(passphrase='pygmalion') + self._test_connection(passphrase="pygmalion") def test_passphrase_kwarg_used_for_key_passphrase(self): # Straightforward again, with new passphrase kwarg. self._test_connection( - key_filename=_support('test_rsa_password.key'), - passphrase='television', + key_filename=_support("test_rsa_password.key"), + passphrase="television", ) - def test_password_kwarg_used_for_passphrase_when_no_passphrase_kwarg_given(self): # noqa + def test_password_kwarg_used_for_passphrase_when_no_passphrase_kwarg_given( + self + ): # noqa # Backwards compatibility: passphrase in the password field. self._test_connection( - key_filename=_support('test_rsa_password.key'), - password='television', + key_filename=_support("test_rsa_password.key"), + password="television", ) - @raises(AuthenticationException) # TODO: more granular - def test_password_kwarg_not_used_for_passphrase_when_passphrase_kwarg_given(self): # noqa + @raises(AuthenticationException) # TODO: more granular + def test_password_kwarg_not_used_for_passphrase_when_passphrase_kwarg_given( + self + ): # noqa # Sanity: if we're given both fields, the password field is NOT used as # a passphrase. self._test_connection( - key_filename=_support('test_rsa_password.key'), - password='television', - passphrase='wat? lol no', + key_filename=_support("test_rsa_password.key"), + password="television", + passphrase="wat? lol no", ) diff --git a/tests/test_file.py b/tests/test_file.py index 3d2c94e6..d2990118 100644 --- a/tests/test_file.py +++ b/tests/test_file.py @@ -30,18 +30,19 @@ from paramiko.py3compat import BytesIO from .util import needs_builtin -class LoopbackFile (BufferedFile): +class LoopbackFile(BufferedFile): """ BufferedFile object that you can write data into, and then read it back. """ - def __init__(self, mode='r', bufsize=-1): + + def __init__(self, mode="r", bufsize=-1): BufferedFile.__init__(self) self._set_mode(mode, bufsize) self.buffer = BytesIO() self.offset = 0 def _read(self, size): - data = self.buffer.getvalue()[self.offset:self.offset+size] + data = self.buffer.getvalue()[self.offset:self.offset + size] self.offset += len(data) return data @@ -50,44 +51,46 @@ class LoopbackFile (BufferedFile): return len(data) -class BufferedFileTest (unittest.TestCase): +class BufferedFileTest(unittest.TestCase): def test_1_simple(self): - f = LoopbackFile('r') + f = LoopbackFile("r") try: - f.write(b'hi') - self.assertTrue(False, 'no exception on write to read-only file') + f.write(b"hi") + self.assertTrue(False, "no exception on write to read-only file") except: pass f.close() - f = LoopbackFile('w') + f = LoopbackFile("w") try: f.read(1) - self.assertTrue(False, 'no exception to read from write-only file') + self.assertTrue(False, "no exception to read from write-only file") except: pass f.close() def test_2_readline(self): - f = LoopbackFile('r+U') - f.write(b'First line.\nSecond line.\r\nThird line.\n' + - b'Fourth line.\nFinal line non-terminated.') + f = LoopbackFile("r+U") + f.write( + b"First line.\nSecond line.\r\nThird line.\n" + + b"Fourth line.\nFinal line non-terminated." + ) - self.assertEqual(f.readline(), 'First line.\n') + self.assertEqual(f.readline(), "First line.\n") # universal newline mode should convert this linefeed: - self.assertEqual(f.readline(), 'Second line.\n') + self.assertEqual(f.readline(), "Second line.\n") # truncated line: - self.assertEqual(f.readline(7), 'Third l') - self.assertEqual(f.readline(), 'ine.\n') + self.assertEqual(f.readline(7), "Third l") + self.assertEqual(f.readline(), "ine.\n") # newline should be detected and only the fourth line returned - self.assertEqual(f.readline(39), 'Fourth line.\n') - self.assertEqual(f.readline(), 'Final line non-terminated.') - self.assertEqual(f.readline(), '') + self.assertEqual(f.readline(39), "Fourth line.\n") + self.assertEqual(f.readline(), "Final line non-terminated.") + self.assertEqual(f.readline(), "") f.close() try: f.readline() - self.assertTrue(False, 'no exception on readline of closed file') + self.assertTrue(False, "no exception on readline of closed file") except IOError: pass self.assertTrue(linefeed_byte in f.newlines) @@ -98,11 +101,11 @@ class BufferedFileTest (unittest.TestCase): """ try to trick the linefeed detector. """ - f = LoopbackFile('r+U') - f.write(b'First line.\r') - self.assertEqual(f.readline(), 'First line.\n') - f.write(b'\nSecond.\r\n') - self.assertEqual(f.readline(), 'Second.\n') + f = LoopbackFile("r+U") + f.write(b"First line.\r") + self.assertEqual(f.readline(), "First line.\n") + f.write(b"\nSecond.\r\n") + self.assertEqual(f.readline(), "Second.\n") f.close() self.assertEqual(f.newlines, crlf) @@ -110,51 +113,54 @@ class BufferedFileTest (unittest.TestCase): """ verify that write buffering is on. """ - f = LoopbackFile('r+', 1) - f.write(b'Complete line.\nIncomplete line.') - self.assertEqual(f.readline(), 'Complete line.\n') - self.assertEqual(f.readline(), '') - f.write('..\n') - self.assertEqual(f.readline(), 'Incomplete line...\n') + f = LoopbackFile("r+", 1) + f.write(b"Complete line.\nIncomplete line.") + self.assertEqual(f.readline(), "Complete line.\n") + self.assertEqual(f.readline(), "") + f.write("..\n") + self.assertEqual(f.readline(), "Incomplete line...\n") f.close() def test_5_flush(self): """ verify that flush will force a write. """ - f = LoopbackFile('r+', 512) - f.write('Not\nquite\n512 bytes.\n') - self.assertEqual(f.read(1), b'') + f = LoopbackFile("r+", 512) + f.write("Not\nquite\n512 bytes.\n") + self.assertEqual(f.read(1), b"") f.flush() - self.assertEqual(f.read(5), b'Not\nq') - self.assertEqual(f.read(10), b'uite\n512 b') - self.assertEqual(f.read(9), b'ytes.\n') - self.assertEqual(f.read(3), b'') + self.assertEqual(f.read(5), b"Not\nq") + self.assertEqual(f.read(10), b"uite\n512 b") + self.assertEqual(f.read(9), b"ytes.\n") + self.assertEqual(f.read(3), b"") f.close() def test_6_buffering(self): """ verify that flushing happens automatically on buffer crossing. """ - f = LoopbackFile('r+', 16) - f.write(b'Too small.') - self.assertEqual(f.read(4), b'') - f.write(b' ') - self.assertEqual(f.read(4), b'') - f.write(b'Enough.') - self.assertEqual(f.read(20), b'Too small. Enough.') + f = LoopbackFile("r+", 16) + f.write(b"Too small.") + self.assertEqual(f.read(4), b"") + f.write(b" ") + self.assertEqual(f.read(4), b"") + f.write(b"Enough.") + self.assertEqual(f.read(20), b"Too small. Enough.") f.close() def test_7_read_all(self): """ verify that read(-1) returns everything left in the file. """ - f = LoopbackFile('r+', 16) - f.write(b'The first thing you need to do is open your eyes. ') - f.write(b'Then, you need to close them again.\n') + f = LoopbackFile("r+", 16) + f.write(b"The first thing you need to do is open your eyes. ") + f.write(b"Then, you need to close them again.\n") s = f.read(-1) - self.assertEqual(s, b'The first thing you need to do is open your eyes. Then, you ' + - b'need to close them again.\n') + self.assertEqual( + s, + b"The first thing you need to do is open your eyes. Then, you " + + b"need to close them again.\n", + ) f.close() def test_8_buffering(self): @@ -162,19 +168,19 @@ class BufferedFileTest (unittest.TestCase): verify that buffered objects can be written """ if sys.version_info[0] == 2: - f = LoopbackFile('r+', 16) - f.write(buffer(b'Too small.')) + f = LoopbackFile("r+", 16) + f.write(buffer(b"Too small.")) f.close() def test_9_readable(self): - f = LoopbackFile('r') + f = LoopbackFile("r") self.assertTrue(f.readable()) self.assertFalse(f.writable()) self.assertFalse(f.seekable()) f.close() def test_A_writable(self): - f = LoopbackFile('w') + f = LoopbackFile("w") self.assertTrue(f.writable()) self.assertFalse(f.readable()) self.assertFalse(f.seekable()) @@ -182,48 +188,49 @@ class BufferedFileTest (unittest.TestCase): def test_B_readinto(self): data = bytearray(5) - f = LoopbackFile('r+') + f = LoopbackFile("r+") f._write(b"hello") f.readinto(data) - self.assertEqual(data, b'hello') + self.assertEqual(data, b"hello") f.close() def test_write_bad_type(self): - with LoopbackFile('wb') as f: + with LoopbackFile("wb") as f: self.assertRaises(TypeError, f.write, object()) def test_write_unicode_as_binary(self): text = u"\xa7 why is writing text to a binary file allowed?\n" - with LoopbackFile('rb+') as f: + with LoopbackFile("rb+") as f: f.write(text) self.assertEqual(f.read(), text.encode("utf-8")) - @needs_builtin('memoryview') + @needs_builtin("memoryview") def test_write_bytearray(self): - with LoopbackFile('rb+') as f: + with LoopbackFile("rb+") as f: f.write(bytearray(12)) self.assertEqual(f.read(), 12 * b"\0") - @needs_builtin('buffer') + @needs_builtin("buffer") def test_write_buffer(self): data = 3 * b"pretend giant block of data\n" offsets = range(0, len(data), 8) - with LoopbackFile('rb+') as f: + with LoopbackFile("rb+") as f: for offset in offsets: f.write(buffer(data, offset, 8)) self.assertEqual(f.read(), data) - @needs_builtin('memoryview') + @needs_builtin("memoryview") def test_write_memoryview(self): data = 3 * b"pretend giant block of data\n" offsets = range(0, len(data), 8) - with LoopbackFile('rb+') as f: + with LoopbackFile("rb+") as f: view = memoryview(data) for offset in offsets: - f.write(view[offset:offset+8]) + f.write(view[offset:offset + 8]) self.assertEqual(f.read(), data) -if __name__ == '__main__': +if __name__ == "__main__": from unittest import main + main() diff --git a/tests/test_gssapi.py b/tests/test_gssapi.py index d4b632be..d7fbdd53 100644 --- a/tests/test_gssapi.py +++ b/tests/test_gssapi.py @@ -30,6 +30,7 @@ from .util import needs_gssapi @needs_gssapi class GSSAPITest(unittest.TestCase): + def setup(): # TODO: these vars should all come from os.environ or whatever the # approved pytest method is for runtime-configuring test data. @@ -43,6 +44,7 @@ class GSSAPITest(unittest.TestCase): """ from pyasn1.type.univ import ObjectIdentifier from pyasn1.codec.der import encoder, decoder + oid = encoder.encode(ObjectIdentifier(self.krb5_mech)) mech, __ = decoder.decode(oid) self.assertEquals(self.krb5_mech, mech.__str__()) @@ -57,6 +59,7 @@ class GSSAPITest(unittest.TestCase): except ImportError: import sspicon import sspi + _API = "SSPI" c_token = None @@ -65,23 +68,28 @@ class GSSAPITest(unittest.TestCase): if _API == "MIT": if self.server_mode: - gss_flags = (gssapi.C_PROT_READY_FLAG, - gssapi.C_INTEG_FLAG, - gssapi.C_MUTUAL_FLAG, - gssapi.C_DELEG_FLAG) + gss_flags = ( + gssapi.C_PROT_READY_FLAG, + gssapi.C_INTEG_FLAG, + gssapi.C_MUTUAL_FLAG, + gssapi.C_DELEG_FLAG, + ) else: - gss_flags = (gssapi.C_PROT_READY_FLAG, - gssapi.C_INTEG_FLAG, - gssapi.C_DELEG_FLAG) + gss_flags = ( + gssapi.C_PROT_READY_FLAG, + gssapi.C_INTEG_FLAG, + gssapi.C_DELEG_FLAG, + ) # Initialize a GSS-API context. ctx = gssapi.Context() ctx.flags = gss_flags krb5_oid = gssapi.OID.mech_from_string(self.krb5_mech) - target_name = gssapi.Name("host@" + self.targ_name, - gssapi.C_NT_HOSTBASED_SERVICE) - gss_ctxt = gssapi.InitContext(peer_name=target_name, - mech_type=krb5_oid, - req_flags=ctx.flags) + target_name = gssapi.Name( + "host@" + self.targ_name, gssapi.C_NT_HOSTBASED_SERVICE + ) + gss_ctxt = gssapi.InitContext( + peer_name=target_name, mech_type=krb5_oid, req_flags=ctx.flags + ) if self.server_mode: c_token = gss_ctxt.step(c_token) gss_ctxt_status = gss_ctxt.established @@ -108,15 +116,15 @@ class GSSAPITest(unittest.TestCase): self.assertEquals(0, status) else: gss_flags = ( - sspicon.ISC_REQ_INTEGRITY | - sspicon.ISC_REQ_MUTUAL_AUTH | - sspicon.ISC_REQ_DELEGATE + sspicon.ISC_REQ_INTEGRITY + | sspicon.ISC_REQ_MUTUAL_AUTH + | sspicon.ISC_REQ_DELEGATE ) # Initialize a GSS-API context. target_name = "host/" + socket.getfqdn(self.targ_name) - gss_ctxt = sspi.ClientAuth("Kerberos", - scflags=gss_flags, - targetspn=target_name) + gss_ctxt = sspi.ClientAuth( + "Kerberos", scflags=gss_flags, targetspn=target_name + ) if self.server_mode: error, token = gss_ctxt.authorize(c_token) c_token = token[0].Buffer diff --git a/tests/test_hostkeys.py b/tests/test_hostkeys.py index cd75f8ab..a1b7a9e0 100644 --- a/tests/test_hostkeys.py +++ b/tests/test_hostkeys.py @@ -54,77 +54,80 @@ Ngw3qIch/WgRmMHy4kBq1SsXMjQCte1So6HBMvBPIW5SiMTmjCfZZiw4AYHK+B/JaOwaG9yRg2Ejg\ 0d54U0X/NeX5QxuYR6OMJlrkQB7oiW/P/1mwjQgE=""" -class HostKeysTest (unittest.TestCase): +class HostKeysTest(unittest.TestCase): def setUp(self): - with open('hostfile.temp', 'w') as f: + with open("hostfile.temp", "w") as f: f.write(test_hosts_file) def tearDown(self): - os.unlink('hostfile.temp') + os.unlink("hostfile.temp") def test_1_load(self): - hostdict = paramiko.HostKeys('hostfile.temp') + hostdict = paramiko.HostKeys("hostfile.temp") self.assertEqual(2, len(hostdict)) self.assertEqual(1, len(list(hostdict.values())[0])) self.assertEqual(1, len(list(hostdict.values())[1])) - fp = hexlify(hostdict['secure.example.com']['ssh-rsa'].get_fingerprint()).upper() - self.assertEqual(b'E6684DB30E109B67B70FF1DC5C7F1363', fp) + fp = hexlify( + hostdict["secure.example.com"]["ssh-rsa"].get_fingerprint() + ).upper() + self.assertEqual(b"E6684DB30E109B67B70FF1DC5C7F1363", fp) def test_2_add(self): - hostdict = paramiko.HostKeys('hostfile.temp') - hh = '|1|BMsIC6cUIP2zBuXR3t2LRcJYjzM=|hpkJMysjTk/+zzUUzxQEa2ieq6c=' + hostdict = paramiko.HostKeys("hostfile.temp") + hh = "|1|BMsIC6cUIP2zBuXR3t2LRcJYjzM=|hpkJMysjTk/+zzUUzxQEa2ieq6c=" key = paramiko.RSAKey(data=decodebytes(keyblob)) - hostdict.add(hh, 'ssh-rsa', key) + hostdict.add(hh, "ssh-rsa", key) self.assertEqual(3, len(list(hostdict))) - x = hostdict['foo.example.com'] - fp = hexlify(x['ssh-rsa'].get_fingerprint()).upper() - self.assertEqual(b'7EC91BB336CB6D810B124B1353C32396', fp) - self.assertTrue(hostdict.check('foo.example.com', key)) + x = hostdict["foo.example.com"] + fp = hexlify(x["ssh-rsa"].get_fingerprint()).upper() + self.assertEqual(b"7EC91BB336CB6D810B124B1353C32396", fp) + self.assertTrue(hostdict.check("foo.example.com", key)) def test_3_dict(self): - hostdict = paramiko.HostKeys('hostfile.temp') - self.assertTrue('secure.example.com' in hostdict) - self.assertTrue('not.example.com' not in hostdict) - self.assertTrue('secure.example.com' in hostdict) - self.assertTrue('not.example.com' not in hostdict) - x = hostdict.get('secure.example.com', None) + hostdict = paramiko.HostKeys("hostfile.temp") + self.assertTrue("secure.example.com" in hostdict) + self.assertTrue("not.example.com" not in hostdict) + self.assertTrue("secure.example.com" in hostdict) + self.assertTrue("not.example.com" not in hostdict) + x = hostdict.get("secure.example.com", None) self.assertTrue(x is not None) - fp = hexlify(x['ssh-rsa'].get_fingerprint()).upper() - self.assertEqual(b'E6684DB30E109B67B70FF1DC5C7F1363', fp) + fp = hexlify(x["ssh-rsa"].get_fingerprint()).upper() + self.assertEqual(b"E6684DB30E109B67B70FF1DC5C7F1363", fp) i = 0 for key in hostdict: i += 1 self.assertEqual(2, i) - + def test_4_dict_set(self): - hostdict = paramiko.HostKeys('hostfile.temp') + hostdict = paramiko.HostKeys("hostfile.temp") key = paramiko.RSAKey(data=decodebytes(keyblob)) key_dss = paramiko.DSSKey(data=decodebytes(keyblob_dss)) - hostdict['secure.example.com'] = { - 'ssh-rsa': key, - 'ssh-dss': key_dss - } - hostdict['fake.example.com'] = {} - hostdict['fake.example.com']['ssh-rsa'] = key - + hostdict["secure.example.com"] = {"ssh-rsa": key, "ssh-dss": key_dss} + hostdict["fake.example.com"] = {} + hostdict["fake.example.com"]["ssh-rsa"] = key + self.assertEqual(3, len(hostdict)) self.assertEqual(2, len(list(hostdict.values())[0])) self.assertEqual(1, len(list(hostdict.values())[1])) self.assertEqual(1, len(list(hostdict.values())[2])) - fp = hexlify(hostdict['secure.example.com']['ssh-rsa'].get_fingerprint()).upper() - self.assertEqual(b'7EC91BB336CB6D810B124B1353C32396', fp) - fp = hexlify(hostdict['secure.example.com']['ssh-dss'].get_fingerprint()).upper() - self.assertEqual(b'4478F0B9A23CC5182009FF755BC1D26C', fp) + fp = hexlify( + hostdict["secure.example.com"]["ssh-rsa"].get_fingerprint() + ).upper() + self.assertEqual(b"7EC91BB336CB6D810B124B1353C32396", fp) + fp = hexlify( + hostdict["secure.example.com"]["ssh-dss"].get_fingerprint() + ).upper() + self.assertEqual(b"4478F0B9A23CC5182009FF755BC1D26C", fp) def test_delitem(self): - hostdict = paramiko.HostKeys('hostfile.temp') - target = 'happy.example.com' - entry = hostdict[target] # will KeyError if not present + hostdict = paramiko.HostKeys("hostfile.temp") + target = "happy.example.com" + entry = hostdict[target] # will KeyError if not present del hostdict[target] try: entry = hostdict[target] except KeyError: - pass # Good + pass # Good else: assert False, "Entry was not deleted from HostKeys on delitem!" diff --git a/tests/test_kex.py b/tests/test_kex.py index b5808e7e..41e2dea2 100644 --- a/tests/test_kex.py +++ b/tests/test_kex.py @@ -38,29 +38,45 @@ from paramiko.kex_ecdh_nist import KexNistp256 def dummy_urandom(n): return byte_chr(0xcc) * n + def dummy_generate_key_pair(obj): private_key_value = 94761803665136558137557783047955027733968423115106677159790289642479432803037 public_key_numbers = "042bdab212fa8ba1b7c843301682a4db424d307246c7e1e6083c41d9ca7b098bf30b3d63e2ec6278488c135360456cc054b3444ecc45998c08894cbc1370f5f989" - public_key_numbers_obj = ec.EllipticCurvePublicNumbers.from_encoded_point(ec.SECP256R1(), unhexlify(public_key_numbers)) - obj.P = ec.EllipticCurvePrivateNumbers(private_value=private_key_value, public_numbers=public_key_numbers_obj).private_key(default_backend()) + public_key_numbers_obj = ec.EllipticCurvePublicNumbers.from_encoded_point( + ec.SECP256R1(), unhexlify(public_key_numbers) + ) + obj.P = ec.EllipticCurvePrivateNumbers( + private_value=private_key_value, public_numbers=public_key_numbers_obj + ).private_key( + default_backend() + ) if obj.transport.server_mode: - obj.Q_S = ec.EllipticCurvePublicNumbers.from_encoded_point(ec.SECP256R1(), unhexlify(public_key_numbers)).public_key(default_backend()) + obj.Q_S = ec.EllipticCurvePublicNumbers.from_encoded_point( + ec.SECP256R1(), unhexlify(public_key_numbers) + ).public_key( + default_backend() + ) return - obj.Q_C = ec.EllipticCurvePublicNumbers.from_encoded_point(ec.SECP256R1(), unhexlify(public_key_numbers)).public_key(default_backend()) + obj.Q_C = ec.EllipticCurvePublicNumbers.from_encoded_point( + ec.SECP256R1(), unhexlify(public_key_numbers) + ).public_key( + default_backend() + ) + +class FakeKey(object): -class FakeKey (object): def __str__(self): - return 'fake-key' + return "fake-key" def asbytes(self): - return b'fake-key' + return b"fake-key" def sign_ssh_data(self, H): - return b'fake-sig' + return b"fake-sig" -class FakeModulusPack (object): +class FakeModulusPack(object): P = 0xFFFFFFFFFFFFFFFFC90FDAA22168C234C4C6628B80DC1CD129024E088A67CC74020BBEA63B139B22514A08798E3404DDEF9519B3CD3A431B302B0A6DF25F14374FE1356D6D51C245E485B576625E7EC6F44C42E9A637ED6B0BFF5CB6F406B7EDEE386BFB5A899FA5AE9F24117C4B1FE649286651ECE65381FFFFFFFFFFFFFFFF G = 2 @@ -69,10 +85,10 @@ class FakeModulusPack (object): class FakeTransport(object): - local_version = 'SSH-2.0-paramiko_1.0' - remote_version = 'SSH-2.0-lame' - local_kex_init = 'local-kex-init' - remote_kex_init = 'remote-kex-init' + local_version = "SSH-2.0-paramiko_1.0" + remote_version = "SSH-2.0-lame" + local_kex_init = "local-kex-init" + remote_kex_init = "remote-kex-init" def _send_message(self, m): self._message = m @@ -100,7 +116,7 @@ class FakeTransport(object): return FakeModulusPack() -class KexTest (unittest.TestCase): +class KexTest(unittest.TestCase): K = 14730343317708716439807310032871972459448364195094179797249681733965528989482751523943515690110179031004049109375612685505881911274101441415545039654102474376472240501616988799699744135291070488314748284283496055223852115360852283821334858541043710301057312858051901453919067023103730011648890038847384890504 @@ -119,21 +135,23 @@ class KexTest (unittest.TestCase): transport.server_mode = False kex = KexGroup1(transport) kex.start_kex() - x = b'1E000000807E2DDB1743F3487D6545F04F1C8476092FB912B013626AB5BCEB764257D88BBA64243B9F348DF7B41B8C814A995E00299913503456983FFB9178D3CD79EB6D55522418A8ABF65375872E55938AB99A84A0B5FC8A1ECC66A7C3766E7E0F80B7CE2C9225FC2DD683F4764244B72963BBB383F529DCF0C5D17740B8A2ADBE9208D4' + x = b"1E000000807E2DDB1743F3487D6545F04F1C8476092FB912B013626AB5BCEB764257D88BBA64243B9F348DF7B41B8C814A995E00299913503456983FFB9178D3CD79EB6D55522418A8ABF65375872E55938AB99A84A0B5FC8A1ECC66A7C3766E7E0F80B7CE2C9225FC2DD683F4764244B72963BBB383F529DCF0C5D17740B8A2ADBE9208D4" self.assertEqual(x, hexlify(transport._message.asbytes()).upper()) - self.assertEqual((paramiko.kex_group1._MSG_KEXDH_REPLY,), transport._expect) + self.assertEqual( + (paramiko.kex_group1._MSG_KEXDH_REPLY,), transport._expect + ) # fake "reply" msg = Message() - msg.add_string('fake-host-key') + msg.add_string("fake-host-key") msg.add_mpint(69) - msg.add_string('fake-sig') + msg.add_string("fake-sig") msg.rewind() kex.parse_next(paramiko.kex_group1._MSG_KEXDH_REPLY, msg) - H = b'03079780F3D3AD0B3C6DB30C8D21685F367A86D2' + H = b"03079780F3D3AD0B3C6DB30C8D21685F367A86D2" self.assertEqual(self.K, transport._K) self.assertEqual(H, hexlify(transport._H).upper()) - self.assertEqual((b'fake-host-key', b'fake-sig'), transport._verify) + self.assertEqual((b"fake-host-key", b"fake-sig"), transport._verify) self.assertTrue(transport._activated) def test_2_group1_server(self): @@ -141,14 +159,16 @@ class KexTest (unittest.TestCase): transport.server_mode = True kex = KexGroup1(transport) kex.start_kex() - self.assertEqual((paramiko.kex_group1._MSG_KEXDH_INIT,), transport._expect) + self.assertEqual( + (paramiko.kex_group1._MSG_KEXDH_INIT,), transport._expect + ) msg = Message() msg.add_mpint(69) msg.rewind() kex.parse_next(paramiko.kex_group1._MSG_KEXDH_INIT, msg) - H = b'B16BF34DD10945EDE84E9C1EF24A14BFDC843389' - x = b'1F0000000866616B652D6B6579000000807E2DDB1743F3487D6545F04F1C8476092FB912B013626AB5BCEB764257D88BBA64243B9F348DF7B41B8C814A995E00299913503456983FFB9178D3CD79EB6D55522418A8ABF65375872E55938AB99A84A0B5FC8A1ECC66A7C3766E7E0F80B7CE2C9225FC2DD683F4764244B72963BBB383F529DCF0C5D17740B8A2ADBE9208D40000000866616B652D736967' + H = b"B16BF34DD10945EDE84E9C1EF24A14BFDC843389" + x = b"1F0000000866616B652D6B6579000000807E2DDB1743F3487D6545F04F1C8476092FB912B013626AB5BCEB764257D88BBA64243B9F348DF7B41B8C814A995E00299913503456983FFB9178D3CD79EB6D55522418A8ABF65375872E55938AB99A84A0B5FC8A1ECC66A7C3766E7E0F80B7CE2C9225FC2DD683F4764244B72963BBB383F529DCF0C5D17740B8A2ADBE9208D40000000866616B652D736967" self.assertEqual(self.K, transport._K) self.assertEqual(H, hexlify(transport._H).upper()) self.assertEqual(x, hexlify(transport._message.asbytes()).upper()) @@ -159,29 +179,33 @@ class KexTest (unittest.TestCase): transport.server_mode = False kex = KexGex(transport) kex.start_kex() - x = b'22000004000000080000002000' + x = b"22000004000000080000002000" self.assertEqual(x, hexlify(transport._message.asbytes()).upper()) - self.assertEqual((paramiko.kex_gex._MSG_KEXDH_GEX_GROUP,), transport._expect) + self.assertEqual( + (paramiko.kex_gex._MSG_KEXDH_GEX_GROUP,), transport._expect + ) msg = Message() msg.add_mpint(FakeModulusPack.P) msg.add_mpint(FakeModulusPack.G) msg.rewind() kex.parse_next(paramiko.kex_gex._MSG_KEXDH_GEX_GROUP, msg) - x = b'20000000807E2DDB1743F3487D6545F04F1C8476092FB912B013626AB5BCEB764257D88BBA64243B9F348DF7B41B8C814A995E00299913503456983FFB9178D3CD79EB6D55522418A8ABF65375872E55938AB99A84A0B5FC8A1ECC66A7C3766E7E0F80B7CE2C9225FC2DD683F4764244B72963BBB383F529DCF0C5D17740B8A2ADBE9208D4' + x = b"20000000807E2DDB1743F3487D6545F04F1C8476092FB912B013626AB5BCEB764257D88BBA64243B9F348DF7B41B8C814A995E00299913503456983FFB9178D3CD79EB6D55522418A8ABF65375872E55938AB99A84A0B5FC8A1ECC66A7C3766E7E0F80B7CE2C9225FC2DD683F4764244B72963BBB383F529DCF0C5D17740B8A2ADBE9208D4" self.assertEqual(x, hexlify(transport._message.asbytes()).upper()) - self.assertEqual((paramiko.kex_gex._MSG_KEXDH_GEX_REPLY,), transport._expect) + self.assertEqual( + (paramiko.kex_gex._MSG_KEXDH_GEX_REPLY,), transport._expect + ) msg = Message() - msg.add_string('fake-host-key') + msg.add_string("fake-host-key") msg.add_mpint(69) - msg.add_string('fake-sig') + msg.add_string("fake-sig") msg.rewind() kex.parse_next(paramiko.kex_gex._MSG_KEXDH_GEX_REPLY, msg) - H = b'A265563F2FA87F1A89BF007EE90D58BE2E4A4BD0' + H = b"A265563F2FA87F1A89BF007EE90D58BE2E4A4BD0" self.assertEqual(self.K, transport._K) self.assertEqual(H, hexlify(transport._H).upper()) - self.assertEqual((b'fake-host-key', b'fake-sig'), transport._verify) + self.assertEqual((b"fake-host-key", b"fake-sig"), transport._verify) self.assertTrue(transport._activated) def test_4_gex_old_client(self): @@ -189,37 +213,47 @@ class KexTest (unittest.TestCase): transport.server_mode = False kex = KexGex(transport) kex.start_kex(_test_old_style=True) - x = b'1E00000800' + x = b"1E00000800" self.assertEqual(x, hexlify(transport._message.asbytes()).upper()) - self.assertEqual((paramiko.kex_gex._MSG_KEXDH_GEX_GROUP,), transport._expect) + self.assertEqual( + (paramiko.kex_gex._MSG_KEXDH_GEX_GROUP,), transport._expect + ) msg = Message() msg.add_mpint(FakeModulusPack.P) msg.add_mpint(FakeModulusPack.G) msg.rewind() kex.parse_next(paramiko.kex_gex._MSG_KEXDH_GEX_GROUP, msg) - x = b'20000000807E2DDB1743F3487D6545F04F1C8476092FB912B013626AB5BCEB764257D88BBA64243B9F348DF7B41B8C814A995E00299913503456983FFB9178D3CD79EB6D55522418A8ABF65375872E55938AB99A84A0B5FC8A1ECC66A7C3766E7E0F80B7CE2C9225FC2DD683F4764244B72963BBB383F529DCF0C5D17740B8A2ADBE9208D4' + x = b"20000000807E2DDB1743F3487D6545F04F1C8476092FB912B013626AB5BCEB764257D88BBA64243B9F348DF7B41B8C814A995E00299913503456983FFB9178D3CD79EB6D55522418A8ABF65375872E55938AB99A84A0B5FC8A1ECC66A7C3766E7E0F80B7CE2C9225FC2DD683F4764244B72963BBB383F529DCF0C5D17740B8A2ADBE9208D4" self.assertEqual(x, hexlify(transport._message.asbytes()).upper()) - self.assertEqual((paramiko.kex_gex._MSG_KEXDH_GEX_REPLY,), transport._expect) + self.assertEqual( + (paramiko.kex_gex._MSG_KEXDH_GEX_REPLY,), transport._expect + ) msg = Message() - msg.add_string('fake-host-key') + msg.add_string("fake-host-key") msg.add_mpint(69) - msg.add_string('fake-sig') + msg.add_string("fake-sig") msg.rewind() kex.parse_next(paramiko.kex_gex._MSG_KEXDH_GEX_REPLY, msg) - H = b'807F87B269EF7AC5EC7E75676808776A27D5864C' + H = b"807F87B269EF7AC5EC7E75676808776A27D5864C" self.assertEqual(self.K, transport._K) self.assertEqual(H, hexlify(transport._H).upper()) - self.assertEqual((b'fake-host-key', b'fake-sig'), transport._verify) + self.assertEqual((b"fake-host-key", b"fake-sig"), transport._verify) self.assertTrue(transport._activated) - + def test_5_gex_server(self): transport = FakeTransport() transport.server_mode = True kex = KexGex(transport) kex.start_kex() - self.assertEqual((paramiko.kex_gex._MSG_KEXDH_GEX_REQUEST, paramiko.kex_gex._MSG_KEXDH_GEX_REQUEST_OLD), transport._expect) + self.assertEqual( + ( + paramiko.kex_gex._MSG_KEXDH_GEX_REQUEST, + paramiko.kex_gex._MSG_KEXDH_GEX_REQUEST_OLD, + ), + transport._expect, + ) msg = Message() msg.add_int(1024) @@ -227,17 +261,19 @@ class KexTest (unittest.TestCase): msg.add_int(4096) msg.rewind() kex.parse_next(paramiko.kex_gex._MSG_KEXDH_GEX_REQUEST, msg) - x = b'1F0000008100FFFFFFFFFFFFFFFFC90FDAA22168C234C4C6628B80DC1CD129024E088A67CC74020BBEA63B139B22514A08798E3404DDEF9519B3CD3A431B302B0A6DF25F14374FE1356D6D51C245E485B576625E7EC6F44C42E9A637ED6B0BFF5CB6F406B7EDEE386BFB5A899FA5AE9F24117C4B1FE649286651ECE65381FFFFFFFFFFFFFFFF0000000102' + x = b"1F0000008100FFFFFFFFFFFFFFFFC90FDAA22168C234C4C6628B80DC1CD129024E088A67CC74020BBEA63B139B22514A08798E3404DDEF9519B3CD3A431B302B0A6DF25F14374FE1356D6D51C245E485B576625E7EC6F44C42E9A637ED6B0BFF5CB6F406B7EDEE386BFB5A899FA5AE9F24117C4B1FE649286651ECE65381FFFFFFFFFFFFFFFF0000000102" self.assertEqual(x, hexlify(transport._message.asbytes()).upper()) - self.assertEqual((paramiko.kex_gex._MSG_KEXDH_GEX_INIT,), transport._expect) + self.assertEqual( + (paramiko.kex_gex._MSG_KEXDH_GEX_INIT,), transport._expect + ) msg = Message() msg.add_mpint(12345) msg.rewind() kex.parse_next(paramiko.kex_gex._MSG_KEXDH_GEX_INIT, msg) K = 67592995013596137876033460028393339951879041140378510871612128162185209509220726296697886624612526735888348020498716482757677848959420073720160491114319163078862905400020959196386947926388406687288901564192071077389283980347784184487280885335302632305026248574716290537036069329724382811853044654824945750581 - H = b'CE754197C21BF3452863B4F44D0B3951F12516EF' - x = b'210000000866616B652D6B6579000000807E2DDB1743F3487D6545F04F1C8476092FB912B013626AB5BCEB764257D88BBA64243B9F348DF7B41B8C814A995E00299913503456983FFB9178D3CD79EB6D55522418A8ABF65375872E55938AB99A84A0B5FC8A1ECC66A7C3766E7E0F80B7CE2C9225FC2DD683F4764244B72963BBB383F529DCF0C5D17740B8A2ADBE9208D40000000866616B652D736967' + H = b"CE754197C21BF3452863B4F44D0B3951F12516EF" + x = b"210000000866616B652D6B6579000000807E2DDB1743F3487D6545F04F1C8476092FB912B013626AB5BCEB764257D88BBA64243B9F348DF7B41B8C814A995E00299913503456983FFB9178D3CD79EB6D55522418A8ABF65375872E55938AB99A84A0B5FC8A1ECC66A7C3766E7E0F80B7CE2C9225FC2DD683F4764244B72963BBB383F529DCF0C5D17740B8A2ADBE9208D40000000866616B652D736967" self.assertEqual(K, transport._K) self.assertEqual(H, hexlify(transport._H).upper()) self.assertEqual(x, hexlify(transport._message.asbytes()).upper()) @@ -248,23 +284,31 @@ class KexTest (unittest.TestCase): transport.server_mode = True kex = KexGex(transport) kex.start_kex() - self.assertEqual((paramiko.kex_gex._MSG_KEXDH_GEX_REQUEST, paramiko.kex_gex._MSG_KEXDH_GEX_REQUEST_OLD), transport._expect) + self.assertEqual( + ( + paramiko.kex_gex._MSG_KEXDH_GEX_REQUEST, + paramiko.kex_gex._MSG_KEXDH_GEX_REQUEST_OLD, + ), + transport._expect, + ) msg = Message() msg.add_int(2048) msg.rewind() kex.parse_next(paramiko.kex_gex._MSG_KEXDH_GEX_REQUEST_OLD, msg) - x = b'1F0000008100FFFFFFFFFFFFFFFFC90FDAA22168C234C4C6628B80DC1CD129024E088A67CC74020BBEA63B139B22514A08798E3404DDEF9519B3CD3A431B302B0A6DF25F14374FE1356D6D51C245E485B576625E7EC6F44C42E9A637ED6B0BFF5CB6F406B7EDEE386BFB5A899FA5AE9F24117C4B1FE649286651ECE65381FFFFFFFFFFFFFFFF0000000102' + x = b"1F0000008100FFFFFFFFFFFFFFFFC90FDAA22168C234C4C6628B80DC1CD129024E088A67CC74020BBEA63B139B22514A08798E3404DDEF9519B3CD3A431B302B0A6DF25F14374FE1356D6D51C245E485B576625E7EC6F44C42E9A637ED6B0BFF5CB6F406B7EDEE386BFB5A899FA5AE9F24117C4B1FE649286651ECE65381FFFFFFFFFFFFFFFF0000000102" self.assertEqual(x, hexlify(transport._message.asbytes()).upper()) - self.assertEqual((paramiko.kex_gex._MSG_KEXDH_GEX_INIT,), transport._expect) + self.assertEqual( + (paramiko.kex_gex._MSG_KEXDH_GEX_INIT,), transport._expect + ) msg = Message() msg.add_mpint(12345) msg.rewind() kex.parse_next(paramiko.kex_gex._MSG_KEXDH_GEX_INIT, msg) K = 67592995013596137876033460028393339951879041140378510871612128162185209509220726296697886624612526735888348020498716482757677848959420073720160491114319163078862905400020959196386947926388406687288901564192071077389283980347784184487280885335302632305026248574716290537036069329724382811853044654824945750581 - H = b'B41A06B2E59043CEFC1AE16EC31F1E2D12EC455B' - x = b'210000000866616B652D6B6579000000807E2DDB1743F3487D6545F04F1C8476092FB912B013626AB5BCEB764257D88BBA64243B9F348DF7B41B8C814A995E00299913503456983FFB9178D3CD79EB6D55522418A8ABF65375872E55938AB99A84A0B5FC8A1ECC66A7C3766E7E0F80B7CE2C9225FC2DD683F4764244B72963BBB383F529DCF0C5D17740B8A2ADBE9208D40000000866616B652D736967' + H = b"B41A06B2E59043CEFC1AE16EC31F1E2D12EC455B" + x = b"210000000866616B652D6B6579000000807E2DDB1743F3487D6545F04F1C8476092FB912B013626AB5BCEB764257D88BBA64243B9F348DF7B41B8C814A995E00299913503456983FFB9178D3CD79EB6D55522418A8ABF65375872E55938AB99A84A0B5FC8A1ECC66A7C3766E7E0F80B7CE2C9225FC2DD683F4764244B72963BBB383F529DCF0C5D17740B8A2ADBE9208D40000000866616B652D736967" self.assertEqual(K, transport._K) self.assertEqual(H, hexlify(transport._H).upper()) self.assertEqual(x, hexlify(transport._message.asbytes()).upper()) @@ -275,29 +319,33 @@ class KexTest (unittest.TestCase): transport.server_mode = False kex = KexGexSHA256(transport) kex.start_kex() - x = b'22000004000000080000002000' + x = b"22000004000000080000002000" self.assertEqual(x, hexlify(transport._message.asbytes()).upper()) - self.assertEqual((paramiko.kex_gex._MSG_KEXDH_GEX_GROUP,), transport._expect) + self.assertEqual( + (paramiko.kex_gex._MSG_KEXDH_GEX_GROUP,), transport._expect + ) msg = Message() msg.add_mpint(FakeModulusPack.P) msg.add_mpint(FakeModulusPack.G) msg.rewind() kex.parse_next(paramiko.kex_gex._MSG_KEXDH_GEX_GROUP, msg) - x = b'20000000807E2DDB1743F3487D6545F04F1C8476092FB912B013626AB5BCEB764257D88BBA64243B9F348DF7B41B8C814A995E00299913503456983FFB9178D3CD79EB6D55522418A8ABF65375872E55938AB99A84A0B5FC8A1ECC66A7C3766E7E0F80B7CE2C9225FC2DD683F4764244B72963BBB383F529DCF0C5D17740B8A2ADBE9208D4' + x = b"20000000807E2DDB1743F3487D6545F04F1C8476092FB912B013626AB5BCEB764257D88BBA64243B9F348DF7B41B8C814A995E00299913503456983FFB9178D3CD79EB6D55522418A8ABF65375872E55938AB99A84A0B5FC8A1ECC66A7C3766E7E0F80B7CE2C9225FC2DD683F4764244B72963BBB383F529DCF0C5D17740B8A2ADBE9208D4" self.assertEqual(x, hexlify(transport._message.asbytes()).upper()) - self.assertEqual((paramiko.kex_gex._MSG_KEXDH_GEX_REPLY,), transport._expect) + self.assertEqual( + (paramiko.kex_gex._MSG_KEXDH_GEX_REPLY,), transport._expect + ) msg = Message() - msg.add_string('fake-host-key') + msg.add_string("fake-host-key") msg.add_mpint(69) - msg.add_string('fake-sig') + msg.add_string("fake-sig") msg.rewind() kex.parse_next(paramiko.kex_gex._MSG_KEXDH_GEX_REPLY, msg) - H = b'AD1A9365A67B4496F05594AD1BF656E3CDA0851289A4C1AFF549FEAE50896DF4' + H = b"AD1A9365A67B4496F05594AD1BF656E3CDA0851289A4C1AFF549FEAE50896DF4" self.assertEqual(self.K, transport._K) self.assertEqual(H, hexlify(transport._H).upper()) - self.assertEqual((b'fake-host-key', b'fake-sig'), transport._verify) + self.assertEqual((b"fake-host-key", b"fake-sig"), transport._verify) self.assertTrue(transport._activated) def test_8_gex_sha256_old_client(self): @@ -305,29 +353,33 @@ class KexTest (unittest.TestCase): transport.server_mode = False kex = KexGexSHA256(transport) kex.start_kex(_test_old_style=True) - x = b'1E00000800' + x = b"1E00000800" self.assertEqual(x, hexlify(transport._message.asbytes()).upper()) - self.assertEqual((paramiko.kex_gex._MSG_KEXDH_GEX_GROUP,), transport._expect) + self.assertEqual( + (paramiko.kex_gex._MSG_KEXDH_GEX_GROUP,), transport._expect + ) msg = Message() msg.add_mpint(FakeModulusPack.P) msg.add_mpint(FakeModulusPack.G) msg.rewind() kex.parse_next(paramiko.kex_gex._MSG_KEXDH_GEX_GROUP, msg) - x = b'20000000807E2DDB1743F3487D6545F04F1C8476092FB912B013626AB5BCEB764257D88BBA64243B9F348DF7B41B8C814A995E00299913503456983FFB9178D3CD79EB6D55522418A8ABF65375872E55938AB99A84A0B5FC8A1ECC66A7C3766E7E0F80B7CE2C9225FC2DD683F4764244B72963BBB383F529DCF0C5D17740B8A2ADBE9208D4' + x = b"20000000807E2DDB1743F3487D6545F04F1C8476092FB912B013626AB5BCEB764257D88BBA64243B9F348DF7B41B8C814A995E00299913503456983FFB9178D3CD79EB6D55522418A8ABF65375872E55938AB99A84A0B5FC8A1ECC66A7C3766E7E0F80B7CE2C9225FC2DD683F4764244B72963BBB383F529DCF0C5D17740B8A2ADBE9208D4" self.assertEqual(x, hexlify(transport._message.asbytes()).upper()) - self.assertEqual((paramiko.kex_gex._MSG_KEXDH_GEX_REPLY,), transport._expect) + self.assertEqual( + (paramiko.kex_gex._MSG_KEXDH_GEX_REPLY,), transport._expect + ) msg = Message() - msg.add_string('fake-host-key') + msg.add_string("fake-host-key") msg.add_mpint(69) - msg.add_string('fake-sig') + msg.add_string("fake-sig") msg.rewind() kex.parse_next(paramiko.kex_gex._MSG_KEXDH_GEX_REPLY, msg) - H = b'518386608B15891AE5237DEE08DCADDE76A0BCEFCE7F6DB3AD66BC41D256DFE5' + H = b"518386608B15891AE5237DEE08DCADDE76A0BCEFCE7F6DB3AD66BC41D256DFE5" self.assertEqual(self.K, transport._K) self.assertEqual(H, hexlify(transport._H).upper()) - self.assertEqual((b'fake-host-key', b'fake-sig'), transport._verify) + self.assertEqual((b"fake-host-key", b"fake-sig"), transport._verify) self.assertTrue(transport._activated) def test_9_gex_sha256_server(self): @@ -335,7 +387,13 @@ class KexTest (unittest.TestCase): transport.server_mode = True kex = KexGexSHA256(transport) kex.start_kex() - self.assertEqual((paramiko.kex_gex._MSG_KEXDH_GEX_REQUEST, paramiko.kex_gex._MSG_KEXDH_GEX_REQUEST_OLD), transport._expect) + self.assertEqual( + ( + paramiko.kex_gex._MSG_KEXDH_GEX_REQUEST, + paramiko.kex_gex._MSG_KEXDH_GEX_REQUEST_OLD, + ), + transport._expect, + ) msg = Message() msg.add_int(1024) @@ -343,17 +401,19 @@ class KexTest (unittest.TestCase): msg.add_int(4096) msg.rewind() kex.parse_next(paramiko.kex_gex._MSG_KEXDH_GEX_REQUEST, msg) - x = b'1F0000008100FFFFFFFFFFFFFFFFC90FDAA22168C234C4C6628B80DC1CD129024E088A67CC74020BBEA63B139B22514A08798E3404DDEF9519B3CD3A431B302B0A6DF25F14374FE1356D6D51C245E485B576625E7EC6F44C42E9A637ED6B0BFF5CB6F406B7EDEE386BFB5A899FA5AE9F24117C4B1FE649286651ECE65381FFFFFFFFFFFFFFFF0000000102' + x = b"1F0000008100FFFFFFFFFFFFFFFFC90FDAA22168C234C4C6628B80DC1CD129024E088A67CC74020BBEA63B139B22514A08798E3404DDEF9519B3CD3A431B302B0A6DF25F14374FE1356D6D51C245E485B576625E7EC6F44C42E9A637ED6B0BFF5CB6F406B7EDEE386BFB5A899FA5AE9F24117C4B1FE649286651ECE65381FFFFFFFFFFFFFFFF0000000102" self.assertEqual(x, hexlify(transport._message.asbytes()).upper()) - self.assertEqual((paramiko.kex_gex._MSG_KEXDH_GEX_INIT,), transport._expect) + self.assertEqual( + (paramiko.kex_gex._MSG_KEXDH_GEX_INIT,), transport._expect + ) msg = Message() msg.add_mpint(12345) msg.rewind() kex.parse_next(paramiko.kex_gex._MSG_KEXDH_GEX_INIT, msg) K = 67592995013596137876033460028393339951879041140378510871612128162185209509220726296697886624612526735888348020498716482757677848959420073720160491114319163078862905400020959196386947926388406687288901564192071077389283980347784184487280885335302632305026248574716290537036069329724382811853044654824945750581 - H = b'CCAC0497CF0ABA1DBF55E1A3995D17F4CC31824B0E8D95CDF8A06F169D050D80' - x = b'210000000866616B652D6B6579000000807E2DDB1743F3487D6545F04F1C8476092FB912B013626AB5BCEB764257D88BBA64243B9F348DF7B41B8C814A995E00299913503456983FFB9178D3CD79EB6D55522418A8ABF65375872E55938AB99A84A0B5FC8A1ECC66A7C3766E7E0F80B7CE2C9225FC2DD683F4764244B72963BBB383F529DCF0C5D17740B8A2ADBE9208D40000000866616B652D736967' + H = b"CCAC0497CF0ABA1DBF55E1A3995D17F4CC31824B0E8D95CDF8A06F169D050D80" + x = b"210000000866616B652D6B6579000000807E2DDB1743F3487D6545F04F1C8476092FB912B013626AB5BCEB764257D88BBA64243B9F348DF7B41B8C814A995E00299913503456983FFB9178D3CD79EB6D55522418A8ABF65375872E55938AB99A84A0B5FC8A1ECC66A7C3766E7E0F80B7CE2C9225FC2DD683F4764244B72963BBB383F529DCF0C5D17740B8A2ADBE9208D40000000866616B652D736967" self.assertEqual(K, transport._K) self.assertEqual(H, hexlify(transport._H).upper()) self.assertEqual(x, hexlify(transport._message.asbytes()).upper()) @@ -364,23 +424,31 @@ class KexTest (unittest.TestCase): transport.server_mode = True kex = KexGexSHA256(transport) kex.start_kex() - self.assertEqual((paramiko.kex_gex._MSG_KEXDH_GEX_REQUEST, paramiko.kex_gex._MSG_KEXDH_GEX_REQUEST_OLD), transport._expect) + self.assertEqual( + ( + paramiko.kex_gex._MSG_KEXDH_GEX_REQUEST, + paramiko.kex_gex._MSG_KEXDH_GEX_REQUEST_OLD, + ), + transport._expect, + ) msg = Message() msg.add_int(2048) msg.rewind() kex.parse_next(paramiko.kex_gex._MSG_KEXDH_GEX_REQUEST_OLD, msg) - x = b'1F0000008100FFFFFFFFFFFFFFFFC90FDAA22168C234C4C6628B80DC1CD129024E088A67CC74020BBEA63B139B22514A08798E3404DDEF9519B3CD3A431B302B0A6DF25F14374FE1356D6D51C245E485B576625E7EC6F44C42E9A637ED6B0BFF5CB6F406B7EDEE386BFB5A899FA5AE9F24117C4B1FE649286651ECE65381FFFFFFFFFFFFFFFF0000000102' + x = b"1F0000008100FFFFFFFFFFFFFFFFC90FDAA22168C234C4C6628B80DC1CD129024E088A67CC74020BBEA63B139B22514A08798E3404DDEF9519B3CD3A431B302B0A6DF25F14374FE1356D6D51C245E485B576625E7EC6F44C42E9A637ED6B0BFF5CB6F406B7EDEE386BFB5A899FA5AE9F24117C4B1FE649286651ECE65381FFFFFFFFFFFFFFFF0000000102" self.assertEqual(x, hexlify(transport._message.asbytes()).upper()) - self.assertEqual((paramiko.kex_gex._MSG_KEXDH_GEX_INIT,), transport._expect) + self.assertEqual( + (paramiko.kex_gex._MSG_KEXDH_GEX_INIT,), transport._expect + ) msg = Message() msg.add_mpint(12345) msg.rewind() kex.parse_next(paramiko.kex_gex._MSG_KEXDH_GEX_INIT, msg) K = 67592995013596137876033460028393339951879041140378510871612128162185209509220726296697886624612526735888348020498716482757677848959420073720160491114319163078862905400020959196386947926388406687288901564192071077389283980347784184487280885335302632305026248574716290537036069329724382811853044654824945750581 - H = b'3DDD2AD840AD095E397BA4D0573972DC60F6461FD38A187CACA6615A5BC8ADBB' - x = b'210000000866616B652D6B6579000000807E2DDB1743F3487D6545F04F1C8476092FB912B013626AB5BCEB764257D88BBA64243B9F348DF7B41B8C814A995E00299913503456983FFB9178D3CD79EB6D55522418A8ABF65375872E55938AB99A84A0B5FC8A1ECC66A7C3766E7E0F80B7CE2C9225FC2DD683F4764244B72963BBB383F529DCF0C5D17740B8A2ADBE9208D40000000866616B652D736967' + H = b"3DDD2AD840AD095E397BA4D0573972DC60F6461FD38A187CACA6615A5BC8ADBB" + x = b"210000000866616B652D6B6579000000807E2DDB1743F3487D6545F04F1C8476092FB912B013626AB5BCEB764257D88BBA64243B9F348DF7B41B8C814A995E00299913503456983FFB9178D3CD79EB6D55522418A8ABF65375872E55938AB99A84A0B5FC8A1ECC66A7C3766E7E0F80B7CE2C9225FC2DD683F4764244B72963BBB383F529DCF0C5D17740B8A2ADBE9208D40000000866616B652D736967" self.assertEqual(K, transport._K) self.assertEqual(H, hexlify(transport._H).upper()) self.assertEqual(x, hexlify(transport._message.asbytes()).upper()) @@ -392,20 +460,24 @@ class KexTest (unittest.TestCase): transport.server_mode = False kex = KexNistp256(transport) kex.start_kex() - self.assertEqual((paramiko.kex_ecdh_nist._MSG_KEXECDH_REPLY,), transport._expect) + self.assertEqual( + (paramiko.kex_ecdh_nist._MSG_KEXECDH_REPLY,), transport._expect + ) - #fake reply + # fake reply msg = Message() - msg.add_string('fake-host-key') - Q_S = unhexlify("043ae159594ba062efa121480e9ef136203fa9ec6b6e1f8723a321c16e62b945f573f3b822258cbcd094b9fa1c125cbfe5f043280893e66863cc0cb4dccbe70210") + msg.add_string("fake-host-key") + Q_S = unhexlify( + "043ae159594ba062efa121480e9ef136203fa9ec6b6e1f8723a321c16e62b945f573f3b822258cbcd094b9fa1c125cbfe5f043280893e66863cc0cb4dccbe70210" + ) msg.add_string(Q_S) - msg.add_string('fake-sig') + msg.add_string("fake-sig") msg.rewind() kex.parse_next(paramiko.kex_ecdh_nist._MSG_KEXECDH_REPLY, msg) - H = b'BAF7CE243A836037EB5D2221420F35C02B9AB6C957FE3BDE3369307B9612570A' + H = b"BAF7CE243A836037EB5D2221420F35C02B9AB6C957FE3BDE3369307B9612570A" self.assertEqual(K, kex.transport._K) self.assertEqual(H, hexlify(transport._H).upper()) - self.assertEqual((b'fake-host-key', b'fake-sig'), transport._verify) + self.assertEqual((b"fake-host-key", b"fake-sig"), transport._verify) self.assertTrue(transport._activated) def test_12_kex_nistp256_server(self): @@ -414,12 +486,16 @@ class KexTest (unittest.TestCase): transport.server_mode = True kex = KexNistp256(transport) kex.start_kex() - self.assertEqual((paramiko.kex_ecdh_nist._MSG_KEXECDH_INIT,), transport._expect) + self.assertEqual( + (paramiko.kex_ecdh_nist._MSG_KEXECDH_INIT,), transport._expect + ) - #fake init - msg=Message() - Q_C = unhexlify("043ae159594ba062efa121480e9ef136203fa9ec6b6e1f8723a321c16e62b945f573f3b822258cbcd094b9fa1c125cbfe5f043280893e66863cc0cb4dccbe70210") - H = b'2EF4957AFD530DD3F05DBEABF68D724FACC060974DA9704F2AEE4C3DE861E7CA' + # fake init + msg = Message() + Q_C = unhexlify( + "043ae159594ba062efa121480e9ef136203fa9ec6b6e1f8723a321c16e62b945f573f3b822258cbcd094b9fa1c125cbfe5f043280893e66863cc0cb4dccbe70210" + ) + H = b"2EF4957AFD530DD3F05DBEABF68D724FACC060974DA9704F2AEE4C3DE861E7CA" msg.add_string(Q_C) msg.rewind() kex.parse_next(paramiko.kex_ecdh_nist._MSG_KEXECDH_INIT, msg) diff --git a/tests/test_kex_gss.py b/tests/test_kex_gss.py index 025d1faa..afddee08 100644 --- a/tests/test_kex_gss.py +++ b/tests/test_kex_gss.py @@ -34,14 +34,14 @@ import paramiko from .util import needs_gssapi -class NullServer (paramiko.ServerInterface): +class NullServer(paramiko.ServerInterface): def get_allowed_auths(self, username): - return 'gssapi-keyex' + return "gssapi-keyex" - def check_auth_gssapi_keyex(self, username, - gss_authenticated=paramiko.AUTH_FAILED, - cc_file=None): + def check_auth_gssapi_keyex( + self, username, gss_authenticated=paramiko.AUTH_FAILED, cc_file=None + ): if gss_authenticated == paramiko.AUTH_SUCCESSFUL: return paramiko.AUTH_SUCCESSFUL return paramiko.AUTH_FAILED @@ -54,13 +54,14 @@ class NullServer (paramiko.ServerInterface): return paramiko.OPEN_SUCCEEDED def check_channel_exec_request(self, channel, command): - if command != 'yes': + if command != "yes": return False return True @needs_gssapi class GSSKexTest(unittest.TestCase): + @staticmethod def init(username, hostname): global krb5_principal, targ_name @@ -86,13 +87,13 @@ class GSSKexTest(unittest.TestCase): def _run(self): self.socks, addr = self.sockl.accept() self.ts = paramiko.Transport(self.socks, gss_kex=True) - host_key = paramiko.RSAKey.from_private_key_file('tests/test_rsa.key') + host_key = paramiko.RSAKey.from_private_key_file("tests/test_rsa.key") self.ts.add_server_key(host_key) self.ts.set_gss_host(targ_name) try: self.ts.load_server_moduli() except: - print ('(Failed to load moduli -- gex will be unsupported.)') + print("(Failed to load moduli -- gex will be unsupported.)") server = NullServer() self.ts.start_server(self.event, server) @@ -102,14 +103,21 @@ class GSSKexTest(unittest.TestCase): Diffie-Hellman Key Exchange and user authentication with the GSS-API context created during key exchange. """ - host_key = paramiko.RSAKey.from_private_key_file('tests/test_rsa.key') + host_key = paramiko.RSAKey.from_private_key_file("tests/test_rsa.key") public_host_key = paramiko.RSAKey(data=host_key.asbytes()) self.tc = paramiko.SSHClient() - self.tc.get_host_keys().add('[%s]:%d' % (self.hostname, self.port), - 'ssh-rsa', public_host_key) - self.tc.connect(self.hostname, self.port, username=self.username, - gss_auth=True, gss_kex=True, gss_host=gss_host) + self.tc.get_host_keys().add( + "[%s]:%d" % (self.hostname, self.port), "ssh-rsa", public_host_key + ) + self.tc.connect( + self.hostname, + self.port, + username=self.username, + gss_auth=True, + gss_kex=True, + gss_host=gss_host, + ) self.event.wait(1.0) self.assert_(self.event.is_set()) @@ -118,19 +126,19 @@ class GSSKexTest(unittest.TestCase): self.assertEquals(True, self.ts.is_authenticated()) self.assertEquals(True, self.tc.get_transport().gss_kex_used) - stdin, stdout, stderr = self.tc.exec_command('yes') + stdin, stdout, stderr = self.tc.exec_command("yes") schan = self.ts.accept(1.0) if rekey: self.tc.get_transport().renegotiate_keys() - schan.send('Hello there.\n') - schan.send_stderr('This is on stderr.\n') + schan.send("Hello there.\n") + schan.send_stderr("This is on stderr.\n") schan.close() - self.assertEquals('Hello there.\n', stdout.readline()) - self.assertEquals('', stdout.readline()) - self.assertEquals('This is on stderr.\n', stderr.readline()) - self.assertEquals('', stderr.readline()) + self.assertEquals("Hello there.\n", stdout.readline()) + self.assertEquals("", stdout.readline()) + self.assertEquals("This is on stderr.\n", stderr.readline()) + self.assertEquals("", stderr.readline()) stdin.close() stdout.close() diff --git a/tests/test_message.py b/tests/test_message.py index 645b0509..e6b80f3b 100644 --- a/tests/test_message.py +++ b/tests/test_message.py @@ -26,20 +26,20 @@ from paramiko.message import Message from paramiko.common import byte_chr, zero_byte -class MessageTest (unittest.TestCase): +class MessageTest(unittest.TestCase): - __a = b'\x00\x00\x00\x17\x07\x60\xe0\x90\x00\x00\x00\x01\x71\x00\x00\x00\x05\x68\x65\x6c\x6c\x6f\x00\x00\x03\xe8' + b'x' * 1000 - __b = b'\x01\x00\xf3\x00\x3f\x00\x00\x00\x10\x68\x75\x65\x79\x2c\x64\x65\x77\x65\x79\x2c\x6c\x6f\x75\x69\x65' - __c = b'\x00\x00\x00\x00\x00\x00\x00\x05\x00\x00\xf5\xe4\xd3\xc2\xb1\x09\x00\x00\x00\x01\x11\x00\x00\x00\x07\x00\xf5\xe4\xd3\xc2\xb1\x09\x00\x00\x00\x06\x9a\x1b\x2c\x3d\x4e\xf7' - __d = b'\x00\x00\x00\x05\xff\x00\x00\x00\x05\x11\x22\x33\x44\x55\xff\x00\x00\x00\x0a\x00\xf0\x00\x00\x00\x00\x00\x00\x00\x00\x01\x00\x00\x00\x03\x63\x61\x74\x00\x00\x00\x03\x61\x2c\x62' + __a = b"\x00\x00\x00\x17\x07\x60\xe0\x90\x00\x00\x00\x01\x71\x00\x00\x00\x05\x68\x65\x6c\x6c\x6f\x00\x00\x03\xe8" + b"x" * 1000 + __b = b"\x01\x00\xf3\x00\x3f\x00\x00\x00\x10\x68\x75\x65\x79\x2c\x64\x65\x77\x65\x79\x2c\x6c\x6f\x75\x69\x65" + __c = b"\x00\x00\x00\x00\x00\x00\x00\x05\x00\x00\xf5\xe4\xd3\xc2\xb1\x09\x00\x00\x00\x01\x11\x00\x00\x00\x07\x00\xf5\xe4\xd3\xc2\xb1\x09\x00\x00\x00\x06\x9a\x1b\x2c\x3d\x4e\xf7" + __d = b"\x00\x00\x00\x05\xff\x00\x00\x00\x05\x11\x22\x33\x44\x55\xff\x00\x00\x00\x0a\x00\xf0\x00\x00\x00\x00\x00\x00\x00\x00\x01\x00\x00\x00\x03\x63\x61\x74\x00\x00\x00\x03\x61\x2c\x62" def test_1_encode(self): msg = Message() msg.add_int(23) msg.add_int(123789456) - msg.add_string('q') - msg.add_string('hello') - msg.add_string('x' * 1000) + msg.add_string("q") + msg.add_string("hello") + msg.add_string("x" * 1000) self.assertEqual(msg.asbytes(), self.__a) msg = Message() @@ -48,7 +48,7 @@ class MessageTest (unittest.TestCase): msg.add_byte(byte_chr(0xf3)) msg.add_bytes(zero_byte + byte_chr(0x3f)) - msg.add_list(['huey', 'dewey', 'louie']) + msg.add_list(["huey", "dewey", "louie"]) self.assertEqual(msg.asbytes(), self.__b) msg = Message() @@ -63,16 +63,16 @@ class MessageTest (unittest.TestCase): msg = Message(self.__a) self.assertEqual(msg.get_int(), 23) self.assertEqual(msg.get_int(), 123789456) - self.assertEqual(msg.get_text(), 'q') - self.assertEqual(msg.get_text(), 'hello') - self.assertEqual(msg.get_text(), 'x' * 1000) + self.assertEqual(msg.get_text(), "q") + self.assertEqual(msg.get_text(), "hello") + self.assertEqual(msg.get_text(), "x" * 1000) msg = Message(self.__b) self.assertEqual(msg.get_boolean(), True) self.assertEqual(msg.get_boolean(), False) self.assertEqual(msg.get_byte(), byte_chr(0xf3)) self.assertEqual(msg.get_bytes(2), zero_byte + byte_chr(0x3f)) - self.assertEqual(msg.get_list(), ['huey', 'dewey', 'louie']) + self.assertEqual(msg.get_list(), ["huey", "dewey", "louie"]) msg = Message(self.__c) self.assertEqual(msg.get_int64(), 5) @@ -87,8 +87,8 @@ class MessageTest (unittest.TestCase): msg.add(0x1122334455) msg.add(0xf00000000000000000) msg.add(True) - msg.add('cat') - msg.add(['a', 'b']) + msg.add("cat") + msg.add(["a", "b"]) self.assertEqual(msg.asbytes(), self.__d) def test_4_misc(self): diff --git a/tests/test_packetizer.py b/tests/test_packetizer.py index 414b7e38..dbe5993e 100644 --- a/tests/test_packetizer.py +++ b/tests/test_packetizer.py @@ -36,19 +36,20 @@ from .loop import LoopSocket x55 = byte_chr(0x55) x1f = byte_chr(0x1f) -class PacketizerTest (unittest.TestCase): + +class PacketizerTest(unittest.TestCase): def test_1_write(self): rsock = LoopSocket() wsock = LoopSocket() rsock.link(wsock) p = Packetizer(wsock) - p.set_log(util.get_logger('paramiko.transport')) + p.set_log(util.get_logger("paramiko.transport")) p.set_hexdump(True) encryptor = Cipher( algorithms.AES(zero_byte * 16), modes.CBC(x55 * 16), - backend=default_backend() + backend=default_backend(), ).encryptor() p.set_outbound_cipher(encryptor, 16, sha1, 12, x1f * 20) @@ -63,22 +64,27 @@ class PacketizerTest (unittest.TestCase): data = rsock.recv(100) # 32 + 12 bytes of MAC = 44 self.assertEqual(44, len(data)) - self.assertEqual(b'\x43\x91\x97\xbd\x5b\x50\xac\x25\x87\xc2\xc4\x6b\xc7\xe9\x38\xc0', data[:16]) + self.assertEqual( + b"\x43\x91\x97\xbd\x5b\x50\xac\x25\x87\xc2\xc4\x6b\xc7\xe9\x38\xc0", + data[:16], + ) def test_2_read(self): rsock = LoopSocket() wsock = LoopSocket() rsock.link(wsock) p = Packetizer(rsock) - p.set_log(util.get_logger('paramiko.transport')) + p.set_log(util.get_logger("paramiko.transport")) p.set_hexdump(True) decryptor = Cipher( algorithms.AES(zero_byte * 16), modes.CBC(x55 * 16), - backend=default_backend() + backend=default_backend(), ).decryptor() p.set_inbound_cipher(decryptor, 16, sha1, 12, x1f * 20) - wsock.send(b'\x43\x91\x97\xbd\x5b\x50\xac\x25\x87\xc2\xc4\x6b\xc7\xe9\x38\xc0\x90\xd2\x16\x56\x0d\x71\x73\x61\x38\x7c\x4c\x3d\xfb\x97\x7d\xe2\x6e\x03\xb1\xa0\xc2\x1c\xd6\x41\x41\x4c\xb4\x59') + wsock.send( + b"\x43\x91\x97\xbd\x5b\x50\xac\x25\x87\xc2\xc4\x6b\xc7\xe9\x38\xc0\x90\xd2\x16\x56\x0d\x71\x73\x61\x38\x7c\x4c\x3d\xfb\x97\x7d\xe2\x6e\x03\xb1\xa0\xc2\x1c\xd6\x41\x41\x4c\xb4\x59" + ) cmd, m = p.read_message() self.assertEqual(100, cmd) self.assertEqual(100, m.get_int()) @@ -86,18 +92,18 @@ class PacketizerTest (unittest.TestCase): self.assertEqual(900, m.get_int()) def test_3_closed(self): - if sys.platform.startswith("win"): # no SIGALRM on windows + if sys.platform.startswith("win"): # no SIGALRM on windows return rsock = LoopSocket() wsock = LoopSocket() rsock.link(wsock) p = Packetizer(wsock) - p.set_log(util.get_logger('paramiko.transport')) + p.set_log(util.get_logger("paramiko.transport")) p.set_hexdump(True) encryptor = Cipher( algorithms.AES(zero_byte * 16), modes.CBC(x55 * 16), - backend=default_backend() + backend=default_backend(), ).encryptor() p.set_outbound_cipher(encryptor, 16, sha1, 12, x1f * 20) @@ -115,14 +121,17 @@ class PacketizerTest (unittest.TestCase): import signal class TimeoutError(Exception): + def __init__(self, error_message): - if hasattr(errno, 'ETIME'): + if hasattr(errno, "ETIME"): self.message = os.sterror(errno.ETIME) else: self.messaage = error_message - def timeout(seconds=1, error_message='Timer expired'): + def timeout(seconds=1, error_message="Timer expired"): + def decorator(func): + def _handle_timeout(signum, frame): raise TimeoutError(error_message) @@ -138,5 +147,6 @@ class PacketizerTest (unittest.TestCase): return wraps(func)(wrapper) return decorator + send = timeout()(p.send_message) self.assertRaises(EOFError, send, m) diff --git a/tests/test_pkey.py b/tests/test_pkey.py index 1827d2a9..66bebc43 100644 --- a/tests/test_pkey.py +++ b/tests/test_pkey.py @@ -34,18 +34,18 @@ from .util import _support # from openssh's ssh-keygen -PUB_RSA = 'ssh-rsa AAAAB3NzaC1yc2EAAAABIwAAAIEA049W6geFpmsljTwfvI1UmKWWJPNFI74+vNKTk4dmzkQY2yAMs6FhlvhlI8ysU4oj71ZsRYMecHbBbxdN79+JRFVYTKaLqjwGENeTd+yv4q+V2PvZv3fLnzApI3l7EJCqhWwJUHJ1jAkZzqDx0tyOL4uoZpww3nmE0kb3y21tH4c=' -PUB_DSS = 'ssh-dss AAAAB3NzaC1kc3MAAACBAOeBpgNnfRzr/twmAQRu2XwWAp3CFtrVnug6s6fgwj/oLjYbVtjAy6pl/h0EKCWx2rf1IetyNsTxWrniA9I6HeDj65X1FyDkg6g8tvCnaNB8Xp/UUhuzHuGsMIipRxBxw9LF608EqZcj1E3ytktoW5B5OcjrkEoz3xG7C+rpIjYvAAAAFQDwz4UnmsGiSNu5iqjn3uTzwUpshwAAAIEAkxfFeY8P2wZpDjX0MimZl5wkoFQDL25cPzGBuB4OnB8NoUk/yjAHIIpEShw8V+LzouMK5CTJQo5+Ngw3qIch/WgRmMHy4kBq1SsXMjQCte1So6HBMvBPIW5SiMTmjCfZZiw4AYHK+B/JaOwaG9yRg2Ejg4Ok10+XFDxlqZo8Y+wAAACARmR7CCPjodxASvRbIyzaVpZoJ/Z6x7dAumV+ysrV1BVYd0lYukmnjO1kKBWApqpH1ve9XDQYN8zgxM4b16L21kpoWQnZtXrY3GZ4/it9kUgyB7+NwacIBlXa8cMDL7Q/69o0d54U0X/NeX5QxuYR6OMJlrkQB7oiW/P/1mwjQgE=' -PUB_ECDSA_256 = 'ecdsa-sha2-nistp256 AAAAE2VjZHNhLXNoYTItbmlzdHAyNTYAAAAIbmlzdHAyNTYAAABBBJSPZm3ZWkvk/Zx8WP+fZRZ5/NBBHnGQwR6uIC6XHGPDIHuWUzIjAwA0bzqkOUffEsbLe+uQgKl5kbc/L8KA/eo=' -PUB_ECDSA_384 = 'ecdsa-sha2-nistp384 AAAAE2VjZHNhLXNoYTItbmlzdHAzODQAAAAIbmlzdHAzODQAAABhBBbGibQLW9AAZiGN2hEQxWYYoFaWKwN3PKSaDJSMqmIn1Z9sgRUuw8Y/w502OGvXL/wFk0i2z50l3pWZjD7gfMH7gX5TUiCzwrQkS+Hn1U2S9aF5WJp0NcIzYxXw2r4M2A==' -PUB_ECDSA_521 = 'ecdsa-sha2-nistp521 AAAAE2VjZHNhLXNoYTItbmlzdHA1MjEAAAAIbmlzdHA1MjEAAACFBACaOaFLZGuxa5AW16qj6VLypFbLrEWrt9AZUloCMefxO8bNLjK/O5g0rAVasar1TnyHE9qj4NwzANZASWjQNbc4MAG8vzqezFwLIn/kNyNTsXNfqEko9OgHZknlj2Z79dwTJcRAL4QLcT5aND0EHZLB2fAUDXiWIb2j4rg1mwPlBMiBXA==' - -FINGER_RSA = '1024 60:73:38:44:cb:51:86:65:7f:de:da:a2:2b:5a:57:d5' -FINGER_DSS = '1024 44:78:f0:b9:a2:3c:c5:18:20:09:ff:75:5b:c1:d2:6c' -FINGER_ECDSA_256 = '256 25:19:eb:55:e6:a1:47:ff:4f:38:d2:75:6f:a5:d5:60' -FINGER_ECDSA_384 = '384 c1:8d:a0:59:09:47:41:8e:a8:a6:07:01:29:23:b4:65' -FINGER_ECDSA_521 = '521 44:58:22:52:12:33:16:0e:ce:0e:be:2c:7c:7e:cc:1e' -SIGNED_RSA = '20:d7:8a:31:21:cb:f7:92:12:f2:a4:89:37:f5:78:af:e6:16:b6:25:b9:97:3d:a2:cd:5f:ca:20:21:73:4c:ad:34:73:8f:20:77:28:e2:94:15:08:d8:91:40:7a:85:83:bf:18:37:95:dc:54:1a:9b:88:29:6c:73:ca:38:b4:04:f1:56:b9:f2:42:9d:52:1b:29:29:b4:4f:fd:c9:2d:af:47:d2:40:76:30:f3:63:45:0c:d9:1d:43:86:0f:1c:70:e2:93:12:34:f3:ac:c5:0a:2f:14:50:66:59:f1:88:ee:c1:4a:e9:d1:9c:4e:46:f0:0e:47:6f:38:74:f1:44:a8' +PUB_RSA = "ssh-rsa AAAAB3NzaC1yc2EAAAABIwAAAIEA049W6geFpmsljTwfvI1UmKWWJPNFI74+vNKTk4dmzkQY2yAMs6FhlvhlI8ysU4oj71ZsRYMecHbBbxdN79+JRFVYTKaLqjwGENeTd+yv4q+V2PvZv3fLnzApI3l7EJCqhWwJUHJ1jAkZzqDx0tyOL4uoZpww3nmE0kb3y21tH4c=" +PUB_DSS = "ssh-dss AAAAB3NzaC1kc3MAAACBAOeBpgNnfRzr/twmAQRu2XwWAp3CFtrVnug6s6fgwj/oLjYbVtjAy6pl/h0EKCWx2rf1IetyNsTxWrniA9I6HeDj65X1FyDkg6g8tvCnaNB8Xp/UUhuzHuGsMIipRxBxw9LF608EqZcj1E3ytktoW5B5OcjrkEoz3xG7C+rpIjYvAAAAFQDwz4UnmsGiSNu5iqjn3uTzwUpshwAAAIEAkxfFeY8P2wZpDjX0MimZl5wkoFQDL25cPzGBuB4OnB8NoUk/yjAHIIpEShw8V+LzouMK5CTJQo5+Ngw3qIch/WgRmMHy4kBq1SsXMjQCte1So6HBMvBPIW5SiMTmjCfZZiw4AYHK+B/JaOwaG9yRg2Ejg4Ok10+XFDxlqZo8Y+wAAACARmR7CCPjodxASvRbIyzaVpZoJ/Z6x7dAumV+ysrV1BVYd0lYukmnjO1kKBWApqpH1ve9XDQYN8zgxM4b16L21kpoWQnZtXrY3GZ4/it9kUgyB7+NwacIBlXa8cMDL7Q/69o0d54U0X/NeX5QxuYR6OMJlrkQB7oiW/P/1mwjQgE=" +PUB_ECDSA_256 = "ecdsa-sha2-nistp256 AAAAE2VjZHNhLXNoYTItbmlzdHAyNTYAAAAIbmlzdHAyNTYAAABBBJSPZm3ZWkvk/Zx8WP+fZRZ5/NBBHnGQwR6uIC6XHGPDIHuWUzIjAwA0bzqkOUffEsbLe+uQgKl5kbc/L8KA/eo=" +PUB_ECDSA_384 = "ecdsa-sha2-nistp384 AAAAE2VjZHNhLXNoYTItbmlzdHAzODQAAAAIbmlzdHAzODQAAABhBBbGibQLW9AAZiGN2hEQxWYYoFaWKwN3PKSaDJSMqmIn1Z9sgRUuw8Y/w502OGvXL/wFk0i2z50l3pWZjD7gfMH7gX5TUiCzwrQkS+Hn1U2S9aF5WJp0NcIzYxXw2r4M2A==" +PUB_ECDSA_521 = "ecdsa-sha2-nistp521 AAAAE2VjZHNhLXNoYTItbmlzdHA1MjEAAAAIbmlzdHA1MjEAAACFBACaOaFLZGuxa5AW16qj6VLypFbLrEWrt9AZUloCMefxO8bNLjK/O5g0rAVasar1TnyHE9qj4NwzANZASWjQNbc4MAG8vzqezFwLIn/kNyNTsXNfqEko9OgHZknlj2Z79dwTJcRAL4QLcT5aND0EHZLB2fAUDXiWIb2j4rg1mwPlBMiBXA==" + +FINGER_RSA = "1024 60:73:38:44:cb:51:86:65:7f:de:da:a2:2b:5a:57:d5" +FINGER_DSS = "1024 44:78:f0:b9:a2:3c:c5:18:20:09:ff:75:5b:c1:d2:6c" +FINGER_ECDSA_256 = "256 25:19:eb:55:e6:a1:47:ff:4f:38:d2:75:6f:a5:d5:60" +FINGER_ECDSA_384 = "384 c1:8d:a0:59:09:47:41:8e:a8:a6:07:01:29:23:b4:65" +FINGER_ECDSA_521 = "521 44:58:22:52:12:33:16:0e:ce:0e:be:2c:7c:7e:cc:1e" +SIGNED_RSA = "20:d7:8a:31:21:cb:f7:92:12:f2:a4:89:37:f5:78:af:e6:16:b6:25:b9:97:3d:a2:cd:5f:ca:20:21:73:4c:ad:34:73:8f:20:77:28:e2:94:15:08:d8:91:40:7a:85:83:bf:18:37:95:dc:54:1a:9b:88:29:6c:73:ca:38:b4:04:f1:56:b9:f2:42:9d:52:1b:29:29:b4:4f:fd:c9:2d:af:47:d2:40:76:30:f3:63:45:0c:d9:1d:43:86:0f:1c:70:e2:93:12:34:f3:ac:c5:0a:2f:14:50:66:59:f1:88:ee:c1:4a:e9:d1:9c:4e:46:f0:0e:47:6f:38:74:f1:44:a8" RSA_PRIVATE_OUT = """\ -----BEGIN RSA PRIVATE KEY----- @@ -107,10 +107,10 @@ L4QLcT5aND0EHZLB2fAUDXiWIb2j4rg1mwPlBMiBXA== -----END EC PRIVATE KEY----- """ -x1234 = b'\x01\x02\x03\x04' +x1234 = b"\x01\x02\x03\x04" -TEST_KEY_BYTESTR_2 = '\x00\x00\x00\x07ssh-rsa\x00\x00\x00\x01#\x00\x00\x00\x81\x00\xd3\x8fV\xea\x07\x85\xa6k%\x8d<\x1f\xbc\x8dT\x98\xa5\x96$\xf3E#\xbe>\xbc\xd2\x93\x93\x87f\xceD\x18\xdb \x0c\xb3\xa1a\x96\xf8e#\xcc\xacS\x8a#\xefVlE\x83\x1epv\xc1o\x17M\xef\xdf\x89DUXL\xa6\x8b\xaa<\x06\x10\xd7\x93w\xec\xaf\xe2\xaf\x95\xd8\xfb\xd9\xbfw\xcb\x9f0)#y{\x10\x90\xaa\x85l\tPru\x8c\t\x19\xce\xa0\xf1\xd2\xdc\x8e/\x8b\xa8f\x9c0\xdey\x84\xd2F\xf7\xcbmm\x1f\x87' -TEST_KEY_BYTESTR_3 = '\x00\x00\x00\x07ssh-rsa\x00\x00\x00\x01#\x00\x00\x00\x00ӏV\x07k%<\x1fT$E#>ғfD\x18 \x0cae#̬S#VlE\x1epvo\x17M߉DUXL<\x06\x10דw\u2bd5ٿw˟0)#y{\x10l\tPru\t\x19Π\u070e/f0yFmm\x1f' +TEST_KEY_BYTESTR_2 = "\x00\x00\x00\x07ssh-rsa\x00\x00\x00\x01#\x00\x00\x00\x81\x00\xd3\x8fV\xea\x07\x85\xa6k%\x8d<\x1f\xbc\x8dT\x98\xa5\x96$\xf3E#\xbe>\xbc\xd2\x93\x93\x87f\xceD\x18\xdb \x0c\xb3\xa1a\x96\xf8e#\xcc\xacS\x8a#\xefVlE\x83\x1epv\xc1o\x17M\xef\xdf\x89DUXL\xa6\x8b\xaa<\x06\x10\xd7\x93w\xec\xaf\xe2\xaf\x95\xd8\xfb\xd9\xbfw\xcb\x9f0)#y{\x10\x90\xaa\x85l\tPru\x8c\t\x19\xce\xa0\xf1\xd2\xdc\x8e/\x8b\xa8f\x9c0\xdey\x84\xd2F\xf7\xcbmm\x1f\x87" +TEST_KEY_BYTESTR_3 = "\x00\x00\x00\x07ssh-rsa\x00\x00\x00\x01#\x00\x00\x00\x00ӏV\x07k%<\x1fT$E#>ғfD\x18 \x0cae#̬S#VlE\x1epvo\x17M߉DUXL<\x06\x10דw\u2bd5ٿw˟0)#y{\x10l\tPru\t\x19Π\u070e/f0yFmm\x1f" class KeyTest(unittest.TestCase): @@ -127,21 +127,20 @@ class KeyTest(unittest.TestCase): """ with open(keyfile, "r") as fh: self.assertEqual( - fh.readline()[:-1], - "-----BEGIN RSA PRIVATE KEY-----" + fh.readline()[:-1], "-----BEGIN RSA PRIVATE KEY-----" ) self.assertEqual(fh.readline()[:-1], "Proc-Type: 4,ENCRYPTED") self.assertEqual(fh.readline()[0:10], "DEK-Info: ") def test_1_generate_key_bytes(self): - key = util.generate_key_bytes(md5, x1234, 'happy birthday', 30) - exp = b'\x61\xE1\xF2\x72\xF4\xC1\xC4\x56\x15\x86\xBD\x32\x24\x98\xC0\xE9\x24\x67\x27\x80\xF4\x7B\xB3\x7D\xDA\x7D\x54\x01\x9E\x64' + key = util.generate_key_bytes(md5, x1234, "happy birthday", 30) + exp = b"\x61\xE1\xF2\x72\xF4\xC1\xC4\x56\x15\x86\xBD\x32\x24\x98\xC0\xE9\x24\x67\x27\x80\xF4\x7B\xB3\x7D\xDA\x7D\x54\x01\x9E\x64" self.assertEqual(exp, key) def test_2_load_rsa(self): - key = RSAKey.from_private_key_file(_support('test_rsa.key')) - self.assertEqual('ssh-rsa', key.get_name()) - exp_rsa = b(FINGER_RSA.split()[1].replace(':', '')) + key = RSAKey.from_private_key_file(_support("test_rsa.key")) + self.assertEqual("ssh-rsa", key.get_name()) + exp_rsa = b(FINGER_RSA.split()[1].replace(":", "")) my_rsa = hexlify(key.get_fingerprint()) self.assertEqual(exp_rsa, my_rsa) self.assertEqual(PUB_RSA.split()[1], key.get_base64()) @@ -155,18 +154,20 @@ class KeyTest(unittest.TestCase): self.assertEqual(key, key2) def test_3_load_rsa_password(self): - key = RSAKey.from_private_key_file(_support('test_rsa_password.key'), 'television') - self.assertEqual('ssh-rsa', key.get_name()) - exp_rsa = b(FINGER_RSA.split()[1].replace(':', '')) + key = RSAKey.from_private_key_file( + _support("test_rsa_password.key"), "television" + ) + self.assertEqual("ssh-rsa", key.get_name()) + exp_rsa = b(FINGER_RSA.split()[1].replace(":", "")) my_rsa = hexlify(key.get_fingerprint()) self.assertEqual(exp_rsa, my_rsa) self.assertEqual(PUB_RSA.split()[1], key.get_base64()) self.assertEqual(1024, key.get_bits()) def test_4_load_dss(self): - key = DSSKey.from_private_key_file(_support('test_dss.key')) - self.assertEqual('ssh-dss', key.get_name()) - exp_dss = b(FINGER_DSS.split()[1].replace(':', '')) + key = DSSKey.from_private_key_file(_support("test_dss.key")) + self.assertEqual("ssh-dss", key.get_name()) + exp_dss = b(FINGER_DSS.split()[1].replace(":", "")) my_dss = hexlify(key.get_fingerprint()) self.assertEqual(exp_dss, my_dss) self.assertEqual(PUB_DSS.split()[1], key.get_base64()) @@ -180,9 +181,11 @@ class KeyTest(unittest.TestCase): self.assertEqual(key, key2) def test_5_load_dss_password(self): - key = DSSKey.from_private_key_file(_support('test_dss_password.key'), 'television') - self.assertEqual('ssh-dss', key.get_name()) - exp_dss = b(FINGER_DSS.split()[1].replace(':', '')) + key = DSSKey.from_private_key_file( + _support("test_dss_password.key"), "television" + ) + self.assertEqual("ssh-dss", key.get_name()) + exp_dss = b(FINGER_DSS.split()[1].replace(":", "")) my_dss = hexlify(key.get_fingerprint()) self.assertEqual(exp_dss, my_dss) self.assertEqual(PUB_DSS.split()[1], key.get_base64()) @@ -190,7 +193,7 @@ class KeyTest(unittest.TestCase): def test_6_compare_rsa(self): # verify that the private & public keys compare equal - key = RSAKey.from_private_key_file(_support('test_rsa.key')) + key = RSAKey.from_private_key_file(_support("test_rsa.key")) self.assertEqual(key, key) pub = RSAKey(data=key.asbytes()) self.assertTrue(key.can_sign()) @@ -199,7 +202,7 @@ class KeyTest(unittest.TestCase): def test_7_compare_dss(self): # verify that the private & public keys compare equal - key = DSSKey.from_private_key_file(_support('test_dss.key')) + key = DSSKey.from_private_key_file(_support("test_dss.key")) self.assertEqual(key, key) pub = DSSKey(data=key.asbytes()) self.assertTrue(key.can_sign()) @@ -208,77 +211,79 @@ class KeyTest(unittest.TestCase): def test_8_sign_rsa(self): # verify that the rsa private key can sign and verify - key = RSAKey.from_private_key_file(_support('test_rsa.key')) - msg = key.sign_ssh_data(b'ice weasels') + key = RSAKey.from_private_key_file(_support("test_rsa.key")) + msg = key.sign_ssh_data(b"ice weasels") self.assertTrue(type(msg) is Message) msg.rewind() - self.assertEqual('ssh-rsa', msg.get_text()) - sig = bytes().join([byte_chr(int(x, 16)) for x in SIGNED_RSA.split(':')]) + self.assertEqual("ssh-rsa", msg.get_text()) + sig = bytes().join( + [byte_chr(int(x, 16)) for x in SIGNED_RSA.split(":")] + ) self.assertEqual(sig, msg.get_binary()) msg.rewind() pub = RSAKey(data=key.asbytes()) - self.assertTrue(pub.verify_ssh_sig(b'ice weasels', msg)) + self.assertTrue(pub.verify_ssh_sig(b"ice weasels", msg)) def test_9_sign_dss(self): # verify that the dss private key can sign and verify - key = DSSKey.from_private_key_file(_support('test_dss.key')) - msg = key.sign_ssh_data(b'ice weasels') + key = DSSKey.from_private_key_file(_support("test_dss.key")) + msg = key.sign_ssh_data(b"ice weasels") self.assertTrue(type(msg) is Message) msg.rewind() - self.assertEqual('ssh-dss', msg.get_text()) + self.assertEqual("ssh-dss", msg.get_text()) # can't do the same test as we do for RSA, because DSS signatures # are usually different each time. but we can test verification # anyway so it's ok. self.assertEqual(40, len(msg.get_binary())) msg.rewind() pub = DSSKey(data=key.asbytes()) - self.assertTrue(pub.verify_ssh_sig(b'ice weasels', msg)) + self.assertTrue(pub.verify_ssh_sig(b"ice weasels", msg)) def test_A_generate_rsa(self): key = RSAKey.generate(1024) - msg = key.sign_ssh_data(b'jerri blank') + msg = key.sign_ssh_data(b"jerri blank") msg.rewind() - self.assertTrue(key.verify_ssh_sig(b'jerri blank', msg)) + self.assertTrue(key.verify_ssh_sig(b"jerri blank", msg)) def test_B_generate_dss(self): key = DSSKey.generate(1024) - msg = key.sign_ssh_data(b'jerri blank') + msg = key.sign_ssh_data(b"jerri blank") msg.rewind() - self.assertTrue(key.verify_ssh_sig(b'jerri blank', msg)) + self.assertTrue(key.verify_ssh_sig(b"jerri blank", msg)) def test_C_generate_ecdsa(self): key = ECDSAKey.generate() - msg = key.sign_ssh_data(b'jerri blank') + msg = key.sign_ssh_data(b"jerri blank") msg.rewind() - self.assertTrue(key.verify_ssh_sig(b'jerri blank', msg)) + self.assertTrue(key.verify_ssh_sig(b"jerri blank", msg)) self.assertEqual(key.get_bits(), 256) - self.assertEqual(key.get_name(), 'ecdsa-sha2-nistp256') + self.assertEqual(key.get_name(), "ecdsa-sha2-nistp256") key = ECDSAKey.generate(bits=256) - msg = key.sign_ssh_data(b'jerri blank') + msg = key.sign_ssh_data(b"jerri blank") msg.rewind() - self.assertTrue(key.verify_ssh_sig(b'jerri blank', msg)) + self.assertTrue(key.verify_ssh_sig(b"jerri blank", msg)) self.assertEqual(key.get_bits(), 256) - self.assertEqual(key.get_name(), 'ecdsa-sha2-nistp256') + self.assertEqual(key.get_name(), "ecdsa-sha2-nistp256") key = ECDSAKey.generate(bits=384) - msg = key.sign_ssh_data(b'jerri blank') + msg = key.sign_ssh_data(b"jerri blank") msg.rewind() - self.assertTrue(key.verify_ssh_sig(b'jerri blank', msg)) + self.assertTrue(key.verify_ssh_sig(b"jerri blank", msg)) self.assertEqual(key.get_bits(), 384) - self.assertEqual(key.get_name(), 'ecdsa-sha2-nistp384') + self.assertEqual(key.get_name(), "ecdsa-sha2-nistp384") key = ECDSAKey.generate(bits=521) - msg = key.sign_ssh_data(b'jerri blank') + msg = key.sign_ssh_data(b"jerri blank") msg.rewind() - self.assertTrue(key.verify_ssh_sig(b'jerri blank', msg)) + self.assertTrue(key.verify_ssh_sig(b"jerri blank", msg)) self.assertEqual(key.get_bits(), 521) - self.assertEqual(key.get_name(), 'ecdsa-sha2-nistp521') + self.assertEqual(key.get_name(), "ecdsa-sha2-nistp521") def test_10_load_ecdsa_256(self): - key = ECDSAKey.from_private_key_file(_support('test_ecdsa_256.key')) - self.assertEqual('ecdsa-sha2-nistp256', key.get_name()) - exp_ecdsa = b(FINGER_ECDSA_256.split()[1].replace(':', '')) + key = ECDSAKey.from_private_key_file(_support("test_ecdsa_256.key")) + self.assertEqual("ecdsa-sha2-nistp256", key.get_name()) + exp_ecdsa = b(FINGER_ECDSA_256.split()[1].replace(":", "")) my_ecdsa = hexlify(key.get_fingerprint()) self.assertEqual(exp_ecdsa, my_ecdsa) self.assertEqual(PUB_ECDSA_256.split()[1], key.get_base64()) @@ -292,9 +297,11 @@ class KeyTest(unittest.TestCase): self.assertEqual(key, key2) def test_11_load_ecdsa_password_256(self): - key = ECDSAKey.from_private_key_file(_support('test_ecdsa_password_256.key'), b'television') - self.assertEqual('ecdsa-sha2-nistp256', key.get_name()) - exp_ecdsa = b(FINGER_ECDSA_256.split()[1].replace(':', '')) + key = ECDSAKey.from_private_key_file( + _support("test_ecdsa_password_256.key"), b"television" + ) + self.assertEqual("ecdsa-sha2-nistp256", key.get_name()) + exp_ecdsa = b(FINGER_ECDSA_256.split()[1].replace(":", "")) my_ecdsa = hexlify(key.get_fingerprint()) self.assertEqual(exp_ecdsa, my_ecdsa) self.assertEqual(PUB_ECDSA_256.split()[1], key.get_base64()) @@ -302,7 +309,7 @@ class KeyTest(unittest.TestCase): def test_12_compare_ecdsa_256(self): # verify that the private & public keys compare equal - key = ECDSAKey.from_private_key_file(_support('test_ecdsa_256.key')) + key = ECDSAKey.from_private_key_file(_support("test_ecdsa_256.key")) self.assertEqual(key, key) pub = ECDSAKey(data=key.asbytes()) self.assertTrue(key.can_sign()) @@ -311,11 +318,11 @@ class KeyTest(unittest.TestCase): def test_13_sign_ecdsa_256(self): # verify that the rsa private key can sign and verify - key = ECDSAKey.from_private_key_file(_support('test_ecdsa_256.key')) - msg = key.sign_ssh_data(b'ice weasels') + key = ECDSAKey.from_private_key_file(_support("test_ecdsa_256.key")) + msg = key.sign_ssh_data(b"ice weasels") self.assertTrue(type(msg) is Message) msg.rewind() - self.assertEqual('ecdsa-sha2-nistp256', msg.get_text()) + self.assertEqual("ecdsa-sha2-nistp256", msg.get_text()) # ECDSA signatures, like DSS signatures, tend to be different # each time, so we can't compare against a "known correct" # signature. @@ -323,12 +330,12 @@ class KeyTest(unittest.TestCase): msg.rewind() pub = ECDSAKey(data=key.asbytes()) - self.assertTrue(pub.verify_ssh_sig(b'ice weasels', msg)) + self.assertTrue(pub.verify_ssh_sig(b"ice weasels", msg)) def test_14_load_ecdsa_384(self): - key = ECDSAKey.from_private_key_file(_support('test_ecdsa_384.key')) - self.assertEqual('ecdsa-sha2-nistp384', key.get_name()) - exp_ecdsa = b(FINGER_ECDSA_384.split()[1].replace(':', '')) + key = ECDSAKey.from_private_key_file(_support("test_ecdsa_384.key")) + self.assertEqual("ecdsa-sha2-nistp384", key.get_name()) + exp_ecdsa = b(FINGER_ECDSA_384.split()[1].replace(":", "")) my_ecdsa = hexlify(key.get_fingerprint()) self.assertEqual(exp_ecdsa, my_ecdsa) self.assertEqual(PUB_ECDSA_384.split()[1], key.get_base64()) @@ -342,9 +349,11 @@ class KeyTest(unittest.TestCase): self.assertEqual(key, key2) def test_15_load_ecdsa_password_384(self): - key = ECDSAKey.from_private_key_file(_support('test_ecdsa_password_384.key'), b'television') - self.assertEqual('ecdsa-sha2-nistp384', key.get_name()) - exp_ecdsa = b(FINGER_ECDSA_384.split()[1].replace(':', '')) + key = ECDSAKey.from_private_key_file( + _support("test_ecdsa_password_384.key"), b"television" + ) + self.assertEqual("ecdsa-sha2-nistp384", key.get_name()) + exp_ecdsa = b(FINGER_ECDSA_384.split()[1].replace(":", "")) my_ecdsa = hexlify(key.get_fingerprint()) self.assertEqual(exp_ecdsa, my_ecdsa) self.assertEqual(PUB_ECDSA_384.split()[1], key.get_base64()) @@ -352,7 +361,7 @@ class KeyTest(unittest.TestCase): def test_16_compare_ecdsa_384(self): # verify that the private & public keys compare equal - key = ECDSAKey.from_private_key_file(_support('test_ecdsa_384.key')) + key = ECDSAKey.from_private_key_file(_support("test_ecdsa_384.key")) self.assertEqual(key, key) pub = ECDSAKey(data=key.asbytes()) self.assertTrue(key.can_sign()) @@ -361,11 +370,11 @@ class KeyTest(unittest.TestCase): def test_17_sign_ecdsa_384(self): # verify that the rsa private key can sign and verify - key = ECDSAKey.from_private_key_file(_support('test_ecdsa_384.key')) - msg = key.sign_ssh_data(b'ice weasels') + key = ECDSAKey.from_private_key_file(_support("test_ecdsa_384.key")) + msg = key.sign_ssh_data(b"ice weasels") self.assertTrue(type(msg) is Message) msg.rewind() - self.assertEqual('ecdsa-sha2-nistp384', msg.get_text()) + self.assertEqual("ecdsa-sha2-nistp384", msg.get_text()) # ECDSA signatures, like DSS signatures, tend to be different # each time, so we can't compare against a "known correct" # signature. @@ -373,12 +382,12 @@ class KeyTest(unittest.TestCase): msg.rewind() pub = ECDSAKey(data=key.asbytes()) - self.assertTrue(pub.verify_ssh_sig(b'ice weasels', msg)) + self.assertTrue(pub.verify_ssh_sig(b"ice weasels", msg)) def test_18_load_ecdsa_521(self): - key = ECDSAKey.from_private_key_file(_support('test_ecdsa_521.key')) - self.assertEqual('ecdsa-sha2-nistp521', key.get_name()) - exp_ecdsa = b(FINGER_ECDSA_521.split()[1].replace(':', '')) + key = ECDSAKey.from_private_key_file(_support("test_ecdsa_521.key")) + self.assertEqual("ecdsa-sha2-nistp521", key.get_name()) + exp_ecdsa = b(FINGER_ECDSA_521.split()[1].replace(":", "")) my_ecdsa = hexlify(key.get_fingerprint()) self.assertEqual(exp_ecdsa, my_ecdsa) self.assertEqual(PUB_ECDSA_521.split()[1], key.get_base64()) @@ -395,9 +404,11 @@ class KeyTest(unittest.TestCase): self.assertEqual(key, key2) def test_19_load_ecdsa_password_521(self): - key = ECDSAKey.from_private_key_file(_support('test_ecdsa_password_521.key'), b'television') - self.assertEqual('ecdsa-sha2-nistp521', key.get_name()) - exp_ecdsa = b(FINGER_ECDSA_521.split()[1].replace(':', '')) + key = ECDSAKey.from_private_key_file( + _support("test_ecdsa_password_521.key"), b"television" + ) + self.assertEqual("ecdsa-sha2-nistp521", key.get_name()) + exp_ecdsa = b(FINGER_ECDSA_521.split()[1].replace(":", "")) my_ecdsa = hexlify(key.get_fingerprint()) self.assertEqual(exp_ecdsa, my_ecdsa) self.assertEqual(PUB_ECDSA_521.split()[1], key.get_base64()) @@ -405,7 +416,7 @@ class KeyTest(unittest.TestCase): def test_20_compare_ecdsa_521(self): # verify that the private & public keys compare equal - key = ECDSAKey.from_private_key_file(_support('test_ecdsa_521.key')) + key = ECDSAKey.from_private_key_file(_support("test_ecdsa_521.key")) self.assertEqual(key, key) pub = ECDSAKey(data=key.asbytes()) self.assertTrue(key.can_sign()) @@ -414,11 +425,11 @@ class KeyTest(unittest.TestCase): def test_21_sign_ecdsa_521(self): # verify that the rsa private key can sign and verify - key = ECDSAKey.from_private_key_file(_support('test_ecdsa_521.key')) - msg = key.sign_ssh_data(b'ice weasels') + key = ECDSAKey.from_private_key_file(_support("test_ecdsa_521.key")) + msg = key.sign_ssh_data(b"ice weasels") self.assertTrue(type(msg) is Message) msg.rewind() - self.assertEqual('ecdsa-sha2-nistp521', msg.get_text()) + self.assertEqual("ecdsa-sha2-nistp521", msg.get_text()) # ECDSA signatures, like DSS signatures, tend to be different # each time, so we can't compare against a "known correct" # signature. @@ -426,14 +437,14 @@ class KeyTest(unittest.TestCase): msg.rewind() pub = ECDSAKey(data=key.asbytes()) - self.assertTrue(pub.verify_ssh_sig(b'ice weasels', msg)) + self.assertTrue(pub.verify_ssh_sig(b"ice weasels", msg)) def test_salt_size(self): # Read an existing encrypted private key - file_ = _support('test_rsa_password.key') - password = 'television' - newfile = file_ + '.new' - newpassword = 'radio' + file_ = _support("test_rsa_password.key") + password = "television" + newfile = file_ + ".new" + newpassword = "radio" key = RSAKey(filename=file_, password=password) # Write out a newly re-encrypted copy with a new password. # When the bug under test exists, this will ValueError. @@ -447,20 +458,20 @@ class KeyTest(unittest.TestCase): os.remove(newfile) def test_stringification(self): - key = RSAKey.from_private_key_file(_support('test_rsa.key')) + key = RSAKey.from_private_key_file(_support("test_rsa.key")) comparable = TEST_KEY_BYTESTR_2 if PY2 else TEST_KEY_BYTESTR_3 self.assertEqual(str(key), comparable) def test_ed25519(self): - key1 = Ed25519Key.from_private_key_file(_support('test_ed25519.key')) + key1 = Ed25519Key.from_private_key_file(_support("test_ed25519.key")) key2 = Ed25519Key.from_private_key_file( - _support('test_ed25519_password.key'), b'abc123' + _support("test_ed25519_password.key"), b"abc123" ) self.assertNotEqual(key1.asbytes(), key2.asbytes()) def test_ed25519_compare(self): # verify that the private & public keys compare equal - key = Ed25519Key.from_private_key_file(_support('test_ed25519.key')) + key = Ed25519Key.from_private_key_file(_support("test_ed25519.key")) self.assertEqual(key, key) pub = Ed25519Key(data=key.asbytes()) self.assertTrue(key.can_sign()) @@ -470,25 +481,25 @@ class KeyTest(unittest.TestCase): def test_ed25519_nonbytes_password(self): # https://github.com/paramiko/paramiko/issues/1039 key = Ed25519Key.from_private_key_file( - _support('test_ed25519_password.key'), + _support("test_ed25519_password.key"), # NOTE: not a bytes. Amusingly, the test above for same key DOES # explicitly cast to bytes...code smell! - 'abc123', + "abc123", ) # No exception -> it's good. Meh. def test_ed25519_load_from_file_obj(self): - with open(_support('test_ed25519.key')) as pkey_fileobj: + with open(_support("test_ed25519.key")) as pkey_fileobj: key = Ed25519Key.from_private_key(pkey_fileobj) self.assertEqual(key, key) self.assertTrue(key.can_sign()) def test_keyfile_is_actually_encrypted(self): # Read an existing encrypted private key - file_ = _support('test_rsa_password.key') - password = 'television' - newfile = file_ + '.new' - newpassword = 'radio' + file_ = _support("test_rsa_password.key") + password = "television" + newfile = file_ + ".new" + newpassword = "radio" key = RSAKey(filename=file_, password=password) # Write out a newly re-encrypted copy with a new password. # When the bug under test exists, this will ValueError. @@ -503,19 +514,21 @@ class KeyTest(unittest.TestCase): # test_client.py; this and nearby cert tests are more about the gritty # details. # PKey.load_certificate - key_path = _support(os.path.join('cert_support', 'test_rsa.key')) + key_path = _support(os.path.join("cert_support", "test_rsa.key")) key = RSAKey.from_private_key_file(key_path) self.assertTrue(key.public_blob is None) cert_path = _support( - os.path.join('cert_support', 'test_rsa.key-cert.pub') + os.path.join("cert_support", "test_rsa.key-cert.pub") ) key.load_certificate(cert_path) self.assertTrue(key.public_blob is not None) - self.assertEqual(key.public_blob.key_type, 'ssh-rsa-cert-v01@openssh.com') - self.assertEqual(key.public_blob.comment, 'test_rsa.key.pub') + self.assertEqual( + key.public_blob.key_type, "ssh-rsa-cert-v01@openssh.com" + ) + self.assertEqual(key.public_blob.comment, "test_rsa.key.pub") # Delve into blob contents, for test purposes msg = Message(key.public_blob.key_blob) - self.assertEqual(msg.get_text(), 'ssh-rsa-cert-v01@openssh.com') + self.assertEqual(msg.get_text(), "ssh-rsa-cert-v01@openssh.com") nonce = msg.get_string() e = msg.get_mpint() n = msg.get_mpint() @@ -525,10 +538,10 @@ class KeyTest(unittest.TestCase): self.assertEqual(msg.get_int64(), 1234) # Prevented from loading certificate that doesn't match - key_path = _support(os.path.join('cert_support', 'test_ed25519.key')) + key_path = _support(os.path.join("cert_support", "test_ed25519.key")) key1 = Ed25519Key.from_private_key_file(key_path) self.assertRaises( ValueError, key1.load_certificate, - _support('test_rsa.key-cert.pub'), + _support("test_rsa.key-cert.pub"), ) diff --git a/tests/test_sftp.py b/tests/test_sftp.py index 09a50453..a03961d6 100644 --- a/tests/test_sftp.py +++ b/tests/test_sftp.py @@ -45,7 +45,7 @@ from .stub_sftp import StubServer, StubSFTPServer from .util import _support, slow -ARTICLE = ''' +ARTICLE = """ Insulin sensitivity and liver insulin receptor structure in ducks from two genera @@ -70,7 +70,7 @@ receptors. Therefore the ducks from the two genera exhibit an alpha-beta- structure for liver insulin receptors and a clear difference in the number of liver insulin receptors. Their sensitivity to insulin is, however, similarly decreased compared with chicken. -''' +""" # Here is how unicode characters are encoded over 1 to 6 bytes in utf-8 @@ -82,32 +82,33 @@ decreased compared with chicken. # U-04000000 - U-7FFFFFFF: 1111110x 10xxxxxx 10xxxxxx 10xxxxxx 10xxxxxx 10xxxxxx # Note that: hex(int('11000011',2)) == '0xc3' # Thus, the following 2-bytes sequence is not valid utf8: "invalid continuation byte" -NON_UTF8_DATA = b'\xC3\xC3' +NON_UTF8_DATA = b"\xC3\xC3" -unicode_folder = u'\u00fcnic\u00f8de' if PY2 else '\u00fcnic\u00f8de' -utf8_folder = b'/\xc3\xbcnic\xc3\xb8\x64\x65' +unicode_folder = u"\u00fcnic\u00f8de" if PY2 else "\u00fcnic\u00f8de" +utf8_folder = b"/\xc3\xbcnic\xc3\xb8\x64\x65" @slow class TestSFTP(object): + def test_1_file(self, sftp): """ verify that we can create a file. """ - f = sftp.open(sftp.FOLDER + '/test', 'w') + f = sftp.open(sftp.FOLDER + "/test", "w") try: assert f.stat().st_size == 0 finally: f.close() - sftp.remove(sftp.FOLDER + '/test') + sftp.remove(sftp.FOLDER + "/test") def test_2_close(self, sftp): """ Verify that SFTP session close() causes a socket error on next action. """ sftp.close() - with pytest.raises(socket.error, match='Socket is closed'): - sftp.open(sftp.FOLDER + '/test2', 'w') + with pytest.raises(socket.error, match="Socket is closed"): + sftp.open(sftp.FOLDER + "/test2", "w") def test_2_sftp_can_be_used_as_context_manager(self, sftp): """ @@ -115,117 +116,117 @@ class TestSFTP(object): """ with sftp: pass - with pytest.raises(socket.error, match='Socket is closed'): - sftp.open(sftp.FOLDER + '/test2', 'w') + with pytest.raises(socket.error, match="Socket is closed"): + sftp.open(sftp.FOLDER + "/test2", "w") def test_3_write(self, sftp): """ verify that a file can be created and written, and the size is correct. """ try: - with sftp.open(sftp.FOLDER + '/duck.txt', 'w') as f: + with sftp.open(sftp.FOLDER + "/duck.txt", "w") as f: f.write(ARTICLE) - assert sftp.stat(sftp.FOLDER + '/duck.txt').st_size == 1483 + assert sftp.stat(sftp.FOLDER + "/duck.txt").st_size == 1483 finally: - sftp.remove(sftp.FOLDER + '/duck.txt') + sftp.remove(sftp.FOLDER + "/duck.txt") def test_3_sftp_file_can_be_used_as_context_manager(self, sftp): """ verify that an opened file can be used as a context manager """ try: - with sftp.open(sftp.FOLDER + '/duck.txt', 'w') as f: + with sftp.open(sftp.FOLDER + "/duck.txt", "w") as f: f.write(ARTICLE) - assert sftp.stat(sftp.FOLDER + '/duck.txt').st_size == 1483 + assert sftp.stat(sftp.FOLDER + "/duck.txt").st_size == 1483 finally: - sftp.remove(sftp.FOLDER + '/duck.txt') + sftp.remove(sftp.FOLDER + "/duck.txt") def test_4_append(self, sftp): """ verify that a file can be opened for append, and tell() still works. """ try: - with sftp.open(sftp.FOLDER + '/append.txt', 'w') as f: - f.write('first line\nsecond line\n') + with sftp.open(sftp.FOLDER + "/append.txt", "w") as f: + f.write("first line\nsecond line\n") assert f.tell() == 23 - with sftp.open(sftp.FOLDER + '/append.txt', 'a+') as f: - f.write('third line!!!\n') + with sftp.open(sftp.FOLDER + "/append.txt", "a+") as f: + f.write("third line!!!\n") assert f.tell() == 37 assert f.stat().st_size == 37 f.seek(-26, f.SEEK_CUR) - assert f.readline() == 'second line\n' + assert f.readline() == "second line\n" finally: - sftp.remove(sftp.FOLDER + '/append.txt') + sftp.remove(sftp.FOLDER + "/append.txt") def test_5_rename(self, sftp): """ verify that renaming a file works. """ try: - with sftp.open(sftp.FOLDER + '/first.txt', 'w') as f: - f.write('content!\n') - sftp.rename(sftp.FOLDER + '/first.txt', sftp.FOLDER + '/second.txt') - with pytest.raises(IOError, match='No such file'): - sftp.open(sftp.FOLDER + '/first.txt', 'r') - with sftp.open(sftp.FOLDER + '/second.txt', 'r') as f: + with sftp.open(sftp.FOLDER + "/first.txt", "w") as f: + f.write("content!\n") + sftp.rename( + sftp.FOLDER + "/first.txt", sftp.FOLDER + "/second.txt" + ) + with pytest.raises(IOError, match="No such file"): + sftp.open(sftp.FOLDER + "/first.txt", "r") + with sftp.open(sftp.FOLDER + "/second.txt", "r") as f: f.seek(-6, f.SEEK_END) - assert u(f.read(4)) == 'tent' + assert u(f.read(4)) == "tent" finally: # TODO: this is gross, make some sort of 'remove if possible' / 'rm # -f' a-like, jeez try: - sftp.remove(sftp.FOLDER + '/first.txt') + sftp.remove(sftp.FOLDER + "/first.txt") except: pass try: - sftp.remove(sftp.FOLDER + '/second.txt') + sftp.remove(sftp.FOLDER + "/second.txt") except: pass - def test_5a_posix_rename(self, sftp): """Test posix-rename@openssh.com protocol extension.""" try: # first check that the normal rename works as specified - with sftp.open(sftp.FOLDER + '/a', 'w') as f: - f.write('one') - sftp.rename(sftp.FOLDER + '/a', sftp.FOLDER + '/b') - with sftp.open(sftp.FOLDER + '/a', 'w') as f: - f.write('two') - with pytest.raises(IOError): # actual message seems generic - sftp.rename(sftp.FOLDER + '/a', sftp.FOLDER + '/b') + with sftp.open(sftp.FOLDER + "/a", "w") as f: + f.write("one") + sftp.rename(sftp.FOLDER + "/a", sftp.FOLDER + "/b") + with sftp.open(sftp.FOLDER + "/a", "w") as f: + f.write("two") + with pytest.raises(IOError): # actual message seems generic + sftp.rename(sftp.FOLDER + "/a", sftp.FOLDER + "/b") # now check with the posix_rename - sftp.posix_rename(sftp.FOLDER + '/a', sftp.FOLDER + '/b') - with sftp.open(sftp.FOLDER + '/b', 'r') as f: + sftp.posix_rename(sftp.FOLDER + "/a", sftp.FOLDER + "/b") + with sftp.open(sftp.FOLDER + "/b", "r") as f: data = u(f.read()) err = "Contents of renamed file not the same as original file" - assert 'two' == data, err + assert "two" == data, err finally: try: - sftp.remove(sftp.FOLDER + '/a') + sftp.remove(sftp.FOLDER + "/a") except: pass try: - sftp.remove(sftp.FOLDER + '/b') + sftp.remove(sftp.FOLDER + "/b") except: pass - def test_6_folder(self, sftp): """ create a temporary folder, verify that we can create a file in it, then remove the folder and verify that we can't create a file in it anymore. """ - sftp.mkdir(sftp.FOLDER + '/subfolder') - sftp.open(sftp.FOLDER + '/subfolder/test', 'w').close() - sftp.remove(sftp.FOLDER + '/subfolder/test') - sftp.rmdir(sftp.FOLDER + '/subfolder') + sftp.mkdir(sftp.FOLDER + "/subfolder") + sftp.open(sftp.FOLDER + "/subfolder/test", "w").close() + sftp.remove(sftp.FOLDER + "/subfolder/test") + sftp.rmdir(sftp.FOLDER + "/subfolder") # shouldn't be able to create that file if dir removed with pytest.raises(IOError, match="No such file"): - sftp.open(sftp.FOLDER + '/subfolder/test') + sftp.open(sftp.FOLDER + "/subfolder/test") def test_7_listdir(self, sftp): """ @@ -233,57 +234,57 @@ class TestSFTP(object): it, and those files show up in sftp.listdir. """ try: - sftp.open(sftp.FOLDER + '/duck.txt', 'w').close() - sftp.open(sftp.FOLDER + '/fish.txt', 'w').close() - sftp.open(sftp.FOLDER + '/tertiary.py', 'w').close() + sftp.open(sftp.FOLDER + "/duck.txt", "w").close() + sftp.open(sftp.FOLDER + "/fish.txt", "w").close() + sftp.open(sftp.FOLDER + "/tertiary.py", "w").close() x = sftp.listdir(sftp.FOLDER) assert len(x) == 3 - assert 'duck.txt' in x - assert 'fish.txt' in x - assert 'tertiary.py' in x - assert 'random' not in x + assert "duck.txt" in x + assert "fish.txt" in x + assert "tertiary.py" in x + assert "random" not in x finally: - sftp.remove(sftp.FOLDER + '/duck.txt') - sftp.remove(sftp.FOLDER + '/fish.txt') - sftp.remove(sftp.FOLDER + '/tertiary.py') + sftp.remove(sftp.FOLDER + "/duck.txt") + sftp.remove(sftp.FOLDER + "/fish.txt") + sftp.remove(sftp.FOLDER + "/tertiary.py") def test_7_5_listdir_iter(self, sftp): """ listdir_iter version of above test """ try: - sftp.open(sftp.FOLDER + '/duck.txt', 'w').close() - sftp.open(sftp.FOLDER + '/fish.txt', 'w').close() - sftp.open(sftp.FOLDER + '/tertiary.py', 'w').close() + sftp.open(sftp.FOLDER + "/duck.txt", "w").close() + sftp.open(sftp.FOLDER + "/fish.txt", "w").close() + sftp.open(sftp.FOLDER + "/tertiary.py", "w").close() x = [x.filename for x in sftp.listdir_iter(sftp.FOLDER)] assert len(x) == 3 - assert 'duck.txt' in x - assert 'fish.txt' in x - assert 'tertiary.py' in x - assert 'random' not in x + assert "duck.txt" in x + assert "fish.txt" in x + assert "tertiary.py" in x + assert "random" not in x finally: - sftp.remove(sftp.FOLDER + '/duck.txt') - sftp.remove(sftp.FOLDER + '/fish.txt') - sftp.remove(sftp.FOLDER + '/tertiary.py') + sftp.remove(sftp.FOLDER + "/duck.txt") + sftp.remove(sftp.FOLDER + "/fish.txt") + sftp.remove(sftp.FOLDER + "/tertiary.py") def test_8_setstat(self, sftp): """ verify that the setstat functions (chown, chmod, utime, truncate) work. """ try: - with sftp.open(sftp.FOLDER + '/special', 'w') as f: - f.write('x' * 1024) + with sftp.open(sftp.FOLDER + "/special", "w") as f: + f.write("x" * 1024) - stat = sftp.stat(sftp.FOLDER + '/special') - sftp.chmod(sftp.FOLDER + '/special', (stat.st_mode & ~o777) | o600) - stat = sftp.stat(sftp.FOLDER + '/special') + stat = sftp.stat(sftp.FOLDER + "/special") + sftp.chmod(sftp.FOLDER + "/special", (stat.st_mode & ~o777) | o600) + stat = sftp.stat(sftp.FOLDER + "/special") expected_mode = o600 - if sys.platform == 'win32': + if sys.platform == "win32": # chmod not really functional on windows expected_mode = o666 - if sys.platform == 'cygwin': + if sys.platform == "cygwin": # even worse. expected_mode = o644 assert stat.st_mode & o777 == expected_mode @@ -291,19 +292,19 @@ class TestSFTP(object): mtime = stat.st_mtime - 3600 atime = stat.st_atime - 1800 - sftp.utime(sftp.FOLDER + '/special', (atime, mtime)) - stat = sftp.stat(sftp.FOLDER + '/special') + sftp.utime(sftp.FOLDER + "/special", (atime, mtime)) + stat = sftp.stat(sftp.FOLDER + "/special") assert stat.st_mtime == mtime - if sys.platform not in ('win32', 'cygwin'): + if sys.platform not in ("win32", "cygwin"): assert stat.st_atime == atime # can't really test chown, since we'd have to know a valid uid. - sftp.truncate(sftp.FOLDER + '/special', 512) - stat = sftp.stat(sftp.FOLDER + '/special') + sftp.truncate(sftp.FOLDER + "/special", 512) + stat = sftp.stat(sftp.FOLDER + "/special") assert stat.st_size == 512 finally: - sftp.remove(sftp.FOLDER + '/special') + sftp.remove(sftp.FOLDER + "/special") def test_9_fsetstat(self, sftp): """ @@ -311,19 +312,19 @@ class TestSFTP(object): work on open files. """ try: - with sftp.open(sftp.FOLDER + '/special', 'w') as f: - f.write('x' * 1024) + with sftp.open(sftp.FOLDER + "/special", "w") as f: + f.write("x" * 1024) - with sftp.open(sftp.FOLDER + '/special', 'r+') as f: + with sftp.open(sftp.FOLDER + "/special", "r+") as f: stat = f.stat() f.chmod((stat.st_mode & ~o777) | o600) stat = f.stat() expected_mode = o600 - if sys.platform == 'win32': + if sys.platform == "win32": # chmod not really functional on windows expected_mode = o666 - if sys.platform == 'cygwin': + if sys.platform == "cygwin": # even worse. expected_mode = o644 assert stat.st_mode & o777 == expected_mode @@ -334,7 +335,7 @@ class TestSFTP(object): f.utime((atime, mtime)) stat = f.stat() assert stat.st_mtime == mtime - if sys.platform not in ('win32', 'cygwin'): + if sys.platform not in ("win32", "cygwin"): assert stat.st_atime == atime # can't really test chown, since we'd have to know a valid uid. @@ -343,7 +344,7 @@ class TestSFTP(object): stat = f.stat() assert stat.st_size == 512 finally: - sftp.remove(sftp.FOLDER + '/special') + sftp.remove(sftp.FOLDER + "/special") def test_A_readline_seek(self, sftp): """ @@ -353,10 +354,10 @@ class TestSFTP(object): buffering is reset on 'seek'. """ try: - with sftp.open(sftp.FOLDER + '/duck.txt', 'w') as f: + with sftp.open(sftp.FOLDER + "/duck.txt", "w") as f: f.write(ARTICLE) - with sftp.open(sftp.FOLDER + '/duck.txt', 'r+') as f: + with sftp.open(sftp.FOLDER + "/duck.txt", "r+") as f: line_number = 0 loc = 0 pos_list = [] @@ -366,13 +367,16 @@ class TestSFTP(object): loc = f.tell() assert f.seekable() f.seek(pos_list[6], f.SEEK_SET) - assert f.readline(), 'Nouzilly == France.\n' + assert f.readline(), "Nouzilly == France.\n" f.seek(pos_list[17], f.SEEK_SET) - assert f.readline()[:4] == 'duck' + assert f.readline()[:4] == "duck" f.seek(pos_list[10], f.SEEK_SET) - assert f.readline() == 'duck types were equally resistant to exogenous insulin compared with chicken.\n' + assert ( + f.readline() + == "duck types were equally resistant to exogenous insulin compared with chicken.\n" + ) finally: - sftp.remove(sftp.FOLDER + '/duck.txt') + sftp.remove(sftp.FOLDER + "/duck.txt") def test_B_write_seek(self, sftp): """ @@ -380,17 +384,17 @@ class TestSFTP(object): changes worked. """ try: - with sftp.open(sftp.FOLDER + '/testing.txt', 'w') as f: - f.write('hello kitty.\n') + with sftp.open(sftp.FOLDER + "/testing.txt", "w") as f: + f.write("hello kitty.\n") f.seek(-5, f.SEEK_CUR) - f.write('dd') + f.write("dd") - assert sftp.stat(sftp.FOLDER + '/testing.txt').st_size == 13 - with sftp.open(sftp.FOLDER + '/testing.txt', 'r') as f: + assert sftp.stat(sftp.FOLDER + "/testing.txt").st_size == 13 + with sftp.open(sftp.FOLDER + "/testing.txt", "r") as f: data = f.read(20) - assert data == b'hello kiddy.\n' + assert data == b"hello kiddy.\n" finally: - sftp.remove(sftp.FOLDER + '/testing.txt') + sftp.remove(sftp.FOLDER + "/testing.txt") def test_C_symlink(self, sftp): """ @@ -401,39 +405,41 @@ class TestSFTP(object): return try: - with sftp.open(sftp.FOLDER + '/original.txt', 'w') as f: - f.write('original\n') - sftp.symlink('original.txt', sftp.FOLDER + '/link.txt') - assert sftp.readlink(sftp.FOLDER + '/link.txt') == 'original.txt' + with sftp.open(sftp.FOLDER + "/original.txt", "w") as f: + f.write("original\n") + sftp.symlink("original.txt", sftp.FOLDER + "/link.txt") + assert sftp.readlink(sftp.FOLDER + "/link.txt") == "original.txt" - with sftp.open(sftp.FOLDER + '/link.txt', 'r') as f: - assert f.readlines() == ['original\n'] + with sftp.open(sftp.FOLDER + "/link.txt", "r") as f: + assert f.readlines() == ["original\n"] - cwd = sftp.normalize('.') - if cwd[-1] == '/': + cwd = sftp.normalize(".") + if cwd[-1] == "/": cwd = cwd[:-1] - abs_path = cwd + '/' + sftp.FOLDER + '/original.txt' - sftp.symlink(abs_path, sftp.FOLDER + '/link2.txt') - assert abs_path == sftp.readlink(sftp.FOLDER + '/link2.txt') + abs_path = cwd + "/" + sftp.FOLDER + "/original.txt" + sftp.symlink(abs_path, sftp.FOLDER + "/link2.txt") + assert abs_path == sftp.readlink(sftp.FOLDER + "/link2.txt") - assert sftp.lstat(sftp.FOLDER + '/link.txt').st_size == 12 - assert sftp.stat(sftp.FOLDER + '/link.txt').st_size == 9 + assert sftp.lstat(sftp.FOLDER + "/link.txt").st_size == 12 + assert sftp.stat(sftp.FOLDER + "/link.txt").st_size == 9 # the sftp server may be hiding extra path members from us, so the # length may be longer than we expect: - assert sftp.lstat(sftp.FOLDER + '/link2.txt').st_size >= len(abs_path) - assert sftp.stat(sftp.FOLDER + '/link2.txt').st_size == 9 - assert sftp.stat(sftp.FOLDER + '/original.txt').st_size == 9 + assert ( + sftp.lstat(sftp.FOLDER + "/link2.txt").st_size >= len(abs_path) + ) + assert sftp.stat(sftp.FOLDER + "/link2.txt").st_size == 9 + assert sftp.stat(sftp.FOLDER + "/original.txt").st_size == 9 finally: try: - sftp.remove(sftp.FOLDER + '/link.txt') + sftp.remove(sftp.FOLDER + "/link.txt") except: pass try: - sftp.remove(sftp.FOLDER + '/link2.txt') + sftp.remove(sftp.FOLDER + "/link2.txt") except: pass try: - sftp.remove(sftp.FOLDER + '/original.txt') + sftp.remove(sftp.FOLDER + "/original.txt") except: pass @@ -442,18 +448,18 @@ class TestSFTP(object): verify that buffered writes are automatically flushed on seek. """ try: - with sftp.open(sftp.FOLDER + '/happy.txt', 'w', 1) as f: - f.write('full line.\n') - f.write('partial') + with sftp.open(sftp.FOLDER + "/happy.txt", "w", 1) as f: + f.write("full line.\n") + f.write("partial") f.seek(9, f.SEEK_SET) - f.write('?\n') + f.write("?\n") - with sftp.open(sftp.FOLDER + '/happy.txt', 'r') as f: - assert f.readline() == u('full line?\n') - assert f.read(7) == b'partial' + with sftp.open(sftp.FOLDER + "/happy.txt", "r") as f: + assert f.readline() == u("full line?\n") + assert f.read(7) == b"partial" finally: try: - sftp.remove(sftp.FOLDER + '/happy.txt') + sftp.remove(sftp.FOLDER + "/happy.txt") except: pass @@ -462,9 +468,9 @@ class TestSFTP(object): test that realpath is returning something non-empty and not an error. """ - pwd = sftp.normalize('.') + pwd = sftp.normalize(".") assert len(pwd) > 0 - f = sftp.normalize('./' + sftp.FOLDER) + f = sftp.normalize("./" + sftp.FOLDER) assert len(f) > 0 assert os.path.join(pwd, sftp.FOLDER) == f @@ -472,46 +478,46 @@ class TestSFTP(object): """ verify that mkdir/rmdir work. """ - sftp.mkdir(sftp.FOLDER + '/subfolder') - with pytest.raises(IOError): # generic msg only - sftp.mkdir(sftp.FOLDER + '/subfolder') - sftp.rmdir(sftp.FOLDER + '/subfolder') + sftp.mkdir(sftp.FOLDER + "/subfolder") + with pytest.raises(IOError): # generic msg only + sftp.mkdir(sftp.FOLDER + "/subfolder") + sftp.rmdir(sftp.FOLDER + "/subfolder") with pytest.raises(IOError, match="No such file"): - sftp.rmdir(sftp.FOLDER + '/subfolder') + sftp.rmdir(sftp.FOLDER + "/subfolder") def test_G_chdir(self, sftp): """ verify that chdir/getcwd work. """ - root = sftp.normalize('.') - if root[-1] != '/': - root += '/' + root = sftp.normalize(".") + if root[-1] != "/": + root += "/" try: - sftp.mkdir(sftp.FOLDER + '/alpha') - sftp.chdir(sftp.FOLDER + '/alpha') - sftp.mkdir('beta') - assert root + sftp.FOLDER + '/alpha' == sftp.getcwd() - assert ['beta'] == sftp.listdir('.') - - sftp.chdir('beta') - with sftp.open('fish', 'w') as f: - f.write('hello\n') - sftp.chdir('..') - assert ['fish'] == sftp.listdir('beta') - sftp.chdir('..') - assert ['fish'] == sftp.listdir('alpha/beta') + sftp.mkdir(sftp.FOLDER + "/alpha") + sftp.chdir(sftp.FOLDER + "/alpha") + sftp.mkdir("beta") + assert root + sftp.FOLDER + "/alpha" == sftp.getcwd() + assert ["beta"] == sftp.listdir(".") + + sftp.chdir("beta") + with sftp.open("fish", "w") as f: + f.write("hello\n") + sftp.chdir("..") + assert ["fish"] == sftp.listdir("beta") + sftp.chdir("..") + assert ["fish"] == sftp.listdir("alpha/beta") finally: sftp.chdir(root) try: - sftp.unlink(sftp.FOLDER + '/alpha/beta/fish') + sftp.unlink(sftp.FOLDER + "/alpha/beta/fish") except: pass try: - sftp.rmdir(sftp.FOLDER + '/alpha/beta') + sftp.rmdir(sftp.FOLDER + "/alpha/beta") except: pass try: - sftp.rmdir(sftp.FOLDER + '/alpha') + sftp.rmdir(sftp.FOLDER + "/alpha") except: pass @@ -519,20 +525,21 @@ class TestSFTP(object): """ verify that get/put work. """ - warnings.filterwarnings('ignore', 'tempnam.*') + warnings.filterwarnings("ignore", "tempnam.*") fd, localname = mkstemp() os.close(fd) - text = b'All I wanted was a plastic bunny rabbit.\n' - with open(localname, 'wb') as f: + text = b"All I wanted was a plastic bunny rabbit.\n" + with open(localname, "wb") as f: f.write(text) saved_progress = [] def progress_callback(x, y): saved_progress.append((x, y)) - sftp.put(localname, sftp.FOLDER + '/bunny.txt', progress_callback) - with sftp.open(sftp.FOLDER + '/bunny.txt', 'rb') as f: + sftp.put(localname, sftp.FOLDER + "/bunny.txt", progress_callback) + + with sftp.open(sftp.FOLDER + "/bunny.txt", "rb") as f: assert text == f.read(128) assert [(41, 41)] == saved_progress @@ -540,14 +547,14 @@ class TestSFTP(object): fd, localname = mkstemp() os.close(fd) saved_progress = [] - sftp.get(sftp.FOLDER + '/bunny.txt', localname, progress_callback) + sftp.get(sftp.FOLDER + "/bunny.txt", localname, progress_callback) - with open(localname, 'rb') as f: + with open(localname, "rb") as f: assert text == f.read(128) assert [(41, 41)] == saved_progress os.unlink(localname) - sftp.unlink(sftp.FOLDER + '/bunny.txt') + sftp.unlink(sftp.FOLDER + "/bunny.txt") def test_I_check(self, sftp): """ @@ -555,118 +562,132 @@ class TestSFTP(object): (it's an sftp extension that we support, and may be the only ones who support it.) """ - with sftp.open(sftp.FOLDER + '/kitty.txt', 'w') as f: - f.write('here kitty kitty' * 64) + with sftp.open(sftp.FOLDER + "/kitty.txt", "w") as f: + f.write("here kitty kitty" * 64) try: - with sftp.open(sftp.FOLDER + '/kitty.txt', 'r') as f: - sum = f.check('sha1') - assert '91059CFC6615941378D413CB5ADAF4C5EB293402' == u(hexlify(sum)).upper() - sum = f.check('md5', 0, 512) - assert '93DE4788FCA28D471516963A1FE3856A' == u(hexlify(sum)).upper() - sum = f.check('md5', 0, 0, 510) - assert u(hexlify(sum)).upper() == 'EB3B45B8CD55A0707D99B177544A319F373183D241432BB2157AB9E46358C4AC90370B5CADE5D90336FC1716F90B36D6' # noqa + with sftp.open(sftp.FOLDER + "/kitty.txt", "r") as f: + sum = f.check("sha1") + assert ( + "91059CFC6615941378D413CB5ADAF4C5EB293402" + == u(hexlify(sum)).upper() + ) + sum = f.check("md5", 0, 512) + assert ( + "93DE4788FCA28D471516963A1FE3856A" + == u(hexlify(sum)).upper() + ) + sum = f.check("md5", 0, 0, 510) + assert ( + u(hexlify(sum)).upper() + == "EB3B45B8CD55A0707D99B177544A319F373183D241432BB2157AB9E46358C4AC90370B5CADE5D90336FC1716F90B36D6" + ) # noqa finally: - sftp.unlink(sftp.FOLDER + '/kitty.txt') + sftp.unlink(sftp.FOLDER + "/kitty.txt") def test_J_x_flag(self, sftp): """ verify that the 'x' flag works when opening a file. """ - sftp.open(sftp.FOLDER + '/unusual.txt', 'wx').close() + sftp.open(sftp.FOLDER + "/unusual.txt", "wx").close() try: try: - sftp.open(sftp.FOLDER + '/unusual.txt', 'wx') - self.fail('expected exception') + sftp.open(sftp.FOLDER + "/unusual.txt", "wx") + self.fail("expected exception") except IOError: pass finally: - sftp.unlink(sftp.FOLDER + '/unusual.txt') + sftp.unlink(sftp.FOLDER + "/unusual.txt") def test_K_utf8(self, sftp): """ verify that unicode strings are encoded into utf8 correctly. """ - with sftp.open(sftp.FOLDER + '/something', 'w') as f: - f.write('okay') + with sftp.open(sftp.FOLDER + "/something", "w") as f: + f.write("okay") try: - sftp.rename(sftp.FOLDER + '/something', sftp.FOLDER + '/' + unicode_folder) - sftp.open(b(sftp.FOLDER) + utf8_folder, 'r') + sftp.rename( + sftp.FOLDER + "/something", sftp.FOLDER + "/" + unicode_folder + ) + sftp.open(b(sftp.FOLDER) + utf8_folder, "r") except Exception as e: - self.fail('exception ' + str(e)) + self.fail("exception " + str(e)) sftp.unlink(b(sftp.FOLDER) + utf8_folder) def test_L_utf8_chdir(self, sftp): - sftp.mkdir(sftp.FOLDER + '/' + unicode_folder) + sftp.mkdir(sftp.FOLDER + "/" + unicode_folder) try: - sftp.chdir(sftp.FOLDER + '/' + unicode_folder) - with sftp.open('something', 'w') as f: - f.write('okay') - sftp.unlink('something') + sftp.chdir(sftp.FOLDER + "/" + unicode_folder) + with sftp.open("something", "w") as f: + f.write("okay") + sftp.unlink("something") finally: sftp.chdir() - sftp.rmdir(sftp.FOLDER + '/' + unicode_folder) + sftp.rmdir(sftp.FOLDER + "/" + unicode_folder) def test_M_bad_readv(self, sftp): """ verify that readv at the end of the file doesn't essplode. """ - sftp.open(sftp.FOLDER + '/zero', 'w').close() + sftp.open(sftp.FOLDER + "/zero", "w").close() try: - with sftp.open(sftp.FOLDER + '/zero', 'r') as f: + with sftp.open(sftp.FOLDER + "/zero", "r") as f: f.readv([(0, 12)]) - with sftp.open(sftp.FOLDER + '/zero', 'r') as f: + with sftp.open(sftp.FOLDER + "/zero", "r") as f: file_size = f.stat().st_size f.prefetch(file_size) f.read(100) finally: - sftp.unlink(sftp.FOLDER + '/zero') + sftp.unlink(sftp.FOLDER + "/zero") def test_N_put_without_confirm(self, sftp): """ verify that get/put work without confirmation. """ - warnings.filterwarnings('ignore', 'tempnam.*') + warnings.filterwarnings("ignore", "tempnam.*") fd, localname = mkstemp() os.close(fd) - text = b'All I wanted was a plastic bunny rabbit.\n' - with open(localname, 'wb') as f: + text = b"All I wanted was a plastic bunny rabbit.\n" + with open(localname, "wb") as f: f.write(text) saved_progress = [] def progress_callback(x, y): saved_progress.append((x, y)) - res = sftp.put(localname, sftp.FOLDER + '/bunny.txt', progress_callback, False) + + res = sftp.put( + localname, sftp.FOLDER + "/bunny.txt", progress_callback, False + ) assert SFTPAttributes().attr == res.attr - with sftp.open(sftp.FOLDER + '/bunny.txt', 'r') as f: + with sftp.open(sftp.FOLDER + "/bunny.txt", "r") as f: assert text == f.read(128) assert (41, 41) == saved_progress[-1] os.unlink(localname) - sftp.unlink(sftp.FOLDER + '/bunny.txt') + sftp.unlink(sftp.FOLDER + "/bunny.txt") def test_O_getcwd(self, sftp): """ verify that chdir/getcwd work. """ assert sftp.getcwd() == None - root = sftp.normalize('.') - if root[-1] != '/': - root += '/' + root = sftp.normalize(".") + if root[-1] != "/": + root += "/" try: - sftp.mkdir(sftp.FOLDER + '/alpha') - sftp.chdir(sftp.FOLDER + '/alpha') - assert sftp.getcwd() == '/' + sftp.FOLDER + '/alpha' + sftp.mkdir(sftp.FOLDER + "/alpha") + sftp.chdir(sftp.FOLDER + "/alpha") + assert sftp.getcwd() == "/" + sftp.FOLDER + "/alpha" finally: sftp.chdir(root) try: - sftp.rmdir(sftp.FOLDER + '/alpha') + sftp.rmdir(sftp.FOLDER + "/alpha") except: pass @@ -677,24 +698,24 @@ class TestSFTP(object): does not work except through paramiko. :( openssh fails. """ try: - with sftp.open(sftp.FOLDER + '/append.txt', 'a') as f: - f.write('first line\nsecond line\n') + with sftp.open(sftp.FOLDER + "/append.txt", "a") as f: + f.write("first line\nsecond line\n") f.seek(11, f.SEEK_SET) - f.write('third line\n') + f.write("third line\n") - with sftp.open(sftp.FOLDER + '/append.txt', 'r') as f: + with sftp.open(sftp.FOLDER + "/append.txt", "r") as f: assert f.stat().st_size == 34 - assert f.readline() == 'first line\n' - assert f.readline() == 'second line\n' - assert f.readline() == 'third line\n' + assert f.readline() == "first line\n" + assert f.readline() == "second line\n" + assert f.readline() == "third line\n" finally: - sftp.remove(sftp.FOLDER + '/append.txt') + sftp.remove(sftp.FOLDER + "/append.txt") def test_putfo_empty_file(self, sftp): """ Send an empty file and confirm it is sent. """ - target = sftp.FOLDER + '/empty file.txt' + target = sftp.FOLDER + "/empty file.txt" stream = StringIO() try: attrs = sftp.putfo(stream, target) @@ -713,59 +734,61 @@ class TestSFTP(object): verify that we can create a file with a '%' in the filename. ( it needs to be properly escaped by _log() ) """ - f = sftp.open(sftp.FOLDER + '/test%file', 'w') + f = sftp.open(sftp.FOLDER + "/test%file", "w") try: assert f.stat().st_size == 0 finally: f.close() - sftp.remove(sftp.FOLDER + '/test%file') + sftp.remove(sftp.FOLDER + "/test%file") def test_O_non_utf8_data(self, sftp): """Test write() and read() of non utf8 data""" try: - with sftp.open('%s/nonutf8data' % sftp.FOLDER, 'w') as f: + with sftp.open("%s/nonutf8data" % sftp.FOLDER, "w") as f: f.write(NON_UTF8_DATA) - with sftp.open('%s/nonutf8data' % sftp.FOLDER, 'r') as f: + with sftp.open("%s/nonutf8data" % sftp.FOLDER, "r") as f: data = f.read() assert data == NON_UTF8_DATA - with sftp.open('%s/nonutf8data' % sftp.FOLDER, 'wb') as f: + with sftp.open("%s/nonutf8data" % sftp.FOLDER, "wb") as f: f.write(NON_UTF8_DATA) - with sftp.open('%s/nonutf8data' % sftp.FOLDER, 'rb') as f: + with sftp.open("%s/nonutf8data" % sftp.FOLDER, "rb") as f: data = f.read() assert data == NON_UTF8_DATA finally: - sftp.remove('%s/nonutf8data' % sftp.FOLDER) - + sftp.remove("%s/nonutf8data" % sftp.FOLDER) def test_sftp_attributes_empty_str(self, sftp): sftp_attributes = SFTPAttributes() - assert str(sftp_attributes) == "?--------- 1 0 0 0 (unknown date) ?" + assert ( + str(sftp_attributes) + == "?--------- 1 0 0 0 (unknown date) ?" + ) - @needs_builtin('buffer') + @needs_builtin("buffer") def test_write_buffer(self, sftp): """Test write() using a buffer instance.""" - data = 3 * b'A potentially large block of data to chunk up.\n' + data = 3 * b"A potentially large block of data to chunk up.\n" try: - with sftp.open('%s/write_buffer' % sftp.FOLDER, 'wb') as f: + with sftp.open("%s/write_buffer" % sftp.FOLDER, "wb") as f: for offset in range(0, len(data), 8): f.write(buffer(data, offset, 8)) - with sftp.open('%s/write_buffer' % sftp.FOLDER, 'rb') as f: + with sftp.open("%s/write_buffer" % sftp.FOLDER, "rb") as f: assert f.read() == data finally: - sftp.remove('%s/write_buffer' % sftp.FOLDER) + sftp.remove("%s/write_buffer" % sftp.FOLDER) - @needs_builtin('memoryview') + @needs_builtin("memoryview") def test_write_memoryview(self, sftp): """Test write() using a memoryview instance.""" - data = 3 * b'A potentially large block of data to chunk up.\n' + data = 3 * b"A potentially large block of data to chunk up.\n" try: - with sftp.open('%s/write_memoryview' % sftp.FOLDER, 'wb') as f: + with sftp.open("%s/write_memoryview" % sftp.FOLDER, "wb") as f: view = memoryview(data) for offset in range(0, len(data), 8): - f.write(view[offset:offset+8]) + f.write(view[offset:offset + 8]) - with sftp.open('%s/write_memoryview' % sftp.FOLDER, 'rb') as f: + with sftp.open("%s/write_memoryview" % sftp.FOLDER, "rb") as f: assert f.read() == data finally: - sftp.remove('%s/write_memoryview' % sftp.FOLDER) + sftp.remove("%s/write_memoryview" % sftp.FOLDER) diff --git a/tests/test_sftp_big.py b/tests/test_sftp_big.py index a659098d..7f74d5f6 100644 --- a/tests/test_sftp_big.py +++ b/tests/test_sftp_big.py @@ -37,6 +37,7 @@ from .util import slow @slow class TestBigSFTP(object): + def test_1_lots_of_files(self, sftp): """ create a bunch of files over the same session. @@ -44,22 +45,24 @@ class TestBigSFTP(object): numfiles = 100 try: for i in range(numfiles): - with sftp.open('%s/file%d.txt' % (sftp.FOLDER, i), 'w', 1) as f: - f.write('this is file #%d.\n' % i) - sftp.chmod('%s/file%d.txt' % (sftp.FOLDER, i), o660) + with sftp.open( + "%s/file%d.txt" % (sftp.FOLDER, i), "w", 1 + ) as f: + f.write("this is file #%d.\n" % i) + sftp.chmod("%s/file%d.txt" % (sftp.FOLDER, i), o660) # now make sure every file is there, by creating a list of filenmes # and reading them in random order. numlist = list(range(numfiles)) while len(numlist) > 0: r = numlist[random.randint(0, len(numlist) - 1)] - with sftp.open('%s/file%d.txt' % (sftp.FOLDER, r)) as f: - assert f.readline() == 'this is file #%d.\n' % r + with sftp.open("%s/file%d.txt" % (sftp.FOLDER, r)) as f: + assert f.readline() == "this is file #%d.\n" % r numlist.remove(r) finally: for i in range(numfiles): try: - sftp.remove('%s/file%d.txt' % (sftp.FOLDER, i)) + sftp.remove("%s/file%d.txt" % (sftp.FOLDER, i)) except: pass @@ -67,52 +70,56 @@ class TestBigSFTP(object): """ write a 1MB file with no buffering. """ - kblob = (1024 * b'x') + kblob = (1024 * b"x") start = time.time() try: - with sftp.open('%s/hongry.txt' % sftp.FOLDER, 'w') as f: + with sftp.open("%s/hongry.txt" % sftp.FOLDER, "w") as f: for n in range(1024): f.write(kblob) if n % 128 == 0: - sys.stderr.write('.') - sys.stderr.write(' ') + sys.stderr.write(".") + sys.stderr.write(" ") - assert sftp.stat('%s/hongry.txt' % sftp.FOLDER).st_size == 1024 * 1024 + assert ( + sftp.stat("%s/hongry.txt" % sftp.FOLDER).st_size == 1024 * 1024 + ) end = time.time() - sys.stderr.write('%ds ' % round(end - start)) - + sys.stderr.write("%ds " % round(end - start)) + start = time.time() - with sftp.open('%s/hongry.txt' % sftp.FOLDER, 'r') as f: + with sftp.open("%s/hongry.txt" % sftp.FOLDER, "r") as f: for n in range(1024): data = f.read(1024) assert data == kblob end = time.time() - sys.stderr.write('%ds ' % round(end - start)) + sys.stderr.write("%ds " % round(end - start)) finally: - sftp.remove('%s/hongry.txt' % sftp.FOLDER) + sftp.remove("%s/hongry.txt" % sftp.FOLDER) def test_3_big_file_pipelined(self, sftp): """ write a 1MB file, with no linefeeds, using pipelining. """ - kblob = bytes().join([struct.pack('>H', n) for n in range(512)]) + kblob = bytes().join([struct.pack(">H", n) for n in range(512)]) start = time.time() try: - with sftp.open('%s/hongry.txt' % sftp.FOLDER, 'wb') as f: + with sftp.open("%s/hongry.txt" % sftp.FOLDER, "wb") as f: f.set_pipelined(True) for n in range(1024): f.write(kblob) if n % 128 == 0: - sys.stderr.write('.') - sys.stderr.write(' ') + sys.stderr.write(".") + sys.stderr.write(" ") - assert sftp.stat('%s/hongry.txt' % sftp.FOLDER).st_size == 1024 * 1024 + assert ( + sftp.stat("%s/hongry.txt" % sftp.FOLDER).st_size == 1024 * 1024 + ) end = time.time() - sys.stderr.write('%ds ' % round(end - start)) - + sys.stderr.write("%ds " % round(end - start)) + start = time.time() - with sftp.open('%s/hongry.txt' % sftp.FOLDER, 'rb') as f: + with sftp.open("%s/hongry.txt" % sftp.FOLDER, "rb") as f: file_size = f.stat().st_size f.prefetch(file_size) @@ -130,31 +137,35 @@ class TestBigSFTP(object): n += chunk end = time.time() - sys.stderr.write('%ds ' % round(end - start)) + sys.stderr.write("%ds " % round(end - start)) finally: - sftp.remove('%s/hongry.txt' % sftp.FOLDER) + sftp.remove("%s/hongry.txt" % sftp.FOLDER) def test_4_prefetch_seek(self, sftp): - kblob = bytes().join([struct.pack('>H', n) for n in range(512)]) + kblob = bytes().join([struct.pack(">H", n) for n in range(512)]) try: - with sftp.open('%s/hongry.txt' % sftp.FOLDER, 'wb') as f: + with sftp.open("%s/hongry.txt" % sftp.FOLDER, "wb") as f: f.set_pipelined(True) for n in range(1024): f.write(kblob) if n % 128 == 0: - sys.stderr.write('.') - sys.stderr.write(' ') - - assert sftp.stat('%s/hongry.txt' % sftp.FOLDER).st_size == 1024 * 1024 - + sys.stderr.write(".") + sys.stderr.write(" ") + + assert ( + sftp.stat("%s/hongry.txt" % sftp.FOLDER).st_size == 1024 * 1024 + ) + start = time.time() k2blob = kblob + kblob chunk = 793 for i in range(10): - with sftp.open('%s/hongry.txt' % sftp.FOLDER, 'rb') as f: + with sftp.open("%s/hongry.txt" % sftp.FOLDER, "rb") as f: file_size = f.stat().st_size f.prefetch(file_size) - base_offset = (512 * 1024) + 17 * random.randint(1000, 2000) + base_offset = (512 * 1024) + 17 * random.randint( + 1000, 2000 + ) offsets = [base_offset + j * chunk for j in range(100)] # randomly seek around and read them out for j in range(100): @@ -166,29 +177,33 @@ class TestBigSFTP(object): assert data == k2blob[n_offset:n_offset + chunk] offset += chunk end = time.time() - sys.stderr.write('%ds ' % round(end - start)) + sys.stderr.write("%ds " % round(end - start)) finally: - sftp.remove('%s/hongry.txt' % sftp.FOLDER) + sftp.remove("%s/hongry.txt" % sftp.FOLDER) def test_5_readv_seek(self, sftp): - kblob = bytes().join([struct.pack('>H', n) for n in range(512)]) + kblob = bytes().join([struct.pack(">H", n) for n in range(512)]) try: - with sftp.open('%s/hongry.txt' % sftp.FOLDER, 'wb') as f: + with sftp.open("%s/hongry.txt" % sftp.FOLDER, "wb") as f: f.set_pipelined(True) for n in range(1024): f.write(kblob) if n % 128 == 0: - sys.stderr.write('.') - sys.stderr.write(' ') + sys.stderr.write(".") + sys.stderr.write(" ") - assert sftp.stat('%s/hongry.txt' % sftp.FOLDER).st_size == 1024 * 1024 + assert ( + sftp.stat("%s/hongry.txt" % sftp.FOLDER).st_size == 1024 * 1024 + ) start = time.time() k2blob = kblob + kblob chunk = 793 for i in range(10): - with sftp.open('%s/hongry.txt' % sftp.FOLDER, 'rb') as f: - base_offset = (512 * 1024) + 17 * random.randint(1000, 2000) + with sftp.open("%s/hongry.txt" % sftp.FOLDER, "rb") as f: + base_offset = (512 * 1024) + 17 * random.randint( + 1000, 2000 + ) # make a bunch of offsets and put them in random order offsets = [base_offset + j * chunk for j in range(100)] readv_list = [] @@ -202,60 +217,64 @@ class TestBigSFTP(object): n_offset = offset % 1024 assert next(ret) == k2blob[n_offset:n_offset + chunk] end = time.time() - sys.stderr.write('%ds ' % round(end - start)) + sys.stderr.write("%ds " % round(end - start)) finally: - sftp.remove('%s/hongry.txt' % sftp.FOLDER) + sftp.remove("%s/hongry.txt" % sftp.FOLDER) def test_6_lots_of_prefetching(self, sftp): """ prefetch a 1MB file a bunch of times, discarding the file object without using it, to verify that paramiko doesn't get confused. """ - kblob = (1024 * b'x') + kblob = (1024 * b"x") try: - with sftp.open('%s/hongry.txt' % sftp.FOLDER, 'w') as f: + with sftp.open("%s/hongry.txt" % sftp.FOLDER, "w") as f: f.set_pipelined(True) for n in range(1024): f.write(kblob) if n % 128 == 0: - sys.stderr.write('.') - sys.stderr.write(' ') + sys.stderr.write(".") + sys.stderr.write(" ") - assert sftp.stat('%s/hongry.txt' % sftp.FOLDER).st_size == 1024 * 1024 + assert ( + sftp.stat("%s/hongry.txt" % sftp.FOLDER).st_size == 1024 * 1024 + ) for i in range(10): - with sftp.open('%s/hongry.txt' % sftp.FOLDER, 'r') as f: + with sftp.open("%s/hongry.txt" % sftp.FOLDER, "r") as f: file_size = f.stat().st_size f.prefetch(file_size) - with sftp.open('%s/hongry.txt' % sftp.FOLDER, 'r') as f: + with sftp.open("%s/hongry.txt" % sftp.FOLDER, "r") as f: file_size = f.stat().st_size f.prefetch(file_size) for n in range(1024): data = f.read(1024) assert data == kblob if n % 128 == 0: - sys.stderr.write('.') - sys.stderr.write(' ') + sys.stderr.write(".") + sys.stderr.write(" ") finally: - sftp.remove('%s/hongry.txt' % sftp.FOLDER) - + sftp.remove("%s/hongry.txt" % sftp.FOLDER) + def test_7_prefetch_readv(self, sftp): """ verify that prefetch and readv don't conflict with each other. """ - kblob = bytes().join([struct.pack('>H', n) for n in range(512)]) + kblob = bytes().join([struct.pack(">H", n) for n in range(512)]) try: - with sftp.open('%s/hongry.txt' % sftp.FOLDER, 'wb') as f: + with sftp.open("%s/hongry.txt" % sftp.FOLDER, "wb") as f: f.set_pipelined(True) for n in range(1024): f.write(kblob) if n % 128 == 0: - sys.stderr.write('.') - sys.stderr.write(' ') - - assert sftp.stat('%s/hongry.txt' % sftp.FOLDER).st_size == 1024 * 1024 + sys.stderr.write(".") + sys.stderr.write(" ") + + assert ( + sftp.stat("%s/hongry.txt" % sftp.FOLDER).st_size == 1024 * 1024 + ) - with sftp.open('%s/hongry.txt' % sftp.FOLDER, 'rb') as f: + with sftp.open("%s/hongry.txt" % sftp.FOLDER, "rb") as f: file_size = f.stat().st_size f.prefetch(file_size) data = f.read(1024) @@ -264,79 +283,94 @@ class TestBigSFTP(object): chunk_size = 793 base_offset = 512 * 1024 k2blob = kblob + kblob - chunks = [(base_offset + (chunk_size * i), chunk_size) for i in range(20)] + chunks = [ + (base_offset + (chunk_size * i), chunk_size) + for i in range(20) + ] for data in f.readv(chunks): offset = base_offset % 1024 assert chunk_size == len(data) assert k2blob[offset:offset + chunk_size] == data base_offset += chunk_size - sys.stderr.write(' ') + sys.stderr.write(" ") finally: - sftp.remove('%s/hongry.txt' % sftp.FOLDER) - + sftp.remove("%s/hongry.txt" % sftp.FOLDER) + def test_8_large_readv(self, sftp): """ verify that a very large readv is broken up correctly and still returned as a single blob. """ - kblob = bytes().join([struct.pack('>H', n) for n in range(512)]) + kblob = bytes().join([struct.pack(">H", n) for n in range(512)]) try: - with sftp.open('%s/hongry.txt' % sftp.FOLDER, 'wb') as f: + with sftp.open("%s/hongry.txt" % sftp.FOLDER, "wb") as f: f.set_pipelined(True) for n in range(1024): f.write(kblob) if n % 128 == 0: - sys.stderr.write('.') - sys.stderr.write(' ') + sys.stderr.write(".") + sys.stderr.write(" ") + + assert ( + sftp.stat("%s/hongry.txt" % sftp.FOLDER).st_size == 1024 * 1024 + ) - assert sftp.stat('%s/hongry.txt' % sftp.FOLDER).st_size == 1024 * 1024 - - with sftp.open('%s/hongry.txt' % sftp.FOLDER, 'rb') as f: + with sftp.open("%s/hongry.txt" % sftp.FOLDER, "rb") as f: data = list(f.readv([(23 * 1024, 128 * 1024)])) assert len(data) == 1 data = data[0] assert len(data) == 128 * 1024 - - sys.stderr.write(' ') + + sys.stderr.write(" ") finally: - sftp.remove('%s/hongry.txt' % sftp.FOLDER) - + sftp.remove("%s/hongry.txt" % sftp.FOLDER) + def test_9_big_file_big_buffer(self, sftp): """ write a 1MB file, with no linefeeds, and a big buffer. """ - mblob = (1024 * 1024 * 'x') + mblob = (1024 * 1024 * "x") try: - with sftp.open('%s/hongry.txt' % sftp.FOLDER, 'w', 128 * 1024) as f: + with sftp.open( + "%s/hongry.txt" % sftp.FOLDER, "w", 128 * 1024 + ) as f: f.write(mblob) - assert sftp.stat('%s/hongry.txt' % sftp.FOLDER).st_size == 1024 * 1024 + assert ( + sftp.stat("%s/hongry.txt" % sftp.FOLDER).st_size == 1024 * 1024 + ) finally: - sftp.remove('%s/hongry.txt' % sftp.FOLDER) - + sftp.remove("%s/hongry.txt" % sftp.FOLDER) + def test_A_big_file_renegotiate(self, sftp): """ write a 1MB file, forcing key renegotiation in the middle. """ t = sftp.sock.get_transport() t.packetizer.REKEY_BYTES = 512 * 1024 - k32blob = (32 * 1024 * 'x') + k32blob = (32 * 1024 * "x") try: - with sftp.open('%s/hongry.txt' % sftp.FOLDER, 'w', 128 * 1024) as f: + with sftp.open( + "%s/hongry.txt" % sftp.FOLDER, "w", 128 * 1024 + ) as f: for i in range(32): f.write(k32blob) - assert sftp.stat('%s/hongry.txt' % sftp.FOLDER).st_size == 1024 * 1024 + assert ( + sftp.stat("%s/hongry.txt" % sftp.FOLDER).st_size == 1024 * 1024 + ) assert t.H != t.session_id - + # try to read it too. - with sftp.open('%s/hongry.txt' % sftp.FOLDER, 'r', 128 * 1024) as f: + with sftp.open( + "%s/hongry.txt" % sftp.FOLDER, "r", 128 * 1024 + ) as f: file_size = f.stat().st_size f.prefetch(file_size) total = 0 while total < 1024 * 1024: total += len(f.read(32 * 1024)) finally: - sftp.remove('%s/hongry.txt' % sftp.FOLDER) + sftp.remove("%s/hongry.txt" % sftp.FOLDER) t.packetizer.REKEY_BYTES = pow(2, 30) diff --git a/tests/test_ssh_exception.py b/tests/test_ssh_exception.py index 18f2a97d..6cc5d06a 100644 --- a/tests/test_ssh_exception.py +++ b/tests/test_ssh_exception.py @@ -4,28 +4,33 @@ import unittest from paramiko.ssh_exception import NoValidConnectionsError -class NoValidConnectionsErrorTest (unittest.TestCase): +class NoValidConnectionsErrorTest(unittest.TestCase): def test_pickling(self): # Regression test for https://github.com/paramiko/paramiko/issues/617 - exc = NoValidConnectionsError({('127.0.0.1', '22'): Exception()}) + exc = NoValidConnectionsError({("127.0.0.1", "22"): Exception()}) new_exc = pickle.loads(pickle.dumps(exc)) self.assertEqual(type(exc), type(new_exc)) self.assertEqual(str(exc), str(new_exc)) self.assertEqual(exc.args, new_exc.args) def test_error_message_for_single_host(self): - exc = NoValidConnectionsError({('127.0.0.1', '22'): Exception()}) + exc = NoValidConnectionsError({("127.0.0.1", "22"): Exception()}) assert "Unable to connect to port 22 on 127.0.0.1" in str(exc) def test_error_message_for_two_hosts(self): - exc = NoValidConnectionsError({('127.0.0.1', '22'): Exception(), - ('::1', '22'): Exception()}) + exc = NoValidConnectionsError( + {("127.0.0.1", "22"): Exception(), ("::1", "22"): Exception()} + ) assert "Unable to connect to port 22 on 127.0.0.1 or ::1" in str(exc) def test_error_message_for_multiple_hosts(self): - exc = NoValidConnectionsError({('127.0.0.1', '22'): Exception(), - ('::1', '22'): Exception(), - ('10.0.0.42', '22'): Exception()}) + exc = NoValidConnectionsError( + { + ("127.0.0.1", "22"): Exception(), + ("::1", "22"): Exception(), + ("10.0.0.42", "22"): Exception(), + } + ) exp = "Unable to connect to port 22 on 10.0.0.42, 127.0.0.1 or ::1" assert exp in str(exc) diff --git a/tests/test_ssh_gss.py b/tests/test_ssh_gss.py index f0645e0e..1e08b361 100644 --- a/tests/test_ssh_gss.py +++ b/tests/test_ssh_gss.py @@ -33,15 +33,13 @@ from .util import _support, needs_gssapi from .test_client import FINGERPRINTS -class NullServer (paramiko.ServerInterface): +class NullServer(paramiko.ServerInterface): + def get_allowed_auths(self, username): - return 'gssapi-with-mic,publickey' + return "gssapi-with-mic,publickey" def check_auth_gssapi_with_mic( - self, - username, - gss_authenticated=paramiko.AUTH_FAILED, - cc_file=None, + self, username, gss_authenticated=paramiko.AUTH_FAILED, cc_file=None ): if gss_authenticated == paramiko.AUTH_SUCCESSFUL: return paramiko.AUTH_SUCCESSFUL @@ -64,13 +62,14 @@ class NullServer (paramiko.ServerInterface): return paramiko.OPEN_SUCCEEDED def check_channel_exec_request(self, channel, command): - if command != 'yes': + if command != "yes": return False return True @needs_gssapi class GSSAuthTest(unittest.TestCase): + def setUp(self): # TODO: username and targ_name should come from os.environ or whatever # the approved pytest method is for runtime-configuring test data. @@ -92,7 +91,7 @@ class GSSAuthTest(unittest.TestCase): def _run(self): self.socks, addr = self.sockl.accept() self.ts = paramiko.Transport(self.socks) - host_key = paramiko.RSAKey.from_private_key_file('tests/test_rsa.key') + host_key = paramiko.RSAKey.from_private_key_file("tests/test_rsa.key") self.ts.add_server_key(host_key) server = NullServer() self.ts.start_server(self.event, server) @@ -103,15 +102,22 @@ class GSSAuthTest(unittest.TestCase): The exception is ... no exception yet """ - host_key = paramiko.RSAKey.from_private_key_file('tests/test_rsa.key') + host_key = paramiko.RSAKey.from_private_key_file("tests/test_rsa.key") public_host_key = paramiko.RSAKey(data=host_key.asbytes()) self.tc = paramiko.SSHClient() self.tc.set_missing_host_key_policy(paramiko.WarningPolicy()) - self.tc.get_host_keys().add('[%s]:%d' % (self.addr, self.port), - 'ssh-rsa', public_host_key) - self.tc.connect(hostname=self.addr, port=self.port, username=self.username, gss_host=self.hostname, - gss_auth=True, **kwargs) + self.tc.get_host_keys().add( + "[%s]:%d" % (self.addr, self.port), "ssh-rsa", public_host_key + ) + self.tc.connect( + hostname=self.addr, + port=self.port, + username=self.username, + gss_host=self.hostname, + gss_auth=True, + **kwargs + ) self.event.wait(1.0) self.assert_(self.event.is_set()) @@ -119,17 +125,17 @@ class GSSAuthTest(unittest.TestCase): self.assertEquals(self.username, self.ts.get_username()) self.assertEquals(True, self.ts.is_authenticated()) - stdin, stdout, stderr = self.tc.exec_command('yes') + stdin, stdout, stderr = self.tc.exec_command("yes") schan = self.ts.accept(1.0) - schan.send('Hello there.\n') - schan.send_stderr('This is on stderr.\n') + schan.send("Hello there.\n") + schan.send_stderr("This is on stderr.\n") schan.close() - self.assertEquals('Hello there.\n', stdout.readline()) - self.assertEquals('', stdout.readline()) - self.assertEquals('This is on stderr.\n', stderr.readline()) - self.assertEquals('', stderr.readline()) + self.assertEquals("Hello there.\n", stdout.readline()) + self.assertEquals("", stdout.readline()) + self.assertEquals("This is on stderr.\n", stderr.readline()) + self.assertEquals("", stderr.readline()) stdin.close() stdout.close() @@ -140,14 +146,15 @@ class GSSAuthTest(unittest.TestCase): Verify that Paramiko can handle SSHv2 GSS-API / SSPI authentication (gssapi-with-mic) in client and server mode. """ - self._test_connection(allow_agent=False, - look_for_keys=False) + self._test_connection(allow_agent=False, look_for_keys=False) def test_2_auth_trickledown(self): """ Failed gssapi-with-mic auth doesn't prevent subsequent key auth from succeeding """ self.hostname = "this_host_does_not_exists_and_causes_a_GSSAPI-exception" - self._test_connection(key_filename=[_support('test_rsa.key')], - allow_agent=False, - look_for_keys=False) + self._test_connection( + key_filename=[_support("test_rsa.key")], + allow_agent=False, + look_for_keys=False, + ) diff --git a/tests/test_transport.py b/tests/test_transport.py index 9474acfc..e09c5e92 100644 --- a/tests/test_transport.py +++ b/tests/test_transport.py @@ -32,14 +32,26 @@ from hashlib import sha1 import unittest from paramiko import ( - Transport, SecurityOptions, ServerInterface, RSAKey, DSSKey, SSHException, - ChannelException, Packetizer, Channel, + Transport, + SecurityOptions, + ServerInterface, + RSAKey, + DSSKey, + SSHException, + ChannelException, + Packetizer, + Channel, ) from paramiko import AUTH_FAILED, AUTH_SUCCESSFUL from paramiko import OPEN_SUCCEEDED, OPEN_FAILED_ADMINISTRATIVELY_PROHIBITED from paramiko.common import ( - MSG_KEXINIT, cMSG_CHANNEL_WINDOW_ADJUST, MIN_PACKET_SIZE, MIN_WINDOW_SIZE, - MAX_WINDOW_SIZE, DEFAULT_WINDOW_SIZE, DEFAULT_MAX_PACKET_SIZE, + MSG_KEXINIT, + cMSG_CHANNEL_WINDOW_ADJUST, + MIN_PACKET_SIZE, + MIN_WINDOW_SIZE, + MAX_WINDOW_SIZE, + DEFAULT_WINDOW_SIZE, + DEFAULT_MAX_PACKET_SIZE, ) from paramiko.py3compat import bytes from paramiko.message import Message @@ -61,28 +73,28 @@ Maybe. """ -class NullServer (ServerInterface): +class NullServer(ServerInterface): paranoid_did_password = False paranoid_did_public_key = False - paranoid_key = DSSKey.from_private_key_file(_support('test_dss.key')) + paranoid_key = DSSKey.from_private_key_file(_support("test_dss.key")) def get_allowed_auths(self, username): - if username == 'slowdive': - return 'publickey,password' - return 'publickey' + if username == "slowdive": + return "publickey,password" + return "publickey" def check_auth_password(self, username, password): - if (username == 'slowdive') and (password == 'pygmalion'): + if (username == "slowdive") and (password == "pygmalion"): return AUTH_SUCCESSFUL return AUTH_FAILED def check_channel_request(self, kind, chanid): - if kind == 'bogus': + if kind == "bogus": return OPEN_FAILED_ADMINISTRATIVELY_PROHIBITED return OPEN_SUCCEEDED def check_channel_exec_request(self, channel, command): - if command != b'yes': + if command != b"yes": return False return True @@ -95,9 +107,16 @@ class NullServer (ServerInterface): # tho that's only supposed to occur if the request cannot be served. # For now, leaving that the default unless test supplies specific # 'acceptable' request kind - return kind == 'acceptable' - - def check_channel_x11_request(self, channel, single_connection, auth_protocol, auth_cookie, screen_number): + return kind == "acceptable" + + def check_channel_x11_request( + self, + channel, + single_connection, + auth_protocol, + auth_cookie, + screen_number, + ): self._x11_single_connection = single_connection self._x11_auth_protocol = auth_protocol self._x11_auth_cookie = auth_cookie @@ -106,7 +125,7 @@ class NullServer (ServerInterface): def check_port_forward_request(self, addr, port): self._listen = socket.socket() - self._listen.bind(('127.0.0.1', 0)) + self._listen.bind(("127.0.0.1", 0)) self._listen.listen(1) return self._listen.getsockname()[1] @@ -120,6 +139,7 @@ class NullServer (ServerInterface): class TransportTest(unittest.TestCase): + def setUp(self): self.socks = LoopSocket() self.sockc = LoopSocket() @@ -134,9 +154,9 @@ class TransportTest(unittest.TestCase): self.sockc.close() def setup_test_server( - self, client_options=None, server_options=None, connect_kwargs=None, + self, client_options=None, server_options=None, connect_kwargs=None ): - host_key = RSAKey.from_private_key_file(_support('test_rsa.key')) + host_key = RSAKey.from_private_key_file(_support("test_rsa.key")) public_host_key = RSAKey(data=host_key.asbytes()) self.ts.add_server_key(host_key) @@ -152,8 +172,8 @@ class TransportTest(unittest.TestCase): if connect_kwargs is None: connect_kwargs = dict( hostkey=public_host_key, - username='slowdive', - password='pygmalion', + username="slowdive", + password="pygmalion", ) self.tc.connect(**connect_kwargs) event.wait(1.0) @@ -163,11 +183,11 @@ class TransportTest(unittest.TestCase): def test_1_security_options(self): o = self.tc.get_security_options() self.assertEqual(type(o), SecurityOptions) - self.assertTrue(('aes256-cbc', 'blowfish-cbc') != o.ciphers) - o.ciphers = ('aes256-cbc', 'blowfish-cbc') - self.assertEqual(('aes256-cbc', 'blowfish-cbc'), o.ciphers) + self.assertTrue(("aes256-cbc", "blowfish-cbc") != o.ciphers) + o.ciphers = ("aes256-cbc", "blowfish-cbc") + self.assertEqual(("aes256-cbc", "blowfish-cbc"), o.ciphers) try: - o.ciphers = ('aes256-cbc', 'made-up-cipher') + o.ciphers = ("aes256-cbc", "made-up-cipher") self.assertTrue(False) except ValueError: pass @@ -188,11 +208,13 @@ class TransportTest(unittest.TestCase): def test_2_compute_key(self): self.tc.K = 123281095979686581523377256114209720774539068973101330872763622971399429481072519713536292772709507296759612401802191955568143056534122385270077606457721553469730659233569339356140085284052436697480759510519672848743794433460113118986816826624865291116513647975790797391795651716378444844877749505443714557929 - self.tc.H = b'\x0C\x83\x07\xCD\xE6\x85\x6F\xF3\x0B\xA9\x36\x84\xEB\x0F\x04\xC2\x52\x0E\x9E\xD3' + self.tc.H = b"\x0C\x83\x07\xCD\xE6\x85\x6F\xF3\x0B\xA9\x36\x84\xEB\x0F\x04\xC2\x52\x0E\x9E\xD3" self.tc.session_id = self.tc.H - key = self.tc._compute_key('C', 32) - self.assertEqual(b'207E66594CA87C44ECCBA3B3CD39FDDB378E6FDB0F97C54B2AA0CFBF900CD995', - hexlify(key).upper()) + key = self.tc._compute_key("C", 32) + self.assertEqual( + b"207E66594CA87C44ECCBA3B3CD39FDDB378E6FDB0F97C54B2AA0CFBF900CD995", + hexlify(key).upper(), + ) def test_3_simple(self): """ @@ -200,7 +222,7 @@ class TransportTest(unittest.TestCase): loopback sockets. this is hardly "simple" but it's simpler than the later tests. :) """ - host_key = RSAKey.from_private_key_file(_support('test_rsa.key')) + host_key = RSAKey.from_private_key_file(_support("test_rsa.key")) public_host_key = RSAKey(data=host_key.asbytes()) self.ts.add_server_key(host_key) event = threading.Event() @@ -211,13 +233,14 @@ class TransportTest(unittest.TestCase): self.assertEqual(False, self.tc.is_authenticated()) self.assertEqual(False, self.ts.is_authenticated()) self.ts.start_server(event, server) - self.tc.connect(hostkey=public_host_key, - username='slowdive', password='pygmalion') + self.tc.connect( + hostkey=public_host_key, username="slowdive", password="pygmalion" + ) event.wait(1.0) self.assertTrue(event.is_set()) self.assertTrue(self.ts.is_active()) - self.assertEqual('slowdive', self.tc.get_username()) - self.assertEqual('slowdive', self.ts.get_username()) + self.assertEqual("slowdive", self.tc.get_username()) + self.assertEqual("slowdive", self.ts.get_username()) self.assertEqual(True, self.tc.is_authenticated()) self.assertEqual(True, self.ts.is_authenticated()) @@ -225,7 +248,7 @@ class TransportTest(unittest.TestCase): """ verify that a long banner doesn't mess up the handshake. """ - host_key = RSAKey.from_private_key_file(_support('test_rsa.key')) + host_key = RSAKey.from_private_key_file(_support("test_rsa.key")) public_host_key = RSAKey(data=host_key.asbytes()) self.ts.add_server_key(host_key) event = threading.Event() @@ -233,8 +256,9 @@ class TransportTest(unittest.TestCase): self.assertTrue(not event.is_set()) self.socks.send(LONG_BANNER) self.ts.start_server(event, server) - self.tc.connect(hostkey=public_host_key, - username='slowdive', password='pygmalion') + self.tc.connect( + hostkey=public_host_key, username="slowdive", password="pygmalion" + ) event.wait(1.0) self.assertTrue(event.is_set()) self.assertTrue(self.ts.is_active()) @@ -244,12 +268,14 @@ class TransportTest(unittest.TestCase): verify that the client can demand odd handshake settings, and can renegotiate keys in mid-stream. """ + def force_algorithms(options): - options.ciphers = ('aes256-cbc',) - options.digests = ('hmac-md5-96',) + options.ciphers = ("aes256-cbc",) + options.digests = ("hmac-md5-96",) + self.setup_test_server(client_options=force_algorithms) - self.assertEqual('aes256-cbc', self.tc.local_cipher) - self.assertEqual('aes256-cbc', self.tc.remote_cipher) + self.assertEqual("aes256-cbc", self.tc.local_cipher) + self.assertEqual("aes256-cbc", self.tc.remote_cipher) self.assertEqual(12, self.tc.packetizer.get_mac_size_out()) self.assertEqual(12, self.tc.packetizer.get_mac_size_in()) @@ -263,10 +289,10 @@ class TransportTest(unittest.TestCase): verify that the keepalive will be sent. """ self.setup_test_server() - self.assertEqual(None, getattr(self.server, '_global_request', None)) + self.assertEqual(None, getattr(self.server, "_global_request", None)) self.tc.set_keepalive(1) time.sleep(2) - self.assertEqual('keepalive@lag.net', self.server._global_request) + self.assertEqual("keepalive@lag.net", self.server._global_request) def test_6_exec_command(self): """ @@ -277,39 +303,41 @@ class TransportTest(unittest.TestCase): chan = self.tc.open_session() schan = self.ts.accept(1.0) try: - chan.exec_command(b'command contains \xfc and is not a valid UTF-8 string') + chan.exec_command( + b"command contains \xfc and is not a valid UTF-8 string" + ) self.assertTrue(False) except SSHException: pass chan = self.tc.open_session() - chan.exec_command('yes') + chan.exec_command("yes") schan = self.ts.accept(1.0) - schan.send('Hello there.\n') - schan.send_stderr('This is on stderr.\n') + schan.send("Hello there.\n") + schan.send_stderr("This is on stderr.\n") schan.close() f = chan.makefile() - self.assertEqual('Hello there.\n', f.readline()) - self.assertEqual('', f.readline()) + self.assertEqual("Hello there.\n", f.readline()) + self.assertEqual("", f.readline()) f = chan.makefile_stderr() - self.assertEqual('This is on stderr.\n', f.readline()) - self.assertEqual('', f.readline()) + self.assertEqual("This is on stderr.\n", f.readline()) + self.assertEqual("", f.readline()) # now try it with combined stdout/stderr chan = self.tc.open_session() - chan.exec_command('yes') + chan.exec_command("yes") schan = self.ts.accept(1.0) - schan.send('Hello there.\n') - schan.send_stderr('This is on stderr.\n') + schan.send("Hello there.\n") + schan.send_stderr("This is on stderr.\n") schan.close() chan.set_combine_stderr(True) f = chan.makefile() - self.assertEqual('Hello there.\n', f.readline()) - self.assertEqual('This is on stderr.\n', f.readline()) - self.assertEqual('', f.readline()) - + self.assertEqual("Hello there.\n", f.readline()) + self.assertEqual("This is on stderr.\n", f.readline()) + self.assertEqual("", f.readline()) + def test_6a_channel_can_be_used_as_context_manager(self): """ verify that exec_command() does something reasonable. @@ -318,13 +346,13 @@ class TransportTest(unittest.TestCase): with self.tc.open_session() as chan: with self.ts.accept(1.0) as schan: - chan.exec_command('yes') - schan.send('Hello there.\n') + chan.exec_command("yes") + schan.send("Hello there.\n") schan.close() f = chan.makefile() - self.assertEqual('Hello there.\n', f.readline()) - self.assertEqual('', f.readline()) + self.assertEqual("Hello there.\n", f.readline()) + self.assertEqual("", f.readline()) def test_7_invoke_shell(self): """ @@ -334,11 +362,11 @@ class TransportTest(unittest.TestCase): chan = self.tc.open_session() chan.invoke_shell() schan = self.ts.accept(1.0) - chan.send('communist j. cat\n') + chan.send("communist j. cat\n") f = schan.makefile() - self.assertEqual('communist j. cat\n', f.readline()) + self.assertEqual("communist j. cat\n", f.readline()) chan.close() - self.assertEqual('', f.readline()) + self.assertEqual("", f.readline()) def test_8_channel_exception(self): """ @@ -346,8 +374,8 @@ class TransportTest(unittest.TestCase): """ self.setup_test_server() try: - chan = self.tc.open_channel('bogus') - self.fail('expected exception') + chan = self.tc.open_channel("bogus") + self.fail("expected exception") except ChannelException as e: self.assertTrue(e.code == OPEN_FAILED_ADMINISTRATIVELY_PROHIBITED) @@ -359,8 +387,8 @@ class TransportTest(unittest.TestCase): chan = self.tc.open_session() schan = self.ts.accept(1.0) - chan.exec_command('yes') - schan.send('Hello there.\n') + chan.exec_command("yes") + schan.send("Hello there.\n") self.assertTrue(not chan.exit_status_ready()) # trigger an EOF schan.shutdown_read() @@ -369,8 +397,8 @@ class TransportTest(unittest.TestCase): schan.close() f = chan.makefile() - self.assertEqual('Hello there.\n', f.readline()) - self.assertEqual('', f.readline()) + self.assertEqual("Hello there.\n", f.readline()) + self.assertEqual("", f.readline()) count = 0 while not chan.exit_status_ready(): time.sleep(0.1) @@ -395,7 +423,7 @@ class TransportTest(unittest.TestCase): self.assertEqual([], w) self.assertEqual([], e) - schan.send('hello\n') + schan.send("hello\n") # something should be ready now (give it 1 second to appear) for i in range(10): @@ -407,7 +435,7 @@ class TransportTest(unittest.TestCase): self.assertEqual([], w) self.assertEqual([], e) - self.assertEqual(b'hello\n', chan.recv(6)) + self.assertEqual(b"hello\n", chan.recv(6)) # and, should be dead again now r, w, e = select.select([chan], [], [], 0.1) @@ -442,12 +470,12 @@ class TransportTest(unittest.TestCase): self.setup_test_server() self.tc.packetizer.REKEY_BYTES = 16384 chan = self.tc.open_session() - chan.exec_command('yes') + chan.exec_command("yes") schan = self.ts.accept(1.0) self.assertEqual(self.tc.H, self.tc.session_id) for i in range(20): - chan.send('x' * 1024) + chan.send("x" * 1024) chan.close() # allow a few seconds for the rekeying to complete @@ -463,18 +491,20 @@ class TransportTest(unittest.TestCase): """ verify that zlib compression is basically working. """ + def force_compression(o): - o.compression = ('zlib',) + o.compression = ("zlib",) + self.setup_test_server(force_compression, force_compression) chan = self.tc.open_session() - chan.exec_command('yes') + chan.exec_command("yes") schan = self.ts.accept(1.0) bytes = self.tc.packetizer._Packetizer__sent_bytes - chan.send('x' * 1024) + chan.send("x" * 1024) bytes2 = self.tc.packetizer._Packetizer__sent_bytes - block_size = self.tc._cipher_info[self.tc.local_cipher]['block-size'] - mac_size = self.tc._mac_info[self.tc.local_mac]['size'] + block_size = self.tc._cipher_info[self.tc.local_cipher]["block-size"] + mac_size = self.tc._mac_info[self.tc.local_mac]["size"] # tests show this is actually compressed to *52 bytes*! including packet overhead! nice!! :) self.assertTrue(bytes2 - bytes < 1024) self.assertEqual(16 + block_size + mac_size, bytes2 - bytes) @@ -488,29 +518,32 @@ class TransportTest(unittest.TestCase): """ self.setup_test_server() chan = self.tc.open_session() - chan.exec_command('yes') + chan.exec_command("yes") schan = self.ts.accept(1.0) requested = [] + def handler(c, addr_port): addr, port = addr_port requested.append((addr, port)) self.tc._queue_incoming_channel(c) - self.assertEqual(None, getattr(self.server, '_x11_screen_number', None)) + self.assertEqual( + None, getattr(self.server, "_x11_screen_number", None) + ) cookie = chan.request_x11(0, single_connection=True, handler=handler) self.assertEqual(0, self.server._x11_screen_number) - self.assertEqual('MIT-MAGIC-COOKIE-1', self.server._x11_auth_protocol) + self.assertEqual("MIT-MAGIC-COOKIE-1", self.server._x11_auth_protocol) self.assertEqual(cookie, self.server._x11_auth_cookie) self.assertEqual(True, self.server._x11_single_connection) - x11_server = self.ts.open_x11_channel(('localhost', 6093)) + x11_server = self.ts.open_x11_channel(("localhost", 6093)) x11_client = self.tc.accept() - self.assertEqual('localhost', requested[0][0]) + self.assertEqual("localhost", requested[0][0]) self.assertEqual(6093, requested[0][1]) - x11_server.send('hello') - self.assertEqual(b'hello', x11_client.recv(5)) + x11_server.send("hello") + self.assertEqual(b"hello", x11_client.recv(5)) x11_server.close() x11_client.close() @@ -524,33 +557,36 @@ class TransportTest(unittest.TestCase): """ self.setup_test_server() chan = self.tc.open_session() - chan.exec_command('yes') + chan.exec_command("yes") schan = self.ts.accept(1.0) requested = [] + def handler(c, origin_addr_port, server_addr_port): requested.append(origin_addr_port) requested.append(server_addr_port) self.tc._queue_incoming_channel(c) - port = self.tc.request_port_forward('127.0.0.1', 0, handler) + port = self.tc.request_port_forward("127.0.0.1", 0, handler) self.assertEqual(port, self.server._listen.getsockname()[1]) cs = socket.socket() - cs.connect(('127.0.0.1', port)) + cs.connect(("127.0.0.1", port)) ss, _ = self.server._listen.accept() - sch = self.ts.open_forwarded_tcpip_channel(ss.getsockname(), ss.getpeername()) + sch = self.ts.open_forwarded_tcpip_channel( + ss.getsockname(), ss.getpeername() + ) cch = self.tc.accept() - sch.send('hello') - self.assertEqual(b'hello', cch.recv(5)) + sch.send("hello") + self.assertEqual(b"hello", cch.recv(5)) sch.close() cch.close() ss.close() cs.close() # now cancel it. - self.tc.cancel_port_forward('127.0.0.1', port) + self.tc.cancel_port_forward("127.0.0.1", port) self.assertTrue(self.server._listen is None) def test_F_port_forwarding(self): @@ -560,27 +596,29 @@ class TransportTest(unittest.TestCase): """ self.setup_test_server() chan = self.tc.open_session() - chan.exec_command('yes') + chan.exec_command("yes") schan = self.ts.accept(1.0) # open a port on the "server" that the client will ask to forward to. greeting_server = socket.socket() - greeting_server.bind(('127.0.0.1', 0)) + greeting_server.bind(("127.0.0.1", 0)) greeting_server.listen(1) greeting_port = greeting_server.getsockname()[1] - cs = self.tc.open_channel('direct-tcpip', ('127.0.0.1', greeting_port), ('', 9000)) + cs = self.tc.open_channel( + "direct-tcpip", ("127.0.0.1", greeting_port), ("", 9000) + ) sch = self.ts.accept(1.0) cch = socket.socket() cch.connect(self.server._tcpip_dest) ss, _ = greeting_server.accept() - ss.send(b'Hello!\n') + ss.send(b"Hello!\n") ss.close() sch.send(cch.recv(8192)) sch.close() - self.assertEqual(b'Hello!\n', cs.recv(7)) + self.assertEqual(b"Hello!\n", cs.recv(7)) cs.close() def test_G_stderr_select(self): @@ -599,7 +637,7 @@ class TransportTest(unittest.TestCase): self.assertEqual([], w) self.assertEqual([], e) - schan.send_stderr('hello\n') + schan.send_stderr("hello\n") # something should be ready now (give it 1 second to appear) for i in range(10): @@ -611,7 +649,7 @@ class TransportTest(unittest.TestCase): self.assertEqual([], w) self.assertEqual([], e) - self.assertEqual(b'hello\n', chan.recv_stderr(6)) + self.assertEqual(b"hello\n", chan.recv_stderr(6)) # and, should be dead again now r, w, e = select.select([chan], [], [], 0.1) @@ -633,8 +671,8 @@ class TransportTest(unittest.TestCase): self.assertEqual(chan.send_ready(), True) total = 0 - K = '*' * 1024 - limit = 1+(64 * 2 ** 15) + K = "*" * 1024 + limit = 1 + (64 * 2 ** 15) while total < limit: chan.send(K) total += len(K) @@ -696,8 +734,11 @@ class TransportTest(unittest.TestCase): # expires, a deadlock is assumed. class SendThread(threading.Thread): + def __init__(self, chan, iterations, done_event): - threading.Thread.__init__(self, None, None, self.__class__.__name__) + threading.Thread.__init__( + self, None, None, self.__class__.__name__ + ) self.setDaemon(True) self.chan = chan self.iterations = iterations @@ -707,19 +748,22 @@ class TransportTest(unittest.TestCase): def run(self): try: - for i in range(1, 1+self.iterations): + for i in range(1, 1 + self.iterations): if self.done_event.is_set(): break self.watchdog_event.set() - #print i, "SEND" + # print i, "SEND" self.chan.send("x" * 2048) finally: self.done_event.set() self.watchdog_event.set() class ReceiveThread(threading.Thread): + def __init__(self, chan, done_event): - threading.Thread.__init__(self, None, None, self.__class__.__name__) + threading.Thread.__init__( + self, None, None, self.__class__.__name__ + ) self.setDaemon(True) self.chan = chan self.done_event = done_event @@ -742,7 +786,7 @@ class TransportTest(unittest.TestCase): self.ts.packetizer.REKEY_BYTES = 2048 chan = self.tc.open_session() - chan.exec_command('yes') + chan.exec_command("yes") schan = self.ts.accept(1.0) # Monkey patch the client's Transport._handler_table so that the client @@ -751,21 +795,23 @@ class TransportTest(unittest.TestCase): # on a real MSG_CHANNEL_WINDOW_ADJUST message. self.tc._handler_table = self.tc._handler_table.copy() # copy per-class dictionary _negotiate_keys = self.tc._handler_table[MSG_KEXINIT] + def _negotiate_keys_wrapper(self, m): - if self.local_kex_init is None: # Remote side sent KEXINIT + if self.local_kex_init is None: # Remote side sent KEXINIT # Simulate in-transit MSG_CHANNEL_WINDOW_ADJUST by sending it # before responding to the incoming MSG_KEXINIT. m2 = Message() m2.add_byte(cMSG_CHANNEL_WINDOW_ADJUST) m2.add_int(chan.remote_chanid) - m2.add_int(1) # bytes to add + m2.add_int(1) # bytes to add self._send_message(m2) return _negotiate_keys(self, m) + self.tc._handler_table[MSG_KEXINIT] = _negotiate_keys_wrapper # Parameters for the test - iterations = 500 # The deadlock does not happen every time, but it - # should after many iterations. + iterations = 500 # The deadlock does not happen every time, but it + # should after many iterations. timeout = 5 # This event is set when the test is completed @@ -807,18 +853,22 @@ class TransportTest(unittest.TestCase): """ verify that we conform to the rfc of packet and window sizes. """ - for val, correct in [(4095, MIN_PACKET_SIZE), - (None, DEFAULT_MAX_PACKET_SIZE), - (2**32, MAX_WINDOW_SIZE)]: + for val, correct in [ + (4095, MIN_PACKET_SIZE), + (None, DEFAULT_MAX_PACKET_SIZE), + (2 ** 32, MAX_WINDOW_SIZE), + ]: self.assertEqual(self.tc._sanitize_packet_size(val), correct) def test_K_sanitze_window_size(self): """ verify that we conform to the rfc of packet and window sizes. """ - for val, correct in [(32767, MIN_WINDOW_SIZE), - (None, DEFAULT_WINDOW_SIZE), - (2**32, MAX_WINDOW_SIZE)]: + for val, correct in [ + (32767, MIN_WINDOW_SIZE), + (None, DEFAULT_WINDOW_SIZE), + (2 ** 32, MAX_WINDOW_SIZE), + ]: self.assertEqual(self.tc._sanitize_window_size(val), correct) @slow @@ -834,15 +884,17 @@ class TransportTest(unittest.TestCase): # (Doing this on the server's transport *sounds* more 'correct' but # actually doesn't work nearly as well for whatever reason.) class SlowPacketizer(Packetizer): + def read_message(self): time.sleep(1) return super(SlowPacketizer, self).read_message() + # NOTE: prettttty sure since the replaced .packetizer Packetizer is now # no longer doing anything with its copy of the socket...everything'll # be fine. Even tho it's a bit squicky. self.tc.packetizer = SlowPacketizer(self.tc.sock) # Continue with regular test red tape. - host_key = RSAKey.from_private_key_file(_support('test_rsa.key')) + host_key = RSAKey.from_private_key_file(_support("test_rsa.key")) public_host_key = RSAKey(data=host_key.asbytes()) self.ts.add_server_key(host_key) event = threading.Event() @@ -850,10 +902,13 @@ class TransportTest(unittest.TestCase): self.assertTrue(not event.is_set()) self.tc.handshake_timeout = 0.000000000001 self.ts.start_server(event, server) - self.assertRaises(EOFError, self.tc.connect, - hostkey=public_host_key, - username='slowdive', - password='pygmalion') + self.assertRaises( + EOFError, + self.tc.connect, + hostkey=public_host_key, + username="slowdive", + password="pygmalion", + ) def test_M_select_after_close(self): """ @@ -894,13 +949,13 @@ class TransportTest(unittest.TestCase): expected = text.encode("utf-8") self.assertEqual(sfile.read(len(expected)), expected) - @needs_builtin('buffer') + @needs_builtin("buffer") def test_channel_send_buffer(self): """ verify sending buffer instances to a channel """ self.setup_test_server() - data = 3 * b'some test data\n whole' + data = 3 * b"some test data\n whole" with self.tc.open_session() as chan: schan = self.ts.accept(1.0) if schan is None: @@ -917,13 +972,13 @@ class TransportTest(unittest.TestCase): chan.sendall(buffer(data)) self.assertEqual(sfile.read(len(data)), data) - @needs_builtin('memoryview') + @needs_builtin("memoryview") def test_channel_send_memoryview(self): """ verify sending memoryview instances to a channel """ self.setup_test_server() - data = 3 * b'some test data\n whole' + data = 3 * b"some test data\n whole" with self.tc.open_session() as chan: schan = self.ts.accept(1.0) if schan is None: @@ -934,7 +989,7 @@ class TransportTest(unittest.TestCase): sent = 0 view = memoryview(data) while sent < len(view): - sent += chan.send(view[sent:sent+8]) + sent += chan.send(view[sent:sent + 8]) self.assertEqual(sfile.read(len(data)), data) # sendall() accepts a memoryview instance @@ -954,7 +1009,7 @@ class TransportTest(unittest.TestCase): self.setup_test_server(connect_kwargs={}) # NOTE: this dummy global request kind would normally pass muster # from the test server. - self.tc.global_request('acceptable') + self.tc.global_request("acceptable") # Global requests never raise exceptions, even on failure (not sure why # this was the original design...ugh.) Best we can do to tell failure # happened is that the client transport's global_response was set back @@ -969,7 +1024,7 @@ class TransportTest(unittest.TestCase): # an exception on the client side, unlike the general case...) self.setup_test_server(connect_kwargs={}) try: - self.tc.request_port_forward('localhost', 1234) + self.tc.request_port_forward("localhost", 1234) except SSHException as e: assert "forwarding request denied" in str(e) else: diff --git a/tests/test_util.py b/tests/test_util.py index 90473f43..6431b9c1 100644 --- a/tests/test_util.py +++ b/tests/test_util.py @@ -67,53 +67,70 @@ from paramiko import * class UtilTest(unittest.TestCase): + def test_import(self): """ verify that all the classes can be imported from paramiko. """ symbols = list(globals().keys()) - self.assertTrue('Transport' in symbols) - self.assertTrue('SSHClient' in symbols) - self.assertTrue('MissingHostKeyPolicy' in symbols) - self.assertTrue('AutoAddPolicy' in symbols) - self.assertTrue('RejectPolicy' in symbols) - self.assertTrue('WarningPolicy' in symbols) - self.assertTrue('SecurityOptions' in symbols) - self.assertTrue('SubsystemHandler' in symbols) - self.assertTrue('Channel' in symbols) - self.assertTrue('RSAKey' in symbols) - self.assertTrue('DSSKey' in symbols) - self.assertTrue('Message' in symbols) - self.assertTrue('SSHException' in symbols) - self.assertTrue('AuthenticationException' in symbols) - self.assertTrue('PasswordRequiredException' in symbols) - self.assertTrue('BadAuthenticationType' in symbols) - self.assertTrue('ChannelException' in symbols) - self.assertTrue('SFTP' in symbols) - self.assertTrue('SFTPFile' in symbols) - self.assertTrue('SFTPHandle' in symbols) - self.assertTrue('SFTPClient' in symbols) - self.assertTrue('SFTPServer' in symbols) - self.assertTrue('SFTPError' in symbols) - self.assertTrue('SFTPAttributes' in symbols) - self.assertTrue('SFTPServerInterface' in symbols) - self.assertTrue('ServerInterface' in symbols) - self.assertTrue('BufferedFile' in symbols) - self.assertTrue('Agent' in symbols) - self.assertTrue('AgentKey' in symbols) - self.assertTrue('HostKeys' in symbols) - self.assertTrue('SSHConfig' in symbols) - self.assertTrue('util' in symbols) + self.assertTrue("Transport" in symbols) + self.assertTrue("SSHClient" in symbols) + self.assertTrue("MissingHostKeyPolicy" in symbols) + self.assertTrue("AutoAddPolicy" in symbols) + self.assertTrue("RejectPolicy" in symbols) + self.assertTrue("WarningPolicy" in symbols) + self.assertTrue("SecurityOptions" in symbols) + self.assertTrue("SubsystemHandler" in symbols) + self.assertTrue("Channel" in symbols) + self.assertTrue("RSAKey" in symbols) + self.assertTrue("DSSKey" in symbols) + self.assertTrue("Message" in symbols) + self.assertTrue("SSHException" in symbols) + self.assertTrue("AuthenticationException" in symbols) + self.assertTrue("PasswordRequiredException" in symbols) + self.assertTrue("BadAuthenticationType" in symbols) + self.assertTrue("ChannelException" in symbols) + self.assertTrue("SFTP" in symbols) + self.assertTrue("SFTPFile" in symbols) + self.assertTrue("SFTPHandle" in symbols) + self.assertTrue("SFTPClient" in symbols) + self.assertTrue("SFTPServer" in symbols) + self.assertTrue("SFTPError" in symbols) + self.assertTrue("SFTPAttributes" in symbols) + self.assertTrue("SFTPServerInterface" in symbols) + self.assertTrue("ServerInterface" in symbols) + self.assertTrue("BufferedFile" in symbols) + self.assertTrue("Agent" in symbols) + self.assertTrue("AgentKey" in symbols) + self.assertTrue("HostKeys" in symbols) + self.assertTrue("SSHConfig" in symbols) + self.assertTrue("util" in symbols) def test_parse_config(self): global test_config_file f = StringIO(test_config_file) config = paramiko.util.parse_ssh_config(f) - self.assertEqual(config._config, - [{'host': ['*'], 'config': {}}, {'host': ['*'], 'config': {'identityfile': ['~/.ssh/id_rsa'], 'user': 'robey'}}, - {'host': ['*.example.com'], 'config': {'user': 'bjork', 'port': '3333'}}, - {'host': ['*'], 'config': {'crazy': 'something dumb'}}, - {'host': ['spoo.example.com'], 'config': {'crazy': 'something else'}}]) + self.assertEqual( + config._config, + [ + {"host": ["*"], "config": {}}, + { + "host": ["*"], + "config": { + "identityfile": ["~/.ssh/id_rsa"], "user": "robey" + }, + }, + { + "host": ["*.example.com"], + "config": {"user": "bjork", "port": "3333"}, + }, + {"host": ["*"], "config": {"crazy": "something dumb"}}, + { + "host": ["spoo.example.com"], + "config": {"crazy": "something else"}, + }, + ], + ) def test_host_config(self): global test_config_file @@ -121,44 +138,57 @@ class UtilTest(unittest.TestCase): config = paramiko.util.parse_ssh_config(f) for host, values in { - 'irc.danger.com': {'crazy': 'something dumb', - 'hostname': 'irc.danger.com', - 'user': 'robey'}, - 'irc.example.com': {'crazy': 'something dumb', - 'hostname': 'irc.example.com', - 'user': 'robey', - 'port': '3333'}, - 'spoo.example.com': {'crazy': 'something dumb', - 'hostname': 'spoo.example.com', - 'user': 'robey', - 'port': '3333'} + "irc.danger.com": { + "crazy": "something dumb", + "hostname": "irc.danger.com", + "user": "robey", + }, + "irc.example.com": { + "crazy": "something dumb", + "hostname": "irc.example.com", + "user": "robey", + "port": "3333", + }, + "spoo.example.com": { + "crazy": "something dumb", + "hostname": "spoo.example.com", + "user": "robey", + "port": "3333", + }, }.items(): - values = dict(values, + values = dict( + values, hostname=host, - identityfile=[os.path.expanduser("~/.ssh/id_rsa")] + identityfile=[os.path.expanduser("~/.ssh/id_rsa")], ) self.assertEqual( - paramiko.util.lookup_ssh_host_config(host, config), - values + paramiko.util.lookup_ssh_host_config(host, config), values ) def test_generate_key_bytes(self): - x = paramiko.util.generate_key_bytes(sha1, b'ABCDEFGH', 'This is my secret passphrase.', 64) - hex = ''.join(['%02x' % byte_ord(c) for c in x]) - self.assertEqual(hex, '9110e2f6793b69363e58173e9436b13a5a4b339005741d5c680e505f57d871347b4239f14fb5c46e857d5e100424873ba849ac699cea98d729e57b3e84378e8b') + x = paramiko.util.generate_key_bytes( + sha1, b"ABCDEFGH", "This is my secret passphrase.", 64 + ) + hex = "".join(["%02x" % byte_ord(c) for c in x]) + self.assertEqual( + hex, + "9110e2f6793b69363e58173e9436b13a5a4b339005741d5c680e505f57d871347b4239f14fb5c46e857d5e100424873ba849ac699cea98d729e57b3e84378e8b", + ) def test_host_keys(self): - with open('hostfile.temp', 'w') as f: + with open("hostfile.temp", "w") as f: f.write(test_hosts_file) try: - hostdict = paramiko.util.load_host_keys('hostfile.temp') + hostdict = paramiko.util.load_host_keys("hostfile.temp") self.assertEqual(2, len(hostdict)) self.assertEqual(1, len(list(hostdict.values())[0])) self.assertEqual(1, len(list(hostdict.values())[1])) - fp = hexlify(hostdict['secure.example.com']['ssh-rsa'].get_fingerprint()).upper() - self.assertEqual(b'E6684DB30E109B67B70FF1DC5C7F1363', fp) + fp = hexlify( + hostdict["secure.example.com"]["ssh-rsa"].get_fingerprint() + ).upper() + self.assertEqual(b"E6684DB30E109B67B70FF1DC5C7F1363", fp) finally: - os.unlink('hostfile.temp') + os.unlink("hostfile.temp") def test_host_config_expose_issue_33(self): test_config_file = """ @@ -173,36 +203,44 @@ Host * """ f = StringIO(test_config_file) config = paramiko.util.parse_ssh_config(f) - host = 'www13.example.com' + host = "www13.example.com" self.assertEqual( paramiko.util.lookup_ssh_host_config(host, config), - {'hostname': host, 'port': '22'} + {"hostname": host, "port": "22"}, ) def test_eintr_retry(self): - self.assertEqual('foo', paramiko.util.retry_on_signal(lambda: 'foo')) + self.assertEqual("foo", paramiko.util.retry_on_signal(lambda: "foo")) # Variables that are set by raises_intr intr_errors_remaining = [3] call_count = [0] + def raises_intr(): call_count[0] += 1 if intr_errors_remaining[0] > 0: intr_errors_remaining[0] -= 1 - raise IOError(errno.EINTR, 'file', 'interrupted system call') + raise IOError(errno.EINTR, "file", "interrupted system call") + self.assertTrue(paramiko.util.retry_on_signal(raises_intr) is None) self.assertEqual(0, intr_errors_remaining[0]) self.assertEqual(4, call_count[0]) def raises_ioerror_not_eintr(): - raise IOError(errno.ENOENT, 'file', 'file not found') - self.assertRaises(IOError, - lambda: paramiko.util.retry_on_signal(raises_ioerror_not_eintr)) + raise IOError(errno.ENOENT, "file", "file not found") + + self.assertRaises( + IOError, + lambda: paramiko.util.retry_on_signal(raises_ioerror_not_eintr), + ) def raises_other_exception(): - raise AssertionError('foo') - self.assertRaises(AssertionError, - lambda: paramiko.util.retry_on_signal(raises_other_exception)) + raise AssertionError("foo") + + self.assertRaises( + AssertionError, + lambda: paramiko.util.retry_on_signal(raises_other_exception), + ) def test_proxycommand_config_equals_parsing(self): """ @@ -217,17 +255,18 @@ Host equals-delimited """ f = StringIO(conf) config = paramiko.util.parse_ssh_config(f) - for host in ('space-delimited', 'equals-delimited'): + for host in ("space-delimited", "equals-delimited"): self.assertEqual( - host_config(host, config)['proxycommand'], - 'foo bar=biz baz' + host_config(host, config)["proxycommand"], "foo bar=biz baz" ) def test_proxycommand_interpolation(self): """ ProxyCommand should perform interpolation on the value """ - config = paramiko.util.parse_ssh_config(StringIO(""" + config = paramiko.util.parse_ssh_config( + StringIO( + """ Host specific Port 37 ProxyCommand host %h port %p lol @@ -238,28 +277,32 @@ Host portonly Host * Port 25 ProxyCommand host %h port %p -""")) +""" + ) + ) for host, val in ( - ('foo.com', "host foo.com port 25"), - ('specific', "host specific port 37 lol"), - ('portonly', "host portonly port 155"), + ("foo.com", "host foo.com port 25"), + ("specific", "host specific port 37 lol"), + ("portonly", "host portonly port 155"), ): - self.assertEqual( - host_config(host, config)['proxycommand'], - val - ) + self.assertEqual(host_config(host, config)["proxycommand"], val) def test_proxycommand_tilde_expansion(self): """ Tilde (~) should be expanded inside ProxyCommand """ - config = paramiko.util.parse_ssh_config(StringIO(""" + config = paramiko.util.parse_ssh_config( + StringIO( + """ Host test ProxyCommand ssh -F ~/.ssh/test_config bastion nc %h %p -""")) +""" + ) + ) self.assertEqual( - 'ssh -F %s/.ssh/test_config bastion nc test 22' % os.path.expanduser('~'), - host_config('test', config)['proxycommand'] + "ssh -F %s/.ssh/test_config bastion nc test 22" + % os.path.expanduser("~"), + host_config("test", config)["proxycommand"], ) def test_host_config_test_negation(self): @@ -278,10 +321,10 @@ Host * """ f = StringIO(test_config_file) config = paramiko.util.parse_ssh_config(f) - host = 'www13.example.com' + host = "www13.example.com" self.assertEqual( paramiko.util.lookup_ssh_host_config(host, config), - {'hostname': host, 'port': '8080'} + {"hostname": host, "port": "8080"}, ) def test_host_config_test_proxycommand(self): @@ -296,20 +339,24 @@ Host proxy-without-equal-divisor ProxyCommand foo=bar:%h-%p """ for host, values in { - 'proxy-with-equal-divisor-and-space' :{'hostname': 'proxy-with-equal-divisor-and-space', - 'proxycommand': 'foo=bar'}, - 'proxy-with-equal-divisor-and-no-space':{'hostname': 'proxy-with-equal-divisor-and-no-space', - 'proxycommand': 'foo=bar'}, - 'proxy-without-equal-divisor' :{'hostname': 'proxy-without-equal-divisor', - 'proxycommand': - 'foo=bar:proxy-without-equal-divisor-22'} + "proxy-with-equal-divisor-and-space": { + "hostname": "proxy-with-equal-divisor-and-space", + "proxycommand": "foo=bar", + }, + "proxy-with-equal-divisor-and-no-space": { + "hostname": "proxy-with-equal-divisor-and-no-space", + "proxycommand": "foo=bar", + }, + "proxy-without-equal-divisor": { + "hostname": "proxy-without-equal-divisor", + "proxycommand": "foo=bar:proxy-without-equal-divisor-22", + }, }.items(): f = StringIO(test_config_file) config = paramiko.util.parse_ssh_config(f) self.assertEqual( - paramiko.util.lookup_ssh_host_config(host, config), - values + paramiko.util.lookup_ssh_host_config(host, config), values ) def test_host_config_test_identityfile(self): @@ -327,19 +374,21 @@ Host dsa2* IdentityFile id_dsa22 """ for host, values in { - 'foo' :{'hostname': 'foo', - 'identityfile': ['id_dsa0', 'id_dsa1']}, - 'dsa2' :{'hostname': 'dsa2', - 'identityfile': ['id_dsa0', 'id_dsa1', 'id_dsa2', 'id_dsa22']}, - 'dsa22' :{'hostname': 'dsa22', - 'identityfile': ['id_dsa0', 'id_dsa1', 'id_dsa22']} + "foo": {"hostname": "foo", "identityfile": ["id_dsa0", "id_dsa1"]}, + "dsa2": { + "hostname": "dsa2", + "identityfile": ["id_dsa0", "id_dsa1", "id_dsa2", "id_dsa22"], + }, + "dsa22": { + "hostname": "dsa22", + "identityfile": ["id_dsa0", "id_dsa1", "id_dsa22"], + }, }.items(): f = StringIO(test_config_file) config = paramiko.util.parse_ssh_config(f) self.assertEqual( - paramiko.util.lookup_ssh_host_config(host, config), - values + paramiko.util.lookup_ssh_host_config(host, config), values ) def test_config_addressfamily_and_lazy_fqdn(self): @@ -351,7 +400,9 @@ AddressFamily inet IdentityFile something_%l_using_fqdn """ config = paramiko.util.parse_ssh_config(StringIO(test_config)) - assert config.lookup('meh') # will die during lookup() if bug regresses + assert config.lookup( + "meh" + ) # will die during lookup() if bug regresses def test_clamp_value(self): self.assertEqual(32768, paramiko.util.clamp_value(32767, 32768, 32769)) @@ -367,7 +418,9 @@ IdentityFile something_%l_using_fqdn def test_get_hostnames(self): f = StringIO(test_config_file) config = paramiko.util.parse_ssh_config(f) - self.assertEqual(config.get_hostnames(), {'*', '*.example.com', 'spoo.example.com'}) + self.assertEqual( + config.get_hostnames(), {"*", "*.example.com", "spoo.example.com"} + ) def test_quoted_host_names(self): test_config_file = """\ @@ -384,27 +437,23 @@ Host param4 "p a r" "p" "par" para Port 4444 """ res = { - 'param pam': {'hostname': 'param pam', 'port': '1111'}, - 'param': {'hostname': 'param', 'port': '1111'}, - 'pam': {'hostname': 'pam', 'port': '1111'}, - - 'param2': {'hostname': 'param2', 'port': '2222'}, - - 'param3': {'hostname': 'param3', 'port': '3333'}, - 'parara': {'hostname': 'parara', 'port': '3333'}, - - 'param4': {'hostname': 'param4', 'port': '4444'}, - 'p a r': {'hostname': 'p a r', 'port': '4444'}, - 'p': {'hostname': 'p', 'port': '4444'}, - 'par': {'hostname': 'par', 'port': '4444'}, - 'para': {'hostname': 'para', 'port': '4444'}, + "param pam": {"hostname": "param pam", "port": "1111"}, + "param": {"hostname": "param", "port": "1111"}, + "pam": {"hostname": "pam", "port": "1111"}, + "param2": {"hostname": "param2", "port": "2222"}, + "param3": {"hostname": "param3", "port": "3333"}, + "parara": {"hostname": "parara", "port": "3333"}, + "param4": {"hostname": "param4", "port": "4444"}, + "p a r": {"hostname": "p a r", "port": "4444"}, + "p": {"hostname": "p", "port": "4444"}, + "par": {"hostname": "par", "port": "4444"}, + "para": {"hostname": "para", "port": "4444"}, } f = StringIO(test_config_file) config = paramiko.util.parse_ssh_config(f) for host, values in res.items(): self.assertEquals( - paramiko.util.lookup_ssh_host_config(host, config), - values + paramiko.util.lookup_ssh_host_config(host, config), values ) def test_quoted_params_in_config(self): @@ -420,52 +469,44 @@ Host param3 parara IdentityFile "test rsa key" """ res = { - 'param pam': {'hostname': 'param pam', 'identityfile': ['id_rsa']}, - 'param': {'hostname': 'param', 'identityfile': ['id_rsa']}, - 'pam': {'hostname': 'pam', 'identityfile': ['id_rsa']}, - - 'param2': {'hostname': 'param2', 'identityfile': ['test rsa key']}, - - 'param3': {'hostname': 'param3', 'identityfile': ['id_rsa', 'test rsa key']}, - 'parara': {'hostname': 'parara', 'identityfile': ['id_rsa', 'test rsa key']}, + "param pam": {"hostname": "param pam", "identityfile": ["id_rsa"]}, + "param": {"hostname": "param", "identityfile": ["id_rsa"]}, + "pam": {"hostname": "pam", "identityfile": ["id_rsa"]}, + "param2": {"hostname": "param2", "identityfile": ["test rsa key"]}, + "param3": { + "hostname": "param3", + "identityfile": ["id_rsa", "test rsa key"], + }, + "parara": { + "hostname": "parara", + "identityfile": ["id_rsa", "test rsa key"], + }, } f = StringIO(test_config_file) config = paramiko.util.parse_ssh_config(f) for host, values in res.items(): self.assertEquals( - paramiko.util.lookup_ssh_host_config(host, config), - values + paramiko.util.lookup_ssh_host_config(host, config), values ) def test_quoted_host_in_config(self): conf = SSHConfig() correct_data = { - 'param': ['param'], - '"param"': ['param'], - - 'param pam': ['param', 'pam'], - '"param" "pam"': ['param', 'pam'], - '"param" pam': ['param', 'pam'], - 'param "pam"': ['param', 'pam'], - - 'param "pam" p': ['param', 'pam', 'p'], - '"param" pam "p"': ['param', 'pam', 'p'], - - '"pa ram"': ['pa ram'], - '"pa ram" pam': ['pa ram', 'pam'], - 'param "p a m"': ['param', 'p a m'], + "param": ["param"], + '"param"': ["param"], + "param pam": ["param", "pam"], + '"param" "pam"': ["param", "pam"], + '"param" pam': ["param", "pam"], + 'param "pam"': ["param", "pam"], + 'param "pam" p': ["param", "pam", "p"], + '"param" pam "p"': ["param", "pam", "p"], + '"pa ram"': ["pa ram"], + '"pa ram" pam': ["pa ram", "pam"], + 'param "p a m"': ["param", "p a m"], } - incorrect_data = [ - 'param"', - '"param', - 'param "pam', - 'param "pam" "p a', - ] + incorrect_data = ['param"', '"param', 'param "pam', 'param "pam" "p a'] for host, values in correct_data.items(): - self.assertEquals( - conf._get_hosts(host), - values - ) + self.assertEquals(conf._get_hosts(host), values) for host in incorrect_data: self.assertRaises(Exception, conf._get_hosts, host) @@ -490,15 +531,18 @@ Host proxycommand-with-equals-none ProxyCommand=None """ for host, values in { - 'proxycommand-standard-none': {'hostname': 'proxycommand-standard-none'}, - 'proxycommand-with-equals-none': {'hostname': 'proxycommand-with-equals-none'} + "proxycommand-standard-none": { + "hostname": "proxycommand-standard-none" + }, + "proxycommand-with-equals-none": { + "hostname": "proxycommand-with-equals-none" + }, }.items(): f = StringIO(test_config_file) config = paramiko.util.parse_ssh_config(f) self.assertEqual( - paramiko.util.lookup_ssh_host_config(host, config), - values + paramiko.util.lookup_ssh_host_config(host, config), values ) def test_proxycommand_none_masking(self): @@ -521,12 +565,10 @@ Host * # backwards compatibility reasons in 1.x/2.x) appear completely blank, # as if the host had no ProxyCommand whatsoever. # Threw another unrelated host in there just for sanity reasons. - self.assertFalse('proxycommand' in config.lookup('specific-host')) + self.assertFalse("proxycommand" in config.lookup("specific-host")) self.assertEqual( - config.lookup('other-host')['proxycommand'], - 'other-proxy' + config.lookup("other-host")["proxycommand"], "other-proxy" ) self.assertEqual( - config.lookup('some-random-host')['proxycommand'], - 'default-proxy' + config.lookup("some-random-host")["proxycommand"], "default-proxy" ) |