Skip to content

Commit

Permalink
Merge pull request #82 from simonsobs/dev
Browse files Browse the repository at this point in the history
Fix all mypy errors with the `disallow_untyped_defs` option
  • Loading branch information
TaiSakuma authored Jun 14, 2024
2 parents f52348e + b9d18e0 commit e0a2c07
Show file tree
Hide file tree
Showing 41 changed files with 83 additions and 70 deletions.
1 change: 1 addition & 0 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -93,6 +93,7 @@ exclude = '''(?x)(
src/nextline_rdb/alembic/versions/.*\.py$
| example_script/.*\.py$
)'''
disallow_untyped_defs = true

[[tool.mypy.overrides]]
module = ["dynaconf.*", "async_asgi_testclient.*", "apluggy.*"]
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -29,7 +29,7 @@ def current(self) -> bool:
return self._current is not None

@current.setter
def current(self, value: bool):
def current(self, value: bool) -> None:
if value and self._current is None:
self._current = CurrentScript(script=self)
elif not value:
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@


@given(st.data())
async def test_repr(data: st.DataObject):
async def test_repr(data: st.DataObject) -> None:
async with DB(use_migration=False, model_base_class=Model) as db:
async with db.session.begin() as session:
model = data.draw(st_model_prompt())
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,7 @@


@given(run=st_model_run(generate_traces=False))
async def test_repr(run: Run):
async def test_repr(run: Run) -> None:
async with DB(use_migration=False, model_base_class=Model) as db:
async with db.session.begin() as session:
session.add(run)
Expand All @@ -24,7 +24,7 @@ async def test_repr(run: Run):


@given(run=st_model_run(generate_traces=True))
async def test_cascade(run: Run):
async def test_cascade(run: Run) -> None:
async with DB(use_migration=False, model_base_class=Model) as db:
async with db.session.begin() as session:
session.add(run)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@


@given(st.data())
async def test_repr(data: st.DataObject):
async def test_repr(data: st.DataObject) -> None:
async with DB(use_migration=False, model_base_class=Model) as db:
async with db.session.begin() as session:
model = data.draw(st_model_stdout())
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@


@given(st.data())
async def test_repr(data: st.DataObject):
async def test_repr(data: st.DataObject) -> None:
async with DB(use_migration=False, model_base_class=Model) as db:
async with db.session.begin() as session:
model = data.draw(st_model_trace())
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -45,7 +45,7 @@ def test_enum(value: Color) -> None:


@st.composite
def st_enum_type(draw: st.DrawFn):
def st_enum_type(draw: st.DrawFn) -> type[Enum]:
'''Generate an Enum type.
>>> enum_type = st_enum_type().example()
Expand All @@ -59,11 +59,11 @@ def st_enum_type(draw: st.DrawFn):
names_ = st.text(ascii_lowercase, min_size=1)
names = st.builds(lambda x: x.capitalize(), names_).filter(str.isidentifier)
values = st.lists(st.text(ascii_uppercase, min_size=1), min_size=1, unique=True)
return draw(st.builds(Enum, names, values))
return draw(st.builds(Enum, names, values)) # type: ignore


@given(st.data())
def test_arbitrary_enum(data: st.DataObject):
def test_arbitrary_enum(data: st.DataObject) -> None:
enum_type = data.draw(st_enum_type())
item = data.draw(st.sampled_from(enum_type))
repr_ = repr_val(item)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@


@given(st.data())
async def test_repr(data: st.DataObject):
async def test_repr(data: st.DataObject) -> None:
async with DB(use_migration=False, model_base_class=Model) as db:
async with db.session.begin() as session:
model = data.draw(st_model_prompt())
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,7 @@


@given(st.data())
async def test_repr(data: st.DataObject):
async def test_repr(data: st.DataObject) -> None:
async with DB(use_migration=False, model_base_class=Model) as db:
async with db.session.begin() as session:
model = data.draw(st_model_run(generate_traces=False))
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@


@given(st.data())
async def test_repr(data: st.DataObject):
async def test_repr(data: st.DataObject) -> None:
async with DB(use_migration=False, model_base_class=Model) as db:
async with db.session.begin() as session:
model = data.draw(st_model_stdout())
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@


@given(st.data())
async def test_repr(data: st.DataObject):
async def test_repr(data: st.DataObject) -> None:
async with DB(use_migration=False, model_base_class=Model) as db:
async with db.session.begin() as session:
model = data.draw(st_model_trace())
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -45,7 +45,7 @@ def test_enum(value: Color) -> None:


@st.composite
def st_enum_type(draw: st.DrawFn):
def st_enum_type(draw: st.DrawFn) -> type[Enum]:
'''Generate an Enum type.
>>> enum_type = st_enum_type().example()
Expand All @@ -59,11 +59,11 @@ def st_enum_type(draw: st.DrawFn):
names_ = st.text(ascii_lowercase, min_size=1)
names = st.builds(lambda x: x.capitalize(), names_).filter(str.isidentifier)
values = st.lists(st.text(ascii_uppercase, min_size=1), min_size=1, unique=True)
return draw(st.builds(Enum, names, values))
return draw(st.builds(Enum, names, values)) # type: ignore


