Skip to content

Commit

Permalink
Python bindings for registering check dialect (#2445)
Browse files Browse the repository at this point in the history
Check dialect is an auxiliary dialect used in StableHLO repository for
validation of StableHLO program evaluation. Currently there is no
cleaner way to parse a module containing check dialect operations or to
create them. One way to go around this is to textually modify the check
dialect ops to normalize them to generic MLIR text and allow
unregistered dialect to pars that text. However, this PR provodes a
cleaner way to process check dialect.

This PR prepares for open sourcing some of the utilities, leveraged in
#2404, for auto-generating
testdata formatted test files.
  • Loading branch information
sdasgup3 authored Jul 23, 2024
1 parent edef22a commit eba821a
Show file tree
Hide file tree
Showing 9 changed files with 203 additions and 0 deletions.
8 changes: 8 additions & 0 deletions stablehlo/integrations/c/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,14 @@
# See the License for the specific language governing permissions and
# limitations under the License.

add_mlir_public_c_api_library(CheckCAPI
PARTIAL_SOURCES_INTENDED
CheckDialect.cpp

LINK_LIBS PUBLIC
CheckOps
)

add_mlir_public_c_api_library(ChloCAPI
PARTIAL_SOURCES_INTENDED
ChloAttributes.cpp
Expand Down
19 changes: 19 additions & 0 deletions stablehlo/integrations/c/CheckDialect.cpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,19 @@
/* Copyright 2024 The StableHLO Authors.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/

#include "stablehlo/integrations/c/CheckDialect.h"

#include "mlir/CAPI/Registration.h"
#include "stablehlo/tests/CheckOps.h"

MLIR_DEFINE_CAPI_DIALECT_REGISTRATION(Check, check,
mlir::stablehlo::check::CheckDialect)
28 changes: 28 additions & 0 deletions stablehlo/integrations/c/CheckDialect.h
Original file line number Diff line number Diff line change
@@ -0,0 +1,28 @@
/* Copyright 2024 The StableHLO Authors.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/

#ifndef STABLEHLO_INTEGRATIONS_C_CHECK_DIALECT_H
#define STABLEHLO_INTEGRATIONS_C_CHECK_DIALECT_H

#include "mlir-c/RegisterEverything.h"

#ifdef __cplusplus
extern "C" {
#endif

MLIR_DECLARE_CAPI_DIALECT_REGISTRATION(Check, check);

#ifdef __cplusplus
}
#endif

#endif // STABLEHLO_INTEGRATIONS_C_CHECK_DIALECT_H
28 changes: 28 additions & 0 deletions stablehlo/integrations/python/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,18 @@ include(AddMLIRPython)
# putting .td and .py files under . instead of mlir/python will break things,
# even if the build rules below are adjusted accordingly.

declare_mlir_python_sources(CheckPythonSources)
declare_mlir_python_sources(CheckPythonSources.Dialects
ADD_TO_PARENT CheckPythonSources
)

declare_mlir_dialect_python_bindings(
ADD_TO_PARENT CheckPythonSources.Dialects
ROOT_DIR "${CMAKE_CURRENT_SOURCE_DIR}/mlir"
TD_FILE dialects/CheckOps.td
SOURCES dialects/check.py
DIALECT_NAME check)

declare_mlir_python_sources(ChloPythonSources)
declare_mlir_python_sources(ChloPythonSources.Dialects
ADD_TO_PARENT ChloPythonSources
Expand Down Expand Up @@ -69,6 +81,18 @@ declare_mlir_dialect_python_bindings(
# Extensions
################################################################################

declare_mlir_python_sources(CheckPythonExtensions)
declare_mlir_python_extension(CheckPythonExtensions.Main
MODULE_NAME _check
ADD_TO_PARENT CheckPythonExtensions
SOURCES
CheckModule.cpp
EMBED_CAPI_LINK_LIBS
CheckCAPI
PRIVATE_LINK_LIBS
LLVMSupport
)

declare_mlir_python_sources(ChloPythonExtensions)
declare_mlir_python_extension(ChloPythonExtensions.Main
MODULE_NAME _chlo
Expand Down Expand Up @@ -127,6 +151,8 @@ add_mlir_python_common_capi_library(StablehloUnifiedPythonCAPI
DECLARED_SOURCES
MLIRPythonSources
MLIRPythonExtension.RegisterEverything
CheckPythonSources
CheckPythonExtensions
ChloPythonSources
ChloPythonExtensions
StablehloPythonSources
Expand All @@ -141,6 +167,8 @@ add_mlir_python_modules(StablehloUnifiedPythonModules
DECLARED_SOURCES
MLIRPythonSources
MLIRPythonExtension.RegisterEverything
CheckPythonSources
CheckPythonExtensions
ChloPythonSources
ChloPythonExtensions
StablehloPythonSources
Expand Down
36 changes: 36 additions & 0 deletions stablehlo/integrations/python/CheckModule.cpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,36 @@
/* Copyright 2024 The StableHLO Authors.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/

#include "mlir-c/IR.h"
#include "mlir/Bindings/Python/PybindAdaptors.h"
#include "stablehlo/integrations/c/CheckDialect.h"

namespace py = pybind11;

PYBIND11_MODULE(_check, m) {
m.doc() = "check main python extension";

//
// Dialects.
//

m.def(
"register_dialect",
[](MlirContext context, bool load) {
MlirDialectHandle dialect = mlirGetDialectHandle__check__();
mlirDialectHandleRegisterDialect(dialect, context);
if (load) {
mlirDialectHandleLoadDialect(dialect, context);
}
},
py::arg("context"), py::arg("load") = true);
}
21 changes: 21 additions & 0 deletions stablehlo/integrations/python/mlir/dialects/CheckOps.td
Original file line number Diff line number Diff line change
@@ -0,0 +1,21 @@
/* Copyright 2024 The StableHLO Authors.

Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at

http://www.apache.org/licenses/LICENSE-2.0

Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/

#ifndef STABLEHLO_INTEGRATIONS_PYTHON_CHECK_OPS
#define STABLEHLO_INTEGRATIONS_PYTHON_CHECK_OPS

include "stablehlo/tests/CheckOps.td"

#endif
18 changes: 18 additions & 0 deletions stablehlo/integrations/python/mlir/dialects/check.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,18 @@
# Copyright 2024 The StableHLO Authors.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================

# pylint: disable=wildcard-import,relative-beyond-top-level,g-import-not-at-top
from ._check_ops_gen import *
from .._mlir_libs._check import *
1 change: 1 addition & 0 deletions stablehlo/integrations/python/tests/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,7 @@ add_custom_target(${test_name}
add_dependencies(check-stablehlo-python ${test_name})
endfunction()

add_stablehlo_python_test(stablehlo-python-check check.py)
add_stablehlo_python_test(stablehlo-python-chlo chlo.py)
add_stablehlo_python_test(stablehlo-python-smoketest smoketest.py)
add_stablehlo_python_test(stablehlo-python-stablehlo stablehlo.py)
Expand Down
44 changes: 44 additions & 0 deletions stablehlo/integrations/python/tests/check.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,44 @@
# Copyright 2024 The StableHLO Authors.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""Tests for CHECK Python APIs."""

# pylint: disable=wildcard-import,undefined-variable

from mlir import ir
from mlir.dialects import check as check_dialect
from mlir.dialects import stablehlo as stablehlo_dialect


def run(f):
with ir.Context() as context:
check_dialect.register_dialect(context)
stablehlo_dialect.register_dialect(context)
f()
return f

@run
def test_parse():
asm = """
module {
func.func @main() {
%cst = stablehlo.constant dense<[1.0, 2.0]> : tensor<2xf32>
%cst_0 = stablehlo.constant dense<[3.0, 4.0]> : tensor<2xf32>
%0 = stablehlo.add %cst, %cst_0 : tensor<2xf32>
check.expect_eq_const %0, dense<[4.0, 6.0]> : tensor<2xf32>
return
}
}
"""
ir.Module.parse(asm)

0 comments on commit eba821a

Please sign in to comment.