Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[SYCL] Pass foffload-fp32-prec-[div/sqrt] options to device's BE #16107

Draft
wants to merge 2 commits into
base: sycl
Choose a base branch
from
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
11 changes: 10 additions & 1 deletion clang/lib/Driver/ToolChains/SYCL.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -1950,9 +1950,18 @@ void SYCLToolChain::AddImpliedTargetArgs(const llvm::Triple &Triple,
if (Args.hasFlag(options::OPT_ftarget_export_symbols,
options::OPT_fno_target_export_symbols, false))
BeArgs.push_back("-library-compilation");
} else if (IsJIT)
// -foffload-fp32-prec-[sqrt/div]
Copy link
Contributor Author

@MrSidims MrSidims Nov 18, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@mdtoguchi please take a look at these few lines, to check if I have correctly figured out what SYCL.cpp does to pass options from FE to BE.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It looks like your intention is to pass -options -ze-fp32-correctly-rounded-device-sqrt to ocloc and for JIT, pass the respective -foffload-fp32-prec* option to be embedded in the compile options when the JIT binary is wrapped.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Also - please add a driver test that checks these behaviors.

if (Args.hasArg(options::OPT_foffload_fp32_prec_div) ||
Args.hasArg(options::OPT_foffload_fp32_prec_sqrt))
BeArgs.push_back("-ze-fp32-correctly-rounded-divide-sqrt");
} else if (IsJIT) {
// -ftarget-compile-fast JIT
Args.AddLastArg(BeArgs, options::OPT_ftarget_compile_fast);
// -foffload-fp32-prec-div JIT
Args.AddLastArg(BeArgs, options::OPT_foffload_fp32_prec_div);
// -foffload-fp32-prec-sqrt JIT
Args.AddLastArg(BeArgs, options::OPT_OPT_foffload_fp32_prec_sqrt);
}
if (IsGen) {
for (auto [DeviceName, BackendArgStr] : PerDeviceArgs) {
CmdArgs.push_back("-device_options");
Expand Down
34 changes: 34 additions & 0 deletions llvm/include/llvm/SYCLLowerIR/SYCLSqrtFDivMaxErrorCleanUp.h
Original file line number Diff line number Diff line change
@@ -0,0 +1,34 @@
//===-- SYCLSqrtFDivMaxErrorCleanUp.h - SYCLSqrtFDivMaxErrorCleanUp Pass --===//
//
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
// See https://llvm.org/LICENSE.txt for license information.
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
//
//===----------------------------------------------------------------------===//
// Remove llvm.fpbuiltin.[sqrt/fdiv] intrinsics to ensure compatibility with the
// old drivers (that don't support SPV_INTEL_fp_max_error extension).
// The intrinsic functions are removed in case if they are used with standard
// for OpenCL max-error (e.g [3.0/2.5] ULP) and there are no:
// - other llvm.fpbuiltin.* intrinsic functions;
// - fdiv instructions
// - @sqrt builtins (both C and C++-styles)/llvm intrinsic in the module.
//===----------------------------------------------------------------------===//
#ifndef LLVM_SYCL_SQRT_FDIV_MAX_ERROR_CLEAN_UP_H
#define LLVM_SYCL_SQRT_FDIV_MAX_ERROR_CLEAN_UP_H

#include "llvm/IR/PassManager.h"

namespace llvm {

// FIXME: remove this pass, it's not really needed.
class SYCLSqrtFDivMaxErrorCleanUpPass
: public PassInfoMixin<SYCLSqrtFDivMaxErrorCleanUpPass> {
public:
PreservedAnalyses run(Module &M, ModuleAnalysisManager &);

static bool isRequired() { return true; }
};

} // namespace llvm

#endif // LLVM_SYCL_SQRT_FDIV_MAX_ERROR_CLEAN_UP_H
1 change: 1 addition & 0 deletions llvm/lib/Passes/PassBuilder.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -151,6 +151,7 @@
#include "llvm/SYCLLowerIR/SYCLJointMatrixTransform.h"
#include "llvm/SYCLLowerIR/SYCLPropagateAspectsUsage.h"
#include "llvm/SYCLLowerIR/SYCLPropagateJointMatrixUsage.h"
#include "llvm/SYCLLowerIR/SYCLSqrtFDivMaxErrorCleanUp.h"
#include "llvm/SYCLLowerIR/SYCLVirtualFunctionsAnalysis.h"
#include "llvm/SYCLLowerIR/SpecConstants.h"
#include "llvm/Support/CommandLine.h"
Expand Down
1 change: 1 addition & 0 deletions llvm/lib/Passes/PassRegistry.def
Original file line number Diff line number Diff line change
Expand Up @@ -166,6 +166,7 @@ MODULE_PASS("esimd-remove-host-code", ESIMDRemoveHostCodePass());
MODULE_PASS("esimd-remove-optnone-noinline", ESIMDRemoveOptnoneNoinlinePass());
MODULE_PASS("sycl-conditional-call-on-device", SYCLConditionalCallOnDevicePass())
MODULE_PASS("sycl-joint-matrix-transform", SYCLJointMatrixTransformPass())
MODULE_PASS("sycl-sqrt-fdiv-max-error-clean-up", SYCLSqrtFDivMaxErrorCleanUpPass())
MODULE_PASS("sycl-propagate-aspects-usage", SYCLPropagateAspectsUsagePass())
MODULE_PASS("sycl-propagate-joint-matrix-usage", SYCLPropagateJointMatrixUsagePass())
MODULE_PASS("sycl-add-opt-level-attribute", SYCLAddOptLevelAttributePass())
Expand Down
1 change: 1 addition & 0 deletions llvm/lib/SYCLLowerIR/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -67,6 +67,7 @@ add_llvm_component_library(LLVMSYCLLowerIR
SYCLJointMatrixTransform.cpp
SYCLPropagateAspectsUsage.cpp
SYCLPropagateJointMatrixUsage.cpp
SYCLSqrtFDivMaxErrorCleanUp.cpp
SYCLVirtualFunctionsAnalysis.cpp
SYCLUtils.cpp
SanitizeDeviceGlobal.cpp
Expand Down
160 changes: 160 additions & 0 deletions llvm/lib/SYCLLowerIR/SYCLSqrtFDivMaxErrorCleanUp.cpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,160 @@
//===- SYCLSqrtFDivMaxErrorCleanUp.cpp - SYCLSqrtFDivMaxErrorCleanUp Pass -===//
//
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
// See https://llvm.org/LICENSE.txt for license information.
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
//
//===----------------------------------------------------------------------===//
// Remove llvm.fpbuiltin.[sqrt/fdiv] intrinsics to ensure compatibility with the
Copy link
Contributor Author

@MrSidims MrSidims Nov 18, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@gmlueck without the deep dive into the pass, may I ask you to check if the logic of the pass described in the comment makes sense to you? Note, I'm not adding annotation of the kernels with some optional kernel feature metadata, that could help discarding 'precise' options from the list of the BE options.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Reason is: currently we either have the intrinsics in the module or don't have them at all. And when we have the - non-precise option was already passed, so there is nothing to rewrite for BE options.

// old drivers (that don't support SPV_INTEL_fp_max_error extension).
// The intrinsic functions are removed in case if they are used with standard
// for OpenCL max-error (e.g [3.0/2.5] ULP) and there are no:
// - other llvm.fpbuiltin.* intrinsic functions;
// - fdiv instructions
// - @sqrt builtins (both C and C++-styles)/llvm intrinsic in the module.
//===----------------------------------------------------------------------===//

#include "llvm/SYCLLowerIR/SYCLSqrtFDivMaxErrorCleanUp.h"

#include "llvm/ADT/SmallSet.h"
#include "llvm/IR/Module.h"
#include "llvm/IR/IntrinsicInst.h"
#include "llvm/IR/IRBuilder.h"

using namespace llvm;

namespace {
static constexpr char SQRT_ERROR[] = "3.0";
static constexpr char FDIV_ERROR[] = "2.5";
} // namespace

PreservedAnalyses
SYCLSqrtFDivMaxErrorCleanUpPass::run(Module &M,
ModuleAnalysisManager &MAM) {
SmallVector<IntrinsicInst *, 16> WorkListSqrt;
SmallVector<IntrinsicInst *, 16> WorkListFDiv;

// Add all llvm.fpbuiltin.sqrt with 3.0 error and llvm.fpbuiltin.fdiv with
// 2.5 error to the work list to remove them later. If attributes with other
// values or other llvm.fpbuiltin.* intrinsic functions found - abort the
// pass.
for (auto &F : M) {
if (!F.isDeclaration())
continue;
const auto ID = F.getIntrinsicID();
if (ID != llvm::Intrinsic::fpbuiltin_sqrt &&
ID != llvm::Intrinsic::fpbuiltin_fdiv)
continue;

for (auto *Use : F.users()) {
auto *II = cast<IntrinsicInst>(Use);
if (II && II->getCalledFunction()->getName().
starts_with("llvm.fpbuiltin")) {
// llvm.fpbuiltin.* intrinsics should always have fpbuiltin-max-error
// attribute, but it's not a concern of the pass, so just do an early
// exit here if the attribute is not attached.
if (!II->getAttributes().hasFnAttr("fpbuiltin-max-error"))
return PreservedAnalyses::none();
StringRef MaxError = II->getAttributes().getFnAttr(
"fpbuiltin-max-error").getValueAsString();

if (ID == llvm::Intrinsic::fpbuiltin_sqrt) {
if (MaxError != SQRT_ERROR)
return PreservedAnalyses::none();
WorkListSqrt.push_back(II);
}
else if (ID == llvm::Intrinsic::fpbuiltin_fdiv) {
if (MaxError != FDIV_ERROR)
return PreservedAnalyses::none();
WorkListFDiv.push_back(II);
} else {
// Another llvm.fpbuiltin.* intrinsic was found - the module is
// already not backward compatible.
return PreservedAnalyses::none();
}
}
}
}

// No intrinsics at all - do an early exist.
if (WorkListSqrt.empty() && WorkListFDiv.empty())
return PreservedAnalyses::none();

// If @sqrt, @_Z4sqrt*, @llvm.sqrt. or fdiv present in the module - do
// nothing.
for (auto &F : M) {
if (F.isDeclaration())
continue;
for (auto &BB : F) {
for (auto &II : BB) {
if (auto *CI = dyn_cast<CallInst>(&II)) {
auto *SqrtF = CI->getCalledFunction();
if (SqrtF->getName() == "sqrt" ||
SqrtF->getName().starts_with("_Z4sqrt") ||
SqrtF->getIntrinsicID() == llvm::Intrinsic::sqrt)
return PreservedAnalyses::none();
}
if (auto *FPI = dyn_cast<FPMathOperator>(&II)) {
auto Opcode = FPI->getOpcode();
if (Opcode == Instruction::FDiv)
return PreservedAnalyses::none();
}
}
}
}

// Replace @llvm.fpbuiltin.sqrt call with @llvm.sqrt. llvm-spirv will handle
// it later.
SmallSet<Function *, 2> DeclToRemove;
for (auto *Sqrt : WorkListSqrt) {
DeclToRemove.insert(Sqrt->getCalledFunction());
IRBuilder Builder(Sqrt);
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

To move outside the loop

Builder.SetInsertPoint(Sqrt);
Type *Ty = Sqrt->getType();
AttributeList Attrs = Sqrt->getAttributes();
Function *NewSqrtF =
Intrinsic::getDeclaration(&M, llvm::Intrinsic::sqrt, Ty);
auto *NewSqrt = Builder.CreateCall(NewSqrtF, { Sqrt->getOperand(0) },
Sqrt->getName());

// Copy FP flags, metadata and attributes. Replace old call with a new call.
Attrs = Attrs.removeFnAttribute(Sqrt->getContext(), "fpbuiltin-max-error");
NewSqrt->setAttributes(Attrs);
NewSqrt->copyMetadata(*Sqrt);
FPMathOperator *FPOp = cast<FPMathOperator>(Sqrt);
FastMathFlags FMF = FPOp->getFastMathFlags();
NewSqrt->setFastMathFlags(FMF);
Sqrt->replaceAllUsesWith(NewSqrt);
Sqrt->dropAllReferences();
Sqrt->eraseFromParent();
}

// Replace @llvm.fpbuiltin.fdiv call with fdiv.
for (auto *FDiv : WorkListFDiv) {
DeclToRemove.insert(FDiv->getCalledFunction());
IRBuilder Builder(FDiv);
Builder.SetInsertPoint(FDiv);
Instruction *NewFDiv =
cast<Instruction>(Builder.CreateFDiv(
FDiv->getOperand(0), FDiv->getOperand(1), FDiv->getName()));

// Copy FP flags and metadata. Replace old call with a new instruction.
cast<Instruction>(NewFDiv)->copyMetadata(*FDiv);
FPMathOperator *FPOp = cast<FPMathOperator>(FDiv);
FastMathFlags FMF = FPOp->getFastMathFlags();
NewFDiv->setFastMathFlags(FMF);
FDiv->replaceAllUsesWith(NewFDiv);
FDiv->dropAllReferences();
FDiv->eraseFromParent();
}

// Clear old declarations.
for (auto *Decl : DeclToRemove) {
assert(Decl->isDeclaration() &&
"attempting to remove a function definition");
Decl->dropAllReferences();
Decl->eraseFromParent();
}

return PreservedAnalyses::all();
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,68 @@
; Test checks if @llvm.fpbuiltin.fdiv and @llvm.fpbuiltin.sqrt are removed from
; the module.

; RUN: opt -passes=sycl-sqrt-fdiv-max-error-clean-up < %s -S | FileCheck %s

; CHECK-NOT: llvm.fpbuiltin.fdiv.f32
; CHECK-NOT: llvm.fpbuiltin.sqrt.f32
; CHECK-NOT: fpbuiltin-max-error

; CHECK: test_fp_max_error_decoration(float [[F1:[%0-9a-z.]+]], float [[F2:[%0-9a-z.]+]])
; CHECK: [[V1:[%0-9a-z.]+]] = fdiv float [[F1]], [[F2]]
; CHECK: call float @llvm.sqrt.f32(float [[V1]])

; CHECK: test_fp_max_error_decoration_fast(float [[F1:[%0-9a-z.]+]], float [[F2:[%0-9a-z.]+]])
; CHECK: [[V1:[%0-9a-z.]+]] = fdiv fast float [[F1]], [[F2]]
; CHECK: call fast float @llvm.sqrt.f32(float [[V1]])

; CHECK: test_fp_max_error_decoration_debug(float [[F1:[%0-9a-z.]+]], float [[F2:[%0-9a-z.]+]])
; CHECK: [[V1:[%0-9a-z.]+]] = fdiv float [[F1]], [[F2]], !dbg ![[#Loc1:]]
; CHECK: call float @llvm.sqrt.f32(float [[V1]]), !dbg ![[#Loc2:]]

; CHECK: [[#Loc1]] = !DILocation(line: 1, column: 1, scope: ![[#]])
; CHECK: [[#Loc2]] = !DILocation(line: 2, column: 1, scope: ![[#]])

target datalayout = "e-i64:64-v16:16-v24:32-v32:32-v48:64-v96:128-v192:256-v256:256-v512:512-v1024:1024-n8:16:32:64"
target triple = "spir64-unknown-unknown"

define void @test_fp_max_error_decoration(float %f1, float %f2) {
entry:
%v1 = call float @llvm.fpbuiltin.fdiv.f32(float %f1, float %f2) #0
%v2 = call float @llvm.fpbuiltin.sqrt.f32(float %v1) #1
ret void
}

define void @test_fp_max_error_decoration_fast(float %f1, float %f2) {
entry:
%v1 = call fast float @llvm.fpbuiltin.fdiv.f32(float %f1, float %f2) #0
%v2 = call fast float @llvm.fpbuiltin.sqrt.f32(float %v1) #1
ret void
}

define void @test_fp_max_error_decoration_debug(float %f1, float %f2) {
entry:
%v1 = call float @llvm.fpbuiltin.fdiv.f32(float %f1, float %f2) #0, !dbg !7
%v2 = call float @llvm.fpbuiltin.sqrt.f32(float %v1) #1, !dbg !8
ret void
}

declare float @llvm.fpbuiltin.fdiv.f32(float, float)

declare float @llvm.fpbuiltin.sqrt.f32(float)

attributes #0 = { "fpbuiltin-max-error"="2.5" }
attributes #1 = { "fpbuiltin-max-error"="3.0" }

!llvm.dbg.cu = !{!0}
!llvm.module.flags = !{!9}

!0 = distinct !DICompileUnit(language: DW_LANG_C99, file: !1, producer: "clang", isOptimized: false, runtimeVersion: 0, emissionKind: FullDebug, nameTableKind: None)
!1 = !DIFile(filename: "test.c", directory: "/tmp", checksumkind: CSK_MD5, checksum: "2a034da6937f5b9cf6dd2d89127f57fd")
!2 = distinct !DISubprogram(name: "test_fp_max_error_decoration_debug", scope: !1, file: !1, line: 1, type: !3, scopeLine: 2, flags: DIFlagPrototyped, spFlags: DISPFlagDefinition, unit: !0)
!3 = !DISubroutineType(types: !4)
!4 = !{!5, !6, !6}
!5 = !DIBasicType(name: "int", size: 32, encoding: DW_ATE_signed)
!6 = !DIBasicType(name: "float", size: 32, encoding: DW_ATE_float)
!7 = !DILocation(line: 1, column: 1, scope: !2)
!8 = !DILocation(line: 2, column: 1, scope: !2)
!9 = !{i32 2, !"Debug Info Version", i32 3}
Original file line number Diff line number Diff line change
@@ -0,0 +1,25 @@
; Test checks if @llvm.fpbuiltin.fdiv and @llvm.fpbuiltin.sqrt remain if
; non-standart for OpenCL max-error is used.

; RUN: opt -passes=sycl-sqrt-fdiv-max-error-clean-up < %s -S | FileCheck %s

; CHECK: llvm.fpbuiltin.fdiv.f32
; CHECK: llvm.fpbuiltin.sqrt.f32
; CHECK: fpbuiltin-max-error

target datalayout = "e-i64:64-v16:16-v24:32-v32:32-v48:64-v96:128-v192:256-v256:256-v512:512-v1024:1024-n8:16:32:64"
target triple = "spir64-unknown-unknown"

define void @test_fp_max_error_decoration(float %f1, float %f2) {
entry:
%v1 = call float @llvm.fpbuiltin.fdiv.f32(float %f1, float %f2) #0
%v2 = call float @llvm.fpbuiltin.sqrt.f32(float %v1) #1
ret void
}

declare float @llvm.fpbuiltin.fdiv.f32(float, float)

declare float @llvm.fpbuiltin.sqrt.f32(float)

attributes #0 = { "fpbuiltin-max-error"="2.0" }
attributes #1 = { "fpbuiltin-max-error"="3.0" }
Original file line number Diff line number Diff line change
@@ -0,0 +1,26 @@
; Test checks if @llvm.fpbuiltin.fdiv and @llvm.fpbuiltin.sqrt remain if
; fdiv instruction was in the module.

; RUN: opt -passes=sycl-sqrt-fdiv-max-error-clean-up < %s -S | FileCheck %s

; CHECK: llvm.fpbuiltin.fdiv.f32
; CHECK: llvm.fpbuiltin.sqrt.f32
; CHECK: fpbuiltin-max-error

target datalayout = "e-i64:64-v16:16-v24:32-v32:32-v48:64-v96:128-v192:256-v256:256-v512:512-v1024:1024-n8:16:32:64"
target triple = "spir64-unknown-unknown"

define void @test_fp_max_error_decoration(float %f1, float %f2) {
entry:
%v1 = call float @llvm.fpbuiltin.fdiv.f32(float %f1, float %f2) #0
%v2 = call float @llvm.fpbuiltin.sqrt.f32(float %v1) #1
%v3 = fdiv float %v2, %f2
ret void
}

declare float @llvm.fpbuiltin.fdiv.f32(float, float)

declare float @llvm.fpbuiltin.sqrt.f32(float)

attributes #0 = { "fpbuiltin-max-error"="2.0" }
attributes #1 = { "fpbuiltin-max-error"="3.0" }
Original file line number Diff line number Diff line change
@@ -0,0 +1,28 @@
; Test checks if @llvm.fpbuiltin.fdiv and @llvm.fpbuiltin.sqrt remain if
; other fpbuiltin intrinsic is used in the module.

; RUN: opt -passes=sycl-sqrt-fdiv-max-error-clean-up < %s -S | FileCheck %s

; CHECK: llvm.fpbuiltin.fdiv.f32
; CHECK: llvm.fpbuiltin.sqrt.f32
; CHECK: fpbuiltin-max-error

target datalayout = "e-i64:64-v16:16-v24:32-v32:32-v48:64-v96:128-v192:256-v256:256-v512:512-v1024:1024-n8:16:32:64"
target triple = "spir64-unknown-unknown"

define void @test_fp_max_error_decoration(float %f1, float %f2) {
entry:
%v1 = call float @llvm.fpbuiltin.fdiv.f32(float %f1, float %f2) #0
%v2 = call float @llvm.fpbuiltin.sqrt.f32(float %v1) #1
%v3 = call float @llvm.fpbuiltin.exp.f32(float %v2)
ret void
}

declare float @llvm.fpbuiltin.fdiv.f32(float, float)

declare float @llvm.fpbuiltin.sqrt.f32(float)

declare float @llvm.fpbuiltin.exp.f32(float)

attributes #0 = { "fpbuiltin-max-error"="2.0" }
attributes #1 = { "fpbuiltin-max-error"="3.0" }
Loading
Loading