Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

请教一下tile_to_shape这个函数如何和swizzle配合使用的 #6

Open
Ddd195 opened this issue Jul 19, 2024 · 10 comments
Open

Comments

@Ddd195
Copy link

Ddd195 commented Jul 19, 2024

RT,tile_to_shape这个函数的作用是什么

@66RING
Copy link
Owner

66RING commented Jul 19, 2024

@Ddd195 根据我的理解,大概意思就是你已经有了一个global memory的tile,现在要在smem上也做tile, 那这个tile的shape要和gmem的shape贴合又要和swizzle atom的shape贴合,所以就有了这么一个tile_to_shape来自动创建shape: 给一个gmem的shape和一个swizzle atom的shape就自动创建一个两边都贴合的shape.

我个人理解是这样,也就只能比较高层次的理解了,具体细节也不很清楚。

@Ddd195
Copy link
Author

Ddd195 commented Jul 24, 2024

@66RING 谢谢您的回复!我请教了一下reed,他说仅仅用于把tile扩展到更大的块,后来我知道了使用方法大概就是先定义经过swizzle的atom,然后用这个atom进行tile_to_shape来构造一个存放在share的tensor,虽然具体细节我也不太懂。

@Ddd195
Copy link
Author

Ddd195 commented Jul 24, 2024

@66RING 请问一下size(Kernel_traits::kNThreads)这个是128吗,为什么在写blockdim时不直接写Kernel_traits::kNThreads呢

@66RING
Copy link
Owner

66RING commented Jul 24, 2024

@Ddd195 是应该直接写成Kernel_traits::kNThreads的,这个可能是standalone用来调试时改着改着忘了

@Ddd195
Copy link
Author

Ddd195 commented Jul 25, 2024

@66RING 谢谢谢谢!,还想再请教一下tiled_mma.get_slice(tidx);这个函数,我读了您的代码,以Q为例,gQ块是6464,mma能力是6416,block线程数是128,这个gQ块是如何通过get_slice和线程id以mma能力为基础分配给每个线程的,在这个例子中一个线程处理怎么样shape的矩阵呢。

@66RING
Copy link
Owner

66RING commented Jul 25, 2024

@Ddd195 这些就是看cutlass抽象出来的mma原语了,可以用这个脚本可视化,https://gist.github.com/66RING/2e188b73fdf703e9f9dfc7371814dd15

比如:

  {
    auto tiled_mma = make_tiled_mma(SM80_16x8x16_F32F16F16F32_TN{},
                                    Layout<Shape<_1,_1, _1>>{},  // AtomLayoutMNK
                                    Layout<Shape<_1,_2, _1>>{}   // ValLayoutMNK
    );
    print_mma_content("flash2: SM80_16x8x16_F32F16F16F32_TN", tiled_mma);
  }

然后改AtomLayoutMNK和ValLayoutMNK可以微调线程的处理的任务

@Ddd195
Copy link
Author

Ddd195 commented Jul 29, 2024

@66RING 谢谢您的答复,我还想再请教一下retile_S和partition_S有什么区别,没太明白retile的作用。
比如auto tCrA = thr_mma.partition_fragment_A(gA(_, _, 0)); auto tCrA_view = s2r_thr_copy_a.retile_D(tCrA);

@66RING
Copy link
Owner

66RING commented Jul 29, 2024

@Ddd195 partition_S/D是创建copy对象所需的src和dst。作用在gmem, smem这些"内存"上。而retile是作用在寄存器的,retile来变成寄存器私有数据所需的形状。而partition_fragment_A是构造一个空的寄存器的view,数据需要从smem拷贝进去

比如从smem拷贝到寄存器,那就用partition构造一个对接smem的src,然后用retile构造一个和寄存器对接的dst。

反之同理

  Tensor taccOrO = smem_thr_copy_O.retile_S(rO);        // ((Atom,AtomNum), MMA_M, MMA_N)
  Tensor taccOsO = smem_thr_copy_O.partition_D(sO);     // ((Atom,AtomNum),PIPE_M,PIPE_N)

  // NOTE: 先拷贝到smem
  cute::copy(smem_tiled_copy_O, taccOrO, taccOsO);

我的理解是这样,如果有错麻烦帮我指正一下

@Ddd195
Copy link
Author

Ddd195 commented Jul 30, 2024

@Ddd195 partition_S/D是创建copy对象所需的src和dst。作用在gmem, smem这些"内存"上。而retile是作用在寄存器的,retile来变成寄存器私有数据所需的形状。而partition_fragment_A是构造一个空的寄存器的view,数据需要从smem拷贝进去

比如从smem拷贝到寄存器,那就用partition构造一个对接smem的src,然后用retile构造一个和寄存器对接的dst。

反之同理

  Tensor taccOrO = smem_thr_copy_O.retile_S(rO);        // ((Atom,AtomNum), MMA_M, MMA_N)
  Tensor taccOsO = smem_thr_copy_O.partition_D(sO);     // ((Atom,AtomNum),PIPE_M,PIPE_N)

  // NOTE: 先拷贝到smem
  cute::copy(smem_tiled_copy_O, taccOrO, taccOsO);

我的理解是这样,如果有错麻烦帮我指正一下

我也认为是这样,但是使用起来灵活,感觉我还是无法完全搞明白其中的线程和寄存器如何变换,只能知道这些步骤在做什么,放弃了。谢谢大佬

@sleepwalker2017
Copy link

@Ddd195 这些就是看cutlass抽象出来的mma原语了,可以用这个脚本可视化,https://gist.github.com/66RING/2e188b73fdf703e9f9dfc7371814dd15

比如:

  {
    auto tiled_mma = make_tiled_mma(SM80_16x8x16_F32F16F16F32_TN{},
                                    Layout<Shape<_1,_1, _1>>{},  // AtomLayoutMNK
                                    Layout<Shape<_1,_2, _1>>{}   // ValLayoutMNK
    );
    print_mma_content("flash2: SM80_16x8x16_F32F16F16F32_TN", tiled_mma);
  }

然后改AtomLayoutMNK和ValLayoutMNK可以微调线程的处理的任务

大佬请教下,这个图里的各种颜色是啥意思啊?为什么A的同一列是一个颜色呢?

image

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

3 participants