Skip to content

Commit

Permalink
add placeholder for modeldiff tests
Browse files Browse the repository at this point in the history
  • Loading branch information
priyakasimbeg committed Dec 7, 2023
1 parent 04de789 commit 993c8da
Show file tree
Hide file tree
Showing 6 changed files with 258 additions and 0 deletions.
Empty file.
86 changes: 86 additions & 0 deletions tests/modeldiffs/fastmri_layernorm/compare.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,86 @@
import os

# Disable GPU access for both jax and pytorch.
os.environ['CUDA_VISIBLE_DEVICES'] = ''

import jax
import torch

from algorithmic_efficiency import spec
from algorithmic_efficiency.workloads.fastmri.fastmri_jax.workload import \
FastMRIWorkload as JaxWorkload
from algorithmic_efficiency.workloads.fastmri.fastmri_pytorch.workload import \
FastMRIWorkload as PytWorkload
from tests.modeldiffs.diff import out_diff


def sd_transform(sd):

def sort_key(k):
if k[0] == 'ModuleList_0':
return (0, *k)
if k[0] == 'ConvBlock_0':
return (1, *k)
if k[0] == 'ModuleList_1':
return (2, *k)
if k[0] == 'ModuleList_2':
return (3, *k)

keys = sorted(sd.keys(), key=sort_key)
c = 0
for idx, k in enumerate(keys):
new_key = []
for idx2, i in enumerate(k):
if 'ModuleList' in i or 'Sequential' in i:
continue
if i.startswith('ConvBlock'):
if idx != 0 and keys[idx - 1][:idx2 + 1] != k[:idx2 + 1]:
c += 1
i = f'ConvBlock_{c}'
if 'Conv2d' in i:
i = i.replace('Conv2d', 'Conv')
if 'ConvTranspose2d' in i:
i = i.replace('ConvTranspose2d', 'ConvTranspose')
if 'weight' in i:
i = i.replace('weight', 'kernel')
new_key.append(i)
new_key = tuple(new_key)
sd[new_key] = sd[k]
del sd[k]
return sd


key_transform = None
if __name__ == '__main__':
# pylint: disable=locally-disabled, not-callable

jax_workload = JaxWorkload()
pytorch_workload = PytWorkload()

# Test outputs for identical weights and inputs.
image = torch.randn(2, 320, 320)

jax_batch = {'inputs': image.detach().numpy()}
pyt_batch = {'inputs': image}

pytorch_model_kwargs = dict(
augmented_and_preprocessed_input_batch=pyt_batch,
model_state=None,
mode=spec.ForwardPassMode.EVAL,
rng=None,
update_batch_norm=False)

jax_model_kwargs = dict(
augmented_and_preprocessed_input_batch=jax_batch,
mode=spec.ForwardPassMode.EVAL,
rng=jax.random.PRNGKey(0),
update_batch_norm=False)

out_diff(
jax_workload=jax_workload,
pytorch_workload=pytorch_workload,
jax_model_kwargs=jax_model_kwargs,
pytorch_model_kwargs=pytorch_model_kwargs,
key_transform=None,
sd_transform=sd_transform,
)
Empty file.
86 changes: 86 additions & 0 deletions tests/modeldiffs/fastmri_model_size/compare.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,86 @@
import os

# Disable GPU access for both jax and pytorch.
os.environ['CUDA_VISIBLE_DEVICES'] = ''

import jax
import torch

from algorithmic_efficiency import spec
from algorithmic_efficiency.workloads.fastmri.fastmri_jax.workload import \
FastMRIModelSizeWorkload as JaxWorkload
from algorithmic_efficiency.workloads.fastmri.fastmri_pytorch.workload import \
FastMRIModelSizeWorkload as PytWorkload
from tests.modeldiffs.diff import out_diff


def sd_transform(sd):

def sort_key(k):
if k[0] == 'ModuleList_0':
return (0, *k)
if k[0] == 'ConvBlock_0':
return (1, *k)
if k[0] == 'ModuleList_1':
return (2, *k)
if k[0] == 'ModuleList_2':
return (3, *k)

