Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add support for adding endpoints to server conditionally #3588

Merged
merged 2 commits into from
Nov 14, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
50 changes: 34 additions & 16 deletions src/nominatim_api/server/falcon/server.py
Original file line number Diff line number Diff line change
Expand Up @@ -147,12 +147,36 @@ async def process_response(self, req: Request, resp: Response,
f'{resource.name} "{params}"\n')


class APIShutdown:
""" Middleware that closes any open database connections.
class APIMiddleware:
""" Middleware managing the Nominatim database connection.
"""

def __init__(self, api: NominatimAPIAsync) -> None:
self.api = api
def __init__(self, project_dir: Path, environ: Optional[Mapping[str, str]]) -> None:
self.api = NominatimAPIAsync(project_dir, environ)
self.app: Optional[App] = None

@property
def config(self) -> Configuration:
""" Get the configuration for Nominatim.
"""
return self.api.config

def set_app(self, app: App) -> None:
""" Set the Falcon application this middleware is connected to.
"""
self.app = app

async def process_startup(self, *_: Any) -> None:
""" Process the ASGI lifespan startup event.
"""
assert self.app is not None
legacy_urls = self.api.config.get_bool('SERVE_LEGACY_URLS')
formatter = load_format_dispatcher('v1', self.api.config.project_dir)
for name, func in await api_impl.get_routes(self.api):
endpoint = EndpointWrapper(name, func, self.api, formatter)
self.app.add_route(f"/{name}", endpoint)
if legacy_urls:
self.app.add_route(f"/{name}.php", endpoint)

async def process_shutdown(self, *_: Any) -> None:
"""Process the ASGI lifespan shutdown event.
Expand All @@ -164,28 +188,22 @@ def get_application(project_dir: Path,
environ: Optional[Mapping[str, str]] = None) -> App:
""" Create a Nominatim Falcon ASGI application.
"""
api = NominatimAPIAsync(project_dir, environ)
apimw = APIMiddleware(project_dir, environ)

middleware: List[object] = [APIShutdown(api)]
log_file = api.config.LOG_FILE
middleware: List[object] = [apimw]
log_file = apimw.config.LOG_FILE
if log_file:
middleware.append(FileLoggingMiddleware(log_file))

