Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat(pydantic): use pydantic serialization #2139

Merged
merged 12 commits into from
Sep 17, 2024
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,6 @@
"""
import collections
import itertools
import json
import secrets
import typing as t

Expand All @@ -16,6 +15,7 @@
from sqlalchemy.engine import Connection # type: ignore

from antarest.study.css4_colors import COLOR_NAMES
from antarest.utils import from_json, to_json

# revision identifiers, used by Alembic.
revision = "dae93f1d9110"
Expand All @@ -34,7 +34,7 @@ def _avoid_duplicates(tags: t.Iterable[str]) -> t.Sequence[str]:
def _load_patch_obj(patch: t.Optional[str]) -> t.MutableMapping[str, t.Any]:
"""Load the patch object from the `patch` field in the `study_additional_data` table."""

obj: t.MutableMapping[str, t.Any] = json.loads(patch or "{}")
obj: t.MutableMapping[str, t.Any] = from_json(patch or "{}")
obj["study"] = obj.get("study") or {}
obj["study"]["tags"] = _avoid_duplicates(obj["study"].get("tags") or [])
return obj
Expand Down Expand Up @@ -113,7 +113,7 @@ def downgrade() -> None:
objects_by_ids[study_id] = obj

# Updating objects in the `study_additional_data` table
bulk_patches = [{"study_id": id_, "patch": json.dumps(obj)} for id_, obj in objects_by_ids.items()]
bulk_patches = [{"study_id": id_, "patch": to_json(obj)} for id_, obj in objects_by_ids.items()]
if bulk_patches:
sql = sa.text("UPDATE study_additional_data SET patch = :patch WHERE study_id = :study_id")
connexion.execute(sql, *bulk_patches)
Expand Down
4 changes: 2 additions & 2 deletions antarest/core/cache/business/redis_cache.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,6 @@
#
# This file is part of the Antares project.

import json
import logging
from typing import List, Optional

Expand All @@ -19,6 +18,7 @@

from antarest.core.interfaces.cache import ICache
from antarest.core.model import JSON
from antarest.utils import from_json

logger = logging.getLogger(__name__)

Expand Down Expand Up @@ -49,7 +49,7 @@ def get(self, id: str, refresh_timeout: Optional[int] = None) -> Optional[JSON]:
logger.info(f"Trying to retrieve cache key {id}")
if result is not None:
logger.info(f"Cache key {id} found")
json_result = json.loads(result)
json_result = from_json(result)
redis_element = RedisCacheElement(duration=json_result["duration"], data=json_result["data"])
self.redis.expire(
redis_key,
Expand Down
6 changes: 3 additions & 3 deletions antarest/core/configdata/repository.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,14 +10,14 @@
#
# This file is part of the Antares project.

import json
from operator import and_
from typing import Optional

from antarest.core.configdata.model import ConfigData
from antarest.core.jwt import DEFAULT_ADMIN_USER
from antarest.core.model import JSON
from antarest.core.utils.fastapi_sqlalchemy import db
from antarest.utils import from_json, to_json


class ConfigDataRepository:
Expand All @@ -43,14 +43,14 @@ def get(self, key: str, owner: Optional[int] = None) -> Optional[ConfigData]:
def get_json(self, key: str, owner: Optional[int] = None) -> Optional[JSON]:
configdata = self.get(key, owner)
if configdata:
data: JSON = json.loads(configdata.value)
data: JSON = from_json(configdata.value)
return data
return None

def put_json(self, key: str, data: JSON, owner: Optional[int] = None) -> None:
configdata = ConfigData(
key=key,
value=json.dumps(data),
value=to_json(data).decode("utf-8"),
MartinBelthle marked this conversation as resolved.
Show resolved Hide resolved
owner=owner or DEFAULT_ADMIN_USER.id,
)
configdata = db.session.merge(configdata)
Expand Down
3 changes: 2 additions & 1 deletion antarest/eventbus/web.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,6 +29,7 @@
from antarest.core.permissions import check_permission
from antarest.fastapi_jwt_auth import AuthJWT
from antarest.login.auth import Auth
from antarest.utils import to_json

logger = logging.getLogger(__name__)

Expand Down Expand Up @@ -99,7 +100,7 @@ async def send_event_to_ws(event: Event) -> None:
event_data = event.model_dump()
del event_data["permissions"]
del event_data["channel"]
await manager.broadcast(json.dumps(event_data), event.permissions, event.channel)
await manager.broadcast(to_json(event_data).decode("utf-8"), event.permissions, event.channel)
MartinBelthle marked this conversation as resolved.
Show resolved Hide resolved

@app_ctxt.api_root.websocket("/ws")
async def connect(
Expand Down
6 changes: 3 additions & 3 deletions antarest/launcher/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,17 +11,17 @@
# This file is part of the Antares project.

import enum
import json
import typing as t
from datetime import datetime

from pydantic import BaseModel, ConfigDict, Field
from pydantic import BaseModel, Field
from sqlalchemy import Column, DateTime, Enum, ForeignKey, Integer, Sequence, String # type: ignore
from sqlalchemy.orm import relationship # type: ignore

from antarest.core.persistence import Base
from antarest.login.model import Identity, UserInfo
from antarest.study.business.all_optional_meta import camel_case_model
from antarest.utils import from_json


class XpansionParametersDTO(BaseModel):
Expand Down Expand Up @@ -54,7 +54,7 @@ def from_launcher_params(cls, params: t.Optional[str]) -> "LauncherParametersDTO
"""
if params is None:
return cls()
return cls.model_validate(json.loads(params))
return cls.model_validate(from_json(params))


