Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

basic graphql server #265

Merged
merged 6 commits into from
Sep 23, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 2 additions & 0 deletions .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -21,3 +21,5 @@ dev-env/
/data/
tmp/
.task/
node_modules/
package-lock.json
14 changes: 14 additions & 0 deletions Taskfile.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -48,6 +48,20 @@ tasks:
- a-file-not-exists-so-it-always-rerun
cmd: python start_grpc_server.py

graphql:
dotenv:
- .env
sources:
- '*.py'
- 'chii/**/*.py'
- 'rpc/**/*.py'
- 'gql/**/*.py'
- 'gql/**/*.graphql'
generates:
- a-file-not-exists-so-it-always-rerun
cmds:
- uvicorn gql.app:app

mypy: mypy --show-column-numbers chii rpc

lint:
Expand Down
6 changes: 3 additions & 3 deletions chii/compat/phpseralize.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,7 @@
import six


def load(fp: BytesIO) -> Any:
def __load(fp: BytesIO) -> Any:
"""Read a string from the open file object `fp` and interpret it as a
data stream of PHP-serialized objects, reconstructing and returning
the original object hierarchy.
Expand All @@ -34,7 +34,7 @@ def load(fp: BytesIO) -> Any:
reading, a `StringIO` object (`BytesIO` on Python 3), or any other custom
object that meets this interface.

`load` will read exactly one object from the stream. See the docstring of
`__load` will read exactly one object from the stream. See the docstring of
the module for this chained behavior.

If an object hook is given object-opcodes are supported in the serilization
Expand Down Expand Up @@ -116,7 +116,7 @@ def loads(data: bytes | str) -> Any:
string must be a bytestring.
"""
with BytesIO(six.ensure_binary(data)) as fp:
return load(fp)
return __load(fp)


def dict_to_list(d: dict[int, Any]) -> list[Any]:
Expand Down
9 changes: 9 additions & 0 deletions chii/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,6 +29,15 @@ def MYSQL_SYNC_DSN(self) -> str:
self.MYSQL_DB,
)

def MYSQL_ASYNC_DSN(self) -> str:
return "mysql+aiomysql://{}:{}@{}:{}/{}".format(
self.MYSQL_USER,
self.MYSQL_PASS,
self.MYSQL_HOST,
self.MYSQL_PORT,
self.MYSQL_DB,
)


config = Settings()

Expand Down
38 changes: 38 additions & 0 deletions chii/db/sa.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
import logging.config
import time

from sqlalchemy import (
Expand All @@ -19,6 +20,7 @@
update,
)
from sqlalchemy.dialects.mysql import insert
from sqlalchemy.ext.asyncio import async_sessionmaker, create_async_engine
from sqlalchemy.orm import joinedload, selectinload, sessionmaker, subqueryload
from sslog import logger

Expand Down Expand Up @@ -58,6 +60,29 @@ def get(T, *where, order=None):
return s


if config.debug:
# redirect echo logger to sslog
logging.config.dictConfig(
{
"version": 1,
"handlers": {
"sslog": {
"class": "sslog.InterceptHandler",
"level": "DEBUG",
}
},
"loggers": {
"": {"level": "INFO", "handlers": ["sslog"]},
"sqlalchemy.engine.Engine": {
"level": "INFO",
"handlers": ["sslog"],
"propagate": False,
},
},
}
)