app = App(cors_enable=api.config.get_bool('CORS_NOACCESSCONTROL'),
app = App(cors_enable=apimw.config.get_bool('CORS_NOACCESSCONTROL'),
middleware=middleware)

apimw.set_app(app)
app.add_error_handler(HTTPNominatimError, nominatim_error_handler)
app.add_error_handler(TimeoutError, timeout_error_handler)
# different from TimeoutError in Python <= 3.10
app.add_error_handler(asyncio.TimeoutError, timeout_error_handler) # type: ignore[arg-type]

legacy_urls = api.config.get_bool('SERVE_LEGACY_URLS')
formatter = load_format_dispatcher('v1', project_dir)
for name, func in api_impl.ROUTES:
endpoint = EndpointWrapper(name, func, api, formatter)
app.add_route(f"/{name}", endpoint)
if legacy_urls:
app.add_route(f"/{name}.php", endpoint)

return app


Expand Down
34 changes: 20 additions & 14 deletions src/nominatim_api/server/starlette/server.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,10 +7,12 @@
"""
Server implementation using the starlette webserver framework.
"""
from typing import Any, Optional, Mapping, Callable, cast, Coroutine, Dict, Awaitable
from typing import Any, Optional, Mapping, Callable, cast, Coroutine, Dict, \
Awaitable, AsyncIterator
from pathlib import Path
import datetime as dt
import asyncio
import contextlib

from starlette.applications import Starlette
from starlette.routing import Route
Expand Down Expand Up @@ -66,7 +68,7 @@ def config(self) -> Configuration:
return cast(Configuration, self.request.app.state.API.config)

def formatting(self) -> FormatDispatcher:
return cast(FormatDispatcher, self.request.app.state.API.formatter)
return cast(FormatDispatcher, self.request.app.state.formatter)


def _wrap_endpoint(func: EndpointFunc)\
Expand Down Expand Up @@ -132,14 +134,6 @@ def get_application(project_dir: Path,
"""
config = Configuration(project_dir, environ)

routes = []
legacy_urls = config.get_bool('SERVE_LEGACY_URLS')
for name, func in api_impl.ROUTES:
endpoint = _wrap_endpoint(func)
routes.append(Route(f"/{name}", endpoint=endpoint))
if legacy_urls:
routes.append(Route(f"/{name}.php", endpoint=endpoint))

middleware = []
if config.get_bool('CORS_NOACCESSCONTROL'):
middleware.append(Middleware(CORSMiddleware,
Expand All @@ -156,14 +150,26 @@ def get_application(project_dir: Path,
asyncio.TimeoutError: timeout_error
}

async def _shutdown() -> None:
@contextlib.asynccontextmanager
async def lifespan(app: Starlette) -> AsyncIterator[Any]:
app.state.API = NominatimAPIAsync(project_dir, environ)
config = app.state.API.config

legacy_urls = config.get_bool('SERVE_LEGACY_URLS')
for name, func in await api_impl.get_routes(app.state.API):
endpoint = _wrap_endpoint(func)
app.routes.append(Route(f"/{name}", endpoint=endpoint))
if legacy_urls:
app.routes.append(Route(f"/{name}.php", endpoint=endpoint))

yield

await app.state.API.close()

app = Starlette(debug=debug, routes=routes, middleware=middleware,
app = Starlette(debug=debug, middleware=middleware,
exception_handlers=exceptions,
on_shutdown=[_shutdown])
lifespan=lifespan)

app.state.API = NominatimAPIAsync(project_dir, environ)
app.state.formatter = load_format_dispatcher('v1', project_dir)

return app
Expand Down
2 changes: 1 addition & 1 deletion src/nominatim_api/v1/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,4 +8,4 @@
Implementation of API version v1 (aka the legacy version).
"""

from .server_glue import ROUTES as ROUTES
from .server_glue import get_routes as get_routes
36 changes: 25 additions & 11 deletions src/nominatim_api/v1/server_glue.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@
Generic part of the server implementation of the v1 API.
Combine with the scaffolding provided for the various Python ASGI frameworks.
"""
from typing import Optional, Any, Type, Dict, cast
from typing import Optional, Any, Type, Dict, cast, Sequence, Tuple
from functools import reduce
import dataclasses
from urllib.parse import urlencode
Expand All @@ -25,7 +25,8 @@
from ..localization import Locales
from . import helpers
from ..server import content_types as ct
from ..server.asgi_adaptor import ASGIAdaptor
from ..server.asgi_adaptor import ASGIAdaptor, EndpointFunc
from ..sql.async_core_library import PGCORE_ERROR


def build_response(adaptor: ASGIAdaptor, output: str, status: int = 200,
Expand Down Expand Up @@ -417,12 +418,25 @@ async def polygons_endpoint(api: NominatimAPIAsync, params: ASGIAdaptor) -> Any:
return build_response(params, params.formatting().format_result(results, fmt, {}))


ROUTES = [
('status', status_endpoint),
('details', details_endpoint),
('reverse', reverse_endpoint),
('lookup', lookup_endpoint),
('search', search_endpoint),
('deletable', deletable_endpoint),
('polygons', polygons_endpoint),
]
async def get_routes(api: NominatimAPIAsync) -> Sequence[Tuple[str, EndpointFunc]]:
routes = [
('status', status_endpoint),
('details', details_endpoint),
('reverse', reverse_endpoint),
('lookup', lookup_endpoint),
('deletable', deletable_endpoint),
('polygons', polygons_endpoint),
]

def has_search_name(conn: sa.engine.Connection) -> bool:
insp = sa.inspect(conn)
return insp.has_table('search_name')

try:
async with api.begin() as conn:
if await conn.connection.run_sync(has_search_name):
routes.append(('search', search_endpoint))
except (PGCORE_ERROR, sa.exc.OperationalError):
pass # ignored

return routes