From f8a391d148055857fc2647c8c0c528bf0344608e Mon Sep 17 00:00:00 2001 From: Leandro de Souza Date: Tue, 7 Jun 2022 08:59:07 -0300 Subject: [PATCH] Feature (partial providers): Allows injection of parameters at run time to the provider. --- README.md | 62 ++++++++++++++++++++++++++++++++++++++++-- inject_it/__init__.py | 2 +- inject_it/_injector.py | 39 ++++++++++++++++++++------ inject_it/register.py | 25 ++++++++++++++++- tests/test_register.py | 30 +++++++++++++++++++- 5 files changed, 145 insertions(+), 13 deletions(-) diff --git a/README.md b/README.md index 2bf7563..543ad0e 100644 --- a/README.md +++ b/README.md @@ -120,7 +120,65 @@ This will have the same effect as calling `register_dependency` after creating t Your provider can also `requires` a dependency, but it must be registered before it. -# Depending on abstract classes +## Conditional arguments to Provider + +Sometimes you will want to create a service dinamically, using some attributes for the current context of your application. For example, on a HTTP view passing the current user to the provider, the state of a object on the database. `inject-it` allows you to give this parameters to the provider on the fly using `additional_kwargs_to_provider` context manager. That will apply `functools.partial` into your provider for the given kwargs. Example: + +First, let's define the provider like usual: + +```python +# client.py +from inject_it import provider + + +class Client: + def __init__(self, key): + self.key + + +@provider +def client_provider(api_key: str) -> Client: + return Client(key=api_key) +``` + +Notice that if we don't inject the `api_key` argument, `inject_it` won't be able to call the `client_provider` function, since it will be missing the `api_key` parameter. To solve this let's continue the example: + +```python +# services.py +from client import Client +from inject_it import requires + + +@requires(Client) +def make_request(client: Client): + print(client.key) +``` + +So let's say you use the api_key for each user. And you receive an HTTP request into your view. Using django views, migth look like this: + +```python +# views.py +from client import Client +from services import make_request +from inject_it import additional_kwargs_to_provider + + +def some_view(request): + user = request.user + + with additional_kwargs_to_provider(Client, api_key=user.some_service_key): + make_request() # client will be injected for the given user.some_service_key + + ... +``` + +Two things is happening when you call the `additional_kwargs_to_provider` function: +1- You will be patch the `Client` provider function to receive the kwargs you given. +2- The kwargs must match the `client_provider` arguments. + +This helps if you are using some design patterns like the Strategy Pattern, swapping a service implementation for your current application state. + +## Depending on abstract classes `inject-it` allows you to `register_dependency` to another `bound_type`. This is useful if you don't really care about the concrete implementation, only the abstract one. Consider this example: @@ -176,6 +234,6 @@ def provider_func(): For the moment, you can only have one dependency for each type. So you can't have like two different `str` dependencies. When you register the second `str` you will be overriding the first. You can work around this by using specific types, instead of primitive types. -# Testing +## Testing Testing is made easy with `inject-it`, you just have to register your `mock`, `fake`, `stub` before calling your function. If you are using pytest, use fixtures. diff --git a/inject_it/__init__.py b/inject_it/__init__.py index 50bbc9b..6193e39 100644 --- a/inject_it/__init__.py +++ b/inject_it/__init__.py @@ -6,4 +6,4 @@ InvalidFunctionSignature, InjectedKwargAlreadyGiven, ) -from .register import register_dependency +from .register import register_dependency, additional_kwargs_to_provider diff --git a/inject_it/_injector.py b/inject_it/_injector.py index 291b799..81503f5 100644 --- a/inject_it/_injector.py +++ b/inject_it/_injector.py @@ -1,27 +1,50 @@ import inspect -from typing import Any, Type +from typing import Any, Optional from . import _checks -from .exceptions import DependencyNotRegistered, InvalidFunctionSignature -from .stubs import Dependencies, Providers, Types, Kwargs +from .exceptions import ( + DependencyNotRegistered, + InvalidDependency, + InvalidFunctionSignature, +) +from .objects import Provider +from .stubs import Class, Dependencies, Providers, Types, Kwargs dependencies: Dependencies = {} providers: Providers = {} -def _get_dependency(t: Type) -> Any: - from inject_it.register import register_dependency - +def _get_dependency(t: Class) -> Optional[Any]: dep = dependencies.get(t, Ellipsis) if dep is not Ellipsis: return dep + +def _get_provider(t: Class) -> Provider: provider = providers.get(t) if not provider: raise DependencyNotRegistered( f"Could not found an dependency for: {t}. Did you forgot to register it?" ) - dependency = provider.fnc() + return provider + + +def _resolve_dependency(t: Class) -> Any: + from inject_it.register import register_dependency + + dep = _get_dependency(t) + if dep: + return dep + provider = _get_provider(t) + + try: + dependency = provider.fnc() + except TypeError as e: + raise InvalidDependency( + "Could not properly call the provider function.", + "If the provider requires additional arguments, don't forget to wrap your function call in `additional_kwargs_to_provider`.", + "Or maybe you forgot to `requires` some dependency on the provider?", + ) from e _checks.provider_returned_expected_type( obj=dependency, type_=provider.expected_return_type ) @@ -38,4 +61,4 @@ def get_injected_kwargs_for_signature(sig: inspect.Signature, types: Types) -> K f"The type {typ} was not found on the function signature. Did you forgot to type annotate it?" ) - return {param_for_type[t]: _get_dependency(t) for t in types} + return {param_for_type[t]: _resolve_dependency(t) for t in types} diff --git a/inject_it/register.py b/inject_it/register.py index 1e974a1..6bcf58d 100644 --- a/inject_it/register.py +++ b/inject_it/register.py @@ -1,7 +1,10 @@ +from contextlib import contextmanager +import contextlib +from functools import partial from typing import Any, Optional from .objects import Provider from .stubs import Class, Function -from .exceptions import InvalidDependency +from .exceptions import DependencyNotRegistered, InvalidDependency def register_dependency(obj: Any, bound_type: Optional[Class] = None) -> None: @@ -39,3 +42,23 @@ def register_provider(type_: Class, fnc: Function, cache_dependency: bool) -> No cache_dependency=cache_dependency, expected_return_type=type_, ) + + +@contextmanager +def additional_kwargs_to_provider(type_: Class, **kwargs): + """Context manager that applies the given `kwargs` to the provider function previously registered. + At the end, rollbacks to the original function. It's useful when your dependency is created on the + fly using some additional parameters, like the current user in a HTTP Request, the current state of + some object. + """ + from ._injector import _get_provider, providers + + provider = _get_provider(type_) + + providers[type_] = Provider( + fnc=partial(provider.fnc, **kwargs), + expected_return_type=provider.expected_return_type, + cache_dependency=provider.cache_dependency, + ) + yield + providers[type_] = provider diff --git a/tests/test_register.py b/tests/test_register.py index 7931c2f..5a7ea55 100644 --- a/tests/test_register.py +++ b/tests/test_register.py @@ -1,6 +1,11 @@ import pytest from inject_it.register import register_dependency -from inject_it import exceptions as exc +from inject_it import ( + exceptions as exc, + provider, + requires, + additional_kwargs_to_provider, +) from tests.conftest import T @@ -17,3 +22,26 @@ class Z(T): pass register_dependency(Z(), bound_type=T) + + +def test_additional_kwargs_for_provider_succeeds_for_correct_call(): + class Client: + def __init__(self, key): + self.key = key + + @provider(Client) + def conditional_t_provider(key: str): + return Client(key) + + @requires(Client) + def f(c: Client): + return c.key + + with additional_kwargs_to_provider(Client, key="ABC"): + key = f() + assert key == "ABC" + + # Should rollback, and since we'are not passing any arguments to the provider + # should fail. + with pytest.raises(exc.InvalidDependency): + f()