diff --git a/clorm/orm/query.py b/clorm/orm/query.py index b0ad1e5..ab072b2 100644 --- a/clorm/orm/query.py +++ b/clorm/orm/query.py @@ -3415,7 +3415,7 @@ def make_query(qp, factsets, factindexes): # ------------------------------------------------------------------------------ -def make_outputter(insig, outsig): +def make_outputter(queryroots, insig, outsig): def make_simple_outputter(): af = make_input_alignment_functor(insig, outsig) return lambda intuple, af=af: af(intuple) @@ -3432,7 +3432,8 @@ def make_complex_outputter(): tmp = make_input_alignment_functor(insig, out.paths) metasig.append(lambda x, af=tmp, f=out.functor: f(*af(x))) elif callable(out): - metasig.append(lambda x, f=out: f(*x)) + tmp = make_input_alignment_functor(insig, queryroots) + metasig.append(lambda x, af=tmp, f=out: f(*af(x))) else: metasign.append(lambda x, out=out: out) @@ -4232,10 +4233,11 @@ def all(self) -> Generator[Any, None, None]: (self._qplan, self._query) = self._make_plan_and_query() outsig = self._qspec.select + roots = self._qspec.roots if outsig is None or not outsig: outsig = self._qspec.roots - self._outputter = make_outputter(self._qplan.output_signature, outsig) + self._outputter = make_outputter(roots, self._qplan.output_signature, outsig) self._unwrap = not self._qspec.tuple and len(outsig) == 1 self._distinct = self._qspec.distinct @@ -4387,9 +4389,10 @@ def execute_fn(facts): ) outsig = self._qspec.select + roots = self._qspec.roots if outsig is None or not outsig: outsig = self._qspec.roots - self._outputter = make_outputter(self._qplan.output_signature, outsig) + self._outputter = make_outputter(roots, self._qplan.output_signature, outsig) self._unwrap = False self._distinct = self._qspec.distinct diff --git a/tests/test_orm_query.py b/tests/test_orm_query.py index 96ebf2a..d4b772b 100644 --- a/tests/test_orm_query.py +++ b/tests/test_orm_query.py @@ -2593,9 +2593,11 @@ def f_in_factmaps(f, factmaps): return f in fm.factset -def factmaps_dict(facts, indexes=[]): +def factmaps_dict(facts, indexes=None): from itertools import groupby + if indexes is None: + indexes = [] indexes = sorted([hashable_path(p) for p in indexes]) predicate2indexes = {} for k, g in groupby(indexes, lambda p: path(p).meta.predicate): @@ -3431,6 +3433,57 @@ def replace(subroots, fn): self.assertTrue(f_in_factmaps(F(35, "a"), factmaps)) self.assertTrue(f_in_factmaps(F(35, "foo"), factmaps)) + # ---------------------------------------------------------------------------------------- + # Test the output order of the predicates for a complex join matches the input order. This + # is to track down a bug that is not re-aligning the order when the query execution plan + # internally changes the order. + # + # NOTE: the original bug is to do with passing a function/lambda to the "select" clause. + # Internally since no input signature is associated with the function/lambda it should be + # given the signature of the query roots. This wasn't happening and was instead being + # passed the inputs as they were generated by the query plan. The fix is to use the query + # root as the input signature for a function/lambda. + # ---------------------------------------------------------------------------------------- + def test_api_QueryExecutor_output_alignment(self): + class F(Predicate): + aint = IntegerField + + class G(Predicate): + astr = StringField + + class H(Predicate): + aint = IntegerField + astr = StringField + + pw = process_where + pj = process_join + roots = [F, H, G] + + factmaps = factmaps_dict( + [F(1), F(2), G("a"), G("b"), H(1, "a"), H(2, "b")], [G.astr, F.aint, H.astr] + ) + + qspec = QuerySpec( + roots=roots, + join=pj([F.aint == H.aint, G.astr == H.astr], roots), + where=pw(F.aint == 1, roots), + select=(lambda f, h, g: (f, h, g),), + ) + + qe = QueryExecutor(factmaps, qspec) + plan, _ = qe._make_plan_and_query() + plan_roots = [x.root.meta.predicate for x in plan] + + # To check that the bug is fixed and the order of the output is the same as the input + # we need to make sure the order of the input roots is different to the plan roots. + self.assertNotEqual(roots, plan_roots) + output = list(qe.all()) + self.assertEqual(len(output), 1) + f, h, g = output[0] + self.assertEqual(type(f), F) + self.assertEqual(type(g), G) + self.assertEqual(type(h), H) + # ------------------------------------------------------------------------------ # main