Skip to content

Commit

Permalink
Merge pull request #4 from ucb-bar/refactor
Browse files Browse the repository at this point in the history
Major Refactor
  • Loading branch information
T-K-233 authored Jul 11, 2024
2 parents e4e90a2 + 5ef4e92 commit aa71c8a
Show file tree
Hide file tree
Showing 221 changed files with 4,018 additions and 3,247 deletions.
4 changes: 4 additions & 0 deletions .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -28,3 +28,7 @@ share/python-wheels/
.installed.cfg
*.egg
MANIFEST


example/llama2/checkpoints/stories15M.bin

2 changes: 1 addition & 1 deletion CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -27,7 +27,7 @@ set(WRAP_SPECS_FILE "htif_wrap.specs")
set(SPECS_FILE "htif_nano.specs")
set(LIBGLOSS_DIR "$ENV{RISCV}/riscv64-unknown-elf/lib/")

set(MARCH "rv64gcv_zfh")
set(MARCH "rv64gcv_zfh_zvfh")
set(MABI "lp64d")
set(MCMODEL "medany")

Expand Down
82 changes: 41 additions & 41 deletions barstools/src/barstools/converter.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,15 +13,15 @@

class TorchConverter(torch.fx.Interpreter):
@staticmethod
def toNumpy(tensor: torch.Tensor):
def to_numpy(tensor: torch.Tensor):
return tensor.cpu().detach().contiguous().numpy()

@staticmethod
def toBytes(ndarray: np.ndarray):
def to_bytes(ndarray: np.ndarray):
return ndarray.astype(np.float32).flatten().tobytes()

@staticmethod
def dtypeToStr(dtype: torch.dtype):
def dtype_to_str(dtype: torch.dtype):
if dtype == torch.float16:
return "DTYPE_F16"
elif dtype == torch.float32:
Expand Down Expand Up @@ -67,12 +67,12 @@ def __init__(self, model):
def print(self):
self.gm.graph.print_tabular()

def getModuleInSequential(self, module, indicies):
def get_module_in_sequential(self, module, indicies):
if len(indicies) == 0:
return module
return self.getModuleInSequential(module[indicies[0]], indicies[1:])
return self.get_module_in_sequential(module[indicies[0]], indicies[1:])

def getModule(self, module_name):
def get_module(self, module_name):
if "." in module_name:
# if we have nn.Sequential layers
target_hierarchy = module_name.split(".")
Expand All @@ -82,36 +82,36 @@ def getModule(self, module_name):
indicies = [int(x) for x in target_hierarchy[1:]]

module = getattr(self.model, sequential_name)
return self.getModuleInSequential(module, indicies)
return self.get_module_in_sequential(module, indicies)

return getattr(self.model, module_name)

def addDataTensor(self, name, tensor):
def add_data_tensor(self, name, tensor):
self.model_struct += INDENT + "Tensor {name};\n".format(
name=name
)
data = TorchConverter.toNumpy(tensor)
data = TorchConverter.to_numpy(tensor)

