Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

float8 training axiswise scaling support with per-gemm-argument configuration #940

Merged
merged 51 commits into from
Oct 7, 2024
Merged
Show file tree
Hide file tree
Changes from 38 commits
Commits
Show all changes
51 commits
Select commit Hold shift + click to select a range
183dbec
Update
vkuzo Sep 23, 2024
241f815
Update
vkuzo Sep 23, 2024
d0b1002
Update
vkuzo Sep 23, 2024
f15c2a0
Update
vkuzo Sep 23, 2024
c816ce9
Update
vkuzo Sep 23, 2024
9150b4f
Update
vkuzo Sep 23, 2024
40279fb
Update
vkuzo Sep 23, 2024
459e92c
Update
vkuzo Sep 23, 2024
732b231
Update
vkuzo Sep 23, 2024
e7c15d1
Update
vkuzo Sep 24, 2024
1d01df3
Update
vkuzo Sep 25, 2024
62acdaf
Update
vkuzo Sep 25, 2024
381e16e
Update
vkuzo Sep 27, 2024
afdf660
Update
vkuzo Sep 27, 2024
d53f2ce
Update
vkuzo Sep 27, 2024
0737eb8
Update
vkuzo Sep 27, 2024
2791eb3
Update
vkuzo Sep 27, 2024
6cfd1cd
Update
vkuzo Sep 27, 2024
94907e5
Update
vkuzo Sep 27, 2024
423760a
Update
vkuzo Sep 27, 2024
552db23
Update
vkuzo Sep 27, 2024
953bc2f
Update
vkuzo Sep 27, 2024
c5d19e0
Update
vkuzo Sep 27, 2024
24f0f3b
Update
vkuzo Sep 27, 2024
7e0fe97
Update
vkuzo Sep 27, 2024
10f2e0f
Update
vkuzo Sep 27, 2024
4437054
Update
vkuzo Sep 30, 2024
31a017b
Update
vkuzo Sep 30, 2024
743b4c1
Update
vkuzo Sep 30, 2024
0c473c4
Update
vkuzo Oct 2, 2024
c1be278
Update
vkuzo Oct 2, 2024
179e3b3
Update
vkuzo Oct 2, 2024
fc8d4ef
Update
vkuzo Oct 2, 2024
76322de
Update
vkuzo Oct 2, 2024
4eec2bd
Update
vkuzo Oct 2, 2024
ac6f768
Update
vkuzo Oct 2, 2024
02b9fca
Update
vkuzo Oct 2, 2024
5756aa5
Update
vkuzo Oct 2, 2024
1f01df9
Update
vkuzo Oct 4, 2024
b9bcc30
Update
vkuzo Oct 4, 2024
e595f30
Update
vkuzo Oct 4, 2024
4bb59a6
Update
vkuzo Oct 4, 2024
c1c218f
Update
vkuzo Oct 4, 2024
024fe94
Update
vkuzo Oct 4, 2024
4027694
Update
vkuzo Oct 4, 2024
076de91
Update
vkuzo Oct 4, 2024
ca127f0
Update
vkuzo Oct 4, 2024
712fd5d
Update
vkuzo Oct 5, 2024
d70326c
Update
vkuzo Oct 7, 2024
f2d104a
Update
vkuzo Oct 7, 2024
b536435
Update
vkuzo Oct 7, 2024
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
32 changes: 27 additions & 5 deletions benchmarks/float8/bench_linear_float8.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,12 @@

