Skip to content

Commit

Permalink
[Haiku] Add shard_map flops estimate to the jaxpr_info.
Browse files Browse the repository at this point in the history
PiperOrigin-RevId: 669305139
  • Loading branch information
Haiku Contributor authored and copybara-github committed Aug 30, 2024
1 parent 54a1eba commit 399ec5b
Showing 1 changed file with 17 additions and 3 deletions.
20 changes: 17 additions & 3 deletions haiku/_src/jaxpr_info.py
Original file line number Diff line number Diff line change
Expand Up @@ -292,9 +292,19 @@ def _process_eqn(
name_stack=name_stack.split('/') if name_stack else [])

if eqn.primitive.name in [
'named_call', 'custom_jvp_call_jaxpr', 'custom_vjp_call_jaxpr',
'custom_jvp_call', 'custom_vjp_call', 'remat_call', 'scan', 'while',
'xla_call', 'xla_pmap', 'pjit', 'remat2'
'named_call',
'custom_jvp_call_jaxpr',
'custom_vjp_call_jaxpr',
'custom_jvp_call',
'custom_vjp_call',
'remat_call',
'scan',
'while',
'xla_call',
'xla_pmap',
'pjit',
'remat2',
'shard_map',
]:
flops_multiplier = 1
if eqn.primitive.name in ['named_call', 'xla_call']:
Expand All @@ -318,6 +328,10 @@ def _process_eqn(
elif eqn.primitive.name == 'pjit':
name = eqn.params['name']
jaxpr = eqn.params['jaxpr'].jaxpr
elif eqn.primitive.name == 'shard_map':
name = 'shard_map'
jaxpr = eqn.params['jaxpr']
flops_multiplier = eqn.params['mesh'].size
elif eqn.primitive.name in ['custom_jvp_call', 'custom_vjp_call']:
name = eqn.primitive.name.replace('_call', '')
jaxpr = eqn.params['call_jaxpr']
Expand Down

0 comments on commit 399ec5b

Please sign in to comment.