-
Notifications
You must be signed in to change notification settings - Fork 0
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
- Loading branch information
Showing
3 changed files
with
49 additions
and
61 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -2,6 +2,7 @@ | |
|
||
import attr | ||
from authlib.oidc.core import UserInfo | ||
from typing import Any, Dict, List, Optional, Tuple | ||
|
||
from synapse.handlers.oidc import OidcMappingProvider, Token, UserAttributeDict | ||
from synapse.handlers.sso import MappingException | ||
|
@@ -14,16 +15,16 @@ | |
|
||
@attr.s(slots=True, frozen=True, auto_attribs=True) | ||
class ProConnectMappingConfig: | ||
pass | ||
|
||
mapper_new_old_domain:Dict[str, str]= {} | ||
|
||
class ProConnectMappingProvider(OidcMappingProvider[ProConnectMappingConfig]): | ||
def __init__(self, config: ProConnectMappingConfig, module_api: ModuleApi): | ||
def __init__(self, config: Dict[str, str], module_api: ModuleApi): | ||
self.module_api = module_api | ||
self.config_mapper = self.parse_config(config) | ||
|
||
@staticmethod | ||
def parse_config(config: dict) -> ProConnectMappingConfig: | ||
return ProConnectMappingConfig() | ||
def parse_config(config_dict: Dict[str, str]) -> ProConnectMappingConfig: | ||
return ProConnectMappingConfig(**config_dict) | ||
|
||
def get_remote_user_id(self, userinfo: UserInfo) -> str: | ||
return userinfo.sub | ||
|
@@ -77,26 +78,16 @@ async def map_user_attributes( | |
display_name=display_name, | ||
) | ||
|
||
# Return a dict with specific email replacements mappings. | ||
async def getReplaceMapping(self): | ||
return { | ||
# Specific email replacement | ||
"[email protected]" : "[email protected]", | ||
# General domain replacement | ||
"numerique.gouv.fr": "beta.gouv.fr" | ||
} | ||
|
||
# Search user ID by its email, retrying with replacements if necessary. | ||
async def search_user_id_by_threepid(self, email: str): | ||
# Try to find the user ID using the provided email | ||
userId = await self.module_api._store.get_user_id_by_threepid("email", email) | ||
|
||
# If userId is not found, attempt replacements | ||
if not userId: | ||
replace_mapping = await self.getReplaceMapping() # Get the mapping of replacements | ||
|
||
|
||
# Iterate through all mappings | ||
for old_value, new_value in replace_mapping.items(): | ||
for old_value, new_value in self.config_mapper.mapper_new_old_domain.items(): | ||
# Check if the key (old_value) exists within the email | ||
if old_value in email: | ||
# Replace the old value with the new value | ||
|
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1,29 +0,0 @@ | ||
from typing import Any, Dict, Optional | ||
from unittest.mock import AsyncMock, Mock | ||
|
||
import attr | ||
from synapse.module_api import ModuleApi, UserID | ||
|
||
from synapse_sso_proconnect.proconnect_mapping import ProConnectMappingProvider | ||
|
||
class MockHomeserver: | ||
def get_datastores(self): | ||
return Mock(spec=["main"]) | ||
|
||
def get_task_scheduler(self): | ||
return Mock(spec=["register_action"]) | ||
|
||
def create_module( | ||
config_override: Optional[Dict[str, Any]] = None, server_name: str = "example.com" | ||
) -> ProConnectMappingProvider: | ||
# Create a mock based on the ModuleApi spec, but override some mocked functions | ||
# because some capabilities are needed for running the tests. | ||
module_api = Mock(spec=ModuleApi) | ||
|
||
if config_override is None: | ||
config_override = {} | ||
config_override["id_server"] = "example.com" | ||
|
||
config = ProConnectMappingProvider.parse_config(config_override) | ||
|
||
return ProConnectMappingProvider(config, module_api) | ||
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1,37 +1,63 @@ | ||
from unittest.mock import AsyncMock | ||
import aiounittest | ||
from tests import create_module | ||
from synapse_sso_proconnect.proconnect_mapping import ProConnectMappingProvider | ||
|
||
|
||
def create_module( | ||
def create_module(config | ||
) -> ProConnectMappingProvider: | ||
module_api = AsyncMock() | ||
# Adding _store to the module_api object | ||
module_api._store = AsyncMock() | ||
module_api.getReplaceMapping = AsyncMock( | ||
return_value={"numerique.gouv.fr": "beta.gouv.fr"} | ||
) | ||
module_api._store.get_user_id_by_threepid.side_effect = lambda typ, email: ( | ||
"test-beta" if email == "[email protected]" | ||
else "test-exemple" if email == "[email protected]" | ||
"test-beta" if email == "[email protected]" | ||
else "test-exemple" if email == "[email protected]" | ||
else "test-numerique" if email == "[email protected]" | ||
else "test-old" if email == "[email protected]" | ||
else None | ||
) | ||
config= {} | ||
return ProConnectMappingProvider(config, module_api) | ||
|
||
|
||
class ProConnectMappingTest(aiounittest.AsyncTestCase): | ||
def setUp(self) -> None: | ||
self.module = create_module() | ||
#def setUp(self) -> None: | ||
|
||
async def test_with_map_should_replace(self): | ||
self.module = create_module({"mapper_new_old_domain":{"new.fr": "beta.fr"}}) | ||
# Call the tested function with an email that requires replacement | ||
user_id = await self.module.search_user_id_by_threepid("[email protected]") | ||
# Assertions | ||
self.module.module_api._store.get_user_id_by_threepid.assert_any_call("email", "[email protected]") | ||
self.module.module_api._store.get_user_id_by_threepid.assert_any_call("email", "[email protected]") | ||
self.assertEqual(user_id, "test-beta") # Should match the replaced email | ||
|
||
|
||
async def test_replace_by_priority(self): | ||
self.module = create_module({"mapper_new_old_domain":{ | ||
"[email protected]":"[email protected]", | ||
"new.fr": "beta.fr"}})#replace by domain leads to a dead-end but it lower in the list | ||
|
||
# Call the tested function with an email that requires replacement | ||
user_id = await self.module.search_user_id_by_threepid("[email protected]") | ||
# Assertions | ||
self.module.module_api._store.get_user_id_by_threepid.assert_any_call("email", "[email protected]") | ||
self.module.module_api._store.get_user_id_by_threepid.assert_any_call("email", "[email protected]") | ||
self.assertEqual(user_id, "test-old") # Should match the replaced email | ||
|
||
async def test_with_email_replacement(self): | ||
async def test_with_map_should_not_replace(self): | ||
self.module = create_module({"mapper_new_old_domain":{"new.fr": "beta.fr"}}) | ||
|
||
# Call the tested function with an email that requires replacement | ||
user_id = await self.module.search_user_id_by_threepid("test@numerique.gouv.fr") | ||
user_id = await self.module.search_user_id_by_threepid("[email protected]") | ||
|
||
# Assertions | ||
self.assertEqual(user_id, "test-beta") # Should match the replaced email | ||
self.module.module_api._store.get_user_id_by_threepid.assert_any_call("email", "[email protected]") | ||
self.module.module_api._store.get_user_id_by_threepid.assert_any_call("email", "[email protected]") | ||
self.assertEqual(user_id, "test-numerique") | ||
|
||
async def test_with_empty_map(self): | ||
|
||
self.module = create_module({"mapper_new_old_domain":{}}) | ||
|
||
# Call the tested function with an email that requires replacement | ||
user_id = await self.module.search_user_id_by_threepid("[email protected]") | ||
|
||
# Assertions | ||
self.assertEqual(user_id, "test-numerique") |