Skip to content

Commit

Permalink
Improve callbacks performance by prebaking context awareness
Browse files Browse the repository at this point in the history
  • Loading branch information
maximkulkin committed Nov 30, 2016
1 parent 6459d28 commit 6a48f89
Show file tree
Hide file tree
Showing 4 changed files with 62 additions and 34 deletions.
55 changes: 31 additions & 24 deletions lollipop/types.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
from lollipop.errors import ValidationError, ValidationErrorBuilder, \
ErrorMessagesMixin, merge_errors
from lollipop.utils import is_list, is_dict, call_with_context, constant, identity
from lollipop.utils import is_list, is_dict, make_context_aware, constant, identity
from lollipop.compat import string_types, int_types, iteritems, OrderedDict
import datetime

Expand Down Expand Up @@ -60,7 +60,8 @@ def __init__(self, validate=None, *args, **kwargs):
elif callable(validate):
validate = [validate]

self._validators = validate
self._validators = [make_context_aware(validator, 1)
for validator in validate]

def validate(self, data, context=None):
"""Takes serialized data and returns validation errors or None.
Expand All @@ -87,7 +88,7 @@ def load(self, data, context=None):
errors_builder = ValidationErrorBuilder()
for validator in self._validators:
try:
call_with_context(validator, context, data)
validator(data, context)
except ValidationError as ve:
errors_builder.add_errors(ve.messages)
errors_builder.raise_errors()
Expand Down Expand Up @@ -873,7 +874,7 @@ def get_value(self, name, obj, context=None, *args, **kwargs):
method = getattr(obj, method_name)
if not callable(method):
raise ValueError('Value of %s is not callable' % method_name)
return call_with_context(method, context)
return make_context_aware(method, 0)(context)

def set_value(self, name, obj, value, context=None, *args, **kwargs):
if not self.set_method:
Expand All @@ -885,7 +886,7 @@ def set_value(self, name, obj, value, context=None, *args, **kwargs):
method = getattr(obj, method_name)
if not callable(method):
raise ValueError('Value of %s is not callable' % method_name)
return call_with_context(method, context, value)
return make_context_aware(method, 1)(value, context)


class FunctionField(Field):
Expand Down Expand Up @@ -918,18 +919,24 @@ def __init__(self, field_type, get=None, set=None, *args, **kwargs):
raise ValueError("Get function is not callable")
if set is not None and not callable(set):
raise ValueError("Set function is not callable")

if get is not None:
get = make_context_aware(get, 1)
if set is not None:
set = make_context_aware(set, 2)

self.get_func = get
self.set_func = set

def get_value(self, name, obj, context=None, *args, **kwargs):
if self.get_func is None:
return MISSING
return call_with_context(self.get_func, context, obj)
return self.get_func(obj, context)

def set_value(self, name, obj, value, context=None, *args, **kwargs):
if self.set_func is None:
return MISSING
call_with_context(self.set_func, context, obj, value)
self.set_func(obj, value, context)


def inheritable_property(name):
Expand Down Expand Up @@ -1271,20 +1278,20 @@ def __init__(self, inner_type,
load_default = constant(load_default)
if not callable(dump_default):
dump_default = constant(dump_default)
self.load_default = load_default
self.dump_default = dump_default
self.load_default = make_context_aware(load_default, 0)
self.dump_default = make_context_aware(dump_default, 0)

def load(self, data, context=None, *args, **kwargs):
if data is MISSING or data is None:
return call_with_context(self.load_default, context)
return self.load_default(context)
return super(Optional, self).load(
self.inner_type.load(data, context=context, *args, **kwargs),
*args, **kwargs
)

def dump(self, data, context=None, *args, **kwargs):
if data is MISSING or data is None:
return call_with_context(self.dump_default, context)
return self.dump_default(context)
return super(Optional, self).dump(
self.inner_type.dump(data, context=context, *args, **kwargs),
*args, **kwargs
Expand Down Expand Up @@ -1396,27 +1403,27 @@ def __init__(self, inner_type,
pre_dump=identity, post_dump=identity):
super(Transform, self).__init__()
self.inner_type = inner_type
self.pre_load = pre_load
self.post_load = post_load
self.pre_dump = pre_dump
self.post_dump = post_dump
self.pre_load = make_context_aware(pre_load, 1)
self.post_load = make_context_aware(post_load, 1)
self.pre_dump = make_context_aware(pre_dump, 1)
self.post_dump = make_context_aware(post_dump, 1)

def load(self, data, context=None):
return call_with_context(
self.post_load, context,
return self.post_load(
self.inner_type.load(
call_with_context(self.pre_load, context, data),
self.pre_load(data, context),
context,
)
),
context,
)

def dump(self, value, context=None):
return call_with_context(
self.post_dump, context,
return self.post_dump(
self.inner_type.dump(
call_with_context(self.pre_dump, context, value),
self.pre_dump(value, context),
context,
)
),
context,
)


Expand Down Expand Up @@ -1463,6 +1470,6 @@ class ValidatedSubtype(base_type):
def __init__(self, *args, **kwargs):
super(ValidatedSubtype, self).__init__(*args, **kwargs)
for validator in reversed(validate):
self._validators.insert(0, validator)
self._validators.insert(0, make_context_aware(validator, 1))

return ValidatedSubtype
25 changes: 18 additions & 7 deletions lollipop/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,10 +23,11 @@ def is_dict(value):
return isinstance(value, dict)


def call_with_context(func, context, *args):
def make_context_aware(func, numargs):
"""
Check if given function has more arguments than given. Call it with context
as last argument or without it.
Check if given function has no more arguments than given. If so, wrap it
into another function that takes extra argument and drops it.
Used to support user providing callback functions that are not context aware.
"""
if inspect.ismethod(func):
arg_count = len(inspect.getargspec(func).args) - 1
Expand All @@ -35,11 +36,21 @@ def call_with_context(func, context, *args):
else:
arg_count = len(inspect.getargspec(func.__call__).args) - 1

if len(args) < arg_count:
args = list(args)
args.append(context)
if arg_count <= numargs:
def normalized(*args):
return func(*args[:-1])

return normalized

return func


return func(*args)
def call_with_context(func, context, *args):
"""
Check if given function has more arguments than given. Call it with context
as last argument or without it.
"""
return make_context_aware(func, len(args))(*args + (context,))


def to_snake_case(s):
Expand Down
6 changes: 3 additions & 3 deletions lollipop/validators.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
from lollipop.errors import ValidationError, ValidationErrorBuilder, \
ErrorMessagesMixin
from lollipop.compat import string_types
from lollipop.utils import call_with_context, is_list, identity
from lollipop.utils import make_context_aware, is_list, identity
import re


Expand Down Expand Up @@ -53,13 +53,13 @@ class Predicate(Validator):

def __init__(self, predicate, error=None, **kwargs):
super(Predicate, self).__init__(**kwargs)
self.predicate = predicate
self.predicate = make_context_aware(predicate, 1)
if error is not None:
self._error_messages['invalid'] = error
self.error = error

def __call__(self, value, context=None):
if not call_with_context(self.predicate, context, value):
if not self.predicate(value, context):
self._fail('invalid', data=value)

def __repr__(self):
Expand Down
10 changes: 10 additions & 0 deletions tests/test_types.py
Original file line number Diff line number Diff line change
Expand Up @@ -2159,6 +2159,16 @@ def test_returns_type_that_has_single_given_validator(self):
assert OddInteger().validate(1) is None
assert OddInteger().validate(2) == is_odd_validator.message

def test_accepts_context_unaware_validators(self):
error_message = 'Value should be odd'
def context_unaware_is_odd_validator(value):
if value % 2 == 0:
raise ValidationError(error_message)

OddInteger = validated_type(Integer, validate=context_unaware_is_odd_validator)
assert OddInteger().validate(1) is None
assert OddInteger().validate(2) == error_message

def test_returns_type_that_has_multiple_given_validators(self):
MyInteger = validated_type(Integer, validate=[divisible_by_validator(3),
divisible_by_validator(5)])
Expand Down

0 comments on commit 6a48f89

Please sign in to comment.