From e22c5ea330814801d8487dc3da347f987bafe5ec Mon Sep 17 00:00:00 2001 From: Jeff Forcier Date: Thu, 4 May 2023 13:52:40 -0400 Subject: [PATCH] Start consolidating test server nonsense Reference:https://github.com/paramiko/paramiko/commit/e22c5ea330814801d8487dc3da347f987bafe5ec Conflict:Currently, _util.py does not exist due to different versions. Therefore, the reconstruction code of test_transport.py is still stored in this file The key name must be the same as the current one. --- tests/test_transport.py | 198 ++++++++++++++++++++++++++++++++++++---- 1 file changed, 181 insertions(+), 17 deletions(-) diff --git a/tests/test_transport.py b/tests/test_transport.py index 4ed712e..6cdbfd6 100644 --- a/tests/test_transport.py +++ b/tests/test_transport.py @@ -33,6 +33,7 @@ import random import sys import unittest from mock import Mock +from time import sleep from paramiko import ( AuthHandler, @@ -1196,6 +1197,146 @@ class AlgorithmDisablingTests(unittest.TestCase): assert "diffie-hellman-group14-sha256" not in kexen assert "zlib" not in compressions +_disable_sha2 = dict( + disabled_algorithms=dict(keys=["rsa-sha2-256", "rsa-sha2-512"]) +) +_disable_sha1 = dict(disabled_algorithms=dict(keys=["ssh-rsa"])) +_disable_sha2_pubkey = dict( + disabled_algorithms=dict(pubkeys=["rsa-sha2-256", "rsa-sha2-512"]) +) +_disable_sha1_pubkey = dict(disabled_algorithms=dict(pubkeys=["ssh-rsa"])) + + +unicodey = "\u2022" + + +class TestServer(ServerInterface): + paranoid_did_password = False + paranoid_did_public_key = False + # TODO: make this ed25519 or something else modern? (_is_ this used??) + paranoid_key = DSSKey.from_private_key_file(_support("test_dss.key")) + + def __init__(self, allowed_keys=None): + self.allowed_keys = allowed_keys if allowed_keys is not None else [] + + def check_channel_request(self, kind, chanid): + if kind == "bogus": + return OPEN_FAILED_ADMINISTRATIVELY_PROHIBITED + return OPEN_SUCCEEDED + + def check_channel_exec_request(self, channel, command): + if command != b"yes": + return False + return True + + def check_channel_shell_request(self, channel): + return True + + def check_global_request(self, kind, msg): + self._global_request = kind + # NOTE: for w/e reason, older impl of this returned False always, even + # tho that's only supposed to occur if the request cannot be served. + # For now, leaving that the default unless test supplies specific + # 'acceptable' request kind + return kind == "acceptable" + + def check_channel_x11_request( + self, + channel, + single_connection, + auth_protocol, + auth_cookie, + screen_number, + ): + self._x11_single_connection = single_connection + self._x11_auth_protocol = auth_protocol + self._x11_auth_cookie = auth_cookie + self._x11_screen_number = screen_number + return True + + def check_port_forward_request(self, addr, port): + self._listen = socket.socket() + self._listen.bind(("127.0.0.1", 0)) + self._listen.listen(1) + return self._listen.getsockname()[1] + + def cancel_port_forward_request(self, addr, port): + self._listen.close() + self._listen = None + + def check_channel_direct_tcpip_request(self, chanid, origin, destination): + self._tcpip_dest = destination + return OPEN_SUCCEEDED + + def get_allowed_auths(self, username): + if username == "slowdive": + return "publickey,password" + if username == "paranoid": + if ( + not self.paranoid_did_password + and not self.paranoid_did_public_key + ): + return "publickey,password" + elif self.paranoid_did_password: + return "publickey" + else: + return "password" + if username == "commie": + return "keyboard-interactive" + if username == "utf8": + return "password" + if username == "non-utf8": + return "password" + return "publickey" + + def check_auth_password(self, username, password): + if (username == "slowdive") and (password == "pygmalion"): + return AUTH_SUCCESSFUL + if (username == "paranoid") and (password == "paranoid"): + # 2-part auth (even openssh doesn't support this) + self.paranoid_did_password = True + if self.paranoid_did_public_key: + return AUTH_SUCCESSFUL + return AUTH_PARTIALLY_SUCCESSFUL + if (username == "utf8") and (password == unicodey): + return AUTH_SUCCESSFUL + if (username == "non-utf8") and (password == "\xff"): + return AUTH_SUCCESSFUL + if username == "bad-server": + raise Exception("Ack!") + if username == "unresponsive-server": + time.sleep(5) + return AUTH_SUCCESSFUL + return AUTH_FAILED + + def check_auth_publickey(self, username, key): + if (username == "paranoid") and (key == self.paranoid_key): + # 2-part auth + self.paranoid_did_public_key = True + if self.paranoid_did_password: + return AUTH_SUCCESSFUL + return AUTH_PARTIALLY_SUCCESSFUL + # TODO: make sure all tests incidentally using this to pass, _without + # sending a username oops_, get updated somehow - probably via server() + # default always injecting a username + elif key in self.allowed_keys: + return AUTH_SUCCESSFUL + return AUTH_FAILED + + def check_auth_interactive(self, username, submethods): + if username == "commie": + self.username = username + return InteractiveQuery( + "password", "Please enter a password.", ("Password", False) + ) + return AUTH_FAILED + + def check_auth_interactive_response(self, responses): + if self.username == "commie": + if (len(responses) == 1) and (responses[0] == "cat"): + return AUTH_SUCCESSFUL + return AUTH_FAILED + @contextmanager def server( @@ -1206,13 +1347,20 @@ def server( connect=None, pubkeys=None, catch_error=False, + transport_factory=None, + server_transport_factory=None, + defer=False, + skip_verify=False, ): """ SSH server contextmanager for testing. + Yields a tuple of ``(tc, ts)`` (client- and server-side `Transport` + objects), or ``(tc, ts, err)`` when ``catch_error==True``. + :param hostkey: Host key to use for the server; if None, loads - ``test_rsa.key``. + ``rsa.key``. :param init: Default `Transport` constructor kwargs to use for both sides. :param server_init: @@ -1226,6 +1374,17 @@ def server( :param catch_error: Whether to capture connection errors & yield from contextmanager. Necessary for connection_time exception testing. + :param transport_factory: + Like the same-named param in SSHClient: which Transport class to use. + :param server_transport_factory: + Like ``transport_factory``, but only impacts the server transport. + :param bool defer: + Whether to defer authentication during connecting. + + This is really just shorthand for ``connect={}`` which would do roughly + the same thing. Also: this implies skip_verify=True automatically! + :param bool skip_verify: + Whether NOT to do the default "make sure auth passed" check. """ if init is None: init = {} @@ -1234,18 +1393,27 @@ def server( if client_init is None: client_init = {} if connect is None: - connect = dict(username="slowdive", password="pygmalion") + # No auth at all please + if defer: + connect = dict() + # Default username based auth + else: + connect = dict(username="slowdive", password="pygmalion") socks = LoopSocket() sockc = LoopSocket() sockc.link(socks) - tc = Transport(sockc, **dict(init, **client_init)) - ts = Transport(socks, **dict(init, **server_init)) + if transport_factory is None: + transport_factory = Transport + if server_transport_factory is None: + server_transport_factory = transport_factory + tc = transport_factory(sockc, **dict(init, **client_init)) + ts = server_transport_factory(socks, **dict(init, **server_init)) if hostkey is None: hostkey = RSAKey.from_private_key_file(_support("test_rsa.key")) ts.add_server_key(hostkey) event = threading.Event() - server = NullServer(allowed_keys=pubkeys) + server = TestServer(allowed_keys=pubkeys) assert not event.is_set() assert not ts.is_active() assert tc.get_username() is None @@ -1273,22 +1441,15 @@ def server( yield (tc, ts, err) if catch_error else (tc, ts) + if not (catch_error or skip_verify or defer): + assert ts.is_authenticated() + assert tc.is_authenticated() + tc.close() ts.close() socks.close() sockc.close() - -_disable_sha2 = dict( - disabled_algorithms=dict(keys=["rsa-sha2-256", "rsa-sha2-512"]) -) -_disable_sha1 = dict(disabled_algorithms=dict(keys=["ssh-rsa"])) -_disable_sha2_pubkey = dict( - disabled_algorithms=dict(pubkeys=["rsa-sha2-256", "rsa-sha2-512"]) -) -_disable_sha1_pubkey = dict(disabled_algorithms=dict(pubkeys=["ssh-rsa"])) - - class TestSHA2SignatureKeyExchange(unittest.TestCase): # NOTE: these all rely on the default server() hostkey being RSA # NOTE: these rely on both sides being properly implemented re: agreed-upon @@ -1352,7 +1513,10 @@ class TestSHA2SignatureKeyExchange(unittest.TestCase): # the entire preferred-hostkeys structure when given an explicit key as # a client.) hostkey = RSAKey.from_private_key_file(_support("test_rsa.key")) - with server(hostkey=hostkey, connect=dict(hostkey=hostkey)) as (tc, _): + connect = dict( + hostkey=hostkey, username="slowdive", password="pygmalion" + ) + with server(hostkey=hostkey, connect=connect) as (tc, _): assert tc.host_key_type == "rsa-sha2-512" -- 2.33.0