diff --git a/api/v1/routes/auth.py b/api/v1/routes/auth.py index e257ff261..78ee5f2b8 100644 --- a/api/v1/routes/auth.py +++ b/api/v1/routes/auth.py @@ -33,7 +33,8 @@ def register(background_tasks: BackgroundTasks, response: Response, user_schema: name=f"{user.email}'s Organisation", email=user.email ) - user_org = organisation_service.create(db=db, schema=org, user=user) + organisation_service.create(db=db, schema=org, user=user) + user_organizations = organisation_service.retrieve_user_organizations(user, db) # Create access and refresh tokens access_token = user_service.create_access_token(user_id=user.id) @@ -61,7 +62,8 @@ def register(background_tasks: BackgroundTasks, response: Response, user_schema: 'user': jsonable_encoder( user, exclude=['password', 'is_deleted', 'is_verified', 'updated_at'] - ) + ), + 'organizations': user_organizations } ) @@ -121,6 +123,7 @@ def login(login_request: LoginRequest, db: Session = Depends(get_db)): user = user_service.authenticate_user( db=db, email=login_request.email, password=login_request.password ) + user_organizations = organisation_service.retrieve_user_organizations(user, db) # Generate access and refresh tokens access_token = user_service.create_access_token(user_id=user.id) @@ -134,7 +137,8 @@ def login(login_request: LoginRequest, db: Session = Depends(get_db)): 'user': jsonable_encoder( user, exclude=['password', 'is_deleted', 'is_verified', 'updated_at'] - ) + ), + 'organizations': user_organizations } ) @@ -287,6 +291,7 @@ def request_magic_link( @auth.post("/magic-link/verify") async def verify_magic_link(token_schema: Token, db: Session = Depends(get_db)): user, access_token = AuthService.verify_magic_token(token_schema.access_token, db) + user_organizations = organisation_service.retrieve_user_organizations(user, db) refresh_token = user_service.create_refresh_token(user_id=user.id) @@ -298,7 +303,8 @@ async def verify_magic_link(token_schema: Token, db: Session = Depends(get_db)): 'user': jsonable_encoder( user, exclude=['password', 'is_deleted', 'is_verified', 'updated_at'] - ) + ), + 'organizations': user_organizations } ) diff --git a/api/v1/schemas/organisation.py b/api/v1/schemas/organisation.py index fe8767800..b0166475b 100644 --- a/api/v1/schemas/organisation.py +++ b/api/v1/schemas/organisation.py @@ -1,6 +1,6 @@ from datetime import datetime from typing import Dict, List -from pydantic import BaseModel, EmailStr, field_validator +from pydantic import BaseModel, EmailStr, field_validator, ConfigDict from typing import Optional from api.utils.success_response import success_response @@ -65,3 +65,19 @@ class PaginatedOrgUsers(BaseModel): success: bool message: str data: List[Dict] + +class OrganisationData(BaseModel): + """Base organisation schema""" + id: str + created_at: datetime + updated_at: datetime + name: str + email: Optional[EmailStr] = None + industry: Optional[str] = None + type: Optional[str] = None + country: Optional[str] = None + state: Optional[str] = None + address: Optional[str] = None + description: Optional[str] = None + + model_config = ConfigDict(from_attributes=True) \ No newline at end of file diff --git a/api/v1/schemas/user.py b/api/v1/schemas/user.py index df4ff98b6..696aa993c 100644 --- a/api/v1/schemas/user.py +++ b/api/v1/schemas/user.py @@ -136,7 +136,7 @@ class EmailRequest(BaseModel): class Token(BaseModel): access_token: str - token_type: str + token_type: str = None class TokenData(BaseModel): diff --git a/api/v1/services/organisation.py b/api/v1/services/organisation.py index 8d0bf0cb4..e0e0873d6 100644 --- a/api/v1/services/organisation.py +++ b/api/v1/services/organisation.py @@ -1,8 +1,8 @@ import csv from io import StringIO import logging -from typing import Any, Optional -from fastapi import HTTPException +from typing import Any, Optional, Annotated +from fastapi import HTTPException, Depends, status from sqlalchemy.orm import Session from fastapi import HTTPException, status from sqlalchemy import select @@ -18,8 +18,10 @@ from api.v1.schemas.organisation import ( CreateUpdateOrganisation, AddUpdateOrganisationRole, - RemoveUserFromOrganisation + RemoveUserFromOrganisation, + OrganisationData ) +from api.db.database import get_db class OrganisationService(Service): @@ -299,6 +301,25 @@ def export_organisation_members(self, db: Session, org_id: str): csv_file.seek(0) return csv_file + + def retrieve_user_organizations(self, user: User, + db: Annotated[Session, Depends(get_db)]): + """ + Retrieves all organizations a user belongs to. + + Args: + user: the user to retrieve the organizations + """ + user_organisations = db.query(Organisation).join( + user_organisation_association, + user_organisation_association.c.user_id == user.id + ).filter( + user_organisation_association.c.user_id == user.id + ).all() + + if user_organisations: + return [OrganisationData.model_validate(org, from_attributes=True) for org in user_organisations] + return [{"name": '', "id": '', "description": '', "created_at": "", "updated_at": ''}] organisation_service = OrganisationService()