Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Allow queryset slicing and running a function on the materialised relationship elements #85

Open
wants to merge 2 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
24 changes: 18 additions & 6 deletions django_readers/pairs.py
Original file line number Diff line number Diff line change
Expand Up @@ -116,21 +116,33 @@ def forward_relationship(name, related_queryset, relationship_pair, to_attr=None


def reverse_relationship(
name, related_name, related_queryset, relationship_pair, to_attr=None
name,
related_name,
related_queryset,
relationship_pair,
to_attr=None,
post_fn=None,
slice=None,
):
prepare_related_queryset, project_relationship = relationship_pair
prepare = qs.prefetch_reverse_relationship(
name, related_name, related_queryset, prepare_related_queryset, to_attr
name, related_name, related_queryset, prepare_related_queryset, to_attr, slice
)
return prepare, producers.relationship(
to_attr or name, project_relationship, post_fn
)
return prepare, producers.relationship(to_attr or name, project_relationship)


def many_to_many_relationship(name, related_queryset, relationship_pair, to_attr=None):
def many_to_many_relationship(
name, related_queryset, relationship_pair, to_attr=None, post_fn=None, slice=None
):
prepare_related_queryset, project_relationship = relationship_pair
prepare = qs.prefetch_many_to_many_relationship(
name, related_queryset, prepare_related_queryset, to_attr
name, related_queryset, prepare_related_queryset, to_attr, slice
)
return prepare, producers.relationship(
to_attr or name, project_relationship, post_fn
)
return prepare, producers.relationship(to_attr or name, project_relationship)


def relationship(name, relationship_pair, to_attr=None):
Expand Down
9 changes: 7 additions & 2 deletions django_readers/producers.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,20 +16,25 @@ def producer(instance):
method = methodcaller


def relationship(name, related_projector):
def relationship(name, related_projector, post_fn=None):
"""
Given an attribute name and a projector, return a producer which plucks
the attribute off the instance, figures out whether it represents a single
object or an iterable/queryset of objects, and applies the given projector
to the related object or objects.
"""

if not post_fn:

def post_fn(x):
return x

def producer(instance):
try:
related = none_safe_attrgetter(name)(instance)
except ObjectDoesNotExist:
return None
return map_or_apply(related, related_projector)
return post_fn(map_or_apply(related, related_projector))

return producer

Expand Down
45 changes: 23 additions & 22 deletions django_readers/qs.py
Original file line number Diff line number Diff line change
Expand Up @@ -142,7 +142,12 @@ def prefetch_forward_relationship(


def prefetch_reverse_relationship(
name, related_name, related_queryset, prepare_related_queryset=noop, to_attr=None
name,
related_name,
related_queryset,
prepare_related_queryset=noop,
to_attr=None,
slice=None,
):
"""
Efficiently prefetch a reverse relationship: one where the field on the "parent"
Expand All @@ -152,41 +157,37 @@ def prefetch_reverse_relationship(
as Django will need it when it comes to stitch them together when the query
is executed.
"""

prefetch_qs = pipe(include_fields(related_name), prepare_related_queryset)(
related_queryset
)

if slice:
prefetch_qs = prefetch_qs.__getitem__(slice)

return pipe(
include_fields("pk"),
prefetch_related(
Prefetch(
name,
pipe(
include_fields(related_name),
prepare_related_queryset,
)(related_queryset),
to_attr,
)
),
prefetch_related(Prefetch(name, prefetch_qs, to_attr)),
)


def prefetch_many_to_many_relationship(
name, related_queryset, prepare_related_queryset=noop, to_attr=None
name, related_queryset, prepare_related_queryset=noop, to_attr=None, slice=None
):
"""
For many-to-many relationships, both sides of the relationship are non-concrete,
so we don't need to do anything special with including fields. They are also
symmetrical, so no need to differentiate between forward and reverse direction.
"""

prefetch_qs = pipe(include_fields("pk"), prepare_related_queryset)(related_queryset)

if slice:
prefetch_qs = prefetch_qs.__getitem__(slice)

return pipe(
include_fields("pk"),
prefetch_related(
Prefetch(
name,
pipe(
include_fields("pk"),
prepare_related_queryset,
)(related_queryset),
to_attr,
)
),
prefetch_related(Prefetch(name, prefetch_qs, to_attr)),
)


Expand Down
29 changes: 26 additions & 3 deletions django_readers/utils.py
Original file line number Diff line number Diff line change
@@ -1,10 +1,12 @@
from itertools import islice

try:
import zen_queries
except ImportError:
zen_queries = None


def map_or_apply(obj, fn):
def map_or_apply(obj, fn, slice=None):
"""
If the first argument is iterable, map the function across each item in it and
return the result. If it looks like a queryset or manager, call `.all()` and
Expand All @@ -16,11 +18,22 @@ def map_or_apply(obj, fn):

try:
# Is the object itself iterable?
return [fn(item) for item in iter(obj)]
print(obj)
if slice:
iterable = islice(obj, slice.start, slice.stop, slice.step)
else:
iterable = iter(obj)

return [fn(item) for item in iterable]
except TypeError:
try:
# Does the object have a `.all()` method (is it a manager?)
return [fn(item) for item in obj.all()]
qs = obj.all()

if slice:
qs = qs.__gettiem__(slice)

return [fn(item) for item in qs]
except AttributeError:
# It must be a single object
return fn(obj)
Expand Down Expand Up @@ -106,3 +119,13 @@ def visit_dict_item_tuple(self, key, value):

def visit_dict_item_callable(self, key, value):
return key, self.visit_callable(value)


def collapse_list(res):
if not res:
return None

if len(res) == 1:
return res[0]

return res
135 changes: 134 additions & 1 deletion tests/test_pairs.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
from django.db.models import Count
from django.db.models.functions import Length
from django.test import TestCase
from django_readers import pairs, producers, projectors, qs
from django_readers import pairs, producers, projectors, qs, utils
from tests.models import Category, Group, Owner, Thing, Widget
from tests.test_producers import title_and_reverse

Expand Down Expand Up @@ -231,6 +231,90 @@ def test_reverse_many_to_one_relationship_with_to_attr(self):
},
)

def test_reverse_many_to_one_relationship_with_slice(self):
group = Group.objects.create(name="test group")
owner = Owner.objects.create(name="test owner", group=group)
Widget.objects.create(name="widget 1", value=1, owner=owner)
Widget.objects.create(name="widget 2", value=100, owner=owner)

prepare, project = pairs.combine(
pairs.producer_to_projector("name", pairs.field("name")),
pairs.producer_to_projector(
"widget_set_attr",
pairs.reverse_relationship(
"widget_set",
"owner",
Widget.objects.all().order_by("value"),
pairs.combine(
pairs.producer_to_projector("name", pairs.field("name")),
pairs.producer_to_projector("value", pairs.field("value")),
),
to_attr="widget_set_attr",
slice=slice(0, 1),
),
),
)

with self.assertNumQueries(0):
queryset = prepare(Owner.objects.all())

with self.assertNumQueries(2):
instance = queryset.first()

with self.assertNumQueries(0):
result = project(instance)

self.assertEqual(
result,
{
"name": "test owner",
"widget_set_attr": [
{"name": "widget 1", "value": 1},
],
},
)

def test_reverse_many_to_one_relationship_with_slice_post_fn(self):
owner = Owner.objects.create(name="test owner")
Widget.objects.create(name="widget 1", value=1, owner=owner)
Widget.objects.create(name="widget 2", value=100, owner=owner)

prepare, project = pairs.combine(
pairs.producer_to_projector("name", pairs.field("name")),
pairs.producer_to_projector(
"widget_set_attr",
pairs.reverse_relationship(
"widget_set",
"owner",
Widget.objects.all().order_by("value"),
pairs.combine(
pairs.producer_to_projector("name", pairs.field("name")),
pairs.producer_to_projector("value", pairs.field("value")),
),
to_attr="widget_set_attr",
post_fn=utils.collapse_list,
slice=slice(1, 2),
),
),
)

with self.assertNumQueries(0):
queryset = prepare(Owner.objects.all())

with self.assertNumQueries(2):
instance = queryset.first()

with self.assertNumQueries(0):
result = project(instance)

self.assertEqual(
result,
{
"name": "test owner",
"widget_set_attr": {"name": "widget 2", "value": 100},
},
)

def test_one_to_one_relationship(self):
widget = Widget.objects.create(name="test widget")
Thing.objects.create(name="test thing", widget=widget)
Expand Down Expand Up @@ -417,6 +501,55 @@ def test_many_to_many_relationship_with_to_attr(self):
},
)

def test_many_to_many_relationship_with_to_attr_slice_post_fn(self):
widget_1 = Widget.objects.create(name="test widget 1")
widget_2 = Widget.objects.create(name="test widget 2")
category = Category.objects.create(name="test category")
category.widget_set.add(widget_1)
category.widget_set.add(widget_2)

prepare, project = pairs.combine(
pairs.producer_to_projector("name", pairs.field("name")),
pairs.producer_to_projector(
"widget_set_attr",
pairs.many_to_many_relationship(
"widget_set",
Widget.objects.all().order_by("-name"),
pairs.combine(
pairs.producer_to_projector("name", pairs.field("name")),
pairs.producer_to_projector(
"category_set_attr",
pairs.many_to_many_relationship(
"category_set",
Category.objects.all().order_by("-name"),
pairs.producer_to_projector(
"name", pairs.field("name")
),
to_attr="category_set_attr",
),
),
),
to_attr="widget_set_attr",
post_fn=utils.collapse_list,
slice=slice(1),
),
),
)

instance = prepare(Category.objects.all()).first()
result = project(instance)

self.assertEqual(
result,
{
"name": "test category",
"widget_set_attr": {
"name": "test widget 2",
"category_set_attr": [{"name": "test category"}],
},
},
)

def test_relationship(self):
owner = Owner.objects.create(name="test owner")
widget = Widget.objects.create(name="test widget", owner=owner)
Expand Down