Skip to content

Commit

Permalink
Flops profiler support einops.einsum (#6755)
Browse files Browse the repository at this point in the history
- Added support for FlopsProfiler to include einops.einsum operation
- Added _patch_miscellaneous_operations() and
_reload_miscellaneous_operations() to include this operation and
potentially include other miscellaneous operations in the future
- Added _einops_einsum_flops_compute() that mimic already-existed
_einsum_flops_compute()

---------

Co-authored-by: Logan Adams <[email protected]>
  • Loading branch information
lvhoaa and loadams authored Dec 9, 2024
1 parent 9ca6016 commit 9a41cca
Showing 1 changed file with 34 additions and 0 deletions.
34 changes: 34 additions & 0 deletions deepspeed/profiling/flops_profiler/profiler.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@
from deepspeed.moe.layer import MoE
from deepspeed.utils.timer import FORWARD_GLOBAL_TIMER, BACKWARD_GLOBAL_TIMER, STEP_GLOBAL_TIMER
from deepspeed.utils.torch import required_torch_version
import einops

Tensor = torch.Tensor

Expand Down Expand Up @@ -82,6 +83,7 @@ def start_profile(self, ignore_list=None):
self.reset_profile()
_patch_functionals()
_patch_tensor_methods()
_patch_miscellaneous_operations()

def register_module_hooks(module, ignore_list):
if ignore_list and type(module) in ignore_list:
Expand Down Expand Up @@ -137,6 +139,7 @@ def stop_profile(self):
if self.started and self.func_patched:
_reload_functionals()
_reload_tensor_methods()
_reload_miscellaneous_operations()
self.func_patched = False

def remove_profile_attrs(module):
Expand Down Expand Up @@ -787,6 +790,29 @@ def _einsum_flops_compute(equation, *operands):
raise NotImplementedError("Unsupported einsum operation.")


def _einops_einsum_flops_compute(*args):
"""
Count flops for the einops.einsum operation.
"""
*operands, equation = args
input_shapes = [o.shape for o in operands]

# Re-map equation so that same equation with different alphabet
# representations will look the same.
letter_order = OrderedDict((k, 0) for k in equation if k.isalpha()).keys()
mapping = {ord(x): 97 + i for i, x in enumerate(letter_order)}
equation = equation.translate(mapping)

np_arrs = [np.zeros(s) for s in input_shapes]
optim = np.einsum_path(equation, *np_arrs, optimize="optimal")[1]
for line in optim.split("\n"):
if "optimized flop" in line.lower():
flop = int(float(line.split(":")[-1]))
return flop, 0

raise NotImplementedError("Unsupported einops.einsum operation.")


def _tensor_addmm_flops_compute(self, mat1, mat2, *, beta=1, alpha=1, out=None):
"""
Count flops for the tensor addmm operation.
Expand Down Expand Up @@ -937,6 +963,10 @@ def _patch_tensor_methods():
torch.baddbmm = wrapFunc(torch.baddbmm, _tensor_addmm_flops_compute)


def _patch_miscellaneous_operations():
einops.einsum = wrapFunc(einops.einsum, _einops_einsum_flops_compute)


def _reload_functionals():
# torch.nn.functional does not support importlib.reload()
F.linear = old_functions[F.linear.__str__]
Expand Down Expand Up @@ -995,6 +1025,10 @@ def _reload_tensor_methods():
torch.baddbmm = old_functions[torch.baddbmm.__str__]


def _reload_miscellaneous_operations():
einops.einsum = old_functions[einops.einsum.__str__]


def _rnn_flops(flops, rnn_module, w_ih, w_hh, input_size):
gates_size = w_ih.shape[0]
# matrix matrix mult ih state and internal state
Expand Down

0 comments on commit 9a41cca

Please sign in to comment.