From 3f8acd4bae71555b5d3e0f790a79138faa276a2c Mon Sep 17 00:00:00 2001 From: Tai Sakuma Date: Tue, 13 Feb 2024 14:33:38 -0500 Subject: [PATCH] Add totalCount to the GraphQL pagination query --- src/nextline_rdb/schema/pagination/connection.py | 6 +++++- src/nextline_rdb/schema/pagination/db.py | 12 ++++++++++++ tests/schema/graphql/queries/RDBRuns.gql | 1 + tests/schema/queries/test_pagination.py | 12 ++++++++++++ 4 files changed, 30 insertions(+), 1 deletion(-) diff --git a/src/nextline_rdb/schema/pagination/connection.py b/src/nextline_rdb/schema/pagination/connection.py index 311b2bb..5c3394d 100644 --- a/src/nextline_rdb/schema/pagination/connection.py +++ b/src/nextline_rdb/schema/pagination/connection.py @@ -29,11 +29,13 @@ class PageInfo: @strawberry.type class Connection(Generic[_T]): page_info: PageInfo + total_count: int edges: list[Edge[_T]] async def query_connection( query_edges: Callable[..., Coroutine[Any, Any, list[Edge[_T]]]], + query_total_count: Callable[..., Coroutine[Any, Any, int]], before: Optional[str] = None, after: Optional[str] = None, first: Optional[int] = None, @@ -64,6 +66,8 @@ async def query_connection( has_previous_page = False has_next_page = False + total_count = await query_total_count() + page_info = PageInfo( has_previous_page=has_previous_page, has_next_page=has_next_page, @@ -71,4 +75,4 @@ async def query_connection( end_cursor=edges[-1].cursor if edges else None, ) - return Connection(page_info=page_info, edges=edges) + return Connection(page_info=page_info, total_count=total_count, edges=edges) diff --git a/src/nextline_rdb/schema/pagination/db.py b/src/nextline_rdb/schema/pagination/db.py index 542476b..0ab2159 100644 --- a/src/nextline_rdb/schema/pagination/db.py +++ b/src/nextline_rdb/schema/pagination/db.py @@ -3,6 +3,8 @@ from functools import partial from typing import Optional, Type, TypeVar +from sqlalchemy import func, select + from nextline_rdb import models as db_models from nextline_rdb.db import DB from nextline_rdb.pagination import Sort, load_models @@ -41,8 +43,11 @@ async def load_connection( sort=sort, ) + query_total_count = partial(load_total_count, db=db, Model=Model) + return await query_connection( query_edges, + query_total_count, before, after, first, @@ -50,6 +55,13 @@ async def load_connection( ) +async def load_total_count(db: DB, Model: Type[db_models.Model]) -> int: + async with db.session() as session: + stmt = select(func.count()).select_from(Model) + total_count = (await session.execute(stmt)).scalar() or 0 + return total_count + + async def load_edges( db: DB, Model: Type[db_models.Model], diff --git a/tests/schema/graphql/queries/RDBRuns.gql b/tests/schema/graphql/queries/RDBRuns.gql index 894bcd0..318a93c 100644 --- a/tests/schema/graphql/queries/RDBRuns.gql +++ b/tests/schema/graphql/queries/RDBRuns.gql @@ -12,6 +12,7 @@ query HistoryRuns( hasNextPage hasPreviousPage } + totalCount edges { cursor node { diff --git a/tests/schema/queries/test_pagination.py b/tests/schema/queries/test_pagination.py index 446767c..2c87de7 100644 --- a/tests/schema/queries/test_pagination.py +++ b/tests/schema/queries/test_pagination.py @@ -52,6 +52,7 @@ async def test_all(runs: list[Run]): note(f'runs: {runs}') nodes_saved = [Node(id=run.id, runNo=run.run_no) for run in runs] + n_nodes_saved = len(nodes_saved) note(f'nodes_saved: {nodes_saved}') resp = await schema.execute(QUERY_RDB_RUNS, context_value={'db': db}) @@ -59,12 +60,15 @@ async def test_all(runs: list[Run]): all_runs = resp.data['rdb']['runs'] page_info: PageInfo = all_runs['pageInfo'] + total_count = all_runs['totalCount'] edges: list[Edge] = all_runs['edges'] if edges: assert page_info['startCursor'] == edges[0]['cursor'] assert page_info['endCursor'] == edges[-1]['cursor'] + assert total_count == n_nodes_saved + nodes = [edge['node'] for edge in edges] assert nodes == nodes_saved @@ -83,6 +87,7 @@ async def test_forward(runs: list[Run], first: int): note(f'runs: {runs}') nodes_saved = [Node(id=run.id, runNo=run.run_no) for run in runs] + n_nodes_saved = len(nodes_saved) note(f'nodes_saved: {nodes_saved}') after = None @@ -100,6 +105,7 @@ async def test_forward(runs: list[Run], first: int): all_runs = resp.data['rdb']['runs'] page_info: PageInfo = all_runs['pageInfo'] + total_count = all_runs['totalCount'] edges: list[Edge] = all_runs['edges'] has_next_page = page_info['hasNextPage'] @@ -113,6 +119,8 @@ async def test_forward(runs: list[Run], first: int): if edges: assert after == edges[-1]['cursor'] + assert total_count == n_nodes_saved + nodes.extend(edge['node'] for edge in edges) assert nodes == nodes_saved @@ -131,6 +139,7 @@ async def test_backward(runs: list[Run], last: int): note(f'runs: {runs}') nodes_saved = [Node(id=run.id, runNo=run.run_no) for run in runs] + n_nodes_saved = len(nodes_saved) note(f'nodes_saved: {nodes_saved}') before = None @@ -148,6 +157,7 @@ async def test_backward(runs: list[Run], last: int): all_runs = resp.data['rdb']['runs'] page_info: PageInfo = all_runs['pageInfo'] + total_count = all_runs['totalCount'] edges: list[Edge] = all_runs['edges'] has_previous_page = page_info['hasPreviousPage'] @@ -161,6 +171,8 @@ async def test_backward(runs: list[Run], last: int): if edges: assert before == edges[0]['cursor'] + assert total_count == n_nodes_saved + nodes.extend(edge['node'] for edge in reversed(edges)) assert nodes == list(reversed(nodes_saved))