From 114c7bd004be2afc8d529cd52403250a1041bbdd Mon Sep 17 00:00:00 2001 From: davidche Date: Tue, 27 Aug 2024 16:46:03 +0800 Subject: [PATCH] Update model primary key for dynamic retrieval (#23) MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit * 删除、更新、查询支持主键名称非ID,从sqlalchemy 模型定义中获取主键 * custom error and update Document * Update docs for primary key * fix lint --- docs/advanced/primary_key.md | 16 +++++++++++++++- docs/usage/delete_model.md | 4 +++- docs/usage/select_model.md | 6 ++++-- docs/usage/update_model.md | 4 +++- sqlalchemy_crud_plus/crud.py | 22 +++++++++++++++++----- sqlalchemy_crud_plus/errors.py | 7 +++++++ 6 files changed, 49 insertions(+), 10 deletions(-) diff --git a/docs/advanced/primary_key.md b/docs/advanced/primary_key.md index 6211408..6b21358 100644 --- a/docs/advanced/primary_key.md +++ b/docs/advanced/primary_key.md @@ -1,8 +1,22 @@ !!! note 主键参数命名 由于在 python 内部 id 的特殊性,我们设定 pk (参考 Django) 作为模型主键命名,所以在 crud 方法中,任何涉及到主键的地方,入参都为 `pk` - + ```py title="e.g." hl_lines="2" async def delete(self, db: AsyncSession, primary_key: int) -> int: return self.delete_model(db, pk=primary_key) ``` + +## 主键定义 + +!!! warning 自动主键 + + 我们在 SQLAlchemy CRUD Plus 内部通过 [inspect()](https://docs.sqlalchemy.org/en/20/core/inspection.html) 自动搜索表主键, + 而非强制绑定主键列必须命名为 id,感谢 [@DavidSche](https://github.com/DavidSche) 提供帮助 + +```py title="e.g." hl_lines="4" +class ModelIns(Base): + # your sqlalchemy model + # define your primary_key + custom_id: Mapped[int] = mapped_column(primary_key=True, index=True, autoincrement=True) +``` diff --git a/docs/usage/delete_model.md b/docs/usage/delete_model.md index bd626f1..45cbdc6 100644 --- a/docs/usage/delete_model.md +++ b/docs/usage/delete_model.md @@ -18,13 +18,15 @@ from pydantic import BaseModel from sqlalchemy_crud_plus import CRUDPlus +from sqlalchemy import Mapped, mapped_column from sqlalchemy import DeclarativeBase as Base from sqlalchemy.ext.asyncio import AsyncSession class ModelIns(Base): # your sqlalchemy model - pass + # define your primary_key + custom_id: Mapped[int] = mapped_column(primary_key=True, index=True, autoincrement=True) class CreateIns(BaseModel): diff --git a/docs/usage/select_model.md b/docs/usage/select_model.md index 4e6c45d..960e168 100644 --- a/docs/usage/select_model.md +++ b/docs/usage/select_model.md @@ -15,13 +15,15 @@ from pydantic import BaseModel from sqlalchemy_crud_plus import CRUDPlus +from sqlalchemy import Mapped, mapped_column from sqlalchemy import DeclarativeBase as Base from sqlalchemy.ext.asyncio import AsyncSession class ModelIns(Base): # your sqlalchemy model - pass + # define your primary_key + custom_id: Mapped[int] = mapped_column(primary_key=True, index=True, autoincrement=True) class CreateIns(BaseModel): @@ -30,6 +32,6 @@ class CreateIns(BaseModel): class CRUDIns(CRUDPlus[ModelIns]): - async def create(self, db: AsyncSession, pk: int) -> ModelIns: + async def select(self, db: AsyncSession, pk: int) -> ModelIns: return await self.select_model(db, pk) ``` diff --git a/docs/usage/update_model.md b/docs/usage/update_model.md index 4ba1bc6..6d3dddd 100644 --- a/docs/usage/update_model.md +++ b/docs/usage/update_model.md @@ -19,13 +19,15 @@ from pydantic import BaseModel from sqlalchemy_crud_plus import CRUDPlus +from sqlalchemy import Mapped, mapped_column from sqlalchemy import DeclarativeBase as Base from sqlalchemy.ext.asyncio import AsyncSession class ModelIns(Base): # your sqlalchemy model - pass + # define your primary_key + custom_id: Mapped[int] = mapped_column(primary_key=True, index=True, autoincrement=True) class UpdateIns(BaseModel): diff --git a/sqlalchemy_crud_plus/crud.py b/sqlalchemy_crud_plus/crud.py index 10abc80..2de200e 100644 --- a/sqlalchemy_crud_plus/crud.py +++ b/sqlalchemy_crud_plus/crud.py @@ -2,10 +2,10 @@ # -*- coding: utf-8 -*- from typing import Any, Generic, Iterable, Sequence, Type -from sqlalchemy import Row, RowMapping, Select, delete, select, update +from sqlalchemy import Row, RowMapping, Select, delete, inspect, select, update from sqlalchemy.ext.asyncio import AsyncSession -from sqlalchemy_crud_plus.errors import MultipleResultsError +from sqlalchemy_crud_plus.errors import CompositePrimaryKeysError, MultipleResultsError from sqlalchemy_crud_plus.types import CreateSchema, Model, UpdateSchema from sqlalchemy_crud_plus.utils import apply_sorting, count, parse_filters @@ -13,6 +13,18 @@ class CRUDPlus(Generic[Model]): def __init__(self, model: Type[Model]): self.model = model + self.primary_key = self._get_primary_key() + + def _get_primary_key(self): + """ + Dynamically retrieve the primary key column(s) for the model. + """ + mapper = inspect(self.model) + primary_key = mapper.primary_key + if len(primary_key) == 1: + return primary_key[0] + else: + raise CompositePrimaryKeysError('Composite primary keys are not supported') async def create_model( self, @@ -69,7 +81,7 @@ async def select_model(self, session: AsyncSession, pk: int) -> Model | None: :param pk: The database primary key value. :return: """ - stmt = select(self.model).where(self.model.id == pk) + stmt = select(self.model).where(self.primary_key == pk) query = await session.execute(stmt) return query.scalars().first() @@ -166,7 +178,7 @@ async def update_model( instance_data = obj else: instance_data = obj.model_dump(exclude_unset=True) - stmt = update(self.model).where(self.model.id == pk).values(**instance_data) + stmt = update(self.model).where(self.primary_key == pk).values(**instance_data) result = await session.execute(stmt) if commit: await session.commit() @@ -218,7 +230,7 @@ async def delete_model( :param commit: If `True`, commits the transaction immediately. Default is `False`. :return: """ - stmt = delete(self.model).where(self.model.id == pk) + stmt = delete(self.model).where(self.primary_key == pk) result = await session.execute(stmt) if commit: await session.commit() diff --git a/sqlalchemy_crud_plus/errors.py b/sqlalchemy_crud_plus/errors.py index eb9b884..197c197 100644 --- a/sqlalchemy_crud_plus/errors.py +++ b/sqlalchemy_crud_plus/errors.py @@ -36,3 +36,10 @@ class MultipleResultsError(SQLAlchemyCRUDPlusException): def __init__(self, msg: str) -> None: super().__init__(msg) + + +class CompositePrimaryKeysError(SQLAlchemyCRUDPlusException): + """Error raised when a table have Composite primary keys.""" + + def __init__(self, msg: str) -> None: + super().__init__(msg)