diff --git a/.github/workflows/compiler-build.yml b/.github/workflows/compiler-build.yml
index ab9877f341..4806a15baf 100644
--- a/.github/workflows/compiler-build.yml
+++ b/.github/workflows/compiler-build.yml
@@ -172,7 +172,7 @@ jobs:
- {name: x86_64-windows, os: windows-latest, shell: bash}
env:
- VULKANSDK_VER: 1.2.182.0
+ VULKANSDK_VER: 1.3.268.0
steps:
- uses: actions/checkout@v3
@@ -211,8 +211,8 @@ jobs:
- name: Set up test environment (Linux)
run: |
- wget https://sdk.lunarg.com/sdk/download/${VULKANSDK_VER}/linux/vulkansdk-linux-x86_64-${VULKANSDK_VER}.tar.gz -O vulkansdk.tar.gz
- tar xf vulkansdk.tar.gz
+ wget https://sdk.lunarg.com/sdk/download/${VULKANSDK_VER}/linux/vulkansdk-linux-x86_64-${VULKANSDK_VER}.tar.xz -O vulkansdk.tar.xz
+ tar xf vulkansdk.tar.xz
sudo cp -P ${VULKANSDK_VER}/x86_64/lib/libvulkan.so* /usr/local/lib/
wget https://github.com/sunnycase/swiftshader/releases/download/v1.0/swiftshader-ubuntu-18.04-x86_64.zip -O swiftshader.zip
unzip swiftshader.zip
@@ -225,8 +225,8 @@ jobs:
- name: Set up test environment (Windows)
shell: pwsh
run: |
- Invoke-WebRequest -Uri https://sdk.lunarg.com/sdk/download/${env:VULKANSDK_VER}/windows/VulkanSDK-${env:VULKANSDK_VER}-Installer.exe -O VulkanSDK-Installer.exe
- .\VulkanSDK-Installer.exe /S
+ # Invoke-WebRequest -Uri https://sdk.lunarg.com/sdk/download/${env:VULKANSDK_VER}/windows/VulkanSDK-${env:VULKANSDK_VER}-Installer.exe -O VulkanSDK-Installer.exe
+ # .\VulkanSDK-Installer.exe /S
Invoke-WebRequest -Uri https://github.com/sunnycase/swiftshader/releases/download/v1.0/swiftshader-windows-2019-x86_64.zip -OutFile swiftshader.zip
Expand-Archive swiftshader.zip
Copy-Item swiftshader\lib\vk_swiftshader_icd.json swiftshader\bin\
diff --git a/.github/workflows/compiler-python-release.yml b/.github/workflows/compiler-python-release.yml
index 9176290bd7..5e0db927a0 100644
--- a/.github/workflows/compiler-python-release.yml
+++ b/.github/workflows/compiler-python-release.yml
@@ -58,7 +58,7 @@ jobs:
- {name: x86_64-windows, os: windows-latest, arch: x64}
env:
- VULKANSDK_VER: 1.2.182.0
+ VULKANSDK_VER: 1.3.268.0
steps:
- uses: actions/checkout@v3
diff --git a/.gitignore b/.gitignore
index 2dfd485fbe..5b1e72c18f 100644
--- a/.gitignore
+++ b/.gitignore
@@ -68,7 +68,7 @@ artifacts/
*.pidb
*.svclog
*.scc
-
+*.bin
# Chutzpah Test files
_Chutzpah*
@@ -306,4 +306,4 @@ cmake-build-*
*gmodel_dump_dir*
*.ipynb_checkpoints*
# Auto generated files
-# generated/
\ No newline at end of file
+# generated/
diff --git a/CMakeLists.txt b/CMakeLists.txt
index b97e01e62c..7ac7539a47 100644
--- a/CMakeLists.txt
+++ b/CMakeLists.txt
@@ -12,7 +12,6 @@ endif()
if(NOT DEFINED NNCASE_VERSION_SUFFIX)
find_package (Git)
-
execute_process(
COMMAND ${GIT_EXECUTABLE} describe --always --dirty --tag
WORKING_DIRECTORY ${CMAKE_SOURCE_DIR}
@@ -274,5 +273,5 @@ if(BUILD_TESTING)
endif()
# Modules
-#add_subdirectory(modules/k210)
+
#add_subdirectory(modules/vulkan)
diff --git a/Directory.Packages.props b/Directory.Packages.props
index 3f2166020f..8e5ef00da3 100644
--- a/Directory.Packages.props
+++ b/Directory.Packages.props
@@ -57,8 +57,9 @@
-
-
+
+
+
diff --git a/docs/MixQuant.md b/docs/MixQuant.md
index c70de011f0..76197cc40d 100644
--- a/docs/MixQuant.md
+++ b/docs/MixQuant.md
@@ -16,11 +16,13 @@ compiler.use_ptq(ptq_options)
```python
ptq_options.quant_scheme = ""
+ptq_options.quant_scheme_strict_mode = False
ptq_options.export_quant_scheme = False
ptq_options.export_weight_range_by_channel = False
```
* **quant_scheme:导入量化参数配置文件的路径**
+* **quant_scheme_strict_mode:是否严格按照quant_scheme执行量化**
* **export_quant_scheme:是否导出量化参数配置文件**
* **export_weight_range_by_channel:是否导出** `bychannel`形式的weights量化参数,为了保证量化效果,该参数建议设置为 `True`
@@ -36,6 +38,7 @@ compile_options.dump_ir = True
```python
ptq_options.quant_scheme = ""
+ptq_options.quant_scheme_strict_mode = False
ptq_options.export_quant_scheme = True
ptq_options.export_weight_range_by_channel = True
```
@@ -108,6 +111,7 @@ ptq_options.export_weight_range_by_channel = True
```python
ptq_options.quant_scheme = "./QuantScheme.json" # path to your 'QuantScheme.json'
+ptq_options.quant_scheme_strict_mode = False # Whether to strictly follow quant_scheme for quantification
ptq_options.export_quant_scheme = False
ptq_options.export_weight_range_by_channel = False # whatever
```
diff --git a/docs/USAGE_v2.md b/docs/USAGE_v2.md
index 68944d7bbf..a7aa4d3809 100644
--- a/docs/USAGE_v2.md
+++ b/docs/USAGE_v2.md
@@ -228,6 +228,7 @@ PTQTensorOptions类, 用于配置nncase PTQ选项,各属性说明如下
| dump_quant_error | bool | 否 | 是否生成量化损失,默认为False。在 `dump_ir=True`时生效 |
| dump_quant_error_symmetric_for_signed | bool | 否 | 是否生成使用范围对称的量化损失,默认为True。在 `dump_ir=True`时生效 |
| quant_scheme | string | 否 | 量化配置文件路径,默认为“ ”。在 `dump_ir=True`时生效 |
+| quant_scheme_strict_mode | bool | 否 | 是否严格按照quant_scheme执行量化,默认为False。在 `quant_scheme`不为空时生效 |
| export_quant_scheme | bool | 否 | 是否导出量化配置文件,默认为False。在 `dump_ir=True`时生效 |
| export_weight_range_by_channel | bool | 否 | 导出量化配置文件时,是否按照channel统计权重的范围,默认为False。在 `dump_ir=True`时生效 |
diff --git a/docs/USAGE_v2_EN.md b/docs/USAGE_v2_EN.md
index 5800ccfc7e..8a19714c32 100644
--- a/docs/USAGE_v2_EN.md
+++ b/docs/USAGE_v2_EN.md
@@ -226,6 +226,7 @@ PTQTensorOptions is used to configure PTQ options. The details of all attributes
| dump_quant_error | bool | N | Specify whether dump quantification error, False by default. The parameters following worked when `dump_ir=True`. |
| dump_quant_error_symmetric_for_signed | bool | N | Specify whether dump quantification error by symmetric for signed number,True by default. |
| quant_scheme | string | N | specify the path of quantification scheme file,"" by default. |
+| quant_scheme_strict_mode | bool | N | Specify whether strictly follow quant_scheme for quantification, False by default. |
| export_quant_scheme | bool | N | Specify whether export quantification scheme, False by default. |
| export_weight_range_by_channel | bool | N | Specify whether export weights range by channel, False by default. |
diff --git a/examples/audio/tts_1.wav b/examples/audio/tts_1.wav
new file mode 100755
index 0000000000..029d3c4f8d
Binary files /dev/null and b/examples/audio/tts_1.wav differ
diff --git a/examples/audio/tts_2.wav b/examples/audio/tts_2.wav
new file mode 100755
index 0000000000..864fe17bbb
Binary files /dev/null and b/examples/audio/tts_2.wav differ
diff --git a/examples/audio/tts_3.wav b/examples/audio/tts_3.wav
new file mode 100755
index 0000000000..deb31e9a59
Binary files /dev/null and b/examples/audio/tts_3.wav differ
diff --git a/examples/user_guide/k230_simulate-EN.ipynb b/examples/user_guide/k230_simulate-EN.ipynb
index a8630394df..b4c5aa2b91 100644
--- a/examples/user_guide/k230_simulate-EN.ipynb
+++ b/examples/user_guide/k230_simulate-EN.ipynb
@@ -150,6 +150,7 @@
" # mix quantize options\n",
" # more details in docs/MixQuant.md\n",
" ptq_options.quant_scheme = \"\"\n",
+ " ptq_options.quant_scheme_strict_mode = False\n",
" ptq_options.export_quant_scheme = False\n",
" ptq_options.export_weight_range_by_channel = False\n",
" ############################################\n",
diff --git a/examples/user_guide/k230_simulate-ZH.ipynb b/examples/user_guide/k230_simulate-ZH.ipynb
index e9b9dc5329..3a099a1ea3 100644
--- a/examples/user_guide/k230_simulate-ZH.ipynb
+++ b/examples/user_guide/k230_simulate-ZH.ipynb
@@ -150,6 +150,7 @@
" # mix quantize options\n",
" # more details in docs/MixQuant.md\n",
" ptq_options.quant_scheme = \"\"\n",
+ " ptq_options.quant_scheme_strict_mode = False\n",
" ptq_options.export_quant_scheme = False\n",
" ptq_options.export_weight_range_by_channel = False\n",
" ############################################\n",
diff --git a/modules/Nncase.Modules.CPU/packages.lock.json b/modules/Nncase.Modules.CPU/packages.lock.json
deleted file mode 100644
index 5dd73e5bda..0000000000
--- a/modules/Nncase.Modules.CPU/packages.lock.json
+++ /dev/null
@@ -1,295 +0,0 @@
-{
- "version": 2,
- "dependencies": {
- "net7.0": {
- "StyleCop.Analyzers": {
- "type": "Direct",
- "requested": "[1.2.0-beta.435, )",
- "resolved": "1.2.0-beta.435",
- "contentHash": "TADk7vdGXtfTnYCV7GyleaaRTQjfoSfZXprQrVMm7cSJtJbFc1QIbWPyLvrgrfGdfHbGmUPvaN4ODKNxg2jgPQ==",
- "dependencies": {
- "StyleCop.Analyzers.Unstable": "1.2.0.435"
- }
- },
- "Google.OrTools.runtime.linux-arm64": {
- "type": "Transitive",
- "resolved": "9.4.1874",
- "contentHash": "Z46ndZcZa2Lt5b76xU9kxVYbPLg/LfuMufhUVsu3Qo3L7Bibf7WXd9j7RRldjnuv8RIHWTqb0b+2FwwMxs0c5A=="
- },
- "Google.OrTools.runtime.linux-x64": {
- "type": "Transitive",
- "resolved": "9.4.1874",
- "contentHash": "zGeDb8FuvP9HXjrsU7krVXtSDFpR+DUGNEsH51k94jL9tzf2vWYI8+WUBRHZ/cGe50dpLr+vIjfcNo3gFyOpkQ=="
- },
- "Google.OrTools.runtime.osx-arm64": {
- "type": "Transitive",
- "resolved": "9.4.1874",
- "contentHash": "Wo0ZfDaH6DhiQw0jZm4HWJm/oPGPpWNwOLUz+EYaoH3MLtocSxItHGQj/Ta3HyhXnYNOv+TliAH8L+8RCXu/2w=="
- },
- "Google.OrTools.runtime.osx-x64": {
- "type": "Transitive",
- "resolved": "9.4.1874",
- "contentHash": "IAfGgKR1og6vU87axK1d37Ak/4jy8B4NMoElovG/KZc/2UY+cJEAQDA709UMegtI4lBhuxTWFNUiHQYmRIB9yQ=="
- },
- "Google.OrTools.runtime.win-x64": {
- "type": "Transitive",
- "resolved": "9.4.1874",
- "contentHash": "fUs5qDnZA6itygolcX6nPuachQkY9CVvQbakIzIiRAWKcaj8umQAbFdGwbkyzp3qp34BKW5mtPVsmMyfQBBjOQ=="
- },
- "libortki": {
- "type": "Transitive",
- "resolved": "0.0.2",
- "contentHash": "svfuG5mxGY/QC/5DVheHOCELmdSP90RtxQ73j23KarPXZ9ZXW+7v1l5J77hGDyQbEh1BGrnGgKBlyn76RauGHg==",
- "dependencies": {
- "libortki-linux": "0.0.2",
- "libortki-osx": "0.0.2",
- "libortki-osx-arm64": "0.0.2",
- "libortki-win": "0.0.2"
- }
- },
- "libortki-linux": {
- "type": "Transitive",
- "resolved": "0.0.2",
- "contentHash": "b04LWD4lgGy60tys3hPFhnUpgWDM6dN5r1PI7GOcPj8VupXCaI70LKNQ5/5twbDE6rkowOGanVTw0S2wBGBqBQ=="
- },
- "libortki-osx": {
- "type": "Transitive",
- "resolved": "0.0.2",
- "contentHash": "O6Q9GLULkDkZEPAZJVKLPH0ROXGVOE7BxuddgOcHNK2oiTEM7wIRnzp2OIlYgLpaOLyxJMisbGOhtWgdzt2Wng=="
- },
- "libortki-osx-arm64": {
- "type": "Transitive",
- "resolved": "0.0.2",
- "contentHash": "4Qn2dirJmRicnUG945oWpq7HVGwgqCKKxYPMISv/MRvmpZBbXrZ1cVvRaF8WwTu4XXgfKTa1sLv+i8zLifUMeQ=="
- },
- "libortki-win": {
- "type": "Transitive",
- "resolved": "0.0.2",
- "contentHash": "HAoROgAKn8XBun11X43HZuspKlo5JGy8/OYw5IUPo7FVh5TCaPrLjGmyGYYZ2dqLlv31yv/b6s254PIRGn95cA=="
- },
- "Microsoft.Extensions.Configuration.Abstractions": {
- "type": "Transitive",
- "resolved": "6.0.0",
- "contentHash": "qWzV9o+ZRWq+pGm+1dF+R7qTgTYoXvbyowRoBxQJGfqTpqDun2eteerjRQhq5PQ/14S+lqto3Ft4gYaRyl4rdQ==",
- "dependencies": {
- "Microsoft.Extensions.Primitives": "6.0.0"
- }
- },
- "Microsoft.Extensions.DependencyInjection.Abstractions": {
- "type": "Transitive",
- "resolved": "6.0.0",
- "contentHash": "xlzi2IYREJH3/m6+lUrQlujzX8wDitm4QGnUu6kUXTQAWPuZY8i+ticFJbzfqaetLA6KR/rO6Ew/HuYD+bxifg=="
- },
- "Microsoft.Extensions.FileProviders.Abstractions": {
- "type": "Transitive",
- "resolved": "6.0.0",
- "contentHash": "0pd4/fho0gC12rQswaGQxbU34jOS1TPS8lZPpkFCH68ppQjHNHYle9iRuHeev1LhrJ94YPvzcRd8UmIuFk23Qw==",
- "dependencies": {
- "Microsoft.Extensions.Primitives": "6.0.0"
- }
- },
- "Microsoft.Extensions.Primitives": {
- "type": "Transitive",
- "resolved": "6.0.0",
- "contentHash": "9+PnzmQFfEFNR9J2aDTfJGGupShHjOuGw4VUv+JB044biSHrnmCIMD+mJHmb2H7YryrfBEXDurxQ47gJZdCKNQ==",
- "dependencies": {
- "System.Runtime.CompilerServices.Unsafe": "6.0.0"
- }
- },
- "NetFabric.Hyperlinq.Abstractions": {
- "type": "Transitive",
- "resolved": "1.3.0",
- "contentHash": "WXnEcGwmXfa8gW9N2MlcaPNUzM3NLMwnAhacbtH554F8YcoXbIkTB+uGa1Aa+9gyb/9JZgYVHnmADgJUKP52nA=="
- },
- "StyleCop.Analyzers.Unstable": {
- "type": "Transitive",
- "resolved": "1.2.0.435",
- "contentHash": "ouwPWZxbOV3SmCZxIRqHvljkSzkCyi1tDoMzQtDb/bRP8ctASV/iRJr+A2Gdj0QLaLmWnqTWDrH82/iP+X80Lg=="
- },
- "System.Buffers": {
- "type": "Transitive",
- "resolved": "4.5.1",
- "contentHash": "Rw7ijyl1qqRS0YQD/WycNst8hUUMgrMH4FCn1nNm27M4VxchZ1js3fVjQaANHO5f3sN4isvP4a+Met9Y4YomAg=="
- },
- "System.Runtime.CompilerServices.Unsafe": {
- "type": "Transitive",
- "resolved": "6.0.0",
- "contentHash": "/iUeP3tq1S0XdNNoMz5C9twLSrM/TH+qElHkXWaPvuNOt+99G75NrV0OS2EqHx5wMN7popYjpc8oTjC1y16DLg=="
- },
- "nncase.codegen": {
- "type": "Project",
- "dependencies": {
- "Extension.Mathematics": "[1.2.12, )",
- "Nncase.Core": "[1.0.0, )",
- "Nncase.IO": "[1.0.0, )"
- }
- },
- "nncase.core": {
- "type": "Project",
- "dependencies": {
- "DryIoc.dll": "[5.3.1, )",
- "GiGraph.Dot": "[2.0.0, )",
- "Microsoft.Extensions.Hosting.Abstractions": "[6.0.0, )",
- "Microsoft.Extensions.Logging.Abstractions": "[6.0.0, )",
- "Microsoft.Extensions.Options": "[6.0.0, )",
- "Microsoft.Toolkit.HighPerformance": "[7.1.1, )",
- "NetFabric.Hyperlinq": "[3.0.0-beta48, )",
- "System.Reactive": "[5.0.0, )"
- }
- },
- "nncase.diagnostics": {
- "type": "Project",
- "dependencies": {
- "Nncase.Core": "[1.0.0, )"
- }
- },
- "nncase.egraph": {
- "type": "Project",
- "dependencies": {
- "GiGraph.Dot": "[2.0.0, )",
- "Google.OrTools": "[9.4.1874, )",
- "NetFabric.Hyperlinq": "[3.0.0-beta48, )",
- "Nncase.Core": "[1.0.0, )",
- "Nncase.Evaluator": "[1.0.0, )",
- "Singulink.Collections.Weak": "[1.0.2, )"
- }
- },
- "nncase.evaluator": {
- "type": "Project",
- "dependencies": {
- "Nncase.Core": "[1.0.0, )",
- "OrtKISharp": "[0.0.2, )"
- }
- },
- "nncase.graph": {
- "type": "Project",
- "dependencies": {
- "Nncase.Core": "[1.0.0, )",
- "Nncase.Evaluator": "[1.0.0, )"
- }
- },
- "nncase.io": {
- "type": "Project"
- },
- "nncase.modules.stackvm": {
- "type": "Project",
- "dependencies": {
- "Nncase.CodeGen": "[1.0.0, )",
- "Nncase.Passes": "[1.0.0, )"
- }
- },
- "nncase.passes": {
- "type": "Project",
- "dependencies": {
- "Nncase.Core": "[1.0.0, )",
- "Nncase.EGraph": "[1.0.0, )",
- "Nncase.Evaluator": "[1.0.0, )",
- "Nncase.Graph": "[1.0.0, )"
- }
- },
- "DryIoc.dll": {
- "type": "CentralTransitive",
- "requested": "[5.3.1, )",
- "resolved": "5.3.1",
- "contentHash": "E3zclUh2CIBks1t2uBD1k18pyGFJ1YSKCrbCDbB7qCdl2RAB+k68AyDpjeplhF1ot2XPV82AgyCWBXMf0ggL1g=="
- },
- "Extension.Mathematics": {
- "type": "CentralTransitive",
- "requested": "[1.2.12, )",
- "resolved": "1.2.12",
- "contentHash": "D4mn5Cab4ztPLJ0V8uMErDrO/Y61098nwrvyIOLZymVAYOQcwP1vomVWKbTagf1aPU3cX5Q7adZtQEQwOy6XEg=="
- },
- "GiGraph.Dot": {
- "type": "CentralTransitive",
- "requested": "[2.0.0, )",
- "resolved": "2.0.0",
- "contentHash": "ThvS2mQVveSkTMUm04tMbRYzu1XFPV8xBHISrUMp02APjhv9IRbLu3v3upTPCywORx2Ds/c6AqEUL1WU6kPfuQ=="
- },
- "Google.OrTools": {
- "type": "CentralTransitive",
- "requested": "[9.4.1874, )",
- "resolved": "9.4.1874",
- "contentHash": "jqRoI+pYlym+fhoU25u+13oti5h+772bllQ9zDitTVMclDXVTiG6pxzvmYO74wnADBMdpb2SQlgiNQxoNk5dlA==",
- "dependencies": {
- "Google.OrTools.runtime.linux-arm64": "9.4.1874",
- "Google.OrTools.runtime.linux-x64": "9.4.1874",
- "Google.OrTools.runtime.osx-arm64": "9.4.1874",
- "Google.OrTools.runtime.osx-x64": "9.4.1874",
- "Google.OrTools.runtime.win-x64": "9.4.1874",
- "Google.Protobuf": "3.19.4"
- }
- },
- "Google.Protobuf": {
- "type": "CentralTransitive",
- "requested": "[3.19.4, )",
- "resolved": "3.19.4",
- "contentHash": "fd07/ykL4O4FhqrZIELm5lmiyOHfdPg9+o+hWr6tcfRdS7tHXnImg/2wtogLzlW2eEmr0J7j6ZrZvaWOLiJbxQ=="
- },
- "Microsoft.Extensions.Hosting.Abstractions": {
- "type": "CentralTransitive",
- "requested": "[6.0.0, )",
- "resolved": "6.0.0",
- "contentHash": "GcT5l2CYXL6Sa27KCSh0TixsRfADUgth+ojQSD5EkzisZxmGFh7CwzkcYuGwvmXLjr27uWRNrJ2vuuEjMhU05Q==",
- "dependencies": {
- "Microsoft.Extensions.Configuration.Abstractions": "6.0.0",
- "Microsoft.Extensions.DependencyInjection.Abstractions": "6.0.0",
- "Microsoft.Extensions.FileProviders.Abstractions": "6.0.0"
- }
- },
- "Microsoft.Extensions.Logging.Abstractions": {
- "type": "CentralTransitive",
- "requested": "[6.0.0, )",
- "resolved": "6.0.0",
- "contentHash": "/HggWBbTwy8TgebGSX5DBZ24ndhzi93sHUBDvP1IxbZD7FDokYzdAr6+vbWGjw2XAfR2EJ1sfKUotpjHnFWPxA=="
- },
- "Microsoft.Extensions.Options": {
- "type": "CentralTransitive",
- "requested": "[6.0.0, )",
- "resolved": "6.0.0",
- "contentHash": "dzXN0+V1AyjOe2xcJ86Qbo233KHuLEY0njf/P2Kw8SfJU+d45HNS2ctJdnEnrWbM9Ye2eFgaC5Mj9otRMU6IsQ==",
- "dependencies": {
- "Microsoft.Extensions.DependencyInjection.Abstractions": "6.0.0",
- "Microsoft.Extensions.Primitives": "6.0.0"
- }
- },
- "Microsoft.Toolkit.HighPerformance": {
- "type": "CentralTransitive",
- "requested": "[7.1.1, )",
- "resolved": "7.1.1",
- "contentHash": "TRnvDpZPXO30hTOtjfLw6Y9BtTKtTpzk9lefeh4RMCaUihWrVKQR454nYH4/mMJAh+LXqfAPyk0kfkJs0Amopw=="
- },
- "NetFabric.Hyperlinq": {
- "type": "CentralTransitive",
- "requested": "[3.0.0-beta48, )",
- "resolved": "3.0.0-beta48",
- "contentHash": "oYUhXvxNS8bBJWqNkvx5g8y0P/0LtyqS2pN0w4OWjVDNWEpLbdbvPy9w/9z1n2PrqIjX3jxUsEnoCmxxGnI3gw==",
- "dependencies": {
- "NetFabric.Hyperlinq.Abstractions": "1.3.0",
- "System.Buffers": "4.5.1",
- "System.Runtime.CompilerServices.Unsafe": "5.0.0"
- }
- },
- "OrtKISharp": {
- "type": "CentralTransitive",
- "requested": "[0.0.2, )",
- "resolved": "0.0.2",
- "contentHash": "q8j0yR5836Zhv9WB9BFkQt1UaEFyibq8bqJcTiULlILF6/sz8z7Wy2N8sgYdDKsdW25zncIz7j6IDbKM5ynePg==",
- "dependencies": {
- "libortki": "0.0.2"
- }
- },
- "Singulink.Collections.Weak": {
- "type": "CentralTransitive",
- "requested": "[1.0.2, )",
- "resolved": "1.0.2",
- "contentHash": "giLAHrjJe0Bh7yhNexR6pmcv02+Fi+lEPxQVdB8zvkuJCmy6rnqu8CZLIpxrUfLcWDuTCSiK0IfGmMhig3UDhA=="
- },
- "System.Reactive": {
- "type": "CentralTransitive",
- "requested": "[5.0.0, )",
- "resolved": "5.0.0",
- "contentHash": "erBZjkQHWL9jpasCE/0qKAryzVBJFxGHVBAvgRN1bzM0q2s1S4oYREEEL0Vb+1kA/6BKb5FjUZMp5VXmy+gzkQ=="
- }
- }
- }
-}
\ No newline at end of file
diff --git a/modules/Nncase.Modules.K210/CodeGen/K210/KPULinkedModule.cs b/modules/Nncase.Modules.K210/CodeGen/K210/KPULinkedModule.cs
index 7018e7c7bd..4c4aeb20da 100644
--- a/modules/Nncase.Modules.K210/CodeGen/K210/KPULinkedModule.cs
+++ b/modules/Nncase.Modules.K210/CodeGen/K210/KPULinkedModule.cs
@@ -17,8 +17,8 @@ public KPULinkedModule(IReadOnlyList functions, byte[] text, byt
Functions = functions;
Sections = new[]
{
- new LinkedSection(text, ".text", 0, 8, (uint)text.Length),
- new LinkedSection(rdata, ".rdata", 0, 8, (uint)(rdata?.Length ?? 0)),
+ new LinkedSection(text, ".text", 0, 8, (ulong)text.Length),
+ new LinkedSection(rdata, ".rdata", 0, 8, (ulong)(rdata?.Length ?? 0)),
};
}
diff --git a/modules/Nncase.Modules.StackVM/CodeGen/StackVM/CodeGenVisitor.g.cs b/modules/Nncase.Modules.StackVM/CodeGen/StackVM/CodeGenVisitor.g.cs
index 04cbd48a7b..314930a8a9 100644
--- a/modules/Nncase.Modules.StackVM/CodeGen/StackVM/CodeGenVisitor.g.cs
+++ b/modules/Nncase.Modules.StackVM/CodeGen/StackVM/CodeGenVisitor.g.cs
@@ -1,6 +1,6 @@
// Copyright (c) Canaan Inc. All rights reserved.
// Licensed under the Apache license. See LICENSE file in the project root for full license information.
-/* This file is generated by tools/stackvm_gen/IsaGen at 2023/9/18 下午5:04:31 +08:00. */
+/* This file is generated by tools/stackvm_gen/IsaGen at 9/20/2023 10:17:08 AM +00:00. */
using System;
using System.Collections.Generic;
@@ -59,7 +59,7 @@ private void EmitTensorCall(Op op)
Emitter.T.L2Normalization();
break;
case IR.NN.LayerNorm top:
- Emitter.T.LayerNorm(top.Axis, top.Epsilon);
+ Emitter.T.LayerNorm(top.Axis, top.Epsilon, top.UseMean);
break;
case IR.NN.LeakyRelu top:
Emitter.T.LeakyRelu();
@@ -176,7 +176,7 @@ private void EmitTensorCall(Op op)
Emitter.T.Cast(top.NewType, top.CastMode);
break;
case IR.Tensors.Concat top:
- Emitter.T.Concat();
+ Emitter.T.Concat(top.Axis);
break;
case IR.Tensors.ConstantOfShape top:
Emitter.T.ConstantOfShape();
@@ -191,7 +191,7 @@ private void EmitTensorCall(Op op)
Emitter.T.Flatten();
break;
case IR.Tensors.Gather top:
- Emitter.T.Gather();
+ Emitter.T.Gather(top.Axis);
break;
case IR.Tensors.GatherElements top:
Emitter.T.GatherElements();
@@ -205,9 +205,6 @@ private void EmitTensorCall(Op op)
case IR.Tensors.IndexOf top:
Emitter.T.IndexOf();
break;
- case IR.Tensors.LSTM top:
- Emitter.T.LSTM(top.Direction, top.Layout, top.Activations);
- break;
case IR.Tensors.Prod top:
Emitter.T.Prod();
break;
@@ -289,6 +286,9 @@ private void EmitTensorCall(Op op)
case IR.ShapeExpr.UnsqueezeShape top:
Emitter.T.UnsqueezeShape();
break;
+ case IR.RNN.LSTM top:
+ Emitter.T.LSTM(top.Direction, top.Layout, top.Activations);
+ break;
case IR.Random.Normal top:
Emitter.T.Normal(top.Type);
break;
diff --git a/modules/Nncase.Modules.StackVM/CodeGen/StackVM/StackVMEmitter.g.cs b/modules/Nncase.Modules.StackVM/CodeGen/StackVM/StackVMEmitter.g.cs
index 6e2184c5ea..b6512a344c 100644
--- a/modules/Nncase.Modules.StackVM/CodeGen/StackVM/StackVMEmitter.g.cs
+++ b/modules/Nncase.Modules.StackVM/CodeGen/StackVM/StackVMEmitter.g.cs
@@ -1,6 +1,6 @@
// Copyright (c) Canaan Inc. All rights reserved.
// Licensed under the Apache license. See LICENSE file in the project root for full license information.
-/* This file is generated by tools/stackvm_gen/IsaGen at 2023/9/5 19:40:30 +08:00. */
+/* This file is generated by tools/stackvm_gen/IsaGen at 9/20/2023 10:17:08 AM +00:00. */
using System;
using System.Collections.Generic;
@@ -723,10 +723,11 @@ public void Compare(CompareOp compareOp)
}
///.
- public void Concat()
+ public void Concat(int axis)
{
_emitter.Write((byte)100);
_emitter.Write((ushort)11);
+ _emitter.Write(axis);
}
///.
@@ -841,10 +842,11 @@ public void Flatten()
}
///.
- public void Gather()
+ public void Gather(int axis)
{
_emitter.Write((byte)100);
_emitter.Write((ushort)27);
+ _emitter.Write(axis);
}
///.
@@ -925,12 +927,13 @@ public void L2Normalization()
}
///.
- public void LayerNorm(int axis, float epsilon)
+ public void LayerNorm(int axis, float epsilon, bool useMean)
{
_emitter.Write((byte)100);
_emitter.Write((ushort)39);
_emitter.Write(axis);
_emitter.Write(epsilon);
+ _emitter.Write(useMean);
}
///.
diff --git a/modules/Nncase.Modules.StackVM/CodeGen/StackVM/StackVMLinkedModule.cs b/modules/Nncase.Modules.StackVM/CodeGen/StackVM/StackVMLinkedModule.cs
index 3bfb985fd8..93c8d9c6b6 100644
--- a/modules/Nncase.Modules.StackVM/CodeGen/StackVM/StackVMLinkedModule.cs
+++ b/modules/Nncase.Modules.StackVM/CodeGen/StackVM/StackVMLinkedModule.cs
@@ -17,9 +17,9 @@ public StackVMLinkedModule(IReadOnlyList functions, Stream text,
Functions = functions;
Sections = new[]
{
- new LinkedSection(text, ".text", 0, 8, (uint)text.Length),
- new LinkedSection(rdata, ".rdata", 0, 8, (uint)(rdata?.Length ?? 0)),
- new LinkedSection(custom_calls, ".custom_calls", 0, 8, (uint)(custom_calls?.Length ?? 0)),
+ new LinkedSection(text, ".text", 0, 8, (ulong)text.Length),
+ new LinkedSection(rdata, ".rdata", 0, 8, (ulong)(rdata?.Length ?? 0)),
+ new LinkedSection(custom_calls, ".custom_calls", 0, 8, (ulong)(custom_calls?.Length ?? 0)),
};
}
diff --git a/modules/Nncase.Modules.StackVM/Targets/CPUTarget.cs b/modules/Nncase.Modules.StackVM/Targets/CPUTarget.cs
index ba77ae58f5..2f63e02be9 100644
--- a/modules/Nncase.Modules.StackVM/Targets/CPUTarget.cs
+++ b/modules/Nncase.Modules.StackVM/Targets/CPUTarget.cs
@@ -3,6 +3,7 @@
using System;
using System.Collections.Generic;
+using System.CommandLine.Invocation;
using System.Linq;
using System.Text;
using System.Threading.Tasks;
@@ -21,13 +22,15 @@ namespace Nncase.Targets;
///
public class CPUTarget : ITarget
{
- ///
- /// Gets kind.
- ///
- public static readonly string Kind = "cpu";
+ public const string Kind = "cpu";
string ITarget.Kind => Kind;
+ public (System.CommandLine.Command Command, Func Parser) RegisterCommandAndParser()
+ {
+ return (new System.CommandLine.Command(Kind), (_, _) => DefaultTargetCompileOptions.Instance);
+ }
+
///
public void ParseTargetDependentOptions(IConfigurationSection configure)
{
diff --git a/modules/Nncase.Modules.StackVM/packages.lock.json b/modules/Nncase.Modules.StackVM/packages.lock.json
index b69bdc40f8..4830d0bbf6 100644
--- a/modules/Nncase.Modules.StackVM/packages.lock.json
+++ b/modules/Nncase.Modules.StackVM/packages.lock.json
@@ -4,11 +4,11 @@
"net7.0": {
"StyleCop.Analyzers": {
"type": "Direct",
- "requested": "[1.2.0-beta.507, )",
- "resolved": "1.2.0-beta.507",
- "contentHash": "/FtugDT66cKJJ+GGH7rNpG6UDrT4iIWz45M6lrXXHobDUFDHw+q5VgkbiR+6ffTO564ge7w6fQh/eoQhVdJO8Q==",
+ "requested": "[1.2.0-beta.435, )",
+ "resolved": "1.2.0-beta.435",
+ "contentHash": "TADk7vdGXtfTnYCV7GyleaaRTQjfoSfZXprQrVMm7cSJtJbFc1QIbWPyLvrgrfGdfHbGmUPvaN4ODKNxg2jgPQ==",
"dependencies": {
- "StyleCop.Analyzers.Unstable": "1.2.0.507"
+ "StyleCop.Analyzers.Unstable": "1.2.0.435"
}
},
"Google.OrTools.runtime.linux-arm64": {
@@ -103,8 +103,8 @@
},
"StyleCop.Analyzers.Unstable": {
"type": "Transitive",
- "resolved": "1.2.0.507",
- "contentHash": "gTY3IQdRqDJ4hbhSA3e/R48oE8b/OiKfvwkt1QdNVfrJK2gMHBV8ldaHJ885jxWZfllK66soa/sdcjh9bX49Tw=="
+ "resolved": "1.2.0.435",
+ "contentHash": "ouwPWZxbOV3SmCZxIRqHvljkSzkCyi1tDoMzQtDb/bRP8ctASV/iRJr+A2Gdj0QLaLmWnqTWDrH82/iP+X80Lg=="
},
"System.Buffers": {
"type": "Transitive",
@@ -134,6 +134,7 @@
"Microsoft.Extensions.Options": "[6.0.0, )",
"Microsoft.Toolkit.HighPerformance": "[7.1.1, )",
"NetFabric.Hyperlinq": "[3.0.0-beta48, )",
+ "System.CommandLine": "[2.0.0-beta4.22272.1, )",
"System.Reactive": "[5.0.0, )"
}
},
@@ -271,6 +272,12 @@
"resolved": "1.0.2",
"contentHash": "giLAHrjJe0Bh7yhNexR6pmcv02+Fi+lEPxQVdB8zvkuJCmy6rnqu8CZLIpxrUfLcWDuTCSiK0IfGmMhig3UDhA=="
},
+ "System.CommandLine": {
+ "type": "CentralTransitive",
+ "requested": "[2.0.0-beta4.22272.1, )",
+ "resolved": "2.0.0-beta4.22272.1",
+ "contentHash": "1uqED/q2H0kKoLJ4+hI2iPSBSEdTuhfCYADeJrAqERmiGQ2NNacYKRNEQ+gFbU4glgVyK8rxI+ZOe1onEtr/Pg=="
+ },
"System.Reactive": {
"type": "CentralTransitive",
"requested": "[5.0.0, )",
diff --git a/nncase.sln b/nncase.sln
index b79fd74a04..77baf826c2 100644
--- a/nncase.sln
+++ b/nncase.sln
@@ -44,8 +44,6 @@ Project("{9A19103F-16F7-4668-BE54-9A1E7A4F7556}") = "Nncase.IO", "src\Nncase.IO\
EndProject
Project("{9A19103F-16F7-4668-BE54-9A1E7A4F7556}") = "Nncase.Schedule", "src\Nncase.Schedule\Nncase.Schedule.csproj", "{8E0E0672-0F96-4EF1-BDCD-D31F96A3DF73}"
EndProject
-Project("{2150E333-8FDC-42A3-9474-1A3956D46DE8}") = "targets", "targets", "{A2590531-71C5-4326-88DD-6A9DB2EF0A2B}"
-EndProject
Project("{9A19103F-16F7-4668-BE54-9A1E7A4F7556}") = "Nncase.Targets", "src\Nncase.Targets\Nncase.Targets.csproj", "{56283378-06E3-4C6E-A8BF-7BD85C92D42C}"
EndProject
Project("{9A19103F-16F7-4668-BE54-9A1E7A4F7556}") = "Nncase.Simulator", "src\Nncase.Simulator\Nncase.Simulator.csproj", "{901AC17C-7B53-4B10-A2AC-EA7AEA6DC614}"
diff --git a/pyproject.toml b/pyproject.toml
index 04a03089f8..a8dc5783a2 100644
--- a/pyproject.toml
+++ b/pyproject.toml
@@ -48,9 +48,9 @@ before-all = [
"pip install conan==1.59",
"conan profile new default --detect",
"conan profile update settings.compiler.libcxx=libstdc++11 default",
- "curl -L https://sdk.lunarg.com/sdk/download/1.2.182.0/linux/vulkansdk-linux-x86_64-1.2.182.0.tar.gz --output vulkansdk.tar.gz",
- "tar xf vulkansdk.tar.gz",
- "cp -P 1.2.182.0/x86_64/lib/libvulkan.so* /usr/local/lib/"
+ "curl -L https://sdk.lunarg.com/sdk/download/1.3.268.0/linux/vulkansdk-linux-x86_64-1.3.268.0.tar.xz --output vulkansdk.tar.xz",
+ "tar xf vulkansdk.tar.xz",
+ "cp -P 1.3.268.0/x86_64/lib/libvulkan.so* /usr/local/lib/"
]
before-build = "pip install auditwheel"
repair-wheel-command = "LD_LIBRARY_PATH=/usr/lib64 auditwheel repair -w {dest_dir} {wheel} --exclude libvulkan.so.1,libgomp.so.1"
diff --git a/python/common/pystreambuf.h b/python/common/pystreambuf.h
index 178a041e0c..27581c1d75 100644
--- a/python/common/pystreambuf.h
+++ b/python/common/pystreambuf.h
@@ -1,6 +1,7 @@
// https://gist.github.com/asford/544323a5da7dddad2c9174490eb5ed06
#pragma once
+#include
#include
#include
diff --git a/python/nncase/__init__.py b/python/nncase/__init__.py
index 0663332e10..cef1b150bb 100644
--- a/python/nncase/__init__.py
+++ b/python/nncase/__init__.py
@@ -66,6 +66,7 @@ class PTQTensorOptions:
input_mean: float
input_std: float
quant_scheme: str
+ quant_scheme_strict_mode: bool
samples_count: int
cali_data: List[RuntimeTensor]
@@ -83,6 +84,7 @@ def __init__(self) -> None:
self.input_mean: float = 0.5
self.input_std: float = 0.5
self.quant_scheme: str = ""
+ self.quant_scheme_strict_mode: bool = False
self.samples_count: int = 5
self.cali_data: List[RuntimeTensor] = []
@@ -244,6 +246,7 @@ def use_ptq(self, ptq_dataset_options: PTQTensorOptions) -> None:
self._quantize_options.use_mix_quant = ptq_dataset_options.use_mix_quant
self._quantize_options.quant_scheme = ptq_dataset_options.quant_scheme
+ self._quantize_options.quant_scheme_strict_mode = ptq_dataset_options.quant_scheme_strict_mode
self._quantize_options.export_quant_scheme = ptq_dataset_options.export_quant_scheme
self._quantize_options.export_weight_range_by_channel = ptq_dataset_options.export_weight_range_by_channel
self._quantize_options.dump_quant_error = ptq_dataset_options.dump_quant_error
@@ -295,7 +298,7 @@ def _import_ncnn_module(self, model_param: bytes | io.RawIOBase, model_bin: byte
def check_target(target: str):
def test_target(target: str):
- return target in ["cpu", "k510", "k230"]
+ return target in ["cpu", "k510", "k230", "xpu"]
def target_exists(target: str):
return _nncase.Target.exists(target)
diff --git a/python/nncase/native/ffi.cpp b/python/nncase/native/ffi.cpp
index 8bd6bd6ba9..b8c99ed966 100644
--- a/python/nncase/native/ffi.cpp
+++ b/python/nncase/native/ffi.cpp
@@ -185,6 +185,11 @@ PYBIND11_MODULE(_nncase, m) {
py::overload_cast<>(&quantize_options::quant_scheme),
py::overload_cast(
&quantize_options::quant_scheme))
+ .def_property(
+ "quant_scheme_strict_mode",
+ py::overload_cast<>(&quantize_options::quant_scheme_strict_mode),
+ py::overload_cast(
+ &quantize_options::quant_scheme_strict_mode))
.def_property(
"export_quant_scheme",
py::overload_cast<>(&quantize_options::export_quant_scheme),
diff --git a/src/Native/include/nncase/compiler.h b/src/Native/include/nncase/compiler.h
index 6b7b33ef92..7339d0b546 100644
--- a/src/Native/include/nncase/compiler.h
+++ b/src/Native/include/nncase/compiler.h
@@ -199,6 +199,8 @@ typedef struct {
void (*quantize_options_set_quant_scheme)(
clr_object_handle_t quantize_options, const char *quant_scheme,
size_t quant_scheme_length);
+ void (*quantize_options_set_quant_scheme_strict_mode)(
+ clr_object_handle_t quantize_options, bool quant_scheme_strict_mode);
void (*quantize_options_set_export_quant_scheme)(
clr_object_handle_t quantize_options, bool export_quant_scheme);
void (*quantize_options_set_export_weight_range_by_channel)(
@@ -401,6 +403,12 @@ class quantize_options : public clr_object_base {
obj_.get(), value.data(), value.length());
}
+ bool quant_scheme_strict_mode() { return false; }
+ void quant_scheme_strict_mode(bool value) {
+ nncase_clr_api()->quantize_options_set_quant_scheme_strict_mode(
+ obj_.get(), value);
+ }
+
bool export_quant_scheme() { return false; }
void export_quant_scheme(bool value) {
nncase_clr_api()->quantize_options_set_export_quant_scheme(obj_.get(),
diff --git a/src/Native/include/nncase/kernels/stackvm/tensor_ops.h b/src/Native/include/nncase/kernels/stackvm/tensor_ops.h
index e918f22f25..84cae4e387 100644
--- a/src/Native/include/nncase/kernels/stackvm/tensor_ops.h
+++ b/src/Native/include/nncase/kernels/stackvm/tensor_ops.h
@@ -1,5 +1,5 @@
-/* This file is generated by tools/stackvm_gen/IsaGen at 2023/9/5 19:40:29
- * +08:00.
+/* This file is generated by tools/stackvm_gen/IsaGen at 9/20/2023 10:17:07 AM
+ * +00:00.
*
* Copyright 2019-2021 Canaan Inc.
*
@@ -78,7 +78,7 @@ compare(runtime::stackvm::compare_op_t compare_op, value_t lhs, value_t rhs,
kernel_context &context = default_kernel_context());
NNCASE_API result
-concat(value_t input, value_t axis, value_t output = nullptr,
+concat(int32_t axis, value_t input, value_t output = nullptr,
kernel_context &context = default_kernel_context());
NNCASE_API result
@@ -157,7 +157,7 @@ flatten(value_t input, value_t axis, value_t output = nullptr,
kernel_context &context = default_kernel_context());
NNCASE_API result
-gather(value_t input, value_t axis, value_t index, value_t output = nullptr,
+gather(int32_t axis, value_t input, value_t index, value_t output = nullptr,
kernel_context &context = default_kernel_context());
NNCASE_API result
@@ -211,8 +211,8 @@ l2_normalization(value_t input, value_t output = nullptr,
kernel_context &context = default_kernel_context());
NNCASE_API result
-layer_norm(int32_t axis, float epsilon, value_t input, value_t scale,
- value_t bias, value_t output = nullptr,
+layer_norm(int32_t axis, float epsilon, bool use_mean, value_t input,
+ value_t scale, value_t bias, value_t output = nullptr,
kernel_context &context = default_kernel_context());
NNCASE_API result
diff --git a/src/Native/include/nncase/runtime/interpreter.h b/src/Native/include/nncase/runtime/interpreter.h
index fc91970657..a8f5814d34 100644
--- a/src/Native/include/nncase/runtime/interpreter.h
+++ b/src/Native/include/nncase/runtime/interpreter.h
@@ -73,6 +73,7 @@ class NNCASE_API interpreter {
options_dict &options() noexcept;
result find_module_by_id(size_t index) noexcept;
+ result find_id_by_module(runtime_module *module) noexcept;
/* V1 APIs */
diff --git a/src/Native/include/nncase/runtime/runtime_module.h b/src/Native/include/nncase/runtime/runtime_module.h
index 354d747cbf..194dd3b1f7 100644
--- a/src/Native/include/nncase/runtime/runtime_module.h
+++ b/src/Native/include/nncase/runtime/runtime_module.h
@@ -58,6 +58,8 @@ class NNCASE_API runtime_module {
result find_function_by_id(size_t index) noexcept;
+ result find_id_by_function(runtime_function *function) noexcept;
+
protected:
virtual result
initialize_before_functions(runtime_module_init_context &context) noexcept;
diff --git a/src/Native/include/nncase/runtime/stackvm/op_reader.h b/src/Native/include/nncase/runtime/stackvm/op_reader.h
index 80372463e4..ffde6669bd 100644
--- a/src/Native/include/nncase/runtime/stackvm/op_reader.h
+++ b/src/Native/include/nncase/runtime/stackvm/op_reader.h
@@ -1,5 +1,5 @@
-/* This file is generated by tools/stackvm_gen/IsaGen at 2023/9/5 19:40:29
- * +08:00.
+/* This file is generated by tools/stackvm_gen/IsaGen at 9/20/2023 10:17:07 AM
+ * +00:00.
*
* Copyright 2019-2021 Canaan Inc.
*
@@ -837,6 +837,7 @@ template <> struct tensor_op_reader {
template <> struct tensor_op_reader {
tensor_concat_op_t operator()(NNCASE_UNUSED span_reader &reader) const {
tensor_concat_op_t op;
+ op.axis = reader.read_unaligned();
return op;
}
};
@@ -964,6 +965,7 @@ template <> struct tensor_op_reader {
template <> struct tensor_op_reader {
tensor_gather_op_t operator()(NNCASE_UNUSED span_reader &reader) const {
tensor_gather_op_t op;
+ op.axis = reader.read_unaligned();
return op;
}
};
@@ -1055,6 +1057,7 @@ template <> struct tensor_op_reader {
tensor_layer_norm_op_t op;
op.axis = reader.read_unaligned();
op.epsilon = reader.read_unaligned();
+ op.use_mean = reader.read_unaligned();
return op;
}
};
diff --git a/src/Native/include/nncase/runtime/stackvm/opcode.h b/src/Native/include/nncase/runtime/stackvm/opcode.h
index 5c17c82894..8a5225b54e 100644
--- a/src/Native/include/nncase/runtime/stackvm/opcode.h
+++ b/src/Native/include/nncase/runtime/stackvm/opcode.h
@@ -1,5 +1,5 @@
-/* This file is generated by tools/stackvm_gen/IsaGen at 2023/9/5 19:40:29
- * +08:00.
+/* This file is generated by tools/stackvm_gen/IsaGen at 9/20/2023 10:17:07 AM
+ * +00:00.
*
* Copyright 2019-2021 Canaan Inc.
*
@@ -190,7 +190,6 @@ enum class tensor_function_t : uint16_t {
gather_nd = 29,
get_item = 31,
index_of = 36,
- lstm = 44,
prod = 52,
range = 55,
rank = 57,
@@ -218,6 +217,7 @@ enum class tensor_function_t : uint16_t {
squeeze_shape = 81,
transpose_shape = 87,
unsqueeze_shape = 93,
+ lstm = 44,
normal = 47,
normal_like = 48,
uniform = 90,
@@ -614,7 +614,9 @@ struct tensor_compare_op_t {
compare_op_t compare_op;
};
-struct tensor_concat_op_t {};
+struct tensor_concat_op_t {
+ int32_t axis;
+};
struct tensor_condition_op_t {
bool can_fold_const_call;
@@ -658,7 +660,9 @@ struct tensor_fix_shape_op_t {};
struct tensor_flatten_op_t {};
-struct tensor_gather_op_t {};
+struct tensor_gather_op_t {
+ int32_t axis;
+};
struct tensor_gather_elements_op_t {};
@@ -685,6 +689,7 @@ struct tensor_l2_normalization_op_t {};
struct tensor_layer_norm_op_t {
int32_t axis;
float epsilon;
+ bool use_mean;
};
struct tensor_leaky_relu_op_t {};
@@ -964,8 +969,6 @@ inline std::string to_string(tensor_function_t tensor_funct) {
return "get_item";
case tensor_function_t::index_of:
return "index_of";
- case tensor_function_t::lstm:
- return "lstm";
case tensor_function_t::prod:
return "prod";
case tensor_function_t::range:
@@ -1020,6 +1023,8 @@ inline std::string to_string(tensor_function_t tensor_funct) {
return "transpose_shape";
case tensor_function_t::unsqueeze_shape:
return "unsqueeze_shape";
+ case tensor_function_t::lstm:
+ return "lstm";
case tensor_function_t::normal:
return "normal";
case tensor_function_t::normal_like:
diff --git a/src/Native/src/kernels/stackvm/optimized/slice.cpp b/src/Native/src/kernels/stackvm/optimized/slice.cpp
index b56a16fad6..7899c426b3 100644
--- a/src/Native/src/kernels/stackvm/optimized/slice.cpp
+++ b/src/Native/src/kernels/stackvm/optimized/slice.cpp
@@ -85,6 +85,9 @@ result slice_contiguous_impl(
} else if (dims == 3) {
_slice_contiguous_dim_copy<3>(begins, ends, line_copy, in_index,
std::true_type{});
+ } else if (dims == 4) {
+ _slice_contiguous_dim_copy<4>(begins, ends, line_copy, in_index,
+ std::true_type{});
} else {
assert(false);
}
diff --git a/src/Native/src/kernels/stackvm/tensor_ops.cpp b/src/Native/src/kernels/stackvm/tensor_ops.cpp
index b9ea85d40b..2cb452aa3d 100644
--- a/src/Native/src/kernels/stackvm/tensor_ops.cpp
+++ b/src/Native/src/kernels/stackvm/tensor_ops.cpp
@@ -47,8 +47,9 @@ result nncase::kernels::stackvm::batch_normalization(
}
result nncase::kernels::stackvm::layer_norm(
- int32_t axis, float epsilon, value_t input, value_t scale, value_t bias,
- value_t output, [[maybe_unused]] kernel_context &context) {
+ int32_t axis, float epsilon, [[maybe_unused]] bool use_mean, value_t input,
+ value_t scale, value_t bias, value_t output,
+ [[maybe_unused]] kernel_context &context) {
try_input(input_mem, input);
try_input(scale_mem, scale);
try_input(bias_mem, bias);
@@ -124,7 +125,7 @@ nncase::kernels::stackvm::clamp(value_t input, value_t min, value_t max,
KERNEL_FINISH;
}
-result nncase::kernels::stackvm::concat(value_t input, value_t axis,
+result nncase::kernels::stackvm::concat(int32_t axis, value_t input,
value_t output,
kernel_context &context) {
try_tuple_input(inputs_mem, input);
@@ -132,7 +133,7 @@ result nncase::kernels::stackvm::concat(value_t input, value_t axis,
try_var(strides, get_strides(input_tuple));
try_tuple_field0(input0, input_tuple);
auto dtype = input0->dtype();
- try_positive_axis_with_rank(axis_value, axis, input0->shape().size());
+ auto axis_value = positive_index(axis, input0->shape().size());
auto out_shape = concat_infer_shape(shapes, axis_value);
try_output(out_mem, output, dtype, out_shape);
auto concat_dims = dims_t();
@@ -293,14 +294,15 @@ nncase::kernels::stackvm::flatten(value_t input, value_t axis, value_t output,
KERNEL_FINISH;
}
-result nncase::kernels::stackvm::gather(value_t input, value_t axis,
+result nncase::kernels::stackvm::gather(int32_t axis, value_t input,
value_t index, value_t output,
kernel_context &context) {
try_input(input_mem, input);
try_input(index_mem, index);
auto dtype = input_tensor->dtype();
try_var(typecode, to_typecode(dtype));
- try_positive_axis(axis_value, axis, input_tensor);
+ // try_positive_axis(axis_value, axis, input_tensor);
+ auto axis_value = positive_index(axis, input_tensor->shape().size());
auto out_shape = gather_infer_shape(input_tensor->shape(),
index_tensor->shape(), axis_value);
try_output(out_mem, output, dtype, out_shape);
diff --git a/src/Native/src/runtime/CMakeLists.txt b/src/Native/src/runtime/CMakeLists.txt
index b892beded3..f92450b6a0 100644
--- a/src/Native/src/runtime/CMakeLists.txt
+++ b/src/Native/src/runtime/CMakeLists.txt
@@ -54,6 +54,7 @@ else()
add_library(simulator OBJECT ${SRCS})
target_include_directories(simulator PRIVATE ${CMAKE_CURRENT_BINARY_DIR})
target_link_libraries(simulator PUBLIC gsl::gsl-lite)
+ target_link_libraries(simulator PUBLIC fmt::fmt)
target_link_libraries(simulator PRIVATE kernels)
target_compile_definitions(simulator PUBLIC -DNNCASE_DLL -DNNCASE_SIMULATOR)
if (DEFAULT_BUILTIN_RUNTIMES)
diff --git a/src/Native/src/runtime/interpreter.cpp b/src/Native/src/runtime/interpreter.cpp
index dc5839fb44..fe69b8a071 100644
--- a/src/Native/src/runtime/interpreter.cpp
+++ b/src/Native/src/runtime/interpreter.cpp
@@ -246,6 +246,17 @@ result interpreter::find_module_by_id(size_t index) noexcept {
return ok(modules_[index].get());
}
+result interpreter::find_id_by_module(runtime_module *module) noexcept {
+ auto it = std::find_if(modules_.begin(), modules_.end(),
+ [&module](const std::unique_ptr &p) {
+ return p.get() == module;
+ });
+ if (it == modules_.end()) {
+ return err(std::errc::result_out_of_range);
+ }
+ return ok((it - modules_.begin()));
+}
+
options_dict &interpreter::options() noexcept { return options_; }
result interpreter::entry_function() noexcept {
diff --git a/src/Native/src/runtime/runtime_module.cpp b/src/Native/src/runtime/runtime_module.cpp
index 4b5d747bbd..66acaca61b 100644
--- a/src/Native/src/runtime/runtime_module.cpp
+++ b/src/Native/src/runtime/runtime_module.cpp
@@ -189,6 +189,19 @@ runtime_module::find_function_by_id(size_t index) noexcept {
return ok(functions_[index].get());
}
+result
+runtime_module::find_id_by_function(runtime_function *function) noexcept {
+ auto it =
+ std::find_if(functions_.begin(), functions_.end(),
+ [&function](const std::unique_ptr &p) {
+ return p.get() == function;
+ });
+ if (it == functions_.end()) {
+ return err(std::errc::result_out_of_range);
+ }
+ return ok((it - functions_.begin()));
+}
+
result runtime_module::initialize_before_functions(
NNCASE_UNUSED runtime_module_init_context &context) noexcept {
return ok();
diff --git a/src/Native/src/runtime/stackvm/op_reader.cpp b/src/Native/src/runtime/stackvm/op_reader.cpp
index 901f0f6125..4cb7f0e36a 100644
--- a/src/Native/src/runtime/stackvm/op_reader.cpp
+++ b/src/Native/src/runtime/stackvm/op_reader.cpp
@@ -1,5 +1,5 @@
-/* This file is generated by tools/stackvm_gen/IsaGen at 2023/9/5 19:40:29
- * +08:00.
+/* This file is generated by tools/stackvm_gen/IsaGen at 9/20/2023 10:17:07 AM
+ * +00:00.
*
* Copyright 2019-2021 Canaan Inc.
*
diff --git a/src/Native/src/runtime/stackvm/ops/tensor.cpp b/src/Native/src/runtime/stackvm/ops/tensor.cpp
index 6f09a7084c..3172fd0f90 100644
--- a/src/Native/src/runtime/stackvm/ops/tensor.cpp
+++ b/src/Native/src/runtime/stackvm/ops/tensor.cpp
@@ -1,5 +1,5 @@
-/* This file is generated by tools/stackvm_gen/IsaGen at 2023/9/5 19:40:29
- * +08:00.
+/* This file is generated by tools/stackvm_gen/IsaGen at 9/20/2023 10:17:08 AM
+ * +00:00.
*
* Copyright 2019-2021 Canaan Inc.
*
@@ -207,9 +207,7 @@ result stackvm_runtime_function::visit(
dump_op("concat");
try_var(input, pop_value());
dump_input(input);
- try_var(axis, pop_value());
- dump_input(axis);
- try_var(output, kernels::stackvm::concat(input, axis, nullptr,
+ try_var(output, kernels::stackvm::concat(op.axis, input, nullptr,
module().kernel_context()));
dump_output(output);
stack_.push(std::move(output));
@@ -491,11 +489,9 @@ result stackvm_runtime_function::visit(
dump_op("gather");
try_var(input, pop_value());
dump_input(input);
- try_var(axis, pop_value());
- dump_input(axis);
try_var(index, pop_value());
dump_input(index);
- try_var(output, kernels::stackvm::gather(input, axis, index, nullptr,
+ try_var(output, kernels::stackvm::gather(op.axis, input, index, nullptr,
module().kernel_context()));
dump_output(output);
stack_.push(std::move(output));
@@ -683,9 +679,9 @@ result stackvm_runtime_function::visit(
dump_input(scale);
try_var(bias, pop_value());
dump_input(bias);
- try_var(output, kernels::stackvm::layer_norm(op.axis, op.epsilon, input,
- scale, bias, nullptr,
- module().kernel_context()));
+ try_var(output, kernels::stackvm::layer_norm(
+ op.axis, op.epsilon, op.use_mean, input, scale, bias,
+ nullptr, module().kernel_context()));
dump_output(output);
stack_.push(std::move(output));
return ok();
diff --git a/src/Native/src/runtime/stackvm/runtime_function_ops.h b/src/Native/src/runtime/stackvm/runtime_function_ops.h
index ae6944ef59..351b758b88 100644
--- a/src/Native/src/runtime/stackvm/runtime_function_ops.h
+++ b/src/Native/src/runtime/stackvm/runtime_function_ops.h
@@ -1,5 +1,5 @@
-/* This file is generated by tools/stackvm_gen/IsaGen at 2023/9/5 19:40:29
- * +08:00.
+/* This file is generated by tools/stackvm_gen/IsaGen at 9/20/2023 10:17:07 AM
+ * +00:00.
*
* Copyright 2019-2021 Canaan Inc.
*
diff --git a/src/Native/src/test_cli.cpp b/src/Native/src/test_cli.cpp
index 7f703a216a..3e95be5a64 100644
--- a/src/Native/src/test_cli.cpp
+++ b/src/Native/src/test_cli.cpp
@@ -12,6 +12,7 @@
* See the License for the specific language governing permissions and
* limitations under the License.
*/
+#include
#include
#include
#include
@@ -19,6 +20,8 @@
using namespace nncase;
using namespace nncase::runtime;
+// constexpr size_t loop_count = 10;
+constexpr size_t loop_count = 1;
#define TRY(x) \
if (x) \
@@ -34,8 +37,7 @@ result write_tensor_buffer(value_t value, std::ofstream &of) {
}
result run_core(const std::string &kmodel_path,
- const std::vector &input_bins,
- const std::string &output_bin) {
+ const std::vector &bins) {
auto kmodel = read_file(kmodel_path);
interpreter *interp = new interpreter();
// auto dump_path =
@@ -47,16 +49,16 @@ result run_core(const std::string &kmodel_path,
try_var(entry, interp->entry_function());
- if (entry->parameters_size() != input_bins.size())
+ if (entry->parameters_size() > bins.size())
return err(std::errc::argument_list_too_long);
/* create the input parameters tensor
note the input tenosr must be contiguous
*/
std::vector parameters;
- for (int i = 0; i < input_bins.size(); i++) {
+ for (int i = 0; i < entry->parameters_size(); i++) {
try_var(type, entry->parameter_type(i));
try_var(ts_type, type.as());
- auto input_pool = read_file(input_bins[i]);
+ auto input_pool = read_file(bins[i]);
gsl::span input_pool_span = {
reinterpret_cast(input_pool.data()),
input_pool.size()};
@@ -66,21 +68,40 @@ result run_core(const std::string &kmodel_path,
parameters.push_back(_.impl());
}
- try_var(ret, entry->invoke({parameters.data(), parameters.size()}));
+ double total_time = 0.0;
+ for (size_t i = 0; i < loop_count; i++) {
+ auto start_time = std::chrono::steady_clock::now();
+ try_var(ret, entry->invoke({parameters.data(), parameters.size()}));
+ auto end_time = std::chrono::steady_clock::now();
+ total_time += (std::chrono::duration_cast(
+ end_time - start_time)
+ .count() /
+ 1e6);
- std::ofstream output_stream(output_bin, std::ios::binary);
-
- if (ret.is_a()) {
- try_(write_tensor_buffer(ret, output_stream));
- } else if (ret.is_a()) {
- try_var(tp, ret.as());
- for (auto &&ret_v : tp->fields()) {
- try_(write_tensor_buffer(ret_v, output_stream));
+ if (i == (loop_count - 1) && (entry->parameters_size() < bins.size())) {
+ if (ret.is_a()) {
+ auto output_bin = bins.back();
+ std::ofstream output_stream(output_bin, std::ios::binary);
+ try_(write_tensor_buffer(ret, output_stream));
+ output_stream.close();
+ } else if (ret.is_a()) {
+ try_var(tp, ret.as());
+ auto o = 0;
+ for (auto &&ret_v : tp->fields()) {
+ auto output_bin = bins[entry->parameters_size() + (o++)];
+ std::ofstream output_stream(output_bin, std::ios::binary);
+ try_(write_tensor_buffer(ret_v, output_stream));
+ output_stream.close();
+ }
+ } else {
+ return nncase::err(std::errc::bad_message);
+ }
}
- } else {
- return nncase::err(std::errc::bad_message);
}
- output_stream.close();
+
+ std::cout << "interp run: " << (total_time / loop_count)
+ << " ms, fps = " << 1000 / (total_time / loop_count) << std::endl;
+
return ok();
}
@@ -92,13 +113,12 @@ result run_core(const std::string &kmodel_path,
* @return int
*/
int main(NNCASE_UNUSED int argc, char **argv) {
- assert(argc >= 4);
- std::vector input_bins;
- for (int i = 2; i < argc - 1; i++) {
- input_bins.push_back(argv[i]);
+ assert(argc >= 3);
+ std::vector bins;
+ for (int i = 2; i < argc; i++) {
+ bins.push_back(argv[i]);
}
std::string kmodel_bin(argv[1]);
- std::string output_bin(argv[argc - 1]);
- run_core(kmodel_bin, input_bins, output_bin).unwrap_or_throw();
+ run_core(kmodel_bin, bins).unwrap_or_throw();
return 0;
-}
+}
\ No newline at end of file
diff --git a/src/Nncase.Cli/Commands/Compile.cs b/src/Nncase.Cli/Commands/Compile.cs
deleted file mode 100644
index 69a8cf0c54..0000000000
--- a/src/Nncase.Cli/Commands/Compile.cs
+++ /dev/null
@@ -1,214 +0,0 @@
-// Copyright (c) Canaan Inc. All rights reserved.
-// Licensed under the Apache license. See LICENSE file in the project root for full license information.
-
-using System;
-using System.Collections.Generic;
-using System.CommandLine;
-using System.CommandLine.Invocation;
-using System.IO;
-using System.Linq;
-using System.Threading.Tasks;
-using Microsoft.Extensions.DependencyInjection;
-using Microsoft.Extensions.Hosting;
-using Nncase.CodeGen;
-using Nncase.Compiler;
-using Nncase.Diagnostics;
-using Nncase.IR;
-using Nncase.Passes;
-using Nncase.Quantization;
-
-namespace Nncase.Cli.Commands;
-
-internal enum QuantType
-{
- UInt8,
- Int8,
- Int16,
-}
-
-internal enum DatasetFormat
-{
- Image,
- Raw,
- Pytest,
- Random,
-}
-
-///
-/// Compile command.
-///
-public sealed class Compile : Command
-{
- ///
- /// Initializes a new instance of the class.
- ///
- public Compile()
- : base("compile")
- {
- AddArgument(new Argument("input-file"));
- AddArgument(new Argument("output-file"));
- AddOption(new Option(
- aliases: new string[] { "-t", "--target" },
- description: "target architecture, e.g. cpu, k210"));
- AddOption(new Option(
- aliases: new[] { "-i", "--input-format" },
- description: "input format, e.g. tflite",
- getDefaultValue: () => "tflite"));
- AddOption(new Option(
- alias: "--dump-level",
- description: $"dump ir to .il, default is {0}",
- getDefaultValue: () => 0));
- AddOption(new Option(
- alias: "--dump-dir",
- description: "dump to directory, default is .",
- getDefaultValue: () => "."));
- AddOption(new Option(
- alias: "--quant-type",
- description: $"quant type, default is {QuantType.UInt8}",
- getDefaultValue: () => QuantType.UInt8));
- AddOption(new Option(
- alias: "--wquant-type",
- description: $"wquant type, default is {QuantType.UInt8}",
- getDefaultValue: () => QuantType.UInt8));
- AddOption(new Option(
- alias: "--dataset",
- description: $"calibration dataset, used in post quantization, default is empty",
- getDefaultValue: () => string.Empty));
- AddOption(new Option(
- alias: "--dataset-format",
- description: $"datset format: e.g. Image|Raw|Pytest",
- getDefaultValue: () => DatasetFormat.Raw));
- AddOption(new Option(
- alias: "--model-quant-mode",
- description: $"model quant mode, default is {Quantization.ModelQuantMode.NoQuant}",
- getDefaultValue: () => Quantization.ModelQuantMode.NoQuant));
- AddOption(new Option(
- alias: "--calib-method",
- description: $"model quant options, default is {Quantization.CalibMethod.Kld}",
- getDefaultValue: () => Quantization.CalibMethod.Kld));
- AddOption(new Option(
- alias: "--benchmark-only",
- description: $"benchmark only",
- getDefaultValue: () => false));
-
- Handler = CommandHandler.Create(RunAsync);
- }
-
- private static DumpFlags DumpLevelToFlags(int dumpLevel)
- {
- return dumpLevel switch
- {
- 0 => DumpFlags.None,
- 1 => DumpLevelToFlags(0) | DumpFlags.Compile,
- 2 => DumpLevelToFlags(1) | DumpFlags.PassIR,
- 3 => DumpLevelToFlags(2) | DumpFlags.Rewrite,
- 4 => DumpLevelToFlags(3) | DumpFlags.EGraphCost,
- 5 => DumpLevelToFlags(4) | DumpFlags.Evaluator,
- 6 => DumpLevelToFlags(5) | DumpFlags.Calibration,
- 7 => DumpLevelToFlags(6) | DumpFlags.Tiling,
- 8 => DumpLevelToFlags(7) | DumpFlags.Schedule,
- >= 9 => DumpLevelToFlags(8) | DumpFlags.CodeGen,
- _ => throw new ArgumentOutOfRangeException(nameof(dumpLevel)),
- };
- }
-
- private async Task RunAsync(CliCompileOptions cliOptions, IHost host)
- {
- CompilerServices.Configure(host.Services);
-
- // 1. setup the options
- var compileOptions = new CompileOptions
- {
- InputFile = cliOptions.InputFile,
- InputFormat = cliOptions.InputFormat,
- DumpFlags = DumpLevelToFlags(cliOptions.DumpLevel),
- DumpDir = cliOptions.DumpDir,
- QuantizeOptions = new()
- {
- CalibrationMethod = cliOptions.CalibMethod,
- QuantType = cliOptions.QuantType switch
- {
- QuantType.UInt8 => DataTypes.UInt8,
- QuantType.Int8 => DataTypes.Int8,
- QuantType.Int16 => DataTypes.Int16,
- _ => throw new ArgumentException("Invalid quant type"),
- },
- WQuantType = cliOptions.WQuantType switch
- {
- QuantType.UInt8 => DataTypes.UInt8,
- QuantType.Int8 => DataTypes.Int8,
- QuantType.Int16 => DataTypes.Int16,
- _ => throw new ArgumentException("Invalid weights quant type"),
- },
- ModelQuantMode = cliOptions.ModelQuantMode,
- },
- IsBenchmarkOnly = cliOptions.BenchmarkOnly,
- };
-
- // 2. import the model
- var target = CompilerServices.GetTarget(cliOptions.Target);
- using var compileSession = CompileSession.Create(target, compileOptions);
- var compiler = compileSession.Compiler;
- var module = await compiler.ImportModuleAsync(compileOptions.InputFormat, compileOptions.InputFile, compileOptions.IsBenchmarkOnly);
-
- // 3. create the calib dataset
- if (compileOptions.QuantizeOptions.ModelQuantMode == Quantization.ModelQuantMode.UsePTQ)
- {
- if (cliOptions.DatasetFormat == DatasetFormat.Random)
- {
- compileOptions.QuantizeOptions.CalibrationDataset = new RandomCalibrationDatasetProvider(((Function)module.Entry!).Parameters.ToArray(), 5);
- }
- else if (cliOptions.DatasetFormat == DatasetFormat.Pytest)
- {
- compileOptions.QuantizeOptions.CalibrationDataset = new PytestCalibrationDatasetProvider(((Function)module.Entry!).Parameters.ToArray(), cliOptions.Dataset);
- }
- else
- {
- throw new NotSupportedException(cliOptions.DatasetFormat.ToString());
- }
- }
-
- // 4. compile
- await compiler.CompileAsync();
-
- // 5. code gen
- using (var os = File.OpenWrite(cliOptions.OutputFile))
- {
- compiler.Gencode(os);
- }
- }
-}
-
-// Validate null in command line parser.
-#pragma warning disable CS8618
-
-internal sealed class CliCompileOptions
-{
- public string InputFile { get; set; }
-
- public string InputFormat { get; set; }
-
- public string Target { get; set; }
-
- public int DumpLevel { get; set; }
-
- public string DumpDir { get; set; }
-
- public QuantType QuantType { get; set; }
-
- public QuantType WQuantType { get; set; }
-
- public string OutputFile { get; set; }
-
- public ModelQuantMode ModelQuantMode { get; set; }
-
- public CalibMethod CalibMethod { get; set; }
-
- public string Dataset { get; set; }
-
- public DatasetFormat DatasetFormat { get; set; }
-
- public bool BenchmarkOnly { get; set; }
-}
-
-#pragma warning restore CS8618
diff --git a/src/Nncase.Cli/Compile.cs b/src/Nncase.Cli/Compile.cs
new file mode 100644
index 0000000000..e7f65a5303
--- /dev/null
+++ b/src/Nncase.Cli/Compile.cs
@@ -0,0 +1,217 @@
+// Copyright (c) Canaan Inc. All rights reserved.
+// Licensed under the Apache license. See LICENSE file in the project root for full license information.
+
+using System;
+using System.Collections.Generic;
+using System.CommandLine;
+using System.Linq;
+using Nncase.Diagnostics;
+using Nncase.Quantization;
+
+namespace Nncase.Cli;
+
+internal enum QuantType
+{
+ UInt8,
+ Int8,
+ Int16,
+}
+
+internal enum DatasetFormat
+{
+ Image,
+ Raw,
+ Pytest,
+ Random,
+}
+
+///
+/// Compile command.
+///
+internal sealed class CompileCommand : Command
+{
+ ///
+ /// Initializes a new instance of the class.
+ ///
+ public CompileCommand()
+ : base("compile")
+ {
+ InputFile = new Argument("input-file");
+ OutputFile = new Argument("output-file");
+ InputFormat = new Option(
+ aliases: new[] { "-i", "--input-format" },
+ description: "input format, e.g. tflite",
+ getDefaultValue: () => "tflite");
+ DumpFlags = new Option>(
+ name: "--dump-flags",
+ description: "dump ir flags. \navailable value: None,ImportOps,PassIR,EGraphCost,Rewrite,Calibration,Evaluator,Compile,Tiling,Schedule,CodeGen.")
+ {
+ AllowMultipleArgumentsPerToken = true,
+ };
+ DumpDir = new Option(
+ name: "--dump-dir",
+ description: "dump to directory.",
+ getDefaultValue: () => ".");
+ QuantType = new Option(
+ name: "--quant-type",
+ description: $"quant type",
+ getDefaultValue: () => Nncase.Cli.QuantType.UInt8);
+ WQuantType = new Option(
+ name: "--wquant-type",
+ description: $"wquant type",
+ getDefaultValue: () => Nncase.Cli.QuantType.UInt8);
+ Dataset = new Option(
+ name: "--dataset",
+ description: $"calibration dataset, used in post quantization",
+ getDefaultValue: () => string.Empty);
+ DatasetFormat = new Option(
+ name: "--dataset-format",
+ description: $"datset format.",
+ getDefaultValue: () => Nncase.Cli.DatasetFormat.Raw);
+ ModelQuantMode = new Option(
+ name: "--model-quant-mode",
+ description: $"model quant mode",
+ getDefaultValue: () => Quantization.ModelQuantMode.NoQuant);
+ CalibMethod = new Option(
+ name: "--calib-method",
+ description: $"model quant options",
+ getDefaultValue: () => Quantization.CalibMethod.Kld);
+ FixedVars = new Option>(
+ name: "--fixed-vars",
+ description: $"dynamic shape fixed vars, default is empty. \nset by `n:123`",
+ parseArgument: result =>
+ {
+ return result.Tokens.
+ Select(tk => tk.Value.Split(":").ToArray()).
+ Select(tp => (tp[0].Trim(), int.Parse(tp[1].Trim())));
+ })
+ {
+ AllowMultipleArgumentsPerToken = true,
+ };
+ PreProcess = new Option(
+ name: "--pre-process",
+ description: "whether enable pre process",
+ getDefaultValue: () => false);
+ InputLayout = new Option(
+ name: "--input-layout",
+ description: "the model input data layout",
+ getDefaultValue: () => string.Empty).FromAmong("NCHW", "NHWC");
+ OutputLayout = new Option(
+ name: "--output-layout",
+ description: "the model output data layout.",
+ getDefaultValue: () => string.Empty).FromAmong("NCHW", "NHWC");
+ InputType = new Option(
+ name: "--input-type",
+ description: "the model input data value type, default is Float32",
+ getDefaultValue: () => Nncase.InputType.Float32);
+ InputShape = new Option>(
+ name: "--input-shape",
+ description: "the model input data shape. eg. `--input-shape 1 2 3 4`",
+ getDefaultValue: Array.Empty)
+ {
+ AllowMultipleArgumentsPerToken = true,
+ };
+ InputRange = new Option>(
+ name: "--input-range",
+ description: "the model input data value range. eg `--input-range -100.3 200.4`",
+ getDefaultValue: Array.Empty)
+ {
+ AllowMultipleArgumentsPerToken = true,
+ };
+ SwapRB = new Option(
+ name: "--swap-rb",
+ description: "whether swap the model input data channel, like cv2.BGRtoRGB(im)",
+ getDefaultValue: () => false);
+ LetterBoxValue = new Option(
+ name: "--letter-box-value",
+ description: "letterbox fill value",
+ getDefaultValue: () => 0.0f);
+ Mean = new Option>(
+ name: "--mean",
+ description: "the model input data mean, default []",
+ getDefaultValue: Array.Empty)
+ {
+ AllowMultipleArgumentsPerToken = true,
+ };
+ Std = new Option>(
+ name: "--std",
+ description: "the model input data std, default []",
+ getDefaultValue: Array.Empty)
+ {
+ AllowMultipleArgumentsPerToken = true,
+ };
+ ModelLayout = new Option(
+ name: "--model-layout",
+ description: "the model's input layout.",
+ getDefaultValue: () => string.Empty).FromAmong("NCHW", "NHWC");
+ AddArgument(InputFile);
+ AddArgument(OutputFile);
+ AddGlobalOption(InputFormat);
+ AddGlobalOption(DumpFlags);
+ AddGlobalOption(DumpDir);
+ AddGlobalOption(QuantType);
+ AddGlobalOption(WQuantType);
+ AddGlobalOption(Dataset);
+ AddGlobalOption(DatasetFormat);
+ AddGlobalOption(ModelQuantMode);
+ AddGlobalOption(CalibMethod);
+ AddGlobalOption(FixedVars);
+ AddGlobalOption(PreProcess);
+ AddGlobalOption(InputLayout);
+ AddGlobalOption(OutputLayout);
+ AddGlobalOption(InputType);
+ AddGlobalOption(InputShape);
+ AddGlobalOption(InputRange);
+ AddGlobalOption(SwapRB);
+ AddGlobalOption(LetterBoxValue);
+ AddGlobalOption(Mean);
+ AddGlobalOption(Std);
+ AddGlobalOption(ModelLayout);
+ }
+
+ public Argument InputFile { get; }
+
+ public Argument OutputFile { get; }
+
+ public Option InputFormat { get; }
+
+ public Option> DumpFlags { get; }
+
+ public Option DumpDir { get; }
+
+ public Option QuantType { get; }
+
+ public Option WQuantType { get; }
+
+ public Option Dataset { get; }
+
+ public Option DatasetFormat { get; }
+
+ public Option ModelQuantMode { get; }
+
+ public Option CalibMethod { get; }
+
+ public Option> FixedVars { get; }
+
+ public Option PreProcess { get; }
+
+ public Option InputLayout { get; }
+
+ public Option OutputLayout { get; }
+
+ public Option InputType { get; }
+
+ public Option> InputShape { get; }
+
+ public Option> InputRange { get; }
+
+ public Option SwapRB { get; }
+
+ public Option LetterBoxValue { get; }
+
+ public Option> Mean { get; }
+
+ public Option> Std { get; }
+
+ public Option ModelLayout { get; }
+}
diff --git a/src/Nncase.Cli/Nncase.Cli.csproj b/src/Nncase.Cli/Nncase.Cli.csproj
index 3070806ba5..17e370d4ca 100644
--- a/src/Nncase.Cli/Nncase.Cli.csproj
+++ b/src/Nncase.Cli/Nncase.Cli.csproj
@@ -26,4 +26,8 @@
PreserveNewest
+
+
+
+
diff --git a/src/Nncase.Cli/Program.CommandLine.cs b/src/Nncase.Cli/Program.CommandLine.cs
deleted file mode 100644
index e88f691647..0000000000
--- a/src/Nncase.Cli/Program.CommandLine.cs
+++ /dev/null
@@ -1,26 +0,0 @@
-// Copyright (c) Canaan Inc. All rights reserved.
-// Licensed under the Apache license. See LICENSE file in the project root for full license information.
-
-using System;
-using System.CommandLine;
-using System.CommandLine.Builder;
-using System.Linq;
-
-namespace Nncase.Cli;
-
-internal partial class Program
-{
- private static CommandLineBuilder BuildCommandLine()
- {
- var commands = from t in typeof(Program).Assembly.ExportedTypes
- where t.Namespace == "Nncase.Cli.Commands" && t.IsAssignableTo(typeof(Command))
- select (Command)Activator.CreateInstance(t)!;
- var root = new RootCommand();
- foreach (var command in commands)
- {
- root.AddCommand(command);
- }
-
- return new CommandLineBuilder(root);
- }
-}
diff --git a/src/Nncase.Cli/Program.cs b/src/Nncase.Cli/Program.cs
index 2d3e40c803..0ef25c0565 100644
--- a/src/Nncase.Cli/Program.cs
+++ b/src/Nncase.Cli/Program.cs
@@ -1,10 +1,14 @@
// Copyright (c) Canaan Inc. All rights reserved.
// Licensed under the Apache license. See LICENSE file in the project root for full license information.
+using System;
+using System.Collections.Generic;
+using System.CommandLine;
using System.CommandLine.Builder;
using System.CommandLine.Hosting;
using System.CommandLine.Parsing;
using System.IO;
+using System.Linq;
using System.Threading.Tasks;
using Microsoft.Extensions.Configuration;
using Microsoft.Extensions.Hosting;
@@ -16,12 +20,155 @@ internal partial class Program
{
public static async Task Main(string[] args)
{
- return await BuildCommandLine()
+ return await ConfigureCommandLine()
.UseHost(ConfigureHost)
.UseDefaults()
.Build().InvokeAsync(args);
}
+ private static async Task RunAsync(string targetKind, CompileOptions compileOptions, DatasetFormat datasetFormat, string dataset, string outputFile, IHost host)
+ {
+ CompilerServices.Configure(host.Services);
+
+ // 2. import the model
+ var target = CompilerServices.GetTarget(targetKind);
+ using var compileSession = CompileSession.Create(target, compileOptions);
+ var compiler = compileSession.Compiler;
+ IR.IRModule module = await compiler.ImportModuleAsync(Path.GetExtension(compileOptions.InputFile).Trim('.'), compileOptions.InputFile);
+
+ // 3. create the calib dataset
+ if (compileOptions.QuantizeOptions.ModelQuantMode == Quantization.ModelQuantMode.UsePTQ)
+ {
+ if (datasetFormat == DatasetFormat.Random)
+ {
+ compileOptions.QuantizeOptions.CalibrationDataset = new Quantization.RandomCalibrationDatasetProvider(((Nncase.IR.Function)module.Entry!).Parameters.ToArray(), 5);
+ }
+ else if (datasetFormat == DatasetFormat.Pytest)
+ {
+ compileOptions.QuantizeOptions.CalibrationDataset = new Quantization.PytestCalibrationDatasetProvider(((IR.Function)module.Entry!).Parameters.ToArray(), dataset);
+ }
+ else
+ {
+ throw new NotSupportedException(datasetFormat.ToString());
+ }
+ }
+
+ // 4. compile
+ await compiler.CompileAsync();
+
+ // 5. code gen
+ using (var os = File.OpenWrite(outputFile))
+ {
+ compiler.Gencode(os);
+ }
+ }
+
+ private static CommandLineBuilder ConfigureCommandLine()
+ {
+ var compile = new CompileCommand();
+ foreach (var target in LoadTargets())
+ {
+ var (targetCmd, targetParser) = target.RegisterCommandAndParser();
+ Action targetHandler = async (System.CommandLine.Invocation.InvocationContext context) =>
+ {
+ var options = ParseCompileOptions(context, compile);
+ options.TargetCompileOptions = targetParser(context, targetCmd);
+ await RunAsync(targetCmd.Name, options, context.ParseResult.GetValueForOption(compile.DatasetFormat), context.ParseResult.GetValueForOption(compile.Dataset)!, context.ParseResult.GetValueForArgument(compile.OutputFile), context.GetHost());
+ };
+ targetCmd.SetHandler(targetHandler);
+ compile.AddCommand(targetCmd);
+ }
+
+ return new CommandLineBuilder(new RootCommand() { compile });
+ }
+
+ private static CompileOptions ParseCompileOptions(System.CommandLine.Invocation.InvocationContext context, CompileCommand compilecmd)
+ {
+ // 1. setup the options
+ var compileOptions = new CompileOptions
+ {
+ InputFile = context.ParseResult.GetValueForArgument(compilecmd.InputFile),
+ InputFormat = context.ParseResult.GetValueForOption(compilecmd.InputFormat)!,
+ DumpFlags = context.ParseResult.GetValueForOption(compilecmd.DumpFlags)!.Aggregate(Diagnostics.DumpFlags.None, (a, b) => a | b),
+ DumpDir = context.ParseResult.GetValueForOption(compilecmd.DumpDir)!,
+ PreProcess = context.ParseResult.GetValueForOption(compilecmd.PreProcess)!,
+ InputLayout = context.ParseResult.GetValueForOption(compilecmd.InputLayout)!,
+ OutputLayout = context.ParseResult.GetValueForOption(compilecmd.OutputLayout)!,
+ InputType = context.ParseResult.GetValueForOption(compilecmd.InputType)!,
+ InputShape = context.ParseResult.GetValueForOption(compilecmd.InputShape)!.ToArray(),
+ InputRange = context.ParseResult.GetValueForOption(compilecmd.InputRange)!.ToArray(),
+ SwapRB = context.ParseResult.GetValueForOption(compilecmd.SwapRB)!,
+ LetterBoxValue = context.ParseResult.GetValueForOption(compilecmd.LetterBoxValue)!,
+ Mean = context.ParseResult.GetValueForOption(compilecmd.Mean)!.ToArray(),
+ Std = context.ParseResult.GetValueForOption(compilecmd.Std)!.ToArray(),
+ ModelLayout = context.ParseResult.GetValueForOption(compilecmd.ModelLayout)!,
+ QuantizeOptions = new()
+ {
+ CalibrationMethod = context.ParseResult.GetValueForOption(compilecmd.CalibMethod),
+ QuantType = context.ParseResult.GetValueForOption(compilecmd.QuantType) switch
+ {
+ QuantType.UInt8 => DataTypes.UInt8,
+ QuantType.Int8 => DataTypes.Int8,
+ QuantType.Int16 => DataTypes.Int16,
+ _ => throw new ArgumentException("Invalid quant type"),
+ },
+ WQuantType = context.ParseResult.GetValueForOption(compilecmd.WQuantType) switch
+ {
+ QuantType.UInt8 => DataTypes.UInt8,
+ QuantType.Int8 => DataTypes.Int8,
+ QuantType.Int16 => DataTypes.Int16,
+ _ => throw new ArgumentException("Invalid weights quant type"),
+ },
+ ModelQuantMode = context.ParseResult.GetValueForOption(compilecmd.ModelQuantMode),
+ },
+ };
+
+ foreach (var item in context.ParseResult.GetValueForOption(compilecmd.FixedVars)!)
+ {
+ compileOptions.ShapeBucketOptions.FixVarMap.Add(item.Name, item.Value);
+ }
+
+ return compileOptions;
+ }
+
+ private static IReadOnlyList LoadTargets()
+ {
+ var loadContext = System.Runtime.Loader.AssemblyLoadContext.Default;
+ var pluginAsms = PluginLoader.GetPluginsSearchDirectories(PluginLoader.PluginPathEnvName, null).
+ Select(PluginLoader.GetPluginAssemblies).
+ SelectMany(x => x).
+ DistinctBy(Path.GetFileName).
+ Select(x => PluginLoader.LoadPluginAssembly(x, loadContext)).
+ Distinct().
+ ToList();
+ pluginAsms.AddRange(new[] { Path.GetDirectoryName(typeof(Program).Assembly.Location)! }.
+ Select(basePath =>
+ {
+ if (Directory.Exists(basePath))
+ {
+ return (from filePath in Directory.GetFiles(basePath, PluginLoader.ModulesDllPattern, SearchOption.AllDirectories)
+ where PluginLoader.IsLoadableAssembly(filePath)
+ select filePath).Distinct();
+ }
+ else
+ {
+ return Array.Empty();
+ }
+ }).
+ SelectMany(x => x).
+ DistinctBy(Path.GetFileName).
+ Select(x => PluginLoader.LoadPluginAssembly(x, loadContext)).
+ Distinct());
+ var targets = (from asm in pluginAsms
+ from t in asm.ExportedTypes
+ where t.IsClass
+ && t.IsAssignableTo(typeof(ITarget))
+ let ctor = t.GetConstructor(Type.EmptyTypes)
+ where ctor != null
+ select (ITarget)ctor.Invoke(null)).ToList();
+ return targets;
+ }
+
private static void ConfigureHost(IHostBuilder hostBuilder)
{
hostBuilder.ConfigureAppConfiguration(ConfigureAppConfiguration)
diff --git a/src/Nncase.Cli/packages.lock.json b/src/Nncase.Cli/packages.lock.json
index b438ba9cc6..56eaafb112 100644
--- a/src/Nncase.Cli/packages.lock.json
+++ b/src/Nncase.Cli/packages.lock.json
@@ -33,21 +33,22 @@
},
"StyleCop.Analyzers": {
"type": "Direct",
- "requested": "[1.2.0-beta.507, )",
- "resolved": "1.2.0-beta.507",
- "contentHash": "/FtugDT66cKJJ+GGH7rNpG6UDrT4iIWz45M6lrXXHobDUFDHw+q5VgkbiR+6ffTO564ge7w6fQh/eoQhVdJO8Q==",
+ "requested": "[1.2.0-beta.435, )",
+ "resolved": "1.2.0-beta.435",
+ "contentHash": "TADk7vdGXtfTnYCV7GyleaaRTQjfoSfZXprQrVMm7cSJtJbFc1QIbWPyLvrgrfGdfHbGmUPvaN4ODKNxg2jgPQ==",
"dependencies": {
- "StyleCop.Analyzers.Unstable": "1.2.0.507"
+ "StyleCop.Analyzers.Unstable": "1.2.0.435"
}
},
"System.CommandLine.Hosting": {
"type": "Direct",
- "requested": "[0.3.0-alpha.21216.1, )",
- "resolved": "0.3.0-alpha.21216.1",
- "contentHash": "zP8QEUH8dSUYUHdGk6k71kOJy8uFgEPZG2RfhA0cMjDH3/Jov5AjUNaxOvpSNHh+ewu8eIUCYgV8+fEkCPyNlw==",
+ "requested": "[0.4.0-alpha.22272.1, )",
+ "resolved": "0.4.0-alpha.22272.1",
+ "contentHash": "x9JhHxBLxlKyCIZADFYC8q16L9yGHdTakrLFjHabwR7Tk0761aTexiGgMTIS744HGuhc8pk9MoLUzsr/TlRfMQ==",
"dependencies": {
- "Microsoft.Extensions.Hosting": "3.1.5",
- "System.CommandLine": "2.0.0-beta1.21216.1"
+ "Microsoft.Extensions.Hosting": "6.0.0",
+ "System.CommandLine": "2.0.0-beta4.22272.1",
+ "System.CommandLine.NamingConventionBinder": "2.0.0-beta4.22272.1"
}
},
"Google.OrTools.runtime.linux-arm64": {
@@ -344,8 +345,8 @@
},
"StyleCop.Analyzers.Unstable": {
"type": "Transitive",
- "resolved": "1.2.0.507",
- "contentHash": "gTY3IQdRqDJ4hbhSA3e/R48oE8b/OiKfvwkt1QdNVfrJK2gMHBV8ldaHJ885jxWZfllK66soa/sdcjh9bX49Tw=="
+ "resolved": "1.2.0.435",
+ "contentHash": "ouwPWZxbOV3SmCZxIRqHvljkSzkCyi1tDoMzQtDb/bRP8ctASV/iRJr+A2Gdj0QLaLmWnqTWDrH82/iP+X80Lg=="
},
"System.Buffers": {
"type": "Transitive",
@@ -362,13 +363,12 @@
"System.Runtime": "4.3.0"
}
},
- "System.CommandLine": {
+ "System.CommandLine.NamingConventionBinder": {
"type": "Transitive",
- "resolved": "2.0.0-beta1.21216.1",
- "contentHash": "Nbv/tW8sbOKN5T+4SSVBMdk4ADSIpJpY4UHMsj3VkcNtOckIT4iyzagjF+W5FEh2YBRvmvVQijOTIZbUJ1+1aA==",
+ "resolved": "2.0.0-beta4.22272.1",
+ "contentHash": "ux2eUA/syF+JtlpMDc/Lsd6PBIBuwjH3AvHnestoh5uD0WKT5b+wkQxDWVCqp9qgVjMBTLNhX19ZYFtenunt9A==",
"dependencies": {
- "Microsoft.CSharp": "4.4.1",
- "system.memory": "4.5.4"
+ "System.CommandLine": "2.0.0-beta4.22272.1"
}
},
"System.Diagnostics.Contracts": {
@@ -696,6 +696,7 @@
"Microsoft.Extensions.Options": "[6.0.0, )",
"Microsoft.Toolkit.HighPerformance": "[7.1.1, )",
"NetFabric.Hyperlinq": "[3.0.0-beta48, )",
+ "System.CommandLine": "[2.0.0-beta4.22272.1, )",
"System.Reactive": "[5.0.0, )"
}
},
@@ -937,6 +938,12 @@
"resolved": "1.0.2",
"contentHash": "giLAHrjJe0Bh7yhNexR6pmcv02+Fi+lEPxQVdB8zvkuJCmy6rnqu8CZLIpxrUfLcWDuTCSiK0IfGmMhig3UDhA=="
},
+ "System.CommandLine": {
+ "type": "CentralTransitive",
+ "requested": "[2.0.0-beta4.22272.1, )",
+ "resolved": "2.0.0-beta4.22272.1",
+ "contentHash": "1uqED/q2H0kKoLJ4+hI2iPSBSEdTuhfCYADeJrAqERmiGQ2NNacYKRNEQ+gFbU4glgVyK8rxI+ZOe1onEtr/Pg=="
+ },
"System.Linq.Async": {
"type": "CentralTransitive",
"requested": "[6.0.1, )",
diff --git a/src/Nncase.CodeGen/CodeGen/LinkedFunction.cs b/src/Nncase.CodeGen/CodeGen/LinkedFunction.cs
index b3dbe692e0..bd4d97cafc 100644
--- a/src/Nncase.CodeGen/CodeGen/LinkedFunction.cs
+++ b/src/Nncase.CodeGen/CodeGen/LinkedFunction.cs
@@ -15,7 +15,6 @@ public class LinkedFunction : ILinkedFunction
public LinkedFunction(uint id, Callable sourceFunction, ulong textBegin, ulong textLength, IReadOnlyList sections)
{
Id = id;
- CompilerServices.InferenceType(sourceFunction);
ParameterTypes = ((CallableType)sourceFunction.CheckedType).Parameters.ToArray();
ReturnType = ((CallableType)sourceFunction.CheckedType).ReturnType;
TextBegin = textBegin;
diff --git a/src/Nncase.CodeGen/packages.lock.json b/src/Nncase.CodeGen/packages.lock.json
index b618b5504c..fd39ebd2fc 100644
--- a/src/Nncase.CodeGen/packages.lock.json
+++ b/src/Nncase.CodeGen/packages.lock.json
@@ -10,11 +10,11 @@
},
"StyleCop.Analyzers": {
"type": "Direct",
- "requested": "[1.2.0-beta.507, )",
- "resolved": "1.2.0-beta.507",
- "contentHash": "/FtugDT66cKJJ+GGH7rNpG6UDrT4iIWz45M6lrXXHobDUFDHw+q5VgkbiR+6ffTO564ge7w6fQh/eoQhVdJO8Q==",
+ "requested": "[1.2.0-beta.435, )",
+ "resolved": "1.2.0-beta.435",
+ "contentHash": "TADk7vdGXtfTnYCV7GyleaaRTQjfoSfZXprQrVMm7cSJtJbFc1QIbWPyLvrgrfGdfHbGmUPvaN4ODKNxg2jgPQ==",
"dependencies": {
- "StyleCop.Analyzers.Unstable": "1.2.0.507"
+ "StyleCop.Analyzers.Unstable": "1.2.0.435"
}
},
"Microsoft.Extensions.Configuration.Abstractions": {
@@ -53,8 +53,8 @@
},
"StyleCop.Analyzers.Unstable": {
"type": "Transitive",
- "resolved": "1.2.0.507",
- "contentHash": "gTY3IQdRqDJ4hbhSA3e/R48oE8b/OiKfvwkt1QdNVfrJK2gMHBV8ldaHJ885jxWZfllK66soa/sdcjh9bX49Tw=="
+ "resolved": "1.2.0.435",
+ "contentHash": "ouwPWZxbOV3SmCZxIRqHvljkSzkCyi1tDoMzQtDb/bRP8ctASV/iRJr+A2Gdj0QLaLmWnqTWDrH82/iP+X80Lg=="
},
"System.Buffers": {
"type": "Transitive",
@@ -76,6 +76,7 @@
"Microsoft.Extensions.Options": "[6.0.0, )",
"Microsoft.Toolkit.HighPerformance": "[7.1.1, )",
"NetFabric.Hyperlinq": "[3.0.0-beta48, )",
+ "System.CommandLine": "[2.0.0-beta4.22272.1, )",
"System.Reactive": "[5.0.0, )"
}
},
@@ -138,6 +139,12 @@
"System.Runtime.CompilerServices.Unsafe": "5.0.0"
}
},
+ "System.CommandLine": {
+ "type": "CentralTransitive",
+ "requested": "[2.0.0-beta4.22272.1, )",
+ "resolved": "2.0.0-beta4.22272.1",
+ "contentHash": "1uqED/q2H0kKoLJ4+hI2iPSBSEdTuhfCYADeJrAqERmiGQ2NNacYKRNEQ+gFbU4glgVyK8rxI+ZOe1onEtr/Pg=="
+ },
"System.Reactive": {
"type": "CentralTransitive",
"requested": "[5.0.0, )",
diff --git a/src/Nncase.Compiler/Compiler.cs b/src/Nncase.Compiler/Compiler.cs
index 4a7ffbe777..5d2b49488f 100644
--- a/src/Nncase.Compiler/Compiler.cs
+++ b/src/Nncase.Compiler/Compiler.cs
@@ -88,13 +88,15 @@ public void AddPreAndPostProcess(IPassManager passManager)
public void TargetIndependentPass(IPassManager passManager)
{
- passManager.AddWithName("ReshapeMatMul").Configure(p =>
+ passManager.AddWithName("NormAxisAndShape").Configure(p =>
{
p.Add();
- });
-
- passManager.AddWithName("SqueezeShape").Configure(p =>
- {
+ p.Add();
+ p.Add();
+ p.Add();
+ p.Add();
+ p.Add();
+ p.Add();
p.Add();
p.Add();
p.Add();
@@ -102,6 +104,7 @@ public void TargetIndependentPass(IPassManager passManager)
p.Add();
p.Add();
p.Add();
+ p.Add();
p.Add();
p.Add();
p.Add();
@@ -157,6 +160,8 @@ public void TargetIndependentPass(IPassManager passManager)
p.Add();
p.Add();
p.Add();
+ p.Add();
+ p.Add();
p.Add();
p.Add();
p.Add();
@@ -168,6 +173,7 @@ public void TargetIndependentPass(IPassManager passManager)
p.Add();
p.Add();
p.Add();
+ p.Add();
p.Add();
p.Add();
p.Add();
diff --git a/src/Nncase.Compiler/Hosting/PluginLoader.cs b/src/Nncase.Compiler/Hosting/PluginLoader.cs
index 73f10f367a..014ab29fb1 100644
--- a/src/Nncase.Compiler/Hosting/PluginLoader.cs
+++ b/src/Nncase.Compiler/Hosting/PluginLoader.cs
@@ -19,12 +19,14 @@ namespace Nncase.Hosting;
///
public sealed class PluginLoader
{
- private const string _modulesDllPattern = "Nncase.Modules.*.dll";
- private const string _pluginPathEnvName = "NNCASE_PLUGIN_PATH";
+ public const string PluginPathEnvName = "NNCASE_PLUGIN_PATH";
+
+ public const string ModulesDllPattern = "Nncase.Modules.*.dll";
private static readonly string[] _builtinModules = new[]
{
"Nncase.Modules.StackVM.dll",
+ "Nncase.Modules.CPU.dll",
"Nncase.Modules.K210.dll",
};
@@ -42,67 +44,16 @@ public PluginLoader(ILogger logger)
?? AssemblyLoadContext.Default;
}
- ///
- /// Load plugins.
- ///
- /// Plugins.
- public IReadOnlyList LoadPlugins()
- {
- var pluginAsms = GetPluginsSearchDirectories().Select(GetPluginAssemblies).SelectMany(x => x)
- .DistinctBy(Path.GetFileName).Select(LoadPluginAssembly).Distinct().ToList();
- var plugins = (from asm in pluginAsms
- from t in asm.ExportedTypes
- where t.IsClass
- && t.IsAssignableTo(typeof(IPlugin))
- let ctor = t.GetConstructor(Type.EmptyTypes)
- where ctor != null
- select (IPlugin)ctor.Invoke(null)).ToList();
-
- return plugins;
- }
-
- private static bool IsLoadableAssembly(string filePath)
- {
- using var fs = File.OpenRead(filePath);
- using var peReader = new PEReader(fs);
-
- if (!peReader.HasMetadata)
- {
- return false;
- }
-
- var metaReader = peReader.GetMetadataReader();
- if (!metaReader.IsAssembly)
- {
- return false;
- }
-
- // Is reference assembly
- if ((from cah in metaReader.CustomAttributes
- let ca = metaReader.GetCustomAttribute(cah)
- where ca.Constructor.Kind == HandleKind.MemberReference
- let ctor = metaReader.GetMemberReference((MemberReferenceHandle)ca.Constructor)
- let attrType = metaReader.GetTypeReference((TypeReferenceHandle)ctor.Parent)
- where metaReader.GetString(attrType.Namespace) == nameof(System.Runtime.CompilerServices)
- && metaReader.GetString(attrType.Name) == nameof(ReferenceAssemblyAttribute)
- select cah).Any())
- {
- return false;
- }
-
- return true;
- }
-
- private Assembly LoadPluginAssembly(string assemblyFile)
+ public static Assembly LoadPluginAssembly(string assemblyFile, AssemblyLoadContext loadContext)
{
- return _loadContext.LoadFromAssemblyPath(assemblyFile);
+ return loadContext.LoadFromAssemblyPath(assemblyFile);
}
- private IEnumerable GetPluginAssemblies(string basePath)
+ public static IEnumerable GetPluginAssemblies(string basePath)
{
if (Directory.Exists(basePath))
{
- return (from filePath in Directory.GetFiles(basePath, _modulesDllPattern, SearchOption.AllDirectories)
+ return (from filePath in Directory.GetFiles(basePath, ModulesDllPattern, SearchOption.AllDirectories)
where !_builtinModules.Contains(Path.GetFileName(filePath))
&& IsLoadableAssembly(filePath)
select filePath).Distinct();
@@ -113,19 +64,22 @@ private IEnumerable GetPluginAssemblies(string basePath)
}
}
- private IEnumerable GetPluginsSearchDirectories()
+ public static IEnumerable GetPluginsSearchDirectories(string pluginPathEnvName, ILogger? logger)
{
var directories = new List();
// 1. Environment variable
- var targetPathEnv = Environment.GetEnvironmentVariable(_pluginPathEnvName);
+ var targetPathEnv = Environment.GetEnvironmentVariable(pluginPathEnvName);
if (string.IsNullOrWhiteSpace(targetPathEnv))
{
- _logger.LogWarning($"{_pluginPathEnvName} is not set.");
+ if (logger is not null)
+ {
+ logger.LogWarning($"{pluginPathEnvName} is not set.");
+ }
}
else
{
- var targetPaths = from path in targetPathEnv.Split(Path.PathSeparator, StringSplitOptions.RemoveEmptyEntries)
+ var targetPaths = from path in targetPathEnv!.Split(Path.PathSeparator, StringSplitOptions.RemoveEmptyEntries)
select Environment.ExpandEnvironmentVariables(path);
directories.AddRange(targetPaths);
}
@@ -135,11 +89,62 @@ private IEnumerable GetPluginsSearchDirectories()
var modulesPath = Path.Combine(rootPath, "modules");
directories.Add(modulesPath);
- if (_logger.IsEnabled(LogLevel.Trace))
+ if (logger is not null && logger.IsEnabled(LogLevel.Trace))
{
- _logger.LogInformation($"Loading plugins from {string.Join(", ", directories)}.");
+ logger.LogInformation($"Loading plugins from {string.Join(", ", directories)}.");
}
return directories.Distinct();
}
+
+ public static bool IsLoadableAssembly(string filePath)
+ {
+ using var fs = File.OpenRead(filePath);
+ using var peReader = new PEReader(fs);
+
+ if (!peReader.HasMetadata)
+ {
+ return false;
+ }
+
+ var metaReader = peReader.GetMetadataReader();
+ if (!metaReader.IsAssembly)
+ {
+ return false;
+ }
+
+ // Is reference assembly
+ if ((from cah in metaReader.CustomAttributes
+ let ca = metaReader.GetCustomAttribute(cah)
+ where ca.Constructor.Kind == HandleKind.MemberReference
+ let ctor = metaReader.GetMemberReference((MemberReferenceHandle)ca.Constructor)
+ let attrType = metaReader.GetTypeReference((TypeReferenceHandle)ctor.Parent)
+ where metaReader.GetString(attrType.Namespace) == nameof(System.Runtime.CompilerServices)
+ && metaReader.GetString(attrType.Name) == nameof(ReferenceAssemblyAttribute)
+ select cah).Any())
+ {
+ return false;
+ }
+
+ return true;
+ }
+
+ ///
+ /// Load plugins.
+ ///
+ /// Plugins.
+ public IReadOnlyList LoadPlugins()
+ {
+ var pluginAsms = GetPluginsSearchDirectories(PluginPathEnvName, _logger).Select(GetPluginAssemblies).SelectMany(x => x)
+ .DistinctBy(Path.GetFileName).Select(x => LoadPluginAssembly(x, _loadContext)).Distinct().ToList();
+ var plugins = (from asm in pluginAsms
+ from t in asm.ExportedTypes
+ where t.IsClass
+ && t.IsAssignableTo(typeof(IPlugin))
+ let ctor = t.GetConstructor(Type.EmptyTypes)
+ where ctor != null
+ select (IPlugin)ctor.Invoke(null)).ToList();
+
+ return plugins;
+ }
}
diff --git a/src/Nncase.Compiler/Interop/CApi.cs b/src/Nncase.Compiler/Interop/CApi.cs
index 3ce2d0e289..69bacc8397 100644
--- a/src/Nncase.Compiler/Interop/CApi.cs
+++ b/src/Nncase.Compiler/Interop/CApi.cs
@@ -84,6 +84,7 @@ public unsafe struct CApiMT
public delegate* unmanaged QuantOptionsSetFineTuneWeightsMethodPtr;
public delegate* unmanaged QuantOptionsSetUseMixQuantPtr;
public delegate* unmanaged QuantOptionsSetQuantSchemePtr;
+ public delegate* unmanaged QuantOptionsSetQuantSchemeStrictModePtr;
public delegate* unmanaged QuantOptionsSetExportQuantSchemePtr;
public delegate* unmanaged QuantOptionsSetExportWeightRangeByChannelPtr;
public delegate* unmanaged QuantOptionsSetDumpQuantErrorPtr;
@@ -154,6 +155,7 @@ public static void Initialize(CApiMT* mt)
mt->QuantOptionsSetFineTuneWeightsMethodPtr = &QuantizeOptionsSetFineTuneWeightsMethod;
mt->QuantOptionsSetUseMixQuantPtr = &QuantOptionsSetUseMixQuant;
mt->QuantOptionsSetQuantSchemePtr = &QuantizeOptionsSetQuantScheme;
+ mt->QuantOptionsSetQuantSchemeStrictModePtr = &QuantizeOptionsSetQuantSchemeStrictMode;
mt->QuantOptionsSetExportQuantSchemePtr = &QuantizeOptionsSetExportQuantScheme;
mt->QuantOptionsSetExportWeightRangeByChannelPtr = &QuantizeOptionsSetExportWeightRangeByChannel;
mt->QuantOptionsSetDumpQuantErrorPtr = &QuantizeOptionsSetDumpQuantError;
@@ -603,6 +605,22 @@ private static void QuantizeOptionsSetQuantScheme(IntPtr quantizeOptionsHandle,
Get(quantizeOptionsHandle).QuantScheme = ToString(quantSchemePtr, quantSchemeLength);
}
+ [UnmanagedCallersOnly]
+ private static void QuantizeOptionsSetQuantSchemeStrictMode(IntPtr quantizeOptionsHandle, byte quantSchemeStrictMode)
+ {
+ switch (quantSchemeStrictMode)
+ {
+ case 0:
+ Get(quantizeOptionsHandle).QuantSchemeStrictMode = false;
+ break;
+ case 1:
+ Get(quantizeOptionsHandle).QuantSchemeStrictMode = true;
+ break;
+ default:
+ throw new ArgumentException("Invalid QuantSchemeStrictMode Flag");
+ }
+ }
+
[UnmanagedCallersOnly]
private static void QuantizeOptionsSetExportQuantScheme(IntPtr quantizeOptionsHandle, byte exportQuantScheme)
{
diff --git a/src/Nncase.Compiler/packages.lock.json b/src/Nncase.Compiler/packages.lock.json
index 639bb6a9bc..f22a606140 100644
--- a/src/Nncase.Compiler/packages.lock.json
+++ b/src/Nncase.Compiler/packages.lock.json
@@ -49,11 +49,11 @@
},
"StyleCop.Analyzers": {
"type": "Direct",
- "requested": "[1.2.0-beta.507, )",
- "resolved": "1.2.0-beta.507",
- "contentHash": "/FtugDT66cKJJ+GGH7rNpG6UDrT4iIWz45M6lrXXHobDUFDHw+q5VgkbiR+6ffTO564ge7w6fQh/eoQhVdJO8Q==",
+ "requested": "[1.2.0-beta.435, )",
+ "resolved": "1.2.0-beta.435",
+ "contentHash": "TADk7vdGXtfTnYCV7GyleaaRTQjfoSfZXprQrVMm7cSJtJbFc1QIbWPyLvrgrfGdfHbGmUPvaN4ODKNxg2jgPQ==",
"dependencies": {
- "StyleCop.Analyzers.Unstable": "1.2.0.507"
+ "StyleCop.Analyzers.Unstable": "1.2.0.435"
}
},
"Google.OrTools.runtime.linux-arm64": {
@@ -350,8 +350,8 @@
},
"StyleCop.Analyzers.Unstable": {
"type": "Transitive",
- "resolved": "1.2.0.507",
- "contentHash": "gTY3IQdRqDJ4hbhSA3e/R48oE8b/OiKfvwkt1QdNVfrJK2gMHBV8ldaHJ885jxWZfllK66soa/sdcjh9bX49Tw=="
+ "resolved": "1.2.0.435",
+ "contentHash": "ouwPWZxbOV3SmCZxIRqHvljkSzkCyi1tDoMzQtDb/bRP8ctASV/iRJr+A2Gdj0QLaLmWnqTWDrH82/iP+X80Lg=="
},
"System.Buffers": {
"type": "Transitive",
@@ -674,6 +674,7 @@
"Microsoft.Extensions.Options": "[6.0.0, )",
"Microsoft.Toolkit.HighPerformance": "[7.1.1, )",
"NetFabric.Hyperlinq": "[3.0.0-beta48, )",
+ "System.CommandLine": "[2.0.0-beta4.22272.1, )",
"System.Reactive": "[5.0.0, )"
}
},
@@ -885,6 +886,12 @@
"resolved": "1.0.2",
"contentHash": "giLAHrjJe0Bh7yhNexR6pmcv02+Fi+lEPxQVdB8zvkuJCmy6rnqu8CZLIpxrUfLcWDuTCSiK0IfGmMhig3UDhA=="
},
+ "System.CommandLine": {
+ "type": "CentralTransitive",
+ "requested": "[2.0.0-beta4.22272.1, )",
+ "resolved": "2.0.0-beta4.22272.1",
+ "contentHash": "1uqED/q2H0kKoLJ4+hI2iPSBSEdTuhfCYADeJrAqERmiGQ2NNacYKRNEQ+gFbU4glgVyK8rxI+ZOe1onEtr/Pg=="
+ },
"System.Linq.Async": {
"type": "CentralTransitive",
"requested": "[6.0.1, )",
diff --git a/src/Nncase.Core/CompileOptions.cs b/src/Nncase.Core/CompileOptions.cs
index 5d7b3cb058..e5d1b2874c 100644
--- a/src/Nncase.Core/CompileOptions.cs
+++ b/src/Nncase.Core/CompileOptions.cs
@@ -119,4 +119,9 @@ public sealed record CompileOptions
/// Gets or sets a value indicating whether is benchmark only.
///
public bool IsBenchmarkOnly { get; set; }
+
+ ///
+ /// Gets or sets the target compile options.
+ ///
+ public ITargetCompileOptions TargetCompileOptions { get; set; } = null!;
}
diff --git a/src/Nncase.Core/CompilerServices.cs b/src/Nncase.Core/CompilerServices.cs
index 5d1bd6cb83..9c666bcd21 100644
--- a/src/Nncase.Core/CompilerServices.cs
+++ b/src/Nncase.Core/CompilerServices.cs
@@ -73,6 +73,14 @@ public interface ICompilerServicesProvider
/// false for save const into bin.
public void DumpCSharpIR(Expr expr, string prefix, string dumpDir, bool randConst);
+ ///
+ /// dump the expr as csharp code.
+ ///
+ /// expression.
+ /// file prefix.
+ /// file dump ir.
+ public void DumpPatternIR(Expr expr, string prefix, string dumpDir);
+
///
/// print ir type.
///
@@ -468,6 +476,15 @@ public static void DumpDotIR(Expr expr, string prefix, string dumpPath, bool dis
public static void DumpCSharpIR(Expr expr, string prefix, string dumpDir, bool randConst = true) =>
Provider.DumpCSharpIR(expr, prefix, dumpDir, randConst);
+ ///
+ /// dump the expr as csharp code.
+ ///
+ /// expression.
+ /// file prefix.
+ /// file dump ir.
+ public static void DumpPatternIR(Expr expr, string prefix, string dumpDir) =>
+ Provider.DumpPatternIR(expr, prefix, dumpDir);
+
public static string Print(IRType type) => Provider.Print(type);
public static string Print(Expr expr, bool useScript = false) => Provider.Print(expr, useScript);
@@ -583,6 +600,10 @@ public void DumpDotIR(Expr expr, string prefix, string dumpPath, bool display_ca
public void DumpCSharpIR(Expr expr, string prefix, string dumpDir, bool randConst) =>
_irprinterProvider.DumpCSharpIR(expr, prefix, dumpDir, randConst);
+ ///
+ public void DumpPatternIR(Expr expr, string prefix, string dumpDir) =>
+ _irprinterProvider.DumpPatternIR(expr, prefix, dumpDir);
+
///
public string Print(IRType type) => _irprinterProvider.Print(type);
diff --git a/src/Nncase.Core/Converters/ConvertersModule.cs b/src/Nncase.Core/Converters/ConvertersModule.cs
index c7e5d4a9bc..3b406a8c88 100644
--- a/src/Nncase.Core/Converters/ConvertersModule.cs
+++ b/src/Nncase.Core/Converters/ConvertersModule.cs
@@ -28,5 +28,6 @@ public void ConfigureServices(IRegistrator registrator)
registrator.RegisterManyInterface(reuse: Reuse.Singleton);
registrator.RegisterManyInterface(reuse: Reuse.Singleton);
registrator.RegisterManyInterface(reuse: Reuse.Singleton);
+ registrator.RegisterManyInterface(reuse: Reuse.Singleton);
}
}
diff --git a/src/Nncase.Core/Converters/PointerConverters.cs b/src/Nncase.Core/Converters/PointerConverters.cs
index af274fab5e..134c773609 100644
--- a/src/Nncase.Core/Converters/PointerConverters.cs
+++ b/src/Nncase.Core/Converters/PointerConverters.cs
@@ -30,3 +30,25 @@ public void ConvertTo(ReadOnlySpan> source, Span dest, Cast
}
}
}
+
+internal class PointerIntConverters : IPointerSpanConverter
+{
+ public void ConvertTo(ReadOnlySpan> source, Span dest, CastMode castMode)
+ where T : unmanaged, IEquatable
+ {
+ if (castMode != CastMode.KDefault)
+ {
+ throw new InvalidCastException();
+ }
+
+ if (dest.Length < source.Length)
+ {
+ throw new ArgumentException("Dest buffer is not sufficient.");
+ }
+
+ for (int i = 0; i < source.Length; i++)
+ {
+ dest[i] = checked((int)source[i].Value);
+ }
+ }
+}
diff --git a/src/Nncase.Core/CostModel/Cost.cs b/src/Nncase.Core/CostModel/Cost.cs
index b989507457..e60414a613 100644
--- a/src/Nncase.Core/CostModel/Cost.cs
+++ b/src/Nncase.Core/CostModel/Cost.cs
@@ -204,6 +204,7 @@ public static UInt128 GetMemoryAccess(IRType type)
{
TensorType t => (UInt128)(t.Shape.Aggregate(1D, (acc, x) => acc * (x.IsFixed ? x.FixedValue : 1)) * t.DType.SizeInBytes),
TupleType t => t.Fields.Sum(GetMemoryAccess),
+ DistributedType t => GetMemoryAccess(Utilities.DistributedUtility.GetDividedTensorType(t)),
_ => 0,
};
}
@@ -229,6 +230,7 @@ public static UInt128 GetCPUCycles(IRType type, double cyclesPerElement = 1)
{
TensorType t => (UInt128)(t.Shape.Aggregate(1D, (acc, x) => acc * (x.IsFixed ? x.FixedValue : 1)) * cyclesPerElement),
TupleType t => t.Fields.Sum(GetMemoryAccess),
+ DistributedType t => GetCPUCycles(Utilities.DistributedUtility.GetDividedTensorType(t)),
_ => 0,
};
}
@@ -328,7 +330,7 @@ public static Cost GetActivationCost(TensorType ret, uint macPerElement)
}
// cost for op similar to broadcast
- public static Cost GetBroadcastCost(TensorType input, TensorType ret)
+ public static Cost GetBroadcastCost(IRType input, IRType ret)
{
return new()
{
diff --git a/src/Nncase.Core/DataTypes.cs b/src/Nncase.Core/DataTypes.cs
index d1a24255e7..0b59a5696e 100644
--- a/src/Nncase.Core/DataTypes.cs
+++ b/src/Nncase.Core/DataTypes.cs
@@ -114,7 +114,7 @@ public static bool IsPointer(this DataType srcType) =>
/// datatype name.
public static string GetDisplayName(this DataType dataType) => dataType switch
{
- PointerType pointerType => $"({GetDisplayName(pointerType.ElemType)}*)",
+ PointerType pointerType => $"({GetDisplayName(pointerType.ElemType)} *)",
PrimType primType => primType.ShortName,
ValueType => dataType.ToString(),
_ => throw new ArgumentOutOfRangeException(dataType.GetType().Name),
diff --git a/src/Nncase.Core/Diagnostics/IDumpper.cs b/src/Nncase.Core/Diagnostics/IDumpper.cs
index 0e07109232..cc84bf46cb 100644
--- a/src/Nncase.Core/Diagnostics/IDumpper.cs
+++ b/src/Nncase.Core/Diagnostics/IDumpper.cs
@@ -42,6 +42,8 @@ public interface IDumpper
void DumpCSharpIR(Expr expr, string prefix, string? reletivePath = null);
+ void DumpPatternIR(Expr expr, string prefix, string? reletivePath = null);
+
void DumpModule(IRModule module, string? reletivePath = null);
Stream OpenFile(string reletivePath, FileMode fileMode = FileMode.Create);
diff --git a/src/Nncase.Core/Diagnostics/NullDumpper.cs b/src/Nncase.Core/Diagnostics/NullDumpper.cs
index 7212fc7686..3120fa25e5 100644
--- a/src/Nncase.Core/Diagnostics/NullDumpper.cs
+++ b/src/Nncase.Core/Diagnostics/NullDumpper.cs
@@ -46,6 +46,11 @@ public void DumpCSharpIR(Expr expr, string prefix, string? reletivePath = null)
{
}
+ ///
+ public void DumpPatternIR(Expr expr, string prefix, string? reletivePath = null)
+ {
+ }
+
///
public bool IsEnabled(DumpFlags dumpFlags) => false;
diff --git a/src/Nncase.Core/DistributedType.cs b/src/Nncase.Core/DistributedType.cs
new file mode 100644
index 0000000000..efe52395da
--- /dev/null
+++ b/src/Nncase.Core/DistributedType.cs
@@ -0,0 +1,68 @@
+// Copyright (c) Canaan Inc. All rights reserved.
+// Licensed under the Apache license. See LICENSE file in the project root for full license information.
+using System;
+using System.Collections;
+using System.Collections.Generic;
+using System.Collections.Immutable;
+using System.Globalization;
+using System.IO;
+using System.Linq;
+using System.Text;
+using System.Threading.Tasks;
+using DryIoc.ImTools;
+
+namespace Nncase.IR;
+
+public abstract record SBP
+{
+ public static SBPPartialSum P => SBPPartialSum.Instance;
+
+ public static SBPBroadCast B => SBPBroadCast.Instance;
+
+ public static SBPSplit S(int axis) => new SBPSplit(axis);
+}
+
+public sealed record SBPSplit(int Axis) : SBP
+{
+ public override string ToString() => $"S({Axis})";
+}
+
+public sealed record SBPPartialSum : SBP
+{
+ public static readonly SBPPartialSum Instance = new SBPPartialSum();
+
+ private SBPPartialSum()
+ {
+ }
+
+ public override string ToString() => "P";
+}
+
+public sealed record SBPBroadCast : SBP
+{
+ public static readonly SBPBroadCast Instance = new SBPBroadCast();
+
+ private SBPBroadCast()
+ {
+ }
+
+ public override string ToString() => "B";
+}
+
+// public sealed record Placement(Placement.DeviceKind Kind, IRArray Hierarchy, string Name)
+public sealed record Placement(IRArray Hierarchy, string Name)
+{
+ // public enum DeviceKind : uint
+ // {
+ // CPU = 0,
+ // }
+ public int Rank => Hierarchy.Count;
+
+ // public override string ToString() => $"@{Kind} [{string.Join(',', Hierarchy.Zip(Name).Select(t => t.First.ToString() + '@' + t.Second.ToString()))}]";
+ public override string ToString() => $"@ [{string.Join(',', Hierarchy.Zip(Name).Select(t => t.First.ToString() + '@' + t.Second.ToString()))}]";
+}
+
+public sealed record DistributedType(TensorType TensorType, IRArray NdSBP, Placement Placement) : IRType
+{
+ public override string ToString() => $"{TensorType}, ({string.Join(',', NdSBP)}), {Placement}";
+}
diff --git a/src/Nncase.Core/FunctionCollector.cs b/src/Nncase.Core/FunctionCollector.cs
index 9ed2a8bca7..12655d25a1 100644
--- a/src/Nncase.Core/FunctionCollector.cs
+++ b/src/Nncase.Core/FunctionCollector.cs
@@ -17,7 +17,7 @@ public FunctionCollector()
public HashSet Functions => _functions;
- protected override int VisitLeafFunction(Function expr, Unit context)
+ protected override int VisitLeafFunction(Function expr)
{
_functions.Add(expr);
return 0;
diff --git a/src/Nncase.Core/IR/Buffers/BufferLoad.cs b/src/Nncase.Core/IR/Buffers/BufferLoad.cs
new file mode 100644
index 0000000000..dbf3427b6e
--- /dev/null
+++ b/src/Nncase.Core/IR/Buffers/BufferLoad.cs
@@ -0,0 +1,28 @@
+// Copyright (c) Canaan Inc. All rights reserved.
+// Licensed under the Apache license. See LICENSE file in the project root for full license information.
+
+using Nncase.IR.Tensors;
+using Nncase.PatternMatch;
+using static Nncase.IR.TypePatternUtility;
+
+namespace Nncase.IR.Buffers;
+
+///
+/// BufferLoad expression.
+///
+[PatternFunctionalGenerator]
+public sealed partial class BufferLoad : Op
+{
+ ///
+ /// Get the input parameter.
+ ///
+ public static readonly ParameterInfo Input = new(typeof(BufferLoad), 0, "input", IsTensor());
+
+ ///
+ /// Get the indices.
+ ///
+ public static readonly ParameterInfo Indices = new(typeof(BufferLoad), 1, "indices", IsTuple());
+
+ ///
+ public override bool CanFoldConstCall => false;
+}
diff --git a/src/Nncase.Core/IR/Buffers/BufferOf.cs b/src/Nncase.Core/IR/Buffers/BufferOf.cs
index a3bb033275..47a2541c1b 100644
--- a/src/Nncase.Core/IR/Buffers/BufferOf.cs
+++ b/src/Nncase.Core/IR/Buffers/BufferOf.cs
@@ -16,7 +16,7 @@ public sealed partial class BufferOf : Op
///
public static readonly ParameterInfo Input = new(typeof(BufferOf), 0, "input", IsTensor());
- public Schedule.MemoryLocation MemoryLocation { get; }
+ public TIR.MemoryLocation MemoryLocation { get; }
///
public override string DisplayProperty() => $"Schedule.MemoryLocation.{MemoryLocation}";
diff --git a/src/Nncase.Core/IR/Buffers/BufferStore.cs b/src/Nncase.Core/IR/Buffers/BufferStore.cs
new file mode 100644
index 0000000000..2d8e86cad8
--- /dev/null
+++ b/src/Nncase.Core/IR/Buffers/BufferStore.cs
@@ -0,0 +1,33 @@
+// Copyright (c) Canaan Inc. All rights reserved.
+// Licensed under the Apache license. See LICENSE file in the project root for full license information.
+
+using Nncase.IR.Tensors;
+using Nncase.PatternMatch;
+using static Nncase.IR.TypePatternUtility;
+
+namespace Nncase.IR.Buffers;
+
+///
+/// BufferStore op.
+///
+[PatternFunctionalGenerator]
+public sealed partial class BufferStore : Op
+{
+ ///
+ /// Get the input parameter.
+ ///
+ public static readonly ParameterInfo Input = new(typeof(BufferStore), 0, "input", IsTensor());
+
+ ///
+ /// Get the indices parameter.
+ ///
+ public static readonly ParameterInfo Indices = new(typeof(BufferStore), 1, "indices", IsTuple());
+
+ ///
+ /// Get the value parameter.
+ ///
+ public static readonly ParameterInfo Value = new(typeof(BufferStore), 2, "value", IsScalar());
+
+ ///
+ public override bool CanFoldConstCall => false;
+}
diff --git a/src/Nncase.Core/IR/Buffers/DDrOf.cs b/src/Nncase.Core/IR/Buffers/DDrOf.cs
index 8657d3f417..116e019d5f 100644
--- a/src/Nncase.Core/IR/Buffers/DDrOf.cs
+++ b/src/Nncase.Core/IR/Buffers/DDrOf.cs
@@ -17,4 +17,7 @@ public sealed partial class DDrOf : Op
/// Get the input parameter.
///
public static readonly ParameterInfo Input = new(typeof(DDrOf), 0, "input", IsTensor());
+
+ ///
+ public override bool CanFoldConstCall => false;
}
diff --git a/src/Nncase.Core/IR/Buffers/Functional.cs b/src/Nncase.Core/IR/Buffers/Functional.cs
index 54cd9f59cf..a2e3507a5f 100644
--- a/src/Nncase.Core/IR/Buffers/Functional.cs
+++ b/src/Nncase.Core/IR/Buffers/Functional.cs
@@ -41,5 +41,5 @@ public static Call BaseMentOf(Expr input) =>
///
/// create the uninitialized buffer.
///
- public static Call Uninitialized(DataType dataType, Schedule.MemoryLocation memoryLocation, Expr shape) => new Call(new Uninitialized(dataType, memoryLocation), shape);
+ public static Call Uninitialized(DataType dataType, TIR.MemoryLocation memoryLocation, Expr shape) => new Call(new Uninitialized(dataType, memoryLocation), shape);
}
diff --git a/src/Nncase.Core/IR/Buffers/MatchBuffer.cs b/src/Nncase.Core/IR/Buffers/MatchBuffer.cs
new file mode 100644
index 0000000000..3cafa7f595
--- /dev/null
+++ b/src/Nncase.Core/IR/Buffers/MatchBuffer.cs
@@ -0,0 +1,21 @@
+// Copyright (c) Canaan Inc. All rights reserved.
+// Licensed under the Apache license. See LICENSE file in the project root for full license information.
+
+using Nncase.IR.Tensors;
+using Nncase.PatternMatch;
+using static Nncase.IR.TypePatternUtility;
+
+namespace Nncase.IR.Buffers;
+
+///
+/// MatchBuffer op.
+/// todo maybe need united matchbuffer and allocatebuffer.
+///
+[PatternFunctionalGenerator]
+public sealed partial class MatchBuffer : Op
+{
+ public static readonly ParameterInfo Input = new(typeof(MatchBuffer), 0, "input", IsTensor());
+
+ ///
+ public override bool CanFoldConstCall => false;
+}
diff --git a/src/Nncase.Core/IR/Buffers/Uninitialized.cs b/src/Nncase.Core/IR/Buffers/Uninitialized.cs
index 2f529638b9..42564bbc4c 100644
--- a/src/Nncase.Core/IR/Buffers/Uninitialized.cs
+++ b/src/Nncase.Core/IR/Buffers/Uninitialized.cs
@@ -19,11 +19,11 @@ public sealed partial class Uninitialized : Op
public DataType DType { get; }
- public Schedule.MemoryLocation MemoryLocation { get; }
+ public TIR.MemoryLocation MemoryLocation { get; }
///
public override bool CanFoldConstCall => false;
///
- public override string DisplayProperty() => $"{DType.GetCSharpName()}, Schedule.MemoryLocation.{MemoryLocation}";
+ public override string DisplayProperty() => $"{DType.GetCSharpName()}, MemoryLocation.{MemoryLocation}";
}
diff --git a/src/Nncase.Core/IR/Callable.cs b/src/Nncase.Core/IR/Callable.cs
index edd6004539..16fc9ad5a7 100644
--- a/src/Nncase.Core/IR/Callable.cs
+++ b/src/Nncase.Core/IR/Callable.cs
@@ -17,7 +17,7 @@ public abstract class Callable : Expr
///
/// StackVM module kind.
///
- public static readonly string StackVMModuleKind = "stackvm";
+ public const string StackVMModuleKind = "stackvm";
public Callable(string name, string moduleKind, Expr[] operands)
: base(operands)
diff --git a/src/Nncase.Core/IR/ExprCloner.g.cs b/src/Nncase.Core/IR/ExprCloner.g.cs
index 214605c69e..855ff4e22e 100644
--- a/src/Nncase.Core/IR/ExprCloner.g.cs
+++ b/src/Nncase.Core/IR/ExprCloner.g.cs
@@ -1,4 +1,3 @@
-
//---------------------------------------------------------------------------------------------------
//
// This code was generated by T4 template.
@@ -57,8 +56,7 @@ protected override Expr VisitLeafIf(If expr, TContext context)
return expr.With(
condition: Clone(expr.Condition, context),
then: Clone(expr.Then, context),
- @else: Clone(expr.Else, context),
- paramList: expr.ParamList.Select(p => Clone(p, context)).ToArray()
+ @else: Clone(expr.Else, context)
);
}
@@ -141,31 +139,6 @@ protected override Expr VisitLeafBlock(TIR.Block expr, TContext context)
);
}
- ///
- protected override Expr VisitLeafLogicalBuffer(TIR.LogicalBuffer expr, TContext context)
- {
- return expr.With(
- dimensions: CloneArray(expr.Dimensions, context),
- strides: CloneArray(expr.Strides, context)
- );
- }
-
- ///
- protected override Expr VisitLeafPhysicalBuffer(TIR.PhysicalBuffer expr, TContext context)
- {
- return expr.With(
- );
- }
-
- ///
- protected override Expr VisitLeafBufferLoad(TIR.BufferLoad expr, TContext context)
- {
- return expr.With(
- buffer: Clone(expr.Buffer, context),
- indices: CloneArray(expr.Indices, context)
- );
- }
-
///
protected override Expr VisitLeafBufferRegion(TIR.BufferRegion expr, TContext context)
{
@@ -175,16 +148,6 @@ protected override Expr VisitLeafBufferRegion(TIR.BufferRegion expr, TContext co
);
}
- ///
- protected override Expr VisitLeafBufferStore(TIR.BufferStore expr, TContext context)
- {
- return expr.With(
- buffer: Clone(expr.Buffer, context),
- indices: CloneArray(expr.Indices, context),
- value: Clone(expr.Value, context)
- );
- }
-
///
protected override Expr VisitLeafFor(TIR.For expr, TContext context)
{
diff --git a/src/Nncase.Core/IR/ExprFunctor.cs b/src/Nncase.Core/IR/ExprFunctor.cs
index 4462f8d2cc..2d19dbc1b3 100644
--- a/src/Nncase.Core/IR/ExprFunctor.cs
+++ b/src/Nncase.Core/IR/ExprFunctor.cs
@@ -102,6 +102,13 @@ public partial class ExprFunctor : ExprFunctorResult.
public virtual TTypeResult VisitType(TensorType type) => base.VisitType(type, default);
+ ///
+ /// Visit point type.
+ ///
+ /// pointer type.
+ /// Result.
+ public virtual TTypeResult VisitType(PointerType type) => base.VisitType(type, default);
+
///
/// Visit tuple type.
///
@@ -116,6 +123,13 @@ public partial class ExprFunctor : ExprFunctorResult.
public virtual TTypeResult VisitType(CallableType type) => base.VisitType(type, default);
+ ///
+ /// Visit callable type.
+ ///
+ /// Callable type.
+ /// Result.
+ public virtual TTypeResult VisitType(DistributedType type) => base.VisitType(type, default);
+
///
/// Default visit routine.
///
@@ -135,12 +149,18 @@ public partial class ExprFunctor : ExprFunctor
public sealed override TTypeResult VisitType(TensorType type, Unit context) => VisitType(type);
+ ///
+ public sealed override TTypeResult VisitType(PointerType type, Unit context) => VisitType(type);
+
///
public sealed override TTypeResult VisitType(TupleType type, Unit context) => VisitType(type);
///
public sealed override TTypeResult VisitType(CallableType type, Unit context) => VisitType(type);
+ ///
+ public sealed override TTypeResult VisitType(DistributedType type, Unit context) => VisitType(type);
+
///
public sealed override TTypeResult DefaultVisitType(IRType type, Unit context) => DefaultVisitType(type);
diff --git a/src/Nncase.Core/IR/ExprFunctor.g.cs b/src/Nncase.Core/IR/ExprFunctor.g.cs
index 642b2709e4..188aad4659 100644
--- a/src/Nncase.Core/IR/ExprFunctor.g.cs
+++ b/src/Nncase.Core/IR/ExprFunctor.g.cs
@@ -1,4 +1,3 @@
-
//---------------------------------------------------------------------------------------------------
//
// This code was generated by T4 template.
@@ -79,6 +78,11 @@ public partial class ExprFunctor
///
internal protected virtual TExprResult VisitTupleConst(TupleConst expr, TContext context) => VisitConst(expr, context);
+ ///
+ /// Visit .
+ ///
+ internal protected virtual TExprResult VisitMemSpan(TIR.MemSpan expr, TContext context) => DefaultVisit(expr, context);
+
///
/// Visit .
///
@@ -94,31 +98,11 @@ public partial class ExprFunctor
///
internal protected virtual TExprResult VisitBuffer(TIR.Buffer expr, TContext context) => DefaultVisit(expr, context);
- ///
- /// Visit .
- ///
- internal protected virtual TExprResult VisitLogicalBuffer(TIR.LogicalBuffer expr, TContext context) => VisitBuffer(expr, context);
-
- ///
- /// Visit .
- ///
- internal protected virtual TExprResult VisitPhysicalBuffer(TIR.PhysicalBuffer expr, TContext context) => VisitBuffer(expr, context);
-
- ///
- /// Visit .
- ///
- internal protected virtual TExprResult VisitBufferLoad(TIR.BufferLoad expr, TContext context) => DefaultVisit(expr, context);
-
///
/// Visit .
///
internal protected virtual TExprResult VisitBufferRegion(TIR.BufferRegion expr, TContext context) => DefaultVisit(expr, context);
- ///
- /// Visit .
- ///
- internal protected virtual TExprResult VisitBufferStore(TIR.BufferStore expr, TContext context) => DefaultVisit(expr, context);
-
///
/// Visit .
///
@@ -250,6 +234,13 @@ public partial class ExprFunctor
///
internal protected sealed override TExprResult VisitTupleConst(TupleConst expr, Unit context) => VisitTupleConst(expr);
///
+ /// Visit .
+ ///
+ internal protected virtual TExprResult VisitMemSpan(TIR.MemSpan expr) => base.VisitMemSpan(expr, default);
+
+ ///
+ internal protected sealed override TExprResult VisitMemSpan(TIR.MemSpan expr, Unit context) => VisitMemSpan(expr);
+ ///
/// Visit .
///
internal protected virtual TExprResult VisitVar(Var expr) => base.VisitVar(expr, default);
@@ -271,27 +262,6 @@ public partial class ExprFunctor
///
internal protected sealed override TExprResult VisitBuffer(TIR.Buffer expr, Unit context) => VisitBuffer(expr);
///
- /// Visit .
- ///
- internal protected virtual TExprResult VisitLogicalBuffer(TIR.LogicalBuffer expr) => base.VisitLogicalBuffer(expr, default);
-
- ///
- internal protected sealed override TExprResult VisitLogicalBuffer(TIR.LogicalBuffer expr, Unit context) => VisitLogicalBuffer(expr);
- ///
- /// Visit .
- ///
- internal protected virtual TExprResult VisitPhysicalBuffer(TIR.PhysicalBuffer expr) => base.VisitPhysicalBuffer(expr, default);
-
- ///
- internal protected sealed override TExprResult VisitPhysicalBuffer(TIR.PhysicalBuffer expr, Unit context) => VisitPhysicalBuffer(expr);
- ///
- /// Visit .
- ///
- internal protected virtual TExprResult VisitBufferLoad(TIR.BufferLoad expr) => base.VisitBufferLoad(expr, default);
-
- ///
- internal protected sealed override TExprResult VisitBufferLoad(TIR.BufferLoad expr, Unit context) => VisitBufferLoad(expr);
- ///
/// Visit .
///
internal protected virtual TExprResult VisitBufferRegion(TIR.BufferRegion expr) => base.VisitBufferRegion(expr, default);
@@ -299,13 +269,6 @@ public partial class ExprFunctor
///
internal protected sealed override TExprResult VisitBufferRegion(TIR.BufferRegion expr, Unit context) => VisitBufferRegion(expr);
///
- /// Visit .
- ///
- internal protected virtual TExprResult VisitBufferStore(TIR.BufferStore expr) => base.VisitBufferStore(expr, default);
-
- ///
- internal protected sealed override TExprResult VisitBufferStore(TIR.BufferStore expr, Unit context) => VisitBufferStore(expr);
- ///
/// Visit .
///
internal protected virtual TExprResult VisitFor(TIR.For expr) => base.VisitFor(expr, default);
diff --git a/src/Nncase.Core/IR/ExprRewriter.g.cs b/src/Nncase.Core/IR/ExprRewriter.g.cs
index 4c8cece3f2..b842c110f1 100644
--- a/src/Nncase.Core/IR/ExprRewriter.g.cs
+++ b/src/Nncase.Core/IR/ExprRewriter.g.cs
@@ -1,4 +1,3 @@
-
//---------------------------------------------------------------------------------------------------
//
// This code was generated by T4 template.
@@ -92,6 +91,12 @@ protected sealed override Expr VisitLeafTupleConst(TupleConst expr, TContext con
return RewriteLeafTupleConst(expr, context);
}
+ ///
+ protected sealed override Expr VisitLeafMemSpan(TIR.MemSpan expr, TContext context)
+ {
+ return RewriteLeafMemSpan(expr, context);
+ }
+
///
protected sealed override Expr VisitLeafVar(Var expr, TContext context)
{
@@ -110,36 +115,12 @@ protected sealed override Expr VisitLeafBuffer(TIR.Buffer expr, TContext context
return RewriteLeafBuffer(expr, context);
}
- ///
- protected sealed override Expr VisitLeafLogicalBuffer(TIR.LogicalBuffer expr, TContext context)
- {
- return RewriteLeafLogicalBuffer(expr, context);
- }
-
- ///
- protected sealed override Expr VisitLeafPhysicalBuffer(TIR.PhysicalBuffer expr, TContext context)
- {
- return RewriteLeafPhysicalBuffer(expr, context);
- }
-
- ///
- protected sealed override Expr VisitLeafBufferLoad(TIR.BufferLoad expr, TContext context)
- {
- return RewriteLeafBufferLoad(expr, context);
- }
-
///
protected sealed override Expr VisitLeafBufferRegion(TIR.BufferRegion expr, TContext context)
{
return RewriteLeafBufferRegion(expr, context);
}
- ///
- protected sealed override Expr VisitLeafBufferStore(TIR.BufferStore expr, TContext context)
- {
- return RewriteLeafBufferStore(expr, context);
- }
-
///
protected sealed override Expr VisitLeafFor(TIR.For expr, TContext context)
{
@@ -247,6 +228,11 @@ protected sealed override Expr VisitLeafIterVar(TIR.IterVar expr, TContext conte
///
protected virtual Expr RewriteLeafTupleConst(TupleConst expr, TContext context) => RewriteLeafConst(expr, context);
+ ///
+ /// Rewrite leaf .
+ ///
+ protected virtual Expr RewriteLeafMemSpan(TIR.MemSpan expr, TContext context) => DefaultRewriteLeaf(expr, context);
+
///
/// Rewrite leaf .
///
@@ -262,31 +248,11 @@ protected sealed override Expr VisitLeafIterVar(TIR.IterVar expr, TContext conte
///
protected virtual Expr RewriteLeafBuffer(TIR.Buffer expr, TContext context) => DefaultRewriteLeaf(expr, context);
- ///
- /// Rewrite leaf .
- ///
- protected virtual Expr RewriteLeafLogicalBuffer(TIR.LogicalBuffer expr, TContext context) => RewriteLeafBuffer(expr, context);
-
- ///
- /// Rewrite leaf .
- ///
- protected virtual Expr RewriteLeafPhysicalBuffer(TIR.PhysicalBuffer expr, TContext context) => RewriteLeafBuffer(expr, context);
-
- ///
- /// Rewrite leaf .
- ///
- protected virtual Expr RewriteLeafBufferLoad(TIR.BufferLoad expr, TContext context) => DefaultRewriteLeaf(expr, context);
-
///
/// Rewrite leaf .
///
protected virtual Expr RewriteLeafBufferRegion(TIR.BufferRegion expr, TContext context) => DefaultRewriteLeaf(expr, context);
- ///
- /// Rewrite leaf .
- ///
- protected virtual Expr RewriteLeafBufferStore(TIR.BufferStore expr, TContext context) => DefaultRewriteLeaf(expr, context);
-
///
/// Rewrite leaf .
///
@@ -430,6 +396,14 @@ public partial class ExprRewriter
///
protected sealed override Expr RewriteLeafTupleConst(TupleConst expr, Unit context) => RewriteLeafTupleConst(expr);
+ ///
+ /// Rewrite leaf .
+ ///
+ protected virtual Expr RewriteLeafMemSpan(TIR.MemSpan expr) => DefaultRewriteLeaf(expr);
+
+ ///
+ protected sealed override Expr RewriteLeafMemSpan(TIR.MemSpan expr, Unit context) => RewriteLeafMemSpan(expr);
+
///
/// Rewrite leaf .
///
@@ -454,30 +428,6 @@ public partial class ExprRewriter
///
protected sealed override Expr RewriteLeafBuffer(TIR.Buffer expr, Unit context) => RewriteLeafBuffer(expr);
- ///
- /// Rewrite leaf .
- ///
- protected virtual Expr RewriteLeafLogicalBuffer(TIR.LogicalBuffer expr) => RewriteLeafBuffer(expr);
-
- ///
- protected sealed override Expr RewriteLeafLogicalBuffer(TIR.LogicalBuffer expr, Unit context) => RewriteLeafLogicalBuffer(expr);
-
- ///
- /// Rewrite leaf .
- ///
- protected virtual Expr RewriteLeafPhysicalBuffer(TIR.PhysicalBuffer expr) => RewriteLeafBuffer(expr);
-
- ///
- protected sealed override Expr RewriteLeafPhysicalBuffer(TIR.PhysicalBuffer expr, Unit context) => RewriteLeafPhysicalBuffer(expr);
-
- ///
- /// Rewrite leaf .
- ///
- protected virtual Expr RewriteLeafBufferLoad(TIR.BufferLoad expr) => DefaultRewriteLeaf(expr);
-
- ///
- protected sealed override Expr RewriteLeafBufferLoad(TIR.BufferLoad expr, Unit context) => RewriteLeafBufferLoad(expr);
-
///
/// Rewrite leaf .
///
@@ -486,14 +436,6 @@ public partial class ExprRewriter
///
protected sealed override Expr RewriteLeafBufferRegion(TIR.BufferRegion expr, Unit context) => RewriteLeafBufferRegion(expr);
- ///
- /// Rewrite leaf .
- ///
- protected virtual Expr RewriteLeafBufferStore(TIR.BufferStore expr) => DefaultRewriteLeaf(expr);
-
- ///
- protected sealed override Expr RewriteLeafBufferStore(TIR.BufferStore expr, Unit context) => RewriteLeafBufferStore(expr);
-
///
/// Rewrite leaf .
///
diff --git a/src/Nncase.Core/IR/ExprVisitor.g.cs b/src/Nncase.Core/IR/ExprVisitor.g.cs
index c296f5f7e0..dd974ec60b 100644
--- a/src/Nncase.Core/IR/ExprVisitor.g.cs
+++ b/src/Nncase.Core/IR/ExprVisitor.g.cs
@@ -1,4 +1,3 @@
-
//---------------------------------------------------------------------------------------------------
//
// This code was generated by T4 template.
@@ -104,38 +103,31 @@ protected internal override TExprResult VisitTupleConst(TupleConst expr, TContex
}
///
- protected internal override TExprResult VisitVar(Var expr, TContext context)
+ protected internal override TExprResult VisitMemSpan(TIR.MemSpan expr, TContext context)
{
VisitOperands(expr, context);
- return VisitLeafVar(expr, context);
+ return VisitLeafMemSpan(expr, context);
}
///
- protected internal override TExprResult VisitBlock(TIR.Block expr, TContext context)
- {
- VisitOperands(expr, context);
- return VisitLeafBlock(expr, context);
- }
-
- ///
- protected internal override TExprResult VisitLogicalBuffer(TIR.LogicalBuffer expr, TContext context)
+ protected internal override TExprResult VisitVar(Var expr, TContext context)
{
VisitOperands(expr, context);
- return VisitLeafLogicalBuffer(expr, context);
+ return VisitLeafVar(expr, context);
}
///
- protected internal override TExprResult VisitPhysicalBuffer(TIR.PhysicalBuffer expr, TContext context)
+ protected internal override TExprResult VisitBlock(TIR.Block expr, TContext context)
{
VisitOperands(expr, context);
- return VisitLeafPhysicalBuffer(expr, context);
+ return VisitLeafBlock(expr, context);
}
///
- protected internal override TExprResult VisitBufferLoad(TIR.BufferLoad expr, TContext context)
+ protected internal override TExprResult VisitBuffer(TIR.Buffer expr, TContext context)
{
VisitOperands(expr, context);
- return VisitLeafBufferLoad(expr, context);
+ return VisitLeafBuffer(expr, context);
}
///
@@ -145,13 +137,6 @@ protected internal override TExprResult VisitBufferRegion(TIR.BufferRegion expr,
return VisitLeafBufferRegion(expr, context);
}
- ///
- protected internal override TExprResult VisitBufferStore(TIR.BufferStore expr, TContext context)
- {
- VisitOperands(expr, context);
- return VisitLeafBufferStore(expr, context);
- }
-
///
protected internal override TExprResult VisitFor(TIR.For expr, TContext context)
{
@@ -270,6 +255,11 @@ protected internal override TExprResult VisitIterVar(TIR.IterVar expr, TContext
///
protected virtual TExprResult VisitLeafTupleConst(TupleConst expr, TContext context) => VisitLeafConst(expr, context);
+ ///
+ /// Visit leaf .
+ ///
+ protected virtual TExprResult VisitLeafMemSpan(TIR.MemSpan expr, TContext context) => DefaultVisitLeaf(expr, context);
+
///
/// Visit leaf .
///
@@ -285,31 +275,11 @@ protected internal override TExprResult VisitIterVar(TIR.IterVar expr, TContext
///
protected virtual TExprResult VisitLeafBuffer(TIR.Buffer expr, TContext context) => DefaultVisitLeaf(expr, context);
- ///
- /// Visit leaf .
- ///
- protected virtual TExprResult VisitLeafLogicalBuffer(TIR.LogicalBuffer expr, TContext context) => VisitLeafBuffer(expr, context);
-
- ///
- /// Visit leaf .
- ///
- protected virtual TExprResult VisitLeafPhysicalBuffer(TIR.PhysicalBuffer expr, TContext context) => VisitLeafBuffer(expr, context);
-
- ///
- /// Visit leaf .
- ///
- protected virtual TExprResult VisitLeafBufferLoad(TIR.BufferLoad expr, TContext context) => DefaultVisitLeaf(expr, context);
-
///
/// Visit leaf .
///
protected virtual TExprResult VisitLeafBufferRegion(TIR.BufferRegion expr, TContext context) => DefaultVisitLeaf(expr, context);
- ///
- /// Visit leaf .
- ///
- protected virtual TExprResult VisitLeafBufferStore(TIR.BufferStore expr, TContext context) => DefaultVisitLeaf(expr, context);
-
///
/// Visit leaf .
///
@@ -353,182 +323,168 @@ public partial class ExprVisitor
/// Visit .
///
internal protected virtual TExprResult VisitCall(Call expr) => base.VisitCall(expr, default);
-
+
///
internal protected sealed override TExprResult VisitCall(Call expr, Unit context) => VisitCall(expr);
///
/// Visit .
///
internal protected virtual TExprResult VisitFunction(Function expr) => base.VisitFunction(expr, default);
-
+
///
internal protected sealed override TExprResult VisitFunction(Function expr, Unit context) => VisitFunction(expr);
///
/// Visit .
///
internal protected virtual TExprResult VisitFusion(Fusion expr) => base.VisitFusion(expr, default);
-
+
///
internal protected sealed override TExprResult VisitFusion(Fusion expr, Unit context) => VisitFusion(expr);
///
/// Visit .
///
internal protected virtual TExprResult VisitIf(If expr) => base.VisitIf(expr, default);
-
+
///
internal protected sealed override TExprResult VisitIf(If expr, Unit context) => VisitIf(expr);
///
/// Visit .
///
internal protected virtual TExprResult VisitMarker(Marker expr) => base.VisitMarker(expr, default);
-
+
///
internal protected sealed override TExprResult VisitMarker(Marker expr, Unit context) => VisitMarker(expr);
///
/// Visit .
///
internal protected virtual TExprResult VisitNone(None expr) => base.VisitNone(expr, default);
-
+
///
internal protected sealed override TExprResult VisitNone(None expr, Unit context) => VisitNone(expr);
///
/// Visit .
///
internal protected virtual TExprResult VisitOp(Op expr) => base.VisitOp(expr, default);
-
+
///
internal protected sealed override TExprResult VisitOp(Op expr, Unit context) => VisitOp(expr);
///
/// Visit .
///
internal protected virtual TExprResult VisitPrimFunctionWrapper(PrimFunctionWrapper expr) => base.VisitPrimFunctionWrapper(expr, default);
-
+
///
internal protected sealed override TExprResult VisitPrimFunctionWrapper(PrimFunctionWrapper expr, Unit context) => VisitPrimFunctionWrapper(expr);
///
/// Visit .
///
internal protected virtual TExprResult VisitTensorConst(TensorConst expr) => base.VisitTensorConst(expr, default);
-
+
///
internal protected sealed override TExprResult VisitTensorConst(TensorConst expr, Unit context) => VisitTensorConst(expr);
///
/// Visit .
///
internal protected virtual TExprResult VisitTuple(IR.Tuple expr) => base.VisitTuple(expr, default);
-
+
///
internal protected sealed override TExprResult VisitTuple(IR.Tuple expr, Unit context) => VisitTuple(expr);
///
/// Visit .
///
internal protected virtual TExprResult VisitTupleConst(TupleConst expr) => base.VisitTupleConst(expr, default);
-
+
///
internal protected sealed override TExprResult VisitTupleConst(TupleConst expr, Unit context) => VisitTupleConst(expr);
///
+ /// Visit .
+ ///
+ internal protected virtual TExprResult VisitMemSpan(TIR.MemSpan expr) => base.VisitMemSpan(expr, default);
+
+ ///
+ internal protected sealed override TExprResult VisitMemSpan(TIR.MemSpan expr, Unit context) => VisitMemSpan(expr);
+ ///
/// Visit .
///
internal protected virtual TExprResult VisitVar(Var expr) => base.VisitVar(expr, default);
-
+
///
internal protected sealed override TExprResult VisitVar(Var expr, Unit context) => VisitVar(expr);
///
/// Visit .
///
internal protected virtual TExprResult VisitBlock(TIR.Block expr) => base.VisitBlock(expr, default);
-
+
///
internal protected sealed override TExprResult VisitBlock(TIR.Block expr, Unit context) => VisitBlock(expr);
///
- /// Visit .
- ///
- internal protected virtual TExprResult VisitLogicalBuffer(TIR.LogicalBuffer expr) => base.VisitLogicalBuffer(expr, default);
-
- ///
- internal protected sealed override TExprResult VisitLogicalBuffer(TIR.LogicalBuffer expr, Unit context) => VisitLogicalBuffer(expr);
- ///
- /// Visit .
+ /// Visit .
///
- internal protected virtual TExprResult VisitPhysicalBuffer(TIR.PhysicalBuffer expr) => base.VisitPhysicalBuffer(expr, default);
-
+ internal protected virtual TExprResult VisitBuffer(TIR.Buffer expr) => base.VisitBuffer(expr, default);
+
///
- internal protected sealed override TExprResult VisitPhysicalBuffer(TIR.PhysicalBuffer expr, Unit context) => VisitPhysicalBuffer(expr);
- ///
- /// Visit .
- ///
- internal protected virtual TExprResult VisitBufferLoad(TIR.BufferLoad expr) => base.VisitBufferLoad(expr, default);
-
- ///
- internal protected sealed override TExprResult VisitBufferLoad(TIR.BufferLoad expr, Unit context) => VisitBufferLoad(expr);
+ internal protected sealed override TExprResult VisitBuffer(TIR.Buffer expr, Unit context) => VisitBuffer(expr);
///
/// Visit .
///
internal protected virtual TExprResult VisitBufferRegion(TIR.BufferRegion expr) => base.VisitBufferRegion(expr, default);
-
+
///
internal protected sealed override TExprResult VisitBufferRegion(TIR.BufferRegion expr, Unit context) => VisitBufferRegion(expr);
///
- /// Visit .
- ///
- internal protected virtual TExprResult VisitBufferStore(TIR.BufferStore expr) => base.VisitBufferStore(expr, default);
-
- ///
- internal protected sealed override TExprResult VisitBufferStore(TIR.BufferStore expr, Unit context) => VisitBufferStore(expr);
- ///
/// Visit .
///
internal protected virtual TExprResult VisitFor(TIR.For expr) => base.VisitFor(expr, default);
-
+
///
internal protected sealed override TExprResult VisitFor(TIR.For expr, Unit context) => VisitFor(expr);
///
/// Visit .
///
internal protected virtual TExprResult VisitIfThenElse(TIR.IfThenElse expr) => base.VisitIfThenElse(expr, default);
-
+
///
internal protected sealed override TExprResult VisitIfThenElse(TIR.IfThenElse expr, Unit context) => VisitIfThenElse(expr);
///
/// Visit .
///
internal protected virtual TExprResult VisitLet(TIR.Let expr) => base.VisitLet(expr, default);
-
+
///
internal protected sealed override TExprResult VisitLet(TIR.Let expr, Unit context) => VisitLet(expr);
///
/// Visit .
///
internal protected virtual TExprResult VisitPrimFunction(TIR.PrimFunction expr) => base.VisitPrimFunction(expr, default);
-
+
///
internal protected sealed override TExprResult VisitPrimFunction(TIR.PrimFunction expr, Unit context) => VisitPrimFunction(expr);
///
/// Visit .
///
internal protected virtual TExprResult VisitSequential(TIR.Sequential expr) => base.VisitSequential(expr, default);
-
+
///
internal protected sealed override TExprResult VisitSequential(TIR.Sequential expr, Unit context) => VisitSequential(expr);
///
/// Visit .
///
internal protected virtual TExprResult VisitRange(TIR.Range expr) => base.VisitRange(expr, default);
-
+
///
internal protected sealed override TExprResult VisitRange(TIR.Range expr, Unit context) => VisitRange(expr);
///
/// Visit .
///
internal protected virtual TExprResult VisitIterVar(TIR.IterVar expr) => base.VisitIterVar(expr, default);
-
+
///
internal protected sealed override TExprResult VisitIterVar(TIR.IterVar expr, Unit context) => VisitIterVar(expr);
///
/// Visit leaf .
///
protected virtual TExprResult VisitLeafBaseFunction(BaseFunction expr) => base.VisitLeafBaseFunction(expr, default);
-
+
///
protected sealed override TExprResult VisitLeafBaseFunction(BaseFunction expr, Unit context) => VisitLeafBaseFunction(expr);
@@ -536,7 +492,7 @@ public partial class ExprVisitor
/// Visit leaf .
///
protected virtual TExprResult VisitLeafCall(Call expr) => base.VisitLeafCall(expr, default);
-
+
///
protected sealed override TExprResult VisitLeafCall(Call expr, Unit context) => VisitLeafCall(expr);
@@ -544,7 +500,7 @@ public partial class ExprVisitor
/// Visit leaf .
///
protected virtual TExprResult VisitLeafConst(Const expr) => base.VisitLeafConst(expr, default);
-
+
///
protected sealed override TExprResult VisitLeafConst(Const expr, Unit context) => VisitLeafConst(expr);
@@ -552,15 +508,15 @@ public partial class ExprVisitor
/// Visit leaf .
///
protected virtual TExprResult VisitLeafFunction(Function expr) => base.VisitLeafFunction(expr, default);
-
+
///
- protected override TExprResult VisitLeafFunction(Function expr, Unit context) => VisitLeafFunction(expr);
+ protected sealed override TExprResult VisitLeafFunction(Function expr, Unit context) => VisitLeafFunction(expr);
///
/// Visit leaf .
///
protected virtual TExprResult VisitLeafFusion(Fusion expr) => base.VisitLeafFusion(expr, default);
-
+
///
protected sealed override TExprResult VisitLeafFusion(Fusion expr, Unit context) => VisitLeafFusion(expr);
@@ -568,7 +524,7 @@ public partial class ExprVisitor
/// Visit leaf .
///
protected virtual TExprResult VisitLeafIf(If expr) => base.VisitLeafIf(expr, default);
-
+
///
protected sealed override TExprResult VisitLeafIf(If expr, Unit context) => VisitLeafIf(expr);
@@ -576,7 +532,7 @@ public partial class ExprVisitor
/// Visit leaf .
///
protected virtual TExprResult VisitLeafMarker(Marker expr) => base.VisitLeafMarker(expr, default);
-
+
///
protected sealed override TExprResult VisitLeafMarker(Marker expr, Unit context) => VisitLeafMarker(expr);
@@ -584,7 +540,7 @@ public partial class ExprVisitor
/// Visit leaf .
///
protected virtual TExprResult VisitLeafNone(None expr) => base.VisitLeafNone(expr, default);
-
+
///
protected sealed override TExprResult VisitLeafNone(None expr, Unit context) => VisitLeafNone(expr);
@@ -592,7 +548,7 @@ public partial class ExprVisitor
/// Visit leaf .
///
protected virtual TExprResult VisitLeafOp(Op expr) => base.VisitLeafOp(expr, default);
-
+
///
protected sealed override TExprResult VisitLeafOp(Op expr, Unit context) => VisitLeafOp(expr);
@@ -600,7 +556,7 @@ public partial class ExprVisitor
/// Visit leaf .
///
protected virtual TExprResult VisitLeafPrimFunctionWrapper(PrimFunctionWrapper expr) => base.VisitLeafPrimFunctionWrapper(expr, default);
-
+
///
protected sealed override TExprResult VisitLeafPrimFunctionWrapper(PrimFunctionWrapper expr, Unit context) => VisitLeafPrimFunctionWrapper(expr);
@@ -608,7 +564,7 @@ public partial class ExprVisitor
/// Visit leaf .
///
protected virtual TExprResult VisitLeafTensorConst(TensorConst expr) => base.VisitLeafTensorConst(expr, default);
-
+
///
protected sealed override TExprResult VisitLeafTensorConst(TensorConst expr, Unit context) => VisitLeafTensorConst(expr);
@@ -616,7 +572,7 @@ public partial class ExprVisitor
/// Visit leaf .
///
protected virtual TExprResult VisitLeafTuple(IR.Tuple expr) => base.VisitLeafTuple(expr, default);
-
+
///
protected sealed override TExprResult VisitLeafTuple(IR.Tuple expr, Unit context) => VisitLeafTuple(expr);
@@ -624,15 +580,23 @@ public partial class ExprVisitor
/// Visit leaf .
///
protected virtual TExprResult VisitLeafTupleConst(TupleConst expr) => base.VisitLeafTupleConst(expr, default);
-
+
///
protected sealed override TExprResult VisitLeafTupleConst(TupleConst expr, Unit context) => VisitLeafTupleConst(expr);
+ ///
+ /// Visit leaf .
+ ///
+ protected virtual TExprResult VisitLeafMemSpan(TIR.MemSpan expr) => base.VisitLeafMemSpan(expr, default);
+
+ ///
+ protected sealed override TExprResult VisitLeafMemSpan(TIR.MemSpan expr, Unit context) => VisitLeafMemSpan(expr);
+
///
/// Visit leaf .
///
protected virtual TExprResult VisitLeafVar(Var expr) => base.VisitLeafVar(expr, default);
-
+
///
protected sealed override TExprResult VisitLeafVar(Var expr, Unit context) => VisitLeafVar(expr);
@@ -640,7 +604,7 @@ public partial class ExprVisitor
/// Visit leaf .
///
protected virtual TExprResult VisitLeafBlock(TIR.Block expr) => base.VisitLeafBlock(expr, default);
-
+
///
protected sealed override TExprResult VisitLeafBlock(TIR.Block expr, Unit context) => VisitLeafBlock(expr);
@@ -648,55 +612,23 @@ public partial class ExprVisitor
/// Visit leaf .
///
protected virtual TExprResult VisitLeafBuffer(TIR.Buffer expr) => base.VisitLeafBuffer(expr, default);
-
+
///
protected sealed override TExprResult VisitLeafBuffer(TIR.Buffer expr, Unit context) => VisitLeafBuffer(expr);
- ///
- /// Visit leaf .
- ///
- protected virtual TExprResult VisitLeafLogicalBuffer(TIR.LogicalBuffer expr) => base.VisitLeafLogicalBuffer(expr, default);
-
- ///
- protected sealed override TExprResult VisitLeafLogicalBuffer(TIR.LogicalBuffer expr, Unit context) => VisitLeafLogicalBuffer(expr);
-
- ///
- /// Visit leaf .
- ///
- protected virtual TExprResult VisitLeafPhysicalBuffer(TIR.PhysicalBuffer expr) => base.VisitLeafPhysicalBuffer(expr, default);
-
- ///
- protected sealed override TExprResult VisitLeafPhysicalBuffer(TIR.PhysicalBuffer expr, Unit context) => VisitLeafPhysicalBuffer(expr);
-
- ///
- /// Visit leaf .
- ///
- protected virtual TExprResult VisitLeafBufferLoad(TIR.BufferLoad expr) => base.VisitLeafBufferLoad(expr, default);
-
- ///
- protected sealed override TExprResult VisitLeafBufferLoad(TIR.BufferLoad expr, Unit context) => VisitLeafBufferLoad(expr);
-
///
/// Visit leaf .
///
protected virtual TExprResult VisitLeafBufferRegion(TIR.BufferRegion expr) => base.VisitLeafBufferRegion(expr, default);
-
+
///
protected sealed override TExprResult VisitLeafBufferRegion(TIR.BufferRegion expr, Unit context) => VisitLeafBufferRegion(expr);
- ///
- /// Visit leaf .
- ///
- protected virtual TExprResult VisitLeafBufferStore(TIR.BufferStore expr) => base.VisitLeafBufferStore(expr, default);
-
- ///
- protected sealed override TExprResult VisitLeafBufferStore(TIR.BufferStore expr, Unit context) => VisitLeafBufferStore(expr);
-
///
/// Visit leaf .
///
protected virtual TExprResult VisitLeafFor(TIR.For expr) => base.VisitLeafFor(expr, default);
-
+
///
protected sealed override TExprResult VisitLeafFor(TIR.For expr, Unit context) => VisitLeafFor(expr);
@@ -704,7 +636,7 @@ public partial class ExprVisitor
/// Visit leaf .
///
protected virtual TExprResult VisitLeafIfThenElse(TIR.IfThenElse expr) => base.VisitLeafIfThenElse(expr, default);
-
+
///
protected sealed override TExprResult VisitLeafIfThenElse(TIR.IfThenElse expr, Unit context) => VisitLeafIfThenElse(expr);
@@ -712,7 +644,7 @@ public partial class ExprVisitor
/// Visit leaf .
///
protected virtual TExprResult VisitLeafLet(TIR.Let expr) => base.VisitLeafLet(expr, default);
-
+
///
protected sealed override TExprResult VisitLeafLet(TIR.Let expr, Unit context) => VisitLeafLet(expr);
@@ -720,7 +652,7 @@ public partial class ExprVisitor
/// Visit leaf .
///
protected virtual TExprResult VisitLeafPrimFunction(TIR.PrimFunction expr) => base.VisitLeafPrimFunction(expr, default);
-
+
///
protected sealed override TExprResult VisitLeafPrimFunction(TIR.PrimFunction expr, Unit context) => VisitLeafPrimFunction(expr);
@@ -728,7 +660,7 @@ public partial class ExprVisitor
/// Visit leaf .
///
protected virtual TExprResult VisitLeafSequential(TIR.Sequential expr) => base.VisitLeafSequential(expr, default);
-
+
///
protected sealed override TExprResult VisitLeafSequential(TIR.Sequential expr, Unit context) => VisitLeafSequential(expr);
@@ -736,7 +668,7 @@ public partial class ExprVisitor
/// Visit leaf .
///
protected virtual TExprResult VisitLeafRange(TIR.Range expr) => base.VisitLeafRange(expr, default);
-
+
///
protected sealed override TExprResult VisitLeafRange(TIR.Range expr, Unit context) => VisitLeafRange(expr);
@@ -744,7 +676,7 @@ public partial class ExprVisitor
/// Visit leaf .
///
protected virtual TExprResult VisitLeafIterVar(TIR.IterVar expr) => base.VisitLeafIterVar(expr, default);
-
+
///
protected sealed override TExprResult VisitLeafIterVar(TIR.IterVar expr, Unit context) => VisitLeafIterVar(expr);
diff --git a/src/Nncase.Core/IR/IIRPrinterProvider.cs b/src/Nncase.Core/IR/IIRPrinterProvider.cs
index 3c267771df..f411167f2a 100644
--- a/src/Nncase.Core/IR/IIRPrinterProvider.cs
+++ b/src/Nncase.Core/IR/IIRPrinterProvider.cs
@@ -73,6 +73,14 @@ public interface IIRPrinterProvider
/// randConst = false will save the const into bin.
public void DumpCSharpIR(Expr expr, string prefix, string dumpDir, bool randConst);
+ ///
+ /// dump the expr as csharp code.
+ ///
+ /// expression.
+ /// file prefix.
+ /// file dump ir.
+ public void DumpPatternIR(Expr expr, string prefix, string dumpDir);
+
///
/// print ir type.
///
diff --git a/src/Nncase.Core/IR/IRList.csv b/src/Nncase.Core/IR/IRList.csv
index 5ae3c89d18..ba9dd8033b 100644
--- a/src/Nncase.Core/IR/IRList.csv
+++ b/src/Nncase.Core/IR/IRList.csv
@@ -11,14 +11,11 @@ PrimFunctionWrapper,true,true,BaseFunction,,Target
TensorConst,true,false,Const,,
Tuple,true,false,Default,IR.,@Fields
TupleConst,true,false,Const,,
+MemSpan,true,false,Default,TIR.,Start;Size;
Var,true,false,Default,,
Block,true,false,Default,TIR.,Body;InitBody;@IterVars;@Reads;@Writes;@AllocBuffers;Predicate
-Buffer,false,false,Default,TIR.,
-LogicalBuffer,true,false,Buffer,TIR.,@Dimensions;@Strides
-PhysicalBuffer,true,false,Buffer,TIR.,
-BufferLoad,true,false,Default,TIR.,Buffer;@Indices
+Buffer,true,false,Default,TIR.,MemSpan;@Dimensions;@Strides;
BufferRegion,true,false,Default,TIR.,Buffer;@Region
-BufferStore,true,false,Default,TIR.,Buffer;@Indices;Value
For,true,false,Default,TIR.,LoopVar;Domain;Body
IfThenElse,true,false,Default,TIR.,Condition;Then;Else
Let,true,false,Default,TIR.,Var;Expression;Body
diff --git a/src/Nncase.Core/IR/IRType.cs b/src/Nncase.Core/IR/IRType.cs
index b8aeda469f..a4311aae13 100644
--- a/src/Nncase.Core/IR/IRType.cs
+++ b/src/Nncase.Core/IR/IRType.cs
@@ -139,6 +139,15 @@ public sealed record TensorType(DataType DType, Shape Shape) : IRType
/// the Pointed Element Type.
/// the pointer tensor type.
public static TensorType Pointer(DataType elemType) => new(new PointerType(elemType), Shape.Scalar);
+
+ ///
+ public override string ToString() => DType switch
+ {
+ PrimType ptype => ptype.GetDisplayName() + (Shape.IsScalar ? string.Empty : Shape.ToString()),
+ PointerType { ElemType: PrimType etype } => $"*{etype.GetDisplayName()}",
+ ValueType => $"{DType}",
+ _ => throw new NotSupportedException(DType.GetType().Name),
+ };
}
///
diff --git a/src/Nncase.Core/IR/Imaging/ResizeImage.cs b/src/Nncase.Core/IR/Imaging/ResizeImage.cs
index 088651b511..ae48d48831 100644
--- a/src/Nncase.Core/IR/Imaging/ResizeImage.cs
+++ b/src/Nncase.Core/IR/Imaging/ResizeImage.cs
@@ -21,7 +21,7 @@ public sealed partial class ResizeImage : Op
///
/// Gets input.
///
- public static readonly ParameterInfo Input = new(typeof(ResizeImage), 0, "input", HasRank(r => r >= 2, "RanK >= 2"));
+ public static readonly ParameterInfo Input = new(typeof(ResizeImage), 0, "input", HasRank(r => r >= 2, "RanK >= 2"), ParameterKind.Input);
///
/// Gets roi.
diff --git a/src/Nncase.Core/IR/Math/Binary.cs b/src/Nncase.Core/IR/Math/Binary.cs
index ead10ff710..f61f8a8704 100644
--- a/src/Nncase.Core/IR/Math/Binary.cs
+++ b/src/Nncase.Core/IR/Math/Binary.cs
@@ -20,12 +20,12 @@ public sealed partial class Binary : Op
///
/// Gets lhs.
///
- public static readonly ParameterInfo Lhs = new(typeof(Binary), 0, "lhs");
+ public static readonly ParameterInfo Lhs = new(typeof(Binary), 0, "lhs", ParameterKind.Input);
///
/// Gets rhs.
///
- public static readonly ParameterInfo Rhs = new(typeof(Binary), 1, "rhs");
+ public static readonly ParameterInfo Rhs = new(typeof(Binary), 1, "rhs", ParameterKind.Input);
public BinaryOp BinaryOp { get; }
diff --git a/src/Nncase.Core/IR/Math/Clamp.cs b/src/Nncase.Core/IR/Math/Clamp.cs
index 9f14cf287d..b8409f375c 100644
--- a/src/Nncase.Core/IR/Math/Clamp.cs
+++ b/src/Nncase.Core/IR/Math/Clamp.cs
@@ -21,7 +21,7 @@ public sealed partial class Clamp : Op
///
/// Gets input.
///
- public static readonly ParameterInfo Input = new(typeof(Clamp), 0, "input");
+ public static readonly ParameterInfo Input = new(typeof(Clamp), 0, "input", ParameterKind.Input);
///
/// Gets min.
diff --git a/src/Nncase.Core/IR/Math/MatMul.cs b/src/Nncase.Core/IR/Math/MatMul.cs
index 51d5615f1f..fc74e211e1 100644
--- a/src/Nncase.Core/IR/Math/MatMul.cs
+++ b/src/Nncase.Core/IR/Math/MatMul.cs
@@ -20,10 +20,10 @@ public sealed partial class MatMul : Op
///
/// Gets input.
///
- public static readonly ParameterInfo Lhs = new(typeof(MatMul), 0, "lhs");
+ public static readonly ParameterInfo Lhs = new(typeof(MatMul), 0, "lhs", ParameterKind.Input);
///
/// Gets Other.
///
- public static readonly ParameterInfo Rhs = new(typeof(MatMul), 1, "rhs");
+ public static readonly ParameterInfo Rhs = new(typeof(MatMul), 1, "rhs", ParameterKind.Input);
}
diff --git a/src/Nncase.Core/IR/Math/ReduceArg.cs b/src/Nncase.Core/IR/Math/ReduceArg.cs
index ecad8e95e2..2afd43010c 100644
--- a/src/Nncase.Core/IR/Math/ReduceArg.cs
+++ b/src/Nncase.Core/IR/Math/ReduceArg.cs
@@ -21,7 +21,7 @@ public sealed partial class ReduceArg : Op
///
/// Gets input.
///
- public static readonly ParameterInfo Input = new(typeof(ReduceArg), 0, "input");
+ public static readonly ParameterInfo Input = new(typeof(ReduceArg), 0, "input", ParameterKind.Input);
///
/// Gets Axis.
@@ -42,8 +42,8 @@ public sealed partial class ReduceArg : Op
public ReduceArgOp ReduceArgOp { get; }
- public DataType DestType { get; }
+ public PrimType DestType { get; }
///
- public override string DisplayProperty() => $"ReduceArgOp.{ReduceArgOp}";
+ public override string DisplayProperty() => $"ReduceArgOp.{ReduceArgOp}, DestType: {DestType}";
}
diff --git a/src/Nncase.Core/IR/Math/Unary.cs b/src/Nncase.Core/IR/Math/Unary.cs
index 820572437e..20d6b3fb03 100644
--- a/src/Nncase.Core/IR/Math/Unary.cs
+++ b/src/Nncase.Core/IR/Math/Unary.cs
@@ -20,7 +20,7 @@ public sealed partial class Unary : Op
///
/// Gets input.
///
- public static readonly ParameterInfo Input = new(typeof(Unary), 0, "input");
+ public static readonly ParameterInfo Input = new(typeof(Unary), 0, "input", ParameterKind.Input);
public UnaryOp UnaryOp { get; }
diff --git a/src/Nncase.Core/IR/NN/Activations.cs b/src/Nncase.Core/IR/NN/Activations.cs
index 46df70d241..1866e65220 100644
--- a/src/Nncase.Core/IR/NN/Activations.cs
+++ b/src/Nncase.Core/IR/NN/Activations.cs
@@ -154,7 +154,12 @@ public sealed partial class Swish : ActivationOp
///
/// Gets input.
///
- public static readonly ParameterInfo Input = new(typeof(Swish), 0, "input");
+ public static readonly ParameterInfo Input = new(typeof(Swish), 0, "input", ParameterKind.Input);
+
+ ///
+ /// Gets beta.
+ ///
+ public static readonly ParameterInfo Beta = new(typeof(Swish), 1, "beta", IsFloatScalar());
}
///
diff --git a/src/Nncase.Core/IR/NN/Conv2D.cs b/src/Nncase.Core/IR/NN/Conv2D.cs
index 43c1b2fced..0607870481 100644
--- a/src/Nncase.Core/IR/NN/Conv2D.cs
+++ b/src/Nncase.Core/IR/NN/Conv2D.cs
@@ -21,17 +21,17 @@ public sealed partial class Conv2D : Op
///
/// Gets input.
///
- public static readonly ParameterInfo Input = new(typeof(Conv2D), 0, "input");
+ public static readonly ParameterInfo Input = new(typeof(Conv2D), 0, "input", ParameterKind.Input);
///
/// Gets Weights.
///
- public static readonly ParameterInfo Weights = new(typeof(Conv2D), 1, "weights", HasRank(4));
+ public static readonly ParameterInfo Weights = new(typeof(Conv2D), 1, "weights", HasRank(4), ParameterKind.Input);
///
/// Gets Bias.
///
- public static readonly ParameterInfo Bias = new(typeof(Conv2D), 2, "bias", HasRank(1));
+ public static readonly ParameterInfo Bias = new(typeof(Conv2D), 2, "bias", HasRank(1), ParameterKind.Input);
///
/// Gets Stride.
diff --git a/src/Nncase.Core/IR/NN/Functional.cs b/src/Nncase.Core/IR/NN/Functional.cs
index 30b4005388..e8a44d8a38 100644
--- a/src/Nncase.Core/IR/NN/Functional.cs
+++ b/src/Nncase.Core/IR/NN/Functional.cs
@@ -34,7 +34,7 @@ public static class NN
public static Call BatchNormalization(Expr input, Expr scale, Expr bias, Expr input_mean, Expr input_var, Expr epsilon, Expr momentum) => new Call(new BatchNormalization(), input, scale, bias, input_mean, input_var, epsilon, momentum);
- public static Call LayerNorm(int axis, float epsilon, Expr input, Expr scale, Expr bias) => new Call(new LayerNorm(axis, epsilon), input, scale, bias);
+ public static Call LayerNorm(int axis, float epsilon, Expr input, Expr scale, Expr bias, bool hasMean = true) => new Call(new LayerNorm(axis, epsilon, hasMean), input, scale, bias);
public static Call BatchToSpace(Expr input, Expr blockShape, Expr crops) => new Call(new BatchToSpace(), input, blockShape, crops);
@@ -103,5 +103,10 @@ public static Call ReduceWindow2D(ReduceOp reduceOp, Expr input, Expr initValue,
///
/// create Swish call.
///
- public static Call Swish(Expr input) => new Call(new Swish(), input);
+ public static Call Swish(Expr input) => new Call(new Swish(), input, 1f);
+
+ ///
+ /// create Swish call.
+ ///
+ public static Call Swish(Expr input, Expr beta) => new Call(new Swish(), input, beta);
}
diff --git a/src/Nncase.Core/IR/NN/LayerNorm.cs b/src/Nncase.Core/IR/NN/LayerNorm.cs
index 2dff32f440..2474f44fc2 100644
--- a/src/Nncase.Core/IR/NN/LayerNorm.cs
+++ b/src/Nncase.Core/IR/NN/LayerNorm.cs
@@ -21,19 +21,23 @@ public sealed partial class LayerNorm : Op
///
/// Gets input.
///
- public static readonly ParameterInfo Input = new(typeof(LayerNorm), 0, "input");
+ public static readonly ParameterInfo Input = new(typeof(LayerNorm), 0, "input", ParameterKind.Input);
///
/// Gets scale.
///
- public static readonly ParameterInfo Scale = new(typeof(LayerNorm), 1, "scale");
+ public static readonly ParameterInfo Scale = new(typeof(LayerNorm), 1, "scale", ParameterKind.Input);
///
/// Gets bias.
///
- public static readonly ParameterInfo Bias = new(typeof(LayerNorm), 2, "bias");
+ public static readonly ParameterInfo Bias = new(typeof(LayerNorm), 2, "bias", ParameterKind.Input);
public int Axis { get; }
public float Epsilon { get; }
+
+ public bool UseMean { get; }
+
+ public override string DisplayProperty() => $"Axis: {Axis}, Epsilon: {Epsilon}, UseMean: {UseMean}";
}
diff --git a/src/Nncase.Core/IR/NN/Normalization.cs b/src/Nncase.Core/IR/NN/Normalization.cs
index ba91f13f83..2b7b74168c 100644
--- a/src/Nncase.Core/IR/NN/Normalization.cs
+++ b/src/Nncase.Core/IR/NN/Normalization.cs
@@ -61,17 +61,17 @@ public sealed partial class InstanceNormalization : Op
///
/// Gets input.
///
- public static readonly ParameterInfo Input = new(typeof(InstanceNormalization), 0, "input");
+ public static readonly ParameterInfo Input = new(typeof(InstanceNormalization), 0, "input", ParameterKind.Input);
///
/// Gets input.
///
- public static readonly ParameterInfo Scale = new(typeof(InstanceNormalization), 1, "scale");
+ public static readonly ParameterInfo Scale = new(typeof(InstanceNormalization), 1, "scale", ParameterKind.Input);
///
/// Gets input.
///
- public static readonly ParameterInfo Bias = new(typeof(InstanceNormalization), 2, "bias");
+ public static readonly ParameterInfo Bias = new(typeof(InstanceNormalization), 2, "bias", ParameterKind.Input);
///
/// Gets Epsilon.
diff --git a/src/Nncase.Core/IR/NN/SoftMax.cs b/src/Nncase.Core/IR/NN/SoftMax.cs
index 686696f032..919f8ac76c 100644
--- a/src/Nncase.Core/IR/NN/SoftMax.cs
+++ b/src/Nncase.Core/IR/NN/SoftMax.cs
@@ -33,7 +33,7 @@ public sealed partial class Softmax : Op
///
/// Gets input.
///
- public static readonly ParameterInfo Input = new(typeof(Softmax), 0, "input");
+ public static readonly ParameterInfo Input = new(typeof(Softmax), 0, "input", ParameterKind.Input);
///
/// Gets axis.
diff --git a/src/Nncase.Core/IR/Op.cs b/src/Nncase.Core/IR/Op.cs
index 07d4bbcabb..2af2e39fd3 100644
--- a/src/Nncase.Core/IR/Op.cs
+++ b/src/Nncase.Core/IR/Op.cs
@@ -12,6 +12,12 @@
namespace Nncase.IR;
+public enum ParameterKind : int
+{
+ Input,
+ Attribute,
+}
+
///
/// Parameter information.
///
@@ -24,11 +30,13 @@ public sealed class ParameterInfo
/// this op type.
/// param index.
/// param name.
- public ParameterInfo(Type ownerType, int index, string name)
+ /// kind.
+ public ParameterInfo(Type ownerType, int index, string name, ParameterKind parameterKind = ParameterKind.Attribute)
{
OwnerType = ownerType;
Index = index;
Name = name;
+ ParameterKind = parameterKind;
}
///
@@ -39,8 +47,9 @@ public ParameterInfo(Type ownerType, int index, string name)
/// param index.
/// param name.
/// the param condition.
- public ParameterInfo(Type ownerType, int index, string name, TypePattern pattern)
- : this(ownerType, index, name)
+ /// kind.
+ public ParameterInfo(Type ownerType, int index, string name, TypePattern pattern, ParameterKind parameterKind = ParameterKind.Attribute)
+ : this(ownerType, index, name, parameterKind)
{
Pattern = pattern;
}
@@ -60,6 +69,11 @@ public ParameterInfo(Type ownerType, int index, string name, TypePattern pattern
///
public string Name { get; }
+ ///
+ /// Gets parameter kind.
+ ///
+ public ParameterKind ParameterKind { get; }
+
///
/// Gets this paramter's type condition.
///
@@ -90,7 +104,7 @@ public Op()
///
/// Gets get the parameters.
///
- public IEnumerable Parameters =>
+ public virtual IEnumerable Parameters =>
_parameters ??= (from p in GetType().GetFields(BindingFlags.Public | BindingFlags.Static)
where p.FieldType == typeof(ParameterInfo)
let param = (ParameterInfo)(p.GetValue(null) ?? throw new InvalidOperationException())
diff --git a/src/Nncase.Core/IR/RNN/Functional.cs b/src/Nncase.Core/IR/RNN/Functional.cs
index 07ebf7f8d6..571bb2141f 100644
--- a/src/Nncase.Core/IR/RNN/Functional.cs
+++ b/src/Nncase.Core/IR/RNN/Functional.cs
@@ -19,5 +19,5 @@ namespace Nncase.IR.F;
public static class RNN
{
public static Call LSTM(LSTMDirection direction, LSTMLayout layout, string[] acts, Expr x, Expr w, Expr r, Expr b, Expr seqLens, Expr initH, Expr initC, Expr p, Expr actAlpha, Expr actBeta, Expr clip, Expr hiddenSize, Expr inputForget, Expr outputSize) =>
- new Call(new IR.Tensors.LSTM(direction, layout, acts), x, w, r, b, seqLens, initH, initC, p, actAlpha, actBeta, clip, hiddenSize, inputForget, outputSize);
+ new Call(new IR.RNN.LSTM(direction, layout, acts), x, w, r, b, seqLens, initH, initC, p, actAlpha, actBeta, clip, hiddenSize, inputForget, outputSize);
}
diff --git a/src/Nncase.Core/IR/RNN/LSTM.cs b/src/Nncase.Core/IR/RNN/LSTM.cs
index ec5b9802b3..4bdd60f223 100644
--- a/src/Nncase.Core/IR/RNN/LSTM.cs
+++ b/src/Nncase.Core/IR/RNN/LSTM.cs
@@ -5,7 +5,7 @@
using Nncase.PatternMatch;
using static Nncase.IR.TypePatternUtility;
-namespace Nncase.IR.Tensors;
+namespace Nncase.IR.RNN;
///
/// LSTM expression.
diff --git a/src/Nncase.Core/IR/TensorConst.cs b/src/Nncase.Core/IR/TensorConst.cs
index 64b1fd6442..9e651978ed 100644
--- a/src/Nncase.Core/IR/TensorConst.cs
+++ b/src/Nncase.Core/IR/TensorConst.cs
@@ -146,7 +146,7 @@ public override TExprResult Accept