Skip to content

Commit

Permalink
[wip] make scaling configurable by gemm-argument
Browse files Browse the repository at this point in the history
Summary:

My brain hurts from so many long identifiers...

Test Plan:

Reviewers:

Subscribers:

Tasks:

Tags:

ghstack-source-id: 26b3b8ff4e59cc59bf4580056575e25ca7492d4f
ghstack-comment-id: 2372563439
Pull Request resolved: #940
  • Loading branch information
vkuzo committed Oct 5, 2024
1 parent 3f24d79 commit 371775a
Show file tree
Hide file tree
Showing 10 changed files with 563 additions and 427 deletions.
53 changes: 43 additions & 10 deletions benchmarks/float8/float8_roofline.py
Original file line number Diff line number Diff line change
Expand Up @@ -70,6 +70,7 @@
ScalingType,
CastConfig,
)
from torchao.float8.config import recipe_name_to_linear_config, Float8LinearRecipeName


class LNLinearSigmoid(torch.nn.Module):
Expand Down Expand Up @@ -129,6 +130,8 @@ def get_gemm_times(M, K, N, fast_accum, cache_filename=None):
else:
# cache does not exist yet, create it
cache = dict()
else:
cache = dict()
key = f"{M},{K},{N},{fast_accum}"
if key in cache:
return cache[key]
Expand All @@ -153,13 +156,18 @@ def do_matmul(A, B):
)
f8_time_s = get_gpu_kernel_gemm_time_s(do_matmul, A, B)

scale_a = torch.ones(M, 1, device=device)
scale_b = torch.ones(1, N, device=device)
fast_accum = True # for axiswise
f8_axs_time_s = get_gpu_kernel_gemm_time_s(do_matmul, A, B)

# save to cache if needed
if cache_filename is not None:
cache[key] = [bf16_time_s, f8_time_s]
cache[key] = [bf16_time_s, f8_time_s, f8_axs_time_s]
with open(cache_filename, 'w') as f:
json.dump(cache, f)

return bf16_time_s, f8_time_s
return bf16_time_s, f8_time_s, f8_axs_time_s

