diff --git a/synapse_sso_proconnect/proconnect_mapping.py b/synapse_sso_proconnect/proconnect_mapping.py index ee8ff06..c3401ef 100644 --- a/synapse_sso_proconnect/proconnect_mapping.py +++ b/synapse_sso_proconnect/proconnect_mapping.py @@ -2,6 +2,7 @@ import attr from authlib.oidc.core import UserInfo +from typing import Any, Dict, List, Optional, Tuple from synapse.handlers.oidc import OidcMappingProvider, Token, UserAttributeDict from synapse.handlers.sso import MappingException @@ -14,16 +15,16 @@ @attr.s(slots=True, frozen=True, auto_attribs=True) class ProConnectMappingConfig: - pass - + mapper_new_old_domain:Dict[str, str]= {} class ProConnectMappingProvider(OidcMappingProvider[ProConnectMappingConfig]): - def __init__(self, config: ProConnectMappingConfig, module_api: ModuleApi): + def __init__(self, config: Dict[str, str], module_api: ModuleApi): self.module_api = module_api + self.config_mapper = self.parse_config(config) @staticmethod - def parse_config(config: dict) -> ProConnectMappingConfig: - return ProConnectMappingConfig() + def parse_config(config_dict: Dict[str, str]) -> ProConnectMappingConfig: + return ProConnectMappingConfig(**config_dict) def get_remote_user_id(self, userinfo: UserInfo) -> str: return userinfo.sub @@ -77,15 +78,6 @@ async def map_user_attributes( display_name=display_name, ) - # Return a dict with specific email replacements mappings. - async def getReplaceMapping(self): - return { - # Specific email replacement - "aaa.externe@numerique.gouv.fr" : "aaa@beta.gouv.fr", - # General domain replacement - "numerique.gouv.fr": "beta.gouv.fr" - } - # Search user ID by its email, retrying with replacements if necessary. async def search_user_id_by_threepid(self, email: str): # Try to find the user ID using the provided email @@ -93,10 +85,9 @@ async def search_user_id_by_threepid(self, email: str): # If userId is not found, attempt replacements if not userId: - replace_mapping = await self.getReplaceMapping() # Get the mapping of replacements - + # Iterate through all mappings - for old_value, new_value in replace_mapping.items(): + for old_value, new_value in self.config_mapper.mapper_new_old_domain.items(): # Check if the key (old_value) exists within the email if old_value in email: # Replace the old value with the new value diff --git a/tests/__init__.py b/tests/__init__.py index fcdd259..e69de29 100644 --- a/tests/__init__.py +++ b/tests/__init__.py @@ -1,29 +0,0 @@ -from typing import Any, Dict, Optional -from unittest.mock import AsyncMock, Mock - -import attr -from synapse.module_api import ModuleApi, UserID - -from synapse_sso_proconnect.proconnect_mapping import ProConnectMappingProvider - -class MockHomeserver: - def get_datastores(self): - return Mock(spec=["main"]) - - def get_task_scheduler(self): - return Mock(spec=["register_action"]) - -def create_module( - config_override: Optional[Dict[str, Any]] = None, server_name: str = "example.com" -) -> ProConnectMappingProvider: - # Create a mock based on the ModuleApi spec, but override some mocked functions - # because some capabilities are needed for running the tests. - module_api = Mock(spec=ModuleApi) - - if config_override is None: - config_override = {} - config_override["id_server"] = "example.com" - - config = ProConnectMappingProvider.parse_config(config_override) - - return ProConnectMappingProvider(config, module_api) \ No newline at end of file diff --git a/tests/test_proconnect_mapping.py b/tests/test_proconnect_mapping.py index 81332d5..548ba6b 100644 --- a/tests/test_proconnect_mapping.py +++ b/tests/test_proconnect_mapping.py @@ -1,37 +1,63 @@ from unittest.mock import AsyncMock import aiounittest -from tests import create_module from synapse_sso_proconnect.proconnect_mapping import ProConnectMappingProvider -def create_module( +def create_module(config ) -> ProConnectMappingProvider: module_api = AsyncMock() # Adding _store to the module_api object module_api._store = AsyncMock() - module_api.getReplaceMapping = AsyncMock( - return_value={"numerique.gouv.fr": "beta.gouv.fr"} - ) module_api._store.get_user_id_by_threepid.side_effect = lambda typ, email: ( - "test-beta" if email == "test@beta.gouv.fr" - else "test-exemple" if email == "test@example.com" + "test-beta" if email == "test@beta.fr" + else "test-exemple" if email == "test@example.com" + else "test-numerique" if email == "test@numerique.fr" + else "test-old" if email == "test@old.fr" else None ) - config= {} return ProConnectMappingProvider(config, module_api) class ProConnectMappingTest(aiounittest.AsyncTestCase): - def setUp(self) -> None: - self.module = create_module() + #def setUp(self) -> None: + async def test_with_map_should_replace(self): + self.module = create_module({"mapper_new_old_domain":{"new.fr": "beta.fr"}}) + # Call the tested function with an email that requires replacement + user_id = await self.module.search_user_id_by_threepid("test@new.fr") + # Assertions + self.module.module_api._store.get_user_id_by_threepid.assert_any_call("email", "test@new.fr") + self.module.module_api._store.get_user_id_by_threepid.assert_any_call("email", "test@beta.fr") + self.assertEqual(user_id, "test-beta") # Should match the replaced email + + + async def test_replace_by_priority(self): + self.module = create_module({"mapper_new_old_domain":{ + "test@new.fr":"test@old.fr", + "new.fr": "beta.fr"}})#replace by domain leads to a dead-end but it lower in the list + + # Call the tested function with an email that requires replacement + user_id = await self.module.search_user_id_by_threepid("test@new.fr") + # Assertions + self.module.module_api._store.get_user_id_by_threepid.assert_any_call("email", "test@new.fr") + self.module.module_api._store.get_user_id_by_threepid.assert_any_call("email", "test@old.fr") + self.assertEqual(user_id, "test-old") # Should match the replaced email - async def test_with_email_replacement(self): + async def test_with_map_should_not_replace(self): + self.module = create_module({"mapper_new_old_domain":{"new.fr": "beta.fr"}}) # Call the tested function with an email that requires replacement - user_id = await self.module.search_user_id_by_threepid("test@numerique.gouv.fr") + user_id = await self.module.search_user_id_by_threepid("test@numerique.fr") # Assertions - self.assertEqual(user_id, "test-beta") # Should match the replaced email - self.module.module_api._store.get_user_id_by_threepid.assert_any_call("email", "test@numerique.gouv.fr") - self.module.module_api._store.get_user_id_by_threepid.assert_any_call("email", "test@beta.gouv.fr") \ No newline at end of file + self.assertEqual(user_id, "test-numerique") + + async def test_with_empty_map(self): + + self.module = create_module({"mapper_new_old_domain":{}}) + + # Call the tested function with an email that requires replacement + user_id = await self.module.search_user_id_by_threepid("test@numerique.fr") + + # Assertions + self.assertEqual(user_id, "test-numerique") \ No newline at end of file