diff --git a/xla/service/BUILD b/xla/service/BUILD index 5d59d10dc6c25..81c966566adaa 100644 --- a/xla/service/BUILD +++ b/xla/service/BUILD @@ -3975,6 +3975,7 @@ cc_library( hdrs = ["hlo_cost_analysis.h"], deps = [ "//xla:shape_util", + "//xla:status", "//xla:status_macros", "//xla:statusor", "//xla:util", diff --git a/xla/service/hlo_cost_analysis.cc b/xla/service/hlo_cost_analysis.cc index 8207046fa6388..9fe36cd2eb16e 100644 --- a/xla/service/hlo_cost_analysis.cc +++ b/xla/service/hlo_cost_analysis.cc @@ -31,6 +31,7 @@ limitations under the License. #include "xla/hlo/ir/hlo_instructions.h" #include "xla/hlo/ir/hlo_opcode.h" #include "xla/shape_util.h" +#include "xla/status.h" #include "xla/status_macros.h" #include "xla/util.h" #include "xla/window_util.h" @@ -222,7 +223,11 @@ Status HloCostAnalysis::FusionCalculateUtilizations( // instruction. for (const HloInstruction* instr : fusion->fused_instructions_computation()->instructions()) { - hlo_properties_[instr][kUtilizationKey] = 1.f; + if (ShouldFilterFusionInstruction(fusion, instr)) { + hlo_properties_[instr][kUtilizationKey] = 0.f; + } else { + hlo_properties_[instr][kUtilizationKey] = 1.f; + } } return OkStatus(); } @@ -1009,28 +1014,11 @@ Status HloCostAnalysis::HandleRngGetAndUpdateState( return OkStatus(); } -Status HloCostAnalysis::HandleFusion(const HloInstruction* fusion) { - VLOG(8) << "Processing fusion " << fusion->ToString(); - - if (fusion->IsCustomFusion()) { - for (const HloInstruction* hlo : - fusion->fused_instructions_computation()->instructions()) { - if (hlo->opcode() == HloOpcode::kGather) { - return HandleGather(hlo); - } - if (hlo->opcode() == HloOpcode::kScatter) { - return HandleScatter(hlo); - } - } - } - TF_ASSIGN_OR_RETURN( - current_properties_, - ProcessSubcomputation(fusion->fused_instructions_computation())); - +Status HloCostAnalysis::FusionProcessOutputBytesAccessed( + const HloInstruction* fusion) { // Fusion nodes that produce a tuple also produce the entries in the tuple. // Ignore the memory accessed inside fused ops, since fusion is supposed to // prevent intermediate data from touching slow memory. - current_properties_[kBytesAccessedKey] = 0; ShapeUtil::ForEachSubshape( fusion->shape(), [this, fusion](const Shape& subshape, const ShapeIndex& shape_index) { @@ -1039,7 +1027,17 @@ Status HloCostAnalysis::HandleFusion(const HloInstruction* fusion) { } const HloInstruction* root = fusion->fused_expression_root(); - if (shape_index.size() == 1 && root->opcode() == HloOpcode::kTuple) { + + auto further_examine_index = + shape_index.size() == 1 && root->opcode() == HloOpcode::kTuple; + if (further_examine_index && + ShouldFilterFusionOutputIndex(fusion, shape_index)) { + current_properties_.set_output_bytes_accessed(shape_index, 0); + hlo_properties_[root->operand(shape_index[0])] + [GetOperandUtilizationKey(0)] = 0; + return; + } + if (further_examine_index) { root = root->operand(shape_index[0]); } @@ -1072,6 +1070,9 @@ Status HloCostAnalysis::HandleFusion(const HloInstruction* fusion) { } for (int i = 0; i < shape.tuple_shapes_size(); ++i) { const Shape& subshape = shape.tuple_shapes(i); + if (!subshape.IsTuple() && ShouldFilterFusionOutputIndex(fusion, {i})) { + continue; + } ShapeIndex subshape_index(shape_index); subshape_index.push_back(i); bytes_accessed += @@ -1082,27 +1083,20 @@ Status HloCostAnalysis::HandleFusion(const HloInstruction* fusion) { current_properties_[GetOutputBytesAccessedKey()] = 0; propagate_output_size_to_parent(fusion->shape(), {}); } + return OkStatus(); +} - TF_RETURN_IF_ERROR(FusionCalculateUtilizations(fusion)); - - // Count memory access to all large constants. - for (const HloInstruction* instr : - fusion->fused_instructions_computation()->instructions()) { - if (instr->opcode() == HloOpcode::kConstant && - ShapeUtil::ElementsIn(instr->shape()) > - immediate_constant_max_elements()) { - float utilization = hlo_properties_[instr][kUtilizationKey]; - if (!options_.count_multiple_input_accesses) { - utilization = fmin(utilization, 1.0); - } - current_properties_[kBytesAccessedKey] += - GetShapeSize(instr->shape()) * utilization; - } - } - +Status HloCostAnalysis::FusionProcessOperandBytesRead( + const HloInstruction* fusion) { for (int64_t i = 0; i < fusion->fused_parameters().size(); ++i) { const HloInstruction* operand = fusion->fused_parameter(i); int64_t operand_size = 0; + if (ShouldFilterFusionInput(fusion, i)) { + current_properties_.set_operand_bytes_accessed(i, operand_size); + current_properties_.set_operand_utilization( + i, hlo_properties_[operand][kUtilizationKey]); + continue; + } if (!operand->shape().IsTuple()) { operand_size = FusionParameterReadBytes(operand); } else { @@ -1131,6 +1125,51 @@ Status HloCostAnalysis::HandleFusion(const HloInstruction* fusion) { current_properties_.set_operand_utilization( i, hlo_properties_[operand][kUtilizationKey]); } + return OkStatus(); +} + +Status HloCostAnalysis::FusionCountConstantsMemoryAccess( + const HloInstruction* fusion) { + // Count memory access to all large constants. + for (const HloInstruction* instr : + fusion->fused_instructions_computation()->instructions()) { + if (instr->opcode() == HloOpcode::kConstant && + ShapeUtil::ElementsIn(instr->shape()) > + immediate_constant_max_elements()) { + float utilization = hlo_properties_[instr][kUtilizationKey]; + if (!options_.count_multiple_input_accesses) { + utilization = fmin(utilization, 1.0); + } + current_properties_[kBytesAccessedKey] += + GetShapeSize(instr->shape()) * utilization; + } + } + return OkStatus(); +} + +Status HloCostAnalysis::HandleFusion(const HloInstruction* fusion) { + VLOG(8) << "Processing fusion " << fusion->ToString(); + + if (fusion->IsCustomFusion()) { + for (const HloInstruction* hlo : + fusion->fused_instructions_computation()->instructions()) { + if (hlo->opcode() == HloOpcode::kGather) { + return HandleGather(hlo); + } + if (hlo->opcode() == HloOpcode::kScatter) { + return HandleScatter(hlo); + } + } + } + TF_ASSIGN_OR_RETURN( + current_properties_, + ProcessSubcomputation(fusion->fused_instructions_computation())); + + current_properties_[kBytesAccessedKey] = 0; + TF_RETURN_IF_ERROR(FusionProcessOutputBytesAccessed(fusion)); + TF_RETURN_IF_ERROR(FusionCalculateUtilizations(fusion)); + TF_RETURN_IF_ERROR(FusionCountConstantsMemoryAccess(fusion)); + TF_RETURN_IF_ERROR(FusionProcessOperandBytesRead(fusion)); return OkStatus(); } diff --git a/xla/service/hlo_cost_analysis.h b/xla/service/hlo_cost_analysis.h index 8305b0fadd215..705eb2437860f 100644 --- a/xla/service/hlo_cost_analysis.h +++ b/xla/service/hlo_cost_analysis.h @@ -16,6 +16,7 @@ limitations under the License. #ifndef XLA_SERVICE_HLO_COST_ANALYSIS_H_ #define XLA_SERVICE_HLO_COST_ANALYSIS_H_ +#include #include #include #include @@ -247,7 +248,7 @@ class HloCostAnalysis : public ConstDfsHloVisitor { // props[kFlopsKey] gets optimized to `return flops_` just fine. // Getters/setters for more complex properties like operand utilization, - // where we have a fastpath for e.g. operand 0/1 + shape_index {}. + // where we have a fastpath, e.g., operand 0/1 + shape_index {}. float operand_utilization(int64_t operand, const ShapeIndex& shape_index = {}) { if (operand == 0 && shape_index.empty()) { @@ -571,6 +572,37 @@ class HloCostAnalysis : public ConstDfsHloVisitor { const DotDimensionNumbers& dnums); protected: + // Computes the bytes accessed based on the outputs produced by the fusion + // instruction. + virtual Status FusionProcessOutputBytesAccessed(const HloInstruction* fusion); + + // Computes the bytes accessed (read) based on the inputs consumed by the + // fusion instruction. + virtual Status FusionProcessOperandBytesRead(const HloInstruction* fusion); + + // Computes memory access to all larger constants in the fusion instruction. + virtual Status FusionCountConstantsMemoryAccess(const HloInstruction* fusion); + + // Allows exclusion of certain types of inputs from bytes accessed during + // FusionProcessOperandBytesRead. + virtual bool ShouldFilterFusionInput(const HloInstruction* fusion, + int64_t input_index) { + return false; + } + + // Allows exclusion of certain instructions from FusionCalculateUtilizations. + virtual bool ShouldFilterFusionInstruction( + const HloInstruction* fusion, const HloInstruction* instruction) { + return false; + } + + // Allows exclusion of certain types of output from bytes written during + // FusionProcessOutputBytesAccessed. + virtual bool ShouldFilterFusionOutputIndex(const HloInstruction* fusion, + const ShapeIndex& output_index) { + return false; + } + typedef absl::flat_hash_map HloToProperties; @@ -588,7 +620,8 @@ class HloCostAnalysis : public ConstDfsHloVisitor { // given hlo. The cost of visited sub HLO instructions is saved to // hlo_properties_, which will be used by functions such as // flop_count(hlo_instruction) to return cost of a particular HLO instruction. - StatusOr ProcessSubcomputation(HloComputation* computation); + virtual StatusOr ProcessSubcomputation( + HloComputation* computation); // Utility function to handle all element-wise operations. Status HandleElementwiseOp(const HloInstruction* hlo_instruction); @@ -615,7 +648,7 @@ class HloCostAnalysis : public ConstDfsHloVisitor { // bottleneck. bool current_should_compute_bottleneck_time_; - // The properties of the currently visited instruction. A HandleFoo method can + // The properties of the currently visited instruction. A HandleFoo method // modify these to change the default values computed in Preprocess. Properties current_properties_;