Skip to content

Commit

Permalink
Merge pull request #3588 from lonvia/optional-reverse-api
Browse files Browse the repository at this point in the history
Add support for adding endpoints to server conditionally
  • Loading branch information
lonvia authored Nov 14, 2024
2 parents 04d5f67 + 20d0fb3 commit 3acd7df
Show file tree
Hide file tree
Showing 4 changed files with 80 additions and 42 deletions.
50 changes: 34 additions & 16 deletions src/nominatim_api/server/falcon/server.py
Original file line number Diff line number Diff line change
Expand Up @@ -147,12 +147,36 @@ async def process_response(self, req: Request, resp: Response,
f'{resource.name} "{params}"\n')


class APIShutdown:
""" Middleware that closes any open database connections.
class APIMiddleware:
""" Middleware managing the Nominatim database connection.
"""

def __init__(self, api: NominatimAPIAsync) -> None:
self.api = api
def __init__(self, project_dir: Path, environ: Optional[Mapping[str, str]]) -> None:
self.api = NominatimAPIAsync(project_dir, environ)
self.app: Optional[App] = None

@property
def config(self) -> Configuration:
""" Get the configuration for Nominatim.
"""
return self.api.config

def set_app(self, app: App) -> None:
""" Set the Falcon application this middleware is connected to.
"""
self.app = app

async def process_startup(self, *_: Any) -> None:
""" Process the ASGI lifespan startup event.
"""
assert self.app is not None
legacy_urls = self.api.config.get_bool('SERVE_LEGACY_URLS')
formatter = load_format_dispatcher('v1', self.api.config.project_dir)
for name, func in await api_impl.get_routes(self.api):
endpoint = EndpointWrapper(name, func, self.api, formatter)
self.app.add_route(f"/{name}", endpoint)
if legacy_urls:
self.app.add_route(f"/{name}.php", endpoint)

async def process_shutdown(self, *_: Any) -> None:
"""Process the ASGI lifespan shutdown event.
Expand All @@ -164,28 +188,22 @@ def get_application(project_dir: Path,
environ: Optional[Mapping[str, str]] = None) -> App:
""" Create a Nominatim Falcon ASGI application.
"""
api = NominatimAPIAsync(project_dir, environ)
apimw = APIMiddleware(project_dir, environ)

middleware: List[object] = [APIShutdown(api)]
log_file = api.config.LOG_FILE
middleware: List[object] = [apimw]
log_file = apimw.config.LOG_FILE
if log_file:
middleware.append(FileLoggingMiddleware(log_file))

