286 lines
10 KiB
Diff
286 lines
10 KiB
Diff
From e22c5ea330814801d8487dc3da347f987bafe5ec Mon Sep 17 00:00:00 2001
|
|
From: Jeff Forcier <jeff@bitprophet.org>
|
|
Date: Thu, 4 May 2023 13:52:40 -0400
|
|
Subject: [PATCH] Start consolidating test server nonsense
|
|
|
|
Reference:https://github.com/paramiko/paramiko/commit/e22c5ea330814801d8487dc3da347f987bafe5ec
|
|
Conflict:Currently, _util.py does not exist due to different versions. Therefore, the reconstruction code of test_transport.py is still stored in this file
|
|
The key name must be the same as the current one.
|
|
|
|
---
|
|
tests/test_transport.py | 198 ++++++++++++++++++++++++++++++++++++----
|
|
1 file changed, 181 insertions(+), 17 deletions(-)
|
|
|
|
diff --git a/tests/test_transport.py b/tests/test_transport.py
|
|
index 4ed712e..6cdbfd6 100644
|
|
--- a/tests/test_transport.py
|
|
+++ b/tests/test_transport.py
|
|
@@ -33,6 +33,7 @@ import random
|
|
import sys
|
|
import unittest
|
|
from mock import Mock
|
|
+from time import sleep
|
|
|
|
from paramiko import (
|
|
AuthHandler,
|
|
@@ -1196,6 +1197,146 @@ class AlgorithmDisablingTests(unittest.TestCase):
|
|
assert "diffie-hellman-group14-sha256" not in kexen
|
|
assert "zlib" not in compressions
|
|
|
|
+_disable_sha2 = dict(
|
|
+ disabled_algorithms=dict(keys=["rsa-sha2-256", "rsa-sha2-512"])
|
|
+)
|
|
+_disable_sha1 = dict(disabled_algorithms=dict(keys=["ssh-rsa"]))
|
|
+_disable_sha2_pubkey = dict(
|
|
+ disabled_algorithms=dict(pubkeys=["rsa-sha2-256", "rsa-sha2-512"])
|
|
+)
|
|
+_disable_sha1_pubkey = dict(disabled_algorithms=dict(pubkeys=["ssh-rsa"]))
|
|
+
|
|
+
|
|
+unicodey = "\u2022"
|
|
+
|
|
+
|
|
+class TestServer(ServerInterface):
|
|
+ paranoid_did_password = False
|
|
+ paranoid_did_public_key = False
|
|
+ # TODO: make this ed25519 or something else modern? (_is_ this used??)
|
|
+ paranoid_key = DSSKey.from_private_key_file(_support("test_dss.key"))
|
|
+
|
|
+ def __init__(self, allowed_keys=None):
|
|
+ self.allowed_keys = allowed_keys if allowed_keys is not None else []
|
|
+
|
|
+ def check_channel_request(self, kind, chanid):
|
|
+ if kind == "bogus":
|
|
+ return OPEN_FAILED_ADMINISTRATIVELY_PROHIBITED
|
|
+ return OPEN_SUCCEEDED
|
|
+
|
|
+ def check_channel_exec_request(self, channel, command):
|
|
+ if command != b"yes":
|
|
+ return False
|
|
+ return True
|
|
+
|
|
+ def check_channel_shell_request(self, channel):
|
|
+ return True
|
|
+
|
|
+ def check_global_request(self, kind, msg):
|
|
+ self._global_request = kind
|
|
+ # NOTE: for w/e reason, older impl of this returned False always, even
|
|
+ # tho that's only supposed to occur if the request cannot be served.
|
|
+ # For now, leaving that the default unless test supplies specific
|
|
+ # 'acceptable' request kind
|
|
+ return kind == "acceptable"
|
|
+
|
|
+ def check_channel_x11_request(
|
|
+ self,
|
|
+ channel,
|
|
+ single_connection,
|
|
+ auth_protocol,
|
|
+ auth_cookie,
|
|
+ screen_number,
|
|
+ ):
|
|
+ self._x11_single_connection = single_connection
|
|
+ self._x11_auth_protocol = auth_protocol
|
|
+ self._x11_auth_cookie = auth_cookie
|
|
+ self._x11_screen_number = screen_number
|
|
+ return True
|
|
+
|
|
+ def check_port_forward_request(self, addr, port):
|
|
+ self._listen = socket.socket()
|
|
+ self._listen.bind(("127.0.0.1", 0))
|
|
+ self._listen.listen(1)
|
|
+ return self._listen.getsockname()[1]
|
|
+
|
|
+ def cancel_port_forward_request(self, addr, port):
|
|
+ self._listen.close()
|
|
+ self._listen = None
|
|
+
|
|
+ def check_channel_direct_tcpip_request(self, chanid, origin, destination):
|
|
+ self._tcpip_dest = destination
|
|
+ return OPEN_SUCCEEDED
|
|
+
|
|
+ def get_allowed_auths(self, username):
|
|
+ if username == "slowdive":
|
|
+ return "publickey,password"
|
|
+ if username == "paranoid":
|
|
+ if (
|
|
+ not self.paranoid_did_password
|
|
+ and not self.paranoid_did_public_key
|
|
+ ):
|
|
+ return "publickey,password"
|
|
+ elif self.paranoid_did_password:
|
|
+ return "publickey"
|
|
+ else:
|
|
+ return "password"
|
|
+ if username == "commie":
|
|
+ return "keyboard-interactive"
|
|
+ if username == "utf8":
|
|
+ return "password"
|
|
+ if username == "non-utf8":
|
|
+ return "password"
|
|
+ return "publickey"
|
|
+
|
|
+ def check_auth_password(self, username, password):
|
|
+ if (username == "slowdive") and (password == "pygmalion"):
|
|
+ return AUTH_SUCCESSFUL
|
|
+ if (username == "paranoid") and (password == "paranoid"):
|
|
+ # 2-part auth (even openssh doesn't support this)
|
|
+ self.paranoid_did_password = True
|
|
+ if self.paranoid_did_public_key:
|
|
+ return AUTH_SUCCESSFUL
|
|
+ return AUTH_PARTIALLY_SUCCESSFUL
|
|
+ if (username == "utf8") and (password == unicodey):
|
|
+ return AUTH_SUCCESSFUL
|
|
+ if (username == "non-utf8") and (password == "\xff"):
|
|
+ return AUTH_SUCCESSFUL
|
|
+ if username == "bad-server":
|
|
+ raise Exception("Ack!")
|
|
+ if username == "unresponsive-server":
|
|
+ time.sleep(5)
|
|
+ return AUTH_SUCCESSFUL
|
|
+ return AUTH_FAILED
|
|
+
|
|
+ def check_auth_publickey(self, username, key):
|
|
+ if (username == "paranoid") and (key == self.paranoid_key):
|
|
+ # 2-part auth
|
|
+ self.paranoid_did_public_key = True
|
|
+ if self.paranoid_did_password:
|
|
+ return AUTH_SUCCESSFUL
|
|
+ return AUTH_PARTIALLY_SUCCESSFUL
|
|
+ # TODO: make sure all tests incidentally using this to pass, _without
|
|
+ # sending a username oops_, get updated somehow - probably via server()
|
|
+ # default always injecting a username
|
|
+ elif key in self.allowed_keys:
|
|
+ return AUTH_SUCCESSFUL
|
|
+ return AUTH_FAILED
|
|
+
|
|
+ def check_auth_interactive(self, username, submethods):
|
|
+ if username == "commie":
|
|
+ self.username = username
|
|
+ return InteractiveQuery(
|
|
+ "password", "Please enter a password.", ("Password", False)
|
|
+ )
|
|
+ return AUTH_FAILED
|
|
+
|
|
+ def check_auth_interactive_response(self, responses):
|
|
+ if self.username == "commie":
|
|
+ if (len(responses) == 1) and (responses[0] == "cat"):
|
|
+ return AUTH_SUCCESSFUL
|
|
+ return AUTH_FAILED
|
|
+
|
|
|
|
@contextmanager
|
|
def server(
|
|
@@ -1206,13 +1347,20 @@ def server(
|
|
connect=None,
|
|
pubkeys=None,
|
|
catch_error=False,
|
|
+ transport_factory=None,
|
|
+ server_transport_factory=None,
|
|
+ defer=False,
|
|
+ skip_verify=False,
|
|
):
|
|
"""
|
|
SSH server contextmanager for testing.
|
|
|
|
+ Yields a tuple of ``(tc, ts)`` (client- and server-side `Transport`
|
|
+ objects), or ``(tc, ts, err)`` when ``catch_error==True``.
|
|
+
|
|
:param hostkey:
|
|
Host key to use for the server; if None, loads
|
|
- ``test_rsa.key``.
|
|
+ ``rsa.key``.
|
|
:param init:
|
|
Default `Transport` constructor kwargs to use for both sides.
|
|
:param server_init:
|
|
@@ -1226,6 +1374,17 @@ def server(
|
|
:param catch_error:
|
|
Whether to capture connection errors & yield from contextmanager.
|
|
Necessary for connection_time exception testing.
|
|
+ :param transport_factory:
|
|
+ Like the same-named param in SSHClient: which Transport class to use.
|
|
+ :param server_transport_factory:
|
|
+ Like ``transport_factory``, but only impacts the server transport.
|
|
+ :param bool defer:
|
|
+ Whether to defer authentication during connecting.
|
|
+
|
|
+ This is really just shorthand for ``connect={}`` which would do roughly
|
|
+ the same thing. Also: this implies skip_verify=True automatically!
|
|
+ :param bool skip_verify:
|
|
+ Whether NOT to do the default "make sure auth passed" check.
|
|
"""
|
|
if init is None:
|
|
init = {}
|
|
@@ -1234,18 +1393,27 @@ def server(
|
|
if client_init is None:
|
|
client_init = {}
|
|
if connect is None:
|
|
- connect = dict(username="slowdive", password="pygmalion")
|
|
+ # No auth at all please
|
|
+ if defer:
|
|
+ connect = dict()
|
|
+ # Default username based auth
|
|
+ else:
|
|
+ connect = dict(username="slowdive", password="pygmalion")
|
|
socks = LoopSocket()
|
|
sockc = LoopSocket()
|
|
sockc.link(socks)
|
|
- tc = Transport(sockc, **dict(init, **client_init))
|
|
- ts = Transport(socks, **dict(init, **server_init))
|
|
+ if transport_factory is None:
|
|
+ transport_factory = Transport
|
|
+ if server_transport_factory is None:
|
|
+ server_transport_factory = transport_factory
|
|
+ tc = transport_factory(sockc, **dict(init, **client_init))
|
|
+ ts = server_transport_factory(socks, **dict(init, **server_init))
|
|
|
|
if hostkey is None:
|
|
hostkey = RSAKey.from_private_key_file(_support("test_rsa.key"))
|
|
ts.add_server_key(hostkey)
|
|
event = threading.Event()
|
|
- server = NullServer(allowed_keys=pubkeys)
|
|
+ server = TestServer(allowed_keys=pubkeys)
|
|
assert not event.is_set()
|
|
assert not ts.is_active()
|
|
assert tc.get_username() is None
|
|
@@ -1273,22 +1441,15 @@ def server(
|
|
|
|
yield (tc, ts, err) if catch_error else (tc, ts)
|
|
|
|
+ if not (catch_error or skip_verify or defer):
|
|
+ assert ts.is_authenticated()
|
|
+ assert tc.is_authenticated()
|
|
+
|
|
tc.close()
|
|
ts.close()
|
|
socks.close()
|
|
sockc.close()
|
|
|
|
-
|
|
-_disable_sha2 = dict(
|
|
- disabled_algorithms=dict(keys=["rsa-sha2-256", "rsa-sha2-512"])
|
|
-)
|
|
-_disable_sha1 = dict(disabled_algorithms=dict(keys=["ssh-rsa"]))
|
|
-_disable_sha2_pubkey = dict(
|
|
- disabled_algorithms=dict(pubkeys=["rsa-sha2-256", "rsa-sha2-512"])
|
|
-)
|
|
-_disable_sha1_pubkey = dict(disabled_algorithms=dict(pubkeys=["ssh-rsa"]))
|
|
-
|
|
-
|
|
class TestSHA2SignatureKeyExchange(unittest.TestCase):
|
|
# NOTE: these all rely on the default server() hostkey being RSA
|
|
# NOTE: these rely on both sides being properly implemented re: agreed-upon
|
|
@@ -1352,7 +1513,10 @@ class TestSHA2SignatureKeyExchange(unittest.TestCase):
|
|
# the entire preferred-hostkeys structure when given an explicit key as
|
|
# a client.)
|
|
hostkey = RSAKey.from_private_key_file(_support("test_rsa.key"))
|
|
- with server(hostkey=hostkey, connect=dict(hostkey=hostkey)) as (tc, _):
|
|
+ connect = dict(
|
|
+ hostkey=hostkey, username="slowdive", password="pygmalion"
|
|
+ )
|
|
+ with server(hostkey=hostkey, connect=connect) as (tc, _):
|
|
assert tc.host_key_type == "rsa-sha2-512"
|
|
|
|
|
|
--
|
|
2.33.0
|