Skip to content

Commit

Permalink
fix config parsing
Browse files Browse the repository at this point in the history
  • Loading branch information
odelcroi committed Dec 5, 2024
1 parent ffcd03d commit 6cc9fd2
Show file tree
Hide file tree
Showing 2 changed files with 12 additions and 12 deletions.
13 changes: 6 additions & 7 deletions synapse_sso_proconnect/proconnect_mapping.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,16 +15,16 @@

@attr.s(slots=True, frozen=True, auto_attribs=True)
class ProConnectMappingConfig:
mapper_new_old_domain:Dict[str, str]= {}
user_id_lookup_fallback_rules:Dict[str, str]= {}

class ProConnectMappingProvider(OidcMappingProvider[ProConnectMappingConfig]):
def __init__(self, config: Dict[str, str], module_api: ModuleApi):
def __init__(self, config: ProConnectMappingConfig, module_api: ModuleApi):
self.module_api = module_api
self.config_mapper = self.parse_config(config)
self._config=config

@staticmethod
def parse_config(config_dict: Dict[str, str]) -> ProConnectMappingConfig:
return ProConnectMappingConfig(**config_dict)
def parse_config(config: Dict[str, Any]) -> ProConnectMappingConfig:
return ProConnectMappingConfig(**config)

def get_remote_user_id(self, userinfo: UserInfo) -> str:
return userinfo.sub
Expand Down Expand Up @@ -85,9 +85,8 @@ async def search_user_id_by_threepid(self, email: str):

# If userId is not found, attempt replacements
if not userId:

# Iterate through all mappings
for old_value, new_value in self.config_mapper.mapper_new_old_domain.items():
for old_value, new_value in self._config.user_id_lookup_fallback_rules.items():
# Check if the key (old_value) exists within the email
if old_value in email:
# Replace the old value with the new value
Expand Down
11 changes: 6 additions & 5 deletions tests/test_proconnect_mapping.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,14 +15,15 @@ def create_module(config
else "test-old" if email == "[email protected]"
else None
)
return ProConnectMappingProvider(config, module_api)
parsed_config = ProConnectMappingProvider.parse_config(config)
return ProConnectMappingProvider(parsed_config, module_api)


class ProConnectMappingTest(aiounittest.AsyncTestCase):
#def setUp(self) -> None:

async def test_with_map_should_replace(self):
self.module = create_module({"mapper_new_old_domain":{"new.fr": "beta.fr"}})
self.module = create_module({"user_id_lookup_fallback_rules":{"new.fr": "beta.fr"}})
# Call the tested function with an email that requires replacement
user_id = await self.module.search_user_id_by_threepid("[email protected]")
# Assertions
Expand All @@ -32,7 +33,7 @@ async def test_with_map_should_replace(self):


async def test_replace_by_priority(self):
self.module = create_module({"mapper_new_old_domain":{
self.module = create_module({"user_id_lookup_fallback_rules":{
"[email protected]":"[email protected]",
"new.fr": "beta.fr"}})#replace by domain leads to a dead-end but it lower in the list

Expand All @@ -44,7 +45,7 @@ async def test_replace_by_priority(self):
self.assertEqual(user_id, "test-old") # Should match the replaced email

async def test_with_map_should_not_replace(self):
self.module = create_module({"mapper_new_old_domain":{"new.fr": "beta.fr"}})
self.module = create_module({"user_id_lookup_fallback_rules":{"new.fr": "beta.fr"}})

# Call the tested function with an email that requires replacement
user_id = await self.module.search_user_id_by_threepid("[email protected]")
Expand All @@ -54,7 +55,7 @@ async def test_with_map_should_not_replace(self):

async def test_with_empty_map(self):

self.module = create_module({"mapper_new_old_domain":{}})
self.module = create_module({"user_id_lookup_fallback_rules":{}})

# Call the tested function with an email that requires replacement
user_id = await self.module.search_user_id_by_threepid("[email protected]")
Expand Down

0 comments on commit 6cc9fd2

Please sign in to comment.