diff --git a/django_readers/pairs.py b/django_readers/pairs.py index 0c85af1..6943607 100644 --- a/django_readers/pairs.py +++ b/django_readers/pairs.py @@ -116,21 +116,33 @@ def forward_relationship(name, related_queryset, relationship_pair, to_attr=None def reverse_relationship( - name, related_name, related_queryset, relationship_pair, to_attr=None + name, + related_name, + related_queryset, + relationship_pair, + to_attr=None, + post_fn=None, + slice=None, ): prepare_related_queryset, project_relationship = relationship_pair prepare = qs.prefetch_reverse_relationship( - name, related_name, related_queryset, prepare_related_queryset, to_attr + name, related_name, related_queryset, prepare_related_queryset, to_attr, slice + ) + return prepare, producers.relationship( + to_attr or name, project_relationship, post_fn ) - return prepare, producers.relationship(to_attr or name, project_relationship) -def many_to_many_relationship(name, related_queryset, relationship_pair, to_attr=None): +def many_to_many_relationship( + name, related_queryset, relationship_pair, to_attr=None, post_fn=None, slice=None +): prepare_related_queryset, project_relationship = relationship_pair prepare = qs.prefetch_many_to_many_relationship( - name, related_queryset, prepare_related_queryset, to_attr + name, related_queryset, prepare_related_queryset, to_attr, slice + ) + return prepare, producers.relationship( + to_attr or name, project_relationship, post_fn ) - return prepare, producers.relationship(to_attr or name, project_relationship) def relationship(name, relationship_pair, to_attr=None): diff --git a/django_readers/producers.py b/django_readers/producers.py index a61a268..62e3c9c 100644 --- a/django_readers/producers.py +++ b/django_readers/producers.py @@ -16,7 +16,7 @@ def producer(instance): method = methodcaller -def relationship(name, related_projector): +def relationship(name, related_projector, post_fn=None): """ Given an attribute name and a projector, return a producer which plucks the attribute off the instance, figures out whether it represents a single @@ -24,12 +24,17 @@ def relationship(name, related_projector): to the related object or objects. """ + if not post_fn: + + def post_fn(x): + return x + def producer(instance): try: related = none_safe_attrgetter(name)(instance) except ObjectDoesNotExist: return None - return map_or_apply(related, related_projector) + return post_fn(map_or_apply(related, related_projector)) return producer diff --git a/django_readers/qs.py b/django_readers/qs.py index 59038f4..7e39afe 100644 --- a/django_readers/qs.py +++ b/django_readers/qs.py @@ -142,7 +142,12 @@ def prefetch_forward_relationship( def prefetch_reverse_relationship( - name, related_name, related_queryset, prepare_related_queryset=noop, to_attr=None + name, + related_name, + related_queryset, + prepare_related_queryset=noop, + to_attr=None, + slice=None, ): """ Efficiently prefetch a reverse relationship: one where the field on the "parent" @@ -152,41 +157,37 @@ def prefetch_reverse_relationship( as Django will need it when it comes to stitch them together when the query is executed. """ + + prefetch_qs = pipe(include_fields(related_name), prepare_related_queryset)( + related_queryset + ) + + if slice: + prefetch_qs = prefetch_qs.__getitem__(slice) + return pipe( include_fields("pk"), - prefetch_related( - Prefetch( - name, - pipe( - include_fields(related_name), - prepare_related_queryset, - )(related_queryset), - to_attr, - ) - ), + prefetch_related(Prefetch(name, prefetch_qs, to_attr)), ) def prefetch_many_to_many_relationship( - name, related_queryset, prepare_related_queryset=noop, to_attr=None + name, related_queryset, prepare_related_queryset=noop, to_attr=None, slice=None ): """ For many-to-many relationships, both sides of the relationship are non-concrete, so we don't need to do anything special with including fields. They are also symmetrical, so no need to differentiate between forward and reverse direction. """ + + prefetch_qs = pipe(include_fields("pk"), prepare_related_queryset)(related_queryset) + + if slice: + prefetch_qs = prefetch_qs.__getitem__(slice) + return pipe( include_fields("pk"), - prefetch_related( - Prefetch( - name, - pipe( - include_fields("pk"), - prepare_related_queryset, - )(related_queryset), - to_attr, - ) - ), + prefetch_related(Prefetch(name, prefetch_qs, to_attr)), ) diff --git a/django_readers/utils.py b/django_readers/utils.py index 6953d50..08ab8ed 100644 --- a/django_readers/utils.py +++ b/django_readers/utils.py @@ -1,10 +1,12 @@ +from itertools import islice + try: import zen_queries except ImportError: zen_queries = None -def map_or_apply(obj, fn): +def map_or_apply(obj, fn, slice=None): """ If the first argument is iterable, map the function across each item in it and return the result. If it looks like a queryset or manager, call `.all()` and @@ -16,11 +18,22 @@ def map_or_apply(obj, fn): try: # Is the object itself iterable? - return [fn(item) for item in iter(obj)] + print(obj) + if slice: + iterable = islice(obj, slice.start, slice.stop, slice.step) + else: + iterable = iter(obj) + + return [fn(item) for item in iterable] except TypeError: try: # Does the object have a `.all()` method (is it a manager?) - return [fn(item) for item in obj.all()] + qs = obj.all() + + if slice: + qs = qs.__gettiem__(slice) + + return [fn(item) for item in qs] except AttributeError: # It must be a single object return fn(obj) @@ -106,3 +119,13 @@ def visit_dict_item_tuple(self, key, value): def visit_dict_item_callable(self, key, value): return key, self.visit_callable(value) + + +def collapse_list(res): + if not res: + return None + + if len(res) == 1: + return res[0] + + return res diff --git a/tests/test_pairs.py b/tests/test_pairs.py index 0c9bbfb..4c463ea 100644 --- a/tests/test_pairs.py +++ b/tests/test_pairs.py @@ -1,7 +1,7 @@ from django.db.models import Count from django.db.models.functions import Length from django.test import TestCase -from django_readers import pairs, producers, projectors, qs +from django_readers import pairs, producers, projectors, qs, utils from tests.models import Category, Group, Owner, Thing, Widget from tests.test_producers import title_and_reverse @@ -231,6 +231,90 @@ def test_reverse_many_to_one_relationship_with_to_attr(self): }, ) + def test_reverse_many_to_one_relationship_with_slice(self): + group = Group.objects.create(name="test group") + owner = Owner.objects.create(name="test owner", group=group) + Widget.objects.create(name="widget 1", value=1, owner=owner) + Widget.objects.create(name="widget 2", value=100, owner=owner) + + prepare, project = pairs.combine( + pairs.producer_to_projector("name", pairs.field("name")), + pairs.producer_to_projector( + "widget_set_attr", + pairs.reverse_relationship( + "widget_set", + "owner", + Widget.objects.all().order_by("value"), + pairs.combine( + pairs.producer_to_projector("name", pairs.field("name")), + pairs.producer_to_projector("value", pairs.field("value")), + ), + to_attr="widget_set_attr", + slice=slice(0, 1), + ), + ), + ) + + with self.assertNumQueries(0): + queryset = prepare(Owner.objects.all()) + + with self.assertNumQueries(2): + instance = queryset.first() + + with self.assertNumQueries(0): + result = project(instance) + + self.assertEqual( + result, + { + "name": "test owner", + "widget_set_attr": [ + {"name": "widget 1", "value": 1}, + ], + }, + ) + + def test_reverse_many_to_one_relationship_with_slice_post_fn(self): + owner = Owner.objects.create(name="test owner") + Widget.objects.create(name="widget 1", value=1, owner=owner) + Widget.objects.create(name="widget 2", value=100, owner=owner) + + prepare, project = pairs.combine( + pairs.producer_to_projector("name", pairs.field("name")), + pairs.producer_to_projector( + "widget_set_attr", + pairs.reverse_relationship( + "widget_set", + "owner", + Widget.objects.all().order_by("value"), + pairs.combine( + pairs.producer_to_projector("name", pairs.field("name")), + pairs.producer_to_projector("value", pairs.field("value")), + ), + to_attr="widget_set_attr", + post_fn=utils.collapse_list, + slice=slice(1, 2), + ), + ), + ) + + with self.assertNumQueries(0): + queryset = prepare(Owner.objects.all()) + + with self.assertNumQueries(2): + instance = queryset.first() + + with self.assertNumQueries(0): + result = project(instance) + + self.assertEqual( + result, + { + "name": "test owner", + "widget_set_attr": {"name": "widget 2", "value": 100}, + }, + ) + def test_one_to_one_relationship(self): widget = Widget.objects.create(name="test widget") Thing.objects.create(name="test thing", widget=widget) @@ -417,6 +501,55 @@ def test_many_to_many_relationship_with_to_attr(self): }, ) + def test_many_to_many_relationship_with_to_attr_slice_post_fn(self): + widget_1 = Widget.objects.create(name="test widget 1") + widget_2 = Widget.objects.create(name="test widget 2") + category = Category.objects.create(name="test category") + category.widget_set.add(widget_1) + category.widget_set.add(widget_2) + + prepare, project = pairs.combine( + pairs.producer_to_projector("name", pairs.field("name")), + pairs.producer_to_projector( + "widget_set_attr", + pairs.many_to_many_relationship( + "widget_set", + Widget.objects.all().order_by("-name"), + pairs.combine( + pairs.producer_to_projector("name", pairs.field("name")), + pairs.producer_to_projector( + "category_set_attr", + pairs.many_to_many_relationship( + "category_set", + Category.objects.all().order_by("-name"), + pairs.producer_to_projector( + "name", pairs.field("name") + ), + to_attr="category_set_attr", + ), + ), + ), + to_attr="widget_set_attr", + post_fn=utils.collapse_list, + slice=slice(1), + ), + ), + ) + + instance = prepare(Category.objects.all()).first() + result = project(instance) + + self.assertEqual( + result, + { + "name": "test category", + "widget_set_attr": { + "name": "test widget 2", + "category_set_attr": [{"name": "test category"}], + }, + }, + ) + def test_relationship(self): owner = Owner.objects.create(name="test owner") widget = Widget.objects.create(name="test widget", owner=owner)