diff --git a/tests/modeldiffs/fastmri_layernorm/__init__.py b/tests/modeldiffs/fastmri_layernorm/__init__.py new file mode 100644 index 000000000..e69de29bb diff --git a/tests/modeldiffs/fastmri_layernorm/compare.py b/tests/modeldiffs/fastmri_layernorm/compare.py new file mode 100644 index 000000000..6780ff91e --- /dev/null +++ b/tests/modeldiffs/fastmri_layernorm/compare.py @@ -0,0 +1,86 @@ +import os + +# Disable GPU access for both jax and pytorch. +os.environ['CUDA_VISIBLE_DEVICES'] = '' + +import jax +import torch + +from algorithmic_efficiency import spec +from algorithmic_efficiency.workloads.fastmri.fastmri_jax.workload import \ + FastMRIWorkload as JaxWorkload +from algorithmic_efficiency.workloads.fastmri.fastmri_pytorch.workload import \ + FastMRIWorkload as PytWorkload +from tests.modeldiffs.diff import out_diff + + +def sd_transform(sd): + + def sort_key(k): + if k[0] == 'ModuleList_0': + return (0, *k) + if k[0] == 'ConvBlock_0': + return (1, *k) + if k[0] == 'ModuleList_1': + return (2, *k) + if k[0] == 'ModuleList_2': + return (3, *k) + + keys = sorted(sd.keys(), key=sort_key) + c = 0 + for idx, k in enumerate(keys): + new_key = [] + for idx2, i in enumerate(k): + if 'ModuleList' in i or 'Sequential' in i: + continue + if i.startswith('ConvBlock'): + if idx != 0 and keys[idx - 1][:idx2 + 1] != k[:idx2 + 1]: + c += 1 + i = f'ConvBlock_{c}' + if 'Conv2d' in i: + i = i.replace('Conv2d', 'Conv') + if 'ConvTranspose2d' in i: + i = i.replace('ConvTranspose2d', 'ConvTranspose') + if 'weight' in i: + i = i.replace('weight', 'kernel') + new_key.append(i) + new_key = tuple(new_key) + sd[new_key] = sd[k] + del sd[k] + return sd + + +key_transform = None +if __name__ == '__main__': + # pylint: disable=locally-disabled, not-callable + + jax_workload = JaxWorkload() + pytorch_workload = PytWorkload() + + # Test outputs for identical weights and inputs. + image = torch.randn(2, 320, 320) + + jax_batch = {'inputs': image.detach().numpy()} + pyt_batch = {'inputs': image} + + pytorch_model_kwargs = dict( + augmented_and_preprocessed_input_batch=pyt_batch, + model_state=None, + mode=spec.ForwardPassMode.EVAL, + rng=None, + update_batch_norm=False) + + jax_model_kwargs = dict( + augmented_and_preprocessed_input_batch=jax_batch, + mode=spec.ForwardPassMode.EVAL, + rng=jax.random.PRNGKey(0), + update_batch_norm=False) + + out_diff( + jax_workload=jax_workload, + pytorch_workload=pytorch_workload, + jax_model_kwargs=jax_model_kwargs, + pytorch_model_kwargs=pytorch_model_kwargs, + key_transform=None, + sd_transform=sd_transform, + ) diff --git a/tests/modeldiffs/fastmri_model_size/__init__.py b/tests/modeldiffs/fastmri_model_size/__init__.py new file mode 100644 index 000000000..e69de29bb diff --git a/tests/modeldiffs/fastmri_model_size/compare.py b/tests/modeldiffs/fastmri_model_size/compare.py new file mode 100644 index 000000000..60d846b6f --- /dev/null +++ b/tests/modeldiffs/fastmri_model_size/compare.py @@ -0,0 +1,86 @@ +import os + +# Disable GPU access for both jax and pytorch. +os.environ['CUDA_VISIBLE_DEVICES'] = '' + +import jax +import torch + +from algorithmic_efficiency import spec +from algorithmic_efficiency.workloads.fastmri.fastmri_jax.workload import \ + FastMRIModelSizeWorkload as JaxWorkload +from algorithmic_efficiency.workloads.fastmri.fastmri_pytorch.workload import \ + FastMRIModelSizeWorkload as PytWorkload +from tests.modeldiffs.diff import out_diff + + +def sd_transform(sd): + + def sort_key(k): + if k[0] == 'ModuleList_0': + return (0, *k) + if k[0] == 'ConvBlock_0': + return (1, *k) + if k[0] == 'ModuleList_1': + return (2, *k) + if k[0] == 'ModuleList_2': + return (3, *k) + + keys = sorted(sd.keys(), key=sort_key) + c = 0 + for idx, k in enumerate(keys): + new_key = [] + for idx2, i in enumerate(k): + if 'ModuleList' in i or 'Sequential' in i: + continue + if i.startswith('ConvBlock'): + if idx != 0 and keys[idx - 1][:idx2 + 1] != k[:idx2 + 1]: + c += 1 + i = f'ConvBlock_{c}' + if 'Conv2d' in i: + i = i.replace('Conv2d', 'Conv') + if 'ConvTranspose2d' in i: + i = i.replace('ConvTranspose2d', 'ConvTranspose') + if 'weight' in i: + i = i.replace('weight', 'kernel') + new_key.append(i) + new_key = tuple(new_key) + sd[new_key] = sd[k] + del sd[k] + return sd + + +key_transform = None +if __name__ == '__main__': + # pylint: disable=locally-disabled, not-callable + + jax_workload = JaxWorkload() + pytorch_workload = PytWorkload() + + # Test outputs for identical weights and inputs. + image = torch.randn(2, 320, 320) + + jax_batch = {'inputs': image.detach().numpy()} + pyt_batch = {'inputs': image} + + pytorch_model_kwargs = dict( + augmented_and_preprocessed_input_batch=pyt_batch, + model_state=None, + mode=spec.ForwardPassMode.EVAL, + rng=None, + update_batch_norm=False) + + jax_model_kwargs = dict( + augmented_and_preprocessed_input_batch=jax_batch, + mode=spec.ForwardPassMode.EVAL, + rng=jax.random.PRNGKey(0), + update_batch_norm=False) + + out_diff( + jax_workload=jax_workload, + pytorch_workload=pytorch_workload, + jax_model_kwargs=jax_model_kwargs, + pytorch_model_kwargs=pytorch_model_kwargs, + key_transform=None, + sd_transform=sd_transform, + ) diff --git a/tests/modeldiffs/fastmri_tanh/__init__.py b/tests/modeldiffs/fastmri_tanh/__init__.py new file mode 100644 index 000000000..e69de29bb diff --git a/tests/modeldiffs/fastmri_tanh/compare.py b/tests/modeldiffs/fastmri_tanh/compare.py new file mode 100644 index 000000000..6780ff91e --- /dev/null +++ b/tests/modeldiffs/fastmri_tanh/compare.py @@ -0,0 +1,86 @@ +import os + +# Disable GPU access for both jax and pytorch. +os.environ['CUDA_VISIBLE_DEVICES'] = '' + +import jax +import torch + +from algorithmic_efficiency import spec +from algorithmic_efficiency.workloads.fastmri.fastmri_jax.workload import \ + FastMRIWorkload as JaxWorkload +from algorithmic_efficiency.workloads.fastmri.fastmri_pytorch.workload import \ + FastMRIWorkload as PytWorkload +from tests.modeldiffs.diff import out_diff + + +def sd_transform(sd): + + def sort_key(k): + if k[0] == 'ModuleList_0': + return (0, *k) + if k[0] == 'ConvBlock_0': + return (1, *k) + if k[0] == 'ModuleList_1': + return (2, *k) + if k[0] == 'ModuleList_2': + return (3, *k) + + keys = sorted(sd.keys(), key=sort_key) + c = 0 + for idx, k in enumerate(keys): + new_key = [] + for idx2, i in enumerate(k): + if 'ModuleList' in i or 'Sequential' in i: + continue + if i.startswith('ConvBlock'): + if idx != 0 and keys[idx - 1][:idx2 + 1] != k[:idx2 + 1]: + c += 1 + i = f'ConvBlock_{c}' + if 'Conv2d' in i: + i = i.replace('Conv2d', 'Conv') + if 'ConvTranspose2d' in i: + i = i.replace('ConvTranspose2d', 'ConvTranspose') + if 'weight' in i: + i = i.replace('weight', 'kernel') + new_key.append(i) + new_key = tuple(new_key) + sd[new_key] = sd[k] + del sd[k] + return sd + + +key_transform = None +if __name__ == '__main__': + # pylint: disable=locally-disabled, not-callable + + jax_workload = JaxWorkload() + pytorch_workload = PytWorkload() + + # Test outputs for identical weights and inputs. + image = torch.randn(2, 320, 320) + + jax_batch = {'inputs': image.detach().numpy()} + pyt_batch = {'inputs': image} + + pytorch_model_kwargs = dict( + augmented_and_preprocessed_input_batch=pyt_batch, + model_state=None, + mode=spec.ForwardPassMode.EVAL, + rng=None, + update_batch_norm=False) + + jax_model_kwargs = dict( + augmented_and_preprocessed_input_batch=jax_batch, + mode=spec.ForwardPassMode.EVAL, + rng=jax.random.PRNGKey(0), + update_batch_norm=False) + + out_diff( + jax_workload=jax_workload, + pytorch_workload=pytorch_workload, + jax_model_kwargs=jax_model_kwargs, + pytorch_model_kwargs=pytorch_model_kwargs, + key_transform=None, + sd_transform=sd_transform, + )