summaryrefslogtreecommitdiffhomepage
path: root/tests
diff options
context:
space:
mode:
authorJeff Forcier <jeff@bitprophet.org>2023-05-04 13:52:40 -0400
committerJeff Forcier <jeff@bitprophet.org>2023-05-05 12:27:20 -0400
commite22c5ea330814801d8487dc3da347f987bafe5ec (patch)
treecc14df5d632b26c1ef581cfc76063a0cb51e43fc /tests
parent4803d68d3df2dd8d3391293cca9d4c6cbd503135 (diff)
Start consolidating test server nonsense
Diffstat (limited to 'tests')
-rw-r--r--tests/_util.py245
-rw-r--r--tests/test_transport.py198
2 files changed, 264 insertions, 179 deletions
diff --git a/tests/_util.py b/tests/_util.py
index 2f1c5ac2..2bfe314d 100644
--- a/tests/_util.py
+++ b/tests/_util.py
@@ -1,13 +1,29 @@
+from contextlib import contextmanager
from os.path import dirname, realpath, join
import builtins
import os
from pathlib import Path
+import socket
import struct
import sys
import unittest
+from time import sleep
+import threading
import pytest
+from paramiko import (
+ ServerInterface,
+ RSAKey,
+ DSSKey,
+ AUTH_FAILED,
+ AUTH_PARTIALLY_SUCCESSFUL,
+ AUTH_SUCCESSFUL,
+ OPEN_SUCCEEDED,
+ OPEN_FAILED_ADMINISTRATIVELY_PROHIBITED,
+ InteractiveQuery,
+ Transport,
+)
from paramiko.ssh_gss import GSS_AUTH_AVAILABLE
from cryptography.exceptions import UnsupportedAlgorithm, _Reasons
@@ -17,6 +33,8 @@ from cryptography.hazmat.primitives.asymmetric import padding, rsa
tests_dir = dirname(realpath(__file__))
+from ._loop import LoopSocket
+
def _support(filename):
base = Path(tests_dir)
@@ -176,3 +194,230 @@ def sha1_signing_unsupported():
requires_sha1_signing = unittest.skipIf(
sha1_signing_unsupported(), "SHA-1 signing not supported"
)
+
+_disable_sha2 = dict(
+ disabled_algorithms=dict(keys=["rsa-sha2-256", "rsa-sha2-512"])
+)
+_disable_sha1 = dict(disabled_algorithms=dict(keys=["ssh-rsa"]))
+_disable_sha2_pubkey = dict(
+ disabled_algorithms=dict(pubkeys=["rsa-sha2-256", "rsa-sha2-512"])
+)
+_disable_sha1_pubkey = dict(disabled_algorithms=dict(pubkeys=["ssh-rsa"]))
+
+
+unicodey = "\u2022"
+
+
+class TestServer(ServerInterface):
+ paranoid_did_password = False
+ paranoid_did_public_key = False
+ # TODO: make this ed25519 or something else modern? (_is_ this used??)
+ paranoid_key = DSSKey.from_private_key_file(_support("dss.key"))
+
+ def __init__(self, allowed_keys=None):
+ self.allowed_keys = allowed_keys if allowed_keys is not None else []
+
+ def check_channel_request(self, kind, chanid):
+ if kind == "bogus":
+ return OPEN_FAILED_ADMINISTRATIVELY_PROHIBITED
+ return OPEN_SUCCEEDED
+
+ def check_channel_exec_request(self, channel, command):
+ if command != b"yes":
+ return False
+ return True
+
+ def check_channel_shell_request(self, channel):
+ return True
+
+ def check_global_request(self, kind, msg):
+ self._global_request = kind
+ # NOTE: for w/e reason, older impl of this returned False always, even
+ # tho that's only supposed to occur if the request cannot be served.
+ # For now, leaving that the default unless test supplies specific
+ # 'acceptable' request kind
+ return kind == "acceptable"
+
+ def check_channel_x11_request(
+ self,
+ channel,
+ single_connection,
+ auth_protocol,
+ auth_cookie,
+ screen_number,
+ ):
+ self._x11_single_connection = single_connection
+ self._x11_auth_protocol = auth_protocol
+ self._x11_auth_cookie = auth_cookie
+ self._x11_screen_number = screen_number
+ return True
+
+ def check_port_forward_request(self, addr, port):
+ self._listen = socket.socket()
+ self._listen.bind(("127.0.0.1", 0))
+ self._listen.listen(1)
+ return self._listen.getsockname()[1]
+
+ def cancel_port_forward_request(self, addr, port):
+ self._listen.close()
+ self._listen = None
+
+ def check_channel_direct_tcpip_request(self, chanid, origin, destination):
+ self._tcpip_dest = destination
+ return OPEN_SUCCEEDED
+
+ def get_allowed_auths(self, username):
+ if username == "slowdive":
+ return "publickey,password"
+ if username == "paranoid":
+ if (
+ not self.paranoid_did_password
+ and not self.paranoid_did_public_key
+ ):
+ return "publickey,password"
+ elif self.paranoid_did_password:
+ return "publickey"
+ else:
+ return "password"
+ if username == "commie":
+ return "keyboard-interactive"
+ if username == "utf8":
+ return "password"
+ if username == "non-utf8":
+ return "password"
+ return "publickey"
+
+ def check_auth_password(self, username, password):
+ if (username == "slowdive") and (password == "pygmalion"):
+ return AUTH_SUCCESSFUL
+ if (username == "paranoid") and (password == "paranoid"):
+ # 2-part auth (even openssh doesn't support this)
+ self.paranoid_did_password = True
+ if self.paranoid_did_public_key:
+ return AUTH_SUCCESSFUL
+ return AUTH_PARTIALLY_SUCCESSFUL
+ if (username == "utf8") and (password == unicodey):
+ return AUTH_SUCCESSFUL
+ if (username == "non-utf8") and (password == "\xff"):
+ return AUTH_SUCCESSFUL
+ if username == "bad-server":
+ raise Exception("Ack!")
+ if username == "unresponsive-server":
+ sleep(5)
+ return AUTH_SUCCESSFUL
+ return AUTH_FAILED
+
+ def check_auth_publickey(self, username, key):
+ if (username == "paranoid") and (key == self.paranoid_key):
+ # 2-part auth
+ self.paranoid_did_public_key = True
+ if self.paranoid_did_password:
+ return AUTH_SUCCESSFUL
+ return AUTH_PARTIALLY_SUCCESSFUL
+ # TODO: make sure all tests incidentally using this to pass, _without
+ # sending a username oops_, get updated somehow - probably via server()
+ # default always injecting a username
+ elif key in self.allowed_keys:
+ return AUTH_SUCCESSFUL
+ return AUTH_FAILED
+
+ def check_auth_interactive(self, username, submethods):
+ if username == "commie":
+ self.username = username
+ return InteractiveQuery(
+ "password", "Please enter a password.", ("Password", False)
+ )
+ return AUTH_FAILED
+
+ def check_auth_interactive_response(self, responses):
+ if self.username == "commie":
+ if (len(responses) == 1) and (responses[0] == "cat"):
+ return AUTH_SUCCESSFUL
+ return AUTH_FAILED
+
+
+@contextmanager
+def server(
+ hostkey=None,
+ init=None,
+ server_init=None,
+ client_init=None,
+ connect=None,
+ pubkeys=None,
+ catch_error=False,
+ transport_factory=None,
+):
+ """
+ SSH server contextmanager for testing.
+
+ :param hostkey:
+ Host key to use for the server; if None, loads
+ ``rsa.key``.
+ :param init:
+ Default `Transport` constructor kwargs to use for both sides.
+ :param server_init:
+ Extends and/or overrides ``init`` for server transport only.
+ :param client_init:
+ Extends and/or overrides ``init`` for client transport only.
+ :param connect:
+ Kwargs to use for ``connect()`` on the client.
+ :param pubkeys:
+ List of public keys for auth.
+ :param catch_error:
+ Whether to capture connection errors & yield from contextmanager.
+ Necessary for connection_time exception testing.
+ :param transport_factory:
+ Like the same-named param in SSHClient: which Transport class to use.
+ """
+ if init is None:
+ init = {}
+ if server_init is None:
+ server_init = {}
+ if client_init is None:
+ client_init = {}
+ if connect is None:
+ connect = dict(username="slowdive", password="pygmalion")
+ socks = LoopSocket()
+ sockc = LoopSocket()
+ sockc.link(socks)
+ if transport_factory is None:
+ transport_factory = Transport
+ tc = transport_factory(sockc, **dict(init, **client_init))
+ ts = transport_factory(socks, **dict(init, **server_init))
+
+ if hostkey is None:
+ hostkey = RSAKey.from_private_key_file(_support("rsa.key"))
+ ts.add_server_key(hostkey)
+ event = threading.Event()
+ server = TestServer(allowed_keys=pubkeys)
+ assert not event.is_set()
+ assert not ts.is_active()
+ assert tc.get_username() is None
+ assert ts.get_username() is None
+ assert not tc.is_authenticated()
+ assert not ts.is_authenticated()
+
+ err = None
+ # Trap errors and yield instead of raising right away; otherwise callers
+ # cannot usefully deal with problems at connect time which stem from errors
+ # in the server side.
+ try:
+ ts.start_server(event, server)
+ tc.connect(**connect)
+
+ event.wait(1.0)
+ assert event.is_set()
+ assert ts.is_active()
+ assert tc.is_active()
+
+ except Exception as e:
+ if not catch_error:
+ raise
+ err = e
+
+ yield (tc, ts, err) if catch_error else (tc, ts)
+
+ tc.close()
+ ts.close()
+ socks.close()
+ sockc.close()
diff --git a/tests/test_transport.py b/tests/test_transport.py
index d7704af6..ee00830a 100644
--- a/tests/test_transport.py
+++ b/tests/test_transport.py
@@ -22,8 +22,6 @@ Some unit tests for the ssh2 protocol in Transport.
from binascii import hexlify
-from contextlib import contextmanager
-import pytest
import select
import socket
import time
@@ -35,18 +33,15 @@ from unittest.mock import Mock
from paramiko import (
AuthHandler,
ChannelException,
- DSSKey,
Packetizer,
RSAKey,
SSHException,
AuthenticationException,
IncompatiblePeer,
SecurityOptions,
- ServerInterface,
Transport,
)
-from paramiko import AUTH_FAILED, AUTH_SUCCESSFUL
-from paramiko import OPEN_SUCCEEDED, OPEN_FAILED_ADMINISTRATIVELY_PROHIBITED
+from paramiko import OPEN_FAILED_ADMINISTRATIVELY_PROHIBITED
from paramiko.common import (
DEFAULT_MAX_PACKET_SIZE,
DEFAULT_WINDOW_SIZE,
@@ -61,7 +56,18 @@ from paramiko.common import (
)
from paramiko.message import Message
-from ._util import needs_builtin, _support, requires_sha1_signing, slow
+from ._util import (
+ needs_builtin,
+ _support,
+ requires_sha1_signing,
+ slow,
+ server,
+ _disable_sha2,
+ _disable_sha1,
+ _disable_sha2_pubkey,
+ _disable_sha1_pubkey,
+ TestServer as NullServer,
+)
from ._loop import LoopSocket
@@ -78,79 +84,6 @@ Maybe.
"""
-class NullServer(ServerInterface):
- paranoid_did_password = False
- paranoid_did_public_key = False
- paranoid_key = DSSKey.from_private_key_file(_support("dss.key"))
-
- def __init__(self, allowed_keys=None):
- self.allowed_keys = allowed_keys if allowed_keys is not None else []
-
- def get_allowed_auths(self, username):
- if username == "slowdive":
- return "publickey,password"
- return "publickey"
-
- def check_auth_password(self, username, password):
- if (username == "slowdive") and (password == "pygmalion"):
- return AUTH_SUCCESSFUL
- return AUTH_FAILED
-
- def check_auth_publickey(self, username, key):
- if key in self.allowed_keys:
- return AUTH_SUCCESSFUL
- return AUTH_FAILED
-
- def check_channel_request(self, kind, chanid):
- if kind == "bogus":
- return OPEN_FAILED_ADMINISTRATIVELY_PROHIBITED
- return OPEN_SUCCEEDED
-
- def check_channel_exec_request(self, channel, command):
- if command != b"yes":
- return False
- return True
-
- def check_channel_shell_request(self, channel):
- return True
-
- def check_global_request(self, kind, msg):
- self._global_request = kind
- # NOTE: for w/e reason, older impl of this returned False always, even
- # tho that's only supposed to occur if the request cannot be served.
- # For now, leaving that the default unless test supplies specific
- # 'acceptable' request kind
- return kind == "acceptable"
-
- def check_channel_x11_request(
- self,
- channel,
- single_connection,
- auth_protocol,
- auth_cookie,
- screen_number,
- ):
- self._x11_single_connection = single_connection
- self._x11_auth_protocol = auth_protocol
- self._x11_auth_cookie = auth_cookie
- self._x11_screen_number = screen_number
- return True
-
- def check_port_forward_request(self, addr, port):
- self._listen = socket.socket()
- self._listen.bind(("127.0.0.1", 0))
- self._listen.listen(1)
- return self._listen.getsockname()[1]
-
- def cancel_port_forward_request(self, addr, port):
- self._listen.close()
- self._listen = None
-
- def check_channel_direct_tcpip_request(self, chanid, origin, destination):
- self._tcpip_dest = destination
- return OPEN_SUCCEEDED
-
-
class TransportTest(unittest.TestCase):
def setUp(self):
self.socks = LoopSocket()
@@ -1190,103 +1123,6 @@ class AlgorithmDisablingTests(unittest.TestCase):
assert "zlib" not in compressions
-@contextmanager
-def server(
- hostkey=None,
- init=None,
- server_init=None,
- client_init=None,
- connect=None,
- pubkeys=None,
- catch_error=False,
- transport_factory=None,
-):
- """
- SSH server contextmanager for testing.
-
- :param hostkey:
- Host key to use for the server; if None, loads
- ``rsa.key``.
- :param init:
- Default `Transport` constructor kwargs to use for both sides.
- :param server_init:
- Extends and/or overrides ``init`` for server transport only.
- :param client_init:
- Extends and/or overrides ``init`` for client transport only.
- :param connect:
- Kwargs to use for ``connect()`` on the client.
- :param pubkeys:
- List of public keys for auth.
- :param catch_error:
- Whether to capture connection errors & yield from contextmanager.
- Necessary for connection_time exception testing.
- :param transport_factory:
- Like the same-named param in SSHClient: which Transport class to use.
- """
- if init is None:
- init = {}
- if server_init is None:
- server_init = {}
- if client_init is None:
- client_init = {}
- if connect is None:
- connect = dict(username="slowdive", password="pygmalion")
- socks = LoopSocket()
- sockc = LoopSocket()
- sockc.link(socks)
- if transport_factory is None:
- transport_factory = Transport
- tc = transport_factory(sockc, **dict(init, **client_init))
- ts = transport_factory(socks, **dict(init, **server_init))
-
- if hostkey is None:
- hostkey = RSAKey.from_private_key_file(_support("rsa.key"))
- ts.add_server_key(hostkey)
- event = threading.Event()
- server = NullServer(allowed_keys=pubkeys)
- assert not event.is_set()
- assert not ts.is_active()
- assert tc.get_username() is None
- assert ts.get_username() is None
- assert not tc.is_authenticated()
- assert not ts.is_authenticated()
-
- err = None
- # Trap errors and yield instead of raising right away; otherwise callers
- # cannot usefully deal with problems at connect time which stem from errors
- # in the server side.
- try:
- ts.start_server(event, server)
- tc.connect(**connect)
-
- event.wait(1.0)
- assert event.is_set()
- assert ts.is_active()
- assert tc.is_active()
-
- except Exception as e:
- if not catch_error:
- raise
- err = e
-
- yield (tc, ts, err) if catch_error else (tc, ts)
-
- tc.close()
- ts.close()
- socks.close()
- sockc.close()
-
-
-_disable_sha2 = dict(
- disabled_algorithms=dict(keys=["rsa-sha2-256", "rsa-sha2-512"])
-)
-_disable_sha1 = dict(disabled_algorithms=dict(keys=["ssh-rsa"]))
-_disable_sha2_pubkey = dict(
- disabled_algorithms=dict(pubkeys=["rsa-sha2-256", "rsa-sha2-512"])
-)
-_disable_sha1_pubkey = dict(disabled_algorithms=dict(pubkeys=["ssh-rsa"]))
-
-
class TestSHA2SignatureKeyExchange(unittest.TestCase):
# NOTE: these all rely on the default server() hostkey being RSA
# NOTE: these rely on both sides being properly implemented re: agreed-upon
@@ -1351,7 +1187,10 @@ class TestSHA2SignatureKeyExchange(unittest.TestCase):
# the entire preferred-hostkeys structure when given an explicit key as
# a client.)
hostkey = RSAKey.from_private_key_file(_support("rsa.key"))
- with server(hostkey=hostkey, connect=dict(hostkey=hostkey)) as (tc, _):
+ connect = dict(
+ hostkey=hostkey, username="slowdive", password="pygmalion"
+ )
+ with server(hostkey=hostkey, connect=connect) as (tc, _):
assert tc.host_key_type == "rsa-sha2-512"
@@ -1442,7 +1281,7 @@ class TestSHA2SignaturePubkeys(unittest.TestCase):
server_init = dict(_disable_sha2_pubkey, server_sig_algs=False)
with server(
pubkeys=[privkey],
- connect=dict(pkey=privkey),
+ connect=dict(username="slowdive", pkey=privkey),
server_init=server_init,
catch_error=True,
) as (tc, ts, err):
@@ -1455,6 +1294,7 @@ class TestSHA2SignaturePubkeys(unittest.TestCase):
privkey = RSAKey.from_private_key_file(_support("rsa.key"))
with server(
pubkeys=[privkey],
+ # TODO: why is this passing without a username?
connect=dict(pkey=privkey),
init=dict(
disabled_algorithms=dict(pubkeys=["ssh-rsa", "rsa-sha2-256"])