diff --git a/stablehlo/integrations/c/CMakeLists.txt b/stablehlo/integrations/c/CMakeLists.txt index 88dc7c2b938..1e0bcfb4726 100644 --- a/stablehlo/integrations/c/CMakeLists.txt +++ b/stablehlo/integrations/c/CMakeLists.txt @@ -13,6 +13,14 @@ # See the License for the specific language governing permissions and # limitations under the License. +add_mlir_public_c_api_library(CheckCAPI + PARTIAL_SOURCES_INTENDED + CheckDialect.cpp + + LINK_LIBS PUBLIC + CheckOps +) + add_mlir_public_c_api_library(ChloCAPI PARTIAL_SOURCES_INTENDED ChloAttributes.cpp diff --git a/stablehlo/integrations/c/CheckDialect.cpp b/stablehlo/integrations/c/CheckDialect.cpp new file mode 100644 index 00000000000..b30f3626a8b --- /dev/null +++ b/stablehlo/integrations/c/CheckDialect.cpp @@ -0,0 +1,19 @@ +/* Copyright 2024 The StableHLO Authors. +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + http://www.apache.org/licenses/LICENSE-2.0 +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include "stablehlo/integrations/c/CheckDialect.h" + +#include "mlir/CAPI/Registration.h" +#include "stablehlo/tests/CheckOps.h" + +MLIR_DEFINE_CAPI_DIALECT_REGISTRATION(Check, check, + mlir::stablehlo::check::CheckDialect) diff --git a/stablehlo/integrations/c/CheckDialect.h b/stablehlo/integrations/c/CheckDialect.h new file mode 100644 index 00000000000..b50303e6112 --- /dev/null +++ b/stablehlo/integrations/c/CheckDialect.h @@ -0,0 +1,28 @@ +/* Copyright 2024 The StableHLO Authors. +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + http://www.apache.org/licenses/LICENSE-2.0 +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#ifndef STABLEHLO_INTEGRATIONS_C_CHECK_DIALECT_H +#define STABLEHLO_INTEGRATIONS_C_CHECK_DIALECT_H + +#include "mlir-c/RegisterEverything.h" + +#ifdef __cplusplus +extern "C" { +#endif + +MLIR_DECLARE_CAPI_DIALECT_REGISTRATION(Check, check); + +#ifdef __cplusplus +} +#endif + +#endif // STABLEHLO_INTEGRATIONS_C_CHECK_DIALECT_H diff --git a/stablehlo/integrations/python/CMakeLists.txt b/stablehlo/integrations/python/CMakeLists.txt index 5187e4affd6..fac75773a16 100644 --- a/stablehlo/integrations/python/CMakeLists.txt +++ b/stablehlo/integrations/python/CMakeLists.txt @@ -22,6 +22,18 @@ include(AddMLIRPython) # putting .td and .py files under . instead of mlir/python will break things, # even if the build rules below are adjusted accordingly. +declare_mlir_python_sources(CheckPythonSources) +declare_mlir_python_sources(CheckPythonSources.Dialects + ADD_TO_PARENT CheckPythonSources +) + +declare_mlir_dialect_python_bindings( + ADD_TO_PARENT CheckPythonSources.Dialects + ROOT_DIR "${CMAKE_CURRENT_SOURCE_DIR}/mlir" + TD_FILE dialects/CheckOps.td + SOURCES dialects/check.py + DIALECT_NAME check) + declare_mlir_python_sources(ChloPythonSources) declare_mlir_python_sources(ChloPythonSources.Dialects ADD_TO_PARENT ChloPythonSources @@ -69,6 +81,18 @@ declare_mlir_dialect_python_bindings( # Extensions ################################################################################ +declare_mlir_python_sources(CheckPythonExtensions) +declare_mlir_python_extension(CheckPythonExtensions.Main + MODULE_NAME _check + ADD_TO_PARENT CheckPythonExtensions + SOURCES + CheckModule.cpp + EMBED_CAPI_LINK_LIBS + CheckCAPI + PRIVATE_LINK_LIBS + LLVMSupport +) + declare_mlir_python_sources(ChloPythonExtensions) declare_mlir_python_extension(ChloPythonExtensions.Main MODULE_NAME _chlo @@ -127,6 +151,8 @@ add_mlir_python_common_capi_library(StablehloUnifiedPythonCAPI DECLARED_SOURCES MLIRPythonSources MLIRPythonExtension.RegisterEverything + CheckPythonSources + CheckPythonExtensions ChloPythonSources ChloPythonExtensions StablehloPythonSources @@ -141,6 +167,8 @@ add_mlir_python_modules(StablehloUnifiedPythonModules DECLARED_SOURCES MLIRPythonSources MLIRPythonExtension.RegisterEverything + CheckPythonSources + CheckPythonExtensions ChloPythonSources ChloPythonExtensions StablehloPythonSources diff --git a/stablehlo/integrations/python/CheckModule.cpp b/stablehlo/integrations/python/CheckModule.cpp new file mode 100644 index 00000000000..c67c7911a8e --- /dev/null +++ b/stablehlo/integrations/python/CheckModule.cpp @@ -0,0 +1,36 @@ +/* Copyright 2024 The StableHLO Authors. +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + http://www.apache.org/licenses/LICENSE-2.0 +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include "mlir-c/IR.h" +#include "mlir/Bindings/Python/PybindAdaptors.h" +#include "stablehlo/integrations/c/CheckDialect.h" + +namespace py = pybind11; + +PYBIND11_MODULE(_check, m) { + m.doc() = "check main python extension"; + + // + // Dialects. + // + + m.def( + "register_dialect", + [](MlirContext context, bool load) { + MlirDialectHandle dialect = mlirGetDialectHandle__check__(); + mlirDialectHandleRegisterDialect(dialect, context); + if (load) { + mlirDialectHandleLoadDialect(dialect, context); + } + }, + py::arg("context"), py::arg("load") = true); +} diff --git a/stablehlo/integrations/python/mlir/dialects/CheckOps.td b/stablehlo/integrations/python/mlir/dialects/CheckOps.td new file mode 100644 index 00000000000..9b606d33765 --- /dev/null +++ b/stablehlo/integrations/python/mlir/dialects/CheckOps.td @@ -0,0 +1,21 @@ +/* Copyright 2024 The StableHLO Authors. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#ifndef STABLEHLO_INTEGRATIONS_PYTHON_CHECK_OPS +#define STABLEHLO_INTEGRATIONS_PYTHON_CHECK_OPS + +include "stablehlo/tests/CheckOps.td" + +#endif diff --git a/stablehlo/integrations/python/mlir/dialects/check.py b/stablehlo/integrations/python/mlir/dialects/check.py new file mode 100644 index 00000000000..19f9fc37635 --- /dev/null +++ b/stablehlo/integrations/python/mlir/dialects/check.py @@ -0,0 +1,18 @@ +# Copyright 2024 The StableHLO Authors. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== + +# pylint: disable=wildcard-import,relative-beyond-top-level,g-import-not-at-top +from ._check_ops_gen import * +from .._mlir_libs._check import * diff --git a/stablehlo/integrations/python/tests/CMakeLists.txt b/stablehlo/integrations/python/tests/CMakeLists.txt index f83b7cd2140..8661662ae96 100644 --- a/stablehlo/integrations/python/tests/CMakeLists.txt +++ b/stablehlo/integrations/python/tests/CMakeLists.txt @@ -24,6 +24,7 @@ add_custom_target(${test_name} add_dependencies(check-stablehlo-python ${test_name}) endfunction() +add_stablehlo_python_test(stablehlo-python-check check.py) add_stablehlo_python_test(stablehlo-python-chlo chlo.py) add_stablehlo_python_test(stablehlo-python-smoketest smoketest.py) add_stablehlo_python_test(stablehlo-python-stablehlo stablehlo.py) diff --git a/stablehlo/integrations/python/tests/check.py b/stablehlo/integrations/python/tests/check.py new file mode 100644 index 00000000000..70c03adbc20 --- /dev/null +++ b/stablehlo/integrations/python/tests/check.py @@ -0,0 +1,44 @@ +# Copyright 2024 The StableHLO Authors. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== +"""Tests for CHECK Python APIs.""" + +# pylint: disable=wildcard-import,undefined-variable + +from mlir import ir +from mlir.dialects import check as check_dialect +from mlir.dialects import stablehlo as stablehlo_dialect + + +def run(f): + with ir.Context() as context: + check_dialect.register_dialect(context) + stablehlo_dialect.register_dialect(context) + f() + return f + +@run +def test_parse(): + asm = """ + module { + func.func @main() { + %cst = stablehlo.constant dense<[1.0, 2.0]> : tensor<2xf32> + %cst_0 = stablehlo.constant dense<[3.0, 4.0]> : tensor<2xf32> + %0 = stablehlo.add %cst, %cst_0 : tensor<2xf32> + check.expect_eq_const %0, dense<[4.0, 6.0]> : tensor<2xf32> + return + } + } + """ + ir.Module.parse(asm)