diff --git a/pinject/errors.py b/pinject/errors.py index 755cf38..8a2a9f2 100644 --- a/pinject/errors.py +++ b/pinject/errors.py @@ -164,3 +164,19 @@ class UnknownScopeError(Error): def __init__(self, scope_id, binding_loc): Error.__init__(self, 'unknown scope ID {0} in binding created at' ' {1}'.format(scope_id, binding_loc)) + + +class WrongArgElementTypeError(Error): + + def __init__(self, arg_name, idx, expected_type_desc, actual_type_desc): + Error.__init__( + self, 'wrong type for element {0} of arg {1}: expected {2} but got' + ' {3}'.format(idx, arg_name, expected_type_desc, actual_type_desc)) + + +class WrongArgTypeError(Error): + + def __init__(self, arg_name, expected_type_desc, actual_type_desc): + Error.__init__( + self, 'wrong type for arg {0}: expected {1} but got {2}'.format( + arg_name, expected_type_desc, actual_type_desc)) diff --git a/pinject/object_graph.py b/pinject/object_graph.py index 2004fdc..4f33dfe 100644 --- a/pinject/object_graph.py +++ b/pinject/object_graph.py @@ -14,6 +14,7 @@ """ +import collections import functools import inspect import types @@ -60,7 +61,7 @@ def new_object_graph( a provider (if any) id_to_scope: a map from scope ID to the concrete Scope implementation instance for that scope - is_scope_usable_from_scope_fn: a function taking two scope IDs and + is_scope_usable_from_scope: a function taking two scope IDs and returning whether an object in the first scope can be injected into an object from the second scope; by default, injection is allowed from any scope into any other scope @@ -74,6 +75,22 @@ def new_object_graph( """ try: + if modules is not None and modules is not finding.ALL_IMPORTED_MODULES: + _verify_types(modules, types.ModuleType, 'modules') + if classes is not None: + _verify_types(classes, types.TypeType, 'classes') + if binding_specs is not None: + _verify_subclasses( + binding_specs, bindings.BindingSpec, 'binding_specs') + if get_arg_names_from_class_name is not None: + _verify_callable(get_arg_names_from_class_name, + 'get_arg_names_from_class_name') + if get_arg_names_from_provider_fn_name is not None: + _verify_callable(get_arg_names_from_provider_fn_name, + 'get_arg_names_from_provider_fn_name') + if is_scope_usable_from_scope is not None: + _verify_callable(is_scope_usable_from_scope, + 'is_scope_usable_from_scope') injection_context_factory = injection_contexts.InjectionContextFactory( is_scope_usable_from_scope) id_to_scope = scoping.get_id_to_scope_with_defaults(id_to_scope) @@ -131,6 +148,37 @@ def new_object_graph( use_short_stack_traces) +def _verify_types(seq, required_type, arg_name): + if not isinstance(seq, collections.Sequence): + raise errors.WrongArgTypeError( + arg_name, 'sequence (of {0})'.format(required_type.__name__), + type(seq).__name__) + for idx, elt in enumerate(seq): + if type(elt) != required_type: + raise errors.WrongArgElementTypeError( + arg_name, idx, required_type.__name__, type(elt).__name__) + + +def _verify_subclasses(seq, required_superclass, arg_name): + if not isinstance(seq, collections.Sequence): + raise errors.WrongArgTypeError( + arg_name, + 'sequence (of subclasses of {0})'.format( + required_superclass.__name__), + type(seq).__name__) + for idx, elt in enumerate(seq): + if not isinstance(elt, required_superclass): + raise errors.WrongArgElementTypeError( + arg_name, idx, + 'subclass of {0}'.format(required_superclass.__name__), + type(elt).__name__) + + +def _verify_callable(fn, arg_name): + if not callable(fn): + raise errors.WrongArgTypeError(arg_name, 'callable', type(fn).__name__) + + class ObjectGraph(object): """A graph of objects instantiable with dependency injection.""" diff --git a/pinject/object_graph_test.py b/pinject/object_graph_test.py index 14f3708..b2cf1b7 100644 --- a/pinject/object_graph_test.py +++ b/pinject/object_graph_test.py @@ -15,6 +15,7 @@ import inspect +import types import unittest from pinject import bindings @@ -116,6 +117,85 @@ def provide_foo(self): some_class_two = obj_graph.provide(SomeClass) self.assertIs(some_class_one.foo, some_class_two.foo) + def test_raises_exception_if_modules_is_wrong_type(self): + self.assertRaises(errors.WrongArgTypeError, + object_graph.new_object_graph, modules=42) + + def test_raises_exception_if_classes_is_wrong_type(self): + self.assertRaises(errors.WrongArgTypeError, + object_graph.new_object_graph, classes=42) + + def test_raises_exception_if_binding_specs_is_wrong_type(self): + self.assertRaises(errors.WrongArgTypeError, + object_graph.new_object_graph, binding_specs=42) + + def test_raises_exception_if_get_arg_names_from_class_name_is_wrong_type(self): + self.assertRaises(errors.WrongArgTypeError, + object_graph.new_object_graph, + get_arg_names_from_class_name=42) + + def test_raises_exception_if_get_arg_names_from_provider_fn_name_is_wrong_type(self): + self.assertRaises(errors.WrongArgTypeError, + object_graph.new_object_graph, + get_arg_names_from_provider_fn_name=42) + + def test_raises_exception_if_is_scope_usable_from_scope_is_wrong_type(self): + self.assertRaises(errors.WrongArgTypeError, + object_graph.new_object_graph, + is_scope_usable_from_scope=42) + + +class VerifyTypesTest(unittest.TestCase): + + def test_verifies_empty_sequence_ok(self): + object_graph._verify_types([], types.ModuleType, 'unused') + + def test_verifies_correct_type_ok(self): + object_graph._verify_types([types], types.ModuleType, 'unused') + + def test_raises_exception_if_not_sequence(self): + self.assertRaises(errors.WrongArgTypeError, object_graph._verify_types, + 42, types.ModuleType, 'an-arg-name') + + def test_raises_exception_if_element_is_incorrect_type(self): + self.assertRaises(errors.WrongArgElementTypeError, + object_graph._verify_types, + ['not-a-module'], types.ModuleType, 'an-arg-name') + + +class VerifySubclassesTest(unittest.TestCase): + + def test_verifies_empty_sequence_ok(self): + object_graph._verify_subclasses([], bindings.BindingSpec, 'unused') + + def test_verifies_correct_type_ok(self): + class SomeBindingSpec(bindings.BindingSpec): + pass + object_graph._verify_subclasses( + [SomeBindingSpec()], bindings.BindingSpec, 'unused') + + def test_raises_exception_if_not_sequence(self): + self.assertRaises(errors.WrongArgTypeError, + object_graph._verify_subclasses, + 42, bindings.BindingSpec, 'an-arg-name') + + def test_raises_exception_if_element_is_not_subclass(self): + class NotBindingSpec(object): + pass + self.assertRaises( + errors.WrongArgElementTypeError, object_graph._verify_subclasses, + [NotBindingSpec()], bindings.BindingSpec, 'an-arg-name') + + +class VerifyCallable(unittest.TestCase): + + def test_verifies_callable_ok(self): + object_graph._verify_callable(lambda: None, 'unused') + + def test_raises_exception_if_not_callable(self): + self.assertRaises(errors.WrongArgTypeError, + object_graph._verify_callable, 42, 'an-arg-name') + class ObjectGraphProvideTest(unittest.TestCase): diff --git a/test_errors.py b/test_errors.py index e731ee1..84499a3 100755 --- a/test_errors.py +++ b/test_errors.py @@ -37,12 +37,12 @@ def _print_raised_exception(exc, fn, *pargs, **kwargs): def print_ambiguous_arg_name_error(): - class SomeClass(): + class SomeClass(object): def __init__(self, foo): pass - class Foo(): + class Foo(object): pass - class _Foo(): + class _Foo(object): pass obj_graph = object_graph.new_object_graph( modules=None, classes=[SomeClass, Foo, _Foo]) @@ -215,6 +215,17 @@ def configure(self, bind): modules=None, binding_specs=[SomeBindingSpec()]) +def print_wrong_arg_element_type_error(): + _print_raised_exception( + errors.WrongArgElementTypeError, object_graph.new_object_graph, + modules=[42]) + + +def print_wrong_arg_type_error(): + _print_raised_exception( + errors.WrongArgTypeError, object_graph.new_object_graph, modules=42) + + all_print_method_pairs = inspect.getmembers( sys.modules[__name__], lambda x: (type(x) == types.FunctionType and