From 3820af1eafea1e54a4df7d3bdbc424bfa05e1438 Mon Sep 17 00:00:00 2001 From: FusionBolt <59008347+FusionBolt@users.noreply.github.com> Date: Fri, 17 Nov 2023 14:47:58 +0800 Subject: [PATCH] Fix/melgan (#1126) * add FoldDilatedConv2D * add SwapBinaryArg * update merge strategy(static shape, merge multiuser for normal, rebuild split from bucket, extract bucket condition) * fix SpaceToBatch * fix ShapeBucket * fix * add test * fix * fix * fix * fix marker * fix some bug * fix some bug * add FoldTransposeActTranspose * update * format and fix some build * fix build * Apply code-format changes * remove unused * fix test * Apply code-format changes * fix * fix test * fix importer * fix something * update * update --------- Co-authored-by: FusionBolt --- .../src/kernels/stackvm/reference/pad.cpp | 12 +- src/Native/src/kernels/stackvm/tensor_ops.cpp | 9 - .../PatternMatch/PatternUtility.cs | 4 + src/Nncase.Evaluator/EvaluatorDumpManager.cs | 15 +- src/Nncase.Evaluator/NN/BatchToSpace.cs | 2 +- src/Nncase.Evaluator/NN/SpaceToBatch.cs | 96 +++++- src/Nncase.Evaluator/Tensors/BucketPad.cs | 7 + src/Nncase.Importer/TFLite/SpaceToBatchND.cs | 31 +- .../Rules/Neutral/FoldDilatedConv2D.cs | 182 +++++++++++ src/Nncase.Passes/Rules/Neutral/FoldPad.cs | 4 +- .../Rules/Neutral/FoldReshape.cs | 2 +- .../Rules/Neutral/FoldTranspose.cs | 2 +- .../Rules/Neutral/SpaceToBatchTransform.cs | 4 +- .../Rules/Neutral/SplitSpaceToBatch.cs | 4 +- .../Rules/Neutral/SwapBinaryArgs.cs | 29 ++ .../Rules/ShapeBucket/MergeBucketFusion.cs | 20 +- .../Rules/ShapeBucket/MergeCallToFusion.cs | 9 +- .../Rules/ShapeBucket/RecordFusionShape.cs | 31 +- .../Rules/ShapeBucket/ShapeBucket.cs | 308 ++++++++++++------ .../Rules/ShapeBucket/ShapeBucketHelper.cs | 54 +-- .../Rules/ShapeExpr/FoldSplitShapeOf.cs | 15 +- .../Rules/WithMarker/CombineTranspose.cs | 259 +++++++++++++++ .../Runtime/Interop/RTHostMemoryManager.cs | 2 +- .../TransformBase/Compare.cs | 5 + .../Evaluator/UnitTestEvaluatorNN.cs | 6 +- .../Evaluator/UnitTestShapeEvaluator.cs | 2 +- .../Neutral/UnitTestSpaceToBatchTransform.cs | 3 +- .../Neutral/UnitTestSplitSpaceToBatch.cs | 3 +- .../Rules/ShapeBucket/ShapeBucketTest.cs | 2 +- .../ShapeExpr/UnitTestFoldSplitShapeOf.cs | 9 + 30 files changed, 947 insertions(+), 184 deletions(-) create mode 100644 src/Nncase.Passes/Rules/Neutral/FoldDilatedConv2D.cs create mode 100644 src/Nncase.Passes/Rules/Neutral/SwapBinaryArgs.cs create mode 100644 src/Nncase.Passes/Rules/WithMarker/CombineTranspose.cs diff --git a/src/Native/src/kernels/stackvm/reference/pad.cpp b/src/Native/src/kernels/stackvm/reference/pad.cpp index 2b27400fab..d0fc488b38 100644 --- a/src/Native/src/kernels/stackvm/reference/pad.cpp +++ b/src/Native/src/kernels/stackvm/reference/pad.cpp @@ -171,8 +171,7 @@ void padding_impl_opt(T *in, T *out, gsl::span in_shape, dh = out_shape[1]; hh = out_shape[2]; wh = out_shape[3]; - } else // size ==2 - { + } else if (in_shape.size() == 2) { cl = 1; dl = 1; hl = in_shape[0]; @@ -181,6 +180,15 @@ void padding_impl_opt(T *in, T *out, gsl::span in_shape, dh = 1; hh = out_shape[0]; wh = out_shape[1]; + } else { + cl = 1; + dl = 1; + hl = 1; + wl = in_shape[0]; + ch = 1; + dh = 1; + hh = 1; + wh = out_shape[1]; } pad_data2(in, out, cl, dl, hl, wl, ch, dh, hh, wh, value); diff --git a/src/Native/src/kernels/stackvm/tensor_ops.cpp b/src/Native/src/kernels/stackvm/tensor_ops.cpp index 120361645b..cfd6b7600d 100644 --- a/src/Native/src/kernels/stackvm/tensor_ops.cpp +++ b/src/Native/src/kernels/stackvm/tensor_ops.cpp @@ -793,14 +793,6 @@ result nncase::kernels::stackvm::bucket_pad( auto in_tensor = input.as().expect("input is not a tensor"); auto in_shape = in_tensor->shape(); if (compute_size(in_shape) > compute_size(shape_value)) { - std::cout << "in shape" << std::endl; - for (int i = 0; i < in_shape.size(); ++i) { - std::cout << in_shape[i] << std::endl; - } - std::cout << "shape_value shape" << std::endl; - for (int i = 0; i < shape_value.size(); ++i) { - std::cout << shape_value[i] << std::endl; - } return err(std::errc::invalid_argument); } @@ -1138,7 +1130,6 @@ nncase::kernels::stackvm::squeeze(value_t input, value_t dim, value_t output, try_var(in_tensor, input.as()); auto in_shape = in_tensor->shape(); not_impl_no_contiguous(in_tensor); - // todo: dim is scalar try_positive_axes(axes, dim, in_tensor->shape().size()); auto new_shape = squeeze_infer_shape(in_shape, axes); output = tensor_reshape(in_tensor, new_shape); diff --git a/src/Nncase.Core/PatternMatch/PatternUtility.cs b/src/Nncase.Core/PatternMatch/PatternUtility.cs index 17c3ee73af..f142fc21cd 100644 --- a/src/Nncase.Core/PatternMatch/PatternUtility.cs +++ b/src/Nncase.Core/PatternMatch/PatternUtility.cs @@ -285,4 +285,8 @@ public static Pattern IsCallWildcardMaybeSwappable(string callName, Pattern IsAlt( IsCallWildcard(callName, IsOp(callName + "Op"), input), IsCallWildcardSwappable(callName, IsOp(callName + "Op"), input, swappableOther ?? IsWildcard())); + + public static Pattern MaybeMarker(Pattern input) => IsAlt(input, IsRangeOfMarker(input, IsWildcard())); + + public static Pattern HasMarker(Pattern input, string? markerName = null) => IsRangeOfMarker(markerName, input, IsWildcard()); } diff --git a/src/Nncase.Evaluator/EvaluatorDumpManager.cs b/src/Nncase.Evaluator/EvaluatorDumpManager.cs index 976cd5e133..187e1f6615 100644 --- a/src/Nncase.Evaluator/EvaluatorDumpManager.cs +++ b/src/Nncase.Evaluator/EvaluatorDumpManager.cs @@ -7,6 +7,7 @@ using System.Linq; using Nncase.Diagnostics; using Nncase.IR; +using Nncase.IR.Tensors; using Nncase.Utilities; using CallbacksRegister = System.Action>; using TensorGetter = System.Func; @@ -27,6 +28,7 @@ public EvaluatorDumpManager(IDumpper dumpper, TensorGetter tensorGetter) _dumpper = dumpper; _tensorGetter = tensorGetter; + // todo: has bug when evaluate sub function if (_dumpper.IsEnabled(DumpFlags.Evaluator)) { _shapeWriter = new StreamWriter(_dumpper.OpenFile("!out_shape_list")); @@ -57,6 +59,12 @@ private static string GetTargetName(Call call) private void DumpCallArgs(Call call) { + // todo: fix this + if (call.Target is not Op) + { + return; + } + string target = GetTargetName(call); var paramsInfo = ((Op)call.Target).Parameters.ToArray(); @@ -75,11 +83,12 @@ private void DumpCall(Call call) string target = GetTargetName(call); // a bad tmp change - var shape = !(call.CheckedType is TensorType) ? Shape.Scalar : call.CheckedShape; + var result = _tensorGetter(call); + + // todo: when tuple maybe bug + var shape = result.Length == 1 ? result[0].Shape : result[0].Shape.ToValueArray(); DumpCall(target, shape, sr => { - sr.WriteLine(target); - var result = _tensorGetter(call); ValueDumper.DumpTensors(result, sr); }); } diff --git a/src/Nncase.Evaluator/NN/BatchToSpace.cs b/src/Nncase.Evaluator/NN/BatchToSpace.cs index e53297c7ab..4a7fc3a18c 100644 --- a/src/Nncase.Evaluator/NN/BatchToSpace.cs +++ b/src/Nncase.Evaluator/NN/BatchToSpace.cs @@ -203,7 +203,7 @@ private IRType Visit(ITypeInferenceContext context, BatchToSpace target, TensorT var m = blockShape.Shape[0].FixedValue; var cropsV = cropsValue.Value.Cast(); var cropSection = Enumerable.Range(0, m).Select( - i => (inShape[i + 1] * blockShapeArr[0]) - cropsV[i, 0] - cropsV[i, 1]); + i => (inShape[i + 1] * blockShapeArr[i]) - cropsV[i, 0] - cropsV[i, 1]); var remainSize = inShape.Rank - 1 - m; var remainShape = remainSize > 0 ? inShape.Skip(1 + m) : Array.Empty(); diff --git a/src/Nncase.Evaluator/NN/SpaceToBatch.cs b/src/Nncase.Evaluator/NN/SpaceToBatch.cs index 98ec949151..ce41b08af7 100644 --- a/src/Nncase.Evaluator/NN/SpaceToBatch.cs +++ b/src/Nncase.Evaluator/NN/SpaceToBatch.cs @@ -45,6 +45,7 @@ public Metric Visit(IMetricEvaluateContext context, SpaceToBatch target) public IValue Visit(IEvaluateContext context, SpaceToBatch s) { var input = context.GetOrtArgumentValue(s, SpaceToBatch.Input); + input = NCHWToNHWC(input.ToTensor()).Evaluate().AsTensor().ToOrtTensor(); var blockShape = context.GetArgumentValueAsTensor(s, SpaceToBatch.BlockShape); var paddings = context.GetArgumentValueAsArray(s, SpaceToBatch.Paddings); var spatialSize = blockShape.Length; @@ -82,7 +83,8 @@ public IValue Visit(IEvaluateContext context, SpaceToBatch s) var reshape1 = OrtKI.Reshape(p, (OrtKISharp.Tensor)reshappedShape1, 0); var rt = OrtKI.Transpose(reshape1, perm); var reshape2 = OrtKI.Reshape(rt, (OrtKISharp.Tensor)reshappedShape2, 0); - return reshape2.ToValue(); + + return NHWCToNCHW(reshape2.ToTensor()).Evaluate(); } /// @@ -97,6 +99,9 @@ public IRType Visit(ITypeInferenceContext context, SpaceToBatch target) public Expr Visit(IShapeEvaluateContext context, SpaceToBatch target) { var inShape = context.GetArgumentShape(target, SpaceToBatch.Input); + var inputExpr = context.GetArgument(target, SpaceToBatch.Input); + inShape = ShapeValueNCHWToNHWC(inputExpr, inShape); + var blockShape = context.GetArgument(target, SpaceToBatch.BlockShape); var padding = Cast(context.GetArgument(target, SpaceToBatch.Paddings), DataTypes.Int64); var input = context.GetArgument(target, SpaceToBatch.Input); @@ -125,12 +130,89 @@ public Expr Visit(IShapeEvaluateContext context, SpaceToBatch target) var remainShape = new If(remainSize > 0, ShapeExprUtility.Slice(inShape, 1 + m, int.MaxValue), Array.Empty()); var outLast = remainShape; var outShape = Concat(new IR.Tuple(Stack(new IR.Tuple(outFirst.Concat(outMid).ToArray()), 0), outLast), 0); + + outShape = ShapeValueNHWCToNCHW(inputExpr, outShape); + return outShape; } throw new NotImplementedException(); } + private static Call ShapeValueNHWCToNCHW(Expr inputExpr, Call outShape) + { + if (inputExpr.CheckedShape.Rank == 4) + { + outShape = Stack(new IR.Tuple(new[] { outShape[0], outShape[3], outShape[1], outShape[2] }), 0); + } + else if (inputExpr.CheckedShape.Rank == 3) + { + outShape = Stack(new IR.Tuple(new[] { outShape[0], outShape[2], outShape[1] }), 0); + } + else + { + throw new InvalidOperationException(); + } + + return outShape; + } + + private static Expr ShapeValueNCHWToNHWC(Expr inputExpr, Expr inShape) + { + if (inputExpr.CheckedShape.Rank == 4) + { + inShape = Stack(new IR.Tuple(new[] { inShape[0], inShape[2], inShape[3], inShape[1] }), 0); + } + else if (inputExpr.CheckedShape.Rank == 3) + { + inShape = Stack(new IR.Tuple(new[] { inShape[0], inShape[2], inShape[1] }), 0); + } + else + { + throw new InvalidOperationException(); + } + + return inShape; + } + + private static Dimension[] ShapeNHWCToNCHW(List inShape, List outshape) + { + Dimension[] outputShape; + + // nhwc to nchw + if (inShape.Count == 4) + { + outputShape = new[] { outshape[0], outshape[3], outshape[1], outshape[2] }; + } + else + { + outputShape = new[] { inShape[0], inShape[2], inShape[1] }; + } + + return outputShape; + } + + private static Dimension[] ShapeNCHWToNHWC(List inShape) + { + Dimension[] padded_shape; + + // nchw to nhwc + if (inShape.Count == 4) + { + padded_shape = new[] { inShape[0], inShape[2], inShape[3], inShape[1] }; + } + else if (inShape.Count == 3) + { + padded_shape = new[] { inShape[0], inShape[2], inShape[1] }; + } + else + { + throw new InvalidOperationException(); + } + + return padded_shape; + } + private T[] RangeExec(long end, Func f) { return EndRange(0, (int)end).Select(f).ToArray(); @@ -149,7 +231,11 @@ private IRType Visit(ITypeInferenceContext context, SpaceToBatch target, TensorT var ts_block_shape = block_shape_con.Value.Cast(); var ts_paddings = paddings_con.Value.ToArray(); int m = (int)ts_block_shape.Length; - var padded_shape = input.Shape.ToList(); + + // var padded_shape = input.Shape.ToList(); + var inShape = input.Shape.ToList(); + var padded_shape = ShapeNCHWToNHWC(inShape); + for (int i = 0; i < m; i++) { if (!padded_shape[1 + i].IsUnknown) @@ -168,7 +254,7 @@ private IRType Visit(ITypeInferenceContext context, SpaceToBatch target, TensorT new InvalidType($"The Padded Shape Must Divides BlockShape!"))); } - foreach (var i in Enumerable.Range(m + 1, padded_shape.Count - (m + 1))) + foreach (var i in Enumerable.Range(m + 1, padded_shape.Length - (m + 1))) { outshape.Add(padded_shape[i]); } @@ -178,7 +264,9 @@ private IRType Visit(ITypeInferenceContext context, SpaceToBatch target, TensorT outshape[0] = outshape[0].IsUnknown ? Dimension.Unknown : outshape[0].FixedValue * block; } - return input with { Shape = new Shape(outshape) }; + var outputShape = ShapeNHWCToNCHW(inShape, outshape); + + return input with { Shape = new Shape(outputShape) }; } return new TensorType(input.DType, Enumerable.Repeat(Dimension.Unknown, input.Shape.Count).ToArray()); diff --git a/src/Nncase.Evaluator/Tensors/BucketPad.cs b/src/Nncase.Evaluator/Tensors/BucketPad.cs index 30595c208e..d8b0350dd9 100644 --- a/src/Nncase.Evaluator/Tensors/BucketPad.cs +++ b/src/Nncase.Evaluator/Tensors/BucketPad.cs @@ -1,6 +1,7 @@ // Copyright (c) Canaan Inc. All rights reserved. // Licensed under the Apache license. See LICENSE file in the project root for full license information. +using System; using System.Linq; using DryIoc.ImTools; using Nncase.CostModel; @@ -9,6 +10,7 @@ using Nncase.IR.Tensors; using OrtKISharp; using static Nncase.IR.F.Tensors; +using Tuple = Nncase.IR.Tuple; namespace Nncase.Evaluator.Tensors; @@ -27,6 +29,11 @@ public IValue Visit(IEvaluateContext context, BucketPad bucketPad) } var shape = context.GetArgumentValueAsArray(bucketPad, BucketPad.Shape); + if (input.Shape.Size > shape.Aggregate((x, sum) => x * sum)) + { + throw new InvalidOperationException(); + } + var pads = shape - (Expr)input.Shape; var paddings = Transpose( Stack(new Tuple(Enumerable.Repeat(0, shape.Length).ToArray(), pads), 0), diff --git a/src/Nncase.Importer/TFLite/SpaceToBatchND.cs b/src/Nncase.Importer/TFLite/SpaceToBatchND.cs index 14d8f524dd..1bf51dc176 100644 --- a/src/Nncase.Importer/TFLite/SpaceToBatchND.cs +++ b/src/Nncase.Importer/TFLite/SpaceToBatchND.cs @@ -4,6 +4,7 @@ using Nncase.IR.Tensors; using static Nncase.IR.F.NN; using static Nncase.IR.F.Tensors; +using Unsqueeze = Nncase.IR.Tensors.Unsqueeze; namespace Nncase.Importer.TFLite { @@ -13,14 +14,40 @@ private Expr VisitSpaceToBatchND(in tflite.Operator op) { var (input, blockShape) = GetInputExprs(op, 0, 1); var paddings = GetInputExprs(op, 2); - return SpaceToBatch(input, blockShape, paddings); + if (input.CheckedShape.Rank == 3) + { + blockShape = Concat(new IR.Tuple(new[] { new[] { 1 }, blockShape }), 0); + paddings = Concat(new IR.Tuple(new[] { new[,] { { 0, 0 } }, paddings }), 0); + input = Unsqueeze(input, new[] { -3 }); + } + + var stb = NCHWToNHWC(SpaceToBatch(NHWCToNCHW(input), blockShape, paddings)); + if (input.CheckedShape.Rank == 3) + { + return Squeeze(stb, new[] { 1 }); + } + + return stb; } private Expr VisitBatchToSpaceND(in tflite.Operator op) { var (input, blockShape) = GetInputExprs(op, 0, 1); var crops = GetInputExprs(op, 2); - return NCHWToNHWC(BatchToSpace(NHWCToNCHW(input), blockShape, crops)); + if (input.CheckedShape.Rank == 3) + { + blockShape = Concat(new IR.Tuple(new[] { new[] { 1 }, blockShape }), 0); + crops = Concat(new IR.Tuple(new[] { new[,] { { 0, 0 } }, crops }), 0); + input = Unsqueeze(input, new[] { -3 }); + } + + var bts = NCHWToNHWC(BatchToSpace(NHWCToNCHW(input), blockShape, crops)); + if (input.CheckedShape.Rank == 3) + { + return Squeeze(bts, new[] { 1 }); + } + + return bts; } } } diff --git a/src/Nncase.Passes/Rules/Neutral/FoldDilatedConv2D.cs b/src/Nncase.Passes/Rules/Neutral/FoldDilatedConv2D.cs new file mode 100644 index 0000000000..6bc2b09c30 --- /dev/null +++ b/src/Nncase.Passes/Rules/Neutral/FoldDilatedConv2D.cs @@ -0,0 +1,182 @@ +// Copyright (c) Canaan Inc. All rights reserved. +// Licensed under the Apache license. See LICENSE file in the project root for full license information. + +using System; +using System.Collections.Generic; +using System.Linq; +using Nncase.Diagnostics; +using Nncase.IR; +using Nncase.IR.NN; +using Nncase.IR.Tensors; +using Nncase.Passes.Rules.ShapeExpr; +using Nncase.PatternMatch; +using Nncase.Utilities; +using OrtKISharp; +using static Nncase.IR.F.Math; +using static Nncase.IR.F.NN; +using static Nncase.IR.F.Tensors; +using static Nncase.IR.TypePatternUtility; +using static Nncase.PatternMatch.F.Math; +using static Nncase.PatternMatch.F.NN; +using static Nncase.PatternMatch.F.Tensors; +using static Nncase.PatternMatch.Utility; +using Conv2D = Nncase.IR.NN.Conv2D; + +namespace Nncase.Passes.Rules.Neutral; + +[RuleGenerator] +public partial class FoldDilatedConv2D : RewriteRule +{ + /// + public override Pattern Pattern { get; } = + IsBatchToSpace( + "bts", + "btsCall", + IsRangeOfMarker( + "btsInput", + Conv2DPattern(), + IsTensorConst()), + IsTensorConst("btsBlockShape"), + IsTensorConst("originCrop")); + + private static CallPattern Conv2DPattern() => + IsCallWildcard( + "conv", + IsOp(), + IsRangeOfMarker( + IsSpaceToBatch( + "sbt", + "stbCall", + IsWildcard("stbInput") with { TypePattern = HasFixedShape() }, + IsTensorConst("stbBlockShape"), + IsTensorConst("originPaddings")), + IsTensorConst())); + + private Expr? GetReplace(Call conv, Call btsCall, Call stbCall, Expr btsInput, Expr stbInput, int[] btsBlockShape, int[] stbBlockShape, int[] originPaddings, int[] originCrop) + { + var btsShape = btsCall.CheckedShape.ToValueArray(); + var btsInputShape = btsInput.CheckedShape.ToValueArray(); + var stbInputShape = stbInput.CheckedShape.ToValueArray(); + + var paddings = new[,] { { originPaddings[0], originPaddings[1] }, { originPaddings[2], originPaddings[3] } }; + var crop = new[,] { { originCrop[0], originCrop[1] }, { originCrop[2], originCrop[3] } }; + + var padIfH = paddings[0, 0] + paddings[0, 1] + stbInputShape[2]; + var padIfW = paddings[1, 0] + paddings[1, 1] + stbInputShape[3]; + var dilationH = stbBlockShape[0]; + var dilationW = stbBlockShape[1]; + var weightsShape = conv.Arguments[Conv2D.Weights.Index].CheckedShape.ToValueArray(); + var wH = weightsShape[2]; + var wW = weightsShape[3]; + var outH = btsShape[2] + crop[0, 0] + crop[0, 1]; + var outW = btsShape[3] + crop[1, 0] + crop[1, 1]; + var strideH = outH == 1 ? 1 : (padIfH - (dilationH * (wH - 1)) - 1) / (outH - 1); + var strideW = outW == 1 ? 1 : (padIfW - (dilationW * (wW - 1)) - 1) / (outW - 1); + + var (begin, end) = GetBeginEnd(btsBlockShape, crop, btsInputShape); + var slicePadding = new[,] + { + { -begin[0], end[0] - btsShape[0] }, + { -begin[3], end[3] - btsShape[1] }, + { -begin[1], end[1] - btsShape[2] }, + { -begin[2], end[2] - btsShape[3] }, + }; + + var newPaddings = new[,] + { + { 0, 0 }, + { 0, 0 }, + { paddings[0, 0] + (strideH * slicePadding[2, 0]) - crop[0, 0], paddings[0, 1] + (strideH * slicePadding[2, 1]) - crop[0, 1] }, + { paddings[1, 0] + (strideH * slicePadding[3, 0]) - crop[1, 0], paddings[1, 1] + (strideH * slicePadding[3, 1]) - crop[1, 1] }, + }; + + var pairs = new[] + { + (Conv2D.Input.Index, stbInput), + (Conv2D.Padding.Index, (Expr)newPaddings), + (Conv2D.Stride.Index, (Expr)new[] { strideH, strideW }), + (Conv2D.Dilation.Index, (Expr)new[] { dilationH, dilationW }), + }; + return ReplaceUtility.ReplaceCallParams(conv, conv.Arguments.ToArray(), pairs); + } + + private (int[] Begin, int[] End) GetBeginEnd(int[] btsBlockShape, int[,] crop, int[] btsInputShape) + { + List shape_expend = new(); + var block_shape_produt = btsBlockShape.Aggregate((x, sum) => x * sum); + for (var i = 0; i < btsBlockShape.Length; i++) + { + shape_expend.Add(btsBlockShape[i]); + } + + shape_expend.Add(btsInputShape[0] / block_shape_produt); + for (var i = 1; i < btsInputShape.Length; i++) + { + shape_expend.Add(btsInputShape[i]); + } + + List shape_shrink = new(); + shape_shrink.Add(shape_expend[btsBlockShape.Length]); + for (var i = 0; i < btsBlockShape.Length; i++) + { + shape_shrink.Add(btsBlockShape[i] * btsInputShape[i + 1]); + } + + for (var i = btsBlockShape.Length + 1; i < btsInputShape.Length; i++) + { + shape_shrink.Add(btsInputShape[i]); + } + + List crop_begs = new(), crop_ends = new(); + crop_begs.Add(0); + crop_ends.Add(shape_shrink[0]); + for (var i = 0; i < crop.GetLength(0); i++) + { + crop_begs.Add(crop[i, 0]); + crop_ends.Add(shape_shrink[i + 1] - crop[i, 1]); + } + + for (var i = btsBlockShape.Length + 1; i < btsInputShape.Length; i++) + { + crop_begs.Add(0); + crop_ends.Add(shape_shrink[i]); + } + + var cropBegin = crop_begs.ToArray(); + var cropEnd = crop_ends.ToArray(); + var strides = Enumerable.Repeat(1, crop_begs.Count).ToArray(); + var begin = NormalizeStridedSliceBegin(btsInputShape, cropBegin, strides, 0); + var end = NormalizeStridedSliceEndEnd(btsInputShape, begin, cropEnd, strides, 0, 0); + return (begin, end); + } + + private int[] NormalizeStridedSliceEndEnd(int[] in_shape, int[] begin, int[] end, int[] strides, int end_mask, int shrink_axis_mask) + { + var new_shape = Enumerable.Range(0, strides.Length).ToArray(); + for (var i = 0; i < new_shape.Length; i++) + { + var stride = strides[i]; + var end_val = (end_mask & (1 << i)) != 0 + ? stride > 0 ? in_shape[i] : -1 + : (shrink_axis_mask & (1 << i)) == 0 ? (end[i] >= 0 ? end[i] : in_shape[i] + end[i] + 1) + : begin[i] + 1; + new_shape[i] = end_val; + } + + return new_shape; + } + + private int[] NormalizeStridedSliceBegin(int[] in_shape, int[] begin, int[] strides, int begin_mask) + { + var new_shape = Enumerable.Range(0, strides.Length).ToArray(); + for (var i = 0; i < new_shape.Length; i++) + { + var stride = strides[i]; + new_shape[i] = (begin_mask & (1 << i)) != 0 + ? stride > 0 ? 0 : in_shape[i] + : (begin[i] >= 0 ? begin[i] : in_shape[i] + begin[i]); + } + + return new_shape; + } +} diff --git a/src/Nncase.Passes/Rules/Neutral/FoldPad.cs b/src/Nncase.Passes/Rules/Neutral/FoldPad.cs index 39b34aab39..2ea62feff0 100644 --- a/src/Nncase.Passes/Rules/Neutral/FoldPad.cs +++ b/src/Nncase.Passes/Rules/Neutral/FoldPad.cs @@ -80,11 +80,11 @@ public sealed partial class FoldConv2DPads : IRewriteRule public IPattern Pattern { get; } = IsConv2D( "conv", conv => conv.PadMode == PadMode.Constant, - IsPad( + MaybeMarker(IsPad( pad => pad.PadMode == PadMode.Constant, IsWildcard("input"), IsTensorConst("ext_pad"), - IsTensorConst("ext_pad_init")), + IsTensorConst("ext_pad_init"))), IsWildcard("weights"), IsWildcard("bias"), IsWildcard("stride"), diff --git a/src/Nncase.Passes/Rules/Neutral/FoldReshape.cs b/src/Nncase.Passes/Rules/Neutral/FoldReshape.cs index 013d76f86e..4de3b7f3cc 100644 --- a/src/Nncase.Passes/Rules/Neutral/FoldReshape.cs +++ b/src/Nncase.Passes/Rules/Neutral/FoldReshape.cs @@ -54,7 +54,7 @@ public sealed partial class FoldTwoReshapes : IRewriteRule { /// public IPattern Pattern { get; } = IsReshape( - IsReshape(IsWildcard("input"), IsWildcard()), IsWildcard("newShape")); + MaybeMarker(IsReshape(IsWildcard("input"), IsWildcard())), IsWildcard("newShape")); private Expr? GetReplace(Expr input, Expr newShape) { diff --git a/src/Nncase.Passes/Rules/Neutral/FoldTranspose.cs b/src/Nncase.Passes/Rules/Neutral/FoldTranspose.cs index d5e9d6256f..2e719f416c 100644 --- a/src/Nncase.Passes/Rules/Neutral/FoldTranspose.cs +++ b/src/Nncase.Passes/Rules/Neutral/FoldTranspose.cs @@ -45,7 +45,7 @@ public sealed partial class FoldTwoTransposes : IRewriteRule { /// public IPattern Pattern { get; } = IsTranspose( - IsTranspose(IsWildcard("input"), IsWildcard("perm1") with { TypePattern = HasRank() }), + MaybeMarker(IsTranspose(IsWildcard("input"), IsWildcard("perm1") with { TypePattern = HasRank() })), IsWildcard("perm2") with { TypePattern = HasRank() }); private Expr? GetReplace(Expr input, Expr perm1, Expr perm2) diff --git a/src/Nncase.Passes/Rules/Neutral/SpaceToBatchTransform.cs b/src/Nncase.Passes/Rules/Neutral/SpaceToBatchTransform.cs index 56a0e84796..821ffba660 100644 --- a/src/Nncase.Passes/Rules/Neutral/SpaceToBatchTransform.cs +++ b/src/Nncase.Passes/Rules/Neutral/SpaceToBatchTransform.cs @@ -41,9 +41,11 @@ public sealed partial class SpaceToBatchToPad : IRewriteRule if (input.CheckedShape.Rank == 4 && blockShapeArray.Length == 2 && blockShapeArray[0] == 1 && blockShape[1] == 1) { var newPaddingsArray = new int[8]; + + // pad for hw for (var i = 0; i < paddingsArray.Length; i++) { - newPaddingsArray[i + 2] = paddingsArray[i]; + newPaddingsArray[i + 4] = paddingsArray[i]; } var newPaddings = Tensor.From(newPaddingsArray, new[] { 4, 2 }); diff --git a/src/Nncase.Passes/Rules/Neutral/SplitSpaceToBatch.cs b/src/Nncase.Passes/Rules/Neutral/SplitSpaceToBatch.cs index 4361b2db15..8494e72427 100644 --- a/src/Nncase.Passes/Rules/Neutral/SplitSpaceToBatch.cs +++ b/src/Nncase.Passes/Rules/Neutral/SplitSpaceToBatch.cs @@ -44,7 +44,7 @@ public partial class SplitSpaceToBatch : RewriteRule var tmpPaddings = Stack(new IR.Tuple(newPaddings), 0); var newPaddingsTensor = Transpose(Reshape(tmpPaddings, new long[] { 2, 1 + spatialSize + remainShapeSize }), new long[] { 1, 0 }); - var p = Pad(input, newPaddingsTensor, PadMode.Constant, 0f); + var p = Pad(NCHWToNHWC(input), newPaddingsTensor, PadMode.Constant, 0f); var padShape = Cast(ShapeOf(p), DataTypes.Int32); var batchShape1 = StackScalar(padShape[0]); @@ -77,7 +77,7 @@ public partial class SplitSpaceToBatch : RewriteRule var reshape1 = Reshape(p, reshappedShape1); var rt = Transpose(reshape1, perm); var reshape2 = Reshape(rt, reshappedShape2); - return reshape2; + return NHWCToNCHW(reshape2); } private T[] RangeExec(long end, Func f) diff --git a/src/Nncase.Passes/Rules/Neutral/SwapBinaryArgs.cs b/src/Nncase.Passes/Rules/Neutral/SwapBinaryArgs.cs new file mode 100644 index 0000000000..7626b9840d --- /dev/null +++ b/src/Nncase.Passes/Rules/Neutral/SwapBinaryArgs.cs @@ -0,0 +1,29 @@ +// Copyright (c) Canaan Inc. All rights reserved. +// Licensed under the Apache license. See LICENSE file in the project root for full license information. + +using System; +using System.Linq; +using Nncase.IR; +using Nncase.PatternMatch; +using static Nncase.Passes.Utility; +using static Nncase.PatternMatch.F.Math; +using static Nncase.PatternMatch.Utility; + +namespace Nncase.Passes.Rules.Neutral; + +[RuleGenerator] +public partial class SwapBinaryArgs : RewriteRule +{ + public override Pattern Pattern => IsBinary( + "bn", + "bnCall", + op => op.BinaryOp == BinaryOp.Add || op.BinaryOp == BinaryOp.Mul || op.BinaryOp == BinaryOp.Min || + op.BinaryOp == BinaryOp.Max, + IsTensorConst("lhs"), + IsWildcard("rhs")); + + private Expr? GetReplace(Call bnCall, Expr lhs, Expr rhs) + { + return bnCall.With(arguments: new[] { rhs, lhs }); + } +} diff --git a/src/Nncase.Passes/Rules/ShapeBucket/MergeBucketFusion.cs b/src/Nncase.Passes/Rules/ShapeBucket/MergeBucketFusion.cs index be7dad641b..b34c47e170 100644 --- a/src/Nncase.Passes/Rules/ShapeBucket/MergeBucketFusion.cs +++ b/src/Nncase.Passes/Rules/ShapeBucket/MergeBucketFusion.cs @@ -40,20 +40,12 @@ protected override async Task RunCoreAsync(BaseFunction input, Run while (true) { var preHash = main.GetHashCode(); - if (_greedy) - { - CompilerServices.Rewrite(main, new IRewriteRule[] { new MultiUserCallToFusion(false, _greedy), new MergeTupleFusion() }, new()); - await new MergeSeqBucketFusion().RunAsync(main, context); - IRHelpers.DCE(main); - await new MergeMultiUsersFusion().RunAsync(main, context); - DumpIR(main, $"{i}_before", "FoldNopTuple"); - await new FoldNopTuple().RunAsync(main, context); - } - else - { - await new MergeSeqBucketFusion().RunAsync(main, context); - IRHelpers.DCE(main); - } + CompilerServices.Rewrite(main, new IRewriteRule[] { new MultiUserCallToFusion(!_greedy, _greedy), new MergeTupleFusion() }, new()); + await new MergeSeqBucketFusion().RunAsync(main, context); + IRHelpers.DCE(main); + await new MergeMultiUsersFusion().RunAsync(main, context); + DumpIR(main, $"{i}_before", "FoldNopTuple"); + await new FoldNopTuple().RunAsync(main, context); CheckRepeat(main); CheckErrorVar(main, main.Parameters.ToArray()); diff --git a/src/Nncase.Passes/Rules/ShapeBucket/MergeCallToFusion.cs b/src/Nncase.Passes/Rules/ShapeBucket/MergeCallToFusion.cs index 08e232585e..72fb29e737 100644 --- a/src/Nncase.Passes/Rules/ShapeBucket/MergeCallToFusion.cs +++ b/src/Nncase.Passes/Rules/ShapeBucket/MergeCallToFusion.cs @@ -135,8 +135,8 @@ public MergeNextCallToFusion() // nextCall(marker(fusion(x))) -> fusion(nextCall(marker(x))) public Expr? GetReplace(Call nextCall, Expr maybeFusionCallMarker, Expr target, Call fusionOuterCall, BucketFusion fusion) { - var singleVar = SingleDimVar(CompileSession.CompileOptions.ShapeBucketOptions); - if (!singleVar && nextCall.Arguments.ToArray().OfType().Count() > 1) + _ = SingleDimVar(CompileSession.CompileOptions.ShapeBucketOptions); + if (!AllConst(nextCall)) { return null; } @@ -159,11 +159,6 @@ public MergeNextCallToFusion() return null; } - if (!AllConst(nextCall)) - { - return null; - } - DumpIR(nextCall, $"{Counter}_{fusion.Name}_{target.GetType().Name}_origin"); // 将call里面call fusion的部分替换为fusion的body diff --git a/src/Nncase.Passes/Rules/ShapeBucket/RecordFusionShape.cs b/src/Nncase.Passes/Rules/ShapeBucket/RecordFusionShape.cs index 94f7977f2f..59775c2c3a 100644 --- a/src/Nncase.Passes/Rules/ShapeBucket/RecordFusionShape.cs +++ b/src/Nncase.Passes/Rules/ShapeBucket/RecordFusionShape.cs @@ -7,10 +7,14 @@ using System.Reactive; using System.Threading.Tasks; using Google.OrTools.Algorithms; +using Google.OrTools.Graph; +using Microsoft.Toolkit.HighPerformance; using NetFabric.Hyperlinq; using Nncase.Diagnostics; using Nncase.Evaluator; using Nncase.IR; +using Nncase.IR.Tensors; +using Nncase.Utilities; using static Nncase.IR.F.Tensors; using static Nncase.Passes.Rules.ShapeBucket.ShapeBucketHelper; @@ -72,11 +76,14 @@ private IValue GetShape(IValue value) public class RecordFusionShape : FunctionPass { + private readonly bool _once; + private Dictionary _dimVarValues = new(); - public RecordFusionShape(Dictionary shapeList) + public RecordFusionShape(Dictionary shapeList, bool once = false) { FusionShapeInfo = shapeList; + _once = once; } public Dictionary FusionShapeInfo { get; set; } @@ -90,6 +97,13 @@ public static Dictionary pair => pair.Key, pair => { + if (pair.Key.CheckedShape.IsFixed) + { + return ConstantOfShape( + pair.Key.CheckedShape.ToValueArray(), + Cast(1, pair.Key.CheckedDataType)).Evaluate(); + } + // todo: dummy input可能会有问题... var shapeExpr = pair.Key.CheckedShape.IsScalar ? (Expr)Array.Empty() @@ -106,14 +120,22 @@ protected override Task RunCoreAsync(BaseFunction main, RunPassCon { var options = CompileSession.CompileOptions.ShapeBucketOptions; var varMap = options.VarMap; - _dimVarValues = MakeVarValuesForAllSegment(options); + + var staticShape = IsStaticShpae; + var segmentCount = staticShape + && SingleDimVar(options) + ? options.RangeInfo.First().Value.Max + : options.SegmentsCount; + + _dimVarValues = MakeVarValuesForAllSegment(options, segmentCount, staticShape); // 一共有多组key seg - var list = Enumerable.Range(0, _dimVarValues.First().Value.Length).Select(i => + var tmpList = Enumerable.Range(0, _dimVarValues.First().Value.Length).Select(i => { // 一组里面多个key seg return _dimVarValues.Select(pair => (pair.Key, Value: pair.Value[i])).ToArray(); - }).ToArray(); + }); + var list = _once ? tmpList.TakeLast(1).ToArray() : tmpList.ToArray(); var body = ((Function)main).Body; var tmpFusionShapeList = list.Select((seg, i) => @@ -129,7 +151,6 @@ protected override Task RunCoreAsync(BaseFunction main, RunPassCon .ToLookup(x => x.Key, x => x.Value) .ToDictionary(pair => pair.Key, pair => pair.ToArray()); - GC.Collect(); foreach (var (f, shapeInfo) in tmpFusionShapeList) { FusionShapeInfo[f] = shapeInfo; diff --git a/src/Nncase.Passes/Rules/ShapeBucket/ShapeBucket.cs b/src/Nncase.Passes/Rules/ShapeBucket/ShapeBucket.cs index a0cf03380a..72c85244ea 100644 --- a/src/Nncase.Passes/Rules/ShapeBucket/ShapeBucket.cs +++ b/src/Nncase.Passes/Rules/ShapeBucket/ShapeBucket.cs @@ -153,7 +153,6 @@ public virtual bool Check(Call call) Init(matchResult); - Console.WriteLine(call.Target.GetType().Name); var argsMarkerData = CollectInputs(call); var args = argsMarkerData.Select(pair => pair.Item1).ToArray(); var varMap = CompileSession.CompileOptions.ShapeBucketOptions.VarMap; @@ -622,7 +621,7 @@ public FusionBucketContext(Call outerCall, BucketFusion fusion, ShapeBucketOptio Arguments = OuterCall.Arguments.ToArray(); Parameters = Fusion.Parameters.ToArray(); FixedShapeCache = new(); - SliceShape = ComputeSliceShape(shapeInfos); + SliceShape = ComputeSliceShape(shapeInfos, options.RangeInfo.First().Value.Max); _index = index; } @@ -659,7 +658,7 @@ private static Dictionary MakeShapeOfFusionInput(Var[] parameters, .Zip(args) .ToDictionary(pair => pair.First, pair => { - var shape = Cast((Expr)ShapeOf(pair.Second), DataTypes.Int32); + var shape = Cast((Expr)ShapeOf(pair.Second), DataTypes.Int64); return Enumerable.Range(0, pair.Second.CheckedShape.Rank).Select(i => shape[i]).ToArray(); }); return fusionInputShapes; @@ -775,9 +774,38 @@ private static Expr SimplifyShape(Expr body) => }, new()); - private Expr ComputeSliceShape(FusionShapeData[] shapeInfos) + private Expr ComputeSliceShape(FusionShapeData[] shapeInfos, long max) { var originBody = FusionBody; + var staticShape = IsStaticShpae; + if (staticShape) + { + var len = FusionBucket.GetVarValue(this); + + // todo: reverse + if (originBody.CheckedType is TupleType tuple) + { + var outputCount = tuple.Fields.Count; + var outShapes = Enumerable.Range(0, outputCount).Select(i => + { + var arr = new[] { shapeInfos.First().Outshape.AsTensors()[i] } + .Concat(shapeInfos.Select(info => info.Outshape.AsTensors()[i]).Reverse()).ToArray(); + var outShapeList = Cast(Stack(new IR.Tuple(arr.Select(x => (Expr)x).ToArray()), 0), DataTypes.Int64); + return outShapeList[len]; + }).ToArray(); + return new IR.Tuple(outShapes); + } + else + { + // 80 79 ... 1 -> 1 1 ... 79 80 + // len is 1 - 80 + var arr = new[] { shapeInfos.First().Outshape.AsTensor() } + .Concat(shapeInfos.Select(x => x.Outshape.AsTensor()).Reverse()).ToArray(); + var outShapeList = Cast(Stack(new IR.Tuple(arr.Select(x => (Expr)x).ToArray()), 0), DataTypes.Int64); + return outShapeList[len]; + } + } + var shapeOfFusionInput = MakeShapeOfFusionInput(Parameters, Arguments); var originShape = originBody.EvaluateShapeExpr(shapeOfFusionInput); originShape.InferenceType(); @@ -832,9 +860,7 @@ public FusionBucket(Dictionary list) FusionShapeInfo = list; } - public Dictionary FusionShapeInfo { get; set; } - - public override Pattern Pattern => IsCall( + public static Pattern BucketFusionPattern => IsCall( "outerCall", IsFusion( "fusion", @@ -843,11 +869,16 @@ public FusionBucket(Dictionary list) GenerateParameters(null)), GenerateParameters(null)); + public Dictionary FusionShapeInfo { get; set; } + + public override Pattern Pattern => FusionBucket.BucketFusionPattern; + public static Expr PreProcess(FusionBucketContext context, Var param, Dictionary inputInfo, Dictionary varValues, Dictionary fusionInputData, int segIndex, int inputIndex) { // Console.WriteLine($"seg index{segIndex}"); if (context.FixedShapeCache.TryGetValue(segIndex, out var cachedFixedShape)) { + // replace index by value var shape = cachedFixedShape[inputIndex]; if ((param.CheckedShape.IsFixed && shape.SequenceEqual(param.CheckedShape.ToValueArray())) || param.CheckedShape.IsScalar) { @@ -876,22 +907,25 @@ public static (Dictionary MinDict, Dictionary MaxDict) return (minDict, maxDict); } - public static Expr Split(FusionBucketContext context) + public static (Expr Body, List CondList) Split(FusionBucketContext context, SegmentInfo? info = null) { - var failure = MakeFailure(context.FusionBody); + var restore = MakeFailure(context); // todo: test this var value = GetVarValue(context); int i = 0; - // todo: only used for same range + var condList = new List(); + + // todo: only used for same range, should add check var body = context.DimVarValues.First().Value.OrderByDescending(x => x).Aggregate( - failure, + restore, (sum, seg) => { // 根据var,也就是target为这个fusion的call的参数来进行判断落在哪个段 var cond = value <= (long)seg; + condList.Add(cond); var sameCond = IR.F.Math.Equal(value, (long)seg); // select var value for current segment @@ -904,12 +938,35 @@ public static Expr Split(FusionBucketContext context) return result; }); - return body; + body.InferenceType(); + + if (body.CheckedType is InvalidType) + { + DumpIR(body, "InvalidBody"); + throw new InvalidOperationException(); + } + + if (body.Users.Count > 1) + { + throw new InvalidOperationException(); + } + + return (body, condList); } public static Expr MakeSplitEntry(FusionBucketContext context, Dictionary varInfo, int segIndex, Expr sameCond, bool sameOpt = false) { var originBody = context.FusionBody; + var call = MakeNewBody(context, varInfo, segIndex); + + var slice = MakeSlice(context, call, originBody); + return sameOpt + ? new If(sameCond, call, slice) + : slice; + } + + public static Expr MakeNewBody(FusionBucketContext context, Dictionary varInfo, int segIndex) + { var fusionVars = context.Parameters; var fixInputs = fusionVars .Select((arg, i) => @@ -918,17 +975,14 @@ public static Expr MakeSplitEntry(FusionBucketContext context, Dictionary fusion原始的var -> target为fusion的call的input // 本质上只是对这个body的所有输入做替换 // 避免这里的修改影响到原始的body,每个分支需要进行自己的修改,所以要clone处理 - var call = ReplaceClone(originBody, fusionVars.Zip(fixInputs).ToArray()); + var call = ReplaceClone(context.FusionBody, fusionVars.Zip(fixInputs).ToArray()); if (!call.InferenceType()) { DumpIR(call, "InvalidType"); throw new InvalidOperationException(); } - var slice = MakeSlice(context, call, originBody); - return sameOpt - ? new If(sameCond, call, slice) - : slice; + return call; } public static Expr GetVarValue(FusionBucketContext context) @@ -948,6 +1002,46 @@ public static Expr GetVarValue(FusionBucketContext context) return varList.First(); } + public static Expr MakeSliceImpl(Expr body, Expr sliceShape) + { + var rank = body.CheckedShape.Rank; + var axes = Tensor.From(Enumerable.Range(0, rank).Select(x => (long)x).ToArray()); + var strides = Tensor.FromScalar(1L, rank); + return Slice(body, Enumerable.Repeat(0L, rank).ToArray(), Cast(sliceShape, DataTypes.Int64), axes, strides); + } + + public static Expr RestoreBodyWithArgs(Expr[] args, Var[] parameters, Expr body) => + ReplaceClone(body, parameters.Zip(args).ToArray()); + + public static int[][][] UpdateShapeCache(FusionShapeData[] shapeInfos, ShapeBucketOptions options, FusionBucketContext context) + { + var allFixedShapes = shapeInfos + .Select(x => + x.InputShapes.Select(iShape => iShape.AsTensor().ToArray().ToArray()).ToArray()).ToArray(); + if (!SingleDimVar(options)) + { + for (int i = 0; i < shapeInfos.Length; i++) + { + for (int j = 0; j < allFixedShapes.Length; j++) + { + context.FixedShapeCache[j] = allFixedShapes[j]; + } + } + } + else + { + allFixedShapes = new[] { allFixedShapes[0] }.Concat(allFixedShapes.Reverse()).ToArray(); + var segments = context.DimVarValues.First().Value.Reverse().ToArray(); + + for (int i = 0; i < segments.Length; i++) + { + context.FixedShapeCache[segments.Length - 1 - i] = allFixedShapes[segments[i]]; + } + } + + return allFixedShapes; + } + public Expr? GetReplace(Call outerCall, BucketFusion fusion, Expr fusionBody) { if (ShouldRestore(outerCall, fusion)) @@ -976,21 +1070,10 @@ public static Expr GetVarValue(FusionBucketContext context) // 每个段的output var context = new FusionBucketContext(outerCall, fusion, options, _cache, _counter, shapeInfos); - var allFixedShapes = shapeInfos - .Select(x => - x.InputShapes.Select(iShape => iShape.AsTensor().ToArray().ToArray()).ToArray()).ToArray(); - for (int i = 0; i < shapeInfos.Length; i++) - { - for (int j = 0; j < allFixedShapes.Length; j++) - { - context.FixedShapeCache[j] = allFixedShapes[j]; - } - } + int[][][] allFixedShapes = UpdateShapeCache(shapeInfos, options, context); - // todo: fix min max - // reverse var minFixedShapeList = allFixedShapes[^1]; - var maxFixedShapeList = allFixedShapes[0]; + var maxFixedShapeList = allFixedShapes[1]; // PrintMinMaxShape(minFixedShapeList, maxFixedShapeList, _relPath); @@ -1010,30 +1093,8 @@ public static Expr GetVarValue(FusionBucketContext context) // Console.WriteLine($"{fusion.Name} totalCount > 1"); } - // 1. 普通情况不应该rebuild - // 2. rebuild的正确性 - if (ShouldBeRebuild(context)) - { - _counter++; - Console.WriteLine("Rebuild"); - var rebuild = RestoreBodyWithArgs(context.Arguments, context.Parameters, context.FusionBody); - DumpIR(rebuild, "Rebuild", _relPath); - return rebuild; - } - - var body = Split(context); - body.InferenceType(); - - if (body.CheckedType is InvalidType) - { - DumpIR(body, "InvalidBody"); - throw new InvalidOperationException(); - } - - if (body.Users.Count > 1) - { - throw new InvalidOperationException(); - } + var info = ComputeSegmentInfo(counts, options); + var (body, condList) = Split(context, info); // FixInput Replace Var var newBody = ReplaceFusionVarWithCallArgs(fusion, context.Arguments, body); @@ -1041,7 +1102,8 @@ public static Expr GetVarValue(FusionBucketContext context) // let bind if (newBody is If @if) { - newBody = IR.F.Math.Require(true, @if.With(paramList: context.Arguments)); + var parameters = context.Arguments.ToArray().Concat(condList).ToArray(); + newBody = IR.F.Math.Require(true, @if.With(paramList: parameters)); } DumpIR(newBody, "BucketResult", _relPath); @@ -1094,7 +1156,6 @@ private static Expr MakeSlice(FusionBucketContext context, Expr call, Expr origi private static Expr MakeSliceForTensor(Expr sliceShape, Expr call, FusionBucketContext context) { - var rank = call.CheckedShape.Rank; var simplifyCall = CompilerServices.Rewrite( call, new IRewriteRule[] @@ -1114,9 +1175,7 @@ private static Expr MakeSliceForTensor(Expr sliceShape, Expr call, FusionBucketC }, new()); - var axes = Tensor.From(Enumerable.Range(0, rank).Select(x => (long)x).ToArray()); - var strides = Tensor.FromScalar(1L, rank); - var body = (Expr)Slice(simplifyCall, Enumerable.Repeat(0L, rank).ToArray(), Cast(sliceShape, DataTypes.Int64), axes, strides); + var body = MakeSliceImpl(simplifyCall, sliceShape); return body; } @@ -1148,9 +1207,6 @@ private static bool ShouldRestore(Call outerCall, BucketFusion fusion) return false; } - private static Expr RestoreBodyWithArgs(Expr[] args, Var[] parameters, Expr body) => - ReplaceClone(body, parameters.Zip(args).ToArray()); - private static void PrintMinMaxShape(int[][] minFixedShapeList, int[][] maxFixedShapeList, string relPath) { string str = string.Empty; @@ -1216,9 +1272,10 @@ private static Expr ReplaceFusionVarWithCallArgs(BucketFusion fusion, Expr[] arg return result; }); - private static Expr MakeFailure(Expr fusionBody) + private static Expr MakeFailure(FusionBucketContext context) { - var failure = fusionBody.CheckedType switch + // return RestoreBodyWithArgs(context.Arguments, context.Parameters, context.FusionBody); + var failure = context.FusionBody.CheckedType switch { TupleType tuple => new IR.Tuple(tuple.Fields.ToArray() .Select(x => @@ -1226,45 +1283,102 @@ private static Expr MakeFailure(Expr fusionBody) return ConstantOfShape(new[] { 1 }, Cast(0, ((TensorType)x).DType)); }).ToArray()), TensorType tensorType => (Expr)ConstantOfShape(new[] { 1 }, Cast(0, tensorType.DType)), - _ => throw new ArgumentOutOfRangeException("fusionBody"), + _ => throw new ArgumentOutOfRangeException("context"), }; return IR.F.Math.Require(false, failure, "input dim large than limit"); } +} + +[RuleGenerator] +public partial class RebuildBucket : RewriteRule +{ + private static int _counter; - private static bool ShouldBeRebuild(FusionBucketContext context) + private readonly Dictionary _shapeInfo; + + private string _name = string.Empty; + + public RebuildBucket(Dictionary shapeInfo) { - var varInfo = context.DimVarValue(0); - var entry = MakeSplitEntry(context, varInfo, 0, false, false); - return entry switch - { - IR.Tuple tuple => tuple.Fields.ToArray().Any(ShouldBeRebuild), - Call => ShouldBeRebuild(entry), - _ => DumpError(entry), - }; + _shapeInfo = shapeInfo; } - private static bool DumpError(Expr entry) + public override Pattern Pattern => FusionBucket.BucketFusionPattern; + + public Expr? GetReplace(Call outerCall, BucketFusion fusion, Expr fusionBody) { - DumpIR(entry, "FailedEntry"); - throw new InvalidOperationException(); + // only once RecordShape + var options = CompileSession.CompileOptions.ShapeBucketOptions; + + var shapeInfos = Array.Empty(); + if (!_shapeInfo.TryGetValue(fusion, out shapeInfos)) + { + // todo: 不知道为什么有的时候无法从key中获取 + var list = _shapeInfo.Where(x => x.Key == fusion).ToArray(); + if (list.Length != 1) + { + throw new InvalidOperationException($"NoKey{fusion.Name}"); + } + + shapeInfos = list[0].Value; + } + + // 1. 普通情况不应该rebuild + // 2. rebuild的正确性 + var context = new FusionBucketContext(outerCall, fusion, options, ShapeExprCache.Default, _counter, shapeInfos); + + var allFixedShapes = shapeInfos + .Select(x => + x.InputShapes.Select(iShape => iShape.AsTensor().ToArray().ToArray()).ToArray()).ToArray(); + for (int i = 0; i < shapeInfos.Length; i++) + { + for (int j = 0; j < allFixedShapes.Length; j++) + { + context.FixedShapeCache[j] = allFixedShapes[j]; + } + } + + _name = fusion.Name; + if (ShouldBeRebuild(context)) + { + var rebuild = FusionBucket.RestoreBodyWithArgs(context.Arguments, context.Parameters, context.FusionBody); + DumpIR(rebuild, $"{_counter++}_{_name}"); + return rebuild; + } + + return null; } private static bool ShouldBeRebuild(Expr entry) { - if (entry is Call { Target: IR.Tensors.Slice } c) + if (entry.CheckedShape.IsFixed) { - var body = c.Arguments[IR.Tensors.Slice.Input.Index]; - if (body.CheckedShape.IsFixed) - { - var visitor = new DynamicCheckVisitor(); - visitor.Visit(body); - return visitor.HasDynamic; - } + var visitor = new DynamicCheckVisitor(); + visitor.Visit(entry); + return visitor.HasDynamic; } return true; } + private static bool DumpError(Expr entry) + { + DumpIR(entry, "FailedEntry"); + throw new InvalidOperationException(); + } + + private bool ShouldBeRebuild(FusionBucketContext context) + { + var varInfo = context.DimVarValue(0); + var entry = FusionBucket.MakeNewBody(context, varInfo, 0); + DumpIR(entry, $"{_counter}_{_name}"); + return entry switch + { + IR.Tuple tuple => tuple.Fields.ToArray().Any(ShouldBeRebuild), + _ => ShouldBeRebuild(entry), + }; + } + public class DynamicCheckVisitor : ExprVisitor { private bool _hasDynamic; @@ -1288,7 +1402,7 @@ protected override Expr VisitLeafCall(Call expr) } } -internal record SegmentInfo(int InputIndex, int DimIndex, int[] Segments); +public record SegmentInfo(int InputIndex, int DimIndex, int[] Segments); public class FullBucket : FunctionPass { @@ -1305,24 +1419,20 @@ protected override Task RunCoreAsync(BaseFunction input, RunPassCo var options = CompileSession.CompileOptions.ShapeBucketOptions; var tmpFusion = new BucketFusion("stackvm", cloneMain.Body, cloneMain.Parameters, Array.Empty()); var call = new Call(tmpFusion, main.Parameters.ToArray()); + DumpIR(tmpFusion, "FullBucketResult"); + + var shapeList = new Dictionary(); + var tmpF = new Function(call, main.Parameters.ToArray()); + new RecordFusionShape(shapeList).RunAsync(tmpF, ctx).Wait(); + var dimVarValues = MakeVarValuesForAllSegment(options); - var list = InputConfList(dimVarValues); - var shapeData = MakeShapeData(list, options); + var shapeData = shapeList.First().Value; var context = new FusionBucketContext(call, tmpFusion, options, new ShapeExprCache(options.VarMap), 0, shapeData); - var allFixedShapes = shapeData - .Select(x => - x.InputShapes.Select(iShape => iShape.AsTensor().ToArray().ToArray()).ToArray()).ToArray(); - for (int i = 0; i < shapeData.Length; i++) - { - for (int j = 0; j < allFixedShapes.Length; j++) - { - context.FixedShapeCache[j] = allFixedShapes[j]; - } - } + FusionBucket.UpdateShapeCache(shapeData, options, context); - var newBody = FusionBucket.Split(context); + var (newBody, condList) = FusionBucket.Split(context); foreach (var (oldVar, tmpVar) in replaceItem) { ReplaceExpr(newBody, tmpVar, oldVar); diff --git a/src/Nncase.Passes/Rules/ShapeBucket/ShapeBucketHelper.cs b/src/Nncase.Passes/Rules/ShapeBucket/ShapeBucketHelper.cs index 0275771c3c..3f7abdd479 100644 --- a/src/Nncase.Passes/Rules/ShapeBucket/ShapeBucketHelper.cs +++ b/src/Nncase.Passes/Rules/ShapeBucket/ShapeBucketHelper.cs @@ -30,7 +30,9 @@ public static class CallValidator typeof(MatMul).TypeHandle, typeof(Transpose).TypeHandle, typeof(Pad).TypeHandle, - typeof(Tile).TypeHandle, + typeof(Unsqueeze).TypeHandle, + typeof(Squeeze).TypeHandle, + typeof(Unary).TypeHandle, }; private static readonly HashSet MaybeDynamic = new() @@ -42,21 +44,18 @@ public static class CallValidator typeof(Gather).TypeHandle, typeof(ShapeOf).TypeHandle, - typeof(Unsqueeze).TypeHandle, - typeof(Squeeze).TypeHandle, typeof(Cast).TypeHandle, - typeof(Unary).TypeHandle, typeof(Reshape).TypeHandle, typeof(Expand).TypeHandle, typeof(ConstantOfShape).TypeHandle, - typeof(Where).TypeHandle, + + // typeof(Where).TypeHandle, typeof(Compare).TypeHandle, typeof(Reduce).TypeHandle, typeof(Clamp).TypeHandle, typeof(Tile).TypeHandle, typeof(CumSum).TypeHandle, - typeof(IR.Tensors.Range).TypeHandle, }; public static bool IsMaybeDynamic(Expr target) => MaybeDynamic.Contains(target.GetType().TypeHandle); @@ -71,7 +70,7 @@ public static bool ValidTarget(Call call, bool greedy) ShapeBucketHelper.SingleDimVar( CompileSessionScope.GetCurrentThrowIfNull().CompileOptions.ShapeBucketOptions); - if (target is Binary && call.Arguments.ToArray().OfType().Any()) + if (target is Binary && call.Arguments.ToArray().Any(arg => arg is TensorConst || arg is Marker { Target: TensorConst })) { return true; } @@ -179,7 +178,13 @@ public static void Bucket(IPassManager p) public static void Rebuild(IPassManager p, bool singleVar) { - // rebuild + var shapeList = new Dictionary(); + p.Add(shapeList, true); + p.AddWithName("RestoreDynamic").Configure(p => + { + p.Add(shapeList); + }); + ToFusion(p, true); MergeOp(p, false); @@ -191,7 +196,6 @@ public static void Rebuild(IPassManager p, bool singleVar) }); MergeFusion(p, singleVar, false); - Bucket(p); } public static void MergeFusion(IPassManager p, bool singleVar, bool greedy) @@ -226,9 +230,9 @@ public static void ClearMarker(IPassManager p) => public static void Simplify(IPassManager p) => p.AddWithName("Simplify").Configure(c => { + c.Add(); c.Add(); c.Add(); - c.Add(); c.Add(); c.Add(); c.Add(); @@ -246,6 +250,15 @@ public static void Simplify(IPassManager p) => public static class ShapeBucketHelper { + public static bool IsStaticShpae + { + get + { + var options = CompileSessionScope.GetCurrentThrowIfNull().CompileOptions.ShapeBucketOptions; + return SingleDimVar(options); + } + } + public static Dictionary ConcatDictionary(Dictionary memo, Dictionary exprValues) where T : Expr { @@ -257,18 +270,21 @@ public static Dictionary ConcatDictionary(Dictionary me return memo; } - public static Dictionary MakeVarValuesForAllSegment(ShapeBucketOptions options) + public static Dictionary MakeVarValuesForAllSegment(ShapeBucketOptions options, bool staticShape = false) + { + return MakeVarValuesForAllSegment(options, options.SegmentsCount, staticShape); + } + + public static Dictionary MakeVarValuesForAllSegment(ShapeBucketOptions options, int segmentCount, bool staticShape) { - int segmentCount = options.SegmentsCount; var varRange = options.RangeInfo; var varMap = options.VarMap; - var staticShape = false; var varAndInputAllSegment = varRange.ToDictionary(pair => pair.Key, pair => { var (min, max) = pair.Value; if (staticShape) { - return Enumerable.Range(min, max - min).ToArray(); + return Enumerable.Range(min, max - min + 1).ToArray(); } var segments = ComputeSegmentList(segmentCount, min, max); @@ -280,7 +296,10 @@ public static Dictionary MakeVarValuesForAllSegment(ShapeBucketOptio // DimVarName -> Dict.key -> Dict.Value var varValues = varAndInputAllSegment.ToDictionary( pair => vars.FindFirst(v => v.Name == pair.Key), - pair => { return pair.Value.OrderByDescending(x => x).ToArray(); }); + pair => + { + return pair.Value.OrderByDescending(x => x).ToArray(); + }); return varValues; } @@ -302,11 +321,6 @@ public static void ArgsChecker(Expr[] newArgs) throw new InvalidOperationException("Args has Var in fusion"); } - if (newArgs.Any(arg => arg is Marker m && m.Target is Const)) - { - throw new InvalidOperationException("Args has tuple"); - } - if (newArgs.Any(arg => arg is IR.Tuple)) { throw new InvalidOperationException("Args has tuple"); diff --git a/src/Nncase.Passes/Rules/ShapeExpr/FoldSplitShapeOf.cs b/src/Nncase.Passes/Rules/ShapeExpr/FoldSplitShapeOf.cs index 737e371356..dd530a7c54 100644 --- a/src/Nncase.Passes/Rules/ShapeExpr/FoldSplitShapeOf.cs +++ b/src/Nncase.Passes/Rules/ShapeExpr/FoldSplitShapeOf.cs @@ -25,16 +25,25 @@ public partial class FoldSplitShapeOf : RewriteRule new VArgsPattern( list => Enumerable.Range(0, list.Length) - .Select(_ => IsGetItem(InputPattern, IsTensorConst())) + .Select(_ => IsAlt(IsCast(c => c.NewType == DataTypes.Int64, InputPattern), InputPattern)) .ToArray(), "args")), IsTensorConst(tensor => tensor.Value.ToScalar() == 0)); - public Pattern InputPattern => IsShapeOf(IsWildcard()); + public Pattern InputPattern => IsGetItem(IsShapeOf(IsWildcard()), IsTensorConst()); private Expr? GetReplace(IR.Tuple tuple) { - var getItemList = tuple.Fields.ToArray().OfType().ToArray(); + var getItemList = tuple.Fields.ToArray().OfType().Select(x => + { + if (x.Target is Cast) + { + return x.Arguments[Cast.Input.Index]; + } + + return x; + }).OfType().ToArray(); + var getItemIndices = getItemList.Select(x => x.Arguments[GetItem.Index.Index]).OfType().Select(x => x.Value.ToScalar()).ToArray(); if (getItemIndices.Length == 0) { diff --git a/src/Nncase.Passes/Rules/WithMarker/CombineTranspose.cs b/src/Nncase.Passes/Rules/WithMarker/CombineTranspose.cs new file mode 100644 index 0000000000..eb50c92a69 --- /dev/null +++ b/src/Nncase.Passes/Rules/WithMarker/CombineTranspose.cs @@ -0,0 +1,259 @@ +// Copyright (c) Canaan Inc. All rights reserved. +// Licensed under the Apache license. See LICENSE file in the project root for full license information. + +using System; +using System.Collections.Generic; +using System.Collections.Immutable; +using System.Linq; +using Nncase.Evaluator; +using Nncase.IR; +using Nncase.IR.Math; +using Nncase.IR.NN; +using Nncase.PatternMatch; +using Nncase.Utilities; +using static Nncase.IR.F.Math; +using static Nncase.IR.F.NN; +using static Nncase.IR.F.Tensors; +using static Nncase.IR.TypePatternUtility; +using static Nncase.PatternMatch.F.Math; +using static Nncase.PatternMatch.F.NN; +using static Nncase.PatternMatch.F.Tensors; +using static Nncase.PatternMatch.Utility; +using static Nncase.Utilities.MetadataUtility; +using Tuple = System.Tuple; + +namespace Nncase.Passes.Rules.WithMarker; + +/// +/// transpose(activation(x),perm) => activation(transpose(x,perm)). +/// +[RuleGenerator] +public sealed partial class CombineTransposeActivations : IRewriteRule +{ + /// + public IPattern Pattern { get; } = + HasMarker( + IsTranspose( + HasMarker( + IsCall("actCall", IsOp("activation", op => true), IsVArgsRepeat("arguments", () => IsWildcard() with { TypePattern = HasFixedShape() })), + "outputMarker"), + IsTensorConst("perm")), + "transposeMarker"); + + private Expr? GetReplace(Call actCall, ActivationOp activation, IReadOnlyList arguments, int[] perm, Marker transposeMarker, Marker outputMarker) + { + // todo: argument 1 is marker + var newArgs = new List(); + foreach (var arg in arguments) + { + if (arg.CheckedShape.IsScalar) + { + newArgs.Add(arg); + continue; + } + else if (arg.CheckedShape.Rank <= perm.Length) + { + newArgs.Add(transposeMarker.With(target: Transpose(arg, perm.Select(p => p - (perm.Length - arg.CheckedShape.Rank)).Where(p => p >= 0).ToArray()))); + continue; + } + else + { + return null; + } + } + + var newcall = new Call(activation, newArgs.ToArray()); + newcall.InheritMetaData(actCall); + return outputMarker.With(target: newcall); + } +} + +/// +/// activations(transpose(input,p),args...) => transpose(activations(input,args...),p). +/// +[RuleGenerator] +public sealed partial class CombineActivationsTranspose : IRewriteRule +{ + /// + public IPattern Pattern { get; } = + HasMarker( + IsCall("actCall", IsOp("activation", op => true), IsVArgsRepeat("parameters", (inputs) => + { + var patterns = new Pattern[inputs.Length]; + patterns[0] = HasMarker( + IsTranspose(IsWildcard("input"), IsWildcard("perm")), + "inputMarker"); + for (int i = 1; i < inputs.Length; i++) + { + patterns[i] = IsWildcard(); + } + + return patterns; + })), + "outputMarker"); + + private Expr? GetReplace(Call actCall, ActivationOp activation, Expr input, IReadOnlyList parameters, Expr perm, Marker inputMarker, Marker outputMarker) + { + // note the prelu scope can be broadcast with inputs. + if (activation is PRelu && parameters[1].CheckedShape.Rank > 1) + { + if (perm is not TensorConst const_perm || parameters[1] is not TensorConst slope) + { + return null; + } + + // eg. transpose(input,perm) shape = [1,32,32,8], scope = [1,1,8] + Expr new_slope; + var perms = const_perm.Value.ToArray(); + if (slope.Value.Shape.Rank == input.CheckedShape.Rank - 1) + { + if (perms[0] != 0) + { + return null; + } + + var inv_perm = perms.Skip(1).Select((p, i) => (p - 1, i)).OrderBy(tp => tp.Item1).Select(tp => tp.i).ToArray(); + new_slope = Const.FromValue(Transpose(slope, inv_perm).Evaluate()); + return outputMarker.With(target: Transpose(outputMarker.With(target: new Call(activation, inputMarker.With(target: input), new_slope)), perm)); + } + else if (slope.Value.Shape.Rank == input.CheckedShape.Rank) + { + var inv_perm = perms.Select((p, i) => (p, i)).OrderBy(tp => tp.p).Select(tp => tp.i).ToArray(); + new_slope = Const.FromValue(Transpose(slope, inv_perm).Evaluate()); + } + else + { + return null; + } + + return outputMarker.With(target: Transpose(outputMarker.With(target: new Call(activation, inputMarker.With(target: input), new_slope)), perm)); + } + + var newCall = new Call(activation, new Expr[] { input }.Concat(parameters.Skip(1)).ToArray()); + newCall.InheritMetaData(actCall); + return outputMarker.With(target: Transpose( + outputMarker.With(target: newCall), + perm)); + } +} + +/// +/// activations(reshape(input, shape), args...) => reshape(activations(input, args...), shape). +/// +[RuleGenerator] +public sealed partial class CombineActivationsReshape : IRewriteRule +{ + /// + public IPattern Pattern { get; } = + HasMarker( + IsCall("call", IsOp("activation", op => true), IsVArgsRepeat("parameters", (inputs) => + { + var patterns = new Pattern[inputs.Length]; + patterns[0] = HasMarker(IsReshape(IsWildcard("input"), IsWildcard("shape")), "inputMarker"); + for (int i = 1; i < inputs.Length; i++) + { + patterns[i] = IsWildcard(); + } + + return patterns; + })), + "outMarker"); + + private Expr? GetReplace(ActivationOp activation, Call call, Expr input, IReadOnlyList parameters, Expr shape, Marker inputMarker, Marker outMarker) + { + // TODO: Not support PRelu for now. + if (activation is PRelu) + { + return null; + } + + return outMarker.With(target: Reshape( + new Call(activation, new Expr[] { inputMarker.With(target: input) }.Concat(parameters.Skip(1)).ToArray()).InheritMetaData(call), + shape)); + } +} + +[RuleGenerator] +public partial class FoldTransposeActTranspose : RewriteRule +{ + public override Pattern Pattern => IsTranspose( + "outTr", + "outTrCall", + LeakyReluPattern, + IsWildcard("perm2")); + + public Pattern LeakyReluPattern => HasMarker( + IsLeakyRelu( + "target", + "call", + HasMarker( + IsTranspose(IsWildcard("input") with { TypePattern = HasFixedShape() }, IsWildcard("perm1")), + "inMarker"), + IsWildcard("alpha")), + "outMarker"); + + private Expr? GetReplace(Call call, Expr input, Marker inMarker, Marker outMarker, int[] perm1, int[] perm2, Call outTrCall, Expr alpha, Expr target) + { + if (perm1.Length != perm2.Length) + { + return null; + } + + if (outTrCall.CheckedShape.SequenceEqual(input.CheckedShape)) + { + return outMarker.With(target: ReplaceUtility.ReplaceCallFirstParam(target, call.Arguments.ToArray(), inMarker.With(target: input))); + } + + // transpose(leakyrelu(transpose(input))) => leakyRelu(transpose(transpose(input))) + else + { + return outMarker.With(target: Transpose(outMarker.With(target: Transpose(outMarker.With(target: LeakyRelu(inMarker.With(target: input), alpha)), perm1)), perm2)); + } + } +} + +[RuleGenerator] +public partial class FoldTransposeBinaryActTranspose : RewriteRule +{ + public override Pattern Pattern => IsTranspose( + HasMarker( + IsReshape( + HasMarker( + IsLeakyRelu( + "op", + "call", + HasMarker( + IsBinary( + "bn", + "bnCall", + BinaryOp.Add, + HasMarker( + IsReshape( + HasMarker( + IsTranspose(HasMarker(IsWildcard(), "input"), IsWildcard("perm1"))), + IsWildcard())), + IsWildcard("rhs")), + "bnMarker"), + IsWildcard("alpha"))), + IsWildcard()), + "outMarker"), + IsWildcard("perm2")); + + private Expr? GetReplace(int[] perm1, int[] perm2, Expr input, Expr rhs, Marker bnMarker, Marker outMarker, Expr alpha) + { + if (perm1.SequenceEqual(new[] { 0, 2, 3, 1 }) && perm2.SequenceEqual(new[] { 0, 3, 1, 2 })) + { + // transpose shape check + // input no marker + if (rhs is Marker m) + { + var constRhs = m.With(target: Reshape(m.Target, new[] { rhs.CheckedShape.Size, 1, 1 }).Evaluate().AsTensor()); + return outMarker.With(target: LeakyRelu(bnMarker.With(target: Add(input, constRhs)), alpha)); + } + + return outMarker.With(target: LeakyRelu(bnMarker.With(target: Add(input, Reshape(rhs, new[] { rhs.CheckedShape.Size, 1, 1 }))), alpha)); + } + + return null; + } +} diff --git a/src/Nncase.Simulator/Runtime/Interop/RTHostMemoryManager.cs b/src/Nncase.Simulator/Runtime/Interop/RTHostMemoryManager.cs index 049cb5cd5a..81ceba06f2 100644 --- a/src/Nncase.Simulator/Runtime/Interop/RTHostMemoryManager.cs +++ b/src/Nncase.Simulator/Runtime/Interop/RTHostMemoryManager.cs @@ -59,7 +59,7 @@ public override void Unpin() protected override void Dispose(bool disposing) { var pointer = Interlocked.Exchange(ref _pointer, IntPtr.Zero); - if (pointer != IntPtr.Zero && _buffer != null) + if (pointer != IntPtr.Zero && _buffer != null && _length != 0) { Native.HostBufferUnmap(_buffer.DangerousGetHandle()); GC.RemoveMemoryPressure(_length); diff --git a/src/Nncase.Tests.TestFixture/TransformBase/Compare.cs b/src/Nncase.Tests.TestFixture/TransformBase/Compare.cs index e8f7213cb6..feed797ce6 100644 --- a/src/Nncase.Tests.TestFixture/TransformBase/Compare.cs +++ b/src/Nncase.Tests.TestFixture/TransformBase/Compare.cs @@ -94,6 +94,11 @@ public static float CosSimilarity(Tensor a, Tensor b) return 1f; } + if (!a.Shape.ToValueArray().SequenceEqual(b.Shape.ToValueArray())) + { + throw new InvalidOperationException(); + } + var va = a.ToArray(); var vb = b.ToArray(); var v1 = Math.Sqrt(Prod(va, va)); diff --git a/src/Nncase.Tests/Evaluator/UnitTestEvaluatorNN.cs b/src/Nncase.Tests/Evaluator/UnitTestEvaluatorNN.cs index 74296b8cfb..802fc6bfed 100755 --- a/src/Nncase.Tests/Evaluator/UnitTestEvaluatorNN.cs +++ b/src/Nncase.Tests/Evaluator/UnitTestEvaluatorNN.cs @@ -660,10 +660,10 @@ public void TestSpaceToBatch() var output = new float[] { 1, 3, 9, 11, 2, 4, 10, 12, 5, 7, 13, 15, 6, 8, 14, 16 }; var expect = Tensor.From(output, new[] { 4, 2, 2, 1 }); var crops = new long[] { 0, 0, 0, 0 }; - var expr = IR.F.NN.SpaceToBatch( - input, + var expr = NCHWToNHWC(IR.F.NN.SpaceToBatch( + NHWCToNCHW(input).Evaluate().AsTensor(), Tensor.From(shape, new[] { 2 }), - Tensor.From(crops, new[] { 2, 2 })); + Tensor.From(crops, new[] { 2, 2 }))); CompilerServices.InferenceType(expr); Assert.Equal(expect, expr.Evaluate().AsTensor()); } diff --git a/src/Nncase.Tests/Evaluator/UnitTestShapeEvaluator.cs b/src/Nncase.Tests/Evaluator/UnitTestShapeEvaluator.cs index 97c9a8a5c0..9bbbf554f5 100644 --- a/src/Nncase.Tests/Evaluator/UnitTestShapeEvaluator.cs +++ b/src/Nncase.Tests/Evaluator/UnitTestShapeEvaluator.cs @@ -262,7 +262,7 @@ public void TestSpaceTobatch() var dimVar = new Var(new TensorType(DataTypes.Int32, Shape.Scalar)); var input = new Var(new TensorType(DataTypes.Float32, new[] { 1, Dimension.Unknown, 192 })); var paddings = Tensor.From(new[] { 0, 1 }, new[] { 1, 2 }); - var expr = SpaceToBatch(input, new[] { 3 }, paddings); + var expr = NCHWToNHWC(SpaceToBatch(NHWCToNCHW(input), new[] { 3 }, paddings)); var dict = new Dictionary { { input, new Expr[] { 1, dimVar, 192 } } }; var shape = expr.EvaluateShapeExpr(dict); var varValues = new Dictionary { { dimVar, Value.FromTensor(8) } }; diff --git a/src/Nncase.Tests/Rules/Neutral/UnitTestSpaceToBatchTransform.cs b/src/Nncase.Tests/Rules/Neutral/UnitTestSpaceToBatchTransform.cs index 4fa0dd0c0a..563693c66f 100644 --- a/src/Nncase.Tests/Rules/Neutral/UnitTestSpaceToBatchTransform.cs +++ b/src/Nncase.Tests/Rules/Neutral/UnitTestSpaceToBatchTransform.cs @@ -12,6 +12,7 @@ using Nncase.Passes.Rules.Neutral; using Nncase.Tests.TestFixture; using Xunit; +using static Nncase.IR.F.Tensors; using Math = Nncase.IR.F.Math; using NN = Nncase.IR.F.NN; using Random = Nncase.IR.F.Random; @@ -41,7 +42,7 @@ public class UnitTestSpaceToBatchToPad : TransformTestBase public void TestSpaceToBatchToPadPositive(int[] shape, int[] blockShape, int[,] paddings) { var a = Random.Normal(DataTypes.Float32, 0, 1, 0, shape); - var rootPre = NN.SpaceToBatch(a, blockShape, paddings); + var rootPre = NCHWToNHWC(NN.SpaceToBatch(NHWCToNCHW(a), blockShape, paddings)); TestMatched(rootPre); } diff --git a/src/Nncase.Tests/Rules/Neutral/UnitTestSplitSpaceToBatch.cs b/src/Nncase.Tests/Rules/Neutral/UnitTestSplitSpaceToBatch.cs index 4ef0186aad..210d63fdfe 100644 --- a/src/Nncase.Tests/Rules/Neutral/UnitTestSplitSpaceToBatch.cs +++ b/src/Nncase.Tests/Rules/Neutral/UnitTestSplitSpaceToBatch.cs @@ -8,6 +8,7 @@ using Nncase.Tests.TestFixture; using Xunit; using static Nncase.IR.F.NN; +using static Nncase.IR.F.Tensors; namespace Nncase.Tests.Rules.NeutralTest; @@ -17,7 +18,7 @@ public class UnitTestSpaceToBatch : TransformTestBase [Fact] public void TestSplitSpaceToBatch() { - var i = SpaceToBatch(Testing.Rand(1, 206, 192), new[] { 3 }, new[,] { { 0, 1 } }); + var i = NCHWToNHWC(SpaceToBatch(NHWCToNCHW(Testing.Rand(1, 206, 192)), new[] { 3 }, new[,] { { 0, 1 } })); var originEvaluateResult = i.Evaluate(); var newBody = TestMatched(i); var ev = newBody.Evaluate(); diff --git a/src/Nncase.Tests/Rules/ShapeBucket/ShapeBucketTest.cs b/src/Nncase.Tests/Rules/ShapeBucket/ShapeBucketTest.cs index 1e0ae120fd..f6964fb288 100644 --- a/src/Nncase.Tests/Rules/ShapeBucket/ShapeBucketTest.cs +++ b/src/Nncase.Tests/Rules/ShapeBucket/ShapeBucketTest.cs @@ -106,7 +106,7 @@ public async Task TestRebuild() var newBody = TestMatchedCore( main.Body!, new Dictionary { { mainVar, Value.FromTensor(input) } }, - new FusionBucket(shape)); + new RebuildBucket(shape)); Assert.True(newBody is Call { Target: IR.Math.MatMul }); } diff --git a/src/Nncase.Tests/Rules/ShapeExpr/UnitTestFoldSplitShapeOf.cs b/src/Nncase.Tests/Rules/ShapeExpr/UnitTestFoldSplitShapeOf.cs index 845755e2f8..9f947992ce 100644 --- a/src/Nncase.Tests/Rules/ShapeExpr/UnitTestFoldSplitShapeOf.cs +++ b/src/Nncase.Tests/Rules/ShapeExpr/UnitTestFoldSplitShapeOf.cs @@ -25,4 +25,13 @@ public void TestFoldSplitShapeOf() var newShape = Stack(new IR.Tuple(shape[0], shape[1], shape[2], shape[3]), 0); TestMatched(newShape); } + + [Fact] + public void TestFoldSplitCastShapeOf() + { + var input = Testing.Rand(1, 3, 24, 24); + var shape = ShapeOf(input); + var newShape = Stack(new IR.Tuple(Cast(shape[0], DataTypes.Int64), Cast(shape[1], DataTypes.Int64), Cast(shape[2], DataTypes.Int64), Cast(shape[3], DataTypes.Int64)), 0); + TestMatched(newShape); + } }