From 3e8550b3c011e0d6415ebac5b8de61bf9b8769cf Mon Sep 17 00:00:00 2001 From: Juhan Bae Date: Mon, 15 Jul 2024 16:31:06 -0400 Subject: [PATCH] Fix linting --- kronfluence/factor/config.py | 9 +++++---- kronfluence/utils/constants.py | 4 ++++ 2 files changed, 9 insertions(+), 4 deletions(-) diff --git a/kronfluence/factor/config.py b/kronfluence/factor/config.py index ac8d32f..39190d2 100644 --- a/kronfluence/factor/config.py +++ b/kronfluence/factor/config.py @@ -10,6 +10,7 @@ GRADIENT_EIGENVALUES_NAME, GRADIENT_EIGENVECTORS_NAME, HEURISTIC_DAMPING_SCALE, + LAMBDA_DTYPE, LAMBDA_MATRIX_NAME, NUM_LAMBDA_PROCESSED, ) @@ -196,7 +197,7 @@ def requires_lambda_matrices_for_precondition(self) -> bool: return True def prepare(self, storage: STORAGE_TYPE, score_args: Any, device: torch.device) -> None: - lambda_matrix = storage[LAMBDA_MATRIX_NAME].to(dtype=torch.float64, device=device) + lambda_matrix = storage[LAMBDA_MATRIX_NAME].to(dtype=LAMBDA_DTYPE, device=device) lambda_matrix.div_(storage[NUM_LAMBDA_PROCESSED].to(device=device)) damping_factor = score_args.damping_factor if damping_factor is None: @@ -256,8 +257,8 @@ def prepare(self, storage: STORAGE_TYPE, score_args: Any, device: torch.device) storage[GRADIENT_EIGENVECTORS_NAME] = ( storage[GRADIENT_EIGENVECTORS_NAME].to(dtype=score_args.precondition_dtype).contiguous() ) - activation_eigenvalues = storage[ACTIVATION_EIGENVALUES_NAME].to(dtype=torch.float64, device=device) - gradient_eigenvalues = storage[GRADIENT_EIGENVALUES_NAME].to(dtype=torch.float64, device=device) + activation_eigenvalues = storage[ACTIVATION_EIGENVALUES_NAME].to(dtype=LAMBDA_DTYPE, device=device) + gradient_eigenvalues = storage[GRADIENT_EIGENVALUES_NAME].to(dtype=LAMBDA_DTYPE, device=device) lambda_matrix = torch.kron(activation_eigenvalues.unsqueeze(0), gradient_eigenvalues.unsqueeze(-1)).unsqueeze(0) damping_factor = score_args.damping_factor if damping_factor is None: @@ -327,7 +328,7 @@ def prepare(self, storage: STORAGE_TYPE, score_args: Any, device: torch.device) ) storage[ACTIVATION_EIGENVALUES_NAME] = None storage[GRADIENT_EIGENVALUES_NAME] = None - lambda_matrix = storage[LAMBDA_MATRIX_NAME].to(dtype=torch.float64, device=device) + lambda_matrix = storage[LAMBDA_MATRIX_NAME].to(dtype=LAMBDA_DTYPE, device=device) lambda_matrix.div_(storage[NUM_LAMBDA_PROCESSED].to(device=device)) damping_factor = score_args.damping_factor if damping_factor is None: diff --git a/kronfluence/utils/constants.py b/kronfluence/utils/constants.py index 166a26d..1f75cb4 100644 --- a/kronfluence/utils/constants.py +++ b/kronfluence/utils/constants.py @@ -18,6 +18,7 @@ # The total iteration step to synchronize the process when using distributed setting. DISTRIBUTED_SYNC_INTERVAL = 1_000 +# The scale for the heuristic damping term. HEURISTIC_DAMPING_SCALE = 0.1 # Activation covariance matrix. @@ -76,3 +77,6 @@ # The dictionary key for storing summed scores. ALL_MODULE_NAME = "all_modules" + +# Data type when computing the reciprocal. +LAMBDA_DTYPE = torch.float64