keys = sorted(sd.keys(), key=sort_key)
c = 0
for idx, k in enumerate(keys):
new_key = []
for idx2, i in enumerate(k):
if 'ModuleList' in i or 'Sequential' in i:
continue
if i.startswith('ConvBlock'):
if idx != 0 and keys[idx - 1][:idx2 + 1] != k[:idx2 + 1]:
c += 1
i = f'ConvBlock_{c}'
if 'Conv2d' in i:
i = i.replace('Conv2d', 'Conv')
if 'ConvTranspose2d' in i:
i = i.replace('ConvTranspose2d', 'ConvTranspose')
if 'weight' in i:
i = i.replace('weight', 'kernel')
new_key.append(i)
new_key = tuple(new_key)
sd[new_key] = sd[k]
del sd[k]
return sd


key_transform = None
if __name__ == '__main__':
# pylint: disable=locally-disabled, not-callable

jax_workload = JaxWorkload()
pytorch_workload = PytWorkload()

# Test outputs for identical weights and inputs.
image = torch.randn(2, 320, 320)

jax_batch = {'inputs': image.detach().numpy()}
pyt_batch = {'inputs': image}

pytorch_model_kwargs = dict(
augmented_and_preprocessed_input_batch=pyt_batch,
model_state=None,
mode=spec.ForwardPassMode.EVAL,
rng=None,
update_batch_norm=False)

jax_model_kwargs = dict(
augmented_and_preprocessed_input_batch=jax_batch,
mode=spec.ForwardPassMode.EVAL,
rng=jax.random.PRNGKey(0),
update_batch_norm=False)

out_diff(
jax_workload=jax_workload,
pytorch_workload=pytorch_workload,
jax_model_kwargs=jax_model_kwargs,
pytorch_model_kwargs=pytorch_model_kwargs,
key_transform=None,
sd_transform=sd_transform,
)
Empty file.
86 changes: 86 additions & 0 deletions tests/modeldiffs/fastmri_tanh/compare.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,86 @@
import os

# Disable GPU access for both jax and pytorch.
os.environ['CUDA_VISIBLE_DEVICES'] = ''

import jax
import torch

from algorithmic_efficiency import spec
from algorithmic_efficiency.workloads.fastmri.fastmri_jax.workload import \
FastMRIWorkload as JaxWorkload
from algorithmic_efficiency.workloads.fastmri.fastmri_pytorch.workload import \
FastMRIWorkload as PytWorkload
from tests.modeldiffs.diff import out_diff


def sd_transform(sd):

def sort_key(k):
if k[0] == 'ModuleList_0':
return (0, *k)
if k[0] == 'ConvBlock_0':
return (1, *k)
if k[0] == 'ModuleList_1':
return (2, *k)
if k[0] == 'ModuleList_2':
return (3, *k)

keys = sorted(sd.keys(), key=sort_key)
c = 0
for idx, k in enumerate(keys):
new_key = []
for idx2, i in enumerate(k):
if 'ModuleList' in i or 'Sequential' in i:
continue
if i.startswith('ConvBlock'):
if idx != 0 and keys[idx - 1][:idx2 + 1] != k[:idx2 + 1]:
c += 1
i = f'ConvBlock_{c}'
if 'Conv2d' in i:
i = i.replace('Conv2d', 'Conv')
if 'ConvTranspose2d' in i:
i = i.replace('ConvTranspose2d', 'ConvTranspose')
if 'weight' in i:
i = i.replace('weight', 'kernel')
new_key.append(i)
new_key = tuple(new_key)
sd[new_key] = sd[k]
del sd[k]
return sd


key_transform = None
if __name__ == '__main__':
# pylint: disable=locally-disabled, not-callable

jax_workload = JaxWorkload()
pytorch_workload = PytWorkload()

# Test outputs for identical weights and inputs.
image = torch.randn(2, 320, 320)

jax_batch = {'inputs': image.detach().numpy()}
pyt_batch = {'inputs': image}

pytorch_model_kwargs = dict(
augmented_and_preprocessed_input_batch=pyt_batch,
model_state=None,
mode=spec.ForwardPassMode.EVAL,
rng=None,
update_batch_norm=False)

jax_model_kwargs = dict(
augmented_and_preprocessed_input_batch=jax_batch,
mode=spec.ForwardPassMode.EVAL,
rng=jax.random.PRNGKey(0),
update_batch_norm=False)

out_diff(
jax_workload=jax_workload,
pytorch_workload=pytorch_workload,
jax_model_kwargs=jax_model_kwargs,
pytorch_model_kwargs=pytorch_model_kwargs,
key_transform=None,
sd_transform=sd_transform,
)

0 comments on commit 993c8da

Please sign in to comment.