Skip to content

Commit

Permalink
add conformer modeldiff tests
Browse files Browse the repository at this point in the history
  • Loading branch information
priyakasimbeg committed Nov 28, 2023
1 parent d727522 commit e61c1d9
Show file tree
Hide file tree
Showing 6 changed files with 276 additions and 0 deletions.
Empty file.
Original file line number Diff line number Diff line change
@@ -0,0 +1,92 @@
import os

# Disable GPU access for both jax and pytorch.
os.environ['CUDA_VISIBLE_DEVICES'] = ''

import jax
import torch

from algorithmic_efficiency import spec
from algorithmic_efficiency.workloads.librispeech_conformer.librispeech_jax.workload import \
LibriSpeechConformerAttentionTemperatureWorkload as JaxWorkload
from algorithmic_efficiency.workloads.librispeech_conformer.librispeech_pytorch.workload import \
LibriSpeechConformerAttentionTemperatureWorkload as PytWorkload
from tests.modeldiffs.diff import out_diff


def key_transform(k):
new_key = []
for i in k:
if 'ModuleList' in i:
continue
if 'Linear' in i:
if 'NonDynamicallyQuantizableLinear' in i:
i = 'out'
else:
i = i.replace('Linear', 'Dense')
elif 'Conv1d' in i:
i = i.replace('Conv1d', 'Conv')
elif 'MHSAwithQS' in i:
i = i.replace('MHSAwithQS', 'SelfAttention')
elif 'weight' in i:
i = i.replace('weight', 'kernel')
new_key.append(i)
return tuple(new_key)


def sd_transform(sd):
out = {}
for k in sd:
if 'Attention' in ''.join(k):
if 'Dense_0' in k[-2]:
# In-proj
new_key = k[:-2]
chunks = sd[k].chunk(3)
for t, c in zip(['query', 'key', 'value'], chunks):
out[new_key + (t, k[-1])] = c
elif 'Dense_1' in k[-2]:
# Out-proj
out[(*k[:-2], 'out', k[-1])] = sd[k]
else:
out[k] = sd[k]
else:
out[k] = sd[k]
return out


if __name__ == '__main__':
# pylint: disable=locally-disabled, not-callable

jax_workload = JaxWorkload()
pytorch_workload = PytWorkload()

# Test outputs for identical weights and inputs.
wave = torch.randn(2, 320000)
pad = torch.zeros_like(wave)
pad[0, 200000:] = 1

jax_batch = {'inputs': (wave.detach().numpy(), pad.detach().numpy())}
pyt_batch = {'inputs': (wave, pad)}

pytorch_model_kwargs = dict(
augmented_and_preprocessed_input_batch=pyt_batch,
model_state=None,
mode=spec.ForwardPassMode.EVAL,
rng=None,
update_batch_norm=False)

jax_model_kwargs = dict(
augmented_and_preprocessed_input_batch=jax_batch,
mode=spec.ForwardPassMode.EVAL,
rng=jax.random.PRNGKey(0),
update_batch_norm=False)

out_diff(
jax_workload=jax_workload,
pytorch_workload=pytorch_workload,
jax_model_kwargs=jax_model_kwargs,
pytorch_model_kwargs=pytorch_model_kwargs,
key_transform=key_transform,
sd_transform=sd_transform,
out_transform=lambda out_outpad: out_outpad[0] *
(1 - out_outpad[1][:, :, None]))
Empty file.
92 changes: 92 additions & 0 deletions tests/modeldiffs/librispeech_conformer_gelu/compare.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,92 @@
import os

# Disable GPU access for both jax and pytorch.
os.environ['CUDA_VISIBLE_DEVICES'] = ''

import jax
import torch

from algorithmic_efficiency import spec
from algorithmic_efficiency.workloads.librispeech_conformer.librispeech_jax.workload import \
LibriSpeechConformerGeluWorkload as JaxWorkload
from algorithmic_efficiency.workloads.librispeech_conformer.librispeech_pytorch.workload import \
LibriSpeechConformerGeluWorkload as PytWorkload
from tests.modeldiffs.diff import out_diff


def key_transform(k):
new_key = []
for i in k:
if 'ModuleList' in i:
continue
if 'Linear' in i:
if 'NonDynamicallyQuantizableLinear' in i:
i = 'out'
else:
i = i.replace('Linear', 'Dense')
elif 'Conv1d' in i:
i = i.replace('Conv1d', 'Conv')
elif 'MHSAwithQS' in i:
i = i.replace('MHSAwithQS', 'SelfAttention')
elif 'weight' in i:
i = i.replace('weight', 'kernel')
new_key.append(i)
return tuple(new_key)


