Skip to content

Commit

Permalink
Add simple hash and eq methods for gemm_operations. (#1053)
Browse files Browse the repository at this point in the history
  • Loading branch information
ipiszy authored and ttl10101 committed Feb 7, 2024
1 parent f9f5f4e commit 0c44dfb
Showing 1 changed file with 21 additions and 16 deletions.
37 changes: 21 additions & 16 deletions tools/library/scripts/gemm_operation.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,7 @@
class GemmOperation:
#
def __init__(self, gemm_kind, arch, tile_description, A, B, C, element_epilogue, \
epilogue_functor = EpilogueFunctor.LinearCombination, swizzling_functor = SwizzlingFunctor.Identity8, D = None,
epilogue_functor = EpilogueFunctor.LinearCombination, swizzling_functor = SwizzlingFunctor.Identity8, D = None,
kernel_schedule = KernelScheduleType.ScheduleAuto, epilogue_schedule = EpilogueScheduleType.ScheduleAuto,
tile_scheduler = TileSchedulerType.Default):

Expand All @@ -35,7 +35,7 @@ def __init__(self, gemm_kind, arch, tile_description, A, B, C, element_epilogue,
self.A = A
self.B = B
self.C = C
self.D = D
self.D = D
if self.D == None:
self.D = self.C

Expand All @@ -52,7 +52,7 @@ def __init__(self, gemm_kind, arch, tile_description, A, B, C, element_epilogue,
#
def is_complex(self):
complex_operators = [
MathOperation.multiply_add_complex,
MathOperation.multiply_add_complex,
MathOperation.multiply_add_complex_gaussian,
MathOperation.multiply_add_complex_fast_f32
]
Expand Down Expand Up @@ -81,7 +81,7 @@ def short_math_name(self):
#
def core_name(self):
''' The basic operation kind is prefixed with a letter indicating the accumulation type. '''

inst_shape = ''
inst_operation = ''
intermediate_type = ''
Expand Down Expand Up @@ -148,7 +148,7 @@ def extended_name_3x(self):
def layout_name(self):
if self.is_complex() or self.is_planar_complex():
return "%s%s" % (
ShortComplexLayoutNames[(self.A.layout, self.A.complex_transform)],
ShortComplexLayoutNames[(self.A.layout, self.A.complex_transform)],
ShortComplexLayoutNames[(self.B.layout, self.B.complex_transform)]
)
return "%s%s" % (ShortLayoutTypeNames[self.A.layout], ShortLayoutTypeNames[self.B.layout])
Expand All @@ -157,7 +157,7 @@ def layout_name(self):
def layout_name_3x(self):
if self.is_complex() or self.is_planar_complex():
return "{}{}{}".format(
ShortComplexLayoutNames[(self.A.layout, self.A.complex_transform)],
ShortComplexLayoutNames[(self.A.layout, self.A.complex_transform)],
ShortComplexLayoutNames[(self.B.layout, self.B.complex_transform)],
ShortComplexLayoutNames[(self.C.layout, self.C.complex_transform)])
else:
Expand Down Expand Up @@ -212,6 +212,11 @@ def configuration_name(self):
''' The full procedural name indicates architecture, extended name, tile size, and layout. '''
return self.procedural_name()

def __hash__(self):
return hash(self.configuration_name())

def __eq__(self, other):
return self.configuration_name() == other.configuration_name()

###################################################################################################
#
Expand Down Expand Up @@ -324,7 +329,7 @@ def emit(self, operation):
epilogue_vector_length = int(min(operation.C.alignment * DataTypeSize[operation.C.element], 128) / DataTypeSize[operation.C.element])

residual = ''

values = {
'operation_name': operation.procedural_name(),
'element_a': DataTypeTag[operation.A.element],
Expand Down Expand Up @@ -414,7 +419,7 @@ def emit(self, operation):
epilogue_vector_length = int(min(operation.C.alignment * DataTypeSize[operation.C.element], 128) / DataTypeSize[operation.C.element])

residual = ''

values = {
'operation_name': operation.procedural_name(),
'element_a': DataTypeTag[operation.A.element],
Expand Down Expand Up @@ -481,7 +486,7 @@ def __init__(self, operation_suffix = ''):
"""
self.gemm_template = """
// Gemm operator ${operation_name}
using ${operation_name}_base =
using ${operation_name}_base =
typename cutlass::gemm::kernel::DefaultGemmUniversal<
${element_b}, ${layout_b}, ${transform_b}, ${align_b}, // transposed B operand
${element_a}, ${layout_a}, ${transform_a}, ${align_a}, // transposed A operand
Expand All @@ -499,12 +504,12 @@ def __init__(self, operation_suffix = ''):
>::GemmKernel;
// Define named type
struct ${operation_name}${operation_suffix} :
struct ${operation_name}${operation_suffix} :
public ${operation_name}_base { };
"""
self.gemm_template_interleaved = """
// Gemm operator ${operation_name}
using ${operation_name}_base =
using ${operation_name}_base =
typename cutlass::gemm::kernel::DefaultGemmUniversal<
${element_a}, ${layout_a}, ${transform_a}, ${align_a},
${element_b}, ${layout_b}, ${transform_b}, ${align_b},
Expand All @@ -522,7 +527,7 @@ def __init__(self, operation_suffix = ''):
>::GemmKernel;
// Define named type
struct ${operation_name}${operation_suffix} :
struct ${operation_name}${operation_suffix} :
public ${operation_name}_base { };
"""

Expand Down Expand Up @@ -793,7 +798,7 @@ def __init__(self, operation_suffix = ''):
${math_operator}
>::GemmKernel;
struct ${operation_name} :
struct ${operation_name} :
public Operation_${operation_name} { };
"""

Expand Down Expand Up @@ -1170,7 +1175,7 @@ def emit(self, operation):
'compile_guard_start': SubstituteTemplate(self.wmma_guard_start, {'sm_number': str(operation.arch)}) \
if operation.tile_description.math_instruction.opcode_class == OpcodeClass.WmmaTensorOp else "",
'compile_guard_end': "#endif" \
if operation.tile_description.math_instruction.opcode_class == OpcodeClass.WmmaTensorOp else ""
if operation.tile_description.math_instruction.opcode_class == OpcodeClass.WmmaTensorOp else ""
}))

def __exit__(self, exception_type, exception_value, traceback):
Expand All @@ -1190,9 +1195,9 @@ def __exit__(self, exception_type, exception_value, traceback):
self.configuration_file.write(SubstituteTemplate(self.initialize_function_template, {
'configuration_name': self.configuration_name
}))

for instance_wrapper in self.instance_wrappers:
self.configuration_file.write(instance_wrapper)
self.configuration_file.write(instance_wrapper)

self.configuration_file.write(self.epilogue_template)
self.configuration_file.close()
Expand Down

0 comments on commit 0c44dfb

Please sign in to comment.