Skip to content

Commit

Permalink
Minimize size for dot product
Browse files Browse the repository at this point in the history
  • Loading branch information
pomonam committed Jul 9, 2024
1 parent 0535cf4 commit 6645ca9
Show file tree
Hide file tree
Showing 5 changed files with 5 additions and 37 deletions.
10 changes: 1 addition & 9 deletions kronfluence/module/conv2d.py
Original file line number Diff line number Diff line change
Expand Up @@ -211,12 +211,4 @@ def compute_self_measurement_score(
input_activation = self._flatten_input_activation(input_activation=input_activation)
input_activation = input_activation.view(output_gradient.size(0), -1, input_activation.size(-1))
output_gradient = rearrange(tensor=output_gradient, pattern="b o i1 i2 -> b (i1 i2) o")
if self.einsum_expression is None:
self.einsum_expression = contract_expression(
"bio,bci,bco->b",
preconditioned_gradient.shape,
output_gradient.shape,
input_activation.shape,
optimize=DynamicProgramming(search_outer=True, minimize="flops"),
)
return self.einsum_expression(preconditioned_gradient, output_gradient, input_activation)
return contract("bio,bci,bco->b", preconditioned_gradient, output_gradient, input_activation)
12 changes: 2 additions & 10 deletions kronfluence/module/linear.py
Original file line number Diff line number Diff line change
Expand Up @@ -103,7 +103,7 @@ def compute_pairwise_score(
minimize = "size"
else:
expr = "qio,b...i,b...o->qb"
minimize = "flops"
minimize = "size"
self.einsum_expression = contract_expression(
expr,
preconditioned_gradient.shape,
Expand All @@ -117,12 +117,4 @@ def compute_self_measurement_score(
self, preconditioned_gradient: torch.Tensor, input_activation: torch.Tensor, output_gradient: torch.Tensor
) -> torch.Tensor:
input_activation = self._flatten_input_activation(input_activation=input_activation)
if self.einsum_expression is None:
self.einsum_expression = contract_expression(
"bio,b...i,b...o->b",
preconditioned_gradient.shape,
output_gradient.shape,
input_activation.shape,
optimize=DynamicProgramming(search_outer=True, minimize="flops"),
)
return self.einsum_expression(preconditioned_gradient, output_gradient, input_activation)
return contract("bio,b...i,b...o->b", preconditioned_gradient, output_gradient, input_activation)
2 changes: 1 addition & 1 deletion kronfluence/module/tracker/pairwise_score.py
Original file line number Diff line number Diff line change
Expand Up @@ -88,7 +88,7 @@ def backward_hook(output_gradient: torch.Tensor) -> None:
input_activation=cached_activation.to(device=output_gradient.device),
output_gradient=output_gradient,
)
del cached_activation, output_gradient
del self.cached_activations, cached_activation, output_gradient
self.clear_all_cache()
else:
per_sample_gradient = self.module.compute_per_sample_gradient(
Expand Down
2 changes: 1 addition & 1 deletion kronfluence/module/tracker/precondition.py
Original file line number Diff line number Diff line change
Expand Up @@ -236,7 +236,7 @@ def accumulate_iterations(self) -> None:
self.module.storage[ACCUMULATED_PRECONDITIONED_GRADIENT_NAME] = torch.cat(
(accumulated_gradient, gradient), dim=0
).contiguous()
del gradient, self.module.storage[PRECONDITIONED_GRADIENT_NAME]
del self.module.storage[PRECONDITIONED_GRADIENT_NAME], gradient
self.module.storage[PRECONDITIONED_GRADIENT_NAME] = None

@torch.no_grad()
Expand Down
16 changes: 0 additions & 16 deletions kronfluence/score/dot_product.py
Original file line number Diff line number Diff line change
Expand Up @@ -73,10 +73,6 @@ def compute_dot_products_with_loader(
total_steps = 0
enable_amp = score_args.amp_dtype is not None

print("Start")
print(torch.cuda.memory_allocated())
print(torch.cuda.memory_reserved())

with tqdm(
total=len(train_loader),
desc="Computing pairwise scores (training gradient)",
Expand All @@ -86,10 +82,6 @@ def compute_dot_products_with_loader(
for batch in train_loader:
batch = send_to_device(tensor=batch, device=state.device)

print("Begin")
print(torch.cuda.memory_allocated())
print(torch.cuda.memory_reserved())

with no_sync(model=model, state=state):
model.zero_grad(set_to_none=True)
with autocast(device_type=state.device.type, enabled=enable_amp, dtype=score_args.amp_dtype):
Expand All @@ -103,10 +95,6 @@ def compute_dot_products_with_loader(
if factor_args.has_shared_parameters:
finalize_iteration(model=model, tracked_module_names=tracked_module_names)

print("Middle")
print(torch.cuda.memory_allocated())
print(torch.cuda.memory_reserved())

with torch.no_grad():
if score_args.compute_per_module_scores:
for module in cached_module_lst:
Expand Down Expand Up @@ -137,10 +125,6 @@ def compute_dot_products_with_loader(
total_steps += 1
pbar.update(1)

print("End")
print(torch.cuda.memory_allocated())
print(torch.cuda.memory_reserved())

model.zero_grad(set_to_none=True)
finalize_all_iterations(model=model, tracked_module_names=tracked_module_names)
set_mode(
Expand Down

0 comments on commit 6645ca9

Please sign in to comment.