Skip to content

Commit

Permalink
Add more test cases
Browse files Browse the repository at this point in the history
  • Loading branch information
pomonam committed Mar 22, 2024
1 parent 12df935 commit dbbcdc4
Show file tree
Hide file tree
Showing 4 changed files with 84 additions and 36 deletions.
42 changes: 9 additions & 33 deletions tests/test_per_sample_gradients.py
Original file line number Diff line number Diff line change
Expand Up @@ -95,6 +95,7 @@ def for_loop_per_sample_gradient(
"conv",
"conv_bn",
"bert",
"gpt",
],
)
@pytest.mark.parametrize("use_measurement", [True, False])
Expand Down Expand Up @@ -177,7 +178,11 @@ def test_for_loop_per_sample_gradient_equivalence(
task=task,
use_measurement=use_measurement,
)

for i in range(num_batches):
if "lm_head" in for_loop_per_sample_gradients[i]:
del for_loop_per_sample_gradients[i]["lm_head"]

assert check_tensor_dict_equivalence(
per_sample_gradients[i],
for_loop_per_sample_gradients[i],
Expand All @@ -192,6 +197,7 @@ def test_for_loop_per_sample_gradient_equivalence(
"mlp",
"repeated_mlp",
"conv",
"gpt",
],
)
@pytest.mark.parametrize("train_size", [32])
Expand Down Expand Up @@ -271,9 +277,8 @@ def test_lambda_equivalence(
)


