Skip to content

Commit

Permalink
Fix linting
Browse files Browse the repository at this point in the history
  • Loading branch information
pomonam committed Jul 15, 2024
1 parent 18f5a00 commit 3e8550b
Show file tree
Hide file tree
Showing 2 changed files with 9 additions and 4 deletions.
9 changes: 5 additions & 4 deletions kronfluence/factor/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@
GRADIENT_EIGENVALUES_NAME,
GRADIENT_EIGENVECTORS_NAME,
HEURISTIC_DAMPING_SCALE,
LAMBDA_DTYPE,
LAMBDA_MATRIX_NAME,
NUM_LAMBDA_PROCESSED,
)
Expand Down Expand Up @@ -196,7 +197,7 @@ def requires_lambda_matrices_for_precondition(self) -> bool:
return True

def prepare(self, storage: STORAGE_TYPE, score_args: Any, device: torch.device) -> None:
lambda_matrix = storage[LAMBDA_MATRIX_NAME].to(dtype=torch.float64, device=device)
lambda_matrix = storage[LAMBDA_MATRIX_NAME].to(dtype=LAMBDA_DTYPE, device=device)
lambda_matrix.div_(storage[NUM_LAMBDA_PROCESSED].to(device=device))
damping_factor = score_args.damping_factor
if damping_factor is None:
Expand Down Expand Up @@ -256,8 +257,8 @@ def prepare(self, storage: STORAGE_TYPE, score_args: Any, device: torch.device)
storage[GRADIENT_EIGENVECTORS_NAME] = (
storage[GRADIENT_EIGENVECTORS_NAME].to(dtype=score_args.precondition_dtype).contiguous()
)
activation_eigenvalues = storage[ACTIVATION_EIGENVALUES_NAME].to(dtype=torch.float64, device=device)
gradient_eigenvalues = storage[GRADIENT_EIGENVALUES_NAME].to(dtype=torch.float64, device=device)
activation_eigenvalues = storage[ACTIVATION_EIGENVALUES_NAME].to(dtype=LAMBDA_DTYPE, device=device)
gradient_eigenvalues = storage[GRADIENT_EIGENVALUES_NAME].to(dtype=LAMBDA_DTYPE, device=device)
lambda_matrix = torch.kron(activation_eigenvalues.unsqueeze(0), gradient_eigenvalues.unsqueeze(-1)).unsqueeze(0)
damping_factor = score_args.damping_factor
if damping_factor is None:
Expand Down Expand Up @@ -327,7 +328,7 @@ def prepare(self, storage: STORAGE_TYPE, score_args: Any, device: torch.device)
)
storage[ACTIVATION_EIGENVALUES_NAME] = None
storage[GRADIENT_EIGENVALUES_NAME] = None
lambda_matrix = storage[LAMBDA_MATRIX_NAME].to(dtype=torch.float64, device=device)
lambda_matrix = storage[LAMBDA_MATRIX_NAME].to(dtype=LAMBDA_DTYPE, device=device)
lambda_matrix.div_(storage[NUM_LAMBDA_PROCESSED].to(device=device))
damping_factor = score_args.damping_factor
if damping_factor is None:
Expand Down
4 changes: 4 additions & 0 deletions kronfluence/utils/constants.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@
# The total iteration step to synchronize the process when using distributed setting.
DISTRIBUTED_SYNC_INTERVAL = 1_000

# The scale for the heuristic damping term.
HEURISTIC_DAMPING_SCALE = 0.1

# Activation covariance matrix.
Expand Down Expand Up @@ -76,3 +77,6 @@

# The dictionary key for storing summed scores.
ALL_MODULE_NAME = "all_modules"

# Data type when computing the reciprocal.
LAMBDA_DTYPE = torch.float64

0 comments on commit 3e8550b

Please sign in to comment.