Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

fix(pydantic): allow str fields to be populated by int #2166

Merged
merged 5 commits into from
Oct 2, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
5 changes: 2 additions & 3 deletions antarest/core/cache/business/local_chache.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,16 +15,15 @@
import time
from typing import Dict, List, Optional

from pydantic import BaseModel

from antarest.core.config import CacheConfig
from antarest.core.interfaces.cache import ICache
from antarest.core.model import JSON
from antarest.core.serialization import AntaresBaseModel

logger = logging.getLogger(__name__)


class LocalCacheElement(BaseModel):
class LocalCacheElement(AntaresBaseModel):
timeout: int
duration: int
data: JSON
Expand Down
5 changes: 2 additions & 3 deletions antarest/core/cache/business/redis_cache.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,17 +13,16 @@
import logging
from typing import List, Optional

from pydantic import BaseModel
from redis.client import Redis

from antarest.core.interfaces.cache import ICache
from antarest.core.model import JSON
from antarest.core.serialization import from_json
from antarest.core.serialization import AntaresBaseModel, from_json

logger = logging.getLogger(__name__)


class RedisCacheElement(BaseModel):
class RedisCacheElement(AntaresBaseModel):
duration: int
data: JSON

Expand Down
4 changes: 2 additions & 2 deletions antarest/core/configdata/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,13 +13,13 @@
from enum import Enum
from typing import Any, Optional

from pydantic import BaseModel
from sqlalchemy import Column, Integer, String # type: ignore

from antarest.core.persistence import Base
from antarest.core.serialization import AntaresBaseModel


class ConfigDataDTO(BaseModel):
class ConfigDataDTO(AntaresBaseModel):
key: str
value: Optional[str]

Expand Down
4 changes: 2 additions & 2 deletions antarest/core/core_blueprint.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,14 +13,14 @@
from typing import Any

from fastapi import APIRouter
from pydantic import BaseModel

from antarest.core.config import Config
from antarest.core.serialization import AntaresBaseModel
from antarest.core.utils.web import APITag
from antarest.core.version_info import VersionInfoDTO, get_commit_id, get_dependencies


class StatusDTO(BaseModel):
class StatusDTO(AntaresBaseModel):
status: str


Expand Down
6 changes: 3 additions & 3 deletions antarest/core/filetransfer/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,10 +15,10 @@
from http.client import HTTPException
from typing import Optional

from pydantic import BaseModel
from sqlalchemy import Boolean, Column, DateTime, Integer, String # type: ignore

from antarest.core.persistence import Base
from antarest.core.serialization import AntaresBaseModel


class FileDownloadNotFound(HTTPException):
Expand All @@ -37,7 +37,7 @@ def __init__(self) -> None:
)


class FileDownloadDTO(BaseModel):
class FileDownloadDTO(AntaresBaseModel):
id: str
name: str
filename: str
Expand All @@ -47,7 +47,7 @@ class FileDownloadDTO(BaseModel):
error_message: str = ""


class FileDownloadTaskDTO(BaseModel):
class FileDownloadTaskDTO(AntaresBaseModel):
file: FileDownloadDTO
task: str

Expand Down
5 changes: 2 additions & 3 deletions antarest/core/interfaces/eventbus.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,9 +14,8 @@
from enum import Enum
from typing import Any, Awaitable, Callable, List, Optional

from pydantic import BaseModel

from antarest.core.model import PermissionInfo
from antarest.core.serialization import AntaresBaseModel


class EventType(str, Enum):
Expand Down Expand Up @@ -56,7 +55,7 @@ class EventChannelDirectory:
STUDY_GENERATION = "GENERATION_TASK/"


class Event(BaseModel):
class Event(AntaresBaseModel):
type: EventType
payload: Any
permissions: PermissionInfo
Expand Down
7 changes: 3 additions & 4 deletions antarest/core/jwt.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,13 +12,12 @@

from typing import List, Union

from pydantic import BaseModel