def sync_session_maker():
engine = create_engine(
config.MYSQL_SYNC_DSN,
Expand All @@ -75,6 +100,19 @@ def sync_session_maker():
return sessionmaker(engine)


def async_session_maker():
engine = create_async_engine(
config.MYSQL_ASYNC_DSN(),
pool_recycle=14400,
pool_size=10,
max_overflow=20,
echo=config.debug,
execution_options={"statement_timeout": config.MYSQL_STMT_TIMEOUT},
)

return async_sessionmaker(engine)


def before_cursor_execute(
conn: Connection, cursor, statement, parameters, context, executemany
):
Expand Down
5 changes: 4 additions & 1 deletion chii/db/tables.py
Original file line number Diff line number Diff line change
Expand Up @@ -91,7 +91,7 @@ class ChiiTimeline:
)
},
)
dateline: int = field(
created_at: int = field(
default_factory=lambda: int(time.time()),
metadata={
"sa": Column(
Expand All @@ -111,7 +111,10 @@ class ChiiTimeline:

# type helper for ChiiTimeline.uid.desc()
ChiiTimeline_column_id: Column[int] = cast(Column[int], ChiiTimeline.id)
ChiiTimeline_column_cat: Column[int] = cast(Column[int], ChiiTimeline.cat)
ChiiTimeline_column_type: Column[int] = cast(Column[int], ChiiTimeline.type)
ChiiTimeline_column_uid: Column[int] = cast(Column[int], ChiiTimeline.uid)
ChiiTimeline_column_created_at: Column[int] = cast(Column[int], ChiiTimeline.created_at)


class HTMLEscapedString(types.TypeDecorator):
Expand Down
86 changes: 86 additions & 0 deletions gql/app.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,86 @@
from pathlib import Path
from typing import Any, Iterator

from ariadne import ObjectType, QueryType, gql, make_executable_schema
from ariadne.asgi import GraphQL
from sqlalchemy import select
from starlette.applications import Starlette
from starlette.routing import Mount

from chii.compat import phpseralize
from chii.const import CollectionType
from chii.db import sa
from chii.db.tables import ChiiTimeline, ChiiTimeline_column_cat, ChiiTimeline_column_id
from chii.timeline import TimelineCat
from gql.model import CollectTimeline
from gql.rules import depth_limit_validator

# Define types using Schema Definition Language (https://graphql.org/learn/schema/)
# Wrapping string in gql function provides validation and better error traceback
type_defs = gql(
Path(__file__, "..", "schema.graphql").resolve().read_text(encoding="utf8")
)

CreateSession = sa.async_session_maker()

# Map resolver functions to Query fields using QueryType
gql_query = QueryType()


# Resolvers are simple python functions
@gql_query.field("timeline_collection")
async def timeline_collection(*_: Any) -> list[CollectTimeline]:
async with CreateSession() as session:
rows: Iterator[ChiiTimeline] = await session.scalars(
select(ChiiTimeline)
.where(ChiiTimeline_column_cat == TimelineCat.Subject)
.order_by(ChiiTimeline_column_id.desc())
.limit(10)
)

result = []
for row in rows:
meme = phpseralize.loads(row.memo.encode())
if not row.batch:
result.append(
CollectTimeline(
id=row.id,
action=CollectionType.wish,
user_id=row.uid,
subject_id=[int(meme["subject_id"])],
created_at=row.created_at,
)
)
else:
result.append(
CollectTimeline(
id=row.id,
action=CollectionType.wish,
user_id=row.uid,
subject_id=[int(x) for x in meme],
created_at=row.created_at,
)
)

return result


# Map resolver functions to custom type fields using ObjectType
gql_collect_timeline = ObjectType("CollectTimeline")

# Create executable GraphQL schema
schema = make_executable_schema(type_defs, gql_query, gql_collect_timeline)

app = Starlette(
debug=True,
routes=[
Mount(
"/graphql",
GraphQL(
schema,
debug=True,
validation_rules=[depth_limit_validator(max_depth=5)],
),
),
],
)
9 changes: 9 additions & 0 deletions gql/model.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,9 @@
from typing import TypedDict


class CollectTimeline(TypedDict):
id: int
action: int
user_id: int
subject_id: list[int]
created_at: int
120 changes: 120 additions & 0 deletions gql/rules.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,120 @@
from typing import Dict, Tuple

from ariadne.contrib.tracing.utils import is_introspection_key
from graphql import (
ASTValidationRule,
DefinitionNode,
FieldNode,
FragmentDefinitionNode,
FragmentSpreadNode,
GraphQLError,
InlineFragmentNode,
Node,
OperationDefinitionNode,
ValidationContext,
)
from graphql.validation.validate import ValidationAbortedError


def depth_limit_validator(max_depth: int):
class DepthLimitValidator(ASTValidationRule):
def __init__(self, validation_context: ValidationContext):
document = validation_context.document
definitions = document.definitions

fragments = get_fragments(definitions)
queries = get_queries_and_mutations(definitions)
query_depths = {}

for name in queries:
query_depths[name] = determine_depth(
node=queries[name],
fragments=fragments,
depth_so_far=0,
max_depth=max_depth,
context=validation_context,
operation_name=name,
)
super().__init__(validation_context)

return DepthLimitValidator


def get_fragments(
definitions: Tuple[DefinitionNode, ...],
) -> Dict[str, FragmentDefinitionNode]:
fragments = {}
for definition in definitions:
if isinstance(definition, FragmentDefinitionNode):
fragments[definition.name.value] = definition
return fragments


def get_queries_and_mutations(
definitions: Tuple[DefinitionNode, ...],
) -> Dict[str, OperationDefinitionNode]:
operations = {}

for definition in definitions:
if isinstance(definition, OperationDefinitionNode):
operation = definition.name.value if definition.name else "anonymous"
operations[operation] = definition
return operations


def determine_depth(
node: Node,
fragments: Dict[str, FragmentDefinitionNode],
depth_so_far: int,
max_depth: int,
context: ValidationContext,
operation_name: str,
) -> int:
if depth_so_far > max_depth:
context.report_error(
GraphQLError(
f"'{operation_name}' exceeds maximum operation depth of {max_depth}.",
[node],
)
)
raise ValidationAbortedError
if isinstance(node, FieldNode):
should_ignore = is_introspection_key(node.name.value)

if should_ignore or not node.selection_set:
return 0
return 1 + max(
determine_depth(
node=selection,
fragments=fragments,
depth_so_far=depth_so_far + 1,
max_depth=max_depth,
context=context,
operation_name=operation_name,
)
for selection in node.selection_set.selections
)
if isinstance(node, FragmentSpreadNode):
return determine_depth(
node=fragments[node.name.value],
fragments=fragments,
depth_so_far=depth_so_far,
max_depth=max_depth,
context=context,
operation_name=operation_name,
)
if isinstance(
node, (InlineFragmentNode, FragmentDefinitionNode, OperationDefinitionNode)
):
return max(
determine_depth(
node=selection,
fragments=fragments,
depth_so_far=depth_so_far,
max_depth=max_depth,
context=context,
operation_name=operation_name,
)
for selection in node.selection_set.selections
)
raise Exception(f"Depth crawler cannot handle: {node.kind}.") # pragma: no cover
12 changes: 12 additions & 0 deletions gql/schema.graphql
Original file line number Diff line number Diff line change
@@ -0,0 +1,12 @@
type Query {
timeline_collection:[CollectTimeline!]!
}

type CollectTimeline {
id: Int!
action: Int!
user_id: Int!
subject_id: [Int!]!
# unix timestamp in seconds
created_at: Int!
}
Loading
Loading