diff --git a/op_builder/builder.py b/op_builder/builder.py index 6d593b48217e..bd1608b0f563 100644 --- a/op_builder/builder.py +++ b/op_builder/builder.py @@ -506,6 +506,9 @@ def jit_load(self, verbose=True): cxx_args.append("-DBF16_AVAILABLE") nvcc_args.append("-DBF16_AVAILABLE") + if self.is_rocm_pytorch(): + cxx_args.append("-D__HIP_PLATFORM_AMD__=1") + op_module = load(name=self.name, sources=self.strip_empty_entries(sources), extra_include_paths=self.strip_empty_entries(extra_include_paths),