self.model_init += INDENT + "NN_initTensor(&model->{name}, {dim}, (size_t[]){{{shape}}}, {dtype}, array_pointer);\n".format(
self.model_init += INDENT + "NN_init_tensor(&model->{name}, {dim}, (size_t[]){{{shape}}}, {dtype}, array_pointer);\n".format(
name=name,
dim=len(tensor.shape),
shape=", ".join(str(x) for x in tensor.shape),
dtype=TorchConverter.dtypeToStr(tensor.dtype)
dtype=TorchConverter.dtype_to_str(tensor.dtype)
)
self.model_init += INDENT + "array_pointer += {increment};\n".format(
increment=np.prod(tensor.shape)
)
self.weight_content += TorchConverter.toBytes(data)
self.weight_content += TorchConverter.to_bytes(data)

def addOutputTensor(self, name, shape, dtype=torch.float32):
def add_output_tensor(self, name, shape, dtype=torch.float32):
self.model_struct += INDENT + "Tensor {name};\n".format(
name=name
)
self.model_init += INDENT + "NN_initTensor(&model->{name}, {dim}, (size_t[]){{{shape}}}, {dtype}, NULL);\n".format(
self.model_init += INDENT + "NN_init_tensor(&model->{name}, {dim}, (size_t[]){{{shape}}}, {dtype}, NULL);\n".format(
name=name,
dim=len(shape),
shape=", ".join(str(x) for x in shape),
dtype=TorchConverter.dtypeToStr(dtype)
dtype=TorchConverter.dtype_to_str(dtype)
)

def placeholder(self, target, args, kwargs):
Expand All @@ -134,7 +134,7 @@ def placeholder(self, target, args, kwargs):

self.model_struct += INDENT + "Tensor {name};\n".format(name=name)

self.model_init += INDENT + "NN_initTensor(&model->{name}, {dim}, (size_t[]){{{shape}}}, DTYPE_F32, NULL);\n".format(
self.model_init += INDENT + "NN_init_tensor(&model->{name}, {dim}, (size_t[]){{{shape}}}, DTYPE_F32, NULL);\n".format(
name=name,
dim=len(shape),
shape=", ".join(str(x) for x in shape)
Expand All @@ -160,7 +160,7 @@ def call_function(self, target, args, kwargs):
layer_name=layer_name,
input_names=self.node_info[layer_name][0]
)
self.addOutputTensor(layer_name, output_shape)
self.add_output_tensor(layer_name, output_shape)

elif target == torch.nn.functional.interpolate:
layer_name = "interpolate_{count}".format(count=count) if count > 0 else "interpolate"
Expand All @@ -170,20 +170,20 @@ def call_function(self, target, args, kwargs):
input_names=self.node_info[layer_name][0],
scale_factor=kwargs.get("scale_factor")
)
self.addOutputTensor(layer_name, output_shape)
self.add_output_tensor(layer_name, output_shape)

elif target == torch.nn.functional.relu:
layer_name = "relu_{count}".format(count=count) if count > 0 else "relu"
self.model_forward += INDENT + "// F.{layer_name}\n".format(layer_name=layer_name)
self.model_forward += INDENT + "NN_ReLU(&model->{layer_name}, &model->{input_names[0]});\n".format(
self.model_forward += INDENT + "NN_relu(&model->{layer_name}, &model->{input_names[0]});\n".format(
layer_name=layer_name,
input_names=self.node_info[layer_name][0]
)

elif target == torch.nn.functional.relu6:
layer_name = "relu6_{count}".format(count=count) if count > 0 else "relu6"
self.model_forward += INDENT + "// F.{layer_name}\n".format(layer_name=layer_name)
self.model_forward += INDENT + "NN_ReLU6(&model->{layer_name}, &model->{input_names[0]});\n".format(
self.model_forward += INDENT + "NN_relu6(&model->{layer_name}, &model->{input_names[0]});\n".format(
layer_name=layer_name,
input_names=self.node_info[layer_name][0]
)
Expand All @@ -206,7 +206,7 @@ def call_module(self, target, args, kwargs):
if len(output_shape) == 4:
output_shape = (output_shape[0], output_shape[2], output_shape[3], output_shape[1])

module = self.getModule(target)
module = self.get_module(target)
layer_name = target.replace(".", "_")
input_names = self.node_info[layer_name][0]

Expand All @@ -217,24 +217,24 @@ def call_module(self, target, args, kwargs):
)

if type(module) == torch.nn.Linear:
self.addDataTensor(
self.add_data_tensor(
"{layer_name}_weight".format(layer_name=layer_name),
module.state_dict().get("weight")
)

if module.bias is not None:
self.addDataTensor(
self.add_data_tensor(
"{layer_name}_bias".format(layer_name=layer_name),
module.state_dict().get("bias")
)

batch_size = int(output_shape[0])
self.addOutputTensor(
self.add_output_tensor(
layer_name,
(batch_size, module.out_features)
)

self.model_forward += INDENT + "NN_Linear(&model->{layer_name}, &model->{input_names[0]}, {weight}, {bias});\n".format(
self.model_forward += INDENT + "NN_linear(&model->{layer_name}, &model->{input_names[0]}, {weight}, {bias});\n".format(
layer_name=layer_name,
input_names=input_names,
weight="&model->{layer_name}_weight".format(layer_name=layer_name),
Expand All @@ -243,30 +243,30 @@ def call_module(self, target, args, kwargs):

elif type(module) == torch.nn.BatchNorm2d:
if module.weight is not None:
self.addDataTensor(
self.add_data_tensor(
"{layer_name}_weight".format(layer_name=layer_name),
module.state_dict().get("weight")
)
if module.bias is not None:
self.addDataTensor(
self.add_data_tensor(
"{layer_name}_bias".format(layer_name=layer_name),
module.state_dict().get("bias")
)
if module.running_mean is not None:
self.addDataTensor(
self.add_data_tensor(
"{layer_name}_running_mean".format(layer_name=layer_name),
module.state_dict().get("running_mean")
)
if module.running_var is not None:
self.addDataTensor(
self.add_data_tensor(
"{layer_name}_running_var".format(layer_name=layer_name),
module.state_dict().get("running_var")
)

batch_size = int(output_shape[0])
self.addOutputTensor(layer_name, output_shape)
self.add_output_tensor(layer_name, output_shape)

self.model_forward += INDENT + """NN_BatchNorm2d(
self.model_forward += INDENT + """NN_batch_norm2d(
&model->{layer_name}, &model->{input_name[0]},
{weight}, {bias},
{eps}, {running_mean}, {running_var});\n""".format(
Expand All @@ -282,19 +282,19 @@ def call_module(self, target, args, kwargs):
elif type(module) == torch.nn.Conv2d:
if module.weight is not None:
# weight need to be converted from (out_ch, in_ch, kh, kw) to (kh, kw, in_ch, out_ch)
self.addDataTensor(
self.add_data_tensor(
"{layer_name}_weight".format(layer_name=layer_name),
module.state_dict().get("weight").permute(2, 3, 1, 0)
)
if module.bias is not None:
self.addDataTensor(
self.add_data_tensor(
"{layer_name}_bias".format(layer_name=layer_name),
module.state_dict().get("bias")
)

self.addOutputTensor(layer_name, output_shape)
self.add_output_tensor(layer_name, output_shape)

self.model_forward += INDENT + """NN_Conv2d(
self.model_forward += INDENT + """NN_conv2d(
&model->{layer_name}, &model->{input_names[0]},
{weight}, {bias}, (size_t[]){{{stride}}}, (size_t[]){{{padding}}}, (size_t[]){{{dilation}}}, {groups});\n""".format(
layer_name=layer_name,
Expand All @@ -309,26 +309,26 @@ def call_module(self, target, args, kwargs):
self.prev_layer_name = "{layer_name}".format(layer_name=layer_name)

elif type(module) == torch.nn.ReLU:
self.model_forward += INDENT + "NN_ReLU(&model->{layer_name}, &model->{input_names[0]});\n".format(
self.model_forward += INDENT + "NN_relu(&model->{layer_name}, &model->{input_names[0]});\n".format(
layer_name=layer_name,
input_names=input_names
)
self.addOutputTensor(layer_name, output_shape)
self.add_output_tensor(layer_name, output_shape)

elif type(module) == torch.nn.ReLU6:
self.model_forward += INDENT + "NN_ReLU6(&model->{layer_name}, &model->{input_names[0]});\n".format(
self.model_forward += INDENT + "NN_relu6(&model->{layer_name}, &model->{input_names[0]});\n".format(
layer_name=layer_name,
input_names=input_names
)
self.addOutputTensor(layer_name, output_shape)
self.add_output_tensor(layer_name, output_shape)

elif type(module) == torch.nn.ELU:
self.model_forward += INDENT + "NN_ELU(&model->{layer_name}, &model->{input_names[0]}, {eps});\n".format(
self.model_forward += INDENT + "NN_elu(&model->{layer_name}, &model->{input_names[0]}, {eps});\n".format(
layer_name=layer_name,
input_names=input_names,
eps=module.alpha
)
self.addOutputTensor(layer_name, output_shape)
self.add_output_tensor(layer_name, output_shape)

else:
print("[WARNING] Unsupported module call:", target)
Expand Down
4 changes: 2 additions & 2 deletions example/char-rnn/runtime_test_c/char-rnn.c
Original file line number Diff line number Diff line change
Expand Up @@ -57,7 +57,7 @@ int main() {

for (int j=1; j<strlen(str); j+=1) {
encodeOneHot(&input, str[j]);
NN_Linear(&hidden, &i2h_weight_transposed, &i2h_bias, &input);
NN_linear(&hidden, &i2h_weight_transposed, &i2h_bias, &input);

forward(&output, &hidden, &input);
}
Expand All @@ -67,7 +67,7 @@ int main() {

printf("\n> %s\n", str);
printf("score: (");
NN_printFloat(output.data[index], 2);
NN_print_f32(output.data[index], 2);
printf("), predicted: (%d, %s)\n", index, categories[index]);
}

Expand Down
4 changes: 2 additions & 2 deletions example/char-rnn/runtime_test_c/model.h
Original file line number Diff line number Diff line change
Expand Up @@ -35,9 +35,9 @@ static void forward(Matrix *output, Matrix *hidden, Matrix *input) {
// Input
Matrix *input_out = input;
// Linear
NN_Linear(hidden, &i2h_weight_transposed, &i2h_bias, input_out);
NN_linear(hidden, &i2h_weight_transposed, &i2h_bias, input_out);
// Linear
NN_Linear(output, &h2o_weight_transposed, &h2o_bias, hidden);
NN_linear(output, &h2o_weight_transposed, &h2o_bias, hidden);
// Log Softmax
NN_logSoftmax(output, output);
}
Expand Down
8 changes: 4 additions & 4 deletions example/char-rnn/runtime_test_c/nn.h
Original file line number Diff line number Diff line change
Expand Up @@ -31,7 +31,7 @@ void NN_assert(int condition, char *message) {
* These functions assumes that printf is available.
*/

void NN_printFloat(float v, int16_t num_digits) {
void NN_print_f32(float v, int16_t num_digits) {
int32_t scale = 1;
int32_t integer_part, fractional_part;
while (num_digits != 0) {
Expand All @@ -46,14 +46,14 @@ void NN_printFloat(float v, int16_t num_digits) {
printf("%i.%i", integer_part, fractional_part);
}

void NN_printShape(Matrix *a) {
void NN_print_shape(Matrix *a) {
printf("(%d, %d)\n", a->rows, a->cols);
}

void NN_printMatrix(Matrix *a) {
for (size_t i = 0; i < a->rows; i++) {
for (size_t j = 0; j < a->cols; j++) {
NN_printFloat(a->data[i * a->cols + j], 2);
NN_print_f32(a->data[i * a->cols + j], 2);
printf(" ");
}
printf("\n");
Expand Down Expand Up @@ -128,7 +128,7 @@ size_t NN_argmax(Matrix *a) {
* ====== Operators ======
*/

void NN_Linear(Matrix *out, Matrix *weight, Matrix *bias, Matrix *input) {
void NN_linear(Matrix *out, Matrix *weight, Matrix *bias, Matrix *input) {
NN_matmul(out, input, weight);
NN_matadd(out, out, bias);
}
Expand Down
4 changes: 2 additions & 2 deletions example/char-rnn/runtime_test_np/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,9 +8,9 @@ def forward(input):
# Input
input_out = input
# Linear
i2h_out = NN_Linear(input_out, i2h_weight_transposed, i2h_bias)
i2h_out = NN_linear(input_out, i2h_weight_transposed, i2h_bias)
# Linear
h2o_out = NN_Linear(i2h_out, h2o_weight_transposed, h2o_bias)
h2o_out = NN_linear(i2h_out, h2o_weight_transposed, h2o_bias)
# Log Softmax
softmax_out = nn_logsoftmax(h2o_out)
return softmax_out, i2h_out
2 changes: 1 addition & 1 deletion example/char-rnn/runtime_test_np/nn.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
import numpy as np

def NN_Linear(input, weight_T, bias):
def NN_linear(input, weight_T, bias):
return np.matmul(input, weight_T) + bias

def nn_logsoftmax(input):
Expand Down
Loading

0 comments on commit aa71c8a

Please sign in to comment.