Skip to content

Commit

Permalink
dicst as module output for stride
Browse files Browse the repository at this point in the history
  • Loading branch information
wiederm committed Aug 24, 2024
1 parent e4ce8d6 commit 3f087ab
Showing 1 changed file with 16 additions and 5 deletions.
21 changes: 16 additions & 5 deletions modelforge/train/training.py
Original file line number Diff line number Diff line change
Expand Up @@ -582,11 +582,22 @@ def check_strides(module, grad_input, grad_output):
print(
f"Grad input {i}: size {grad.size()}, strides {grad.stride()}"
)
for i, grad in enumerate(grad_output):
if grad is not None:
print(
f"Grad output {i}: size {grad.size()}, strides {grad.stride()}"
)
# Handle grad_output
if isinstance(grad_output, tuple) and isinstance(grad_output[0], dict):
# If the output is a dict wrapped in a tuple, extract the dict
grad_output = grad_output[0]
if isinstance(grad_output, dict):
for key, grad in grad_output.items():
if grad is not None:
print(
f"Grad output [{key}]: size {grad.size()}, strides {grad.stride()}"
)
else:
for i, grad in enumerate(grad_output):
if grad is not None:
print(
f"Grad output {i}: size {grad.size()}, strides {grad.stride()}"
)

# Register the full backward hook
if debugging is True:
Expand Down

0 comments on commit 3f087ab

Please sign in to comment.