From eba821aa1c54a21d70331d7926dfc8b929f988f3 Mon Sep 17 00:00:00 2001 From: Sandeep Dasgupta Date: Tue, 23 Jul 2024 16:07:58 +0000 Subject: [PATCH] Python bindings for registering check dialect (#2445) Check dialect is an auxiliary dialect used in StableHLO repository for validation of StableHLO program evaluation. Currently there is no cleaner way to parse a module containing check dialect operations or to create them. One way to go around this is to textually modify the check dialect ops to normalize them to generic MLIR text and allow unregistered dialect to pars that text. However, this PR provodes a cleaner way to process check dialect. This PR prepares for open sourcing some of the utilities, leveraged in https://github.com/openxla/stablehlo/pull/2404, for auto-generating testdata formatted test files. --- stablehlo/integrations/c/CMakeLists.txt | 8 ++++ stablehlo/integrations/c/CheckDialect.cpp | 19 ++++++++ stablehlo/integrations/c/CheckDialect.h | 28 ++++++++++++ stablehlo/integrations/python/CMakeLists.txt | 28 ++++++++++++ stablehlo/integrations/python/CheckModule.cpp | 36 +++++++++++++++ .../python/mlir/dialects/CheckOps.td | 21 +++++++++ .../python/mlir/dialects/check.py | 18 ++++++++ .../integrations/python/tests/CMakeLists.txt | 1 + stablehlo/integrations/python/tests/check.py | 44 +++++++++++++++++++ 9 files changed, 203 insertions(+) create mode 100644 stablehlo/integrations/c/CheckDialect.cpp create mode 100644 stablehlo/integrations/c/CheckDialect.h create mode 100644 stablehlo/integrations/python/CheckModule.cpp create mode 100644 stablehlo/integrations/python/mlir/dialects/CheckOps.td create mode 100644 stablehlo/integrations/python/mlir/dialects/check.py create mode 100644 stablehlo/integrations/python/tests/check.py diff --git a/stablehlo/integrations/c/CMakeLists.txt b/stablehlo/integrations/c/CMakeLists.txt index 88dc7c2b938..1e0bcfb4726 100644 --- a/stablehlo/integrations/c/CMakeLists.txt +++ b/stablehlo/integrations/c/CMakeLists.txt @@ -13,6 +13,14 @@ # See the License for the specific language governing permissions and # limitations under the License. +add_mlir_public_c_api_library(CheckCAPI + PARTIAL_SOURCES_INTENDED + CheckDialect.cpp + + LINK_LIBS PUBLIC + CheckOps +) + add_mlir_public_c_api_library(ChloCAPI PARTIAL_SOURCES_INTENDED ChloAttributes.cpp diff --git a/stablehlo/integrations/c/CheckDialect.cpp b/stablehlo/integrations/c/CheckDialect.cpp new file mode 100644 index 00000000000..b30f3626a8b --- /dev/null +++ b/stablehlo/integrations/c/CheckDialect.cpp @@ -0,0 +1,19 @@ +/* Copyright 2024 The StableHLO Authors. +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + http://www.apache.org/licenses/LICENSE-2.0 +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include "stablehlo/integrations/c/CheckDialect.h" + +#include "mlir/CAPI/Registration.h" +#include "stablehlo/tests/CheckOps.h" + +MLIR_DEFINE_CAPI_DIALECT_REGISTRATION(Check, check, + mlir::stablehlo::check::CheckDialect) diff --git a/stablehlo/integrations/c/CheckDialect.h b/stablehlo/integrations/c/CheckDialect.h new file mode 100644 index 00000000000..b50303e6112 --- /dev/null +++ b/stablehlo/integrations/c/CheckDialect.h @@ -0,0 +1,28 @@ +/* Copyright 2024 The StableHLO Authors. +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + http://www.apache.org/licenses/LICENSE-2.0 +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#ifndef STABLEHLO_INTEGRATIONS_C_CHECK_DIALECT_H +#define STABLEHLO_INTEGRATIONS_C_CHECK_DIALECT_H + +#include "mlir-c/RegisterEverything.h" + +#ifdef __cplusplus +extern "C" { +#endif + +MLIR_DECLARE_CAPI_DIALECT_REGISTRATION(Check, check); + +#ifdef __cplusplus +} +#endif + +#endif // STABLEHLO_INTEGRATIONS_C_CHECK_DIALECT_H diff --git a/stablehlo/integrations/python/CMakeLists.txt b/stablehlo/integrations/python/CMakeLists.txt index 5187e4affd6..fac75773a16 100644 --- a/stablehlo/integrations/python/CMakeLists.txt +++ b/stablehlo/integrations/python/CMakeLists.txt @@ -22,6 +22,18 @@ include(AddMLIRPython) # putting .td and .py files under . instead of mlir/python will break things, # even if the build rules below are adjusted accordingly. +declare_mlir_python_sources(CheckPythonSources) +declare_mlir_python_sources(CheckPythonSources.Dialects + ADD_TO_PARENT CheckPythonSources +) + +declare_mlir_dialect_python_bindings( + ADD_TO_PARENT CheckPythonSources.Dialects + ROOT_DIR "${CMAKE_CURRENT_SOURCE_DIR}/mlir" + TD_FILE dialects/CheckOps.td + SOURCES dialects/check.py + DIALECT_NAME check) + declare_mlir_python_sources(ChloPythonSources) declare_mlir_python_sources(ChloPythonSources.Dialects ADD_TO_PARENT ChloPythonSources @@ -69,6 +81,18 @@ declare_mlir_dialect_python_bindings( # Extensions ################################################################################ +declare_mlir_python_sources(CheckPythonExtensions) +declare_mlir_python_extension(CheckPythonExtensions.Main + MODULE_NAME _check + ADD_TO_PARENT CheckPythonExtensions + SOURCES + CheckModule.cpp + EMBED_CAPI_LINK_LIBS + CheckCAPI + PRIVATE_LINK_LIBS + LLVMSupport +) + declare_mlir_python_sources(ChloPythonExtensions) declare_mlir_python_extension(ChloPythonExtensions.Main MODULE_NAME _chlo @@ -127,6 +151,8 @@ add_mlir_python_common_capi_library(StablehloUnifiedPythonCAPI DECLARED_SOURCES MLIRPythonSources MLIRPythonExtension.RegisterEverything + CheckPythonSources + CheckPythonExtensions ChloPythonSources ChloPythonExtensions StablehloPythonSources @@ -141,6 +167,8 @@ add_mlir_python_modules(StablehloUnifiedPythonModules DECLARED_SOURCES MLIRPythonSources MLIRPythonExtension.RegisterEverything + CheckPythonSources + CheckPythonExtensions ChloPythonSources ChloPythonExtensions StablehloPythonSources diff --git a/stablehlo/integrations/python/CheckModule.cpp b/stablehlo/integrations/python/CheckModule.cpp new file mode 100644 index 00000000000..c67c7911a8e --- /dev/null +++ b/stablehlo/integrations/python/CheckModule.cpp @@ -0,0 +1,36 @@ +/* Copyright 2024 The StableHLO Authors. +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + http://www.apache.org/licenses/LICENSE-2.0 +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include "mlir-c/IR.h" +#include "mlir/Bindings/Python/PybindAdaptors.h" +#include "stablehlo/integrations/c/CheckDialect.h" + +namespace py = pybind11; + +PYBIND11_MODULE(_check, m) { + m.doc() = "check main python extension"; + + // + // Dialects. + // + + m.def( + "register_dialect", + [](MlirContext context, bool load) { + MlirDialectHandle dialect = mlirGetDialectHandle__check__(); + mlirDialectHandleRegisterDialect(dialect, context); + if (load) { + mlirDialectHandleLoadDialect(dialect, context); + } + }, + py::arg("context"), py::arg("load") = true); +} diff --git a/stablehlo/integrations/python/mlir/dialects/CheckOps.td b/stablehlo/integrations/python/mlir/dialects/CheckOps.td new file mode 100644 index 00000000000..9b606d33765 --- /dev/null +++ b/stablehlo/integrations/python/mlir/dialects/CheckOps.td @@ -0,0 +1,21 @@ +/* Copyright 2024 The StableHLO Authors. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#ifndef STABLEHLO_INTEGRATIONS_PYTHON_CHECK_OPS +#define STABLEHLO_INTEGRATIONS_PYTHON_CHECK_OPS + +include "stablehlo/tests/CheckOps.td" + +#endif diff --git a/stablehlo/integrations/python/mlir/dialects/check.py b/stablehlo/integrations/python/mlir/dialects/check.py new file mode 100644 index 00000000000..19f9fc37635 --- /dev/null +++ b/stablehlo/integrations/python/mlir/dialects/check.py @@ -0,0 +1,18 @@ +# Copyright 2024 The StableHLO Authors. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== + +# pylint: disable=wildcard-import,relative-beyond-top-level,g-import-not-at-top +from ._check_ops_gen import * +from .._mlir_libs._check import * diff --git a/stablehlo/integrations/python/tests/CMakeLists.txt b/stablehlo/integrations/python/tests/CMakeLists.txt index f83b7cd2140..8661662ae96 100644 --- a/stablehlo/integrations/python/tests/CMakeLists.txt +++ b/stablehlo/integrations/python/tests/CMakeLists.txt @@ -24,6 +24,7 @@ add_custom_target(${test_name} add_dependencies(check-stablehlo-python ${test_name}) endfunction() +add_stablehlo_python_test(stablehlo-python-check check.py) add_stablehlo_python_test(stablehlo-python-chlo chlo.py) add_stablehlo_python_test(stablehlo-python-smoketest smoketest.py) add_stablehlo_python_test(stablehlo-python-stablehlo stablehlo.py) diff --git a/stablehlo/integrations/python/tests/check.py b/stablehlo/integrations/python/tests/check.py new file mode 100644 index 00000000000..70c03adbc20 --- /dev/null +++ b/stablehlo/integrations/python/tests/check.py @@ -0,0 +1,44 @@ +# Copyright 2024 The StableHLO Authors. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== +"""Tests for CHECK Python APIs.""" + +# pylint: disable=wildcard-import,undefined-variable + +from mlir import ir +from mlir.dialects import check as check_dialect +from mlir.dialects import stablehlo as stablehlo_dialect + + +def run(f): + with ir.Context() as context: + check_dialect.register_dialect(context) + stablehlo_dialect.register_dialect(context) + f() + return f + +@run +def test_parse(): + asm = """ + module { + func.func @main() { + %cst = stablehlo.constant dense<[1.0, 2.0]> : tensor<2xf32> + %cst_0 = stablehlo.constant dense<[3.0, 4.0]> : tensor<2xf32> + %0 = stablehlo.add %cst, %cst_0 : tensor<2xf32> + check.expect_eq_const %0, dense<[4.0, 6.0]> : tensor<2xf32> + return + } + } + """ + ir.Module.parse(asm)