diff --git a/ERD.png b/ERD.png index 6a01d2c..04b2e8f 100644 Binary files a/ERD.png and b/ERD.png differ diff --git a/migrate/versions/2024_03_26_7102c8734154_archive_integration.py b/migrate/versions/2024_04_17_e9ac68f05e6b_archive_integration.py similarity index 94% rename from migrate/versions/2024_03_26_7102c8734154_archive_integration.py rename to migrate/versions/2024_04_17_e9ac68f05e6b_archive_integration.py index 997dd2d..5af1531 100644 --- a/migrate/versions/2024_03_26_7102c8734154_archive_integration.py +++ b/migrate/versions/2024_04_17_e9ac68f05e6b_archive_integration.py @@ -1,8 +1,8 @@ """Archive integration -Revision ID: 7102c8734154 +Revision ID: e9ac68f05e6b Revises: df57d06e1ee5 -Create Date: 2024-03-26 14:43:52.793498 +Create Date: 2024-04-17 13:25:02.261646 """ from alembic import op @@ -10,7 +10,7 @@ from sqlalchemy.dialects import postgresql # revision identifiers, used by Alembic. -revision = '7102c8734154' +revision = 'e9ac68f05e6b' down_revision = 'df57d06e1ee5' branch_labels = None depends_on = None @@ -38,6 +38,13 @@ def upgrade(): ondelete='RESTRICT'), sa.PrimaryKeyConstraint('id') ) + op.create_table('provider_user', + sa.Column('provider_id', sa.String(), nullable=False), + sa.Column('user_id', sa.String(), nullable=False), + sa.ForeignKeyConstraint(['provider_id'], ['provider.id'], ondelete='CASCADE'), + sa.ForeignKeyConstraint(['user_id'], ['user.id'], ondelete='CASCADE'), + sa.PrimaryKeyConstraint('provider_id', 'user_id') + ) op.create_table('resource', sa.Column('id', sa.String(), nullable=False), sa.Column('title', sa.String(), nullable=False), @@ -51,13 +58,6 @@ def upgrade(): sa.ForeignKeyConstraint(['provider_id'], ['provider.id'], ondelete='RESTRICT'), sa.PrimaryKeyConstraint('id') ) - op.create_table('user_provider', - sa.Column('user_id', sa.String(), nullable=False), - sa.Column('provider_id', sa.String(), nullable=False), - sa.ForeignKeyConstraint(['provider_id'], ['provider.id'], ondelete='CASCADE'), - sa.ForeignKeyConstraint(['user_id'], ['user.id'], ondelete='CASCADE'), - sa.PrimaryKeyConstraint('user_id', 'provider_id') - ) op.create_table('archive_resource', sa.Column('archive_id', sa.String(), nullable=False), sa.Column('resource_id', sa.String(), nullable=False), @@ -88,13 +88,13 @@ def upgrade(): op.add_column('client', sa.Column('provider_id', sa.String(), nullable=True)) op.create_foreign_key('client_provider_id_fkey', 'client', 'provider', ['provider_id'], ['id'], ondelete='SET NULL') op.drop_column('client', 'collection_specific') - op.add_column('identity_audit', sa.Column('_providers', sa.ARRAY(sa.String()), nullable=True)) + op.add_column('provider_audit', sa.Column('_users', sa.ARRAY(sa.String()), nullable=True)) # ### end Alembic commands ### def downgrade(): # ### commands auto generated by Alembic - adjusted ### - op.drop_column('identity_audit', '_providers') + op.drop_column('provider_audit', '_users') op.add_column('client', sa.Column('collection_specific', sa.BOOLEAN(), server_default=sa.text('false'), autoincrement=False, nullable=False)) op.drop_constraint('client_provider_id_fkey', 'client', type_='foreignkey') op.drop_column('client', 'provider_id') @@ -109,8 +109,8 @@ def downgrade(): op.drop_table('record_package') op.drop_table('package_resource') op.drop_table('archive_resource') - op.drop_table('user_provider') op.drop_table('resource') + op.drop_table('provider_user') op.drop_table('package') op.drop_table('archive') # ### end Alembic commands ### diff --git a/odp/api/routers/provider.py b/odp/api/routers/provider.py index 1a6921e..dee3e86 100644 --- a/odp/api/routers/provider.py +++ b/odp/api/routers/provider.py @@ -53,6 +53,7 @@ def output_audit_model(row) -> ProviderAuditModel: provider_id=row.ProviderAudit._id, provider_key=row.ProviderAudit._key, provider_name=row.ProviderAudit._name, + provider_users=row.ProviderAudit._users or [], ) @@ -70,6 +71,7 @@ def create_audit_record( _id=provider.id, _key=provider.key, _name=provider.name, + _users=[user.id for user in provider.users], ).save() @@ -142,6 +144,10 @@ async def create_provider( provider = Provider( key=provider_in.key, name=provider_in.name, + users=[ + Session.get(User, user_id) + for user_id in provider_in.user_ids + ], timestamp=(timestamp := datetime.now(timezone.utc)), ) provider.save() @@ -179,10 +185,14 @@ async def update_provider( if ( provider.key != provider_in.key or - provider.name != provider_in.name + provider.name != provider_in.name # todo: or user_ids != ... ): provider.key = provider_in.key provider.name = provider_in.name + provider.users = [ + Session.get(User, user_id) + for user_id in provider_in.user_ids + ] provider.timestamp = (timestamp := datetime.now(timezone.utc)) provider.save() create_audit_record(auth, provider, timestamp, AuditCommand.update) diff --git a/odp/api/routers/user.py b/odp/api/routers/user.py index 3f42040..c124e17 100644 --- a/odp/api/routers/user.py +++ b/odp/api/routers/user.py @@ -11,7 +11,7 @@ from odp.const import ODPScope from odp.const.db import IdentityCommand from odp.db import Session -from odp.db.models import IdentityAudit, Provider, Role, User +from odp.db.models import IdentityAudit, Role, User router = APIRouter() @@ -49,7 +49,6 @@ def create_audit_record( _email=user.email, _active=user.active, _roles=[role.id for role in user.roles], - _providers=[provider.id for provider in user.providers] ).save() @@ -67,7 +66,6 @@ def output_audit_model(row) -> IdentityAuditModel: user_email=row.IdentityAudit._email, user_active=row.IdentityAudit._active, user_roles=row.IdentityAudit._roles, - user_providers=row.IdentityAudit._providers, ) @@ -109,15 +107,12 @@ async def update_user( if not (user := Session.get(User, user_in.id)): raise HTTPException(HTTP_404_NOT_FOUND) + # todo: if different... user.active = user_in.active user.roles = [ Session.get(Role, role_id) for role_id in user_in.role_ids ] - user.providers = [ - Session.get(Provider, provider_id) - for provider_id in user_in.provider_ids - ] user.save() create_audit_record(auth, user, IdentityCommand.edit) diff --git a/odp/db/models/__init__.py b/odp/db/models/__init__.py index 789c269..a60210d 100644 --- a/odp/db/models/__init__.py +++ b/odp/db/models/__init__.py @@ -3,12 +3,12 @@ from .client import Client, ClientScope from .collection import Collection, CollectionAudit, CollectionTag, CollectionTagAudit from .package import Package, PackageResource -from .provider import Provider, ProviderAudit +from .provider import Provider, ProviderAudit, ProviderUser from .record import PublishedRecord, Record, RecordAudit, RecordPackage, RecordTag, RecordTagAudit from .resource import Resource from .role import Role, RoleCollection, RoleScope from .schema import Schema from .scope import Scope from .tag import Tag -from .user import IdentityAudit, User, UserProvider, UserRole +from .user import IdentityAudit, User, UserRole from .vocabulary import Vocabulary, VocabularyTerm, VocabularyTermAudit diff --git a/odp/db/models/provider.py b/odp/db/models/provider.py index 098e110..b11fece 100644 --- a/odp/db/models/provider.py +++ b/odp/db/models/provider.py @@ -1,6 +1,6 @@ import uuid -from sqlalchemy import Column, Enum, Identity, Integer, String, TIMESTAMP +from sqlalchemy import ARRAY, Column, Enum, ForeignKey, Identity, Integer, String, TIMESTAMP from sqlalchemy.ext.associationproxy import association_proxy from sqlalchemy.orm import relationship @@ -21,19 +21,35 @@ class Provider(Base): name = Column(String, nullable=False) timestamp = Column(TIMESTAMP(timezone=True), nullable=False) - # view of associated users via many-to-many user_provider relation - provider_users = relationship('UserProvider', viewonly=True) - users = association_proxy('provider_users', 'user') - # view of associated collections (one-to-many) collections = relationship('Collection', viewonly=True) # view of associated clients (one-to-many) clients = relationship('Client', viewonly=True) + # many-to-many provider_user entities are persisted by + # assigning/removing User instances to/from users + provider_users = relationship('ProviderUser', cascade='all, delete-orphan', passive_deletes=True) + users = association_proxy('provider_users', 'user', creator=lambda u: ProviderUser(user=u)) + _repr_ = 'id', 'key', 'name' +class ProviderUser(Base): + """Model of a many-to-many provider-user association, which enables + partitioning of package/resource access by (groups of) user.""" + + __tablename__ = 'provider_user' + + provider_id = Column(String, ForeignKey('provider.id', ondelete='CASCADE'), primary_key=True) + user_id = Column(String, ForeignKey('user.id', ondelete='CASCADE'), primary_key=True) + + provider = relationship('Provider', viewonly=True) + user = relationship('User') + + _repr_ = 'provider_id', 'user_id' + + class ProviderAudit(Base): """Provider audit log.""" @@ -48,3 +64,4 @@ class ProviderAudit(Base): _id = Column(String, nullable=False) _key = Column(String, nullable=False) _name = Column(String, nullable=False) + _users = Column(ARRAY(String)) diff --git a/odp/db/models/user.py b/odp/db/models/user.py index 6e7b610..f4cfdb3 100644 --- a/odp/db/models/user.py +++ b/odp/db/models/user.py @@ -27,10 +27,9 @@ class User(Base): user_roles = relationship('UserRole', cascade='all, delete-orphan', passive_deletes=True) roles = association_proxy('user_roles', 'role', creator=lambda r: UserRole(role=r)) - # many-to-many user_provider entities are persisted by - # assigning/removing Provider instances to/from providers - user_providers = relationship('UserProvider', cascade='all, delete-orphan', passive_deletes=True) - providers = association_proxy('user_providers', 'provider', creator=lambda p: UserProvider(provider=p)) + # view of associated providers via many-to-many provider_user relation + user_providers = relationship('ProviderUser', viewonly=True) + providers = association_proxy('user_providers', 'provider') _repr_ = 'id', 'email', 'name', 'active', 'verified' @@ -49,21 +48,6 @@ class UserRole(Base): _repr_ = 'user_id', 'role_id' -class UserProvider(Base): - """A user-provider association, which enables partitioning of - package/resource access by (groups of) user.""" - - __tablename__ = 'user_provider' - - user_id = Column(String, ForeignKey('user.id', ondelete='CASCADE'), primary_key=True) - provider_id = Column(String, ForeignKey('provider.id', ondelete='CASCADE'), primary_key=True) - - user = relationship('User', viewonly=True) - provider = relationship('Provider') - - _repr_ = 'user_id', 'provider_id' - - class IdentityAudit(Base): """User identity audit log.""" @@ -81,4 +65,3 @@ class IdentityAudit(Base): _email = Column(String) _active = Column(Boolean) _roles = Column(ARRAY(String)) - _providers = Column(ARRAY(String)) diff --git a/test/api/test_provider.py b/test/api/test_provider.py index 4ef59ba..aec43b4 100644 --- a/test/api/test_provider.py +++ b/test/api/test_provider.py @@ -6,7 +6,7 @@ from sqlalchemy import select from odp.const import ODPScope -from odp.db.models import Provider, ProviderAudit +from odp.db.models import Provider, ProviderAudit, ProviderUser from test import TestSession from test.api import ( assert_conflict, assert_empty_result, assert_forbidden, assert_new_timestamp, @@ -18,37 +18,37 @@ @pytest.fixture def provider_batch(): """Create and commit a batch of Provider instances, - with associated collections, users and clients.""" + with associated users, clients and collections.""" providers = [ - ProviderFactory() + ProviderFactory(users=UserFactory.create_batch(randint(0, 3))) for _ in range(randint(3, 5)) ] for provider in providers: - CollectionFactory.create_batch(randint(0, 3), provider=provider) - UserFactory.create_batch(randint(0, 3), providers=[provider]) ClientFactory.create_batch(randint(0, 3), provider=provider) + CollectionFactory.create_batch(randint(0, 3), provider=provider) + provider.user_names = {user.id: user.name for user in provider.users} + provider.client_ids = [client.id for client in provider.clients] + provider.collection_keys = {collection.id: collection.key for collection in provider.collections} + return providers def provider_build(**id): - """Build and return an uncommitted Provider instance.""" - return ProviderFactory.build(**id) - - -def collection_keys(provider): - return {collection.id: collection.key for collection in provider.collections} - - -def user_names(provider): - return {user.id: user.name for user in provider.users} - - -def client_ids(provider): - return tuple(sorted(client.id for client in provider.clients)) + """Build and return an uncommitted Provider instance. + Associated users are however committed.""" + provider = ProviderFactory.build( + **id, + users=UserFactory.create_batch(randint(0, 3)), + ) + provider.user_names = {user.id: user.name for user in provider.users} + provider.client_ids = [] + provider.collection_keys = {} + return provider def assert_db_state(providers): - """Verify that the DB provider table contains the given provider batch.""" + """Verify that the provider table contains the given provider batch, + and that the provider_user table contains the associated user references.""" result = TestSession.execute(select(Provider)).scalars().all() result.sort(key=lambda p: p.id) providers.sort(key=lambda p: p.id) @@ -59,6 +59,15 @@ def assert_db_state(providers): assert row.name == providers[n].name assert_new_timestamp(row.timestamp) + result = TestSession.execute(select(ProviderUser.provider_id, ProviderUser.user_id)).all() + result.sort(key=lambda pu: (pu.provider_id, pu.user_id)) + provider_users = [] + for provider in providers: + for user_id in provider.user_names: + provider_users += [(provider.id, user_id)] + provider_users.sort() + assert result == provider_users + def assert_audit_log(command, provider, grant_type): result = TestSession.execute(select(ProviderAudit)).scalar_one() @@ -69,6 +78,7 @@ def assert_audit_log(command, provider, grant_type): assert result._id == provider.id assert result._key == provider.key assert result._name == provider.name + assert sorted(result._users) == sorted(provider.user_names) def assert_no_audit_log(): @@ -81,9 +91,9 @@ def assert_json_result(response, json, provider): assert json['id'] == provider.id assert json['key'] == provider.key assert json['name'] == provider.name - assert json['collection_keys'] == collection_keys(provider) - assert json['user_names'] == user_names(provider) - assert tuple(sorted(json['client_ids'])) == client_ids(provider) + assert json['collection_keys'] == provider.collection_keys + assert json['user_names'] == provider.user_names + assert sorted(json['client_ids']) == sorted(provider.client_ids) assert_new_timestamp(datetime.fromisoformat(json['timestamp'])) @@ -132,15 +142,18 @@ def test_get_provider_not_found(api, provider_batch): @pytest.mark.require_scope(ODPScope.PROVIDER_ADMIN) def test_create_provider(api, provider_batch, scopes): authorized = ODPScope.PROVIDER_ADMIN in scopes - modified_provider_batch = provider_batch + [provider := provider_build()] + provider = provider_build() + r = api(scopes).post('/provider/', json=dict( key=provider.key, name=provider.name, + user_ids=list(provider.user_names), )) + if authorized: provider.id = r.json().get('id') assert_json_result(r, r.json(), provider) - assert_db_state(modified_provider_batch) + assert_db_state(provider_batch + [provider]) assert_audit_log('insert', provider, api.grant_type) else: assert_forbidden(r) @@ -150,11 +163,16 @@ def test_create_provider(api, provider_batch, scopes): def test_create_provider_conflict(api, provider_batch): scopes = [ODPScope.PROVIDER_ADMIN] - provider = provider_build(key=provider_batch[2].key) + provider = provider_build( + key=provider_batch[2].key, + ) + r = api(scopes).post('/provider/', json=dict( key=provider.key, name=provider.name, + user_ids=list(provider.user_names), )) + assert_conflict(r, 'Provider key is already in use') assert_db_state(provider_batch) assert_no_audit_log() @@ -163,17 +181,19 @@ def test_create_provider_conflict(api, provider_batch): @pytest.mark.require_scope(ODPScope.PROVIDER_ADMIN) def test_update_provider(api, provider_batch, scopes): authorized = ODPScope.PROVIDER_ADMIN in scopes - modified_provider_batch = provider_batch.copy() - modified_provider_batch[2] = (provider := provider_build( + provider = provider_build( id=provider_batch[2].id, - )) + ) + r = api(scopes).put(f'/provider/{provider.id}', json=dict( key=provider.key, name=provider.name, + user_ids=list(provider.user_names), )) + if authorized: assert_empty_result(r) - assert_db_state(modified_provider_batch) + assert_db_state(provider_batch[:2] + [provider] + provider_batch[3:]) assert_audit_log('update', provider, api.grant_type) else: assert_forbidden(r) @@ -183,11 +203,16 @@ def test_update_provider(api, provider_batch, scopes): def test_update_provider_not_found(api, provider_batch): scopes = [ODPScope.PROVIDER_ADMIN] - provider = provider_build(id=str(uuid.uuid4())) + provider = provider_build( + id=str(uuid.uuid4()), + ) + r = api(scopes).put(f'/provider/{provider.id}', json=dict( key=provider.key, name=provider.name, + user_ids=list(provider.user_names), )) + assert_not_found(r) assert_db_state(provider_batch) assert_no_audit_log() @@ -199,10 +224,13 @@ def test_update_provider_conflict(api, provider_batch): id=provider_batch[2].id, key=provider_batch[0].key, ) + r = api(scopes).put(f'/provider/{provider.id}', json=dict( key=provider.key, name=provider.name, + user_ids=list(provider.user_names), )) + assert_conflict(r, 'Provider key is already in use') assert_db_state(provider_batch) assert_no_audit_log() @@ -226,23 +254,21 @@ def has_package(request): @pytest.mark.require_scope(ODPScope.PROVIDER_ADMIN) def test_delete_provider(api, provider_batch, scopes, has_record, has_resource, has_package): authorized = ODPScope.PROVIDER_ADMIN in scopes - modified_provider_batch = provider_batch.copy() - deleted_provider = modified_provider_batch[2] - del modified_provider_batch[2] + deleted_provider = provider_batch[2] if has_record: - if collection := next((c for c in provider_batch[2].collections), None): + if collection := next((c for c in deleted_provider.collections), None): RecordFactory(collection=collection) else: has_record = False if has_resource: - ResourceFactory(provider=provider_batch[2]) + ResourceFactory(provider=deleted_provider) if has_package: - PackageFactory(provider=provider_batch[2]) + PackageFactory(provider=deleted_provider) - r = api(scopes).delete(f'/provider/{provider_batch[2].id}') + r = api(scopes).delete(f'/provider/{deleted_provider.id}') if authorized: if has_record or has_resource or has_package: @@ -254,7 +280,7 @@ def test_delete_provider(api, provider_batch, scopes, has_record, has_resource, assert_no_audit_log() else: assert_empty_result(r) - assert_db_state(modified_provider_batch) + assert_db_state(provider_batch[:2] + provider_batch[3:]) assert_audit_log('delete', deleted_provider, api.grant_type) else: assert_forbidden(r) diff --git a/test/api/test_user.py b/test/api/test_user.py index 4091014..f29926c 100644 --- a/test/api/test_user.py +++ b/test/api/test_user.py @@ -4,7 +4,7 @@ from sqlalchemy import select from odp.const import ODPScope -from odp.db.models import IdentityAudit, User +from odp.db.models import IdentityAudit, User, UserRole from test import TestSession from test.api import ( all_scopes, assert_empty_result, assert_forbidden, assert_method_not_allowed, assert_new_timestamp, @@ -15,31 +15,37 @@ @pytest.fixture def user_batch(): - """Create and commit a batch of User instances.""" - return [ - UserFactory( - roles=RoleFactory.create_batch(randint(0, 3)), - providers=ProviderFactory.create_batch(randint(0, 3)), - ) + """Create and commit a batch of User instances, with + associated roles and providers.""" + users = [ + UserFactory(roles=RoleFactory.create_batch(randint(0, 3))) for _ in range(randint(3, 5)) ] + for user in users: + ProviderFactory.create_batch(randint(0, 3), users=[user]), + user.role_ids = [role.id for role in user.roles] + user.provider_keys = {provider.id: provider.key for provider in user.providers} + return users -def role_ids(user): - return tuple(sorted(role.id for role in user.roles)) - -def provider_ids(user): - return tuple(sorted(provider.id for provider in user.providers)) - - -def provider_keys(user): - return {provider.id: provider.key for provider in user.providers} +def user_build(**attr): + """Build and return an uncommitted User instance. + Associated roles are however committed.""" + user = UserFactory.build( + **attr, + roles=RoleFactory.create_batch(randint(0, 3)), + ) + user.role_ids = [role.id for role in user.roles] + user.provider_keys = {} + return user def assert_db_state(users): - """Verify that the DB user table contains the given user batch.""" - result = TestSession.execute(select(User).where(User.id != 'odp.test.user')).scalars().all() + """Verify that the user table contains the given user batch, and + that the user_role table contains the associated role references.""" + result = TestSession.execute( + select(User).where(User.id != 'odp.test.user')).scalars().all() result.sort(key=lambda u: u.id) users.sort(key=lambda u: u.id) assert len(result) == len(users) @@ -49,11 +55,20 @@ def assert_db_state(users): assert row.email == users[n].email assert row.active == users[n].active assert row.verified == users[n].verified - assert role_ids(row) == role_ids(users[n]) - assert provider_ids(row) == provider_ids(users[n]) + assert row.picture == users[n].picture + + result = TestSession.execute( + select(UserRole.user_id, UserRole.role_id).where(UserRole.user_id != 'odp.test.user')).all() + result.sort(key=lambda ur: (ur.user_id, ur.role_id)) + user_roles = [] + for user in users: + for role_id in user.role_ids: + user_roles += [(user.id, role_id)] + user_roles.sort() + assert result == user_roles -def assert_audit_log(command, user, user_role_ids, user_provider_ids, grant_type): +def assert_audit_log(command, user, grant_type): """Verify that the identity audit table contains the given entry.""" result = TestSession.execute(select(IdentityAudit)).scalar_one() assert result.client_id == 'odp.test.client' @@ -65,8 +80,7 @@ def assert_audit_log(command, user, user_role_ids, user_provider_ids, grant_type assert result._id == user.id assert result._email == user.email assert result._active == user.active - assert tuple(sorted(result._roles)) == user_role_ids - assert tuple(sorted(result._providers)) == user_provider_ids + assert sorted(result._roles) == sorted(user.role_ids) def assert_no_audit_log(): @@ -82,8 +96,9 @@ def assert_json_result(response, json, user): assert json['email'] == user.email assert json['active'] == user.active assert json['verified'] == user.verified - assert tuple(sorted(json['role_ids'])) == role_ids(user) - assert json['provider_keys'] == provider_keys(user) + assert json['picture'] == user.picture + assert sorted(json['role_ids']) == sorted(user.role_ids) + assert json['provider_keys'] == user.provider_keys def assert_json_results(response, json, users): @@ -137,25 +152,26 @@ def test_create_user(api): @pytest.mark.require_scope(ODPScope.USER_ADMIN) def test_update_user(api, user_batch, scopes): authorized = ODPScope.USER_ADMIN in scopes - modified_user_batch = user_batch.copy() - modified_user_batch[2] = (user := UserFactory.build( + # the user update API can only modify `active` and `role_ids`; + # everything else must stay the same on the rebuilt user object + user = user_build( id=user_batch[2].id, name=user_batch[2].name, email=user_batch[2].email, verified=user_batch[2].verified, - roles=RoleFactory.create_batch(randint(0, 3)), - providers=ProviderFactory.create_batch(randint(0, 3)), - )) + picture=user_batch[2].picture, + ) + r = api(scopes).put('/user/', json=dict( id=user.id, active=user.active, - role_ids=role_ids(user), - provider_ids=provider_ids(user), + role_ids=user.role_ids, )) + if authorized: assert_empty_result(r) - assert_db_state(modified_user_batch) - assert_audit_log('edit', user, role_ids(user), provider_ids(user), api.grant_type) + assert_db_state(user_batch[:2] + [user] + user_batch[3:]) + assert_audit_log('edit', user, api.grant_type) else: assert_forbidden(r) assert_db_state(user_batch) @@ -164,19 +180,17 @@ def test_update_user(api, user_batch, scopes): def test_update_user_not_found(api, user_batch): scopes = [ODPScope.USER_ADMIN] - user = UserFactory.build( + user = user_build( id='foo', name=user_batch[2].name, email=user_batch[2].email, verified=user_batch[2].verified, - roles=RoleFactory.create_batch(randint(0, 3)), - providers=ProviderFactory.create_batch(randint(0, 3)), + picture=user_batch[2].picture, ) r = api(scopes).put('/user/', json=dict( id=user.id, active=user.active, - role_ids=role_ids(user), - provider_ids=provider_ids(user), + role_ids=user.role_ids, )) assert_not_found(r) assert_db_state(user_batch) @@ -191,18 +205,14 @@ def has_tag_instance(request): @pytest.mark.require_scope(ODPScope.USER_ADMIN) def test_delete_user(api, user_batch, scopes, has_tag_instance): authorized = ODPScope.USER_ADMIN in scopes - modified_user_batch = user_batch.copy() - deleted_user = modified_user_batch[2] - deleted_user_role_ids = role_ids(deleted_user) - deleted_user_provider_ids = provider_ids(deleted_user) - del modified_user_batch[2] + deleted_user = user_batch[2] if has_tag_instance in ('collection', 'both'): - CollectionTagFactory(user=user_batch[2]) + CollectionTagFactory(user=deleted_user) if has_tag_instance in ('record', 'both'): - RecordTagFactory(user=user_batch[2]) + RecordTagFactory(user=deleted_user) - r = api(scopes).delete(f'/user/{user_batch[2].id}') + r = api(scopes).delete(f'/user/{deleted_user.id}') if authorized: if has_tag_instance in ('collection', 'record', 'both'): @@ -211,8 +221,8 @@ def test_delete_user(api, user_batch, scopes, has_tag_instance): assert_no_audit_log() else: assert_empty_result(r) - assert_db_state(modified_user_batch) - assert_audit_log('delete', deleted_user, deleted_user_role_ids, deleted_user_provider_ids, api.grant_type) + assert_db_state(user_batch[:2] + user_batch[3:]) + assert_audit_log('delete', deleted_user, api.grant_type) else: assert_forbidden(r) assert_db_state(user_batch) diff --git a/test/factories.py b/test/factories.py index 908d8b8..5c2d610 100644 --- a/test/factories.py +++ b/test/factories.py @@ -183,6 +183,14 @@ class Meta: name = factory.Sequence(lambda n: f'{fake.company()}.{n}') timestamp = factory.LazyFunction(lambda: datetime.now(timezone.utc)) + @factory.post_generation + def users(obj, create, users): + if users: + for user in users: + obj.users.append(user) + if create: + FactorySession.commit() + class PackageFactory(ODPModelFactory): class Meta: @@ -311,14 +319,6 @@ class Meta: verified = factory.LazyFunction(lambda: randint(0, 1)) picture = factory.Faker('image_url') - @factory.post_generation - def providers(obj, create, providers): - if providers: - for provider in providers: - obj.providers.append(provider) - if create: - FactorySession.commit() - @factory.post_generation def roles(obj, create, roles): if roles: diff --git a/test/test_db.py b/test/test_db.py index 38d1c98..3ef889e 100644 --- a/test/test_db.py +++ b/test/test_db.py @@ -14,6 +14,7 @@ CollectionTag, Package, Provider, + ProviderUser, Record, RecordTag, Resource, @@ -24,7 +25,6 @@ Scope, Tag, User, - UserProvider, UserRole, Vocabulary, VocabularyTerm, @@ -183,6 +183,14 @@ def test_create_provider(): assert (result.id, result.key, result.name) == (provider.id, provider.key, provider.name) +def test_create_provider_with_users(): + users = UserFactory.create_batch(5) + provider = ProviderFactory(users=users) + result = TestSession.execute(select(ProviderUser)).scalars() + assert sorted_tuples((row.provider_id, row.user_id) for row in result) \ + == sorted_tuples((provider.id, user.id) for user in users) + + def test_create_record(): record = RecordFactory(is_child_record=randint(0, 1)) result = TestSession.execute( @@ -319,14 +327,6 @@ def test_create_user(): == (user.id, user.name, user.email, user.active, user.verified) -def test_create_user_with_providers(): - providers = ProviderFactory.create_batch(5) - user = UserFactory(providers=providers) - result = TestSession.execute(select(UserProvider)).scalars() - assert sorted_tuples((row.user_id, row.provider_id) for row in result) \ - == sorted_tuples((user.id, provider.id) for provider in providers) - - def test_create_user_with_roles(): roles = RoleFactory.create_batch(5) user = UserFactory(roles=roles)