Skip to content

Commit

Permalink
remove redundant logic from resolving of object provider, add tests
Browse files Browse the repository at this point in the history
  • Loading branch information
nightblure committed Nov 22, 2024
1 parent 551b63b commit b59191c
Show file tree
Hide file tree
Showing 6 changed files with 24 additions and 12 deletions.
2 changes: 1 addition & 1 deletion .pre-commit-config.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@ repos:
exclude: '.bumpversion.cfg'

- repo: https://github.com/astral-sh/ruff-pre-commit
rev: v0.7.1
rev: v0.8.0
hooks:
- id: ruff
entry: ruff check src tests --fix --exit-non-zero-on-fix --show-fixes
Expand Down
10 changes: 3 additions & 7 deletions src/injection/providers/object.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,6 @@
from typing import Any, TypeVar, Union, cast
from typing import Any, TypeVar, Union

from injection.providers.base import BaseProvider
from injection.resolving import resolve_value

T = TypeVar("T")

Expand All @@ -12,12 +11,9 @@ def __init__(self, obj: T) -> None:
self._obj = obj

def _resolve(self) -> T:
value = cast(T, resolve_value(self._obj))
return value
return self._obj

def __call__(self, **_: Any) -> Union[T, Any]:
# **_ - workaround for working DI with Litestar
# It's ok because there should be no arguments in the __call__ method
def __call__(self) -> Union[T, Any]:
if self._mocks:
return self._mocks[-1]
return self._resolve()
2 changes: 1 addition & 1 deletion tests/container_objects.py
Original file line number Diff line number Diff line change
Expand Up @@ -67,7 +67,7 @@ class Container(DeclarativeContainer):
)
service = providers.Transient(Service, redis_client=redis)
some_service = providers.Singleton(SomeService, 1, redis, svc=service)
num = providers.Object(settings.provided.nested_settings.some_const)
num = providers.Object(1234)
num2 = providers.Object(9402)
callable_obj = providers.Callable(func, 1, c="string2", nums=num, d={"d": 500})
coroutine_provider = providers.Coroutine(coroutine, arg1=1, arg2=2)
Expand Down
2 changes: 1 addition & 1 deletion tests/test_base_container.py
Original file line number Diff line number Diff line change
Expand Up @@ -50,7 +50,7 @@ def test_override_providers_success(container):
assert container.num() == 999

assert not isinstance(container.redis(), Mock)
assert container.num() == 144
assert container.num() == 1234


def test_container_instance_is_singleton(container):
Expand Down
2 changes: 1 addition & 1 deletion tests/test_e2e.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,7 @@ def test_e2e_success(container):
assert container.some_service() is some_svc
assert container.redis() is redis
assert container.service() is not service
assert container.num() == 144
assert container.num() == 1234

coroutine_result = asyncio.run(container.coroutine_provider())
assert coroutine_result == (1, 2)
Expand Down
18 changes: 17 additions & 1 deletion tests/test_providers/test_object.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,22 @@ class SomeClass: ...
(234525, int),
],
)
def test_object_provider_resolving(obj, expected):
def test_object_provider_resolve_with_expected_type(obj, expected):
provider = providers.Object(obj)

assert isinstance(provider(), expected)


@pytest.mark.parametrize(
("obj", "expected"),
[
(type, type),
(234525, 234525),
(object, object),
("some_class", "some_class"),
],
)
def test_object_provider_resolve_with_expected_value(obj, expected):
provider = providers.Object(obj)

assert provider() == expected

0 comments on commit b59191c

Please sign in to comment.