from antarest.core.roles import RoleType
from antarest.core.serialization import AntaresBaseModel
from antarest.login.model import ADMIN_ID, Group, Identity


class JWTGroup(BaseModel):
class JWTGroup(AntaresBaseModel):
"""
Sub JWT domain with groups data belongs to user
"""
Expand All @@ -28,7 +27,7 @@ class JWTGroup(BaseModel):
role: RoleType


class JWTUser(BaseModel):
class JWTUser(AntaresBaseModel):
"""
JWT domain with user data.
"""
Expand Down
4 changes: 2 additions & 2 deletions antarest/core/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,7 @@
import enum
from typing import TYPE_CHECKING, Any, Dict, List, Optional, Union

from pydantic import BaseModel
from antarest.core.serialization import AntaresBaseModel

if TYPE_CHECKING:
# These dependencies are only used for type checking with mypy.
Expand Down Expand Up @@ -43,7 +43,7 @@ class StudyPermissionType(str, enum.Enum):
MANAGE_PERMISSIONS = "MANAGE_PERMISSIONS"


class PermissionInfo(BaseModel):
class PermissionInfo(AntaresBaseModel):
owner: Optional[int] = None
groups: List[str] = []
public_mode: PublicMode = PublicMode.NONE
Expand Down
15 changes: 15 additions & 0 deletions antarest/core/serialization/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,3 +32,18 @@ def to_json(data: t.Any, indent: t.Optional[int] = None) -> bytes:

def to_json_string(data: t.Any, indent: t.Optional[int] = None) -> str:
return to_json(data, indent=indent).decode("utf-8")


class AntaresBaseModel(pydantic.BaseModel):
"""
Due to pydantic migration from v1 to v2, we can have this issue:

class A(BaseModel):
a: str

A(a=2) raises ValidationError as we give an int instead of a str

To avoid this issue we created our own BaseModel class that inherits from BaseModel and allows such object creation.
"""

model_config = pydantic.config.ConfigDict(coerce_numbers_to_str=True)
14 changes: 7 additions & 7 deletions antarest/core/tasks/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,12 +15,12 @@
from datetime import datetime
from enum import Enum

from pydantic import BaseModel
from sqlalchemy import Boolean, Column, DateTime, ForeignKey, Integer, Sequence, String # type: ignore
from sqlalchemy.engine.base import Engine # type: ignore
from sqlalchemy.orm import relationship, sessionmaker # type: ignore

from antarest.core.persistence import Base
from antarest.core.serialization import AntaresBaseModel

if t.TYPE_CHECKING:
# avoid circular import
Expand Down Expand Up @@ -57,30 +57,30 @@ def is_final(self) -> bool:
]


class TaskResult(BaseModel, extra="forbid"):
class TaskResult(AntaresBaseModel, extra="forbid"):
success: bool
message: str
# Can be used to store json serialized result
return_value: t.Optional[str] = None


class TaskLogDTO(BaseModel, extra="forbid"):
class TaskLogDTO(AntaresBaseModel, extra="forbid"):
id: str
message: str


class CustomTaskEventMessages(BaseModel, extra="forbid"):
class CustomTaskEventMessages(AntaresBaseModel, extra="forbid"):
start: str
running: str
end: str


class TaskEventPayload(BaseModel, extra="forbid"):
class TaskEventPayload(AntaresBaseModel, extra="forbid"):
id: str
message: str


class TaskDTO(BaseModel, extra="forbid"):
class TaskDTO(AntaresBaseModel, extra="forbid"):
id: str
name: str
owner: t.Optional[int] = None
Expand All @@ -93,7 +93,7 @@ class TaskDTO(BaseModel, extra="forbid"):
ref_id: t.Optional[str] = None


class TaskListFilter(BaseModel, extra="forbid"):
class TaskListFilter(AntaresBaseModel, extra="forbid"):
status: t.List[TaskStatus] = []
name: t.Optional[str] = None
type: t.List[TaskType] = []
Expand Down
2 changes: 2 additions & 0 deletions antarest/core/utils/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,3 +9,5 @@
# SPDX-License-Identifier: MPL-2.0
#
# This file is part of the Antares project.

