Skip to content

Commit

Permalink
Improve unions and tests
Browse files Browse the repository at this point in the history
  • Loading branch information
zmievsa committed Jul 16, 2023
1 parent a12f054 commit 7b092ae
Show file tree
Hide file tree
Showing 12 changed files with 158 additions and 21 deletions.
3 changes: 2 additions & 1 deletion pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -32,7 +32,7 @@ dirty-equals = "^0.6.0"
asyncio_mode = "auto"

[tool.coverage.report]
fail_under = 99
fail_under = 100
skip_covered = true
skip_empty = true
# Taken from https://coverage.readthedocs.io/en/7.1.0/excluding.html#advanced-exclusion
Expand All @@ -56,6 +56,7 @@ exclude_lines = [
"__rich_repr__",
"__repr__",
]
omit=["tests/test_tutorial/test_users_example003/*"]


[tool.ruff]
Expand Down
13 changes: 10 additions & 3 deletions tests/test_codegen.py
Original file line number Diff line number Diff line change
Expand Up @@ -398,17 +398,24 @@ def test__codegen_unions__init_file():
from tests._data import v2000_01_01, v2001_01_01
from tests._data.unions import EnumWithOneMemberUnion, SchemaWithOneIntFieldUnion

assert EnumWithOneMemberUnion == v2000_01_01.EnumWithOneMember | v2001_01_01.EnumWithOneMember
assert SchemaWithOneIntFieldUnion == v2000_01_01.SchemaWithOneIntField | v2001_01_01.SchemaWithOneIntField
assert (
EnumWithOneMemberUnion
== v2000_01_01.EnumWithOneMember | v2001_01_01.EnumWithOneMember | latest.EnumWithOneMember
)
assert (
SchemaWithOneIntFieldUnion
== v2000_01_01.SchemaWithOneIntField | v2001_01_01.SchemaWithOneIntField | latest.SchemaWithOneIntField
)


def test__codegen_unions__regular_file():
generate_test_version_packages()
from tests._data.unions.some_schema import MySchemaUnion
from tests._data.v2000_01_01.some_schema import MySchema as MySchema2000
from tests._data.v2001_01_01.some_schema import MySchema as MySchema2001
from tests._data.latest.some_schema import MySchema as MySchemaLatest

assert MySchemaUnion == MySchema2000 | MySchema2001
assert MySchemaUnion == MySchema2000 | MySchema2001 | MySchemaLatest


def test__codegen_property():
Expand Down
Empty file.
21 changes: 21 additions & 0 deletions tests/test_tutorial/test_users_example003/run.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,21 @@
if __name__ == "__main__":
from datetime import date
from pathlib import Path

import uvicorn
from fastapi import FastAPI

from tests.test_tutorial.test_users_example003.schemas import latest
from tests.test_tutorial.test_users_example003.users import router, versions
from tests.test_tutorial.utils import clean_versions
from universi import api_version_var, regenerate_dir_to_all_versions

try:
regenerate_dir_to_all_versions(latest, versions)
router_versions = router.create_versioned_copies(versions, latest_schemas_module=latest)
app = FastAPI()
api_version_var.set(date(2000, 1, 1))
app.include_router(router_versions[date(2000, 1, 1)])
uvicorn.run(app)
finally:
clean_versions(Path(__file__).parent / "schemas")
23 changes: 23 additions & 0 deletions tests/test_tutorial/test_users_example003/scenario.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,23 @@
from .schemas.unions.users import UserCreateRequestUnion


class UserScenario:
async def create_user(self, payload: UserCreateRequestUnion):
return {
"id": 83,
"_prefetched_addresses": [{"id": 100, "value": payload.default_address}],
}

async def get_user(self, user_id: int):
return {
"id": user_id,
"_prefetched_addresses": (await self.get_user_addresses(user_id))["data"],
}

async def get_user_addresses(self, user_id: int):
return {
"data": [
{"id": 83, "value": "123 Example St"},
{"id": 91, "value": "456 Main St"},
],
}
Empty file.
Empty file.
18 changes: 18 additions & 0 deletions tests/test_tutorial/test_users_example003/schemas/latest/users.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,18 @@
from pydantic import BaseModel


class UserCreateRequest(BaseModel):
default_address: str


class UserResource(BaseModel):
id: int


class UserAddressResource(BaseModel):
id: int
value: str


class UserAddressResourceList(BaseModel):
data: list[UserAddressResource]
79 changes: 79 additions & 0 deletions tests/test_tutorial/test_users_example003/users.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,79 @@
from datetime import date
from typing import Any

from universi import Field, VersionedAPIRouter
from universi.structure import (
Version,
VersionChange,
Versions,
convert_response_to_previous_version_for,
endpoint,
schema,
)

from .schemas.latest.users import (
UserAddressResourceList,
UserCreateRequest,
UserResource,
)
from .scenario import UserScenario

router = VersionedAPIRouter()


@router.post("/users", response_model=UserResource)
async def create_user(user: UserCreateRequest):
return await UserScenario().create_user(user)


@router.get("/users/{user_id}", response_model=UserResource)
async def get_user(user_id: int):
return await UserScenario().get_user(user_id)