app = App(cors_enable=api.config.get_bool('CORS_NOACCESSCONTROL'),
app = App(cors_enable=apimw.config.get_bool('CORS_NOACCESSCONTROL'),
middleware=middleware)

apimw.set_app(app)
app.add_error_handler(HTTPNominatimError, nominatim_error_handler)
app.add_error_handler(TimeoutError, timeout_error_handler)
# different from TimeoutError in Python <= 3.10
app.add_error_handler(asyncio.TimeoutError, timeout_error_handler) # type: ignore[arg-type]

legacy_urls = api.config.get_bool('SERVE_LEGACY_URLS')
formatter = load_format_dispatcher('v1', project_dir)
for name, func in api_impl.ROUTES:
endpoint = EndpointWrapper(name, func, api, formatter)
app.add_route(f"/{name}", endpoint)
if legacy_urls:
app.add_route(f"/{name}.php", endpoint)

return app


Expand Down
34 changes: 20 additions & 14 deletions src/nominatim_api/server/starlette/server.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,10 +7,12 @@
"""
Server implementation using the starlette webserver framework.
"""
from typing import Any, Optional, Mapping, Callable, cast, Coroutine, Dict, Awaitable
from typing import Any, Optional, Mapping, Callable, cast, Coroutine, Dict, \
Awaitable, AsyncIterator
from pathlib import Path
import datetime as dt
import asyncio
import contextlib

from starlette.applications import Starlette
from starlette.routing import Route
Expand Down Expand Up @@ -66,7 +68,7 @@ def config(self) -> Configuration:
return cast(Configuration, self.request.app.state.API.config)

def formatting(self) -> FormatDispatcher:
return cast(FormatDispatcher, self.request.app.state.API.formatter)
return cast(FormatDispatcher, self.request.app.state.formatter)


def _wrap_endpoint(func: EndpointFunc)\
Expand Down Expand Up @@ -132,14 +134,6 @@ def get_application(project_dir: Path,
"""
config = Configuration(project_dir, environ)

routes = []
legacy_urls = config.get_bool('SERVE_LEGACY_URLS')
for name, func in api_impl.ROUTES:
endpoint = _wrap_endpoint(func)
routes.append(Route(f"/{name}", endpoint=endpoint))
if legacy_urls:
routes.append(Route(f"/{name}.php", endpoint=endpoint))

middleware = []
if config.get_bool('CORS_NOACCESSCONTROL'):
middleware.append(Middleware(CORSMiddleware,
Expand All @@ -156,14 +150,26 @@ def get_application(project_dir: Path,
asyncio.TimeoutError: timeout_error
}

async def _shutdown() -> None:
@contextlib.asynccontextmanager
async def lifespan(app: Starlette) -> AsyncIterator[Any]:
app.state.API = NominatimAPIAsync(project_dir, environ)
config = app.state.API.config

legacy_urls = config.get_bool('SERVE_LEGACY_URLS')
for name, func in await api_impl.get_routes(app.state.API):
endpoint = _wrap_endpoint(func)
app.routes.append(Route(f"/{name}", endpoint=endpoint))
if legacy_urls:
app.routes.append(Route(f"/{name}.php", endpoint=endpoint))

yield

await app.state.API.close()

app = Starlette(debug=debug, routes=routes, middleware=middleware,
app = Starlette(debug=debug, middleware=middleware,
exception_handlers=exceptions,
on_shutdown=[_shutdown])
lifespan=lifespan)

app.state.API = NominatimAPIAsync(project_dir, environ)
app.state.formatter = load_format_dispatcher('v1', project_dir)

return app
Expand Down
2 changes: 1 addition & 1 deletion src/nominatim_api/v1/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,4 +8,4 @@
Implementation of API version v1 (aka the legacy version).
"""

from .server_glue import ROUTES as ROUTES
from .server_glue import get_routes as get_routes
36 changes: 25 additions & 11 deletions src/nominatim_api/v1/server_glue.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@
Generic part of the server implementation of the v1 API.
Combine with the scaffolding provided for the various Python ASGI frameworks.
"""
from typing import Optional, Any, Type, Dict, cast
from typing import Optional, Any, Type, Dict, cast, Sequence, Tuple
from functools import reduce
import dataclasses
from urllib.parse import urlencode
Expand All @@ -25,7 +25,8 @@
from ..localization import Locales
from . import helpers
from ..server import content_types as ct
from ..server.asgi_adaptor import ASGIAdaptor
from ..server.asgi_adaptor import ASGIAdaptor, EndpointFunc
from ..sql.async_core_library import PGCORE_ERROR


def build_response(adaptor: ASGIAdaptor, output: str, status: int = 200,
Expand Down Expand Up @@ -417,12 +418,25 @@ async def polygons_endpoint(api: NominatimAPIAsync, params: ASGIAdaptor) -> Any:
return build_response(params, params.formatting().format_result(results, fmt, {}))


ROUTES = [
('status', status_endpoint),
('details', details_endpoint),
('reverse', reverse_endpoint),
('lookup', lookup_endpoint),
('search', search_endpoint),
('deletable', deletable_endpoint),
('polygons', polygons_endpoint),
]
async def get_routes(api: NominatimAPIAsync) -> Sequence[Tuple[str, EndpointFunc]]:
routes = [
('status', status_endpoint),
('details', details_endpoint),
('reverse', reverse_endpoint),
('lookup', lookup_endpoint),
('deletable', deletable_endpoint),
('polygons', polygons_endpoint),
]

def has_search_name(conn: sa.engine.Connection) -> bool:
insp = sa.inspect(conn)
return insp.has_table('search_name')

try:
async with api.begin() as conn:
if await conn.connection.run_sync(has_search_name):
routes.append(('search', search_endpoint))
except (PGCORE_ERROR, sa.exc.OperationalError):
pass # ignored

return routes

0 comments on commit 3acd7df

Please sign in to comment.