From 06de38bf8c8c3d04ce4e753ed3aa05863031ce17 Mon Sep 17 00:00:00 2001 From: Tai Sakuma Date: Fri, 14 Jun 2024 12:03:49 -0400 Subject: [PATCH 1/2] Fix all mypy errors with the `disallow_untyped_defs` option --- .../alembic/models/rev_4dc6a93dfed8/model_script.py | 2 +- .../rev_4dc6a93dfed8/tests/test_model_prompt.py | 2 +- .../models/rev_4dc6a93dfed8/tests/test_model_run.py | 4 ++-- .../rev_4dc6a93dfed8/tests/test_model_stdout.py | 2 +- .../rev_4dc6a93dfed8/tests/test_model_trace.py | 2 +- .../models/rev_4dc6a93dfed8/tests/test_repr_val.py | 6 +++--- .../rev_5a08750d6760/tests/test_model_prompt.py | 2 +- .../models/rev_5a08750d6760/tests/test_model_run.py | 2 +- .../rev_5a08750d6760/tests/test_model_stdout.py | 2 +- .../rev_5a08750d6760/tests/test_model_trace.py | 2 +- .../models/rev_5a08750d6760/tests/test_repr_val.py | 6 +++--- .../alembic/models/rev_f3edea6dbde2/model_script.py | 2 +- .../rev_f3edea6dbde2/tests/test_model_prompt.py | 2 +- .../models/rev_f3edea6dbde2/tests/test_model_run.py | 4 ++-- .../rev_f3edea6dbde2/tests/test_model_stdout.py | 2 +- .../rev_f3edea6dbde2/tests/test_model_trace.py | 2 +- .../models/rev_f3edea6dbde2/tests/test_repr_val.py | 6 +++--- .../rev_f9a742bb2297/tests/test_model_prompt.py | 2 +- .../models/rev_f9a742bb2297/tests/test_model_run.py | 4 ++-- .../rev_f9a742bb2297/tests/test_model_stdout.py | 2 +- .../rev_f9a742bb2297/tests/test_model_trace.py | 2 +- .../models/rev_f9a742bb2297/tests/test_repr_val.py | 6 +++--- src/nextline_rdb/db.py | 6 +++--- src/nextline_rdb/models/tests/test_cascade.py | 6 +++--- src/nextline_rdb/models/tests/test_repr_val.py | 6 +++--- src/nextline_rdb/pagination.py | 4 ++-- src/nextline_rdb/schema/nodes/prompt_node.py | 2 +- src/nextline_rdb/schema/nodes/run_node.py | 2 +- src/nextline_rdb/schema/nodes/stdout_node.py | 2 +- src/nextline_rdb/schema/nodes/trace_call_node.py | 4 +++- src/nextline_rdb/schema/nodes/trace_node.py | 2 +- tests/db/models.py | 6 ++++-- tests/db/test_db.py | 12 ++++++++---- tests/pagination/funcs.py | 2 +- tests/pagination/test_load_models.py | 6 +++--- tests/schema/queries/test_pagination.py | 8 ++++---- tests/test_plugin.py | 8 +++++--- tests/test_version.py | 2 +- tests/utils/test_mark_last.py | 2 +- tests/utils/test_until.py | 6 ++++-- 40 files changed, 82 insertions(+), 70 deletions(-) diff --git a/src/nextline_rdb/alembic/models/rev_4dc6a93dfed8/model_script.py b/src/nextline_rdb/alembic/models/rev_4dc6a93dfed8/model_script.py index fc69366..7e443c4 100644 --- a/src/nextline_rdb/alembic/models/rev_4dc6a93dfed8/model_script.py +++ b/src/nextline_rdb/alembic/models/rev_4dc6a93dfed8/model_script.py @@ -29,7 +29,7 @@ def current(self) -> bool: return self._current is not None @current.setter - def current(self, value: bool): + def current(self, value: bool) -> None: if value and self._current is None: self._current = CurrentScript(script=self) elif not value: diff --git a/src/nextline_rdb/alembic/models/rev_4dc6a93dfed8/tests/test_model_prompt.py b/src/nextline_rdb/alembic/models/rev_4dc6a93dfed8/tests/test_model_prompt.py index 342a2e5..207e3d2 100644 --- a/src/nextline_rdb/alembic/models/rev_4dc6a93dfed8/tests/test_model_prompt.py +++ b/src/nextline_rdb/alembic/models/rev_4dc6a93dfed8/tests/test_model_prompt.py @@ -12,7 +12,7 @@ @given(st.data()) -async def test_repr(data: st.DataObject): +async def test_repr(data: st.DataObject) -> None: async with DB(use_migration=False, model_base_class=Model) as db: async with db.session.begin() as session: model = data.draw(st_model_prompt()) diff --git a/src/nextline_rdb/alembic/models/rev_4dc6a93dfed8/tests/test_model_run.py b/src/nextline_rdb/alembic/models/rev_4dc6a93dfed8/tests/test_model_run.py index e592281..cc3248f 100644 --- a/src/nextline_rdb/alembic/models/rev_4dc6a93dfed8/tests/test_model_run.py +++ b/src/nextline_rdb/alembic/models/rev_4dc6a93dfed8/tests/test_model_run.py @@ -10,7 +10,7 @@ @given(run=st_model_run(generate_traces=False)) -async def test_repr(run: Run): +async def test_repr(run: Run) -> None: async with DB(use_migration=False, model_base_class=Model) as db: async with db.session.begin() as session: session.add(run) @@ -24,7 +24,7 @@ async def test_repr(run: Run): @given(run=st_model_run(generate_traces=True)) -async def test_cascade(run: Run): +async def test_cascade(run: Run) -> None: async with DB(use_migration=False, model_base_class=Model) as db: async with db.session.begin() as session: session.add(run) diff --git a/src/nextline_rdb/alembic/models/rev_4dc6a93dfed8/tests/test_model_stdout.py b/src/nextline_rdb/alembic/models/rev_4dc6a93dfed8/tests/test_model_stdout.py index 3e38593..91e7f33 100644 --- a/src/nextline_rdb/alembic/models/rev_4dc6a93dfed8/tests/test_model_stdout.py +++ b/src/nextline_rdb/alembic/models/rev_4dc6a93dfed8/tests/test_model_stdout.py @@ -12,7 +12,7 @@ @given(st.data()) -async def test_repr(data: st.DataObject): +async def test_repr(data: st.DataObject) -> None: async with DB(use_migration=False, model_base_class=Model) as db: async with db.session.begin() as session: model = data.draw(st_model_stdout()) diff --git a/src/nextline_rdb/alembic/models/rev_4dc6a93dfed8/tests/test_model_trace.py b/src/nextline_rdb/alembic/models/rev_4dc6a93dfed8/tests/test_model_trace.py index 932e186..422a19e 100644 --- a/src/nextline_rdb/alembic/models/rev_4dc6a93dfed8/tests/test_model_trace.py +++ b/src/nextline_rdb/alembic/models/rev_4dc6a93dfed8/tests/test_model_trace.py @@ -12,7 +12,7 @@ @given(st.data()) -async def test_repr(data: st.DataObject): +async def test_repr(data: st.DataObject) -> None: async with DB(use_migration=False, model_base_class=Model) as db: async with db.session.begin() as session: model = data.draw(st_model_trace()) diff --git a/src/nextline_rdb/alembic/models/rev_4dc6a93dfed8/tests/test_repr_val.py b/src/nextline_rdb/alembic/models/rev_4dc6a93dfed8/tests/test_repr_val.py index 6632177..b552a37 100644 --- a/src/nextline_rdb/alembic/models/rev_4dc6a93dfed8/tests/test_repr_val.py +++ b/src/nextline_rdb/alembic/models/rev_4dc6a93dfed8/tests/test_repr_val.py @@ -45,7 +45,7 @@ def test_enum(value: Color) -> None: @st.composite -def st_enum_type(draw: st.DrawFn): +def st_enum_type(draw: st.DrawFn) -> type[Enum]: '''Generate an Enum type. >>> enum_type = st_enum_type().example() @@ -59,11 +59,11 @@ def st_enum_type(draw: st.DrawFn): names_ = st.text(ascii_lowercase, min_size=1) names = st.builds(lambda x: x.capitalize(), names_).filter(str.isidentifier) values = st.lists(st.text(ascii_uppercase, min_size=1), min_size=1, unique=True) - return draw(st.builds(Enum, names, values)) + return draw(st.builds(Enum, names, values)) # type: ignore @given(st.data()) -def test_arbitrary_enum(data: st.DataObject): +def test_arbitrary_enum(data: st.DataObject) -> None: enum_type = data.draw(st_enum_type()) item = data.draw(st.sampled_from(enum_type)) repr_ = repr_val(item) diff --git a/src/nextline_rdb/alembic/models/rev_5a08750d6760/tests/test_model_prompt.py b/src/nextline_rdb/alembic/models/rev_5a08750d6760/tests/test_model_prompt.py index 342a2e5..207e3d2 100644 --- a/src/nextline_rdb/alembic/models/rev_5a08750d6760/tests/test_model_prompt.py +++ b/src/nextline_rdb/alembic/models/rev_5a08750d6760/tests/test_model_prompt.py @@ -12,7 +12,7 @@ @given(st.data()) -async def test_repr(data: st.DataObject): +async def test_repr(data: st.DataObject) -> None: async with DB(use_migration=False, model_base_class=Model) as db: async with db.session.begin() as session: model = data.draw(st_model_prompt()) diff --git a/src/nextline_rdb/alembic/models/rev_5a08750d6760/tests/test_model_run.py b/src/nextline_rdb/alembic/models/rev_5a08750d6760/tests/test_model_run.py index da2d723..976dfa0 100644 --- a/src/nextline_rdb/alembic/models/rev_5a08750d6760/tests/test_model_run.py +++ b/src/nextline_rdb/alembic/models/rev_5a08750d6760/tests/test_model_run.py @@ -11,7 +11,7 @@ @given(st.data()) -async def test_repr(data: st.DataObject): +async def test_repr(data: st.DataObject) -> None: async with DB(use_migration=False, model_base_class=Model) as db: async with db.session.begin() as session: model = data.draw(st_model_run(generate_traces=False)) diff --git a/src/nextline_rdb/alembic/models/rev_5a08750d6760/tests/test_model_stdout.py b/src/nextline_rdb/alembic/models/rev_5a08750d6760/tests/test_model_stdout.py index 3e38593..91e7f33 100644 --- a/src/nextline_rdb/alembic/models/rev_5a08750d6760/tests/test_model_stdout.py +++ b/src/nextline_rdb/alembic/models/rev_5a08750d6760/tests/test_model_stdout.py @@ -12,7 +12,7 @@ @given(st.data()) -async def test_repr(data: st.DataObject): +async def test_repr(data: st.DataObject) -> None: async with DB(use_migration=False, model_base_class=Model) as db: async with db.session.begin() as session: model = data.draw(st_model_stdout()) diff --git a/src/nextline_rdb/alembic/models/rev_5a08750d6760/tests/test_model_trace.py b/src/nextline_rdb/alembic/models/rev_5a08750d6760/tests/test_model_trace.py index 932e186..422a19e 100644 --- a/src/nextline_rdb/alembic/models/rev_5a08750d6760/tests/test_model_trace.py +++ b/src/nextline_rdb/alembic/models/rev_5a08750d6760/tests/test_model_trace.py @@ -12,7 +12,7 @@ @given(st.data()) -async def test_repr(data: st.DataObject): +async def test_repr(data: st.DataObject) -> None: async with DB(use_migration=False, model_base_class=Model) as db: async with db.session.begin() as session: model = data.draw(st_model_trace()) diff --git a/src/nextline_rdb/alembic/models/rev_5a08750d6760/tests/test_repr_val.py b/src/nextline_rdb/alembic/models/rev_5a08750d6760/tests/test_repr_val.py index 6632177..b552a37 100644 --- a/src/nextline_rdb/alembic/models/rev_5a08750d6760/tests/test_repr_val.py +++ b/src/nextline_rdb/alembic/models/rev_5a08750d6760/tests/test_repr_val.py @@ -45,7 +45,7 @@ def test_enum(value: Color) -> None: @st.composite -def st_enum_type(draw: st.DrawFn): +def st_enum_type(draw: st.DrawFn) -> type[Enum]: '''Generate an Enum type. >>> enum_type = st_enum_type().example() @@ -59,11 +59,11 @@ def st_enum_type(draw: st.DrawFn): names_ = st.text(ascii_lowercase, min_size=1) names = st.builds(lambda x: x.capitalize(), names_).filter(str.isidentifier) values = st.lists(st.text(ascii_uppercase, min_size=1), min_size=1, unique=True) - return draw(st.builds(Enum, names, values)) + return draw(st.builds(Enum, names, values)) # type: ignore @given(st.data()) -def test_arbitrary_enum(data: st.DataObject): +def test_arbitrary_enum(data: st.DataObject) -> None: enum_type = data.draw(st_enum_type()) item = data.draw(st.sampled_from(enum_type)) repr_ = repr_val(item) diff --git a/src/nextline_rdb/alembic/models/rev_f3edea6dbde2/model_script.py b/src/nextline_rdb/alembic/models/rev_f3edea6dbde2/model_script.py index fc69366..7e443c4 100644 --- a/src/nextline_rdb/alembic/models/rev_f3edea6dbde2/model_script.py +++ b/src/nextline_rdb/alembic/models/rev_f3edea6dbde2/model_script.py @@ -29,7 +29,7 @@ def current(self) -> bool: return self._current is not None @current.setter - def current(self, value: bool): + def current(self, value: bool) -> None: if value and self._current is None: self._current = CurrentScript(script=self) elif not value: diff --git a/src/nextline_rdb/alembic/models/rev_f3edea6dbde2/tests/test_model_prompt.py b/src/nextline_rdb/alembic/models/rev_f3edea6dbde2/tests/test_model_prompt.py index 342a2e5..207e3d2 100644 --- a/src/nextline_rdb/alembic/models/rev_f3edea6dbde2/tests/test_model_prompt.py +++ b/src/nextline_rdb/alembic/models/rev_f3edea6dbde2/tests/test_model_prompt.py @@ -12,7 +12,7 @@ @given(st.data()) -async def test_repr(data: st.DataObject): +async def test_repr(data: st.DataObject) -> None: async with DB(use_migration=False, model_base_class=Model) as db: async with db.session.begin() as session: model = data.draw(st_model_prompt()) diff --git a/src/nextline_rdb/alembic/models/rev_f3edea6dbde2/tests/test_model_run.py b/src/nextline_rdb/alembic/models/rev_f3edea6dbde2/tests/test_model_run.py index 8e8d5b4..b167453 100644 --- a/src/nextline_rdb/alembic/models/rev_f3edea6dbde2/tests/test_model_run.py +++ b/src/nextline_rdb/alembic/models/rev_f3edea6dbde2/tests/test_model_run.py @@ -10,7 +10,7 @@ @given(run=st_model_run(generate_traces=False)) -async def test_repr(run: Run): +async def test_repr(run: Run) -> None: async with DB(use_migration=False, model_base_class=Model) as db: async with db.session.begin() as session: session.add(run) @@ -25,7 +25,7 @@ async def test_repr(run: Run): @settings(phases=(Phase.generate,)) # Avoid shrinking @given(run=st_model_run(generate_traces=True)) -async def test_cascade(run: Run): +async def test_cascade(run: Run) -> None: async with DB(use_migration=False, model_base_class=Model) as db: async with db.session.begin() as session: session.add(run) diff --git a/src/nextline_rdb/alembic/models/rev_f3edea6dbde2/tests/test_model_stdout.py b/src/nextline_rdb/alembic/models/rev_f3edea6dbde2/tests/test_model_stdout.py index 3e38593..91e7f33 100644 --- a/src/nextline_rdb/alembic/models/rev_f3edea6dbde2/tests/test_model_stdout.py +++ b/src/nextline_rdb/alembic/models/rev_f3edea6dbde2/tests/test_model_stdout.py @@ -12,7 +12,7 @@ @given(st.data()) -async def test_repr(data: st.DataObject): +async def test_repr(data: st.DataObject) -> None: async with DB(use_migration=False, model_base_class=Model) as db: async with db.session.begin() as session: model = data.draw(st_model_stdout()) diff --git a/src/nextline_rdb/alembic/models/rev_f3edea6dbde2/tests/test_model_trace.py b/src/nextline_rdb/alembic/models/rev_f3edea6dbde2/tests/test_model_trace.py index 932e186..422a19e 100644 --- a/src/nextline_rdb/alembic/models/rev_f3edea6dbde2/tests/test_model_trace.py +++ b/src/nextline_rdb/alembic/models/rev_f3edea6dbde2/tests/test_model_trace.py @@ -12,7 +12,7 @@ @given(st.data()) -async def test_repr(data: st.DataObject): +async def test_repr(data: st.DataObject) -> None: async with DB(use_migration=False, model_base_class=Model) as db: async with db.session.begin() as session: model = data.draw(st_model_trace()) diff --git a/src/nextline_rdb/alembic/models/rev_f3edea6dbde2/tests/test_repr_val.py b/src/nextline_rdb/alembic/models/rev_f3edea6dbde2/tests/test_repr_val.py index 6632177..b552a37 100644 --- a/src/nextline_rdb/alembic/models/rev_f3edea6dbde2/tests/test_repr_val.py +++ b/src/nextline_rdb/alembic/models/rev_f3edea6dbde2/tests/test_repr_val.py @@ -45,7 +45,7 @@ def test_enum(value: Color) -> None: @st.composite -def st_enum_type(draw: st.DrawFn): +def st_enum_type(draw: st.DrawFn) -> type[Enum]: '''Generate an Enum type. >>> enum_type = st_enum_type().example() @@ -59,11 +59,11 @@ def st_enum_type(draw: st.DrawFn): names_ = st.text(ascii_lowercase, min_size=1) names = st.builds(lambda x: x.capitalize(), names_).filter(str.isidentifier) values = st.lists(st.text(ascii_uppercase, min_size=1), min_size=1, unique=True) - return draw(st.builds(Enum, names, values)) + return draw(st.builds(Enum, names, values)) # type: ignore @given(st.data()) -def test_arbitrary_enum(data: st.DataObject): +def test_arbitrary_enum(data: st.DataObject) -> None: enum_type = data.draw(st_enum_type()) item = data.draw(st.sampled_from(enum_type)) repr_ = repr_val(item) diff --git a/src/nextline_rdb/alembic/models/rev_f9a742bb2297/tests/test_model_prompt.py b/src/nextline_rdb/alembic/models/rev_f9a742bb2297/tests/test_model_prompt.py index 342a2e5..207e3d2 100644 --- a/src/nextline_rdb/alembic/models/rev_f9a742bb2297/tests/test_model_prompt.py +++ b/src/nextline_rdb/alembic/models/rev_f9a742bb2297/tests/test_model_prompt.py @@ -12,7 +12,7 @@ @given(st.data()) -async def test_repr(data: st.DataObject): +async def test_repr(data: st.DataObject) -> None: async with DB(use_migration=False, model_base_class=Model) as db: async with db.session.begin() as session: model = data.draw(st_model_prompt()) diff --git a/src/nextline_rdb/alembic/models/rev_f9a742bb2297/tests/test_model_run.py b/src/nextline_rdb/alembic/models/rev_f9a742bb2297/tests/test_model_run.py index e592281..cc3248f 100644 --- a/src/nextline_rdb/alembic/models/rev_f9a742bb2297/tests/test_model_run.py +++ b/src/nextline_rdb/alembic/models/rev_f9a742bb2297/tests/test_model_run.py @@ -10,7 +10,7 @@ @given(run=st_model_run(generate_traces=False)) -async def test_repr(run: Run): +async def test_repr(run: Run) -> None: async with DB(use_migration=False, model_base_class=Model) as db: async with db.session.begin() as session: session.add(run) @@ -24,7 +24,7 @@ async def test_repr(run: Run): @given(run=st_model_run(generate_traces=True)) -async def test_cascade(run: Run): +async def test_cascade(run: Run) -> None: async with DB(use_migration=False, model_base_class=Model) as db: async with db.session.begin() as session: session.add(run) diff --git a/src/nextline_rdb/alembic/models/rev_f9a742bb2297/tests/test_model_stdout.py b/src/nextline_rdb/alembic/models/rev_f9a742bb2297/tests/test_model_stdout.py index 3e38593..91e7f33 100644 --- a/src/nextline_rdb/alembic/models/rev_f9a742bb2297/tests/test_model_stdout.py +++ b/src/nextline_rdb/alembic/models/rev_f9a742bb2297/tests/test_model_stdout.py @@ -12,7 +12,7 @@ @given(st.data()) -async def test_repr(data: st.DataObject): +async def test_repr(data: st.DataObject) -> None: async with DB(use_migration=False, model_base_class=Model) as db: async with db.session.begin() as session: model = data.draw(st_model_stdout()) diff --git a/src/nextline_rdb/alembic/models/rev_f9a742bb2297/tests/test_model_trace.py b/src/nextline_rdb/alembic/models/rev_f9a742bb2297/tests/test_model_trace.py index 932e186..422a19e 100644 --- a/src/nextline_rdb/alembic/models/rev_f9a742bb2297/tests/test_model_trace.py +++ b/src/nextline_rdb/alembic/models/rev_f9a742bb2297/tests/test_model_trace.py @@ -12,7 +12,7 @@ @given(st.data()) -async def test_repr(data: st.DataObject): +async def test_repr(data: st.DataObject) -> None: async with DB(use_migration=False, model_base_class=Model) as db: async with db.session.begin() as session: model = data.draw(st_model_trace()) diff --git a/src/nextline_rdb/alembic/models/rev_f9a742bb2297/tests/test_repr_val.py b/src/nextline_rdb/alembic/models/rev_f9a742bb2297/tests/test_repr_val.py index 6632177..b552a37 100644 --- a/src/nextline_rdb/alembic/models/rev_f9a742bb2297/tests/test_repr_val.py +++ b/src/nextline_rdb/alembic/models/rev_f9a742bb2297/tests/test_repr_val.py @@ -45,7 +45,7 @@ def test_enum(value: Color) -> None: @st.composite -def st_enum_type(draw: st.DrawFn): +def st_enum_type(draw: st.DrawFn) -> type[Enum]: '''Generate an Enum type. >>> enum_type = st_enum_type().example() @@ -59,11 +59,11 @@ def st_enum_type(draw: st.DrawFn): names_ = st.text(ascii_lowercase, min_size=1) names = st.builds(lambda x: x.capitalize(), names_).filter(str.isidentifier) values = st.lists(st.text(ascii_uppercase, min_size=1), min_size=1, unique=True) - return draw(st.builds(Enum, names, values)) + return draw(st.builds(Enum, names, values)) # type: ignore @given(st.data()) -def test_arbitrary_enum(data: st.DataObject): +def test_arbitrary_enum(data: st.DataObject) -> None: enum_type = data.draw(st_enum_type()) item = data.draw(st.sampled_from(enum_type)) repr_ = repr_val(item) diff --git a/src/nextline_rdb/db.py b/src/nextline_rdb/db.py index 2b55b59..a8ec7aa 100644 --- a/src/nextline_rdb/db.py +++ b/src/nextline_rdb/db.py @@ -2,7 +2,7 @@ from logging import getLogger from os import PathLike from pathlib import Path -from typing import Optional, Type +from typing import Any, Optional, Type from alembic.config import Config from alembic.migration import MigrationContext @@ -136,7 +136,7 @@ async def __aenter__(self) -> 'DB': await self.start() return self - async def __aexit__(self, *_, **__) -> None: + async def __aexit__(self, *_: Any, **__: Any) -> None: await self.aclose() @@ -168,7 +168,7 @@ def do_run_migrations(connection: Connection) -> None: @event.listens_for(Engine, 'connect') -def set_sqlite_pragma(dbapi_connection, connection_record): +def set_sqlite_pragma(dbapi_connection: Any, connection_record: Any) -> None: '''Enable foreign key constraints in SQLite. The code copied from the SQLAlchemy documentation: diff --git a/src/nextline_rdb/models/tests/test_cascade.py b/src/nextline_rdb/models/tests/test_cascade.py index 4c3c94d..d7754da 100644 --- a/src/nextline_rdb/models/tests/test_cascade.py +++ b/src/nextline_rdb/models/tests/test_cascade.py @@ -13,19 +13,19 @@ @settings(phases=(Phase.generate,)) # Avoid shrinking @given(parent=st_model_run(generate_traces=True)) -async def test_run(parent: Run): +async def test_run(parent: Run) -> None: await assert_cascade(parent, Run, [Prompt, Trace, TraceCall, Stdout]) @settings(phases=(Phase.generate,)) # Avoid shrinking @given(parent=st_model_trace(generate_trace_calls=True, generate_prompts=True)) -async def test_trace(parent: Trace): +async def test_trace(parent: Trace) -> None: await assert_cascade(parent, Trace, [Prompt, TraceCall]) @settings(phases=(Phase.generate,)) # Avoid shrinking @given(parent=st_model_trace_call(generate_prompts=True)) -async def test_trace_call(parent: TraceCall): +async def test_trace_call(parent: TraceCall) -> None: await assert_cascade(parent, TraceCall, [Prompt]) # type: ignore diff --git a/src/nextline_rdb/models/tests/test_repr_val.py b/src/nextline_rdb/models/tests/test_repr_val.py index 6632177..b552a37 100644 --- a/src/nextline_rdb/models/tests/test_repr_val.py +++ b/src/nextline_rdb/models/tests/test_repr_val.py @@ -45,7 +45,7 @@ def test_enum(value: Color) -> None: @st.composite -def st_enum_type(draw: st.DrawFn): +def st_enum_type(draw: st.DrawFn) -> type[Enum]: '''Generate an Enum type. >>> enum_type = st_enum_type().example() @@ -59,11 +59,11 @@ def st_enum_type(draw: st.DrawFn): names_ = st.text(ascii_lowercase, min_size=1) names = st.builds(lambda x: x.capitalize(), names_).filter(str.isidentifier) values = st.lists(st.text(ascii_uppercase, min_size=1), min_size=1, unique=True) - return draw(st.builds(Enum, names, values)) + return draw(st.builds(Enum, names, values)) # type: ignore @given(st.data()) -def test_arbitrary_enum(data: st.DataObject): +def test_arbitrary_enum(data: st.DataObject) -> None: enum_type = data.draw(st_enum_type()) item = data.draw(st.sampled_from(enum_type)) repr_ = repr_val(item) diff --git a/src/nextline_rdb/pagination.py b/src/nextline_rdb/pagination.py index 6db86df..846499d 100644 --- a/src/nextline_rdb/pagination.py +++ b/src/nextline_rdb/pagination.py @@ -1,7 +1,7 @@ from collections.abc import Sequence from typing import Any, NamedTuple, Optional, Type, TypeVar -from sqlalchemy import func, select +from sqlalchemy import ScalarResult, func, select from sqlalchemy.ext.asyncio import AsyncSession from sqlalchemy.orm import DeclarativeBase, aliased, selectinload from sqlalchemy.sql.selectable import Select @@ -36,7 +36,7 @@ async def load_models( after: Optional[_Id] = None, first: Optional[int] = None, last: Optional[int] = None, -): +) -> ScalarResult[T]: sort = sort or [] if id_field not in {s.field for s in sort}: diff --git a/src/nextline_rdb/schema/nodes/prompt_node.py b/src/nextline_rdb/schema/nodes/prompt_node.py index 267e67c..e45dfe1 100644 --- a/src/nextline_rdb/schema/nodes/prompt_node.py +++ b/src/nextline_rdb/schema/nodes/prompt_node.py @@ -48,7 +48,7 @@ def trace_call( return TraceCallNode.from_model(self._model.trace_call) @classmethod - def from_model(cls: type['PromptNode'], model: db_models.Prompt): + def from_model(cls: type['PromptNode'], model: db_models.Prompt) -> 'PromptNode': return cls( _model=model, id=model.id, diff --git a/src/nextline_rdb/schema/nodes/run_node.py b/src/nextline_rdb/schema/nodes/run_node.py index 11a4fa6..6123e0e 100644 --- a/src/nextline_rdb/schema/nodes/run_node.py +++ b/src/nextline_rdb/schema/nodes/run_node.py @@ -154,7 +154,7 @@ class RunNode: ] = strawberry.field(resolver=_resolve_stdouts) @classmethod - def from_model(cls: type['RunNode'], model: Run): + def from_model(cls: type['RunNode'], model: Run) -> 'RunNode': script = model.script.script if model.script else None return cls( _model=model, diff --git a/src/nextline_rdb/schema/nodes/stdout_node.py b/src/nextline_rdb/schema/nodes/stdout_node.py index 1a2f40e..ffc05d7 100644 --- a/src/nextline_rdb/schema/nodes/stdout_node.py +++ b/src/nextline_rdb/schema/nodes/stdout_node.py @@ -32,7 +32,7 @@ def trace(self) -> Annotated['TraceNode', strawberry.lazy('.trace_node')]: return TraceNode.from_model(self._model.trace) @classmethod - def from_model(cls: type['StdoutNode'], model: db_models.Stdout): + def from_model(cls: type['StdoutNode'], model: db_models.Stdout) -> 'StdoutNode': return cls( _model=model, id=model.id, diff --git a/src/nextline_rdb/schema/nodes/trace_call_node.py b/src/nextline_rdb/schema/nodes/trace_call_node.py index 48614fb..6ba1013 100644 --- a/src/nextline_rdb/schema/nodes/trace_call_node.py +++ b/src/nextline_rdb/schema/nodes/trace_call_node.py @@ -77,7 +77,9 @@ def trace(self) -> Annotated['TraceNode', strawberry.lazy('.trace_node')]: ] = strawberry.field(resolver=_resolve_prompts) @classmethod - def from_model(cls: type['TraceCallNode'], model: db_models.TraceCall): + def from_model( + cls: type['TraceCallNode'], model: db_models.TraceCall + ) -> 'TraceCallNode': return cls( _model=model, id=model.id, diff --git a/src/nextline_rdb/schema/nodes/trace_node.py b/src/nextline_rdb/schema/nodes/trace_node.py index c72d5d4..734d59f 100644 --- a/src/nextline_rdb/schema/nodes/trace_node.py +++ b/src/nextline_rdb/schema/nodes/trace_node.py @@ -99,7 +99,7 @@ def run(self) -> Annotated['RunNode', strawberry.lazy('.run_node')]: ] = strawberry.field(resolver=_resolve_stdouts) @classmethod - def from_model(cls: type['TraceNode'], model: db_models.Trace): + def from_model(cls: type['TraceNode'], model: db_models.Trace) -> 'TraceNode': return cls( _model=model, id=model.id, diff --git a/tests/db/models.py b/tests/db/models.py index 112e803..cbf2fb4 100644 --- a/tests/db/models.py +++ b/tests/db/models.py @@ -1,3 +1,5 @@ +from typing import Any + from sqlalchemy import ForeignKey, event from sqlalchemy.orm import ( DeclarativeBase, @@ -31,10 +33,10 @@ def register_session_events(session: sessionmaker[Session]) -> None: new: set[Model] @event.listens_for(session, 'before_flush') - def _before_flush(session: Session, flush_context, instances): + def _before_flush(session: Session, flush_context: Any, instances: Any) -> None: nonlocal new new = set(session.new) @event.listens_for(session, 'after_flush_postexec') - def _after_flush_postexec(session: Session, flush_context): + def _after_flush_postexec(session: Session, flush_context: Any) -> None: pass diff --git a/tests/db/test_db.py b/tests/db/test_db.py index b061bc1..6bdb09b 100644 --- a/tests/db/test_db.py +++ b/tests/db/test_db.py @@ -10,7 +10,7 @@ from .models import Bar, Foo, Model, register_session_events -async def test_ensure_sync_url(tmp_url_factory: Callable[[], str]): +async def test_ensure_sync_url(tmp_url_factory: Callable[[], str]) -> None: url = tmp_url_factory() sync_url = ensure_sync_url(url) @@ -18,7 +18,7 @@ async def test_ensure_sync_url(tmp_url_factory: Callable[[], str]): assert db.url == url -async def test_fields(): +async def test_fields() -> None: db = DB() assert db.url assert db.metadata @@ -38,7 +38,9 @@ async def test_migration_revision(use_migration: bool) -> None: @given(st.lists(st.integers(min_value=0, max_value=4), min_size=0, max_size=4)) -async def test_session_nested(tmp_url_factory: Callable[[], str], sizes: list[int]): +async def test_session_nested( + tmp_url_factory: Callable[[], str], sizes: list[int] +) -> None: url = tmp_url_factory() objs = [Foo(bars=[Bar() for _ in range(size)]) for size in sizes] @@ -68,7 +70,9 @@ async def test_session_nested(tmp_url_factory: Callable[[], str], sizes: list[in @given(st.lists(st.integers(min_value=0, max_value=4), min_size=0, max_size=4)) -async def test_session_begin(tmp_url_factory: Callable[[], str], sizes: list[int]): +async def test_session_begin( + tmp_url_factory: Callable[[], str], sizes: list[int] +) -> None: url = tmp_url_factory() objs = [Foo(bars=[Bar() for _ in range(size)]) for size in sizes] diff --git a/tests/pagination/funcs.py b/tests/pagination/funcs.py index 688fac4..0635beb 100644 --- a/tests/pagination/funcs.py +++ b/tests/pagination/funcs.py @@ -25,7 +25,7 @@ def st_sort(draw: st.DrawFn) -> Sort: @given(st_sort()) -def test_st_sort(sort: Sort): +def test_st_sort(sort: Sort) -> None: # ic(sort) pass diff --git a/tests/pagination/test_load_models.py b/tests/pagination/test_load_models.py index c14df46..9c49de3 100644 --- a/tests/pagination/test_load_models.py +++ b/tests/pagination/test_load_models.py @@ -12,7 +12,7 @@ @given(st.data()) -async def test_all(data: st.DataObject): +async def test_all(data: st.DataObject) -> None: n_max = 10 entities = data.draw(st.lists(st_entity(), min_size=0, max_size=n_max)) @@ -35,7 +35,7 @@ async def test_all(data: st.DataObject): @given(st.data()) -async def test_forward(data: st.DataObject): +async def test_forward(data: st.DataObject) -> None: n_max = 10 entities = data.draw(st.lists(st_entity(), min_size=0, max_size=n_max)) @@ -73,7 +73,7 @@ async def test_forward(data: st.DataObject): @given(st.data()) -async def test_backward(data: st.DataObject): +async def test_backward(data: st.DataObject) -> None: n_max = 10 entities = data.draw(st.lists(st_entity(), min_size=0, max_size=n_max)) diff --git a/tests/schema/queries/test_pagination.py b/tests/schema/queries/test_pagination.py index d9233d0..173e401 100644 --- a/tests/schema/queries/test_pagination.py +++ b/tests/schema/queries/test_pagination.py @@ -43,7 +43,7 @@ class Edge(TypedDict): @given(runs=st_model_run_list(generate_traces=False, min_size=0, max_size=12)) -async def test_all(runs: list[Run]): +async def test_all(runs: list[Run]) -> None: schema = strawberry.Schema(query=Query) async with DB() as db: @@ -78,7 +78,7 @@ async def test_all(runs: list[Run]): runs=st_model_run_list(generate_traces=False, min_size=0, max_size=12), first=st.integers(min_value=1, max_value=15), ) -async def test_forward(runs: list[Run], first: int): +async def test_forward(runs: list[Run], first: int) -> None: schema = strawberry.Schema(query=Query) async with DB() as db: @@ -130,7 +130,7 @@ async def test_forward(runs: list[Run], first: int): runs=st_model_run_list(generate_traces=False, min_size=0, max_size=12), last=st.integers(min_value=1, max_value=15), ) -async def test_backward(runs: list[Run], last: int): +async def test_backward(runs: list[Run], last: int) -> None: schema = strawberry.Schema(query=Query) async with DB() as db: @@ -179,7 +179,7 @@ async def test_backward(runs: list[Run], last: int): @given(runs=st_model_run_list(generate_traces=False, min_size=0, max_size=12)) -async def test_cursor(runs: list[Run]): +async def test_cursor(runs: list[Run]) -> None: schema = strawberry.Schema(query=Query) async with DB() as db: diff --git a/tests/test_plugin.py b/tests/test_plugin.py index 2d25d12..b83566e 100644 --- a/tests/test_plugin.py +++ b/tests/test_plugin.py @@ -14,7 +14,7 @@ from .schema.graphql import QUERY_RDB_CONNECTIONS -async def test_plugin(set_new_url: Callable[[], str]): +async def test_plugin(set_new_url: Callable[[], str]) -> None: # Enter some runs into the database. runs = st_model_run_list(generate_traces=False, max_size=2).example() @@ -37,7 +37,7 @@ async def test_plugin(set_new_url: Callable[[], str]): assert n_runs == expected_n_runs -def test_fixture(settings_path: Path, set_new_url: Callable[[], str]): +def test_fixture(settings_path: Path, set_new_url: Callable[[], str]) -> None: from nextline_rdb import plugin assert str(settings_path) in plugin.SETTINGS @@ -67,7 +67,9 @@ def _f() -> str: @pytest.fixture(autouse=True) -def monkeypatch_settings_path(monkeypatch: pytest.MonkeyPatch, settings_path: Path): +def monkeypatch_settings_path( + monkeypatch: pytest.MonkeyPatch, settings_path: Path +) -> None: from nextline_rdb import plugin monkeypatch.setattr(plugin, 'SETTINGS', (str(settings_path),)) diff --git a/tests/test_version.py b/tests/test_version.py index 331d5d3..e1744c9 100644 --- a/tests/test_version.py +++ b/tests/test_version.py @@ -1,6 +1,6 @@ import nextline_rdb -def test_version(): +def test_version() -> None: '''Confirm that the version string is attached to the module''' nextline_rdb.__version__ diff --git a/tests/utils/test_mark_last.py b/tests/utils/test_mark_last.py index 9a0bfeb..ddc28c0 100644 --- a/tests/utils/test_mark_last.py +++ b/tests/utils/test_mark_last.py @@ -5,7 +5,7 @@ @given(st.lists(st.integers())) -def test_one(items: list[int]): +def test_one(items: list[int]) -> None: it = iter(items) actual = list(mark_last(it)) diff --git a/tests/utils/test_until.py b/tests/utils/test_until.py index 454bb4c..128d293 100644 --- a/tests/utils/test_until.py +++ b/tests/utils/test_until.py @@ -1,4 +1,6 @@ import asyncio +from collections.abc import AsyncIterator +from typing import NoReturn import pytest from hypothesis import given @@ -18,7 +20,7 @@ async def test_return(data: st.DataObject) -> None: async def test_timeout() -> None: - async def gen_none(): + async def gen_none() -> AsyncIterator[None]: while True: await asyncio.sleep(0) yield None @@ -31,7 +33,7 @@ async def gen_none(): @pytest.mark.timeout(5) async def test_timeout_never_return() -> None: - async def func(): + async def func() -> NoReturn: while True: await asyncio.sleep(0) From b9d18e0de647be93ef3bffdf541e6f4a6e008725 Mon Sep 17 00:00:00 2001 From: Tai Sakuma Date: Fri, 14 Jun 2024 12:04:14 -0400 Subject: [PATCH 2/2] Enable the `disallow_untyped_defs` option in `pyproject.toml` --- pyproject.toml | 1 + 1 file changed, 1 insertion(+) diff --git a/pyproject.toml b/pyproject.toml index 501321a..00af0c8 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -93,6 +93,7 @@ exclude = '''(?x)( src/nextline_rdb/alembic/versions/.*\.py$ | example_script/.*\.py$ )''' +disallow_untyped_defs = true [[tool.mypy.overrides]] module = ["dynaconf.*", "async_asgi_testclient.*", "apluggy.*"]