Skip to content

Commit

Permalink
Uneven support
Browse files Browse the repository at this point in the history
  • Loading branch information
Michael Gilbert committed Oct 16, 2024
1 parent 5558522 commit 89cac66
Show file tree
Hide file tree
Showing 12 changed files with 485 additions and 369 deletions.
77 changes: 77 additions & 0 deletions bindings/looptree/ir.cpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,77 @@
#include "pytimeloop/bindings/looptree.h"

#include <sstream>

#include <applications/looptree-model/model.hpp>
#include <workload/fused-workload.hpp>
#include <workload/fused-workload-dependency-analyzer.hpp>

#include <pybind11/stl.h>


#define FUSED_WORKLOAD_METHOD(python_name, cpp_name) \
def(#python_name, &problem::FusedWorkload::cpp_name)

#define FUSED_WORKLOAD_ANALYZER_METHOD(python_name, cpp_name) \
def(#python_name, &problem::FusedWorkloadDependencyAnalyzer::cpp_name)

namespace py = pybind11;

#define DEFINE_REPR_VIA_STRINGSTREAM(class) \
def("__repr__", &print_via_stringstream<class>)

#define DEFINE_PROPERTY(class, name) \
def_readwrite(#name, &analysis::class::name)


template<typename T>
std::string print_via_stringstream(const T& t)
{
std::stringstream buf;
buf << t;
return buf.str();
}


namespace pytimeloop::looptree_bindings
{

void BindIr(py::module& m)
{
py::class_<analysis::Temporal>(m, "Temporal")
.def(py::init<>())
.DEFINE_REPR_VIA_STRINGSTREAM(analysis::Temporal);

py::class_<analysis::Spatial>(m, "Spatial")
.def(py::init<int, analysis::BufferId>())
.DEFINE_REPR_VIA_STRINGSTREAM(analysis::Spatial);

py::class_<analysis::Sequential>(m, "Sequential")
.def(py::init<>())
.DEFINE_REPR_VIA_STRINGSTREAM(analysis::Sequential);

py::class_<analysis::PipelineTemporal>(m, "PipelineTemporal")
.def(py::init<>())
.DEFINE_REPR_VIA_STRINGSTREAM(analysis::PipelineTemporal);

py::class_<analysis::PipelineSpatial>(m, "PipelineSpatial")
.def(py::init<>())
.DEFINE_REPR_VIA_STRINGSTREAM(analysis::PipelineSpatial);

py::class_<analysis::LogicalBuffer>(m, "LogicalBuffer")
.def(py::init<>())
.DEFINE_PROPERTY(LogicalBuffer, buffer_id)
.DEFINE_PROPERTY(LogicalBuffer, dspace_id)
.DEFINE_PROPERTY(LogicalBuffer, branch_leaf_id)
.DEFINE_REPR_VIA_STRINGSTREAM(analysis::LogicalBuffer);

py::class_<analysis::Occupancy>(m, "Occupancy")
.def(py::init<>())
.DEFINE_PROPERTY(Occupancy, dim_in_tags);

py::class_<analysis::Fill>(m, "Fill")
.def(py::init<>())
.DEFINE_PROPERTY(Fill, dim_in_tags);
}

}
218 changes: 6 additions & 212 deletions pytimeloop/fastfusion/fusionset.py
Original file line number Diff line number Diff line change
@@ -1,10 +1,10 @@
from typing import Any, Generator
from pareto import Pareto
from compatibility import OpCompatibility
from collections import defaultdict
import unittest
import itertools
from util import fzs
from typing import Any, Generator

from .compatibility import OpCompatibility
from .pareto import Pareto
from .util import fzs


class FusionSet:
Expand Down Expand Up @@ -202,210 +202,4 @@ def __lt__(self, other: "FusionSet") -> bool:
return self.compatibility < other.compatibility

def __repr__(self):
return f"FusionSet({self.compatibility})"


class TestFusionSet(unittest.TestCase):
def test_vertical_combine(self):
fs = []
for i in range(2):
comp = OpCompatibility(
einsum_id=f"einsum1",
fused_tensors=fzs(),
fused_loops=(),
fused_ranks=fzs(),
ranks=fzs(),
tensors=fzs(),
neighbors=fzs(),
)
fs.append(FusionSet({comp}, Pareto(data={})))
new_fs = FusionSet.vertical_combine(fs)
self.assertEqual(len(new_fs.compatibility), 1)
self.assertEqual(new_fs.payload.data, {})

def test_combine(self):
comp1 = OpCompatibility(
einsum_id=f"einsum1",
fused_tensors=fzs(),
fused_loops=(),
fused_ranks=fzs(),
ranks=fzs("R"),
tensors=fzs("Q"),
neighbors=fzs("123"),
)
comp2 = OpCompatibility(
einsum_id=f"einsum2",
fused_tensors=fzs(),
fused_loops=(),
fused_ranks=fzs(),
ranks=fzs("S"),
tensors=fzs("V"),
neighbors=fzs("ABC"),
)
fs1 = FusionSet({comp1}, Pareto(data={}))
fs2 = FusionSet({comp2}, Pareto(data={}))
new_fs = fs1.combine(fs2)
self.assertEqual(len(new_fs.compatibility), 2)
self.assertIn(comp1, new_fs.compatibility)
self.assertIn(comp2, new_fs.compatibility)
self.assertEqual(new_fs.payload.data, {})
self.assertEqual(new_fs.tensors, {"Q", "V"})
self.assertEqual(new_fs.ranks, {"R", "S"})

def test_compatibile_with(self):
for neighbors in fzs("ABC"), fzs():
kwargs = dict(
fused_tensors=fzs("T1"),
fused_ranks=fzs(),
ranks=fzs("A"),
tensors=fzs(),
neighbors=neighbors,
)

comp1 = OpCompatibility(einsum_id="A", fused_loops=(("A", 1),), **kwargs)
comp2 = OpCompatibility(einsum_id="B", fused_loops=(("A", 2),), **kwargs)

comp4 = OpCompatibility(einsum_id="C", fused_loops=(("A", 4),), **kwargs)
comp5 = OpCompatibility(einsum_id="C", fused_loops=(("A", 3),), **kwargs)

fs1 = FusionSet({comp1, comp2}, Pareto(data={}))
fs2 = FusionSet({comp4}, Pareto(data={}))
self.assertEqual(fs1.compatible_with(fs2), True)

fs2 = FusionSet({comp5}, Pareto(data={}))
# Not neighbors --> compatible becuase there's nothing overlapping to check
self.assertEqual(fs1.compatible_with(fs2), not neighbors)

# Test:
# - Drop dead
# - Finding live neighbors
# -
def test_drop_dead(self):
comp1 = OpCompatibility(
einsum_id=f"einsum1",
fused_tensors=fzs(),
fused_loops=(),
fused_ranks=fzs(),
ranks=fzs("R"),
tensors=fzs("Q"),
neighbors=fzs("123"),
)
comp2 = OpCompatibility(
einsum_id=f"einsum2",
fused_tensors=fzs(),
fused_loops=(),
fused_ranks=fzs(),
ranks=fzs("S"),
tensors=fzs("V"),
neighbors=fzs("ABC"),
)
fs = FusionSet({comp1, comp2}, Pareto(data={}))
fs.drop_dead({"einsum1"})
self.assertEqual(len(fs.compatibility), 1)
self.assertIn(comp1, fs.compatibility)
self.assertEqual(fs.payload.data, {})
fs.drop_dead(set())
self.assertEqual(len(fs.compatibility), 0)
self.assertEqual(fs.payload.data, {})

def test_live_partition(self):
kwargs = dict(
fused_tensors=fzs("T1"),
ranks=fzs("A"),
tensors=fzs(),
fused_loops=(),
)

a = OpCompatibility(
einsum_id="A", fused_ranks=fzs("A"), neighbors=fzs("B"), **kwargs
)
b = OpCompatibility(
einsum_id="B", fused_ranks=fzs("A"), neighbors=fzs("AC"), **kwargs
)
c = OpCompatibility(
einsum_id="C", fused_ranks=fzs(), neighbors=fzs("BD"), **kwargs
)
d = OpCompatibility(
einsum_id="D", fused_ranks=fzs("A"), neighbors=fzs("CE"), **kwargs
)
e = OpCompatibility(
einsum_id="E", fused_ranks=fzs("A"), neighbors=fzs("DF"), **kwargs
)
f = OpCompatibility(
einsum_id="F", fused_ranks=fzs("A"), neighbors=fzs("E"), **kwargs
)

for live, partition in [
("A", ("AB",)),
("B", ("AB", "C")),
("C", ("AB", "C", "DEF")),
("D", ("C", "DEF")),
("E", ("DEF",)),
("F", ("DEF",)),
("AF", ("AB", "DEF")),
("ABF", ("AB", "C", "DEF")),
]:
fs = FusionSet({a, b, c, d, e, f}, Pareto(data={}))
fs.drop_dead(set(live))
partitions = OpCompatibility.get_tiled_partitions(fs.compatibility)
ids = tuple(
sorted("".join(sorted(p.einsum_id for p in p2)) for p2 in partitions)
)
msg = f"Failed with {live} {partition}, got {ids}"
self.assertEqual(len(fs.compatibility), sum(len(l) for l in partition), msg)
self.assertEqual(ids, partition, msg)

def test_bucketing(self):
tensor_choices = ["A", "B", "BC"]
rank_choices = ["MN", "NM"]
rank_size_choices = [1, 2]
has_other_einsum = [True, False]

comps = []
for t in tensor_choices:
for r in rank_choices:
for rs in rank_size_choices:
for other in has_other_einsum:
kwargs = dict(
fused_tensors=fzs(t),
fused_loops=tuple((x, rs) for x in r),
fused_ranks=fzs(r),
ranks=fzs("MN"),
tensors=fzs(t),
neighbors=fzs(),
)
x = {OpCompatibility(einsum_id="einsum1", **kwargs)}
if other:
x.add(OpCompatibility(einsum_id="einsum2", **kwargs))
comps.append(FusionSet(x, Pareto(data={})))

def check_bucket_sizes(bucketed, expected):
if expected:
self.assertEqual(len(bucketed), expected[0])
for b in bucketed.values():
check_bucket_sizes(b, expected[1:])

fusion_sets = FusionSet.bucket_multi_level(
comps, {"einsum1"}, {"A"}, {"M", "N"}
)
check_bucket_sizes(fusion_sets, [2, 2, 2])
fusion_sets = FusionSet.bucket_multi_level(
comps, {"einsum1", "einsum2"}, {"A"}, {"M", "N"}
)
check_bucket_sizes(fusion_sets, [4, 2, 2])
fusion_sets = FusionSet.bucket_multi_level(
comps, {"einsum1"}, {"A", "C"}, {"M"}
)
check_bucket_sizes(fusion_sets, [3, 2, 2])
fusion_sets = FusionSet.bucket_multi_level(
comps, {"einsum1"}, {"A", "C"}, set()
)
check_bucket_sizes(fusion_sets, [3, 1, 1])
fusion_sets = FusionSet.bucket_multi_level(
comps, {"einsum1", "einsum2"}, set(), set()
)
check_bucket_sizes(fusion_sets, [2, 1, 1])


if __name__ == "__main__":
unittest.main()
return f"FusionSet({self.compatibility})"
Loading

0 comments on commit 89cac66

Please sign in to comment.