Skip to content

Commit

Permalink
#7694: add support for non 4d tensor in moreh_sgd
Browse files Browse the repository at this point in the history
  • Loading branch information
hschoi4448 committed Apr 22, 2024
1 parent 49350c9 commit 43fb6d6
Show file tree
Hide file tree
Showing 2 changed files with 8 additions and 11 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -30,7 +30,7 @@ def create_tt_tensor(tensor, device):
@pytest.mark.parametrize(
"shape",
(
(1, 1, 32, 32), # single
(32, 32), # single
(12, 6, 64, 64), # multiple tiles
),
)
Expand All @@ -49,15 +49,13 @@ def test_moreh_sgd(shape, lr, momentum, dampening, weight_decay, nesterov, momen
torch.manual_seed(0)

# make model and compute grad
N, C, H, W = shape

x_data = torch.rand((N, C, H, W)).to(torch.bfloat16)
y_data = torch.rand((N, C, H, W)).to(torch.bfloat16)
x_data = torch.rand(shape).to(torch.bfloat16)
y_data = torch.rand(shape).to(torch.bfloat16)

class SimpleModel(nn.Module):
def __init__(self):
super(SimpleModel, self).__init__()
self.weight = nn.Parameter(torch.randn(N, C, H, W).to(torch.bfloat16)).to(torch.bfloat16)
self.weight = nn.Parameter(torch.randn(shape).to(torch.bfloat16)).to(torch.bfloat16)

def forward(self, x):
return torch.mul(x, self.weight)
Expand Down
9 changes: 4 additions & 5 deletions tt_eager/tt_dnn/op_library/moreh_sgd/moreh_sgd.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -33,16 +33,15 @@ operation::ProgramWithCallbacks moreh_sgd_(
const CoreRange core_range) {
// split work
auto shape = param_in.get_legacy_shape();
auto N = shape[0];
auto C = shape[1];
auto H = shape[2];
auto W = shape[3];
auto H = shape[-2];
auto W = shape[-1];
auto num = param_in.volume() / H / W;
auto Ht = H / TILE_HEIGHT;
auto Wt = W / TILE_WIDTH;

bool has_momentum_buffer = momentum_buffer_in.has_value() && momentum_buffer_out.has_value();

uint32_t units_to_divide = N * C * Ht * Wt;
uint32_t units_to_divide = num * Ht * Wt;
uint32_t core_w = core_range.end.x - core_range.start.x + 1;
uint32_t core_h = core_range.end.y - core_range.start.y + 1;

Expand Down

0 comments on commit 43fb6d6

Please sign in to comment.