From 0c44dfbbec595f0e7887c401962d130d2a112099 Mon Sep 17 00:00:00 2001 From: Ying Zhang Date: Sun, 27 Aug 2023 17:41:57 -0700 Subject: [PATCH] Add simple hash and eq methods for gemm_operations. (#1053) --- tools/library/scripts/gemm_operation.py | 37 ++++++++++++++----------- 1 file changed, 21 insertions(+), 16 deletions(-) diff --git a/tools/library/scripts/gemm_operation.py b/tools/library/scripts/gemm_operation.py index d248643d..58dba0ff 100644 --- a/tools/library/scripts/gemm_operation.py +++ b/tools/library/scripts/gemm_operation.py @@ -23,7 +23,7 @@ class GemmOperation: # def __init__(self, gemm_kind, arch, tile_description, A, B, C, element_epilogue, \ - epilogue_functor = EpilogueFunctor.LinearCombination, swizzling_functor = SwizzlingFunctor.Identity8, D = None, + epilogue_functor = EpilogueFunctor.LinearCombination, swizzling_functor = SwizzlingFunctor.Identity8, D = None, kernel_schedule = KernelScheduleType.ScheduleAuto, epilogue_schedule = EpilogueScheduleType.ScheduleAuto, tile_scheduler = TileSchedulerType.Default): @@ -35,7 +35,7 @@ def __init__(self, gemm_kind, arch, tile_description, A, B, C, element_epilogue, self.A = A self.B = B self.C = C - self.D = D + self.D = D if self.D == None: self.D = self.C @@ -52,7 +52,7 @@ def __init__(self, gemm_kind, arch, tile_description, A, B, C, element_epilogue, # def is_complex(self): complex_operators = [ - MathOperation.multiply_add_complex, + MathOperation.multiply_add_complex, MathOperation.multiply_add_complex_gaussian, MathOperation.multiply_add_complex_fast_f32 ] @@ -81,7 +81,7 @@ def short_math_name(self): # def core_name(self): ''' The basic operation kind is prefixed with a letter indicating the accumulation type. ''' - + inst_shape = '' inst_operation = '' intermediate_type = '' @@ -148,7 +148,7 @@ def extended_name_3x(self): def layout_name(self): if self.is_complex() or self.is_planar_complex(): return "%s%s" % ( - ShortComplexLayoutNames[(self.A.layout, self.A.complex_transform)], + ShortComplexLayoutNames[(self.A.layout, self.A.complex_transform)], ShortComplexLayoutNames[(self.B.layout, self.B.complex_transform)] ) return "%s%s" % (ShortLayoutTypeNames[self.A.layout], ShortLayoutTypeNames[self.B.layout]) @@ -157,7 +157,7 @@ def layout_name(self): def layout_name_3x(self): if self.is_complex() or self.is_planar_complex(): return "{}{}{}".format( - ShortComplexLayoutNames[(self.A.layout, self.A.complex_transform)], + ShortComplexLayoutNames[(self.A.layout, self.A.complex_transform)], ShortComplexLayoutNames[(self.B.layout, self.B.complex_transform)], ShortComplexLayoutNames[(self.C.layout, self.C.complex_transform)]) else: @@ -212,6 +212,11 @@ def configuration_name(self): ''' The full procedural name indicates architecture, extended name, tile size, and layout. ''' return self.procedural_name() + def __hash__(self): + return hash(self.configuration_name()) + + def __eq__(self, other): + return self.configuration_name() == other.configuration_name() ################################################################################################### # @@ -324,7 +329,7 @@ def emit(self, operation): epilogue_vector_length = int(min(operation.C.alignment * DataTypeSize[operation.C.element], 128) / DataTypeSize[operation.C.element]) residual = '' - + values = { 'operation_name': operation.procedural_name(), 'element_a': DataTypeTag[operation.A.element], @@ -414,7 +419,7 @@ def emit(self, operation): epilogue_vector_length = int(min(operation.C.alignment * DataTypeSize[operation.C.element], 128) / DataTypeSize[operation.C.element]) residual = '' - + values = { 'operation_name': operation.procedural_name(), 'element_a': DataTypeTag[operation.A.element], @@ -481,7 +486,7 @@ def __init__(self, operation_suffix = ''): """ self.gemm_template = """ // Gemm operator ${operation_name} -using ${operation_name}_base = +using ${operation_name}_base = typename cutlass::gemm::kernel::DefaultGemmUniversal< ${element_b}, ${layout_b}, ${transform_b}, ${align_b}, // transposed B operand ${element_a}, ${layout_a}, ${transform_a}, ${align_a}, // transposed A operand @@ -499,12 +504,12 @@ def __init__(self, operation_suffix = ''): >::GemmKernel; // Define named type -struct ${operation_name}${operation_suffix} : +struct ${operation_name}${operation_suffix} : public ${operation_name}_base { }; """ self.gemm_template_interleaved = """ // Gemm operator ${operation_name} -using ${operation_name}_base = +using ${operation_name}_base = typename cutlass::gemm::kernel::DefaultGemmUniversal< ${element_a}, ${layout_a}, ${transform_a}, ${align_a}, ${element_b}, ${layout_b}, ${transform_b}, ${align_b}, @@ -522,7 +527,7 @@ def __init__(self, operation_suffix = ''): >::GemmKernel; // Define named type -struct ${operation_name}${operation_suffix} : +struct ${operation_name}${operation_suffix} : public ${operation_name}_base { }; """ @@ -793,7 +798,7 @@ def __init__(self, operation_suffix = ''): ${math_operator} >::GemmKernel; - struct ${operation_name} : + struct ${operation_name} : public Operation_${operation_name} { }; """ @@ -1170,7 +1175,7 @@ def emit(self, operation): 'compile_guard_start': SubstituteTemplate(self.wmma_guard_start, {'sm_number': str(operation.arch)}) \ if operation.tile_description.math_instruction.opcode_class == OpcodeClass.WmmaTensorOp else "", 'compile_guard_end': "#endif" \ - if operation.tile_description.math_instruction.opcode_class == OpcodeClass.WmmaTensorOp else "" + if operation.tile_description.math_instruction.opcode_class == OpcodeClass.WmmaTensorOp else "" })) def __exit__(self, exception_type, exception_value, traceback): @@ -1190,9 +1195,9 @@ def __exit__(self, exception_type, exception_value, traceback): self.configuration_file.write(SubstituteTemplate(self.initialize_function_template, { 'configuration_name': self.configuration_name })) - + for instance_wrapper in self.instance_wrappers: - self.configuration_file.write(instance_wrapper) + self.configuration_file.write(instance_wrapper) self.configuration_file.write(self.epilogue_template) self.configuration_file.close()