From ce68e203d5e5acc12fa6f7362de7cad69c0d4898 Mon Sep 17 00:00:00 2001 From: Manuel Giffels Date: Fri, 24 May 2024 13:09:50 +0200 Subject: [PATCH] Add tests for MFA support --- .../executors_t/test_sshexecutor.py | 106 +++++++++++++++++- 1 file changed, 100 insertions(+), 6 deletions(-) diff --git a/tests/utilities_t/executors_t/test_sshexecutor.py b/tests/utilities_t/executors_t/test_sshexecutor.py index c789ece6..4f4ae296 100644 --- a/tests/utilities_t/executors_t/test_sshexecutor.py +++ b/tests/utilities_t/executors_t/test_sshexecutor.py @@ -1,7 +1,12 @@ from tests.utilities.utilities import async_return, run_async from tardis.utilities.attributedict import AttributeDict -from tardis.utilities.executors.sshexecutor import SSHExecutor, probe_max_session +from tardis.utilities.executors.sshexecutor import ( + SSHExecutor, + probe_max_session, + MFASSHClient, +) from tardis.exceptions.executorexceptions import CommandExecutionFailure +from tardis.exceptions.tardisexceptions import TardisAuthError from asyncssh import ChannelOpenError, ConnectionLost, DisconnectError, ProcessError @@ -11,6 +16,7 @@ import asyncio import yaml import contextlib +import logging from asyncstdlib import contextmanager as asynccontextmanager @@ -67,6 +73,63 @@ def test_max_sessions(self): ) +class TestMFASSHClient(TestCase): + def setUp(self): + mfa_secrets = [ + { + "prompt": "Enter MFA token:", + "secret": "EJL2DAWFOH7QPJ3D6I2DK2ARTBEJDBIB", + }, + { + "prompt": "Yet another token:", + "secret": "D22246GDKKEDK7AAM77ZH5VRDRL7Z6W7", + }, + ] + self.mfa_ssh_client = MFASSHClient(mfa_secrets=mfa_secrets) + + def test_kbdint_auth_requested(self): + self.assertEqual(run_async(self.mfa_ssh_client.kbdint_auth_requested), "") + + def test_kbdint_challenge_received(self): + def test_responses(prompts, num_of_expected_responses): + responses = run_async( + self.mfa_ssh_client.kbdint_challenge_received, + name="test", + instructions="no", + lang="en", + prompts=prompts, + ) + + self.assertEqual(len(responses), num_of_expected_responses) + for response in responses: + self.assertTrue(response.isdigit()) + + for prompts, num_of_expected_responses in ( + ([("Enter MFA token:", False)], 1), + ([("Enter MFA token:", False), ("Yet another token: ", False)], 2), + ([], 0), + ): + test_responses( + prompts=prompts, num_of_expected_responses=num_of_expected_responses + ) + + prompts_to_fail = [("Enter MFA token:", False), ("Unknown token: ", False)] + + with self.assertRaises(TardisAuthError) as tae: + with self.assertLogs(level=logging.ERROR): + run_async( + self.mfa_ssh_client.kbdint_challenge_received, + name="test", + instructions="no", + lang="en", + prompts=prompts_to_fail, + ) + self.assertIn( + "Keyboard interactive authentication failed: Unexpected Prompt", + str(tae.exception), + ) + + class TestSSHExecutor(TestCase): mock_asyncssh = None @@ -208,6 +271,17 @@ def test_run_command(self): run_async(raising_executor.run_command, command="Test", stdin_input="Test") def test_construction_by_yaml(self): + def test_yaml_construction(test_executor, *args, **kwargs): + self.assertEqual( + run_async( + test_executor.run_command, command="Test", stdin_input="Test" + ).stdout, + "Test", + ) + self.mock_asyncssh.connect.assert_called_with(*args, **kwargs) + + self.mock_asyncssh.reset_mock() + executor = yaml.safe_load( """ !SSHExecutor @@ -218,10 +292,30 @@ def test_construction_by_yaml(self): """ ) - self.assertEqual( - run_async(executor.run_command, command="Test", stdin_input="Test").stdout, - "Test", + test_yaml_construction( + executor, + host="test_host", + username="test", + client_keys=["TestKey"], ) - self.mock_asyncssh.connect.assert_called_with( - host="test_host", username="test", client_keys=["TestKey"] + + mfa_executor = yaml.safe_load( + """ + !SSHExecutor + host: test_host + username: test + client_keys: + - TestKey + mfa_secrets: + - prompt: 'Token: ' + secret: 123TopSecret + """ + ) + + test_yaml_construction( + mfa_executor, + host="test_host", + username="test", + client_keys=["TestKey"], + client_factory=mfa_executor._parameters["client_factory"], )