Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat(pydantic): use pydantic serialization #2139

Merged
merged 12 commits into from
Sep 17, 2024
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,6 @@
"""
import collections
import itertools
import json
import secrets
import typing as t

Expand All @@ -16,6 +15,7 @@
from sqlalchemy.engine import Connection # type: ignore

from antarest.study.css4_colors import COLOR_NAMES
from antarest.core.serialization import from_json, to_json

# revision identifiers, used by Alembic.
revision = "dae93f1d9110"
Expand All @@ -34,7 +34,7 @@ def _avoid_duplicates(tags: t.Iterable[str]) -> t.Sequence[str]:
def _load_patch_obj(patch: t.Optional[str]) -> t.MutableMapping[str, t.Any]:
"""Load the patch object from the `patch` field in the `study_additional_data` table."""

obj: t.MutableMapping[str, t.Any] = json.loads(patch or "{}")
obj: t.MutableMapping[str, t.Any] = from_json(patch or "{}")
obj["study"] = obj.get("study") or {}
obj["study"]["tags"] = _avoid_duplicates(obj["study"].get("tags") or [])
return obj
Expand Down Expand Up @@ -113,7 +113,7 @@ def downgrade() -> None:
objects_by_ids[study_id] = obj

# Updating objects in the `study_additional_data` table
bulk_patches = [{"study_id": id_, "patch": json.dumps(obj)} for id_, obj in objects_by_ids.items()]
bulk_patches = [{"study_id": id_, "patch": to_json(obj)} for id_, obj in objects_by_ids.items()]
if bulk_patches:
sql = sa.text("UPDATE study_additional_data SET patch = :patch WHERE study_id = :study_id")
connexion.execute(sql, *bulk_patches)
Expand Down
4 changes: 2 additions & 2 deletions antarest/core/cache/business/redis_cache.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,6 @@
#
# This file is part of the Antares project.

import json
import logging
from typing import List, Optional

Expand All @@ -19,6 +18,7 @@

from antarest.core.interfaces.cache import ICache
from antarest.core.model import JSON
from antarest.core.serialization import from_json

logger = logging.getLogger(__name__)

Expand Down Expand Up @@ -49,7 +49,7 @@ def get(self, id: str, refresh_timeout: Optional[int] = None) -> Optional[JSON]:
logger.info(f"Trying to retrieve cache key {id}")
if result is not None:
logger.info(f"Cache key {id} found")
json_result = json.loads(result)
json_result = from_json(result)
redis_element = RedisCacheElement(duration=json_result["duration"], data=json_result["data"])
self.redis.expire(
redis_key,
Expand Down
6 changes: 3 additions & 3 deletions antarest/core/configdata/repository.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,13 +10,13 @@
#
# This file is part of the Antares project.

import json
from operator import and_
from typing import Optional

from antarest.core.configdata.model import ConfigData
from antarest.core.jwt import DEFAULT_ADMIN_USER
from antarest.core.model import JSON
from antarest.core.serialization import from_json, to_json_string
from antarest.core.utils.fastapi_sqlalchemy import db


Expand All @@ -43,14 +43,14 @@ def get(self, key: str, owner: Optional[int] = None) -> Optional[ConfigData]:
def get_json(self, key: str, owner: Optional[int] = None) -> Optional[JSON]:
configdata = self.get(key, owner)
if configdata:
data: JSON = json.loads(configdata.value)
data: JSON = from_json(configdata.value)
return data
return None

def put_json(self, key: str, data: JSON, owner: Optional[int] = None) -> None:
configdata = ConfigData(
key=key,
value=json.dumps(data),
value=to_json_string(data),
owner=owner or DEFAULT_ADMIN_USER.id,
)
configdata = db.session.merge(configdata)
Expand Down
34 changes: 34 additions & 0 deletions antarest/core/serialization/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,34 @@
# Copyright (c) 2024, RTE (https://www.rte-france.com)
#
# See AUTHORS.txt
#
# This Source Code Form is subject to the terms of the Mozilla Public
# License, v. 2.0. If a copy of the MPL was not distributed with this
# file, You can obtain one at http://mozilla.org/MPL/2.0/.
#
# SPDX-License-Identifier: MPL-2.0
#
# This file is part of the Antares project.
import typing as t

import pydantic

ADAPTER: pydantic.TypeAdapter[t.Any] = pydantic.TypeAdapter(
type=t.Any, config=pydantic.config.ConfigDict(ser_json_inf_nan="constants")
) # ser_json_inf_nan="constants" means infinity and NaN values will be serialized as `Infinity` and `NaN`.


# These utility functions allow to serialize with pydantic instead of using the built-in python "json" library.
# Since pydantic v2 is written in RUST it's way faster.


def from_json(data: t.Union[str, bytes, bytearray]) -> t.Dict[str, t.Any]:
return ADAPTER.validate_json(data) # type: ignore


def to_json(data: t.Any, indent: t.Optional[int] = None) -> bytes:
return ADAPTER.dump_json(data, indent=indent)


def to_json_string(data: t.Any, indent: t.Optional[int] = None) -> str:
return to_json(data, indent=indent).decode("utf-8")
6 changes: 3 additions & 3 deletions antarest/eventbus/web.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,13 +11,12 @@
# This file is part of the Antares project.

import dataclasses
import json
import logging
from enum import Enum
from http import HTTPStatus
from typing import List, Optional

from fastapi import APIRouter, Depends, FastAPI, HTTPException, Query
from fastapi import Depends, HTTPException, Query
from pydantic import BaseModel
from starlette.websockets import WebSocket, WebSocketDisconnect

Expand All @@ -27,6 +26,7 @@
from antarest.core.jwt import DEFAULT_ADMIN_USER, JWTUser
from antarest.core.model import PermissionInfo, StudyPermissionType
from antarest.core.permissions import check_permission
from antarest.core.serialization import to_json_string
from antarest.fastapi_jwt_auth import AuthJWT
from antarest.login.auth import Auth

Expand Down Expand Up @@ -99,7 +99,7 @@ async def send_event_to_ws(event: Event) -> None:
event_data = event.model_dump()
del event_data["permissions"]
del event_data["channel"]
await manager.broadcast(json.dumps(event_data), event.permissions, event.channel)
await manager.broadcast(to_json_string(event_data), event.permissions, event.channel)

@app_ctxt.api_root.websocket("/ws")
async def connect(
Expand Down
6 changes: 3 additions & 3 deletions antarest/launcher/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,15 +11,15 @@
# This file is part of the Antares project.

import enum
import json
import typing as t
from datetime import datetime

from pydantic import BaseModel, ConfigDict, Field
from pydantic import BaseModel, Field
from sqlalchemy import Column, DateTime, Enum, ForeignKey, Integer, Sequence, String # type: ignore
from sqlalchemy.orm import relationship # type: ignore

from antarest.core.persistence import Base
from antarest.core.serialization import from_json
from antarest.login.model import Identity, UserInfo
from antarest.study.business.all_optional_meta import camel_case_model

Expand Down Expand Up @@ -54,7 +54,7 @@ def from_launcher_params(cls, params: t.Optional[str]) -> "LauncherParametersDTO
"""
if params is None:
return cls()
return cls.model_validate(json.loads(params))
return cls.model_validate(from_json(params))


