Skip to content

Commit

Permalink
Merge pull request #39 from simonsobs/dev
Browse files Browse the repository at this point in the history
Add `totalCount` to the GraphQL pagination query
  • Loading branch information
TaiSakuma authored Feb 13, 2024
2 parents 2dfda4a + 3f8acd4 commit 71fa530
Show file tree
Hide file tree
Showing 4 changed files with 30 additions and 1 deletion.
6 changes: 5 additions & 1 deletion src/nextline_rdb/schema/pagination/connection.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,11 +29,13 @@ class PageInfo:
@strawberry.type
class Connection(Generic[_T]):
page_info: PageInfo
total_count: int
edges: list[Edge[_T]]


async def query_connection(
query_edges: Callable[..., Coroutine[Any, Any, list[Edge[_T]]]],
query_total_count: Callable[..., Coroutine[Any, Any, int]],
before: Optional[str] = None,
after: Optional[str] = None,
first: Optional[int] = None,
Expand Down Expand Up @@ -64,11 +66,13 @@ async def query_connection(
has_previous_page = False
has_next_page = False

total_count = await query_total_count()

page_info = PageInfo(
has_previous_page=has_previous_page,
has_next_page=has_next_page,
start_cursor=edges[0].cursor if edges else None,
end_cursor=edges[-1].cursor if edges else None,
)

return Connection(page_info=page_info, edges=edges)
return Connection(page_info=page_info, total_count=total_count, edges=edges)
12 changes: 12 additions & 0 deletions src/nextline_rdb/schema/pagination/db.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,8 @@
from functools import partial
from typing import Optional, Type, TypeVar

from sqlalchemy import func, select

from nextline_rdb import models as db_models
from nextline_rdb.db import DB
from nextline_rdb.pagination import Sort, load_models
Expand Down Expand Up @@ -41,15 +43,25 @@ async def load_connection(
sort=sort,
)

query_total_count = partial(load_total_count, db=db, Model=Model)

return await query_connection(
query_edges,
query_total_count,
before,
after,
first,
last,
)


async def load_total_count(db: DB, Model: Type[db_models.Model]) -> int:
async with db.session() as session:
stmt = select(func.count()).select_from(Model)
total_count = (await session.execute(stmt)).scalar() or 0
return total_count


async def load_edges(
db: DB,
Model: Type[db_models.Model],
Expand Down
1 change: 1 addition & 0 deletions tests/schema/graphql/queries/RDBRuns.gql
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@ query HistoryRuns(
hasNextPage
hasPreviousPage
}
totalCount
edges {
cursor
node {
Expand Down
12 changes: 12 additions & 0 deletions tests/schema/queries/test_pagination.py
Original file line number Diff line number Diff line change
Expand Up @@ -52,19 +52,23 @@ async def test_all(runs: list[Run]):
note(f'runs: {runs}')

nodes_saved = [Node(id=run.id, runNo=run.run_no) for run in runs]
n_nodes_saved = len(nodes_saved)
note(f'nodes_saved: {nodes_saved}')

resp = await schema.execute(QUERY_RDB_RUNS, context_value={'db': db})
assert resp.data

all_runs = resp.data['rdb']['runs']
page_info: PageInfo = all_runs['pageInfo']
total_count = all_runs['totalCount']
edges: list[Edge] = all_runs['edges']

if edges:
assert page_info['startCursor'] == edges[0]['cursor']
assert page_info['endCursor'] == edges[-1]['cursor']

assert total_count == n_nodes_saved

nodes = [edge['node'] for edge in edges]

assert nodes == nodes_saved
Expand All @@ -83,6 +87,7 @@ async def test_forward(runs: list[Run], first: int):
note(f'runs: {runs}')

nodes_saved = [Node(id=run.id, runNo=run.run_no) for run in runs]
n_nodes_saved = len(nodes_saved)
note(f'nodes_saved: {nodes_saved}')

after = None
Expand All @@ -100,6 +105,7 @@ async def test_forward(runs: list[Run], first: int):

all_runs = resp.data['rdb']['runs']
page_info: PageInfo = all_runs['pageInfo']
total_count = all_runs['totalCount']
edges: list[Edge] = all_runs['edges']

has_next_page = page_info['hasNextPage']
Expand All @@ -113,6 +119,8 @@ async def test_forward(runs: list[Run], first: int):
if edges:
assert after == edges[-1]['cursor']

assert total_count == n_nodes_saved

nodes.extend(edge['node'] for edge in edges)

assert nodes == nodes_saved
Expand All @@ -131,6 +139,7 @@ async def test_backward(runs: list[Run], last: int):
note(f'runs: {runs}')

nodes_saved = [Node(id=run.id, runNo=run.run_no) for run in runs]
n_nodes_saved = len(nodes_saved)
note(f'nodes_saved: {nodes_saved}')

before = None
Expand All @@ -148,6 +157,7 @@ async def test_backward(runs: list[Run], last: int):

all_runs = resp.data['rdb']['runs']
page_info: PageInfo = all_runs['pageInfo']
total_count = all_runs['totalCount']
edges: list[Edge] = all_runs['edges']

has_previous_page = page_info['hasPreviousPage']
Expand All @@ -161,6 +171,8 @@ async def test_backward(runs: list[Run], last: int):
if edges:
assert before == edges[0]['cursor']

assert total_count == n_nodes_saved

nodes.extend(edge['node'] for edge in reversed(edges))

assert nodes == list(reversed(nodes_saved))
Expand Down

0 comments on commit 71fa530

Please sign in to comment.