Skip to content

Commit

Permalink
Merge pull request #44 from dls-controls/strawberry_graphql_api_memor…
Browse files Browse the repository at this point in the history
…y_leak_fix

Change graphql library to use Strawberry and fix memory leak with large numbers subscription
  • Loading branch information
aawdls authored Apr 18, 2023
2 parents f9a05b6 + ee1304c commit 43b61f9
Show file tree
Hide file tree
Showing 28 changed files with 3,072 additions and 3,845 deletions.
4 changes: 4 additions & 0 deletions .github/workflows/docs.yml
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,9 @@ on:
jobs:
docs:
runs-on: ubuntu-latest
strategy:
matrix:
python-version: ["3.8"]

steps:
- name: Avoid git conflicts when tag and branch pushed at same time
Expand All @@ -26,6 +29,7 @@ jobs:
- name: Build docs
uses: dls-controls/pipenv-run-action@v1
with:
python-version: "3.8"
pipenv-run: docs

- name: Move to versioned directory
Expand Down
3,004 changes: 1,699 additions & 1,305 deletions Pipfile.lock

Large diffs are not rendered by default.

18 changes: 9 additions & 9 deletions setup.cfg
Original file line number Diff line number Diff line change
Expand Up @@ -24,12 +24,13 @@ package_dir =

# Specify any package dependencies below.
install_requires =
tartiflette-aiohttp
strawberry-graphql
aioca
p4p<4.0.0
ruamel-yaml
p4p
ruamel.yaml
pydantic
aiohttp-cors


[options.extras_require]
# For development tests/docs
Expand All @@ -43,18 +44,16 @@ dev =
flake8 <= 3.9.2
flake8-isort
sphinx-rtd-theme-github-versions
sphinxcontrib-applehelp==1.0.2
sphinxcontrib-htmlhelp==2.0.0
pre-commit
websockets
pytest-asyncio>0.17
pytest-aiohttp
wheel

[options.packages.find]
where = src

# Specify any package data to be included in the wheel below.
[options.package_data]
coniql =
*.gql

[options.entry_points]
# Include a command line script
console_scripts =
Expand All @@ -63,6 +62,7 @@ console_scripts =
[mypy]
# Ignore missing stubs for modules we use
ignore_missing_imports = True
plugins = strawberry.ext.mypy_plugin

[isort]
profile=black
Expand Down
145 changes: 93 additions & 52 deletions src/coniql/app.py
Original file line number Diff line number Diff line change
@@ -1,49 +1,92 @@
import traceback
import logging
from argparse import ArgumentParser
from datetime import timedelta
from pathlib import Path
from typing import Any, Dict
from typing import Any, Optional

import aiohttp_cors
import strawberry
from aiohttp import web
from tartiflette import Engine, TartifletteError
from tartiflette_aiohttp import register_graphql_handlers
from strawberry.aiohttp.views import GraphQLView
from strawberry.subscriptions import GRAPHQL_TRANSPORT_WS_PROTOCOL, GRAPHQL_WS_PROTOCOL

from coniql.caplugin import CAPlugin
from coniql.plugin import PluginStore
from coniql.pvaplugin import PVAPlugin
from coniql.simplugin import SimPlugin
import coniql.strawberry_schema as schema

from . import __version__


async def error_coercer(exception: Exception, error: Dict[str, Any]) -> Dict[str, Any]:
if isinstance(exception, TartifletteError):
e = exception.original_error
else:
e = exception
if e:
traceback.print_exception(type(e), e, e.__traceback__)
return error
def create_schema(debug: bool):
# Create the schema
return strawberry.Schema(
query=schema.Query,
subscription=schema.Subscription,
mutation=schema.Mutation,
)


def create_app(
use_cors: bool,
debug: bool,
graphiql: bool,
connection_init_wait_timeout: Optional[timedelta] = None,
):
# Create the schema
strawberry_schema = create_schema(debug)

