Skip to content

Commit

Permalink
Add print_svg for mma (#1733)
Browse files Browse the repository at this point in the history
* add print_svg for mma

* correct the code indentation
  • Loading branch information
reed-lau authored Sep 18, 2024
1 parent 1ebda1c commit 2991ce1
Showing 1 changed file with 177 additions and 0 deletions.
177 changes: 177 additions & 0 deletions include/cute/atom/mma_atom.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -928,6 +928,183 @@ print_latex_mma(LayoutC const& C, ThrIDC const& TC, // (m,n) -> (tid,vid) and
printf(latex_footer);
}

// MNK MMA Layout to SVG -- 8-value color coded by thread
template <class LayoutC, class ThrIDC,
class LayoutA, class ThrIDA,
class LayoutB, class ThrIDB>
CUTE_HOST_DEVICE
void
print_svg_mma(LayoutC const& C, ThrIDC const& TC, // (m,n) -> (tid,vid) and tid -> thr_idx
LayoutA const& A, ThrIDA const& TA, // (m,k) -> (tid,vid) and tid -> thr_idx
LayoutB const& B, ThrIDB const& TB) // (n,k) -> (tid,vid) and tid -> thr_idx
{
char const *color_map[8] = {"175,175,255", "175,255,175", "255,255,175",
"255,175,175", "210,210,255", "210,255,210",
"255,255,210", "255,210,210"};

const int cell_width = 20;
const int cell_height = 20;

const int page_width = (size<1>(A) + size<0>(B) + 2) * cell_width;
const int page_height = (size<1>(B) + size<0>(A) + 2) * cell_height;

// header
printf("<svg width=\"100%%\" height=\"100%%\" viewBox=\"0 0 %d %d\" "
"preserveAspectRatio=\"xMidYMid meet\" "
"xmlns=\"http://www.w3.org/2000/svg\">\n",
page_width, page_height);

// C
int c_base_x = (size<1>(A) + 2) * cell_width;
int c_base_y = (size<1>(B) + 2) * cell_height;
for (int m = 0; m < cute::size<0>(C); ++m) {
for (int n = 0; n < cute::size<1>(C); ++n) {

int thrid = C(m, n) % size(TC);
int val_idx = C(m, n) / size(TC);
int thr_idx = TC(thrid);

int x = n * cell_width + c_base_x;
int y = m * cell_height + c_base_y;

int thr_x = x + cell_width / 2;
int thr_y = y + cell_height / 4;
int val_x = x + cell_width / 2;
int val_y = y + cell_height * 3 / 4;

printf("<rect x=\"%d\" y=\"%d\" width=\"%d\" height=\"%d\" "
"fill=\"rgb(%s)\" stroke=\"black\"/>\n",
x, y, cell_width, cell_height, color_map[thr_idx % 8]);

printf("<text x=\"%d\" y=\"%d\" text-anchor=\"middle\" "
"alignment-baseline=\"central\" font-size=\"8\">T%d</text>\n",
thr_x, thr_y, thr_idx);
printf("<text x=\"%d\" y=\"%d\" text-anchor=\"middle\" "
"alignment-baseline=\"central\" font-size=\"8\">V%d</text>\n",
val_x, val_y, val_idx);
}
}

// A
int a_base_x = cell_width;
int a_base_y = (size<1>(B) + 2) * cell_height;
for (int m = 0; m < size<0>(A); ++m) {
for (int k = 0; k < size<1>(A); ++k) {
int thrid = A(m, k) % size(TA);
int val_idx = A(m, k) / size(TA);
int thr_idx = TA(thrid);

int x = k * cell_width + a_base_x;
int y = m * cell_height + a_base_y;

int thr_x = x + cell_width / 2;
int thr_y = y + cell_height / 4;
int val_x = x + cell_width / 2;
int val_y = y + cell_height * 3 / 4;

printf("<rect x=\"%d\" y=\"%d\" width=\"%d\" height=\"%d\" "
"fill=\"rgb(%s)\" stroke=\"black\" />\n",
x, y, cell_width, cell_height, color_map[thr_idx % 8]);
printf("<text x=\"%d\" y=\"%d\" text-anchor=\"middle\" "
"alignment-baseline=\"central\" font-size=\"8\">T%d</text>\n",
thr_x, thr_y, thr_idx);
printf("<text x=\"%d\" y=\"%d\" text-anchor=\"middle\" "
"alignment-baseline=\"central\" font-size=\"8\">V%d</text>\n",
val_x, val_y, val_idx);
}
}

// B
int b_base_x = (size<1>(A) + 2) * cell_width;
int b_base_y = cell_height;
for (int n = 0; n < size<0>(B); ++n) {
for (int k = 0; k < size<1>(B); ++k) {
int thrid = B(n, k) % size(TB);
int val_idx = B(n, k) / size(TB);
int thr_idx = TB(thrid);

int x = n * cell_width + b_base_x;
int y = k * cell_height + b_base_y;

int thr_x = x + cell_width / 2;
int thr_y = y + cell_height / 4;
int val_x = x + cell_width / 2;
int val_y = y + cell_height * 3 / 4;

printf("<rect x=\"%d\" y=\"%d\" width=\"%d\" height=\"%d\" "
"fill=\"rgb(%s)\" stroke=\"black\" />\n",
x, y, cell_width, cell_height, color_map[thr_idx % 8]);
printf("<text x=\"%d\" y=\"%d\" text-anchor=\"middle\" "
"alignment-baseline=\"central\" font-size=\"8\">T%d</text>\n",
thr_x, thr_y, thr_idx);
printf("<text x=\"%d\" y=\"%d\" text-anchor=\"middle\" "
"alignment-baseline=\"central\" font-size=\"8\">V%d</text>\n",
val_x, val_y, val_idx);
}
}

// A labels
for (int m = 0; m < size<0>(A); ++m) {
int x = cell_width / 2;
int y = m * cell_height + cell_height / 2 + a_base_y;
printf("<text x=\"%d\" y=\"%d\" text-anchor=\"middle\" "
"alignment-baseline=\"central\" font-size=\"12\">%d</text>\n",
x, y, m);
}
for (int k = 0; k < size<1>(A); ++k) {
int x = cell_width + k * cell_width + cell_width / 2;
int y = -cell_height / 2 + a_base_y;
printf("<text x=\"%d\" y=\"%d\" text-anchor=\"middle\" "
"alignment-baseline=\"central\" font-size=\"12\">%d</text>\n",
x, y, k);
}

// B labels
for (int n = 0; n < size<0>(B); ++n) {
int x = b_base_x + cell_width * n + cell_width / 2;
int y = cell_height / 2;
printf("<text x=\"%d\" y=\"%d\" text-anchor=\"middle\" "
"alignment-baseline=\"central\" font-size=\"12\">%d</text>\n",
x, y, n);
}
for (int k = 0; k < size<1>(B); ++k) {
int x = b_base_x - cell_width / 2;
int y = cell_height * (k + 1) + cell_height / 2;
printf("<text x=\"%d\" y=\"%d\" text-anchor=\"middle\" "
"alignment-baseline=\"central\" font-size=\"12\">%d</text>\n",
x, y, k);
}

// footer
printf("</svg>");
}

template <class... Args>
CUTE_HOST_DEVICE
void
print_svg(MMA_Atom<Args...> const &mma_atom) {
print_svg(make_tiled_mma(mma_atom));
}

template <class... Args>
CUTE_HOST_DEVICE
void
print_svg(TiledMMA<Args...> const &mma) {
auto layout_and_thrid_C = mma.get_layoutC_MN();
auto layoutC_MN = get<0>(layout_and_thrid_C);
auto thrID_C = get<1>(layout_and_thrid_C);

auto layout_and_thrid_A = mma.get_layoutA_MK();
auto layoutA_MK = get<0>(layout_and_thrid_A);
auto thrID_A = get<1>(layout_and_thrid_A);

auto layout_and_thrid_B = mma.get_layoutB_NK();
auto layoutB_NK = get<0>(layout_and_thrid_B);
auto thrID_B = get<1>(layout_and_thrid_B);

print_svg_mma(layoutC_MN, thrID_C, layoutA_MK, thrID_A, layoutB_NK, thrID_B);
}

} // namespace cute

////////////////////////////////////////////////////////////////////////////////////////////////////
Expand Down

0 comments on commit 2991ce1

Please sign in to comment.