Skip to content

Commit

Permalink
#4379: support moreh sgd
Browse files Browse the repository at this point in the history
  • Loading branch information
hschoi4448 committed Jan 25, 2024
1 parent da3c5fa commit 59d6e69
Show file tree
Hide file tree
Showing 11 changed files with 1,130 additions and 0 deletions.
152 changes: 152 additions & 0 deletions tests/tt_eager/python_api_testing/unit_testing/test_moreh_sgd.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,152 @@
# SPDX-FileCopyrightText: © 2023 Tenstorrent Inc.

# SPDX-License-Identifier: Apache-2.0

import torch
import torch.nn as nn
import torch.optim as optim

import tt_lib as ttl
import pytest
from models.utility_functions import (
comp_allclose_and_pcc,
)
from loguru import logger


def create_tt_tensor(tensor, device):
ret = (
ttl.tensor.Tensor(
tensor,
ttl.tensor.DataType.BFLOAT16,
)
.to(ttl.tensor.Layout.TILE)
.to(device)
)

return ret


@pytest.mark.parametrize(
"shape",
(
(1, 1, 32, 32), # single
(12, 6, 64, 64), # multiple tiles
),
)
@pytest.mark.parametrize("lr", [3.0])
@pytest.mark.parametrize("momentum", [0.0, 7.7])
@pytest.mark.parametrize("dampening", [0.0, 0.5])
@pytest.mark.parametrize("weight_decay", [0.0, 2.2])
@pytest.mark.parametrize("nesterov", [True, False], ids=["NESTEROV_TRUE", "NESTEROV_FALSE"])
@pytest.mark.parametrize(
"momentum_initialized", [True, False], ids=["MOMENTUM_INITIALIZED", "MOMENTUM_NOT_INITIALIZED"]
)
def test_moreh_sgd(shape, lr, momentum, dampening, weight_decay, nesterov, momentum_initialized, device):
if nesterov and (momentum <= 0 or dampening != 0):
pytest.skip()

torch.manual_seed(0)

# make model and compute grad
N, C, H, W = shape

x_data = torch.rand((N, C, H, W)).to(torch.bfloat16)
y_data = torch.rand((N, C, H, W)).to(torch.bfloat16)

class SimpleModel(nn.Module):
def __init__(self):
super(SimpleModel, self).__init__()
self.weight = nn.Parameter(torch.randn(N, C, H, W).to(torch.bfloat16)).to(torch.bfloat16)

def forward(self, x):
return torch.mul(x, self.weight)

model = SimpleModel()

criterion = nn.L1Loss()
optimizer = optim.SGD(
{model.weight}, lr=lr, momentum=momentum, dampening=dampening, weight_decay=weight_decay, nesterov=nesterov
)
optimizer.zero_grad()

outputs = model(x_data)
loss = criterion(outputs, y_data)
loss.backward()

# do step for momentum_initialized test
step_cnt = 2 if momentum_initialized else 1

cpu_momentum_in = None
cpu_momentum_out = None
for i in range(0, step_cnt):
cpu_param_in = model.weight.clone()
dev_param_in = create_tt_tensor(cpu_param_in, device)

optimizer_state_dict = optimizer.state_dict()
if momentum != 0:
if 0 in optimizer_state_dict["state"]:
cpu_momentum_in = optimizer_state_dict["state"][0]["momentum_buffer"].clone()

optimizer.step()

optimizer_state_dict = optimizer.state_dict()
if momentum != 0:
if 0 in optimizer_state_dict["state"]:
cpu_momentum_out = optimizer_state_dict["state"][0]["momentum_buffer"].clone()

# create other dev tensors
dev_param_out = create_tt_tensor(cpu_param_in, device)

cpu_grad = model.weight.grad
dev_grad = create_tt_tensor(cpu_grad, device)

dev_momentum_buffer_in = None
dev_momentum_buffer_out = None
if momentum != 0:
if momentum_initialized:
if cpu_momentum_in is not None:
dev_momentum_buffer_in = create_tt_tensor(cpu_momentum_in, device)
else:
dev_momentum_buffer_in = create_tt_tensor(cpu_param_in, device)

dev_momentum_buffer_out = create_tt_tensor(cpu_param_in, device)

ttl.operations.primary.moreh_sgd(
dev_param_in,
dev_grad,
dev_momentum_buffer_in,
dev_param_out,
dev_momentum_buffer_out,
lr,
momentum,
dampening,
weight_decay,
nesterov,
momentum_initialized,
)

assert dev_param_in.shape() == list(model.weight.shape)

