Skip to content

Commit

Permalink
first commit
Browse files Browse the repository at this point in the history
  • Loading branch information
solaslin committed Aug 20, 2024
1 parent b92f9e9 commit 45ea7ad
Show file tree
Hide file tree
Showing 9 changed files with 24 additions and 64 deletions.
2 changes: 1 addition & 1 deletion tensilelite/Tensile/ClientWriter.py
Original file line number Diff line number Diff line change
Expand Up @@ -543,7 +543,7 @@ def param(key, value):
if problemType.useE:
param('e-type', problemType.eType.toEnum())
if problemType.outputAmaxD:
param('amaxD-type', problemType.amaxDType.toEnum())
param('amaxD-type', problemType.alphaType.toEnum())
param('alpha-type', problemType.alphaType.toEnum())
param('beta-type', problemType.betaType.toEnum())
param('f32-xdl-math-op', problemType.f32XdlMathOp.toEnum())
Expand Down
1 change: 0 additions & 1 deletion tensilelite/Tensile/Common.py
Original file line number Diff line number Diff line change
Expand Up @@ -1186,7 +1186,6 @@ def supportedCompiler(compiler: str) -> bool:
"DataTypeA": 0, # A data type can specified by a variety of ways, such as "s", as listed in SolutionStructs.py::DataType
"DataTypeB": 0, # B data type can specified by a variety of ways, such as "s", as listed in SolutionStructs.py::DataType
"DataTypeE": 0, # E data type can specified by a variety of ways, such as "s", as listed in SolutionStructs.py::DataType
"DataTypeAmaxD": 0, # AmaxD data type can specified by a variety of ways, such as "s", as listed in SolutionStructs.py::DataType
"DestDataType": 0, # destination data types can specified by a variety of ways, such as "s", as listed in SolutionStructs.py::DataType
"ComputeDataType": 0, # compute data types can specified by a variety of ways, such as "s", as listed in SolutionStructs.py::DataType
"F32XdlMathOp": 0, # reducing intermediate precision from f32 to a specific type, such as "x", as listed in SolutionStructs.py::DataType.
Expand Down
7 changes: 1 addition & 6 deletions tensilelite/Tensile/Contractions.py
Original file line number Diff line number Diff line change
Expand Up @@ -150,11 +150,6 @@ def FromOriginalState(cls, d):
else:
rv.eType = dstType

if 'DataTypeAmaxD' in d:
rv.amaxDType = DataType(d['DataTypeAmaxD'])
else:
rv.amaxDType = computeType