__all__ = "AntaresBaseModel"
13 changes: 0 additions & 13 deletions antarest/core/utils/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,10 +24,8 @@
from pathlib import Path

import py7zr
import redis
from fastapi import HTTPException

from antarest.core.config import RedisConfig
from antarest.core.exceptions import ShouldNotHappenException

logger = logging.getLogger(__name__)
Expand Down Expand Up @@ -131,17 +129,6 @@ def get_local_path() -> Path:
return filepath


def new_redis_instance(config: RedisConfig) -> redis.Redis: # type: ignore
redis_client = redis.Redis(
host=config.host,
port=config.port,
password=config.password,
db=0,
retry_on_error=[redis.ConnectionError, redis.TimeoutError], # type: ignore
)
return redis_client # type: ignore


class StopWatch:
def __init__(self) -> None:
self.current_time: float = time.time()
Expand Down
4 changes: 2 additions & 2 deletions antarest/core/version_info.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,10 +18,10 @@
from pathlib import Path
from typing import Dict

from pydantic import BaseModel
from antarest.core.serialization import AntaresBaseModel


class VersionInfoDTO(BaseModel):
class VersionInfoDTO(AntaresBaseModel):
name: str = "AntaREST"
version: str
gitcommit: str
Expand Down
5 changes: 2 additions & 3 deletions antarest/eventbus/web.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,6 @@
from typing import List, Optional

from fastapi import Depends, HTTPException, Query
from pydantic import BaseModel
from starlette.websockets import WebSocket, WebSocketDisconnect

from antarest.core.application import AppBuildContext
Expand All @@ -26,7 +25,7 @@
from antarest.core.jwt import DEFAULT_ADMIN_USER, JWTUser
from antarest.core.model import PermissionInfo, StudyPermissionType
from antarest.core.permissions import check_permission
from antarest.core.serialization import to_json_string
from antarest.core.serialization import AntaresBaseModel, to_json_string
from antarest.fastapi_jwt_auth import AuthJWT
from antarest.login.auth import Auth

Expand All @@ -38,7 +37,7 @@ class WebsocketMessageAction(str, Enum):
UNSUBSCRIBE = "UNSUBSCRIBE"


class WebsocketMessage(BaseModel):
class WebsocketMessage(AntaresBaseModel):
action: WebsocketMessageAction
payload: str

Expand Down
4 changes: 2 additions & 2 deletions antarest/front.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,13 +25,13 @@
from typing import Any, Optional, Sequence

from fastapi import FastAPI
from pydantic import BaseModel
from starlette.middleware.base import BaseHTTPMiddleware, DispatchFunction, RequestResponseEndpoint
from starlette.requests import Request
from starlette.responses import FileResponse
from starlette.staticfiles import StaticFiles
from starlette.types import ASGIApp

from antarest.core.serialization import AntaresBaseModel
from antarest.core.utils.string import to_camel_case


Expand Down Expand Up @@ -77,7 +77,7 @@ async def dispatch(self, request: Request, call_next: RequestResponseEndpoint) -
return await call_next(request)


class BackEndConfig(BaseModel):
class BackEndConfig(AntaresBaseModel):
"""
Configuration about backend URLs served to the frontend.
"""
Expand Down
4 changes: 2 additions & 2 deletions antarest/launcher/adapters/log_parser.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,7 @@
import re
import typing as t

from pydantic import BaseModel
from antarest.core.serialization import AntaresBaseModel

_SearchFunc = t.Callable[[str], t.Optional[t.Match[str]]]

Expand Down Expand Up @@ -63,7 +63,7 @@
)


class LaunchProgressDTO(BaseModel):
class LaunchProgressDTO(AntaresBaseModel):
"""
Measure the progress of a study simulation.

Expand Down
Loading
Loading