Skip to content
This repository has been archived by the owner on Aug 30, 2024. It is now read-only.

Commit

Permalink
opt store_xe.hpp
Browse files Browse the repository at this point in the history
  • Loading branch information
sunjiweiswift committed Aug 26, 2024
1 parent bba4180 commit 468c1ea
Show file tree
Hide file tree
Showing 4 changed files with 90 additions and 145 deletions.
3 changes: 3 additions & 0 deletions include/common/core/arch_config.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -47,6 +47,9 @@ struct load_store_attr_t<msg_type::block_2d, gpu_arch::XeHpc> {
// BlockWidth must be 1,2,4 for qwords and be in range [1..8] for dwords.
static constexpr uint32_t max_trans_load_width_in_bytes = 32;

// BlockHeight must be 8 for qwords and be in range [1..32] for dwords.
static constexpr uint32_t max_trans_load_height_in_elem = 32;

// If Transformed is true
// BlockWidth must be in range [4..16] for bytes and [2..16] for word.
static constexpr uint32_t max_vnni_load_width_in_elems = 16;
Expand Down
23 changes: 8 additions & 15 deletions include/subgroup/tile/impl/load_xe.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -89,7 +89,6 @@ tile_load(tile_t& tile, payload_t& payload) {

static constexpr uint32_t num_block_x = tile_desc::num_block_x;
static constexpr uint32_t num_block_y = tile_desc::num_block_y;
// static constexpr uint32_t num_block = tile_desc::num_block;

static constexpr gpu_arch arch_tag = payload_t::arch_tag;

Expand Down Expand Up @@ -181,19 +180,9 @@ tile_load(tile_t& tile, payload_t& payload) {
for (uint32_t i = 0; i < num_block_y; ++i) {
constexpr uint32_t load_block_elems = block_elems * arr_len;
int offset_y = i * block_size_y;
// auto payload_row =
// payload_2d.xetla_select<num_block_x, 1, 16, 1>(i * num_block_x, 0);
// detail::reset_tile_desc_core<
// num_block_x,
// block_size_x,
// ld_blk_size_y,
// scale_factor,
// arr_len,
// mem_transpose>(payload_row);
#pragma unroll
for (uint32_t j = 0; j < num_block_x; j += arr_len) {
int32_t offset_x = j * block_size_x;
// xetla_tdescriptor tdesc = payload_row.row(j);
auto reg_blk = tile.reg.xetla_select<load_block_elems, 1>(
(i * num_block_x + j) * block_elems);
constexpr uint32_t ld_blk_height = (reg_transpose && trans)
Expand All @@ -215,7 +204,8 @@ tile_load(tile_t& tile, payload_t& payload) {
mem_transform,
L1,
L2>(
payload.base_ptr,
reinterpret_cast<const native_type_t<load_dtype*>>(
payload.base_ptr),
payload.surface_width,
payload.surface_height,
payload.surface_pitch,
Expand Down Expand Up @@ -273,7 +263,8 @@ tile_load(tile_t& tile, payload_t& payload) {
mem_transform,
L1,
L2>(
payload.base_ptr,
reinterpret_cast<const native_type_t<load_dtype*>>(
payload.base_ptr),
payload.surface_width,
payload.surface_height,
payload.surface_pitch,
Expand Down Expand Up @@ -335,7 +326,8 @@ tile_load(tile_t& tile, payload_t& payload) {
mem_transform,
L1,
L2>(
payload.base_ptr,
reinterpret_cast<const native_type_t<load_dtype*>>(
payload.base_ptr),
payload.surface_width,
payload.surface_height,
payload.surface_pitch,
Expand Down Expand Up @@ -402,7 +394,8 @@ tile_load(tile_t& tile, payload_t& payload) {
mem_transform,
L1,
L2>(
payload.base_ptr,
reinterpret_cast<const native_type_t<load_dtype*>>(
payload.base_ptr),
payload.surface_width,
payload.surface_height,
payload.surface_pitch,
Expand Down
31 changes: 26 additions & 5 deletions include/subgroup/tile/impl/payload_xe.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -84,7 +84,28 @@ struct mem_payload_t<
using mem_dtype = typename std::
conditional_t<mem_transpose_dtype_less4bytes, uint32_t, dtype>;
static constexpr uint32_t scale_factor = sizeof(mem_dtype) / sizeof(dtype);
mem_dtype* base_ptr;

using load_store_attr = load_store_attr_t<msg_type::block_2d, arch_tag>;

static constexpr uint32_t max_load_width_in_elem = trans
? load_store_attr::max_trans_load_width_in_bytes / sizeof(dtype)
: load_store_attr::max_load_width_in_bytes / sizeof(dtype);
static constexpr uint32_t max_load_height_in_elem = trans
? load_store_attr::max_trans_load_height_in_elem
: load_store_attr::max_load_height_in_elem;

static constexpr uint32_t max_store_width_in_elem =
load_store_attr::max_store_width_in_bytes / sizeof(dtype);
static constexpr uint32_t max_store_height_in_elem =
load_store_attr::max_store_height_in_elem;

static constexpr uint32_t elems_per_CL =
load_store_attr::cache_line_size_in_bytes / sizeof(dtype);

static constexpr uint32_t elems_per_reg =
register_bytes_t<arch_tag>::reg_in_bytes / sizeof(dtype);

dtype* base_ptr;
uint32_t surface_width;
uint32_t surface_height;
uint32_t surface_pitch;
Expand All @@ -105,7 +126,7 @@ struct mem_payload_t<
}

inline mem_payload_t(mem_desc_t& mem_desc) {
this->base_ptr = (mem_dtype*)mem_desc.base.base;
this->base_ptr = (dtype*)mem_desc.base.base;
this->surface_width =
(mem_transpose ? mem_desc.shape.y : mem_desc.shape.x) * sizeof(dtype);
this->surface_height =
Expand All @@ -130,7 +151,7 @@ struct mem_payload_t<
uint32_t surface_pitch,
int32_t surface_offset_x = 0,
int32_t surface_offset_y = 0) {
this->base_ptr = (mem_dtype*)p;
this->base_ptr = p;
this->surface_width = surface_width * sizeof(dtype);
this->surface_height = surface_height;
this->surface_pitch = surface_pitch * sizeof(dtype);
Expand All @@ -151,7 +172,7 @@ struct mem_payload_t<
}

__XETLA_API void init(mem_desc_t& mem_desc) {
this->base_ptr = (mem_dtype*)mem_desc.base.base;
this->base_ptr = (dtype*)mem_desc.base.base;
this->surface_width =
(mem_transpose ? mem_desc.shape.y : mem_desc.shape.x) * sizeof(dtype);
this->surface_height =
Expand Down Expand Up @@ -184,7 +205,7 @@ struct mem_payload_t<
uint32_t surface_pitch,
int32_t surface_offset_x = 0,
int32_t surface_offset_y = 0) {
this->base_ptr = (mem_dtype*)p;
this->base_ptr = p;
this->surface_width = surface_width * sizeof(dtype);
this->surface_height = surface_height;
this->surface_pitch = surface_pitch * sizeof(dtype);
Expand Down
Loading

0 comments on commit 468c1ea

Please sign in to comment.