Skip to content

Commit

Permalink
Migrate to Pydantic V2 (#1495)
Browse files Browse the repository at this point in the history
* chore: update pydantic to version 2.10.2 and refactor model validators

* refactor: simplify field validators by removing unnecessary pre and always flags

* remove typing_extensions==4.9.0 from requirements.txt

* refactor: remove allow_reuse flag from status field validator

* refactor: simplify field validators by removing pre and always flags

* refactor: update user model imports and enhance account class with abstract method

* refactor: update model_config to use dictionary format in Admin and SubscriptionUserResponse classes

* fix typo in UserDataResetByNext

* change pre=True to mode="before"

* refactor: update validation methods and model configuration in User and Proxy classes

* change pre=False with mode="after"

* Migrated to Pydantic V2

* fix: custom subscriptions not workong

* some small changes

* add missing properties to example schema

* replace from_orm with model_validate

---------

Co-authored-by: MahdiButcher <[email protected]>
Co-authored-by: Mahdi Butcher <[email protected]>
  • Loading branch information
3 people authored Dec 9, 2024
1 parent afa6bc4 commit ea6a3d2
Show file tree
Hide file tree
Showing 24 changed files with 263 additions and 268 deletions.
2 changes: 1 addition & 1 deletion app/db/crud.py
Original file line number Diff line number Diff line change
Expand Up @@ -613,7 +613,7 @@ def revoke_user_sub(db: Session, dbuser: User) -> User:
"""
dbuser.sub_revoked_at = datetime.utcnow()

user = UserResponse.from_orm(dbuser)
user = UserResponse.model_validate(dbuser)
for proxy_type, settings in user.proxies.copy().items():
settings.revoke()
user.proxies[proxy_type] = settings
Expand Down
5 changes: 3 additions & 2 deletions app/dependencies.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,7 @@ def validate_admin(db: Session, username: str, password: str) -> Optional[AdminV
return AdminValidationResult(username=username, is_sudo=True)

dbadmin = crud.get_admin(db, username)
if dbadmin and AdminInDB.from_orm(dbadmin).verify_password(password):
if dbadmin and AdminInDB.model_validate(dbadmin).verify_password(password):
return AdminValidationResult(username=dbadmin.username, is_sudo=dbadmin.is_sudo)

return None
Expand All @@ -40,7 +40,8 @@ def validate_dates(start: Optional[Union[str, datetime]], end: Optional[Union[st
"""Validate if start and end dates are correct and if end is after start."""
try:
if start:
start_date = start if isinstance(start, datetime) else datetime.fromisoformat(start).astimezone(timezone.utc)
start_date = start if isinstance(start, datetime) else datetime.fromisoformat(
start).astimezone(timezone.utc)
else:
start_date = datetime.now(timezone.utc) - timedelta(days=30)
if end:
Expand Down
2 changes: 1 addition & 1 deletion app/jobs/remove_expired_users.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@ def remove_expired_users():
deleted_users = crud.autodelete_expired_users(db, USER_AUTODELETE_INCLUDE_LIMITED_ACCOUNTS)

for user in deleted_users:
report.user_deleted(user.username, SYSTEM_ADMIN, user_admin=Admin.from_orm(user.admin))
report.user_deleted(user.username, SYSTEM_ADMIN, user_admin=Admin.model_validate(user.admin))
logger.log(logging.INFO, "Expired user %s deleted." % user.username)


Expand Down
22 changes: 12 additions & 10 deletions app/jobs/review_users.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,7 +25,7 @@ def add_notification_reminders(db: Session, user: "User", now: datetime = dateti
if usage_percent >= percent:
if not get_notification_reminder(db, user.id, ReminderType.data_usage, threshold=percent):
report.data_usage_percent_reached(
db, usage_percent, UserResponse.from_orm(user),
db, usage_percent, UserResponse.model_validate(user),
user.id, user.expire, threshold=percent
)
break
Expand All @@ -37,17 +37,19 @@ def add_notification_reminders(db: Session, user: "User", now: datetime = dateti
if expire_days <= days_left:
if not get_notification_reminder(db, user.id, ReminderType.expiration_date, threshold=days_left):
report.expire_days_reached(
db, expire_days, UserResponse.from_orm(user),
db, expire_days, UserResponse.model_validate(user),
user.id, user.expire, threshold=days_left
)
break


def reset_user_by_next_report(db: Session, user: "User"):
user = reset_user_by_next(db, user)

xray.operations.update_user(user)

report.user_data_reset_by_next(user=UserResponse.from_orm(user), user_admin=user.admin)

report.user_data_reset_by_next(user=UserResponse.model_validate(user), user_admin=user.admin)


def review():
now = datetime.utcnow()
Expand All @@ -60,15 +62,15 @@ def review():

if (limited or expired) and user.next_plan is not None:
if user.next_plan is not None:

if user.next_plan.fire_on_either:
reset_user_by_next_report(db, user)
continue

elif limited and expired:
reset_user_by_next_report(db, user)
continue

if limited:
status = UserStatus.limited
elif expired:
Expand All @@ -82,7 +84,7 @@ def review():
update_user_status(db, user, status)

report.status_change(username=user.username, status=status,
user=UserResponse.from_orm(user), user_admin=user.admin)
user=UserResponse.model_validate(user), user_admin=user.admin)

logger.info(f"User \"{user.username}\" status changed to {status}")

Expand All @@ -108,7 +110,7 @@ def review():
start_user_expire(db, user)

report.status_change(username=user.username, status=status,
user=UserResponse.from_orm(user), user_admin=user.admin)
user=UserResponse.model_validate(user), user_admin=user.admin)

logger.info(f"User \"{user.username}\" status changed to {status}")

Expand Down
37 changes: 19 additions & 18 deletions app/models/admin.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@
from fastapi import Depends, HTTPException, status
from fastapi.security import OAuth2PasswordBearer
from passlib.context import CryptContext
from pydantic import BaseModel, validator
from pydantic import ConfigDict, field_validator, BaseModel

from app.db import Session, crud, get_db
from app.utils.jwt import get_admin_payload
Expand All @@ -21,12 +21,10 @@ class Token(BaseModel):
class Admin(BaseModel):
username: str
is_sudo: bool
telegram_id: Optional[int]
discord_webhook: Optional[str]
users_usage: Optional[int]

class Config:
orm_mode = True
telegram_id: Optional[int] = None
discord_webhook: Optional[str] = None
users_usage: Optional[int] = None
model_config = ConfigDict(from_attributes=True)

@classmethod
def get_admin(cls, token: str, db: Session):
Expand All @@ -47,7 +45,7 @@ def get_admin(cls, token: str, db: Session):
if dbadmin.password_reset_at > payload.get("created_at"):
return

return cls.from_orm(dbadmin)
return cls.model_validate(dbadmin)

@classmethod
def get_current(cls,
Expand All @@ -60,13 +58,12 @@ def get_current(cls,
detail="Could not validate credentials",
headers={"WWW-Authenticate": "Bearer"},
)

return admin

@classmethod
def check_sudo_admin(cls,
db: Session = Depends(get_db),
token: str = Depends(oauth2_scheme)):
db: Session = Depends(get_db),
token: str = Depends(oauth2_scheme)):
admin = cls.get_admin(token, db)
if not admin:
raise HTTPException(
Expand All @@ -81,34 +78,37 @@ def check_sudo_admin(cls,
)
return admin


class AdminCreate(Admin):
password: str
telegram_id: Optional[int]
discord_webhook: Optional[str]
telegram_id: Optional[int] = None
discord_webhook: Optional[str] = None

@property
def hashed_password(self):
return pwd_context.hash(self.password)

@validator("discord_webhook")
@field_validator("discord_webhook")
@classmethod
def validate_discord_webhook(cls, value):
if value and not value.startswith("https://discord.com"):
raise ValueError("Discord webhook must start with 'https://discord.com'")
return value


class AdminModify(BaseModel):
password: Optional[str]
password: Optional[str] = None
is_sudo: bool
telegram_id: Optional[int]
discord_webhook: Optional[str]
telegram_id: Optional[int] = None
discord_webhook: Optional[str] = None

@property
def hashed_password(self):
if self.password:
return pwd_context.hash(self.password)

@validator("discord_webhook")
@field_validator("discord_webhook")
@classmethod
def validate_discord_webhook(cls, value):
if value and not value.startswith("https://discord.com"):
raise ValueError("Discord webhook must start with 'https://discord.com'")
Expand All @@ -126,6 +126,7 @@ class AdminInDB(Admin):
def verify_password(self, plain_password):
return pwd_context.verify(plain_password, self.hashed_password)


class AdminValidationResult(BaseModel):
username: str
is_sudo: bool
52 changes: 23 additions & 29 deletions app/models/node.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
from enum import Enum
from typing import List, Optional

from pydantic import BaseModel, Field
from pydantic import ConfigDict, BaseModel, Field


class NodeStatus(str, Enum):
Expand All @@ -26,18 +26,16 @@ class Node(BaseModel):

class NodeCreate(Node):
add_as_new_host: bool = True

class Config:
schema_extra = {
"example": {
"name": "DE node",
"address": "192.168.1.1",
"port": 62050,
"api_port": 62051,
"add_as_new_host": True,
"usage_coefficient": 1
}
model_config = ConfigDict(json_schema_extra={
"example": {
"name": "DE node",
"address": "192.168.1.1",
"port": 62050,
"api_port": 62051,
"add_as_new_host": True,
"usage_coefficient": 1
}
})


class NodeModify(Node):
Expand All @@ -47,32 +45,28 @@ class NodeModify(Node):
api_port: Optional[int] = Field(None, nullable=True)
status: Optional[NodeStatus] = Field(None, nullable=True)
usage_coefficient: Optional[float] = Field(None, nullable=True)

class Config:
schema_extra = {
"example": {
"name": "DE node",
"address": "192.168.1.1",
"port": 62050,
"api_port": 62051,
"status": "disabled",
"usage_coefficient": 1.0
}
model_config = ConfigDict(json_schema_extra={
"example": {
"name": "DE node",
"address": "192.168.1.1",
"port": 62050,
"api_port": 62051,
"status": "disabled",
"usage_coefficient": 1.0
}
})


class NodeResponse(Node):
id: int
xray_version: Optional[str]
xray_version: Optional[str] = None
status: NodeStatus
message: Optional[str]

class Config:
orm_mode = True
message: Optional[str] = None
model_config = ConfigDict(from_attributes=True)


class NodeUsageResponse(BaseModel):
node_id: Optional[int]
node_id: Optional[int] = None
node_name: str
uplink: int
downlink: int
Expand Down
20 changes: 10 additions & 10 deletions app/models/proxy.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@
from typing import Optional, Union
from uuid import UUID, uuid4

from pydantic import BaseModel, Field, validator
from pydantic import field_validator, ConfigDict, BaseModel, Field

from app.utils.system import random_password
from xray_api.types.account import (
Expand Down Expand Up @@ -53,10 +53,10 @@ def settings_model(self):
return ShadowsocksSettings


class ProxySettings(BaseModel):
class ProxySettings(BaseModel, use_enum_values=True):
@classmethod
def from_dict(cls, proxy_type: ProxyTypes, _dict: dict):
return ProxyTypes(proxy_type).settings_model.parse_obj(_dict)
return ProxyTypes(proxy_type).settings_model.model_validate(_dict)

def dict(self, *, no_obj=False, **kwargs):
if no_obj:
Expand Down Expand Up @@ -154,11 +154,9 @@ class ProxyHost(BaseModel):
fragment_setting: Optional[str] = Field(None, nullable=True)
noise_setting: Optional[str] = Field(None, nullable=True)
random_user_agent: Union[bool, None] = None
model_config = ConfigDict(from_attributes=True)

class Config:
orm_mode = True

@validator("remark", pre=False, always=True)
@field_validator("remark", mode="after")
def validate_remark(cls, v):
try:
v.format_map(FormatVariables())
Expand All @@ -167,7 +165,7 @@ def validate_remark(cls, v):

return v

@validator("address", pre=False, always=True)
@field_validator("address", mode="after")
def validate_address(cls, v):
try:
v.format_map(FormatVariables())
Expand All @@ -176,15 +174,17 @@ def validate_address(cls, v):

return v

@validator("fragment_setting", check_fields=False)
@field_validator("fragment_setting", check_fields=False)
@classmethod
def validate_fragment(cls, v):
if v and not FRAGMENT_PATTERN.match(v):
raise ValueError(
"Fragment setting must be like this: length,interval,packet (10-100,100-200,tlshello)."
)
return v

@validator("noise_setting", check_fields=False)
@field_validator("noise_setting", check_fields=False)
@classmethod
def validate_noise(cls, v):
if v:
if not NOISE_PATTERN.match(v):
Expand Down
Loading

0 comments on commit ea6a3d2

Please sign in to comment.