diff --git a/example/openface/adams.png b/example/openface/adams.png new file mode 100755 index 0000000..ac90a2e Binary files /dev/null and b/example/openface/adams.png differ diff --git a/example/openface/convert.py b/example/openface/convert.py new file mode 100644 index 0000000..0fb0042 --- /dev/null +++ b/example/openface/convert.py @@ -0,0 +1,31 @@ +import sys +import os +sys.path.append( + os.path.dirname(os.path.realpath(__file__)) + "/../../torch2coreml" +) + +from _torch_converter import convert +import _layers as layers +from torch.utils.serialization import load_lua +import torch as th +import torch.tensor +import torch.nn as nn +from collections import OrderedDict +from torch.autograd import Variable + +def convert_unknown(builder, name, layer, input_names, output_names): + print("!! No converter yet for layer: "+name) + return output_names + +coreml_model = convert( + "openface.t7", + [(3,96,96)], + image_input_names=['input'], + output_shapes=[[128]], + mode=None, + unknown_layer_converter_fn=convert_unknown + ) + +coreml_model.author = 'Leonardo Galli feat. OpenFace' +coreml_model.license = 'Free for personal or research use' +coreml_model.save("openface.mlmodel") diff --git a/example/openface/coremlmodel.png b/example/openface/coremlmodel.png new file mode 100644 index 0000000..7d16681 Binary files /dev/null and b/example/openface/coremlmodel.png differ diff --git a/example/openface/prepare.lua b/example/openface/prepare.lua new file mode 100644 index 0000000..a985ad2 --- /dev/null +++ b/example/openface/prepare.lua @@ -0,0 +1,82 @@ +require 'torch' +require 'nn' +require 'dpnn' + +-- require 'fast-neural-style.fast_neural_style.ShaveImage' +-- require 'fast-neural-style.fast_neural_style.TotalVariation' +-- require 'fast-neural-style.fast_neural_style.InstanceNormalization' + + + +local function replaceModule(x, name, create_fn) + if not x.modules then + return + end + for i = 1,#x.modules do + m = x.modules[i] + if m.__typename == name then + x.modules[i] = create_fn(m) + end + replaceModule(m, name, create_fn) + end +end + +local function replaceInception(x) + if not x.modules then + return + end + for i = 1,#x.modules do + m = x.modules[i] + if m.__typename == 'nn.Inception' then + print(m.module) + x.modules[i] = m.module + end + replaceInception(m) + end +end + +local function main() + local cmd = torch.CmdLine() + cmd:option('-input', '') + cmd:option('-output', '') + local opt = cmd:parse(arg) + local model = torch.load(opt.input) + + -- Replace nn.ShaveImage with crop using SpatialZeroPadding with negative offsets + -- replaceModule(model, 'nn.ShaveImage', function(m) + -- local size = m.size + -- return nn.SpatialZeroPadding(-size, -size, -size, -size) + -- end) + -- + replaceModule(model, 'nn.SpatialConvolutionMM', function(n) + torch.setdefaulttensortype('torch.FloatTensor') + local new = nn.SpatialConvolution(n.nInputPlane, n.nOutputPlane, n.kW, n.kH, n.dW, n.dH, n.padW, n.padH) + torch.setdefaulttensortype('torch.LongTensor') + new.weight = n.weight +-- new.weight = new.weight:float() + new.bias = n.bias + new.gradWeight = n.gradWeight + new.gradBias = n.gradBias +-- new.gradWeight = new.gradWeight:float() +-- new.gradBias = new.gradBias:double() + --new.bias = n.bias +-- new:type(float) + --print(new.bias) + --print(torch.Tensor()) + return new + end) + + replaceInception(model) + + x = torch.FloatTensor() + model:forward(x:resize(1,3,96,96)) + -- -- Remove last TotalVariation layer + -- if model.modules[#model.modules].__typename == 'nn.TotalVariation' then + -- model.modules[#model.modules] = nil + -- end + print(model) + -- Save prepared model + torch.save(opt.output, model) +end + +main() diff --git a/example/openface/print.py b/example/openface/print.py new file mode 100644 index 0000000..5f83f09 --- /dev/null +++ b/example/openface/print.py @@ -0,0 +1,4 @@ +from coremltools.models import MLModel + +m = MLModel("openface.mlmodel") +print(m.get_spec()) diff --git a/example/openface/test.py b/example/openface/test.py new file mode 100644 index 0000000..4e9884c --- /dev/null +++ b/example/openface/test.py @@ -0,0 +1,13 @@ +from coremltools.models import MLModel +from PIL import Image + +print("Loading model...") +m = MLModel("openface.mlmodel") + +test = Image.open("adams.png") + +print("Predicting...") + +pred = m.predict({"input" : test}) + +print("Prediction: ", pred) diff --git a/example/openface/torchmodel.png b/example/openface/torchmodel.png new file mode 100644 index 0000000..9b12a92 Binary files /dev/null and b/example/openface/torchmodel.png differ diff --git a/example/openface/visualize.py b/example/openface/visualize.py new file mode 100644 index 0000000..aa38ebb --- /dev/null +++ b/example/openface/visualize.py @@ -0,0 +1,21 @@ +import pydot # import pydot or you're not going to get anywhere my friend :D +from coremltools.models import MLModel + +# first you create a new graph, you do that with pydot.Dot() +graph = pydot.Dot(graph_type='graph') + +print("Loading model...") +m = MLModel("openface.mlmodel") + +print("Loading layers...") +layers = m.get_spec().neuralNetwork.layers + +print("Drawing graph...") +for layer in layers[::-1]: #reverse Order + for input_name in layer.input: + edge = pydot.Edge(input_name, layer.name) + # and we obviosuly need to add the edge to our graph + graph.add_edge(edge) + + +graph.write_png('coremlmodel.png') diff --git a/example/openface/visualize_torch.py b/example/openface/visualize_torch.py new file mode 100644 index 0000000..9eca955 --- /dev/null +++ b/example/openface/visualize_torch.py @@ -0,0 +1,24 @@ +import pydot +from torch.utils.serialization import load_lua + +graph = pydot.Dot(graph_type='graph') + +m = load_lua("openface.t7") + +last_layer = None + +num = len(m.modules) + +for layer in m.modules[::-1]: + if last_layer is None: + last_layer = layer + continue + fmt = "{0}_{1}\n{0.output.shape}" + input_name = fmt.format(layer, num) + output_name = fmt.format(last_layer, num+1) + edge = pydot.Edge(input_name, output_name) + graph.add_edge(edge) + last_layer = layer + num -= 1 + +graph.write_png('torchmodel.png') diff --git a/torch2coreml/_layers.py b/torch2coreml/_layers.py index 17d5117..74ad18c 100644 --- a/torch2coreml/_layers.py +++ b/torch2coreml/_layers.py @@ -153,6 +153,49 @@ def _convert_concat_table(builder, name, layer, input_names, output_names): result_outputs += l_outputs return result_outputs +def _convert_depth_concat(builder, name, layer, input_names, output_names): + #Each DepthConcat layer has multiple sub sequential nets wich in itself have some layers + nets = layer.modules + base_name = output_names[0] + max_size = (0,0) #Some layers have smaller sizes so we need to pad them. Thus we keep track of the max size + net_num = ord('a') #We label the sub nets with lowercase alphabets + concat_inputs = [] #We need to keep track of the outputs of all subnets + for net in nets: + net_name = base_name + chr(net_num) + "_" + layer_num = 1 + last_name = input_names[0] + for l in net.modules: + l_name = net_name + _torch_typename(l) + "_" + str(layer_num) + + _convert_layer(builder, l_name, l, [last_name], [l_name]) + size = l.output.shape + last_name = l_name + layer_num += 1 + + #Setting max size + if size[2] > max_size[0]: + max_size = (size[2], size[3]) + + #Check padding + if size[2] < max_size[0]: + #ohoh we need some padding! + diff = max_size[0] - size[2] + right_bottom = diff / 2 + left_top = diff / 2 + if diff % 2 != 0: + right_bottom += 1 + + padding_name = net_name + "Padding" + builder.add_padding(padding_name, left=left_top, right=right_bottom, top=left_top, bottom=right_bottom, + input_name = last_name, output_name=padding_name) + last_name = padding_name + net_num += 1 + concat_inputs.append(last_name) + + builder.add_elementwise(name = name, input_names = concat_inputs, + output_name = output_names[0], mode = "CONCAT") + return output_names + def _convert_batch_norm(builder, name, layer, input_names, output_names): epsilon = layer.eps @@ -215,11 +258,15 @@ def _convert_pooling(builder, name, layer, input_names, output_names): elif typename == 'SpatialAveragePooling': layer_type = 'AVERAGE' exclude_pad_area = not layer.count_include_pad + elif typename == 'SpatialLPPooling': + layer_type = 'L2' #LP pooling with pnorm = 2 is the same as L2 Pooling else: raise TypeError("Unknown type '{}'".format(typename,)) k_h, k_w = layer.kH, layer.kW - pad_h, pad_w = layer.padH, layer.padW + pad_h, pad_w = 0, 0 + if layer_type != "L2": + pad_h, pad_w = layer.padH, layer.padW d_h, d_w = layer.dH, layer.dW builder.add_pooling( @@ -266,6 +313,7 @@ def _convert_linear(builder, name, layer, input_names, output_names): def _convert_view(builder, name, layer, input_names, output_names): shape = tuple(layer.size) + if len(shape) == 1 or (len(shape) == 2 and shape[0] == 1): builder.add_flatten( name=name, @@ -436,12 +484,21 @@ def _convert_split_table(builder, name, layer, input_names, output_names): return output_names +def _convert_lrn(builder, name, layer, input_names, output_names): + builder.add_lrn(name, input_name=input_names[0], output_name=output_names[0], alpha=layer.alpha, + beta=layer.beta, local_size=layer.size, k = layer.k) + return output_names + +def _convert_normalize(builder, name, layer, input_names, output_names): + builder.add_l2_normalize(name, input_name=input_names[0], output_name=output_names[0], epsilon = layer.eps) + return output_names _TORCH_LAYER_REGISTRY = { 'Sequential': _convert_sequential, 'SpatialConvolution': _convert_convolution, 'ELU': _convert_elu, 'ConcatTable': _convert_concat_table, + 'DepthConcat': _convert_depth_concat, 'SpatialBatchNormalization': _convert_batch_norm, 'Identity': _convert_identity, 'CAddTable': _convert_cadd_table, @@ -451,7 +508,9 @@ def _convert_split_table(builder, name, layer, input_names, output_names): 'ReLU': _convert_relu, 'SpatialMaxPooling': _convert_pooling, 'SpatialAveragePooling': _convert_pooling, + 'SpatialLPPooling': _convert_pooling, 'View': _convert_view, + 'Reshape': _convert_view, 'Linear': _convert_linear, 'Tanh': _convert_tanh, 'MulConstant': _convert_mul_constant, @@ -459,7 +518,9 @@ def _convert_split_table(builder, name, layer, input_names, output_names): 'Narrow': _convert_narrow, 'SpatialReflectionPadding': _convert_reflection_padding, 'SpatialUpSamplingNearest': _convert_upsampling_nearest, - 'SplitTable': _convert_split_table + 'SplitTable': _convert_split_table, + 'SpatialCrossMapLRN': _convert_lrn, + 'Normalize': _convert_normalize, } diff --git a/torch2coreml/_torch_converter.py b/torch2coreml/_torch_converter.py index c086e45..791397d 100644 --- a/torch2coreml/_torch_converter.py +++ b/torch2coreml/_torch_converter.py @@ -119,6 +119,7 @@ def _set_deprocessing(is_grayscale, def convert(model, input_shapes, + output_shapes=None, input_names=['input'], output_names=['output'], mode=None, @@ -183,7 +184,7 @@ def convert(model, _get_layer_converter_fn.unknown_converter_fn = unknown_layer_converter_fn if isinstance(model, basestring): - torch_model = load_lua(model) + torch_model = load_lua(model, unknown_classes=True) elif isinstance(model, torch.legacy.nn.Sequential): torch_model = model else: @@ -192,7 +193,7 @@ def convert(model, with torch.legacy.nn.Sequential module as root" ) - torch_model.evaluate() + #torch_model.evaluate() if not isinstance(input_shapes, list): raise TypeError("Input shapes should be a list of tuples.") @@ -206,10 +207,11 @@ def convert(model, "Input names count must be equal to input shapes count" ) - output_shapes = _infer_torch_output_shapes( - torch_model, - input_shapes - ) + if output_shapes == None: + output_shapes = _infer_torch_output_shapes( + torch_model, + input_shapes + ) if len(output_shapes) != len(output_names): raise ValueError(