Skip to content

Commit

Permalink
FIX: fixing converter for torch optimized case
Browse files Browse the repository at this point in the history
  • Loading branch information
T-K-233 committed Jul 21, 2024
1 parent 0b957e6 commit 30d0878
Show file tree
Hide file tree
Showing 3 changed files with 117 additions and 186 deletions.
221 changes: 117 additions & 104 deletions barstools/src/barstools/converter.py
Original file line number Diff line number Diff line change
@@ -1,12 +1,14 @@
import operator
import os
import inspect
from typing import Any, Dict, List

import numpy as np
import torch
import torch.nn
import torch.fx
import jinja2
import tabulate


INDENT = " "
Expand Down Expand Up @@ -103,6 +105,24 @@ def extract_graph_module(model: torch.nn.Module) -> list[torch.fx.Graph, torch.f
gm = torch.fx.GraphModule(model, graph)
return graph, gm

@staticmethod
def to_functional_torch(module: torch.nn.Module) -> Any:
if type(module) == torch.nn.Linear:
return torch.nn.functional.linear
elif type(module) == torch.nn.Conv2d:
return torch.nn.functional.conv2d
elif type(module) == torch.nn.BatchNorm2d:
return torch.nn.functional.batch_norm
elif type(module) == torch.nn.ReLU:
return torch.nn.functional.relu
elif type(module) == torch.nn.ReLU6:
return torch.nn.functional.relu6
elif type(module) == torch.nn.ELU:
return torch.nn.functional.elu
else:
print("[WARNING] Unsupported module call:", module)


def __init__(self, model: torch.nn.Module):
graph, gm = TorchConverter.extract_graph_module(model)
super().__init__(gm)
Expand Down Expand Up @@ -134,7 +154,7 @@ def __init__(self, model: torch.nn.Module):
self.placeholder_counter = {}
self.function_counter = {}

def print(self):
def print_graph(self):
self.gm.graph.print_tabular()

def get_module_in_sequential(self, module, indicies):
Expand Down Expand Up @@ -186,78 +206,50 @@ def add_output_tensor(self, name, shape, dtype=torch.float32):
data="NULL"
)

def placeholder(self, target, args, kwargs):
print("placeholder:", target)

## sooooo hacky
shape = self.example_input.shape
if len(shape) == 4:
shape = (shape[0], shape[2], shape[3], shape[1])

# this is also hacky

name = target

if name == "input":
# torch will rename the input tensor to input_1 to avoid conflict with python keyword
name = "input_1"

self.model_struct += INDENT + TEMPLATE_TENSOR_DECLARE.format(name=name)

self.model_init += INDENT + TEMPLATE_TENSOR_INIT.format(
name=name,
dim=len(shape),
shape=", ".join(str(x) for x in shape),
dtype=TorchConverter.dtype_to_str(self.example_input.dtype),
data="NULL"
)

return super().placeholder(target, args, kwargs)

def trace_functional(self, layer_name, function, args, kwargs):
print(" trace:", function)
output = super().call_function(function, args, kwargs)
def trace_functional(self, layer_name, target, args, kwargs, out):
print(" trace:", layer_name, target)

output_shape = output.shape
output_shape = out.shape
if len(output_shape) == 4:
output_shape = (output_shape[0], output_shape[2], output_shape[3], output_shape[1])

input_names = self.node_info[layer_name][0]

