Skip to content

Commit

Permalink
Refactor diffs module to use ModelDiffRunner class
Browse files Browse the repository at this point in the history
  • Loading branch information
Sujata Goswami authored and Sujata Goswami committed Nov 27, 2024
1 parent 86d2a0d commit b9b6e66
Show file tree
Hide file tree
Showing 2 changed files with 42 additions and 3 deletions.
7 changes: 4 additions & 3 deletions tests/modeldiffs/criteo1tb/compare.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,8 @@
Criteo1TbDlrmSmallWorkload as JaxWorkload
from algorithmic_efficiency.workloads.criteo1tb.criteo1tb_pytorch.workload import \
Criteo1TbDlrmSmallWorkload as PyTorchWorkload
from tests.modeldiffs.diff import out_diff
from tests.modeldiffs.diff import ModelDiffRunner



def key_transform(k):
Expand Down Expand Up @@ -74,11 +75,11 @@ def sd_transform(sd):
rng=jax.random.PRNGKey(0),
update_batch_norm=False)

out_diff(
ModelDiffRunner(
jax_workload=jax_workload,
pytorch_workload=pytorch_workload,
jax_model_kwargs=jax_model_kwargs,
pytorch_model_kwargs=pytorch_model_kwargs,
key_transform=key_transform,
sd_transform=sd_transform,
out_transform=None)
out_transform=None).run()
38 changes: 38 additions & 0 deletions tests/modeldiffs/diff.py
Original file line number Diff line number Diff line change
Expand Up @@ -61,3 +61,41 @@ def out_diff(jax_workload,

print(f'Max fprop difference between jax and pytorch: {max_diff}')
print(f'Min fprop difference between jax and pytorch: {min_diff}')


class ModelDiffRunner:
def __init__(self, jax_workload,
pytorch_workload,
jax_model_kwargs,
pytorch_model_kwargs,
key_transform=None,
sd_transform=None,
out_transform=None) -> None:
"""Initializes the instance based on diffing logic.
Args:
jax_workload: Workload implementation using JAX
pytorch_workload: Workload implementation using PyTorch
jax_model_kwargs: Arguments to be used for model_fn in jax workload
pytorch_model_kwargs: Arguments to be used for model_fn in PyTorch workload
key_transform: Transformation function for keys.
sd_transform: Transformation function for State Dictionary.
out_transform: Transformation function for the output.
"""

self.jax_workload = jax_workload
self.pytorch_workload = pytorch_workload
self.jax_model_kwargs = jax_model_kwargs
self.pytorch_model_kwargs = pytorch_model_kwargs
self.key_transform = key_transform
self.sd_transform = sd_transform
self.out_transform = out_transform

def run(self):
out_diff(self.jax_workload,
self.pytorch_workload,
self.jax_model_kwargs,
self.pytorch_model_kwargs,
self.key_transform,
self.sd_transform,
self.out_transform)

0 comments on commit b9b6e66

Please sign in to comment.