@pytest.mark.parametrize("seed", [0])
def test_precondition_gradient(
seed: int,
seed: int = 0,
) -> None:
input_dim = 128
output_dim = 256
Expand Down Expand Up @@ -326,9 +331,8 @@ def test_precondition_gradient(
assert torch.allclose(raw_results, results, atol=1e-5, rtol=1e-3)


@pytest.mark.parametrize("seed", [0])
def test_query_gradient_svd(
seed: int,
seed: int = 0,
) -> None:
input_dim = 2048
output_dim = 1024
Expand Down Expand Up @@ -388,7 +392,6 @@ def test_query_gradient_svd(
assert torch.allclose(score, lr_score_reconst_matmul)

# These should be able to avoid explicit reconstruction.

# This should be used when input_dim > output_dim.
intermediate = opt_einsum.contract("qki,toi->qtko", right_mat, new_gradient)
final = opt_einsum.contract("qtko,qok->qt", intermediate, left_mat)
Expand Down Expand Up @@ -474,9 +477,8 @@ def test_query_gradient_svd_reconst(
assert intermediate.numel() <= reconst_numel


@pytest.mark.parametrize("seed", [0])
def test_compute_score_matmul(
seed: int,
seed: int = 0,
) -> None:
input_dim = 1024
output_dim = 2048
Expand All @@ -495,29 +497,3 @@ def test_compute_score_matmul(
assert torch.allclose(score, unsqueeze_score)
path = opt_einsum.contract_path("t...,q...->tq", gradient, new_gradient)
print(path)


@pytest.mark.parametrize("seed", [0])
def test_compute_score_fast_matmul(
seed: int,
) -> None:
input_dim = 512
output_dim = 1024
seq_len = 32
batch_dim = 8
query_batch_dim = 16

set_seed(seed)

input_activation = torch.rand(size=(batch_dim, seq_len, input_dim), dtype=torch.float64)
output_gradient = torch.rand(size=(batch_dim, seq_len, output_dim), dtype=torch.float64)
per_sample_gradient = opt_einsum.contract("b...i,b...o->bio", output_gradient, input_activation)
gradient = torch.rand(size=(query_batch_dim, output_dim, input_dim), dtype=torch.float64)
score = opt_einsum.contract("toi,qoi->tq", per_sample_gradient, gradient)
print(score)

all_score = opt_einsum.contract("tco,tci,qoi->tq", output_gradient, input_activation, gradient)
assert torch.allclose(score, all_score)

path = opt_einsum.contract_path("tco,tci,qoi->tq", output_gradient, input_activation, gradient, optimize="optimal")
print(path)
71 changes: 70 additions & 1 deletion tests/test_regression.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@
from tests.utils import prepare_test


def test_regression(
def test_mlp_regression(
test_name: str = "mlp",
strategy: str = "ekfac",
seed: int = 0,
Expand Down Expand Up @@ -75,3 +75,72 @@ def test_regression(
)
scores = analyzer.load_pairwise_scores("pairwise")
assert round(scores["all_modules"].sum().item(), 2) == 145.53


def test_conv_regression(
test_name: str = "conv",
strategy: str = "ekfac",
seed: int = 0,
train_size: int = 32,
query_size: int = 16,
) -> None:
model, train_dataset, query_dataset, data_collator, task = prepare_test(
test_name=test_name,
train_size=train_size,
query_size=query_size,
seed=seed,
)
assert round(list(model.named_parameters())[0][1].sum().item(), 2) == -0.75

model = prepare_model(model=model, task=task)
analyzer = Analyzer(
analysis_name=f"pytest_regression_{test_name}",
model=model,
task=task,
disable_model_save=True,
cpu=True,
)
kwargs = DataLoaderKwargs(collate_fn=data_collator)
factor_args = FactorArguments(strategy=strategy, use_empirical_fisher=True)
analyzer.fit_covariance_matrices(
factors_name=f"pytest_{test_name}",
dataset=train_dataset,
per_device_batch_size=1,
dataloader_kwargs=kwargs,
factor_args=factor_args,
overwrite_output_dir=True,
)
covariance_matrices = analyzer.load_covariance_matrices(f"pytest_{test_name}")
assert round(torch.sum(covariance_matrices["activation_covariance"]["0"] / train_size).item(), 2) == 42299.42

analyzer.perform_eigendecomposition(
factors_name=f"pytest_{test_name}",
factor_args=factor_args,
overwrite_output_dir=True,
)
eigen_factors = analyzer.load_eigendecomposition(f"pytest_{test_name}")
assert round(eigen_factors["activation_eigenvectors"]["0"].sum().item(), 2) == 4.34

analyzer.fit_lambda_matrices(
factors_name=f"pytest_{test_name}",
dataset=train_dataset,
per_device_batch_size=1,
dataloader_kwargs=kwargs,
factor_args=factor_args,
overwrite_output_dir=True,
)
lambda_matrices = analyzer.load_lambda_matrices(f"pytest_{test_name}")
assert round((lambda_matrices["lambda_matrix"]["0"] / train_size).sum().item(), 2) == 0.18

analyzer.compute_pairwise_scores(
scores_name="pairwise",
factors_name=f"pytest_{test_name}",
query_dataset=query_dataset,
per_device_query_batch_size=1,
train_dataset=train_dataset,
per_device_train_batch_size=1,
dataloader_kwargs=kwargs,
overwrite_output_dir=True,
)
scores = analyzer.load_pairwise_scores("pairwise")
assert round(scores["all_modules"].sum().item(), 2) == 6268.84
4 changes: 2 additions & 2 deletions tests/testable_tasks/language_modeling.py
Original file line number Diff line number Diff line change
Expand Up @@ -131,11 +131,11 @@ def compute_measurement(
def tracked_modules(self) -> List[str]:
total_modules = []

for i in range(4):
for i in range(5):
total_modules.append(f"transformer.h.{i}.attn.c_attn")
total_modules.append(f"transformer.h.{i}.attn.c_proj")

for i in range(4):
for i in range(5):
total_modules.append(f"transformer.h.{i}.mlp.c_fc")
total_modules.append(f"transformer.h.{i}.mlp.c_proj")

Expand Down
3 changes: 3 additions & 0 deletions tests/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -108,6 +108,9 @@ def reshape_parameter_gradient_to_module_matrix(
remove_gradient: bool = True,
) -> torch.Tensor:
if isinstance(module, nn.Linear):
if module_name == "lm_head":
# Edge case for small GPT model.
return
gradient_matrix = gradient_dict[module_name + ".weight"]
if remove_gradient:
del gradient_dict[module_name + ".weight"]
Expand Down

0 comments on commit dbbcdc4

Please sign in to comment.