From 7b292da57e8364e0ec191daf1f640dfeade8992d Mon Sep 17 00:00:00 2001 From: Hanno Becker Date: Mon, 25 Mar 2024 04:09:00 +0000 Subject: [PATCH] Allow pre->pre and late->late cross loop dependencies --- slothy/core/core.py | 57 +++++++++++++++++++++++++++++++---------- slothy/core/dataflow.py | 4 +++ 2 files changed, 47 insertions(+), 14 deletions(-) diff --git a/slothy/core/core.py b/slothy/core/core.py index e490832f..08c97caf 100644 --- a/slothy/core/core.py +++ b/slothy/core/core.py @@ -1746,6 +1746,13 @@ def _add_path_constraint( self, consumer, producer, cb): cb() return + if self._is_low(consumer) and self._is_high(producer): + ct = cb() + ct.OnlyEnforceIf([consumer.pre_var, producer.pre_var]) + ct = cb() + ct.OnlyEnforceIf([consumer.post_var, producer.post_var]) + return + if self._is_input(producer) and self._is_low(consumer): return if self._is_output(consumer) and self._is_high(producer): @@ -1773,6 +1780,9 @@ def _add_path_constraint_from( self, consumer, producer, cb_lst): bvars = [ self._NewBoolVar("") for _ in cb_lst ] self._AddExactlyOne(bvars) + if self._is_low(consumer) and self._is_high(producer): + raise Exception("Not yet implemented") + if not self.config.sw_pipelining.enabled or producer.is_virtual or consumer.is_virtual: for (cb, bvar) in zip(cb_lst, bvars, strict=True): cb().OnlyEnforceIf(bvar) @@ -2112,15 +2122,29 @@ def _is_output(self, t): assert isinstance(t, ComputationNode) return t.is_virtual_output - def _iter_dependencies(self, with_virt=True): - def f(t): - if with_virt: - return True + def _iter_dependencies(self, with_virt=True, with_duals=True): + def check_dep(t): (consumer, producer, _, _) = t - return consumer in self._get_nodes() and \ - producer.src in self._get_nodes() + if with_virt: + yield t + elif consumer in self._get_nodes() and \ + producer.src in self._get_nodes(): + yield t + + def is_cross_iteration_dependency(producer, consumer): + if not self.config.sw_pipelining.enabled is True: + return False + return self._is_low(producer.src) and self._is_high(consumer) - yield from filter(f, self._model.tree.iter_dependencies()) + for t in self._model.tree.iter_dependencies(): + yield from check_dep(t) + + if with_duals is False: + continue + + (consumer, producer, a, b) = t + if is_cross_iteration_dependency(producer, consumer): + yield from check_dep((consumer.sibling, producer.sibling(), a, b)) def _iter_dependencies_with_lifetime(self): @@ -2129,7 +2153,7 @@ def _get_lifetime_start(src): return src.src.out_lifetime_start[src.idx] if isinstance(src, InstructionInOut): return src.src.inout_lifetime_start[src.idx] - raise SlothyException("Unknown register source") + raise SlothyException(f"Unknown register source {src}") def _get_lifetime_end(src): if isinstance(src, InstructionOutput): @@ -2139,9 +2163,9 @@ def _get_lifetime_end(src): raise SlothyException("Unknown register source") for (consumer, producer, ty, idx) in self._iter_dependencies(): - start_var = _get_lifetime_start(producer) - end_var = _get_lifetime_end(producer) - yield (consumer, producer, ty, idx, start_var, end_var, producer.alloc()) + producer_start_var = _get_lifetime_start(producer) + producer_end_var = _get_lifetime_end(producer) + yield (consumer, producer, ty, idx, producer_start_var, producer_end_var, producer.alloc()) def _iter_cross_iteration_dependencies(self): def is_cross_iteration_dependency(dep): @@ -2358,15 +2382,20 @@ def _add_constraints_loop_optimization(self): self._AddImplication( producer.src.post_var, consumer.post_var ) self._AddImplication( consumer.pre_var, producer.src.pre_var ) self._AddImplication( producer.src.pre_var, consumer.post_var.Not() ) - elif self._is_low(producer.src): + elif self._is_low(producer.src) and self._is_high(consumer): + self._AddImplication( producer.src.pre_var, consumer.pre_var ) + self._AddImplication( consumer.post_var, producer.src.post_var ) + # self._AddImplication(producer.src.pre_var + # pass + # An instruction with forward dependency to the next iteration # cannot be an early instruction, and an instruction depending # on an instruction from a previous iteration cannot be late. # pylint:disable=singleton-comparison - self._Add(producer.src.pre_var == False) + # self._Add(producer.src.pre_var == False) # pylint:disable=singleton-comparison - self._Add(consumer.post_var == False) + # self._Add(consumer.post_var == False) # ================================================================ # CONSTRAINTS (Single issuing) # diff --git a/slothy/core/dataflow.py b/slothy/core/dataflow.py index e2e24b4b..0bb152b4 100644 --- a/slothy/core/dataflow.py +++ b/slothy/core/dataflow.py @@ -74,6 +74,8 @@ def alloc(self): return self.src.alloc_out_var[self.idx] def reduce(self): return self + def sibling(self): + return InstructionOutput(self.src.sibling, self.idx) class InstructionInOut(RegisterSource): """Represents an input/output of a node in the data flow graph""" @@ -87,6 +89,8 @@ def alloc(self): return self.src.alloc_in_out_var[self.idx] def reduce(self): return self.src.src_in_out[self.idx].reduce() + def sibling(self): + return InstructionInOut(self.src.sibling, self.idx) class VirtualInstruction: """A 'virtual' instruction node for inputs and outputs."""