Skip to content

Commit

Permalink
Merge pull request #12 from nickatnight/nickatnight-GH-9
Browse files Browse the repository at this point in the history
[GH-9] Add repository and interfaces
  • Loading branch information
nickatnight authored Feb 7, 2023
2 parents 189124a + 5c55058 commit d84b44d
Show file tree
Hide file tree
Showing 16 changed files with 209 additions and 80 deletions.
2 changes: 0 additions & 2 deletions cookiecutter.json
Original file line number Diff line number Diff line change
Expand Up @@ -4,9 +4,7 @@
"project_slug_db": "{{ cookiecutter.project_name|lower|replace(' ', '') }}",

"db_container_name": "db",

"backend_container_name": "backend",

"nginx_container_name": "nginx",

"doctl_version": "1.92.0",
Expand Down
2 changes: 1 addition & 1 deletion hooks/post_gen_project.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@
"%smodels/meme.py" % BASE_BACKEND_SRC_PATH,
"%sschemas/meme.py" % BASE_BACKEND_SRC_PATH,
"%sapi/v1/meme.py" % BASE_BACKEND_SRC_PATH,
"%smigrations/versions/3577cec8a2bb_init.py" % BASE_BACKEND_SRC_PATH,
"%srepository/meme.py" % BASE_BACKEND_SRC_PATH,
"%sdb/init_db.py" % BASE_BACKEND_SRC_PATH,
]
DEPLOYMENT_FILES = [
Expand Down
10 changes: 5 additions & 5 deletions {{ cookiecutter.project_slug }}/.pre-commit-config.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -25,8 +25,8 @@ repos:
hooks:
- id: isort
args: ["--settings-path=./{{ cookiecutter.backend_container_name }}/pyproject.toml"]
- repo: https://github.com/pre-commit/mirrors-mypy
rev: v0.982
hooks:
- id: mypy
args: ["--config-file=./{{ cookiecutter.backend_container_name }}/pyproject.toml"]
# - repo: https://github.com/pre-commit/mirrors-mypy
# rev: v0.982
# hooks:
# - id: mypy
# args: ["--config-file=./{{ cookiecutter.backend_container_name }}/pyproject.toml"]
Original file line number Diff line number Diff line change
Expand Up @@ -62,16 +62,16 @@ omit = ['*tests/*']
exclude = ["migrations/"]
# --strict
disallow_any_generics = true
disallow_subclassing_any = true
disallow_untyped_calls = true
disallow_subclassing_any = true
disallow_untyped_calls = true
disallow_untyped_defs = true
disallow_incomplete_defs = true
check_untyped_defs = true
disallow_untyped_decorators = true
disallow_incomplete_defs = true
check_untyped_defs = true
disallow_untyped_decorators = true
no_implicit_optional = true
warn_redundant_casts = true
warn_redundant_casts = true
warn_unused_ignores = true
warn_return_any = true
warn_return_any = true
implicit_reexport = false
strict_equality = true
# --strict end
Expand Down
Original file line number Diff line number Diff line change
@@ -1,8 +1,11 @@
from typing import Optional, List

from fastapi import APIRouter, Depends, Query
from sqlalchemy.ext.asyncio import AsyncSession
from sqlmodel import select

from src.core.enums import SortOrder
from src.db.session import get_session
from src.repositories.sqlalchemy import SQLAlchemyRepository
from src.models.meme import Meme
from src.schemas.common import IGetResponseBase
from src.schemas.meme import IMemeRead
Expand All @@ -14,17 +17,17 @@
@router.get(
"/memes",
response_description="List all meme instances",
response_model=IGetResponseBase[IMemeRead],
response_model=IGetResponseBase[List[IMemeRead]],
tags=["memes"],
)
async def memes(
skip: int = Query(0, ge=0),
limit: int = Query(100, ge=1),
limit: int = Query(50, ge=1),
sort_field: Optional[str] = "created_at",
sort_order: Optional[str] = SortOrder.DESC,
session: AsyncSession = Depends(get_session),
) -> IGetResponseBase[IMemeRead]:
result = await session.execute(
select(Meme).offset(skip).limit(limit).order_by(Meme.created_at.desc()) # type: ignore
)
memes = result.scalars().all()
) -> IGetResponseBase[List[IMemeRead]]:
meme_repo = SQLAlchemyRepository(model=Meme, db=session)
memes = await meme_repo.all(skip=skip, limit=limit, sort_field=sort_field, sort_order=sort_order)