class LogType(str, enum.Enum):
Expand Down
6 changes: 3 additions & 3 deletions antarest/login/auth.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,6 @@
#
# This file is part of the Antares project.

import json
import logging
from datetime import timedelta
from typing import Any, Callable, Coroutine, Dict, Optional, Tuple, Union
Expand All @@ -22,6 +21,7 @@

from antarest.core.config import Config
from antarest.core.jwt import DEFAULT_ADMIN_USER, JWTUser
from antarest.core.serialization import from_json
from antarest.fastapi_jwt_auth import AuthJWT

logger = logging.getLogger(__name__)
Expand Down Expand Up @@ -66,14 +66,14 @@ def get_current_user(self, auth_jwt: AuthJWT = Depends()) -> JWTUser:

auth_jwt.jwt_required()

user = JWTUser.model_validate(json.loads(auth_jwt.get_jwt_subject()))
user = JWTUser.model_validate(from_json(auth_jwt.get_jwt_subject()))
return user

@staticmethod
def get_user_from_token(token: str, jwt_manager: AuthJWT) -> Optional[JWTUser]:
try:
token_data = jwt_manager._verified_token(token)
return JWTUser.model_validate(json.loads(token_data["sub"]))
return JWTUser.model_validate(from_json(token_data["sub"]))
except Exception as e:
logger.debug("Failed to retrieve user from token", exc_info=e)
return None
Expand Down
5 changes: 2 additions & 3 deletions antarest/login/main.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,17 +10,16 @@
#
# This file is part of the Antares project.

