Skip to content

Commit

Permalink
Skip loading triton.nvidia.cublas if not found
Browse files Browse the repository at this point in the history
Summary:
We have an old triton internally that doesn't have the cublasLt
bindings

Reviewed By: adamomainz

Differential Revision: D63643619

fbshipit-source-id: 39aece74b52f7747fe2100d7bb905bad49ba1fa0
  • Loading branch information
bertmaher authored and facebook-github-bot committed Oct 1, 2024
1 parent b6b67a4 commit 0611c41
Showing 1 changed file with 9 additions and 5 deletions.
14 changes: 9 additions & 5 deletions torchbenchmark/operators/fp8_gemm/persistent.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,13 +5,17 @@
import triton.language as tl
import triton.tools.experimental_descriptor

cublas = None
if torch.cuda.is_available():
from triton._C.libtriton import nvidia
try:
from triton._C.libtriton import nvidia

cublas_workspace = torch.empty(32 * 1024 * 1024, device="cuda", dtype=torch.uint8)
cublas = nvidia.cublas.CublasLt(cublas_workspace)
else:
cublas = None
cublas_workspace = torch.empty(
32 * 1024 * 1024, device="cuda", dtype=torch.uint8
)
cublas = nvidia.cublas.CublasLt(cublas_workspace)
except (ImportError, IOError, AttributeError):
pass


def is_cuda():
Expand Down

0 comments on commit 0611c41

Please sign in to comment.