if function == operator.__add__:
if target == operator.__add__:
self.model_forward += INDENT + "NN_add(&model->{layer_name}, &model->{input_names[0]}, &model->{input_names[1]});\n".format(
layer_name=layer_name,
input_names=input_names
)
self.add_output_tensor(layer_name, output_shape)
elif function == torch.nn.functional.interpolate:
elif target == torch.nn.functional.interpolate:
self.model_forward += INDENT + "NN_interpolate(&model->{layer_name}, &model->{input_names[0]}, (float []){{{scale_factor}, {scale_factor}}});\n".format(
layer_name=layer_name,
input_names=input_names,
scale_factor=kwargs.get("scale_factor")
)
self.add_output_tensor(layer_name, output_shape)
elif function == torch.nn.functional.relu:
elif target == torch.nn.functional.relu:
self.model_forward += INDENT + "NN_relu(&model->{layer_name}, &model->{input_names[0]});\n".format(
layer_name=layer_name,
input_names=input_names
)
self.add_output_tensor(layer_name, output_shape)
elif function == torch.nn.functional.relu6:
elif target == torch.nn.functional.relu6:
self.model_forward += INDENT + "NN_relu6(&model->{layer_name}, &model->{input_names[0]});\n".format(
layer_name=layer_name,
input_names=input_names
)
self.add_output_tensor(layer_name, output_shape)
elif function == torch.nn.functional.conv2d:
elif target == torch.nn.functional.conv2d:
weight = args[1]
bias = args[2]
stride = args[3]
padding = args[4]
dilation = args[5]
groups = args[6]


if weight is not None:
# weight need to be converted from (out_ch, in_ch, kh, kw) to (kh, kw, in_ch, out_ch)
self.add_data_tensor(
Expand Down Expand Up @@ -286,88 +278,108 @@ def trace_functional(self, layer_name, function, args, kwargs):
)
self.prev_layer_name = "{layer_name}".format(layer_name=layer_name)

elif function == torch.nn.functional.linear:
elif target == torch.nn.functional.linear:
input_names = input_names
weight = args[1]
bias = args[2]
add_linear(self, layer_name, output_shape, input_names, weight, bias)
# self.model_forward += INDENT + "NN_linear(&model->{layer_name}, &model->{input_names[0]}, {weight}, {bias});\n".format(
# layer_name=layer_name,
# input_names=self.node_info[layer_name][0],
# weight="&model->{layer_name}_weight".format(layer_name=layer_name),
# bias="&model->{layer_name}_bias".format(layer_name=layer_name)
# )
# self.add_output_tensor(layer_name, output_shape)
elif function == torch.nn.functional.elu:
self.model_forward += INDENT + "NN_linear(&model->{layer_name}, &model->{input_names[0]}, {weight}, {bias});\n".format(
layer_name=layer_name,
input_names=self.node_info[layer_name][0],
weight="&model->{layer_name}_weight".format(layer_name=layer_name),
bias="&model->{layer_name}_bias".format(layer_name=layer_name)
)
if weight is not None:
# weight need to be converted from (out_ch, in_ch, kh, kw) to (kh, kw, in_ch, out_ch)
self.add_data_tensor(
"{layer_name}_weight".format(layer_name=layer_name),
weight
)
if bias is not None:
self.add_data_tensor(
"{layer_name}_bias".format(layer_name=layer_name),
bias
)

self.add_output_tensor(layer_name, output_shape)
elif target == torch.nn.functional.elu:
self.model_forward += INDENT + "NN_elu(&model->{layer_name}, &model->{input_names[0]}, {eps});\n".format(
layer_name=layer_name,
input_names=input_names,
eps=args[1]
)
self.add_output_tensor(layer_name, output_shape)
else:
print("[WARNING] Unsupported function call:", function)

return output

def call_function(self, target, args, kwargs):
print("call function:", target)
output = super().call_function(target, args, kwargs)

# because in functional API, the function name is not unique,
# we need to add a counter to the layer name
count = self.function_counter.get(target.__name__, 0)
self.function_counter[target.__name__] = count + 1

layer_name = target.__name__ + "_{count}".format(count=count) if count > 0 else target.__name__

self.trace_functional(layer_name, target, args, kwargs)

return output

def call_method(self, target, args, kwargs):
print("call method:", target)
return super().call_method(target, args, kwargs)

def call_module(self, target, args, kwargs):
print("call module:", target)
output = super().call_module(target, args, kwargs)