import json
from http import HTTPStatus
from typing import Any, Optional

from fastapi import APIRouter, FastAPI
from starlette.requests import Request
from starlette.responses import JSONResponse

from antarest.core.application import AppBuildContext
from antarest.core.config import Config
from antarest.core.interfaces.eventbus import DummyEventBusService, IEventBus
from antarest.core.serialization import from_json
from antarest.core.utils.fastapi_sqlalchemy import db
from antarest.fastapi_jwt_auth import AuthJWT
from antarest.fastapi_jwt_auth.exceptions import AuthJWTException
Expand Down Expand Up @@ -78,7 +77,7 @@ def authjwt_exception_handler(request: Request, exc: AuthJWTException) -> Any:

@AuthJWT.token_in_denylist_loader # type: ignore
def check_if_token_is_revoked(decrypted_token: Any) -> bool:
subject = json.loads(decrypted_token["sub"])
subject = from_json(decrypted_token["sub"])
user_id = subject["id"]
token_type = subject["type"]
with db():
Expand Down
4 changes: 2 additions & 2 deletions antarest/login/web.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,6 @@
#
# This file is part of the Antares project.

import json
import logging
from datetime import timedelta
from typing import Any, List, Optional, Union
Expand All @@ -23,6 +22,7 @@
from antarest.core.jwt import JWTGroup, JWTUser
from antarest.core.requests import RequestParameters, UserHasNotPermissionError
from antarest.core.roles import RoleType
from antarest.core.serialization import from_json
from antarest.core.utils.web import APITag
from antarest.fastapi_jwt_auth import AuthJWT
from antarest.login.auth import Auth
Expand Down Expand Up @@ -103,7 +103,7 @@ def login(
)
def refresh(jwt_manager: AuthJWT = Depends()) -> Any:
jwt_manager.jwt_refresh_token_required()
identity = json.loads(jwt_manager.get_jwt_subject())
identity = from_json(jwt_manager.get_jwt_subject())
logger.debug(f"Refreshing access token for {identity['id']}")
user = service.get_jwt(identity["id"])
if user:
Expand Down
2 changes: 1 addition & 1 deletion antarest/main.py
Original file line number Diff line number Diff line change
Expand Up @@ -47,11 +47,11 @@
from antarest.login.auth import Auth, JwtSettings
from antarest.login.model import init_admin_user
from antarest.matrixstore.matrix_garbage_collector import MatrixGarbageCollector
from antarest.service_creator import SESSION_ARGS, Module, create_services, init_db_engine
from antarest.singleton_services import start_all_services
from antarest.study.storage.auto_archive_service import AutoArchiveService
from antarest.study.storage.rawstudy.watcher import Watcher
from antarest.tools.admin_lib import clean_locks
from antarest.utils import SESSION_ARGS, Module, create_services, init_db_engine

logger = logging.getLogger(__name__)

Expand Down
4 changes: 2 additions & 2 deletions antarest/matrixstore/service.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,6 @@

import contextlib
import io
import json
import logging
import tempfile
import typing as t
Expand All @@ -31,6 +30,7 @@
from antarest.core.filetransfer.service import FileTransferManager
from antarest.core.jwt import JWTUser
from antarest.core.requests import RequestParameters, UserHasNotPermissionError
from antarest.core.serialization import from_json
from antarest.core.tasks.model import TaskResult, TaskType
from antarest.core.tasks.service import ITaskService, TaskUpdateNotifier
from antarest.core.utils.fastapi_sqlalchemy import db
Expand Down Expand Up @@ -263,7 +263,7 @@ def _file_importation(self, file: bytes, *, is_json: bool = False) -> str:
A SHA256 hash that identifies the imported matrix.
"""
if is_json:
obj = json.loads(file)
obj = from_json(file)
content = MatrixContent(**obj)
return self.create(content.data)
# noinspection PyTypeChecker
Expand Down
Loading
Loading