@given(st.data())
def test_arbitrary_enum(data: st.DataObject):
def test_arbitrary_enum(data: st.DataObject) -> None:
enum_type = data.draw(st_enum_type())
item = data.draw(st.sampled_from(enum_type))
repr_ = repr_val(item)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -29,7 +29,7 @@ def current(self) -> bool:
return self._current is not None

@current.setter
def current(self, value: bool):
def current(self, value: bool) -> None:
if value and self._current is None:
self._current = CurrentScript(script=self)
elif not value:
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@


@given(st.data())
async def test_repr(data: st.DataObject):
async def test_repr(data: st.DataObject) -> None:
async with DB(use_migration=False, model_base_class=Model) as db:
async with db.session.begin() as session:
model = data.draw(st_model_prompt())
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,7 @@


@given(run=st_model_run(generate_traces=False))
async def test_repr(run: Run):
async def test_repr(run: Run) -> None:
async with DB(use_migration=False, model_base_class=Model) as db:
async with db.session.begin() as session:
session.add(run)
Expand All @@ -25,7 +25,7 @@ async def test_repr(run: Run):

@settings(phases=(Phase.generate,)) # Avoid shrinking
@given(run=st_model_run(generate_traces=True))
async def test_cascade(run: Run):
async def test_cascade(run: Run) -> None:
async with DB(use_migration=False, model_base_class=Model) as db:
async with db.session.begin() as session:
session.add(run)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@


@given(st.data())
async def test_repr(data: st.DataObject):
async def test_repr(data: st.DataObject) -> None:
async with DB(use_migration=False, model_base_class=Model) as db:
async with db.session.begin() as session:
model = data.draw(st_model_stdout())
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@


@given(st.data())
async def test_repr(data: st.DataObject):
async def test_repr(data: st.DataObject) -> None:
async with DB(use_migration=False, model_base_class=Model) as db:
async with db.session.begin() as session:
model = data.draw(st_model_trace())
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -45,7 +45,7 @@ def test_enum(value: Color) -> None:


@st.composite
def st_enum_type(draw: st.DrawFn):
def st_enum_type(draw: st.DrawFn) -> type[Enum]:
'''Generate an Enum type.
>>> enum_type = st_enum_type().example()
Expand All @@ -59,11 +59,11 @@ def st_enum_type(draw: st.DrawFn):
names_ = st.text(ascii_lowercase, min_size=1)
names = st.builds(lambda x: x.capitalize(), names_).filter(str.isidentifier)
values = st.lists(st.text(ascii_uppercase, min_size=1), min_size=1, unique=True)
return draw(st.builds(Enum, names, values))
return draw(st.builds(Enum, names, values)) # type: ignore


@given(st.data())
def test_arbitrary_enum(data: st.DataObject):
def test_arbitrary_enum(data: st.DataObject) -> None:
enum_type = data.draw(st_enum_type())
item = data.draw(st.sampled_from(enum_type))
repr_ = repr_val(item)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@


@given(st.data())
async def test_repr(data: st.DataObject):
async def test_repr(data: st.DataObject) -> None:
async with DB(use_migration=False, model_base_class=Model) as db:
async with db.session.begin() as session:
model = data.draw(st_model_prompt())
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,7 @@


@given(run=st_model_run(generate_traces=False))
async def test_repr(run: Run):
async def test_repr(run: Run) -> None:
async with DB(use_migration=False, model_base_class=Model) as db:
async with db.session.begin() as session:
session.add(run)
Expand All @@ -24,7 +24,7 @@ async def test_repr(run: Run):


@given(run=st_model_run(generate_traces=True))
async def test_cascade(run: Run):
async def test_cascade(run: Run) -> None:
async with DB(use_migration=False, model_base_class=Model) as db:
async with db.session.begin() as session:
session.add(run)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@


@given(st.data())
async def test_repr(data: st.DataObject):
async def test_repr(data: st.DataObject) -> None:
async with DB(use_migration=False, model_base_class=Model) as db:
async with db.session.begin() as session:
model = data.draw(st_model_stdout())
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@


@given(st.data())
async def test_repr(data: st.DataObject):
async def test_repr(data: st.DataObject) -> None:
async with DB(use_migration=False, model_base_class=Model) as db:
async with db.session.begin() as session:
model = data.draw(st_model_trace())
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -45,7 +45,7 @@ def test_enum(value: Color) -> None:


@st.composite
def st_enum_type(draw: st.DrawFn):
def st_enum_type(draw: st.DrawFn) -> type[Enum]:
'''Generate an Enum type.
>>> enum_type = st_enum_type().example()
Expand All @@ -59,11 +59,11 @@ def st_enum_type(draw: st.DrawFn):
names_ = st.text(ascii_lowercase, min_size=1)
names = st.builds(lambda x: x.capitalize(), names_).filter(str.isidentifier)
values = st.lists(st.text(ascii_uppercase, min_size=1), min_size=1, unique=True)
return draw(st.builds(Enum, names, values))
return draw(st.builds(Enum, names, values)) # type: ignore


@given(st.data())
def test_arbitrary_enum(data: st.DataObject):
def test_arbitrary_enum(data: st.DataObject) -> None:
enum_type = data.draw(st_enum_type())
item = data.draw(st.sampled_from(enum_type))
repr_ = repr_val(item)
Expand Down
6 changes: 3 additions & 3 deletions src/nextline_rdb/db.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@
from logging import getLogger
from os import PathLike
from pathlib import Path
from typing import Optional, Type
from typing import Any, Optional, Type

from alembic.config import Config
from alembic.migration import MigrationContext
Expand Down Expand Up @@ -136,7 +136,7 @@ async def __aenter__(self) -> 'DB':
await self.start()
return self

async def __aexit__(self, *_, **__) -> None:
async def __aexit__(self, *_: Any, **__: Any) -> None:
await self.aclose()


Expand Down Expand Up @@ -168,7 +168,7 @@ def do_run_migrations(connection: Connection) -> None:


