diff --git a/tests/storage/repository/test_study.py b/tests/storage/repository/test_study.py index 7de63bc508..6fbda37a9c 100644 --- a/tests/storage/repository/test_study.py +++ b/tests/storage/repository/test_study.py @@ -4,7 +4,7 @@ from antarest.core.model import PublicMode from antarest.login.model import Group, User from antarest.study.model import DEFAULT_WORKSPACE_NAME, RawStudy, Study, StudyContentStatus -from antarest.study.repository import QueryUser, StudyMetadataRepository +from antarest.study.repository import QueryUser, StudyFilter, StudyMetadataRepository from antarest.study.storage.variantstudy.model.dbmodel import VariantStudy from tests.helpers import with_db_context @@ -64,7 +64,7 @@ def test_lifecycle() -> None: c = repo.one(a.id) assert a == c - assert len(repo.get_all(query_user=QueryUser(is_admin=True))) == 4 + assert len(repo.get_all(study_filter=StudyFilter(query_user=QueryUser(is_admin=True)))) == 4 assert len(repo.get_all_raw(exists=True)) == 1 assert len(repo.get_all_raw(exists=False)) == 1 assert len(repo.get_all_raw()) == 2 diff --git a/tests/storage/test_service.py b/tests/storage/test_service.py index 12f61e6489..165a5f42ba 100644 --- a/tests/storage/test_service.py +++ b/tests/storage/test_service.py @@ -44,7 +44,7 @@ TimeSerie, TimeSeriesData, ) -from antarest.study.repository import StudyFilter, StudyMetadataRepository +from antarest.study.repository import StudyFilter, StudyMetadataRepository, build_query_user_from_params from antarest.study.service import MAX_MISSING_STUDY_TIMEOUT, StudyService, StudyUpgraderTask, UserHasNotPermissionError from antarest.study.storage.patch_service import PatchService from antarest.study.storage.rawstudy.model.filesystem.config.model import ( @@ -172,6 +172,7 @@ def test_study_listing(db_session: Session) -> None: config = Config(storage=StorageConfig(workspaces={DEFAULT_WORKSPACE_NAME: WorkspaceConfig()})) repository = StudyMetadataRepository(cache_service=Mock(spec=ICache), session=db_session) service = build_study_service(raw_study_service, repository, config, cache_service=cache) + params: RequestParameters = RequestParameters(user=JWTUser(id=2, impersonator=2, type="users")) # retrieve studies that are not managed # use the db recorder to check that: @@ -179,10 +180,7 @@ def test_study_listing(db_session: Session) -> None: # 2- having an exact total of queries equals to 1 with DBStatementRecorder(db_session.bind) as db_recorder: studies = service.get_studies_information( - study_filter=StudyFilter( - managed=False, - ), - params=RequestParameters(user=JWTUser(id=2, impersonator=2, type="users")), + study_filter=StudyFilter(managed=False, query_user=build_query_user_from_params(params)), params=params ) assert len(db_recorder.sql_statements) == 1, str(db_recorder) @@ -196,10 +194,7 @@ def test_study_listing(db_session: Session) -> None: # 2- having an exact total of queries equals to 1 with DBStatementRecorder(db_session.bind) as db_recorder: studies = service.get_studies_information( - study_filter=StudyFilter( - managed=True, - ), - params=RequestParameters(user=JWTUser(id=2, impersonator=2, type="users")), + study_filter=StudyFilter(managed=True, query_user=build_query_user_from_params(params)), params=params ) assert len(db_recorder.sql_statements) == 1, str(db_recorder) @@ -213,10 +208,7 @@ def test_study_listing(db_session: Session) -> None: # 2- having an exact total of queries equals to 1 with DBStatementRecorder(db_session.bind) as db_recorder: studies = service.get_studies_information( - study_filter=StudyFilter( - managed=None, - ), - params=RequestParameters(user=JWTUser(id=2, impersonator=2, type="users")), + study_filter=StudyFilter(managed=None, query_user=build_query_user_from_params(params)), params=params ) assert len(db_recorder.sql_statements) == 1, str(db_recorder) @@ -230,10 +222,7 @@ def test_study_listing(db_session: Session) -> None: # 2- the `put` method of `cache` was never used with DBStatementRecorder(db_session.bind) as db_recorder: studies = service.get_studies_information( - study_filter=StudyFilter( - managed=None, - ), - params=RequestParameters(user=JWTUser(id=2, impersonator=2, type="users")), + study_filter=StudyFilter(managed=None, query_user=build_query_user_from_params(params)), params=params ) assert len(db_recorder.sql_statements) == 1, str(db_recorder) with contextlib.suppress(AssertionError): diff --git a/tests/study/test_repository.py b/tests/study/test_repository.py index f4041d46b4..7a8364c319 100644 --- a/tests/study/test_repository.py +++ b/tests/study/test_repository.py @@ -67,8 +67,12 @@ def test_repository_get_all__general_case( # 3- having an exact total of queries equals to 1 with DBStatementRecorder(db_session.bind) as db_recorder: all_studies = repository.get_all( - study_filter=StudyFilter(managed=managed, study_ids=study_ids, exists=exists), - query_user=QueryUser(is_admin=True), + study_filter=StudyFilter( + managed=managed, + study_ids=study_ids, + exists=exists, + query_user=QueryUser(is_admin=True), + ) ) _ = [s.owner for s in all_studies] _ = [s.groups for s in all_studies] @@ -100,9 +104,9 @@ def test_repository_get_all__incompatible_case( db_session.commit() # case 1 - study_filter = StudyFilter(managed=False, variant=True) + study_filter = StudyFilter(managed=False, variant=True, query_user=QueryUser(is_admin=True)) with DBStatementRecorder(db_session.bind) as db_recorder: - all_studies = repository.get_all(study_filter=study_filter, query_user=QueryUser(is_admin=True)) + all_studies = repository.get_all(study_filter=study_filter) _ = [s.owner for s in all_studies] _ = [s.groups for s in all_studies] _ = [s.additional_data for s in all_studies] @@ -111,9 +115,9 @@ def test_repository_get_all__incompatible_case( assert not {s.id for s in all_studies} # case 2 - study_filter = StudyFilter(workspace=test_workspace, variant=True) + study_filter = StudyFilter(workspace=test_workspace, variant=True, query_user=QueryUser(is_admin=True)) with DBStatementRecorder(db_session.bind) as db_recorder: - all_studies = repository.get_all(study_filter=study_filter, query_user=QueryUser(is_admin=True)) + all_studies = repository.get_all(study_filter=study_filter) _ = [s.owner for s in all_studies] _ = [s.groups for s in all_studies] _ = [s.additional_data for s in all_studies] @@ -122,9 +126,9 @@ def test_repository_get_all__incompatible_case( assert not {s.id for s in all_studies} # case 3 - study_filter = StudyFilter(exists=False, variant=True) + study_filter = StudyFilter(exists=False, variant=True, query_user=QueryUser(is_admin=True)) with DBStatementRecorder(db_session.bind) as db_recorder: - all_studies = repository.get_all(study_filter=study_filter, query_user=QueryUser(is_admin=True)) + all_studies = repository.get_all(study_filter=study_filter) _ = [s.owner for s in all_studies] _ = [s.groups for s in all_studies] _ = [s.additional_data for s in all_studies] @@ -172,7 +176,7 @@ def test_repository_get_all__study_name_filter( # 2- accessing studies attributes does not require additional queries to db # 3- having an exact total of queries equals to 1 with DBStatementRecorder(db_session.bind) as db_recorder: - all_studies = repository.get_all(study_filter=StudyFilter(name=name), query_user=QueryUser(is_admin=True)) + all_studies = repository.get_all(study_filter=StudyFilter(name=name, query_user=QueryUser(is_admin=True))) _ = [s.owner for s in all_studies] _ = [s.groups for s in all_studies] _ = [s.additional_data for s in all_studies] @@ -217,7 +221,7 @@ def test_repository_get_all__managed_study_filter( # 2- accessing studies attributes does not require additional queries to db # 3- having an exact total of queries equals to 1 with DBStatementRecorder(db_session.bind) as db_recorder: - all_studies = repository.get_all(study_filter=StudyFilter(managed=managed), query_user=QueryUser(is_admin=True)) + all_studies = repository.get_all(study_filter=StudyFilter(managed=managed, query_user=QueryUser(is_admin=True))) _ = [s.owner for s in all_studies] _ = [s.groups for s in all_studies] _ = [s.additional_data for s in all_studies] @@ -258,7 +262,7 @@ def test_repository_get_all__archived_study_filter( # 3- having an exact total of queries equals to 1 with DBStatementRecorder(db_session.bind) as db_recorder: all_studies = repository.get_all( - study_filter=StudyFilter(archived=archived), query_user=QueryUser(is_admin=True) + study_filter=StudyFilter(archived=archived, query_user=QueryUser(is_admin=True)) ) _ = [s.owner for s in all_studies] _ = [s.groups for s in all_studies] @@ -299,7 +303,7 @@ def test_repository_get_all__variant_study_filter( # 2- accessing studies attributes does not require additional queries to db # 3- having an exact total of queries equals to 1 with DBStatementRecorder(db_session.bind) as db_recorder: - all_studies = repository.get_all(study_filter=StudyFilter(variant=variant), query_user=QueryUser(is_admin=True)) + all_studies = repository.get_all(study_filter=StudyFilter(variant=variant, query_user=QueryUser(is_admin=True))) _ = [s.owner for s in all_studies] _ = [s.groups for s in all_studies] _ = [s.additional_data for s in all_studies] @@ -342,7 +346,7 @@ def test_repository_get_all__study_version_filter( # 3- having an exact total of queries equals to 1 with DBStatementRecorder(db_session.bind) as db_recorder: all_studies = repository.get_all( - study_filter=StudyFilter(versions=versions), query_user=QueryUser(is_admin=True) + study_filter=StudyFilter(versions=versions, query_user=QueryUser(is_admin=True)) ) _ = [s.owner for s in all_studies] _ = [s.groups for s in all_studies] @@ -391,7 +395,7 @@ def test_repository_get_all__study_users_filter( # 2- accessing studies attributes does not require additional queries to db # 3- having an exact total of queries equals to 1 with DBStatementRecorder(db_session.bind) as db_recorder: - all_studies = repository.get_all(study_filter=StudyFilter(users=users), query_user=QueryUser(is_admin=True)) + all_studies = repository.get_all(study_filter=StudyFilter(users=users, query_user=QueryUser(is_admin=True))) _ = [s.owner for s in all_studies] _ = [s.groups for s in all_studies] _ = [s.additional_data for s in all_studies] @@ -439,7 +443,7 @@ def test_repository_get_all__study_groups_filter( # 2- accessing studies attributes does not require additional queries to db # 3- having an exact total of queries equals to 1 with DBStatementRecorder(db_session.bind) as db_recorder: - all_studies = repository.get_all(study_filter=StudyFilter(groups=groups), query_user=QueryUser(is_admin=True)) + all_studies = repository.get_all(study_filter=StudyFilter(groups=groups, query_user=QueryUser(is_admin=True))) _ = [s.owner for s in all_studies] _ = [s.groups for s in all_studies] _ = [s.additional_data for s in all_studies] @@ -483,7 +487,7 @@ def test_repository_get_all__study_ids_filter( # 3- having an exact total of queries equals to 1 with DBStatementRecorder(db_session.bind) as db_recorder: all_studies = repository.get_all( - study_filter=StudyFilter(study_ids=study_ids), query_user=QueryUser(is_admin=True) + study_filter=StudyFilter(study_ids=study_ids, query_user=QueryUser(is_admin=True)) ) _ = [s.owner for s in all_studies] _ = [s.groups for s in all_studies] @@ -524,7 +528,7 @@ def test_repository_get_all__study_existence_filter( # 2- accessing studies attributes does not require additional queries to db # 3- having an exact total of queries equals to 1 with DBStatementRecorder(db_session.bind) as db_recorder: - all_studies = repository.get_all(study_filter=StudyFilter(exists=exists), query_user=QueryUser(is_admin=True)) + all_studies = repository.get_all(study_filter=StudyFilter(exists=exists, query_user=QueryUser(is_admin=True))) _ = [s.owner for s in all_studies] _ = [s.groups for s in all_studies] _ = [s.additional_data for s in all_studies] @@ -566,7 +570,7 @@ def test_repository_get_all__study_workspace_filter( # 3- having an exact total of queries equals to 1 with DBStatementRecorder(db_session.bind) as db_recorder: all_studies = repository.get_all( - study_filter=StudyFilter(workspace=workspace), query_user=QueryUser(is_admin=True) + study_filter=StudyFilter(workspace=workspace, query_user=QueryUser(is_admin=True)) ) _ = [s.owner for s in all_studies] _ = [s.groups for s in all_studies] @@ -610,7 +614,7 @@ def test_repository_get_all__study_folder_filter( # 2- accessing studies attributes does not require additional queries to db # 3- having an exact total of queries equals to 1 with DBStatementRecorder(db_session.bind) as db_recorder: - all_studies = repository.get_all(study_filter=StudyFilter(folder=folder), query_user=QueryUser(is_admin=True)) + all_studies = repository.get_all(study_filter=StudyFilter(folder=folder, query_user=QueryUser(is_admin=True))) _ = [s.owner for s in all_studies] _ = [s.groups for s in all_studies] _ = [s.additional_data for s in all_studies] @@ -661,7 +665,7 @@ def test_repository_get_all__study_tags_filter( # 2- accessing studies attributes does not require additional queries to db # 3- having an exact total of queries equals to 1 with DBStatementRecorder(db_session.bind) as db_recorder: - all_studies = repository.get_all(study_filter=StudyFilter(tags=tags), query_user=QueryUser(is_admin=True)) + all_studies = repository.get_all(study_filter=StudyFilter(tags=tags, query_user=QueryUser(is_admin=True))) _ = [s.owner for s in all_studies] _ = [s.groups for s in all_studies] _ = [s.additional_data for s in all_studies]