summaryrefslogtreecommitdiffhomepage
path: root/tests/test_transport.py
diff options
context:
space:
mode:
Diffstat (limited to 'tests/test_transport.py')
-rw-r--r--tests/test_transport.py326
1 files changed, 53 insertions, 273 deletions
diff --git a/tests/test_transport.py b/tests/test_transport.py
index 4062d767..421c078b 100644
--- a/tests/test_transport.py
+++ b/tests/test_transport.py
@@ -22,7 +22,6 @@ Some unit tests for the ssh2 protocol in Transport.
from binascii import hexlify
-from contextlib import contextmanager
import select
import socket
import time
@@ -34,18 +33,16 @@ from unittest.mock import Mock
from paramiko import (
AuthHandler,
ChannelException,
- DSSKey,
Packetizer,
RSAKey,
SSHException,
- AuthenticationException,
IncompatiblePeer,
SecurityOptions,
- ServerInterface,
+ ServiceRequestingTransport,
Transport,
)
-from paramiko import AUTH_FAILED, AUTH_SUCCESSFUL
-from paramiko import OPEN_SUCCEEDED, OPEN_FAILED_ADMINISTRATIVELY_PROHIBITED
+from paramiko.auth_handler import AuthOnlyHandler
+from paramiko import OPEN_FAILED_ADMINISTRATIVELY_PROHIBITED
from paramiko.common import (
DEFAULT_MAX_PACKET_SIZE,
DEFAULT_WINDOW_SIZE,
@@ -60,8 +57,17 @@ from paramiko.common import (
)
from paramiko.message import Message
-from .util import needs_builtin, _support, requires_sha1_signing, slow
-from .loop import LoopSocket
+from ._util import (
+ needs_builtin,
+ _support,
+ requires_sha1_signing,
+ slow,
+ server,
+ _disable_sha2,
+ _disable_sha1,
+ TestServer as NullServer,
+)
+from ._loop import LoopSocket
LONG_BANNER = """\
@@ -77,80 +83,11 @@ Maybe.
"""
-class NullServer(ServerInterface):
- paranoid_did_password = False
- paranoid_did_public_key = False
- paranoid_key = DSSKey.from_private_key_file(_support("test_dss.key"))
-
- def __init__(self, allowed_keys=None):
- self.allowed_keys = allowed_keys if allowed_keys is not None else []
-
- def get_allowed_auths(self, username):
- if username == "slowdive":
- return "publickey,password"
- return "publickey"
-
- def check_auth_password(self, username, password):
- if (username == "slowdive") and (password == "pygmalion"):
- return AUTH_SUCCESSFUL
- return AUTH_FAILED
-
- def check_auth_publickey(self, username, key):
- if key in self.allowed_keys:
- return AUTH_SUCCESSFUL
- return AUTH_FAILED
-
- def check_channel_request(self, kind, chanid):
- if kind == "bogus":
- return OPEN_FAILED_ADMINISTRATIVELY_PROHIBITED
- return OPEN_SUCCEEDED
-
- def check_channel_exec_request(self, channel, command):
- if command != b"yes":
- return False
- return True
-
- def check_channel_shell_request(self, channel):
- return True
-
- def check_global_request(self, kind, msg):
- self._global_request = kind
- # NOTE: for w/e reason, older impl of this returned False always, even
- # tho that's only supposed to occur if the request cannot be served.
- # For now, leaving that the default unless test supplies specific
- # 'acceptable' request kind
- return kind == "acceptable"
-
- def check_channel_x11_request(
- self,
- channel,
- single_connection,
- auth_protocol,
- auth_cookie,
- screen_number,
- ):
- self._x11_single_connection = single_connection
- self._x11_auth_protocol = auth_protocol
- self._x11_auth_cookie = auth_cookie
- self._x11_screen_number = screen_number
- return True
-
- def check_port_forward_request(self, addr, port):
- self._listen = socket.socket()
- self._listen.bind(("127.0.0.1", 0))
- self._listen.listen(1)
- return self._listen.getsockname()[1]
-
- def cancel_port_forward_request(self, addr, port):
- self._listen.close()
- self._listen = None
-
- def check_channel_direct_tcpip_request(self, chanid, origin, destination):
- self._tcpip_dest = destination
- return OPEN_SUCCEEDED
-
-
class TransportTest(unittest.TestCase):
+ # TODO: this can get nuked once ServiceRequestingTransport becomes the
+ # only Transport, as it has this baked in.
+ _auth_handler_class = AuthHandler
+
def setUp(self):
self.socks = LoopSocket()
self.sockc = LoopSocket()
@@ -168,7 +105,7 @@ class TransportTest(unittest.TestCase):
def setup_test_server(
self, client_options=None, server_options=None, connect_kwargs=None
):
- host_key = RSAKey.from_private_key_file(_support("test_rsa.key"))
+ host_key = RSAKey.from_private_key_file(_support("rsa.key"))
public_host_key = RSAKey(data=host_key.asbytes())
self.ts.add_server_key(host_key)
@@ -234,7 +171,7 @@ class TransportTest(unittest.TestCase):
loopback sockets. this is hardly "simple" but it's simpler than the
later tests. :)
"""
- host_key = RSAKey.from_private_key_file(_support("test_rsa.key"))
+ host_key = RSAKey.from_private_key_file(_support("rsa.key"))
public_host_key = RSAKey(data=host_key.asbytes())
self.ts.add_server_key(host_key)
event = threading.Event()
@@ -260,7 +197,7 @@ class TransportTest(unittest.TestCase):
"""
verify that a long banner doesn't mess up the handshake.
"""
- host_key = RSAKey.from_private_key_file(_support("test_rsa.key"))
+ host_key = RSAKey.from_private_key_file(_support("rsa.key"))
public_host_key = RSAKey(data=host_key.asbytes())
self.ts.add_server_key(host_key)
event = threading.Event()
@@ -910,7 +847,7 @@ class TransportTest(unittest.TestCase):
# be fine. Even tho it's a bit squicky.
self.tc.packetizer = SlowPacketizer(self.tc.sock)
# Continue with regular test red tape.
- host_key = RSAKey.from_private_key_file(_support("test_rsa.key"))
+ host_key = RSAKey.from_private_key_file(_support("rsa.key"))
public_host_key = RSAKey(data=host_key.asbytes())
self.ts.add_server_key(host_key)
event = threading.Event()
@@ -1099,7 +1036,8 @@ class TransportTest(unittest.TestCase):
def test_server_transports_reject_client_message_types(self):
# TODO: handle Transport's own tables too, not just its inner auth
# handler's table. See TODOs in auth_handler.py
- for message_type in AuthHandler._client_handler_table:
+ some_handler = self._auth_handler_class(self.tc)
+ for message_type in some_handler._client_handler_table:
self._send_client_message(message_type)
self._expect_unimplemented()
# Reset for rest of loop
@@ -1115,6 +1053,21 @@ class TransportTest(unittest.TestCase):
self._expect_unimplemented()
+# TODO: for now this is purely a regression test. It needs actual tests of the
+# intentional new behavior too!
+class ServiceRequestingTransportTest(TransportTest):
+ _auth_handler_class = AuthOnlyHandler
+
+ def setUp(self):
+ # Copypasta (Transport init is load-bearing)
+ self.socks = LoopSocket()
+ self.sockc = LoopSocket()
+ self.sockc.link(self.socks)
+ # New class who dis
+ self.tc = ServiceRequestingTransport(self.sockc)
+ self.ts = ServiceRequestingTransport(self.socks)
+
+
class AlgorithmDisablingTests(unittest.TestCase):
def test_preferred_lists_default_to_private_attribute_contents(self):
t = Transport(sock=Mock())
@@ -1188,98 +1141,6 @@ class AlgorithmDisablingTests(unittest.TestCase):
assert "zlib" not in compressions
-@contextmanager
-def server(
- hostkey=None,
- init=None,
- server_init=None,
- client_init=None,
- connect=None,
- pubkeys=None,
- catch_error=False,
-):
- """
- SSH server contextmanager for testing.
-
- :param hostkey:
- Host key to use for the server; if None, loads
- ``test_rsa.key``.
- :param init:
- Default `Transport` constructor kwargs to use for both sides.
- :param server_init:
- Extends and/or overrides ``init`` for server transport only.
- :param client_init:
- Extends and/or overrides ``init`` for client transport only.
- :param connect:
- Kwargs to use for ``connect()`` on the client.
- :param pubkeys:
- List of public keys for auth.
- :param catch_error:
- Whether to capture connection errors & yield from contextmanager.
- Necessary for connection_time exception testing.
- """
- if init is None:
- init = {}
- if server_init is None:
- server_init = {}
- if client_init is None:
- client_init = {}
- if connect is None:
- connect = dict(username="slowdive", password="pygmalion")
- socks = LoopSocket()
- sockc = LoopSocket()
- sockc.link(socks)
- tc = Transport(sockc, **dict(init, **client_init))
- ts = Transport(socks, **dict(init, **server_init))
-
- if hostkey is None:
- hostkey = RSAKey.from_private_key_file(_support("test_rsa.key"))
- ts.add_server_key(hostkey)
- event = threading.Event()
- server = NullServer(allowed_keys=pubkeys)
- assert not event.is_set()
- assert not ts.is_active()
- assert tc.get_username() is None
- assert ts.get_username() is None
- assert not tc.is_authenticated()
- assert not ts.is_authenticated()
-
- err = None
- # Trap errors and yield instead of raising right away; otherwise callers
- # cannot usefully deal with problems at connect time which stem from errors
- # in the server side.
- try:
- ts.start_server(event, server)
- tc.connect(**connect)
-
- event.wait(1.0)
- assert event.is_set()
- assert ts.is_active()
- assert tc.is_active()
-
- except Exception as e:
- if not catch_error:
- raise
- err = e
-
- yield (tc, ts, err) if catch_error else (tc, ts)
-
- tc.close()
- ts.close()
- socks.close()
- sockc.close()
-
-
-_disable_sha2 = dict(
- disabled_algorithms=dict(keys=["rsa-sha2-256", "rsa-sha2-512"])
-)
-_disable_sha1 = dict(disabled_algorithms=dict(keys=["ssh-rsa"]))
-_disable_sha2_pubkey = dict(
- disabled_algorithms=dict(pubkeys=["rsa-sha2-256", "rsa-sha2-512"])
-)
-_disable_sha1_pubkey = dict(disabled_algorithms=dict(pubkeys=["ssh-rsa"]))
-
-
class TestSHA2SignatureKeyExchange(unittest.TestCase):
# NOTE: these all rely on the default server() hostkey being RSA
# NOTE: these rely on both sides being properly implemented re: agreed-upon
@@ -1343,22 +1204,29 @@ class TestSHA2SignatureKeyExchange(unittest.TestCase):
# (This is a regression test vs previous implementation which overwrote
# the entire preferred-hostkeys structure when given an explicit key as
# a client.)
- hostkey = RSAKey.from_private_key_file(_support("test_rsa.key"))
- with server(hostkey=hostkey, connect=dict(hostkey=hostkey)) as (tc, _):
+ hostkey = RSAKey.from_private_key_file(_support("rsa.key"))
+ connect = dict(
+ hostkey=hostkey, username="slowdive", password="pygmalion"
+ )
+ with server(hostkey=hostkey, connect=connect) as (tc, _):
assert tc.host_key_type == "rsa-sha2-512"
class TestExtInfo(unittest.TestCase):
- def test_ext_info_handshake(self):
+ def test_ext_info_handshake_exposed_in_client_kexinit(self):
with server() as (tc, _):
+ # NOTE: this is latest KEXINIT /sent by us/ (Transport retains it)
kex = tc._get_latest_kex_init()
- assert kex["kex_algo_list"][-1] == "ext-info-c"
+ # flag in KexAlgorithms list
+ assert "ext-info-c" in kex["kex_algo_list"]
+ # data stored on Transport after hearing back from a compatible
+ # server (such as ourselves in server mode)
assert tc.server_extensions == {
"server-sig-algs": b"ssh-ed25519,ecdsa-sha2-nistp256,ecdsa-sha2-nistp384,ecdsa-sha2-nistp521,rsa-sha2-512,rsa-sha2-256,ssh-rsa,ssh-dss" # noqa
}
def test_client_uses_server_sig_algs_for_pubkey_auth(self):
- privkey = RSAKey.from_private_key_file(_support("test_rsa.key"))
+ privkey = RSAKey.from_private_key_file(_support("rsa.key"))
with server(
pubkeys=[privkey],
connect=dict(pkey=privkey),
@@ -1367,94 +1235,6 @@ class TestExtInfo(unittest.TestCase):
),
) as (tc, _):
assert tc.is_authenticated()
- # Client settled on 256 despite itself not having 512 disabled
- assert tc._agreed_pubkey_algorithm == "rsa-sha2-256"
-
-
-# TODO: these could move into test_auth.py but that badly needs refactoring
-# with this module anyways...
-class TestSHA2SignaturePubkeys(unittest.TestCase):
- def test_pubkey_auth_honors_disabled_algorithms(self):
- privkey = RSAKey.from_private_key_file(_support("test_rsa.key"))
- with server(
- pubkeys=[privkey],
- connect=dict(pkey=privkey),
- init=dict(
- disabled_algorithms=dict(
- pubkeys=["ssh-rsa", "rsa-sha2-256", "rsa-sha2-512"]
- )
- ),
- catch_error=True,
- ) as (_, _, err):
- assert isinstance(err, SSHException)
- assert "no RSA pubkey algorithms" in str(err)
-
- def test_client_sha2_disabled_server_sha1_disabled_no_match(self):
- privkey = RSAKey.from_private_key_file(_support("test_rsa.key"))
- with server(
- pubkeys=[privkey],
- connect=dict(pkey=privkey),
- client_init=_disable_sha2_pubkey,
- server_init=_disable_sha1_pubkey,
- catch_error=True,
- ) as (tc, ts, err):
- assert isinstance(err, AuthenticationException)
-
- def test_client_sha1_disabled_server_sha2_disabled_no_match(self):
- privkey = RSAKey.from_private_key_file(_support("test_rsa.key"))
- with server(
- pubkeys=[privkey],
- connect=dict(pkey=privkey),
- client_init=_disable_sha1_pubkey,
- server_init=_disable_sha2_pubkey,
- catch_error=True,
- ) as (tc, ts, err):
- assert isinstance(err, AuthenticationException)
-
- @requires_sha1_signing
- def test_ssh_rsa_still_used_when_sha2_disabled(self):
- privkey = RSAKey.from_private_key_file(_support("test_rsa.key"))
- # NOTE: this works because key obj comparison uses public bytes
- # TODO: would be nice for PKey to grow a legit "give me another obj of
- # same class but just the public bits" using asbytes()
- with server(
- pubkeys=[privkey], connect=dict(pkey=privkey), init=_disable_sha2
- ) as (tc, _):
- assert tc.is_authenticated()
-
- def test_sha2_512(self):
- privkey = RSAKey.from_private_key_file(_support("test_rsa.key"))
- with server(
- pubkeys=[privkey],
- connect=dict(pkey=privkey),
- init=dict(
- disabled_algorithms=dict(pubkeys=["ssh-rsa", "rsa-sha2-256"])
- ),
- ) as (tc, ts):
- assert tc.is_authenticated()
- assert tc._agreed_pubkey_algorithm == "rsa-sha2-512"
-
- def test_sha2_256(self):
- privkey = RSAKey.from_private_key_file(_support("test_rsa.key"))
- with server(
- pubkeys=[privkey],
- connect=dict(pkey=privkey),
- init=dict(
- disabled_algorithms=dict(pubkeys=["ssh-rsa", "rsa-sha2-512"])
- ),
- ) as (tc, ts):
- assert tc.is_authenticated()
- assert tc._agreed_pubkey_algorithm == "rsa-sha2-256"
-
- def test_sha2_256_when_client_only_enables_256(self):
- privkey = RSAKey.from_private_key_file(_support("test_rsa.key"))
- with server(
- pubkeys=[privkey],
- connect=dict(pkey=privkey),
- # Client-side only; server still accepts all 3.
- client_init=dict(
- disabled_algorithms=dict(pubkeys=["ssh-rsa", "rsa-sha2-512"])
- ),
- ) as (tc, ts):
- assert tc.is_authenticated()
+ # Client settled on 256 despite itself not having 512 disabled (and
+ # otherwise, 512 would have been earlier in the preferred list)
assert tc._agreed_pubkey_algorithm == "rsa-sha2-256"