Skip to content

Commit

Permalink
[XLA]
Browse files Browse the repository at this point in the history
No functional change: Refactoring cost analysis for fusions.

PiperOrigin-RevId: 586412178
  • Loading branch information
tensorflower-gardener authored and copybara-github committed Nov 29, 2023
1 parent 17d9572 commit 9e172db
Show file tree
Hide file tree
Showing 3 changed files with 114 additions and 41 deletions.
1 change: 1 addition & 0 deletions xla/service/BUILD
Original file line number Diff line number Diff line change
Expand Up @@ -3975,6 +3975,7 @@ cc_library(
hdrs = ["hlo_cost_analysis.h"],
deps = [
"//xla:shape_util",
"//xla:status",
"//xla:status_macros",
"//xla:statusor",
"//xla:util",
Expand Down
115 changes: 77 additions & 38 deletions xla/service/hlo_cost_analysis.cc
Original file line number Diff line number Diff line change
Expand Up @@ -31,6 +31,7 @@ limitations under the License.
#include "xla/hlo/ir/hlo_instructions.h"
#include "xla/hlo/ir/hlo_opcode.h"
#include "xla/shape_util.h"
#include "xla/status.h"
#include "xla/status_macros.h"
#include "xla/util.h"
#include "xla/window_util.h"
Expand Down Expand Up @@ -222,7 +223,11 @@ Status HloCostAnalysis::FusionCalculateUtilizations(
// instruction.
for (const HloInstruction* instr :
fusion->fused_instructions_computation()->instructions()) {
hlo_properties_[instr][kUtilizationKey] = 1.f;
if (ShouldFilterFusionInstruction(fusion, instr)) {
hlo_properties_[instr][kUtilizationKey] = 0.f;
} else {
hlo_properties_[instr][kUtilizationKey] = 1.f;
}
}
return OkStatus();
}
Expand Down Expand Up @@ -1009,28 +1014,11 @@ Status HloCostAnalysis::HandleRngGetAndUpdateState(
return OkStatus();
}

Status HloCostAnalysis::HandleFusion(const HloInstruction* fusion) {
VLOG(8) << "Processing fusion " << fusion->ToString();

if (fusion->IsCustomFusion()) {
for (const HloInstruction* hlo :
fusion->fused_instructions_computation()->instructions()) {
if (hlo->opcode() == HloOpcode::kGather) {
return HandleGather(hlo);
}
if (hlo->opcode() == HloOpcode::kScatter) {
return HandleScatter(hlo);
}
}
}
TF_ASSIGN_OR_RETURN(
current_properties_,
ProcessSubcomputation(fusion->fused_instructions_computation()));

Status HloCostAnalysis::FusionProcessOutputBytesAccessed(
const HloInstruction* fusion) {
// Fusion nodes that produce a tuple also produce the entries in the tuple.
// Ignore the memory accessed inside fused ops, since fusion is supposed to
// prevent intermediate data from touching slow memory.
current_properties_[kBytesAccessedKey] = 0;
ShapeUtil::ForEachSubshape(
fusion->shape(),
[this, fusion](const Shape& subshape, const ShapeIndex& shape_index) {
Expand All @@ -1039,7 +1027,17 @@ Status HloCostAnalysis::HandleFusion(const HloInstruction* fusion) {
}

const HloInstruction* root = fusion->fused_expression_root();
if (shape_index.size() == 1 && root->opcode() == HloOpcode::kTuple) {

auto further_examine_index =
shape_index.size() == 1 && root->opcode() == HloOpcode::kTuple;
if (further_examine_index &&
ShouldFilterFusionOutputIndex(fusion, shape_index)) {
current_properties_.set_output_bytes_accessed(shape_index, 0);
hlo_properties_[root->operand(shape_index[0])]
[GetOperandUtilizationKey(0)] = 0;
return;
}
if (further_examine_index) {
root = root->operand(shape_index[0]);
}

Expand Down Expand Up @@ -1072,6 +1070,9 @@ Status HloCostAnalysis::HandleFusion(const HloInstruction* fusion) {
}
for (int i = 0; i < shape.tuple_shapes_size(); ++i) {
const Shape& subshape = shape.tuple_shapes(i);
if (!subshape.IsTuple() && ShouldFilterFusionOutputIndex(fusion, {i})) {
continue;
}
ShapeIndex subshape_index(shape_index);
subshape_index.push_back(i);
bytes_accessed +=
Expand All @@ -1082,27 +1083,20 @@ Status HloCostAnalysis::HandleFusion(const HloInstruction* fusion) {
current_properties_[GetOutputBytesAccessedKey()] = 0;
propagate_output_size_to_parent(fusion->shape(), {});
}
return OkStatus();
}

TF_RETURN_IF_ERROR(FusionCalculateUtilizations(fusion));

// Count memory access to all large constants.
for (const HloInstruction* instr :
fusion->fused_instructions_computation()->instructions()) {
if (instr->opcode() == HloOpcode::kConstant &&
ShapeUtil::ElementsIn(instr->shape()) >
immediate_constant_max_elements()) {
float utilization = hlo_properties_[instr][kUtilizationKey];
if (!options_.count_multiple_input_accesses) {
utilization = fmin(utilization, 1.0);
}
current_properties_[kBytesAccessedKey] +=
GetShapeSize(instr->shape()) * utilization;
}
}

Status HloCostAnalysis::FusionProcessOperandBytesRead(
const HloInstruction* fusion) {
for (int64_t i = 0; i < fusion->fused_parameters().size(); ++i) {
const HloInstruction* operand = fusion->fused_parameter(i);
int64_t operand_size = 0;
if (ShouldFilterFusionInput(fusion, i)) {
current_properties_.set_operand_bytes_accessed(i, operand_size);
current_properties_.set_operand_utilization(
i, hlo_properties_[operand][kUtilizationKey]);
continue;
}
if (!operand->shape().IsTuple()) {
operand_size = FusionParameterReadBytes(operand);
} else {
Expand Down Expand Up @@ -1131,6 +1125,51 @@ Status HloCostAnalysis::HandleFusion(const HloInstruction* fusion) {
current_properties_.set_operand_utilization(
i, hlo_properties_[operand][kUtilizationKey]);
}
return OkStatus();
}

Status HloCostAnalysis::FusionCountConstantsMemoryAccess(
const HloInstruction* fusion) {
// Count memory access to all large constants.
for (const HloInstruction* instr :
fusion->fused_instructions_computation()->instructions()) {
if (instr->opcode() == HloOpcode::kConstant &&
ShapeUtil::ElementsIn(instr->shape()) >
immediate_constant_max_elements()) {
float utilization = hlo_properties_[instr][kUtilizationKey];
if (!options_.count_multiple_input_accesses) {
utilization = fmin(utilization, 1.0);
}
current_properties_[kBytesAccessedKey] +=
GetShapeSize(instr->shape()) * utilization;
}
}
return OkStatus();
}

Status HloCostAnalysis::HandleFusion(const HloInstruction* fusion) {
VLOG(8) << "Processing fusion " << fusion->ToString();

if (fusion->IsCustomFusion()) {
for (const HloInstruction* hlo :
fusion->fused_instructions_computation()->instructions()) {
if (hlo->opcode() == HloOpcode::kGather) {
return HandleGather(hlo);
}
if (hlo->opcode() == HloOpcode::kScatter) {
return HandleScatter(hlo);
}
}
}
TF_ASSIGN_OR_RETURN(
current_properties_,
ProcessSubcomputation(fusion->fused_instructions_computation()));

current_properties_[kBytesAccessedKey] = 0;
TF_RETURN_IF_ERROR(FusionProcessOutputBytesAccessed(fusion));
TF_RETURN_IF_ERROR(FusionCalculateUtilizations(fusion));
TF_RETURN_IF_ERROR(FusionCountConstantsMemoryAccess(fusion));
TF_RETURN_IF_ERROR(FusionProcessOperandBytesRead(fusion));

return OkStatus();
}
Expand Down
39 changes: 36 additions & 3 deletions xla/service/hlo_cost_analysis.h
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@ limitations under the License.
#ifndef XLA_SERVICE_HLO_COST_ANALYSIS_H_
#define XLA_SERVICE_HLO_COST_ANALYSIS_H_

#include <cstdint>
#include <functional>
#include <memory>
#include <optional>
Expand Down Expand Up @@ -247,7 +248,7 @@ class HloCostAnalysis : public ConstDfsHloVisitor {
// props[kFlopsKey] gets optimized to `return flops_` just fine.

// Getters/setters for more complex properties like operand utilization,
// where we have a fastpath for e.g. operand 0/1 + shape_index {}.
// where we have a fastpath, e.g., operand 0/1 + shape_index {}.
float operand_utilization(int64_t operand,
const ShapeIndex& shape_index = {}) {
if (operand == 0 && shape_index.empty()) {
Expand Down Expand Up @@ -571,6 +572,37 @@ class HloCostAnalysis : public ConstDfsHloVisitor {
const DotDimensionNumbers& dnums);

protected:
// Computes the bytes accessed based on the outputs produced by the fusion
// instruction.
virtual Status FusionProcessOutputBytesAccessed(const HloInstruction* fusion);

// Computes the bytes accessed (read) based on the inputs consumed by the
// fusion instruction.
virtual Status FusionProcessOperandBytesRead(const HloInstruction* fusion);

// Computes memory access to all larger constants in the fusion instruction.
virtual Status FusionCountConstantsMemoryAccess(const HloInstruction* fusion);

// Allows exclusion of certain types of inputs from bytes accessed during
// FusionProcessOperandBytesRead.
virtual bool ShouldFilterFusionInput(const HloInstruction* fusion,
int64_t input_index) {
return false;
}

// Allows exclusion of certain instructions from FusionCalculateUtilizations.
virtual bool ShouldFilterFusionInstruction(
const HloInstruction* fusion, const HloInstruction* instruction) {
return false;
}

// Allows exclusion of certain types of output from bytes written during
// FusionProcessOutputBytesAccessed.
virtual bool ShouldFilterFusionOutputIndex(const HloInstruction* fusion,
const ShapeIndex& output_index) {
return false;
}

typedef absl::flat_hash_map<const HloInstruction*, Properties>
HloToProperties;

Expand All @@ -588,7 +620,8 @@ class HloCostAnalysis : public ConstDfsHloVisitor {
// given hlo. The cost of visited sub HLO instructions is saved to
// hlo_properties_, which will be used by functions such as
// flop_count(hlo_instruction) to return cost of a particular HLO instruction.
StatusOr<Properties> ProcessSubcomputation(HloComputation* computation);
virtual StatusOr<Properties> ProcessSubcomputation(
HloComputation* computation);

// Utility function to handle all element-wise operations.
Status HandleElementwiseOp(const HloInstruction* hlo_instruction);
Expand All @@ -615,7 +648,7 @@ class HloCostAnalysis : public ConstDfsHloVisitor {
// bottleneck.
bool current_should_compute_bottleneck_time_;

// The properties of the currently visited instruction. A HandleFoo method can
// The properties of the currently visited instruction. A HandleFoo method
// modify these to change the default values computed in Preprocess.
Properties current_properties_;

Expand Down

0 comments on commit 9e172db

Please sign in to comment.