module = self.get_module(target)
layer_name = target.replace(".", "_")
def run_node(self, n: torch.fx.node.Node) -> Any:
out = super().run_node(n)

self.model_init += "\n"
self.model_init += INDENT + "// {module}: {layer_name}\n".format(
module=type(module),
layer_name=layer_name
if n.op == "placeholder":
print("placeholder:", n.name)

## sooooo hacky
shape = self.example_input.shape
if len(shape) == 4:
shape = (shape[0], shape[2], shape[3], shape[1])

self.model_struct += INDENT + TEMPLATE_TENSOR_DECLARE.format(name=n.name)

self.model_init += INDENT + TEMPLATE_TENSOR_INIT.format(
name=n.name,
dim=len(shape),
shape=", ".join(str(x) for x in shape),
dtype=TorchConverter.dtype_to_str(self.example_input.dtype),
data="NULL"
)

if type(module) == torch.nn.Linear:
args = (args[0], module.weight, module.bias)
self.trace_functional(layer_name, torch.nn.functional.linear, args, kwargs)

elif type(module) == torch.nn.BatchNorm2d:
args = (args[0], module.weight, module.bias, module.running_mean, module.running_var, module.eps)
self.trace_functional(layer_name, torch.nn.functional.batch_norm, args, kwargs)

elif type(module) == torch.nn.Conv2d:
args = (args[0], module.weight, module.bias, module.stride, module.padding, module.dilation, module.groups)
self.trace_functional(layer_name, torch.nn.functional.conv2d, args, kwargs)

# elif n.op == "get_attr":
# breakpoint()

elif type(module) == torch.nn.ReLU:
self.trace_functional(layer_name, torch.nn.functional.relu, args, kwargs)
elif n.op == "call_function":
print("call function:", n.name, n.target)
args = n.args

elif type(module) == torch.nn.ReLU6:
self.trace_functional(layer_name, torch.nn.functional.relu6, args, kwargs)
if n.target == torch.nn.functional.linear:
weight = self.model.state_dict()[n.args[1].target]
bias = self.model.state_dict()[n.args[2].target]
args = (n.args[0], weight, bias)

self.trace_functional(n.name, n.target, args, n.kwargs, out)

elif type(module) == torch.nn.ELU:
args = (args[0], module.alpha)
self.trace_functional(layer_name, torch.nn.functional.elu, args, kwargs)

else:
print("[WARNING] Unsupported module call:", target)

return output
elif n.op == "call method":
print("call method:", n.name, n.target)
raise NotImplementedError()

elif n.op == "call_module":
print("call module:", n.name, n.target)
args = n.args
layer_name = n.name

module = self.get_module(n.target)
target = TorchConverter.to_functional_torch(module)

self.model_init += "\n"
self.model_init += INDENT + "// {module}: {layer_name}\n".format(
module=type(module),
layer_name=layer_name
)

if type(module) == torch.nn.Linear:
args = (n.args[0], module.weight, module.bias)

elif type(module) == torch.nn.BatchNorm2d:
args = (n.args[0], module.weight, module.bias, module.running_mean, module.running_var, module.eps)

elif type(module) == torch.nn.Conv2d:
args = (n.args[0], module.weight, module.bias, module.stride, module.padding, module.dilation, module.groups)

elif type(module) == torch.nn.ELU:
args = (n.args[0], module.alpha)

else:
print("[WARNING] Unsupported module call:", n.target)

self.trace_functional(layer_name, target, args, n.kwargs, out)

return out

def convert(self, example_input, output_dir="."):
self.example_input = example_input
Expand Down Expand Up @@ -420,6 +432,7 @@ def forward(self, input):
test_input = torch.zeros((48, )).unsqueeze(0)
print("input:", test_input)

TorchConverter(m).print_graph()
output = TorchConverter(m).convert(test_input)
print("output:", output)

Binary file removed barstools/src/barstools/model.bin
Binary file not shown.
Loading

0 comments on commit 30d0878

Please sign in to comment.