Skip to content

Commit

Permalink
Rollback of [XLA:GPU] Allow reduction users in multi-output fusions w…
Browse files Browse the repository at this point in the history
…ith buffer aliasing (FusionCanShareBufferHint)

Reverts fbed9b7

PiperOrigin-RevId: 571017871
  • Loading branch information
tensorflower-gardener authored and copybara-github committed Oct 5, 2023
1 parent 3a34150 commit e73af42
Show file tree
Hide file tree
Showing 3 changed files with 13 additions and 72 deletions.
3 changes: 0 additions & 3 deletions xla/service/gpu/BUILD
Original file line number Diff line number Diff line change
Expand Up @@ -3428,12 +3428,9 @@ cc_library(
deps = [
":backend_configs_cc",
":cublas_cudnn",
":hlo_fusion_analysis",
":ir_emission_utils",
"//xla:shape_util",
"//xla/hlo/ir:hlo",
"//xla/stream_executor:device_description",
"//xla/stream_executor:device_description_proto_cc",
"@com_google_absl//absl/container:flat_hash_set",
],
)
Expand Down
42 changes: 10 additions & 32 deletions xla/service/gpu/buffer_sharing.cc
Original file line number Diff line number Diff line change
Expand Up @@ -21,27 +21,21 @@ limitations under the License.
#include <utility>

#include "absl/container/flat_hash_set.h"
#include "xla/hlo/ir/hlo_casting_utils.h"
#include "xla/hlo/ir/hlo_instruction.h"
#include "xla/hlo/ir/hlo_instructions.h"
#include "xla/hlo/ir/hlo_opcode.h"
#include "xla/service/gpu/backend_configs.pb.h"
#include "xla/service/gpu/cublas_cudnn.h"
#include "xla/service/gpu/hlo_fusion_analysis.h"
#include "xla/service/gpu/ir_emission_utils.h"
#include "xla/shape.h"
#include "xla/shape_util.h"
#include "xla/stream_executor/device_description.h"
#include "xla/stream_executor/device_description.pb.h"

namespace xla {
namespace gpu {

std::optional<bool> FusionCanShareBufferHint(const HloInstruction* user,
const HloInstruction* operand,
const ShapeIndex& user_index) {
const HloFusionInstruction* fusion = DynCast<HloFusionInstruction>(user);
if (fusion == nullptr) {
if (user->opcode() != HloOpcode::kFusion) {
return std::nullopt;
}

Expand Down Expand Up @@ -71,18 +65,10 @@ std::optional<bool> FusionCanShareBufferHint(const HloInstruction* user,
}
}

// Allow multiple output users, if they end in reductions.
// This only works for the reduction emitter, as it calculates the reduction
// first, i.e. before processing other outputs (that may overwrite the input).
stream_executor::GpuDeviceInfoProto device_info;
stream_executor::DeviceDescription device_description(device_info);
auto analysis = HloFusionAnalysis::Create(fusion, &device_description);
bool is_reduction_emitter = analysis->GetEmitterFusionKind() ==
HloFusionAnalysis::EmitterFusionKind::kReduction;

// We need to make sure that the fusion parameter is accessed in the same
// iteration order as the fusion output. Also, there should not be any other
// fusion output that accesses it in a different iteration order. To make sure
// iteration order as the fusion output. Also, there should not be two fusion
// outputs that consume the fusion parameter, because we do not want to share
// the same fusion operand with two different fusion outputs. To make sure
// that the iteration order is the same, we only allow ops on the path from
// fusion parameter to fusion output which are elementwise (no copy) or
// bitcast or an elementwise dynamic update slice (i.e. with the first operand
Expand All @@ -102,17 +88,17 @@ std::optional<bool> FusionCanShareBufferHint(const HloInstruction* user,
q.push(fusion_param);
visited.insert(fusion_param);
bool found_path_to_output = false;
int reached_root = 0;
while (!q.empty()) {
HloInstruction* hlo_operand = q.front();
q.pop();
if (hlo_operand == output) {
found_path_to_output = true;
// We still need to process the users of 'hlo_operand'. There can be other
// reduction users in addition to the tuple user.
if (hlo_operand->user_count() > 1 && !is_reduction_emitter) {
// The output should have at most 1 user: the tuple op (in case of a
// multi-output fusion)
if (hlo_operand->user_count() > 1) {
return false;
}
continue;
}
for (HloInstruction* hlo : hlo_operand->users()) {
if (non_bitcast_root->opcode() == HloOpcode::kDynamicUpdateSlice &&
Expand All @@ -136,15 +122,10 @@ std::optional<bool> FusionCanShareBufferHint(const HloInstruction* user,
}
}
}
} else if (hlo->opcode() == HloOpcode::kReduce && is_reduction_emitter) {
// Reduction emitter processes the reduction first, so the values below
// it will not interfere with buffer sharing.
continue;
} else if ((!hlo->IsElementwiseOnOperand(
hlo->operand_index(hlo_operand)) ||
hlo->opcode() == HloOpcode::kCopy) &&
hlo->opcode() != HloOpcode::kBitcast &&
hlo->opcode() != HloOpcode::kTuple) {
hlo->opcode() != HloOpcode::kBitcast) {
// This check also catches the case that we reach a different fusion
// output, as that fusion output would have a tuple op as user, which we
// do not allow here.
Expand All @@ -165,11 +146,8 @@ std::optional<bool> FusionCanShareBufferHint(const HloInstruction* user,
q.push(hlo);
}
}
if (hlo_operand->IsRoot()) {
++reached_root;
}
}
return found_path_to_output && reached_root == 1;
return found_path_to_output;
}

std::optional<bool> CanShareBufferHint(const HloInstruction* user,
Expand Down
40 changes: 3 additions & 37 deletions xla/service/gpu/gpu_copy_insertion_test.cc
Original file line number Diff line number Diff line change
Expand Up @@ -201,14 +201,13 @@ fused_computation {
param_1.1 = f32[2,3]{1,0} parameter(1)
neg = f32[2,3]{1,0} negate(param_1.1)
mul = f32[2,3]{1,0} multiply(param_0.1, neg)
transpose = f32[3,2]{1,0} transpose(neg), dimensions={1,0}
ROOT tuple = (f32[2,3]{1,0}, f32[2,3]{1,0}, f32[3,2]{1,0}) tuple(mul, neg, transpose)
ROOT tuple = (f32[2,3]{1,0}, f32[2,3]{1,0}) tuple(mul, neg)
}
ENTRY main {
param_0 = f32[2,3]{1,0} parameter(0)
param_1 = f32[2,3]{1,0} parameter(1)
ROOT fusion = (f32[2,3]{1,0}, f32[2,3]{1,0}, f32[3,2]{1,0}) fusion(param_0, param_1), kind=kLoop, calls=fused_computation
ROOT fusion = (f32[2,3]{1,0}, f32[2,3]{1,0}) fusion(param_0, param_1), kind=kLoop, calls=fused_computation
}
)";

Expand All @@ -217,7 +216,7 @@ ENTRY main {
HloInstruction* fusion = module->entry_computation()->root_instruction();
ExpectOptionalTrue(FusionCanShareBufferHint(fusion, fusion->operand(0), {0}));
// The second operand cannot share the buffer with the second fusion output,
// because the 'neg' op is also used by a non-elementwise op.
// because the 'neg' op is also used on the path to the first fusion output.
ExpectOptionalFalse(
FusionCanShareBufferHint(fusion, fusion->operand(1), {1}));
// The first operand cannot share the buffer with the second fusion output,
Expand All @@ -226,39 +225,6 @@ ENTRY main {
FusionCanShareBufferHint(fusion, fusion->operand(0), {1}));
}

TEST_F(FusionCanShareBufferHintTest, BufferCanBeSharedReductionEmitter) {
constexpr char kModuleString[] = R"(
HloModule TestModule
%maximum {
%lhs = f32[] parameter(0)
%rhs = f32[] parameter(1)
ROOT %res = f32[] maximum(%lhs, %rhs)
}
%fused_computation {
%lhs = f32[3,40] parameter(0)
%rhs = f32[3,40] parameter(1)
%add = f32[3,40] add(%lhs, %rhs)
%bc = f32[120] bitcast(%add)
%init = f32[] constant(-inf)
%max = f32[] reduce(%bc, %init), dimensions={0}, to_apply=%maximum
ROOT %result = (f32[], f32[3,40]) tuple(%max, %add)
}
ENTRY %main {
%lhs = f32[3,40] parameter(0)
%rhs = f32[3,40] parameter(1)
ROOT %fusion = (f32[], f32[3,40]) fusion(%lhs, %rhs),
kind=kLoop, calls=%fused_computation
})";

TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr<xla::HloModule> module,
ParseAndReturnVerifiedModule(kModuleString));
HloInstruction* fusion = module->entry_computation()->root_instruction();
ExpectOptionalTrue(FusionCanShareBufferHint(fusion, fusion->operand(0), {1}));
}

TEST_F(FusionCanShareBufferHintTest,
BufferCannotBeSharedConvertedShapeDifferentByteWidth) {
const char* const kModuleString = R"(
Expand Down

0 comments on commit e73af42

Please sign in to comment.