Skip to content

Commit

Permalink
Make batch size more difficult
Browse files Browse the repository at this point in the history
  • Loading branch information
pomonam committed Mar 12, 2024
1 parent 05daad4 commit 6b9b736
Show file tree
Hide file tree
Showing 2 changed files with 18 additions and 17 deletions.
20 changes: 9 additions & 11 deletions tests/gpu_tests/cpu_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@
construct_mnist_mlp,
get_mnist_dataset,
)
from tests.gpu_tests.prepare_tests import TRAIN_INDICES, QUERY_INDICES
from tests.utils import check_tensor_dict_equivalence

logging.basicConfig(level=logging.DEBUG)
Expand All @@ -33,9 +34,7 @@ def setUpClass(cls) -> None:
cls.model = cls.model.double()

cls.train_dataset = get_mnist_dataset(split="train", data_path="data")
cls.train_dataset = data.Subset(cls.train_dataset, indices=list(range(200)))
cls.eval_dataset = get_mnist_dataset(split="valid", data_path="data")
cls.eval_dataset = data.Subset(cls.eval_dataset, indices=list(range(100)))

cls.task = ClassificationTask()
cls.model = prepare_model(cls.model, cls.task)
Expand All @@ -54,7 +53,7 @@ def test_covariance_matrices(self) -> None:
factors_name=NEW_FACTOR_NAME,
dataset=self.train_dataset,
factor_args=factor_args,
per_device_batch_size=16,
per_device_batch_size=512,
overwrite_output_dir=True,
)
new_covariance_factors = self.analyzer.load_covariance_matrices(factors_name=NEW_FACTOR_NAME)
Expand Down Expand Up @@ -83,7 +82,7 @@ def test_lambda_matrices(self):
factors_name=NEW_FACTOR_NAME,
dataset=self.train_dataset,
factor_args=factor_args,
per_device_batch_size=16,
per_device_batch_size=512,
overwrite_output_dir=True,
load_from_factors_name=OLD_FACTOR_NAME,
)
Expand Down Expand Up @@ -114,10 +113,10 @@ def test_pairwise_scores(self) -> None:
factors_name=OLD_FACTOR_NAME,
query_dataset=self.eval_dataset,
train_dataset=self.train_dataset,
train_indices=list(range(42)),
query_indices=list(range(23)),
per_device_query_batch_size=2,
per_device_train_batch_size=4,
train_indices=list(range(TRAIN_INDICES)),
query_indices=list(range(QUERY_INDICES)),
per_device_query_batch_size=12,
per_device_train_batch_size=512,
score_args=score_args,
overwrite_output_dir=True,
)
Expand Down Expand Up @@ -145,15 +144,14 @@ def test_self_scores(self) -> None:
scores_name=NEW_SCORE_NAME,
factors_name=OLD_FACTOR_NAME,
train_dataset=self.train_dataset,
train_indices=list(range(42)),
per_device_train_batch_size=4,
train_indices=list(range(TRAIN_INDICES)),
per_device_train_batch_size=512,
score_args=score_args,
overwrite_output_dir=True,
)
new_self_scores = self.analyzer.load_self_scores(scores_name=NEW_SCORE_NAME)

self_scores = self.analyzer.load_self_scores(scores_name=OLD_SCORE_NAME)
torch.set_printoptions(threshold=30_000)
print(f"Previous score: {self_scores[ALL_MODULE_NAME]}")
print(f"Previous shape: {self_scores[ALL_MODULE_NAME].shape}")
print(f"New score: {new_self_scores[ALL_MODULE_NAME]}")
Expand Down
15 changes: 9 additions & 6 deletions tests/gpu_tests/prepare_tests.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,11 @@
)


# Pick difficult cases where the dataset is not perfectly divisible by batch size.
TRAIN_INDICES = 59_999
QUERY_INDICES = 50


def train() -> None:
assert torch.cuda.is_available()
device = torch.device("cuda")
Expand Down Expand Up @@ -76,9 +81,7 @@ def run_analysis() -> None:
model.load_state_dict(torch.load("model.pth"))

train_dataset = get_mnist_dataset(split="train", data_path="data")
train_dataset = Subset(train_dataset, indices=list(range(200)))
eval_dataset = get_mnist_dataset(split="valid", data_path="data")
eval_dataset = Subset(eval_dataset, indices=list(range(100)))

task = ClassificationTask()
model = model.double()
Expand All @@ -100,7 +103,7 @@ def run_analysis() -> None:
factors_name="single_gpu",
dataset=train_dataset,
factor_args=factor_args,
per_device_batch_size=32,
per_device_batch_size=512,
overwrite_output_dir=True,
)

Expand All @@ -114,8 +117,8 @@ def run_analysis() -> None:
factors_name="single_gpu",
query_dataset=eval_dataset,
train_dataset=train_dataset,
train_indices=list(range(59_999)),
query_indices=list(range(50)),
train_indices=list(range(TRAIN_INDICES)),
query_indices=list(range(QUERY_INDICES)),
per_device_query_batch_size=12,
per_device_train_batch_size=512,
score_args=score_args,
Expand All @@ -125,7 +128,7 @@ def run_analysis() -> None:
scores_name="single_gpu",
factors_name="single_gpu",
train_dataset=train_dataset,
train_indices=list(range(59_999)),
train_indices=list(range(TRAIN_INDICES)),
per_device_train_batch_size=512,
score_args=score_args,
overwrite_output_dir=True,
Expand Down

0 comments on commit 6b9b736

Please sign in to comment.