-
Notifications
You must be signed in to change notification settings - Fork 74
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
- Loading branch information
1 parent
da3c5fa
commit 59d6e69
Showing
11 changed files
with
1,130 additions
and
0 deletions.
There are no files selected for viewing
152 changes: 152 additions & 0 deletions
152
tests/tt_eager/python_api_testing/unit_testing/test_moreh_sgd.py
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,152 @@ | ||
# SPDX-FileCopyrightText: © 2023 Tenstorrent Inc. | ||
|
||
# SPDX-License-Identifier: Apache-2.0 | ||
|
||
import torch | ||
import torch.nn as nn | ||
import torch.optim as optim | ||
|
||
import tt_lib as ttl | ||
import pytest | ||
from models.utility_functions import ( | ||
comp_allclose_and_pcc, | ||
) | ||
from loguru import logger | ||
|
||
|
||
def create_tt_tensor(tensor, device): | ||
ret = ( | ||
ttl.tensor.Tensor( | ||
tensor, | ||
ttl.tensor.DataType.BFLOAT16, | ||
) | ||
.to(ttl.tensor.Layout.TILE) | ||
.to(device) | ||
) | ||
|
||
return ret | ||
|
||
|
||
@pytest.mark.parametrize( | ||
"shape", | ||
( | ||
(1, 1, 32, 32), # single | ||
(12, 6, 64, 64), # multiple tiles | ||
), | ||
) | ||
@pytest.mark.parametrize("lr", [3.0]) | ||
@pytest.mark.parametrize("momentum", [0.0, 7.7]) | ||
@pytest.mark.parametrize("dampening", [0.0, 0.5]) | ||
@pytest.mark.parametrize("weight_decay", [0.0, 2.2]) | ||
@pytest.mark.parametrize("nesterov", [True, False], ids=["NESTEROV_TRUE", "NESTEROV_FALSE"]) | ||
@pytest.mark.parametrize( | ||
"momentum_initialized", [True, False], ids=["MOMENTUM_INITIALIZED", "MOMENTUM_NOT_INITIALIZED"] | ||
) | ||
def test_moreh_sgd(shape, lr, momentum, dampening, weight_decay, nesterov, momentum_initialized, device): | ||
if nesterov and (momentum <= 0 or dampening != 0): | ||
pytest.skip() | ||
|
||
torch.manual_seed(0) | ||
|
||
# make model and compute grad | ||
N, C, H, W = shape | ||
|
||
x_data = torch.rand((N, C, H, W)).to(torch.bfloat16) | ||
y_data = torch.rand((N, C, H, W)).to(torch.bfloat16) | ||
|
||
class SimpleModel(nn.Module): | ||
def __init__(self): | ||
super(SimpleModel, self).__init__() | ||
self.weight = nn.Parameter(torch.randn(N, C, H, W).to(torch.bfloat16)).to(torch.bfloat16) | ||
|
||
def forward(self, x): | ||
return torch.mul(x, self.weight) | ||
|
||
model = SimpleModel() | ||
|
||
criterion = nn.L1Loss() | ||
optimizer = optim.SGD( | ||
{model.weight}, lr=lr, momentum=momentum, dampening=dampening, weight_decay=weight_decay, nesterov=nesterov | ||
) | ||
optimizer.zero_grad() | ||
|
||
outputs = model(x_data) | ||
loss = criterion(outputs, y_data) | ||
loss.backward() | ||
|
||
# do step for momentum_initialized test | ||
step_cnt = 2 if momentum_initialized else 1 | ||
|
||
cpu_momentum_in = None | ||
cpu_momentum_out = None | ||
for i in range(0, step_cnt): | ||
cpu_param_in = model.weight.clone() | ||
dev_param_in = create_tt_tensor(cpu_param_in, device) | ||
|
||
optimizer_state_dict = optimizer.state_dict() | ||
if momentum != 0: | ||
if 0 in optimizer_state_dict["state"]: | ||
cpu_momentum_in = optimizer_state_dict["state"][0]["momentum_buffer"].clone() | ||
|
||
optimizer.step() | ||
|
||
optimizer_state_dict = optimizer.state_dict() | ||
if momentum != 0: | ||
if 0 in optimizer_state_dict["state"]: | ||
cpu_momentum_out = optimizer_state_dict["state"][0]["momentum_buffer"].clone() | ||
|
||
# create other dev tensors | ||
dev_param_out = create_tt_tensor(cpu_param_in, device) | ||
|
||
cpu_grad = model.weight.grad | ||
dev_grad = create_tt_tensor(cpu_grad, device) | ||
|
||
dev_momentum_buffer_in = None | ||
dev_momentum_buffer_out = None | ||
if momentum != 0: | ||
if momentum_initialized: | ||
if cpu_momentum_in is not None: | ||
dev_momentum_buffer_in = create_tt_tensor(cpu_momentum_in, device) | ||
else: | ||
dev_momentum_buffer_in = create_tt_tensor(cpu_param_in, device) | ||
|
||
dev_momentum_buffer_out = create_tt_tensor(cpu_param_in, device) | ||
|
||
ttl.operations.primary.moreh_sgd( | ||
dev_param_in, | ||
dev_grad, | ||
dev_momentum_buffer_in, | ||
dev_param_out, | ||
dev_momentum_buffer_out, | ||
lr, | ||
momentum, | ||
dampening, | ||
weight_decay, | ||
nesterov, | ||
momentum_initialized, | ||
) | ||
|
||
assert dev_param_in.shape() == list(model.weight.shape) | ||
|
||
# check param_out | ||
param_result = dev_param_out.cpu().to(ttl.tensor.Layout.ROW_MAJOR).to_torch().to(torch.bfloat16) | ||
|
||
rtol = atol = 0.05 | ||
passing, out = comp_allclose_and_pcc(model.weight, param_result, pcc=0.99, rtol=rtol, atol=atol) | ||
|
||
logger.info(f"Out passing (param)={passing}") | ||
logger.info(f"Output pcc={out}") | ||
|
||
assert passing | ||
|
||
# check momentum_out | ||
if momentum != 0: | ||
momentum_buffer_result = ( | ||
dev_momentum_buffer_out.cpu().to(ttl.tensor.Layout.ROW_MAJOR).to_torch().to(torch.bfloat16) | ||
) | ||
|
||
passing, out = comp_allclose_and_pcc(cpu_momentum_out, momentum_buffer_result, pcc=0.99, rtol=rtol, atol=atol) | ||
logger.info(f"Momentum_out passing (param)={passing}") | ||
logger.info(f"Momentum_out pcc={out}") | ||
|
||
assert passing |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
201 changes: 201 additions & 0 deletions
201
tt_eager/tt_dnn/op_library/moreh_sgd/kernels/common.hpp
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,201 @@ | ||
/* | ||
* SPDX-FileCopyrightText: © 2023 Tenstorrent Inc. | ||
* | ||
* SPDX-License-Identifier: Apache-2.0 | ||
*/ | ||
|
||
#include <stdint.h> | ||
#include "dataflow_api.h" | ||
|
||
template <bool DRAM> | ||
InterleavedAddrGenFast<DRAM> InterleavedAddrGenFastHelper_(uint32_t addr, tt::CB cb, uint32_t idx) { | ||
uint32_t tile_bytes = get_tile_size(cb); | ||
auto data_format = get_dataformat(cb); | ||
|
||
const InterleavedAddrGenFast<DRAM> x = { | ||
.bank_base_address = addr, | ||
.page_size = tile_bytes, | ||
.data_format = data_format | ||
}; | ||
|
||
return x; | ||
} | ||
|
||
#define InterleavedAddrGenFastHelper(addr, cb, idx) \ | ||
({ \ | ||
constexpr bool is_dram = (get_compile_time_arg_val(idx) == 1); \ | ||
const InterleavedAddrGenFast<is_dram> ret = InterleavedAddrGenFastHelper_<is_dram>(addr, cb, idx); \ | ||
ret; \ | ||
}) | ||
|
||
|
||
template<typename AddrGen> | ||
void noc_async_read_tile_helper(tt::CB cb, uint32_t num_tiles, uint32_t tile_idx, AddrGen addr_gen) { | ||
cb_reserve_back(cb, num_tiles); | ||
uint32_t addr = get_write_ptr(cb); | ||
noc_async_read_tile(tile_idx, addr_gen, addr); | ||
noc_async_read_barrier(); | ||
cb_push_back(cb, num_tiles); | ||
} | ||
|
||
template<typename AddrGen> | ||
void noc_async_write_tile_helper(tt::CB cb, uint32_t num_tiles, uint32_t tile_idx, AddrGen addr_gen) { | ||
cb_wait_front(cb, num_tiles); | ||
uint32_t l1_read_addr = get_read_ptr(cb); | ||
noc_async_write_tile(tile_idx, addr_gen, l1_read_addr); | ||
noc_async_write_barrier(); | ||
cb_pop_front(cb, num_tiles); | ||
} | ||
|
||
void generate_bcast_scaler( | ||
uint32_t cb_scaler, | ||
uint32_t scaler) { | ||
union { float f; uint32_t u; } u; u.u = scaler; | ||
cb_reserve_back(cb_scaler, 1); | ||
auto ptr = reinterpret_cast<uint16_t*>(get_write_ptr(cb_scaler)); | ||
|
||
for (int j = 0; j < 1024; j++) | ||
ptr[j] = uint16_t(0); | ||
|
||
for (int k = 0; k < 4; k++) | ||
for (int j = 0; j < 16; j++) | ||
ptr[k*256 + j] = uint16_t(u.u>>16); | ||
cb_push_back(cb_scaler, 1); | ||
} | ||
|
||
void fill_cb_with_value(uint32_t cb_id, uint32_t value) { | ||
cb_reserve_back(cb_id, 1); | ||
auto ptr = reinterpret_cast<uint16_t *>(get_write_ptr(cb_id)); | ||
for (int j = 0; j < 1024; j++) { | ||
ptr[j] = uint16_t(value >> 16); | ||
} | ||
cb_push_back(cb_id, 1); | ||
} | ||
|
||
void generate_mask_w( | ||
uint32_t cb_mask, | ||
uint32_t mask_w) { | ||
union { float f; uint32_t u; } one; one.f = 1.0f; | ||
union { float f; uint32_t u; } zero; zero.f = 0.0f; | ||
|
||
cb_reserve_back(cb_mask, 1); | ||
auto ptr = reinterpret_cast<uint16_t*>(get_write_ptr(cb_mask)); | ||
|
||
for(uint32_t h = 0 ; h < 16; h++) { | ||
// sub tile 0 | ||
{ | ||
uint32_t mask_w_0 = mask_w; | ||
if (mask_w_0 >= 16) mask_w_0 = 16; | ||
uint32_t w = 0; | ||
for(; w < mask_w_0; w++){ | ||
ptr[h * 16 + w] = uint16_t(one.u >> 16); | ||
} | ||
for(; w < 16; w++){ | ||
ptr[h * 16 + w] = uint16_t(zero.u >> 16); | ||
} | ||
} | ||
|
||
// sub tile 1 | ||
{ | ||
uint32_t mask_w_1 = (mask_w < 16) ? 0 : mask_w - 16; | ||
uint32_t w = 0; | ||
for(; w < mask_w_1; w++){ | ||
ptr[h * 16 + w + 256] = uint16_t(one.u >> 16); | ||
} | ||
for(; w < 16; w++){ | ||
ptr[h * 16 + w + 256] = uint16_t(zero.u >> 16); | ||
} | ||
} | ||
|
||
// sub tile 2 | ||
{ | ||
uint32_t mask_w_0 = mask_w; | ||
if (mask_w_0 >= 16) mask_w_0 = 16; | ||
uint32_t w = 0; | ||
for(; w < mask_w_0; w++){ | ||
ptr[h * 16 + w + 512] = uint16_t(one.u >> 16); | ||
} | ||
for(; w < 16; w++){ | ||
ptr[h * 16 + w + 512] = uint16_t(zero.u >> 16); | ||
} | ||
} | ||
|
||
// sub tile 3 | ||
{ | ||
uint32_t mask_w_1 = (mask_w < 16) ? 0 : mask_w - 16; | ||
uint32_t w = 0; | ||
for(; w < mask_w_1; w++){ | ||
ptr[h * 16 + w + 768] = uint16_t(one.u >> 16); | ||
} | ||
for(; w < 16; w++){ | ||
ptr[h * 16 + w + 768] = uint16_t(zero.u >> 16); | ||
} | ||
} | ||
} | ||
|
||
cb_push_back(cb_mask, 1); | ||
} | ||
|
||
void generate_mask_h( | ||
uint32_t cb_mask, | ||
uint32_t mask_h) { | ||
union { float f; uint32_t u; } one; one.f = 1.0f; | ||
union { float f; uint32_t u; } zero; zero.f = 0.0f; | ||
|
||
cb_reserve_back(cb_mask, 1); | ||
auto ptr = reinterpret_cast<uint16_t*>(get_write_ptr(cb_mask)); | ||
|
||
for(uint32_t w = 0; w < 16; w++) { | ||
// sub tile 0 | ||
{ | ||
uint32_t mask_h_0 = mask_h; | ||
if (mask_h_0 >= 16) mask_h_0 = 16; | ||
uint32_t h = 0; | ||
for(; h < mask_h_0; h++){ | ||
ptr[h * 16 + w] = uint16_t(one.u >> 16); | ||
} | ||
for(; h < 16; h++){ | ||
ptr[h * 16 + w] = uint16_t(zero.u >> 16); | ||
} | ||
} | ||
|
||
// sub tile 1 | ||
{ | ||
uint32_t mask_h_0 = mask_h; | ||
if (mask_h_0 >= 16) mask_h_0 = 16; | ||
uint32_t h = 0; | ||
for(; h < mask_h_0; h++){ | ||
ptr[h * 16 + w + 256] = uint16_t(one.u >> 16); | ||
} | ||
for(; h < 16; h++){ | ||
ptr[h * 16 + w + 256] = uint16_t(zero.u >> 16); | ||
} | ||
} | ||
|
||
// sub tile 2 | ||
{ | ||
uint32_t mask_h_1 = (mask_h < 16) ? 0 : mask_h - 16; | ||
uint32_t h = 0; | ||
for(; h < mask_h_1; h++){ | ||
ptr[h * 16 + w + 512] = uint16_t(one.u >> 16); | ||
} | ||
for(; h < 16; h++){ | ||
ptr[h * 16 + w + 512] = uint16_t(zero.u >> 16); | ||
} | ||
} | ||
|
||
// sub tile 3 | ||
{ | ||
uint32_t mask_h_1 = (mask_h < 16) ? 0 : mask_h - 16; | ||
uint32_t h = 0; | ||
for(; h < mask_h_1; h++){ | ||
ptr[h * 16 + w + 768] = uint16_t(one.u >> 16); | ||
} | ||
for(; h < 16; h++){ | ||
ptr[h * 16 + w + 768] = uint16_t(zero.u >> 16); | ||
} | ||
} | ||
} | ||
|
||
cb_push_back(cb_mask, 1); | ||
} |
Oops, something went wrong.