Skip to content

Commit

Permalink
Merge branch 'master' into nncase-studio
Browse files Browse the repository at this point in the history
  • Loading branch information
FusionBolt authored Nov 20, 2023
2 parents 96e15e7 + 3820af1 commit 5c08f36
Show file tree
Hide file tree
Showing 30 changed files with 947 additions and 184 deletions.
12 changes: 10 additions & 2 deletions src/Native/src/kernels/stackvm/reference/pad.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -171,8 +171,7 @@ void padding_impl_opt(T *in, T *out, gsl::span<const size_t> in_shape,
dh = out_shape[1];
hh = out_shape[2];
wh = out_shape[3];
} else // size ==2
{
} else if (in_shape.size() == 2) {
cl = 1;
dl = 1;
hl = in_shape[0];
Expand All @@ -181,6 +180,15 @@ void padding_impl_opt(T *in, T *out, gsl::span<const size_t> in_shape,
dh = 1;
hh = out_shape[0];
wh = out_shape[1];
} else {
cl = 1;
dl = 1;
hl = 1;
wl = in_shape[0];
ch = 1;
dh = 1;
hh = 1;
wh = out_shape[1];
}

pad_data2(in, out, cl, dl, hl, wl, ch, dh, hh, wh, value);
Expand Down
9 changes: 0 additions & 9 deletions src/Native/src/kernels/stackvm/tensor_ops.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -793,14 +793,6 @@ result<value_t> nncase::kernels::stackvm::bucket_pad(
auto in_tensor = input.as<tensor>().expect("input is not a tensor");
auto in_shape = in_tensor->shape();
if (compute_size(in_shape) > compute_size(shape_value)) {
std::cout << "in shape" << std::endl;
for (int i = 0; i < in_shape.size(); ++i) {
std::cout << in_shape[i] << std::endl;
}
std::cout << "shape_value shape" << std::endl;
for (int i = 0; i < shape_value.size(); ++i) {
std::cout << shape_value[i] << std::endl;
}
return err(std::errc::invalid_argument);
}

Expand Down Expand Up @@ -1138,7 +1130,6 @@ nncase::kernels::stackvm::squeeze(value_t input, value_t dim, value_t output,
try_var(in_tensor, input.as<tensor>());
auto in_shape = in_tensor->shape();
not_impl_no_contiguous(in_tensor);
// todo: dim is scalar
try_positive_axes(axes, dim, in_tensor->shape().size());
auto new_shape = squeeze_infer_shape(in_shape, axes);
output = tensor_reshape(in_tensor, new_shape);
Expand Down
4 changes: 4 additions & 0 deletions src/Nncase.Core/PatternMatch/PatternUtility.cs
Original file line number Diff line number Diff line change
Expand Up @@ -285,4 +285,8 @@ public static Pattern IsCallWildcardMaybeSwappable<TOp>(string callName, Pattern
IsAlt(
IsCallWildcard(callName, IsOp<TOp>(callName + "Op"), input),
IsCallWildcardSwappable(callName, IsOp<TOp>(callName + "Op"), input, swappableOther ?? IsWildcard()));

public static Pattern MaybeMarker(Pattern input) => IsAlt(input, IsRangeOfMarker(input, IsWildcard()));

public static Pattern HasMarker(Pattern input, string? markerName = null) => IsRangeOfMarker(markerName, input, IsWildcard());
}
15 changes: 12 additions & 3 deletions src/Nncase.Evaluator/EvaluatorDumpManager.cs
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@
using System.Linq;
using Nncase.Diagnostics;
using Nncase.IR;
using Nncase.IR.Tensors;
using Nncase.Utilities;
using CallbacksRegister = System.Action<string, System.Action<Nncase.IR.Expr>>;
using TensorGetter = System.Func<Nncase.IR.Expr, Nncase.Tensor[]>;
Expand All @@ -27,6 +28,7 @@ public EvaluatorDumpManager(IDumpper dumpper, TensorGetter tensorGetter)
_dumpper = dumpper;
_tensorGetter = tensorGetter;

// todo: has bug when evaluate sub function
if (_dumpper.IsEnabled(DumpFlags.Evaluator))
{
_shapeWriter = new StreamWriter(_dumpper.OpenFile("!out_shape_list"));
Expand Down Expand Up @@ -57,6 +59,12 @@ private static string GetTargetName(Call call)

private void DumpCallArgs(Call call)
{
// todo: fix this
if (call.Target is not Op)
{
return;
}

string target = GetTargetName(call);
var paramsInfo = ((Op)call.Target).Parameters.ToArray();

Expand All @@ -75,11 +83,12 @@ private void DumpCall(Call call)
string target = GetTargetName(call);

// a bad tmp change
var shape = !(call.CheckedType is TensorType) ? Shape.Scalar : call.CheckedShape;
var result = _tensorGetter(call);

// todo: when tuple maybe bug
var shape = result.Length == 1 ? result[0].Shape : result[0].Shape.ToValueArray();
DumpCall(target, shape, sr =>
{
sr.WriteLine(target);
var result = _tensorGetter(call);
ValueDumper.DumpTensors(result, sr);
});
}
Expand Down
2 changes: 1 addition & 1 deletion src/Nncase.Evaluator/NN/BatchToSpace.cs
Original file line number Diff line number Diff line change
Expand Up @@ -203,7 +203,7 @@ private IRType Visit(ITypeInferenceContext context, BatchToSpace target, TensorT
var m = blockShape.Shape[0].FixedValue;
var cropsV = cropsValue.Value.Cast<int>();
var cropSection = Enumerable.Range(0, m).Select(
i => (inShape[i + 1] * blockShapeArr[0]) - cropsV[i, 0] - cropsV[i, 1]);
i => (inShape[i + 1] * blockShapeArr[i]) - cropsV[i, 0] - cropsV[i, 1]);

var remainSize = inShape.Rank - 1 - m;
var remainShape = remainSize > 0 ? inShape.Skip(1 + m) : Array.Empty<Dimension>();
Expand Down
96 changes: 92 additions & 4 deletions src/Nncase.Evaluator/NN/SpaceToBatch.cs
Original file line number Diff line number Diff line change
Expand Up @@ -45,6 +45,7 @@ public Metric Visit(IMetricEvaluateContext context, SpaceToBatch target)
public IValue Visit(IEvaluateContext context, SpaceToBatch s)
{
var input = context.GetOrtArgumentValue(s, SpaceToBatch.Input);
input = NCHWToNHWC(input.ToTensor()).Evaluate().AsTensor().ToOrtTensor();
var blockShape = context.GetArgumentValueAsTensor<long>(s, SpaceToBatch.BlockShape);
var paddings = context.GetArgumentValueAsArray<long>(s, SpaceToBatch.Paddings);
var spatialSize = blockShape.Length;
Expand Down Expand Up @@ -82,7 +83,8 @@ public IValue Visit(IEvaluateContext context, SpaceToBatch s)
var reshape1 = OrtKI.Reshape(p, (OrtKISharp.Tensor)reshappedShape1, 0);
var rt = OrtKI.Transpose(reshape1, perm);
var reshape2 = OrtKI.Reshape(rt, (OrtKISharp.Tensor)reshappedShape2, 0);
return reshape2.ToValue();

return NHWCToNCHW(reshape2.ToTensor()).Evaluate();
}

/// <inheritdoc/>
Expand All @@ -97,6 +99,9 @@ public IRType Visit(ITypeInferenceContext context, SpaceToBatch target)
public Expr Visit(IShapeEvaluateContext context, SpaceToBatch target)
{
var inShape = context.GetArgumentShape(target, SpaceToBatch.Input);
var inputExpr = context.GetArgument(target, SpaceToBatch.Input);
inShape = ShapeValueNCHWToNHWC(inputExpr, inShape);

var blockShape = context.GetArgument(target, SpaceToBatch.BlockShape);
var padding = Cast(context.GetArgument(target, SpaceToBatch.Paddings), DataTypes.Int64);
var input = context.GetArgument(target, SpaceToBatch.Input);
Expand Down Expand Up @@ -125,12 +130,89 @@ public Expr Visit(IShapeEvaluateContext context, SpaceToBatch target)
var remainShape = new If(remainSize > 0, ShapeExprUtility.Slice(inShape, 1 + m, int.MaxValue), Array.Empty<long>());
var outLast = remainShape;
var outShape = Concat(new IR.Tuple(Stack(new IR.Tuple(outFirst.Concat(outMid).ToArray()), 0), outLast), 0);

outShape = ShapeValueNHWCToNCHW(inputExpr, outShape);

return outShape;
}

throw new NotImplementedException();
}

private static Call ShapeValueNHWCToNCHW(Expr inputExpr, Call outShape)
{
if (inputExpr.CheckedShape.Rank == 4)
{
outShape = Stack(new IR.Tuple(new[] { outShape[0], outShape[3], outShape[1], outShape[2] }), 0);
}
else if (inputExpr.CheckedShape.Rank == 3)
{
outShape = Stack(new IR.Tuple(new[] { outShape[0], outShape[2], outShape[1] }), 0);
}
else
{
throw new InvalidOperationException();
}

return outShape;
}

private static Expr ShapeValueNCHWToNHWC(Expr inputExpr, Expr inShape)
{
if (inputExpr.CheckedShape.Rank == 4)
{
inShape = Stack(new IR.Tuple(new[] { inShape[0], inShape[2], inShape[3], inShape[1] }), 0);
}
else if (inputExpr.CheckedShape.Rank == 3)
{
inShape = Stack(new IR.Tuple(new[] { inShape[0], inShape[2], inShape[1] }), 0);
}
else
{
throw new InvalidOperationException();
}

return inShape;
}

private static Dimension[] ShapeNHWCToNCHW(List<Dimension> inShape, List<Dimension> outshape)
{
Dimension[] outputShape;

// nhwc to nchw
if (inShape.Count == 4)
{
outputShape = new[] { outshape[0], outshape[3], outshape[1], outshape[2] };
}
else
{
outputShape = new[] { inShape[0], inShape[2], inShape[1] };
}

return outputShape;
}

private static Dimension[] ShapeNCHWToNHWC(List<Dimension> inShape)
{
Dimension[] padded_shape;

// nchw to nhwc
if (inShape.Count == 4)
{
padded_shape = new[] { inShape[0], inShape[2], inShape[3], inShape[1] };
}
else if (inShape.Count == 3)
{
padded_shape = new[] { inShape[0], inShape[2], inShape[1] };
}
else
{
throw new InvalidOperationException();
}

return padded_shape;
}

private T[] RangeExec<T>(long end, Func<int, T> f)
{
return EndRange(0, (int)end).Select(f).ToArray();
Expand All @@ -149,7 +231,11 @@ private IRType Visit(ITypeInferenceContext context, SpaceToBatch target, TensorT
var ts_block_shape = block_shape_con.Value.Cast<int>();
var ts_paddings = paddings_con.Value.ToArray<int>();
int m = (int)ts_block_shape.Length;
var padded_shape = input.Shape.ToList();

// var padded_shape = input.Shape.ToList();
var inShape = input.Shape.ToList();
var padded_shape = ShapeNCHWToNHWC(inShape);

for (int i = 0; i < m; i++)
{
if (!padded_shape[1 + i].IsUnknown)
Expand All @@ -168,7 +254,7 @@ private IRType Visit(ITypeInferenceContext context, SpaceToBatch target, TensorT
new InvalidType($"The Padded Shape Must Divides BlockShape!")));
}

foreach (var i in Enumerable.Range(m + 1, padded_shape.Count - (m + 1)))
foreach (var i in Enumerable.Range(m + 1, padded_shape.Length - (m + 1)))
{
outshape.Add(padded_shape[i]);
}
Expand All @@ -178,7 +264,9 @@ private IRType Visit(ITypeInferenceContext context, SpaceToBatch target, TensorT
outshape[0] = outshape[0].IsUnknown ? Dimension.Unknown : outshape[0].FixedValue * block;
}

return input with { Shape = new Shape(outshape) };
var outputShape = ShapeNHWCToNCHW(inShape, outshape);

return input with { Shape = new Shape(outputShape) };
}

return new TensorType(input.DType, Enumerable.Repeat(Dimension.Unknown, input.Shape.Count).ToArray());
Expand Down
7 changes: 7 additions & 0 deletions src/Nncase.Evaluator/Tensors/BucketPad.cs
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
// Copyright (c) Canaan Inc. All rights reserved.
// Licensed under the Apache license. See LICENSE file in the project root for full license information.

using System;
using System.Linq;
using DryIoc.ImTools;
using Nncase.CostModel;
Expand All @@ -9,6 +10,7 @@
using Nncase.IR.Tensors;
using OrtKISharp;
using static Nncase.IR.F.Tensors;
using Tuple = Nncase.IR.Tuple;

namespace Nncase.Evaluator.Tensors;

Expand All @@ -27,6 +29,11 @@ public IValue Visit(IEvaluateContext context, BucketPad bucketPad)
}

var shape = context.GetArgumentValueAsArray<int>(bucketPad, BucketPad.Shape);
if (input.Shape.Size > shape.Aggregate((x, sum) => x * sum))
{
throw new InvalidOperationException();
}

var pads = shape - (Expr)input.Shape;
var paddings = Transpose(
Stack(new Tuple(Enumerable.Repeat(0, shape.Length).ToArray(), pads), 0),
Expand Down
31 changes: 29 additions & 2 deletions src/Nncase.Importer/TFLite/SpaceToBatchND.cs
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@
using Nncase.IR.Tensors;
using static Nncase.IR.F.NN;
using static Nncase.IR.F.Tensors;
using Unsqueeze = Nncase.IR.Tensors.Unsqueeze;

namespace Nncase.Importer.TFLite
{
Expand All @@ -13,14 +14,40 @@ private Expr VisitSpaceToBatchND(in tflite.Operator op)
{
var (input, blockShape) = GetInputExprs(op, 0, 1);
var paddings = GetInputExprs(op, 2);
return SpaceToBatch(input, blockShape, paddings);
if (input.CheckedShape.Rank == 3)
{
blockShape = Concat(new IR.Tuple(new[] { new[] { 1 }, blockShape }), 0);
paddings = Concat(new IR.Tuple(new[] { new[,] { { 0, 0 } }, paddings }), 0);
input = Unsqueeze(input, new[] { -3 });
}

var stb = NCHWToNHWC(SpaceToBatch(NHWCToNCHW(input), blockShape, paddings));
if (input.CheckedShape.Rank == 3)
{
return Squeeze(stb, new[] { 1 });
}

return stb;
}

private Expr VisitBatchToSpaceND(in tflite.Operator op)
{
var (input, blockShape) = GetInputExprs(op, 0, 1);
var crops = GetInputExprs(op, 2);
return NCHWToNHWC(BatchToSpace(NHWCToNCHW(input), blockShape, crops));
if (input.CheckedShape.Rank == 3)
{
blockShape = Concat(new IR.Tuple(new[] { new[] { 1 }, blockShape }), 0);
crops = Concat(new IR.Tuple(new[] { new[,] { { 0, 0 } }, crops }), 0);
input = Unsqueeze(input, new[] { -3 });
}

var bts = NCHWToNHWC(BatchToSpace(NHWCToNCHW(input), blockShape, crops));
if (input.CheckedShape.Rank == 3)
{
return Squeeze(bts, new[] { 1 });
}

return bts;
}
}
}
Loading

0 comments on commit 5c08f36

Please sign in to comment.