Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Extend XeTile.tile and XeGPU.nd_tdesc to accept n-d memref as input #812

Open
wants to merge 4 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 5 additions & 1 deletion docs/rfcs/XeGPU.md
Original file line number Diff line number Diff line change
Expand Up @@ -61,7 +61,7 @@ create_nd_tdesc creates a tensor descriptor that covers an array of 2D subtensor
into tensor_desc<8x16xbf16, array_length=2>
```

create_nd_tdesc also accepts a memref as input instead of a memory address, shapes, and sizes.
create_nd_tdesc also accepts a memref as input instead of a memory address, shapes, and sizes. The memref can be high-dimension.
```mlir
#sg_map_a = xegpu.sg_map<wi_layout = [2, 8], wi_data = [1, 2]>
%tdesc1 = XeGPU.create_nd_tdesc %mref, %offsets:2
Expand All @@ -71,6 +71,10 @@ create_nd_tdesc also accepts a memref as input instead of a memory address, shap
%tdesc2 = XeGPU.create_nd_tdesc %mref, %offsets:2 {mode =vc}
: memref<1024x1024xbf16>, index, index
into tensor_desc<8x16xbf16>

%tdesc2 = XeGPU.create_nd_tdesc %mref, %offsets:4 {mode =vc}
: memref<4x4x1024x1024xbf16>, index, index
into tensor_desc<8x16xbf16>
```

The example below accepts a memory address and an offset and creates a 1D tensor_desc. The tensor_desc describes a 1D vector that is loaded by all WI threads combined within the subgroup.
Expand Down
15 changes: 14 additions & 1 deletion docs/rfcs/XeTile.md
Original file line number Diff line number Diff line change
Expand Up @@ -45,6 +45,13 @@ To create a 2D Tile memory descriptor, the user needs to set up a tile (init_til
%tile0 = XeTile.init_tile %base_memref, [%tile_offset:2] :
memref<128x128xbf16> into tile<8x16xbf16>
```

`init_tile` can take high-dimension memref as input. The innermost two dimension of input memref are used to derive the tile's base_shape and base_strides.
```mlir
%tile0 = XeTile.init_tile %base_memref, [%tile_offset:4] :
memref<4x4x128x128xbf16> into tile<8x16xbf16>
```

`init_tile` with memref of dynamic shape. The memref has a dynamic shape, so that its shape and strides have to be passed as runtime parameters to init_tile.
```mlir
%tile0 = XeTile.init_tile %base_memref, [%tile_offset:2], [%base_shape:2], [%base_strides:2]:
Expand All @@ -61,9 +68,15 @@ To create a 2D Tile memory descriptor, the user needs to set up a tile (init_til
```mlir
#tile_attr = #xetile.tile_attr<order = [0, 1]>
%tile0 = XeTile.init_tile %base_memref, [%tile_offset:2]:
memref<128x128xbf16, affine_map=<(d0, d1)->(d1, d0)> into tile<64x32xbf16, #tile_attr>
memref<128x128xbf16, affine_map=<(d0, d1)->(d1, d0)>> into tile<64x32xbf16, #tile_attr>
```

The tile with `order` attribute can be created from the high-dimension memref.
```mlir
#tile_attr = #xetile.tile_attr<order = [0, 1]>
%tile0 = XeTile.init_tile %base_memref, [%tile_offset:4]:
memref<4x4x128x128xbf16, affine_map=<(d3, d2, d0, d1)->(d3, d2, d1, d0)>> into tile<64x32xbf16, #tile_attr>
```

With the tile date type, XeTile supports load_tile, prefetch_tile, and store_tile.

Expand Down
Loading