def sd_transform(sd):
out = {}
for k in sd:
if 'Attention' in ''.join(k):
if 'Dense_0' in k[-2]:
# In-proj
new_key = k[:-2]
chunks = sd[k].chunk(3)
for t, c in zip(['query', 'key', 'value'], chunks):
out[new_key + (t, k[-1])] = c
elif 'Dense_1' in k[-2]:
# Out-proj
out[(*k[:-2], 'out', k[-1])] = sd[k]
else:
out[k] = sd[k]
else:
out[k] = sd[k]
return out


if __name__ == '__main__':
# pylint: disable=locally-disabled, not-callable

jax_workload = JaxWorkload()
pytorch_workload = PytWorkload()

# Test outputs for identical weights and inputs.
wave = torch.randn(2, 320000)
pad = torch.zeros_like(wave)
pad[0, 200000:] = 1

jax_batch = {'inputs': (wave.detach().numpy(), pad.detach().numpy())}
pyt_batch = {'inputs': (wave, pad)}

pytorch_model_kwargs = dict(
augmented_and_preprocessed_input_batch=pyt_batch,
model_state=None,
mode=spec.ForwardPassMode.EVAL,
rng=None,
update_batch_norm=False)

jax_model_kwargs = dict(
augmented_and_preprocessed_input_batch=jax_batch,
mode=spec.ForwardPassMode.EVAL,
rng=jax.random.PRNGKey(0),
update_batch_norm=False)

out_diff(
jax_workload=jax_workload,
pytorch_workload=pytorch_workload,
jax_model_kwargs=jax_model_kwargs,
pytorch_model_kwargs=pytorch_model_kwargs,
key_transform=key_transform,
sd_transform=sd_transform,
out_transform=lambda out_outpad: out_outpad[0] *
(1 - out_outpad[1][:, :, None]))
Empty file.
92 changes: 92 additions & 0 deletions tests/modeldiffs/librispeech_conformer_layernorm/compare.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,92 @@
import os

# Disable GPU access for both jax and pytorch.
os.environ['CUDA_VISIBLE_DEVICES'] = ''

import jax
import torch

from algorithmic_efficiency import spec
from algorithmic_efficiency.workloads.librispeech_conformer.librispeech_jax.workload import \
LibriSpeechConformerLayerNormWorkload as JaxWorkload
from algorithmic_efficiency.workloads.librispeech_conformer.librispeech_pytorch.workload import \
LibriSpeechConformerLayerNormWorkload as PytWorkload
from tests.modeldiffs.diff import out_diff


def key_transform(k):
new_key = []
for i in k:
if 'ModuleList' in i:
continue
if 'Linear' in i:
if 'NonDynamicallyQuantizableLinear' in i:
i = 'out'
else:
i = i.replace('Linear', 'Dense')
elif 'Conv1d' in i:
i = i.replace('Conv1d', 'Conv')
elif 'MHSAwithQS' in i:
i = i.replace('MHSAwithQS', 'SelfAttention')
elif 'weight' in i:
i = i.replace('weight', 'kernel')
new_key.append(i)
return tuple(new_key)


def sd_transform(sd):
out = {}
for k in sd:
if 'Attention' in ''.join(k):
if 'Dense_0' in k[-2]:
# In-proj
new_key = k[:-2]
chunks = sd[k].chunk(3)
for t, c in zip(['query', 'key', 'value'], chunks):
out[new_key + (t, k[-1])] = c
elif 'Dense_1' in k[-2]:
# Out-proj
out[(*k[:-2], 'out', k[-1])] = sd[k]
else:
out[k] = sd[k]
else:
out[k] = sd[k]
return out


if __name__ == '__main__':
# pylint: disable=locally-disabled, not-callable

jax_workload = JaxWorkload()
pytorch_workload = PytWorkload()

# Test outputs for identical weights and inputs.
wave = torch.randn(2, 320000)
pad = torch.zeros_like(wave)
pad[0, 200000:] = 1

jax_batch = {'inputs': (wave.detach().numpy(), pad.detach().numpy())}
pyt_batch = {'inputs': (wave, pad)}

pytorch_model_kwargs = dict(
augmented_and_preprocessed_input_batch=pyt_batch,
model_state=None,
mode=spec.ForwardPassMode.EVAL,
rng=None,
update_batch_norm=False)

jax_model_kwargs = dict(
augmented_and_preprocessed_input_batch=jax_batch,
mode=spec.ForwardPassMode.EVAL,
rng=jax.random.PRNGKey(0),
update_batch_norm=False)

out_diff(
jax_workload=jax_workload,
pytorch_workload=pytorch_workload,
jax_model_kwargs=jax_model_kwargs,
pytorch_model_kwargs=pytorch_model_kwargs,
key_transform=key_transform,
sd_transform=sd_transform,
out_transform=lambda out_outpad: out_outpad[0] *
(1 - out_outpad[1][:, :, None]))

0 comments on commit e61c1d9

Please sign in to comment.