Skip to content

Commit

Permalink
Fx parallel split-k (#1116)
Browse files Browse the repository at this point in the history
  • Loading branch information
Manish Gupta authored and ttl10101 committed Feb 7, 2024
1 parent af0a5c8 commit f17b199
Showing 1 changed file with 19 additions and 4 deletions.
23 changes: 19 additions & 4 deletions tools/library/src/handle.cu
Original file line number Diff line number Diff line change
Expand Up @@ -1168,15 +1168,30 @@ Operation const* find_gemm_operation_for_parallel_reduction(Operation const *ope
return nullptr;
}

// A and B uses the same alignment in the generator.py
int alignment = gemm_desc.A.alignment;
// gemm operation for same compute capability and max operand alignment
int alignment = std::max(
gemm_desc.A.alignment,
gemm_desc.B.alignment);

// gemm operation for same compute capability and iterator algorithm
GemmPreferenceKey preference_key(
gemm_desc.tile_description.minimum_compute_capability,
alignment);

return find_gemm_operation(operators_it, preference_key);
auto it = operators_it->second.find(preference_key);

if(it == operators_it->second.end()) {
return nullptr;
}

// return matching gemm operation (same tile shape, stages, warp count, and instruction)
for (auto op : it->second) {
if (op->description().tile_description == operation->description().tile_description) {
return op;
}
}

// return nullptr if no matching gemm operation found for parallel split-k reduction
return nullptr;
}

/////////////////////////////////////////////////////////////////////////////////////////////////
Expand Down

0 comments on commit f17b199

Please sign in to comment.