rv.computeInputType = srcType
rv.cType = dstType
rv.dType = dstType
Expand Down Expand Up @@ -348,7 +343,7 @@ def predicates(self, includeBatch=False, includeOperation=False, includeType=Fal
predicates.append(ProblemPredicate("BetaZero"))
predicates.append(ProblemPredicate("BiasDataTypeWhiteList", value=self.biasDataTypeWhiteList))
predicates.append(ProblemPredicate("BiasSrcWhiteList", value=self.biasSrcWhiteList))
predicates.append(ProblemPredicate("AmaxDCheck", value=self.outputAmaxD))
predicates.append(ProblemPredicate("AmaxDCheck"))
if self.activationType == 'all':
exportType = ActivationType.Export.GRADONLY if self.useGradient else ActivationType.Export.NORMAL
enumList = [actEnum.capitalize() for actEnum in ActivationType.getEnumStrList(self.activationComputeDataType, exportType=exportType)]
Expand Down
15 changes: 9 additions & 6 deletions tensilelite/Tensile/KernelWriterAssembly.py
Original file line number Diff line number Diff line change
Expand Up @@ -10627,6 +10627,9 @@ def amax_define_load_res(self) -> Module:
module.add(SWaitCnt(lgkmcnt=0))
module.addSpaceLine()

module.add(SCmpEQU64(sgpr("AddrAmaxOut", 2), hex(0), "amaxD == nullptr ?"))
module.add(SCBranchSCC1(self.label_amax_end.getLabelName(), "skip amaxD if nullptr"))

return module

def amax_intra_wave_reduction(self, kernel, postfix) -> Module:
Expand Down Expand Up @@ -10657,7 +10660,7 @@ def amax_intra_wave_reduction(self, kernel, postfix) -> Module:
def amax_inter_wave_reduction(self, kernel) -> Module:
wave_size = kernel["WavefrontSize"]
numWorkItems = kernel["NumThreads"]
amaxOutType = kernel["ProblemType"]["DataTypeAmaxD"]
amaxOutType = kernel["ProblemType"]["ComputeDataType"]
amax_lds_start = kernel["LdsBytesNoAmax"]

label_wave_inter = Label("wave_inter", 'wave_inter')
Expand Down Expand Up @@ -10743,19 +10746,18 @@ def amax_broadcast(self, kernel) -> Module:
def amax_output_result(self, kernel) -> Module:
wave_size = kernel["WavefrontSize"]
amaxInType = kernel["ProblemType"]["ComputeDataType"]
amaxOutType = kernel["ProblemType"]["DataTypeAmaxD"]
amaxOutType = kernel["ProblemType"]["ComputeDataType"]

mod = Module("output_result")
mod.addComment0("output_result")

label_end = Label("end", 'end')
label_final_loop = Label("final_loop", 'final_loop')
label_final_output = Label("final_output", 'final_output')
mod.addSpaceLine()

mod.add(VReadfirstlaneB32(sgpr("Tmp"), vgpr("Serial")))
mod.add(SCmpEQU32(sgpr("Tmp"), 0))
mod.add(SCBranchSCC0(label_end.getLabelName()))
mod.add(SCBranchSCC0(self.label_amax_end.getLabelName()))
mod.addSpaceLine()

# if self.arch.find("gfx94") != -1:
Expand All @@ -10781,7 +10783,7 @@ def amax_output_result(self, kernel) -> Module:
mod.add(SAtomicDec(sgpr("Tmp"), sgpr("AddressSy",2), SMEMModifiers(glc=True)))
mod.add(SWaitCnt(vmcnt=0, lgkmcnt=0))
mod.add(SCmpEQU32(sgpr("Tmp"), 1))
mod.add(SCBranchSCC0(label_end.getLabelName()))
mod.add(SCBranchSCC0(self.label_amax_end.getLabelName()))
mod.addSpaceLine()

mod.add(SLShiftLeftB32(sgpr("Tmp"), int(log2(amaxInType.numBytes())), sgpr("NumGroup")))
Expand Down Expand Up @@ -10831,7 +10833,7 @@ def amax_output_result(self, kernel) -> Module:
# TODO- select inst
mod.add(BufferStoreB32(vgpr("AmaxOut"), vgpr("Offset"), sgpr("Dst",4), 0, MUBUFModifiers(offen=True)))
mod.addSpaceLine()
mod.add(label_end)
mod.add(self.label_amax_end)
mod.addSpaceLine()

return mod
Expand All @@ -10844,6 +10846,7 @@ def insertAmaxD(self, kernel):
self.amaxVgprSizes = [1, 1, 1, 1]
self.amaxSgprArgNames = ["AddrAmaxOut", "AddressWk", "AddressSy"]
self.amaxSgprArgSizes = [2, 2, 2]
self.label_amax_end = Label("amax_end", 'amax_end')

module.addSpaceLine()
module.add(SBarrier())
Expand Down
8 changes: 0 additions & 8 deletions tensilelite/Tensile/LibraryIO.py
Original file line number Diff line number Diff line change
Expand Up @@ -131,8 +131,6 @@ def writeSolutions(filename, problemSizes, biasTypeArgs, activationArgs, solutio
solutionState["ProblemType"]["DataTypeB"].value
solutionState["ProblemType"]["DataTypeE"] = \
solutionState["ProblemType"]["DataTypeE"].value
solutionState["ProblemType"]["DataTypeAmaxD"] = \
solutionState["ProblemType"]["DataTypeAmaxD"].value
solutionState["ProblemType"]["DestDataType"] = \
solutionState["ProblemType"]["DestDataType"].value
solutionState["ProblemType"]["ComputeDataType"] = \
Expand Down Expand Up @@ -405,8 +403,6 @@ def createLibraryLogic(schedulePrefix, architectureName, deviceNames, libraryTyp
problemTypeState["DataTypeB"].value
problemTypeState["DataTypeE"] = \
problemTypeState["DataTypeE"].value
problemTypeState["DataTypeAmaxD"] = \
problemTypeState["DataTypeAmaxD"].value
problemTypeState["DestDataType"] = \
problemTypeState["DestDataType"].value
problemTypeState["ComputeDataType"] = \
Expand All @@ -433,8 +429,6 @@ def createLibraryLogic(schedulePrefix, architectureName, deviceNames, libraryTyp
solutionState["ProblemType"]["DataTypeB"].value
solutionState["ProblemType"]["DataTypeE"] = \
solutionState["ProblemType"]["DataTypeE"].value
solutionState["ProblemType"]["DataTypeAmaxD"] = \
solutionState["ProblemType"]["DataTypeAmaxD"].value
solutionState["ProblemType"]["DestDataType"] = \
solutionState["ProblemType"]["DestDataType"].value
solutionState["ProblemType"]["ComputeDataType"] = \
Expand Down Expand Up @@ -462,8 +456,6 @@ def createLibraryLogic(schedulePrefix, architectureName, deviceNames, libraryTyp
solutionState["ProblemType"]["DataTypeB"].value
solutionState["ProblemType"]["DataTypeE"] = \
solutionState["ProblemType"]["DataTypeE"].value
solutionState["ProblemType"]["DataTypeAmaxD"] = \
solutionState["ProblemType"]["DataTypeAmaxD"].value
solutionState["ProblemType"]["DestDataType"] = \
solutionState["ProblemType"]["DestDataType"].value
solutionState["ProblemType"]["ComputeDataType"] = \
Expand Down
8 changes: 1 addition & 7 deletions tensilelite/Tensile/SolutionStructs.py
Original file line number Diff line number Diff line change
Expand Up @@ -144,12 +144,6 @@ def __init__(self, config):
printExit("NO compute data type, or dest data type, or data type specified")
self["DataType"] = DataType(0)

# Just like DataTypeE is DestDataType by default; DataTypeAmaxD if ComputeDataType by default.
# So far we don't have to set it in config yamls
self["DataTypeAmaxD"] = self["ComputeDataType"]
if "DataTypeAmaxD" in config:
self["DataTypeAmaxD"] = DataType(config["DataTypeAmaxD"])

if self["Sparse"]:
self["DataTypeMetadata"] = DataType("I8")

Expand Down Expand Up @@ -3465,7 +3459,7 @@ def subCheckLdsBlockSizePerPad(tc, idx):
# 4 data * half_wave_num * amax bytePerE
num_workItems = state["NumThreads"]
half_wave_size = state["WavefrontSize"] // 2
amaxBPE = state["ProblemType"]["DataTypeAmaxD"].numBytes()
amaxBPE = state["ProblemType"]["ComputeDataType"].numBytes() # amax type = compute type
ldsAmaxDBytes = 4 * (num_workItems // half_wave_size) * amaxBPE
ldsNumBytes += ldsAmaxDBytes

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -850,15 +850,10 @@ namespace Tensile
enum
{
HasIndex = false,
HasValue = true
HasValue = false
};
bool value;

AmaxDCheck() = default;
AmaxDCheck(bool value)
: value(value)
{
}

static std::string Type()
{
Expand All @@ -867,38 +862,24 @@ namespace Tensile

virtual bool operator()(ContractionProblemGemm const& problem) const override
{
bool amaxDStatusEqual = (problem.outputAmaxD() == value);

// if value is true, then we also need to check gsu
// otherwise we just check outputAmaxD
if(value)
return amaxDStatusEqual && problem.getParams().gsu() <= 1;
if(problem.outputAmaxD())
return problem.getParams().gsu() <= 1;
else
return amaxDStatusEqual;
return true;
}

virtual bool debugEval(ContractionProblemGemm const& problem,
std::ostream& stream) const override
{
return (value) ? debugEvalCmp(problem,
stream,
"prob_amaxD",
problem.outputAmaxD(),
"==",
"sol_amaxD",
value,
"prob_gsu",
(int)(problem.getParams().gsu()),
"<=",
"sol_gsu",
1)
: debugEvalCmp(problem,
stream,
"prob_amaxD",
problem.outputAmaxD(),
"==",
"sol_amaxD",
value);
return (problem.outputAmaxD()) ? debugEvalCmp(problem,
stream,
"prob_gsu",
(int)(problem.getParams().gsu()),
"<=",
"sol_gsu",
1)
: true;
}
};

Expand Down
1 change: 0 additions & 1 deletion tensilelite/Tensile/TensileCreateLibrary.py
Original file line number Diff line number Diff line change
Expand Up @@ -1188,7 +1188,6 @@ def WriteClientLibraryFromSolutions(solutionList, libraryWorkingPath, tensileSou
problemType["DataTypeA"] = problemType["DataTypeA"].value
problemType["DataTypeB"] = problemType["DataTypeB"].value
problemType["DataTypeE"] = problemType["DataTypeE"].value
problemType["DataTypeAmaxD"] = problemType["DataTypeAmaxD"].value
problemType["DestDataType"] = problemType["DestDataType"].value
problemType["ComputeDataType"] = problemType["ComputeDataType"].value
problemType["F32XdlMathOp"] = problemType["F32XdlMathOp"].value
Expand Down
3 changes: 0 additions & 3 deletions tensilelite/Tensile/TensileUpdateLibrary.py
Original file line number Diff line number Diff line change
Expand Up @@ -47,7 +47,6 @@ def UpdateLogic(filename, logicPath, outputPath):
problemTypeState["DataTypeA"] = problemTypeState["DataTypeA"].value
problemTypeState["DataTypeB"] = problemTypeState["DataTypeB"].value
problemTypeState["DataTypeE"] = problemTypeState["DataTypeE"].value
problemTypeState["DataTypeAmaxD"] = problemTypeState["DataTypeAmaxD"].value
problemTypeState["DestDataType"] = problemTypeState["DestDataType"].value
problemTypeState["ComputeDataType"] = problemTypeState["ComputeDataType"].value
problemTypeState["BiasDataTypeList"] = [btype.value for btype in problemTypeState["BiasDataTypeList"]]
Expand All @@ -68,8 +67,6 @@ def UpdateLogic(filename, logicPath, outputPath):
solutionState["ProblemType"]["DataTypeB"].value
solutionState["ProblemType"]["DataTypeE"] = \
solutionState["ProblemType"]["DataTypeE"].value
solutionState["ProblemType"]["DataTypeAmaxD"] = \
solutionState["ProblemType"]["DataTypeAmaxD"].value
solutionState["ProblemType"]["DestDataType"] = \
solutionState["ProblemType"]["DestDataType"].value
solutionState["ProblemType"]["ComputeDataType"] = \
Expand Down

0 comments on commit 45ea7ad

Please sign in to comment.