Skip to content

Commit

Permalink
Merge branch 'initial-sync'
Browse files Browse the repository at this point in the history
  • Loading branch information
jerith committed Jul 13, 2023
2 parents 07034ab + d02fa15 commit 8f7c225
Show file tree
Hide file tree
Showing 12 changed files with 760 additions and 76 deletions.
128 changes: 127 additions & 1 deletion poetry.lock

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

3 changes: 3 additions & 0 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -19,13 +19,16 @@ include = [

[tool.poetry.dependencies]
python = "^3.11"
attrs = "^23.1.0"
sqlalchemy = "^2.0.16"
httpx = "^0.24.1"

[tool.poetry.group.dev.dependencies]
black = "^23.3.0"
mypy = "^1.2.0"
pytest = "^7.3.1"
pytest-cov = "^4.0.0"
pytest-httpx = "^0.22.0"
pytest-postgresql = "^5.0.0"
ruff = "^0.0.261"

Expand Down
97 changes: 97 additions & 0 deletions src/aaq_sync/data_export_client.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,97 @@
from collections.abc import Generator
from contextlib import AbstractContextManager
from typing import Any, Self, TypedDict, TypeVar

from attrs import define, field
from httpx import URL, Client

from .data_models import Base

T = TypeVar("T")
# Note: `TGen` on its own is equivalent to `TGen[Any]`.
TGen = Generator[T, None, None]
TBase = TypeVar("TBase", bound=Base)

JSONDict = dict[str, Any]


class PageMeta(TypedDict):
size: int
offset: int
limit: int


class ResponseJSON(TypedDict):
metadata: PageMeta
result: list[JSONDict]


@define
class PaginatedResponse:
client: "ExportClient"
table: str
page_meta: PageMeta
items: list[JSONDict]

@property
def is_last_page(self):
return self.page_meta["size"] < self.page_meta["limit"]

@classmethod
def from_json(
cls, client: "ExportClient", table: str, resp_json: ResponseJSON
) -> Self:
return cls(client, table, resp_json["metadata"], resp_json["result"])

def __iter__(self) -> TGen[JSONDict]:
yield from self.items

def iter_all(self) -> TGen[JSONDict]:
yield from self.items
if self.is_last_page:
return
limit = self.page_meta["limit"]
offset = self.page_meta["offset"] + limit
nextpage = self.client._get_data_export(self.table, limit=limit, offset=offset)
yield from nextpage.iter_all()


@define
class ExportClient(AbstractContextManager):
base_url: URL = field(converter=URL)
auth_token: str
_cached_client: Client | None = None

@property
def _client(self) -> Client:
if self._cached_client is None:
headers = {
"Accept": "application/json",
"Authorization": f"Bearer {self.auth_token}",
}
self._cached_client = Client(headers=headers)
return self._cached_client

def close(self):
if self._cached_client:
self._cached_client.close()
self._cached_client = None

def __exit__(self, exc_type, exc_val, exc_tb):
self.close()

def get_faqmatches(self, **kw) -> PaginatedResponse:
return self._get_data_export("faqmatches", **kw)

def get_model_items(self, model: type[TBase], **kw) -> TGen[TBase]:
paginated_items = self._get_data_export(model.__tablename__, **kw)
for item in paginated_items.iter_all():
yield model.from_json(item)

def _get_data_export(
self, table: str, limit: int = 1000, offset: int = 0
) -> PaginatedResponse:
params = {"limit": limit, "offset": offset}
resp = self._client.get(self.base_url.join(table), params=params)
resp.raise_for_status()
return PaginatedResponse.from_json(self, table, resp.json())
32 changes: 21 additions & 11 deletions src/aaq_sync/data_models.py
Original file line number Diff line number Diff line change
@@ -1,8 +1,8 @@
from datetime import datetime
from typing import Any, Self, TypeVar

from sqlalchemy import ARRAY, ColumnElement, Float, Integer, String
from sqlalchemy.orm import DeclarativeBase, Mapped, mapped_column
from sqlalchemy import ARRAY, ColumnElement, Float, String
from sqlalchemy.orm import DeclarativeBase, Mapped, MappedAsDataclass, mapped_column

T = TypeVar("T")

Expand All @@ -16,20 +16,21 @@ def _translate_json_field(col: ColumnElement, value: T) -> T | datetime | None:
# TypeDecorator, for example) but nothing in our existing models does that.
col_pyt = col.type.python_type
match value:
case None if col.nullable:
# Nullability isn't part of the python_type, so we need to check
# for it separately.
return value
case int(v) if col_pyt is datetime:
# Timestamps are represented as milliseconds since the unix epoch.
return datetime.utcfromtimestamp(v / 1000)
case "None" if col.nullable and col_pyt is not str:
# NULL values are apparently represented as the string "None".
return None
case _ if not isinstance(value, col_pyt):
[vtype, ctype] = [t.__name__ for t in [type(value), col_pyt]]
raise TypeError(f"{col.name} has type {vtype}, expected {ctype}")
# Everything else is already in an appropriate form.
return value


class Base(DeclarativeBase):
class Base(MappedAsDataclass, DeclarativeBase):
type_annotation_map = {
list[str]: ARRAY(String),
list[float]: ARRAY(Float),
Expand All @@ -54,8 +55,17 @@ def from_json(cls, json_dict: dict[str, Any]) -> Self:
raise ValueError(f"Extra keys in JSON for {cls.__tablename__}: {keys}")
return cls(**json_fixed)

def pkey_value(self) -> tuple:
"""
Return the value of the identity key for this instance, suitable for
matching instances from different sources (db vs json, for example) in
order to filter or compare them.
"""
return self.__mapper__.primary_key_from_instance(self)


class FAQModel(Base):
# dataclass options (such as kw_only) aren't inherited from parent classes.
class FAQModel(Base, kw_only=True):
"""
SQLAlchemy data model for FAQ
Expand All @@ -64,14 +74,14 @@ class FAQModel(Base):

__tablename__ = "faqmatches"

faq_id: Mapped[int] = mapped_column(Integer, primary_key=True)
faq_id: Mapped[int] = mapped_column(default=None, primary_key=True)
faq_added_utc: Mapped[datetime]
faq_updated_utc: Mapped[datetime]
faq_author: Mapped[str]
faq_title: Mapped[str]
faq_content_to_send: Mapped[str]
faq_tags: Mapped[list[str] | None]
faq_tags: Mapped[list[str] | None] = mapped_column(default=None)
faq_questions: Mapped[list[str]]
faq_contexts: Mapped[list[str] | None]
faq_thresholds: Mapped[list[float] | None]
faq_contexts: Mapped[list[str] | None] = mapped_column(default=None)
faq_thresholds: Mapped[list[float] | None] = mapped_column(default=None)
faq_weight: Mapped[int]
43 changes: 43 additions & 0 deletions src/aaq_sync/itertools.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,43 @@
from collections.abc import Iterable, Iterator
from typing import Generic, TypeVar

T = TypeVar("T")


class IteratorWithFinishedCheck(Generic[T]):
"""
An iterator that knows if it's reached its end.
"""

_iter: Iterator[T]
_next_item: T
_finished: bool = False

@property
def finished(self) -> bool:
return self._finished

def __init__(self, iterable: Iterable[T]):
self._iter = iter(iterable)
self._set_next()

def __iter__(self):
return self

def __next__(self) -> T:
if self._finished:
raise StopIteration
next_item = self._next_item
self._set_next()
return next_item

def _set_next(self):
try:
self._next_item = next(self._iter)
except StopIteration:
self._finished = True

def peek_next(self) -> T:
if self.finished:
raise StopIteration
return self._next_item
Loading

0 comments on commit 8f7c225

Please sign in to comment.