def run(
outfile: str,
Expand Down Expand Up @@ -231,13 +239,15 @@ def run(
headers = [
'fwd_M', 'fwd_K', 'fwd_N',
# gemm microbenchmarks
'bf16_gemm_s', 'fp8_gemm_s',
'bf16_gemm_s', 'fp8_gemm_s', 'fp8_axs_gemm_time_s',
# roofline memory overhead estimates
'fp8_oh_dyn_limit', 'fp8_oh_dyn_nolimit',
'fp8_oh_del_limit', 'fp8_oh_del_nolimit',
# actual e2e measurements
'bf16_e2e_s', 'fp8_dyn_e2e_s', 'fp8_del_e2e_s',
'fp8_dyn_speedup', 'fp8_del_speedup',
'bf16_s', 'fp8_dyn_s', 'fp8_del_s', 'fp8_dyn_axs_s',
# 'fp8_lw_s',
'fp8_dyn_sp', 'fp8_del_sp', 'fp8_dyn_axs_sp',
# 'fp8_lw_sp',
]
results = []

Expand All @@ -248,15 +258,18 @@ def run(
break

if gemm_time_strategy == "benchmarks":
bf16_g1, f8_g1 = get_gemm_times(M_val, K_val, N_val, True, gemm_cache_filename)
bf16_g2, f8_g2 = get_gemm_times(M_val, N_val, K_val, False, gemm_cache_filename)
bf16_g3, f8_g3 = get_gemm_times(K_val, M_val, N_val, False, gemm_cache_filename)
bf16_g1, f8_g1, f8_g1_axs = get_gemm_times(M_val, K_val, N_val, True, gemm_cache_filename)
bf16_g2, f8_g2, f8_g2_axs = get_gemm_times(M_val, N_val, K_val, False, gemm_cache_filename)
bf16_g3, f8_g3, f8_g3_axs = get_gemm_times(K_val, M_val, N_val, False, gemm_cache_filename)
bf16_time_val = bf16_g1 + bf16_g2 + bf16_g3
fp8_gemm_time_s = f8_g1 + f8_g2 + f8_g3
fp8_axs_gemm_time_s = f8_g1_axs + f8_g2_axs + f8_g3_axs
else:
assert gemm_time_strategy == "roofline", "unsupported"
bf16_time_val = bf16_gemm_time_sympy.subs(M, M_val).subs(K, K_val).subs(N, N_val)
fp8_gemm_time_s = fp8_gemm_time_sympy.subs(M, M_val).subs(K, K_val).subs(N, N_val)
# for now, assume axiswise gemm is similar to tensorwise
fp8_axs_gemm_time_s = fp8_gemm_time_s

fp8_mem_time_dyn_limit_s = \
fp8_mem_time_sympy_dyn_limit.subs(M, M_val).subs(K, K_val).subs(N, N_val)
Expand Down Expand Up @@ -291,23 +304,43 @@ def run(
cast_config_weight=CastConfig(scaling_type=ScalingType.DELAYED),
cast_config_grad_output=CastConfig(scaling_type=ScalingType.DELAYED),
)
m_fp8_del = convert_to_float8_training(m_orig)
m_fp8_del = convert_to_float8_training(copy.deepcopy(m_orig), config=config)
m_fp8_del = torch.compile(m_fp8_del)
fp8_del_time_actual_s = get_gpu_kernel_time(m_fp8_del, x)

# get the float8 dynamic axiswise scaling gpu kernel time
torch._dynamo.reset()
config = recipe_name_to_linear_config(Float8LinearRecipeName.ALL_AXISWISE)
m_fp8_dyn_axs = convert_to_float8_training(copy.deepcopy(m_orig), config=config)
m_fp8_dyn_axs = torch.compile(m_fp8_dyn_axs)
fp8_dyn_axs_time_actual_s = get_gpu_kernel_time(m_fp8_dyn_axs, x)

# get the lw recipe scaling gpu kernel time
# TODO(future PR): enable below once basic performance issues
# are fixed
# torch._dynamo.reset()
# config = recipe_name_to_linear_config(Float8LinearRecipeName.LW_AXISWISE_WITH_GW_HP)
# m_fp8_lw = convert_to_float8_training(m_orig, config=config)
# m_fp8_lw = torch.compile(m_fp8_lw)
# fp8_lw_time_actual_s = get_gpu_kernel_time(m_fp8_lw, x)

results.append([
M_val, K_val, N_val,
# gemm microbenchmarks
bf16_time_val, fp8_gemm_time_s,
bf16_time_val, fp8_gemm_time_s, fp8_axs_gemm_time_s,
# roofline overhead estimates
fp8_mem_time_dyn_limit_s,
fp8_mem_time_dyn_nolimit_s,
fp8_mem_time_del_limit_s,
fp8_mem_time_del_nolimit_s,
# e2e numbers
bf16_time_actual_s, fp8_dyn_time_actual_s, fp8_del_time_actual_s,
fp8_dyn_axs_time_actual_s,
# fp8_lw_time_actual_s,
bf16_time_actual_s / fp8_dyn_time_actual_s,
bf16_time_actual_s / fp8_del_time_actual_s,
bf16_time_actual_s / fp8_dyn_axs_time_actual_s,
# bf16_time_actual_s / fp8_lw_time_actual_s,
])

df = pd.DataFrame(results, columns=headers)
Expand Down
53 changes: 13 additions & 40 deletions benchmarks/float8/profile_linear_float8.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,12 +27,15 @@
Float8LinearConfig,
ScalingType,
ScalingGranularity,
Float8LinearRecipeName,
recipe_name_to_linear_config,
)
from torchao.float8.float8_linear_utils import (
convert_to_float8_training,
linear_requires_sync,
sync_float8_amax_and_scale_history,
)
from torchao.testing.float8.test_utils import get_test_float8_linear_config
from torch.profiler import profile, ProfilerActivity, record_function
from utils import (
kernel_name_to_category,
Expand Down Expand Up @@ -257,7 +260,7 @@ def main(
scaling_type_input: str = "dynamic",
scaling_type_weight: str = "dynamic",
scaling_type_grad_output: str = "dynamic",
scaling_granularity: str = "tensorwise",
recipe_name: Optional[str] = None,
model_type: str = "linear",
dtype_filter: str = "both",
add_inductor_metadata_to_trace: bool = True,
Expand All @@ -269,47 +272,17 @@ def main(
scaling_type_input = ScalingType(scaling_type_input)
scaling_type_weight = ScalingType(scaling_type_weight)
scaling_type_grad_output = ScalingType(scaling_type_grad_output)
scaling_granularity = ScalingGranularity(scaling_granularity)

if scaling_type_input is ScalingType.STATIC:
cast_config_input=CastConfig(
scaling_type=scaling_type_input,
static_scale=torch.tensor([1.0], device="cuda"),
scaling_granularity=scaling_granularity,
if recipe_name is None:
config = get_test_float8_linear_config(
scaling_type_input,
scaling_type_weight,
scaling_type_grad_output,
emulate=False,
)
else:
cast_config_input=CastConfig(
scaling_type=scaling_type_input,
scaling_granularity=scaling_granularity,
)
if scaling_type_weight is ScalingType.STATIC:
cast_config_weight=CastConfig(
scaling_type=scaling_type_weight,
static_scale=torch.tensor([1.0], device="cuda"),
scaling_granularity=scaling_granularity,
)
else:
cast_config_weight=CastConfig(
scaling_type=scaling_type_weight,
scaling_granularity=scaling_granularity,
)
if scaling_type_grad_output is ScalingType.STATIC:
cast_config_grad_output=CastConfig(
scaling_type=scaling_type_grad_output,
static_scale=torch.tensor([1.0], device="cuda"),
scaling_granularity=scaling_granularity,
)
else:
cast_config_grad_output=CastConfig(
scaling_type=scaling_type_grad_output,
scaling_granularity=scaling_granularity,
)

config = Float8LinearConfig(
cast_config_input=cast_config_input,
cast_config_weight=cast_config_weight,
cast_config_grad_output=cast_config_grad_output,
)
elif recipe_name is not None:
recipe_name = Float8LinearRecipeName(recipe_name)
config = recipe_name_to_linear_config(recipe_name)

scaling_repr = "_".join(
[
Expand Down
Loading

0 comments on commit 371775a

Please sign in to comment.