@router.get("/users/{user_id}/addresses", response_model=UserAddressResourceList)
async def get_user_addresses(user_id: int):
return await UserScenario().get_user_addresses(user_id)


class ChangeAddressToList(VersionChange):
description = "Change vat id to list"
instructions_to_migrate_to_previous_version = (
schema(UserCreateRequest).field("addresses").didnt_exist,
schema(UserCreateRequest).field("address").existed_with(type=str, info=Field()),
schema(UserResource).field("addresses").didnt_exist,
schema(UserResource).field("address").existed_with(type=str, info=Field()),
)

@convert_response_to_previous_version_for(get_user, create_user)
def change_addresses_to_single_item(cls, data: dict[str, Any]) -> None:
data["address"] = data.pop("addresses")[0]

@schema(UserCreateRequest).had_property("addresses")
def addresses_property(parsed_schema):
return [parsed_schema.address] # pragma: no cover


class ChangeAddressesToSubresource(VersionChange):
description = "Change vat ids to subresource"
instructions_to_migrate_to_previous_version = (
schema(UserCreateRequest).field("addresses").existed_with(type=list[str], info=Field()),
schema(UserCreateRequest).field("default_address").didnt_exist,
schema(UserResource).field("addresses").existed_with(type=list[str], info=Field()),
endpoint(get_user_addresses).didnt_exist,
)

@convert_response_to_previous_version_for(get_user, create_user)
def change_addresses_to_list(cls, data: dict[str, Any]) -> None:
data["addresses"] = [id["value"] for id in data.pop("_prefetched_addresses")]

@schema(UserCreateRequest).had_property("default_address")
def default_address_property(parsed_schema):
return parsed_schema.addresses[0] # pragma: no cover


versions = Versions(
Version(date(2002, 1, 1), ChangeAddressesToSubresource),
Version(date(2001, 1, 1), ChangeAddressToList),
Version(date(2000, 1, 1)),
)
2 changes: 1 addition & 1 deletion universi/codegen.py
Original file line number Diff line number Diff line change
Expand Up @@ -117,7 +117,7 @@ def _get_unionized_version_of_module(
import_pythonpath_template.format(_get_version_dir_name(version.date))
for version in versions.versions
]

imported_modules += [import_pythonpath_template.format("latest")]
parsed_file = _parse_python_module(original_module)

body = ast.Module(
Expand Down
19 changes: 4 additions & 15 deletions universi/routing.py
Original file line number Diff line number Diff line change
Expand Up @@ -110,7 +110,6 @@ def create_versioned_copies(
route.response_model,
version_dir,
)
# TODO: Write a test for this line
route.dependencies = _change_versions_of_all_annotations(
route.dependencies,
version_dir,
Expand Down Expand Up @@ -197,9 +196,7 @@ def _change_versions_of_all_annotations(annotation: Any, version_dir: Path) -> A
}

elif isinstance(annotation, list | tuple):
return type(annotation)(
_change_versions_of_all_annotations(v, version_dir) for v in annotation
)
return type(annotation)(_change_versions_of_all_annotations(v, version_dir) for v in annotation)
else:
return _memoized_change_versions_of_all_annotations(annotation, version_dir)

Expand All @@ -213,10 +210,7 @@ def _memoized_change_versions_of_all_annotations(
) -> Any:
if isinstance(annotation, _BaseGenericAlias | GenericAlias):
return _change_versions_of_all_annotations(get_origin(annotation), version_dir)[
tuple(
_change_versions_of_all_annotations(arg, version_dir)
for arg in get_args(annotation)
)
tuple(_change_versions_of_all_annotations(arg, version_dir) for arg in get_args(annotation))
]
elif isinstance(annotation, Depends):
return Depends(
Expand Down Expand Up @@ -258,11 +252,7 @@ def new_callable( # pyright: ignore[reportGeneralTypeIssues]
version_dir,
)
new_callable.__defaults__ = _change_versions_of_all_annotations(
tuple(
p.default
for p in old_params.values()
if p.default is not inspect.Signature.empty
),
tuple(p.default for p in old_params.values() if p.default is not inspect.Signature.empty),
version_dir=version_dir,
)
new_callable.__signature__ = _generate_signature(new_callable, old_params)
Expand Down Expand Up @@ -306,8 +296,7 @@ def _generate_signature(
def _get_route_index(routes: list[BaseRoute], endpoint: Endpoint):
for index, route in enumerate(routes):
if isinstance(route, APIRoute) and (
route.endpoint == endpoint
or getattr(route.endpoint, "func", None) == endpoint
route.endpoint == endpoint or getattr(route.endpoint, "func", None) == endpoint
):
return index
return None
1 change: 0 additions & 1 deletion universi/structure/schemas.py
Original file line number Diff line number Diff line change
Expand Up @@ -158,7 +158,6 @@ class SchemaPropertyDidntExistInstruction:
name: str


# TODO: Validate that the function has the correct definition
@dataclass
class SchemaPropertyDefinitionInstruction:
schema: type[BaseModel]
Expand Down

0 comments on commit 7b092ae

Please sign in to comment.