# check param_out
param_result = dev_param_out.cpu().to(ttl.tensor.Layout.ROW_MAJOR).to_torch().to(torch.bfloat16)

rtol = atol = 0.05
passing, out = comp_allclose_and_pcc(model.weight, param_result, pcc=0.99, rtol=rtol, atol=atol)

logger.info(f"Out passing (param)={passing}")
logger.info(f"Output pcc={out}")

assert passing

# check momentum_out
if momentum != 0:
momentum_buffer_result = (
dev_momentum_buffer_out.cpu().to(ttl.tensor.Layout.ROW_MAJOR).to_torch().to(torch.bfloat16)
)

passing, out = comp_allclose_and_pcc(cpu_momentum_out, momentum_buffer_result, pcc=0.99, rtol=rtol, atol=atol)
logger.info(f"Momentum_out passing (param)={passing}")
logger.info(f"Momentum_out pcc={out}")

assert passing
2 changes: 2 additions & 0 deletions tt_eager/tt_dnn/module.mk
Original file line number Diff line number Diff line change
Expand Up @@ -105,6 +105,8 @@ TT_DNN_SRCS = \
tt_eager/tt_dnn/op_library/moreh_layernorm_backward/gamma_beta_grad/moreh_layernorm_backward_gamma_beta_grad.cpp \
tt_eager/tt_dnn/op_library/moreh_cumsum/moreh_cumsum_nc/moreh_cumsum_nc.cpp \
tt_eager/tt_dnn/op_library/moreh_cumsum/moreh_cumsum_op.cpp \
tt_eager/tt_dnn/op_library/moreh_sgd/moreh_sgd_op.cpp \
tt_eager/tt_dnn/op_library/moreh_sgd/moreh_sgd.cpp \
tt_eager/tt_dnn/op_library/groupnorm/groupnorm_op.cpp \
tt_eager/tt_dnn/op_library/reshape/reshape_op.cpp \
tt_eager/tt_dnn/op_library/permute/permute_op.cpp \
Expand Down
201 changes: 201 additions & 0 deletions tt_eager/tt_dnn/op_library/moreh_sgd/kernels/common.hpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,201 @@
/*
* SPDX-FileCopyrightText: © 2023 Tenstorrent Inc.
*
* SPDX-License-Identifier: Apache-2.0
*/

#include <stdint.h>
#include "dataflow_api.h"

template <bool DRAM>
InterleavedAddrGenFast<DRAM> InterleavedAddrGenFastHelper_(uint32_t addr, tt::CB cb, uint32_t idx) {
uint32_t tile_bytes = get_tile_size(cb);
auto data_format = get_dataformat(cb);

const InterleavedAddrGenFast<DRAM> x = {
.bank_base_address = addr,
.page_size = tile_bytes,
.data_format = data_format
};

return x;
}

#define InterleavedAddrGenFastHelper(addr, cb, idx) \
({ \
constexpr bool is_dram = (get_compile_time_arg_val(idx) == 1); \
const InterleavedAddrGenFast<is_dram> ret = InterleavedAddrGenFastHelper_<is_dram>(addr, cb, idx); \
ret; \
})


template<typename AddrGen>
void noc_async_read_tile_helper(tt::CB cb, uint32_t num_tiles, uint32_t tile_idx, AddrGen addr_gen) {
cb_reserve_back(cb, num_tiles);
uint32_t addr = get_write_ptr(cb);
noc_async_read_tile(tile_idx, addr_gen, addr);
noc_async_read_barrier();
cb_push_back(cb, num_tiles);
}

template<typename AddrGen>
void noc_async_write_tile_helper(tt::CB cb, uint32_t num_tiles, uint32_t tile_idx, AddrGen addr_gen) {
cb_wait_front(cb, num_tiles);
uint32_t l1_read_addr = get_read_ptr(cb);
noc_async_write_tile(tile_idx, addr_gen, l1_read_addr);
noc_async_write_barrier();
cb_pop_front(cb, num_tiles);
}

void generate_bcast_scaler(
uint32_t cb_scaler,
uint32_t scaler) {
union { float f; uint32_t u; } u; u.u = scaler;
cb_reserve_back(cb_scaler, 1);
auto ptr = reinterpret_cast<uint16_t*>(get_write_ptr(cb_scaler));

for (int j = 0; j < 1024; j++)
ptr[j] = uint16_t(0);

for (int k = 0; k < 4; k++)
for (int j = 0; j < 16; j++)
ptr[k*256 + j] = uint16_t(u.u>>16);
cb_push_back(cb_scaler, 1);
}