@event.listens_for(Engine, 'connect')
def set_sqlite_pragma(dbapi_connection, connection_record):
def set_sqlite_pragma(dbapi_connection: Any, connection_record: Any) -> None:
'''Enable foreign key constraints in SQLite.
The code copied from the SQLAlchemy documentation:
Expand Down
6 changes: 3 additions & 3 deletions src/nextline_rdb/models/tests/test_cascade.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,19 +13,19 @@

@settings(phases=(Phase.generate,)) # Avoid shrinking
@given(parent=st_model_run(generate_traces=True))
async def test_run(parent: Run):
async def test_run(parent: Run) -> None:
await assert_cascade(parent, Run, [Prompt, Trace, TraceCall, Stdout])


@settings(phases=(Phase.generate,)) # Avoid shrinking
@given(parent=st_model_trace(generate_trace_calls=True, generate_prompts=True))
async def test_trace(parent: Trace):
async def test_trace(parent: Trace) -> None:
await assert_cascade(parent, Trace, [Prompt, TraceCall])


@settings(phases=(Phase.generate,)) # Avoid shrinking
@given(parent=st_model_trace_call(generate_prompts=True))
async def test_trace_call(parent: TraceCall):
async def test_trace_call(parent: TraceCall) -> None:
await assert_cascade(parent, TraceCall, [Prompt]) # type: ignore


Expand Down
6 changes: 3 additions & 3 deletions src/nextline_rdb/models/tests/test_repr_val.py
Original file line number Diff line number Diff line change
Expand Up @@ -45,7 +45,7 @@ def test_enum(value: Color) -> None:


@st.composite
def st_enum_type(draw: st.DrawFn):
def st_enum_type(draw: st.DrawFn) -> type[Enum]:
'''Generate an Enum type.
>>> enum_type = st_enum_type().example()
Expand All @@ -59,11 +59,11 @@ def st_enum_type(draw: st.DrawFn):
names_ = st.text(ascii_lowercase, min_size=1)
names = st.builds(lambda x: x.capitalize(), names_).filter(str.isidentifier)
values = st.lists(st.text(ascii_uppercase, min_size=1), min_size=1, unique=True)
return draw(st.builds(Enum, names, values))
return draw(st.builds(Enum, names, values)) # type: ignore


@given(st.data())
def test_arbitrary_enum(data: st.DataObject):
def test_arbitrary_enum(data: st.DataObject) -> None:
enum_type = data.draw(st_enum_type())
item = data.draw(st.sampled_from(enum_type))
repr_ = repr_val(item)
Expand Down
4 changes: 2 additions & 2 deletions src/nextline_rdb/pagination.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
from collections.abc import Sequence
from typing import Any, NamedTuple, Optional, Type, TypeVar

from sqlalchemy import func, select
from sqlalchemy import ScalarResult, func, select
from sqlalchemy.ext.asyncio import AsyncSession
from sqlalchemy.orm import DeclarativeBase, aliased, selectinload
from sqlalchemy.sql.selectable import Select
Expand Down Expand Up @@ -36,7 +36,7 @@ async def load_models(
after: Optional[_Id] = None,
first: Optional[int] = None,
last: Optional[int] = None,
):
) -> ScalarResult[T]:
sort = sort or []

if id_field not in {s.field for s in sort}:
Expand Down
2 changes: 1 addition & 1 deletion src/nextline_rdb/schema/nodes/prompt_node.py
Original file line number Diff line number Diff line change
Expand Up @@ -48,7 +48,7 @@ def trace_call(
return TraceCallNode.from_model(self._model.trace_call)

@classmethod
def from_model(cls: type['PromptNode'], model: db_models.Prompt):
def from_model(cls: type['PromptNode'], model: db_models.Prompt) -> 'PromptNode':
return cls(
_model=model,
id=model.id,
Expand Down
2 changes: 1 addition & 1 deletion src/nextline_rdb/schema/nodes/run_node.py
Original file line number Diff line number Diff line change
Expand Up @@ -154,7 +154,7 @@ class RunNode:
] = strawberry.field(resolver=_resolve_stdouts)

@classmethod
def from_model(cls: type['RunNode'], model: Run):
def from_model(cls: type['RunNode'], model: Run) -> 'RunNode':
script = model.script.script if model.script else None
return cls(
_model=model,
Expand Down
Loading

0 comments on commit e0a2c07

Please sign in to comment.