Skip to content

Commit

Permalink
Lint
Browse files Browse the repository at this point in the history
  • Loading branch information
cloudhan committed Nov 22, 2023
1 parent b601911 commit b440cf6
Show file tree
Hide file tree
Showing 2 changed files with 3 additions and 14 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -4,23 +4,13 @@
# --------------------------------------------------------------------------

import sys
import ctypes
from dataclasses import dataclass
from itertools import product

import kernel_explorer as ke
import numpy as np
import pytest
from ml_dtypes import float8_e4m3fn, float8_e4m3fnuz, finfo
from utils import (
dtype_to_suffix,
get_gemm_basic_sizes,
get_gemm_bert_sizes,
get_gemm_bound,
matmul,
transab_to_suffix,
dtype_to_bytes,
)
from ml_dtypes import finfo, float8_e4m3fn, float8_e4m3fnuz
from utils import dtype_to_bytes, dtype_to_suffix, get_gemm_bert_sizes, matmul, transab_to_suffix


def create_device_array(a):
Expand All @@ -45,7 +35,7 @@ def compute_scaling_factor(a: np.ndarray, fp8_max: float, margin: int) -> np.nda
def cast_and_scale(a, dtype: str):
if dtype == "float16":
return a.astype(dtype), 1.0
elif dtype in ("float8_e4m3fn", "float8_e4m3fnuz"):
elif np.dtype(dtype) in (float8_e4m3fn, float8_e4m3fnuz):
t = globals()[dtype]
sf = compute_scaling_factor(a, fp8_max=finfo(t).max, margin=4)
return (a * sf).astype(t), sf
Expand Down
1 change: 0 additions & 1 deletion onnxruntime/python/tools/kernel_explorer/kernels/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,6 @@

import numpy as np
import scipy.special
from ml_dtypes import float8_e4m3fnuz


def dtype_to_bytes(dtype):
Expand Down

0 comments on commit b440cf6

Please sign in to comment.