From e39f97198a66b79e620320be87ac88823b888be3 Mon Sep 17 00:00:00 2001 From: sunnycase Date: Wed, 24 May 2023 17:49:09 +0800 Subject: [PATCH 001/308] Add cpu module --- .../Nncase.Modules.CPU/CPUApplicationPart.cs | 30 +++ modules/Nncase.Modules.CPU/CPUModule.cs | 19 ++ .../Evaluator/CPU/CPUModule.cs | 18 ++ .../Evaluator/CPU/CPUUnary.cs | 141 ++++++++++ modules/Nncase.Modules.CPU/IR/CPU/CPUUnary.cs | 23 ++ .../Nncase.Modules.CPU/IR/CPU/Functional.cs | 26 ++ .../Nncase.Modules.CPU.csproj | 21 ++ .../Passes/Rules/LowerUnary.cs | 33 +++ .../Targets/CPUTarget.cs | 5 + modules/Nncase.Modules.CPU/packages.lock.json | 243 ++++++++++++++++++ .../Nncase.Modules.StackVM/StackVMModule.cs | 1 - nncase.sln | 7 + python/nncase/__init__.py | 2 +- src/Nncase.Cli/Nncase.Cli.csproj | 4 + src/Nncase.Cli/packages.lock.json | 9 + .../Hosting/CompilerHostBuilderExtensions.cs | 3 +- src/Nncase.Compiler/Hosting/PluginLoader.cs | 1 + src/Nncase.Compiler/Nncase.Compiler.csproj | 1 + src/Nncase.Compiler/packages.lock.json | 8 + .../Nncase.Tests.TestFixture.csproj | 1 + .../packages.lock.json | 8 + src/Nncase.Tests/Targets/UnitTestCPUTarget.cs | 10 + src/Nncase.Tests/packages.lock.json | 10 + 23 files changed, 621 insertions(+), 3 deletions(-) create mode 100644 modules/Nncase.Modules.CPU/CPUApplicationPart.cs create mode 100644 modules/Nncase.Modules.CPU/CPUModule.cs create mode 100644 modules/Nncase.Modules.CPU/Evaluator/CPU/CPUModule.cs create mode 100644 modules/Nncase.Modules.CPU/Evaluator/CPU/CPUUnary.cs create mode 100644 modules/Nncase.Modules.CPU/IR/CPU/CPUUnary.cs create mode 100644 modules/Nncase.Modules.CPU/IR/CPU/Functional.cs create mode 100644 modules/Nncase.Modules.CPU/Nncase.Modules.CPU.csproj create mode 100644 modules/Nncase.Modules.CPU/Passes/Rules/LowerUnary.cs rename modules/{Nncase.Modules.StackVM => Nncase.Modules.CPU}/Targets/CPUTarget.cs (94%) create mode 100644 modules/Nncase.Modules.CPU/packages.lock.json diff --git a/modules/Nncase.Modules.CPU/CPUApplicationPart.cs b/modules/Nncase.Modules.CPU/CPUApplicationPart.cs new file mode 100644 index 0000000000..655138ddd4 --- /dev/null +++ b/modules/Nncase.Modules.CPU/CPUApplicationPart.cs @@ -0,0 +1,30 @@ +// Copyright (c) Canaan Inc. All rights reserved. +// Licensed under the Apache license. See LICENSE file in the project root for full license information. + +using System; +using System.Collections.Generic; +using System.Linq; +using System.Reflection; +using System.Text; +using System.Threading.Tasks; +using DryIoc; +using Nncase.Hosting; + +namespace Nncase; + +/// +/// CPU application part extensions. +/// +public static class CPUApplicationPart +{ + /// + /// Add CPU assembly. + /// + /// Service registrator. + /// Configured service registrator. + public static IRegistrator AddCPU(this IRegistrator registrator) + { + return registrator.RegisterModule() + .RegisterModule(); + } +} diff --git a/modules/Nncase.Modules.CPU/CPUModule.cs b/modules/Nncase.Modules.CPU/CPUModule.cs new file mode 100644 index 0000000000..5e91015cef --- /dev/null +++ b/modules/Nncase.Modules.CPU/CPUModule.cs @@ -0,0 +1,19 @@ +// Copyright (c) Canaan Inc. All rights reserved. +// Licensed under the Apache license. See LICENSE file in the project root for full license information. + +using DryIoc; +using Nncase.Hosting; +using Nncase.Targets; + +namespace Nncase; + +/// +/// CPU module. +/// +internal class CPUModule : IApplicationPart +{ + public void ConfigureServices(IRegistrator registrator) + { + registrator.Register(reuse: Reuse.Singleton); + } +} diff --git a/modules/Nncase.Modules.CPU/Evaluator/CPU/CPUModule.cs b/modules/Nncase.Modules.CPU/Evaluator/CPU/CPUModule.cs new file mode 100644 index 0000000000..025c6e3b1b --- /dev/null +++ b/modules/Nncase.Modules.CPU/Evaluator/CPU/CPUModule.cs @@ -0,0 +1,18 @@ +// Copyright (c) Canaan Inc. All rights reserved. +// Licensed under the Apache license. See LICENSE file in the project root for full license information. + +using DryIoc; +using Nncase.Hosting; + +namespace Nncase.Evaluator.CPU; + +/// +/// CPU module. +/// +internal class CPUModule : IApplicationPart +{ + public void ConfigureServices(IRegistrator registrator) + { + registrator.RegisterManyInterface(reuse: Reuse.Singleton); + } +} diff --git a/modules/Nncase.Modules.CPU/Evaluator/CPU/CPUUnary.cs b/modules/Nncase.Modules.CPU/Evaluator/CPU/CPUUnary.cs new file mode 100644 index 0000000000..d48a45b8e0 --- /dev/null +++ b/modules/Nncase.Modules.CPU/Evaluator/CPU/CPUUnary.cs @@ -0,0 +1,141 @@ +// Copyright (c) Canaan Inc. All rights reserved. +// Licensed under the Apache license. See LICENSE file in the project root for full license information. + +using System; +using System.Linq; +using Nncase.CostModel; +using Nncase.IR; +using Nncase.IR.CPU; +using OrtKISharp; + +namespace Nncase.Evaluator.CPU; + +/// +/// Evaluator for . +/// +public class CPUUnaryEvaluator : IEvaluator, ITypeInferencer, ICostEvaluator, IOpPrinter +{ + /// + public IValue Visit(IEvaluateContext context, CPUUnary unary) + { + var input_tensor = context.GetArgumentValueAsTensor(unary, CPUUnary.Input); + if (input_tensor.Shape.IsScalar) + { + if (input_tensor.ElementType == DataTypes.Int32) + { + return Value.FromTensor(Tensor.FromScalar(Compute_int(input_tensor.ToScalar(), unary.UnaryOp))); + } + else if (input_tensor.ElementType == DataTypes.Float32) + { + return Value.FromTensor(Tensor.FromScalar(Compute_float(input_tensor.ToScalar(), unary.UnaryOp))); + } + } + + var input = context.GetOrtArgumentValue(unary, CPUUnary.Input); + var result = unary.UnaryOp switch + { + UnaryOp.Abs => OrtKI.Abs(input), + UnaryOp.Acos => OrtKI.Acos(input), + UnaryOp.Acosh => OrtKI.Acosh(input), + UnaryOp.Asin => OrtKI.Asin(input), + UnaryOp.Asinh => OrtKI.Asinh(input), + UnaryOp.Ceil => OrtKI.Ceil(input), + UnaryOp.Cos => OrtKI.Cos(input), + UnaryOp.Cosh => OrtKI.Cosh(input), + UnaryOp.Exp => OrtKI.Exp(input), + UnaryOp.Floor => OrtKI.Floor(input), + UnaryOp.Log => OrtKI.Log(input), + UnaryOp.Neg => OrtKI.Neg(input), + UnaryOp.Round => OrtKI.Round(input), + UnaryOp.Rsqrt => OrtKI.Rsqrt(input), + UnaryOp.Sin => OrtKI.Sin(input), + UnaryOp.Sinh => OrtKI.Sinh(input), + UnaryOp.Sign => OrtKI.Sign(input), + UnaryOp.Sqrt => OrtKI.Sqrt(input), + UnaryOp.Square => OrtKI.Square(input), + UnaryOp.Tanh => OrtKI.Tanh(input), + UnaryOp.BitwiseNot => throw new NotSupportedException("NotSupported UnaryOp BitwiseNot"), + UnaryOp.LogicalNot => OrtKI.Not(input), + _ => throw new ArgumentOutOfRangeException(nameof(unary)), + }; + return result.ToValue(); + } + + /// + public IRType Visit(ITypeInferenceContext context, CPUUnary target) + { + var input = context.CheckArgumentType(target, CPUUnary.Input); + return Visit(input); + } + + /// + public Cost Visit(ICostEvaluateContext context, CPUUnary target) + { + var inputType = context.GetArgumentType(target, CPUUnary.Input); + var outputType = context.GetReturnType(); + + return new() + { + [CostFactorNames.MemoryLoad] = CostUtility.GetMemoryAccess(inputType), + [CostFactorNames.MemoryStore] = CostUtility.GetMemoryAccess(outputType), + [CostFactorNames.CPUCycles] = CostUtility.GetCPUCycles(outputType, CostUtility.GetCPUCyclesOfUnary(target.UnaryOp)), + }; + } + + /// + public string Visit(IIRPrinterContext context, CPUUnary target, bool iLmode) + { + var op_str = target.UnaryOp switch + { + UnaryOp.BitwiseNot => "!", + UnaryOp.LogicalNot => "!", + var op => op.ToString(), + }; + if (!iLmode) + { + return $"{op_str}({string.Join(", ", target.Parameters.Select(p => p.Name + ": " + context.GetArgument(target, p).Serialize()))})"; + } + + throw new NotSupportedException("ILmode = true"); + } + + private int Compute_int(int input, UnaryOp op) => op switch + { + UnaryOp.Ceil => input, + UnaryOp.Floor => input, + UnaryOp.Neg => -input, + UnaryOp.Abs => System.Math.Abs(input), + UnaryOp.Square => input * input, + _ => throw new ArgumentOutOfRangeException(nameof(op), $"NotSupported {nameof(op)} For Int"), + }; + + private float Compute_float(float input, UnaryOp op) => op switch + { + UnaryOp.Abs => System.MathF.Abs(input), + UnaryOp.Acos => System.MathF.Acos(input), + UnaryOp.Acosh => System.MathF.Acosh(input), + UnaryOp.Asin => System.MathF.Asin(input), + UnaryOp.Asinh => System.MathF.Asinh(input), + UnaryOp.Ceil => System.MathF.Ceiling(input), + UnaryOp.Cos => System.MathF.Cos(input), + UnaryOp.Cosh => System.MathF.Cosh(input), + UnaryOp.Exp => System.MathF.Exp(input), + UnaryOp.Floor => System.MathF.Floor(input), + UnaryOp.Log => System.MathF.Log(input), + UnaryOp.Neg => -input, + UnaryOp.Round => System.MathF.Round(input), + UnaryOp.Rsqrt => 1.0f / System.MathF.Sqrt(input), + UnaryOp.Sin => System.MathF.Sin(input), + UnaryOp.Sinh => System.MathF.Sinh(input), + UnaryOp.Sign => System.MathF.Sign(input), + UnaryOp.Sqrt => System.MathF.Sqrt(input), + UnaryOp.Square => input * input, + UnaryOp.Tanh => System.MathF.Tanh(input), + _ => throw new ArgumentOutOfRangeException(nameof(op), $"NotSupported {nameof(op)} For Float"), + }; + + private IRType Visit(TensorType input) + { + return input; + } +} diff --git a/modules/Nncase.Modules.CPU/IR/CPU/CPUUnary.cs b/modules/Nncase.Modules.CPU/IR/CPU/CPUUnary.cs new file mode 100644 index 0000000000..44dbb59fd8 --- /dev/null +++ b/modules/Nncase.Modules.CPU/IR/CPU/CPUUnary.cs @@ -0,0 +1,23 @@ +// Copyright (c) Canaan Inc. All rights reserved. +// Licensed under the Apache license. See LICENSE file in the project root for full license information. + +using System; +using System.Collections.Generic; +using System.Linq; +using System.Text; +using System.Threading.Tasks; +using Nncase.IR.Math; +using Nncase.PatternMatch; + +namespace Nncase.IR.CPU; + +[PatternFunctionalGenerator] +public sealed partial class CPUUnary : Op +{ + /// + /// Gets input. + /// + public static readonly ParameterInfo Input = new(typeof(CPUUnary), 0, "input"); + + public UnaryOp UnaryOp { get; } +} diff --git a/modules/Nncase.Modules.CPU/IR/CPU/Functional.cs b/modules/Nncase.Modules.CPU/IR/CPU/Functional.cs new file mode 100644 index 0000000000..9fd63b63a6 --- /dev/null +++ b/modules/Nncase.Modules.CPU/IR/CPU/Functional.cs @@ -0,0 +1,26 @@ +// Copyright (c) Canaan Inc. All rights reserved. +// Licensed under the Apache license. See LICENSE file in the project root for full license information. + +using System; +using System.Collections.Generic; +using System.Linq; +using System.Text; +using System.Threading.Tasks; +using Nncase.IR.CPU; +using Nncase.IR.Math; + +namespace Nncase.IR.F; + +public partial class CPU +{ + /// + /// Call unary. + /// + /// Unary operator. + /// Source expression. + /// Result expression. + public static Call CPUUnary(UnaryOp unaryOp, Expr expr) + { + return new Call(new CPUUnary(unaryOp), expr); + } +} diff --git a/modules/Nncase.Modules.CPU/Nncase.Modules.CPU.csproj b/modules/Nncase.Modules.CPU/Nncase.Modules.CPU.csproj new file mode 100644 index 0000000000..fb4674b51c --- /dev/null +++ b/modules/Nncase.Modules.CPU/Nncase.Modules.CPU.csproj @@ -0,0 +1,21 @@ + + + + Nncase + enable + true + true + True + + + + + + + + + + + + + diff --git a/modules/Nncase.Modules.CPU/Passes/Rules/LowerUnary.cs b/modules/Nncase.Modules.CPU/Passes/Rules/LowerUnary.cs new file mode 100644 index 0000000000..bb77cc1354 --- /dev/null +++ b/modules/Nncase.Modules.CPU/Passes/Rules/LowerUnary.cs @@ -0,0 +1,33 @@ +// Copyright (c) Canaan Inc. All rights reserved. +// Licensed under the Apache license. See LICENSE file in the project root for full license information. + +using System; +using System.Collections.Generic; +using System.Linq; +using System.Text; +using System.Threading.Tasks; +using Nncase.IR; +using Nncase.IR.Math; +using Nncase.PatternMatch; + +using static Nncase.IR.F.CPU; +using static Nncase.IR.TypePatternUtility; +using static Nncase.PatternMatch.F.Math; +using static Nncase.PatternMatch.Utility; + +namespace Nncase.Passes.Rules; + +[RuleGenerator] +public partial class LowerUnary : RewriteRule +{ + /// + public override Pattern Pattern { get; } = IsUnary( + target_name: "unary", + _ => true, + IsWildcard("input")); + + private Expr GetReplace(Unary unary, Expr input) + { + return CPUUnary(unary.UnaryOp, input); + } +} diff --git a/modules/Nncase.Modules.StackVM/Targets/CPUTarget.cs b/modules/Nncase.Modules.CPU/Targets/CPUTarget.cs similarity index 94% rename from modules/Nncase.Modules.StackVM/Targets/CPUTarget.cs rename to modules/Nncase.Modules.CPU/Targets/CPUTarget.cs index ba77ae58f5..27fb312bfd 100644 --- a/modules/Nncase.Modules.StackVM/Targets/CPUTarget.cs +++ b/modules/Nncase.Modules.CPU/Targets/CPUTarget.cs @@ -12,6 +12,7 @@ using Nncase.CodeGen.StackVM; using Nncase.IR; using Nncase.Passes; +using Nncase.Passes.Rules; using Nncase.Quantization; namespace Nncase.Targets; @@ -41,6 +42,10 @@ public void RegisterTargetInDependentPass(IPassManager passManager, CompileOptio /// public void RegisterTargetDependentPass(IPassManager passManager, CompileOptions options) { + passManager.AddWithName("LowerIR").Configure(p => + { + p.Add(); + }); } /// diff --git a/modules/Nncase.Modules.CPU/packages.lock.json b/modules/Nncase.Modules.CPU/packages.lock.json new file mode 100644 index 0000000000..444b371070 --- /dev/null +++ b/modules/Nncase.Modules.CPU/packages.lock.json @@ -0,0 +1,243 @@ +{ + "version": 2, + "dependencies": { + "net7.0": { + "StyleCop.Analyzers": { + "type": "Direct", + "requested": "[1.2.0-beta.435, )", + "resolved": "1.2.0-beta.435", + "contentHash": "TADk7vdGXtfTnYCV7GyleaaRTQjfoSfZXprQrVMm7cSJtJbFc1QIbWPyLvrgrfGdfHbGmUPvaN4ODKNxg2jgPQ==", + "dependencies": { + "StyleCop.Analyzers.Unstable": "1.2.0.435" + } + }, + "libortki": { + "type": "Transitive", + "resolved": "0.0.2", + "contentHash": "svfuG5mxGY/QC/5DVheHOCELmdSP90RtxQ73j23KarPXZ9ZXW+7v1l5J77hGDyQbEh1BGrnGgKBlyn76RauGHg==", + "dependencies": { + "libortki-linux": "0.0.2", + "libortki-osx": "0.0.2", + "libortki-osx-arm64": "0.0.2", + "libortki-win": "0.0.2" + } + }, + "libortki-linux": { + "type": "Transitive", + "resolved": "0.0.2", + "contentHash": "b04LWD4lgGy60tys3hPFhnUpgWDM6dN5r1PI7GOcPj8VupXCaI70LKNQ5/5twbDE6rkowOGanVTw0S2wBGBqBQ==" + }, + "libortki-osx": { + "type": "Transitive", + "resolved": "0.0.2", + "contentHash": "O6Q9GLULkDkZEPAZJVKLPH0ROXGVOE7BxuddgOcHNK2oiTEM7wIRnzp2OIlYgLpaOLyxJMisbGOhtWgdzt2Wng==" + }, + "libortki-osx-arm64": { + "type": "Transitive", + "resolved": "0.0.2", + "contentHash": "4Qn2dirJmRicnUG945oWpq7HVGwgqCKKxYPMISv/MRvmpZBbXrZ1cVvRaF8WwTu4XXgfKTa1sLv+i8zLifUMeQ==" + }, + "libortki-win": { + "type": "Transitive", + "resolved": "0.0.2", + "contentHash": "HAoROgAKn8XBun11X43HZuspKlo5JGy8/OYw5IUPo7FVh5TCaPrLjGmyGYYZ2dqLlv31yv/b6s254PIRGn95cA==" + }, + "Microsoft.Extensions.Configuration.Abstractions": { + "type": "Transitive", + "resolved": "6.0.0", + "contentHash": "qWzV9o+ZRWq+pGm+1dF+R7qTgTYoXvbyowRoBxQJGfqTpqDun2eteerjRQhq5PQ/14S+lqto3Ft4gYaRyl4rdQ==", + "dependencies": { + "Microsoft.Extensions.Primitives": "6.0.0" + } + }, + "Microsoft.Extensions.DependencyInjection.Abstractions": { + "type": "Transitive", + "resolved": "6.0.0", + "contentHash": "xlzi2IYREJH3/m6+lUrQlujzX8wDitm4QGnUu6kUXTQAWPuZY8i+ticFJbzfqaetLA6KR/rO6Ew/HuYD+bxifg==" + }, + "Microsoft.Extensions.FileProviders.Abstractions": { + "type": "Transitive", + "resolved": "6.0.0", + "contentHash": "0pd4/fho0gC12rQswaGQxbU34jOS1TPS8lZPpkFCH68ppQjHNHYle9iRuHeev1LhrJ94YPvzcRd8UmIuFk23Qw==", + "dependencies": { + "Microsoft.Extensions.Primitives": "6.0.0" + } + }, + "Microsoft.Extensions.Primitives": { + "type": "Transitive", + "resolved": "6.0.0", + "contentHash": "9+PnzmQFfEFNR9J2aDTfJGGupShHjOuGw4VUv+JB044biSHrnmCIMD+mJHmb2H7YryrfBEXDurxQ47gJZdCKNQ==", + "dependencies": { + "System.Runtime.CompilerServices.Unsafe": "6.0.0" + } + }, + "NetFabric.Hyperlinq.Abstractions": { + "type": "Transitive", + "resolved": "1.3.0", + "contentHash": "WXnEcGwmXfa8gW9N2MlcaPNUzM3NLMwnAhacbtH554F8YcoXbIkTB+uGa1Aa+9gyb/9JZgYVHnmADgJUKP52nA==" + }, + "StyleCop.Analyzers.Unstable": { + "type": "Transitive", + "resolved": "1.2.0.435", + "contentHash": "ouwPWZxbOV3SmCZxIRqHvljkSzkCyi1tDoMzQtDb/bRP8ctASV/iRJr+A2Gdj0QLaLmWnqTWDrH82/iP+X80Lg==" + }, + "System.Buffers": { + "type": "Transitive", + "resolved": "4.5.1", + "contentHash": "Rw7ijyl1qqRS0YQD/WycNst8hUUMgrMH4FCn1nNm27M4VxchZ1js3fVjQaANHO5f3sN4isvP4a+Met9Y4YomAg==" + }, + "System.Runtime.CompilerServices.Unsafe": { + "type": "Transitive", + "resolved": "6.0.0", + "contentHash": "/iUeP3tq1S0XdNNoMz5C9twLSrM/TH+qElHkXWaPvuNOt+99G75NrV0OS2EqHx5wMN7popYjpc8oTjC1y16DLg==" + }, + "nncase.codegen": { + "type": "Project", + "dependencies": { + "Extension.Mathematics": "[1.2.12, )", + "Nncase.Core": "[1.0.0, )", + "Nncase.IO": "[1.0.0, )" + } + }, + "nncase.core": { + "type": "Project", + "dependencies": { + "DryIoc.dll": "[5.3.1, )", + "GiGraph.Dot": "[2.0.0, )", + "Microsoft.Extensions.Hosting.Abstractions": "[6.0.0, )", + "Microsoft.Extensions.Logging.Abstractions": "[6.0.0, )", + "Microsoft.Extensions.Options": "[6.0.0, )", + "Microsoft.Toolkit.HighPerformance": "[7.1.1, )", + "NetFabric.Hyperlinq": "[3.0.0-beta48, )", + "System.Reactive": "[5.0.0, )" + } + }, + "nncase.egraph": { + "type": "Project", + "dependencies": { + "GiGraph.Dot": "[2.0.0, )", + "NetFabric.Hyperlinq": "[3.0.0-beta48, )", + "Nncase.Core": "[1.0.0, )", + "Nncase.Evaluator": "[1.0.0, )", + "Singulink.Collections.Weak": "[1.0.2, )" + } + }, + "nncase.evaluator": { + "type": "Project", + "dependencies": { + "Nncase.Core": "[1.0.0, )", + "OrtKISharp": "[0.0.2, )" + } + }, + "nncase.graph": { + "type": "Project", + "dependencies": { + "Nncase.Core": "[1.0.0, )", + "Nncase.Evaluator": "[1.0.0, )" + } + }, + "nncase.io": { + "type": "Project" + }, + "nncase.modules.stackvm": { + "type": "Project", + "dependencies": { + "Nncase.CodeGen": "[1.0.0, )", + "Nncase.Passes": "[1.0.0, )" + } + }, + "nncase.passes": { + "type": "Project", + "dependencies": { + "Nncase.Core": "[1.0.0, )", + "Nncase.EGraph": "[1.0.0, )", + "Nncase.Evaluator": "[1.0.0, )", + "Nncase.Graph": "[1.0.0, )" + } + }, + "DryIoc.dll": { + "type": "CentralTransitive", + "requested": "[5.3.1, )", + "resolved": "5.3.1", + "contentHash": "E3zclUh2CIBks1t2uBD1k18pyGFJ1YSKCrbCDbB7qCdl2RAB+k68AyDpjeplhF1ot2XPV82AgyCWBXMf0ggL1g==" + }, + "Extension.Mathematics": { + "type": "CentralTransitive", + "requested": "[1.2.12, )", + "resolved": "1.2.12", + "contentHash": "D4mn5Cab4ztPLJ0V8uMErDrO/Y61098nwrvyIOLZymVAYOQcwP1vomVWKbTagf1aPU3cX5Q7adZtQEQwOy6XEg==" + }, + "GiGraph.Dot": { + "type": "CentralTransitive", + "requested": "[2.0.0, )", + "resolved": "2.0.0", + "contentHash": "ThvS2mQVveSkTMUm04tMbRYzu1XFPV8xBHISrUMp02APjhv9IRbLu3v3upTPCywORx2Ds/c6AqEUL1WU6kPfuQ==" + }, + "Microsoft.Extensions.Hosting.Abstractions": { + "type": "CentralTransitive", + "requested": "[6.0.0, )", + "resolved": "6.0.0", + "contentHash": "GcT5l2CYXL6Sa27KCSh0TixsRfADUgth+ojQSD5EkzisZxmGFh7CwzkcYuGwvmXLjr27uWRNrJ2vuuEjMhU05Q==", + "dependencies": { + "Microsoft.Extensions.Configuration.Abstractions": "6.0.0", + "Microsoft.Extensions.DependencyInjection.Abstractions": "6.0.0", + "Microsoft.Extensions.FileProviders.Abstractions": "6.0.0" + } + }, + "Microsoft.Extensions.Logging.Abstractions": { + "type": "CentralTransitive", + "requested": "[6.0.0, )", + "resolved": "6.0.0", + "contentHash": "/HggWBbTwy8TgebGSX5DBZ24ndhzi93sHUBDvP1IxbZD7FDokYzdAr6+vbWGjw2XAfR2EJ1sfKUotpjHnFWPxA==" + }, + "Microsoft.Extensions.Options": { + "type": "CentralTransitive", + "requested": "[6.0.0, )", + "resolved": "6.0.0", + "contentHash": "dzXN0+V1AyjOe2xcJ86Qbo233KHuLEY0njf/P2Kw8SfJU+d45HNS2ctJdnEnrWbM9Ye2eFgaC5Mj9otRMU6IsQ==", + "dependencies": { + "Microsoft.Extensions.DependencyInjection.Abstractions": "6.0.0", + "Microsoft.Extensions.Primitives": "6.0.0" + } + }, + "Microsoft.Toolkit.HighPerformance": { + "type": "CentralTransitive", + "requested": "[7.1.1, )", + "resolved": "7.1.1", + "contentHash": "TRnvDpZPXO30hTOtjfLw6Y9BtTKtTpzk9lefeh4RMCaUihWrVKQR454nYH4/mMJAh+LXqfAPyk0kfkJs0Amopw==" + }, + "NetFabric.Hyperlinq": { + "type": "CentralTransitive", + "requested": "[3.0.0-beta48, )", + "resolved": "3.0.0-beta48", + "contentHash": "oYUhXvxNS8bBJWqNkvx5g8y0P/0LtyqS2pN0w4OWjVDNWEpLbdbvPy9w/9z1n2PrqIjX3jxUsEnoCmxxGnI3gw==", + "dependencies": { + "NetFabric.Hyperlinq.Abstractions": "1.3.0", + "System.Buffers": "4.5.1", + "System.Runtime.CompilerServices.Unsafe": "5.0.0" + } + }, + "OrtKISharp": { + "type": "CentralTransitive", + "requested": "[0.0.2, )", + "resolved": "0.0.2", + "contentHash": "q8j0yR5836Zhv9WB9BFkQt1UaEFyibq8bqJcTiULlILF6/sz8z7Wy2N8sgYdDKsdW25zncIz7j6IDbKM5ynePg==", + "dependencies": { + "libortki": "0.0.2" + } + }, + "Singulink.Collections.Weak": { + "type": "CentralTransitive", + "requested": "[1.0.2, )", + "resolved": "1.0.2", + "contentHash": "giLAHrjJe0Bh7yhNexR6pmcv02+Fi+lEPxQVdB8zvkuJCmy6rnqu8CZLIpxrUfLcWDuTCSiK0IfGmMhig3UDhA==" + }, + "System.Reactive": { + "type": "CentralTransitive", + "requested": "[5.0.0, )", + "resolved": "5.0.0", + "contentHash": "erBZjkQHWL9jpasCE/0qKAryzVBJFxGHVBAvgRN1bzM0q2s1S4oYREEEL0Vb+1kA/6BKb5FjUZMp5VXmy+gzkQ==" + } + } + } +} \ No newline at end of file diff --git a/modules/Nncase.Modules.StackVM/StackVMModule.cs b/modules/Nncase.Modules.StackVM/StackVMModule.cs index 44d5c616e8..fcbeb7a0f5 100644 --- a/modules/Nncase.Modules.StackVM/StackVMModule.cs +++ b/modules/Nncase.Modules.StackVM/StackVMModule.cs @@ -14,6 +14,5 @@ internal class StackVMModule : IApplicationPart { public void ConfigureServices(IRegistrator registrator) { - registrator.Register(reuse: Reuse.Singleton); } } diff --git a/nncase.sln b/nncase.sln index afde606dc0..e77c50726c 100644 --- a/nncase.sln +++ b/nncase.sln @@ -79,6 +79,8 @@ Project("{9A19103F-16F7-4668-BE54-9A1E7A4F7556}") = "Nncase.Tests.TestFixture", EndProject Project("{9A19103F-16F7-4668-BE54-9A1E7A4F7556}") = "Nncase.Passes", "src\Nncase.Passes\Nncase.Passes.csproj", "{E6462E82-B48F-4AFA-AE34-725EF0A9CB42}" EndProject +Project("{FAE04EC0-301F-11D3-BF4B-00C04F79EFBC}") = "Nncase.Modules.CPU", "modules\Nncase.Modules.CPU\Nncase.Modules.CPU.csproj", "{97DA8EED-F382-4A2E-AE0E-F297A00DA19D}" +EndProject Global GlobalSection(SolutionConfigurationPlatforms) = preSolution Debug|Any CPU = Debug|Any CPU @@ -173,6 +175,10 @@ Global {E6462E82-B48F-4AFA-AE34-725EF0A9CB42}.Debug|Any CPU.Build.0 = Debug|Any CPU {E6462E82-B48F-4AFA-AE34-725EF0A9CB42}.Release|Any CPU.ActiveCfg = Release|Any CPU {E6462E82-B48F-4AFA-AE34-725EF0A9CB42}.Release|Any CPU.Build.0 = Release|Any CPU + {97DA8EED-F382-4A2E-AE0E-F297A00DA19D}.Debug|Any CPU.ActiveCfg = Debug|Any CPU + {97DA8EED-F382-4A2E-AE0E-F297A00DA19D}.Debug|Any CPU.Build.0 = Debug|Any CPU + {97DA8EED-F382-4A2E-AE0E-F297A00DA19D}.Release|Any CPU.ActiveCfg = Release|Any CPU + {97DA8EED-F382-4A2E-AE0E-F297A00DA19D}.Release|Any CPU.Build.0 = Release|Any CPU EndGlobalSection GlobalSection(SolutionProperties) = preSolution HideSolutionNode = FALSE @@ -203,6 +209,7 @@ Global {E365B1B1-4D13-4839-9763-A7A7C5F32FD4} = {BCA74168-F015-4B5B-B4CD-C83AE06B9822} {98A03405-CA53-4EC4-9B18-94D1C8DF9453} = {E5A4516C-4080-4346-991D-57A7AA76ADA6} {E6462E82-B48F-4AFA-AE34-725EF0A9CB42} = {BCA74168-F015-4B5B-B4CD-C83AE06B9822} + {97DA8EED-F382-4A2E-AE0E-F297A00DA19D} = {9859F5E8-5504-4AFE-B955-9497A0A0CD66} EndGlobalSection GlobalSection(ExtensibilityGlobals) = postSolution SolutionGuid = {9492E141-292E-4D60-9C6E-3738AB234DB2} diff --git a/python/nncase/__init__.py b/python/nncase/__init__.py index f8bc443b06..7e186030e6 100644 --- a/python/nncase/__init__.py +++ b/python/nncase/__init__.py @@ -43,7 +43,7 @@ def _initialize(): _initialize() -# _nncase.launch_debugger() +_nncase.launch_debugger() class ImportOptions: diff --git a/src/Nncase.Cli/Nncase.Cli.csproj b/src/Nncase.Cli/Nncase.Cli.csproj index 3070806ba5..17e370d4ca 100644 --- a/src/Nncase.Cli/Nncase.Cli.csproj +++ b/src/Nncase.Cli/Nncase.Cli.csproj @@ -26,4 +26,8 @@ PreserveNewest + + + + diff --git a/src/Nncase.Cli/packages.lock.json b/src/Nncase.Cli/packages.lock.json index 40fb79a75e..1677ce9a38 100644 --- a/src/Nncase.Cli/packages.lock.json +++ b/src/Nncase.Cli/packages.lock.json @@ -655,6 +655,7 @@ "Nncase.Evaluator": "[1.0.0, )", "Nncase.Graph": "[1.0.0, )", "Nncase.Importer": "[1.0.0, )", + "Nncase.Modules.CPU": "[1.0.0, )", "Nncase.Modules.StackVM": "[1.0.0, )", "Nncase.Passes": "[1.0.0, )", "Nncase.Quantization": "[1.0.0, )", @@ -716,6 +717,14 @@ "nncase.io": { "type": "Project" }, + "nncase.modules.cpu": { + "type": "Project", + "dependencies": { + "Nncase.CodeGen": "[1.0.0, )", + "Nncase.Modules.StackVM": "[1.0.0, )", + "Nncase.Passes": "[1.0.0, )" + } + }, "nncase.modules.stackvm": { "type": "Project", "dependencies": { diff --git a/src/Nncase.Compiler/Hosting/CompilerHostBuilderExtensions.cs b/src/Nncase.Compiler/Hosting/CompilerHostBuilderExtensions.cs index 5fa0e4ba0f..70f110f2e5 100644 --- a/src/Nncase.Compiler/Hosting/CompilerHostBuilderExtensions.cs +++ b/src/Nncase.Compiler/Hosting/CompilerHostBuilderExtensions.cs @@ -53,7 +53,8 @@ private static void ConfigureBuiltinModules(Container builder) .AddEGraph() .AddCodeGen() .AddPasses() - .AddStackVM(); + .AddStackVM() + .AddCPU(); } private static void ConfigureServices(HostBuilderContext context, IServiceCollection services) diff --git a/src/Nncase.Compiler/Hosting/PluginLoader.cs b/src/Nncase.Compiler/Hosting/PluginLoader.cs index 73f10f367a..0264ccc1b9 100644 --- a/src/Nncase.Compiler/Hosting/PluginLoader.cs +++ b/src/Nncase.Compiler/Hosting/PluginLoader.cs @@ -25,6 +25,7 @@ public sealed class PluginLoader private static readonly string[] _builtinModules = new[] { "Nncase.Modules.StackVM.dll", + "Nncase.Modules.CPU.dll", "Nncase.Modules.K210.dll", }; diff --git a/src/Nncase.Compiler/Nncase.Compiler.csproj b/src/Nncase.Compiler/Nncase.Compiler.csproj index 8d02d21ec4..ab93c75220 100644 --- a/src/Nncase.Compiler/Nncase.Compiler.csproj +++ b/src/Nncase.Compiler/Nncase.Compiler.csproj @@ -14,6 +14,7 @@ + diff --git a/src/Nncase.Compiler/packages.lock.json b/src/Nncase.Compiler/packages.lock.json index c58fa234e0..a00db8ff64 100644 --- a/src/Nncase.Compiler/packages.lock.json +++ b/src/Nncase.Compiler/packages.lock.json @@ -694,6 +694,14 @@ "nncase.io": { "type": "Project" }, + "nncase.modules.cpu": { + "type": "Project", + "dependencies": { + "Nncase.CodeGen": "[1.0.0, )", + "Nncase.Modules.StackVM": "[1.0.0, )", + "Nncase.Passes": "[1.0.0, )" + } + }, "nncase.modules.stackvm": { "type": "Project", "dependencies": { diff --git a/src/Nncase.Tests.TestFixture/Nncase.Tests.TestFixture.csproj b/src/Nncase.Tests.TestFixture/Nncase.Tests.TestFixture.csproj index 3e1336e99c..d3b0b7e32b 100644 --- a/src/Nncase.Tests.TestFixture/Nncase.Tests.TestFixture.csproj +++ b/src/Nncase.Tests.TestFixture/Nncase.Tests.TestFixture.csproj @@ -22,6 +22,7 @@ + diff --git a/src/Nncase.Tests.TestFixture/packages.lock.json b/src/Nncase.Tests.TestFixture/packages.lock.json index 02bfd47a3b..692f52a7fe 100644 --- a/src/Nncase.Tests.TestFixture/packages.lock.json +++ b/src/Nncase.Tests.TestFixture/packages.lock.json @@ -1094,6 +1094,14 @@ "nncase.io": { "type": "Project" }, + "nncase.modules.cpu": { + "type": "Project", + "dependencies": { + "Nncase.CodeGen": "[1.0.0, )", + "Nncase.Modules.StackVM": "[1.0.0, )", + "Nncase.Passes": "[1.0.0, )" + } + }, "nncase.modules.stackvm": { "type": "Project", "dependencies": { diff --git a/src/Nncase.Tests/Targets/UnitTestCPUTarget.cs b/src/Nncase.Tests/Targets/UnitTestCPUTarget.cs index b3be7dff7e..89c6555e5c 100644 --- a/src/Nncase.Tests/Targets/UnitTestCPUTarget.cs +++ b/src/Nncase.Tests/Targets/UnitTestCPUTarget.cs @@ -111,6 +111,16 @@ public void TestSimpleBinary() GenerateKModelAndRun(module, new[] { 1.0f }, new[] { 2.0f }); } + [Fact] + public void TestSimpleUnary() + { + var x = new Var("x", new TensorType(DataTypes.Float32, new[] { 1 })); + var y = IR.F.Math.Log(x); + var main = new Function("main", y, new[] { x }); + var module = new IRModule(main); + GenerateKModelAndRun(module, new[] { 1.0f }, new[] { MathF.Log(1.0f) }); + } + [Fact] public void TestCodegenCallParamOrder() { diff --git a/src/Nncase.Tests/packages.lock.json b/src/Nncase.Tests/packages.lock.json index b2b9315b3b..0b497aebc2 100644 --- a/src/Nncase.Tests/packages.lock.json +++ b/src/Nncase.Tests/packages.lock.json @@ -1445,6 +1445,7 @@ "Nncase.Evaluator": "[1.0.0, )", "Nncase.Graph": "[1.0.0, )", "Nncase.Importer": "[1.0.0, )", + "Nncase.Modules.CPU": "[1.0.0, )", "Nncase.Modules.StackVM": "[1.0.0, )", "Nncase.Passes": "[1.0.0, )", "Nncase.Quantization": "[1.0.0, )", @@ -1506,6 +1507,14 @@ "nncase.io": { "type": "Project" }, + "nncase.modules.cpu": { + "type": "Project", + "dependencies": { + "Nncase.CodeGen": "[1.0.0, )", + "Nncase.Modules.StackVM": "[1.0.0, )", + "Nncase.Passes": "[1.0.0, )" + } + }, "nncase.modules.stackvm": { "type": "Project", "dependencies": { @@ -1544,6 +1553,7 @@ "MethodBoundaryAspect.Fody": "[2.0.148, )", "Nncase.CodeGen": "[1.0.0, )", "Nncase.Core": "[1.0.0, )", + "Nncase.Modules.CPU": "[1.0.0, )", "Nncase.Modules.StackVM": "[1.0.0, )", "Nncase.Passes": "[1.0.0, )", "Nncase.Simulator": "[1.0.0, )", From 296e70b6c71df1808b83435f27a4df8b8f3308a4 Mon Sep 17 00:00:00 2001 From: sunnycase Date: Wed, 24 May 2023 18:46:23 +0800 Subject: [PATCH 002/308] Add make fusion --- .../Passes/Rules/MakeFusion.cs | 43 +++++++++++++++++++ .../Nncase.Modules.CPU/Targets/CPUTarget.cs | 15 ++++--- .../Rule/RuleGenerator.cs | 1 + 3 files changed, 54 insertions(+), 5 deletions(-) create mode 100644 modules/Nncase.Modules.CPU/Passes/Rules/MakeFusion.cs diff --git a/modules/Nncase.Modules.CPU/Passes/Rules/MakeFusion.cs b/modules/Nncase.Modules.CPU/Passes/Rules/MakeFusion.cs new file mode 100644 index 0000000000..fb1bd25a22 --- /dev/null +++ b/modules/Nncase.Modules.CPU/Passes/Rules/MakeFusion.cs @@ -0,0 +1,43 @@ +// Copyright (c) Canaan Inc. All rights reserved. +// Licensed under the Apache license. See LICENSE file in the project root for full license information. + +using System; +using System.Collections.Generic; +using System.Linq; +using System.Text; +using System.Threading.Tasks; +using Nncase.IR; +using Nncase.IR.CPU; +using Nncase.Passes.Rules.Neutral; +using Nncase.PatternMatch; +using Nncase.Targets; +using static Nncase.PatternMatch.F.Tensors; +using static Nncase.PatternMatch.Utility; +using static Nncase.Utilities.ReplaceUtility; + +namespace Nncase.Passes.Rules; + +[RuleGenerator] +internal partial class CPUSingleInputFusion : FusionMaker + where T : Op +{ + public override string ModuleKind { get; } = CPUTarget.Kind; + + public override Pattern Pattern { get; } = IsCallWildcard( + "call", + IsOp("op"), + IsWildcard("input")); + + private Call? GetReplace(Call call, IReadOnlyList callParams, Op op, Expr input) + { + var newInput = new Var(input.CheckedType!); + var newCall = ReplaceCallParams(op, callParams, (input, newInput)); + var fusion = new Call(new Fusion(FullName, ModuleKind, newCall, new[] { newInput }), input); + return fusion; + } +} + +internal sealed class CPUUnaryFusion : CPUSingleInputFusion +{ + public override string Name => "Unary"; +} diff --git a/modules/Nncase.Modules.CPU/Targets/CPUTarget.cs b/modules/Nncase.Modules.CPU/Targets/CPUTarget.cs index 27fb312bfd..c82f0d24d1 100644 --- a/modules/Nncase.Modules.CPU/Targets/CPUTarget.cs +++ b/modules/Nncase.Modules.CPU/Targets/CPUTarget.cs @@ -8,6 +8,7 @@ using System.Threading.Tasks; using Microsoft.Extensions.Configuration; using Microsoft.Extensions.Options; +using Microsoft.VisualBasic; using Nncase.CodeGen; using Nncase.CodeGen.StackVM; using Nncase.IR; @@ -63,11 +64,6 @@ public Task AdaRoundWeights(ICalibrationDatasetProvider calibrationDataset, List /// public void RegisterQuantizePass(IPassManager passManager, CompileOptions options) - { - } - - /// - public void RegisterTargetDependentAfterQuantPass(IPassManager passManager, CompileOptions options) { if (options.QuantizeOptions.ModelQuantMode == ModelQuantMode.UsePTQ) { @@ -78,6 +74,15 @@ public void RegisterTargetDependentAfterQuantPass(IPassManager passManager, Comp } } + /// + public void RegisterTargetDependentAfterQuantPass(IPassManager passManager, CompileOptions options) + { + passManager.AddWithName("MakeFusion").Configure(p => + { + p.Add(); + }); + } + public void RegisterTargetDependentBeforeCodeGen(IPassManager passManager, CompileOptions options) { } diff --git a/tools/Nncase.SourceGenerator/Rule/RuleGenerator.cs b/tools/Nncase.SourceGenerator/Rule/RuleGenerator.cs index 57cf983aba..f2af3c5a2d 100644 --- a/tools/Nncase.SourceGenerator/Rule/RuleGenerator.cs +++ b/tools/Nncase.SourceGenerator/Rule/RuleGenerator.cs @@ -213,6 +213,7 @@ private void Execute(SourceProductionContext context, ImmutableArray(method)) .WithAttributeLists(new SyntaxList() { }) From a36cd0845d0a66f22b189e2c879cacd3b41f3852 Mon Sep 17 00:00:00 2001 From: sunnycase Date: Thu, 6 Jul 2023 12:23:59 +0800 Subject: [PATCH 003/308] GNNE-1881 Add tiling basics --- .../Passes/CPUFusionToTirPass.cs | 117 ++ .../Passes/Tile/CPUFusionGroupMutator.cs | 187 +++ .../Passes/Tile/IFusionChecker.cs | 26 + .../Passes/Tile/LayerFusionConverter.cs | 1262 +++++++++++++++++ .../Passes/Tile/MultiFusionChecker.cs | 251 ++++ .../Passes/Tile/MultiLayerFusionConverter.cs | 228 +++ .../Passes/Tile/TileOptions.cs | 21 + .../Passes/Tile/TwoFusionChecker.cs | 131 ++ .../Nncase.Modules.CPU/Targets/CPUTarget.cs | 2 - modules/Nncase.Modules.CPU/packages.lock.json | 46 + .../Properties/launchSettings.json | 2 +- 11 files changed, 2270 insertions(+), 3 deletions(-) create mode 100644 modules/Nncase.Modules.CPU/Passes/CPUFusionToTirPass.cs create mode 100644 modules/Nncase.Modules.CPU/Passes/Tile/CPUFusionGroupMutator.cs create mode 100644 modules/Nncase.Modules.CPU/Passes/Tile/IFusionChecker.cs create mode 100644 modules/Nncase.Modules.CPU/Passes/Tile/LayerFusionConverter.cs create mode 100644 modules/Nncase.Modules.CPU/Passes/Tile/MultiFusionChecker.cs create mode 100644 modules/Nncase.Modules.CPU/Passes/Tile/MultiLayerFusionConverter.cs create mode 100644 modules/Nncase.Modules.CPU/Passes/Tile/TileOptions.cs create mode 100644 modules/Nncase.Modules.CPU/Passes/Tile/TwoFusionChecker.cs diff --git a/modules/Nncase.Modules.CPU/Passes/CPUFusionToTirPass.cs b/modules/Nncase.Modules.CPU/Passes/CPUFusionToTirPass.cs new file mode 100644 index 0000000000..128ca15c91 --- /dev/null +++ b/modules/Nncase.Modules.CPU/Passes/CPUFusionToTirPass.cs @@ -0,0 +1,117 @@ +using System; +using System.Collections.Generic; +using System.Linq; +using System.Text; +using System.Threading.Tasks; +using Microsoft.Extensions.DependencyInjection; +using Nncase.Diagnostics; +using Nncase.IR; +using Nncase.Passes.Analysis; +using Nncase.Passes.Mutators; +using Nncase.Passes.Tile; +using Nncase.TIR; + +namespace Nncase.Passes; + +internal sealed class CPUFusionToTirPass : ModulePass +{ + private readonly TileOptions _tileOptions; + private readonly Dictionary _fusionMacsMap; + + public CPUFusionToTirPass(TileOptions tileOptions) + { + _tileOptions = tileOptions; + _fusionMacsMap = new(ReferenceEqualityComparer.Instance); + } + + private IAnalyzerManager AnalyzerManager => CompileSession.GetRequiredService(); + + /// + protected override Task RunCoreAsync(IRModule module, RunPassContext options) + { + Dictionary fusionConertedCache = new(ReferenceEqualityComparer.Instance); + + // convert the fusion as entry. + for (int i = 0; i < module.Functions.Count; i++) + { + if (module.Functions[i] is Fusion { ModuleKind: "cpu" } fusion) + { + TIR.PrimFunction primFunction; + var visitor = new MultiLayerFusionConverter(_tileOptions); + primFunction = visitor.VisitToPrimFunc(fusion); + + CompilerServices.InferenceType(primFunction); + fusionConertedCache[fusion] = primFunction; + module.Replace(i, primFunction); + } + } + + // convert the stackvm function call k510 fusion + for (int i = 0; i < module.Functions.Count; i++) + { + if (module.Functions[i] is Function { ModuleKind: "stackvm" } func) + { + var analysis = new Dictionary + { + [typeof(IExprUserAnalysisResult)] = AnalyzerManager.GetAnaylsis(func), + }; + var rewriter = new DataFlowMergeRewriter(); + var fusionCheckCache = new Dictionary(ReferenceEqualityComparer.Instance); + + var post = (Function)rewriter.Rewrite(func, new Mutators.IMergeRewriteRule[] { new GNNESameInputFusionMergeRule(), }, (rule, option) => new CPUFusionGroupMutator(fusionCheckCache, _tileOptions, rule, option), new() { AnalysisResults = analysis, MatchOptions = new Mutators.FusionGroupMutator.GroupedMatchOptions() }); + + // if (DumpScope.Current.IsEnabled(DumpFlags.PassIR)) + // { + // DumpScope.Current.DumpDotIR(post, "MultiLayer"); + // } + // post = (Function)rewriter.Rewrite( + // post, + // new Mutators.IMergeRewriteRule[] { + // new GNNESameInputFusionMergeRule(), + // }, + // (rule, option) => new CPUFusionGroupMutator(fusionCheckCache, _tileOptions, rule, option), + // new() { AnalysisResults = analysis, MatchOptions = new Mutators.FusionGroupMutator.GroupedMatchOptions() }); + + // if (DumpScope.Current.IsEnabled(DumpFlags.PassIR)) + // { + // DumpScope.Current.DumpDotIR(post, "TwoLayer"); + // } + // var post = func; + var mutator = new CheckedConvertMutator(fusionConertedCache, _fusionMacsMap, fusionCheckCache, _tileOptions, options); + var new_func = (Function)mutator.Rewrite(post); + CompilerServices.InferenceType(new_func); + if (mutator.IsMutated) + { + module.Replace(i, new_func); + } + } + } + + // add all prim func. + foreach (var item in fusionConertedCache.Values) + { + if (item is PrimFunctionWrapper wrapper) + { + module.Add(wrapper); + module.Add(wrapper.Target); + } + } + + return Task.FromResult(module); + } + + protected override async Task OnPassEndAsync(IRModule post, RunPassContext context) + { + await base.OnPassEndAsync(post, context); + if (DumpScope.Current.IsEnabled(DumpFlags.PassIR)) + { + using var writer = new StreamWriter(DumpScope.Current.OpenFile("mac.csv")); + foreach (var (fusion, mac) in _fusionMacsMap) + { + writer.WriteLine($"mac: {fusion.Name},{mac}"); + } + } + + _fusionMacsMap.Clear(); + } +} diff --git a/modules/Nncase.Modules.CPU/Passes/Tile/CPUFusionGroupMutator.cs b/modules/Nncase.Modules.CPU/Passes/Tile/CPUFusionGroupMutator.cs new file mode 100644 index 0000000000..4b206ca688 --- /dev/null +++ b/modules/Nncase.Modules.CPU/Passes/Tile/CPUFusionGroupMutator.cs @@ -0,0 +1,187 @@ +// Copyright (c) Canaan Inc. All rights reserved. +// Licensed under the Apache license. See LICENSE file in the project root for full license information. + +using System.Runtime.CompilerServices; +using Nncase.IR; +using Nncase.PatternMatch; +using Nncase.Targets; +using static Nncase.PatternMatch.Utility; + +namespace Nncase.Passes.Tile; + +internal sealed class GNNESameInputFusionMergeRule : Mutators.SameInputFusionMergeRule +{ + public override string ModuleKind => CPUTarget.Kind; + + // todo enable multi input fusion merge pattern. + public override Pattern CreatePattern(string target_module_kind) + { + var inputPat = IsWildcard("input"); + + var callerPattern = IsCall( + "caller", + IsFusion( + "caller_fusion", + target_module_kind, + IsWildcard(), + IsVArgs(IsWildcard())), + IsVArgs("caller_inputs", new[] { + IsCall( + $"callee_{0}", + IsFusion($"callee_fusion_{0}", target_module_kind, IsWildcard(), IsVArgs(IsWildcard())), + inputPat), + })); + return callerPattern; + } +} + +internal sealed class CPUFusionGroupMutator : Mutators.FusionGroupMutator + where T : IFusionChecker +{ + private readonly TileOptions _tileOptions; + + public CPUFusionGroupMutator( + Dictionary fusioncheckerCache, + TileOptions tileOptions, + Mutators.IMergeRewriteRule rule, + RunPassContext passOptions) + : base(rule, passOptions) + { + _tileOptions = tileOptions; + FusioncheckerCache = fusioncheckerCache; + } + + public Dictionary FusioncheckerCache { get; } + + /// + public override bool MergedFusionCheckCallBack(Fusion mergedFusion, HashSet candidateFusions) + { + // note the gnne activate must be first layer. + if (mergedFusion.Body is Call { Target: IR.K510.GNNEStore } st_call && + st_call[IR.K510.GNNEStore.Input] is Call { Target: IR.K510.GNNEActivation }) + { + return false; + } + + var checker = (IFusionChecker)Activator.CreateInstance(typeof(T), new object[] { _tileOptions })!; + var ret = checker.Check(mergedFusion, PassOptions); + if (ret) + { + FusioncheckerCache.Add(mergedFusion, checker); + foreach (var cand in candidateFusions) + { // release the merged fusion. + FusioncheckerCache.Remove(cand); + } + } + + return ret; + } + + public override Expr MergedFusionRewriteCallBack(Expr mergedFusionBody) + { + return CompilerServices.Rewrite(mergedFusionBody, new[] { new Rules.GNNE.Opt.FoldLoadStore() }, new()); + } +} + +internal sealed class CheckedConvertMutator : ExprRewriter +{ + private readonly Dictionary _fusionConertedCache; + private readonly IDictionary _fusionMacsMap; + private readonly IReadOnlyDictionary _fusionCheckerCache; + private readonly TileOptions _tileOptions; + private readonly RunPassContext _passOptions; + + public CheckedConvertMutator(Dictionary fusion_converted_cache, Dictionary fusionMacsMap, IReadOnlyDictionary fusionchecker_cache, TileOptions tileOptions, RunPassContext passOptions) + { + _fusionConertedCache = fusion_converted_cache; + _fusionMacsMap = fusionMacsMap; + _fusionCheckerCache = fusionchecker_cache; + _tileOptions = tileOptions; + _passOptions = passOptions; + } + + /// + protected override Expr RewriteLeafFusion(Fusion expr) + { + if (expr is Fusion { ModuleKind: K510Target.Kind } fusion) + { + if (!_fusionConertedCache.TryGetValue(fusion, out _)) + { + TIR.PrimFunction prim_func; + if (_fusionCheckerCache.TryGetValue(fusion, out var checker)) + { + prim_func = checker.Convert(_passOptions); + } + else + { + if (CompilerServices.TryMatchRoot(fusion, Conv2DFusionConverter.Conv2DFusionPattern, out var matchResult)) + { + prim_func = Conv2DFusionConverter.VisitToPrimFunc(_tileOptions, fusion, matchResult, out _, out _); + } + else if (CompilerServices.TryMatchRoot(fusion, Conv2DTransposeFusionConverter.Conv2DFusionPattern, out matchResult)) + { + prim_func = Conv2DTransposeFusionConverter.VisitToPrimFunc(_tileOptions, fusion, matchResult, out _, out _); + } + else if (!_tileOptions.ForceMultiLayer && CompilerServices.TryMatchRoot(fusion, LSTMFusionConverter.FusionPattern, out matchResult)) + { + prim_func = LSTMFusionConverter.VisitToPrimFunc(_tileOptions, fusion, matchResult, out _, out _); + } + else + { + var visitor = new MultiLayerFusionConverter(_tileOptions); + prim_func = visitor.VisitToPrimFunc(fusion); + } + } + + BaseFunction? convert_func = prim_func; + _fusionConertedCache.Add(fusion, convert_func); + new DDrMacCalcVisitor(_fusionMacsMap).Visit(fusion); + } + } + + return expr; + } + + protected override Expr RewriteLeafCall(Call expr) + { + if (expr.Target is Fusion { ModuleKind: K510Target.Kind } fusion) + { + var convert_func = _fusionConertedCache[fusion]; + PrimFunctionWrapper wrapper; + if (convert_func is TIR.PrimFunction prim_func) + { + bool is_input = true; + int param_count = 0; + foreach (var b in prim_func.Parameters) + { + if (b.MemLocation == Schedule.MemoryLocation.Input) + { + if (is_input) + { + param_count += 1; + } + else + { + throw new InvalidOperationException("The output buffer must behind the input buffer"); + } + } + else + { + is_input = false; + } + } + + wrapper = new PrimFunctionWrapper(prim_func, param_count); + _fusionConertedCache[fusion] = wrapper; + } + else + { + wrapper = (PrimFunctionWrapper)convert_func; + } + + return expr.With(target: wrapper); + } + + return expr; + } +} diff --git a/modules/Nncase.Modules.CPU/Passes/Tile/IFusionChecker.cs b/modules/Nncase.Modules.CPU/Passes/Tile/IFusionChecker.cs new file mode 100644 index 0000000000..84b2d37072 --- /dev/null +++ b/modules/Nncase.Modules.CPU/Passes/Tile/IFusionChecker.cs @@ -0,0 +1,26 @@ +using System; +using System.Collections.Generic; +using System.Linq; +using System.Text; +using System.Threading.Tasks; +using Nncase.IR; + +namespace Nncase.Passes.Tile; + +internal interface IFusionChecker +{ + /// + /// 检查fusion是否可以正常执行. + /// + /// fusion. + /// passOptions. + /// . + public bool Check(Fusion fusion, RunPassContext passOptions); + + /// + /// 通常当check过一个fusion之后, 可以cache部分的内容, 此时通过convert复用. + /// + /// passOptions. + /// . + public TIR.PrimFunction Convert(RunPassContext passOptions); +} diff --git a/modules/Nncase.Modules.CPU/Passes/Tile/LayerFusionConverter.cs b/modules/Nncase.Modules.CPU/Passes/Tile/LayerFusionConverter.cs new file mode 100644 index 0000000000..8cf139f3fc --- /dev/null +++ b/modules/Nncase.Modules.CPU/Passes/Tile/LayerFusionConverter.cs @@ -0,0 +1,1262 @@ +// Copyright (c) Canaan Inc. All rights reserved. +// Licensed under the Apache license. See LICENSE file in the project root for full license information. + +using System.Reactive; +using System.Runtime.CompilerServices; +using NetFabric.Hyperlinq; +using Nncase.Diagnostics; +using Nncase.IR; +using Nncase.IR.Buffers; +using Nncase.IR.F; +using Nncase.IR.K510; +using Nncase.IR.Math; +using Nncase.Passes.BufferSchedule; +using Nncase.Passes.Mutators; +using Nncase.Runtime.K510; +using Nncase.Schedule; +using Nncase.TIR; +using Nncase.TIR.Builders; +using Nncase.TIR.K510; +using Nncase.TIR.K510.Builders; +using Nncase.TIR.K510.Instructions; +using Buffer = Nncase.TIR.Buffer; +using MathF = Nncase.IR.F.Math; +using Range = Nncase.TIR.Range; +using Tuple = Nncase.IR.Tuple; + +namespace Nncase.Passes.Tile; + +public sealed class BufferRegionView +{ + private Expr? _cache; + + private Expr[]? _condtion_buffer_regions; + + private Expr[]? _region_size; + + public BufferRegionView(IEnumerable buffers, IEnumerable bounds, IEnumerable region, IndexMapKey key) + : this(buffers, bounds, region, key, 0, null) + { + } + + public BufferRegionView(IEnumerable buffers, IEnumerable bounds, IEnumerable region, IndexMapKey key, Expr loopCount, int? promote) + { + Buffers = buffers.ToArray(); + Region = region.ToArray(); + LoopCount = loopCount; + Parent = null; + Key = key; + Promote = promote; + Bounds = bounds.ToArray(); + } + + public IndexMapKey Key { get; } + + /// + /// Gets 记录他的loop count. + /// + public Expr LoopCount { get; } + + public int? Promote { get; } + + public IReadOnlyList Bounds { get; } + + public IReadOnlyList Buffers { get; } + + public IReadOnlyList Region { get; } + + public BufferRegionView? Parent { get; set; } + + public ReadOnlySpan Dimensions => Buffers[0].Dimensions; + + /// + /// Gets 返回带有condition的buffer region的表达式. + /// + public IReadOnlyList BufferRegions + { + get + { + _condtion_buffer_regions ??= Buffers.Count == 0 ? Array.Empty() : Buffers.Select(b => new BufferRegion(b, Region.ToArray())).ToArray(); + return _condtion_buffer_regions; + } + } + + public BufferRegionView this[params Range[] ranges] + { + get => new(Buffers, Bounds, Region.Zip(ranges).Select(tp => tp.Second.Equals(Range.All) ? tp.First : tp.Second.Stop switch { Call { Target: Unary { UnaryOp: UnaryOp.Neg } } => throw new NotSupportedException("Neg Region!"), _ => tp.Second, }), Key, LoopCount, Promote) { Parent = Parent is null ? this : Parent, }; // if stop is neg, add the shape, else return the origin range. + } + + /// + /// convert the BufferRegionView to expr. + /// + /// 当开启ping pong时,如果 + /// + /// + /// view. + public static implicit operator Expr(BufferRegionView view) + { + if (view._cache is not null) + { + return view._cache; + } + + Expr expr; + if (view.Buffers.Count == 0) + { + expr = IR.None.Default; + } + else if (view.Buffers.Count == 1) + { + expr = view.BufferRegions[0]; + } + else if (view.Buffers.Count >= 2) + { + expr = new Tuple(view.BufferRegions.ToArray())[view.LoopCount % view.Buffers.Count]; + } + else + { + throw new NotSupportedException(); + } + + view._cache = expr; + return view._cache; + } + + public static BufferRegionView None(IndexMapKey key) => new(Array.Empty(), new IRArray(), new IRArray(), key); + + public ReadOnlySpan RegionSize() + { + _region_size ??= Region.AsValueEnumerable().Select(r => r.Stop - r.Start).ToArray(); + return _region_size; + } + + public Expr RegionSize(int i) => RegionSize()[i]; +} + +/// +/// name 分配器. +/// +internal sealed class NameAllocator +{ + public Dictionary NamePool { get; } = new(); + + public string Get(string name) + { + if (!NamePool.TryGetValue(name, out var count)) + { + count = 0; + } + + NamePool[name] = count + 1; + return count == 0 ? name : $"{name}_{count}"; + } +} + +/// +/// buffer region view. +/// +internal sealed class LogiclPrimFuncCloner : ExprCloner +{ + protected override Expr VisitLeafLogicalBuffer(LogicalBuffer buffer, Unit context) + { + return buffer; + } + + protected override Expr VisitLeafPhysicalBuffer(PhysicalBuffer buffer, Unit context) + { + return buffer; + } + + protected override Expr VisitVar(Var var, Unit context) + { + return var; + } +} + +internal sealed record ReIndexCacheKey(IBoundsInferGraph BoundsInferGraph, IndexMapKey From, IndexMapKey To, IRArray FromRegion, int? Promote) +{ +} + +internal abstract class LayerFusionConverter +{ + public NameAllocator NameAllocator { get; } = new(); + + /// + /// Gets map the graph expression and it's bufferRegion. + /// + public Dictionary KeyToViewMap { get; } = new(); + + /// + /// Gets because of the index map can't create by var, so need other map save the relationship. + /// + public Dictionary VarToKeyMap { get; } = new(ReferenceEqualityComparer.Instance); + + /// + /// Gets tile size 的变量. + /// + public List TileSizeVars { get; } = new(); + + /// + /// Gets loop 变量. + /// + public List LoopVars { get; } = new(); + + /// + /// Gets loop domains. + /// + public List LoopDomains { get; } = new(); + + /// + /// Gets 所有的blocks + /// 最终是: + /// mainBlock + /// loop n + /// block n + /// loop c + /// block c + /// . + /// . + /// + public List NestedBlocks { get; } = new(); + + /// + /// Gets nested loops. + /// + public List> NestedLoops { get; } = new(); + + public TileOptions TileOptions { get; protected set; } = null!; + + /// + /// Gets or sets 默认的bounds infer graph. + /// + public abstract IBoundsInferGraph BoundsInferGraph { get; protected set; } + + /// + /// Gets or sets 总的loop count. + /// / + public abstract Expr LoopCount { get; protected set; } + + /// + /// Gets or sets ping pong 外层的tiling. + /// + public abstract Expr LoopCountOuter { get; protected set; } + + /// + /// Gets or sets ping pong 内侧的tiling. + /// + public abstract Expr LoopCountInner { get; protected set; } + + /// + /// Gets or sets 当前的fusion. + /// + public abstract Fusion CurrentFusion { get; protected set; } + + /// + /// Gets glb reindex cache. + /// + protected Dictionary ToRegion, IReadOnlyList<(Expr Before, Expr After)> Paddings)> GlbReindexCache { get; } = new(); + + public abstract Expr Visit(Fusion fusion); + + public virtual PrimFunction BuildLogicalPrimFunc(Expr bodySeq) + { + var inputs_buffer = CurrentFusion.Parameters.ToArray().Select(p => (PhysicalBuffer)KeyToViewMap[VarToKeyMap[p]].Buffers[0]); + var primFuncBuilder = T.PrimFunc(CurrentFusion.Name, K510RTModule.Kind, inputs_buffer.Concat(new[] { (PhysicalBuffer)KeyToViewMap[(Call)CurrentFusion.Body].Buffers[0] }).ToArray()); + + NestedBlocks[^1].Body(bodySeq); + primFuncBuilder.Body( + I.MmuConf(0, 0, MMU_CONF_WIDTH._8, 0, ExtCompilerServices.Env.GlbDepth), // 把整个glb当作连续内存使用. + NestedBlocks[0], + I.Fence()); + + var logicalPrimFunc = primFuncBuilder.Build(); + logicalPrimFunc = (PrimFunction)new Mutators.SimplifyBounds().Rewrite(logicalPrimFunc); + logicalPrimFunc.InferenceType(); + GlbReindexCache.Clear(); + return logicalPrimFunc; + } + + public abstract bool BalanceTileSize(int[] tile_size, Segment[] search_spaces); + + public virtual PrimFunction BuildPhysicalPrimFunc(int[] final_tile_size, IReadOnlyDictionary sched_result, PrimFunction logicalPrimFunc) + { + var physicalizer = new BufferPhysicalizer(final_tile_size, sched_result, TileSizeVars); + var physicalPrimFunc = (PrimFunction)physicalizer.Rewrite(logicalPrimFunc); + return physicalPrimFunc; + } + + public virtual int[] SearchTileSize(ISearchTileGenerator tile_generator, PrimFunction logicalPrimFunc, bool multi_workers, bool hasResult, out ScheduledResponse response) + { + AllocationCache response_cache = new(); + bool schedule_status = false; + int[] final_tile = Array.Empty(); + + while (true) + { + var next_tile = tile_generator.GetNextTile(schedule_status).ToArray(); + if (next_tile.Length == 0) + { + break; + } + + schedule_status = TryScheduleNextTileSize(next_tile, logicalPrimFunc, response_cache, multi_workers, hasResult); + if (schedule_status) + { + final_tile = next_tile; + response_cache.CheckIn(); + } + } + + if (!final_tile.Any()) + { + response = new(new Dictionary(), null!, null!, logicalPrimFunc, null!, 0, false); + return final_tile; + } + + // take back last success allocation result + response = response_cache.GetLastSuccess(final_tile); + return final_tile; + } + + public virtual Expr Visit(IndexMapKey mapKey, string prefix, int? promote = null) + { + prefix = prefix + mapKey.Prefix; + return mapKey.Expr switch + { + Call call => (call.Target switch + { + GNNELoad op => LowerGnneLoad(mapKey, call, op, NameAllocator.Get(nameof(GNNELoad)), prefix, promote, true), + GNNEStore op => LowerGnneStore(call, op, NameAllocator.Get(nameof(GNNEStore)), prefix), + GNNEConv2D op => LowerGnneConv2D(mapKey, call, op, NameAllocator.Get(nameof(GNNEConv2D)), prefix), + GNNEConv2DTranspose op => LowerGnneConv2DTranspose(mapKey, call, op, NameAllocator.Get(nameof(GNNEConv2D)), prefix), + GNNEReduce op => LowerGnneReduce(mapKey, call, op, NameAllocator.Get(nameof(GNNEReduce)), prefix), + GNNEMeshNet op => LowerGnneMeshNet(mapKey, call, op, NameAllocator.Get(nameof(GNNEMeshNet)), prefix), + GNNETranspose op => LowerGnneTranspose(mapKey, call, op, NameAllocator.Get(nameof(GNNETranspose)), prefix), + GNNEActivation op => LowerGnneActivation(mapKey, call, op, NameAllocator.Get(nameof(GNNEActivation)), prefix), + GNNEPdpReduce op => LowerGnnePdpReduce(mapKey, call, op, NameAllocator.Get(nameof(GNNEPdpReduce)), prefix), + GNNECrop op => LowerGnneCrop(mapKey, call, op, NameAllocator.Get(nameof(GNNECrop)), prefix), + Uninitialized => T.Sequential(), + _ => throw new NotSupportedException(), + }).Build(), + _ => T.Nop(), + }; + } + + /// + /// 子偏移输入到bounds infer后反推子偏移. + /// + /// from. + /// to. + /// sub_paddings. + /// the partial compute funcs. + /// . + public virtual BufferRegionView GlbReIndex(BufferRegionView from, BufferRegionView to, out IReadOnlyList<(Expr Before, Expr After)> sub_paddings, params (int Axis, Func CallBack)[] partialFuncs) + { + var key = new ReIndexCacheKey(BoundsInferGraph, from.Key, to.Key, new(from.Region), to.Promote); + IReadOnlyList to_region; + if (partialFuncs.Length == 0 && GlbReindexCache.TryGetValue(key, out var result)) + { + to_region = result.ToRegion; + sub_paddings = result.Paddings; + } + else + { + to_region = TileUtilities.GetRelativeNoPadBounds(BoundsInferGraph, from.Key, to.Key, from.Region, to.Promote, partialFuncs, out sub_paddings); + if (partialFuncs.Length == 0) + { + GlbReindexCache.Add(key, (new(to_region), sub_paddings)); + } + } + + return to[to_region.ToArray()]; + } + + protected virtual bool TryScheduleNextTileSize(int[] next_tile_size, PrimFunction logicalPrimFunc, AllocationCache response_cache, bool multi_workers, bool hasResult) + { + // 1. make one tile feed dict + var feed_dict = next_tile_size.Select((s, i) => + new[] { (LoopVars[i], (IValue)Value.FromTensor(Tensor.FromScalar(0))), + (TileSizeVars[i], (IValue)Value.FromTensor(Tensor.FromScalar(s))), }). + SelectMany(i => i). + ToDictionary(kv => kv.Item1, kv => kv.Item2); + var sched_candidate = new Dictionary(ReferenceEqualityComparer.Instance); + + // 2. folding the tileblock op to the block + PrimFunction new_logical_primfunc; + using (var dumpScope = new DumpScope(NullDumpper.Instance)) + { + var pass = new PrimFuncPass { Name = "FoldingTileBlock" }; + pass.Add(feed_dict); + pass.Add(); + pass.Add(); + var task = pass.RunAsync(new LogiclPrimFuncCloner().Clone(logicalPrimFunc, default), new()); + task.Wait(); + new_logical_primfunc = task.Result; + } + + BufferScheduler bufferScheduler = new(new_logical_primfunc); + + // 3. clloction buffers + bufferScheduler.LifeTimeAnalysis(); + + // compute the size in bytes + foreach (var buffer in bufferScheduler.RecordBuffers) + { + var dimensions = buffer.Dimensions.ToArray().Select(d => d.Evaluate(feed_dict).AsTensor().ToScalar()).ToArray(); + var strides = TensorUtilities.GetStrides(dimensions); + var glb_strides = strides.Select(s => s * buffer.ElemType.SizeInBytes).ToArray(); + + if (bufferScheduler.InnerConstraints[buffer] == ConstraintsMode.Channel && + + // 当load psum的时候,如果shape过小, 那么不额外添加stride. + !(buffer.Name.Split(".").Last().StartsWith(GNNEConv2D.PSum.Name) && + dimensions[2] * dimensions[3] < 14 * 14)) + { + glb_strides = TileUtilities.PaddingAvoidConflict(dimensions, glb_strides, 1); + strides = glb_strides.Select(s => + { + if (s % buffer.ElemType.SizeInBytes != 0) + { + throw new NotSupportedException(); + } + + return s / buffer.ElemType.SizeInBytes; + }).ToArray(); + } + + var size_n_byte = dimensions[0] * glb_strides[0]; + + // todo 可以不用align到一整行, 到一个bank即可. + var glb_size = TileUtilities.AlignBy(size_n_byte, ExtCompilerServices.Env.GlbBankWidth * ExtCompilerServices.Env.GlbWidth); + var physical_candidate = new PhysicalBuffer(buffer.Name, buffer.ElemType, buffer.MemLocation, dimensions, strides, start: 0, size: glb_size); + sched_candidate.Add(buffer, physical_candidate); + } + + var respose = bufferScheduler.Schedule(sched_candidate, multi_workers, hasResult); + response_cache.Add(next_tile_size, respose); + return respose.Success; + } + + /// + /// 申请 buffer. + /// NOTE 会自动添加到buffer map, 同时会记录他ddr 上的padding到字典中. + /// 如果给定 ddr buf region, 那么默认glb buffer region则是通过ddr buffer load 进来的,此时glb buffer的region是减去过padding的. + /// 如果promote到对应的循环后,那么申请buffer的时候在promote内部的循环都应该被调整到最大值. + /// + /// mapKey. + /// region. + /// 开启ping pong就会多开一块相同的buffer. + /// 如果promote为int,那么就会提升buffer到指定循环, 为-1那么就是整个计算块, 会忽略ping pong. + /// specificLoopBounds. + /// name. + /// . + /// NotSupportedException. + /// System.ArgumentOutOfRangeException. + protected virtual Expr GetBufferRegion(IndexMapKey mapKey, out BufferRegionView region, bool ping_pong = false, int? promote = null, Dictionary>? specificLoopBounds = null, [CallerArgumentExpression("region")] string name = "region") + { + if (name.StartsWith("var ")) + { + name = name[4..]; + } + + if (KeyToViewMap.ContainsKey(mapKey)) + { + region = KeyToViewMap[mapKey]; + return T.Nop(); + } + + name = NameAllocator.Get(name); + switch (mapKey.Expr) + { + case TensorConst con: + { + IEnumerable bounds; + IEnumerable clampedBounds; + if (promote is int promoteInt) + { + if (promoteInt == -1) + { + clampedBounds = mapKey.Expr.CheckedShape.Select(s => new Range(0, s.FixedValue, 1)); + bounds = BoundsInferGraph[mapKey].Bounds; + } + else + { + if (specificLoopBounds is null || !specificLoopBounds.TryGetValue(mapKey, out var newBounds)) + { + newBounds = K510TIRExtensions.PromotedBounds(promoteInt, BoundsInferGraph, mapKey, LoopVars, LoopDomains).ToList(); + } + + bounds = newBounds; + clampedBounds = TIRUtilities.ClampBounds(newBounds, mapKey.Expr.CheckedShape); + } + } + else + { + bounds = BoundsInferGraph[mapKey].Bounds; + clampedBounds = BoundsInferGraph[mapKey].ClampedBounds; + } + + T.ConstBuffer(con, out var ddr_buffer, name); + if (ping_pong) + { + throw new NotSupportedException(); + } + + region = new BufferRegionView(new[] { ddr_buffer }, bounds, clampedBounds, mapKey); + break; + } + + case Call call: + { + // 1. 对于glb buffer来说, 他的总大小要跟着申请buffer维度来变化. + // note 实际上对于 + List bounds; + Expr loopCount; + if (promote is int promoteInt) + { + if (promoteInt == -1) + { + bounds = mapKey.Expr.CheckedShape.Select(s => new Range(0, s.FixedValue, 1)).ToList(); + loopCount = 0; + } + else + { + if (specificLoopBounds is null || !specificLoopBounds.TryGetValue(mapKey, out var newBounds)) + { + newBounds = K510TIRExtensions.PromotedBounds(promoteInt, BoundsInferGraph, mapKey, LoopVars, LoopDomains).ToList(); + } + + bounds = newBounds; + loopCount = K510TIRExtensions.PromotedLoopCount(promoteInt, LoopVars, LoopDomains); + } + } + else + { + bounds = BoundsInferGraph[mapKey].Bounds.ToList(); + loopCount = LoopCount; + } + + // note 这里的bounds实际上会因为输入不同的var而被改变, 所以后面要获取dimension的地方需要注意. + var dimensions = bounds.Select(r => r.Stop - r.Start).Select((b, i) => MathF.Min(b, call.CheckedShape[i].FixedValue)).ToArray(); + List glb_buffers = new(); + if (ping_pong) + { + for (int i = 0; i < TileOptions.PingPongNum; i++) + { + glb_buffers.Add(new LogicalBuffer(name + $"(p{i})", call.CheckedDataType, MemoryLocation.L2Data, dimensions)); + } + } + else + { + glb_buffers.Add(new LogicalBuffer(name, call.CheckedDataType, MemoryLocation.L2Data, dimensions)); + } + + // 对于glb_buffer来说, 默认region 从0 开始, 但是要减去输入ddr index 的padding. + var noPadBounds = TIRUtilities.ComputeNoPadBounds(bounds, TIRUtilities.ComputePaddings(bounds, mapKey.Expr.CheckedShape)); + region = new BufferRegionView(glb_buffers, bounds, noPadBounds, mapKey, loopCount, promote); + break; + } + + case Var v: + { + // the different mapkey will point to same var: add(v,conv(v)) + if (!VarToKeyMap.TryGetValue(v, out var old_map_key)) + { + T.PhysicalBuffer(v.CheckedDataType, MemoryLocation.Input, v.CheckedShape.ToValueArray(), out var ddr_buffer, name); + IEnumerable clampedBounds; + IReadOnlyList bounds; + if (promote is int promoteInt) + { + if (promoteInt != -1) + { + if (specificLoopBounds is null || !specificLoopBounds.TryGetValue(mapKey, out var newBounds)) + { + newBounds = K510TIRExtensions.PromotedBounds(promoteInt, BoundsInferGraph, mapKey, LoopVars, LoopDomains).ToList(); + } + + bounds = newBounds; + } + else + { + bounds = mapKey.Expr.CheckedShape.Select(s => new Range(0, s.FixedValue, 1)).ToList(); + } + + clampedBounds = TIRUtilities.ClampBounds(bounds, mapKey.Expr.CheckedShape); + } + else + { + bounds = BoundsInferGraph[mapKey].Bounds; + clampedBounds = BoundsInferGraph[mapKey].ClampedBounds; + } + + if (ping_pong) + { + throw new NotSupportedException(); + } + + region = new BufferRegionView(new[] { ddr_buffer }, bounds, clampedBounds, mapKey); + VarToKeyMap.Add(v, mapKey); + } + else + { + region = KeyToViewMap[old_map_key]; + } + + break; + } + + case None none: + region = BufferRegionView.None(mapKey); + break; + default: + throw new NotSupportedException(); + } + + KeyToViewMap.Add(mapKey, region); + return T.Nop(); + } + + /// + /// promote 的逻辑, 根据值选择移动当前的buffer开在哪个循环. + /// -1 表示在所有循环之外 + /// 0 表示在N循环内 + /// 3 表示在W循环内. + /// + /// + /// 上一级传入的key. + /// call. + /// op. + /// block_name. + /// prefix. + /// promote. + /// is enable soft pipe line. + /// . + protected virtual ISequentialBuilder LowerGnneLoad(IndexMapKey parentKey, Call call, GNNELoad op, string block_name, string prefix, int? promote, bool softPipeLine) + { + var call_input = IndexMapKey.Create(call, GNNELoad.Input); + var call_deq = IndexMapKey.Create(call, GNNELoad.DeqParams); + + var seq = T.Sequential().Body( + Visit(call_deq, prefix, promote), + GetBufferRegion(call_input, out var ddr_ld_input, name: prefix + "." + TileNames.DdrInput, promote: promote), // loadif 的输入可能来自于const或输入 + GetBufferRegion(call_deq, out var glb_ld_qarg_input), // 只有promote到n循环外时,才不进行ping pong. + GetBufferRegion(parentKey, out var glb_ld_output, promote == -1 ? false : TileOptions.PingPong, name: prefix, promote: promote)); // load if 要用的glb buffer + + var block = EAction.TileBlock(block_name). + Alloc(promote is null ? glb_ld_output.Buffers : None.Default). + Reads(ddr_ld_input.BufferRegions, glb_ld_qarg_input.BufferRegions). + Writes(glb_ld_output.BufferRegions). + Predicate(true).// todo 这里先不做局部加载, 后面再实现 + Body(// promote这里load用的是full region, 但是在字典中存的还应该是partial的, 因为后面是每个glb的tile在使用. + softPipeLine ? + (TileOptions.PingPong & (promote != -1) ? K510.PingPongSlot(block_name, glb_ld_output.LoopCount / TileOptions.PingPongNum, glb_ld_output.LoopCount % TileOptions.PingPongNum) : T.Nop()) : + T.Nop(), + EAction.LoadT(ddr_ld_input, glb_ld_output, glb_ld_qarg_input, op.DeqAxis)); + + if (promote is int promoteIndex) + { + // 如果promote, 那么在这个循环的所有block外执行 + NestedBlocks[promoteIndex + 1].Init(block); + NestedBlocks[promoteIndex + 1].Alloc(glb_ld_output.Buffers); + } + else + { + seq.Body(block); + } + + return seq; + } + + protected virtual ISequentialBuilder LowerGnneMeshNet(IndexMapKey parentKey, Call call, GNNEMeshNet target, string block_name, string prefix) + { + prefix = NameAllocator.Get(nameof(GNNEMeshNet)); + var call_in_a = IndexMapKey.Create(call, GNNEMeshNet.InputA); + var call_in_b = IndexMapKey.Create(call, GNNEMeshNet.InputB); + var call_in_seg0 = IndexMapKey.Create(call, GNNEMeshNet.SegFittingParam0); + var call_in_seg1 = IndexMapKey.Create(call, GNNEMeshNet.SegFittingParam1); + var seq = T.Sequential().Body( + Visit(call_in_a, prefix), + Visit(call_in_b, prefix), + GetBufferRegion(call_in_a, out var meshnet_input_a), + GetBufferRegion(call_in_b, out var meshnet_input_b), + GetBufferRegion(call_in_seg0, out var meshnet_input_seg0), + GetBufferRegion(call_in_seg1, out var meshnet_input_seg1), + GetBufferRegion(parentKey, out var meshnet_output, TileOptions.PingPong, name: prefix), + EAction.TileBlock(block_name). + Alloc(meshnet_output.Buffers). + Reads( + meshnet_input_a.BufferRegions, + meshnet_input_b.BufferRegions, + meshnet_input_seg0.BufferRegions, + meshnet_input_seg1.BufferRegions). + Body( + EAction.MeshNetCompute( + (Fusion)call[GNNEMeshNet.MeshFunc], + meshnet_input_a, + meshnet_input_b, + meshnet_input_seg0, + meshnet_input_seg1, + meshnet_output))); + + if (!(call[GNNEMeshNet.InputB] is None && call[GNNEMeshNet.NewShape] is None && call[GNNEMeshNet.SegFittingParam0] is None && call[GNNEMeshNet.SegFittingParam1] is None && !TileUtilities.MeshFuncHasConstants((Fusion)call[GNNEMeshNet.MeshFunc]))) + { + foreach (var item in meshnet_output.Buffers) + { + item.Metadata = new TileMetadata() { StrideByShape = true }; + } + } + + return seq; + } + + protected virtual ISequentialBuilder LowerGnneStore(Call call, GNNEStore op, string block_name, string prefix, bool promoteQarg = true) + { + prefix = NameAllocator.Get(nameof(GNNEStore)); + var cropPadding = ((TensorConst)call[GNNEStore.CropPadding]).Value.Cast(); + var channel = call.CheckedShape[1].FixedValue; + bool is_quant_by_channel = false; + if (call[GNNEStore.QuantParams] is Call { Target: GNNELoad } l_qarg && l_qarg[GNNELoad.Input] is TensorConst qarg) + { + _ = qarg.Value.Cast(); + if (qarg[0] != qarg[channel - 1]) + { + is_quant_by_channel = true; + } + } + + var outputShape = call.CheckedShape.ToValueArray(); + T.PhysicalBuffer(call.CheckedDataType, MemoryLocation.Output, outputShape, out var ddr_st_buffer, name: prefix + ".ddr_buffer"); + var bounds = BoundsInferGraph[call].Bounds; + + var (paddingHBefore, paddingHafter) = TileUtilities.ComputePadding(bounds[2] - cropPadding[0, 0], outputShape[2]); + var (paddingWBefore, paddingWafter) = TileUtilities.ComputePadding(bounds[3] - cropPadding[1, 0], outputShape[3]); + + var newBounds = bounds.ToArray(); + newBounds[2] = newBounds[2] - cropPadding[0, 0]; + newBounds[3] = newBounds[3] - cropPadding[1, 0]; + var ddrRegion = TIRUtilities.ClampBounds(newBounds, outputShape); + var ddr_st_output = new BufferRegionView(new[] { ddr_st_buffer }, BoundsInferGraph[call].Bounds, ddrRegion, call); + KeyToViewMap.Add(call, ddr_st_output); + + var call_in = IndexMapKey.Create(call, GNNEStore.Input); + var call_qarg = IndexMapKey.Create(call, GNNEStore.QuantParams); + return T.Sequential().Body( + Visit(call_in, prefix), + Visit(call_qarg, prefix, promoteQarg ? -1 : null), // 多层是只按h切, 此时oc满的,默认promote. + GetBufferRegion(call_in, out var glb_st_input), + GetBufferRegion(call_qarg, out var glb_st_qarg_input), + EAction.TileBlock(block_name).Reads(glb_st_input.BufferRegions, glb_st_qarg_input.BufferRegions).Body( + EAction.StoreT(ddr_st_output, glb_st_input[.., .., (glb_st_input.Region[2].Start + paddingHBefore, glb_st_input.Region[2].Stop - paddingHafter), (glb_st_input.Region[3].Start + paddingWBefore, glb_st_input.Region[3].Stop - paddingWafter)], glb_st_qarg_input, null, is_quant_by_channel))); + } + + protected virtual ISequentialBuilder LowerGnneReduce(IndexMapKey parentKey, Call call, GNNEReduce op, string block_name, string prefix) + { + prefix = NameAllocator.Get(nameof(GNNEReduce)); + var reduce_in = IndexMapKey.Create(call, GNNEReduce.Input); + var seq = T.Sequential().Body( + Visit(reduce_in, prefix), + GetBufferRegion(reduce_in, out var gnne_reduce_input), + GetBufferRegion(parentKey, out var gnne_reduce_output, TileOptions.PingPong, name: prefix), + EAction.TileBlock(block_name). + Alloc(gnne_reduce_output.Buffers). + Reads(gnne_reduce_input.BufferRegions).Body( + EAction.Reduce( + gnne_reduce_input, + gnne_reduce_output, + call[GNNEReduce.InitValue], + op.ReduceOp, + op.ReduceDim))); + + return seq; + } + + /// + /// 对于dw卷积来说,每个ic对应一个oc, 因此让每个tcu计算一半的if. + /// + protected ITileBlockBuilder GNNEConv2DSharedNone(Call call, string block_name, BufferRegionView glb_w, BufferRegionView glb_if, BufferRegionView glb_act, BufferRegionView glb_psum, BufferRegionView glb_of, bool is_init_psum, string prefix, int? promote = null) + { + var init_psums = GetInitPSumBufferRegion(call, IndexMapKey.Create(call, GNNEConv2D.PSum), glb_psum, promote, prefix, 1, ExtCompilerServices.Env.TcuActNum, out var part_condition); + + var block = EAction.TileBlock(block_name).Reads(glb_w.BufferRegions, is_init_psum ? init_psums[0].BufferRegions.Concat(init_psums[1].BufferRegions.Select(b => MathF.Condition(EAction.FoldOfMarker(part_condition > 1), b))).ToArray() : glb_psum.BufferRegions, glb_if.BufferRegions, glb_act.BufferRegions).Writes(glb_of.BufferRegions).Body( + T.Unrolled(out var kh, new(0, glb_w.Dimensions[2], ExtCompilerServices.Env.PuHeight)).Body( + T.Unrolled(out var kw, new(0, glb_w.Dimensions[3], ExtCompilerServices.Env.PuKernelSpad)).Body( + T.Let(out var m_once, 1).Body( + T.Let(out var c_once, MathF.Select(MathF.Equal(m_once, glb_w.RegionSize(0)), MathF.Min(MathF.Min(ExtCompilerServices.Env.PuWidth / m_once, ExtCompilerServices.Env.PuHeight / MathF.Min(glb_w.Dimensions[2], ExtCompilerServices.Env.PuHeight)), glb_w.RegionSize(0)), 1)).Body(// note 我这里没有实现dw卷积的多输出channel的, 默认都是 1 ic : 1 oc. + T.Let(out var tcu_oc_chunk, TileUtilities.Split(glb_of.RegionSize(1), ExtCompilerServices.Env.TcuActNum)).Body(// 1. determine tcu act num + T.Let(out var n_active_tcu, TileUtilities.SplitTimes(glb_of.RegionSize(1), tcu_oc_chunk)).Body(// 2. broadcast action + EAction.TcuDmBroadCast(TcuDivideStrategy.NoShare), + T.Unrolled(out var tcu_oc, new(glb_of.Region[1].Start, glb_of.Region[1].Stop, tcu_oc_chunk)).Body(// 3. loop over tcus and config each tcu + T.Let(out var m_once_tcu, 1).Body( + T.Let(out var c_once_tcu, MathF.Min(ExtCompilerServices.Env.PuHeight / glb_w.RegionSize(2), MathF.Min(tcu_oc + tcu_oc_chunk, glb_of.Region[1].Stop) - tcu_oc)).Body( + EAction.TcuPuConfAct( + TileUtilities.GetTcuIndexBits(tcu_oc / tcu_oc_chunk), + GlbReIndex(glb_of[.., (tcu_oc, MathF.Min(tcu_oc + tcu_oc_chunk, glb_of.Region[1].Stop)), .., ..], glb_act, out _), + call[GNNEConv2D.FusedClamp][0], + call[GNNEConv2D.FusedClamp][1]), + EAction.TcuPuConf( + TileUtilities.GetTcuIndexBits(tcu_oc / tcu_oc_chunk), + GlbReIndex(glb_of[.., (tcu_oc, MathF.Min(tcu_oc + tcu_oc_chunk, glb_of.Region[1].Stop)), .., ..], glb_if, out _), + glb_of[.., (tcu_oc, MathF.Min(tcu_oc + tcu_oc_chunk, glb_of.Region[1].Stop)), .., ..], + IR.F.Math.Min(glb_w.Dimensions[2], ExtCompilerServices.Env.PuHeight), + IR.F.Math.Min(glb_w.Dimensions[3], ExtCompilerServices.Env.PuKernelSpad), + m_once: m_once_tcu, + c_once: c_once_tcu, + groups: 1, + mode: TcuComputeMode.DwConv2d), + EAction.TcuDmConfOf(// todo 这里hardcode两个tcu, 后面需要改进 + TileUtilities.GetTcuIndexBits(tcu_oc / tcu_oc_chunk), + is_init_psum ? MathF.Select(MathF.Equal(tcu_oc, glb_of.Region[1].Start), init_psums[0][.., (0, MathF.Min(tcu_oc + tcu_oc_chunk, glb_of.Region[1].Stop) - tcu_oc), .., ..], init_psums[1][.., (0, MathF.Min(tcu_oc + tcu_oc_chunk, glb_of.Region[1].Stop) - tcu_oc), .., ..]) : GlbReIndex(glb_of[.., (tcu_oc, MathF.Min(tcu_oc + tcu_oc_chunk, glb_of.Region[1].Stop)), .., ..], glb_psum, out _), + glb_of[.., (tcu_oc, MathF.Min(tcu_oc + tcu_oc_chunk, glb_of.Region[1].Stop)), .., ..], + 0), + EAction.TcuDmConfIf( + TileUtilities.GetTcuIndexBits(tcu_oc / tcu_oc_chunk), + GlbReIndex(glb_of[.., (tcu_oc, MathF.Min(tcu_oc + tcu_oc_chunk, glb_of.Region[1].Stop)), .., ..], glb_if, out var if_paddings), + stride_w: call[GNNEConv2D.Stride][1], + stride_h: call[GNNEConv2D.Stride][0], + input_c_pre_pu: c_once_tcu, + dilation_h: call[GNNEConv2D.Dilation][0], + padding_top: if_paddings[2].Before, + padding_bottom: if_paddings[2].After, + padding_left: if_paddings[3].Before, + padding_right: if_paddings[3].After), + EAction.TcuDmConfW( + TileUtilities.GetTcuIndexBits(tcu_oc / tcu_oc_chunk), + GlbReIndex(glb_of[.., (tcu_oc, MathF.Min(tcu_oc + tcu_oc_chunk, glb_of.Region[1].Stop)), .., ..], glb_w, out _)), + EAction.TcuDmFetchW(// 4. fetch weights + TileUtilities.GetTcuIndexBits(tcu_oc / tcu_oc_chunk), + GlbReIndex(glb_of[.., (tcu_oc, MathF.Min(tcu_oc + tcu_oc_chunk, glb_of.Region[1].Stop)), .., ..], glb_w, out _)), + EAction.TcuDmFetchIf(// 5. loop over tcus and fetch if for each tcu + TileUtilities.GetTcuIndexBits(tcu_oc / tcu_oc_chunk), + GlbReIndex(glb_of[.., (tcu_oc, MathF.Min(tcu_oc + tcu_oc_chunk, glb_of.Region[1].Stop)), .., ..], glb_if, out _))))), + EAction.TcuPuCompute(// 6. pu compute NOTE 这里我没有在weight的kh和kw上切,所以默认都是一次算完的 + TileUtilities.GetNTcuIndexBits(n_active_tcu), + true, + true, + call[GNNEConv2D.PSum] is not Call { Target: Uninitialized }, + TileUtilities.GetNTcuIndexBits(n_active_tcu))))))))); + + if (promote is null) + { + block.Alloc(glb_of.Buffers, is_init_psum ? init_psums[0].Buffers.OfType().Concat(init_psums[1].Buffers.Select(b => MathF.Condition(EAction.FoldOfMarker(part_condition > 1), b))).ToArray() : None.Default); + } + else if (promote is int promoteIndex) + { + if (is_init_psum) + { + NestedBlocks[promoteIndex + 1].Alloc(init_psums[0].Buffers.OfType().Concat(init_psums[1].Buffers.Select(b => MathF.Condition(EAction.FoldOfMarker(part_condition > 1), b))).ToArray()); + } + } + + return block; + } + + protected virtual BufferRegionView[] GetInitPSumBufferRegion(Call call, IndexMapKey key, BufferRegionView glb_psum, int? promote, string prefix, int split_axis, int tcuActNum, out Expr part_condition) + { + var chunk = TileUtilities.Split(glb_psum.Dimensions[split_axis], tcuActNum); + part_condition = TileUtilities.SplitTimes(glb_psum.Dimensions[split_axis], chunk); + var views = new BufferRegionView[2]; + var psum_dimensions = glb_psum.Dimensions.ToArray(); + psum_dimensions[split_axis] = chunk; + var dimensions = psum_dimensions; + + // build psum a and b + foreach (var (part, i) in new[] { "_a", "_b" }.Select((p, i) => (p, i))) + { + var name = prefix + "." + TileNames.InitPSum + part; + var glb_init_psums = new List(); + if (TileOptions.PingPong) + { + for (int p = 0; p < TileOptions.PingPongNum; p++) + { + glb_init_psums.Add(new LogicalBuffer(name + $"(p{p})", DataTypes.Float32, MemoryLocation.L2Data, dimensions)); + } + } + else + { + glb_init_psums.Add(new LogicalBuffer(name, DataTypes.Float32, MemoryLocation.L2Data, dimensions)); + } + + Expr loopCount; + if (promote is int promoteInt) + { + if (promoteInt != -1) + { + loopCount = K510TIRExtensions.PromotedLoopCount(promoteInt, LoopVars, LoopDomains); + } + else + { + loopCount = LoopCount; + } + } + else + { + loopCount = LoopCount; + } + + views[i] = new BufferRegionView(glb_init_psums, glb_psum.Bounds, psum_dimensions.Select(d => new Range(0, d, 1)), key, loopCount, promote); + } + + return views; + } + + protected virtual Expr GNNEConv2DComputeActEnable(Call call, BufferRegionView glb_w, Expr khStop, Expr kHBounds, Expr kwStop, Expr kWBounds) + { + return MathF.LogicalAnd(MathF.GreaterEqual(khStop, kHBounds), MathF.GreaterEqual(kwStop, kWBounds)); + } + + protected virtual Expr GNNEConv2DComputeOfEnable(Call call, BufferRegionView glb_w, Expr khStop, Expr kHBounds, Expr kwStop, Expr kWBounds) + { + return GNNEConv2DComputeActEnable(call, glb_w, khStop, kHBounds, kwStop, kWBounds); + } + + protected virtual Expr GNNEConv2DComputeLoadPsumEnable(Call call, BufferRegionView glb_w, Expr kh, Expr kw) + { + if (call[IR.K510.GNNEConv2D.PSum] is Call { Target: IR.Buffers.Uninitialized }) + { + return IR.F.Math.LogicalNot(IR.F.Math.LogicalAnd(IR.F.Math.Equal(kh, 0), IR.F.Math.Equal(kw, 0))); + } + + return true; + } + + /// + /// share if 是每个tcu计算一半的oc, 此时他们共享同一个if. + /// + protected ITileBlockBuilder GNNEConv2DSharedIF(Call call, string block_name, BufferRegionView glb_w, BufferRegionView glb_if, BufferRegionView glb_act, BufferRegionView glb_psum, BufferRegionView glb_of, bool is_init_psum, string prefix, int? promote = null) + { + var reGlbIf = GlbReIndex(glb_of, glb_if, out var sub_paddings); + + var init_psums = GetInitPSumBufferRegion(call, IndexMapKey.Create(call, GNNEConv2D.PSum), glb_psum, promote, prefix, 1, ExtCompilerServices.Env.TcuActNum, out var part_condition); + + var block = EAction.TileBlock(block_name).Reads(glb_w.BufferRegions, is_init_psum ? init_psums[0].BufferRegions.Concat(init_psums[1].BufferRegions.Select(b => MathF.Condition(EAction.FoldOfMarker(part_condition > 1), b))).ToArray() : glb_psum.BufferRegions, glb_if.BufferRegions, glb_act.BufferRegions).Body( + T.Let(out var khChunck, MathF.Min(glb_w.Dimensions[2], ExtCompilerServices.Env.PuHeight)).Body( + T.Let(out var kwChunck, MathF.Min(glb_w.Dimensions[3], ExtCompilerServices.Env.PuKernelSpad)).Body( + T.Unrolled(out var kh, new(0, glb_w.Dimensions[2], khChunck)).Body(// 对kernel h/w进行tiling 暂时先不考虑 + T.Unrolled(out var kw, new(0, glb_w.Dimensions[3], kwChunck)).Body( + T.Let(out var tcu_oc_chunk, TileUtilities.Split(glb_of.RegionSize(1), ExtCompilerServices.Env.TcuActNum)).Body(// 1. determine tcu act num + T.Let(out var n_active_tcu, TileUtilities.SplitTimes(glb_of.RegionSize(1), tcu_oc_chunk)).Body( + T.If(MathF.Equal(n_active_tcu, 1)).Then(// 3. broadcast action + EAction.TcuDmBroadCast(TcuDivideStrategy.NoShare)).Else( + EAction.TcuDmBroadCast(TcuDivideStrategy.ShareIf)), + EAction.TcuDmConfIf(// 4. conf if + TileUtilities.GetNTcuIndexBits(n_active_tcu), + reGlbIf, + stride_w: call[GNNEConv2D.Stride][1], + stride_h: call[GNNEConv2D.Stride][0], + input_c_pre_pu: MathF.Min(ExtCompilerServices.Env.PuHeight / glb_w.RegionSize(2), glb_if.RegionSize(1)), // todo 这里可能有问题. + dilation_h: call[GNNEConv2D.Dilation][0], + padding_top: sub_paddings[2].Before, + padding_bottom: sub_paddings[2].After, + padding_left: sub_paddings[3].Before, + padding_right: sub_paddings[3].After), + T.Unrolled(out var tcu_oc, new(glb_of.Region[1].Start, glb_of.Region[1].Stop, tcu_oc_chunk)).Body(// 5. loop over tcus and config each tcu + EAction.TcuPuConfAct( + TileUtilities.GetTcuIndexBits(tcu_oc / tcu_oc_chunk), + GlbReIndex(glb_of[.., (tcu_oc, MathF.Min(tcu_oc + tcu_oc_chunk, glb_of.Region[1].Stop)), .., ..], glb_act, out _), + call[GNNEConv2D.FusedClamp][0], + call[GNNEConv2D.FusedClamp][1]), + EAction.TcuPuConf( + TileUtilities.GetTcuIndexBits(tcu_oc / tcu_oc_chunk), + reGlbIf, // 切oc对于if不影响 + glb_of[.., (tcu_oc, MathF.Min(tcu_oc + tcu_oc_chunk, glb_of.Region[1].Stop)), .., ..], + khChunck, + kwChunck, + m_once: MathF.Min(ExtCompilerServices.Env.PuWidth, MathF.Min(tcu_oc + tcu_oc_chunk, glb_of.Region[1].Stop) - tcu_oc), + c_once: MathF.Min(MathF.Min(ExtCompilerServices.Env.PuHeight / glb_w.RegionSize(2), glb_w.RegionSize(1)), glb_if.RegionSize(1)), + groups: call[GNNEConv2D.Groups], + mode: TcuComputeMode.NormalConv2d), + EAction.TcuDmConfOf(// todo 这里hardcode两个tcu, 后面需要改进 + TileUtilities.GetTcuIndexBits(tcu_oc / tcu_oc_chunk), + is_init_psum ? MathF.Select(MathF.Equal(tcu_oc, glb_of.Region[1].Start), init_psums[0][.., (0, MathF.Min(tcu_oc + tcu_oc_chunk, glb_of.Region[1].Stop) - tcu_oc), .., ..], init_psums[1][.., (0, MathF.Min(tcu_oc + tcu_oc_chunk, glb_of.Region[1].Stop) - tcu_oc), .., ..]) : GlbReIndex(glb_of[.., (tcu_oc, MathF.Min(tcu_oc + tcu_oc_chunk, glb_of.Region[1].Stop)), .., ..], glb_psum, out _), + glb_of[.., (tcu_oc, MathF.Min(tcu_oc + tcu_oc_chunk, glb_of.Region[1].Stop)), .., ..], + 0), + EAction.TcuDmConfW( + TileUtilities.GetTcuIndexBits(tcu_oc / tcu_oc_chunk), + GlbReIndex(glb_of[.., (tcu_oc, MathF.Min(tcu_oc + tcu_oc_chunk, glb_of.Region[1].Stop)), .., ..], glb_w, out _))), + T.Unrolled(out var tcu_oc2, new(glb_of.Region[1].Start, glb_of.Region[1].Stop, tcu_oc_chunk)).Body( + EAction.TcuDmFetchW(// 6. fetch weights. + TileUtilities.GetTcuIndexBits(tcu_oc2 / tcu_oc_chunk), + GlbReIndex(glb_of[.., (tcu_oc2, MathF.Min(tcu_oc2 + tcu_oc_chunk, glb_of.Region[1].Stop)), .., ..], glb_w, out _))), + EAction.TcuDmFetchIf(// 7. fetch if. + TileUtilities.GetNTcuIndexBits(n_active_tcu), + reGlbIf), + EAction.TcuPuCompute(// 8. tcu compute + TileUtilities.GetNTcuIndexBits(n_active_tcu), + act_enable: GNNEConv2DComputeActEnable(call, glb_w, kh + khChunck, glb_w.Dimensions[2], kw + kwChunck, glb_w.Dimensions[3]), + of_enable: GNNEConv2DComputeOfEnable(call, glb_w, kh + khChunck, glb_w.Dimensions[2], kw + kwChunck, glb_w.Dimensions[3]), + load_psum: GNNEConv2DComputeLoadPsumEnable(call, glb_w, kh, kw), + TileUtilities.GetNTcuIndexBits(n_active_tcu))))))))); + + if (promote is null) + { + block.Alloc(glb_of.Buffers, is_init_psum ? init_psums[0].Buffers.OfType().Concat(init_psums[1].Buffers.Select(b => MathF.Condition(EAction.FoldOfMarker(part_condition > 1), b))).ToArray() : None.Default); + } + else if (promote is int promoteIndex) + { + if (is_init_psum) + { + NestedBlocks[promoteIndex + 1].Alloc(init_psums[0].Buffers.OfType().Concat(init_psums[1].Buffers.Select(b => MathF.Condition(EAction.FoldOfMarker(part_condition > 1), b))).ToArray()); + } + } + + return block; + } + + /// + /// 假设oc为32被拆分之后每个tcu只能映射一半, 那么两个tcu共享一份weights, 在ofmap的h上进行切分. + /// + protected ITileBlockBuilder GNNEConv2DSharedW(Call call, string block_name, BufferRegionView glb_w, BufferRegionView glb_if, BufferRegionView glb_act, BufferRegionView glb_psum, BufferRegionView glb_of, bool is_depthwise, bool is_init_psum, string prefix, int? promote = null) + { + var init_psums = GetInitPSumBufferRegion(call, IndexMapKey.Create(call, GNNEConv2D.PSum), glb_psum, promote, prefix, 2, ExtCompilerServices.Env.TcuActNum, out var part_condition); + + var reGlbW = GlbReIndex(glb_of[.., .., .., ..], glb_w, out _); + var (iH, iW) = (glb_if.Dimensions[2], glb_if.Dimensions[3]); + var (kH, kW) = (glb_w.Dimensions[2], glb_w.Dimensions[3]); + var stride = ((TensorConst)call[IR.K510.GNNEConv2D.Stride]).Value.Cast(); + var padding = ((TensorConst)call[IR.K510.GNNEConv2D.Padding]).Value.Cast(); + var dilation = ((TensorConst)call[IR.K510.GNNEConv2D.Dilation]).Value.Cast(); + + var block = EAction.TileBlock(block_name).Reads(glb_w.BufferRegions, is_init_psum ? init_psums[0].BufferRegions.Concat(init_psums[1].BufferRegions.Select(b => MathF.Condition(EAction.FoldOfMarker(part_condition > 1), b))).ToArray() : glb_psum.BufferRegions, glb_if.BufferRegions, glb_act.BufferRegions).Body( + T.Let(out var khChunck, MathF.Min(kH, ExtCompilerServices.Env.PuHeight)).Body( + T.Let(out var kwChunck, dilation[1] != 1 ? 1 : MathF.Min(kW, ExtCompilerServices.Env.PuKernelSpad)).Body( + T.Unrolled(out var kh, new(0, kH, khChunck)).Body( + T.Unrolled(out var kw, new(0, kW, kwChunck)).Body(// NOTE dw卷积时m once指的是一次ic对应输出多少个oc, 所以默认为1 + T.Let(out var m_once, is_depthwise ? 1 : MathF.Min(ExtCompilerServices.Env.PuWidth, glb_w.RegionSize(0))).Body( + T.Let(out var c_once, is_depthwise ? MathF.Min(MathF.Min(ExtCompilerServices.Env.PuWidth / m_once, ExtCompilerServices.Env.PuHeight / glb_w.RegionSize(2)), glb_w.RegionSize(0)) : MathF.Min(MathF.Min(ExtCompilerServices.Env.PuHeight / glb_w.RegionSize(2), glb_w.RegionSize(1)), glb_if.RegionSize(1))).Body(// NOTE dw卷积时, if是按对角线排列的, 所以要小于min(pu w/pu h) + T.Let(out var tcu_oh_chunk, TileUtilities.Split(glb_of.RegionSize(2), ExtCompilerServices.Env.TcuActNum)).Body(// segment tcu h in output_h + T.Let(out var n_active_tcu, TileUtilities.SplitTimes(glb_of.RegionSize(2), tcu_oh_chunk)).Body( + T.If(MathF.Equal(n_active_tcu, 1)).Then(// NOTE 这里的psum已经被load好了, 可能到时候会存在psum大小和后续不匹配的问题.// 3. broadcast action + EAction.TcuDmBroadCast(TcuDivideStrategy.NoShare)) + .Else( + EAction.TcuDmBroadCast(TcuDivideStrategy.ShareW)), + EAction.TcuDmConfW(TileUtilities.GetNTcuIndexBits(n_active_tcu), reGlbW[.., .., (kh, MathF.Min(kh + khChunck, kH)), (kw, MathF.Min(kw + kwChunck, kW))]), + T.Unrolled(out var tcu_oh, new(glb_of.Region[2].Start, glb_of.Region[2].Stop, tcu_oh_chunk)).Body(// 4. conf_w action + T.Let(out var tcu_index_bits, TileUtilities.GetTcuIndexBits(tcu_oh / tcu_oh_chunk)).Body(// 5. loop over tcus and config each tcu + EAction.TcuDmConfIf(// conf if + tcu_index_bits, + GlbReIndex(glb_of[.., .., (tcu_oh, MathF.Min(tcu_oh + tcu_oh_chunk, glb_of.Region[2].Stop)), ..], glb_if, out var if_padding, (2, r => TileUtilities.Conv2DSubSlice(r, new TIR.Range(kh, IR.F.Math.Min(kh + khChunck, kH), 1), stride[0], padding[0, 0], dilation[0])), (3, r => TileUtilities.Conv2DSubSlice(r, new TIR.Range(kw, IR.F.Math.Min(kw + kwChunck, kW), 1), stride[1], padding[1, 0], dilation[1]))), + stride_w: stride[1], + stride_h: stride[0], + input_c_pre_pu: MathF.Min(ExtCompilerServices.Env.PuHeight / glb_w.RegionSize(2), glb_if.RegionSize(1)), // todo 这里可能有问题. + dilation_h: call[GNNEConv2D.Dilation][0], + padding_top: if_padding[2].Before, + padding_bottom: if_padding[2].After, + padding_left: if_padding[3].Before, + padding_right: if_padding[3].After), + EAction.TcuPuConfAct(// conf act + tcu_index_bits, + GlbReIndex(glb_of[.., .., (tcu_oh, MathF.Min(tcu_oh + tcu_oh_chunk, glb_of.Region[2].Stop)), ..], glb_act, out _), + call[GNNEConv2D.FusedClamp][0], + call[GNNEConv2D.FusedClamp][1]), + EAction.TcuPuConf(// conf pu + tcu_index_bits, + GlbReIndex(glb_of[.., .., (tcu_oh, MathF.Min(tcu_oh + tcu_oh_chunk, glb_of.Region[2].Stop)), ..], glb_if, out var _, (2, r => TileUtilities.Conv2DSubSlice(r, new TIR.Range(kh, IR.F.Math.Min(kh + khChunck, kH), 1), stride[0], padding[0, 0], dilation[0])), (3, r => TileUtilities.Conv2DSubSlice(r, new TIR.Range(kw, IR.F.Math.Min(kw + kwChunck, kW), 1), stride[1], padding[1, 0], dilation[1]))), + glb_of[.., .., (tcu_oh, MathF.Min(tcu_oh + tcu_oh_chunk, glb_of.Region[2].Stop)), ..], + khChunck, + kwChunck, + m_once, + c_once, + groups: is_depthwise ? 1 : call[GNNEConv2D.Groups], // NOTE tcu pu conf 的group其实是multiplier的意思,就是一个ic会输出多个oc, 并不是标准conv定义的groups. + mode: is_depthwise ? TcuComputeMode.DwConv2d : TcuComputeMode.NormalConv2d), + EAction.TcuDmConfOf(// conf of + tcu_index_bits, + is_init_psum ? MathF.Select(MathF.Equal(tcu_oh, glb_of.Region[2].Start), init_psums[0][.., .., (0, MathF.Min(tcu_oh + tcu_oh_chunk, glb_of.Region[2].Stop) - tcu_oh), .., ..], init_psums[1][.., .., (0, MathF.Min(tcu_oh + tcu_oh_chunk, glb_of.Region[2].Stop) - tcu_oh), .., ..]) : GlbReIndex(glb_of[.., .., (tcu_oh, MathF.Min(tcu_oh + tcu_oh_chunk, glb_of.Region[2].Stop)), ..], glb_psum, out _), + glb_of[.., .., (tcu_oh, MathF.Min(tcu_oh + tcu_oh_chunk, glb_of.Region[2].Stop)), ..], + 0))), + EAction.TcuDmFetchW(TileUtilities.GetNTcuIndexBits(n_active_tcu), reGlbW[.., .., (kh, MathF.Min(kh + khChunck, kH)), (kw, MathF.Min(kw + kwChunck, kW))]), + T.Unrolled(out var tcu_oh2, new(glb_of.Region[2].Start, glb_of.Region[2].Stop, tcu_oh_chunk)).Body(// 6. fetch weights + EAction.TcuDmFetchIf( + TileUtilities.GetTcuIndexBits(tcu_oh2 / tcu_oh_chunk), // 7. loop over tcus and fetch if for each tcu + GlbReIndex(glb_of[.., .., (tcu_oh2, MathF.Min(tcu_oh2 + tcu_oh_chunk, glb_of.Region[2].Stop)), ..], glb_if, out _, (2, r => TileUtilities.Conv2DSubSlice(r, new TIR.Range(kh, IR.F.Math.Min(kh + khChunck, kH), 1), stride[0], padding[0, 0], dilation[0])), (3, r => TileUtilities.Conv2DSubSlice(r, new TIR.Range(kw, IR.F.Math.Min(kw + kwChunck, kW), 1), stride[1], padding[1, 0], dilation[1]))))), + EAction.TcuPuCompute(// 8. tcu compute. + TileUtilities.GetNTcuIndexBits(n_active_tcu), + GNNEConv2DComputeOfEnable(call, glb_w, kh + khChunck, kH, kw + kwChunck, kW), + GNNEConv2DComputeActEnable(call, glb_w, kh + khChunck, kH, kw + kwChunck, kW), + GNNEConv2DComputeLoadPsumEnable(call, glb_w, kh, kw), + TileUtilities.GetNTcuIndexBits(n_active_tcu))))))))))); + + if (promote is null) + { + block.Alloc(glb_of.Buffers, is_init_psum ? init_psums[0].Buffers.OfType().Concat(init_psums[1].Buffers.Select(b => MathF.Condition(EAction.FoldOfMarker(part_condition > 1), b))).ToArray() : None.Default); + } + else if (promote is int promoteIndex) + { + if (is_init_psum) + { + NestedBlocks[promoteIndex + 1].Alloc(init_psums[0].Buffers.OfType().Concat(init_psums[1].Buffers.Select(b => MathF.Condition(EAction.FoldOfMarker(part_condition > 1), b))).ToArray()); + } + } + + return block; + } + + protected virtual ISequentialBuilder LowerGnneConv2D(IndexMapKey parentKey, Call call, GNNEConv2D op, string block_name, string prefix) + { + prefix = NameAllocator.Get(nameof(GNNEConv2D)); + bool is_depthwise; + { + var groups = ((TensorConst)call[GNNEConv2D.Groups]).Value.ToScalar(); + var input_channels = call[GNNEConv2D.Input].CheckedShape[1].FixedValue; + var output_channels = call.CheckedShape[1].FixedValue; + is_depthwise = input_channels == output_channels && output_channels == groups && groups != 1; + } + + if (is_depthwise) + { + prefix = prefix + "(dw)"; + block_name += "(dw)"; + } + + var call_w = IndexMapKey.Create(call, GNNEConv2D.Weights); + var call_in = IndexMapKey.Create(call, GNNEConv2D.Input); + var call_act = IndexMapKey.Create(call, GNNEConv2D.Act); + var call_psum = IndexMapKey.Create(call, GNNEConv2D.PSum); + + bool is_init_psum = call_psum.Expr is Call { Target: Uninitialized }; + + TcuDivideStrategy tcu_strategy; + if (!is_depthwise) + { + // 优先让每个tcu的width用满 + var out_shape = call.CheckedShape.ToValueArray(); + if (out_shape[1] >= ExtCompilerServices.Env.PuWidth * ExtCompilerServices.Env.TcuActNum) + { + tcu_strategy = TcuDivideStrategy.ShareIf; + } + else + { + tcu_strategy = TcuDivideStrategy.ShareW; + } + } + else + { // TODO 需要一种量化的方法来决定dw卷积用什么策略. + tcu_strategy = TcuDivideStrategy.NoShare; + } + + prefix = prefix + "." + tcu_strategy; + + // 默认是layer group的做法, 也就是w/act全部promote + return T.Sequential().Body( + Visit(call_w, prefix, -1), + Visit(call_in, prefix), + Visit(call_act, prefix, -1), + Visit(call_psum, prefix), + GetBufferRegion(call_w, out var glb_w), + GetBufferRegion(call_in, out var glb_if), // glb if 存在padding的情况. + GetBufferRegion(call_act, out var glb_act), + GetBufferRegion(call_psum, out var glb_psum, TileOptions.PingPong, name: prefix + "." + GNNEConv2D.PSum.Name), // note 这里的pusm申请了但不记录到allocs中,仅用于给psum apart使用. + GetBufferRegion(parentKey, out var glb_of, TileOptions.PingPong, name: prefix + "." + TileNames.Output), + tcu_strategy switch { TcuDivideStrategy.ShareIf => GNNEConv2DSharedIF(call, block_name, glb_w, glb_if, glb_act, glb_psum, glb_of, is_init_psum, prefix), TcuDivideStrategy.ShareW => GNNEConv2DSharedW(call, block_name, glb_w, glb_if, glb_act, glb_psum, glb_of, is_depthwise, is_init_psum, prefix), TcuDivideStrategy.NoShare => GNNEConv2DSharedNone(call, block_name, glb_w, glb_if, glb_act, glb_psum, glb_of, is_init_psum, prefix), _ => throw new NotSupportedException(), }); + } + + protected virtual ISequentialBuilder LowerGnneTranspose(IndexMapKey parentKey, Call call, GNNETranspose op, string block_name, string prefix) + { + prefix = NameAllocator.Get(nameof(GNNETranspose)); + var call_in = IndexMapKey.Create(call, GNNETranspose.Input); + var seq = T.Sequential().Body( + Visit(call_in, prefix), GetBufferRegion(call_in, out var glb_trans_input), GetBufferRegion(parentKey, out var glb_trans_output, TileOptions.PingPong, name: prefix), EAction.TileBlock(block_name).Alloc(glb_trans_output.Buffers).Reads(glb_trans_input.BufferRegions).Body(EAction.MfuTranspose(glb_trans_input, glb_trans_output, op.Perm))); + + return seq; + } + + protected virtual ISequentialBuilder LowerGnneCrop(IndexMapKey parentKey, Call call, GNNECrop op, string block_name, string prefix) + { + prefix = NameAllocator.Get(nameof(GNNECrop)); + var call_in = IndexMapKey.Create(call, GNNECrop.Input); + var call_in_bbox = IndexMapKey.Create(call, GNNECrop.InputBBox); + var seq = T.Sequential().Body( + Visit(call_in, prefix), + Visit(call_in_bbox, prefix), + GetBufferRegion(call_in, out var glb_crop_input), + GetBufferRegion(call_in_bbox, out var glb_crop_bbox), + GetBufferRegion(parentKey, out var glb_crop_output, TileOptions.PingPong, name: prefix), + EAction.TileBlock(block_name).Alloc(glb_crop_output.Buffers). + Reads(glb_crop_input.BufferRegions, glb_crop_bbox.BufferRegions). + Body( + EAction.MfuCrop( + glb_crop_input, + glb_crop_output, + glb_crop_bbox, + op.ResizeMethod, + op.AlignMethod, + op.HalfPixelCenters))); + + return seq; + } + + protected virtual ISequentialBuilder LowerGnneActivation(IndexMapKey parentKey, Call call, GNNEActivation op, string block_name, string prefix) + { + prefix = NameAllocator.Get(nameof(GNNEActivation)); + var fusedclamps = ((TensorConst)call[GNNEActivation.FusedClamp]).Value.Cast(); + var call_in = IndexMapKey.Create(call, GNNEActivation.Input); + var call_in_act = IndexMapKey.Create(call, GNNEActivation.Act); + + var seq = T.Sequential().Body( + Visit(call_in, prefix), + Visit(call_in_act, prefix), + GetBufferRegion(call_in, out var glb_if), + GetBufferRegion(call_in_act, out var glb_act), + GetBufferRegion(parentKey, out var glb_of, TileOptions.PingPong, name: prefix), + EAction.TileBlock(block_name).Alloc(glb_of.Buffers).Reads(glb_if.BufferRegions, glb_act.BufferRegions).Body( + T.Let(out var m_once, 1).Body( + T.Let(out var c_once, MathF.Min(glb_if.RegionSize(1), ExtCompilerServices.Env.TcuActNum)).Body( + T.Let(out var tcu_h_chunk, TileUtilities.Split(glb_of.RegionSize(2), ExtCompilerServices.Env.TcuActNum)).Body(// segment tcu h in output_h + T.Let(out var n_active_tcu, TileUtilities.SplitTimes(glb_of.RegionSize(2), tcu_h_chunk)).Body( + T.Unrolled(out var tcu_oh, new(glb_of.Region[2].Start, glb_of.Region[2].Stop, tcu_h_chunk)).Body( + T.Let(out var tcu_index_bits, TileUtilities.GetTcuIndexBits(tcu_oh / tcu_h_chunk)).Body( + EAction.TcuPuConfAct(// 1. conf act + tcu_index_bits, + glb_act, + fusedclamps[0], + fusedclamps[1]), + EAction.TcuPuConf( + tcu_index_bits, + glb_if[.., .., (tcu_oh, MathF.Min(tcu_oh + tcu_h_chunk, glb_if.Region[2].Stop)), ..], + glb_of[.., .., (tcu_oh, MathF.Min(tcu_oh + tcu_h_chunk, glb_of.Region[2].Stop)), ..], + 1, + 1, + m_once, + c_once, + 1, + TcuComputeMode.Activation), + EAction.TcuDmConfOf( + tcu_index_bits, + glb_if[.., .., (tcu_oh, MathF.Min(tcu_oh + tcu_h_chunk, glb_if.Region[2].Stop)), ..], + glb_of[.., .., (tcu_oh, MathF.Min(tcu_oh + tcu_h_chunk, glb_of.Region[2].Stop)), ..], + 0))), + EAction.TcuPuComputeDummy(TileUtilities.GetNTcuIndexBits(n_active_tcu), true))))))); + return seq; + } + + protected virtual ISequentialBuilder LowerGnnePdpReduce(IndexMapKey parentKey, Call call, GNNEPdpReduce op, string block_name, string prefix) + { + prefix = NameAllocator.Get(nameof(GNNEPdpReduce)); + var call_in = IndexMapKey.Create(call, GNNEPdpReduce.Input); + + // var ddr_if = BoundsInferGraph[call_in]; + // GlbReIndex(glb_of[.., .., (tcu_oh, MathF.Min(tcu_oh + tcu_h_chunk, glb_of.Region[2].Stop)), ..], glb_if, out var if_paddings) + var seq = T.Sequential().Body( + Visit(call_in, prefix), + GetBufferRegion(call_in, out var glb_if), + GetBufferRegion(parentKey, out var glb_of, TileOptions.PingPong, name: prefix)); + GlbReIndex(glb_of, glb_if, out var sub_paddings); + seq.Body( + EAction.TileBlock(block_name).Alloc(glb_of.Buffers).Reads(glb_if.BufferRegions).Body( + EAction.PdpReduce( + glb_if, + glb_of, + call[GNNEPdpReduce.Filter], + call[GNNEPdpReduce.Stride], + sub_paddings[2].Before, + sub_paddings[2].After, + sub_paddings[3].Before, + sub_paddings[3].After, + op.ReduceOp))); + return seq; + } + + protected virtual ISequentialBuilder LowerGnneConv2DTranspose(IndexMapKey parentKey, Call call, GNNEConv2DTranspose op, string block_name, string prefix) + { + throw new NotSupportedException(); + } +} diff --git a/modules/Nncase.Modules.CPU/Passes/Tile/MultiFusionChecker.cs b/modules/Nncase.Modules.CPU/Passes/Tile/MultiFusionChecker.cs new file mode 100644 index 0000000000..a3a4380d74 --- /dev/null +++ b/modules/Nncase.Modules.CPU/Passes/Tile/MultiFusionChecker.cs @@ -0,0 +1,251 @@ +// Copyright (c) Canaan Inc. All rights reserved. +// Licensed under the Apache license. See LICENSE file in the project root for full license information. + +using System.Collections.Immutable; +using System.Runtime.CompilerServices; +using Nncase.Diagnostics; +using Nncase.IR; +using Nncase.TIR.CPU; + +namespace Nncase.Passes.Tile; + +/// +/// the multi Fusion checker. +/// +internal sealed class MultiFusionChecker : IFusionChecker +{ + private readonly List<(MultiLayerFusionConverter, int[], BufferSchedule.ScheduledResponse)> _caches = new(); + + private readonly TileOptions _tileOptions; + + public MultiFusionChecker(TileOptions tileOptions) + { + _tileOptions = tileOptions; + } + + [Flags] + public enum DeviceKind + { + Load, + Store, + Mfu, + Tcu, + None, + } + + public TIR.PrimFunction Convert(RunPassContext passOptions) + { + var (convertVisitor, final_tile_size, response) = _caches.First(); + if (DumpScope.Current.IsEnabled(DumpFlags.PassIR)) + { + response.Dump($"{response.LogicalPrimfunc.Name}_{string.Join("_", final_tile_size)}", convertVisitor.GetType().Name); + } + + return convertVisitor.BuildPhysicalPrimFunc(final_tile_size, response.SchedCandidate, response.LogicalPrimfunc); + } + + public bool Check(Fusion fusion, RunPassContext passOptions) + { + // 1. check all conv2d weights size less than glb size + var visitor = new MultiFusionPreCheckVisitor(); + visitor.Visit(fusion); + if (visitor.AllWeightSizeInBytes > ExtCompilerServices.Env.GlbSize) + { + return false; + } + + // note not support conv2d transpose in layer group. + if (visitor.CountCallOp() > 0) + { + return false; + } + + var curTileOptions = _tileOptions; + if ((visitor.DeviceUsage[DeviceKind.Mfu], visitor.DeviceUsage[DeviceKind.Tcu]) switch + { + (> 1, > 1) => true, + (> 1, 1) => true, + (1, > 1) => true, + _ => false, + }) + { + curTileOptions = curTileOptions with { PingPongNum = 3 }; + } + + // 2. try convert + var convertVisitor = new MultiLayerFusionConverter(curTileOptions); // note the grouped fusion must pingpong input. + var bodySeq = convertVisitor.Visit(fusion); + + // 3. search the tile size + var originLogicalPrimFunc = convertVisitor.BuildLogicalPrimFunc(bodySeq); + + var output_shape = fusion.Body.CheckedShape.ToValueArray(); + var search_space = convertVisitor.BoundsInferGraph.RootTileStep.ToArray(); + var candidate_tile_size = convertVisitor.SearchTileSize(TileOhSearchGenerator(curTileOptions, search_space, convertVisitor.BoundsInferGraph.RootPerm.ToArray()), originLogicalPrimFunc, curTileOptions.MultiWorkers, false, out var sched_response); + if (!candidate_tile_size.Any()) + { + return false; + } + + int[] final_tile_size = new int[candidate_tile_size.Length]; + if (convertVisitor.BalanceTileSize(candidate_tile_size, search_space)) + { + final_tile_size = convertVisitor.SearchTileSize(new TargetTileGenerator(candidate_tile_size), originLogicalPrimFunc, curTileOptions.MultiWorkers, true, out sched_response); + } + else + { + Array.Copy(candidate_tile_size, final_tile_size, candidate_tile_size.Length); + } + + // 5. check the input load usage and compute overlap + var input_shape = fusion.Parameters[0].CheckedShape.ToValueArray(); + var each_axis_tile_nums = final_tile_size.Zip(output_shape).Select(p => (int)System.Math.Ceiling(p.Second / (float)p.First)).ToArray(); + var total_tile_nums = TensorUtilities.GetProduct(each_axis_tile_nums); + if (total_tile_nums > 1) + { + var clamp = (TIR.K510.Segment seg, int i) => + { + return new TIR.K510.Segment(Math.Max(0, seg.Start), Math.Min(input_shape[i], seg.Stop), 1); + }; + + var first_segment = convertVisitor.BoundsInferGraph[convertVisitor.VarToKeyMap[fusion.Parameters[0]]]. + Eval(final_tile_size.Select(t => new TIR.K510.Segment(0, t, 1)).ToArray()). + Select((s, i) => clamp(s, i)). + ToArray(); + + int first_split_axis = 0; + for (int i = input_shape.Length - 1; i >= 0; i--) + { + if (first_segment[i].Length != input_shape[i]) + { + first_split_axis = i; + break; + } + } + + // when once load less than load burst, false + var burst_load_data = TensorUtilities.GetProduct(first_segment.Skip(first_split_axis).Select(s => s.Length).ToArray()); + if (burst_load_data < ExtCompilerServices.Env.LoadBurst) + { + return false; + } + + var second_segment = convertVisitor.BoundsInferGraph[convertVisitor.VarToKeyMap[fusion.Parameters[0]]].Eval(final_tile_size.Select((t, i) => + t < output_shape[i] ? + new TIR.K510.Segment(t, System.Math.Min(t * 2, output_shape[i]), 1) : + new TIR.K510.Segment(0, t, 1)).ToArray()). + Select((s, i) => clamp(s, i)). + ToArray(); + + // Todo 因为我无法知道在当前维度切分会影响哪个维度的变化, 比如带有transpose的, 可能我在c上切,只影响 h w. 所以直接计算所有的的交集 + var overlaps = first_segment.Zip(second_segment).Select(p => p.First.Intersect(p.Second)).ToArray(); + if (Array.IndexOf(convertVisitor.BoundsInferGraph.RootPerm.ToArray(), TIR.K510.NamedAxis.H) is int h && h != -1) + { + // 如果只在h上切分, 只需要考虑h上的overlap有没有超过0.3 + if (overlaps[h] > (input_shape[h] * 0.3)) + { + return false; + } + } + else + { + if (TensorUtilities.GetProduct(overlaps) > TensorUtilities.GetProduct(input_shape) * 0.3) + { + return false; + } + } + } + + _caches.Add((convertVisitor, candidate_tile_size, sched_response)); + if (_caches.Count > 1) + { + _caches.RemoveAt(0); + } + + return true; + } + + /// + /// + /// 只在oh上切分. + /// + /// + private ISearchTileGenerator TileOhSearchGenerator(TileOptions tileOptions, Segment[] search_spaces, TIR.K510.NamedAxis[] rootPerm) + { + if (Array.IndexOf(rootPerm, TIR.K510.NamedAxis.C) is int c && c != -1) + { + search_spaces[c].Start = search_spaces[c].Stop; // not tile oc + } + + if (Array.IndexOf(rootPerm, TIR.K510.NamedAxis.H) is int h && h != -1) + { + // 因为在h上切分, 如果ping pong那么需要限制大小 + if (tileOptions.PingPong) + { + if (search_spaces[h].ClampStop(2, out var new_h_seg)) + { + // assume tile h must > 8 for tcu use. + search_spaces[h] = new Segment(Math.Min(Math.Max(search_spaces[h].Step, ExtCompilerServices.Env.TcuActNum), new_h_seg.Stop), new_h_seg.Stop, new_h_seg.Step); + } + } + } + + if (Array.IndexOf(rootPerm, TIR.K510.NamedAxis.W) is int w && w != -1) + { + search_spaces[w].Start = Math.Min(search_spaces[w].Stop, Math.Max(search_spaces[w].Step, ExtCompilerServices.Env.PuWidth)); // no tile w + } + + // 如果有perm, 那就是 c w h n 方式搜, 没有perm就是从最后搜到最前 所以在有ping pong的时候需要限制切分. + if (rootPerm.All(r => r == NamedAxis.UnKnow) && tileOptions.PingPong) + { + if (search_spaces[^1].ClampStop(2, out var new_seg)) + { + search_spaces[^1] = new_seg; + } + } + + return new DefaultSearchTileGenerator(search_spaces, rootPerm); + } + + internal sealed class MultiFusionPreCheckVisitor : ExprVisitor + { + public Dictionary DeviceUsage { get; } = new() + { + { DeviceKind.Load, 0 }, + { DeviceKind.Store, 0 }, + { DeviceKind.Mfu, 0 }, + { DeviceKind.Tcu, 0 }, + { DeviceKind.None, 0 }, + }; + + public int AllWeightSizeInBytes { get; private set; } + + public int CountCallOp() + where T : Op + { + return ExprMemo.Keys.Count(e => e is Call { Target: Op t } && t.GetType() == typeof(T)); + } + + protected override bool DefaultVisitLeaf(Expr expr) => true; + + protected override bool VisitLeafCall(Call expr) + { + if (expr is Call { Target: IR.K510.GNNEConv2D } && expr[IR.K510.GNNEConv2D.Weights] is Expr weights) + { + AllWeightSizeInBytes += weights.CheckedShape.Prod().FixedValue * weights.CheckedDataType.SizeInBytes; + } + + DeviceUsage[GetDeviceType(expr.Target)]++; + return true; + } + + private static DeviceKind GetDeviceType(Expr op) => op switch + { + IR.K510.GNNELoad => DeviceKind.Load, + IR.K510.GNNEStore => DeviceKind.Store, + IR.K510.GNNEConv2D or IR.K510.GNNEActivation => DeviceKind.Tcu, + IR.K510.GNNEReduce or IR.K510.GNNEMeshNet or IR.K510.GNNETranspose or IR.K510.GNNEPdpReduce or IR.K510.GNNECrop => DeviceKind.Mfu, + _ => DeviceKind.None, + }; + } +} diff --git a/modules/Nncase.Modules.CPU/Passes/Tile/MultiLayerFusionConverter.cs b/modules/Nncase.Modules.CPU/Passes/Tile/MultiLayerFusionConverter.cs new file mode 100644 index 0000000000..4a07a4a17c --- /dev/null +++ b/modules/Nncase.Modules.CPU/Passes/Tile/MultiLayerFusionConverter.cs @@ -0,0 +1,228 @@ +// Copyright (c) Canaan Inc. All rights reserved. +// Licensed under the Apache license. See LICENSE file in the project root for full license information. + +using System.Runtime.CompilerServices; +using Nncase.Diagnostics; +using Nncase.IR; +using Nncase.Passes.BufferSchedule; +using Nncase.TIR; +using Nncase.TIR.Builders; +using Nncase.TIR.K510; +using Nncase.TIR.K510.Builders; +using MathF = Nncase.IR.F.Math; + +namespace Nncase.Passes.Tile; + +internal class MultiLayerFusionConverter : LayerFusionConverter +{ + public MultiLayerFusionConverter(TileOptions tileOptions) + { + TileOptions = tileOptions; + } + + public override Fusion CurrentFusion { get; protected set; } = null!; + + public override IBoundsInferGraph BoundsInferGraph { get; protected set; } = null!; + + /// + /// Gets or sets calc the loop count. + /// + public override Expr LoopCount { get; protected set; } = null!; + + public override Expr LoopCountOuter { get; protected set; } = null!; // LoopCount / TileOptions.PingPongNum; + + public override Expr LoopCountInner { get; protected set; } = null!; // LoopCount % TileOptions.PingPongNum; + + public override Expr Visit(Fusion fusion) + { + if (CurrentFusion is null) + { + CurrentFusion = fusion; + } + else + { + throw new InvalidOperationException("Can't Visit More Than One Fusion!"); + } + + // 0. init the fields + var output_shape = Tile.TileUtilities.GetFusionRealOutputShape(fusion.Body); + BoundsInferGraph = ExtCompilerServices.MakeBoundsInferGraph((Call)fusion.Body); + TileSizeVars.AddRange(output_shape.Select((_, i) => new Var($"dim{i}_tile", new TensorType(DataTypes.Int32, Shape.Scalar)))); + + // 1. make the tile gird loop + NestedBlocks.AddRange(new[] { EAction.TileBlock("MainBlock") }.Concat(Enumerable.Range(0, TileSizeVars.Count).Select(i => EAction.TileBlock($"TileBlock_{i}")))); + + LoopDomains.AddRange(output_shape.Zip(TileSizeVars).Select(t => new TIR.Range(0, t.First, t.Second))); + for (int i = 0; i < TileSizeVars.Count; i++) + { + NestedLoops.Add(T.ForLoop(out var loopVar, LoopDomains[i], LoopMode.Unrolled, $"loop_var_{i}")); + LoopVars.Add(loopVar); + } + + object lastBody = NestedBlocks[^1]; + for (int i = NestedLoops.Count - 1; i >= 0; i--) + { + lastBody = NestedLoops[i].Body(lastBody); + lastBody = NestedBlocks[i].Body(lastBody); + } + + // 2. create the bounds infer function input arguments with the new loop var. + BoundsInferGraph.RootBounds = output_shape.Select((s, i) => + { + var loopVar = LoopVars[i]; + return new TIR.Range(loopVar, IR.F.Math.Min(loopVar + TileSizeVars[i], s), 1); + }).ToList(); + + // 3. set up loop count + var shape = new Expr[LoopVars.Count]; + var upbounds = CurrentFusion.Body.CheckedShape.ToValueArray(); + for (int j = LoopVars.Count - 1; j >= 0; j--) + { + shape[j] = TileUtilities.SplitTimes(upbounds[j], TileSizeVars[j]); + } + + var strides = TensorUtilities.GetStrides(shape).ToArray(); + var indices = LoopVars.Select((v, j) => (Expr)(v / TileSizeVars[j])).ToArray(); + LoopCount = TensorUtilities.GetIndex(strides, indices); + LoopCountOuter = LoopCount / TileOptions.PingPongNum; + LoopCountInner = LoopCount % TileOptions.PingPongNum; + + return Visit((Call)fusion.Body, "root"); + } + + /// + /// convert to the final prim func. + /// + /// . + public PrimFunction VisitToPrimFunc(Fusion fusion) + { + // 1. visit the fusion + var bodySeq = Visit(fusion); + + // 2. build the prim func with tile size vars. + var logicalPrimFunc = BuildLogicalPrimFunc(bodySeq); + + // 3. seach the tiling size + var search_spaces = BoundsInferGraph.RootTileStep.ToArray(); + ISearchTileGenerator tileGenerator; + if (TileOptions.TargetTileSize.Any()) + { + for (int i = 0; i < TileOptions.TargetTileSize.Length; i++) + { + System.Diagnostics.Trace.Assert(TileOptions.TargetTileSize[i] <= search_spaces[i].Stop); + } + + tileGenerator = new TargetTileGenerator(TileOptions.TargetTileSize); + } + else + { + var perm = BoundsInferGraph.RootPerm.ToArray(); + + // when ping pong all, clamp the upper bounds by perm order. + if (TileOptions.PingPong) + { + var re_perm = perm.Zip(Enumerable.Range(0, perm.Length)).OrderBy(t => t.First).Select(t => t.Second).ToArray(); + + var pp_axis = NamedAxis.H; + + // 如果已知维度, 那么在pp axis上进行切分 + if (Array.IndexOf(perm, NamedAxis.H) is var h && h != -1 && Array.IndexOf(perm, NamedAxis.W) is var w && w != -1) + { + // if split the h will less than one burst, split on c. + if ((int)System.Math.Ceiling(search_spaces[h].Stop / (float)TileOptions.PingPongNum) * search_spaces[w].Stop < 128) + { + pp_axis = NamedAxis.C; + } + else + { + pp_axis = NamedAxis.H; + } + } + + for (int i = 0; i < perm.Length; i++) + { + var p = re_perm[i]; + if (perm[i] == pp_axis && search_spaces[p].ClampStop(TileOptions.PingPongNum, out var new_seg)) + { + search_spaces[p] = new_seg; + break; + } + } + } + + { + if (Array.IndexOf(perm, NamedAxis.H) is var h && h != -1 && Array.IndexOf(perm, NamedAxis.W) is var w && w != -1) + { + // if one layer output less than 128, don't split the hw + if (search_spaces[h].Stop * search_spaces[w].Stop < ExtCompilerServices.Env.LoadBurst) + { + search_spaces[h].Start = search_spaces[h].Stop; + search_spaces[w].Start = search_spaces[w].Stop; + } + else + { + search_spaces[h].Start = Math.Min(Math.Max(search_spaces[h].Step, ExtCompilerServices.Env.TcuActNum), search_spaces[h].Stop); + search_spaces[w].Start = Math.Min(Math.Max(search_spaces[w].Step, ExtCompilerServices.Env.PuWidth), search_spaces[w].Stop); + } + } + } + + tileGenerator = new DefaultSearchTileGenerator(search_spaces, BoundsInferGraph.RootPerm); + } + + int[] candidate_tile_size = SearchTileSize(tileGenerator, logicalPrimFunc, TileOptions.MultiWorkers, false, out var response); + if (!candidate_tile_size.Any()) + { + throw new TileFailedException(); + } + + int[] final_tile_size = Array.Empty(); + if (!TileOptions.TargetTileSize.Any() && TileOptions.PingPong && BalanceTileSize(candidate_tile_size, search_spaces)) + { + final_tile_size = SearchTileSize(new TargetTileGenerator(candidate_tile_size), logicalPrimFunc, TileOptions.MultiWorkers, true, out response); + } + else + { + final_tile_size = candidate_tile_size; + } + + if (DumpScope.Current.IsEnabled(DumpFlags.PassIR)) + { + response.Dump($"{CurrentFusion.Name}_{string.Join("_", final_tile_size)}", GetType().Name); + } + + // 4. the local logical buffer to phsy buffer + return BuildPhysicalPrimFunc(final_tile_size, response.SchedCandidate, response.LogicalPrimfunc); + } + + /// + /// 1. if inner loop var > half, balance it. + /// 2. find the highest axis loop var == up_bounds, split it. + /// + public override bool BalanceTileSize(int[] tile_size, Segment[] search_spaces) + { + bool changed = false; + + // balance tile + for (int i = search_spaces.Length - 1; i >= 0; i--) + { + if (search_spaces[i].BalanceTile(tile_size[i], out var newTile)) + { + tile_size[i] = newTile; + return true; + } + } + + // force ping pong + for (int i = 0; i < search_spaces.Length; i++) + { + if (search_spaces[i].ClampStop(2, out var new_seg) && new_seg.Stop < tile_size[i]) + { + tile_size[i] = new_seg.Stop; + return true; + } + } + + return changed; + } +} diff --git a/modules/Nncase.Modules.CPU/Passes/Tile/TileOptions.cs b/modules/Nncase.Modules.CPU/Passes/Tile/TileOptions.cs new file mode 100644 index 0000000000..1016a0c681 --- /dev/null +++ b/modules/Nncase.Modules.CPU/Passes/Tile/TileOptions.cs @@ -0,0 +1,21 @@ +using System; +using System.Collections.Generic; +using System.Linq; +using System.Text; +using System.Threading.Tasks; + +namespace Nncase.Passes.Tile; + +/// +/// TileOptions. +/// +/// TargetTileSize. +/// ForceFence. +/// 是否进行ping pong. +/// PingPongNum. +/// 对于测试. +/// 是否开启多线程搜索. +public sealed record TileOptions(int[] TargetTileSize, bool ForceFence, bool PingPong, int PingPongNum, bool ForceMultiLayer, bool MultiWorkers) +{ + public static TileOptions Default { get; } = new(Array.Empty(), false, true, 2, false, true); +} diff --git a/modules/Nncase.Modules.CPU/Passes/Tile/TwoFusionChecker.cs b/modules/Nncase.Modules.CPU/Passes/Tile/TwoFusionChecker.cs new file mode 100644 index 0000000000..b0a8a858c3 --- /dev/null +++ b/modules/Nncase.Modules.CPU/Passes/Tile/TwoFusionChecker.cs @@ -0,0 +1,131 @@ +// Copyright (c) Canaan Inc. All rights reserved. +// Licensed under the Apache license. See LICENSE file in the project root for full license information. + +#if false +using System.Collections.Immutable; +using System.Runtime.CompilerServices; +using Nncase.Diagnostics; +using Nncase.IR; +using Nncase.PatternMatch; +using Nncase.TIR.K510; +using static Nncase.PatternMatch.Utility; + +[assembly: InternalsVisibleTo("Nncase.Tests.K510")] + +namespace Nncase.Passes.Tile; + +/// +/// the two Fusion checker. +/// 专门前一层可以在conv的ic上tiling的情况. +/// +internal sealed class TwoFusionChecker : IFusionChecker +{ + private readonly List<(LayerFusionConverter, int[], BufferSchedule.ScheduledResponse)> _caches = new(); + + private readonly TileOptions _tileOptions; + + public TwoFusionChecker(TileOptions tileOptions) + { + _tileOptions = tileOptions; + } + + /// + /// Gets 匹配 conv2d + 非reduction. + /// + public static Pattern TwoFusionPattern { get; } = IsCallWildcard( + null, + IsOp(), + IsCallWildcard( + "conv2d", + IsOp(), + IsCallWildcard(null, IsOp("calleeOp", op => op is IR.K510.GNNEMeshNet or IR.K510.GNNEPdpReduce), IsCallWildcard(null, IsOp())))); + + public TIR.PrimFunction Convert(RunPassContext passOptions) + { + var (convertVisitor, final_tile_size, response) = _caches.First(); + if (DumpScope.Current.IsEnabled(DumpFlags.PassIR)) + { + response.Dump($"{response.LogicalPrimfunc.Name}_{string.Join("_", final_tile_size)}", nameof(LayerFusionOcIcConverter)); + } + + return convertVisitor.BuildPhysicalPrimFunc(final_tile_size, response.SchedCandidate, response.LogicalPrimfunc); + } + + public bool Check(Fusion fusion, RunPassContext passOptions) + { + // 1. try match pattern + if (!CompilerServices.TryMatchRoot(fusion.Body, TwoFusionPattern, out var matchResult)) + { + return false; + } + + // 2. try convert + var convertVisitor = new LayerFusionOcIcConverter(_tileOptions, TileUtilities.ChoiceTcuStrategy((Call)matchResult["conv2d"], out _), false); // note the grouped fusion must pingpong input. + var bodySeq = convertVisitor.Visit(fusion); + + // 3. search the tile size + var originLogicalPrimFunc = convertVisitor.BuildLogicalPrimFunc(bodySeq); + _ = fusion.Body.CheckedShape.ToValueArray(); + var search_space = convertVisitor.OCBoundsInferGraph.RootTileStep.ToArray(); + search_space = search_space.Concat(new[] { new Segment(1, convertVisitor.Conv2DInShape[1], 1) }).ToArray(); + var candidate_tile_size = convertVisitor.SearchTileSize( + SearchGenerator(search_space), + originLogicalPrimFunc, + _tileOptions.MultiWorkers, + false, + out var sched_response); + if (!candidate_tile_size.Any()) + { + return false; + } + + if (_tileOptions.PingPong && convertVisitor.BalanceTileSize(candidate_tile_size, search_space)) + { + _ = convertVisitor.SearchTileSize(new TargetTileGenerator(candidate_tile_size), originLogicalPrimFunc, _tileOptions.MultiWorkers, true, out sched_response); + } + else + { + } + + _caches.Add((convertVisitor, candidate_tile_size, sched_response)); + if (_caches.Count > 1) + { + _caches.RemoveAt(0); + } + + return true; + } + + /// + /// + /// do not tile on w dimension. + /// + /// + /// . + private ISearchTileGenerator SearchGenerator(Segment[] search_spaces) + { + var newSpaces = search_spaces.ToArray(); + + // 这里就优先在ic上切分ping pong. 因为在ic上切分对于if来说都是不一样的. + var ic = newSpaces.Length - 1; + if (newSpaces[ic].ClampStop(2, out var new_seg)) + { + newSpaces[ic] = new_seg; + } + + // ic最小也得分两个tcu. + newSpaces[ic].Start = Math.Min(ExtCompilerServices.Env.TcuActNum * ExtCompilerServices.Env.PuHeight, newSpaces[ic].Stop); + + var w = 3; + newSpaces[w].Start = newSpaces[w].Stop; // no tile w + + return new QueuedSearchTileGenerator(newSpaces, g => + { + g.Queue.Add((1, System.Math.Min(ExtCompilerServices.Env.PuWidth * ExtCompilerServices.Env.TcuActNum, g.UpperBounds[1]))); // oc + g.Queue.Add((4, g.UpperBounds[4])); // ic + g.Queue.Add((2, g.UpperBounds[2])); // h + g.Queue.Add((1, g.UpperBounds[1])); // oc + }); + } +} +#endif diff --git a/modules/Nncase.Modules.CPU/Targets/CPUTarget.cs b/modules/Nncase.Modules.CPU/Targets/CPUTarget.cs index c82f0d24d1..fc7c392e6f 100644 --- a/modules/Nncase.Modules.CPU/Targets/CPUTarget.cs +++ b/modules/Nncase.Modules.CPU/Targets/CPUTarget.cs @@ -7,8 +7,6 @@ using System.Text; using System.Threading.Tasks; using Microsoft.Extensions.Configuration; -using Microsoft.Extensions.Options; -using Microsoft.VisualBasic; using Nncase.CodeGen; using Nncase.CodeGen.StackVM; using Nncase.IR; diff --git a/modules/Nncase.Modules.CPU/packages.lock.json b/modules/Nncase.Modules.CPU/packages.lock.json index 444b371070..4ed711f699 100644 --- a/modules/Nncase.Modules.CPU/packages.lock.json +++ b/modules/Nncase.Modules.CPU/packages.lock.json @@ -11,6 +11,31 @@ "StyleCop.Analyzers.Unstable": "1.2.0.435" } }, + "Google.OrTools.runtime.linux-arm64": { + "type": "Transitive", + "resolved": "9.4.1874", + "contentHash": "Z46ndZcZa2Lt5b76xU9kxVYbPLg/LfuMufhUVsu3Qo3L7Bibf7WXd9j7RRldjnuv8RIHWTqb0b+2FwwMxs0c5A==" + }, + "Google.OrTools.runtime.linux-x64": { + "type": "Transitive", + "resolved": "9.4.1874", + "contentHash": "zGeDb8FuvP9HXjrsU7krVXtSDFpR+DUGNEsH51k94jL9tzf2vWYI8+WUBRHZ/cGe50dpLr+vIjfcNo3gFyOpkQ==" + }, + "Google.OrTools.runtime.osx-arm64": { + "type": "Transitive", + "resolved": "9.4.1874", + "contentHash": "Wo0ZfDaH6DhiQw0jZm4HWJm/oPGPpWNwOLUz+EYaoH3MLtocSxItHGQj/Ta3HyhXnYNOv+TliAH8L+8RCXu/2w==" + }, + "Google.OrTools.runtime.osx-x64": { + "type": "Transitive", + "resolved": "9.4.1874", + "contentHash": "IAfGgKR1og6vU87axK1d37Ak/4jy8B4NMoElovG/KZc/2UY+cJEAQDA709UMegtI4lBhuxTWFNUiHQYmRIB9yQ==" + }, + "Google.OrTools.runtime.win-x64": { + "type": "Transitive", + "resolved": "9.4.1874", + "contentHash": "fUs5qDnZA6itygolcX6nPuachQkY9CVvQbakIzIiRAWKcaj8umQAbFdGwbkyzp3qp34BKW5mtPVsmMyfQBBjOQ==" + }, "libortki": { "type": "Transitive", "resolved": "0.0.2", @@ -116,6 +141,7 @@ "type": "Project", "dependencies": { "GiGraph.Dot": "[2.0.0, )", + "Google.OrTools": "[9.4.1874, )", "NetFabric.Hyperlinq": "[3.0.0-beta48, )", "Nncase.Core": "[1.0.0, )", "Nncase.Evaluator": "[1.0.0, )", @@ -173,6 +199,26 @@ "resolved": "2.0.0", "contentHash": "ThvS2mQVveSkTMUm04tMbRYzu1XFPV8xBHISrUMp02APjhv9IRbLu3v3upTPCywORx2Ds/c6AqEUL1WU6kPfuQ==" }, + "Google.OrTools": { + "type": "CentralTransitive", + "requested": "[9.4.1874, )", + "resolved": "9.4.1874", + "contentHash": "jqRoI+pYlym+fhoU25u+13oti5h+772bllQ9zDitTVMclDXVTiG6pxzvmYO74wnADBMdpb2SQlgiNQxoNk5dlA==", + "dependencies": { + "Google.OrTools.runtime.linux-arm64": "9.4.1874", + "Google.OrTools.runtime.linux-x64": "9.4.1874", + "Google.OrTools.runtime.osx-arm64": "9.4.1874", + "Google.OrTools.runtime.osx-x64": "9.4.1874", + "Google.OrTools.runtime.win-x64": "9.4.1874", + "Google.Protobuf": "3.19.4" + } + }, + "Google.Protobuf": { + "type": "CentralTransitive", + "requested": "[3.19.4, )", + "resolved": "3.19.4", + "contentHash": "fd07/ykL4O4FhqrZIELm5lmiyOHfdPg9+o+hWr6tcfRdS7tHXnImg/2wtogLzlW2eEmr0J7j6ZrZvaWOLiJbxQ==" + }, "Microsoft.Extensions.Hosting.Abstractions": { "type": "CentralTransitive", "requested": "[6.0.0, )", diff --git a/src/Nncase.Tests/Properties/launchSettings.json b/src/Nncase.Tests/Properties/launchSettings.json index d081379b07..3109588c45 100644 --- a/src/Nncase.Tests/Properties/launchSettings.json +++ b/src/Nncase.Tests/Properties/launchSettings.json @@ -2,7 +2,7 @@ "profiles": { "Nncase.Tests": { "commandName": "Project", - "nativeDebugging": true + "nativeDebugging": false } } } \ No newline at end of file From dc1598dab0ca9482d3b7aa936318bf71ea8af2fb Mon Sep 17 00:00:00 2001 From: huochenghai Date: Fri, 14 Jul 2023 15:31:21 +0800 Subject: [PATCH 004/308] fix build --- .../Passes/CPUFusionToTirPass.cs | 43 ++++++++------- .../Passes/Tile/CPUFusionGroupMutator.cs | 55 ++++++++++--------- .../Passes/Tile/IFusionChecker.cs | 5 +- .../Passes/Tile/LayerFusionConverter.cs | 2 + .../Passes/Tile/MultiFusionChecker.cs | 3 +- .../Passes/Tile/MultiLayerFusionConverter.cs | 4 +- .../Passes/Tile/TileOptions.cs | 5 +- .../Nncase.Modules.CPU/Targets/CPUTarget.cs | 2 +- 8 files changed, 67 insertions(+), 52 deletions(-) diff --git a/modules/Nncase.Modules.CPU/Passes/CPUFusionToTirPass.cs b/modules/Nncase.Modules.CPU/Passes/CPUFusionToTirPass.cs index 128ca15c91..5e1ff781b0 100644 --- a/modules/Nncase.Modules.CPU/Passes/CPUFusionToTirPass.cs +++ b/modules/Nncase.Modules.CPU/Passes/CPUFusionToTirPass.cs @@ -1,4 +1,8 @@ -using System; +// Copyright (c) Canaan Inc. All rights reserved. +// Licensed under the Apache license. See LICENSE file in the project root for full license information. + +#if false +using System; using System.Collections.Generic; using System.Linq; using System.Text; @@ -9,6 +13,7 @@ using Nncase.Passes.Analysis; using Nncase.Passes.Mutators; using Nncase.Passes.Tile; +using Nncase.Targets; using Nncase.TIR; namespace Nncase.Passes; @@ -32,33 +37,30 @@ protected override Task RunCoreAsync(IRModule module, RunPassContext o Dictionary fusionConertedCache = new(ReferenceEqualityComparer.Instance); // convert the fusion as entry. - for (int i = 0; i < module.Functions.Count; i++) - { - if (module.Functions[i] is Fusion { ModuleKind: "cpu" } fusion) - { - TIR.PrimFunction primFunction; - var visitor = new MultiLayerFusionConverter(_tileOptions); - primFunction = visitor.VisitToPrimFunc(fusion); - - CompilerServices.InferenceType(primFunction); - fusionConertedCache[fusion] = primFunction; - module.Replace(i, primFunction); - } - } + // for (int i = 0; i < module.Functions.Count; i++) + // { + // if (module.Functions[i] is Fusion { ModuleKind: CPUTarget.Kind } fusion) + // { + // TIR.PrimFunction primFunction; + // var visitor = new MultiLayerFusionConverter(_tileOptions); + // primFunction = visitor.VisitToPrimFunc(fusion); + // + // CompilerServices.InferenceType(primFunction); + // fusionConertedCache[fusion] = primFunction; + // module.Replace(i, primFunction); + // } + // } // convert the stackvm function call k510 fusion for (int i = 0; i < module.Functions.Count; i++) { - if (module.Functions[i] is Function { ModuleKind: "stackvm" } func) + if (module.Functions[i] is Function { ModuleKind: CPUTarget.Kind } func) { - var analysis = new Dictionary - { - [typeof(IExprUserAnalysisResult)] = AnalyzerManager.GetAnaylsis(func), - }; + var analysis = new Dictionary { [typeof(IExprUserAnalysisResult)] = AnalyzerManager.GetAnaylsis(func), }; var rewriter = new DataFlowMergeRewriter(); var fusionCheckCache = new Dictionary(ReferenceEqualityComparer.Instance); - var post = (Function)rewriter.Rewrite(func, new Mutators.IMergeRewriteRule[] { new GNNESameInputFusionMergeRule(), }, (rule, option) => new CPUFusionGroupMutator(fusionCheckCache, _tileOptions, rule, option), new() { AnalysisResults = analysis, MatchOptions = new Mutators.FusionGroupMutator.GroupedMatchOptions() }); + // var post = (Function)rewriter.Rewrite(func, new Mutators.IMergeRewriteRule[] { new GNNESameInputFusionMergeRule(), }, (rule, option) => new CPUFusionGroupMutator(fusionCheckCache, _tileOptions, rule, option), new() { AnalysisResults = analysis, MatchOptions = new Mutators.FusionGroupMutator.GroupedMatchOptions() }); // if (DumpScope.Current.IsEnabled(DumpFlags.PassIR)) // { @@ -115,3 +117,4 @@ protected override async Task OnPassEndAsync(IRModule post, RunPassContext conte _fusionMacsMap.Clear(); } } +#endif diff --git a/modules/Nncase.Modules.CPU/Passes/Tile/CPUFusionGroupMutator.cs b/modules/Nncase.Modules.CPU/Passes/Tile/CPUFusionGroupMutator.cs index 4b206ca688..3cd7ab7c9d 100644 --- a/modules/Nncase.Modules.CPU/Passes/Tile/CPUFusionGroupMutator.cs +++ b/modules/Nncase.Modules.CPU/Passes/Tile/CPUFusionGroupMutator.cs @@ -1,6 +1,6 @@ // Copyright (c) Canaan Inc. All rights reserved. // Licensed under the Apache license. See LICENSE file in the project root for full license information. - +#if false using System.Runtime.CompilerServices; using Nncase.IR; using Nncase.PatternMatch; @@ -9,7 +9,7 @@ namespace Nncase.Passes.Tile; -internal sealed class GNNESameInputFusionMergeRule : Mutators.SameInputFusionMergeRule +internal sealed class CPUSameInputFusionMergeRule : Mutators.SameInputFusionMergeRule { public override string ModuleKind => CPUTarget.Kind; @@ -57,11 +57,11 @@ public CPUFusionGroupMutator( public override bool MergedFusionCheckCallBack(Fusion mergedFusion, HashSet candidateFusions) { // note the gnne activate must be first layer. - if (mergedFusion.Body is Call { Target: IR.K510.GNNEStore } st_call && - st_call[IR.K510.GNNEStore.Input] is Call { Target: IR.K510.GNNEActivation }) - { - return false; - } + // if (mergedFusion.Body is Call { Target: IR.K510.GNNEStore } st_call && + // st_call[IR.K510.GNNEStore.Input] is Call { Target: IR.K510.GNNEActivation }) + // { + // return false; + // } var checker = (IFusionChecker)Activator.CreateInstance(typeof(T), new object[] { _tileOptions })!; var ret = checker.Check(mergedFusion, PassOptions); @@ -103,7 +103,7 @@ public CheckedConvertMutator(Dictionary fusion_converted_c /// protected override Expr RewriteLeafFusion(Fusion expr) { - if (expr is Fusion { ModuleKind: K510Target.Kind } fusion) + if (expr is Fusion { ModuleKind: CPUTarget.Kind } fusion) { if (!_fusionConertedCache.TryGetValue(fusion, out _)) { @@ -114,28 +114,27 @@ protected override Expr RewriteLeafFusion(Fusion expr) } else { - if (CompilerServices.TryMatchRoot(fusion, Conv2DFusionConverter.Conv2DFusionPattern, out var matchResult)) - { - prim_func = Conv2DFusionConverter.VisitToPrimFunc(_tileOptions, fusion, matchResult, out _, out _); - } - else if (CompilerServices.TryMatchRoot(fusion, Conv2DTransposeFusionConverter.Conv2DFusionPattern, out matchResult)) - { - prim_func = Conv2DTransposeFusionConverter.VisitToPrimFunc(_tileOptions, fusion, matchResult, out _, out _); - } - else if (!_tileOptions.ForceMultiLayer && CompilerServices.TryMatchRoot(fusion, LSTMFusionConverter.FusionPattern, out matchResult)) - { - prim_func = LSTMFusionConverter.VisitToPrimFunc(_tileOptions, fusion, matchResult, out _, out _); - } - else - { - var visitor = new MultiLayerFusionConverter(_tileOptions); - prim_func = visitor.VisitToPrimFunc(fusion); - } + // if (CompilerServices.TryMatchRoot(fusion, Conv2DFusionConverter.Conv2DFusionPattern, out var matchResult)) + // { + // prim_func = Conv2DFusionConverter.VisitToPrimFunc(_tileOptions, fusion, matchResult, out _, out _); + // } + // else if (CompilerServices.TryMatchRoot(fusion, Conv2DTransposeFusionConverter.Conv2DFusionPattern, out matchResult)) + // { + // prim_func = Conv2DTransposeFusionConverter.VisitToPrimFunc(_tileOptions, fusion, matchResult, out _, out _); + // } + // else if (!_tileOptions.ForceMultiLayer && CompilerServices.TryMatchRoot(fusion, LSTMFusionConverter.FusionPattern, out matchResult)) + // { + // prim_func = LSTMFusionConverter.VisitToPrimFunc(_tileOptions, fusion, matchResult, out _, out _); + // } + // else + // { + // var visitor = new MultiLayerFusionConverter(_tileOptions); + // prim_func = visitor.VisitToPrimFunc(fusion); + // } } BaseFunction? convert_func = prim_func; _fusionConertedCache.Add(fusion, convert_func); - new DDrMacCalcVisitor(_fusionMacsMap).Visit(fusion); } } @@ -144,7 +143,7 @@ protected override Expr RewriteLeafFusion(Fusion expr) protected override Expr RewriteLeafCall(Call expr) { - if (expr.Target is Fusion { ModuleKind: K510Target.Kind } fusion) + if (expr.Target is Fusion { ModuleKind: CPUTarget.Kind } fusion) { var convert_func = _fusionConertedCache[fusion]; PrimFunctionWrapper wrapper; @@ -185,3 +184,5 @@ protected override Expr RewriteLeafCall(Call expr) return expr; } } + +#endif diff --git a/modules/Nncase.Modules.CPU/Passes/Tile/IFusionChecker.cs b/modules/Nncase.Modules.CPU/Passes/Tile/IFusionChecker.cs index 84b2d37072..9c7bd70876 100644 --- a/modules/Nncase.Modules.CPU/Passes/Tile/IFusionChecker.cs +++ b/modules/Nncase.Modules.CPU/Passes/Tile/IFusionChecker.cs @@ -1,4 +1,7 @@ -using System; +// Copyright (c) Canaan Inc. All rights reserved. +// Licensed under the Apache license. See LICENSE file in the project root for full license information. + +using System; using System.Collections.Generic; using System.Linq; using System.Text; diff --git a/modules/Nncase.Modules.CPU/Passes/Tile/LayerFusionConverter.cs b/modules/Nncase.Modules.CPU/Passes/Tile/LayerFusionConverter.cs index 8cf139f3fc..f205692a57 100644 --- a/modules/Nncase.Modules.CPU/Passes/Tile/LayerFusionConverter.cs +++ b/modules/Nncase.Modules.CPU/Passes/Tile/LayerFusionConverter.cs @@ -1,6 +1,7 @@ // Copyright (c) Canaan Inc. All rights reserved. // Licensed under the Apache license. See LICENSE file in the project root for full license information. +#if false using System.Reactive; using System.Runtime.CompilerServices; using NetFabric.Hyperlinq; @@ -1260,3 +1261,4 @@ protected virtual ISequentialBuilder LowerGnneConv2DTranspose(IndexM throw new NotSupportedException(); } } +#endif diff --git a/modules/Nncase.Modules.CPU/Passes/Tile/MultiFusionChecker.cs b/modules/Nncase.Modules.CPU/Passes/Tile/MultiFusionChecker.cs index a3a4380d74..7cbeab67c8 100644 --- a/modules/Nncase.Modules.CPU/Passes/Tile/MultiFusionChecker.cs +++ b/modules/Nncase.Modules.CPU/Passes/Tile/MultiFusionChecker.cs @@ -1,6 +1,6 @@ // Copyright (c) Canaan Inc. All rights reserved. // Licensed under the Apache license. See LICENSE file in the project root for full license information. - +#if false using System.Collections.Immutable; using System.Runtime.CompilerServices; using Nncase.Diagnostics; @@ -249,3 +249,4 @@ protected override bool VisitLeafCall(Call expr) }; } } +#endif diff --git a/modules/Nncase.Modules.CPU/Passes/Tile/MultiLayerFusionConverter.cs b/modules/Nncase.Modules.CPU/Passes/Tile/MultiLayerFusionConverter.cs index 4a07a4a17c..a99248cf14 100644 --- a/modules/Nncase.Modules.CPU/Passes/Tile/MultiLayerFusionConverter.cs +++ b/modules/Nncase.Modules.CPU/Passes/Tile/MultiLayerFusionConverter.cs @@ -1,13 +1,14 @@ // Copyright (c) Canaan Inc. All rights reserved. // Licensed under the Apache license. See LICENSE file in the project root for full license information. +#if false using System.Runtime.CompilerServices; using Nncase.Diagnostics; using Nncase.IR; using Nncase.Passes.BufferSchedule; using Nncase.TIR; using Nncase.TIR.Builders; -using Nncase.TIR.K510; +using Nncase.TIR.CPU; using Nncase.TIR.K510.Builders; using MathF = Nncase.IR.F.Math; @@ -226,3 +227,4 @@ public override bool BalanceTileSize(int[] tile_size, Segment[] search_spaces) return changed; } } +#endif diff --git a/modules/Nncase.Modules.CPU/Passes/Tile/TileOptions.cs b/modules/Nncase.Modules.CPU/Passes/Tile/TileOptions.cs index 1016a0c681..ba3368a679 100644 --- a/modules/Nncase.Modules.CPU/Passes/Tile/TileOptions.cs +++ b/modules/Nncase.Modules.CPU/Passes/Tile/TileOptions.cs @@ -1,4 +1,7 @@ -using System; +// Copyright (c) Canaan Inc. All rights reserved. +// Licensed under the Apache license. See LICENSE file in the project root for full license information. + +using System; using System.Collections.Generic; using System.Linq; using System.Text; diff --git a/modules/Nncase.Modules.CPU/Targets/CPUTarget.cs b/modules/Nncase.Modules.CPU/Targets/CPUTarget.cs index fc7c392e6f..6bee77d2a7 100644 --- a/modules/Nncase.Modules.CPU/Targets/CPUTarget.cs +++ b/modules/Nncase.Modules.CPU/Targets/CPUTarget.cs @@ -24,7 +24,7 @@ public class CPUTarget : ITarget /// /// Gets kind. /// - public static readonly string Kind = "cpu"; + public const string Kind = "cpu"; string ITarget.Kind => Kind; From ba3ca85ca017d0d7b098a1d8eeb2c77534351178 Mon Sep 17 00:00:00 2001 From: huochenghai Date: Fri, 14 Jul 2023 15:33:26 +0800 Subject: [PATCH 005/308] basic cpu module builder --- modules/Nncase.Modules.CPU/CodeGen/CSource.cs | 283 ++++++++++++++ .../CodeGen/CSourceVisitor.cs | 365 ++++++++++++++++++ .../CodeGen/FunctionBuilder.cs | 119 ++++++ .../CodeGen/LinkableFunction.cs | 30 ++ .../CodeGen/LinkableModule.cs | 45 +++ .../CodeGen/LinkedModule.cs | 31 ++ .../CodeGen/ModuleBuilder.cs | 39 ++ .../Nncase.Modules.CPU.csproj | 1 + .../Passes/Tile/CPUFusionConverter.cs | 148 +++++++ .../Runtime/CPU/CPURTModule.cs | 23 ++ modules/Nncase.Modules.CPU/packages.lock.json | 6 + src/Nncase.Cli/packages.lock.json | 1 + src/Nncase.Compiler/packages.lock.json | 1 + .../packages.lock.json | 7 + src/Nncase.Tests/packages.lock.json | 1 + 15 files changed, 1100 insertions(+) create mode 100644 modules/Nncase.Modules.CPU/CodeGen/CSource.cs create mode 100644 modules/Nncase.Modules.CPU/CodeGen/CSourceVisitor.cs create mode 100644 modules/Nncase.Modules.CPU/CodeGen/FunctionBuilder.cs create mode 100644 modules/Nncase.Modules.CPU/CodeGen/LinkableFunction.cs create mode 100644 modules/Nncase.Modules.CPU/CodeGen/LinkableModule.cs create mode 100644 modules/Nncase.Modules.CPU/CodeGen/LinkedModule.cs create mode 100644 modules/Nncase.Modules.CPU/CodeGen/ModuleBuilder.cs create mode 100644 modules/Nncase.Modules.CPU/Passes/Tile/CPUFusionConverter.cs create mode 100644 modules/Nncase.Modules.CPU/Runtime/CPU/CPURTModule.cs diff --git a/modules/Nncase.Modules.CPU/CodeGen/CSource.cs b/modules/Nncase.Modules.CPU/CodeGen/CSource.cs new file mode 100644 index 0000000000..8570adf641 --- /dev/null +++ b/modules/Nncase.Modules.CPU/CodeGen/CSource.cs @@ -0,0 +1,283 @@ +// Copyright (c) Canaan Inc. All rights reserved. +// Licensed under the Apache license. See LICENSE file in the project root for full license information. + +#if false +using System; +using System.Collections; +using System.Collections.Generic; +using System.Diagnostics; +using System.IO; +using System.Linq; +using System.Runtime.InteropServices; +using System.Text; +using Nncase.IR; +using Nncase.Schedule; +using Nncase.TIR; + +namespace Nncase.CodeGen; + +/// +/// the c source runtime function. +/// +/// +/// +public record CSourceRTFunction(string name, Delegate handle) : IRTFunction +{ + public string Name { get => name; set { } } + public Delegate Handle { get => handle; set { } } +} + +public class CSourceSerializeResult : ISerializeResult +{ + +} + +/// +/// c runtime module impl +/// +public class CSourceRTModel : IRTModule, IRTModel +{ + /// + public ModuleType ModuleType { get => CodeGen.ModuleType.Create("CSource"); set { } } + + /// + public ITarget Target { get; set; } + + /// + public IReadOnlyList Modules => throw new NotImplementedException(); + + /// + public string SourcePath { get; private set; } + + public IRModel Model { get; set; } + IRTFunction? _entry = null; + + /// + public bool IsSerialized { get; private set; } + + readonly List _functions = new(); + + /// + /// + /// + public CSourceRTModel(IRModel model, ITarget target) + { + SourcePath = CodeGenUtil.GetTempFileName("c"); + Model = model; + Target = target; + } + + /// + public byte[] Source { get => File.ReadAllBytes(SourcePath); set { } } + + /// + public string SourceExt { get => "c"; set { } } + + /// + public IRTFunction? Entry => _entry; + + /// + public IReadOnlyList Functions => _functions; + + /// + string _dllPath = ""; + + /// + /// write the c source code into source path. + /// + /// + void BuildCode() + { + if (File.Exists(SourcePath)) + File.Delete(SourcePath); + using (var writer = new StreamWriter(SourcePath, false, Encoding.UTF8)) + { + var visior = new CSourceHostBuildVisior(writer); + if (Model.Entry is null) { throw new InvalidProgramException("The Model Entry Is Null!"); } + if (Model.Entry.CheckedType is null && Model.Entry.InferenceType() == false) { throw new InvalidProgramException("The Model Entry Can't Inference Type!"); } + visior.Visit(Model.Entry); + } + } + + public void CompileCode() + { + if (!File.Exists(SourcePath)) + throw new InvalidProgramException("The Source Code Path Is Invalid!"); + var compiler = new CSourceCompiler(); + _dllPath = compiler.Compile(SourcePath); + } + + /// + /// bind each IR.Funtion with C function + /// + /// + public void ExportCode() + { + if (!File.Exists(_dllPath)) + throw new InvalidProgramException("The DLL Path Is Invalid!"); + var dllPtr = NativeLibrary.Load(_dllPath); + foreach (var module in Model.Modules) + { + foreach (var f in module.Callables) + { + var funcType = f.ToDelegateType(Path.GetFileName(_dllPath)); + var funPtr = NativeLibrary.GetExport(dllPtr, f.Name); + _functions.Add(new CSourceRTFunction(f.Name, funPtr.BindDelegate(funcType))); + if (f == Model.Entry) { _entry = _functions.Last(); } + } + } + } + + /// + public ISerializeResult Serialize() + { + if (IsSerialized) { return new CSourceSerializeResult(); } + BuildCode(); + CompileCode(); + ExportCode(); + return new CSourceSerializeResult(); + } + + /// + /// invoke the module entry + /// + /// input args + /// results + /// + public object? Invoke(params object?[]? args) + { + if (Entry is null) + throw new InvalidOperationException("This RTModule Have No Entry Function!"); + return Entry.Handle.DynamicInvoke(args); + } + + public string Dump(string name, string DumpDirPath) + { + var dump_path = $"{DumpDirPath}/{name}.{SourceExt}"; + using var file = File.Open(dump_path, FileMode.OpenOrCreate, FileAccess.Write); + using var writer = new StreamWriter(file); + writer.Write(Source); + return dump_path; + } + +} + +/// +/// the csource code compiler. +/// +public class CSourceCompiler +{ + /// + /// compiler exe name + /// + string _exe = "", _arch = "", _ext = ""; + + /// + /// select current pattern's exe + /// + /// + void PlatformSpecific() + { + if (RuntimeInformation.IsOSPlatform(OSPlatform.Linux)) + { + _exe = "gcc"; + _ext = "so"; + } + else if (RuntimeInformation.IsOSPlatform(OSPlatform.OSX)) + { + _exe = "clang"; + _ext = "dylib"; + } + else if (RuntimeInformation.IsOSPlatform(OSPlatform.Windows)) + { + _exe = "cmd"; + _ext = "dll"; + } + } + + void ArchSpecific() + { + _arch = RuntimeInformation.OSArchitecture switch + { + Architecture.X64 => RuntimeInformation.IsOSPlatform(OSPlatform.Linux) ? "x86-64" : "x86_64", + Architecture.Arm64 => "arm64", + _ => throw new NotSupportedException(RuntimeInformation.OSArchitecture.ToString()), + }; + } + + string ArgumentsSpecific(string sourcePath, string outPath) + { + if (RuntimeInformation.IsOSPlatform(OSPlatform.Linux)) + { + return $"{sourcePath} -fPIC -shared -march={Arch} -o {outPath}"; + } + else if (RuntimeInformation.IsOSPlatform(OSPlatform.OSX)) + { + return $"{sourcePath} -fPIC -shared -arch {Arch} -o {outPath}"; + } + else if (RuntimeInformation.IsOSPlatform(OSPlatform.Windows)) + { + var vsdir = Environment.GetEnvironmentVariable("VSAPPIDDIR") ?? throw new InvalidOperationException("Cannot find vs"); + var vcvardir = Path.Combine(vsdir, "..\\..\\VC\\Auxiliary\\Build\\vcvarsall.bat"); + return $"/C (\"{vcvardir}\" x64) && (cl /D_USRDLL /D_WINDLL \"{sourcePath}\" /MT /link /DLL /OUT:\"{outPath}\")"; + } + throw new System.ArgumentOutOfRangeException("Only Support Linux/Osx/Windows"); + } + + protected string Exe + { + get => _exe; + } + + protected string Arch + { + get => _arch; + } + + protected string Ext + { + get => _ext; + } + + public CSourceCompiler() + { + PlatformSpecific(); + ArchSpecific(); + } + + /// + /// compile the source txt, write to the out_path + /// + /// c source code + /// out .so path + /// outPath + public string Compile(string sourcePath, string outPath) + { + var errMsg = new StringBuilder(); + using (var errWriter = new StringWriter(errMsg)) + { + using (var proc = new Process()) + { + proc.StartInfo.FileName = Exe; + proc.StartInfo.Arguments = ArgumentsSpecific(sourcePath, outPath); + proc.StartInfo.RedirectStandardError = true; + proc.ErrorDataReceived += (sender, e) => errWriter.WriteLine(e.Data); + proc.Start(); + proc.BeginErrorReadLine(); + proc.WaitForExit(); + if (proc.ExitCode != 0) + { + throw new InvalidOperationException(errMsg.ToString()); + } + } + } + return outPath; + } + + /// + /// create the temp dll file and compile source + /// + /// + public string Compile(string sourcePath) => Compile(sourcePath, CodeGenUtil.GetTempFileName(Ext)); +} +#endif diff --git a/modules/Nncase.Modules.CPU/CodeGen/CSourceVisitor.cs b/modules/Nncase.Modules.CPU/CodeGen/CSourceVisitor.cs new file mode 100644 index 0000000000..60453f1563 --- /dev/null +++ b/modules/Nncase.Modules.CPU/CodeGen/CSourceVisitor.cs @@ -0,0 +1,365 @@ +// Copyright (c) Canaan Inc. All rights reserved. +// Licensed under the Apache license. See LICENSE file in the project root for full license information. + +using System; +using System.Collections; +using System.Collections.Generic; +using System.IO; +using System.Linq; +using System.Runtime.InteropServices; +using System.Text; +using NetFabric.Hyperlinq; +using Nncase.Diagnostics; +using Nncase.IR; +using Nncase.Runtime; +using Nncase.TIR; + +namespace Nncase.CodeGen; + +/// +/// the c symbol define. +/// +internal struct CSymbol +{ + public string Type; + public StringBuilder Doc; + + public CSymbol(string type, StringBuilder doc) + { + Type = type; + Doc = doc; + } + + public override string ToString() => $"{Type} {Doc}"; +} + +/// +/// convert the type/op to c name. +/// +internal static class NameConverter +{ + private static readonly Dictionary _primTypeToC = new() + { + { DataTypes.Boolean, "bool" }, + { DataTypes.Int8, "int8_t" }, + { DataTypes.Int16, "int16_t" }, + { DataTypes.Int32, "int32_t" }, + { DataTypes.Int64, "int64_t" }, + { DataTypes.UInt8, "uint8_t" }, + { DataTypes.UInt16, "uint16_t" }, + { DataTypes.UInt32, "uint32_t" }, + { DataTypes.UInt64, "uint64_t" }, + { DataTypes.Float32, "float" }, + { DataTypes.Float64, "double" }, + }; + + public static string ToC(this PrimType primType) => + _primTypeToC[primType]; + + public static string ToC(this DataType dataType) => dataType switch + { + PrimType ptype => ptype.ToC(), + PointerType { ElemType: PrimType etype } => etype.ToC() + "*", + _ => throw new NotSupportedException(dataType.ToString()), + }; +} + +/// +/// collect the csymbol's parameter. +/// +internal class CSymbolParamList : IParameterList, IEnumerable +{ + private CSymbol[] _symbols; + + public CSymbolParamList(CSymbol[] symbols) + { + this._symbols = symbols; + } + + public CSymbol this[ParameterInfo parameter] => _symbols[parameter.Index]; + + public CSymbol this[int index] => _symbols[index]; + + public IEnumerator GetEnumerator() + { + return ((IEnumerable)_symbols).GetEnumerator(); + } + + IEnumerator IEnumerable.GetEnumerator() + { + return _symbols.GetEnumerator(); + } +} + +/// +/// visitor for the build c source code, the expr vistor return (type string , name string). +/// +internal class CSourceHostBuildVisior : ExprFunctor +{ + /// + /// source writer . + /// TODO we need the decl writer. + /// + private readonly ScopeWriter _scope; + + /// + /// symbols name memo. + /// + private readonly Dictionary _symbols = new(ReferenceEqualityComparer.Instance); + + /// + /// Initializes a new instance of the class. + /// . + /// + /// TextWriter. + public CSourceHostBuildVisior(TextWriter textWriter) + { + _scope = new ScopeWriter(textWriter); + + // insert some declare + _scope.IndWriteLine(@" +#ifdef _WIN32 +#define EXPORT_API __declspec(dllexport) +#else +#define EXPORT_API +#endif"); + _scope.IndWriteLine("#include "); + } + + /// + /// void (*fun_ptr)(int). + /// + public string CallableTypeToPtr(CallableType type, string name) => $"{VisitType(type.ReturnType)} (*{name}_ptr)({string.Join(",", type.Parameters.Select(VisitType))})"; + + /// + public override string VisitType(TensorType type) + { + if (!type.IsScalar) + { + throw new NotSupportedException($"{type}"); + } + + return type.DType.ToC(); + } + + /// + public override string VisitType(TupleType type) => type == TupleType.Void ? "void" : throw new InvalidProgramException($"The C Source Must Not Have TupleType {type}!"); + + /// + protected override CSymbol VisitCall(Call expr) + { + if (_symbols.TryGetValue(expr, out var symbol)) + { + return symbol; + } + + var target = Visit(expr.Target); + var args = new CSymbolParamList(expr.Arguments.AsValueEnumerable().Select(Visit).ToArray()); + var type = VisitType(expr.CheckedType!); + _scope.Push(); + switch (expr.Target) + { + case IR.Math.Binary: + _scope.Append($"({args[0].Doc} {target.Doc} {args[1].Doc})"); + break; + case Store: + _scope.Append($"{args[Store.Handle].Doc}[{args[Store.Index].Doc}] = {args[Store.Value].Doc}"); + break; + case Load: + _scope.Append($"{args[Store.Handle].Doc}[{args[Store.Index].Doc}]"); + break; + case IR.Tensors.Cast: + _scope.Append($"(({type}){args[IR.Tensors.Cast.Input].Doc})"); + break; + default: + _scope.Append($"{target.Doc}({string.Join(", ", args.Select(x => x.Doc))})"); + break; + } + + symbol = new(type, _scope.Pop()); + _symbols.Add(expr, symbol); + return symbol; + } + + /// + protected override CSymbol VisitConst(Const expr) + { + if (_symbols.TryGetValue(expr, out var symbol)) + { + return symbol; + } + + if (expr.CheckedType is TensorType ttype && ttype.IsScalar) + { + var literal = $"{expr}" switch + { + "True" => "1", + "False" => "0", + var x => x, + }; + symbol = new(VisitType(ttype), new(literal)); + } + else + { + throw new NotSupportedException($"Not Support {expr.CheckedType} Const"); + } + + _symbols.Add(expr, symbol); + return symbol; + } + + /// + protected override CSymbol VisitPrimFunction(PrimFunction expr) + { + if (_symbols.TryGetValue(expr, out var symbol)) + { + return symbol; + } + + var retType = VisitType(((CallableType)expr.CheckedType!).ReturnType); + _scope.Push(); + + // 1. Function signature + _scope.IndWrite($"EXPORT_API {retType} {expr.Name}({string.Join(", ", expr.Parameters.AsValueEnumerable().Select(Visit).ToArray())}) {{"); + + // 2. Function body + using (_scope.IndentUp()) + { + _scope.Append(Visit(expr.Body).Doc); + } + + // 3. Function closing + _scope.IndWrite("}"); + symbol = new(CallableTypeToPtr((CallableType)expr.CheckedType!, expr.Name), _scope.Pop()); + + // 4. write whole code + _scope.IndWrite(symbol.Doc); + return symbol; + } + + /// + protected override CSymbol VisitOp(Op expr) + { + if (_symbols.TryGetValue(expr, out var symbol)) + { + return symbol; + } + + symbol = new("Invalid Op", new(expr switch + { + IR.Math.Binary op => op.BinaryOp switch + { + BinaryOp.Add => "+", + BinaryOp.Sub => "-", + BinaryOp.Mul => "*", + BinaryOp.Div => "/", + BinaryOp.Mod => "%", + _ => throw new ArgumentOutOfRangeException(op.BinaryOp.ToString()), + }, + TIR.Store op => "Store", + TIR.Load op => "Load", + IR.Tensors.Cast op => op.NewType.ToC(), + _ => throw new NotSupportedException($"{expr.GetType().Name}"), + })); + _symbols.Add(expr, symbol); + return symbol; + } + + /// + protected override CSymbol VisitVar(Var expr) + { + if (_symbols.TryGetValue(expr, out var symbol)) + { + return symbol; + } + + var isymbol = _scope.GetUniqueVarSymbol(expr); + symbol = new(VisitType(expr.CheckedType!), isymbol.Span); + _symbols.Add(expr, symbol); + return symbol; + } + + /// + protected override CSymbol VisitFor(For expr) + { + if (_symbols.TryGetValue(expr, out var symbol)) + { + return symbol; + } + + _scope.Push(); + + // 1. For Loop signature + var loopVar = Visit(expr.LoopVar); + _scope.Append($"for ({loopVar} = {Visit(expr.Domain.Start).Doc}; {loopVar.Doc} < {Visit(expr.Domain.Stop).Doc}; {loopVar.Doc}+={expr.Domain.Step}) {{"); + + // 2. For Body + _scope.Append(Visit(expr.Body).Doc); + + // 3. For closing + _scope.IndWrite("}"); + symbol = new(VisitType(expr.CheckedType!), _scope.Pop()); + _symbols.Add(expr, symbol); + return symbol; + } + + /// + protected override CSymbol VisitSequential(Sequential expr) + { + if (_symbols.TryGetValue(expr, out var symbol)) + { + return symbol; + } + + _scope.Push(); + _scope.AppendLine(string.Empty); + using (_scope.IndentUp()) + { + foreach (var i in Enumerable.Range(0, expr.Count)) + { + if (i == expr.Count - 1 && + expr.Fields[i].CheckedType is TensorType) + { + _scope.IndWrite("return "); + } + else + { + _scope.IndWrite(string.Empty); + } + + _scope.Append(Visit(expr.Fields[i]).Doc); + if (expr.Fields[i] is Call) + { + _scope.AppendLine(";"); + } + else + { + _scope.AppendLine(string.Empty); + } + } + } + + symbol = new(VisitType(expr.CheckedType!), _scope.Pop()); + _symbols.Add(expr, symbol); + return symbol; + } + + /// + protected override CSymbol VisitIfThenElse(IfThenElse expr) + { + if (_symbols.TryGetValue(expr, out var symbol)) + { + return symbol; + } + + _scope.Push(); + _scope.Append($"if({Visit(expr.Condition).Doc}) {{"); + _scope.Append(Visit(expr.Then).Doc); + _scope.IndWrite("} else {"); + _scope.Append(Visit(expr.Else).Doc); + _scope.IndWrite("}"); + symbol = new(VisitType(expr.CheckedType!), _scope.Pop()); + _symbols.Add(expr, symbol); + return symbol; + } +} diff --git a/modules/Nncase.Modules.CPU/CodeGen/FunctionBuilder.cs b/modules/Nncase.Modules.CPU/CodeGen/FunctionBuilder.cs new file mode 100644 index 0000000000..854a8582b0 --- /dev/null +++ b/modules/Nncase.Modules.CPU/CodeGen/FunctionBuilder.cs @@ -0,0 +1,119 @@ +// Copyright (c) Canaan Inc. All rights reserved. +// Licensed under the Apache license. See LICENSE file in the project root for full license information. +#pragma warning disable +using System.Runtime.InteropServices; +using System.Text; +using System.Threading.Tasks; +using NetFabric.Hyperlinq; +using Nncase.IR; + +namespace Nncase.CodeGen.CPU; + +/// +/// StackVM function builder. +/// +internal class FunctionBuilder : IDisposable +{ + private readonly uint _id; + private readonly MemoryStream _textContent = new MemoryStream(); + private readonly BinaryWriter _textWriter; + private readonly BinaryWriter _rdataWriter; + + /// + /// NOTE sync with the cpu runtime function. + /// + [StructLayout(LayoutKind.Sequential)] + private struct MemoryRange + { + public uint Start; + public uint Size; + } + + /// + /// NOTE sync with the cpu runtime function. + /// + [StructLayout(LayoutKind.Sequential)] + private unsafe struct DescHeader + { + /// + /// input pool size. + /// + public uint InputPoolSize; + + /// + /// output pool size. + /// + public uint OutputPoolSize; + + /// + /// input numbers. + /// + public uint Inputs; + + /// + /// output numbers. + /// + public uint Outputs; + } + + public FunctionBuilder(uint id, BinaryWriter rdataWriter) + { + _id = id; + _textWriter = new BinaryWriter(_textContent, Encoding.UTF8, leaveOpen: true); + _rdataWriter = rdataWriter; + } + + public unsafe LinkableFunction Build(TIR.PrimFunction function) + { + // 1. write the inst + // new InstSerializeVisitor(_textWriter).Visit(function.Body); + + // 2. write the desc + var descContent = new MemoryStream(); + using (var descWriter = new BinaryWriter(descContent, Encoding.UTF8)) + { + DescHeader header = new() { InputPoolSize = 0, OutputPoolSize = 0, Inputs = 0, Outputs = 0 }; + long headerStart = descWriter.Position(); + descWriter.Skip((ulong)sizeof(DescHeader)); + + foreach (var input in function.Parameters.AsValueEnumerable() + .Where(buf => buf.MemLocation == Schedule.MemoryLocation.Input)) + { + header.Inputs++; + var rg = new MemoryRange { Start = checked((uint)input.Start), Size = checked((uint)input.Size) }; + descWriter.Write(ref rg); + header.InputPoolSize = Math.Max(header.InputPoolSize, rg.Start + rg.Size); + } + + foreach (var output in function.Parameters.AsValueEnumerable().Where(buf => buf.MemLocation == Schedule.MemoryLocation.Output)) + { + header.Outputs++; + var rg = new MemoryRange { Start = checked((uint)output.Start), Size = checked((uint)output.Size) }; + descWriter.Write(ref rg); + header.OutputPoolSize = Math.Max(header.OutputPoolSize, rg.Start + rg.Size); + } + + descWriter.Position(headerStart); + descWriter.Write(ref header); + } + + // 3. write the rdata + foreach (var buffer in function.SchedResult.Rdatas) + { + var bytes = buffer.Const!.Value.BytesBuffer; + if ((uint)bytes.Length != buffer.Size) + { + throw new InvalidDataException("The Buffer Szie Not Equal!"); + } + + _rdataWriter.Position((uint)buffer.Start); + _rdataWriter.Write(bytes); + } + + return new LinkableFunction(_id, function, _textContent.ToArray(), descContent.ToArray()); + } + + public void Dispose() + { + } +} diff --git a/modules/Nncase.Modules.CPU/CodeGen/LinkableFunction.cs b/modules/Nncase.Modules.CPU/CodeGen/LinkableFunction.cs new file mode 100644 index 0000000000..e92d1f5b4e --- /dev/null +++ b/modules/Nncase.Modules.CPU/CodeGen/LinkableFunction.cs @@ -0,0 +1,30 @@ +// Copyright (c) Canaan Inc. All rights reserved. +// Licensed under the Apache license. See LICENSE file in the project root for full license information. + +using Nncase.IR; + +namespace Nncase.CodeGen.CPU; + +internal sealed class LinkableFunction : ILinkableFunction +{ + private readonly byte[] _desc; + + public LinkableFunction(uint id, TIR.PrimFunction sourceFunction, byte[] text, byte[] desc) + { + Id = id; + SourceFunction = sourceFunction; + Text = text; + _desc = desc; + Sections = new LinkedSection[] { new(_desc, ".desc", 0, 8, (uint)_desc.Length) }; + } + + public uint Id { get; } + + public BaseFunction SourceFunction { get; } + + public byte[] Text { get; } + + public IEnumerable FunctionRefs => Enumerable.Empty(); + + public IReadOnlyList Sections { get; } +} diff --git a/modules/Nncase.Modules.CPU/CodeGen/LinkableModule.cs b/modules/Nncase.Modules.CPU/CodeGen/LinkableModule.cs new file mode 100644 index 0000000000..53d8c79b18 --- /dev/null +++ b/modules/Nncase.Modules.CPU/CodeGen/LinkableModule.cs @@ -0,0 +1,45 @@ +// Copyright (c) Canaan Inc. All rights reserved. +// Licensed under the Apache license. See LICENSE file in the project root for full license information. + +using System; +using System.Collections.Generic; +using System.Linq; +using System.Text; +using System.Threading.Tasks; +using Nncase.Runtime.StackVM; + +namespace Nncase.CodeGen.CPU; + +internal sealed class LinkableModule : ILinkableModule +{ + private const int _textAlignment = 8; + + private readonly byte[] _rdata; + + private readonly IReadOnlyList _functions; + + public LinkableModule(byte[] rdata, IReadOnlyList functions) + { + _rdata = rdata; + _functions = functions; + } + + public ILinkedModule Link(ILinkContext linkContext) + { + var linkedFunctions = new List(); + var text = new MemoryStream(); + using (var bw = new BinaryWriter(text, Encoding.UTF8, true)) + { + foreach (var func in _functions) + { + // FixFunctionRefs(func, linkContext); + bw.AlignPosition(_textAlignment); + var textBegin = bw.Position(); + bw.Write(func.Text); + linkedFunctions.Add(new LinkedFunction(func.Id, func.SourceFunction, (uint)textBegin, (uint)func.Text.Length, func.Sections)); + } + } + + return new LinkedModule(linkedFunctions, text.ToArray(), _rdata); + } +} diff --git a/modules/Nncase.Modules.CPU/CodeGen/LinkedModule.cs b/modules/Nncase.Modules.CPU/CodeGen/LinkedModule.cs new file mode 100644 index 0000000000..4ad317298f --- /dev/null +++ b/modules/Nncase.Modules.CPU/CodeGen/LinkedModule.cs @@ -0,0 +1,31 @@ +// Copyright (c) Canaan Inc. All rights reserved. +// Licensed under the Apache license. See LICENSE file in the project root for full license information. + +using System; +using System.Collections.Generic; +using System.Linq; +using System.Text; +using System.Threading.Tasks; +using Nncase.Runtime.StackVM; + +namespace Nncase.CodeGen.CPU; + +internal sealed class LinkedModule : ILinkedModule +{ + public LinkedModule(IReadOnlyList functions, byte[] text, byte[] rdata) + { + Functions = functions; + Sections = new[] { + new LinkedSection(text, ".text", 0, 8, (uint)text.Length), + new LinkedSection(rdata, ".rdata", 0, 8, (uint)rdata.Length), + }; + } + + public string ModuleKind => Runtime.CPU.CPURTModule.Kind; + + public uint Version => Runtime.CPU.CPURTModule.Version; + + public IReadOnlyList Functions { get; } + + public IReadOnlyList Sections { get; } +} diff --git a/modules/Nncase.Modules.CPU/CodeGen/ModuleBuilder.cs b/modules/Nncase.Modules.CPU/CodeGen/ModuleBuilder.cs new file mode 100644 index 0000000000..659a16dc9b --- /dev/null +++ b/modules/Nncase.Modules.CPU/CodeGen/ModuleBuilder.cs @@ -0,0 +1,39 @@ +// Copyright (c) Canaan Inc. All rights reserved. +// Licensed under the Apache license. See LICENSE file in the project root for full license information. + +using System.Text; +using Nncase.Diagnostics; +using Nncase.IR; + +namespace Nncase.CodeGen.CPU; + +/// +/// K230CoreModule builder. +/// +public sealed class ModuleBuilder : IModuleBuilder, IDisposable +{ + private readonly MemoryStream _rdataContent = new MemoryStream(); + private readonly BinaryWriter _rdataWriter; + + public ModuleBuilder(CompileOptions options) + { + _rdataWriter = new BinaryWriter(_rdataContent, Encoding.UTF8, leaveOpen: true); + CompileOptions = options; + } + + public CompileOptions CompileOptions { get; } + + /// + public string ModuleKind => Runtime.CPU.CPURTModule.Kind; + + /// + public ILinkableModule Build(IReadOnlyList functions) + { + var linkableFunctions = functions.OfType().Select((f, i) => new FunctionBuilder((uint)i, _rdataWriter).Build(f)).ToArray(); + _rdataWriter.Flush(); + + return new LinkableModule(_rdataContent.ToArray(), linkableFunctions); + } + + public void Dispose() => ((IDisposable)_rdataContent).Dispose(); +} diff --git a/modules/Nncase.Modules.CPU/Nncase.Modules.CPU.csproj b/modules/Nncase.Modules.CPU/Nncase.Modules.CPU.csproj index fb4674b51c..60eb18e05d 100644 --- a/modules/Nncase.Modules.CPU/Nncase.Modules.CPU.csproj +++ b/modules/Nncase.Modules.CPU/Nncase.Modules.CPU.csproj @@ -9,6 +9,7 @@ + diff --git a/modules/Nncase.Modules.CPU/Passes/Tile/CPUFusionConverter.cs b/modules/Nncase.Modules.CPU/Passes/Tile/CPUFusionConverter.cs new file mode 100644 index 0000000000..2c85b4deca --- /dev/null +++ b/modules/Nncase.Modules.CPU/Passes/Tile/CPUFusionConverter.cs @@ -0,0 +1,148 @@ +// Copyright (c) Canaan Inc. All rights reserved. +// Licensed under the Apache license. See LICENSE file in the project root for full license information. + +#if true +using System.Reactive; +using System.Runtime.CompilerServices; +using NetFabric.Hyperlinq; +using Nncase.Diagnostics; +using Nncase.IR; +using Nncase.IR.Buffers; +using Nncase.IR.CPU; +using Nncase.IR.F; +using Nncase.IR.Math; +using Nncase.Passes.Mutators; +using Nncase.PatternMatch; +using Nncase.Schedule; +using Nncase.Targets; +using Nncase.TIR; +using Nncase.TIR.Builders; +using Buffer = Nncase.TIR.Buffer; +using MathF = Nncase.IR.F.Math; +using Range = Nncase.TIR.Range; +using Tuple = Nncase.IR.Tuple; + +namespace Nncase.Passes.Tile; + +/// +/// name 分配器. +/// +internal sealed class NameAllocator +{ + public Dictionary NamePool { get; } = new(); + + public string Get(string name) + { + if (!NamePool.TryGetValue(name, out var count)) + { + count = 0; + } + + NamePool[name] = count + 1; + return count == 0 ? name : $"{name}_{count}"; + } +} + +internal class CPUFusionConverter +{ + public NameAllocator NameAllocator { get; } = new(); + + /// + /// Gets tile size 的变量. + /// + public List TileSizeVars { get; } = new(); + + /// + /// Gets loop 变量. + /// + // public List LoopVars { get; } = new(); + + /// + /// Gets loop domains. + /// + public List LoopDomains { get; } = new(); + + /// + /// Gets nested loops. + /// + // public List> NestedLoops { get; } = new(); + + public TileOptions TileOptions { get; protected set; } = null!; + + /// + /// Gets or sets 总的loop count. + /// / + public virtual Expr LoopCount { get; protected set; } + + /// + /// Gets or sets ping pong 外层的tiling. + /// + public virtual Expr LoopCountOuter { get; protected set; } + + /// + /// Gets or sets ping pong 内侧的tiling. + /// + public virtual Expr LoopCountInner { get; protected set; } + + /// + /// Gets or sets 当前的fusion. + /// + public virtual Fusion CurrentFusion { get; protected set; } + + public virtual PrimFunction BuildPrimFunc(Fusion fusion) + { + // TODO: buffer顺序可能需要调整以保持原图的顺序 + var primFuncBuilder = T.PrimFunc(CurrentFusion.Name, CPUTarget.Kind, _ifBufferMap.Values.Union(_ofBufferMap.Values).Select(b => (PhysicalBuffer)b).ToArray()); + return primFuncBuilder.Build(); + } + + public virtual Expr Visit(Expr root) + { + return root switch + { + Call call => (call.Target switch + { + CPUUnary op => LowerCPUUnary(call, op), + _ => throw new NotSupportedException(), + }).Build(), + _ => T.Nop(), + }; + } + + protected virtual ISequentialBuilder LowerCPUUnary(Call call, CPUUnary op) + { + var prefix = NameAllocator.Get(nameof(CPUUnary)); + var inputCall = call[CPUUnary.Input]; + T.PhysicalBuffer(inputCall.CheckedDataType, MemoryLocation.Input, inputCall.CheckedShape, out var ddrIf); + T.PhysicalBuffer(call.CheckedDataType, MemoryLocation.Output, call.CheckedShape.ToValueArray(), out var ddrOf); + _ifBufferMap.Add(call, ddrIf); + _ofBufferMap.Add(call, ddrOf); + + List LoopVars = new(); + List> NestedLoops = new(); + List LoopDomains = call.CheckedShape.Select(s => new Range(0, 1, s.FixedValue)).ToList(); + + var seq = T.Sequential().Body(Visit(inputCall)); + + for (int i = 0; i < call.CheckedShape.Rank; i++) + { + NestedLoops.Add(T.ForLoop(out var loopVar, LoopDomains[i], LoopMode.Unrolled, $"loop_var_{i}")); + LoopVars.Add(loopVar); + } + + NestedLoops[^1].Body( + op.UnaryOp switch + { + // TODO: body的实现 + UnaryOp.Abs => T.Nop(), + _ => throw new NotSupportedException(), + }); + + seq.Body(NestedLoops[0].Body()); + return seq; + } + + private readonly Dictionary _ifBufferMap = new(ReferenceEqualityComparer.Instance); + private readonly Dictionary _ofBufferMap = new(ReferenceEqualityComparer.Instance); +} +#endif diff --git a/modules/Nncase.Modules.CPU/Runtime/CPU/CPURTModule.cs b/modules/Nncase.Modules.CPU/Runtime/CPU/CPURTModule.cs new file mode 100644 index 0000000000..0ed89e20f3 --- /dev/null +++ b/modules/Nncase.Modules.CPU/Runtime/CPU/CPURTModule.cs @@ -0,0 +1,23 @@ +// Copyright (c) Canaan Inc. All rights reserved. +// Licensed under the Apache license. See LICENSE file in the project root for full license information. + +using System; +using System.Collections.Generic; +using System.Linq; +using System.Text; +using System.Threading.Tasks; + +namespace Nncase.Runtime.CPU; + +internal class CPURTModule +{ + /// + /// KPU module kind. + /// + public static readonly string Kind = "cpu"; + + /// + /// KPU module version. + /// + public static readonly uint Version = 1; +} diff --git a/modules/Nncase.Modules.CPU/packages.lock.json b/modules/Nncase.Modules.CPU/packages.lock.json index 4ed711f699..5dd73e5bda 100644 --- a/modules/Nncase.Modules.CPU/packages.lock.json +++ b/modules/Nncase.Modules.CPU/packages.lock.json @@ -137,6 +137,12 @@ "System.Reactive": "[5.0.0, )" } }, + "nncase.diagnostics": { + "type": "Project", + "dependencies": { + "Nncase.Core": "[1.0.0, )" + } + }, "nncase.egraph": { "type": "Project", "dependencies": { diff --git a/src/Nncase.Cli/packages.lock.json b/src/Nncase.Cli/packages.lock.json index 7ba8ee7204..712cc915f1 100644 --- a/src/Nncase.Cli/packages.lock.json +++ b/src/Nncase.Cli/packages.lock.json @@ -747,6 +747,7 @@ "type": "Project", "dependencies": { "Nncase.CodeGen": "[1.0.0, )", + "Nncase.Diagnostics": "[1.0.0, )", "Nncase.Modules.StackVM": "[1.0.0, )", "Nncase.Passes": "[1.0.0, )" } diff --git a/src/Nncase.Compiler/packages.lock.json b/src/Nncase.Compiler/packages.lock.json index 1655cc14b1..8cede23417 100644 --- a/src/Nncase.Compiler/packages.lock.json +++ b/src/Nncase.Compiler/packages.lock.json @@ -724,6 +724,7 @@ "type": "Project", "dependencies": { "Nncase.CodeGen": "[1.0.0, )", + "Nncase.Diagnostics": "[1.0.0, )", "Nncase.Modules.StackVM": "[1.0.0, )", "Nncase.Passes": "[1.0.0, )" } diff --git a/src/Nncase.Tests.TestFixture/packages.lock.json b/src/Nncase.Tests.TestFixture/packages.lock.json index cafcae2049..94827ead5b 100644 --- a/src/Nncase.Tests.TestFixture/packages.lock.json +++ b/src/Nncase.Tests.TestFixture/packages.lock.json @@ -1098,6 +1098,12 @@ "System.Reactive": "[5.0.0, )" } }, + "nncase.diagnostics": { + "type": "Project", + "dependencies": { + "Nncase.Core": "[1.0.0, )" + } + }, "nncase.egraph": { "type": "Project", "dependencies": { @@ -1130,6 +1136,7 @@ "type": "Project", "dependencies": { "Nncase.CodeGen": "[1.0.0, )", + "Nncase.Diagnostics": "[1.0.0, )", "Nncase.Modules.StackVM": "[1.0.0, )", "Nncase.Passes": "[1.0.0, )" } diff --git a/src/Nncase.Tests/packages.lock.json b/src/Nncase.Tests/packages.lock.json index 6d25e0bccb..02b0e823c7 100644 --- a/src/Nncase.Tests/packages.lock.json +++ b/src/Nncase.Tests/packages.lock.json @@ -1543,6 +1543,7 @@ "type": "Project", "dependencies": { "Nncase.CodeGen": "[1.0.0, )", + "Nncase.Diagnostics": "[1.0.0, )", "Nncase.Modules.StackVM": "[1.0.0, )", "Nncase.Passes": "[1.0.0, )" } From 7de0b568a88bc04b1e63791bfbe165b95bd6ce32 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?=E9=83=91=E5=90=AF=E8=88=AA?= <597323109@qq.com> Date: Fri, 21 Jul 2023 16:27:35 +0800 Subject: [PATCH 006/308] add csource codegen --- modules/Nncase.Modules.CPU/CodeGen/CSource.cs | 283 ---- .../CodeGen/CSourceBuiltn.cs | 22 + .../CodeGen/CSourceConvertVisitor.cs | 323 +++++ .../CodeGen/CSourceExtensions.cs | 45 + .../CodeGen/CSourceUtilities.cs | 40 + .../CodeGen/CSourceVisitor.cs | 365 ----- .../CodeGen/FunctionBuilder.cs | 76 +- .../CodeGen/FunctionCSource.cs | 149 ++ .../CodeGen/LinkableFunction.cs | 10 +- .../CodeGen/LinkableModule.cs | 53 +- .../CodeGen/LinkedModule.cs | 9 +- .../CodeGen/ModuleBuilder.cs | 2 +- .../Evaluator/CPU/CPUKernelOp.cs | 34 + .../Evaluator/CPU/CPUModule.cs | 2 +- .../Evaluator/CPU/CPUUnary.cs | 141 -- .../IR/CPU/{CPUUnary.cs => CPUKernelOp.cs} | 8 +- .../Nncase.Modules.CPU/IR/CPU/Functional.cs | 10 +- .../Passes/CPUFusionToTirPass.cs | 63 +- .../Passes/Rules/LowerUnary.cs | 2 +- .../Passes/Rules/MakeFusion.cs | 4 +- .../Passes/Tile/CPUFusionConverter.cs | 148 -- .../Passes/Tile/CPUFusionGroupMutator.cs | 57 +- .../Passes/Tile/LayerFusionConverter.cs | 1264 ----------------- .../Passes/Tile/MultiFusionChecker.cs | 240 +--- .../Passes/Tile/MultiLayerFusionConverter.cs | 230 --- .../Passes/Tile/SingleCPUFusionConverter.cs | 119 ++ .../Passes/Tile/TileOptions.cs | 10 +- .../Passes/Tile/TwoFusionChecker.cs | 131 -- .../Runtime/CPU/CPURTModule.cs | 23 - .../Nncase.Modules.CPU/Targets/CPUTarget.cs | 16 +- src/Nncase.CodeGen/CodeGen/LinkedFunction.cs | 1 - src/Nncase.Core/IR/Buffers/BufferOf.cs | 4 +- src/Nncase.Core/IR/Buffers/Functional.cs | 2 +- src/Nncase.Core/IR/Buffers/Uninitialized.cs | 4 +- src/Nncase.Core/IR/ExprCloner.g.cs | 8 + src/Nncase.Core/IR/ExprFunctor.g.cs | 11 + src/Nncase.Core/IR/ExprRewriter.g.cs | 11 + src/Nncase.Core/IR/ExprVisitor.g.cs | 27 + src/Nncase.Core/Schedule/ScheduleTypes.cs | 45 - src/Nncase.Core/TIR/Buffer.cs | 20 +- src/Nncase.Core/TIR/MemSpan.cs | 86 ++ src/Nncase.Core/TIR/Ops.cs | 16 +- src/Nncase.Core/TIR/Script.cs | 30 +- src/Nncase.Evaluator/TIR/Load.cs | 8 +- src/Nncase.Evaluator/TIR/Store.cs | 8 +- src/Nncase.Passes/DDrBufferSchdeulePass.cs | 18 +- .../Rules/Neutral/PrimFuncMergeRule.cs | 10 +- src/Nncase.Tests/CodeGen/CSourceHostCases.cs | 8 +- src/Nncase.Tests/Core/UnitTestExpression.cs | 10 +- .../Core/UnitTestStringUtility.cs | 6 +- src/Nncase.Tests/Core/UnitTestTIR.cs | 23 +- .../Diagnostics/UnitTestDumpper.cs | 2 +- .../Evaluator/UnitTestEvaluator.cs | 6 +- .../Evaluator/UnitTestEvaluatorBuffers.cs | 2 +- .../TIR/PrimFunc/IDataFlowPrimFuncCase.cs | 36 +- .../TIR/PrimFunc/UnitTestPrimFuncMerge.cs | 10 +- src/Nncase.Tests/TIR/UnitTestMutators.cs | 20 +- .../Targets/UnitTestCPUTargetTiling.cs | 51 + .../Transform/UnitTestPassManager.cs | 12 +- .../Transform/UnitTestSubstitutor.cs | 15 +- .../Nncase.Targets.CSource/CSourceTarget.cs | 34 - .../Nncase.Targets.CSource/CodeGen/CSource.cs | 278 ---- .../CodeGen/CSourceVisitor.cs | 317 ----- .../Nncase.Targets.CSource/CodeGen/Interop.cs | 140 -- .../Nncase.Targets.CSource.csproj | 16 - .../Schedule/CSourceScheduler.cs | 22 - .../Pattern/PatternGenerator.cs | 4 +- 67 files changed, 1165 insertions(+), 4035 deletions(-) delete mode 100644 modules/Nncase.Modules.CPU/CodeGen/CSource.cs create mode 100644 modules/Nncase.Modules.CPU/CodeGen/CSourceBuiltn.cs create mode 100644 modules/Nncase.Modules.CPU/CodeGen/CSourceConvertVisitor.cs create mode 100644 modules/Nncase.Modules.CPU/CodeGen/CSourceExtensions.cs create mode 100644 modules/Nncase.Modules.CPU/CodeGen/CSourceUtilities.cs delete mode 100644 modules/Nncase.Modules.CPU/CodeGen/CSourceVisitor.cs create mode 100644 modules/Nncase.Modules.CPU/CodeGen/FunctionCSource.cs create mode 100644 modules/Nncase.Modules.CPU/Evaluator/CPU/CPUKernelOp.cs delete mode 100644 modules/Nncase.Modules.CPU/Evaluator/CPU/CPUUnary.cs rename modules/Nncase.Modules.CPU/IR/CPU/{CPUUnary.cs => CPUKernelOp.cs} (69%) delete mode 100644 modules/Nncase.Modules.CPU/Passes/Tile/CPUFusionConverter.cs delete mode 100644 modules/Nncase.Modules.CPU/Passes/Tile/LayerFusionConverter.cs delete mode 100644 modules/Nncase.Modules.CPU/Passes/Tile/MultiLayerFusionConverter.cs create mode 100644 modules/Nncase.Modules.CPU/Passes/Tile/SingleCPUFusionConverter.cs delete mode 100644 modules/Nncase.Modules.CPU/Passes/Tile/TwoFusionChecker.cs delete mode 100644 modules/Nncase.Modules.CPU/Runtime/CPU/CPURTModule.cs create mode 100644 src/Nncase.Core/TIR/MemSpan.cs create mode 100644 src/Nncase.Tests/Targets/UnitTestCPUTargetTiling.cs delete mode 100644 targets/Nncase.Targets.CSource/CSourceTarget.cs delete mode 100644 targets/Nncase.Targets.CSource/CodeGen/CSource.cs delete mode 100644 targets/Nncase.Targets.CSource/CodeGen/CSourceVisitor.cs delete mode 100644 targets/Nncase.Targets.CSource/CodeGen/Interop.cs delete mode 100644 targets/Nncase.Targets.CSource/Nncase.Targets.CSource.csproj delete mode 100644 targets/Nncase.Targets.CSource/Schedule/CSourceScheduler.cs diff --git a/modules/Nncase.Modules.CPU/CodeGen/CSource.cs b/modules/Nncase.Modules.CPU/CodeGen/CSource.cs deleted file mode 100644 index 8570adf641..0000000000 --- a/modules/Nncase.Modules.CPU/CodeGen/CSource.cs +++ /dev/null @@ -1,283 +0,0 @@ -// Copyright (c) Canaan Inc. All rights reserved. -// Licensed under the Apache license. See LICENSE file in the project root for full license information. - -#if false -using System; -using System.Collections; -using System.Collections.Generic; -using System.Diagnostics; -using System.IO; -using System.Linq; -using System.Runtime.InteropServices; -using System.Text; -using Nncase.IR; -using Nncase.Schedule; -using Nncase.TIR; - -namespace Nncase.CodeGen; - -/// -/// the c source runtime function. -/// -/// -/// -public record CSourceRTFunction(string name, Delegate handle) : IRTFunction -{ - public string Name { get => name; set { } } - public Delegate Handle { get => handle; set { } } -} - -public class CSourceSerializeResult : ISerializeResult -{ - -} - -/// -/// c runtime module impl -/// -public class CSourceRTModel : IRTModule, IRTModel -{ - /// - public ModuleType ModuleType { get => CodeGen.ModuleType.Create("CSource"); set { } } - - /// - public ITarget Target { get; set; } - - /// - public IReadOnlyList Modules => throw new NotImplementedException(); - - /// - public string SourcePath { get; private set; } - - public IRModel Model { get; set; } - IRTFunction? _entry = null; - - /// - public bool IsSerialized { get; private set; } - - readonly List _functions = new(); - - /// - /// - /// - public CSourceRTModel(IRModel model, ITarget target) - { - SourcePath = CodeGenUtil.GetTempFileName("c"); - Model = model; - Target = target; - } - - /// - public byte[] Source { get => File.ReadAllBytes(SourcePath); set { } } - - /// - public string SourceExt { get => "c"; set { } } - - /// - public IRTFunction? Entry => _entry; - - /// - public IReadOnlyList Functions => _functions; - - /// - string _dllPath = ""; - - /// - /// write the c source code into source path. - /// - /// - void BuildCode() - { - if (File.Exists(SourcePath)) - File.Delete(SourcePath); - using (var writer = new StreamWriter(SourcePath, false, Encoding.UTF8)) - { - var visior = new CSourceHostBuildVisior(writer); - if (Model.Entry is null) { throw new InvalidProgramException("The Model Entry Is Null!"); } - if (Model.Entry.CheckedType is null && Model.Entry.InferenceType() == false) { throw new InvalidProgramException("The Model Entry Can't Inference Type!"); } - visior.Visit(Model.Entry); - } - } - - public void CompileCode() - { - if (!File.Exists(SourcePath)) - throw new InvalidProgramException("The Source Code Path Is Invalid!"); - var compiler = new CSourceCompiler(); - _dllPath = compiler.Compile(SourcePath); - } - - /// - /// bind each IR.Funtion with C function - /// - /// - public void ExportCode() - { - if (!File.Exists(_dllPath)) - throw new InvalidProgramException("The DLL Path Is Invalid!"); - var dllPtr = NativeLibrary.Load(_dllPath); - foreach (var module in Model.Modules) - { - foreach (var f in module.Callables) - { - var funcType = f.ToDelegateType(Path.GetFileName(_dllPath)); - var funPtr = NativeLibrary.GetExport(dllPtr, f.Name); - _functions.Add(new CSourceRTFunction(f.Name, funPtr.BindDelegate(funcType))); - if (f == Model.Entry) { _entry = _functions.Last(); } - } - } - } - - /// - public ISerializeResult Serialize() - { - if (IsSerialized) { return new CSourceSerializeResult(); } - BuildCode(); - CompileCode(); - ExportCode(); - return new CSourceSerializeResult(); - } - - /// - /// invoke the module entry - /// - /// input args - /// results - /// - public object? Invoke(params object?[]? args) - { - if (Entry is null) - throw new InvalidOperationException("This RTModule Have No Entry Function!"); - return Entry.Handle.DynamicInvoke(args); - } - - public string Dump(string name, string DumpDirPath) - { - var dump_path = $"{DumpDirPath}/{name}.{SourceExt}"; - using var file = File.Open(dump_path, FileMode.OpenOrCreate, FileAccess.Write); - using var writer = new StreamWriter(file); - writer.Write(Source); - return dump_path; - } - -} - -/// -/// the csource code compiler. -/// -public class CSourceCompiler -{ - /// - /// compiler exe name - /// - string _exe = "", _arch = "", _ext = ""; - - /// - /// select current pattern's exe - /// - /// - void PlatformSpecific() - { - if (RuntimeInformation.IsOSPlatform(OSPlatform.Linux)) - { - _exe = "gcc"; - _ext = "so"; - } - else if (RuntimeInformation.IsOSPlatform(OSPlatform.OSX)) - { - _exe = "clang"; - _ext = "dylib"; - } - else if (RuntimeInformation.IsOSPlatform(OSPlatform.Windows)) - { - _exe = "cmd"; - _ext = "dll"; - } - } - - void ArchSpecific() - { - _arch = RuntimeInformation.OSArchitecture switch - { - Architecture.X64 => RuntimeInformation.IsOSPlatform(OSPlatform.Linux) ? "x86-64" : "x86_64", - Architecture.Arm64 => "arm64", - _ => throw new NotSupportedException(RuntimeInformation.OSArchitecture.ToString()), - }; - } - - string ArgumentsSpecific(string sourcePath, string outPath) - { - if (RuntimeInformation.IsOSPlatform(OSPlatform.Linux)) - { - return $"{sourcePath} -fPIC -shared -march={Arch} -o {outPath}"; - } - else if (RuntimeInformation.IsOSPlatform(OSPlatform.OSX)) - { - return $"{sourcePath} -fPIC -shared -arch {Arch} -o {outPath}"; - } - else if (RuntimeInformation.IsOSPlatform(OSPlatform.Windows)) - { - var vsdir = Environment.GetEnvironmentVariable("VSAPPIDDIR") ?? throw new InvalidOperationException("Cannot find vs"); - var vcvardir = Path.Combine(vsdir, "..\\..\\VC\\Auxiliary\\Build\\vcvarsall.bat"); - return $"/C (\"{vcvardir}\" x64) && (cl /D_USRDLL /D_WINDLL \"{sourcePath}\" /MT /link /DLL /OUT:\"{outPath}\")"; - } - throw new System.ArgumentOutOfRangeException("Only Support Linux/Osx/Windows"); - } - - protected string Exe - { - get => _exe; - } - - protected string Arch - { - get => _arch; - } - - protected string Ext - { - get => _ext; - } - - public CSourceCompiler() - { - PlatformSpecific(); - ArchSpecific(); - } - - /// - /// compile the source txt, write to the out_path - /// - /// c source code - /// out .so path - /// outPath - public string Compile(string sourcePath, string outPath) - { - var errMsg = new StringBuilder(); - using (var errWriter = new StringWriter(errMsg)) - { - using (var proc = new Process()) - { - proc.StartInfo.FileName = Exe; - proc.StartInfo.Arguments = ArgumentsSpecific(sourcePath, outPath); - proc.StartInfo.RedirectStandardError = true; - proc.ErrorDataReceived += (sender, e) => errWriter.WriteLine(e.Data); - proc.Start(); - proc.BeginErrorReadLine(); - proc.WaitForExit(); - if (proc.ExitCode != 0) - { - throw new InvalidOperationException(errMsg.ToString()); - } - } - } - return outPath; - } - - /// - /// create the temp dll file and compile source - /// - /// - public string Compile(string sourcePath) => Compile(sourcePath, CodeGenUtil.GetTempFileName(Ext)); -} -#endif diff --git a/modules/Nncase.Modules.CPU/CodeGen/CSourceBuiltn.cs b/modules/Nncase.Modules.CPU/CodeGen/CSourceBuiltn.cs new file mode 100644 index 0000000000..eda007050f --- /dev/null +++ b/modules/Nncase.Modules.CPU/CodeGen/CSourceBuiltn.cs @@ -0,0 +1,22 @@ +// Copyright (c) Canaan Inc. All rights reserved. +// Licensed under the Apache license. See LICENSE file in the project root for full license information. + +namespace Nncase.CodeGen.CPU; + +public static class CSourceBuiltn +{ + + public const string BufferType = "buffer_t"; + + public const string BufferStruct = @"typedef struct buffer { + void *ptr; + int *shape; + int *stride; + int rank; +} buffer_t;"; + + public const string Include = @"#include"; + + public static string Header => Include + "\n" + BufferStruct; + +} \ No newline at end of file diff --git a/modules/Nncase.Modules.CPU/CodeGen/CSourceConvertVisitor.cs b/modules/Nncase.Modules.CPU/CodeGen/CSourceConvertVisitor.cs new file mode 100644 index 0000000000..71b8d5ee65 --- /dev/null +++ b/modules/Nncase.Modules.CPU/CodeGen/CSourceConvertVisitor.cs @@ -0,0 +1,323 @@ +// Copyright (c) Canaan Inc. All rights reserved. +// Licensed under the Apache license. See LICENSE file in the project root for full license information. + +using System; +using System.Collections; +using System.Collections.Generic; +using System.IO; +using System.Linq; +using System.Reactive; +using System.Runtime.InteropServices; +using System.Text; +using NetFabric.Hyperlinq; +using Nncase.IR; +using Nncase.Runtime; +using Nncase.TIR; + +namespace Nncase.CodeGen.CPU; + +/// +/// the c symbol define. +/// +internal sealed class CSymbol +{ + public CSymbol(string type, string name) + { + Type = type; + Name = name; + } + + public string Type { get; } + public string Name { get; } + + public override string ToString() => $"{Type} {Name}"; +} + +internal struct IndentScope : IDisposable +{ + private static readonly AsyncLocal _writer = new AsyncLocal(); + + private readonly bool _initialized; + + private readonly IndentWriter? _originalWriter; + + public IndentScope(StringBuilder sb) + { + _initialized = true; + _writer.Value = new IndentWriter(sb); + _originalWriter = null; + } + + public IndentScope() + { + _initialized = true; + if (_writer.Value is null) + { + return; + } + + _originalWriter = _writer.Value; + _writer.Value = new(_originalWriter.GetStringBuilder(), _originalWriter.Indent + 2); + } + + public static IndentWriter Writer => _writer.Value!; + + public void Dispose() + { + if (_initialized) + { + _writer.Value = _originalWriter; + } + } +} + + +internal sealed class IndentWriter : StringWriter +{ + public int Indent; + + public IndentWriter(StringBuilder sb, int indent = 0) : base(sb) + { + Indent = indent; + } + + public void IndWrite(string? value) + { + for (int i = 0; i < Indent; i++) + { + this.Write(' '); + } + this.Write(value); + } +} + +/// +/// convert single prim function to c source +/// +internal sealed class CSourceConvertVisitor : ExprFunctor +{ + private readonly StringBuilder _implBuilder; + public readonly Dictionary ExprMemo; + + public CSourceConvertVisitor() + { + _implBuilder = new StringBuilder(); + ExprMemo = new(ReferenceEqualityComparer.Instance); + } + + public FunctionCSource GetFunctionCSource() + { + return new(ExprMemo[VisitRoot!].Type + ";", _implBuilder.ToString()); + } + + /// + protected override CSymbol VisitPrimFunction(PrimFunction expr) + { + if (ExprMemo.TryGetValue(expr, out var symbol)) + { + return symbol; + } + + if (expr.CheckedType is not CallableType { ReturnType: TupleType r } || r != TupleType.Void) + { + throw new NotSupportedException("The PrimFunction must return void!"); + } + + var type = $"void {expr.Name}({string.Join(", ", expr.Parameters.AsValueEnumerable().Select(b => Visit(b).ToString()).ToArray())})"; + + using (var scope = new IndentScope(_implBuilder)) + { + // 1. Function signature + IndentScope.Writer.IndWrite($"{type} {{\n"); + // 2. Function body + using (var _ = new IndentScope()) + { + Visit(expr.Body); + } + // 3. Function closing + IndentScope.Writer.IndWrite("}\n"); + } + + symbol = new(type, new(expr.Name)); + ExprMemo.Add(expr, symbol); + return symbol; + } + + /// + protected override CSymbol VisitCall(Call expr) + { + if (ExprMemo.TryGetValue(expr, out var symbol)) + { + return symbol; + } + + var arguments = expr.Arguments.AsValueEnumerable().Select(Visit).ToArray(); + string type = expr.CheckedType switch + { + TupleType x when x == TupleType.Void => "", + TensorType { IsScalar: true } x => x.DType.ToC(), + _ => throw new NotSupportedException() + }; + + string str; + switch (expr.Target) + { + case IR.Math.Binary op: + str = CSourceUtilities.ContertBinary(op, arguments); + break; + case IR.Math.Unary op: + str = CSourceUtilities.ContertUnary(op, arguments); + break; + case Store: + str = $"((({arguments[2].Type} *){arguments[0].Name}->ptr)[{arguments[1].Name}] = {arguments[2].Name})"; + break; + case Load: + str = $"((({type} *){arguments[0].Name}->ptr)[{arguments[1].Name}])"; + break; + default: + throw new NotSupportedException(); + } + + symbol = new(type, str); + ExprMemo.Add(expr, symbol); + return symbol; + } + + /// + protected override CSymbol VisitConst(Const expr) + { + if (ExprMemo.TryGetValue(expr, out var symbol)) + { + return symbol; + } + string type; + string str; + if (expr is TensorConst { Value: Tensor { ElementType: PrimType ptype, Shape: { IsScalar: true } } scalar }) + { + str = scalar[0].ToString() switch + { + "True" => "1", + "False" => "0", + null => string.Empty, + var x => x, + }; + + type = ptype.ToC(); + } + else + { + throw new NotSupportedException($"Not Support {expr.CheckedType} Const"); + } + + symbol = new(type, str); + ExprMemo.Add(expr, symbol); + return symbol; + } + + /// + protected override CSymbol VisitVar(Var expr) + { + if (ExprMemo.TryGetValue(expr, out var symbol)) + { + return symbol; + } + + if (expr.CheckedType is not TensorType { Shape: { IsScalar: true } } ttype) + { + throw new NotSupportedException(); + } + + symbol = new(ttype.DType.ToC(), new($"{expr.Name}_{expr.GlobalVarIndex}")); + ExprMemo.Add(expr, symbol); + return symbol; + } + + /// + protected override CSymbol VisitFor(For expr) + { + if (ExprMemo.TryGetValue(expr, out var symbol)) + { + return symbol; + } + + // 1. For Loop signature + var loopVar = Visit(expr.LoopVar); + IndentScope.Writer.IndWrite($"for ({loopVar.Type} {loopVar.Name} = {Visit(expr.Domain.Start).Name}; {loopVar.Name} < {Visit(expr.Domain.Stop).Name}; {loopVar.Name}+={Visit(expr.Domain.Step).Name}) {{\n"); + using (_ = new IndentScope()) + { + // 2. For Body + Visit(expr.Body); + } + // 3. For closing + IndentScope.Writer.IndWrite("}\n"); + + symbol = new(string.Empty, string.Empty); + ExprMemo.Add(expr, symbol); + return symbol; + } + + /// + protected override CSymbol VisitSequential(Sequential expr) + { + if (ExprMemo.TryGetValue(expr, out var symbol)) + { + return symbol; + } + + foreach (var field in expr.Fields) + { + if (field is Call call) + { + IndentScope.Writer.IndWrite(Visit(call).Name); + IndentScope.Writer.Write(";\n"); + } + else + { + Visit(field); + } + } + + symbol = new(string.Empty, string.Empty); + ExprMemo.Add(expr, symbol); + return symbol; + } + + /// + protected override CSymbol VisitIfThenElse(IfThenElse expr) + { + if (ExprMemo.TryGetValue(expr, out var symbol)) + { + return symbol; + } + + IndentScope.Writer.IndWrite($"if({Visit(expr.Condition).Name}) {{\n"); + using (var _ = new IndentScope()) + { + Visit(expr.Then); + } + + IndentScope.Writer.IndWrite("} else {\n"); + using (var _ = new IndentScope()) + { + Visit(expr.Else); + } + + IndentScope.Writer.IndWrite("}\n"); + + symbol = new(string.Empty, string.Empty); + ExprMemo.Add(expr, symbol); + return symbol; + } + + protected override CSymbol VisitPhysicalBuffer(PhysicalBuffer expr) + { + if (ExprMemo.TryGetValue(expr, out var symbol)) + { + return symbol; + } + + symbol = new(CSourceBuiltn.BufferType + "*", expr.Name); + ExprMemo.Add(expr, symbol); + return symbol; + } +} diff --git a/modules/Nncase.Modules.CPU/CodeGen/CSourceExtensions.cs b/modules/Nncase.Modules.CPU/CodeGen/CSourceExtensions.cs new file mode 100644 index 0000000000..63618df79e --- /dev/null +++ b/modules/Nncase.Modules.CPU/CodeGen/CSourceExtensions.cs @@ -0,0 +1,45 @@ +// Copyright (c) Canaan Inc. All rights reserved. +// Licensed under the Apache license. See LICENSE file in the project root for full license information. + + +namespace Nncase.CodeGen.CPU; + +/// +/// convert the type/op to c name. +/// +internal static class CSourceExtensions +{ + private static readonly Dictionary _primTypeToC = new() + { + { DataTypes.Boolean, "uint8_t" }, + { DataTypes.Int8, "int8_t" }, + { DataTypes.Int16, "int16_t" }, + { DataTypes.Int32, "int32_t" }, + { DataTypes.Int64, "int64_t" }, + { DataTypes.UInt8, "uint8_t" }, + { DataTypes.UInt16, "uint16_t" }, + { DataTypes.UInt32, "uint32_t" }, + { DataTypes.UInt64, "uint64_t" }, + { DataTypes.Float32, "float" }, + { DataTypes.Float64, "double" }, + }; + + public static string ToC(this PrimType primType) => + _primTypeToC[primType]; + + public static string ToC(this DataType dataType) => dataType switch + { + PrimType ptype => ptype.ToC(), + PointerType { ElemType: PrimType etype } => etype.ToC() + "*", + _ => throw new NotSupportedException(dataType.ToString()), + }; + + public static string ToC(this BinaryOp binaryOp) => binaryOp switch + { + BinaryOp.Add => "+", + BinaryOp.Sub => "-", + BinaryOp.Mul => "*", + BinaryOp.Div => "/", + _ => throw new NotSupportedException(binaryOp.ToString()) + }; +} \ No newline at end of file diff --git a/modules/Nncase.Modules.CPU/CodeGen/CSourceUtilities.cs b/modules/Nncase.Modules.CPU/CodeGen/CSourceUtilities.cs new file mode 100644 index 0000000000..fbcf4e372d --- /dev/null +++ b/modules/Nncase.Modules.CPU/CodeGen/CSourceUtilities.cs @@ -0,0 +1,40 @@ +using Nncase.Diagnostics; +using Nncase.IR.Math; +namespace Nncase.CodeGen.CPU; + +internal static class CSourceUtilities +{ + public static string ContertBinary(Binary binary, CSymbol[] arguments) + { + var lhs = arguments[Binary.Lhs.Index].Name; + var rhs = arguments[Binary.Rhs.Index].Name; + string str; + switch (binary.BinaryOp) + { + case BinaryOp.Add or BinaryOp.Sub or BinaryOp.Mul or BinaryOp.Div: + str = ($"({lhs} {binary.BinaryOp.ToC()} {rhs})"); + break; + default: + throw new NotSupportedException(); + } + + return str; + } + + internal static string ContertUnary(Unary op, CSymbol[] arguments) + { + var input = arguments[Unary.Input.Index].Name; + string str; + switch (op.UnaryOp) + { + case UnaryOp.Neg: + str = ($"!{input}"); + break; + default: + str = ($"nncase_mt->{op.UnaryOp.ToString()}{input}"); + break; + } + + return str; + } +} \ No newline at end of file diff --git a/modules/Nncase.Modules.CPU/CodeGen/CSourceVisitor.cs b/modules/Nncase.Modules.CPU/CodeGen/CSourceVisitor.cs deleted file mode 100644 index 60453f1563..0000000000 --- a/modules/Nncase.Modules.CPU/CodeGen/CSourceVisitor.cs +++ /dev/null @@ -1,365 +0,0 @@ -// Copyright (c) Canaan Inc. All rights reserved. -// Licensed under the Apache license. See LICENSE file in the project root for full license information. - -using System; -using System.Collections; -using System.Collections.Generic; -using System.IO; -using System.Linq; -using System.Runtime.InteropServices; -using System.Text; -using NetFabric.Hyperlinq; -using Nncase.Diagnostics; -using Nncase.IR; -using Nncase.Runtime; -using Nncase.TIR; - -namespace Nncase.CodeGen; - -/// -/// the c symbol define. -/// -internal struct CSymbol -{ - public string Type; - public StringBuilder Doc; - - public CSymbol(string type, StringBuilder doc) - { - Type = type; - Doc = doc; - } - - public override string ToString() => $"{Type} {Doc}"; -} - -/// -/// convert the type/op to c name. -/// -internal static class NameConverter -{ - private static readonly Dictionary _primTypeToC = new() - { - { DataTypes.Boolean, "bool" }, - { DataTypes.Int8, "int8_t" }, - { DataTypes.Int16, "int16_t" }, - { DataTypes.Int32, "int32_t" }, - { DataTypes.Int64, "int64_t" }, - { DataTypes.UInt8, "uint8_t" }, - { DataTypes.UInt16, "uint16_t" }, - { DataTypes.UInt32, "uint32_t" }, - { DataTypes.UInt64, "uint64_t" }, - { DataTypes.Float32, "float" }, - { DataTypes.Float64, "double" }, - }; - - public static string ToC(this PrimType primType) => - _primTypeToC[primType]; - - public static string ToC(this DataType dataType) => dataType switch - { - PrimType ptype => ptype.ToC(), - PointerType { ElemType: PrimType etype } => etype.ToC() + "*", - _ => throw new NotSupportedException(dataType.ToString()), - }; -} - -/// -/// collect the csymbol's parameter. -/// -internal class CSymbolParamList : IParameterList, IEnumerable -{ - private CSymbol[] _symbols; - - public CSymbolParamList(CSymbol[] symbols) - { - this._symbols = symbols; - } - - public CSymbol this[ParameterInfo parameter] => _symbols[parameter.Index]; - - public CSymbol this[int index] => _symbols[index]; - - public IEnumerator GetEnumerator() - { - return ((IEnumerable)_symbols).GetEnumerator(); - } - - IEnumerator IEnumerable.GetEnumerator() - { - return _symbols.GetEnumerator(); - } -} - -/// -/// visitor for the build c source code, the expr vistor return (type string , name string). -/// -internal class CSourceHostBuildVisior : ExprFunctor -{ - /// - /// source writer . - /// TODO we need the decl writer. - /// - private readonly ScopeWriter _scope; - - /// - /// symbols name memo. - /// - private readonly Dictionary _symbols = new(ReferenceEqualityComparer.Instance); - - /// - /// Initializes a new instance of the class. - /// . - /// - /// TextWriter. - public CSourceHostBuildVisior(TextWriter textWriter) - { - _scope = new ScopeWriter(textWriter); - - // insert some declare - _scope.IndWriteLine(@" -#ifdef _WIN32 -#define EXPORT_API __declspec(dllexport) -#else -#define EXPORT_API -#endif"); - _scope.IndWriteLine("#include "); - } - - /// - /// void (*fun_ptr)(int). - /// - public string CallableTypeToPtr(CallableType type, string name) => $"{VisitType(type.ReturnType)} (*{name}_ptr)({string.Join(",", type.Parameters.Select(VisitType))})"; - - /// - public override string VisitType(TensorType type) - { - if (!type.IsScalar) - { - throw new NotSupportedException($"{type}"); - } - - return type.DType.ToC(); - } - - /// - public override string VisitType(TupleType type) => type == TupleType.Void ? "void" : throw new InvalidProgramException($"The C Source Must Not Have TupleType {type}!"); - - /// - protected override CSymbol VisitCall(Call expr) - { - if (_symbols.TryGetValue(expr, out var symbol)) - { - return symbol; - } - - var target = Visit(expr.Target); - var args = new CSymbolParamList(expr.Arguments.AsValueEnumerable().Select(Visit).ToArray()); - var type = VisitType(expr.CheckedType!); - _scope.Push(); - switch (expr.Target) - { - case IR.Math.Binary: - _scope.Append($"({args[0].Doc} {target.Doc} {args[1].Doc})"); - break; - case Store: - _scope.Append($"{args[Store.Handle].Doc}[{args[Store.Index].Doc}] = {args[Store.Value].Doc}"); - break; - case Load: - _scope.Append($"{args[Store.Handle].Doc}[{args[Store.Index].Doc}]"); - break; - case IR.Tensors.Cast: - _scope.Append($"(({type}){args[IR.Tensors.Cast.Input].Doc})"); - break; - default: - _scope.Append($"{target.Doc}({string.Join(", ", args.Select(x => x.Doc))})"); - break; - } - - symbol = new(type, _scope.Pop()); - _symbols.Add(expr, symbol); - return symbol; - } - - /// - protected override CSymbol VisitConst(Const expr) - { - if (_symbols.TryGetValue(expr, out var symbol)) - { - return symbol; - } - - if (expr.CheckedType is TensorType ttype && ttype.IsScalar) - { - var literal = $"{expr}" switch - { - "True" => "1", - "False" => "0", - var x => x, - }; - symbol = new(VisitType(ttype), new(literal)); - } - else - { - throw new NotSupportedException($"Not Support {expr.CheckedType} Const"); - } - - _symbols.Add(expr, symbol); - return symbol; - } - - /// - protected override CSymbol VisitPrimFunction(PrimFunction expr) - { - if (_symbols.TryGetValue(expr, out var symbol)) - { - return symbol; - } - - var retType = VisitType(((CallableType)expr.CheckedType!).ReturnType); - _scope.Push(); - - // 1. Function signature - _scope.IndWrite($"EXPORT_API {retType} {expr.Name}({string.Join(", ", expr.Parameters.AsValueEnumerable().Select(Visit).ToArray())}) {{"); - - // 2. Function body - using (_scope.IndentUp()) - { - _scope.Append(Visit(expr.Body).Doc); - } - - // 3. Function closing - _scope.IndWrite("}"); - symbol = new(CallableTypeToPtr((CallableType)expr.CheckedType!, expr.Name), _scope.Pop()); - - // 4. write whole code - _scope.IndWrite(symbol.Doc); - return symbol; - } - - /// - protected override CSymbol VisitOp(Op expr) - { - if (_symbols.TryGetValue(expr, out var symbol)) - { - return symbol; - } - - symbol = new("Invalid Op", new(expr switch - { - IR.Math.Binary op => op.BinaryOp switch - { - BinaryOp.Add => "+", - BinaryOp.Sub => "-", - BinaryOp.Mul => "*", - BinaryOp.Div => "/", - BinaryOp.Mod => "%", - _ => throw new ArgumentOutOfRangeException(op.BinaryOp.ToString()), - }, - TIR.Store op => "Store", - TIR.Load op => "Load", - IR.Tensors.Cast op => op.NewType.ToC(), - _ => throw new NotSupportedException($"{expr.GetType().Name}"), - })); - _symbols.Add(expr, symbol); - return symbol; - } - - /// - protected override CSymbol VisitVar(Var expr) - { - if (_symbols.TryGetValue(expr, out var symbol)) - { - return symbol; - } - - var isymbol = _scope.GetUniqueVarSymbol(expr); - symbol = new(VisitType(expr.CheckedType!), isymbol.Span); - _symbols.Add(expr, symbol); - return symbol; - } - - /// - protected override CSymbol VisitFor(For expr) - { - if (_symbols.TryGetValue(expr, out var symbol)) - { - return symbol; - } - - _scope.Push(); - - // 1. For Loop signature - var loopVar = Visit(expr.LoopVar); - _scope.Append($"for ({loopVar} = {Visit(expr.Domain.Start).Doc}; {loopVar.Doc} < {Visit(expr.Domain.Stop).Doc}; {loopVar.Doc}+={expr.Domain.Step}) {{"); - - // 2. For Body - _scope.Append(Visit(expr.Body).Doc); - - // 3. For closing - _scope.IndWrite("}"); - symbol = new(VisitType(expr.CheckedType!), _scope.Pop()); - _symbols.Add(expr, symbol); - return symbol; - } - - /// - protected override CSymbol VisitSequential(Sequential expr) - { - if (_symbols.TryGetValue(expr, out var symbol)) - { - return symbol; - } - - _scope.Push(); - _scope.AppendLine(string.Empty); - using (_scope.IndentUp()) - { - foreach (var i in Enumerable.Range(0, expr.Count)) - { - if (i == expr.Count - 1 && - expr.Fields[i].CheckedType is TensorType) - { - _scope.IndWrite("return "); - } - else - { - _scope.IndWrite(string.Empty); - } - - _scope.Append(Visit(expr.Fields[i]).Doc); - if (expr.Fields[i] is Call) - { - _scope.AppendLine(";"); - } - else - { - _scope.AppendLine(string.Empty); - } - } - } - - symbol = new(VisitType(expr.CheckedType!), _scope.Pop()); - _symbols.Add(expr, symbol); - return symbol; - } - - /// - protected override CSymbol VisitIfThenElse(IfThenElse expr) - { - if (_symbols.TryGetValue(expr, out var symbol)) - { - return symbol; - } - - _scope.Push(); - _scope.Append($"if({Visit(expr.Condition).Doc}) {{"); - _scope.Append(Visit(expr.Then).Doc); - _scope.IndWrite("} else {"); - _scope.Append(Visit(expr.Else).Doc); - _scope.IndWrite("}"); - symbol = new(VisitType(expr.CheckedType!), _scope.Pop()); - _symbols.Add(expr, symbol); - return symbol; - } -} diff --git a/modules/Nncase.Modules.CPU/CodeGen/FunctionBuilder.cs b/modules/Nncase.Modules.CPU/CodeGen/FunctionBuilder.cs index 854a8582b0..969fbcec47 100644 --- a/modules/Nncase.Modules.CPU/CodeGen/FunctionBuilder.cs +++ b/modules/Nncase.Modules.CPU/CodeGen/FunctionBuilder.cs @@ -19,43 +19,6 @@ internal class FunctionBuilder : IDisposable private readonly BinaryWriter _textWriter; private readonly BinaryWriter _rdataWriter; - /// - /// NOTE sync with the cpu runtime function. - /// - [StructLayout(LayoutKind.Sequential)] - private struct MemoryRange - { - public uint Start; - public uint Size; - } - - /// - /// NOTE sync with the cpu runtime function. - /// - [StructLayout(LayoutKind.Sequential)] - private unsafe struct DescHeader - { - /// - /// input pool size. - /// - public uint InputPoolSize; - - /// - /// output pool size. - /// - public uint OutputPoolSize; - - /// - /// input numbers. - /// - public uint Inputs; - - /// - /// output numbers. - /// - public uint Outputs; - } - public FunctionBuilder(uint id, BinaryWriter rdataWriter) { _id = id; @@ -65,39 +28,12 @@ public FunctionBuilder(uint id, BinaryWriter rdataWriter) public unsafe LinkableFunction Build(TIR.PrimFunction function) { - // 1. write the inst - // new InstSerializeVisitor(_textWriter).Visit(function.Body); - - // 2. write the desc - var descContent = new MemoryStream(); - using (var descWriter = new BinaryWriter(descContent, Encoding.UTF8)) - { - DescHeader header = new() { InputPoolSize = 0, OutputPoolSize = 0, Inputs = 0, Outputs = 0 }; - long headerStart = descWriter.Position(); - descWriter.Skip((ulong)sizeof(DescHeader)); - - foreach (var input in function.Parameters.AsValueEnumerable() - .Where(buf => buf.MemLocation == Schedule.MemoryLocation.Input)) - { - header.Inputs++; - var rg = new MemoryRange { Start = checked((uint)input.Start), Size = checked((uint)input.Size) }; - descWriter.Write(ref rg); - header.InputPoolSize = Math.Max(header.InputPoolSize, rg.Start + rg.Size); - } - - foreach (var output in function.Parameters.AsValueEnumerable().Where(buf => buf.MemLocation == Schedule.MemoryLocation.Output)) - { - header.Outputs++; - var rg = new MemoryRange { Start = checked((uint)output.Start), Size = checked((uint)output.Size) }; - descWriter.Write(ref rg); - header.OutputPoolSize = Math.Max(header.OutputPoolSize, rg.Start + rg.Size); - } - - descWriter.Position(headerStart); - descWriter.Write(ref header); - } + // 1. convert func to csource + var visitor = new CSourceConvertVisitor(); + visitor.Visit(function); + var functionCSource = visitor.GetFunctionCSource(); - // 3. write the rdata + // 2. write the rdata foreach (var buffer in function.SchedResult.Rdatas) { var bytes = buffer.Const!.Value.BytesBuffer; @@ -110,7 +46,7 @@ public unsafe LinkableFunction Build(TIR.PrimFunction function) _rdataWriter.Write(bytes); } - return new LinkableFunction(_id, function, _textContent.ToArray(), descContent.ToArray()); + return new LinkableFunction(_id, function, functionCSource); } public void Dispose() diff --git a/modules/Nncase.Modules.CPU/CodeGen/FunctionCSource.cs b/modules/Nncase.Modules.CPU/CodeGen/FunctionCSource.cs new file mode 100644 index 0000000000..c3a1e15066 --- /dev/null +++ b/modules/Nncase.Modules.CPU/CodeGen/FunctionCSource.cs @@ -0,0 +1,149 @@ +// Copyright (c) Canaan Inc. All rights reserved. +// Licensed under the Apache license. See LICENSE file in the project root for full license information. + + +using System; +using System.Collections; +using System.Collections.Generic; +using System.Diagnostics; +using System.IO; +using System.Linq; +using System.Runtime.InteropServices; +using System.Text; +using Nncase.IR; +using Nncase.Schedule; +using Nncase.TIR; + +namespace Nncase.CodeGen; + +internal sealed class FunctionCSource +{ + public FunctionCSource(string declaration, string implementation) + { + Declaration = declaration; + Implementation = implementation; + } + + public string Declaration { get; } + public string Implementation { get; } +} + + +/// +/// the csource code compiler. +/// +public class CSourceCompiler +{ + /// + /// compiler exe name + /// + string _exe = "", _arch = "", _ext = ""; + + /// + /// select current pattern's exe + /// + /// + void PlatformSpecific() + { + if (RuntimeInformation.IsOSPlatform(OSPlatform.Linux)) + { + _exe = "gcc"; + _ext = "so"; + } + else if (RuntimeInformation.IsOSPlatform(OSPlatform.OSX)) + { + _exe = "clang"; + _ext = "dylib"; + } + else if (RuntimeInformation.IsOSPlatform(OSPlatform.Windows)) + { + _exe = "cmd"; + _ext = "dll"; + } + } + + void ArchSpecific() + { + _arch = RuntimeInformation.OSArchitecture switch + { + Architecture.X64 => RuntimeInformation.IsOSPlatform(OSPlatform.Linux) ? "x86-64" : "x86_64", + Architecture.Arm64 => "arm64", + _ => throw new NotSupportedException(RuntimeInformation.OSArchitecture.ToString()), + }; + } + + string ArgumentsSpecific(string sourcePath, string outPath) + { + if (RuntimeInformation.IsOSPlatform(OSPlatform.Linux)) + { + return $"{sourcePath} -fPIC -shared -march={Arch} -o {outPath}"; + } + else if (RuntimeInformation.IsOSPlatform(OSPlatform.OSX)) + { + return $"{sourcePath} -fPIC -shared -arch {Arch} -o {outPath}"; + } + else if (RuntimeInformation.IsOSPlatform(OSPlatform.Windows)) + { + var vsdir = Environment.GetEnvironmentVariable("VSAPPIDDIR") ?? throw new InvalidOperationException("Cannot find vs"); + var vcvardir = Path.Combine(vsdir, "..\\..\\VC\\Auxiliary\\Build\\vcvarsall.bat"); + return $"/C (\"{vcvardir}\" x64) && (cl /D_USRDLL /D_WINDLL \"{sourcePath}\" /MT /link /DLL /OUT:\"{outPath}\")"; + } + throw new System.ArgumentOutOfRangeException("Only Support Linux/Osx/Windows"); + } + + protected string Exe + { + get => _exe; + } + + protected string Arch + { + get => _arch; + } + + protected string Ext + { + get => _ext; + } + + public CSourceCompiler() + { + PlatformSpecific(); + ArchSpecific(); + } + + /// + /// compile the source txt, write to the out_path + /// + /// c source code + /// out .so path + /// outPath + public string Compile(string sourcePath, string outPath) + { + var errMsg = new StringBuilder(); + using (var errWriter = new StringWriter(errMsg)) + { + using (var proc = new Process()) + { + proc.StartInfo.FileName = Exe; + proc.StartInfo.Arguments = ArgumentsSpecific(sourcePath, outPath); + proc.StartInfo.RedirectStandardError = true; + proc.ErrorDataReceived += (sender, e) => errWriter.WriteLine(e.Data); + proc.Start(); + proc.BeginErrorReadLine(); + proc.WaitForExit(); + if (proc.ExitCode != 0) + { + throw new InvalidOperationException(errMsg.ToString()); + } + } + } + return outPath; + } + + /// + /// create the temp dll file and compile source + /// + /// + public string Compile(string sourcePath) => Compile(sourcePath, CodeGenUtil.GetTempFileName(Ext)); +} diff --git a/modules/Nncase.Modules.CPU/CodeGen/LinkableFunction.cs b/modules/Nncase.Modules.CPU/CodeGen/LinkableFunction.cs index e92d1f5b4e..bd6dc4c270 100644 --- a/modules/Nncase.Modules.CPU/CodeGen/LinkableFunction.cs +++ b/modules/Nncase.Modules.CPU/CodeGen/LinkableFunction.cs @@ -9,19 +9,21 @@ internal sealed class LinkableFunction : ILinkableFunction { private readonly byte[] _desc; - public LinkableFunction(uint id, TIR.PrimFunction sourceFunction, byte[] text, byte[] desc) + public LinkableFunction(uint id, TIR.PrimFunction sourceFunction, FunctionCSource funcCSource) { Id = id; SourceFunction = sourceFunction; - Text = text; - _desc = desc; - Sections = new LinkedSection[] { new(_desc, ".desc", 0, 8, (uint)_desc.Length) }; + FunctionCSource = funcCSource; + Text = Array.Empty(); + Sections = new LinkedSection[] { }; } public uint Id { get; } public BaseFunction SourceFunction { get; } + public FunctionCSource FunctionCSource { get; } + public byte[] Text { get; } public IEnumerable FunctionRefs => Enumerable.Empty(); diff --git a/modules/Nncase.Modules.CPU/CodeGen/LinkableModule.cs b/modules/Nncase.Modules.CPU/CodeGen/LinkableModule.cs index 53d8c79b18..8c8f5952a3 100644 --- a/modules/Nncase.Modules.CPU/CodeGen/LinkableModule.cs +++ b/modules/Nncase.Modules.CPU/CodeGen/LinkableModule.cs @@ -6,6 +6,7 @@ using System.Linq; using System.Text; using System.Threading.Tasks; +using Nncase.Diagnostics; using Nncase.Runtime.StackVM; namespace Nncase.CodeGen.CPU; @@ -26,20 +27,54 @@ public LinkableModule(byte[] rdata, IReadOnlyList functions) public ILinkedModule Link(ILinkContext linkContext) { + var csourcePath = LinkCSources(); + var elfPath = CompileCSource(csourcePath); + var text = File.ReadAllBytes(elfPath); + + if (DumpScope.Current.IsEnabled(DumpFlags.CodeGen)) + { + using (var fs = DumpScope.Current.OpenFile("cpuModule.c")) + { + File.Open(csourcePath, FileMode.Open, FileAccess.Read).CopyTo(fs); + } + } + var linkedFunctions = new List(); - var text = new MemoryStream(); - using (var bw = new BinaryWriter(text, Encoding.UTF8, true)) + foreach (var func in _functions) { - foreach (var func in _functions) + linkedFunctions.Add(new LinkedFunction(func.Id, func.SourceFunction, 0, 0, func.Sections)); + } + + return new LinkedModule(linkedFunctions, text, _rdata); + } + + private string LinkCSources() + { + var path = Path.GetTempFileName(); + using (var fs = File.OpenWrite(path)) + { + using (var writer = new StreamWriter(fs)) { - // FixFunctionRefs(func, linkContext); - bw.AlignPosition(_textAlignment); - var textBegin = bw.Position(); - bw.Write(func.Text); - linkedFunctions.Add(new LinkedFunction(func.Id, func.SourceFunction, (uint)textBegin, (uint)func.Text.Length, func.Sections)); + writer.WriteLine(CSourceBuiltn.Header); + foreach (var func in _functions) + { + writer.WriteLine(func.FunctionCSource.Declaration); + } + + foreach (var func in _functions) + { + writer.WriteLine(func.FunctionCSource.Implementation); + } } } - return new LinkedModule(linkedFunctions, text.ToArray(), _rdata); + return path; } + + private string CompileCSource(string sourcePath) + { + var compiler = new CSourceCompiler(); + return compiler.Compile(sourcePath, Path.GetTempFileName()); + } + } diff --git a/modules/Nncase.Modules.CPU/CodeGen/LinkedModule.cs b/modules/Nncase.Modules.CPU/CodeGen/LinkedModule.cs index 4ad317298f..8a300e55bf 100644 --- a/modules/Nncase.Modules.CPU/CodeGen/LinkedModule.cs +++ b/modules/Nncase.Modules.CPU/CodeGen/LinkedModule.cs @@ -15,15 +15,12 @@ internal sealed class LinkedModule : ILinkedModule public LinkedModule(IReadOnlyList functions, byte[] text, byte[] rdata) { Functions = functions; - Sections = new[] { - new LinkedSection(text, ".text", 0, 8, (uint)text.Length), - new LinkedSection(rdata, ".rdata", 0, 8, (uint)rdata.Length), - }; + Sections = new[] { new LinkedSection(text, ".text", 0, 8, (uint)text.Length) }; } - public string ModuleKind => Runtime.CPU.CPURTModule.Kind; + public string ModuleKind => Targets.CPUTarget.Kind; - public uint Version => Runtime.CPU.CPURTModule.Version; + public uint Version => 0; public IReadOnlyList Functions { get; } diff --git a/modules/Nncase.Modules.CPU/CodeGen/ModuleBuilder.cs b/modules/Nncase.Modules.CPU/CodeGen/ModuleBuilder.cs index 659a16dc9b..c68b7b16da 100644 --- a/modules/Nncase.Modules.CPU/CodeGen/ModuleBuilder.cs +++ b/modules/Nncase.Modules.CPU/CodeGen/ModuleBuilder.cs @@ -24,7 +24,7 @@ public ModuleBuilder(CompileOptions options) public CompileOptions CompileOptions { get; } /// - public string ModuleKind => Runtime.CPU.CPURTModule.Kind; + public string ModuleKind => Targets.CPUTarget.Kind; /// public ILinkableModule Build(IReadOnlyList functions) diff --git a/modules/Nncase.Modules.CPU/Evaluator/CPU/CPUKernelOp.cs b/modules/Nncase.Modules.CPU/Evaluator/CPU/CPUKernelOp.cs new file mode 100644 index 0000000000..fa18c5439c --- /dev/null +++ b/modules/Nncase.Modules.CPU/Evaluator/CPU/CPUKernelOp.cs @@ -0,0 +1,34 @@ +// Copyright (c) Canaan Inc. All rights reserved. +// Licensed under the Apache license. See LICENSE file in the project root for full license information. + +using System; +using System.Linq; +using Nncase.CostModel; +using Nncase.IR; +using Nncase.IR.CPU; + +namespace Nncase.Evaluator.CPU; + +/// +/// Evaluator for . +/// +public class CPUKernelOpEvaluator : IEvaluator, ITypeInferencer, ICostEvaluator +{ + /// + public IValue Visit(IEvaluateContext context, CPUKernelOp target) + { + return CompilerServices.EvaluateOp(target.Target, context); + } + + /// + public IRType Visit(ITypeInferenceContext context, CPUKernelOp target) + { + return CompilerServices.InferenceOp(target.Target, context, new()); + } + + /// + public Cost Visit(ICostEvaluateContext context, CPUKernelOp target) + { + return CompilerServices.EvaluateOpCost(target.Target, context); + } +} diff --git a/modules/Nncase.Modules.CPU/Evaluator/CPU/CPUModule.cs b/modules/Nncase.Modules.CPU/Evaluator/CPU/CPUModule.cs index 025c6e3b1b..aa01122091 100644 --- a/modules/Nncase.Modules.CPU/Evaluator/CPU/CPUModule.cs +++ b/modules/Nncase.Modules.CPU/Evaluator/CPU/CPUModule.cs @@ -13,6 +13,6 @@ internal class CPUModule : IApplicationPart { public void ConfigureServices(IRegistrator registrator) { - registrator.RegisterManyInterface(reuse: Reuse.Singleton); + registrator.RegisterManyInterface(reuse: Reuse.Singleton); } } diff --git a/modules/Nncase.Modules.CPU/Evaluator/CPU/CPUUnary.cs b/modules/Nncase.Modules.CPU/Evaluator/CPU/CPUUnary.cs deleted file mode 100644 index d48a45b8e0..0000000000 --- a/modules/Nncase.Modules.CPU/Evaluator/CPU/CPUUnary.cs +++ /dev/null @@ -1,141 +0,0 @@ -// Copyright (c) Canaan Inc. All rights reserved. -// Licensed under the Apache license. See LICENSE file in the project root for full license information. - -using System; -using System.Linq; -using Nncase.CostModel; -using Nncase.IR; -using Nncase.IR.CPU; -using OrtKISharp; - -namespace Nncase.Evaluator.CPU; - -/// -/// Evaluator for . -/// -public class CPUUnaryEvaluator : IEvaluator, ITypeInferencer, ICostEvaluator, IOpPrinter -{ - /// - public IValue Visit(IEvaluateContext context, CPUUnary unary) - { - var input_tensor = context.GetArgumentValueAsTensor(unary, CPUUnary.Input); - if (input_tensor.Shape.IsScalar) - { - if (input_tensor.ElementType == DataTypes.Int32) - { - return Value.FromTensor(Tensor.FromScalar(Compute_int(input_tensor.ToScalar(), unary.UnaryOp))); - } - else if (input_tensor.ElementType == DataTypes.Float32) - { - return Value.FromTensor(Tensor.FromScalar(Compute_float(input_tensor.ToScalar(), unary.UnaryOp))); - } - } - - var input = context.GetOrtArgumentValue(unary, CPUUnary.Input); - var result = unary.UnaryOp switch - { - UnaryOp.Abs => OrtKI.Abs(input), - UnaryOp.Acos => OrtKI.Acos(input), - UnaryOp.Acosh => OrtKI.Acosh(input), - UnaryOp.Asin => OrtKI.Asin(input), - UnaryOp.Asinh => OrtKI.Asinh(input), - UnaryOp.Ceil => OrtKI.Ceil(input), - UnaryOp.Cos => OrtKI.Cos(input), - UnaryOp.Cosh => OrtKI.Cosh(input), - UnaryOp.Exp => OrtKI.Exp(input), - UnaryOp.Floor => OrtKI.Floor(input), - UnaryOp.Log => OrtKI.Log(input), - UnaryOp.Neg => OrtKI.Neg(input), - UnaryOp.Round => OrtKI.Round(input), - UnaryOp.Rsqrt => OrtKI.Rsqrt(input), - UnaryOp.Sin => OrtKI.Sin(input), - UnaryOp.Sinh => OrtKI.Sinh(input), - UnaryOp.Sign => OrtKI.Sign(input), - UnaryOp.Sqrt => OrtKI.Sqrt(input), - UnaryOp.Square => OrtKI.Square(input), - UnaryOp.Tanh => OrtKI.Tanh(input), - UnaryOp.BitwiseNot => throw new NotSupportedException("NotSupported UnaryOp BitwiseNot"), - UnaryOp.LogicalNot => OrtKI.Not(input), - _ => throw new ArgumentOutOfRangeException(nameof(unary)), - }; - return result.ToValue(); - } - - /// - public IRType Visit(ITypeInferenceContext context, CPUUnary target) - { - var input = context.CheckArgumentType(target, CPUUnary.Input); - return Visit(input); - } - - /// - public Cost Visit(ICostEvaluateContext context, CPUUnary target) - { - var inputType = context.GetArgumentType(target, CPUUnary.Input); - var outputType = context.GetReturnType(); - - return new() - { - [CostFactorNames.MemoryLoad] = CostUtility.GetMemoryAccess(inputType), - [CostFactorNames.MemoryStore] = CostUtility.GetMemoryAccess(outputType), - [CostFactorNames.CPUCycles] = CostUtility.GetCPUCycles(outputType, CostUtility.GetCPUCyclesOfUnary(target.UnaryOp)), - }; - } - - /// - public string Visit(IIRPrinterContext context, CPUUnary target, bool iLmode) - { - var op_str = target.UnaryOp switch - { - UnaryOp.BitwiseNot => "!", - UnaryOp.LogicalNot => "!", - var op => op.ToString(), - }; - if (!iLmode) - { - return $"{op_str}({string.Join(", ", target.Parameters.Select(p => p.Name + ": " + context.GetArgument(target, p).Serialize()))})"; - } - - throw new NotSupportedException("ILmode = true"); - } - - private int Compute_int(int input, UnaryOp op) => op switch - { - UnaryOp.Ceil => input, - UnaryOp.Floor => input, - UnaryOp.Neg => -input, - UnaryOp.Abs => System.Math.Abs(input), - UnaryOp.Square => input * input, - _ => throw new ArgumentOutOfRangeException(nameof(op), $"NotSupported {nameof(op)} For Int"), - }; - - private float Compute_float(float input, UnaryOp op) => op switch - { - UnaryOp.Abs => System.MathF.Abs(input), - UnaryOp.Acos => System.MathF.Acos(input), - UnaryOp.Acosh => System.MathF.Acosh(input), - UnaryOp.Asin => System.MathF.Asin(input), - UnaryOp.Asinh => System.MathF.Asinh(input), - UnaryOp.Ceil => System.MathF.Ceiling(input), - UnaryOp.Cos => System.MathF.Cos(input), - UnaryOp.Cosh => System.MathF.Cosh(input), - UnaryOp.Exp => System.MathF.Exp(input), - UnaryOp.Floor => System.MathF.Floor(input), - UnaryOp.Log => System.MathF.Log(input), - UnaryOp.Neg => -input, - UnaryOp.Round => System.MathF.Round(input), - UnaryOp.Rsqrt => 1.0f / System.MathF.Sqrt(input), - UnaryOp.Sin => System.MathF.Sin(input), - UnaryOp.Sinh => System.MathF.Sinh(input), - UnaryOp.Sign => System.MathF.Sign(input), - UnaryOp.Sqrt => System.MathF.Sqrt(input), - UnaryOp.Square => input * input, - UnaryOp.Tanh => System.MathF.Tanh(input), - _ => throw new ArgumentOutOfRangeException(nameof(op), $"NotSupported {nameof(op)} For Float"), - }; - - private IRType Visit(TensorType input) - { - return input; - } -} diff --git a/modules/Nncase.Modules.CPU/IR/CPU/CPUUnary.cs b/modules/Nncase.Modules.CPU/IR/CPU/CPUKernelOp.cs similarity index 69% rename from modules/Nncase.Modules.CPU/IR/CPU/CPUUnary.cs rename to modules/Nncase.Modules.CPU/IR/CPU/CPUKernelOp.cs index 44dbb59fd8..1113d70e55 100644 --- a/modules/Nncase.Modules.CPU/IR/CPU/CPUUnary.cs +++ b/modules/Nncase.Modules.CPU/IR/CPU/CPUKernelOp.cs @@ -12,12 +12,12 @@ namespace Nncase.IR.CPU; [PatternFunctionalGenerator] -public sealed partial class CPUUnary : Op +public sealed partial class CPUKernelOp : Op { /// - /// Gets input. + /// Gets the target. /// - public static readonly ParameterInfo Input = new(typeof(CPUUnary), 0, "input"); + public Op Target { get; } - public UnaryOp UnaryOp { get; } + public override string DisplayProperty() => Target.GetType().Name; } diff --git a/modules/Nncase.Modules.CPU/IR/CPU/Functional.cs b/modules/Nncase.Modules.CPU/IR/CPU/Functional.cs index 9fd63b63a6..52869d0153 100644 --- a/modules/Nncase.Modules.CPU/IR/CPU/Functional.cs +++ b/modules/Nncase.Modules.CPU/IR/CPU/Functional.cs @@ -14,13 +14,13 @@ namespace Nncase.IR.F; public partial class CPU { /// - /// Call unary. + /// Call cpu kernel. /// - /// Unary operator. - /// Source expression. + /// Unary operator. + /// Source inputs. /// Result expression. - public static Call CPUUnary(UnaryOp unaryOp, Expr expr) + public static Call CPUKernel(Op target, params Expr[] inputs) { - return new Call(new CPUUnary(unaryOp), expr); + return new Call(new CPUKernelOp(target), inputs); } } diff --git a/modules/Nncase.Modules.CPU/Passes/CPUFusionToTirPass.cs b/modules/Nncase.Modules.CPU/Passes/CPUFusionToTirPass.cs index 5e1ff781b0..4afc56cb6f 100644 --- a/modules/Nncase.Modules.CPU/Passes/CPUFusionToTirPass.cs +++ b/modules/Nncase.Modules.CPU/Passes/CPUFusionToTirPass.cs @@ -1,7 +1,6 @@ // Copyright (c) Canaan Inc. All rights reserved. // Licensed under the Apache license. See LICENSE file in the project root for full license information. -#if false using System; using System.Collections.Generic; using System.Linq; @@ -21,12 +20,10 @@ namespace Nncase.Passes; internal sealed class CPUFusionToTirPass : ModulePass { private readonly TileOptions _tileOptions; - private readonly Dictionary _fusionMacsMap; public CPUFusionToTirPass(TileOptions tileOptions) { _tileOptions = tileOptions; - _fusionMacsMap = new(ReferenceEqualityComparer.Instance); } private IAnalyzerManager AnalyzerManager => CompileSession.GetRequiredService(); @@ -36,50 +33,21 @@ protected override Task RunCoreAsync(IRModule module, RunPassContext o { Dictionary fusionConertedCache = new(ReferenceEqualityComparer.Instance); - // convert the fusion as entry. - // for (int i = 0; i < module.Functions.Count; i++) - // { - // if (module.Functions[i] is Fusion { ModuleKind: CPUTarget.Kind } fusion) - // { - // TIR.PrimFunction primFunction; - // var visitor = new MultiLayerFusionConverter(_tileOptions); - // primFunction = visitor.VisitToPrimFunc(fusion); - // - // CompilerServices.InferenceType(primFunction); - // fusionConertedCache[fusion] = primFunction; - // module.Replace(i, primFunction); - // } - // } - - // convert the stackvm function call k510 fusion for (int i = 0; i < module.Functions.Count; i++) { - if (module.Functions[i] is Function { ModuleKind: CPUTarget.Kind } func) + if (module.Functions[i] is Function { ModuleKind: string kind } func && kind == Callable.StackVMModuleKind) { var analysis = new Dictionary { [typeof(IExprUserAnalysisResult)] = AnalyzerManager.GetAnaylsis(func), }; var rewriter = new DataFlowMergeRewriter(); var fusionCheckCache = new Dictionary(ReferenceEqualityComparer.Instance); - // var post = (Function)rewriter.Rewrite(func, new Mutators.IMergeRewriteRule[] { new GNNESameInputFusionMergeRule(), }, (rule, option) => new CPUFusionGroupMutator(fusionCheckCache, _tileOptions, rule, option), new() { AnalysisResults = analysis, MatchOptions = new Mutators.FusionGroupMutator.GroupedMatchOptions() }); - - // if (DumpScope.Current.IsEnabled(DumpFlags.PassIR)) - // { - // DumpScope.Current.DumpDotIR(post, "MultiLayer"); - // } - // post = (Function)rewriter.Rewrite( - // post, - // new Mutators.IMergeRewriteRule[] { - // new GNNESameInputFusionMergeRule(), - // }, - // (rule, option) => new CPUFusionGroupMutator(fusionCheckCache, _tileOptions, rule, option), - // new() { AnalysisResults = analysis, MatchOptions = new Mutators.FusionGroupMutator.GroupedMatchOptions() }); + var post = (Function)rewriter.Rewrite( + func, + new Mutators.IMergeRewriteRule[] { new CPUSameInputFusionMergeRule() }, + (rule, option) => new CPUFusionGroupMutator(fusionCheckCache, _tileOptions, rule, option), + new() { AnalysisResults = analysis, MatchOptions = new Mutators.FusionGroupMutator.GroupedMatchOptions() }); - // if (DumpScope.Current.IsEnabled(DumpFlags.PassIR)) - // { - // DumpScope.Current.DumpDotIR(post, "TwoLayer"); - // } - // var post = func; - var mutator = new CheckedConvertMutator(fusionConertedCache, _fusionMacsMap, fusionCheckCache, _tileOptions, options); + var mutator = new CheckedConvertMutator(fusionConertedCache, fusionCheckCache, _tileOptions, options); var new_func = (Function)mutator.Rewrite(post); CompilerServices.InferenceType(new_func); if (mutator.IsMutated) @@ -89,7 +57,6 @@ protected override Task RunCoreAsync(IRModule module, RunPassContext o } } - // add all prim func. foreach (var item in fusionConertedCache.Values) { if (item is PrimFunctionWrapper wrapper) @@ -101,20 +68,4 @@ protected override Task RunCoreAsync(IRModule module, RunPassContext o return Task.FromResult(module); } - - protected override async Task OnPassEndAsync(IRModule post, RunPassContext context) - { - await base.OnPassEndAsync(post, context); - if (DumpScope.Current.IsEnabled(DumpFlags.PassIR)) - { - using var writer = new StreamWriter(DumpScope.Current.OpenFile("mac.csv")); - foreach (var (fusion, mac) in _fusionMacsMap) - { - writer.WriteLine($"mac: {fusion.Name},{mac}"); - } - } - - _fusionMacsMap.Clear(); - } } -#endif diff --git a/modules/Nncase.Modules.CPU/Passes/Rules/LowerUnary.cs b/modules/Nncase.Modules.CPU/Passes/Rules/LowerUnary.cs index bb77cc1354..f0cbd5a862 100644 --- a/modules/Nncase.Modules.CPU/Passes/Rules/LowerUnary.cs +++ b/modules/Nncase.Modules.CPU/Passes/Rules/LowerUnary.cs @@ -28,6 +28,6 @@ public partial class LowerUnary : RewriteRule private Expr GetReplace(Unary unary, Expr input) { - return CPUUnary(unary.UnaryOp, input); + return CPUKernel(unary, input); } } diff --git a/modules/Nncase.Modules.CPU/Passes/Rules/MakeFusion.cs b/modules/Nncase.Modules.CPU/Passes/Rules/MakeFusion.cs index fb1bd25a22..d7644bf2b5 100644 --- a/modules/Nncase.Modules.CPU/Passes/Rules/MakeFusion.cs +++ b/modules/Nncase.Modules.CPU/Passes/Rules/MakeFusion.cs @@ -37,7 +37,7 @@ internal partial class CPUSingleInputFusion : FusionMaker } } -internal sealed class CPUUnaryFusion : CPUSingleInputFusion +internal sealed class CPUFusion : CPUSingleInputFusion { - public override string Name => "Unary"; + public override string Name => nameof(CPUFusion); } diff --git a/modules/Nncase.Modules.CPU/Passes/Tile/CPUFusionConverter.cs b/modules/Nncase.Modules.CPU/Passes/Tile/CPUFusionConverter.cs deleted file mode 100644 index 2c85b4deca..0000000000 --- a/modules/Nncase.Modules.CPU/Passes/Tile/CPUFusionConverter.cs +++ /dev/null @@ -1,148 +0,0 @@ -// Copyright (c) Canaan Inc. All rights reserved. -// Licensed under the Apache license. See LICENSE file in the project root for full license information. - -#if true -using System.Reactive; -using System.Runtime.CompilerServices; -using NetFabric.Hyperlinq; -using Nncase.Diagnostics; -using Nncase.IR; -using Nncase.IR.Buffers; -using Nncase.IR.CPU; -using Nncase.IR.F; -using Nncase.IR.Math; -using Nncase.Passes.Mutators; -using Nncase.PatternMatch; -using Nncase.Schedule; -using Nncase.Targets; -using Nncase.TIR; -using Nncase.TIR.Builders; -using Buffer = Nncase.TIR.Buffer; -using MathF = Nncase.IR.F.Math; -using Range = Nncase.TIR.Range; -using Tuple = Nncase.IR.Tuple; - -namespace Nncase.Passes.Tile; - -/// -/// name 分配器. -/// -internal sealed class NameAllocator -{ - public Dictionary NamePool { get; } = new(); - - public string Get(string name) - { - if (!NamePool.TryGetValue(name, out var count)) - { - count = 0; - } - - NamePool[name] = count + 1; - return count == 0 ? name : $"{name}_{count}"; - } -} - -internal class CPUFusionConverter -{ - public NameAllocator NameAllocator { get; } = new(); - - /// - /// Gets tile size 的变量. - /// - public List TileSizeVars { get; } = new(); - - /// - /// Gets loop 变量. - /// - // public List LoopVars { get; } = new(); - - /// - /// Gets loop domains. - /// - public List LoopDomains { get; } = new(); - - /// - /// Gets nested loops. - /// - // public List> NestedLoops { get; } = new(); - - public TileOptions TileOptions { get; protected set; } = null!; - - /// - /// Gets or sets 总的loop count. - /// / - public virtual Expr LoopCount { get; protected set; } - - /// - /// Gets or sets ping pong 外层的tiling. - /// - public virtual Expr LoopCountOuter { get; protected set; } - - /// - /// Gets or sets ping pong 内侧的tiling. - /// - public virtual Expr LoopCountInner { get; protected set; } - - /// - /// Gets or sets 当前的fusion. - /// - public virtual Fusion CurrentFusion { get; protected set; } - - public virtual PrimFunction BuildPrimFunc(Fusion fusion) - { - // TODO: buffer顺序可能需要调整以保持原图的顺序 - var primFuncBuilder = T.PrimFunc(CurrentFusion.Name, CPUTarget.Kind, _ifBufferMap.Values.Union(_ofBufferMap.Values).Select(b => (PhysicalBuffer)b).ToArray()); - return primFuncBuilder.Build(); - } - - public virtual Expr Visit(Expr root) - { - return root switch - { - Call call => (call.Target switch - { - CPUUnary op => LowerCPUUnary(call, op), - _ => throw new NotSupportedException(), - }).Build(), - _ => T.Nop(), - }; - } - - protected virtual ISequentialBuilder LowerCPUUnary(Call call, CPUUnary op) - { - var prefix = NameAllocator.Get(nameof(CPUUnary)); - var inputCall = call[CPUUnary.Input]; - T.PhysicalBuffer(inputCall.CheckedDataType, MemoryLocation.Input, inputCall.CheckedShape, out var ddrIf); - T.PhysicalBuffer(call.CheckedDataType, MemoryLocation.Output, call.CheckedShape.ToValueArray(), out var ddrOf); - _ifBufferMap.Add(call, ddrIf); - _ofBufferMap.Add(call, ddrOf); - - List LoopVars = new(); - List> NestedLoops = new(); - List LoopDomains = call.CheckedShape.Select(s => new Range(0, 1, s.FixedValue)).ToList(); - - var seq = T.Sequential().Body(Visit(inputCall)); - - for (int i = 0; i < call.CheckedShape.Rank; i++) - { - NestedLoops.Add(T.ForLoop(out var loopVar, LoopDomains[i], LoopMode.Unrolled, $"loop_var_{i}")); - LoopVars.Add(loopVar); - } - - NestedLoops[^1].Body( - op.UnaryOp switch - { - // TODO: body的实现 - UnaryOp.Abs => T.Nop(), - _ => throw new NotSupportedException(), - }); - - seq.Body(NestedLoops[0].Body()); - return seq; - } - - private readonly Dictionary _ifBufferMap = new(ReferenceEqualityComparer.Instance); - private readonly Dictionary _ofBufferMap = new(ReferenceEqualityComparer.Instance); -} -#endif diff --git a/modules/Nncase.Modules.CPU/Passes/Tile/CPUFusionGroupMutator.cs b/modules/Nncase.Modules.CPU/Passes/Tile/CPUFusionGroupMutator.cs index 3cd7ab7c9d..d11af85bb5 100644 --- a/modules/Nncase.Modules.CPU/Passes/Tile/CPUFusionGroupMutator.cs +++ b/modules/Nncase.Modules.CPU/Passes/Tile/CPUFusionGroupMutator.cs @@ -1,6 +1,5 @@ // Copyright (c) Canaan Inc. All rights reserved. // Licensed under the Apache license. See LICENSE file in the project root for full license information. -#if false using System.Runtime.CompilerServices; using Nncase.IR; using Nncase.PatternMatch; @@ -56,45 +55,38 @@ public CPUFusionGroupMutator( /// public override bool MergedFusionCheckCallBack(Fusion mergedFusion, HashSet candidateFusions) { - // note the gnne activate must be first layer. - // if (mergedFusion.Body is Call { Target: IR.K510.GNNEStore } st_call && - // st_call[IR.K510.GNNEStore.Input] is Call { Target: IR.K510.GNNEActivation }) + + // var checker = (IFusionChecker)Activator.CreateInstance(typeof(T), new object[] { _tileOptions })!; + // var ret = checker.Check(mergedFusion, PassOptions); + // if (ret) // { - // return false; + // FusioncheckerCache.Add(mergedFusion, checker); + // foreach (var cand in candidateFusions) + // { // release the merged fusion. + // FusioncheckerCache.Remove(cand); + // } // } - var checker = (IFusionChecker)Activator.CreateInstance(typeof(T), new object[] { _tileOptions })!; - var ret = checker.Check(mergedFusion, PassOptions); - if (ret) - { - FusioncheckerCache.Add(mergedFusion, checker); - foreach (var cand in candidateFusions) - { // release the merged fusion. - FusioncheckerCache.Remove(cand); - } - } - - return ret; + // return ret; + return false; } public override Expr MergedFusionRewriteCallBack(Expr mergedFusionBody) { - return CompilerServices.Rewrite(mergedFusionBody, new[] { new Rules.GNNE.Opt.FoldLoadStore() }, new()); + return mergedFusionBody; } } internal sealed class CheckedConvertMutator : ExprRewriter { private readonly Dictionary _fusionConertedCache; - private readonly IDictionary _fusionMacsMap; private readonly IReadOnlyDictionary _fusionCheckerCache; private readonly TileOptions _tileOptions; private readonly RunPassContext _passOptions; - public CheckedConvertMutator(Dictionary fusion_converted_cache, Dictionary fusionMacsMap, IReadOnlyDictionary fusionchecker_cache, TileOptions tileOptions, RunPassContext passOptions) + public CheckedConvertMutator(Dictionary fusion_converted_cache, IReadOnlyDictionary fusionchecker_cache, TileOptions tileOptions, RunPassContext passOptions) { _fusionConertedCache = fusion_converted_cache; - _fusionMacsMap = fusionMacsMap; _fusionCheckerCache = fusionchecker_cache; _tileOptions = tileOptions; _passOptions = passOptions; @@ -114,23 +106,8 @@ protected override Expr RewriteLeafFusion(Fusion expr) } else { - // if (CompilerServices.TryMatchRoot(fusion, Conv2DFusionConverter.Conv2DFusionPattern, out var matchResult)) - // { - // prim_func = Conv2DFusionConverter.VisitToPrimFunc(_tileOptions, fusion, matchResult, out _, out _); - // } - // else if (CompilerServices.TryMatchRoot(fusion, Conv2DTransposeFusionConverter.Conv2DFusionPattern, out matchResult)) - // { - // prim_func = Conv2DTransposeFusionConverter.VisitToPrimFunc(_tileOptions, fusion, matchResult, out _, out _); - // } - // else if (!_tileOptions.ForceMultiLayer && CompilerServices.TryMatchRoot(fusion, LSTMFusionConverter.FusionPattern, out matchResult)) - // { - // prim_func = LSTMFusionConverter.VisitToPrimFunc(_tileOptions, fusion, matchResult, out _, out _); - // } - // else - // { - // var visitor = new MultiLayerFusionConverter(_tileOptions); - // prim_func = visitor.VisitToPrimFunc(fusion); - // } + var converter = new SingleCPUFusionConverter(); + prim_func = converter.Visit(fusion); } BaseFunction? convert_func = prim_func; @@ -153,7 +130,7 @@ protected override Expr RewriteLeafCall(Call expr) int param_count = 0; foreach (var b in prim_func.Parameters) { - if (b.MemLocation == Schedule.MemoryLocation.Input) + if (b.MemLocation == TIR.MemoryLocation.Input) { if (is_input) { @@ -184,5 +161,3 @@ protected override Expr RewriteLeafCall(Call expr) return expr; } } - -#endif diff --git a/modules/Nncase.Modules.CPU/Passes/Tile/LayerFusionConverter.cs b/modules/Nncase.Modules.CPU/Passes/Tile/LayerFusionConverter.cs deleted file mode 100644 index f205692a57..0000000000 --- a/modules/Nncase.Modules.CPU/Passes/Tile/LayerFusionConverter.cs +++ /dev/null @@ -1,1264 +0,0 @@ -// Copyright (c) Canaan Inc. All rights reserved. -// Licensed under the Apache license. See LICENSE file in the project root for full license information. - -#if false -using System.Reactive; -using System.Runtime.CompilerServices; -using NetFabric.Hyperlinq; -using Nncase.Diagnostics; -using Nncase.IR; -using Nncase.IR.Buffers; -using Nncase.IR.F; -using Nncase.IR.K510; -using Nncase.IR.Math; -using Nncase.Passes.BufferSchedule; -using Nncase.Passes.Mutators; -using Nncase.Runtime.K510; -using Nncase.Schedule; -using Nncase.TIR; -using Nncase.TIR.Builders; -using Nncase.TIR.K510; -using Nncase.TIR.K510.Builders; -using Nncase.TIR.K510.Instructions; -using Buffer = Nncase.TIR.Buffer; -using MathF = Nncase.IR.F.Math; -using Range = Nncase.TIR.Range; -using Tuple = Nncase.IR.Tuple; - -namespace Nncase.Passes.Tile; - -public sealed class BufferRegionView -{ - private Expr? _cache; - - private Expr[]? _condtion_buffer_regions; - - private Expr[]? _region_size; - - public BufferRegionView(IEnumerable buffers, IEnumerable bounds, IEnumerable region, IndexMapKey key) - : this(buffers, bounds, region, key, 0, null) - { - } - - public BufferRegionView(IEnumerable buffers, IEnumerable bounds, IEnumerable region, IndexMapKey key, Expr loopCount, int? promote) - { - Buffers = buffers.ToArray(); - Region = region.ToArray(); - LoopCount = loopCount; - Parent = null; - Key = key; - Promote = promote; - Bounds = bounds.ToArray(); - } - - public IndexMapKey Key { get; } - - /// - /// Gets 记录他的loop count. - /// - public Expr LoopCount { get; } - - public int? Promote { get; } - - public IReadOnlyList Bounds { get; } - - public IReadOnlyList Buffers { get; } - - public IReadOnlyList Region { get; } - - public BufferRegionView? Parent { get; set; } - - public ReadOnlySpan Dimensions => Buffers[0].Dimensions; - - /// - /// Gets 返回带有condition的buffer region的表达式. - /// - public IReadOnlyList BufferRegions - { - get - { - _condtion_buffer_regions ??= Buffers.Count == 0 ? Array.Empty() : Buffers.Select(b => new BufferRegion(b, Region.ToArray())).ToArray(); - return _condtion_buffer_regions; - } - } - - public BufferRegionView this[params Range[] ranges] - { - get => new(Buffers, Bounds, Region.Zip(ranges).Select(tp => tp.Second.Equals(Range.All) ? tp.First : tp.Second.Stop switch { Call { Target: Unary { UnaryOp: UnaryOp.Neg } } => throw new NotSupportedException("Neg Region!"), _ => tp.Second, }), Key, LoopCount, Promote) { Parent = Parent is null ? this : Parent, }; // if stop is neg, add the shape, else return the origin range. - } - - /// - /// convert the BufferRegionView to expr. - /// - /// 当开启ping pong时,如果 - /// - /// - /// view. - public static implicit operator Expr(BufferRegionView view) - { - if (view._cache is not null) - { - return view._cache; - } - - Expr expr; - if (view.Buffers.Count == 0) - { - expr = IR.None.Default; - } - else if (view.Buffers.Count == 1) - { - expr = view.BufferRegions[0]; - } - else if (view.Buffers.Count >= 2) - { - expr = new Tuple(view.BufferRegions.ToArray())[view.LoopCount % view.Buffers.Count]; - } - else - { - throw new NotSupportedException(); - } - - view._cache = expr; - return view._cache; - } - - public static BufferRegionView None(IndexMapKey key) => new(Array.Empty(), new IRArray(), new IRArray(), key); - - public ReadOnlySpan RegionSize() - { - _region_size ??= Region.AsValueEnumerable().Select(r => r.Stop - r.Start).ToArray(); - return _region_size; - } - - public Expr RegionSize(int i) => RegionSize()[i]; -} - -/// -/// name 分配器. -/// -internal sealed class NameAllocator -{ - public Dictionary NamePool { get; } = new(); - - public string Get(string name) - { - if (!NamePool.TryGetValue(name, out var count)) - { - count = 0; - } - - NamePool[name] = count + 1; - return count == 0 ? name : $"{name}_{count}"; - } -} - -/// -/// buffer region view. -/// -internal sealed class LogiclPrimFuncCloner : ExprCloner -{ - protected override Expr VisitLeafLogicalBuffer(LogicalBuffer buffer, Unit context) - { - return buffer; - } - - protected override Expr VisitLeafPhysicalBuffer(PhysicalBuffer buffer, Unit context) - { - return buffer; - } - - protected override Expr VisitVar(Var var, Unit context) - { - return var; - } -} - -internal sealed record ReIndexCacheKey(IBoundsInferGraph BoundsInferGraph, IndexMapKey From, IndexMapKey To, IRArray FromRegion, int? Promote) -{ -} - -internal abstract class LayerFusionConverter -{ - public NameAllocator NameAllocator { get; } = new(); - - /// - /// Gets map the graph expression and it's bufferRegion. - /// - public Dictionary KeyToViewMap { get; } = new(); - - /// - /// Gets because of the index map can't create by var, so need other map save the relationship. - /// - public Dictionary VarToKeyMap { get; } = new(ReferenceEqualityComparer.Instance); - - /// - /// Gets tile size 的变量. - /// - public List TileSizeVars { get; } = new(); - - /// - /// Gets loop 变量. - /// - public List LoopVars { get; } = new(); - - /// - /// Gets loop domains. - /// - public List LoopDomains { get; } = new(); - - /// - /// Gets 所有的blocks - /// 最终是: - /// mainBlock - /// loop n - /// block n - /// loop c - /// block c - /// . - /// . - /// - public List NestedBlocks { get; } = new(); - - /// - /// Gets nested loops. - /// - public List> NestedLoops { get; } = new(); - - public TileOptions TileOptions { get; protected set; } = null!; - - /// - /// Gets or sets 默认的bounds infer graph. - /// - public abstract IBoundsInferGraph BoundsInferGraph { get; protected set; } - - /// - /// Gets or sets 总的loop count. - /// / - public abstract Expr LoopCount { get; protected set; } - - /// - /// Gets or sets ping pong 外层的tiling. - /// - public abstract Expr LoopCountOuter { get; protected set; } - - /// - /// Gets or sets ping pong 内侧的tiling. - /// - public abstract Expr LoopCountInner { get; protected set; } - - /// - /// Gets or sets 当前的fusion. - /// - public abstract Fusion CurrentFusion { get; protected set; } - - /// - /// Gets glb reindex cache. - /// - protected Dictionary ToRegion, IReadOnlyList<(Expr Before, Expr After)> Paddings)> GlbReindexCache { get; } = new(); - - public abstract Expr Visit(Fusion fusion); - - public virtual PrimFunction BuildLogicalPrimFunc(Expr bodySeq) - { - var inputs_buffer = CurrentFusion.Parameters.ToArray().Select(p => (PhysicalBuffer)KeyToViewMap[VarToKeyMap[p]].Buffers[0]); - var primFuncBuilder = T.PrimFunc(CurrentFusion.Name, K510RTModule.Kind, inputs_buffer.Concat(new[] { (PhysicalBuffer)KeyToViewMap[(Call)CurrentFusion.Body].Buffers[0] }).ToArray()); - - NestedBlocks[^1].Body(bodySeq); - primFuncBuilder.Body( - I.MmuConf(0, 0, MMU_CONF_WIDTH._8, 0, ExtCompilerServices.Env.GlbDepth), // 把整个glb当作连续内存使用. - NestedBlocks[0], - I.Fence()); - - var logicalPrimFunc = primFuncBuilder.Build(); - logicalPrimFunc = (PrimFunction)new Mutators.SimplifyBounds().Rewrite(logicalPrimFunc); - logicalPrimFunc.InferenceType(); - GlbReindexCache.Clear(); - return logicalPrimFunc; - } - - public abstract bool BalanceTileSize(int[] tile_size, Segment[] search_spaces); - - public virtual PrimFunction BuildPhysicalPrimFunc(int[] final_tile_size, IReadOnlyDictionary sched_result, PrimFunction logicalPrimFunc) - { - var physicalizer = new BufferPhysicalizer(final_tile_size, sched_result, TileSizeVars); - var physicalPrimFunc = (PrimFunction)physicalizer.Rewrite(logicalPrimFunc); - return physicalPrimFunc; - } - - public virtual int[] SearchTileSize(ISearchTileGenerator tile_generator, PrimFunction logicalPrimFunc, bool multi_workers, bool hasResult, out ScheduledResponse response) - { - AllocationCache response_cache = new(); - bool schedule_status = false; - int[] final_tile = Array.Empty(); - - while (true) - { - var next_tile = tile_generator.GetNextTile(schedule_status).ToArray(); - if (next_tile.Length == 0) - { - break; - } - - schedule_status = TryScheduleNextTileSize(next_tile, logicalPrimFunc, response_cache, multi_workers, hasResult); - if (schedule_status) - { - final_tile = next_tile; - response_cache.CheckIn(); - } - } - - if (!final_tile.Any()) - { - response = new(new Dictionary(), null!, null!, logicalPrimFunc, null!, 0, false); - return final_tile; - } - - // take back last success allocation result - response = response_cache.GetLastSuccess(final_tile); - return final_tile; - } - - public virtual Expr Visit(IndexMapKey mapKey, string prefix, int? promote = null) - { - prefix = prefix + mapKey.Prefix; - return mapKey.Expr switch - { - Call call => (call.Target switch - { - GNNELoad op => LowerGnneLoad(mapKey, call, op, NameAllocator.Get(nameof(GNNELoad)), prefix, promote, true), - GNNEStore op => LowerGnneStore(call, op, NameAllocator.Get(nameof(GNNEStore)), prefix), - GNNEConv2D op => LowerGnneConv2D(mapKey, call, op, NameAllocator.Get(nameof(GNNEConv2D)), prefix), - GNNEConv2DTranspose op => LowerGnneConv2DTranspose(mapKey, call, op, NameAllocator.Get(nameof(GNNEConv2D)), prefix), - GNNEReduce op => LowerGnneReduce(mapKey, call, op, NameAllocator.Get(nameof(GNNEReduce)), prefix), - GNNEMeshNet op => LowerGnneMeshNet(mapKey, call, op, NameAllocator.Get(nameof(GNNEMeshNet)), prefix), - GNNETranspose op => LowerGnneTranspose(mapKey, call, op, NameAllocator.Get(nameof(GNNETranspose)), prefix), - GNNEActivation op => LowerGnneActivation(mapKey, call, op, NameAllocator.Get(nameof(GNNEActivation)), prefix), - GNNEPdpReduce op => LowerGnnePdpReduce(mapKey, call, op, NameAllocator.Get(nameof(GNNEPdpReduce)), prefix), - GNNECrop op => LowerGnneCrop(mapKey, call, op, NameAllocator.Get(nameof(GNNECrop)), prefix), - Uninitialized => T.Sequential(), - _ => throw new NotSupportedException(), - }).Build(), - _ => T.Nop(), - }; - } - - /// - /// 子偏移输入到bounds infer后反推子偏移. - /// - /// from. - /// to. - /// sub_paddings. - /// the partial compute funcs. - /// . - public virtual BufferRegionView GlbReIndex(BufferRegionView from, BufferRegionView to, out IReadOnlyList<(Expr Before, Expr After)> sub_paddings, params (int Axis, Func CallBack)[] partialFuncs) - { - var key = new ReIndexCacheKey(BoundsInferGraph, from.Key, to.Key, new(from.Region), to.Promote); - IReadOnlyList to_region; - if (partialFuncs.Length == 0 && GlbReindexCache.TryGetValue(key, out var result)) - { - to_region = result.ToRegion; - sub_paddings = result.Paddings; - } - else - { - to_region = TileUtilities.GetRelativeNoPadBounds(BoundsInferGraph, from.Key, to.Key, from.Region, to.Promote, partialFuncs, out sub_paddings); - if (partialFuncs.Length == 0) - { - GlbReindexCache.Add(key, (new(to_region), sub_paddings)); - } - } - - return to[to_region.ToArray()]; - } - - protected virtual bool TryScheduleNextTileSize(int[] next_tile_size, PrimFunction logicalPrimFunc, AllocationCache response_cache, bool multi_workers, bool hasResult) - { - // 1. make one tile feed dict - var feed_dict = next_tile_size.Select((s, i) => - new[] { (LoopVars[i], (IValue)Value.FromTensor(Tensor.FromScalar(0))), - (TileSizeVars[i], (IValue)Value.FromTensor(Tensor.FromScalar(s))), }). - SelectMany(i => i). - ToDictionary(kv => kv.Item1, kv => kv.Item2); - var sched_candidate = new Dictionary(ReferenceEqualityComparer.Instance); - - // 2. folding the tileblock op to the block - PrimFunction new_logical_primfunc; - using (var dumpScope = new DumpScope(NullDumpper.Instance)) - { - var pass = new PrimFuncPass { Name = "FoldingTileBlock" }; - pass.Add(feed_dict); - pass.Add(); - pass.Add(); - var task = pass.RunAsync(new LogiclPrimFuncCloner().Clone(logicalPrimFunc, default), new()); - task.Wait(); - new_logical_primfunc = task.Result; - } - - BufferScheduler bufferScheduler = new(new_logical_primfunc); - - // 3. clloction buffers - bufferScheduler.LifeTimeAnalysis(); - - // compute the size in bytes - foreach (var buffer in bufferScheduler.RecordBuffers) - { - var dimensions = buffer.Dimensions.ToArray().Select(d => d.Evaluate(feed_dict).AsTensor().ToScalar()).ToArray(); - var strides = TensorUtilities.GetStrides(dimensions); - var glb_strides = strides.Select(s => s * buffer.ElemType.SizeInBytes).ToArray(); - - if (bufferScheduler.InnerConstraints[buffer] == ConstraintsMode.Channel && - - // 当load psum的时候,如果shape过小, 那么不额外添加stride. - !(buffer.Name.Split(".").Last().StartsWith(GNNEConv2D.PSum.Name) && - dimensions[2] * dimensions[3] < 14 * 14)) - { - glb_strides = TileUtilities.PaddingAvoidConflict(dimensions, glb_strides, 1); - strides = glb_strides.Select(s => - { - if (s % buffer.ElemType.SizeInBytes != 0) - { - throw new NotSupportedException(); - } - - return s / buffer.ElemType.SizeInBytes; - }).ToArray(); - } - - var size_n_byte = dimensions[0] * glb_strides[0]; - - // todo 可以不用align到一整行, 到一个bank即可. - var glb_size = TileUtilities.AlignBy(size_n_byte, ExtCompilerServices.Env.GlbBankWidth * ExtCompilerServices.Env.GlbWidth); - var physical_candidate = new PhysicalBuffer(buffer.Name, buffer.ElemType, buffer.MemLocation, dimensions, strides, start: 0, size: glb_size); - sched_candidate.Add(buffer, physical_candidate); - } - - var respose = bufferScheduler.Schedule(sched_candidate, multi_workers, hasResult); - response_cache.Add(next_tile_size, respose); - return respose.Success; - } - - /// - /// 申请 buffer. - /// NOTE 会自动添加到buffer map, 同时会记录他ddr 上的padding到字典中. - /// 如果给定 ddr buf region, 那么默认glb buffer region则是通过ddr buffer load 进来的,此时glb buffer的region是减去过padding的. - /// 如果promote到对应的循环后,那么申请buffer的时候在promote内部的循环都应该被调整到最大值. - /// - /// mapKey. - /// region. - /// 开启ping pong就会多开一块相同的buffer. - /// 如果promote为int,那么就会提升buffer到指定循环, 为-1那么就是整个计算块, 会忽略ping pong. - /// specificLoopBounds. - /// name. - /// . - /// NotSupportedException. - /// System.ArgumentOutOfRangeException. - protected virtual Expr GetBufferRegion(IndexMapKey mapKey, out BufferRegionView region, bool ping_pong = false, int? promote = null, Dictionary>? specificLoopBounds = null, [CallerArgumentExpression("region")] string name = "region") - { - if (name.StartsWith("var ")) - { - name = name[4..]; - } - - if (KeyToViewMap.ContainsKey(mapKey)) - { - region = KeyToViewMap[mapKey]; - return T.Nop(); - } - - name = NameAllocator.Get(name); - switch (mapKey.Expr) - { - case TensorConst con: - { - IEnumerable bounds; - IEnumerable clampedBounds; - if (promote is int promoteInt) - { - if (promoteInt == -1) - { - clampedBounds = mapKey.Expr.CheckedShape.Select(s => new Range(0, s.FixedValue, 1)); - bounds = BoundsInferGraph[mapKey].Bounds; - } - else - { - if (specificLoopBounds is null || !specificLoopBounds.TryGetValue(mapKey, out var newBounds)) - { - newBounds = K510TIRExtensions.PromotedBounds(promoteInt, BoundsInferGraph, mapKey, LoopVars, LoopDomains).ToList(); - } - - bounds = newBounds; - clampedBounds = TIRUtilities.ClampBounds(newBounds, mapKey.Expr.CheckedShape); - } - } - else - { - bounds = BoundsInferGraph[mapKey].Bounds; - clampedBounds = BoundsInferGraph[mapKey].ClampedBounds; - } - - T.ConstBuffer(con, out var ddr_buffer, name); - if (ping_pong) - { - throw new NotSupportedException(); - } - - region = new BufferRegionView(new[] { ddr_buffer }, bounds, clampedBounds, mapKey); - break; - } - - case Call call: - { - // 1. 对于glb buffer来说, 他的总大小要跟着申请buffer维度来变化. - // note 实际上对于 - List bounds; - Expr loopCount; - if (promote is int promoteInt) - { - if (promoteInt == -1) - { - bounds = mapKey.Expr.CheckedShape.Select(s => new Range(0, s.FixedValue, 1)).ToList(); - loopCount = 0; - } - else - { - if (specificLoopBounds is null || !specificLoopBounds.TryGetValue(mapKey, out var newBounds)) - { - newBounds = K510TIRExtensions.PromotedBounds(promoteInt, BoundsInferGraph, mapKey, LoopVars, LoopDomains).ToList(); - } - - bounds = newBounds; - loopCount = K510TIRExtensions.PromotedLoopCount(promoteInt, LoopVars, LoopDomains); - } - } - else - { - bounds = BoundsInferGraph[mapKey].Bounds.ToList(); - loopCount = LoopCount; - } - - // note 这里的bounds实际上会因为输入不同的var而被改变, 所以后面要获取dimension的地方需要注意. - var dimensions = bounds.Select(r => r.Stop - r.Start).Select((b, i) => MathF.Min(b, call.CheckedShape[i].FixedValue)).ToArray(); - List glb_buffers = new(); - if (ping_pong) - { - for (int i = 0; i < TileOptions.PingPongNum; i++) - { - glb_buffers.Add(new LogicalBuffer(name + $"(p{i})", call.CheckedDataType, MemoryLocation.L2Data, dimensions)); - } - } - else - { - glb_buffers.Add(new LogicalBuffer(name, call.CheckedDataType, MemoryLocation.L2Data, dimensions)); - } - - // 对于glb_buffer来说, 默认region 从0 开始, 但是要减去输入ddr index 的padding. - var noPadBounds = TIRUtilities.ComputeNoPadBounds(bounds, TIRUtilities.ComputePaddings(bounds, mapKey.Expr.CheckedShape)); - region = new BufferRegionView(glb_buffers, bounds, noPadBounds, mapKey, loopCount, promote); - break; - } - - case Var v: - { - // the different mapkey will point to same var: add(v,conv(v)) - if (!VarToKeyMap.TryGetValue(v, out var old_map_key)) - { - T.PhysicalBuffer(v.CheckedDataType, MemoryLocation.Input, v.CheckedShape.ToValueArray(), out var ddr_buffer, name); - IEnumerable clampedBounds; - IReadOnlyList bounds; - if (promote is int promoteInt) - { - if (promoteInt != -1) - { - if (specificLoopBounds is null || !specificLoopBounds.TryGetValue(mapKey, out var newBounds)) - { - newBounds = K510TIRExtensions.PromotedBounds(promoteInt, BoundsInferGraph, mapKey, LoopVars, LoopDomains).ToList(); - } - - bounds = newBounds; - } - else - { - bounds = mapKey.Expr.CheckedShape.Select(s => new Range(0, s.FixedValue, 1)).ToList(); - } - - clampedBounds = TIRUtilities.ClampBounds(bounds, mapKey.Expr.CheckedShape); - } - else - { - bounds = BoundsInferGraph[mapKey].Bounds; - clampedBounds = BoundsInferGraph[mapKey].ClampedBounds; - } - - if (ping_pong) - { - throw new NotSupportedException(); - } - - region = new BufferRegionView(new[] { ddr_buffer }, bounds, clampedBounds, mapKey); - VarToKeyMap.Add(v, mapKey); - } - else - { - region = KeyToViewMap[old_map_key]; - } - - break; - } - - case None none: - region = BufferRegionView.None(mapKey); - break; - default: - throw new NotSupportedException(); - } - - KeyToViewMap.Add(mapKey, region); - return T.Nop(); - } - - /// - /// promote 的逻辑, 根据值选择移动当前的buffer开在哪个循环. - /// -1 表示在所有循环之外 - /// 0 表示在N循环内 - /// 3 表示在W循环内. - /// - /// - /// 上一级传入的key. - /// call. - /// op. - /// block_name. - /// prefix. - /// promote. - /// is enable soft pipe line. - /// . - protected virtual ISequentialBuilder LowerGnneLoad(IndexMapKey parentKey, Call call, GNNELoad op, string block_name, string prefix, int? promote, bool softPipeLine) - { - var call_input = IndexMapKey.Create(call, GNNELoad.Input); - var call_deq = IndexMapKey.Create(call, GNNELoad.DeqParams); - - var seq = T.Sequential().Body( - Visit(call_deq, prefix, promote), - GetBufferRegion(call_input, out var ddr_ld_input, name: prefix + "." + TileNames.DdrInput, promote: promote), // loadif 的输入可能来自于const或输入 - GetBufferRegion(call_deq, out var glb_ld_qarg_input), // 只有promote到n循环外时,才不进行ping pong. - GetBufferRegion(parentKey, out var glb_ld_output, promote == -1 ? false : TileOptions.PingPong, name: prefix, promote: promote)); // load if 要用的glb buffer - - var block = EAction.TileBlock(block_name). - Alloc(promote is null ? glb_ld_output.Buffers : None.Default). - Reads(ddr_ld_input.BufferRegions, glb_ld_qarg_input.BufferRegions). - Writes(glb_ld_output.BufferRegions). - Predicate(true).// todo 这里先不做局部加载, 后面再实现 - Body(// promote这里load用的是full region, 但是在字典中存的还应该是partial的, 因为后面是每个glb的tile在使用. - softPipeLine ? - (TileOptions.PingPong & (promote != -1) ? K510.PingPongSlot(block_name, glb_ld_output.LoopCount / TileOptions.PingPongNum, glb_ld_output.LoopCount % TileOptions.PingPongNum) : T.Nop()) : - T.Nop(), - EAction.LoadT(ddr_ld_input, glb_ld_output, glb_ld_qarg_input, op.DeqAxis)); - - if (promote is int promoteIndex) - { - // 如果promote, 那么在这个循环的所有block外执行 - NestedBlocks[promoteIndex + 1].Init(block); - NestedBlocks[promoteIndex + 1].Alloc(glb_ld_output.Buffers); - } - else - { - seq.Body(block); - } - - return seq; - } - - protected virtual ISequentialBuilder LowerGnneMeshNet(IndexMapKey parentKey, Call call, GNNEMeshNet target, string block_name, string prefix) - { - prefix = NameAllocator.Get(nameof(GNNEMeshNet)); - var call_in_a = IndexMapKey.Create(call, GNNEMeshNet.InputA); - var call_in_b = IndexMapKey.Create(call, GNNEMeshNet.InputB); - var call_in_seg0 = IndexMapKey.Create(call, GNNEMeshNet.SegFittingParam0); - var call_in_seg1 = IndexMapKey.Create(call, GNNEMeshNet.SegFittingParam1); - var seq = T.Sequential().Body( - Visit(call_in_a, prefix), - Visit(call_in_b, prefix), - GetBufferRegion(call_in_a, out var meshnet_input_a), - GetBufferRegion(call_in_b, out var meshnet_input_b), - GetBufferRegion(call_in_seg0, out var meshnet_input_seg0), - GetBufferRegion(call_in_seg1, out var meshnet_input_seg1), - GetBufferRegion(parentKey, out var meshnet_output, TileOptions.PingPong, name: prefix), - EAction.TileBlock(block_name). - Alloc(meshnet_output.Buffers). - Reads( - meshnet_input_a.BufferRegions, - meshnet_input_b.BufferRegions, - meshnet_input_seg0.BufferRegions, - meshnet_input_seg1.BufferRegions). - Body( - EAction.MeshNetCompute( - (Fusion)call[GNNEMeshNet.MeshFunc], - meshnet_input_a, - meshnet_input_b, - meshnet_input_seg0, - meshnet_input_seg1, - meshnet_output))); - - if (!(call[GNNEMeshNet.InputB] is None && call[GNNEMeshNet.NewShape] is None && call[GNNEMeshNet.SegFittingParam0] is None && call[GNNEMeshNet.SegFittingParam1] is None && !TileUtilities.MeshFuncHasConstants((Fusion)call[GNNEMeshNet.MeshFunc]))) - { - foreach (var item in meshnet_output.Buffers) - { - item.Metadata = new TileMetadata() { StrideByShape = true }; - } - } - - return seq; - } - - protected virtual ISequentialBuilder LowerGnneStore(Call call, GNNEStore op, string block_name, string prefix, bool promoteQarg = true) - { - prefix = NameAllocator.Get(nameof(GNNEStore)); - var cropPadding = ((TensorConst)call[GNNEStore.CropPadding]).Value.Cast(); - var channel = call.CheckedShape[1].FixedValue; - bool is_quant_by_channel = false; - if (call[GNNEStore.QuantParams] is Call { Target: GNNELoad } l_qarg && l_qarg[GNNELoad.Input] is TensorConst qarg) - { - _ = qarg.Value.Cast(); - if (qarg[0] != qarg[channel - 1]) - { - is_quant_by_channel = true; - } - } - - var outputShape = call.CheckedShape.ToValueArray(); - T.PhysicalBuffer(call.CheckedDataType, MemoryLocation.Output, outputShape, out var ddr_st_buffer, name: prefix + ".ddr_buffer"); - var bounds = BoundsInferGraph[call].Bounds; - - var (paddingHBefore, paddingHafter) = TileUtilities.ComputePadding(bounds[2] - cropPadding[0, 0], outputShape[2]); - var (paddingWBefore, paddingWafter) = TileUtilities.ComputePadding(bounds[3] - cropPadding[1, 0], outputShape[3]); - - var newBounds = bounds.ToArray(); - newBounds[2] = newBounds[2] - cropPadding[0, 0]; - newBounds[3] = newBounds[3] - cropPadding[1, 0]; - var ddrRegion = TIRUtilities.ClampBounds(newBounds, outputShape); - var ddr_st_output = new BufferRegionView(new[] { ddr_st_buffer }, BoundsInferGraph[call].Bounds, ddrRegion, call); - KeyToViewMap.Add(call, ddr_st_output); - - var call_in = IndexMapKey.Create(call, GNNEStore.Input); - var call_qarg = IndexMapKey.Create(call, GNNEStore.QuantParams); - return T.Sequential().Body( - Visit(call_in, prefix), - Visit(call_qarg, prefix, promoteQarg ? -1 : null), // 多层是只按h切, 此时oc满的,默认promote. - GetBufferRegion(call_in, out var glb_st_input), - GetBufferRegion(call_qarg, out var glb_st_qarg_input), - EAction.TileBlock(block_name).Reads(glb_st_input.BufferRegions, glb_st_qarg_input.BufferRegions).Body( - EAction.StoreT(ddr_st_output, glb_st_input[.., .., (glb_st_input.Region[2].Start + paddingHBefore, glb_st_input.Region[2].Stop - paddingHafter), (glb_st_input.Region[3].Start + paddingWBefore, glb_st_input.Region[3].Stop - paddingWafter)], glb_st_qarg_input, null, is_quant_by_channel))); - } - - protected virtual ISequentialBuilder LowerGnneReduce(IndexMapKey parentKey, Call call, GNNEReduce op, string block_name, string prefix) - { - prefix = NameAllocator.Get(nameof(GNNEReduce)); - var reduce_in = IndexMapKey.Create(call, GNNEReduce.Input); - var seq = T.Sequential().Body( - Visit(reduce_in, prefix), - GetBufferRegion(reduce_in, out var gnne_reduce_input), - GetBufferRegion(parentKey, out var gnne_reduce_output, TileOptions.PingPong, name: prefix), - EAction.TileBlock(block_name). - Alloc(gnne_reduce_output.Buffers). - Reads(gnne_reduce_input.BufferRegions).Body( - EAction.Reduce( - gnne_reduce_input, - gnne_reduce_output, - call[GNNEReduce.InitValue], - op.ReduceOp, - op.ReduceDim))); - - return seq; - } - - /// - /// 对于dw卷积来说,每个ic对应一个oc, 因此让每个tcu计算一半的if. - /// - protected ITileBlockBuilder GNNEConv2DSharedNone(Call call, string block_name, BufferRegionView glb_w, BufferRegionView glb_if, BufferRegionView glb_act, BufferRegionView glb_psum, BufferRegionView glb_of, bool is_init_psum, string prefix, int? promote = null) - { - var init_psums = GetInitPSumBufferRegion(call, IndexMapKey.Create(call, GNNEConv2D.PSum), glb_psum, promote, prefix, 1, ExtCompilerServices.Env.TcuActNum, out var part_condition); - - var block = EAction.TileBlock(block_name).Reads(glb_w.BufferRegions, is_init_psum ? init_psums[0].BufferRegions.Concat(init_psums[1].BufferRegions.Select(b => MathF.Condition(EAction.FoldOfMarker(part_condition > 1), b))).ToArray() : glb_psum.BufferRegions, glb_if.BufferRegions, glb_act.BufferRegions).Writes(glb_of.BufferRegions).Body( - T.Unrolled(out var kh, new(0, glb_w.Dimensions[2], ExtCompilerServices.Env.PuHeight)).Body( - T.Unrolled(out var kw, new(0, glb_w.Dimensions[3], ExtCompilerServices.Env.PuKernelSpad)).Body( - T.Let(out var m_once, 1).Body( - T.Let(out var c_once, MathF.Select(MathF.Equal(m_once, glb_w.RegionSize(0)), MathF.Min(MathF.Min(ExtCompilerServices.Env.PuWidth / m_once, ExtCompilerServices.Env.PuHeight / MathF.Min(glb_w.Dimensions[2], ExtCompilerServices.Env.PuHeight)), glb_w.RegionSize(0)), 1)).Body(// note 我这里没有实现dw卷积的多输出channel的, 默认都是 1 ic : 1 oc. - T.Let(out var tcu_oc_chunk, TileUtilities.Split(glb_of.RegionSize(1), ExtCompilerServices.Env.TcuActNum)).Body(// 1. determine tcu act num - T.Let(out var n_active_tcu, TileUtilities.SplitTimes(glb_of.RegionSize(1), tcu_oc_chunk)).Body(// 2. broadcast action - EAction.TcuDmBroadCast(TcuDivideStrategy.NoShare), - T.Unrolled(out var tcu_oc, new(glb_of.Region[1].Start, glb_of.Region[1].Stop, tcu_oc_chunk)).Body(// 3. loop over tcus and config each tcu - T.Let(out var m_once_tcu, 1).Body( - T.Let(out var c_once_tcu, MathF.Min(ExtCompilerServices.Env.PuHeight / glb_w.RegionSize(2), MathF.Min(tcu_oc + tcu_oc_chunk, glb_of.Region[1].Stop) - tcu_oc)).Body( - EAction.TcuPuConfAct( - TileUtilities.GetTcuIndexBits(tcu_oc / tcu_oc_chunk), - GlbReIndex(glb_of[.., (tcu_oc, MathF.Min(tcu_oc + tcu_oc_chunk, glb_of.Region[1].Stop)), .., ..], glb_act, out _), - call[GNNEConv2D.FusedClamp][0], - call[GNNEConv2D.FusedClamp][1]), - EAction.TcuPuConf( - TileUtilities.GetTcuIndexBits(tcu_oc / tcu_oc_chunk), - GlbReIndex(glb_of[.., (tcu_oc, MathF.Min(tcu_oc + tcu_oc_chunk, glb_of.Region[1].Stop)), .., ..], glb_if, out _), - glb_of[.., (tcu_oc, MathF.Min(tcu_oc + tcu_oc_chunk, glb_of.Region[1].Stop)), .., ..], - IR.F.Math.Min(glb_w.Dimensions[2], ExtCompilerServices.Env.PuHeight), - IR.F.Math.Min(glb_w.Dimensions[3], ExtCompilerServices.Env.PuKernelSpad), - m_once: m_once_tcu, - c_once: c_once_tcu, - groups: 1, - mode: TcuComputeMode.DwConv2d), - EAction.TcuDmConfOf(// todo 这里hardcode两个tcu, 后面需要改进 - TileUtilities.GetTcuIndexBits(tcu_oc / tcu_oc_chunk), - is_init_psum ? MathF.Select(MathF.Equal(tcu_oc, glb_of.Region[1].Start), init_psums[0][.., (0, MathF.Min(tcu_oc + tcu_oc_chunk, glb_of.Region[1].Stop) - tcu_oc), .., ..], init_psums[1][.., (0, MathF.Min(tcu_oc + tcu_oc_chunk, glb_of.Region[1].Stop) - tcu_oc), .., ..]) : GlbReIndex(glb_of[.., (tcu_oc, MathF.Min(tcu_oc + tcu_oc_chunk, glb_of.Region[1].Stop)), .., ..], glb_psum, out _), - glb_of[.., (tcu_oc, MathF.Min(tcu_oc + tcu_oc_chunk, glb_of.Region[1].Stop)), .., ..], - 0), - EAction.TcuDmConfIf( - TileUtilities.GetTcuIndexBits(tcu_oc / tcu_oc_chunk), - GlbReIndex(glb_of[.., (tcu_oc, MathF.Min(tcu_oc + tcu_oc_chunk, glb_of.Region[1].Stop)), .., ..], glb_if, out var if_paddings), - stride_w: call[GNNEConv2D.Stride][1], - stride_h: call[GNNEConv2D.Stride][0], - input_c_pre_pu: c_once_tcu, - dilation_h: call[GNNEConv2D.Dilation][0], - padding_top: if_paddings[2].Before, - padding_bottom: if_paddings[2].After, - padding_left: if_paddings[3].Before, - padding_right: if_paddings[3].After), - EAction.TcuDmConfW( - TileUtilities.GetTcuIndexBits(tcu_oc / tcu_oc_chunk), - GlbReIndex(glb_of[.., (tcu_oc, MathF.Min(tcu_oc + tcu_oc_chunk, glb_of.Region[1].Stop)), .., ..], glb_w, out _)), - EAction.TcuDmFetchW(// 4. fetch weights - TileUtilities.GetTcuIndexBits(tcu_oc / tcu_oc_chunk), - GlbReIndex(glb_of[.., (tcu_oc, MathF.Min(tcu_oc + tcu_oc_chunk, glb_of.Region[1].Stop)), .., ..], glb_w, out _)), - EAction.TcuDmFetchIf(// 5. loop over tcus and fetch if for each tcu - TileUtilities.GetTcuIndexBits(tcu_oc / tcu_oc_chunk), - GlbReIndex(glb_of[.., (tcu_oc, MathF.Min(tcu_oc + tcu_oc_chunk, glb_of.Region[1].Stop)), .., ..], glb_if, out _))))), - EAction.TcuPuCompute(// 6. pu compute NOTE 这里我没有在weight的kh和kw上切,所以默认都是一次算完的 - TileUtilities.GetNTcuIndexBits(n_active_tcu), - true, - true, - call[GNNEConv2D.PSum] is not Call { Target: Uninitialized }, - TileUtilities.GetNTcuIndexBits(n_active_tcu))))))))); - - if (promote is null) - { - block.Alloc(glb_of.Buffers, is_init_psum ? init_psums[0].Buffers.OfType().Concat(init_psums[1].Buffers.Select(b => MathF.Condition(EAction.FoldOfMarker(part_condition > 1), b))).ToArray() : None.Default); - } - else if (promote is int promoteIndex) - { - if (is_init_psum) - { - NestedBlocks[promoteIndex + 1].Alloc(init_psums[0].Buffers.OfType().Concat(init_psums[1].Buffers.Select(b => MathF.Condition(EAction.FoldOfMarker(part_condition > 1), b))).ToArray()); - } - } - - return block; - } - - protected virtual BufferRegionView[] GetInitPSumBufferRegion(Call call, IndexMapKey key, BufferRegionView glb_psum, int? promote, string prefix, int split_axis, int tcuActNum, out Expr part_condition) - { - var chunk = TileUtilities.Split(glb_psum.Dimensions[split_axis], tcuActNum); - part_condition = TileUtilities.SplitTimes(glb_psum.Dimensions[split_axis], chunk); - var views = new BufferRegionView[2]; - var psum_dimensions = glb_psum.Dimensions.ToArray(); - psum_dimensions[split_axis] = chunk; - var dimensions = psum_dimensions; - - // build psum a and b - foreach (var (part, i) in new[] { "_a", "_b" }.Select((p, i) => (p, i))) - { - var name = prefix + "." + TileNames.InitPSum + part; - var glb_init_psums = new List(); - if (TileOptions.PingPong) - { - for (int p = 0; p < TileOptions.PingPongNum; p++) - { - glb_init_psums.Add(new LogicalBuffer(name + $"(p{p})", DataTypes.Float32, MemoryLocation.L2Data, dimensions)); - } - } - else - { - glb_init_psums.Add(new LogicalBuffer(name, DataTypes.Float32, MemoryLocation.L2Data, dimensions)); - } - - Expr loopCount; - if (promote is int promoteInt) - { - if (promoteInt != -1) - { - loopCount = K510TIRExtensions.PromotedLoopCount(promoteInt, LoopVars, LoopDomains); - } - else - { - loopCount = LoopCount; - } - } - else - { - loopCount = LoopCount; - } - - views[i] = new BufferRegionView(glb_init_psums, glb_psum.Bounds, psum_dimensions.Select(d => new Range(0, d, 1)), key, loopCount, promote); - } - - return views; - } - - protected virtual Expr GNNEConv2DComputeActEnable(Call call, BufferRegionView glb_w, Expr khStop, Expr kHBounds, Expr kwStop, Expr kWBounds) - { - return MathF.LogicalAnd(MathF.GreaterEqual(khStop, kHBounds), MathF.GreaterEqual(kwStop, kWBounds)); - } - - protected virtual Expr GNNEConv2DComputeOfEnable(Call call, BufferRegionView glb_w, Expr khStop, Expr kHBounds, Expr kwStop, Expr kWBounds) - { - return GNNEConv2DComputeActEnable(call, glb_w, khStop, kHBounds, kwStop, kWBounds); - } - - protected virtual Expr GNNEConv2DComputeLoadPsumEnable(Call call, BufferRegionView glb_w, Expr kh, Expr kw) - { - if (call[IR.K510.GNNEConv2D.PSum] is Call { Target: IR.Buffers.Uninitialized }) - { - return IR.F.Math.LogicalNot(IR.F.Math.LogicalAnd(IR.F.Math.Equal(kh, 0), IR.F.Math.Equal(kw, 0))); - } - - return true; - } - - /// - /// share if 是每个tcu计算一半的oc, 此时他们共享同一个if. - /// - protected ITileBlockBuilder GNNEConv2DSharedIF(Call call, string block_name, BufferRegionView glb_w, BufferRegionView glb_if, BufferRegionView glb_act, BufferRegionView glb_psum, BufferRegionView glb_of, bool is_init_psum, string prefix, int? promote = null) - { - var reGlbIf = GlbReIndex(glb_of, glb_if, out var sub_paddings); - - var init_psums = GetInitPSumBufferRegion(call, IndexMapKey.Create(call, GNNEConv2D.PSum), glb_psum, promote, prefix, 1, ExtCompilerServices.Env.TcuActNum, out var part_condition); - - var block = EAction.TileBlock(block_name).Reads(glb_w.BufferRegions, is_init_psum ? init_psums[0].BufferRegions.Concat(init_psums[1].BufferRegions.Select(b => MathF.Condition(EAction.FoldOfMarker(part_condition > 1), b))).ToArray() : glb_psum.BufferRegions, glb_if.BufferRegions, glb_act.BufferRegions).Body( - T.Let(out var khChunck, MathF.Min(glb_w.Dimensions[2], ExtCompilerServices.Env.PuHeight)).Body( - T.Let(out var kwChunck, MathF.Min(glb_w.Dimensions[3], ExtCompilerServices.Env.PuKernelSpad)).Body( - T.Unrolled(out var kh, new(0, glb_w.Dimensions[2], khChunck)).Body(// 对kernel h/w进行tiling 暂时先不考虑 - T.Unrolled(out var kw, new(0, glb_w.Dimensions[3], kwChunck)).Body( - T.Let(out var tcu_oc_chunk, TileUtilities.Split(glb_of.RegionSize(1), ExtCompilerServices.Env.TcuActNum)).Body(// 1. determine tcu act num - T.Let(out var n_active_tcu, TileUtilities.SplitTimes(glb_of.RegionSize(1), tcu_oc_chunk)).Body( - T.If(MathF.Equal(n_active_tcu, 1)).Then(// 3. broadcast action - EAction.TcuDmBroadCast(TcuDivideStrategy.NoShare)).Else( - EAction.TcuDmBroadCast(TcuDivideStrategy.ShareIf)), - EAction.TcuDmConfIf(// 4. conf if - TileUtilities.GetNTcuIndexBits(n_active_tcu), - reGlbIf, - stride_w: call[GNNEConv2D.Stride][1], - stride_h: call[GNNEConv2D.Stride][0], - input_c_pre_pu: MathF.Min(ExtCompilerServices.Env.PuHeight / glb_w.RegionSize(2), glb_if.RegionSize(1)), // todo 这里可能有问题. - dilation_h: call[GNNEConv2D.Dilation][0], - padding_top: sub_paddings[2].Before, - padding_bottom: sub_paddings[2].After, - padding_left: sub_paddings[3].Before, - padding_right: sub_paddings[3].After), - T.Unrolled(out var tcu_oc, new(glb_of.Region[1].Start, glb_of.Region[1].Stop, tcu_oc_chunk)).Body(// 5. loop over tcus and config each tcu - EAction.TcuPuConfAct( - TileUtilities.GetTcuIndexBits(tcu_oc / tcu_oc_chunk), - GlbReIndex(glb_of[.., (tcu_oc, MathF.Min(tcu_oc + tcu_oc_chunk, glb_of.Region[1].Stop)), .., ..], glb_act, out _), - call[GNNEConv2D.FusedClamp][0], - call[GNNEConv2D.FusedClamp][1]), - EAction.TcuPuConf( - TileUtilities.GetTcuIndexBits(tcu_oc / tcu_oc_chunk), - reGlbIf, // 切oc对于if不影响 - glb_of[.., (tcu_oc, MathF.Min(tcu_oc + tcu_oc_chunk, glb_of.Region[1].Stop)), .., ..], - khChunck, - kwChunck, - m_once: MathF.Min(ExtCompilerServices.Env.PuWidth, MathF.Min(tcu_oc + tcu_oc_chunk, glb_of.Region[1].Stop) - tcu_oc), - c_once: MathF.Min(MathF.Min(ExtCompilerServices.Env.PuHeight / glb_w.RegionSize(2), glb_w.RegionSize(1)), glb_if.RegionSize(1)), - groups: call[GNNEConv2D.Groups], - mode: TcuComputeMode.NormalConv2d), - EAction.TcuDmConfOf(// todo 这里hardcode两个tcu, 后面需要改进 - TileUtilities.GetTcuIndexBits(tcu_oc / tcu_oc_chunk), - is_init_psum ? MathF.Select(MathF.Equal(tcu_oc, glb_of.Region[1].Start), init_psums[0][.., (0, MathF.Min(tcu_oc + tcu_oc_chunk, glb_of.Region[1].Stop) - tcu_oc), .., ..], init_psums[1][.., (0, MathF.Min(tcu_oc + tcu_oc_chunk, glb_of.Region[1].Stop) - tcu_oc), .., ..]) : GlbReIndex(glb_of[.., (tcu_oc, MathF.Min(tcu_oc + tcu_oc_chunk, glb_of.Region[1].Stop)), .., ..], glb_psum, out _), - glb_of[.., (tcu_oc, MathF.Min(tcu_oc + tcu_oc_chunk, glb_of.Region[1].Stop)), .., ..], - 0), - EAction.TcuDmConfW( - TileUtilities.GetTcuIndexBits(tcu_oc / tcu_oc_chunk), - GlbReIndex(glb_of[.., (tcu_oc, MathF.Min(tcu_oc + tcu_oc_chunk, glb_of.Region[1].Stop)), .., ..], glb_w, out _))), - T.Unrolled(out var tcu_oc2, new(glb_of.Region[1].Start, glb_of.Region[1].Stop, tcu_oc_chunk)).Body( - EAction.TcuDmFetchW(// 6. fetch weights. - TileUtilities.GetTcuIndexBits(tcu_oc2 / tcu_oc_chunk), - GlbReIndex(glb_of[.., (tcu_oc2, MathF.Min(tcu_oc2 + tcu_oc_chunk, glb_of.Region[1].Stop)), .., ..], glb_w, out _))), - EAction.TcuDmFetchIf(// 7. fetch if. - TileUtilities.GetNTcuIndexBits(n_active_tcu), - reGlbIf), - EAction.TcuPuCompute(// 8. tcu compute - TileUtilities.GetNTcuIndexBits(n_active_tcu), - act_enable: GNNEConv2DComputeActEnable(call, glb_w, kh + khChunck, glb_w.Dimensions[2], kw + kwChunck, glb_w.Dimensions[3]), - of_enable: GNNEConv2DComputeOfEnable(call, glb_w, kh + khChunck, glb_w.Dimensions[2], kw + kwChunck, glb_w.Dimensions[3]), - load_psum: GNNEConv2DComputeLoadPsumEnable(call, glb_w, kh, kw), - TileUtilities.GetNTcuIndexBits(n_active_tcu))))))))); - - if (promote is null) - { - block.Alloc(glb_of.Buffers, is_init_psum ? init_psums[0].Buffers.OfType().Concat(init_psums[1].Buffers.Select(b => MathF.Condition(EAction.FoldOfMarker(part_condition > 1), b))).ToArray() : None.Default); - } - else if (promote is int promoteIndex) - { - if (is_init_psum) - { - NestedBlocks[promoteIndex + 1].Alloc(init_psums[0].Buffers.OfType().Concat(init_psums[1].Buffers.Select(b => MathF.Condition(EAction.FoldOfMarker(part_condition > 1), b))).ToArray()); - } - } - - return block; - } - - /// - /// 假设oc为32被拆分之后每个tcu只能映射一半, 那么两个tcu共享一份weights, 在ofmap的h上进行切分. - /// - protected ITileBlockBuilder GNNEConv2DSharedW(Call call, string block_name, BufferRegionView glb_w, BufferRegionView glb_if, BufferRegionView glb_act, BufferRegionView glb_psum, BufferRegionView glb_of, bool is_depthwise, bool is_init_psum, string prefix, int? promote = null) - { - var init_psums = GetInitPSumBufferRegion(call, IndexMapKey.Create(call, GNNEConv2D.PSum), glb_psum, promote, prefix, 2, ExtCompilerServices.Env.TcuActNum, out var part_condition); - - var reGlbW = GlbReIndex(glb_of[.., .., .., ..], glb_w, out _); - var (iH, iW) = (glb_if.Dimensions[2], glb_if.Dimensions[3]); - var (kH, kW) = (glb_w.Dimensions[2], glb_w.Dimensions[3]); - var stride = ((TensorConst)call[IR.K510.GNNEConv2D.Stride]).Value.Cast(); - var padding = ((TensorConst)call[IR.K510.GNNEConv2D.Padding]).Value.Cast(); - var dilation = ((TensorConst)call[IR.K510.GNNEConv2D.Dilation]).Value.Cast(); - - var block = EAction.TileBlock(block_name).Reads(glb_w.BufferRegions, is_init_psum ? init_psums[0].BufferRegions.Concat(init_psums[1].BufferRegions.Select(b => MathF.Condition(EAction.FoldOfMarker(part_condition > 1), b))).ToArray() : glb_psum.BufferRegions, glb_if.BufferRegions, glb_act.BufferRegions).Body( - T.Let(out var khChunck, MathF.Min(kH, ExtCompilerServices.Env.PuHeight)).Body( - T.Let(out var kwChunck, dilation[1] != 1 ? 1 : MathF.Min(kW, ExtCompilerServices.Env.PuKernelSpad)).Body( - T.Unrolled(out var kh, new(0, kH, khChunck)).Body( - T.Unrolled(out var kw, new(0, kW, kwChunck)).Body(// NOTE dw卷积时m once指的是一次ic对应输出多少个oc, 所以默认为1 - T.Let(out var m_once, is_depthwise ? 1 : MathF.Min(ExtCompilerServices.Env.PuWidth, glb_w.RegionSize(0))).Body( - T.Let(out var c_once, is_depthwise ? MathF.Min(MathF.Min(ExtCompilerServices.Env.PuWidth / m_once, ExtCompilerServices.Env.PuHeight / glb_w.RegionSize(2)), glb_w.RegionSize(0)) : MathF.Min(MathF.Min(ExtCompilerServices.Env.PuHeight / glb_w.RegionSize(2), glb_w.RegionSize(1)), glb_if.RegionSize(1))).Body(// NOTE dw卷积时, if是按对角线排列的, 所以要小于min(pu w/pu h) - T.Let(out var tcu_oh_chunk, TileUtilities.Split(glb_of.RegionSize(2), ExtCompilerServices.Env.TcuActNum)).Body(// segment tcu h in output_h - T.Let(out var n_active_tcu, TileUtilities.SplitTimes(glb_of.RegionSize(2), tcu_oh_chunk)).Body( - T.If(MathF.Equal(n_active_tcu, 1)).Then(// NOTE 这里的psum已经被load好了, 可能到时候会存在psum大小和后续不匹配的问题.// 3. broadcast action - EAction.TcuDmBroadCast(TcuDivideStrategy.NoShare)) - .Else( - EAction.TcuDmBroadCast(TcuDivideStrategy.ShareW)), - EAction.TcuDmConfW(TileUtilities.GetNTcuIndexBits(n_active_tcu), reGlbW[.., .., (kh, MathF.Min(kh + khChunck, kH)), (kw, MathF.Min(kw + kwChunck, kW))]), - T.Unrolled(out var tcu_oh, new(glb_of.Region[2].Start, glb_of.Region[2].Stop, tcu_oh_chunk)).Body(// 4. conf_w action - T.Let(out var tcu_index_bits, TileUtilities.GetTcuIndexBits(tcu_oh / tcu_oh_chunk)).Body(// 5. loop over tcus and config each tcu - EAction.TcuDmConfIf(// conf if - tcu_index_bits, - GlbReIndex(glb_of[.., .., (tcu_oh, MathF.Min(tcu_oh + tcu_oh_chunk, glb_of.Region[2].Stop)), ..], glb_if, out var if_padding, (2, r => TileUtilities.Conv2DSubSlice(r, new TIR.Range(kh, IR.F.Math.Min(kh + khChunck, kH), 1), stride[0], padding[0, 0], dilation[0])), (3, r => TileUtilities.Conv2DSubSlice(r, new TIR.Range(kw, IR.F.Math.Min(kw + kwChunck, kW), 1), stride[1], padding[1, 0], dilation[1]))), - stride_w: stride[1], - stride_h: stride[0], - input_c_pre_pu: MathF.Min(ExtCompilerServices.Env.PuHeight / glb_w.RegionSize(2), glb_if.RegionSize(1)), // todo 这里可能有问题. - dilation_h: call[GNNEConv2D.Dilation][0], - padding_top: if_padding[2].Before, - padding_bottom: if_padding[2].After, - padding_left: if_padding[3].Before, - padding_right: if_padding[3].After), - EAction.TcuPuConfAct(// conf act - tcu_index_bits, - GlbReIndex(glb_of[.., .., (tcu_oh, MathF.Min(tcu_oh + tcu_oh_chunk, glb_of.Region[2].Stop)), ..], glb_act, out _), - call[GNNEConv2D.FusedClamp][0], - call[GNNEConv2D.FusedClamp][1]), - EAction.TcuPuConf(// conf pu - tcu_index_bits, - GlbReIndex(glb_of[.., .., (tcu_oh, MathF.Min(tcu_oh + tcu_oh_chunk, glb_of.Region[2].Stop)), ..], glb_if, out var _, (2, r => TileUtilities.Conv2DSubSlice(r, new TIR.Range(kh, IR.F.Math.Min(kh + khChunck, kH), 1), stride[0], padding[0, 0], dilation[0])), (3, r => TileUtilities.Conv2DSubSlice(r, new TIR.Range(kw, IR.F.Math.Min(kw + kwChunck, kW), 1), stride[1], padding[1, 0], dilation[1]))), - glb_of[.., .., (tcu_oh, MathF.Min(tcu_oh + tcu_oh_chunk, glb_of.Region[2].Stop)), ..], - khChunck, - kwChunck, - m_once, - c_once, - groups: is_depthwise ? 1 : call[GNNEConv2D.Groups], // NOTE tcu pu conf 的group其实是multiplier的意思,就是一个ic会输出多个oc, 并不是标准conv定义的groups. - mode: is_depthwise ? TcuComputeMode.DwConv2d : TcuComputeMode.NormalConv2d), - EAction.TcuDmConfOf(// conf of - tcu_index_bits, - is_init_psum ? MathF.Select(MathF.Equal(tcu_oh, glb_of.Region[2].Start), init_psums[0][.., .., (0, MathF.Min(tcu_oh + tcu_oh_chunk, glb_of.Region[2].Stop) - tcu_oh), .., ..], init_psums[1][.., .., (0, MathF.Min(tcu_oh + tcu_oh_chunk, glb_of.Region[2].Stop) - tcu_oh), .., ..]) : GlbReIndex(glb_of[.., .., (tcu_oh, MathF.Min(tcu_oh + tcu_oh_chunk, glb_of.Region[2].Stop)), ..], glb_psum, out _), - glb_of[.., .., (tcu_oh, MathF.Min(tcu_oh + tcu_oh_chunk, glb_of.Region[2].Stop)), ..], - 0))), - EAction.TcuDmFetchW(TileUtilities.GetNTcuIndexBits(n_active_tcu), reGlbW[.., .., (kh, MathF.Min(kh + khChunck, kH)), (kw, MathF.Min(kw + kwChunck, kW))]), - T.Unrolled(out var tcu_oh2, new(glb_of.Region[2].Start, glb_of.Region[2].Stop, tcu_oh_chunk)).Body(// 6. fetch weights - EAction.TcuDmFetchIf( - TileUtilities.GetTcuIndexBits(tcu_oh2 / tcu_oh_chunk), // 7. loop over tcus and fetch if for each tcu - GlbReIndex(glb_of[.., .., (tcu_oh2, MathF.Min(tcu_oh2 + tcu_oh_chunk, glb_of.Region[2].Stop)), ..], glb_if, out _, (2, r => TileUtilities.Conv2DSubSlice(r, new TIR.Range(kh, IR.F.Math.Min(kh + khChunck, kH), 1), stride[0], padding[0, 0], dilation[0])), (3, r => TileUtilities.Conv2DSubSlice(r, new TIR.Range(kw, IR.F.Math.Min(kw + kwChunck, kW), 1), stride[1], padding[1, 0], dilation[1]))))), - EAction.TcuPuCompute(// 8. tcu compute. - TileUtilities.GetNTcuIndexBits(n_active_tcu), - GNNEConv2DComputeOfEnable(call, glb_w, kh + khChunck, kH, kw + kwChunck, kW), - GNNEConv2DComputeActEnable(call, glb_w, kh + khChunck, kH, kw + kwChunck, kW), - GNNEConv2DComputeLoadPsumEnable(call, glb_w, kh, kw), - TileUtilities.GetNTcuIndexBits(n_active_tcu))))))))))); - - if (promote is null) - { - block.Alloc(glb_of.Buffers, is_init_psum ? init_psums[0].Buffers.OfType().Concat(init_psums[1].Buffers.Select(b => MathF.Condition(EAction.FoldOfMarker(part_condition > 1), b))).ToArray() : None.Default); - } - else if (promote is int promoteIndex) - { - if (is_init_psum) - { - NestedBlocks[promoteIndex + 1].Alloc(init_psums[0].Buffers.OfType().Concat(init_psums[1].Buffers.Select(b => MathF.Condition(EAction.FoldOfMarker(part_condition > 1), b))).ToArray()); - } - } - - return block; - } - - protected virtual ISequentialBuilder LowerGnneConv2D(IndexMapKey parentKey, Call call, GNNEConv2D op, string block_name, string prefix) - { - prefix = NameAllocator.Get(nameof(GNNEConv2D)); - bool is_depthwise; - { - var groups = ((TensorConst)call[GNNEConv2D.Groups]).Value.ToScalar(); - var input_channels = call[GNNEConv2D.Input].CheckedShape[1].FixedValue; - var output_channels = call.CheckedShape[1].FixedValue; - is_depthwise = input_channels == output_channels && output_channels == groups && groups != 1; - } - - if (is_depthwise) - { - prefix = prefix + "(dw)"; - block_name += "(dw)"; - } - - var call_w = IndexMapKey.Create(call, GNNEConv2D.Weights); - var call_in = IndexMapKey.Create(call, GNNEConv2D.Input); - var call_act = IndexMapKey.Create(call, GNNEConv2D.Act); - var call_psum = IndexMapKey.Create(call, GNNEConv2D.PSum); - - bool is_init_psum = call_psum.Expr is Call { Target: Uninitialized }; - - TcuDivideStrategy tcu_strategy; - if (!is_depthwise) - { - // 优先让每个tcu的width用满 - var out_shape = call.CheckedShape.ToValueArray(); - if (out_shape[1] >= ExtCompilerServices.Env.PuWidth * ExtCompilerServices.Env.TcuActNum) - { - tcu_strategy = TcuDivideStrategy.ShareIf; - } - else - { - tcu_strategy = TcuDivideStrategy.ShareW; - } - } - else - { // TODO 需要一种量化的方法来决定dw卷积用什么策略. - tcu_strategy = TcuDivideStrategy.NoShare; - } - - prefix = prefix + "." + tcu_strategy; - - // 默认是layer group的做法, 也就是w/act全部promote - return T.Sequential().Body( - Visit(call_w, prefix, -1), - Visit(call_in, prefix), - Visit(call_act, prefix, -1), - Visit(call_psum, prefix), - GetBufferRegion(call_w, out var glb_w), - GetBufferRegion(call_in, out var glb_if), // glb if 存在padding的情况. - GetBufferRegion(call_act, out var glb_act), - GetBufferRegion(call_psum, out var glb_psum, TileOptions.PingPong, name: prefix + "." + GNNEConv2D.PSum.Name), // note 这里的pusm申请了但不记录到allocs中,仅用于给psum apart使用. - GetBufferRegion(parentKey, out var glb_of, TileOptions.PingPong, name: prefix + "." + TileNames.Output), - tcu_strategy switch { TcuDivideStrategy.ShareIf => GNNEConv2DSharedIF(call, block_name, glb_w, glb_if, glb_act, glb_psum, glb_of, is_init_psum, prefix), TcuDivideStrategy.ShareW => GNNEConv2DSharedW(call, block_name, glb_w, glb_if, glb_act, glb_psum, glb_of, is_depthwise, is_init_psum, prefix), TcuDivideStrategy.NoShare => GNNEConv2DSharedNone(call, block_name, glb_w, glb_if, glb_act, glb_psum, glb_of, is_init_psum, prefix), _ => throw new NotSupportedException(), }); - } - - protected virtual ISequentialBuilder LowerGnneTranspose(IndexMapKey parentKey, Call call, GNNETranspose op, string block_name, string prefix) - { - prefix = NameAllocator.Get(nameof(GNNETranspose)); - var call_in = IndexMapKey.Create(call, GNNETranspose.Input); - var seq = T.Sequential().Body( - Visit(call_in, prefix), GetBufferRegion(call_in, out var glb_trans_input), GetBufferRegion(parentKey, out var glb_trans_output, TileOptions.PingPong, name: prefix), EAction.TileBlock(block_name).Alloc(glb_trans_output.Buffers).Reads(glb_trans_input.BufferRegions).Body(EAction.MfuTranspose(glb_trans_input, glb_trans_output, op.Perm))); - - return seq; - } - - protected virtual ISequentialBuilder LowerGnneCrop(IndexMapKey parentKey, Call call, GNNECrop op, string block_name, string prefix) - { - prefix = NameAllocator.Get(nameof(GNNECrop)); - var call_in = IndexMapKey.Create(call, GNNECrop.Input); - var call_in_bbox = IndexMapKey.Create(call, GNNECrop.InputBBox); - var seq = T.Sequential().Body( - Visit(call_in, prefix), - Visit(call_in_bbox, prefix), - GetBufferRegion(call_in, out var glb_crop_input), - GetBufferRegion(call_in_bbox, out var glb_crop_bbox), - GetBufferRegion(parentKey, out var glb_crop_output, TileOptions.PingPong, name: prefix), - EAction.TileBlock(block_name).Alloc(glb_crop_output.Buffers). - Reads(glb_crop_input.BufferRegions, glb_crop_bbox.BufferRegions). - Body( - EAction.MfuCrop( - glb_crop_input, - glb_crop_output, - glb_crop_bbox, - op.ResizeMethod, - op.AlignMethod, - op.HalfPixelCenters))); - - return seq; - } - - protected virtual ISequentialBuilder LowerGnneActivation(IndexMapKey parentKey, Call call, GNNEActivation op, string block_name, string prefix) - { - prefix = NameAllocator.Get(nameof(GNNEActivation)); - var fusedclamps = ((TensorConst)call[GNNEActivation.FusedClamp]).Value.Cast(); - var call_in = IndexMapKey.Create(call, GNNEActivation.Input); - var call_in_act = IndexMapKey.Create(call, GNNEActivation.Act); - - var seq = T.Sequential().Body( - Visit(call_in, prefix), - Visit(call_in_act, prefix), - GetBufferRegion(call_in, out var glb_if), - GetBufferRegion(call_in_act, out var glb_act), - GetBufferRegion(parentKey, out var glb_of, TileOptions.PingPong, name: prefix), - EAction.TileBlock(block_name).Alloc(glb_of.Buffers).Reads(glb_if.BufferRegions, glb_act.BufferRegions).Body( - T.Let(out var m_once, 1).Body( - T.Let(out var c_once, MathF.Min(glb_if.RegionSize(1), ExtCompilerServices.Env.TcuActNum)).Body( - T.Let(out var tcu_h_chunk, TileUtilities.Split(glb_of.RegionSize(2), ExtCompilerServices.Env.TcuActNum)).Body(// segment tcu h in output_h - T.Let(out var n_active_tcu, TileUtilities.SplitTimes(glb_of.RegionSize(2), tcu_h_chunk)).Body( - T.Unrolled(out var tcu_oh, new(glb_of.Region[2].Start, glb_of.Region[2].Stop, tcu_h_chunk)).Body( - T.Let(out var tcu_index_bits, TileUtilities.GetTcuIndexBits(tcu_oh / tcu_h_chunk)).Body( - EAction.TcuPuConfAct(// 1. conf act - tcu_index_bits, - glb_act, - fusedclamps[0], - fusedclamps[1]), - EAction.TcuPuConf( - tcu_index_bits, - glb_if[.., .., (tcu_oh, MathF.Min(tcu_oh + tcu_h_chunk, glb_if.Region[2].Stop)), ..], - glb_of[.., .., (tcu_oh, MathF.Min(tcu_oh + tcu_h_chunk, glb_of.Region[2].Stop)), ..], - 1, - 1, - m_once, - c_once, - 1, - TcuComputeMode.Activation), - EAction.TcuDmConfOf( - tcu_index_bits, - glb_if[.., .., (tcu_oh, MathF.Min(tcu_oh + tcu_h_chunk, glb_if.Region[2].Stop)), ..], - glb_of[.., .., (tcu_oh, MathF.Min(tcu_oh + tcu_h_chunk, glb_of.Region[2].Stop)), ..], - 0))), - EAction.TcuPuComputeDummy(TileUtilities.GetNTcuIndexBits(n_active_tcu), true))))))); - return seq; - } - - protected virtual ISequentialBuilder LowerGnnePdpReduce(IndexMapKey parentKey, Call call, GNNEPdpReduce op, string block_name, string prefix) - { - prefix = NameAllocator.Get(nameof(GNNEPdpReduce)); - var call_in = IndexMapKey.Create(call, GNNEPdpReduce.Input); - - // var ddr_if = BoundsInferGraph[call_in]; - // GlbReIndex(glb_of[.., .., (tcu_oh, MathF.Min(tcu_oh + tcu_h_chunk, glb_of.Region[2].Stop)), ..], glb_if, out var if_paddings) - var seq = T.Sequential().Body( - Visit(call_in, prefix), - GetBufferRegion(call_in, out var glb_if), - GetBufferRegion(parentKey, out var glb_of, TileOptions.PingPong, name: prefix)); - GlbReIndex(glb_of, glb_if, out var sub_paddings); - seq.Body( - EAction.TileBlock(block_name).Alloc(glb_of.Buffers).Reads(glb_if.BufferRegions).Body( - EAction.PdpReduce( - glb_if, - glb_of, - call[GNNEPdpReduce.Filter], - call[GNNEPdpReduce.Stride], - sub_paddings[2].Before, - sub_paddings[2].After, - sub_paddings[3].Before, - sub_paddings[3].After, - op.ReduceOp))); - return seq; - } - - protected virtual ISequentialBuilder LowerGnneConv2DTranspose(IndexMapKey parentKey, Call call, GNNEConv2DTranspose op, string block_name, string prefix) - { - throw new NotSupportedException(); - } -} -#endif diff --git a/modules/Nncase.Modules.CPU/Passes/Tile/MultiFusionChecker.cs b/modules/Nncase.Modules.CPU/Passes/Tile/MultiFusionChecker.cs index 7cbeab67c8..0a5d5a31e3 100644 --- a/modules/Nncase.Modules.CPU/Passes/Tile/MultiFusionChecker.cs +++ b/modules/Nncase.Modules.CPU/Passes/Tile/MultiFusionChecker.cs @@ -1,11 +1,11 @@ // Copyright (c) Canaan Inc. All rights reserved. // Licensed under the Apache license. See LICENSE file in the project root for full license information. -#if false + using System.Collections.Immutable; using System.Runtime.CompilerServices; using Nncase.Diagnostics; using Nncase.IR; -using Nncase.TIR.CPU; +using Nncase.TIR; namespace Nncase.Passes.Tile; @@ -14,239 +14,7 @@ namespace Nncase.Passes.Tile; /// internal sealed class MultiFusionChecker : IFusionChecker { - private readonly List<(MultiLayerFusionConverter, int[], BufferSchedule.ScheduledResponse)> _caches = new(); - - private readonly TileOptions _tileOptions; - - public MultiFusionChecker(TileOptions tileOptions) - { - _tileOptions = tileOptions; - } - - [Flags] - public enum DeviceKind - { - Load, - Store, - Mfu, - Tcu, - None, - } - - public TIR.PrimFunction Convert(RunPassContext passOptions) - { - var (convertVisitor, final_tile_size, response) = _caches.First(); - if (DumpScope.Current.IsEnabled(DumpFlags.PassIR)) - { - response.Dump($"{response.LogicalPrimfunc.Name}_{string.Join("_", final_tile_size)}", convertVisitor.GetType().Name); - } - - return convertVisitor.BuildPhysicalPrimFunc(final_tile_size, response.SchedCandidate, response.LogicalPrimfunc); - } - - public bool Check(Fusion fusion, RunPassContext passOptions) - { - // 1. check all conv2d weights size less than glb size - var visitor = new MultiFusionPreCheckVisitor(); - visitor.Visit(fusion); - if (visitor.AllWeightSizeInBytes > ExtCompilerServices.Env.GlbSize) - { - return false; - } - - // note not support conv2d transpose in layer group. - if (visitor.CountCallOp() > 0) - { - return false; - } - - var curTileOptions = _tileOptions; - if ((visitor.DeviceUsage[DeviceKind.Mfu], visitor.DeviceUsage[DeviceKind.Tcu]) switch - { - (> 1, > 1) => true, - (> 1, 1) => true, - (1, > 1) => true, - _ => false, - }) - { - curTileOptions = curTileOptions with { PingPongNum = 3 }; - } - - // 2. try convert - var convertVisitor = new MultiLayerFusionConverter(curTileOptions); // note the grouped fusion must pingpong input. - var bodySeq = convertVisitor.Visit(fusion); - - // 3. search the tile size - var originLogicalPrimFunc = convertVisitor.BuildLogicalPrimFunc(bodySeq); - - var output_shape = fusion.Body.CheckedShape.ToValueArray(); - var search_space = convertVisitor.BoundsInferGraph.RootTileStep.ToArray(); - var candidate_tile_size = convertVisitor.SearchTileSize(TileOhSearchGenerator(curTileOptions, search_space, convertVisitor.BoundsInferGraph.RootPerm.ToArray()), originLogicalPrimFunc, curTileOptions.MultiWorkers, false, out var sched_response); - if (!candidate_tile_size.Any()) - { - return false; - } - - int[] final_tile_size = new int[candidate_tile_size.Length]; - if (convertVisitor.BalanceTileSize(candidate_tile_size, search_space)) - { - final_tile_size = convertVisitor.SearchTileSize(new TargetTileGenerator(candidate_tile_size), originLogicalPrimFunc, curTileOptions.MultiWorkers, true, out sched_response); - } - else - { - Array.Copy(candidate_tile_size, final_tile_size, candidate_tile_size.Length); - } - - // 5. check the input load usage and compute overlap - var input_shape = fusion.Parameters[0].CheckedShape.ToValueArray(); - var each_axis_tile_nums = final_tile_size.Zip(output_shape).Select(p => (int)System.Math.Ceiling(p.Second / (float)p.First)).ToArray(); - var total_tile_nums = TensorUtilities.GetProduct(each_axis_tile_nums); - if (total_tile_nums > 1) - { - var clamp = (TIR.K510.Segment seg, int i) => - { - return new TIR.K510.Segment(Math.Max(0, seg.Start), Math.Min(input_shape[i], seg.Stop), 1); - }; - - var first_segment = convertVisitor.BoundsInferGraph[convertVisitor.VarToKeyMap[fusion.Parameters[0]]]. - Eval(final_tile_size.Select(t => new TIR.K510.Segment(0, t, 1)).ToArray()). - Select((s, i) => clamp(s, i)). - ToArray(); - - int first_split_axis = 0; - for (int i = input_shape.Length - 1; i >= 0; i--) - { - if (first_segment[i].Length != input_shape[i]) - { - first_split_axis = i; - break; - } - } - - // when once load less than load burst, false - var burst_load_data = TensorUtilities.GetProduct(first_segment.Skip(first_split_axis).Select(s => s.Length).ToArray()); - if (burst_load_data < ExtCompilerServices.Env.LoadBurst) - { - return false; - } - - var second_segment = convertVisitor.BoundsInferGraph[convertVisitor.VarToKeyMap[fusion.Parameters[0]]].Eval(final_tile_size.Select((t, i) => - t < output_shape[i] ? - new TIR.K510.Segment(t, System.Math.Min(t * 2, output_shape[i]), 1) : - new TIR.K510.Segment(0, t, 1)).ToArray()). - Select((s, i) => clamp(s, i)). - ToArray(); - - // Todo 因为我无法知道在当前维度切分会影响哪个维度的变化, 比如带有transpose的, 可能我在c上切,只影响 h w. 所以直接计算所有的的交集 - var overlaps = first_segment.Zip(second_segment).Select(p => p.First.Intersect(p.Second)).ToArray(); - if (Array.IndexOf(convertVisitor.BoundsInferGraph.RootPerm.ToArray(), TIR.K510.NamedAxis.H) is int h && h != -1) - { - // 如果只在h上切分, 只需要考虑h上的overlap有没有超过0.3 - if (overlaps[h] > (input_shape[h] * 0.3)) - { - return false; - } - } - else - { - if (TensorUtilities.GetProduct(overlaps) > TensorUtilities.GetProduct(input_shape) * 0.3) - { - return false; - } - } - } - - _caches.Add((convertVisitor, candidate_tile_size, sched_response)); - if (_caches.Count > 1) - { - _caches.RemoveAt(0); - } - - return true; - } - - /// - /// - /// 只在oh上切分. - /// - /// - private ISearchTileGenerator TileOhSearchGenerator(TileOptions tileOptions, Segment[] search_spaces, TIR.K510.NamedAxis[] rootPerm) - { - if (Array.IndexOf(rootPerm, TIR.K510.NamedAxis.C) is int c && c != -1) - { - search_spaces[c].Start = search_spaces[c].Stop; // not tile oc - } - - if (Array.IndexOf(rootPerm, TIR.K510.NamedAxis.H) is int h && h != -1) - { - // 因为在h上切分, 如果ping pong那么需要限制大小 - if (tileOptions.PingPong) - { - if (search_spaces[h].ClampStop(2, out var new_h_seg)) - { - // assume tile h must > 8 for tcu use. - search_spaces[h] = new Segment(Math.Min(Math.Max(search_spaces[h].Step, ExtCompilerServices.Env.TcuActNum), new_h_seg.Stop), new_h_seg.Stop, new_h_seg.Step); - } - } - } - - if (Array.IndexOf(rootPerm, TIR.K510.NamedAxis.W) is int w && w != -1) - { - search_spaces[w].Start = Math.Min(search_spaces[w].Stop, Math.Max(search_spaces[w].Step, ExtCompilerServices.Env.PuWidth)); // no tile w - } - - // 如果有perm, 那就是 c w h n 方式搜, 没有perm就是从最后搜到最前 所以在有ping pong的时候需要限制切分. - if (rootPerm.All(r => r == NamedAxis.UnKnow) && tileOptions.PingPong) - { - if (search_spaces[^1].ClampStop(2, out var new_seg)) - { - search_spaces[^1] = new_seg; - } - } - - return new DefaultSearchTileGenerator(search_spaces, rootPerm); - } - - internal sealed class MultiFusionPreCheckVisitor : ExprVisitor - { - public Dictionary DeviceUsage { get; } = new() - { - { DeviceKind.Load, 0 }, - { DeviceKind.Store, 0 }, - { DeviceKind.Mfu, 0 }, - { DeviceKind.Tcu, 0 }, - { DeviceKind.None, 0 }, - }; - - public int AllWeightSizeInBytes { get; private set; } - - public int CountCallOp() - where T : Op - { - return ExprMemo.Keys.Count(e => e is Call { Target: Op t } && t.GetType() == typeof(T)); - } - - protected override bool DefaultVisitLeaf(Expr expr) => true; - - protected override bool VisitLeafCall(Call expr) - { - if (expr is Call { Target: IR.K510.GNNEConv2D } && expr[IR.K510.GNNEConv2D.Weights] is Expr weights) - { - AllWeightSizeInBytes += weights.CheckedShape.Prod().FixedValue * weights.CheckedDataType.SizeInBytes; - } - - DeviceUsage[GetDeviceType(expr.Target)]++; - return true; - } + public bool Check(Fusion fusion, RunPassContext passOptions) => false; - private static DeviceKind GetDeviceType(Expr op) => op switch - { - IR.K510.GNNELoad => DeviceKind.Load, - IR.K510.GNNEStore => DeviceKind.Store, - IR.K510.GNNEConv2D or IR.K510.GNNEActivation => DeviceKind.Tcu, - IR.K510.GNNEReduce or IR.K510.GNNEMeshNet or IR.K510.GNNETranspose or IR.K510.GNNEPdpReduce or IR.K510.GNNECrop => DeviceKind.Mfu, - _ => DeviceKind.None, - }; - } + public PrimFunction Convert(RunPassContext passOptions) => throw new NotImplementedException(); } -#endif diff --git a/modules/Nncase.Modules.CPU/Passes/Tile/MultiLayerFusionConverter.cs b/modules/Nncase.Modules.CPU/Passes/Tile/MultiLayerFusionConverter.cs deleted file mode 100644 index a99248cf14..0000000000 --- a/modules/Nncase.Modules.CPU/Passes/Tile/MultiLayerFusionConverter.cs +++ /dev/null @@ -1,230 +0,0 @@ -// Copyright (c) Canaan Inc. All rights reserved. -// Licensed under the Apache license. See LICENSE file in the project root for full license information. - -#if false -using System.Runtime.CompilerServices; -using Nncase.Diagnostics; -using Nncase.IR; -using Nncase.Passes.BufferSchedule; -using Nncase.TIR; -using Nncase.TIR.Builders; -using Nncase.TIR.CPU; -using Nncase.TIR.K510.Builders; -using MathF = Nncase.IR.F.Math; - -namespace Nncase.Passes.Tile; - -internal class MultiLayerFusionConverter : LayerFusionConverter -{ - public MultiLayerFusionConverter(TileOptions tileOptions) - { - TileOptions = tileOptions; - } - - public override Fusion CurrentFusion { get; protected set; } = null!; - - public override IBoundsInferGraph BoundsInferGraph { get; protected set; } = null!; - - /// - /// Gets or sets calc the loop count. - /// - public override Expr LoopCount { get; protected set; } = null!; - - public override Expr LoopCountOuter { get; protected set; } = null!; // LoopCount / TileOptions.PingPongNum; - - public override Expr LoopCountInner { get; protected set; } = null!; // LoopCount % TileOptions.PingPongNum; - - public override Expr Visit(Fusion fusion) - { - if (CurrentFusion is null) - { - CurrentFusion = fusion; - } - else - { - throw new InvalidOperationException("Can't Visit More Than One Fusion!"); - } - - // 0. init the fields - var output_shape = Tile.TileUtilities.GetFusionRealOutputShape(fusion.Body); - BoundsInferGraph = ExtCompilerServices.MakeBoundsInferGraph((Call)fusion.Body); - TileSizeVars.AddRange(output_shape.Select((_, i) => new Var($"dim{i}_tile", new TensorType(DataTypes.Int32, Shape.Scalar)))); - - // 1. make the tile gird loop - NestedBlocks.AddRange(new[] { EAction.TileBlock("MainBlock") }.Concat(Enumerable.Range(0, TileSizeVars.Count).Select(i => EAction.TileBlock($"TileBlock_{i}")))); - - LoopDomains.AddRange(output_shape.Zip(TileSizeVars).Select(t => new TIR.Range(0, t.First, t.Second))); - for (int i = 0; i < TileSizeVars.Count; i++) - { - NestedLoops.Add(T.ForLoop(out var loopVar, LoopDomains[i], LoopMode.Unrolled, $"loop_var_{i}")); - LoopVars.Add(loopVar); - } - - object lastBody = NestedBlocks[^1]; - for (int i = NestedLoops.Count - 1; i >= 0; i--) - { - lastBody = NestedLoops[i].Body(lastBody); - lastBody = NestedBlocks[i].Body(lastBody); - } - - // 2. create the bounds infer function input arguments with the new loop var. - BoundsInferGraph.RootBounds = output_shape.Select((s, i) => - { - var loopVar = LoopVars[i]; - return new TIR.Range(loopVar, IR.F.Math.Min(loopVar + TileSizeVars[i], s), 1); - }).ToList(); - - // 3. set up loop count - var shape = new Expr[LoopVars.Count]; - var upbounds = CurrentFusion.Body.CheckedShape.ToValueArray(); - for (int j = LoopVars.Count - 1; j >= 0; j--) - { - shape[j] = TileUtilities.SplitTimes(upbounds[j], TileSizeVars[j]); - } - - var strides = TensorUtilities.GetStrides(shape).ToArray(); - var indices = LoopVars.Select((v, j) => (Expr)(v / TileSizeVars[j])).ToArray(); - LoopCount = TensorUtilities.GetIndex(strides, indices); - LoopCountOuter = LoopCount / TileOptions.PingPongNum; - LoopCountInner = LoopCount % TileOptions.PingPongNum; - - return Visit((Call)fusion.Body, "root"); - } - - /// - /// convert to the final prim func. - /// - /// . - public PrimFunction VisitToPrimFunc(Fusion fusion) - { - // 1. visit the fusion - var bodySeq = Visit(fusion); - - // 2. build the prim func with tile size vars. - var logicalPrimFunc = BuildLogicalPrimFunc(bodySeq); - - // 3. seach the tiling size - var search_spaces = BoundsInferGraph.RootTileStep.ToArray(); - ISearchTileGenerator tileGenerator; - if (TileOptions.TargetTileSize.Any()) - { - for (int i = 0; i < TileOptions.TargetTileSize.Length; i++) - { - System.Diagnostics.Trace.Assert(TileOptions.TargetTileSize[i] <= search_spaces[i].Stop); - } - - tileGenerator = new TargetTileGenerator(TileOptions.TargetTileSize); - } - else - { - var perm = BoundsInferGraph.RootPerm.ToArray(); - - // when ping pong all, clamp the upper bounds by perm order. - if (TileOptions.PingPong) - { - var re_perm = perm.Zip(Enumerable.Range(0, perm.Length)).OrderBy(t => t.First).Select(t => t.Second).ToArray(); - - var pp_axis = NamedAxis.H; - - // 如果已知维度, 那么在pp axis上进行切分 - if (Array.IndexOf(perm, NamedAxis.H) is var h && h != -1 && Array.IndexOf(perm, NamedAxis.W) is var w && w != -1) - { - // if split the h will less than one burst, split on c. - if ((int)System.Math.Ceiling(search_spaces[h].Stop / (float)TileOptions.PingPongNum) * search_spaces[w].Stop < 128) - { - pp_axis = NamedAxis.C; - } - else - { - pp_axis = NamedAxis.H; - } - } - - for (int i = 0; i < perm.Length; i++) - { - var p = re_perm[i]; - if (perm[i] == pp_axis && search_spaces[p].ClampStop(TileOptions.PingPongNum, out var new_seg)) - { - search_spaces[p] = new_seg; - break; - } - } - } - - { - if (Array.IndexOf(perm, NamedAxis.H) is var h && h != -1 && Array.IndexOf(perm, NamedAxis.W) is var w && w != -1) - { - // if one layer output less than 128, don't split the hw - if (search_spaces[h].Stop * search_spaces[w].Stop < ExtCompilerServices.Env.LoadBurst) - { - search_spaces[h].Start = search_spaces[h].Stop; - search_spaces[w].Start = search_spaces[w].Stop; - } - else - { - search_spaces[h].Start = Math.Min(Math.Max(search_spaces[h].Step, ExtCompilerServices.Env.TcuActNum), search_spaces[h].Stop); - search_spaces[w].Start = Math.Min(Math.Max(search_spaces[w].Step, ExtCompilerServices.Env.PuWidth), search_spaces[w].Stop); - } - } - } - - tileGenerator = new DefaultSearchTileGenerator(search_spaces, BoundsInferGraph.RootPerm); - } - - int[] candidate_tile_size = SearchTileSize(tileGenerator, logicalPrimFunc, TileOptions.MultiWorkers, false, out var response); - if (!candidate_tile_size.Any()) - { - throw new TileFailedException(); - } - - int[] final_tile_size = Array.Empty(); - if (!TileOptions.TargetTileSize.Any() && TileOptions.PingPong && BalanceTileSize(candidate_tile_size, search_spaces)) - { - final_tile_size = SearchTileSize(new TargetTileGenerator(candidate_tile_size), logicalPrimFunc, TileOptions.MultiWorkers, true, out response); - } - else - { - final_tile_size = candidate_tile_size; - } - - if (DumpScope.Current.IsEnabled(DumpFlags.PassIR)) - { - response.Dump($"{CurrentFusion.Name}_{string.Join("_", final_tile_size)}", GetType().Name); - } - - // 4. the local logical buffer to phsy buffer - return BuildPhysicalPrimFunc(final_tile_size, response.SchedCandidate, response.LogicalPrimfunc); - } - - /// - /// 1. if inner loop var > half, balance it. - /// 2. find the highest axis loop var == up_bounds, split it. - /// - public override bool BalanceTileSize(int[] tile_size, Segment[] search_spaces) - { - bool changed = false; - - // balance tile - for (int i = search_spaces.Length - 1; i >= 0; i--) - { - if (search_spaces[i].BalanceTile(tile_size[i], out var newTile)) - { - tile_size[i] = newTile; - return true; - } - } - - // force ping pong - for (int i = 0; i < search_spaces.Length; i++) - { - if (search_spaces[i].ClampStop(2, out var new_seg) && new_seg.Stop < tile_size[i]) - { - tile_size[i] = new_seg.Stop; - return true; - } - } - - return changed; - } -} -#endif diff --git a/modules/Nncase.Modules.CPU/Passes/Tile/SingleCPUFusionConverter.cs b/modules/Nncase.Modules.CPU/Passes/Tile/SingleCPUFusionConverter.cs new file mode 100644 index 0000000000..a2c932eb92 --- /dev/null +++ b/modules/Nncase.Modules.CPU/Passes/Tile/SingleCPUFusionConverter.cs @@ -0,0 +1,119 @@ +// Copyright (c) Canaan Inc. All rights reserved. +// Licensed under the Apache license. See LICENSE file in the project root for full license information. + +using System.Reactive; +using System.Runtime.CompilerServices; +using NetFabric.Hyperlinq; +using Nncase.Diagnostics; +using Nncase.IR; +using Nncase.IR.Buffers; +using Nncase.IR.CPU; +using Nncase.IR.F; +using Nncase.IR.Math; +using Nncase.Passes.Mutators; +using Nncase.PatternMatch; +using Nncase.Schedule; +using Nncase.Targets; +using Nncase.TIR; +using Nncase.TIR.Builders; +using Buffer = Nncase.TIR.Buffer; +using MathF = Nncase.IR.F.Math; +using Range = Nncase.TIR.Range; +using Tuple = Nncase.IR.Tuple; + +namespace Nncase.Passes.Tile; + +/// +/// convert the fusion to prim func. +/// +internal sealed class SingleCPUFusionConverter +{ + public TIR.PrimFunction Visit(Fusion fusion) + { + var body = new List(); + var visitor = new ConvertVisitor(body); + visitor.Visit(fusion); + return T.PrimFunc(fusion.Name, fusion.ModuleKind, visitor.InputBuffers.Concat(visitor.OutputBuffers).ToArray()).Body(body.ToArray()).Build(); + } + + private sealed class ConvertVisitor : ExprVisitor + { + private readonly Dictionary _buffersMap = new(ReferenceEqualityComparer.Instance); + private readonly List _mainBody; + + public ConvertVisitor(List mainBody) + { + _mainBody = mainBody; + } + + public Fusion VisitRootFusion => (Fusion)(VisitRoot!); + + public IEnumerable OutputBuffers => _buffersMap.Values.OfType().Where(b => b.MemLocation == MemoryLocation.Output); + + public IEnumerable InputBuffers => _buffersMap.Values.OfType().Where(b => b.MemLocation == MemoryLocation.Input); + + protected override Unit DefaultVisitLeaf(Expr expr) + { + return new(); + } + + protected override Unit VisitLeafCall(Call expr) + { + var arguments = expr.Arguments.AsValueEnumerable().Select(TryAllocateBuffer).ToArray(); + var ret = TryAllocateBuffer(expr); + var op = ((CPUKernelOp)expr.Target).Target; + + switch (op) + { + case Unary unary: + GenerateUnary(unary, arguments, ret); + break; + default: + throw new NotSupportedException(); + } + return new(); + } + + private void GenerateUnary(Unary unary, ReadOnlySpan arguments, Buffer ret) + { + var input = arguments[Unary.Input.Index]; + var loops = Enumerable.Range(0, input.Rank).Select(i => (T.ForLoop(out var loopVar, (0, input.Dimensions[i]), LoopMode.Serial, $"loop_{i}"), loopVar)).ToArray(); + var input_index = Enumerable.Range(0, input.Rank).Aggregate((Expr)0, (acc, i) => acc + input.Strides[i] * loops[i].Item2); + var output_index = Enumerable.Range(0, input.Rank).Aggregate((Expr)0, (acc, i) => acc + ret.Strides[i] * loops[i].Item2); + Expr stmt = T.Store(ret, output_index, IR.F.Math.Unary(unary.UnaryOp, T.Load(input, output_index))); + var final = loops.Reverse().Aggregate(stmt, (acc, p) => p.Item1.Body(acc).Build()); + _mainBody.Add(T.Block(nameof(Unary)).Body(final).Build()); + } + + private TIR.Buffer TryAllocateBuffer(Expr expr) + { + var name = $"buffer_{_buffersMap.Keys.Count}"; + if (!_buffersMap.TryGetValue(expr, out var buffer)) + { + switch (expr) + { + case Call c: + if (ReferenceEquals(c, VisitRootFusion.Body)) + { + buffer = T.PhysicalBuffer(c.CheckedDataType, MemoryLocation.Output, c.CheckedShape.ToValueArray(), out _, name); + } + else + { + buffer = T.Buffer(c.CheckedDataType, MemoryLocation.Data, c.CheckedShape.ToValueArray().Select(i => (Expr)i).ToArray(), out _, name); + } + break; + case Var v: + buffer = T.PhysicalBuffer(v.CheckedDataType, MemoryLocation.Input, v.CheckedShape.ToValueArray(), out _, name); + break; + case TensorConst c: + buffer = T.PhysicalBuffer(c.Value.ElementType, MemoryLocation.Rdata, c.Value.Dimensions, out _, name); + break; + default: + throw new NotSupportedException(); + } + _buffersMap.Add(expr, buffer); + } + return buffer; + } + } +} diff --git a/modules/Nncase.Modules.CPU/Passes/Tile/TileOptions.cs b/modules/Nncase.Modules.CPU/Passes/Tile/TileOptions.cs index ba3368a679..749355b5da 100644 --- a/modules/Nncase.Modules.CPU/Passes/Tile/TileOptions.cs +++ b/modules/Nncase.Modules.CPU/Passes/Tile/TileOptions.cs @@ -13,12 +13,8 @@ namespace Nncase.Passes.Tile; /// TileOptions. /// /// TargetTileSize. -/// ForceFence. -/// 是否进行ping pong. -/// PingPongNum. -/// 对于测试. -/// 是否开启多线程搜索. -public sealed record TileOptions(int[] TargetTileSize, bool ForceFence, bool PingPong, int PingPongNum, bool ForceMultiLayer, bool MultiWorkers) +/// the cache size. +public sealed record TileOptions(int[] TargetTileSize, int CacheSize) { - public static TileOptions Default { get; } = new(Array.Empty(), false, true, 2, false, true); + public static TileOptions Default { get; } = new(Array.Empty(), 4 * 1024 * 1024 * 8); } diff --git a/modules/Nncase.Modules.CPU/Passes/Tile/TwoFusionChecker.cs b/modules/Nncase.Modules.CPU/Passes/Tile/TwoFusionChecker.cs deleted file mode 100644 index b0a8a858c3..0000000000 --- a/modules/Nncase.Modules.CPU/Passes/Tile/TwoFusionChecker.cs +++ /dev/null @@ -1,131 +0,0 @@ -// Copyright (c) Canaan Inc. All rights reserved. -// Licensed under the Apache license. See LICENSE file in the project root for full license information. - -#if false -using System.Collections.Immutable; -using System.Runtime.CompilerServices; -using Nncase.Diagnostics; -using Nncase.IR; -using Nncase.PatternMatch; -using Nncase.TIR.K510; -using static Nncase.PatternMatch.Utility; - -[assembly: InternalsVisibleTo("Nncase.Tests.K510")] - -namespace Nncase.Passes.Tile; - -/// -/// the two Fusion checker. -/// 专门前一层可以在conv的ic上tiling的情况. -/// -internal sealed class TwoFusionChecker : IFusionChecker -{ - private readonly List<(LayerFusionConverter, int[], BufferSchedule.ScheduledResponse)> _caches = new(); - - private readonly TileOptions _tileOptions; - - public TwoFusionChecker(TileOptions tileOptions) - { - _tileOptions = tileOptions; - } - - /// - /// Gets 匹配 conv2d + 非reduction. - /// - public static Pattern TwoFusionPattern { get; } = IsCallWildcard( - null, - IsOp(), - IsCallWildcard( - "conv2d", - IsOp(), - IsCallWildcard(null, IsOp("calleeOp", op => op is IR.K510.GNNEMeshNet or IR.K510.GNNEPdpReduce), IsCallWildcard(null, IsOp())))); - - public TIR.PrimFunction Convert(RunPassContext passOptions) - { - var (convertVisitor, final_tile_size, response) = _caches.First(); - if (DumpScope.Current.IsEnabled(DumpFlags.PassIR)) - { - response.Dump($"{response.LogicalPrimfunc.Name}_{string.Join("_", final_tile_size)}", nameof(LayerFusionOcIcConverter)); - } - - return convertVisitor.BuildPhysicalPrimFunc(final_tile_size, response.SchedCandidate, response.LogicalPrimfunc); - } - - public bool Check(Fusion fusion, RunPassContext passOptions) - { - // 1. try match pattern - if (!CompilerServices.TryMatchRoot(fusion.Body, TwoFusionPattern, out var matchResult)) - { - return false; - } - - // 2. try convert - var convertVisitor = new LayerFusionOcIcConverter(_tileOptions, TileUtilities.ChoiceTcuStrategy((Call)matchResult["conv2d"], out _), false); // note the grouped fusion must pingpong input. - var bodySeq = convertVisitor.Visit(fusion); - - // 3. search the tile size - var originLogicalPrimFunc = convertVisitor.BuildLogicalPrimFunc(bodySeq); - _ = fusion.Body.CheckedShape.ToValueArray(); - var search_space = convertVisitor.OCBoundsInferGraph.RootTileStep.ToArray(); - search_space = search_space.Concat(new[] { new Segment(1, convertVisitor.Conv2DInShape[1], 1) }).ToArray(); - var candidate_tile_size = convertVisitor.SearchTileSize( - SearchGenerator(search_space), - originLogicalPrimFunc, - _tileOptions.MultiWorkers, - false, - out var sched_response); - if (!candidate_tile_size.Any()) - { - return false; - } - - if (_tileOptions.PingPong && convertVisitor.BalanceTileSize(candidate_tile_size, search_space)) - { - _ = convertVisitor.SearchTileSize(new TargetTileGenerator(candidate_tile_size), originLogicalPrimFunc, _tileOptions.MultiWorkers, true, out sched_response); - } - else - { - } - - _caches.Add((convertVisitor, candidate_tile_size, sched_response)); - if (_caches.Count > 1) - { - _caches.RemoveAt(0); - } - - return true; - } - - /// - /// - /// do not tile on w dimension. - /// - /// - /// . - private ISearchTileGenerator SearchGenerator(Segment[] search_spaces) - { - var newSpaces = search_spaces.ToArray(); - - // 这里就优先在ic上切分ping pong. 因为在ic上切分对于if来说都是不一样的. - var ic = newSpaces.Length - 1; - if (newSpaces[ic].ClampStop(2, out var new_seg)) - { - newSpaces[ic] = new_seg; - } - - // ic最小也得分两个tcu. - newSpaces[ic].Start = Math.Min(ExtCompilerServices.Env.TcuActNum * ExtCompilerServices.Env.PuHeight, newSpaces[ic].Stop); - - var w = 3; - newSpaces[w].Start = newSpaces[w].Stop; // no tile w - - return new QueuedSearchTileGenerator(newSpaces, g => - { - g.Queue.Add((1, System.Math.Min(ExtCompilerServices.Env.PuWidth * ExtCompilerServices.Env.TcuActNum, g.UpperBounds[1]))); // oc - g.Queue.Add((4, g.UpperBounds[4])); // ic - g.Queue.Add((2, g.UpperBounds[2])); // h - g.Queue.Add((1, g.UpperBounds[1])); // oc - }); - } -} -#endif diff --git a/modules/Nncase.Modules.CPU/Runtime/CPU/CPURTModule.cs b/modules/Nncase.Modules.CPU/Runtime/CPU/CPURTModule.cs deleted file mode 100644 index 0ed89e20f3..0000000000 --- a/modules/Nncase.Modules.CPU/Runtime/CPU/CPURTModule.cs +++ /dev/null @@ -1,23 +0,0 @@ -// Copyright (c) Canaan Inc. All rights reserved. -// Licensed under the Apache license. See LICENSE file in the project root for full license information. - -using System; -using System.Collections.Generic; -using System.Linq; -using System.Text; -using System.Threading.Tasks; - -namespace Nncase.Runtime.CPU; - -internal class CPURTModule -{ - /// - /// KPU module kind. - /// - public static readonly string Kind = "cpu"; - - /// - /// KPU module version. - /// - public static readonly uint Version = 1; -} diff --git a/modules/Nncase.Modules.CPU/Targets/CPUTarget.cs b/modules/Nncase.Modules.CPU/Targets/CPUTarget.cs index 6bee77d2a7..3f9c2bc9cc 100644 --- a/modules/Nncase.Modules.CPU/Targets/CPUTarget.cs +++ b/modules/Nncase.Modules.CPU/Targets/CPUTarget.cs @@ -77,7 +77,15 @@ public void RegisterTargetDependentAfterQuantPass(IPassManager passManager, Comp { passManager.AddWithName("MakeFusion").Configure(p => { - p.Add(); + p.Add(); + }); + + passManager.Add(Passes.Tile.TileOptions.Default); + + passManager.Add().Configure(p => + { + p.Add(); + p.Add(); }); } @@ -92,9 +100,11 @@ public IModuleBuilder CreateModuleBuilder(string moduleKind, CompileOptions opti { return new StackVMModuleBuilder(); } - else + else if (moduleKind == CPUTarget.Kind) { - throw new NotSupportedException($"{moduleKind} module is not supported."); + return new CodeGen.CPU.ModuleBuilder(options); } + + throw new NotSupportedException(); } } diff --git a/src/Nncase.CodeGen/CodeGen/LinkedFunction.cs b/src/Nncase.CodeGen/CodeGen/LinkedFunction.cs index df9ff51a28..e1fe76b67a 100644 --- a/src/Nncase.CodeGen/CodeGen/LinkedFunction.cs +++ b/src/Nncase.CodeGen/CodeGen/LinkedFunction.cs @@ -15,7 +15,6 @@ public class LinkedFunction : ILinkedFunction public LinkedFunction(uint id, Callable sourceFunction, uint textBegin, uint textLength, IReadOnlyList sections) { Id = id; - CompilerServices.InferenceType(sourceFunction); ParameterTypes = ((CallableType)sourceFunction.CheckedType).Parameters.ToArray(); ReturnType = ((CallableType)sourceFunction.CheckedType).ReturnType; TextBegin = textBegin; diff --git a/src/Nncase.Core/IR/Buffers/BufferOf.cs b/src/Nncase.Core/IR/Buffers/BufferOf.cs index a3bb033275..baf982d17e 100644 --- a/src/Nncase.Core/IR/Buffers/BufferOf.cs +++ b/src/Nncase.Core/IR/Buffers/BufferOf.cs @@ -16,8 +16,8 @@ public sealed partial class BufferOf : Op /// public static readonly ParameterInfo Input = new(typeof(BufferOf), 0, "input", IsTensor()); - public Schedule.MemoryLocation MemoryLocation { get; } + public TIR.MemoryLocation MemoryLocation { get; } /// - public override string DisplayProperty() => $"Schedule.MemoryLocation.{MemoryLocation}"; + public override string DisplayProperty() => $"MemoryLocation.{MemoryLocation}"; } diff --git a/src/Nncase.Core/IR/Buffers/Functional.cs b/src/Nncase.Core/IR/Buffers/Functional.cs index 54cd9f59cf..a2e3507a5f 100644 --- a/src/Nncase.Core/IR/Buffers/Functional.cs +++ b/src/Nncase.Core/IR/Buffers/Functional.cs @@ -41,5 +41,5 @@ public static Call BaseMentOf(Expr input) => /// /// create the uninitialized buffer. /// - public static Call Uninitialized(DataType dataType, Schedule.MemoryLocation memoryLocation, Expr shape) => new Call(new Uninitialized(dataType, memoryLocation), shape); + public static Call Uninitialized(DataType dataType, TIR.MemoryLocation memoryLocation, Expr shape) => new Call(new Uninitialized(dataType, memoryLocation), shape); } diff --git a/src/Nncase.Core/IR/Buffers/Uninitialized.cs b/src/Nncase.Core/IR/Buffers/Uninitialized.cs index 2f529638b9..42564bbc4c 100644 --- a/src/Nncase.Core/IR/Buffers/Uninitialized.cs +++ b/src/Nncase.Core/IR/Buffers/Uninitialized.cs @@ -19,11 +19,11 @@ public sealed partial class Uninitialized : Op public DataType DType { get; } - public Schedule.MemoryLocation MemoryLocation { get; } + public TIR.MemoryLocation MemoryLocation { get; } /// public override bool CanFoldConstCall => false; /// - public override string DisplayProperty() => $"{DType.GetCSharpName()}, Schedule.MemoryLocation.{MemoryLocation}"; + public override string DisplayProperty() => $"{DType.GetCSharpName()}, MemoryLocation.{MemoryLocation}"; } diff --git a/src/Nncase.Core/IR/ExprCloner.g.cs b/src/Nncase.Core/IR/ExprCloner.g.cs index 214605c69e..78d5d2bda4 100644 --- a/src/Nncase.Core/IR/ExprCloner.g.cs +++ b/src/Nncase.Core/IR/ExprCloner.g.cs @@ -256,4 +256,12 @@ protected override Expr VisitLeafIterVar(TIR.IterVar expr, TContext context) ); } + /// + protected override Expr VisitLeafMemSpan(TIR.MemSpan expr, TContext context) + { + return expr.With( + start: Clone(expr.Start, context), + size: Clone(expr.Size, context) + ); + } } diff --git a/src/Nncase.Core/IR/ExprFunctor.g.cs b/src/Nncase.Core/IR/ExprFunctor.g.cs index 642b2709e4..57e1ef86d5 100644 --- a/src/Nncase.Core/IR/ExprFunctor.g.cs +++ b/src/Nncase.Core/IR/ExprFunctor.g.cs @@ -154,6 +154,10 @@ public partial class ExprFunctor /// internal protected virtual TExprResult VisitIterVar(TIR.IterVar expr, TContext context) => DefaultVisit(expr, context); + /// + /// Visit . + /// + internal protected virtual TExprResult VisitMemSpan(TIR.MemSpan expr, TContext context) => DefaultVisit(expr, context); } public partial class ExprFunctor @@ -354,4 +358,11 @@ public partial class ExprFunctor /// internal protected sealed override TExprResult VisitIterVar(TIR.IterVar expr, Unit context) => VisitIterVar(expr); + /// + /// Visit . + /// + internal protected virtual TExprResult VisitMemSpan(TIR.MemSpan expr) => base.VisitMemSpan(expr, default); + + /// + internal protected sealed override TExprResult VisitMemSpan(TIR.MemSpan expr, Unit context) => VisitMemSpan(expr); } diff --git a/src/Nncase.Core/IR/ExprRewriter.g.cs b/src/Nncase.Core/IR/ExprRewriter.g.cs index 4c8cece3f2..2f0b6c1233 100644 --- a/src/Nncase.Core/IR/ExprRewriter.g.cs +++ b/src/Nncase.Core/IR/ExprRewriter.g.cs @@ -322,6 +322,10 @@ protected sealed override Expr VisitLeafIterVar(TIR.IterVar expr, TContext conte /// protected virtual Expr RewriteLeafIterVar(TIR.IterVar expr, TContext context) => DefaultRewriteLeaf(expr, context); + /// + /// Rewrite leaf . + /// + protected virtual Expr RewriteLeafMemSpan(TIR.MemSpan expr, TContext context) => DefaultRewriteLeaf(expr, context); } public partial class ExprRewriter @@ -550,4 +554,11 @@ public partial class ExprRewriter /// protected sealed override Expr RewriteLeafIterVar(TIR.IterVar expr, Unit context) => RewriteLeafIterVar(expr); + /// + /// Rewrite leaf . + /// + protected virtual Expr RewriteLeafMemSpan(TIR.MemSpan expr) => DefaultRewriteLeaf(expr); + + /// + protected sealed override Expr RewriteLeafMemSpan(TIR.MemSpan expr, Unit context) => RewriteLeafMemSpan(expr); } diff --git a/src/Nncase.Core/IR/ExprVisitor.g.cs b/src/Nncase.Core/IR/ExprVisitor.g.cs index c296f5f7e0..5e5a609a6a 100644 --- a/src/Nncase.Core/IR/ExprVisitor.g.cs +++ b/src/Nncase.Core/IR/ExprVisitor.g.cs @@ -205,6 +205,13 @@ protected internal override TExprResult VisitIterVar(TIR.IterVar expr, TContext return VisitLeafIterVar(expr, context); } + /// + protected internal override TExprResult VisitMemSpan(TIR.MemSpan expr, TContext context) + { + VisitOperands(expr, context); + return VisitLeafMemSpan(expr, context); + } + /// /// Visit leaf . /// @@ -345,6 +352,11 @@ protected internal override TExprResult VisitIterVar(TIR.IterVar expr, TContext /// protected virtual TExprResult VisitLeafIterVar(TIR.IterVar expr, TContext context) => DefaultVisitLeaf(expr, context); + /// + /// Visit leaf . + /// + protected virtual TExprResult VisitLeafMemSpan(TIR.MemSpan expr, TContext context) => DefaultVisitLeaf(expr, context); + } public partial class ExprVisitor @@ -524,6 +536,14 @@ public partial class ExprVisitor /// internal protected sealed override TExprResult VisitIterVar(TIR.IterVar expr, Unit context) => VisitIterVar(expr); + /// + /// Visit . + /// + internal protected virtual TExprResult VisitMemSpan(TIR.MemSpan expr) => base.VisitMemSpan(expr, default); + + /// + internal protected sealed override TExprResult VisitMemSpan(TIR.MemSpan expr, Unit context) => VisitMemSpan(expr); + /// /// Visit leaf . /// @@ -748,4 +768,11 @@ public partial class ExprVisitor /// protected sealed override TExprResult VisitLeafIterVar(TIR.IterVar expr, Unit context) => VisitLeafIterVar(expr); + /// + /// Visit leaf . + /// + protected virtual TExprResult VisitLeafMemSpan(TIR.MemSpan expr) => base.VisitLeafMemSpan(expr, default); + + /// + protected sealed override TExprResult VisitLeafMemSpan(TIR.MemSpan expr, Unit context) => VisitLeafMemSpan(expr); } diff --git a/src/Nncase.Core/Schedule/ScheduleTypes.cs b/src/Nncase.Core/Schedule/ScheduleTypes.cs index 55b89ce8a0..9c2e47f8e2 100644 --- a/src/Nncase.Core/Schedule/ScheduleTypes.cs +++ b/src/Nncase.Core/Schedule/ScheduleTypes.cs @@ -10,51 +10,6 @@ namespace Nncase.Schedule; -/// -/// the memory type. -/// -public enum MemoryLocation : byte -{ - /// - /// input. - /// - Input = 0, - - /// - /// output. - /// - Output = 1, - - /// - /// constant data. - /// - Rdata = 2, - - /// - /// compute temp data. - /// - Data = 3, - - /// - /// shared data. - /// - SharedData = 4, - - /// - /// l2 data. - /// - L2Data = 5, - - /// - /// L1 data. - /// - L1Data = 6, - - /// - /// base addr. - /// - PrivateBase = 64, -} /// /// the scheduler interface. diff --git a/src/Nncase.Core/TIR/Buffer.cs b/src/Nncase.Core/TIR/Buffer.cs index a3a35aac52..9289d1afec 100644 --- a/src/Nncase.Core/TIR/Buffer.cs +++ b/src/Nncase.Core/TIR/Buffer.cs @@ -269,7 +269,7 @@ public SelectedRange Slice(Segment1D segment) /// public abstract class Buffer : Expr { - public Buffer(string name, DataType elemType, Schedule.MemoryLocation memoryLocation, Expr[] operands) + public Buffer(string name, DataType elemType, MemoryLocation memoryLocation, Expr[] operands) : base(operands.ToArray()) { Name = name; @@ -281,7 +281,7 @@ public Buffer(string name, DataType elemType, Schedule.MemoryLocation memoryLoca public DataType ElemType { get; } - public Schedule.MemoryLocation MemLocation { get; } + public MemoryLocation MemLocation { get; } /// /// Gets if this buffer from the constant !. @@ -341,7 +341,7 @@ public sealed class LogicalBuffer : Buffer /// prim type. /// the shape. /// the strides. - public LogicalBuffer(string name, DataType elemType, Schedule.MemoryLocation location, ReadOnlySpan dimensions, ReadOnlySpan strides) + public LogicalBuffer(string name, DataType elemType, MemoryLocation location, ReadOnlySpan dimensions, ReadOnlySpan strides) : base(name, elemType, location, ArrayUtility.Concat(dimensions, strides)) { Rank = dimensions.Length; @@ -351,7 +351,7 @@ public LogicalBuffer(string name, DataType elemType, Schedule.MemoryLocation loc /// Initializes a new instance of the class. /// . /// - public LogicalBuffer(string name, Schedule.MemoryLocation location, TensorConst tensor) + public LogicalBuffer(string name, MemoryLocation location, TensorConst tensor) : this(name, tensor.Value.ElementType, location, ArrayUtility.ToExprArray(tensor.Value.Dimensions), ArrayUtility.ToExprArray(tensor.Value.Strides)) { Const = tensor; @@ -361,7 +361,7 @@ public LogicalBuffer(string name, Schedule.MemoryLocation location, TensorConst /// Initializes a new instance of the class. /// /// - public LogicalBuffer(string name, DataType elemType, Schedule.MemoryLocation location, ReadOnlySpan dimensions) + public LogicalBuffer(string name, DataType elemType, MemoryLocation location, ReadOnlySpan dimensions) : this(name, elemType, location, dimensions, TensorUtilities.GetStrides(dimensions)) { } @@ -394,7 +394,7 @@ public override string ToString() public override TExprResult Accept(ExprFunctor functor, TContext context) => functor.VisitLogicalBuffer(this, context); - public LogicalBuffer With(string? name = null, DataType? elemType = null, Schedule.MemoryLocation? location = null, Expr[]? dimensions = null, Expr[]? strides = null) + public LogicalBuffer With(string? name = null, DataType? elemType = null, MemoryLocation? location = null, Expr[]? dimensions = null, Expr[]? strides = null) => new LogicalBuffer(name ?? Name, elemType ?? ElemType, location ?? MemLocation, dimensions ?? Dimensions, strides ?? Strides) { Const = Const }; } @@ -410,7 +410,7 @@ public sealed class PhysicalBuffer : Buffer /// Initializes a new instance of the class. /// ctor for physical buffer. /// - public PhysicalBuffer(string name, DataType elemType, Schedule.MemoryLocation location, ReadOnlySpan dimensions, ReadOnlySpan strides, int start, int size) + public PhysicalBuffer(string name, DataType elemType, MemoryLocation location, ReadOnlySpan dimensions, ReadOnlySpan strides, int start, int size) : base(name, elemType, location, Array.Empty()) { Start = start; @@ -423,7 +423,7 @@ public PhysicalBuffer(string name, DataType elemType, Schedule.MemoryLocation lo /// Initializes a new instance of the class. /// . /// - public PhysicalBuffer(string name, DataType elemType, Schedule.MemoryLocation location, ReadOnlySpan dimensions, int start, int size) + public PhysicalBuffer(string name, DataType elemType, MemoryLocation location, ReadOnlySpan dimensions, int start, int size) : this(name, elemType, location, dimensions, TensorUtilities.GetStrides(dimensions), start, size) { } @@ -432,7 +432,7 @@ public PhysicalBuffer(string name, DataType elemType, Schedule.MemoryLocation lo /// Initializes a new instance of the class. /// . /// - public PhysicalBuffer(string name, Schedule.MemoryLocation location, TensorConst tensor, int start, int size) + public PhysicalBuffer(string name, MemoryLocation location, TensorConst tensor, int start, int size) : this(name, tensor.Value.ElementType, location, tensor.Value.Dimensions, tensor.Value.Strides, start, size) { Const = tensor; @@ -494,6 +494,6 @@ public override bool Equals(object? obj) public override TExprResult Accept(ExprFunctor functor, TContext context) => functor.VisitPhysicalBuffer(this, context); - public PhysicalBuffer With(string? name = null, DataType? elemType = null, Schedule.MemoryLocation? location = null, int[]? dimensions = null, int[]? strides = null, int? start = null, int? size = null) + public PhysicalBuffer With(string? name = null, DataType? elemType = null, MemoryLocation? location = null, int[]? dimensions = null, int[]? strides = null, int? start = null, int? size = null) => new PhysicalBuffer(name ?? Name, elemType ?? ElemType, location ?? MemLocation, dimensions ?? FixedDimensions, strides ?? FixedStrides, start ?? Start, size ?? Size) { Const = Const }; } diff --git a/src/Nncase.Core/TIR/MemSpan.cs b/src/Nncase.Core/TIR/MemSpan.cs new file mode 100644 index 0000000000..c360932c46 --- /dev/null +++ b/src/Nncase.Core/TIR/MemSpan.cs @@ -0,0 +1,86 @@ +using Nncase; +using Nncase.IR; + +namespace Nncase.TIR; + + +/// +/// the memory type. +/// +public enum MemoryLocation : byte +{ + /// + /// input. + /// + Input = 0, + + /// + /// output. + /// + Output = 1, + + /// + /// constant data. + /// + Rdata = 2, + + /// + /// compute temp data. + /// + Data = 3, + + /// + /// shared data. + /// + SharedData = 4, + + /// + /// l2 data. + /// + L2Data = 5, + + /// + /// L1 data. + /// + L1Data = 6, + + /// + /// base addr. + /// + PrivateBase = 64, +} + +public sealed class MemSpan : Expr +{ + public MemSpan(Expr size, MemoryLocation location) : base(new[] { None.Default, size }) + { + Location = location; + } + + public MemSpan(Expr start, Expr size, MemoryLocation location) : base(new[] { start, size }) + { + Location = location; + } + + /// + /// Gets the start. + /// + public Expr Start => Operands[0]; + + /// + /// Gets the size of bytes. + /// + public Expr Size => Operands[1]; + + /// + /// Gets the memory location. + /// + public MemoryLocation Location { get; } + + /// + public override TExprResult Accept(ExprFunctor functor, TContext context) + => functor.VisitMemSpan(this, context); + + + public MemSpan With(Expr? start = null, Expr? size = null, MemoryLocation? location = null) => new(start ?? Start, size ?? Size, location ?? Location); +} \ No newline at end of file diff --git a/src/Nncase.Core/TIR/Ops.cs b/src/Nncase.Core/TIR/Ops.cs index 76f9e395b6..cb6446fb76 100644 --- a/src/Nncase.Core/TIR/Ops.cs +++ b/src/Nncase.Core/TIR/Ops.cs @@ -12,7 +12,7 @@ namespace Nncase.TIR; /// -/// . +/// Load op. /// public sealed partial class Load : Op { @@ -24,7 +24,10 @@ public sealed partial class Load : Op /// /// Gets index. /// - public static readonly ParameterInfo Index = new(typeof(Load), 1, "index", HasDataType(DataTypes.Int32) & (IsScalar() | HasRank(1))); + public static readonly ParameterInfo Index = new(typeof(Load), 1, "index", IsIntegralScalar()); + + /// + public override bool CanFoldConstCall => false; } /// @@ -53,17 +56,20 @@ public sealed partial class Store : Op /// /// The buffer variable handle. /// - public static readonly ParameterInfo Handle = new(typeof(Store), 0, "handle", IsPointer()); + public static readonly ParameterInfo Handle = new(typeof(Store), 0, "handle"); /// /// The index locations to be stored. /// - public static readonly ParameterInfo Index = new(typeof(Store), 1, "index", HasDataType(DataTypes.Int32)); + public static readonly ParameterInfo Index = new(typeof(Store), 1, "index", IsIntegralScalar()); /// /// The value to be stored. /// - public static readonly ParameterInfo Value = new(typeof(Store), 2, "value"); + public static readonly ParameterInfo Value = new(typeof(Store), 2, "value", IsScalar()); + + /// + public override bool CanFoldConstCall => false; } /// diff --git a/src/Nncase.Core/TIR/Script.cs b/src/Nncase.Core/TIR/Script.cs index 9d9a212e46..8f640a0419 100644 --- a/src/Nncase.Core/TIR/Script.cs +++ b/src/Nncase.Core/TIR/Script.cs @@ -52,7 +52,7 @@ public static class T /// /// The buffer handle variable in the load expression. /// The index in the load. - public static Call Load(Var handle, Expr index) => new Call(new Load(), handle, index); + public static Call Load(TIR.Buffer handle, Expr index) => new Call(new Load(), handle, index); /// /// get the nop op. @@ -76,25 +76,7 @@ public static class T /// The buffer Variable. /// The index in the store expression. /// The value we want to store. - public static Call Store(Var handle, Expr index, Expr value) => new Call(new Store(), handle, index, value); - - /// - /// If the op is BufferLoad, it will return BufferStore - /// If the op is Load, it will return Store. - /// - /// the op call. - /// update value. - /// new store call. - public static Expr Store(Expr op, Expr value) => op switch - { - Call load => load.Target switch - { - TIR.Load => T.Store((Var)load[TIR.Load.Handle], load[TIR.Load.Index], value), - _ => throw new InvalidOperationException("Only Can build Store Op from Load!"), - }, - TIR.BufferLoad bufload => new BufferStore(bufload.Buffer, bufload.Indices, value), - _ => throw new InvalidOperationException("Only Can build Store Op from Load!"), - }; + public static Call Store(TIR.Buffer handle, Expr index, Expr value) => new Call(new Store(), handle, index, value); /// /// build for loop. @@ -226,7 +208,7 @@ public static IIfThenElseBuilder If(Expr condition) /// /// create the memRef by tensortype. /// - public static LogicalBuffer Buffer(DataType elem_type, Schedule.MemoryLocation location, ReadOnlySpan dimensions, out LogicalBuffer buffer, [CallerArgumentExpression("buffer")] string name = "") + public static LogicalBuffer Buffer(DataType elem_type, MemoryLocation location, ReadOnlySpan dimensions, out LogicalBuffer buffer, [CallerArgumentExpression("buffer")] string name = "") { if (name.StartsWith("var ")) { @@ -240,7 +222,7 @@ public static LogicalBuffer Buffer(DataType elem_type, Schedule.MemoryLocation l /// /// ctor for physical buffer. /// - public static PhysicalBuffer PhysicalBuffer(DataType elem_type, Schedule.MemoryLocation location, ReadOnlySpan dimensions, out PhysicalBuffer buffer, [CallerArgumentExpression("buffer")] string name = "") + public static PhysicalBuffer PhysicalBuffer(DataType elem_type, MemoryLocation location, ReadOnlySpan dimensions, out PhysicalBuffer buffer, [CallerArgumentExpression("buffer")] string name = "") { if (name.StartsWith("var ")) { @@ -271,7 +253,7 @@ public static PhysicalBuffer ConstBuffer(Const expr, out PhysicalBuffer buffer, throw new NotSupportedException(); } - buffer = new PhysicalBuffer(name, Schedule.MemoryLocation.Rdata, (TensorConst)expr, 0, size); + buffer = new PhysicalBuffer(name, MemoryLocation.Rdata, (TensorConst)expr, 0, size); return buffer; } @@ -294,7 +276,7 @@ public static Expr MayBeConst(Const? expr, out Buffer? buffer, [CallerArgumentEx { name = name[4..]; } - buffer = new Buffer(name, Schedule.MemoryLocation.Rdata, (TensorType)expr.ValueType) + buffer = new Buffer(name, MemoryLocation.Rdata, (TensorType)expr.ValueType) { Const = expr, }; diff --git a/src/Nncase.Evaluator/TIR/Load.cs b/src/Nncase.Evaluator/TIR/Load.cs index 6ea6faddff..86885bcc48 100644 --- a/src/Nncase.Evaluator/TIR/Load.cs +++ b/src/Nncase.Evaluator/TIR/Load.cs @@ -30,12 +30,6 @@ public string Visit(IIRPrinterContext context, Load target, bool iLmode) private IRType Visit(Load target, TensorType handle, TensorType index) { - if (!handle.IsScalar && handle.DType is not PointerType) - { - throw new NotSupportedException(handle.DType.ToString()); - } - - _ = index.IsScalar ? 1 : index.Shape[0].FixedValue; - return TensorType.Scalar(((PointerType)handle.DType).ElemType); + return TensorType.Scalar(handle.DType); } } diff --git a/src/Nncase.Evaluator/TIR/Store.cs b/src/Nncase.Evaluator/TIR/Store.cs index b29459bfe2..573a5e8660 100644 --- a/src/Nncase.Evaluator/TIR/Store.cs +++ b/src/Nncase.Evaluator/TIR/Store.cs @@ -33,12 +33,10 @@ public string Visit(IIRPrinterContext context, Store target, bool iLmode) private IRType Visit(Store target, TensorType handle, TensorType index, TensorType value) { - _ = index.IsScalar ? 1 : index.Shape[0].FixedValue; - - var elemType = ((PointerType)handle.DType).ElemType; - if (elemType != value.DType) + + if (handle.DType != value.DType) { - return new InvalidType($"You Can't Load The {value.DType} To {elemType}"); + return new InvalidType($"You Can't Load The {value.DType} To {handle.DType}"); } return TupleType.Void; diff --git a/src/Nncase.Passes/DDrBufferSchdeulePass.cs b/src/Nncase.Passes/DDrBufferSchdeulePass.cs index 8afdb3c5e0..e26a62ec0c 100644 --- a/src/Nncase.Passes/DDrBufferSchdeulePass.cs +++ b/src/Nncase.Passes/DDrBufferSchdeulePass.cs @@ -23,7 +23,7 @@ namespace Nncase.Passes; /// public sealed class DDrBufferSchdeulePass : ModulePass { - private readonly Dictionary> _module_usage = new(); + private readonly Dictionary> _module_usage = new(); private readonly Dictionary> _module_hashset = new(); @@ -106,12 +106,12 @@ protected override async Task RunCoreAsync(IRModule module, RunPassCon /// internal sealed class DDrBufferAllocator : ExprVisitor { - private readonly Dictionary _functionUsage; + private readonly Dictionary _functionUsage; private readonly HashSet _functionHashset; private PrimFunction? _entry; - public DDrBufferAllocator(Dictionary> module_usage, Dictionary> module_hashset) + public DDrBufferAllocator(Dictionary> module_usage, Dictionary> module_hashset) { ModuleUsage = module_usage; ModuleHashSet = module_hashset; @@ -120,13 +120,13 @@ public DDrBufferAllocator(Dictionary> ModuleUsage { get; } + public Dictionary> ModuleUsage { get; } public Dictionary> ModuleHashSet { get; } public bool Changed { get; private set; } - public int DataUsage => _functionUsage.GetValueOrDefault(Schedule.MemoryLocation.Data, 0); + public int DataUsage => _functionUsage.GetValueOrDefault(MemoryLocation.Data, 0); /// /// only visit one prim func. @@ -138,7 +138,7 @@ protected override bool VisitPrimFunction(PrimFunction primFunction) { foreach (var physical in primFunction.Parameters) { - if (physical.MemLocation is Schedule.MemoryLocation.Input or Schedule.MemoryLocation.Output) + if (physical.MemLocation is MemoryLocation.Input or MemoryLocation.Output) { // avoid visit same buffer if (!_functionHashset.Contains(physical)) @@ -175,7 +175,7 @@ protected override bool VisitLeafBuffer(TIR.Buffer buffer) } // rdata write into the moduleUsage - if (physical.MemLocation is Schedule.MemoryLocation.Rdata) + if (physical.MemLocation is MemoryLocation.Rdata) { if (!ModuleHashSet.TryGetValue(_entry!.ModuleKind, out var module_hashset)) { @@ -204,7 +204,7 @@ protected override bool VisitLeafBuffer(TIR.Buffer buffer) Changed = true; } } - else if (physical.MemLocation is Schedule.MemoryLocation.Data) + else if (physical.MemLocation is MemoryLocation.Data) { // data write into the FunctionUsage if (!_functionHashset.Contains(physical)) @@ -220,7 +220,7 @@ protected override bool VisitLeafBuffer(TIR.Buffer buffer) Changed = true; } } - else if (physical.MemLocation is Schedule.MemoryLocation.SharedData) + else if (physical.MemLocation is MemoryLocation.SharedData) { throw new NotSupportedException("Current Not Support!"); } diff --git a/src/Nncase.Passes/Rules/Neutral/PrimFuncMergeRule.cs b/src/Nncase.Passes/Rules/Neutral/PrimFuncMergeRule.cs index 30cd890f12..9e1c2bc22c 100644 --- a/src/Nncase.Passes/Rules/Neutral/PrimFuncMergeRule.cs +++ b/src/Nncase.Passes/Rules/Neutral/PrimFuncMergeRule.cs @@ -98,7 +98,7 @@ public PrimFuncMergeRule(HashSet mergedFuncs) } // 2. chack and create the data buffer - if (calleeFunc.Parameters.ToArray().Count(b => b.MemLocation == Schedule.MemoryLocation.Output) != 1) + if (calleeFunc.Parameters.ToArray().Count(b => b.MemLocation == MemoryLocation.Output) != 1) { // the direct call mean the callee function only have one output. return null; @@ -128,7 +128,7 @@ public PrimFuncMergeRule(HashSet mergedFuncs) // 5. build the new call. var nameWrapper = callerWrapper.Name; // + '_' + calleeWrapper.Name; - var newWrapper = new PrimFunctionWrapper(nameWrapper, newFunc, newFuncParams.Count(b => b.MemLocation == Schedule.MemoryLocation.Input)); + var newWrapper = new PrimFunctionWrapper(nameWrapper, newFunc, newFuncParams.Count(b => b.MemLocation == MemoryLocation.Input)); var newCallParams = new List(); newCallParams.AddRange(callerParams.Take(calleeBufferIndexs[0])); @@ -151,10 +151,10 @@ private bool BufferCanMerge(TIR.PhysicalBuffer retBuffer, TIR.PhysicalBuffer inB retBuffer.FixedStrides.SequenceEqual(inBuffer.FixedStrides) && retBuffer.ElemType == inBuffer.ElemType && retBuffer.Size == inBuffer.Size && - retBuffer.MemLocation == Schedule.MemoryLocation.Output && - inBuffer.MemLocation == Schedule.MemoryLocation.Input) + retBuffer.MemLocation == MemoryLocation.Output && + inBuffer.MemLocation == MemoryLocation.Input) { - dataBuffer = new TIR.PhysicalBuffer(inBuffer.Name, inBuffer.ElemType, Schedule.MemoryLocation.Data, inBuffer.FixedDimensions, inBuffer.FixedStrides, inBuffer.Start, inBuffer.Size); + dataBuffer = new TIR.PhysicalBuffer(inBuffer.Name, inBuffer.ElemType, MemoryLocation.Data, inBuffer.FixedDimensions, inBuffer.FixedStrides, inBuffer.Start, inBuffer.Size); return true; } diff --git a/src/Nncase.Tests/CodeGen/CSourceHostCases.cs b/src/Nncase.Tests/CodeGen/CSourceHostCases.cs index d6be29ed3b..97101d3c15 100644 --- a/src/Nncase.Tests/CodeGen/CSourceHostCases.cs +++ b/src/Nncase.Tests/CodeGen/CSourceHostCases.cs @@ -27,8 +27,8 @@ public class SubCase : ICodeGenCase public override PrimFunction GetEntry() { var func = T.PrimFunc("sub", - T.Buffer(TensorType.Scalar(DataTypes.Float32), Schedule.MemoryLocation.Input, out var x), - T.Buffer(TensorType.Scalar(DataTypes.Float32), Schedule.MemoryLocation.Input, out var y)).Body( + T.Buffer(TensorType.Scalar(DataTypes.Float32), MemoryLocation.Input, out var x), + T.Buffer(TensorType.Scalar(DataTypes.Float32), MemoryLocation.Input, out var y)).Body( x - y ); return func; @@ -71,8 +71,8 @@ public override void CompareEqual(IRTModel rtmod) public override PrimFunction GetEntry() { return T.PrimFunc("for_loop", - T.Buffer(new(DataTypes.Int32, new[] { 100 }), Schedule.MemoryLocation.Input, out var A), - T.Buffer(TensorType.Scalar(DataTypes.Int32), Schedule.MemoryLocation.Input, out var n) + T.Buffer(new(DataTypes.Int32, new[] { 100 }), MemoryLocation.Input, out var A), + T.Buffer(TensorType.Scalar(DataTypes.Int32), MemoryLocation.Input, out var n) ).Body( T.Serial(out var i, n).Body( T.Store(A[i], A[i] + 1), diff --git a/src/Nncase.Tests/Core/UnitTestExpression.cs b/src/Nncase.Tests/Core/UnitTestExpression.cs index 8f29fbdab5..9cbe60f849 100644 --- a/src/Nncase.Tests/Core/UnitTestExpression.cs +++ b/src/Nncase.Tests/Core/UnitTestExpression.cs @@ -262,7 +262,7 @@ public void TestConstBufferNotEqual() { var c = IR.F.Random.Normal(DataTypes.Float32, 1, 0, 0, new[] { 1, 16, 64, 400 }).Evaluate().AsTensor(); var ddr_ld_input = new TIR.BufferRegion(Nncase.TIR.T.ConstBuffer(Const.FromTensor(c), out _, "ddr_ld_input"), new(new TIR.Range[] { 0..1, 0..16, 0..31, 0..400 })); - var ddr_ld_output = new TIR.BufferRegion(new TIR.PhysicalBuffer("ddr_ld_input", DataTypes.Float32, Schedule.MemoryLocation.Input, new[] { 1, 16, 64, 400 }, TensorUtilities.GetStrides(new[] { 1, 16, 64, 400 }), 0, 0), new(new TIR.Range[] { 0..1, 0..16, 0..31, 0..400 })); + var ddr_ld_output = new TIR.BufferRegion(new TIR.PhysicalBuffer("ddr_ld_input", DataTypes.Float32, MemoryLocation.Input, new[] { 1, 16, 64, 400 }, TensorUtilities.GetStrides(new[] { 1, 16, 64, 400 }), 0, 0), new(new TIR.Range[] { 0..1, 0..16, 0..31, 0..400 })); Assert.NotEqual(ddr_ld_input.Buffer, ddr_ld_output.Buffer); Assert.NotEqual(ddr_ld_input, ddr_ld_output); } @@ -270,8 +270,8 @@ public void TestConstBufferNotEqual() [Fact] public void TestBufferEqual() { - var ddr_ld_input = new TIR.BufferRegion(new TIR.PhysicalBuffer("ddr_ld_input", DataTypes.Float32, Schedule.MemoryLocation.Input, new[] { 1, 16, 64, 400 }, TensorUtilities.GetStrides(new[] { 1, 16, 64, 400 }), 0, 0), new(new TIR.Range[] { 0..1, 0..16, 0..31, 0..400 })); - var ddr_ld_output = new TIR.BufferRegion(new TIR.PhysicalBuffer("ddr_ld_input", DataTypes.Float32, Schedule.MemoryLocation.Input, new[] { 1, 16, 64, 400 }, TensorUtilities.GetStrides(new[] { 1, 16, 64, 400 }), 0, 0), new(new TIR.Range[] { 0..1, 0..16, 0..31, 0..400 })); + var ddr_ld_input = new TIR.BufferRegion(new TIR.PhysicalBuffer("ddr_ld_input", DataTypes.Float32, MemoryLocation.Input, new[] { 1, 16, 64, 400 }, TensorUtilities.GetStrides(new[] { 1, 16, 64, 400 }), 0, 0), new(new TIR.Range[] { 0..1, 0..16, 0..31, 0..400 })); + var ddr_ld_output = new TIR.BufferRegion(new TIR.PhysicalBuffer("ddr_ld_input", DataTypes.Float32, MemoryLocation.Input, new[] { 1, 16, 64, 400 }, TensorUtilities.GetStrides(new[] { 1, 16, 64, 400 }), 0, 0), new(new TIR.Range[] { 0..1, 0..16, 0..31, 0..400 })); Assert.Equal(ddr_ld_input.Buffer, ddr_ld_output.Buffer); Assert.Equal(ddr_ld_input, ddr_ld_output); } @@ -279,8 +279,8 @@ public void TestBufferEqual() [Fact] public void TestBufferNotEqual() { - var ddr_ld_input = new TIR.BufferRegion(new TIR.PhysicalBuffer("ddr_ld_input", DataTypes.Float32, Schedule.MemoryLocation.Input, new[] { 1, 16, 64, 400 }, TensorUtilities.GetStrides(new[] { 1, 16, 64, 400 }), 0, 0), new(new TIR.Range[] { 0..1, 0..16, 0..31, 0..400 })); - var glb_ld_output = new TIR.BufferRegion(new TIR.PhysicalBuffer("glb_ld_output", DataTypes.BFloat16, Schedule.MemoryLocation.Data, new[] { 1, 16, 64, 400 }, TensorUtilities.GetStrides(new[] { 1, 16, 64, 400 }), 0, 0), new(new TIR.Range[] { 0..1, 0..16, 0..31, 0..400 })); + var ddr_ld_input = new TIR.BufferRegion(new TIR.PhysicalBuffer("ddr_ld_input", DataTypes.Float32, MemoryLocation.Input, new[] { 1, 16, 64, 400 }, TensorUtilities.GetStrides(new[] { 1, 16, 64, 400 }), 0, 0), new(new TIR.Range[] { 0..1, 0..16, 0..31, 0..400 })); + var glb_ld_output = new TIR.BufferRegion(new TIR.PhysicalBuffer("glb_ld_output", DataTypes.BFloat16, MemoryLocation.Data, new[] { 1, 16, 64, 400 }, TensorUtilities.GetStrides(new[] { 1, 16, 64, 400 }), 0, 0), new(new TIR.Range[] { 0..1, 0..16, 0..31, 0..400 })); Assert.False(ddr_ld_input.Buffer.Equals(glb_ld_output.Buffer)); Assert.False(ddr_ld_input.Equals(glb_ld_output)); } diff --git a/src/Nncase.Tests/Core/UnitTestStringUtility.cs b/src/Nncase.Tests/Core/UnitTestStringUtility.cs index 0b01ae0fd6..c17f577ee2 100644 --- a/src/Nncase.Tests/Core/UnitTestStringUtility.cs +++ b/src/Nncase.Tests/Core/UnitTestStringUtility.cs @@ -16,14 +16,14 @@ namespace Nncase.Tests.CoreTest; public static class TestExtensions { - public static ArrayExtensions.SpanWhereEnumerable> InputOf(this ReadOnlySpan arr) => arr.AsValueEnumerable().Where(b => b.MemLocation == Schedule.MemoryLocation.Input); + public static ArrayExtensions.SpanWhereEnumerable> InputOf(this ReadOnlySpan arr) => arr.AsValueEnumerable().Where(b => b.MemLocation == MemoryLocation.Input); - public static ArrayExtensions.SpanWhereEnumerable> OutputOf(this ReadOnlySpan arr) => arr.AsValueEnumerable().Where(b => b.MemLocation == Schedule.MemoryLocation.Output); + public static ArrayExtensions.SpanWhereEnumerable> OutputOf(this ReadOnlySpan arr) => arr.AsValueEnumerable().Where(b => b.MemLocation == MemoryLocation.Output); } public sealed class UnitTestStringUtility { - private readonly TIR.PrimFunction _entry = new("test_module", new Sequential(1), new TIR.PhysicalBuffer("testInput", DataTypes.Float32, Schedule.MemoryLocation.Input, new[] { 1, 16, 64, 400 }, TensorUtilities.GetStrides(new[] { 1, 16, 64, 400 }), 0, 0), new TIR.PhysicalBuffer("testInput", DataTypes.Float32, Schedule.MemoryLocation.Input, new[] { 1, 16, 64, 400 }, TensorUtilities.GetStrides(new[] { 1, 16, 64, 400 }), 0, 0)); + private readonly TIR.PrimFunction _entry = new("test_module", new Sequential(1), new TIR.PhysicalBuffer("testInput", DataTypes.Float32, MemoryLocation.Input, new[] { 1, 16, 64, 400 }, TensorUtilities.GetStrides(new[] { 1, 16, 64, 400 }), 0, 0), new TIR.PhysicalBuffer("testInput", DataTypes.Float32, MemoryLocation.Input, new[] { 1, 16, 64, 400 }, TensorUtilities.GetStrides(new[] { 1, 16, 64, 400 }), 0, 0)); [Fact] public void TestJoin() diff --git a/src/Nncase.Tests/Core/UnitTestTIR.cs b/src/Nncase.Tests/Core/UnitTestTIR.cs index ba7131f64e..0fce4d672d 100644 --- a/src/Nncase.Tests/Core/UnitTestTIR.cs +++ b/src/Nncase.Tests/Core/UnitTestTIR.cs @@ -47,21 +47,10 @@ public void TestScheduler() [Fact] public void TestBufferStore() { - Assert.Throws(() => T.Store(null!, null!)); - - var variable = new Var("x", DataTypes.Int32); - int index = 0; - Expr loadOp = T.Load(variable, index); Expr value = 42; - _ = T.Store(loadOp, value); - - var physicalBuffer = new TIR.PhysicalBuffer("testInput", DataTypes.Float32, Schedule.MemoryLocation.Input, new[] { 1, 16, 64, 400 }, TensorUtilities.GetStrides(new[] { 1, 16, 64, 400 }), 0, 0); + var physicalBuffer = new TIR.PhysicalBuffer("testInput", DataTypes.Float32, MemoryLocation.Input, new[] { 1, 16, 64, 400 }, TensorUtilities.GetStrides(new[] { 1, 16, 64, 400 }), 0, 0); var indices = new Expr[] { 0, 1 }; - Expr storeOp = T.Store(new BufferLoad(physicalBuffer, indices), value); - var store = (BufferStore)storeOp; - Assert.Equal(physicalBuffer, store.Buffer); - Assert.Equal(value, store.Value); - Assert.Equal(new Expr[] { 0 }, store.Indices.ToArray()); + Call store = T.Store(physicalBuffer, 0, value); } [Fact] @@ -165,8 +154,8 @@ public void TestPrimFunction() { var primFunc = new PrimFunction("test_module", new Sequential(new Expr[] { 1 }), new[] { - new TIR.PhysicalBuffer("testInput", DataTypes.Float32, Schedule.MemoryLocation.Input, new[] { 1, 16, 64, 400 }, TensorUtilities.GetStrides(new[] { 1, 16, 64, 400 }), 0, 0), - new TIR.PhysicalBuffer("testInput", DataTypes.Float32, Schedule.MemoryLocation.Input, new[] { 1, 16, 64, 400 }, TensorUtilities.GetStrides(new[] { 1, 16, 64, 400 }), 0, 0), + new TIR.PhysicalBuffer("testInput", DataTypes.Float32, MemoryLocation.Input, new[] { 1, 16, 64, 400 }, TensorUtilities.GetStrides(new[] { 1, 16, 64, 400 }), 0, 0), + new TIR.PhysicalBuffer("testInput", DataTypes.Float32, MemoryLocation.Input, new[] { 1, 16, 64, 400 }, TensorUtilities.GetStrides(new[] { 1, 16, 64, 400 }), 0, 0), }); var primFuncParameters = primFunc.Parameters; @@ -178,8 +167,8 @@ public void TestPrimFunction() var newBody = new Sequential(new Expr[] { 3 }); var newParams = new[] { - new TIR.PhysicalBuffer("testInput", DataTypes.Float32, Schedule.MemoryLocation.Input, new[] { 1, 16, 64, 400 }, TensorUtilities.GetStrides(new[] { 1, 16, 64, 400 }), 0, 0), - new TIR.PhysicalBuffer("testInput", DataTypes.Float32, Schedule.MemoryLocation.Input, new[] { 1, 16, 64, 400 }, TensorUtilities.GetStrides(new[] { 1, 16, 64, 400 }), 0, 0), + new TIR.PhysicalBuffer("testInput", DataTypes.Float32, MemoryLocation.Input, new[] { 1, 16, 64, 400 }, TensorUtilities.GetStrides(new[] { 1, 16, 64, 400 }), 0, 0), + new TIR.PhysicalBuffer("testInput", DataTypes.Float32, MemoryLocation.Input, new[] { 1, 16, 64, 400 }, TensorUtilities.GetStrides(new[] { 1, 16, 64, 400 }), 0, 0), }; var newPrimFunc = primFunc.With(moduleKind: newModuleKind, body: newBody, parameters: newParams); diff --git a/src/Nncase.Tests/Diagnostics/UnitTestDumpper.cs b/src/Nncase.Tests/Diagnostics/UnitTestDumpper.cs index 31c6d763ae..bd9bcc3c5c 100644 --- a/src/Nncase.Tests/Diagnostics/UnitTestDumpper.cs +++ b/src/Nncase.Tests/Diagnostics/UnitTestDumpper.cs @@ -65,7 +65,7 @@ public void TestDumpFusion() [Fact] public void TestDumpScript() { - var prim_func_1 = T.PrimFunc("prim_func_1", "k?", T.PhysicalBuffer(DataTypes.Float32, Schedule.MemoryLocation.Input, new[] { 1, 2, 3, 4 }, out _), T.PhysicalBuffer(DataTypes.Float32, Schedule.MemoryLocation.Output, new[] { 1, 2, 3, 4 }, out _)).Body(T.Nop()).Build(); + var prim_func_1 = T.PrimFunc("prim_func_1", "k?", T.PhysicalBuffer(DataTypes.Float32, MemoryLocation.Input, new[] { 1, 2, 3, 4 }, out _), T.PhysicalBuffer(DataTypes.Float32, MemoryLocation.Output, new[] { 1, 2, 3, 4 }, out _)).Body(T.Nop()).Build(); Assert.True(CompilerServices.InferenceType(prim_func_1)); diff --git a/src/Nncase.Tests/Evaluator/UnitTestEvaluator.cs b/src/Nncase.Tests/Evaluator/UnitTestEvaluator.cs index 3f230b49c5..0fdb6fe4da 100755 --- a/src/Nncase.Tests/Evaluator/UnitTestEvaluator.cs +++ b/src/Nncase.Tests/Evaluator/UnitTestEvaluator.cs @@ -97,10 +97,10 @@ public void TestOnnxResizeImage() public void TestLoadStore() { var loop_i = new Var(TensorType.Scalar(DataTypes.Int32)); - var load = T.Load(T.Handle("hd", DataTypes.Float32), loop_i); + T.Buffer(DataTypes.Float32, MemoryLocation.Input, new Expr[] { 1, 2, 3 }, out var bf); + var load = T.Load(bf, loop_i); CompilerServices.InferenceType(load); - - var store = T.Store((Var)load[TIR.Load.Handle], load[TIR.Load.Index], loop_i); + var store = T.Store(bf, loop_i, IR.F.Tensors.Cast(loop_i, DataTypes.Float32)); CompilerServices.InferenceType(store); } diff --git a/src/Nncase.Tests/Evaluator/UnitTestEvaluatorBuffers.cs b/src/Nncase.Tests/Evaluator/UnitTestEvaluatorBuffers.cs index 04c7a43c44..bed1177d0d 100644 --- a/src/Nncase.Tests/Evaluator/UnitTestEvaluatorBuffers.cs +++ b/src/Nncase.Tests/Evaluator/UnitTestEvaluatorBuffers.cs @@ -29,7 +29,7 @@ public class UnitTestEvaluatorBuffers : TestClassBase public void TestUninitialized() { var shape = new[] { 1 }; - var expr = IR.F.Buffer.Uninitialized(DataTypes.Float32, MemoryLocation.Input, shape); + var expr = IR.F.Buffer.Uninitialized(DataTypes.Float32, TIR.MemoryLocation.Input, shape); CompilerServices.InferenceType(expr); Assert.Equal(Value.None, expr.Evaluate()); } diff --git a/src/Nncase.Tests/TIR/PrimFunc/IDataFlowPrimFuncCase.cs b/src/Nncase.Tests/TIR/PrimFunc/IDataFlowPrimFuncCase.cs index 9b791049d0..e514c3e09a 100644 --- a/src/Nncase.Tests/TIR/PrimFunc/IDataFlowPrimFuncCase.cs +++ b/src/Nncase.Tests/TIR/PrimFunc/IDataFlowPrimFuncCase.cs @@ -33,11 +33,11 @@ internal static class PrimFuncBuilder public static PrimFunctionWrapper MakeLoadStoreFunc(bool mask) { var allocator = new Allocator(); - var fusion_input = allocator.Allocate($"fusion_{_count}_input", Schedule.MemoryLocation.Input); + var fusion_input = allocator.Allocate($"fusion_{_count}_input", TIR.MemoryLocation.Input); - var glb = allocator.Allocate($"fusion_{_count}_glb", Schedule.MemoryLocation.L2Data); + var glb = allocator.Allocate($"fusion_{_count}_glb", TIR.MemoryLocation.L2Data); - var fusion_output = allocator.Allocate($"fusion_{_count}_output", Schedule.MemoryLocation.Output); + var fusion_output = allocator.Allocate($"fusion_{_count}_output", TIR.MemoryLocation.Output); var fusion_1 = TIR.T.PrimFunc($"fusion_{_count}_{mask}", Callable.StackVMModuleKind, fusion_input, fusion_output).Body( new Call(new TIRTest.LoadT(), fusion_input, glb), @@ -50,12 +50,12 @@ public static PrimFunctionWrapper MakeLoadStoreFunc(bool mask) public static PrimFunctionWrapper MakeBinaryFunc(BinaryOp binaryOp, bool mask) { var allocator = new Allocator(); - var fusion_input_lhs = allocator.Allocate($"fusion_{_count}_input_lhs", Schedule.MemoryLocation.Input); - var fusion_input_rhs = allocator.Allocate($"fusion_{_count}_input_rhs", Schedule.MemoryLocation.Input); - var glb_lhs = allocator.Allocate($"fusion_{_count}_glb_lhs", Schedule.MemoryLocation.L2Data); - var glb_rhs = allocator.Allocate($"fusion_{_count}_glb_rhs", Schedule.MemoryLocation.L2Data); - var glb_output = allocator.Allocate($"fusion_{_count}_glb_output", Schedule.MemoryLocation.L2Data); - var fusion_output = allocator.Allocate($"fusion_{_count}_output", Schedule.MemoryLocation.Output); + var fusion_input_lhs = allocator.Allocate($"fusion_{_count}_input_lhs", TIR.MemoryLocation.Input); + var fusion_input_rhs = allocator.Allocate($"fusion_{_count}_input_rhs", TIR.MemoryLocation.Input); + var glb_lhs = allocator.Allocate($"fusion_{_count}_glb_lhs", TIR.MemoryLocation.L2Data); + var glb_rhs = allocator.Allocate($"fusion_{_count}_glb_rhs", TIR.MemoryLocation.L2Data); + var glb_output = allocator.Allocate($"fusion_{_count}_glb_output", TIR.MemoryLocation.L2Data); + var fusion_output = allocator.Allocate($"fusion_{_count}_output", TIR.MemoryLocation.Output); var fusion = TIR.T.PrimFunc($"fusion_{_count}_{mask}", Callable.StackVMModuleKind, fusion_input_lhs, fusion_input_rhs, fusion_output).Body( new Call(new TIRTest.LoadT(), fusion_input_lhs, glb_lhs), @@ -74,13 +74,13 @@ public static PrimFunctionWrapper MakeMultiInputFunc(int length, bool mask) var fusion_inputs = new List(); for (int i = 0; i < length; i++) { - var fusion_input_i = allocator.Allocate($"fusion_{_count}_input_{i}", Schedule.MemoryLocation.Input); + var fusion_input_i = allocator.Allocate($"fusion_{_count}_input_{i}", TIR.MemoryLocation.Input); fusion_inputs.Add(fusion_input_i); } - var glb1 = allocator.Allocate($"fusion_{_count}_glb1", Schedule.MemoryLocation.L2Data); - var glb2 = allocator.Allocate($"fusion_{_count}_glb2", Schedule.MemoryLocation.L2Data); - var fusion_output = allocator.Allocate($"fusion_{_count}_output", Schedule.MemoryLocation.Output); + var glb1 = allocator.Allocate($"fusion_{_count}_glb1", TIR.MemoryLocation.L2Data); + var glb2 = allocator.Allocate($"fusion_{_count}_glb2", TIR.MemoryLocation.L2Data); + var fusion_output = allocator.Allocate($"fusion_{_count}_output", TIR.MemoryLocation.Output); var fusion = TIR.T.PrimFunc($"multi_fusion_{_count}_{mask}", Callable.StackVMModuleKind, fusion_inputs.Concat(new[] { fusion_output }).ToArray()); @@ -124,13 +124,13 @@ private static IEnumerable GetBinaryOp(int length) private sealed class Allocator { - private readonly Dictionary _useage = new() { - { Schedule.MemoryLocation.Input, 0 }, - { Schedule.MemoryLocation.Output, 0 }, - { Schedule.MemoryLocation.L2Data, 0 }, + private readonly Dictionary _useage = new() { + { TIR.MemoryLocation.Input, 0 }, + { TIR.MemoryLocation.Output, 0 }, + { TIR.MemoryLocation.L2Data, 0 }, }; - public TIR.PhysicalBuffer Allocate(string name, Schedule.MemoryLocation location) + public TIR.PhysicalBuffer Allocate(string name, TIR.MemoryLocation location) { var strides = TensorUtilities.GetStrides(Dimensions); var size = TensorUtilities.GetSize(Dimensions, strides, DataTypes.Float32.SizeInBytes); diff --git a/src/Nncase.Tests/TIR/PrimFunc/UnitTestPrimFuncMerge.cs b/src/Nncase.Tests/TIR/PrimFunc/UnitTestPrimFuncMerge.cs index 9cd1eb7139..f5183e5127 100644 --- a/src/Nncase.Tests/TIR/PrimFunc/UnitTestPrimFuncMerge.cs +++ b/src/Nncase.Tests/TIR/PrimFunc/UnitTestPrimFuncMerge.cs @@ -121,11 +121,11 @@ internal sealed class PrimFuncEvaluateVisitor private static readonly int _pool_size = 1 * 4 * 8 * 9 * 4 * 30; private readonly PrimFunctionWrapper _wrapper; private readonly IValue[] _args; - private readonly Dictionary _poolMap = new() { - { Schedule.MemoryLocation.Input, new byte[_pool_size] }, - { Schedule.MemoryLocation.L2Data, new byte[_pool_size] }, - { Schedule.MemoryLocation.Data, new byte[_pool_size] }, - { Schedule.MemoryLocation.Output, new byte[_pool_size] }, + private readonly Dictionary _poolMap = new() { + { TIR.MemoryLocation.Input, new byte[_pool_size] }, + { TIR.MemoryLocation.L2Data, new byte[_pool_size] }, + { TIR.MemoryLocation.Data, new byte[_pool_size] }, + { TIR.MemoryLocation.Output, new byte[_pool_size] }, }; public PrimFuncEvaluateVisitor(PrimFunctionWrapper wrapper, params IValue[] args) diff --git a/src/Nncase.Tests/TIR/UnitTestMutators.cs b/src/Nncase.Tests/TIR/UnitTestMutators.cs index 1df67941ae..a20d5f7295 100644 --- a/src/Nncase.Tests/TIR/UnitTestMutators.cs +++ b/src/Nncase.Tests/TIR/UnitTestMutators.cs @@ -30,9 +30,9 @@ public UnitTestMutators() [Fact] public async Task TestFoldConstCallWithTuple() { - T.PhysicalBuffer(DataTypes.BFloat16, Schedule.MemoryLocation.Input, new[] { 48 }, out var ddr_if); - T.PhysicalBuffer(DataTypes.BFloat16, Schedule.MemoryLocation.Data, new[] { 9 }, out var glb_if_ping); - T.PhysicalBuffer(DataTypes.BFloat16, Schedule.MemoryLocation.Data, new[] { 9 }, out var glb_if_pong); + T.PhysicalBuffer(DataTypes.BFloat16, MemoryLocation.Input, new[] { 48 }, out var ddr_if); + T.PhysicalBuffer(DataTypes.BFloat16, MemoryLocation.Data, new[] { 9 }, out var glb_if_ping); + T.PhysicalBuffer(DataTypes.BFloat16, MemoryLocation.Data, new[] { 9 }, out var glb_if_pong); PrimFunction main; { main = T.PrimFunc("main", Callable.StackVMModuleKind, ddr_if).Body( @@ -118,8 +118,8 @@ public async Task TestUnRollLoopSequential() [Fact] public async Task TestUnRollLoopSequential2() { - T.PhysicalBuffer(DataTypes.BFloat16, Schedule.MemoryLocation.Input, new[] { 3, 16, 24, 24 }, out var ddr_if); - T.PhysicalBuffer(DataTypes.BFloat16, Schedule.MemoryLocation.Data, new[] { 3, 10, 5, 9 }, out var glb_if); + T.PhysicalBuffer(DataTypes.BFloat16, MemoryLocation.Input, new[] { 3, 16, 24, 24 }, out var ddr_if); + T.PhysicalBuffer(DataTypes.BFloat16, MemoryLocation.Data, new[] { 3, 10, 5, 9 }, out var glb_if); PrimFunction main; { @@ -201,8 +201,8 @@ public async Task TestUnRollLoopSequential2() [Fact] public async Task TestUnRollLoopSequential3() { - T.PhysicalBuffer(DataTypes.BFloat16, Schedule.MemoryLocation.Input, new[] { 3, 16, 24, 24 }, out var ddr_if); - T.PhysicalBuffer(DataTypes.BFloat16, Schedule.MemoryLocation.Data, new[] { 3, 10, 5, 9 }, out var glb_if); + T.PhysicalBuffer(DataTypes.BFloat16, MemoryLocation.Input, new[] { 3, 16, 24, 24 }, out var ddr_if); + T.PhysicalBuffer(DataTypes.BFloat16, MemoryLocation.Data, new[] { 3, 10, 5, 9 }, out var glb_if); PrimFunction main; { @@ -362,9 +362,9 @@ public async Task TestFoldLet2() [Fact] public async Task TestFoldBufferIndex() { - T.PhysicalBuffer(DataTypes.BFloat16, Schedule.MemoryLocation.Input, new[] { 3, 16, 24, 24 }, out var ddr_if); - T.PhysicalBuffer(DataTypes.BFloat16, Schedule.MemoryLocation.Output, new[] { 3, 16, 24, 24 }, out var ddr_of); - T.PhysicalBuffer(DataTypes.BFloat16, Schedule.MemoryLocation.Data, new[] { 3, 10, 5, 9 }, out var glb_if); + T.PhysicalBuffer(DataTypes.BFloat16, MemoryLocation.Input, new[] { 3, 16, 24, 24 }, out var ddr_if); + T.PhysicalBuffer(DataTypes.BFloat16, MemoryLocation.Output, new[] { 3, 16, 24, 24 }, out var ddr_of); + T.PhysicalBuffer(DataTypes.BFloat16, MemoryLocation.Data, new[] { 3, 10, 5, 9 }, out var glb_if); var bufferIndexMap = new Dictionary() { { ddr_if, 2 }, { ddr_of, 4 }, diff --git a/src/Nncase.Tests/Targets/UnitTestCPUTargetTiling.cs b/src/Nncase.Tests/Targets/UnitTestCPUTargetTiling.cs new file mode 100644 index 0000000000..28ff7ccf07 --- /dev/null +++ b/src/Nncase.Tests/Targets/UnitTestCPUTargetTiling.cs @@ -0,0 +1,51 @@ +// Copyright (c) Canaan Inc. All rights reserved. +// Licensed under the Apache license. See LICENSE file in the project root for full license information. + +using System; +using System.Collections.Generic; +using System.IO; +using System.Linq; +using System.Runtime.CompilerServices; +using System.Text; +using System.Threading.Tasks; +using Microsoft.Extensions.DependencyInjection; +using Nncase.CodeGen; +using Nncase.IR; +using Nncase.IR.Tensors; +using Nncase.Runtime.Interop; +using Nncase.Targets; +using Nncase.Tests.TestFixture; +using Nncase.Utilities; +using Xunit; +using static Nncase.IR.F.Tensors; +using GetItem = Nncase.IR.Tensors.GetItem; + +namespace Nncase.Tests.Targets; + +[AutoSetupTestMethod(InitSession = true)] +public class UnitTestCPUTargetTiling : TestClassBase +{ + public UnitTestCPUTargetTiling() + { + DefaultTargetName = CPUTarget.Kind; +#if DEBUG + CompileOptions.DumpFlags = Diagnostics.DumpFlags.PassIR | Diagnostics.DumpFlags.Rewrite | Diagnostics.DumpFlags.CodeGen; +#endif + } + + [Fact] + public async Task TestCpuUnary() + { + var input = new Var("input", new TensorType(DataTypes.Float32, new[] { 1, 2, 3, 4, 5 })); + var main = new Function("main", IR.F.Math.Unary(UnaryOp.Asin, input), new[] { input }); + var module = new IR.IRModule(main); + + var compiler = CompileSession.Compiler; + compiler.ImportIRModule(module); + await compiler.CompileAsync(); + using (var fs = new MemoryStream()) + { + compiler.Gencode(fs); + } + } +} diff --git a/src/Nncase.Tests/Transform/UnitTestPassManager.cs b/src/Nncase.Tests/Transform/UnitTestPassManager.cs index bc7cb98896..85d4cf1442 100644 --- a/src/Nncase.Tests/Transform/UnitTestPassManager.cs +++ b/src/Nncase.Tests/Transform/UnitTestPassManager.cs @@ -22,7 +22,7 @@ public sealed class UnitTestPassManager : TestClassBase [Fact] public void TestPassMangerUpdateDependence() { - var prim_func_1 = T.PrimFunc("prim_func_1", "k?", T.PhysicalBuffer(DataTypes.Float32, Schedule.MemoryLocation.Input, new[] { 1, 2, 3, 4 }, out _), T.PhysicalBuffer(DataTypes.Float32, Schedule.MemoryLocation.Output, new[] { 1, 2, 3, 4 }, out _)).Body(T.Nop()).Build(); + var prim_func_1 = T.PrimFunc("prim_func_1", "k?", T.PhysicalBuffer(DataTypes.Float32, MemoryLocation.Input, new[] { 1, 2, 3, 4 }, out _), T.PhysicalBuffer(DataTypes.Float32, MemoryLocation.Output, new[] { 1, 2, 3, 4 }, out _)).Body(T.Nop()).Build(); var prim_wrapper = new PrimFunctionWrapper(prim_func_1, 1); @@ -30,7 +30,7 @@ public void TestPassMangerUpdateDependence() var main_func = new Function("main", new Call(prim_wrapper, input), input); // prim_func_2 for update - var prim_func_2 = T.PrimFunc("prim_func_2", "k?", T.PhysicalBuffer(DataTypes.Float32, Schedule.MemoryLocation.Input, new[] { 1, 2, 3, 4 }, out _), T.PhysicalBuffer(DataTypes.Float32, Schedule.MemoryLocation.Output, new[] { 1, 2, 3, 4 }, out _)).Body( + var prim_func_2 = T.PrimFunc("prim_func_2", "k?", T.PhysicalBuffer(DataTypes.Float32, MemoryLocation.Input, new[] { 1, 2, 3, 4 }, out _), T.PhysicalBuffer(DataTypes.Float32, MemoryLocation.Output, new[] { 1, 2, 3, 4 }, out _)).Body( T.Nop(), T.Nop()).Build(); @@ -54,15 +54,15 @@ public void TestPassMangerUpdateDependence2() %3 = %func_3(%2): // f16[1,23,30,16] */ - var prim_func_0 = T.PrimFunc("prim_func_0", "k?", T.PhysicalBuffer(DataTypes.Float32, Schedule.MemoryLocation.Input, new[] { 1, 24, 32, 3 }, out var _), T.PhysicalBuffer(DataTypes.Float32, Schedule.MemoryLocation.Output, new[] { 1, 3, 24, 32 }, out var _)).Body( + var prim_func_0 = T.PrimFunc("prim_func_0", "k?", T.PhysicalBuffer(DataTypes.Float32, MemoryLocation.Input, new[] { 1, 24, 32, 3 }, out var _), T.PhysicalBuffer(DataTypes.Float32, MemoryLocation.Output, new[] { 1, 3, 24, 32 }, out var _)).Body( T.Nop()).Build(); var func_0 = new PrimFunctionWrapper(prim_func_0, 1); - var prim_func_1 = T.PrimFunc("prim_func_1", "k?", T.PhysicalBuffer(DataTypes.Float32, Schedule.MemoryLocation.Input, new[] { 1, 3, 24, 32 }, out var _), T.PhysicalBuffer(DataTypes.Float32, Schedule.MemoryLocation.Output, new[] { 1, 3, 24, 32 }, out var _)).Body( + var prim_func_1 = T.PrimFunc("prim_func_1", "k?", T.PhysicalBuffer(DataTypes.Float32, MemoryLocation.Input, new[] { 1, 3, 24, 32 }, out var _), T.PhysicalBuffer(DataTypes.Float32, MemoryLocation.Output, new[] { 1, 3, 24, 32 }, out var _)).Body( T.Nop()).Build(); var func_1 = new PrimFunctionWrapper(prim_func_1, 1); - var prim_func_2 = T.PrimFunc("prim_func_2", "k?", T.PhysicalBuffer(DataTypes.Float32, Schedule.MemoryLocation.Input, new[] { 1, 3, 24, 32 }, out var _), T.PhysicalBuffer(DataTypes.Float32, Schedule.MemoryLocation.Output, new[] { 1, 23, 30, 16 }, out var _)).Body( + var prim_func_2 = T.PrimFunc("prim_func_2", "k?", T.PhysicalBuffer(DataTypes.Float32, MemoryLocation.Input, new[] { 1, 3, 24, 32 }, out var _), T.PhysicalBuffer(DataTypes.Float32, MemoryLocation.Output, new[] { 1, 23, 30, 16 }, out var _)).Body( T.Nop()).Build(); var func_2 = new PrimFunctionWrapper(prim_func_2, 1); @@ -74,7 +74,7 @@ public void TestPassMangerUpdateDependence2() Assert.True(CompilerServices.InferenceType(main_func)); // prim_func_2 for update - var prim_func_1_update = T.PrimFunc("prim_func_1_update", "k?", T.PhysicalBuffer(DataTypes.Float32, Schedule.MemoryLocation.Input, new[] { 1, 3, 24, 32 }, out var _), T.PhysicalBuffer(DataTypes.Float32, Schedule.MemoryLocation.Output, new[] { 1, 3, 24, 32 }, out var _)).Body( + var prim_func_1_update = T.PrimFunc("prim_func_1_update", "k?", T.PhysicalBuffer(DataTypes.Float32, MemoryLocation.Input, new[] { 1, 3, 24, 32 }, out var _), T.PhysicalBuffer(DataTypes.Float32, MemoryLocation.Output, new[] { 1, 3, 24, 32 }, out var _)).Body( T.Nop(), T.Nop()).Build(); diff --git a/src/Nncase.Tests/Transform/UnitTestSubstitutor.cs b/src/Nncase.Tests/Transform/UnitTestSubstitutor.cs index 9313303252..f6a1ed969e 100644 --- a/src/Nncase.Tests/Transform/UnitTestSubstitutor.cs +++ b/src/Nncase.Tests/Transform/UnitTestSubstitutor.cs @@ -24,8 +24,9 @@ public sealed class UnitTestSubstitutor : TestClassBase public void TestSubstitutorFailed() { var loop_i = new Var("loop_i", TensorType.Scalar(DataTypes.Int32)); - var prim_func_1 = T.PrimFunc("prim_func_1", "k?", T.PhysicalBuffer(DataTypes.Float32, Schedule.MemoryLocation.Input, new[] { 1, 2, 3, 4 }, out var input_a), T.PhysicalBuffer(DataTypes.Float32, Schedule.MemoryLocation.Output, new[] { 1, 2, 3, 4 }, out var input_b)).Body( - T.Load(T.Handle("hd", DataTypes.Float32), loop_i)).Build(); + T.Buffer(DataTypes.Float32, MemoryLocation.Input, new Expr[] { 1, 2, 3, 4 }, out var hd); + var prim_func_1 = T.PrimFunc("prim_func_1", "k?", T.PhysicalBuffer(DataTypes.Float32, MemoryLocation.Input, new[] { 1, 2, 3, 4 }, out var input_a), T.PhysicalBuffer(DataTypes.Float32, MemoryLocation.Output, new[] { 1, 2, 3, 4 }, out var input_b)).Body( + T.Load(hd, loop_i)).Build(); var prim_wrapper = new PrimFunctionWrapper(prim_func_1, 1); @@ -48,8 +49,9 @@ public void TestSubstitutorFailed() public void TestSubstitutorTrue() { var loop_i = new Var("loop_i", TensorType.Scalar(DataTypes.Int32)); - var prim_func_1 = T.PrimFunc("prim_func_1", "k?", T.PhysicalBuffer(DataTypes.Float32, Schedule.MemoryLocation.Input, new[] { 1, 2, 3, 4 }, out var input_a), T.PhysicalBuffer(DataTypes.Float32, Schedule.MemoryLocation.Output, new[] { 1, 2, 3, 4 }, out var input_b)).Body( - T.Load(T.Handle("hd", DataTypes.Float32), loop_i)).Build(); + T.Buffer(DataTypes.Float32, MemoryLocation.Input, new Expr[] { 1, 2, 3, 4 }, out var hd); + var prim_func_1 = T.PrimFunc("prim_func_1", "k?", T.PhysicalBuffer(DataTypes.Float32, MemoryLocation.Input, new[] { 1, 2, 3, 4 }, out var input_a), T.PhysicalBuffer(DataTypes.Float32, MemoryLocation.Output, new[] { 1, 2, 3, 4 }, out var input_b)).Body( + T.Load(hd, loop_i)).Build(); Dictionary vmap = new() { { loop_i, 1 } }; var substitutor = Mutator.Substitute(e => vmap.TryGetValue(e, out var res) ? res : null)(); @@ -65,8 +67,9 @@ public void TestSubstitutorTrue() public void TestSubstitutorTrue2() { var loop_i = new Var("loop_i", TensorType.Scalar(DataTypes.Int32)); - var prim_func_1 = T.PrimFunc("prim_func_1", "k?", T.PhysicalBuffer(DataTypes.Float32, Schedule.MemoryLocation.Input, new[] { 1, 2, 3, 4 }, out var input_a), T.PhysicalBuffer(DataTypes.Int32, Schedule.MemoryLocation.Output, new[] { 1, 2, 3, 4 }, out var input_b)).Body( - T.Load(T.Handle("hd", DataTypes.Float32), loop_i)).Build(); + T.Buffer(DataTypes.Float32, MemoryLocation.Input, new Expr[] { 1, 2, 3, 4 }, out var hd); + var prim_func_1 = T.PrimFunc("prim_func_1", "k?", T.PhysicalBuffer(DataTypes.Float32, MemoryLocation.Input, new[] { 1, 2, 3, 4 }, out var input_a), T.PhysicalBuffer(DataTypes.Int32, MemoryLocation.Output, new[] { 1, 2, 3, 4 }, out var input_b)).Body( + T.Load(hd, loop_i)).Build(); var prim_wrapper = new PrimFunctionWrapper(prim_func_1, 1); diff --git a/targets/Nncase.Targets.CSource/CSourceTarget.cs b/targets/Nncase.Targets.CSource/CSourceTarget.cs deleted file mode 100644 index e6f3e19a8f..0000000000 --- a/targets/Nncase.Targets.CSource/CSourceTarget.cs +++ /dev/null @@ -1,34 +0,0 @@ - -namespace Nncase.Targets; - -public class CSourceTarget : ITarget -{ - /// - public string Kind { get => "CSource"; set { } } - /// - public Dictionary Options { get; set; } = new(); - /// - public Dictionary Attrs { get; set; } = new(); - /// - public void ConfigOptions() { } - /// - public void ConfigAttrs() { } - - /// - public Schedule.IScheduler CreateScheduler(IR.IRModule main_module) - { - return new Schedule.CSourceScheduler(main_module, this); - } - - /// - public CodeGen.IRTModel CreateRTModel(IR.IRModel model) - { - return new CodeGen.CSourceRTModel(model, this); - } - - /// - public CodeGen.IRTModule CreateRTModule(IR.IRModel model, IR.IRModule module) - { - throw new NotImplementedException("The CSource Target Only Have Runtime Model!"); - } -} diff --git a/targets/Nncase.Targets.CSource/CodeGen/CSource.cs b/targets/Nncase.Targets.CSource/CodeGen/CSource.cs deleted file mode 100644 index afc0de3b68..0000000000 --- a/targets/Nncase.Targets.CSource/CodeGen/CSource.cs +++ /dev/null @@ -1,278 +0,0 @@ -using System; -using System.Collections; -using System.Collections.Generic; -using System.Diagnostics; -using System.IO; -using System.Linq; -using System.Runtime.InteropServices; -using System.Text; -using Nncase.IR; -using Nncase.Schedule; -using Nncase.TIR; - -namespace Nncase.CodeGen; - -/// -/// the c source runtime function. -/// -/// -/// -public record CSourceRTFunction(string name, Delegate handle) : IRTFunction -{ - public string Name { get => name; set { } } - public Delegate Handle { get => handle; set { } } -} - -public class CSourceSerializeResult : ISerializeResult -{ - -} - -/// -/// c runtime module impl -/// -public class CSourceRTModel : IRTModule, IRTModel -{ - /// - public ModuleType ModuleType { get => CodeGen.ModuleType.Create("CSource"); set { } } - - /// - public ITarget Target { get; set; } - - /// - public IReadOnlyList Modules => throw new NotImplementedException(); - - /// - public string SourcePath { get; private set; } - - public IRModel Model { get; set; } - IRTFunction? _entry = null; - - /// - public bool IsSerialized { get; private set; } - - readonly List _functions = new(); - - /// - /// - /// - public CSourceRTModel(IRModel model, ITarget target) - { - SourcePath = CodeGenUtil.GetTempFileName("c"); - Model = model; - Target = target; - } - - /// - public byte[] Source { get => File.ReadAllBytes(SourcePath); set { } } - - /// - public string SourceExt { get => "c"; set { } } - - /// - public IRTFunction? Entry => _entry; - - /// - public IReadOnlyList Functions => _functions; - - /// - string _dllPath = ""; - - /// - /// write the c source code into source path. - /// - /// - void BuildCode() - { - if (File.Exists(SourcePath)) - File.Delete(SourcePath); - using (var writer = new StreamWriter(SourcePath, false, Encoding.UTF8)) - { - var visior = new CSourceHostBuildVisior(writer); - if (Model.Entry is null) { throw new InvalidProgramException("The Model Entry Is Null!"); } - if (Model.Entry.CheckedType is null && Model.Entry.InferenceType() == false) { throw new InvalidProgramException("The Model Entry Can't Inference Type!"); } - visior.Visit(Model.Entry); - } - } - - public void CompileCode() - { - if (!File.Exists(SourcePath)) - throw new InvalidProgramException("The Source Code Path Is Invalid!"); - var compiler = new CSourceCompiler(); - _dllPath = compiler.Compile(SourcePath); - } - - /// - /// bind each IR.Funtion with C function - /// - /// - public void ExportCode() - { - if (!File.Exists(_dllPath)) - throw new InvalidProgramException("The DLL Path Is Invalid!"); - var dllPtr = NativeLibrary.Load(_dllPath); - foreach (var module in Model.Modules) - { - foreach (var f in module.Callables) - { - var funcType = f.ToDelegateType(Path.GetFileName(_dllPath)); - var funPtr = NativeLibrary.GetExport(dllPtr, f.Name); - _functions.Add(new CSourceRTFunction(f.Name, funPtr.BindDelegate(funcType))); - if (f == Model.Entry) { _entry = _functions.Last(); } - } - } - } - - /// - public ISerializeResult Serialize() - { - if (IsSerialized) { return new CSourceSerializeResult(); } - BuildCode(); - CompileCode(); - ExportCode(); - return new CSourceSerializeResult(); - } - - /// - /// invoke the module entry - /// - /// input args - /// results - /// - public object? Invoke(params object?[]? args) - { - if (Entry is null) - throw new InvalidOperationException("This RTModule Have No Entry Function!"); - return Entry.Handle.DynamicInvoke(args); - } - - public string Dump(string name, string DumpDirPath) - { - var dump_path = $"{DumpDirPath}/{name}.{SourceExt}"; - using var file = File.Open(dump_path, FileMode.OpenOrCreate, FileAccess.Write); - using var writer = new StreamWriter(file); - writer.Write(Source); - return dump_path; - } - -} - -/// -/// the csource code compiler. -/// -public class CSourceCompiler -{ - /// - /// compiler exe name - /// - string _exe = "", _arch = "", _ext = ""; - - /// - /// select current pattern's exe - /// - /// - void PlatformSpecific() - { - if (RuntimeInformation.IsOSPlatform(OSPlatform.Linux)) - { - _exe = "gcc"; - _ext = "so"; - } - else if (RuntimeInformation.IsOSPlatform(OSPlatform.OSX)) - { - _exe = "clang"; - _ext = "dylib"; - } - else if (RuntimeInformation.IsOSPlatform(OSPlatform.Windows)) - { - _exe = "cmd"; - _ext = "dll"; - } - } - - void ArchSpecific() - { - _arch = RuntimeInformation.OSArchitecture switch - { - Architecture.X64 => RuntimeInformation.IsOSPlatform(OSPlatform.Linux) ? "x86-64" : "x86_64", - Architecture.Arm64 => "arm64", - _ => throw new NotSupportedException(RuntimeInformation.OSArchitecture.ToString()), - }; - } - - string ArgumentsSpecific(string sourcePath, string outPath) - { - if (RuntimeInformation.IsOSPlatform(OSPlatform.Linux)) - { - return $"{sourcePath} -fPIC -shared -march={Arch} -o {outPath}"; - } - else if (RuntimeInformation.IsOSPlatform(OSPlatform.OSX)) - { - return $"{sourcePath} -fPIC -shared -arch {Arch} -o {outPath}"; - } - else if (RuntimeInformation.IsOSPlatform(OSPlatform.Windows)) - { - var vsdir = Environment.GetEnvironmentVariable("VSAPPIDDIR") ?? throw new InvalidOperationException("Cannot find vs"); - var vcvardir = Path.Combine(vsdir, "..\\..\\VC\\Auxiliary\\Build\\vcvarsall.bat"); - return $"/C (\"{vcvardir}\" x64) && (cl /D_USRDLL /D_WINDLL \"{sourcePath}\" /MT /link /DLL /OUT:\"{outPath}\")"; - } - throw new System.ArgumentOutOfRangeException("Only Support Linux/Osx/Windows"); - } - - protected string Exe - { - get => _exe; - } - - protected string Arch - { - get => _arch; - } - - protected string Ext - { - get => _ext; - } - - public CSourceCompiler() - { - PlatformSpecific(); - ArchSpecific(); - } - - /// - /// compile the source txt, write to the out_path - /// - /// c source code - /// out .so path - /// outPath - public string Compile(string sourcePath, string outPath) - { - var errMsg = new StringBuilder(); - using (var errWriter = new StringWriter(errMsg)) - { - using (var proc = new Process()) - { - proc.StartInfo.FileName = Exe; - proc.StartInfo.Arguments = ArgumentsSpecific(sourcePath, outPath); - proc.StartInfo.RedirectStandardError = true; - proc.ErrorDataReceived += (sender, e) => errWriter.WriteLine(e.Data); - proc.Start(); - proc.BeginErrorReadLine(); - proc.WaitForExit(); - if (proc.ExitCode != 0) - { - throw new InvalidOperationException(errMsg.ToString()); - } - } - } - return outPath; - } - - /// - /// create the temp dll file and compile source - /// - /// - public string Compile(string sourcePath) => Compile(sourcePath, CodeGenUtil.GetTempFileName(Ext)); -} \ No newline at end of file diff --git a/targets/Nncase.Targets.CSource/CodeGen/CSourceVisitor.cs b/targets/Nncase.Targets.CSource/CodeGen/CSourceVisitor.cs deleted file mode 100644 index 352b9dc6ea..0000000000 --- a/targets/Nncase.Targets.CSource/CodeGen/CSourceVisitor.cs +++ /dev/null @@ -1,317 +0,0 @@ -using System; -using System.Collections; -using System.Collections.Generic; -using System.Diagnostics; -using System.IO; -using System.Linq; -using System.Runtime.InteropServices; -using System.Text; -using Nncase.IR; -using Nncase.Runtime; -using Nncase.TIR; - -namespace Nncase.CodeGen; - -/// -/// convert the type/op to c name -/// -internal static class NameConverter -{ - private static readonly Dictionary _primTypeToC = new() - { - { DataTypes.Boolean, "bool" }, - { DataTypes.Int8, "int8_t" }, - { DataTypes.Int16, "int16_t" }, - { DataTypes.Int32, "int32_t" }, - { DataTypes.Int64, "int64_t" }, - { DataTypes.UInt8, "uint8_t" }, - { DataTypes.UInt16, "uint16_t" }, - { DataTypes.UInt32, "uint32_t" }, - { DataTypes.UInt64, "uint64_t" }, - { DataTypes.Float32, "float" }, - { DataTypes.Float64, "double" }, - }; - - public static string toC(this PrimType primType) => - _primTypeToC[primType]; - - public static string toC(this DataType dataType) => dataType switch - { - PrimType ptype => ptype.toC(), - PointerType { ElemType: PrimType etype } => etype.toC() + "*", - _ => throw new NotSupportedException(dataType.ToString()) - }; -} - -/// -/// the c symbol define -/// -internal struct CSymbol -{ - public string Type; - public StringBuilder Doc; - public CSymbol(string type, StringBuilder doc) - { - Type = type; - Doc = doc; - } - public override string ToString() => $"{Type} {Doc}"; -} - -/// -/// collect the csymbol's parameter -/// -internal class CSymbolParamList : IParameterList, IEnumerable -{ - CSymbol[] Symbols; - public CSymbolParamList(CSymbol[] symbols) - { - Symbols = symbols; - } - - public CSymbol this[ParameterInfo parameter] => Symbols[parameter.Index]; - public CSymbol this[int index] => Symbols[index]; - - public IEnumerator GetEnumerator() - { - return ((IEnumerable)Symbols).GetEnumerator(); - } - - IEnumerator IEnumerable.GetEnumerator() - { - return Symbols.GetEnumerator(); - } -} - - -/// -/// visitor for the build c source code, the expr vistor return (type string , name string) -/// -internal class CSourceHostBuildVisior : ExprFunctor -{ - - /// - /// source writer . - /// TODO we need the decl writer - /// - readonly ScopeWriter Scope; - - /// - /// symbols name memo - /// - readonly Dictionary Symbols = new(ReferenceEqualityComparer.Instance); - - /// - /// - /// - /// - public CSourceHostBuildVisior(TextWriter textWriter) - { - Scope = new ScopeWriter(textWriter); - // insert some declare - Scope.IndWriteLine(@" -#ifdef _WIN32 -#define EXPORT_API __declspec(dllexport) -#else -#define EXPORT_API -#endif"); - Scope.IndWriteLine("#include "); - } - - /// - public override CSymbol Visit(Call expr) - { - if (Symbols.TryGetValue(expr, out var symbol)) { return symbol; } - var target = Visit(expr.Target); - var args = new CSymbolParamList(expr.Parameters.Select(Visit).ToArray()); - var type = VisitType(expr.CheckedType!); - Scope.Push(); - switch (expr.Target) - { - case IR.Math.Binary: - Scope.Append($"({args[0].Doc} {target.Doc} {args[1].Doc})"); - break; - case Store: - Scope.Append($"{args[Store.Handle].Doc}[{args[Store.Index].Doc}] = {args[Store.Value].Doc}"); - break; - case Load: - Scope.Append($"{args[Store.Handle].Doc}[{args[Store.Index].Doc}]"); - break; - case IR.Tensors.Cast: - Scope.Append($"(({type}){args[IR.Tensors.Cast.Input].Doc})"); - break; - default: - Scope.Append($"{target.Doc}({string.Join(", ", args.Select(x => x.Doc))})"); - break; - } - symbol = new(type, Scope.Pop()); - Symbols.Add(expr, symbol); - return symbol; - } - - /// - public override CSymbol Visit(Const expr) - { - if (Symbols.TryGetValue(expr, out var symbol)) { return symbol; } - if (expr.CheckedType is TensorType ttype && ttype.IsScalar) - { - var literal = $"{expr}" switch - { - "True" => "1", - "False" => "0", - var x => x - }; - symbol = new(VisitType(ttype), new(literal)); - } - else - { - throw new NotSupportedException($"Not Support {expr.CheckedType} Const"); - } - Symbols.Add(expr, symbol); - return symbol; - } - - /// - public override CSymbol Visit(Function expr) - { - if (Symbols.TryGetValue(expr, out var symbol)) { return symbol; } - var retType = VisitType(((CallableType)expr.CheckedType!).ReturnType); - Scope.Push(); - // 1. Function signature - Scope.IndWrite($"EXPORT_API {retType} {expr.Name}({string.Join(", ", expr.Parameters.Select(Visit))}) {{"); - // 2. Function body - using (Scope.IndentUp()) - { - Scope.Append(Visit(expr.Body).Doc); - } - // 3. Function closing - Scope.IndWrite("}"); - symbol = new(CallableTypeToPtr((CallableType)expr.CheckedType!, expr.Name), Scope.Pop()); - // 4. write whole code - Scope.IndWrite(symbol.Doc); - return symbol; - } - - /// - public override CSymbol Visit(Op expr) - { - if (Symbols.TryGetValue(expr, out var symbol)) { return symbol; } - symbol = new("Invalid Op", new(expr switch - { - IR.Math.Binary op => op.BinaryOp switch - { - BinaryOp.Add => "+", - BinaryOp.Sub => "-", - BinaryOp.Mul => "*", - BinaryOp.Div => "/", - BinaryOp.Mod => "%", - _ => throw new ArgumentOutOfRangeException(op.BinaryOp.ToString()) - }, - TIR.Store op => "Store", - TIR.Load op => "Load", - IR.Tensors.Cast op => op.NewType.toC(), - _ => throw new NotSupportedException($"{expr.GetType().Name}") - })); - Symbols.Add(expr, symbol); - return symbol; - } - - /// - public override CSymbol Visit(Var expr) - { - if (Symbols.TryGetValue(expr, out var symbol)) { return symbol; } - var isymbol = Scope.GetUniqueVarSymbol(expr); - symbol = new(VisitType(expr.CheckedType!), isymbol.Span); - Symbols.Add(expr, symbol); - return symbol; - } - - /// - public override CSymbol Visit(For expr) - { - if (Symbols.TryGetValue(expr, out var symbol)) { return symbol; } - Scope.Push(); - // 1. For Loop signature - var loopVar = Visit(expr.LoopVar); - Scope.Append($"for ({loopVar} = {Visit(expr.Dom.Start).Doc}; {loopVar.Doc} < {Visit(expr.Dom.Stop).Doc}; {loopVar.Doc}+={expr.Dom.Step}) {{"); - // 2. For Body - Scope.Append(Visit(expr.Body).Doc); - // 3. For closing - Scope.IndWrite("}"); - symbol = new(VisitType(expr.CheckedType!), Scope.Pop()); - Symbols.Add(expr, symbol); - return symbol; - } - - /// - public override CSymbol Visit(Sequential expr) - { - if (Symbols.TryGetValue(expr, out var symbol)) { return symbol; } - Scope.Push(); - Scope.AppendLine(""); - using (Scope.IndentUp()) - { - foreach (var i in Enumerable.Range(0, expr.Fields.Count)) - { - if (i == expr.Fields.Count - 1 && - expr.Fields[i].CheckedType is TensorType) - { - Scope.IndWrite("return "); - } - else - { - Scope.IndWrite(string.Empty); - } - Scope.Append(Visit(expr.Fields[i]).Doc); - if (expr.Fields[i] is Call) - { - Scope.AppendLine(";"); - } - else - { - Scope.AppendLine(string.Empty); - } - } - } - symbol = new(VisitType(expr.CheckedType!), Scope.Pop()); - Symbols.Add(expr, symbol); - return symbol; - } - - /// - public override CSymbol Visit(IfThenElse expr) - { - if (Symbols.TryGetValue(expr, out var symbol)) { return symbol; } - Scope.Push(); - Scope.Append($"if({Visit(expr.Condition).Doc}) {{"); - Scope.Append(Visit(expr.Then).Doc); - Scope.IndWrite("} else {"); - Scope.Append(Visit(expr.Else).Doc); - Scope.IndWrite("}"); - symbol = new(VisitType(expr.CheckedType!), Scope.Pop()); - Symbols.Add(expr, symbol); - return symbol; - } - - /// - /// - /// void (*fun_ptr)(int) - /// - public string CallableTypeToPtr(CallableType type, string name) => $"{VisitType(type.ReturnType)} (*{name}_ptr)({string.Join(",", type.Parameters.Select(VisitType))})"; - - - /// - public override string VisitType(TensorType type) - { - if (!type.IsScalar) - { - throw new NotSupportedException($"{type}"); - } - return type.DType.toC(); - } - - /// - public override string VisitType(TupleType type) => type == TupleType.Void ? - "void" : - throw new InvalidProgramException($"The C Source Must Not Have TupleType {type}!"); -} \ No newline at end of file diff --git a/targets/Nncase.Targets.CSource/CodeGen/Interop.cs b/targets/Nncase.Targets.CSource/CodeGen/Interop.cs deleted file mode 100644 index 33d31af269..0000000000 --- a/targets/Nncase.Targets.CSource/CodeGen/Interop.cs +++ /dev/null @@ -1,140 +0,0 @@ -using System; -using System.Collections.Generic; -using System.Linq; -using System.Reflection; -using System.Reflection.Emit; -using System.Runtime.InteropServices; -using Nncase.IR; - -namespace Nncase.CodeGen; - -/// -/// -/// -internal class DynamicAssemble -{ - /// - /// name - /// - AssemblyName assemblyName; - /// - /// asm builder for whole module - /// - AssemblyBuilder asmBuilder; - /// - /// module buidler - /// - ModuleBuilder modBuilder; - /// - /// save the func name <=> func delegate type - /// - readonly Dictionary delegateTypes = new(); - - /// - /// a DynamicAssemble instance, it's contains one rtmodule's all functions defination. - /// - /// asmble name - public DynamicAssemble(string Name) - { - assemblyName = new AssemblyName(Name); - asmBuilder = AssemblyBuilder.DefineDynamicAssembly(assemblyName, AssemblyBuilderAccess.RunAndCollect); - modBuilder = asmBuilder.DefineDynamicModule(assemblyName.Name!); - - } - - /// - /// - /// - /// - /// func delegate type - public Type BuildDelegateType(Callable function) - { - Type deleType; - if (function.CheckedType is CallableType ctype) - { - deleType = CreateDelegateType(function.Name, ctype.ReturnType.ToType(), ctype.Parameters.Select(Interop.ToType).ToArray()); - } - else { throw new NotSupportedException(function.CheckedType?.ToString()); } - return deleType; - } - - /// - /// dynamic create delegate type for function. - /// - /// - /// - /// - /// - /// - public Type CreateDelegateType(string funcName, Type returnType, params Type[]? ParamTypes) - { - if (!delegateTypes.TryGetValue(funcName, out var ret)) - { - ParamTypes ??= new Type[] { }; - TypeBuilder tb = modBuilder.DefineType(funcName, TypeAttributes.Public | TypeAttributes.Sealed, typeof(MulticastDelegate)); - tb.DefineConstructor(MethodAttributes.Public | MethodAttributes.HideBySig | MethodAttributes.SpecialName | MethodAttributes.RTSpecialName, CallingConventions.Standard | CallingConventions.HasThis, new[] { typeof(object), typeof(IntPtr) }).SetImplementationFlags(MethodImplAttributes.Runtime | MethodImplAttributes.Managed); - tb.DefineMethod("Invoke", MethodAttributes.Public | MethodAttributes.Virtual | MethodAttributes.HideBySig | MethodAttributes.NewSlot, CallingConventions.Standard | CallingConventions.HasThis, returnType, ParamTypes).SetImplementationFlags(MethodImplAttributes.Runtime | MethodImplAttributes.Managed); - tb.DefineMethod("BeginInvoke", MethodAttributes.Public | MethodAttributes.Virtual | MethodAttributes.HideBySig | MethodAttributes.NewSlot, CallingConventions.Standard | CallingConventions.HasThis, typeof(IAsyncResult), ParamTypes.Concat(new[] { typeof(IAsyncResult), typeof(object) }).ToArray()).SetImplementationFlags(MethodImplAttributes.Runtime | MethodImplAttributes.Managed); - tb.DefineMethod("EndInvoke", MethodAttributes.Public | MethodAttributes.Virtual | MethodAttributes.HideBySig | MethodAttributes.NewSlot, CallingConventions.Standard | CallingConventions.HasThis, returnType, new[] { typeof(IAsyncResult) }).SetImplementationFlags(MethodImplAttributes.Runtime | MethodImplAttributes.Managed); - ret = tb.CreateType(); - if (ret is null) { throw new InvalidProgramException($"Can't Create The Func {funcName}'s delegate Type!"); } - delegateTypes.Add(funcName, ret); - } - return ret; - } -} - -/// -/// Interop helper -/// -public static class Interop -{ - /// - /// collect the all dynamic asmbs - /// - private static readonly Dictionary _definedAsms = new(); - - /// - /// convert the ir type to the system type - /// - /// - /// - /// - public static Type ToType(this IRType iRType) => iRType switch - { - TensorType { IsScalar: true, DType: PrimType { } primType } => primType.CLRType, - TensorType { IsScalar: true, DType: PointerType { ElemType: PrimType primType } } => primType.CLRType.MakeArrayType(), - TupleType ttype => (ttype == TupleType.Void) switch - { - true => typeof(void), - false => throw new NotSupportedException($"Can't Support the {ttype}!") - }, - _ => throw new NotSupportedException($"IRType is {iRType}!") - }; - - - /// - /// convrt function to delegate type - /// - /// input function - /// the dynamic lib name - /// - /// - public static Type ToDelegateType(this Callable function, string libName) - { - if (!_definedAsms.TryGetValue(libName, out var dyasm)) - { - dyasm = new DynamicAssemble(libName); - _definedAsms.Add(libName, dyasm); - } - return dyasm.BuildDelegateType(function); ; - } - - /// - /// bind the delegate to funcptr. - /// - /// - /// - /// - public static Delegate BindDelegate(this IntPtr funcPtr, Type funcType) => Marshal.GetDelegateForFunctionPointer(funcPtr, funcType); -} diff --git a/targets/Nncase.Targets.CSource/Nncase.Targets.CSource.csproj b/targets/Nncase.Targets.CSource/Nncase.Targets.CSource.csproj deleted file mode 100644 index 79226f6c03..0000000000 --- a/targets/Nncase.Targets.CSource/Nncase.Targets.CSource.csproj +++ /dev/null @@ -1,16 +0,0 @@ - - - - net6.0 - enable - enable - $(SolutionDir)/tools/StyleCopAnalyzers.ruleset - - - - - - - - - diff --git a/targets/Nncase.Targets.CSource/Schedule/CSourceScheduler.cs b/targets/Nncase.Targets.CSource/Schedule/CSourceScheduler.cs deleted file mode 100644 index b57cf2419d..0000000000 --- a/targets/Nncase.Targets.CSource/Schedule/CSourceScheduler.cs +++ /dev/null @@ -1,22 +0,0 @@ -using Nncase.IR; - -namespace Nncase.Schedule; - -public class CSourceScheduler : IScheduler -{ - - public CSourceScheduler(IR.IRModule main_module, ITarget target) - { - Module = main_module; - Target = target; - } - - public ITarget Target { get; set; } - public IRModule Module { get; set; } - - - IRModel IScheduler.Schedule(bool skip_buffer_alias) - { - return new IRModel(new[] { Module }); - } -} \ No newline at end of file diff --git a/tools/Nncase.SourceGenerator/Pattern/PatternGenerator.cs b/tools/Nncase.SourceGenerator/Pattern/PatternGenerator.cs index 7495791550..8e2aa7cfa2 100644 --- a/tools/Nncase.SourceGenerator/Pattern/PatternGenerator.cs +++ b/tools/Nncase.SourceGenerator/Pattern/PatternGenerator.cs @@ -85,7 +85,7 @@ select Parameter(Identifier(f.Name.ToLower())) // var x = name_params[0]; statements.Add(ParseStatement(@$"return new( new OpPattern<{cand.Op.ToDisplayString()}>(x => {condition}, {(name_params[0] != null ? "target_name" : "null")}), -new VArgsPattern (new[]{{ {inputs} }}, null), +new VArgsPattern (new Pattern[]{{ {inputs} }}, null), {(name_params[1] != null ? "call_name" : "null")});"). WithLeadingTrivia(ElasticTab). WithTrailingTrivia(ElasticLineFeed)); @@ -125,7 +125,7 @@ select Parameter(Identifier(f.Name.ToLower())) // 1.3 build method return statements.Add(ParseStatement(@$"return new( new OpPattern<{cand.Op.ToDisplayString()}>(condition, {(name_params[0] != null ? "target_name" : "null")}), -new VArgsPattern( new [] {{ {inputs} }}, null ), +new VArgsPattern( new Pattern[] {{ {inputs} }}, null ), {(name_params[1] != null ? "call_name" : "null")});"). WithLeadingTrivia(ElasticTab). WithTrailingTrivia(ElasticLineFeed)); From 49c9baac44ff75f25c025188f40b71041f04d80b Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?=E9=83=91=E5=90=AF=E8=88=AA?= <597323109@qq.com> Date: Fri, 21 Jul 2023 18:05:55 +0800 Subject: [PATCH 007/308] fix csource entry --- .../CodeGen/CSourceBuiltn.cs | 39 +++++++++++++++++-- .../CodeGen/CSourceConvertVisitor.cs | 6 +-- .../CodeGen/CSourceUtilities.cs | 2 +- .../CodeGen/LinkableFunction.cs | 8 ++-- .../CodeGen/LinkableModule.cs | 18 +++++++++ 5 files changed, 63 insertions(+), 10 deletions(-) diff --git a/modules/Nncase.Modules.CPU/CodeGen/CSourceBuiltn.cs b/modules/Nncase.Modules.CPU/CodeGen/CSourceBuiltn.cs index eda007050f..adb7195e5f 100644 --- a/modules/Nncase.Modules.CPU/CodeGen/CSourceBuiltn.cs +++ b/modules/Nncase.Modules.CPU/CodeGen/CSourceBuiltn.cs @@ -9,14 +9,47 @@ public static class CSourceBuiltn public const string BufferType = "buffer_t"; public const string BufferStruct = @"typedef struct buffer { - void *ptr; + void *vaddr; + size_t paddr; int *shape; int *stride; int rank; } buffer_t;"; - public const string Include = @"#include"; + public const string MethodTable = @"typedef struct nncase_method_table { + float (*float_unary_asin)(float); +} nncase_mt_t;"; + + public const string Include = @"#include +#include +"; + + public const string FixedParameters = "nncase_mt_t* nncase_mt, void* data, void* rdata"; + + public const string MainPrologue = $@"void _start(char* name, buffer_t** buffers, {FixedParameters}) {{"; + + public const string MainEpilogue = @"}"; + + public static string Header => $@" +{Include} + +{MethodTable} + +{BufferStruct} + +int strcmp(const char* s1,const char* s2) {{ + while(*s1 && *s2) {{ + if(*s1 != *s2) {{ + break; + }} + s1++; + s2++; + }} + return (*s1 - *s2) || (*s1 - '\0') || (*s2 - '\0'); +}} + +static nncase_mt_t *nncase_mt; +"; - public static string Header => Include + "\n" + BufferStruct; } \ No newline at end of file diff --git a/modules/Nncase.Modules.CPU/CodeGen/CSourceConvertVisitor.cs b/modules/Nncase.Modules.CPU/CodeGen/CSourceConvertVisitor.cs index 71b8d5ee65..7df0c005dc 100644 --- a/modules/Nncase.Modules.CPU/CodeGen/CSourceConvertVisitor.cs +++ b/modules/Nncase.Modules.CPU/CodeGen/CSourceConvertVisitor.cs @@ -123,7 +123,7 @@ protected override CSymbol VisitPrimFunction(PrimFunction expr) throw new NotSupportedException("The PrimFunction must return void!"); } - var type = $"void {expr.Name}({string.Join(", ", expr.Parameters.AsValueEnumerable().Select(b => Visit(b).ToString()).ToArray())})"; + var type = $"void {expr.Name}({string.Join(", ", expr.Parameters.AsValueEnumerable().Select(b => Visit(b).ToString()).ToArray())}, {CSourceBuiltn.FixedParameters})"; using (var scope = new IndentScope(_implBuilder)) { @@ -169,10 +169,10 @@ protected override CSymbol VisitCall(Call expr) str = CSourceUtilities.ContertUnary(op, arguments); break; case Store: - str = $"((({arguments[2].Type} *){arguments[0].Name}->ptr)[{arguments[1].Name}] = {arguments[2].Name})"; + str = $"((({arguments[2].Type} *){arguments[0].Name}->vaddr)[{arguments[1].Name}] = {arguments[2].Name})"; break; case Load: - str = $"((({type} *){arguments[0].Name}->ptr)[{arguments[1].Name}])"; + str = $"((({type} *){arguments[0].Name}->vaddr)[{arguments[1].Name}])"; break; default: throw new NotSupportedException(); diff --git a/modules/Nncase.Modules.CPU/CodeGen/CSourceUtilities.cs b/modules/Nncase.Modules.CPU/CodeGen/CSourceUtilities.cs index fbcf4e372d..be8d32cabd 100644 --- a/modules/Nncase.Modules.CPU/CodeGen/CSourceUtilities.cs +++ b/modules/Nncase.Modules.CPU/CodeGen/CSourceUtilities.cs @@ -31,7 +31,7 @@ internal static string ContertUnary(Unary op, CSymbol[] arguments) str = ($"!{input}"); break; default: - str = ($"nncase_mt->{op.UnaryOp.ToString()}{input}"); + str = ($"nncase_mt->{arguments[0].Type}_{nameof(Unary).ToLower()}_{op.UnaryOp.ToString().ToLower()}{input}"); break; } diff --git a/modules/Nncase.Modules.CPU/CodeGen/LinkableFunction.cs b/modules/Nncase.Modules.CPU/CodeGen/LinkableFunction.cs index bd6dc4c270..4ad32bd353 100644 --- a/modules/Nncase.Modules.CPU/CodeGen/LinkableFunction.cs +++ b/modules/Nncase.Modules.CPU/CodeGen/LinkableFunction.cs @@ -7,21 +7,23 @@ namespace Nncase.CodeGen.CPU; internal sealed class LinkableFunction : ILinkableFunction { - private readonly byte[] _desc; - public LinkableFunction(uint id, TIR.PrimFunction sourceFunction, FunctionCSource funcCSource) { Id = id; SourceFunction = sourceFunction; + PrimFunction = sourceFunction; FunctionCSource = funcCSource; Text = Array.Empty(); - Sections = new LinkedSection[] { }; + var desc = System.Text.Encoding.ASCII.GetBytes(sourceFunction.Name); + Sections = new LinkedSection[] { new(desc, ".desc", 0, 8, (uint)desc.Length) }; } public uint Id { get; } public BaseFunction SourceFunction { get; } + public TIR.PrimFunction PrimFunction { get; } + public FunctionCSource FunctionCSource { get; } public byte[] Text { get; } diff --git a/modules/Nncase.Modules.CPU/CodeGen/LinkableModule.cs b/modules/Nncase.Modules.CPU/CodeGen/LinkableModule.cs index 8c8f5952a3..350460e5a2 100644 --- a/modules/Nncase.Modules.CPU/CodeGen/LinkableModule.cs +++ b/modules/Nncase.Modules.CPU/CodeGen/LinkableModule.cs @@ -33,6 +33,13 @@ public ILinkedModule Link(ILinkContext linkContext) if (DumpScope.Current.IsEnabled(DumpFlags.CodeGen)) { + using (var fs = DumpScope.Current.OpenFile("cpuModule.h")) + { + using (var writer = new StreamWriter(fs)) + { + writer.Write(CSourceBuiltn.Header); + } + } using (var fs = DumpScope.Current.OpenFile("cpuModule.c")) { File.Open(csourcePath, FileMode.Open, FileAccess.Read).CopyTo(fs); @@ -56,6 +63,7 @@ private string LinkCSources() using (var writer = new StreamWriter(fs)) { writer.WriteLine(CSourceBuiltn.Header); + foreach (var func in _functions) { writer.WriteLine(func.FunctionCSource.Declaration); @@ -65,6 +73,16 @@ private string LinkCSources() { writer.WriteLine(func.FunctionCSource.Implementation); } + + writer.WriteLine(CSourceBuiltn.MainPrologue); + foreach (var func in _functions) + { + writer.WriteLine($" if (strcmp(name,\"{func.SourceFunction.Name}\") == 0) {{"); + writer.WriteLine($" {func.SourceFunction.Name}({string.Join(",", Enumerable.Range(0, func.PrimFunction.Parameters.Length).Select(i => $"buffers[{i}]"))}, nncase_mt, data, rdata);"); + writer.WriteLine(" } else"); + } + writer.WriteLine(" { }"); + writer.WriteLine(CSourceBuiltn.MainEpilogue); } } From 33673656064d62b541ffbc6e5e834c33fe63b858 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?=E9=83=91=E5=90=AF=E8=88=AA?= <597323109@qq.com> Date: Fri, 21 Jul 2023 19:29:40 +0800 Subject: [PATCH 008/308] add cpu target compile --- CMakeLists.txt | 2 +- modules/cpu/CMakeLists.txt | 56 +++++++++++++++ .../nncase_rt_modules_cpuConfig.cmake.in | 1 + .../nncase/runtime/cpu/compiler_defs.h | 28 ++++++++ .../nncase/runtime/cpu/runtime_module.h | 29 ++++++++ modules/cpu/src/runtime/CMakeLists.txt | 18 +++++ modules/cpu/src/runtime/runtime_function.cpp | 52 ++++++++++++++ modules/cpu/src/runtime/runtime_function.h | 37 ++++++++++ modules/cpu/src/runtime/runtime_module.cpp | 70 +++++++++++++++++++ modules/cpu/src/runtime/runtime_module.h | 38 ++++++++++ 10 files changed, 330 insertions(+), 1 deletion(-) create mode 100644 modules/cpu/CMakeLists.txt create mode 100644 modules/cpu/cmake/nncase_rt_modules_cpuConfig.cmake.in create mode 100644 modules/cpu/include/nncase/runtime/cpu/compiler_defs.h create mode 100644 modules/cpu/include/nncase/runtime/cpu/runtime_module.h create mode 100644 modules/cpu/src/runtime/CMakeLists.txt create mode 100644 modules/cpu/src/runtime/runtime_function.cpp create mode 100644 modules/cpu/src/runtime/runtime_function.h create mode 100644 modules/cpu/src/runtime/runtime_module.cpp create mode 100644 modules/cpu/src/runtime/runtime_module.h diff --git a/CMakeLists.txt b/CMakeLists.txt index 83e3f71db3..3ac570f4c9 100644 --- a/CMakeLists.txt +++ b/CMakeLists.txt @@ -265,5 +265,5 @@ if(BUILD_TESTING) endif() # Modules -#add_subdirectory(modules/k210) +add_subdirectory(modules/cpu) #add_subdirectory(modules/vulkan) diff --git a/modules/cpu/CMakeLists.txt b/modules/cpu/CMakeLists.txt new file mode 100644 index 0000000000..2d7a213559 --- /dev/null +++ b/modules/cpu/CMakeLists.txt @@ -0,0 +1,56 @@ +cmake_minimum_required (VERSION 3.8) + +include_directories(include) +#add_subdirectory(src/kernels) +add_subdirectory(src/runtime) + +if (BUILDING_RUNTIME) + if (ENABLE_VULKAN_RUNTIME) + add_library(nncase_rt_modules_cpu STATIC ${SRCS}) + target_include_directories(nncase_rt_modules_cpu PRIVATE include) + target_link_libraries(nncase_rt_modules_cpu PRIVATE runtime_cpu) + set_target_properties(nncase_rt_modules_cpu PROPERTIES + OUTPUT_NAME "nncase.rt_modules.cpu") + + install(DIRECTORY include/nncase/kernels + DESTINATION include/nncase + COMPONENT nncase-headers + FILES_MATCHING + PATTERN "*.def" + PATTERN "*.h" + PATTERN "*.hpp" + PATTERN "*.td" + PATTERN "*.inc" + PATTERN "LICENSE.TXT" + ) + + install(DIRECTORY include/nncase/runtime + DESTINATION include/nncase + COMPONENT nncase-headers + FILES_MATCHING + PATTERN "*.def" + PATTERN "*.h" + PATTERN "*.hpp" + PATTERN "*.td" + PATTERN "*.inc" + PATTERN "LICENSE.TXT" + ) + + install(TARGETS nncase_rt_modules_cpu EXPORT nncaseruntimeTargets + ARCHIVE DESTINATION lib + LIBRARY DESTINATION lib + RUNTIME DESTINATION bin + INCLUDES DESTINATION include + ) + + configure_file(${CMAKE_CURRENT_LIST_DIR}/cmake/nncase_rt_modules_cpuConfig.cmake.in nncase_rt_modules_cpuConfig.cmake @ONLY) + install(FILES ${CMAKE_CURRENT_BINARY_DIR}/nncase_rt_modules_cpuConfig.cmake DESTINATION lib/cmake/nncaseruntime) + endif() +else() + add_library(nncase_modules_cpu SHARED ${SRCS}) + target_include_directories(nncase_modules_cpu PUBLIC include) + target_link_libraries(nncase_modules_cpu PRIVATE simulator_cpu nncasebase) + set_target_properties(nncase_modules_cpu PROPERTIES OUTPUT_NAME "nncase.modules.cpu") + install(TARGETS nncase_modules_cpu + COMPONENT nncase-runtime) +endif() diff --git a/modules/cpu/cmake/nncase_rt_modules_cpuConfig.cmake.in b/modules/cpu/cmake/nncase_rt_modules_cpuConfig.cmake.in new file mode 100644 index 0000000000..497d8eb3b3 --- /dev/null +++ b/modules/cpu/cmake/nncase_rt_modules_cpuConfig.cmake.in @@ -0,0 +1 @@ +include(${CMAKE_CURRENT_LIST_DIR}/nncase_rt_modules_cpuTargets.cmake) \ No newline at end of file diff --git a/modules/cpu/include/nncase/runtime/cpu/compiler_defs.h b/modules/cpu/include/nncase/runtime/cpu/compiler_defs.h new file mode 100644 index 0000000000..fa7dcceef0 --- /dev/null +++ b/modules/cpu/include/nncase/runtime/cpu/compiler_defs.h @@ -0,0 +1,28 @@ +/* Copyright 2019-2021 Canaan Inc. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#pragma once +#include + +#if defined(_MSC_VER) +#ifdef NNCASE_MODULES_CPU_DLL +#define NNCASE_MODULES_CPU_API __declspec(dllexport) +#elif NNCASE_SHARED_LIBS +#define NNCASE_MODULES_CPU_API __declspec(dllimport) +#else +#define NNCASE_MODULES_CPU_API +#endif +#else +#define NNCASE_MODULES_CPU_API __attribute__((visibility("default"))) +#endif diff --git a/modules/cpu/include/nncase/runtime/cpu/runtime_module.h b/modules/cpu/include/nncase/runtime/cpu/runtime_module.h new file mode 100644 index 0000000000..5aecaa4aa7 --- /dev/null +++ b/modules/cpu/include/nncase/runtime/cpu/runtime_module.h @@ -0,0 +1,29 @@ +/* Copyright 2019-2021 Canaan Inc. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#pragma once +#include "compiler_defs.h" +#include + +BEGIN_NS_NNCASE_RT_MODULE(cpu) + +NNCASE_INLINE_VAR constexpr char DESCRIPTORS_SECTION_NAME[] = ".desc"; + +NNCASE_INLINE_VAR constexpr module_kind_t cpu_module_type = to_module_kind("cpu"); +NNCASE_INLINE_VAR constexpr uint32_t cpu_module_version = 0; + +NNCASE_MODULES_CPU_API result> +create_cpu_runtime_module(); + +END_NS_NNCASE_RT_MODULE diff --git a/modules/cpu/src/runtime/CMakeLists.txt b/modules/cpu/src/runtime/CMakeLists.txt new file mode 100644 index 0000000000..fe8fc9236a --- /dev/null +++ b/modules/cpu/src/runtime/CMakeLists.txt @@ -0,0 +1,18 @@ +cmake_minimum_required (VERSION 3.13) + +set(SRCS runtime_module.cpp + runtime_function.cpp) + +if (BUILDING_RUNTIME) + if (ENABLE_CPU_RUNTIME) + add_library(runtime_cpu OBJECT ${SRCS}) + target_link_libraries(runtime_cpu PUBLIC nncaseruntime) + set_target_properties(runtime_cpu PROPERTIES POSITION_INDEPENDENT_CODE ON) + install(TARGETS runtime_cpu EXPORT nncaseruntimeTargets) + endif() +else() + add_library(simulator_cpu OBJECT ${SRCS}) + target_link_libraries(simulator_cpu PUBLIC nncasebase nncaseruntime) + target_compile_definitions(simulator_cpu PUBLIC -DNNCASE_MODULES_CPU_DLL -DNNCASE_SIMULATOR) + set_target_properties(simulator_cpu PROPERTIES POSITION_INDEPENDENT_CODE ON) +endif() diff --git a/modules/cpu/src/runtime/runtime_function.cpp b/modules/cpu/src/runtime/runtime_function.cpp new file mode 100644 index 0000000000..cdeec6a4ba --- /dev/null +++ b/modules/cpu/src/runtime/runtime_function.cpp @@ -0,0 +1,52 @@ +/* Copyright 2019-2021 Canaan Inc. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#include "runtime_function.h" +#include +#include +#include + +using namespace nncase; +using namespace nncase::runtime; +using namespace nncase::runtime::cpu; + +cpu_runtime_module &cpu_runtime_function::module() const noexcept { + return static_cast(runtime_function::module()); +} + +result cpu_runtime_function::initialize_core( + runtime_function_init_context &context) noexcept { + text_ = context.module_init_context().section(".text").subspan( + context.header().entrypoint, context.header().text_size); + + return ok(); +} + +result +cpu_runtime_function::invoke_core(NNCASE_UNUSED gsl::span parameters, + value_t return_value) noexcept { + // try_(preprocess_inputs()); + + // vk::SubmitInfo si({}, {}, cmd_buffer_, {}); + // try_(vk::to_result(module().compute_queue().submit(si))); + // try_(vk::to_result(module().compute_queue().waitIdle())); + // try_(vk::to_result(module().device().waitIdle())); + + // assert(buffer_refs_.empty()); + // assert(buffer_copies_.empty()); + // assert(buffer_barriers_.empty()); + + // try_(postprocess_outputs()); + return ok(return_value); +} \ No newline at end of file diff --git a/modules/cpu/src/runtime/runtime_function.h b/modules/cpu/src/runtime/runtime_function.h new file mode 100644 index 0000000000..3504a46695 --- /dev/null +++ b/modules/cpu/src/runtime/runtime_function.h @@ -0,0 +1,37 @@ +/* Copyright 2019-2021 Canaan Inc. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#pragma once +#include "runtime_module.h" +#include +#include + +BEGIN_NS_NNCASE_RT_MODULE(cpu) + +class cpu_runtime_function : public runtime_function { + public: + using runtime_function::runtime_function; + + cpu_runtime_module &module() const noexcept; + + protected: + result initialize_core(runtime_function_init_context &context) noexcept override; + result invoke_core(gsl::span parameters, + value_t return_value) noexcept override; + + private: + gsl::span text_; +}; + +END_NS_NNCASE_RT_MODULE diff --git a/modules/cpu/src/runtime/runtime_module.cpp b/modules/cpu/src/runtime/runtime_module.cpp new file mode 100644 index 0000000000..8c30911df9 --- /dev/null +++ b/modules/cpu/src/runtime/runtime_module.cpp @@ -0,0 +1,70 @@ +/* Copyright 2019-2021 Canaan Inc. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#include "runtime_module.h" +#include "runtime_function.h" +#include +#include +#include + +using namespace nncase; +using namespace nncase::runtime; +using namespace nncase::runtime::cpu; + +result cpu_runtime_module::initialize_before_functions( + runtime_module_init_context &context) noexcept { + if (!context.is_section_pinned()) + return nncase::err(std::errc::bad_address); + // try_var(rdata_, + // context.get_or_read_section(".rdata", rdata_storage_, true)); + // try_var(text_, context.get_or_read_section(".text", text_storage_, true)); + + // auto descs = + // context.section(DESCRIPTORS_SECTION_NAME).as_span(); + // descriptor_sets_ = descs[0]; + // descriptors_ = descs[1]; + // shader_ = context.section(".shader"); + + // rdata_ = context.section(".rdata"); + return ok(); +} + +result> +cpu_runtime_module::create_function() noexcept { + std::unique_ptr mod(new (std::nothrow) + cpu_runtime_function(*this)); + if (mod) + return ok(std::move(mod)); + return err(std::errc::not_enough_memory); +} + +result> cpu::create_cpu_runtime_module() { + std::unique_ptr mod(new (std::nothrow) + cpu_runtime_module()); + if (mod) + return ok(std::move(mod)); + return err(std::errc::not_enough_memory); +} + +extern "C" { +NNCASE_MODULES_CPU_API void +RUNTIME_MODULE_ACTIVATOR_NAME(result> &result) { + result = create_cpu_runtime_module(); +} +} + +#ifndef NNCASE_SIMULATOR +runtime_registration nncase::runtime::builtin_runtimes[] = { + {cpu_module_type, RUNTIME_MODULE_ACTIVATOR_NAME}, {}}; +#endif diff --git a/modules/cpu/src/runtime/runtime_module.h b/modules/cpu/src/runtime/runtime_module.h new file mode 100644 index 0000000000..d179f9fd40 --- /dev/null +++ b/modules/cpu/src/runtime/runtime_module.h @@ -0,0 +1,38 @@ +/* Copyright 2019-2021 Canaan Inc. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#pragma once +#include + +BEGIN_NS_NNCASE_RT_MODULE(cpu) + +class cpu_runtime_module : public runtime_module { + public: + gsl::span rdata_physical() noexcept { return rdata_; } + gsl::span data_physical() noexcept { return data_; } + gsl::span text_physical() noexcept { return text_; } + + protected: + result initialize_before_functions( + runtime_module_init_context &context) noexcept override; + result> + create_function() noexcept override; + + private: + gsl::span data_; + gsl::span text_; + gsl::span rdata_; +}; + +END_NS_NNCASE_RT_MODULE From c95026167a7a9e0ba2ee47355fff9308b387e774 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?=E9=83=91=E5=90=AF=E8=88=AA?= <597323109@qq.com> Date: Fri, 21 Jul 2023 19:38:20 +0800 Subject: [PATCH 009/308] fix build --- modules/cpu/src/runtime/runtime_function.cpp | 4 +--- modules/cpu/src/runtime/runtime_function.h | 2 +- 2 files changed, 2 insertions(+), 4 deletions(-) diff --git a/modules/cpu/src/runtime/runtime_function.cpp b/modules/cpu/src/runtime/runtime_function.cpp index cdeec6a4ba..1ed0c0a3ed 100644 --- a/modules/cpu/src/runtime/runtime_function.cpp +++ b/modules/cpu/src/runtime/runtime_function.cpp @@ -26,9 +26,7 @@ cpu_runtime_module &cpu_runtime_function::module() const noexcept { } result cpu_runtime_function::initialize_core( - runtime_function_init_context &context) noexcept { - text_ = context.module_init_context().section(".text").subspan( - context.header().entrypoint, context.header().text_size); + NNCASE_UNUSED runtime_function_init_context &context) noexcept { return ok(); } diff --git a/modules/cpu/src/runtime/runtime_function.h b/modules/cpu/src/runtime/runtime_function.h index 3504a46695..b7554144b5 100644 --- a/modules/cpu/src/runtime/runtime_function.h +++ b/modules/cpu/src/runtime/runtime_function.h @@ -31,7 +31,7 @@ class cpu_runtime_function : public runtime_function { value_t return_value) noexcept override; private: - gsl::span text_; + std::string name_; }; END_NS_NNCASE_RT_MODULE From 138527d0342f7ede01a78762a4fb6fe981717a18 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?=E9=83=91=E5=90=AF=E8=88=AA?= <597323109@qq.com> Date: Fri, 21 Jul 2023 19:58:46 +0800 Subject: [PATCH 010/308] add cpu module --- .../CodeGen/LinkableFunction.cs | 4 ++-- modules/cpu/CMakeLists.txt | 4 ++-- .../include/nncase/runtime/cpu/runtime_module.h | 2 +- modules/cpu/src/runtime/runtime_function.cpp | 16 +++------------- modules/cpu/src/runtime/runtime_module.cpp | 13 +++---------- modules/cpu/src/runtime/runtime_module.h | 3 +++ .../Targets/UnitTestCPUTargetTiling.cs | 6 +++++- 7 files changed, 19 insertions(+), 29 deletions(-) diff --git a/modules/Nncase.Modules.CPU/CodeGen/LinkableFunction.cs b/modules/Nncase.Modules.CPU/CodeGen/LinkableFunction.cs index 4ad32bd353..ae49c39b3c 100644 --- a/modules/Nncase.Modules.CPU/CodeGen/LinkableFunction.cs +++ b/modules/Nncase.Modules.CPU/CodeGen/LinkableFunction.cs @@ -14,8 +14,8 @@ public LinkableFunction(uint id, TIR.PrimFunction sourceFunction, FunctionCSourc PrimFunction = sourceFunction; FunctionCSource = funcCSource; Text = Array.Empty(); - var desc = System.Text.Encoding.ASCII.GetBytes(sourceFunction.Name); - Sections = new LinkedSection[] { new(desc, ".desc", 0, 8, (uint)desc.Length) }; + var name = System.Text.Encoding.ASCII.GetBytes(sourceFunction.Name); + Sections = new LinkedSection[] { new(name, ".name", 0, 8, (uint)name.Length) }; } public uint Id { get; } diff --git a/modules/cpu/CMakeLists.txt b/modules/cpu/CMakeLists.txt index 2d7a213559..f030315513 100644 --- a/modules/cpu/CMakeLists.txt +++ b/modules/cpu/CMakeLists.txt @@ -5,7 +5,7 @@ include_directories(include) add_subdirectory(src/runtime) if (BUILDING_RUNTIME) - if (ENABLE_VULKAN_RUNTIME) + if (ENABLE_CPU_RUNTIME) add_library(nncase_rt_modules_cpu STATIC ${SRCS}) target_include_directories(nncase_rt_modules_cpu PRIVATE include) target_link_libraries(nncase_rt_modules_cpu PRIVATE runtime_cpu) @@ -50,7 +50,7 @@ else() add_library(nncase_modules_cpu SHARED ${SRCS}) target_include_directories(nncase_modules_cpu PUBLIC include) target_link_libraries(nncase_modules_cpu PRIVATE simulator_cpu nncasebase) - set_target_properties(nncase_modules_cpu PROPERTIES OUTPUT_NAME "nncase.modules.cpu") + set_target_properties(nncase_modules_cpu PROPERTIES OUTPUT_NAME "nncase.simulator.cpu") install(TARGETS nncase_modules_cpu COMPONENT nncase-runtime) endif() diff --git a/modules/cpu/include/nncase/runtime/cpu/runtime_module.h b/modules/cpu/include/nncase/runtime/cpu/runtime_module.h index 5aecaa4aa7..4642409eb2 100644 --- a/modules/cpu/include/nncase/runtime/cpu/runtime_module.h +++ b/modules/cpu/include/nncase/runtime/cpu/runtime_module.h @@ -18,7 +18,7 @@ BEGIN_NS_NNCASE_RT_MODULE(cpu) -NNCASE_INLINE_VAR constexpr char DESCRIPTORS_SECTION_NAME[] = ".desc"; +NNCASE_INLINE_VAR constexpr char FUNCTION_NAME_SECTION_IDENTIFIER[] = ".name"; NNCASE_INLINE_VAR constexpr module_kind_t cpu_module_type = to_module_kind("cpu"); NNCASE_INLINE_VAR constexpr uint32_t cpu_module_version = 0; diff --git a/modules/cpu/src/runtime/runtime_function.cpp b/modules/cpu/src/runtime/runtime_function.cpp index 1ed0c0a3ed..a268b68403 100644 --- a/modules/cpu/src/runtime/runtime_function.cpp +++ b/modules/cpu/src/runtime/runtime_function.cpp @@ -27,24 +27,14 @@ cpu_runtime_module &cpu_runtime_function::module() const noexcept { result cpu_runtime_function::initialize_core( NNCASE_UNUSED runtime_function_init_context &context) noexcept { - + try_var(name, context.section(FUNCTION_NAME_SECTION_IDENTIFIER)); + name_ = std::string(name.as_span().data()); return ok(); } result cpu_runtime_function::invoke_core(NNCASE_UNUSED gsl::span parameters, value_t return_value) noexcept { - // try_(preprocess_inputs()); - - // vk::SubmitInfo si({}, {}, cmd_buffer_, {}); - // try_(vk::to_result(module().compute_queue().submit(si))); - // try_(vk::to_result(module().compute_queue().waitIdle())); - // try_(vk::to_result(module().device().waitIdle())); - - // assert(buffer_refs_.empty()); - // assert(buffer_copies_.empty()); - // assert(buffer_barriers_.empty()); - - // try_(postprocess_outputs()); + std::cout << "call " << name_ << std::endl; return ok(return_value); } \ No newline at end of file diff --git a/modules/cpu/src/runtime/runtime_module.cpp b/modules/cpu/src/runtime/runtime_module.cpp index 8c30911df9..6e94bcf699 100644 --- a/modules/cpu/src/runtime/runtime_module.cpp +++ b/modules/cpu/src/runtime/runtime_module.cpp @@ -26,17 +26,10 @@ result cpu_runtime_module::initialize_before_functions( runtime_module_init_context &context) noexcept { if (!context.is_section_pinned()) return nncase::err(std::errc::bad_address); - // try_var(rdata_, - // context.get_or_read_section(".rdata", rdata_storage_, true)); - // try_var(text_, context.get_or_read_section(".text", text_storage_, true)); + try_var(data_, context.get_or_read_section(".rdata", data_storage_, false)); + try_var(rdata_,context.get_or_read_section(".rdata", rdata_storage_, true)); + try_var(text_, context.get_or_read_section(".text", text_storage_, true)); - // auto descs = - // context.section(DESCRIPTORS_SECTION_NAME).as_span(); - // descriptor_sets_ = descs[0]; - // descriptors_ = descs[1]; - // shader_ = context.section(".shader"); - - // rdata_ = context.section(".rdata"); return ok(); } diff --git a/modules/cpu/src/runtime/runtime_module.h b/modules/cpu/src/runtime/runtime_module.h index d179f9fd40..f672aeeab2 100644 --- a/modules/cpu/src/runtime/runtime_module.h +++ b/modules/cpu/src/runtime/runtime_module.h @@ -33,6 +33,9 @@ class cpu_runtime_module : public runtime_module { gsl::span data_; gsl::span text_; gsl::span rdata_; + host_buffer_t text_storage_; + host_buffer_t rdata_storage_; + host_buffer_t data_storage_; }; END_NS_NNCASE_RT_MODULE diff --git a/src/Nncase.Tests/Targets/UnitTestCPUTargetTiling.cs b/src/Nncase.Tests/Targets/UnitTestCPUTargetTiling.cs index 28ff7ccf07..376376f888 100644 --- a/src/Nncase.Tests/Targets/UnitTestCPUTargetTiling.cs +++ b/src/Nncase.Tests/Targets/UnitTestCPUTargetTiling.cs @@ -43,9 +43,13 @@ public async Task TestCpuUnary() var compiler = CompileSession.Compiler; compiler.ImportIRModule(module); await compiler.CompileAsync(); - using (var fs = new MemoryStream()) + using (var fs = Dumpper.OpenFile("test.kmodel")) { compiler.Gencode(fs); } + using (var fs = Dumpper.OpenFile("input_0.bin")) + { + fs.Write(IR.F.Random.Normal(DataTypes.Float32, 0, 1, 2, new[] { 1, 2, 3, 4, 5 }).Evaluate().AsTensor().BytesBuffer); + } } } From 941952f2235af8dbe66a934077f167ed70686ea1 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?=E9=83=91=E5=90=AF=E8=88=AA?= <597323109@qq.com> Date: Fri, 21 Jul 2023 20:08:24 +0800 Subject: [PATCH 011/308] fix cpu module init --- modules/cpu/src/runtime/runtime_module.cpp | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/modules/cpu/src/runtime/runtime_module.cpp b/modules/cpu/src/runtime/runtime_module.cpp index 6e94bcf699..c8dad680fd 100644 --- a/modules/cpu/src/runtime/runtime_module.cpp +++ b/modules/cpu/src/runtime/runtime_module.cpp @@ -26,7 +26,7 @@ result cpu_runtime_module::initialize_before_functions( runtime_module_init_context &context) noexcept { if (!context.is_section_pinned()) return nncase::err(std::errc::bad_address); - try_var(data_, context.get_or_read_section(".rdata", data_storage_, false)); + try_var(data_, context.get_or_read_section(".data", data_storage_, false)); try_var(rdata_,context.get_or_read_section(".rdata", rdata_storage_, true)); try_var(text_, context.get_or_read_section(".text", text_storage_, true)); From 56efb3b6b599acb708d6008aa11c17e760aad375 Mon Sep 17 00:00:00 2001 From: xhuohai Date: Mon, 24 Jul 2023 02:20:46 +0000 Subject: [PATCH 012/308] Apply code-format changes --- .../CodeGen/CSourceBuiltn.cs | 7 +- .../CodeGen/CSourceConvertVisitor.cs | 56 +++++++------ .../CodeGen/CSourceExtensions.cs | 7 +- .../CodeGen/CSourceUtilities.cs | 12 ++- .../CodeGen/FunctionCSource.cs | 83 ++++++++++--------- .../CodeGen/LinkableModule.cs | 3 +- .../Passes/CPUFusionToTirPass.cs | 2 +- .../Passes/Tile/CPUFusionGroupMutator.cs | 1 - .../Passes/Tile/SingleCPUFusionConverter.cs | 14 ++-- .../nncase/runtime/cpu/runtime_module.h | 3 +- modules/cpu/src/runtime/runtime_function.h | 5 +- modules/cpu/src/runtime/runtime_module.cpp | 3 +- src/Nncase.Core/Schedule/ScheduleTypes.cs | 1 - src/Nncase.Core/TIR/MemSpan.cs | 13 +-- src/Nncase.Evaluator/TIR/Store.cs | 1 - src/Nncase.Tests/Core/UnitTestTIR.cs | 4 +- .../Targets/UnitTestCPUTargetTiling.cs | 1 + 17 files changed, 120 insertions(+), 96 deletions(-) diff --git a/modules/Nncase.Modules.CPU/CodeGen/CSourceBuiltn.cs b/modules/Nncase.Modules.CPU/CodeGen/CSourceBuiltn.cs index adb7195e5f..45b5563f5f 100644 --- a/modules/Nncase.Modules.CPU/CodeGen/CSourceBuiltn.cs +++ b/modules/Nncase.Modules.CPU/CodeGen/CSourceBuiltn.cs @@ -1,11 +1,10 @@ -// Copyright (c) Canaan Inc. All rights reserved. +// Copyright (c) Canaan Inc. All rights reserved. // Licensed under the Apache license. See LICENSE file in the project root for full license information. namespace Nncase.CodeGen.CPU; public static class CSourceBuiltn { - public const string BufferType = "buffer_t"; public const string BufferStruct = @"typedef struct buffer { @@ -50,6 +49,4 @@ int strcmp(const char* s1,const char* s2) {{ static nncase_mt_t *nncase_mt; "; - - -} \ No newline at end of file +} diff --git a/modules/Nncase.Modules.CPU/CodeGen/CSourceConvertVisitor.cs b/modules/Nncase.Modules.CPU/CodeGen/CSourceConvertVisitor.cs index 7df0c005dc..ab4397910d 100644 --- a/modules/Nncase.Modules.CPU/CodeGen/CSourceConvertVisitor.cs +++ b/modules/Nncase.Modules.CPU/CodeGen/CSourceConvertVisitor.cs @@ -1,4 +1,4 @@ -// Copyright (c) Canaan Inc. All rights reserved. +// Copyright (c) Canaan Inc. All rights reserved. // Licensed under the Apache license. See LICENSE file in the project root for full license information. using System; @@ -16,23 +16,6 @@ namespace Nncase.CodeGen.CPU; -/// -/// the c symbol define. -/// -internal sealed class CSymbol -{ - public CSymbol(string type, string name) - { - Type = type; - Name = name; - } - - public string Type { get; } - public string Name { get; } - - public override string ToString() => $"{Type} {Name}"; -} - internal struct IndentScope : IDisposable { private static readonly AsyncLocal _writer = new AsyncLocal(); @@ -71,12 +54,30 @@ public void Dispose() } } +/// +/// the c symbol define. +/// +internal sealed class CSymbol +{ + public CSymbol(string type, string name) + { + Type = type; + Name = name; + } + + public string Type { get; } + + public string Name { get; } + + public override string ToString() => $"{Type} {Name}"; +} internal sealed class IndentWriter : StringWriter { public int Indent; - public IndentWriter(StringBuilder sb, int indent = 0) : base(sb) + public IndentWriter(StringBuilder sb, int indent = 0) + : base(sb) { Indent = indent; } @@ -85,19 +86,20 @@ public void IndWrite(string? value) { for (int i = 0; i < Indent; i++) { - this.Write(' '); + Write(' '); } - this.Write(value); + + Write(value); } } /// -/// convert single prim function to c source +/// convert single prim function to c source. /// internal sealed class CSourceConvertVisitor : ExprFunctor { - private readonly StringBuilder _implBuilder; public readonly Dictionary ExprMemo; + private readonly StringBuilder _implBuilder; public CSourceConvertVisitor() { @@ -129,11 +131,13 @@ protected override CSymbol VisitPrimFunction(PrimFunction expr) { // 1. Function signature IndentScope.Writer.IndWrite($"{type} {{\n"); + // 2. Function body using (var _ = new IndentScope()) { Visit(expr.Body); } + // 3. Function closing IndentScope.Writer.IndWrite("}\n"); } @@ -154,9 +158,9 @@ protected override CSymbol VisitCall(Call expr) var arguments = expr.Arguments.AsValueEnumerable().Select(Visit).ToArray(); string type = expr.CheckedType switch { - TupleType x when x == TupleType.Void => "", + TupleType x when x == TupleType.Void => string.Empty, TensorType { IsScalar: true } x => x.DType.ToC(), - _ => throw new NotSupportedException() + _ => throw new NotSupportedException(), }; string str; @@ -190,6 +194,7 @@ protected override CSymbol VisitConst(Const expr) { return symbol; } + string type; string str; if (expr is TensorConst { Value: Tensor { ElementType: PrimType ptype, Shape: { IsScalar: true } } scalar }) @@ -248,6 +253,7 @@ protected override CSymbol VisitFor(For expr) // 2. For Body Visit(expr.Body); } + // 3. For closing IndentScope.Writer.IndWrite("}\n"); diff --git a/modules/Nncase.Modules.CPU/CodeGen/CSourceExtensions.cs b/modules/Nncase.Modules.CPU/CodeGen/CSourceExtensions.cs index 63618df79e..e5b48ec442 100644 --- a/modules/Nncase.Modules.CPU/CodeGen/CSourceExtensions.cs +++ b/modules/Nncase.Modules.CPU/CodeGen/CSourceExtensions.cs @@ -1,7 +1,6 @@ -// Copyright (c) Canaan Inc. All rights reserved. +// Copyright (c) Canaan Inc. All rights reserved. // Licensed under the Apache license. See LICENSE file in the project root for full license information. - namespace Nncase.CodeGen.CPU; /// @@ -40,6 +39,6 @@ public static string ToC(this PrimType primType) => BinaryOp.Sub => "-", BinaryOp.Mul => "*", BinaryOp.Div => "/", - _ => throw new NotSupportedException(binaryOp.ToString()) + _ => throw new NotSupportedException(binaryOp.ToString()), }; -} \ No newline at end of file +} diff --git a/modules/Nncase.Modules.CPU/CodeGen/CSourceUtilities.cs b/modules/Nncase.Modules.CPU/CodeGen/CSourceUtilities.cs index be8d32cabd..7026825fa2 100644 --- a/modules/Nncase.Modules.CPU/CodeGen/CSourceUtilities.cs +++ b/modules/Nncase.Modules.CPU/CodeGen/CSourceUtilities.cs @@ -1,5 +1,9 @@ +// Copyright (c) Canaan Inc. All rights reserved. +// Licensed under the Apache license. See LICENSE file in the project root for full license information. + using Nncase.Diagnostics; using Nncase.IR.Math; + namespace Nncase.CodeGen.CPU; internal static class CSourceUtilities @@ -12,7 +16,7 @@ public static string ContertBinary(Binary binary, CSymbol[] arguments) switch (binary.BinaryOp) { case BinaryOp.Add or BinaryOp.Sub or BinaryOp.Mul or BinaryOp.Div: - str = ($"({lhs} {binary.BinaryOp.ToC()} {rhs})"); + str = $"({lhs} {binary.BinaryOp.ToC()} {rhs})"; break; default: throw new NotSupportedException(); @@ -28,13 +32,13 @@ internal static string ContertUnary(Unary op, CSymbol[] arguments) switch (op.UnaryOp) { case UnaryOp.Neg: - str = ($"!{input}"); + str = $"!{input}"; break; default: - str = ($"nncase_mt->{arguments[0].Type}_{nameof(Unary).ToLower()}_{op.UnaryOp.ToString().ToLower()}{input}"); + str = $"nncase_mt->{arguments[0].Type}_{nameof(Unary).ToLower()}_{op.UnaryOp.ToString().ToLower()}{input}"; break; } return str; } -} \ No newline at end of file +} diff --git a/modules/Nncase.Modules.CPU/CodeGen/FunctionCSource.cs b/modules/Nncase.Modules.CPU/CodeGen/FunctionCSource.cs index c3a1e15066..b5f7ad2c86 100644 --- a/modules/Nncase.Modules.CPU/CodeGen/FunctionCSource.cs +++ b/modules/Nncase.Modules.CPU/CodeGen/FunctionCSource.cs @@ -1,7 +1,6 @@ -// Copyright (c) Canaan Inc. All rights reserved. +// Copyright (c) Canaan Inc. All rights reserved. // Licensed under the Apache license. See LICENSE file in the project root for full license information. - using System; using System.Collections; using System.Collections.Generic; @@ -16,34 +15,40 @@ namespace Nncase.CodeGen; -internal sealed class FunctionCSource -{ - public FunctionCSource(string declaration, string implementation) - { - Declaration = declaration; - Implementation = implementation; - } - - public string Declaration { get; } - public string Implementation { get; } -} - - /// /// the csource code compiler. /// public class CSourceCompiler { /// - /// compiler exe name + /// compiler exe name. /// - string _exe = "", _arch = "", _ext = ""; + private string _exe = string.Empty; + /// + /// compiler exe name. + /// + private string _arch = string.Empty; + /// + /// compiler exe name. + /// + private string _ext = string.Empty; + + public CSourceCompiler() + { + PlatformSpecific(); + ArchSpecific(); + } + + protected string Exe + { + get => _exe; + } /// - /// select current pattern's exe + /// select current pattern's exe. /// /// - void PlatformSpecific() + private void PlatformSpecific() { if (RuntimeInformation.IsOSPlatform(OSPlatform.Linux)) { @@ -62,7 +67,7 @@ void PlatformSpecific() } } - void ArchSpecific() + private void ArchSpecific() { _arch = RuntimeInformation.OSArchitecture switch { @@ -72,7 +77,7 @@ void ArchSpecific() }; } - string ArgumentsSpecific(string sourcePath, string outPath) + private string ArgumentsSpecific(string sourcePath, string outPath) { if (RuntimeInformation.IsOSPlatform(OSPlatform.Linux)) { @@ -88,12 +93,8 @@ string ArgumentsSpecific(string sourcePath, string outPath) var vcvardir = Path.Combine(vsdir, "..\\..\\VC\\Auxiliary\\Build\\vcvarsall.bat"); return $"/C (\"{vcvardir}\" x64) && (cl /D_USRDLL /D_WINDLL \"{sourcePath}\" /MT /link /DLL /OUT:\"{outPath}\")"; } - throw new System.ArgumentOutOfRangeException("Only Support Linux/Osx/Windows"); - } - protected string Exe - { - get => _exe; + throw new System.ArgumentOutOfRangeException("Only Support Linux/Osx/Windows"); } protected string Arch @@ -106,18 +107,12 @@ protected string Ext get => _ext; } - public CSourceCompiler() - { - PlatformSpecific(); - ArchSpecific(); - } - /// - /// compile the source txt, write to the out_path + /// compile the source txt, write to the out_path. /// - /// c source code - /// out .so path - /// outPath + /// c source code. + /// out .so path. + /// outPath. public string Compile(string sourcePath, string outPath) { var errMsg = new StringBuilder(); @@ -138,12 +133,26 @@ public string Compile(string sourcePath, string outPath) } } } + return outPath; } /// /// create the temp dll file and compile source - /// + /// . /// public string Compile(string sourcePath) => Compile(sourcePath, CodeGenUtil.GetTempFileName(Ext)); } + +internal sealed class FunctionCSource +{ + public FunctionCSource(string declaration, string implementation) + { + Declaration = declaration; + Implementation = implementation; + } + + public string Declaration { get; } + + public string Implementation { get; } +} diff --git a/modules/Nncase.Modules.CPU/CodeGen/LinkableModule.cs b/modules/Nncase.Modules.CPU/CodeGen/LinkableModule.cs index 350460e5a2..13ec1b0779 100644 --- a/modules/Nncase.Modules.CPU/CodeGen/LinkableModule.cs +++ b/modules/Nncase.Modules.CPU/CodeGen/LinkableModule.cs @@ -40,6 +40,7 @@ public ILinkedModule Link(ILinkContext linkContext) writer.Write(CSourceBuiltn.Header); } } + using (var fs = DumpScope.Current.OpenFile("cpuModule.c")) { File.Open(csourcePath, FileMode.Open, FileAccess.Read).CopyTo(fs); @@ -81,6 +82,7 @@ private string LinkCSources() writer.WriteLine($" {func.SourceFunction.Name}({string.Join(",", Enumerable.Range(0, func.PrimFunction.Parameters.Length).Select(i => $"buffers[{i}]"))}, nncase_mt, data, rdata);"); writer.WriteLine(" } else"); } + writer.WriteLine(" { }"); writer.WriteLine(CSourceBuiltn.MainEpilogue); } @@ -94,5 +96,4 @@ private string CompileCSource(string sourcePath) var compiler = new CSourceCompiler(); return compiler.Compile(sourcePath, Path.GetTempFileName()); } - } diff --git a/modules/Nncase.Modules.CPU/Passes/CPUFusionToTirPass.cs b/modules/Nncase.Modules.CPU/Passes/CPUFusionToTirPass.cs index 4afc56cb6f..293fd6a627 100644 --- a/modules/Nncase.Modules.CPU/Passes/CPUFusionToTirPass.cs +++ b/modules/Nncase.Modules.CPU/Passes/CPUFusionToTirPass.cs @@ -43,7 +43,7 @@ protected override Task RunCoreAsync(IRModule module, RunPassContext o var post = (Function)rewriter.Rewrite( func, - new Mutators.IMergeRewriteRule[] { new CPUSameInputFusionMergeRule() }, + new Mutators.IMergeRewriteRule[] { new CPUSameInputFusionMergeRule() }, (rule, option) => new CPUFusionGroupMutator(fusionCheckCache, _tileOptions, rule, option), new() { AnalysisResults = analysis, MatchOptions = new Mutators.FusionGroupMutator.GroupedMatchOptions() }); diff --git a/modules/Nncase.Modules.CPU/Passes/Tile/CPUFusionGroupMutator.cs b/modules/Nncase.Modules.CPU/Passes/Tile/CPUFusionGroupMutator.cs index d11af85bb5..f1f9daab07 100644 --- a/modules/Nncase.Modules.CPU/Passes/Tile/CPUFusionGroupMutator.cs +++ b/modules/Nncase.Modules.CPU/Passes/Tile/CPUFusionGroupMutator.cs @@ -55,7 +55,6 @@ public CPUFusionGroupMutator( /// public override bool MergedFusionCheckCallBack(Fusion mergedFusion, HashSet candidateFusions) { - // var checker = (IFusionChecker)Activator.CreateInstance(typeof(T), new object[] { _tileOptions })!; // var ret = checker.Check(mergedFusion, PassOptions); // if (ret) diff --git a/modules/Nncase.Modules.CPU/Passes/Tile/SingleCPUFusionConverter.cs b/modules/Nncase.Modules.CPU/Passes/Tile/SingleCPUFusionConverter.cs index a2c932eb92..502b22c321 100644 --- a/modules/Nncase.Modules.CPU/Passes/Tile/SingleCPUFusionConverter.cs +++ b/modules/Nncase.Modules.CPU/Passes/Tile/SingleCPUFusionConverter.cs @@ -46,7 +46,7 @@ public ConvertVisitor(List mainBody) _mainBody = mainBody; } - public Fusion VisitRootFusion => (Fusion)(VisitRoot!); + public Fusion VisitRootFusion => (Fusion)VisitRoot!; public IEnumerable OutputBuffers => _buffersMap.Values.OfType().Where(b => b.MemLocation == MemoryLocation.Output); @@ -54,7 +54,7 @@ public ConvertVisitor(List mainBody) protected override Unit DefaultVisitLeaf(Expr expr) { - return new(); + return default(Unit); } protected override Unit VisitLeafCall(Call expr) @@ -71,15 +71,16 @@ protected override Unit VisitLeafCall(Call expr) default: throw new NotSupportedException(); } - return new(); + + return default(Unit); } private void GenerateUnary(Unary unary, ReadOnlySpan arguments, Buffer ret) { var input = arguments[Unary.Input.Index]; var loops = Enumerable.Range(0, input.Rank).Select(i => (T.ForLoop(out var loopVar, (0, input.Dimensions[i]), LoopMode.Serial, $"loop_{i}"), loopVar)).ToArray(); - var input_index = Enumerable.Range(0, input.Rank).Aggregate((Expr)0, (acc, i) => acc + input.Strides[i] * loops[i].Item2); - var output_index = Enumerable.Range(0, input.Rank).Aggregate((Expr)0, (acc, i) => acc + ret.Strides[i] * loops[i].Item2); + var input_index = Enumerable.Range(0, input.Rank).Aggregate((Expr)0, (acc, i) => acc + (input.Strides[i] * loops[i].loopVar)); + var output_index = Enumerable.Range(0, input.Rank).Aggregate((Expr)0, (acc, i) => acc + (ret.Strides[i] * loops[i].loopVar)); Expr stmt = T.Store(ret, output_index, IR.F.Math.Unary(unary.UnaryOp, T.Load(input, output_index))); var final = loops.Reverse().Aggregate(stmt, (acc, p) => p.Item1.Body(acc).Build()); _mainBody.Add(T.Block(nameof(Unary)).Body(final).Build()); @@ -101,6 +102,7 @@ private TIR.Buffer TryAllocateBuffer(Expr expr) { buffer = T.Buffer(c.CheckedDataType, MemoryLocation.Data, c.CheckedShape.ToValueArray().Select(i => (Expr)i).ToArray(), out _, name); } + break; case Var v: buffer = T.PhysicalBuffer(v.CheckedDataType, MemoryLocation.Input, v.CheckedShape.ToValueArray(), out _, name); @@ -111,8 +113,10 @@ private TIR.Buffer TryAllocateBuffer(Expr expr) default: throw new NotSupportedException(); } + _buffersMap.Add(expr, buffer); } + return buffer; } } diff --git a/modules/cpu/include/nncase/runtime/cpu/runtime_module.h b/modules/cpu/include/nncase/runtime/cpu/runtime_module.h index 4642409eb2..9c071f41b6 100644 --- a/modules/cpu/include/nncase/runtime/cpu/runtime_module.h +++ b/modules/cpu/include/nncase/runtime/cpu/runtime_module.h @@ -20,7 +20,8 @@ BEGIN_NS_NNCASE_RT_MODULE(cpu) NNCASE_INLINE_VAR constexpr char FUNCTION_NAME_SECTION_IDENTIFIER[] = ".name"; -NNCASE_INLINE_VAR constexpr module_kind_t cpu_module_type = to_module_kind("cpu"); +NNCASE_INLINE_VAR constexpr module_kind_t cpu_module_type = + to_module_kind("cpu"); NNCASE_INLINE_VAR constexpr uint32_t cpu_module_version = 0; NNCASE_MODULES_CPU_API result> diff --git a/modules/cpu/src/runtime/runtime_function.h b/modules/cpu/src/runtime/runtime_function.h index b7554144b5..e6a2580c7f 100644 --- a/modules/cpu/src/runtime/runtime_function.h +++ b/modules/cpu/src/runtime/runtime_function.h @@ -26,9 +26,10 @@ class cpu_runtime_function : public runtime_function { cpu_runtime_module &module() const noexcept; protected: - result initialize_core(runtime_function_init_context &context) noexcept override; + result + initialize_core(runtime_function_init_context &context) noexcept override; result invoke_core(gsl::span parameters, - value_t return_value) noexcept override; + value_t return_value) noexcept override; private: std::string name_; diff --git a/modules/cpu/src/runtime/runtime_module.cpp b/modules/cpu/src/runtime/runtime_module.cpp index c8dad680fd..bf4481d159 100644 --- a/modules/cpu/src/runtime/runtime_module.cpp +++ b/modules/cpu/src/runtime/runtime_module.cpp @@ -27,7 +27,8 @@ result cpu_runtime_module::initialize_before_functions( if (!context.is_section_pinned()) return nncase::err(std::errc::bad_address); try_var(data_, context.get_or_read_section(".data", data_storage_, false)); - try_var(rdata_,context.get_or_read_section(".rdata", rdata_storage_, true)); + try_var(rdata_, + context.get_or_read_section(".rdata", rdata_storage_, true)); try_var(text_, context.get_or_read_section(".text", text_storage_, true)); return ok(); diff --git a/src/Nncase.Core/Schedule/ScheduleTypes.cs b/src/Nncase.Core/Schedule/ScheduleTypes.cs index 9c2e47f8e2..f63e0467dc 100644 --- a/src/Nncase.Core/Schedule/ScheduleTypes.cs +++ b/src/Nncase.Core/Schedule/ScheduleTypes.cs @@ -10,7 +10,6 @@ namespace Nncase.Schedule; - /// /// the scheduler interface. /// diff --git a/src/Nncase.Core/TIR/MemSpan.cs b/src/Nncase.Core/TIR/MemSpan.cs index c360932c46..533e9fbb04 100644 --- a/src/Nncase.Core/TIR/MemSpan.cs +++ b/src/Nncase.Core/TIR/MemSpan.cs @@ -1,9 +1,11 @@ +// Copyright (c) Canaan Inc. All rights reserved. +// Licensed under the Apache license. See LICENSE file in the project root for full license information. + using Nncase; using Nncase.IR; namespace Nncase.TIR; - /// /// the memory type. /// @@ -52,12 +54,14 @@ public enum MemoryLocation : byte public sealed class MemSpan : Expr { - public MemSpan(Expr size, MemoryLocation location) : base(new[] { None.Default, size }) + public MemSpan(Expr size, MemoryLocation location) + : base(new[] { None.Default, size }) { Location = location; } - public MemSpan(Expr start, Expr size, MemoryLocation location) : base(new[] { start, size }) + public MemSpan(Expr start, Expr size, MemoryLocation location) + : base(new[] { start, size }) { Location = location; } @@ -81,6 +85,5 @@ public MemSpan(Expr start, Expr size, MemoryLocation location) : base(new[] { st public override TExprResult Accept(ExprFunctor functor, TContext context) => functor.VisitMemSpan(this, context); - public MemSpan With(Expr? start = null, Expr? size = null, MemoryLocation? location = null) => new(start ?? Start, size ?? Size, location ?? Location); -} \ No newline at end of file +} diff --git a/src/Nncase.Evaluator/TIR/Store.cs b/src/Nncase.Evaluator/TIR/Store.cs index 573a5e8660..e0b6d5bcc7 100644 --- a/src/Nncase.Evaluator/TIR/Store.cs +++ b/src/Nncase.Evaluator/TIR/Store.cs @@ -33,7 +33,6 @@ public string Visit(IIRPrinterContext context, Store target, bool iLmode) private IRType Visit(Store target, TensorType handle, TensorType index, TensorType value) { - if (handle.DType != value.DType) { return new InvalidType($"You Can't Load The {value.DType} To {handle.DType}"); diff --git a/src/Nncase.Tests/Core/UnitTestTIR.cs b/src/Nncase.Tests/Core/UnitTestTIR.cs index 0fce4d672d..9f3b78cf12 100644 --- a/src/Nncase.Tests/Core/UnitTestTIR.cs +++ b/src/Nncase.Tests/Core/UnitTestTIR.cs @@ -49,8 +49,8 @@ public void TestBufferStore() { Expr value = 42; var physicalBuffer = new TIR.PhysicalBuffer("testInput", DataTypes.Float32, MemoryLocation.Input, new[] { 1, 16, 64, 400 }, TensorUtilities.GetStrides(new[] { 1, 16, 64, 400 }), 0, 0); - var indices = new Expr[] { 0, 1 }; - Call store = T.Store(physicalBuffer, 0, value); + _ = new Expr[] { 0, 1 }; + _ = T.Store(physicalBuffer, 0, value); } [Fact] diff --git a/src/Nncase.Tests/Targets/UnitTestCPUTargetTiling.cs b/src/Nncase.Tests/Targets/UnitTestCPUTargetTiling.cs index 376376f888..7156572dbf 100644 --- a/src/Nncase.Tests/Targets/UnitTestCPUTargetTiling.cs +++ b/src/Nncase.Tests/Targets/UnitTestCPUTargetTiling.cs @@ -47,6 +47,7 @@ public async Task TestCpuUnary() { compiler.Gencode(fs); } + using (var fs = Dumpper.OpenFile("input_0.bin")) { fs.Write(IR.F.Random.Normal(DataTypes.Float32, 0, 1, 2, new[] { 1, 2, 3, 4, 5 }).Evaluate().AsTensor().BytesBuffer); From 6ace500bc82a0012677d29e8144bd21fac7947ec Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?=E9=83=91=E5=90=AF=E8=88=AA?= <597323109@qq.com> Date: Mon, 24 Jul 2023 11:28:09 +0800 Subject: [PATCH 013/308] call cpu func by id --- modules/Nncase.Modules.CPU/CodeGen/CSourceBuiltn.cs | 13 +------------ .../Nncase.Modules.CPU/CodeGen/LinkableFunction.cs | 3 +-- .../Nncase.Modules.CPU/CodeGen/LinkableModule.cs | 13 ++++++++----- modules/cpu/src/runtime/runtime_function.cpp | 6 +++--- modules/cpu/src/runtime/runtime_function.h | 2 -- src/Native/include/nncase/runtime/interpreter.h | 1 + src/Native/include/nncase/runtime/runtime_module.h | 2 ++ src/Native/src/runtime/interpreter.cpp | 11 +++++++++++ src/Native/src/runtime/runtime_module.cpp | 13 +++++++++++++ 9 files changed, 40 insertions(+), 24 deletions(-) diff --git a/modules/Nncase.Modules.CPU/CodeGen/CSourceBuiltn.cs b/modules/Nncase.Modules.CPU/CodeGen/CSourceBuiltn.cs index 45b5563f5f..74fbfae3af 100644 --- a/modules/Nncase.Modules.CPU/CodeGen/CSourceBuiltn.cs +++ b/modules/Nncase.Modules.CPU/CodeGen/CSourceBuiltn.cs @@ -25,7 +25,7 @@ public static class CSourceBuiltn public const string FixedParameters = "nncase_mt_t* nncase_mt, void* data, void* rdata"; - public const string MainPrologue = $@"void _start(char* name, buffer_t** buffers, {FixedParameters}) {{"; + public const string MainPrologue = $@"void _start(size_t func_id, buffer_t** buffers, {FixedParameters}) {{"; public const string MainEpilogue = @"}"; @@ -36,17 +36,6 @@ public static class CSourceBuiltn {BufferStruct} -int strcmp(const char* s1,const char* s2) {{ - while(*s1 && *s2) {{ - if(*s1 != *s2) {{ - break; - }} - s1++; - s2++; - }} - return (*s1 - *s2) || (*s1 - '\0') || (*s2 - '\0'); -}} - static nncase_mt_t *nncase_mt; "; } diff --git a/modules/Nncase.Modules.CPU/CodeGen/LinkableFunction.cs b/modules/Nncase.Modules.CPU/CodeGen/LinkableFunction.cs index ae49c39b3c..25f264df0c 100644 --- a/modules/Nncase.Modules.CPU/CodeGen/LinkableFunction.cs +++ b/modules/Nncase.Modules.CPU/CodeGen/LinkableFunction.cs @@ -14,8 +14,7 @@ public LinkableFunction(uint id, TIR.PrimFunction sourceFunction, FunctionCSourc PrimFunction = sourceFunction; FunctionCSource = funcCSource; Text = Array.Empty(); - var name = System.Text.Encoding.ASCII.GetBytes(sourceFunction.Name); - Sections = new LinkedSection[] { new(name, ".name", 0, 8, (uint)name.Length) }; + Sections = new LinkedSection[] { }; } public uint Id { get; } diff --git a/modules/Nncase.Modules.CPU/CodeGen/LinkableModule.cs b/modules/Nncase.Modules.CPU/CodeGen/LinkableModule.cs index 13ec1b0779..89eb4d9b31 100644 --- a/modules/Nncase.Modules.CPU/CodeGen/LinkableModule.cs +++ b/modules/Nncase.Modules.CPU/CodeGen/LinkableModule.cs @@ -76,14 +76,17 @@ private string LinkCSources() } writer.WriteLine(CSourceBuiltn.MainPrologue); + writer.WriteLine($" switch (func_id) {{"); + foreach (var func in _functions) { - writer.WriteLine($" if (strcmp(name,\"{func.SourceFunction.Name}\") == 0) {{"); - writer.WriteLine($" {func.SourceFunction.Name}({string.Join(",", Enumerable.Range(0, func.PrimFunction.Parameters.Length).Select(i => $"buffers[{i}]"))}, nncase_mt, data, rdata);"); - writer.WriteLine(" } else"); + writer.WriteLine($" case {func.Id}:"); + writer.WriteLine($" {func.SourceFunction.Name}({string.Join(", ", Enumerable.Range(0, func.PrimFunction.Parameters.Length).Select(i => $"buffers[{i}]"))}, nncase_mt, data, rdata);"); + writer.WriteLine(" break;"); } - - writer.WriteLine(" { }"); + writer.WriteLine(" default: "); + writer.WriteLine(" break;"); + writer.WriteLine(" }"); writer.WriteLine(CSourceBuiltn.MainEpilogue); } } diff --git a/modules/cpu/src/runtime/runtime_function.cpp b/modules/cpu/src/runtime/runtime_function.cpp index a268b68403..8de1ee7917 100644 --- a/modules/cpu/src/runtime/runtime_function.cpp +++ b/modules/cpu/src/runtime/runtime_function.cpp @@ -27,14 +27,14 @@ cpu_runtime_module &cpu_runtime_function::module() const noexcept { result cpu_runtime_function::initialize_core( NNCASE_UNUSED runtime_function_init_context &context) noexcept { - try_var(name, context.section(FUNCTION_NAME_SECTION_IDENTIFIER)); - name_ = std::string(name.as_span().data()); return ok(); } result cpu_runtime_function::invoke_core(NNCASE_UNUSED gsl::span parameters, value_t return_value) noexcept { - std::cout << "call " << name_ << std::endl; + module().interp(); + try_var(id, module().find_id_by_function(this)); + std::cout << "call " << id << std::endl; return ok(return_value); } \ No newline at end of file diff --git a/modules/cpu/src/runtime/runtime_function.h b/modules/cpu/src/runtime/runtime_function.h index e6a2580c7f..05d8eb099a 100644 --- a/modules/cpu/src/runtime/runtime_function.h +++ b/modules/cpu/src/runtime/runtime_function.h @@ -31,8 +31,6 @@ class cpu_runtime_function : public runtime_function { result invoke_core(gsl::span parameters, value_t return_value) noexcept override; - private: - std::string name_; }; END_NS_NNCASE_RT_MODULE diff --git a/src/Native/include/nncase/runtime/interpreter.h b/src/Native/include/nncase/runtime/interpreter.h index fc91970657..a8f5814d34 100644 --- a/src/Native/include/nncase/runtime/interpreter.h +++ b/src/Native/include/nncase/runtime/interpreter.h @@ -73,6 +73,7 @@ class NNCASE_API interpreter { options_dict &options() noexcept; result find_module_by_id(size_t index) noexcept; + result find_id_by_module(runtime_module *module) noexcept; /* V1 APIs */ diff --git a/src/Native/include/nncase/runtime/runtime_module.h b/src/Native/include/nncase/runtime/runtime_module.h index 354d747cbf..444209016f 100644 --- a/src/Native/include/nncase/runtime/runtime_module.h +++ b/src/Native/include/nncase/runtime/runtime_module.h @@ -57,6 +57,8 @@ class NNCASE_API runtime_module { interpreter &interp() const noexcept { return *interp_; } result find_function_by_id(size_t index) noexcept; + + result find_id_by_function(runtime_function * function) noexcept; protected: virtual result diff --git a/src/Native/src/runtime/interpreter.cpp b/src/Native/src/runtime/interpreter.cpp index dc5839fb44..fe69b8a071 100644 --- a/src/Native/src/runtime/interpreter.cpp +++ b/src/Native/src/runtime/interpreter.cpp @@ -246,6 +246,17 @@ result interpreter::find_module_by_id(size_t index) noexcept { return ok(modules_[index].get()); } +result interpreter::find_id_by_module(runtime_module *module) noexcept { + auto it = std::find_if(modules_.begin(), modules_.end(), + [&module](const std::unique_ptr &p) { + return p.get() == module; + }); + if (it == modules_.end()) { + return err(std::errc::result_out_of_range); + } + return ok((it - modules_.begin())); +} + options_dict &interpreter::options() noexcept { return options_; } result interpreter::entry_function() noexcept { diff --git a/src/Native/src/runtime/runtime_module.cpp b/src/Native/src/runtime/runtime_module.cpp index 4b5d747bbd..66acaca61b 100644 --- a/src/Native/src/runtime/runtime_module.cpp +++ b/src/Native/src/runtime/runtime_module.cpp @@ -189,6 +189,19 @@ runtime_module::find_function_by_id(size_t index) noexcept { return ok(functions_[index].get()); } +result +runtime_module::find_id_by_function(runtime_function *function) noexcept { + auto it = + std::find_if(functions_.begin(), functions_.end(), + [&function](const std::unique_ptr &p) { + return p.get() == function; + }); + if (it == functions_.end()) { + return err(std::errc::result_out_of_range); + } + return ok((it - functions_.begin())); +} + result runtime_module::initialize_before_functions( NNCASE_UNUSED runtime_module_init_context &context) noexcept { return ok(); From 96fd3381c09a92b582ffcee318bfcd940c74e744 Mon Sep 17 00:00:00 2001 From: zhen8838 Date: Mon, 24 Jul 2023 03:31:56 +0000 Subject: [PATCH 014/308] Apply code-format changes --- .../Nncase.Modules.CPU/CodeGen/CSourceUtilities.cs | 2 +- .../Nncase.Modules.CPU/CodeGen/FunctionCSource.cs | 12 +++++++----- .../Nncase.Modules.CPU/CodeGen/LinkableFunction.cs | 2 +- modules/Nncase.Modules.CPU/CodeGen/LinkableModule.cs | 1 + .../Passes/Tile/SingleCPUFusionConverter.cs | 4 ++-- modules/cpu/src/runtime/runtime_function.cpp | 2 +- modules/cpu/src/runtime/runtime_function.h | 1 - src/Native/include/nncase/runtime/runtime_module.h | 4 ++-- src/Nncase.Core/TIR/MemSpan.cs | 2 +- 9 files changed, 16 insertions(+), 14 deletions(-) diff --git a/modules/Nncase.Modules.CPU/CodeGen/CSourceUtilities.cs b/modules/Nncase.Modules.CPU/CodeGen/CSourceUtilities.cs index 7026825fa2..e7e479017b 100644 --- a/modules/Nncase.Modules.CPU/CodeGen/CSourceUtilities.cs +++ b/modules/Nncase.Modules.CPU/CodeGen/CSourceUtilities.cs @@ -1,4 +1,4 @@ -// Copyright (c) Canaan Inc. All rights reserved. +// Copyright (c) Canaan Inc. All rights reserved. // Licensed under the Apache license. See LICENSE file in the project root for full license information. using Nncase.Diagnostics; diff --git a/modules/Nncase.Modules.CPU/CodeGen/FunctionCSource.cs b/modules/Nncase.Modules.CPU/CodeGen/FunctionCSource.cs index b5f7ad2c86..8b98277d30 100644 --- a/modules/Nncase.Modules.CPU/CodeGen/FunctionCSource.cs +++ b/modules/Nncase.Modules.CPU/CodeGen/FunctionCSource.cs @@ -24,10 +24,12 @@ public class CSourceCompiler /// compiler exe name. /// private string _exe = string.Empty; + /// /// compiler exe name. /// private string _arch = string.Empty; + /// /// compiler exe name. /// @@ -44,6 +46,11 @@ protected string Exe get => _exe; } + protected string Arch + { + get => _arch; + } + /// /// select current pattern's exe. /// @@ -97,11 +104,6 @@ private string ArgumentsSpecific(string sourcePath, string outPath) throw new System.ArgumentOutOfRangeException("Only Support Linux/Osx/Windows"); } - protected string Arch - { - get => _arch; - } - protected string Ext { get => _ext; diff --git a/modules/Nncase.Modules.CPU/CodeGen/LinkableFunction.cs b/modules/Nncase.Modules.CPU/CodeGen/LinkableFunction.cs index 25f264df0c..37ca1f584f 100644 --- a/modules/Nncase.Modules.CPU/CodeGen/LinkableFunction.cs +++ b/modules/Nncase.Modules.CPU/CodeGen/LinkableFunction.cs @@ -14,7 +14,7 @@ public LinkableFunction(uint id, TIR.PrimFunction sourceFunction, FunctionCSourc PrimFunction = sourceFunction; FunctionCSource = funcCSource; Text = Array.Empty(); - Sections = new LinkedSection[] { }; + Sections = Array.Empty(); } public uint Id { get; } diff --git a/modules/Nncase.Modules.CPU/CodeGen/LinkableModule.cs b/modules/Nncase.Modules.CPU/CodeGen/LinkableModule.cs index 89eb4d9b31..cfa27dc8b5 100644 --- a/modules/Nncase.Modules.CPU/CodeGen/LinkableModule.cs +++ b/modules/Nncase.Modules.CPU/CodeGen/LinkableModule.cs @@ -84,6 +84,7 @@ private string LinkCSources() writer.WriteLine($" {func.SourceFunction.Name}({string.Join(", ", Enumerable.Range(0, func.PrimFunction.Parameters.Length).Select(i => $"buffers[{i}]"))}, nncase_mt, data, rdata);"); writer.WriteLine(" break;"); } + writer.WriteLine(" default: "); writer.WriteLine(" break;"); writer.WriteLine(" }"); diff --git a/modules/Nncase.Modules.CPU/Passes/Tile/SingleCPUFusionConverter.cs b/modules/Nncase.Modules.CPU/Passes/Tile/SingleCPUFusionConverter.cs index 502b22c321..1e54700651 100644 --- a/modules/Nncase.Modules.CPU/Passes/Tile/SingleCPUFusionConverter.cs +++ b/modules/Nncase.Modules.CPU/Passes/Tile/SingleCPUFusionConverter.cs @@ -54,7 +54,7 @@ public ConvertVisitor(List mainBody) protected override Unit DefaultVisitLeaf(Expr expr) { - return default(Unit); + return default; } protected override Unit VisitLeafCall(Call expr) @@ -72,7 +72,7 @@ protected override Unit VisitLeafCall(Call expr) throw new NotSupportedException(); } - return default(Unit); + return default; } private void GenerateUnary(Unary unary, ReadOnlySpan arguments, Buffer ret) diff --git a/modules/cpu/src/runtime/runtime_function.cpp b/modules/cpu/src/runtime/runtime_function.cpp index 8de1ee7917..3f7cbe4156 100644 --- a/modules/cpu/src/runtime/runtime_function.cpp +++ b/modules/cpu/src/runtime/runtime_function.cpp @@ -33,7 +33,7 @@ result cpu_runtime_function::initialize_core( result cpu_runtime_function::invoke_core(NNCASE_UNUSED gsl::span parameters, value_t return_value) noexcept { - module().interp(); + module().interp(); try_var(id, module().find_id_by_function(this)); std::cout << "call " << id << std::endl; return ok(return_value); diff --git a/modules/cpu/src/runtime/runtime_function.h b/modules/cpu/src/runtime/runtime_function.h index 05d8eb099a..79a524dfc1 100644 --- a/modules/cpu/src/runtime/runtime_function.h +++ b/modules/cpu/src/runtime/runtime_function.h @@ -30,7 +30,6 @@ class cpu_runtime_function : public runtime_function { initialize_core(runtime_function_init_context &context) noexcept override; result invoke_core(gsl::span parameters, value_t return_value) noexcept override; - }; END_NS_NNCASE_RT_MODULE diff --git a/src/Native/include/nncase/runtime/runtime_module.h b/src/Native/include/nncase/runtime/runtime_module.h index 444209016f..194dd3b1f7 100644 --- a/src/Native/include/nncase/runtime/runtime_module.h +++ b/src/Native/include/nncase/runtime/runtime_module.h @@ -57,8 +57,8 @@ class NNCASE_API runtime_module { interpreter &interp() const noexcept { return *interp_; } result find_function_by_id(size_t index) noexcept; - - result find_id_by_function(runtime_function * function) noexcept; + + result find_id_by_function(runtime_function *function) noexcept; protected: virtual result diff --git a/src/Nncase.Core/TIR/MemSpan.cs b/src/Nncase.Core/TIR/MemSpan.cs index 533e9fbb04..0ee5b4a215 100644 --- a/src/Nncase.Core/TIR/MemSpan.cs +++ b/src/Nncase.Core/TIR/MemSpan.cs @@ -1,4 +1,4 @@ -// Copyright (c) Canaan Inc. All rights reserved. +// Copyright (c) Canaan Inc. All rights reserved. // Licensed under the Apache license. See LICENSE file in the project root for full license information. using Nncase; From dd6451e51f9f110c548cdce6d22d0d88e8ffb408 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?=E9=83=91=E5=90=AF=E8=88=AA?= <597323109@qq.com> Date: Mon, 24 Jul 2023 16:14:39 +0800 Subject: [PATCH 015/308] add desc header --- .../CodeGen/FunctionBuilder.cs | 88 ++++++++++++++++++- .../CodeGen/LinkableFunction.cs | 5 +- .../nncase/runtime/cpu/runtime_module.h | 2 - modules/cpu/src/runtime/runtime_function.cpp | 59 +++++++++++++ 4 files changed, 148 insertions(+), 6 deletions(-) diff --git a/modules/Nncase.Modules.CPU/CodeGen/FunctionBuilder.cs b/modules/Nncase.Modules.CPU/CodeGen/FunctionBuilder.cs index 969fbcec47..7cd4df4f49 100644 --- a/modules/Nncase.Modules.CPU/CodeGen/FunctionBuilder.cs +++ b/modules/Nncase.Modules.CPU/CodeGen/FunctionBuilder.cs @@ -19,6 +19,43 @@ internal class FunctionBuilder : IDisposable private readonly BinaryWriter _textWriter; private readonly BinaryWriter _rdataWriter; + /// + /// NOTE sync with the k230 runtime function. + /// + [StructLayout(LayoutKind.Sequential)] + private struct MemoryRange + { + public uint Start; + public uint Size; + } + + /// + /// NOTE sync with the k230 runtime function. + /// + [StructLayout(LayoutKind.Sequential)] + private unsafe struct DescHeader + { + /// + /// input pool size. + /// + public uint InputPoolSize; + + /// + /// output pool size. + /// + public uint OutputPoolSize; + + /// + /// input numbers. + /// + public uint Inputs; + + /// + /// output numbers. + /// + public uint Outputs; + } + public FunctionBuilder(uint id, BinaryWriter rdataWriter) { _id = id; @@ -33,7 +70,54 @@ public unsafe LinkableFunction Build(TIR.PrimFunction function) visitor.Visit(function); var functionCSource = visitor.GetFunctionCSource(); - // 2. write the rdata + // 2. write the desc + var descContent = new MemoryStream(); + using (var descWriter = new BinaryWriter(descContent, Encoding.UTF8)) + { + DescHeader header = new() { InputPoolSize = 0, OutputPoolSize = 0, Inputs = 0, Outputs = 0 }; + long headerStart = descWriter.Position(); + descWriter.Skip((ulong)sizeof(DescHeader)); + + foreach (var input in function.Parameters.AsValueEnumerable() + .Where(buf => buf.MemLocation == TIR.MemoryLocation.Input)) + { + header.Inputs++; + var rg = new MemoryRange { Start = checked((uint)input.Start), Size = checked((uint)input.Size) }; + descWriter.Write(ref rg); + header.InputPoolSize = Math.Max(header.InputPoolSize, rg.Start + rg.Size); + descWriter.Write((uint)input.FixedDimensions.Length); + foreach (var dim in input.FixedDimensions) + { + descWriter.Write((uint)dim); + } + foreach (var s in input.FixedStrides) + { + descWriter.Write((uint)s); + } + } + + foreach (var output in function.Parameters.AsValueEnumerable().Where(buf => buf.MemLocation == TIR.MemoryLocation.Output)) + { + header.Outputs++; + var rg = new MemoryRange { Start = checked((uint)output.Start), Size = checked((uint)output.Size) }; + descWriter.Write(ref rg); + header.OutputPoolSize = Math.Max(header.OutputPoolSize, rg.Start + rg.Size); + descWriter.Write((uint)output.FixedDimensions.Length); + foreach (var dim in output.FixedDimensions) + { + descWriter.Write((uint)dim); + } + foreach (var s in output.FixedStrides) + { + descWriter.Write((uint)s); + } + } + + descWriter.Position(headerStart); + descWriter.Write(ref header); + } + + // 3. write the rdata foreach (var buffer in function.SchedResult.Rdatas) { var bytes = buffer.Const!.Value.BytesBuffer; @@ -46,7 +130,7 @@ public unsafe LinkableFunction Build(TIR.PrimFunction function) _rdataWriter.Write(bytes); } - return new LinkableFunction(_id, function, functionCSource); + return new LinkableFunction(_id, descContent.ToArray(), function, functionCSource); } public void Dispose() diff --git a/modules/Nncase.Modules.CPU/CodeGen/LinkableFunction.cs b/modules/Nncase.Modules.CPU/CodeGen/LinkableFunction.cs index 37ca1f584f..8757cc2ccd 100644 --- a/modules/Nncase.Modules.CPU/CodeGen/LinkableFunction.cs +++ b/modules/Nncase.Modules.CPU/CodeGen/LinkableFunction.cs @@ -1,20 +1,21 @@ // Copyright (c) Canaan Inc. All rights reserved. // Licensed under the Apache license. See LICENSE file in the project root for full license information. +using System.Runtime.InteropServices; using Nncase.IR; namespace Nncase.CodeGen.CPU; internal sealed class LinkableFunction : ILinkableFunction { - public LinkableFunction(uint id, TIR.PrimFunction sourceFunction, FunctionCSource funcCSource) + public LinkableFunction(uint id, byte[] descContents, TIR.PrimFunction sourceFunction, FunctionCSource funcCSource) { Id = id; SourceFunction = sourceFunction; PrimFunction = sourceFunction; FunctionCSource = funcCSource; Text = Array.Empty(); - Sections = Array.Empty(); + Sections = new ILinkedSection[] { new LinkedSection(descContents, ".desc", 0, 8, (uint)descContents.Length) }; } public uint Id { get; } diff --git a/modules/cpu/include/nncase/runtime/cpu/runtime_module.h b/modules/cpu/include/nncase/runtime/cpu/runtime_module.h index 9c071f41b6..487908008d 100644 --- a/modules/cpu/include/nncase/runtime/cpu/runtime_module.h +++ b/modules/cpu/include/nncase/runtime/cpu/runtime_module.h @@ -18,8 +18,6 @@ BEGIN_NS_NNCASE_RT_MODULE(cpu) -NNCASE_INLINE_VAR constexpr char FUNCTION_NAME_SECTION_IDENTIFIER[] = ".name"; - NNCASE_INLINE_VAR constexpr module_kind_t cpu_module_type = to_module_kind("cpu"); NNCASE_INLINE_VAR constexpr uint32_t cpu_module_version = 0; diff --git a/modules/cpu/src/runtime/runtime_function.cpp b/modules/cpu/src/runtime/runtime_function.cpp index 3f7cbe4156..9934a6985c 100644 --- a/modules/cpu/src/runtime/runtime_function.cpp +++ b/modules/cpu/src/runtime/runtime_function.cpp @@ -21,12 +21,71 @@ using namespace nncase; using namespace nncase::runtime; using namespace nncase::runtime::cpu; +namespace { +typedef struct memory_range { + uint32_t start; + uint32_t size; +} memory_range_t; + +typedef struct desc_header { + uint32_t input_pool_size; + + uint32_t output_pool_size; + + uint32_t inputs; + + uint32_t outputs; +} desc_header_t; + +} // namespace + cpu_runtime_module &cpu_runtime_function::module() const noexcept { return static_cast(runtime_function::module()); } result cpu_runtime_function::initialize_core( NNCASE_UNUSED runtime_function_init_context &context) noexcept { + + try_(context.read_section(".desc", [this](auto sr, size_t) -> result { + auto header = sr.template read(); + if (parameters_size() != header.inputs + header.outputs) + return nncase::err(std::errc::invalid_argument); + + for (uint32_t i = 0; i < header.inputs; i++) { + sr.template read(); + auto rank = sr.template read(); + std::cout << "shape: "; + for (uint32_t j = 0; j < rank; j++) { + std::cout << sr.template read() << ", "; + } + std::cout << std::endl; + + std::cout << "stride: "; + for (uint32_t j = 0; j < rank; j++) { + std::cout << sr.template read() << ", "; + } + std::cout << std::endl; + } + + for (uint32_t i = 0; i < header.outputs; i++) { + sr.template read(); + auto rank = sr.template read(); + std::cout << "shape: "; + for (uint32_t j = 0; j < rank; j++) { + std::cout << sr.template read() << ", "; + } + std::cout << std::endl; + + std::cout << "stride: "; + for (uint32_t j = 0; j < rank; j++) { + std::cout << sr.template read() << ", "; + } + std::cout << std::endl; + } + + return ok(); + })); + return ok(); } From e37b6a1c0372ab85ac71da8d958ed133dccdaccb Mon Sep 17 00:00:00 2001 From: zhen8838 Date: Mon, 24 Jul 2023 08:18:12 +0000 Subject: [PATCH 016/308] Apply code-format changes --- modules/Nncase.Modules.CPU/CodeGen/FunctionCSource.cs | 10 +++++----- 1 file changed, 5 insertions(+), 5 deletions(-) diff --git a/modules/Nncase.Modules.CPU/CodeGen/FunctionCSource.cs b/modules/Nncase.Modules.CPU/CodeGen/FunctionCSource.cs index 8b98277d30..73bcd2f605 100644 --- a/modules/Nncase.Modules.CPU/CodeGen/FunctionCSource.cs +++ b/modules/Nncase.Modules.CPU/CodeGen/FunctionCSource.cs @@ -51,6 +51,11 @@ protected string Arch get => _arch; } + protected string Ext + { + get => _ext; + } + /// /// select current pattern's exe. /// @@ -104,11 +109,6 @@ private string ArgumentsSpecific(string sourcePath, string outPath) throw new System.ArgumentOutOfRangeException("Only Support Linux/Osx/Windows"); } - protected string Ext - { - get => _ext; - } - /// /// compile the source txt, write to the out_path. /// From 834ab82250fe8c68fa642b56a6d629ff20c5c2de Mon Sep 17 00:00:00 2001 From: huochenghai Date: Mon, 24 Jul 2023 19:58:04 +0800 Subject: [PATCH 017/308] add elfloader --- .../CodeGen/CSourceBuiltn.cs | 28 +- .../CodeGen/FunctionCSource.cs | 2 +- .../CodeGen/LinkableModule.cs | 4 +- modules/cpu/src/runtime/CMakeLists.txt | 8 +- modules/cpu/src/runtime/cpu_common.h | 71 +++ modules/cpu/src/runtime/elf.h | 578 ++++++++++++++++++ modules/cpu/src/runtime/elfarch.h | 35 ++ modules/cpu/src/runtime/elfload.cpp | 269 ++++++++ modules/cpu/src/runtime/elfload.h | 99 +++ modules/cpu/src/runtime/elfloader.cpp | 40 ++ modules/cpu/src/runtime/elfloader.h | 62 ++ modules/cpu/src/runtime/elfreloc_aarch64.cpp | 67 ++ modules/cpu/src/runtime/elfreloc_amd64.cpp | 28 + modules/cpu/src/runtime/elfreloc_i386.cpp | 28 + modules/cpu/src/runtime/elfreloc_riscv64.cpp | 33 + modules/cpu/src/runtime/runtime_function.cpp | 62 +- modules/cpu/src/runtime/runtime_function.h | 8 + modules/cpu/src/runtime/runtime_module.cpp | 8 +- 18 files changed, 1414 insertions(+), 16 deletions(-) create mode 100644 modules/cpu/src/runtime/cpu_common.h create mode 100644 modules/cpu/src/runtime/elf.h create mode 100644 modules/cpu/src/runtime/elfarch.h create mode 100644 modules/cpu/src/runtime/elfload.cpp create mode 100644 modules/cpu/src/runtime/elfload.h create mode 100644 modules/cpu/src/runtime/elfloader.cpp create mode 100644 modules/cpu/src/runtime/elfloader.h create mode 100644 modules/cpu/src/runtime/elfreloc_aarch64.cpp create mode 100644 modules/cpu/src/runtime/elfreloc_amd64.cpp create mode 100644 modules/cpu/src/runtime/elfreloc_i386.cpp create mode 100644 modules/cpu/src/runtime/elfreloc_riscv64.cpp diff --git a/modules/Nncase.Modules.CPU/CodeGen/CSourceBuiltn.cs b/modules/Nncase.Modules.CPU/CodeGen/CSourceBuiltn.cs index 74fbfae3af..6eb0ae594d 100644 --- a/modules/Nncase.Modules.CPU/CodeGen/CSourceBuiltn.cs +++ b/modules/Nncase.Modules.CPU/CodeGen/CSourceBuiltn.cs @@ -10,13 +10,33 @@ public static class CSourceBuiltn public const string BufferStruct = @"typedef struct buffer { void *vaddr; size_t paddr; - int *shape; - int *stride; - int rank; + uint32_t *shape; + uint32_t *stride; + uint32_t rank; } buffer_t;"; public const string MethodTable = @"typedef struct nncase_method_table { - float (*float_unary_asin)(float); + float (*float_unary_abs)(float); + float (*float_unary_acos)(float); + float (*float_unary_acosh)(float); + float (*float_unary_asin)(float); + float (*float_unary_asinh)(float); + float (*float_unary_ceil)(float); + float (*float_unary_cos)(float); + float (*float_unary_cosh)(float); + float (*float_unary_exp)(float); + float (*float_unary_floor)(float); + float (*float_unary_log)(float); + float (*float_unary_logical_not)(float); + float (*float_unary_neg)(float); + float (*float_unary_round)(float); + float (*float_unary_rsqrt)(float); + float (*float_unary_sign)(float); + float (*float_unary_sin)(float); + float (*float_unary_sinh)(float); + float (*float_unary_sqrt)(float); + float (*float_unary_square)(float); + float (*float_unary_tanh)(float); } nncase_mt_t;"; public const string Include = @"#include diff --git a/modules/Nncase.Modules.CPU/CodeGen/FunctionCSource.cs b/modules/Nncase.Modules.CPU/CodeGen/FunctionCSource.cs index 73bcd2f605..cd5ff86d9d 100644 --- a/modules/Nncase.Modules.CPU/CodeGen/FunctionCSource.cs +++ b/modules/Nncase.Modules.CPU/CodeGen/FunctionCSource.cs @@ -93,7 +93,7 @@ private string ArgumentsSpecific(string sourcePath, string outPath) { if (RuntimeInformation.IsOSPlatform(OSPlatform.Linux)) { - return $"{sourcePath} -fPIC -shared -march={Arch} -o {outPath}"; + return $"{sourcePath} -nostdlib -static -no-pie -fPIC -march={Arch} -o {outPath}"; } else if (RuntimeInformation.IsOSPlatform(OSPlatform.OSX)) { diff --git a/modules/Nncase.Modules.CPU/CodeGen/LinkableModule.cs b/modules/Nncase.Modules.CPU/CodeGen/LinkableModule.cs index cfa27dc8b5..e2e53e6b19 100644 --- a/modules/Nncase.Modules.CPU/CodeGen/LinkableModule.cs +++ b/modules/Nncase.Modules.CPU/CodeGen/LinkableModule.cs @@ -58,7 +58,7 @@ public ILinkedModule Link(ILinkContext linkContext) private string LinkCSources() { - var path = Path.GetTempFileName(); + var path = Path.GetTempFileName() + ".c"; using (var fs = File.OpenWrite(path)) { using (var writer = new StreamWriter(fs)) @@ -98,6 +98,6 @@ private string LinkCSources() private string CompileCSource(string sourcePath) { var compiler = new CSourceCompiler(); - return compiler.Compile(sourcePath, Path.GetTempFileName()); + return compiler.Compile(sourcePath, Path.GetTempFileName() + ".elf"); } } diff --git a/modules/cpu/src/runtime/CMakeLists.txt b/modules/cpu/src/runtime/CMakeLists.txt index fe8fc9236a..ebba027d8d 100644 --- a/modules/cpu/src/runtime/CMakeLists.txt +++ b/modules/cpu/src/runtime/CMakeLists.txt @@ -1,7 +1,13 @@ cmake_minimum_required (VERSION 3.13) set(SRCS runtime_module.cpp - runtime_function.cpp) + runtime_function.cpp + elfload.cpp + elfloader.cpp + elfreloc_aarch64.cpp + elfreloc_amd64.cpp + elfreloc_i386.cpp + elfreloc_riscv64.cpp) if (BUILDING_RUNTIME) if (ENABLE_CPU_RUNTIME) diff --git a/modules/cpu/src/runtime/cpu_common.h b/modules/cpu/src/runtime/cpu_common.h new file mode 100644 index 0000000000..70c22065b9 --- /dev/null +++ b/modules/cpu/src/runtime/cpu_common.h @@ -0,0 +1,71 @@ +#pragma once +#include +#include +#include +#include +#include + +BEGIN_NS_NNCASE_RT_MODULE(cpu) + +typedef struct nncase_method_table { + float (*float_unary_abs)(float); + float (*float_unary_acos)(float); + float (*float_unary_acosh)(float); + float (*float_unary_asin)(float); + float (*float_unary_asinh)(float); + float (*float_unary_ceil)(float); + float (*float_unary_cos)(float); + float (*float_unary_cosh)(float); + float (*float_unary_exp)(float); + float (*float_unary_floor)(float); + float (*float_unary_log)(float); + float (*float_unary_logical_not)(float); + float (*float_unary_neg)(float); + float (*float_unary_round)(float); + float (*float_unary_rsqrt)(float); + float (*float_unary_sign)(float); + float (*float_unary_sin)(float); + float (*float_unary_sinh)(float); + float (*float_unary_sqrt)(float); + float (*float_unary_square)(float); + float (*float_unary_tanh)(float); +} nncase_mt_t; + +typedef struct buffer { + void *vaddr; + size_t paddr; + uint32_t *shape; + uint32_t *stride; + uint32_t rank; +} buffer_t; + +inline float float_unary_logical_not(float x) { return !x; } +inline float float_unary_neg(float x) { return std::negate()(x); } +inline float float_unary_rsqrt(float x) { return 1.f / sqrtf(x); } +inline float float_unary_sign(float x) { return (0.f < x) - (x < 0.f); } +inline float float_unary_square(float x) { return x * x; } + +[[maybe_unused]] static nncase_mt_t nncase_mt = { + .float_unary_abs = fabsf, + .float_unary_acos = acosf, + .float_unary_acosh = acoshf, + .float_unary_asin = asinf, + .float_unary_asinh = asinhf, + .float_unary_ceil = ceilf, + .float_unary_cos = cosf, + .float_unary_cosh = coshf, + .float_unary_exp = expf, + .float_unary_floor = floorf, + .float_unary_log = logf, + .float_unary_logical_not = &float_unary_logical_not, + .float_unary_neg = &float_unary_neg, + .float_unary_round = roundf, + .float_unary_rsqrt = &float_unary_rsqrt, + .float_unary_sign = &float_unary_sign, + .float_unary_sin = sinf, + .float_unary_sinh = sinhf, + .float_unary_sqrt = sqrtf, + .float_unary_square = &float_unary_square, + .float_unary_tanh = tanhf}; + +END_NS_NNCASE_RT_MODULE \ No newline at end of file diff --git a/modules/cpu/src/runtime/elf.h b/modules/cpu/src/runtime/elf.h new file mode 100644 index 0000000000..027f99ad0c --- /dev/null +++ b/modules/cpu/src/runtime/elf.h @@ -0,0 +1,578 @@ +/* $OpenBSD: exec_elf.h,v 1.53 2014/01/03 03:00:39 guenther Exp $ */ +/* + * Copyright (c) 1995, 1996 Erik Theisen. All rights reserved. + * + * Redistribution and use in source and binary forms, with or without + * modification, are permitted provided that the following conditions + * are met: + * 1. Redistributions of source code must retain the above copyright + * notice, this list of conditions and the following disclaimer. + * 2. Redistributions in binary form must reproduce the above copyright + * notice, this list of conditions and the following disclaimer in the + * documentation and/or other materials provided with the distribution. + * 3. The name of the author may not be used to endorse or promote products + * derived from this software without specific prior written permission + * + * THIS SOFTWARE IS PROVIDED BY THE AUTHOR ``AS IS'' AND ANY EXPRESS OR + * IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED WARRANTIES + * OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE DISCLAIMED. + * IN NO EVENT SHALL THE AUTHOR BE LIABLE FOR ANY DIRECT, INDIRECT, + * INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT + * NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, + * DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY + * THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT + * (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF + * THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + */ + +/* imported sys/exec_elf.h from OpenBSD */ + +#ifndef ELF_H +#define ELF_H +#include + +typedef uint8_t Elf_Byte; + +typedef uint32_t Elf32_Addr; /* Unsigned program address */ +typedef uint32_t Elf32_Off; /* Unsigned file offset */ +typedef int32_t Elf32_Sword; /* Signed large integer */ +typedef uint32_t Elf32_Word; /* Unsigned large integer */ +typedef uint16_t Elf32_Half; /* Unsigned medium integer */ + +typedef uint64_t Elf64_Addr; +typedef uint64_t Elf64_Off; +typedef int32_t Elf64_Shalf; + +#ifdef __alpha__ +typedef int64_t Elf64_Sword; +typedef uint64_t Elf64_Word; +#else +typedef int32_t Elf64_Sword; +typedef uint32_t Elf64_Word; +#endif + +typedef int64_t Elf64_Sxword; +typedef uint64_t Elf64_Xword; + +typedef uint32_t Elf64_Half; +typedef uint16_t Elf64_Quarter; + +/* + * e_ident[] identification indexes + * See http://www.sco.com/developers/gabi/latest/ch4.eheader.html + */ +#define EI_MAG0 0 /* file ID */ +#define EI_MAG1 1 /* file ID */ +#define EI_MAG2 2 /* file ID */ +#define EI_MAG3 3 /* file ID */ +#define EI_CLASS 4 /* file class */ +#define EI_DATA 5 /* data encoding */ +#define EI_VERSION 6 /* ELF header version */ +#define EI_OSABI 7 /* OS/ABI ID */ +#define EI_ABIVERSION 8 /* ABI version */ +#define EI_PAD 9 /* start of pad bytes */ +#define EI_NIDENT 16 /* Size of e_ident[] */ + +/* e_ident[] magic number */ +#define ELFMAG0 0x7f /* e_ident[EI_MAG0] */ +#define ELFMAG1 'E' /* e_ident[EI_MAG1] */ +#define ELFMAG2 'L' /* e_ident[EI_MAG2] */ +#define ELFMAG3 'F' /* e_ident[EI_MAG3] */ +#define ELFMAG "\177ELF" /* magic */ +#define SELFMAG 4 /* size of magic */ + +/* e_ident[] file class */ +#define ELFCLASSNONE 0 /* invalid */ +#define ELFCLASS32 1 /* 32-bit objs */ +#define ELFCLASS64 2 /* 64-bit objs */ +#define ELFCLASSNUM 3 /* number of classes */ + +/* e_ident[] data encoding */ +#define ELFDATANONE 0 /* invalid */ +#define ELFDATA2LSB 1 /* Little-Endian */ +#define ELFDATA2MSB 2 /* Big-Endian */ +#define ELFDATANUM 3 /* number of data encode defines */ + +/* e_ident[] Operating System/ABI */ +#define ELFOSABI_SYSV 0 /* UNIX System V ABI */ +#define ELFOSABI_HPUX 1 /* HP-UX operating system */ +#define ELFOSABI_NETBSD 2 /* NetBSD */ +#define ELFOSABI_LINUX 3 /* GNU/Linux */ +#define ELFOSABI_HURD 4 /* GNU/Hurd */ +#define ELFOSABI_86OPEN 5 /* 86Open common IA32 ABI */ +#define ELFOSABI_SOLARIS 6 /* Solaris */ +#define ELFOSABI_MONTEREY 7 /* Monterey */ +#define ELFOSABI_IRIX 8 /* IRIX */ +#define ELFOSABI_FREEBSD 9 /* FreeBSD */ +#define ELFOSABI_TRU64 10 /* TRU64 UNIX */ +#define ELFOSABI_MODESTO 11 /* Novell Modesto */ +#define ELFOSABI_OPENBSD 12 /* OpenBSD */ +#define ELFOSABI_ARM 97 /* ARM */ +#define ELFOSABI_STANDALONE 255 /* Standalone (embedded) application */ + +/* e_ident */ +#define IS_ELF(ehdr) ((ehdr).e_ident[EI_MAG0] == ELFMAG0 && \ + (ehdr).e_ident[EI_MAG1] == ELFMAG1 && \ + (ehdr).e_ident[EI_MAG2] == ELFMAG2 && \ + (ehdr).e_ident[EI_MAG3] == ELFMAG3) + +/* ELF Header */ +typedef struct { + unsigned char e_ident[EI_NIDENT]; /* ELF Identification */ + Elf32_Half e_type; /* object file type */ + Elf32_Half e_machine; /* machine */ + Elf32_Word e_version; /* object file version */ + Elf32_Addr e_entry; /* virtual entry point */ + Elf32_Off e_phoff; /* program header table offset */ + Elf32_Off e_shoff; /* section header table offset */ + Elf32_Word e_flags; /* processor-specific flags */ + Elf32_Half e_ehsize; /* ELF header size */ + Elf32_Half e_phentsize; /* program header entry size */ + Elf32_Half e_phnum; /* number of program header entries */ + Elf32_Half e_shentsize; /* section header entry size */ + Elf32_Half e_shnum; /* number of section header entries */ + Elf32_Half e_shstrndx; /* section header table's "section + header string table" entry offset */ +} Elf32_Ehdr; + +typedef struct { + unsigned char e_ident[EI_NIDENT]; /* Id bytes */ + Elf64_Quarter e_type; /* file type */ + Elf64_Quarter e_machine; /* machine type */ + Elf64_Half e_version; /* version number */ + Elf64_Addr e_entry; /* entry point */ + Elf64_Off e_phoff; /* Program hdr offset */ + Elf64_Off e_shoff; /* Section hdr offset */ + Elf64_Half e_flags; /* Processor flags */ + Elf64_Quarter e_ehsize; /* sizeof ehdr */ + Elf64_Quarter e_phentsize; /* Program header entry size */ + Elf64_Quarter e_phnum; /* Number of program headers */ + Elf64_Quarter e_shentsize; /* Section header entry size */ + Elf64_Quarter e_shnum; /* Number of section headers */ + Elf64_Quarter e_shstrndx; /* String table index */ +} Elf64_Ehdr; + +/* e_type */ +#define ET_NONE 0 /* No file type */ +#define ET_REL 1 /* relocatable file */ +#define ET_EXEC 2 /* executable file */ +#define ET_DYN 3 /* shared object file */ +#define ET_CORE 4 /* core file */ +#define ET_NUM 5 /* number of types */ +#define ET_LOPROC 0xff00 /* reserved range for processor */ +#define ET_HIPROC 0xffff /* specific e_type */ + +/* e_machine */ +#define EM_NONE 0 /* No Machine */ +#define EM_M32 1 /* AT&T WE 32100 */ +#define EM_SPARC 2 /* SPARC */ +#define EM_386 3 /* Intel 80386 */ +#define EM_68K 4 /* Motorola 68000 */ +#define EM_88K 5 /* Motorola 88000 */ +#define EM_486 6 /* Intel 80486 - unused? */ +#define EM_860 7 /* Intel 80860 */ +#define EM_MIPS 8 /* MIPS R3000 Big-Endian only */ +/* + * Don't know if EM_MIPS_RS4_BE, + * EM_SPARC64, EM_PARISC, + * or EM_PPC are ABI compliant + */ +#define EM_MIPS_RS4_BE 10 /* MIPS R4000 Big-Endian */ +#define EM_SPARC64 11 /* SPARC v9 64-bit unofficial */ +#define EM_PARISC 15 /* HPPA */ +#define EM_SPARC32PLUS 18 /* Enhanced instruction set SPARC */ +#define EM_PPC 20 /* PowerPC */ +#define EM_ARM 40 /* ARM AArch32 */ +#define EM_ALPHA 41 /* DEC ALPHA */ +#define EM_SH 42 /* Hitachi/Renesas Super-H */ +#define EM_SPARCV9 43 /* SPARC version 9 */ +#define EM_IA_64 50 /* Intel IA-64 Processor */ +#define EM_AMD64 62 /* AMD64 architecture */ +#define EM_VAX 75 /* DEC VAX */ +#define EM_AARCH64 183 /* ARM AArch64 */ + +/* Non-standard */ +#define EM_ALPHA_EXP 0x9026 /* DEC ALPHA */ + +#define EM_RISCV 243 + + +/* Version */ +#define EV_NONE 0 /* Invalid */ +#define EV_CURRENT 1 /* Current */ +#define EV_NUM 2 /* number of versions */ + +/* Section Header */ +typedef struct { + Elf32_Word sh_name; /* name - index into section header + * string table section */ + Elf32_Word sh_type; /* type */ + Elf32_Word sh_flags; /* flags */ + Elf32_Addr sh_addr; /* address */ + Elf32_Off sh_offset; /* file offset */ + Elf32_Word sh_size; /* section size */ + Elf32_Word sh_link; /* section header table index link */ + Elf32_Word sh_info; /* extra information */ + Elf32_Word sh_addralign; /* address alignment */ + Elf32_Word sh_entsize; /* section entry size */ +} Elf32_Shdr; + +typedef struct { + Elf64_Half sh_name; /* section name */ + Elf64_Half sh_type; /* section type */ + Elf64_Xword sh_flags; /* section flags */ + Elf64_Addr sh_addr; /* virtual address */ + Elf64_Off sh_offset; /* file offset */ + Elf64_Xword sh_size; /* section size */ + Elf64_Half sh_link; /* link to another */ + Elf64_Half sh_info; /* misc info */ + Elf64_Xword sh_addralign; /* memory alignment */ + Elf64_Xword sh_entsize; /* table entry size */ +} Elf64_Shdr; + +/* Special Section Indexes */ +#define SHN_UNDEF 0 /* undefined */ +#define SHN_LORESERVE 0xff00 /* lower bounds of reserved indexes */ +#define SHN_LOPROC 0xff00 /* reserved range for processor */ +#define SHN_HIPROC 0xff1f /* specific section indexes */ +#define SHN_ABS 0xfff1 /* absolute value */ +#define SHN_COMMON 0xfff2 /* common symbol */ +#define SHN_HIRESERVE 0xffff /* upper bounds of reserved indexes */ + +/* sh_type */ +#define SHT_NULL 0 /* inactive */ +#define SHT_PROGBITS 1 /* program defined information */ +#define SHT_SYMTAB 2 /* symbol table section */ +#define SHT_STRTAB 3 /* string table section */ +#define SHT_RELA 4 /* relocation section with addends*/ +#define SHT_HASH 5 /* symbol hash table section */ +#define SHT_DYNAMIC 6 /* dynamic section */ +#define SHT_NOTE 7 /* note section */ +#define SHT_NOBITS 8 /* no space section */ +#define SHT_REL 9 /* relation section without addends */ +#define SHT_SHLIB 10 /* reserved - purpose unknown */ +#define SHT_DYNSYM 11 /* dynamic symbol table section */ +#define SHT_NUM 12 /* number of section types */ +#define SHT_LOPROC 0x70000000 /* reserved range for processor */ +#define SHT_HIPROC 0x7fffffff /* specific section header types */ +#define SHT_LOUSER 0x80000000 /* reserved range for application */ +#define SHT_HIUSER 0xffffffff /* specific indexes */ + +/* Section names */ +#define ELF_BSS ".bss" /* uninitialized data */ +#define ELF_DATA ".data" /* initialized data */ +#define ELF_DEBUG ".debug" /* debug */ +#define ELF_DYNAMIC ".dynamic" /* dynamic linking information */ +#define ELF_DYNSTR ".dynstr" /* dynamic string table */ +#define ELF_DYNSYM ".dynsym" /* dynamic symbol table */ +#define ELF_FINI ".fini" /* termination code */ +#define ELF_GOT ".got" /* global offset table */ +#define ELF_HASH ".hash" /* symbol hash table */ +#define ELF_INIT ".init" /* initialization code */ +#define ELF_REL_DATA ".rel.data" /* relocation data */ +#define ELF_REL_FINI ".rel.fini" /* relocation termination code */ +#define ELF_REL_INIT ".rel.init" /* relocation initialization code */ +#define ELF_REL_DYN ".rel.dyn" /* relocation dynamic link info */ +#define ELF_REL_RODATA ".rel.rodata" /* relocation read-only data */ +#define ELF_REL_TEXT ".rel.text" /* relocation code */ +#define ELF_RODATA ".rodata" /* read-only data */ +#define ELF_SHSTRTAB ".shstrtab" /* section header string table */ +#define ELF_STRTAB ".strtab" /* string table */ +#define ELF_SYMTAB ".symtab" /* symbol table */ +#define ELF_TEXT ".text" /* code */ + + +/* Section Attribute Flags - sh_flags */ +#define SHF_WRITE 0x1 /* Writable */ +#define SHF_ALLOC 0x2 /* occupies memory */ +#define SHF_EXECINSTR 0x4 /* executable */ +#define SHF_TLS 0x400 /* thread local storage */ +#define SHF_MASKPROC 0xf0000000 /* reserved bits for processor + * specific section attributes */ + +/* Symbol Table Entry */ +typedef struct elf32_sym { + Elf32_Word st_name; /* name - index into string table */ + Elf32_Addr st_value; /* symbol value */ + Elf32_Word st_size; /* symbol size */ + unsigned char st_info; /* type and binding */ + unsigned char st_other; /* 0 - no defined meaning */ + Elf32_Half st_shndx; /* section header index */ +} Elf32_Sym; + +typedef struct { + Elf64_Half st_name; /* Symbol name index in str table */ + Elf_Byte st_info; /* type / binding attrs */ + Elf_Byte st_other; /* unused */ + Elf64_Quarter st_shndx; /* section index of symbol */ + Elf64_Xword st_value; /* value of symbol */ + Elf64_Xword st_size; /* size of symbol */ +} Elf64_Sym; + +/* Symbol table index */ +#define STN_UNDEF 0 /* undefined */ + +/* Extract symbol info - st_info */ +#define ELF32_ST_BIND(x) ((x) >> 4) +#define ELF32_ST_TYPE(x) (((unsigned int) x) & 0xf) +#define ELF32_ST_INFO(b,t) (((b) << 4) + ((t) & 0xf)) + +#define ELF64_ST_BIND(x) ((x) >> 4) +#define ELF64_ST_TYPE(x) (((unsigned int) x) & 0xf) +#define ELF64_ST_INFO(b,t) (((b) << 4) + ((t) & 0xf)) + +/* Symbol Binding - ELF32_ST_BIND - st_info */ +#define STB_LOCAL 0 /* Local symbol */ +#define STB_GLOBAL 1 /* Global symbol */ +#define STB_WEAK 2 /* like global - lower precedence */ +#define STB_NUM 3 /* number of symbol bindings */ +#define STB_LOPROC 13 /* reserved range for processor */ +#define STB_HIPROC 15 /* specific symbol bindings */ + +/* Symbol type - ELF32_ST_TYPE - st_info */ +#define STT_NOTYPE 0 /* not specified */ +#define STT_OBJECT 1 /* data object */ +#define STT_FUNC 2 /* function */ +#define STT_SECTION 3 /* section */ +#define STT_FILE 4 /* file */ +#define STT_TLS 6 /* thread local storage */ +#define STT_LOPROC 13 /* reserved range for processor */ +#define STT_HIPROC 15 /* specific symbol types */ + +/* Relocation entry with implicit addend */ +typedef struct { + Elf32_Addr r_offset; /* offset of relocation */ + Elf32_Word r_info; /* symbol table index and type */ +} Elf32_Rel; + +/* Relocation entry with explicit addend */ +typedef struct { + Elf32_Addr r_offset; /* offset of relocation */ + Elf32_Word r_info; /* symbol table index and type */ + Elf32_Sword r_addend; +} Elf32_Rela; + +/* Extract relocation info - r_info */ +#define ELF32_R_SYM(i) ((i) >> 8) +#define ELF32_R_TYPE(i) ((unsigned char) (i)) +#define ELF32_R_INFO(s,t) (((s) << 8) + (unsigned char)(t)) + +typedef struct { + Elf64_Xword r_offset; /* where to do it */ + Elf64_Xword r_info; /* index & type of relocation */ +} Elf64_Rel; + +typedef struct { + Elf64_Xword r_offset; /* where to do it */ + Elf64_Xword r_info; /* index & type of relocation */ + Elf64_Sxword r_addend; /* adjustment value */ +} Elf64_Rela; + +#define ELF64_R_SYM(info) ((info) >> 32) +#define ELF64_R_TYPE(info) ((info) & 0xFFFFFFFF) +#define ELF64_R_INFO(s,t) (((s) << 32) + (__uint32_t)(t)) + +#if defined(__mips64__) && defined(__MIPSEL__) +/* + * The 64-bit MIPS ELF ABI uses a slightly different relocation format + * than the regular ELF ABI: the r_info field is split into several + * pieces (see gnu/usr.bin/binutils/include/elf/mips.h for details). + */ +#undef ELF64_R_SYM +#undef ELF64_R_TYPE +#undef ELF64_R_INFO +#define ELF64_R_TYPE(info) (swap32((info) >> 32)) +#define ELF64_R_SYM(info) ((info) & 0xFFFFFFFF) +#define ELF64_R_INFO(s,t) (((__uint64_t)swap32(t) << 32) + (__uint32_t)(s)) +#endif /* __mips64__ && __MIPSEL__ */ + +/* Program Header */ +typedef struct { + Elf32_Word p_type; /* segment type */ + Elf32_Off p_offset; /* segment offset */ + Elf32_Addr p_vaddr; /* virtual address of segment */ + Elf32_Addr p_paddr; /* physical address - ignored? */ + Elf32_Word p_filesz; /* number of bytes in file for seg. */ + Elf32_Word p_memsz; /* number of bytes in mem. for seg. */ + Elf32_Word p_flags; /* flags */ + Elf32_Word p_align; /* memory alignment */ +} Elf32_Phdr; + +typedef struct { + Elf64_Half p_type; /* entry type */ + Elf64_Half p_flags; /* flags */ + Elf64_Off p_offset; /* offset */ + Elf64_Addr p_vaddr; /* virtual address */ + Elf64_Addr p_paddr; /* physical address */ + Elf64_Xword p_filesz; /* file size */ + Elf64_Xword p_memsz; /* memory size */ + Elf64_Xword p_align; /* memory & file alignment */ +} Elf64_Phdr; + +/* Segment types - p_type */ +#define PT_NULL 0 /* unused */ +#define PT_LOAD 1 /* loadable segment */ +#define PT_DYNAMIC 2 /* dynamic linking section */ +#define PT_INTERP 3 /* the RTLD */ +#define PT_NOTE 4 /* auxiliary information */ +#define PT_SHLIB 5 /* reserved - purpose undefined */ +#define PT_PHDR 6 /* program header */ +#define PT_TLS 7 /* thread local storage */ +#define PT_LOOS 0x60000000 /* reserved range for OS */ +#define PT_HIOS 0x6fffffff /* specific segment types */ +#define PT_LOPROC 0x70000000 /* reserved range for processor */ +#define PT_HIPROC 0x7fffffff /* specific segment types */ + +#define PT_OPENBSD_RANDOMIZE 0x65a3dbe6 /* fill with random data */ +#define PT_GANDR_KERNEL 0x67646b6c /* gdkl */ + + +/* Segment flags - p_flags */ +#define PF_X 0x1 /* Executable */ +#define PF_W 0x2 /* Writable */ +#define PF_R 0x4 /* Readable */ +#define PF_MASKPROC 0xf0000000 /* reserved bits for processor */ + /* specific segment flags */ + +/* Dynamic structure */ +typedef struct { + Elf32_Sword d_tag; /* controls meaning of d_val */ + union { + Elf32_Word d_val; /* Multiple meanings - see d_tag */ + Elf32_Addr d_ptr; /* program virtual address */ + } d_un; +} Elf32_Dyn; + +typedef struct { + Elf64_Xword d_tag; /* controls meaning of d_val */ + union { + Elf64_Addr d_ptr; + Elf64_Xword d_val; + } d_un; +} Elf64_Dyn; + +/* Dynamic Array Tags - d_tag */ +#define DT_NULL 0 /* marks end of _DYNAMIC array */ +#define DT_NEEDED 1 /* string table offset of needed lib */ +#define DT_PLTRELSZ 2 /* size of relocation entries in PLT */ +#define DT_PLTGOT 3 /* address PLT/GOT */ +#define DT_HASH 4 /* address of symbol hash table */ +#define DT_STRTAB 5 /* address of string table */ +#define DT_SYMTAB 6 /* address of symbol table */ +#define DT_RELA 7 /* address of relocation table */ +#define DT_RELASZ 8 /* size of relocation table */ +#define DT_RELAENT 9 /* size of relocation entry */ +#define DT_STRSZ 10 /* size of string table */ +#define DT_SYMENT 11 /* size of symbol table entry */ +#define DT_INIT 12 /* address of initialization func. */ +#define DT_FINI 13 /* address of termination function */ +#define DT_SONAME 14 /* string table offset of shared obj */ +#define DT_RPATH 15 /* string table offset of library + * search path */ +#define DT_SYMBOLIC 16 /* start sym search in shared obj. */ +#define DT_REL 17 /* address of rel. tbl. w addends */ +#define DT_RELSZ 18 /* size of DT_REL relocation table */ +#define DT_RELENT 19 /* size of DT_REL relocation entry */ +#define DT_PLTREL 20 /* PLT referenced relocation entry */ +#define DT_DEBUG 21 /* bugger */ +#define DT_TEXTREL 22 /* Allow rel. mod. to unwritable seg */ +#define DT_JMPREL 23 /* add. of PLT's relocation entries */ +#define DT_BIND_NOW 24 /* Bind now regardless of env setting */ +#define DT_LOOS 0x6000000d /* reserved range for OS */ +#define DT_HIOS 0x6ffff000 /* specific dynamic array tags */ +#define DT_LOPROC 0x70000000 /* reserved range for processor */ +#define DT_HIPROC 0x7fffffff /* specific dynamic array tags */ + +/* some other useful tags */ +#define DT_RELACOUNT 0x6ffffff9 /* if present, number of RELATIVE */ +#define DT_RELCOUNT 0x6ffffffa /* relocs, which must come first */ +#define DT_FLAGS_1 0x6ffffffb + +/* Dynamic Flags - DT_FLAGS_1 .dynamic entry */ +#define DF_1_NOW 0x00000001 +#define DF_1_GLOBAL 0x00000002 +#define DF_1_GROUP 0x00000004 +#define DF_1_NODELETE 0x00000008 +#define DF_1_LOADFLTR 0x00000010 +#define DF_1_INITFIRST 0x00000020 +#define DF_1_NOOPEN 0x00000040 +#define DF_1_ORIGIN 0x00000080 +#define DF_1_DIRECT 0x00000100 +#define DF_1_TRANS 0x00000200 +#define DF_1_INTERPOSE 0x00000400 +#define DF_1_NODEFLIB 0x00000800 +#define DF_1_NODUMP 0x00001000 +#define DF_1_CONLFAT 0x00002000 + +/* ld.so: number of low tags that are used saved internally (0 .. DT_NUM-1) */ +#define DT_NUM (DT_JMPREL+1) + +/* + * Note Definitions + */ +typedef struct { + Elf32_Word namesz; + Elf32_Word descsz; + Elf32_Word type; +} Elf32_Note; + +typedef struct { + Elf64_Half namesz; + Elf64_Half descsz; + Elf64_Half type; +} Elf64_Note; + +#if defined(ELFSIZE) && (ELFSIZE == 32) +#define Elf_Ehdr Elf32_Ehdr +#define Elf_Phdr Elf32_Phdr +#define Elf_Shdr Elf32_Shdr +#define Elf_Sym Elf32_Sym +#define Elf_Rel Elf32_Rel +#define Elf_RelA Elf32_Rela +#define Elf_Dyn Elf32_Dyn +#define Elf_Half Elf32_Half +#define Elf_Word Elf32_Word +#define Elf_Sword Elf32_Sword +#define Elf_Addr Elf32_Addr +#define Elf_Off Elf32_Off +#define Elf_Nhdr Elf32_Nhdr +#define Elf_Note Elf32_Note + +#define ELF_R_SYM ELF32_R_SYM +#define ELF_R_TYPE ELF32_R_TYPE +#define ELF_R_INFO ELF32_R_INFO +#define ELFCLASS ELFCLASS32 + +#define ELF_ST_BIND ELF32_ST_BIND +#define ELF_ST_TYPE ELF32_ST_TYPE +#define ELF_ST_INFO ELF32_ST_INFO + +#elif defined(ELFSIZE) && (ELFSIZE == 64) + +#define Elf_Ehdr Elf64_Ehdr +#define Elf_Phdr Elf64_Phdr +#define Elf_Shdr Elf64_Shdr +#define Elf_Sym Elf64_Sym +#define Elf_Rel Elf64_Rel +#define Elf_RelA Elf64_Rela +#define Elf_Dyn Elf64_Dyn +#define Elf_Half Elf64_Half +#define Elf_Word Elf64_Word +#define Elf_Sword Elf64_Sword +#define Elf_Addr Elf64_Addr +#define Elf_Off Elf64_Off +#define Elf_Nhdr Elf64_Nhdr +#define Elf_Note Elf64_Note + +#define ELF_R_SYM ELF64_R_SYM +#define ELF_R_TYPE ELF64_R_TYPE +#define ELF_R_INFO ELF64_R_INFO +#define ELFCLASS ELFCLASS64 + +#define ELF_ST_BIND ELF64_ST_BIND +#define ELF_ST_TYPE ELF64_ST_TYPE +#define ELF_ST_INFO ELF64_ST_INFO + +#endif + +#endif \ No newline at end of file diff --git a/modules/cpu/src/runtime/elfarch.h b/modules/cpu/src/runtime/elfarch.h new file mode 100644 index 0000000000..11f666c984 --- /dev/null +++ b/modules/cpu/src/runtime/elfarch.h @@ -0,0 +1,35 @@ +#ifndef ELFARCH_H +#define ELFARCH_H + +#if defined(__i386__) +#define EM_THIS EM_386 +#define EL_ARCH_USES_REL +#elif defined(__amd64__) +#define EM_THIS EM_AMD64 +#define EL_ARCH_USES_RELA +#elif defined(__arm__) +#define EM_THIS EM_ARM +#elif defined(__aarch64__) +#define EM_THIS EM_AARCH64 +#define EL_ARCH_USES_RELA +#define EL_ARCH_USES_REL +#elif defined(__riscv) +#define EM_THIS EM_RISCV +#define EL_ARCH_USES_RELA +#else +#error specify your ELF architecture +#endif + +#if defined(__LP64__) || defined(__LLP64__) +#define ELFSIZE 64 +#else +#define ELFSIZE 32 +#endif + +#if __BYTE_ORDER__ == __ORDER_LITTLE_ENDIAN__ +#define ELFDATATHIS ELFDATA2LSB +#else +#define ELFDATATHIS ELFDATA2MSB +#endif + +#endif diff --git a/modules/cpu/src/runtime/elfload.cpp b/modules/cpu/src/runtime/elfload.cpp new file mode 100644 index 0000000000..b1670da065 --- /dev/null +++ b/modules/cpu/src/runtime/elfload.cpp @@ -0,0 +1,269 @@ +#include "elfload.h" +#include +#include + +el_status el_pread(el_ctx *ctx, void *def, size_t nb, size_t offset) +{ + return ctx->pread(ctx, def, nb, offset) ? EL_OK : EL_EIO; +} + +#define EL_PHOFF(ctx, num) (((ctx)->ehdr.e_phoff + (num) * (ctx)->ehdr.e_phentsize)) + +el_status el_findphdr(el_ctx *ctx, Elf_Phdr *phdr, uint32_t type, unsigned *i) +{ + el_status rv = EL_OK; + for (; *i < ctx->ehdr.e_phnum; (*i)++) { + if ((rv = el_pread(ctx, phdr, sizeof *phdr, EL_PHOFF(ctx, *i)))) + return rv; + + if (phdr->p_type == type) { + return rv; + } + } + + *i = -1; + return rv; +} + +el_status el_init(el_ctx *ctx) +{ + el_status rv = EL_OK; + if ((rv = el_pread(ctx, &ctx->ehdr, sizeof ctx->ehdr, 0))) + return rv; + + /* validate header */ + + if (!IS_ELF(ctx->ehdr)) + return EL_NOTELF; + + + if (ctx->ehdr.e_ident[EI_CLASS] != ELFCLASS) + return EL_WRONGBITS; + + if (ctx->ehdr.e_ident[EI_DATA] != ELFDATATHIS) + return EL_WRONGENDIAN; + + if (ctx->ehdr.e_ident[EI_VERSION] != EV_CURRENT) + return EL_NOTELF; + +#if 0 + /* gandr binaries use the STANDALONE ABI */ + if (ctx->ehdr.e_ident[EI_OSABI] != ELFOSABI_STANDALONE) + return EL_WRONGOS; + + /* G is for Gandr + if (ctx->ehdr.e_ident[EI_ABIVERSION] != 'G') + return EL_WRONGOS; */ +#endif + + if (ctx->ehdr.e_type != ET_EXEC && ctx->ehdr.e_type != ET_DYN) + return EL_NOTEXEC; + + if (ctx->ehdr.e_machine != EM_THIS) + return EL_WRONGARCH; + + if (ctx->ehdr.e_version != EV_CURRENT) + return EL_NOTELF; + + /* load phdrs */ + Elf_Phdr ph; + + /* iterate through, calculate extents */ + ctx->base_load_paddr = ctx->base_load_vaddr = 0; + ctx->align = 1; + ctx->memsz = 0; + + unsigned i = 0; + for(;;) { + if ((rv = el_findphdr(ctx, &ph, PT_LOAD, &i))) + return rv; + + if (i == (unsigned) -1) + break; + + Elf_Addr phend = ph.p_vaddr + ph.p_memsz; + if (phend > ctx->memsz) + ctx->memsz = phend; + + if (ph.p_align > ctx->align) + ctx->align = ph.p_align; + + i++; + } + + if (ctx->ehdr.e_type == ET_DYN) { + i = 0; + + if ((rv = el_findphdr(ctx, &ph, PT_DYNAMIC, &i))) + return rv; + + if (i == (unsigned) -1) + return EL_NODYN; + + ctx->dynoff = ph.p_offset; + ctx->dynsize = ph.p_filesz; + } else { + ctx->dynoff = 0; + ctx->dynsize = 0; + } + + return rv; +} + +/* +typedef void* (*el_alloc_cb)( + el_ctx *ctx, + Elf_Addr phys, + Elf_Addr virt, + Elf_Addr size); +*/ + +el_status el_load(el_ctx *ctx, el_alloc_cb alloc) +{ + el_status rv = EL_OK; + + /* address deltas */ + Elf_Addr pdelta = ctx->base_load_paddr; + Elf_Addr vdelta = ctx->base_load_vaddr; + + /* iterate paddrs */ + Elf_Phdr ph; + unsigned i = 0; + for(;;) { + if ((rv = el_findphdr(ctx, &ph, PT_LOAD, &i))) + return rv; + + if (i == (unsigned) -1) + break; + + Elf_Addr pload = ph.p_paddr + pdelta; + Elf_Addr vload = ph.p_vaddr + vdelta; + + /* allocate mem */ + char *dest = (char *)alloc(ctx, pload, vload, ph.p_memsz); + if (!dest) + return EL_ENOMEM; + + printf("Loading seg fileoff %lx, vaddr %lx to %lx\n", + ph.p_offset, ph.p_vaddr, (uintptr_t)dest); + + /* read loaded portion */ + if ((rv = el_pread(ctx, dest, ph.p_filesz, ph.p_offset))) + return rv; + + /* zero mem-only portion */ + memset(dest + ph.p_filesz, 0, ph.p_memsz - ph.p_filesz); + + i++; + } + + return rv; +} + +el_status el_finddyn(el_ctx *ctx, Elf_Dyn *dyn, uint32_t tag) +{ + el_status rv = EL_OK; + size_t ndyn = ctx->dynsize / sizeof(Elf_Dyn); + + for(unsigned i = 0; i < ndyn; i++) { + if ((rv = el_pread(ctx, dyn, sizeof *dyn, ctx->dynoff + i * sizeof *dyn))) + return rv; + + if (dyn->d_tag == tag) + return EL_OK; + } + + dyn->d_tag = DT_NULL; + return EL_OK; +} + +el_status el_findrelocs(el_ctx *ctx, el_relocinfo *ri, uint32_t type) +{ + el_status rv = EL_OK; + + Elf_Dyn rel, relsz, relent; + + if ((rv = el_finddyn(ctx, &rel, type))) + return rv; + + if ((rv = el_finddyn(ctx, &relsz, type + 1))) + return rv; + + if ((rv = el_finddyn(ctx, &relent, type + 2))) + return rv; + + if (rel.d_tag == DT_NULL + || relsz.d_tag == DT_NULL + || relent.d_tag == DT_NULL) { + ri->entrysize = 0; + ri->tablesize = 0; + ri->tableoff = 0; + } else { + ri->tableoff = rel.d_un.d_ptr; + ri->tablesize = relsz.d_un.d_val; + ri->entrysize = relent.d_un.d_val; + } + + return rv; +} + +extern el_status el_applyrel(el_ctx *ctx, Elf_Rel *rel); +extern el_status el_applyrela(el_ctx *ctx, Elf_RelA *rela); + +el_status el_relocate(el_ctx *ctx) +{ + el_status rv = EL_OK; + + // not dynamic + if (ctx->ehdr.e_type != ET_DYN) + return EL_OK; + + char *base = (char *) ctx->base_load_paddr; + + el_relocinfo ri; +#ifdef EL_ARCH_USES_REL + if ((rv = el_findrelocs(ctx, &ri, DT_REL))) + return rv; + + if (ri.entrysize != sizeof(Elf_Rel) && ri.tablesize) { + EL_DEBUG("Relocation size %u doesn't match expected %u\n", + ri.entrysize, sizeof(Elf_Rel)); + return EL_BADREL; + } + + size_t relcnt = ri.tablesize / sizeof(Elf_Rel); + Elf_Rel *reltab = base + ri.tableoff; + for (size_t i = 0; i < relcnt; i++) { + if ((rv = el_applyrel(ctx, &reltab[i]))) + return rv; + } +#else + EL_DEBUG("Architecture doesn't use REL\n"); +#endif + +#ifdef EL_ARCH_USES_RELA + if ((rv = el_findrelocs(ctx, &ri, DT_RELA))) + return rv; + + if (ri.entrysize != sizeof(Elf_RelA) && ri.tablesize) { + EL_DEBUG("Relocation size %u doesn't match expected %u\n", + ri.entrysize, sizeof(Elf_RelA)); + return EL_BADREL; + } + + size_t relacnt = ri.tablesize / sizeof(Elf_RelA); + Elf_RelA *relatab = (Elf_RelA *)(base + ri.tableoff); + for (size_t i = 0; i < relacnt; i++) { + if ((rv = el_applyrela(ctx, &relatab[i]))) + return rv; + } +#else + EL_DEBUG("Architecture doesn't use RELA\n"); +#endif + +#if !defined(EL_ARCH_USES_REL) && !defined(EL_ARCH_USES_RELA) + #error No relocation type defined! +#endif + + return rv; +} diff --git a/modules/cpu/src/runtime/elfload.h b/modules/cpu/src/runtime/elfload.h new file mode 100644 index 0000000000..148b9b3a72 --- /dev/null +++ b/modules/cpu/src/runtime/elfload.h @@ -0,0 +1,99 @@ +#ifndef ELFLOAD_H +#define ELFLOAD_H +#include +#include +#include "elfarch.h" +#include "elf.h" + +#ifdef DEBUG +#include +#define EL_DEBUG(...) printf(__VA_ARGS__) +#else +#define EL_DEBUG(...) do {} while(0) +#endif + +typedef enum { + EL_OK = 0, + + EL_EIO, + EL_ENOMEM, + + EL_NOTELF, + EL_WRONGBITS, + EL_WRONGENDIAN, + EL_WRONGARCH, + EL_WRONGOS, + EL_NOTEXEC, + EL_NODYN, + EL_BADREL, + +} el_status; + +typedef struct el_ctx { + bool (*pread)(struct el_ctx *ctx, void *dest, size_t nb, size_t offset); + + /* base_load_* -> address we are actually going to load at + */ + Elf_Addr + base_load_paddr, + base_load_vaddr; + + /* size in memory of binary */ + Elf_Addr memsz; + + /* required alignment */ + Elf_Addr align; + + /* ELF header */ + Elf_Ehdr ehdr; + + /* Offset of dynamic table (0 if not ET_DYN) */ + Elf_Off dynoff; + /* Size of dynamic table (0 if not ET_DYN) */ + Elf_Addr dynsize; + + void *elf; +} el_ctx; + +el_status el_pread(el_ctx *ctx, void *def, size_t nb, size_t offset); + +el_status el_init(el_ctx *ctx); +typedef void* (*el_alloc_cb)( + el_ctx *ctx, + Elf_Addr phys, + Elf_Addr virt, + Elf_Addr size); + +el_status el_load(el_ctx *ctx, el_alloc_cb alloccb); + +/* find the next phdr of type \p type, starting at \p *i. + * On success, returns EL_OK with *i set to the phdr number, and the phdr loaded + * in *phdr. + * + * If the end of the phdrs table was reached, *i is set to -1 and the contents + * of *phdr are undefined + */ +el_status el_findphdr(el_ctx *ctx, Elf_Phdr *phdr, uint32_t type, unsigned *i); + +/* Relocate the loaded executable */ +el_status el_relocate(el_ctx *ctx); + +/* find a dynamic table entry + * returns the entry on success, dyn->d_tag = DT_NULL on failure + */ +el_status el_finddyn(el_ctx *ctx, Elf_Dyn *dyn, uint32_t type); + +typedef struct { + Elf_Off tableoff; + Elf_Addr tablesize; + Elf_Addr entrysize; +} el_relocinfo; + +/* find all information regarding relocations of a specific type. + * + * pass DT_REL or DT_RELA for type + * sets ri->entrysize = 0 if not found + */ +el_status el_findrelocs(el_ctx *ctx, el_relocinfo *ri, uint32_t type); + +#endif \ No newline at end of file diff --git a/modules/cpu/src/runtime/elfloader.cpp b/modules/cpu/src/runtime/elfloader.cpp new file mode 100644 index 0000000000..8bced71e14 --- /dev/null +++ b/modules/cpu/src/runtime/elfloader.cpp @@ -0,0 +1,40 @@ +#include "elfloader.h" + +using namespace nncase; +using namespace nncase::runtime; +using namespace nncase::runtime::cpu; + +int elfloader::invoke_elf(size_t id, buffer_t **buffers, + nncase_mt_t *nncase_mt, void *data, void *rdata) { + + check(el_init(&ctx_), "initialising"); + + // align to ctx.align + ptr_ = malloc(ctx_.memsz + ctx_.align); + buf_ = (void *)(((size_t)ptr_ + (ctx_.align - 1)) & ~(ctx_.align - 1)); + +#if defined(__linux__) + if (mprotect(buf_, ctx_.memsz, PROT_READ | PROT_WRITE | PROT_EXEC)) { + perror("mprotect"); + return 1; + } +#endif + + ctx_.base_load_vaddr = ctx_.base_load_paddr = (uintptr_t)buf_; + + check(el_load(&ctx_, alloccb), "loading"); + check(el_relocate(&ctx_), "relocating"); + + uintptr_t epaddr = ctx_.ehdr.e_entry + (uintptr_t)buf_; + + entrypoint_t ep = (entrypoint_t)epaddr; + + printf("Binary entrypoint is %" PRIxPTR "; invoking %p\n", + (uintptr_t)ctx_.ehdr.e_entry, (void *)epaddr); + + ep(id, buffers, nncase_mt, data, rdata); + + free(ptr_); + + return 0; +} diff --git a/modules/cpu/src/runtime/elfloader.h b/modules/cpu/src/runtime/elfloader.h new file mode 100644 index 0000000000..771a6b5f16 --- /dev/null +++ b/modules/cpu/src/runtime/elfloader.h @@ -0,0 +1,62 @@ +#pragma once +#include "cpu_common.h" +#include "elfload.h" +#include +#include +#include +#include +#include +#include +#if defined(__linux__) +#include +#endif + +BEGIN_NS_NNCASE_RT_MODULE(cpu) + +typedef void (*entrypoint_t)(size_t id, buffer_t **buffers, + nncase_mt_t *nncase_mt, void *data, void *rdata); + +class elfloader { + public: + elfloader(char *elf) : elf_(elf) { + ctx_.pread = bpread; + ctx_.elf = elf; + } + + // typedef void (*entrypoint_t)(float (*op_t)(float), float *, float *, + // int); + + static bool bpread(el_ctx *ctx, void *dest, size_t nb, size_t offset) { + (void)ctx; + + memcpy(dest, (char *)ctx->elf + offset, nb); + + return true; + } + + static void *alloccb(el_ctx *ctx, Elf_Addr phys, Elf_Addr virt, + Elf_Addr size) { + (void)ctx; + (void)phys; + (void)size; + return (void *)virt; + } + + static void check(el_status stat, const char *expln) { + if (stat) { + fprintf(stderr, "%s: error %d\n", expln, stat); + exit(1); + } + } + + int invoke_elf(size_t id, buffer_t **buffers, nncase_mt_t *nncase_mt, + void *data, void *rdata); + + private: + void *ptr_; + void *buf_; + char *elf_; + el_ctx ctx_; +}; + +END_NS_NNCASE_RT_MODULE \ No newline at end of file diff --git a/modules/cpu/src/runtime/elfreloc_aarch64.cpp b/modules/cpu/src/runtime/elfreloc_aarch64.cpp new file mode 100644 index 0000000000..e8fd3f8ed7 --- /dev/null +++ b/modules/cpu/src/runtime/elfreloc_aarch64.cpp @@ -0,0 +1,67 @@ +#include "elfload.h" + +#if defined(__aarch64__) + +#define R_AARCH64_NONE 0 +#define R_AARCH64_RELATIVE 1027 + +el_status el_applyrela(el_ctx *ctx, Elf_RelA *rel) +{ + uintptr_t *p = (uintptr_t*) (rel->r_offset + ctx->base_load_paddr); + uint32_t type = ELF_R_TYPE(rel->r_info); + uint32_t sym = ELF_R_SYM(rel->r_info); + + switch (type) { + case R_AARCH64_NONE: + EL_DEBUG("R_AARCH64_NONE\n"); + break; + case R_AARCH64_RELATIVE: + if (sym) { + EL_DEBUG("R_AARCH64_RELATIVE with symbol ref!\n"); + return EL_BADREL; + } + + EL_DEBUG("Applying R_AARCH64_RELATIVE reloc @%p\n", p); + *p = rel->r_addend + ctx->base_load_vaddr; + break; + + default: + EL_DEBUG("Bad relocation %u\n", type); + return EL_BADREL; + + } + + return EL_OK; +} + +el_status el_applyrel(el_ctx *ctx, Elf_Rel *rel) +{ + uintptr_t *p = (uintptr_t*) (rel->r_offset + ctx->base_load_paddr); + uint32_t type = ELF_R_TYPE(rel->r_info); + uint32_t sym = ELF_R_SYM(rel->r_info); + + switch (type) { + case R_AARCH64_NONE: + EL_DEBUG("R_AARCH64_NONE\n"); + break; + case R_AARCH64_RELATIVE: + if (sym) { + EL_DEBUG("R_AARCH64_RELATIVE with symbol ref!\n"); + return EL_BADREL; + } + + EL_DEBUG("Applying R_AARCH64_RELATIVE reloc @%p\n", p); + *p += ctx->base_load_vaddr; + break; + + default: + EL_DEBUG("Bad relocation %u\n", type); + return EL_BADREL; + + } + + return EL_OK; +} + + +#endif diff --git a/modules/cpu/src/runtime/elfreloc_amd64.cpp b/modules/cpu/src/runtime/elfreloc_amd64.cpp new file mode 100644 index 0000000000..8b32082ff3 --- /dev/null +++ b/modules/cpu/src/runtime/elfreloc_amd64.cpp @@ -0,0 +1,28 @@ +#include "elfload.h" + +#if defined(__amd64__) + +#define R_AMD64_NONE 0 +#define R_AMD64_RELATIVE 8 + +el_status el_applyrela(el_ctx *ctx, Elf_RelA *rel) +{ + uint64_t *p = (uint64_t*) (rel->r_offset + ctx->base_load_vaddr); + uint32_t type = ELF_R_TYPE(rel->r_info); + + switch (type) { + case R_AMD64_NONE: break; + case R_AMD64_RELATIVE: + EL_DEBUG("Applying R_AMD64_RELATIVE reloc @%p\n", p); + *p = rel->r_addend + ctx->base_load_vaddr; + break; + default: + EL_DEBUG("Bad relocation %u\n", type); + return EL_BADREL; + + } + + return EL_OK; +} + +#endif diff --git a/modules/cpu/src/runtime/elfreloc_i386.cpp b/modules/cpu/src/runtime/elfreloc_i386.cpp new file mode 100644 index 0000000000..fdf66e3d3f --- /dev/null +++ b/modules/cpu/src/runtime/elfreloc_i386.cpp @@ -0,0 +1,28 @@ +#include "elfload.h" + +#if defined(__i386__) + +#define R_386_NONE 0 +#define R_386_RELATIVE 8 + +el_status el_applyrel(el_ctx *ctx, Elf_Rel *rel) +{ + uint32_t *p = (uint32_t*) (rel->r_offset + ctx->base_load_vaddr); + uint32_t type = ELF_R_TYPE(rel->r_info); + uint32_t sym = ELF_R_SYM(rel->r_info); + + switch (type) { + case R_386_NONE: break; + case R_386_RELATIVE: + EL_DEBUG("Applying R_386_RELATIVE reloc @%p\n", p); + *p += ctx->base_load_vaddr; + break; + default: + EL_DEBUG("Bad relocation %u\n", type); + return EL_BADREL; + } + + return EL_OK; +} + +#endif diff --git a/modules/cpu/src/runtime/elfreloc_riscv64.cpp b/modules/cpu/src/runtime/elfreloc_riscv64.cpp new file mode 100644 index 0000000000..a4d99a64ef --- /dev/null +++ b/modules/cpu/src/runtime/elfreloc_riscv64.cpp @@ -0,0 +1,33 @@ +#include "elfload.h" + +#if defined(__riscv) + +#define R_riscv64_NONE 0 +#define R_riscv64_RELATIVE 3 +#define R_riscv64_JUMP_SLOT 5 + +el_status el_applyrela(el_ctx *ctx, Elf_RelA *rel) +{ + uint64_t *p = (uint64_t*) (rel->r_offset + ctx->base_load_vaddr); + uint32_t type = ELF_R_TYPE(rel->r_info); + EL_DEBUG("rv\n"); + + switch (type) { + case R_riscv64_NONE: break; + case R_riscv64_RELATIVE: + EL_DEBUG("Applying R_riscv64_RELATIVE reloc @%p\n", p); + *p = rel->r_addend + ctx->base_load_vaddr; + break; + case R_riscv64_JUMP_SLOT: + EL_DEBUG("Applying R_riscv64_JUMP_SLOT reloc @%p\n", p); + break; + default: + EL_DEBUG("Bad relocation %u\n", type); + return EL_BADREL; + + } + + return EL_OK; +} + +#endif diff --git a/modules/cpu/src/runtime/runtime_function.cpp b/modules/cpu/src/runtime/runtime_function.cpp index 9934a6985c..2a6be1e344 100644 --- a/modules/cpu/src/runtime/runtime_function.cpp +++ b/modules/cpu/src/runtime/runtime_function.cpp @@ -13,9 +13,11 @@ * limitations under the License. */ #include "runtime_function.h" +#include "elfloader.h" #include #include #include +#include using namespace nncase; using namespace nncase::runtime; @@ -54,33 +56,49 @@ result cpu_runtime_function::initialize_core( for (uint32_t i = 0; i < header.inputs; i++) { sr.template read(); auto rank = sr.template read(); + std::vector shape(rank); std::cout << "shape: "; for (uint32_t j = 0; j < rank; j++) { - std::cout << sr.template read() << ", "; + shape[j] = sr.template read(); + std::cout << shape[j] << ", "; } std::cout << std::endl; + std::vector stride(rank); std::cout << "stride: "; for (uint32_t j = 0; j < rank; j++) { - std::cout << sr.template read() << ", "; + stride[j] = sr.template read(); + std::cout << stride[j] << ", "; } std::cout << std::endl; + + input_ranks_.emplace_back(rank); + input_shapes_.emplace_back(shape); + input_strides_.emplace_back(stride); } for (uint32_t i = 0; i < header.outputs; i++) { sr.template read(); auto rank = sr.template read(); + std::vector shape(rank); std::cout << "shape: "; for (uint32_t j = 0; j < rank; j++) { - std::cout << sr.template read() << ", "; + shape[j] = sr.template read(); + std::cout << shape[j] << ", "; } std::cout << std::endl; + std::vector stride(rank); std::cout << "stride: "; for (uint32_t j = 0; j < rank; j++) { - std::cout << sr.template read() << ", "; + stride[j] = sr.template read(); + std::cout << stride[j] << ", "; } std::cout << std::endl; + + output_ranks_.emplace_back(rank); + output_shapes_.emplace_back(shape); + output_strides_.emplace_back(stride); } return ok(); @@ -92,8 +110,42 @@ result cpu_runtime_function::initialize_core( result cpu_runtime_function::invoke_core(NNCASE_UNUSED gsl::span parameters, value_t return_value) noexcept { - module().interp(); try_var(id, module().find_id_by_function(this)); std::cout << "call " << id << std::endl; + + std::vector buffers(input_ranks_.size() + output_ranks_.size()); + + // input buffer + for (uint32_t i = 0; i < input_ranks_.size(); i++) { + auto input_tensor = parameters[i].as().expect( + "input " + std::to_string(i) + " is not a tensor"); + try_var(input_span, get_input_span(input_tensor)); + buffer_t *input_buffer = + new buffer_t(input_span.data(), 0, input_shapes_[i].data(), + input_strides_[i].data(), input_ranks_[i]); + buffers[i] = input_buffer; + } + + // output buffer + for (uint32_t i = 0; i < output_ranks_.size(); i++) { + auto output_tensor = parameters[i].as().expect( + "output " + std::to_string(i) + " is not a tensor"); + try_var(output_span, get_output_span(output_tensor)); + buffer_t *output_buffer = + new buffer_t(output_span.data(), 0, output_shapes_[i].data(), + output_strides_[i].data(), output_ranks_[i]); + buffers[input_ranks_.size() + i] = output_buffer; + } + + auto elfloader_ = elfloader{(char *)module().text_physical().data()}; + elfloader_.invoke_elf(id, buffers.data(), &nncase_mt, nullptr, nullptr); + for (int i = 0; i < 10; i++) { + printf("%f\n", ((float *)buffers[1]->vaddr)[i]); + } + + for (int i = 0; i < buffers.size(); i++) { + delete buffers[i]; + } + return ok(return_value); } \ No newline at end of file diff --git a/modules/cpu/src/runtime/runtime_function.h b/modules/cpu/src/runtime/runtime_function.h index 79a524dfc1..ba19300437 100644 --- a/modules/cpu/src/runtime/runtime_function.h +++ b/modules/cpu/src/runtime/runtime_function.h @@ -30,6 +30,14 @@ class cpu_runtime_function : public runtime_function { initialize_core(runtime_function_init_context &context) noexcept override; result invoke_core(gsl::span parameters, value_t return_value) noexcept override; + + private: + std::vector input_ranks_; + std::vector> input_shapes_; + std::vector> input_strides_; + std::vector output_ranks_; + std::vector> output_shapes_; + std::vector> output_strides_; }; END_NS_NNCASE_RT_MODULE diff --git a/modules/cpu/src/runtime/runtime_module.cpp b/modules/cpu/src/runtime/runtime_module.cpp index bf4481d159..810b51de4e 100644 --- a/modules/cpu/src/runtime/runtime_module.cpp +++ b/modules/cpu/src/runtime/runtime_module.cpp @@ -26,10 +26,12 @@ result cpu_runtime_module::initialize_before_functions( runtime_module_init_context &context) noexcept { if (!context.is_section_pinned()) return nncase::err(std::errc::bad_address); - try_var(data_, context.get_or_read_section(".data", data_storage_, false)); - try_var(rdata_, + try_var(data, context.get_or_read_section(".data", data_storage_, false)); + try_var(rdata, context.get_or_read_section(".rdata", rdata_storage_, true)); - try_var(text_, context.get_or_read_section(".text", text_storage_, true)); + try_var(text, context.get_or_read_section(".text", text_storage_, true)); + + text_ = text.as_span(); return ok(); } From 8a428c60c09ede898a0e399c30df60d9f9938ce7 Mon Sep 17 00:00:00 2001 From: xhuohai Date: Mon, 24 Jul 2023 12:00:55 +0000 Subject: [PATCH 018/308] Apply code-format changes --- .../CodeGen/FunctionCSource.cs | 60 +- modules/cpu/src/runtime/elf.h | 749 +++++++++--------- modules/cpu/src/runtime/elfload.cpp | 71 +- modules/cpu/src/runtime/elfload.h | 27 +- modules/cpu/src/runtime/elfloader.cpp | 4 +- modules/cpu/src/runtime/elfreloc_aarch64.cpp | 75 +- modules/cpu/src/runtime/elfreloc_amd64.cpp | 25 +- modules/cpu/src/runtime/elfreloc_i386.cpp | 26 +- modules/cpu/src/runtime/elfreloc_riscv64.cpp | 31 +- modules/cpu/src/runtime/runtime_module.cpp | 3 +- 10 files changed, 526 insertions(+), 545 deletions(-) diff --git a/modules/Nncase.Modules.CPU/CodeGen/FunctionCSource.cs b/modules/Nncase.Modules.CPU/CodeGen/FunctionCSource.cs index cd5ff86d9d..e07152e4fc 100644 --- a/modules/Nncase.Modules.CPU/CodeGen/FunctionCSource.cs +++ b/modules/Nncase.Modules.CPU/CodeGen/FunctionCSource.cs @@ -56,6 +56,36 @@ protected string Ext get => _ext; } + /// + /// compile the source txt, write to the out_path. + /// + /// c source code. + /// out .so path. + /// outPath. + public string Compile(string sourcePath, string outPath) + { + var errMsg = new StringBuilder(); + using (var errWriter = new StringWriter(errMsg)) + { + using (var proc = new Process()) + { + proc.StartInfo.FileName = Exe; + proc.StartInfo.Arguments = ArgumentsSpecific(sourcePath, outPath); + proc.StartInfo.RedirectStandardError = true; + proc.ErrorDataReceived += (sender, e) => errWriter.WriteLine(e.Data); + proc.Start(); + proc.BeginErrorReadLine(); + proc.WaitForExit(); + if (proc.ExitCode != 0) + { + throw new InvalidOperationException(errMsg.ToString()); + } + } + } + + return outPath; + } + /// /// select current pattern's exe. /// @@ -109,36 +139,6 @@ private string ArgumentsSpecific(string sourcePath, string outPath) throw new System.ArgumentOutOfRangeException("Only Support Linux/Osx/Windows"); } - /// - /// compile the source txt, write to the out_path. - /// - /// c source code. - /// out .so path. - /// outPath. - public string Compile(string sourcePath, string outPath) - { - var errMsg = new StringBuilder(); - using (var errWriter = new StringWriter(errMsg)) - { - using (var proc = new Process()) - { - proc.StartInfo.FileName = Exe; - proc.StartInfo.Arguments = ArgumentsSpecific(sourcePath, outPath); - proc.StartInfo.RedirectStandardError = true; - proc.ErrorDataReceived += (sender, e) => errWriter.WriteLine(e.Data); - proc.Start(); - proc.BeginErrorReadLine(); - proc.WaitForExit(); - if (proc.ExitCode != 0) - { - throw new InvalidOperationException(errMsg.ToString()); - } - } - } - - return outPath; - } - /// /// create the temp dll file and compile source /// . diff --git a/modules/cpu/src/runtime/elf.h b/modules/cpu/src/runtime/elf.h index 027f99ad0c..d7be196fa3 100644 --- a/modules/cpu/src/runtime/elf.h +++ b/modules/cpu/src/runtime/elf.h @@ -31,346 +31,345 @@ #define ELF_H #include -typedef uint8_t Elf_Byte; +typedef uint8_t Elf_Byte; -typedef uint32_t Elf32_Addr; /* Unsigned program address */ -typedef uint32_t Elf32_Off; /* Unsigned file offset */ -typedef int32_t Elf32_Sword; /* Signed large integer */ -typedef uint32_t Elf32_Word; /* Unsigned large integer */ -typedef uint16_t Elf32_Half; /* Unsigned medium integer */ +typedef uint32_t Elf32_Addr; /* Unsigned program address */ +typedef uint32_t Elf32_Off; /* Unsigned file offset */ +typedef int32_t Elf32_Sword; /* Signed large integer */ +typedef uint32_t Elf32_Word; /* Unsigned large integer */ +typedef uint16_t Elf32_Half; /* Unsigned medium integer */ -typedef uint64_t Elf64_Addr; -typedef uint64_t Elf64_Off; -typedef int32_t Elf64_Shalf; +typedef uint64_t Elf64_Addr; +typedef uint64_t Elf64_Off; +typedef int32_t Elf64_Shalf; #ifdef __alpha__ -typedef int64_t Elf64_Sword; -typedef uint64_t Elf64_Word; +typedef int64_t Elf64_Sword; +typedef uint64_t Elf64_Word; #else -typedef int32_t Elf64_Sword; -typedef uint32_t Elf64_Word; +typedef int32_t Elf64_Sword; +typedef uint32_t Elf64_Word; #endif -typedef int64_t Elf64_Sxword; -typedef uint64_t Elf64_Xword; +typedef int64_t Elf64_Sxword; +typedef uint64_t Elf64_Xword; -typedef uint32_t Elf64_Half; -typedef uint16_t Elf64_Quarter; +typedef uint32_t Elf64_Half; +typedef uint16_t Elf64_Quarter; /* * e_ident[] identification indexes * See http://www.sco.com/developers/gabi/latest/ch4.eheader.html */ -#define EI_MAG0 0 /* file ID */ -#define EI_MAG1 1 /* file ID */ -#define EI_MAG2 2 /* file ID */ -#define EI_MAG3 3 /* file ID */ -#define EI_CLASS 4 /* file class */ -#define EI_DATA 5 /* data encoding */ -#define EI_VERSION 6 /* ELF header version */ -#define EI_OSABI 7 /* OS/ABI ID */ -#define EI_ABIVERSION 8 /* ABI version */ -#define EI_PAD 9 /* start of pad bytes */ -#define EI_NIDENT 16 /* Size of e_ident[] */ +#define EI_MAG0 0 /* file ID */ +#define EI_MAG1 1 /* file ID */ +#define EI_MAG2 2 /* file ID */ +#define EI_MAG3 3 /* file ID */ +#define EI_CLASS 4 /* file class */ +#define EI_DATA 5 /* data encoding */ +#define EI_VERSION 6 /* ELF header version */ +#define EI_OSABI 7 /* OS/ABI ID */ +#define EI_ABIVERSION 8 /* ABI version */ +#define EI_PAD 9 /* start of pad bytes */ +#define EI_NIDENT 16 /* Size of e_ident[] */ /* e_ident[] magic number */ -#define ELFMAG0 0x7f /* e_ident[EI_MAG0] */ -#define ELFMAG1 'E' /* e_ident[EI_MAG1] */ -#define ELFMAG2 'L' /* e_ident[EI_MAG2] */ -#define ELFMAG3 'F' /* e_ident[EI_MAG3] */ -#define ELFMAG "\177ELF" /* magic */ -#define SELFMAG 4 /* size of magic */ +#define ELFMAG0 0x7f /* e_ident[EI_MAG0] */ +#define ELFMAG1 'E' /* e_ident[EI_MAG1] */ +#define ELFMAG2 'L' /* e_ident[EI_MAG2] */ +#define ELFMAG3 'F' /* e_ident[EI_MAG3] */ +#define ELFMAG "\177ELF" /* magic */ +#define SELFMAG 4 /* size of magic */ /* e_ident[] file class */ -#define ELFCLASSNONE 0 /* invalid */ -#define ELFCLASS32 1 /* 32-bit objs */ -#define ELFCLASS64 2 /* 64-bit objs */ -#define ELFCLASSNUM 3 /* number of classes */ +#define ELFCLASSNONE 0 /* invalid */ +#define ELFCLASS32 1 /* 32-bit objs */ +#define ELFCLASS64 2 /* 64-bit objs */ +#define ELFCLASSNUM 3 /* number of classes */ /* e_ident[] data encoding */ -#define ELFDATANONE 0 /* invalid */ -#define ELFDATA2LSB 1 /* Little-Endian */ -#define ELFDATA2MSB 2 /* Big-Endian */ -#define ELFDATANUM 3 /* number of data encode defines */ +#define ELFDATANONE 0 /* invalid */ +#define ELFDATA2LSB 1 /* Little-Endian */ +#define ELFDATA2MSB 2 /* Big-Endian */ +#define ELFDATANUM 3 /* number of data encode defines */ /* e_ident[] Operating System/ABI */ -#define ELFOSABI_SYSV 0 /* UNIX System V ABI */ -#define ELFOSABI_HPUX 1 /* HP-UX operating system */ -#define ELFOSABI_NETBSD 2 /* NetBSD */ -#define ELFOSABI_LINUX 3 /* GNU/Linux */ -#define ELFOSABI_HURD 4 /* GNU/Hurd */ -#define ELFOSABI_86OPEN 5 /* 86Open common IA32 ABI */ -#define ELFOSABI_SOLARIS 6 /* Solaris */ -#define ELFOSABI_MONTEREY 7 /* Monterey */ -#define ELFOSABI_IRIX 8 /* IRIX */ -#define ELFOSABI_FREEBSD 9 /* FreeBSD */ -#define ELFOSABI_TRU64 10 /* TRU64 UNIX */ -#define ELFOSABI_MODESTO 11 /* Novell Modesto */ -#define ELFOSABI_OPENBSD 12 /* OpenBSD */ -#define ELFOSABI_ARM 97 /* ARM */ -#define ELFOSABI_STANDALONE 255 /* Standalone (embedded) application */ +#define ELFOSABI_SYSV 0 /* UNIX System V ABI */ +#define ELFOSABI_HPUX 1 /* HP-UX operating system */ +#define ELFOSABI_NETBSD 2 /* NetBSD */ +#define ELFOSABI_LINUX 3 /* GNU/Linux */ +#define ELFOSABI_HURD 4 /* GNU/Hurd */ +#define ELFOSABI_86OPEN 5 /* 86Open common IA32 ABI */ +#define ELFOSABI_SOLARIS 6 /* Solaris */ +#define ELFOSABI_MONTEREY 7 /* Monterey */ +#define ELFOSABI_IRIX 8 /* IRIX */ +#define ELFOSABI_FREEBSD 9 /* FreeBSD */ +#define ELFOSABI_TRU64 10 /* TRU64 UNIX */ +#define ELFOSABI_MODESTO 11 /* Novell Modesto */ +#define ELFOSABI_OPENBSD 12 /* OpenBSD */ +#define ELFOSABI_ARM 97 /* ARM */ +#define ELFOSABI_STANDALONE 255 /* Standalone (embedded) application */ /* e_ident */ -#define IS_ELF(ehdr) ((ehdr).e_ident[EI_MAG0] == ELFMAG0 && \ - (ehdr).e_ident[EI_MAG1] == ELFMAG1 && \ - (ehdr).e_ident[EI_MAG2] == ELFMAG2 && \ - (ehdr).e_ident[EI_MAG3] == ELFMAG3) +#define IS_ELF(ehdr) \ + ((ehdr).e_ident[EI_MAG0] == ELFMAG0 && \ + (ehdr).e_ident[EI_MAG1] == ELFMAG1 && \ + (ehdr).e_ident[EI_MAG2] == ELFMAG2 && (ehdr).e_ident[EI_MAG3] == ELFMAG3) /* ELF Header */ typedef struct { - unsigned char e_ident[EI_NIDENT]; /* ELF Identification */ - Elf32_Half e_type; /* object file type */ - Elf32_Half e_machine; /* machine */ - Elf32_Word e_version; /* object file version */ - Elf32_Addr e_entry; /* virtual entry point */ - Elf32_Off e_phoff; /* program header table offset */ - Elf32_Off e_shoff; /* section header table offset */ - Elf32_Word e_flags; /* processor-specific flags */ - Elf32_Half e_ehsize; /* ELF header size */ - Elf32_Half e_phentsize; /* program header entry size */ - Elf32_Half e_phnum; /* number of program header entries */ - Elf32_Half e_shentsize; /* section header entry size */ - Elf32_Half e_shnum; /* number of section header entries */ - Elf32_Half e_shstrndx; /* section header table's "section - header string table" entry offset */ + unsigned char e_ident[EI_NIDENT]; /* ELF Identification */ + Elf32_Half e_type; /* object file type */ + Elf32_Half e_machine; /* machine */ + Elf32_Word e_version; /* object file version */ + Elf32_Addr e_entry; /* virtual entry point */ + Elf32_Off e_phoff; /* program header table offset */ + Elf32_Off e_shoff; /* section header table offset */ + Elf32_Word e_flags; /* processor-specific flags */ + Elf32_Half e_ehsize; /* ELF header size */ + Elf32_Half e_phentsize; /* program header entry size */ + Elf32_Half e_phnum; /* number of program header entries */ + Elf32_Half e_shentsize; /* section header entry size */ + Elf32_Half e_shnum; /* number of section header entries */ + Elf32_Half e_shstrndx; /* section header table's "section + header string table" entry offset */ } Elf32_Ehdr; typedef struct { - unsigned char e_ident[EI_NIDENT]; /* Id bytes */ - Elf64_Quarter e_type; /* file type */ - Elf64_Quarter e_machine; /* machine type */ - Elf64_Half e_version; /* version number */ - Elf64_Addr e_entry; /* entry point */ - Elf64_Off e_phoff; /* Program hdr offset */ - Elf64_Off e_shoff; /* Section hdr offset */ - Elf64_Half e_flags; /* Processor flags */ - Elf64_Quarter e_ehsize; /* sizeof ehdr */ - Elf64_Quarter e_phentsize; /* Program header entry size */ - Elf64_Quarter e_phnum; /* Number of program headers */ - Elf64_Quarter e_shentsize; /* Section header entry size */ - Elf64_Quarter e_shnum; /* Number of section headers */ - Elf64_Quarter e_shstrndx; /* String table index */ + unsigned char e_ident[EI_NIDENT]; /* Id bytes */ + Elf64_Quarter e_type; /* file type */ + Elf64_Quarter e_machine; /* machine type */ + Elf64_Half e_version; /* version number */ + Elf64_Addr e_entry; /* entry point */ + Elf64_Off e_phoff; /* Program hdr offset */ + Elf64_Off e_shoff; /* Section hdr offset */ + Elf64_Half e_flags; /* Processor flags */ + Elf64_Quarter e_ehsize; /* sizeof ehdr */ + Elf64_Quarter e_phentsize; /* Program header entry size */ + Elf64_Quarter e_phnum; /* Number of program headers */ + Elf64_Quarter e_shentsize; /* Section header entry size */ + Elf64_Quarter e_shnum; /* Number of section headers */ + Elf64_Quarter e_shstrndx; /* String table index */ } Elf64_Ehdr; /* e_type */ -#define ET_NONE 0 /* No file type */ -#define ET_REL 1 /* relocatable file */ -#define ET_EXEC 2 /* executable file */ -#define ET_DYN 3 /* shared object file */ -#define ET_CORE 4 /* core file */ -#define ET_NUM 5 /* number of types */ -#define ET_LOPROC 0xff00 /* reserved range for processor */ -#define ET_HIPROC 0xffff /* specific e_type */ +#define ET_NONE 0 /* No file type */ +#define ET_REL 1 /* relocatable file */ +#define ET_EXEC 2 /* executable file */ +#define ET_DYN 3 /* shared object file */ +#define ET_CORE 4 /* core file */ +#define ET_NUM 5 /* number of types */ +#define ET_LOPROC 0xff00 /* reserved range for processor */ +#define ET_HIPROC 0xffff /* specific e_type */ /* e_machine */ -#define EM_NONE 0 /* No Machine */ -#define EM_M32 1 /* AT&T WE 32100 */ -#define EM_SPARC 2 /* SPARC */ -#define EM_386 3 /* Intel 80386 */ -#define EM_68K 4 /* Motorola 68000 */ -#define EM_88K 5 /* Motorola 88000 */ -#define EM_486 6 /* Intel 80486 - unused? */ -#define EM_860 7 /* Intel 80860 */ -#define EM_MIPS 8 /* MIPS R3000 Big-Endian only */ +#define EM_NONE 0 /* No Machine */ +#define EM_M32 1 /* AT&T WE 32100 */ +#define EM_SPARC 2 /* SPARC */ +#define EM_386 3 /* Intel 80386 */ +#define EM_68K 4 /* Motorola 68000 */ +#define EM_88K 5 /* Motorola 88000 */ +#define EM_486 6 /* Intel 80486 - unused? */ +#define EM_860 7 /* Intel 80860 */ +#define EM_MIPS 8 /* MIPS R3000 Big-Endian only */ /* * Don't know if EM_MIPS_RS4_BE, * EM_SPARC64, EM_PARISC, * or EM_PPC are ABI compliant */ -#define EM_MIPS_RS4_BE 10 /* MIPS R4000 Big-Endian */ -#define EM_SPARC64 11 /* SPARC v9 64-bit unofficial */ -#define EM_PARISC 15 /* HPPA */ -#define EM_SPARC32PLUS 18 /* Enhanced instruction set SPARC */ -#define EM_PPC 20 /* PowerPC */ -#define EM_ARM 40 /* ARM AArch32 */ -#define EM_ALPHA 41 /* DEC ALPHA */ -#define EM_SH 42 /* Hitachi/Renesas Super-H */ -#define EM_SPARCV9 43 /* SPARC version 9 */ -#define EM_IA_64 50 /* Intel IA-64 Processor */ -#define EM_AMD64 62 /* AMD64 architecture */ -#define EM_VAX 75 /* DEC VAX */ -#define EM_AARCH64 183 /* ARM AArch64 */ +#define EM_MIPS_RS4_BE 10 /* MIPS R4000 Big-Endian */ +#define EM_SPARC64 11 /* SPARC v9 64-bit unofficial */ +#define EM_PARISC 15 /* HPPA */ +#define EM_SPARC32PLUS 18 /* Enhanced instruction set SPARC */ +#define EM_PPC 20 /* PowerPC */ +#define EM_ARM 40 /* ARM AArch32 */ +#define EM_ALPHA 41 /* DEC ALPHA */ +#define EM_SH 42 /* Hitachi/Renesas Super-H */ +#define EM_SPARCV9 43 /* SPARC version 9 */ +#define EM_IA_64 50 /* Intel IA-64 Processor */ +#define EM_AMD64 62 /* AMD64 architecture */ +#define EM_VAX 75 /* DEC VAX */ +#define EM_AARCH64 183 /* ARM AArch64 */ /* Non-standard */ #define EM_ALPHA_EXP 0x9026 /* DEC ALPHA */ #define EM_RISCV 243 - /* Version */ -#define EV_NONE 0 /* Invalid */ -#define EV_CURRENT 1 /* Current */ -#define EV_NUM 2 /* number of versions */ +#define EV_NONE 0 /* Invalid */ +#define EV_CURRENT 1 /* Current */ +#define EV_NUM 2 /* number of versions */ /* Section Header */ typedef struct { - Elf32_Word sh_name; /* name - index into section header - * string table section */ - Elf32_Word sh_type; /* type */ - Elf32_Word sh_flags; /* flags */ - Elf32_Addr sh_addr; /* address */ - Elf32_Off sh_offset; /* file offset */ - Elf32_Word sh_size; /* section size */ - Elf32_Word sh_link; /* section header table index link */ - Elf32_Word sh_info; /* extra information */ - Elf32_Word sh_addralign; /* address alignment */ - Elf32_Word sh_entsize; /* section entry size */ + Elf32_Word sh_name; /* name - index into section header + * string table section */ + Elf32_Word sh_type; /* type */ + Elf32_Word sh_flags; /* flags */ + Elf32_Addr sh_addr; /* address */ + Elf32_Off sh_offset; /* file offset */ + Elf32_Word sh_size; /* section size */ + Elf32_Word sh_link; /* section header table index link */ + Elf32_Word sh_info; /* extra information */ + Elf32_Word sh_addralign; /* address alignment */ + Elf32_Word sh_entsize; /* section entry size */ } Elf32_Shdr; typedef struct { - Elf64_Half sh_name; /* section name */ - Elf64_Half sh_type; /* section type */ - Elf64_Xword sh_flags; /* section flags */ - Elf64_Addr sh_addr; /* virtual address */ - Elf64_Off sh_offset; /* file offset */ - Elf64_Xword sh_size; /* section size */ - Elf64_Half sh_link; /* link to another */ - Elf64_Half sh_info; /* misc info */ - Elf64_Xword sh_addralign; /* memory alignment */ - Elf64_Xword sh_entsize; /* table entry size */ + Elf64_Half sh_name; /* section name */ + Elf64_Half sh_type; /* section type */ + Elf64_Xword sh_flags; /* section flags */ + Elf64_Addr sh_addr; /* virtual address */ + Elf64_Off sh_offset; /* file offset */ + Elf64_Xword sh_size; /* section size */ + Elf64_Half sh_link; /* link to another */ + Elf64_Half sh_info; /* misc info */ + Elf64_Xword sh_addralign; /* memory alignment */ + Elf64_Xword sh_entsize; /* table entry size */ } Elf64_Shdr; /* Special Section Indexes */ -#define SHN_UNDEF 0 /* undefined */ -#define SHN_LORESERVE 0xff00 /* lower bounds of reserved indexes */ -#define SHN_LOPROC 0xff00 /* reserved range for processor */ -#define SHN_HIPROC 0xff1f /* specific section indexes */ -#define SHN_ABS 0xfff1 /* absolute value */ -#define SHN_COMMON 0xfff2 /* common symbol */ -#define SHN_HIRESERVE 0xffff /* upper bounds of reserved indexes */ +#define SHN_UNDEF 0 /* undefined */ +#define SHN_LORESERVE 0xff00 /* lower bounds of reserved indexes */ +#define SHN_LOPROC 0xff00 /* reserved range for processor */ +#define SHN_HIPROC 0xff1f /* specific section indexes */ +#define SHN_ABS 0xfff1 /* absolute value */ +#define SHN_COMMON 0xfff2 /* common symbol */ +#define SHN_HIRESERVE 0xffff /* upper bounds of reserved indexes */ /* sh_type */ -#define SHT_NULL 0 /* inactive */ +#define SHT_NULL 0 /* inactive */ #define SHT_PROGBITS 1 /* program defined information */ -#define SHT_SYMTAB 2 /* symbol table section */ -#define SHT_STRTAB 3 /* string table section */ -#define SHT_RELA 4 /* relocation section with addends*/ -#define SHT_HASH 5 /* symbol hash table section */ -#define SHT_DYNAMIC 6 /* dynamic section */ -#define SHT_NOTE 7 /* note section */ -#define SHT_NOBITS 8 /* no space section */ -#define SHT_REL 9 /* relation section without addends */ -#define SHT_SHLIB 10 /* reserved - purpose unknown */ -#define SHT_DYNSYM 11 /* dynamic symbol table section */ -#define SHT_NUM 12 /* number of section types */ -#define SHT_LOPROC 0x70000000 /* reserved range for processor */ -#define SHT_HIPROC 0x7fffffff /* specific section header types */ -#define SHT_LOUSER 0x80000000 /* reserved range for application */ -#define SHT_HIUSER 0xffffffff /* specific indexes */ +#define SHT_SYMTAB 2 /* symbol table section */ +#define SHT_STRTAB 3 /* string table section */ +#define SHT_RELA 4 /* relocation section with addends*/ +#define SHT_HASH 5 /* symbol hash table section */ +#define SHT_DYNAMIC 6 /* dynamic section */ +#define SHT_NOTE 7 /* note section */ +#define SHT_NOBITS 8 /* no space section */ +#define SHT_REL 9 /* relation section without addends */ +#define SHT_SHLIB 10 /* reserved - purpose unknown */ +#define SHT_DYNSYM 11 /* dynamic symbol table section */ +#define SHT_NUM 12 /* number of section types */ +#define SHT_LOPROC 0x70000000 /* reserved range for processor */ +#define SHT_HIPROC 0x7fffffff /* specific section header types */ +#define SHT_LOUSER 0x80000000 /* reserved range for application */ +#define SHT_HIUSER 0xffffffff /* specific indexes */ /* Section names */ -#define ELF_BSS ".bss" /* uninitialized data */ -#define ELF_DATA ".data" /* initialized data */ -#define ELF_DEBUG ".debug" /* debug */ -#define ELF_DYNAMIC ".dynamic" /* dynamic linking information */ -#define ELF_DYNSTR ".dynstr" /* dynamic string table */ -#define ELF_DYNSYM ".dynsym" /* dynamic symbol table */ -#define ELF_FINI ".fini" /* termination code */ -#define ELF_GOT ".got" /* global offset table */ -#define ELF_HASH ".hash" /* symbol hash table */ -#define ELF_INIT ".init" /* initialization code */ -#define ELF_REL_DATA ".rel.data" /* relocation data */ -#define ELF_REL_FINI ".rel.fini" /* relocation termination code */ -#define ELF_REL_INIT ".rel.init" /* relocation initialization code */ -#define ELF_REL_DYN ".rel.dyn" /* relocation dynamic link info */ -#define ELF_REL_RODATA ".rel.rodata" /* relocation read-only data */ -#define ELF_REL_TEXT ".rel.text" /* relocation code */ -#define ELF_RODATA ".rodata" /* read-only data */ -#define ELF_SHSTRTAB ".shstrtab" /* section header string table */ -#define ELF_STRTAB ".strtab" /* string table */ -#define ELF_SYMTAB ".symtab" /* symbol table */ -#define ELF_TEXT ".text" /* code */ - +#define ELF_BSS ".bss" /* uninitialized data */ +#define ELF_DATA ".data" /* initialized data */ +#define ELF_DEBUG ".debug" /* debug */ +#define ELF_DYNAMIC ".dynamic" /* dynamic linking information */ +#define ELF_DYNSTR ".dynstr" /* dynamic string table */ +#define ELF_DYNSYM ".dynsym" /* dynamic symbol table */ +#define ELF_FINI ".fini" /* termination code */ +#define ELF_GOT ".got" /* global offset table */ +#define ELF_HASH ".hash" /* symbol hash table */ +#define ELF_INIT ".init" /* initialization code */ +#define ELF_REL_DATA ".rel.data" /* relocation data */ +#define ELF_REL_FINI ".rel.fini" /* relocation termination code */ +#define ELF_REL_INIT ".rel.init" /* relocation initialization code */ +#define ELF_REL_DYN ".rel.dyn" /* relocation dynamic link info */ +#define ELF_REL_RODATA ".rel.rodata" /* relocation read-only data */ +#define ELF_REL_TEXT ".rel.text" /* relocation code */ +#define ELF_RODATA ".rodata" /* read-only data */ +#define ELF_SHSTRTAB ".shstrtab" /* section header string table */ +#define ELF_STRTAB ".strtab" /* string table */ +#define ELF_SYMTAB ".symtab" /* symbol table */ +#define ELF_TEXT ".text" /* code */ /* Section Attribute Flags - sh_flags */ -#define SHF_WRITE 0x1 /* Writable */ -#define SHF_ALLOC 0x2 /* occupies memory */ -#define SHF_EXECINSTR 0x4 /* executable */ -#define SHF_TLS 0x400 /* thread local storage */ -#define SHF_MASKPROC 0xf0000000 /* reserved bits for processor - * specific section attributes */ +#define SHF_WRITE 0x1 /* Writable */ +#define SHF_ALLOC 0x2 /* occupies memory */ +#define SHF_EXECINSTR 0x4 /* executable */ +#define SHF_TLS 0x400 /* thread local storage */ +#define SHF_MASKPROC \ + 0xf0000000 /* reserved bits for processor \ + * specific section attributes */ /* Symbol Table Entry */ typedef struct elf32_sym { - Elf32_Word st_name; /* name - index into string table */ - Elf32_Addr st_value; /* symbol value */ - Elf32_Word st_size; /* symbol size */ - unsigned char st_info; /* type and binding */ - unsigned char st_other; /* 0 - no defined meaning */ - Elf32_Half st_shndx; /* section header index */ + Elf32_Word st_name; /* name - index into string table */ + Elf32_Addr st_value; /* symbol value */ + Elf32_Word st_size; /* symbol size */ + unsigned char st_info; /* type and binding */ + unsigned char st_other; /* 0 - no defined meaning */ + Elf32_Half st_shndx; /* section header index */ } Elf32_Sym; typedef struct { - Elf64_Half st_name; /* Symbol name index in str table */ - Elf_Byte st_info; /* type / binding attrs */ - Elf_Byte st_other; /* unused */ - Elf64_Quarter st_shndx; /* section index of symbol */ - Elf64_Xword st_value; /* value of symbol */ - Elf64_Xword st_size; /* size of symbol */ + Elf64_Half st_name; /* Symbol name index in str table */ + Elf_Byte st_info; /* type / binding attrs */ + Elf_Byte st_other; /* unused */ + Elf64_Quarter st_shndx; /* section index of symbol */ + Elf64_Xword st_value; /* value of symbol */ + Elf64_Xword st_size; /* size of symbol */ } Elf64_Sym; /* Symbol table index */ -#define STN_UNDEF 0 /* undefined */ +#define STN_UNDEF 0 /* undefined */ /* Extract symbol info - st_info */ -#define ELF32_ST_BIND(x) ((x) >> 4) -#define ELF32_ST_TYPE(x) (((unsigned int) x) & 0xf) -#define ELF32_ST_INFO(b,t) (((b) << 4) + ((t) & 0xf)) +#define ELF32_ST_BIND(x) ((x) >> 4) +#define ELF32_ST_TYPE(x) (((unsigned int)x) & 0xf) +#define ELF32_ST_INFO(b, t) (((b) << 4) + ((t)&0xf)) -#define ELF64_ST_BIND(x) ((x) >> 4) -#define ELF64_ST_TYPE(x) (((unsigned int) x) & 0xf) -#define ELF64_ST_INFO(b,t) (((b) << 4) + ((t) & 0xf)) +#define ELF64_ST_BIND(x) ((x) >> 4) +#define ELF64_ST_TYPE(x) (((unsigned int)x) & 0xf) +#define ELF64_ST_INFO(b, t) (((b) << 4) + ((t)&0xf)) /* Symbol Binding - ELF32_ST_BIND - st_info */ -#define STB_LOCAL 0 /* Local symbol */ -#define STB_GLOBAL 1 /* Global symbol */ -#define STB_WEAK 2 /* like global - lower precedence */ -#define STB_NUM 3 /* number of symbol bindings */ -#define STB_LOPROC 13 /* reserved range for processor */ -#define STB_HIPROC 15 /* specific symbol bindings */ +#define STB_LOCAL 0 /* Local symbol */ +#define STB_GLOBAL 1 /* Global symbol */ +#define STB_WEAK 2 /* like global - lower precedence */ +#define STB_NUM 3 /* number of symbol bindings */ +#define STB_LOPROC 13 /* reserved range for processor */ +#define STB_HIPROC 15 /* specific symbol bindings */ /* Symbol type - ELF32_ST_TYPE - st_info */ -#define STT_NOTYPE 0 /* not specified */ -#define STT_OBJECT 1 /* data object */ -#define STT_FUNC 2 /* function */ -#define STT_SECTION 3 /* section */ -#define STT_FILE 4 /* file */ -#define STT_TLS 6 /* thread local storage */ -#define STT_LOPROC 13 /* reserved range for processor */ -#define STT_HIPROC 15 /* specific symbol types */ +#define STT_NOTYPE 0 /* not specified */ +#define STT_OBJECT 1 /* data object */ +#define STT_FUNC 2 /* function */ +#define STT_SECTION 3 /* section */ +#define STT_FILE 4 /* file */ +#define STT_TLS 6 /* thread local storage */ +#define STT_LOPROC 13 /* reserved range for processor */ +#define STT_HIPROC 15 /* specific symbol types */ /* Relocation entry with implicit addend */ typedef struct { - Elf32_Addr r_offset; /* offset of relocation */ - Elf32_Word r_info; /* symbol table index and type */ + Elf32_Addr r_offset; /* offset of relocation */ + Elf32_Word r_info; /* symbol table index and type */ } Elf32_Rel; /* Relocation entry with explicit addend */ typedef struct { - Elf32_Addr r_offset; /* offset of relocation */ - Elf32_Word r_info; /* symbol table index and type */ - Elf32_Sword r_addend; + Elf32_Addr r_offset; /* offset of relocation */ + Elf32_Word r_info; /* symbol table index and type */ + Elf32_Sword r_addend; } Elf32_Rela; /* Extract relocation info - r_info */ -#define ELF32_R_SYM(i) ((i) >> 8) -#define ELF32_R_TYPE(i) ((unsigned char) (i)) -#define ELF32_R_INFO(s,t) (((s) << 8) + (unsigned char)(t)) +#define ELF32_R_SYM(i) ((i) >> 8) +#define ELF32_R_TYPE(i) ((unsigned char)(i)) +#define ELF32_R_INFO(s, t) (((s) << 8) + (unsigned char)(t)) typedef struct { - Elf64_Xword r_offset; /* where to do it */ - Elf64_Xword r_info; /* index & type of relocation */ + Elf64_Xword r_offset; /* where to do it */ + Elf64_Xword r_info; /* index & type of relocation */ } Elf64_Rel; typedef struct { - Elf64_Xword r_offset; /* where to do it */ - Elf64_Xword r_info; /* index & type of relocation */ - Elf64_Sxword r_addend; /* adjustment value */ + Elf64_Xword r_offset; /* where to do it */ + Elf64_Xword r_info; /* index & type of relocation */ + Elf64_Sxword r_addend; /* adjustment value */ } Elf64_Rela; -#define ELF64_R_SYM(info) ((info) >> 32) -#define ELF64_R_TYPE(info) ((info) & 0xFFFFFFFF) -#define ELF64_R_INFO(s,t) (((s) << 32) + (__uint32_t)(t)) +#define ELF64_R_SYM(info) ((info) >> 32) +#define ELF64_R_TYPE(info) ((info)&0xFFFFFFFF) +#define ELF64_R_INFO(s, t) (((s) << 32) + (__uint32_t)(t)) #if defined(__mips64__) && defined(__MIPSEL__) /* @@ -378,134 +377,134 @@ typedef struct { * than the regular ELF ABI: the r_info field is split into several * pieces (see gnu/usr.bin/binutils/include/elf/mips.h for details). */ -#undef ELF64_R_SYM -#undef ELF64_R_TYPE -#undef ELF64_R_INFO -#define ELF64_R_TYPE(info) (swap32((info) >> 32)) -#define ELF64_R_SYM(info) ((info) & 0xFFFFFFFF) -#define ELF64_R_INFO(s,t) (((__uint64_t)swap32(t) << 32) + (__uint32_t)(s)) +#undef ELF64_R_SYM +#undef ELF64_R_TYPE +#undef ELF64_R_INFO +#define ELF64_R_TYPE(info) (swap32((info) >> 32)) +#define ELF64_R_SYM(info) ((info)&0xFFFFFFFF) +#define ELF64_R_INFO(s, t) (((__uint64_t)swap32(t) << 32) + (__uint32_t)(s)) #endif /* __mips64__ && __MIPSEL__ */ /* Program Header */ typedef struct { - Elf32_Word p_type; /* segment type */ - Elf32_Off p_offset; /* segment offset */ - Elf32_Addr p_vaddr; /* virtual address of segment */ - Elf32_Addr p_paddr; /* physical address - ignored? */ - Elf32_Word p_filesz; /* number of bytes in file for seg. */ - Elf32_Word p_memsz; /* number of bytes in mem. for seg. */ - Elf32_Word p_flags; /* flags */ - Elf32_Word p_align; /* memory alignment */ + Elf32_Word p_type; /* segment type */ + Elf32_Off p_offset; /* segment offset */ + Elf32_Addr p_vaddr; /* virtual address of segment */ + Elf32_Addr p_paddr; /* physical address - ignored? */ + Elf32_Word p_filesz; /* number of bytes in file for seg. */ + Elf32_Word p_memsz; /* number of bytes in mem. for seg. */ + Elf32_Word p_flags; /* flags */ + Elf32_Word p_align; /* memory alignment */ } Elf32_Phdr; typedef struct { - Elf64_Half p_type; /* entry type */ - Elf64_Half p_flags; /* flags */ - Elf64_Off p_offset; /* offset */ - Elf64_Addr p_vaddr; /* virtual address */ - Elf64_Addr p_paddr; /* physical address */ - Elf64_Xword p_filesz; /* file size */ - Elf64_Xword p_memsz; /* memory size */ - Elf64_Xword p_align; /* memory & file alignment */ + Elf64_Half p_type; /* entry type */ + Elf64_Half p_flags; /* flags */ + Elf64_Off p_offset; /* offset */ + Elf64_Addr p_vaddr; /* virtual address */ + Elf64_Addr p_paddr; /* physical address */ + Elf64_Xword p_filesz; /* file size */ + Elf64_Xword p_memsz; /* memory size */ + Elf64_Xword p_align; /* memory & file alignment */ } Elf64_Phdr; /* Segment types - p_type */ -#define PT_NULL 0 /* unused */ -#define PT_LOAD 1 /* loadable segment */ -#define PT_DYNAMIC 2 /* dynamic linking section */ -#define PT_INTERP 3 /* the RTLD */ -#define PT_NOTE 4 /* auxiliary information */ -#define PT_SHLIB 5 /* reserved - purpose undefined */ -#define PT_PHDR 6 /* program header */ -#define PT_TLS 7 /* thread local storage */ -#define PT_LOOS 0x60000000 /* reserved range for OS */ -#define PT_HIOS 0x6fffffff /* specific segment types */ -#define PT_LOPROC 0x70000000 /* reserved range for processor */ -#define PT_HIPROC 0x7fffffff /* specific segment types */ +#define PT_NULL 0 /* unused */ +#define PT_LOAD 1 /* loadable segment */ +#define PT_DYNAMIC 2 /* dynamic linking section */ +#define PT_INTERP 3 /* the RTLD */ +#define PT_NOTE 4 /* auxiliary information */ +#define PT_SHLIB 5 /* reserved - purpose undefined */ +#define PT_PHDR 6 /* program header */ +#define PT_TLS 7 /* thread local storage */ +#define PT_LOOS 0x60000000 /* reserved range for OS */ +#define PT_HIOS 0x6fffffff /* specific segment types */ +#define PT_LOPROC 0x70000000 /* reserved range for processor */ +#define PT_HIPROC 0x7fffffff /* specific segment types */ #define PT_OPENBSD_RANDOMIZE 0x65a3dbe6 /* fill with random data */ -#define PT_GANDR_KERNEL 0x67646b6c /* gdkl */ - +#define PT_GANDR_KERNEL 0x67646b6c /* gdkl */ /* Segment flags - p_flags */ -#define PF_X 0x1 /* Executable */ -#define PF_W 0x2 /* Writable */ -#define PF_R 0x4 /* Readable */ -#define PF_MASKPROC 0xf0000000 /* reserved bits for processor */ - /* specific segment flags */ +#define PF_X 0x1 /* Executable */ +#define PF_W 0x2 /* Writable */ +#define PF_R 0x4 /* Readable */ +#define PF_MASKPROC 0xf0000000 /* reserved bits for processor */ + /* specific segment flags */ /* Dynamic structure */ typedef struct { - Elf32_Sword d_tag; /* controls meaning of d_val */ + Elf32_Sword d_tag; /* controls meaning of d_val */ union { - Elf32_Word d_val; /* Multiple meanings - see d_tag */ - Elf32_Addr d_ptr; /* program virtual address */ + Elf32_Word d_val; /* Multiple meanings - see d_tag */ + Elf32_Addr d_ptr; /* program virtual address */ } d_un; } Elf32_Dyn; typedef struct { - Elf64_Xword d_tag; /* controls meaning of d_val */ + Elf64_Xword d_tag; /* controls meaning of d_val */ union { - Elf64_Addr d_ptr; + Elf64_Addr d_ptr; Elf64_Xword d_val; } d_un; } Elf64_Dyn; /* Dynamic Array Tags - d_tag */ -#define DT_NULL 0 /* marks end of _DYNAMIC array */ -#define DT_NEEDED 1 /* string table offset of needed lib */ -#define DT_PLTRELSZ 2 /* size of relocation entries in PLT */ -#define DT_PLTGOT 3 /* address PLT/GOT */ -#define DT_HASH 4 /* address of symbol hash table */ -#define DT_STRTAB 5 /* address of string table */ -#define DT_SYMTAB 6 /* address of symbol table */ -#define DT_RELA 7 /* address of relocation table */ -#define DT_RELASZ 8 /* size of relocation table */ -#define DT_RELAENT 9 /* size of relocation entry */ -#define DT_STRSZ 10 /* size of string table */ -#define DT_SYMENT 11 /* size of symbol table entry */ -#define DT_INIT 12 /* address of initialization func. */ -#define DT_FINI 13 /* address of termination function */ -#define DT_SONAME 14 /* string table offset of shared obj */ -#define DT_RPATH 15 /* string table offset of library - * search path */ -#define DT_SYMBOLIC 16 /* start sym search in shared obj. */ -#define DT_REL 17 /* address of rel. tbl. w addends */ -#define DT_RELSZ 18 /* size of DT_REL relocation table */ -#define DT_RELENT 19 /* size of DT_REL relocation entry */ -#define DT_PLTREL 20 /* PLT referenced relocation entry */ -#define DT_DEBUG 21 /* bugger */ -#define DT_TEXTREL 22 /* Allow rel. mod. to unwritable seg */ -#define DT_JMPREL 23 /* add. of PLT's relocation entries */ -#define DT_BIND_NOW 24 /* Bind now regardless of env setting */ -#define DT_LOOS 0x6000000d /* reserved range for OS */ -#define DT_HIOS 0x6ffff000 /* specific dynamic array tags */ -#define DT_LOPROC 0x70000000 /* reserved range for processor */ -#define DT_HIPROC 0x7fffffff /* specific dynamic array tags */ +#define DT_NULL 0 /* marks end of _DYNAMIC array */ +#define DT_NEEDED 1 /* string table offset of needed lib */ +#define DT_PLTRELSZ 2 /* size of relocation entries in PLT */ +#define DT_PLTGOT 3 /* address PLT/GOT */ +#define DT_HASH 4 /* address of symbol hash table */ +#define DT_STRTAB 5 /* address of string table */ +#define DT_SYMTAB 6 /* address of symbol table */ +#define DT_RELA 7 /* address of relocation table */ +#define DT_RELASZ 8 /* size of relocation table */ +#define DT_RELAENT 9 /* size of relocation entry */ +#define DT_STRSZ 10 /* size of string table */ +#define DT_SYMENT 11 /* size of symbol table entry */ +#define DT_INIT 12 /* address of initialization func. */ +#define DT_FINI 13 /* address of termination function */ +#define DT_SONAME 14 /* string table offset of shared obj */ +#define DT_RPATH \ + 15 /* string table offset of library \ + * search path */ +#define DT_SYMBOLIC 16 /* start sym search in shared obj. */ +#define DT_REL 17 /* address of rel. tbl. w addends */ +#define DT_RELSZ 18 /* size of DT_REL relocation table */ +#define DT_RELENT 19 /* size of DT_REL relocation entry */ +#define DT_PLTREL 20 /* PLT referenced relocation entry */ +#define DT_DEBUG 21 /* bugger */ +#define DT_TEXTREL 22 /* Allow rel. mod. to unwritable seg */ +#define DT_JMPREL 23 /* add. of PLT's relocation entries */ +#define DT_BIND_NOW 24 /* Bind now regardless of env setting */ +#define DT_LOOS 0x6000000d /* reserved range for OS */ +#define DT_HIOS 0x6ffff000 /* specific dynamic array tags */ +#define DT_LOPROC 0x70000000 /* reserved range for processor */ +#define DT_HIPROC 0x7fffffff /* specific dynamic array tags */ /* some other useful tags */ -#define DT_RELACOUNT 0x6ffffff9 /* if present, number of RELATIVE */ -#define DT_RELCOUNT 0x6ffffffa /* relocs, which must come first */ -#define DT_FLAGS_1 0x6ffffffb +#define DT_RELACOUNT 0x6ffffff9 /* if present, number of RELATIVE */ +#define DT_RELCOUNT 0x6ffffffa /* relocs, which must come first */ +#define DT_FLAGS_1 0x6ffffffb /* Dynamic Flags - DT_FLAGS_1 .dynamic entry */ -#define DF_1_NOW 0x00000001 -#define DF_1_GLOBAL 0x00000002 -#define DF_1_GROUP 0x00000004 -#define DF_1_NODELETE 0x00000008 -#define DF_1_LOADFLTR 0x00000010 +#define DF_1_NOW 0x00000001 +#define DF_1_GLOBAL 0x00000002 +#define DF_1_GROUP 0x00000004 +#define DF_1_NODELETE 0x00000008 +#define DF_1_LOADFLTR 0x00000010 #define DF_1_INITFIRST 0x00000020 -#define DF_1_NOOPEN 0x00000040 -#define DF_1_ORIGIN 0x00000080 -#define DF_1_DIRECT 0x00000100 -#define DF_1_TRANS 0x00000200 +#define DF_1_NOOPEN 0x00000040 +#define DF_1_ORIGIN 0x00000080 +#define DF_1_DIRECT 0x00000100 +#define DF_1_TRANS 0x00000200 #define DF_1_INTERPOSE 0x00000400 -#define DF_1_NODEFLIB 0x00000800 -#define DF_1_NODUMP 0x00001000 -#define DF_1_CONLFAT 0x00002000 +#define DF_1_NODEFLIB 0x00000800 +#define DF_1_NODUMP 0x00001000 +#define DF_1_CONLFAT 0x00002000 /* ld.so: number of low tags that are used saved internally (0 .. DT_NUM-1) */ -#define DT_NUM (DT_JMPREL+1) +#define DT_NUM (DT_JMPREL + 1) /* * Note Definitions @@ -523,25 +522,25 @@ typedef struct { } Elf64_Note; #if defined(ELFSIZE) && (ELFSIZE == 32) -#define Elf_Ehdr Elf32_Ehdr -#define Elf_Phdr Elf32_Phdr -#define Elf_Shdr Elf32_Shdr -#define Elf_Sym Elf32_Sym -#define Elf_Rel Elf32_Rel -#define Elf_RelA Elf32_Rela -#define Elf_Dyn Elf32_Dyn -#define Elf_Half Elf32_Half -#define Elf_Word Elf32_Word -#define Elf_Sword Elf32_Sword -#define Elf_Addr Elf32_Addr -#define Elf_Off Elf32_Off -#define Elf_Nhdr Elf32_Nhdr -#define Elf_Note Elf32_Note - -#define ELF_R_SYM ELF32_R_SYM -#define ELF_R_TYPE ELF32_R_TYPE -#define ELF_R_INFO ELF32_R_INFO -#define ELFCLASS ELFCLASS32 +#define Elf_Ehdr Elf32_Ehdr +#define Elf_Phdr Elf32_Phdr +#define Elf_Shdr Elf32_Shdr +#define Elf_Sym Elf32_Sym +#define Elf_Rel Elf32_Rel +#define Elf_RelA Elf32_Rela +#define Elf_Dyn Elf32_Dyn +#define Elf_Half Elf32_Half +#define Elf_Word Elf32_Word +#define Elf_Sword Elf32_Sword +#define Elf_Addr Elf32_Addr +#define Elf_Off Elf32_Off +#define Elf_Nhdr Elf32_Nhdr +#define Elf_Note Elf32_Note + +#define ELF_R_SYM ELF32_R_SYM +#define ELF_R_TYPE ELF32_R_TYPE +#define ELF_R_INFO ELF32_R_INFO +#define ELFCLASS ELFCLASS32 #define ELF_ST_BIND ELF32_ST_BIND #define ELF_ST_TYPE ELF32_ST_TYPE @@ -549,25 +548,25 @@ typedef struct { #elif defined(ELFSIZE) && (ELFSIZE == 64) -#define Elf_Ehdr Elf64_Ehdr -#define Elf_Phdr Elf64_Phdr -#define Elf_Shdr Elf64_Shdr -#define Elf_Sym Elf64_Sym -#define Elf_Rel Elf64_Rel -#define Elf_RelA Elf64_Rela -#define Elf_Dyn Elf64_Dyn -#define Elf_Half Elf64_Half -#define Elf_Word Elf64_Word -#define Elf_Sword Elf64_Sword -#define Elf_Addr Elf64_Addr -#define Elf_Off Elf64_Off -#define Elf_Nhdr Elf64_Nhdr -#define Elf_Note Elf64_Note - -#define ELF_R_SYM ELF64_R_SYM -#define ELF_R_TYPE ELF64_R_TYPE -#define ELF_R_INFO ELF64_R_INFO -#define ELFCLASS ELFCLASS64 +#define Elf_Ehdr Elf64_Ehdr +#define Elf_Phdr Elf64_Phdr +#define Elf_Shdr Elf64_Shdr +#define Elf_Sym Elf64_Sym +#define Elf_Rel Elf64_Rel +#define Elf_RelA Elf64_Rela +#define Elf_Dyn Elf64_Dyn +#define Elf_Half Elf64_Half +#define Elf_Word Elf64_Word +#define Elf_Sword Elf64_Sword +#define Elf_Addr Elf64_Addr +#define Elf_Off Elf64_Off +#define Elf_Nhdr Elf64_Nhdr +#define Elf_Note Elf64_Note + +#define ELF_R_SYM ELF64_R_SYM +#define ELF_R_TYPE ELF64_R_TYPE +#define ELF_R_INFO ELF64_R_INFO +#define ELFCLASS ELFCLASS64 #define ELF_ST_BIND ELF64_ST_BIND #define ELF_ST_TYPE ELF64_ST_TYPE diff --git a/modules/cpu/src/runtime/elfload.cpp b/modules/cpu/src/runtime/elfload.cpp index b1670da065..cb712edba5 100644 --- a/modules/cpu/src/runtime/elfload.cpp +++ b/modules/cpu/src/runtime/elfload.cpp @@ -2,15 +2,14 @@ #include #include -el_status el_pread(el_ctx *ctx, void *def, size_t nb, size_t offset) -{ +el_status el_pread(el_ctx *ctx, void *def, size_t nb, size_t offset) { return ctx->pread(ctx, def, nb, offset) ? EL_OK : EL_EIO; } -#define EL_PHOFF(ctx, num) (((ctx)->ehdr.e_phoff + (num) * (ctx)->ehdr.e_phentsize)) +#define EL_PHOFF(ctx, num) \ + (((ctx)->ehdr.e_phoff + (num) * (ctx)->ehdr.e_phentsize)) -el_status el_findphdr(el_ctx *ctx, Elf_Phdr *phdr, uint32_t type, unsigned *i) -{ +el_status el_findphdr(el_ctx *ctx, Elf_Phdr *phdr, uint32_t type, unsigned *i) { el_status rv = EL_OK; for (; *i < ctx->ehdr.e_phnum; (*i)++) { if ((rv = el_pread(ctx, phdr, sizeof *phdr, EL_PHOFF(ctx, *i)))) @@ -25,8 +24,7 @@ el_status el_findphdr(el_ctx *ctx, Elf_Phdr *phdr, uint32_t type, unsigned *i) return rv; } -el_status el_init(el_ctx *ctx) -{ +el_status el_init(el_ctx *ctx) { el_status rv = EL_OK; if ((rv = el_pread(ctx, &ctx->ehdr, sizeof ctx->ehdr, 0))) return rv; @@ -36,7 +34,6 @@ el_status el_init(el_ctx *ctx) if (!IS_ELF(ctx->ehdr)) return EL_NOTELF; - if (ctx->ehdr.e_ident[EI_CLASS] != ELFCLASS) return EL_WRONGBITS; @@ -74,11 +71,11 @@ el_status el_init(el_ctx *ctx) ctx->memsz = 0; unsigned i = 0; - for(;;) { + for (;;) { if ((rv = el_findphdr(ctx, &ph, PT_LOAD, &i))) return rv; - if (i == (unsigned) -1) + if (i == (unsigned)-1) break; Elf_Addr phend = ph.p_vaddr + ph.p_memsz; @@ -97,13 +94,13 @@ el_status el_init(el_ctx *ctx) if ((rv = el_findphdr(ctx, &ph, PT_DYNAMIC, &i))) return rv; - if (i == (unsigned) -1) + if (i == (unsigned)-1) return EL_NODYN; - ctx->dynoff = ph.p_offset; + ctx->dynoff = ph.p_offset; ctx->dynsize = ph.p_filesz; } else { - ctx->dynoff = 0; + ctx->dynoff = 0; ctx->dynsize = 0; } @@ -118,8 +115,7 @@ typedef void* (*el_alloc_cb)( Elf_Addr size); */ -el_status el_load(el_ctx *ctx, el_alloc_cb alloc) -{ +el_status el_load(el_ctx *ctx, el_alloc_cb alloc) { el_status rv = EL_OK; /* address deltas */ @@ -129,11 +125,11 @@ el_status el_load(el_ctx *ctx, el_alloc_cb alloc) /* iterate paddrs */ Elf_Phdr ph; unsigned i = 0; - for(;;) { + for (;;) { if ((rv = el_findphdr(ctx, &ph, PT_LOAD, &i))) return rv; - if (i == (unsigned) -1) + if (i == (unsigned)-1) break; Elf_Addr pload = ph.p_paddr + pdelta; @@ -144,8 +140,8 @@ el_status el_load(el_ctx *ctx, el_alloc_cb alloc) if (!dest) return EL_ENOMEM; - printf("Loading seg fileoff %lx, vaddr %lx to %lx\n", - ph.p_offset, ph.p_vaddr, (uintptr_t)dest); + printf("Loading seg fileoff %lx, vaddr %lx to %lx\n", ph.p_offset, + ph.p_vaddr, (uintptr_t)dest); /* read loaded portion */ if ((rv = el_pread(ctx, dest, ph.p_filesz, ph.p_offset))) @@ -160,13 +156,13 @@ el_status el_load(el_ctx *ctx, el_alloc_cb alloc) return rv; } -el_status el_finddyn(el_ctx *ctx, Elf_Dyn *dyn, uint32_t tag) -{ +el_status el_finddyn(el_ctx *ctx, Elf_Dyn *dyn, uint32_t tag) { el_status rv = EL_OK; - size_t ndyn = ctx->dynsize / sizeof(Elf_Dyn); + size_t ndyn = ctx->dynsize / sizeof(Elf_Dyn); - for(unsigned i = 0; i < ndyn; i++) { - if ((rv = el_pread(ctx, dyn, sizeof *dyn, ctx->dynoff + i * sizeof *dyn))) + for (unsigned i = 0; i < ndyn; i++) { + if ((rv = el_pread(ctx, dyn, sizeof *dyn, + ctx->dynoff + i * sizeof *dyn))) return rv; if (dyn->d_tag == tag) @@ -177,8 +173,7 @@ el_status el_finddyn(el_ctx *ctx, Elf_Dyn *dyn, uint32_t tag) return EL_OK; } -el_status el_findrelocs(el_ctx *ctx, el_relocinfo *ri, uint32_t type) -{ +el_status el_findrelocs(el_ctx *ctx, el_relocinfo *ri, uint32_t type) { el_status rv = EL_OK; Elf_Dyn rel, relsz, relent; @@ -192,14 +187,13 @@ el_status el_findrelocs(el_ctx *ctx, el_relocinfo *ri, uint32_t type) if ((rv = el_finddyn(ctx, &relent, type + 2))) return rv; - if (rel.d_tag == DT_NULL - || relsz.d_tag == DT_NULL - || relent.d_tag == DT_NULL) { + if (rel.d_tag == DT_NULL || relsz.d_tag == DT_NULL || + relent.d_tag == DT_NULL) { ri->entrysize = 0; ri->tablesize = 0; - ri->tableoff = 0; + ri->tableoff = 0; } else { - ri->tableoff = rel.d_un.d_ptr; + ri->tableoff = rel.d_un.d_ptr; ri->tablesize = relsz.d_un.d_val; ri->entrysize = relent.d_un.d_val; } @@ -210,15 +204,14 @@ el_status el_findrelocs(el_ctx *ctx, el_relocinfo *ri, uint32_t type) extern el_status el_applyrel(el_ctx *ctx, Elf_Rel *rel); extern el_status el_applyrela(el_ctx *ctx, Elf_RelA *rela); -el_status el_relocate(el_ctx *ctx) -{ +el_status el_relocate(el_ctx *ctx) { el_status rv = EL_OK; // not dynamic if (ctx->ehdr.e_type != ET_DYN) return EL_OK; - char *base = (char *) ctx->base_load_paddr; + char *base = (char *)ctx->base_load_paddr; el_relocinfo ri; #ifdef EL_ARCH_USES_REL @@ -226,8 +219,8 @@ el_status el_relocate(el_ctx *ctx) return rv; if (ri.entrysize != sizeof(Elf_Rel) && ri.tablesize) { - EL_DEBUG("Relocation size %u doesn't match expected %u\n", - ri.entrysize, sizeof(Elf_Rel)); + EL_DEBUG("Relocation size %u doesn't match expected %u\n", ri.entrysize, + sizeof(Elf_Rel)); return EL_BADREL; } @@ -246,8 +239,8 @@ el_status el_relocate(el_ctx *ctx) return rv; if (ri.entrysize != sizeof(Elf_RelA) && ri.tablesize) { - EL_DEBUG("Relocation size %u doesn't match expected %u\n", - ri.entrysize, sizeof(Elf_RelA)); + EL_DEBUG("Relocation size %u doesn't match expected %u\n", ri.entrysize, + sizeof(Elf_RelA)); return EL_BADREL; } @@ -262,7 +255,7 @@ el_status el_relocate(el_ctx *ctx) #endif #if !defined(EL_ARCH_USES_REL) && !defined(EL_ARCH_USES_RELA) - #error No relocation type defined! +#error No relocation type defined! #endif return rv; diff --git a/modules/cpu/src/runtime/elfload.h b/modules/cpu/src/runtime/elfload.h index 148b9b3a72..044fd5f967 100644 --- a/modules/cpu/src/runtime/elfload.h +++ b/modules/cpu/src/runtime/elfload.h @@ -1,19 +1,21 @@ #ifndef ELFLOAD_H #define ELFLOAD_H +#include "elf.h" +#include "elfarch.h" #include #include -#include "elfarch.h" -#include "elf.h" #ifdef DEBUG #include #define EL_DEBUG(...) printf(__VA_ARGS__) #else -#define EL_DEBUG(...) do {} while(0) +#define EL_DEBUG(...) \ + do { \ + } while (0) #endif typedef enum { - EL_OK = 0, + EL_OK = 0, EL_EIO, EL_ENOMEM, @@ -34,9 +36,7 @@ typedef struct el_ctx { /* base_load_* -> address we are actually going to load at */ - Elf_Addr - base_load_paddr, - base_load_vaddr; + Elf_Addr base_load_paddr, base_load_vaddr; /* size in memory of binary */ Elf_Addr memsz; @@ -45,10 +45,10 @@ typedef struct el_ctx { Elf_Addr align; /* ELF header */ - Elf_Ehdr ehdr; + Elf_Ehdr ehdr; /* Offset of dynamic table (0 if not ET_DYN) */ - Elf_Off dynoff; + Elf_Off dynoff; /* Size of dynamic table (0 if not ET_DYN) */ Elf_Addr dynsize; @@ -58,11 +58,8 @@ typedef struct el_ctx { el_status el_pread(el_ctx *ctx, void *def, size_t nb, size_t offset); el_status el_init(el_ctx *ctx); -typedef void* (*el_alloc_cb)( - el_ctx *ctx, - Elf_Addr phys, - Elf_Addr virt, - Elf_Addr size); +typedef void *(*el_alloc_cb)(el_ctx *ctx, Elf_Addr phys, Elf_Addr virt, + Elf_Addr size); el_status el_load(el_ctx *ctx, el_alloc_cb alloccb); @@ -84,7 +81,7 @@ el_status el_relocate(el_ctx *ctx); el_status el_finddyn(el_ctx *ctx, Elf_Dyn *dyn, uint32_t type); typedef struct { - Elf_Off tableoff; + Elf_Off tableoff; Elf_Addr tablesize; Elf_Addr entrysize; } el_relocinfo; diff --git a/modules/cpu/src/runtime/elfloader.cpp b/modules/cpu/src/runtime/elfloader.cpp index 8bced71e14..f0c6571db1 100644 --- a/modules/cpu/src/runtime/elfloader.cpp +++ b/modules/cpu/src/runtime/elfloader.cpp @@ -4,8 +4,8 @@ using namespace nncase; using namespace nncase::runtime; using namespace nncase::runtime::cpu; -int elfloader::invoke_elf(size_t id, buffer_t **buffers, - nncase_mt_t *nncase_mt, void *data, void *rdata) { +int elfloader::invoke_elf(size_t id, buffer_t **buffers, nncase_mt_t *nncase_mt, + void *data, void *rdata) { check(el_init(&ctx_), "initialising"); diff --git a/modules/cpu/src/runtime/elfreloc_aarch64.cpp b/modules/cpu/src/runtime/elfreloc_aarch64.cpp index e8fd3f8ed7..58a5ad1753 100644 --- a/modules/cpu/src/runtime/elfreloc_aarch64.cpp +++ b/modules/cpu/src/runtime/elfreloc_aarch64.cpp @@ -2,66 +2,61 @@ #if defined(__aarch64__) -#define R_AARCH64_NONE 0 +#define R_AARCH64_NONE 0 #define R_AARCH64_RELATIVE 1027 -el_status el_applyrela(el_ctx *ctx, Elf_RelA *rel) -{ - uintptr_t *p = (uintptr_t*) (rel->r_offset + ctx->base_load_paddr); +el_status el_applyrela(el_ctx *ctx, Elf_RelA *rel) { + uintptr_t *p = (uintptr_t *)(rel->r_offset + ctx->base_load_paddr); uint32_t type = ELF_R_TYPE(rel->r_info); - uint32_t sym = ELF_R_SYM(rel->r_info); + uint32_t sym = ELF_R_SYM(rel->r_info); switch (type) { - case R_AARCH64_NONE: - EL_DEBUG("R_AARCH64_NONE\n"); - break; - case R_AARCH64_RELATIVE: - if (sym) { - EL_DEBUG("R_AARCH64_RELATIVE with symbol ref!\n"); - return EL_BADREL; - } - - EL_DEBUG("Applying R_AARCH64_RELATIVE reloc @%p\n", p); - *p = rel->r_addend + ctx->base_load_vaddr; - break; - - default: - EL_DEBUG("Bad relocation %u\n", type); + case R_AARCH64_NONE: + EL_DEBUG("R_AARCH64_NONE\n"); + break; + case R_AARCH64_RELATIVE: + if (sym) { + EL_DEBUG("R_AARCH64_RELATIVE with symbol ref!\n"); return EL_BADREL; + } + EL_DEBUG("Applying R_AARCH64_RELATIVE reloc @%p\n", p); + *p = rel->r_addend + ctx->base_load_vaddr; + break; + + default: + EL_DEBUG("Bad relocation %u\n", type); + return EL_BADREL; } return EL_OK; } -el_status el_applyrel(el_ctx *ctx, Elf_Rel *rel) -{ - uintptr_t *p = (uintptr_t*) (rel->r_offset + ctx->base_load_paddr); +el_status el_applyrel(el_ctx *ctx, Elf_Rel *rel) { + uintptr_t *p = (uintptr_t *)(rel->r_offset + ctx->base_load_paddr); uint32_t type = ELF_R_TYPE(rel->r_info); - uint32_t sym = ELF_R_SYM(rel->r_info); + uint32_t sym = ELF_R_SYM(rel->r_info); switch (type) { - case R_AARCH64_NONE: - EL_DEBUG("R_AARCH64_NONE\n"); - break; - case R_AARCH64_RELATIVE: - if (sym) { - EL_DEBUG("R_AARCH64_RELATIVE with symbol ref!\n"); - return EL_BADREL; - } - - EL_DEBUG("Applying R_AARCH64_RELATIVE reloc @%p\n", p); - *p += ctx->base_load_vaddr; - break; - - default: - EL_DEBUG("Bad relocation %u\n", type); + case R_AARCH64_NONE: + EL_DEBUG("R_AARCH64_NONE\n"); + break; + case R_AARCH64_RELATIVE: + if (sym) { + EL_DEBUG("R_AARCH64_RELATIVE with symbol ref!\n"); return EL_BADREL; + } + + EL_DEBUG("Applying R_AARCH64_RELATIVE reloc @%p\n", p); + *p += ctx->base_load_vaddr; + break; + default: + EL_DEBUG("Bad relocation %u\n", type); + return EL_BADREL; } return EL_OK; } - #endif diff --git a/modules/cpu/src/runtime/elfreloc_amd64.cpp b/modules/cpu/src/runtime/elfreloc_amd64.cpp index 8b32082ff3..98e638ae44 100644 --- a/modules/cpu/src/runtime/elfreloc_amd64.cpp +++ b/modules/cpu/src/runtime/elfreloc_amd64.cpp @@ -2,24 +2,23 @@ #if defined(__amd64__) -#define R_AMD64_NONE 0 +#define R_AMD64_NONE 0 #define R_AMD64_RELATIVE 8 -el_status el_applyrela(el_ctx *ctx, Elf_RelA *rel) -{ - uint64_t *p = (uint64_t*) (rel->r_offset + ctx->base_load_vaddr); +el_status el_applyrela(el_ctx *ctx, Elf_RelA *rel) { + uint64_t *p = (uint64_t *)(rel->r_offset + ctx->base_load_vaddr); uint32_t type = ELF_R_TYPE(rel->r_info); switch (type) { - case R_AMD64_NONE: break; - case R_AMD64_RELATIVE: - EL_DEBUG("Applying R_AMD64_RELATIVE reloc @%p\n", p); - *p = rel->r_addend + ctx->base_load_vaddr; - break; - default: - EL_DEBUG("Bad relocation %u\n", type); - return EL_BADREL; - + case R_AMD64_NONE: + break; + case R_AMD64_RELATIVE: + EL_DEBUG("Applying R_AMD64_RELATIVE reloc @%p\n", p); + *p = rel->r_addend + ctx->base_load_vaddr; + break; + default: + EL_DEBUG("Bad relocation %u\n", type); + return EL_BADREL; } return EL_OK; diff --git a/modules/cpu/src/runtime/elfreloc_i386.cpp b/modules/cpu/src/runtime/elfreloc_i386.cpp index fdf66e3d3f..2373ab036c 100644 --- a/modules/cpu/src/runtime/elfreloc_i386.cpp +++ b/modules/cpu/src/runtime/elfreloc_i386.cpp @@ -2,24 +2,24 @@ #if defined(__i386__) -#define R_386_NONE 0 +#define R_386_NONE 0 #define R_386_RELATIVE 8 -el_status el_applyrel(el_ctx *ctx, Elf_Rel *rel) -{ - uint32_t *p = (uint32_t*) (rel->r_offset + ctx->base_load_vaddr); +el_status el_applyrel(el_ctx *ctx, Elf_Rel *rel) { + uint32_t *p = (uint32_t *)(rel->r_offset + ctx->base_load_vaddr); uint32_t type = ELF_R_TYPE(rel->r_info); - uint32_t sym = ELF_R_SYM(rel->r_info); + uint32_t sym = ELF_R_SYM(rel->r_info); switch (type) { - case R_386_NONE: break; - case R_386_RELATIVE: - EL_DEBUG("Applying R_386_RELATIVE reloc @%p\n", p); - *p += ctx->base_load_vaddr; - break; - default: - EL_DEBUG("Bad relocation %u\n", type); - return EL_BADREL; + case R_386_NONE: + break; + case R_386_RELATIVE: + EL_DEBUG("Applying R_386_RELATIVE reloc @%p\n", p); + *p += ctx->base_load_vaddr; + break; + default: + EL_DEBUG("Bad relocation %u\n", type); + return EL_BADREL; } return EL_OK; diff --git a/modules/cpu/src/runtime/elfreloc_riscv64.cpp b/modules/cpu/src/runtime/elfreloc_riscv64.cpp index a4d99a64ef..487361da2c 100644 --- a/modules/cpu/src/runtime/elfreloc_riscv64.cpp +++ b/modules/cpu/src/runtime/elfreloc_riscv64.cpp @@ -2,29 +2,28 @@ #if defined(__riscv) -#define R_riscv64_NONE 0 +#define R_riscv64_NONE 0 #define R_riscv64_RELATIVE 3 #define R_riscv64_JUMP_SLOT 5 -el_status el_applyrela(el_ctx *ctx, Elf_RelA *rel) -{ - uint64_t *p = (uint64_t*) (rel->r_offset + ctx->base_load_vaddr); +el_status el_applyrela(el_ctx *ctx, Elf_RelA *rel) { + uint64_t *p = (uint64_t *)(rel->r_offset + ctx->base_load_vaddr); uint32_t type = ELF_R_TYPE(rel->r_info); EL_DEBUG("rv\n"); switch (type) { - case R_riscv64_NONE: break; - case R_riscv64_RELATIVE: - EL_DEBUG("Applying R_riscv64_RELATIVE reloc @%p\n", p); - *p = rel->r_addend + ctx->base_load_vaddr; - break; - case R_riscv64_JUMP_SLOT: - EL_DEBUG("Applying R_riscv64_JUMP_SLOT reloc @%p\n", p); - break; - default: - EL_DEBUG("Bad relocation %u\n", type); - return EL_BADREL; - + case R_riscv64_NONE: + break; + case R_riscv64_RELATIVE: + EL_DEBUG("Applying R_riscv64_RELATIVE reloc @%p\n", p); + *p = rel->r_addend + ctx->base_load_vaddr; + break; + case R_riscv64_JUMP_SLOT: + EL_DEBUG("Applying R_riscv64_JUMP_SLOT reloc @%p\n", p); + break; + default: + EL_DEBUG("Bad relocation %u\n", type); + return EL_BADREL; } return EL_OK; diff --git a/modules/cpu/src/runtime/runtime_module.cpp b/modules/cpu/src/runtime/runtime_module.cpp index 810b51de4e..1bdd121f9c 100644 --- a/modules/cpu/src/runtime/runtime_module.cpp +++ b/modules/cpu/src/runtime/runtime_module.cpp @@ -27,8 +27,7 @@ result cpu_runtime_module::initialize_before_functions( if (!context.is_section_pinned()) return nncase::err(std::errc::bad_address); try_var(data, context.get_or_read_section(".data", data_storage_, false)); - try_var(rdata, - context.get_or_read_section(".rdata", rdata_storage_, true)); + try_var(rdata, context.get_or_read_section(".rdata", rdata_storage_, true)); try_var(text, context.get_or_read_section(".text", text_storage_, true)); text_ = text.as_span(); From 092df591f9352a2394100a3d44f87e24be97d04d Mon Sep 17 00:00:00 2001 From: xhuohai Date: Tue, 25 Jul 2023 01:57:36 +0000 Subject: [PATCH 019/308] Apply code-format changes --- .../Nncase.Modules.CPU/CodeGen/FunctionCSource.cs | 12 ++++++------ 1 file changed, 6 insertions(+), 6 deletions(-) diff --git a/modules/Nncase.Modules.CPU/CodeGen/FunctionCSource.cs b/modules/Nncase.Modules.CPU/CodeGen/FunctionCSource.cs index e07152e4fc..347ff76d36 100644 --- a/modules/Nncase.Modules.CPU/CodeGen/FunctionCSource.cs +++ b/modules/Nncase.Modules.CPU/CodeGen/FunctionCSource.cs @@ -86,6 +86,12 @@ public string Compile(string sourcePath, string outPath) return outPath; } + /// + /// create the temp dll file and compile source + /// . + /// + public string Compile(string sourcePath) => Compile(sourcePath, CodeGenUtil.GetTempFileName(Ext)); + /// /// select current pattern's exe. /// @@ -138,12 +144,6 @@ private string ArgumentsSpecific(string sourcePath, string outPath) throw new System.ArgumentOutOfRangeException("Only Support Linux/Osx/Windows"); } - - /// - /// create the temp dll file and compile source - /// . - /// - public string Compile(string sourcePath) => Compile(sourcePath, CodeGenUtil.GetTempFileName(Ext)); } internal sealed class FunctionCSource From 19546c076c87bdcb8f25f4577d69fb9d3a52e9c0 Mon Sep 17 00:00:00 2001 From: huochenghai Date: Tue, 25 Jul 2023 10:00:25 +0800 Subject: [PATCH 020/308] fix elfloader build --- modules/cpu/src/runtime/elf.h | 1 + 1 file changed, 1 insertion(+) diff --git a/modules/cpu/src/runtime/elf.h b/modules/cpu/src/runtime/elf.h index d7be196fa3..04cbf49893 100644 --- a/modules/cpu/src/runtime/elf.h +++ b/modules/cpu/src/runtime/elf.h @@ -30,6 +30,7 @@ #ifndef ELF_H #define ELF_H #include +#include "elfarch.h" typedef uint8_t Elf_Byte; From e8956e6986eca9770c1ead538fd6fc831280924d Mon Sep 17 00:00:00 2001 From: huochenghai Date: Tue, 25 Jul 2023 10:02:28 +0800 Subject: [PATCH 021/308] disable launch_debugger --- python/nncase/__init__.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/python/nncase/__init__.py b/python/nncase/__init__.py index f58942ce25..8653d3b301 100644 --- a/python/nncase/__init__.py +++ b/python/nncase/__init__.py @@ -44,7 +44,7 @@ def _initialize(): _initialize() -_nncase.launch_debugger() +# _nncase.launch_debugger() class ImportOptions: From 1db8052db303f429081c40ca57cf7c35803f14f9 Mon Sep 17 00:00:00 2001 From: xhuohai Date: Tue, 25 Jul 2023 02:06:13 +0000 Subject: [PATCH 022/308] Apply code-format changes --- modules/cpu/src/runtime/elf.h | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/modules/cpu/src/runtime/elf.h b/modules/cpu/src/runtime/elf.h index 04cbf49893..4db5998b3a 100644 --- a/modules/cpu/src/runtime/elf.h +++ b/modules/cpu/src/runtime/elf.h @@ -29,8 +29,8 @@ #ifndef ELF_H #define ELF_H -#include #include "elfarch.h" +#include typedef uint8_t Elf_Byte; From a726a93d25c5a7ea77214c566f1f89ae8326e314 Mon Sep 17 00:00:00 2001 From: huochenghai Date: Tue, 25 Jul 2023 10:53:46 +0800 Subject: [PATCH 023/308] update csource dump --- modules/Nncase.Modules.CPU/CodeGen/LinkableModule.cs | 12 ++++++++---- modules/Nncase.Modules.CPU/CodeGen/ModuleBuilder.cs | 2 +- 2 files changed, 9 insertions(+), 5 deletions(-) diff --git a/modules/Nncase.Modules.CPU/CodeGen/LinkableModule.cs b/modules/Nncase.Modules.CPU/CodeGen/LinkableModule.cs index e2e53e6b19..946cfd4afa 100644 --- a/modules/Nncase.Modules.CPU/CodeGen/LinkableModule.cs +++ b/modules/Nncase.Modules.CPU/CodeGen/LinkableModule.cs @@ -2,6 +2,7 @@ // Licensed under the Apache license. See LICENSE file in the project root for full license information. using System; +using System.IO; using System.Collections.Generic; using System.Linq; using System.Text; @@ -18,11 +19,13 @@ internal sealed class LinkableModule : ILinkableModule private readonly byte[] _rdata; private readonly IReadOnlyList _functions; + private readonly CompileOptions _options; - public LinkableModule(byte[] rdata, IReadOnlyList functions) + public LinkableModule(byte[] rdata, IReadOnlyList functions, CompileOptions options) { _rdata = rdata; _functions = functions; + _options = options; } public ILinkedModule Link(ILinkContext linkContext) @@ -31,9 +34,10 @@ public ILinkedModule Link(ILinkContext linkContext) var elfPath = CompileCSource(csourcePath); var text = File.ReadAllBytes(elfPath); - if (DumpScope.Current.IsEnabled(DumpFlags.CodeGen)) + if (_options.DumpFlags.HasFlag(DumpFlags.CodeGen)) { - using (var fs = DumpScope.Current.OpenFile("cpuModule.h")) + var dumpPath = _options.DumpDir; + using (var fs = File.Open(Path.Join(dumpPath, "cpuModule.h"), FileMode.Create)) { using (var writer = new StreamWriter(fs)) { @@ -41,7 +45,7 @@ public ILinkedModule Link(ILinkContext linkContext) } } - using (var fs = DumpScope.Current.OpenFile("cpuModule.c")) + using (var fs = File.Open(Path.Join(dumpPath, "cpuModule.c"), FileMode.Create)) { File.Open(csourcePath, FileMode.Open, FileAccess.Read).CopyTo(fs); } diff --git a/modules/Nncase.Modules.CPU/CodeGen/ModuleBuilder.cs b/modules/Nncase.Modules.CPU/CodeGen/ModuleBuilder.cs index c68b7b16da..36dd625463 100644 --- a/modules/Nncase.Modules.CPU/CodeGen/ModuleBuilder.cs +++ b/modules/Nncase.Modules.CPU/CodeGen/ModuleBuilder.cs @@ -32,7 +32,7 @@ public ILinkableModule Build(IReadOnlyList functions) var linkableFunctions = functions.OfType().Select((f, i) => new FunctionBuilder((uint)i, _rdataWriter).Build(f)).ToArray(); _rdataWriter.Flush(); - return new LinkableModule(_rdataContent.ToArray(), linkableFunctions); + return new LinkableModule(_rdataContent.ToArray(), linkableFunctions, CompileOptions); } public void Dispose() => ((IDisposable)_rdataContent).Dispose(); From 11f469196b1cbff2a22e5d7e70842d14b29ac7ae Mon Sep 17 00:00:00 2001 From: huochenghai Date: Tue, 25 Jul 2023 12:04:24 +0800 Subject: [PATCH 024/308] fix cpu runtime function --- .../Nncase.Modules.CPU/CodeGen/CSourceUtilities.cs | 3 --- modules/cpu/src/runtime/cpu_common.h | 10 +++++----- modules/cpu/src/runtime/runtime_function.cpp | 6 +++--- modules/cpu/src/runtime/runtime_module.cpp | 12 +++++------- 4 files changed, 13 insertions(+), 18 deletions(-) diff --git a/modules/Nncase.Modules.CPU/CodeGen/CSourceUtilities.cs b/modules/Nncase.Modules.CPU/CodeGen/CSourceUtilities.cs index e7e479017b..7b2d6e09bc 100644 --- a/modules/Nncase.Modules.CPU/CodeGen/CSourceUtilities.cs +++ b/modules/Nncase.Modules.CPU/CodeGen/CSourceUtilities.cs @@ -31,9 +31,6 @@ internal static string ContertUnary(Unary op, CSymbol[] arguments) string str; switch (op.UnaryOp) { - case UnaryOp.Neg: - str = $"!{input}"; - break; default: str = $"nncase_mt->{arguments[0].Type}_{nameof(Unary).ToLower()}_{op.UnaryOp.ToString().ToLower()}{input}"; break; diff --git a/modules/cpu/src/runtime/cpu_common.h b/modules/cpu/src/runtime/cpu_common.h index 70c22065b9..72ace7cd89 100644 --- a/modules/cpu/src/runtime/cpu_common.h +++ b/modules/cpu/src/runtime/cpu_common.h @@ -57,15 +57,15 @@ inline float float_unary_square(float x) { return x * x; } .float_unary_exp = expf, .float_unary_floor = floorf, .float_unary_log = logf, - .float_unary_logical_not = &float_unary_logical_not, - .float_unary_neg = &float_unary_neg, + .float_unary_logical_not = float_unary_logical_not, + .float_unary_neg = float_unary_neg, .float_unary_round = roundf, - .float_unary_rsqrt = &float_unary_rsqrt, - .float_unary_sign = &float_unary_sign, + .float_unary_rsqrt = float_unary_rsqrt, + .float_unary_sign = float_unary_sign, .float_unary_sin = sinf, .float_unary_sinh = sinhf, .float_unary_sqrt = sqrtf, - .float_unary_square = &float_unary_square, + .float_unary_square = float_unary_square, .float_unary_tanh = tanhf}; END_NS_NNCASE_RT_MODULE \ No newline at end of file diff --git a/modules/cpu/src/runtime/runtime_function.cpp b/modules/cpu/src/runtime/runtime_function.cpp index 2a6be1e344..a128173364 100644 --- a/modules/cpu/src/runtime/runtime_function.cpp +++ b/modules/cpu/src/runtime/runtime_function.cpp @@ -109,7 +109,7 @@ result cpu_runtime_function::initialize_core( result cpu_runtime_function::invoke_core(NNCASE_UNUSED gsl::span parameters, - value_t return_value) noexcept { + NNCASE_UNUSED value_t return_value) noexcept { try_var(id, module().find_id_by_function(this)); std::cout << "call " << id << std::endl; @@ -128,7 +128,7 @@ cpu_runtime_function::invoke_core(NNCASE_UNUSED gsl::span parameters, // output buffer for (uint32_t i = 0; i < output_ranks_.size(); i++) { - auto output_tensor = parameters[i].as().expect( + auto output_tensor = parameters[input_ranks_.size() + i].as().expect( "output " + std::to_string(i) + " is not a tensor"); try_var(output_span, get_output_span(output_tensor)); buffer_t *output_buffer = @@ -147,5 +147,5 @@ cpu_runtime_function::invoke_core(NNCASE_UNUSED gsl::span parameters, delete buffers[i]; } - return ok(return_value); + return ok(tuple(std::in_place)); } \ No newline at end of file diff --git a/modules/cpu/src/runtime/runtime_module.cpp b/modules/cpu/src/runtime/runtime_module.cpp index 1bdd121f9c..4bfc2fff11 100644 --- a/modules/cpu/src/runtime/runtime_module.cpp +++ b/modules/cpu/src/runtime/runtime_module.cpp @@ -24,13 +24,11 @@ using namespace nncase::runtime::cpu; result cpu_runtime_module::initialize_before_functions( runtime_module_init_context &context) noexcept { - if (!context.is_section_pinned()) - return nncase::err(std::errc::bad_address); - try_var(data, context.get_or_read_section(".data", data_storage_, false)); - try_var(rdata, context.get_or_read_section(".rdata", rdata_storage_, true)); - try_var(text, context.get_or_read_section(".text", text_storage_, true)); - - text_ = text.as_span(); + // if (!context.is_section_pinned()) + // return nncase::err(std::errc::bad_address); + // try_var(data, context.get_or_read_section(".data", data_storage_, false)); + // try_var(rdata, context.get_or_read_section(".rdata", rdata_storage_, true)); + try_set(text_, context.get_or_read_section(".text", text_storage_, true)); return ok(); } From b438407bd62e59a79fa2e9e8a0a907837bd658eb Mon Sep 17 00:00:00 2001 From: xhuohai Date: Tue, 25 Jul 2023 04:07:23 +0000 Subject: [PATCH 025/308] Apply code-format changes --- modules/Nncase.Modules.CPU/CodeGen/LinkableModule.cs | 2 +- modules/cpu/src/runtime/runtime_function.cpp | 7 ++++--- modules/cpu/src/runtime/runtime_module.cpp | 5 +++-- 3 files changed, 8 insertions(+), 6 deletions(-) diff --git a/modules/Nncase.Modules.CPU/CodeGen/LinkableModule.cs b/modules/Nncase.Modules.CPU/CodeGen/LinkableModule.cs index 946cfd4afa..f668f441b4 100644 --- a/modules/Nncase.Modules.CPU/CodeGen/LinkableModule.cs +++ b/modules/Nncase.Modules.CPU/CodeGen/LinkableModule.cs @@ -2,8 +2,8 @@ // Licensed under the Apache license. See LICENSE file in the project root for full license information. using System; -using System.IO; using System.Collections.Generic; +using System.IO; using System.Linq; using System.Text; using System.Threading.Tasks; diff --git a/modules/cpu/src/runtime/runtime_function.cpp b/modules/cpu/src/runtime/runtime_function.cpp index a128173364..f786c92a20 100644 --- a/modules/cpu/src/runtime/runtime_function.cpp +++ b/modules/cpu/src/runtime/runtime_function.cpp @@ -109,7 +109,7 @@ result cpu_runtime_function::initialize_core( result cpu_runtime_function::invoke_core(NNCASE_UNUSED gsl::span parameters, - NNCASE_UNUSED value_t return_value) noexcept { + NNCASE_UNUSED value_t return_value) noexcept { try_var(id, module().find_id_by_function(this)); std::cout << "call " << id << std::endl; @@ -128,8 +128,9 @@ cpu_runtime_function::invoke_core(NNCASE_UNUSED gsl::span parameters, // output buffer for (uint32_t i = 0; i < output_ranks_.size(); i++) { - auto output_tensor = parameters[input_ranks_.size() + i].as().expect( - "output " + std::to_string(i) + " is not a tensor"); + auto output_tensor = + parameters[input_ranks_.size() + i].as().expect( + "output " + std::to_string(i) + " is not a tensor"); try_var(output_span, get_output_span(output_tensor)); buffer_t *output_buffer = new buffer_t(output_span.data(), 0, output_shapes_[i].data(), diff --git a/modules/cpu/src/runtime/runtime_module.cpp b/modules/cpu/src/runtime/runtime_module.cpp index 4bfc2fff11..0b1b2efbe8 100644 --- a/modules/cpu/src/runtime/runtime_module.cpp +++ b/modules/cpu/src/runtime/runtime_module.cpp @@ -26,8 +26,9 @@ result cpu_runtime_module::initialize_before_functions( runtime_module_init_context &context) noexcept { // if (!context.is_section_pinned()) // return nncase::err(std::errc::bad_address); - // try_var(data, context.get_or_read_section(".data", data_storage_, false)); - // try_var(rdata, context.get_or_read_section(".rdata", rdata_storage_, true)); + // try_var(data, context.get_or_read_section(".data", data_storage_, + // false)); try_var(rdata, context.get_or_read_section(".rdata", + // rdata_storage_, true)); try_set(text_, context.get_or_read_section(".text", text_storage_, true)); return ok(); From 893b73dd7c4f0166728e990f1d0b1d23b8127d2c Mon Sep 17 00:00:00 2001 From: huochenghai Date: Tue, 25 Jul 2023 13:44:36 +0800 Subject: [PATCH 026/308] fix warnings --- .../CodeGen/CSourceConvertVisitor.cs | 48 +++++++++---------- .../CodeGen/CSourceUtilities.cs | 3 +- .../CodeGen/FunctionCSource.cs | 4 +- 3 files changed, 28 insertions(+), 27 deletions(-) diff --git a/modules/Nncase.Modules.CPU/CodeGen/CSourceConvertVisitor.cs b/modules/Nncase.Modules.CPU/CodeGen/CSourceConvertVisitor.cs index ab4397910d..5dc9cec026 100644 --- a/modules/Nncase.Modules.CPU/CodeGen/CSourceConvertVisitor.cs +++ b/modules/Nncase.Modules.CPU/CodeGen/CSourceConvertVisitor.cs @@ -74,14 +74,14 @@ public CSymbol(string type, string name) internal sealed class IndentWriter : StringWriter { - public int Indent; - public IndentWriter(StringBuilder sb, int indent = 0) : base(sb) { Indent = indent; } + public int Indent { get; set; } + public void IndWrite(string? value) { for (int i = 0; i < Indent; i++) @@ -98,24 +98,24 @@ public void IndWrite(string? value) /// internal sealed class CSourceConvertVisitor : ExprFunctor { - public readonly Dictionary ExprMemo; + private readonly Dictionary _exprMemo; private readonly StringBuilder _implBuilder; public CSourceConvertVisitor() { _implBuilder = new StringBuilder(); - ExprMemo = new(ReferenceEqualityComparer.Instance); + _exprMemo = new(ReferenceEqualityComparer.Instance); } public FunctionCSource GetFunctionCSource() { - return new(ExprMemo[VisitRoot!].Type + ";", _implBuilder.ToString()); + return new(_exprMemo[VisitRoot!].Type + ";", _implBuilder.ToString()); } /// protected override CSymbol VisitPrimFunction(PrimFunction expr) { - if (ExprMemo.TryGetValue(expr, out var symbol)) + if (_exprMemo.TryGetValue(expr, out var symbol)) { return symbol; } @@ -133,7 +133,7 @@ protected override CSymbol VisitPrimFunction(PrimFunction expr) IndentScope.Writer.IndWrite($"{type} {{\n"); // 2. Function body - using (var _ = new IndentScope()) + using (_ = new IndentScope()) { Visit(expr.Body); } @@ -143,14 +143,14 @@ protected override CSymbol VisitPrimFunction(PrimFunction expr) } symbol = new(type, new(expr.Name)); - ExprMemo.Add(expr, symbol); + _exprMemo.Add(expr, symbol); return symbol; } /// protected override CSymbol VisitCall(Call expr) { - if (ExprMemo.TryGetValue(expr, out var symbol)) + if (_exprMemo.TryGetValue(expr, out var symbol)) { return symbol; } @@ -183,14 +183,14 @@ protected override CSymbol VisitCall(Call expr) } symbol = new(type, str); - ExprMemo.Add(expr, symbol); + _exprMemo.Add(expr, symbol); return symbol; } /// protected override CSymbol VisitConst(Const expr) { - if (ExprMemo.TryGetValue(expr, out var symbol)) + if (_exprMemo.TryGetValue(expr, out var symbol)) { return symbol; } @@ -215,14 +215,14 @@ protected override CSymbol VisitConst(Const expr) } symbol = new(type, str); - ExprMemo.Add(expr, symbol); + _exprMemo.Add(expr, symbol); return symbol; } /// protected override CSymbol VisitVar(Var expr) { - if (ExprMemo.TryGetValue(expr, out var symbol)) + if (_exprMemo.TryGetValue(expr, out var symbol)) { return symbol; } @@ -233,14 +233,14 @@ protected override CSymbol VisitVar(Var expr) } symbol = new(ttype.DType.ToC(), new($"{expr.Name}_{expr.GlobalVarIndex}")); - ExprMemo.Add(expr, symbol); + _exprMemo.Add(expr, symbol); return symbol; } /// protected override CSymbol VisitFor(For expr) { - if (ExprMemo.TryGetValue(expr, out var symbol)) + if (_exprMemo.TryGetValue(expr, out var symbol)) { return symbol; } @@ -258,14 +258,14 @@ protected override CSymbol VisitFor(For expr) IndentScope.Writer.IndWrite("}\n"); symbol = new(string.Empty, string.Empty); - ExprMemo.Add(expr, symbol); + _exprMemo.Add(expr, symbol); return symbol; } /// protected override CSymbol VisitSequential(Sequential expr) { - if (ExprMemo.TryGetValue(expr, out var symbol)) + if (_exprMemo.TryGetValue(expr, out var symbol)) { return symbol; } @@ -284,26 +284,26 @@ protected override CSymbol VisitSequential(Sequential expr) } symbol = new(string.Empty, string.Empty); - ExprMemo.Add(expr, symbol); + _exprMemo.Add(expr, symbol); return symbol; } /// protected override CSymbol VisitIfThenElse(IfThenElse expr) { - if (ExprMemo.TryGetValue(expr, out var symbol)) + if (_exprMemo.TryGetValue(expr, out var symbol)) { return symbol; } IndentScope.Writer.IndWrite($"if({Visit(expr.Condition).Name}) {{\n"); - using (var _ = new IndentScope()) + using (_ = new IndentScope()) { Visit(expr.Then); } IndentScope.Writer.IndWrite("} else {\n"); - using (var _ = new IndentScope()) + using (_ = new IndentScope()) { Visit(expr.Else); } @@ -311,19 +311,19 @@ protected override CSymbol VisitIfThenElse(IfThenElse expr) IndentScope.Writer.IndWrite("}\n"); symbol = new(string.Empty, string.Empty); - ExprMemo.Add(expr, symbol); + _exprMemo.Add(expr, symbol); return symbol; } protected override CSymbol VisitPhysicalBuffer(PhysicalBuffer expr) { - if (ExprMemo.TryGetValue(expr, out var symbol)) + if (_exprMemo.TryGetValue(expr, out var symbol)) { return symbol; } symbol = new(CSourceBuiltn.BufferType + "*", expr.Name); - ExprMemo.Add(expr, symbol); + _exprMemo.Add(expr, symbol); return symbol; } } diff --git a/modules/Nncase.Modules.CPU/CodeGen/CSourceUtilities.cs b/modules/Nncase.Modules.CPU/CodeGen/CSourceUtilities.cs index 7b2d6e09bc..359ba4c3e9 100644 --- a/modules/Nncase.Modules.CPU/CodeGen/CSourceUtilities.cs +++ b/modules/Nncase.Modules.CPU/CodeGen/CSourceUtilities.cs @@ -1,6 +1,7 @@ // Copyright (c) Canaan Inc. All rights reserved. // Licensed under the Apache license. See LICENSE file in the project root for full license information. +using System.Globalization; using Nncase.Diagnostics; using Nncase.IR.Math; @@ -32,7 +33,7 @@ internal static string ContertUnary(Unary op, CSymbol[] arguments) switch (op.UnaryOp) { default: - str = $"nncase_mt->{arguments[0].Type}_{nameof(Unary).ToLower()}_{op.UnaryOp.ToString().ToLower()}{input}"; + str = $"nncase_mt->{arguments[0].Type}_{nameof(Unary).ToLower(CultureInfo.CurrentCulture)}_{op.UnaryOp.ToString().ToLower(CultureInfo.CurrentCulture)}{input}"; break; } diff --git a/modules/Nncase.Modules.CPU/CodeGen/FunctionCSource.cs b/modules/Nncase.Modules.CPU/CodeGen/FunctionCSource.cs index 347ff76d36..456d23649d 100644 --- a/modules/Nncase.Modules.CPU/CodeGen/FunctionCSource.cs +++ b/modules/Nncase.Modules.CPU/CodeGen/FunctionCSource.cs @@ -95,7 +95,7 @@ public string Compile(string sourcePath, string outPath) /// /// select current pattern's exe. /// - /// + /// NotSupportedException. private void PlatformSpecific() { if (RuntimeInformation.IsOSPlatform(OSPlatform.Linux)) @@ -142,7 +142,7 @@ private string ArgumentsSpecific(string sourcePath, string outPath) return $"/C (\"{vcvardir}\" x64) && (cl /D_USRDLL /D_WINDLL \"{sourcePath}\" /MT /link /DLL /OUT:\"{outPath}\")"; } - throw new System.ArgumentOutOfRangeException("Only Support Linux/Osx/Windows"); + throw new System.NotSupportedException("Only Support Linux/Osx/Windows"); } } From 25c4fc296fd53a94480de676b5ca14d1c7c2aa5a Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?=E9=83=91=E5=90=AF=E8=88=AA?= <597323109@qq.com> Date: Tue, 25 Jul 2023 15:26:58 +0800 Subject: [PATCH 027/308] fix bug --- .../CodeGen/CSourceBuiltn.cs | 2 - .../CodeGen/CSourceCompiler.cs | 152 ++++++++++++++++++ .../CodeGen/FunctionCSource.cs | 131 --------------- .../CodeGen/LinkableModule.cs | 15 +- .../Passes/Rules/LowerUnary.cs | 2 +- .../Nncase.Modules.CPU/Targets/CPUTarget.cs | 2 +- modules/cpu/src/runtime/elfload.cpp | 6 +- modules/cpu/src/runtime/elfloader.h | 3 +- modules/cpu/src/runtime/runtime_function.cpp | 8 +- 9 files changed, 174 insertions(+), 147 deletions(-) create mode 100644 modules/Nncase.Modules.CPU/CodeGen/CSourceCompiler.cs diff --git a/modules/Nncase.Modules.CPU/CodeGen/CSourceBuiltn.cs b/modules/Nncase.Modules.CPU/CodeGen/CSourceBuiltn.cs index 6eb0ae594d..13b55dbb34 100644 --- a/modules/Nncase.Modules.CPU/CodeGen/CSourceBuiltn.cs +++ b/modules/Nncase.Modules.CPU/CodeGen/CSourceBuiltn.cs @@ -55,7 +55,5 @@ public static class CSourceBuiltn {MethodTable} {BufferStruct} - -static nncase_mt_t *nncase_mt; "; } diff --git a/modules/Nncase.Modules.CPU/CodeGen/CSourceCompiler.cs b/modules/Nncase.Modules.CPU/CodeGen/CSourceCompiler.cs new file mode 100644 index 0000000000..9d34d5763a --- /dev/null +++ b/modules/Nncase.Modules.CPU/CodeGen/CSourceCompiler.cs @@ -0,0 +1,152 @@ +// Copyright (c) Canaan Inc. All rights reserved. +// Licensed under the Apache license. See LICENSE file in the project root for full license information. + +using System; +using System.Collections; +using System.Collections.Generic; +using System.Diagnostics; +using System.IO; +using System.Linq; +using System.Runtime.InteropServices; +using System.Text; +using Nncase.IR; +using Nncase.Schedule; +using Nncase.TIR; + +namespace Nncase.CodeGen; + +/// +/// the csource code compiler. +/// +public class CSourceCompiler +{ + /// + /// compiler exe name. + /// + private string _exe = string.Empty; + + /// + /// compiler exe name. + /// + private string _arch = string.Empty; + + /// + /// compiler exe name. + /// + private string _ext = string.Empty; + + public CSourceCompiler() + { + PlatformSpecific(); + ArchSpecific(); + } + + protected string Exe + { + get => _exe; + } + + protected string Arch + { + get => _arch; + } + + protected string Ext + { + get => _ext; + } + + /// + /// compile the source txt, write to the out_path. + /// + /// c source code. + /// out .so path. + /// outPath. + public string Compile(string sourcePath, string outPath) + { + var errMsg = new StringBuilder(); + using (var errWriter = new StringWriter(errMsg)) + { + using (var proc = new Process()) + { + proc.StartInfo.FileName = Exe; + proc.StartInfo.Arguments = ArgumentsSpecific(sourcePath, outPath); + proc.StartInfo.RedirectStandardError = true; + proc.ErrorDataReceived += (sender, e) => errWriter.WriteLine(e.Data); + proc.Start(); + proc.BeginErrorReadLine(); + proc.WaitForExit(); + if (proc.ExitCode != 0) + { + throw new InvalidOperationException(errMsg.ToString()); + } + } + } + + return outPath; + } + + /// + /// create the temp dll file and compile source + /// . + /// + public string Compile(string sourcePath) => Compile(sourcePath, CodeGenUtil.GetTempFileName(Ext)); + + /// + /// select current pattern's exe. + /// + /// NotSupportedException. + private void PlatformSpecific() + { + if (RuntimeInformation.IsOSPlatform(OSPlatform.Linux)) + { + _exe = "gcc"; + _ext = "so"; + } + else if (RuntimeInformation.IsOSPlatform(OSPlatform.OSX)) + { + _exe = "clang"; + _ext = "dylib"; + } + else if (RuntimeInformation.IsOSPlatform(OSPlatform.Windows)) + { + _exe = "cmd"; + _ext = "dll"; + } + + if (System.Environment.GetEnvironmentVariable("NNCASE_CPU_COMPILER") is string exe) + { + _exe = exe; + } + } + + private void ArchSpecific() + { + _arch = RuntimeInformation.OSArchitecture switch + { + Architecture.X64 => RuntimeInformation.IsOSPlatform(OSPlatform.Linux) ? "x86-64" : "x86_64", + Architecture.Arm64 => "arm64", + _ => throw new NotSupportedException(RuntimeInformation.OSArchitecture.ToString()), + }; + } + + private string ArgumentsSpecific(string sourcePath, string outPath) + { + if (RuntimeInformation.IsOSPlatform(OSPlatform.Linux)) + { + return $"{sourcePath} -nostdlib -static -no-pie -fPIC -march={Arch} -o {outPath}"; + } + else if (RuntimeInformation.IsOSPlatform(OSPlatform.OSX)) + { + return $"{sourcePath} -nostdlib -static -nopie -fPIC -arch {Arch} -o {outPath} -e__start"; + } + else if (RuntimeInformation.IsOSPlatform(OSPlatform.Windows)) + { + var vsdir = Environment.GetEnvironmentVariable("VSAPPIDDIR") ?? throw new InvalidOperationException("Cannot find vs"); + var vcvardir = Path.Combine(vsdir, "..\\..\\VC\\Auxiliary\\Build\\vcvarsall.bat"); + return $"/C (\"{vcvardir}\" x64) && (cl /D_USRDLL /D_WINDLL \"{sourcePath}\" /MT /link /DLL /OUT:\"{outPath}\")"; + } + + throw new System.NotSupportedException("Only Support Linux/Osx/Windows"); + } +} diff --git a/modules/Nncase.Modules.CPU/CodeGen/FunctionCSource.cs b/modules/Nncase.Modules.CPU/CodeGen/FunctionCSource.cs index 456d23649d..d08db57507 100644 --- a/modules/Nncase.Modules.CPU/CodeGen/FunctionCSource.cs +++ b/modules/Nncase.Modules.CPU/CodeGen/FunctionCSource.cs @@ -15,137 +15,6 @@ namespace Nncase.CodeGen; -/// -/// the csource code compiler. -/// -public class CSourceCompiler -{ - /// - /// compiler exe name. - /// - private string _exe = string.Empty; - - /// - /// compiler exe name. - /// - private string _arch = string.Empty; - - /// - /// compiler exe name. - /// - private string _ext = string.Empty; - - public CSourceCompiler() - { - PlatformSpecific(); - ArchSpecific(); - } - - protected string Exe - { - get => _exe; - } - - protected string Arch - { - get => _arch; - } - - protected string Ext - { - get => _ext; - } - - /// - /// compile the source txt, write to the out_path. - /// - /// c source code. - /// out .so path. - /// outPath. - public string Compile(string sourcePath, string outPath) - { - var errMsg = new StringBuilder(); - using (var errWriter = new StringWriter(errMsg)) - { - using (var proc = new Process()) - { - proc.StartInfo.FileName = Exe; - proc.StartInfo.Arguments = ArgumentsSpecific(sourcePath, outPath); - proc.StartInfo.RedirectStandardError = true; - proc.ErrorDataReceived += (sender, e) => errWriter.WriteLine(e.Data); - proc.Start(); - proc.BeginErrorReadLine(); - proc.WaitForExit(); - if (proc.ExitCode != 0) - { - throw new InvalidOperationException(errMsg.ToString()); - } - } - } - - return outPath; - } - - /// - /// create the temp dll file and compile source - /// . - /// - public string Compile(string sourcePath) => Compile(sourcePath, CodeGenUtil.GetTempFileName(Ext)); - - /// - /// select current pattern's exe. - /// - /// NotSupportedException. - private void PlatformSpecific() - { - if (RuntimeInformation.IsOSPlatform(OSPlatform.Linux)) - { - _exe = "gcc"; - _ext = "so"; - } - else if (RuntimeInformation.IsOSPlatform(OSPlatform.OSX)) - { - _exe = "clang"; - _ext = "dylib"; - } - else if (RuntimeInformation.IsOSPlatform(OSPlatform.Windows)) - { - _exe = "cmd"; - _ext = "dll"; - } - } - - private void ArchSpecific() - { - _arch = RuntimeInformation.OSArchitecture switch - { - Architecture.X64 => RuntimeInformation.IsOSPlatform(OSPlatform.Linux) ? "x86-64" : "x86_64", - Architecture.Arm64 => "arm64", - _ => throw new NotSupportedException(RuntimeInformation.OSArchitecture.ToString()), - }; - } - - private string ArgumentsSpecific(string sourcePath, string outPath) - { - if (RuntimeInformation.IsOSPlatform(OSPlatform.Linux)) - { - return $"{sourcePath} -nostdlib -static -no-pie -fPIC -march={Arch} -o {outPath}"; - } - else if (RuntimeInformation.IsOSPlatform(OSPlatform.OSX)) - { - return $"{sourcePath} -fPIC -shared -arch {Arch} -o {outPath}"; - } - else if (RuntimeInformation.IsOSPlatform(OSPlatform.Windows)) - { - var vsdir = Environment.GetEnvironmentVariable("VSAPPIDDIR") ?? throw new InvalidOperationException("Cannot find vs"); - var vcvardir = Path.Combine(vsdir, "..\\..\\VC\\Auxiliary\\Build\\vcvarsall.bat"); - return $"/C (\"{vcvardir}\" x64) && (cl /D_USRDLL /D_WINDLL \"{sourcePath}\" /MT /link /DLL /OUT:\"{outPath}\")"; - } - - throw new System.NotSupportedException("Only Support Linux/Osx/Windows"); - } -} - internal sealed class FunctionCSource { public FunctionCSource(string declaration, string implementation) diff --git a/modules/Nncase.Modules.CPU/CodeGen/LinkableModule.cs b/modules/Nncase.Modules.CPU/CodeGen/LinkableModule.cs index f668f441b4..ad46c59d0b 100644 --- a/modules/Nncase.Modules.CPU/CodeGen/LinkableModule.cs +++ b/modules/Nncase.Modules.CPU/CodeGen/LinkableModule.cs @@ -31,9 +31,6 @@ public LinkableModule(byte[] rdata, IReadOnlyList functions, C public ILinkedModule Link(ILinkContext linkContext) { var csourcePath = LinkCSources(); - var elfPath = CompileCSource(csourcePath); - var text = File.ReadAllBytes(elfPath); - if (_options.DumpFlags.HasFlag(DumpFlags.CodeGen)) { var dumpPath = _options.DumpDir; @@ -51,6 +48,18 @@ public ILinkedModule Link(ILinkContext linkContext) } } + var elfPath = CompileCSource(csourcePath); + var text = File.ReadAllBytes(elfPath); + + if (_options.DumpFlags.HasFlag(DumpFlags.CodeGen)) + { + var dumpPath = _options.DumpDir; + using (var fs = File.Open(Path.Join(dumpPath, "cpuModule.elf"), FileMode.Create)) + { + fs.Write(text); + } + } + var linkedFunctions = new List(); foreach (var func in _functions) { diff --git a/modules/Nncase.Modules.CPU/Passes/Rules/LowerUnary.cs b/modules/Nncase.Modules.CPU/Passes/Rules/LowerUnary.cs index f0cbd5a862..a6182d5c27 100644 --- a/modules/Nncase.Modules.CPU/Passes/Rules/LowerUnary.cs +++ b/modules/Nncase.Modules.CPU/Passes/Rules/LowerUnary.cs @@ -24,7 +24,7 @@ public partial class LowerUnary : RewriteRule public override Pattern Pattern { get; } = IsUnary( target_name: "unary", _ => true, - IsWildcard("input")); + IsWildcard("input") with { TypePattern = IsFloat() }); private Expr GetReplace(Unary unary, Expr input) { diff --git a/modules/Nncase.Modules.CPU/Targets/CPUTarget.cs b/modules/Nncase.Modules.CPU/Targets/CPUTarget.cs index 3f9c2bc9cc..5978d15620 100644 --- a/modules/Nncase.Modules.CPU/Targets/CPUTarget.cs +++ b/modules/Nncase.Modules.CPU/Targets/CPUTarget.cs @@ -41,7 +41,7 @@ public void RegisterTargetInDependentPass(IPassManager passManager, CompileOptio /// public void RegisterTargetDependentPass(IPassManager passManager, CompileOptions options) { - passManager.AddWithName("LowerIR").Configure(p => + passManager.AddWithName("LowerIR").Configure(p => { p.Add(); }); diff --git a/modules/cpu/src/runtime/elfload.cpp b/modules/cpu/src/runtime/elfload.cpp index cb712edba5..8a03d6c93c 100644 --- a/modules/cpu/src/runtime/elfload.cpp +++ b/modules/cpu/src/runtime/elfload.cpp @@ -140,8 +140,8 @@ el_status el_load(el_ctx *ctx, el_alloc_cb alloc) { if (!dest) return EL_ENOMEM; - printf("Loading seg fileoff %lx, vaddr %lx to %lx\n", ph.p_offset, - ph.p_vaddr, (uintptr_t)dest); + // printf("Loading seg fileoff %lx, vaddr %lx to %lx\n", ph.p_offset, + // ph.p_vaddr, (uintptr_t)dest); /* read loaded portion */ if ((rv = el_pread(ctx, dest, ph.p_filesz, ph.p_offset))) @@ -225,7 +225,7 @@ el_status el_relocate(el_ctx *ctx) { } size_t relcnt = ri.tablesize / sizeof(Elf_Rel); - Elf_Rel *reltab = base + ri.tableoff; + Elf_Rel *reltab = (Elf_Rel *)(base + ri.tableoff); for (size_t i = 0; i < relcnt; i++) { if ((rv = el_applyrel(ctx, &reltab[i]))) return rv; diff --git a/modules/cpu/src/runtime/elfloader.h b/modules/cpu/src/runtime/elfloader.h index 771a6b5f16..bdc0c5039e 100644 --- a/modules/cpu/src/runtime/elfloader.h +++ b/modules/cpu/src/runtime/elfloader.h @@ -18,7 +18,7 @@ typedef void (*entrypoint_t)(size_t id, buffer_t **buffers, class elfloader { public: - elfloader(char *elf) : elf_(elf) { + elfloader(char *elf) { ctx_.pread = bpread; ctx_.elf = elf; } @@ -55,7 +55,6 @@ class elfloader { private: void *ptr_; void *buf_; - char *elf_; el_ctx ctx_; }; diff --git a/modules/cpu/src/runtime/runtime_function.cpp b/modules/cpu/src/runtime/runtime_function.cpp index f786c92a20..61afac5746 100644 --- a/modules/cpu/src/runtime/runtime_function.cpp +++ b/modules/cpu/src/runtime/runtime_function.cpp @@ -121,8 +121,8 @@ cpu_runtime_function::invoke_core(NNCASE_UNUSED gsl::span parameters, "input " + std::to_string(i) + " is not a tensor"); try_var(input_span, get_input_span(input_tensor)); buffer_t *input_buffer = - new buffer_t(input_span.data(), 0, input_shapes_[i].data(), - input_strides_[i].data(), input_ranks_[i]); + new buffer_t{input_span.data(), 0, input_shapes_[i].data(), + input_strides_[i].data(), input_ranks_[i]}; buffers[i] = input_buffer; } @@ -133,8 +133,8 @@ cpu_runtime_function::invoke_core(NNCASE_UNUSED gsl::span parameters, "output " + std::to_string(i) + " is not a tensor"); try_var(output_span, get_output_span(output_tensor)); buffer_t *output_buffer = - new buffer_t(output_span.data(), 0, output_shapes_[i].data(), - output_strides_[i].data(), output_ranks_[i]); + new buffer_t{output_span.data(), 0, output_shapes_[i].data(), + output_strides_[i].data(), output_ranks_[i]}; buffers[input_ranks_.size() + i] = output_buffer; } From 79e010a9b0f85ddcf83fd6e91efbc950a4f5ed09 Mon Sep 17 00:00:00 2001 From: zhen8838 Date: Tue, 25 Jul 2023 07:29:57 +0000 Subject: [PATCH 028/308] Apply code-format changes --- modules/Nncase.Modules.CPU/CodeGen/CSourceCompiler.cs | 2 +- modules/cpu/src/runtime/elfload.cpp | 2 +- 2 files changed, 2 insertions(+), 2 deletions(-) diff --git a/modules/Nncase.Modules.CPU/CodeGen/CSourceCompiler.cs b/modules/Nncase.Modules.CPU/CodeGen/CSourceCompiler.cs index 9d34d5763a..6ab53936fc 100644 --- a/modules/Nncase.Modules.CPU/CodeGen/CSourceCompiler.cs +++ b/modules/Nncase.Modules.CPU/CodeGen/CSourceCompiler.cs @@ -1,4 +1,4 @@ -// Copyright (c) Canaan Inc. All rights reserved. +// Copyright (c) Canaan Inc. All rights reserved. // Licensed under the Apache license. See LICENSE file in the project root for full license information. using System; diff --git a/modules/cpu/src/runtime/elfload.cpp b/modules/cpu/src/runtime/elfload.cpp index 8a03d6c93c..23e4b355e9 100644 --- a/modules/cpu/src/runtime/elfload.cpp +++ b/modules/cpu/src/runtime/elfload.cpp @@ -141,7 +141,7 @@ el_status el_load(el_ctx *ctx, el_alloc_cb alloc) { return EL_ENOMEM; // printf("Loading seg fileoff %lx, vaddr %lx to %lx\n", ph.p_offset, - // ph.p_vaddr, (uintptr_t)dest); + // ph.p_vaddr, (uintptr_t)dest); /* read loaded portion */ if ((rv = el_pread(ctx, dest, ph.p_filesz, ph.p_offset))) From ebe45fd4555762832b96f9b5d1956b01b272bade Mon Sep 17 00:00:00 2001 From: huochenghai Date: Tue, 25 Jul 2023 15:52:50 +0800 Subject: [PATCH 029/308] add bianry method table --- .../CodeGen/CSourceBuiltn.cs | 38 ++++++- modules/cpu/src/runtime/cpu_common.h | 102 +++++++++++++++++- 2 files changed, 138 insertions(+), 2 deletions(-) diff --git a/modules/Nncase.Modules.CPU/CodeGen/CSourceBuiltn.cs b/modules/Nncase.Modules.CPU/CodeGen/CSourceBuiltn.cs index 13b55dbb34..269e8e5492 100644 --- a/modules/Nncase.Modules.CPU/CodeGen/CSourceBuiltn.cs +++ b/modules/Nncase.Modules.CPU/CodeGen/CSourceBuiltn.cs @@ -16,6 +16,7 @@ public static class CSourceBuiltn } buffer_t;"; public const string MethodTable = @"typedef struct nncase_method_table { + // float unary float (*float_unary_abs)(float); float (*float_unary_acos)(float); float (*float_unary_acosh)(float); @@ -37,9 +38,44 @@ public static class CSourceBuiltn float (*float_unary_sqrt)(float); float (*float_unary_square)(float); float (*float_unary_tanh)(float); + // float bianry + float (*float_binary_add)(float, float); + float (*float_binary_sub)(float, float); + float (*float_binary_mul)(float, float); + float (*float_binary_div)(float, float); + float (*float_binary_min)(float, float); + float (*float_binary_max)(float, float); + float (*float_binary_pow)(float, float); + float (*float_binary_logical_and)(float, float); + float (*float_binary_mod)(float, float); + // int32 bianry + int32_t (*int32_binary_add)(int32_t, int32_t); + int32_t (*int32_binary_sub)(int32_t, int32_t); + int32_t (*int32_binary_mul)(int32_t, int32_t); + int32_t (*int32_binary_div)(int32_t, int32_t); + int32_t (*int32_binary_min)(int32_t, int32_t); + int32_t (*int32_binary_max)(int32_t, int32_t); + int32_t (*int32_binary_pow)(int32_t, int32_t); + int32_t (*int32_binary_logical_and)(int32_t, int32_t); + int32_t (*int32_binary_mod)(int32_t, int32_t); + // int64 bianry + int64_t (*int64_binary_add)(int64_t, int64_t); + int64_t (*int64_binary_sub)(int64_t, int64_t); + int64_t (*int64_binary_mul)(int64_t, int64_t); + int64_t (*int64_binary_div)(int64_t, int64_t); + int64_t (*int64_binary_min)(int64_t, int64_t); + int64_t (*int64_binary_max)(int64_t, int64_t); + int64_t (*int64_binary_pow)(int64_t, int64_t); + int64_t (*int64_binary_logical_and)(int64_t, int64_t); + int64_t (*int64_binary_mod)(int64_t, int64_t); + // bool binary + bool (*bool_binary_and)(bool, bool); + bool (*bool_binary_or)(bool, bool); + bool (*bool_binary_xor)(bool, bool); } nncase_mt_t;"; - public const string Include = @"#include + public const string Include = @"#include +#include #include "; diff --git a/modules/cpu/src/runtime/cpu_common.h b/modules/cpu/src/runtime/cpu_common.h index 72ace7cd89..047ea7ce31 100644 --- a/modules/cpu/src/runtime/cpu_common.h +++ b/modules/cpu/src/runtime/cpu_common.h @@ -2,12 +2,14 @@ #include #include #include +#include #include #include BEGIN_NS_NNCASE_RT_MODULE(cpu) typedef struct nncase_method_table { + // float unary float (*float_unary_abs)(float); float (*float_unary_acos)(float); float (*float_unary_acosh)(float); @@ -29,6 +31,40 @@ typedef struct nncase_method_table { float (*float_unary_sqrt)(float); float (*float_unary_square)(float); float (*float_unary_tanh)(float); + // float bianry + float (*float_binary_add)(float, float); + float (*float_binary_sub)(float, float); + float (*float_binary_mul)(float, float); + float (*float_binary_div)(float, float); + float (*float_binary_min)(float, float); + float (*float_binary_max)(float, float); + float (*float_binary_pow)(float, float); + float (*float_binary_logical_and)(float, float); + float (*float_binary_mod)(float, float); + // int32 bianry + int32_t (*int32_binary_add)(int32_t, int32_t); + int32_t (*int32_binary_sub)(int32_t, int32_t); + int32_t (*int32_binary_mul)(int32_t, int32_t); + int32_t (*int32_binary_div)(int32_t, int32_t); + int32_t (*int32_binary_min)(int32_t, int32_t); + int32_t (*int32_binary_max)(int32_t, int32_t); + int32_t (*int32_binary_pow)(int32_t, int32_t); + int32_t (*int32_binary_logical_and)(int32_t, int32_t); + int32_t (*int32_binary_mod)(int32_t, int32_t); + // int64 bianry + int64_t (*int64_binary_add)(int64_t, int64_t); + int64_t (*int64_binary_sub)(int64_t, int64_t); + int64_t (*int64_binary_mul)(int64_t, int64_t); + int64_t (*int64_binary_div)(int64_t, int64_t); + int64_t (*int64_binary_min)(int64_t, int64_t); + int64_t (*int64_binary_max)(int64_t, int64_t); + int64_t (*int64_binary_pow)(int64_t, int64_t); + int64_t (*int64_binary_logical_and)(int64_t, int64_t); + int64_t (*int64_binary_mod)(int64_t, int64_t); + // bool binary + bool (*bool_binary_and)(bool, bool); + bool (*bool_binary_or)(bool, bool); + bool (*bool_binary_xor)(bool, bool); } nncase_mt_t; typedef struct buffer { @@ -45,6 +81,40 @@ inline float float_unary_rsqrt(float x) { return 1.f / sqrtf(x); } inline float float_unary_sign(float x) { return (0.f < x) - (x < 0.f); } inline float float_unary_square(float x) { return x * x; } +inline float float_binary_add(float x, float y) { return x + y; } +inline float float_binary_sub(float x, float y) { return x - y; } +inline float float_binary_mul(float x, float y) { return x * y; } +inline float float_binary_div(float x, float y) { return x / y; } +inline float float_binary_min(float x, float y) { return std::min(x, y); } +inline float float_binary_max(float x, float y) { return std::max(x, y); } +inline float float_binary_pow(float x, float y) { return powf(x, y); } +inline float float_binary_logical_and(float x, float y) { return x && y; } +inline float float_binary_mod(float x, float y) { return fmod(x, y); } + +inline int32_t int32_binary_add(int32_t x, int32_t y) { return x + y; } +inline int32_t int32_binary_sub(int32_t x, int32_t y) { return x - y; } +inline int32_t int32_binary_mul(int32_t x, int32_t y) { return x * y; } +inline int32_t int32_binary_div(int32_t x, int32_t y) { return x / y; } +inline int32_t int32_binary_min(int32_t x, int32_t y) { return std::min(x, y); } +inline int32_t int32_binary_max(int32_t x, int32_t y) { return std::max(x, y); } +inline int32_t int32_binary_pow(int32_t x, int32_t y) { return std::pow(x, y); } +inline int32_t int32_binary_logical_and(int32_t x, int32_t y) { return x && y; } +inline int32_t int32_binary_mod(int32_t x, int32_t y) { return x % y; } + +inline int64_t int64_binary_add(int64_t x, int64_t y) { return x + y; } +inline int64_t int64_binary_sub(int64_t x, int64_t y) { return x - y; } +inline int64_t int64_binary_mul(int64_t x, int64_t y) { return x * y; } +inline int64_t int64_binary_div(int64_t x, int64_t y) { return x / y; } +inline int64_t int64_binary_min(int64_t x, int64_t y) { return std::min(x, y); } +inline int64_t int64_binary_max(int64_t x, int64_t y) { return std::max(x, y); } +inline int64_t int64_binary_pow(int64_t x, int64_t y) { return std::pow(x, y); } +inline int64_t int64_binary_logical_and(int64_t x, int64_t y) { return x && y; } +inline int64_t int64_binary_mod(int64_t x, int64_t y) { return x % y; } + +inline bool bool_binary_logical_and(bool x, bool y) { return x && y; } +inline bool bool_binary_logical_or(bool x, bool y) { return x || y; } +inline bool bool_binary_logical_xor(bool x, bool y) { return x ^ y; } + [[maybe_unused]] static nncase_mt_t nncase_mt = { .float_unary_abs = fabsf, .float_unary_acos = acosf, @@ -66,6 +136,36 @@ inline float float_unary_square(float x) { return x * x; } .float_unary_sinh = sinhf, .float_unary_sqrt = sqrtf, .float_unary_square = float_unary_square, - .float_unary_tanh = tanhf}; + .float_unary_tanh = tanhf, + .float_binary_add = float_binary_add, + .float_binary_sub = float_binary_sub, + .float_binary_mul = float_binary_mul, + .float_binary_div = float_binary_div, + .float_binary_min = float_binary_min, + .float_binary_max = float_binary_max, + .float_binary_pow = float_binary_pow, + .float_binary_logical_and = float_binary_logical_and, + .float_binary_mod = float_binary_mod, + .int32_binary_add = int32_binary_add, + .int32_binary_sub = int32_binary_sub, + .int32_binary_mul = int32_binary_mul, + .int32_binary_div = int32_binary_div, + .int32_binary_min = int32_binary_min, + .int32_binary_max = int32_binary_max, + .int32_binary_pow = int32_binary_pow, + .int32_binary_logical_and = int32_binary_logical_and, + .int32_binary_mod = int32_binary_mod, + .int64_binary_add = int64_binary_add, + .int64_binary_sub = int64_binary_sub, + .int64_binary_mul = int64_binary_mul, + .int64_binary_div = int64_binary_div, + .int64_binary_min = int64_binary_min, + .int64_binary_max = int64_binary_max, + .int64_binary_pow = int64_binary_pow, + .int64_binary_logical_and = int64_binary_logical_and, + .int64_binary_mod = int64_binary_mod, + .bool_binary_and = bool_binary_logical_and, + .bool_binary_or = bool_binary_logical_or, + .bool_binary_xor = bool_binary_logical_xor}; END_NS_NNCASE_RT_MODULE \ No newline at end of file From 8b7b8ca3b03f02bd2a896c3ef42efd28ca8874a9 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?=E9=83=91=E5=90=AF=E8=88=AA?= <597323109@qq.com> Date: Tue, 25 Jul 2023 16:41:20 +0800 Subject: [PATCH 030/308] change make fusion --- .../Passes/Rules/LowerMatMul.cs | 34 ++++++++++++++++++ .../Passes/Rules/MakeFusion.cs | 35 ++++++++++--------- .../Nncase.Modules.CPU/Targets/CPUTarget.cs | 1 + 3 files changed, 54 insertions(+), 16 deletions(-) create mode 100644 modules/Nncase.Modules.CPU/Passes/Rules/LowerMatMul.cs diff --git a/modules/Nncase.Modules.CPU/Passes/Rules/LowerMatMul.cs b/modules/Nncase.Modules.CPU/Passes/Rules/LowerMatMul.cs new file mode 100644 index 0000000000..81b0c5a3c3 --- /dev/null +++ b/modules/Nncase.Modules.CPU/Passes/Rules/LowerMatMul.cs @@ -0,0 +1,34 @@ +// Copyright (c) Canaan Inc. All rights reserved. +// Licensed under the Apache license. See LICENSE file in the project root for full license information. + +using System; +using System.Collections.Generic; +using System.Linq; +using System.Text; +using System.Threading.Tasks; +using Nncase.IR; +using Nncase.IR.Math; +using Nncase.PatternMatch; + +using static Nncase.IR.F.CPU; +using static Nncase.IR.TypePatternUtility; +using static Nncase.PatternMatch.F.Math; +using static Nncase.PatternMatch.Utility; + +namespace Nncase.Passes.Rules; + +[RuleGenerator] +public partial class LowerMatMul : RewriteRule +{ + /// + public override Pattern Pattern { get; } = IsMatMul( + target_name: "matmul", + _ => true, + IsWildcard("inputA") with { TypePattern = IsFloat() }, + IsWildcard("inputB") with { TypePattern = IsFloat() }); + + private Expr GetReplace(MatMul matmul, Expr inputA, Expr inputB) + { + return CPUKernel(matmul, inputA, inputB); + } +} diff --git a/modules/Nncase.Modules.CPU/Passes/Rules/MakeFusion.cs b/modules/Nncase.Modules.CPU/Passes/Rules/MakeFusion.cs index d7644bf2b5..29d10bd7ad 100644 --- a/modules/Nncase.Modules.CPU/Passes/Rules/MakeFusion.cs +++ b/modules/Nncase.Modules.CPU/Passes/Rules/MakeFusion.cs @@ -18,26 +18,29 @@ namespace Nncase.Passes.Rules; [RuleGenerator] -internal partial class CPUSingleInputFusion : FusionMaker - where T : Op +internal sealed partial class CPUFusion : FusionMaker { public override string ModuleKind { get; } = CPUTarget.Kind; - public override Pattern Pattern { get; } = IsCallWildcard( - "call", - IsOp("op"), - IsWildcard("input")); + public override Pattern Pattern => IsCallWildcard("call", IsOp("op")); - private Call? GetReplace(Call call, IReadOnlyList callParams, Op op, Expr input) + private Call? GetReplace(Call call, CPUKernelOp op, IReadOnlyList callParams) { - var newInput = new Var(input.CheckedType!); - var newCall = ReplaceCallParams(op, callParams, (input, newInput)); - var fusion = new Call(new Fusion(FullName, ModuleKind, newCall, new[] { newInput }), input); - return fusion; - } -} + var newInputs = new List(); + for (int i = 0; i < callParams.Count; i++) + { + if (callParams[i] is (Call or Var)) + { + newInputs.Add(new Var(callParams[i].CheckedType!)); + } + else + { + newInputs.Add(callParams[i]); + } + } -internal sealed class CPUFusion : CPUSingleInputFusion -{ - public override string Name => nameof(CPUFusion); + var newCall = new Call(op, newInputs.ToArray()); + var callFusion = new Call(new Fusion(FullName, ModuleKind, newCall, newInputs.OfType().ToArray()), newInputs.Select((e, i) => (e, i)).Where(p => p.e is Var).Select(p => callParams[p.i]).ToArray()); + return callFusion; + } } diff --git a/modules/Nncase.Modules.CPU/Targets/CPUTarget.cs b/modules/Nncase.Modules.CPU/Targets/CPUTarget.cs index 5978d15620..01cf36f618 100644 --- a/modules/Nncase.Modules.CPU/Targets/CPUTarget.cs +++ b/modules/Nncase.Modules.CPU/Targets/CPUTarget.cs @@ -44,6 +44,7 @@ public void RegisterTargetDependentPass(IPassManager passManager, CompileOptions passManager.AddWithName("LowerIR").Configure(p => { p.Add(); + p.Add(); }); } From a251e59287eaff78f969d41cea6d9f1fc889a5dc Mon Sep 17 00:00:00 2001 From: zhen8838 Date: Tue, 25 Jul 2023 08:44:21 +0000 Subject: [PATCH 031/308] Apply code-format changes --- modules/Nncase.Modules.CPU/Passes/Rules/LowerMatMul.cs | 2 +- modules/Nncase.Modules.CPU/Passes/Rules/MakeFusion.cs | 2 +- 2 files changed, 2 insertions(+), 2 deletions(-) diff --git a/modules/Nncase.Modules.CPU/Passes/Rules/LowerMatMul.cs b/modules/Nncase.Modules.CPU/Passes/Rules/LowerMatMul.cs index 81b0c5a3c3..d23b542ff9 100644 --- a/modules/Nncase.Modules.CPU/Passes/Rules/LowerMatMul.cs +++ b/modules/Nncase.Modules.CPU/Passes/Rules/LowerMatMul.cs @@ -1,4 +1,4 @@ -// Copyright (c) Canaan Inc. All rights reserved. +// Copyright (c) Canaan Inc. All rights reserved. // Licensed under the Apache license. See LICENSE file in the project root for full license information. using System; diff --git a/modules/Nncase.Modules.CPU/Passes/Rules/MakeFusion.cs b/modules/Nncase.Modules.CPU/Passes/Rules/MakeFusion.cs index 29d10bd7ad..d372d36037 100644 --- a/modules/Nncase.Modules.CPU/Passes/Rules/MakeFusion.cs +++ b/modules/Nncase.Modules.CPU/Passes/Rules/MakeFusion.cs @@ -29,7 +29,7 @@ internal sealed partial class CPUFusion : FusionMaker var newInputs = new List(); for (int i = 0; i < callParams.Count; i++) { - if (callParams[i] is (Call or Var)) + if (callParams[i] is Call or Var) { newInputs.Add(new Var(callParams[i].CheckedType!)); } From e6894e3faf05a46a0e4c0047f600496c58c2ae33 Mon Sep 17 00:00:00 2001 From: huochenghai Date: Tue, 25 Jul 2023 16:45:45 +0800 Subject: [PATCH 032/308] add cpu binary lower --- .../Passes/Rules/LowerBinary.cs | 40 +++++++++++++++++++ .../Nncase.Modules.CPU/Targets/CPUTarget.cs | 1 + 2 files changed, 41 insertions(+) create mode 100644 modules/Nncase.Modules.CPU/Passes/Rules/LowerBinary.cs diff --git a/modules/Nncase.Modules.CPU/Passes/Rules/LowerBinary.cs b/modules/Nncase.Modules.CPU/Passes/Rules/LowerBinary.cs new file mode 100644 index 0000000000..ef1c431f99 --- /dev/null +++ b/modules/Nncase.Modules.CPU/Passes/Rules/LowerBinary.cs @@ -0,0 +1,40 @@ +// Copyright (c) Canaan Inc. All rights reserved. +// Licensed under the Apache license. See LICENSE file in the project root for full license information. + +using System; +using System.Collections.Generic; +using System.Linq; +using System.Text; +using System.Threading.Tasks; +using Nncase.IR; +using Nncase.IR.Math; +using Nncase.PatternMatch; + +using static Nncase.IR.F.CPU; +using static Nncase.IR.TypePatternUtility; +using static Nncase.PatternMatch.F.Math; +using static Nncase.PatternMatch.Utility; + +namespace Nncase.Passes.Rules; + +[RuleGenerator] +public partial class LowerBinary : RewriteRule +{ + /// + public override Pattern Pattern { get; } = IsBinary( + target_name: "binary", + _ => true, + IsWildcard("lhs") with { TypePattern = IsFloat() }, + IsWildcard("rhs") with { TypePattern = IsFloat() } + ); + + private Expr? GetReplace(Binary binary, Expr lhs, Expr rhs) + { + if (lhs.CheckedShape.Rank != rhs.CheckedShape.Rank) + { + return null; + } + + return CPUKernel(binary, lhs, rhs); + } +} diff --git a/modules/Nncase.Modules.CPU/Targets/CPUTarget.cs b/modules/Nncase.Modules.CPU/Targets/CPUTarget.cs index 01cf36f618..1b3504f305 100644 --- a/modules/Nncase.Modules.CPU/Targets/CPUTarget.cs +++ b/modules/Nncase.Modules.CPU/Targets/CPUTarget.cs @@ -43,6 +43,7 @@ public void RegisterTargetDependentPass(IPassManager passManager, CompileOptions { passManager.AddWithName("LowerIR").Configure(p => { + p.Add(); p.Add(); p.Add(); }); From 1ce70232c830871a8119fa5bd0bcc1f02da1f24f Mon Sep 17 00:00:00 2001 From: xhuohai Date: Tue, 25 Jul 2023 08:48:38 +0000 Subject: [PATCH 033/308] Apply code-format changes --- modules/Nncase.Modules.CPU/Passes/Rules/LowerBinary.cs | 9 ++++----- 1 file changed, 4 insertions(+), 5 deletions(-) diff --git a/modules/Nncase.Modules.CPU/Passes/Rules/LowerBinary.cs b/modules/Nncase.Modules.CPU/Passes/Rules/LowerBinary.cs index ef1c431f99..ffef820490 100644 --- a/modules/Nncase.Modules.CPU/Passes/Rules/LowerBinary.cs +++ b/modules/Nncase.Modules.CPU/Passes/Rules/LowerBinary.cs @@ -25,16 +25,15 @@ public partial class LowerBinary : RewriteRule target_name: "binary", _ => true, IsWildcard("lhs") with { TypePattern = IsFloat() }, - IsWildcard("rhs") with { TypePattern = IsFloat() } - ); + IsWildcard("rhs") with { TypePattern = IsFloat() }); private Expr? GetReplace(Binary binary, Expr lhs, Expr rhs) { - if (lhs.CheckedShape.Rank != rhs.CheckedShape.Rank) - { + if (lhs.CheckedShape.Rank != rhs.CheckedShape.Rank) + { return null; } - + return CPUKernel(binary, lhs, rhs); } } From bf162feadf0fb02c461a19cc7410c4927f56d4b4 Mon Sep 17 00:00:00 2001 From: huochenghai Date: Tue, 25 Jul 2023 16:51:57 +0800 Subject: [PATCH 034/308] draft binary to tir --- .../Passes/Tile/SingleCPUFusionConverter.cs | 27 +++++++++++++++++++ 1 file changed, 27 insertions(+) diff --git a/modules/Nncase.Modules.CPU/Passes/Tile/SingleCPUFusionConverter.cs b/modules/Nncase.Modules.CPU/Passes/Tile/SingleCPUFusionConverter.cs index 1e54700651..c7818713ec 100644 --- a/modules/Nncase.Modules.CPU/Passes/Tile/SingleCPUFusionConverter.cs +++ b/modules/Nncase.Modules.CPU/Passes/Tile/SingleCPUFusionConverter.cs @@ -68,6 +68,9 @@ protected override Unit VisitLeafCall(Call expr) case Unary unary: GenerateUnary(unary, arguments, ret); break; + case Binary binary: + GenerateBinary(binary, arguments, ret, expr); + break; default: throw new NotSupportedException(); } @@ -86,6 +89,30 @@ private void GenerateUnary(Unary unary, ReadOnlySpan arguments, Buffer r _mainBody.Add(T.Block(nameof(Unary)).Body(final).Build()); } + private void GenerateBinary(Binary binary, ReadOnlySpan arguments, Buffer ret, Call call) + { + var lhs = call[Binary.Lhs]; + var rhs = call[Binary.Rhs]; + var lhsBuffer = arguments[Binary.Lhs.Index]; + var rhsBuffer = arguments[Binary.Rhs.Index]; + + var outShape = call.CheckedShape.ToValueArray(); + var lhsShape = Enumerable.Repeat(1, outShape.Length).ToArray(); + Array.Copy(lhs.CheckedShape.ToValueArray(), 0, lhsShape, lhsShape.Length - lhs.CheckedShape.Rank, lhs.CheckedShape.Rank); + var rhsShape = Enumerable.Repeat(1, outShape.Length).ToArray(); + Array.Copy(rhs.CheckedShape.ToValueArray(), 0, rhsShape, rhsShape.Length - rhs.CheckedShape.Rank, rhs.CheckedShape.Rank); + + var lhsScale = outShape.Zip(lhsShape).Select(s => s.First / s.Second).ToArray(); + var rhsScale = outShape.Zip(rhsShape).Select(s => s.First / s.Second).ToArray(); + + var loops = Enumerable.Range(0, outShape.Length).Select(i => (T.ForLoop(out var loopVar, (0, outShape[i]), LoopMode.Serial, $"loop_{i}"), loopVar)).ToArray(); + var input_index = Enumerable.Range(0, input.Rank).Aggregate((Expr)0, (acc, i) => acc + (input.Strides[i] * loops[i].loopVar)); + var output_index = Enumerable.Range(0, input.Rank).Aggregate((Expr)0, (acc, i) => acc + (ret.Strides[i] * loops[i].loopVar)); + Expr stmt = T.Store(ret, output_index, IR.F.Math.Unary(unary.UnaryOp, T.Load(input, output_index))); + var final = loops.Reverse().Aggregate(stmt, (acc, p) => p.Item1.Body(acc).Build()); + _mainBody.Add(T.Block(nameof(Unary)).Body(final).Build()); + } + private TIR.Buffer TryAllocateBuffer(Expr expr) { var name = $"buffer_{_buffersMap.Keys.Count}"; From 58af7d5f4b9f04632ef46087c9ed98d244f574b5 Mon Sep 17 00:00:00 2001 From: xhuohai Date: Tue, 25 Jul 2023 08:54:57 +0000 Subject: [PATCH 035/308] Apply code-format changes --- .../Nncase.Modules.CPU/Passes/Tile/SingleCPUFusionConverter.cs | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/modules/Nncase.Modules.CPU/Passes/Tile/SingleCPUFusionConverter.cs b/modules/Nncase.Modules.CPU/Passes/Tile/SingleCPUFusionConverter.cs index c7818713ec..4c3aa9045e 100644 --- a/modules/Nncase.Modules.CPU/Passes/Tile/SingleCPUFusionConverter.cs +++ b/modules/Nncase.Modules.CPU/Passes/Tile/SingleCPUFusionConverter.cs @@ -70,7 +70,7 @@ protected override Unit VisitLeafCall(Call expr) break; case Binary binary: GenerateBinary(binary, arguments, ret, expr); - break; + break; default: throw new NotSupportedException(); } From 852c7ac5d698f9fdae08b840259018d78c6332de Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?=E9=83=91=E5=90=AF=E8=88=AA?= <597323109@qq.com> Date: Tue, 25 Jul 2023 19:13:00 +0800 Subject: [PATCH 036/308] remove buffer load/store --- src/Nncase.Core/IR/Buffers/BufferLoad.cs | 22 ++++++++ src/Nncase.Core/IR/Buffers/BufferStore.cs | 24 +++++++++ src/Nncase.Core/IR/ExprCloner.g.cs | 19 ------- src/Nncase.Core/IR/ExprFunctor.g.cs | 24 --------- src/Nncase.Core/IR/ExprRewriter.g.cs | 38 ------------- src/Nncase.Core/IR/ExprVisitor.g.cs | 54 ------------------- src/Nncase.Core/TIR/BufferLoad.cs | 40 -------------- src/Nncase.Core/TIR/BufferStore.cs | 47 ---------------- src/Nncase.Core/TIR/Script.cs | 16 ++++++ .../Diagnostics/ScriptPrintVisitor.cs | 28 ---------- src/Nncase.Evaluator/Buffer.cs | 9 ---- src/Nncase.Evaluator/Buffers/BufferLoad.cs | 41 ++++++++++++++ src/Nncase.Evaluator/Buffers/BufferStore.cs | 47 ++++++++++++++++ src/Nncase.Evaluator/TIR/Store.cs | 2 - src/Nncase.Evaluator/TypeInferenceVisitor.cs | 48 ----------------- 15 files changed, 150 insertions(+), 309 deletions(-) create mode 100644 src/Nncase.Core/IR/Buffers/BufferLoad.cs create mode 100644 src/Nncase.Core/IR/Buffers/BufferStore.cs delete mode 100644 src/Nncase.Core/TIR/BufferLoad.cs delete mode 100644 src/Nncase.Core/TIR/BufferStore.cs delete mode 100644 src/Nncase.Evaluator/Buffer.cs create mode 100644 src/Nncase.Evaluator/Buffers/BufferLoad.cs create mode 100644 src/Nncase.Evaluator/Buffers/BufferStore.cs diff --git a/src/Nncase.Core/IR/Buffers/BufferLoad.cs b/src/Nncase.Core/IR/Buffers/BufferLoad.cs new file mode 100644 index 0000000000..4885197c3f --- /dev/null +++ b/src/Nncase.Core/IR/Buffers/BufferLoad.cs @@ -0,0 +1,22 @@ +// Copyright (c) Canaan Inc. All rights reserved. +// Licensed under the Apache license. See LICENSE file in the project root for full license information. + +using Nncase.IR.Tensors; +using Nncase.PatternMatch; +using static Nncase.IR.TypePatternUtility; + +namespace Nncase.IR.Buffers; + +/// +/// BufferIndexOf expression. +/// +[PatternFunctionalGenerator] +public sealed partial class BufferLoad : Op +{ + /// + /// Get the input parameter. + /// + public static readonly ParameterInfo Input = new(typeof(BufferLoad), 0, "input", IsTensor()); + + public static readonly ParameterInfo Indices = new(typeof(BufferLoad), 1, "indices", IsTuple()); +} diff --git a/src/Nncase.Core/IR/Buffers/BufferStore.cs b/src/Nncase.Core/IR/Buffers/BufferStore.cs new file mode 100644 index 0000000000..04f57d594b --- /dev/null +++ b/src/Nncase.Core/IR/Buffers/BufferStore.cs @@ -0,0 +1,24 @@ +// Copyright (c) Canaan Inc. All rights reserved. +// Licensed under the Apache license. See LICENSE file in the project root for full license information. + +using Nncase.IR.Tensors; +using Nncase.PatternMatch; +using static Nncase.IR.TypePatternUtility; + +namespace Nncase.IR.Buffers; + +/// +/// BufferIndexOf expression. +/// +[PatternFunctionalGenerator] +public sealed partial class BufferStore : Op +{ + /// + /// Get the input parameter. + /// + public static readonly ParameterInfo Input = new(typeof(BufferStore), 0, "input", IsTensor()); + + public static readonly ParameterInfo Indices = new(typeof(BufferStore), 1, "indices", IsTuple()); + + public static readonly ParameterInfo Value = new(typeof(BufferStore), 2, "value", IsTensor()); +} diff --git a/src/Nncase.Core/IR/ExprCloner.g.cs b/src/Nncase.Core/IR/ExprCloner.g.cs index 78d5d2bda4..7357f7f78b 100644 --- a/src/Nncase.Core/IR/ExprCloner.g.cs +++ b/src/Nncase.Core/IR/ExprCloner.g.cs @@ -157,15 +157,6 @@ protected override Expr VisitLeafPhysicalBuffer(TIR.PhysicalBuffer expr, TContex ); } - /// - protected override Expr VisitLeafBufferLoad(TIR.BufferLoad expr, TContext context) - { - return expr.With( - buffer: Clone(expr.Buffer, context), - indices: CloneArray(expr.Indices, context) - ); - } - /// protected override Expr VisitLeafBufferRegion(TIR.BufferRegion expr, TContext context) { @@ -175,16 +166,6 @@ protected override Expr VisitLeafBufferRegion(TIR.BufferRegion expr, TContext co ); } - /// - protected override Expr VisitLeafBufferStore(TIR.BufferStore expr, TContext context) - { - return expr.With( - buffer: Clone(expr.Buffer, context), - indices: CloneArray(expr.Indices, context), - value: Clone(expr.Value, context) - ); - } - /// protected override Expr VisitLeafFor(TIR.For expr, TContext context) { diff --git a/src/Nncase.Core/IR/ExprFunctor.g.cs b/src/Nncase.Core/IR/ExprFunctor.g.cs index 57e1ef86d5..f6ff8fd928 100644 --- a/src/Nncase.Core/IR/ExprFunctor.g.cs +++ b/src/Nncase.Core/IR/ExprFunctor.g.cs @@ -104,21 +104,11 @@ public partial class ExprFunctor /// internal protected virtual TExprResult VisitPhysicalBuffer(TIR.PhysicalBuffer expr, TContext context) => VisitBuffer(expr, context); - /// - /// Visit . - /// - internal protected virtual TExprResult VisitBufferLoad(TIR.BufferLoad expr, TContext context) => DefaultVisit(expr, context); - /// /// Visit . /// internal protected virtual TExprResult VisitBufferRegion(TIR.BufferRegion expr, TContext context) => DefaultVisit(expr, context); - /// - /// Visit . - /// - internal protected virtual TExprResult VisitBufferStore(TIR.BufferStore expr, TContext context) => DefaultVisit(expr, context); - /// /// Visit . /// @@ -289,13 +279,6 @@ public partial class ExprFunctor /// internal protected sealed override TExprResult VisitPhysicalBuffer(TIR.PhysicalBuffer expr, Unit context) => VisitPhysicalBuffer(expr); /// - /// Visit . - /// - internal protected virtual TExprResult VisitBufferLoad(TIR.BufferLoad expr) => base.VisitBufferLoad(expr, default); - - /// - internal protected sealed override TExprResult VisitBufferLoad(TIR.BufferLoad expr, Unit context) => VisitBufferLoad(expr); - /// /// Visit . /// internal protected virtual TExprResult VisitBufferRegion(TIR.BufferRegion expr) => base.VisitBufferRegion(expr, default); @@ -303,13 +286,6 @@ public partial class ExprFunctor /// internal protected sealed override TExprResult VisitBufferRegion(TIR.BufferRegion expr, Unit context) => VisitBufferRegion(expr); /// - /// Visit . - /// - internal protected virtual TExprResult VisitBufferStore(TIR.BufferStore expr) => base.VisitBufferStore(expr, default); - - /// - internal protected sealed override TExprResult VisitBufferStore(TIR.BufferStore expr, Unit context) => VisitBufferStore(expr); - /// /// Visit . /// internal protected virtual TExprResult VisitFor(TIR.For expr) => base.VisitFor(expr, default); diff --git a/src/Nncase.Core/IR/ExprRewriter.g.cs b/src/Nncase.Core/IR/ExprRewriter.g.cs index 2f0b6c1233..6695315197 100644 --- a/src/Nncase.Core/IR/ExprRewriter.g.cs +++ b/src/Nncase.Core/IR/ExprRewriter.g.cs @@ -122,24 +122,12 @@ protected sealed override Expr VisitLeafPhysicalBuffer(TIR.PhysicalBuffer expr, return RewriteLeafPhysicalBuffer(expr, context); } - /// - protected sealed override Expr VisitLeafBufferLoad(TIR.BufferLoad expr, TContext context) - { - return RewriteLeafBufferLoad(expr, context); - } - /// protected sealed override Expr VisitLeafBufferRegion(TIR.BufferRegion expr, TContext context) { return RewriteLeafBufferRegion(expr, context); } - /// - protected sealed override Expr VisitLeafBufferStore(TIR.BufferStore expr, TContext context) - { - return RewriteLeafBufferStore(expr, context); - } - /// protected sealed override Expr VisitLeafFor(TIR.For expr, TContext context) { @@ -272,21 +260,11 @@ protected sealed override Expr VisitLeafIterVar(TIR.IterVar expr, TContext conte /// protected virtual Expr RewriteLeafPhysicalBuffer(TIR.PhysicalBuffer expr, TContext context) => RewriteLeafBuffer(expr, context); - /// - /// Rewrite leaf . - /// - protected virtual Expr RewriteLeafBufferLoad(TIR.BufferLoad expr, TContext context) => DefaultRewriteLeaf(expr, context); - /// /// Rewrite leaf . /// protected virtual Expr RewriteLeafBufferRegion(TIR.BufferRegion expr, TContext context) => DefaultRewriteLeaf(expr, context); - /// - /// Rewrite leaf . - /// - protected virtual Expr RewriteLeafBufferStore(TIR.BufferStore expr, TContext context) => DefaultRewriteLeaf(expr, context); - /// /// Rewrite leaf . /// @@ -474,14 +452,6 @@ public partial class ExprRewriter /// protected sealed override Expr RewriteLeafPhysicalBuffer(TIR.PhysicalBuffer expr, Unit context) => RewriteLeafPhysicalBuffer(expr); - /// - /// Rewrite leaf . - /// - protected virtual Expr RewriteLeafBufferLoad(TIR.BufferLoad expr) => DefaultRewriteLeaf(expr); - - /// - protected sealed override Expr RewriteLeafBufferLoad(TIR.BufferLoad expr, Unit context) => RewriteLeafBufferLoad(expr); - /// /// Rewrite leaf . /// @@ -490,14 +460,6 @@ public partial class ExprRewriter /// protected sealed override Expr RewriteLeafBufferRegion(TIR.BufferRegion expr, Unit context) => RewriteLeafBufferRegion(expr); - /// - /// Rewrite leaf . - /// - protected virtual Expr RewriteLeafBufferStore(TIR.BufferStore expr) => DefaultRewriteLeaf(expr); - - /// - protected sealed override Expr RewriteLeafBufferStore(TIR.BufferStore expr, Unit context) => RewriteLeafBufferStore(expr); - /// /// Rewrite leaf . /// diff --git a/src/Nncase.Core/IR/ExprVisitor.g.cs b/src/Nncase.Core/IR/ExprVisitor.g.cs index 5e5a609a6a..c56e4f5aa7 100644 --- a/src/Nncase.Core/IR/ExprVisitor.g.cs +++ b/src/Nncase.Core/IR/ExprVisitor.g.cs @@ -131,13 +131,6 @@ protected internal override TExprResult VisitPhysicalBuffer(TIR.PhysicalBuffer e return VisitLeafPhysicalBuffer(expr, context); } - /// - protected internal override TExprResult VisitBufferLoad(TIR.BufferLoad expr, TContext context) - { - VisitOperands(expr, context); - return VisitLeafBufferLoad(expr, context); - } - /// protected internal override TExprResult VisitBufferRegion(TIR.BufferRegion expr, TContext context) { @@ -145,13 +138,6 @@ protected internal override TExprResult VisitBufferRegion(TIR.BufferRegion expr, return VisitLeafBufferRegion(expr, context); } - /// - protected internal override TExprResult VisitBufferStore(TIR.BufferStore expr, TContext context) - { - VisitOperands(expr, context); - return VisitLeafBufferStore(expr, context); - } - /// protected internal override TExprResult VisitFor(TIR.For expr, TContext context) { @@ -302,21 +288,11 @@ protected internal override TExprResult VisitMemSpan(TIR.MemSpan expr, TContext /// protected virtual TExprResult VisitLeafPhysicalBuffer(TIR.PhysicalBuffer expr, TContext context) => VisitLeafBuffer(expr, context); - /// - /// Visit leaf . - /// - protected virtual TExprResult VisitLeafBufferLoad(TIR.BufferLoad expr, TContext context) => DefaultVisitLeaf(expr, context); - /// /// Visit leaf . /// protected virtual TExprResult VisitLeafBufferRegion(TIR.BufferRegion expr, TContext context) => DefaultVisitLeaf(expr, context); - /// - /// Visit leaf . - /// - protected virtual TExprResult VisitLeafBufferStore(TIR.BufferStore expr, TContext context) => DefaultVisitLeaf(expr, context); - /// /// Visit leaf . /// @@ -467,13 +443,6 @@ public partial class ExprVisitor /// internal protected sealed override TExprResult VisitPhysicalBuffer(TIR.PhysicalBuffer expr, Unit context) => VisitPhysicalBuffer(expr); /// - /// Visit . - /// - internal protected virtual TExprResult VisitBufferLoad(TIR.BufferLoad expr) => base.VisitBufferLoad(expr, default); - - /// - internal protected sealed override TExprResult VisitBufferLoad(TIR.BufferLoad expr, Unit context) => VisitBufferLoad(expr); - /// /// Visit . /// internal protected virtual TExprResult VisitBufferRegion(TIR.BufferRegion expr) => base.VisitBufferRegion(expr, default); @@ -481,13 +450,6 @@ public partial class ExprVisitor /// internal protected sealed override TExprResult VisitBufferRegion(TIR.BufferRegion expr, Unit context) => VisitBufferRegion(expr); /// - /// Visit . - /// - internal protected virtual TExprResult VisitBufferStore(TIR.BufferStore expr) => base.VisitBufferStore(expr, default); - - /// - internal protected sealed override TExprResult VisitBufferStore(TIR.BufferStore expr, Unit context) => VisitBufferStore(expr); - /// /// Visit . /// internal protected virtual TExprResult VisitFor(TIR.For expr) => base.VisitFor(expr, default); @@ -688,14 +650,6 @@ public partial class ExprVisitor /// protected sealed override TExprResult VisitLeafPhysicalBuffer(TIR.PhysicalBuffer expr, Unit context) => VisitLeafPhysicalBuffer(expr); - /// - /// Visit leaf . - /// - protected virtual TExprResult VisitLeafBufferLoad(TIR.BufferLoad expr) => base.VisitLeafBufferLoad(expr, default); - - /// - protected sealed override TExprResult VisitLeafBufferLoad(TIR.BufferLoad expr, Unit context) => VisitLeafBufferLoad(expr); - /// /// Visit leaf . /// @@ -704,14 +658,6 @@ public partial class ExprVisitor /// protected sealed override TExprResult VisitLeafBufferRegion(TIR.BufferRegion expr, Unit context) => VisitLeafBufferRegion(expr); - /// - /// Visit leaf . - /// - protected virtual TExprResult VisitLeafBufferStore(TIR.BufferStore expr) => base.VisitLeafBufferStore(expr, default); - - /// - protected sealed override TExprResult VisitLeafBufferStore(TIR.BufferStore expr, Unit context) => VisitLeafBufferStore(expr); - /// /// Visit leaf . /// diff --git a/src/Nncase.Core/TIR/BufferLoad.cs b/src/Nncase.Core/TIR/BufferLoad.cs deleted file mode 100644 index 86081624dd..0000000000 --- a/src/Nncase.Core/TIR/BufferLoad.cs +++ /dev/null @@ -1,40 +0,0 @@ -// Copyright (c) Canaan Inc. All rights reserved. -// Licensed under the Apache license. See LICENSE file in the project root for full license information. - -using System; -using System.Collections.Generic; -using System.Linq; -using System.Text; -using System.Threading.Tasks; -using Nncase.IR; -using Nncase.Utilities; - -namespace Nncase.TIR; - -/// -/// Buffer load node. -/// -public sealed class BufferLoad : Expr -{ - public BufferLoad(PhysicalBuffer buffer, ReadOnlySpan indices) - : base(ArrayUtility.Concat(buffer, indices)) - { - } - - /// - /// Gets the buffer to be loaded. - /// - public PhysicalBuffer Buffer => (PhysicalBuffer)Operands[0]; - - /// - /// Gets the buffer indices. - /// - public ReadOnlySpan Indices => Operands.Slice(1); - - /// - public override TExprResult Accept(ExprFunctor functor, TContext context) - => functor.VisitBufferLoad(this, context); - - public BufferLoad With(PhysicalBuffer? buffer = null, Expr[]? indices = null) - => new BufferLoad(buffer ?? Buffer, indices ?? Indices); -} diff --git a/src/Nncase.Core/TIR/BufferStore.cs b/src/Nncase.Core/TIR/BufferStore.cs deleted file mode 100644 index 56d6a9df4d..0000000000 --- a/src/Nncase.Core/TIR/BufferStore.cs +++ /dev/null @@ -1,47 +0,0 @@ -// Copyright (c) Canaan Inc. All rights reserved. -// Licensed under the Apache license. See LICENSE file in the project root for full license information. - -using System; -using System.Collections.Generic; -using System.Linq; -using System.Text; -using System.Threading.Tasks; -using Nncase.IR; - -namespace Nncase.TIR; - -/// -/// Buffer store node. -/// -public sealed class BufferStore : Expr -{ - private readonly int _indicesCount; - - public BufferStore(PhysicalBuffer buffer, ReadOnlySpan indices, Expr value) - : base(new Expr[] { buffer }.Concat(indices.ToArray()).Append(value).ToArray()) - { - _indicesCount = indices.Length; - } - - /// - /// Gets the buffer. - /// - public PhysicalBuffer Buffer => (PhysicalBuffer)Operands[0]; - - /// - /// Gets the value we to be stored. - /// - public ReadOnlySpan Indices => Operands[1.._indicesCount]; - - /// - /// Gets the indices location to be stored. - /// - public Expr Value => Operands[_indicesCount + 1]; - - /// - public override TExprResult Accept(ExprFunctor functor, TContext context) - => functor.VisitBufferStore(this, context); - - public BufferStore With(PhysicalBuffer? buffer = null, Expr[]? indices = null, Expr? value = null) - => new BufferStore(buffer ?? Buffer, indices ?? Indices, value ?? Value); -} diff --git a/src/Nncase.Core/TIR/Script.cs b/src/Nncase.Core/TIR/Script.cs index 8f640a0419..fd89d8661a 100644 --- a/src/Nncase.Core/TIR/Script.cs +++ b/src/Nncase.Core/TIR/Script.cs @@ -313,4 +313,20 @@ public static Call Emit(out T value, Func creator) value = creator(); return Nop(); } + + /// + /// buffer load. + /// + /// buffer. + /// indices. + /// call bufferload. + public static Call BufferLoad(TIR.Buffer buffer, params Expr[] indices) => new Call(new IR.Buffers.BufferLoad(), buffer, new IR.Tuple(indices)); + + /// + /// buffer store. + /// + /// buffer. + /// indices and value. + /// buffer store. + public static Call BufferStore(TIR.Buffer buffer, params Expr[] indicesAndValue) => new Call(new IR.Buffers.BufferLoad(), buffer, new IR.Tuple(indicesAndValue[..^1]), indicesAndValue[^1]); } diff --git a/src/Nncase.Diagnostics/Diagnostics/ScriptPrintVisitor.cs b/src/Nncase.Diagnostics/Diagnostics/ScriptPrintVisitor.cs index 21047b326d..9dae59bd85 100644 --- a/src/Nncase.Diagnostics/Diagnostics/ScriptPrintVisitor.cs +++ b/src/Nncase.Diagnostics/Diagnostics/ScriptPrintVisitor.cs @@ -481,34 +481,6 @@ protected override IPrintSymbol VisitBlock(Block expr) return doc; } - /// - protected override IPrintSymbol VisitBufferLoad(BufferLoad expr) - { - if (_exprMemo.TryGetValue(expr, out var doc)) - { - return doc; - } - - _scope.Push(); - _scope.Append($"{expr.Buffer.Name}[{string.Join(", ", expr.Indices.ToArray().Select(Visit))}]"); - doc = new(_scope.Pop()); - return doc; - } - - /// - protected override IPrintSymbol VisitBufferStore(BufferStore expr) - { - if (_exprMemo.TryGetValue(expr, out var doc)) - { - return doc; - } - - _scope.Push(); - _scope.Append($"{expr.Buffer.Name}[{string.Join(", ", expr.Indices.ToArray().Select(Visit))}] = {Visit(expr.Value)}"); - doc = new(_scope.Pop()); - return doc; - } - /// protected override IPrintSymbol VisitIterVar(IterVar expr) { diff --git a/src/Nncase.Evaluator/Buffer.cs b/src/Nncase.Evaluator/Buffer.cs deleted file mode 100644 index 079cdaecce..0000000000 --- a/src/Nncase.Evaluator/Buffer.cs +++ /dev/null @@ -1,9 +0,0 @@ -// Copyright (c) Canaan Inc. All rights reserved. -// Licensed under the Apache license. See LICENSE file in the project root for full license information. - -namespace Nncase.Evaluator.TIR -{ - public class Buffer - { - } -} diff --git a/src/Nncase.Evaluator/Buffers/BufferLoad.cs b/src/Nncase.Evaluator/Buffers/BufferLoad.cs new file mode 100644 index 0000000000..0069028c32 --- /dev/null +++ b/src/Nncase.Evaluator/Buffers/BufferLoad.cs @@ -0,0 +1,41 @@ +// Copyright (c) Canaan Inc. All rights reserved. +// Licensed under the Apache license. See LICENSE file in the project root for full license information. + +using Nncase.IR; +using Nncase.IR.Buffers; + +namespace Nncase.Evaluator.Buffers; + +/// +/// Evaluator for BufferOf. +/// +[TypeInferGenerator] +public partial class BufferLoadEvaluator : ITypeInferencer, IOpPrinter +{ + public string Visit(IIRPrinterContext context, BufferLoad target, bool iLmode) + { + if (iLmode) + { + throw new System.NotSupportedException(); + } + return $"{context.GetArgument(target, BufferLoad.Input)}[{context.GetArgument(target, BufferLoad.Indices)}]"; + } + + private IRType Visit(TensorType input, TupleType indices) + { + if (indices.Count != input.Shape.Rank) + { + return new InvalidType($"the input buffer rank {input.Shape.Rank} != indices.Count {indices.Count}"); + } + + foreach (var item in indices) + { + if (item is not TensorType { IsScalar: true, DType: var dtype } || dtype != DataTypes.Int32) + { + return new InvalidType("indices is not int32 type!"); + } + } + + return TensorType.Scalar(input.DType); + } +} diff --git a/src/Nncase.Evaluator/Buffers/BufferStore.cs b/src/Nncase.Evaluator/Buffers/BufferStore.cs new file mode 100644 index 0000000000..2d0020486b --- /dev/null +++ b/src/Nncase.Evaluator/Buffers/BufferStore.cs @@ -0,0 +1,47 @@ +// Copyright (c) Canaan Inc. All rights reserved. +// Licensed under the Apache license. See LICENSE file in the project root for full license information. + +using Nncase.IR; +using Nncase.IR.Buffers; + +namespace Nncase.Evaluator.Buffers; + +/// +/// Evaluator for BufferOf. +/// +[TypeInferGenerator] +public partial class BufferStoreEvaluator : ITypeInferencer, IOpPrinter +{ + public string Visit(IIRPrinterContext context, BufferStore target, bool iLmode) + { + if (iLmode) + { + throw new System.NotSupportedException(); + } + + return $"{context.GetArgument(target, BufferStore.Input)}[{context.GetArgument(target, BufferStore.Indices)}] = {context.GetArgument(target, BufferStore.Value)}"; + } + + private IRType Visit(TensorType input, TupleType indices, TensorType value) + { + if (indices.Count != input.Shape.Rank) + { + return new InvalidType($"the input buffer rank {input.Shape.Rank} != indices.Count {indices.Count}"); + } + + foreach (var item in indices) + { + if (item is not TensorType { IsScalar: true, DType: var dtype } || dtype != DataTypes.Int32) + { + return new InvalidType("indices is not int32 type!"); + } + } + + if (!value.IsScalar || input.DType != value.DType) + { + return new InvalidType("value can't store!"); + } + + return TupleType.Void; + } +} diff --git a/src/Nncase.Evaluator/TIR/Store.cs b/src/Nncase.Evaluator/TIR/Store.cs index e0b6d5bcc7..9a1d6d6cda 100644 --- a/src/Nncase.Evaluator/TIR/Store.cs +++ b/src/Nncase.Evaluator/TIR/Store.cs @@ -27,8 +27,6 @@ public string Visit(IIRPrinterContext context, Store target, bool iLmode) _ = context.GetArgument(target, Store.Value); var index = context.GetArgument(target, Store.Index); return $"{handle}[{index}] = {index}"; - - throw new System.NotImplementedException(); } private IRType Visit(Store target, TensorType handle, TensorType index, TensorType value) diff --git a/src/Nncase.Evaluator/TypeInferenceVisitor.cs b/src/Nncase.Evaluator/TypeInferenceVisitor.cs index 208d085058..c0f9f28f55 100644 --- a/src/Nncase.Evaluator/TypeInferenceVisitor.cs +++ b/src/Nncase.Evaluator/TypeInferenceVisitor.cs @@ -52,28 +52,6 @@ protected override IRType VisitLeafBlock(Block expr) return TupleType.Void; } - /// - protected override IRType VisitLeafBufferLoad(BufferLoad expr) - { - IRType type; - VerifySubField(expr, expr.Buffer, TypePatternUtility.IsPointer()); - for (int i = 0; i < expr.Indices.Length; i++) - { - VerifySubField(expr, expr.Indices[i], TypePatternUtility.IsIntegralScalar(), $"BufferLoad.Indices[{i}]"); - } - - if (expr.Buffer.CheckedType is TensorType { IsScalar: true, DType: PointerType { ElemType: PrimType pointedType } }) - { - type = TensorType.Scalar(pointedType); - } - else - { - type = new InvalidType($"Can't load from {expr.Buffer.CheckedType}"); - } - - return type; - } - /// protected override IRType VisitLeafBufferRegion(BufferRegion expr) { @@ -90,32 +68,6 @@ protected override IRType VisitLeafBufferRegion(BufferRegion expr) return type; } - /// - protected override IRType VisitLeafBufferStore(BufferStore expr) - { - VerifySubField(expr, expr.Buffer, TypePatternUtility.IsPointer()); - for (int i = 0; i < expr.Indices.Length; i++) - { - VerifySubField(expr, expr.Indices[i], TypePatternUtility.IsIntegralScalar(), $"BufferStore.Indices[{i}]"); - } - - VerifySubField(expr, expr.Value, TypePatternUtility.IsScalar()); - - IRType type; - if (expr.Value.CheckedType is TensorType { IsScalar: true, DType: PrimType valueType } && - expr.Buffer.CheckedType is TensorType { IsScalar: true, DType: PointerType { ElemType: PrimType pointedType } } - && valueType == pointedType) - { - type = TupleType.Void; - } - else - { - type = new InvalidType($"Can't store {expr.Value.CheckedType} to {expr.Buffer.CheckedType}"); - } - - return type; - } - /// protected override IRType VisitLeafCall(Call expr) { From 8dbe2279f0bbf9235bb990d665ecea739572f1f5 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?=E9=83=91=E5=90=AF=E8=88=AA?= <597323109@qq.com> Date: Tue, 25 Jul 2023 19:53:17 +0800 Subject: [PATCH 037/308] update tiling --- .../Passes/Tile/SingleCPUFusionConverter.cs | 12 ++--- .../Nncase.Modules.CPU/Targets/CPUTarget.cs | 2 + src/Nncase.Core/IR/Buffers/BufferLoad.cs | 10 +++- src/Nncase.Core/IR/Buffers/BufferStore.cs | 13 ++++- .../Passes/Mutators/FlattenBuffer.cs | 35 ++++++------ .../Passes/Mutators/FoldMathCall.cs | 30 ----------- src/Nncase.Core/Passes/Mutators/Mutator.cs | 6 --- .../SubstituteVarAndCollectOpaqueBlock.cs | 54 ------------------- src/Nncase.Core/TIR/Scheduler.cs | 19 +++---- src/Nncase.Core/TIR/Script.cs | 7 +-- src/Nncase.Evaluator/Buffers/BufferLoad.cs | 1 + src/Nncase.Evaluator/Buffers/BufferModule.cs | 2 + src/Nncase.Tests/Core/UnitTestMutator.cs | 2 - 13 files changed, 61 insertions(+), 132 deletions(-) delete mode 100644 src/Nncase.Core/Passes/Mutators/FoldMathCall.cs delete mode 100644 src/Nncase.Core/Passes/Mutators/SubstituteVarAndCollectOpaqueBlock.cs diff --git a/modules/Nncase.Modules.CPU/Passes/Tile/SingleCPUFusionConverter.cs b/modules/Nncase.Modules.CPU/Passes/Tile/SingleCPUFusionConverter.cs index 4c3aa9045e..5301794aaf 100644 --- a/modules/Nncase.Modules.CPU/Passes/Tile/SingleCPUFusionConverter.cs +++ b/modules/Nncase.Modules.CPU/Passes/Tile/SingleCPUFusionConverter.cs @@ -82,9 +82,8 @@ private void GenerateUnary(Unary unary, ReadOnlySpan arguments, Buffer r { var input = arguments[Unary.Input.Index]; var loops = Enumerable.Range(0, input.Rank).Select(i => (T.ForLoop(out var loopVar, (0, input.Dimensions[i]), LoopMode.Serial, $"loop_{i}"), loopVar)).ToArray(); - var input_index = Enumerable.Range(0, input.Rank).Aggregate((Expr)0, (acc, i) => acc + (input.Strides[i] * loops[i].loopVar)); - var output_index = Enumerable.Range(0, input.Rank).Aggregate((Expr)0, (acc, i) => acc + (ret.Strides[i] * loops[i].loopVar)); - Expr stmt = T.Store(ret, output_index, IR.F.Math.Unary(unary.UnaryOp, T.Load(input, output_index))); + var loopVars = loops.Select(f => f.loopVar).ToArray(); + Expr stmt = T.BufferStore(ret, loopVars, IR.F.Math.Unary(unary.UnaryOp, T.BufferLoad(input, loopVars))); var final = loops.Reverse().Aggregate(stmt, (acc, p) => p.Item1.Body(acc).Build()); _mainBody.Add(T.Block(nameof(Unary)).Body(final).Build()); } @@ -106,11 +105,8 @@ private void GenerateBinary(Binary binary, ReadOnlySpan arguments, Buffe var rhsScale = outShape.Zip(rhsShape).Select(s => s.First / s.Second).ToArray(); var loops = Enumerable.Range(0, outShape.Length).Select(i => (T.ForLoop(out var loopVar, (0, outShape[i]), LoopMode.Serial, $"loop_{i}"), loopVar)).ToArray(); - var input_index = Enumerable.Range(0, input.Rank).Aggregate((Expr)0, (acc, i) => acc + (input.Strides[i] * loops[i].loopVar)); - var output_index = Enumerable.Range(0, input.Rank).Aggregate((Expr)0, (acc, i) => acc + (ret.Strides[i] * loops[i].loopVar)); - Expr stmt = T.Store(ret, output_index, IR.F.Math.Unary(unary.UnaryOp, T.Load(input, output_index))); - var final = loops.Reverse().Aggregate(stmt, (acc, p) => p.Item1.Body(acc).Build()); - _mainBody.Add(T.Block(nameof(Unary)).Body(final).Build()); + // var ?final = loops.Reverse().Aggregate(stmt, (acc, p) => p.Item1.Body(acc).Build()); + // _mainBody.Add(T.Block(nameof(Unary)).Body(final).Build()); } private TIR.Buffer TryAllocateBuffer(Expr expr) diff --git a/modules/Nncase.Modules.CPU/Targets/CPUTarget.cs b/modules/Nncase.Modules.CPU/Targets/CPUTarget.cs index 1b3504f305..b3fe833906 100644 --- a/modules/Nncase.Modules.CPU/Targets/CPUTarget.cs +++ b/modules/Nncase.Modules.CPU/Targets/CPUTarget.cs @@ -87,7 +87,9 @@ public void RegisterTargetDependentAfterQuantPass(IPassManager passManager, Comp passManager.Add().Configure(p => { p.Add(); + p.Add(); p.Add(); + p.Add(); }); } diff --git a/src/Nncase.Core/IR/Buffers/BufferLoad.cs b/src/Nncase.Core/IR/Buffers/BufferLoad.cs index 4885197c3f..882282a073 100644 --- a/src/Nncase.Core/IR/Buffers/BufferLoad.cs +++ b/src/Nncase.Core/IR/Buffers/BufferLoad.cs @@ -8,7 +8,7 @@ namespace Nncase.IR.Buffers; /// -/// BufferIndexOf expression. +/// BufferLoad expression. /// [PatternFunctionalGenerator] public sealed partial class BufferLoad : Op @@ -17,6 +17,12 @@ public sealed partial class BufferLoad : Op /// Get the input parameter. /// public static readonly ParameterInfo Input = new(typeof(BufferLoad), 0, "input", IsTensor()); - + + /// + /// Get the indices. + /// public static readonly ParameterInfo Indices = new(typeof(BufferLoad), 1, "indices", IsTuple()); + + /// + public override bool CanFoldConstCall => false; } diff --git a/src/Nncase.Core/IR/Buffers/BufferStore.cs b/src/Nncase.Core/IR/Buffers/BufferStore.cs index 04f57d594b..b8402d10e3 100644 --- a/src/Nncase.Core/IR/Buffers/BufferStore.cs +++ b/src/Nncase.Core/IR/Buffers/BufferStore.cs @@ -8,7 +8,7 @@ namespace Nncase.IR.Buffers; /// -/// BufferIndexOf expression. +/// BufferStore op. /// [PatternFunctionalGenerator] public sealed partial class BufferStore : Op @@ -18,7 +18,16 @@ public sealed partial class BufferStore : Op /// public static readonly ParameterInfo Input = new(typeof(BufferStore), 0, "input", IsTensor()); + /// + /// Get the indices parameter. + /// public static readonly ParameterInfo Indices = new(typeof(BufferStore), 1, "indices", IsTuple()); - public static readonly ParameterInfo Value = new(typeof(BufferStore), 2, "value", IsTensor()); + /// + /// Get the value parameter. + /// + public static readonly ParameterInfo Value = new(typeof(BufferStore), 2, "value", IsScalar()); + + /// + public override bool CanFoldConstCall => false; } diff --git a/src/Nncase.Core/Passes/Mutators/FlattenBuffer.cs b/src/Nncase.Core/Passes/Mutators/FlattenBuffer.cs index a0431cdd68..a941477030 100644 --- a/src/Nncase.Core/Passes/Mutators/FlattenBuffer.cs +++ b/src/Nncase.Core/Passes/Mutators/FlattenBuffer.cs @@ -12,35 +12,38 @@ namespace Nncase.Passes.Mutators; /// -/// Flatten the multi-dimensional BufferLoad and BufferStore to single dimensional Load/Store. Also remove Block to ensure that the flattened TIR can not be scheduled again. +/// Flatten the multi-dimensional BufferLoad and BufferStore to single dimensional Load/Store. /// public sealed class FlattenBuffer : ExprRewriter { /// protected override Expr RewriteLeafBlock(Block expr) { - if (!expr.IterVars.IsEmpty) + // TODO: put the unfold block into this. + if (expr.Predicate is TensorConst tc && tc.Value.ToScalar() == true) { - throw new InvalidOperationException("Non-opaque blocks are not allowed in FlattenBuffer. Please call pass ConvertBlocksToOpaque before."); + return expr.Body; } - // 1. Visit the body - var predicate = expr.Predicate; - if (predicate is TensorConst { Value: { Length: 1 } t } - && t.ToScalar()) + return T.Nop(); + } + + /// + protected override Expr RewriteLeafCall(Call expr) + { + if (expr.Target is IR.Buffers.BufferLoad) { - return expr.Body; + var indices = (IR.Tuple)expr[IR.Buffers.BufferLoad.Indices]; + var input = (TIR.Buffer)expr[IR.Buffers.BufferLoad.Input]; + return T.Load(input, Enumerable.Range(0, indices.Count).Aggregate((Expr)0, (acc, i) => acc + (input.Strides[i] * indices[i]))); } - else + else if (expr.Target is IR.Buffers.BufferStore) { - return new IfThenElse(predicate, expr.Body); + var indices = (IR.Tuple)expr[IR.Buffers.BufferStore.Indices]; + var input = (TIR.Buffer)expr[IR.Buffers.BufferStore.Input]; + return T.Store(input, Enumerable.Range(0, indices.Count).Aggregate((Expr)0, (acc, i) => acc + (input.Strides[i] * indices[i])), expr[IR.Buffers.BufferStore.Value]); } - // Step 3. Handle allocations in reverse order - // TODO add the alloc buffers. - // for (size_t i = new_block->alloc_buffers.size(); i > 0; --i) { - // const Buffer& buffer = new_block->alloc_buffers[i - 1]; - // body = MakeAllocStmt(buffer, std::move(body)); - // } + return expr; } } diff --git a/src/Nncase.Core/Passes/Mutators/FoldMathCall.cs b/src/Nncase.Core/Passes/Mutators/FoldMathCall.cs deleted file mode 100644 index af25604454..0000000000 --- a/src/Nncase.Core/Passes/Mutators/FoldMathCall.cs +++ /dev/null @@ -1,30 +0,0 @@ -// Copyright (c) Canaan Inc. All rights reserved. -// Licensed under the Apache license. See LICENSE file in the project root for full license information. - -using System.Reactive; -using NetFabric.Hyperlinq; -using Nncase.Evaluator; -using Nncase.IR; -using Nncase.Passes; - -namespace Nncase.Passes.Mutators; - -/// -/// fold math calc operator. -/// -public sealed class FoldMathCall : ExprRewriter -{ - /// - protected override Expr RewriteLeafCall(Call expr) - { - if (expr.Target is Op op && op.GetType().Namespace is string @namespace - && @namespace.StartsWith("Nncase.IR.Math")) - { - return expr.Arguments.AsValueEnumerable().All(x => x is Const) - ? Const.FromValue(CompilerServices.Evaluate(expr)) - : expr; - } - - return expr; - } -} diff --git a/src/Nncase.Core/Passes/Mutators/Mutator.cs b/src/Nncase.Core/Passes/Mutators/Mutator.cs index 1392ff7726..1f2587d5d0 100644 --- a/src/Nncase.Core/Passes/Mutators/Mutator.cs +++ b/src/Nncase.Core/Passes/Mutators/Mutator.cs @@ -50,10 +50,4 @@ public static class Mutator /// /// RemoveNop. public static Func RemoveNop() => () => new Mutators.RemoveNop(); - - /// - /// fold math calc operator. - /// - /// FoldMathCall. - public static Func FoldMathCall() => () => new Mutators.FoldMathCall(); } diff --git a/src/Nncase.Core/Passes/Mutators/SubstituteVarAndCollectOpaqueBlock.cs b/src/Nncase.Core/Passes/Mutators/SubstituteVarAndCollectOpaqueBlock.cs deleted file mode 100644 index da6daf3ed2..0000000000 --- a/src/Nncase.Core/Passes/Mutators/SubstituteVarAndCollectOpaqueBlock.cs +++ /dev/null @@ -1,54 +0,0 @@ -// Copyright (c) Canaan Inc. All rights reserved. -// Licensed under the Apache license. See LICENSE file in the project root for full license information. -using System; -using System.Collections.Generic; -using System.IO; -using System.Linq; -using System.Reactive; -using Nncase.IR; -using Nncase.TIR; - -namespace Nncase.Passes.Mutators; - -/// -/// Substitute vars and collect the reuse mapping of opaque blocks. -/// -public sealed class SubstituteVarAndCollectOpaqueBlock : ExprRewriter -{ - private readonly Func _varMapper; - private readonly Dictionary _opaqueBlocks; - - /// - /// Initializes a new instance of the class. - /// . - /// - public SubstituteVarAndCollectOpaqueBlock( - Func varMaper, - Dictionary opaque_blocks) - { - _varMapper = varMaper; - _opaqueBlocks = opaque_blocks; - } - - /// - protected override Expr RewriteLeafVar(Var expr) - { - if (_varMapper(expr) is Expr replace) - { - return replace; - } - - return expr; - } - - protected override Expr RewriteLeafBlock(Block expr) - { - var replace = (Block)base.RewriteLeafBlock(expr); - if (replace.IterVars.IsEmpty) - { - _opaqueBlocks.Add(expr, replace); - } - - return replace; - } -} diff --git a/src/Nncase.Core/TIR/Scheduler.cs b/src/Nncase.Core/TIR/Scheduler.cs index 214bd983f6..eb660d3dac 100644 --- a/src/Nncase.Core/TIR/Scheduler.cs +++ b/src/Nncase.Core/TIR/Scheduler.cs @@ -101,16 +101,17 @@ public For[] Split(For loop, params Expr[] factors) // Step 3. create new for loop. var nFor = new For[factors.Length]; - nbody = (Sequential)new Passes.Mutators.SubstituteVarAndCollectOpaqueBlock(v => v == loop.LoopVar ? substitute : v, opaque_block_reuse).Rewrite(nbody); - for (int i = factors.Length - 1; i >= 0; i--) - { - var @for = new For(newloopVars[i], (0, factors[i]), LoopMode.Serial, nbody); - nbody = T.Sequential(@for); - nFor[i] = @for; - } - // Setp 4. update the function - Entry = (Function)new Passes.Mutators.Substitutor(expr => object.ReferenceEquals(expr, loop) ? nFor[0] : null).Rewrite(Entry); + // nbody = (Sequential)new Passes.Mutators.SubstituteVarAndCollectOpaqueBlock(v => v == loop.LoopVar ? substitute : v, opaque_block_reuse).Rewrite(nbody); + // for (int i = factors.Length - 1; i >= 0; i--) + // { + // var @for = new For(newloopVars[i], (0, factors[i]), LoopMode.Serial, nbody); + // nbody = T.Sequential(@for); + // nFor[i] = @for; + // } + + // // Setp 4. update the function + // Entry = (Function)new Passes.Mutators.Substitutor(expr => object.ReferenceEquals(expr, loop) ? nFor[0] : null).Rewrite(Entry); return nFor; } diff --git a/src/Nncase.Core/TIR/Script.cs b/src/Nncase.Core/TIR/Script.cs index fd89d8661a..c9f4a77f72 100644 --- a/src/Nncase.Core/TIR/Script.cs +++ b/src/Nncase.Core/TIR/Script.cs @@ -326,7 +326,8 @@ public static Call Emit(out T value, Func creator) /// buffer store. /// /// buffer. - /// indices and value. - /// buffer store. - public static Call BufferStore(TIR.Buffer buffer, params Expr[] indicesAndValue) => new Call(new IR.Buffers.BufferLoad(), buffer, new IR.Tuple(indicesAndValue[..^1]), indicesAndValue[^1]); + /// indices. + /// value. + /// call bufferstore. + public static Call BufferStore(TIR.Buffer buffer, Expr[] indices, Expr value) => new Call(new IR.Buffers.BufferStore(), buffer, new IR.Tuple(indices), value); } diff --git a/src/Nncase.Evaluator/Buffers/BufferLoad.cs b/src/Nncase.Evaluator/Buffers/BufferLoad.cs index 0069028c32..210227c555 100644 --- a/src/Nncase.Evaluator/Buffers/BufferLoad.cs +++ b/src/Nncase.Evaluator/Buffers/BufferLoad.cs @@ -18,6 +18,7 @@ public string Visit(IIRPrinterContext context, BufferLoad target, bool iLmode) { throw new System.NotSupportedException(); } + return $"{context.GetArgument(target, BufferLoad.Input)}[{context.GetArgument(target, BufferLoad.Indices)}]"; } diff --git a/src/Nncase.Evaluator/Buffers/BufferModule.cs b/src/Nncase.Evaluator/Buffers/BufferModule.cs index 4547718379..0954206679 100644 --- a/src/Nncase.Evaluator/Buffers/BufferModule.cs +++ b/src/Nncase.Evaluator/Buffers/BufferModule.cs @@ -20,5 +20,7 @@ public void ConfigureServices(IRegistrator registrator) registrator.RegisterManyInterface(reuse: Reuse.Singleton); registrator.RegisterManyInterface(reuse: Reuse.Singleton); registrator.RegisterManyInterface(reuse: Reuse.Singleton); + registrator.RegisterManyInterface(reuse: Reuse.Singleton); + registrator.RegisterManyInterface(reuse: Reuse.Singleton); } } diff --git a/src/Nncase.Tests/Core/UnitTestMutator.cs b/src/Nncase.Tests/Core/UnitTestMutator.cs index b08f089911..b8c7d7d587 100644 --- a/src/Nncase.Tests/Core/UnitTestMutator.cs +++ b/src/Nncase.Tests/Core/UnitTestMutator.cs @@ -29,7 +29,5 @@ public void TestMutator() var removeNop = Mutator.RemoveNop().Invoke(); Assert.Equal(new Passes.Mutators.RemoveNop().IsMutated, removeNop.IsMutated); - var foldMathCall = Mutator.FoldMathCall().Invoke(); - Assert.Equal(new Passes.Mutators.FoldMathCall().IsMutated, foldMathCall.IsMutated); } } From f664e3685264981bc2f720164dcd4f439cf8c103 Mon Sep 17 00:00:00 2001 From: zhen8838 Date: Tue, 25 Jul 2023 11:56:00 +0000 Subject: [PATCH 038/308] Apply code-format changes --- .../Passes/Tile/SingleCPUFusionConverter.cs | 1 + src/Nncase.Core/IR/Buffers/BufferLoad.cs | 2 +- src/Nncase.Core/IR/Buffers/BufferStore.cs | 2 +- src/Nncase.Core/TIR/Scheduler.cs | 10 +++++++--- src/Nncase.Evaluator/Buffers/BufferLoad.cs | 2 +- src/Nncase.Evaluator/Buffers/BufferStore.cs | 2 +- src/Nncase.Tests/Core/UnitTestMutator.cs | 1 - 7 files changed, 12 insertions(+), 8 deletions(-) diff --git a/modules/Nncase.Modules.CPU/Passes/Tile/SingleCPUFusionConverter.cs b/modules/Nncase.Modules.CPU/Passes/Tile/SingleCPUFusionConverter.cs index 5301794aaf..aa7c827ded 100644 --- a/modules/Nncase.Modules.CPU/Passes/Tile/SingleCPUFusionConverter.cs +++ b/modules/Nncase.Modules.CPU/Passes/Tile/SingleCPUFusionConverter.cs @@ -105,6 +105,7 @@ private void GenerateBinary(Binary binary, ReadOnlySpan arguments, Buffe var rhsScale = outShape.Zip(rhsShape).Select(s => s.First / s.Second).ToArray(); var loops = Enumerable.Range(0, outShape.Length).Select(i => (T.ForLoop(out var loopVar, (0, outShape[i]), LoopMode.Serial, $"loop_{i}"), loopVar)).ToArray(); + // var ?final = loops.Reverse().Aggregate(stmt, (acc, p) => p.Item1.Body(acc).Build()); // _mainBody.Add(T.Block(nameof(Unary)).Body(final).Build()); } diff --git a/src/Nncase.Core/IR/Buffers/BufferLoad.cs b/src/Nncase.Core/IR/Buffers/BufferLoad.cs index 882282a073..dbf3427b6e 100644 --- a/src/Nncase.Core/IR/Buffers/BufferLoad.cs +++ b/src/Nncase.Core/IR/Buffers/BufferLoad.cs @@ -1,4 +1,4 @@ -// Copyright (c) Canaan Inc. All rights reserved. +// Copyright (c) Canaan Inc. All rights reserved. // Licensed under the Apache license. See LICENSE file in the project root for full license information. using Nncase.IR.Tensors; diff --git a/src/Nncase.Core/IR/Buffers/BufferStore.cs b/src/Nncase.Core/IR/Buffers/BufferStore.cs index b8402d10e3..2d8e86cad8 100644 --- a/src/Nncase.Core/IR/Buffers/BufferStore.cs +++ b/src/Nncase.Core/IR/Buffers/BufferStore.cs @@ -1,4 +1,4 @@ -// Copyright (c) Canaan Inc. All rights reserved. +// Copyright (c) Canaan Inc. All rights reserved. // Licensed under the Apache license. See LICENSE file in the project root for full license information. using Nncase.IR.Tensors; diff --git a/src/Nncase.Core/TIR/Scheduler.cs b/src/Nncase.Core/TIR/Scheduler.cs index eb660d3dac..30e9552c04 100644 --- a/src/Nncase.Core/TIR/Scheduler.cs +++ b/src/Nncase.Core/TIR/Scheduler.cs @@ -87,7 +87,10 @@ public For[] Split(For loop, params Expr[] factors) } // TODO add assert total == (loop.Dom.Max - loop.Dom.Min) // Step 2. Replace all occurrences of the original loop var with new variables - Expr total = 1, substitute = 0; + _ = 1; + + // Step 2. Replace all occurrences of the original loop var with new variables + Expr substitute = 0; var newloopVars = new Var[factors.Length]; foreach (var i in Enumerable.Range(0, factors.Length)) { @@ -96,8 +99,9 @@ public For[] Split(For loop, params Expr[] factors) newloopVars[i] = loopVar; } - Dictionary opaque_block_reuse = new(); // TODO the opaque_block_reuse for what? - Sequential nbody = loop.Body; + _ = new + Dictionary(); // TODO the opaque_block_reuse for what? + _ = loop.Body; // Step 3. create new for loop. var nFor = new For[factors.Length]; diff --git a/src/Nncase.Evaluator/Buffers/BufferLoad.cs b/src/Nncase.Evaluator/Buffers/BufferLoad.cs index 210227c555..78bab2e920 100644 --- a/src/Nncase.Evaluator/Buffers/BufferLoad.cs +++ b/src/Nncase.Evaluator/Buffers/BufferLoad.cs @@ -1,4 +1,4 @@ -// Copyright (c) Canaan Inc. All rights reserved. +// Copyright (c) Canaan Inc. All rights reserved. // Licensed under the Apache license. See LICENSE file in the project root for full license information. using Nncase.IR; diff --git a/src/Nncase.Evaluator/Buffers/BufferStore.cs b/src/Nncase.Evaluator/Buffers/BufferStore.cs index 2d0020486b..81a833f79e 100644 --- a/src/Nncase.Evaluator/Buffers/BufferStore.cs +++ b/src/Nncase.Evaluator/Buffers/BufferStore.cs @@ -1,4 +1,4 @@ -// Copyright (c) Canaan Inc. All rights reserved. +// Copyright (c) Canaan Inc. All rights reserved. // Licensed under the Apache license. See LICENSE file in the project root for full license information. using Nncase.IR; diff --git a/src/Nncase.Tests/Core/UnitTestMutator.cs b/src/Nncase.Tests/Core/UnitTestMutator.cs index b8c7d7d587..83d99fbb60 100644 --- a/src/Nncase.Tests/Core/UnitTestMutator.cs +++ b/src/Nncase.Tests/Core/UnitTestMutator.cs @@ -28,6 +28,5 @@ public void TestMutator() var removeNop = Mutator.RemoveNop().Invoke(); Assert.Equal(new Passes.Mutators.RemoveNop().IsMutated, removeNop.IsMutated); - } } From 240d195f063eb5ca6445d7f43ac201df934e2100 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?=E9=83=91=E5=90=AF=E8=88=AA?= <597323109@qq.com> Date: Tue, 25 Jul 2023 20:05:48 +0800 Subject: [PATCH 039/308] add matmul cpu kernel --- .../Passes/Tile/SingleCPUFusionConverter.cs | 22 +++++++++++++++++++ .../Targets/UnitTestCPUTargetTiling.cs | 22 +++++++++++++++++++ 2 files changed, 44 insertions(+) diff --git a/modules/Nncase.Modules.CPU/Passes/Tile/SingleCPUFusionConverter.cs b/modules/Nncase.Modules.CPU/Passes/Tile/SingleCPUFusionConverter.cs index aa7c827ded..e1bd589c40 100644 --- a/modules/Nncase.Modules.CPU/Passes/Tile/SingleCPUFusionConverter.cs +++ b/modules/Nncase.Modules.CPU/Passes/Tile/SingleCPUFusionConverter.cs @@ -71,6 +71,9 @@ protected override Unit VisitLeafCall(Call expr) case Binary binary: GenerateBinary(binary, arguments, ret, expr); break; + case MatMul matmul: + GenerateMatMul(arguments, ret, expr); + break; default: throw new NotSupportedException(); } @@ -78,6 +81,25 @@ protected override Unit VisitLeafCall(Call expr) return default; } + private void GenerateMatMul(Buffer[] arguments, Buffer ret, Call expr) + { + var lhs = arguments[0]; + var rhs = arguments[1]; + + // [m,k] @ [k, n] + var body = T.Block(nameof(MatMul)).Body( + T.Serial(out var m, (0, lhs.Dimensions[0])).Body( + T.Serial(out var n, (0, rhs.Dimensions[1])).Body( + T.Serial(out var k, (0, lhs.Dimensions[1])).Body( + T.BufferStore(ret, new[] { m, n }, T.BufferLoad(ret, m, n) + (T.BufferLoad(lhs, m, k) * T.BufferLoad(rhs, k, n))) + ) + ) + ) + ); + + _mainBody.Add(body.Build()); + } + private void GenerateUnary(Unary unary, ReadOnlySpan arguments, Buffer ret) { var input = arguments[Unary.Input.Index]; diff --git a/src/Nncase.Tests/Targets/UnitTestCPUTargetTiling.cs b/src/Nncase.Tests/Targets/UnitTestCPUTargetTiling.cs index 7156572dbf..47a62e510b 100644 --- a/src/Nncase.Tests/Targets/UnitTestCPUTargetTiling.cs +++ b/src/Nncase.Tests/Targets/UnitTestCPUTargetTiling.cs @@ -53,4 +53,26 @@ public async Task TestCpuUnary() fs.Write(IR.F.Random.Normal(DataTypes.Float32, 0, 1, 2, new[] { 1, 2, 3, 4, 5 }).Evaluate().AsTensor().BytesBuffer); } } + + [Fact] + public async Task TestCpuMatMul() + { + var lhs = new Var("lhs", new TensorType(DataTypes.Float32, new[] { 3, 4 })); + var rhs = new Var("rhs", new TensorType(DataTypes.Float32, new[] { 4, 6 })); + var main = new Function("main", IR.F.Tensors.MatMul(lhs, rhs), new[] { lhs, rhs }); + var module = new IR.IRModule(main); + + var compiler = CompileSession.Compiler; + compiler.ImportIRModule(module); + await compiler.CompileAsync(); + using (var fs = Dumpper.OpenFile("test.kmodel")) + { + compiler.Gencode(fs); + } + + using (var fs = Dumpper.OpenFile("input_0.bin")) + { + fs.Write(IR.F.Random.Normal(DataTypes.Float32, 0, 1, 2, new[] { 1, 2, 3, 4, 5 }).Evaluate().AsTensor().BytesBuffer); + } + } } From 55cb1054c1b025d1a978368db98f71a74e174fff Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?=E9=83=91=E5=90=AF=E8=88=AA?= <597323109@qq.com> Date: Tue, 25 Jul 2023 20:06:55 +0800 Subject: [PATCH 040/308] fix warning --- .../Passes/Tile/SingleCPUFusionConverter.cs | 6 +----- 1 file changed, 1 insertion(+), 5 deletions(-) diff --git a/modules/Nncase.Modules.CPU/Passes/Tile/SingleCPUFusionConverter.cs b/modules/Nncase.Modules.CPU/Passes/Tile/SingleCPUFusionConverter.cs index e1bd589c40..7263c6a55f 100644 --- a/modules/Nncase.Modules.CPU/Passes/Tile/SingleCPUFusionConverter.cs +++ b/modules/Nncase.Modules.CPU/Passes/Tile/SingleCPUFusionConverter.cs @@ -91,11 +91,7 @@ private void GenerateMatMul(Buffer[] arguments, Buffer ret, Call expr) T.Serial(out var m, (0, lhs.Dimensions[0])).Body( T.Serial(out var n, (0, rhs.Dimensions[1])).Body( T.Serial(out var k, (0, lhs.Dimensions[1])).Body( - T.BufferStore(ret, new[] { m, n }, T.BufferLoad(ret, m, n) + (T.BufferLoad(lhs, m, k) * T.BufferLoad(rhs, k, n))) - ) - ) - ) - ); + T.BufferStore(ret, new[] { m, n }, T.BufferLoad(ret, m, n) + (T.BufferLoad(lhs, m, k) * T.BufferLoad(rhs, k, n))))))); _mainBody.Add(body.Build()); } From a71ad51a7428e02600a29fc8e4f9c5d3a42b5360 Mon Sep 17 00:00:00 2001 From: zhengqihang <597323109@qq.com> Date: Wed, 26 Jul 2023 13:30:50 +0800 Subject: [PATCH 041/308] update matmul tiling --- .../Passes/Rules/LowerMatMul.cs | 9 +++++++-- .../Passes/Rules/MakeFusion.cs | 2 +- .../Passes/Tile/SingleCPUFusionConverter.cs | 16 ++++++++++------ src/Nncase.Passes/Rules/Neutral/FusionMaker.cs | 4 +++- .../Targets/UnitTestCPUTargetTiling.cs | 5 +++-- 5 files changed, 24 insertions(+), 12 deletions(-) diff --git a/modules/Nncase.Modules.CPU/Passes/Rules/LowerMatMul.cs b/modules/Nncase.Modules.CPU/Passes/Rules/LowerMatMul.cs index d23b542ff9..a087945bbf 100644 --- a/modules/Nncase.Modules.CPU/Passes/Rules/LowerMatMul.cs +++ b/modules/Nncase.Modules.CPU/Passes/Rules/LowerMatMul.cs @@ -27,8 +27,13 @@ public partial class LowerMatMul : RewriteRule IsWildcard("inputA") with { TypePattern = IsFloat() }, IsWildcard("inputB") with { TypePattern = IsFloat() }); - private Expr GetReplace(MatMul matmul, Expr inputA, Expr inputB) + private Expr? GetReplace(MatMul matmul, Expr inputA, Expr inputB) { - return CPUKernel(matmul, inputA, inputB); + if (inputA.CheckedShape.Rank == inputB.CheckedShape.Rank && inputA.CheckedShape.Zip(inputB.CheckedShape).SkipLast(2).All(d => d.First == d.Second)) + { + return CPUKernel(matmul, inputA, inputB); + } + + return null; } } diff --git a/modules/Nncase.Modules.CPU/Passes/Rules/MakeFusion.cs b/modules/Nncase.Modules.CPU/Passes/Rules/MakeFusion.cs index d372d36037..54e6eb725f 100644 --- a/modules/Nncase.Modules.CPU/Passes/Rules/MakeFusion.cs +++ b/modules/Nncase.Modules.CPU/Passes/Rules/MakeFusion.cs @@ -40,7 +40,7 @@ internal sealed partial class CPUFusion : FusionMaker } var newCall = new Call(op, newInputs.ToArray()); - var callFusion = new Call(new Fusion(FullName, ModuleKind, newCall, newInputs.OfType().ToArray()), newInputs.Select((e, i) => (e, i)).Where(p => p.e is Var).Select(p => callParams[p.i]).ToArray()); + var callFusion = new Call(new Fusion($"{op.Target.GetType().Name}_{Count}", ModuleKind, newCall, newInputs.OfType().ToArray()), newInputs.Select((e, i) => (e, i)).Where(p => p.e is Var).Select(p => callParams[p.i]).ToArray()); return callFusion; } } diff --git a/modules/Nncase.Modules.CPU/Passes/Tile/SingleCPUFusionConverter.cs b/modules/Nncase.Modules.CPU/Passes/Tile/SingleCPUFusionConverter.cs index 7263c6a55f..a399410337 100644 --- a/modules/Nncase.Modules.CPU/Passes/Tile/SingleCPUFusionConverter.cs +++ b/modules/Nncase.Modules.CPU/Passes/Tile/SingleCPUFusionConverter.cs @@ -86,13 +86,17 @@ private void GenerateMatMul(Buffer[] arguments, Buffer ret, Call expr) var lhs = arguments[0]; var rhs = arguments[1]; - // [m,k] @ [k, n] - var body = T.Block(nameof(MatMul)).Body( - T.Serial(out var m, (0, lhs.Dimensions[0])).Body( - T.Serial(out var n, (0, rhs.Dimensions[1])).Body( - T.Serial(out var k, (0, lhs.Dimensions[1])).Body( - T.BufferStore(ret, new[] { m, n }, T.BufferLoad(ret, m, n) + (T.BufferLoad(lhs, m, k) * T.BufferLoad(rhs, k, n))))))); + var loops = Enumerable.Range(0, lhs.Rank - 2).Select(i => (T.ForLoop(out var loopVar, (0, lhs.Dimensions[i]), LoopMode.Serial, $"loop_{i}"), loopVar)).ToArray(); + var loopVars = loops.Select(f => f.loopVar).ToArray(); + var stmt = T.Serial(out var m, (0, lhs.Dimensions[0])).Body( + T.Serial(out var n, (0, rhs.Dimensions[1])).Body( + T.Serial(out var k, (0, lhs.Dimensions[1])).Body( + T.BufferStore(ret, loopVars.Concat(new[] { m, n }).ToArray(), T.BufferLoad(ret, loopVars.Concat(new[] { m, n }).ToArray()) + (T.BufferLoad(lhs, loopVars.Concat(new[] { m, k }).ToArray()) * T.BufferLoad(rhs, loopVars.Concat(new[] { k, n }).ToArray())))))). + Build(); + var final = loops.Reverse().Aggregate(stmt, (acc, p) => p.Item1.Body(acc).Build()); + // [m,k] @ [k, n] + var body = T.Block(nameof(MatMul)).Body(final); _mainBody.Add(body.Build()); } diff --git a/src/Nncase.Passes/Rules/Neutral/FusionMaker.cs b/src/Nncase.Passes/Rules/Neutral/FusionMaker.cs index 8e371dded0..7d64ddd81f 100644 --- a/src/Nncase.Passes/Rules/Neutral/FusionMaker.cs +++ b/src/Nncase.Passes/Rules/Neutral/FusionMaker.cs @@ -24,11 +24,13 @@ public abstract class FusionMaker : RewriteRule { private int _count; + public int Count { get => _count++; } + public virtual string Name { get; } = "FusionMaker"; public virtual string ModuleKind { get; } = "StackVM"; - public string FullName => $"{Name}_{_count++}"; + public string FullName => $"{Name}_{Count}"; } /// diff --git a/src/Nncase.Tests/Targets/UnitTestCPUTargetTiling.cs b/src/Nncase.Tests/Targets/UnitTestCPUTargetTiling.cs index 47a62e510b..b1cca37c84 100644 --- a/src/Nncase.Tests/Targets/UnitTestCPUTargetTiling.cs +++ b/src/Nncase.Tests/Targets/UnitTestCPUTargetTiling.cs @@ -58,8 +58,9 @@ public async Task TestCpuUnary() public async Task TestCpuMatMul() { var lhs = new Var("lhs", new TensorType(DataTypes.Float32, new[] { 3, 4 })); - var rhs = new Var("rhs", new TensorType(DataTypes.Float32, new[] { 4, 6 })); - var main = new Function("main", IR.F.Tensors.MatMul(lhs, rhs), new[] { lhs, rhs }); + var rhs = IR.F.Random.Normal(DataTypes.Float32, 0, 1, 0, new[] { 4, 6 }); + //new Var("rhs", new TensorType(DataTypes.Float32, new[] { 4, 6 })); + var main = new Function("main", IR.F.Tensors.MatMul(lhs, rhs), new[] { lhs }); var module = new IR.IRModule(main); var compiler = CompileSession.Compiler; From 31c479130d98f005a5e66311b14f01517bf111c1 Mon Sep 17 00:00:00 2001 From: huochenghai Date: Wed, 26 Jul 2023 14:11:46 +0800 Subject: [PATCH 042/308] cpu binary to tir --- .../Passes/Tile/SingleCPUFusionConverter.cs | 14 +++++++++----- modules/cpu/src/runtime/runtime_function.cpp | 3 --- .../Targets/UnitTestCPUTargetTiling.cs | 3 ++- 3 files changed, 11 insertions(+), 9 deletions(-) diff --git a/modules/Nncase.Modules.CPU/Passes/Tile/SingleCPUFusionConverter.cs b/modules/Nncase.Modules.CPU/Passes/Tile/SingleCPUFusionConverter.cs index a399410337..468db5dc19 100644 --- a/modules/Nncase.Modules.CPU/Passes/Tile/SingleCPUFusionConverter.cs +++ b/modules/Nncase.Modules.CPU/Passes/Tile/SingleCPUFusionConverter.cs @@ -90,6 +90,7 @@ private void GenerateMatMul(Buffer[] arguments, Buffer ret, Call expr) var loopVars = loops.Select(f => f.loopVar).ToArray(); var stmt = T.Serial(out var m, (0, lhs.Dimensions[0])).Body( T.Serial(out var n, (0, rhs.Dimensions[1])).Body( + T.BufferStore(ret, loopVars.Concat(new[] { m, n }).ToArray(), 0f), T.Serial(out var k, (0, lhs.Dimensions[1])).Body( T.BufferStore(ret, loopVars.Concat(new[] { m, n }).ToArray(), T.BufferLoad(ret, loopVars.Concat(new[] { m, n }).ToArray()) + (T.BufferLoad(lhs, loopVars.Concat(new[] { m, k }).ToArray()) * T.BufferLoad(rhs, loopVars.Concat(new[] { k, n }).ToArray())))))). Build(); @@ -112,8 +113,8 @@ private void GenerateUnary(Unary unary, ReadOnlySpan arguments, Buffer r private void GenerateBinary(Binary binary, ReadOnlySpan arguments, Buffer ret, Call call) { - var lhs = call[Binary.Lhs]; - var rhs = call[Binary.Rhs]; + var lhs = call.Arguments[Binary.Lhs.Index]; + var rhs = call.Arguments[Binary.Rhs.Index]; var lhsBuffer = arguments[Binary.Lhs.Index]; var rhsBuffer = arguments[Binary.Rhs.Index]; @@ -127,9 +128,12 @@ private void GenerateBinary(Binary binary, ReadOnlySpan arguments, Buffe var rhsScale = outShape.Zip(rhsShape).Select(s => s.First / s.Second).ToArray(); var loops = Enumerable.Range(0, outShape.Length).Select(i => (T.ForLoop(out var loopVar, (0, outShape[i]), LoopMode.Serial, $"loop_{i}"), loopVar)).ToArray(); - - // var ?final = loops.Reverse().Aggregate(stmt, (acc, p) => p.Item1.Body(acc).Build()); - // _mainBody.Add(T.Block(nameof(Unary)).Body(final).Build()); + var loopVars = loops.Select(f => f.loopVar).ToArray(); + var lhsLoopVars = loopVars.Zip(lhsScale).Select(v => v.First / v.Second).ToArray(); + var rhsLoopVars = loopVars.Zip(rhsScale).Select(v => v.First / v.Second).ToArray(); + Expr stmt = T.BufferStore(ret, loopVars, IR.F.Math.Binary(binary.BinaryOp, T.BufferLoad(lhsBuffer, lhsLoopVars), T.BufferLoad(rhsBuffer, rhsLoopVars))); + var final = loops.Reverse().Aggregate(stmt, (acc, p) => p.Item1.Body(acc).Build()); + _mainBody.Add(T.Block(nameof(Binary)).Body(final).Build()); } private TIR.Buffer TryAllocateBuffer(Expr expr) diff --git a/modules/cpu/src/runtime/runtime_function.cpp b/modules/cpu/src/runtime/runtime_function.cpp index 61afac5746..74345772a7 100644 --- a/modules/cpu/src/runtime/runtime_function.cpp +++ b/modules/cpu/src/runtime/runtime_function.cpp @@ -140,9 +140,6 @@ cpu_runtime_function::invoke_core(NNCASE_UNUSED gsl::span parameters, auto elfloader_ = elfloader{(char *)module().text_physical().data()}; elfloader_.invoke_elf(id, buffers.data(), &nncase_mt, nullptr, nullptr); - for (int i = 0; i < 10; i++) { - printf("%f\n", ((float *)buffers[1]->vaddr)[i]); - } for (int i = 0; i < buffers.size(); i++) { delete buffers[i]; diff --git a/src/Nncase.Tests/Targets/UnitTestCPUTargetTiling.cs b/src/Nncase.Tests/Targets/UnitTestCPUTargetTiling.cs index b1cca37c84..f3a897f37d 100644 --- a/src/Nncase.Tests/Targets/UnitTestCPUTargetTiling.cs +++ b/src/Nncase.Tests/Targets/UnitTestCPUTargetTiling.cs @@ -59,7 +59,8 @@ public async Task TestCpuMatMul() { var lhs = new Var("lhs", new TensorType(DataTypes.Float32, new[] { 3, 4 })); var rhs = IR.F.Random.Normal(DataTypes.Float32, 0, 1, 0, new[] { 4, 6 }); - //new Var("rhs", new TensorType(DataTypes.Float32, new[] { 4, 6 })); + + // new Var("rhs", new TensorType(DataTypes.Float32, new[] { 4, 6 })); var main = new Function("main", IR.F.Tensors.MatMul(lhs, rhs), new[] { lhs }); var module = new IR.IRModule(main); From 4d93b0e409731efb275696c187478d3b30da4bf3 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?=E9=83=91=E5=90=AF=E8=88=AA?= <597323109@qq.com> Date: Wed, 26 Jul 2023 15:40:28 +0800 Subject: [PATCH 043/308] add buffer schedule --- .../CodeGen/CSourceConvertVisitor.cs | 14 +++++ .../Passes/Tile/SingleCPUFusionConverter.cs | 7 +-- .../Nncase.Modules.CPU/Targets/CPUTarget.cs | 8 +++ modules/cpu/src/runtime/cpu_common.h | 8 +++ modules/cpu/src/runtime/runtime_function.cpp | 3 +- modules/cpu/src/runtime/runtime_module.cpp | 4 +- src/Nncase.Core/IR/Buffers/MatchBuffer.cs | 21 ++++++++ .../Passes/Mutators/FoldBufferSlot.cs | 51 +++++++++++++++++++ src/Nncase.Core/TIR/Script.cs | 2 + src/Nncase.Evaluator/Buffers/BufferModule.cs | 1 + src/Nncase.Evaluator/Buffers/MatchBuffer.cs | 29 +++++++++++ 11 files changed, 141 insertions(+), 7 deletions(-) create mode 100644 src/Nncase.Core/IR/Buffers/MatchBuffer.cs create mode 100644 src/Nncase.Core/Passes/Mutators/FoldBufferSlot.cs create mode 100644 src/Nncase.Evaluator/Buffers/MatchBuffer.cs diff --git a/modules/Nncase.Modules.CPU/CodeGen/CSourceConvertVisitor.cs b/modules/Nncase.Modules.CPU/CodeGen/CSourceConvertVisitor.cs index 5dc9cec026..9b6b358525 100644 --- a/modules/Nncase.Modules.CPU/CodeGen/CSourceConvertVisitor.cs +++ b/modules/Nncase.Modules.CPU/CodeGen/CSourceConvertVisitor.cs @@ -178,6 +178,20 @@ protected override CSymbol VisitCall(Call expr) case Load: str = $"((({type} *){arguments[0].Name}->vaddr)[{arguments[1].Name}])"; break; + case IR.Buffers.MatchBuffer op: + var n = arguments[0].Name; + var pb = (TIR.PhysicalBuffer)expr[IR.Buffers.MatchBuffer.Input]; + var ind = new String(Enumerable.Repeat(' ', IndentScope.Writer.Indent).ToArray()); + str = $@"uint32_t _{n}_shape[] = {{ {string.Join(", ", pb.FixedDimensions.ToArray())} }}; +{ind}uint32_t _{n}_stride[] = {{ {string.Join(", ", pb.FixedStrides.ToArray())} }}; +{ind}buffer_t _{n} = {{ +{ind}{ind}.vaddr = ((uint8_t*) rdata + {pb.Start}), +{ind}{ind}.paddr = 0, +{ind}{ind}.shape = _{n}_shape, +{ind}{ind}.stride = _{n}_stride, +{ind}{ind}.rank = {pb.FixedDimensions.Length} }}; +{ind}buffer_t *{n} = &_{n}"; + break; default: throw new NotSupportedException(); } diff --git a/modules/Nncase.Modules.CPU/Passes/Tile/SingleCPUFusionConverter.cs b/modules/Nncase.Modules.CPU/Passes/Tile/SingleCPUFusionConverter.cs index 468db5dc19..cc05e9396b 100644 --- a/modules/Nncase.Modules.CPU/Passes/Tile/SingleCPUFusionConverter.cs +++ b/modules/Nncase.Modules.CPU/Passes/Tile/SingleCPUFusionConverter.cs @@ -95,9 +95,10 @@ private void GenerateMatMul(Buffer[] arguments, Buffer ret, Call expr) T.BufferStore(ret, loopVars.Concat(new[] { m, n }).ToArray(), T.BufferLoad(ret, loopVars.Concat(new[] { m, n }).ToArray()) + (T.BufferLoad(lhs, loopVars.Concat(new[] { m, k }).ToArray()) * T.BufferLoad(rhs, loopVars.Concat(new[] { k, n }).ToArray())))))). Build(); var final = loops.Reverse().Aggregate(stmt, (acc, p) => p.Item1.Body(acc).Build()); - // [m,k] @ [k, n] - var body = T.Block(nameof(MatMul)).Body(final); + var body = T.Block(nameof(MatMul)).Body( + T.Sequential(arguments.OfType().Where(p => p.Const != null).Select(b => T.MatchBuffer(b)).ToArray()), + final); _mainBody.Add(body.Build()); } @@ -158,7 +159,7 @@ private TIR.Buffer TryAllocateBuffer(Expr expr) buffer = T.PhysicalBuffer(v.CheckedDataType, MemoryLocation.Input, v.CheckedShape.ToValueArray(), out _, name); break; case TensorConst c: - buffer = T.PhysicalBuffer(c.Value.ElementType, MemoryLocation.Rdata, c.Value.Dimensions, out _, name); + buffer = T.ConstBuffer(c, out _, name); break; default: throw new NotSupportedException(); diff --git a/modules/Nncase.Modules.CPU/Targets/CPUTarget.cs b/modules/Nncase.Modules.CPU/Targets/CPUTarget.cs index b3fe833906..1e2657ab3b 100644 --- a/modules/Nncase.Modules.CPU/Targets/CPUTarget.cs +++ b/modules/Nncase.Modules.CPU/Targets/CPUTarget.cs @@ -91,6 +91,14 @@ public void RegisterTargetDependentAfterQuantPass(IPassManager passManager, Comp p.Add(); p.Add(); }); + + passManager.AddWithName("DDrBufferSchdeule"); + + passManager.AddWithName("InstStage").Configure(p => + { + p.Add(); + p.Add(); // 折叠自定义op + }); } public void RegisterTargetDependentBeforeCodeGen(IPassManager passManager, CompileOptions options) diff --git a/modules/cpu/src/runtime/cpu_common.h b/modules/cpu/src/runtime/cpu_common.h index 047ea7ce31..d51e1775aa 100644 --- a/modules/cpu/src/runtime/cpu_common.h +++ b/modules/cpu/src/runtime/cpu_common.h @@ -97,7 +97,11 @@ inline int32_t int32_binary_mul(int32_t x, int32_t y) { return x * y; } inline int32_t int32_binary_div(int32_t x, int32_t y) { return x / y; } inline int32_t int32_binary_min(int32_t x, int32_t y) { return std::min(x, y); } inline int32_t int32_binary_max(int32_t x, int32_t y) { return std::max(x, y); } +#if defined (__arm64__) && defined (__APPLE__) +inline int32_t int32_binary_pow(int32_t x, int32_t y) { return (int32_t)pow(x, y); } +#else inline int32_t int32_binary_pow(int32_t x, int32_t y) { return std::pow(x, y); } +#endif inline int32_t int32_binary_logical_and(int32_t x, int32_t y) { return x && y; } inline int32_t int32_binary_mod(int32_t x, int32_t y) { return x % y; } @@ -107,7 +111,11 @@ inline int64_t int64_binary_mul(int64_t x, int64_t y) { return x * y; } inline int64_t int64_binary_div(int64_t x, int64_t y) { return x / y; } inline int64_t int64_binary_min(int64_t x, int64_t y) { return std::min(x, y); } inline int64_t int64_binary_max(int64_t x, int64_t y) { return std::max(x, y); } +#if defined (__arm64__) && defined (__APPLE__) +inline int64_t int64_binary_pow(int64_t x, int64_t y) { return (int64_t)pow(x, y); } +#else inline int64_t int64_binary_pow(int64_t x, int64_t y) { return std::pow(x, y); } +#endif inline int64_t int64_binary_logical_and(int64_t x, int64_t y) { return x && y; } inline int64_t int64_binary_mod(int64_t x, int64_t y) { return x % y; } diff --git a/modules/cpu/src/runtime/runtime_function.cpp b/modules/cpu/src/runtime/runtime_function.cpp index 74345772a7..f5cc272b88 100644 --- a/modules/cpu/src/runtime/runtime_function.cpp +++ b/modules/cpu/src/runtime/runtime_function.cpp @@ -139,8 +139,7 @@ cpu_runtime_function::invoke_core(NNCASE_UNUSED gsl::span parameters, } auto elfloader_ = elfloader{(char *)module().text_physical().data()}; - elfloader_.invoke_elf(id, buffers.data(), &nncase_mt, nullptr, nullptr); - + elfloader_.invoke_elf(id, buffers.data(), &nncase_mt, nullptr, (void *)module().rdata_physical().data()); for (int i = 0; i < buffers.size(); i++) { delete buffers[i]; } diff --git a/modules/cpu/src/runtime/runtime_module.cpp b/modules/cpu/src/runtime/runtime_module.cpp index 0b1b2efbe8..6e44d1fd3b 100644 --- a/modules/cpu/src/runtime/runtime_module.cpp +++ b/modules/cpu/src/runtime/runtime_module.cpp @@ -27,8 +27,8 @@ result cpu_runtime_module::initialize_before_functions( // if (!context.is_section_pinned()) // return nncase::err(std::errc::bad_address); // try_var(data, context.get_or_read_section(".data", data_storage_, - // false)); try_var(rdata, context.get_or_read_section(".rdata", - // rdata_storage_, true)); + // false)); + try_set(rdata_, context.get_or_read_section(".rdata", rdata_storage_, true)); try_set(text_, context.get_or_read_section(".text", text_storage_, true)); return ok(); diff --git a/src/Nncase.Core/IR/Buffers/MatchBuffer.cs b/src/Nncase.Core/IR/Buffers/MatchBuffer.cs new file mode 100644 index 0000000000..0904ed7267 --- /dev/null +++ b/src/Nncase.Core/IR/Buffers/MatchBuffer.cs @@ -0,0 +1,21 @@ +// Copyright (c) Canaan Inc. All rights reserved. +// Licensed under the Apache license. See LICENSE file in the project root for full license information. + +using Nncase.IR.Tensors; +using Nncase.PatternMatch; +using static Nncase.IR.TypePatternUtility; + +namespace Nncase.IR.Buffers; + +/// +/// MatchBuffer op. +/// todo maybe need united matchbuffer and allocatebuffer +/// +[PatternFunctionalGenerator] +public sealed partial class MatchBuffer : Op +{ + public static readonly ParameterInfo Input = new(typeof(MatchBuffer), 0, "input", IsTensor()); + + /// + public override bool CanFoldConstCall => false; +} diff --git a/src/Nncase.Core/Passes/Mutators/FoldBufferSlot.cs b/src/Nncase.Core/Passes/Mutators/FoldBufferSlot.cs new file mode 100644 index 0000000000..d96dfd968d --- /dev/null +++ b/src/Nncase.Core/Passes/Mutators/FoldBufferSlot.cs @@ -0,0 +1,51 @@ +// Copyright (c) Canaan Inc. All rights reserved. +// Licensed under the Apache license. See LICENSE file in the project root for full license information. + +using System.Reactive; +using Nncase.Evaluator; +using Nncase.IR; +using Nncase.Passes; +using Nncase.TIR; + +namespace Nncase.Passes.Mutators; + +/// +/// remove buffer BaseMentOf/DDrOf/MmuOF. +/// +public sealed class FoldBufferSlot : ExprRewriter +{ + protected internal override Expr VisitPrimFunction(TIR.PrimFunction expr, Unit context) + { + if (expr.SchedResult.IsScheduled == true) + { + return base.VisitPrimFunction(expr, context); + } + + return expr; + } + + protected override Expr RewriteLeafCall(Call expr) + { + if (expr.Target is IR.Buffers.BaseMentOf) + { + var locate = ((TIR.PhysicalBuffer)expr.Arguments[0]).MemLocation; + return locate switch + { + MemoryLocation.Input => 0, + MemoryLocation.Output => 1, + MemoryLocation.Rdata => 2, + MemoryLocation.Data => 3, + _ => throw new ArgumentOutOfRangeException($"You Can't Assgin The BaseMent For {locate}!"), + }; + } + else if (expr.Target is IR.Buffers.DDrOf) + { + if (expr.Arguments[0] is TIR.PhysicalBuffer buf) + { + return buf.Start; + } + } + + return expr; + } +} diff --git a/src/Nncase.Core/TIR/Script.cs b/src/Nncase.Core/TIR/Script.cs index c9f4a77f72..ad1b71f0bf 100644 --- a/src/Nncase.Core/TIR/Script.cs +++ b/src/Nncase.Core/TIR/Script.cs @@ -330,4 +330,6 @@ public static Call Emit(out T value, Func creator) /// value. /// call bufferstore. public static Call BufferStore(TIR.Buffer buffer, Expr[] indices, Expr value) => new Call(new IR.Buffers.BufferStore(), buffer, new IR.Tuple(indices), value); + + public static Call MatchBuffer(TIR.Buffer buffer) => new Call(new IR.Buffers.MatchBuffer(), buffer); } diff --git a/src/Nncase.Evaluator/Buffers/BufferModule.cs b/src/Nncase.Evaluator/Buffers/BufferModule.cs index 0954206679..a2512b6f13 100644 --- a/src/Nncase.Evaluator/Buffers/BufferModule.cs +++ b/src/Nncase.Evaluator/Buffers/BufferModule.cs @@ -22,5 +22,6 @@ public void ConfigureServices(IRegistrator registrator) registrator.RegisterManyInterface(reuse: Reuse.Singleton); registrator.RegisterManyInterface(reuse: Reuse.Singleton); registrator.RegisterManyInterface(reuse: Reuse.Singleton); + registrator.RegisterManyInterface(reuse: Reuse.Singleton); } } diff --git a/src/Nncase.Evaluator/Buffers/MatchBuffer.cs b/src/Nncase.Evaluator/Buffers/MatchBuffer.cs new file mode 100644 index 0000000000..7a806fdd61 --- /dev/null +++ b/src/Nncase.Evaluator/Buffers/MatchBuffer.cs @@ -0,0 +1,29 @@ +// Copyright (c) Canaan Inc. All rights reserved. +// Licensed under the Apache license. See LICENSE file in the project root for full license information. + +using Nncase.IR; +using Nncase.IR.Buffers; + +namespace Nncase.Evaluator.Buffers; + +/// +/// Evaluator for BufferOf. +/// +[TypeInferGenerator] +public partial class MatchBufferEvaluator : ITypeInferencer, IOpPrinter +{ + public string Visit(IIRPrinterContext context, MatchBuffer target, bool iLmode) + { + if (iLmode) + { + throw new System.NotSupportedException(); + } + + return $"Matched {context.GetArgument(target, MatchBuffer.Input)}"; + } + + private IRType Visit() + { + return TupleType.Void; + } +} From 02fbf2fee6e1d34c19c4facd6d95eef4aff5287f Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?=E9=83=91=E5=90=AF=E8=88=AA?= <597323109@qq.com> Date: Wed, 26 Jul 2023 15:42:05 +0800 Subject: [PATCH 044/308] format --- modules/Nncase.Modules.CPU/CodeGen/CSourceConvertVisitor.cs | 2 +- .../Passes/Tile/SingleCPUFusionConverter.cs | 1 + src/Nncase.Core/IR/Buffers/MatchBuffer.cs | 4 ++-- src/Nncase.Core/Passes/Mutators/FoldBufferSlot.cs | 2 +- src/Nncase.Evaluator/Buffers/MatchBuffer.cs | 2 +- 5 files changed, 6 insertions(+), 5 deletions(-) diff --git a/modules/Nncase.Modules.CPU/CodeGen/CSourceConvertVisitor.cs b/modules/Nncase.Modules.CPU/CodeGen/CSourceConvertVisitor.cs index 9b6b358525..d75f80cced 100644 --- a/modules/Nncase.Modules.CPU/CodeGen/CSourceConvertVisitor.cs +++ b/modules/Nncase.Modules.CPU/CodeGen/CSourceConvertVisitor.cs @@ -181,7 +181,7 @@ protected override CSymbol VisitCall(Call expr) case IR.Buffers.MatchBuffer op: var n = arguments[0].Name; var pb = (TIR.PhysicalBuffer)expr[IR.Buffers.MatchBuffer.Input]; - var ind = new String(Enumerable.Repeat(' ', IndentScope.Writer.Indent).ToArray()); + var ind = new string(Enumerable.Repeat(' ', IndentScope.Writer.Indent).ToArray()); str = $@"uint32_t _{n}_shape[] = {{ {string.Join(", ", pb.FixedDimensions.ToArray())} }}; {ind}uint32_t _{n}_stride[] = {{ {string.Join(", ", pb.FixedStrides.ToArray())} }}; {ind}buffer_t _{n} = {{ diff --git a/modules/Nncase.Modules.CPU/Passes/Tile/SingleCPUFusionConverter.cs b/modules/Nncase.Modules.CPU/Passes/Tile/SingleCPUFusionConverter.cs index cc05e9396b..267a5a484d 100644 --- a/modules/Nncase.Modules.CPU/Passes/Tile/SingleCPUFusionConverter.cs +++ b/modules/Nncase.Modules.CPU/Passes/Tile/SingleCPUFusionConverter.cs @@ -95,6 +95,7 @@ private void GenerateMatMul(Buffer[] arguments, Buffer ret, Call expr) T.BufferStore(ret, loopVars.Concat(new[] { m, n }).ToArray(), T.BufferLoad(ret, loopVars.Concat(new[] { m, n }).ToArray()) + (T.BufferLoad(lhs, loopVars.Concat(new[] { m, k }).ToArray()) * T.BufferLoad(rhs, loopVars.Concat(new[] { k, n }).ToArray())))))). Build(); var final = loops.Reverse().Aggregate(stmt, (acc, p) => p.Item1.Body(acc).Build()); + // [m,k] @ [k, n] var body = T.Block(nameof(MatMul)).Body( T.Sequential(arguments.OfType().Where(p => p.Const != null).Select(b => T.MatchBuffer(b)).ToArray()), diff --git a/src/Nncase.Core/IR/Buffers/MatchBuffer.cs b/src/Nncase.Core/IR/Buffers/MatchBuffer.cs index 0904ed7267..3cafa7f595 100644 --- a/src/Nncase.Core/IR/Buffers/MatchBuffer.cs +++ b/src/Nncase.Core/IR/Buffers/MatchBuffer.cs @@ -1,4 +1,4 @@ -// Copyright (c) Canaan Inc. All rights reserved. +// Copyright (c) Canaan Inc. All rights reserved. // Licensed under the Apache license. See LICENSE file in the project root for full license information. using Nncase.IR.Tensors; @@ -9,7 +9,7 @@ namespace Nncase.IR.Buffers; /// /// MatchBuffer op. -/// todo maybe need united matchbuffer and allocatebuffer +/// todo maybe need united matchbuffer and allocatebuffer. /// [PatternFunctionalGenerator] public sealed partial class MatchBuffer : Op diff --git a/src/Nncase.Core/Passes/Mutators/FoldBufferSlot.cs b/src/Nncase.Core/Passes/Mutators/FoldBufferSlot.cs index d96dfd968d..4bc48af9d1 100644 --- a/src/Nncase.Core/Passes/Mutators/FoldBufferSlot.cs +++ b/src/Nncase.Core/Passes/Mutators/FoldBufferSlot.cs @@ -1,4 +1,4 @@ -// Copyright (c) Canaan Inc. All rights reserved. +// Copyright (c) Canaan Inc. All rights reserved. // Licensed under the Apache license. See LICENSE file in the project root for full license information. using System.Reactive; diff --git a/src/Nncase.Evaluator/Buffers/MatchBuffer.cs b/src/Nncase.Evaluator/Buffers/MatchBuffer.cs index 7a806fdd61..7a8122d2ae 100644 --- a/src/Nncase.Evaluator/Buffers/MatchBuffer.cs +++ b/src/Nncase.Evaluator/Buffers/MatchBuffer.cs @@ -1,4 +1,4 @@ -// Copyright (c) Canaan Inc. All rights reserved. +// Copyright (c) Canaan Inc. All rights reserved. // Licensed under the Apache license. See LICENSE file in the project root for full license information. using Nncase.IR; From 9fa866bdd6d95aa5116649f401979106d70aadba Mon Sep 17 00:00:00 2001 From: huochenghai Date: Wed, 26 Jul 2023 16:25:22 +0800 Subject: [PATCH 045/308] update cpu tir --- .../Nncase.Modules.CPU/CodeGen/CSourceCompiler.cs | 2 +- .../Nncase.Modules.CPU/CodeGen/LinkedModule.cs | 2 +- .../Passes/Tile/SingleCPUFusionConverter.cs | 15 +++++++++------ modules/cpu/src/runtime/elfloader.cpp | 2 +- modules/cpu/src/runtime/elfloader.h | 4 ++-- modules/cpu/src/runtime/runtime_function.cpp | 2 +- 6 files changed, 15 insertions(+), 12 deletions(-) diff --git a/modules/Nncase.Modules.CPU/CodeGen/CSourceCompiler.cs b/modules/Nncase.Modules.CPU/CodeGen/CSourceCompiler.cs index 6ab53936fc..d59e244f92 100644 --- a/modules/Nncase.Modules.CPU/CodeGen/CSourceCompiler.cs +++ b/modules/Nncase.Modules.CPU/CodeGen/CSourceCompiler.cs @@ -134,7 +134,7 @@ private string ArgumentsSpecific(string sourcePath, string outPath) { if (RuntimeInformation.IsOSPlatform(OSPlatform.Linux)) { - return $"{sourcePath} -nostdlib -static -no-pie -fPIC -march={Arch} -o {outPath}"; + return $"{sourcePath} -nostdlib -static -no-pie -fPIC -fno-stack-protector -march={Arch} -o {outPath}"; } else if (RuntimeInformation.IsOSPlatform(OSPlatform.OSX)) { diff --git a/modules/Nncase.Modules.CPU/CodeGen/LinkedModule.cs b/modules/Nncase.Modules.CPU/CodeGen/LinkedModule.cs index 8a300e55bf..c32350936f 100644 --- a/modules/Nncase.Modules.CPU/CodeGen/LinkedModule.cs +++ b/modules/Nncase.Modules.CPU/CodeGen/LinkedModule.cs @@ -15,7 +15,7 @@ internal sealed class LinkedModule : ILinkedModule public LinkedModule(IReadOnlyList functions, byte[] text, byte[] rdata) { Functions = functions; - Sections = new[] { new LinkedSection(text, ".text", 0, 8, (uint)text.Length) }; + Sections = new[] { new LinkedSection(text, ".text", 0, 8, (uint)text.Length), new LinkedSection(rdata, ".rdata", 0, 8, (uint)rdata.Length) }; } public string ModuleKind => Targets.CPUTarget.Kind; diff --git a/modules/Nncase.Modules.CPU/Passes/Tile/SingleCPUFusionConverter.cs b/modules/Nncase.Modules.CPU/Passes/Tile/SingleCPUFusionConverter.cs index 267a5a484d..2cf09245c1 100644 --- a/modules/Nncase.Modules.CPU/Passes/Tile/SingleCPUFusionConverter.cs +++ b/modules/Nncase.Modules.CPU/Passes/Tile/SingleCPUFusionConverter.cs @@ -81,24 +81,24 @@ protected override Unit VisitLeafCall(Call expr) return default; } - private void GenerateMatMul(Buffer[] arguments, Buffer ret, Call expr) + private void GenerateMatMul(ReadOnlySpan arguments, Buffer ret, Call expr) { var lhs = arguments[0]; var rhs = arguments[1]; var loops = Enumerable.Range(0, lhs.Rank - 2).Select(i => (T.ForLoop(out var loopVar, (0, lhs.Dimensions[i]), LoopMode.Serial, $"loop_{i}"), loopVar)).ToArray(); var loopVars = loops.Select(f => f.loopVar).ToArray(); - var stmt = T.Serial(out var m, (0, lhs.Dimensions[0])).Body( - T.Serial(out var n, (0, rhs.Dimensions[1])).Body( + var stmt = T.Serial(out var m, (0, lhs.Dimensions[^2])).Body( + T.Serial(out var n, (0, rhs.Dimensions[^1])).Body( T.BufferStore(ret, loopVars.Concat(new[] { m, n }).ToArray(), 0f), - T.Serial(out var k, (0, lhs.Dimensions[1])).Body( + T.Serial(out var k, (0, lhs.Dimensions[^1])).Body( T.BufferStore(ret, loopVars.Concat(new[] { m, n }).ToArray(), T.BufferLoad(ret, loopVars.Concat(new[] { m, n }).ToArray()) + (T.BufferLoad(lhs, loopVars.Concat(new[] { m, k }).ToArray()) * T.BufferLoad(rhs, loopVars.Concat(new[] { k, n }).ToArray())))))). Build(); var final = loops.Reverse().Aggregate(stmt, (acc, p) => p.Item1.Body(acc).Build()); // [m,k] @ [k, n] var body = T.Block(nameof(MatMul)).Body( - T.Sequential(arguments.OfType().Where(p => p.Const != null).Select(b => T.MatchBuffer(b)).ToArray()), + T.Sequential(arguments.ToArray().OfType().Where(p => p.Const != null).Select(b => T.MatchBuffer(b)).ToArray()), final); _mainBody.Add(body.Build()); } @@ -135,7 +135,10 @@ private void GenerateBinary(Binary binary, ReadOnlySpan arguments, Buffe var rhsLoopVars = loopVars.Zip(rhsScale).Select(v => v.First / v.Second).ToArray(); Expr stmt = T.BufferStore(ret, loopVars, IR.F.Math.Binary(binary.BinaryOp, T.BufferLoad(lhsBuffer, lhsLoopVars), T.BufferLoad(rhsBuffer, rhsLoopVars))); var final = loops.Reverse().Aggregate(stmt, (acc, p) => p.Item1.Body(acc).Build()); - _mainBody.Add(T.Block(nameof(Binary)).Body(final).Build()); + var body = T.Block(nameof(Binary)).Body( + T.Sequential(arguments.ToArray().OfType().Where(p => p.Const != null).Select(b => T.MatchBuffer(b)).ToArray()), + final); + _mainBody.Add(body.Build()); } private TIR.Buffer TryAllocateBuffer(Expr expr) diff --git a/modules/cpu/src/runtime/elfloader.cpp b/modules/cpu/src/runtime/elfloader.cpp index f0c6571db1..f20063b9d5 100644 --- a/modules/cpu/src/runtime/elfloader.cpp +++ b/modules/cpu/src/runtime/elfloader.cpp @@ -5,7 +5,7 @@ using namespace nncase::runtime; using namespace nncase::runtime::cpu; int elfloader::invoke_elf(size_t id, buffer_t **buffers, nncase_mt_t *nncase_mt, - void *data, void *rdata) { + void *data, const void *rdata) { check(el_init(&ctx_), "initialising"); diff --git a/modules/cpu/src/runtime/elfloader.h b/modules/cpu/src/runtime/elfloader.h index bdc0c5039e..7a0aa36f00 100644 --- a/modules/cpu/src/runtime/elfloader.h +++ b/modules/cpu/src/runtime/elfloader.h @@ -14,7 +14,7 @@ BEGIN_NS_NNCASE_RT_MODULE(cpu) typedef void (*entrypoint_t)(size_t id, buffer_t **buffers, - nncase_mt_t *nncase_mt, void *data, void *rdata); + nncase_mt_t *nncase_mt, void *data, const void *rdata); class elfloader { public: @@ -50,7 +50,7 @@ class elfloader { } int invoke_elf(size_t id, buffer_t **buffers, nncase_mt_t *nncase_mt, - void *data, void *rdata); + void *data, const void *rdata); private: void *ptr_; diff --git a/modules/cpu/src/runtime/runtime_function.cpp b/modules/cpu/src/runtime/runtime_function.cpp index f5cc272b88..422fd23e59 100644 --- a/modules/cpu/src/runtime/runtime_function.cpp +++ b/modules/cpu/src/runtime/runtime_function.cpp @@ -139,7 +139,7 @@ cpu_runtime_function::invoke_core(NNCASE_UNUSED gsl::span parameters, } auto elfloader_ = elfloader{(char *)module().text_physical().data()}; - elfloader_.invoke_elf(id, buffers.data(), &nncase_mt, nullptr, (void *)module().rdata_physical().data()); + elfloader_.invoke_elf(id, buffers.data(), &nncase_mt, nullptr, (const void *)module().rdata_physical().data()); for (int i = 0; i < buffers.size(); i++) { delete buffers[i]; } From 5429c6bcb1a92d71b4f8f0bbf89b945f9f063e02 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?=E9=83=91=E5=90=AF=E8=88=AA?= <597323109@qq.com> Date: Thu, 27 Jul 2023 13:37:13 +0800 Subject: [PATCH 046/308] refactor buffer --- src/Nncase.Core/TIR/Buffer.cs | 219 +++------------------------------- 1 file changed, 18 insertions(+), 201 deletions(-) diff --git a/src/Nncase.Core/TIR/Buffer.cs b/src/Nncase.Core/TIR/Buffer.cs index 9289d1afec..80f86c4f21 100644 --- a/src/Nncase.Core/TIR/Buffer.cs +++ b/src/Nncase.Core/TIR/Buffer.cs @@ -267,233 +267,50 @@ public SelectedRange Slice(Segment1D segment) /// /// buffer. /// -public abstract class Buffer : Expr +public sealed class Buffer : Expr { - public Buffer(string name, DataType elemType, MemoryLocation memoryLocation, Expr[] operands) - : base(operands.ToArray()) + private static int _globalVarIndex; + + public Buffer(string name, DataType elemType, MemSpan memSpan, Expr[] dimensions, Expr[] strides) + : base(new[] { memSpan }.Concat(dimensions).Concat(strides)) { Name = name; ElemType = elemType; - MemLocation = memoryLocation; + Rank = dimensions.Length; + GlobalVarIndex = Interlocked.Increment(ref _globalVarIndex); } public string Name { get; } public DataType ElemType { get; } - public MemoryLocation MemLocation { get; } - - /// - /// Gets if this buffer from the constant !. - /// - public TensorConst? Const { get; init; } - /// /// Gets rank of the tensor: number of dimensions. /// - public abstract int Rank { get; } + public int Rank { get; } /// - /// Gets the strides. - /// - /// This Strides is by elements not by bytes! - /// + /// Gets the global var index. /// - public abstract ReadOnlySpan Strides { get; } + public int GlobalVarIndex { get; } /// /// Gets the shape. /// - public abstract ReadOnlySpan Dimensions { get; } - - /// - public override bool Equals(object? obj) - { - if (obj is not Buffer other) - { - return false; - } - - if (Const is not null && !Const.Equals(other.Const)) - { - return false; - } - - return string.Equals(Name, other.Name, StringComparison.Ordinal) && - ElemType.Equals(other.ElemType) && - MemLocation.Equals(other.MemLocation) && - Rank.Equals(other.Rank) && - base.Equals(obj); - } -} - -/// -/// the logical buffer. -/// -public sealed class LogicalBuffer : Buffer -{ - /// - /// Initializes a new instance of the class. - /// create from the IRType. - /// - /// the name. - /// the location. - /// prim type. - /// the shape. - /// the strides. - public LogicalBuffer(string name, DataType elemType, MemoryLocation location, ReadOnlySpan dimensions, ReadOnlySpan strides) - : base(name, elemType, location, ArrayUtility.Concat(dimensions, strides)) - { - Rank = dimensions.Length; - } - - /// - /// Initializes a new instance of the class. - /// . - /// - public LogicalBuffer(string name, MemoryLocation location, TensorConst tensor) - : this(name, tensor.Value.ElementType, location, ArrayUtility.ToExprArray(tensor.Value.Dimensions), ArrayUtility.ToExprArray(tensor.Value.Strides)) - { - Const = tensor; - } - - /// - /// Initializes a new instance of the class. - /// - /// - public LogicalBuffer(string name, DataType elemType, MemoryLocation location, ReadOnlySpan dimensions) - : this(name, elemType, location, dimensions, TensorUtilities.GetStrides(dimensions)) - { - } - - /// - /// Gets get the total length. - /// - public Expr Length => TensorUtilities.GetProduct(Dimensions); + public MemSpan MemSpan => (MemSpan)Operands[0]; /// /// Gets the shape. /// - public override ReadOnlySpan Dimensions => Operands[0..Rank]; + public ReadOnlySpan Dimensions => Operands[1..(1 + Rank)]; /// /// Gets the strides. + /// + /// This Strides is by elements not by bytes! + /// /// - public override ReadOnlySpan Strides => Operands[Rank..]; - - /// - public override int Rank { get; } - - /// - public override string ToString() - { - return $"LogicalBuffer({Name}, {ElemType}, {nameof(MemLocation)})"; - } - - /// - public override TExprResult Accept(ExprFunctor functor, TContext context) - => functor.VisitLogicalBuffer(this, context); - - public LogicalBuffer With(string? name = null, DataType? elemType = null, MemoryLocation? location = null, Expr[]? dimensions = null, Expr[]? strides = null) - => new LogicalBuffer(name ?? Name, elemType ?? ElemType, location ?? MemLocation, dimensions ?? Dimensions, strides ?? Strides) { Const = Const }; -} - -/// -/// the physical buffer. -/// -public sealed class PhysicalBuffer : Buffer -{ - private readonly int[] _fixedDimensions; - private readonly int[] _fixedStrides; - - /// - /// Initializes a new instance of the class. - /// ctor for physical buffer. - /// - public PhysicalBuffer(string name, DataType elemType, MemoryLocation location, ReadOnlySpan dimensions, ReadOnlySpan strides, int start, int size) - : base(name, elemType, location, Array.Empty()) - { - Start = start; - Size = size; - _fixedDimensions = dimensions.ToArray(); - _fixedStrides = strides.ToArray(); - } - - /// - /// Initializes a new instance of the class. - /// . - /// - public PhysicalBuffer(string name, DataType elemType, MemoryLocation location, ReadOnlySpan dimensions, int start, int size) - : this(name, elemType, location, dimensions, TensorUtilities.GetStrides(dimensions), start, size) - { - } - - /// - /// Initializes a new instance of the class. - /// . - /// - public PhysicalBuffer(string name, MemoryLocation location, TensorConst tensor, int start, int size) - : this(name, tensor.Value.ElementType, location, tensor.Value.Dimensions, tensor.Value.Strides, start, size) - { - Const = tensor; - } - - /// - /// Gets fixed dimensions. - /// - public ReadOnlySpan FixedDimensions => _fixedDimensions; - - /// - /// Gets fixed strides. - /// - public ReadOnlySpan FixedStrides => _fixedStrides; - - /// - /// Gets or sets start. - /// - public int Start { get; set; } - - /// - /// Gets total size in bytes. - /// - public int Size { get; init; } - - /// - /// Gets dimensions. - /// - public override ReadOnlySpan Dimensions => ArrayUtility.ToExprArray(FixedDimensions); - - /// - /// Gets strides. - /// - public override ReadOnlySpan Strides => ArrayUtility.ToExprArray(FixedStrides); - - /// - /// Gets shape. - /// - public Shape Shape => new Shape(FixedDimensions); - - /// - public override int Rank => FixedDimensions.Length; - - /// - public override string ToString() - { - return $"PhysicalBuffer({Name}, {ElemType}, {nameof(MemLocation)})"; - } - - /// - public override bool Equals(object? obj) - { - return base.Equals(obj) && obj is PhysicalBuffer other && - FixedDimensions.SequenceEqual(other.FixedDimensions) && - FixedStrides.SequenceEqual(other.FixedStrides); - } + public ReadOnlySpan Strides => Operands[(1 + Rank)..(1 + Rank + Rank)]; - /// - public override TExprResult Accept(ExprFunctor functor, TContext context) - => functor.VisitPhysicalBuffer(this, context); - - public PhysicalBuffer With(string? name = null, DataType? elemType = null, MemoryLocation? location = null, int[]? dimensions = null, int[]? strides = null, int? start = null, int? size = null) - => new PhysicalBuffer(name ?? Name, elemType ?? ElemType, location ?? MemLocation, dimensions ?? FixedDimensions, strides ?? FixedStrides, start ?? Start, size ?? Size) { Const = Const }; -} + public override TExprResult Accept(ExprFunctor functor, TContext context) => throw new NotImplementedException(); +} \ No newline at end of file From 4b5361388e77bcaa9866b11efba05c42b20348e3 Mon Sep 17 00:00:00 2001 From: zhengqihang <597323109@qq.com> Date: Fri, 28 Jul 2023 16:32:41 +0800 Subject: [PATCH 047/308] fix build --- .../CodeGen/CSourceConvertVisitor.cs | 22 +-- .../CodeGen/CSourceExtensions.cs | 1 - .../CodeGen/FunctionBuilder.cs | 97 ++++----- .../Passes/Tile/CPUFusionGroupMutator.cs | 2 +- .../Passes/Tile/SingleCPUFusionConverter.cs | 25 ++- src/Nncase.Core/FunctionCollector.cs | 2 +- src/Nncase.Core/IR/ExprCloner.g.cs | 28 +-- src/Nncase.Core/IR/ExprFunctor.cs | 10 + src/Nncase.Core/IR/ExprFunctor.g.cs | 48 ++--- src/Nncase.Core/IR/ExprRewriter.g.cs | 69 ++----- src/Nncase.Core/IR/ExprVisitor.g.cs | 187 ++++++------------ src/Nncase.Core/IR/IRList.csv | 7 +- src/Nncase.Core/IR/TypeFunctor.cs | 8 + .../Passes/Mutators/FoldBufferSlot.cs | 4 +- .../Passes/Mutators/UnRollLoopSequential.cs | 5 +- src/Nncase.Core/Schedule/ScheduleTypes.cs | 4 +- src/Nncase.Core/TIR/Buffer.cs | 15 +- src/Nncase.Core/TIR/MemSpan.cs | 21 +- src/Nncase.Core/TIR/PrimFunction.cs | 12 +- src/Nncase.Core/TIR/Script.cs | 53 +++-- .../Diagnostics/ScriptPrintVisitor.cs | 24 ++- src/Nncase.Evaluator/TypeInferenceVisitor.cs | 43 ++-- src/Nncase.Passes/DDrBufferSchdeulePass.cs | 168 +++++++--------- .../Rules/Neutral/PrimFuncMergeRule.cs | 2 + src/Nncase.Tests/Core/UnitTestExpression.cs | 12 +- .../Core/UnitTestStringUtility.cs | 6 +- src/Nncase.Tests/Core/UnitTestTIR.cs | 34 +--- .../Diagnostics/UnitTestDumpper.cs | 2 +- .../Evaluator/UnitTestEvaluator.cs | 2 +- .../TIR/PrimFunc/IDataFlowPrimFuncCase.cs | 15 +- .../TIR/PrimFunc/UnitTestPrimFuncMerge.cs | 10 +- src/Nncase.Tests/TIR/UnitTestMutators.cs | 24 +-- .../Transform/UnitTestPassManager.cs | 12 +- .../Transform/UnitTestSubstitutor.cs | 12 +- 34 files changed, 421 insertions(+), 565 deletions(-) diff --git a/modules/Nncase.Modules.CPU/CodeGen/CSourceConvertVisitor.cs b/modules/Nncase.Modules.CPU/CodeGen/CSourceConvertVisitor.cs index d75f80cced..9e71b76ea6 100644 --- a/modules/Nncase.Modules.CPU/CodeGen/CSourceConvertVisitor.cs +++ b/modules/Nncase.Modules.CPU/CodeGen/CSourceConvertVisitor.cs @@ -180,16 +180,16 @@ protected override CSymbol VisitCall(Call expr) break; case IR.Buffers.MatchBuffer op: var n = arguments[0].Name; - var pb = (TIR.PhysicalBuffer)expr[IR.Buffers.MatchBuffer.Input]; + var pb = (TIR.Buffer)expr[IR.Buffers.MatchBuffer.Input]; var ind = new string(Enumerable.Repeat(' ', IndentScope.Writer.Indent).ToArray()); - str = $@"uint32_t _{n}_shape[] = {{ {string.Join(", ", pb.FixedDimensions.ToArray())} }}; -{ind}uint32_t _{n}_stride[] = {{ {string.Join(", ", pb.FixedStrides.ToArray())} }}; + str = $@"uint32_t _{n}_shape[] = {{ {string.Join(", ", pb.Dimensions.AsValueEnumerable().Select(e => Visit(e).Name).ToArray())} }}; +{ind}uint32_t _{n}_stride[] = {{ {string.Join(", ", pb.Strides.AsValueEnumerable().Select(e => Visit(e).Name).ToArray())} }}; {ind}buffer_t _{n} = {{ -{ind}{ind}.vaddr = ((uint8_t*) rdata + {pb.Start}), +{ind}{ind}.vaddr = ((uint8_t*) rdata + {Visit(pb.MemSpan.Start).Name}), {ind}{ind}.paddr = 0, {ind}{ind}.shape = _{n}_shape, {ind}{ind}.stride = _{n}_stride, -{ind}{ind}.rank = {pb.FixedDimensions.Length} }}; +{ind}{ind}.rank = {pb.Dimensions.Length} }}; {ind}buffer_t *{n} = &_{n}"; break; default: @@ -328,16 +328,4 @@ protected override CSymbol VisitIfThenElse(IfThenElse expr) _exprMemo.Add(expr, symbol); return symbol; } - - protected override CSymbol VisitPhysicalBuffer(PhysicalBuffer expr) - { - if (_exprMemo.TryGetValue(expr, out var symbol)) - { - return symbol; - } - - symbol = new(CSourceBuiltn.BufferType + "*", expr.Name); - _exprMemo.Add(expr, symbol); - return symbol; - } } diff --git a/modules/Nncase.Modules.CPU/CodeGen/CSourceExtensions.cs b/modules/Nncase.Modules.CPU/CodeGen/CSourceExtensions.cs index e5b48ec442..9ba7ff5656 100644 --- a/modules/Nncase.Modules.CPU/CodeGen/CSourceExtensions.cs +++ b/modules/Nncase.Modules.CPU/CodeGen/CSourceExtensions.cs @@ -29,7 +29,6 @@ public static string ToC(this PrimType primType) => public static string ToC(this DataType dataType) => dataType switch { PrimType ptype => ptype.ToC(), - PointerType { ElemType: PrimType etype } => etype.ToC() + "*", _ => throw new NotSupportedException(dataType.ToString()), }; diff --git a/modules/Nncase.Modules.CPU/CodeGen/FunctionBuilder.cs b/modules/Nncase.Modules.CPU/CodeGen/FunctionBuilder.cs index 7cd4df4f49..a571294a75 100644 --- a/modules/Nncase.Modules.CPU/CodeGen/FunctionBuilder.cs +++ b/modules/Nncase.Modules.CPU/CodeGen/FunctionBuilder.cs @@ -72,61 +72,62 @@ public unsafe LinkableFunction Build(TIR.PrimFunction function) // 2. write the desc var descContent = new MemoryStream(); - using (var descWriter = new BinaryWriter(descContent, Encoding.UTF8)) - { - DescHeader header = new() { InputPoolSize = 0, OutputPoolSize = 0, Inputs = 0, Outputs = 0 }; - long headerStart = descWriter.Position(); - descWriter.Skip((ulong)sizeof(DescHeader)); - - foreach (var input in function.Parameters.AsValueEnumerable() - .Where(buf => buf.MemLocation == TIR.MemoryLocation.Input)) - { - header.Inputs++; - var rg = new MemoryRange { Start = checked((uint)input.Start), Size = checked((uint)input.Size) }; - descWriter.Write(ref rg); - header.InputPoolSize = Math.Max(header.InputPoolSize, rg.Start + rg.Size); - descWriter.Write((uint)input.FixedDimensions.Length); - foreach (var dim in input.FixedDimensions) - { - descWriter.Write((uint)dim); - } - foreach (var s in input.FixedStrides) - { - descWriter.Write((uint)s); - } - } - - foreach (var output in function.Parameters.AsValueEnumerable().Where(buf => buf.MemLocation == TIR.MemoryLocation.Output)) - { - header.Outputs++; - var rg = new MemoryRange { Start = checked((uint)output.Start), Size = checked((uint)output.Size) }; - descWriter.Write(ref rg); - header.OutputPoolSize = Math.Max(header.OutputPoolSize, rg.Start + rg.Size); - descWriter.Write((uint)output.FixedDimensions.Length); - foreach (var dim in output.FixedDimensions) - { - descWriter.Write((uint)dim); - } - foreach (var s in output.FixedStrides) - { - descWriter.Write((uint)s); - } - } - - descWriter.Position(headerStart); - descWriter.Write(ref header); - } + // using (var descWriter = new BinaryWriter(descContent, Encoding.UTF8)) + // { + // DescHeader header = new() { InputPoolSize = 0, OutputPoolSize = 0, Inputs = 0, Outputs = 0 }; + // long headerStart = descWriter.Position(); + // descWriter.Skip((ulong)sizeof(DescHeader)); + + // foreach (var input in function.Parameters.AsValueEnumerable() + // .Where(buf => buf.MemLocation == TIR.MemoryLocation.Input)) + // { + // header.Inputs++; + // var rg = new MemoryRange { Start = checked((uint)input.Start), Size = checked((uint)input.Size) }; + // descWriter.Write(ref rg); + // header.InputPoolSize = Math.Max(header.InputPoolSize, rg.Start + rg.Size); + // descWriter.Write((uint)input.FixedDimensions.Length); + // foreach (var dim in input.FixedDimensions) + // { + // descWriter.Write((uint)dim); + // } + // foreach (var s in input.FixedStrides) + // { + // descWriter.Write((uint)s); + // } + // } + + // foreach (var output in function.Parameters.AsValueEnumerable().Where(buf => buf.MemLocation == TIR.MemoryLocation.Output)) + // { + // header.Outputs++; + // var rg = new MemoryRange { Start = checked((uint)output.Start), Size = checked((uint)output.Size) }; + // descWriter.Write(ref rg); + // header.OutputPoolSize = Math.Max(header.OutputPoolSize, rg.Start + rg.Size); + // descWriter.Write((uint)output.FixedDimensions.Length); + // foreach (var dim in output.FixedDimensions) + // { + // descWriter.Write((uint)dim); + // } + // foreach (var s in output.FixedStrides) + // { + // descWriter.Write((uint)s); + // } + // } + + // descWriter.Position(headerStart); + // descWriter.Write(ref header); + // } // 3. write the rdata - foreach (var buffer in function.SchedResult.Rdatas) + foreach (var (@const, range) in function.SchedResult.Rdatas) { - var bytes = buffer.Const!.Value.BytesBuffer; - if ((uint)bytes.Length != buffer.Size) + var bytes = ((TensorConst)@const).Value.BytesBuffer; + var size = range.End.Value - range.Start.Value; + if ((uint)bytes.Length != size) { throw new InvalidDataException("The Buffer Szie Not Equal!"); } - _rdataWriter.Position((uint)buffer.Start); + _rdataWriter.Position((uint)size); _rdataWriter.Write(bytes); } diff --git a/modules/Nncase.Modules.CPU/Passes/Tile/CPUFusionGroupMutator.cs b/modules/Nncase.Modules.CPU/Passes/Tile/CPUFusionGroupMutator.cs index f1f9daab07..e0d361a6da 100644 --- a/modules/Nncase.Modules.CPU/Passes/Tile/CPUFusionGroupMutator.cs +++ b/modules/Nncase.Modules.CPU/Passes/Tile/CPUFusionGroupMutator.cs @@ -129,7 +129,7 @@ protected override Expr RewriteLeafCall(Call expr) int param_count = 0; foreach (var b in prim_func.Parameters) { - if (b.MemLocation == TIR.MemoryLocation.Input) + if (b.MemSpan.Location == TIR.MemoryLocation.Input) { if (is_input) { diff --git a/modules/Nncase.Modules.CPU/Passes/Tile/SingleCPUFusionConverter.cs b/modules/Nncase.Modules.CPU/Passes/Tile/SingleCPUFusionConverter.cs index 2cf09245c1..bd45b072ca 100644 --- a/modules/Nncase.Modules.CPU/Passes/Tile/SingleCPUFusionConverter.cs +++ b/modules/Nncase.Modules.CPU/Passes/Tile/SingleCPUFusionConverter.cs @@ -48,9 +48,9 @@ public ConvertVisitor(List mainBody) public Fusion VisitRootFusion => (Fusion)VisitRoot!; - public IEnumerable OutputBuffers => _buffersMap.Values.OfType().Where(b => b.MemLocation == MemoryLocation.Output); + public IEnumerable OutputBuffers => _buffersMap.Values.OfType().Where(b => b.MemSpan.Location.HasFlag(MemoryLocation.Output)); - public IEnumerable InputBuffers => _buffersMap.Values.OfType().Where(b => b.MemLocation == MemoryLocation.Input); + public IEnumerable InputBuffers => _buffersMap.Values.OfType().Where(b => b.MemSpan.Location.HasFlag(MemoryLocation.Input)); protected override Unit DefaultVisitLeaf(Expr expr) { @@ -98,7 +98,9 @@ private void GenerateMatMul(ReadOnlySpan arguments, Buffer ret, Call exp // [m,k] @ [k, n] var body = T.Block(nameof(MatMul)).Body( - T.Sequential(arguments.ToArray().OfType().Where(p => p.Const != null).Select(b => T.MatchBuffer(b)).ToArray()), + T.MatchBuffer(arguments[0]), + T.MatchBuffer(arguments[1]), + T.MatchBuffer(ret), final); _mainBody.Add(body.Build()); } @@ -110,7 +112,10 @@ private void GenerateUnary(Unary unary, ReadOnlySpan arguments, Buffer r var loopVars = loops.Select(f => f.loopVar).ToArray(); Expr stmt = T.BufferStore(ret, loopVars, IR.F.Math.Unary(unary.UnaryOp, T.BufferLoad(input, loopVars))); var final = loops.Reverse().Aggregate(stmt, (acc, p) => p.Item1.Body(acc).Build()); - _mainBody.Add(T.Block(nameof(Unary)).Body(final).Build()); + _mainBody.Add(T.Block(nameof(Unary)).Body( + T.MatchBuffer(arguments[0]), + T.MatchBuffer(ret), + final).Build()); } private void GenerateBinary(Binary binary, ReadOnlySpan arguments, Buffer ret, Call call) @@ -136,7 +141,9 @@ private void GenerateBinary(Binary binary, ReadOnlySpan arguments, Buffe Expr stmt = T.BufferStore(ret, loopVars, IR.F.Math.Binary(binary.BinaryOp, T.BufferLoad(lhsBuffer, lhsLoopVars), T.BufferLoad(rhsBuffer, rhsLoopVars))); var final = loops.Reverse().Aggregate(stmt, (acc, p) => p.Item1.Body(acc).Build()); var body = T.Block(nameof(Binary)).Body( - T.Sequential(arguments.ToArray().OfType().Where(p => p.Const != null).Select(b => T.MatchBuffer(b)).ToArray()), + T.MatchBuffer(arguments[0]), + T.MatchBuffer(arguments[1]), + T.MatchBuffer(ret), final); _mainBody.Add(body.Build()); } @@ -151,19 +158,19 @@ private TIR.Buffer TryAllocateBuffer(Expr expr) case Call c: if (ReferenceEquals(c, VisitRootFusion.Body)) { - buffer = T.PhysicalBuffer(c.CheckedDataType, MemoryLocation.Output, c.CheckedShape.ToValueArray(), out _, name); + buffer = T.AttachBuffer((TensorType)c.CheckedType, MemoryLocation.Output, out _, out _, name); } else { - buffer = T.Buffer(c.CheckedDataType, MemoryLocation.Data, c.CheckedShape.ToValueArray().Select(i => (Expr)i).ToArray(), out _, name); + buffer = T.CreateBuffer((TensorType)c.CheckedDataType, MemoryLocation.Data, out _, name); } break; case Var v: - buffer = T.PhysicalBuffer(v.CheckedDataType, MemoryLocation.Input, v.CheckedShape.ToValueArray(), out _, name); + buffer = T.AttachBuffer((TensorType)v.CheckedType, MemoryLocation.Input, out _, out _, name); break; case TensorConst c: - buffer = T.ConstBuffer(c, out _, name); + buffer = T.AttachBuffer(c, out _, name); break; default: throw new NotSupportedException(); diff --git a/src/Nncase.Core/FunctionCollector.cs b/src/Nncase.Core/FunctionCollector.cs index 9ed2a8bca7..12655d25a1 100644 --- a/src/Nncase.Core/FunctionCollector.cs +++ b/src/Nncase.Core/FunctionCollector.cs @@ -17,7 +17,7 @@ public FunctionCollector() public HashSet Functions => _functions; - protected override int VisitLeafFunction(Function expr, Unit context) + protected override int VisitLeafFunction(Function expr) { _functions.Add(expr); return 0; diff --git a/src/Nncase.Core/IR/ExprCloner.g.cs b/src/Nncase.Core/IR/ExprCloner.g.cs index 7357f7f78b..855ff4e22e 100644 --- a/src/Nncase.Core/IR/ExprCloner.g.cs +++ b/src/Nncase.Core/IR/ExprCloner.g.cs @@ -1,4 +1,3 @@ - //--------------------------------------------------------------------------------------------------- // // This code was generated by T4 template. @@ -57,8 +56,7 @@ protected override Expr VisitLeafIf(If expr, TContext context) return expr.With( condition: Clone(expr.Condition, context), then: Clone(expr.Then, context), - @else: Clone(expr.Else, context), - paramList: expr.ParamList.Select(p => Clone(p, context)).ToArray() + @else: Clone(expr.Else, context) ); } @@ -141,22 +139,6 @@ protected override Expr VisitLeafBlock(TIR.Block expr, TContext context) ); } - /// - protected override Expr VisitLeafLogicalBuffer(TIR.LogicalBuffer expr, TContext context) - { - return expr.With( - dimensions: CloneArray(expr.Dimensions, context), - strides: CloneArray(expr.Strides, context) - ); - } - - /// - protected override Expr VisitLeafPhysicalBuffer(TIR.PhysicalBuffer expr, TContext context) - { - return expr.With( - ); - } - /// protected override Expr VisitLeafBufferRegion(TIR.BufferRegion expr, TContext context) { @@ -237,12 +219,4 @@ protected override Expr VisitLeafIterVar(TIR.IterVar expr, TContext context) ); } - /// - protected override Expr VisitLeafMemSpan(TIR.MemSpan expr, TContext context) - { - return expr.With( - start: Clone(expr.Start, context), - size: Clone(expr.Size, context) - ); - } } diff --git a/src/Nncase.Core/IR/ExprFunctor.cs b/src/Nncase.Core/IR/ExprFunctor.cs index 4462f8d2cc..0fd8bc9ed9 100644 --- a/src/Nncase.Core/IR/ExprFunctor.cs +++ b/src/Nncase.Core/IR/ExprFunctor.cs @@ -102,6 +102,13 @@ public partial class ExprFunctor : ExprFunctorResult. public virtual TTypeResult VisitType(TensorType type) => base.VisitType(type, default); + /// + /// Visit point type. + /// + /// pointer type. + /// Result. + public virtual TTypeResult VisitType(PointerType type) => base.VisitType(type, default); + /// /// Visit tuple type. /// @@ -135,6 +142,9 @@ public partial class ExprFunctor : ExprFunctor public sealed override TTypeResult VisitType(TensorType type, Unit context) => VisitType(type); + /// + public sealed override TTypeResult VisitType(PointerType type, Unit context) => VisitType(type); + /// public sealed override TTypeResult VisitType(TupleType type, Unit context) => VisitType(type); diff --git a/src/Nncase.Core/IR/ExprFunctor.g.cs b/src/Nncase.Core/IR/ExprFunctor.g.cs index f6ff8fd928..188aad4659 100644 --- a/src/Nncase.Core/IR/ExprFunctor.g.cs +++ b/src/Nncase.Core/IR/ExprFunctor.g.cs @@ -1,4 +1,3 @@ - //--------------------------------------------------------------------------------------------------- // // This code was generated by T4 template. @@ -79,6 +78,11 @@ public partial class ExprFunctor /// internal protected virtual TExprResult VisitTupleConst(TupleConst expr, TContext context) => VisitConst(expr, context); + /// + /// Visit . + /// + internal protected virtual TExprResult VisitMemSpan(TIR.MemSpan expr, TContext context) => DefaultVisit(expr, context); + /// /// Visit . /// @@ -94,16 +98,6 @@ public partial class ExprFunctor /// internal protected virtual TExprResult VisitBuffer(TIR.Buffer expr, TContext context) => DefaultVisit(expr, context); - /// - /// Visit . - /// - internal protected virtual TExprResult VisitLogicalBuffer(TIR.LogicalBuffer expr, TContext context) => VisitBuffer(expr, context); - - /// - /// Visit . - /// - internal protected virtual TExprResult VisitPhysicalBuffer(TIR.PhysicalBuffer expr, TContext context) => VisitBuffer(expr, context); - /// /// Visit . /// @@ -144,10 +138,6 @@ public partial class ExprFunctor /// internal protected virtual TExprResult VisitIterVar(TIR.IterVar expr, TContext context) => DefaultVisit(expr, context); - /// - /// Visit . - /// - internal protected virtual TExprResult VisitMemSpan(TIR.MemSpan expr, TContext context) => DefaultVisit(expr, context); } public partial class ExprFunctor @@ -244,6 +234,13 @@ public partial class ExprFunctor /// internal protected sealed override TExprResult VisitTupleConst(TupleConst expr, Unit context) => VisitTupleConst(expr); /// + /// Visit . + /// + internal protected virtual TExprResult VisitMemSpan(TIR.MemSpan expr) => base.VisitMemSpan(expr, default); + + /// + internal protected sealed override TExprResult VisitMemSpan(TIR.MemSpan expr, Unit context) => VisitMemSpan(expr); + /// /// Visit . /// internal protected virtual TExprResult VisitVar(Var expr) => base.VisitVar(expr, default); @@ -265,20 +262,6 @@ public partial class ExprFunctor /// internal protected sealed override TExprResult VisitBuffer(TIR.Buffer expr, Unit context) => VisitBuffer(expr); /// - /// Visit . - /// - internal protected virtual TExprResult VisitLogicalBuffer(TIR.LogicalBuffer expr) => base.VisitLogicalBuffer(expr, default); - - /// - internal protected sealed override TExprResult VisitLogicalBuffer(TIR.LogicalBuffer expr, Unit context) => VisitLogicalBuffer(expr); - /// - /// Visit . - /// - internal protected virtual TExprResult VisitPhysicalBuffer(TIR.PhysicalBuffer expr) => base.VisitPhysicalBuffer(expr, default); - - /// - internal protected sealed override TExprResult VisitPhysicalBuffer(TIR.PhysicalBuffer expr, Unit context) => VisitPhysicalBuffer(expr); - /// /// Visit . /// internal protected virtual TExprResult VisitBufferRegion(TIR.BufferRegion expr) => base.VisitBufferRegion(expr, default); @@ -334,11 +317,4 @@ public partial class ExprFunctor /// internal protected sealed override TExprResult VisitIterVar(TIR.IterVar expr, Unit context) => VisitIterVar(expr); - /// - /// Visit . - /// - internal protected virtual TExprResult VisitMemSpan(TIR.MemSpan expr) => base.VisitMemSpan(expr, default); - - /// - internal protected sealed override TExprResult VisitMemSpan(TIR.MemSpan expr, Unit context) => VisitMemSpan(expr); } diff --git a/src/Nncase.Core/IR/ExprRewriter.g.cs b/src/Nncase.Core/IR/ExprRewriter.g.cs index 6695315197..b842c110f1 100644 --- a/src/Nncase.Core/IR/ExprRewriter.g.cs +++ b/src/Nncase.Core/IR/ExprRewriter.g.cs @@ -1,4 +1,3 @@ - //--------------------------------------------------------------------------------------------------- // // This code was generated by T4 template. @@ -92,6 +91,12 @@ protected sealed override Expr VisitLeafTupleConst(TupleConst expr, TContext con return RewriteLeafTupleConst(expr, context); } + /// + protected sealed override Expr VisitLeafMemSpan(TIR.MemSpan expr, TContext context) + { + return RewriteLeafMemSpan(expr, context); + } + /// protected sealed override Expr VisitLeafVar(Var expr, TContext context) { @@ -110,18 +115,6 @@ protected sealed override Expr VisitLeafBuffer(TIR.Buffer expr, TContext context return RewriteLeafBuffer(expr, context); } - /// - protected sealed override Expr VisitLeafLogicalBuffer(TIR.LogicalBuffer expr, TContext context) - { - return RewriteLeafLogicalBuffer(expr, context); - } - - /// - protected sealed override Expr VisitLeafPhysicalBuffer(TIR.PhysicalBuffer expr, TContext context) - { - return RewriteLeafPhysicalBuffer(expr, context); - } - /// protected sealed override Expr VisitLeafBufferRegion(TIR.BufferRegion expr, TContext context) { @@ -235,6 +228,11 @@ protected sealed override Expr VisitLeafIterVar(TIR.IterVar expr, TContext conte /// protected virtual Expr RewriteLeafTupleConst(TupleConst expr, TContext context) => RewriteLeafConst(expr, context); + /// + /// Rewrite leaf . + /// + protected virtual Expr RewriteLeafMemSpan(TIR.MemSpan expr, TContext context) => DefaultRewriteLeaf(expr, context); + /// /// Rewrite leaf . /// @@ -250,16 +248,6 @@ protected sealed override Expr VisitLeafIterVar(TIR.IterVar expr, TContext conte /// protected virtual Expr RewriteLeafBuffer(TIR.Buffer expr, TContext context) => DefaultRewriteLeaf(expr, context); - /// - /// Rewrite leaf . - /// - protected virtual Expr RewriteLeafLogicalBuffer(TIR.LogicalBuffer expr, TContext context) => RewriteLeafBuffer(expr, context); - - /// - /// Rewrite leaf . - /// - protected virtual Expr RewriteLeafPhysicalBuffer(TIR.PhysicalBuffer expr, TContext context) => RewriteLeafBuffer(expr, context); - /// /// Rewrite leaf . /// @@ -300,10 +288,6 @@ protected sealed override Expr VisitLeafIterVar(TIR.IterVar expr, TContext conte /// protected virtual Expr RewriteLeafIterVar(TIR.IterVar expr, TContext context) => DefaultRewriteLeaf(expr, context); - /// - /// Rewrite leaf . - /// - protected virtual Expr RewriteLeafMemSpan(TIR.MemSpan expr, TContext context) => DefaultRewriteLeaf(expr, context); } public partial class ExprRewriter @@ -412,6 +396,14 @@ public partial class ExprRewriter /// protected sealed override Expr RewriteLeafTupleConst(TupleConst expr, Unit context) => RewriteLeafTupleConst(expr); + /// + /// Rewrite leaf . + /// + protected virtual Expr RewriteLeafMemSpan(TIR.MemSpan expr) => DefaultRewriteLeaf(expr); + + /// + protected sealed override Expr RewriteLeafMemSpan(TIR.MemSpan expr, Unit context) => RewriteLeafMemSpan(expr); + /// /// Rewrite leaf . /// @@ -436,22 +428,6 @@ public partial class ExprRewriter /// protected sealed override Expr RewriteLeafBuffer(TIR.Buffer expr, Unit context) => RewriteLeafBuffer(expr); - /// - /// Rewrite leaf . - /// - protected virtual Expr RewriteLeafLogicalBuffer(TIR.LogicalBuffer expr) => RewriteLeafBuffer(expr); - - /// - protected sealed override Expr RewriteLeafLogicalBuffer(TIR.LogicalBuffer expr, Unit context) => RewriteLeafLogicalBuffer(expr); - - /// - /// Rewrite leaf . - /// - protected virtual Expr RewriteLeafPhysicalBuffer(TIR.PhysicalBuffer expr) => RewriteLeafBuffer(expr); - - /// - protected sealed override Expr RewriteLeafPhysicalBuffer(TIR.PhysicalBuffer expr, Unit context) => RewriteLeafPhysicalBuffer(expr); - /// /// Rewrite leaf . /// @@ -516,11 +492,4 @@ public partial class ExprRewriter /// protected sealed override Expr RewriteLeafIterVar(TIR.IterVar expr, Unit context) => RewriteLeafIterVar(expr); - /// - /// Rewrite leaf . - /// - protected virtual Expr RewriteLeafMemSpan(TIR.MemSpan expr) => DefaultRewriteLeaf(expr); - - /// - protected sealed override Expr RewriteLeafMemSpan(TIR.MemSpan expr, Unit context) => RewriteLeafMemSpan(expr); } diff --git a/src/Nncase.Core/IR/ExprVisitor.g.cs b/src/Nncase.Core/IR/ExprVisitor.g.cs index c56e4f5aa7..75dc5ab13d 100644 --- a/src/Nncase.Core/IR/ExprVisitor.g.cs +++ b/src/Nncase.Core/IR/ExprVisitor.g.cs @@ -1,4 +1,3 @@ - //--------------------------------------------------------------------------------------------------- // // This code was generated by T4 template. @@ -117,20 +116,6 @@ protected internal override TExprResult VisitBlock(TIR.Block expr, TContext cont return VisitLeafBlock(expr, context); } - /// - protected internal override TExprResult VisitLogicalBuffer(TIR.LogicalBuffer expr, TContext context) - { - VisitOperands(expr, context); - return VisitLeafLogicalBuffer(expr, context); - } - - /// - protected internal override TExprResult VisitPhysicalBuffer(TIR.PhysicalBuffer expr, TContext context) - { - VisitOperands(expr, context); - return VisitLeafPhysicalBuffer(expr, context); - } - /// protected internal override TExprResult VisitBufferRegion(TIR.BufferRegion expr, TContext context) { @@ -191,13 +176,6 @@ protected internal override TExprResult VisitIterVar(TIR.IterVar expr, TContext return VisitLeafIterVar(expr, context); } - /// - protected internal override TExprResult VisitMemSpan(TIR.MemSpan expr, TContext context) - { - VisitOperands(expr, context); - return VisitLeafMemSpan(expr, context); - } - /// /// Visit leaf . /// @@ -263,6 +241,11 @@ protected internal override TExprResult VisitMemSpan(TIR.MemSpan expr, TContext /// protected virtual TExprResult VisitLeafTupleConst(TupleConst expr, TContext context) => VisitLeafConst(expr, context); + /// + /// Visit leaf . + /// + protected virtual TExprResult VisitLeafMemSpan(TIR.MemSpan expr, TContext context) => DefaultVisitLeaf(expr, context); + /// /// Visit leaf . /// @@ -278,16 +261,6 @@ protected internal override TExprResult VisitMemSpan(TIR.MemSpan expr, TContext /// protected virtual TExprResult VisitLeafBuffer(TIR.Buffer expr, TContext context) => DefaultVisitLeaf(expr, context); - /// - /// Visit leaf . - /// - protected virtual TExprResult VisitLeafLogicalBuffer(TIR.LogicalBuffer expr, TContext context) => VisitLeafBuffer(expr, context); - - /// - /// Visit leaf . - /// - protected virtual TExprResult VisitLeafPhysicalBuffer(TIR.PhysicalBuffer expr, TContext context) => VisitLeafBuffer(expr, context); - /// /// Visit leaf . /// @@ -328,11 +301,6 @@ protected internal override TExprResult VisitMemSpan(TIR.MemSpan expr, TContext /// protected virtual TExprResult VisitLeafIterVar(TIR.IterVar expr, TContext context) => DefaultVisitLeaf(expr, context); - /// - /// Visit leaf . - /// - protected virtual TExprResult VisitLeafMemSpan(TIR.MemSpan expr, TContext context) => DefaultVisitLeaf(expr, context); - } public partial class ExprVisitor @@ -341,176 +309,154 @@ public partial class ExprVisitor /// Visit . /// internal protected virtual TExprResult VisitCall(Call expr) => base.VisitCall(expr, default); - + /// internal protected sealed override TExprResult VisitCall(Call expr, Unit context) => VisitCall(expr); /// /// Visit . /// internal protected virtual TExprResult VisitFunction(Function expr) => base.VisitFunction(expr, default); - + /// internal protected sealed override TExprResult VisitFunction(Function expr, Unit context) => VisitFunction(expr); /// /// Visit . /// internal protected virtual TExprResult VisitFusion(Fusion expr) => base.VisitFusion(expr, default); - + /// internal protected sealed override TExprResult VisitFusion(Fusion expr, Unit context) => VisitFusion(expr); /// /// Visit . /// internal protected virtual TExprResult VisitIf(If expr) => base.VisitIf(expr, default); - + /// internal protected sealed override TExprResult VisitIf(If expr, Unit context) => VisitIf(expr); /// /// Visit . /// internal protected virtual TExprResult VisitMarker(Marker expr) => base.VisitMarker(expr, default); - + /// internal protected sealed override TExprResult VisitMarker(Marker expr, Unit context) => VisitMarker(expr); /// /// Visit . /// internal protected virtual TExprResult VisitNone(None expr) => base.VisitNone(expr, default); - + /// internal protected sealed override TExprResult VisitNone(None expr, Unit context) => VisitNone(expr); /// /// Visit . /// internal protected virtual TExprResult VisitOp(Op expr) => base.VisitOp(expr, default); - + /// internal protected sealed override TExprResult VisitOp(Op expr, Unit context) => VisitOp(expr); /// /// Visit . /// internal protected virtual TExprResult VisitPrimFunctionWrapper(PrimFunctionWrapper expr) => base.VisitPrimFunctionWrapper(expr, default); - + /// internal protected sealed override TExprResult VisitPrimFunctionWrapper(PrimFunctionWrapper expr, Unit context) => VisitPrimFunctionWrapper(expr); /// /// Visit . /// internal protected virtual TExprResult VisitTensorConst(TensorConst expr) => base.VisitTensorConst(expr, default); - + /// internal protected sealed override TExprResult VisitTensorConst(TensorConst expr, Unit context) => VisitTensorConst(expr); /// /// Visit . /// internal protected virtual TExprResult VisitTuple(IR.Tuple expr) => base.VisitTuple(expr, default); - + /// internal protected sealed override TExprResult VisitTuple(IR.Tuple expr, Unit context) => VisitTuple(expr); /// /// Visit . /// internal protected virtual TExprResult VisitTupleConst(TupleConst expr) => base.VisitTupleConst(expr, default); - + /// internal protected sealed override TExprResult VisitTupleConst(TupleConst expr, Unit context) => VisitTupleConst(expr); /// /// Visit . /// internal protected virtual TExprResult VisitVar(Var expr) => base.VisitVar(expr, default); - + /// internal protected sealed override TExprResult VisitVar(Var expr, Unit context) => VisitVar(expr); /// /// Visit . /// internal protected virtual TExprResult VisitBlock(TIR.Block expr) => base.VisitBlock(expr, default); - + /// internal protected sealed override TExprResult VisitBlock(TIR.Block expr, Unit context) => VisitBlock(expr); /// - /// Visit . - /// - internal protected virtual TExprResult VisitLogicalBuffer(TIR.LogicalBuffer expr) => base.VisitLogicalBuffer(expr, default); - - /// - internal protected sealed override TExprResult VisitLogicalBuffer(TIR.LogicalBuffer expr, Unit context) => VisitLogicalBuffer(expr); - /// - /// Visit . - /// - internal protected virtual TExprResult VisitPhysicalBuffer(TIR.PhysicalBuffer expr) => base.VisitPhysicalBuffer(expr, default); - - /// - internal protected sealed override TExprResult VisitPhysicalBuffer(TIR.PhysicalBuffer expr, Unit context) => VisitPhysicalBuffer(expr); - /// /// Visit . /// internal protected virtual TExprResult VisitBufferRegion(TIR.BufferRegion expr) => base.VisitBufferRegion(expr, default); - + /// internal protected sealed override TExprResult VisitBufferRegion(TIR.BufferRegion expr, Unit context) => VisitBufferRegion(expr); /// /// Visit . /// internal protected virtual TExprResult VisitFor(TIR.For expr) => base.VisitFor(expr, default); - + /// internal protected sealed override TExprResult VisitFor(TIR.For expr, Unit context) => VisitFor(expr); /// /// Visit . /// internal protected virtual TExprResult VisitIfThenElse(TIR.IfThenElse expr) => base.VisitIfThenElse(expr, default); - + /// internal protected sealed override TExprResult VisitIfThenElse(TIR.IfThenElse expr, Unit context) => VisitIfThenElse(expr); /// /// Visit . /// internal protected virtual TExprResult VisitLet(TIR.Let expr) => base.VisitLet(expr, default); - + /// internal protected sealed override TExprResult VisitLet(TIR.Let expr, Unit context) => VisitLet(expr); /// /// Visit . /// internal protected virtual TExprResult VisitPrimFunction(TIR.PrimFunction expr) => base.VisitPrimFunction(expr, default); - + /// internal protected sealed override TExprResult VisitPrimFunction(TIR.PrimFunction expr, Unit context) => VisitPrimFunction(expr); /// /// Visit . /// internal protected virtual TExprResult VisitSequential(TIR.Sequential expr) => base.VisitSequential(expr, default); - + /// internal protected sealed override TExprResult VisitSequential(TIR.Sequential expr, Unit context) => VisitSequential(expr); /// /// Visit . /// internal protected virtual TExprResult VisitRange(TIR.Range expr) => base.VisitRange(expr, default); - + /// internal protected sealed override TExprResult VisitRange(TIR.Range expr, Unit context) => VisitRange(expr); /// /// Visit . /// internal protected virtual TExprResult VisitIterVar(TIR.IterVar expr) => base.VisitIterVar(expr, default); - + /// internal protected sealed override TExprResult VisitIterVar(TIR.IterVar expr, Unit context) => VisitIterVar(expr); - /// - /// Visit . - /// - internal protected virtual TExprResult VisitMemSpan(TIR.MemSpan expr) => base.VisitMemSpan(expr, default); - - /// - internal protected sealed override TExprResult VisitMemSpan(TIR.MemSpan expr, Unit context) => VisitMemSpan(expr); - /// /// Visit leaf . /// protected virtual TExprResult VisitLeafBaseFunction(BaseFunction expr) => base.VisitLeafBaseFunction(expr, default); - + /// protected sealed override TExprResult VisitLeafBaseFunction(BaseFunction expr, Unit context) => VisitLeafBaseFunction(expr); @@ -518,7 +464,7 @@ public partial class ExprVisitor /// Visit leaf . /// protected virtual TExprResult VisitLeafCall(Call expr) => base.VisitLeafCall(expr, default); - + /// protected sealed override TExprResult VisitLeafCall(Call expr, Unit context) => VisitLeafCall(expr); @@ -526,7 +472,7 @@ public partial class ExprVisitor /// Visit leaf . /// protected virtual TExprResult VisitLeafConst(Const expr) => base.VisitLeafConst(expr, default); - + /// protected sealed override TExprResult VisitLeafConst(Const expr, Unit context) => VisitLeafConst(expr); @@ -534,15 +480,15 @@ public partial class ExprVisitor /// Visit leaf . /// protected virtual TExprResult VisitLeafFunction(Function expr) => base.VisitLeafFunction(expr, default); - + /// - protected override TExprResult VisitLeafFunction(Function expr, Unit context) => VisitLeafFunction(expr); + protected sealed override TExprResult VisitLeafFunction(Function expr, Unit context) => VisitLeafFunction(expr); /// /// Visit leaf . /// protected virtual TExprResult VisitLeafFusion(Fusion expr) => base.VisitLeafFusion(expr, default); - + /// protected sealed override TExprResult VisitLeafFusion(Fusion expr, Unit context) => VisitLeafFusion(expr); @@ -550,7 +496,7 @@ public partial class ExprVisitor /// Visit leaf . /// protected virtual TExprResult VisitLeafIf(If expr) => base.VisitLeafIf(expr, default); - + /// protected sealed override TExprResult VisitLeafIf(If expr, Unit context) => VisitLeafIf(expr); @@ -558,7 +504,7 @@ public partial class ExprVisitor /// Visit leaf . /// protected virtual TExprResult VisitLeafMarker(Marker expr) => base.VisitLeafMarker(expr, default); - + /// protected sealed override TExprResult VisitLeafMarker(Marker expr, Unit context) => VisitLeafMarker(expr); @@ -566,7 +512,7 @@ public partial class ExprVisitor /// Visit leaf . /// protected virtual TExprResult VisitLeafNone(None expr) => base.VisitLeafNone(expr, default); - + /// protected sealed override TExprResult VisitLeafNone(None expr, Unit context) => VisitLeafNone(expr); @@ -574,7 +520,7 @@ public partial class ExprVisitor /// Visit leaf . /// protected virtual TExprResult VisitLeafOp(Op expr) => base.VisitLeafOp(expr, default); - + /// protected sealed override TExprResult VisitLeafOp(Op expr, Unit context) => VisitLeafOp(expr); @@ -582,7 +528,7 @@ public partial class ExprVisitor /// Visit leaf . /// protected virtual TExprResult VisitLeafPrimFunctionWrapper(PrimFunctionWrapper expr) => base.VisitLeafPrimFunctionWrapper(expr, default); - + /// protected sealed override TExprResult VisitLeafPrimFunctionWrapper(PrimFunctionWrapper expr, Unit context) => VisitLeafPrimFunctionWrapper(expr); @@ -590,7 +536,7 @@ public partial class ExprVisitor /// Visit leaf . /// protected virtual TExprResult VisitLeafTensorConst(TensorConst expr) => base.VisitLeafTensorConst(expr, default); - + /// protected sealed override TExprResult VisitLeafTensorConst(TensorConst expr, Unit context) => VisitLeafTensorConst(expr); @@ -598,7 +544,7 @@ public partial class ExprVisitor /// Visit leaf . /// protected virtual TExprResult VisitLeafTuple(IR.Tuple expr) => base.VisitLeafTuple(expr, default); - + /// protected sealed override TExprResult VisitLeafTuple(IR.Tuple expr, Unit context) => VisitLeafTuple(expr); @@ -606,15 +552,23 @@ public partial class ExprVisitor /// Visit leaf . /// protected virtual TExprResult VisitLeafTupleConst(TupleConst expr) => base.VisitLeafTupleConst(expr, default); - + /// protected sealed override TExprResult VisitLeafTupleConst(TupleConst expr, Unit context) => VisitLeafTupleConst(expr); + /// + /// Visit leaf . + /// + protected virtual TExprResult VisitLeafMemSpan(TIR.MemSpan expr) => base.VisitLeafMemSpan(expr, default); + + /// + protected sealed override TExprResult VisitLeafMemSpan(TIR.MemSpan expr, Unit context) => VisitLeafMemSpan(expr); + /// /// Visit leaf . /// protected virtual TExprResult VisitLeafVar(Var expr) => base.VisitLeafVar(expr, default); - + /// protected sealed override TExprResult VisitLeafVar(Var expr, Unit context) => VisitLeafVar(expr); @@ -622,7 +576,7 @@ public partial class ExprVisitor /// Visit leaf . /// protected virtual TExprResult VisitLeafBlock(TIR.Block expr) => base.VisitLeafBlock(expr, default); - + /// protected sealed override TExprResult VisitLeafBlock(TIR.Block expr, Unit context) => VisitLeafBlock(expr); @@ -630,31 +584,15 @@ public partial class ExprVisitor /// Visit leaf . /// protected virtual TExprResult VisitLeafBuffer(TIR.Buffer expr) => base.VisitLeafBuffer(expr, default); - + /// protected sealed override TExprResult VisitLeafBuffer(TIR.Buffer expr, Unit context) => VisitLeafBuffer(expr); - /// - /// Visit leaf . - /// - protected virtual TExprResult VisitLeafLogicalBuffer(TIR.LogicalBuffer expr) => base.VisitLeafLogicalBuffer(expr, default); - - /// - protected sealed override TExprResult VisitLeafLogicalBuffer(TIR.LogicalBuffer expr, Unit context) => VisitLeafLogicalBuffer(expr); - - /// - /// Visit leaf . - /// - protected virtual TExprResult VisitLeafPhysicalBuffer(TIR.PhysicalBuffer expr) => base.VisitLeafPhysicalBuffer(expr, default); - - /// - protected sealed override TExprResult VisitLeafPhysicalBuffer(TIR.PhysicalBuffer expr, Unit context) => VisitLeafPhysicalBuffer(expr); - /// /// Visit leaf . /// protected virtual TExprResult VisitLeafBufferRegion(TIR.BufferRegion expr) => base.VisitLeafBufferRegion(expr, default); - + /// protected sealed override TExprResult VisitLeafBufferRegion(TIR.BufferRegion expr, Unit context) => VisitLeafBufferRegion(expr); @@ -662,7 +600,7 @@ public partial class ExprVisitor /// Visit leaf . /// protected virtual TExprResult VisitLeafFor(TIR.For expr) => base.VisitLeafFor(expr, default); - + /// protected sealed override TExprResult VisitLeafFor(TIR.For expr, Unit context) => VisitLeafFor(expr); @@ -670,7 +608,7 @@ public partial class ExprVisitor /// Visit leaf . /// protected virtual TExprResult VisitLeafIfThenElse(TIR.IfThenElse expr) => base.VisitLeafIfThenElse(expr, default); - + /// protected sealed override TExprResult VisitLeafIfThenElse(TIR.IfThenElse expr, Unit context) => VisitLeafIfThenElse(expr); @@ -678,7 +616,7 @@ public partial class ExprVisitor /// Visit leaf . /// protected virtual TExprResult VisitLeafLet(TIR.Let expr) => base.VisitLeafLet(expr, default); - + /// protected sealed override TExprResult VisitLeafLet(TIR.Let expr, Unit context) => VisitLeafLet(expr); @@ -686,7 +624,7 @@ public partial class ExprVisitor /// Visit leaf . /// protected virtual TExprResult VisitLeafPrimFunction(TIR.PrimFunction expr) => base.VisitLeafPrimFunction(expr, default); - + /// protected sealed override TExprResult VisitLeafPrimFunction(TIR.PrimFunction expr, Unit context) => VisitLeafPrimFunction(expr); @@ -694,7 +632,7 @@ public partial class ExprVisitor /// Visit leaf . /// protected virtual TExprResult VisitLeafSequential(TIR.Sequential expr) => base.VisitLeafSequential(expr, default); - + /// protected sealed override TExprResult VisitLeafSequential(TIR.Sequential expr, Unit context) => VisitLeafSequential(expr); @@ -702,7 +640,7 @@ public partial class ExprVisitor /// Visit leaf . /// protected virtual TExprResult VisitLeafRange(TIR.Range expr) => base.VisitLeafRange(expr, default); - + /// protected sealed override TExprResult VisitLeafRange(TIR.Range expr, Unit context) => VisitLeafRange(expr); @@ -710,15 +648,8 @@ public partial class ExprVisitor /// Visit leaf . /// protected virtual TExprResult VisitLeafIterVar(TIR.IterVar expr) => base.VisitLeafIterVar(expr, default); - + /// protected sealed override TExprResult VisitLeafIterVar(TIR.IterVar expr, Unit context) => VisitLeafIterVar(expr); - /// - /// Visit leaf . - /// - protected virtual TExprResult VisitLeafMemSpan(TIR.MemSpan expr) => base.VisitLeafMemSpan(expr, default); - - /// - protected sealed override TExprResult VisitLeafMemSpan(TIR.MemSpan expr, Unit context) => VisitLeafMemSpan(expr); } diff --git a/src/Nncase.Core/IR/IRList.csv b/src/Nncase.Core/IR/IRList.csv index 5ae3c89d18..8a1e2ed2f1 100644 --- a/src/Nncase.Core/IR/IRList.csv +++ b/src/Nncase.Core/IR/IRList.csv @@ -11,14 +11,11 @@ PrimFunctionWrapper,true,true,BaseFunction,,Target TensorConst,true,false,Const,, Tuple,true,false,Default,IR.,@Fields TupleConst,true,false,Const,, +MemSpan,false,false,Default,TIR.,Start;Size; Var,true,false,Default,, Block,true,false,Default,TIR.,Body;InitBody;@IterVars;@Reads;@Writes;@AllocBuffers;Predicate -Buffer,false,false,Default,TIR., -LogicalBuffer,true,false,Buffer,TIR.,@Dimensions;@Strides -PhysicalBuffer,true,false,Buffer,TIR., -BufferLoad,true,false,Default,TIR.,Buffer;@Indices +Buffer,false,false,Default,TIR.,MemSpan;@Dimensions;@Strides BufferRegion,true,false,Default,TIR.,Buffer;@Region -BufferStore,true,false,Default,TIR.,Buffer;@Indices;Value For,true,false,Default,TIR.,LoopVar;Domain;Body IfThenElse,true,false,Default,TIR.,Condition;Then;Else Let,true,false,Default,TIR.,Var;Expression;Body diff --git a/src/Nncase.Core/IR/TypeFunctor.cs b/src/Nncase.Core/IR/TypeFunctor.cs index 453cfa257a..d63f56c101 100644 --- a/src/Nncase.Core/IR/TypeFunctor.cs +++ b/src/Nncase.Core/IR/TypeFunctor.cs @@ -68,6 +68,14 @@ public virtual TResult VisitType(IRType type, TContext context) /// Result. public virtual TResult VisitType(TensorType type, TContext context) => DefaultVisitType(type, context); + /// + /// Visit pointer type. + /// + /// Pointer type. + /// Context. + /// Result. + public virtual TResult VisitType(PointerType type, TContext context) => DefaultVisitType(type, context); + /// /// Visit tuple type. /// diff --git a/src/Nncase.Core/Passes/Mutators/FoldBufferSlot.cs b/src/Nncase.Core/Passes/Mutators/FoldBufferSlot.cs index 4bc48af9d1..018183c5d7 100644 --- a/src/Nncase.Core/Passes/Mutators/FoldBufferSlot.cs +++ b/src/Nncase.Core/Passes/Mutators/FoldBufferSlot.cs @@ -28,7 +28,7 @@ protected override Expr RewriteLeafCall(Call expr) { if (expr.Target is IR.Buffers.BaseMentOf) { - var locate = ((TIR.PhysicalBuffer)expr.Arguments[0]).MemLocation; + var locate = ((TIR.MemSpan)expr.Arguments[0]).Location; return locate switch { MemoryLocation.Input => 0, @@ -40,7 +40,7 @@ protected override Expr RewriteLeafCall(Call expr) } else if (expr.Target is IR.Buffers.DDrOf) { - if (expr.Arguments[0] is TIR.PhysicalBuffer buf) + if (expr.Arguments[0] is TIR.MemSpan buf) { return buf.Start; } diff --git a/src/Nncase.Core/Passes/Mutators/UnRollLoopSequential.cs b/src/Nncase.Core/Passes/Mutators/UnRollLoopSequential.cs index e0a043cc64..0bb934f02a 100644 --- a/src/Nncase.Core/Passes/Mutators/UnRollLoopSequential.cs +++ b/src/Nncase.Core/Passes/Mutators/UnRollLoopSequential.cs @@ -144,8 +144,6 @@ public LoopBodyCloner(IReadOnlyDictionary vmap, Dictionary expr; - protected override Expr VisitLeafVar(Var expr, Unit context) { if (_vmap.TryGetValue(expr, out var result)) @@ -189,9 +187,10 @@ protected override Expr VisitLeafRange(TIR.Range expr, Unit context) return CSE(expr.With(start: Clone(expr.Start, context), stop: Clone(expr.Stop, context), step: Clone(expr.Step, context))); } - protected override Expr VisitLeafLogicalBuffer(LogicalBuffer expr, Unit context) + protected override Expr VisitLeafBuffer(TIR.Buffer expr, Unit context) { return expr.With( + memSpan: Clone(expr.MemSpan, context), dimensions: CloneArray(expr.Dimensions, context).Select(e => CSE(e)).ToArray(), strides: CloneArray(expr.Strides, context)); } diff --git a/src/Nncase.Core/Schedule/ScheduleTypes.cs b/src/Nncase.Core/Schedule/ScheduleTypes.cs index f63e0467dc..8661109f22 100644 --- a/src/Nncase.Core/Schedule/ScheduleTypes.cs +++ b/src/Nncase.Core/Schedule/ScheduleTypes.cs @@ -215,7 +215,7 @@ public SchedFunctionResult() /// /// Gets the buffer allocation. /// - public HashSet Rdatas { get; } + public Dictionary Rdatas { get; } /// /// Gets or sets the data section length. @@ -250,7 +250,7 @@ public override bool Equals(object? obj) return true; } - return EqualityComparer>.Default.Equals(Rdatas, result.Rdatas) && + return EqualityComparer>.Default.Equals(Rdatas, result.Rdatas) && EqualityComparer.Default.Equals(DataUsage, result.DataUsage); } diff --git a/src/Nncase.Core/TIR/Buffer.cs b/src/Nncase.Core/TIR/Buffer.cs index 80f86c4f21..cc5c5f1155 100644 --- a/src/Nncase.Core/TIR/Buffer.cs +++ b/src/Nncase.Core/TIR/Buffer.cs @@ -269,15 +269,12 @@ public SelectedRange Slice(Segment1D segment) /// public sealed class Buffer : Expr { - private static int _globalVarIndex; - public Buffer(string name, DataType elemType, MemSpan memSpan, Expr[] dimensions, Expr[] strides) : base(new[] { memSpan }.Concat(dimensions).Concat(strides)) { Name = name; ElemType = elemType; Rank = dimensions.Length; - GlobalVarIndex = Interlocked.Increment(ref _globalVarIndex); } public string Name { get; } @@ -289,11 +286,6 @@ public Buffer(string name, DataType elemType, MemSpan memSpan, Expr[] dimensions /// public int Rank { get; } - /// - /// Gets the global var index. - /// - public int GlobalVarIndex { get; } - /// /// Gets the shape. /// @@ -312,5 +304,8 @@ public Buffer(string name, DataType elemType, MemSpan memSpan, Expr[] dimensions /// public ReadOnlySpan Strides => Operands[(1 + Rank)..(1 + Rank + Rank)]; - public override TExprResult Accept(ExprFunctor functor, TContext context) => throw new NotImplementedException(); -} \ No newline at end of file + public override TExprResult Accept(ExprFunctor functor, TContext context) => functor.VisitBuffer(this, context); + + public Buffer With(MemSpan? memSpan = null, Expr[]? dimensions = null, Expr[]? strides = null) + => new Buffer(Name, ElemType, memSpan ?? MemSpan, dimensions ?? Dimensions.ToArray(), strides ?? Strides.ToArray()); +} diff --git a/src/Nncase.Core/TIR/MemSpan.cs b/src/Nncase.Core/TIR/MemSpan.cs index 0ee5b4a215..3610d198c3 100644 --- a/src/Nncase.Core/TIR/MemSpan.cs +++ b/src/Nncase.Core/TIR/MemSpan.cs @@ -9,47 +9,48 @@ namespace Nncase.TIR; /// /// the memory type. /// -public enum MemoryLocation : byte +[Flags] +public enum MemoryLocation { /// /// input. /// - Input = 0, + Input = 1 << 1, /// /// output. /// - Output = 1, + Output = 1 << 2, /// /// constant data. /// - Rdata = 2, + Rdata = 1 << 3, /// /// compute temp data. /// - Data = 3, + Data = 1 << 4, /// /// shared data. /// - SharedData = 4, + SharedData = 1 << 5, /// /// l2 data. /// - L2Data = 5, + L2Data = 1 << 6, /// /// L1 data. /// - L1Data = 6, + L1Data = 1 << 7, /// /// base addr. /// - PrivateBase = 64, + PrivateBase = 1 << 8, } public sealed class MemSpan : Expr @@ -81,6 +82,8 @@ public MemSpan(Expr start, Expr size, MemoryLocation location) /// public MemoryLocation Location { get; } + public MemSpan SubSpan(Expr offset, Expr size) => new MemSpan((Start is None ? IR.F.Buffer.DDrOf(this) : Start) + offset, size, Location); + /// public override TExprResult Accept(ExprFunctor functor, TContext context) => functor.VisitMemSpan(this, context); diff --git a/src/Nncase.Core/TIR/PrimFunction.cs b/src/Nncase.Core/TIR/PrimFunction.cs index ea208efb15..2bf94454eb 100644 --- a/src/Nncase.Core/TIR/PrimFunction.cs +++ b/src/Nncase.Core/TIR/PrimFunction.cs @@ -28,8 +28,8 @@ public sealed class PrimFunction : BaseFunction /// module kind. /// Arguments. /// Body. - public PrimFunction(string name, string moduleKind, Sequential body, ReadOnlySpan parameters) - : base(name, moduleKind, ArrayUtility.Concat(body, SpanUtility.UnsafeCast(parameters))) + public PrimFunction(string name, string moduleKind, Sequential body, ReadOnlySpan parameters) + : base(name, moduleKind, ArrayUtility.Concat(body, SpanUtility.UnsafeCast(parameters))) { } @@ -39,7 +39,7 @@ public PrimFunction(string name, string moduleKind, Sequential body, ReadOnlySpa /// module kind. /// Arguments. /// Body. - public PrimFunction(string moduleKind, Sequential body, ReadOnlySpan parameters) + public PrimFunction(string moduleKind, Sequential body, ReadOnlySpan parameters) : this($"primfunc_{_globalFuncIndex++}", moduleKind, body, parameters) { } @@ -48,7 +48,7 @@ public PrimFunction(string moduleKind, Sequential body, ReadOnlySpan class. /// build function. /// - public PrimFunction(string moduleKind, Sequential body, params PhysicalBuffer[] parameters) + public PrimFunction(string moduleKind, Sequential body, params Buffer[] parameters) : this($"primfunc_{_globalFuncIndex++}", moduleKind, body, new(parameters)) { } @@ -58,7 +58,7 @@ public PrimFunction(string moduleKind, Sequential body, params PhysicalBuffer[] /// public Sequential Body => (Sequential)Operands[0]; - public ReadOnlySpan Parameters => SpanUtility.UnsafeCast(Operands.Slice(1)); + public ReadOnlySpan Parameters => SpanUtility.UnsafeCast(Operands.Slice(1)); public override IEnumerable ParameterTypes => Parameters.AsValueEnumerable().Select(x => x.CheckedType).ToArray(); @@ -66,7 +66,7 @@ public PrimFunction(string moduleKind, Sequential body, params PhysicalBuffer[] public override TExprResult Accept(ExprFunctor functor, TContext context) => functor.VisitPrimFunction(this, context); - public PrimFunction With(string? name = null, string? moduleKind = null, Sequential? body = null, PhysicalBuffer[]? parameters = null, Schedule.SchedFunctionResult? sched = null) + public PrimFunction With(string? name = null, string? moduleKind = null, Sequential? body = null, Buffer[]? parameters = null, Schedule.SchedFunctionResult? sched = null) => new PrimFunction(name ?? Name, moduleKind ?? ModuleKind, body ?? Body, parameters ?? Parameters) { // note maybe add SchedResult into ctor. diff --git a/src/Nncase.Core/TIR/Script.cs b/src/Nncase.Core/TIR/Script.cs index ad1b71f0bf..453a336e7f 100644 --- a/src/Nncase.Core/TIR/Script.cs +++ b/src/Nncase.Core/TIR/Script.cs @@ -184,7 +184,7 @@ public static ISequentialBuilder Sequential() /// )); /// /// - public static ISequentialBuilder PrimFunc(string name, string module_kind, params PhysicalBuffer[] parameters) + public static ISequentialBuilder PrimFunc(string name, string module_kind, params Buffer[] parameters) { return new SequentialBuilder(body => new PrimFunction(name, module_kind, body, parameters)); } @@ -206,54 +206,73 @@ public static IIfThenElseBuilder If(Expr condition) } /// - /// create the memRef by tensortype. + /// create the buffer by tensortype. /// - public static LogicalBuffer Buffer(DataType elem_type, MemoryLocation location, ReadOnlySpan dimensions, out LogicalBuffer buffer, [CallerArgumentExpression("buffer")] string name = "") + public static Buffer CreateBuffer(TensorType tensorType, MemoryLocation location, out Buffer buffer, [CallerArgumentExpression("buffer")] string name = "") { if (name.StartsWith("var ")) { name = name[4..]; } - buffer = new LogicalBuffer(name, elem_type, location, dimensions); + var dimensions = tensorType.Shape.ToValueArray(); + var strides = TensorUtilities.GetStrides(dimensions); + var size = (int)TensorUtilities.GetProduct(dimensions.ToArray()) * tensorType.DType.SizeInBytes; + var memspan = new MemSpan(size, location); + buffer = new Buffer(name, tensorType.DType, memspan, dimensions.Select(i => (Expr)i).ToArray(), strides.Select(i => (Expr)i).ToArray()); return buffer; } /// - /// ctor for physical buffer. + /// create buffer by const. /// - public static PhysicalBuffer PhysicalBuffer(DataType elem_type, MemoryLocation location, ReadOnlySpan dimensions, out PhysicalBuffer buffer, [CallerArgumentExpression("buffer")] string name = "") + public static Buffer AttachBuffer(TensorConst @const, out Buffer buffer, [CallerArgumentExpression("buffer")] string name = "") { if (name.StartsWith("var ")) { name = name[4..]; } - buffer = new PhysicalBuffer(name, elem_type, location, dimensions, 0, (int)TensorUtilities.GetProduct(dimensions.ToArray()) * elem_type.SizeInBytes); + var dimensions = @const.ValueType.Shape.ToValueArray(); + var strides = TensorUtilities.GetStrides(dimensions); + var size = (int)TensorUtilities.GetProduct(dimensions.ToArray()) * @const.ValueType.DType.SizeInBytes; + var memspan = new MemSpan(IR.F.Buffer.DDrOf(@const), size, MemoryLocation.Rdata); + buffer = new Buffer(name, @const.ValueType.DType, memspan, dimensions.Select(i => (Expr)i).ToArray(), strides.Select(i => (Expr)i).ToArray()); return buffer; } /// - /// create buffer from const. + /// attach the buffer. /// - public static PhysicalBuffer ConstBuffer(Const expr, out PhysicalBuffer buffer, [CallerArgumentExpression("buffer")] string name = "") + public static Buffer AttachBuffer(Buffer originBuffer, Expr offset, TensorType tensorType, out Buffer buffer, [CallerArgumentExpression("buffer")] string name = "") { if (name.StartsWith("var ")) { name = name[4..]; } - int size; - if (expr is TensorConst tc) - { - size = tc.Value.BytesBuffer.Length; - } - else + var dimensions = tensorType.Shape.ToValueArray(); + var strides = TensorUtilities.GetStrides(dimensions); + var size = (int)TensorUtilities.GetProduct(dimensions.ToArray()) * tensorType.DType.SizeInBytes; + buffer = new Buffer(name, tensorType.DType, originBuffer.MemSpan.SubSpan(offset, size), dimensions.Select(i => (Expr)i).ToArray(), strides.Select(i => (Expr)i).ToArray()); + return buffer; + } + + /// + /// attach the buffer. + /// + public static Buffer AttachBuffer(TensorType tensorType, MemoryLocation location, out Var @var, out Buffer buffer, [CallerArgumentExpression("buffer")] string name = "") + { + if (name.StartsWith("var ")) { - throw new NotSupportedException(); + name = name[4..]; } - buffer = new PhysicalBuffer(name, MemoryLocation.Rdata, (TensorConst)expr, 0, size); + @var = new Var(name, TensorType.Pointer(tensorType.DType)); + var dimensions = tensorType.Shape.ToValueArray(); + var strides = TensorUtilities.GetStrides(dimensions); + var size = (int)TensorUtilities.GetProduct(dimensions.ToArray()) * tensorType.DType.SizeInBytes; + buffer = new Buffer(name, tensorType.DType, new MemSpan(@var, size, location), dimensions.Select(i => (Expr)i).ToArray(), strides.Select(i => (Expr)i).ToArray()); return buffer; } diff --git a/src/Nncase.Diagnostics/Diagnostics/ScriptPrintVisitor.cs b/src/Nncase.Diagnostics/Diagnostics/ScriptPrintVisitor.cs index 9dae59bd85..aa5d9a4126 100644 --- a/src/Nncase.Diagnostics/Diagnostics/ScriptPrintVisitor.cs +++ b/src/Nncase.Diagnostics/Diagnostics/ScriptPrintVisitor.cs @@ -216,6 +216,22 @@ protected override IPrintSymbol VisitTuple(IR.Tuple expr) return doc; } + protected override IPrintSymbol VisitMemSpan(MemSpan expr) + { + if (_exprMemo.TryGetValue(expr, out var doc)) + { + return doc; + } + + var start = Visit(expr.Start); + var size = Visit(expr.Size); + _scope.Push(); + _scope.Append($"MemSpan({start}, {size})@{expr.Location}"); + doc = new(_scope.Pop()); + _exprMemo.Add(expr, doc); + return doc; + } + /// protected override IPrintSymbol VisitMarker(Marker expr) { @@ -542,12 +558,8 @@ protected override IPrintSymbol VisitBuffer(TIR.Buffer expr) } _scope.Push(); - _scope.Append($"T.Buffer({expr.Name}, {expr.MemLocation}, {VisitType(expr.ElemType)})"); - if (expr is TIR.PhysicalBuffer phy) - { - _scope.Append($"@({phy.Start}, {phy.Size})"); - } - + var memSpan = Visit(expr.MemSpan); + _scope.Append($"T.Buffer({expr.Name}, {VisitType(expr.ElemType)}, {memSpan.Span})"); doc = new(_scope.Pop(), expr.Name, true); _exprMemo.Add(expr, doc); return doc; diff --git a/src/Nncase.Evaluator/TypeInferenceVisitor.cs b/src/Nncase.Evaluator/TypeInferenceVisitor.cs index c0f9f28f55..da0ab274b0 100644 --- a/src/Nncase.Evaluator/TypeInferenceVisitor.cs +++ b/src/Nncase.Evaluator/TypeInferenceVisitor.cs @@ -68,6 +68,28 @@ protected override IRType VisitLeafBufferRegion(BufferRegion expr) return type; } + protected override IRType VisitLeafBuffer(Nncase.TIR.Buffer expr) + { + VerifySubField(expr, expr.MemSpan, TypePatternUtility.IsTuple()); + foreach (var r in expr.Dimensions) + { + VerifySubField(expr, r, TypePatternUtility.IsIntegralScalar()); + } + + foreach (var r in expr.Strides) + { + VerifySubField(expr, r, TypePatternUtility.IsIntegralScalar()); + } + + var type = new TensorType(expr.ElemType, expr.Dimensions.AsValueEnumerable().Select(e => e switch + { + TensorConst { Value: { Shape: { IsScalar: true } } t } => new Dimension(t.ToScalar()), + _ => Dimension.Unknown, + }).ToArray()); + + return type; + } + /// protected override IRType VisitLeafCall(Call expr) { @@ -174,13 +196,6 @@ protected override IRType VisitLeafLet(Let expr) return type; } - /// - protected override IRType VisitLeafLogicalBuffer(LogicalBuffer expr) - { - var type = new TensorType(expr.ElemType, Shape.Unknown(expr.Rank)); - return type; - } - /// protected override IRType VisitLeafMarker(Marker expr) { @@ -203,13 +218,6 @@ protected override IRType VisitLeafOp(Op expr) return type; } - /// - protected override IRType VisitLeafPhysicalBuffer(PhysicalBuffer expr) - { - var type = new TensorType(expr.ElemType, new(expr.FixedDimensions)); - return type; - } - /// protected override IRType VisitLeafPrimFunction(PrimFunction expr) { @@ -270,6 +278,13 @@ protected override IRType VisitLeafVar(Var expr) return type; } + protected override IRType VisitLeafMemSpan(MemSpan expr) + { + VerifySubField(expr, expr.Start, TypePatternUtility.IsNoneType() | TypePatternUtility.IsIntegralScalar() | TypePatternUtility.IsPointer()); + VerifySubField(expr, expr.Size, TypePatternUtility.IsIntegralScalar()); + return TupleType.Void; + } + /// protected override IRType VisitLet(Let expr) { diff --git a/src/Nncase.Passes/DDrBufferSchdeulePass.cs b/src/Nncase.Passes/DDrBufferSchdeulePass.cs index e26a62ec0c..e21f833763 100644 --- a/src/Nncase.Passes/DDrBufferSchdeulePass.cs +++ b/src/Nncase.Passes/DDrBufferSchdeulePass.cs @@ -23,9 +23,9 @@ namespace Nncase.Passes; /// public sealed class DDrBufferSchdeulePass : ModulePass { - private readonly Dictionary> _module_usage = new(); + private readonly Dictionary> _moduleUsage = new(); - private readonly Dictionary> _module_hashset = new(); + private readonly Dictionary> _moduleRdataMaps = new(); private readonly bool _enbaleMergeCall; @@ -40,6 +40,7 @@ public DDrBufferSchdeulePass(bool enableMergeCall = false) protected override async Task RunCoreAsync(IRModule module, RunPassContext options) { // 1. merge the all call prim func +#if false if (_enbaleMergeCall) { HashSet mergedFuncs = new(ReferenceEqualityComparer.Instance); @@ -78,6 +79,7 @@ protected override async Task RunCoreAsync(IRModule module, RunPassCon module.Remove(item); } } +#endif // 4. schedule the prim funcs. for (int i = 0; i < module.Functions.Count; i++) @@ -86,149 +88,115 @@ protected override async Task RunCoreAsync(IRModule module, RunPassCon { if (!prim_func.SchedResult.IsScheduled) { - var ddr_allocator = new DDrBufferAllocator(_module_usage, _module_hashset); - ddr_allocator.Visit(prim_func); // changed ddr buffer. - prim_func.SchedResult.DataUsage = ddr_allocator.DataUsage; - prim_func.SchedResult.IsScheduled = ddr_allocator.Changed; + var rewriter = new DDrBufferRewriter(_moduleUsage, _moduleRdataMaps); + var post = (TIR.PrimFunction)rewriter.Rewrite(prim_func); // changed ddr buffer. + if (rewriter.IsMutated) + { + post.SchedResult.DataUsage = rewriter.DataUsage; + post.SchedResult.IsScheduled = true; + } + + module.Replace(i, prim_func); } } } - _module_hashset.Clear(); - _module_usage.Clear(); + _moduleRdataMaps.Clear(); + _moduleUsage.Clear(); return await Task.FromResult(module); } } -/// -/// collect and assgin the PhysicalBuffer. -/// -internal sealed class DDrBufferAllocator : ExprVisitor +internal sealed class DDrBufferRewriter : ExprRewriter { private readonly Dictionary _functionUsage; - private readonly HashSet _functionHashset; + private readonly Dictionary _functionRdatas; - private PrimFunction? _entry; - - public DDrBufferAllocator(Dictionary> module_usage, Dictionary> module_hashset) + public DDrBufferRewriter(Dictionary> moduleUsage, Dictionary> moduleRdataMaps) { - ModuleUsage = module_usage; - ModuleHashSet = module_hashset; + ModuleUsage = moduleUsage; + ModuleRdataMaps = moduleRdataMaps; _functionUsage = new(); - _functionHashset = new(ReferenceEqualityComparer.Instance); + _functionRdatas = new(); Changed = false; } public Dictionary> ModuleUsage { get; } - public Dictionary> ModuleHashSet { get; } + public Dictionary> ModuleRdataMaps { get; } public bool Changed { get; private set; } public int DataUsage => _functionUsage.GetValueOrDefault(MemoryLocation.Data, 0); - /// - /// only visit one prim func. - /// - protected override bool VisitPrimFunction(PrimFunction primFunction) - { - _entry ??= primFunction; - if (object.ReferenceEquals(_entry, primFunction)) - { - foreach (var physical in primFunction.Parameters) - { - if (physical.MemLocation is MemoryLocation.Input or MemoryLocation.Output) - { - // avoid visit same buffer - if (!_functionHashset.Contains(physical)) - { - // input/output write into the FunctionUsage - if (!_functionUsage.TryGetValue(physical.MemLocation, out var start)) - { - start = 0; - } - - physical.Start = start; - _functionUsage[physical.MemLocation] = start + physical.Size; - _functionHashset.Add(physical); - Changed = true; - } - } - else - { - throw new NotSupportedException($"The prim function parameters mem location must be input/output but get {physical.MemLocation}!"); - } - } - - return base.VisitPrimFunction(_entry); - } - - return true; - } + public PrimFunction Entry => (PrimFunction)VisitRoot!; - protected override bool VisitLeafBuffer(TIR.Buffer buffer) + protected override TIR.MemSpan RewriteLeafMemSpan(TIR.MemSpan memSpan) { - if (buffer is not TIR.PhysicalBuffer physical) - { - return true; - } - - // rdata write into the moduleUsage - if (physical.MemLocation is MemoryLocation.Rdata) + if (memSpan is { Location: MemoryLocation.Rdata, Start: Call { Target: IR.Buffers.DDrOf, Arguments: var arg } } && arg[0] is Const @const) { - if (!ModuleHashSet.TryGetValue(_entry!.ModuleKind, out var module_hashset)) + if (!ModuleRdataMaps.TryGetValue(Entry.ModuleKind, out var moduleRdataMap)) { - module_hashset = new(ReferenceEqualityComparer.Instance); - ModuleHashSet.Add(_entry!.ModuleKind, module_hashset); + moduleRdataMap = new(); + ModuleRdataMaps.Add(Entry.ModuleKind, moduleRdataMap); } - if (!ModuleUsage.TryGetValue(_entry!.ModuleKind, out var module_usage)) + if (!ModuleUsage.TryGetValue(Entry.ModuleKind, out var moduleUsage)) { - module_usage = new(); - ModuleUsage.Add(_entry!.ModuleKind, module_usage); + moduleUsage = new(); + ModuleUsage.Add(Entry.ModuleKind, moduleUsage); } - if (!module_hashset.Contains(physical)) + if (!moduleRdataMap.TryGetValue(@const, out var memRange)) { - if (!module_usage.TryGetValue(physical.MemLocation, out var start)) + if (!moduleUsage.TryGetValue(memSpan.Location, out var start)) { start = 0; } - physical.Start = start; - module_usage[physical.MemLocation] = start + physical.Size; - module_hashset.Add(physical); - _entry.SchedResult.Rdatas.Add(physical); - + _ = ComputeSize(@const); + moduleUsage[memSpan.Location] = start + ComputeSize(@const); + memRange = start..(start + ComputeSize(@const)); + moduleRdataMap.Add(@const, memRange); + Entry.SchedResult.Rdatas.Add(@const, memRange); Changed = true; } - } - else if (physical.MemLocation is MemoryLocation.Data) - { - // data write into the FunctionUsage - if (!_functionHashset.Contains(physical)) - { - if (!_functionUsage.TryGetValue(physical.MemLocation, out var start)) - { - start = 0; - } - physical.Start = start; - _functionUsage[physical.MemLocation] = start + physical.Size; - _functionHashset.Add(physical); - Changed = true; - } - } - else if (physical.MemLocation is MemoryLocation.SharedData) - { - throw new NotSupportedException("Current Not Support!"); + return memSpan.With(memRange.Start.Value, memRange.End.Value - memRange.Start.Value); } - return true; + // else if (memSpan.Location is MemoryLocation.Data) + // { + // data write into the FunctionUsage + // if (!_functionRdatas.Contains(physical)) + // { + // if (!_functionUsage.TryGetValue(physical.Location, out var start)) + // { + // start = 0; + // } + + // physical.Start = start; + // _functionUsage[physical.Location] = start + physical.Size; + // _functionRdatas.Add(physical); + // Changed = true; + // } + // } + // else if (memSpan.Location is MemoryLocation.SharedData) + // { + // throw new NotSupportedException("Current Not Support!"); + // } + return memSpan; } - protected override bool DefaultVisitLeaf(Expr expr) => true; + private int ComputeSize(IValue v) => v.AsTensors().Select(t => t.BytesBuffer.Length).Sum(); + + private int ComputeSize(Const @const) => @const switch + { + TensorConst { Value: Tensor tc } => tc.BytesBuffer.Length, + TupleConst tc => ComputeSize(tc.Value), + _ => throw new NotSupportedException(), + }; } internal sealed class ExternalFuncCollector : ExprWalker diff --git a/src/Nncase.Passes/Rules/Neutral/PrimFuncMergeRule.cs b/src/Nncase.Passes/Rules/Neutral/PrimFuncMergeRule.cs index 9e1c2bc22c..16e9b6727c 100644 --- a/src/Nncase.Passes/Rules/Neutral/PrimFuncMergeRule.cs +++ b/src/Nncase.Passes/Rules/Neutral/PrimFuncMergeRule.cs @@ -21,6 +21,7 @@ namespace Nncase.Passes.Rules.Neutral; +#if false [RuleGenerator] public sealed partial class PrimFuncMergeRule : RewriteRule { @@ -191,3 +192,4 @@ protected override Expr VisitVar(Var var, Unit context) } } } +#endif diff --git a/src/Nncase.Tests/Core/UnitTestExpression.cs b/src/Nncase.Tests/Core/UnitTestExpression.cs index 9cbe60f849..59741a8735 100644 --- a/src/Nncase.Tests/Core/UnitTestExpression.cs +++ b/src/Nncase.Tests/Core/UnitTestExpression.cs @@ -261,8 +261,8 @@ public void TestDenseTensorLength() public void TestConstBufferNotEqual() { var c = IR.F.Random.Normal(DataTypes.Float32, 1, 0, 0, new[] { 1, 16, 64, 400 }).Evaluate().AsTensor(); - var ddr_ld_input = new TIR.BufferRegion(Nncase.TIR.T.ConstBuffer(Const.FromTensor(c), out _, "ddr_ld_input"), new(new TIR.Range[] { 0..1, 0..16, 0..31, 0..400 })); - var ddr_ld_output = new TIR.BufferRegion(new TIR.PhysicalBuffer("ddr_ld_input", DataTypes.Float32, MemoryLocation.Input, new[] { 1, 16, 64, 400 }, TensorUtilities.GetStrides(new[] { 1, 16, 64, 400 }), 0, 0), new(new TIR.Range[] { 0..1, 0..16, 0..31, 0..400 })); + var ddr_ld_input = new TIR.BufferRegion(TIR.T.AttachBuffer(Const.FromTensor(c), out _, "ddr_ld_input"), new(new TIR.Range[] { 0..1, 0..16, 0..31, 0..400 })); + var ddr_ld_output = new TIR.BufferRegion(new TIR.Buffer("ddr_ld_input", DataTypes.Float32, new MemSpan(0, 0, MemoryLocation.Input), new Expr[] { 1, 16, 64, 400 }, TensorUtilities.GetStrides(new Expr[] { 1, 16, 64, 400 })), new(new TIR.Range[] { 0..1, 0..16, 0..31, 0..400 })); Assert.NotEqual(ddr_ld_input.Buffer, ddr_ld_output.Buffer); Assert.NotEqual(ddr_ld_input, ddr_ld_output); } @@ -270,8 +270,8 @@ public void TestConstBufferNotEqual() [Fact] public void TestBufferEqual() { - var ddr_ld_input = new TIR.BufferRegion(new TIR.PhysicalBuffer("ddr_ld_input", DataTypes.Float32, MemoryLocation.Input, new[] { 1, 16, 64, 400 }, TensorUtilities.GetStrides(new[] { 1, 16, 64, 400 }), 0, 0), new(new TIR.Range[] { 0..1, 0..16, 0..31, 0..400 })); - var ddr_ld_output = new TIR.BufferRegion(new TIR.PhysicalBuffer("ddr_ld_input", DataTypes.Float32, MemoryLocation.Input, new[] { 1, 16, 64, 400 }, TensorUtilities.GetStrides(new[] { 1, 16, 64, 400 }), 0, 0), new(new TIR.Range[] { 0..1, 0..16, 0..31, 0..400 })); + var ddr_ld_input = new TIR.BufferRegion(new TIR.Buffer("ddr_ld_input", DataTypes.Float32, new MemSpan(0, 0, MemoryLocation.Input), new Expr[] { 1, 16, 64, 400 }, TensorUtilities.GetStrides(new Expr[] { 1, 16, 64, 400 })), new(new TIR.Range[] { 0..1, 0..16, 0..31, 0..400 })); + var ddr_ld_output = new TIR.BufferRegion(new TIR.Buffer("ddr_ld_input", DataTypes.Float32, new MemSpan(0, 0, MemoryLocation.Input), new Expr[] { 1, 16, 64, 400 }, TensorUtilities.GetStrides(new Expr[] { 1, 16, 64, 400 })), new(new TIR.Range[] { 0..1, 0..16, 0..31, 0..400 })); Assert.Equal(ddr_ld_input.Buffer, ddr_ld_output.Buffer); Assert.Equal(ddr_ld_input, ddr_ld_output); } @@ -279,8 +279,8 @@ public void TestBufferEqual() [Fact] public void TestBufferNotEqual() { - var ddr_ld_input = new TIR.BufferRegion(new TIR.PhysicalBuffer("ddr_ld_input", DataTypes.Float32, MemoryLocation.Input, new[] { 1, 16, 64, 400 }, TensorUtilities.GetStrides(new[] { 1, 16, 64, 400 }), 0, 0), new(new TIR.Range[] { 0..1, 0..16, 0..31, 0..400 })); - var glb_ld_output = new TIR.BufferRegion(new TIR.PhysicalBuffer("glb_ld_output", DataTypes.BFloat16, MemoryLocation.Data, new[] { 1, 16, 64, 400 }, TensorUtilities.GetStrides(new[] { 1, 16, 64, 400 }), 0, 0), new(new TIR.Range[] { 0..1, 0..16, 0..31, 0..400 })); + var ddr_ld_input = new TIR.BufferRegion(new TIR.Buffer("ddr_ld_input", DataTypes.Float32, new MemSpan(0, 0, MemoryLocation.Input), new Expr[] { 1, 16, 64, 400 }, TensorUtilities.GetStrides(new Expr[] { 1, 16, 64, 400 })), new(new TIR.Range[] { 0..1, 0..16, 0..31, 0..400 })); + var glb_ld_output = new TIR.BufferRegion(new TIR.Buffer("glb_ld_output", DataTypes.BFloat16, new MemSpan(0, 0, MemoryLocation.Data), new Expr[] { 1, 16, 64, 400 }, TensorUtilities.GetStrides(new Expr[] { 1, 16, 64, 400 })), new(new TIR.Range[] { 0..1, 0..16, 0..31, 0..400 })); Assert.False(ddr_ld_input.Buffer.Equals(glb_ld_output.Buffer)); Assert.False(ddr_ld_input.Equals(glb_ld_output)); } diff --git a/src/Nncase.Tests/Core/UnitTestStringUtility.cs b/src/Nncase.Tests/Core/UnitTestStringUtility.cs index c17f577ee2..5505371db5 100644 --- a/src/Nncase.Tests/Core/UnitTestStringUtility.cs +++ b/src/Nncase.Tests/Core/UnitTestStringUtility.cs @@ -16,14 +16,14 @@ namespace Nncase.Tests.CoreTest; public static class TestExtensions { - public static ArrayExtensions.SpanWhereEnumerable> InputOf(this ReadOnlySpan arr) => arr.AsValueEnumerable().Where(b => b.MemLocation == MemoryLocation.Input); + public static ArrayExtensions.SpanWhereEnumerable> InputOf(this ReadOnlySpan arr) => arr.AsValueEnumerable().Where(b => b.MemSpan.Location == MemoryLocation.Input); - public static ArrayExtensions.SpanWhereEnumerable> OutputOf(this ReadOnlySpan arr) => arr.AsValueEnumerable().Where(b => b.MemLocation == MemoryLocation.Output); + public static ArrayExtensions.SpanWhereEnumerable> OutputOf(this ReadOnlySpan arr) => arr.AsValueEnumerable().Where(b => b.MemSpan.Location == MemoryLocation.Output); } public sealed class UnitTestStringUtility { - private readonly TIR.PrimFunction _entry = new("test_module", new Sequential(1), new TIR.PhysicalBuffer("testInput", DataTypes.Float32, MemoryLocation.Input, new[] { 1, 16, 64, 400 }, TensorUtilities.GetStrides(new[] { 1, 16, 64, 400 }), 0, 0), new TIR.PhysicalBuffer("testInput", DataTypes.Float32, MemoryLocation.Input, new[] { 1, 16, 64, 400 }, TensorUtilities.GetStrides(new[] { 1, 16, 64, 400 }), 0, 0)); + private readonly TIR.PrimFunction _entry = new("test_module", new Sequential(1), new TIR.Buffer("testInput", DataTypes.Float32, new MemSpan(0, 123, MemoryLocation.Input), new Expr[] { 1, 16, 64, 400 }, TensorUtilities.GetStrides(new Expr[] { 1, 16, 64, 400 })), new TIR.Buffer("testInput", DataTypes.Float32, new MemSpan(0, 123, MemoryLocation.Output), new Expr[] { 1, 16, 64, 400 }, TensorUtilities.GetStrides(new Expr[] { 1, 16, 64, 400 }))); [Fact] public void TestJoin() diff --git a/src/Nncase.Tests/Core/UnitTestTIR.cs b/src/Nncase.Tests/Core/UnitTestTIR.cs index 9f3b78cf12..c43b9d52b1 100644 --- a/src/Nncase.Tests/Core/UnitTestTIR.cs +++ b/src/Nncase.Tests/Core/UnitTestTIR.cs @@ -21,15 +21,6 @@ namespace Nncase.Tests.CoreTest; public sealed class UnitTestTIR { - [Fact] - public void TestLogicalBuffer() - { - var logicalBuffer1 = new LogicalBuffer("logicalBuffer", default, new TensorConst(new Tensor(new[] { 1 }))); - var logicalBuffer2 = new LogicalBuffer("logicalBuffer", DataTypes.Int32, default, new[] { (Expr)new Tensor(new[] { 1 }) }); - Assert.Equal(logicalBuffer2.Length.ToString(), logicalBuffer1.Length.ToString()); - Assert.Equal("LogicalBuffer(logicalBuffer, i32, MemLocation)", logicalBuffer1.ToString()); - } - [Fact] public void TestScheduler() { @@ -48,9 +39,9 @@ public void TestScheduler() public void TestBufferStore() { Expr value = 42; - var physicalBuffer = new TIR.PhysicalBuffer("testInput", DataTypes.Float32, MemoryLocation.Input, new[] { 1, 16, 64, 400 }, TensorUtilities.GetStrides(new[] { 1, 16, 64, 400 }), 0, 0); + TIR.T.CreateBuffer(new TensorType(DataTypes.Float32, new[] { 1, 16, 64, 400 }), MemoryLocation.Input, out var testInput); _ = new Expr[] { 0, 1 }; - _ = T.Store(physicalBuffer, 0, value); + _ = T.Store(testInput, 0, value); } [Fact] @@ -95,15 +86,6 @@ public void TestSequential() Assert.Equal(expect2, actual2); } - [Fact] - public void TestBuffer() - { - var buffer = T.Buffer(DataTypes.Float32, MemoryLocation.Input, new Expr[] { 1, 16, 64, 400 }, out _); - Assert.Equal(DataTypes.Float32, buffer.ElemType); - var expect = new LogicalBuffer("_", DataTypes.Float32, MemoryLocation.Input, new Expr[] { 1, 16, 64, 400 }); - Assert.Equal(expect, buffer); - } - [Fact] public void TestForSegment() { @@ -132,7 +114,7 @@ public void TestEmit() [Fact] public void TestBufferRegion() { - var buffer = T.Buffer(DataTypes.Float32, MemoryLocation.Input, new Expr[] { 1, 16, 64, 400 }, out _); + var buffer = T.CreateBuffer(new(DataTypes.Float32, new[] { 1, 16, 64, 400 }), MemoryLocation.Input, out _); var region = new Range[] { new Range(1, 2, 2), new Range(-1, 3, 2) }; var bufferRegion = new BufferRegion(buffer, region); @@ -154,8 +136,8 @@ public void TestPrimFunction() { var primFunc = new PrimFunction("test_module", new Sequential(new Expr[] { 1 }), new[] { - new TIR.PhysicalBuffer("testInput", DataTypes.Float32, MemoryLocation.Input, new[] { 1, 16, 64, 400 }, TensorUtilities.GetStrides(new[] { 1, 16, 64, 400 }), 0, 0), - new TIR.PhysicalBuffer("testInput", DataTypes.Float32, MemoryLocation.Input, new[] { 1, 16, 64, 400 }, TensorUtilities.GetStrides(new[] { 1, 16, 64, 400 }), 0, 0), + TIR.T.CreateBuffer(new(DataTypes.Float32, new[] { 1, 16, 64, 400 }), MemoryLocation.Input, out var _), + TIR.T.CreateBuffer(new(DataTypes.Float32, new[] { 1, 16, 64, 400 }), MemoryLocation.Input, out var _), }); var primFuncParameters = primFunc.Parameters; @@ -167,8 +149,8 @@ public void TestPrimFunction() var newBody = new Sequential(new Expr[] { 3 }); var newParams = new[] { - new TIR.PhysicalBuffer("testInput", DataTypes.Float32, MemoryLocation.Input, new[] { 1, 16, 64, 400 }, TensorUtilities.GetStrides(new[] { 1, 16, 64, 400 }), 0, 0), - new TIR.PhysicalBuffer("testInput", DataTypes.Float32, MemoryLocation.Input, new[] { 1, 16, 64, 400 }, TensorUtilities.GetStrides(new[] { 1, 16, 64, 400 }), 0, 0), + TIR.T.CreateBuffer(new(DataTypes.Float32, new[] { 1, 16, 64, 400 }), MemoryLocation.Input, out var _), + TIR.T.CreateBuffer(new(DataTypes.Float32, new[] { 1, 16, 64, 400 }), MemoryLocation.Input, out var _), }; var newPrimFunc = primFunc.With(moduleKind: newModuleKind, body: newBody, parameters: newParams); @@ -179,7 +161,7 @@ public void TestPrimFunction() Assert.Equal(newParams, newPrimFunc.Parameters.ToArray()); Assert.Equal(primFunc.Name, newPrimFunc.Name); // should not change the name - Assert.NotNull(new PrimFunction("test_module", new Sequential(new Expr[] { 1 }), default(ReadOnlySpan))); + Assert.NotNull(new PrimFunction("test_module", new Sequential(new Expr[] { 1 }), default(ReadOnlySpan))); } [Fact] diff --git a/src/Nncase.Tests/Diagnostics/UnitTestDumpper.cs b/src/Nncase.Tests/Diagnostics/UnitTestDumpper.cs index bd9bcc3c5c..9e2f0f3a44 100644 --- a/src/Nncase.Tests/Diagnostics/UnitTestDumpper.cs +++ b/src/Nncase.Tests/Diagnostics/UnitTestDumpper.cs @@ -65,7 +65,7 @@ public void TestDumpFusion() [Fact] public void TestDumpScript() { - var prim_func_1 = T.PrimFunc("prim_func_1", "k?", T.PhysicalBuffer(DataTypes.Float32, MemoryLocation.Input, new[] { 1, 2, 3, 4 }, out _), T.PhysicalBuffer(DataTypes.Float32, MemoryLocation.Output, new[] { 1, 2, 3, 4 }, out _)).Body(T.Nop()).Build(); + var prim_func_1 = T.PrimFunc("prim_func_1", "k?", T.CreateBuffer(new(DataTypes.Float32, new[] { 1, 2, 3, 4 }), MemoryLocation.Input, out _), T.CreateBuffer(new(DataTypes.Float32, new[] { 1, 2, 3, 4 }), MemoryLocation.Output, out _)).Body(T.Nop()).Build(); Assert.True(CompilerServices.InferenceType(prim_func_1)); diff --git a/src/Nncase.Tests/Evaluator/UnitTestEvaluator.cs b/src/Nncase.Tests/Evaluator/UnitTestEvaluator.cs index 0fdb6fe4da..037e40793b 100755 --- a/src/Nncase.Tests/Evaluator/UnitTestEvaluator.cs +++ b/src/Nncase.Tests/Evaluator/UnitTestEvaluator.cs @@ -97,7 +97,7 @@ public void TestOnnxResizeImage() public void TestLoadStore() { var loop_i = new Var(TensorType.Scalar(DataTypes.Int32)); - T.Buffer(DataTypes.Float32, MemoryLocation.Input, new Expr[] { 1, 2, 3 }, out var bf); + T.CreateBuffer(new(DataTypes.Float32, new[] { 1, 2, 3 }), MemoryLocation.Input, out var bf); var load = T.Load(bf, loop_i); CompilerServices.InferenceType(load); var store = T.Store(bf, loop_i, IR.F.Tensors.Cast(loop_i, DataTypes.Float32)); diff --git a/src/Nncase.Tests/TIR/PrimFunc/IDataFlowPrimFuncCase.cs b/src/Nncase.Tests/TIR/PrimFunc/IDataFlowPrimFuncCase.cs index e514c3e09a..0721bfd29d 100644 --- a/src/Nncase.Tests/TIR/PrimFunc/IDataFlowPrimFuncCase.cs +++ b/src/Nncase.Tests/TIR/PrimFunc/IDataFlowPrimFuncCase.cs @@ -71,7 +71,7 @@ public static PrimFunctionWrapper MakeBinaryFunc(BinaryOp binaryOp, bool mask) public static PrimFunctionWrapper MakeMultiInputFunc(int length, bool mask) { var allocator = new Allocator(); - var fusion_inputs = new List(); + var fusion_inputs = new List(); for (int i = 0; i < length; i++) { var fusion_input_i = allocator.Allocate($"fusion_{_count}_input_{i}", TIR.MemoryLocation.Input); @@ -124,18 +124,19 @@ private static IEnumerable GetBinaryOp(int length) private sealed class Allocator { - private readonly Dictionary _useage = new() { + private readonly Dictionary _usage = new() { { TIR.MemoryLocation.Input, 0 }, { TIR.MemoryLocation.Output, 0 }, { TIR.MemoryLocation.L2Data, 0 }, }; - public TIR.PhysicalBuffer Allocate(string name, TIR.MemoryLocation location) + public TIR.Buffer Allocate(string name, TIR.MemoryLocation location) { - var strides = TensorUtilities.GetStrides(Dimensions); - var size = TensorUtilities.GetSize(Dimensions, strides, DataTypes.Float32.SizeInBytes); - var buffer = new TIR.PhysicalBuffer(name, DataTypes.Float32, location, Dimensions, strides, _useage[location], size); - _useage[location] += size; + var dims = Dimensions.Select(d => (Expr)d).ToArray(); + var strides = TensorUtilities.GetStrides(Dimensions).Select(s => (Expr)s).ToArray(); + var size = TensorUtilities.GetSize(Dimensions, TensorUtilities.GetStrides(Dimensions), DataTypes.Float32.SizeInBytes); + var buffer = new TIR.Buffer(name, DataTypes.Float32, new TIR.MemSpan(_usage[location], size, location), dims, strides); + _usage[location] += size; return buffer; } } diff --git a/src/Nncase.Tests/TIR/PrimFunc/UnitTestPrimFuncMerge.cs b/src/Nncase.Tests/TIR/PrimFunc/UnitTestPrimFuncMerge.cs index f5183e5127..1b5ce74461 100644 --- a/src/Nncase.Tests/TIR/PrimFunc/UnitTestPrimFuncMerge.cs +++ b/src/Nncase.Tests/TIR/PrimFunc/UnitTestPrimFuncMerge.cs @@ -139,8 +139,8 @@ public IValue Evaluate() // 1. copy input into input pool foreach (var (arg, param) in _args.Zip(_wrapper.Target.Parameters[.._wrapper.ParametersCount].ToArray())) { - Assert.Equal(param.Size, arg.AsTensor().BytesBuffer.Length); - arg.AsTensor().BytesBuffer.CopyTo(_poolMap[param.MemLocation].AsSpan(param.Start)); + Assert.Equal(param.MemSpan.Size.Evaluate().AsTensor().ToScalar(), arg.AsTensor().BytesBuffer.Length); + arg.AsTensor().BytesBuffer.CopyTo(_poolMap[param.MemSpan.Location].AsSpan(param.MemSpan.Start.Evaluate().AsTensor().ToScalar())); } // 2. start l2 computing @@ -153,7 +153,7 @@ public IValue Evaluate() var tensors = new List(); foreach (var outputParam in _wrapper.Target.Parameters[_wrapper.ParametersCount..]) { - tensors.Add(Tensor.FromBytes(outputParam.ElemType, GetBufferSpan(outputParam).ToArray(), outputParam.FixedDimensions.ToArray())); + tensors.Add(Tensor.FromBytes(outputParam.ElemType, GetBufferSpan(outputParam).ToArray(), outputParam.Dimensions.AsValueEnumerable().Select(e => e.Evaluate().AsTensor().ToScalar()).ToArray())); } return tensors.Count == 1 ? Value.FromTensor(tensors[0]) : Value.FromTensors(tensors.ToArray()); @@ -208,7 +208,7 @@ private void EvaluateStatement(Expr statement) private Span GetBufferSpan(Expr expr) { - var buffer = Assert.IsType(expr); - return _poolMap[buffer.MemLocation].AsSpan(buffer.Start, buffer.Size); + var buffer = Assert.IsType(expr); + return _poolMap[buffer.MemSpan.Location].AsSpan(buffer.MemSpan.Start.Evaluate().AsTensor().ToScalar(), buffer.MemSpan.Size.Evaluate().AsTensor().ToScalar()); } } diff --git a/src/Nncase.Tests/TIR/UnitTestMutators.cs b/src/Nncase.Tests/TIR/UnitTestMutators.cs index a20d5f7295..7317eb1176 100644 --- a/src/Nncase.Tests/TIR/UnitTestMutators.cs +++ b/src/Nncase.Tests/TIR/UnitTestMutators.cs @@ -30,9 +30,9 @@ public UnitTestMutators() [Fact] public async Task TestFoldConstCallWithTuple() { - T.PhysicalBuffer(DataTypes.BFloat16, MemoryLocation.Input, new[] { 48 }, out var ddr_if); - T.PhysicalBuffer(DataTypes.BFloat16, MemoryLocation.Data, new[] { 9 }, out var glb_if_ping); - T.PhysicalBuffer(DataTypes.BFloat16, MemoryLocation.Data, new[] { 9 }, out var glb_if_pong); + T.CreateBuffer(new TensorType(DataTypes.BFloat16, new[] { 48 }), MemoryLocation.Input, out var ddr_if); + T.CreateBuffer(new TensorType(DataTypes.BFloat16, new[] { 9 }), MemoryLocation.Data, out var glb_if_ping); + T.CreateBuffer(new TensorType(DataTypes.BFloat16, new[] { 9 }), MemoryLocation.Data, out var glb_if_pong); PrimFunction main; { main = T.PrimFunc("main", Callable.StackVMModuleKind, ddr_if).Body( @@ -118,8 +118,8 @@ public async Task TestUnRollLoopSequential() [Fact] public async Task TestUnRollLoopSequential2() { - T.PhysicalBuffer(DataTypes.BFloat16, MemoryLocation.Input, new[] { 3, 16, 24, 24 }, out var ddr_if); - T.PhysicalBuffer(DataTypes.BFloat16, MemoryLocation.Data, new[] { 3, 10, 5, 9 }, out var glb_if); + T.CreateBuffer(new TensorType(DataTypes.BFloat16, new[] { 3, 16, 24, 24 }), MemoryLocation.Input, out var ddr_if); + T.CreateBuffer(new TensorType(DataTypes.BFloat16, new[] { 3, 10, 5, 9 }), MemoryLocation.Data, out var glb_if); PrimFunction main; { @@ -201,8 +201,8 @@ public async Task TestUnRollLoopSequential2() [Fact] public async Task TestUnRollLoopSequential3() { - T.PhysicalBuffer(DataTypes.BFloat16, MemoryLocation.Input, new[] { 3, 16, 24, 24 }, out var ddr_if); - T.PhysicalBuffer(DataTypes.BFloat16, MemoryLocation.Data, new[] { 3, 10, 5, 9 }, out var glb_if); + T.CreateBuffer(new TensorType(DataTypes.BFloat16, new[] { 3, 16, 24, 24 }), MemoryLocation.Input, out var ddr_if); + T.CreateBuffer(new TensorType(DataTypes.BFloat16, new[] { 3, 10, 5, 9 }), MemoryLocation.Data, out var glb_if); PrimFunction main; { @@ -362,10 +362,10 @@ public async Task TestFoldLet2() [Fact] public async Task TestFoldBufferIndex() { - T.PhysicalBuffer(DataTypes.BFloat16, MemoryLocation.Input, new[] { 3, 16, 24, 24 }, out var ddr_if); - T.PhysicalBuffer(DataTypes.BFloat16, MemoryLocation.Output, new[] { 3, 16, 24, 24 }, out var ddr_of); - T.PhysicalBuffer(DataTypes.BFloat16, MemoryLocation.Data, new[] { 3, 10, 5, 9 }, out var glb_if); - var bufferIndexMap = new Dictionary() { + T.CreateBuffer(new(DataTypes.BFloat16, new[] { 3, 16, 24, 24 }), MemoryLocation.Input, out var ddr_if); + T.CreateBuffer(new(DataTypes.BFloat16, new[] { 3, 16, 24, 24 }), MemoryLocation.Output, out var ddr_of); + T.CreateBuffer(new(DataTypes.BFloat16, new[] { 3, 10, 5, 9 }), MemoryLocation.Data, out var glb_if); + var bufferIndexMap = new Dictionary() { { ddr_if, 2 }, { ddr_of, 4 }, }; @@ -386,7 +386,7 @@ public async Task TestFoldBufferIndex() pass.Add(); pass.Add(Expr? (Expr e) => { - if (e is Call { } call && call.Arguments[0] is PhysicalBuffer physicalBuffer && bufferIndexMap.TryGetValue(physicalBuffer, out var index)) + if (e is Call { } call && call.Arguments[0] is Buffer physicalBuffer && bufferIndexMap.TryGetValue(physicalBuffer, out var index)) { return index; } diff --git a/src/Nncase.Tests/Transform/UnitTestPassManager.cs b/src/Nncase.Tests/Transform/UnitTestPassManager.cs index 85d4cf1442..cfc76bb4eb 100644 --- a/src/Nncase.Tests/Transform/UnitTestPassManager.cs +++ b/src/Nncase.Tests/Transform/UnitTestPassManager.cs @@ -22,7 +22,7 @@ public sealed class UnitTestPassManager : TestClassBase [Fact] public void TestPassMangerUpdateDependence() { - var prim_func_1 = T.PrimFunc("prim_func_1", "k?", T.PhysicalBuffer(DataTypes.Float32, MemoryLocation.Input, new[] { 1, 2, 3, 4 }, out _), T.PhysicalBuffer(DataTypes.Float32, MemoryLocation.Output, new[] { 1, 2, 3, 4 }, out _)).Body(T.Nop()).Build(); + var prim_func_1 = T.PrimFunc("prim_func_1", "k?", T.CreateBuffer(new(DataTypes.Float32, new[] { 1, 2, 3, 4 }), MemoryLocation.Input, out _), T.CreateBuffer(new(DataTypes.Float32, new[] { 1, 2, 3, 4 }), MemoryLocation.Output, out _)).Body(T.Nop()).Build(); var prim_wrapper = new PrimFunctionWrapper(prim_func_1, 1); @@ -30,7 +30,7 @@ public void TestPassMangerUpdateDependence() var main_func = new Function("main", new Call(prim_wrapper, input), input); // prim_func_2 for update - var prim_func_2 = T.PrimFunc("prim_func_2", "k?", T.PhysicalBuffer(DataTypes.Float32, MemoryLocation.Input, new[] { 1, 2, 3, 4 }, out _), T.PhysicalBuffer(DataTypes.Float32, MemoryLocation.Output, new[] { 1, 2, 3, 4 }, out _)).Body( + var prim_func_2 = T.PrimFunc("prim_func_2", "k?", T.CreateBuffer(new(DataTypes.Float32, new[] { 1, 2, 3, 4 }), MemoryLocation.Input, out _), T.CreateBuffer(new(DataTypes.Float32, new[] { 1, 2, 3, 4 }), MemoryLocation.Output, out _)).Body( T.Nop(), T.Nop()).Build(); @@ -54,15 +54,15 @@ public void TestPassMangerUpdateDependence2() %3 = %func_3(%2): // f16[1,23,30,16] */ - var prim_func_0 = T.PrimFunc("prim_func_0", "k?", T.PhysicalBuffer(DataTypes.Float32, MemoryLocation.Input, new[] { 1, 24, 32, 3 }, out var _), T.PhysicalBuffer(DataTypes.Float32, MemoryLocation.Output, new[] { 1, 3, 24, 32 }, out var _)).Body( + var prim_func_0 = T.PrimFunc("prim_func_0", "k?", T.CreateBuffer(new(DataTypes.Float32, new[] { 1, 24, 32, 3 }), MemoryLocation.Input, out var _), T.CreateBuffer(new(DataTypes.Float32, new[] { 1, 3, 24, 32 }), MemoryLocation.Output, out var _)).Body( T.Nop()).Build(); var func_0 = new PrimFunctionWrapper(prim_func_0, 1); - var prim_func_1 = T.PrimFunc("prim_func_1", "k?", T.PhysicalBuffer(DataTypes.Float32, MemoryLocation.Input, new[] { 1, 3, 24, 32 }, out var _), T.PhysicalBuffer(DataTypes.Float32, MemoryLocation.Output, new[] { 1, 3, 24, 32 }, out var _)).Body( + var prim_func_1 = T.PrimFunc("prim_func_1", "k?", T.CreateBuffer(new(DataTypes.Float32, new[] { 1, 3, 24, 32 }), MemoryLocation.Input, out var _), T.CreateBuffer(new(DataTypes.Float32, new[] { 1, 3, 24, 32 }), MemoryLocation.Output, out var _)).Body( T.Nop()).Build(); var func_1 = new PrimFunctionWrapper(prim_func_1, 1); - var prim_func_2 = T.PrimFunc("prim_func_2", "k?", T.PhysicalBuffer(DataTypes.Float32, MemoryLocation.Input, new[] { 1, 3, 24, 32 }, out var _), T.PhysicalBuffer(DataTypes.Float32, MemoryLocation.Output, new[] { 1, 23, 30, 16 }, out var _)).Body( + var prim_func_2 = T.PrimFunc("prim_func_2", "k?", T.CreateBuffer(new(DataTypes.Float32, new[] { 1, 3, 24, 32 }), MemoryLocation.Input, out var _), T.CreateBuffer(new(DataTypes.Float32, new[] { 1, 23, 30, 16 }), MemoryLocation.Output, out var _)).Body( T.Nop()).Build(); var func_2 = new PrimFunctionWrapper(prim_func_2, 1); @@ -74,7 +74,7 @@ public void TestPassMangerUpdateDependence2() Assert.True(CompilerServices.InferenceType(main_func)); // prim_func_2 for update - var prim_func_1_update = T.PrimFunc("prim_func_1_update", "k?", T.PhysicalBuffer(DataTypes.Float32, MemoryLocation.Input, new[] { 1, 3, 24, 32 }, out var _), T.PhysicalBuffer(DataTypes.Float32, MemoryLocation.Output, new[] { 1, 3, 24, 32 }, out var _)).Body( + var prim_func_1_update = T.PrimFunc("prim_func_1_update", "k?", T.CreateBuffer(new(DataTypes.Float32, new[] { 1, 3, 24, 32 }), MemoryLocation.Input, out var _), T.CreateBuffer(new(DataTypes.Float32, new[] { 1, 3, 24, 32 }), MemoryLocation.Output, out var _)).Body( T.Nop(), T.Nop()).Build(); diff --git a/src/Nncase.Tests/Transform/UnitTestSubstitutor.cs b/src/Nncase.Tests/Transform/UnitTestSubstitutor.cs index f6a1ed969e..0864bca6de 100644 --- a/src/Nncase.Tests/Transform/UnitTestSubstitutor.cs +++ b/src/Nncase.Tests/Transform/UnitTestSubstitutor.cs @@ -24,8 +24,8 @@ public sealed class UnitTestSubstitutor : TestClassBase public void TestSubstitutorFailed() { var loop_i = new Var("loop_i", TensorType.Scalar(DataTypes.Int32)); - T.Buffer(DataTypes.Float32, MemoryLocation.Input, new Expr[] { 1, 2, 3, 4 }, out var hd); - var prim_func_1 = T.PrimFunc("prim_func_1", "k?", T.PhysicalBuffer(DataTypes.Float32, MemoryLocation.Input, new[] { 1, 2, 3, 4 }, out var input_a), T.PhysicalBuffer(DataTypes.Float32, MemoryLocation.Output, new[] { 1, 2, 3, 4 }, out var input_b)).Body( + T.CreateBuffer(new(DataTypes.Float32, new[] { 1, 2, 3, 4 }), MemoryLocation.Input, out var hd); + var prim_func_1 = T.PrimFunc("prim_func_1", "k?", T.CreateBuffer(new(DataTypes.Float32, new[] { 1, 2, 3, 4 }), MemoryLocation.Input, out var input_a), T.CreateBuffer(new(DataTypes.Float32, new[] { 1, 2, 3, 4 }), MemoryLocation.Output, out var input_b)).Body( T.Load(hd, loop_i)).Build(); var prim_wrapper = new PrimFunctionWrapper(prim_func_1, 1); @@ -49,8 +49,8 @@ public void TestSubstitutorFailed() public void TestSubstitutorTrue() { var loop_i = new Var("loop_i", TensorType.Scalar(DataTypes.Int32)); - T.Buffer(DataTypes.Float32, MemoryLocation.Input, new Expr[] { 1, 2, 3, 4 }, out var hd); - var prim_func_1 = T.PrimFunc("prim_func_1", "k?", T.PhysicalBuffer(DataTypes.Float32, MemoryLocation.Input, new[] { 1, 2, 3, 4 }, out var input_a), T.PhysicalBuffer(DataTypes.Float32, MemoryLocation.Output, new[] { 1, 2, 3, 4 }, out var input_b)).Body( + T.CreateBuffer(new(DataTypes.Float32, new[] { 1, 2, 3, 4 }), MemoryLocation.Input, out var hd); + var prim_func_1 = T.PrimFunc("prim_func_1", "k?", T.CreateBuffer(new(DataTypes.Float32, new[] { 1, 2, 3, 4 }), MemoryLocation.Input, out var input_a), T.CreateBuffer(new(DataTypes.Float32, new[] { 1, 2, 3, 4 }), MemoryLocation.Output, out var input_b)).Body( T.Load(hd, loop_i)).Build(); Dictionary vmap = new() { { loop_i, 1 } }; @@ -67,8 +67,8 @@ public void TestSubstitutorTrue() public void TestSubstitutorTrue2() { var loop_i = new Var("loop_i", TensorType.Scalar(DataTypes.Int32)); - T.Buffer(DataTypes.Float32, MemoryLocation.Input, new Expr[] { 1, 2, 3, 4 }, out var hd); - var prim_func_1 = T.PrimFunc("prim_func_1", "k?", T.PhysicalBuffer(DataTypes.Float32, MemoryLocation.Input, new[] { 1, 2, 3, 4 }, out var input_a), T.PhysicalBuffer(DataTypes.Int32, MemoryLocation.Output, new[] { 1, 2, 3, 4 }, out var input_b)).Body( + T.CreateBuffer(new(DataTypes.Float32, new[] { 1, 2, 3, 4 }), MemoryLocation.Input, out var hd); + var prim_func_1 = T.PrimFunc("prim_func_1", "k?", T.CreateBuffer(new(DataTypes.Float32, new[] { 1, 2, 3, 4 }), MemoryLocation.Input, out var input_a), T.CreateBuffer(new(DataTypes.Int32, new[] { 1, 2, 3, 4 }), MemoryLocation.Output, out var input_b)).Body( T.Load(hd, loop_i)).Build(); var prim_wrapper = new PrimFunctionWrapper(prim_func_1, 1); From 8d01066b76f8b65a109c1beb6b2b71fe443a16fb Mon Sep 17 00:00:00 2001 From: zhengqihang <597323109@qq.com> Date: Fri, 28 Jul 2023 19:26:03 +0800 Subject: [PATCH 048/308] refactor test --- .../CodeGen/CSourceBuiltn.cs | 4 +- .../CodeGen/CSourceConvertVisitor.cs | 6 +- .../CodeGen/CSourceExtensions.cs | 1 + .../CodeGen/LinkableFunction.cs | 4 +- .../Nncase.Modules.CPU/Targets/CPUTarget.cs | 4 +- modules/cpu/src/runtime/elfloader.cpp | 4 +- modules/cpu/src/runtime/elfloader.h | 8 +- modules/cpu/src/runtime/runtime_function.cpp | 144 ++++++++---------- modules/cpu/src/runtime/runtime_function.h | 12 +- src/Nncase.Core/IR/ExprVisitor.g.cs | 28 ++++ src/Nncase.Core/IR/IRList.csv | 4 +- .../Passes/Mutators/FlattenBuffer.cs | 9 +- src/Nncase.Core/TIR/Ops.cs | 4 +- src/Nncase.Core/TIR/Script.cs | 6 +- src/Nncase.Evaluator/TIR/Load.cs | 7 +- src/Nncase.Evaluator/TIR/Store.cs | 11 +- src/Nncase.Evaluator/TypeInferenceVisitor.cs | 1 + 17 files changed, 143 insertions(+), 114 deletions(-) diff --git a/modules/Nncase.Modules.CPU/CodeGen/CSourceBuiltn.cs b/modules/Nncase.Modules.CPU/CodeGen/CSourceBuiltn.cs index 269e8e5492..b2e0a6fcb3 100644 --- a/modules/Nncase.Modules.CPU/CodeGen/CSourceBuiltn.cs +++ b/modules/Nncase.Modules.CPU/CodeGen/CSourceBuiltn.cs @@ -79,9 +79,9 @@ public static class CSourceBuiltn #include "; - public const string FixedParameters = "nncase_mt_t* nncase_mt, void* data, void* rdata"; + public const string FixedParameters = "nncase_mt_t* nncase_mt, uint8_t* data, const uint8_t* rdata"; - public const string MainPrologue = $@"void _start(size_t func_id, buffer_t** buffers, {FixedParameters}) {{"; + public const string MainPrologue = $@"void _start(size_t func_id, uint8_t** buffers, {FixedParameters}) {{"; public const string MainEpilogue = @"}"; diff --git a/modules/Nncase.Modules.CPU/CodeGen/CSourceConvertVisitor.cs b/modules/Nncase.Modules.CPU/CodeGen/CSourceConvertVisitor.cs index 9e71b76ea6..64658b1b06 100644 --- a/modules/Nncase.Modules.CPU/CodeGen/CSourceConvertVisitor.cs +++ b/modules/Nncase.Modules.CPU/CodeGen/CSourceConvertVisitor.cs @@ -125,7 +125,7 @@ protected override CSymbol VisitPrimFunction(PrimFunction expr) throw new NotSupportedException("The PrimFunction must return void!"); } - var type = $"void {expr.Name}({string.Join(", ", expr.Parameters.AsValueEnumerable().Select(b => Visit(b).ToString()).ToArray())}, {CSourceBuiltn.FixedParameters})"; + var type = $"void {expr.Name}({string.Join(", ", expr.Parameters.AsValueEnumerable().Select(b => Visit(b.MemSpan.Start).ToString()).ToArray())}, {CSourceBuiltn.FixedParameters})"; using (var scope = new IndentScope(_implBuilder)) { @@ -173,10 +173,10 @@ protected override CSymbol VisitCall(Call expr) str = CSourceUtilities.ContertUnary(op, arguments); break; case Store: - str = $"((({arguments[2].Type} *){arguments[0].Name}->vaddr)[{arguments[1].Name}] = {arguments[2].Name})"; + str = $"((({arguments[2].Type} *){arguments[0].Name})[{arguments[1].Name}] = {arguments[2].Name})"; break; case Load: - str = $"((({type} *){arguments[0].Name}->vaddr)[{arguments[1].Name}])"; + str = $"((({type} *){arguments[0].Name})[{arguments[1].Name}])"; break; case IR.Buffers.MatchBuffer op: var n = arguments[0].Name; diff --git a/modules/Nncase.Modules.CPU/CodeGen/CSourceExtensions.cs b/modules/Nncase.Modules.CPU/CodeGen/CSourceExtensions.cs index 9ba7ff5656..c36037d603 100644 --- a/modules/Nncase.Modules.CPU/CodeGen/CSourceExtensions.cs +++ b/modules/Nncase.Modules.CPU/CodeGen/CSourceExtensions.cs @@ -29,6 +29,7 @@ public static string ToC(this PrimType primType) => public static string ToC(this DataType dataType) => dataType switch { PrimType ptype => ptype.ToC(), + PointerType ptype => "uint8_t *", _ => throw new NotSupportedException(dataType.ToString()), }; diff --git a/modules/Nncase.Modules.CPU/CodeGen/LinkableFunction.cs b/modules/Nncase.Modules.CPU/CodeGen/LinkableFunction.cs index 8757cc2ccd..bab60300b6 100644 --- a/modules/Nncase.Modules.CPU/CodeGen/LinkableFunction.cs +++ b/modules/Nncase.Modules.CPU/CodeGen/LinkableFunction.cs @@ -15,7 +15,9 @@ public LinkableFunction(uint id, byte[] descContents, TIR.PrimFunction sourceFun PrimFunction = sourceFunction; FunctionCSource = funcCSource; Text = Array.Empty(); - Sections = new ILinkedSection[] { new LinkedSection(descContents, ".desc", 0, 8, (uint)descContents.Length) }; + + // new LinkedSection(descContents, ".desc", 0, 8, (uint)descContents.Length) + Sections = new ILinkedSection[] { }; } public uint Id { get; } diff --git a/modules/Nncase.Modules.CPU/Targets/CPUTarget.cs b/modules/Nncase.Modules.CPU/Targets/CPUTarget.cs index 1e2657ab3b..a354c86557 100644 --- a/modules/Nncase.Modules.CPU/Targets/CPUTarget.cs +++ b/modules/Nncase.Modules.CPU/Targets/CPUTarget.cs @@ -87,7 +87,6 @@ public void RegisterTargetDependentAfterQuantPass(IPassManager passManager, Comp passManager.Add().Configure(p => { p.Add(); - p.Add(); p.Add(); p.Add(); }); @@ -96,8 +95,9 @@ public void RegisterTargetDependentAfterQuantPass(IPassManager passManager, Comp passManager.AddWithName("InstStage").Configure(p => { + p.Add(); p.Add(); - p.Add(); // 折叠自定义op + p.Add(); }); } diff --git a/modules/cpu/src/runtime/elfloader.cpp b/modules/cpu/src/runtime/elfloader.cpp index f20063b9d5..c41d7e11ab 100644 --- a/modules/cpu/src/runtime/elfloader.cpp +++ b/modules/cpu/src/runtime/elfloader.cpp @@ -4,8 +4,8 @@ using namespace nncase; using namespace nncase::runtime; using namespace nncase::runtime::cpu; -int elfloader::invoke_elf(size_t id, buffer_t **buffers, nncase_mt_t *nncase_mt, - void *data, const void *rdata) { +int elfloader::invoke_elf(size_t id, uint8_t **buffers, nncase_mt_t *nncase_mt, + void *data, const uint8_t *rdata) { check(el_init(&ctx_), "initialising"); diff --git a/modules/cpu/src/runtime/elfloader.h b/modules/cpu/src/runtime/elfloader.h index 7a0aa36f00..5fc7ceaec3 100644 --- a/modules/cpu/src/runtime/elfloader.h +++ b/modules/cpu/src/runtime/elfloader.h @@ -13,8 +13,8 @@ BEGIN_NS_NNCASE_RT_MODULE(cpu) -typedef void (*entrypoint_t)(size_t id, buffer_t **buffers, - nncase_mt_t *nncase_mt, void *data, const void *rdata); +typedef void (*entrypoint_t)(size_t id, uint8_t **buffers, + nncase_mt_t *nncase_mt, void *data, const uint8_t *rdata); class elfloader { public: @@ -49,8 +49,8 @@ class elfloader { } } - int invoke_elf(size_t id, buffer_t **buffers, nncase_mt_t *nncase_mt, - void *data, const void *rdata); + int invoke_elf(size_t id, uint8_t **buffers, nncase_mt_t *nncase_mt, + void *data, const uint8_t *rdata); private: void *ptr_; diff --git a/modules/cpu/src/runtime/runtime_function.cpp b/modules/cpu/src/runtime/runtime_function.cpp index 422fd23e59..776da74410 100644 --- a/modules/cpu/src/runtime/runtime_function.cpp +++ b/modules/cpu/src/runtime/runtime_function.cpp @@ -48,61 +48,61 @@ cpu_runtime_module &cpu_runtime_function::module() const noexcept { result cpu_runtime_function::initialize_core( NNCASE_UNUSED runtime_function_init_context &context) noexcept { - try_(context.read_section(".desc", [this](auto sr, size_t) -> result { - auto header = sr.template read(); - if (parameters_size() != header.inputs + header.outputs) - return nncase::err(std::errc::invalid_argument); - - for (uint32_t i = 0; i < header.inputs; i++) { - sr.template read(); - auto rank = sr.template read(); - std::vector shape(rank); - std::cout << "shape: "; - for (uint32_t j = 0; j < rank; j++) { - shape[j] = sr.template read(); - std::cout << shape[j] << ", "; - } - std::cout << std::endl; - - std::vector stride(rank); - std::cout << "stride: "; - for (uint32_t j = 0; j < rank; j++) { - stride[j] = sr.template read(); - std::cout << stride[j] << ", "; - } - std::cout << std::endl; - - input_ranks_.emplace_back(rank); - input_shapes_.emplace_back(shape); - input_strides_.emplace_back(stride); - } - - for (uint32_t i = 0; i < header.outputs; i++) { - sr.template read(); - auto rank = sr.template read(); - std::vector shape(rank); - std::cout << "shape: "; - for (uint32_t j = 0; j < rank; j++) { - shape[j] = sr.template read(); - std::cout << shape[j] << ", "; - } - std::cout << std::endl; - - std::vector stride(rank); - std::cout << "stride: "; - for (uint32_t j = 0; j < rank; j++) { - stride[j] = sr.template read(); - std::cout << stride[j] << ", "; - } - std::cout << std::endl; - - output_ranks_.emplace_back(rank); - output_shapes_.emplace_back(shape); - output_strides_.emplace_back(stride); - } - - return ok(); - })); + // try_(context.read_section(".desc", [this](auto sr, size_t) -> result { + // auto header = sr.template read(); + // if (parameters_size() != header.inputs + header.outputs) + // return nncase::err(std::errc::invalid_argument); + + // for (uint32_t i = 0; i < header.inputs; i++) { + // sr.template read(); + // auto rank = sr.template read(); + // std::vector shape(rank); + // std::cout << "shape: "; + // for (uint32_t j = 0; j < rank; j++) { + // shape[j] = sr.template read(); + // std::cout << shape[j] << ", "; + // } + // std::cout << std::endl; + + // std::vector stride(rank); + // std::cout << "stride: "; + // for (uint32_t j = 0; j < rank; j++) { + // stride[j] = sr.template read(); + // std::cout << stride[j] << ", "; + // } + // std::cout << std::endl; + + // input_ranks_.emplace_back(rank); + // input_shapes_.emplace_back(shape); + // input_strides_.emplace_back(stride); + // } + + // for (uint32_t i = 0; i < header.outputs; i++) { + // sr.template read(); + // auto rank = sr.template read(); + // std::vector shape(rank); + // std::cout << "shape: "; + // for (uint32_t j = 0; j < rank; j++) { + // shape[j] = sr.template read(); + // std::cout << shape[j] << ", "; + // } + // std::cout << std::endl; + + // std::vector stride(rank); + // std::cout << "stride: "; + // for (uint32_t j = 0; j < rank; j++) { + // stride[j] = sr.template read(); + // std::cout << stride[j] << ", "; + // } + // std::cout << std::endl; + + // output_ranks_.emplace_back(rank); + // output_shapes_.emplace_back(shape); + // output_strides_.emplace_back(stride); + // } + + // return ok(); + // })); return ok(); } @@ -111,38 +111,20 @@ result cpu_runtime_function::invoke_core(NNCASE_UNUSED gsl::span parameters, NNCASE_UNUSED value_t return_value) noexcept { try_var(id, module().find_id_by_function(this)); - std::cout << "call " << id << std::endl; - - std::vector buffers(input_ranks_.size() + output_ranks_.size()); + uint8_t **buffers = new uint8_t*[parameters.size()]; // input buffer - for (uint32_t i = 0; i < input_ranks_.size(); i++) { - auto input_tensor = parameters[i].as().expect( - "input " + std::to_string(i) + " is not a tensor"); + for (size_t i = 0; i < parameters.size(); i++) { + try_var(input_tensor, parameters[i].as()); try_var(input_span, get_input_span(input_tensor)); - buffer_t *input_buffer = - new buffer_t{input_span.data(), 0, input_shapes_[i].data(), - input_strides_[i].data(), input_ranks_[i]}; - buffers[i] = input_buffer; - } - - // output buffer - for (uint32_t i = 0; i < output_ranks_.size(); i++) { - auto output_tensor = - parameters[input_ranks_.size() + i].as().expect( - "output " + std::to_string(i) + " is not a tensor"); - try_var(output_span, get_output_span(output_tensor)); - buffer_t *output_buffer = - new buffer_t{output_span.data(), 0, output_shapes_[i].data(), - output_strides_[i].data(), output_ranks_[i]}; - buffers[input_ranks_.size() + i] = output_buffer; + buffers[i] = (uint8_t *)(input_span.data()); } auto elfloader_ = elfloader{(char *)module().text_physical().data()}; - elfloader_.invoke_elf(id, buffers.data(), &nncase_mt, nullptr, (const void *)module().rdata_physical().data()); - for (int i = 0; i < buffers.size(); i++) { - delete buffers[i]; - } + elfloader_.invoke_elf(id, buffers, &nncase_mt, nullptr, + (const uint8_t *)module().rdata_physical().data()); + + delete[] buffers; return ok(tuple(std::in_place)); } \ No newline at end of file diff --git a/modules/cpu/src/runtime/runtime_function.h b/modules/cpu/src/runtime/runtime_function.h index ba19300437..74b24e0319 100644 --- a/modules/cpu/src/runtime/runtime_function.h +++ b/modules/cpu/src/runtime/runtime_function.h @@ -32,12 +32,12 @@ class cpu_runtime_function : public runtime_function { value_t return_value) noexcept override; private: - std::vector input_ranks_; - std::vector> input_shapes_; - std::vector> input_strides_; - std::vector output_ranks_; - std::vector> output_shapes_; - std::vector> output_strides_; + // std::vector input_ranks_; + // std::vector> input_shapes_; + // std::vector> input_strides_; + // std::vector output_ranks_; + // std::vector> output_shapes_; + // std::vector> output_strides_; }; END_NS_NNCASE_RT_MODULE diff --git a/src/Nncase.Core/IR/ExprVisitor.g.cs b/src/Nncase.Core/IR/ExprVisitor.g.cs index 75dc5ab13d..dd974ec60b 100644 --- a/src/Nncase.Core/IR/ExprVisitor.g.cs +++ b/src/Nncase.Core/IR/ExprVisitor.g.cs @@ -102,6 +102,13 @@ protected internal override TExprResult VisitTupleConst(TupleConst expr, TContex return VisitLeafTupleConst(expr, context); } + /// + protected internal override TExprResult VisitMemSpan(TIR.MemSpan expr, TContext context) + { + VisitOperands(expr, context); + return VisitLeafMemSpan(expr, context); + } + /// protected internal override TExprResult VisitVar(Var expr, TContext context) { @@ -116,6 +123,13 @@ protected internal override TExprResult VisitBlock(TIR.Block expr, TContext cont return VisitLeafBlock(expr, context); } + /// + protected internal override TExprResult VisitBuffer(TIR.Buffer expr, TContext context) + { + VisitOperands(expr, context); + return VisitLeafBuffer(expr, context); + } + /// protected internal override TExprResult VisitBufferRegion(TIR.BufferRegion expr, TContext context) { @@ -383,6 +397,13 @@ public partial class ExprVisitor /// internal protected sealed override TExprResult VisitTupleConst(TupleConst expr, Unit context) => VisitTupleConst(expr); /// + /// Visit . + /// + internal protected virtual TExprResult VisitMemSpan(TIR.MemSpan expr) => base.VisitMemSpan(expr, default); + + /// + internal protected sealed override TExprResult VisitMemSpan(TIR.MemSpan expr, Unit context) => VisitMemSpan(expr); + /// /// Visit . /// internal protected virtual TExprResult VisitVar(Var expr) => base.VisitVar(expr, default); @@ -397,6 +418,13 @@ public partial class ExprVisitor /// internal protected sealed override TExprResult VisitBlock(TIR.Block expr, Unit context) => VisitBlock(expr); /// + /// Visit . + /// + internal protected virtual TExprResult VisitBuffer(TIR.Buffer expr) => base.VisitBuffer(expr, default); + + /// + internal protected sealed override TExprResult VisitBuffer(TIR.Buffer expr, Unit context) => VisitBuffer(expr); + /// /// Visit . /// internal protected virtual TExprResult VisitBufferRegion(TIR.BufferRegion expr) => base.VisitBufferRegion(expr, default); diff --git a/src/Nncase.Core/IR/IRList.csv b/src/Nncase.Core/IR/IRList.csv index 8a1e2ed2f1..ba9dd8033b 100644 --- a/src/Nncase.Core/IR/IRList.csv +++ b/src/Nncase.Core/IR/IRList.csv @@ -11,10 +11,10 @@ PrimFunctionWrapper,true,true,BaseFunction,,Target TensorConst,true,false,Const,, Tuple,true,false,Default,IR.,@Fields TupleConst,true,false,Const,, -MemSpan,false,false,Default,TIR.,Start;Size; +MemSpan,true,false,Default,TIR.,Start;Size; Var,true,false,Default,, Block,true,false,Default,TIR.,Body;InitBody;@IterVars;@Reads;@Writes;@AllocBuffers;Predicate -Buffer,false,false,Default,TIR.,MemSpan;@Dimensions;@Strides +Buffer,true,false,Default,TIR.,MemSpan;@Dimensions;@Strides; BufferRegion,true,false,Default,TIR.,Buffer;@Region For,true,false,Default,TIR.,LoopVar;Domain;Body IfThenElse,true,false,Default,TIR.,Condition;Then;Else diff --git a/src/Nncase.Core/Passes/Mutators/FlattenBuffer.cs b/src/Nncase.Core/Passes/Mutators/FlattenBuffer.cs index a941477030..28a619fdbb 100644 --- a/src/Nncase.Core/Passes/Mutators/FlattenBuffer.cs +++ b/src/Nncase.Core/Passes/Mutators/FlattenBuffer.cs @@ -35,13 +35,18 @@ protected override Expr RewriteLeafCall(Call expr) { var indices = (IR.Tuple)expr[IR.Buffers.BufferLoad.Indices]; var input = (TIR.Buffer)expr[IR.Buffers.BufferLoad.Input]; - return T.Load(input, Enumerable.Range(0, indices.Count).Aggregate((Expr)0, (acc, i) => acc + (input.Strides[i] * indices[i]))); + return T.Load(input.MemSpan.Start, Enumerable.Range(0, indices.Count).Aggregate((Expr)0, (acc, i) => acc + (input.Strides[i] * indices[i]))); } else if (expr.Target is IR.Buffers.BufferStore) { var indices = (IR.Tuple)expr[IR.Buffers.BufferStore.Indices]; var input = (TIR.Buffer)expr[IR.Buffers.BufferStore.Input]; - return T.Store(input, Enumerable.Range(0, indices.Count).Aggregate((Expr)0, (acc, i) => acc + (input.Strides[i] * indices[i])), expr[IR.Buffers.BufferStore.Value]); + return T.Store(input.MemSpan.Start, Enumerable.Range(0, indices.Count).Aggregate((Expr)0, (acc, i) => acc + (input.Strides[i] * indices[i])), expr[IR.Buffers.BufferStore.Value]); + } + else if (expr.Target is IR.Buffers.MatchBuffer && expr.Arguments[0] is TIR.Buffer { MemSpan: { Start: (Const or Var) } }) + { + // remove the all fixed match operation. + return T.Nop(); } return expr; diff --git a/src/Nncase.Core/TIR/Ops.cs b/src/Nncase.Core/TIR/Ops.cs index cb6446fb76..3405cdc841 100644 --- a/src/Nncase.Core/TIR/Ops.cs +++ b/src/Nncase.Core/TIR/Ops.cs @@ -19,7 +19,7 @@ public sealed partial class Load : Op /// /// Gets handle. /// - public static readonly ParameterInfo Handle = new(typeof(Load), 0, "handle"); + public static readonly ParameterInfo Handle = new(typeof(Load), 0, "handle", IsPointer() | IsIntegralScalar()); /// /// Gets index. @@ -56,7 +56,7 @@ public sealed partial class Store : Op /// /// The buffer variable handle. /// - public static readonly ParameterInfo Handle = new(typeof(Store), 0, "handle"); + public static readonly ParameterInfo Handle = new(typeof(Store), 0, "handle", IsPointer() | IsIntegralScalar()); /// /// The index locations to be stored. diff --git a/src/Nncase.Core/TIR/Script.cs b/src/Nncase.Core/TIR/Script.cs index 453a336e7f..28740e43ab 100644 --- a/src/Nncase.Core/TIR/Script.cs +++ b/src/Nncase.Core/TIR/Script.cs @@ -52,7 +52,7 @@ public static class T /// /// The buffer handle variable in the load expression. /// The index in the load. - public static Call Load(TIR.Buffer handle, Expr index) => new Call(new Load(), handle, index); + public static Call Load(Expr handle, Expr index) => new Call(new Load(), handle, index); /// /// get the nop op. @@ -76,7 +76,7 @@ public static class T /// The buffer Variable. /// The index in the store expression. /// The value we want to store. - public static Call Store(TIR.Buffer handle, Expr index, Expr value) => new Call(new Store(), handle, index, value); + public static Call Store(Expr handle, Expr index, Expr value) => new Call(new Store(), handle, index, value); /// /// build for loop. @@ -268,7 +268,7 @@ public static Buffer AttachBuffer(TensorType tensorType, MemoryLocation location name = name[4..]; } - @var = new Var(name, TensorType.Pointer(tensorType.DType)); + @var = new Var(TensorType.Pointer(tensorType.DType)); var dimensions = tensorType.Shape.ToValueArray(); var strides = TensorUtilities.GetStrides(dimensions); var size = (int)TensorUtilities.GetProduct(dimensions.ToArray()) * tensorType.DType.SizeInBytes; diff --git a/src/Nncase.Evaluator/TIR/Load.cs b/src/Nncase.Evaluator/TIR/Load.cs index 86885bcc48..5898e3d353 100644 --- a/src/Nncase.Evaluator/TIR/Load.cs +++ b/src/Nncase.Evaluator/TIR/Load.cs @@ -30,6 +30,11 @@ public string Visit(IIRPrinterContext context, Load target, bool iLmode) private IRType Visit(Load target, TensorType handle, TensorType index) { - return TensorType.Scalar(handle.DType); + if (handle is not TensorType { DType: PointerType { } p }) + { + return new InvalidType("handle must be pointer type!"); + } + + return TensorType.Scalar(p.ElemType); } } diff --git a/src/Nncase.Evaluator/TIR/Store.cs b/src/Nncase.Evaluator/TIR/Store.cs index 9a1d6d6cda..b46bf57f52 100644 --- a/src/Nncase.Evaluator/TIR/Store.cs +++ b/src/Nncase.Evaluator/TIR/Store.cs @@ -24,18 +24,23 @@ public IRType Visit(ITypeInferenceContext context, Store target) public string Visit(IIRPrinterContext context, Store target, bool iLmode) { var handle = context.GetArgument(target, Store.Handle); - _ = context.GetArgument(target, Store.Value); + var value = context.GetArgument(target, Store.Value); var index = context.GetArgument(target, Store.Index); - return $"{handle}[{index}] = {index}"; + return $"{handle}[{index}] = {value}"; } private IRType Visit(Store target, TensorType handle, TensorType index, TensorType value) { - if (handle.DType != value.DType) + if (handle.DType is not PointerType { ElemType: DataType elemType } || elemType != value.DType) { return new InvalidType($"You Can't Load The {value.DType} To {handle.DType}"); } + if (index.DType != DataTypes.Int32) + { + return new InvalidType($"store value type {index.DType} not supported"); + } + return TupleType.Void; } } diff --git a/src/Nncase.Evaluator/TypeInferenceVisitor.cs b/src/Nncase.Evaluator/TypeInferenceVisitor.cs index da0ab274b0..365528b6d5 100644 --- a/src/Nncase.Evaluator/TypeInferenceVisitor.cs +++ b/src/Nncase.Evaluator/TypeInferenceVisitor.cs @@ -68,6 +68,7 @@ protected override IRType VisitLeafBufferRegion(BufferRegion expr) return type; } + /// protected override IRType VisitLeafBuffer(Nncase.TIR.Buffer expr) { VerifySubField(expr, expr.MemSpan, TypePatternUtility.IsTuple()); From c2b23394479679bba47a28137015de9469d815f5 Mon Sep 17 00:00:00 2001 From: zhengqihang <597323109@qq.com> Date: Mon, 31 Jul 2023 12:18:05 +0800 Subject: [PATCH 049/308] fix const buffer --- .../CodeGen/CSourceConvertVisitor.cs | 34 ++++++++++++++++++- src/Nncase.Core/DataType.cs | 5 +-- src/Nncase.Core/DataTypes.cs | 4 +-- src/Nncase.Core/IR/Buffers/DDrOf.cs | 3 ++ src/Nncase.Core/IR/IRType.cs | 2 +- src/Nncase.Core/IR/Var.cs | 2 +- .../Passes/Mutators/FlattenBuffer.cs | 4 +-- src/Nncase.Core/Tensor.cs | 6 ++-- .../Diagnostics/ScriptPrintVisitor.cs | 11 +++--- src/Nncase.Evaluator/Buffers/DDrOf.cs | 2 +- src/Nncase.Evaluator/TypeInferenceVisitor.cs | 4 +-- src/Nncase.Passes/DDrBufferSchdeulePass.cs | 4 +-- .../Core/IR/UnitTestTensorConst.cs | 2 +- src/Nncase.Tests/Core/UnitTestDataType.cs | 2 +- src/Nncase.Tests/Core/UnitTestDataTypes.cs | 7 ++-- src/Nncase.Tests/Core/UnitTestTensor.cs | 2 +- .../Targets/UnitTestCPUTargetTiling.cs | 10 ++++-- 17 files changed, 72 insertions(+), 32 deletions(-) diff --git a/modules/Nncase.Modules.CPU/CodeGen/CSourceConvertVisitor.cs b/modules/Nncase.Modules.CPU/CodeGen/CSourceConvertVisitor.cs index 64658b1b06..b71479da32 100644 --- a/modules/Nncase.Modules.CPU/CodeGen/CSourceConvertVisitor.cs +++ b/modules/Nncase.Modules.CPU/CodeGen/CSourceConvertVisitor.cs @@ -147,6 +147,33 @@ protected override CSymbol VisitPrimFunction(PrimFunction expr) return symbol; } + /// + protected override CSymbol VisitMemSpan(MemSpan expr) + { + if (_exprMemo.TryGetValue(expr, out var symbol)) + { + return symbol; + } + + var start = Visit(expr.Start); + var size = Visit(expr.Size); + string name = start.Name; + if (expr.Start is (TensorConst or Call)) + { + var loc = expr.Location switch + { + MemoryLocation.Rdata => "rdata", + MemoryLocation.Data => "data", + _ => throw new NotSupportedException(), + }; + name = $"({loc} + {start.Name})"; + } + + symbol = new(start.Type, name); + _exprMemo.Add(expr, symbol); + return symbol; + } + /// protected override CSymbol VisitCall(Call expr) { @@ -223,9 +250,14 @@ protected override CSymbol VisitConst(Const expr) type = ptype.ToC(); } + else if (expr is TensorConst { Value: Tensor { ElementType: PointerType { ElemType: PrimType etype }, Shape: { IsScalar: true } } pointer }) + { + str = pointer.ToScalar().ToString(); + type = "uint8_t *"; + } else { - throw new NotSupportedException($"Not Support {expr.CheckedType} Const"); + throw new NotSupportedException(); } symbol = new(type, str); diff --git a/src/Nncase.Core/DataType.cs b/src/Nncase.Core/DataType.cs index 04c3170301..4b86755d22 100644 --- a/src/Nncase.Core/DataType.cs +++ b/src/Nncase.Core/DataType.cs @@ -42,7 +42,8 @@ public enum PrimTypeAttributes /// /// /// the type the pointer points to. -public sealed record PointerType(DataType ElemType) : DataType +/// the shape of the pointer points to. +public sealed record PointerType(DataType ElemType, IR.Shape Shape) : DataType { /// public override Type CLRType { get; } = typeof(Pointer<>).MakeGenericType(ElemType.CLRType); @@ -81,7 +82,7 @@ public static DataType FromType(Type t) { if (t.GetGenericTypeDefinition() == typeof(Pointer<>)) { - return new PointerType(FromType(t.GenericTypeArguments[0])); + return new PointerType(FromType(t.GenericTypeArguments[0]), IR.Shape.Scalar); } throw new ArgumentException("Unsupported CLR type."); diff --git a/src/Nncase.Core/DataTypes.cs b/src/Nncase.Core/DataTypes.cs index d1a24255e7..3c8b65fd3d 100644 --- a/src/Nncase.Core/DataTypes.cs +++ b/src/Nncase.Core/DataTypes.cs @@ -114,7 +114,7 @@ public static bool IsPointer(this DataType srcType) => /// datatype name. public static string GetDisplayName(this DataType dataType) => dataType switch { - PointerType pointerType => $"({GetDisplayName(pointerType.ElemType)}*)", + PointerType pointerType => $"({GetDisplayName(pointerType.ElemType)}{(pointerType.Shape.IsScalar ? string.Empty : pointerType.Shape.ToString())} *)", PrimType primType => primType.ShortName, ValueType => dataType.ToString(), _ => throw new ArgumentOutOfRangeException(dataType.GetType().Name), @@ -128,7 +128,7 @@ public static bool IsPointer(this DataType srcType) => public static string GetCSharpName(this DataType dataType) => dataType switch { PrimType primType => $"DataTypes.{primType.FullName}", - PointerType pointerType => $"new PointerType({pointerType.ElemType.GetCSharpName()})", + PointerType pointerType => $"new PointerType({pointerType.ElemType.GetCSharpName()}, IR.Shape.Scalar)", ValueType valueType => $"new {valueType.GetType().Name}()", _ => throw new ArgumentOutOfRangeException(dataType.GetType().Name), }; diff --git a/src/Nncase.Core/IR/Buffers/DDrOf.cs b/src/Nncase.Core/IR/Buffers/DDrOf.cs index 8657d3f417..116e019d5f 100644 --- a/src/Nncase.Core/IR/Buffers/DDrOf.cs +++ b/src/Nncase.Core/IR/Buffers/DDrOf.cs @@ -17,4 +17,7 @@ public sealed partial class DDrOf : Op /// Get the input parameter. /// public static readonly ParameterInfo Input = new(typeof(DDrOf), 0, "input", IsTensor()); + + /// + public override bool CanFoldConstCall => false; } diff --git a/src/Nncase.Core/IR/IRType.cs b/src/Nncase.Core/IR/IRType.cs index b8aeda469f..ceb76319ad 100644 --- a/src/Nncase.Core/IR/IRType.cs +++ b/src/Nncase.Core/IR/IRType.cs @@ -138,7 +138,7 @@ public sealed record TensorType(DataType DType, Shape Shape) : IRType /// /// the Pointed Element Type. /// the pointer tensor type. - public static TensorType Pointer(DataType elemType) => new(new PointerType(elemType), Shape.Scalar); + public static TensorType Pointer(DataType elemType) => new(new PointerType(elemType, Shape.Scalar), Shape.Scalar); } /// diff --git a/src/Nncase.Core/IR/Var.cs b/src/Nncase.Core/IR/Var.cs index af1987e16b..a37d90eb80 100644 --- a/src/Nncase.Core/IR/Var.cs +++ b/src/Nncase.Core/IR/Var.cs @@ -92,7 +92,7 @@ public Var() /// get handle var. /// /// var. - public static Var Handle(string name, DataType dtype, string scope = "") => new Var(name, TensorType.Scalar(new PointerType(dtype))); + public static Var Handle(string name, DataType dtype, string scope = "") => new Var(name, TensorType.Scalar(new PointerType(dtype, Shape.Scalar))); /// /// get the size var. it can be used in tensor shape. like n>=0, m>=0. diff --git a/src/Nncase.Core/Passes/Mutators/FlattenBuffer.cs b/src/Nncase.Core/Passes/Mutators/FlattenBuffer.cs index 28a619fdbb..1c15637eb6 100644 --- a/src/Nncase.Core/Passes/Mutators/FlattenBuffer.cs +++ b/src/Nncase.Core/Passes/Mutators/FlattenBuffer.cs @@ -35,13 +35,13 @@ protected override Expr RewriteLeafCall(Call expr) { var indices = (IR.Tuple)expr[IR.Buffers.BufferLoad.Indices]; var input = (TIR.Buffer)expr[IR.Buffers.BufferLoad.Input]; - return T.Load(input.MemSpan.Start, Enumerable.Range(0, indices.Count).Aggregate((Expr)0, (acc, i) => acc + (input.Strides[i] * indices[i]))); + return T.Load(input.MemSpan, Enumerable.Range(0, indices.Count).Aggregate((Expr)0, (acc, i) => acc + (input.Strides[i] * indices[i]))); } else if (expr.Target is IR.Buffers.BufferStore) { var indices = (IR.Tuple)expr[IR.Buffers.BufferStore.Indices]; var input = (TIR.Buffer)expr[IR.Buffers.BufferStore.Input]; - return T.Store(input.MemSpan.Start, Enumerable.Range(0, indices.Count).Aggregate((Expr)0, (acc, i) => acc + (input.Strides[i] * indices[i])), expr[IR.Buffers.BufferStore.Value]); + return T.Store(input.MemSpan, Enumerable.Range(0, indices.Count).Aggregate((Expr)0, (acc, i) => acc + (input.Strides[i] * indices[i])), expr[IR.Buffers.BufferStore.Value]); } else if (expr.Target is IR.Buffers.MatchBuffer && expr.Arguments[0] is TIR.Buffer { MemSpan: { Start: (Const or Var) } }) { diff --git a/src/Nncase.Core/Tensor.cs b/src/Nncase.Core/Tensor.cs index 6747cddee8..a4875bf7c5 100644 --- a/src/Nncase.Core/Tensor.cs +++ b/src/Nncase.Core/Tensor.cs @@ -349,11 +349,11 @@ public static Tensor> FromPointer(ulong value) /// Create tensor from a ulong address. /// /// addr value. - /// Element type. + /// points type. /// Created tensor. - public static Tensor FromPointer(ulong value, DataType elemType) + public static Tensor FromPointer(ulong value, PointerType pointerType) { - return Tensor.FromBytes(TensorType.Scalar(new PointerType(elemType)), BitConverter.GetBytes(value)); + return Tensor.FromBytes(TensorType.Scalar(pointerType), BitConverter.GetBytes(value)); } /// diff --git a/src/Nncase.Diagnostics/Diagnostics/ScriptPrintVisitor.cs b/src/Nncase.Diagnostics/Diagnostics/ScriptPrintVisitor.cs index aa5d9a4126..b6bc422147 100644 --- a/src/Nncase.Diagnostics/Diagnostics/ScriptPrintVisitor.cs +++ b/src/Nncase.Diagnostics/Diagnostics/ScriptPrintVisitor.cs @@ -295,18 +295,17 @@ protected override IPrintSymbol VisitTensorConst(TensorConst @const) { doc = new(new($"{@const}")); } - else - if (@const.Value.ElementType.IsFloat()) + else if (@const.Value.ElementType.IsFloat()) { - doc = new(new($"{string.Join(",", @const.Value.ToArray())}")); + doc = new(new(@const.Value.Length > 8 ? @const.CheckedShape.ToString() : $"{string.Join(",", @const.Value.ToArray())}")); } else if (@const.Value.ElementType.IsIntegral()) { - doc = new(new($"{string.Join(",", @const.Value.ToArray())}")); + doc = new(new(@const.Value.Length > 8 ? @const.CheckedShape.ToString() : $"{string.Join(",", @const.Value.ToArray())}")); } - else if (@const.Value.ElementType.IsPointer()) + else if (@const.Value.ElementType is PointerType p) { - doc = new(new($"{string.Join(",", @const.Value.ToArray().Select(i => "0x" + i.ToString("X")))}")); + doc = new(new($"*{p.ElemType.GetDisplayName()}@{@const.Value.Shape}")); } _exprMemo.Add(@const, doc!); diff --git a/src/Nncase.Evaluator/Buffers/DDrOf.cs b/src/Nncase.Evaluator/Buffers/DDrOf.cs index eb6de07acf..86c53e04b7 100644 --- a/src/Nncase.Evaluator/Buffers/DDrOf.cs +++ b/src/Nncase.Evaluator/Buffers/DDrOf.cs @@ -14,6 +14,6 @@ public partial class DDrOfEvaluator : ITypeInferencer { private IRType Visit(TensorType input) { - return new PointerType(input.DType); + return TensorType.Pointer(input.DType); } } diff --git a/src/Nncase.Evaluator/TypeInferenceVisitor.cs b/src/Nncase.Evaluator/TypeInferenceVisitor.cs index 365528b6d5..01ff4c48a6 100644 --- a/src/Nncase.Evaluator/TypeInferenceVisitor.cs +++ b/src/Nncase.Evaluator/TypeInferenceVisitor.cs @@ -71,7 +71,7 @@ protected override IRType VisitLeafBufferRegion(BufferRegion expr) /// protected override IRType VisitLeafBuffer(Nncase.TIR.Buffer expr) { - VerifySubField(expr, expr.MemSpan, TypePatternUtility.IsTuple()); + VerifySubField(expr, expr.MemSpan, TypePatternUtility.IsPointer()); foreach (var r in expr.Dimensions) { VerifySubField(expr, r, TypePatternUtility.IsIntegralScalar()); @@ -283,7 +283,7 @@ protected override IRType VisitLeafMemSpan(MemSpan expr) { VerifySubField(expr, expr.Start, TypePatternUtility.IsNoneType() | TypePatternUtility.IsIntegralScalar() | TypePatternUtility.IsPointer()); VerifySubField(expr, expr.Size, TypePatternUtility.IsIntegralScalar()); - return TupleType.Void; + return expr.Start.CheckedType; } /// diff --git a/src/Nncase.Passes/DDrBufferSchdeulePass.cs b/src/Nncase.Passes/DDrBufferSchdeulePass.cs index e21f833763..bf79e37d57 100644 --- a/src/Nncase.Passes/DDrBufferSchdeulePass.cs +++ b/src/Nncase.Passes/DDrBufferSchdeulePass.cs @@ -134,7 +134,7 @@ public DDrBufferRewriter(Dictionary> mod protected override TIR.MemSpan RewriteLeafMemSpan(TIR.MemSpan memSpan) { - if (memSpan is { Location: MemoryLocation.Rdata, Start: Call { Target: IR.Buffers.DDrOf, Arguments: var arg } } && arg[0] is Const @const) + if (memSpan is { Location: MemoryLocation.Rdata, Start: Call { Target: IR.Buffers.DDrOf, Arguments: var arg } } && arg[0] is Const { ValueType: TensorType constType } @const) { if (!ModuleRdataMaps.TryGetValue(Entry.ModuleKind, out var moduleRdataMap)) { @@ -163,7 +163,7 @@ protected override TIR.MemSpan RewriteLeafMemSpan(TIR.MemSpan memSpan) Changed = true; } - return memSpan.With(memRange.Start.Value, memRange.End.Value - memRange.Start.Value); + return memSpan.With(new TensorConst(Tensor.FromPointer((ulong)memRange.Start.Value, new PointerType(constType.DType, constType.Shape))), memRange.End.Value - memRange.Start.Value); } // else if (memSpan.Location is MemoryLocation.Data) diff --git a/src/Nncase.Tests/Core/IR/UnitTestTensorConst.cs b/src/Nncase.Tests/Core/IR/UnitTestTensorConst.cs index c444a0e45b..d98bf5ed13 100644 --- a/src/Nncase.Tests/Core/IR/UnitTestTensorConst.cs +++ b/src/Nncase.Tests/Core/IR/UnitTestTensorConst.cs @@ -179,7 +179,7 @@ public void TestTensorType() var actual2 = TensorType.Invalid(DataTypes.Float32); Assert.Equal(expect2, actual2); - var expect3 = new TensorType(new PointerType(DataTypes.Float32), Shape.Scalar); + var expect3 = new TensorType(new PointerType(DataTypes.Float32, Shape.Scalar), Shape.Scalar); var actual3 = TensorType.Pointer(DataTypes.Float32); Assert.Equal(expect3, actual3); } diff --git a/src/Nncase.Tests/Core/UnitTestDataType.cs b/src/Nncase.Tests/Core/UnitTestDataType.cs index bb734e2671..410b6a13d3 100644 --- a/src/Nncase.Tests/Core/UnitTestDataType.cs +++ b/src/Nncase.Tests/Core/UnitTestDataType.cs @@ -15,7 +15,7 @@ public sealed class UnitTestDataType [Fact] public void TestPointerType() { - var pType = new PointerType(DataTypes.Float32); + var pType = new PointerType(DataTypes.Float32, IR.Shape.Scalar); Assert.Equal(8, pType.SizeInBytes); var t = DataType.FromType(typeof(Pointer)); diff --git a/src/Nncase.Tests/Core/UnitTestDataTypes.cs b/src/Nncase.Tests/Core/UnitTestDataTypes.cs index 322b989921..31a9249435 100644 --- a/src/Nncase.Tests/Core/UnitTestDataTypes.cs +++ b/src/Nncase.Tests/Core/UnitTestDataTypes.cs @@ -53,7 +53,7 @@ public void TestIsFloat() public void TestIsPointer() { Assert.False(DataTypes.IsPointer(DataTypes.Float32)); - Assert.True(DataTypes.IsPointer(new PointerType(DataTypes.Float32))); + Assert.True(DataTypes.IsPointer(new PointerType(DataTypes.Float32, IR.Shape.Scalar))); } [Fact] @@ -61,7 +61,7 @@ public void TestGetDisplayName() { var a = new QuantParamType(); Assert.Equal(a.ToString(), DataTypes.GetDisplayName(a)); - Assert.Equal("(f32*)", DataTypes.GetDisplayName(new PointerType(DataTypes.Float32))); + Assert.Equal("(f32*)", DataTypes.GetDisplayName(new PointerType(DataTypes.Float32, IR.Shape.Scalar))); Assert.Equal(DataTypes.Boolean.ShortName, DataTypes.GetDisplayName(DataTypes.Boolean)); Assert.Equal(DataTypes.Utf8Char.ShortName, DataTypes.GetDisplayName(DataTypes.Utf8Char)); Assert.Equal(DataTypes.Int8.ShortName, DataTypes.GetDisplayName(DataTypes.Int8)); @@ -82,7 +82,7 @@ public void TestGetDisplayName() public void TestCSharpName() { Assert.Equal("new QuantParamType()", DataTypes.GetCSharpName(new QuantParamType())); - Assert.Equal("new PointerType(DataTypes.Float32)", DataTypes.GetCSharpName(new PointerType(DataTypes.Float32))); + Assert.Equal("new PointerType(DataTypes.Float32, IR.Shape.Scalar)", DataTypes.GetCSharpName(new PointerType(DataTypes.Float32, IR.Shape.Scalar))); Assert.Equal("DataTypes.Boolean", DataTypes.GetCSharpName(DataTypes.Boolean)); Assert.Equal("DataTypes.Utf8Char", DataTypes.GetCSharpName(DataTypes.Utf8Char)); Assert.Equal("DataTypes.Int8", DataTypes.GetCSharpName(DataTypes.Int8)); @@ -103,7 +103,6 @@ public void TestCSharpName() public void TestBuiltInName() { Assert.Equal("QuantParam", DataTypes.GetBuiltInName(new QuantParamType())); - Assert.Throws(() => DataTypes.GetBuiltInName(new PointerType(DataTypes.Float32))); Assert.Equal("bool", DataTypes.GetBuiltInName(DataTypes.Boolean)); Assert.Equal("Utf8Char", DataTypes.GetBuiltInName(DataTypes.Utf8Char)); Assert.Equal("sbyte", DataTypes.GetBuiltInName(DataTypes.Int8)); diff --git a/src/Nncase.Tests/Core/UnitTestTensor.cs b/src/Nncase.Tests/Core/UnitTestTensor.cs index 28a4c31c6e..259de59424 100644 --- a/src/Nncase.Tests/Core/UnitTestTensor.cs +++ b/src/Nncase.Tests/Core/UnitTestTensor.cs @@ -111,7 +111,7 @@ public unsafe void TestFromPointerET() var p1 = new Pointer(addr1); var p2 = new Pointer(addr2); - var t = Tensor.FromPointer(addr1, DataTypes.Int32); + var t = Tensor.FromPointer(addr1, new PointerType(DataTypes.Int32, Shape.Scalar)); Assert.Equal(p1, t.ToScalar>()); Assert.Equal(addr1, t.ToScalar>().Value); Assert.NotEqual(p2, t.ToScalar>()); diff --git a/src/Nncase.Tests/Targets/UnitTestCPUTargetTiling.cs b/src/Nncase.Tests/Targets/UnitTestCPUTargetTiling.cs index f3a897f37d..db5c9c5567 100644 --- a/src/Nncase.Tests/Targets/UnitTestCPUTargetTiling.cs +++ b/src/Nncase.Tests/Targets/UnitTestCPUTargetTiling.cs @@ -48,10 +48,13 @@ public async Task TestCpuUnary() compiler.Gencode(fs); } + var input_tensor = IR.F.Random.Normal(DataTypes.Float32, 0, 1, 2, new[] { 1, 2, 3, 4, 5 }).Evaluate().AsTensor(); using (var fs = Dumpper.OpenFile("input_0.bin")) { - fs.Write(IR.F.Random.Normal(DataTypes.Float32, 0, 1, 2, new[] { 1, 2, 3, 4, 5 }).Evaluate().AsTensor().BytesBuffer); + fs.Write(input_tensor.BytesBuffer); } + + Testing.RunKModel(File.ReadAllBytes(Path.Join(Dumpper.Directory, "test.kmodel")), Dumpper.Directory, new[] { input_tensor }); } [Fact] @@ -72,9 +75,12 @@ public async Task TestCpuMatMul() compiler.Gencode(fs); } + var input_tensor = IR.F.Random.Normal(DataTypes.Float32, 0, 1, 2, new[] { 3, 4 }).Evaluate().AsTensor(); using (var fs = Dumpper.OpenFile("input_0.bin")) { - fs.Write(IR.F.Random.Normal(DataTypes.Float32, 0, 1, 2, new[] { 1, 2, 3, 4, 5 }).Evaluate().AsTensor().BytesBuffer); + fs.Write(input_tensor.BytesBuffer); } + + Testing.RunKModel(File.ReadAllBytes(Path.Join(Dumpper.Directory, "test.kmodel")), Dumpper.Directory, new[] { input_tensor }); } } From 1e6da3411ecefb251297b905cf360f76114aab97 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?=E9=83=91=E5=90=AF=E8=88=AA?= <597323109@qq.com> Date: Tue, 1 Aug 2023 11:05:46 +0800 Subject: [PATCH 050/308] refactor buffer --- .../CodeGen/CSourceConvertVisitor.cs | 6 +- .../CodeGen/CSourceExtensions.cs | 2 +- .../CodeGen/LinkableFunction.cs | 2 +- src/Nncase.Core/IR/Callable.cs | 2 +- .../Passes/Mutators/FlattenBuffer.cs | 2 +- .../BufferSchedule/BufferScheduleTypes.cs | 67 ++++++++++ .../BufferSchedule/BufferScheduler.cs | 114 ++++++++++++++++++ .../BufferSchedule/LifeTimeCollector.cs | 103 ++++++++++++++++ src/Nncase.Passes/DDrBufferSchdeulePass.cs | 50 ++------ 9 files changed, 304 insertions(+), 44 deletions(-) create mode 100644 src/Nncase.Passes/BufferSchedule/BufferScheduleTypes.cs create mode 100644 src/Nncase.Passes/BufferSchedule/BufferScheduler.cs create mode 100644 src/Nncase.Passes/BufferSchedule/LifeTimeCollector.cs diff --git a/modules/Nncase.Modules.CPU/CodeGen/CSourceConvertVisitor.cs b/modules/Nncase.Modules.CPU/CodeGen/CSourceConvertVisitor.cs index b71479da32..bc715c4f2f 100644 --- a/modules/Nncase.Modules.CPU/CodeGen/CSourceConvertVisitor.cs +++ b/modules/Nncase.Modules.CPU/CodeGen/CSourceConvertVisitor.cs @@ -156,9 +156,9 @@ protected override CSymbol VisitMemSpan(MemSpan expr) } var start = Visit(expr.Start); - var size = Visit(expr.Size); + _ = Visit(expr.Size); string name = start.Name; - if (expr.Start is (TensorConst or Call)) + if (expr.Start is TensorConst or Call) { var loc = expr.Location switch { @@ -250,7 +250,7 @@ protected override CSymbol VisitConst(Const expr) type = ptype.ToC(); } - else if (expr is TensorConst { Value: Tensor { ElementType: PointerType { ElemType: PrimType etype }, Shape: { IsScalar: true } } pointer }) + else if (expr is TensorConst { Value: Tensor { ElementType: PointerType { ElemType: PrimType }, Shape: { IsScalar: true } } pointer }) { str = pointer.ToScalar().ToString(); type = "uint8_t *"; diff --git a/modules/Nncase.Modules.CPU/CodeGen/CSourceExtensions.cs b/modules/Nncase.Modules.CPU/CodeGen/CSourceExtensions.cs index c36037d603..43fda30de3 100644 --- a/modules/Nncase.Modules.CPU/CodeGen/CSourceExtensions.cs +++ b/modules/Nncase.Modules.CPU/CodeGen/CSourceExtensions.cs @@ -29,7 +29,7 @@ public static string ToC(this PrimType primType) => public static string ToC(this DataType dataType) => dataType switch { PrimType ptype => ptype.ToC(), - PointerType ptype => "uint8_t *", + PointerType => "uint8_t *", _ => throw new NotSupportedException(dataType.ToString()), }; diff --git a/modules/Nncase.Modules.CPU/CodeGen/LinkableFunction.cs b/modules/Nncase.Modules.CPU/CodeGen/LinkableFunction.cs index bab60300b6..5d3e3aa0fc 100644 --- a/modules/Nncase.Modules.CPU/CodeGen/LinkableFunction.cs +++ b/modules/Nncase.Modules.CPU/CodeGen/LinkableFunction.cs @@ -17,7 +17,7 @@ public LinkableFunction(uint id, byte[] descContents, TIR.PrimFunction sourceFun Text = Array.Empty(); // new LinkedSection(descContents, ".desc", 0, 8, (uint)descContents.Length) - Sections = new ILinkedSection[] { }; + Sections = Array.Empty(); } public uint Id { get; } diff --git a/src/Nncase.Core/IR/Callable.cs b/src/Nncase.Core/IR/Callable.cs index edd6004539..16fc9ad5a7 100644 --- a/src/Nncase.Core/IR/Callable.cs +++ b/src/Nncase.Core/IR/Callable.cs @@ -17,7 +17,7 @@ public abstract class Callable : Expr /// /// StackVM module kind. /// - public static readonly string StackVMModuleKind = "stackvm"; + public const string StackVMModuleKind = "stackvm"; public Callable(string name, string moduleKind, Expr[] operands) : base(operands) diff --git a/src/Nncase.Core/Passes/Mutators/FlattenBuffer.cs b/src/Nncase.Core/Passes/Mutators/FlattenBuffer.cs index 1c15637eb6..b0b41be0b6 100644 --- a/src/Nncase.Core/Passes/Mutators/FlattenBuffer.cs +++ b/src/Nncase.Core/Passes/Mutators/FlattenBuffer.cs @@ -43,7 +43,7 @@ protected override Expr RewriteLeafCall(Call expr) var input = (TIR.Buffer)expr[IR.Buffers.BufferStore.Input]; return T.Store(input.MemSpan, Enumerable.Range(0, indices.Count).Aggregate((Expr)0, (acc, i) => acc + (input.Strides[i] * indices[i])), expr[IR.Buffers.BufferStore.Value]); } - else if (expr.Target is IR.Buffers.MatchBuffer && expr.Arguments[0] is TIR.Buffer { MemSpan: { Start: (Const or Var) } }) + else if (expr.Target is IR.Buffers.MatchBuffer && expr.Arguments[0] is TIR.Buffer { MemSpan: { Start: Const or Var } }) { // remove the all fixed match operation. return T.Nop(); diff --git a/src/Nncase.Passes/BufferSchedule/BufferScheduleTypes.cs b/src/Nncase.Passes/BufferSchedule/BufferScheduleTypes.cs new file mode 100644 index 0000000000..af10e24e63 --- /dev/null +++ b/src/Nncase.Passes/BufferSchedule/BufferScheduleTypes.cs @@ -0,0 +1,67 @@ +// Copyright (c) Canaan Inc. All rights reserved. +// Licensed under the Apache license. See LICENSE file in the project root for full license information. + +namespace Nncase.Passes.BufferSchedule; + +internal sealed class TimeInterval +{ + public TimeInterval(int start, int end) + { + Start = start; + End = end; + } + + public int Start { get; set; } + + public int End { get; set; } + + public override string ToString() + { + return $"TimeInterval({Start}, {End})"; + } +} + +internal sealed class MemSpan +{ + public MemSpan(int start, int end) + { + Start = start; + End = end; + } + + public int Start { get; set; } + + public int End { get; set; } + + public override string ToString() + { + return $"MemSpan({Start}, {End})"; + } +} + +internal class ScheduleBuffer +{ + public ScheduleBuffer(string name, TimeInterval interval, MemSpan span, int[] shape, int[] strides) + { + Name = name; + Interval = interval; + Span = span; + Shape = shape; + Strides = strides; + } + + public string Name { get; } + + public TimeInterval Interval { get; } + + public MemSpan Span { get; } + + public int[] Shape { get; } + + public int[] Strides { get; } + + public override string ToString() + { + return $"ScheduledBuffer('{Name}', {Interval}, {Span}, ConstraintsMode.No, [{string.Join(",", Shape)}], [{string.Join(",", Strides)}])"; + } +} diff --git a/src/Nncase.Passes/BufferSchedule/BufferScheduler.cs b/src/Nncase.Passes/BufferSchedule/BufferScheduler.cs new file mode 100644 index 0000000000..eb9853d22c --- /dev/null +++ b/src/Nncase.Passes/BufferSchedule/BufferScheduler.cs @@ -0,0 +1,114 @@ +// Copyright (c) Canaan Inc. All rights reserved. +// Licensed under the Apache license. See LICENSE file in the project root for full license information. + +using System.Collections.Generic; +using System.IO; +using System.Linq; +using System.Reactive; +using Nncase; +using Nncase.IR; + +namespace Nncase.Passes.BufferSchedule; + +internal sealed class BufferScheduler +{ + public List CollectLifeTime(Function func) + { + var c = new LifeTimeCollector(); + return c.Collect(func); + } + + public void DumpScheduled(string path, List buffers) + { + using (var fs = File.OpenWrite(path)) + { + using (var wr = new StreamWriter(fs)) + { + wr.Write(@"from bokeh.models import ColumnDataSource, HoverTool, FuncTickFormatter, SingleIntervalTicker, SaveTool, WheelZoomTool, WheelPanTool, ResetTool +from bokeh.palettes import Category20_20 as palette +from bokeh.plotting import figure, show, save +import itertools +from dataclasses import dataclass +from enum import Enum +from typing import List +@dataclass +class Interval(): + start: int + end: int + +@dataclass +class Location(): + depth_start: int + depth_size: int + def __str__(self) -> str: + return f'(start: {self.depth_start}, size {self.depth_size})' + +class ConstraintsMode(Enum): + No = 0 + Channel = 1 + +@dataclass +class ScheduledBuffer(): + name: str + interval: Interval + location: Location + constraints: ConstraintsMode + shape: List[int] + stride: List[int] + +colors = itertools.cycle(palette) + +buffers = [ +"); + foreach (var item in buffers) + { + wr.WriteLine(item.ToString()); + } + + wr.Write(@"] + +source = { + 'name': [], + 'x': [], + 'y': [], + 'width': [], + 'height': [], + 'color': [], + 'location': [], + 'shape': [], + 'stride': [], +} + +y_range_max = 0 +for buffer in buffers: + source['name'].append(buffer.name) + width = buffer.interval.end - buffer.interval.start + x = buffer.interval.start + (width // 2) + height = buffer.location.depth_size + y = buffer.location.depth_start + (height // 2) + y_range_max = max(y_range_max, y) + source['x'].append(x) + source['y'].append(y) + source['width'].append(width) + source['height'].append(height) + source['color'].append(next(colors)) + source['location'].append(str(buffer.location)) + source['shape'].append(','.join([str(s) for s in buffer.shape])) + source['stride'].append(','.join([str(s) for s in buffer.stride])) + +source = ColumnDataSource(source) +hover = HoverTool(tooltips=[('name', '@name'), ('location', '@location'), + ('shape', '@shape'), ('stride', '@stride')]) + +p = figure(tools=[hover, WheelPanTool(), SaveTool(), WheelZoomTool(), ResetTool()], width=1280, height=720, + y_range=(0, y_range_max * 2), + title='Local Buffer LifeTime (by Steps)') +p.rect(x='x', y='y', width='width', height='height', fill_color='color', source=source) +p.xaxis.axis_label = 'Time (steps)' +p.outline_line_color = None + +show(p)"); + } + } + } +} diff --git a/src/Nncase.Passes/BufferSchedule/LifeTimeCollector.cs b/src/Nncase.Passes/BufferSchedule/LifeTimeCollector.cs new file mode 100644 index 0000000000..281c13891e --- /dev/null +++ b/src/Nncase.Passes/BufferSchedule/LifeTimeCollector.cs @@ -0,0 +1,103 @@ +// Copyright (c) Canaan Inc. All rights reserved. +// Licensed under the Apache license. See LICENSE file in the project root for full license information. + +using System.Collections.Generic; +using System.Linq; +using System.Reactive; +using Nncase; +using Nncase.IR; + +namespace Nncase.Passes.BufferSchedule; + +internal sealed class LifeTimeCollector : ExprVisitor +{ + public int TimeStamp { get; private set; } + + public Dictionary LifenessMap { get; } = new(ReferenceEqualityComparer.Instance); + + public List Collect(Function entry) + { + Visit(entry.Body); + Alias(); + + var l = new List(); + foreach (var (k, v) in LifenessMap) + { + var name = k switch + { + Call c => c.Target.GetType().Name, + Var va => va.Name, + _ => k.GetType().Name, + }; + + var shape = k.CheckedShape.ToValueArray(); + var stride = TensorUtilities.GetStrides(shape); + var size = TensorUtilities.GetSize(shape, stride, k.CheckedDataType.SizeInBytes); + + l.Add(new(name, v, new(0, size), shape, stride)); + } + + return l; + } + + protected override Unit DefaultVisitLeaf(Expr expr) => Unit.Default; + + protected override Unit VisitLeafCall(Call expr) + { + foreach (var arg in expr.Arguments) + { + Update(arg); + } + + Update(expr); + + TimeStamp += 1; + + return Unit.Default; + } + + private void Update(Expr expr) + { + if (expr is Const) + { + return; + } + + if (expr is IR.Tuple t) + { + foreach (var item in t.Fields) + { + Update(item); + } + + return; + } + + if (!LifenessMap.TryGetValue(expr, out var interval)) + { + interval = new(TimeStamp, TimeStamp + 1); + } + else + { + interval.End += 1; + } + + // advance the getitem buffer. + if (expr is Call { Target: IR.Tensors.GetItem, Arguments: var args } call && args[0] is Call { CheckedType: TupleType }) + { + interval.Start = LifenessMap[args[0]].Start; + } + + LifenessMap[expr] = interval; + } + + private void Alias() + { + // skip the call which output type is tuple. + var calls = LifenessMap.Select(kv => kv.Key is Call { CheckedType: TupleType }).ToArray(); + foreach (var c in calls) + { + LifenessMap.Remove(c); + } + } +} diff --git a/src/Nncase.Passes/DDrBufferSchdeulePass.cs b/src/Nncase.Passes/DDrBufferSchdeulePass.cs index bf79e37d57..cdd9fe3be1 100644 --- a/src/Nncase.Passes/DDrBufferSchdeulePass.cs +++ b/src/Nncase.Passes/DDrBufferSchdeulePass.cs @@ -5,6 +5,7 @@ using System.Collections.Generic; using System.Diagnostics.CodeAnalysis; using System.Linq; +using System.Reactive; using System.Text; using System.Threading.Tasks; using Microsoft.Extensions.DependencyInjection; @@ -40,46 +41,14 @@ public DDrBufferSchdeulePass(bool enableMergeCall = false) protected override async Task RunCoreAsync(IRModule module, RunPassContext options) { // 1. merge the all call prim func -#if false if (_enbaleMergeCall) { - HashSet mergedFuncs = new(ReferenceEqualityComparer.Instance); - HashSet stackvmFuncs = new(ReferenceEqualityComparer.Instance); - for (int i = 0; i < module.Functions.Count; i++) - { - if (module.Functions[i] is Function { ModuleKind: "stackvm" } func) - { - var analysis = new Dictionary - { - [typeof(IExprUserAnalysisResult)] = AnalyzerManager.GetAnaylsis(func), - }; - _ = new HashSet(ReferenceEqualityComparer.Instance); - var mergePass = new DataflowPass(); - mergePass.Add(mergedFuncs); - var post = await mergePass.RunAsync(func, new() { AnalysisResults = analysis, RewriteOnce = true }); - module.Replace(i, post); - stackvmFuncs.Add(post); - } - } - - // 2. add the ext func into module. - foreach (var func in stackvmFuncs) - { - var collector = new ExternalFuncCollector(); - collector.Visit(func); - foreach (var ext_func in collector.GetExternalFuncs()) - { - module.Add(ext_func); - } - } - - // 3. remove the all merged funcs - foreach (var item in mergedFuncs) - { - module.Remove(item); - } + // if (module.Entry is Function { ModuleKind: Callable.StackVMModuleKind, Body: Expr body } func && IsFixedType(body.CheckedType)) + // { + // var sorter = new TopSorter(); + // sorter.GetTimeLine(func); + // } } -#endif // 4. schedule the prim funcs. for (int i = 0; i < module.Functions.Count; i++) @@ -106,6 +75,13 @@ protected override async Task RunCoreAsync(IRModule module, RunPassCon return await Task.FromResult(module); } + + private bool IsFixedType(IRType type) => type switch + { + TensorType tensorType => tensorType.Shape.IsFixed, + TupleType tupleType => tupleType.Fields.All(IsFixedType), + _ => false, + }; } internal sealed class DDrBufferRewriter : ExprRewriter From 85d565290b3fd1c4568dc2aac1a3716c9f886209 Mon Sep 17 00:00:00 2001 From: zhen8838 Date: Tue, 1 Aug 2023 03:09:10 +0000 Subject: [PATCH 051/308] Apply code-format changes --- modules/cpu/src/runtime/cpu_common.h | 12 ++++++++---- modules/cpu/src/runtime/elfloader.h | 3 ++- modules/cpu/src/runtime/runtime_function.cpp | 5 +++-- modules/cpu/src/runtime/runtime_module.cpp | 5 +++-- .../BufferSchedule/BufferScheduleTypes.cs | 2 +- src/Nncase.Passes/BufferSchedule/BufferScheduler.cs | 2 +- .../BufferSchedule/LifeTimeCollector.cs | 2 +- 7 files changed, 19 insertions(+), 12 deletions(-) diff --git a/modules/cpu/src/runtime/cpu_common.h b/modules/cpu/src/runtime/cpu_common.h index d51e1775aa..8221fa015f 100644 --- a/modules/cpu/src/runtime/cpu_common.h +++ b/modules/cpu/src/runtime/cpu_common.h @@ -97,8 +97,10 @@ inline int32_t int32_binary_mul(int32_t x, int32_t y) { return x * y; } inline int32_t int32_binary_div(int32_t x, int32_t y) { return x / y; } inline int32_t int32_binary_min(int32_t x, int32_t y) { return std::min(x, y); } inline int32_t int32_binary_max(int32_t x, int32_t y) { return std::max(x, y); } -#if defined (__arm64__) && defined (__APPLE__) -inline int32_t int32_binary_pow(int32_t x, int32_t y) { return (int32_t)pow(x, y); } +#if defined(__arm64__) && defined(__APPLE__) +inline int32_t int32_binary_pow(int32_t x, int32_t y) { + return (int32_t)pow(x, y); +} #else inline int32_t int32_binary_pow(int32_t x, int32_t y) { return std::pow(x, y); } #endif @@ -111,8 +113,10 @@ inline int64_t int64_binary_mul(int64_t x, int64_t y) { return x * y; } inline int64_t int64_binary_div(int64_t x, int64_t y) { return x / y; } inline int64_t int64_binary_min(int64_t x, int64_t y) { return std::min(x, y); } inline int64_t int64_binary_max(int64_t x, int64_t y) { return std::max(x, y); } -#if defined (__arm64__) && defined (__APPLE__) -inline int64_t int64_binary_pow(int64_t x, int64_t y) { return (int64_t)pow(x, y); } +#if defined(__arm64__) && defined(__APPLE__) +inline int64_t int64_binary_pow(int64_t x, int64_t y) { + return (int64_t)pow(x, y); +} #else inline int64_t int64_binary_pow(int64_t x, int64_t y) { return std::pow(x, y); } #endif diff --git a/modules/cpu/src/runtime/elfloader.h b/modules/cpu/src/runtime/elfloader.h index 5fc7ceaec3..0f9d73e5f7 100644 --- a/modules/cpu/src/runtime/elfloader.h +++ b/modules/cpu/src/runtime/elfloader.h @@ -14,7 +14,8 @@ BEGIN_NS_NNCASE_RT_MODULE(cpu) typedef void (*entrypoint_t)(size_t id, uint8_t **buffers, - nncase_mt_t *nncase_mt, void *data, const uint8_t *rdata); + nncase_mt_t *nncase_mt, void *data, + const uint8_t *rdata); class elfloader { public: diff --git a/modules/cpu/src/runtime/runtime_function.cpp b/modules/cpu/src/runtime/runtime_function.cpp index 776da74410..2b4d7e3a8f 100644 --- a/modules/cpu/src/runtime/runtime_function.cpp +++ b/modules/cpu/src/runtime/runtime_function.cpp @@ -48,7 +48,8 @@ cpu_runtime_module &cpu_runtime_function::module() const noexcept { result cpu_runtime_function::initialize_core( NNCASE_UNUSED runtime_function_init_context &context) noexcept { - // try_(context.read_section(".desc", [this](auto sr, size_t) -> result { + // try_(context.read_section(".desc", [this](auto sr, size_t) -> + // result { // auto header = sr.template read(); // if (parameters_size() != header.inputs + header.outputs) // return nncase::err(std::errc::invalid_argument); @@ -112,7 +113,7 @@ cpu_runtime_function::invoke_core(NNCASE_UNUSED gsl::span parameters, NNCASE_UNUSED value_t return_value) noexcept { try_var(id, module().find_id_by_function(this)); - uint8_t **buffers = new uint8_t*[parameters.size()]; + uint8_t **buffers = new uint8_t *[parameters.size()]; // input buffer for (size_t i = 0; i < parameters.size(); i++) { try_var(input_tensor, parameters[i].as()); diff --git a/modules/cpu/src/runtime/runtime_module.cpp b/modules/cpu/src/runtime/runtime_module.cpp index 6e44d1fd3b..989ad42b8e 100644 --- a/modules/cpu/src/runtime/runtime_module.cpp +++ b/modules/cpu/src/runtime/runtime_module.cpp @@ -27,8 +27,9 @@ result cpu_runtime_module::initialize_before_functions( // if (!context.is_section_pinned()) // return nncase::err(std::errc::bad_address); // try_var(data, context.get_or_read_section(".data", data_storage_, - // false)); - try_set(rdata_, context.get_or_read_section(".rdata", rdata_storage_, true)); + // false)); + try_set(rdata_, + context.get_or_read_section(".rdata", rdata_storage_, true)); try_set(text_, context.get_or_read_section(".text", text_storage_, true)); return ok(); diff --git a/src/Nncase.Passes/BufferSchedule/BufferScheduleTypes.cs b/src/Nncase.Passes/BufferSchedule/BufferScheduleTypes.cs index af10e24e63..d68c5ca5e8 100644 --- a/src/Nncase.Passes/BufferSchedule/BufferScheduleTypes.cs +++ b/src/Nncase.Passes/BufferSchedule/BufferScheduleTypes.cs @@ -1,4 +1,4 @@ -// Copyright (c) Canaan Inc. All rights reserved. +// Copyright (c) Canaan Inc. All rights reserved. // Licensed under the Apache license. See LICENSE file in the project root for full license information. namespace Nncase.Passes.BufferSchedule; diff --git a/src/Nncase.Passes/BufferSchedule/BufferScheduler.cs b/src/Nncase.Passes/BufferSchedule/BufferScheduler.cs index eb9853d22c..496d45af2c 100644 --- a/src/Nncase.Passes/BufferSchedule/BufferScheduler.cs +++ b/src/Nncase.Passes/BufferSchedule/BufferScheduler.cs @@ -1,4 +1,4 @@ -// Copyright (c) Canaan Inc. All rights reserved. +// Copyright (c) Canaan Inc. All rights reserved. // Licensed under the Apache license. See LICENSE file in the project root for full license information. using System.Collections.Generic; diff --git a/src/Nncase.Passes/BufferSchedule/LifeTimeCollector.cs b/src/Nncase.Passes/BufferSchedule/LifeTimeCollector.cs index 281c13891e..454fd25608 100644 --- a/src/Nncase.Passes/BufferSchedule/LifeTimeCollector.cs +++ b/src/Nncase.Passes/BufferSchedule/LifeTimeCollector.cs @@ -1,4 +1,4 @@ -// Copyright (c) Canaan Inc. All rights reserved. +// Copyright (c) Canaan Inc. All rights reserved. // Licensed under the Apache license. See LICENSE file in the project root for full license information. using System.Collections.Generic; From 66865183b4d0744355fb1458ae27bd3fe3e2696d Mon Sep 17 00:00:00 2001 From: huochenghai Date: Tue, 1 Aug 2023 13:41:12 +0800 Subject: [PATCH 052/308] fix mac build --- modules/cpu/src/runtime/cpu_common.h | 4 +- tests/config.toml | 18 +- tests/generator.py | 4 +- tests/importer/onnx_/basic/test_binary.py | 12 +- .../onnx_/basic/test_binary_from_onnx.py | 100 ++++---- .../onnx_/basic/test_conv_transpose.py | 214 +++++++++++++----- tests/importer/onnx_/basic/test_unary.py | 28 +-- .../tflite_/basic/test_fully_connected.py | 19 +- tests/test_runner.py | 2 +- 9 files changed, 257 insertions(+), 144 deletions(-) diff --git a/modules/cpu/src/runtime/cpu_common.h b/modules/cpu/src/runtime/cpu_common.h index 8221fa015f..8fdca3c3aa 100644 --- a/modules/cpu/src/runtime/cpu_common.h +++ b/modules/cpu/src/runtime/cpu_common.h @@ -97,7 +97,7 @@ inline int32_t int32_binary_mul(int32_t x, int32_t y) { return x * y; } inline int32_t int32_binary_div(int32_t x, int32_t y) { return x / y; } inline int32_t int32_binary_min(int32_t x, int32_t y) { return std::min(x, y); } inline int32_t int32_binary_max(int32_t x, int32_t y) { return std::max(x, y); } -#if defined(__arm64__) && defined(__APPLE__) +#if defined(__APPLE__) inline int32_t int32_binary_pow(int32_t x, int32_t y) { return (int32_t)pow(x, y); } @@ -113,7 +113,7 @@ inline int64_t int64_binary_mul(int64_t x, int64_t y) { return x * y; } inline int64_t int64_binary_div(int64_t x, int64_t y) { return x / y; } inline int64_t int64_binary_min(int64_t x, int64_t y) { return std::min(x, y); } inline int64_t int64_binary_max(int64_t x, int64_t y) { return std::max(x, y); } -#if defined(__arm64__) && defined(__APPLE__) +#if defined(__APPLE__) inline int64_t int64_binary_pow(int64_t x, int64_t y) { return (int64_t)pow(x, y); } diff --git a/tests/config.toml b/tests/config.toml index 7d27973767..4a13208015 100644 --- a/tests/config.toml +++ b/tests/config.toml @@ -16,7 +16,7 @@ output_layout = 'NHWC' model_layout = 'NHWC' letterbox_value = 0 dump_asm = true -dump_ir = false +dump_ir = true [ptq_opt] use_mix_quant = false @@ -59,7 +59,7 @@ args = [] [generator.calibs] method = 'random' -number = 5 +number = 1 batch = 1 [generator.calibs.random] @@ -80,21 +80,21 @@ args = [] [target] [target.cpu] -eval = true +eval = false infer = true simarity_name = 'cosine' [target.cpu.mode.noptq] -enabled = false +enabled = true threshold = 0.999 [target.cpu.mode.ptq] -enabled = true +enabled = false threshold = 0.98 [target.k510] -eval = true -infer = true +eval = false +infer = false simarity_name = 'cosine' [target.k510.mode.noptq] @@ -106,8 +106,8 @@ enabled = true threshold = 0.98 [target.k230] -eval = true -infer = true +eval = false +infer = false simarity_name = 'cosine' [target.k230.mode.noptq] diff --git a/tests/generator.py b/tests/generator.py index 631883c757..74408b9e3d 100644 --- a/tests/generator.py +++ b/tests/generator.py @@ -14,9 +14,9 @@ def from_random(self, shape: List[int], dtype: np.dtype, abs: bool = False) -> n elif dtype == np.bool: data = np.random.rand(*shape) > 0.5 elif dtype == np.int32: - data = np.random.randint(1, 5, size=shape, dtype='int32') + data = np.random.randint(1, 2, size=shape, dtype='int32') elif dtype == np.int64: - data = np.random.randint(1, 5, size=shape, dtype='int64') + data = np.random.randint(1, 2, size=shape, dtype='int64') # data = np.random.randint(1, 128, size=shape, dtype='int64') else: data = np.random.rand(*shape) diff --git a/tests/importer/onnx_/basic/test_binary.py b/tests/importer/onnx_/basic/test_binary.py index c629691f04..abb928157e 100644 --- a/tests/importer/onnx_/basic/test_binary.py +++ b/tests/importer/onnx_/basic/test_binary.py @@ -23,13 +23,13 @@ def _make_module(v_shape): class BinaryModule(torch.nn.Module): def __init__(self): super(BinaryModule, self).__init__() - # self.v = torch.from_numpy(np.random.rand(*v_shape).astype(np.float32)) - self.v = torch.from_numpy(np.ones(v_shape).astype(np.float32)) + self.v = torch.from_numpy(np.random.rand(*v_shape).astype(np.float32)) + # self.v = torch.from_numpy(np.ones(v_shape).astype(np.float32)) def forward(self, x): outs = [] - outs.append(torch.add(x, self.v)) - # outs.append(torch.mul(x, self.v)) + # outs.append(torch.add(x, self.v)) + outs.append(torch.mul(x, self.v)) # outs.append(torch.sub(x, self.v)) # outs.append(torch.max(x, self.v)) # outs.append(torch.div(x, self.v)) @@ -45,7 +45,7 @@ def forward(self, x): # [64, 3], # [3, 64, 3], # [8, 3, 64, 3] - [1, 3, 24, 24] + [1, 560, 196] ] rhs_shapes = [ @@ -61,7 +61,7 @@ def forward(self, x): # [8, 3, 1, 3], # [8, 1, 64, 3], # [1, 3, 64, 1] - [1, 3, 24, 24] + [1, 560, 1] ] diff --git a/tests/importer/onnx_/basic/test_binary_from_onnx.py b/tests/importer/onnx_/basic/test_binary_from_onnx.py index f95bf61a80..0a3ba13c62 100644 --- a/tests/importer/onnx_/basic/test_binary_from_onnx.py +++ b/tests/importer/onnx_/basic/test_binary_from_onnx.py @@ -31,27 +31,29 @@ def _make_module(op, in_type, in_shape_0, in_shape_1): # input1 input1 = helper.make_tensor_value_info('input1', in_type, in_shape_0) inputs.append('input1') + input2 = helper.make_tensor_value_info('input2', in_type, in_shape_1) + inputs.append('input2') # set input2 to avoid SIGFPE for div op. - if op != 'Pow': - tensor = helper.make_tensor( - 'input2', - in_type, - dims=in_shape_1, - vals=(np.random.rand(*in_shape_1) + - 2).astype(onnx.mapping.TENSOR_TYPE_TO_NP_TYPE[in_type]).flatten().tolist() - ) - inputs.append('input2') - initializers.append(tensor) - else: - tensor = helper.make_tensor( - 'input2', - TensorProto.INT32, - dims=[1], - vals=[2] - ) - inputs.append('input2') - initializers.append(tensor) + # if op != 'Pow': + # tensor = helper.make_tensor( + # 'input2', + # in_type, + # dims=in_shape_1, + # vals=(np.random.rand(*in_shape_1) + + # 2).astype(onnx.mapping.TENSOR_TYPE_TO_NP_TYPE[in_type]).flatten().tolist() + # ) + # inputs.append('input2') + # initializers.append(tensor) + # else: + # tensor = helper.make_tensor( + # 'input2', + # TensorProto.INT32, + # dims=[1], + # vals=[2] + # ) + # inputs.append('input2') + # initializers.append(tensor) # output x = np.random.randn(*in_shape_0) @@ -71,7 +73,7 @@ def _make_module(op, in_type, in_shape_0, in_shape_1): graph_def = helper.make_graph( nodes, 'test-model', - [input1], + [input1, input2], [output], initializer=initializers) @@ -84,9 +86,9 @@ def _make_module(op, in_type, in_shape_0, in_shape_1): 'Sub', 'Mul', 'Div', - 'Min', - 'Max', - 'Pow', + # 'Min', + # 'Max', + # 'Pow', ] in_types = [ @@ -94,45 +96,45 @@ def _make_module(op, in_type, in_shape_0, in_shape_1): ] in_shapes = [ - [[1, 3, 16, 16], [1]], - [[1, 3, 16, 16], [16]], - [[1, 3, 16, 16], [1, 16]], - [[1, 3, 16, 16], [16, 16]], - [[1, 3, 16, 16], [1, 16, 16]], - [[1, 3, 16, 16], [3, 16, 16]], + # [[1, 3, 16, 16], [1]], + # [[1, 3, 16, 16], [16]], + # [[1, 3, 16, 16], [1, 16]], + # [[1, 3, 16, 16], [16, 16]], + # [[1, 3, 16, 16], [1, 16, 16]], + # [[1, 3, 16, 16], [3, 16, 16]], [[1, 3, 16, 16], [1, 3, 16, 16]], - [[3, 16, 16], [1]], - [[3, 16, 16], [16]], - [[3, 16, 16], [1, 16]], - [[3, 16, 16], [16, 16]], + # [[3, 16, 16], [1]], + # [[3, 16, 16], [16]], + # [[3, 16, 16], [1, 16]], + # [[3, 16, 16], [16, 16]], [[3, 16, 16], [1, 16, 16]], [[3, 16, 16], [3, 16, 16]], - [[3, 16, 16], [1, 3, 16, 16]], + # [[3, 16, 16], [1, 3, 16, 16]], - [[16, 16], [1]], - [[16, 16], [16]], + # [[16, 16], [1]], + # [[16, 16], [16]], [[16, 16], [1, 16]], [[16, 16], [16, 16]], - [[16, 16], [1, 16, 16]], - [[16, 16], [3, 16, 16]], - [[16, 16], [1, 3, 16, 16]], + # [[16, 16], [1, 16, 16]], + # [[16, 16], [3, 16, 16]], + # [[16, 16], [1, 3, 16, 16]], [[1], [1]], [[1], [16]], - [[1], [1, 16]], - [[1], [16, 16]], - [[1], [1, 16, 16]], - [[1], [3, 16, 16]], - [[1], [1, 3, 16, 16]], + # [[1], [1, 16]], + # [[1], [16, 16]], + # [[1], [1, 16, 16]], + # [[1], [3, 16, 16]], + # [[1], [1, 3, 16, 16]], [[16], [1]], [[16], [16]], - [[16], [1, 16]], - [[16], [16, 16]], - [[16], [1, 16, 16]], - [[16], [3, 16, 16]], - [[16], [1, 3, 16, 16]] + # [[16], [1, 16]], + # [[16], [16, 16]], + # [[16], [1, 16, 16]], + # [[16], [3, 16, 16]], + # [[16], [1, 3, 16, 16]] ] diff --git a/tests/importer/onnx_/basic/test_conv_transpose.py b/tests/importer/onnx_/basic/test_conv_transpose.py index c2cfc3c8c3..a675680e2b 100644 --- a/tests/importer/onnx_/basic/test_conv_transpose.py +++ b/tests/importer/onnx_/basic/test_conv_transpose.py @@ -16,88 +16,194 @@ import math import pytest import onnx -import torch from onnx import helper from onnx import AttributeProto, TensorProto, GraphProto from onnx_test_runner import OnnxTestRunner import numpy as np -def _make_module(in_channel, out_channel, kernel_size, stride, dilation, pad, group, bias): - class ConvTransposeModule(torch.nn.Module): - def __init__(self): - super(ConvTransposeModule, self).__init__() - self.conv_transpose = torch.nn.ConvTranspose2d( - in_channel, out_channel, kernel_size, stride, pad, [0, 0], group, bias, dilation) - - def forward(self, x): - return self.conv_transpose(x) - - return ConvTransposeModule() - - -in_sizes = [ - [16, 16], - [33, 65], +def _make_module(in_shape, kernel_output_channel, bias_shape, auto_pad_mode, dilation, group, kernel_shape, output_padding, pad, stride): + inputs = [] + initializers = [] + + # input + input = helper.make_tensor_value_info('input', TensorProto.FLOAT, in_shape) + inputs.append('input') + + group = 1 if group is None else group + + # weight + w_shape = [] + w_shape.append(in_shape[1]) + w_shape.append(kernel_output_channel // group) + w_shape.extend(kernel_shape) + weight = helper.make_tensor( + 'weight', + TensorProto.FLOAT, + dims=w_shape, + vals=np.random.rand(*w_shape).astype(np.float32).flatten().tolist() + ) + inputs.append('weight') + initializers.append(weight) + + # bias + if bias_shape is not None: + bias = helper.make_tensor( + 'bias', + TensorProto.FLOAT, + dims=bias_shape, + vals=np.random.rand(*bias_shape).astype(np.float32).flatten().tolist() + ) + inputs.append('bias') + initializers.append(bias) + + # dilation + d = [1, 1] if dilation is None else dilation + + # output_padding + out_padding = [0, 0] if output_padding is None else output_padding + + # stride + s = [1, 1] if stride is None else stride + + # output + out_shape = [] + out_shape.append(in_shape[0]) + out_shape.append(w_shape[1] * group) + + # pad + padding = [0, 0, 0, 0] + if auto_pad_mode in [None, 'NOTSET'] and pad is not None: + padding = pad + out_shape.append(s[0] * (in_shape[2] - 1) + out_padding[0] + + (w_shape[2] - 1) * d[0] + 1 - padding[0] - padding[2]) + out_shape.append(s[1] * (in_shape[3] - 1) + out_padding[1] + + (w_shape[3] - 1) * d[1] + 1 - padding[1] - padding[3]) + elif auto_pad_mode in ['SAME_UPPER', 'SAME_LOWER']: + out_shape.append(in_shape[2] * s[0]) + out_shape.append(in_shape[3] * s[1]) + else: + out_shape.append(in_shape[2] + (in_shape[2] - 1) * (s[0] - 1) - w_shape[2] + 1) + out_shape.append(in_shape[3] + (in_shape[3] - 1) * (s[1] - 1) - w_shape[3] + 1) + + output = helper.make_tensor_value_info('output', TensorProto.FLOAT, out_shape) + + attributes_dict = {} + + if auto_pad_mode is not None: + attributes_dict['auto_pad'] = auto_pad_mode + + if dilation is not None: + attributes_dict['dilations'] = dilation + + if group is not None: + attributes_dict['group'] = group + + if kernel_shape is not None: + attributes_dict['kernel_shape'] = kernel_shape + + if output_padding is not None: + attributes_dict['output_padding'] = output_padding + + if pad is not None: + attributes_dict['pads'] = padding + + if stride is not None: + attributes_dict['strides'] = stride + + node = onnx.helper.make_node( + 'ConvTranspose', + inputs=inputs, + outputs=['output'], + **attributes_dict + ) + + nodes = [] + nodes.append(node) + + graph_def = helper.make_graph( + nodes, + 'test-model', + [input], + [output], + initializer=initializers) + + model_def = helper.make_model(graph_def, producer_name='kendryte') + + return model_def + + +in_shapes = [ + [1, 3, 10, 10] ] -in_channels = [ - 1, - 3, - 16 +kernel_output_channels = [ + 3 ] -out_channels = [ - 1, - 16 +bias_shapes = [ + None, ] - -kernel_sizes = [ - [1, 1], - [3, 3], +bias_shapes.extend(list([[x] for x in kernel_output_channels])) + +auto_pad_modes = [ + None, + # 'NOTSET', + # 'SAME_UPPER', + # 'SAME_LOWER', + # 'VALID' ] -strides = [ - 1, - [2, 2] +dilations = [ + None, ] -dilations = [ - 1 +groups = [ + None, + # 3 ] -pads = [ - 0, +kernel_shapes = [ [1, 1], ] -groups = [ - 1 +output_paddings = [ + None, + # [1, 1] +] + +pads = [ + # None, + [0, 0, 1, 1], ] -biases = [ - True, - False +strides = [ + None, + # [2, 3], + # [3, 2], + # [3, 3] ] -@pytest.mark.parametrize('in_size', in_sizes) -@pytest.mark.parametrize('in_channel', in_channels) -@pytest.mark.parametrize('out_channel', out_channels) -@pytest.mark.parametrize('kernel_size', kernel_sizes) -@pytest.mark.parametrize('stride', strides) +@pytest.mark.parametrize('in_shape', in_shapes) +@pytest.mark.parametrize('kernel_output_channel', kernel_output_channels) +@pytest.mark.parametrize('bias_shape', bias_shapes) +@pytest.mark.parametrize('auto_pad_mode', auto_pad_modes) @pytest.mark.parametrize('dilation', dilations) -@pytest.mark.parametrize('pad', pads) @pytest.mark.parametrize('group', groups) -@pytest.mark.parametrize('bias', biases) -def test_conv_transpose(in_size, in_channel, out_channel, kernel_size, stride, dilation, pad, group, bias, request): - model_file = _make_module(in_channel, out_channel, kernel_size, - stride, dilation, pad, group, bias) +@pytest.mark.parametrize('kernel_shape', kernel_shapes) +@pytest.mark.parametrize('output_padding', output_paddings) +@pytest.mark.parametrize('pad', pads) +@pytest.mark.parametrize('stride', strides) +def test_conv_transpose(in_shape, kernel_output_channel, bias_shape, auto_pad_mode, dilation, group, kernel_shape, output_padding, pad, stride, request): + if (bias_shape is None or (bias_shape is not None and bias_shape[0] == kernel_output_channel)) and ((auto_pad_mode in [None, 'NOTSET'] and pad is not None) or (auto_pad_mode in ['SAME_UPPER', 'SAME_LOWER', 'VALID'] and pad is None)) and (dilation is None or (auto_pad_modes in [None, 'NOTSET'])) and ((output_padding is None) or (output_padding is not None and stride is not None)): + model_def = _make_module(in_shape, kernel_output_channel, bias_shape, + auto_pad_mode, dilation, group, kernel_shape, output_padding, pad, stride) - runner = OnnxTestRunner(request.node.name) - model_file = runner.from_torch(model_file, [1, in_channel, *in_size]) - runner.run(model_file) + runner = OnnxTestRunner(request.node.name) + model_file = runner.from_onnx_helper(model_def) + runner.run(model_file) if __name__ == "__main__": - pytest.main(['-vv', __file__]) + pytest.main(['-vv', 'test_conv_transpose.py']) diff --git a/tests/importer/onnx_/basic/test_unary.py b/tests/importer/onnx_/basic/test_unary.py index 1137dd0a06..58353ff25a 100644 --- a/tests/importer/onnx_/basic/test_unary.py +++ b/tests/importer/onnx_/basic/test_unary.py @@ -26,19 +26,19 @@ def __init__(self): def forward(self, x): outs = [] - outs.append(torch.abs(-x)) - outs.append(torch.acos(x)) - outs.append(torch.asin(x)) - outs.append(torch.ceil(x)) - outs.append(torch.cos(x)) - outs.append(torch.exp(x)) + # outs.append(torch.abs(-x)) + # outs.append(torch.acos(x)) + # outs.append(torch.asin(x)) + # outs.append(torch.ceil(x)) + # outs.append(torch.cos(x)) + # outs.append(torch.exp(x)) outs.append(torch.floor(x * 10)) - outs.append(torch.log(x + 2)) - outs.append(torch.neg(x)) - outs.append(torch.round(x)) - outs.append(torch.sin(x)) - outs.append(torch.sqrt(x + 2)) - outs.append(torch.tanh(x)) + # outs.append(torch.log(x + 2)) + # outs.append(torch.neg(x)) + # outs.append(torch.round(x)) + # outs.append(torch.sin(x)) + # outs.append(torch.sqrt(x + 2)) + # outs.append(torch.tanh(x)) return outs return UnaryModule() @@ -46,7 +46,7 @@ def forward(self, x): in_shapes = [ [16], - [1, 3, 16, 16] + # [1, 3, 16, 16] ] @@ -60,4 +60,4 @@ def test_unary(in_shape, request): if __name__ == "__main__": - pytest.main(['-vv', 'test_unary.py']) + pytest.main(['-vv', __file__]) diff --git a/tests/importer/tflite_/basic/test_fully_connected.py b/tests/importer/tflite_/basic/test_fully_connected.py index 6cc3996d03..c81829f4ce 100644 --- a/tests/importer/tflite_/basic/test_fully_connected.py +++ b/tests/importer/tflite_/basic/test_fully_connected.py @@ -28,29 +28,34 @@ def __init__(self): @tf.function(input_signature=[tf.TensorSpec(input_shape, dtype=tf.float32)]) def __call__(self, x): - return self.out(x) + out = [] + x = self.out(x) + y = tf.reshape(x, [1, 1, 560, 80]) + out.append(x) + out.append(y) + return out return FullyConnectedModule() input_shapes = [ - [4, 6], - [3, 7] + [1, 560, 128], + # [3, 7] ] units = [ - 3, - 13 + 80, + # 13 ] activations = [ None, - 'relu', + # 'relu', ] use_biases = [ True, - False + # False ] diff --git a/tests/test_runner.py b/tests/test_runner.py index 335b22e2c9..f6b63b0501 100644 --- a/tests/test_runner.py +++ b/tests/test_runner.py @@ -255,7 +255,7 @@ def run(self, model_file: Union[List[str], str]): if not judge: if test_utils.in_ci(): self.clear(self.case_dir) - assert f"Fault result in {stage} + {result}" + assert (judge), f"Fault result in {stage} + {result}" if test_utils.in_ci(): self.clear(self.case_dir) From 42e7b550fe7c64af885b87b42a2e4f545e51a175 Mon Sep 17 00:00:00 2001 From: huochenghai Date: Tue, 1 Aug 2023 14:37:15 +0800 Subject: [PATCH 053/308] fix cpu rdata write --- modules/Nncase.Modules.CPU/CodeGen/FunctionBuilder.cs | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/modules/Nncase.Modules.CPU/CodeGen/FunctionBuilder.cs b/modules/Nncase.Modules.CPU/CodeGen/FunctionBuilder.cs index a571294a75..3ba264dbec 100644 --- a/modules/Nncase.Modules.CPU/CodeGen/FunctionBuilder.cs +++ b/modules/Nncase.Modules.CPU/CodeGen/FunctionBuilder.cs @@ -127,7 +127,7 @@ public unsafe LinkableFunction Build(TIR.PrimFunction function) throw new InvalidDataException("The Buffer Szie Not Equal!"); } - _rdataWriter.Position((uint)size); + _rdataWriter.Position(range.Start.Value); _rdataWriter.Write(bytes); } From 246ff2a494e41b0ba8926f882aa344415b4fd541 Mon Sep 17 00:00:00 2001 From: huochenghai Date: Tue, 1 Aug 2023 14:41:51 +0800 Subject: [PATCH 054/308] try to fix msvc build --- CMakeLists.txt | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/CMakeLists.txt b/CMakeLists.txt index 8cbe18e463..019393e9df 100644 --- a/CMakeLists.txt +++ b/CMakeLists.txt @@ -92,7 +92,7 @@ if (BUILDING_RUNTIME) if (MSVC) add_definitions(/D_CRT_SECURE_NO_WARNINGS /DNOMINMAX) - add_compile_options(/wd4267 /wd4251 /wd4244 /FC /utf-8 /W3 /WX /wd4297 -Wno-unused-function -Wno-unused-command-line-argument) + add_compile_options(/wd4267 /wd4251 /wd4244 /FC /utf-8 /W3 /WX /wd4297 -Wno-unused-function -Wno-unused-command-line-argument -Wno-int-to-void-pointer-cast) else() add_compile_options(-Wall -Wextra -pedantic -Werror -Wno-multichar -Wno-missing-field-initializers -Wno-unused-function -Wno-type-limits) if (APPLE) @@ -184,7 +184,7 @@ else() if (MSVC) add_definitions(/D_SILENCE_ALL_CXX17_DEPRECATION_WARNINGS /D_CRT_SECURE_NO_WARNINGS /DNOMINMAX) - add_compile_options(/wd4267 /wd4251 /wd4244 /FC /utf-8 /W3 /WX -Wno-unused-function -Wno-unused-command-line-argument) + add_compile_options(/wd4267 /wd4251 /wd4244 /FC /utf-8 /W3 /WX -Wno-unused-function -Wno-unused-command-line-argument -Wno-int-to-void-pointer-cast) set(PYBIND11_CPP_STANDARD "/std:c++latest") else() add_compile_options(-fvisibility=hidden) From fa6611002fc4690b8b657defc30ef9ea6b85ba5d Mon Sep 17 00:00:00 2001 From: huochenghai Date: Tue, 1 Aug 2023 14:57:17 +0800 Subject: [PATCH 055/308] Revert "fix mac build" This reverts commit 66865183b4d0744355fb1458ae27bd3fe3e2696d. --- modules/cpu/src/runtime/cpu_common.h | 4 +- tests/config.toml | 18 +- tests/generator.py | 4 +- tests/importer/onnx_/basic/test_binary.py | 12 +- .../onnx_/basic/test_binary_from_onnx.py | 100 ++++---- .../onnx_/basic/test_conv_transpose.py | 214 +++++------------- tests/importer/onnx_/basic/test_unary.py | 28 +-- .../tflite_/basic/test_fully_connected.py | 19 +- tests/test_runner.py | 2 +- 9 files changed, 144 insertions(+), 257 deletions(-) diff --git a/modules/cpu/src/runtime/cpu_common.h b/modules/cpu/src/runtime/cpu_common.h index 8fdca3c3aa..8221fa015f 100644 --- a/modules/cpu/src/runtime/cpu_common.h +++ b/modules/cpu/src/runtime/cpu_common.h @@ -97,7 +97,7 @@ inline int32_t int32_binary_mul(int32_t x, int32_t y) { return x * y; } inline int32_t int32_binary_div(int32_t x, int32_t y) { return x / y; } inline int32_t int32_binary_min(int32_t x, int32_t y) { return std::min(x, y); } inline int32_t int32_binary_max(int32_t x, int32_t y) { return std::max(x, y); } -#if defined(__APPLE__) +#if defined(__arm64__) && defined(__APPLE__) inline int32_t int32_binary_pow(int32_t x, int32_t y) { return (int32_t)pow(x, y); } @@ -113,7 +113,7 @@ inline int64_t int64_binary_mul(int64_t x, int64_t y) { return x * y; } inline int64_t int64_binary_div(int64_t x, int64_t y) { return x / y; } inline int64_t int64_binary_min(int64_t x, int64_t y) { return std::min(x, y); } inline int64_t int64_binary_max(int64_t x, int64_t y) { return std::max(x, y); } -#if defined(__APPLE__) +#if defined(__arm64__) && defined(__APPLE__) inline int64_t int64_binary_pow(int64_t x, int64_t y) { return (int64_t)pow(x, y); } diff --git a/tests/config.toml b/tests/config.toml index 4a13208015..7d27973767 100644 --- a/tests/config.toml +++ b/tests/config.toml @@ -16,7 +16,7 @@ output_layout = 'NHWC' model_layout = 'NHWC' letterbox_value = 0 dump_asm = true -dump_ir = true +dump_ir = false [ptq_opt] use_mix_quant = false @@ -59,7 +59,7 @@ args = [] [generator.calibs] method = 'random' -number = 1 +number = 5 batch = 1 [generator.calibs.random] @@ -80,21 +80,21 @@ args = [] [target] [target.cpu] -eval = false +eval = true infer = true simarity_name = 'cosine' [target.cpu.mode.noptq] -enabled = true +enabled = false threshold = 0.999 [target.cpu.mode.ptq] -enabled = false +enabled = true threshold = 0.98 [target.k510] -eval = false -infer = false +eval = true +infer = true simarity_name = 'cosine' [target.k510.mode.noptq] @@ -106,8 +106,8 @@ enabled = true threshold = 0.98 [target.k230] -eval = false -infer = false +eval = true +infer = true simarity_name = 'cosine' [target.k230.mode.noptq] diff --git a/tests/generator.py b/tests/generator.py index 74408b9e3d..631883c757 100644 --- a/tests/generator.py +++ b/tests/generator.py @@ -14,9 +14,9 @@ def from_random(self, shape: List[int], dtype: np.dtype, abs: bool = False) -> n elif dtype == np.bool: data = np.random.rand(*shape) > 0.5 elif dtype == np.int32: - data = np.random.randint(1, 2, size=shape, dtype='int32') + data = np.random.randint(1, 5, size=shape, dtype='int32') elif dtype == np.int64: - data = np.random.randint(1, 2, size=shape, dtype='int64') + data = np.random.randint(1, 5, size=shape, dtype='int64') # data = np.random.randint(1, 128, size=shape, dtype='int64') else: data = np.random.rand(*shape) diff --git a/tests/importer/onnx_/basic/test_binary.py b/tests/importer/onnx_/basic/test_binary.py index abb928157e..c629691f04 100644 --- a/tests/importer/onnx_/basic/test_binary.py +++ b/tests/importer/onnx_/basic/test_binary.py @@ -23,13 +23,13 @@ def _make_module(v_shape): class BinaryModule(torch.nn.Module): def __init__(self): super(BinaryModule, self).__init__() - self.v = torch.from_numpy(np.random.rand(*v_shape).astype(np.float32)) - # self.v = torch.from_numpy(np.ones(v_shape).astype(np.float32)) + # self.v = torch.from_numpy(np.random.rand(*v_shape).astype(np.float32)) + self.v = torch.from_numpy(np.ones(v_shape).astype(np.float32)) def forward(self, x): outs = [] - # outs.append(torch.add(x, self.v)) - outs.append(torch.mul(x, self.v)) + outs.append(torch.add(x, self.v)) + # outs.append(torch.mul(x, self.v)) # outs.append(torch.sub(x, self.v)) # outs.append(torch.max(x, self.v)) # outs.append(torch.div(x, self.v)) @@ -45,7 +45,7 @@ def forward(self, x): # [64, 3], # [3, 64, 3], # [8, 3, 64, 3] - [1, 560, 196] + [1, 3, 24, 24] ] rhs_shapes = [ @@ -61,7 +61,7 @@ def forward(self, x): # [8, 3, 1, 3], # [8, 1, 64, 3], # [1, 3, 64, 1] - [1, 560, 1] + [1, 3, 24, 24] ] diff --git a/tests/importer/onnx_/basic/test_binary_from_onnx.py b/tests/importer/onnx_/basic/test_binary_from_onnx.py index 0a3ba13c62..f95bf61a80 100644 --- a/tests/importer/onnx_/basic/test_binary_from_onnx.py +++ b/tests/importer/onnx_/basic/test_binary_from_onnx.py @@ -31,29 +31,27 @@ def _make_module(op, in_type, in_shape_0, in_shape_1): # input1 input1 = helper.make_tensor_value_info('input1', in_type, in_shape_0) inputs.append('input1') - input2 = helper.make_tensor_value_info('input2', in_type, in_shape_1) - inputs.append('input2') # set input2 to avoid SIGFPE for div op. - # if op != 'Pow': - # tensor = helper.make_tensor( - # 'input2', - # in_type, - # dims=in_shape_1, - # vals=(np.random.rand(*in_shape_1) + - # 2).astype(onnx.mapping.TENSOR_TYPE_TO_NP_TYPE[in_type]).flatten().tolist() - # ) - # inputs.append('input2') - # initializers.append(tensor) - # else: - # tensor = helper.make_tensor( - # 'input2', - # TensorProto.INT32, - # dims=[1], - # vals=[2] - # ) - # inputs.append('input2') - # initializers.append(tensor) + if op != 'Pow': + tensor = helper.make_tensor( + 'input2', + in_type, + dims=in_shape_1, + vals=(np.random.rand(*in_shape_1) + + 2).astype(onnx.mapping.TENSOR_TYPE_TO_NP_TYPE[in_type]).flatten().tolist() + ) + inputs.append('input2') + initializers.append(tensor) + else: + tensor = helper.make_tensor( + 'input2', + TensorProto.INT32, + dims=[1], + vals=[2] + ) + inputs.append('input2') + initializers.append(tensor) # output x = np.random.randn(*in_shape_0) @@ -73,7 +71,7 @@ def _make_module(op, in_type, in_shape_0, in_shape_1): graph_def = helper.make_graph( nodes, 'test-model', - [input1, input2], + [input1], [output], initializer=initializers) @@ -86,9 +84,9 @@ def _make_module(op, in_type, in_shape_0, in_shape_1): 'Sub', 'Mul', 'Div', - # 'Min', - # 'Max', - # 'Pow', + 'Min', + 'Max', + 'Pow', ] in_types = [ @@ -96,45 +94,45 @@ def _make_module(op, in_type, in_shape_0, in_shape_1): ] in_shapes = [ - # [[1, 3, 16, 16], [1]], - # [[1, 3, 16, 16], [16]], - # [[1, 3, 16, 16], [1, 16]], - # [[1, 3, 16, 16], [16, 16]], - # [[1, 3, 16, 16], [1, 16, 16]], - # [[1, 3, 16, 16], [3, 16, 16]], + [[1, 3, 16, 16], [1]], + [[1, 3, 16, 16], [16]], + [[1, 3, 16, 16], [1, 16]], + [[1, 3, 16, 16], [16, 16]], + [[1, 3, 16, 16], [1, 16, 16]], + [[1, 3, 16, 16], [3, 16, 16]], [[1, 3, 16, 16], [1, 3, 16, 16]], - # [[3, 16, 16], [1]], - # [[3, 16, 16], [16]], - # [[3, 16, 16], [1, 16]], - # [[3, 16, 16], [16, 16]], + [[3, 16, 16], [1]], + [[3, 16, 16], [16]], + [[3, 16, 16], [1, 16]], + [[3, 16, 16], [16, 16]], [[3, 16, 16], [1, 16, 16]], [[3, 16, 16], [3, 16, 16]], - # [[3, 16, 16], [1, 3, 16, 16]], + [[3, 16, 16], [1, 3, 16, 16]], - # [[16, 16], [1]], - # [[16, 16], [16]], + [[16, 16], [1]], + [[16, 16], [16]], [[16, 16], [1, 16]], [[16, 16], [16, 16]], - # [[16, 16], [1, 16, 16]], - # [[16, 16], [3, 16, 16]], - # [[16, 16], [1, 3, 16, 16]], + [[16, 16], [1, 16, 16]], + [[16, 16], [3, 16, 16]], + [[16, 16], [1, 3, 16, 16]], [[1], [1]], [[1], [16]], - # [[1], [1, 16]], - # [[1], [16, 16]], - # [[1], [1, 16, 16]], - # [[1], [3, 16, 16]], - # [[1], [1, 3, 16, 16]], + [[1], [1, 16]], + [[1], [16, 16]], + [[1], [1, 16, 16]], + [[1], [3, 16, 16]], + [[1], [1, 3, 16, 16]], [[16], [1]], [[16], [16]], - # [[16], [1, 16]], - # [[16], [16, 16]], - # [[16], [1, 16, 16]], - # [[16], [3, 16, 16]], - # [[16], [1, 3, 16, 16]] + [[16], [1, 16]], + [[16], [16, 16]], + [[16], [1, 16, 16]], + [[16], [3, 16, 16]], + [[16], [1, 3, 16, 16]] ] diff --git a/tests/importer/onnx_/basic/test_conv_transpose.py b/tests/importer/onnx_/basic/test_conv_transpose.py index a675680e2b..c2cfc3c8c3 100644 --- a/tests/importer/onnx_/basic/test_conv_transpose.py +++ b/tests/importer/onnx_/basic/test_conv_transpose.py @@ -16,194 +16,88 @@ import math import pytest import onnx +import torch from onnx import helper from onnx import AttributeProto, TensorProto, GraphProto from onnx_test_runner import OnnxTestRunner import numpy as np -def _make_module(in_shape, kernel_output_channel, bias_shape, auto_pad_mode, dilation, group, kernel_shape, output_padding, pad, stride): - inputs = [] - initializers = [] - - # input - input = helper.make_tensor_value_info('input', TensorProto.FLOAT, in_shape) - inputs.append('input') - - group = 1 if group is None else group - - # weight - w_shape = [] - w_shape.append(in_shape[1]) - w_shape.append(kernel_output_channel // group) - w_shape.extend(kernel_shape) - weight = helper.make_tensor( - 'weight', - TensorProto.FLOAT, - dims=w_shape, - vals=np.random.rand(*w_shape).astype(np.float32).flatten().tolist() - ) - inputs.append('weight') - initializers.append(weight) - - # bias - if bias_shape is not None: - bias = helper.make_tensor( - 'bias', - TensorProto.FLOAT, - dims=bias_shape, - vals=np.random.rand(*bias_shape).astype(np.float32).flatten().tolist() - ) - inputs.append('bias') - initializers.append(bias) - - # dilation - d = [1, 1] if dilation is None else dilation - - # output_padding - out_padding = [0, 0] if output_padding is None else output_padding - - # stride - s = [1, 1] if stride is None else stride - - # output - out_shape = [] - out_shape.append(in_shape[0]) - out_shape.append(w_shape[1] * group) - - # pad - padding = [0, 0, 0, 0] - if auto_pad_mode in [None, 'NOTSET'] and pad is not None: - padding = pad - out_shape.append(s[0] * (in_shape[2] - 1) + out_padding[0] + - (w_shape[2] - 1) * d[0] + 1 - padding[0] - padding[2]) - out_shape.append(s[1] * (in_shape[3] - 1) + out_padding[1] + - (w_shape[3] - 1) * d[1] + 1 - padding[1] - padding[3]) - elif auto_pad_mode in ['SAME_UPPER', 'SAME_LOWER']: - out_shape.append(in_shape[2] * s[0]) - out_shape.append(in_shape[3] * s[1]) - else: - out_shape.append(in_shape[2] + (in_shape[2] - 1) * (s[0] - 1) - w_shape[2] + 1) - out_shape.append(in_shape[3] + (in_shape[3] - 1) * (s[1] - 1) - w_shape[3] + 1) - - output = helper.make_tensor_value_info('output', TensorProto.FLOAT, out_shape) - - attributes_dict = {} - - if auto_pad_mode is not None: - attributes_dict['auto_pad'] = auto_pad_mode - - if dilation is not None: - attributes_dict['dilations'] = dilation - - if group is not None: - attributes_dict['group'] = group - - if kernel_shape is not None: - attributes_dict['kernel_shape'] = kernel_shape - - if output_padding is not None: - attributes_dict['output_padding'] = output_padding - - if pad is not None: - attributes_dict['pads'] = padding - - if stride is not None: - attributes_dict['strides'] = stride - - node = onnx.helper.make_node( - 'ConvTranspose', - inputs=inputs, - outputs=['output'], - **attributes_dict - ) - - nodes = [] - nodes.append(node) - - graph_def = helper.make_graph( - nodes, - 'test-model', - [input], - [output], - initializer=initializers) - - model_def = helper.make_model(graph_def, producer_name='kendryte') - - return model_def - - -in_shapes = [ - [1, 3, 10, 10] -] +def _make_module(in_channel, out_channel, kernel_size, stride, dilation, pad, group, bias): + class ConvTransposeModule(torch.nn.Module): + def __init__(self): + super(ConvTransposeModule, self).__init__() + self.conv_transpose = torch.nn.ConvTranspose2d( + in_channel, out_channel, kernel_size, stride, pad, [0, 0], group, bias, dilation) -kernel_output_channels = [ - 3 -] + def forward(self, x): + return self.conv_transpose(x) -bias_shapes = [ - None, -] -bias_shapes.extend(list([[x] for x in kernel_output_channels])) - -auto_pad_modes = [ - None, - # 'NOTSET', - # 'SAME_UPPER', - # 'SAME_LOWER', - # 'VALID' + return ConvTransposeModule() + + +in_sizes = [ + [16, 16], + [33, 65], ] -dilations = [ - None, +in_channels = [ + 1, + 3, + 16 ] -groups = [ - None, - # 3 +out_channels = [ + 1, + 16 ] -kernel_shapes = [ +kernel_sizes = [ [1, 1], + [3, 3], +] + +strides = [ + 1, + [2, 2] ] -output_paddings = [ - None, - # [1, 1] +dilations = [ + 1 ] pads = [ - # None, - [0, 0, 1, 1], + 0, + [1, 1], ] -strides = [ - None, - # [2, 3], - # [3, 2], - # [3, 3] +groups = [ + 1 +] + +biases = [ + True, + False ] -@pytest.mark.parametrize('in_shape', in_shapes) -@pytest.mark.parametrize('kernel_output_channel', kernel_output_channels) -@pytest.mark.parametrize('bias_shape', bias_shapes) -@pytest.mark.parametrize('auto_pad_mode', auto_pad_modes) +@pytest.mark.parametrize('in_size', in_sizes) +@pytest.mark.parametrize('in_channel', in_channels) +@pytest.mark.parametrize('out_channel', out_channels) +@pytest.mark.parametrize('kernel_size', kernel_sizes) +@pytest.mark.parametrize('stride', strides) @pytest.mark.parametrize('dilation', dilations) -@pytest.mark.parametrize('group', groups) -@pytest.mark.parametrize('kernel_shape', kernel_shapes) -@pytest.mark.parametrize('output_padding', output_paddings) @pytest.mark.parametrize('pad', pads) -@pytest.mark.parametrize('stride', strides) -def test_conv_transpose(in_shape, kernel_output_channel, bias_shape, auto_pad_mode, dilation, group, kernel_shape, output_padding, pad, stride, request): - if (bias_shape is None or (bias_shape is not None and bias_shape[0] == kernel_output_channel)) and ((auto_pad_mode in [None, 'NOTSET'] and pad is not None) or (auto_pad_mode in ['SAME_UPPER', 'SAME_LOWER', 'VALID'] and pad is None)) and (dilation is None or (auto_pad_modes in [None, 'NOTSET'])) and ((output_padding is None) or (output_padding is not None and stride is not None)): - model_def = _make_module(in_shape, kernel_output_channel, bias_shape, - auto_pad_mode, dilation, group, kernel_shape, output_padding, pad, stride) +@pytest.mark.parametrize('group', groups) +@pytest.mark.parametrize('bias', biases) +def test_conv_transpose(in_size, in_channel, out_channel, kernel_size, stride, dilation, pad, group, bias, request): + model_file = _make_module(in_channel, out_channel, kernel_size, + stride, dilation, pad, group, bias) - runner = OnnxTestRunner(request.node.name) - model_file = runner.from_onnx_helper(model_def) - runner.run(model_file) + runner = OnnxTestRunner(request.node.name) + model_file = runner.from_torch(model_file, [1, in_channel, *in_size]) + runner.run(model_file) if __name__ == "__main__": - pytest.main(['-vv', 'test_conv_transpose.py']) + pytest.main(['-vv', __file__]) diff --git a/tests/importer/onnx_/basic/test_unary.py b/tests/importer/onnx_/basic/test_unary.py index 58353ff25a..1137dd0a06 100644 --- a/tests/importer/onnx_/basic/test_unary.py +++ b/tests/importer/onnx_/basic/test_unary.py @@ -26,19 +26,19 @@ def __init__(self): def forward(self, x): outs = [] - # outs.append(torch.abs(-x)) - # outs.append(torch.acos(x)) - # outs.append(torch.asin(x)) - # outs.append(torch.ceil(x)) - # outs.append(torch.cos(x)) - # outs.append(torch.exp(x)) + outs.append(torch.abs(-x)) + outs.append(torch.acos(x)) + outs.append(torch.asin(x)) + outs.append(torch.ceil(x)) + outs.append(torch.cos(x)) + outs.append(torch.exp(x)) outs.append(torch.floor(x * 10)) - # outs.append(torch.log(x + 2)) - # outs.append(torch.neg(x)) - # outs.append(torch.round(x)) - # outs.append(torch.sin(x)) - # outs.append(torch.sqrt(x + 2)) - # outs.append(torch.tanh(x)) + outs.append(torch.log(x + 2)) + outs.append(torch.neg(x)) + outs.append(torch.round(x)) + outs.append(torch.sin(x)) + outs.append(torch.sqrt(x + 2)) + outs.append(torch.tanh(x)) return outs return UnaryModule() @@ -46,7 +46,7 @@ def forward(self, x): in_shapes = [ [16], - # [1, 3, 16, 16] + [1, 3, 16, 16] ] @@ -60,4 +60,4 @@ def test_unary(in_shape, request): if __name__ == "__main__": - pytest.main(['-vv', __file__]) + pytest.main(['-vv', 'test_unary.py']) diff --git a/tests/importer/tflite_/basic/test_fully_connected.py b/tests/importer/tflite_/basic/test_fully_connected.py index c81829f4ce..6cc3996d03 100644 --- a/tests/importer/tflite_/basic/test_fully_connected.py +++ b/tests/importer/tflite_/basic/test_fully_connected.py @@ -28,34 +28,29 @@ def __init__(self): @tf.function(input_signature=[tf.TensorSpec(input_shape, dtype=tf.float32)]) def __call__(self, x): - out = [] - x = self.out(x) - y = tf.reshape(x, [1, 1, 560, 80]) - out.append(x) - out.append(y) - return out + return self.out(x) return FullyConnectedModule() input_shapes = [ - [1, 560, 128], - # [3, 7] + [4, 6], + [3, 7] ] units = [ - 80, - # 13 + 3, + 13 ] activations = [ None, - # 'relu', + 'relu', ] use_biases = [ True, - # False + False ] diff --git a/tests/test_runner.py b/tests/test_runner.py index f6b63b0501..335b22e2c9 100644 --- a/tests/test_runner.py +++ b/tests/test_runner.py @@ -255,7 +255,7 @@ def run(self, model_file: Union[List[str], str]): if not judge: if test_utils.in_ci(): self.clear(self.case_dir) - assert (judge), f"Fault result in {stage} + {result}" + assert f"Fault result in {stage} + {result}" if test_utils.in_ci(): self.clear(self.case_dir) From 9c98a6a09b86355e8f01c871beaed41252efdaee Mon Sep 17 00:00:00 2001 From: huochenghai Date: Tue, 1 Aug 2023 14:58:15 +0800 Subject: [PATCH 056/308] fix mac build --- modules/cpu/src/runtime/cpu_common.h | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/modules/cpu/src/runtime/cpu_common.h b/modules/cpu/src/runtime/cpu_common.h index 8221fa015f..8fdca3c3aa 100644 --- a/modules/cpu/src/runtime/cpu_common.h +++ b/modules/cpu/src/runtime/cpu_common.h @@ -97,7 +97,7 @@ inline int32_t int32_binary_mul(int32_t x, int32_t y) { return x * y; } inline int32_t int32_binary_div(int32_t x, int32_t y) { return x / y; } inline int32_t int32_binary_min(int32_t x, int32_t y) { return std::min(x, y); } inline int32_t int32_binary_max(int32_t x, int32_t y) { return std::max(x, y); } -#if defined(__arm64__) && defined(__APPLE__) +#if defined(__APPLE__) inline int32_t int32_binary_pow(int32_t x, int32_t y) { return (int32_t)pow(x, y); } @@ -113,7 +113,7 @@ inline int64_t int64_binary_mul(int64_t x, int64_t y) { return x * y; } inline int64_t int64_binary_div(int64_t x, int64_t y) { return x / y; } inline int64_t int64_binary_min(int64_t x, int64_t y) { return std::min(x, y); } inline int64_t int64_binary_max(int64_t x, int64_t y) { return std::max(x, y); } -#if defined(__arm64__) && defined(__APPLE__) +#if defined(__APPLE__) inline int64_t int64_binary_pow(int64_t x, int64_t y) { return (int64_t)pow(x, y); } From 0e7d98d89a3d13e6dae6c29f1ba1497a669df386 Mon Sep 17 00:00:00 2001 From: huochenghai Date: Tue, 1 Aug 2023 15:48:54 +0800 Subject: [PATCH 057/308] try to fix msvc build --- modules/cpu/src/runtime/runtime_function.cpp | 4 ++++ 1 file changed, 4 insertions(+) diff --git a/modules/cpu/src/runtime/runtime_function.cpp b/modules/cpu/src/runtime/runtime_function.cpp index 2b4d7e3a8f..9e6b261bb9 100644 --- a/modules/cpu/src/runtime/runtime_function.cpp +++ b/modules/cpu/src/runtime/runtime_function.cpp @@ -48,6 +48,8 @@ cpu_runtime_module &cpu_runtime_function::module() const noexcept { result cpu_runtime_function::initialize_core( NNCASE_UNUSED runtime_function_init_context &context) noexcept { + printf("Initializing Core\n"); + // try_(context.read_section(".desc", [this](auto sr, size_t) -> // result { // auto header = sr.template read(); @@ -113,6 +115,8 @@ cpu_runtime_function::invoke_core(NNCASE_UNUSED gsl::span parameters, NNCASE_UNUSED value_t return_value) noexcept { try_var(id, module().find_id_by_function(this)); + printf("InvokeCore\n"); + uint8_t **buffers = new uint8_t *[parameters.size()]; // input buffer for (size_t i = 0; i < parameters.size(); i++) { From 352b69d507d3ed28999df540207b44e5c79546e6 Mon Sep 17 00:00:00 2001 From: huochenghai Date: Tue, 1 Aug 2023 15:52:43 +0800 Subject: [PATCH 058/308] Revert "try to fix msvc build" This reverts commit 0e7d98d89a3d13e6dae6c29f1ba1497a669df386. --- modules/cpu/src/runtime/runtime_function.cpp | 4 ---- 1 file changed, 4 deletions(-) diff --git a/modules/cpu/src/runtime/runtime_function.cpp b/modules/cpu/src/runtime/runtime_function.cpp index 9e6b261bb9..2b4d7e3a8f 100644 --- a/modules/cpu/src/runtime/runtime_function.cpp +++ b/modules/cpu/src/runtime/runtime_function.cpp @@ -48,8 +48,6 @@ cpu_runtime_module &cpu_runtime_function::module() const noexcept { result cpu_runtime_function::initialize_core( NNCASE_UNUSED runtime_function_init_context &context) noexcept { - printf("Initializing Core\n"); - // try_(context.read_section(".desc", [this](auto sr, size_t) -> // result { // auto header = sr.template read(); @@ -115,8 +113,6 @@ cpu_runtime_function::invoke_core(NNCASE_UNUSED gsl::span parameters, NNCASE_UNUSED value_t return_value) noexcept { try_var(id, module().find_id_by_function(this)); - printf("InvokeCore\n"); - uint8_t **buffers = new uint8_t *[parameters.size()]; // input buffer for (size_t i = 0; i < parameters.size(); i++) { From 4b64b4404d1873ec889140ab4134ee9b01b8d38e Mon Sep 17 00:00:00 2001 From: huochenghai Date: Tue, 1 Aug 2023 15:53:22 +0800 Subject: [PATCH 059/308] try to fix msvc build --- CMakeLists.txt | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/CMakeLists.txt b/CMakeLists.txt index 019393e9df..6b746d454b 100644 --- a/CMakeLists.txt +++ b/CMakeLists.txt @@ -92,7 +92,7 @@ if (BUILDING_RUNTIME) if (MSVC) add_definitions(/D_CRT_SECURE_NO_WARNINGS /DNOMINMAX) - add_compile_options(/wd4267 /wd4251 /wd4244 /FC /utf-8 /W3 /WX /wd4297 -Wno-unused-function -Wno-unused-command-line-argument -Wno-int-to-void-pointer-cast) + add_compile_options(/wd4267 /wd4251 /wd4244 /FC /utf-8 /W3 /WX /wd4297 -Wno-unused-function -Wno-unused-command-line-argument -Wno-int-to-void-pointer-cast -Wno-int-to-pointer-cast) else() add_compile_options(-Wall -Wextra -pedantic -Werror -Wno-multichar -Wno-missing-field-initializers -Wno-unused-function -Wno-type-limits) if (APPLE) @@ -184,7 +184,7 @@ else() if (MSVC) add_definitions(/D_SILENCE_ALL_CXX17_DEPRECATION_WARNINGS /D_CRT_SECURE_NO_WARNINGS /DNOMINMAX) - add_compile_options(/wd4267 /wd4251 /wd4244 /FC /utf-8 /W3 /WX -Wno-unused-function -Wno-unused-command-line-argument -Wno-int-to-void-pointer-cast) + add_compile_options(/wd4267 /wd4251 /wd4244 /FC /utf-8 /W3 /WX -Wno-unused-function -Wno-unused-command-line-argument -Wno-int-to-void-pointer-cast -Wno-int-to-pointer-cast) set(PYBIND11_CPP_STANDARD "/std:c++latest") else() add_compile_options(-fvisibility=hidden) From 5e85bf33e2f71aaf2a260455554c02be2f590cfe Mon Sep 17 00:00:00 2001 From: huochenghai Date: Thu, 3 Aug 2023 11:12:31 +0800 Subject: [PATCH 060/308] update cpu runtime --- .../CodeGen/CSourceCompiler.cs | 4 +- modules/cpu/src/runtime/cpu_common.h | 103 +++++++++--------- modules/cpu/src/runtime/elfloader.cpp | 4 +- 3 files changed, 55 insertions(+), 56 deletions(-) diff --git a/modules/Nncase.Modules.CPU/CodeGen/CSourceCompiler.cs b/modules/Nncase.Modules.CPU/CodeGen/CSourceCompiler.cs index d59e244f92..449329e034 100644 --- a/modules/Nncase.Modules.CPU/CodeGen/CSourceCompiler.cs +++ b/modules/Nncase.Modules.CPU/CodeGen/CSourceCompiler.cs @@ -134,11 +134,11 @@ private string ArgumentsSpecific(string sourcePath, string outPath) { if (RuntimeInformation.IsOSPlatform(OSPlatform.Linux)) { - return $"{sourcePath} -nostdlib -static -no-pie -fPIC -fno-stack-protector -march={Arch} -o {outPath}"; + return $"{sourcePath} -nostdlib -static -no-pie -fPIC -fno-stack-protector -o {outPath}"; } else if (RuntimeInformation.IsOSPlatform(OSPlatform.OSX)) { - return $"{sourcePath} -nostdlib -static -nopie -fPIC -arch {Arch} -o {outPath} -e__start"; + return $"{sourcePath} -nostdlib -static -nopie -fPIC -o {outPath} -e__start"; } else if (RuntimeInformation.IsOSPlatform(OSPlatform.Windows)) { diff --git a/modules/cpu/src/runtime/cpu_common.h b/modules/cpu/src/runtime/cpu_common.h index 8fdca3c3aa..aefa9f32d1 100644 --- a/modules/cpu/src/runtime/cpu_common.h +++ b/modules/cpu/src/runtime/cpu_common.h @@ -127,57 +127,56 @@ inline bool bool_binary_logical_and(bool x, bool y) { return x && y; } inline bool bool_binary_logical_or(bool x, bool y) { return x || y; } inline bool bool_binary_logical_xor(bool x, bool y) { return x ^ y; } -[[maybe_unused]] static nncase_mt_t nncase_mt = { - .float_unary_abs = fabsf, - .float_unary_acos = acosf, - .float_unary_acosh = acoshf, - .float_unary_asin = asinf, - .float_unary_asinh = asinhf, - .float_unary_ceil = ceilf, - .float_unary_cos = cosf, - .float_unary_cosh = coshf, - .float_unary_exp = expf, - .float_unary_floor = floorf, - .float_unary_log = logf, - .float_unary_logical_not = float_unary_logical_not, - .float_unary_neg = float_unary_neg, - .float_unary_round = roundf, - .float_unary_rsqrt = float_unary_rsqrt, - .float_unary_sign = float_unary_sign, - .float_unary_sin = sinf, - .float_unary_sinh = sinhf, - .float_unary_sqrt = sqrtf, - .float_unary_square = float_unary_square, - .float_unary_tanh = tanhf, - .float_binary_add = float_binary_add, - .float_binary_sub = float_binary_sub, - .float_binary_mul = float_binary_mul, - .float_binary_div = float_binary_div, - .float_binary_min = float_binary_min, - .float_binary_max = float_binary_max, - .float_binary_pow = float_binary_pow, - .float_binary_logical_and = float_binary_logical_and, - .float_binary_mod = float_binary_mod, - .int32_binary_add = int32_binary_add, - .int32_binary_sub = int32_binary_sub, - .int32_binary_mul = int32_binary_mul, - .int32_binary_div = int32_binary_div, - .int32_binary_min = int32_binary_min, - .int32_binary_max = int32_binary_max, - .int32_binary_pow = int32_binary_pow, - .int32_binary_logical_and = int32_binary_logical_and, - .int32_binary_mod = int32_binary_mod, - .int64_binary_add = int64_binary_add, - .int64_binary_sub = int64_binary_sub, - .int64_binary_mul = int64_binary_mul, - .int64_binary_div = int64_binary_div, - .int64_binary_min = int64_binary_min, - .int64_binary_max = int64_binary_max, - .int64_binary_pow = int64_binary_pow, - .int64_binary_logical_and = int64_binary_logical_and, - .int64_binary_mod = int64_binary_mod, - .bool_binary_and = bool_binary_logical_and, - .bool_binary_or = bool_binary_logical_or, - .bool_binary_xor = bool_binary_logical_xor}; +[[maybe_unused]] static nncase_mt_t nncase_mt{fabsf, + acosf, + acoshf, + asinf, + asinhf, + ceilf, + cosf, + coshf, + expf, + floorf, + logf, + float_unary_logical_not, + float_unary_neg, + roundf, + float_unary_rsqrt, + float_unary_sign, + sinf, + sinhf, + sqrtf, + float_unary_square, + tanhf, + float_binary_add, + float_binary_sub, + float_binary_mul, + float_binary_div, + float_binary_min, + float_binary_max, + float_binary_pow, + float_binary_logical_and, + float_binary_mod, + int32_binary_add, + int32_binary_sub, + int32_binary_mul, + int32_binary_div, + int32_binary_min, + int32_binary_max, + int32_binary_pow, + int32_binary_logical_and, + int32_binary_mod, + int64_binary_add, + int64_binary_sub, + int64_binary_mul, + int64_binary_div, + int64_binary_min, + int64_binary_max, + int64_binary_pow, + int64_binary_logical_and, + int64_binary_mod, + bool_binary_logical_and, + bool_binary_logical_or, + bool_binary_logical_xor}; END_NS_NNCASE_RT_MODULE \ No newline at end of file diff --git a/modules/cpu/src/runtime/elfloader.cpp b/modules/cpu/src/runtime/elfloader.cpp index c41d7e11ab..394ed6ed3d 100644 --- a/modules/cpu/src/runtime/elfloader.cpp +++ b/modules/cpu/src/runtime/elfloader.cpp @@ -29,8 +29,8 @@ int elfloader::invoke_elf(size_t id, uint8_t **buffers, nncase_mt_t *nncase_mt, entrypoint_t ep = (entrypoint_t)epaddr; - printf("Binary entrypoint is %" PRIxPTR "; invoking %p\n", - (uintptr_t)ctx_.ehdr.e_entry, (void *)epaddr); + // printf("Binary entrypoint is %" PRIxPTR "; invoking %p\n", + // (uintptr_t)ctx_.ehdr.e_entry, (void *)epaddr); ep(id, buffers, nncase_mt, data, rdata); From 5531c7d6c45684c42858624b6a16dbcbac2615ad Mon Sep 17 00:00:00 2001 From: huochenghai Date: Thu, 3 Aug 2023 11:13:20 +0800 Subject: [PATCH 061/308] add cpu_test example --- modules/cpu/CMakeLists.txt | 4 + modules/cpu/examples/CMakeLists.txt | 4 + modules/cpu/examples/cpu_test/CMakeLists.txt | 5 + modules/cpu/examples/cpu_test/main.cpp | 191 +++++++++++++++++++ toolchains/k230_cpu.linux.toolchain.cmake | 33 ++++ 5 files changed, 237 insertions(+) create mode 100644 modules/cpu/examples/CMakeLists.txt create mode 100644 modules/cpu/examples/cpu_test/CMakeLists.txt create mode 100644 modules/cpu/examples/cpu_test/main.cpp create mode 100644 toolchains/k230_cpu.linux.toolchain.cmake diff --git a/modules/cpu/CMakeLists.txt b/modules/cpu/CMakeLists.txt index f030315513..42d6983640 100644 --- a/modules/cpu/CMakeLists.txt +++ b/modules/cpu/CMakeLists.txt @@ -45,6 +45,10 @@ if (BUILDING_RUNTIME) configure_file(${CMAKE_CURRENT_LIST_DIR}/cmake/nncase_rt_modules_cpuConfig.cmake.in nncase_rt_modules_cpuConfig.cmake @ONLY) install(FILES ${CMAKE_CURRENT_BINARY_DIR}/nncase_rt_modules_cpuConfig.cmake DESTINATION lib/cmake/nncaseruntime) + + if (K230_LINUX_SDK_DIR) + add_subdirectory(examples) + endif() endif() else() add_library(nncase_modules_cpu SHARED ${SRCS}) diff --git a/modules/cpu/examples/CMakeLists.txt b/modules/cpu/examples/CMakeLists.txt new file mode 100644 index 0000000000..ea209b0f0f --- /dev/null +++ b/modules/cpu/examples/CMakeLists.txt @@ -0,0 +1,4 @@ +cmake_minimum_required(VERSION 3.13) +project(examples C CXX) + +add_subdirectory(cpu_test) \ No newline at end of file diff --git a/modules/cpu/examples/cpu_test/CMakeLists.txt b/modules/cpu/examples/cpu_test/CMakeLists.txt new file mode 100644 index 0000000000..92c0068dff --- /dev/null +++ b/modules/cpu/examples/cpu_test/CMakeLists.txt @@ -0,0 +1,5 @@ +set(SRC main.cpp) +set(bin cpu_test.elf) + +add_executable(${bin} ${SRC}) +target_link_libraries(${bin} runtime_cpu) \ No newline at end of file diff --git a/modules/cpu/examples/cpu_test/main.cpp b/modules/cpu/examples/cpu_test/main.cpp new file mode 100644 index 0000000000..95dd514129 --- /dev/null +++ b/modules/cpu/examples/cpu_test/main.cpp @@ -0,0 +1,191 @@ +#include +#include +#include +#include +#include + +using namespace nncase; +using namespace nncase::runtime; +using namespace nncase::runtime::detail; + +template +std::vector read_binary_file(const char *file_name) +{ + std::ifstream ifs(file_name, std::ios::binary); + ifs.seekg(0, ifs.end); + size_t len = ifs.tellg(); + std::vector vec(len / sizeof(T), 0); + ifs.seekg(0, ifs.beg); + ifs.read(reinterpret_cast(vec.data()), len); + ifs.close(); + return vec; +} + +void read_binary_file(const char *file_name, char *buffer) +{ + std::ifstream ifs(file_name, std::ios::binary); + ifs.seekg(0, ifs.end); + size_t len = ifs.tellg(); + ifs.seekg(0, ifs.beg); + ifs.read(buffer, len); + ifs.close(); +} + +auto read_binary(const char *file_name, char *buffer, size_t begin, size_t count) +{ + std::ifstream ifs(file_name, std::ios::binary); + ifs.seekg(begin, ifs.beg); + ifs.read(buffer + begin, count); + ifs.close(); + std::cout << "read bin seg ok" << std::endl; +} + +size_t get_binary_file_size(const char *file_name) +{ + std::ifstream ifs(file_name, std::ios::binary); + ifs.seekg(0, ifs.end); + size_t len = ifs.tellg(); + ifs.close(); + return len; +} + +auto load_bin_to_kmodel(char *file_path, interpreter &interp) +{ + auto model_size = get_binary_file_size(file_path); + auto model_data = std::make_unique(model_size); + for (size_t i = 0; i < model_size;) + { + size_t count = 8000000; + if (count + i >= model_size) + count = model_size - i; + read_binary(file_path, model_data.get(), i, count); + i += 8000000; + } + // show the identification of kmodel. + // print_one_line_data("kmodel identification :", model_data.get(), 4); + + interp.load_model({ (const gsl::byte *)model_data.get(), model_size }).expect("cannot load kmodel."); + std::cout << "load kmodel success" << std::endl; + + return model_data; +} + +template +double dot(const T *v1, const T *v2, size_t size) +{ + double ret = 0.f; + for (size_t i = 0; i < size; i++) + { + ret += v1[i] * v2[i]; + } + + return ret; +} + +template +double cosine(const T *v1, const T *v2, size_t size) +{ + for (size_t i = 0; i < 10; i++) + { + std::cout << v1[i] << " " << v2[i] << std::endl; + } + return dot(v1, v2, size) / ((sqrt(dot(v1, v1, size)) * sqrt(dot(v2, v2, size)))); +} + +int main(int argc, char *argv[]) +{ + std::cout << "case " << argv[0] << " build " << __DATE__ << " " << __TIME__ << std::endl; + + if (argc < 3) + { + std::cerr << "Usage: " << std::endl; + std::cerr << argv[0] << " ... " << std::endl; + std::cerr << argv[0] << " ... ... " << std::endl; + return -1; + } + + interpreter interp; + + // 1. load model + std::ifstream ifs(argv[1], std::ios::binary); + interp.load_model(ifs).expect("Invalid kmodel"); + + // 2. set inputs + for (size_t i = 2, j = 0; i < 2 + interp.inputs_size(); i++, j++) + { + auto desc = interp.input_desc(j); + auto shape = interp.input_shape(j); + auto tensor = hrt::create(desc.datatype, shape, hrt::pool_shared).expect("cannot create input tensor"); + interp.input_tensor(j, tensor).expect("cannot set input tensor"); + + auto span = tensor.impl()->to_host().unwrap()->buffer().as_host().unwrap().map(map_access_::map_write).unwrap().buffer(); + read_binary_file(argv[i], reinterpret_cast(span.data())); + hrt::sync(tensor, sync_op_t::sync_write_back, true).expect("sync write_back failed"); + } + + // 3. set outputs + // for (size_t i = 0; i < interp.outputs_size(); i++) + // { + // auto desc = interp.output_desc(i); + // auto shape = interp.output_shape(i); + // auto tensor = hrt::create(desc.datatype, shape, hrt::pool_shared).expect("cannot create input tensor"); + // interp.output_tensor(i, tensor).expect("cannot set output tensor"); + // } + + // 4. run + interp.run().expect("error occurred in running model"); + auto start = std::chrono::steady_clock::now(); + interp.run().expect("error occurred in running model"); + auto stop = std::chrono::steady_clock::now(); + double duration = std::chrono::duration(stop - start).count(); + + // 5. get outputs + double cos = 0.f; + for (int i = 2 + interp.inputs_size(), j = 0; i < argc; i++, j++) + { + auto desc = interp.output_desc(j); + auto out = interp.output_tensor(j).expect("cannot get output tensor"); + auto mapped_buf = std::move(hrt::map(out, map_access_t::map_read).unwrap()); + auto vec = read_binary_file(argv[i]); + switch (desc.datatype) + { + case dt_boolean: + case dt_uint8: + { + cos = cosine((const uint8_t *)mapped_buf.buffer().data(), (const uint8_t *)vec.data(), vec.size() / sizeof(uint8_t)); + break; + } + case dt_int8: + { + cos = cosine((const int8_t *)mapped_buf.buffer().data(), (const int8_t *)vec.data(), vec.size() / sizeof(int8_t)); + break; + } + case dt_float32: + { + cos = cosine((const float *)mapped_buf.buffer().data(), (const float *)vec.data(), vec.size() / sizeof(float)); + break; + } + case dt_int32: + { + cos = cosine((const int32_t *)mapped_buf.buffer().data(), (const int32_t *)vec.data(), vec.size() / sizeof(int32_t)); + break; + } + case dt_int64: + { + cos = cosine((const int64_t *)mapped_buf.buffer().data(), (const int64_t *)vec.data(), vec.size() / sizeof(int64_t)); + break; + } + default: + { + std::cerr << "not supported data type: " << desc.datatype << std::endl; + std::abort(); + } + } + + std::cout << "output " << j << " cosine similarity: " << cos << std::endl; + } + + std::cout << "interp run: " << duration << " ms, fps = " << 1000 / duration << std::endl; + + return 0; +} \ No newline at end of file diff --git a/toolchains/k230_cpu.linux.toolchain.cmake b/toolchains/k230_cpu.linux.toolchain.cmake new file mode 100644 index 0000000000..4bd9aba679 --- /dev/null +++ b/toolchains/k230_cpu.linux.toolchain.cmake @@ -0,0 +1,33 @@ +set(CMAKE_SYSTEM_NAME Linux) +set(CMAKE_SYSTEM_PROCESSOR riscv64) + +if(DEFINED ENV{RISCV_ROOT_PATH}) + file(TO_CMAKE_PATH $ENV{RISCV_ROOT_PATH} RISCV_ROOT_PATH) +endif() + +if(NOT RISCV_ROOT_PATH) + message(FATAL_ERROR "RISCV_ROOT_PATH env must be defined") +endif() + +set(RISCV_ROOT_PATH ${RISCV_ROOT_PATH} CACHE STRING "root path to riscv toolchain") +set(CMAKE_C_COMPILER "${RISCV_ROOT_PATH}/bin/riscv64-unknown-linux-musl-gcc") +set(CMAKE_CXX_COMPILER "${RISCV_ROOT_PATH}/bin/riscv64-unknown-linux-musl-g++") +set(CMAKE_FIND_ROOT_PATH "${RISCV_ROOT_PATH}/riscv64-unknown-linux-musl") + +set(CMAKE_FIND_ROOT_PATH_MODE_PROGRAM NEVER) +set(CMAKE_FIND_ROOT_PATH_MODE_LIBRARY ONLY) +set(CMAKE_FIND_ROOT_PATH_MODE_INCLUDE ONLY) +set(ENABLE_VULKAN_RUNTIME OFF) +set(ENABLE_OPENMP OFF) +set(ENABLE_HALIDE OFF) +set(DEFAULT_BUILTIN_RUNTIMES OFF) +set(DEFAULT_SHARED_RUNTIME_TENSOR_PLATFORM_IMPL ON) +set(BUILD_BENCHMARK OFF) + +set(CMAKE_C_FLAGS "${CMAKE_C_FLAGS} -march=rv64imafdcv -mabi=lp64d -mcmodel=medany") +set(CMAKE_CXX_FLAGS "${CMAKE_CXX_FLAGS} -march=rv64imafdcv -mabi=lp64d -mcmodel=medany") +set(CMAKE_EXE_LINKER_FLAGS "-T ${K230_LINUX_SDK_DIR}/src/big/rt-smart/userapps/linker_scripts/riscv64/link.lds --static") + +set(BUILDING_RUNTIME ON) +set(ENABLE_CPU_RUNTIME ON) +set(BUILD_SHARED_LIBS OFF) \ No newline at end of file From 81b07a263c87e06b69fd1e22cb68ca508876dc80 Mon Sep 17 00:00:00 2001 From: xhuohai Date: Thu, 3 Aug 2023 03:16:13 +0000 Subject: [PATCH 062/308] Apply code-format changes --- modules/cpu/examples/cpu_test/main.cpp | 138 +++++++++++++------------ 1 file changed, 74 insertions(+), 64 deletions(-) diff --git a/modules/cpu/examples/cpu_test/main.cpp b/modules/cpu/examples/cpu_test/main.cpp index 95dd514129..3952270f0e 100644 --- a/modules/cpu/examples/cpu_test/main.cpp +++ b/modules/cpu/examples/cpu_test/main.cpp @@ -8,9 +8,7 @@ using namespace nncase; using namespace nncase::runtime; using namespace nncase::runtime::detail; -template -std::vector read_binary_file(const char *file_name) -{ +template std::vector read_binary_file(const char *file_name) { std::ifstream ifs(file_name, std::ios::binary); ifs.seekg(0, ifs.end); size_t len = ifs.tellg(); @@ -21,8 +19,7 @@ std::vector read_binary_file(const char *file_name) return vec; } -void read_binary_file(const char *file_name, char *buffer) -{ +void read_binary_file(const char *file_name, char *buffer) { std::ifstream ifs(file_name, std::ios::binary); ifs.seekg(0, ifs.end); size_t len = ifs.tellg(); @@ -31,8 +28,8 @@ void read_binary_file(const char *file_name, char *buffer) ifs.close(); } -auto read_binary(const char *file_name, char *buffer, size_t begin, size_t count) -{ +auto read_binary(const char *file_name, char *buffer, size_t begin, + size_t count) { std::ifstream ifs(file_name, std::ios::binary); ifs.seekg(begin, ifs.beg); ifs.read(buffer + begin, count); @@ -40,8 +37,7 @@ auto read_binary(const char *file_name, char *buffer, size_t begin, size_t count std::cout << "read bin seg ok" << std::endl; } -size_t get_binary_file_size(const char *file_name) -{ +size_t get_binary_file_size(const char *file_name) { std::ifstream ifs(file_name, std::ios::binary); ifs.seekg(0, ifs.end); size_t len = ifs.tellg(); @@ -49,12 +45,10 @@ size_t get_binary_file_size(const char *file_name) return len; } -auto load_bin_to_kmodel(char *file_path, interpreter &interp) -{ +auto load_bin_to_kmodel(char *file_path, interpreter &interp) { auto model_size = get_binary_file_size(file_path); auto model_data = std::make_unique(model_size); - for (size_t i = 0; i < model_size;) - { + for (size_t i = 0; i < model_size;) { size_t count = 8000000; if (count + i >= model_size) count = model_size - i; @@ -64,43 +58,43 @@ auto load_bin_to_kmodel(char *file_path, interpreter &interp) // show the identification of kmodel. // print_one_line_data("kmodel identification :", model_data.get(), 4); - interp.load_model({ (const gsl::byte *)model_data.get(), model_size }).expect("cannot load kmodel."); + interp.load_model({(const gsl::byte *)model_data.get(), model_size}) + .expect("cannot load kmodel."); std::cout << "load kmodel success" << std::endl; return model_data; } -template -double dot(const T *v1, const T *v2, size_t size) -{ +template double dot(const T *v1, const T *v2, size_t size) { double ret = 0.f; - for (size_t i = 0; i < size; i++) - { + for (size_t i = 0; i < size; i++) { ret += v1[i] * v2[i]; } return ret; } -template -double cosine(const T *v1, const T *v2, size_t size) -{ - for (size_t i = 0; i < 10; i++) - { +template double cosine(const T *v1, const T *v2, size_t size) { + for (size_t i = 0; i < 10; i++) { std::cout << v1[i] << " " << v2[i] << std::endl; } - return dot(v1, v2, size) / ((sqrt(dot(v1, v1, size)) * sqrt(dot(v2, v2, size)))); + return dot(v1, v2, size) / + ((sqrt(dot(v1, v1, size)) * sqrt(dot(v2, v2, size)))); } -int main(int argc, char *argv[]) -{ - std::cout << "case " << argv[0] << " build " << __DATE__ << " " << __TIME__ << std::endl; +int main(int argc, char *argv[]) { + std::cout << "case " << argv[0] << " build " << __DATE__ << " " << __TIME__ + << std::endl; - if (argc < 3) - { + if (argc < 3) { std::cerr << "Usage: " << std::endl; - std::cerr << argv[0] << " ... " << std::endl; - std::cerr << argv[0] << " ... ... " << std::endl; + std::cerr << argv[0] + << " ... " + << std::endl; + std::cerr << argv[0] + << " ... " + " ... " + << std::endl; return -1; } @@ -111,16 +105,25 @@ int main(int argc, char *argv[]) interp.load_model(ifs).expect("Invalid kmodel"); // 2. set inputs - for (size_t i = 2, j = 0; i < 2 + interp.inputs_size(); i++, j++) - { + for (size_t i = 2, j = 0; i < 2 + interp.inputs_size(); i++, j++) { auto desc = interp.input_desc(j); auto shape = interp.input_shape(j); - auto tensor = hrt::create(desc.datatype, shape, hrt::pool_shared).expect("cannot create input tensor"); + auto tensor = hrt::create(desc.datatype, shape, hrt::pool_shared) + .expect("cannot create input tensor"); interp.input_tensor(j, tensor).expect("cannot set input tensor"); - auto span = tensor.impl()->to_host().unwrap()->buffer().as_host().unwrap().map(map_access_::map_write).unwrap().buffer(); + auto span = tensor.impl() + ->to_host() + .unwrap() + ->buffer() + .as_host() + .unwrap() + .map(map_access_::map_write) + .unwrap() + .buffer(); read_binary_file(argv[i], reinterpret_cast(span.data())); - hrt::sync(tensor, sync_op_t::sync_write_back, true).expect("sync write_back failed"); + hrt::sync(tensor, sync_op_t::sync_write_back, true) + .expect("sync write_back failed"); } // 3. set outputs @@ -128,7 +131,8 @@ int main(int argc, char *argv[]) // { // auto desc = interp.output_desc(i); // auto shape = interp.output_shape(i); - // auto tensor = hrt::create(desc.datatype, shape, hrt::pool_shared).expect("cannot create input tensor"); + // auto tensor = hrt::create(desc.datatype, shape, + // hrt::pool_shared).expect("cannot create input tensor"); // interp.output_tensor(i, tensor).expect("cannot set output tensor"); // } @@ -137,55 +141,61 @@ int main(int argc, char *argv[]) auto start = std::chrono::steady_clock::now(); interp.run().expect("error occurred in running model"); auto stop = std::chrono::steady_clock::now(); - double duration = std::chrono::duration(stop - start).count(); + double duration = + std::chrono::duration(stop - start).count(); // 5. get outputs double cos = 0.f; - for (int i = 2 + interp.inputs_size(), j = 0; i < argc; i++, j++) - { + for (int i = 2 + interp.inputs_size(), j = 0; i < argc; i++, j++) { auto desc = interp.output_desc(j); auto out = interp.output_tensor(j).expect("cannot get output tensor"); - auto mapped_buf = std::move(hrt::map(out, map_access_t::map_read).unwrap()); + auto mapped_buf = + std::move(hrt::map(out, map_access_t::map_read).unwrap()); auto vec = read_binary_file(argv[i]); - switch (desc.datatype) - { + switch (desc.datatype) { case dt_boolean: - case dt_uint8: - { - cos = cosine((const uint8_t *)mapped_buf.buffer().data(), (const uint8_t *)vec.data(), vec.size() / sizeof(uint8_t)); + case dt_uint8: { + cos = cosine((const uint8_t *)mapped_buf.buffer().data(), + (const uint8_t *)vec.data(), + vec.size() / sizeof(uint8_t)); break; } - case dt_int8: - { - cos = cosine((const int8_t *)mapped_buf.buffer().data(), (const int8_t *)vec.data(), vec.size() / sizeof(int8_t)); + case dt_int8: { + cos = + cosine((const int8_t *)mapped_buf.buffer().data(), + (const int8_t *)vec.data(), vec.size() / sizeof(int8_t)); break; } - case dt_float32: - { - cos = cosine((const float *)mapped_buf.buffer().data(), (const float *)vec.data(), vec.size() / sizeof(float)); + case dt_float32: { + cos = cosine((const float *)mapped_buf.buffer().data(), + (const float *)vec.data(), vec.size() / sizeof(float)); break; } - case dt_int32: - { - cos = cosine((const int32_t *)mapped_buf.buffer().data(), (const int32_t *)vec.data(), vec.size() / sizeof(int32_t)); + case dt_int32: { + cos = cosine((const int32_t *)mapped_buf.buffer().data(), + (const int32_t *)vec.data(), + vec.size() / sizeof(int32_t)); break; } - case dt_int64: - { - cos = cosine((const int64_t *)mapped_buf.buffer().data(), (const int64_t *)vec.data(), vec.size() / sizeof(int64_t)); + case dt_int64: { + cos = cosine((const int64_t *)mapped_buf.buffer().data(), + (const int64_t *)vec.data(), + vec.size() / sizeof(int64_t)); break; } - default: - { - std::cerr << "not supported data type: " << desc.datatype << std::endl; + default: { + std::cerr << "not supported data type: " << desc.datatype + << std::endl; std::abort(); } } - std::cout << "output " << j << " cosine similarity: " << cos << std::endl; + std::cout << "output " << j << " cosine similarity: " << cos + << std::endl; } - std::cout << "interp run: " << duration << " ms, fps = " << 1000 / duration << std::endl; + std::cout << "interp run: " << duration << " ms, fps = " << 1000 / duration + << std::endl; return 0; } \ No newline at end of file From 6f1c48bab3a86b51cd1b1b9ab9124eeb4cbb6e18 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?=E9=83=91=E5=90=AF=E8=88=AA?= <597323109@qq.com> Date: Wed, 2 Aug 2023 16:24:18 +0800 Subject: [PATCH 063/308] add test --- .../BufferScheduleExtensions.cs | 17 ++ .../BufferSchedule/BufferScheduleTypes.cs | 19 ++- .../BufferSchedule/BufferScheduler.cs | 145 +++++++++++++++--- .../BufferSchedule/LifeTimeCollector.cs | 65 +++++--- src/Nncase.Passes/DDrBufferSchdeulePass.cs | 15 +- src/Nncase.Tests/EGraph/UnitTestVrp.cs | 29 ++++ 6 files changed, 234 insertions(+), 56 deletions(-) create mode 100644 src/Nncase.Passes/BufferSchedule/BufferScheduleExtensions.cs diff --git a/src/Nncase.Passes/BufferSchedule/BufferScheduleExtensions.cs b/src/Nncase.Passes/BufferSchedule/BufferScheduleExtensions.cs new file mode 100644 index 0000000000..91cf5e8dd4 --- /dev/null +++ b/src/Nncase.Passes/BufferSchedule/BufferScheduleExtensions.cs @@ -0,0 +1,17 @@ +using System.Collections.Generic; +using System.Linq; +using Nncase.IR; + +namespace Nncase.Passes.BufferSchedule; + +public static class BufferScheduleExtensions +{ + + public static IEnumerable GetUsers(this Call call) + { + var hs = new HashSet(ReferenceEqualityComparer.Instance); + hs.UnionWith(call.Users.Where(e => e is not BaseFunction).ToArray().Select(e => e switch { IR.Tuple tp => tp.Fields.ToArray(), _ => new[] { e } }).SelectMany(i => i)); + return hs; + } + +} \ No newline at end of file diff --git a/src/Nncase.Passes/BufferSchedule/BufferScheduleTypes.cs b/src/Nncase.Passes/BufferSchedule/BufferScheduleTypes.cs index d68c5ca5e8..449102871f 100644 --- a/src/Nncase.Passes/BufferSchedule/BufferScheduleTypes.cs +++ b/src/Nncase.Passes/BufferSchedule/BufferScheduleTypes.cs @@ -7,17 +7,19 @@ internal sealed class TimeInterval { public TimeInterval(int start, int end) { - Start = start; - End = end; + Brith = start; + Death = end; } - public int Start { get; set; } + public int Brith { get; set; } - public int End { get; set; } + public int Death { get; set; } + + public int Size => Death - Brith; public override string ToString() { - return $"TimeInterval({Start}, {End})"; + return $"TimeInterval({Brith}, {Death})"; } } @@ -41,13 +43,14 @@ public override string ToString() internal class ScheduleBuffer { - public ScheduleBuffer(string name, TimeInterval interval, MemSpan span, int[] shape, int[] strides) + public ScheduleBuffer(string name, TimeInterval interval, MemSpan span, int[] shape, int[] strides, bool inplace) { Name = name; Interval = interval; Span = span; Shape = shape; Strides = strides; + Inplace = inplace; } public string Name { get; } @@ -60,8 +63,10 @@ public ScheduleBuffer(string name, TimeInterval interval, MemSpan span, int[] sh public int[] Strides { get; } + public bool Inplace { get; } + public override string ToString() { - return $"ScheduledBuffer('{Name}', {Interval}, {Span}, ConstraintsMode.No, [{string.Join(",", Shape)}], [{string.Join(",", Strides)}])"; + return $"ScheduledBuffer('{Name}', {Interval}, {Span}, ConstraintsMode.No, [{string.Join(",", Shape)}], [{string.Join(",", Strides)}], {Inplace})"; } } diff --git a/src/Nncase.Passes/BufferSchedule/BufferScheduler.cs b/src/Nncase.Passes/BufferSchedule/BufferScheduler.cs index 496d45af2c..66ebe81736 100644 --- a/src/Nncase.Passes/BufferSchedule/BufferScheduler.cs +++ b/src/Nncase.Passes/BufferSchedule/BufferScheduler.cs @@ -1,10 +1,13 @@ // Copyright (c) Canaan Inc. All rights reserved. // Licensed under the Apache license. See LICENSE file in the project root for full license information. +using System; using System.Collections.Generic; using System.IO; using System.Linq; using System.Reactive; +using Google.OrTools.Sat; +using NetFabric.Hyperlinq; using Nncase; using Nncase.IR; @@ -12,36 +15,119 @@ namespace Nncase.Passes.BufferSchedule; internal sealed class BufferScheduler { - public List CollectLifeTime(Function func) + public IReadOnlyDictionary CollectLifeTime(Function func) { var c = new LifeTimeCollector(); return c.Collect(func); } - public void DumpScheduled(string path, List buffers) + public void Schedule(IReadOnlyDictionary bufferMap) { - using (var fs = File.OpenWrite(path)) + var model = new CpModel(); + var noOverlap = model.AddNoOverlap2D(); + var boxs = new Dictionary(ReferenceEqualityComparer.Instance); + var timeMap = new Dictionary>(); + var yStarts = new List(); + foreach (var (expr, item) in bufferMap) { - using (var wr = new StreamWriter(fs)) + var xInterval = model.NewIntervalVar(model.NewConstant(item.Interval.Brith), model.NewConstant(item.Interval.Size), model.NewConstant(item.Interval.Death), item.Name + "_x"); + + var upbound = 2147483648 - item.Span.End; + if (upbound <= 0) + { + throw new System.NotSupportedException(); + } + + var memStartVar = model.NewIntVar(0, upbound, $"{item.Name}_y_start"); + var yInterval = model.NewFixedSizeIntervalVar(memStartVar, item.Span.End, $"{item.Name}_y"); + + if (!item.Inplace) + { + noOverlap.AddRectangle(xInterval, yInterval); + } + + yStarts.Add(memStartVar); + boxs.Add(expr, (xInterval, yInterval)); + + for (int time = item.Interval.Brith; time < item.Interval.Death; time++) + { + if (!timeMap.TryGetValue(time, out var timelist)) + { + timelist = new(); + timeMap.Add(time, timelist); + } + + timelist.Add(expr); + } + } + + foreach (var (expr, item) in bufferMap) + { + if (expr is Call { Target: IR.Tensors.Concat } concatCall && concatCall.Arguments[0] is IR.Tuple tuple) + { + // the concat inputs must contiguous + model.AddMinEquality(boxs[concatCall].Y.StartExpr(), tuple.Fields.ToArray().Select(arg => boxs[arg].Y.StartExpr())); + model.AddMaxEquality(boxs[concatCall].Y.EndExpr(), tuple.Fields.ToArray().Select(arg => boxs[arg].Y.EndExpr())); + } + else if (expr is Call { Target: IR.Tensors.Split } splitCall) + { + // the split must equal with input. + model.Add(boxs[splitCall].Y.StartExpr() == boxs[splitCall.Arguments[0]].Y.StartExpr()); + + // the split outputs must contiguous + var users = splitCall.GetUsers(); + model.AddMinEquality(boxs[splitCall].Y.StartExpr(), users.Select(e => boxs[e].Y.StartExpr())); + model.AddMaxEquality(boxs[splitCall].Y.EndExpr(), users.Select(e => boxs[e].Y.EndExpr())); + } + else if (expr is Call { Target: IR.Tensors.Reshape } reshapCall) { - wr.Write(@"from bokeh.models import ColumnDataSource, HoverTool, FuncTickFormatter, SingleIntervalTicker, SaveTool, WheelZoomTool, WheelPanTool, ResetTool + // the reshape must equal with it's input. + model.Add(boxs[reshapCall].Y.StartExpr() == boxs[reshapCall.Arguments[0]].Y.StartExpr()); + } + } + + model.Minimize(LinearExpr.Sum(yStarts)); + + var solver = new CpSolver(); + solver.StringParameters = $"max_time_in_seconds:{60},num_workers:{8}"; + CpSolverStatus solve_status = solver.Solve(model); + if (solve_status != CpSolverStatus.Optimal && solve_status != CpSolverStatus.Feasible) + { + throw new System.NotSupportedException(); + } + + foreach (var (k, v) in bufferMap) + { + bufferMap[k].Span.Start = checked((int)solver.Value(boxs[k].Y.StartExpr())); + bufferMap[k].Span.End = checked((int)solver.Value(boxs[k].Y.EndExpr())); + } + } + + public void Dump(Stream fs, IReadOnlyDictionary buffers) + { + using (var wr = new StreamWriter(fs)) + { + wr.Write(@"from bokeh.models import ColumnDataSource, HoverTool, FuncTickFormatter, SingleIntervalTicker, SaveTool, WheelZoomTool, WheelPanTool, ResetTool from bokeh.palettes import Category20_20 as palette from bokeh.plotting import figure, show, save import itertools from dataclasses import dataclass from enum import Enum from typing import List + @dataclass -class Interval(): +class TimeInterval(): start: int end: int + def __str__(self) -> str: + return f'(start: {self.start}, end {self.end})' @dataclass -class Location(): +class MemSpan(): depth_start: int - depth_size: int + depth_end: int def __str__(self) -> str: - return f'(start: {self.depth_start}, size {self.depth_size})' + return f'(start: {self.depth_start}, size {self.depth_end - self.depth_start})' class ConstraintsMode(Enum): No = 0 @@ -50,22 +136,23 @@ class ConstraintsMode(Enum): @dataclass class ScheduledBuffer(): name: str - interval: Interval - location: Location + interval: TimeInterval + location: MemSpan constraints: ConstraintsMode shape: List[int] stride: List[int] + inplace: bool colors = itertools.cycle(palette) buffers = [ "); - foreach (var item in buffers) - { - wr.WriteLine(item.ToString()); - } + foreach (var (k, v) in buffers) + { + wr.WriteLine(v.ToString() + ","); + } - wr.Write(@"] + wr.Write(@"] source = { 'name': [], @@ -73,42 +160,52 @@ class ScheduledBuffer(): 'y': [], 'width': [], 'height': [], + 'alpha': [], 'color': [], 'location': [], + 'interval': [], 'shape': [], 'stride': [], } y_range_max = 0 +x_range_max = 0 +color_dict = {} for buffer in buffers: source['name'].append(buffer.name) width = buffer.interval.end - buffer.interval.start - x = buffer.interval.start + (width // 2) - height = buffer.location.depth_size - y = buffer.location.depth_start + (height // 2) + x = buffer.interval.start + (width / 2) + height = buffer.location.depth_end - buffer.location.depth_start + y = buffer.location.depth_start + (height / 2) y_range_max = max(y_range_max, y) + x_range_max = max(x_range_max, buffer.interval.end) source['x'].append(x) source['y'].append(y) source['width'].append(width) source['height'].append(height) - source['color'].append(next(colors)) + color = color_dict.get(buffer.name) + if color == None: + color = next(colors) + color_dict[buffer.name] = color + source['color'].append(color) + source['alpha'].append(0.2 if buffer.inplace else 1.0) + source['interval'].append(str(buffer.interval)) source['location'].append(str(buffer.location)) source['shape'].append(','.join([str(s) for s in buffer.shape])) source['stride'].append(','.join([str(s) for s in buffer.stride])) source = ColumnDataSource(source) -hover = HoverTool(tooltips=[('name', '@name'), ('location', '@location'), +hover = HoverTool(tooltips=[('name', '@name'), ('interval', '@interval'), ('location', '@location'), ('shape', '@shape'), ('stride', '@stride')]) p = figure(tools=[hover, WheelPanTool(), SaveTool(), WheelZoomTool(), ResetTool()], width=1280, height=720, - y_range=(0, y_range_max * 2), + y_range=(0, y_range_max * 1.2), x_range=(-1, x_range_max + 1), title='Local Buffer LifeTime (by Steps)') -p.rect(x='x', y='y', width='width', height='height', fill_color='color', source=source) +p.rect(x='x', y='y', width='width', height='height', fill_color='color', legend_field='name', fill_alpha='alpha', source=source) p.xaxis.axis_label = 'Time (steps)' p.outline_line_color = None show(p)"); - } } } } diff --git a/src/Nncase.Passes/BufferSchedule/LifeTimeCollector.cs b/src/Nncase.Passes/BufferSchedule/LifeTimeCollector.cs index 454fd25608..760083f1da 100644 --- a/src/Nncase.Passes/BufferSchedule/LifeTimeCollector.cs +++ b/src/Nncase.Passes/BufferSchedule/LifeTimeCollector.cs @@ -1,6 +1,7 @@ // Copyright (c) Canaan Inc. All rights reserved. // Licensed under the Apache license. See LICENSE file in the project root for full license information. +using System; using System.Collections.Generic; using System.Linq; using System.Reactive; @@ -15,12 +16,11 @@ internal sealed class LifeTimeCollector : ExprVisitor public Dictionary LifenessMap { get; } = new(ReferenceEqualityComparer.Instance); - public List Collect(Function entry) + public IReadOnlyDictionary Collect(Function entry) { Visit(entry.Body); - Alias(); - var l = new List(); + var d = new Dictionary(ReferenceEqualityComparer.Instance); foreach (var (k, v) in LifenessMap) { var name = k switch @@ -29,15 +29,12 @@ public List Collect(Function entry) Var va => va.Name, _ => k.GetType().Name, }; + var size = GetSize(k.CheckedType, out var shape, out var stride); - var shape = k.CheckedShape.ToValueArray(); - var stride = TensorUtilities.GetStrides(shape); - var size = TensorUtilities.GetSize(shape, stride, k.CheckedDataType.SizeInBytes); - - l.Add(new(name, v, new(0, size), shape, stride)); + d.Add(k, new(name, v, new(0, size), shape, stride, IsInPlace(k))); } - return l; + return d; } protected override Unit DefaultVisitLeaf(Expr expr) => Unit.Default; @@ -58,7 +55,7 @@ protected override Unit VisitLeafCall(Call expr) private void Update(Expr expr) { - if (expr is Const) + if (expr is (Const or None)) { return; } @@ -79,25 +76,53 @@ private void Update(Expr expr) } else { - interval.End += 1; + interval.Death = TimeStamp + 1; } - // advance the getitem buffer. - if (expr is Call { Target: IR.Tensors.GetItem, Arguments: var args } call && args[0] is Call { CheckedType: TupleType }) + LifenessMap[expr] = interval; + } + + private bool IsInPlace(Expr expr) + { + if (expr is Call { Target: IR.Tensors.Reshape } callReshape) { - interval.Start = LifenessMap[args[0]].Start; + return true; } - LifenessMap[expr] = interval; + if (expr is Call { Target: IR.Tensors.Concat } concatCall && concatCall.Arguments[0] is IR.Tuple concatTuple) + { + return true; + } + + if (expr is Call { Target: IR.Tensors.Split } splitCall) + { + return true; + } + + return false; } - private void Alias() + private int GetSize(IRType type, out int[] shape, out int[] stride) { - // skip the call which output type is tuple. - var calls = LifenessMap.Select(kv => kv.Key is Call { CheckedType: TupleType }).ToArray(); - foreach (var c in calls) + shape = Array.Empty(); + stride = Array.Empty(); + var size = 0; + if (type is TensorType tensorType) { - LifenessMap.Remove(c); + shape = tensorType.Shape.ToValueArray(); + stride = TensorUtilities.GetStrides(shape); + size = TensorUtilities.GetSize(shape, stride, tensorType.DType.SizeInBytes); } + else if (type is TupleType tupleType) + { + size = 0; + foreach (var item in tupleType) + { + size += GetSize(item, out _, out _); + } + } + + return size; } + } diff --git a/src/Nncase.Passes/DDrBufferSchdeulePass.cs b/src/Nncase.Passes/DDrBufferSchdeulePass.cs index cdd9fe3be1..015b9d3dd3 100644 --- a/src/Nncase.Passes/DDrBufferSchdeulePass.cs +++ b/src/Nncase.Passes/DDrBufferSchdeulePass.cs @@ -43,11 +43,16 @@ protected override async Task RunCoreAsync(IRModule module, RunPassCon // 1. merge the all call prim func if (_enbaleMergeCall) { - // if (module.Entry is Function { ModuleKind: Callable.StackVMModuleKind, Body: Expr body } func && IsFixedType(body.CheckedType)) - // { - // var sorter = new TopSorter(); - // sorter.GetTimeLine(func); - // } + if (module.Entry is Function { ModuleKind: Callable.StackVMModuleKind, Body: Expr body } func && IsFixedType(body.CheckedType)) + { + var sch = new BufferSchedule.BufferScheduler(); + var buffers = sch.CollectLifeTime(func); + sch.Schedule(buffers); + using (var fs = Diagnostics.DumpScope.Current.OpenFile("draw_buffers.py")) + { + sch.Dump(fs, buffers); + } + } } // 4. schedule the prim funcs. diff --git a/src/Nncase.Tests/EGraph/UnitTestVrp.cs b/src/Nncase.Tests/EGraph/UnitTestVrp.cs index 5a9f4b564e..f421022a7c 100644 --- a/src/Nncase.Tests/EGraph/UnitTestVrp.cs +++ b/src/Nncase.Tests/EGraph/UnitTestVrp.cs @@ -174,6 +174,35 @@ public void TestSimpleEgraphSat() } } + [Fact] + public void TestOverLap() + { + // note ortools no overlap not support 0 size. + var model = new CpModel(); + + var x0 = model.NewIntervalVar(model.NewConstant(0), model.NewConstant(2), model.NewConstant(2), "x0"); + var y0 = model.NewFixedSizeIntervalVar(model.NewIntVar(0, 10, "y0_start"), 7, "y0"); + + var x1 = model.NewIntervalVar(model.NewConstant(2), model.NewConstant(0), model.NewConstant(2), "x1"); + var y1 = model.NewFixedSizeIntervalVar(model.NewIntVar(0, 10, "y1_start"), 7, "y1"); + + var x2 = model.NewIntervalVar(model.NewConstant(2), model.NewConstant(1), model.NewConstant(3), "x2"); + var y2 = model.NewFixedSizeIntervalVar(model.NewIntVar(0, 10, "y2_start"), 7, "y2"); + + model.Add(y0.StartExpr() == y1.StartExpr()); + model.Add(y1.StartExpr() == y2.StartExpr()); + var nooverlap = model.AddNoOverlap2D(); + nooverlap.AddRectangle(x0, y0); + nooverlap.AddRectangle(x1, y1); + nooverlap.AddRectangle(x2, y2); + model.Minimize(y0.StartExpr() + y1.StartExpr() + y2.StartExpr()); + + var solver = new CpSolver(); + var status = solver.Solve(model); + + Assert.Equal(CpSolverStatus.Infeasible, status); + } + private static void PrintSolution(in IDataModel data, in RoutingModel routing, in RoutingIndexManager manager, in Assignment solution) { Console.WriteLine($"Objective {solution.ObjectiveValue()}:"); From 858e2af2fb4ba481928f3ec163fc38f692ccbb38 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?=E9=83=91=E5=90=AF=E8=88=AA?= <597323109@qq.com> Date: Wed, 2 Aug 2023 18:16:58 +0800 Subject: [PATCH 064/308] fix bug --- .../BufferScheduleExtensions.cs | 7 +++ .../BufferSchedule/BufferScheduleTypes.cs | 7 ++- .../BufferSchedule/BufferScheduler.cs | 14 ++--- .../BufferSchedule/LifeTimeCollector.cs | 62 +++++++++++++++---- 4 files changed, 67 insertions(+), 23 deletions(-) diff --git a/src/Nncase.Passes/BufferSchedule/BufferScheduleExtensions.cs b/src/Nncase.Passes/BufferSchedule/BufferScheduleExtensions.cs index 91cf5e8dd4..f10e080fc0 100644 --- a/src/Nncase.Passes/BufferSchedule/BufferScheduleExtensions.cs +++ b/src/Nncase.Passes/BufferSchedule/BufferScheduleExtensions.cs @@ -7,6 +7,13 @@ namespace Nncase.Passes.BufferSchedule; public static class BufferScheduleExtensions { + public static IEnumerable GetArguments(this Call call) + { + var hs = new HashSet(ReferenceEqualityComparer.Instance); + hs.UnionWith(call.Arguments.ToArray().Where(e => e is not (BaseFunction or Const)).ToArray().Select(e => e switch { IR.Tuple tp => tp.Fields.ToArray(), _ => new[] { e } }).SelectMany(i => i)); + return hs; + } + public static IEnumerable GetUsers(this Call call) { var hs = new HashSet(ReferenceEqualityComparer.Instance); diff --git a/src/Nncase.Passes/BufferSchedule/BufferScheduleTypes.cs b/src/Nncase.Passes/BufferSchedule/BufferScheduleTypes.cs index 449102871f..a7920aca48 100644 --- a/src/Nncase.Passes/BufferSchedule/BufferScheduleTypes.cs +++ b/src/Nncase.Passes/BufferSchedule/BufferScheduleTypes.cs @@ -43,9 +43,10 @@ public override string ToString() internal class ScheduleBuffer { - public ScheduleBuffer(string name, TimeInterval interval, MemSpan span, int[] shape, int[] strides, bool inplace) + public ScheduleBuffer(string name, int number, TimeInterval interval, MemSpan span, int[] shape, int[] strides, bool inplace) { Name = name; + Number = number; Interval = interval; Span = span; Shape = shape; @@ -54,7 +55,7 @@ public ScheduleBuffer(string name, TimeInterval interval, MemSpan span, int[] sh } public string Name { get; } - + public int Number { get; } public TimeInterval Interval { get; } public MemSpan Span { get; } @@ -67,6 +68,6 @@ public ScheduleBuffer(string name, TimeInterval interval, MemSpan span, int[] sh public override string ToString() { - return $"ScheduledBuffer('{Name}', {Interval}, {Span}, ConstraintsMode.No, [{string.Join(",", Shape)}], [{string.Join(",", Strides)}], {Inplace})"; + return $"ScheduledBuffer('{Name}', {Number}, {Interval}, {Span}, ConstraintsMode.No, [{string.Join(",", Shape)}], [{string.Join(",", Strides)}], {Inplace})"; } } diff --git a/src/Nncase.Passes/BufferSchedule/BufferScheduler.cs b/src/Nncase.Passes/BufferSchedule/BufferScheduler.cs index 66ebe81736..88225cb76e 100644 --- a/src/Nncase.Passes/BufferSchedule/BufferScheduler.cs +++ b/src/Nncase.Passes/BufferSchedule/BufferScheduler.cs @@ -30,7 +30,7 @@ public void Schedule(IReadOnlyDictionary bufferMap) var yStarts = new List(); foreach (var (expr, item) in bufferMap) { - var xInterval = model.NewIntervalVar(model.NewConstant(item.Interval.Brith), model.NewConstant(item.Interval.Size), model.NewConstant(item.Interval.Death), item.Name + "_x"); + var xInterval = model.NewIntervalVar(model.NewConstant(item.Interval.Brith), model.NewConstant(item.Interval.Size), model.NewConstant(item.Interval.Death), item.Name + $"{item.Number}_x"); var upbound = 2147483648 - item.Span.End; if (upbound <= 0) @@ -38,14 +38,9 @@ public void Schedule(IReadOnlyDictionary bufferMap) throw new System.NotSupportedException(); } - var memStartVar = model.NewIntVar(0, upbound, $"{item.Name}_y_start"); - var yInterval = model.NewFixedSizeIntervalVar(memStartVar, item.Span.End, $"{item.Name}_y"); - - if (!item.Inplace) - { - noOverlap.AddRectangle(xInterval, yInterval); - } - + var memStartVar = model.NewIntVar(0, upbound, $"{item.Name}_{item.Number}_y_start"); + var yInterval = model.NewFixedSizeIntervalVar(memStartVar, item.Span.End, $"{item.Name}_{item.Number}_y"); + noOverlap.AddRectangle(xInterval, yInterval); yStarts.Add(memStartVar); boxs.Add(expr, (xInterval, yInterval)); @@ -136,6 +131,7 @@ class ConstraintsMode(Enum): @dataclass class ScheduledBuffer(): name: str + number: int interval: TimeInterval location: MemSpan constraints: ConstraintsMode diff --git a/src/Nncase.Passes/BufferSchedule/LifeTimeCollector.cs b/src/Nncase.Passes/BufferSchedule/LifeTimeCollector.cs index 760083f1da..7a7f04561b 100644 --- a/src/Nncase.Passes/BufferSchedule/LifeTimeCollector.cs +++ b/src/Nncase.Passes/BufferSchedule/LifeTimeCollector.cs @@ -19,8 +19,10 @@ internal sealed class LifeTimeCollector : ExprVisitor public IReadOnlyDictionary Collect(Function entry) { Visit(entry.Body); + Alias(); var d = new Dictionary(ReferenceEqualityComparer.Instance); + int count = 0; foreach (var (k, v) in LifenessMap) { var name = k switch @@ -31,7 +33,7 @@ public IReadOnlyDictionary Collect(Function entry) }; var size = GetSize(k.CheckedType, out var shape, out var stride); - d.Add(k, new(name, v, new(0, size), shape, stride, IsInPlace(k))); + d.Add(k, new(name, count++, v, new(0, size), shape, stride, false)); } return d; @@ -48,7 +50,12 @@ protected override Unit VisitLeafCall(Call expr) Update(expr); - TimeStamp += 1; + TimeStamp += 2; + + foreach (var item in expr.GetUsers()) + { + Update(item); + } return Unit.Default; } @@ -82,24 +89,57 @@ private void Update(Expr expr) LifenessMap[expr] = interval; } - private bool IsInPlace(Expr expr) + private void Alias() { - if (expr is Call { Target: IR.Tensors.Reshape } callReshape) + bool changed = true; + do { - return true; - } + changed = false; + foreach (var (expr, interval) in LifenessMap) + { + if (expr is Call { Target: IR.Tensors.Reshape } callReshape) + { + changed = AliasTime(callReshape, interval); + } + } + + foreach (var (expr, interval) in LifenessMap) + { + + if (expr is Call { Target: IR.Tensors.Concat } concatCall) + { + changed = AliasTime(concatCall, interval); + } + } + + foreach (var (expr, interval) in LifenessMap) + { + if (expr is Call { Target: IR.Tensors.Split } splitCall) + { + changed = AliasTime(splitCall, interval); + } + } + } while (!changed); + } + + private bool AliasTime(Call call, TimeInterval interval) + { + var brith = call.GetArguments().Select(arg => LifenessMap[arg].Death).Concat(new[] { interval.Brith }).Max(); + var death = call.GetUsers().Select(usr => LifenessMap[usr].Brith).Concat(new[] { interval.Death }).Min(); - if (expr is Call { Target: IR.Tensors.Concat } concatCall && concatCall.Arguments[0] is IR.Tuple concatTuple) + if (brith == interval.Brith && death == interval.Death) { - return true; + return false; } - if (expr is Call { Target: IR.Tensors.Split } splitCall) + if (brith >= death) { - return true; + throw new InvalidOperationException(); } - return false; + interval.Brith = brith; + interval.Death = death; + return true; } private int GetSize(IRType type, out int[] shape, out int[] stride) From 6553c0cae2fa75539a3a94715c6c5041010f3bb6 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?=E9=83=91=E5=90=AF=E8=88=AA?= <597323109@qq.com> Date: Wed, 2 Aug 2023 19:56:36 +0800 Subject: [PATCH 065/308] fix bug --- src/Nncase.Passes/BufferSchedule/BufferScheduleExtensions.cs | 3 +-- src/Nncase.Passes/BufferSchedule/LifeTimeCollector.cs | 5 +++-- 2 files changed, 4 insertions(+), 4 deletions(-) diff --git a/src/Nncase.Passes/BufferSchedule/BufferScheduleExtensions.cs b/src/Nncase.Passes/BufferSchedule/BufferScheduleExtensions.cs index f10e080fc0..706fbac412 100644 --- a/src/Nncase.Passes/BufferSchedule/BufferScheduleExtensions.cs +++ b/src/Nncase.Passes/BufferSchedule/BufferScheduleExtensions.cs @@ -17,8 +17,7 @@ public static IEnumerable GetArguments(this Call call) public static IEnumerable GetUsers(this Call call) { var hs = new HashSet(ReferenceEqualityComparer.Instance); - hs.UnionWith(call.Users.Where(e => e is not BaseFunction).ToArray().Select(e => e switch { IR.Tuple tp => tp.Fields.ToArray(), _ => new[] { e } }).SelectMany(i => i)); + hs.UnionWith(call.Users.Where(e => e is not BaseFunction).ToArray().Select(e => e switch { IR.Tuple tp => tp.Users.ToArray(), _ => new[] { e } }).SelectMany(i => i)); return hs; } - } \ No newline at end of file diff --git a/src/Nncase.Passes/BufferSchedule/LifeTimeCollector.cs b/src/Nncase.Passes/BufferSchedule/LifeTimeCollector.cs index 7a7f04561b..bfa43f9707 100644 --- a/src/Nncase.Passes/BufferSchedule/LifeTimeCollector.cs +++ b/src/Nncase.Passes/BufferSchedule/LifeTimeCollector.cs @@ -19,6 +19,7 @@ internal sealed class LifeTimeCollector : ExprVisitor public IReadOnlyDictionary Collect(Function entry) { Visit(entry.Body); + Update(entry.Body); // avoid final call time interval size == 1. Alias(); var d = new Dictionary(ReferenceEqualityComparer.Instance); @@ -52,7 +53,8 @@ protected override Unit VisitLeafCall(Call expr) TimeStamp += 2; - foreach (var item in expr.GetUsers()) + // note we will update tuple field on the next call. + foreach (var item in expr.Users.Where(e => e is not (BaseFunction or IR.Tuple))) { Update(item); } @@ -105,7 +107,6 @@ private void Alias() foreach (var (expr, interval) in LifenessMap) { - if (expr is Call { Target: IR.Tensors.Concat } concatCall) { changed = AliasTime(concatCall, interval); From 191bd67b171e9e6783854403d20f47484896b4ac Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?=E9=83=91=E5=90=AF=E8=88=AA?= <597323109@qq.com> Date: Thu, 3 Aug 2023 10:56:01 +0800 Subject: [PATCH 066/308] fix concat/split order --- .../BufferSchedule/BufferScheduleTypes.cs | 2 ++ .../BufferSchedule/BufferScheduler.cs | 16 ++++++++++++---- 2 files changed, 14 insertions(+), 4 deletions(-) diff --git a/src/Nncase.Passes/BufferSchedule/BufferScheduleTypes.cs b/src/Nncase.Passes/BufferSchedule/BufferScheduleTypes.cs index a7920aca48..f3b548cf1b 100644 --- a/src/Nncase.Passes/BufferSchedule/BufferScheduleTypes.cs +++ b/src/Nncase.Passes/BufferSchedule/BufferScheduleTypes.cs @@ -35,6 +35,8 @@ public MemSpan(int start, int end) public int End { get; set; } + public int Size => End - Start; + public override string ToString() { return $"MemSpan({Start}, {End})"; diff --git a/src/Nncase.Passes/BufferSchedule/BufferScheduler.cs b/src/Nncase.Passes/BufferSchedule/BufferScheduler.cs index 88225cb76e..cdad8d521d 100644 --- a/src/Nncase.Passes/BufferSchedule/BufferScheduler.cs +++ b/src/Nncase.Passes/BufferSchedule/BufferScheduler.cs @@ -61,8 +61,12 @@ public void Schedule(IReadOnlyDictionary bufferMap) if (expr is Call { Target: IR.Tensors.Concat } concatCall && concatCall.Arguments[0] is IR.Tuple tuple) { // the concat inputs must contiguous - model.AddMinEquality(boxs[concatCall].Y.StartExpr(), tuple.Fields.ToArray().Select(arg => boxs[arg].Y.StartExpr())); - model.AddMaxEquality(boxs[concatCall].Y.EndExpr(), tuple.Fields.ToArray().Select(arg => boxs[arg].Y.EndExpr())); + int offset = 0; + for (int i = 0; i < tuple.Fields.Length; i++) + { + model.Add((boxs[concatCall].Y.StartExpr() + offset) == boxs[tuple.Fields[i]].Y.StartExpr()); + offset += bufferMap[tuple.Fields[i]].Span.Size; + } } else if (expr is Call { Target: IR.Tensors.Split } splitCall) { @@ -71,8 +75,12 @@ public void Schedule(IReadOnlyDictionary bufferMap) // the split outputs must contiguous var users = splitCall.GetUsers(); - model.AddMinEquality(boxs[splitCall].Y.StartExpr(), users.Select(e => boxs[e].Y.StartExpr())); - model.AddMaxEquality(boxs[splitCall].Y.EndExpr(), users.Select(e => boxs[e].Y.EndExpr())); + int offset = 0; + foreach (var user in users.OrderBy(e => ((Call)e).Arguments[1].Evaluate().AsTensor().ToScalar())) + { + model.Add((boxs[splitCall].Y.StartExpr() + offset) == boxs[user].Y.StartExpr()); + offset += bufferMap[user].Span.Size; + } } else if (expr is Call { Target: IR.Tensors.Reshape } reshapCall) { From 98d57a3a6e04b7412a80a218451389f7826e042b Mon Sep 17 00:00:00 2001 From: zhen8838 Date: Thu, 3 Aug 2023 03:29:27 +0000 Subject: [PATCH 067/308] Apply code-format changes --- .../BufferSchedule/BufferScheduleExtensions.cs | 6 ++++-- src/Nncase.Passes/BufferSchedule/BufferScheduleTypes.cs | 2 ++ src/Nncase.Passes/BufferSchedule/LifeTimeCollector.cs | 5 ++--- 3 files changed, 8 insertions(+), 5 deletions(-) diff --git a/src/Nncase.Passes/BufferSchedule/BufferScheduleExtensions.cs b/src/Nncase.Passes/BufferSchedule/BufferScheduleExtensions.cs index 706fbac412..4a07e97a8b 100644 --- a/src/Nncase.Passes/BufferSchedule/BufferScheduleExtensions.cs +++ b/src/Nncase.Passes/BufferSchedule/BufferScheduleExtensions.cs @@ -1,3 +1,6 @@ +// Copyright (c) Canaan Inc. All rights reserved. +// Licensed under the Apache license. See LICENSE file in the project root for full license information. + using System.Collections.Generic; using System.Linq; using Nncase.IR; @@ -6,7 +9,6 @@ namespace Nncase.Passes.BufferSchedule; public static class BufferScheduleExtensions { - public static IEnumerable GetArguments(this Call call) { var hs = new HashSet(ReferenceEqualityComparer.Instance); @@ -20,4 +22,4 @@ public static IEnumerable GetUsers(this Call call) hs.UnionWith(call.Users.Where(e => e is not BaseFunction).ToArray().Select(e => e switch { IR.Tuple tp => tp.Users.ToArray(), _ => new[] { e } }).SelectMany(i => i)); return hs; } -} \ No newline at end of file +} diff --git a/src/Nncase.Passes/BufferSchedule/BufferScheduleTypes.cs b/src/Nncase.Passes/BufferSchedule/BufferScheduleTypes.cs index f3b548cf1b..13e8ab86f0 100644 --- a/src/Nncase.Passes/BufferSchedule/BufferScheduleTypes.cs +++ b/src/Nncase.Passes/BufferSchedule/BufferScheduleTypes.cs @@ -57,7 +57,9 @@ public ScheduleBuffer(string name, int number, TimeInterval interval, MemSpan sp } public string Name { get; } + public int Number { get; } + public TimeInterval Interval { get; } public MemSpan Span { get; } diff --git a/src/Nncase.Passes/BufferSchedule/LifeTimeCollector.cs b/src/Nncase.Passes/BufferSchedule/LifeTimeCollector.cs index bfa43f9707..284a56b659 100644 --- a/src/Nncase.Passes/BufferSchedule/LifeTimeCollector.cs +++ b/src/Nncase.Passes/BufferSchedule/LifeTimeCollector.cs @@ -64,7 +64,7 @@ protected override Unit VisitLeafCall(Call expr) private void Update(Expr expr) { - if (expr is (Const or None)) + if (expr is Const or None) { return; } @@ -93,7 +93,7 @@ private void Update(Expr expr) private void Alias() { - bool changed = true; + bool changed; do { changed = false; @@ -165,5 +165,4 @@ private int GetSize(IRType type, out int[] shape, out int[] stride) return size; } - } From e6bb3fddbd97fd6765eebfc8016808e85f14c924 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?=E9=83=91=E5=90=AF=E8=88=AA?= <597323109@qq.com> Date: Thu, 3 Aug 2023 12:01:01 +0800 Subject: [PATCH 068/308] format --- src/Nncase.Passes/BufferSchedule/BufferScheduler.cs | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/Nncase.Passes/BufferSchedule/BufferScheduler.cs b/src/Nncase.Passes/BufferSchedule/BufferScheduler.cs index cdad8d521d..25f7d6f5b8 100644 --- a/src/Nncase.Passes/BufferSchedule/BufferScheduler.cs +++ b/src/Nncase.Passes/BufferSchedule/BufferScheduler.cs @@ -151,7 +151,7 @@ class ScheduledBuffer(): buffers = [ "); - foreach (var (k, v) in buffers) + foreach (var (_, v) in buffers) { wr.WriteLine(v.ToString() + ","); } From 945799cc49d539bcc375271d95fc96ea43b7ec62 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?=E9=83=91=E5=90=AF=E8=88=AA?= <597323109@qq.com> Date: Tue, 8 Aug 2023 19:55:42 +0800 Subject: [PATCH 069/308] add unary parallel --- .../CodeGen/CSourceBuiltn.cs | 3 + .../CodeGen/CSourceConvertVisitor.cs | 88 +++++++++++++++++-- .../Passes/Tile/SingleCPUFusionConverter.cs | 4 +- 3 files changed, 87 insertions(+), 8 deletions(-) diff --git a/modules/Nncase.Modules.CPU/CodeGen/CSourceBuiltn.cs b/modules/Nncase.Modules.CPU/CodeGen/CSourceBuiltn.cs index b2e0a6fcb3..a945b05958 100644 --- a/modules/Nncase.Modules.CPU/CodeGen/CSourceBuiltn.cs +++ b/modules/Nncase.Modules.CPU/CodeGen/CSourceBuiltn.cs @@ -72,6 +72,9 @@ public static class CSourceBuiltn bool (*bool_binary_and)(bool, bool); bool (*bool_binary_or)(bool, bool); bool (*bool_binary_xor)(bool, bool); + // paralell + void *thread_start(void *(*callable)(void *args), void *user); + void thread_end(); } nncase_mt_t;"; public const string Include = @"#include diff --git a/modules/Nncase.Modules.CPU/CodeGen/CSourceConvertVisitor.cs b/modules/Nncase.Modules.CPU/CodeGen/CSourceConvertVisitor.cs index bc715c4f2f..2803ebba94 100644 --- a/modules/Nncase.Modules.CPU/CodeGen/CSourceConvertVisitor.cs +++ b/modules/Nncase.Modules.CPU/CodeGen/CSourceConvertVisitor.cs @@ -27,8 +27,8 @@ internal struct IndentScope : IDisposable public IndentScope(StringBuilder sb) { _initialized = true; + _originalWriter = _writer.Value; _writer.Value = new IndentWriter(sb); - _originalWriter = null; } public IndentScope() @@ -70,6 +70,12 @@ public CSymbol(string type, string name) public string Name { get; } public override string ToString() => $"{Type} {Name}"; + + public static IReadOnlyList Builtns => new CSymbol[] { + new CSymbol("nncase_mt_t*", "nncase_mt"), + new CSymbol("uint8_t*", "data"), + new CSymbol("const uint8_t*", "rdata"), + }; } internal sealed class IndentWriter : StringWriter @@ -100,16 +106,22 @@ internal sealed class CSourceConvertVisitor : ExprFunctor { private readonly Dictionary _exprMemo; private readonly StringBuilder _implBuilder; + private readonly StringBuilder _declBuilder; + private readonly StringWriter _declWriter; public CSourceConvertVisitor() { _implBuilder = new StringBuilder(); + _declBuilder = new StringBuilder(); + _declWriter = new StringWriter(_declBuilder); _exprMemo = new(ReferenceEqualityComparer.Instance); } + public PrimFunction VisitEntry => (TIR.PrimFunction)VisitRoot!; + public FunctionCSource GetFunctionCSource() { - return new(_exprMemo[VisitRoot!].Type + ";", _implBuilder.ToString()); + return new(_declBuilder.ToString(), _implBuilder.ToString()); } /// @@ -127,6 +139,9 @@ protected override CSymbol VisitPrimFunction(PrimFunction expr) var type = $"void {expr.Name}({string.Join(", ", expr.Parameters.AsValueEnumerable().Select(b => Visit(b.MemSpan.Start).ToString()).ToArray())}, {CSourceBuiltn.FixedParameters})"; + _declWriter.WriteLine(type + ";"); + _declWriter.WriteLine(); + using (var scope = new IndentScope(_implBuilder)) { // 1. Function signature @@ -142,7 +157,8 @@ protected override CSymbol VisitPrimFunction(PrimFunction expr) IndentScope.Writer.IndWrite("}\n"); } - symbol = new(type, new(expr.Name)); + var ctype = $"void (*{expr.Name})({string.Join(", ", expr.Parameters.AsValueEnumerable().Select(b => Visit(b.MemSpan.Start).ToString()).ToArray())}, {CSourceBuiltn.FixedParameters})"; + symbol = new(ctype, expr.Name); _exprMemo.Add(expr, symbol); return symbol; } @@ -294,10 +310,51 @@ protected override CSymbol VisitFor(For expr) // 1. For Loop signature var loopVar = Visit(expr.LoopVar); IndentScope.Writer.IndWrite($"for ({loopVar.Type} {loopVar.Name} = {Visit(expr.Domain.Start).Name}; {loopVar.Name} < {Visit(expr.Domain.Stop).Name}; {loopVar.Name}+={Visit(expr.Domain.Step).Name}) {{\n"); - using (_ = new IndentScope()) + + if (expr.Mode == LoopMode.Parallel) + { + // find the vars will be used and make new struct type. + var msg_fields = _exprMemo.Where(p => p.Key is MemSpan or TIR.Buffer or Var).Select(p => p.Value).Concat(CSymbol.Builtns); + var msg_type = DeclThreadMessageStruct(msg_fields); + + using (new IndentScope(_declBuilder)) + { + IndentScope.Writer.IndWrite($"void *{VisitEntry.Name}_inner(void *args) {{\n"); + using (new IndentScope()) + { + IndentScope.Writer.IndWrite($"{msg_type}* _message = ({msg_type}*)args;\n"); + foreach (var sym in msg_fields) + { + IndentScope.Writer.IndWrite($"{sym.Type} {sym.Name} = _message->{sym.Name};\n"); + } + + Visit(expr.Body); + } + + IndentScope.Writer.IndWrite(" return 0;\n"); + IndentScope.Writer.IndWrite("}\n"); + } + + using (new IndentScope()) + { + IndentScope.Writer.IndWrite($"{msg_type} _message = {{\n"); + foreach (var sym in msg_fields) + { + IndentScope.Writer.IndWrite($".{sym.Name} = {sym.Name},\n"); + } + + IndentScope.Writer.IndWrite("};\n"); + + IndentScope.Writer.IndWrite($"nncase_mt->thread_start({VisitEntry.Name}_inner, (void *)_message);\n"); + } + } + else { - // 2. For Body - Visit(expr.Body); + using (_ = new IndentScope()) + { + // 2. For Body + Visit(expr.Body); + } } // 3. For closing @@ -360,4 +417,23 @@ protected override CSymbol VisitIfThenElse(IfThenElse expr) _exprMemo.Add(expr, symbol); return symbol; } + + private string DeclThreadMessageStruct(IEnumerable keyValues) + { + var type = $"{VisitEntry.Name}_thread_message_t"; + _declWriter.WriteLine("typedef struct {"); + foreach (var sym in keyValues) + { + if (sym.Name == string.Empty) + { + throw new InvalidOperationException("empty name"); + } + + _declWriter.WriteLine(" " + sym.Type + " " + sym.Name + ";"); + } + + _declWriter.WriteLine($"}} {type};"); + _declWriter.WriteLine(); + return type; + } } diff --git a/modules/Nncase.Modules.CPU/Passes/Tile/SingleCPUFusionConverter.cs b/modules/Nncase.Modules.CPU/Passes/Tile/SingleCPUFusionConverter.cs index bd45b072ca..2dc177c607 100644 --- a/modules/Nncase.Modules.CPU/Passes/Tile/SingleCPUFusionConverter.cs +++ b/modules/Nncase.Modules.CPU/Passes/Tile/SingleCPUFusionConverter.cs @@ -86,7 +86,7 @@ private void GenerateMatMul(ReadOnlySpan arguments, Buffer ret, Call exp var lhs = arguments[0]; var rhs = arguments[1]; - var loops = Enumerable.Range(0, lhs.Rank - 2).Select(i => (T.ForLoop(out var loopVar, (0, lhs.Dimensions[i]), LoopMode.Serial, $"loop_{i}"), loopVar)).ToArray(); + var loops = Enumerable.Range(0, lhs.Rank - 2).Select(i => (T.ForLoop(out var loopVar, (0, lhs.Dimensions[i]), i == 0 ? LoopMode.Parallel : LoopMode.Serial, $"loop_{i}"), loopVar)).ToArray(); var loopVars = loops.Select(f => f.loopVar).ToArray(); var stmt = T.Serial(out var m, (0, lhs.Dimensions[^2])).Body( T.Serial(out var n, (0, rhs.Dimensions[^1])).Body( @@ -108,7 +108,7 @@ private void GenerateMatMul(ReadOnlySpan arguments, Buffer ret, Call exp private void GenerateUnary(Unary unary, ReadOnlySpan arguments, Buffer ret) { var input = arguments[Unary.Input.Index]; - var loops = Enumerable.Range(0, input.Rank).Select(i => (T.ForLoop(out var loopVar, (0, input.Dimensions[i]), LoopMode.Serial, $"loop_{i}"), loopVar)).ToArray(); + var loops = Enumerable.Range(0, input.Rank).Select(i => (T.ForLoop(out var loopVar, (0, input.Dimensions[i]), i == 0 ? LoopMode.Parallel : LoopMode.Serial, $"loop_{i}"), loopVar)).ToArray(); var loopVars = loops.Select(f => f.loopVar).ToArray(); Expr stmt = T.BufferStore(ret, loopVars, IR.F.Math.Unary(unary.UnaryOp, T.BufferLoad(input, loopVars))); var final = loops.Reverse().Aggregate(stmt, (acc, p) => p.Item1.Body(acc).Build()); From c4464cac9cd0deb39a4d3ed456d851cf5d9a5e0b Mon Sep 17 00:00:00 2001 From: zhen8838 Date: Tue, 8 Aug 2023 11:58:41 +0000 Subject: [PATCH 070/308] Apply code-format changes --- .../CodeGen/CSourceConvertVisitor.cs | 12 ++++++------ 1 file changed, 6 insertions(+), 6 deletions(-) diff --git a/modules/Nncase.Modules.CPU/CodeGen/CSourceConvertVisitor.cs b/modules/Nncase.Modules.CPU/CodeGen/CSourceConvertVisitor.cs index 2803ebba94..636cf1e2ea 100644 --- a/modules/Nncase.Modules.CPU/CodeGen/CSourceConvertVisitor.cs +++ b/modules/Nncase.Modules.CPU/CodeGen/CSourceConvertVisitor.cs @@ -65,17 +65,17 @@ public CSymbol(string type, string name) Name = name; } - public string Type { get; } - - public string Name { get; } - - public override string ToString() => $"{Type} {Name}"; - public static IReadOnlyList Builtns => new CSymbol[] { new CSymbol("nncase_mt_t*", "nncase_mt"), new CSymbol("uint8_t*", "data"), new CSymbol("const uint8_t*", "rdata"), }; + + public string Type { get; } + + public string Name { get; } + + public override string ToString() => $"{Type} {Name}"; } internal sealed class IndentWriter : StringWriter From 9404c5fba4938fe71ebd9326a574993c40b21bcb Mon Sep 17 00:00:00 2001 From: huochenghai Date: Wed, 9 Aug 2023 19:15:43 +0800 Subject: [PATCH 071/308] support multi-thread --- modules/cpu/src/runtime/CMakeLists.txt | 5 ++- modules/cpu/src/runtime/cpu_common.h | 5 +++ modules/cpu/src/runtime/elfloader.cpp | 4 ++ modules/cpu/src/runtime/thread_pool.cpp | 53 +++++++++++++++++++++++++ modules/cpu/src/runtime/thread_pool.h | 28 +++++++++++++ 5 files changed, 93 insertions(+), 2 deletions(-) create mode 100644 modules/cpu/src/runtime/thread_pool.cpp create mode 100644 modules/cpu/src/runtime/thread_pool.h diff --git a/modules/cpu/src/runtime/CMakeLists.txt b/modules/cpu/src/runtime/CMakeLists.txt index ebba027d8d..37da417626 100644 --- a/modules/cpu/src/runtime/CMakeLists.txt +++ b/modules/cpu/src/runtime/CMakeLists.txt @@ -2,6 +2,7 @@ set(SRCS runtime_module.cpp runtime_function.cpp + thread_pool.cpp elfload.cpp elfloader.cpp elfreloc_aarch64.cpp @@ -12,13 +13,13 @@ set(SRCS runtime_module.cpp if (BUILDING_RUNTIME) if (ENABLE_CPU_RUNTIME) add_library(runtime_cpu OBJECT ${SRCS}) - target_link_libraries(runtime_cpu PUBLIC nncaseruntime) + target_link_libraries(runtime_cpu PUBLIC nncaseruntime pthread) set_target_properties(runtime_cpu PROPERTIES POSITION_INDEPENDENT_CODE ON) install(TARGETS runtime_cpu EXPORT nncaseruntimeTargets) endif() else() add_library(simulator_cpu OBJECT ${SRCS}) - target_link_libraries(simulator_cpu PUBLIC nncasebase nncaseruntime) + target_link_libraries(simulator_cpu PUBLIC nncasebase nncaseruntime pthread) target_compile_definitions(simulator_cpu PUBLIC -DNNCASE_MODULES_CPU_DLL -DNNCASE_SIMULATOR) set_target_properties(simulator_cpu PROPERTIES POSITION_INDEPENDENT_CODE ON) endif() diff --git a/modules/cpu/src/runtime/cpu_common.h b/modules/cpu/src/runtime/cpu_common.h index aefa9f32d1..d28d6ed69e 100644 --- a/modules/cpu/src/runtime/cpu_common.h +++ b/modules/cpu/src/runtime/cpu_common.h @@ -1,4 +1,5 @@ #pragma once +#include "thread_pool.h" #include #include #include @@ -65,6 +66,10 @@ typedef struct nncase_method_table { bool (*bool_binary_and)(bool, bool); bool (*bool_binary_or)(bool, bool); bool (*bool_binary_xor)(bool, bool); + + // multi-thread + void *(*thread_start)(void *(*callable)(void *), void *user, size_t user_size); + void *(*thread_end)(); } nncase_mt_t; typedef struct buffer { diff --git a/modules/cpu/src/runtime/elfloader.cpp b/modules/cpu/src/runtime/elfloader.cpp index 394ed6ed3d..d9845fcf72 100644 --- a/modules/cpu/src/runtime/elfloader.cpp +++ b/modules/cpu/src/runtime/elfloader.cpp @@ -1,4 +1,5 @@ #include "elfloader.h" +#include "thread_pool.h" using namespace nncase; using namespace nncase::runtime; @@ -32,6 +33,9 @@ int elfloader::invoke_elf(size_t id, uint8_t **buffers, nncase_mt_t *nncase_mt, // printf("Binary entrypoint is %" PRIxPTR "; invoking %p\n", // (uintptr_t)ctx_.ehdr.e_entry, (void *)epaddr); + thread_pool::paddr_offset = (uintptr_t)buf_; + nncase_mt->thread_start = thread_pool::thread_start; + nncase_mt->thread_end = thread_pool::thread_end; ep(id, buffers, nncase_mt, data, rdata); free(ptr_); diff --git a/modules/cpu/src/runtime/thread_pool.cpp b/modules/cpu/src/runtime/thread_pool.cpp new file mode 100644 index 0000000000..757aaad32a --- /dev/null +++ b/modules/cpu/src/runtime/thread_pool.cpp @@ -0,0 +1,53 @@ +#include "thread_pool.h" + +using namespace nncase::runtime::cpu::thread_pool; + +static int threads_size = atoi(getenv("NNCASE_MAX_THREADS") ? getenv("NNCASE_MAX_THREADS") : "0"); +static int threads_count; +static std::vector threads; +static std::vector users; + +static void *thread_start(thread_func callable, void *user, size_t user_size) { + auto user_ = malloc(user_size); + std::memcpy(user_, user, user_size); + thread_func new_call = thread_func((char *)callable + paddr_offset); + if (threads_size == 0) { + new_call(user_); + } else { + auto idx = threads_count % threads_size; + if (threads_count >= threads_size) { + pthread_join(threads[idx], NULL); + free(users[idx]); + } + pthread_t pt; + auto ret = pthread_create(&pt, NULL, new_call, user_); + if (ret != 0) { + throw std::runtime_error("thread creation failed\n"); + } + + if (threads_count == 0) { + threads.resize(threads_size); + users.resize(threads_size); + } + threads[idx] = pt; + users[idx] = user_; + threads_count++; + } + return nullptr; +} + +static void *thread_end() { + if (threads_size) { + for (int i = 0; i < std::min(threads_size, threads_count); i++) { + // if (threads[i].joinable()) { + pthread_join(threads[i], NULL); + free(users[i]); + // } + } + threads_count = 0; + threads.clear(); + users.clear(); + } + return nullptr; +} + diff --git a/modules/cpu/src/runtime/thread_pool.h b/modules/cpu/src/runtime/thread_pool.h new file mode 100644 index 0000000000..6cdfd64817 --- /dev/null +++ b/modules/cpu/src/runtime/thread_pool.h @@ -0,0 +1,28 @@ +#ifndef THREAD_POOL_ +#define THREAD_POOL_ +#include +#include +#include +#include +#include +#include +#include + +BEGIN_NS_NNCASE_RT_MODULE(cpu) +namespace thread_pool { + +using thread_func = void *(*)(void *); + +// static int threads_size = atoi(getenv("NNCASE_MAX_THREADS") ? getenv("NNCASE_MAX_THREADS") : "0"); +// static int threads_count; +// static std::vector threads; +// static std::vector users; +uintptr_t paddr_offset; + +void *thread_start(thread_func callable, void *user, size_t user_size); +void *thread_end(); + +} // namespace thread_pool +END_NS_NNCASE_RT_MODULE + +#endif From 427d290e68c22f29760fc17cc28baee0a20858b2 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?=E9=83=91=E5=90=AF=E8=88=AA?= <597323109@qq.com> Date: Wed, 9 Aug 2023 19:18:37 +0800 Subject: [PATCH 072/308] fix build --- modules/cpu/src/runtime/thread_pool.cpp | 13 +++++++------ modules/cpu/src/runtime/thread_pool.h | 2 +- 2 files changed, 8 insertions(+), 7 deletions(-) diff --git a/modules/cpu/src/runtime/thread_pool.cpp b/modules/cpu/src/runtime/thread_pool.cpp index 757aaad32a..43846aed35 100644 --- a/modules/cpu/src/runtime/thread_pool.cpp +++ b/modules/cpu/src/runtime/thread_pool.cpp @@ -2,12 +2,13 @@ using namespace nncase::runtime::cpu::thread_pool; -static int threads_size = atoi(getenv("NNCASE_MAX_THREADS") ? getenv("NNCASE_MAX_THREADS") : "0"); -static int threads_count; -static std::vector threads; -static std::vector users; +int threads_size = atoi(getenv("NNCASE_MAX_THREADS") ? getenv("NNCASE_MAX_THREADS") : "0"); +int threads_count; +std::vector threads; +std::vector users; +uintptr_t nncase::runtime::cpu::thread_pool::paddr_offset; -static void *thread_start(thread_func callable, void *user, size_t user_size) { +void *nncase::runtime::cpu::thread_pool::thread_start(thread_func callable, void *user, size_t user_size) { auto user_ = malloc(user_size); std::memcpy(user_, user, user_size); thread_func new_call = thread_func((char *)callable + paddr_offset); @@ -36,7 +37,7 @@ static void *thread_start(thread_func callable, void *user, size_t user_size) { return nullptr; } -static void *thread_end() { +void *nncase::runtime::cpu::thread_pool::thread_end() { if (threads_size) { for (int i = 0; i < std::min(threads_size, threads_count); i++) { // if (threads[i].joinable()) { diff --git a/modules/cpu/src/runtime/thread_pool.h b/modules/cpu/src/runtime/thread_pool.h index 6cdfd64817..eb5819f6f8 100644 --- a/modules/cpu/src/runtime/thread_pool.h +++ b/modules/cpu/src/runtime/thread_pool.h @@ -17,7 +17,7 @@ using thread_func = void *(*)(void *); // static int threads_count; // static std::vector threads; // static std::vector users; -uintptr_t paddr_offset; +extern uintptr_t paddr_offset; void *thread_start(thread_func callable, void *user, size_t user_size); void *thread_end(); From 18fc13b1608b838026441d2f2d5cec436547d4b9 Mon Sep 17 00:00:00 2001 From: zhen8838 Date: Wed, 9 Aug 2023 11:21:25 +0000 Subject: [PATCH 073/308] Apply code-format changes --- modules/cpu/src/runtime/cpu_common.h | 3 ++- modules/cpu/src/runtime/thread_pool.cpp | 8 +++++--- modules/cpu/src/runtime/thread_pool.h | 7 +++---- 3 files changed, 10 insertions(+), 8 deletions(-) diff --git a/modules/cpu/src/runtime/cpu_common.h b/modules/cpu/src/runtime/cpu_common.h index d28d6ed69e..60b87e5ccc 100644 --- a/modules/cpu/src/runtime/cpu_common.h +++ b/modules/cpu/src/runtime/cpu_common.h @@ -68,7 +68,8 @@ typedef struct nncase_method_table { bool (*bool_binary_xor)(bool, bool); // multi-thread - void *(*thread_start)(void *(*callable)(void *), void *user, size_t user_size); + void *(*thread_start)(void *(*callable)(void *), void *user, + size_t user_size); void *(*thread_end)(); } nncase_mt_t; diff --git a/modules/cpu/src/runtime/thread_pool.cpp b/modules/cpu/src/runtime/thread_pool.cpp index 43846aed35..69de596cbd 100644 --- a/modules/cpu/src/runtime/thread_pool.cpp +++ b/modules/cpu/src/runtime/thread_pool.cpp @@ -2,13 +2,16 @@ using namespace nncase::runtime::cpu::thread_pool; -int threads_size = atoi(getenv("NNCASE_MAX_THREADS") ? getenv("NNCASE_MAX_THREADS") : "0"); +int threads_size = + atoi(getenv("NNCASE_MAX_THREADS") ? getenv("NNCASE_MAX_THREADS") : "0"); int threads_count; std::vector threads; std::vector users; uintptr_t nncase::runtime::cpu::thread_pool::paddr_offset; -void *nncase::runtime::cpu::thread_pool::thread_start(thread_func callable, void *user, size_t user_size) { +void *nncase::runtime::cpu::thread_pool::thread_start(thread_func callable, + void *user, + size_t user_size) { auto user_ = malloc(user_size); std::memcpy(user_, user, user_size); thread_func new_call = thread_func((char *)callable + paddr_offset); @@ -51,4 +54,3 @@ void *nncase::runtime::cpu::thread_pool::thread_end() { } return nullptr; } - diff --git a/modules/cpu/src/runtime/thread_pool.h b/modules/cpu/src/runtime/thread_pool.h index eb5819f6f8..f5baed41f7 100644 --- a/modules/cpu/src/runtime/thread_pool.h +++ b/modules/cpu/src/runtime/thread_pool.h @@ -13,10 +13,9 @@ namespace thread_pool { using thread_func = void *(*)(void *); -// static int threads_size = atoi(getenv("NNCASE_MAX_THREADS") ? getenv("NNCASE_MAX_THREADS") : "0"); -// static int threads_count; -// static std::vector threads; -// static std::vector users; +// static int threads_size = atoi(getenv("NNCASE_MAX_THREADS") ? +// getenv("NNCASE_MAX_THREADS") : "0"); static int threads_count; static +// std::vector threads; static std::vector users; extern uintptr_t paddr_offset; void *thread_start(thread_func callable, void *user, size_t user_size); From d926f7f2c40d6b2cc1955eaee0409d1fa84a025f Mon Sep 17 00:00:00 2001 From: huochenghai Date: Wed, 9 Aug 2023 19:23:15 +0800 Subject: [PATCH 074/308] fix runtime build --- modules/cpu/src/runtime/elfloader.cpp | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/modules/cpu/src/runtime/elfloader.cpp b/modules/cpu/src/runtime/elfloader.cpp index d9845fcf72..da4cc1f71f 100644 --- a/modules/cpu/src/runtime/elfloader.cpp +++ b/modules/cpu/src/runtime/elfloader.cpp @@ -33,7 +33,7 @@ int elfloader::invoke_elf(size_t id, uint8_t **buffers, nncase_mt_t *nncase_mt, // printf("Binary entrypoint is %" PRIxPTR "; invoking %p\n", // (uintptr_t)ctx_.ehdr.e_entry, (void *)epaddr); - thread_pool::paddr_offset = (uintptr_t)buf_; + nncase::runtime::cpu::thread_pool::paddr_offset = (uintptr_t)buf_; nncase_mt->thread_start = thread_pool::thread_start; nncase_mt->thread_end = thread_pool::thread_end; ep(id, buffers, nncase_mt, data, rdata); From 5cf33c7c116c3540efd6f3c5b3b20fd2d48cb967 Mon Sep 17 00:00:00 2001 From: huochenghai Date: Wed, 9 Aug 2023 19:24:11 +0800 Subject: [PATCH 075/308] update thread related csouce gen --- .../CodeGen/CSourceBuiltn.cs | 6 +- .../CodeGen/CSourceCompiler.cs | 2 +- .../CodeGen/CSourceConvertVisitor.cs | 9 ++- .../Passes/Tile/SingleCPUFusionConverter.cs | 4 +- src/Native/src/test_cli.cpp | 65 ++++++++++++------- 5 files changed, 54 insertions(+), 32 deletions(-) diff --git a/modules/Nncase.Modules.CPU/CodeGen/CSourceBuiltn.cs b/modules/Nncase.Modules.CPU/CodeGen/CSourceBuiltn.cs index a945b05958..81c495da47 100644 --- a/modules/Nncase.Modules.CPU/CodeGen/CSourceBuiltn.cs +++ b/modules/Nncase.Modules.CPU/CodeGen/CSourceBuiltn.cs @@ -72,9 +72,9 @@ public static class CSourceBuiltn bool (*bool_binary_and)(bool, bool); bool (*bool_binary_or)(bool, bool); bool (*bool_binary_xor)(bool, bool); - // paralell - void *thread_start(void *(*callable)(void *args), void *user); - void thread_end(); + // multi-thread + void *(*thread_start)(void *(*callable)(void *), void *user, size_t user_size); + void *(*thread_end)(); } nncase_mt_t;"; public const string Include = @"#include diff --git a/modules/Nncase.Modules.CPU/CodeGen/CSourceCompiler.cs b/modules/Nncase.Modules.CPU/CodeGen/CSourceCompiler.cs index 449329e034..691fe6cd24 100644 --- a/modules/Nncase.Modules.CPU/CodeGen/CSourceCompiler.cs +++ b/modules/Nncase.Modules.CPU/CodeGen/CSourceCompiler.cs @@ -138,7 +138,7 @@ private string ArgumentsSpecific(string sourcePath, string outPath) } else if (RuntimeInformation.IsOSPlatform(OSPlatform.OSX)) { - return $"{sourcePath} -nostdlib -static -nopie -fPIC -o {outPath} -e__start"; + return $"{sourcePath} -nostartfiles -pie -fPIC -o {outPath} -static -e__start"; } else if (RuntimeInformation.IsOSPlatform(OSPlatform.Windows)) { diff --git a/modules/Nncase.Modules.CPU/CodeGen/CSourceConvertVisitor.cs b/modules/Nncase.Modules.CPU/CodeGen/CSourceConvertVisitor.cs index 636cf1e2ea..1b05cacccf 100644 --- a/modules/Nncase.Modules.CPU/CodeGen/CSourceConvertVisitor.cs +++ b/modules/Nncase.Modules.CPU/CodeGen/CSourceConvertVisitor.cs @@ -314,7 +314,7 @@ protected override CSymbol VisitFor(For expr) if (expr.Mode == LoopMode.Parallel) { // find the vars will be used and make new struct type. - var msg_fields = _exprMemo.Where(p => p.Key is MemSpan or TIR.Buffer or Var).Select(p => p.Value).Concat(CSymbol.Builtns); + var msg_fields = _exprMemo.Where(p => p.Key is MemSpan or TIR.Buffer or Var).Select(p => p.Value).Concat(CSymbol.Builtns).ToArray(); var msg_type = DeclThreadMessageStruct(msg_fields); using (new IndentScope(_declBuilder)) @@ -345,7 +345,7 @@ protected override CSymbol VisitFor(For expr) IndentScope.Writer.IndWrite("};\n"); - IndentScope.Writer.IndWrite($"nncase_mt->thread_start({VisitEntry.Name}_inner, (void *)_message);\n"); + IndentScope.Writer.IndWrite($"nncase_mt->thread_start({VisitEntry.Name}_inner, (void *)&_message, sizeof ({msg_type}));\n"); } } else @@ -360,6 +360,11 @@ protected override CSymbol VisitFor(For expr) // 3. For closing IndentScope.Writer.IndWrite("}\n"); + if (expr.Mode == LoopMode.Parallel) + { + IndentScope.Writer.IndWrite("nncase_mt->thread_end();\n"); + } + symbol = new(string.Empty, string.Empty); _exprMemo.Add(expr, symbol); return symbol; diff --git a/modules/Nncase.Modules.CPU/Passes/Tile/SingleCPUFusionConverter.cs b/modules/Nncase.Modules.CPU/Passes/Tile/SingleCPUFusionConverter.cs index 2dc177c607..99ba7688ab 100644 --- a/modules/Nncase.Modules.CPU/Passes/Tile/SingleCPUFusionConverter.cs +++ b/modules/Nncase.Modules.CPU/Passes/Tile/SingleCPUFusionConverter.cs @@ -86,9 +86,9 @@ private void GenerateMatMul(ReadOnlySpan arguments, Buffer ret, Call exp var lhs = arguments[0]; var rhs = arguments[1]; - var loops = Enumerable.Range(0, lhs.Rank - 2).Select(i => (T.ForLoop(out var loopVar, (0, lhs.Dimensions[i]), i == 0 ? LoopMode.Parallel : LoopMode.Serial, $"loop_{i}"), loopVar)).ToArray(); + var loops = Enumerable.Range(0, lhs.Rank - 2).Select(i => (T.ForLoop(out var loopVar, (0, lhs.Dimensions[i]), LoopMode.Serial, $"loop_{i}"), loopVar)).ToArray(); var loopVars = loops.Select(f => f.loopVar).ToArray(); - var stmt = T.Serial(out var m, (0, lhs.Dimensions[^2])).Body( + var stmt = T.ForLoop(out var m, (0, lhs.Dimensions[^2]), LoopMode.Parallel).Body( T.Serial(out var n, (0, rhs.Dimensions[^1])).Body( T.BufferStore(ret, loopVars.Concat(new[] { m, n }).ToArray(), 0f), T.Serial(out var k, (0, lhs.Dimensions[^1])).Body( diff --git a/src/Native/src/test_cli.cpp b/src/Native/src/test_cli.cpp index 7f703a216a..ea375f4f35 100644 --- a/src/Native/src/test_cli.cpp +++ b/src/Native/src/test_cli.cpp @@ -12,6 +12,7 @@ * See the License for the specific language governing permissions and * limitations under the License. */ +#include #include #include #include @@ -19,6 +20,7 @@ using namespace nncase; using namespace nncase::runtime; +constexpr size_t loop_count = 10; #define TRY(x) \ if (x) \ @@ -34,8 +36,7 @@ result write_tensor_buffer(value_t value, std::ofstream &of) { } result run_core(const std::string &kmodel_path, - const std::vector &input_bins, - const std::string &output_bin) { + const std::vector &bins) { auto kmodel = read_file(kmodel_path); interpreter *interp = new interpreter(); // auto dump_path = @@ -47,16 +48,16 @@ result run_core(const std::string &kmodel_path, try_var(entry, interp->entry_function()); - if (entry->parameters_size() != input_bins.size()) + if (entry->parameters_size() > bins.size()) return err(std::errc::argument_list_too_long); /* create the input parameters tensor note the input tenosr must be contiguous */ std::vector parameters; - for (int i = 0; i < input_bins.size(); i++) { + for (int i = 0; i < entry->parameters_size(); i++) { try_var(type, entry->parameter_type(i)); try_var(ts_type, type.as()); - auto input_pool = read_file(input_bins[i]); + auto input_pool = read_file(bins[i]); gsl::span input_pool_span = { reinterpret_cast(input_pool.data()), input_pool.size()}; @@ -66,21 +67,38 @@ result run_core(const std::string &kmodel_path, parameters.push_back(_.impl()); } - try_var(ret, entry->invoke({parameters.data(), parameters.size()})); + double total_time = 0.0; + for (size_t i = 0; i < loop_count; i++) { + auto start_time = std::chrono::steady_clock::now(); + try_var(ret, entry->invoke({parameters.data(), parameters.size()})); + auto end_time = std::chrono::steady_clock::now(); + total_time += (std::chrono::duration_cast( + end_time - start_time) + .count() / + 1e6); - std::ofstream output_stream(output_bin, std::ios::binary); - - if (ret.is_a()) { - try_(write_tensor_buffer(ret, output_stream)); - } else if (ret.is_a()) { - try_var(tp, ret.as()); - for (auto &&ret_v : tp->fields()) { - try_(write_tensor_buffer(ret_v, output_stream)); + if (i == (loop_count - 1)) { + if (entry->parameters_size() == (bins.size() - 1)) { + auto output_bin = *bins.end(); + std::ofstream output_stream(output_bin, std::ios::binary); + if (ret.is_a()) { + try_(write_tensor_buffer(ret, output_stream)); + } else if (ret.is_a()) { + try_var(tp, ret.as()); + for (auto &&ret_v : tp->fields()) { + try_(write_tensor_buffer(ret_v, output_stream)); + } + } else { + return nncase::err(std::errc::bad_message); + } + output_stream.close(); + } } - } else { - return nncase::err(std::errc::bad_message); } - output_stream.close(); + + std::cout << "interp run: " << (total_time / loop_count) + << " ms, fps = " << 1000 / (total_time / loop_count) << std::endl; + return ok(); } @@ -92,13 +110,12 @@ result run_core(const std::string &kmodel_path, * @return int */ int main(NNCASE_UNUSED int argc, char **argv) { - assert(argc >= 4); - std::vector input_bins; - for (int i = 2; i < argc - 1; i++) { - input_bins.push_back(argv[i]); + assert(argc >= 3); + std::vector bins; + for (int i = 2; i < argc; i++) { + bins.push_back(argv[i]); } std::string kmodel_bin(argv[1]); - std::string output_bin(argv[argc - 1]); - run_core(kmodel_bin, input_bins, output_bin).unwrap_or_throw(); + run_core(kmodel_bin, bins).unwrap_or_throw(); return 0; -} +} \ No newline at end of file From 1e9ffd5ad9314f42b5e7cd97298789097f40d1c8 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?=E9=83=91=E5=90=AF=E8=88=AA?= <597323109@qq.com> Date: Tue, 15 Aug 2023 15:54:03 +0800 Subject: [PATCH 076/308] add simulator --- modules/cpu/src/runtime/CMakeLists.txt | 2 + modules/cpu/src/runtime/cmodel/CMakeLists.txt | 21 +++ .../cpu/src/runtime/cmodel/include/apply.h | 102 +++++++++++++ .../runtime/cmodel/include/hardware_context.h | 24 +++ .../src/runtime/cmodel/include/hardware_def.h | 5 + .../runtime/cmodel/include/runtime_utils.h | 85 +++++++++++ modules/cpu/src/runtime/cmodel/include/tdma.h | 140 ++++++++++++++++++ .../cpu/src/runtime/cmodel/include/tensor.h | 66 +++++++++ .../runtime/cmodel/include/thread_context.h | 16 ++ .../cpu/src/runtime/cmodel/src/cpu_cmodel.cpp | 58 ++++++++ .../runtime/cmodel/src/hardware_context.cpp | 113 ++++++++++++++ .../src/runtime/cmodel/tests/demo1/kernel.h | 100 +++++++++++++ .../src/runtime/cmodel/tests/demo1/main.cpp | 76 ++++++++++ .../runtime/cmodel/tests/demo1/shared_def.h | 9 ++ modules/cpu/src/runtime/shared_memory.cpp | 125 ++++++++++++++++ modules/cpu/src/runtime/shared_memory.h | 51 +++++++ 16 files changed, 993 insertions(+) create mode 100644 modules/cpu/src/runtime/cmodel/CMakeLists.txt create mode 100644 modules/cpu/src/runtime/cmodel/include/apply.h create mode 100644 modules/cpu/src/runtime/cmodel/include/hardware_context.h create mode 100644 modules/cpu/src/runtime/cmodel/include/hardware_def.h create mode 100644 modules/cpu/src/runtime/cmodel/include/runtime_utils.h create mode 100644 modules/cpu/src/runtime/cmodel/include/tdma.h create mode 100644 modules/cpu/src/runtime/cmodel/include/tensor.h create mode 100644 modules/cpu/src/runtime/cmodel/include/thread_context.h create mode 100644 modules/cpu/src/runtime/cmodel/src/cpu_cmodel.cpp create mode 100644 modules/cpu/src/runtime/cmodel/src/hardware_context.cpp create mode 100644 modules/cpu/src/runtime/cmodel/tests/demo1/kernel.h create mode 100644 modules/cpu/src/runtime/cmodel/tests/demo1/main.cpp create mode 100644 modules/cpu/src/runtime/cmodel/tests/demo1/shared_def.h create mode 100644 modules/cpu/src/runtime/shared_memory.cpp create mode 100644 modules/cpu/src/runtime/shared_memory.h diff --git a/modules/cpu/src/runtime/CMakeLists.txt b/modules/cpu/src/runtime/CMakeLists.txt index 37da417626..a0d8e42910 100644 --- a/modules/cpu/src/runtime/CMakeLists.txt +++ b/modules/cpu/src/runtime/CMakeLists.txt @@ -1,5 +1,7 @@ cmake_minimum_required (VERSION 3.13) +add_subdirectory(cmodel) + set(SRCS runtime_module.cpp runtime_function.cpp thread_pool.cpp diff --git a/modules/cpu/src/runtime/cmodel/CMakeLists.txt b/modules/cpu/src/runtime/cmodel/CMakeLists.txt new file mode 100644 index 0000000000..0a071feced --- /dev/null +++ b/modules/cpu/src/runtime/cmodel/CMakeLists.txt @@ -0,0 +1,21 @@ + +add_library(cpu_cmodel STATIC + src/hardware_context.cpp +) + +target_include_directories(cpu_cmodel PUBLIC include) +set_target_properties(cpu_cmodel PROPERTIES POSITION_INDEPENDENT_CODE ON) + +add_executable(cpu_cmodel_cli src/cpu_cmodel.cpp ../shared_memory.cpp) +target_link_libraries(cpu_cmodel_cli PUBLIC cpu_cmodel) +set_target_properties(cpu_cmodel_cli PROPERTIES POSITION_INDEPENDENT_CODE ON + OUTPUT_NAME "nncase.simulator.cpu.c") +install(TARGETS cpu_cmodel_cli COMPONENT nncase-runtime) + + +function(add_test name test_path) + add_executable(${name} "${test_path}/main.cpp") + target_include_directories(${name} PUBLIC include) +endfunction(add_test) + +add_test(demo1 tests/demo1) diff --git a/modules/cpu/src/runtime/cmodel/include/apply.h b/modules/cpu/src/runtime/cmodel/include/apply.h new file mode 100644 index 0000000000..6b519f8c48 --- /dev/null +++ b/modules/cpu/src/runtime/cmodel/include/apply.h @@ -0,0 +1,102 @@ +#pragma once +#include + +namespace detail { +#define APPLY_IMPL_FOR(i) for (index[i] = 0; index[i] < shape[i]; index[i]++) + +template +void apply_1(gsl::span shape, Callable &&callable) noexcept { + size_t index[1]; + APPLY_IMPL_FOR(0) + callable(gsl::span(index)); +} + +template +void apply_2(gsl::span shape, Callable &&callable) noexcept { + size_t index[2]; + APPLY_IMPL_FOR(0) + APPLY_IMPL_FOR(1) + callable(gsl::span(index)); +} + +template +void apply_3(gsl::span shape, Callable &&callable) noexcept { + size_t index[3]; + APPLY_IMPL_FOR(0) + APPLY_IMPL_FOR(1) + APPLY_IMPL_FOR(2) + callable(gsl::span(index)); +} + +template +void apply_4(gsl::span shape, Callable &&callable) noexcept { + size_t index[4]; + APPLY_IMPL_FOR(0) + APPLY_IMPL_FOR(1) + APPLY_IMPL_FOR(2) + APPLY_IMPL_FOR(3) + callable(gsl::span(index)); +} + +template +void apply_5(gsl::span shape, Callable &&callable) noexcept { + size_t index[5]; + APPLY_IMPL_FOR(0) + APPLY_IMPL_FOR(1) + APPLY_IMPL_FOR(2) + APPLY_IMPL_FOR(3) + APPLY_IMPL_FOR(4) + callable(gsl::span(index)); +} + +template +void apply_generic(gsl::span shape, + Callable &&callable) noexcept { + auto index_buffer = (size_t *) +#ifdef _WIN32 + _alloca +#else + __builtin_alloca +#endif + (sizeof(size_t) * shape.size()); + + gsl::span index(index_buffer, shape.size()); + std::fill(index.begin(), index.end(), 0); + auto last_dim_idx = (int32_t)shape.size() - 1; + while (true) { + int dim = last_dim_idx; + while (index[dim] == shape[dim]) { + if (dim == 0) { + } + + index[dim] = 0; + index[--dim]++; + } + + callable(index); + index[last_dim_idx]++; + } +} +} // namespace detail + +template +void apply(gsl::span shape, Callable &&callable) noexcept { + switch (shape.size()) { + case 0: + return callable(shape); + case 1: + return detail::apply_1(shape, std::forward(callable)); + case 2: + return detail::apply_2(shape, std::forward(callable)); + case 3: + return detail::apply_3(shape, std::forward(callable)); + case 4: + return detail::apply_4(shape, std::forward(callable)); + case 5: + return detail::apply_5(shape, std::forward(callable)); + default: + break; + } + + return detail::apply_generic(shape, std::forward(callable)); +} \ No newline at end of file diff --git a/modules/cpu/src/runtime/cmodel/include/hardware_context.h b/modules/cpu/src/runtime/cmodel/include/hardware_context.h new file mode 100644 index 0000000000..e01366c9f2 --- /dev/null +++ b/modules/cpu/src/runtime/cmodel/include/hardware_context.h @@ -0,0 +1,24 @@ +#pragma once + +#include + +struct hardware_context_impl; + +class hardware_context { + public: + hardware_context(); + void lock_block(int bid); + int mark_block_visit(int bid, int tid); + void unlock_block(int bid); + void wait_block_sync(int bid, int visited); + void lock_all(); + int mark_all_visit(int bid, int tid); + void unlock_all(); + void wait_all_sync(int visited); + void *all_reduce_var = nullptr; + + private: + std::unique_ptr impl_; +}; + +extern std::unique_ptr global_hardware_ctx; \ No newline at end of file diff --git a/modules/cpu/src/runtime/cmodel/include/hardware_def.h b/modules/cpu/src/runtime/cmodel/include/hardware_def.h new file mode 100644 index 0000000000..9dfdcd946a --- /dev/null +++ b/modules/cpu/src/runtime/cmodel/include/hardware_def.h @@ -0,0 +1,5 @@ +#pragma once +#include + +constexpr size_t BLOCKS = 8; +constexpr size_t CORES = 4; \ No newline at end of file diff --git a/modules/cpu/src/runtime/cmodel/include/runtime_utils.h b/modules/cpu/src/runtime/cmodel/include/runtime_utils.h new file mode 100644 index 0000000000..68d944706d --- /dev/null +++ b/modules/cpu/src/runtime/cmodel/include/runtime_utils.h @@ -0,0 +1,85 @@ +#pragma once + +#include +#include +#include +#include +#include +#include +#include + +using dims_t = itlib::small_vector; +using strides_t = itlib::small_vector; + +void print_vec(itlib::small_vector vec) { + for (const size_t v : vec) { + std::cout << std::to_string(v) << ", "; + } + std::cout << std::endl; +} + +template inline size_t compute_size(const TShape &shape) { + return std::accumulate(shape.begin(), shape.end(), 1, + std::multiplies()); +} + +template +inline size_t compute_size(const TShape &shape, const TShape &strides) { + size_t max_stride = 1, max_shape = 1; + for (size_t i = 0; i < shape.size(); i++) { + if ((shape[i] == 1 ? 0 : strides[i]) >= max_stride) { + max_stride = strides[i]; + max_shape = shape[i]; + } + } + size_t size = max_stride * max_shape; + return size; +} + +template +inline std::size_t compute_strides(const shape_type &shape, + strides_type &strides) { + using strides_value_type = typename std::decay_t::value_type; + strides_value_type data_size = 1; + for (std::size_t i = shape.size(); i != 0; --i) { + strides[i - 1] = data_size; + data_size = + strides[i - 1] * static_cast(shape[i - 1]); + } + return static_cast(data_size); +} + +inline strides_t get_default_strides(dims_t shape) { + strides_t strides(shape.size()); + compute_strides(shape, strides); + return strides; +} + +template +inline offset_type element_offset(const S &strides, It first, + It last) noexcept { + using difference_type = typename std::iterator_traits::difference_type; + auto size = static_cast((std::min)( + static_cast(std::distance(first, last)), strides.size())); + return std::inner_product(last - size, last, strides.cend() - size, + offset_type(0)); +} + +inline size_t offset(gsl::span strides, + gsl::span index) { + // scalar + if (strides.size() == 0 || index.size() == 0) { + return 0; + } + assert(strides.size() == index.size()); + return element_offset(strides, index.begin(), index.end()); +} + +inline bool is_shape_equal(const dims_t &a, const dims_t &b) { + for (size_t i = 0; i < a.size(); i++) { + if (a[i] != b[i]) { + return false; + } + } + return true; +} diff --git a/modules/cpu/src/runtime/cmodel/include/tdma.h b/modules/cpu/src/runtime/cmodel/include/tdma.h new file mode 100644 index 0000000000..9a66bb403c --- /dev/null +++ b/modules/cpu/src/runtime/cmodel/include/tdma.h @@ -0,0 +1,140 @@ +#pragma once +#include +#include +#include +#include +#include + +template