def make_engine() -> Engine:
engine = Engine(
sdl=Path(__file__).resolve().parent / "schema.gql",
error_coercer=error_coercer,
modules=["coniql.resolvers"],
kwargs: Any = {}
if connection_init_wait_timeout:
kwargs["connection_init_wait_timeout"] = connection_init_wait_timeout

# Create the GraphQL view to attach to the app
view = GraphQLView(
schema=strawberry_schema,
subscription_protocols=[GRAPHQL_TRANSPORT_WS_PROTOCOL, GRAPHQL_WS_PROTOCOL],
graphiql=graphiql,
**kwargs
)
return engine

# Create app
app = web.Application()
# Add routes
app.router.add_route("GET", "/ws", view)
app.router.add_route("POST", "/ws", view)
app.router.add_route("POST", "/graphql", view)
# Enable CORS for all origins on all routes (if applicable)
if use_cors:
cors = aiohttp_cors.setup(app)
for route in app.router.routes():
allow_all = {
"*": aiohttp_cors.ResourceOptions(
allow_headers=("*"), max_age=3600, allow_credentials=True
)
}
cors.add(route, allow_all)

return app


def configure_logger(debug: bool = False, fmt: Optional[str] = None) -> None:
class OptionalTraceFormatter(logging.Formatter):
def __init__(self, debug: bool = False, fmt: Optional[str] = None) -> None:
self.debug = debug
super().__init__(fmt)

def formatStack(self, stack_info: str) -> str:
"""Option to suppress the stack trace output"""
if not self.debug:
return ""
return super().formatStack(stack_info)

def make_context(*schema_paths: Path) -> Dict[str, Any]:
store = PluginStore()
store.add_plugin("ssim", SimPlugin())
store.add_plugin("pva", PVAPlugin())
store.add_plugin("ca", CAPlugin(), set_default=True)
for path in schema_paths:
store.add_device_config(path)
context = dict(store=store)
return context
# Handler to print to stderr
console = logging.StreamHandler()
console.setLevel(logging.DEBUG if debug else logging.ERROR)
console.setFormatter(OptionalTraceFormatter(debug, fmt))

# Attach it to both Coniql and Strawberry loggers
strawberry_logger = logging.getLogger("strawberry")
strawberry_logger.addHandler(console)
coniql_logger = logging.getLogger("coniql")
coniql_logger.addHandler(console)


def main(args=None) -> None:
Expand All @@ -60,29 +103,27 @@ def main(args=None) -> None:
help="Paths to .coniql.yaml files describing Channels and Devices",
)
parser.add_argument(
"--cors", action="store_true", help="Allow CORS for all origins and routes"
"--cors",
action="store_true",
default=False,
help="Allow CORS for all origins and routes",
)
parsed_args = parser.parse_args(args)

context = make_context(*parsed_args.config_paths)
app = register_graphql_handlers(
app=web.Application(),
executor_context=context,
executor_http_endpoint="/graphql",
subscription_ws_endpoint="/ws",
graphiql_enabled=True,
engine=make_engine(),
parser.add_argument(
"--debug",
action="store_true",
default=False,
help="Print stack trace on errors",
)
parser.add_argument(
"--graphiql",
action="store_true",
default=False,
help="Enable GraphiQL for testing at localhost:8080/ws",
)
parsed_args = parser.parse_args(args)

if parsed_args.cors:
# Enable CORS for all origins on all routes.
cors = aiohttp_cors.setup(app)
for route in app.router.routes():
allow_all = {
"*": aiohttp_cors.ResourceOptions(
allow_headers=("*"), max_age=3600, allow_credentials=True
)
}
cors.add(route, allow_all)
logger_fmt = "[%(asctime)s::%(name)s::%(levelname)s]: %(message)s"
configure_logger(parsed_args.debug, logger_fmt)

app = create_app(parsed_args.cors, parsed_args.debug, parsed_args.graphiql)
web.run_app(app)
Loading

0 comments on commit 43b61f9

Please sign in to comment.