diff --git a/tests/gpu_tests/ddp_test.py b/tests/gpu_tests/ddp_test.py index 2ff0499..5abf37c 100644 --- a/tests/gpu_tests/ddp_test.py +++ b/tests/gpu_tests/ddp_test.py @@ -9,8 +9,8 @@ from torch.utils import data from kronfluence.analyzer import Analyzer, prepare_model -from kronfluence.arguments import FactorArguments, ScoreArguments from kronfluence.utils.common.factor_arguments import pytest_factor_arguments +from kronfluence.utils.common.score_arguments import pytest_score_arguments from kronfluence.utils.constants import ( ALL_MODULE_NAME, COVARIANCE_FACTOR_NAMES, @@ -169,233 +169,280 @@ def test_lambda_shared_matrices(self) -> None: rtol=RTOL, ) - # def test_pairwise_scores(self) -> None: - # pairwise_scores = self.analyzer.load_pairwise_scores(scores_name=OLD_SCORE_NAME) - # - # score_args = ScoreArguments( - # score_dtype=torch.float64, - # per_sample_gradient_dtype=torch.float64, - # precondition_dtype=torch.float64, - # ) - # self.analyzer.compute_pairwise_scores( - # scores_name=NEW_SCORE_NAME, - # factors_name=OLD_FACTOR_NAME, - # query_dataset=self.eval_dataset, - # train_dataset=self.train_dataset, - # train_indices=list(range(TRAIN_INDICES)), - # query_indices=list(range(QUERY_INDICES)), - # per_device_query_batch_size=12, - # per_device_train_batch_size=512, - # score_args=score_args, - # overwrite_output_dir=True, - # ) - # new_pairwise_scores = self.analyzer.load_pairwise_scores(scores_name=NEW_SCORE_NAME) - # - # if LOCAL_RANK == 0: - # print(f"Previous score: {pairwise_scores[ALL_MODULE_NAME][0]}") - # print(f"Previous shape: {pairwise_scores[ALL_MODULE_NAME].shape}") - # print(f"New score: {new_pairwise_scores[ALL_MODULE_NAME][0]}") - # print(f"New shape: {new_pairwise_scores[ALL_MODULE_NAME].shape}") - # print(f"Previous score: {pairwise_scores[ALL_MODULE_NAME][50]}") - # print(f"Previous shape: {pairwise_scores[ALL_MODULE_NAME].shape}") - # print(f"New score: {new_pairwise_scores[ALL_MODULE_NAME][50]}") - # print(f"New shape: {new_pairwise_scores[ALL_MODULE_NAME].shape}") - # assert check_tensor_dict_equivalence( - # pairwise_scores, - # new_pairwise_scores, - # atol=1e-5, - # rtol=1e-3, - # ) - # - # def test_pairwise_partition_scores(self) -> None: - # pairwise_scores = self.analyzer.load_pairwise_scores(scores_name=OLD_SCORE_NAME) - # - # score_args = ScoreArguments( - # module_partition_size=2, - # data_partition_size=2, - # score_dtype=torch.float64, - # per_sample_gradient_dtype=torch.float64, - # precondition_dtype=torch.float64, - # ) - # self.analyzer.compute_pairwise_scores( - # scores_name=NEW_SCORE_NAME, - # factors_name=OLD_FACTOR_NAME, - # query_dataset=self.eval_dataset, - # train_dataset=self.train_dataset, - # train_indices=list(range(TRAIN_INDICES)), - # query_indices=list(range(QUERY_INDICES)), - # per_device_query_batch_size=12, - # per_device_train_batch_size=512, - # score_args=score_args, - # overwrite_output_dir=True, - # ) - # new_pairwise_scores = self.analyzer.load_pairwise_scores(scores_name=NEW_SCORE_NAME) - # - # if LOCAL_RANK == 0: - # print(f"Previous score: {pairwise_scores[ALL_MODULE_NAME][0]}") - # print(f"Previous shape: {pairwise_scores[ALL_MODULE_NAME].shape}") - # print(f"New score: {new_pairwise_scores[ALL_MODULE_NAME][0]}") - # print(f"New shape: {new_pairwise_scores[ALL_MODULE_NAME].shape}") - # print(f"Previous score: {pairwise_scores[ALL_MODULE_NAME][50]}") - # print(f"Previous shape: {pairwise_scores[ALL_MODULE_NAME].shape}") - # print(f"New score: {new_pairwise_scores[ALL_MODULE_NAME][50]}") - # print(f"New shape: {new_pairwise_scores[ALL_MODULE_NAME].shape}") - # assert check_tensor_dict_equivalence( - # pairwise_scores, - # new_pairwise_scores, - # atol=1e-5, - # rtol=1e-3, - # ) - # - # def test_self_scores(self) -> None: - # self_scores = self.analyzer.load_self_scores(scores_name=OLD_SCORE_NAME) - # - # score_args = ScoreArguments( - # score_dtype=torch.float64, - # per_sample_gradient_dtype=torch.float64, - # precondition_dtype=torch.float64, - # ) - # self.analyzer.compute_self_scores( - # scores_name=NEW_SCORE_NAME, - # factors_name=OLD_FACTOR_NAME, - # train_dataset=self.train_dataset, - # train_indices=list(range(TRAIN_INDICES)), - # per_device_train_batch_size=512, - # score_args=score_args, - # overwrite_output_dir=True, - # ) - # new_self_scores = self.analyzer.load_self_scores(scores_name=NEW_SCORE_NAME) - # - # if LOCAL_RANK == 0: - # print(f"Previous score: {self_scores[ALL_MODULE_NAME]}") - # print(f"Previous shape: {self_scores[ALL_MODULE_NAME].shape}") - # print(f"New score: {new_self_scores[ALL_MODULE_NAME]}") - # print(f"New shape: {new_self_scores[ALL_MODULE_NAME].shape}") - # assert check_tensor_dict_equivalence( - # self_scores, - # new_self_scores, - # atol=1e-5, - # rtol=1e-3, - # ) - # - # def test_lr_pairwise_scores(self) -> None: - # pairwise_scores = self.analyzer.load_pairwise_scores(scores_name="single_gpu_qb") - # - # score_args = ScoreArguments( - # query_gradient_rank=32, - # score_dtype=torch.float64, - # per_sample_gradient_dtype=torch.float64, - # precondition_dtype=torch.float64, - # query_gradient_svd_dtype=torch.float64, - # ) - # self.analyzer.compute_pairwise_scores( - # scores_name="ddp_qb", - # factors_name=OLD_FACTOR_NAME, - # query_dataset=self.eval_dataset, - # train_dataset=self.train_dataset, - # train_indices=list(range(TRAIN_INDICES)), - # query_indices=list(range(QUERY_INDICES)), - # per_device_query_batch_size=12, - # per_device_train_batch_size=512, - # score_args=score_args, - # overwrite_output_dir=True, - # ) - # new_pairwise_scores = self.analyzer.load_pairwise_scores(scores_name="ddp_qb") - # - # if LOCAL_RANK == 0: - # print(f"Previous score: {pairwise_scores[ALL_MODULE_NAME][0]}") - # print(f"Previous shape: {pairwise_scores[ALL_MODULE_NAME].shape}") - # print(f"New score: {new_pairwise_scores[ALL_MODULE_NAME][0]}") - # print(f"New shape: {new_pairwise_scores[ALL_MODULE_NAME].shape}") - # print(f"Previous score: {pairwise_scores[ALL_MODULE_NAME][50]}") - # print(f"Previous shape: {pairwise_scores[ALL_MODULE_NAME].shape}") - # print(f"New score: {new_pairwise_scores[ALL_MODULE_NAME][50]}") - # print(f"New shape: {new_pairwise_scores[ALL_MODULE_NAME].shape}") - # assert check_tensor_dict_equivalence( - # pairwise_scores, - # new_pairwise_scores, - # atol=1e-3, - # rtol=1e-1, - # ) - # - # def test_per_module_pairwise_scores(self) -> None: - # pairwise_scores = self.analyzer.load_pairwise_scores(scores_name="single_gpu_qb") - # - # score_args = ScoreArguments( - # per_module_score=True, - # score_dtype=torch.float64, - # per_sample_gradient_dtype=torch.float64, - # precondition_dtype=torch.float64, - # query_gradient_svd_dtype=torch.float64, - # ) - # self.analyzer.compute_pairwise_scores( - # scores_name=NEW_SCORE_NAME + "_per_module", - # factors_name=OLD_FACTOR_NAME, - # query_dataset=self.eval_dataset, - # train_dataset=self.train_dataset, - # train_indices=list(range(TRAIN_INDICES)), - # query_indices=list(range(QUERY_INDICES)), - # per_device_query_batch_size=12, - # per_device_train_batch_size=512, - # score_args=score_args, - # overwrite_output_dir=True, - # ) - # new_pairwise_scores = self.analyzer.load_pairwise_scores(scores_name="ddp_qb") - # - # if LOCAL_RANK == 0: - # print(f"Previous score: {pairwise_scores[ALL_MODULE_NAME][0]}") - # print(f"Previous shape: {pairwise_scores[ALL_MODULE_NAME].shape}") - # print(f"New score: {new_pairwise_scores[ALL_MODULE_NAME][0]}") - # print(f"New shape: {new_pairwise_scores[ALL_MODULE_NAME].shape}") - # print(f"Previous score: {pairwise_scores[ALL_MODULE_NAME][50]}") - # print(f"Previous shape: {pairwise_scores[ALL_MODULE_NAME].shape}") - # print(f"New score: {new_pairwise_scores[ALL_MODULE_NAME][50]}") - # print(f"New shape: {new_pairwise_scores[ALL_MODULE_NAME].shape}") - # assert check_tensor_dict_equivalence( - # pairwise_scores, - # new_pairwise_scores, - # atol=1e-3, - # rtol=1e-1, - # ) - # - # def test_lr_accumulate_pairwise_scores(self) -> None: - # pairwise_scores = self.analyzer.load_pairwise_scores(scores_name="single_gpu_qb") - # - # score_args = ScoreArguments( - # query_gradient_rank=32, - # num_query_gradient_accumulations=3, - # score_dtype=torch.float64, - # per_sample_gradient_dtype=torch.float64, - # precondition_dtype=torch.float64, - # query_gradient_svd_dtype=torch.float64, - # ) - # self.analyzer.compute_pairwise_scores( - # scores_name="ddp_qb_agg", - # factors_name=OLD_FACTOR_NAME, - # query_dataset=self.eval_dataset, - # train_dataset=self.train_dataset, - # train_indices=list(range(TRAIN_INDICES)), - # query_indices=list(range(QUERY_INDICES)), - # per_device_query_batch_size=2, - # per_device_train_batch_size=512, - # score_args=score_args, - # overwrite_output_dir=True, - # ) - # new_pairwise_scores = self.analyzer.load_pairwise_scores(scores_name="ddp_qb_agg") - # - # if LOCAL_RANK == 0: - # print(f"Previous score: {pairwise_scores[ALL_MODULE_NAME][0]}") - # print(f"Previous shape: {pairwise_scores[ALL_MODULE_NAME].shape}") - # print(f"New score: {new_pairwise_scores[ALL_MODULE_NAME][0]}") - # print(f"New shape: {new_pairwise_scores[ALL_MODULE_NAME].shape}") - # print(f"Previous score: {pairwise_scores[ALL_MODULE_NAME][44]}") - # print(f"New score: {new_pairwise_scores[ALL_MODULE_NAME][44]}") - # assert check_tensor_dict_equivalence( - # pairwise_scores, - # new_pairwise_scores, - # atol=1e-1, - # rtol=1e-1, - # ) + def test_pairwise_scores(self) -> None: + pairwise_scores = self.analyzer.load_pairwise_scores(scores_name=OLD_SCORE_NAME) + + score_args = pytest_score_arguments() + self.analyzer.compute_pairwise_scores( + scores_name=NEW_SCORE_NAME, + factors_name=OLD_FACTOR_NAME, + query_dataset=self.eval_dataset, + train_dataset=self.train_dataset, + train_indices=list(range(TRAIN_INDICES)), + query_indices=list(range(QUERY_INDICES)), + per_device_query_batch_size=12, + per_device_train_batch_size=512, + score_args=score_args, + overwrite_output_dir=True, + ) + new_pairwise_scores = self.analyzer.load_pairwise_scores(scores_name=NEW_SCORE_NAME) + + if LOCAL_RANK == 0: + print(f"Previous score: {pairwise_scores[ALL_MODULE_NAME][0]}") + print(f"Previous shape: {pairwise_scores[ALL_MODULE_NAME].shape}") + print(f"New score: {new_pairwise_scores[ALL_MODULE_NAME][0]}") + print(f"New shape: {new_pairwise_scores[ALL_MODULE_NAME].shape}") + print(f"Previous score: {pairwise_scores[ALL_MODULE_NAME][50]}") + print(f"Previous shape: {pairwise_scores[ALL_MODULE_NAME].shape}") + print(f"New score: {new_pairwise_scores[ALL_MODULE_NAME][50]}") + print(f"New shape: {new_pairwise_scores[ALL_MODULE_NAME].shape}") + assert check_tensor_dict_equivalence( + pairwise_scores, + new_pairwise_scores, + atol=ATOL, + rtol=RTOL, + ) + + def test_pairwise_partition_scores(self) -> None: + pairwise_scores = self.analyzer.load_pairwise_scores(scores_name=OLD_SCORE_NAME) + + score_args = pytest_score_arguments() + score_args.module_partitions = 2 + score_args.data_partitions = 2 + self.analyzer.compute_pairwise_scores( + scores_name=NEW_SCORE_NAME, + factors_name=OLD_FACTOR_NAME, + query_dataset=self.eval_dataset, + train_dataset=self.train_dataset, + train_indices=list(range(TRAIN_INDICES)), + query_indices=list(range(QUERY_INDICES)), + per_device_query_batch_size=12, + per_device_train_batch_size=512, + score_args=score_args, + overwrite_output_dir=True, + ) + new_pairwise_scores = self.analyzer.load_pairwise_scores(scores_name=NEW_SCORE_NAME) + + if LOCAL_RANK == 0: + print(f"Previous score: {pairwise_scores[ALL_MODULE_NAME][0]}") + print(f"Previous shape: {pairwise_scores[ALL_MODULE_NAME].shape}") + print(f"New score: {new_pairwise_scores[ALL_MODULE_NAME][0]}") + print(f"New shape: {new_pairwise_scores[ALL_MODULE_NAME].shape}") + print(f"Previous score: {pairwise_scores[ALL_MODULE_NAME][50]}") + print(f"Previous shape: {pairwise_scores[ALL_MODULE_NAME].shape}") + print(f"New score: {new_pairwise_scores[ALL_MODULE_NAME][50]}") + print(f"New shape: {new_pairwise_scores[ALL_MODULE_NAME].shape}") + assert check_tensor_dict_equivalence( + pairwise_scores, + new_pairwise_scores, + atol=ATOL, + rtol=RTOL, + ) + + def test_self_scores(self) -> None: + score_args = pytest_score_arguments() + self.analyzer.compute_self_scores( + scores_name=NEW_SCORE_NAME, + factors_name=OLD_FACTOR_NAME, + train_dataset=self.train_dataset, + train_indices=list(range(TRAIN_INDICES)), + per_device_train_batch_size=512, + score_args=score_args, + overwrite_output_dir=True, + ) + new_self_scores = self.analyzer.load_self_scores(scores_name=NEW_SCORE_NAME) + self_scores = self.analyzer.load_self_scores(scores_name=OLD_SCORE_NAME) + + if LOCAL_RANK == 0: + print(f"Previous score: {self_scores[ALL_MODULE_NAME]}") + print(f"Previous shape: {self_scores[ALL_MODULE_NAME].shape}") + print(f"New score: {new_self_scores[ALL_MODULE_NAME]}") + print(f"New shape: {new_self_scores[ALL_MODULE_NAME].shape}") + assert check_tensor_dict_equivalence( + self_scores, + new_self_scores, + atol=ATOL, + rtol=RTOL, + ) + + score_args.use_measurement_for_self_influence = True + self.analyzer.compute_self_scores( + scores_name=NEW_SCORE_NAME + "_measurement", + factors_name=OLD_FACTOR_NAME, + train_dataset=self.train_dataset, + train_indices=list(range(TRAIN_INDICES)), + per_device_train_batch_size=512, + score_args=score_args, + overwrite_output_dir=True, + ) + new_self_scores = self.analyzer.load_self_scores(scores_name=NEW_SCORE_NAME + "_measurement") + self_scores = self.analyzer.load_self_scores(scores_name=OLD_SCORE_NAME + "_measurement") + + if LOCAL_RANK == 0: + print(f"Previous score: {self_scores[ALL_MODULE_NAME]}") + print(f"Previous shape: {self_scores[ALL_MODULE_NAME].shape}") + print(f"New score: {new_self_scores[ALL_MODULE_NAME]}") + print(f"New shape: {new_self_scores[ALL_MODULE_NAME].shape}") + assert check_tensor_dict_equivalence( + self_scores, + new_self_scores, + atol=ATOL, + rtol=RTOL, + ) + + def test_lr_pairwise_scores(self) -> None: + pairwise_scores = self.analyzer.load_pairwise_scores(scores_name="single_gpu_qb") + + score_args = pytest_score_arguments() + score_args.query_gradient_low_rank = 32 + self.analyzer.compute_pairwise_scores( + scores_name="ddp_qb", + factors_name=OLD_FACTOR_NAME, + query_dataset=self.eval_dataset, + train_dataset=self.train_dataset, + train_indices=list(range(TRAIN_INDICES)), + query_indices=list(range(QUERY_INDICES)), + per_device_query_batch_size=12, + per_device_train_batch_size=512, + score_args=score_args, + overwrite_output_dir=True, + ) + new_pairwise_scores = self.analyzer.load_pairwise_scores(scores_name="ddp_qb") + + if LOCAL_RANK == 0: + print(f"Previous score: {pairwise_scores[ALL_MODULE_NAME][0]}") + print(f"Previous shape: {pairwise_scores[ALL_MODULE_NAME].shape}") + print(f"New score: {new_pairwise_scores[ALL_MODULE_NAME][0]}") + print(f"New shape: {new_pairwise_scores[ALL_MODULE_NAME].shape}") + print(f"Previous score: {pairwise_scores[ALL_MODULE_NAME][50]}") + print(f"Previous shape: {pairwise_scores[ALL_MODULE_NAME].shape}") + print(f"New score: {new_pairwise_scores[ALL_MODULE_NAME][50]}") + print(f"New shape: {new_pairwise_scores[ALL_MODULE_NAME].shape}") + assert check_tensor_dict_equivalence( + pairwise_scores, + new_pairwise_scores, + atol=ATOL, + rtol=RTOL, + ) + + def test_per_module_pairwise_scores(self) -> None: + pairwise_scores = self.analyzer.load_pairwise_scores(scores_name="single_gpu_qb") + + score_args = pytest_score_arguments() + score_args.compute_per_module_scores = True + self.analyzer.compute_pairwise_scores( + scores_name=NEW_SCORE_NAME + "_per_module", + factors_name=OLD_FACTOR_NAME, + query_dataset=self.eval_dataset, + train_dataset=self.train_dataset, + train_indices=list(range(TRAIN_INDICES)), + query_indices=list(range(QUERY_INDICES)), + per_device_query_batch_size=12, + per_device_train_batch_size=512, + score_args=score_args, + overwrite_output_dir=True, + ) + new_pairwise_scores = self.analyzer.load_pairwise_scores(scores_name="ddp_qb") + + if LOCAL_RANK == 0: + print(f"Previous score: {pairwise_scores[ALL_MODULE_NAME][0]}") + print(f"Previous shape: {pairwise_scores[ALL_MODULE_NAME].shape}") + print(f"New score: {new_pairwise_scores[ALL_MODULE_NAME][0]}") + print(f"New shape: {new_pairwise_scores[ALL_MODULE_NAME].shape}") + print(f"Previous score: {pairwise_scores[ALL_MODULE_NAME][50]}") + print(f"Previous shape: {pairwise_scores[ALL_MODULE_NAME].shape}") + print(f"New score: {new_pairwise_scores[ALL_MODULE_NAME][50]}") + print(f"New shape: {new_pairwise_scores[ALL_MODULE_NAME].shape}") + assert check_tensor_dict_equivalence( + pairwise_scores, + new_pairwise_scores, + atol=ATOL, + rtol=RTOL, + ) + + def test_lr_accumulate_pairwise_scores(self) -> None: + pairwise_scores = self.analyzer.load_pairwise_scores(scores_name="single_gpu_qb") + + score_args = pytest_score_arguments() + score_args.query_gradient_low_rank = 32 + score_args.query_gradient_accumulation_steps = 3 + self.analyzer.compute_pairwise_scores( + scores_name="ddp_qb_agg", + factors_name=OLD_FACTOR_NAME, + query_dataset=self.eval_dataset, + train_dataset=self.train_dataset, + train_indices=list(range(TRAIN_INDICES)), + query_indices=list(range(QUERY_INDICES)), + per_device_query_batch_size=2, + per_device_train_batch_size=512, + score_args=score_args, + overwrite_output_dir=True, + ) + new_pairwise_scores = self.analyzer.load_pairwise_scores(scores_name="ddp_qb_agg") + + if LOCAL_RANK == 0: + print(f"Previous score: {pairwise_scores[ALL_MODULE_NAME][0]}") + print(f"Previous shape: {pairwise_scores[ALL_MODULE_NAME].shape}") + print(f"New score: {new_pairwise_scores[ALL_MODULE_NAME][0]}") + print(f"New shape: {new_pairwise_scores[ALL_MODULE_NAME].shape}") + print(f"Previous score: {pairwise_scores[ALL_MODULE_NAME][44]}") + print(f"New score: {new_pairwise_scores[ALL_MODULE_NAME][44]}") + assert check_tensor_dict_equivalence( + pairwise_scores, + new_pairwise_scores, + atol=ATOL, + rtol=RTOL, + ) + + def test_aggregate_scores(self) -> None: + score_args = pytest_score_arguments() + score_args.aggregate_train_gradients = True + self.analyzer.compute_pairwise_scores( + scores_name="ddp", + factors_name=OLD_FACTOR_NAME, + query_dataset=self.eval_dataset, + train_dataset=self.train_dataset, + train_indices=list(range(TRAIN_INDICES)), + query_indices=list(range(QUERY_INDICES)), + per_device_query_batch_size=2, + per_device_train_batch_size=512, + score_args=score_args, + overwrite_output_dir=True, + ) + new_pairwise_scores = self.analyzer.load_pairwise_scores(scores_name="ddp") + pairwise_scores = self.analyzer.load_pairwise_scores(scores_name="single_gpu_train_agg") + + if LOCAL_RANK == 0: + assert check_tensor_dict_equivalence( + pairwise_scores, + new_pairwise_scores, + atol=ATOL, + rtol=RTOL, + ) + + score_args.aggregate_query_gradients = True + self.analyzer.compute_pairwise_scores( + scores_name="ddp", + factors_name=OLD_FACTOR_NAME, + query_dataset=self.eval_dataset, + train_dataset=self.train_dataset, + train_indices=list(range(TRAIN_INDICES)), + query_indices=list(range(QUERY_INDICES)), + per_device_query_batch_size=2, + per_device_train_batch_size=512, + score_args=score_args, + overwrite_output_dir=True, + ) + new_pairwise_scores = self.analyzer.load_pairwise_scores(scores_name="ddp") + pairwise_scores = self.analyzer.load_pairwise_scores(scores_name="single_gpu") + + if LOCAL_RANK == 0: + assert check_tensor_dict_equivalence( + pairwise_scores, + new_pairwise_scores, + atol=ATOL, + rtol=RTOL, + ) @classmethod def tearDownClass(cls) -> None: