diff --git a/.gitignore b/.gitignore index 2dfd485fbe..5b1e72c18f 100644 --- a/.gitignore +++ b/.gitignore @@ -68,7 +68,7 @@ artifacts/ *.pidb *.svclog *.scc - +*.bin # Chutzpah Test files _Chutzpah* @@ -306,4 +306,4 @@ cmake-build-* *gmodel_dump_dir* *.ipynb_checkpoints* # Auto generated files -# generated/ \ No newline at end of file +# generated/ diff --git a/CMakeLists.txt b/CMakeLists.txt index b97e01e62c..7ac7539a47 100644 --- a/CMakeLists.txt +++ b/CMakeLists.txt @@ -12,7 +12,6 @@ endif() if(NOT DEFINED NNCASE_VERSION_SUFFIX) find_package (Git) - execute_process( COMMAND ${GIT_EXECUTABLE} describe --always --dirty --tag WORKING_DIRECTORY ${CMAKE_SOURCE_DIR} @@ -274,5 +273,5 @@ if(BUILD_TESTING) endif() # Modules -#add_subdirectory(modules/k210) + #add_subdirectory(modules/vulkan) diff --git a/Directory.Packages.props b/Directory.Packages.props index 3b3a93c2df..bcc1830198 100644 --- a/Directory.Packages.props +++ b/Directory.Packages.props @@ -49,8 +49,9 @@ - - + + + diff --git a/modules/Nncase.Modules.CPU/packages.lock.json b/modules/Nncase.Modules.CPU/packages.lock.json deleted file mode 100644 index 5dd73e5bda..0000000000 --- a/modules/Nncase.Modules.CPU/packages.lock.json +++ /dev/null @@ -1,295 +0,0 @@ -{ - "version": 2, - "dependencies": { - "net7.0": { - "StyleCop.Analyzers": { - "type": "Direct", - "requested": "[1.2.0-beta.435, )", - "resolved": "1.2.0-beta.435", - "contentHash": "TADk7vdGXtfTnYCV7GyleaaRTQjfoSfZXprQrVMm7cSJtJbFc1QIbWPyLvrgrfGdfHbGmUPvaN4ODKNxg2jgPQ==", - "dependencies": { - "StyleCop.Analyzers.Unstable": "1.2.0.435" - } - }, - "Google.OrTools.runtime.linux-arm64": { - "type": "Transitive", - "resolved": "9.4.1874", - "contentHash": "Z46ndZcZa2Lt5b76xU9kxVYbPLg/LfuMufhUVsu3Qo3L7Bibf7WXd9j7RRldjnuv8RIHWTqb0b+2FwwMxs0c5A==" - }, - "Google.OrTools.runtime.linux-x64": { - "type": "Transitive", - "resolved": "9.4.1874", - "contentHash": "zGeDb8FuvP9HXjrsU7krVXtSDFpR+DUGNEsH51k94jL9tzf2vWYI8+WUBRHZ/cGe50dpLr+vIjfcNo3gFyOpkQ==" - }, - "Google.OrTools.runtime.osx-arm64": { - "type": "Transitive", - "resolved": "9.4.1874", - "contentHash": "Wo0ZfDaH6DhiQw0jZm4HWJm/oPGPpWNwOLUz+EYaoH3MLtocSxItHGQj/Ta3HyhXnYNOv+TliAH8L+8RCXu/2w==" - }, - "Google.OrTools.runtime.osx-x64": { - "type": "Transitive", - "resolved": "9.4.1874", - "contentHash": "IAfGgKR1og6vU87axK1d37Ak/4jy8B4NMoElovG/KZc/2UY+cJEAQDA709UMegtI4lBhuxTWFNUiHQYmRIB9yQ==" - }, - "Google.OrTools.runtime.win-x64": { - "type": "Transitive", - "resolved": "9.4.1874", - "contentHash": "fUs5qDnZA6itygolcX6nPuachQkY9CVvQbakIzIiRAWKcaj8umQAbFdGwbkyzp3qp34BKW5mtPVsmMyfQBBjOQ==" - }, - "libortki": { - "type": "Transitive", - "resolved": "0.0.2", - "contentHash": "svfuG5mxGY/QC/5DVheHOCELmdSP90RtxQ73j23KarPXZ9ZXW+7v1l5J77hGDyQbEh1BGrnGgKBlyn76RauGHg==", - "dependencies": { - "libortki-linux": "0.0.2", - "libortki-osx": "0.0.2", - "libortki-osx-arm64": "0.0.2", - "libortki-win": "0.0.2" - } - }, - "libortki-linux": { - "type": "Transitive", - "resolved": "0.0.2", - "contentHash": "b04LWD4lgGy60tys3hPFhnUpgWDM6dN5r1PI7GOcPj8VupXCaI70LKNQ5/5twbDE6rkowOGanVTw0S2wBGBqBQ==" - }, - "libortki-osx": { - "type": "Transitive", - "resolved": "0.0.2", - "contentHash": "O6Q9GLULkDkZEPAZJVKLPH0ROXGVOE7BxuddgOcHNK2oiTEM7wIRnzp2OIlYgLpaOLyxJMisbGOhtWgdzt2Wng==" - }, - "libortki-osx-arm64": { - "type": "Transitive", - "resolved": "0.0.2", - "contentHash": "4Qn2dirJmRicnUG945oWpq7HVGwgqCKKxYPMISv/MRvmpZBbXrZ1cVvRaF8WwTu4XXgfKTa1sLv+i8zLifUMeQ==" - }, - "libortki-win": { - "type": "Transitive", - "resolved": "0.0.2", - "contentHash": "HAoROgAKn8XBun11X43HZuspKlo5JGy8/OYw5IUPo7FVh5TCaPrLjGmyGYYZ2dqLlv31yv/b6s254PIRGn95cA==" - }, - "Microsoft.Extensions.Configuration.Abstractions": { - "type": "Transitive", - "resolved": "6.0.0", - "contentHash": "qWzV9o+ZRWq+pGm+1dF+R7qTgTYoXvbyowRoBxQJGfqTpqDun2eteerjRQhq5PQ/14S+lqto3Ft4gYaRyl4rdQ==", - "dependencies": { - "Microsoft.Extensions.Primitives": "6.0.0" - } - }, - "Microsoft.Extensions.DependencyInjection.Abstractions": { - "type": "Transitive", - "resolved": "6.0.0", - "contentHash": "xlzi2IYREJH3/m6+lUrQlujzX8wDitm4QGnUu6kUXTQAWPuZY8i+ticFJbzfqaetLA6KR/rO6Ew/HuYD+bxifg==" - }, - "Microsoft.Extensions.FileProviders.Abstractions": { - "type": "Transitive", - "resolved": "6.0.0", - "contentHash": "0pd4/fho0gC12rQswaGQxbU34jOS1TPS8lZPpkFCH68ppQjHNHYle9iRuHeev1LhrJ94YPvzcRd8UmIuFk23Qw==", - "dependencies": { - "Microsoft.Extensions.Primitives": "6.0.0" - } - }, - "Microsoft.Extensions.Primitives": { - "type": "Transitive", - "resolved": "6.0.0", - "contentHash": "9+PnzmQFfEFNR9J2aDTfJGGupShHjOuGw4VUv+JB044biSHrnmCIMD+mJHmb2H7YryrfBEXDurxQ47gJZdCKNQ==", - "dependencies": { - "System.Runtime.CompilerServices.Unsafe": "6.0.0" - } - }, - "NetFabric.Hyperlinq.Abstractions": { - "type": "Transitive", - "resolved": "1.3.0", - "contentHash": "WXnEcGwmXfa8gW9N2MlcaPNUzM3NLMwnAhacbtH554F8YcoXbIkTB+uGa1Aa+9gyb/9JZgYVHnmADgJUKP52nA==" - }, - "StyleCop.Analyzers.Unstable": { - "type": "Transitive", - "resolved": "1.2.0.435", - "contentHash": "ouwPWZxbOV3SmCZxIRqHvljkSzkCyi1tDoMzQtDb/bRP8ctASV/iRJr+A2Gdj0QLaLmWnqTWDrH82/iP+X80Lg==" - }, - "System.Buffers": { - "type": "Transitive", - "resolved": "4.5.1", - "contentHash": "Rw7ijyl1qqRS0YQD/WycNst8hUUMgrMH4FCn1nNm27M4VxchZ1js3fVjQaANHO5f3sN4isvP4a+Met9Y4YomAg==" - }, - "System.Runtime.CompilerServices.Unsafe": { - "type": "Transitive", - "resolved": "6.0.0", - "contentHash": "/iUeP3tq1S0XdNNoMz5C9twLSrM/TH+qElHkXWaPvuNOt+99G75NrV0OS2EqHx5wMN7popYjpc8oTjC1y16DLg==" - }, - "nncase.codegen": { - "type": "Project", - "dependencies": { - "Extension.Mathematics": "[1.2.12, )", - "Nncase.Core": "[1.0.0, )", - "Nncase.IO": "[1.0.0, )" - } - }, - "nncase.core": { - "type": "Project", - "dependencies": { - "DryIoc.dll": "[5.3.1, )", - "GiGraph.Dot": "[2.0.0, )", - "Microsoft.Extensions.Hosting.Abstractions": "[6.0.0, )", - "Microsoft.Extensions.Logging.Abstractions": "[6.0.0, )", - "Microsoft.Extensions.Options": "[6.0.0, )", - "Microsoft.Toolkit.HighPerformance": "[7.1.1, )", - "NetFabric.Hyperlinq": "[3.0.0-beta48, )", - "System.Reactive": "[5.0.0, )" - } - }, - "nncase.diagnostics": { - "type": "Project", - "dependencies": { - "Nncase.Core": "[1.0.0, )" - } - }, - "nncase.egraph": { - "type": "Project", - "dependencies": { - "GiGraph.Dot": "[2.0.0, )", - "Google.OrTools": "[9.4.1874, )", - "NetFabric.Hyperlinq": "[3.0.0-beta48, )", - "Nncase.Core": "[1.0.0, )", - "Nncase.Evaluator": "[1.0.0, )", - "Singulink.Collections.Weak": "[1.0.2, )" - } - }, - "nncase.evaluator": { - "type": "Project", - "dependencies": { - "Nncase.Core": "[1.0.0, )", - "OrtKISharp": "[0.0.2, )" - } - }, - "nncase.graph": { - "type": "Project", - "dependencies": { - "Nncase.Core": "[1.0.0, )", - "Nncase.Evaluator": "[1.0.0, )" - } - }, - "nncase.io": { - "type": "Project" - }, - "nncase.modules.stackvm": { - "type": "Project", - "dependencies": { - "Nncase.CodeGen": "[1.0.0, )", - "Nncase.Passes": "[1.0.0, )" - } - }, - "nncase.passes": { - "type": "Project", - "dependencies": { - "Nncase.Core": "[1.0.0, )", - "Nncase.EGraph": "[1.0.0, )", - "Nncase.Evaluator": "[1.0.0, )", - "Nncase.Graph": "[1.0.0, )" - } - }, - "DryIoc.dll": { - "type": "CentralTransitive", - "requested": "[5.3.1, )", - "resolved": "5.3.1", - "contentHash": "E3zclUh2CIBks1t2uBD1k18pyGFJ1YSKCrbCDbB7qCdl2RAB+k68AyDpjeplhF1ot2XPV82AgyCWBXMf0ggL1g==" - }, - "Extension.Mathematics": { - "type": "CentralTransitive", - "requested": "[1.2.12, )", - "resolved": "1.2.12", - "contentHash": "D4mn5Cab4ztPLJ0V8uMErDrO/Y61098nwrvyIOLZymVAYOQcwP1vomVWKbTagf1aPU3cX5Q7adZtQEQwOy6XEg==" - }, - "GiGraph.Dot": { - "type": "CentralTransitive", - "requested": "[2.0.0, )", - "resolved": "2.0.0", - "contentHash": "ThvS2mQVveSkTMUm04tMbRYzu1XFPV8xBHISrUMp02APjhv9IRbLu3v3upTPCywORx2Ds/c6AqEUL1WU6kPfuQ==" - }, - "Google.OrTools": { - "type": "CentralTransitive", - "requested": "[9.4.1874, )", - "resolved": "9.4.1874", - "contentHash": "jqRoI+pYlym+fhoU25u+13oti5h+772bllQ9zDitTVMclDXVTiG6pxzvmYO74wnADBMdpb2SQlgiNQxoNk5dlA==", - "dependencies": { - "Google.OrTools.runtime.linux-arm64": "9.4.1874", - "Google.OrTools.runtime.linux-x64": "9.4.1874", - "Google.OrTools.runtime.osx-arm64": "9.4.1874", - "Google.OrTools.runtime.osx-x64": "9.4.1874", - "Google.OrTools.runtime.win-x64": "9.4.1874", - "Google.Protobuf": "3.19.4" - } - }, - "Google.Protobuf": { - "type": "CentralTransitive", - "requested": "[3.19.4, )", - "resolved": "3.19.4", - "contentHash": "fd07/ykL4O4FhqrZIELm5lmiyOHfdPg9+o+hWr6tcfRdS7tHXnImg/2wtogLzlW2eEmr0J7j6ZrZvaWOLiJbxQ==" - }, - "Microsoft.Extensions.Hosting.Abstractions": { - "type": "CentralTransitive", - "requested": "[6.0.0, )", - "resolved": "6.0.0", - "contentHash": "GcT5l2CYXL6Sa27KCSh0TixsRfADUgth+ojQSD5EkzisZxmGFh7CwzkcYuGwvmXLjr27uWRNrJ2vuuEjMhU05Q==", - "dependencies": { - "Microsoft.Extensions.Configuration.Abstractions": "6.0.0", - "Microsoft.Extensions.DependencyInjection.Abstractions": "6.0.0", - "Microsoft.Extensions.FileProviders.Abstractions": "6.0.0" - } - }, - "Microsoft.Extensions.Logging.Abstractions": { - "type": "CentralTransitive", - "requested": "[6.0.0, )", - "resolved": "6.0.0", - "contentHash": "/HggWBbTwy8TgebGSX5DBZ24ndhzi93sHUBDvP1IxbZD7FDokYzdAr6+vbWGjw2XAfR2EJ1sfKUotpjHnFWPxA==" - }, - "Microsoft.Extensions.Options": { - "type": "CentralTransitive", - "requested": "[6.0.0, )", - "resolved": "6.0.0", - "contentHash": "dzXN0+V1AyjOe2xcJ86Qbo233KHuLEY0njf/P2Kw8SfJU+d45HNS2ctJdnEnrWbM9Ye2eFgaC5Mj9otRMU6IsQ==", - "dependencies": { - "Microsoft.Extensions.DependencyInjection.Abstractions": "6.0.0", - "Microsoft.Extensions.Primitives": "6.0.0" - } - }, - "Microsoft.Toolkit.HighPerformance": { - "type": "CentralTransitive", - "requested": "[7.1.1, )", - "resolved": "7.1.1", - "contentHash": "TRnvDpZPXO30hTOtjfLw6Y9BtTKtTpzk9lefeh4RMCaUihWrVKQR454nYH4/mMJAh+LXqfAPyk0kfkJs0Amopw==" - }, - "NetFabric.Hyperlinq": { - "type": "CentralTransitive", - "requested": "[3.0.0-beta48, )", - "resolved": "3.0.0-beta48", - "contentHash": "oYUhXvxNS8bBJWqNkvx5g8y0P/0LtyqS2pN0w4OWjVDNWEpLbdbvPy9w/9z1n2PrqIjX3jxUsEnoCmxxGnI3gw==", - "dependencies": { - "NetFabric.Hyperlinq.Abstractions": "1.3.0", - "System.Buffers": "4.5.1", - "System.Runtime.CompilerServices.Unsafe": "5.0.0" - } - }, - "OrtKISharp": { - "type": "CentralTransitive", - "requested": "[0.0.2, )", - "resolved": "0.0.2", - "contentHash": "q8j0yR5836Zhv9WB9BFkQt1UaEFyibq8bqJcTiULlILF6/sz8z7Wy2N8sgYdDKsdW25zncIz7j6IDbKM5ynePg==", - "dependencies": { - "libortki": "0.0.2" - } - }, - "Singulink.Collections.Weak": { - "type": "CentralTransitive", - "requested": "[1.0.2, )", - "resolved": "1.0.2", - "contentHash": "giLAHrjJe0Bh7yhNexR6pmcv02+Fi+lEPxQVdB8zvkuJCmy6rnqu8CZLIpxrUfLcWDuTCSiK0IfGmMhig3UDhA==" - }, - "System.Reactive": { - "type": "CentralTransitive", - "requested": "[5.0.0, )", - "resolved": "5.0.0", - "contentHash": "erBZjkQHWL9jpasCE/0qKAryzVBJFxGHVBAvgRN1bzM0q2s1S4oYREEEL0Vb+1kA/6BKb5FjUZMp5VXmy+gzkQ==" - } - } - } -} \ No newline at end of file diff --git a/modules/Nncase.Modules.StackVM/CodeGen/StackVM/CodeGenVisitor.g.cs b/modules/Nncase.Modules.StackVM/CodeGen/StackVM/CodeGenVisitor.g.cs index 04cbd48a7b..314930a8a9 100644 --- a/modules/Nncase.Modules.StackVM/CodeGen/StackVM/CodeGenVisitor.g.cs +++ b/modules/Nncase.Modules.StackVM/CodeGen/StackVM/CodeGenVisitor.g.cs @@ -1,6 +1,6 @@ // Copyright (c) Canaan Inc. All rights reserved. // Licensed under the Apache license. See LICENSE file in the project root for full license information. -/* This file is generated by tools/stackvm_gen/IsaGen at 2023/9/18 下午5:04:31 +08:00. */ +/* This file is generated by tools/stackvm_gen/IsaGen at 9/20/2023 10:17:08 AM +00:00. */ using System; using System.Collections.Generic; @@ -59,7 +59,7 @@ private void EmitTensorCall(Op op) Emitter.T.L2Normalization(); break; case IR.NN.LayerNorm top: - Emitter.T.LayerNorm(top.Axis, top.Epsilon); + Emitter.T.LayerNorm(top.Axis, top.Epsilon, top.UseMean); break; case IR.NN.LeakyRelu top: Emitter.T.LeakyRelu(); @@ -176,7 +176,7 @@ private void EmitTensorCall(Op op) Emitter.T.Cast(top.NewType, top.CastMode); break; case IR.Tensors.Concat top: - Emitter.T.Concat(); + Emitter.T.Concat(top.Axis); break; case IR.Tensors.ConstantOfShape top: Emitter.T.ConstantOfShape(); @@ -191,7 +191,7 @@ private void EmitTensorCall(Op op) Emitter.T.Flatten(); break; case IR.Tensors.Gather top: - Emitter.T.Gather(); + Emitter.T.Gather(top.Axis); break; case IR.Tensors.GatherElements top: Emitter.T.GatherElements(); @@ -205,9 +205,6 @@ private void EmitTensorCall(Op op) case IR.Tensors.IndexOf top: Emitter.T.IndexOf(); break; - case IR.Tensors.LSTM top: - Emitter.T.LSTM(top.Direction, top.Layout, top.Activations); - break; case IR.Tensors.Prod top: Emitter.T.Prod(); break; @@ -289,6 +286,9 @@ private void EmitTensorCall(Op op) case IR.ShapeExpr.UnsqueezeShape top: Emitter.T.UnsqueezeShape(); break; + case IR.RNN.LSTM top: + Emitter.T.LSTM(top.Direction, top.Layout, top.Activations); + break; case IR.Random.Normal top: Emitter.T.Normal(top.Type); break; diff --git a/modules/Nncase.Modules.StackVM/CodeGen/StackVM/StackVMEmitter.g.cs b/modules/Nncase.Modules.StackVM/CodeGen/StackVM/StackVMEmitter.g.cs index 6e2184c5ea..b6512a344c 100644 --- a/modules/Nncase.Modules.StackVM/CodeGen/StackVM/StackVMEmitter.g.cs +++ b/modules/Nncase.Modules.StackVM/CodeGen/StackVM/StackVMEmitter.g.cs @@ -1,6 +1,6 @@ // Copyright (c) Canaan Inc. All rights reserved. // Licensed under the Apache license. See LICENSE file in the project root for full license information. -/* This file is generated by tools/stackvm_gen/IsaGen at 2023/9/5 19:40:30 +08:00. */ +/* This file is generated by tools/stackvm_gen/IsaGen at 9/20/2023 10:17:08 AM +00:00. */ using System; using System.Collections.Generic; @@ -723,10 +723,11 @@ public void Compare(CompareOp compareOp) } ///. - public void Concat() + public void Concat(int axis) { _emitter.Write((byte)100); _emitter.Write((ushort)11); + _emitter.Write(axis); } ///. @@ -841,10 +842,11 @@ public void Flatten() } ///. - public void Gather() + public void Gather(int axis) { _emitter.Write((byte)100); _emitter.Write((ushort)27); + _emitter.Write(axis); } ///. @@ -925,12 +927,13 @@ public void L2Normalization() } ///. - public void LayerNorm(int axis, float epsilon) + public void LayerNorm(int axis, float epsilon, bool useMean) { _emitter.Write((byte)100); _emitter.Write((ushort)39); _emitter.Write(axis); _emitter.Write(epsilon); + _emitter.Write(useMean); } ///. diff --git a/modules/Nncase.Modules.StackVM/Targets/CPUTarget.cs b/modules/Nncase.Modules.StackVM/Targets/CPUTarget.cs index ba77ae58f5..2f63e02be9 100644 --- a/modules/Nncase.Modules.StackVM/Targets/CPUTarget.cs +++ b/modules/Nncase.Modules.StackVM/Targets/CPUTarget.cs @@ -3,6 +3,7 @@ using System; using System.Collections.Generic; +using System.CommandLine.Invocation; using System.Linq; using System.Text; using System.Threading.Tasks; @@ -21,13 +22,15 @@ namespace Nncase.Targets; /// public class CPUTarget : ITarget { - /// - /// Gets kind. - /// - public static readonly string Kind = "cpu"; + public const string Kind = "cpu"; string ITarget.Kind => Kind; + public (System.CommandLine.Command Command, Func Parser) RegisterCommandAndParser() + { + return (new System.CommandLine.Command(Kind), (_, _) => DefaultTargetCompileOptions.Instance); + } + /// public void ParseTargetDependentOptions(IConfigurationSection configure) { diff --git a/modules/Nncase.Modules.StackVM/packages.lock.json b/modules/Nncase.Modules.StackVM/packages.lock.json index b69bdc40f8..4830d0bbf6 100644 --- a/modules/Nncase.Modules.StackVM/packages.lock.json +++ b/modules/Nncase.Modules.StackVM/packages.lock.json @@ -4,11 +4,11 @@ "net7.0": { "StyleCop.Analyzers": { "type": "Direct", - "requested": "[1.2.0-beta.507, )", - "resolved": "1.2.0-beta.507", - "contentHash": "/FtugDT66cKJJ+GGH7rNpG6UDrT4iIWz45M6lrXXHobDUFDHw+q5VgkbiR+6ffTO564ge7w6fQh/eoQhVdJO8Q==", + "requested": "[1.2.0-beta.435, )", + "resolved": "1.2.0-beta.435", + "contentHash": "TADk7vdGXtfTnYCV7GyleaaRTQjfoSfZXprQrVMm7cSJtJbFc1QIbWPyLvrgrfGdfHbGmUPvaN4ODKNxg2jgPQ==", "dependencies": { - "StyleCop.Analyzers.Unstable": "1.2.0.507" + "StyleCop.Analyzers.Unstable": "1.2.0.435" } }, "Google.OrTools.runtime.linux-arm64": { @@ -103,8 +103,8 @@ }, "StyleCop.Analyzers.Unstable": { "type": "Transitive", - "resolved": "1.2.0.507", - "contentHash": "gTY3IQdRqDJ4hbhSA3e/R48oE8b/OiKfvwkt1QdNVfrJK2gMHBV8ldaHJ885jxWZfllK66soa/sdcjh9bX49Tw==" + "resolved": "1.2.0.435", + "contentHash": "ouwPWZxbOV3SmCZxIRqHvljkSzkCyi1tDoMzQtDb/bRP8ctASV/iRJr+A2Gdj0QLaLmWnqTWDrH82/iP+X80Lg==" }, "System.Buffers": { "type": "Transitive", @@ -134,6 +134,7 @@ "Microsoft.Extensions.Options": "[6.0.0, )", "Microsoft.Toolkit.HighPerformance": "[7.1.1, )", "NetFabric.Hyperlinq": "[3.0.0-beta48, )", + "System.CommandLine": "[2.0.0-beta4.22272.1, )", "System.Reactive": "[5.0.0, )" } }, @@ -271,6 +272,12 @@ "resolved": "1.0.2", "contentHash": "giLAHrjJe0Bh7yhNexR6pmcv02+Fi+lEPxQVdB8zvkuJCmy6rnqu8CZLIpxrUfLcWDuTCSiK0IfGmMhig3UDhA==" }, + "System.CommandLine": { + "type": "CentralTransitive", + "requested": "[2.0.0-beta4.22272.1, )", + "resolved": "2.0.0-beta4.22272.1", + "contentHash": "1uqED/q2H0kKoLJ4+hI2iPSBSEdTuhfCYADeJrAqERmiGQ2NNacYKRNEQ+gFbU4glgVyK8rxI+ZOe1onEtr/Pg==" + }, "System.Reactive": { "type": "CentralTransitive", "requested": "[5.0.0, )", diff --git a/nncase.sln b/nncase.sln index afde606dc0..065feb6959 100644 --- a/nncase.sln +++ b/nncase.sln @@ -44,8 +44,6 @@ Project("{9A19103F-16F7-4668-BE54-9A1E7A4F7556}") = "Nncase.IO", "src\Nncase.IO\ EndProject Project("{9A19103F-16F7-4668-BE54-9A1E7A4F7556}") = "Nncase.Schedule", "src\Nncase.Schedule\Nncase.Schedule.csproj", "{8E0E0672-0F96-4EF1-BDCD-D31F96A3DF73}" EndProject -Project("{2150E333-8FDC-42A3-9474-1A3956D46DE8}") = "targets", "targets", "{A2590531-71C5-4326-88DD-6A9DB2EF0A2B}" -EndProject Project("{9A19103F-16F7-4668-BE54-9A1E7A4F7556}") = "Nncase.Targets", "src\Nncase.Targets\Nncase.Targets.csproj", "{56283378-06E3-4C6E-A8BF-7BD85C92D42C}" EndProject Project("{9A19103F-16F7-4668-BE54-9A1E7A4F7556}") = "Nncase.Simulator", "src\Nncase.Simulator\Nncase.Simulator.csproj", "{901AC17C-7B53-4B10-A2AC-EA7AEA6DC614}" diff --git a/python/common/pystreambuf.h b/python/common/pystreambuf.h index 178a041e0c..27581c1d75 100644 --- a/python/common/pystreambuf.h +++ b/python/common/pystreambuf.h @@ -1,6 +1,7 @@ // https://gist.github.com/asford/544323a5da7dddad2c9174490eb5ed06 #pragma once +#include #include #include diff --git a/python/nncase/__init__.py b/python/nncase/__init__.py index 3c6aa2b69f..cef1b150bb 100644 --- a/python/nncase/__init__.py +++ b/python/nncase/__init__.py @@ -298,7 +298,7 @@ def _import_ncnn_module(self, model_param: bytes | io.RawIOBase, model_bin: byte def check_target(target: str): def test_target(target: str): - return target in ["cpu", "k510", "k230"] + return target in ["cpu", "k510", "k230", "xpu"] def target_exists(target: str): return _nncase.Target.exists(target) diff --git a/src/Native/include/nncase/kernels/stackvm/tensor_ops.h b/src/Native/include/nncase/kernels/stackvm/tensor_ops.h index e918f22f25..84cae4e387 100644 --- a/src/Native/include/nncase/kernels/stackvm/tensor_ops.h +++ b/src/Native/include/nncase/kernels/stackvm/tensor_ops.h @@ -1,5 +1,5 @@ -/* This file is generated by tools/stackvm_gen/IsaGen at 2023/9/5 19:40:29 - * +08:00. +/* This file is generated by tools/stackvm_gen/IsaGen at 9/20/2023 10:17:07 AM + * +00:00. * * Copyright 2019-2021 Canaan Inc. * @@ -78,7 +78,7 @@ compare(runtime::stackvm::compare_op_t compare_op, value_t lhs, value_t rhs, kernel_context &context = default_kernel_context()); NNCASE_API result -concat(value_t input, value_t axis, value_t output = nullptr, +concat(int32_t axis, value_t input, value_t output = nullptr, kernel_context &context = default_kernel_context()); NNCASE_API result @@ -157,7 +157,7 @@ flatten(value_t input, value_t axis, value_t output = nullptr, kernel_context &context = default_kernel_context()); NNCASE_API result -gather(value_t input, value_t axis, value_t index, value_t output = nullptr, +gather(int32_t axis, value_t input, value_t index, value_t output = nullptr, kernel_context &context = default_kernel_context()); NNCASE_API result @@ -211,8 +211,8 @@ l2_normalization(value_t input, value_t output = nullptr, kernel_context &context = default_kernel_context()); NNCASE_API result -layer_norm(int32_t axis, float epsilon, value_t input, value_t scale, - value_t bias, value_t output = nullptr, +layer_norm(int32_t axis, float epsilon, bool use_mean, value_t input, + value_t scale, value_t bias, value_t output = nullptr, kernel_context &context = default_kernel_context()); NNCASE_API result diff --git a/src/Native/include/nncase/runtime/interpreter.h b/src/Native/include/nncase/runtime/interpreter.h index fc91970657..a8f5814d34 100644 --- a/src/Native/include/nncase/runtime/interpreter.h +++ b/src/Native/include/nncase/runtime/interpreter.h @@ -73,6 +73,7 @@ class NNCASE_API interpreter { options_dict &options() noexcept; result find_module_by_id(size_t index) noexcept; + result find_id_by_module(runtime_module *module) noexcept; /* V1 APIs */ diff --git a/src/Native/include/nncase/runtime/runtime_module.h b/src/Native/include/nncase/runtime/runtime_module.h index 354d747cbf..194dd3b1f7 100644 --- a/src/Native/include/nncase/runtime/runtime_module.h +++ b/src/Native/include/nncase/runtime/runtime_module.h @@ -58,6 +58,8 @@ class NNCASE_API runtime_module { result find_function_by_id(size_t index) noexcept; + result find_id_by_function(runtime_function *function) noexcept; + protected: virtual result initialize_before_functions(runtime_module_init_context &context) noexcept; diff --git a/src/Native/include/nncase/runtime/stackvm/op_reader.h b/src/Native/include/nncase/runtime/stackvm/op_reader.h index 80372463e4..ffde6669bd 100644 --- a/src/Native/include/nncase/runtime/stackvm/op_reader.h +++ b/src/Native/include/nncase/runtime/stackvm/op_reader.h @@ -1,5 +1,5 @@ -/* This file is generated by tools/stackvm_gen/IsaGen at 2023/9/5 19:40:29 - * +08:00. +/* This file is generated by tools/stackvm_gen/IsaGen at 9/20/2023 10:17:07 AM + * +00:00. * * Copyright 2019-2021 Canaan Inc. * @@ -837,6 +837,7 @@ template <> struct tensor_op_reader { template <> struct tensor_op_reader { tensor_concat_op_t operator()(NNCASE_UNUSED span_reader &reader) const { tensor_concat_op_t op; + op.axis = reader.read_unaligned(); return op; } }; @@ -964,6 +965,7 @@ template <> struct tensor_op_reader { template <> struct tensor_op_reader { tensor_gather_op_t operator()(NNCASE_UNUSED span_reader &reader) const { tensor_gather_op_t op; + op.axis = reader.read_unaligned(); return op; } }; @@ -1055,6 +1057,7 @@ template <> struct tensor_op_reader { tensor_layer_norm_op_t op; op.axis = reader.read_unaligned(); op.epsilon = reader.read_unaligned(); + op.use_mean = reader.read_unaligned(); return op; } }; diff --git a/src/Native/include/nncase/runtime/stackvm/opcode.h b/src/Native/include/nncase/runtime/stackvm/opcode.h index 5c17c82894..8a5225b54e 100644 --- a/src/Native/include/nncase/runtime/stackvm/opcode.h +++ b/src/Native/include/nncase/runtime/stackvm/opcode.h @@ -1,5 +1,5 @@ -/* This file is generated by tools/stackvm_gen/IsaGen at 2023/9/5 19:40:29 - * +08:00. +/* This file is generated by tools/stackvm_gen/IsaGen at 9/20/2023 10:17:07 AM + * +00:00. * * Copyright 2019-2021 Canaan Inc. * @@ -190,7 +190,6 @@ enum class tensor_function_t : uint16_t { gather_nd = 29, get_item = 31, index_of = 36, - lstm = 44, prod = 52, range = 55, rank = 57, @@ -218,6 +217,7 @@ enum class tensor_function_t : uint16_t { squeeze_shape = 81, transpose_shape = 87, unsqueeze_shape = 93, + lstm = 44, normal = 47, normal_like = 48, uniform = 90, @@ -614,7 +614,9 @@ struct tensor_compare_op_t { compare_op_t compare_op; }; -struct tensor_concat_op_t {}; +struct tensor_concat_op_t { + int32_t axis; +}; struct tensor_condition_op_t { bool can_fold_const_call; @@ -658,7 +660,9 @@ struct tensor_fix_shape_op_t {}; struct tensor_flatten_op_t {}; -struct tensor_gather_op_t {}; +struct tensor_gather_op_t { + int32_t axis; +}; struct tensor_gather_elements_op_t {}; @@ -685,6 +689,7 @@ struct tensor_l2_normalization_op_t {}; struct tensor_layer_norm_op_t { int32_t axis; float epsilon; + bool use_mean; }; struct tensor_leaky_relu_op_t {}; @@ -964,8 +969,6 @@ inline std::string to_string(tensor_function_t tensor_funct) { return "get_item"; case tensor_function_t::index_of: return "index_of"; - case tensor_function_t::lstm: - return "lstm"; case tensor_function_t::prod: return "prod"; case tensor_function_t::range: @@ -1020,6 +1023,8 @@ inline std::string to_string(tensor_function_t tensor_funct) { return "transpose_shape"; case tensor_function_t::unsqueeze_shape: return "unsqueeze_shape"; + case tensor_function_t::lstm: + return "lstm"; case tensor_function_t::normal: return "normal"; case tensor_function_t::normal_like: diff --git a/src/Native/src/kernels/stackvm/tensor_ops.cpp b/src/Native/src/kernels/stackvm/tensor_ops.cpp index b9ea85d40b..2cb452aa3d 100644 --- a/src/Native/src/kernels/stackvm/tensor_ops.cpp +++ b/src/Native/src/kernels/stackvm/tensor_ops.cpp @@ -47,8 +47,9 @@ result nncase::kernels::stackvm::batch_normalization( } result nncase::kernels::stackvm::layer_norm( - int32_t axis, float epsilon, value_t input, value_t scale, value_t bias, - value_t output, [[maybe_unused]] kernel_context &context) { + int32_t axis, float epsilon, [[maybe_unused]] bool use_mean, value_t input, + value_t scale, value_t bias, value_t output, + [[maybe_unused]] kernel_context &context) { try_input(input_mem, input); try_input(scale_mem, scale); try_input(bias_mem, bias); @@ -124,7 +125,7 @@ nncase::kernels::stackvm::clamp(value_t input, value_t min, value_t max, KERNEL_FINISH; } -result nncase::kernels::stackvm::concat(value_t input, value_t axis, +result nncase::kernels::stackvm::concat(int32_t axis, value_t input, value_t output, kernel_context &context) { try_tuple_input(inputs_mem, input); @@ -132,7 +133,7 @@ result nncase::kernels::stackvm::concat(value_t input, value_t axis, try_var(strides, get_strides(input_tuple)); try_tuple_field0(input0, input_tuple); auto dtype = input0->dtype(); - try_positive_axis_with_rank(axis_value, axis, input0->shape().size()); + auto axis_value = positive_index(axis, input0->shape().size()); auto out_shape = concat_infer_shape(shapes, axis_value); try_output(out_mem, output, dtype, out_shape); auto concat_dims = dims_t(); @@ -293,14 +294,15 @@ nncase::kernels::stackvm::flatten(value_t input, value_t axis, value_t output, KERNEL_FINISH; } -result nncase::kernels::stackvm::gather(value_t input, value_t axis, +result nncase::kernels::stackvm::gather(int32_t axis, value_t input, value_t index, value_t output, kernel_context &context) { try_input(input_mem, input); try_input(index_mem, index); auto dtype = input_tensor->dtype(); try_var(typecode, to_typecode(dtype)); - try_positive_axis(axis_value, axis, input_tensor); + // try_positive_axis(axis_value, axis, input_tensor); + auto axis_value = positive_index(axis, input_tensor->shape().size()); auto out_shape = gather_infer_shape(input_tensor->shape(), index_tensor->shape(), axis_value); try_output(out_mem, output, dtype, out_shape); diff --git a/src/Native/src/runtime/CMakeLists.txt b/src/Native/src/runtime/CMakeLists.txt index b892beded3..f92450b6a0 100644 --- a/src/Native/src/runtime/CMakeLists.txt +++ b/src/Native/src/runtime/CMakeLists.txt @@ -54,6 +54,7 @@ else() add_library(simulator OBJECT ${SRCS}) target_include_directories(simulator PRIVATE ${CMAKE_CURRENT_BINARY_DIR}) target_link_libraries(simulator PUBLIC gsl::gsl-lite) + target_link_libraries(simulator PUBLIC fmt::fmt) target_link_libraries(simulator PRIVATE kernels) target_compile_definitions(simulator PUBLIC -DNNCASE_DLL -DNNCASE_SIMULATOR) if (DEFAULT_BUILTIN_RUNTIMES) diff --git a/src/Native/src/runtime/interpreter.cpp b/src/Native/src/runtime/interpreter.cpp index dc5839fb44..fe69b8a071 100644 --- a/src/Native/src/runtime/interpreter.cpp +++ b/src/Native/src/runtime/interpreter.cpp @@ -246,6 +246,17 @@ result interpreter::find_module_by_id(size_t index) noexcept { return ok(modules_[index].get()); } +result interpreter::find_id_by_module(runtime_module *module) noexcept { + auto it = std::find_if(modules_.begin(), modules_.end(), + [&module](const std::unique_ptr &p) { + return p.get() == module; + }); + if (it == modules_.end()) { + return err(std::errc::result_out_of_range); + } + return ok((it - modules_.begin())); +} + options_dict &interpreter::options() noexcept { return options_; } result interpreter::entry_function() noexcept { diff --git a/src/Native/src/runtime/runtime_module.cpp b/src/Native/src/runtime/runtime_module.cpp index 4b5d747bbd..66acaca61b 100644 --- a/src/Native/src/runtime/runtime_module.cpp +++ b/src/Native/src/runtime/runtime_module.cpp @@ -189,6 +189,19 @@ runtime_module::find_function_by_id(size_t index) noexcept { return ok(functions_[index].get()); } +result +runtime_module::find_id_by_function(runtime_function *function) noexcept { + auto it = + std::find_if(functions_.begin(), functions_.end(), + [&function](const std::unique_ptr &p) { + return p.get() == function; + }); + if (it == functions_.end()) { + return err(std::errc::result_out_of_range); + } + return ok((it - functions_.begin())); +} + result runtime_module::initialize_before_functions( NNCASE_UNUSED runtime_module_init_context &context) noexcept { return ok(); diff --git a/src/Native/src/runtime/stackvm/op_reader.cpp b/src/Native/src/runtime/stackvm/op_reader.cpp index 901f0f6125..4cb7f0e36a 100644 --- a/src/Native/src/runtime/stackvm/op_reader.cpp +++ b/src/Native/src/runtime/stackvm/op_reader.cpp @@ -1,5 +1,5 @@ -/* This file is generated by tools/stackvm_gen/IsaGen at 2023/9/5 19:40:29 - * +08:00. +/* This file is generated by tools/stackvm_gen/IsaGen at 9/20/2023 10:17:07 AM + * +00:00. * * Copyright 2019-2021 Canaan Inc. * diff --git a/src/Native/src/runtime/stackvm/ops/tensor.cpp b/src/Native/src/runtime/stackvm/ops/tensor.cpp index 6f09a7084c..3172fd0f90 100644 --- a/src/Native/src/runtime/stackvm/ops/tensor.cpp +++ b/src/Native/src/runtime/stackvm/ops/tensor.cpp @@ -1,5 +1,5 @@ -/* This file is generated by tools/stackvm_gen/IsaGen at 2023/9/5 19:40:29 - * +08:00. +/* This file is generated by tools/stackvm_gen/IsaGen at 9/20/2023 10:17:08 AM + * +00:00. * * Copyright 2019-2021 Canaan Inc. * @@ -207,9 +207,7 @@ result stackvm_runtime_function::visit( dump_op("concat"); try_var(input, pop_value()); dump_input(input); - try_var(axis, pop_value()); - dump_input(axis); - try_var(output, kernels::stackvm::concat(input, axis, nullptr, + try_var(output, kernels::stackvm::concat(op.axis, input, nullptr, module().kernel_context())); dump_output(output); stack_.push(std::move(output)); @@ -491,11 +489,9 @@ result stackvm_runtime_function::visit( dump_op("gather"); try_var(input, pop_value()); dump_input(input); - try_var(axis, pop_value()); - dump_input(axis); try_var(index, pop_value()); dump_input(index); - try_var(output, kernels::stackvm::gather(input, axis, index, nullptr, + try_var(output, kernels::stackvm::gather(op.axis, input, index, nullptr, module().kernel_context())); dump_output(output); stack_.push(std::move(output)); @@ -683,9 +679,9 @@ result stackvm_runtime_function::visit( dump_input(scale); try_var(bias, pop_value()); dump_input(bias); - try_var(output, kernels::stackvm::layer_norm(op.axis, op.epsilon, input, - scale, bias, nullptr, - module().kernel_context())); + try_var(output, kernels::stackvm::layer_norm( + op.axis, op.epsilon, op.use_mean, input, scale, bias, + nullptr, module().kernel_context())); dump_output(output); stack_.push(std::move(output)); return ok(); diff --git a/src/Native/src/runtime/stackvm/runtime_function_ops.h b/src/Native/src/runtime/stackvm/runtime_function_ops.h index ae6944ef59..351b758b88 100644 --- a/src/Native/src/runtime/stackvm/runtime_function_ops.h +++ b/src/Native/src/runtime/stackvm/runtime_function_ops.h @@ -1,5 +1,5 @@ -/* This file is generated by tools/stackvm_gen/IsaGen at 2023/9/5 19:40:29 - * +08:00. +/* This file is generated by tools/stackvm_gen/IsaGen at 9/20/2023 10:17:07 AM + * +00:00. * * Copyright 2019-2021 Canaan Inc. * diff --git a/src/Native/src/test_cli.cpp b/src/Native/src/test_cli.cpp index 7f703a216a..3e95be5a64 100644 --- a/src/Native/src/test_cli.cpp +++ b/src/Native/src/test_cli.cpp @@ -12,6 +12,7 @@ * See the License for the specific language governing permissions and * limitations under the License. */ +#include #include #include #include @@ -19,6 +20,8 @@ using namespace nncase; using namespace nncase::runtime; +// constexpr size_t loop_count = 10; +constexpr size_t loop_count = 1; #define TRY(x) \ if (x) \ @@ -34,8 +37,7 @@ result write_tensor_buffer(value_t value, std::ofstream &of) { } result run_core(const std::string &kmodel_path, - const std::vector &input_bins, - const std::string &output_bin) { + const std::vector &bins) { auto kmodel = read_file(kmodel_path); interpreter *interp = new interpreter(); // auto dump_path = @@ -47,16 +49,16 @@ result run_core(const std::string &kmodel_path, try_var(entry, interp->entry_function()); - if (entry->parameters_size() != input_bins.size()) + if (entry->parameters_size() > bins.size()) return err(std::errc::argument_list_too_long); /* create the input parameters tensor note the input tenosr must be contiguous */ std::vector parameters; - for (int i = 0; i < input_bins.size(); i++) { + for (int i = 0; i < entry->parameters_size(); i++) { try_var(type, entry->parameter_type(i)); try_var(ts_type, type.as()); - auto input_pool = read_file(input_bins[i]); + auto input_pool = read_file(bins[i]); gsl::span input_pool_span = { reinterpret_cast(input_pool.data()), input_pool.size()}; @@ -66,21 +68,40 @@ result run_core(const std::string &kmodel_path, parameters.push_back(_.impl()); } - try_var(ret, entry->invoke({parameters.data(), parameters.size()})); + double total_time = 0.0; + for (size_t i = 0; i < loop_count; i++) { + auto start_time = std::chrono::steady_clock::now(); + try_var(ret, entry->invoke({parameters.data(), parameters.size()})); + auto end_time = std::chrono::steady_clock::now(); + total_time += (std::chrono::duration_cast( + end_time - start_time) + .count() / + 1e6); - std::ofstream output_stream(output_bin, std::ios::binary); - - if (ret.is_a()) { - try_(write_tensor_buffer(ret, output_stream)); - } else if (ret.is_a()) { - try_var(tp, ret.as()); - for (auto &&ret_v : tp->fields()) { - try_(write_tensor_buffer(ret_v, output_stream)); + if (i == (loop_count - 1) && (entry->parameters_size() < bins.size())) { + if (ret.is_a()) { + auto output_bin = bins.back(); + std::ofstream output_stream(output_bin, std::ios::binary); + try_(write_tensor_buffer(ret, output_stream)); + output_stream.close(); + } else if (ret.is_a()) { + try_var(tp, ret.as()); + auto o = 0; + for (auto &&ret_v : tp->fields()) { + auto output_bin = bins[entry->parameters_size() + (o++)]; + std::ofstream output_stream(output_bin, std::ios::binary); + try_(write_tensor_buffer(ret_v, output_stream)); + output_stream.close(); + } + } else { + return nncase::err(std::errc::bad_message); + } } - } else { - return nncase::err(std::errc::bad_message); } - output_stream.close(); + + std::cout << "interp run: " << (total_time / loop_count) + << " ms, fps = " << 1000 / (total_time / loop_count) << std::endl; + return ok(); } @@ -92,13 +113,12 @@ result run_core(const std::string &kmodel_path, * @return int */ int main(NNCASE_UNUSED int argc, char **argv) { - assert(argc >= 4); - std::vector input_bins; - for (int i = 2; i < argc - 1; i++) { - input_bins.push_back(argv[i]); + assert(argc >= 3); + std::vector bins; + for (int i = 2; i < argc; i++) { + bins.push_back(argv[i]); } std::string kmodel_bin(argv[1]); - std::string output_bin(argv[argc - 1]); - run_core(kmodel_bin, input_bins, output_bin).unwrap_or_throw(); + run_core(kmodel_bin, bins).unwrap_or_throw(); return 0; -} +} \ No newline at end of file diff --git a/src/Nncase.Cli/Commands/Compile.cs b/src/Nncase.Cli/Commands/Compile.cs deleted file mode 100644 index edb64f2b3e..0000000000 --- a/src/Nncase.Cli/Commands/Compile.cs +++ /dev/null @@ -1,291 +0,0 @@ -// Copyright (c) Canaan Inc. All rights reserved. -// Licensed under the Apache license. See LICENSE file in the project root for full license information. - -using System; -using System.Collections.Generic; -using System.CommandLine; -using System.CommandLine.Invocation; -using System.IO; -using System.Linq; -using System.Threading.Tasks; -using Microsoft.Extensions.DependencyInjection; -using Microsoft.Extensions.Hosting; -using Nncase.CodeGen; -using Nncase.Compiler; -using Nncase.Diagnostics; -using Nncase.IR; -using Nncase.Passes; -using Nncase.Quantization; - -namespace Nncase.Cli.Commands; - -internal enum QuantType -{ - UInt8, - Int8, - Int16, -} - -internal enum DatasetFormat -{ - Image, - Raw, - Pytest, - Random, -} - -/// -/// Compile command. -/// -public sealed class Compile : Command -{ - /// - /// Initializes a new instance of the class. - /// - public Compile() - : base("compile") - { - AddArgument(new Argument("input-file")); - AddArgument(new Argument("output-file")); - AddOption(new Option( - aliases: new string[] { "-t", "--target" }, - description: "target architecture, e.g. cpu, k210")); - AddOption(new Option( - aliases: new[] { "-i", "--input-format" }, - description: "input format, e.g. tflite", - getDefaultValue: () => "tflite")); - AddOption(new Option( - alias: "--dump-level", - description: $"dump ir to .il, default is {0}", - getDefaultValue: () => 0)); - AddOption(new Option( - alias: "--dump-dir", - description: "dump to directory, default is .", - getDefaultValue: () => ".")); - AddOption(new Option( - alias: "--quant-type", - description: $"quant type, default is {QuantType.UInt8}", - getDefaultValue: () => QuantType.UInt8)); - AddOption(new Option( - alias: "--wquant-type", - description: $"wquant type, default is {QuantType.UInt8}", - getDefaultValue: () => QuantType.UInt8)); - AddOption(new Option( - alias: "--dataset", - description: $"calibration dataset, used in post quantization, default is empty", - getDefaultValue: () => string.Empty)); - AddOption(new Option( - alias: "--dataset-format", - description: $"datset format: e.g. Image|Raw|Pytest", - getDefaultValue: () => DatasetFormat.Raw)); - AddOption(new Option( - alias: "--model-quant-mode", - description: $"model quant mode, default is {Quantization.ModelQuantMode.NoQuant}", - getDefaultValue: () => Quantization.ModelQuantMode.NoQuant)); - AddOption(new Option( - alias: "--calib-method", - description: $"model quant options, default is {Quantization.CalibMethod.Kld}", - getDefaultValue: () => Quantization.CalibMethod.Kld)); - AddOption(new Option( - alias: "--pre-process", - description: "whether enable pre process, default is False", - getDefaultValue: () => false)); - AddOption(new Option( - alias: "--input-layout", - description: "the model input data layout, default is empty. eg. NCHW/NHWC", - getDefaultValue: () => string.Empty)); - AddOption(new Option( - alias: "--output-layout", - description: "the model output data layout, default is empty. eg. NCHW/NHWC", - getDefaultValue: () => string.Empty)); - AddOption(new Option( - alias: "--input-type", - description: "the model input data value type, default is Float32", - getDefaultValue: () => InputType.Float32)); - AddOption(new Option>( - alias: "--input-shape", - description: "the model input data shape, default is []. eg. `--input-shape 1 2 3 4`", - getDefaultValue: () => Array.Empty())); - AddOption(new Option>( - alias: "--input-range", - description: "the model input data value range, default is []. eg `--input-range -100.3 200.4`", - getDefaultValue: () => Array.Empty())); - AddOption(new Option( - alias: "--swap-rb", - description: "whether swap the model input data channel R and B", - getDefaultValue: () => false)); - AddOption(new Option( - alias: "--letter-box-value", - description: "letterbox value, default 0.0", - getDefaultValue: () => 0.0f)); - AddOption(new Option>( - alias: "--mean", - description: "the model input data mean, default []", - getDefaultValue: () => Array.Empty())); - AddOption(new Option>( - alias: "--std", - description: "the model input data std, default []", - getDefaultValue: () => Array.Empty())); - AddOption(new Option( - alias: "--model-layout", - description: "the model's input layout, default is empty. eg. NCHW/NHWC", - getDefaultValue: () => string.Empty)); - AddOption(new Option( - alias: "--benchmark-only", - description: $"benchmark only", - getDefaultValue: () => false)); - - Handler = CommandHandler.Create(RunAsync); - } - - private static DumpFlags DumpLevelToFlags(int dumpLevel) - { - return dumpLevel switch - { - 0 => DumpFlags.None, - 1 => DumpLevelToFlags(0) | DumpFlags.Compile, - 2 => DumpLevelToFlags(1) | DumpFlags.PassIR, - 3 => DumpLevelToFlags(2) | DumpFlags.Rewrite, - 4 => DumpLevelToFlags(3) | DumpFlags.EGraphCost, - 5 => DumpLevelToFlags(4) | DumpFlags.Evaluator, - 6 => DumpLevelToFlags(5) | DumpFlags.Calibration, - 7 => DumpLevelToFlags(6) | DumpFlags.Tiling, - 8 => DumpLevelToFlags(7) | DumpFlags.Schedule, - >= 9 => DumpLevelToFlags(8) | DumpFlags.CodeGen, - _ => throw new ArgumentOutOfRangeException(nameof(dumpLevel)), - }; - } - - private async Task RunAsync(CliCompileOptions cliOptions, IHost host) - { - CompilerServices.Configure(host.Services); - - // 1. setup the options - var compileOptions = new CompileOptions - { - InputFile = cliOptions.InputFile, - InputFormat = cliOptions.InputFormat, - DumpFlags = DumpLevelToFlags(cliOptions.DumpLevel), - DumpDir = cliOptions.DumpDir, - QuantizeOptions = new() - { - CalibrationMethod = cliOptions.CalibMethod, - QuantType = cliOptions.QuantType switch - { - QuantType.UInt8 => DataTypes.UInt8, - QuantType.Int8 => DataTypes.Int8, - QuantType.Int16 => DataTypes.Int16, - _ => throw new ArgumentException("Invalid quant type"), - }, - WQuantType = cliOptions.WQuantType switch - { - QuantType.UInt8 => DataTypes.UInt8, - QuantType.Int8 => DataTypes.Int8, - QuantType.Int16 => DataTypes.Int16, - _ => throw new ArgumentException("Invalid weights quant type"), - }, - ModelQuantMode = cliOptions.ModelQuantMode, - }, - PreProcess = cliOptions.PreProcess, - InputLayout = cliOptions.InputLayout, - OutputLayout = cliOptions.OutputLayout, - InputType = cliOptions.InputType, - InputShape = cliOptions.InputShape.ToArray(), - InputRange = cliOptions.InputRange.ToArray(), - SwapRB = cliOptions.SwapRB, - LetterBoxValue = cliOptions.LetterBoxValue, - Mean = cliOptions.Mean.ToArray(), - Std = cliOptions.Std.ToArray(), - ModelLayout = cliOptions.ModelLayout, - IsBenchmarkOnly = cliOptions.BenchmarkOnly, - }; - - // 2. import the model - var target = CompilerServices.GetTarget(cliOptions.Target); - using var compileSession = CompileSession.Create(target, compileOptions); - var compiler = compileSession.Compiler; - var module = await compiler.ImportModuleAsync(compileOptions.InputFormat, compileOptions.InputFile, compileOptions.IsBenchmarkOnly); - - // 3. create the calib dataset - if (compileOptions.QuantizeOptions.ModelQuantMode == Quantization.ModelQuantMode.UsePTQ) - { - if (cliOptions.DatasetFormat == DatasetFormat.Random) - { - compileOptions.QuantizeOptions.CalibrationDataset = new RandomCalibrationDatasetProvider(((Function)module.Entry!).Parameters.ToArray(), 5); - } - else if (cliOptions.DatasetFormat == DatasetFormat.Pytest) - { - compileOptions.QuantizeOptions.CalibrationDataset = new PytestCalibrationDatasetProvider(((Function)module.Entry!).Parameters.ToArray(), cliOptions.Dataset); - } - else - { - throw new NotSupportedException(cliOptions.DatasetFormat.ToString()); - } - } - - // 4. compile - await compiler.CompileAsync(); - - // 5. code gen - using (var os = File.OpenWrite(cliOptions.OutputFile)) - { - compiler.Gencode(os); - } - } -} - -// Validate null in command line parser. -#pragma warning disable CS8618 - -internal sealed class CliCompileOptions -{ - public string InputFile { get; set; } - - public string InputFormat { get; set; } - - public string Target { get; set; } - - public int DumpLevel { get; set; } - - public string DumpDir { get; set; } - - public QuantType QuantType { get; set; } - - public QuantType WQuantType { get; set; } - - public string OutputFile { get; set; } - - public ModelQuantMode ModelQuantMode { get; set; } - - public CalibMethod CalibMethod { get; set; } - - public string Dataset { get; set; } - - public DatasetFormat DatasetFormat { get; set; } - - public bool BenchmarkOnly { get; set; } - - public bool PreProcess { get; set; } - - public string InputLayout { get; set; } - - public string OutputLayout { get; set; } - - public InputType InputType { get; set; } - - public List InputShape { get; set; } - - public List InputRange { get; set; } - - public bool SwapRB { get; set; } - - public float LetterBoxValue { get; set; } - - public List Mean { get; set; } - - public List Std { get; set; } - - public string ModelLayout { get; set; } -} - -#pragma warning restore CS8618 diff --git a/src/Nncase.Cli/Compile.cs b/src/Nncase.Cli/Compile.cs new file mode 100644 index 0000000000..e7f65a5303 --- /dev/null +++ b/src/Nncase.Cli/Compile.cs @@ -0,0 +1,217 @@ +// Copyright (c) Canaan Inc. All rights reserved. +// Licensed under the Apache license. See LICENSE file in the project root for full license information. + +using System; +using System.Collections.Generic; +using System.CommandLine; +using System.Linq; +using Nncase.Diagnostics; +using Nncase.Quantization; + +namespace Nncase.Cli; + +internal enum QuantType +{ + UInt8, + Int8, + Int16, +} + +internal enum DatasetFormat +{ + Image, + Raw, + Pytest, + Random, +} + +/// +/// Compile command. +/// +internal sealed class CompileCommand : Command +{ + /// + /// Initializes a new instance of the class. + /// + public CompileCommand() + : base("compile") + { + InputFile = new Argument("input-file"); + OutputFile = new Argument("output-file"); + InputFormat = new Option( + aliases: new[] { "-i", "--input-format" }, + description: "input format, e.g. tflite", + getDefaultValue: () => "tflite"); + DumpFlags = new Option>( + name: "--dump-flags", + description: "dump ir flags. \navailable value: None,ImportOps,PassIR,EGraphCost,Rewrite,Calibration,Evaluator,Compile,Tiling,Schedule,CodeGen.") + { + AllowMultipleArgumentsPerToken = true, + }; + DumpDir = new Option( + name: "--dump-dir", + description: "dump to directory.", + getDefaultValue: () => "."); + QuantType = new Option( + name: "--quant-type", + description: $"quant type", + getDefaultValue: () => Nncase.Cli.QuantType.UInt8); + WQuantType = new Option( + name: "--wquant-type", + description: $"wquant type", + getDefaultValue: () => Nncase.Cli.QuantType.UInt8); + Dataset = new Option( + name: "--dataset", + description: $"calibration dataset, used in post quantization", + getDefaultValue: () => string.Empty); + DatasetFormat = new Option( + name: "--dataset-format", + description: $"datset format.", + getDefaultValue: () => Nncase.Cli.DatasetFormat.Raw); + ModelQuantMode = new Option( + name: "--model-quant-mode", + description: $"model quant mode", + getDefaultValue: () => Quantization.ModelQuantMode.NoQuant); + CalibMethod = new Option( + name: "--calib-method", + description: $"model quant options", + getDefaultValue: () => Quantization.CalibMethod.Kld); + FixedVars = new Option>( + name: "--fixed-vars", + description: $"dynamic shape fixed vars, default is empty. \nset by `n:123`", + parseArgument: result => + { + return result.Tokens. + Select(tk => tk.Value.Split(":").ToArray()). + Select(tp => (tp[0].Trim(), int.Parse(tp[1].Trim()))); + }) + { + AllowMultipleArgumentsPerToken = true, + }; + PreProcess = new Option( + name: "--pre-process", + description: "whether enable pre process", + getDefaultValue: () => false); + InputLayout = new Option( + name: "--input-layout", + description: "the model input data layout", + getDefaultValue: () => string.Empty).FromAmong("NCHW", "NHWC"); + OutputLayout = new Option( + name: "--output-layout", + description: "the model output data layout.", + getDefaultValue: () => string.Empty).FromAmong("NCHW", "NHWC"); + InputType = new Option( + name: "--input-type", + description: "the model input data value type, default is Float32", + getDefaultValue: () => Nncase.InputType.Float32); + InputShape = new Option>( + name: "--input-shape", + description: "the model input data shape. eg. `--input-shape 1 2 3 4`", + getDefaultValue: Array.Empty) + { + AllowMultipleArgumentsPerToken = true, + }; + InputRange = new Option>( + name: "--input-range", + description: "the model input data value range. eg `--input-range -100.3 200.4`", + getDefaultValue: Array.Empty) + { + AllowMultipleArgumentsPerToken = true, + }; + SwapRB = new Option( + name: "--swap-rb", + description: "whether swap the model input data channel, like cv2.BGRtoRGB(im)", + getDefaultValue: () => false); + LetterBoxValue = new Option( + name: "--letter-box-value", + description: "letterbox fill value", + getDefaultValue: () => 0.0f); + Mean = new Option>( + name: "--mean", + description: "the model input data mean, default []", + getDefaultValue: Array.Empty) + { + AllowMultipleArgumentsPerToken = true, + }; + Std = new Option>( + name: "--std", + description: "the model input data std, default []", + getDefaultValue: Array.Empty) + { + AllowMultipleArgumentsPerToken = true, + }; + ModelLayout = new Option( + name: "--model-layout", + description: "the model's input layout.", + getDefaultValue: () => string.Empty).FromAmong("NCHW", "NHWC"); + AddArgument(InputFile); + AddArgument(OutputFile); + AddGlobalOption(InputFormat); + AddGlobalOption(DumpFlags); + AddGlobalOption(DumpDir); + AddGlobalOption(QuantType); + AddGlobalOption(WQuantType); + AddGlobalOption(Dataset); + AddGlobalOption(DatasetFormat); + AddGlobalOption(ModelQuantMode); + AddGlobalOption(CalibMethod); + AddGlobalOption(FixedVars); + AddGlobalOption(PreProcess); + AddGlobalOption(InputLayout); + AddGlobalOption(OutputLayout); + AddGlobalOption(InputType); + AddGlobalOption(InputShape); + AddGlobalOption(InputRange); + AddGlobalOption(SwapRB); + AddGlobalOption(LetterBoxValue); + AddGlobalOption(Mean); + AddGlobalOption(Std); + AddGlobalOption(ModelLayout); + } + + public Argument InputFile { get; } + + public Argument OutputFile { get; } + + public Option InputFormat { get; } + + public Option> DumpFlags { get; } + + public Option DumpDir { get; } + + public Option QuantType { get; } + + public Option WQuantType { get; } + + public Option Dataset { get; } + + public Option DatasetFormat { get; } + + public Option ModelQuantMode { get; } + + public Option CalibMethod { get; } + + public Option> FixedVars { get; } + + public Option PreProcess { get; } + + public Option InputLayout { get; } + + public Option OutputLayout { get; } + + public Option InputType { get; } + + public Option> InputShape { get; } + + public Option> InputRange { get; } + + public Option SwapRB { get; } + + public Option LetterBoxValue { get; } + + public Option> Mean { get; } + + public Option> Std { get; } + + public Option ModelLayout { get; } +} diff --git a/src/Nncase.Cli/Nncase.Cli.csproj b/src/Nncase.Cli/Nncase.Cli.csproj index 3070806ba5..17e370d4ca 100644 --- a/src/Nncase.Cli/Nncase.Cli.csproj +++ b/src/Nncase.Cli/Nncase.Cli.csproj @@ -26,4 +26,8 @@ PreserveNewest + + + + diff --git a/src/Nncase.Cli/Program.CommandLine.cs b/src/Nncase.Cli/Program.CommandLine.cs deleted file mode 100644 index e88f691647..0000000000 --- a/src/Nncase.Cli/Program.CommandLine.cs +++ /dev/null @@ -1,26 +0,0 @@ -// Copyright (c) Canaan Inc. All rights reserved. -// Licensed under the Apache license. See LICENSE file in the project root for full license information. - -using System; -using System.CommandLine; -using System.CommandLine.Builder; -using System.Linq; - -namespace Nncase.Cli; - -internal partial class Program -{ - private static CommandLineBuilder BuildCommandLine() - { - var commands = from t in typeof(Program).Assembly.ExportedTypes - where t.Namespace == "Nncase.Cli.Commands" && t.IsAssignableTo(typeof(Command)) - select (Command)Activator.CreateInstance(t)!; - var root = new RootCommand(); - foreach (var command in commands) - { - root.AddCommand(command); - } - - return new CommandLineBuilder(root); - } -} diff --git a/src/Nncase.Cli/Program.cs b/src/Nncase.Cli/Program.cs index 2d3e40c803..0ef25c0565 100644 --- a/src/Nncase.Cli/Program.cs +++ b/src/Nncase.Cli/Program.cs @@ -1,10 +1,14 @@ // Copyright (c) Canaan Inc. All rights reserved. // Licensed under the Apache license. See LICENSE file in the project root for full license information. +using System; +using System.Collections.Generic; +using System.CommandLine; using System.CommandLine.Builder; using System.CommandLine.Hosting; using System.CommandLine.Parsing; using System.IO; +using System.Linq; using System.Threading.Tasks; using Microsoft.Extensions.Configuration; using Microsoft.Extensions.Hosting; @@ -16,12 +20,155 @@ internal partial class Program { public static async Task Main(string[] args) { - return await BuildCommandLine() + return await ConfigureCommandLine() .UseHost(ConfigureHost) .UseDefaults() .Build().InvokeAsync(args); } + private static async Task RunAsync(string targetKind, CompileOptions compileOptions, DatasetFormat datasetFormat, string dataset, string outputFile, IHost host) + { + CompilerServices.Configure(host.Services); + + // 2. import the model + var target = CompilerServices.GetTarget(targetKind); + using var compileSession = CompileSession.Create(target, compileOptions); + var compiler = compileSession.Compiler; + IR.IRModule module = await compiler.ImportModuleAsync(Path.GetExtension(compileOptions.InputFile).Trim('.'), compileOptions.InputFile); + + // 3. create the calib dataset + if (compileOptions.QuantizeOptions.ModelQuantMode == Quantization.ModelQuantMode.UsePTQ) + { + if (datasetFormat == DatasetFormat.Random) + { + compileOptions.QuantizeOptions.CalibrationDataset = new Quantization.RandomCalibrationDatasetProvider(((Nncase.IR.Function)module.Entry!).Parameters.ToArray(), 5); + } + else if (datasetFormat == DatasetFormat.Pytest) + { + compileOptions.QuantizeOptions.CalibrationDataset = new Quantization.PytestCalibrationDatasetProvider(((IR.Function)module.Entry!).Parameters.ToArray(), dataset); + } + else + { + throw new NotSupportedException(datasetFormat.ToString()); + } + } + + // 4. compile + await compiler.CompileAsync(); + + // 5. code gen + using (var os = File.OpenWrite(outputFile)) + { + compiler.Gencode(os); + } + } + + private static CommandLineBuilder ConfigureCommandLine() + { + var compile = new CompileCommand(); + foreach (var target in LoadTargets()) + { + var (targetCmd, targetParser) = target.RegisterCommandAndParser(); + Action targetHandler = async (System.CommandLine.Invocation.InvocationContext context) => + { + var options = ParseCompileOptions(context, compile); + options.TargetCompileOptions = targetParser(context, targetCmd); + await RunAsync(targetCmd.Name, options, context.ParseResult.GetValueForOption(compile.DatasetFormat), context.ParseResult.GetValueForOption(compile.Dataset)!, context.ParseResult.GetValueForArgument(compile.OutputFile), context.GetHost()); + }; + targetCmd.SetHandler(targetHandler); + compile.AddCommand(targetCmd); + } + + return new CommandLineBuilder(new RootCommand() { compile }); + } + + private static CompileOptions ParseCompileOptions(System.CommandLine.Invocation.InvocationContext context, CompileCommand compilecmd) + { + // 1. setup the options + var compileOptions = new CompileOptions + { + InputFile = context.ParseResult.GetValueForArgument(compilecmd.InputFile), + InputFormat = context.ParseResult.GetValueForOption(compilecmd.InputFormat)!, + DumpFlags = context.ParseResult.GetValueForOption(compilecmd.DumpFlags)!.Aggregate(Diagnostics.DumpFlags.None, (a, b) => a | b), + DumpDir = context.ParseResult.GetValueForOption(compilecmd.DumpDir)!, + PreProcess = context.ParseResult.GetValueForOption(compilecmd.PreProcess)!, + InputLayout = context.ParseResult.GetValueForOption(compilecmd.InputLayout)!, + OutputLayout = context.ParseResult.GetValueForOption(compilecmd.OutputLayout)!, + InputType = context.ParseResult.GetValueForOption(compilecmd.InputType)!, + InputShape = context.ParseResult.GetValueForOption(compilecmd.InputShape)!.ToArray(), + InputRange = context.ParseResult.GetValueForOption(compilecmd.InputRange)!.ToArray(), + SwapRB = context.ParseResult.GetValueForOption(compilecmd.SwapRB)!, + LetterBoxValue = context.ParseResult.GetValueForOption(compilecmd.LetterBoxValue)!, + Mean = context.ParseResult.GetValueForOption(compilecmd.Mean)!.ToArray(), + Std = context.ParseResult.GetValueForOption(compilecmd.Std)!.ToArray(), + ModelLayout = context.ParseResult.GetValueForOption(compilecmd.ModelLayout)!, + QuantizeOptions = new() + { + CalibrationMethod = context.ParseResult.GetValueForOption(compilecmd.CalibMethod), + QuantType = context.ParseResult.GetValueForOption(compilecmd.QuantType) switch + { + QuantType.UInt8 => DataTypes.UInt8, + QuantType.Int8 => DataTypes.Int8, + QuantType.Int16 => DataTypes.Int16, + _ => throw new ArgumentException("Invalid quant type"), + }, + WQuantType = context.ParseResult.GetValueForOption(compilecmd.WQuantType) switch + { + QuantType.UInt8 => DataTypes.UInt8, + QuantType.Int8 => DataTypes.Int8, + QuantType.Int16 => DataTypes.Int16, + _ => throw new ArgumentException("Invalid weights quant type"), + }, + ModelQuantMode = context.ParseResult.GetValueForOption(compilecmd.ModelQuantMode), + }, + }; + + foreach (var item in context.ParseResult.GetValueForOption(compilecmd.FixedVars)!) + { + compileOptions.ShapeBucketOptions.FixVarMap.Add(item.Name, item.Value); + } + + return compileOptions; + } + + private static IReadOnlyList LoadTargets() + { + var loadContext = System.Runtime.Loader.AssemblyLoadContext.Default; + var pluginAsms = PluginLoader.GetPluginsSearchDirectories(PluginLoader.PluginPathEnvName, null). + Select(PluginLoader.GetPluginAssemblies). + SelectMany(x => x). + DistinctBy(Path.GetFileName). + Select(x => PluginLoader.LoadPluginAssembly(x, loadContext)). + Distinct(). + ToList(); + pluginAsms.AddRange(new[] { Path.GetDirectoryName(typeof(Program).Assembly.Location)! }. + Select(basePath => + { + if (Directory.Exists(basePath)) + { + return (from filePath in Directory.GetFiles(basePath, PluginLoader.ModulesDllPattern, SearchOption.AllDirectories) + where PluginLoader.IsLoadableAssembly(filePath) + select filePath).Distinct(); + } + else + { + return Array.Empty(); + } + }). + SelectMany(x => x). + DistinctBy(Path.GetFileName). + Select(x => PluginLoader.LoadPluginAssembly(x, loadContext)). + Distinct()); + var targets = (from asm in pluginAsms + from t in asm.ExportedTypes + where t.IsClass + && t.IsAssignableTo(typeof(ITarget)) + let ctor = t.GetConstructor(Type.EmptyTypes) + where ctor != null + select (ITarget)ctor.Invoke(null)).ToList(); + return targets; + } + private static void ConfigureHost(IHostBuilder hostBuilder) { hostBuilder.ConfigureAppConfiguration(ConfigureAppConfiguration) diff --git a/src/Nncase.Cli/packages.lock.json b/src/Nncase.Cli/packages.lock.json index b438ba9cc6..56eaafb112 100644 --- a/src/Nncase.Cli/packages.lock.json +++ b/src/Nncase.Cli/packages.lock.json @@ -33,21 +33,22 @@ }, "StyleCop.Analyzers": { "type": "Direct", - "requested": "[1.2.0-beta.507, )", - "resolved": "1.2.0-beta.507", - "contentHash": "/FtugDT66cKJJ+GGH7rNpG6UDrT4iIWz45M6lrXXHobDUFDHw+q5VgkbiR+6ffTO564ge7w6fQh/eoQhVdJO8Q==", + "requested": "[1.2.0-beta.435, )", + "resolved": "1.2.0-beta.435", + "contentHash": "TADk7vdGXtfTnYCV7GyleaaRTQjfoSfZXprQrVMm7cSJtJbFc1QIbWPyLvrgrfGdfHbGmUPvaN4ODKNxg2jgPQ==", "dependencies": { - "StyleCop.Analyzers.Unstable": "1.2.0.507" + "StyleCop.Analyzers.Unstable": "1.2.0.435" } }, "System.CommandLine.Hosting": { "type": "Direct", - "requested": "[0.3.0-alpha.21216.1, )", - "resolved": "0.3.0-alpha.21216.1", - "contentHash": "zP8QEUH8dSUYUHdGk6k71kOJy8uFgEPZG2RfhA0cMjDH3/Jov5AjUNaxOvpSNHh+ewu8eIUCYgV8+fEkCPyNlw==", + "requested": "[0.4.0-alpha.22272.1, )", + "resolved": "0.4.0-alpha.22272.1", + "contentHash": "x9JhHxBLxlKyCIZADFYC8q16L9yGHdTakrLFjHabwR7Tk0761aTexiGgMTIS744HGuhc8pk9MoLUzsr/TlRfMQ==", "dependencies": { - "Microsoft.Extensions.Hosting": "3.1.5", - "System.CommandLine": "2.0.0-beta1.21216.1" + "Microsoft.Extensions.Hosting": "6.0.0", + "System.CommandLine": "2.0.0-beta4.22272.1", + "System.CommandLine.NamingConventionBinder": "2.0.0-beta4.22272.1" } }, "Google.OrTools.runtime.linux-arm64": { @@ -344,8 +345,8 @@ }, "StyleCop.Analyzers.Unstable": { "type": "Transitive", - "resolved": "1.2.0.507", - "contentHash": "gTY3IQdRqDJ4hbhSA3e/R48oE8b/OiKfvwkt1QdNVfrJK2gMHBV8ldaHJ885jxWZfllK66soa/sdcjh9bX49Tw==" + "resolved": "1.2.0.435", + "contentHash": "ouwPWZxbOV3SmCZxIRqHvljkSzkCyi1tDoMzQtDb/bRP8ctASV/iRJr+A2Gdj0QLaLmWnqTWDrH82/iP+X80Lg==" }, "System.Buffers": { "type": "Transitive", @@ -362,13 +363,12 @@ "System.Runtime": "4.3.0" } }, - "System.CommandLine": { + "System.CommandLine.NamingConventionBinder": { "type": "Transitive", - "resolved": "2.0.0-beta1.21216.1", - "contentHash": "Nbv/tW8sbOKN5T+4SSVBMdk4ADSIpJpY4UHMsj3VkcNtOckIT4iyzagjF+W5FEh2YBRvmvVQijOTIZbUJ1+1aA==", + "resolved": "2.0.0-beta4.22272.1", + "contentHash": "ux2eUA/syF+JtlpMDc/Lsd6PBIBuwjH3AvHnestoh5uD0WKT5b+wkQxDWVCqp9qgVjMBTLNhX19ZYFtenunt9A==", "dependencies": { - "Microsoft.CSharp": "4.4.1", - "system.memory": "4.5.4" + "System.CommandLine": "2.0.0-beta4.22272.1" } }, "System.Diagnostics.Contracts": { @@ -696,6 +696,7 @@ "Microsoft.Extensions.Options": "[6.0.0, )", "Microsoft.Toolkit.HighPerformance": "[7.1.1, )", "NetFabric.Hyperlinq": "[3.0.0-beta48, )", + "System.CommandLine": "[2.0.0-beta4.22272.1, )", "System.Reactive": "[5.0.0, )" } }, @@ -937,6 +938,12 @@ "resolved": "1.0.2", "contentHash": "giLAHrjJe0Bh7yhNexR6pmcv02+Fi+lEPxQVdB8zvkuJCmy6rnqu8CZLIpxrUfLcWDuTCSiK0IfGmMhig3UDhA==" }, + "System.CommandLine": { + "type": "CentralTransitive", + "requested": "[2.0.0-beta4.22272.1, )", + "resolved": "2.0.0-beta4.22272.1", + "contentHash": "1uqED/q2H0kKoLJ4+hI2iPSBSEdTuhfCYADeJrAqERmiGQ2NNacYKRNEQ+gFbU4glgVyK8rxI+ZOe1onEtr/Pg==" + }, "System.Linq.Async": { "type": "CentralTransitive", "requested": "[6.0.1, )", diff --git a/src/Nncase.CodeGen/CodeGen/LinkedFunction.cs b/src/Nncase.CodeGen/CodeGen/LinkedFunction.cs index b3dbe692e0..bd4d97cafc 100644 --- a/src/Nncase.CodeGen/CodeGen/LinkedFunction.cs +++ b/src/Nncase.CodeGen/CodeGen/LinkedFunction.cs @@ -15,7 +15,6 @@ public class LinkedFunction : ILinkedFunction public LinkedFunction(uint id, Callable sourceFunction, ulong textBegin, ulong textLength, IReadOnlyList sections) { Id = id; - CompilerServices.InferenceType(sourceFunction); ParameterTypes = ((CallableType)sourceFunction.CheckedType).Parameters.ToArray(); ReturnType = ((CallableType)sourceFunction.CheckedType).ReturnType; TextBegin = textBegin; diff --git a/src/Nncase.CodeGen/packages.lock.json b/src/Nncase.CodeGen/packages.lock.json index b618b5504c..fd39ebd2fc 100644 --- a/src/Nncase.CodeGen/packages.lock.json +++ b/src/Nncase.CodeGen/packages.lock.json @@ -10,11 +10,11 @@ }, "StyleCop.Analyzers": { "type": "Direct", - "requested": "[1.2.0-beta.507, )", - "resolved": "1.2.0-beta.507", - "contentHash": "/FtugDT66cKJJ+GGH7rNpG6UDrT4iIWz45M6lrXXHobDUFDHw+q5VgkbiR+6ffTO564ge7w6fQh/eoQhVdJO8Q==", + "requested": "[1.2.0-beta.435, )", + "resolved": "1.2.0-beta.435", + "contentHash": "TADk7vdGXtfTnYCV7GyleaaRTQjfoSfZXprQrVMm7cSJtJbFc1QIbWPyLvrgrfGdfHbGmUPvaN4ODKNxg2jgPQ==", "dependencies": { - "StyleCop.Analyzers.Unstable": "1.2.0.507" + "StyleCop.Analyzers.Unstable": "1.2.0.435" } }, "Microsoft.Extensions.Configuration.Abstractions": { @@ -53,8 +53,8 @@ }, "StyleCop.Analyzers.Unstable": { "type": "Transitive", - "resolved": "1.2.0.507", - "contentHash": "gTY3IQdRqDJ4hbhSA3e/R48oE8b/OiKfvwkt1QdNVfrJK2gMHBV8ldaHJ885jxWZfllK66soa/sdcjh9bX49Tw==" + "resolved": "1.2.0.435", + "contentHash": "ouwPWZxbOV3SmCZxIRqHvljkSzkCyi1tDoMzQtDb/bRP8ctASV/iRJr+A2Gdj0QLaLmWnqTWDrH82/iP+X80Lg==" }, "System.Buffers": { "type": "Transitive", @@ -76,6 +76,7 @@ "Microsoft.Extensions.Options": "[6.0.0, )", "Microsoft.Toolkit.HighPerformance": "[7.1.1, )", "NetFabric.Hyperlinq": "[3.0.0-beta48, )", + "System.CommandLine": "[2.0.0-beta4.22272.1, )", "System.Reactive": "[5.0.0, )" } }, @@ -138,6 +139,12 @@ "System.Runtime.CompilerServices.Unsafe": "5.0.0" } }, + "System.CommandLine": { + "type": "CentralTransitive", + "requested": "[2.0.0-beta4.22272.1, )", + "resolved": "2.0.0-beta4.22272.1", + "contentHash": "1uqED/q2H0kKoLJ4+hI2iPSBSEdTuhfCYADeJrAqERmiGQ2NNacYKRNEQ+gFbU4glgVyK8rxI+ZOe1onEtr/Pg==" + }, "System.Reactive": { "type": "CentralTransitive", "requested": "[5.0.0, )", diff --git a/src/Nncase.Compiler/Compiler.cs b/src/Nncase.Compiler/Compiler.cs index 9895f6e0b5..afd0607878 100644 --- a/src/Nncase.Compiler/Compiler.cs +++ b/src/Nncase.Compiler/Compiler.cs @@ -88,13 +88,15 @@ public void AddPreAndPostProcess(IPassManager passManager) public void TargetIndependentPass(IPassManager passManager) { - passManager.AddWithName("ReshapeMatMul").Configure(p => + passManager.AddWithName("NormAxisAndShape").Configure(p => { p.Add(); - }); - - passManager.AddWithName("SqueezeShape").Configure(p => - { + p.Add(); + p.Add(); + p.Add(); + p.Add(); + p.Add(); + p.Add(); p.Add(); p.Add(); p.Add(); @@ -102,6 +104,7 @@ public void TargetIndependentPass(IPassManager passManager) p.Add(); p.Add(); p.Add(); + p.Add(); p.Add(); p.Add(); p.Add(); @@ -124,6 +127,7 @@ public void TargetIndependentPass(IPassManager passManager) p.Add(); p.Add(); p.Add(); + p.Add(); p.Add(); p.Add(); p.Add(); @@ -157,6 +161,8 @@ public void TargetIndependentPass(IPassManager passManager) p.Add(); p.Add(); p.Add(); + p.Add(); + p.Add(); p.Add(); p.Add(); p.Add(); @@ -168,6 +174,7 @@ public void TargetIndependentPass(IPassManager passManager) p.Add(); p.Add(); p.Add(); + p.Add(); p.Add(); p.Add(); p.Add(); diff --git a/src/Nncase.Compiler/Hosting/PluginLoader.cs b/src/Nncase.Compiler/Hosting/PluginLoader.cs index 73f10f367a..014ab29fb1 100644 --- a/src/Nncase.Compiler/Hosting/PluginLoader.cs +++ b/src/Nncase.Compiler/Hosting/PluginLoader.cs @@ -19,12 +19,14 @@ namespace Nncase.Hosting; /// public sealed class PluginLoader { - private const string _modulesDllPattern = "Nncase.Modules.*.dll"; - private const string _pluginPathEnvName = "NNCASE_PLUGIN_PATH"; + public const string PluginPathEnvName = "NNCASE_PLUGIN_PATH"; + + public const string ModulesDllPattern = "Nncase.Modules.*.dll"; private static readonly string[] _builtinModules = new[] { "Nncase.Modules.StackVM.dll", + "Nncase.Modules.CPU.dll", "Nncase.Modules.K210.dll", }; @@ -42,67 +44,16 @@ public PluginLoader(ILogger logger) ?? AssemblyLoadContext.Default; } - /// - /// Load plugins. - /// - /// Plugins. - public IReadOnlyList LoadPlugins() - { - var pluginAsms = GetPluginsSearchDirectories().Select(GetPluginAssemblies).SelectMany(x => x) - .DistinctBy(Path.GetFileName).Select(LoadPluginAssembly).Distinct().ToList(); - var plugins = (from asm in pluginAsms - from t in asm.ExportedTypes - where t.IsClass - && t.IsAssignableTo(typeof(IPlugin)) - let ctor = t.GetConstructor(Type.EmptyTypes) - where ctor != null - select (IPlugin)ctor.Invoke(null)).ToList(); - - return plugins; - } - - private static bool IsLoadableAssembly(string filePath) - { - using var fs = File.OpenRead(filePath); - using var peReader = new PEReader(fs); - - if (!peReader.HasMetadata) - { - return false; - } - - var metaReader = peReader.GetMetadataReader(); - if (!metaReader.IsAssembly) - { - return false; - } - - // Is reference assembly - if ((from cah in metaReader.CustomAttributes - let ca = metaReader.GetCustomAttribute(cah) - where ca.Constructor.Kind == HandleKind.MemberReference - let ctor = metaReader.GetMemberReference((MemberReferenceHandle)ca.Constructor) - let attrType = metaReader.GetTypeReference((TypeReferenceHandle)ctor.Parent) - where metaReader.GetString(attrType.Namespace) == nameof(System.Runtime.CompilerServices) - && metaReader.GetString(attrType.Name) == nameof(ReferenceAssemblyAttribute) - select cah).Any()) - { - return false; - } - - return true; - } - - private Assembly LoadPluginAssembly(string assemblyFile) + public static Assembly LoadPluginAssembly(string assemblyFile, AssemblyLoadContext loadContext) { - return _loadContext.LoadFromAssemblyPath(assemblyFile); + return loadContext.LoadFromAssemblyPath(assemblyFile); } - private IEnumerable GetPluginAssemblies(string basePath) + public static IEnumerable GetPluginAssemblies(string basePath) { if (Directory.Exists(basePath)) { - return (from filePath in Directory.GetFiles(basePath, _modulesDllPattern, SearchOption.AllDirectories) + return (from filePath in Directory.GetFiles(basePath, ModulesDllPattern, SearchOption.AllDirectories) where !_builtinModules.Contains(Path.GetFileName(filePath)) && IsLoadableAssembly(filePath) select filePath).Distinct(); @@ -113,19 +64,22 @@ private IEnumerable GetPluginAssemblies(string basePath) } } - private IEnumerable GetPluginsSearchDirectories() + public static IEnumerable GetPluginsSearchDirectories(string pluginPathEnvName, ILogger? logger) { var directories = new List(); // 1. Environment variable - var targetPathEnv = Environment.GetEnvironmentVariable(_pluginPathEnvName); + var targetPathEnv = Environment.GetEnvironmentVariable(pluginPathEnvName); if (string.IsNullOrWhiteSpace(targetPathEnv)) { - _logger.LogWarning($"{_pluginPathEnvName} is not set."); + if (logger is not null) + { + logger.LogWarning($"{pluginPathEnvName} is not set."); + } } else { - var targetPaths = from path in targetPathEnv.Split(Path.PathSeparator, StringSplitOptions.RemoveEmptyEntries) + var targetPaths = from path in targetPathEnv!.Split(Path.PathSeparator, StringSplitOptions.RemoveEmptyEntries) select Environment.ExpandEnvironmentVariables(path); directories.AddRange(targetPaths); } @@ -135,11 +89,62 @@ private IEnumerable GetPluginsSearchDirectories() var modulesPath = Path.Combine(rootPath, "modules"); directories.Add(modulesPath); - if (_logger.IsEnabled(LogLevel.Trace)) + if (logger is not null && logger.IsEnabled(LogLevel.Trace)) { - _logger.LogInformation($"Loading plugins from {string.Join(", ", directories)}."); + logger.LogInformation($"Loading plugins from {string.Join(", ", directories)}."); } return directories.Distinct(); } + + public static bool IsLoadableAssembly(string filePath) + { + using var fs = File.OpenRead(filePath); + using var peReader = new PEReader(fs); + + if (!peReader.HasMetadata) + { + return false; + } + + var metaReader = peReader.GetMetadataReader(); + if (!metaReader.IsAssembly) + { + return false; + } + + // Is reference assembly + if ((from cah in metaReader.CustomAttributes + let ca = metaReader.GetCustomAttribute(cah) + where ca.Constructor.Kind == HandleKind.MemberReference + let ctor = metaReader.GetMemberReference((MemberReferenceHandle)ca.Constructor) + let attrType = metaReader.GetTypeReference((TypeReferenceHandle)ctor.Parent) + where metaReader.GetString(attrType.Namespace) == nameof(System.Runtime.CompilerServices) + && metaReader.GetString(attrType.Name) == nameof(ReferenceAssemblyAttribute) + select cah).Any()) + { + return false; + } + + return true; + } + + /// + /// Load plugins. + /// + /// Plugins. + public IReadOnlyList LoadPlugins() + { + var pluginAsms = GetPluginsSearchDirectories(PluginPathEnvName, _logger).Select(GetPluginAssemblies).SelectMany(x => x) + .DistinctBy(Path.GetFileName).Select(x => LoadPluginAssembly(x, _loadContext)).Distinct().ToList(); + var plugins = (from asm in pluginAsms + from t in asm.ExportedTypes + where t.IsClass + && t.IsAssignableTo(typeof(IPlugin)) + let ctor = t.GetConstructor(Type.EmptyTypes) + where ctor != null + select (IPlugin)ctor.Invoke(null)).ToList(); + + return plugins; + } } diff --git a/src/Nncase.Compiler/packages.lock.json b/src/Nncase.Compiler/packages.lock.json index 639bb6a9bc..f22a606140 100644 --- a/src/Nncase.Compiler/packages.lock.json +++ b/src/Nncase.Compiler/packages.lock.json @@ -49,11 +49,11 @@ }, "StyleCop.Analyzers": { "type": "Direct", - "requested": "[1.2.0-beta.507, )", - "resolved": "1.2.0-beta.507", - "contentHash": "/FtugDT66cKJJ+GGH7rNpG6UDrT4iIWz45M6lrXXHobDUFDHw+q5VgkbiR+6ffTO564ge7w6fQh/eoQhVdJO8Q==", + "requested": "[1.2.0-beta.435, )", + "resolved": "1.2.0-beta.435", + "contentHash": "TADk7vdGXtfTnYCV7GyleaaRTQjfoSfZXprQrVMm7cSJtJbFc1QIbWPyLvrgrfGdfHbGmUPvaN4ODKNxg2jgPQ==", "dependencies": { - "StyleCop.Analyzers.Unstable": "1.2.0.507" + "StyleCop.Analyzers.Unstable": "1.2.0.435" } }, "Google.OrTools.runtime.linux-arm64": { @@ -350,8 +350,8 @@ }, "StyleCop.Analyzers.Unstable": { "type": "Transitive", - "resolved": "1.2.0.507", - "contentHash": "gTY3IQdRqDJ4hbhSA3e/R48oE8b/OiKfvwkt1QdNVfrJK2gMHBV8ldaHJ885jxWZfllK66soa/sdcjh9bX49Tw==" + "resolved": "1.2.0.435", + "contentHash": "ouwPWZxbOV3SmCZxIRqHvljkSzkCyi1tDoMzQtDb/bRP8ctASV/iRJr+A2Gdj0QLaLmWnqTWDrH82/iP+X80Lg==" }, "System.Buffers": { "type": "Transitive", @@ -674,6 +674,7 @@ "Microsoft.Extensions.Options": "[6.0.0, )", "Microsoft.Toolkit.HighPerformance": "[7.1.1, )", "NetFabric.Hyperlinq": "[3.0.0-beta48, )", + "System.CommandLine": "[2.0.0-beta4.22272.1, )", "System.Reactive": "[5.0.0, )" } }, @@ -885,6 +886,12 @@ "resolved": "1.0.2", "contentHash": "giLAHrjJe0Bh7yhNexR6pmcv02+Fi+lEPxQVdB8zvkuJCmy6rnqu8CZLIpxrUfLcWDuTCSiK0IfGmMhig3UDhA==" }, + "System.CommandLine": { + "type": "CentralTransitive", + "requested": "[2.0.0-beta4.22272.1, )", + "resolved": "2.0.0-beta4.22272.1", + "contentHash": "1uqED/q2H0kKoLJ4+hI2iPSBSEdTuhfCYADeJrAqERmiGQ2NNacYKRNEQ+gFbU4glgVyK8rxI+ZOe1onEtr/Pg==" + }, "System.Linq.Async": { "type": "CentralTransitive", "requested": "[6.0.1, )", diff --git a/src/Nncase.Core/CompileOptions.cs b/src/Nncase.Core/CompileOptions.cs index 5d7b3cb058..e5d1b2874c 100644 --- a/src/Nncase.Core/CompileOptions.cs +++ b/src/Nncase.Core/CompileOptions.cs @@ -119,4 +119,9 @@ public sealed record CompileOptions /// Gets or sets a value indicating whether is benchmark only. /// public bool IsBenchmarkOnly { get; set; } + + /// + /// Gets or sets the target compile options. + /// + public ITargetCompileOptions TargetCompileOptions { get; set; } = null!; } diff --git a/src/Nncase.Core/CompilerServices.cs b/src/Nncase.Core/CompilerServices.cs index 5d1bd6cb83..9c666bcd21 100644 --- a/src/Nncase.Core/CompilerServices.cs +++ b/src/Nncase.Core/CompilerServices.cs @@ -73,6 +73,14 @@ public interface ICompilerServicesProvider /// false for save const into bin. public void DumpCSharpIR(Expr expr, string prefix, string dumpDir, bool randConst); + /// + /// dump the expr as csharp code. + /// + /// expression. + /// file prefix. + /// file dump ir. + public void DumpPatternIR(Expr expr, string prefix, string dumpDir); + /// /// print ir type. /// @@ -468,6 +476,15 @@ public static void DumpDotIR(Expr expr, string prefix, string dumpPath, bool dis public static void DumpCSharpIR(Expr expr, string prefix, string dumpDir, bool randConst = true) => Provider.DumpCSharpIR(expr, prefix, dumpDir, randConst); + /// + /// dump the expr as csharp code. + /// + /// expression. + /// file prefix. + /// file dump ir. + public static void DumpPatternIR(Expr expr, string prefix, string dumpDir) => + Provider.DumpPatternIR(expr, prefix, dumpDir); + public static string Print(IRType type) => Provider.Print(type); public static string Print(Expr expr, bool useScript = false) => Provider.Print(expr, useScript); @@ -583,6 +600,10 @@ public void DumpDotIR(Expr expr, string prefix, string dumpPath, bool display_ca public void DumpCSharpIR(Expr expr, string prefix, string dumpDir, bool randConst) => _irprinterProvider.DumpCSharpIR(expr, prefix, dumpDir, randConst); + /// + public void DumpPatternIR(Expr expr, string prefix, string dumpDir) => + _irprinterProvider.DumpPatternIR(expr, prefix, dumpDir); + /// public string Print(IRType type) => _irprinterProvider.Print(type); diff --git a/src/Nncase.Core/Converters/ConvertersModule.cs b/src/Nncase.Core/Converters/ConvertersModule.cs index c7e5d4a9bc..3b406a8c88 100644 --- a/src/Nncase.Core/Converters/ConvertersModule.cs +++ b/src/Nncase.Core/Converters/ConvertersModule.cs @@ -28,5 +28,6 @@ public void ConfigureServices(IRegistrator registrator) registrator.RegisterManyInterface(reuse: Reuse.Singleton); registrator.RegisterManyInterface(reuse: Reuse.Singleton); registrator.RegisterManyInterface(reuse: Reuse.Singleton); + registrator.RegisterManyInterface(reuse: Reuse.Singleton); } } diff --git a/src/Nncase.Core/Converters/PointerConverters.cs b/src/Nncase.Core/Converters/PointerConverters.cs index af274fab5e..134c773609 100644 --- a/src/Nncase.Core/Converters/PointerConverters.cs +++ b/src/Nncase.Core/Converters/PointerConverters.cs @@ -30,3 +30,25 @@ public void ConvertTo(ReadOnlySpan> source, Span dest, Cast } } } + +internal class PointerIntConverters : IPointerSpanConverter +{ + public void ConvertTo(ReadOnlySpan> source, Span dest, CastMode castMode) + where T : unmanaged, IEquatable + { + if (castMode != CastMode.KDefault) + { + throw new InvalidCastException(); + } + + if (dest.Length < source.Length) + { + throw new ArgumentException("Dest buffer is not sufficient."); + } + + for (int i = 0; i < source.Length; i++) + { + dest[i] = checked((int)source[i].Value); + } + } +} diff --git a/src/Nncase.Core/CostModel/Cost.cs b/src/Nncase.Core/CostModel/Cost.cs index b989507457..e60414a613 100644 --- a/src/Nncase.Core/CostModel/Cost.cs +++ b/src/Nncase.Core/CostModel/Cost.cs @@ -204,6 +204,7 @@ public static UInt128 GetMemoryAccess(IRType type) { TensorType t => (UInt128)(t.Shape.Aggregate(1D, (acc, x) => acc * (x.IsFixed ? x.FixedValue : 1)) * t.DType.SizeInBytes), TupleType t => t.Fields.Sum(GetMemoryAccess), + DistributedType t => GetMemoryAccess(Utilities.DistributedUtility.GetDividedTensorType(t)), _ => 0, }; } @@ -229,6 +230,7 @@ public static UInt128 GetCPUCycles(IRType type, double cyclesPerElement = 1) { TensorType t => (UInt128)(t.Shape.Aggregate(1D, (acc, x) => acc * (x.IsFixed ? x.FixedValue : 1)) * cyclesPerElement), TupleType t => t.Fields.Sum(GetMemoryAccess), + DistributedType t => GetCPUCycles(Utilities.DistributedUtility.GetDividedTensorType(t)), _ => 0, }; } @@ -328,7 +330,7 @@ public static Cost GetActivationCost(TensorType ret, uint macPerElement) } // cost for op similar to broadcast - public static Cost GetBroadcastCost(TensorType input, TensorType ret) + public static Cost GetBroadcastCost(IRType input, IRType ret) { return new() { diff --git a/src/Nncase.Core/DataTypes.cs b/src/Nncase.Core/DataTypes.cs index d1a24255e7..0b59a5696e 100644 --- a/src/Nncase.Core/DataTypes.cs +++ b/src/Nncase.Core/DataTypes.cs @@ -114,7 +114,7 @@ public static bool IsPointer(this DataType srcType) => /// datatype name. public static string GetDisplayName(this DataType dataType) => dataType switch { - PointerType pointerType => $"({GetDisplayName(pointerType.ElemType)}*)", + PointerType pointerType => $"({GetDisplayName(pointerType.ElemType)} *)", PrimType primType => primType.ShortName, ValueType => dataType.ToString(), _ => throw new ArgumentOutOfRangeException(dataType.GetType().Name), diff --git a/src/Nncase.Core/Diagnostics/IDumpper.cs b/src/Nncase.Core/Diagnostics/IDumpper.cs index 0e07109232..cc84bf46cb 100644 --- a/src/Nncase.Core/Diagnostics/IDumpper.cs +++ b/src/Nncase.Core/Diagnostics/IDumpper.cs @@ -42,6 +42,8 @@ public interface IDumpper void DumpCSharpIR(Expr expr, string prefix, string? reletivePath = null); + void DumpPatternIR(Expr expr, string prefix, string? reletivePath = null); + void DumpModule(IRModule module, string? reletivePath = null); Stream OpenFile(string reletivePath, FileMode fileMode = FileMode.Create); diff --git a/src/Nncase.Core/Diagnostics/NullDumpper.cs b/src/Nncase.Core/Diagnostics/NullDumpper.cs index 7212fc7686..3120fa25e5 100644 --- a/src/Nncase.Core/Diagnostics/NullDumpper.cs +++ b/src/Nncase.Core/Diagnostics/NullDumpper.cs @@ -46,6 +46,11 @@ public void DumpCSharpIR(Expr expr, string prefix, string? reletivePath = null) { } + /// + public void DumpPatternIR(Expr expr, string prefix, string? reletivePath = null) + { + } + /// public bool IsEnabled(DumpFlags dumpFlags) => false; diff --git a/src/Nncase.Core/DistributedType.cs b/src/Nncase.Core/DistributedType.cs new file mode 100644 index 0000000000..efe52395da --- /dev/null +++ b/src/Nncase.Core/DistributedType.cs @@ -0,0 +1,68 @@ +// Copyright (c) Canaan Inc. All rights reserved. +// Licensed under the Apache license. See LICENSE file in the project root for full license information. +using System; +using System.Collections; +using System.Collections.Generic; +using System.Collections.Immutable; +using System.Globalization; +using System.IO; +using System.Linq; +using System.Text; +using System.Threading.Tasks; +using DryIoc.ImTools; + +namespace Nncase.IR; + +public abstract record SBP +{ + public static SBPPartialSum P => SBPPartialSum.Instance; + + public static SBPBroadCast B => SBPBroadCast.Instance; + + public static SBPSplit S(int axis) => new SBPSplit(axis); +} + +public sealed record SBPSplit(int Axis) : SBP +{ + public override string ToString() => $"S({Axis})"; +} + +public sealed record SBPPartialSum : SBP +{ + public static readonly SBPPartialSum Instance = new SBPPartialSum(); + + private SBPPartialSum() + { + } + + public override string ToString() => "P"; +} + +public sealed record SBPBroadCast : SBP +{ + public static readonly SBPBroadCast Instance = new SBPBroadCast(); + + private SBPBroadCast() + { + } + + public override string ToString() => "B"; +} + +// public sealed record Placement(Placement.DeviceKind Kind, IRArray Hierarchy, string Name) +public sealed record Placement(IRArray Hierarchy, string Name) +{ + // public enum DeviceKind : uint + // { + // CPU = 0, + // } + public int Rank => Hierarchy.Count; + + // public override string ToString() => $"@{Kind} [{string.Join(',', Hierarchy.Zip(Name).Select(t => t.First.ToString() + '@' + t.Second.ToString()))}]"; + public override string ToString() => $"@ [{string.Join(',', Hierarchy.Zip(Name).Select(t => t.First.ToString() + '@' + t.Second.ToString()))}]"; +} + +public sealed record DistributedType(TensorType TensorType, IRArray NdSBP, Placement Placement) : IRType +{ + public override string ToString() => $"{TensorType}, ({string.Join(',', NdSBP)}), {Placement}"; +} diff --git a/src/Nncase.Core/FunctionCollector.cs b/src/Nncase.Core/FunctionCollector.cs index 9ed2a8bca7..12655d25a1 100644 --- a/src/Nncase.Core/FunctionCollector.cs +++ b/src/Nncase.Core/FunctionCollector.cs @@ -17,7 +17,7 @@ public FunctionCollector() public HashSet Functions => _functions; - protected override int VisitLeafFunction(Function expr, Unit context) + protected override int VisitLeafFunction(Function expr) { _functions.Add(expr); return 0; diff --git a/src/Nncase.Core/IR/Buffers/BufferLoad.cs b/src/Nncase.Core/IR/Buffers/BufferLoad.cs new file mode 100644 index 0000000000..dbf3427b6e --- /dev/null +++ b/src/Nncase.Core/IR/Buffers/BufferLoad.cs @@ -0,0 +1,28 @@ +// Copyright (c) Canaan Inc. All rights reserved. +// Licensed under the Apache license. See LICENSE file in the project root for full license information. + +using Nncase.IR.Tensors; +using Nncase.PatternMatch; +using static Nncase.IR.TypePatternUtility; + +namespace Nncase.IR.Buffers; + +/// +/// BufferLoad expression. +/// +[PatternFunctionalGenerator] +public sealed partial class BufferLoad : Op +{ + /// + /// Get the input parameter. + /// + public static readonly ParameterInfo Input = new(typeof(BufferLoad), 0, "input", IsTensor()); + + /// + /// Get the indices. + /// + public static readonly ParameterInfo Indices = new(typeof(BufferLoad), 1, "indices", IsTuple()); + + /// + public override bool CanFoldConstCall => false; +} diff --git a/src/Nncase.Core/IR/Buffers/BufferOf.cs b/src/Nncase.Core/IR/Buffers/BufferOf.cs index a3bb033275..47a2541c1b 100644 --- a/src/Nncase.Core/IR/Buffers/BufferOf.cs +++ b/src/Nncase.Core/IR/Buffers/BufferOf.cs @@ -16,7 +16,7 @@ public sealed partial class BufferOf : Op /// public static readonly ParameterInfo Input = new(typeof(BufferOf), 0, "input", IsTensor()); - public Schedule.MemoryLocation MemoryLocation { get; } + public TIR.MemoryLocation MemoryLocation { get; } /// public override string DisplayProperty() => $"Schedule.MemoryLocation.{MemoryLocation}"; diff --git a/src/Nncase.Core/IR/Buffers/BufferStore.cs b/src/Nncase.Core/IR/Buffers/BufferStore.cs new file mode 100644 index 0000000000..2d8e86cad8 --- /dev/null +++ b/src/Nncase.Core/IR/Buffers/BufferStore.cs @@ -0,0 +1,33 @@ +// Copyright (c) Canaan Inc. All rights reserved. +// Licensed under the Apache license. See LICENSE file in the project root for full license information. + +using Nncase.IR.Tensors; +using Nncase.PatternMatch; +using static Nncase.IR.TypePatternUtility; + +namespace Nncase.IR.Buffers; + +/// +/// BufferStore op. +/// +[PatternFunctionalGenerator] +public sealed partial class BufferStore : Op +{ + /// + /// Get the input parameter. + /// + public static readonly ParameterInfo Input = new(typeof(BufferStore), 0, "input", IsTensor()); + + /// + /// Get the indices parameter. + /// + public static readonly ParameterInfo Indices = new(typeof(BufferStore), 1, "indices", IsTuple()); + + /// + /// Get the value parameter. + /// + public static readonly ParameterInfo Value = new(typeof(BufferStore), 2, "value", IsScalar()); + + /// + public override bool CanFoldConstCall => false; +} diff --git a/src/Nncase.Core/IR/Buffers/DDrOf.cs b/src/Nncase.Core/IR/Buffers/DDrOf.cs index 8657d3f417..116e019d5f 100644 --- a/src/Nncase.Core/IR/Buffers/DDrOf.cs +++ b/src/Nncase.Core/IR/Buffers/DDrOf.cs @@ -17,4 +17,7 @@ public sealed partial class DDrOf : Op /// Get the input parameter. /// public static readonly ParameterInfo Input = new(typeof(DDrOf), 0, "input", IsTensor()); + + /// + public override bool CanFoldConstCall => false; } diff --git a/src/Nncase.Core/IR/Buffers/Functional.cs b/src/Nncase.Core/IR/Buffers/Functional.cs index 54cd9f59cf..a2e3507a5f 100644 --- a/src/Nncase.Core/IR/Buffers/Functional.cs +++ b/src/Nncase.Core/IR/Buffers/Functional.cs @@ -41,5 +41,5 @@ public static Call BaseMentOf(Expr input) => /// /// create the uninitialized buffer. /// - public static Call Uninitialized(DataType dataType, Schedule.MemoryLocation memoryLocation, Expr shape) => new Call(new Uninitialized(dataType, memoryLocation), shape); + public static Call Uninitialized(DataType dataType, TIR.MemoryLocation memoryLocation, Expr shape) => new Call(new Uninitialized(dataType, memoryLocation), shape); } diff --git a/src/Nncase.Core/IR/Buffers/MatchBuffer.cs b/src/Nncase.Core/IR/Buffers/MatchBuffer.cs new file mode 100644 index 0000000000..3cafa7f595 --- /dev/null +++ b/src/Nncase.Core/IR/Buffers/MatchBuffer.cs @@ -0,0 +1,21 @@ +// Copyright (c) Canaan Inc. All rights reserved. +// Licensed under the Apache license. See LICENSE file in the project root for full license information. + +using Nncase.IR.Tensors; +using Nncase.PatternMatch; +using static Nncase.IR.TypePatternUtility; + +namespace Nncase.IR.Buffers; + +/// +/// MatchBuffer op. +/// todo maybe need united matchbuffer and allocatebuffer. +/// +[PatternFunctionalGenerator] +public sealed partial class MatchBuffer : Op +{ + public static readonly ParameterInfo Input = new(typeof(MatchBuffer), 0, "input", IsTensor()); + + /// + public override bool CanFoldConstCall => false; +} diff --git a/src/Nncase.Core/IR/Buffers/Uninitialized.cs b/src/Nncase.Core/IR/Buffers/Uninitialized.cs index 2f529638b9..42564bbc4c 100644 --- a/src/Nncase.Core/IR/Buffers/Uninitialized.cs +++ b/src/Nncase.Core/IR/Buffers/Uninitialized.cs @@ -19,11 +19,11 @@ public sealed partial class Uninitialized : Op public DataType DType { get; } - public Schedule.MemoryLocation MemoryLocation { get; } + public TIR.MemoryLocation MemoryLocation { get; } /// public override bool CanFoldConstCall => false; /// - public override string DisplayProperty() => $"{DType.GetCSharpName()}, Schedule.MemoryLocation.{MemoryLocation}"; + public override string DisplayProperty() => $"{DType.GetCSharpName()}, MemoryLocation.{MemoryLocation}"; } diff --git a/src/Nncase.Core/IR/Callable.cs b/src/Nncase.Core/IR/Callable.cs index edd6004539..16fc9ad5a7 100644 --- a/src/Nncase.Core/IR/Callable.cs +++ b/src/Nncase.Core/IR/Callable.cs @@ -17,7 +17,7 @@ public abstract class Callable : Expr /// /// StackVM module kind. /// - public static readonly string StackVMModuleKind = "stackvm"; + public const string StackVMModuleKind = "stackvm"; public Callable(string name, string moduleKind, Expr[] operands) : base(operands) diff --git a/src/Nncase.Core/IR/ExprCloner.g.cs b/src/Nncase.Core/IR/ExprCloner.g.cs index 214605c69e..855ff4e22e 100644 --- a/src/Nncase.Core/IR/ExprCloner.g.cs +++ b/src/Nncase.Core/IR/ExprCloner.g.cs @@ -1,4 +1,3 @@ - //--------------------------------------------------------------------------------------------------- // // This code was generated by T4 template. @@ -57,8 +56,7 @@ protected override Expr VisitLeafIf(If expr, TContext context) return expr.With( condition: Clone(expr.Condition, context), then: Clone(expr.Then, context), - @else: Clone(expr.Else, context), - paramList: expr.ParamList.Select(p => Clone(p, context)).ToArray() + @else: Clone(expr.Else, context) ); } @@ -141,31 +139,6 @@ protected override Expr VisitLeafBlock(TIR.Block expr, TContext context) ); } - /// - protected override Expr VisitLeafLogicalBuffer(TIR.LogicalBuffer expr, TContext context) - { - return expr.With( - dimensions: CloneArray(expr.Dimensions, context), - strides: CloneArray(expr.Strides, context) - ); - } - - /// - protected override Expr VisitLeafPhysicalBuffer(TIR.PhysicalBuffer expr, TContext context) - { - return expr.With( - ); - } - - /// - protected override Expr VisitLeafBufferLoad(TIR.BufferLoad expr, TContext context) - { - return expr.With( - buffer: Clone(expr.Buffer, context), - indices: CloneArray(expr.Indices, context) - ); - } - /// protected override Expr VisitLeafBufferRegion(TIR.BufferRegion expr, TContext context) { @@ -175,16 +148,6 @@ protected override Expr VisitLeafBufferRegion(TIR.BufferRegion expr, TContext co ); } - /// - protected override Expr VisitLeafBufferStore(TIR.BufferStore expr, TContext context) - { - return expr.With( - buffer: Clone(expr.Buffer, context), - indices: CloneArray(expr.Indices, context), - value: Clone(expr.Value, context) - ); - } - /// protected override Expr VisitLeafFor(TIR.For expr, TContext context) { diff --git a/src/Nncase.Core/IR/ExprFunctor.cs b/src/Nncase.Core/IR/ExprFunctor.cs index 4462f8d2cc..2d19dbc1b3 100644 --- a/src/Nncase.Core/IR/ExprFunctor.cs +++ b/src/Nncase.Core/IR/ExprFunctor.cs @@ -102,6 +102,13 @@ public partial class ExprFunctor : ExprFunctorResult. public virtual TTypeResult VisitType(TensorType type) => base.VisitType(type, default); + /// + /// Visit point type. + /// + /// pointer type. + /// Result. + public virtual TTypeResult VisitType(PointerType type) => base.VisitType(type, default); + /// /// Visit tuple type. /// @@ -116,6 +123,13 @@ public partial class ExprFunctor : ExprFunctorResult. public virtual TTypeResult VisitType(CallableType type) => base.VisitType(type, default); + /// + /// Visit callable type. + /// + /// Callable type. + /// Result. + public virtual TTypeResult VisitType(DistributedType type) => base.VisitType(type, default); + /// /// Default visit routine. /// @@ -135,12 +149,18 @@ public partial class ExprFunctor : ExprFunctor public sealed override TTypeResult VisitType(TensorType type, Unit context) => VisitType(type); + /// + public sealed override TTypeResult VisitType(PointerType type, Unit context) => VisitType(type); + /// public sealed override TTypeResult VisitType(TupleType type, Unit context) => VisitType(type); /// public sealed override TTypeResult VisitType(CallableType type, Unit context) => VisitType(type); + /// + public sealed override TTypeResult VisitType(DistributedType type, Unit context) => VisitType(type); + /// public sealed override TTypeResult DefaultVisitType(IRType type, Unit context) => DefaultVisitType(type); diff --git a/src/Nncase.Core/IR/ExprFunctor.g.cs b/src/Nncase.Core/IR/ExprFunctor.g.cs index 642b2709e4..188aad4659 100644 --- a/src/Nncase.Core/IR/ExprFunctor.g.cs +++ b/src/Nncase.Core/IR/ExprFunctor.g.cs @@ -1,4 +1,3 @@ - //--------------------------------------------------------------------------------------------------- // // This code was generated by T4 template. @@ -79,6 +78,11 @@ public partial class ExprFunctor /// internal protected virtual TExprResult VisitTupleConst(TupleConst expr, TContext context) => VisitConst(expr, context); + /// + /// Visit . + /// + internal protected virtual TExprResult VisitMemSpan(TIR.MemSpan expr, TContext context) => DefaultVisit(expr, context); + /// /// Visit . /// @@ -94,31 +98,11 @@ public partial class ExprFunctor /// internal protected virtual TExprResult VisitBuffer(TIR.Buffer expr, TContext context) => DefaultVisit(expr, context); - /// - /// Visit . - /// - internal protected virtual TExprResult VisitLogicalBuffer(TIR.LogicalBuffer expr, TContext context) => VisitBuffer(expr, context); - - /// - /// Visit . - /// - internal protected virtual TExprResult VisitPhysicalBuffer(TIR.PhysicalBuffer expr, TContext context) => VisitBuffer(expr, context); - - /// - /// Visit . - /// - internal protected virtual TExprResult VisitBufferLoad(TIR.BufferLoad expr, TContext context) => DefaultVisit(expr, context); - /// /// Visit . /// internal protected virtual TExprResult VisitBufferRegion(TIR.BufferRegion expr, TContext context) => DefaultVisit(expr, context); - /// - /// Visit . - /// - internal protected virtual TExprResult VisitBufferStore(TIR.BufferStore expr, TContext context) => DefaultVisit(expr, context); - /// /// Visit . /// @@ -250,6 +234,13 @@ public partial class ExprFunctor /// internal protected sealed override TExprResult VisitTupleConst(TupleConst expr, Unit context) => VisitTupleConst(expr); /// + /// Visit . + /// + internal protected virtual TExprResult VisitMemSpan(TIR.MemSpan expr) => base.VisitMemSpan(expr, default); + + /// + internal protected sealed override TExprResult VisitMemSpan(TIR.MemSpan expr, Unit context) => VisitMemSpan(expr); + /// /// Visit . /// internal protected virtual TExprResult VisitVar(Var expr) => base.VisitVar(expr, default); @@ -271,27 +262,6 @@ public partial class ExprFunctor /// internal protected sealed override TExprResult VisitBuffer(TIR.Buffer expr, Unit context) => VisitBuffer(expr); /// - /// Visit . - /// - internal protected virtual TExprResult VisitLogicalBuffer(TIR.LogicalBuffer expr) => base.VisitLogicalBuffer(expr, default); - - /// - internal protected sealed override TExprResult VisitLogicalBuffer(TIR.LogicalBuffer expr, Unit context) => VisitLogicalBuffer(expr); - /// - /// Visit . - /// - internal protected virtual TExprResult VisitPhysicalBuffer(TIR.PhysicalBuffer expr) => base.VisitPhysicalBuffer(expr, default); - - /// - internal protected sealed override TExprResult VisitPhysicalBuffer(TIR.PhysicalBuffer expr, Unit context) => VisitPhysicalBuffer(expr); - /// - /// Visit . - /// - internal protected virtual TExprResult VisitBufferLoad(TIR.BufferLoad expr) => base.VisitBufferLoad(expr, default); - - /// - internal protected sealed override TExprResult VisitBufferLoad(TIR.BufferLoad expr, Unit context) => VisitBufferLoad(expr); - /// /// Visit . /// internal protected virtual TExprResult VisitBufferRegion(TIR.BufferRegion expr) => base.VisitBufferRegion(expr, default); @@ -299,13 +269,6 @@ public partial class ExprFunctor /// internal protected sealed override TExprResult VisitBufferRegion(TIR.BufferRegion expr, Unit context) => VisitBufferRegion(expr); /// - /// Visit . - /// - internal protected virtual TExprResult VisitBufferStore(TIR.BufferStore expr) => base.VisitBufferStore(expr, default); - - /// - internal protected sealed override TExprResult VisitBufferStore(TIR.BufferStore expr, Unit context) => VisitBufferStore(expr); - /// /// Visit . /// internal protected virtual TExprResult VisitFor(TIR.For expr) => base.VisitFor(expr, default); diff --git a/src/Nncase.Core/IR/ExprRewriter.g.cs b/src/Nncase.Core/IR/ExprRewriter.g.cs index 4c8cece3f2..b842c110f1 100644 --- a/src/Nncase.Core/IR/ExprRewriter.g.cs +++ b/src/Nncase.Core/IR/ExprRewriter.g.cs @@ -1,4 +1,3 @@ - //--------------------------------------------------------------------------------------------------- // // This code was generated by T4 template. @@ -92,6 +91,12 @@ protected sealed override Expr VisitLeafTupleConst(TupleConst expr, TContext con return RewriteLeafTupleConst(expr, context); } + /// + protected sealed override Expr VisitLeafMemSpan(TIR.MemSpan expr, TContext context) + { + return RewriteLeafMemSpan(expr, context); + } + /// protected sealed override Expr VisitLeafVar(Var expr, TContext context) { @@ -110,36 +115,12 @@ protected sealed override Expr VisitLeafBuffer(TIR.Buffer expr, TContext context return RewriteLeafBuffer(expr, context); } - /// - protected sealed override Expr VisitLeafLogicalBuffer(TIR.LogicalBuffer expr, TContext context) - { - return RewriteLeafLogicalBuffer(expr, context); - } - - /// - protected sealed override Expr VisitLeafPhysicalBuffer(TIR.PhysicalBuffer expr, TContext context) - { - return RewriteLeafPhysicalBuffer(expr, context); - } - - /// - protected sealed override Expr VisitLeafBufferLoad(TIR.BufferLoad expr, TContext context) - { - return RewriteLeafBufferLoad(expr, context); - } - /// protected sealed override Expr VisitLeafBufferRegion(TIR.BufferRegion expr, TContext context) { return RewriteLeafBufferRegion(expr, context); } - /// - protected sealed override Expr VisitLeafBufferStore(TIR.BufferStore expr, TContext context) - { - return RewriteLeafBufferStore(expr, context); - } - /// protected sealed override Expr VisitLeafFor(TIR.For expr, TContext context) { @@ -247,6 +228,11 @@ protected sealed override Expr VisitLeafIterVar(TIR.IterVar expr, TContext conte /// protected virtual Expr RewriteLeafTupleConst(TupleConst expr, TContext context) => RewriteLeafConst(expr, context); + /// + /// Rewrite leaf . + /// + protected virtual Expr RewriteLeafMemSpan(TIR.MemSpan expr, TContext context) => DefaultRewriteLeaf(expr, context); + /// /// Rewrite leaf . /// @@ -262,31 +248,11 @@ protected sealed override Expr VisitLeafIterVar(TIR.IterVar expr, TContext conte /// protected virtual Expr RewriteLeafBuffer(TIR.Buffer expr, TContext context) => DefaultRewriteLeaf(expr, context); - /// - /// Rewrite leaf . - /// - protected virtual Expr RewriteLeafLogicalBuffer(TIR.LogicalBuffer expr, TContext context) => RewriteLeafBuffer(expr, context); - - /// - /// Rewrite leaf . - /// - protected virtual Expr RewriteLeafPhysicalBuffer(TIR.PhysicalBuffer expr, TContext context) => RewriteLeafBuffer(expr, context); - - /// - /// Rewrite leaf . - /// - protected virtual Expr RewriteLeafBufferLoad(TIR.BufferLoad expr, TContext context) => DefaultRewriteLeaf(expr, context); - /// /// Rewrite leaf . /// protected virtual Expr RewriteLeafBufferRegion(TIR.BufferRegion expr, TContext context) => DefaultRewriteLeaf(expr, context); - /// - /// Rewrite leaf . - /// - protected virtual Expr RewriteLeafBufferStore(TIR.BufferStore expr, TContext context) => DefaultRewriteLeaf(expr, context); - /// /// Rewrite leaf . /// @@ -430,6 +396,14 @@ public partial class ExprRewriter /// protected sealed override Expr RewriteLeafTupleConst(TupleConst expr, Unit context) => RewriteLeafTupleConst(expr); + /// + /// Rewrite leaf . + /// + protected virtual Expr RewriteLeafMemSpan(TIR.MemSpan expr) => DefaultRewriteLeaf(expr); + + /// + protected sealed override Expr RewriteLeafMemSpan(TIR.MemSpan expr, Unit context) => RewriteLeafMemSpan(expr); + /// /// Rewrite leaf . /// @@ -454,30 +428,6 @@ public partial class ExprRewriter /// protected sealed override Expr RewriteLeafBuffer(TIR.Buffer expr, Unit context) => RewriteLeafBuffer(expr); - /// - /// Rewrite leaf . - /// - protected virtual Expr RewriteLeafLogicalBuffer(TIR.LogicalBuffer expr) => RewriteLeafBuffer(expr); - - /// - protected sealed override Expr RewriteLeafLogicalBuffer(TIR.LogicalBuffer expr, Unit context) => RewriteLeafLogicalBuffer(expr); - - /// - /// Rewrite leaf . - /// - protected virtual Expr RewriteLeafPhysicalBuffer(TIR.PhysicalBuffer expr) => RewriteLeafBuffer(expr); - - /// - protected sealed override Expr RewriteLeafPhysicalBuffer(TIR.PhysicalBuffer expr, Unit context) => RewriteLeafPhysicalBuffer(expr); - - /// - /// Rewrite leaf . - /// - protected virtual Expr RewriteLeafBufferLoad(TIR.BufferLoad expr) => DefaultRewriteLeaf(expr); - - /// - protected sealed override Expr RewriteLeafBufferLoad(TIR.BufferLoad expr, Unit context) => RewriteLeafBufferLoad(expr); - /// /// Rewrite leaf . /// @@ -486,14 +436,6 @@ public partial class ExprRewriter /// protected sealed override Expr RewriteLeafBufferRegion(TIR.BufferRegion expr, Unit context) => RewriteLeafBufferRegion(expr); - /// - /// Rewrite leaf . - /// - protected virtual Expr RewriteLeafBufferStore(TIR.BufferStore expr) => DefaultRewriteLeaf(expr); - - /// - protected sealed override Expr RewriteLeafBufferStore(TIR.BufferStore expr, Unit context) => RewriteLeafBufferStore(expr); - /// /// Rewrite leaf . /// diff --git a/src/Nncase.Core/IR/ExprVisitor.g.cs b/src/Nncase.Core/IR/ExprVisitor.g.cs index c296f5f7e0..dd974ec60b 100644 --- a/src/Nncase.Core/IR/ExprVisitor.g.cs +++ b/src/Nncase.Core/IR/ExprVisitor.g.cs @@ -1,4 +1,3 @@ - //--------------------------------------------------------------------------------------------------- // // This code was generated by T4 template. @@ -104,38 +103,31 @@ protected internal override TExprResult VisitTupleConst(TupleConst expr, TContex } /// - protected internal override TExprResult VisitVar(Var expr, TContext context) + protected internal override TExprResult VisitMemSpan(TIR.MemSpan expr, TContext context) { VisitOperands(expr, context); - return VisitLeafVar(expr, context); + return VisitLeafMemSpan(expr, context); } /// - protected internal override TExprResult VisitBlock(TIR.Block expr, TContext context) - { - VisitOperands(expr, context); - return VisitLeafBlock(expr, context); - } - - /// - protected internal override TExprResult VisitLogicalBuffer(TIR.LogicalBuffer expr, TContext context) + protected internal override TExprResult VisitVar(Var expr, TContext context) { VisitOperands(expr, context); - return VisitLeafLogicalBuffer(expr, context); + return VisitLeafVar(expr, context); } /// - protected internal override TExprResult VisitPhysicalBuffer(TIR.PhysicalBuffer expr, TContext context) + protected internal override TExprResult VisitBlock(TIR.Block expr, TContext context) { VisitOperands(expr, context); - return VisitLeafPhysicalBuffer(expr, context); + return VisitLeafBlock(expr, context); } /// - protected internal override TExprResult VisitBufferLoad(TIR.BufferLoad expr, TContext context) + protected internal override TExprResult VisitBuffer(TIR.Buffer expr, TContext context) { VisitOperands(expr, context); - return VisitLeafBufferLoad(expr, context); + return VisitLeafBuffer(expr, context); } /// @@ -145,13 +137,6 @@ protected internal override TExprResult VisitBufferRegion(TIR.BufferRegion expr, return VisitLeafBufferRegion(expr, context); } - /// - protected internal override TExprResult VisitBufferStore(TIR.BufferStore expr, TContext context) - { - VisitOperands(expr, context); - return VisitLeafBufferStore(expr, context); - } - /// protected internal override TExprResult VisitFor(TIR.For expr, TContext context) { @@ -270,6 +255,11 @@ protected internal override TExprResult VisitIterVar(TIR.IterVar expr, TContext /// protected virtual TExprResult VisitLeafTupleConst(TupleConst expr, TContext context) => VisitLeafConst(expr, context); + /// + /// Visit leaf . + /// + protected virtual TExprResult VisitLeafMemSpan(TIR.MemSpan expr, TContext context) => DefaultVisitLeaf(expr, context); + /// /// Visit leaf . /// @@ -285,31 +275,11 @@ protected internal override TExprResult VisitIterVar(TIR.IterVar expr, TContext /// protected virtual TExprResult VisitLeafBuffer(TIR.Buffer expr, TContext context) => DefaultVisitLeaf(expr, context); - /// - /// Visit leaf . - /// - protected virtual TExprResult VisitLeafLogicalBuffer(TIR.LogicalBuffer expr, TContext context) => VisitLeafBuffer(expr, context); - - /// - /// Visit leaf . - /// - protected virtual TExprResult VisitLeafPhysicalBuffer(TIR.PhysicalBuffer expr, TContext context) => VisitLeafBuffer(expr, context); - - /// - /// Visit leaf . - /// - protected virtual TExprResult VisitLeafBufferLoad(TIR.BufferLoad expr, TContext context) => DefaultVisitLeaf(expr, context); - /// /// Visit leaf . /// protected virtual TExprResult VisitLeafBufferRegion(TIR.BufferRegion expr, TContext context) => DefaultVisitLeaf(expr, context); - /// - /// Visit leaf . - /// - protected virtual TExprResult VisitLeafBufferStore(TIR.BufferStore expr, TContext context) => DefaultVisitLeaf(expr, context); - /// /// Visit leaf . /// @@ -353,182 +323,168 @@ public partial class ExprVisitor /// Visit . /// internal protected virtual TExprResult VisitCall(Call expr) => base.VisitCall(expr, default); - + /// internal protected sealed override TExprResult VisitCall(Call expr, Unit context) => VisitCall(expr); /// /// Visit . /// internal protected virtual TExprResult VisitFunction(Function expr) => base.VisitFunction(expr, default); - + /// internal protected sealed override TExprResult VisitFunction(Function expr, Unit context) => VisitFunction(expr); /// /// Visit . /// internal protected virtual TExprResult VisitFusion(Fusion expr) => base.VisitFusion(expr, default); - + /// internal protected sealed override TExprResult VisitFusion(Fusion expr, Unit context) => VisitFusion(expr); /// /// Visit . /// internal protected virtual TExprResult VisitIf(If expr) => base.VisitIf(expr, default); - + /// internal protected sealed override TExprResult VisitIf(If expr, Unit context) => VisitIf(expr); /// /// Visit . /// internal protected virtual TExprResult VisitMarker(Marker expr) => base.VisitMarker(expr, default); - + /// internal protected sealed override TExprResult VisitMarker(Marker expr, Unit context) => VisitMarker(expr); /// /// Visit . /// internal protected virtual TExprResult VisitNone(None expr) => base.VisitNone(expr, default); - + /// internal protected sealed override TExprResult VisitNone(None expr, Unit context) => VisitNone(expr); /// /// Visit . /// internal protected virtual TExprResult VisitOp(Op expr) => base.VisitOp(expr, default); - + /// internal protected sealed override TExprResult VisitOp(Op expr, Unit context) => VisitOp(expr); /// /// Visit . /// internal protected virtual TExprResult VisitPrimFunctionWrapper(PrimFunctionWrapper expr) => base.VisitPrimFunctionWrapper(expr, default); - + /// internal protected sealed override TExprResult VisitPrimFunctionWrapper(PrimFunctionWrapper expr, Unit context) => VisitPrimFunctionWrapper(expr); /// /// Visit . /// internal protected virtual TExprResult VisitTensorConst(TensorConst expr) => base.VisitTensorConst(expr, default); - + /// internal protected sealed override TExprResult VisitTensorConst(TensorConst expr, Unit context) => VisitTensorConst(expr); /// /// Visit . /// internal protected virtual TExprResult VisitTuple(IR.Tuple expr) => base.VisitTuple(expr, default); - + /// internal protected sealed override TExprResult VisitTuple(IR.Tuple expr, Unit context) => VisitTuple(expr); /// /// Visit . /// internal protected virtual TExprResult VisitTupleConst(TupleConst expr) => base.VisitTupleConst(expr, default); - + /// internal protected sealed override TExprResult VisitTupleConst(TupleConst expr, Unit context) => VisitTupleConst(expr); /// + /// Visit . + /// + internal protected virtual TExprResult VisitMemSpan(TIR.MemSpan expr) => base.VisitMemSpan(expr, default); + + /// + internal protected sealed override TExprResult VisitMemSpan(TIR.MemSpan expr, Unit context) => VisitMemSpan(expr); + /// /// Visit . /// internal protected virtual TExprResult VisitVar(Var expr) => base.VisitVar(expr, default); - + /// internal protected sealed override TExprResult VisitVar(Var expr, Unit context) => VisitVar(expr); /// /// Visit . /// internal protected virtual TExprResult VisitBlock(TIR.Block expr) => base.VisitBlock(expr, default); - + /// internal protected sealed override TExprResult VisitBlock(TIR.Block expr, Unit context) => VisitBlock(expr); /// - /// Visit . - /// - internal protected virtual TExprResult VisitLogicalBuffer(TIR.LogicalBuffer expr) => base.VisitLogicalBuffer(expr, default); - - /// - internal protected sealed override TExprResult VisitLogicalBuffer(TIR.LogicalBuffer expr, Unit context) => VisitLogicalBuffer(expr); - /// - /// Visit . + /// Visit . /// - internal protected virtual TExprResult VisitPhysicalBuffer(TIR.PhysicalBuffer expr) => base.VisitPhysicalBuffer(expr, default); - + internal protected virtual TExprResult VisitBuffer(TIR.Buffer expr) => base.VisitBuffer(expr, default); + /// - internal protected sealed override TExprResult VisitPhysicalBuffer(TIR.PhysicalBuffer expr, Unit context) => VisitPhysicalBuffer(expr); - /// - /// Visit . - /// - internal protected virtual TExprResult VisitBufferLoad(TIR.BufferLoad expr) => base.VisitBufferLoad(expr, default); - - /// - internal protected sealed override TExprResult VisitBufferLoad(TIR.BufferLoad expr, Unit context) => VisitBufferLoad(expr); + internal protected sealed override TExprResult VisitBuffer(TIR.Buffer expr, Unit context) => VisitBuffer(expr); /// /// Visit . /// internal protected virtual TExprResult VisitBufferRegion(TIR.BufferRegion expr) => base.VisitBufferRegion(expr, default); - + /// internal protected sealed override TExprResult VisitBufferRegion(TIR.BufferRegion expr, Unit context) => VisitBufferRegion(expr); /// - /// Visit . - /// - internal protected virtual TExprResult VisitBufferStore(TIR.BufferStore expr) => base.VisitBufferStore(expr, default); - - /// - internal protected sealed override TExprResult VisitBufferStore(TIR.BufferStore expr, Unit context) => VisitBufferStore(expr); - /// /// Visit . /// internal protected virtual TExprResult VisitFor(TIR.For expr) => base.VisitFor(expr, default); - + /// internal protected sealed override TExprResult VisitFor(TIR.For expr, Unit context) => VisitFor(expr); /// /// Visit . /// internal protected virtual TExprResult VisitIfThenElse(TIR.IfThenElse expr) => base.VisitIfThenElse(expr, default); - + /// internal protected sealed override TExprResult VisitIfThenElse(TIR.IfThenElse expr, Unit context) => VisitIfThenElse(expr); /// /// Visit . /// internal protected virtual TExprResult VisitLet(TIR.Let expr) => base.VisitLet(expr, default); - + /// internal protected sealed override TExprResult VisitLet(TIR.Let expr, Unit context) => VisitLet(expr); /// /// Visit . /// internal protected virtual TExprResult VisitPrimFunction(TIR.PrimFunction expr) => base.VisitPrimFunction(expr, default); - + /// internal protected sealed override TExprResult VisitPrimFunction(TIR.PrimFunction expr, Unit context) => VisitPrimFunction(expr); /// /// Visit . /// internal protected virtual TExprResult VisitSequential(TIR.Sequential expr) => base.VisitSequential(expr, default); - + /// internal protected sealed override TExprResult VisitSequential(TIR.Sequential expr, Unit context) => VisitSequential(expr); /// /// Visit . /// internal protected virtual TExprResult VisitRange(TIR.Range expr) => base.VisitRange(expr, default); - + /// internal protected sealed override TExprResult VisitRange(TIR.Range expr, Unit context) => VisitRange(expr); /// /// Visit . /// internal protected virtual TExprResult VisitIterVar(TIR.IterVar expr) => base.VisitIterVar(expr, default); - + /// internal protected sealed override TExprResult VisitIterVar(TIR.IterVar expr, Unit context) => VisitIterVar(expr); /// /// Visit leaf . /// protected virtual TExprResult VisitLeafBaseFunction(BaseFunction expr) => base.VisitLeafBaseFunction(expr, default); - + /// protected sealed override TExprResult VisitLeafBaseFunction(BaseFunction expr, Unit context) => VisitLeafBaseFunction(expr); @@ -536,7 +492,7 @@ public partial class ExprVisitor /// Visit leaf . /// protected virtual TExprResult VisitLeafCall(Call expr) => base.VisitLeafCall(expr, default); - + /// protected sealed override TExprResult VisitLeafCall(Call expr, Unit context) => VisitLeafCall(expr); @@ -544,7 +500,7 @@ public partial class ExprVisitor /// Visit leaf . /// protected virtual TExprResult VisitLeafConst(Const expr) => base.VisitLeafConst(expr, default); - + /// protected sealed override TExprResult VisitLeafConst(Const expr, Unit context) => VisitLeafConst(expr); @@ -552,15 +508,15 @@ public partial class ExprVisitor /// Visit leaf . /// protected virtual TExprResult VisitLeafFunction(Function expr) => base.VisitLeafFunction(expr, default); - + /// - protected override TExprResult VisitLeafFunction(Function expr, Unit context) => VisitLeafFunction(expr); + protected sealed override TExprResult VisitLeafFunction(Function expr, Unit context) => VisitLeafFunction(expr); /// /// Visit leaf . /// protected virtual TExprResult VisitLeafFusion(Fusion expr) => base.VisitLeafFusion(expr, default); - + /// protected sealed override TExprResult VisitLeafFusion(Fusion expr, Unit context) => VisitLeafFusion(expr); @@ -568,7 +524,7 @@ public partial class ExprVisitor /// Visit leaf . /// protected virtual TExprResult VisitLeafIf(If expr) => base.VisitLeafIf(expr, default); - + /// protected sealed override TExprResult VisitLeafIf(If expr, Unit context) => VisitLeafIf(expr); @@ -576,7 +532,7 @@ public partial class ExprVisitor /// Visit leaf . /// protected virtual TExprResult VisitLeafMarker(Marker expr) => base.VisitLeafMarker(expr, default); - + /// protected sealed override TExprResult VisitLeafMarker(Marker expr, Unit context) => VisitLeafMarker(expr); @@ -584,7 +540,7 @@ public partial class ExprVisitor /// Visit leaf . /// protected virtual TExprResult VisitLeafNone(None expr) => base.VisitLeafNone(expr, default); - + /// protected sealed override TExprResult VisitLeafNone(None expr, Unit context) => VisitLeafNone(expr); @@ -592,7 +548,7 @@ public partial class ExprVisitor /// Visit leaf . /// protected virtual TExprResult VisitLeafOp(Op expr) => base.VisitLeafOp(expr, default); - + /// protected sealed override TExprResult VisitLeafOp(Op expr, Unit context) => VisitLeafOp(expr); @@ -600,7 +556,7 @@ public partial class ExprVisitor /// Visit leaf . /// protected virtual TExprResult VisitLeafPrimFunctionWrapper(PrimFunctionWrapper expr) => base.VisitLeafPrimFunctionWrapper(expr, default); - + /// protected sealed override TExprResult VisitLeafPrimFunctionWrapper(PrimFunctionWrapper expr, Unit context) => VisitLeafPrimFunctionWrapper(expr); @@ -608,7 +564,7 @@ public partial class ExprVisitor /// Visit leaf . /// protected virtual TExprResult VisitLeafTensorConst(TensorConst expr) => base.VisitLeafTensorConst(expr, default); - + /// protected sealed override TExprResult VisitLeafTensorConst(TensorConst expr, Unit context) => VisitLeafTensorConst(expr); @@ -616,7 +572,7 @@ public partial class ExprVisitor /// Visit leaf . /// protected virtual TExprResult VisitLeafTuple(IR.Tuple expr) => base.VisitLeafTuple(expr, default); - + /// protected sealed override TExprResult VisitLeafTuple(IR.Tuple expr, Unit context) => VisitLeafTuple(expr); @@ -624,15 +580,23 @@ public partial class ExprVisitor /// Visit leaf . /// protected virtual TExprResult VisitLeafTupleConst(TupleConst expr) => base.VisitLeafTupleConst(expr, default); - + /// protected sealed override TExprResult VisitLeafTupleConst(TupleConst expr, Unit context) => VisitLeafTupleConst(expr); + /// + /// Visit leaf . + /// + protected virtual TExprResult VisitLeafMemSpan(TIR.MemSpan expr) => base.VisitLeafMemSpan(expr, default); + + /// + protected sealed override TExprResult VisitLeafMemSpan(TIR.MemSpan expr, Unit context) => VisitLeafMemSpan(expr); + /// /// Visit leaf . /// protected virtual TExprResult VisitLeafVar(Var expr) => base.VisitLeafVar(expr, default); - + /// protected sealed override TExprResult VisitLeafVar(Var expr, Unit context) => VisitLeafVar(expr); @@ -640,7 +604,7 @@ public partial class ExprVisitor /// Visit leaf . /// protected virtual TExprResult VisitLeafBlock(TIR.Block expr) => base.VisitLeafBlock(expr, default); - + /// protected sealed override TExprResult VisitLeafBlock(TIR.Block expr, Unit context) => VisitLeafBlock(expr); @@ -648,55 +612,23 @@ public partial class ExprVisitor /// Visit leaf . /// protected virtual TExprResult VisitLeafBuffer(TIR.Buffer expr) => base.VisitLeafBuffer(expr, default); - + /// protected sealed override TExprResult VisitLeafBuffer(TIR.Buffer expr, Unit context) => VisitLeafBuffer(expr); - /// - /// Visit leaf . - /// - protected virtual TExprResult VisitLeafLogicalBuffer(TIR.LogicalBuffer expr) => base.VisitLeafLogicalBuffer(expr, default); - - /// - protected sealed override TExprResult VisitLeafLogicalBuffer(TIR.LogicalBuffer expr, Unit context) => VisitLeafLogicalBuffer(expr); - - /// - /// Visit leaf . - /// - protected virtual TExprResult VisitLeafPhysicalBuffer(TIR.PhysicalBuffer expr) => base.VisitLeafPhysicalBuffer(expr, default); - - /// - protected sealed override TExprResult VisitLeafPhysicalBuffer(TIR.PhysicalBuffer expr, Unit context) => VisitLeafPhysicalBuffer(expr); - - /// - /// Visit leaf . - /// - protected virtual TExprResult VisitLeafBufferLoad(TIR.BufferLoad expr) => base.VisitLeafBufferLoad(expr, default); - - /// - protected sealed override TExprResult VisitLeafBufferLoad(TIR.BufferLoad expr, Unit context) => VisitLeafBufferLoad(expr); - /// /// Visit leaf . /// protected virtual TExprResult VisitLeafBufferRegion(TIR.BufferRegion expr) => base.VisitLeafBufferRegion(expr, default); - + /// protected sealed override TExprResult VisitLeafBufferRegion(TIR.BufferRegion expr, Unit context) => VisitLeafBufferRegion(expr); - /// - /// Visit leaf . - /// - protected virtual TExprResult VisitLeafBufferStore(TIR.BufferStore expr) => base.VisitLeafBufferStore(expr, default); - - /// - protected sealed override TExprResult VisitLeafBufferStore(TIR.BufferStore expr, Unit context) => VisitLeafBufferStore(expr); - /// /// Visit leaf . /// protected virtual TExprResult VisitLeafFor(TIR.For expr) => base.VisitLeafFor(expr, default); - + /// protected sealed override TExprResult VisitLeafFor(TIR.For expr, Unit context) => VisitLeafFor(expr); @@ -704,7 +636,7 @@ public partial class ExprVisitor /// Visit leaf . /// protected virtual TExprResult VisitLeafIfThenElse(TIR.IfThenElse expr) => base.VisitLeafIfThenElse(expr, default); - + /// protected sealed override TExprResult VisitLeafIfThenElse(TIR.IfThenElse expr, Unit context) => VisitLeafIfThenElse(expr); @@ -712,7 +644,7 @@ public partial class ExprVisitor /// Visit leaf . /// protected virtual TExprResult VisitLeafLet(TIR.Let expr) => base.VisitLeafLet(expr, default); - + /// protected sealed override TExprResult VisitLeafLet(TIR.Let expr, Unit context) => VisitLeafLet(expr); @@ -720,7 +652,7 @@ public partial class ExprVisitor /// Visit leaf . /// protected virtual TExprResult VisitLeafPrimFunction(TIR.PrimFunction expr) => base.VisitLeafPrimFunction(expr, default); - + /// protected sealed override TExprResult VisitLeafPrimFunction(TIR.PrimFunction expr, Unit context) => VisitLeafPrimFunction(expr); @@ -728,7 +660,7 @@ public partial class ExprVisitor /// Visit leaf . /// protected virtual TExprResult VisitLeafSequential(TIR.Sequential expr) => base.VisitLeafSequential(expr, default); - + /// protected sealed override TExprResult VisitLeafSequential(TIR.Sequential expr, Unit context) => VisitLeafSequential(expr); @@ -736,7 +668,7 @@ public partial class ExprVisitor /// Visit leaf . /// protected virtual TExprResult VisitLeafRange(TIR.Range expr) => base.VisitLeafRange(expr, default); - + /// protected sealed override TExprResult VisitLeafRange(TIR.Range expr, Unit context) => VisitLeafRange(expr); @@ -744,7 +676,7 @@ public partial class ExprVisitor /// Visit leaf . /// protected virtual TExprResult VisitLeafIterVar(TIR.IterVar expr) => base.VisitLeafIterVar(expr, default); - + /// protected sealed override TExprResult VisitLeafIterVar(TIR.IterVar expr, Unit context) => VisitLeafIterVar(expr); diff --git a/src/Nncase.Core/IR/IIRPrinterProvider.cs b/src/Nncase.Core/IR/IIRPrinterProvider.cs index 3c267771df..f411167f2a 100644 --- a/src/Nncase.Core/IR/IIRPrinterProvider.cs +++ b/src/Nncase.Core/IR/IIRPrinterProvider.cs @@ -73,6 +73,14 @@ public interface IIRPrinterProvider /// randConst = false will save the const into bin. public void DumpCSharpIR(Expr expr, string prefix, string dumpDir, bool randConst); + /// + /// dump the expr as csharp code. + /// + /// expression. + /// file prefix. + /// file dump ir. + public void DumpPatternIR(Expr expr, string prefix, string dumpDir); + /// /// print ir type. /// diff --git a/src/Nncase.Core/IR/IRList.csv b/src/Nncase.Core/IR/IRList.csv index 5ae3c89d18..ba9dd8033b 100644 --- a/src/Nncase.Core/IR/IRList.csv +++ b/src/Nncase.Core/IR/IRList.csv @@ -11,14 +11,11 @@ PrimFunctionWrapper,true,true,BaseFunction,,Target TensorConst,true,false,Const,, Tuple,true,false,Default,IR.,@Fields TupleConst,true,false,Const,, +MemSpan,true,false,Default,TIR.,Start;Size; Var,true,false,Default,, Block,true,false,Default,TIR.,Body;InitBody;@IterVars;@Reads;@Writes;@AllocBuffers;Predicate -Buffer,false,false,Default,TIR., -LogicalBuffer,true,false,Buffer,TIR.,@Dimensions;@Strides -PhysicalBuffer,true,false,Buffer,TIR., -BufferLoad,true,false,Default,TIR.,Buffer;@Indices +Buffer,true,false,Default,TIR.,MemSpan;@Dimensions;@Strides; BufferRegion,true,false,Default,TIR.,Buffer;@Region -BufferStore,true,false,Default,TIR.,Buffer;@Indices;Value For,true,false,Default,TIR.,LoopVar;Domain;Body IfThenElse,true,false,Default,TIR.,Condition;Then;Else Let,true,false,Default,TIR.,Var;Expression;Body diff --git a/src/Nncase.Core/IR/IRType.cs b/src/Nncase.Core/IR/IRType.cs index b8aeda469f..a4311aae13 100644 --- a/src/Nncase.Core/IR/IRType.cs +++ b/src/Nncase.Core/IR/IRType.cs @@ -139,6 +139,15 @@ public sealed record TensorType(DataType DType, Shape Shape) : IRType /// the Pointed Element Type. /// the pointer tensor type. public static TensorType Pointer(DataType elemType) => new(new PointerType(elemType), Shape.Scalar); + + /// + public override string ToString() => DType switch + { + PrimType ptype => ptype.GetDisplayName() + (Shape.IsScalar ? string.Empty : Shape.ToString()), + PointerType { ElemType: PrimType etype } => $"*{etype.GetDisplayName()}", + ValueType => $"{DType}", + _ => throw new NotSupportedException(DType.GetType().Name), + }; } /// diff --git a/src/Nncase.Core/IR/Imaging/ResizeImage.cs b/src/Nncase.Core/IR/Imaging/ResizeImage.cs index 088651b511..ae48d48831 100644 --- a/src/Nncase.Core/IR/Imaging/ResizeImage.cs +++ b/src/Nncase.Core/IR/Imaging/ResizeImage.cs @@ -21,7 +21,7 @@ public sealed partial class ResizeImage : Op /// /// Gets input. /// - public static readonly ParameterInfo Input = new(typeof(ResizeImage), 0, "input", HasRank(r => r >= 2, "RanK >= 2")); + public static readonly ParameterInfo Input = new(typeof(ResizeImage), 0, "input", HasRank(r => r >= 2, "RanK >= 2"), ParameterKind.Input); /// /// Gets roi. diff --git a/src/Nncase.Core/IR/Math/Binary.cs b/src/Nncase.Core/IR/Math/Binary.cs index ead10ff710..f61f8a8704 100644 --- a/src/Nncase.Core/IR/Math/Binary.cs +++ b/src/Nncase.Core/IR/Math/Binary.cs @@ -20,12 +20,12 @@ public sealed partial class Binary : Op /// /// Gets lhs. /// - public static readonly ParameterInfo Lhs = new(typeof(Binary), 0, "lhs"); + public static readonly ParameterInfo Lhs = new(typeof(Binary), 0, "lhs", ParameterKind.Input); /// /// Gets rhs. /// - public static readonly ParameterInfo Rhs = new(typeof(Binary), 1, "rhs"); + public static readonly ParameterInfo Rhs = new(typeof(Binary), 1, "rhs", ParameterKind.Input); public BinaryOp BinaryOp { get; } diff --git a/src/Nncase.Core/IR/Math/Clamp.cs b/src/Nncase.Core/IR/Math/Clamp.cs index 9f14cf287d..b8409f375c 100644 --- a/src/Nncase.Core/IR/Math/Clamp.cs +++ b/src/Nncase.Core/IR/Math/Clamp.cs @@ -21,7 +21,7 @@ public sealed partial class Clamp : Op /// /// Gets input. /// - public static readonly ParameterInfo Input = new(typeof(Clamp), 0, "input"); + public static readonly ParameterInfo Input = new(typeof(Clamp), 0, "input", ParameterKind.Input); /// /// Gets min. diff --git a/src/Nncase.Core/IR/Math/MatMul.cs b/src/Nncase.Core/IR/Math/MatMul.cs index 51d5615f1f..fc74e211e1 100644 --- a/src/Nncase.Core/IR/Math/MatMul.cs +++ b/src/Nncase.Core/IR/Math/MatMul.cs @@ -20,10 +20,10 @@ public sealed partial class MatMul : Op /// /// Gets input. /// - public static readonly ParameterInfo Lhs = new(typeof(MatMul), 0, "lhs"); + public static readonly ParameterInfo Lhs = new(typeof(MatMul), 0, "lhs", ParameterKind.Input); /// /// Gets Other. /// - public static readonly ParameterInfo Rhs = new(typeof(MatMul), 1, "rhs"); + public static readonly ParameterInfo Rhs = new(typeof(MatMul), 1, "rhs", ParameterKind.Input); } diff --git a/src/Nncase.Core/IR/Math/ReduceArg.cs b/src/Nncase.Core/IR/Math/ReduceArg.cs index ecad8e95e2..2afd43010c 100644 --- a/src/Nncase.Core/IR/Math/ReduceArg.cs +++ b/src/Nncase.Core/IR/Math/ReduceArg.cs @@ -21,7 +21,7 @@ public sealed partial class ReduceArg : Op /// /// Gets input. /// - public static readonly ParameterInfo Input = new(typeof(ReduceArg), 0, "input"); + public static readonly ParameterInfo Input = new(typeof(ReduceArg), 0, "input", ParameterKind.Input); /// /// Gets Axis. @@ -42,8 +42,8 @@ public sealed partial class ReduceArg : Op public ReduceArgOp ReduceArgOp { get; } - public DataType DestType { get; } + public PrimType DestType { get; } /// - public override string DisplayProperty() => $"ReduceArgOp.{ReduceArgOp}"; + public override string DisplayProperty() => $"ReduceArgOp.{ReduceArgOp}, DestType: {DestType}"; } diff --git a/src/Nncase.Core/IR/Math/Unary.cs b/src/Nncase.Core/IR/Math/Unary.cs index 820572437e..20d6b3fb03 100644 --- a/src/Nncase.Core/IR/Math/Unary.cs +++ b/src/Nncase.Core/IR/Math/Unary.cs @@ -20,7 +20,7 @@ public sealed partial class Unary : Op /// /// Gets input. /// - public static readonly ParameterInfo Input = new(typeof(Unary), 0, "input"); + public static readonly ParameterInfo Input = new(typeof(Unary), 0, "input", ParameterKind.Input); public UnaryOp UnaryOp { get; } diff --git a/src/Nncase.Core/IR/NN/Activations.cs b/src/Nncase.Core/IR/NN/Activations.cs index 46df70d241..1866e65220 100644 --- a/src/Nncase.Core/IR/NN/Activations.cs +++ b/src/Nncase.Core/IR/NN/Activations.cs @@ -154,7 +154,12 @@ public sealed partial class Swish : ActivationOp /// /// Gets input. /// - public static readonly ParameterInfo Input = new(typeof(Swish), 0, "input"); + public static readonly ParameterInfo Input = new(typeof(Swish), 0, "input", ParameterKind.Input); + + /// + /// Gets beta. + /// + public static readonly ParameterInfo Beta = new(typeof(Swish), 1, "beta", IsFloatScalar()); } /// diff --git a/src/Nncase.Core/IR/NN/Conv2D.cs b/src/Nncase.Core/IR/NN/Conv2D.cs index 43c1b2fced..0607870481 100644 --- a/src/Nncase.Core/IR/NN/Conv2D.cs +++ b/src/Nncase.Core/IR/NN/Conv2D.cs @@ -21,17 +21,17 @@ public sealed partial class Conv2D : Op /// /// Gets input. /// - public static readonly ParameterInfo Input = new(typeof(Conv2D), 0, "input"); + public static readonly ParameterInfo Input = new(typeof(Conv2D), 0, "input", ParameterKind.Input); /// /// Gets Weights. /// - public static readonly ParameterInfo Weights = new(typeof(Conv2D), 1, "weights", HasRank(4)); + public static readonly ParameterInfo Weights = new(typeof(Conv2D), 1, "weights", HasRank(4), ParameterKind.Input); /// /// Gets Bias. /// - public static readonly ParameterInfo Bias = new(typeof(Conv2D), 2, "bias", HasRank(1)); + public static readonly ParameterInfo Bias = new(typeof(Conv2D), 2, "bias", HasRank(1), ParameterKind.Input); /// /// Gets Stride. diff --git a/src/Nncase.Core/IR/NN/Functional.cs b/src/Nncase.Core/IR/NN/Functional.cs index 30b4005388..e8a44d8a38 100644 --- a/src/Nncase.Core/IR/NN/Functional.cs +++ b/src/Nncase.Core/IR/NN/Functional.cs @@ -34,7 +34,7 @@ public static class NN public static Call BatchNormalization(Expr input, Expr scale, Expr bias, Expr input_mean, Expr input_var, Expr epsilon, Expr momentum) => new Call(new BatchNormalization(), input, scale, bias, input_mean, input_var, epsilon, momentum); - public static Call LayerNorm(int axis, float epsilon, Expr input, Expr scale, Expr bias) => new Call(new LayerNorm(axis, epsilon), input, scale, bias); + public static Call LayerNorm(int axis, float epsilon, Expr input, Expr scale, Expr bias, bool hasMean = true) => new Call(new LayerNorm(axis, epsilon, hasMean), input, scale, bias); public static Call BatchToSpace(Expr input, Expr blockShape, Expr crops) => new Call(new BatchToSpace(), input, blockShape, crops); @@ -103,5 +103,10 @@ public static Call ReduceWindow2D(ReduceOp reduceOp, Expr input, Expr initValue, /// /// create Swish call. /// - public static Call Swish(Expr input) => new Call(new Swish(), input); + public static Call Swish(Expr input) => new Call(new Swish(), input, 1f); + + /// + /// create Swish call. + /// + public static Call Swish(Expr input, Expr beta) => new Call(new Swish(), input, beta); } diff --git a/src/Nncase.Core/IR/NN/LayerNorm.cs b/src/Nncase.Core/IR/NN/LayerNorm.cs index 2dff32f440..2474f44fc2 100644 --- a/src/Nncase.Core/IR/NN/LayerNorm.cs +++ b/src/Nncase.Core/IR/NN/LayerNorm.cs @@ -21,19 +21,23 @@ public sealed partial class LayerNorm : Op /// /// Gets input. /// - public static readonly ParameterInfo Input = new(typeof(LayerNorm), 0, "input"); + public static readonly ParameterInfo Input = new(typeof(LayerNorm), 0, "input", ParameterKind.Input); /// /// Gets scale. /// - public static readonly ParameterInfo Scale = new(typeof(LayerNorm), 1, "scale"); + public static readonly ParameterInfo Scale = new(typeof(LayerNorm), 1, "scale", ParameterKind.Input); /// /// Gets bias. /// - public static readonly ParameterInfo Bias = new(typeof(LayerNorm), 2, "bias"); + public static readonly ParameterInfo Bias = new(typeof(LayerNorm), 2, "bias", ParameterKind.Input); public int Axis { get; } public float Epsilon { get; } + + public bool UseMean { get; } + + public override string DisplayProperty() => $"Axis: {Axis}, Epsilon: {Epsilon}, UseMean: {UseMean}"; } diff --git a/src/Nncase.Core/IR/NN/Normalization.cs b/src/Nncase.Core/IR/NN/Normalization.cs index ba91f13f83..2b7b74168c 100644 --- a/src/Nncase.Core/IR/NN/Normalization.cs +++ b/src/Nncase.Core/IR/NN/Normalization.cs @@ -61,17 +61,17 @@ public sealed partial class InstanceNormalization : Op /// /// Gets input. /// - public static readonly ParameterInfo Input = new(typeof(InstanceNormalization), 0, "input"); + public static readonly ParameterInfo Input = new(typeof(InstanceNormalization), 0, "input", ParameterKind.Input); /// /// Gets input. /// - public static readonly ParameterInfo Scale = new(typeof(InstanceNormalization), 1, "scale"); + public static readonly ParameterInfo Scale = new(typeof(InstanceNormalization), 1, "scale", ParameterKind.Input); /// /// Gets input. /// - public static readonly ParameterInfo Bias = new(typeof(InstanceNormalization), 2, "bias"); + public static readonly ParameterInfo Bias = new(typeof(InstanceNormalization), 2, "bias", ParameterKind.Input); /// /// Gets Epsilon. diff --git a/src/Nncase.Core/IR/NN/SoftMax.cs b/src/Nncase.Core/IR/NN/SoftMax.cs index 686696f032..919f8ac76c 100644 --- a/src/Nncase.Core/IR/NN/SoftMax.cs +++ b/src/Nncase.Core/IR/NN/SoftMax.cs @@ -33,7 +33,7 @@ public sealed partial class Softmax : Op /// /// Gets input. /// - public static readonly ParameterInfo Input = new(typeof(Softmax), 0, "input"); + public static readonly ParameterInfo Input = new(typeof(Softmax), 0, "input", ParameterKind.Input); /// /// Gets axis. diff --git a/src/Nncase.Core/IR/Op.cs b/src/Nncase.Core/IR/Op.cs index 07d4bbcabb..2af2e39fd3 100644 --- a/src/Nncase.Core/IR/Op.cs +++ b/src/Nncase.Core/IR/Op.cs @@ -12,6 +12,12 @@ namespace Nncase.IR; +public enum ParameterKind : int +{ + Input, + Attribute, +} + /// /// Parameter information. /// @@ -24,11 +30,13 @@ public sealed class ParameterInfo /// this op type. /// param index. /// param name. - public ParameterInfo(Type ownerType, int index, string name) + /// kind. + public ParameterInfo(Type ownerType, int index, string name, ParameterKind parameterKind = ParameterKind.Attribute) { OwnerType = ownerType; Index = index; Name = name; + ParameterKind = parameterKind; } /// @@ -39,8 +47,9 @@ public ParameterInfo(Type ownerType, int index, string name) /// param index. /// param name. /// the param condition. - public ParameterInfo(Type ownerType, int index, string name, TypePattern pattern) - : this(ownerType, index, name) + /// kind. + public ParameterInfo(Type ownerType, int index, string name, TypePattern pattern, ParameterKind parameterKind = ParameterKind.Attribute) + : this(ownerType, index, name, parameterKind) { Pattern = pattern; } @@ -60,6 +69,11 @@ public ParameterInfo(Type ownerType, int index, string name, TypePattern pattern /// public string Name { get; } + /// + /// Gets parameter kind. + /// + public ParameterKind ParameterKind { get; } + /// /// Gets this paramter's type condition. /// @@ -90,7 +104,7 @@ public Op() /// /// Gets get the parameters. /// - public IEnumerable Parameters => + public virtual IEnumerable Parameters => _parameters ??= (from p in GetType().GetFields(BindingFlags.Public | BindingFlags.Static) where p.FieldType == typeof(ParameterInfo) let param = (ParameterInfo)(p.GetValue(null) ?? throw new InvalidOperationException()) diff --git a/src/Nncase.Core/IR/RNN/Functional.cs b/src/Nncase.Core/IR/RNN/Functional.cs index 07ebf7f8d6..571bb2141f 100644 --- a/src/Nncase.Core/IR/RNN/Functional.cs +++ b/src/Nncase.Core/IR/RNN/Functional.cs @@ -19,5 +19,5 @@ namespace Nncase.IR.F; public static class RNN { public static Call LSTM(LSTMDirection direction, LSTMLayout layout, string[] acts, Expr x, Expr w, Expr r, Expr b, Expr seqLens, Expr initH, Expr initC, Expr p, Expr actAlpha, Expr actBeta, Expr clip, Expr hiddenSize, Expr inputForget, Expr outputSize) => - new Call(new IR.Tensors.LSTM(direction, layout, acts), x, w, r, b, seqLens, initH, initC, p, actAlpha, actBeta, clip, hiddenSize, inputForget, outputSize); + new Call(new IR.RNN.LSTM(direction, layout, acts), x, w, r, b, seqLens, initH, initC, p, actAlpha, actBeta, clip, hiddenSize, inputForget, outputSize); } diff --git a/src/Nncase.Core/IR/RNN/LSTM.cs b/src/Nncase.Core/IR/RNN/LSTM.cs index ec5b9802b3..4bdd60f223 100644 --- a/src/Nncase.Core/IR/RNN/LSTM.cs +++ b/src/Nncase.Core/IR/RNN/LSTM.cs @@ -5,7 +5,7 @@ using Nncase.PatternMatch; using static Nncase.IR.TypePatternUtility; -namespace Nncase.IR.Tensors; +namespace Nncase.IR.RNN; /// /// LSTM expression. diff --git a/src/Nncase.Core/IR/TensorConst.cs b/src/Nncase.Core/IR/TensorConst.cs index 64b1fd6442..9e651978ed 100644 --- a/src/Nncase.Core/IR/TensorConst.cs +++ b/src/Nncase.Core/IR/TensorConst.cs @@ -146,7 +146,7 @@ public override TExprResult Accept(ExprFunct public override bool Equals(object? obj) => Equals(obj as TensorConst); /// - public bool Equals(TensorConst? other) => other is not null && base.Equals(other) && EqualityComparer.Default.Equals(Value, other.Value); + public bool Equals(TensorConst? other) => other is not null && (ReferenceEquals(this, other) || GetHashCode() == other.GetHashCode()) && EqualityComparer.Default.Equals(Value, other.Value); /// protected override int GetHashCodeCore() => HashCode.Combine(Value); diff --git a/src/Nncase.Core/IR/Tensors/Cast.cs b/src/Nncase.Core/IR/Tensors/Cast.cs index 5345cac153..1bc618f786 100644 --- a/src/Nncase.Core/IR/Tensors/Cast.cs +++ b/src/Nncase.Core/IR/Tensors/Cast.cs @@ -20,7 +20,7 @@ public sealed partial class Cast : Op /// /// Gets input. /// - public static readonly ParameterInfo Input = new(typeof(Cast), 0, "input"); + public static readonly ParameterInfo Input = new(typeof(Cast), 0, "input", ParameterKind.Input); public DataType NewType { get; } diff --git a/src/Nncase.Core/IR/Tensors/Concat.cs b/src/Nncase.Core/IR/Tensors/Concat.cs index 88a22e4376..cbe4861bc3 100644 --- a/src/Nncase.Core/IR/Tensors/Concat.cs +++ b/src/Nncase.Core/IR/Tensors/Concat.cs @@ -20,10 +20,13 @@ public sealed partial class Concat : Op /// /// Gets input. /// - public static readonly ParameterInfo Input = new(typeof(Concat), 0, "inputs"); + public static readonly ParameterInfo Input = new(typeof(Concat), 0, "inputs", ParameterKind.Input); /// /// Gets axis. /// - public static readonly ParameterInfo Axis = new(typeof(Concat), 1, "axis"); + public int Axis { get; } + + /// + public override string DisplayProperty() => $"Axis: {Axis}"; } diff --git a/src/Nncase.Core/IR/Tensors/Expand.cs b/src/Nncase.Core/IR/Tensors/Expand.cs index 3b74de6740..91a0d53e26 100644 --- a/src/Nncase.Core/IR/Tensors/Expand.cs +++ b/src/Nncase.Core/IR/Tensors/Expand.cs @@ -21,7 +21,7 @@ public sealed partial class Expand : Op /// /// Gets input. /// - public static readonly ParameterInfo Input = new(typeof(Expand), 0, "input"); + public static readonly ParameterInfo Input = new(typeof(Expand), 0, "input", ParameterKind.Input); /// /// Gets shape. diff --git a/src/Nncase.Core/IR/Tensors/Functional.cs b/src/Nncase.Core/IR/Tensors/Functional.cs index 71c0bfc51c..d20f69cbf5 100644 --- a/src/Nncase.Core/IR/Tensors/Functional.cs +++ b/src/Nncase.Core/IR/Tensors/Functional.cs @@ -70,7 +70,7 @@ public static Call Bitcast(PrimType type, Expr input, PrimType newType, Expr sha public static Call Cast(Expr input, DataType newType, CastMode castMode = CastMode.KDefault) => new Call(new Cast(newType, castMode), input); - public static Call Concat(Expr input, Expr axis) => new Call(new Concat(), input, axis); + public static Call Concat(Expr input, int axis) => new Call(new Concat(axis), input); public static Call ConstantOfShape(Expr shape, Expr value) => new Call(new ConstantOfShape(), shape, value); @@ -89,7 +89,7 @@ public static Call Expand(Expr input, Expr shape) public static Call Flatten(Expr input, Expr axis) => new Call(new Flatten(), input, axis); - public static Call Gather(Expr input, Expr axis, Expr index) => new Call(new Gather(), input, axis, index); + public static Call Gather(Expr input, int axis, Expr index) => new Call(new Gather(axis), input, index); public static Call GatherElements(Expr input, Expr axis, Expr indices) => new Call(new GatherElements(), input, axis, indices); diff --git a/src/Nncase.Core/IR/Tensors/Gather.cs b/src/Nncase.Core/IR/Tensors/Gather.cs index a498d38984..012a0de053 100644 --- a/src/Nncase.Core/IR/Tensors/Gather.cs +++ b/src/Nncase.Core/IR/Tensors/Gather.cs @@ -22,15 +22,18 @@ public sealed partial class Gather : Op /// /// Gets input. /// - public static readonly ParameterInfo Input = new(typeof(Gather), 0, "input"); + public static readonly ParameterInfo Input = new(typeof(Gather), 0, "input", ParameterKind.Input); /// - /// Gets axis. + /// Gets index. /// - public static readonly ParameterInfo Axis = new(typeof(Gather), 1, "axis", IsIntegralScalar()); + public static readonly ParameterInfo Index = new(typeof(Gather), 1, "index", IsIntegral(), ParameterKind.Input); /// - /// Gets index. + /// Gets axis. /// - public static readonly ParameterInfo Index = new(typeof(Gather), 2, "index", IsIntegral()); + public int Axis { get; } + + /// + public override string DisplayProperty() => $"Axis: {Axis}"; } diff --git a/src/Nncase.Core/IR/Tensors/Reshape.cs b/src/Nncase.Core/IR/Tensors/Reshape.cs index 2db6d16b89..571fa457e6 100644 --- a/src/Nncase.Core/IR/Tensors/Reshape.cs +++ b/src/Nncase.Core/IR/Tensors/Reshape.cs @@ -22,7 +22,7 @@ public sealed partial class Reshape : Op /// /// Gets input. /// - public static readonly ParameterInfo Input = new(typeof(Reshape), 0, "input"); + public static readonly ParameterInfo Input = new(typeof(Reshape), 0, "input", ParameterKind.Input); /// /// Gets shape. diff --git a/src/Nncase.Core/IR/Tensors/Slice.cs b/src/Nncase.Core/IR/Tensors/Slice.cs index bc58e51ee8..05963ad7f7 100644 --- a/src/Nncase.Core/IR/Tensors/Slice.cs +++ b/src/Nncase.Core/IR/Tensors/Slice.cs @@ -21,7 +21,7 @@ public sealed partial class Slice : Op /// /// Gets input. /// - public static readonly ParameterInfo Input = new(typeof(Slice), 0, "input"); + public static readonly ParameterInfo Input = new(typeof(Slice), 0, "input", ParameterKind.Input); /// /// Gets begins. diff --git a/src/Nncase.Core/IR/Tensors/Transpose.cs b/src/Nncase.Core/IR/Tensors/Transpose.cs index 896d279c29..211dda9a54 100644 --- a/src/Nncase.Core/IR/Tensors/Transpose.cs +++ b/src/Nncase.Core/IR/Tensors/Transpose.cs @@ -15,7 +15,7 @@ public sealed partial class Transpose : Op /// /// Gets input. /// - public static readonly ParameterInfo Input = new(typeof(Transpose), 0, "input"); + public static readonly ParameterInfo Input = new(typeof(Transpose), 0, "input", ParameterKind.Input); /// /// Gets perm. diff --git a/src/Nncase.Core/IR/Tensors/UnSqueeze.cs b/src/Nncase.Core/IR/Tensors/UnSqueeze.cs index cbc2574fc3..6ce9247d24 100644 --- a/src/Nncase.Core/IR/Tensors/UnSqueeze.cs +++ b/src/Nncase.Core/IR/Tensors/UnSqueeze.cs @@ -23,7 +23,7 @@ public sealed partial class Unsqueeze : Op /// /// Gets input. /// - public static readonly ParameterInfo Input = new(typeof(Unsqueeze), 0, "input"); + public static readonly ParameterInfo Input = new(typeof(Unsqueeze), 0, "input", ParameterKind.Input); /// /// Gets dimension. diff --git a/src/Nncase.Core/IR/TypeFunctor.cs b/src/Nncase.Core/IR/TypeFunctor.cs index 453cfa257a..f11ec869a4 100644 --- a/src/Nncase.Core/IR/TypeFunctor.cs +++ b/src/Nncase.Core/IR/TypeFunctor.cs @@ -32,6 +32,7 @@ public virtual TResult VisitType(IRType type, TContext context) TensorType t => VisitType(t, context), TupleType t => VisitType(t, context), CallableType t => VisitType(t, context), + DistributedType t => VisitType(t, context), _ => DefaultVisitType(type, context), }; } @@ -68,6 +69,14 @@ public virtual TResult VisitType(IRType type, TContext context) /// Result. public virtual TResult VisitType(TensorType type, TContext context) => DefaultVisitType(type, context); + /// + /// Visit pointer type. + /// + /// Pointer type. + /// Context. + /// Result. + public virtual TResult VisitType(PointerType type, TContext context) => DefaultVisitType(type, context); + /// /// Visit tuple type. /// @@ -84,6 +93,14 @@ public virtual TResult VisitType(IRType type, TContext context) /// Result. public virtual TResult VisitType(CallableType type, TContext context) => DefaultVisitType(type, context); + /// + /// Visit dist tensor type. + /// + /// dist tensor type. + /// Context. + /// Result. + public virtual TResult VisitType(DistributedType type, TContext context) => DefaultVisitType(type, context); + /// /// Default visit routine. /// diff --git a/src/Nncase.Core/IR/TypePattern.cs b/src/Nncase.Core/IR/TypePattern.cs index 6eea450f5c..183ae381f7 100644 --- a/src/Nncase.Core/IR/TypePattern.cs +++ b/src/Nncase.Core/IR/TypePattern.cs @@ -57,12 +57,12 @@ public TypePattern(CallableType valueType) public T Check(T valueType, string fieldName) where T : IRType { - if (valueType is TensorType tensorValueType && tensorValueType.Shape.IsUnranked) + if (valueType is TensorType { Shape: { IsUnranked: true } } || valueType is DistributedType { TensorType: { Shape: { IsUnranked: true } } }) { return valueType; } - if (valueType == null || !MatchLeaf(valueType)) + if (valueType == null || (valueType is TensorType t && !MatchLeaf(t)) || (valueType is DistributedType d && !MatchLeaf(d.TensorType))) { var cur = valueType is null ? "None" : CompilerServices.Print(valueType); throw new InvalidOperationException($"{fieldName} Requrie <{Reason}>, But {cur}!"); @@ -187,6 +187,7 @@ public static TypePattern HasRank(Func cond, string reason) => HasSha x => x switch { TensorType ttype => DataTypes.IsIntegral(ttype.DType), + DistributedType distributedType => DataTypes.IsIntegral(distributedType.TensorType.DType), _ => false, }, "IsIntegral"); diff --git a/src/Nncase.Core/ITarget.cs b/src/Nncase.Core/ITarget.cs index 7ecc6fa840..a72c0c5f9b 100644 --- a/src/Nncase.Core/ITarget.cs +++ b/src/Nncase.Core/ITarget.cs @@ -13,6 +13,13 @@ namespace Nncase; +/// +/// The targets own compile options. +/// +public interface ITargetCompileOptions +{ +} + /// /// Target. /// @@ -23,6 +30,12 @@ public interface ITarget /// string Kind { get; } + /// + /// create the current target's command and parser. + /// + /// command. + (System.CommandLine.Command Command, Func Parser) RegisterCommandAndParser(); + /// /// Bind Quant Method And Quant Cosine With IR. /// @@ -91,3 +104,12 @@ public interface ITarget /// Module builder. IModuleBuilder CreateModuleBuilder(string moduleKind, CompileOptions options); } + +public sealed class DefaultTargetCompileOptions : ITargetCompileOptions +{ + public static readonly DefaultTargetCompileOptions Instance = new(); + + private DefaultTargetCompileOptions() + { + } +} diff --git a/src/Nncase.Core/LinqExtensions.cs b/src/Nncase.Core/LinqExtensions.cs index a4245953e8..b40f14a8d5 100644 --- a/src/Nncase.Core/LinqExtensions.cs +++ b/src/Nncase.Core/LinqExtensions.cs @@ -14,6 +14,21 @@ namespace Nncase; /// public static class LinqExtensions { + /// + /// Get the ranges from range desc. + /// + /// stride. + /// start. + /// stop. + /// Ranges. + public static IEnumerable Ranges(this int stride, int start, int stop) + { + for (int i = start; i < stop; i += stride) + { + yield return new Range(i, Math.Min(stop, i + stride)); + } + } + /// /// Get cartesian product. /// @@ -31,6 +46,23 @@ from item in sequence select accseq.Concat(new[] { item })); } + /// + /// Get the permutation of the source. + /// + /// Element type. + /// Source sequences. + /// Permutated sequences. + public static IEnumerable Permutate(this IEnumerable source) + { + return Permutation(source, Enumerable.Empty()); + + IEnumerable Permutation(IEnumerable reminder, IEnumerable prefix) => + !reminder.Any() ? new[] { prefix.ToArray() } : + reminder.SelectMany((c, i) => Permutation( + reminder.Take(i).Concat(reminder.Skip(i + 1)).ToArray(), + prefix.Append(c))); + } + /// /// take or default. /// diff --git a/src/Nncase.Core/Nncase.Core.csproj b/src/Nncase.Core/Nncase.Core.csproj index 716f4e3dba..1ea4c4f534 100644 --- a/src/Nncase.Core/Nncase.Core.csproj +++ b/src/Nncase.Core/Nncase.Core.csproj @@ -4,7 +4,7 @@ enable enable true - true + true True @@ -21,6 +21,7 @@ + diff --git a/src/Nncase.Core/Passes/Mutators/FlattenBuffer.cs b/src/Nncase.Core/Passes/Mutators/FlattenBuffer.cs index a0431cdd68..b0b41be0b6 100644 --- a/src/Nncase.Core/Passes/Mutators/FlattenBuffer.cs +++ b/src/Nncase.Core/Passes/Mutators/FlattenBuffer.cs @@ -12,35 +12,43 @@ namespace Nncase.Passes.Mutators; /// -/// Flatten the multi-dimensional BufferLoad and BufferStore to single dimensional Load/Store. Also remove Block to ensure that the flattened TIR can not be scheduled again. +/// Flatten the multi-dimensional BufferLoad and BufferStore to single dimensional Load/Store. /// public sealed class FlattenBuffer : ExprRewriter { /// protected override Expr RewriteLeafBlock(Block expr) { - if (!expr.IterVars.IsEmpty) + // TODO: put the unfold block into this. + if (expr.Predicate is TensorConst tc && tc.Value.ToScalar() == true) { - throw new InvalidOperationException("Non-opaque blocks are not allowed in FlattenBuffer. Please call pass ConvertBlocksToOpaque before."); + return expr.Body; } - // 1. Visit the body - var predicate = expr.Predicate; - if (predicate is TensorConst { Value: { Length: 1 } t } - && t.ToScalar()) + return T.Nop(); + } + + /// + protected override Expr RewriteLeafCall(Call expr) + { + if (expr.Target is IR.Buffers.BufferLoad) { - return expr.Body; + var indices = (IR.Tuple)expr[IR.Buffers.BufferLoad.Indices]; + var input = (TIR.Buffer)expr[IR.Buffers.BufferLoad.Input]; + return T.Load(input.MemSpan, Enumerable.Range(0, indices.Count).Aggregate((Expr)0, (acc, i) => acc + (input.Strides[i] * indices[i]))); + } + else if (expr.Target is IR.Buffers.BufferStore) + { + var indices = (IR.Tuple)expr[IR.Buffers.BufferStore.Indices]; + var input = (TIR.Buffer)expr[IR.Buffers.BufferStore.Input]; + return T.Store(input.MemSpan, Enumerable.Range(0, indices.Count).Aggregate((Expr)0, (acc, i) => acc + (input.Strides[i] * indices[i])), expr[IR.Buffers.BufferStore.Value]); } - else + else if (expr.Target is IR.Buffers.MatchBuffer && expr.Arguments[0] is TIR.Buffer { MemSpan: { Start: Const or Var } }) { - return new IfThenElse(predicate, expr.Body); + // remove the all fixed match operation. + return T.Nop(); } - // Step 3. Handle allocations in reverse order - // TODO add the alloc buffers. - // for (size_t i = new_block->alloc_buffers.size(); i > 0; --i) { - // const Buffer& buffer = new_block->alloc_buffers[i - 1]; - // body = MakeAllocStmt(buffer, std::move(body)); - // } + return expr; } } diff --git a/src/Nncase.Core/Passes/Mutators/FoldBufferSlot.cs b/src/Nncase.Core/Passes/Mutators/FoldBufferSlot.cs new file mode 100644 index 0000000000..018183c5d7 --- /dev/null +++ b/src/Nncase.Core/Passes/Mutators/FoldBufferSlot.cs @@ -0,0 +1,51 @@ +// Copyright (c) Canaan Inc. All rights reserved. +// Licensed under the Apache license. See LICENSE file in the project root for full license information. + +using System.Reactive; +using Nncase.Evaluator; +using Nncase.IR; +using Nncase.Passes; +using Nncase.TIR; + +namespace Nncase.Passes.Mutators; + +/// +/// remove buffer BaseMentOf/DDrOf/MmuOF. +/// +public sealed class FoldBufferSlot : ExprRewriter +{ + protected internal override Expr VisitPrimFunction(TIR.PrimFunction expr, Unit context) + { + if (expr.SchedResult.IsScheduled == true) + { + return base.VisitPrimFunction(expr, context); + } + + return expr; + } + + protected override Expr RewriteLeafCall(Call expr) + { + if (expr.Target is IR.Buffers.BaseMentOf) + { + var locate = ((TIR.MemSpan)expr.Arguments[0]).Location; + return locate switch + { + MemoryLocation.Input => 0, + MemoryLocation.Output => 1, + MemoryLocation.Rdata => 2, + MemoryLocation.Data => 3, + _ => throw new ArgumentOutOfRangeException($"You Can't Assgin The BaseMent For {locate}!"), + }; + } + else if (expr.Target is IR.Buffers.DDrOf) + { + if (expr.Arguments[0] is TIR.MemSpan buf) + { + return buf.Start; + } + } + + return expr; + } +} diff --git a/src/Nncase.Core/Passes/Mutators/FoldMathCall.cs b/src/Nncase.Core/Passes/Mutators/FoldMathCall.cs deleted file mode 100644 index af25604454..0000000000 --- a/src/Nncase.Core/Passes/Mutators/FoldMathCall.cs +++ /dev/null @@ -1,30 +0,0 @@ -// Copyright (c) Canaan Inc. All rights reserved. -// Licensed under the Apache license. See LICENSE file in the project root for full license information. - -using System.Reactive; -using NetFabric.Hyperlinq; -using Nncase.Evaluator; -using Nncase.IR; -using Nncase.Passes; - -namespace Nncase.Passes.Mutators; - -/// -/// fold math calc operator. -/// -public sealed class FoldMathCall : ExprRewriter -{ - /// - protected override Expr RewriteLeafCall(Call expr) - { - if (expr.Target is Op op && op.GetType().Namespace is string @namespace - && @namespace.StartsWith("Nncase.IR.Math")) - { - return expr.Arguments.AsValueEnumerable().All(x => x is Const) - ? Const.FromValue(CompilerServices.Evaluate(expr)) - : expr; - } - - return expr; - } -} diff --git a/src/Nncase.Core/Passes/Mutators/Mutator.cs b/src/Nncase.Core/Passes/Mutators/Mutator.cs index 1392ff7726..1f2587d5d0 100644 --- a/src/Nncase.Core/Passes/Mutators/Mutator.cs +++ b/src/Nncase.Core/Passes/Mutators/Mutator.cs @@ -50,10 +50,4 @@ public static class Mutator /// /// RemoveNop. public static Func RemoveNop() => () => new Mutators.RemoveNop(); - - /// - /// fold math calc operator. - /// - /// FoldMathCall. - public static Func FoldMathCall() => () => new Mutators.FoldMathCall(); } diff --git a/src/Nncase.Core/Passes/Mutators/UnRollLoopSequential.cs b/src/Nncase.Core/Passes/Mutators/UnRollLoopSequential.cs index e0a043cc64..241899e1c5 100644 --- a/src/Nncase.Core/Passes/Mutators/UnRollLoopSequential.cs +++ b/src/Nncase.Core/Passes/Mutators/UnRollLoopSequential.cs @@ -144,7 +144,10 @@ public LoopBodyCloner(IReadOnlyDictionary vmap, Dictionary expr; + protected override Expr VisitLeafMemSpan(MemSpan expr, Unit context) + { + return expr.With(Clone(expr.Start, context), Clone(expr.Size, context)); + } protected override Expr VisitLeafVar(Var expr, Unit context) { @@ -189,9 +192,10 @@ protected override Expr VisitLeafRange(TIR.Range expr, Unit context) return CSE(expr.With(start: Clone(expr.Start, context), stop: Clone(expr.Stop, context), step: Clone(expr.Step, context))); } - protected override Expr VisitLeafLogicalBuffer(LogicalBuffer expr, Unit context) + protected override Expr VisitLeafBuffer(TIR.Buffer expr, Unit context) { return expr.With( + memSpan: Clone(expr.MemSpan, context), dimensions: CloneArray(expr.Dimensions, context).Select(e => CSE(e)).ToArray(), strides: CloneArray(expr.Strides, context)); } diff --git a/src/Nncase.Core/Schedule/ScheduleTypes.cs b/src/Nncase.Core/Schedule/ScheduleTypes.cs index 55b89ce8a0..6800f79fad 100644 --- a/src/Nncase.Core/Schedule/ScheduleTypes.cs +++ b/src/Nncase.Core/Schedule/ScheduleTypes.cs @@ -10,52 +10,6 @@ namespace Nncase.Schedule; -/// -/// the memory type. -/// -public enum MemoryLocation : byte -{ - /// - /// input. - /// - Input = 0, - - /// - /// output. - /// - Output = 1, - - /// - /// constant data. - /// - Rdata = 2, - - /// - /// compute temp data. - /// - Data = 3, - - /// - /// shared data. - /// - SharedData = 4, - - /// - /// l2 data. - /// - L2Data = 5, - - /// - /// L1 data. - /// - L1Data = 6, - - /// - /// base addr. - /// - PrivateBase = 64, -} - /// /// the scheduler interface. /// @@ -261,12 +215,12 @@ public SchedFunctionResult() /// /// Gets the buffer allocation. /// - public HashSet Rdatas { get; } + public Dictionary> Rdatas { get; } /// /// Gets or sets the data section length. /// - public int DataUsage { get; set; } + public long DataUsage { get; set; } /// /// Gets or sets a value indicating whether the Scheduled status. @@ -296,8 +250,8 @@ public override bool Equals(object? obj) return true; } - return EqualityComparer>.Default.Equals(Rdatas, result.Rdatas) && - EqualityComparer.Default.Equals(DataUsage, result.DataUsage); + return EqualityComparer>>.Default.Equals(Rdatas, result.Rdatas) && + EqualityComparer.Default.Equals(DataUsage, result.DataUsage); } /// diff --git a/src/Nncase.Core/TIR/Buffer.cs b/src/Nncase.Core/TIR/Buffer.cs index a3a35aac52..570a3d43d8 100644 --- a/src/Nncase.Core/TIR/Buffer.cs +++ b/src/Nncase.Core/TIR/Buffer.cs @@ -267,233 +267,58 @@ public SelectedRange Slice(Segment1D segment) /// /// buffer. /// -public abstract class Buffer : Expr +public sealed class Buffer : Expr { - public Buffer(string name, DataType elemType, Schedule.MemoryLocation memoryLocation, Expr[] operands) - : base(operands.ToArray()) + public Buffer(string name, DataType elemType, MemSpan memSpan, Expr[] dimensions, Expr[] strides) + : base(new[] { memSpan }.Concat(dimensions).Concat(strides)) { Name = name; ElemType = elemType; - MemLocation = memoryLocation; + Rank = dimensions.Length; } public string Name { get; } public DataType ElemType { get; } - public Schedule.MemoryLocation MemLocation { get; } - - /// - /// Gets if this buffer from the constant !. - /// - public TensorConst? Const { get; init; } - /// /// Gets rank of the tensor: number of dimensions. /// - public abstract int Rank { get; } - - /// - /// Gets the strides. - /// - /// This Strides is by elements not by bytes! - /// - /// - public abstract ReadOnlySpan Strides { get; } + public int Rank { get; } /// /// Gets the shape. /// - public abstract ReadOnlySpan Dimensions { get; } - - /// - public override bool Equals(object? obj) - { - if (obj is not Buffer other) - { - return false; - } - - if (Const is not null && !Const.Equals(other.Const)) - { - return false; - } - - return string.Equals(Name, other.Name, StringComparison.Ordinal) && - ElemType.Equals(other.ElemType) && - MemLocation.Equals(other.MemLocation) && - Rank.Equals(other.Rank) && - base.Equals(obj); - } -} - -/// -/// the logical buffer. -/// -public sealed class LogicalBuffer : Buffer -{ - /// - /// Initializes a new instance of the class. - /// create from the IRType. - /// - /// the name. - /// the location. - /// prim type. - /// the shape. - /// the strides. - public LogicalBuffer(string name, DataType elemType, Schedule.MemoryLocation location, ReadOnlySpan dimensions, ReadOnlySpan strides) - : base(name, elemType, location, ArrayUtility.Concat(dimensions, strides)) - { - Rank = dimensions.Length; - } - - /// - /// Initializes a new instance of the class. - /// . - /// - public LogicalBuffer(string name, Schedule.MemoryLocation location, TensorConst tensor) - : this(name, tensor.Value.ElementType, location, ArrayUtility.ToExprArray(tensor.Value.Dimensions), ArrayUtility.ToExprArray(tensor.Value.Strides)) - { - Const = tensor; - } - - /// - /// Initializes a new instance of the class. - /// - /// - public LogicalBuffer(string name, DataType elemType, Schedule.MemoryLocation location, ReadOnlySpan dimensions) - : this(name, elemType, location, dimensions, TensorUtilities.GetStrides(dimensions)) - { - } - - /// - /// Gets get the total length. - /// - public Expr Length => TensorUtilities.GetProduct(Dimensions); + public MemSpan MemSpan => (MemSpan)Operands[0]; /// /// Gets the shape. /// - public override ReadOnlySpan Dimensions => Operands[0..Rank]; + public ReadOnlySpan Dimensions => Operands[1..(1 + Rank)]; /// /// Gets the strides. + /// + /// This Strides is by elements not by bytes! + /// /// - public override ReadOnlySpan Strides => Operands[Rank..]; - - /// - public override int Rank { get; } - - /// - public override string ToString() - { - return $"LogicalBuffer({Name}, {ElemType}, {nameof(MemLocation)})"; - } - - /// - public override TExprResult Accept(ExprFunctor functor, TContext context) - => functor.VisitLogicalBuffer(this, context); - - public LogicalBuffer With(string? name = null, DataType? elemType = null, Schedule.MemoryLocation? location = null, Expr[]? dimensions = null, Expr[]? strides = null) - => new LogicalBuffer(name ?? Name, elemType ?? ElemType, location ?? MemLocation, dimensions ?? Dimensions, strides ?? Strides) { Const = Const }; -} - -/// -/// the physical buffer. -/// -public sealed class PhysicalBuffer : Buffer -{ - private readonly int[] _fixedDimensions; - private readonly int[] _fixedStrides; - - /// - /// Initializes a new instance of the class. - /// ctor for physical buffer. - /// - public PhysicalBuffer(string name, DataType elemType, Schedule.MemoryLocation location, ReadOnlySpan dimensions, ReadOnlySpan strides, int start, int size) - : base(name, elemType, location, Array.Empty()) - { - Start = start; - Size = size; - _fixedDimensions = dimensions.ToArray(); - _fixedStrides = strides.ToArray(); - } - - /// - /// Initializes a new instance of the class. - /// . - /// - public PhysicalBuffer(string name, DataType elemType, Schedule.MemoryLocation location, ReadOnlySpan dimensions, int start, int size) - : this(name, elemType, location, dimensions, TensorUtilities.GetStrides(dimensions), start, size) - { - } - - /// - /// Initializes a new instance of the class. - /// . - /// - public PhysicalBuffer(string name, Schedule.MemoryLocation location, TensorConst tensor, int start, int size) - : this(name, tensor.Value.ElementType, location, tensor.Value.Dimensions, tensor.Value.Strides, start, size) - { - Const = tensor; - } - - /// - /// Gets fixed dimensions. - /// - public ReadOnlySpan FixedDimensions => _fixedDimensions; - - /// - /// Gets fixed strides. - /// - public ReadOnlySpan FixedStrides => _fixedStrides; - - /// - /// Gets or sets start. - /// - public int Start { get; set; } - - /// - /// Gets total size in bytes. - /// - public int Size { get; init; } - - /// - /// Gets dimensions. - /// - public override ReadOnlySpan Dimensions => ArrayUtility.ToExprArray(FixedDimensions); - - /// - /// Gets strides. - /// - public override ReadOnlySpan Strides => ArrayUtility.ToExprArray(FixedStrides); - - /// - /// Gets shape. - /// - public Shape Shape => new Shape(FixedDimensions); + public ReadOnlySpan Strides => Operands[(1 + Rank)..(1 + Rank + Rank)]; - /// - public override int Rank => FixedDimensions.Length; + public override TExprResult Accept(ExprFunctor functor, TContext context) => functor.VisitBuffer(this, context); - /// - public override string ToString() - { - return $"PhysicalBuffer({Name}, {ElemType}, {nameof(MemLocation)})"; - } + public Buffer With(MemSpan? memSpan = null, Expr[]? dimensions = null, Expr[]? strides = null) + => new Buffer(Name, ElemType, memSpan ?? MemSpan, dimensions ?? Dimensions.ToArray(), strides ?? Strides.ToArray()); /// public override bool Equals(object? obj) { - return base.Equals(obj) && obj is PhysicalBuffer other && - FixedDimensions.SequenceEqual(other.FixedDimensions) && - FixedStrides.SequenceEqual(other.FixedStrides); - } + if (ReferenceEquals(this, obj)) + { + return true; + } - /// - public override TExprResult Accept(ExprFunctor functor, TContext context) - => functor.VisitPhysicalBuffer(this, context); + return obj is TIR.Buffer other && GetHashCode() == other.GetHashCode() && Name == other.Name && ElemType == other.ElemType && Rank == other.Rank && Operands.SequenceEqual(other.Operands); + } - public PhysicalBuffer With(string? name = null, DataType? elemType = null, Schedule.MemoryLocation? location = null, int[]? dimensions = null, int[]? strides = null, int? start = null, int? size = null) - => new PhysicalBuffer(name ?? Name, elemType ?? ElemType, location ?? MemLocation, dimensions ?? FixedDimensions, strides ?? FixedStrides, start ?? Start, size ?? Size) { Const = Const }; + protected override int GetHashCodeCore() => HashCode.Combine(Name, ElemType, Rank, base.GetHashCodeCore()); } diff --git a/src/Nncase.Core/TIR/BufferLoad.cs b/src/Nncase.Core/TIR/BufferLoad.cs deleted file mode 100644 index 86081624dd..0000000000 --- a/src/Nncase.Core/TIR/BufferLoad.cs +++ /dev/null @@ -1,40 +0,0 @@ -// Copyright (c) Canaan Inc. All rights reserved. -// Licensed under the Apache license. See LICENSE file in the project root for full license information. - -using System; -using System.Collections.Generic; -using System.Linq; -using System.Text; -using System.Threading.Tasks; -using Nncase.IR; -using Nncase.Utilities; - -namespace Nncase.TIR; - -/// -/// Buffer load node. -/// -public sealed class BufferLoad : Expr -{ - public BufferLoad(PhysicalBuffer buffer, ReadOnlySpan indices) - : base(ArrayUtility.Concat(buffer, indices)) - { - } - - /// - /// Gets the buffer to be loaded. - /// - public PhysicalBuffer Buffer => (PhysicalBuffer)Operands[0]; - - /// - /// Gets the buffer indices. - /// - public ReadOnlySpan Indices => Operands.Slice(1); - - /// - public override TExprResult Accept(ExprFunctor functor, TContext context) - => functor.VisitBufferLoad(this, context); - - public BufferLoad With(PhysicalBuffer? buffer = null, Expr[]? indices = null) - => new BufferLoad(buffer ?? Buffer, indices ?? Indices); -} diff --git a/src/Nncase.Core/TIR/BufferStore.cs b/src/Nncase.Core/TIR/BufferStore.cs deleted file mode 100644 index 56d6a9df4d..0000000000 --- a/src/Nncase.Core/TIR/BufferStore.cs +++ /dev/null @@ -1,47 +0,0 @@ -// Copyright (c) Canaan Inc. All rights reserved. -// Licensed under the Apache license. See LICENSE file in the project root for full license information. - -using System; -using System.Collections.Generic; -using System.Linq; -using System.Text; -using System.Threading.Tasks; -using Nncase.IR; - -namespace Nncase.TIR; - -/// -/// Buffer store node. -/// -public sealed class BufferStore : Expr -{ - private readonly int _indicesCount; - - public BufferStore(PhysicalBuffer buffer, ReadOnlySpan indices, Expr value) - : base(new Expr[] { buffer }.Concat(indices.ToArray()).Append(value).ToArray()) - { - _indicesCount = indices.Length; - } - - /// - /// Gets the buffer. - /// - public PhysicalBuffer Buffer => (PhysicalBuffer)Operands[0]; - - /// - /// Gets the value we to be stored. - /// - public ReadOnlySpan Indices => Operands[1.._indicesCount]; - - /// - /// Gets the indices location to be stored. - /// - public Expr Value => Operands[_indicesCount + 1]; - - /// - public override TExprResult Accept(ExprFunctor functor, TContext context) - => functor.VisitBufferStore(this, context); - - public BufferStore With(PhysicalBuffer? buffer = null, Expr[]? indices = null, Expr? value = null) - => new BufferStore(buffer ?? Buffer, indices ?? Indices, value ?? Value); -} diff --git a/src/Nncase.Core/TIR/MemSpan.cs b/src/Nncase.Core/TIR/MemSpan.cs new file mode 100644 index 0000000000..f8e537d549 --- /dev/null +++ b/src/Nncase.Core/TIR/MemSpan.cs @@ -0,0 +1,105 @@ +// Copyright (c) Canaan Inc. All rights reserved. +// Licensed under the Apache license. See LICENSE file in the project root for full license information. + +using Nncase; +using Nncase.IR; + +namespace Nncase.TIR; + +/// +/// the memory type. +/// +[Flags] +public enum MemoryLocation +{ + /// + /// input. + /// + Input = 1 << 1, + + /// + /// output. + /// + Output = 1 << 2, + + /// + /// constant data. + /// + Rdata = 1 << 3, + + /// + /// compute temp data. + /// + Data = 1 << 4, + + /// + /// shared data. + /// + SharedData = 1 << 5, + + /// + /// l2 data. + /// + L2Data = 1 << 6, + + /// + /// L1 data. + /// + L1Data = 1 << 7, + + /// + /// base addr. + /// + PrivateBase = 1 << 8, +} + +public sealed class MemSpan : Expr +{ + public MemSpan(Expr size, MemoryLocation location) + : base(new[] { None.Default, size }) + { + Location = location; + } + + public MemSpan(Expr start, Expr size, MemoryLocation location) + : base(new[] { start, size }) + { + Location = location; + } + + /// + /// Gets the start. + /// + public Expr Start => Operands[0]; + + /// + /// Gets the size of bytes. + /// + public Expr Size => Operands[1]; + + /// + /// Gets the memory location. + /// + public MemoryLocation Location { get; } + + public MemSpan SubSpan(Expr offset, Expr size) => new MemSpan((Start is None ? IR.F.Buffer.DDrOf(this) : Start) + offset, size, Location); + + /// + public override TExprResult Accept(ExprFunctor functor, TContext context) + => functor.VisitMemSpan(this, context); + + public MemSpan With(Expr? start = null, Expr? size = null, MemoryLocation? location = null) => new(start ?? Start, size ?? Size, location ?? Location); + + /// + public override bool Equals(object? obj) + { + if (ReferenceEquals(this, obj)) + { + return true; + } + + return obj is MemSpan other && GetHashCode() == other.GetHashCode() && Location == other.Location && Operands.SequenceEqual(other.Operands); + } + + protected override int GetHashCodeCore() => HashCode.Combine(Location, base.GetHashCodeCore()); +} diff --git a/src/Nncase.Core/TIR/Ops.cs b/src/Nncase.Core/TIR/Ops.cs index 76f9e395b6..3405cdc841 100644 --- a/src/Nncase.Core/TIR/Ops.cs +++ b/src/Nncase.Core/TIR/Ops.cs @@ -12,19 +12,22 @@ namespace Nncase.TIR; /// -/// . +/// Load op. /// public sealed partial class Load : Op { /// /// Gets handle. /// - public static readonly ParameterInfo Handle = new(typeof(Load), 0, "handle"); + public static readonly ParameterInfo Handle = new(typeof(Load), 0, "handle", IsPointer() | IsIntegralScalar()); /// /// Gets index. /// - public static readonly ParameterInfo Index = new(typeof(Load), 1, "index", HasDataType(DataTypes.Int32) & (IsScalar() | HasRank(1))); + public static readonly ParameterInfo Index = new(typeof(Load), 1, "index", IsIntegralScalar()); + + /// + public override bool CanFoldConstCall => false; } /// @@ -53,17 +56,20 @@ public sealed partial class Store : Op /// /// The buffer variable handle. /// - public static readonly ParameterInfo Handle = new(typeof(Store), 0, "handle", IsPointer()); + public static readonly ParameterInfo Handle = new(typeof(Store), 0, "handle", IsPointer() | IsIntegralScalar()); /// /// The index locations to be stored. /// - public static readonly ParameterInfo Index = new(typeof(Store), 1, "index", HasDataType(DataTypes.Int32)); + public static readonly ParameterInfo Index = new(typeof(Store), 1, "index", IsIntegralScalar()); /// /// The value to be stored. /// - public static readonly ParameterInfo Value = new(typeof(Store), 2, "value"); + public static readonly ParameterInfo Value = new(typeof(Store), 2, "value", IsScalar()); + + /// + public override bool CanFoldConstCall => false; } /// diff --git a/src/Nncase.Core/TIR/PrimFunction.cs b/src/Nncase.Core/TIR/PrimFunction.cs index ea208efb15..2bf94454eb 100644 --- a/src/Nncase.Core/TIR/PrimFunction.cs +++ b/src/Nncase.Core/TIR/PrimFunction.cs @@ -28,8 +28,8 @@ public sealed class PrimFunction : BaseFunction /// module kind. /// Arguments. /// Body. - public PrimFunction(string name, string moduleKind, Sequential body, ReadOnlySpan parameters) - : base(name, moduleKind, ArrayUtility.Concat(body, SpanUtility.UnsafeCast(parameters))) + public PrimFunction(string name, string moduleKind, Sequential body, ReadOnlySpan parameters) + : base(name, moduleKind, ArrayUtility.Concat(body, SpanUtility.UnsafeCast(parameters))) { } @@ -39,7 +39,7 @@ public PrimFunction(string name, string moduleKind, Sequential body, ReadOnlySpa /// module kind. /// Arguments. /// Body. - public PrimFunction(string moduleKind, Sequential body, ReadOnlySpan parameters) + public PrimFunction(string moduleKind, Sequential body, ReadOnlySpan parameters) : this($"primfunc_{_globalFuncIndex++}", moduleKind, body, parameters) { } @@ -48,7 +48,7 @@ public PrimFunction(string moduleKind, Sequential body, ReadOnlySpan class. /// build function. /// - public PrimFunction(string moduleKind, Sequential body, params PhysicalBuffer[] parameters) + public PrimFunction(string moduleKind, Sequential body, params Buffer[] parameters) : this($"primfunc_{_globalFuncIndex++}", moduleKind, body, new(parameters)) { } @@ -58,7 +58,7 @@ public PrimFunction(string moduleKind, Sequential body, params PhysicalBuffer[] /// public Sequential Body => (Sequential)Operands[0]; - public ReadOnlySpan Parameters => SpanUtility.UnsafeCast(Operands.Slice(1)); + public ReadOnlySpan Parameters => SpanUtility.UnsafeCast(Operands.Slice(1)); public override IEnumerable ParameterTypes => Parameters.AsValueEnumerable().Select(x => x.CheckedType).ToArray(); @@ -66,7 +66,7 @@ public PrimFunction(string moduleKind, Sequential body, params PhysicalBuffer[] public override TExprResult Accept(ExprFunctor functor, TContext context) => functor.VisitPrimFunction(this, context); - public PrimFunction With(string? name = null, string? moduleKind = null, Sequential? body = null, PhysicalBuffer[]? parameters = null, Schedule.SchedFunctionResult? sched = null) + public PrimFunction With(string? name = null, string? moduleKind = null, Sequential? body = null, Buffer[]? parameters = null, Schedule.SchedFunctionResult? sched = null) => new PrimFunction(name ?? Name, moduleKind ?? ModuleKind, body ?? Body, parameters ?? Parameters) { // note maybe add SchedResult into ctor. diff --git a/src/Nncase.Core/TIR/Scheduler.cs b/src/Nncase.Core/TIR/Scheduler.cs index 214bd983f6..30e9552c04 100644 --- a/src/Nncase.Core/TIR/Scheduler.cs +++ b/src/Nncase.Core/TIR/Scheduler.cs @@ -87,7 +87,10 @@ public For[] Split(For loop, params Expr[] factors) } // TODO add assert total == (loop.Dom.Max - loop.Dom.Min) // Step 2. Replace all occurrences of the original loop var with new variables - Expr total = 1, substitute = 0; + _ = 1; + + // Step 2. Replace all occurrences of the original loop var with new variables + Expr substitute = 0; var newloopVars = new Var[factors.Length]; foreach (var i in Enumerable.Range(0, factors.Length)) { @@ -96,21 +99,23 @@ public For[] Split(For loop, params Expr[] factors) newloopVars[i] = loopVar; } - Dictionary opaque_block_reuse = new(); // TODO the opaque_block_reuse for what? - Sequential nbody = loop.Body; + _ = new + Dictionary(); // TODO the opaque_block_reuse for what? + _ = loop.Body; // Step 3. create new for loop. var nFor = new For[factors.Length]; - nbody = (Sequential)new Passes.Mutators.SubstituteVarAndCollectOpaqueBlock(v => v == loop.LoopVar ? substitute : v, opaque_block_reuse).Rewrite(nbody); - for (int i = factors.Length - 1; i >= 0; i--) - { - var @for = new For(newloopVars[i], (0, factors[i]), LoopMode.Serial, nbody); - nbody = T.Sequential(@for); - nFor[i] = @for; - } - // Setp 4. update the function - Entry = (Function)new Passes.Mutators.Substitutor(expr => object.ReferenceEquals(expr, loop) ? nFor[0] : null).Rewrite(Entry); + // nbody = (Sequential)new Passes.Mutators.SubstituteVarAndCollectOpaqueBlock(v => v == loop.LoopVar ? substitute : v, opaque_block_reuse).Rewrite(nbody); + // for (int i = factors.Length - 1; i >= 0; i--) + // { + // var @for = new For(newloopVars[i], (0, factors[i]), LoopMode.Serial, nbody); + // nbody = T.Sequential(@for); + // nFor[i] = @for; + // } + + // // Setp 4. update the function + // Entry = (Function)new Passes.Mutators.Substitutor(expr => object.ReferenceEquals(expr, loop) ? nFor[0] : null).Rewrite(Entry); return nFor; } diff --git a/src/Nncase.Core/TIR/Script.cs b/src/Nncase.Core/TIR/Script.cs index 9d9a212e46..28740e43ab 100644 --- a/src/Nncase.Core/TIR/Script.cs +++ b/src/Nncase.Core/TIR/Script.cs @@ -52,7 +52,7 @@ public static class T /// /// The buffer handle variable in the load expression. /// The index in the load. - public static Call Load(Var handle, Expr index) => new Call(new Load(), handle, index); + public static Call Load(Expr handle, Expr index) => new Call(new Load(), handle, index); /// /// get the nop op. @@ -76,25 +76,7 @@ public static class T /// The buffer Variable. /// The index in the store expression. /// The value we want to store. - public static Call Store(Var handle, Expr index, Expr value) => new Call(new Store(), handle, index, value); - - /// - /// If the op is BufferLoad, it will return BufferStore - /// If the op is Load, it will return Store. - /// - /// the op call. - /// update value. - /// new store call. - public static Expr Store(Expr op, Expr value) => op switch - { - Call load => load.Target switch - { - TIR.Load => T.Store((Var)load[TIR.Load.Handle], load[TIR.Load.Index], value), - _ => throw new InvalidOperationException("Only Can build Store Op from Load!"), - }, - TIR.BufferLoad bufload => new BufferStore(bufload.Buffer, bufload.Indices, value), - _ => throw new InvalidOperationException("Only Can build Store Op from Load!"), - }; + public static Call Store(Expr handle, Expr index, Expr value) => new Call(new Store(), handle, index, value); /// /// build for loop. @@ -202,7 +184,7 @@ public static ISequentialBuilder Sequential() /// )); /// /// - public static ISequentialBuilder PrimFunc(string name, string module_kind, params PhysicalBuffer[] parameters) + public static ISequentialBuilder PrimFunc(string name, string module_kind, params Buffer[] parameters) { return new SequentialBuilder(body => new PrimFunction(name, module_kind, body, parameters)); } @@ -224,54 +206,73 @@ public static IIfThenElseBuilder If(Expr condition) } /// - /// create the memRef by tensortype. + /// create the buffer by tensortype. /// - public static LogicalBuffer Buffer(DataType elem_type, Schedule.MemoryLocation location, ReadOnlySpan dimensions, out LogicalBuffer buffer, [CallerArgumentExpression("buffer")] string name = "") + public static Buffer CreateBuffer(TensorType tensorType, MemoryLocation location, out Buffer buffer, [CallerArgumentExpression("buffer")] string name = "") { if (name.StartsWith("var ")) { name = name[4..]; } - buffer = new LogicalBuffer(name, elem_type, location, dimensions); + var dimensions = tensorType.Shape.ToValueArray(); + var strides = TensorUtilities.GetStrides(dimensions); + var size = (int)TensorUtilities.GetProduct(dimensions.ToArray()) * tensorType.DType.SizeInBytes; + var memspan = new MemSpan(size, location); + buffer = new Buffer(name, tensorType.DType, memspan, dimensions.Select(i => (Expr)i).ToArray(), strides.Select(i => (Expr)i).ToArray()); return buffer; } /// - /// ctor for physical buffer. + /// create buffer by const. /// - public static PhysicalBuffer PhysicalBuffer(DataType elem_type, Schedule.MemoryLocation location, ReadOnlySpan dimensions, out PhysicalBuffer buffer, [CallerArgumentExpression("buffer")] string name = "") + public static Buffer AttachBuffer(TensorConst @const, out Buffer buffer, [CallerArgumentExpression("buffer")] string name = "") { if (name.StartsWith("var ")) { name = name[4..]; } - buffer = new PhysicalBuffer(name, elem_type, location, dimensions, 0, (int)TensorUtilities.GetProduct(dimensions.ToArray()) * elem_type.SizeInBytes); + var dimensions = @const.ValueType.Shape.ToValueArray(); + var strides = TensorUtilities.GetStrides(dimensions); + var size = (int)TensorUtilities.GetProduct(dimensions.ToArray()) * @const.ValueType.DType.SizeInBytes; + var memspan = new MemSpan(IR.F.Buffer.DDrOf(@const), size, MemoryLocation.Rdata); + buffer = new Buffer(name, @const.ValueType.DType, memspan, dimensions.Select(i => (Expr)i).ToArray(), strides.Select(i => (Expr)i).ToArray()); return buffer; } /// - /// create buffer from const. + /// attach the buffer. /// - public static PhysicalBuffer ConstBuffer(Const expr, out PhysicalBuffer buffer, [CallerArgumentExpression("buffer")] string name = "") + public static Buffer AttachBuffer(Buffer originBuffer, Expr offset, TensorType tensorType, out Buffer buffer, [CallerArgumentExpression("buffer")] string name = "") { if (name.StartsWith("var ")) { name = name[4..]; } - int size; - if (expr is TensorConst tc) - { - size = tc.Value.BytesBuffer.Length; - } - else + var dimensions = tensorType.Shape.ToValueArray(); + var strides = TensorUtilities.GetStrides(dimensions); + var size = (int)TensorUtilities.GetProduct(dimensions.ToArray()) * tensorType.DType.SizeInBytes; + buffer = new Buffer(name, tensorType.DType, originBuffer.MemSpan.SubSpan(offset, size), dimensions.Select(i => (Expr)i).ToArray(), strides.Select(i => (Expr)i).ToArray()); + return buffer; + } + + /// + /// attach the buffer. + /// + public static Buffer AttachBuffer(TensorType tensorType, MemoryLocation location, out Var @var, out Buffer buffer, [CallerArgumentExpression("buffer")] string name = "") + { + if (name.StartsWith("var ")) { - throw new NotSupportedException(); + name = name[4..]; } - buffer = new PhysicalBuffer(name, Schedule.MemoryLocation.Rdata, (TensorConst)expr, 0, size); + @var = new Var(TensorType.Pointer(tensorType.DType)); + var dimensions = tensorType.Shape.ToValueArray(); + var strides = TensorUtilities.GetStrides(dimensions); + var size = (int)TensorUtilities.GetProduct(dimensions.ToArray()) * tensorType.DType.SizeInBytes; + buffer = new Buffer(name, tensorType.DType, new MemSpan(@var, size, location), dimensions.Select(i => (Expr)i).ToArray(), strides.Select(i => (Expr)i).ToArray()); return buffer; } @@ -294,7 +295,7 @@ public static Expr MayBeConst(Const? expr, out Buffer? buffer, [CallerArgumentEx { name = name[4..]; } - buffer = new Buffer(name, Schedule.MemoryLocation.Rdata, (TensorType)expr.ValueType) + buffer = new Buffer(name, MemoryLocation.Rdata, (TensorType)expr.ValueType) { Const = expr, }; @@ -331,4 +332,23 @@ public static Call Emit(out T value, Func creator) value = creator(); return Nop(); } + + /// + /// buffer load. + /// + /// buffer. + /// indices. + /// call bufferload. + public static Call BufferLoad(TIR.Buffer buffer, params Expr[] indices) => new Call(new IR.Buffers.BufferLoad(), buffer, new IR.Tuple(indices)); + + /// + /// buffer store. + /// + /// buffer. + /// indices. + /// value. + /// call bufferstore. + public static Call BufferStore(TIR.Buffer buffer, Expr[] indices, Expr value) => new Call(new IR.Buffers.BufferStore(), buffer, new IR.Tuple(indices), value); + + public static Call MatchBuffer(TIR.Buffer buffer) => new Call(new IR.Buffers.MatchBuffer(), buffer); } diff --git a/src/Nncase.Core/Tensor.cs b/src/Nncase.Core/Tensor.cs index 6747cddee8..78f3e5e999 100644 --- a/src/Nncase.Core/Tensor.cs +++ b/src/Nncase.Core/Tensor.cs @@ -342,18 +342,18 @@ public static unsafe Tensor FromArray(Array array) public static Tensor> FromPointer(ulong value) where T : unmanaged, IEquatable { - return Tensor.FromScalar>(new Pointer(value)); + return FromScalar>(new Pointer(value)); } /// /// Create tensor from a ulong address. /// /// addr value. - /// Element type. + /// pointed type. /// Created tensor. public static Tensor FromPointer(ulong value, DataType elemType) { - return Tensor.FromBytes(TensorType.Scalar(new PointerType(elemType)), BitConverter.GetBytes(value)); + return FromBytes(TensorType.Scalar(new PointerType(elemType)), BitConverter.GetBytes(value)); } /// diff --git a/src/Nncase.Core/TensorUtilities.cs b/src/Nncase.Core/TensorUtilities.cs index 717274e0d3..79f658aefa 100644 --- a/src/Nncase.Core/TensorUtilities.cs +++ b/src/Nncase.Core/TensorUtilities.cs @@ -323,10 +323,11 @@ public static bool IsContiguous(ReadOnlySpan dimensions, ReadOnlySpan /// /// check the dimensions selected range is contiguous. /// - public static bool IsContiguousSlice(ReadOnlySpan dimensions, ReadOnlySpan slices) + public static bool IsContiguousSlice(ReadOnlySpan dimensions, ReadOnlySpan slices, out int contiguousStart) { if (dimensions.Length != slices.Length) { + contiguousStart = slices.Length - 1; return false; } @@ -366,13 +367,17 @@ public static bool IsContiguousSlice(ReadOnlySpan dimensions, ReadOnlySpan< }; if (status == SliceStatus.IsInvalid) { + contiguousStart = i + 1; return false; } } + contiguousStart = 0; return true; } + public static bool IsContiguousSlice(ReadOnlySpan dimensions, ReadOnlySpan slices) => IsContiguousSlice(dimensions, slices, out _); + public static long[] ToLongs(this ReadOnlySpan ints) { var longs = new long[ints.Length]; diff --git a/src/Nncase.Core/Utilities/DistributedUtility.cs b/src/Nncase.Core/Utilities/DistributedUtility.cs new file mode 100644 index 0000000000..13b2870bfb --- /dev/null +++ b/src/Nncase.Core/Utilities/DistributedUtility.cs @@ -0,0 +1,144 @@ +// Copyright (c) Canaan Inc. All rights reserved. +// Licensed under the Apache license. See LICENSE file in the project root for full license information. + +using System.Diagnostics.CodeAnalysis; +using Nncase.IR; + +namespace Nncase.Utilities; + +public static class DistributedUtility +{ + public static IReadOnlyList> GetLeafCandidateNDSBPs(TensorType tensorType, Placement placement) + { + var ndsbps = new List>(); + for (int i = 0; i < placement.Rank; i++) + { + var ndsbp = new List(); + for (int axis = 0; axis < tensorType.Shape.Rank; axis++) + { + if (tensorType.Shape[axis] is { IsFixed: true, Value: int s } && IsDivisible(s, placement.Hierarchy[i])) + { + ndsbp.Add(SBP.S(axis)); + } + } + + ndsbp.Add(SBP.B); + ndsbps.Add(ndsbp); + } + + return ndsbps.CartesianProduct(). + Select(ndsbp => ndsbp.ToArray()). + Where(ndsbp => IsDistributable(tensorType, ndsbp, placement, out _)). + Select(ndsbp => new IRArray(ndsbp)). + ToArray(); + } + + public static IReadOnlyList> GetPartialCandidateNDSBPs(DistributedType distributedType) + { + IRArray ndsbp = distributedType.NdSBP; + TensorType tensorType = distributedType.TensorType; + Placement placement = distributedType.Placement; + if (!ndsbp.Any(sbp => sbp is SBPPartialSum)) + { + return Array.Empty>(); + } + + var candidateNdsbps = new List[placement.Rank]; + for (int i = 0; i < placement.Rank; i++) + { + candidateNdsbps[i] = new List(); + var innerSplitedAxes = distributedType.NdSBP.Skip(i + 1).OfType().Select(sbp => sbp.Axis).ToList(); + if (ndsbp[i] is SBPPartialSum) + { + candidateNdsbps[i].Add(SBP.B); + for (int axis = 0; axis < tensorType.Shape.Rank; axis++) + { + if (tensorType.Shape[axis] is { IsFixed: true, Value: int s } && IsDivisible(s, placement.Hierarchy[i]) && !innerSplitedAxes.Contains(axis)) + { + candidateNdsbps[i].Add(SBP.S(axis)); + } + } + } + else + { + candidateNdsbps[i].Add(ndsbp[i]); + } + } + + return candidateNdsbps.CartesianProduct(). + Select(ndsbp => ndsbp.ToArray()). + Where(ndsbp => IsDistributable(tensorType, ndsbp, placement, out _)). + Select(ndsbp => new IRArray(ndsbp)). + ToArray(); + } + + public static bool IsDistributable(TensorType tensorType, ReadOnlySpan ndsbp, Placement placement, [MaybeNullWhen(false)] out TensorType distType) + { + distType = null; + if (!tensorType.Shape.IsFixed) + { + return false; + } + + var shape = tensorType.Shape.ToValueArray(); + for (int i = 0; i < ndsbp.Length; i++) + { + if (ndsbp[i] is SBPSplit { Axis: int axis }) + { + if (!IsDivisible(shape[axis], placement.Hierarchy[i])) + { + return false; + } + + shape[axis] /= placement.Hierarchy[i]; + } + } + + distType = tensorType with { Shape = shape }; + return true; + } + + public static bool IsDivisible(int input, int divisor) + { + if (input >= divisor && input % divisor == 0) + { + return true; + } + + return false; + } + + public static float GetDividedTensorEfficiency(DistributedType distributedType, int burstLength) + { + var (tiles, shape) = GetDividedTile(distributedType); + return Enumerable.Range(0, tiles.Count). + Select(i => tiles[i].Ranges(0, shape[i])). + CartesianProduct(). + Select(rgs => + { + var slice = rgs.ToArray(); + var iscontiguous = TensorUtilities.IsContiguousSlice(shape.ToArray(), slice, out var contiguousStart); + var size = TensorUtilities.GetProduct(tiles.ToArray(), contiguousStart) * distributedType.TensorType.DType.SizeInBytes; + var (div, rem) = Math.DivRem(size, burstLength); + return ((div * 1.0f) + ((float)rem / burstLength)) / (div + 1); + }).Average(); + } + + public static TensorType GetDividedTensorType(DistributedType distributedType) + { + var (tiles, _) = GetDividedTile(distributedType); + return distributedType.TensorType with { Shape = new Shape(tiles) }; + } + + private static (IReadOnlyList Tile, IReadOnlyList Shape) GetDividedTile(DistributedType distributedType) + { + var shape = distributedType.TensorType.Shape.ToValueArray(); + var tiles = distributedType.TensorType.Shape.ToValueArray(); + foreach (var (s, i) in distributedType.NdSBP.Select((s, i) => (s, i)).Where(t => t.s is SBPSplit).Select(t => ((SBPSplit)t.s, t.i))) + { + tiles[s.Axis] /= distributedType.Placement.Hierarchy[i]; + } + + return (tiles, shape); + } +} diff --git a/src/Nncase.Core/packages.lock.json b/src/Nncase.Core/packages.lock.json index b2543377b5..b846e4b245 100644 --- a/src/Nncase.Core/packages.lock.json +++ b/src/Nncase.Core/packages.lock.json @@ -60,13 +60,19 @@ }, "StyleCop.Analyzers": { "type": "Direct", - "requested": "[1.2.0-beta.507, )", - "resolved": "1.2.0-beta.507", - "contentHash": "/FtugDT66cKJJ+GGH7rNpG6UDrT4iIWz45M6lrXXHobDUFDHw+q5VgkbiR+6ffTO564ge7w6fQh/eoQhVdJO8Q==", + "requested": "[1.2.0-beta.435, )", + "resolved": "1.2.0-beta.435", + "contentHash": "TADk7vdGXtfTnYCV7GyleaaRTQjfoSfZXprQrVMm7cSJtJbFc1QIbWPyLvrgrfGdfHbGmUPvaN4ODKNxg2jgPQ==", "dependencies": { - "StyleCop.Analyzers.Unstable": "1.2.0.507" + "StyleCop.Analyzers.Unstable": "1.2.0.435" } }, + "System.CommandLine": { + "type": "Direct", + "requested": "[2.0.0-beta4.22272.1, )", + "resolved": "2.0.0-beta4.22272.1", + "contentHash": "1uqED/q2H0kKoLJ4+hI2iPSBSEdTuhfCYADeJrAqERmiGQ2NNacYKRNEQ+gFbU4glgVyK8rxI+ZOe1onEtr/Pg==" + }, "System.Reactive": { "type": "Direct", "requested": "[5.0.0, )", @@ -109,8 +115,8 @@ }, "StyleCop.Analyzers.Unstable": { "type": "Transitive", - "resolved": "1.2.0.507", - "contentHash": "gTY3IQdRqDJ4hbhSA3e/R48oE8b/OiKfvwkt1QdNVfrJK2gMHBV8ldaHJ885jxWZfllK66soa/sdcjh9bX49Tw==" + "resolved": "1.2.0.435", + "contentHash": "ouwPWZxbOV3SmCZxIRqHvljkSzkCyi1tDoMzQtDb/bRP8ctASV/iRJr+A2Gdj0QLaLmWnqTWDrH82/iP+X80Lg==" }, "System.Buffers": { "type": "Transitive", diff --git a/src/Nncase.Diagnostics/Diagnostics/Dumpper.cs b/src/Nncase.Diagnostics/Diagnostics/Dumpper.cs index 399ba10b28..8bd3794074 100644 --- a/src/Nncase.Diagnostics/Diagnostics/Dumpper.cs +++ b/src/Nncase.Diagnostics/Diagnostics/Dumpper.cs @@ -52,6 +52,12 @@ public void DumpCSharpIR(Expr expr, string prefix, string? reletivePath = null) CompilerServices.DumpCSharpIR(expr, prefix, EnsureWritable(path)); } + public void DumpPatternIR(Expr expr, string prefix, string? reletivePath = null) + { + var path = Path.Join(_dumpDirectory, reletivePath); + CompilerServices.DumpPatternIR(expr, prefix, EnsureWritable(path)); + } + public void DumpModule(IRModule module, string? reletivePath = null) { foreach (var func in module.Functions) diff --git a/src/Nncase.Diagnostics/Diagnostics/ILDotPrintVisitor.cs b/src/Nncase.Diagnostics/Diagnostics/ILDotPrintVisitor.cs index 409a4846af..c20883fc05 100644 --- a/src/Nncase.Diagnostics/Diagnostics/ILDotPrintVisitor.cs +++ b/src/Nncase.Diagnostics/Diagnostics/ILDotPrintVisitor.cs @@ -363,13 +363,13 @@ protected override ILDotOption VisitCall(Call expr) _ => throw new NotSupportedException($"Target type {expr.Target.GetType()} is not supported."), })) { - if (child is Const or None) + if (child is None) { continue; } var portName = $"P{count++}"; - row.AddCell(arg_name, cell => cell.PortName = portName); + row.AddCell(child switch { Const c => c.CheckedType.ToString(), _ => arg_name }, cell => cell.PortName = portName); connect_list.Add((child, portName)); } }); @@ -385,7 +385,7 @@ protected override ILDotOption VisitCall(Call expr) // 4. connect edge. foreach (var (child, port_name) in connect_list) { - if (child is BaseFunction) + if (child is BaseFunction or Const) { continue; } diff --git a/src/Nncase.Diagnostics/Diagnostics/ILPrintVisitor.cs b/src/Nncase.Diagnostics/Diagnostics/ILPrintVisitor.cs index 93fa794679..4a447073b8 100644 --- a/src/Nncase.Diagnostics/Diagnostics/ILPrintVisitor.cs +++ b/src/Nncase.Diagnostics/Diagnostics/ILPrintVisitor.cs @@ -274,6 +274,33 @@ public override string VisitType(CallableType type) => public override string VisitType(TupleType type) => $"({string.Join(", ", type.Fields.Select(VisitType))})"; + /// + public override string VisitType(DistributedType type) + { + var shape = type.TensorType.Shape.ToArray(); + foreach (var (s, r) in type.NdSBP.Select((s, r) => (s, r))) + { + if (s is SBPSplit split) + { + if (shape[split.Axis].IsFixed) + { + shape[split.Axis] = shape[split.Axis] / type.Placement.Hierarchy[r]; + } + } + } + + var sshape = shape.Select(s => s.ToString()).ToArray(); + foreach (var (s, r) in type.NdSBP.Select((s, r) => (s, r))) + { + if (s is SBPSplit split) + { + sshape[split.Axis] += $"@{type.Placement.Name[r]}"; + } + } + + return $"{{{VisitType(type.TensorType)}, ({string.Join(',', type.NdSBP)}), [{string.Join(',', sshape)}]}}"; + } + /// protected override string VisitCall(Call expr) { @@ -449,13 +476,7 @@ protected override string VisitPrimFunctionWrapper(PrimFunctionWrapper expr) /// protected override string VisitOp(Op expr) { - return expr switch - { - Unary op => op.UnaryOp.ToString(), - Binary op => op.BinaryOp.ToString(), - Compare op => op.CompareOp.ToString(), - _ => expr.GetType().Name, - }; + return expr.GetType().Name; } /// diff --git a/src/Nncase.Diagnostics/Diagnostics/IRPrinterProvider.cs b/src/Nncase.Diagnostics/Diagnostics/IRPrinterProvider.cs index d396e9de74..9d2b7c87a9 100644 --- a/src/Nncase.Diagnostics/Diagnostics/IRPrinterProvider.cs +++ b/src/Nncase.Diagnostics/Diagnostics/IRPrinterProvider.cs @@ -102,6 +102,25 @@ public void DumpCSharpIR(Expr expr, string prefix, string dumpDir, bool randCons } } + /// + public void DumpPatternIR(Expr expr, string prefix, string dumpDir) + { + var nprefix = prefix.Any() ? prefix + "_" : prefix; + string ext = "cs"; + string name = expr is Callable c ? c.Name : expr.GetType().Name; + string file_path = Path.Combine(dumpDir, $"{nprefix}{name}.{ext}"); + if (string.IsNullOrEmpty(dumpDir)) + { + throw new ArgumentException("The dumpDir Is Empty!"); + } + + Directory.CreateDirectory(dumpDir); + + using var dumpFile = File.Open(file_path, FileMode.Create); + using var dumpWriter = new StreamWriter(dumpFile); + new PatternPrintVisitor(dumpWriter, 0).Visit(expr); + } + /// public string Print(IRType type) { diff --git a/src/Nncase.Diagnostics/Diagnostics/PatternPrintVisitor.cs b/src/Nncase.Diagnostics/Diagnostics/PatternPrintVisitor.cs new file mode 100644 index 0000000000..3b8d5df531 --- /dev/null +++ b/src/Nncase.Diagnostics/Diagnostics/PatternPrintVisitor.cs @@ -0,0 +1,245 @@ +// Copyright (c) Canaan Inc. All rights reserved. +// Licensed under the Apache license. See LICENSE file in the project root for full license information. + +using System; +using System.Collections.Generic; +using System.IO; +using System.Linq; +using System.Text; +using System.Threading.Tasks; +using Microsoft.Extensions.DependencyInjection; +using NetFabric.Hyperlinq; +using Nncase.IR; +using Nncase.IR.Math; +using Nncase.TIR; +using Nncase.Utilities; + +namespace Nncase.Diagnostics; + +internal sealed class PatternPrintVisitor : ExprFunctor +{ + private readonly ScopeWriter _scope; + private readonly Dictionary _names = new Dictionary(ReferenceEqualityComparer.Instance); + private int _localId; + + public PatternPrintVisitor(TextWriter textWriter, int indentLevel) + { + _scope = new(textWriter, indentLevel); + } + + /// + public override string VisitType(AnyType type) => "any"; + + /// + public override string VisitType(CallableType type) => + $"({string.Join(", ", type.Parameters.Select(VisitType))}) -> {VisitType(type.ReturnType)}"; + + /// + public override string VisitType(InvalidType type) => $"invalid:{type.Reason}"; + + /// + public override string VisitType(NoneType type) => $""; + + /// + public override string VisitType(TensorType type) => type.DType switch + { + PrimType ptype => ptype.GetDisplayName() + (type.Shape.IsScalar ? string.Empty : type.Shape.ToString()), + PointerType { ElemType: PrimType etype } => $"*{etype.GetDisplayName()}", + ValueType => $"{type.DType.ToString()}", + _ => throw new NotSupportedException(type.DType.GetType().Name), + }; + + /// + public override string VisitType(TupleType type) => + $"({string.Join(", ", type.Fields.Select(VisitType))})"; + + /// + protected override string VisitCall(Call expr) + { + if (_names.TryGetValue(expr, out var name)) + { + return name; + } + + var target = Visit(expr.Target); + var args = expr.Arguments.AsValueEnumerable().Select(Visit).ToArray(); + name = AllocateTempVar(expr); + _scope.IndWrite($"var {name} = IsCall(\"{name}\", IsOp<{expr.Target.GetType().Name}>(), IsVArgs({string.Join(",", args)}));\n"); + + // AppendCheckedType(expr.CheckedType); + return name; + } + + /// + protected override string VisitConst(Const expr) + { + if (_names.TryGetValue(expr, out var name)) + { + return name; + } + + name = AllocateTempVar(expr); + _scope.IndWrite($"var {name} = IsTensorConst(\"{name}\");\n"); + return name; + } + + /// + protected override string VisitFunction(Function expr) + { + if (_names.TryGetValue(expr, out var name)) + { + return name; + } + + name = AllocateTempVar(expr); + _scope.Push(); + + // 1. functionv var + _scope.IndWrite($"Function {name}"); + AppendCheckedType(expr.CheckedType); + + // 2. Function body + _scope.IndWriteLine("{"); + using (_scope.IndentUp()) + { + var body = Visit(expr.Body); + + // _scope.IndWriteLine($"{name} = new Function(\"{expr.Name}\", {body}, new Var[] {{{StringUtility.Join(", ", expr.Parameters.AsValueEnumerable().Select(Visit))}}});"); + } + + // 3. Function signature + _scope.IndWriteLine("}"); + _scope.Append(_scope.Pop()); + return name; + } + + protected override string VisitFusion(Fusion expr) + { + if (_names.TryGetValue(expr, out var name)) + { + return name; + } + + name = AllocateTempVar(expr); + + _scope.IndWrite($"Fusion {name}"); + AppendCheckedType(expr.CheckedType); + _scope.Push(); + _scope.IndWriteLine("{"); + using (_scope.IndentUp()) + { + var body_builder = new StringBuilder(); + string body; + using (var body_writer = new StringWriter(body_builder)) + { + var visitor = new PatternPrintVisitor(body_writer, _scope.IndentLevel) { _localId = _localId }; + body = visitor.Visit(expr.Body); + _scope.Append(body_writer.ToString()); + } + + _scope.IndWriteLine($"{name} = new Fusion(\"{expr.Name}\", \"{expr.ModuleKind}\", {body}, new Var[] {{{StringUtility.Join(", ", expr.Parameters.AsValueEnumerable().Select(Visit))}}});"); + } + + _scope.IndWriteLine("}"); + _scope.Append(_scope.Pop()); + return name; + } + + /// + protected override string VisitOp(Op expr) + { + if (_names.TryGetValue(expr, out var name)) + { + return name; + } + + name = $"new {expr.GetType().Name}({expr.DisplayProperty()})"; + _names.Add(expr, name); + return name; + } + + /// + protected override string VisitTuple(IR.Tuple expr) + { + if (_names.TryGetValue(expr, out var name)) + { + return name; + } + + var fields = expr.Fields.AsValueEnumerable().Select(Visit).ToArray(); + name = AllocateTempVar(expr); + _scope.IndWrite($"var {name} = IsTuple(\"{name}\", IsVArgs({string.Join(",", fields)}));\n"); + + // AppendCheckedType(expr.CheckedType); + _scope.IndWriteLine(); + return name; + } + + /// + protected override string VisitVar(Var expr) + { + if (_names.TryGetValue(expr, out var name)) + { + return name; + } + + name = AllocateTempVar(expr); + _scope.IndWriteLine($"var {name} = IsWildcard(\"{expr.Name}\");\n"); + return name; + } + + /// + protected override string VisitNone(None expr) + { + if (_names.TryGetValue(expr, out var name)) + { + return name; + } + + name = $"None.Default"; + _names.Add(expr, name); + return name; + } + + /// + protected override string VisitMarker(Marker expr) + { + if (_names.TryGetValue(expr, out var name)) + { + return name; + } + + var target = Visit(expr.Target); + var attr = Visit(expr.Attribute); + name = AllocateTempVar(expr); + _scope.IndWrite($"var {name} = new Marker(\"{expr.Name}\",{target},{attr})"); + AppendCheckedType(expr.CheckedType); + return name; + } + + private string AllocateTempVar(Expr expr) + { + var name = $"v{_localId++}"; + _names.Add(expr, name); + return name; + } + + private void AppendCheckedType(IRType? type, string end = "", bool hasNewLine = true) + { + if (type is not null) + { + if (hasNewLine) + { + _scope.AppendLine($"; // {VisitType(type)}{end}"); + } + else + { + _scope.Append($"; // {VisitType(type)}{end}"); + } + } + else + { + _scope.Append(";\n"); + } + } +} diff --git a/src/Nncase.Diagnostics/Diagnostics/ScriptPrintVisitor.cs b/src/Nncase.Diagnostics/Diagnostics/ScriptPrintVisitor.cs index 62dac6f69c..ac4b0e3a50 100644 --- a/src/Nncase.Diagnostics/Diagnostics/ScriptPrintVisitor.cs +++ b/src/Nncase.Diagnostics/Diagnostics/ScriptPrintVisitor.cs @@ -216,6 +216,22 @@ protected override IPrintSymbol VisitTuple(IR.Tuple expr) return doc; } + protected override IPrintSymbol VisitMemSpan(MemSpan expr) + { + if (_exprMemo.TryGetValue(expr, out var doc)) + { + return doc; + } + + var start = Visit(expr.Start); + var size = Visit(expr.Size); + _scope.Push(); + _scope.Append($"MemSpan({start}, {size})@{expr.Location}"); + doc = new(_scope.Pop()); + _exprMemo.Add(expr, doc); + return doc; + } + /// protected override IPrintSymbol VisitMarker(Marker expr) { @@ -279,18 +295,17 @@ protected override IPrintSymbol VisitTensorConst(TensorConst @const) { doc = new(new($"{@const}")); } - else - if (@const.Value.ElementType.IsFloat()) + else if (@const.Value.ElementType.IsFloat()) { - doc = new(new($"{string.Join(",", @const.Value.ToArray())}")); + doc = new(new(@const.Value.Length > 8 ? @const.CheckedShape.ToString() : $"{string.Join(",", @const.Value.ToArray())}")); } else if (@const.Value.ElementType.IsIntegral()) { - doc = new(new($"{string.Join(",", @const.Value.ToArray())}")); + doc = new(new(@const.Value.Length > 8 ? @const.CheckedShape.ToString() : $"{string.Join(",", @const.Value.ToArray())}")); } - else if (@const.Value.ElementType.IsPointer()) + else if (@const.Value.ElementType is PointerType p) { - doc = new(new($"{string.Join(",", @const.Value.ToArray().Select(i => "0x" + i.ToString("X")))}")); + doc = new(new($"*{p.ElemType.GetDisplayName()}@{@const.Value.Shape}")); } _exprMemo.Add(@const, doc!); @@ -481,34 +496,6 @@ protected override IPrintSymbol VisitBlock(Block expr) return doc; } - /// - protected override IPrintSymbol VisitBufferLoad(BufferLoad expr) - { - if (_exprMemo.TryGetValue(expr, out var doc)) - { - return doc; - } - - _scope.Push(); - _scope.Append($"{expr.Buffer.Name}[{string.Join(", ", expr.Indices.ToArray().Select(Visit))}]"); - doc = new(_scope.Pop()); - return doc; - } - - /// - protected override IPrintSymbol VisitBufferStore(BufferStore expr) - { - if (_exprMemo.TryGetValue(expr, out var doc)) - { - return doc; - } - - _scope.Push(); - _scope.Append($"{expr.Buffer.Name}[{string.Join(", ", expr.Indices.ToArray().Select(Visit))}] = {Visit(expr.Value)}"); - doc = new(_scope.Pop()); - return doc; - } - /// protected override IPrintSymbol VisitIterVar(IterVar expr) { @@ -570,12 +557,8 @@ protected override IPrintSymbol VisitBuffer(TIR.Buffer expr) } _scope.Push(); - _scope.Append($"T.Buffer({expr.Name}, {expr.MemLocation}, {VisitType(expr.ElemType)})"); - if (expr is TIR.PhysicalBuffer phy) - { - _scope.Append($"@({phy.Start}, {phy.Size})"); - } - + var memSpan = Visit(expr.MemSpan); + _scope.Append($"T.Buffer({expr.Name}, {VisitType(expr.ElemType)}, {memSpan.Span}, [{string.Join(',', expr.Dimensions.AsValueEnumerable().Select(Visit).Select(e => e.Span.ToString()).ToArray())}], [{string.Join(',', expr.Strides.AsValueEnumerable().Select(Visit).Select(e => e.Span.ToString()).ToArray())}])"); doc = new(_scope.Pop(), expr.Name, true); _exprMemo.Add(expr, doc); return doc; diff --git a/src/Nncase.Diagnostics/packages.lock.json b/src/Nncase.Diagnostics/packages.lock.json index 93fabe1e48..1b9a6c1dd5 100644 --- a/src/Nncase.Diagnostics/packages.lock.json +++ b/src/Nncase.Diagnostics/packages.lock.json @@ -4,11 +4,11 @@ "net7.0": { "StyleCop.Analyzers": { "type": "Direct", - "requested": "[1.2.0-beta.507, )", - "resolved": "1.2.0-beta.507", - "contentHash": "/FtugDT66cKJJ+GGH7rNpG6UDrT4iIWz45M6lrXXHobDUFDHw+q5VgkbiR+6ffTO564ge7w6fQh/eoQhVdJO8Q==", + "requested": "[1.2.0-beta.435, )", + "resolved": "1.2.0-beta.435", + "contentHash": "TADk7vdGXtfTnYCV7GyleaaRTQjfoSfZXprQrVMm7cSJtJbFc1QIbWPyLvrgrfGdfHbGmUPvaN4ODKNxg2jgPQ==", "dependencies": { - "StyleCop.Analyzers.Unstable": "1.2.0.507" + "StyleCop.Analyzers.Unstable": "1.2.0.435" } }, "Microsoft.Extensions.Configuration.Abstractions": { @@ -47,8 +47,8 @@ }, "StyleCop.Analyzers.Unstable": { "type": "Transitive", - "resolved": "1.2.0.507", - "contentHash": "gTY3IQdRqDJ4hbhSA3e/R48oE8b/OiKfvwkt1QdNVfrJK2gMHBV8ldaHJ885jxWZfllK66soa/sdcjh9bX49Tw==" + "resolved": "1.2.0.435", + "contentHash": "ouwPWZxbOV3SmCZxIRqHvljkSzkCyi1tDoMzQtDb/bRP8ctASV/iRJr+A2Gdj0QLaLmWnqTWDrH82/iP+X80Lg==" }, "System.Buffers": { "type": "Transitive", @@ -70,6 +70,7 @@ "Microsoft.Extensions.Options": "[6.0.0, )", "Microsoft.Toolkit.HighPerformance": "[7.1.1, )", "NetFabric.Hyperlinq": "[3.0.0-beta48, )", + "System.CommandLine": "[2.0.0-beta4.22272.1, )", "System.Reactive": "[5.0.0, )" } }, @@ -129,6 +130,12 @@ "System.Runtime.CompilerServices.Unsafe": "5.0.0" } }, + "System.CommandLine": { + "type": "CentralTransitive", + "requested": "[2.0.0-beta4.22272.1, )", + "resolved": "2.0.0-beta4.22272.1", + "contentHash": "1uqED/q2H0kKoLJ4+hI2iPSBSEdTuhfCYADeJrAqERmiGQ2NNacYKRNEQ+gFbU4glgVyK8rxI+ZOe1onEtr/Pg==" + }, "System.Reactive": { "type": "CentralTransitive", "requested": "[5.0.0, )", diff --git a/src/Nncase.EGraph/Passes/EGraphExtractExtensions.cs b/src/Nncase.EGraph/Passes/EGraphExtractExtensions.cs index bbafcac29c..7c0cfbdc26 100644 --- a/src/Nncase.EGraph/Passes/EGraphExtractExtensions.cs +++ b/src/Nncase.EGraph/Passes/EGraphExtractExtensions.cs @@ -26,10 +26,11 @@ public static class EGraphExtractExtensions /// eGraph. /// Root eclass. /// base func cost evaluator. + /// the picks. /// Extracted root expression. - public static Expr Extract(this IEGraph eGraph, EClass root, Evaluator.IBaseFuncCostEvaluator? basefunc_cost_evaluator) + public static Expr Extract(this IEGraph eGraph, EClass root, Evaluator.IBaseFuncCostEvaluator? basefunc_cost_evaluator, out IReadOnlyDictionary picks) { - // 1. set the all expr checked shape + // 1. set enode expr with more accuracy type. foreach (var eclass in eGraph.Classes) { foreach (var nodes in eclass.Nodes) @@ -50,7 +51,7 @@ public static Expr Extract(this IEGraph eGraph, EClass root, Evaluator.IBaseFunc // EGraphPrinter.DumpEgraphAsDot(eGraph, costModel, root.Find(), fs); // } // return new EGraphExtractor(costModel).Extract(root.Find(), eGraph); - return new EGraphExtractors.SatExtractor(costModel).Extract(root.Find(), eGraph); + return new EGraphExtractors.SatExtractor(costModel).Extract(root.Find(), eGraph, out picks); } /// diff --git a/src/Nncase.EGraph/Passes/EGraphExtractors/Extractor.cs b/src/Nncase.EGraph/Passes/EGraphExtractors/Extractor.cs index c15dc50d83..fcf2abd729 100644 --- a/src/Nncase.EGraph/Passes/EGraphExtractors/Extractor.cs +++ b/src/Nncase.EGraph/Passes/EGraphExtractors/Extractor.cs @@ -17,7 +17,7 @@ namespace Nncase.Passes.EGraphExtractors; internal interface IExtractor { - Expr Extract(EClass root, IEGraph eGraph); + Expr Extract(EClass root, IEGraph eGraph, out IReadOnlyDictionary picks); } internal class Extractor : IExtractor @@ -25,6 +25,7 @@ internal class Extractor : IExtractor private readonly EGraphCostModel _costModel; private readonly Dictionary _eclassMemo = new(); private readonly Dictionary _markerEclassMemo = new(); + private readonly Dictionary _picks = new(); private StreamWriter? _dumpWriter; public Extractor(EGraphCostModel costModel) @@ -32,7 +33,7 @@ public Extractor(EGraphCostModel costModel) _costModel = costModel; } - public Expr Extract(EClass root, IEGraph eGraph) + public Expr Extract(EClass root, IEGraph eGraph, out IReadOnlyDictionary picks) { _dumpWriter = DumpScope.Current.IsEnabled(DumpFlags.EGraphCost) ? new StreamWriter(DumpScope.Current.OpenFile($"{nameof(Extractor)}_Class_{root.Id}.txt")) @@ -46,6 +47,15 @@ public Expr Extract(EClass root, IEGraph eGraph) _dumpWriter?.Dispose(); } + foreach (var enode in eGraph.Nodes) + { + if (!_picks.ContainsKey(enode)) + { + _picks[enode] = false; + } + } + + picks = _picks; return _eclassMemo[root]; } @@ -132,6 +142,7 @@ private void Visit(EClass eclass) _eclassMemo.Add(eclass, expr); } + _picks[minCostEnode] = true; stack.Pop(); } } diff --git a/src/Nncase.EGraph/Passes/EGraphExtractors/SatExtractor.cs b/src/Nncase.EGraph/Passes/EGraphExtractors/SatExtractor.cs index 36cf26a0d3..038873173b 100644 --- a/src/Nncase.EGraph/Passes/EGraphExtractors/SatExtractor.cs +++ b/src/Nncase.EGraph/Passes/EGraphExtractors/SatExtractor.cs @@ -22,7 +22,7 @@ public SatExtractor(EGraphCostModel costModel) _costModel = costModel; } - public Expr Extract(EClass root, IEGraph eGraph) + public Expr Extract(EClass root, IEGraph eGraph, out IReadOnlyDictionary picks) { var cpmodel = new CpModel(); @@ -108,13 +108,13 @@ public Expr Extract(EClass root, IEGraph eGraph) throw new InvalidProgramException("SatExtract Failed!"); } - var pick = eGraph.Nodes.ToDictionary(e => e, e => solver.BooleanValue(vars[e])); + picks = eGraph.Nodes.ToDictionary(e => e, e => solver.BooleanValue(vars[e])); using (var dumpStream = enableDump ? DumpScope.Current.OpenFile("Costs/Pick.dot") : Stream.Null) { - EGraphPrinter.DumpEgraphAsDot(eGraph, _costModel, pick, root.Find(), dumpStream); + EGraphPrinter.DumpEgraphAsDot(eGraph, _costModel, picks, root.Find(), dumpStream); } - return new SatExprBuildVisitor(pick).Visit(root); + return new SatExprBuildVisitor(picks).Visit(root); } private void EliminateAllCycles(EClass root, LinkedList<(EClass Class, ENode Node)> path, Dictionary> pathMemo, Dictionary visited, CpModel cpModel, Dictionary vars) diff --git a/src/Nncase.EGraph/Passes/RewriteProvider.cs b/src/Nncase.EGraph/Passes/RewriteProvider.cs index 7642a5395e..07d3416edf 100644 --- a/src/Nncase.EGraph/Passes/RewriteProvider.cs +++ b/src/Nncase.EGraph/Passes/RewriteProvider.cs @@ -36,7 +36,7 @@ public Expr ERewrite(Expr expr, IEnumerable rules, RunPassContext var graph = new EGraph(expr); ERewrite(graph, rules, options); - var post = graph.Extract(graph.Root!, null); + var post = graph.Extract(graph.Root!, null, out _); return post; } diff --git a/src/Nncase.EGraph/packages.lock.json b/src/Nncase.EGraph/packages.lock.json index 83470253e5..2fb3aca77d 100644 --- a/src/Nncase.EGraph/packages.lock.json +++ b/src/Nncase.EGraph/packages.lock.json @@ -41,11 +41,11 @@ }, "StyleCop.Analyzers": { "type": "Direct", - "requested": "[1.2.0-beta.507, )", - "resolved": "1.2.0-beta.507", - "contentHash": "/FtugDT66cKJJ+GGH7rNpG6UDrT4iIWz45M6lrXXHobDUFDHw+q5VgkbiR+6ffTO564ge7w6fQh/eoQhVdJO8Q==", + "requested": "[1.2.0-beta.435, )", + "resolved": "1.2.0-beta.435", + "contentHash": "TADk7vdGXtfTnYCV7GyleaaRTQjfoSfZXprQrVMm7cSJtJbFc1QIbWPyLvrgrfGdfHbGmUPvaN4ODKNxg2jgPQ==", "dependencies": { - "StyleCop.Analyzers.Unstable": "1.2.0.507" + "StyleCop.Analyzers.Unstable": "1.2.0.435" } }, "Google.OrTools.runtime.linux-arm64": { @@ -140,8 +140,8 @@ }, "StyleCop.Analyzers.Unstable": { "type": "Transitive", - "resolved": "1.2.0.507", - "contentHash": "gTY3IQdRqDJ4hbhSA3e/R48oE8b/OiKfvwkt1QdNVfrJK2gMHBV8ldaHJ885jxWZfllK66soa/sdcjh9bX49Tw==" + "resolved": "1.2.0.435", + "contentHash": "ouwPWZxbOV3SmCZxIRqHvljkSzkCyi1tDoMzQtDb/bRP8ctASV/iRJr+A2Gdj0QLaLmWnqTWDrH82/iP+X80Lg==" }, "System.Buffers": { "type": "Transitive", @@ -163,6 +163,7 @@ "Microsoft.Extensions.Options": "[6.0.0, )", "Microsoft.Toolkit.HighPerformance": "[7.1.1, )", "NetFabric.Hyperlinq": "[3.0.0-beta48, )", + "System.CommandLine": "[2.0.0-beta4.22272.1, )", "System.Reactive": "[5.0.0, )" } }, @@ -227,6 +228,12 @@ "libortki": "0.0.2" } }, + "System.CommandLine": { + "type": "CentralTransitive", + "requested": "[2.0.0-beta4.22272.1, )", + "resolved": "2.0.0-beta4.22272.1", + "contentHash": "1uqED/q2H0kKoLJ4+hI2iPSBSEdTuhfCYADeJrAqERmiGQ2NNacYKRNEQ+gFbU4glgVyK8rxI+ZOe1onEtr/Pg==" + }, "System.Reactive": { "type": "CentralTransitive", "requested": "[5.0.0, )", diff --git a/src/Nncase.Evaluator/Buffer.cs b/src/Nncase.Evaluator/Buffer.cs deleted file mode 100644 index 079cdaecce..0000000000 --- a/src/Nncase.Evaluator/Buffer.cs +++ /dev/null @@ -1,9 +0,0 @@ -// Copyright (c) Canaan Inc. All rights reserved. -// Licensed under the Apache license. See LICENSE file in the project root for full license information. - -namespace Nncase.Evaluator.TIR -{ - public class Buffer - { - } -} diff --git a/src/Nncase.Evaluator/Buffers/BufferLoad.cs b/src/Nncase.Evaluator/Buffers/BufferLoad.cs new file mode 100644 index 0000000000..78bab2e920 --- /dev/null +++ b/src/Nncase.Evaluator/Buffers/BufferLoad.cs @@ -0,0 +1,42 @@ +// Copyright (c) Canaan Inc. All rights reserved. +// Licensed under the Apache license. See LICENSE file in the project root for full license information. + +using Nncase.IR; +using Nncase.IR.Buffers; + +namespace Nncase.Evaluator.Buffers; + +/// +/// Evaluator for BufferOf. +/// +[TypeInferGenerator] +public partial class BufferLoadEvaluator : ITypeInferencer, IOpPrinter +{ + public string Visit(IIRPrinterContext context, BufferLoad target, bool iLmode) + { + if (iLmode) + { + throw new System.NotSupportedException(); + } + + return $"{context.GetArgument(target, BufferLoad.Input)}[{context.GetArgument(target, BufferLoad.Indices)}]"; + } + + private IRType Visit(TensorType input, TupleType indices) + { + if (indices.Count != input.Shape.Rank) + { + return new InvalidType($"the input buffer rank {input.Shape.Rank} != indices.Count {indices.Count}"); + } + + foreach (var item in indices) + { + if (item is not TensorType { IsScalar: true, DType: var dtype } || dtype != DataTypes.Int32) + { + return new InvalidType("indices is not int32 type!"); + } + } + + return TensorType.Scalar(input.DType); + } +} diff --git a/src/Nncase.Evaluator/Buffers/BufferModule.cs b/src/Nncase.Evaluator/Buffers/BufferModule.cs index 4547718379..a2512b6f13 100644 --- a/src/Nncase.Evaluator/Buffers/BufferModule.cs +++ b/src/Nncase.Evaluator/Buffers/BufferModule.cs @@ -20,5 +20,8 @@ public void ConfigureServices(IRegistrator registrator) registrator.RegisterManyInterface(reuse: Reuse.Singleton); registrator.RegisterManyInterface(reuse: Reuse.Singleton); registrator.RegisterManyInterface(reuse: Reuse.Singleton); + registrator.RegisterManyInterface(reuse: Reuse.Singleton); + registrator.RegisterManyInterface(reuse: Reuse.Singleton); + registrator.RegisterManyInterface(reuse: Reuse.Singleton); } } diff --git a/src/Nncase.Evaluator/Buffers/BufferStore.cs b/src/Nncase.Evaluator/Buffers/BufferStore.cs new file mode 100644 index 0000000000..81a833f79e --- /dev/null +++ b/src/Nncase.Evaluator/Buffers/BufferStore.cs @@ -0,0 +1,47 @@ +// Copyright (c) Canaan Inc. All rights reserved. +// Licensed under the Apache license. See LICENSE file in the project root for full license information. + +using Nncase.IR; +using Nncase.IR.Buffers; + +namespace Nncase.Evaluator.Buffers; + +/// +/// Evaluator for BufferOf. +/// +[TypeInferGenerator] +public partial class BufferStoreEvaluator : ITypeInferencer, IOpPrinter +{ + public string Visit(IIRPrinterContext context, BufferStore target, bool iLmode) + { + if (iLmode) + { + throw new System.NotSupportedException(); + } + + return $"{context.GetArgument(target, BufferStore.Input)}[{context.GetArgument(target, BufferStore.Indices)}] = {context.GetArgument(target, BufferStore.Value)}"; + } + + private IRType Visit(TensorType input, TupleType indices, TensorType value) + { + if (indices.Count != input.Shape.Rank) + { + return new InvalidType($"the input buffer rank {input.Shape.Rank} != indices.Count {indices.Count}"); + } + + foreach (var item in indices) + { + if (item is not TensorType { IsScalar: true, DType: var dtype } || dtype != DataTypes.Int32) + { + return new InvalidType("indices is not int32 type!"); + } + } + + if (!value.IsScalar || input.DType != value.DType) + { + return new InvalidType("value can't store!"); + } + + return TupleType.Void; + } +} diff --git a/src/Nncase.Evaluator/Buffers/DDrOf.cs b/src/Nncase.Evaluator/Buffers/DDrOf.cs index eb6de07acf..86c53e04b7 100644 --- a/src/Nncase.Evaluator/Buffers/DDrOf.cs +++ b/src/Nncase.Evaluator/Buffers/DDrOf.cs @@ -14,6 +14,6 @@ public partial class DDrOfEvaluator : ITypeInferencer { private IRType Visit(TensorType input) { - return new PointerType(input.DType); + return TensorType.Pointer(input.DType); } } diff --git a/src/Nncase.Evaluator/Buffers/MatchBuffer.cs b/src/Nncase.Evaluator/Buffers/MatchBuffer.cs new file mode 100644 index 0000000000..7a8122d2ae --- /dev/null +++ b/src/Nncase.Evaluator/Buffers/MatchBuffer.cs @@ -0,0 +1,29 @@ +// Copyright (c) Canaan Inc. All rights reserved. +// Licensed under the Apache license. See LICENSE file in the project root for full license information. + +using Nncase.IR; +using Nncase.IR.Buffers; + +namespace Nncase.Evaluator.Buffers; + +/// +/// Evaluator for BufferOf. +/// +[TypeInferGenerator] +public partial class MatchBufferEvaluator : ITypeInferencer, IOpPrinter +{ + public string Visit(IIRPrinterContext context, MatchBuffer target, bool iLmode) + { + if (iLmode) + { + throw new System.NotSupportedException(); + } + + return $"Matched {context.GetArgument(target, MatchBuffer.Input)}"; + } + + private IRType Visit() + { + return TupleType.Void; + } +} diff --git a/src/Nncase.Evaluator/Imaging/ResizeImage.cs b/src/Nncase.Evaluator/Imaging/ResizeImage.cs index eab837a460..e25db7b8c0 100644 --- a/src/Nncase.Evaluator/Imaging/ResizeImage.cs +++ b/src/Nncase.Evaluator/Imaging/ResizeImage.cs @@ -110,16 +110,55 @@ public IValue OnnxResize(IEvaluateContext context, ResizeImage target) /// public IRType Visit(ITypeInferenceContext context, ResizeImage target) { - var input = context.CheckArgumentType(target, ResizeImage.Input); + var input = context.CheckArgumentType(target, ResizeImage.Input); var newSize = context.GetArgument(target, ResizeImage.NewSize); + + return input switch + { + TensorType t => Visit(t, newSize), + DistributedType d => Visit(d, newSize), + _ => new InvalidType(input.GetType().ToString()), + }; + } + + public IRType Visit(TensorType input, Expr newSize) + { return TypeInference.ResizeType(input, newSize, null); } + public IRType Visit(DistributedType input, Expr newSize) + { + if (Visit(input.TensorType, newSize) is not TensorType tensorType) + { + return new InvalidType(string.Empty); + } + + var ndsbp = new SBP[input.Placement.Rank]; + + var invalid = new InvalidType($"{input}, not support"); + for (int i = 0; i < input.Placement.Rank; i++) + { + switch (input.NdSBP[i]) + { + case SBPSplit { Axis: int ix } when ix < 2: + ndsbp[i] = SBP.S(ix); + break; + case SBPBroadCast: + ndsbp[i] = SBP.B; + break; + default: + return invalid; + } + } + + return new DistributedType(tensorType, ndsbp, input.Placement); + } + /// public Cost Visit(ICostEvaluateContext context, ResizeImage target) { - var inputType = context.GetArgumentType(target, ResizeImage.Input); - var returnType = context.GetReturnType(); + var inputType = context.GetArgumentType(target, ResizeImage.Input); + var returnType = context.GetReturnType(); return new() { [CostFactorNames.MemoryLoad] = CostUtility.GetMemoryAccess(inputType), diff --git a/src/Nncase.Evaluator/Math/Binary.cs b/src/Nncase.Evaluator/Math/Binary.cs index f5f9d3c65b..1ac424b0be 100755 --- a/src/Nncase.Evaluator/Math/Binary.cs +++ b/src/Nncase.Evaluator/Math/Binary.cs @@ -2,9 +2,11 @@ // Licensed under the Apache license. See LICENSE file in the project root for full license information. using System; +using DryIoc; using Nncase.CostModel; using Nncase.IR; using Nncase.IR.Math; +using Nncase.IR.Tensors; using Nncase.Utilities; using OrtKISharp; @@ -42,6 +44,14 @@ public IValue Visit(IEvaluateContext context, Binary binary) { return Value.FromTensor(Tensor.FromScalar(Compute(binary.BinaryOp, lhs.ToScalar(), rhs.ToScalar()))); } + else if (lhs.ElementType is PointerType && (rhs.ElementType == DataTypes.UInt32 || rhs.ElementType == DataTypes.UInt64)) + { + return Value.FromTensor(Tensor.FromScalar(Compute(binary.BinaryOp, lhs.ToScalar(), rhs.ToScalar()))); + } + else if ((lhs.ElementType == DataTypes.UInt32 || lhs.ElementType == DataTypes.UInt64) && rhs.ElementType is PointerType) + { + return Value.FromTensor(Tensor.FromScalar(Compute(binary.BinaryOp, lhs.ToScalar(), rhs.ToScalar()))); + } else { return Ort_compute(binary, lhs, rhs); @@ -54,17 +64,24 @@ public IValue Visit(IEvaluateContext context, Binary binary) /// public IRType Visit(ITypeInferenceContext context, Binary target) { - var lhs = context.CheckArgumentType(target, Binary.Lhs); - var rhs = context.CheckArgumentType(target, Binary.Rhs); - return Visit(target, lhs, rhs); + var lhs = context.CheckArgumentType(target, Binary.Lhs); + var rhs = context.CheckArgumentType(target, Binary.Rhs); + return (lhs, rhs) switch + { + (TensorType a, TensorType b) => Visit(target, a, b), + (DistributedType a, DistributedType b) => Visit(target, a, b), + (AnyType, _) => AnyType.Default, + (_, AnyType) => AnyType.Default, + _ => new InvalidType($"{lhs} {rhs}"), + }; } /// public Cost Visit(ICostEvaluateContext context, Binary target) { - var lhsType = context.GetArgumentType(target, Binary.Lhs); - var rhsType = context.GetArgumentType(target, Binary.Rhs); - var outputType = context.GetReturnType(); + var lhsType = context.GetArgumentType(target, Binary.Lhs); + var rhsType = context.GetArgumentType(target, Binary.Rhs); + var outputType = context.GetReturnType(); return new() { @@ -121,6 +138,76 @@ public Expr Visit(IShapeEvaluateContext context, Binary target) return ShapeExprUtility.BroadcastShape(lhs, rhs); } + private IRType Visit(Binary target, DistributedType a, DistributedType b) + { + if (a.Placement != b.Placement) + { + return new InvalidType("lhs rhs have different placement"); + } + + var rType = Visit(target, a.TensorType, b.TensorType); + if (rType is not TensorType tensorType) + { + return rType; + } + + // assume broadcast shapes are left algin + var padA = tensorType.Shape.Rank - a.TensorType.Shape.Rank; + var padB = tensorType.Shape.Rank - b.TensorType.Shape.Rank; + var ndsbp = new SBP[a.Placement.Rank]; + for (int i = 0; i < a.Placement.Rank; i++) + { + switch (a.NdSBP[i], b.NdSBP[i]) + { + case (SBPSplit sa, SBPSplit sb): + if ((padA + sa.Axis) != (padB + sb.Axis)) + { + return new InvalidType($"lhs rhs sbp at {i} not equal"); + } + + ndsbp[i] = SBP.S(padA + sa.Axis); + break; + case (SBPSplit s1, SBPBroadCast): + // invalid (S, B) if B is not broacast + if (s1.Axis + padA - padB >= 0 && b.TensorType.Shape[s1.Axis + padA - padB] != 1) + { + return new InvalidType($"lhs rhs sbp at {i} not broadcast"); + } + + ndsbp[i] = SBP.S(padA + s1.Axis); + break; + case (SBPBroadCast, SBPSplit s2): + // invalid (B, S) if A is not broacast + if (s2.Axis + padB - padA >= 0 && a.TensorType.Shape[s2.Axis + padB - padA] != 1) + { + return new InvalidType($"lhs rhs sbp at {i} not broadcast"); + } + + ndsbp[i] = SBP.S(padB + s2.Axis); + break; + case (SBPBroadCast, SBPBroadCast): + ndsbp[i] = SBP.B; + break; + case (SBPPartialSum, SBPPartialSum): + if (target.BinaryOp == BinaryOp.Add) + { + ndsbp[i] = SBP.P; + } + else + { + return new InvalidType("lhs rhs all partialsum only can be added."); + } + + break; + case (SBPPartialSum, _): + case (_, SBPPartialSum): + return new InvalidType("not support lhs or rhs partial."); + } + } + + return new DistributedType(tensorType, ndsbp, a.Placement); + } + private int Compute(BinaryOp op, int a, int b) => op switch { BinaryOp.Add => a + b, @@ -149,6 +236,18 @@ public Expr Visit(IShapeEvaluateContext context, Binary target) _ => throw new ArgumentOutOfRangeException(nameof(op)), }; + private ulong Compute(BinaryOp op, ulong a, ulong b) => op switch + { + BinaryOp.Add => a + b, + BinaryOp.Sub => a - b, + BinaryOp.Mul => a * b, + BinaryOp.Div => a / b, + BinaryOp.Mod => a % b, + BinaryOp.Min => System.Math.Min(a, b), + BinaryOp.Max => System.Math.Max(a, b), + _ => throw new ArgumentOutOfRangeException(nameof(op)), + }; + private bool Compute(BinaryOp op, bool a, bool b) => op switch { BinaryOp.LogicalAnd => a & b, @@ -228,26 +327,24 @@ private IRType Visit(Binary target, TensorType lhs, TensorType rhs) return new InvalidType("The Binary Logical Only Accept The Boolean Datatype."); } - if (lhs is { DType: PointerType { ElemType: var letype } } && rhs is { DType: PointerType { ElemType: var retype } }) + if (lhs is { DType: PointerType { ElemType: var letype } }) { - if (letype == retype) + if ((rhs is { DType: PointerType { ElemType: var other } } && letype == other) || rhs.DType == DataTypes.UInt64 || rhs.DType == DataTypes.UInt32) { return TensorType.Pointer(letype); } - else - { - return new InvalidType($"The Binary Lhs {CompilerServices.Print(lhs)} != Rhs {CompilerServices.Print(rhs)}"); - } - } - if (lhs is { DType: PointerType { ElemType: var lt } } && rhs.DType == DataTypes.Int32) - { - return TensorType.Pointer(lt); + return new InvalidType($"The Binary Lhs {CompilerServices.Print(lhs)} != Rhs {CompilerServices.Print(rhs)}"); } - if (lhs.DType == DataTypes.Int32 && rhs is { DType: PointerType { ElemType: var rt } }) + if (rhs is { DType: PointerType { ElemType: var retype } }) { - return TensorType.Pointer(rt); + if ((lhs is { DType: PointerType { ElemType: var other } } && retype == other) || lhs.DType == DataTypes.UInt64 || lhs.DType == DataTypes.UInt32) + { + return TensorType.Pointer(retype); + } + + return new InvalidType($"The Binary Lhs {CompilerServices.Print(lhs)} != Rhs {CompilerServices.Print(rhs)}"); } return TypeInference.BroadcastType(lhs, rhs); diff --git a/src/Nncase.Evaluator/Math/Clamp.cs b/src/Nncase.Evaluator/Math/Clamp.cs index c2da8bb94c..383e2dd509 100644 --- a/src/Nncase.Evaluator/Math/Clamp.cs +++ b/src/Nncase.Evaluator/Math/Clamp.cs @@ -29,25 +29,25 @@ public IValue Visit(IEvaluateContext context, Clamp clamp) /// public IRType Visit(ITypeInferenceContext context, Clamp target) { - var input = context.CheckArgumentType(target, Clamp.Input); + var input = context.CheckArgumentType(target, Clamp.Input); var min = context.CheckArgumentType(target, Clamp.Min); var max = context.CheckArgumentType(target, Clamp.Max); - if (input.DType != min.DType || input.DType != max.DType || min.DType != max.DType) - { - return new InvalidType( - $"clamp type is not equal, input:{input.DType}, min:${min.DType}, max:${max.DType}"); - } - return Visit(input, min, max); + return input switch + { + TensorType t => Visit(t, min, max), + DistributedType d => Visit(d, min, max), + _ => new InvalidType("Wrong Clamp Type!"), + }; } /// public Cost Visit(ICostEvaluateContext context, Clamp target) { - var inputType = context.GetArgumentType(target, Clamp.Input); + var inputType = context.GetArgumentType(target, Clamp.Input); var minType = context.GetArgumentType(target, Clamp.Min); var maxType = context.GetArgumentType(target, Clamp.Max); - var outputType = context.GetReturnType(); + var outputType = context.GetReturnType(); return new() { @@ -71,6 +71,12 @@ public Metric Visit(IMetricEvaluateContext context, Clamp target) private IRType Visit(TensorType input, TensorType min, TensorType max) { + if (input.DType != min.DType || input.DType != max.DType || min.DType != max.DType) + { + return new InvalidType( + $"clamp type is not equal, input:{input.DType}, min:${min.DType}, max:${max.DType}"); + } + if (TypeInference.BroadcastType(input, min) is InvalidType invalidMin) { return invalidMin; @@ -88,4 +94,9 @@ private IRType Visit(TensorType input, TensorType min, TensorType max) return input; } + + private IRType Visit(DistributedType input, TensorType min, TensorType max) + { + return input; + } } diff --git a/src/Nncase.Evaluator/Math/MatMul.cs b/src/Nncase.Evaluator/Math/MatMul.cs index 4785e1e1c1..1f19b64388 100644 --- a/src/Nncase.Evaluator/Math/MatMul.cs +++ b/src/Nncase.Evaluator/Math/MatMul.cs @@ -19,62 +19,93 @@ namespace Nncase.Evaluator.Math; /// public class MatMulEvaluator : IEvaluator, ITypeInferencer, ICostEvaluator, IShapeEvaluator, IMetricEvaluator { - /// - public IValue Visit(IEvaluateContext context, MatMul matMul) + public static IRType VisitDistributedType(DistributedType a, DistributedType b) { - var input = context.GetOrtArgumentValue(matMul, MatMul.Lhs); - var other = context.GetOrtArgumentValue(matMul, MatMul.Rhs); - return OrtKI.MatMul(input, other).ToValue(); - } + if (VisitTensorType(a.TensorType, b.TensorType) is not TensorType outType) + { + return new InvalidType(string.Empty); + } - /// - public IRType Visit(ITypeInferenceContext context, MatMul target) - { - var lhs = context.CheckArgumentType(target, MatMul.Lhs); - var rhs = context.CheckArgumentType(target, MatMul.Rhs); - return Visit(lhs, rhs); - } + if (a.Placement != b.Placement) + { + return new InvalidType("placement not equal"); + } - /// - public Cost Visit(ICostEvaluateContext context, MatMul target) - { - var lhs = context.GetArgumentType(target, MatMul.Lhs); - var rhs = context.GetArgumentType(target, MatMul.Rhs); - var outputType = context.GetReturnType(); + var aRank = a.TensorType.Shape.Rank; + var bRank = b.TensorType.Shape.Rank; + var oRank = outType.Shape.Rank; + var aPad = oRank - aRank; + var bPad = oRank - bRank; - uint macPerElement = lhs.Shape[^1].IsFixed ? (uint)lhs.Shape[^1].FixedValue : 1U; - return new() + var ndsbp = new SBP[a.Placement.Rank]; + for (int i = 0; i < a.Placement.Rank; i++) { - [CostFactorNames.MemoryLoad] = CostUtility.GetMemoryAccess(lhs) + CostUtility.GetMemoryAccess(rhs), - [CostFactorNames.MemoryStore] = CostUtility.GetMemoryAccess(outputType), - [CostFactorNames.CPUCycles] = CostUtility.GetCPUCycles(outputType, macPerElement), - }; - } + var invalid = new InvalidType($"({a.NdSBP[i]}, {b.NdSBP[i]}) not support"); + switch (a.NdSBP[i], b.NdSBP[i]) + { + // split on k + case (SBPSplit { Axis: int ax }, SBPSplit { Axis: int bx }): + if (ax == (aRank - 1) && bx == (bRank - 2)) + { + ndsbp[i] = SBP.P; + } + else if ((ax == (aRank - 1) && bx != (bRank - 2)) || (ax != (aRank - 1) && bx == (bRank - 2))) + { + return invalid; + } + else + { + if ((ax + aPad) == (bx + bPad)) + { + ndsbp[i] = SBP.S(ax + aPad); + } + else + { + return invalid; + } + } - public Metric Visit(IMetricEvaluateContext context, MatMul target) - { - var lhs = context.GetArgumentType(target, MatMul.Lhs); - var rhs = context.GetArgumentType(target, MatMul.Rhs); - var outputType = context.GetReturnType(); - var k = (UInt128)lhs.Shape[^1].FixedValue; - var m = MetricUtility.GetFLOPs(lhs) / k; - var n = MetricUtility.GetFLOPs(rhs) / k; - return new() - { - [MetricFactorNames.OffChipMemoryTraffic] = CostUtility.GetMemoryAccess(lhs) + CostUtility.GetMemoryAccess(rhs) + CostUtility.GetMemoryAccess(outputType), - [MetricFactorNames.FLOPs] = m * n * ((2 * k) - 1), - [MetricFactorNames.Parallel] = 4, - }; - } + break; + case (SBPSplit { Axis: int ax }, SBPBroadCast): + if (ax == aRank - 1) + { + return invalid; + } - public Expr Visit(IShapeEvaluateContext context, MatMul target) - { - var lhs = context.GetArgumentShape(target, MatMul.Lhs); - var rhs = context.GetArgumentShape(target, MatMul.Rhs); - return IR.F.ShapeExpr.MatMulShape(lhs, rhs); + // invalid (S, B) if B is not broacast matmul + if (ax < aRank - 2 && !(bRank <= 2 || (ax + aPad - bPad >= 0 && b.TensorType.Shape[ax + aPad - bPad] == 1))) + { + return invalid; + } + + ndsbp[i] = SBP.S(ax + aPad); + break; + case (SBPBroadCast, SBPSplit { Axis: int bx }): + if (bx == bRank - 2) + { + return invalid; + } + + // invalid (B, S) if A is not broacast matmul + if (bx < bRank - 2 && !(aRank <= 2 || (bx + bPad - aPad >= 0 && a.TensorType.Shape[bx + bPad - aPad] == 1))) + { + return invalid; + } + + ndsbp[i] = SBP.S(bx + bPad); + break; + case (SBPBroadCast, SBPBroadCast): + ndsbp[i] = SBP.B; + break; + default: + return invalid; + } + } + + return new DistributedType(outType, ndsbp, a.Placement); } - private IRType Visit(TensorType lhs, TensorType rhs) + public static IRType VisitTensorType(TensorType lhs, TensorType rhs) { if (lhs.Shape.IsUnranked || rhs.Shape.IsUnranked) { @@ -113,4 +144,74 @@ private IRType Visit(TensorType lhs, TensorType rhs) var end = new[] { lhs.Shape[^2], rhs.Shape[^1] }; return new TensorType(lhs.DType, front.Concat(end).ToArray()); } + + /// + public IValue Visit(IEvaluateContext context, MatMul matMul) + { + var input = context.GetOrtArgumentValue(matMul, MatMul.Lhs); + var other = context.GetOrtArgumentValue(matMul, MatMul.Rhs); + return OrtKI.MatMul(input, other).ToValue(); + } + + /// + public IRType Visit(ITypeInferenceContext context, MatMul target) + { + var lhs = context.CheckArgumentType(target, MatMul.Lhs); + var rhs = context.CheckArgumentType(target, MatMul.Rhs); + return (lhs, rhs) switch + { + (DistributedType a, DistributedType b) => VisitDistributedType(a, b), + (TensorType a, TensorType b) => VisitTensorType(a, b), + _ => new InvalidType(string.Empty), + }; + } + + /// + public Cost Visit(ICostEvaluateContext context, MatMul target) + { + var lhs = context.GetArgumentType(target, MatMul.Lhs); + var rhs = context.GetArgumentType(target, MatMul.Rhs); + var outputType = context.GetReturnType(); + + uint macPerElement = 1; + if (lhs is TensorType { Shape: Shape lhsShape }) + { + macPerElement = lhsShape[^1].IsFixed ? (uint)lhsShape[^1].FixedValue : 1U; + } + else if (lhs is DistributedType distributedType) + { + var lhsType = DistributedUtility.GetDividedTensorType(distributedType); + macPerElement = lhsType.Shape[^1].IsFixed ? (uint)lhsType.Shape[^1].FixedValue : 1U; + } + + return new() + { + [CostFactorNames.MemoryLoad] = CostUtility.GetMemoryAccess(lhs) + CostUtility.GetMemoryAccess(rhs), + [CostFactorNames.MemoryStore] = CostUtility.GetMemoryAccess(outputType), + [CostFactorNames.CPUCycles] = CostUtility.GetCPUCycles(outputType, macPerElement), + }; + } + + public Metric Visit(IMetricEvaluateContext context, MatMul target) + { + var lhs = context.GetArgumentType(target, MatMul.Lhs); + var rhs = context.GetArgumentType(target, MatMul.Rhs); + var outputType = context.GetReturnType(); + var k = (UInt128)lhs.Shape[^1].FixedValue; + var m = MetricUtility.GetFLOPs(lhs) / k; + var n = MetricUtility.GetFLOPs(rhs) / k; + return new() + { + [MetricFactorNames.OffChipMemoryTraffic] = CostUtility.GetMemoryAccess(lhs) + CostUtility.GetMemoryAccess(rhs) + CostUtility.GetMemoryAccess(outputType), + [MetricFactorNames.FLOPs] = m * n * ((2 * k) - 1), + [MetricFactorNames.Parallel] = 4, + }; + } + + public Expr Visit(IShapeEvaluateContext context, MatMul target) + { + var lhs = context.GetArgumentShape(target, MatMul.Lhs); + var rhs = context.GetArgumentShape(target, MatMul.Rhs); + return Cast(IR.F.ShapeExpr.MatMulShape(lhs, rhs), DataTypes.Int32); + } } diff --git a/src/Nncase.Evaluator/Math/ReduceArg.cs b/src/Nncase.Evaluator/Math/ReduceArg.cs index 5ded2c065f..ad865c0e51 100644 --- a/src/Nncase.Evaluator/Math/ReduceArg.cs +++ b/src/Nncase.Evaluator/Math/ReduceArg.cs @@ -40,16 +40,23 @@ public IValue Visit(IEvaluateContext context, ReduceArg reduceArg) /// public IRType Visit(ITypeInferenceContext context, ReduceArg target) { - var input = context.CheckArgumentType(target, ReduceArg.Input); - return Visit(context, target, input); + var input = context.CheckArgumentType(target, ReduceArg.Input); + return input switch + { + TensorType tensorType => Visit(context, target, tensorType), + DistributedType distributedType => Visit(context, target, distributedType), + _ => new InvalidType(string.Empty), + }; } public Cost Visit(ICostEvaluateContext context, ReduceArg target) { - var input = context.GetArgumentType(target, ReduceArg.Input); - var ret = context.GetReturnType(); - uint input_elem = input.Shape.Aggregate(1U, (acc, d) => acc * (d.IsFixed ? (uint)d.FixedValue : 1U)); - uint ret_elem = ret.Shape.Aggregate(1U, (acc, d) => acc * (d.IsFixed ? (uint)d.FixedValue : 1U)); + var input = context.GetArgumentType(target, ReduceArg.Input); + var ret = context.GetReturnType(); + var inShape = input switch { TensorType t => t.Shape, DistributedType d => d.TensorType.Shape, _ => throw new NotImplementedException() }; + var rShape = ret switch { TensorType t => t.Shape, DistributedType d => d.TensorType.Shape, _ => throw new NotImplementedException() }; + uint input_elem = inShape.Aggregate(1U, (acc, d) => acc * (d.IsFixed ? (uint)d.FixedValue : 1U)); + uint ret_elem = rShape.Aggregate(1U, (acc, d) => acc * (d.IsFixed ? (uint)d.FixedValue : 1U)); uint macPerElement = input_elem / ret_elem; return new() { [CostFactorNames.MemoryLoad] = CostUtility.GetMemoryAccess(input), [CostFactorNames.MemoryStore] = CostUtility.GetMemoryAccess(ret), [CostFactorNames.CPUCycles] = CostUtility.GetCPUCycles(ret, macPerElement), }; } @@ -93,4 +100,51 @@ private IRType Visit(ITypeInferenceContext context, ReduceArg target, TensorType return new InvalidType("ReduceArg axis and keepDims are not const"); } } + + private IRType Visit(ITypeInferenceContext context, ReduceArg target, DistributedType distributedType) + { + var rType = Visit(context, target, distributedType.TensorType); + if (rType is not TensorType tensorType) + { + return rType; + } + + var inshape = distributedType.TensorType.Shape; + if (context.GetArgument(target, ReduceArg.Axis) is TensorConst axisValue && + context.GetArgument(target, ReduceArg.KeepDims) is TensorConst keepDimsValue) + { + var axis = axisValue.Value.ToScalar(); + axis = axis >= 0 ? axis : inshape.Rank + axis; + var keepdim = keepDimsValue.Value.ToScalar(); + var ndsbp = new SBP[distributedType.Placement.Rank]; + for (int i = 0; i < ndsbp.Length; i++) + { + switch (distributedType.NdSBP[i]) + { + case SBPSplit { Axis: int saxis }: + if (saxis == axis) + { + return new InvalidType("can't split on reduce axis."); + } + + ndsbp[i] = keepdim ? SBP.S(saxis) : SBP.S(saxis > axis ? saxis - 1 : saxis); + break; + case SBPPartialSum: + return new InvalidType("not support partial sum."); + case SBPBroadCast: + ndsbp[i] = SBP.B; + break; + } + } + + return distributedType with { NdSBP = new(ndsbp), TensorType = tensorType }; + } + + if (!distributedType.NdSBP.All(sbp => sbp is SBPBroadCast)) + { + return new InvalidType(string.Empty); + } + + return distributedType with { TensorType = tensorType }; + } } diff --git a/src/Nncase.Evaluator/Math/Unary.cs b/src/Nncase.Evaluator/Math/Unary.cs index 64c3bfbfbf..95824a5bb8 100644 --- a/src/Nncase.Evaluator/Math/Unary.cs +++ b/src/Nncase.Evaluator/Math/Unary.cs @@ -64,21 +64,27 @@ public IValue Visit(IEvaluateContext context, Unary unary) /// public IRType Visit(ITypeInferenceContext context, Unary target) { - var input = context.CheckArgumentType(target, Unary.Input); - return Visit(input); + var inputType = context.GetArgumentType(target, Unary.Input); + + return inputType switch + { + TensorType tensorType => Visit(tensorType), + DistributedType distTensorType => Visit(distTensorType, target.UnaryOp), + AnyType => AnyType.Default, + _ => new InvalidType($"Not support {inputType.GetType().Name}"), + }; } /// public Cost Visit(ICostEvaluateContext context, Unary target) { - var inputType = context.GetArgumentType(target, Unary.Input); - var outputType = context.GetReturnType(); - - return new() + var inputType = context.GetArgumentType(target, Unary.Input); + var outputType = context.GetReturnType(); + return (inputType, outputType) switch { - [CostFactorNames.MemoryLoad] = CostUtility.GetMemoryAccess(inputType), - [CostFactorNames.MemoryStore] = CostUtility.GetMemoryAccess(outputType), - [CostFactorNames.CPUCycles] = CostUtility.GetCPUCycles(outputType, CostUtility.GetCPUCyclesOfUnary(target.UnaryOp)), + (TensorType tensorType, TensorType tensorType1) => Visit(tensorType, tensorType1, target), + (DistributedType distTensorType, DistributedType distTensorType1) => Visit(distTensorType, distTensorType1, target), + _ => throw new NotSupportedException(string.Empty), }; } @@ -117,6 +123,23 @@ public Expr Visit(IShapeEvaluateContext context, Unary target) return context.GetArgumentShape(target, Unary.Input); } + private IRType Visit(DistributedType inType, UnaryOp unaryOp) + { + var invalid = new InvalidType(inType.ToString()); + var ndsbp = new SBP[inType.Placement.Rank]; + for (int i = 0; i < inType.Placement.Rank; i++) + { + if (inType.NdSBP[i] is SBPPartialSum && unaryOp != UnaryOp.Neg) + { + return invalid; + } + + ndsbp[i] = inType.NdSBP[i]; + } + + return new DistributedType(inType.TensorType, ndsbp, inType.Placement); + } + private int Compute_int(int input, UnaryOp op) => op switch { UnaryOp.Ceil => input, @@ -156,4 +179,26 @@ private IRType Visit(TensorType input) { return input; } + + private Cost Visit(DistributedType inType, DistributedType outType, Unary target) + { + var inPartType = Utilities.DistributedUtility.GetDividedTensorType(inType); + var outPartType = Utilities.DistributedUtility.GetDividedTensorType(outType); + return new() + { + [CostFactorNames.MemoryLoad] = CostUtility.GetMemoryAccess(inPartType), + [CostFactorNames.CPUCycles] = CostUtility.GetCPUCycles(outPartType, CostUtility.GetCPUCyclesOfUnary(target.UnaryOp)), + [CostFactorNames.MemoryStore] = CostUtility.GetMemoryAccess(outPartType), + }; + } + + private Cost Visit(TensorType inputType, TensorType outputType, Unary target) + { + return new() + { + [CostFactorNames.MemoryLoad] = CostUtility.GetMemoryAccess(inputType), + [CostFactorNames.MemoryStore] = CostUtility.GetMemoryAccess(outputType), + [CostFactorNames.CPUCycles] = CostUtility.GetCPUCycles(outputType, CostUtility.GetCPUCyclesOfUnary(target.UnaryOp)), + }; + } } diff --git a/src/Nncase.Evaluator/NN/Activations.cs b/src/Nncase.Evaluator/NN/Activations.cs index 002314c9fa..aef4860dbd 100644 --- a/src/Nncase.Evaluator/NN/Activations.cs +++ b/src/Nncase.Evaluator/NN/Activations.cs @@ -1,6 +1,7 @@ // Copyright (c) Canaan Inc. All rights reserved. // Licensed under the Apache license. See LICENSE file in the project root for full license information. +using System.Linq; using Nncase.CostModel; using Nncase.IR; using Nncase.IR.NN; @@ -526,20 +527,21 @@ public class SwishEvaluator : IEvaluator, ITypeInferencer, ICostEv public IValue Visit(IEvaluateContext context, Swish swish) { var input = context.GetOrtArgumentValue(swish, Swish.Input); - return OrtKI.Mul(OrtKI.Sigmoid(input), input).ToValue(); + var beta = context.GetOrtArgumentValue(swish, Swish.Beta); + return OrtKI.Mul(OrtKI.Sigmoid(input * beta), input).ToValue(); } /// public IRType Visit(ITypeInferenceContext context, Swish target) { - var input = context.CheckArgumentType(target, Swish.Input); + var input = context.CheckArgumentType(target, Swish.Input); return Visit(input); } /// public Cost Visit(ICostEvaluateContext context, Swish target) { - var outputType = context.GetReturnType(); + var outputType = context.GetReturnType(); return new() { [CostFactorNames.MemoryLoad] = CostUtility.GetMemoryAccess(outputType), @@ -558,8 +560,13 @@ public Metric Visit(IMetricEvaluateContext context, Swish target) }; } - private IRType Visit(TensorType input) + private IRType Visit(IRType input) { + if (input is DistributedType d && d.NdSBP.Any(s => s is SBPPartialSum)) + { + return new InvalidType("swish with partial sum is not supported"); + } + return input; } } @@ -582,14 +589,14 @@ public IValue Visit(IEvaluateContext context, Gelu gelu) /// public IRType Visit(ITypeInferenceContext context, Gelu target) { - var input = context.CheckArgumentType(target, Gelu.Input); + var input = context.CheckArgumentType(target, Gelu.Input); return Visit(input); } /// public Cost Visit(ICostEvaluateContext context, Gelu target) { - var outputType = context.GetReturnType(); + var outputType = context.GetReturnType(); return new() { [CostFactorNames.MemoryLoad] = CostUtility.GetMemoryAccess(outputType), @@ -610,8 +617,13 @@ public Metric Visit(IMetricEvaluateContext context, Gelu target) public Expr Visit(IShapeEvaluateContext context, Gelu target) => context.GetArgumentShape(target, Gelu.Input); - private IRType Visit(TensorType input) + private IRType Visit(IRType input) { + if (input is DistributedType d && d.NdSBP.Any(s => s is SBPPartialSum)) + { + return new InvalidType("gelu with partial sum is not supported"); + } + return input; } } diff --git a/src/Nncase.Evaluator/NN/Conv2D.cs b/src/Nncase.Evaluator/NN/Conv2D.cs index c26e219821..9ff14e5256 100644 --- a/src/Nncase.Evaluator/NN/Conv2D.cs +++ b/src/Nncase.Evaluator/NN/Conv2D.cs @@ -48,20 +48,28 @@ public IValue Visit(IEvaluateContext context, Conv2D conv) /// public IRType Visit(ITypeInferenceContext context, Conv2D target) { - var input = context.CheckArgumentType(target, Conv2D.Input); - var weights = context.CheckArgumentType(target, Conv2D.Weights); - return Visit(context, target, input, weights); + var input = context.GetArgumentType(target, Conv2D.Input); + var weights = context.GetArgumentType(target, Conv2D.Weights); + var bias = context.GetArgumentType(target, Conv2D.Bias); + return (input, weights) switch + { + (DistributedType a, DistributedType b) => Visit(context, target, a, b, (DistributedType)bias), + (TensorType a, TensorType b) => Visit(context, target, a, b), + (AnyType, _) => AnyType.Default, + (_, AnyType) => AnyType.Default, + _ => new InvalidType(string.Empty), + }; } /// public Cost Visit(ICostEvaluateContext context, Conv2D target) { - var inputType = context.GetArgumentType(target, Conv2D.Input); - var weightsType = context.GetArgumentType(target, Conv2D.Weights); - var biasType = context.GetArgumentType(target, Conv2D.Bias); - var outputType = context.GetReturnType(); + var inputType = context.GetArgumentType(target, Conv2D.Input); + var weightsType = context.GetArgumentType(target, Conv2D.Weights); + var biasType = context.GetArgumentType(target, Conv2D.Bias); + var outputType = context.GetReturnType(); - var weightsShape = weightsType.Shape; + var weightsShape = weightsType is TensorType ? ((TensorType)weightsType).Shape : ((DistributedType)weightsType).TensorType.Shape; var macPerElement = (2 * weightsShape[1] * weightsShape[2] * weightsShape[3]) - 1; return new() { @@ -104,4 +112,90 @@ private IRType Visit(ITypeInferenceContext context, Conv2D target, TensorType in var args = context.GetArguments(target, Conv2D.Stride, Conv2D.Padding, Conv2D.Dilation, Conv2D.Groups); return TypeInference.Conv2DType(input, weights, args[0], args[1], args[2], args[3]); } + + private IRType Visit(ITypeInferenceContext context, Conv2D target, DistributedType input, DistributedType weights, DistributedType bias) + { + if (Visit(context, target, input.TensorType, weights.TensorType) is not TensorType outType) + { + return new InvalidType(string.Empty); + } + + var args = context.GetArguments(target, Conv2D.Stride, Conv2D.Padding, Conv2D.Dilation, Conv2D.Groups); + + // Not support split on h/w/r/s + if (input.NdSBP.Any(sbp => sbp is SBPSplit s && s.Axis >= 2) || weights.NdSBP.Any(sbp => sbp is SBPSplit s && s.Axis >= 2)) + { + return new InvalidType(string.Empty); + } + + if (input.Placement != weights.Placement) + { + return new InvalidType("placement not equal"); + } + + var ndsbp = new SBP[input.Placement.Rank]; + for (int i = 0; i < input.Placement.Rank; i++) + { + var invalid = new InvalidType($"({input.NdSBP[i]}, {weights.NdSBP[i]}) not support"); + switch (input.NdSBP[i], weights.NdSBP[i]) + { + case (SBPSplit { Axis: int ax }, SBPSplit { Axis: int bx }): + // split on ic + if (ax == 1 && bx == 1) + { + if (bias.NdSBP[i] is SBPBroadCast) + { + ndsbp[i] = SBP.P; + } + else + { + return invalid; + } + } + else + { + return invalid; + } + + break; + case (SBPSplit { Axis: int ax }, SBPBroadCast): + if (ax == 0 && bias.NdSBP[i] is SBPBroadCast) + { + ndsbp[i] = SBP.S(ax); + } + else + { + return invalid; + } + + break; + case (SBPBroadCast, SBPSplit { Axis: int bx }): + if (bx == 0 && bias.NdSBP[i] is SBPSplit s && s.Axis == bx) + { + ndsbp[i] = SBP.S(bx + 1); + } + else + { + return invalid; + } + + break; + case (SBPBroadCast, SBPBroadCast): + if (bias.NdSBP[i] is SBPBroadCast) + { + ndsbp[i] = SBP.B; + } + else + { + return invalid; + } + + break; + default: + return invalid; + } + } + + return new DistributedType(outType, ndsbp, input.Placement); + } } diff --git a/src/Nncase.Evaluator/NN/LayerNorm.cs b/src/Nncase.Evaluator/NN/LayerNorm.cs index 4088d8f9e4..b76300efbf 100644 --- a/src/Nncase.Evaluator/NN/LayerNorm.cs +++ b/src/Nncase.Evaluator/NN/LayerNorm.cs @@ -2,6 +2,7 @@ // Licensed under the Apache license. See LICENSE file in the project root for full license information. using System; +using System.Linq; using Nncase.CostModel; using Nncase.IR; using Nncase.IR.NN; @@ -23,26 +24,52 @@ public IValue Visit(IEvaluateContext context, LayerNorm layerNorm) var bias = context.GetOrtArgumentValue(layerNorm, LayerNorm.Bias); // return Value.FromTensor(OrtKI.LayerNormalization(input, scale, bias, layerNorm.Axis, layerNorm.Epsilon, 1)); - return Value.FromTensor(LayerNormImpl(input.ToTensor(), scale.ToTensor(), bias.ToTensor(), layerNorm.Axis, layerNorm.Epsilon)); + return Value.FromTensor(LayerNormImpl(input.ToTensor(), scale.ToTensor(), bias.ToTensor(), layerNorm.Axis, layerNorm.Epsilon, layerNorm.UseMean)); } /// public IRType Visit(ITypeInferenceContext context, LayerNorm target) { - var input = context.CheckArgumentType(target, LayerNorm.Input); - return Visit(input); + var input = context.CheckArgumentType(target, LayerNorm.Input); + var scale = context.CheckArgumentType(target, LayerNorm.Scale); + var bias = context.CheckArgumentType(target, LayerNorm.Bias); + + return (input, scale, bias) switch + { + (DistributedType a, DistributedType b, DistributedType c) => Visit(a, b, c, target.Axis), + (TensorType a, TensorType, TensorType) => Visit(a), + _ => new InvalidType(input.GetType().ToString()), + }; } /// public Cost Visit(ICostEvaluateContext context, LayerNorm target) { - var inputType = context.GetArgumentType(target, LayerNorm.Input); - var returnType = context.GetReturnType(); - return new() + var inputType = context.GetArgumentType(target, LayerNorm.Input); + var returnType = context.GetReturnType(); + switch (inputType, returnType) { - [CostFactorNames.MemoryLoad] = CostUtility.GetMemoryAccess(inputType), - [CostFactorNames.MemoryStore] = CostUtility.GetMemoryAccess(returnType), - }; + case (TensorType, TensorType): + return new() + { + [CostFactorNames.MemoryLoad] = CostUtility.GetMemoryAccess(inputType), + [CostFactorNames.MemoryStore] = CostUtility.GetMemoryAccess(returnType), + }; + + case (DistributedType inputDistributedType, DistributedType): + var scaleType = context.GetArgumentType(target, LayerNorm.Scale); + var biasType = context.GetArgumentType(target, LayerNorm.Bias); + var ring = GetRingReduceCommunicate(scaleType, new[] { 0, 1 }) + GetRingReduceCommunicate(biasType, new[] { 0, 1 }); + var reCompute = inputDistributedType.NdSBP.Select((sbp, i) => sbp is SBPSplit ? 1 : inputDistributedType.Placement.Hierarchy[i]).ToArray().Aggregate(1, (acc, rep) => acc * rep); + return new() + { + [CostFactorNames.MemoryLoad] = CostUtility.GetMemoryAccess(inputType) + ring, + [CostFactorNames.CPUCycles] = CostUtility.GetCPUCycles(inputType, 1) * (UInt128)reCompute, + [CostFactorNames.MemoryStore] = CostUtility.GetMemoryAccess(returnType) + ring, + }; + default: + throw new NotSupportedException(); + } } public Metric Visit(IMetricEvaluateContext context, LayerNorm target) @@ -70,8 +97,53 @@ private IRType Visit(TensorType input) return input; } + private IRType Visit(DistributedType input, DistributedType scale, DistributedType bias, int raxis) + { + var invalid = new InvalidType($"{input}, {scale}, {bias} not support"); + if (input.Placement != scale.Placement || scale.Placement != bias.Placement) + { + return invalid; + } + + var ndsbp = new SBP[input.Placement.Rank]; + + for (int i = 0; i < input.Placement.Rank; i++) + { + switch (input.NdSBP[i], scale.NdSBP[i], bias.NdSBP[i]) + { + case (SBPSplit { Axis: int ix }, SBPSplit { Axis: int sx }, SBPSplit { Axis: int bx }) when ix >= raxis && sx == (ix - raxis) && bx == sx: + ndsbp[i] = SBP.S(ix); + break; + case (SBPSplit { Axis: int ix }, SBPBroadCast, SBPBroadCast) when ix < raxis: + ndsbp[i] = SBP.S(ix); + break; + case (SBPBroadCast, SBPBroadCast, SBPBroadCast): + ndsbp[i] = SBP.B; + break; + default: + return invalid; + } + } + + return new DistributedType(input.TensorType, ndsbp, input.Placement); + } + + private UInt128 GetRingReduceCommunicate(DistributedType distributedType, int[] axes) + { + var ttype = Utilities.DistributedUtility.GetDividedTensorType(distributedType); + var splits = axes.Where(i => distributedType.NdSBP[i] is SBPSplit); + if (!splits.Any()) + { + return 0; + } + + var p = (UInt128)splits.Select(i => distributedType.Placement.Hierarchy[i]).Aggregate(1, (acc, i) => acc * i); + var v = CostUtility.GetMemoryAccess(distributedType.TensorType); + return (p - 1) * (v / p); + } + #if true - private Tensor LayerNormImpl(Tensor input, Tensor scale, Tensor bias, int axis, float epsilon) + private Tensor LayerNormImpl(Tensor input, Tensor scale, Tensor bias, int axis, float epsilon, bool useMean = true) { int outputSize = 1; int innerSize = 1; @@ -96,9 +168,12 @@ private Tensor LayerNormImpl(Tensor input, Tensor scale, Tensor bias, int axis, for (int batch = 0; batch < outputSize; batch++) { float mean1 = 0f; - for (int i = 0; i < innerSize; i++) + if (useMean) { - mean1 += inputArray[(i + (batch * innerSize)) % inputArray.Length] / innerSize; + for (int i = 0; i < innerSize; i++) + { + mean1 += inputArray[(i + (batch * innerSize)) % inputArray.Length] / innerSize; + } } float[] sub = new float[innerSize]; diff --git a/src/Nncase.Evaluator/NN/Normalization.cs b/src/Nncase.Evaluator/NN/Normalization.cs index 1bc2c75bd2..8ad56df403 100644 --- a/src/Nncase.Evaluator/NN/Normalization.cs +++ b/src/Nncase.Evaluator/NN/Normalization.cs @@ -153,15 +153,22 @@ public IValue Visit(IEvaluateContext context, InstanceNormalization i) /// public IRType Visit(ITypeInferenceContext context, InstanceNormalization target) { - var input = context.CheckArgumentType(target, InstanceNormalization.Input); - return Visit(input); + var input = context.CheckArgumentType(target, InstanceNormalization.Input); + var scale = context.CheckArgumentType(target, InstanceNormalization.Scale); + var bias = context.CheckArgumentType(target, InstanceNormalization.Bias); + return (input, scale, bias) switch + { + (DistributedType a, DistributedType b, DistributedType c) => Visit(a, b, c), + (TensorType a, TensorType, TensorType) => Visit(a), + _ => new InvalidType(input.GetType().ToString()), + }; } /// public Cost Visit(ICostEvaluateContext context, InstanceNormalization target) { - var inputType = context.GetArgumentType(target, InstanceNormalization.Input); - var returnType = context.GetReturnType(); + var inputType = context.GetArgumentType(target, InstanceNormalization.Input); + var returnType = context.GetReturnType(); return new() { [CostFactorNames.MemoryLoad] = CostUtility.GetMemoryAccess(inputType), @@ -183,6 +190,40 @@ private IRType Visit(TensorType input) { return input; } + + private IRType Visit(DistributedType input, DistributedType scale, DistributedType bias) + { + var invalid = new InvalidType($"{input}, {scale}, {bias} not support"); + if (input.Placement != scale.Placement || scale.Placement != bias.Placement) + { + return invalid; + } + + var ndsbp = new SBP[input.Placement.Rank]; + + // scale & bias always on Channel + const int rAxis = 1; + + for (int i = 0; i < input.Placement.Rank; i++) + { + switch (input.NdSBP[i], scale.NdSBP[i], bias.NdSBP[i]) + { + case (SBPSplit { Axis: int ix }, SBPSplit { Axis: int sx }, SBPSplit { Axis: int bx }) when ix == rAxis && sx == (ix - rAxis) && bx == sx: + ndsbp[i] = SBP.S(ix); + break; + case (SBPSplit { Axis: int ix }, SBPBroadCast, SBPBroadCast) when ix != rAxis: + ndsbp[i] = SBP.S(ix); + break; + case (SBPBroadCast, SBPBroadCast, SBPBroadCast): + ndsbp[i] = SBP.B; + break; + default: + return invalid; + } + } + + return new DistributedType(input.TensorType, ndsbp, input.Placement); + } } /// diff --git a/src/Nncase.Evaluator/NN/Softmax.cs b/src/Nncase.Evaluator/NN/Softmax.cs index ef5afafbe9..c7640de82f 100644 --- a/src/Nncase.Evaluator/NN/Softmax.cs +++ b/src/Nncase.Evaluator/NN/Softmax.cs @@ -1,6 +1,7 @@ // Copyright (c) Canaan Inc. All rights reserved. // Licensed under the Apache license. See LICENSE file in the project root for full license information. +using System.Linq; using Nncase.CostModel; using Nncase.IR; using Nncase.IR.NN; @@ -81,14 +82,20 @@ public IValue Visit(IEvaluateContext context, Softmax softMax) /// public IRType Visit(ITypeInferenceContext context, Softmax target) { - var input = context.CheckArgumentType(target, Softmax.Input); - return Visit(input); + var input = context.CheckArgumentType(target, Softmax.Input); + var axis = context.GetArgument(target, Softmax.Axis); + return input switch + { + TensorType t => Visit(t), + DistributedType d => Visit(d, axis), + _ => new InvalidType(input.GetType().Name), + }; } /// public Cost Visit(ICostEvaluateContext context, Softmax target) { - var ret = context.GetReturnType(); + var ret = context.GetReturnType(); return new() { [CostFactorNames.MemoryLoad] = CostUtility.GetMemoryAccess(ret), @@ -118,6 +125,17 @@ private IRType Visit(TensorType input) { return input; } + + private IRType Visit(DistributedType input, Expr axisExpr) + { + var axis = ((TensorConst)axisExpr).Value.ToScalar(); + if (input.NdSBP.Any(sbp => sbp is SBPSplit s && s.Axis == axis)) + { + return new InvalidType("Not support split on Axis for Softmax now."); + } + + return input; + } } /// diff --git a/src/Nncase.Evaluator/RNN/LSTM.cs b/src/Nncase.Evaluator/RNN/LSTM.cs index 9abd5964e4..fc1d1f3daa 100644 --- a/src/Nncase.Evaluator/RNN/LSTM.cs +++ b/src/Nncase.Evaluator/RNN/LSTM.cs @@ -6,12 +6,11 @@ using Nncase.CostModel; using Nncase.IR; -// using Nncase.IR.NN; -using Nncase.IR.Tensors; +using Nncase.IR.RNN; using OrtKISharp; using static Nncase.LSTMHelper; -namespace Nncase.Evaluator.NN; +namespace Nncase.Evaluator.RNN; /// /// Evaluator for . diff --git a/src/Nncase.Evaluator/TIR/Load.cs b/src/Nncase.Evaluator/TIR/Load.cs index 6ea6faddff..5898e3d353 100644 --- a/src/Nncase.Evaluator/TIR/Load.cs +++ b/src/Nncase.Evaluator/TIR/Load.cs @@ -30,12 +30,11 @@ public string Visit(IIRPrinterContext context, Load target, bool iLmode) private IRType Visit(Load target, TensorType handle, TensorType index) { - if (!handle.IsScalar && handle.DType is not PointerType) + if (handle is not TensorType { DType: PointerType { } p }) { - throw new NotSupportedException(handle.DType.ToString()); + return new InvalidType("handle must be pointer type!"); } - _ = index.IsScalar ? 1 : index.Shape[0].FixedValue; - return TensorType.Scalar(((PointerType)handle.DType).ElemType); + return TensorType.Scalar(p.ElemType); } } diff --git a/src/Nncase.Evaluator/TIR/Store.cs b/src/Nncase.Evaluator/TIR/Store.cs index b29459bfe2..b46bf57f52 100644 --- a/src/Nncase.Evaluator/TIR/Store.cs +++ b/src/Nncase.Evaluator/TIR/Store.cs @@ -24,21 +24,21 @@ public IRType Visit(ITypeInferenceContext context, Store target) public string Visit(IIRPrinterContext context, Store target, bool iLmode) { var handle = context.GetArgument(target, Store.Handle); - _ = context.GetArgument(target, Store.Value); + var value = context.GetArgument(target, Store.Value); var index = context.GetArgument(target, Store.Index); - return $"{handle}[{index}] = {index}"; - - throw new System.NotImplementedException(); + return $"{handle}[{index}] = {value}"; } private IRType Visit(Store target, TensorType handle, TensorType index, TensorType value) { - _ = index.IsScalar ? 1 : index.Shape[0].FixedValue; + if (handle.DType is not PointerType { ElemType: DataType elemType } || elemType != value.DType) + { + return new InvalidType($"You Can't Load The {value.DType} To {handle.DType}"); + } - var elemType = ((PointerType)handle.DType).ElemType; - if (elemType != value.DType) + if (index.DType != DataTypes.Int32) { - return new InvalidType($"You Can't Load The {value.DType} To {elemType}"); + return new InvalidType($"store value type {index.DType} not supported"); } return TupleType.Void; diff --git a/src/Nncase.Evaluator/Tensors/Cast.cs b/src/Nncase.Evaluator/Tensors/Cast.cs index c7e285d060..6f32939c3c 100644 --- a/src/Nncase.Evaluator/Tensors/Cast.cs +++ b/src/Nncase.Evaluator/Tensors/Cast.cs @@ -24,8 +24,13 @@ public IValue Visit(IEvaluateContext context, Cast cast) /// public IRType Visit(ITypeInferenceContext context, Cast target) { - var input = context.CheckArgumentType(target, Cast.Input); - return Visit(target, input); + var input = context.CheckArgumentType(target, Cast.Input); + return input switch + { + TensorType t => Visit(target, t), + DistributedType d => Visit(target, d), + _ => new InvalidType(input.GetType().ToString()), + }; } /// @@ -37,10 +42,10 @@ public string Visit(IIRPrinterContext context, Cast target, bool iLmode) /// public Cost Visit(ICostEvaluateContext context, Cast target) { - var input = context.GetArgumentType(target, Cast.Input); + var input = context.GetArgumentType(target, Cast.Input); return new() { - [CostFactorNames.MemoryLoad] = CostUtility.GetMemoryAccess(input.DType), + [CostFactorNames.MemoryLoad] = CostUtility.GetMemoryAccess(input), [CostFactorNames.MemoryStore] = CostUtility.GetMemoryAccess(target.NewType), [CostFactorNames.CPUCycles] = CostUtility.GetCPUCycles(target.NewType, 1), }; @@ -61,4 +66,21 @@ private IRType Visit(Cast target, TensorType input) { return new TensorType(target.NewType, input.Shape); } + + private IRType Visit(Cast target, DistributedType inType) + { + var invalid = new InvalidType(inType.ToString()); + var ndsbp = new SBP[inType.Placement.Rank]; + for (int i = 0; i < inType.Placement.Rank; i++) + { + if (inType.NdSBP[i] is SBPPartialSum) + { + return invalid; + } + + ndsbp[i] = inType.NdSBP[i]; + } + + return new DistributedType(new TensorType(target.NewType, inType.TensorType.Shape), ndsbp, inType.Placement); + } } diff --git a/src/Nncase.Evaluator/Tensors/Concat.cs b/src/Nncase.Evaluator/Tensors/Concat.cs index d1098ccf7b..6cce7d2b66 100644 --- a/src/Nncase.Evaluator/Tensors/Concat.cs +++ b/src/Nncase.Evaluator/Tensors/Concat.cs @@ -2,6 +2,7 @@ // Licensed under the Apache license. See LICENSE file in the project root for full license information. using System; +using System.Collections.Generic; using System.Linq; using NetFabric.Hyperlinq; using Nncase.CostModel; @@ -25,7 +26,7 @@ public class ConcatEvaluator : IEvaluator, ITypeInferencer, ICos public IValue Visit(IEvaluateContext context, Concat cat) { var inputs = context.GetArgumentValueAsTensors(cat, Concat.Input); - var axis = context.GetArgumentValueAsScalar(cat, Concat.Axis); + var axis = cat.Axis; return OrtKI.Concat(inputs.Select(t => t.ToOrtTensor()).ToArray(), axis).ToValue(); } @@ -33,14 +34,13 @@ public IValue Visit(IEvaluateContext context, Concat cat) public IRType Visit(ITypeInferenceContext context, Concat target) { var inputs = context.CheckArgumentType(target, Concat.Input); - var axis = context.CheckArgumentType(target, Concat.Axis); - return Visit(context, target, inputs, axis); + return Visit(inputs, target.Axis); } /// public Cost Visit(ICostEvaluateContext context, Concat target) { - var ret = context.GetReturnType(); + var ret = context.GetReturnType(); return new() { [CostFactorNames.MemoryLoad] = CostUtility.GetMemoryAccess(ret), @@ -52,8 +52,7 @@ public Cost Visit(ICostEvaluateContext context, Concat target) public Expr Visit(IShapeEvaluateContext context, Concat target) { var inShape = context.GetArgumentShape(target, Concat.Input); - var axis = context.GetArgument(target, Concat.Axis); - var axisV = ShapeExprUtility.Positive(axis, inShape[0]); + var axisV = ShapeExprUtility.Positive(target.Axis, inShape[0]); var inShapes = ((IR.Tuple)inShape).Fields.ToArray().Select(x => Cast(x, DataTypes.Int64)).ToArray(); var dim = inShapes.ToArray().Aggregate((Expr)0L, (sum, shape) => sum + shape[axisV]); var outShape = ShapeExprUtility.Replace(inShapes[0], axisV, dim); @@ -68,17 +67,18 @@ public Expr Visit(IShapeEvaluateContext context, Concat target) DataType? allDType = null; foreach (var (i, input) in Enumerable.Range(0, inputs.Count).Select(i => (i, inputs[i]))) { - var type = input as TensorType; - if (type is null) + TensorType type; + if (input is TensorType a) { - if (input is InvalidType) - { - return input; - } - else - { - return new InvalidType($"The ConCat Item[{i}] Must Be TensorType But Get {input.GetType().Name}"); - } + type = a; + } + else if (input is DistributedType { TensorType: TensorType b }) + { + type = b; + } + else + { + return new InvalidType($"The ConCat Item[{i}] Must Have TensorType But Get {input}"); } if (type.Shape.IsUnranked) @@ -103,7 +103,14 @@ public Expr Visit(IShapeEvaluateContext context, Concat target) return null; } - private IRType Visit(ITypeInferenceContext context, Concat target, TupleType inputs, TensorType axis) + private TensorType GetTensorType(IRType input) => input switch + { + TensorType t => t, + DistributedType d => d.TensorType, + _ => throw new InvalidCastException(), + }; + + private IRType Visit(TupleType inputs, int axis) { var result = CheckType(inputs); if (result != null) @@ -111,15 +118,15 @@ private IRType Visit(ITypeInferenceContext context, Concat target, TupleType inp return result; } - var sameRank = inputs.All(input => ((TensorType)input).Shape.Rank == ((TensorType)inputs[0]).Shape.Rank); + var sameRank = inputs.All(input => GetTensorType(input).Shape.Rank == GetTensorType(inputs[0]).Shape.Rank); if (!sameRank) { return new InvalidType("Inputs of concat should be same rank"); } - var input0 = (TensorType)inputs[0]; + var input0 = GetTensorType(inputs[0]); InvalidType? invalidType = null; - var axisV = ((TensorConst)context.GetArgument(target, Concat.Axis)).Value.ToScalar(); + var axisV = axis; var axisValue = Util.PositiveIndex(axisV, input0.Shape.Rank); var shapeValue = Enumerable.Range(0, input0.Shape.Rank).Select(i => { @@ -134,18 +141,18 @@ private IRType Visit(ITypeInferenceContext context, Concat target, TupleType inp var allAxisDimIsSame = true; foreach (var inType in inputs.Fields) { - if (((TensorType)inType).Shape.IsUnranked) + if (GetTensorType(inType).Shape.IsUnranked) { continue; } - var d = ((TensorType)inType).Shape[i]; + var d = GetTensorType(inType).Shape[i]; if (d.IsUnknown) { return Dimension.Unknown; } - if (d.FixedValue != ((TensorType)inputs[0]).Shape[i]) + if (d.FixedValue != GetTensorType(inputs[0]).Shape[i]) { allAxisDimIsSame = false; } @@ -153,7 +160,7 @@ private IRType Visit(ITypeInferenceContext context, Concat target, TupleType inp if (allAxisDimIsSame) { - return ((TensorType)inputs[0]).Shape[i]; + return GetTensorType(inputs[0]).Shape[i]; } else { @@ -163,7 +170,56 @@ private IRType Visit(ITypeInferenceContext context, Concat target, TupleType inp } }); var shape = new Shape(shapeValue); - return (invalidType as IRType) ?? new TensorType(input0.DType, shape); + if (invalidType is InvalidType invalid) + { + return invalid; + } + + var tensorType = new TensorType(input0.DType, shape); + + if (inputs[0] is not DistributedType distributedType) + { + return tensorType; + } + + if (inputs.OfType().Select(d => d.Placement).ToHashSet().Count != 1) + { + return new InvalidType("the inputs have different placement"); + } + + var ndsbp = new SBP[distributedType.Placement.Rank]; + + for (int i = 0; i < distributedType.Placement.Rank; i++) + { + var sbps = inputs.OfType().Select(d => d.NdSBP[i]).ToArray(); + if (sbps.Any(sbp => sbp is SBPSplit { Axis: int x } && x == axis)) + { + return new InvalidType("not support distribute on concat axis"); + } + + if (sbps.Any(sbp => sbp is SBPPartialSum)) + { + return new InvalidType("not support distribute with partialsum"); + } + + if (sbps.OfType().ToHashSet() is HashSet setSplit && + sbps.OfType().ToHashSet() is HashSet setBroadcast) + { + switch (setSplit.Count) + { + case 0: + ndsbp[i] = SBP.B; + break; + case 1 when setBroadcast.Count == 0: + ndsbp[i] = setSplit.First(); + break; + default: + return new InvalidType("not support distribute with different axis"); + } + } + } + + return new DistributedType(tensorType, ndsbp, distributedType.Placement); } // axis: if one of inputs shape[axis] is unknown @@ -173,12 +229,12 @@ private Dimension AxisDim(TupleType inputs, int axisValue) { var allAxisDimIsFixed = inputs.Fields.Aggregate( true, - (prod, next) => prod && ((TensorType)next).Shape[axisValue].IsFixed); + (prod, next) => prod && (next switch { TensorType t => t, DistributedType d => d.TensorType, _ => throw new NotSupportedException() }).Shape[axisValue].IsFixed); if (allAxisDimIsFixed) { return inputs.Fields.Aggregate( 0, - (prod, next) => prod + ((TensorType)next).Shape[axisValue].FixedValue); + (prod, next) => prod + (next switch { TensorType t => t, DistributedType d => d.TensorType, _ => throw new NotSupportedException() }).Shape[axisValue].FixedValue); } else { diff --git a/src/Nncase.Evaluator/Tensors/Expand.cs b/src/Nncase.Evaluator/Tensors/Expand.cs index cea0603bfd..e06241f90c 100644 --- a/src/Nncase.Evaluator/Tensors/Expand.cs +++ b/src/Nncase.Evaluator/Tensors/Expand.cs @@ -30,8 +30,8 @@ public IValue Visit(IEvaluateContext context, Expand expand) public Cost Visit(ICostEvaluateContext context, Expand target) { - var input = context.GetArgumentType(target, Expand.Input); - var ret = context.GetReturnType(); + var input = context.GetArgumentType(target, Expand.Input); + var ret = context.GetReturnType(); return CostUtility.GetBroadcastCost(input, ret); } @@ -53,6 +53,18 @@ public Metric Visit(IMetricEvaluateContext context, Expand target) }; } + public IRType Visit(ITypeInferenceContext context, Expand target) + { + var input = context.CheckArgumentType(target, Expand.Input); + var shape = context.CheckArgumentType(target, Expand.Shape); + return input switch + { + TensorType t => Visit(context, target, t, shape), + DistributedType d => Visit(context, target, d, shape), + _ => new InvalidType(input.GetType().ToString()), + }; + } + private IRType Visit(ITypeInferenceContext context, Expand target, TensorType input, TensorType shape) { var shape_expr = context.GetArgument(target, Expand.Shape); @@ -65,4 +77,28 @@ private IRType Visit(ITypeInferenceContext context, Expand target, TensorType in return input with { Shape = TypeInference.ReshapeTo(shape) }; } } + + private IRType Visit(ITypeInferenceContext context, Expand target, DistributedType input, TensorType shape) + { + var invalid = new InvalidType(input.ToString()); + var shape_expr = context.GetArgument(target, Expand.Shape); + if (shape_expr is TensorConst constShape) + { + var newShape = constShape.Value.ToArray(); + var ndsbp = new SBP[input.Placement.Rank]; + for (int i = 0; i < input.Placement.Rank; i++) + { + if (input.NdSBP[i] is SBPSplit sbp && newShape[sbp.Axis] != input.TensorType.Shape[sbp.Axis]) + { + return invalid; + } + + ndsbp[i] = input.NdSBP[i]; + } + + return new DistributedType(new TensorType(input.TensorType.DType, new Shape(newShape)), ndsbp, input.Placement); + } + + return invalid; + } } diff --git a/src/Nncase.Evaluator/Tensors/Gather.cs b/src/Nncase.Evaluator/Tensors/Gather.cs index acfd7ec9b5..2f4bd25d46 100644 --- a/src/Nncase.Evaluator/Tensors/Gather.cs +++ b/src/Nncase.Evaluator/Tensors/Gather.cs @@ -23,7 +23,7 @@ public class GatherEvaluator : IEvaluator, ITypeInferencer, ICos public IValue Visit(IEvaluateContext context, Gather gather) { var input = context.GetOrtArgumentValue(gather, Gather.Input); - var axis = context.GetArgumentValueAsScalar(gather, Gather.Axis); + var axis = gather.Axis; var index = context.GetOrtArgumentValue(gather, Gather.Index); return OrtKI.Gather(input, index, axis).ToValue(); } @@ -31,29 +31,35 @@ public IValue Visit(IEvaluateContext context, Gather gather) /// public IRType Visit(ITypeInferenceContext context, Gather target) { - var input = context.CheckArgumentType(target, Gather.Input); - var axis = context.CheckArgumentType(target, Gather.Axis); - var index = context.CheckArgumentType(target, Gather.Index); - return Visit(context, target, input, axis, index); + var input = context.CheckArgumentType(target, Gather.Input); + var index = context.CheckArgumentType(target, Gather.Index); + + return (input, index) switch + { + (TensorType a, TensorType b) => Visit(a, target.Axis, b), + (DistributedType a, DistributedType b) => Visit(a, target.Axis, b), + _ => new InvalidType($"{input}, {index}"), + }; } /// public Cost Visit(ICostEvaluateContext context, Gather target) { - var ret_type = context.GetReturnType(); + var inputType = context.GetArgumentType(target, Gather.Input); + var indexType = context.GetArgumentType(target, Gather.Index); + var retType = context.GetReturnType(); return new() { - [CostFactorNames.MemoryLoad] = CostUtility.GetMemoryAccess(ret_type.DType), - [CostFactorNames.MemoryStore] = CostUtility.GetMemoryAccess(ret_type.DType), - [CostFactorNames.CPUCycles] = CostUtility.GetCPUCycles(ret_type.DType, 1), + [CostFactorNames.MemoryLoad] = CostUtility.GetMemoryAccess(inputType) + CostUtility.GetMemoryAccess(indexType), + [CostFactorNames.MemoryStore] = CostUtility.GetMemoryAccess(retType), + [CostFactorNames.CPUCycles] = CostUtility.GetCPUCycles(retType), }; } public Expr Visit(IShapeEvaluateContext context, Gather target) { - var axis = context.GetArgument(target, Gather.Axis); var inShape = context.GetArgumentShape(target, Gather.Input); - axis = ShapeExprUtility.Positive(Cast(axis, DataTypes.Int32), inShape); + var axis = ShapeExprUtility.Positive(target.Axis, inShape); var indexShape = context.GetArgumentShape(target, Gather.Index); var outShape = ShapeExprUtility.ReplaceList(inShape, axis, indexShape); return outShape; @@ -68,26 +74,56 @@ public Metric Visit(IMetricEvaluateContext context, Gather target) }; } - private IRType Visit(ITypeInferenceContext context, Gather target, TensorType input, TensorType axis, TensorType index) + private IRType Visit(TensorType input, int axis, TensorType index) { if (input.Shape.IsUnranked) { return input; } - if (context.GetArgument(target, Gather.Axis) is TensorConst axisValue) + axis = axis < 0 ? axis + input.Shape.Rank : axis; + + // input_shape[:axis] + index_shape + input_shape[axis + 1:] + var inShape = input.Shape.ToArray(); + var newShape = inShape[..axis].Concat(index.Shape).Concat(inShape[(axis + 1)..]).ToArray(); + return new TensorType(input.DType, newShape); + } + + private IRType Visit(DistributedType input, int axis, DistributedType index) + { + var invalid = new InvalidType(input.ToString() + " " + index.ToString()); + if (Visit(input.TensorType, axis, index.TensorType) is not TensorType tensorType) { - var axisV = axisValue.Value.ToScalar(); - axisV = axisV < 0 ? axisV + input.Shape.Rank : axisV; + return invalid; + } - // input_shape[:axis] + index_shape + input_shape[axis + 1:] - var inShape = input.Shape.ToArray(); - var newShape = inShape[..axisV].Concat(index.Shape).Concat(inShape[(axisV + 1)..]).ToArray(); - return new TensorType(input.DType, newShape); + if (input.Placement != index.Placement) + { + return invalid; } - else + + var ndsbp = new SBP[input.Placement.Rank]; + + for (int i = 0; i < input.Placement.Rank; i++) { - return new InvalidType("Gather axis must be constant"); + switch (input.NdSBP[i], index.NdSBP[i]) + { + case (SBPSplit { Axis: int ix }, _) when ix == axis: + return new InvalidType($"the input can't split on {axis}"); + case (SBPBroadCast, SBPSplit { Axis: int ix }): + ndsbp[i] = SBP.S(ix); + break; + case (SBPSplit { Axis: int ix }, SBPBroadCast): + ndsbp[i] = SBP.S(ix - axis + index.TensorType.Shape.Rank - 1); + break; + case (SBPBroadCast, SBPBroadCast): + ndsbp[i] = SBP.B; + break; + default: + return invalid; + } } + + return new DistributedType(tensorType, ndsbp, input.Placement); } } diff --git a/src/Nncase.Evaluator/Tensors/Reshape.cs b/src/Nncase.Evaluator/Tensors/Reshape.cs index 38c4d150ee..7488739f1e 100644 --- a/src/Nncase.Evaluator/Tensors/Reshape.cs +++ b/src/Nncase.Evaluator/Tensors/Reshape.cs @@ -2,8 +2,8 @@ // Licensed under the Apache license. See LICENSE file in the project root for full license information. using System; +using System.Collections.Generic; using System.Linq; -using DryIoc.ImTools; using NetFabric.Hyperlinq; using Nncase.CostModel; using Nncase.IR; @@ -34,8 +34,14 @@ public IValue Visit(IEvaluateContext context, Reshape reshape) /// public IRType Visit(ITypeInferenceContext context, Reshape target) { - var input = context.CheckArgumentType(target, Reshape.Input); - return Visit(context, target, input); + var input = context.CheckArgumentType(target, Reshape.Input); + return input switch + { + TensorType tensorType => Visit(context, target, tensorType), + DistributedType distributedType => Visit(context, target, distributedType), + AnyType => AnyType.Default, + _ => throw new NotImplementedException(), + }; } public Cost Visit(ICostEvaluateContext context, Reshape target) @@ -121,4 +127,85 @@ private IRType Visit(ITypeInferenceContext context, Reshape target, TensorType i var outShape = ReshapeTo(targetType); return input with { Shape = outShape }; } + + private IRType Visit(ITypeInferenceContext context, Reshape target, DistributedType inputType) + { + var outType = Visit(context, target, inputType.TensorType); + if (outType is not TensorType outTensorType) + { + return outType; + } + + var invalid = new InvalidType(inputType.ToString()); + if (outTensorType.Shape.IsUnranked) + { + return invalid; + } + + var newShape = outTensorType.Shape.ToValueArray(); + var oldShape = inputType.TensorType.Shape.ToValueArray(); + + // check is unsequeeze/sequeeze + if (Enumerable.SequenceEqual(oldShape.Where(i => i != 1).ToArray(), newShape.Where(i => i != 1).ToArray())) + { + if (oldShape.Length < newShape.Length) + { + var axis = 0; + var axisMap = new Dictionary(); + for (var n = 0; n < newShape.Length; n++) + { + if (newShape[n] == oldShape[axis]) + { + axisMap.Add(axis++, n); + if (axis >= oldShape.Length) + { + break; + } + } + } + + var ndsbp = new SBP[inputType.Placement.Rank]; + for (int i = 0; i < inputType.Placement.Rank; i++) + { + ndsbp[i] = inputType.NdSBP[i] switch + { + SBPSplit { Axis: int sx } => SBPSplit.S(axisMap[sx]), + SBP sbp => sbp, + }; + } + + return inputType with { TensorType = outTensorType, NdSBP = new(ndsbp) }; + } + else if (oldShape.Length > newShape.Length) + { + var axis = 0; + var axisMap = new Dictionary(); + for (var o = 0; o < oldShape.Length; o++) + { + if (oldShape[o] == newShape[axis]) + { + axisMap.Add(o, axis++); + if (axis >= newShape.Length) + { + break; + } + } + } + + var ndsbp = new SBP[inputType.Placement.Rank]; + for (int i = 0; i < inputType.Placement.Rank; i++) + { + ndsbp[i] = inputType.NdSBP[i] switch + { + SBPSplit { Axis: int sx } => SBPSplit.S(axisMap[sx]), + SBP sbp => sbp, + }; + } + + return inputType with { TensorType = outTensorType, NdSBP = new(ndsbp) }; + } + } + + return invalid; + } } diff --git a/src/Nncase.Evaluator/Tensors/Slice.cs b/src/Nncase.Evaluator/Tensors/Slice.cs index eada657d01..08de735500 100644 --- a/src/Nncase.Evaluator/Tensors/Slice.cs +++ b/src/Nncase.Evaluator/Tensors/Slice.cs @@ -41,18 +41,24 @@ public IValue Visit(IEvaluateContext context, Slice sl) /// public IRType Visit(ITypeInferenceContext context, Slice target) { - var input = context.CheckArgumentType(target, Slice.Input); + var input = context.CheckArgumentType(target, Slice.Input); context.CheckArgumentType(target, Slice.Begins); context.CheckArgumentType(target, Slice.Ends); context.CheckArgumentType(target, Slice.Axes); context.CheckArgumentType(target, Slice.Strides); - return Visit(context, target, input); + return input switch + { + TensorType t => Visit(context, target, t), + DistributedType d => Visit(context, target, d), + AnyType => AnyType.Default, + _ => new InvalidType(input.GetType().Name), + }; } /// public Cost Visit(ICostEvaluateContext context, Slice target) { - var outputType = context.GetReturnType(); + var outputType = context.GetReturnType(); return new() { @@ -227,4 +233,21 @@ end is TensorConst ends_con && return input with { Shape = outShape }; } + + private IRType Visit(ITypeInferenceContext context, Slice target, DistributedType input) + { + var outType = Visit(context, target, input.TensorType); + if (outType is not TensorType tensorType) + { + return new InvalidType("not support input tensor type infer"); + } + + var axes = ((TensorConst)context.GetArgument(target, Slice.Axes)).Value.ToArray(); + if (input.NdSBP.Any(sbp => sbp is SBPSplit s && axes.Contains(s.Axis))) + { + return new InvalidType("not support input tensor type infer"); + } + + return new DistributedType((TensorType)outType, input.NdSBP, input.Placement); + } } diff --git a/src/Nncase.Evaluator/Tensors/Transpose.cs b/src/Nncase.Evaluator/Tensors/Transpose.cs index 77e74d07f1..d643370aa9 100644 --- a/src/Nncase.Evaluator/Tensors/Transpose.cs +++ b/src/Nncase.Evaluator/Tensors/Transpose.cs @@ -5,6 +5,7 @@ using System.Collections.Generic; using System.Diagnostics; using System.Linq; +using DryIoc.ImTools; using Nncase.CostModel; using Nncase.IR; using Nncase.IR.Tensors; @@ -64,15 +65,22 @@ public IValue Visit(IEvaluateContext context, Transpose tr) /// public IRType Visit(ITypeInferenceContext context, Transpose target) { - var input = context.CheckArgumentType(target, Transpose.Input); - return Visit(context, target, input); + var input = context.CheckArgumentType(target, Transpose.Input); + + return input switch + { + DistributedType d => Visit(context, target, d), + TensorType t => Visit(context, target, t), + AnyType => AnyType.Default, + _ => new InvalidType(input.GetType().ToString()), + }; } /// public Cost Visit(ICostEvaluateContext context, Transpose target) { - var inputType = context.GetArgumentType(target, Transpose.Input); - var outputType = context.GetReturnType(); + var inputType = context.GetArgumentType(target, Transpose.Input); + var outputType = context.GetReturnType(); return new() { @@ -102,4 +110,36 @@ private IRType Visit(ITypeInferenceContext context, Transpose target, TensorType var permExpr = context.GetArgument(target, Transpose.Perm); return TypeInference.TransposeType(input, permExpr); } + + private IRType Visit(ITypeInferenceContext context, Transpose target, DistributedType input) + { + if (Visit(context, target, input.TensorType) is not TensorType tensorType) + { + throw new InvalidOperationException(); + } + + var permExpr = context.GetArgument(target, Transpose.Perm); + if (permExpr is TensorConst permValue) + { + var perm = permValue.Value.ToArray(); + var ndsbp = new SBP[input.Placement.Rank]; + + for (int i = 0; i < input.Placement.Rank; i++) + { + switch (input.NdSBP[i]) + { + case SBPSplit { Axis: int ix }: + ndsbp[i] = SBP.S(perm.IndexOf(ix)); + break; + default: + ndsbp[i] = input.NdSBP[i]; + break; + } + } + + return new DistributedType(tensorType, ndsbp, input.Placement); + } + + return new InvalidType(input.ToString()); + } } diff --git a/src/Nncase.Evaluator/Tensors/UnSqueeze.cs b/src/Nncase.Evaluator/Tensors/UnSqueeze.cs index bd86940fee..23bed23403 100644 --- a/src/Nncase.Evaluator/Tensors/UnSqueeze.cs +++ b/src/Nncase.Evaluator/Tensors/UnSqueeze.cs @@ -27,9 +27,18 @@ public IValue Visit(IEvaluateContext context, Unsqueeze unSqueeze) /// public IRType Visit(ITypeInferenceContext context, Unsqueeze target) { - var input = context.CheckArgumentType(target, Unsqueeze.Input); + var input = context.CheckArgumentType(target, Unsqueeze.Input); _ = context.CheckArgumentType(target, Unsqueeze.Dim); - return Visit(context, target, input); + if (input is TensorType tensorType) + { + return Visit(context, target, tensorType); + } + else if (input is DistributedType distributedType) + { + return Visit(context, target, distributedType); + } + + return new InvalidType(input.GetType().Name); } /// @@ -81,4 +90,26 @@ private IRType Visit(ITypeInferenceContext context, Unsqueeze target, TensorType return input with { Shape = new Shape(Enumerable.Repeat(Dimension.Unknown, input.Shape.Rank + 1)) }; } + + private IRType Visit(ITypeInferenceContext context, Unsqueeze target, DistributedType input) + { + var tensorType = (TensorType)Visit(context, target, input.TensorType); + + var ndsbp = new SBP[input.NdSBP.Count]; + + if (context.GetArgument(target, Unsqueeze.Dim) is TensorConst tdims) + { + var dimsValue = tdims.Value.Cast(); + for (int i = 0; i < input.NdSBP.Count; i++) + { + ndsbp[i] = input.NdSBP[i] switch + { + SBPSplit { Axis: int axis } => SBP.S(axis + dimsValue.Select(i => i <= axis).Count(b => b)), + SBP sbp => sbp, + }; + } + } + + return new DistributedType(tensorType, ndsbp, input.Placement); + } } diff --git a/src/Nncase.Evaluator/TypeInference.cs b/src/Nncase.Evaluator/TypeInference.cs index a0749e8fdd..bbe74d1739 100644 --- a/src/Nncase.Evaluator/TypeInference.cs +++ b/src/Nncase.Evaluator/TypeInference.cs @@ -11,6 +11,7 @@ using Microsoft.Extensions.DependencyInjection; using NetFabric.Hyperlinq; using Nncase.IR; +using Nncase.TIR; using static Nncase.IR.TypePatternUtility; namespace Nncase.Evaluator; diff --git a/src/Nncase.Evaluator/TypeInferenceVisitor.cs b/src/Nncase.Evaluator/TypeInferenceVisitor.cs index 8ffd28f715..36528fd045 100644 --- a/src/Nncase.Evaluator/TypeInferenceVisitor.cs +++ b/src/Nncase.Evaluator/TypeInferenceVisitor.cs @@ -52,28 +52,6 @@ protected override IRType VisitLeafBlock(Block expr) return TupleType.Void; } - /// - protected override IRType VisitLeafBufferLoad(BufferLoad expr) - { - IRType type; - VerifySubField(expr, expr.Buffer, TypePatternUtility.IsPointer()); - for (int i = 0; i < expr.Indices.Length; i++) - { - VerifySubField(expr, expr.Indices[i], TypePatternUtility.IsIntegralScalar(), $"BufferLoad.Indices[{i}]"); - } - - if (expr.Buffer.CheckedType is TensorType { IsScalar: true, DType: PointerType { ElemType: PrimType pointedType } }) - { - type = TensorType.Scalar(pointedType); - } - else - { - type = new InvalidType($"Can't load from {expr.Buffer.CheckedType}"); - } - - return type; - } - /// protected override IRType VisitLeafBufferRegion(BufferRegion expr) { @@ -91,27 +69,24 @@ protected override IRType VisitLeafBufferRegion(BufferRegion expr) } /// - protected override IRType VisitLeafBufferStore(BufferStore expr) + protected override IRType VisitLeafBuffer(Nncase.TIR.Buffer expr) { - VerifySubField(expr, expr.Buffer, TypePatternUtility.IsPointer()); - for (int i = 0; i < expr.Indices.Length; i++) + VerifySubField(expr, expr.MemSpan, TypePatternUtility.IsPointer() | TypePatternUtility.IsNoneType()); + foreach (var r in expr.Dimensions) { - VerifySubField(expr, expr.Indices[i], TypePatternUtility.IsIntegralScalar(), $"BufferStore.Indices[{i}]"); + VerifySubField(expr, r, TypePatternUtility.IsIntegralScalar()); } - VerifySubField(expr, expr.Value, TypePatternUtility.IsScalar()); - - IRType type; - if (expr.Value.CheckedType is TensorType { IsScalar: true, DType: PrimType valueType } && - expr.Buffer.CheckedType is TensorType { IsScalar: true, DType: PointerType { ElemType: PrimType pointedType } } - && valueType == pointedType) + foreach (var r in expr.Strides) { - type = TupleType.Void; + VerifySubField(expr, r, TypePatternUtility.IsIntegralScalar()); } - else + + var type = new TensorType(expr.ElemType, expr.Dimensions.AsValueEnumerable().Select(e => e switch { - type = new InvalidType($"Can't store {expr.Value.CheckedType} to {expr.Buffer.CheckedType}"); - } + TensorConst { Value: { Shape: { IsScalar: true } } t } => new Dimension(t.ToScalar()), + _ => Dimension.Unknown, + }).ToArray()); return type; } @@ -222,13 +197,6 @@ protected override IRType VisitLeafLet(Let expr) return type; } - /// - protected override IRType VisitLeafLogicalBuffer(LogicalBuffer expr) - { - var type = new TensorType(expr.ElemType, Shape.Unknown(expr.Rank)); - return type; - } - /// protected override IRType VisitLeafMarker(Marker expr) { @@ -251,13 +219,6 @@ protected override IRType VisitLeafOp(Op expr) return type; } - /// - protected override IRType VisitLeafPhysicalBuffer(PhysicalBuffer expr) - { - var type = new TensorType(expr.ElemType, new(expr.FixedDimensions)); - return type; - } - /// protected override IRType VisitLeafPrimFunction(PrimFunction expr) { @@ -318,6 +279,13 @@ protected override IRType VisitLeafVar(Var expr) return type; } + protected override IRType VisitLeafMemSpan(MemSpan expr) + { + VerifySubField(expr, expr.Start, TypePatternUtility.IsNoneType() | TypePatternUtility.IsIntegralScalar() | TypePatternUtility.IsPointer()); + VerifySubField(expr, expr.Size, TypePatternUtility.IsIntegralScalar()); + return expr.Start.CheckedType; + } + /// protected override IRType VisitLet(Let expr) { diff --git a/src/Nncase.Evaluator/packages.lock.json b/src/Nncase.Evaluator/packages.lock.json index 78fc35c9da..cf9c399201 100644 --- a/src/Nncase.Evaluator/packages.lock.json +++ b/src/Nncase.Evaluator/packages.lock.json @@ -13,11 +13,11 @@ }, "StyleCop.Analyzers": { "type": "Direct", - "requested": "[1.2.0-beta.507, )", - "resolved": "1.2.0-beta.507", - "contentHash": "/FtugDT66cKJJ+GGH7rNpG6UDrT4iIWz45M6lrXXHobDUFDHw+q5VgkbiR+6ffTO564ge7w6fQh/eoQhVdJO8Q==", + "requested": "[1.2.0-beta.435, )", + "resolved": "1.2.0-beta.435", + "contentHash": "TADk7vdGXtfTnYCV7GyleaaRTQjfoSfZXprQrVMm7cSJtJbFc1QIbWPyLvrgrfGdfHbGmUPvaN4ODKNxg2jgPQ==", "dependencies": { - "StyleCop.Analyzers.Unstable": "1.2.0.507" + "StyleCop.Analyzers.Unstable": "1.2.0.435" } }, "libortki": { @@ -87,8 +87,8 @@ }, "StyleCop.Analyzers.Unstable": { "type": "Transitive", - "resolved": "1.2.0.507", - "contentHash": "gTY3IQdRqDJ4hbhSA3e/R48oE8b/OiKfvwkt1QdNVfrJK2gMHBV8ldaHJ885jxWZfllK66soa/sdcjh9bX49Tw==" + "resolved": "1.2.0.435", + "contentHash": "ouwPWZxbOV3SmCZxIRqHvljkSzkCyi1tDoMzQtDb/bRP8ctASV/iRJr+A2Gdj0QLaLmWnqTWDrH82/iP+X80Lg==" }, "System.Buffers": { "type": "Transitive", @@ -110,6 +110,7 @@ "Microsoft.Extensions.Options": "[6.0.0, )", "Microsoft.Toolkit.HighPerformance": "[7.1.1, )", "NetFabric.Hyperlinq": "[3.0.0-beta48, )", + "System.CommandLine": "[2.0.0-beta4.22272.1, )", "System.Reactive": "[5.0.0, )" } }, @@ -169,6 +170,12 @@ "System.Runtime.CompilerServices.Unsafe": "5.0.0" } }, + "System.CommandLine": { + "type": "CentralTransitive", + "requested": "[2.0.0-beta4.22272.1, )", + "resolved": "2.0.0-beta4.22272.1", + "contentHash": "1uqED/q2H0kKoLJ4+hI2iPSBSEdTuhfCYADeJrAqERmiGQ2NNacYKRNEQ+gFbU4glgVyK8rxI+ZOe1onEtr/Pg==" + }, "System.Reactive": { "type": "CentralTransitive", "requested": "[5.0.0, )", diff --git a/src/Nncase.Graph/packages.lock.json b/src/Nncase.Graph/packages.lock.json index a439ce21fd..ab3e724693 100644 --- a/src/Nncase.Graph/packages.lock.json +++ b/src/Nncase.Graph/packages.lock.json @@ -4,11 +4,11 @@ "net7.0": { "StyleCop.Analyzers": { "type": "Direct", - "requested": "[1.2.0-beta.507, )", - "resolved": "1.2.0-beta.507", - "contentHash": "/FtugDT66cKJJ+GGH7rNpG6UDrT4iIWz45M6lrXXHobDUFDHw+q5VgkbiR+6ffTO564ge7w6fQh/eoQhVdJO8Q==", + "requested": "[1.2.0-beta.435, )", + "resolved": "1.2.0-beta.435", + "contentHash": "TADk7vdGXtfTnYCV7GyleaaRTQjfoSfZXprQrVMm7cSJtJbFc1QIbWPyLvrgrfGdfHbGmUPvaN4ODKNxg2jgPQ==", "dependencies": { - "StyleCop.Analyzers.Unstable": "1.2.0.507" + "StyleCop.Analyzers.Unstable": "1.2.0.435" } }, "libortki": { @@ -78,8 +78,8 @@ }, "StyleCop.Analyzers.Unstable": { "type": "Transitive", - "resolved": "1.2.0.507", - "contentHash": "gTY3IQdRqDJ4hbhSA3e/R48oE8b/OiKfvwkt1QdNVfrJK2gMHBV8ldaHJ885jxWZfllK66soa/sdcjh9bX49Tw==" + "resolved": "1.2.0.435", + "contentHash": "ouwPWZxbOV3SmCZxIRqHvljkSzkCyi1tDoMzQtDb/bRP8ctASV/iRJr+A2Gdj0QLaLmWnqTWDrH82/iP+X80Lg==" }, "System.Buffers": { "type": "Transitive", @@ -101,6 +101,7 @@ "Microsoft.Extensions.Options": "[6.0.0, )", "Microsoft.Toolkit.HighPerformance": "[7.1.1, )", "NetFabric.Hyperlinq": "[3.0.0-beta48, )", + "System.CommandLine": "[2.0.0-beta4.22272.1, )", "System.Reactive": "[5.0.0, )" } }, @@ -176,6 +177,12 @@ "libortki": "0.0.2" } }, + "System.CommandLine": { + "type": "CentralTransitive", + "requested": "[2.0.0-beta4.22272.1, )", + "resolved": "2.0.0-beta4.22272.1", + "contentHash": "1uqED/q2H0kKoLJ4+hI2iPSBSEdTuhfCYADeJrAqERmiGQ2NNacYKRNEQ+gFbU4glgVyK8rxI+ZOe1onEtr/Pg==" + }, "System.Reactive": { "type": "CentralTransitive", "requested": "[5.0.0, )", diff --git a/src/Nncase.IO/packages.lock.json b/src/Nncase.IO/packages.lock.json index eb0c4a8b7c..ef24cbccbb 100644 --- a/src/Nncase.IO/packages.lock.json +++ b/src/Nncase.IO/packages.lock.json @@ -4,17 +4,17 @@ "net7.0": { "StyleCop.Analyzers": { "type": "Direct", - "requested": "[1.2.0-beta.507, )", - "resolved": "1.2.0-beta.507", - "contentHash": "/FtugDT66cKJJ+GGH7rNpG6UDrT4iIWz45M6lrXXHobDUFDHw+q5VgkbiR+6ffTO564ge7w6fQh/eoQhVdJO8Q==", + "requested": "[1.2.0-beta.435, )", + "resolved": "1.2.0-beta.435", + "contentHash": "TADk7vdGXtfTnYCV7GyleaaRTQjfoSfZXprQrVMm7cSJtJbFc1QIbWPyLvrgrfGdfHbGmUPvaN4ODKNxg2jgPQ==", "dependencies": { - "StyleCop.Analyzers.Unstable": "1.2.0.507" + "StyleCop.Analyzers.Unstable": "1.2.0.435" } }, "StyleCop.Analyzers.Unstable": { "type": "Transitive", - "resolved": "1.2.0.507", - "contentHash": "gTY3IQdRqDJ4hbhSA3e/R48oE8b/OiKfvwkt1QdNVfrJK2gMHBV8ldaHJ885jxWZfllK66soa/sdcjh9bX49Tw==" + "resolved": "1.2.0.435", + "contentHash": "ouwPWZxbOV3SmCZxIRqHvljkSzkCyi1tDoMzQtDb/bRP8ctASV/iRJr+A2Gdj0QLaLmWnqTWDrH82/iP+X80Lg==" } } } diff --git a/src/Nncase.Importer/Ncnn/NcnnModel.cs b/src/Nncase.Importer/Ncnn/NcnnModel.cs index 488e191ad4..ec67287e97 100644 --- a/src/Nncase.Importer/Ncnn/NcnnModel.cs +++ b/src/Nncase.Importer/Ncnn/NcnnModel.cs @@ -65,7 +65,7 @@ public static NcnnModel ParseFromStream(Stream stream) throw new InvalidDataException("parse magic failed"); } - if (reader.ReadLine()?.Split(' ', StringSplitOptions.RemoveEmptyEntries | StringSplitOptions.TrimEntries) is not [var layerCountStr, var blobCountStr]) + if (reader.ReadLine()?.Split(' ', StringSplitOptions.RemoveEmptyEntries | StringSplitOptions.TrimEntries) is not[var layerCountStr, var blobCountStr]) { throw new InvalidDataException("parse layer_count or blob_count failed"); } diff --git a/src/Nncase.Importer/Ncnn/ParamDict.cs b/src/Nncase.Importer/Ncnn/ParamDict.cs index 954525ea48..bc5c77e6d7 100644 --- a/src/Nncase.Importer/Ncnn/ParamDict.cs +++ b/src/Nncase.Importer/Ncnn/ParamDict.cs @@ -44,7 +44,7 @@ public void LoadFrom(ReadOnlySpan fields) { foreach (var field in fields) { - if (field.Split('=', StringSplitOptions.RemoveEmptyEntries | StringSplitOptions.TrimEntries) is not [var idStr, var valueStr]) + if (field.Split('=', StringSplitOptions.RemoveEmptyEntries | StringSplitOptions.TrimEntries) is not[var idStr, var valueStr]) { break; } diff --git a/src/Nncase.Importer/Onnx/Concat.cs b/src/Nncase.Importer/Onnx/Concat.cs index 5f0e471652..433036c7be 100644 --- a/src/Nncase.Importer/Onnx/Concat.cs +++ b/src/Nncase.Importer/Onnx/Concat.cs @@ -14,7 +14,7 @@ private Expr VisitConcat(NodeProto op) { var inputs = Enumerable.Range(0, op.Input.Count).Select(x => GetInputExpr(op, x)).ToArray(); var axis = GetIntAttribute(op, "axis"); - return F.Tensors.Concat(new Tuple(inputs), axis); + return F.Tensors.Concat(new Tuple(inputs), (int)axis); } } } diff --git a/src/Nncase.Importer/Onnx/DataGatter.cs b/src/Nncase.Importer/Onnx/DataGatter.cs index 8001655a67..0cd5da981c 100644 --- a/src/Nncase.Importer/Onnx/DataGatter.cs +++ b/src/Nncase.Importer/Onnx/DataGatter.cs @@ -105,7 +105,7 @@ private Tensor GetTensor(TensorProto tensor) var externalDataCount = tensor.ExternalData.Count; if (externalDataCount != 0) { - if (externalDataCount < 3 && externalDataCount > 5) + if (externalDataCount < 1 || externalDataCount > 5) { throw new NotSupportedException("NotSupport ExternalData format, only support location, offset, length, checksum"); } @@ -113,9 +113,9 @@ private Tensor GetTensor(TensorProto tensor) var parent = Directory.GetParent(CompileSession.CompileOptions.InputFile)?.FullName; var externalData = tensor.ExternalData; var location = Path.Join(parent, externalData[0].Value); - var offset = long.Parse(externalData[1].Value); - var length = int.Parse(externalData[2].Value); + var offset = externalDataCount > 1L ? long.Parse(externalData[1].Value) : 0; using var br = new BinaryReader(new FileStream(location, FileMode.Open)); + var length = externalDataCount > 1 ? int.Parse(externalData[2].Value) : (int)br.BaseStream.Length; br.BaseStream.Seek(offset, SeekOrigin.Begin); var buffer = br.ReadBytes(length); return Tensor.FromBytes(type, buffer, shape); diff --git a/src/Nncase.Importer/Onnx/Gather.cs b/src/Nncase.Importer/Onnx/Gather.cs index eb03f835f6..ae47bb396d 100644 --- a/src/Nncase.Importer/Onnx/Gather.cs +++ b/src/Nncase.Importer/Onnx/Gather.cs @@ -14,7 +14,7 @@ private Expr VisitGather(in NodeProto op) { var (input, indices) = GetInputExprs(op, 0, 1); var axis = GetIntAttribute(op, "axis", 0); - return F.Tensors.Gather(input, axis, indices); + return F.Tensors.Gather(input, (int)axis, indices); } } } diff --git a/src/Nncase.Importer/Onnx/OnnxImporter.cs b/src/Nncase.Importer/Onnx/OnnxImporter.cs index 3004d86822..b87178604a 100644 --- a/src/Nncase.Importer/Onnx/OnnxImporter.cs +++ b/src/Nncase.Importer/Onnx/OnnxImporter.cs @@ -52,7 +52,6 @@ protected override (IEnumerable Inputs, Dictionary VarMap) Cre { var bucketOptions = CompileSession.CompileOptions.ShapeBucketOptions; _fixVarMap = bucketOptions.FixVarMap; - _constTensors = _graph.Initializer .ToDictionary(tensor => tensor.Name, tensor => tensor); diff --git a/src/Nncase.Importer/Onnx/Softmax.cs b/src/Nncase.Importer/Onnx/Softmax.cs index cde4fbb3e6..acd65b9606 100644 --- a/src/Nncase.Importer/Onnx/Softmax.cs +++ b/src/Nncase.Importer/Onnx/Softmax.cs @@ -43,7 +43,7 @@ private Expr SoftmaxV13Process(in NodeProto op, Func f) { var input = GetSingleInputExpr(op); var axis = GetIntAttribute(op, "axis", -1); - return f(input, axis); + return f(input, IR.F.Math.Select(axis < 0, (Rank(input) + axis)[0], axis)); } private Expr SoftmaxV1(in NodeProto op) diff --git a/src/Nncase.Importer/packages.lock.json b/src/Nncase.Importer/packages.lock.json index 845d535c0e..3a6d65fc28 100644 --- a/src/Nncase.Importer/packages.lock.json +++ b/src/Nncase.Importer/packages.lock.json @@ -22,11 +22,11 @@ }, "StyleCop.Analyzers": { "type": "Direct", - "requested": "[1.2.0-beta.507, )", - "resolved": "1.2.0-beta.507", - "contentHash": "/FtugDT66cKJJ+GGH7rNpG6UDrT4iIWz45M6lrXXHobDUFDHw+q5VgkbiR+6ffTO564ge7w6fQh/eoQhVdJO8Q==", + "requested": "[1.2.0-beta.435, )", + "resolved": "1.2.0-beta.435", + "contentHash": "TADk7vdGXtfTnYCV7GyleaaRTQjfoSfZXprQrVMm7cSJtJbFc1QIbWPyLvrgrfGdfHbGmUPvaN4ODKNxg2jgPQ==", "dependencies": { - "StyleCop.Analyzers.Unstable": "1.2.0.507" + "StyleCop.Analyzers.Unstable": "1.2.0.435" } }, "Microsoft.Bcl.AsyncInterfaces": { @@ -85,8 +85,8 @@ }, "StyleCop.Analyzers.Unstable": { "type": "Transitive", - "resolved": "1.2.0.507", - "contentHash": "gTY3IQdRqDJ4hbhSA3e/R48oE8b/OiKfvwkt1QdNVfrJK2gMHBV8ldaHJ885jxWZfllK66soa/sdcjh9bX49Tw==" + "resolved": "1.2.0.435", + "contentHash": "ouwPWZxbOV3SmCZxIRqHvljkSzkCyi1tDoMzQtDb/bRP8ctASV/iRJr+A2Gdj0QLaLmWnqTWDrH82/iP+X80Lg==" }, "System.Buffers": { "type": "Transitive", @@ -371,6 +371,7 @@ "Microsoft.Extensions.Options": "[6.0.0, )", "Microsoft.Toolkit.HighPerformance": "[7.1.1, )", "NetFabric.Hyperlinq": "[3.0.0-beta48, )", + "System.CommandLine": "[2.0.0-beta4.22272.1, )", "System.Reactive": "[5.0.0, )" } }, @@ -454,6 +455,12 @@ "resolved": "2.0.0", "contentHash": "ir3uek0+7Y8SwOUGUR8y94sgpVDWLAjKGBm9z7cLe/38GyPxWbIYHPnHZHksNTExTsx3Ie9GtwagkgR/jm64hA==" }, + "System.CommandLine": { + "type": "CentralTransitive", + "requested": "[2.0.0-beta4.22272.1, )", + "resolved": "2.0.0-beta4.22272.1", + "contentHash": "1uqED/q2H0kKoLJ4+hI2iPSBSEdTuhfCYADeJrAqERmiGQ2NNacYKRNEQ+gFbU4glgVyK8rxI+ZOe1onEtr/Pg==" + }, "System.Reactive": { "type": "CentralTransitive", "requested": "[5.0.0, )", diff --git a/src/Nncase.Passes/BufferSchedule/BufferScheduleExtensions.cs b/src/Nncase.Passes/BufferSchedule/BufferScheduleExtensions.cs new file mode 100644 index 0000000000..4a07e97a8b --- /dev/null +++ b/src/Nncase.Passes/BufferSchedule/BufferScheduleExtensions.cs @@ -0,0 +1,25 @@ +// Copyright (c) Canaan Inc. All rights reserved. +// Licensed under the Apache license. See LICENSE file in the project root for full license information. + +using System.Collections.Generic; +using System.Linq; +using Nncase.IR; + +namespace Nncase.Passes.BufferSchedule; + +public static class BufferScheduleExtensions +{ + public static IEnumerable GetArguments(this Call call) + { + var hs = new HashSet(ReferenceEqualityComparer.Instance); + hs.UnionWith(call.Arguments.ToArray().Where(e => e is not (BaseFunction or Const)).ToArray().Select(e => e switch { IR.Tuple tp => tp.Fields.ToArray(), _ => new[] { e } }).SelectMany(i => i)); + return hs; + } + + public static IEnumerable GetUsers(this Call call) + { + var hs = new HashSet(ReferenceEqualityComparer.Instance); + hs.UnionWith(call.Users.Where(e => e is not BaseFunction).ToArray().Select(e => e switch { IR.Tuple tp => tp.Users.ToArray(), _ => new[] { e } }).SelectMany(i => i)); + return hs; + } +} diff --git a/src/Nncase.Passes/BufferSchedule/BufferScheduleTypes.cs b/src/Nncase.Passes/BufferSchedule/BufferScheduleTypes.cs new file mode 100644 index 0000000000..13e8ab86f0 --- /dev/null +++ b/src/Nncase.Passes/BufferSchedule/BufferScheduleTypes.cs @@ -0,0 +1,77 @@ +// Copyright (c) Canaan Inc. All rights reserved. +// Licensed under the Apache license. See LICENSE file in the project root for full license information. + +namespace Nncase.Passes.BufferSchedule; + +internal sealed class TimeInterval +{ + public TimeInterval(int start, int end) + { + Brith = start; + Death = end; + } + + public int Brith { get; set; } + + public int Death { get; set; } + + public int Size => Death - Brith; + + public override string ToString() + { + return $"TimeInterval({Brith}, {Death})"; + } +} + +internal sealed class MemSpan +{ + public MemSpan(int start, int end) + { + Start = start; + End = end; + } + + public int Start { get; set; } + + public int End { get; set; } + + public int Size => End - Start; + + public override string ToString() + { + return $"MemSpan({Start}, {End})"; + } +} + +internal class ScheduleBuffer +{ + public ScheduleBuffer(string name, int number, TimeInterval interval, MemSpan span, int[] shape, int[] strides, bool inplace) + { + Name = name; + Number = number; + Interval = interval; + Span = span; + Shape = shape; + Strides = strides; + Inplace = inplace; + } + + public string Name { get; } + + public int Number { get; } + + public TimeInterval Interval { get; } + + public MemSpan Span { get; } + + public int[] Shape { get; } + + public int[] Strides { get; } + + public bool Inplace { get; } + + public override string ToString() + { + return $"ScheduledBuffer('{Name}', {Number}, {Interval}, {Span}, ConstraintsMode.No, [{string.Join(",", Shape)}], [{string.Join(",", Strides)}], {Inplace})"; + } +} diff --git a/src/Nncase.Passes/BufferSchedule/BufferScheduler.cs b/src/Nncase.Passes/BufferSchedule/BufferScheduler.cs new file mode 100644 index 0000000000..25f7d6f5b8 --- /dev/null +++ b/src/Nncase.Passes/BufferSchedule/BufferScheduler.cs @@ -0,0 +1,215 @@ +// Copyright (c) Canaan Inc. All rights reserved. +// Licensed under the Apache license. See LICENSE file in the project root for full license information. + +using System; +using System.Collections.Generic; +using System.IO; +using System.Linq; +using System.Reactive; +using Google.OrTools.Sat; +using NetFabric.Hyperlinq; +using Nncase; +using Nncase.IR; + +namespace Nncase.Passes.BufferSchedule; + +internal sealed class BufferScheduler +{ + public IReadOnlyDictionary CollectLifeTime(Function func) + { + var c = new LifeTimeCollector(); + return c.Collect(func); + } + + public void Schedule(IReadOnlyDictionary bufferMap) + { + var model = new CpModel(); + var noOverlap = model.AddNoOverlap2D(); + var boxs = new Dictionary(ReferenceEqualityComparer.Instance); + var timeMap = new Dictionary>(); + var yStarts = new List(); + foreach (var (expr, item) in bufferMap) + { + var xInterval = model.NewIntervalVar(model.NewConstant(item.Interval.Brith), model.NewConstant(item.Interval.Size), model.NewConstant(item.Interval.Death), item.Name + $"{item.Number}_x"); + + var upbound = 2147483648 - item.Span.End; + if (upbound <= 0) + { + throw new System.NotSupportedException(); + } + + var memStartVar = model.NewIntVar(0, upbound, $"{item.Name}_{item.Number}_y_start"); + var yInterval = model.NewFixedSizeIntervalVar(memStartVar, item.Span.End, $"{item.Name}_{item.Number}_y"); + noOverlap.AddRectangle(xInterval, yInterval); + yStarts.Add(memStartVar); + boxs.Add(expr, (xInterval, yInterval)); + + for (int time = item.Interval.Brith; time < item.Interval.Death; time++) + { + if (!timeMap.TryGetValue(time, out var timelist)) + { + timelist = new(); + timeMap.Add(time, timelist); + } + + timelist.Add(expr); + } + } + + foreach (var (expr, item) in bufferMap) + { + if (expr is Call { Target: IR.Tensors.Concat } concatCall && concatCall.Arguments[0] is IR.Tuple tuple) + { + // the concat inputs must contiguous + int offset = 0; + for (int i = 0; i < tuple.Fields.Length; i++) + { + model.Add((boxs[concatCall].Y.StartExpr() + offset) == boxs[tuple.Fields[i]].Y.StartExpr()); + offset += bufferMap[tuple.Fields[i]].Span.Size; + } + } + else if (expr is Call { Target: IR.Tensors.Split } splitCall) + { + // the split must equal with input. + model.Add(boxs[splitCall].Y.StartExpr() == boxs[splitCall.Arguments[0]].Y.StartExpr()); + + // the split outputs must contiguous + var users = splitCall.GetUsers(); + int offset = 0; + foreach (var user in users.OrderBy(e => ((Call)e).Arguments[1].Evaluate().AsTensor().ToScalar())) + { + model.Add((boxs[splitCall].Y.StartExpr() + offset) == boxs[user].Y.StartExpr()); + offset += bufferMap[user].Span.Size; + } + } + else if (expr is Call { Target: IR.Tensors.Reshape } reshapCall) + { + // the reshape must equal with it's input. + model.Add(boxs[reshapCall].Y.StartExpr() == boxs[reshapCall.Arguments[0]].Y.StartExpr()); + } + } + + model.Minimize(LinearExpr.Sum(yStarts)); + + var solver = new CpSolver(); + solver.StringParameters = $"max_time_in_seconds:{60},num_workers:{8}"; + CpSolverStatus solve_status = solver.Solve(model); + if (solve_status != CpSolverStatus.Optimal && solve_status != CpSolverStatus.Feasible) + { + throw new System.NotSupportedException(); + } + + foreach (var (k, v) in bufferMap) + { + bufferMap[k].Span.Start = checked((int)solver.Value(boxs[k].Y.StartExpr())); + bufferMap[k].Span.End = checked((int)solver.Value(boxs[k].Y.EndExpr())); + } + } + + public void Dump(Stream fs, IReadOnlyDictionary buffers) + { + using (var wr = new StreamWriter(fs)) + { + wr.Write(@"from bokeh.models import ColumnDataSource, HoverTool, FuncTickFormatter, SingleIntervalTicker, SaveTool, WheelZoomTool, WheelPanTool, ResetTool +from bokeh.palettes import Category20_20 as palette +from bokeh.plotting import figure, show, save +import itertools +from dataclasses import dataclass +from enum import Enum +from typing import List + +@dataclass +class TimeInterval(): + start: int + end: int + def __str__(self) -> str: + return f'(start: {self.start}, end {self.end})' + +@dataclass +class MemSpan(): + depth_start: int + depth_end: int + def __str__(self) -> str: + return f'(start: {self.depth_start}, size {self.depth_end - self.depth_start})' + +class ConstraintsMode(Enum): + No = 0 + Channel = 1 + +@dataclass +class ScheduledBuffer(): + name: str + number: int + interval: TimeInterval + location: MemSpan + constraints: ConstraintsMode + shape: List[int] + stride: List[int] + inplace: bool + +colors = itertools.cycle(palette) + +buffers = [ +"); + foreach (var (_, v) in buffers) + { + wr.WriteLine(v.ToString() + ","); + } + + wr.Write(@"] + +source = { + 'name': [], + 'x': [], + 'y': [], + 'width': [], + 'height': [], + 'alpha': [], + 'color': [], + 'location': [], + 'interval': [], + 'shape': [], + 'stride': [], +} + +y_range_max = 0 +x_range_max = 0 +color_dict = {} +for buffer in buffers: + source['name'].append(buffer.name) + width = buffer.interval.end - buffer.interval.start + x = buffer.interval.start + (width / 2) + height = buffer.location.depth_end - buffer.location.depth_start + y = buffer.location.depth_start + (height / 2) + y_range_max = max(y_range_max, y) + x_range_max = max(x_range_max, buffer.interval.end) + source['x'].append(x) + source['y'].append(y) + source['width'].append(width) + source['height'].append(height) + color = color_dict.get(buffer.name) + if color == None: + color = next(colors) + color_dict[buffer.name] = color + source['color'].append(color) + source['alpha'].append(0.2 if buffer.inplace else 1.0) + source['interval'].append(str(buffer.interval)) + source['location'].append(str(buffer.location)) + source['shape'].append(','.join([str(s) for s in buffer.shape])) + source['stride'].append(','.join([str(s) for s in buffer.stride])) + +source = ColumnDataSource(source) +hover = HoverTool(tooltips=[('name', '@name'), ('interval', '@interval'), ('location', '@location'), + ('shape', '@shape'), ('stride', '@stride')]) + +p = figure(tools=[hover, WheelPanTool(), SaveTool(), WheelZoomTool(), ResetTool()], width=1280, height=720, + y_range=(0, y_range_max * 1.2), x_range=(-1, x_range_max + 1), + title='Local Buffer LifeTime (by Steps)') +p.rect(x='x', y='y', width='width', height='height', fill_color='color', legend_field='name', fill_alpha='alpha', source=source) +p.xaxis.axis_label = 'Time (steps)' +p.outline_line_color = None + +show(p)"); + } + } +} diff --git a/src/Nncase.Passes/BufferSchedule/LifeTimeCollector.cs b/src/Nncase.Passes/BufferSchedule/LifeTimeCollector.cs new file mode 100644 index 0000000000..1edcc263dd --- /dev/null +++ b/src/Nncase.Passes/BufferSchedule/LifeTimeCollector.cs @@ -0,0 +1,168 @@ +// Copyright (c) Canaan Inc. All rights reserved. +// Licensed under the Apache license. See LICENSE file in the project root for full license information. + +using System; +using System.Collections.Generic; +using System.Linq; +using System.Reactive; +using Nncase; +using Nncase.IR; + +namespace Nncase.Passes.BufferSchedule; + +internal sealed class LifeTimeCollector : ExprVisitor +{ + public int TimeStamp { get; private set; } + + public Dictionary LifenessMap { get; } = new(ReferenceEqualityComparer.Instance); + + public IReadOnlyDictionary Collect(Function entry) + { + Visit(entry.Body); + Update(entry.Body); // avoid final call time interval size == 1. + Alias(); + + var d = new Dictionary(ReferenceEqualityComparer.Instance); + int count = 0; + foreach (var (k, v) in LifenessMap) + { + var name = k switch + { + Call c => c.Target.GetType().Name, + Var va => va.Name, + _ => k.GetType().Name, + }; + var size = GetSize(k.CheckedType, out var shape, out var stride); + + d.Add(k, new(name, count++, v, new(0, size), shape, stride, false)); + } + + return d; + } + + protected override Unit DefaultVisitLeaf(Expr expr) => Unit.Default; + + protected override Unit VisitLeafCall(Call expr) + { + foreach (var arg in expr.Arguments) + { + Update(arg); + } + + Update(expr); + + TimeStamp += 2; + + // note we will update tuple field on the next call. + foreach (var item in expr.Users.Where(e => e is not (BaseFunction or IR.Tuple))) + { + Update(item); + } + + return Unit.Default; + } + + private void Update(Expr expr) + { + if (expr is Const or None) + { + return; + } + + if (expr is IR.Tuple t) + { + foreach (var item in t.Fields) + { + Update(item); + } + + return; + } + + if (!LifenessMap.TryGetValue(expr, out var interval)) + { + interval = new(TimeStamp, TimeStamp + 1); + } + else + { + interval.Death = TimeStamp + 1; + } + + LifenessMap[expr] = interval; + } + + private void Alias() + { + bool changed; + do + { + changed = false; + foreach (var (expr, interval) in LifenessMap) + { + if (expr is Call { Target: IR.Tensors.Reshape } callReshape) + { + changed = AliasTime(callReshape, interval); + } + } + + foreach (var (expr, interval) in LifenessMap) + { + if (expr is Call { Target: IR.Tensors.Concat } concatCall) + { + changed = AliasTime(concatCall, interval); + } + } + + foreach (var (expr, interval) in LifenessMap) + { + if (expr is Call { Target: IR.Tensors.Split } splitCall) + { + changed = AliasTime(splitCall, interval); + } + } + } while (changed); + } + + private bool AliasTime(Call call, TimeInterval interval) + { + var brith = call.GetArguments().Select(arg => LifenessMap[arg].Death).Concat(new[] { interval.Brith }).Max(); + var death = call.GetUsers().Select(usr => LifenessMap[usr].Brith).Concat(new[] { interval.Death }).Min(); + + if (brith == interval.Brith && death == interval.Death) + { + return false; + } + + if (brith >= death) + { + throw new InvalidOperationException(); + } + + interval.Brith = brith; + interval.Death = death; + return true; + } + + private int GetSize(IRType type, out int[] shape, out int[] stride) + { + shape = Array.Empty(); + stride = Array.Empty(); + var size = 0; + if (type is TensorType tensorType) + { + shape = tensorType.Shape.ToValueArray(); + stride = TensorUtilities.GetStrides(shape); + size = TensorUtilities.GetSize(shape, stride, tensorType.DType.SizeInBytes); + } + else if (type is TupleType tupleType) + { + size = 0; + foreach (var item in tupleType) + { + size += GetSize(item, out _, out _); + } + } + + return size; + } +} diff --git a/src/Nncase.Passes/DDrBufferSchdeulePass.cs b/src/Nncase.Passes/DDrBufferSchdeulePass.cs index 8afdb3c5e0..80aebda267 100644 --- a/src/Nncase.Passes/DDrBufferSchdeulePass.cs +++ b/src/Nncase.Passes/DDrBufferSchdeulePass.cs @@ -5,6 +5,7 @@ using System.Collections.Generic; using System.Diagnostics.CodeAnalysis; using System.Linq; +using System.Reactive; using System.Text; using System.Threading.Tasks; using Microsoft.Extensions.DependencyInjection; @@ -23,9 +24,9 @@ namespace Nncase.Passes; /// public sealed class DDrBufferSchdeulePass : ModulePass { - private readonly Dictionary> _module_usage = new(); + private readonly Dictionary> _moduleUsage = new(); - private readonly Dictionary> _module_hashset = new(); + private readonly Dictionary>> _moduleRdataMaps = new(); private readonly bool _enbaleMergeCall; @@ -42,41 +43,16 @@ protected override async Task RunCoreAsync(IRModule module, RunPassCon // 1. merge the all call prim func if (_enbaleMergeCall) { - HashSet mergedFuncs = new(ReferenceEqualityComparer.Instance); - HashSet stackvmFuncs = new(ReferenceEqualityComparer.Instance); - for (int i = 0; i < module.Functions.Count; i++) + if (module.Entry is Function { ModuleKind: Callable.StackVMModuleKind, Body: Expr body } func && IsFixedType(body.CheckedType)) { - if (module.Functions[i] is Function { ModuleKind: "stackvm" } func) + var sch = new BufferSchedule.BufferScheduler(); + var buffers = sch.CollectLifeTime(func); + sch.Schedule(buffers); + using (var fs = Diagnostics.DumpScope.Current.OpenFile("draw_buffers.py")) { - var analysis = new Dictionary - { - [typeof(IExprUserAnalysisResult)] = AnalyzerManager.GetAnaylsis(func), - }; - _ = new HashSet(ReferenceEqualityComparer.Instance); - var mergePass = new DataflowPass(); - mergePass.Add(mergedFuncs); - var post = await mergePass.RunAsync(func, new() { AnalysisResults = analysis, RewriteOnce = true }); - module.Replace(i, post); - stackvmFuncs.Add(post); - } - } - - // 2. add the ext func into module. - foreach (var func in stackvmFuncs) - { - var collector = new ExternalFuncCollector(); - collector.Visit(func); - foreach (var ext_func in collector.GetExternalFuncs()) - { - module.Add(ext_func); + sch.Dump(fs, buffers); } } - - // 3. remove the all merged funcs - foreach (var item in mergedFuncs) - { - module.Remove(item); - } } // 4. schedule the prim funcs. @@ -86,149 +62,121 @@ protected override async Task RunCoreAsync(IRModule module, RunPassCon { if (!prim_func.SchedResult.IsScheduled) { - var ddr_allocator = new DDrBufferAllocator(_module_usage, _module_hashset); - ddr_allocator.Visit(prim_func); // changed ddr buffer. - prim_func.SchedResult.DataUsage = ddr_allocator.DataUsage; - prim_func.SchedResult.IsScheduled = ddr_allocator.Changed; + var rewriter = new DDrBufferRewriter(_moduleUsage, _moduleRdataMaps); + var post = (TIR.PrimFunction)rewriter.Rewrite(prim_func); // changed ddr buffer. + if (rewriter.IsMutated) + { + post.SchedResult.DataUsage = rewriter.DataUsage; + post.SchedResult.IsScheduled = true; + } + + module.Replace(i, prim_func); } } } - _module_hashset.Clear(); - _module_usage.Clear(); + _moduleRdataMaps.Clear(); + _moduleUsage.Clear(); return await Task.FromResult(module); } + + private bool IsFixedType(IRType type) => type switch + { + TensorType tensorType => tensorType.Shape.IsFixed, + TupleType tupleType => tupleType.Fields.All(IsFixedType), + _ => false, + }; } -/// -/// collect and assgin the PhysicalBuffer. -/// -internal sealed class DDrBufferAllocator : ExprVisitor +internal sealed class DDrBufferRewriter : ExprRewriter { - private readonly Dictionary _functionUsage; - private readonly HashSet _functionHashset; + private readonly Dictionary _functionUsage; + private readonly Dictionary> _functionRdatas; - private PrimFunction? _entry; - - public DDrBufferAllocator(Dictionary> module_usage, Dictionary> module_hashset) + public DDrBufferRewriter(Dictionary> moduleUsage, Dictionary>> moduleRdataMaps) { - ModuleUsage = module_usage; - ModuleHashSet = module_hashset; + ModuleUsage = moduleUsage; + ModuleRdataMaps = moduleRdataMaps; _functionUsage = new(); - _functionHashset = new(ReferenceEqualityComparer.Instance); + _functionRdatas = new(); Changed = false; } - public Dictionary> ModuleUsage { get; } + public Dictionary> ModuleUsage { get; } - public Dictionary> ModuleHashSet { get; } + public Dictionary>> ModuleRdataMaps { get; } public bool Changed { get; private set; } - public int DataUsage => _functionUsage.GetValueOrDefault(Schedule.MemoryLocation.Data, 0); + public long DataUsage => _functionUsage.GetValueOrDefault(MemoryLocation.Data, 0); + + public PrimFunction Entry => (PrimFunction)VisitRoot!; - /// - /// only visit one prim func. - /// - protected override bool VisitPrimFunction(PrimFunction primFunction) + protected override Expr RewriteLeafBuffer(TIR.Buffer expr) { - _entry ??= primFunction; - if (object.ReferenceEquals(_entry, primFunction)) + if (expr.MemSpan is { Location: TIR.MemoryLocation.Input or TIR.MemoryLocation.Output, Start: None, Size: TensorConst size } memSpan) { - foreach (var physical in primFunction.Parameters) + // input/output write into the FunctionUsage + if (!_functionUsage.TryGetValue(memSpan.Location, out var start)) { - if (physical.MemLocation is Schedule.MemoryLocation.Input or Schedule.MemoryLocation.Output) - { - // avoid visit same buffer - if (!_functionHashset.Contains(physical)) - { - // input/output write into the FunctionUsage - if (!_functionUsage.TryGetValue(physical.MemLocation, out var start)) - { - start = 0; - } - - physical.Start = start; - _functionUsage[physical.MemLocation] = start + physical.Size; - _functionHashset.Add(physical); - Changed = true; - } - } - else - { - throw new NotSupportedException($"The prim function parameters mem location must be input/output but get {physical.MemLocation}!"); - } + start = 0; } - return base.VisitPrimFunction(_entry); + _functionUsage[memSpan.Location] = start + size.Value.ToScalar(); + Changed = true; + + return expr.With(memSpan: memSpan.With(start: Tensor.FromPointer((ulong)start, expr.ElemType))); } - return true; + return expr; } - protected override bool VisitLeafBuffer(TIR.Buffer buffer) + protected override TIR.MemSpan RewriteLeafMemSpan(TIR.MemSpan memSpan) { - if (buffer is not TIR.PhysicalBuffer physical) - { - return true; - } - - // rdata write into the moduleUsage - if (physical.MemLocation is Schedule.MemoryLocation.Rdata) + if (memSpan is { Location: MemoryLocation.Rdata, Start: Call { Target: IR.Buffers.DDrOf, Arguments: var arg } } && arg[0] is Const { ValueType: TensorType constType } @const) { - if (!ModuleHashSet.TryGetValue(_entry!.ModuleKind, out var module_hashset)) + if (!ModuleRdataMaps.TryGetValue(Entry.ModuleKind, out var moduleRdataMap)) { - module_hashset = new(ReferenceEqualityComparer.Instance); - ModuleHashSet.Add(_entry!.ModuleKind, module_hashset); + moduleRdataMap = new(); + ModuleRdataMaps.Add(Entry.ModuleKind, moduleRdataMap); } - if (!ModuleUsage.TryGetValue(_entry!.ModuleKind, out var module_usage)) + if (!ModuleUsage.TryGetValue(Entry.ModuleKind, out var moduleUsage)) { - module_usage = new(); - ModuleUsage.Add(_entry!.ModuleKind, module_usage); + moduleUsage = new(); + ModuleUsage.Add(Entry.ModuleKind, moduleUsage); } - if (!module_hashset.Contains(physical)) + if (!moduleRdataMap.TryGetValue(@const, out var memRange)) { - if (!module_usage.TryGetValue(physical.MemLocation, out var start)) + if (!moduleUsage.TryGetValue(memSpan.Location, out var start)) { start = 0; } - physical.Start = start; - module_usage[physical.MemLocation] = start + physical.Size; - module_hashset.Add(physical); - _entry.SchedResult.Rdatas.Add(physical); - + _ = ComputeSize(@const); + moduleUsage[memSpan.Location] = start + ComputeSize(@const); + memRange = new(start, start + ComputeSize(@const)); + moduleRdataMap.Add(@const, memRange); + Entry.SchedResult.Rdatas.Add(@const, memRange); Changed = true; } - } - else if (physical.MemLocation is Schedule.MemoryLocation.Data) - { - // data write into the FunctionUsage - if (!_functionHashset.Contains(physical)) - { - if (!_functionUsage.TryGetValue(physical.MemLocation, out var start)) - { - start = 0; - } - physical.Start = start; - _functionUsage[physical.MemLocation] = start + physical.Size; - _functionHashset.Add(physical); - Changed = true; - } - } - else if (physical.MemLocation is Schedule.MemoryLocation.SharedData) - { - throw new NotSupportedException("Current Not Support!"); + return memSpan.With(new TensorConst(Tensor.FromPointer((ulong)memRange.Min, constType.DType)), memRange.Max - memRange.Min); } - return true; + return memSpan; } - protected override bool DefaultVisitLeaf(Expr expr) => true; + private long ComputeSize(IValue v) => v.AsTensors().Select(t => t.BytesBuffer.Length).Sum(); + + private long ComputeSize(Const @const) => @const switch + { + TensorConst { Value: Tensor tc } => tc.BytesBuffer.Length, + TupleConst tc => ComputeSize(tc.Value), + _ => throw new NotSupportedException(), + }; } internal sealed class ExternalFuncCollector : ExprWalker diff --git a/src/Nncase.Passes/EGraphExtractPass.cs b/src/Nncase.Passes/EGraphExtractPass.cs index d4ebe21ebc..2c2baa12b8 100644 --- a/src/Nncase.Passes/EGraphExtractPass.cs +++ b/src/Nncase.Passes/EGraphExtractPass.cs @@ -24,7 +24,7 @@ public EGraphExtractPass(IBaseFuncCostEvaluator? costEvaluator = null) protected override Task RunCoreAsync(IEGraph input, RunPassContext context) { - var post = (BaseFunction)input.Extract(input.Root!, _costEvaluator); + var post = (BaseFunction)input.Extract(input.Root!, _costEvaluator, out _); IRHelpers.DCE(post); return Task.FromResult(post); } diff --git a/src/Nncase.Passes/Mutators/IFusionMergeRule.cs b/src/Nncase.Passes/Mutators/IFusionMergeRule.cs index 15878bcc71..9f72015746 100644 --- a/src/Nncase.Passes/Mutators/IFusionMergeRule.cs +++ b/src/Nncase.Passes/Mutators/IFusionMergeRule.cs @@ -668,14 +668,8 @@ private bool ProcessFusionMerge(Func mergedFusionRewriteCallBack, Fu { if (caller_inputs[i] is Call { Target: Fusion }) { - Fusion callee_fusion; - try + if (result.GetValueOrDefault($"callee_fusion_{i}") is not Fusion callee_fusion) { - callee_fusion = (Fusion)result[$"callee_fusion_{i}"]; - } - catch (KeyNotFoundException) - { - // when matched fusion(fusion(x,y)), the input => fusion(x,y) return false; } diff --git a/src/Nncase.Passes/Rules/Neutral/AddRangeOfAndMarker.cs b/src/Nncase.Passes/Rules/Neutral/AddRangeOfAndMarker.cs index ac7cb9feb6..e75489c638 100644 --- a/src/Nncase.Passes/Rules/Neutral/AddRangeOfAndMarker.cs +++ b/src/Nncase.Passes/Rules/Neutral/AddRangeOfAndMarker.cs @@ -10,6 +10,7 @@ using Nncase.IR.Imaging; using Nncase.IR.Math; using Nncase.IR.NN; +using Nncase.IR.RNN; using Nncase.IR.Tensors; using Nncase.PatternMatch; using static Nncase.IR.TypePatternUtility; diff --git a/src/Nncase.Passes/Rules/Neutral/CombineQuantize.cs b/src/Nncase.Passes/Rules/Neutral/CombineQuantize.cs index e046301045..9519b011c7 100644 --- a/src/Nncase.Passes/Rules/Neutral/CombineQuantize.cs +++ b/src/Nncase.Passes/Rules/Neutral/CombineQuantize.cs @@ -34,11 +34,12 @@ public sealed partial class CombineQuantizeConcat : RewriteRule "quantize", _ => true, IsConcat( - IsTuple(IsVArgsRepeat("tupleInputs", () => IsWildcard())), - IsWildcard("axis")), + "concat", + _ => true, + IsTuple(IsVArgsRepeat("tupleInputs", () => IsWildcard()))), IsWildcard("quantParam")); - private Expr? GetReplace(Quantize quantize, IReadOnlyList tupleInputs, Expr axis, Expr quantParam, RunPassContext options) + private Expr? GetReplace(Quantize quantize, IReadOnlyList tupleInputs, IR.Tensors.Concat concat, Expr quantParam, RunPassContext options) { if (options.Driver is DataflowPass) { @@ -54,7 +55,7 @@ public sealed partial class CombineQuantizeConcat : RewriteRule } } - return Concat(new IR.Tuple(tupleInputs.Select(e => IR.F.Math.Quantize(e, quantParam, quantize.TargetType)).ToArray()), axis); + return Concat(new IR.Tuple(tupleInputs.Select(e => IR.F.Math.Quantize(e, quantParam, quantize.TargetType)).ToArray()), concat.Axis); } } diff --git a/src/Nncase.Passes/Rules/Neutral/CombineReshape.cs b/src/Nncase.Passes/Rules/Neutral/CombineReshape.cs index 3f0c2be75d..8c36b0b537 100644 --- a/src/Nncase.Passes/Rules/Neutral/CombineReshape.cs +++ b/src/Nncase.Passes/Rules/Neutral/CombineReshape.cs @@ -201,3 +201,61 @@ public sealed partial class CombineReshapePad : IRewriteRule return null; } } + +/// +/// combine reshape transpose +/// e.g. : +/// %5 // f32[1,77,768] +/// %6 = Reshape(%5, const(i64[4] : {1L,77L,12L,64L})): // f32[1,77,12,64] +/// %7 = Transpose(%6, const(i64[4] : {0L,2L,1L,3L})): // f32[1,12,77,64] +/// %8 = Reshape(%7, const(i32[3] : {12,77,64})): // f32[12,77,64]. +/// after combine : +/// %5 // f32[1,77,768] +/// %6 = Reshape(%5, const(i64[4] : {1L,77L,12L,64L})): // f32[1,77,12,64] +/// %7 = Reshape(%6, const(i64[3] : {77L,12L,64L})): // f32[77L,12L,64L]. +/// %8 = Transpose(%7, const(i64[4] : {1L,0L,2L})): // f32[12,77,64]. +/// then use foldreshape. +/// +[RuleGenerator] +public sealed partial class CombineReshapeTranspose : IRewriteRule +{ + /// + public IPattern Pattern { get; } = IsReshape( + IsTranspose( + null, + "trans", + IsWildcard("input") with { TypePattern = HasFixedShape() }, + IsTensorConst("perm")) with + { TypePattern = HasFixedShape() }, + IsTensorConst("newShape")); + + private Expr? GetReplace(Expr input, Call trans, int[] newShape, int[] perm) + { + var transShape = trans.CheckedShape.ToValueArray(); + + if (transShape.Length == newShape.Length + 1) + { + // check reshape is sequeeze + var viewAxis = RulesUtility.FindSqueezeAxis(transShape, newShape); + if (viewAxis == -1) + { + return null; + } + + var inv = perm.Select((p, i) => (p, i)).OrderBy(tp => tp.p).ToArray(); + var invViewAxis = inv.Where(tp => tp.i == viewAxis).First().p; + var invPerm = perm.ToList(); + var invNewShape = input.CheckedShape.ToValueList(); + invNewShape.RemoveAt(invViewAxis); + invPerm.Remove(invViewAxis); + return IR.F.Tensors.Transpose(IR.F.Tensors.Reshape(input, invNewShape.ToArray()), invPerm.Select(i => i > invViewAxis ? i - 1 : i).ToArray()); + } + else if (transShape.Length == newShape.Length - 1) + { + // check rehsape is unsequeeze + return null; + } + + return null; + } +} diff --git a/src/Nncase.Passes/Rules/Neutral/CombineTranspose.cs b/src/Nncase.Passes/Rules/Neutral/CombineTranspose.cs index dbe323c1ee..4d979c8a56 100644 --- a/src/Nncase.Passes/Rules/Neutral/CombineTranspose.cs +++ b/src/Nncase.Passes/Rules/Neutral/CombineTranspose.cs @@ -194,6 +194,7 @@ public sealed partial class CombineTransposeConcat : IRewriteRule public IPattern Pattern { get; } = IsConcat( "concat", "concatCall", + _ => true, PatternMatch.Utility.IsTuple(null, IsVArgsRepeat("tupleInputs", exprs => { var patterns = new Pattern[exprs.Length]; @@ -203,11 +204,11 @@ public sealed partial class CombineTransposeConcat : IRewriteRule } return patterns; - })), - IsTensorConst("axis")); + }))); - private Expr? GetReplace(Expr concat, Call concatCall, IReadOnlyList tupleInputs, int axis, IMatchResult matchResult) + private Expr? GetReplace(IR.Tensors.Concat concat, Call concatCall, IReadOnlyList tupleInputs, IMatchResult matchResult) { + int axis = concat.Axis; var inputs = Enumerable.Range(0, tupleInputs.Count).Select(i => (Expr)matchResult[$"input_{i}"]); var perms = new HashSet>(Enumerable.Range(0, tupleInputs.Count).Select(i => ((TensorConst)matchResult[$"perm_{i}"]).Value.Cast(CastMode.KDefault))); @@ -343,6 +344,50 @@ public sealed partial class CombineTransposeReduce : IRewriteRule } } +/// +/// x // [12, 77, 64] +/// transpose(reshape(x, [1, 12, 77, 64]), [0, 2, 1, 3]) => reshape(transpose(x, [1, 0, 2]), [1, 77, 12, 64]). +/// +[RuleGenerator] +public sealed partial class CombineTransposeReshape : IRewriteRule +{ + /// + public IPattern Pattern { get; } = IsTranspose( + null, + "trans", + IsReshape( + IsWildcard("input") with { TypePattern = HasFixedShape() }, + IsTensorConst("newShape")) with + { TypePattern = HasFixedShape() }, + IsTensorConst("perm")); + + private Expr? GetReplace(Call trans, Expr input, int[] newShape, int[] perm) + { + var inShape = input.CheckedShape.ToValueArray(); + var outShape = trans.CheckedShape.ToValueArray(); + if (!(newShape.Length == inShape.Length + 1)) + { + return null; + } + + // check reshape is sequeeze + var axis = RulesUtility.FindSqueezeAxis(newShape, inShape); + if (axis == -1) + { + return null; + } + + var newPerm = perm.ToList(); + newPerm.Remove(axis); + newPerm = newPerm.Select(i => i > axis ? i - 1 : i).ToList(); + + var inv = perm.Select((p, i) => (p, i)).OrderBy(tp => tp.p).ToArray(); + var invNewShape = newPerm.Select(i => inShape[i]).ToList(); + invNewShape.Insert(perm.ToList().IndexOf(axis), 1); + return Reshape(Transpose(input, newPerm.ToArray()), invNewShape.ToArray()); + } +} + /// /// Combine Transpose with Unary /// reduce(transpose(x,p), a) => transpose(reduce(x, invtranspose(a, p)), p). diff --git a/src/Nncase.Passes/Rules/Neutral/FocusFull.cs b/src/Nncase.Passes/Rules/Neutral/FocusFull.cs index 3aa0e6b227..5c8b749140 100644 --- a/src/Nncase.Passes/Rules/Neutral/FocusFull.cs +++ b/src/Nncase.Passes/Rules/Neutral/FocusFull.cs @@ -20,18 +20,19 @@ public sealed partial class FocusFull : RewriteRule /// public override Pattern Pattern { get; } = IsConcat( - null, + "concat", "concatCall", + _ => true, PatternMatch.Utility.IsTuple("tp", new[] { IsSlice(Input, IsTensorConst("begin0"), IsTensorConst("end0"), IsTensorConst("axes0"), IsTensorConst("stride0")), IsSlice(Input, IsTensorConst("begin1"), IsTensorConst("end1"), IsTensorConst("axes1"), IsTensorConst("stride1")), IsSlice(Input, IsTensorConst("begin2"), IsTensorConst("end2"), IsTensorConst("axes2"), IsTensorConst("stride2")), IsSlice(Input, IsTensorConst("begin3"), IsTensorConst("end3"), IsTensorConst("axes3"), IsTensorConst("stride3")), - }), - IsTensorConst("axis")); + })); - private Expr? GetReplace(Call concatCall, Expr input, int[] begin0, long[] end0, int[] axes0, int[] stride0, int[] begin1, long[] end1, int[] axes1, int[] stride1, int[] begin2, long[] end2, int[] axes2, int[] stride2, int[] begin3, long[] end3, int[] axes3, int[] stride3, int axis) + private Expr? GetReplace(IR.Tensors.Concat concat, Call concatCall, Expr input, int[] begin0, long[] end0, int[] axes0, int[] stride0, int[] begin1, long[] end1, int[] axes1, int[] stride1, int[] begin2, long[] end2, int[] axes2, int[] stride2, int[] begin3, long[] end3, int[] axes3, int[] stride3) { + int axis = concat.Axis; var inputShape = input.CheckedShape.ToValueArray(); if (inputShape[0] != 1) { diff --git a/src/Nncase.Passes/Rules/Neutral/FoldGatherReshape.cs b/src/Nncase.Passes/Rules/Neutral/FoldGatherReshape.cs index f0b23530f0..4500a41674 100644 --- a/src/Nncase.Passes/Rules/Neutral/FoldGatherReshape.cs +++ b/src/Nncase.Passes/Rules/Neutral/FoldGatherReshape.cs @@ -15,10 +15,14 @@ public sealed partial class FoldGatherReshape : RewriteRule { // Reshape(Gather(Shape, 0, 0), new[] { 0 }) -> GetItem(Shape, 0) public override Pattern Pattern => IsGather( - IsReshape(IsWildcard("input"), IsTensorConst("newShape")), IsTensorConst("axis"), IsTensorConst("index")); + "gather", + _ => true, + IsReshape(IsWildcard("input"), IsTensorConst("newShape")), + IsTensorConst("index")); - private Expr? GetReplace(Expr input, int[] newShape, int axis, int index) + private Expr? GetReplace(Expr input, int[] newShape, IR.Tensors.Gather gather, int index) { + int axis = gather.Axis; if (newShape.SequenceEqual(new[] { 1 }) && axis == 1) { return input[index]; diff --git a/src/Nncase.Passes/Rules/Neutral/FoldLayerNorm.cs b/src/Nncase.Passes/Rules/Neutral/FoldLayerNorm.cs index 41b3a34979..c8df9e9e62 100644 --- a/src/Nncase.Passes/Rules/Neutral/FoldLayerNorm.cs +++ b/src/Nncase.Passes/Rules/Neutral/FoldLayerNorm.cs @@ -279,3 +279,57 @@ public sealed partial class FoldLayerNormPattern4 : RewriteRule return null; } } + +// pattern from llama without mean and beta +[RuleGenerator] +public sealed partial class FoldLayerNormPattern5 : RewriteRule +{ + /// + public override CallPattern Pattern { get; } = + IsBinary( + "mulGamma", + "mulGammaCall", + BinaryOp.Mul, + IsTensorConst("gamma"), + IsBinary( + "mulX", + "mulXCall", + BinaryOp.Mul, + IsWildcard("input"), + IsBinary( + "rsqrt", + "rsqrtCall", + BinaryOp.Div, + IsTensorConst("one"), + IsUnary( + "sqrt", + "sqrtCall", + UnaryOp.Sqrt, + IsBinary( + "addEps", + "addEpsCall", + BinaryOp.Add, + IsReduce( + "rdVar", + "rdVarCall", + ReduceOp.Mean, + IsBinary( + "pow2", + "pow2Call", + BinaryOp.Pow, + IsWildcard(), + IsTensorConst("two"))), + IsTensorConst("eps")))))); + + private Expr? GetReplace(Call pow2Call, TensorConst eps, TensorConst gamma, Expr input, TensorConst one, TensorConst two) + { + if (input == pow2Call[Binary.Lhs] && one.Value.Cast()[0] == 1f && two.Value.Cast()[0] == 2f) + { + var axis = pow2Call.CheckedShape.Count - gamma.CheckedShape.Count; + var beta = Tensor.FromScalar(0f, gamma.CheckedShape); + return LayerNorm(axis, eps.Value.Cast()[0], input, gamma, beta, hasMean: false); + } + + return null; + } +} diff --git a/src/Nncase.Passes/Rules/Neutral/FoldPrePostReshapeSoftmax.cs b/src/Nncase.Passes/Rules/Neutral/FoldPrePostReshapeSoftmax.cs new file mode 100644 index 0000000000..83edfe5e5e --- /dev/null +++ b/src/Nncase.Passes/Rules/Neutral/FoldPrePostReshapeSoftmax.cs @@ -0,0 +1,38 @@ +// Copyright (c) Canaan Inc. All rights reserved. +// Licensed under the Apache license. See LICENSE file in the project root for full license information. + +using System; +using System.Collections.Generic; +using System.Linq; +using Nncase.IR; +using Nncase.IR.NN; +using Nncase.PatternMatch; +using static Nncase.IR.F.NN; +using static Nncase.IR.F.Tensors; +using static Nncase.IR.TypePatternUtility; +using static Nncase.PatternMatch.F.Math; +using static Nncase.PatternMatch.F.NN; +using static Nncase.PatternMatch.F.Tensors; +using static Nncase.PatternMatch.Utility; + +namespace Nncase.Passes.Rules.Neutral; + +/// +/// Fold nop . +/// +[RuleGenerator] +public sealed partial class FoldPrePostReshapeSoftmax : IRewriteRule +{ + /// + public IPattern Pattern { get; } = IsReshape( + "reshape", + "reshapeCall", + _ => true, + IsSoftmax("softmax", IsReshape("rehsape2", "reshapeCall2", _ => true, IsWildcard("input"), IsTensorConst("shape2"))), + IsTensorConst("shape1")); + + private Expr? GetReplace(Expr input) + { + return Softmax(input, 3); + } +} diff --git a/src/Nncase.Passes/Rules/Neutral/FoldReshape.cs b/src/Nncase.Passes/Rules/Neutral/FoldReshape.cs index 2d12883101..013d76f86e 100644 --- a/src/Nncase.Passes/Rules/Neutral/FoldReshape.cs +++ b/src/Nncase.Passes/Rules/Neutral/FoldReshape.cs @@ -62,6 +62,29 @@ public sealed partial class FoldTwoReshapes : IRewriteRule } } +/// +/// Fold sequeeze reshape(binary(unsequeeze reshape(x), const)). +/// +[RuleGenerator] +public sealed partial class FoldReshapeBinaryConstReshape : IRewriteRule +{ + /// + public IPattern Pattern { get; } = + IsReshape(IsSwappableBinary("binary", null, b => b.BinaryOp is BinaryOp.Add or BinaryOp.Mul, IsReshape(IsWildcard("input") with { TypePattern = HasFixedShape() }, IsTensorConst("unsqShape")), IsTensorConst("binaryConst")), IsTensorConst("sqShape")); + + private Expr? GetReplace(Expr input, Binary binary, int[] unsqShape, TensorConst binaryConst, int[] sqShape) + { + var inShape = input.CheckedShape.ToValueArray(); + if (!(sqShape.SequenceEqual(inShape) && RulesUtility.FindSqueezeAxis(unsqShape, sqShape) is int axis && axis != -1 && ( + (binaryConst.Value.Shape.Rank == unsqShape.Length && binaryConst.Value.Shape[axis].Value == 1) || (Evaluator.TypeInference.BroadcastType((TensorType)input.CheckedType, (TensorType)binaryConst.CheckedType) is TensorType outType && outType.Shape.ToValueArray().SequenceEqual(inShape))))) + { + return null; + } + + return IR.F.Math.Binary(binary.BinaryOp, input, (binaryConst.Value.Shape.Rank == unsqShape.Length && binaryConst.Value.Shape[axis].Value == 1) ? IR.F.Tensors.Squeeze(binaryConst, new[] { axis }) : binaryConst); + } +} + /// /// Fold nop . /// diff --git a/src/Nncase.Passes/Rules/Neutral/FoldSwish.cs b/src/Nncase.Passes/Rules/Neutral/FoldSwish.cs index 0d53cca146..b8ff58e72e 100644 --- a/src/Nncase.Passes/Rules/Neutral/FoldSwish.cs +++ b/src/Nncase.Passes/Rules/Neutral/FoldSwish.cs @@ -4,6 +4,7 @@ using Nncase.IR; using Nncase.IR.Math; using Nncase.PatternMatch; +using static Nncase.IR.TypePatternUtility; using static Nncase.PatternMatch.F.Math; using static Nncase.PatternMatch.F.NN; using static Nncase.PatternMatch.Utility; @@ -11,37 +12,37 @@ namespace Nncase.Passes.Rules.Neutral; [RuleGenerator] -public sealed partial class FoldSwishPattern1 : RewriteRule +public sealed partial class FoldSwishPattern1 : RewriteRule { + public FoldSwishPattern1() + { + var input = IsWildcard("input"); + Pattern = IsSwappableBinary(null!, null, b => b.BinaryOp == BinaryOp.Mul, IsSigmoid(input), input); + } + /// - public override CallPattern Pattern { get; } = - IsBinary(null, "binaryCall", BinaryOp.Mul, IsSigmoid(null, "sigmoidCall", IsWildcard("input"))); + public override Pattern Pattern { get; } - private Expr? GetReplace(Call binaryCall, Call sigmoidCall, Expr input) + private Expr? GetReplace(Expr input) { - if (binaryCall[Binary.Rhs] == input) - { - return IR.F.NN.Swish(input); - } - - return null; + return IR.F.NN.Swish(input); } } [RuleGenerator] -public sealed partial class FoldSwishPattern2 : RewriteRule +public sealed partial class FoldSwishPattern2 : RewriteRule { + public FoldSwishPattern2() + { + var input = IsWildcard("input"); + Pattern = IsSwappableBinary(null!, null, b => b.BinaryOp == BinaryOp.Mul, IsSigmoid(IsSwappableBinary(null!, null, b => b.BinaryOp == BinaryOp.Mul, input, IsTensorConst("beta", IsFloatScalar()))), input); + } + /// - public override CallPattern Pattern { get; } = - IsBinary(null, "binaryCall", BinaryOp.Mul, IsWildcard(), IsSigmoid(null, "sigmoidCall", IsWildcard("input"))); + public override Pattern Pattern { get; } - private Expr? GetReplace(Call binaryCall, Call sigmoidCall, Expr input) + private Expr? GetReplace(Expr input, TensorConst beta) { - if (binaryCall[Binary.Lhs] == input) - { - return IR.F.NN.Swish(input); - } - - return null; + return IR.F.NN.Swish(input, beta); } } diff --git a/src/Nncase.Passes/Rules/Neutral/FusionMaker.cs b/src/Nncase.Passes/Rules/Neutral/FusionMaker.cs index 8e371dded0..9b77082ecf 100644 --- a/src/Nncase.Passes/Rules/Neutral/FusionMaker.cs +++ b/src/Nncase.Passes/Rules/Neutral/FusionMaker.cs @@ -22,13 +22,13 @@ namespace Nncase.Passes.Rules.Neutral; public abstract class FusionMaker : RewriteRule { - private int _count; + public int Count { get; set; } public virtual string Name { get; } = "FusionMaker"; public virtual string ModuleKind { get; } = "StackVM"; - public string FullName => $"{Name}_{_count++}"; + public string FullName => $"{Name}_{Count}"; } /// diff --git a/src/Nncase.Passes/Rules/Neutral/NormAxis.cs b/src/Nncase.Passes/Rules/Neutral/NormAxis.cs new file mode 100644 index 0000000000..13e3c0d30b --- /dev/null +++ b/src/Nncase.Passes/Rules/Neutral/NormAxis.cs @@ -0,0 +1,115 @@ +// Copyright (c) Canaan Inc. All rights reserved. +// Licensed under the Apache license. See LICENSE file in the project root for full license information. + +using System.Collections.Generic; +using System.Linq; +using Nncase.IR; +using Nncase.PatternMatch; +using static Nncase.IR.TypePatternUtility; +using static Nncase.PatternMatch.F.Math; +using static Nncase.PatternMatch.F.NN; +using static Nncase.PatternMatch.F.Tensors; +using static Nncase.PatternMatch.Utility; + +namespace Nncase.Passes.Rules.Neutral; + +[RuleGenerator] +public sealed partial class NormAxisGather : RewriteRule +{ + /// + public override CallPattern Pattern { get; } = IsGather("gather", g => g.Axis < 0, IsWildcard("input") with { TypePattern = HasRank() }, IsWildcard("index") with { TypePattern = HasRank() }); + + private Expr? GetReplace(IR.Tensors.Gather gather, Expr input, Expr index) + { + return IR.F.Tensors.Gather(input, gather.Axis + input.CheckedShape.Rank, index); + } +} + +[RuleGenerator] +public sealed partial class NormAxisConcat : RewriteRule +{ + /// + public override CallPattern Pattern { get; } = IsConcat("concat", op => op.Axis < 0, IsTuple(IsVArgsRepeat("inputs", inputs => + { + var ps = new Pattern[inputs.Length]; + for (int i = 0; i < inputs.Length; i++) + { + ps[i] = IsWildcard(i.ToString()) with { TypePattern = HasRank() }; + } + + return ps; + }))); + + private Expr? GetReplace(IR.Tensors.Concat concat, IReadOnlyList inputs) + { + return IR.F.Tensors.Concat(new IR.Tuple(inputs.ToArray()), concat.Axis + inputs[0].CheckedShape.Rank); + } +} + +[RuleGenerator] +public sealed partial class NormAxisReduce : RewriteRule +{ + /// + public override CallPattern Pattern { get; } = IsReduce("reduce", "call", _ => true, IsWildcard("input") with { TypePattern = HasRank() }, IsTensorConst("axes"), IsWildcard("initValue"), IsWildcard("keepDims")); + + private Expr? GetReplace(IR.Math.Reduce reduce, Call call, Expr input, int[] axes, Expr initValue, Expr keepDims) + { + if (axes.Any(axis => axis < 0)) + { + return IR.F.Tensors.Reduce(reduce.ReduceOp, input, axes.Select(axis => axis < 0 ? axis + input.CheckedShape.Rank : axis).ToArray(), initValue, keepDims); + } + + return call; + } +} + +[RuleGenerator] +public sealed partial class NormAxisReduceArg : RewriteRule +{ + /// + public override CallPattern Pattern { get; } = IsReduceArg("reduce", "call", _ => true, IsWildcard("input") with { TypePattern = HasRank() }, IsTensorConst("axis"), IsWildcard("keepDims"), IsWildcard("selectLastIndex")); + + private Expr? GetReplace(IR.Math.ReduceArg reduce, Call call, Expr input, int axis, Expr keepDims, Expr selectLastIndex) + { + if (axis < 0) + { + return IR.F.Tensors.ReduceArg(reduce.ReduceArgOp, reduce.DestType, input, axis + input.CheckedShape.Rank, keepDims, selectLastIndex); + } + + return call; + } +} + +[RuleGenerator] +public sealed partial class NormAxisReshape : RewriteRule +{ + /// + public override CallPattern Pattern { get; } = IsReshape("reshape", "call", IsWildcard("input") with { TypePattern = HasFixedShape() }, IsTensorConst("newshape")) with { TypePattern = HasFixedShape() }; + + private Expr? GetReplace(Call call, Expr input, int[] newshape) + { + if (newshape.Any(dim => dim < 0)) + { + return IR.F.Tensors.Reshape(input, call.CheckedShape.ToValueArray()); + } + + return null; + } +} + +[RuleGenerator] +public sealed partial class NormAxisSlice : RewriteRule +{ + /// + public override CallPattern Pattern { get; } = IsSlice("slice", "call", IsWildcard("input") with { TypePattern = HasFixedShape() }, IsTensorConst("begins"), IsTensorConst("ends"), IsTensorConst("axes"), IsTensorConst("strides")) with { TypePattern = HasFixedShape() }; + + private Expr? GetReplace(Call call, Expr input, Expr begins, Expr ends, int[] axes, Expr strides) + { + if (axes.Any(dim => dim < 0)) + { + return IR.F.Tensors.Slice(input, begins, ends, axes.Select(dim => dim < 0 ? dim + input.CheckedShape.Rank : dim).ToArray(), strides); + } + + return null; + } +} diff --git a/src/Nncase.Passes/Rules/Neutral/PrimFuncMergeRule.cs b/src/Nncase.Passes/Rules/Neutral/PrimFuncMergeRule.cs index 30cd890f12..16e9b6727c 100644 --- a/src/Nncase.Passes/Rules/Neutral/PrimFuncMergeRule.cs +++ b/src/Nncase.Passes/Rules/Neutral/PrimFuncMergeRule.cs @@ -21,6 +21,7 @@ namespace Nncase.Passes.Rules.Neutral; +#if false [RuleGenerator] public sealed partial class PrimFuncMergeRule : RewriteRule { @@ -98,7 +99,7 @@ public PrimFuncMergeRule(HashSet mergedFuncs) } // 2. chack and create the data buffer - if (calleeFunc.Parameters.ToArray().Count(b => b.MemLocation == Schedule.MemoryLocation.Output) != 1) + if (calleeFunc.Parameters.ToArray().Count(b => b.MemLocation == MemoryLocation.Output) != 1) { // the direct call mean the callee function only have one output. return null; @@ -128,7 +129,7 @@ public PrimFuncMergeRule(HashSet mergedFuncs) // 5. build the new call. var nameWrapper = callerWrapper.Name; // + '_' + calleeWrapper.Name; - var newWrapper = new PrimFunctionWrapper(nameWrapper, newFunc, newFuncParams.Count(b => b.MemLocation == Schedule.MemoryLocation.Input)); + var newWrapper = new PrimFunctionWrapper(nameWrapper, newFunc, newFuncParams.Count(b => b.MemLocation == MemoryLocation.Input)); var newCallParams = new List(); newCallParams.AddRange(callerParams.Take(calleeBufferIndexs[0])); @@ -151,10 +152,10 @@ private bool BufferCanMerge(TIR.PhysicalBuffer retBuffer, TIR.PhysicalBuffer inB retBuffer.FixedStrides.SequenceEqual(inBuffer.FixedStrides) && retBuffer.ElemType == inBuffer.ElemType && retBuffer.Size == inBuffer.Size && - retBuffer.MemLocation == Schedule.MemoryLocation.Output && - inBuffer.MemLocation == Schedule.MemoryLocation.Input) + retBuffer.MemLocation == MemoryLocation.Output && + inBuffer.MemLocation == MemoryLocation.Input) { - dataBuffer = new TIR.PhysicalBuffer(inBuffer.Name, inBuffer.ElemType, Schedule.MemoryLocation.Data, inBuffer.FixedDimensions, inBuffer.FixedStrides, inBuffer.Start, inBuffer.Size); + dataBuffer = new TIR.PhysicalBuffer(inBuffer.Name, inBuffer.ElemType, MemoryLocation.Data, inBuffer.FixedDimensions, inBuffer.FixedStrides, inBuffer.Start, inBuffer.Size); return true; } @@ -191,3 +192,4 @@ protected override Expr VisitVar(Var var, Unit context) } } } +#endif diff --git a/src/Nncase.Passes/Rules/ShapeExpr/GatherToGetItem.cs b/src/Nncase.Passes/Rules/ShapeExpr/GatherToGetItem.cs index 2a1453f3c9..ec379be4a4 100644 --- a/src/Nncase.Passes/Rules/ShapeExpr/GatherToGetItem.cs +++ b/src/Nncase.Passes/Rules/ShapeExpr/GatherToGetItem.cs @@ -14,10 +14,9 @@ namespace Nncase.Passes.Rules.ShapeExpr; public sealed partial class GatherToGetItem : RewriteRule { // (Gather(input, 0, 0) -> GetItem(input) - public override Pattern Pattern => IsGather( - IsWildcard("input"), IsTensorConst("axis"), IsTensorConst("index") with { TypePattern = IsScalar() }); + public override Pattern Pattern => IsGather("gather", 0, IsWildcard("input"), IsTensorConst("index") with { TypePattern = IsScalar() }); - private Expr? GetReplace(Expr input, int axis, int index) + private Expr? GetReplace(Expr input, int index) { return input[index]; } diff --git a/src/Nncase.Passes/RulesUtility.cs b/src/Nncase.Passes/RulesUtility.cs new file mode 100644 index 0000000000..29afea1468 --- /dev/null +++ b/src/Nncase.Passes/RulesUtility.cs @@ -0,0 +1,38 @@ +// Copyright (c) Canaan Inc. All rights reserved. +// Licensed under the Apache license. See LICENSE file in the project root for full license information. + +using System.Linq; + +namespace Nncase.Passes; + +public static class RulesUtility +{ + /// + /// find sequeezed axis index. + /// + /// old shape. + /// new shape. + /// axis, if not found return -1. + public static int FindSqueezeAxis(int[] oldShape, int[] newShape) + { + if (oldShape.Length <= newShape.Length) + { + return -1; + } + + var indices = Enumerable.Range(0, oldShape.Length).ToList(); + foreach (var dim in newShape) + { + for (int i = 0; i < oldShape.Length; i++) + { + if (oldShape[i] == dim && indices.IndexOf(i) != -1) + { + indices.Remove(i); + } + } + } + + var oneindex = (indices.Count == 1) ? indices[0] : -1; + return oneindex; + } +} diff --git a/src/Nncase.Passes/packages.lock.json b/src/Nncase.Passes/packages.lock.json index a910d30fa5..1c39f25003 100644 --- a/src/Nncase.Passes/packages.lock.json +++ b/src/Nncase.Passes/packages.lock.json @@ -4,11 +4,11 @@ "net7.0": { "StyleCop.Analyzers": { "type": "Direct", - "requested": "[1.2.0-beta.507, )", - "resolved": "1.2.0-beta.507", - "contentHash": "/FtugDT66cKJJ+GGH7rNpG6UDrT4iIWz45M6lrXXHobDUFDHw+q5VgkbiR+6ffTO564ge7w6fQh/eoQhVdJO8Q==", + "requested": "[1.2.0-beta.435, )", + "resolved": "1.2.0-beta.435", + "contentHash": "TADk7vdGXtfTnYCV7GyleaaRTQjfoSfZXprQrVMm7cSJtJbFc1QIbWPyLvrgrfGdfHbGmUPvaN4ODKNxg2jgPQ==", "dependencies": { - "StyleCop.Analyzers.Unstable": "1.2.0.507" + "StyleCop.Analyzers.Unstable": "1.2.0.435" } }, "Google.OrTools.runtime.linux-arm64": { @@ -103,8 +103,8 @@ }, "StyleCop.Analyzers.Unstable": { "type": "Transitive", - "resolved": "1.2.0.507", - "contentHash": "gTY3IQdRqDJ4hbhSA3e/R48oE8b/OiKfvwkt1QdNVfrJK2gMHBV8ldaHJ885jxWZfllK66soa/sdcjh9bX49Tw==" + "resolved": "1.2.0.435", + "contentHash": "ouwPWZxbOV3SmCZxIRqHvljkSzkCyi1tDoMzQtDb/bRP8ctASV/iRJr+A2Gdj0QLaLmWnqTWDrH82/iP+X80Lg==" }, "System.Buffers": { "type": "Transitive", @@ -126,6 +126,7 @@ "Microsoft.Extensions.Options": "[6.0.0, )", "Microsoft.Toolkit.HighPerformance": "[7.1.1, )", "NetFabric.Hyperlinq": "[3.0.0-beta48, )", + "System.CommandLine": "[2.0.0-beta4.22272.1, )", "System.Reactive": "[5.0.0, )" } }, @@ -245,6 +246,12 @@ "resolved": "1.0.2", "contentHash": "giLAHrjJe0Bh7yhNexR6pmcv02+Fi+lEPxQVdB8zvkuJCmy6rnqu8CZLIpxrUfLcWDuTCSiK0IfGmMhig3UDhA==" }, + "System.CommandLine": { + "type": "CentralTransitive", + "requested": "[2.0.0-beta4.22272.1, )", + "resolved": "2.0.0-beta4.22272.1", + "contentHash": "1uqED/q2H0kKoLJ4+hI2iPSBSEdTuhfCYADeJrAqERmiGQ2NNacYKRNEQ+gFbU4glgVyK8rxI+ZOe1onEtr/Pg==" + }, "System.Reactive": { "type": "CentralTransitive", "requested": "[5.0.0, )", diff --git a/src/Nncase.Quantization/packages.lock.json b/src/Nncase.Quantization/packages.lock.json index ccc991e111..59323bcaea 100644 --- a/src/Nncase.Quantization/packages.lock.json +++ b/src/Nncase.Quantization/packages.lock.json @@ -19,11 +19,11 @@ }, "StyleCop.Analyzers": { "type": "Direct", - "requested": "[1.2.0-beta.507, )", - "resolved": "1.2.0-beta.507", - "contentHash": "/FtugDT66cKJJ+GGH7rNpG6UDrT4iIWz45M6lrXXHobDUFDHw+q5VgkbiR+6ffTO564ge7w6fQh/eoQhVdJO8Q==", + "requested": "[1.2.0-beta.435, )", + "resolved": "1.2.0-beta.435", + "contentHash": "TADk7vdGXtfTnYCV7GyleaaRTQjfoSfZXprQrVMm7cSJtJbFc1QIbWPyLvrgrfGdfHbGmUPvaN4ODKNxg2jgPQ==", "dependencies": { - "StyleCop.Analyzers.Unstable": "1.2.0.507" + "StyleCop.Analyzers.Unstable": "1.2.0.435" } }, "System.Linq.Async": { @@ -132,8 +132,8 @@ }, "StyleCop.Analyzers.Unstable": { "type": "Transitive", - "resolved": "1.2.0.507", - "contentHash": "gTY3IQdRqDJ4hbhSA3e/R48oE8b/OiKfvwkt1QdNVfrJK2gMHBV8ldaHJ885jxWZfllK66soa/sdcjh9bX49Tw==" + "resolved": "1.2.0.435", + "contentHash": "ouwPWZxbOV3SmCZxIRqHvljkSzkCyi1tDoMzQtDb/bRP8ctASV/iRJr+A2Gdj0QLaLmWnqTWDrH82/iP+X80Lg==" }, "System.Buffers": { "type": "Transitive", @@ -155,6 +155,7 @@ "Microsoft.Extensions.Options": "[6.0.0, )", "Microsoft.Toolkit.HighPerformance": "[7.1.1, )", "NetFabric.Hyperlinq": "[3.0.0-beta48, )", + "System.CommandLine": "[2.0.0-beta4.22272.1, )", "System.Reactive": "[5.0.0, )" } }, @@ -274,6 +275,12 @@ "resolved": "1.0.2", "contentHash": "giLAHrjJe0Bh7yhNexR6pmcv02+Fi+lEPxQVdB8zvkuJCmy6rnqu8CZLIpxrUfLcWDuTCSiK0IfGmMhig3UDhA==" }, + "System.CommandLine": { + "type": "CentralTransitive", + "requested": "[2.0.0-beta4.22272.1, )", + "resolved": "2.0.0-beta4.22272.1", + "contentHash": "1uqED/q2H0kKoLJ4+hI2iPSBSEdTuhfCYADeJrAqERmiGQ2NNacYKRNEQ+gFbU4glgVyK8rxI+ZOe1onEtr/Pg==" + }, "System.Reactive": { "type": "CentralTransitive", "requested": "[5.0.0, )", diff --git a/src/Nncase.Schedule/packages.lock.json b/src/Nncase.Schedule/packages.lock.json index 93fabe1e48..1b9a6c1dd5 100644 --- a/src/Nncase.Schedule/packages.lock.json +++ b/src/Nncase.Schedule/packages.lock.json @@ -4,11 +4,11 @@ "net7.0": { "StyleCop.Analyzers": { "type": "Direct", - "requested": "[1.2.0-beta.507, )", - "resolved": "1.2.0-beta.507", - "contentHash": "/FtugDT66cKJJ+GGH7rNpG6UDrT4iIWz45M6lrXXHobDUFDHw+q5VgkbiR+6ffTO564ge7w6fQh/eoQhVdJO8Q==", + "requested": "[1.2.0-beta.435, )", + "resolved": "1.2.0-beta.435", + "contentHash": "TADk7vdGXtfTnYCV7GyleaaRTQjfoSfZXprQrVMm7cSJtJbFc1QIbWPyLvrgrfGdfHbGmUPvaN4ODKNxg2jgPQ==", "dependencies": { - "StyleCop.Analyzers.Unstable": "1.2.0.507" + "StyleCop.Analyzers.Unstable": "1.2.0.435" } }, "Microsoft.Extensions.Configuration.Abstractions": { @@ -47,8 +47,8 @@ }, "StyleCop.Analyzers.Unstable": { "type": "Transitive", - "resolved": "1.2.0.507", - "contentHash": "gTY3IQdRqDJ4hbhSA3e/R48oE8b/OiKfvwkt1QdNVfrJK2gMHBV8ldaHJ885jxWZfllK66soa/sdcjh9bX49Tw==" + "resolved": "1.2.0.435", + "contentHash": "ouwPWZxbOV3SmCZxIRqHvljkSzkCyi1tDoMzQtDb/bRP8ctASV/iRJr+A2Gdj0QLaLmWnqTWDrH82/iP+X80Lg==" }, "System.Buffers": { "type": "Transitive", @@ -70,6 +70,7 @@ "Microsoft.Extensions.Options": "[6.0.0, )", "Microsoft.Toolkit.HighPerformance": "[7.1.1, )", "NetFabric.Hyperlinq": "[3.0.0-beta48, )", + "System.CommandLine": "[2.0.0-beta4.22272.1, )", "System.Reactive": "[5.0.0, )" } }, @@ -129,6 +130,12 @@ "System.Runtime.CompilerServices.Unsafe": "5.0.0" } }, + "System.CommandLine": { + "type": "CentralTransitive", + "requested": "[2.0.0-beta4.22272.1, )", + "resolved": "2.0.0-beta4.22272.1", + "contentHash": "1uqED/q2H0kKoLJ4+hI2iPSBSEdTuhfCYADeJrAqERmiGQ2NNacYKRNEQ+gFbU4glgVyK8rxI+ZOe1onEtr/Pg==" + }, "System.Reactive": { "type": "CentralTransitive", "requested": "[5.0.0, )", diff --git a/src/Nncase.Simulator/packages.lock.json b/src/Nncase.Simulator/packages.lock.json index 93fabe1e48..1b9a6c1dd5 100644 --- a/src/Nncase.Simulator/packages.lock.json +++ b/src/Nncase.Simulator/packages.lock.json @@ -4,11 +4,11 @@ "net7.0": { "StyleCop.Analyzers": { "type": "Direct", - "requested": "[1.2.0-beta.507, )", - "resolved": "1.2.0-beta.507", - "contentHash": "/FtugDT66cKJJ+GGH7rNpG6UDrT4iIWz45M6lrXXHobDUFDHw+q5VgkbiR+6ffTO564ge7w6fQh/eoQhVdJO8Q==", + "requested": "[1.2.0-beta.435, )", + "resolved": "1.2.0-beta.435", + "contentHash": "TADk7vdGXtfTnYCV7GyleaaRTQjfoSfZXprQrVMm7cSJtJbFc1QIbWPyLvrgrfGdfHbGmUPvaN4ODKNxg2jgPQ==", "dependencies": { - "StyleCop.Analyzers.Unstable": "1.2.0.507" + "StyleCop.Analyzers.Unstable": "1.2.0.435" } }, "Microsoft.Extensions.Configuration.Abstractions": { @@ -47,8 +47,8 @@ }, "StyleCop.Analyzers.Unstable": { "type": "Transitive", - "resolved": "1.2.0.507", - "contentHash": "gTY3IQdRqDJ4hbhSA3e/R48oE8b/OiKfvwkt1QdNVfrJK2gMHBV8ldaHJ885jxWZfllK66soa/sdcjh9bX49Tw==" + "resolved": "1.2.0.435", + "contentHash": "ouwPWZxbOV3SmCZxIRqHvljkSzkCyi1tDoMzQtDb/bRP8ctASV/iRJr+A2Gdj0QLaLmWnqTWDrH82/iP+X80Lg==" }, "System.Buffers": { "type": "Transitive", @@ -70,6 +70,7 @@ "Microsoft.Extensions.Options": "[6.0.0, )", "Microsoft.Toolkit.HighPerformance": "[7.1.1, )", "NetFabric.Hyperlinq": "[3.0.0-beta48, )", + "System.CommandLine": "[2.0.0-beta4.22272.1, )", "System.Reactive": "[5.0.0, )" } }, @@ -129,6 +130,12 @@ "System.Runtime.CompilerServices.Unsafe": "5.0.0" } }, + "System.CommandLine": { + "type": "CentralTransitive", + "requested": "[2.0.0-beta4.22272.1, )", + "resolved": "2.0.0-beta4.22272.1", + "contentHash": "1uqED/q2H0kKoLJ4+hI2iPSBSEdTuhfCYADeJrAqERmiGQ2NNacYKRNEQ+gFbU4glgVyK8rxI+ZOe1onEtr/Pg==" + }, "System.Reactive": { "type": "CentralTransitive", "requested": "[5.0.0, )", diff --git a/src/Nncase.Targets/packages.lock.json b/src/Nncase.Targets/packages.lock.json index 4a17a6364e..a434ae4251 100644 --- a/src/Nncase.Targets/packages.lock.json +++ b/src/Nncase.Targets/packages.lock.json @@ -4,11 +4,11 @@ "net7.0": { "StyleCop.Analyzers": { "type": "Direct", - "requested": "[1.2.0-beta.507, )", - "resolved": "1.2.0-beta.507", - "contentHash": "/FtugDT66cKJJ+GGH7rNpG6UDrT4iIWz45M6lrXXHobDUFDHw+q5VgkbiR+6ffTO564ge7w6fQh/eoQhVdJO8Q==", + "requested": "[1.2.0-beta.435, )", + "resolved": "1.2.0-beta.435", + "contentHash": "TADk7vdGXtfTnYCV7GyleaaRTQjfoSfZXprQrVMm7cSJtJbFc1QIbWPyLvrgrfGdfHbGmUPvaN4ODKNxg2jgPQ==", "dependencies": { - "StyleCop.Analyzers.Unstable": "1.2.0.507" + "StyleCop.Analyzers.Unstable": "1.2.0.435" } }, "Microsoft.Extensions.Configuration.Abstractions": { @@ -47,8 +47,8 @@ }, "StyleCop.Analyzers.Unstable": { "type": "Transitive", - "resolved": "1.2.0.507", - "contentHash": "gTY3IQdRqDJ4hbhSA3e/R48oE8b/OiKfvwkt1QdNVfrJK2gMHBV8ldaHJ885jxWZfllK66soa/sdcjh9bX49Tw==" + "resolved": "1.2.0.435", + "contentHash": "ouwPWZxbOV3SmCZxIRqHvljkSzkCyi1tDoMzQtDb/bRP8ctASV/iRJr+A2Gdj0QLaLmWnqTWDrH82/iP+X80Lg==" }, "System.Buffers": { "type": "Transitive", @@ -78,6 +78,7 @@ "Microsoft.Extensions.Options": "[6.0.0, )", "Microsoft.Toolkit.HighPerformance": "[7.1.1, )", "NetFabric.Hyperlinq": "[3.0.0-beta48, )", + "System.CommandLine": "[2.0.0-beta4.22272.1, )", "System.Reactive": "[5.0.0, )" } }, @@ -152,6 +153,12 @@ "System.Runtime.CompilerServices.Unsafe": "5.0.0" } }, + "System.CommandLine": { + "type": "CentralTransitive", + "requested": "[2.0.0-beta4.22272.1, )", + "resolved": "2.0.0-beta4.22272.1", + "contentHash": "1uqED/q2H0kKoLJ4+hI2iPSBSEdTuhfCYADeJrAqERmiGQ2NNacYKRNEQ+gFbU4glgVyK8rxI+ZOe1onEtr/Pg==" + }, "System.Reactive": { "type": "CentralTransitive", "requested": "[5.0.0, )", diff --git a/src/Nncase.Tests.TestFixture/packages.lock.json b/src/Nncase.Tests.TestFixture/packages.lock.json index a7d3cc2c05..b70a73730e 100644 --- a/src/Nncase.Tests.TestFixture/packages.lock.json +++ b/src/Nncase.Tests.TestFixture/packages.lock.json @@ -28,11 +28,11 @@ }, "StyleCop.Analyzers": { "type": "Direct", - "requested": "[1.2.0-beta.507, )", - "resolved": "1.2.0-beta.507", - "contentHash": "/FtugDT66cKJJ+GGH7rNpG6UDrT4iIWz45M6lrXXHobDUFDHw+q5VgkbiR+6ffTO564ge7w6fQh/eoQhVdJO8Q==", + "requested": "[1.2.0-beta.435, )", + "resolved": "1.2.0-beta.435", + "contentHash": "TADk7vdGXtfTnYCV7GyleaaRTQjfoSfZXprQrVMm7cSJtJbFc1QIbWPyLvrgrfGdfHbGmUPvaN4ODKNxg2jgPQ==", "dependencies": { - "StyleCop.Analyzers.Unstable": "1.2.0.507" + "StyleCop.Analyzers.Unstable": "1.2.0.435" } }, "System.Linq.Async": { @@ -376,8 +376,8 @@ }, "StyleCop.Analyzers.Unstable": { "type": "Transitive", - "resolved": "1.2.0.507", - "contentHash": "gTY3IQdRqDJ4hbhSA3e/R48oE8b/OiKfvwkt1QdNVfrJK2gMHBV8ldaHJ885jxWZfllK66soa/sdcjh9bX49Tw==" + "resolved": "1.2.0.435", + "contentHash": "ouwPWZxbOV3SmCZxIRqHvljkSzkCyi1tDoMzQtDb/bRP8ctASV/iRJr+A2Gdj0QLaLmWnqTWDrH82/iP+X80Lg==" }, "System.Buffers": { "type": "Transitive", @@ -750,6 +750,7 @@ "Microsoft.Extensions.Options": "[6.0.0, )", "Microsoft.Toolkit.HighPerformance": "[7.1.1, )", "NetFabric.Hyperlinq": "[3.0.0-beta48, )", + "System.CommandLine": "[2.0.0-beta4.22272.1, )", "System.Reactive": "[5.0.0, )" } }, @@ -1006,6 +1007,12 @@ "resolved": "1.0.2", "contentHash": "giLAHrjJe0Bh7yhNexR6pmcv02+Fi+lEPxQVdB8zvkuJCmy6rnqu8CZLIpxrUfLcWDuTCSiK0IfGmMhig3UDhA==" }, + "System.CommandLine": { + "type": "CentralTransitive", + "requested": "[2.0.0-beta4.22272.1, )", + "resolved": "2.0.0-beta4.22272.1", + "contentHash": "1uqED/q2H0kKoLJ4+hI2iPSBSEdTuhfCYADeJrAqERmiGQ2NNacYKRNEQ+gFbU4glgVyK8rxI+ZOe1onEtr/Pg==" + }, "System.Reactive": { "type": "CentralTransitive", "requested": "[5.0.0, )", diff --git a/src/Nncase.Tests/CodeGen/CSourceHostCases.cs b/src/Nncase.Tests/CodeGen/CSourceHostCases.cs index d6be29ed3b..97101d3c15 100644 --- a/src/Nncase.Tests/CodeGen/CSourceHostCases.cs +++ b/src/Nncase.Tests/CodeGen/CSourceHostCases.cs @@ -27,8 +27,8 @@ public class SubCase : ICodeGenCase public override PrimFunction GetEntry() { var func = T.PrimFunc("sub", - T.Buffer(TensorType.Scalar(DataTypes.Float32), Schedule.MemoryLocation.Input, out var x), - T.Buffer(TensorType.Scalar(DataTypes.Float32), Schedule.MemoryLocation.Input, out var y)).Body( + T.Buffer(TensorType.Scalar(DataTypes.Float32), MemoryLocation.Input, out var x), + T.Buffer(TensorType.Scalar(DataTypes.Float32), MemoryLocation.Input, out var y)).Body( x - y ); return func; @@ -71,8 +71,8 @@ public override void CompareEqual(IRTModel rtmod) public override PrimFunction GetEntry() { return T.PrimFunc("for_loop", - T.Buffer(new(DataTypes.Int32, new[] { 100 }), Schedule.MemoryLocation.Input, out var A), - T.Buffer(TensorType.Scalar(DataTypes.Int32), Schedule.MemoryLocation.Input, out var n) + T.Buffer(new(DataTypes.Int32, new[] { 100 }), MemoryLocation.Input, out var A), + T.Buffer(TensorType.Scalar(DataTypes.Int32), MemoryLocation.Input, out var n) ).Body( T.Serial(out var i, n).Body( T.Store(A[i], A[i] + 1), diff --git a/src/Nncase.Tests/CodeGen/UnitTestStackVMEmitter.cs b/src/Nncase.Tests/CodeGen/UnitTestStackVMEmitter.cs index 2e1027f3c2..b21d10d10f 100644 --- a/src/Nncase.Tests/CodeGen/UnitTestStackVMEmitter.cs +++ b/src/Nncase.Tests/CodeGen/UnitTestStackVMEmitter.cs @@ -1002,9 +1002,9 @@ public void TestStackVMEmitterGConcat() var memoryStream = new MemoryStream(); var stackVmEmitter = new StackVMEmitter(new BinaryWriter(memoryStream, Encoding.UTF8, true)); var tensorEmitter = new StackVMEmitter.TensorEmitter(stackVmEmitter); - tensorEmitter.Concat(); + tensorEmitter.Concat(0); var actual = memoryStream.ToArray(); - Assert.Equal(new byte[] { 100, actual[1], 0 }, actual); + Assert.Equal(new byte[] { 100, actual[1], 0, 0, 0, 0, 0 }, actual); } [Fact] @@ -1156,9 +1156,9 @@ public void TestStackVMEmitterGGather() var memoryStream = new MemoryStream(); var stackVmEmitter = new StackVMEmitter(new BinaryWriter(memoryStream, Encoding.UTF8, true)); var tensorEmitter = new StackVMEmitter.TensorEmitter(stackVmEmitter); - tensorEmitter.Gather(); + tensorEmitter.Gather(0); var actual = memoryStream.ToArray(); - Assert.Equal(new byte[] { 100, actual[1], 0 }, actual); + Assert.Equal(new byte[] { 100, actual[1], 0, 0, 0, 0, 0 }, actual); } [Fact] @@ -1244,9 +1244,9 @@ public void TestStackVMEmitterGLayerNorm() var memoryStream = new MemoryStream(); var stackVmEmitter = new StackVMEmitter(new BinaryWriter(memoryStream, Encoding.UTF8, true)); var tensorEmitter = new StackVMEmitter.TensorEmitter(stackVmEmitter); - tensorEmitter.LayerNorm(-1, 0f); + tensorEmitter.LayerNorm(-1, 0f, false); var actual = memoryStream.ToArray(); - Assert.Equal(new byte[] { 100, actual[1], 0, 255, 255, 255, 255, 0, 0, 0, 0 }, actual); + Assert.Equal(new byte[] { 100, actual[1], 0, 255, 255, 255, 255, 0, 0, 0, 0, 0 }, actual); } [Fact] diff --git a/src/Nncase.Tests/Core/UnitTestDataTypes.cs b/src/Nncase.Tests/Core/UnitTestDataTypes.cs index 322b989921..183f65a140 100644 --- a/src/Nncase.Tests/Core/UnitTestDataTypes.cs +++ b/src/Nncase.Tests/Core/UnitTestDataTypes.cs @@ -61,7 +61,7 @@ public void TestGetDisplayName() { var a = new QuantParamType(); Assert.Equal(a.ToString(), DataTypes.GetDisplayName(a)); - Assert.Equal("(f32*)", DataTypes.GetDisplayName(new PointerType(DataTypes.Float32))); + Assert.Equal("(f32 *)", DataTypes.GetDisplayName(new PointerType(DataTypes.Float32))); Assert.Equal(DataTypes.Boolean.ShortName, DataTypes.GetDisplayName(DataTypes.Boolean)); Assert.Equal(DataTypes.Utf8Char.ShortName, DataTypes.GetDisplayName(DataTypes.Utf8Char)); Assert.Equal(DataTypes.Int8.ShortName, DataTypes.GetDisplayName(DataTypes.Int8)); @@ -103,7 +103,6 @@ public void TestCSharpName() public void TestBuiltInName() { Assert.Equal("QuantParam", DataTypes.GetBuiltInName(new QuantParamType())); - Assert.Throws(() => DataTypes.GetBuiltInName(new PointerType(DataTypes.Float32))); Assert.Equal("bool", DataTypes.GetBuiltInName(DataTypes.Boolean)); Assert.Equal("Utf8Char", DataTypes.GetBuiltInName(DataTypes.Utf8Char)); Assert.Equal("sbyte", DataTypes.GetBuiltInName(DataTypes.Int8)); diff --git a/src/Nncase.Tests/Core/UnitTestExpression.cs b/src/Nncase.Tests/Core/UnitTestExpression.cs index 8f29fbdab5..dc6ceeb702 100644 --- a/src/Nncase.Tests/Core/UnitTestExpression.cs +++ b/src/Nncase.Tests/Core/UnitTestExpression.cs @@ -261,8 +261,8 @@ public void TestDenseTensorLength() public void TestConstBufferNotEqual() { var c = IR.F.Random.Normal(DataTypes.Float32, 1, 0, 0, new[] { 1, 16, 64, 400 }).Evaluate().AsTensor(); - var ddr_ld_input = new TIR.BufferRegion(Nncase.TIR.T.ConstBuffer(Const.FromTensor(c), out _, "ddr_ld_input"), new(new TIR.Range[] { 0..1, 0..16, 0..31, 0..400 })); - var ddr_ld_output = new TIR.BufferRegion(new TIR.PhysicalBuffer("ddr_ld_input", DataTypes.Float32, Schedule.MemoryLocation.Input, new[] { 1, 16, 64, 400 }, TensorUtilities.GetStrides(new[] { 1, 16, 64, 400 }), 0, 0), new(new TIR.Range[] { 0..1, 0..16, 0..31, 0..400 })); + var ddr_ld_input = new TIR.BufferRegion(TIR.T.AttachBuffer(Const.FromTensor(c), out _, "ddr_ld_input"), new(new TIR.Range[] { 0..1, 0..16, 0..31, 0..400 })); + var ddr_ld_output = new TIR.BufferRegion(new TIR.Buffer("ddr_ld_input", DataTypes.Float32, new MemSpan(0, 0, MemoryLocation.Input), new Expr[] { 1, 16, 64, 400 }, TensorUtilities.GetStrides(new Expr[] { 1, 16, 64, 400 })), new(new TIR.Range[] { 0..1, 0..16, 0..31, 0..400 })); Assert.NotEqual(ddr_ld_input.Buffer, ddr_ld_output.Buffer); Assert.NotEqual(ddr_ld_input, ddr_ld_output); } @@ -270,8 +270,8 @@ public void TestConstBufferNotEqual() [Fact] public void TestBufferEqual() { - var ddr_ld_input = new TIR.BufferRegion(new TIR.PhysicalBuffer("ddr_ld_input", DataTypes.Float32, Schedule.MemoryLocation.Input, new[] { 1, 16, 64, 400 }, TensorUtilities.GetStrides(new[] { 1, 16, 64, 400 }), 0, 0), new(new TIR.Range[] { 0..1, 0..16, 0..31, 0..400 })); - var ddr_ld_output = new TIR.BufferRegion(new TIR.PhysicalBuffer("ddr_ld_input", DataTypes.Float32, Schedule.MemoryLocation.Input, new[] { 1, 16, 64, 400 }, TensorUtilities.GetStrides(new[] { 1, 16, 64, 400 }), 0, 0), new(new TIR.Range[] { 0..1, 0..16, 0..31, 0..400 })); + var ddr_ld_input = new TIR.BufferRegion(new TIR.Buffer("ddr_ld_input", DataTypes.Float32, new MemSpan(0, 0, MemoryLocation.Input), new Expr[] { 1, 16, 64, 400 }, TensorUtilities.GetStrides(new Expr[] { 1, 16, 64, 400 })), new(new TIR.Range[] { 0..1, 0..16, 0..31, 0..400 })); + var ddr_ld_output = new TIR.BufferRegion(new TIR.Buffer("ddr_ld_input", DataTypes.Float32, new MemSpan(0, 0, MemoryLocation.Input), new Expr[] { 1, 16, 64, 400 }, TensorUtilities.GetStrides(new Expr[] { 1, 16, 64, 400 })), new(new TIR.Range[] { 0..1, 0..16, 0..31, 0..400 })); Assert.Equal(ddr_ld_input.Buffer, ddr_ld_output.Buffer); Assert.Equal(ddr_ld_input, ddr_ld_output); } @@ -279,8 +279,8 @@ public void TestBufferEqual() [Fact] public void TestBufferNotEqual() { - var ddr_ld_input = new TIR.BufferRegion(new TIR.PhysicalBuffer("ddr_ld_input", DataTypes.Float32, Schedule.MemoryLocation.Input, new[] { 1, 16, 64, 400 }, TensorUtilities.GetStrides(new[] { 1, 16, 64, 400 }), 0, 0), new(new TIR.Range[] { 0..1, 0..16, 0..31, 0..400 })); - var glb_ld_output = new TIR.BufferRegion(new TIR.PhysicalBuffer("glb_ld_output", DataTypes.BFloat16, Schedule.MemoryLocation.Data, new[] { 1, 16, 64, 400 }, TensorUtilities.GetStrides(new[] { 1, 16, 64, 400 }), 0, 0), new(new TIR.Range[] { 0..1, 0..16, 0..31, 0..400 })); + var ddr_ld_input = new TIR.BufferRegion(new TIR.Buffer("ddr_ld_input", DataTypes.Float32, new MemSpan(0, 0, MemoryLocation.Input), new Expr[] { 1, 16, 64, 400 }, TensorUtilities.GetStrides(new Expr[] { 1, 16, 64, 400 })), new(new TIR.Range[] { 0..1, 0..16, 0..31, 0..400 })); + var glb_ld_output = new TIR.BufferRegion(new TIR.Buffer("glb_ld_output", DataTypes.BFloat16, new MemSpan(0, 0, MemoryLocation.Data), new Expr[] { 1, 16, 64, 400 }, TensorUtilities.GetStrides(new Expr[] { 1, 16, 64, 400 })), new(new TIR.Range[] { 0..1, 0..16, 0..31, 0..400 })); Assert.False(ddr_ld_input.Buffer.Equals(glb_ld_output.Buffer)); Assert.False(ddr_ld_input.Equals(glb_ld_output)); } @@ -468,7 +468,7 @@ public void TestPrintExpr() CompilerServices.InferenceType(x); Assert.Equal("const(i32[4] : {1,2,3,4})", CompilerServices.Print(x)); Assert.Equal("None", CompilerServices.Print(None.Default)); - Assert.Equal("Add", CompilerServices.Print(new Nncase.IR.Math.Binary(BinaryOp.Add))); + Assert.Equal("Binary", CompilerServices.Print(new Nncase.IR.Math.Binary(BinaryOp.Add))); var y = new Var("y"); CompilerServices.InferenceType(y); Assert.Equal("%y: any", CompilerServices.Print(y)); diff --git a/src/Nncase.Tests/Core/UnitTestMutator.cs b/src/Nncase.Tests/Core/UnitTestMutator.cs index b08f089911..83d99fbb60 100644 --- a/src/Nncase.Tests/Core/UnitTestMutator.cs +++ b/src/Nncase.Tests/Core/UnitTestMutator.cs @@ -28,8 +28,5 @@ public void TestMutator() var removeNop = Mutator.RemoveNop().Invoke(); Assert.Equal(new Passes.Mutators.RemoveNop().IsMutated, removeNop.IsMutated); - - var foldMathCall = Mutator.FoldMathCall().Invoke(); - Assert.Equal(new Passes.Mutators.FoldMathCall().IsMutated, foldMathCall.IsMutated); } } diff --git a/src/Nncase.Tests/Core/UnitTestStringUtility.cs b/src/Nncase.Tests/Core/UnitTestStringUtility.cs index 0b01ae0fd6..1efe33a6c8 100644 --- a/src/Nncase.Tests/Core/UnitTestStringUtility.cs +++ b/src/Nncase.Tests/Core/UnitTestStringUtility.cs @@ -16,22 +16,22 @@ namespace Nncase.Tests.CoreTest; public static class TestExtensions { - public static ArrayExtensions.SpanWhereEnumerable> InputOf(this ReadOnlySpan arr) => arr.AsValueEnumerable().Where(b => b.MemLocation == Schedule.MemoryLocation.Input); + public static ArrayExtensions.SpanWhereEnumerable> InputOf(this ReadOnlySpan arr) => arr.AsValueEnumerable().Where(b => b.MemSpan.Location == MemoryLocation.Input); - public static ArrayExtensions.SpanWhereEnumerable> OutputOf(this ReadOnlySpan arr) => arr.AsValueEnumerable().Where(b => b.MemLocation == Schedule.MemoryLocation.Output); + public static ArrayExtensions.SpanWhereEnumerable> OutputOf(this ReadOnlySpan arr) => arr.AsValueEnumerable().Where(b => b.MemSpan.Location == MemoryLocation.Output); } public sealed class UnitTestStringUtility { - private readonly TIR.PrimFunction _entry = new("test_module", new Sequential(1), new TIR.PhysicalBuffer("testInput", DataTypes.Float32, Schedule.MemoryLocation.Input, new[] { 1, 16, 64, 400 }, TensorUtilities.GetStrides(new[] { 1, 16, 64, 400 }), 0, 0), new TIR.PhysicalBuffer("testInput", DataTypes.Float32, Schedule.MemoryLocation.Input, new[] { 1, 16, 64, 400 }, TensorUtilities.GetStrides(new[] { 1, 16, 64, 400 }), 0, 0)); + private readonly TIR.PrimFunction _entry = new("test_module", new Sequential(1), new TIR.Buffer("testInput", DataTypes.Float32, new MemSpan(0, 123, MemoryLocation.Input), new Expr[] { 1, 16, 64, 400 }, TensorUtilities.GetStrides(new Expr[] { 1, 16, 64, 400 })), new TIR.Buffer("testInput", DataTypes.Float32, new MemSpan(0, 123, MemoryLocation.Output), new Expr[] { 1, 16, 64, 400 }, TensorUtilities.GetStrides(new Expr[] { 1, 16, 64, 400 }))); [Fact] public void TestJoin() { var result = StringUtility.Join(",", _entry.Parameters.InputOf().Select(b => b)); - Assert.Equal("PhysicalBuffer(testInput, f32, MemLocation),PhysicalBuffer(testInput, f32, MemLocation)", result); + Assert.Equal("Nncase.TIR.Buffer", result); var result1 = StringUtility.Join(",", _entry.Parameters.OutputOf().Select(b => b)); - Assert.Equal(string.Empty, result1); + Assert.Equal("Nncase.TIR.Buffer", result1); } } diff --git a/src/Nncase.Tests/Core/UnitTestTIR.cs b/src/Nncase.Tests/Core/UnitTestTIR.cs index 54d504f58c..f0c40be178 100644 --- a/src/Nncase.Tests/Core/UnitTestTIR.cs +++ b/src/Nncase.Tests/Core/UnitTestTIR.cs @@ -21,15 +21,6 @@ namespace Nncase.Tests.CoreTest; public sealed class UnitTestTIR { - [Fact] - public void TestLogicalBuffer() - { - var logicalBuffer1 = new LogicalBuffer("logicalBuffer", default, new TensorConst(new Tensor(new[] { 1 }))); - var logicalBuffer2 = new LogicalBuffer("logicalBuffer", DataTypes.Int32, default, new[] { (Expr)new Tensor(new[] { 1 }) }); - Assert.Equal(logicalBuffer2.Length.ToString(), logicalBuffer1.Length.ToString()); - Assert.Equal("LogicalBuffer(logicalBuffer, i32, MemLocation)", logicalBuffer1.ToString()); - } - [Fact] public void TestScheduler() { @@ -47,21 +38,10 @@ public void TestScheduler() [Fact] public void TestBufferStore() { - Assert.Throws(() => T.Store(null!, null!)); - - var variable = new Var("x", DataTypes.Int32); - int index = 0; - Expr loadOp = T.Load(variable, index); Expr value = 42; - _ = T.Store(loadOp, value); - - var physicalBuffer = new TIR.PhysicalBuffer("testInput", DataTypes.Float32, Schedule.MemoryLocation.Input, new[] { 1, 16, 64, 400 }, TensorUtilities.GetStrides(new[] { 1, 16, 64, 400 }), 0, 0); - var indices = new Expr[] { 0, 1 }; - Expr storeOp = T.Store(new BufferLoad(physicalBuffer, indices), value); - var store = (BufferStore)storeOp; - Assert.Equal(physicalBuffer, store.Buffer); - Assert.Equal(value, store.Value); - Assert.Equal(new Expr[] { 0 }, store.Indices.ToArray()); + TIR.T.CreateBuffer(new TensorType(DataTypes.Float32, new[] { 1, 16, 64, 400 }), MemoryLocation.Input, out var testInput); + _ = new Expr[] { 0, 1 }; + _ = T.Store(testInput, 0, value); } [Fact] @@ -106,15 +86,6 @@ public void TestSequential() Assert.Equal(expect2, actual2); } - [Fact] - public void TestBuffer() - { - var buffer = T.Buffer(DataTypes.Float32, MemoryLocation.Input, new Expr[] { 1, 16, 64, 400 }, out _); - Assert.Equal(DataTypes.Float32, buffer.ElemType); - var expect = new LogicalBuffer("_", DataTypes.Float32, MemoryLocation.Input, new Expr[] { 1, 16, 64, 400 }); - Assert.Equal(expect, buffer); - } - [Fact] public void TestForSegment() { @@ -143,7 +114,7 @@ public void TestEmit() [Fact] public void TestBufferRegion() { - var buffer = T.Buffer(DataTypes.Float32, MemoryLocation.Input, new Expr[] { 1, 16, 64, 400 }, out _); + var buffer = T.CreateBuffer(new(DataTypes.Float32, new[] { 1, 16, 64, 400 }), MemoryLocation.Input, out _); var region = new Range[] { new Range(1, 2, 2), new Range(-1, 3, 2) }; var bufferRegion = new BufferRegion(buffer, region); @@ -165,8 +136,8 @@ public void TestPrimFunction() { var primFunc = new PrimFunction("test_module", new Sequential(new Expr[] { 1 }), new[] { - new TIR.PhysicalBuffer("testInput", DataTypes.Float32, Schedule.MemoryLocation.Input, new[] { 1, 16, 64, 400 }, TensorUtilities.GetStrides(new[] { 1, 16, 64, 400 }), 0, 0), - new TIR.PhysicalBuffer("testInput", DataTypes.Float32, Schedule.MemoryLocation.Input, new[] { 1, 16, 64, 400 }, TensorUtilities.GetStrides(new[] { 1, 16, 64, 400 }), 0, 0), + TIR.T.CreateBuffer(new(DataTypes.Float32, new[] { 1, 16, 64, 400 }), MemoryLocation.Input, out var _), + TIR.T.CreateBuffer(new(DataTypes.Float32, new[] { 1, 16, 64, 400 }), MemoryLocation.Input, out var _), }); var primFuncParameters = primFunc.Parameters; @@ -178,8 +149,8 @@ public void TestPrimFunction() var newBody = new Sequential(new Expr[] { 3 }); var newParams = new[] { - new TIR.PhysicalBuffer("testInput", DataTypes.Float32, Schedule.MemoryLocation.Input, new[] { 1, 16, 64, 400 }, TensorUtilities.GetStrides(new[] { 1, 16, 64, 400 }), 0, 0), - new TIR.PhysicalBuffer("testInput", DataTypes.Float32, Schedule.MemoryLocation.Input, new[] { 1, 16, 64, 400 }, TensorUtilities.GetStrides(new[] { 1, 16, 64, 400 }), 0, 0), + TIR.T.CreateBuffer(new(DataTypes.Float32, new[] { 1, 16, 64, 400 }), MemoryLocation.Input, out var _), + TIR.T.CreateBuffer(new(DataTypes.Float32, new[] { 1, 16, 64, 400 }), MemoryLocation.Input, out var _), }; var newPrimFunc = primFunc.With(moduleKind: newModuleKind, body: newBody, parameters: newParams); @@ -190,7 +161,7 @@ public void TestPrimFunction() Assert.Equal(newParams, newPrimFunc.Parameters.ToArray()); Assert.Equal(primFunc.Name, newPrimFunc.Name); // should not change the name - Assert.NotNull(new PrimFunction("test_module", new Sequential(new Expr[] { 1 }), default(ReadOnlySpan))); + Assert.NotNull(new PrimFunction("test_module", new Sequential(new Expr[] { 1 }), default(ReadOnlySpan))); } [Fact] diff --git a/src/Nncase.Tests/Core/UnitTestTensorUtilities.cs b/src/Nncase.Tests/Core/UnitTestTensorUtilities.cs index d07f20bd78..7e8a70c2bb 100644 --- a/src/Nncase.Tests/Core/UnitTestTensorUtilities.cs +++ b/src/Nncase.Tests/Core/UnitTestTensorUtilities.cs @@ -54,48 +54,68 @@ public sealed class UnitTestTensorUtilities public void TestIsContiguousSlice() { var dim1 = new[] { 1, 512, 14, 14 }; - + int start; Assert.True(TensorUtilities.IsContiguousSlice( dim1, - new[] { 0..1, 0..512, 0..14, 0..14 })); + new[] { 0..1, 0..512, 0..14, 0..14 }, + out start)); + Assert.Equal(0, start); Assert.True(TensorUtilities.IsContiguousSlice( dim1, - new[] { 0..1, 0..1, 0..1, 0..14 })); + new[] { 0..1, 0..1, 0..1, 0..14 }, + out start)); + Assert.Equal(0, start); Assert.True(TensorUtilities.IsContiguousSlice( dim1, - new[] { 0..1, 0..1, 0..1, 7..14 })); + new[] { 0..1, 0..1, 0..1, 7..14 }, + out start)); + Assert.Equal(0, start); Assert.True(TensorUtilities.IsContiguousSlice( dim1, - new[] { 0..1, 0..1, 7..14, 0..14 })); + new[] { 0..1, 0..1, 7..14, 0..14 }, + out start)); + Assert.Equal(0, start); Assert.False(TensorUtilities.IsContiguousSlice( dim1, - new[] { 0..1, 0..512, 0..7, 0..14 })); + new[] { 0..1, 0..512, 0..7, 0..14 }, + out start)); + Assert.Equal(2, start); Assert.False(TensorUtilities.IsContiguousSlice( dim1, - new[] { 0..1, 0..512, 0..7, 0..14, 0..1 })); + new[] { 0..1, 0..512, 0..7, 0..14, 0..1 }, + out start)); + Assert.Equal(4, start); Assert.False(TensorUtilities.IsContiguousSlice( dim1, - new[] { 0..1, 10..512, 0..1, 0..1 })); + new[] { 0..1, 10..512, 0..1, 0..1 }, + out start)); + Assert.Equal(2, start); Assert.False(TensorUtilities.IsContiguousSlice( dim1, - new[] { 0..1, 0..512, 0..7, 0..1 })); + new[] { 0..1, 0..512, 0..7, 0..1 }, + out start)); + Assert.Equal(3, start); var dim2 = new[] { 1, 512, 1, 196 }; Assert.True(TensorUtilities.IsContiguousSlice( dim2, - new[] { 0..1, 0..128, 0..1, 0..196 })); + new[] { 0..1, 0..128, 0..1, 0..196 }, + out start)); + Assert.Equal(0, start); Assert.True(TensorUtilities.IsContiguousSlice( dim2, - new[] { 0..1, 0..1, 0..1, 10..15 })); + new[] { 0..1, 0..1, 0..1, 10..15 }, + out start)); + Assert.Equal(0, start); } // long GetProduct(ReadOnlySpan dimensions, int startIndex = 0) diff --git a/src/Nncase.Tests/Diagnostics/UnitTestDumpper.cs b/src/Nncase.Tests/Diagnostics/UnitTestDumpper.cs index 31c6d763ae..dcd94f1303 100644 --- a/src/Nncase.Tests/Diagnostics/UnitTestDumpper.cs +++ b/src/Nncase.Tests/Diagnostics/UnitTestDumpper.cs @@ -65,7 +65,7 @@ public void TestDumpFusion() [Fact] public void TestDumpScript() { - var prim_func_1 = T.PrimFunc("prim_func_1", "k?", T.PhysicalBuffer(DataTypes.Float32, Schedule.MemoryLocation.Input, new[] { 1, 2, 3, 4 }, out _), T.PhysicalBuffer(DataTypes.Float32, Schedule.MemoryLocation.Output, new[] { 1, 2, 3, 4 }, out _)).Body(T.Nop()).Build(); + var prim_func_1 = T.PrimFunc("prim_func_1", "k?", T.CreateBuffer(new(DataTypes.Float32, new[] { 1, 2, 3, 4 }), MemoryLocation.Input, out _), T.CreateBuffer(new(DataTypes.Float32, new[] { 1, 2, 3, 4 }), MemoryLocation.Output, out _)).Body(T.Nop()).Build(); Assert.True(CompilerServices.InferenceType(prim_func_1)); @@ -194,6 +194,17 @@ public void TestDumperCSharpIRFunction() CompilerServices.DumpCSharpIR(main, string.Empty, Dumpper.Directory); } + [Fact] + public void TestDumperPatternIRFunction() + { + var x = IR.F.Math.Quantize(IR.F.Random.Normal(DataTypes.Float32, 0, 1, 0, new[] { 1, 2, 2, 2 }), new QuantParam(1, 2.0f), DataTypes.UInt8); + var y = new Var("y", new TensorType(DataTypes.UInt8, new int[] { 1, 2, 2, 2 })); + var z = IR.F.Random.Normal(DataTypes.UInt8, 0, 1, 0, new[] { 1, 2, 2, 2 }); + var m = IR.F.Random.Normal(DataTypes.UInt8, 0, 1, 0, new[] { 1, 20, 2, 2 }); + var main = new Function("main", IR.F.Tensors.Concat(new IR.Tuple(new Expr[] { x, y, z, m }), 1), new[] { y }); + CompilerServices.DumpPatternIR(main, string.Empty, Dumpper.Directory); + } + [Fact] public void TestDumperCSharpIRFusion() { @@ -214,9 +225,9 @@ public void TestDumperCSharpIRFusion() [Fact] public void TestDumpTIRFusion() { - var lhs = new Var("lhs"); - var main = T.PrimFunc("main", Callable.StackVMModuleKind).Body( - new Call(new TIRTest.MeshNet(), new Fusion("MeshFunc", lhs + 100, lhs), IR.F.Random.Normal(DataTypes.Float32, 0, 1, 123, new[] { 100 }))).Build(); + var lhs = new Var("lhs", TensorType.Scalar(DataTypes.Float32)); + var main = T.PrimFunc("main", DefaultTargetName).Body( + new Call(new TIRTest.MeshNet(), new Fusion("MeshFunc", lhs + 100.0f, lhs), IR.F.Random.Normal(DataTypes.Float32, 0, 1, 123, new[] { 100 }))).Build(); Assert.True(CompilerServices.InferenceType(main)); CompilerServices.DumpIR(main, string.Empty, Dumpper.Directory); } diff --git a/src/Nncase.Tests/EGraph/UnitTestVrp.cs b/src/Nncase.Tests/EGraph/UnitTestVrp.cs index 5a9f4b564e..f421022a7c 100644 --- a/src/Nncase.Tests/EGraph/UnitTestVrp.cs +++ b/src/Nncase.Tests/EGraph/UnitTestVrp.cs @@ -174,6 +174,35 @@ public void TestSimpleEgraphSat() } } + [Fact] + public void TestOverLap() + { + // note ortools no overlap not support 0 size. + var model = new CpModel(); + + var x0 = model.NewIntervalVar(model.NewConstant(0), model.NewConstant(2), model.NewConstant(2), "x0"); + var y0 = model.NewFixedSizeIntervalVar(model.NewIntVar(0, 10, "y0_start"), 7, "y0"); + + var x1 = model.NewIntervalVar(model.NewConstant(2), model.NewConstant(0), model.NewConstant(2), "x1"); + var y1 = model.NewFixedSizeIntervalVar(model.NewIntVar(0, 10, "y1_start"), 7, "y1"); + + var x2 = model.NewIntervalVar(model.NewConstant(2), model.NewConstant(1), model.NewConstant(3), "x2"); + var y2 = model.NewFixedSizeIntervalVar(model.NewIntVar(0, 10, "y2_start"), 7, "y2"); + + model.Add(y0.StartExpr() == y1.StartExpr()); + model.Add(y1.StartExpr() == y2.StartExpr()); + var nooverlap = model.AddNoOverlap2D(); + nooverlap.AddRectangle(x0, y0); + nooverlap.AddRectangle(x1, y1); + nooverlap.AddRectangle(x2, y2); + model.Minimize(y0.StartExpr() + y1.StartExpr() + y2.StartExpr()); + + var solver = new CpSolver(); + var status = solver.Solve(model); + + Assert.Equal(CpSolverStatus.Infeasible, status); + } + private static void PrintSolution(in IDataModel data, in RoutingModel routing, in RoutingIndexManager manager, in Assignment solution) { Console.WriteLine($"Objective {solution.ObjectiveValue()}:"); diff --git a/src/Nncase.Tests/Evaluator/UnitTestEvaluator.cs b/src/Nncase.Tests/Evaluator/UnitTestEvaluator.cs index 3f230b49c5..037e40793b 100755 --- a/src/Nncase.Tests/Evaluator/UnitTestEvaluator.cs +++ b/src/Nncase.Tests/Evaluator/UnitTestEvaluator.cs @@ -97,10 +97,10 @@ public void TestOnnxResizeImage() public void TestLoadStore() { var loop_i = new Var(TensorType.Scalar(DataTypes.Int32)); - var load = T.Load(T.Handle("hd", DataTypes.Float32), loop_i); + T.CreateBuffer(new(DataTypes.Float32, new[] { 1, 2, 3 }), MemoryLocation.Input, out var bf); + var load = T.Load(bf, loop_i); CompilerServices.InferenceType(load); - - var store = T.Store((Var)load[TIR.Load.Handle], load[TIR.Load.Index], loop_i); + var store = T.Store(bf, loop_i, IR.F.Tensors.Cast(loop_i, DataTypes.Float32)); CompilerServices.InferenceType(store); } diff --git a/src/Nncase.Tests/Evaluator/UnitTestEvaluatorBuffers.cs b/src/Nncase.Tests/Evaluator/UnitTestEvaluatorBuffers.cs index 04c7a43c44..bed1177d0d 100644 --- a/src/Nncase.Tests/Evaluator/UnitTestEvaluatorBuffers.cs +++ b/src/Nncase.Tests/Evaluator/UnitTestEvaluatorBuffers.cs @@ -29,7 +29,7 @@ public class UnitTestEvaluatorBuffers : TestClassBase public void TestUninitialized() { var shape = new[] { 1 }; - var expr = IR.F.Buffer.Uninitialized(DataTypes.Float32, MemoryLocation.Input, shape); + var expr = IR.F.Buffer.Uninitialized(DataTypes.Float32, TIR.MemoryLocation.Input, shape); CompilerServices.InferenceType(expr); Assert.Equal(Value.None, expr.Evaluate()); } diff --git a/src/Nncase.Tests/Evaluator/UnitTestEvaluatorTensors.cs b/src/Nncase.Tests/Evaluator/UnitTestEvaluatorTensors.cs index 63cb7507a4..4d56cde01b 100644 --- a/src/Nncase.Tests/Evaluator/UnitTestEvaluatorTensors.cs +++ b/src/Nncase.Tests/Evaluator/UnitTestEvaluatorTensors.cs @@ -112,7 +112,7 @@ public void TestConcat3() for (long i = 0; i < shape.Length; i++) { var expect = OrtKI.Concat(new OrtKISharp.Tensor[] { inputA, inputB }, i); - var expr = IR.F.Tensors.Concat(new Tuple(inputA.ToTensor(), inputB.ToTensor()), i); + var expr = IR.F.Tensors.Concat(new Tuple(inputA.ToTensor(), inputB.ToTensor()), (int)i); CompilerServices.InferenceType(expr); Assert.Equal(expect, expr.Evaluate().AsTensor().ToOrtTensor()); } @@ -624,7 +624,7 @@ public void TestGather() long batchDims = 0L; var expect = OrtKI.Gather(input.ToOrtTensor(), indices.ToOrtTensor(), batchDims); - var expr = IR.F.Tensors.Gather(input, batchDims, indices); + var expr = IR.F.Tensors.Gather(input, (int)batchDims, indices); CompilerServices.InferenceType(expr); Assert.Equal(expect, expr.Evaluate().AsTensor().ToOrtTensor()); } diff --git a/src/Nncase.Tests/Match/UnitTestEGraphMatch.cs b/src/Nncase.Tests/Match/UnitTestEGraphMatch.cs index af32681a6d..f9e1691387 100644 --- a/src/Nncase.Tests/Match/UnitTestEGraphMatch.cs +++ b/src/Nncase.Tests/Match/UnitTestEGraphMatch.cs @@ -113,7 +113,7 @@ public void TestMatchVArgs() Expr expr = Concat(tuple, 0); CompilerServices.InferenceType(expr); - var vpat = IsConcat(IsTuple("tp"), IsConst(0)); + var vpat = IsConcat(0, IsTuple("tp")); Assert.True(CompilerServices.TryEMatchRoot(expr, vpat, out var eMatches)); Assert.Single(eMatches); @@ -122,13 +122,11 @@ public void TestMatchVArgs() [Fact] public void TestMatchVArgsTwice() { - ConstPattern wcaxis = IsConst(); - var tuple_lhs = new IR.Tuple(1, new Var(), 3); var tuple_rhs = new IR.Tuple(4, 5, 6); Expr expr = Concat(tuple_lhs, 0) + Concat(tuple_rhs, 1); - var vpat = IsConcat(IsTuple("tp"), wcaxis); + var vpat = IsConcat(_ => true, IsTuple("tp")); Assert.True(CompilerServices.TryEMatchRoot(expr, vpat, out var eMatches)); Assert.Equal(2, eMatches.Count); @@ -151,9 +149,8 @@ public void TestMatchVArgsRecursion() var wc = IsWildcard("wc"); var wcperm = IsWildcard("perm"); - var wcaxis = IsWildcard("axis"); - var pattern = IsConcat(IsTuple(IsVArgsRepeat("wcvargs", () => IsTranspose(IsWildcard(), wcperm))), wcaxis); + var pattern = IsConcat(_ => true, IsTuple(IsVArgsRepeat("wcvargs", () => IsTranspose(IsWildcard(), wcperm)))); Assert.True(CompilerServices.TryEMatchRoot(expr, pattern, out var results)); Assert.Single(results); @@ -163,7 +160,6 @@ public void TestMatchVArgsRecursion() Assert.Equal(((Call)wcvargs[1]).Arguments[0], y); Assert.Equal(((Call)wcvargs[2]).Arguments[0], z); Assert.Equal(result[wcperm], perm); - Assert.Equal(result[wcaxis], (Const)0); } [Fact] diff --git a/src/Nncase.Tests/Properties/launchSettings.json b/src/Nncase.Tests/Properties/launchSettings.json index d081379b07..3109588c45 100644 --- a/src/Nncase.Tests/Properties/launchSettings.json +++ b/src/Nncase.Tests/Properties/launchSettings.json @@ -2,7 +2,7 @@ "profiles": { "Nncase.Tests": { "commandName": "Project", - "nativeDebugging": true + "nativeDebugging": false } } } \ No newline at end of file diff --git a/src/Nncase.Tests/Quant/UnitTestPytestCalibrationDatasetProvider.cs b/src/Nncase.Tests/Quant/UnitTestPytestCalibrationDatasetProvider.cs index 6d2f29ce71..65e694e129 100644 --- a/src/Nncase.Tests/Quant/UnitTestPytestCalibrationDatasetProvider.cs +++ b/src/Nncase.Tests/Quant/UnitTestPytestCalibrationDatasetProvider.cs @@ -33,89 +33,37 @@ public async Task TestPytestCalibrationDatasetProvider1() { var vars = Setup(); var dataset = "./public/test1"; - foreach (var t in vars) + var actuals = DumpTensors(dataset, vars, 2); + var provider = new PytestCalibrationDatasetProvider(vars, dataset); + Assert.Equal(2, provider.Count); + var samples = provider.Samples; + var count = 0; + await foreach (var sample in samples) { - var actual = IR.F.Random.Uniform(t.CheckedDataType, 1.0f, -1.0f, 0, t.CheckedShape).Evaluate().AsTensor(); - DumpTensors(new[] { actual }, dataset, 2); - var provider = new PytestCalibrationDatasetProvider(new[] { t }, dataset); - Assert.Equal(2, provider.Count); - var samples = provider.Samples; - await foreach (var sample in samples) - { - Assert.Equal(sample[t].AsTensor(), actual); - } - } - } - - [Fact] - public async Task TestPytestCalibrationDatasetProvider2() - { - var vars = Setup(); - var dataset = "./public/test2"; - foreach (var t in vars) - { - var actual = IR.F.Random.Uniform(t.CheckedDataType, 1.0f, -1.0f, 0, t.CheckedShape).Evaluate().AsTensor(); - DumpTensors(new[] { actual }, dataset); - var provider = new PytestCalibrationDatasetProvider(new[] { t }, dataset); - Assert.Equal(1, provider.Count); - var samples = provider.Samples; - await foreach (var sample in samples) - { - Assert.Equal(sample[t].AsTensor(), actual); - } - } - } - - [Fact] - public async Task TestPytestCalibrationDatasetProvider3() - { - var vars1 = Setup(); - var dataset = "./public/test3"; - var actual1 = IR.F.Random.Uniform(vars1[0].CheckedDataType, 1.0f, -1.0f, 0, vars1[0].CheckedShape).Evaluate().AsTensor(); - var actual2 = IR.F.Random.Uniform(vars1[1].CheckedDataType, 1.0f, -1.0f, 0, vars1[1].CheckedShape).Evaluate().AsTensor(); - DumpTensors(new[] { actual1, actual2 }, dataset); - var provider1 = new PytestCalibrationDatasetProvider(vars1, dataset); - Assert.Equal(1, provider1.Count); - var samples1 = provider1.Samples; - await foreach (var sample in samples1) - { - Assert.Equal(sample[vars1[0]].AsTensor(), actual1); - Assert.Equal(sample[vars1[1]].AsTensor(), actual2); - } - } - - [Fact] - public async Task TestPytestCalibrationDatasetProvider4() - { - var vars1 = Setup(); - var dataset = "./public/test4"; - var actual1 = IR.F.Random.Uniform(vars1[0].CheckedDataType, 1.0f, -1.0f, 0, vars1[0].CheckedShape).Evaluate().AsTensor(); - var actual2 = IR.F.Random.Uniform(vars1[1].CheckedDataType, 1.0f, -1.0f, 0, vars1[1].CheckedShape).Evaluate().AsTensor(); - DumpTensors(new[] { actual1, actual2 }, dataset, 2); - var provider1 = new PytestCalibrationDatasetProvider(vars1, dataset); - Assert.Equal(2, provider1.Count); - var samples1 = provider1.Samples; - await foreach (var sample in samples1) - { - Assert.Equal(sample[vars1[0]].AsTensor(), actual1); - Assert.Equal(sample[vars1[1]].AsTensor(), actual2); + Assert.Equal(sample[vars[0]].AsTensor(), actuals[count, 0]); + Assert.Equal(sample[vars[1]].AsTensor(), actuals[count, 1]); + count++; } } - private static void DumpTensors(Tensor[] tensorValue, string dir, int sample = 1) + private static Tensor[,] DumpTensors(string dir, Var[] inputs, int sample) { Directory.CreateDirectory(dir); + var outputs = new Tensor[sample, inputs.Length]; for (var s = 0; s < sample; s++) { - for (var t = 0; t < tensorValue.Length; t++) + for (var t = 0; t < inputs.Length; t++) { - var value = tensorValue[t]; + var value = IR.F.Random.Uniform(inputs[t].CheckedDataType, 1.0f, -1.0f, s + t, inputs[t].CheckedShape).Evaluate().AsTensor(); var sr1 = new StreamWriter(Path.Join(dir, $"input_{t}_{s}.txt")); DumpTxt(value, sr1); var sr2 = Path.Join(dir, $"input_{t}_{s}.bin"); DumpBin(value, sr2); + outputs[s, t] = value; } } + + return outputs; } private static void DumpTxt(Tensor tensorValue, StreamWriter writer) diff --git a/src/Nncase.Tests/Rewrite/Fusion/UnitTestFusionMaker.cs b/src/Nncase.Tests/Rewrite/Fusion/UnitTestFusionMaker.cs index 70259ef6eb..c3e4bac4c7 100644 --- a/src/Nncase.Tests/Rewrite/Fusion/UnitTestFusionMaker.cs +++ b/src/Nncase.Tests/Rewrite/Fusion/UnitTestFusionMaker.cs @@ -9,7 +9,7 @@ using NetFabric.Hyperlinq; using Nncase.IR; using Nncase.IR.Math; -using Nncase.IR.Tensors; +using Nncase.IR.RNN; using Nncase.Passes; using Nncase.Passes.Analysis; using Nncase.Passes.Mutators; @@ -20,9 +20,8 @@ using Xunit; using Xunit.Abstractions; using static Nncase.IR.F.Math; +using static Nncase.IR.F.RNN; using static Nncase.IR.F.Tensors; -using static Nncase.IR.TypePatternUtility; -using static Nncase.PatternMatch.F.Math; using static Nncase.PatternMatch.Utility; using Transpose = Nncase.IR.Tensors.Transpose; using Tuple = Nncase.IR.Tuple; @@ -330,9 +329,9 @@ IR.Tuple WrapOutput(Call call) var newVar2 = newVars[2]; var pairs = new[] { - (LSTM.X, (Expr)WrapInput(newVar0)), - (LSTM.InitialC, WrapInput(newVar1)), - (LSTM.InitialH, WrapInput(newVar2)), + (IR.RNN.LSTM.X, (Expr)WrapInput(newVar0)), + (IR.RNN.LSTM.InitialC, WrapInput(newVar1)), + (IR.RNN.LSTM.InitialH, WrapInput(newVar2)), }; var expectLSTM = ReplaceUtility.ReplaceCallParams(lstm.Target, lstm.Arguments.ToArray(), pairs); var expectBody = WrapOutput(expectLSTM); @@ -363,7 +362,7 @@ internal sealed class TestTransposeComplexFusion : ComplexFusion { public override (ParameterInfo, CallPattern)[] InputPatterns { get; } = - GenerateInputPatterns(LSTM.X, LSTM.InitialC, LSTM.InitialH); + GenerateInputPatterns(IR.RNN.LSTM.X, IR.RNN.LSTM.InitialC, IR.RNN.LSTM.InitialH); } } diff --git a/src/Nncase.Tests/Rewrite/RewriteBase.cs b/src/Nncase.Tests/Rewrite/RewriteBase.cs index 33dc381532..c68a1eba45 100644 --- a/src/Nncase.Tests/Rewrite/RewriteBase.cs +++ b/src/Nncase.Tests/Rewrite/RewriteBase.cs @@ -2072,7 +2072,7 @@ public Function PreExpr var input = new Tensor(new[] { 0, 1, 2, 3 }, shape); var indices = new Tensor(new[] { 0L, 0L, 1L, 1L }, shape); long batchDims = 0L; - var expr = IR.F.Tensors.Gather(input, batchDims, indices); + var expr = IR.F.Tensors.Gather(input, (int)batchDims, indices); return new Function(expr, new Var[] { _input }); } } @@ -2906,3 +2906,53 @@ public FoldReshapeWithBranch() public Dictionary FeedDict { get; } } + +public sealed class ReshapeTransposeReshapeCase : IRewriteCase +{ + public ReshapeTransposeReshapeCase() + { + var input = new Var("input", new TensorType(DataTypes.Float32, new[] { 1, 77, 768 })); + { + var v0 = Reshape(input, new[] { 1, 77, 12, 64 }); + var v2 = Transpose(v0, new[] { 0, 2, 1, 3 }); + var v3 = Reshape(v2, new[] { 12, 77, 64 }); + PreExpr = new Function(v3, new[] { input }); + } + + FeedDict = new() { { input, IR.F.Random.Normal(new[] { 1, 77, 768 }).Evaluate() } }; + } + + public Function PreExpr { get; } + + public IEnumerable Rules => new[] { + typeof(CombineReshapeTranspose), + typeof(FoldTwoReshapes), + }; + + public Dictionary FeedDict { get; } +} + +public sealed class ReshapeBinaryConstReshapeCase : IRewriteCase +{ + public ReshapeBinaryConstReshapeCase() + { + var v9 = new Var("v9", new TensorType(DataTypes.Float32, new[] { 12, 77, 77 })); + { + var v10 = Reshape(v9, new[] { 1, 12, 77, 77 }); // f32[1,12,77,77] + var v11 = IR.F.Math.Add(v10, IR.F.Random.Normal(new[] { 1, 1, 77, 77 }).Evaluate().AsTensor()); // f32[1,12,77,77] + var v12 = Reshape(v11, new[] { 12, 77, 77 }); // f32[12,77,77] + + PreExpr = new Function(v12, new[] { v9 }); + } + + FeedDict = new() { { v9, IR.F.Random.Normal(new[] { 12, 77, 77 }).Evaluate() } }; + } + + public Function PreExpr { get; } + + public IEnumerable Rules => new[] { + typeof(FoldReshapeBinaryConstReshape), + }; + + public Dictionary FeedDict { get; } +} diff --git a/src/Nncase.Tests/Rewrite/UnitTestDataFlowRewriteFactory.cs b/src/Nncase.Tests/Rewrite/UnitTestDataFlowRewriteFactory.cs index f9080652a7..06dcc60c9b 100644 --- a/src/Nncase.Tests/Rewrite/UnitTestDataFlowRewriteFactory.cs +++ b/src/Nncase.Tests/Rewrite/UnitTestDataFlowRewriteFactory.cs @@ -15,7 +15,7 @@ public class UnitTestDataFlowRewriteFactory : TestClassBase { public static TheoryData DataOne => new() { - new CombineClampAddMul(), + new ReshapeBinaryConstReshapeCase(), }; public static TheoryData DataAll => new() @@ -31,6 +31,7 @@ public class UnitTestDataFlowRewriteFactory : TestClassBase new Conv2DPadsCase(), new ReduceWindow2DPadsCase(), new MobileNetV1TransposeCase(), + new CombineClampAddMul(), }; [Theory] diff --git a/src/Nncase.Tests/Rewrite/UnitTestEGraphRewriteFactory.cs b/src/Nncase.Tests/Rewrite/UnitTestEGraphRewriteFactory.cs index 9a1d8faee7..6ec5e4721b 100644 --- a/src/Nncase.Tests/Rewrite/UnitTestEGraphRewriteFactory.cs +++ b/src/Nncase.Tests/Rewrite/UnitTestEGraphRewriteFactory.cs @@ -28,7 +28,7 @@ public UnitTestEGraphRewriteFactory() public static TheoryData DataOne => new() { - new PReluTransposeCase(), + new ReshapeTransposeReshapeCase(), }; public static TheoryData DataAll => new() diff --git a/src/Nncase.Tests/Rules/Neutral/UnitTestAddMarker.cs b/src/Nncase.Tests/Rules/Neutral/UnitTestAddMarker.cs index 74f6188c93..c705e5dac2 100644 --- a/src/Nncase.Tests/Rules/Neutral/UnitTestAddMarker.cs +++ b/src/Nncase.Tests/Rules/Neutral/UnitTestAddMarker.cs @@ -96,7 +96,7 @@ public async Task TestAddMarkerWithLstm() var module = new IRModule(main); await TestAddMarkerPasses(module); Assert.True(((Function)module.Entry!).Body is Tuple t - && CompilerServices.TryMatchRoot(t, IsWrappedLSTM(PatternMatch.F.Tensors.IsLSTM("lstm", "lstmCall", _ => true), (x, _) => IsRangeOfMarker(x, IsWildcard())), out var result) + && CompilerServices.TryMatchRoot(t, IsWrappedLSTM(PatternMatch.F.RNN.IsLSTM("lstm", "lstmCall", _ => true), (x, _) => IsRangeOfMarker(x, IsWildcard())), out var result) && result["lstmCall"] is Call call && new[] { 0, 1, 2, 5, 6 }.All(i => call.Arguments[i] is Marker)); } @@ -126,7 +126,7 @@ public async Task TestAddMarkerWithLstmInitHEqualsInitC() var module = new IRModule(main); await TestAddMarkerPasses(module); Assert.True(((Function)module.Entry!).Body is Tuple t - && CompilerServices.TryMatchRoot(t, IsWrappedLSTM(PatternMatch.F.Tensors.IsLSTM("lstm", "lstmCall", _ => true), (x, _) => IsRangeOfMarker(x, IsWildcard())), out var result) + && CompilerServices.TryMatchRoot(t, IsWrappedLSTM(PatternMatch.F.RNN.IsLSTM("lstm", "lstmCall", _ => true), (x, _) => IsRangeOfMarker(x, IsWildcard())), out var result) && result["lstmCall"] is Call call && new[] { 0, 1, 2, 5, 6 }.All(i => call.Arguments[i] is Marker)); } diff --git a/src/Nncase.Tests/Rules/Neutral/UnitTestCombineReshape.cs b/src/Nncase.Tests/Rules/Neutral/UnitTestCombineReshape.cs index 9295d4e159..918b685d23 100644 --- a/src/Nncase.Tests/Rules/Neutral/UnitTestCombineReshape.cs +++ b/src/Nncase.Tests/Rules/Neutral/UnitTestCombineReshape.cs @@ -35,6 +35,13 @@ public class UnitTestCombineReshape : TransformTestBase { BinaryOp.Sub, new[] { 1 }, new[] { 1, 32, 32, 64, }, new[] { 1, 1024, 64, 1 }, true }, }; + public static readonly TheoryData TestCombineReshapeTransposeNegativeData = + new() + { + { new[] { 1, 77, 1, 64 }, new[] { 2, 1, 3, 0 }, new[] { 77, 64, 1 } }, + { new[] { 1, 77, 12, 64 }, new[] { 1, 0, 2, 3 }, new[] { 1, 77, 768 } }, + }; + public static IEnumerable CombineBinaryReshapePositiveData => new[] { @@ -197,4 +204,48 @@ public void TestCombineReshapePadNegative(int[] inShape, int[] shape, int[] pads var rootPre = Tensors.Reshape(NN.Pad(a, Tensor.From(pads, new[] { pads.Length / 2, 2 }), PadMode.Constant, 0f), shape); TestNotMatch(rootPre); } + + [Theory] + [ClassData(typeof(CombineReshapeTransposePostiveData))] + public void TestCombineReshapeTransposePostive(int[] inShape, int[] perm, int[] newshape) + { + var input = new Var("input", new TensorType(DataTypes.Float32, inShape)); + var feed_dict = new Dictionary + { + { input, Random.Normal(DataTypes.Float32, 0, 1, 0, inShape).Evaluate() }, + }; + var rootPre = Tensors.Reshape(Tensors.Transpose(input, perm), newshape); + TestMatched(rootPre, feed_dict); + } + + [Theory] + [MemberData(nameof(TestCombineReshapeTransposeNegativeData))] + public void TestCombineReshapeTransposeNegative(int[] inShape, int[] perm, int[] newshape) + { + var input = new Var("input", new TensorType(DataTypes.Float32, inShape)); + var rootPre = Tensors.Reshape(Tensors.Transpose(input, perm), newshape); + TestNotMatch(rootPre); + } + + private sealed class CombineReshapeTransposePostiveData : TheoryData + { + public CombineReshapeTransposePostiveData() + { + var inshapes = new[] { + new[] { 1, 77, 12, 64 }, + new[] { 77, 1, 12, 64 }, + new[] { 77, 12, 1, 64 }, + new[] { 77, 12, 64, 1 }, + }; + + var perms = new[] { 0, 1, 2, 3 }.Permutate().ToArray(); + + foreach (var (inshape, perm) in new[] { inshapes, perms }.CartesianProduct().Select(i => i.ToArray()).Select(i => (i[0], i[1]))) + { + var newshape = perm.Select(i => inshape[i]).ToList(); + newshape.Remove(1); + Add(inshape, perm, newshape.ToArray()); + } + } + } } diff --git a/src/Nncase.Tests/Rules/Neutral/UnitTestCombineTranspose.cs b/src/Nncase.Tests/Rules/Neutral/UnitTestCombineTranspose.cs index 440a53e551..dcde48a8b7 100644 --- a/src/Nncase.Tests/Rules/Neutral/UnitTestCombineTranspose.cs +++ b/src/Nncase.Tests/Rules/Neutral/UnitTestCombineTranspose.cs @@ -39,14 +39,14 @@ public class UnitTestCombineTranspose : TransformTestBase }; public static IEnumerable CombineBinaryTransposePositiveData => - new[] - { + new[] + { new object[] { new[] { 5, 4 }, new[] { 5, 4 }, new[] { 1, 0 } }, new object[] { new[] { 4, 4 }, new[] { 4, 4 }, new[] { 1, 0 } }, new object[] { new[] { 4 }, new[] { 4 }, new[] { 0 } }, new object[] { new[] { 1, 3, 4 }, new[] { 1, 3, 4 }, new[] { 0, 2, 1 } }, new object[] { new[] { 1, 3, 2, 4 }, new[] { 1, 3, 2, 4 }, new[] { 0, 2, 3, 1 } }, - }; + }; public static IEnumerable CombineConstBinaryTransposeNotMatchData => new[] @@ -359,4 +359,39 @@ public void TestCombineTransposeUnaryPositive(UnaryOp opType, int[] inShape, int var rootPre = IR.F.Math.Unary(opType, Tensors.Transpose(a, perm)); TestMatched(rootPre, normal); } + + [Theory] + [ClassData(typeof(CombineTransposeReshapePostiveData))] + public void TestCombineTransposeReshapePostive(int[] inShape, int[] newShape, int[] perm) + { + var a = new Var(new TensorType(DataTypes.Float32, inShape)); + var feed_dict = new Dictionary + { + { a, Random.Normal(DataTypes.Float32, 0, 1, 0, inShape).Evaluate() }, + }; + var rootPre = Tensors.Transpose(Tensors.Reshape(a, newShape), perm); + TestMatched(rootPre, feed_dict); + } + + private sealed class CombineTransposeReshapePostiveData : TheoryData + { + public CombineTransposeReshapePostiveData() + { + var inshapes = new[] { new[] { 12, 77, 64 } }; + + var newShapes = new[] { + new[] { 1, 12, 77, 64 }, + new[] { 12, 1, 77, 64 }, + new[] { 12, 77, 1, 64 }, + new[] { 12, 77, 64, 1 }, + }; + + var perms = new[] { 0, 1, 2, 3 }.Permutate().ToArray(); + + foreach (var (a, b, c) in new[] { inshapes, newShapes, perms }.CartesianProduct().Select(i => i.ToArray()).Select(i => (i[0], i[1], i[2]))) + { + Add(a, b, c); + } + } + } } diff --git a/src/Nncase.Tests/Rules/Neutral/UnitTestFoldReshape.cs b/src/Nncase.Tests/Rules/Neutral/UnitTestFoldReshape.cs index b4af452d8e..ac8bbef641 100644 --- a/src/Nncase.Tests/Rules/Neutral/UnitTestFoldReshape.cs +++ b/src/Nncase.Tests/Rules/Neutral/UnitTestFoldReshape.cs @@ -21,6 +21,12 @@ namespace Nncase.Tests.Rules.NeutralTest; [AutoSetupTestMethod(InitSession = true)] public class UnitTestFoldReshape : TransformTestBase { + public static TheoryData TestReshapeBinaryConstReshapePositiveData => new() + { + { new[] { 12, 77, 77 }, new[] { 1, 12, 77, 77 }, new[] { 1, 1, 77, 77 }, new[] { 12, 77, 77 } }, + { new[] { 12, 77, 77 }, new[] { 1, 12, 77, 77 }, new[] { 77 }, new[] { 12, 77, 77 } }, + }; + public static IEnumerable TestFoldNopReshapePositiveData => new[] { @@ -101,4 +107,16 @@ public void TestReshapeToTransposeNegative(int[] shape, int[] newShape) var rootPre = Tensors.Reshape(a, newShape); TestNotMatch(rootPre); } + + [Theory] + [MemberData(nameof(TestReshapeBinaryConstReshapePositiveData))] + public void TestReshapeBinaryConstReshapePositive(int[] inShape, int[] unsqShape, int[] constShape, int[] sqShape) + { + var a = Random.Normal(DataTypes.Float32, 0, 1, 0, inShape); + var v0 = Tensors.Reshape(a, unsqShape); + var v1 = Math.Binary(BinaryOp.Add, v0, IR.F.Random.Normal(DataTypes.Float32, 0, 1, 0, constShape).Evaluate().AsTensor()); + var v2 = Tensors.Reshape(v1, sqShape); + var rootPre = v2; + TestMatched(rootPre); + } } diff --git a/src/Nncase.Tests/Rules/Neutral/UnitTestFoldSwish.cs b/src/Nncase.Tests/Rules/Neutral/UnitTestFoldSwish.cs index 897d1b04cb..95c2c010a4 100644 --- a/src/Nncase.Tests/Rules/Neutral/UnitTestFoldSwish.cs +++ b/src/Nncase.Tests/Rules/Neutral/UnitTestFoldSwish.cs @@ -74,7 +74,8 @@ public void TestFoldSwishPattern2Positive2(int[] shape) Expr rootPre; { var v0 = input; - var v1 = IR.F.NN.Sigmoid(v0); + var v0_2 = IR.F.Math.Binary(BinaryOp.Mul, v0, 2.0f); + var v1 = IR.F.NN.Sigmoid(v0_2); var v2 = IR.F.Math.Binary(BinaryOp.Mul, v0, v1); rootPre = v2; } diff --git a/src/Nncase.Tests/Rules/ShapeBucket/ShapeBucketTest.cs b/src/Nncase.Tests/Rules/ShapeBucket/ShapeBucketTest.cs index a65c520bf1..1e0ae120fd 100644 --- a/src/Nncase.Tests/Rules/ShapeBucket/ShapeBucketTest.cs +++ b/src/Nncase.Tests/Rules/ShapeBucket/ShapeBucketTest.cs @@ -368,7 +368,7 @@ public void TestMatMulReshape() var lhs = MakeVar(input); var add = Add(lhs, new[] { 1f }); var rhs = Reshape(add, Concat( - new IR.Tuple(Reshape(Gather(ShapeOf(add), 0L, 0L), new[] { 1L }), new[] { 3L }, new[] { 24L }, new[] { 24L }), 0)); + new IR.Tuple(Reshape(Gather(ShapeOf(add), 0, 0L), new[] { 1L }), new[] { 3L }, new[] { 24L }, new[] { 24L }), 0)); var lhsVar = new Var("lhs", new TensorType(input.ElementType, input.Shape)); var rhsVar = new Var("rhs", new TensorType(input.ElementType, input.Shape)); diff --git a/src/Nncase.Tests/TIR/PrimFunc/IDataFlowPrimFuncCase.cs b/src/Nncase.Tests/TIR/PrimFunc/IDataFlowPrimFuncCase.cs index 9b791049d0..97dbd1915f 100644 --- a/src/Nncase.Tests/TIR/PrimFunc/IDataFlowPrimFuncCase.cs +++ b/src/Nncase.Tests/TIR/PrimFunc/IDataFlowPrimFuncCase.cs @@ -33,11 +33,11 @@ internal static class PrimFuncBuilder public static PrimFunctionWrapper MakeLoadStoreFunc(bool mask) { var allocator = new Allocator(); - var fusion_input = allocator.Allocate($"fusion_{_count}_input", Schedule.MemoryLocation.Input); + var fusion_input = allocator.Allocate($"fusion_{_count}_input", TIR.MemoryLocation.Input); - var glb = allocator.Allocate($"fusion_{_count}_glb", Schedule.MemoryLocation.L2Data); + var glb = allocator.Allocate($"fusion_{_count}_glb", TIR.MemoryLocation.L2Data); - var fusion_output = allocator.Allocate($"fusion_{_count}_output", Schedule.MemoryLocation.Output); + var fusion_output = allocator.Allocate($"fusion_{_count}_output", TIR.MemoryLocation.Output); var fusion_1 = TIR.T.PrimFunc($"fusion_{_count}_{mask}", Callable.StackVMModuleKind, fusion_input, fusion_output).Body( new Call(new TIRTest.LoadT(), fusion_input, glb), @@ -50,12 +50,12 @@ public static PrimFunctionWrapper MakeLoadStoreFunc(bool mask) public static PrimFunctionWrapper MakeBinaryFunc(BinaryOp binaryOp, bool mask) { var allocator = new Allocator(); - var fusion_input_lhs = allocator.Allocate($"fusion_{_count}_input_lhs", Schedule.MemoryLocation.Input); - var fusion_input_rhs = allocator.Allocate($"fusion_{_count}_input_rhs", Schedule.MemoryLocation.Input); - var glb_lhs = allocator.Allocate($"fusion_{_count}_glb_lhs", Schedule.MemoryLocation.L2Data); - var glb_rhs = allocator.Allocate($"fusion_{_count}_glb_rhs", Schedule.MemoryLocation.L2Data); - var glb_output = allocator.Allocate($"fusion_{_count}_glb_output", Schedule.MemoryLocation.L2Data); - var fusion_output = allocator.Allocate($"fusion_{_count}_output", Schedule.MemoryLocation.Output); + var fusion_input_lhs = allocator.Allocate($"fusion_{_count}_input_lhs", TIR.MemoryLocation.Input); + var fusion_input_rhs = allocator.Allocate($"fusion_{_count}_input_rhs", TIR.MemoryLocation.Input); + var glb_lhs = allocator.Allocate($"fusion_{_count}_glb_lhs", TIR.MemoryLocation.L2Data); + var glb_rhs = allocator.Allocate($"fusion_{_count}_glb_rhs", TIR.MemoryLocation.L2Data); + var glb_output = allocator.Allocate($"fusion_{_count}_glb_output", TIR.MemoryLocation.L2Data); + var fusion_output = allocator.Allocate($"fusion_{_count}_output", TIR.MemoryLocation.Output); var fusion = TIR.T.PrimFunc($"fusion_{_count}_{mask}", Callable.StackVMModuleKind, fusion_input_lhs, fusion_input_rhs, fusion_output).Body( new Call(new TIRTest.LoadT(), fusion_input_lhs, glb_lhs), @@ -71,16 +71,16 @@ public static PrimFunctionWrapper MakeBinaryFunc(BinaryOp binaryOp, bool mask) public static PrimFunctionWrapper MakeMultiInputFunc(int length, bool mask) { var allocator = new Allocator(); - var fusion_inputs = new List(); + var fusion_inputs = new List(); for (int i = 0; i < length; i++) { - var fusion_input_i = allocator.Allocate($"fusion_{_count}_input_{i}", Schedule.MemoryLocation.Input); + var fusion_input_i = allocator.Allocate($"fusion_{_count}_input_{i}", TIR.MemoryLocation.Input); fusion_inputs.Add(fusion_input_i); } - var glb1 = allocator.Allocate($"fusion_{_count}_glb1", Schedule.MemoryLocation.L2Data); - var glb2 = allocator.Allocate($"fusion_{_count}_glb2", Schedule.MemoryLocation.L2Data); - var fusion_output = allocator.Allocate($"fusion_{_count}_output", Schedule.MemoryLocation.Output); + var glb1 = allocator.Allocate($"fusion_{_count}_glb1", TIR.MemoryLocation.L2Data); + var glb2 = allocator.Allocate($"fusion_{_count}_glb2", TIR.MemoryLocation.L2Data); + var fusion_output = allocator.Allocate($"fusion_{_count}_output", TIR.MemoryLocation.Output); var fusion = TIR.T.PrimFunc($"multi_fusion_{_count}_{mask}", Callable.StackVMModuleKind, fusion_inputs.Concat(new[] { fusion_output }).ToArray()); @@ -124,18 +124,20 @@ private static IEnumerable GetBinaryOp(int length) private sealed class Allocator { - private readonly Dictionary _useage = new() { - { Schedule.MemoryLocation.Input, 0 }, - { Schedule.MemoryLocation.Output, 0 }, - { Schedule.MemoryLocation.L2Data, 0 }, + private readonly Dictionary _usage = new() { + { TIR.MemoryLocation.Input, 0 }, + { TIR.MemoryLocation.Output, 0 }, + { TIR.MemoryLocation.L2Data, 0 }, }; - public TIR.PhysicalBuffer Allocate(string name, Schedule.MemoryLocation location) + public TIR.Buffer Allocate(string name, TIR.MemoryLocation location) { - var strides = TensorUtilities.GetStrides(Dimensions); - var size = TensorUtilities.GetSize(Dimensions, strides, DataTypes.Float32.SizeInBytes); - var buffer = new TIR.PhysicalBuffer(name, DataTypes.Float32, location, Dimensions, strides, _useage[location], size); - _useage[location] += size; + var dims = Dimensions.Select(d => (Expr)d).ToArray(); + var strides = TensorUtilities.GetStrides(Dimensions).Select(s => (Expr)s).ToArray(); + var size = TensorUtilities.GetSize(Dimensions, TensorUtilities.GetStrides(Dimensions), DataTypes.Float32.SizeInBytes); + + var buffer = new TIR.Buffer(name, DataTypes.Float32, new TIR.MemSpan(Tensor.FromPointer(_usage[location]), size, location), dims, strides); + _usage[location] += (ulong)size; return buffer; } } diff --git a/src/Nncase.Tests/TIR/PrimFunc/UnitTestPrimFuncMerge.cs b/src/Nncase.Tests/TIR/PrimFunc/UnitTestPrimFuncMerge.cs index 9cd1eb7139..96c6bb415d 100644 --- a/src/Nncase.Tests/TIR/PrimFunc/UnitTestPrimFuncMerge.cs +++ b/src/Nncase.Tests/TIR/PrimFunc/UnitTestPrimFuncMerge.cs @@ -44,16 +44,17 @@ public class UnitTestPrimFuncMerge : TestClassBase public IAnalyzerManager AnalyzerMananger => CompileSession.GetRequiredService(); - [Theory] + [Theory(Skip = "Disable")] [MemberData(nameof(Datas))] private async void RunCore(IDataFlowPrimFuncCase fusionCase, int count) { + var dumper = Diagnostics.DumpScope.Current.CreateSubDummper($"case_{count}"); var inputVar = new Var("input", new TensorType(DataTypes.Float32, PrimFuncBuilder.Dimensions)); var main = new Function(fusionCase.BuildBody(inputVar), inputVar); CompilerServices.InferenceType(main); #if DEBUG - Dumpper.DumpDotIR(main, $"{count}_pre"); + Diagnostics.DumpScope.Current.DumpDotIR(main, $"{count}_pre"); #endif var feedDict = new Dictionary(ReferenceEqualityComparer.Instance) { { inputVar, IR.F.Random.Normal(DataTypes.Float32, 0, 1, 12, PrimFuncBuilder.Dimensions).Evaluate() }, @@ -69,7 +70,7 @@ private async void RunCore(IDataFlowPrimFuncCase fusionCase, int count) var post = (Function)module.Entry!; #if DEBUG - Dumpper.DumpDotIR(post, $"{count}_post"); + Diagnostics.DumpScope.Current.DumpDotIR(post, $"{count}_post"); #endif var visitor = new TestVisitor(); @@ -121,11 +122,11 @@ internal sealed class PrimFuncEvaluateVisitor private static readonly int _pool_size = 1 * 4 * 8 * 9 * 4 * 30; private readonly PrimFunctionWrapper _wrapper; private readonly IValue[] _args; - private readonly Dictionary _poolMap = new() { - { Schedule.MemoryLocation.Input, new byte[_pool_size] }, - { Schedule.MemoryLocation.L2Data, new byte[_pool_size] }, - { Schedule.MemoryLocation.Data, new byte[_pool_size] }, - { Schedule.MemoryLocation.Output, new byte[_pool_size] }, + private readonly Dictionary _poolMap = new() { + { TIR.MemoryLocation.Input, new byte[_pool_size] }, + { TIR.MemoryLocation.L2Data, new byte[_pool_size] }, + { TIR.MemoryLocation.Data, new byte[_pool_size] }, + { TIR.MemoryLocation.Output, new byte[_pool_size] }, }; public PrimFuncEvaluateVisitor(PrimFunctionWrapper wrapper, params IValue[] args) @@ -139,8 +140,8 @@ public IValue Evaluate() // 1. copy input into input pool foreach (var (arg, param) in _args.Zip(_wrapper.Target.Parameters[.._wrapper.ParametersCount].ToArray())) { - Assert.Equal(param.Size, arg.AsTensor().BytesBuffer.Length); - arg.AsTensor().BytesBuffer.CopyTo(_poolMap[param.MemLocation].AsSpan(param.Start)); + Assert.Equal(param.MemSpan.Size.Evaluate().AsTensor().ToScalar(), arg.AsTensor().BytesBuffer.Length); + arg.AsTensor().BytesBuffer.CopyTo(_poolMap[param.MemSpan.Location].AsSpan(param.MemSpan.Start.Evaluate().AsTensor().ToScalar())); } // 2. start l2 computing @@ -153,7 +154,7 @@ public IValue Evaluate() var tensors = new List(); foreach (var outputParam in _wrapper.Target.Parameters[_wrapper.ParametersCount..]) { - tensors.Add(Tensor.FromBytes(outputParam.ElemType, GetBufferSpan(outputParam).ToArray(), outputParam.FixedDimensions.ToArray())); + tensors.Add(Tensor.FromBytes(outputParam.ElemType, GetBufferSpan(outputParam).ToArray(), outputParam.Dimensions.AsValueEnumerable().Select(e => e.Evaluate().AsTensor().ToScalar()).ToArray())); } return tensors.Count == 1 ? Value.FromTensor(tensors[0]) : Value.FromTensors(tensors.ToArray()); @@ -208,7 +209,7 @@ private void EvaluateStatement(Expr statement) private Span GetBufferSpan(Expr expr) { - var buffer = Assert.IsType(expr); - return _poolMap[buffer.MemLocation].AsSpan(buffer.Start, buffer.Size); + var buffer = Assert.IsType(expr); + return _poolMap[buffer.MemSpan.Location].AsSpan(buffer.MemSpan.Start.Evaluate().AsTensor().ToScalar(), buffer.MemSpan.Size.Evaluate().AsTensor().ToScalar()); } } diff --git a/src/Nncase.Tests/TIR/UnitTestMutators.cs b/src/Nncase.Tests/TIR/UnitTestMutators.cs index a90571e8a0..e115f9eeb6 100644 --- a/src/Nncase.Tests/TIR/UnitTestMutators.cs +++ b/src/Nncase.Tests/TIR/UnitTestMutators.cs @@ -30,9 +30,9 @@ public UnitTestMutators() [Fact] public async Task TestFoldConstCallWithTuple() { - T.PhysicalBuffer(DataTypes.BFloat16, Schedule.MemoryLocation.Input, new[] { 48 }, out var ddr_if); - T.PhysicalBuffer(DataTypes.BFloat16, Schedule.MemoryLocation.Data, new[] { 9 }, out var glb_if_ping); - T.PhysicalBuffer(DataTypes.BFloat16, Schedule.MemoryLocation.Data, new[] { 9 }, out var glb_if_pong); + T.CreateBuffer(new TensorType(DataTypes.BFloat16, new[] { 48 }), MemoryLocation.Input, out var ddr_if); + T.CreateBuffer(new TensorType(DataTypes.BFloat16, new[] { 9 }), MemoryLocation.Data, out var glb_if_ping); + T.CreateBuffer(new TensorType(DataTypes.BFloat16, new[] { 9 }), MemoryLocation.Data, out var glb_if_pong); PrimFunction main; { main = T.PrimFunc("main", Callable.StackVMModuleKind, ddr_if).Body( @@ -76,7 +76,7 @@ public async Task TestFoldConstCallWithTuple() int count = 0; for (int w = 0; w < 48; w += 9) { - Assert.True(object.ReferenceEquals(getBuffer(count, LoadT.DdrPp), post.Parameters[0])); + // Assert.True(object.ReferenceEquals(getBuffer(count, LoadT.DdrPp), post.Parameters[0])); var name = getBuffer(count++, LoadT.GlbPp).Name[^4..]; // System.Console.WriteLine($"{w} {name}"); @@ -118,8 +118,8 @@ public async Task TestUnRollLoopSequential() [Fact] public async Task TestUnRollLoopSequential2() { - T.PhysicalBuffer(DataTypes.BFloat16, Schedule.MemoryLocation.Input, new[] { 3, 16, 24, 24 }, out var ddr_if); - T.PhysicalBuffer(DataTypes.BFloat16, Schedule.MemoryLocation.Data, new[] { 3, 10, 5, 9 }, out var glb_if); + T.CreateBuffer(new TensorType(DataTypes.BFloat16, new[] { 3, 16, 24, 24 }), MemoryLocation.Input, out var ddr_if); + T.CreateBuffer(new TensorType(DataTypes.BFloat16, new[] { 3, 10, 5, 9 }), MemoryLocation.Data, out var glb_if); PrimFunction main; { @@ -201,8 +201,8 @@ public async Task TestUnRollLoopSequential2() [Fact] public async Task TestUnRollLoopSequential3() { - T.PhysicalBuffer(DataTypes.BFloat16, Schedule.MemoryLocation.Input, new[] { 3, 16, 24, 24 }, out var ddr_if); - T.PhysicalBuffer(DataTypes.BFloat16, Schedule.MemoryLocation.Data, new[] { 3, 10, 5, 9 }, out var glb_if); + T.CreateBuffer(new TensorType(DataTypes.BFloat16, new[] { 3, 16, 24, 24 }), MemoryLocation.Input, out var ddr_if); + T.CreateBuffer(new TensorType(DataTypes.BFloat16, new[] { 3, 10, 5, 9 }), MemoryLocation.Data, out var glb_if); PrimFunction main; { @@ -362,10 +362,10 @@ public async Task TestFoldLet2() [Fact] public async Task TestFoldBufferIndex() { - T.PhysicalBuffer(DataTypes.BFloat16, Schedule.MemoryLocation.Input, new[] { 3, 16, 24, 24 }, out var ddr_if); - T.PhysicalBuffer(DataTypes.BFloat16, Schedule.MemoryLocation.Output, new[] { 3, 16, 24, 24 }, out var ddr_of); - T.PhysicalBuffer(DataTypes.BFloat16, Schedule.MemoryLocation.Data, new[] { 3, 10, 5, 9 }, out var glb_if); - var bufferIndexMap = new Dictionary() { + T.CreateBuffer(new(DataTypes.BFloat16, new[] { 3, 16, 24, 24 }), MemoryLocation.Input, out var ddr_if); + T.CreateBuffer(new(DataTypes.BFloat16, new[] { 3, 16, 24, 24 }), MemoryLocation.Output, out var ddr_of); + T.CreateBuffer(new(DataTypes.BFloat16, new[] { 3, 10, 5, 9 }), MemoryLocation.Data, out var glb_if); + var bufferIndexMap = new Dictionary() { { ddr_if, 2 }, { ddr_of, 4 }, }; @@ -386,7 +386,7 @@ public async Task TestFoldBufferIndex() pass.Add(); pass.Add(Expr? (Expr e) => { - if (e is Call { } call && call.Arguments[0] is PhysicalBuffer physicalBuffer && bufferIndexMap.TryGetValue(physicalBuffer, out var index)) + if (e is Call { } call && call.Arguments[0] is Buffer physicalBuffer && bufferIndexMap.TryGetValue(physicalBuffer, out var index)) { return index; } diff --git a/src/Nncase.Tests/Transform/UnitTestPassManager.cs b/src/Nncase.Tests/Transform/UnitTestPassManager.cs index bc7cb98896..cfc76bb4eb 100644 --- a/src/Nncase.Tests/Transform/UnitTestPassManager.cs +++ b/src/Nncase.Tests/Transform/UnitTestPassManager.cs @@ -22,7 +22,7 @@ public sealed class UnitTestPassManager : TestClassBase [Fact] public void TestPassMangerUpdateDependence() { - var prim_func_1 = T.PrimFunc("prim_func_1", "k?", T.PhysicalBuffer(DataTypes.Float32, Schedule.MemoryLocation.Input, new[] { 1, 2, 3, 4 }, out _), T.PhysicalBuffer(DataTypes.Float32, Schedule.MemoryLocation.Output, new[] { 1, 2, 3, 4 }, out _)).Body(T.Nop()).Build(); + var prim_func_1 = T.PrimFunc("prim_func_1", "k?", T.CreateBuffer(new(DataTypes.Float32, new[] { 1, 2, 3, 4 }), MemoryLocation.Input, out _), T.CreateBuffer(new(DataTypes.Float32, new[] { 1, 2, 3, 4 }), MemoryLocation.Output, out _)).Body(T.Nop()).Build(); var prim_wrapper = new PrimFunctionWrapper(prim_func_1, 1); @@ -30,7 +30,7 @@ public void TestPassMangerUpdateDependence() var main_func = new Function("main", new Call(prim_wrapper, input), input); // prim_func_2 for update - var prim_func_2 = T.PrimFunc("prim_func_2", "k?", T.PhysicalBuffer(DataTypes.Float32, Schedule.MemoryLocation.Input, new[] { 1, 2, 3, 4 }, out _), T.PhysicalBuffer(DataTypes.Float32, Schedule.MemoryLocation.Output, new[] { 1, 2, 3, 4 }, out _)).Body( + var prim_func_2 = T.PrimFunc("prim_func_2", "k?", T.CreateBuffer(new(DataTypes.Float32, new[] { 1, 2, 3, 4 }), MemoryLocation.Input, out _), T.CreateBuffer(new(DataTypes.Float32, new[] { 1, 2, 3, 4 }), MemoryLocation.Output, out _)).Body( T.Nop(), T.Nop()).Build(); @@ -54,15 +54,15 @@ public void TestPassMangerUpdateDependence2() %3 = %func_3(%2): // f16[1,23,30,16] */ - var prim_func_0 = T.PrimFunc("prim_func_0", "k?", T.PhysicalBuffer(DataTypes.Float32, Schedule.MemoryLocation.Input, new[] { 1, 24, 32, 3 }, out var _), T.PhysicalBuffer(DataTypes.Float32, Schedule.MemoryLocation.Output, new[] { 1, 3, 24, 32 }, out var _)).Body( + var prim_func_0 = T.PrimFunc("prim_func_0", "k?", T.CreateBuffer(new(DataTypes.Float32, new[] { 1, 24, 32, 3 }), MemoryLocation.Input, out var _), T.CreateBuffer(new(DataTypes.Float32, new[] { 1, 3, 24, 32 }), MemoryLocation.Output, out var _)).Body( T.Nop()).Build(); var func_0 = new PrimFunctionWrapper(prim_func_0, 1); - var prim_func_1 = T.PrimFunc("prim_func_1", "k?", T.PhysicalBuffer(DataTypes.Float32, Schedule.MemoryLocation.Input, new[] { 1, 3, 24, 32 }, out var _), T.PhysicalBuffer(DataTypes.Float32, Schedule.MemoryLocation.Output, new[] { 1, 3, 24, 32 }, out var _)).Body( + var prim_func_1 = T.PrimFunc("prim_func_1", "k?", T.CreateBuffer(new(DataTypes.Float32, new[] { 1, 3, 24, 32 }), MemoryLocation.Input, out var _), T.CreateBuffer(new(DataTypes.Float32, new[] { 1, 3, 24, 32 }), MemoryLocation.Output, out var _)).Body( T.Nop()).Build(); var func_1 = new PrimFunctionWrapper(prim_func_1, 1); - var prim_func_2 = T.PrimFunc("prim_func_2", "k?", T.PhysicalBuffer(DataTypes.Float32, Schedule.MemoryLocation.Input, new[] { 1, 3, 24, 32 }, out var _), T.PhysicalBuffer(DataTypes.Float32, Schedule.MemoryLocation.Output, new[] { 1, 23, 30, 16 }, out var _)).Body( + var prim_func_2 = T.PrimFunc("prim_func_2", "k?", T.CreateBuffer(new(DataTypes.Float32, new[] { 1, 3, 24, 32 }), MemoryLocation.Input, out var _), T.CreateBuffer(new(DataTypes.Float32, new[] { 1, 23, 30, 16 }), MemoryLocation.Output, out var _)).Body( T.Nop()).Build(); var func_2 = new PrimFunctionWrapper(prim_func_2, 1); @@ -74,7 +74,7 @@ public void TestPassMangerUpdateDependence2() Assert.True(CompilerServices.InferenceType(main_func)); // prim_func_2 for update - var prim_func_1_update = T.PrimFunc("prim_func_1_update", "k?", T.PhysicalBuffer(DataTypes.Float32, Schedule.MemoryLocation.Input, new[] { 1, 3, 24, 32 }, out var _), T.PhysicalBuffer(DataTypes.Float32, Schedule.MemoryLocation.Output, new[] { 1, 3, 24, 32 }, out var _)).Body( + var prim_func_1_update = T.PrimFunc("prim_func_1_update", "k?", T.CreateBuffer(new(DataTypes.Float32, new[] { 1, 3, 24, 32 }), MemoryLocation.Input, out var _), T.CreateBuffer(new(DataTypes.Float32, new[] { 1, 3, 24, 32 }), MemoryLocation.Output, out var _)).Body( T.Nop(), T.Nop()).Build(); diff --git a/src/Nncase.Tests/Transform/UnitTestSubstitutor.cs b/src/Nncase.Tests/Transform/UnitTestSubstitutor.cs index 9313303252..0864bca6de 100644 --- a/src/Nncase.Tests/Transform/UnitTestSubstitutor.cs +++ b/src/Nncase.Tests/Transform/UnitTestSubstitutor.cs @@ -24,8 +24,9 @@ public sealed class UnitTestSubstitutor : TestClassBase public void TestSubstitutorFailed() { var loop_i = new Var("loop_i", TensorType.Scalar(DataTypes.Int32)); - var prim_func_1 = T.PrimFunc("prim_func_1", "k?", T.PhysicalBuffer(DataTypes.Float32, Schedule.MemoryLocation.Input, new[] { 1, 2, 3, 4 }, out var input_a), T.PhysicalBuffer(DataTypes.Float32, Schedule.MemoryLocation.Output, new[] { 1, 2, 3, 4 }, out var input_b)).Body( - T.Load(T.Handle("hd", DataTypes.Float32), loop_i)).Build(); + T.CreateBuffer(new(DataTypes.Float32, new[] { 1, 2, 3, 4 }), MemoryLocation.Input, out var hd); + var prim_func_1 = T.PrimFunc("prim_func_1", "k?", T.CreateBuffer(new(DataTypes.Float32, new[] { 1, 2, 3, 4 }), MemoryLocation.Input, out var input_a), T.CreateBuffer(new(DataTypes.Float32, new[] { 1, 2, 3, 4 }), MemoryLocation.Output, out var input_b)).Body( + T.Load(hd, loop_i)).Build(); var prim_wrapper = new PrimFunctionWrapper(prim_func_1, 1); @@ -48,8 +49,9 @@ public void TestSubstitutorFailed() public void TestSubstitutorTrue() { var loop_i = new Var("loop_i", TensorType.Scalar(DataTypes.Int32)); - var prim_func_1 = T.PrimFunc("prim_func_1", "k?", T.PhysicalBuffer(DataTypes.Float32, Schedule.MemoryLocation.Input, new[] { 1, 2, 3, 4 }, out var input_a), T.PhysicalBuffer(DataTypes.Float32, Schedule.MemoryLocation.Output, new[] { 1, 2, 3, 4 }, out var input_b)).Body( - T.Load(T.Handle("hd", DataTypes.Float32), loop_i)).Build(); + T.CreateBuffer(new(DataTypes.Float32, new[] { 1, 2, 3, 4 }), MemoryLocation.Input, out var hd); + var prim_func_1 = T.PrimFunc("prim_func_1", "k?", T.CreateBuffer(new(DataTypes.Float32, new[] { 1, 2, 3, 4 }), MemoryLocation.Input, out var input_a), T.CreateBuffer(new(DataTypes.Float32, new[] { 1, 2, 3, 4 }), MemoryLocation.Output, out var input_b)).Body( + T.Load(hd, loop_i)).Build(); Dictionary vmap = new() { { loop_i, 1 } }; var substitutor = Mutator.Substitute(e => vmap.TryGetValue(e, out var res) ? res : null)(); @@ -65,8 +67,9 @@ public void TestSubstitutorTrue() public void TestSubstitutorTrue2() { var loop_i = new Var("loop_i", TensorType.Scalar(DataTypes.Int32)); - var prim_func_1 = T.PrimFunc("prim_func_1", "k?", T.PhysicalBuffer(DataTypes.Float32, Schedule.MemoryLocation.Input, new[] { 1, 2, 3, 4 }, out var input_a), T.PhysicalBuffer(DataTypes.Int32, Schedule.MemoryLocation.Output, new[] { 1, 2, 3, 4 }, out var input_b)).Body( - T.Load(T.Handle("hd", DataTypes.Float32), loop_i)).Build(); + T.CreateBuffer(new(DataTypes.Float32, new[] { 1, 2, 3, 4 }), MemoryLocation.Input, out var hd); + var prim_func_1 = T.PrimFunc("prim_func_1", "k?", T.CreateBuffer(new(DataTypes.Float32, new[] { 1, 2, 3, 4 }), MemoryLocation.Input, out var input_a), T.CreateBuffer(new(DataTypes.Int32, new[] { 1, 2, 3, 4 }), MemoryLocation.Output, out var input_b)).Body( + T.Load(hd, loop_i)).Build(); var prim_wrapper = new PrimFunctionWrapper(prim_func_1, 1); diff --git a/src/Nncase.Tests/packages.lock.json b/src/Nncase.Tests/packages.lock.json index c722181308..beff5e02b5 100644 --- a/src/Nncase.Tests/packages.lock.json +++ b/src/Nncase.Tests/packages.lock.json @@ -80,11 +80,11 @@ }, "StyleCop.Analyzers": { "type": "Direct", - "requested": "[1.2.0-beta.507, )", - "resolved": "1.2.0-beta.507", - "contentHash": "/FtugDT66cKJJ+GGH7rNpG6UDrT4iIWz45M6lrXXHobDUFDHw+q5VgkbiR+6ffTO564ge7w6fQh/eoQhVdJO8Q==", + "requested": "[1.2.0-beta.435, )", + "resolved": "1.2.0-beta.435", + "contentHash": "TADk7vdGXtfTnYCV7GyleaaRTQjfoSfZXprQrVMm7cSJtJbFc1QIbWPyLvrgrfGdfHbGmUPvaN4ODKNxg2jgPQ==", "dependencies": { - "StyleCop.Analyzers.Unstable": "1.2.0.507" + "StyleCop.Analyzers.Unstable": "1.2.0.435" } }, "System.Linq.Async": { @@ -490,8 +490,8 @@ }, "StyleCop.Analyzers.Unstable": { "type": "Transitive", - "resolved": "1.2.0.507", - "contentHash": "gTY3IQdRqDJ4hbhSA3e/R48oE8b/OiKfvwkt1QdNVfrJK2gMHBV8ldaHJ885jxWZfllK66soa/sdcjh9bX49Tw==" + "resolved": "1.2.0.435", + "contentHash": "ouwPWZxbOV3SmCZxIRqHvljkSzkCyi1tDoMzQtDb/bRP8ctASV/iRJr+A2Gdj0QLaLmWnqTWDrH82/iP+X80Lg==" }, "System.Buffers": { "type": "Transitive", @@ -871,6 +871,7 @@ "Microsoft.Extensions.Options": "[6.0.0, )", "Microsoft.Toolkit.HighPerformance": "[7.1.1, )", "NetFabric.Hyperlinq": "[3.0.0-beta48, )", + "System.CommandLine": "[2.0.0-beta4.22272.1, )", "System.Reactive": "[5.0.0, )" } }, @@ -1099,6 +1100,12 @@ "resolved": "1.0.2", "contentHash": "giLAHrjJe0Bh7yhNexR6pmcv02+Fi+lEPxQVdB8zvkuJCmy6rnqu8CZLIpxrUfLcWDuTCSiK0IfGmMhig3UDhA==" }, + "System.CommandLine": { + "type": "CentralTransitive", + "requested": "[2.0.0-beta4.22272.1, )", + "resolved": "2.0.0-beta4.22272.1", + "contentHash": "1uqED/q2H0kKoLJ4+hI2iPSBSEdTuhfCYADeJrAqERmiGQ2NNacYKRNEQ+gFbU4glgVyK8rxI+ZOe1onEtr/Pg==" + }, "System.Reactive": { "type": "CentralTransitive", "requested": "[5.0.0, )", diff --git a/targets/Nncase.Targets.CSource/CSourceTarget.cs b/targets/Nncase.Targets.CSource/CSourceTarget.cs deleted file mode 100644 index e6f3e19a8f..0000000000 --- a/targets/Nncase.Targets.CSource/CSourceTarget.cs +++ /dev/null @@ -1,34 +0,0 @@ - -namespace Nncase.Targets; - -public class CSourceTarget : ITarget -{ - /// - public string Kind { get => "CSource"; set { } } - /// - public Dictionary Options { get; set; } = new(); - /// - public Dictionary Attrs { get; set; } = new(); - /// - public void ConfigOptions() { } - /// - public void ConfigAttrs() { } - - /// - public Schedule.IScheduler CreateScheduler(IR.IRModule main_module) - { - return new Schedule.CSourceScheduler(main_module, this); - } - - /// - public CodeGen.IRTModel CreateRTModel(IR.IRModel model) - { - return new CodeGen.CSourceRTModel(model, this); - } - - /// - public CodeGen.IRTModule CreateRTModule(IR.IRModel model, IR.IRModule module) - { - throw new NotImplementedException("The CSource Target Only Have Runtime Model!"); - } -} diff --git a/targets/Nncase.Targets.CSource/CodeGen/CSource.cs b/targets/Nncase.Targets.CSource/CodeGen/CSource.cs deleted file mode 100644 index afc0de3b68..0000000000 --- a/targets/Nncase.Targets.CSource/CodeGen/CSource.cs +++ /dev/null @@ -1,278 +0,0 @@ -using System; -using System.Collections; -using System.Collections.Generic; -using System.Diagnostics; -using System.IO; -using System.Linq; -using System.Runtime.InteropServices; -using System.Text; -using Nncase.IR; -using Nncase.Schedule; -using Nncase.TIR; - -namespace Nncase.CodeGen; - -/// -/// the c source runtime function. -/// -/// -/// -public record CSourceRTFunction(string name, Delegate handle) : IRTFunction -{ - public string Name { get => name; set { } } - public Delegate Handle { get => handle; set { } } -} - -public class CSourceSerializeResult : ISerializeResult -{ - -} - -/// -/// c runtime module impl -/// -public class CSourceRTModel : IRTModule, IRTModel -{ - /// - public ModuleType ModuleType { get => CodeGen.ModuleType.Create("CSource"); set { } } - - /// - public ITarget Target { get; set; } - - /// - public IReadOnlyList Modules => throw new NotImplementedException(); - - /// - public string SourcePath { get; private set; } - - public IRModel Model { get; set; } - IRTFunction? _entry = null; - - /// - public bool IsSerialized { get; private set; } - - readonly List _functions = new(); - - /// - /// - /// - public CSourceRTModel(IRModel model, ITarget target) - { - SourcePath = CodeGenUtil.GetTempFileName("c"); - Model = model; - Target = target; - } - - /// - public byte[] Source { get => File.ReadAllBytes(SourcePath); set { } } - - /// - public string SourceExt { get => "c"; set { } } - - /// - public IRTFunction? Entry => _entry; - - /// - public IReadOnlyList Functions => _functions; - - /// - string _dllPath = ""; - - /// - /// write the c source code into source path. - /// - /// - void BuildCode() - { - if (File.Exists(SourcePath)) - File.Delete(SourcePath); - using (var writer = new StreamWriter(SourcePath, false, Encoding.UTF8)) - { - var visior = new CSourceHostBuildVisior(writer); - if (Model.Entry is null) { throw new InvalidProgramException("The Model Entry Is Null!"); } - if (Model.Entry.CheckedType is null && Model.Entry.InferenceType() == false) { throw new InvalidProgramException("The Model Entry Can't Inference Type!"); } - visior.Visit(Model.Entry); - } - } - - public void CompileCode() - { - if (!File.Exists(SourcePath)) - throw new InvalidProgramException("The Source Code Path Is Invalid!"); - var compiler = new CSourceCompiler(); - _dllPath = compiler.Compile(SourcePath); - } - - /// - /// bind each IR.Funtion with C function - /// - /// - public void ExportCode() - { - if (!File.Exists(_dllPath)) - throw new InvalidProgramException("The DLL Path Is Invalid!"); - var dllPtr = NativeLibrary.Load(_dllPath); - foreach (var module in Model.Modules) - { - foreach (var f in module.Callables) - { - var funcType = f.ToDelegateType(Path.GetFileName(_dllPath)); - var funPtr = NativeLibrary.GetExport(dllPtr, f.Name); - _functions.Add(new CSourceRTFunction(f.Name, funPtr.BindDelegate(funcType))); - if (f == Model.Entry) { _entry = _functions.Last(); } - } - } - } - - /// - public ISerializeResult Serialize() - { - if (IsSerialized) { return new CSourceSerializeResult(); } - BuildCode(); - CompileCode(); - ExportCode(); - return new CSourceSerializeResult(); - } - - /// - /// invoke the module entry - /// - /// input args - /// results - /// - public object? Invoke(params object?[]? args) - { - if (Entry is null) - throw new InvalidOperationException("This RTModule Have No Entry Function!"); - return Entry.Handle.DynamicInvoke(args); - } - - public string Dump(string name, string DumpDirPath) - { - var dump_path = $"{DumpDirPath}/{name}.{SourceExt}"; - using var file = File.Open(dump_path, FileMode.OpenOrCreate, FileAccess.Write); - using var writer = new StreamWriter(file); - writer.Write(Source); - return dump_path; - } - -} - -/// -/// the csource code compiler. -/// -public class CSourceCompiler -{ - /// - /// compiler exe name - /// - string _exe = "", _arch = "", _ext = ""; - - /// - /// select current pattern's exe - /// - /// - void PlatformSpecific() - { - if (RuntimeInformation.IsOSPlatform(OSPlatform.Linux)) - { - _exe = "gcc"; - _ext = "so"; - } - else if (RuntimeInformation.IsOSPlatform(OSPlatform.OSX)) - { - _exe = "clang"; - _ext = "dylib"; - } - else if (RuntimeInformation.IsOSPlatform(OSPlatform.Windows)) - { - _exe = "cmd"; - _ext = "dll"; - } - } - - void ArchSpecific() - { - _arch = RuntimeInformation.OSArchitecture switch - { - Architecture.X64 => RuntimeInformation.IsOSPlatform(OSPlatform.Linux) ? "x86-64" : "x86_64", - Architecture.Arm64 => "arm64", - _ => throw new NotSupportedException(RuntimeInformation.OSArchitecture.ToString()), - }; - } - - string ArgumentsSpecific(string sourcePath, string outPath) - { - if (RuntimeInformation.IsOSPlatform(OSPlatform.Linux)) - { - return $"{sourcePath} -fPIC -shared -march={Arch} -o {outPath}"; - } - else if (RuntimeInformation.IsOSPlatform(OSPlatform.OSX)) - { - return $"{sourcePath} -fPIC -shared -arch {Arch} -o {outPath}"; - } - else if (RuntimeInformation.IsOSPlatform(OSPlatform.Windows)) - { - var vsdir = Environment.GetEnvironmentVariable("VSAPPIDDIR") ?? throw new InvalidOperationException("Cannot find vs"); - var vcvardir = Path.Combine(vsdir, "..\\..\\VC\\Auxiliary\\Build\\vcvarsall.bat"); - return $"/C (\"{vcvardir}\" x64) && (cl /D_USRDLL /D_WINDLL \"{sourcePath}\" /MT /link /DLL /OUT:\"{outPath}\")"; - } - throw new System.ArgumentOutOfRangeException("Only Support Linux/Osx/Windows"); - } - - protected string Exe - { - get => _exe; - } - - protected string Arch - { - get => _arch; - } - - protected string Ext - { - get => _ext; - } - - public CSourceCompiler() - { - PlatformSpecific(); - ArchSpecific(); - } - - /// - /// compile the source txt, write to the out_path - /// - /// c source code - /// out .so path - /// outPath - public string Compile(string sourcePath, string outPath) - { - var errMsg = new StringBuilder(); - using (var errWriter = new StringWriter(errMsg)) - { - using (var proc = new Process()) - { - proc.StartInfo.FileName = Exe; - proc.StartInfo.Arguments = ArgumentsSpecific(sourcePath, outPath); - proc.StartInfo.RedirectStandardError = true; - proc.ErrorDataReceived += (sender, e) => errWriter.WriteLine(e.Data); - proc.Start(); - proc.BeginErrorReadLine(); - proc.WaitForExit(); - if (proc.ExitCode != 0) - { - throw new InvalidOperationException(errMsg.ToString()); - } - } - } - return outPath; - } - - /// - /// create the temp dll file and compile source - /// - /// - public string Compile(string sourcePath) => Compile(sourcePath, CodeGenUtil.GetTempFileName(Ext)); -} \ No newline at end of file diff --git a/targets/Nncase.Targets.CSource/CodeGen/CSourceVisitor.cs b/targets/Nncase.Targets.CSource/CodeGen/CSourceVisitor.cs deleted file mode 100644 index 352b9dc6ea..0000000000 --- a/targets/Nncase.Targets.CSource/CodeGen/CSourceVisitor.cs +++ /dev/null @@ -1,317 +0,0 @@ -using System; -using System.Collections; -using System.Collections.Generic; -using System.Diagnostics; -using System.IO; -using System.Linq; -using System.Runtime.InteropServices; -using System.Text; -using Nncase.IR; -using Nncase.Runtime; -using Nncase.TIR; - -namespace Nncase.CodeGen; - -/// -/// convert the type/op to c name -/// -internal static class NameConverter -{ - private static readonly Dictionary _primTypeToC = new() - { - { DataTypes.Boolean, "bool" }, - { DataTypes.Int8, "int8_t" }, - { DataTypes.Int16, "int16_t" }, - { DataTypes.Int32, "int32_t" }, - { DataTypes.Int64, "int64_t" }, - { DataTypes.UInt8, "uint8_t" }, - { DataTypes.UInt16, "uint16_t" }, - { DataTypes.UInt32, "uint32_t" }, - { DataTypes.UInt64, "uint64_t" }, - { DataTypes.Float32, "float" }, - { DataTypes.Float64, "double" }, - }; - - public static string toC(this PrimType primType) => - _primTypeToC[primType]; - - public static string toC(this DataType dataType) => dataType switch - { - PrimType ptype => ptype.toC(), - PointerType { ElemType: PrimType etype } => etype.toC() + "*", - _ => throw new NotSupportedException(dataType.ToString()) - }; -} - -/// -/// the c symbol define -/// -internal struct CSymbol -{ - public string Type; - public StringBuilder Doc; - public CSymbol(string type, StringBuilder doc) - { - Type = type; - Doc = doc; - } - public override string ToString() => $"{Type} {Doc}"; -} - -/// -/// collect the csymbol's parameter -/// -internal class CSymbolParamList : IParameterList, IEnumerable -{ - CSymbol[] Symbols; - public CSymbolParamList(CSymbol[] symbols) - { - Symbols = symbols; - } - - public CSymbol this[ParameterInfo parameter] => Symbols[parameter.Index]; - public CSymbol this[int index] => Symbols[index]; - - public IEnumerator GetEnumerator() - { - return ((IEnumerable)Symbols).GetEnumerator(); - } - - IEnumerator IEnumerable.GetEnumerator() - { - return Symbols.GetEnumerator(); - } -} - - -/// -/// visitor for the build c source code, the expr vistor return (type string , name string) -/// -internal class CSourceHostBuildVisior : ExprFunctor -{ - - /// - /// source writer . - /// TODO we need the decl writer - /// - readonly ScopeWriter Scope; - - /// - /// symbols name memo - /// - readonly Dictionary Symbols = new(ReferenceEqualityComparer.Instance); - - /// - /// - /// - /// - public CSourceHostBuildVisior(TextWriter textWriter) - { - Scope = new ScopeWriter(textWriter); - // insert some declare - Scope.IndWriteLine(@" -#ifdef _WIN32 -#define EXPORT_API __declspec(dllexport) -#else -#define EXPORT_API -#endif"); - Scope.IndWriteLine("#include "); - } - - /// - public override CSymbol Visit(Call expr) - { - if (Symbols.TryGetValue(expr, out var symbol)) { return symbol; } - var target = Visit(expr.Target); - var args = new CSymbolParamList(expr.Parameters.Select(Visit).ToArray()); - var type = VisitType(expr.CheckedType!); - Scope.Push(); - switch (expr.Target) - { - case IR.Math.Binary: - Scope.Append($"({args[0].Doc} {target.Doc} {args[1].Doc})"); - break; - case Store: - Scope.Append($"{args[Store.Handle].Doc}[{args[Store.Index].Doc}] = {args[Store.Value].Doc}"); - break; - case Load: - Scope.Append($"{args[Store.Handle].Doc}[{args[Store.Index].Doc}]"); - break; - case IR.Tensors.Cast: - Scope.Append($"(({type}){args[IR.Tensors.Cast.Input].Doc})"); - break; - default: - Scope.Append($"{target.Doc}({string.Join(", ", args.Select(x => x.Doc))})"); - break; - } - symbol = new(type, Scope.Pop()); - Symbols.Add(expr, symbol); - return symbol; - } - - /// - public override CSymbol Visit(Const expr) - { - if (Symbols.TryGetValue(expr, out var symbol)) { return symbol; } - if (expr.CheckedType is TensorType ttype && ttype.IsScalar) - { - var literal = $"{expr}" switch - { - "True" => "1", - "False" => "0", - var x => x - }; - symbol = new(VisitType(ttype), new(literal)); - } - else - { - throw new NotSupportedException($"Not Support {expr.CheckedType} Const"); - } - Symbols.Add(expr, symbol); - return symbol; - } - - /// - public override CSymbol Visit(Function expr) - { - if (Symbols.TryGetValue(expr, out var symbol)) { return symbol; } - var retType = VisitType(((CallableType)expr.CheckedType!).ReturnType); - Scope.Push(); - // 1. Function signature - Scope.IndWrite($"EXPORT_API {retType} {expr.Name}({string.Join(", ", expr.Parameters.Select(Visit))}) {{"); - // 2. Function body - using (Scope.IndentUp()) - { - Scope.Append(Visit(expr.Body).Doc); - } - // 3. Function closing - Scope.IndWrite("}"); - symbol = new(CallableTypeToPtr((CallableType)expr.CheckedType!, expr.Name), Scope.Pop()); - // 4. write whole code - Scope.IndWrite(symbol.Doc); - return symbol; - } - - /// - public override CSymbol Visit(Op expr) - { - if (Symbols.TryGetValue(expr, out var symbol)) { return symbol; } - symbol = new("Invalid Op", new(expr switch - { - IR.Math.Binary op => op.BinaryOp switch - { - BinaryOp.Add => "+", - BinaryOp.Sub => "-", - BinaryOp.Mul => "*", - BinaryOp.Div => "/", - BinaryOp.Mod => "%", - _ => throw new ArgumentOutOfRangeException(op.BinaryOp.ToString()) - }, - TIR.Store op => "Store", - TIR.Load op => "Load", - IR.Tensors.Cast op => op.NewType.toC(), - _ => throw new NotSupportedException($"{expr.GetType().Name}") - })); - Symbols.Add(expr, symbol); - return symbol; - } - - /// - public override CSymbol Visit(Var expr) - { - if (Symbols.TryGetValue(expr, out var symbol)) { return symbol; } - var isymbol = Scope.GetUniqueVarSymbol(expr); - symbol = new(VisitType(expr.CheckedType!), isymbol.Span); - Symbols.Add(expr, symbol); - return symbol; - } - - /// - public override CSymbol Visit(For expr) - { - if (Symbols.TryGetValue(expr, out var symbol)) { return symbol; } - Scope.Push(); - // 1. For Loop signature - var loopVar = Visit(expr.LoopVar); - Scope.Append($"for ({loopVar} = {Visit(expr.Dom.Start).Doc}; {loopVar.Doc} < {Visit(expr.Dom.Stop).Doc}; {loopVar.Doc}+={expr.Dom.Step}) {{"); - // 2. For Body - Scope.Append(Visit(expr.Body).Doc); - // 3. For closing - Scope.IndWrite("}"); - symbol = new(VisitType(expr.CheckedType!), Scope.Pop()); - Symbols.Add(expr, symbol); - return symbol; - } - - /// - public override CSymbol Visit(Sequential expr) - { - if (Symbols.TryGetValue(expr, out var symbol)) { return symbol; } - Scope.Push(); - Scope.AppendLine(""); - using (Scope.IndentUp()) - { - foreach (var i in Enumerable.Range(0, expr.Fields.Count)) - { - if (i == expr.Fields.Count - 1 && - expr.Fields[i].CheckedType is TensorType) - { - Scope.IndWrite("return "); - } - else - { - Scope.IndWrite(string.Empty); - } - Scope.Append(Visit(expr.Fields[i]).Doc); - if (expr.Fields[i] is Call) - { - Scope.AppendLine(";"); - } - else - { - Scope.AppendLine(string.Empty); - } - } - } - symbol = new(VisitType(expr.CheckedType!), Scope.Pop()); - Symbols.Add(expr, symbol); - return symbol; - } - - /// - public override CSymbol Visit(IfThenElse expr) - { - if (Symbols.TryGetValue(expr, out var symbol)) { return symbol; } - Scope.Push(); - Scope.Append($"if({Visit(expr.Condition).Doc}) {{"); - Scope.Append(Visit(expr.Then).Doc); - Scope.IndWrite("} else {"); - Scope.Append(Visit(expr.Else).Doc); - Scope.IndWrite("}"); - symbol = new(VisitType(expr.CheckedType!), Scope.Pop()); - Symbols.Add(expr, symbol); - return symbol; - } - - /// - /// - /// void (*fun_ptr)(int) - /// - public string CallableTypeToPtr(CallableType type, string name) => $"{VisitType(type.ReturnType)} (*{name}_ptr)({string.Join(",", type.Parameters.Select(VisitType))})"; - - - /// - public override string VisitType(TensorType type) - { - if (!type.IsScalar) - { - throw new NotSupportedException($"{type}"); - } - return type.DType.toC(); - } - - /// - public override string VisitType(TupleType type) => type == TupleType.Void ? - "void" : - throw new InvalidProgramException($"The C Source Must Not Have TupleType {type}!"); -} \ No newline at end of file diff --git a/targets/Nncase.Targets.CSource/CodeGen/Interop.cs b/targets/Nncase.Targets.CSource/CodeGen/Interop.cs deleted file mode 100644 index 33d31af269..0000000000 --- a/targets/Nncase.Targets.CSource/CodeGen/Interop.cs +++ /dev/null @@ -1,140 +0,0 @@ -using System; -using System.Collections.Generic; -using System.Linq; -using System.Reflection; -using System.Reflection.Emit; -using System.Runtime.InteropServices; -using Nncase.IR; - -namespace Nncase.CodeGen; - -/// -/// -/// -internal class DynamicAssemble -{ - /// - /// name - /// - AssemblyName assemblyName; - /// - /// asm builder for whole module - /// - AssemblyBuilder asmBuilder; - /// - /// module buidler - /// - ModuleBuilder modBuilder; - /// - /// save the func name <=> func delegate type - /// - readonly Dictionary delegateTypes = new(); - - /// - /// a DynamicAssemble instance, it's contains one rtmodule's all functions defination. - /// - /// asmble name - public DynamicAssemble(string Name) - { - assemblyName = new AssemblyName(Name); - asmBuilder = AssemblyBuilder.DefineDynamicAssembly(assemblyName, AssemblyBuilderAccess.RunAndCollect); - modBuilder = asmBuilder.DefineDynamicModule(assemblyName.Name!); - - } - - /// - /// - /// - /// - /// func delegate type - public Type BuildDelegateType(Callable function) - { - Type deleType; - if (function.CheckedType is CallableType ctype) - { - deleType = CreateDelegateType(function.Name, ctype.ReturnType.ToType(), ctype.Parameters.Select(Interop.ToType).ToArray()); - } - else { throw new NotSupportedException(function.CheckedType?.ToString()); } - return deleType; - } - - /// - /// dynamic create delegate type for function. - /// - /// - /// - /// - /// - /// - public Type CreateDelegateType(string funcName, Type returnType, params Type[]? ParamTypes) - { - if (!delegateTypes.TryGetValue(funcName, out var ret)) - { - ParamTypes ??= new Type[] { }; - TypeBuilder tb = modBuilder.DefineType(funcName, TypeAttributes.Public | TypeAttributes.Sealed, typeof(MulticastDelegate)); - tb.DefineConstructor(MethodAttributes.Public | MethodAttributes.HideBySig | MethodAttributes.SpecialName | MethodAttributes.RTSpecialName, CallingConventions.Standard | CallingConventions.HasThis, new[] { typeof(object), typeof(IntPtr) }).SetImplementationFlags(MethodImplAttributes.Runtime | MethodImplAttributes.Managed); - tb.DefineMethod("Invoke", MethodAttributes.Public | MethodAttributes.Virtual | MethodAttributes.HideBySig | MethodAttributes.NewSlot, CallingConventions.Standard | CallingConventions.HasThis, returnType, ParamTypes).SetImplementationFlags(MethodImplAttributes.Runtime | MethodImplAttributes.Managed); - tb.DefineMethod("BeginInvoke", MethodAttributes.Public | MethodAttributes.Virtual | MethodAttributes.HideBySig | MethodAttributes.NewSlot, CallingConventions.Standard | CallingConventions.HasThis, typeof(IAsyncResult), ParamTypes.Concat(new[] { typeof(IAsyncResult), typeof(object) }).ToArray()).SetImplementationFlags(MethodImplAttributes.Runtime | MethodImplAttributes.Managed); - tb.DefineMethod("EndInvoke", MethodAttributes.Public | MethodAttributes.Virtual | MethodAttributes.HideBySig | MethodAttributes.NewSlot, CallingConventions.Standard | CallingConventions.HasThis, returnType, new[] { typeof(IAsyncResult) }).SetImplementationFlags(MethodImplAttributes.Runtime | MethodImplAttributes.Managed); - ret = tb.CreateType(); - if (ret is null) { throw new InvalidProgramException($"Can't Create The Func {funcName}'s delegate Type!"); } - delegateTypes.Add(funcName, ret); - } - return ret; - } -} - -/// -/// Interop helper -/// -public static class Interop -{ - /// - /// collect the all dynamic asmbs - /// - private static readonly Dictionary _definedAsms = new(); - - /// - /// convert the ir type to the system type - /// - /// - /// - /// - public static Type ToType(this IRType iRType) => iRType switch - { - TensorType { IsScalar: true, DType: PrimType { } primType } => primType.CLRType, - TensorType { IsScalar: true, DType: PointerType { ElemType: PrimType primType } } => primType.CLRType.MakeArrayType(), - TupleType ttype => (ttype == TupleType.Void) switch - { - true => typeof(void), - false => throw new NotSupportedException($"Can't Support the {ttype}!") - }, - _ => throw new NotSupportedException($"IRType is {iRType}!") - }; - - - /// - /// convrt function to delegate type - /// - /// input function - /// the dynamic lib name - /// - /// - public static Type ToDelegateType(this Callable function, string libName) - { - if (!_definedAsms.TryGetValue(libName, out var dyasm)) - { - dyasm = new DynamicAssemble(libName); - _definedAsms.Add(libName, dyasm); - } - return dyasm.BuildDelegateType(function); ; - } - - /// - /// bind the delegate to funcptr. - /// - /// - /// - /// - public static Delegate BindDelegate(this IntPtr funcPtr, Type funcType) => Marshal.GetDelegateForFunctionPointer(funcPtr, funcType); -} diff --git a/targets/Nncase.Targets.CSource/Nncase.Targets.CSource.csproj b/targets/Nncase.Targets.CSource/Nncase.Targets.CSource.csproj deleted file mode 100644 index 79226f6c03..0000000000 --- a/targets/Nncase.Targets.CSource/Nncase.Targets.CSource.csproj +++ /dev/null @@ -1,16 +0,0 @@ - - - - net6.0 - enable - enable - $(SolutionDir)/tools/StyleCopAnalyzers.ruleset - - - - - - - - - diff --git a/targets/Nncase.Targets.CSource/Schedule/CSourceScheduler.cs b/targets/Nncase.Targets.CSource/Schedule/CSourceScheduler.cs deleted file mode 100644 index b57cf2419d..0000000000 --- a/targets/Nncase.Targets.CSource/Schedule/CSourceScheduler.cs +++ /dev/null @@ -1,22 +0,0 @@ -using Nncase.IR; - -namespace Nncase.Schedule; - -public class CSourceScheduler : IScheduler -{ - - public CSourceScheduler(IR.IRModule main_module, ITarget target) - { - Module = main_module; - Target = target; - } - - public ITarget Target { get; set; } - public IRModule Module { get; set; } - - - IRModel IScheduler.Schedule(bool skip_buffer_alias) - { - return new IRModel(new[] { Module }); - } -} \ No newline at end of file diff --git a/tests/importer/onnx_/model/test_llama.py b/tests/importer/onnx_/model/test_llama.py new file mode 100644 index 0000000000..c6beee1ff6 --- /dev/null +++ b/tests/importer/onnx_/model/test_llama.py @@ -0,0 +1,127 @@ +# Copyright 2019-2021 Canaan Inc. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +"""System test: test demo""" +# pylint: disable=invalid-name, unused-argument, import-outside-toplevel + +# from lzma import MODE_FAST +# from xml.parsers.expat import model +import pytest +from onnx_test_runner import OnnxTestRunner + + +def test_demo(request): + runner = OnnxTestRunner("demo1", "/root/Workspace/config/llama_config.toml") + # model_file = r'/data/huochenghai/onnx_model/shufflenet-9.onnx' + # model_file = '/compiler/huochenghai/GNNE/nncase_demo/examples/release_isp_object_detect_nncase/data/yolov5sFocus_320x3.onnx' + # model_file = '/data/huochenghai/GNNE/nncase_demo/examples/release_isp_retinaface_mb_320_nncase/data/retinaface_mobile0.25_320.onnx' + # model_file = '/data/huochenghai/GNNE/nncase_demo/examples/release_isp_face_landmarks106_nncase/data/retinaface_mobile0.25_320.onnx' + # model_file = '/data/huochenghai/GNNE/nncase_demo/examples/release_isp_face_landmarks106_nncase/data/v3.onnx' + # model_file = '/data/huochenghai/fixed_input.onnx' + # model_file = '/data/huochenghai/GNNE/nncase_demo/examples/release_isp_face_alignment_from_box_nncase/data/mb1_120x120.onnx' + # model_file = '/data/huochenghai/GNNE/nncase_demo/examples/release_isp_face_recog_mbface_nncase/data/mbface.onnx' + # model_file = '/data/huochenghai/GNNE/k510-gnne-compiler-tests/zhoumeng-model/resnet50v1/model_f32.onnx' + # model_file = '/data/huochenghai/deploy_modify.onnx' + # model_file = '/data/huochenghai/nanodet_mobilenetv2_416.onnx' + # model_file = '/data/huochenghai/yolov5_face_n0.5_256x256.onnx' + # model_file = '/data/huochenghai/yolov5s_0.5_640_dropact.onnx' + # model_file = '/data/huochenghai/GNNE/nncase_demo/examples/release_isp_object_detect_nncase/data/yolov5sFocus_320x3.onnx' + # model_file = '/data/huochenghai/nanodet_yolov5s_0.5_head_nospp_640.onnx' + # model_file = '/data/huochenghai/dw_21x21_model.onnx' + # model_file = '/compiler/huochenghai/GNNE/nncase/tests_output/test_decoder_part/simplified.onnx' + # model_file = '/data/huochenghai/onnx_model/yolop_self.onnx' + # model_file = '/data/huochenghai/yolov5s_640x640_sigmoid_weights.onnx' + # model_file = '/data/huochenghai/models/yolov5s_640_sigmoid.onnx' + # model_file = '/data/huochenghai/best_batchsize16_300' + # model_file = '/data/huochenghai/candy-9.onnx' + # model_file = '/data/huochenghai/glint360k_cosface_r18_fp16_0.1.onnx' + # model_file = '/data/huochenghai/cls_fixed2.onnx' + # model_file = '/data/huochenghai/stereo_ranpara.onnx' + # model_file = '/data/huochenghai/stereoNet.onnx' + # model_file = '/data/huochenghai/deploy_modify.onnx' + # model_file = '/data/huochenghai/model.onnx' + # model_file = "/data/huochenghai/onnx_model/lite-transformer-encoder.onnx" + # model_file = '/data/huochenghai/onnx_model/lite-transformer-decoder.onnx' + # model_file = '/data/huochenghai/pose_vgg_half_030.onnx' + # model_file = '/data/huochenghai/pose1040.onnx' + # model_file = '/data/huochenghai/net.onnx' + # model_file = '/data/huochenghai/face_expression.onnx' + # model_file = '/data/huochenghai/model_fixed_input_size.onnx' + # model_file = '/data/huochenghai/model_none_lstm.onnx' + # model_file = '/data/huochenghai/squeezenet1_1.onnx' + # model_file = '/data/huochenghai/resnet_tom.onnx' + # model_file = '/data/huochenghai/Ultralight-Nano-SimplePose.onnx' + # model_file = "/data/huochenghai/yolov5sface_640x640_6output.onnx" + # model_file = "/data/huochenghai/model-y1.onnx" + # model_file = "/data/huochenghai/sim_5.onnx" + # model_file = "/data/huochenghai/person_yolov5s_0.5_nospp_640_nncase.onnx" + # model_file = "/data/huochenghai/rec_2_layer_lstm.onnx" + # model_file = "/compiler/huochenghai/east_128_640.onnx" + # model_file = "/compiler/huochenghai/GNNE/k230-gnne-compiler-tests/benchmark-test/CRNN/ocr_rec_model_32-608.onnx" + # model_file = "/data/huochenghai/scrfd_person_2.5g_fixed_input_size_simplify.onnx" + # model_file = "/data/huochenghai/models/model_128-640-11.onnx" + # model_file = "/data/huochenghai/GNNE/nncase/tests_output/simplified.onnx" + # model_file = "/data/huochenghai/dw_deconv.onnx" + # model_file = "/data/huochenghai/GNNE/nncase/tests_output/test_exchannel_rhs_shape0-lhs_shape0_/simplified.onnx" + # model_file = "/data/huochenghai/GNNE/k230-gnne-compiler-tests/benchmark-test/lite-transformer/lite_transformer_encoder_L10.onnx" + # model_file = "/data/huochenghai/GNNE/k230-gnne-compiler-tests/benchmark-test/lite-transformer/lite_transformer_decoder_L10.onnx" + # model_file = "/data/huochenghai/lite_transformer_decoder_L10.onnx" + # model_file = "/data/huochenghai/GNNE/k230-gnne-compiler-tests/benchmark-test/yolov5s/yolov5s_640_sigmoid.onnx" + # model_file = "/data/huochenghai/GNNE/k230-gnne-compiler-tests/benchmark-test/efficientnet/efficientnet_b0_224x224.onnx" + # model_file = "/data/huochenghai/GNNE/k230-gnne-compiler-tests/benchmark-test/mobile-facenet/mbface_sim_224.onnx" + # model_file = "/compiler/huochenghai/GNNE/k230-gnne-compiler-tests/benchmark-test/mobile-retinaface/retinaface_mobile0.25_320_simplified.onnx" + # model_file = "/data/huochenghai/GNNE/k230-gnne-compiler-tests/benchmark-test/mobilenet-v1-ssd/ssd_mobilenetv1_300x300.onnx" + # model_file = "/data/huochenghai/GNNE/k230-gnne-compiler-tests/benchmark-test/mobilenetv2-yolov3/yolov3_mobilenetv2_no_postprocess.onnx" + # model_file = "/data/huochenghai/GNNE/k230-gnne-compiler-tests/benchmark-test/mobilenet-v2-ssd/ssd_mobilenetv2_300x300.onnx" + # model_file = "/data/huochenghai/GNNE/k230-gnne-compiler-tests/benchmark-test/yolov5m/yolov5_m_320x320_with_sigmoid.onnx" + # model_file = "/compiler/huochenghai/GNNE/k230-gnne-compiler-tests/benchmark-test/yolov5s_face/yolov5sface_640x640_6output.onnx" + # model_file = "/data/huochenghai/GNNE/k230-gnne-compiler-tests/benchmark-test/yolox/yolox_s.onnx" + # model_file = "/data/huochenghai/GNNE/k230-gnne-compiler-tests/benchmark-test/Ultralight-SimplePose/Ultralight-SimplePose.onnx" + # model_file = "/compiler/huochenghai/GNNE/k230-gnne-compiler-tests/benchmark-test/reid/osnet_x1_0.onnx" + # model_file = "/compiler/huochenghai/GNNE/k230-gnne-compiler-tests/benchmark-test/yolov7/0_yolov7-tiny-silu_320x320.onnx" + # model_file = "/compiler/huochenghai/GNNE/k230-gnne-compiler-tests/benchmark-test/reid/osnet_ain_x1_0.onnx" + # model_file = "/data/huochenghai/GNNE/k230-gnne-compiler-tests/benchmark-test/reid/osnet_ibn_x1_0.onnx" + # model_file = "/compiler/huochenghai/GNNE/k230-gnne-compiler-tests/benchmark-test/wzm/wzm_stereo6g.onnx" + # model_file = "/compiler/huochenghai/GNNE/k230-gnne-compiler-tests/benchmark-test/wzm/wzm_stereo.onnx" + # model_file = "/data/huochenghai/GNNE/nncase/tests_output/test_matmul_constant-in_shape0_/simplified.onnx" + # model_file = "/compiler/huochenghai/GNNE/k230-gnne-compiler-tests/benchmark-test/lite-transformer/youdaonmt/encoder_model.onnx" + # model_file = "/compiler/huochenghai/GNNE/k230-gnne-compiler-tests/benchmark-test/lite-transformer/youdaonmt/decoder_model.onnx" + # model_file = "/compiler/huochenghai/GNNE/k230-gnne-compiler-tests/benchmark-test/resnetv1_50/onnx/resnet50v1.onnx" + # model_file = "/compiler/huochenghai/GNNE/k230-gnne-ccompiler-tests/benchmark-test/lite-transformer/lite_transformer_encoder_L10.onnx" + # model_file = "/compiler/huochenghai/can3_10.0s_20221011084724.onnx" + # model_file = "/compiler/huochenghai/lstm_256.onnx" + # model_file = "/compiler/huochenghai/weilai/simplified_det.onnx" + # model_file = "/compiler/huochenghai/models/daniu_nmt_enc.onnx" + # model_file = "/compiler/huochenghai/GNNE/k230-gnne-compiler-tests/benchmark-test/centersnap/CenterSnap.onnx" + # model_file = "/compiler/huochenghai/GNNE/nncase/tests_output/daniu_enc.onnx" + # model_file = "/compiler/huochenghai/GNNE/k230-gnne-compiler-tests/benchmark-test/yolop/yolop_self.onnx" + # model_file = "/compiler/huochenghai/GNNE/k230-gnne-compiler-tests/benchmark-test/daniu/e2z/dec.onnx" + # model_file = "/compiler/huochenghai/GNNE/k230-gnne-compiler-tests/benchmark-test/daniu/z2e/enc.onnx" + # model_file = "/compiler/huochenghai/GNNE/k230-gnne-compiler-tests/benchmark-test/daniu/TTS/zho/fix.onnx" + # model_file = "/compiler/huochenghai/GNNE/k230-gnne-compiler-tests/benchmark-test/CRNN/ocr_rec_model_32-608.onnx" + # model_file = "/compiler/huochenghai/GNNE/nncase/tests_output/crnn_part.onnx" + # model_file = "/compiler/huochenghai/GNNE/k230-gnne-compiler-tests/benchmark-test/mobile-facenet/mbface_sim_224.onnx" + # model_file = "/compiler/huochenghai/GNNE/k230-gnne-compiler-tests/FasterTransformer/LongFormer/longformer-base-4096.onnx" + # model_file = '/data/huochenghai/GNNE/k230-gnne-compiler-tests/StableDiffusion/onnx-stable-diffusion-v1-5/vae_decoder/model.onnx' + # model_file = "/data/huochenghai/llama_scrach/65B/decoder-merge-0.onnx" + # model_file = "/root/Downloads/decoder-merge-0.onnx" + model_file = "/root/Downloads/64B-4-layers/decoder-merge-all.onnx" + + # runner.set_shape_var({"batch_size": 1, "num_channels_latent": 4, "height_latent": 64, "width_latent": 64}) + runner.set_shape_var({"N": 384}) + runner.run(model_file) + + +if __name__ == "__main__": + pytest.main( + ['-vvs', __file__]) diff --git a/tests/importer/onnx_/model/test_text_encoder.py b/tests/importer/onnx_/model/test_text_encoder.py new file mode 100644 index 0000000000..8642743bcc --- /dev/null +++ b/tests/importer/onnx_/model/test_text_encoder.py @@ -0,0 +1,39 @@ +# Copyright 2019-2021 Canaan Inc. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +"""System test: test demo""" +# pylint: disable=invalid-name, unused-argument, import-outside-toplevel + +# from lzma import MODE_FAST +# from xml.parsers.expat import model +import pytest +from onnx_test_runner import OnnxTestRunner + + +def test_demo(request): + # runner = OnnxTestRunner(request.node.name, "/root/Workspace/nncase/tests/importer/onnx_/model/llama_config.toml") + runner = OnnxTestRunner("text_encoder", "/root/Workspace/config/text_config.toml") + # runner = OnnxTestRunner("text_encoder") + # + model_file = "/root/Downloads/Models/text_encoder_model.onnx" + + # runner.set_shape_var({"batch_size": 1, "num_channels_latent": 4, "height_latent": 64, "width_latent": 64}) + # runner.set_shape_var({"N": 384}) + runner.set_shape_var({"batch_size": 1, "sequence_length": 77}) + # runner.set_shape_var({"batch_size:1", "sequence_length:77"}) + runner.run(model_file) + + +if __name__ == "__main__": + pytest.main( + ['-vvs', __file__]) diff --git a/tests/importer/onnx_/model/test_unet.py b/tests/importer/onnx_/model/test_unet.py new file mode 100644 index 0000000000..b504723bef --- /dev/null +++ b/tests/importer/onnx_/model/test_unet.py @@ -0,0 +1,40 @@ +# Copyright 2019-2021 Canaan Inc. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +"""System test: test demo""" +# pylint: disable=invalid-name, unused-argument, import-outside-toplevel + +# from lzma import MODE_FAST +# from xml.parsers.expat import model +import pytest +from onnx_test_runner import OnnxTestRunner + + +def test_demo(request): + # runner = OnnxTestRunner(request.node.name, "/root/Workspace/nncase/tests/importer/onnx_/model/llama_config.toml") + runner = OnnxTestRunner("unet", "/root/Workspace/config/unet_config.toml") + # runner = OnnxTestRunner("unet") + # + model_file = "/root/Downloads/Models/unet/model.onnx" + + # runner.set_shape_var({"batch_size": 1, "num_channels_latent": 4, "height_latent": 64, "width_latent": 64}) + # runner.set_shape_var({"N": 384}) + runner.set_shape_var({"batch_size": 2, "num_channels": 4, "height": 64, + "width": 64, "steps": 2, "sequence_length": 77}) + # runner.set_shape_var({"batch_size:1", "sequence_length:77"}) + runner.run(model_file) + + +if __name__ == "__main__": + pytest.main( + ['-vvs', __file__]) diff --git a/tests/importer/onnx_/model/test_vae_decoder.py b/tests/importer/onnx_/model/test_vae_decoder.py new file mode 100644 index 0000000000..0c3eee5818 --- /dev/null +++ b/tests/importer/onnx_/model/test_vae_decoder.py @@ -0,0 +1,40 @@ +# Copyright 2019-2021 Canaan Inc. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +"""System test: test demo""" +# pylint: disable=invalid-name, unused-argument, import-outside-toplevel + +# from lzma import MODE_FAST +# from xml.parsers.expat import model +import pytest +from onnx_test_runner import OnnxTestRunner + + +def test_demo(request): + runner = OnnxTestRunner("test_vae_decoder", + "/root/Workspace/config/vae_config.toml") + # runner = OnnxTestRunner("test_vae_decoder") + model_file = "/root/Downloads/Models/vae_decoder.onnx" + # model_file = "/root/Downloads/Models/modified_modified_vae_decoder.onnx" + # model_file = "/root/Downloads/Models/modified_vae_decoder.onnx" + # model_file = "/root/Downloads/Models/model_sim_huo.onnx" + + runner.set_shape_var({"batch_size": 1, "num_channels_latent": 4, + "height_latent": 64, "width_latent": 64}) + # runner.set_shape_var({"N": 384}) + runner.run(model_file) + + +if __name__ == "__main__": + pytest.main( + ['-vvs', __file__]) diff --git a/tests/kernels/test_concat.cpp b/tests/kernels/test_concat.cpp index c7da18b3bf..9aedb892f8 100644 --- a/tests/kernels/test_concat.cpp +++ b/tests/kernels/test_concat.cpp @@ -86,14 +86,7 @@ TEST_P(ConcatTest, Concat) { fields.push_back(field2); auto output_tuple = tuple(std::in_place, std::move(fields)); - int64_t axis_ptr[] = {axis_value}; - auto axis = - hrt::create(dt_int64, {1}, - {reinterpret_cast(axis_ptr), sizeof(axis_ptr)}, - true, host_runtime_tensor::pool_cpu_only) - .expect("create tensor failed"); - - auto output = kernels::stackvm::concat(output_tuple, axis.impl()) + auto output = kernels::stackvm::concat((int)axis_value, output_tuple) .expect("concat failed"); runtime_tensor actual(output.as().expect("as tensor failed")); diff --git a/tests/kernels/test_gather.cpp b/tests/kernels/test_gather.cpp index 5910d17cc1..65d5ca45c8 100644 --- a/tests/kernels/test_gather.cpp +++ b/tests/kernels/test_gather.cpp @@ -37,7 +37,7 @@ class GatherTest : public KernelTest, auto shape = GetShapeArray("lhs_shape"); auto indices_shape = GetShapeArray("indices_shape"); auto indices_value = GetDataArray("indices_value"); - auto value = GetNumber("axis"); + auto axis = GetNumber("axis"); auto typecode = GetDataType("lhs_type"); input = hrt::create(typecode, shape, host_runtime_tensor::pool_cpu_only) @@ -61,17 +61,9 @@ class GatherTest : public KernelTest, true, host_runtime_tensor::pool_cpu_only) .expect("create tensor failed"); - batchDims_value = value >= 0 - ? (size_t)value >= shape.size() ? -1 : value - : -(size_t)value > shape.size() ? -1 - : value; - - int64_t batchDims_array[1] = {batchDims_value}; - batchDims = hrt::create(dt_int64, dims_t{1}, - {reinterpret_cast(batchDims_array), - sizeof(batchDims_array)}, - true, host_runtime_tensor::pool_cpu_only) - .expect("create tensor failed"); + batchDims_value = axis >= 0 ? (size_t)axis >= shape.size() ? -1 : axis + : -(size_t)axis > shape.size() ? -1 + : axis; } void TearDown() override { CLEAR_SUBCASE() } @@ -79,7 +71,6 @@ class GatherTest : public KernelTest, protected: runtime_tensor input; runtime_tensor indices; - runtime_tensor batchDims; int64_t batchDims_value; }; @@ -103,7 +94,7 @@ TEST_P(GatherTest, gather) { // actual auto output = - kernels::stackvm::gather(input.impl(), batchDims.impl(), indices.impl()) + kernels::stackvm::gather(batchDims_value, input.impl(), indices.impl()) .expect("gather failed"); runtime_tensor actual(output.as().expect("as tensor failed")); diff --git a/tests/kernels/test_layer_norm.cpp b/tests/kernels/test_layer_norm.cpp index cc8a2696bd..5591bfe57d 100644 --- a/tests/kernels/test_layer_norm.cpp +++ b/tests/kernels/test_layer_norm.cpp @@ -106,8 +106,8 @@ TEST_P(LayerNormTest, layer_norm) { // actual auto output = - kernels::stackvm::layer_norm((int32_t)axis_value, eps, input.impl(), - scale.impl(), b.impl()) + kernels::stackvm::layer_norm((int32_t)axis_value, eps, false, + input.impl(), scale.impl(), b.impl()) .expect("layer_norm failed"); runtime_tensor actual(output.as().expect("as tensor failed")); diff --git a/tests/onnx_test_runner.py b/tests/onnx_test_runner.py index d1299bb234..dbde85f780 100644 --- a/tests/onnx_test_runner.py +++ b/tests/onnx_test_runner.py @@ -13,7 +13,7 @@ # limitations under the License. # pylint: disable=invalid-name, unused-argument, import-outside-toplevel -from onnx import version_converter, helper +from onnx import version_converter, helper, external_data_helper import onnxsim import onnxruntime as ort import onnx @@ -61,8 +61,16 @@ def run(self, model_file): if self.case_dir != os.path.dirname(model_file): new_file = os.path.join(self.case_dir, 'test.onnx') shutil.copy(model_file, new_file) - if os.path.exists(model_file + "_data"): - shutil.copy(model_file + "_data", self.case_dir) + for tensor in external_data_helper._get_all_tensors(onnx.load(model_file, load_external_data=False)): + if external_data_helper.uses_external_data(tensor): + info = external_data_helper.ExternalDataInfo(tensor) + file_location = external_data_helper._sanitize_path(info.location) + external_data_src_path = os.path.join( + os.path.dirname(model_file), file_location) + external_data_dst_path = os.path.join( + self.case_dir, file_location) + if not os.path.exists(external_data_dst_path): + os.symlink(external_data_src_path, external_data_dst_path) model_file = new_file if not self.inputs: @@ -176,7 +184,7 @@ def is_dynamic(output): outputs = onnx_model.graph.output self.dynamic = any(is_dynamic(output) for output in outputs) # make a static model for infer output - if self.dynamic: + if self.dynamic and onnx_model.ByteSize() < 2147483648: input_shapes = list(map(lambda input: {input['name']: input['shape']}, self.inputs)) input_shapes = dict(ChainMap(*input_shapes)) (onnx_model, _) = onnxsim.simplify(onnx_model, input_shapes=input_shapes) diff --git a/third_party/onnx/packages.lock.json b/third_party/onnx/packages.lock.json index 0bdbf312bd..207b93b556 100644 --- a/third_party/onnx/packages.lock.json +++ b/third_party/onnx/packages.lock.json @@ -16,17 +16,17 @@ }, "StyleCop.Analyzers": { "type": "Direct", - "requested": "[1.2.0-beta.507, )", - "resolved": "1.2.0-beta.507", - "contentHash": "/FtugDT66cKJJ+GGH7rNpG6UDrT4iIWz45M6lrXXHobDUFDHw+q5VgkbiR+6ffTO564ge7w6fQh/eoQhVdJO8Q==", + "requested": "[1.2.0-beta.435, )", + "resolved": "1.2.0-beta.435", + "contentHash": "TADk7vdGXtfTnYCV7GyleaaRTQjfoSfZXprQrVMm7cSJtJbFc1QIbWPyLvrgrfGdfHbGmUPvaN4ODKNxg2jgPQ==", "dependencies": { - "StyleCop.Analyzers.Unstable": "1.2.0.507" + "StyleCop.Analyzers.Unstable": "1.2.0.435" } }, "StyleCop.Analyzers.Unstable": { "type": "Transitive", - "resolved": "1.2.0.507", - "contentHash": "gTY3IQdRqDJ4hbhSA3e/R48oE8b/OiKfvwkt1QdNVfrJK2gMHBV8ldaHJ885jxWZfllK66soa/sdcjh9bX49Tw==" + "resolved": "1.2.0.435", + "contentHash": "ouwPWZxbOV3SmCZxIRqHvljkSzkCyi1tDoMzQtDb/bRP8ctASV/iRJr+A2Gdj0QLaLmWnqTWDrH82/iP+X80Lg==" } } } diff --git a/third_party/tflite/packages.lock.json b/third_party/tflite/packages.lock.json index 73d7544eab..325adfa712 100644 --- a/third_party/tflite/packages.lock.json +++ b/third_party/tflite/packages.lock.json @@ -10,17 +10,17 @@ }, "StyleCop.Analyzers": { "type": "Direct", - "requested": "[1.2.0-beta.507, )", - "resolved": "1.2.0-beta.507", - "contentHash": "/FtugDT66cKJJ+GGH7rNpG6UDrT4iIWz45M6lrXXHobDUFDHw+q5VgkbiR+6ffTO564ge7w6fQh/eoQhVdJO8Q==", + "requested": "[1.2.0-beta.435, )", + "resolved": "1.2.0-beta.435", + "contentHash": "TADk7vdGXtfTnYCV7GyleaaRTQjfoSfZXprQrVMm7cSJtJbFc1QIbWPyLvrgrfGdfHbGmUPvaN4ODKNxg2jgPQ==", "dependencies": { - "StyleCop.Analyzers.Unstable": "1.2.0.507" + "StyleCop.Analyzers.Unstable": "1.2.0.435" } }, "StyleCop.Analyzers.Unstable": { "type": "Transitive", - "resolved": "1.2.0.507", - "contentHash": "gTY3IQdRqDJ4hbhSA3e/R48oE8b/OiKfvwkt1QdNVfrJK2gMHBV8ldaHJ885jxWZfllK66soa/sdcjh9bX49Tw==" + "resolved": "1.2.0.435", + "contentHash": "ouwPWZxbOV3SmCZxIRqHvljkSzkCyi1tDoMzQtDb/bRP8ctASV/iRJr+A2Gdj0QLaLmWnqTWDrH82/iP+X80Lg==" } } } diff --git a/toolchains/k230_cpu.linux.toolchain.cmake b/toolchains/k230_cpu.linux.toolchain.cmake new file mode 100644 index 0000000000..4bd9aba679 --- /dev/null +++ b/toolchains/k230_cpu.linux.toolchain.cmake @@ -0,0 +1,33 @@ +set(CMAKE_SYSTEM_NAME Linux) +set(CMAKE_SYSTEM_PROCESSOR riscv64) + +if(DEFINED ENV{RISCV_ROOT_PATH}) + file(TO_CMAKE_PATH $ENV{RISCV_ROOT_PATH} RISCV_ROOT_PATH) +endif() + +if(NOT RISCV_ROOT_PATH) + message(FATAL_ERROR "RISCV_ROOT_PATH env must be defined") +endif() + +set(RISCV_ROOT_PATH ${RISCV_ROOT_PATH} CACHE STRING "root path to riscv toolchain") +set(CMAKE_C_COMPILER "${RISCV_ROOT_PATH}/bin/riscv64-unknown-linux-musl-gcc") +set(CMAKE_CXX_COMPILER "${RISCV_ROOT_PATH}/bin/riscv64-unknown-linux-musl-g++") +set(CMAKE_FIND_ROOT_PATH "${RISCV_ROOT_PATH}/riscv64-unknown-linux-musl") + +set(CMAKE_FIND_ROOT_PATH_MODE_PROGRAM NEVER) +set(CMAKE_FIND_ROOT_PATH_MODE_LIBRARY ONLY) +set(CMAKE_FIND_ROOT_PATH_MODE_INCLUDE ONLY) +set(ENABLE_VULKAN_RUNTIME OFF) +set(ENABLE_OPENMP OFF) +set(ENABLE_HALIDE OFF) +set(DEFAULT_BUILTIN_RUNTIMES OFF) +set(DEFAULT_SHARED_RUNTIME_TENSOR_PLATFORM_IMPL ON) +set(BUILD_BENCHMARK OFF) + +set(CMAKE_C_FLAGS "${CMAKE_C_FLAGS} -march=rv64imafdcv -mabi=lp64d -mcmodel=medany") +set(CMAKE_CXX_FLAGS "${CMAKE_CXX_FLAGS} -march=rv64imafdcv -mabi=lp64d -mcmodel=medany") +set(CMAKE_EXE_LINKER_FLAGS "-T ${K230_LINUX_SDK_DIR}/src/big/rt-smart/userapps/linker_scripts/riscv64/link.lds --static") + +set(BUILDING_RUNTIME ON) +set(ENABLE_CPU_RUNTIME ON) +set(BUILD_SHARED_LIBS OFF) \ No newline at end of file diff --git a/tools/Nncase.SourceGenerator/Pattern/PatternGenerator.cs b/tools/Nncase.SourceGenerator/Pattern/PatternGenerator.cs index 7495791550..8e2aa7cfa2 100644 --- a/tools/Nncase.SourceGenerator/Pattern/PatternGenerator.cs +++ b/tools/Nncase.SourceGenerator/Pattern/PatternGenerator.cs @@ -85,7 +85,7 @@ select Parameter(Identifier(f.Name.ToLower())) // var x = name_params[0]; statements.Add(ParseStatement(@$"return new( new OpPattern<{cand.Op.ToDisplayString()}>(x => {condition}, {(name_params[0] != null ? "target_name" : "null")}), -new VArgsPattern (new[]{{ {inputs} }}, null), +new VArgsPattern (new Pattern[]{{ {inputs} }}, null), {(name_params[1] != null ? "call_name" : "null")});"). WithLeadingTrivia(ElasticTab). WithTrailingTrivia(ElasticLineFeed)); @@ -125,7 +125,7 @@ select Parameter(Identifier(f.Name.ToLower())) // 1.3 build method return statements.Add(ParseStatement(@$"return new( new OpPattern<{cand.Op.ToDisplayString()}>(condition, {(name_params[0] != null ? "target_name" : "null")}), -new VArgsPattern( new [] {{ {inputs} }}, null ), +new VArgsPattern( new Pattern[] {{ {inputs} }}, null ), {(name_params[1] != null ? "call_name" : "null")});"). WithLeadingTrivia(ElasticTab). WithTrailingTrivia(ElasticLineFeed)); diff --git a/tools/Nncase.SourceGenerator/Rule/RuleGenerator.cs b/tools/Nncase.SourceGenerator/Rule/RuleGenerator.cs index 57cf983aba..f2af3c5a2d 100644 --- a/tools/Nncase.SourceGenerator/Rule/RuleGenerator.cs +++ b/tools/Nncase.SourceGenerator/Rule/RuleGenerator.cs @@ -213,6 +213,7 @@ private void Execute(SourceProductionContext context, ImmutableArray(method)) .WithAttributeLists(new SyntaxList() { }) diff --git a/tools/Nncase.SourceGenerator/packages.lock.json b/tools/Nncase.SourceGenerator/packages.lock.json index 50bdccb9fa..0430a3e081 100644 --- a/tools/Nncase.SourceGenerator/packages.lock.json +++ b/tools/Nncase.SourceGenerator/packages.lock.json @@ -34,11 +34,11 @@ }, "StyleCop.Analyzers": { "type": "Direct", - "requested": "[1.2.0-beta.507, )", - "resolved": "1.2.0-beta.507", - "contentHash": "/FtugDT66cKJJ+GGH7rNpG6UDrT4iIWz45M6lrXXHobDUFDHw+q5VgkbiR+6ffTO564ge7w6fQh/eoQhVdJO8Q==", + "requested": "[1.2.0-beta.435, )", + "resolved": "1.2.0-beta.435", + "contentHash": "TADk7vdGXtfTnYCV7GyleaaRTQjfoSfZXprQrVMm7cSJtJbFc1QIbWPyLvrgrfGdfHbGmUPvaN4ODKNxg2jgPQ==", "dependencies": { - "StyleCop.Analyzers.Unstable": "1.2.0.507" + "StyleCop.Analyzers.Unstable": "1.2.0.435" } }, "Microsoft.CodeAnalysis.Common": { @@ -62,8 +62,8 @@ }, "StyleCop.Analyzers.Unstable": { "type": "Transitive", - "resolved": "1.2.0.507", - "contentHash": "gTY3IQdRqDJ4hbhSA3e/R48oE8b/OiKfvwkt1QdNVfrJK2gMHBV8ldaHJ885jxWZfllK66soa/sdcjh9bX49Tw==" + "resolved": "1.2.0.435", + "contentHash": "ouwPWZxbOV3SmCZxIRqHvljkSzkCyi1tDoMzQtDb/bRP8ctASV/iRJr+A2Gdj0QLaLmWnqTWDrH82/iP+X80Lg==" }, "System.Buffers": { "type": "Transitive", diff --git a/tools/stackvm_gen/IsaGen/packages.lock.json b/tools/stackvm_gen/IsaGen/packages.lock.json index 6440c55973..fd04d9883e 100644 --- a/tools/stackvm_gen/IsaGen/packages.lock.json +++ b/tools/stackvm_gen/IsaGen/packages.lock.json @@ -27,11 +27,11 @@ }, "StyleCop.Analyzers": { "type": "Direct", - "requested": "[1.2.0-beta.507, )", - "resolved": "1.2.0-beta.507", - "contentHash": "/FtugDT66cKJJ+GGH7rNpG6UDrT4iIWz45M6lrXXHobDUFDHw+q5VgkbiR+6ffTO564ge7w6fQh/eoQhVdJO8Q==", + "requested": "[1.2.0-beta.435, )", + "resolved": "1.2.0-beta.435", + "contentHash": "TADk7vdGXtfTnYCV7GyleaaRTQjfoSfZXprQrVMm7cSJtJbFc1QIbWPyLvrgrfGdfHbGmUPvaN4ODKNxg2jgPQ==", "dependencies": { - "StyleCop.Analyzers.Unstable": "1.2.0.507" + "StyleCop.Analyzers.Unstable": "1.2.0.435" } }, "Microsoft.AspNetCore.Mvc.Razor.Extensions": { @@ -169,8 +169,8 @@ }, "StyleCop.Analyzers.Unstable": { "type": "Transitive", - "resolved": "1.2.0.507", - "contentHash": "gTY3IQdRqDJ4hbhSA3e/R48oE8b/OiKfvwkt1QdNVfrJK2gMHBV8ldaHJ885jxWZfllK66soa/sdcjh9bX49Tw==" + "resolved": "1.2.0.435", + "contentHash": "ouwPWZxbOV3SmCZxIRqHvljkSzkCyi1tDoMzQtDb/bRP8ctASV/iRJr+A2Gdj0QLaLmWnqTWDrH82/iP+X80Lg==" }, "System.Buffers": { "type": "Transitive", @@ -238,6 +238,7 @@ "Microsoft.Extensions.Options": "[6.0.0, )", "Microsoft.Toolkit.HighPerformance": "[7.1.1, )", "NetFabric.Hyperlinq": "[3.0.0-beta48, )", + "System.CommandLine": "[2.0.0-beta4.22272.1, )", "System.Reactive": "[5.0.0, )" } }, @@ -312,6 +313,12 @@ "System.Runtime.CompilerServices.Unsafe": "5.0.0" } }, + "System.CommandLine": { + "type": "CentralTransitive", + "requested": "[2.0.0-beta4.22272.1, )", + "resolved": "2.0.0-beta4.22272.1", + "contentHash": "1uqED/q2H0kKoLJ4+hI2iPSBSEdTuhfCYADeJrAqERmiGQ2NNacYKRNEQ+gFbU4glgVyK8rxI+ZOe1onEtr/Pg==" + }, "System.Reactive": { "type": "CentralTransitive", "requested": "[5.0.0, )",