Skip to content

Commit

Permalink
fix conv2d per sample gradient compute
Browse files Browse the repository at this point in the history
  • Loading branch information
sangkeun00 committed Nov 27, 2023
1 parent de31a39 commit ee7f9e7
Show file tree
Hide file tree
Showing 2 changed files with 101 additions and 1 deletion.
9 changes: 8 additions & 1 deletion analog/logging.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,14 @@ def compute_per_sample_gradient(fwd, bwd, module):
return reduce(outer_product, "n ... i j -> n i j", "sum")
elif isinstance(module, nn.Conv2d):
bsz = fwd.shape[0]
fwd_unfold = torch.nn.functional.unfold(fwd, module.kernel_size)
fwd_unfold = torch.nn.functional.unfold(
fwd,
module.kernel_size,
dilation=module.dilation,
padding=module.padding,
stride=module.stride,
)
fwd_unfold = fwd_unfold.reshape(bsz, fwd_unfold.shape[1], -1)
bwd = bwd.reshape(bsz, -1, fwd_unfold.shape[-1])
grad = torch.einsum("ijk,ilk->ijl", bwd, fwd_unfold)
shape = [bsz] + list(module.weight.shape)
Expand Down
93 changes: 93 additions & 0 deletions tests/logger/test_conv.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,93 @@
import unittest
import copy

import torch
import torch.nn as nn
import torch.nn.functional as F

from analog import AnaLog


class Simple2DCNN(nn.Module):
def __init__(self, num_channels, hidden_size, num_classes):
super(Simple2DCNN, self).__init__()
self.conv1 = nn.Conv2d(
num_channels, hidden_size, kernel_size=3, stride=1, padding=1, bias=False
)
self.relu = nn.ReLU()
self.pool = nn.MaxPool2d(kernel_size=2, stride=2)
self.conv2 = nn.Conv2d(
hidden_size, hidden_size, kernel_size=3, stride=1, padding=1, bias=False
)
self.fc = nn.Linear(
hidden_size * 4 * 4, num_classes
) # Assuming input size is (num_channels, 14, 14)

def forward(self, x):
out = self.pool(self.relu(self.conv1(x)))
out = self.relu(self.conv2(out))
out = out.view(out.size(0), -1) # Flatten the output
out = self.fc(out)
return out


class Test2DCNNGradients(unittest.TestCase):
def setUp(self):
num_channels = 1
hidden_size = 8
num_classes = 10
self.model = Simple2DCNN(num_channels, hidden_size, num_classes)
self.func_model = Simple2DCNN(num_channels, hidden_size, num_classes)
self.func_model.load_state_dict(copy.deepcopy(self.model.state_dict()))
self.func_params = dict(self.func_model.named_parameters())
self.func_buffers = dict(self.func_model.named_buffers())

self.model.eval()
self.func_model.eval()

def test_per_sample_gradient(self):
# Instantiate AnaLog
analog = AnaLog(project="test")
analog.watch(self.model)

# Input and target for batch size of 4
inputs = torch.randn(
4, 1, 8, 8
) # Dummy input for 2D CNN (batch_size, channels, height, width)
labels = torch.tensor([1, 0, 1, 0]) # Dummy labels
batch = (inputs, labels)

# functorch per-sample gradient
def compute_loss_func(_params, _buffers, _batch):
_output = torch.func.functional_call(
self.func_model,
(_params, _buffers),
args=(_batch[0].unsqueeze(0),),
)
_loss = F.cross_entropy(_output, _batch[1].unsqueeze(0))
return _loss

func_compute_grad = torch.func.grad(compute_loss_func, has_aux=False)

grads_dict = torch.func.vmap(
func_compute_grad,
in_dims=(None, None, 0),
randomness="same",
)(self.func_params, self.func_buffers, batch)

# Forward pass with original model
with analog(data_id=inputs, log=["grad"], hessian=False, save=False):
self.model.zero_grad()
output = self.model(inputs)
loss = F.cross_entropy(output, labels, reduction="sum")
loss.backward()
analog_grads_dict = analog.get_log()

for module_name in analog_grads_dict:
analog_grad = analog_grads_dict[module_name]
func_grad = grads_dict[module_name + ".weight"]
self.assertTrue(torch.allclose(analog_grad, func_grad, atol=1e-6))


if __name__ == "__main__":
unittest.main()

0 comments on commit ee7f9e7

Please sign in to comment.