return IGetResponseBase[IMemeRead](data=memes)
return IGetResponseBase[List[IMemeRead]](data=memes)
Original file line number Diff line number Diff line change
@@ -0,0 +1,3 @@
class SortOrder:
ASC = "asc"
DESC = "desc"
Original file line number Diff line number Diff line change
@@ -0,0 +1,7 @@

class {{ cookiecutter.project_name|title|replace(' ', '') }}Exception(Exception):
pass


class ObjectNotFound({{ cookiecutter.project_name|title|replace(' ', '') }}Exception):
pass
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,6 @@

from src.db.session import SessionLocal
from src.models.meme import Meme
from src.schemas.meme import IMemeCreate

logging.basicConfig(level=logging.INFO)

Expand All @@ -24,7 +23,7 @@ async def create_init_data() -> None:
db_obj2 = Meme(
submission_id="10t11t6",
submission_title="just paid for wives bf vacation",
submission_url="https://www.reddit.com/r/dogecoin/comments/10t11t6/just_paid_for_wives_bf_vacation/",
submission_url="https://www.reddit.com/r/dogecoin/comments/10t11t6/just_paid_for_wives_bf_vacation/", # noqa
permalink="/r/dogecoin/comments/10t11t6/just_paid_for_wives_bf_vacation/",
author="DynamicHordeOnion",
timestamp=1675473133.0,
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,34 @@
from abc import ABCMeta, abstractmethod
from typing import Generic, Optional, TypeVar, List


T = TypeVar("T")


class IRepository(Generic[T], metaclass=ABCMeta):
"""Class representing the repository interface."""

@abstractmethod
async def create(self, obj_in: T, **kwargs: int) -> T:
"""Create new entity and returns the saved instance."""
raise NotImplementedError

@abstractmethod
async def update(self, instance: T, obj_in: T) -> T:
"""Updates an entity and returns the saved instance."""
raise NotImplementedError

@abstractmethod
async def get(self, **kwargs: int) -> T:
"""Get and return one instance by filter."""
raise NotImplementedError

@abstractmethod
async def delete(self, **kwargs: int) -> None:
"""Delete one instance by filter."""
raise NotImplementedError

@abstractmethod
async def all(self, skip: int = 0, limit: int = 50, sort_field: Optional[str] = None, sort_order: Optional[str] = None) -> List[T]:
"""Delete one instance by filter."""
raise NotImplementedError

This file was deleted.

Original file line number Diff line number Diff line change
Expand Up @@ -2,25 +2,47 @@
from datetime import datetime
from typing import Optional

from sqlalchemy import text
from sqlalchemy.ext.compiler import compiles
from sqlalchemy.sql import expression
from sqlmodel import Column, DateTime, Field, SQLModel


# https://docs.sqlalchemy.org/en/20/core/compiler.html#utc-timestamp-function
class utcnow(expression.FunctionElement): # type: ignore
type = DateTime()
inherit_cache = True


@compiles(utcnow, "postgresql") # type: ignore
def pg_utcnow(element, compiler, **kw) -> str:
return "TIMEZONE('utc', CURRENT_TIMESTAMP)"


class BaseModel(SQLModel):
id: uuid_pkg.UUID = Field(
default_factory=uuid_pkg.uuid4,
id: Optional[int] = Field(
default=None,
primary_key=True,
index=True,
)
ref_id: Optional[uuid_pkg.UUID] = Field(
default_factory=uuid_pkg.uuid4,
index=True,
nullable=False,
sa_column_kwargs={"server_default": text("gen_random_uuid()"), "unique": True},
)
updated_at: Optional[datetime] = Field(
created_at: Optional[datetime] = Field(
sa_column=Column(
DateTime(timezone=True),
server_default=utcnow(),
nullable=True,
)
)
created_at: Optional[datetime] = Field(
updated_at: Optional[datetime] = Field(
default_factory=datetime.utcnow,
sa_column=Column(
DateTime(timezone=True),
onupdate=utcnow(),
nullable=True,
)
),
)
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
from datetime import datetime, timezone

from pydantic import BaseConfig, validator
from pydantic import BaseConfig
from sqlmodel import Column, DateTime, Field, SQLModel

from src.models.base import BaseModel
Expand All @@ -25,7 +25,8 @@ class Config(BaseConfig):
}
schema_extra = {
"example": {
"id": "1234-43143-3134-13423",
"id": 1,
"ref_id": "1234-43143-3134-13423",
"submission_id": "nny218",
"submission_title": "This community is so nice. Helps me hodl.",
"submission_url": "https://i.redd.it/gdv6tbamkb271.jpg",
Expand All @@ -38,6 +39,4 @@ class Config(BaseConfig):


class Meme(BaseModel, MemeBase, table=True):
@validator("created_at", pre=True, always=True)
def set_created_at_now(cls, v: datetime) -> datetime:
return v or datetime.now(timezone.utc)
pass
Original file line number Diff line number Diff line change
@@ -0,0 +1,103 @@
import logging
from typing import Optional, Type, TypeVar, List

from sqlalchemy.ext.asyncio import AsyncSession
from sqlmodel import SQLModel, select

from src.core.exceptions import ObjectNotFound
from src.interfaces.repository import IRepository

ModelType = TypeVar("ModelType", bound=SQLModel)
logger: logging.Logger = logging.getLogger(__name__)


class SQLAlchemyRepository(IRepository[ModelType]):

def __init__(self, model: Type[ModelType], db: AsyncSession) -> None:
self.model = model
self.db = db

async def create(self, obj_in: ModelType, **kwargs: int) -> ModelType:
logger.info(f"Inserting new object[{obj_in.__class__.__name__}]")

db_obj = self.model.from_orm(obj_in)
add = kwargs.get("add", True)
flush = kwargs.get("flush", True)
commit = kwargs.get("commit", True)

if add:
self.db.add(db_obj)

# Navigate these with caution
if add and commit:
try:
await self.db.commit()
await self.db.refresh(db_obj)
except Exception as exc:
logger.error(exc)
await self.db.rollback()

elif add and flush:
await self.db.flush()

return db_obj

async def get(self, **kwargs: int) -> ModelType:
logger.info(f"Fetching [{self.model}] object by [{kwargs}]")

query = select(self.model).filter_by(**kwargs)
response = await self.db.execute(query)
scalar: Optional[ModelType] = response.scalar_one_or_none()

if not scalar:
raise ObjectNotFound(f"Object with [{kwargs}] not found.")

return scalar

async def update(self, obj_current: ModelType, obj_in: ModelType) -> ModelType:
logger.info(f"Updating [{self.model}] object with [{obj_in}]")

update_data = obj_in.dict(
exclude_unset=True
) # This tells Pydantic to not include the values that were not sent

for field in update_data:
setattr(obj_current, field, update_data[field])

self.db.add(obj_current)
await self.db.commit()
await self.db.refresh(obj_current)

return obj_current

async def delete(self, **kwargs: int) -> None:
obj = self.get(**kwargs)

await self.db.delete(obj)
await self.db.commit()

async def all(
self,
skip: int = 0,
limit: int = 100,
sort_field: Optional[str] = None,
sort_order: Optional[str] = None,
) -> List[ModelType]:
columns = self.model.__table__.columns # type: ignore

if not sort_field:
sort_field = "created_at"

if not sort_order:
sort_order = "desc"

order_by = getattr(columns[sort_field], sort_order)()
query = (
select(self.model)
.offset(skip)
.limit(limit)
.order_by(order_by)
)

response = await self.db.execute(query)
return response.scalars().all()
Original file line number Diff line number Diff line change
Expand Up @@ -8,4 +8,8 @@ class IMemeCreate(MemeBase):


class IMemeRead(MemeBase):
id: UUID
ref_id: UUID


class IMemeUpdate(MemeBase):
pass

0 comments on commit d84b44d

Please sign in to comment.