Skip to content

Commit

Permalink
add config
Browse files Browse the repository at this point in the history
  • Loading branch information
odelcroi committed Dec 5, 2024
1 parent f3c9785 commit ffcd03d
Show file tree
Hide file tree
Showing 3 changed files with 49 additions and 61 deletions.
25 changes: 8 additions & 17 deletions synapse_sso_proconnect/proconnect_mapping.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@

import attr
from authlib.oidc.core import UserInfo
from typing import Any, Dict, List, Optional, Tuple

from synapse.handlers.oidc import OidcMappingProvider, Token, UserAttributeDict
from synapse.handlers.sso import MappingException
Expand All @@ -14,16 +15,16 @@

@attr.s(slots=True, frozen=True, auto_attribs=True)
class ProConnectMappingConfig:
pass

mapper_new_old_domain:Dict[str, str]= {}

class ProConnectMappingProvider(OidcMappingProvider[ProConnectMappingConfig]):
def __init__(self, config: ProConnectMappingConfig, module_api: ModuleApi):
def __init__(self, config: Dict[str, str], module_api: ModuleApi):
self.module_api = module_api
self.config_mapper = self.parse_config(config)

@staticmethod
def parse_config(config: dict) -> ProConnectMappingConfig:
return ProConnectMappingConfig()
def parse_config(config_dict: Dict[str, str]) -> ProConnectMappingConfig:
return ProConnectMappingConfig(**config_dict)

def get_remote_user_id(self, userinfo: UserInfo) -> str:
return userinfo.sub
Expand Down Expand Up @@ -77,26 +78,16 @@ async def map_user_attributes(
display_name=display_name,
)

# Return a dict with specific email replacements mappings.
async def getReplaceMapping(self):
return {
# Specific email replacement
"[email protected]" : "[email protected]",
# General domain replacement
"numerique.gouv.fr": "beta.gouv.fr"
}

# Search user ID by its email, retrying with replacements if necessary.
async def search_user_id_by_threepid(self, email: str):
# Try to find the user ID using the provided email
userId = await self.module_api._store.get_user_id_by_threepid("email", email)

# If userId is not found, attempt replacements
if not userId:
replace_mapping = await self.getReplaceMapping() # Get the mapping of replacements


# Iterate through all mappings
for old_value, new_value in replace_mapping.items():
for old_value, new_value in self.config_mapper.mapper_new_old_domain.items():
# Check if the key (old_value) exists within the email
if old_value in email:
# Replace the old value with the new value
Expand Down
29 changes: 0 additions & 29 deletions tests/__init__.py
Original file line number Diff line number Diff line change
@@ -1,29 +0,0 @@
from typing import Any, Dict, Optional
from unittest.mock import AsyncMock, Mock

import attr
from synapse.module_api import ModuleApi, UserID

from synapse_sso_proconnect.proconnect_mapping import ProConnectMappingProvider

class MockHomeserver:
def get_datastores(self):
return Mock(spec=["main"])

def get_task_scheduler(self):
return Mock(spec=["register_action"])

def create_module(
config_override: Optional[Dict[str, Any]] = None, server_name: str = "example.com"
) -> ProConnectMappingProvider:
# Create a mock based on the ModuleApi spec, but override some mocked functions
# because some capabilities are needed for running the tests.
module_api = Mock(spec=ModuleApi)

if config_override is None:
config_override = {}
config_override["id_server"] = "example.com"

config = ProConnectMappingProvider.parse_config(config_override)

return ProConnectMappingProvider(config, module_api)
56 changes: 41 additions & 15 deletions tests/test_proconnect_mapping.py
Original file line number Diff line number Diff line change
@@ -1,37 +1,63 @@
from unittest.mock import AsyncMock
import aiounittest
from tests import create_module
from synapse_sso_proconnect.proconnect_mapping import ProConnectMappingProvider


def create_module(
def create_module(config
) -> ProConnectMappingProvider:
module_api = AsyncMock()
# Adding _store to the module_api object
module_api._store = AsyncMock()
module_api.getReplaceMapping = AsyncMock(
return_value={"numerique.gouv.fr": "beta.gouv.fr"}
)
module_api._store.get_user_id_by_threepid.side_effect = lambda typ, email: (
"test-beta" if email == "[email protected]"
else "test-exemple" if email == "[email protected]"
"test-beta" if email == "[email protected]"
else "test-exemple" if email == "[email protected]"
else "test-numerique" if email == "[email protected]"
else "test-old" if email == "[email protected]"
else None
)
config= {}
return ProConnectMappingProvider(config, module_api)


class ProConnectMappingTest(aiounittest.AsyncTestCase):
def setUp(self) -> None:
self.module = create_module()
#def setUp(self) -> None:

async def test_with_map_should_replace(self):
self.module = create_module({"mapper_new_old_domain":{"new.fr": "beta.fr"}})
# Call the tested function with an email that requires replacement
user_id = await self.module.search_user_id_by_threepid("[email protected]")
# Assertions
self.module.module_api._store.get_user_id_by_threepid.assert_any_call("email", "[email protected]")
self.module.module_api._store.get_user_id_by_threepid.assert_any_call("email", "[email protected]")
self.assertEqual(user_id, "test-beta") # Should match the replaced email


async def test_replace_by_priority(self):
self.module = create_module({"mapper_new_old_domain":{
"[email protected]":"[email protected]",
"new.fr": "beta.fr"}})#replace by domain leads to a dead-end but it lower in the list

# Call the tested function with an email that requires replacement
user_id = await self.module.search_user_id_by_threepid("[email protected]")
# Assertions
self.module.module_api._store.get_user_id_by_threepid.assert_any_call("email", "[email protected]")
self.module.module_api._store.get_user_id_by_threepid.assert_any_call("email", "[email protected]")
self.assertEqual(user_id, "test-old") # Should match the replaced email

async def test_with_email_replacement(self):
async def test_with_map_should_not_replace(self):
self.module = create_module({"mapper_new_old_domain":{"new.fr": "beta.fr"}})

# Call the tested function with an email that requires replacement
user_id = await self.module.search_user_id_by_threepid("test@numerique.gouv.fr")
user_id = await self.module.search_user_id_by_threepid("[email protected]")

# Assertions
self.assertEqual(user_id, "test-beta") # Should match the replaced email
self.module.module_api._store.get_user_id_by_threepid.assert_any_call("email", "[email protected]")
self.module.module_api._store.get_user_id_by_threepid.assert_any_call("email", "[email protected]")
self.assertEqual(user_id, "test-numerique")

async def test_with_empty_map(self):

self.module = create_module({"mapper_new_old_domain":{}})

# Call the tested function with an email that requires replacement
user_id = await self.module.search_user_id_by_threepid("[email protected]")

# Assertions
self.assertEqual(user_id, "test-numerique")

0 comments on commit ffcd03d

Please sign in to comment.