class LogType(str, enum.Enum):
Expand Down
6 changes: 3 additions & 3 deletions antarest/login/auth.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,6 @@
#
# This file is part of the Antares project.

import json
import logging
from datetime import timedelta
from typing import Any, Callable, Coroutine, Dict, Optional, Tuple, Union
Expand All @@ -23,6 +22,7 @@
from antarest.core.config import Config
from antarest.core.jwt import DEFAULT_ADMIN_USER, JWTUser
from antarest.fastapi_jwt_auth import AuthJWT
from antarest.utils import from_json

logger = logging.getLogger(__name__)

Expand Down Expand Up @@ -66,14 +66,14 @@ def get_current_user(self, auth_jwt: AuthJWT = Depends()) -> JWTUser:

auth_jwt.jwt_required()

user = JWTUser.model_validate(json.loads(auth_jwt.get_jwt_subject()))
user = JWTUser.model_validate(from_json(auth_jwt.get_jwt_subject()))
return user

@staticmethod
def get_user_from_token(token: str, jwt_manager: AuthJWT) -> Optional[JWTUser]:
try:
token_data = jwt_manager._verified_token(token)
return JWTUser.model_validate(json.loads(token_data["sub"]))
return JWTUser.model_validate(from_json(token_data["sub"]))
except Exception as e:
logger.debug("Failed to retrieve user from token", exc_info=e)
return None
Expand Down
5 changes: 2 additions & 3 deletions antarest/login/main.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,11 +10,9 @@
#
# This file is part of the Antares project.

import json
from http import HTTPStatus
from typing import Any, Optional

from fastapi import APIRouter, FastAPI
from starlette.requests import Request
from starlette.responses import JSONResponse

Expand All @@ -28,6 +26,7 @@
from antarest.login.repository import BotRepository, GroupRepository, RoleRepository, UserLdapRepository, UserRepository
from antarest.login.service import LoginService
from antarest.login.web import create_login_api
from antarest.utils import from_json


def build_login(
Expand Down Expand Up @@ -78,7 +77,7 @@ def authjwt_exception_handler(request: Request, exc: AuthJWTException) -> Any:

@AuthJWT.token_in_denylist_loader # type: ignore
def check_if_token_is_revoked(decrypted_token: Any) -> bool:
subject = json.loads(decrypted_token["sub"])
subject = from_json(decrypted_token["sub"])
user_id = subject["id"]
token_type = subject["type"]
with db():
Expand Down
4 changes: 2 additions & 2 deletions antarest/login/web.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,6 @@
#
# This file is part of the Antares project.

import json
import logging
from datetime import timedelta
from typing import Any, List, Optional, Union
Expand Down Expand Up @@ -42,6 +41,7 @@
UserInfo,
)
from antarest.login.service import LoginService
from antarest.utils import from_json

logger = logging.getLogger(__name__)

Expand Down Expand Up @@ -103,7 +103,7 @@ def login(
)
def refresh(jwt_manager: AuthJWT = Depends()) -> Any:
jwt_manager.jwt_refresh_token_required()
identity = json.loads(jwt_manager.get_jwt_subject())
identity = from_json(jwt_manager.get_jwt_subject())
logger.debug(f"Refreshing access token for {identity['id']}")
user = service.get_jwt(identity["id"])
if user:
Expand Down
2 changes: 1 addition & 1 deletion antarest/main.py
Original file line number Diff line number Diff line change
Expand Up @@ -47,11 +47,11 @@
from antarest.login.auth import Auth, JwtSettings
from antarest.login.model import init_admin_user
from antarest.matrixstore.matrix_garbage_collector import MatrixGarbageCollector
from antarest.service_creator import SESSION_ARGS, Module, create_services, init_db_engine
from antarest.singleton_services import start_all_services
from antarest.study.storage.auto_archive_service import AutoArchiveService
from antarest.study.storage.rawstudy.watcher import Watcher
from antarest.tools.admin_lib import clean_locks
from antarest.utils import SESSION_ARGS, Module, create_services, init_db_engine

logger = logging.getLogger(__name__)

Expand Down
4 changes: 2 additions & 2 deletions antarest/matrixstore/service.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,6 @@

import contextlib
import io
import json
import logging
import tempfile
import typing as t
Expand Down Expand Up @@ -49,6 +48,7 @@
MatrixInfoDTO,
)
from antarest.matrixstore.repository import MatrixContentRepository, MatrixDataSetRepository, MatrixRepository
from antarest.utils import from_json

# List of files to exclude from ZIP archives
EXCLUDED_FILES = {
Expand Down Expand Up @@ -263,7 +263,7 @@ def _file_importation(self, file: bytes, *, is_json: bool = False) -> str:
A SHA256 hash that identifies the imported matrix.
"""
if is_json:
obj = json.loads(file)
obj = from_json(file)
content = MatrixContent(**obj)
return self.create(content.data)
# noinspection PyTypeChecker
Expand Down
Loading
Loading