Skip to content

Commit

Permalink
Allow pre->pre and late->late cross loop dependencies
Browse files Browse the repository at this point in the history
  • Loading branch information
hanno-becker committed Mar 25, 2024
1 parent 981b55b commit 7b292da
Show file tree
Hide file tree
Showing 2 changed files with 47 additions and 14 deletions.
57 changes: 43 additions & 14 deletions slothy/core/core.py
Original file line number Diff line number Diff line change
Expand Up @@ -1746,6 +1746,13 @@ def _add_path_constraint( self, consumer, producer, cb):
cb()
return

if self._is_low(consumer) and self._is_high(producer):
ct = cb()
ct.OnlyEnforceIf([consumer.pre_var, producer.pre_var])
ct = cb()
ct.OnlyEnforceIf([consumer.post_var, producer.post_var])
return

if self._is_input(producer) and self._is_low(consumer):
return
if self._is_output(consumer) and self._is_high(producer):
Expand Down Expand Up @@ -1773,6 +1780,9 @@ def _add_path_constraint_from( self, consumer, producer, cb_lst):
bvars = [ self._NewBoolVar("") for _ in cb_lst ]
self._AddExactlyOne(bvars)

if self._is_low(consumer) and self._is_high(producer):
raise Exception("Not yet implemented")

if not self.config.sw_pipelining.enabled or producer.is_virtual or consumer.is_virtual:
for (cb, bvar) in zip(cb_lst, bvars, strict=True):
cb().OnlyEnforceIf(bvar)
Expand Down Expand Up @@ -2112,15 +2122,29 @@ def _is_output(self, t):
assert isinstance(t, ComputationNode)
return t.is_virtual_output

def _iter_dependencies(self, with_virt=True):
def f(t):
if with_virt:
return True
def _iter_dependencies(self, with_virt=True, with_duals=True):
def check_dep(t):
(consumer, producer, _, _) = t
return consumer in self._get_nodes() and \
producer.src in self._get_nodes()
if with_virt:
yield t
elif consumer in self._get_nodes() and \
producer.src in self._get_nodes():
yield t

def is_cross_iteration_dependency(producer, consumer):
if not self.config.sw_pipelining.enabled is True:
return False
return self._is_low(producer.src) and self._is_high(consumer)

yield from filter(f, self._model.tree.iter_dependencies())
for t in self._model.tree.iter_dependencies():
yield from check_dep(t)

if with_duals is False:
continue

(consumer, producer, a, b) = t
if is_cross_iteration_dependency(producer, consumer):
yield from check_dep((consumer.sibling, producer.sibling(), a, b))

def _iter_dependencies_with_lifetime(self):

Expand All @@ -2129,7 +2153,7 @@ def _get_lifetime_start(src):
return src.src.out_lifetime_start[src.idx]
if isinstance(src, InstructionInOut):
return src.src.inout_lifetime_start[src.idx]
raise SlothyException("Unknown register source")
raise SlothyException(f"Unknown register source {src}")

def _get_lifetime_end(src):
if isinstance(src, InstructionOutput):
Expand All @@ -2139,9 +2163,9 @@ def _get_lifetime_end(src):
raise SlothyException("Unknown register source")

for (consumer, producer, ty, idx) in self._iter_dependencies():
start_var = _get_lifetime_start(producer)
end_var = _get_lifetime_end(producer)
yield (consumer, producer, ty, idx, start_var, end_var, producer.alloc())
producer_start_var = _get_lifetime_start(producer)
producer_end_var = _get_lifetime_end(producer)
yield (consumer, producer, ty, idx, producer_start_var, producer_end_var, producer.alloc())

def _iter_cross_iteration_dependencies(self):
def is_cross_iteration_dependency(dep):
Expand Down Expand Up @@ -2358,15 +2382,20 @@ def _add_constraints_loop_optimization(self):
self._AddImplication( producer.src.post_var, consumer.post_var )
self._AddImplication( consumer.pre_var, producer.src.pre_var )
self._AddImplication( producer.src.pre_var, consumer.post_var.Not() )
elif self._is_low(producer.src):
elif self._is_low(producer.src) and self._is_high(consumer):
self._AddImplication( producer.src.pre_var, consumer.pre_var )
self._AddImplication( consumer.post_var, producer.src.post_var )
# self._AddImplication(producer.src.pre_var
# pass

# An instruction with forward dependency to the next iteration
# cannot be an early instruction, and an instruction depending
# on an instruction from a previous iteration cannot be late.

# pylint:disable=singleton-comparison
self._Add(producer.src.pre_var == False)
# self._Add(producer.src.pre_var == False)
# pylint:disable=singleton-comparison
self._Add(consumer.post_var == False)
# self._Add(consumer.post_var == False)

# ================================================================
# CONSTRAINTS (Single issuing) #
Expand Down
4 changes: 4 additions & 0 deletions slothy/core/dataflow.py
Original file line number Diff line number Diff line change
Expand Up @@ -74,6 +74,8 @@ def alloc(self):
return self.src.alloc_out_var[self.idx]
def reduce(self):
return self
def sibling(self):
return InstructionOutput(self.src.sibling, self.idx)

class InstructionInOut(RegisterSource):
"""Represents an input/output of a node in the data flow graph"""
Expand All @@ -87,6 +89,8 @@ def alloc(self):
return self.src.alloc_in_out_var[self.idx]
def reduce(self):
return self.src.src_in_out[self.idx].reduce()
def sibling(self):
return InstructionInOut(self.src.sibling, self.idx)

class VirtualInstruction:
"""A 'virtual' instruction node for inputs and outputs."""
Expand Down

0 comments on commit 7b292da

Please sign in to comment.