diff --git a/.pre-commit-config.yaml b/.pre-commit-config.yaml index 285869a..1875a9f 100644 --- a/.pre-commit-config.yaml +++ b/.pre-commit-config.yaml @@ -1,18 +1,6 @@ repos: - - repo: https://github.com/pre-commit/pre-commit-hooks - rev: 'v4.5.0' - hooks: - - id: check-toml - - id: check-json - - id: end-of-file-fixer - - id: pretty-format-json - args: - - '--autofix' - - id: trailing-whitespace - exclude: '.bumpversion.cfg' - - repo: https://github.com/astral-sh/ruff-pre-commit - rev: v0.8.0 + rev: v0.8.1 hooks: - id: ruff entry: ruff check src tests --fix --exit-non-zero-on-fix --show-fixes diff --git a/docs/dev/migration-from-dependency-injector.md b/docs/dev/migration-from-dependency-injector.md index 5399e93..57127d0 100644 --- a/docs/dev/migration-from-dependency-injector.md +++ b/docs/dev/migration-from-dependency-injector.md @@ -13,7 +13,7 @@ and eliminates its shortcomings, which will make migrating very easy. ⚠️ **IMPORTANT** ❗ [Injection](https://github.com/nightblure/injection) **does not implement** **some** [providers](https://python-dependency-injector.ets-labs.org/providers/index.html) -(Resource, List, Dict, Aggregate and etc.) because the developer considered them to be **rarely used** in practice. +(List, Dict and etc.) because the developer considered them to be **rarely used** in practice. In this case, you don't need to do the migration, but if you really want to use my package, I'd love to see your [issues](https://github.com/nightblure/injection/issues) and/or [merge requests](https://github.com/nightblure/injection/pulls)! diff --git a/docs/index.md b/docs/index.md index 12b8783..fc3b7b4 100644 --- a/docs/index.md +++ b/docs/index.md @@ -29,6 +29,12 @@ providers/object providers/provided_instance +.. toctree:: + :maxdepth: 1 + :caption: Dependency injection + + injection/injection.md + .. toctree:: :maxdepth: 1 :caption: Integration with web frameworks diff --git a/docs/injection/injection.md b/docs/injection/injection.md new file mode 100644 index 0000000..02eb8c4 --- /dev/null +++ b/docs/injection/injection.md @@ -0,0 +1,6 @@ +# Dependency injection + +soon... + +## Auto injection + diff --git a/docs/providers/coroutine.md b/docs/providers/coroutine.md index eb78b67..70c846d 100644 --- a/docs/providers/coroutine.md +++ b/docs/providers/coroutine.md @@ -6,23 +6,28 @@ Can be resolved only with using the `async_resolve` method. ## Example ```python3 - import asyncio - from typing import Tuple +import asyncio +from typing import Tuple - from injection import DeclarativeContainer, providers +from injection import DeclarativeContainer, providers - async def coroutine(arg1: int, arg2: int) -> Tuple[int, int]: - return arg1, arg2 - class DIContainer(DeclarativeContainer): - provider = providers.Coroutine(coroutine, arg1=1, arg2=2) +async def coroutine(arg1: int, arg2: int) -> Tuple[int, int]: + return arg1, arg2 - arg1, arg2 = asyncio.run(DIContainer.provider.async_resolve()) - assert (arg1, arg2) == (1, 2) - async def main() -> None: - arg1, arg2 = await DIContainer.provider.async_resolve(arg1=500, arg2=600) - assert (arg1, arg2) == (500, 600) +class DIContainer(DeclarativeContainer): + provider = providers.Coroutine(coroutine, arg1=1, arg2=2) - asyncio.run(main()) + +arg1, arg2 = asyncio.run(DIContainer.provider.async_resolve()) +assert (arg1, arg2) == (1, 2) + + +async def main() -> None: + arg1, arg2 = await DIContainer.provider.async_resolve(arg1=500, arg2=600) + assert (arg1, arg2) == (500, 600) + + +asyncio.run(main()) ``` diff --git a/docs/providers/factory.md b/docs/providers/factory.md index da46333..917d5bb 100644 --- a/docs/providers/factory.md +++ b/docs/providers/factory.md @@ -9,6 +9,7 @@ Also supports **asynchronous** dependencies. ```python3 import asyncio from dataclasses import dataclass +from typing import Tuple from injection import DeclarativeContainer, providers @@ -35,7 +36,6 @@ async def main() -> None: instance1 = DIContainer.sync_factory() instance2 = DIContainer.sync_factory() - assert instance1 is not instance2 asyncio.run(main()) diff --git a/docs/providers/resource.md b/docs/providers/resource.md index 8e22c80..c1a1a7f 100644 --- a/docs/providers/resource.md +++ b/docs/providers/resource.md @@ -1,3 +1,63 @@ # Resource -soon... +**Resource provider** provides a component with **initialization** and **closing**. +**Resource providers** supports next **initializers**: +* **sync** and **async** **generators**; +* **inheritors** of `ContextManager` and `AsyncContextManager` classes; +* functions wrapped into `@contextmanager` and `@asynccontextmanager` **decorators**. + +## Working scope +Resource provider can works with two scopes: **singleton** and **function-scope**. + +**Function-scope** requires to set parameter of `Resource` provider `function_scope=True`. +**Function-scope** resources can works only with `@inject` decorator! + +## Example +```python +from typing import Tuple, Iterator, AsyncIterator + +from injection import DeclarativeContainer, Provide, inject, providers + + +def sync_func() -> Iterator[str]: + yield "sync_func" + + +async def async_func() -> AsyncIterator[str]: + yield "async_func" + + +class DIContainer(DeclarativeContainer): + sync_resource = providers.Resource(sync_func) + async_resource = providers.Resource(async_func) + + sync_resource_func_scope = providers.Resource(sync_func, function_scope=True) + async_resource_func_scope = providers.Resource(async_func, function_scope=True) + + +@inject +async def func_with_injections( + sync_value: str = Provide[DIContainer.sync_resource], + async_value: str = Provide[DIContainer.async_resource], + sync_func_scope_value: str = Provide[DIContainer.sync_resource_func_scope], + async_func_scope_value: str = Provide[DIContainer.async_resource_func_scope] +) -> Tuple[str, str, str, str]: + return sync_value, async_value, sync_func_scope_value, async_func_scope_value + + +async def main() -> None: + values = await func_with_injections() + + assert values == ("sync_func", "async_func", "sync_func", "async_func") + + assert DIContainer.sync_resource.initialized + assert DIContainer.async_resource.initialized + + # Resources with function scope were closed after dependency injection + assert not DIContainer.sync_resource_func_scope.initialized + assert not DIContainer.async_resource_func_scope.initialized + + +if __name__ == "__main__": + await main() +``` diff --git a/docs/providers/singleton.md b/docs/providers/singleton.md index f9d28c0..6894363 100644 --- a/docs/providers/singleton.md +++ b/docs/providers/singleton.md @@ -26,7 +26,6 @@ if __name__ == "__main__": assert instance1 is instance2 assert instance1.field == 15 - ``` ## Resetting memoized object diff --git a/docs/providers/transient.md b/docs/providers/transient.md index 5e92aa6..2a5ace6 100644 --- a/docs/providers/transient.md +++ b/docs/providers/transient.md @@ -10,6 +10,7 @@ Also supports **asynchronous** dependencies. ```python3 import asyncio from dataclasses import dataclass +from typing import Tuple from injection import DeclarativeContainer, providers @@ -36,7 +37,6 @@ async def main() -> None: instance1 = DIContainer.sync_transient() instance2 = DIContainer.sync_transient() - assert instance1 is not instance2 asyncio.run(main()) diff --git a/pyproject.toml b/pyproject.toml index f64d84f..d220162 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -87,6 +87,7 @@ ignore = [ [tool.pytest.ini_options] pythonpath = ["src"] asyncio_mode = "auto" +asyncio_default_fixture_loop_scope = "function" filterwarnings = [ "ignore::DeprecationWarning:pkg_resources.*", ] @@ -166,7 +167,7 @@ build-backend = "hatchling.build" [tool.coverage.run] omit = [ "src/injection/__version__.py", - "*/tests/*" + "*/tests/*", ] [tool.coverage.report] @@ -175,7 +176,8 @@ exclude_lines = [ "if TYPE_CHECKING:", "sys.version_info", "raise NotImplementedError", - "ImportError" + "ImportError", + "# pragma: no cover" ] [tool.mypy] diff --git a/src/injection/base_container.py b/src/injection/base_container.py index 0576f6f..5d1f342 100644 --- a/src/injection/base_container.py +++ b/src/injection/base_container.py @@ -1,3 +1,4 @@ +import asyncio import inspect from collections import defaultdict from contextlib import contextmanager @@ -138,30 +139,80 @@ def init_resources(cls) -> None: @classmethod async def init_resources_async(cls) -> None: - for provider in cls.get_resource_providers(): - if provider.async_mode: - await provider.async_resolve() + await asyncio.gather( + *[ + provider.async_resolve() + for provider in cls.get_resource_providers() + if provider.async_mode + ], + ) + + @classmethod + async def init_all_resources(cls) -> None: + resource_providers = cls.get_resource_providers() + + await asyncio.gather( + *[ + provider.async_resolve() + for provider in resource_providers + if provider.async_mode + ], + ) + + for provider in resource_providers: + if not provider.async_mode: + provider() @classmethod def close_resources(cls) -> None: for provider in cls.get_resource_providers(): - if not provider.async_mode: + if provider.initialized and not provider.async_mode: provider.close() @classmethod - async def close_resources_async(cls) -> None: - for provider in cls.get_resource_providers(): - if provider.async_mode: - await provider.async_close() + async def close_async_resources(cls) -> None: + await asyncio.gather( + *[ + provider.async_close() + for provider in cls.get_resource_providers() + if provider.initialized and provider.async_mode + ], + ) + + @classmethod + async def close_function_scope_async_resources(cls) -> None: + await asyncio.gather( + *[ + provider.async_close() + for provider in cls.get_resource_providers() + if provider.initialized + and provider.async_mode + and provider.function_scope + ], + ) @classmethod def close_function_scope_resources(cls) -> None: for provider in cls.get_resource_providers(): - if not provider.async_mode and provider.function_scope: + if ( + provider.initialized + and provider.function_scope + and not provider.async_mode + ): provider.close() @classmethod - async def close_function_scope_resources_async(cls) -> None: - for provider in cls.get_resource_providers(): - if provider.async_mode and provider.function_scope: - await provider.async_close() + async def close_all_resources(cls) -> None: + resource_providers = cls.get_resource_providers() + + await asyncio.gather( + *[ + provider.async_close() + for provider in resource_providers + if provider.initialized and provider.async_mode + ], + ) + + for provider in resource_providers: + if provider.initialized and not provider.async_mode: + provider.close() diff --git a/src/injection/inject/auto_inject.py b/src/injection/inject/auto_inject.py index e94108c..0539039 100644 --- a/src/injection/inject/auto_inject.py +++ b/src/injection/inject/auto_inject.py @@ -72,35 +72,39 @@ async def wrapper(*args: P.args, **kwargs: P.kwargs) -> T: def auto_inject( - f: Callable[P, T], target_container: Optional[_ContainerType] = None, -) -> Callable[P, T]: +) -> Callable[[Callable[P, T]], Callable[P, T]]: """Decorate callable with injecting decorator. Inject objects by types""" - if target_container is None: - container_subclasses = DeclarativeContainer.__subclasses__() + def wrapper(f: Callable[P, T]) -> Callable[P, T]: + nonlocal target_container - if len(container_subclasses) > 1: - msg = ( - f"Found {len(container_subclasses)} containers, please specify " - f"the required container explicitly in the parameter 'target_container'" - ) - raise Exception(msg) + if target_container is None: + container_subclasses = DeclarativeContainer.__subclasses__() - target_container = container_subclasses[0] + if len(container_subclasses) > 1: + msg = ( + f"Found {len(container_subclasses)} containers, please specify " + f"the required container explicitly in the parameter 'target_container'" + ) + raise Exception(msg) - signature = inspect.signature(f) + target_container = container_subclasses[0] # pragma: no cover - if inspect.iscoroutinefunction(f): - func_with_injected_params = _get_async_injected( - f=f, - signature=signature, - target_container=target_container, - ) - return cast(Callable[P, T], func_with_injected_params) + signature = inspect.signature(f) - return _get_sync_injected( - f=f, - signature=signature, - target_container=target_container, - ) + if inspect.iscoroutinefunction(f): + func_with_injected_params = _get_async_injected( + f=f, + signature=signature, + target_container=target_container, + ) + return cast(Callable[P, T], func_with_injected_params) + else: + return _get_sync_injected( + f=f, + signature=signature, + target_container=target_container, + ) + + return wrapper diff --git a/src/injection/inject/inject.py b/src/injection/inject/inject.py index 8191cd3..21fae53 100644 --- a/src/injection/inject/inject.py +++ b/src/injection/inject/inject.py @@ -42,7 +42,8 @@ async def wrapper(*args: P.args, **kwargs: P.kwargs) -> T: result = await f(*args, **kwargs) for container in _get_all_di_containers(): - await container.close_function_scope_resources_async() + await container.close_function_scope_async_resources() + container.close_function_scope_resources() return result diff --git a/src/injection/providers/base.py b/src/injection/providers/base.py index 125f557..348e522 100644 --- a/src/injection/providers/base.py +++ b/src/injection/providers/base.py @@ -22,8 +22,9 @@ def __init__(self) -> None: def _resolve(self, *args: Any, **kwargs: Any) -> T: raise NotImplementedError + @abstractmethod async def _async_resolve(self, *args: Any, **kwargs: Any) -> T: - return self._resolve(*args, **kwargs) + raise NotImplementedError async def async_resolve(self, *args: Any, **kwargs: Any) -> T: if self._mocks: diff --git a/src/injection/providers/object.py b/src/injection/providers/object.py index a402cdd..412e338 100644 --- a/src/injection/providers/object.py +++ b/src/injection/providers/object.py @@ -1,4 +1,4 @@ -from typing import TypeVar, cast +from typing import TypeVar from injection.providers.base import BaseProvider @@ -6,14 +6,12 @@ class Object(BaseProvider[T]): - def __init__(self, obj: T) -> None: + def __init__(self, value: T) -> None: super().__init__() - self._obj = obj + self._value = value def _resolve(self) -> T: - return self._obj + return self._value - def __call__(self) -> T: - if self._mocks: - return cast(T, self._mocks[-1]) - return self._resolve() + async def _async_resolve(self) -> T: + return self._value diff --git a/src/injection/providers/resource.py b/src/injection/providers/resource.py index f1a0f8c..a0d1d88 100644 --- a/src/injection/providers/resource.py +++ b/src/injection/providers/resource.py @@ -7,6 +7,7 @@ AsyncIterator, Callable, ContextManager, + Final, Iterator, Optional, Tuple, @@ -69,6 +70,9 @@ def _create_context_factory( return context, async_mode +_resource_not_initialized_error_msg: Final[str] = "Resource is not initialized" + + class Resource(BaseProvider[T]): def __init__( # type: ignore[valid-type] self, @@ -98,11 +102,10 @@ def __init__( # type: ignore[valid-type] self._instance: Optional[T] = None self._function_scope = function_scope - def _create_context(self) -> None: + def __create_context(self) -> None: self._context = self._context_factory(*self._args, **self._kwargs) - self._initialized = True - def _reset_context(self) -> None: + def reset(self) -> None: self._context = None self._initialized = False @@ -123,28 +126,38 @@ def instance(self) -> T: return cast(T, self._instance) def _resolve(self) -> T: - if self.initialized and not self.function_scope: + if self.initialized: return self.instance - self._create_context() + self.__create_context() self._instance = self._context.__enter__() - return self.instance - - async def async_resolve(self) -> T: - self._create_context() - self._instance = await self._context.__aenter__() + self._initialized = True return self.instance def close(self) -> None: if not self._initialized: - return None + raise RuntimeError(_resource_not_initialized_error_msg) self._context.__exit__(None, None, None) - self._reset_context() + self.reset() + + async def _async_resolve(self) -> T: + if self.initialized: + return self.instance + + self.__create_context() + + if self.async_mode: + self._instance = await self._context.__aenter__() + else: + self._instance = self._context.__enter__() + + self._initialized = True + return self.instance async def async_close(self) -> None: if not self._initialized: - return None + raise RuntimeError(_resource_not_initialized_error_msg) await self._context.__aexit__(None, None, None) - self._reset_context() + self.reset() diff --git a/tests/conftest.py b/tests/conftest.py index b9fe893..b3cdd19 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -1,4 +1,4 @@ -from typing import Type +from typing import AsyncIterator, Type import pytest @@ -8,3 +8,9 @@ @pytest.fixture(scope="session") def container() -> Type[Container]: return Container + + +@pytest.fixture(autouse=True) +async def _close_resources(container: Type[Container]) -> AsyncIterator[None]: + yield + await container.close_all_resources() diff --git a/tests/container_objects.py b/tests/container_objects.py index a05b337..9017f02 100644 --- a/tests/container_objects.py +++ b/tests/container_objects.py @@ -213,7 +213,7 @@ def func_with_injections( return redis.url -@auto_inject +@auto_inject(target_container=Container) def func_with_auto_injections( sfs: Any, redis: Redis, @@ -232,7 +232,7 @@ def func_with_auto_injections( return redis.url -@auto_inject +@auto_inject(target_container=Container) def func_with_auto_injections_mixed( sfs: Any, *, diff --git a/tests/integration/test_drf/drf_test_project/views.py b/tests/integration/test_drf/drf_test_project/views.py index ff4794b..1e29d4f 100644 --- a/tests/integration/test_drf/drf_test_project/views.py +++ b/tests/integration/test_drf/drf_test_project/views.py @@ -17,7 +17,7 @@ def get(self, _: Request, redis: Redis = Provide[Container.redis]) -> Response: response_body = {"redis_url": redis.url} return Response(response_body, status=status.HTTP_200_OK) - @auto_inject + @auto_inject(target_container=Container) def post(self, request: Request, redis: Redis) -> Response: body_serializer = PostEndpointBodySerializer(data=request.data) body_serializer.is_valid() diff --git a/tests/integration/test_flask/test_integration.py b/tests/integration/test_flask/test_integration.py index 2f7c2e3..928d67e 100644 --- a/tests/integration/test_flask/test_integration.py +++ b/tests/integration/test_flask/test_integration.py @@ -20,7 +20,7 @@ def flask_endpoint(redis: Redis = Provide[Container.redis]) -> Dict[str, Any]: @app.route("/auto-inject-endpoint", methods=["POST"]) -@auto_inject +@auto_inject(target_container=Container) def flask_endpoint_auto_inject(redis: Redis) -> Dict[str, Any]: value = redis.get(-900) return {"detail": value} diff --git a/tests/test_auto_inject.py b/tests/test_auto_inject.py index d14e8d8..96aa198 100644 --- a/tests/test_auto_inject.py +++ b/tests/test_auto_inject.py @@ -32,10 +32,10 @@ def test_auto_inject_expect_error_with_more_than_one_di_container_and_empty_targ return_value=subclasses, ): with pytest.raises(Exception, match=match): - auto_inject(lambda: None) + auto_inject()(lambda: None) -@auto_inject +@auto_inject(target_container=Container) async def _async_func( redis: Redis, *, @@ -73,7 +73,7 @@ async def test_auto_inject_expect_error_on_duplicated_provider_types( def test_auto_injection_with_args_overriding(container: Type[Container]) -> None: - @auto_inject + @auto_inject(target_container=Container) def _inner( arg1: bool, # noqa: FBT001 arg2: Service, @@ -93,7 +93,7 @@ def _inner( async def test_auto_injection_with_args_overriding_async( container: Type[Container], ) -> None: - @auto_inject + @auto_inject(target_container=Container) async def _inner( arg1: bool, # noqa: FBT001 arg2: Service, @@ -112,16 +112,16 @@ async def _inner( def test_auto_injection_expect_error_on_unknown_provider() -> None: - @auto_inject + @auto_inject(target_container=Container) def inner(_: object) -> Any: ... with pytest.raises(UnknownProviderTypeAutoInjectionError): - inner() # type:ignore[call-arg] + inner() # type: ignore[call-arg] async def test_auto_injection_expect_error_on_unknown_provider_async() -> None: - @auto_inject + @auto_inject(target_container=Container) async def inner(_: object) -> Any: ... with pytest.raises(UnknownProviderTypeAutoInjectionError): - await inner() # type:ignore[call-arg] + await inner() # type: ignore[call-arg] diff --git a/tests/test_base_container.py b/tests/test_base_container.py index 7b8c9f3..a0e27f0 100644 --- a/tests/test_base_container.py +++ b/tests/test_base_container.py @@ -130,7 +130,6 @@ def test_sync_resources_lifecycle(container: Type[Container]) -> None: for provider in container.get_resource_providers(): if not provider.async_mode: assert provider.initialized - _ = provider() container.close_resources() @@ -145,10 +144,16 @@ async def test_async_resources_lifecycle(container: Type[Container]) -> None: for provider in container.get_resource_providers(): if provider.async_mode: assert provider.initialized - _ = await provider.async_resolve() - await container.close_resources_async() + await container.close_async_resources() for provider in container.get_resource_providers(): if provider.async_mode: assert not provider.initialized + + +async def test_init_all_resources(container: Type[Container]) -> None: + await container.init_all_resources() + + for provider in container.get_resource_providers(): + assert provider.initialized diff --git a/tests/test_providers/test_coroutine.py b/tests/test_providers/test_coroutine.py index 9056fc5..4caf79f 100644 --- a/tests/test_providers/test_coroutine.py +++ b/tests/test_providers/test_coroutine.py @@ -1,5 +1,6 @@ import asyncio -from typing import Tuple, Type +from typing import Any, Tuple, Type +from unittest.mock import Mock import pytest @@ -86,3 +87,16 @@ async def test_coroutine_provider_injecting_to_sync_function( value = await container.sync_func_with_coro_dependency.async_resolve() assert value == (1, 2) # type: ignore[comparison-overlap] + + +async def test_coroutine_provider_overriding(container: Type[Container]) -> None: + @inject + async def _inner( + v: Any = Provide[container.coroutine_provider], + ) -> str: + return v() # type: ignore[no-any-return] + + with container.coroutine_provider.override_context(Mock(return_value="mock")): + value = await _inner() + + assert value == "mock" diff --git a/tests/test_providers/test_object.py b/tests/test_providers/test_object.py index 9c9a980..0b933d8 100644 --- a/tests/test_providers/test_object.py +++ b/tests/test_providers/test_object.py @@ -1,9 +1,11 @@ from dataclasses import dataclass -from typing import Any +from typing import Any, Type +from unittest.mock import Mock import pytest -from injection import providers +from injection import Provide, inject, providers +from tests.container_objects import Container @dataclass @@ -38,3 +40,33 @@ def test_object_provider_resolve_with_expected_value(obj: Any, expected: Any) -> provider = providers.Object(obj) assert provider() == expected + + +def test_object_provider_overriding_with_sync_injection( + container: Type[Container], +) -> None: + @inject + def _inner( + v: Any = Provide[container.num], + ) -> str: + return v() # type: ignore[no-any-return] + + with container.num.override_context(Mock(return_value="mock")): + value = _inner() + + assert value == "mock" + + +async def test_object_provider_overriding_with_async_injection( + container: Type[Container], +) -> None: + @inject + async def _inner( + v: Any = Provide[container.num], + ) -> str: + return v() # type: ignore[no-any-return] + + with container.num.override_context(Mock(return_value="mock")): + value = await _inner() + + assert value == "mock" diff --git a/tests/test_providers/test_resource.py b/tests/test_providers/test_resource.py index 9742863..b829223 100644 --- a/tests/test_providers/test_resource.py +++ b/tests/test_providers/test_resource.py @@ -1,5 +1,6 @@ from types import TracebackType -from typing import AsyncContextManager, ContextManager, Optional, Type +from typing import Any, AsyncContextManager, ContextManager, Optional, Type +from unittest.mock import Mock import pytest @@ -240,3 +241,120 @@ async def __aexit__( provider = Resource(_AsyncCtxManager) assert await provider.async_resolve() == 2 + + +def test_resource_provider_overriding(container: Type[Container]) -> None: + @inject + def _inner( + v: Any = Provide[container.sync_resource], + ) -> str: + return v() # type: ignore[no-any-return] + + with container.sync_resource.override_context(Mock(return_value="mock")): + value = _inner() + + assert value == "mock" + + +async def test_resource_provider_overriding_with_async_func_and_sync_resource( + container: Type[Container], +) -> None: + @inject + async def _inner( + v: Any = Provide[container.sync_resource], + ) -> str: + return v() # type: ignore[no-any-return] + + with container.sync_resource.override_context(Mock(return_value="mock")): + value = await _inner() + + assert value == "mock" + + +async def test_resource_provider_overriding_with_async_func_and_async_resource( + container: Type[Container], +) -> None: + @inject + async def _inner( + v: Any = Provide[container.async_resource], + ) -> str: + return v() # type: ignore[no-any-return] + + with container.async_resource.override_context(Mock(return_value="mock")): + value = await _inner() + + assert value == "mock" + + +def test_resource_provider_closing_expect_error_when_not_initialized( + container: Type[Container], +) -> None: + with pytest.raises(RuntimeError, match="Resource is not initialized"): + container.sync_resource.close() + + +async def test_resource_provider_async_closing_expect_error_when_not_initialized( + container: Type[Container], +) -> None: + with pytest.raises(RuntimeError, match="Resource is not initialized"): + await container.async_resource.async_close() + + +def test_resource_provider_successful_repeat_resolving( + container: Type[Container], +) -> None: + container.init_resources() + + assert isinstance(container.sync_resource(), Resources) + + +async def test_resource_provider_successful_repeat_async_resolving( + container: Type[Container], +) -> None: + await container.init_resources_async() + + value = await container.async_resource.async_resolve() + + assert isinstance(value, Resources) + + +async def test_resource_provider_docs_code() -> None: + from typing import AsyncIterator, Iterator, Tuple + + from injection import DeclarativeContainer, Provide, inject, providers + + def sync_func() -> Iterator[str]: + yield "sync_func" + + async def async_func() -> AsyncIterator[str]: + yield "async_func" + + class DIContainer(DeclarativeContainer): + sync_resource = providers.Resource(sync_func) + async_resource = providers.Resource(async_func) + + sync_resource_func_scope = providers.Resource(sync_func, function_scope=True) + async_resource_func_scope = providers.Resource(async_func, function_scope=True) + + @inject + async def func_with_injections( + sync_value: str = Provide[DIContainer.sync_resource], + async_value: str = Provide[DIContainer.async_resource], + sync_func_scope_value: str = Provide[DIContainer.sync_resource_func_scope], + async_func_scope_value: str = Provide[DIContainer.async_resource_func_scope], + ) -> Tuple[str, str, str, str]: + return sync_value, async_value, sync_func_scope_value, async_func_scope_value + + async def main() -> None: + values = await func_with_injections() + + assert values == ("sync_func", "async_func", "sync_func", "async_func") + + assert DIContainer.sync_resource.initialized + assert DIContainer.async_resource.initialized + + # Resources with function scope were closed after dependency injection + assert not DIContainer.sync_resource_func_scope.initialized + assert not DIContainer.async_resource_func_scope.initialized + + await main()