import torch
import torch.utils.benchmark as benchmark
from torchao.float8.config import CastConfig, Float8LinearConfig, ScalingType
from torchao.float8.config import (
CastConfig,
Float8LinearConfig,
ScalingType,
ScalingGranularity,
)
from torchao.float8.float8_linear import Float8Linear
from torchao.float8.float8_linear_utils import (
linear_requires_sync,
Expand Down Expand Up @@ -107,35 +112,49 @@ def main(
scaling_type_input: str = "dynamic",
scaling_type_weight: str = "dynamic",
scaling_type_grad_output: str = "dynamic",
scaling_granularity: str = "tensorwise",
):
device = "cuda"
print(f"Compile is set to | {compile}")

scaling_type_input = ScalingType(scaling_type_input)
scaling_type_weight = ScalingType(scaling_type_weight)
scaling_type_grad_output = ScalingType(scaling_type_grad_output)
scaling_granularity = ScalingGranularity(scaling_granularity)

if scaling_type_input is ScalingType.STATIC:
cast_config_input=CastConfig(
scaling_type=scaling_type_input,
static_scale=torch.tensor([1.0], device="cuda"),
scaling_granularity=scaling_granularity,
)
else:
cast_config_input=CastConfig(scaling_type=scaling_type_input)
cast_config_input=CastConfig(
scaling_type=scaling_type_input,
scaling_granularity=scaling_granularity,
)
if scaling_type_weight is ScalingType.STATIC:
cast_config_weight=CastConfig(
scaling_type=scaling_type_weight,
static_scale=torch.tensor([1.0], device="cuda"),
scaling_granularity=scaling_granularity,
)
else:
cast_config_weight=CastConfig(scaling_type=scaling_type_weight)
cast_config_weight=CastConfig(
scaling_type=scaling_type_weight,
scaling_granularity=scaling_granularity,
)
if scaling_type_grad_output is ScalingType.STATIC:
cast_config_grad_output=CastConfig(
scaling_type=scaling_type_grad_output,
static_scale=torch.tensor([1.0], device="cuda"),
scaling_granularity=scaling_granularity,
)
else:
cast_config_grad_output=CastConfig(scaling_type=scaling_type_grad_output)
cast_config_grad_output=CastConfig(
scaling_type=scaling_type_grad_output,
scaling_granularity=scaling_granularity,
)

config = Float8LinearConfig(
cast_config_input=cast_config_input,
Expand Down Expand Up @@ -167,7 +186,7 @@ def main(
copy.deepcopy(linear_ref),
config=config,
)
scaling_repr = linear_float8.scaling_repr()
scaling_repr = f"{linear_float8.scaling_type_repr()},{linear_float8.scaling_granularity_repr()}"

if fast_accum:
linear_float8.forward_config = ScaledMMConfig(False, True, False)
Expand Down Expand Up @@ -310,6 +329,7 @@ def invoke_main() -> None:
parser.add_argument("--scaling_type_input", type=str, required=False)
parser.add_argument("--scaling_type_weight", type=str, required=False)
parser.add_argument("--scaling_type_grad_output", type=str, required=False)
parser.add_argument("--scaling_granularity", type=str, required=False)
args = parser.parse_args()
output_path = Path(args.output_path) if args.output_path is not None else None
kwargs = {}
Expand All @@ -327,6 +347,8 @@ def invoke_main() -> None:
kwargs["scaling_type_weight"] = args.scaling_type_weight
if args.scaling_type_grad_output is not None:
kwargs["scaling_type_grad_output"] = args.scaling_type_grad_output
if args.scaling_granularity is not None:
kwargs["scaling_granularity"] = args.scaling_granularity
main(
output_path,
not args.disable_compile,
Expand Down
13 changes: 11 additions & 2 deletions benchmarks/float8/bench_matmul.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,8 @@
import torch.nn as nn
import torch.utils.benchmark as benchmark

from torchao.float8.config import ScalingGranularity

from utils import (
get_name_to_shapes_iter,
profiler_output_to_filtered_time_by_kernel_name,
Expand Down Expand Up @@ -75,6 +77,7 @@ def run(
K: Optional[int] = None,
N: Optional[int] = None,
use_gpu_kernel_time: bool = False,
scaling_granularity: str = "tensorwise",
):
device = "cuda"

Expand All @@ -84,6 +87,7 @@ def run(
dtype = torch.bfloat16
name_to_shapes = get_name_to_shapes_iter(shape_gen_name, M, K, N)
fast_accum_vals = [True, False]
scaling_granularity = ScalingGranularity(scaling_granularity)

for idx, (fast_accum, (name, (M, K, N))) in enumerate(itertools.product(fast_accum_vals, name_to_shapes)):
if n_limit is not None and idx >= n_limit:
Expand All @@ -109,8 +113,13 @@ def run(
d1, d2, d3 = torch.float8_e4m3fn, torch.float8_e4m3fn, dtype
A = torch.zeros(M, K, device=device, dtype=d1)
B = torch.zeros(K, N, device=device, dtype=d2).t().contiguous().t()
scale_a = torch.tensor([1.0], device=device)
scale_b = torch.tensor([1.0], device=device)
if scaling_granularity == ScalingGranularity.TENSORWISE:
scale_a = torch.tensor([1.0], device=device)
scale_b = torch.tensor([1.0], device=device)
else:
assert scaling_granularity == ScalingGranularity.AXISWISE, "unsupported"
scale_a = torch.ones(M, 1, device=device)
scale_b = torch.ones(1, N, device=device)

def do_matmul(A, B):
nonlocal scale_a
Expand Down
82 changes: 55 additions & 27 deletions benchmarks/float8/profile_linear_float8.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,14 @@
import torch
import torch.nn as nn
import torch.nn.functional as F
from torchao.float8.config import CastConfig, Float8LinearConfig, ScalingType
from torchao.float8.config import (
CastConfig,
Float8LinearConfig,
ScalingType,
ScalingGranularity,
_Float8LinearRecipeName,
_recipe_name_to_linear_config,
)
from torchao.float8.float8_linear_utils import (
convert_to_float8_training,
linear_requires_sync,
Expand Down Expand Up @@ -252,6 +259,8 @@ def main(
scaling_type_input: str = "dynamic",
scaling_type_weight: str = "dynamic",
scaling_type_grad_output: str = "dynamic",
scaling_granularity: str = "tensorwise",
recipe_name: Optional[str] = None,
model_type: str = "linear",
dtype_filter: str = "both",
add_inductor_metadata_to_trace: bool = True,
Expand All @@ -263,34 +272,53 @@ def main(
scaling_type_input = ScalingType(scaling_type_input)
scaling_type_weight = ScalingType(scaling_type_weight)
scaling_type_grad_output = ScalingType(scaling_type_grad_output)

if scaling_type_input is ScalingType.STATIC:
cast_config_input=CastConfig(
scaling_type=scaling_type_input,
static_scale=torch.tensor([1.0], device="cuda"),
)
else:
cast_config_input=CastConfig(scaling_type=scaling_type_input)
if scaling_type_weight is ScalingType.STATIC:
cast_config_weight=CastConfig(
scaling_type=scaling_type_weight,
static_scale=torch.tensor([1.0], device="cuda"),
scaling_granularity = ScalingGranularity(scaling_granularity)

if recipe_name is None:
vkuzo marked this conversation as resolved.
Show resolved Hide resolved

if scaling_type_input is ScalingType.STATIC:
cast_config_input=CastConfig(
scaling_type=scaling_type_input,
static_scale=torch.tensor([1.0], device="cuda"),
scaling_granularity=scaling_granularity,
)
else:
cast_config_input=CastConfig(
scaling_type=scaling_type_input,
scaling_granularity=scaling_granularity,
)
if scaling_type_weight is ScalingType.STATIC:
cast_config_weight=CastConfig(
scaling_type=scaling_type_weight,
static_scale=torch.tensor([1.0], device="cuda"),
scaling_granularity=scaling_granularity,
)
else:
cast_config_weight=CastConfig(
scaling_type=scaling_type_weight,
scaling_granularity=scaling_granularity,
)
if scaling_type_grad_output is ScalingType.STATIC:
cast_config_grad_output=CastConfig(
scaling_type=scaling_type_grad_output,
static_scale=torch.tensor([1.0], device="cuda"),
scaling_granularity=scaling_granularity,
)
else:
cast_config_grad_output=CastConfig(
scaling_type=scaling_type_grad_output,
scaling_granularity=scaling_granularity,
)

config = Float8LinearConfig(
cast_config_input=cast_config_input,
cast_config_weight=cast_config_weight,
cast_config_grad_output=cast_config_grad_output,
)
else:
cast_config_weight=CastConfig(scaling_type=scaling_type_weight)
if scaling_type_grad_output is ScalingType.STATIC:
cast_config_grad_output=CastConfig(
scaling_type=scaling_type_grad_output,
static_scale=torch.tensor([1.0], device="cuda"),
)
else:
cast_config_grad_output=CastConfig(scaling_type=scaling_type_grad_output)

config = Float8LinearConfig(
cast_config_input=cast_config_input,
cast_config_weight=cast_config_weight,
cast_config_grad_output=cast_config_grad_output,
)
elif recipe_name is not None:
recipe_name = _Float8LinearRecipeName(recipe_name)
config = _recipe_name_to_linear_config(recipe_name)

scaling_repr = "_".join(
[
Expand Down
Loading
Loading