Skip to content

Commit

Permalink
add stride hook for backward
Browse files Browse the repository at this point in the history
  • Loading branch information
wiederm committed Aug 24, 2024
1 parent 8da97be commit e4ce8d6
Showing 1 changed file with 11 additions and 5 deletions.
16 changes: 11 additions & 5 deletions modelforge/train/training.py
Original file line number Diff line number Diff line change
Expand Up @@ -175,7 +175,7 @@ def forward(
per_molecule_squared_error,
batch.metadata.atomic_subsystem_counts,
prefactor=per_atom_prediction.shape[-1],
)
).contiguous()

return per_molecule_square_error_scaled

Expand Down Expand Up @@ -542,6 +542,7 @@ def __init__(
dataset_statistic: Dict[str, float],
training_parameter: TrainingParameters,
potential_seed: Optional[int] = None,
debugging: bool = True,
):
"""
Initializes the TrainingAdapter with the specified model and training configuration.
Expand Down Expand Up @@ -576,16 +577,21 @@ def __init__(

def check_strides(module, grad_input, grad_output):
print(f"Layer: {module.__class__.__name__}")

for i, grad in enumerate(grad_input):
if grad is not None:
print(
f"Grad input {i}: size {grad.size()}, strides {grad.stride()}"
)
for i, grad in enumerate(grad_output):
if grad is not None:
print(
f"Grad output {i}: size {grad.size()}, strides {grad.stride()}"
)

# Register the hook
for module in self.potential.modules():
module.register_backward_hook(check_strides)
# Register the full backward hook
if debugging is True:
for module in self.potential.modules():
module.register_full_backward_hook(check_strides)

self.calculate_predictions = CalculateProperties(
training_parameter.loss_parameter.loss_property
Expand Down

0 comments on commit e4ce8d6

Please sign in to comment.