From 6b1d53174324f80cf7daa393eb09ccbb9d8a635a Mon Sep 17 00:00:00 2001 From: jiangzishan Date: Fri, 27 Sep 2024 20:22:11 +0800 Subject: [PATCH] [micro_perf] fix log and sqrt in gpu backend. --- byte_micro_perf/backends/GPU/backend_gpu.py | 9 +++++---- 1 file changed, 5 insertions(+), 4 deletions(-) diff --git a/byte_micro_perf/backends/GPU/backend_gpu.py b/byte_micro_perf/backends/GPU/backend_gpu.py index a6bbefaf..fa6e8983 100644 --- a/byte_micro_perf/backends/GPU/backend_gpu.py +++ b/byte_micro_perf/backends/GPU/backend_gpu.py @@ -88,11 +88,7 @@ def initialize_ccl(self, rank, world_size): self.setup_2d_group() return True - def log(self): - self.op = LogOp() - def sqrt(self): - self.op = SqrtOp() def setup_2d_group(self): # get rank and set device @@ -247,6 +243,11 @@ def swiglu(self): def cast(self): self.op = module_store.CastOp() + def log(self): + self.op = module_store.LogOp() + + def sqrt(self): + self.op = module_store.SqrtOp() # binary ops def add(self):