void fill_cb_with_value(uint32_t cb_id, uint32_t value) {
cb_reserve_back(cb_id, 1);
auto ptr = reinterpret_cast<uint16_t *>(get_write_ptr(cb_id));
for (int j = 0; j < 1024; j++) {
ptr[j] = uint16_t(value >> 16);
}
cb_push_back(cb_id, 1);
}

void generate_mask_w(
uint32_t cb_mask,
uint32_t mask_w) {
union { float f; uint32_t u; } one; one.f = 1.0f;
union { float f; uint32_t u; } zero; zero.f = 0.0f;

cb_reserve_back(cb_mask, 1);
auto ptr = reinterpret_cast<uint16_t*>(get_write_ptr(cb_mask));

for(uint32_t h = 0 ; h < 16; h++) {
// sub tile 0
{
uint32_t mask_w_0 = mask_w;
if (mask_w_0 >= 16) mask_w_0 = 16;
uint32_t w = 0;
for(; w < mask_w_0; w++){
ptr[h * 16 + w] = uint16_t(one.u >> 16);
}
for(; w < 16; w++){
ptr[h * 16 + w] = uint16_t(zero.u >> 16);
}
}

// sub tile 1
{
uint32_t mask_w_1 = (mask_w < 16) ? 0 : mask_w - 16;
uint32_t w = 0;
for(; w < mask_w_1; w++){
ptr[h * 16 + w + 256] = uint16_t(one.u >> 16);
}
for(; w < 16; w++){
ptr[h * 16 + w + 256] = uint16_t(zero.u >> 16);
}
}

// sub tile 2
{
uint32_t mask_w_0 = mask_w;
if (mask_w_0 >= 16) mask_w_0 = 16;
uint32_t w = 0;
for(; w < mask_w_0; w++){
ptr[h * 16 + w + 512] = uint16_t(one.u >> 16);
}
for(; w < 16; w++){
ptr[h * 16 + w + 512] = uint16_t(zero.u >> 16);
}
}

// sub tile 3
{
uint32_t mask_w_1 = (mask_w < 16) ? 0 : mask_w - 16;
uint32_t w = 0;
for(; w < mask_w_1; w++){
ptr[h * 16 + w + 768] = uint16_t(one.u >> 16);
}
for(; w < 16; w++){
ptr[h * 16 + w + 768] = uint16_t(zero.u >> 16);
}
}
}

cb_push_back(cb_mask, 1);
}

void generate_mask_h(
uint32_t cb_mask,
uint32_t mask_h) {
union { float f; uint32_t u; } one; one.f = 1.0f;
union { float f; uint32_t u; } zero; zero.f = 0.0f;

cb_reserve_back(cb_mask, 1);
auto ptr = reinterpret_cast<uint16_t*>(get_write_ptr(cb_mask));

for(uint32_t w = 0; w < 16; w++) {
// sub tile 0
{
uint32_t mask_h_0 = mask_h;
if (mask_h_0 >= 16) mask_h_0 = 16;
uint32_t h = 0;
for(; h < mask_h_0; h++){
ptr[h * 16 + w] = uint16_t(one.u >> 16);
}
for(; h < 16; h++){
ptr[h * 16 + w] = uint16_t(zero.u >> 16);
}
}

// sub tile 1
{
uint32_t mask_h_0 = mask_h;
if (mask_h_0 >= 16) mask_h_0 = 16;
uint32_t h = 0;
for(; h < mask_h_0; h++){
ptr[h * 16 + w + 256] = uint16_t(one.u >> 16);
}
for(; h < 16; h++){
ptr[h * 16 + w + 256] = uint16_t(zero.u >> 16);
}
}

// sub tile 2
{
uint32_t mask_h_1 = (mask_h < 16) ? 0 : mask_h - 16;
uint32_t h = 0;
for(; h < mask_h_1; h++){
ptr[h * 16 + w + 512] = uint16_t(one.u >> 16);
}
for(; h < 16; h++){
ptr[h * 16 + w + 512] = uint16_t(zero.u >> 16);
}
}

// sub tile 3
{
uint32_t mask_h_1 = (mask_h < 16) ? 0 : mask_h - 16;
uint32_t h = 0;
for(; h < mask_h_1; h++){
ptr[h * 16 + w + 768] = uint16_t(one.u >> 16);
}
for(; h < 16; h++){
ptr[h * 16 + w + 768] = uint16_t(zero.u >> 16);
}
}
}

cb_push_back(cb_mask, 1);
}
Loading

0 comments on commit 59d6e69

Please sign in to comment.