diff --git a/.JuliaFormatter.toml b/.JuliaFormatter.toml new file mode 100644 index 00000000..5807a6d9 --- /dev/null +++ b/.JuliaFormatter.toml @@ -0,0 +1,21 @@ +indent = 2 +margin = 120 +always_for_in = true +whitespace_typedefs = false +whitespace_ops_in_indices = true +remove_extra_newlines = false +import_to_using = false +pipe_to_function_call = false +short_to_long_function_def = false +always_use_return = false +whitespace_in_kwargs = true +annotate_untyped_fields_with_any = false +format_docstrings = false +align_struct_field = true +align_conditional = true +align_assignment = true +align_pair_arrow = true +conditional_to_if = false +normalize_line_endings = "auto" +align_matrix = false +trailing_comma = true diff --git a/.github/workflows/CompatHelper.yml b/.github/workflows/CompatHelper.yml new file mode 100644 index 00000000..0fe6c374 --- /dev/null +++ b/.github/workflows/CompatHelper.yml @@ -0,0 +1,26 @@ +name: CompatHelper + +on: + schedule: + - cron: '00 * * * *' + issues: + types: [opened, reopened] + +jobs: + build: + runs-on: ${{ matrix.os }} + strategy: + matrix: + julia-version: [1] + julia-arch: [x86] + os: [ubuntu-latest] + steps: + - uses: julia-actions/setup-julia@latest + with: + version: ${{ matrix.julia-version }} + - name: Pkg.add("CompatHelper") + run: julia -e 'using Pkg; Pkg.add("CompatHelper")' + - name: CompatHelper.main() + env: + GITHUB_TOKEN: ${{ secrets.GITHUB_TOKEN }} + run: julia -e 'using CompatHelper; CompatHelper.main()' diff --git a/.gitignore b/.gitignore index dbd706c2..41a96567 100644 --- a/.gitignore +++ b/.gitignore @@ -13,11 +13,12 @@ deps/downloads/ deps/usr/ deps/src/ -unreleased/ +wip/ # Build artifacts for creating documentation generated by the Documenter package docs/build/ docs/site/ +docs/src/tutorials/* # File generated by Pkg, the package manager, based on a corresponding Project.toml # It records a fixed state of all packages used by the project. As such, it should not be diff --git a/ARCHITECTURE.md b/ARCHITECTURE.md index 8152df1a..41210707 100644 --- a/ARCHITECTURE.md +++ b/ARCHITECTURE.md @@ -3,14 +3,13 @@ ## Patterns Module The `Patterns.jl` file contains type definitions for pattern matching building blocks -called `Pattern`s, shared between pattern matching backends. +called `AbstractPat`s, shared between pattern matching backends. This module provides the type hierarchy required to build patterns, the left hand side of rules. ## Rules -The `Rules` folder contains -- `rules.jl`: definitions for rule types used in various rewriting backends. +- `Rules.jl`: definitions for rule types used in various rewriting backends. - `matchers.jl`: Classical rewriting pattern matcher. # `Syntax.jl` @@ -19,8 +18,7 @@ Contains the frontend to Rules and Patterns (`@rule` macro and `Pattern` functio # EGraphs Module Contains code for the e-graphs rewriting backend. See [egg paper](https://dl.acm.org/doi/pdf/10.1145/3434304) for an high level overview. -- `egraphs.jl`: Definition of `ENode`, `EClass` and `EGraph` types, EClass unioning, metadata access, defintion of EGraphs, adding, merging, rebuilding. -- `ematch.jl`: E-Graph Pattern matching virtual machine interpreter. +- `egraph.jl`: Definition of `ENode`, `EClass` and `EGraph` types, EClass unioning, metadata access, definition of EGraphs, adding, merging, rebuilding. - `analysis.jl`: Core algorithms for analyzing egraphs and extracting terms from egraphs. - `saturation.jl`: Core algorithm for equality saturation, rewriting on e-graphs, e-graphs search. Search phase of equality saturation. Uses multiple-dispatch on rules, Write phase of equality saturation. Application and instantiation of `Patterns` from matching/search results. Definition of `SaturationParams` type, parameters for equality saturation, Definition of equality saturation execution reports. Utility functions and macros to check equality of terms in egraphs. - `Schedulers.jl`: Module containing definition of Schedulers for equality saturation. diff --git a/CITATION.bib b/CITATION.bib index 44f1076d..2037bf46 100644 --- a/CITATION.bib +++ b/CITATION.bib @@ -10,3 +10,12 @@ @article{Cheli2021 title = {Metatheory.jl: Fast and Elegant Algebraic Computation in Julia with Extensible Equality Saturation}, journal = {Journal of Open Source Software} } + +@misc{cheli2021automated, + title={Automated Code Optimization with E-Graphs}, + author={Alessandro Cheli}, + year={2021}, + eprint={2112.14714}, + archivePrefix={arXiv}, + primaryClass={cs.PL} +} \ No newline at end of file diff --git a/CONTRIBUTING.md b/CONTRIBUTING.md index 78c30ab0..71a569aa 100644 --- a/CONTRIBUTING.md +++ b/CONTRIBUTING.md @@ -10,7 +10,7 @@ A pre-requisite for using Metatheory.jl is to know at least a little about Julia ## Learning Metatheory.jl -Our [main documentaion](https://github.com/JuliaSymbolics/Metatheory.jl/) provides an overview and some examples of using Metatheory.jl. +Our [main documentation](https://github.com/JuliaSymbolics/Metatheory.jl/) provides an overview and some examples of using Metatheory.jl. The core package is hosted at [Metatheory.jl](https://github.com/JuliaSymbolics/Metatheory.jl/). ## Before filing an issue diff --git a/NEWS.md b/NEWS.md index a3694f50..35609225 100644 --- a/NEWS.md +++ b/NEWS.md @@ -1,3 +1,11 @@ +# 2.0 +- No longer dispatch against types, but instead dispatch against objects. +- Faster E-Graph Analysis +- Better library macros +- Updated TermInterface to 0.3.3 +- New interface for e-graph extraction using `EGraphs.egraph_reconstruct_expression` +- Simplify E-Graph Analysis Interface. Use Symbols or functions for identifying Analyses. +- Remove duplicates in E-Graph analyses data. ## 1.2 - Fixes when printing patterns - Can pass custom `similarterm` to `SaturationParams` by using `SaturationParams.simterm`. diff --git a/Project.toml b/Project.toml index 5d6f7ae6..22e8b857 100644 --- a/Project.toml +++ b/Project.toml @@ -1,35 +1,30 @@ name = "Metatheory" uuid = "e9d8d322-4543-424a-9be4-0cc815abe26c" authors = ["Alessandro Cheli - 0x0f0f0f "] -version = "1.3.5" +version = "2.0.0" [deps] AutoHashEquals = "15f4f7f2-30c1-5605-9d31-71845cf9641f" DataStructures = "864edb3b-99cc-5e75-8d2d-829cb0a9cfe8" -Dates = "ade2ca70-3891-5945-98fb-dc099432e06a" DocStringExtensions = "ffbed154-4ef7-542d-bbb7-c09d3a79fcae" -Parameters = "d96e819e-fc66-5662-9728-84c9c7592b0a" Reexport = "189a3867-3050-52da-a836-e630ba90ab69" TermInterface = "8ea1fca8-c5ef-4a55-8b96-4e9afe9c9a3c" -ThreadsX = "ac1d9e8a-700a-412c-b207-f0111f4b6c0d" TimerOutputs = "a759f4b9-e2f1-59dc-863e-4aeb61b1ea8f" [compat] AutoHashEquals = "0.2.0" DataStructures = "0.18.9" DocStringExtensions = "0.8, 0.9" -Parameters = "0.12" Reexport = "0.2, 1" -TermInterface = "0.2.3" -ThreadsX = "0.1.7" +TermInterface = "0.3.3" TimerOutputs = "0.5" -julia = "1" +julia = "1.8" [extras] -Calculus = "49dc2e85-a5d0-5ad3-a950-438e2897f1b9" -Catlab = "134e5e36-593f-5add-ad60-77f754baafbe" +Documenter = "e30172f5-a6a5-5a46-863b-614d45cd2de4" +Literate = "98b081ad-f1c9-55d3-8b20-4c87d4299306" SafeTestsets = "1bc83da4-3b8d-516f-aca4-4fe02f6d838f" Test = "8dfed614-e22c-5e08-85e1-65c5234f0b40" [targets] -test = ["Test", "Calculus", "Catlab", "SafeTestsets"] +test = ["Test", "Documenter", "SafeTestsets", "Literate"] diff --git a/README.md b/README.md index d4c42d52..21ce51b0 100644 --- a/README.md +++ b/README.md @@ -27,20 +27,31 @@ Intuitively, Metatheory.jl transforms Julia expressions in other Julia expressions and can achieve such at both compile and run time. This allows Metatheory.jl users to perform customized and composable compiler optimizations specifically tailored to single, arbitrary Julia packages. Our library provides a simple, algebraically composable interface to help scientists in implementing and reasoning about semantics and all kinds of formal systems, by defining concise rewriting rules in pure, syntactically valid Julia on a high level of abstraction. Our implementation of equality saturation on e-graphs is based on the excellent, state-of-the-art technique implemented in the [egg](https://egraphs-good.github.io/) library, reimplemented in pure Julia. -## 1.0 is out! +## 2.0 is out! -The first stable version of Metatheory.jl is out! The goal of this release is to unify the symbolic manipulation ecosystem of Julia packages. Many features have been ported from SymbolicUtils.jl. Now, Metatheory.jl can be used in place of SymbolicUtils.jl when you have no need of manipulating mathematical expressions. SymbolicUtils.jl can now completely leverage on the generic stack of rewriting features provided by Metatheory.jl, highly decoupled from the symbolic term representation thanks to [TermInterface.jl](https://github.com/JuliaSymbolics/TermInterface.jl). Read more in [NEWS.md](https://github.com/JuliaSymbolics/Metatheory.jl/blob/master/NEWS.md). +Second stable version is out: -The introduction of [TermInterface.jl](https://github.com/JuliaSymbolics/TermInterface.jl) has allowed for large potential in generalization of term rewriting and symbolic analysis and manipulation features. It’s been a few months we’ve been talking about the integration between Metatheory.jl with Symbolics.jl, as it has been shown in the ["High-performance symbolic-numerics via multiple dispatch"](https://arxiv.org/abs/2105.03949) paper. +- New e-graph pattern matching system, relies on functional programming and closures, and is much more extensible than 1.0's virtual machine. +- No longer dispatch against types, but instead dispatch against objects. +- Faster E-Graph Analysis +- Better library macros +- Updated TermInterface to 0.3.3 +- New interface for e-graph extraction using `EGraphs.egraph_reconstruct_expression` +- Simplify E-Graph Analysis Interface. Use Symbols or functions for identifying Analyses. +- Remove duplicates in E-Graph analyses data. + + +Many features have been ported from SymbolicUtils.jl. Metatheory.jl can be used in place of SymbolicUtils.jl when you have no need of manipulating mathematical expressions. The introduction of [TermInterface.jl](https://github.com/JuliaSymbolics/TermInterface.jl) has allowed for large potential in generalization of term rewriting and symbolic analysis and manipulation features. Integration between Metatheory.jl with Symbolics.jl, as it has been shown in the ["High-performance symbolic-numerics via multiple dispatch"](https://arxiv.org/abs/2105.03949) paper. ## Recommended Readings - Selected Publications - The [Metatheory.jl manual](https://juliasymbolics.github.io/Metatheory.jl/stable/) -- The [Metatheory.jl introductory paper](https://joss.theoj.org/papers/10.21105/joss.03078#) gives a brief high level overview on the library and its functionalities. +- **OUT OF DATE**: The [Metatheory.jl introductory paper](https://joss.theoj.org/papers/10.21105/joss.03078#) gives a brief high level overview on the library and its functionalities. - The Julia Manual [metaprogramming section](https://docs.julialang.org/en/v1/manual/metaprogramming/) is fundamental to understand what homoiconic expression manipulation is and how it happens in Julia. - An [introductory blog post on SIGPLAN](https://blog.sigplan.org/2021/04/06/equality-saturation-with-egg/) about `egg` and e-graphs rewriting. - [egg: Fast and Extensible Equality Saturation](https://dl.acm.org/doi/pdf/10.1145/3434304) contains the definition of *E-Graphs* on which Metatheory.jl's equality saturation rewriting backend is based. This is a strongly recommended reading. - [High-performance symbolic-numerics via multiple dispatch](https://arxiv.org/abs/2105.03949): a paper about how we used Metatheory.jl to optimize code generation in [Symbolics.jl](https://github.com/JuliaSymbolics/Symbolics.jl) +- [Automated Code Optimization with E-Graphs](https://arxiv.org/abs/2112.14714). Alessandro Cheli's Thesis on Metatheory.jl ## Contributing diff --git a/STYLEGUIDE.md b/STYLEGUIDE.md new file mode 100644 index 00000000..bafe491e --- /dev/null +++ b/STYLEGUIDE.md @@ -0,0 +1,206 @@ +# Style guide +### IDE + +It is recommended to use VSCode when programming in Julia. Its Julia extension +exclusively has shortcuts for evaluating Julia code, can display results inline +and has some support for working with system images, among others, which +typically make it better suited than other editors (unless you spend some effort +customizing another editor to your workflow). For autocompletions, linting and +navigation, it uses the Language Server Protocol (LSP) which you can reuse in +other text editors that support it. + +#### Recommended VSCode extensions + +- Julia: the official Julia extension. +- GitLens: lets you see inline which +commit recently affected the selected line. It is excellent to know who was +working on a piece of code, such that you can easily ask for explanations or +help in case of trouble. + +### Reduce latency with system images + +We can put package dependencies into a system image (kind of like a snapshot of +a Julia session, abbreviated as sysimage) to speed up their loading. + +### Logging + +To turn on debug logging for a given module, set the environment variable +`JULIA_DEBUG` to the name of the module. For example, to enable debugging from +module Foo, just do + +```bash +JULIA_DEBUG=Foo julia --project test/runtests.jl +``` + +Or from REPL +```julia +ENV["JULIA_DEBUG"] = Foo +``` + +## Collaboration + +Once you have developed a piece of code and want to share it with the team, you +can create a merge request. If the changes are not final and will require +further work before considering a merge, then please mark the merge request as a +draft. + +Merge requests marked as drafts may not be reviewed. If you seek a review from +someone, you should explicitly state it in the merge request and tag the person +in question. + +When you are confident in your changes and want to consider a merge, you can +mark the merge request as ready. It will then be reviewed, and when review +comments are addressed, an automatic merge will be issued. + +## Style + +Code style is different from [[#Formatting]]. While the latter can be easily +assisted with by automatic tools, the former cannot. + +### Comments + +Comments and error messages should form proper sentences unless they are titles. + +Get something done later, but only if someone looks at this code again. For +larger things make an issue. + +``` +# TODO: ... +``` + +Sometimes a piece of code is written in a certain way to work around an existing +issue in a dependency. If this code should be cleaned up after that issue is +fixed then the following line with link to issue should be added. + +``` +# ISSUE: https:// +``` + +Probabilistic tests can sometimes fail in CI. If that is the case they should be marked with [`@test_skip`](https://docs.julialang.org/en/v1/stdlib/Test/#Test.@test_skip), which indicates that the test may intermittently fail (it will be reported in the test summary as `Broken`). This is equivalent to `@test (...) skip=true` but requires at least Julia v1.7. A comment before the relevant line is useful so that they can be debugged and made more reliable. + +``` +# FLAKY +@test_skip some_probabilistic_test() +``` + +For packages that do not have to be used as libraries, it is sometimes +convenient to extend external methods on external types - this is referred to as +"type piracy" in Julia style guide. Generally it should be avoided, but for the +cases where it is very convenient it should be tagged. + +``` +# PIRACY +``` + +### Code + +Generally follow the [Julia Style Guide](https://docs.julialang.org/en/v1/manual/style-guide/) with some caveats: +- [Avoid elaborate container types](https://docs.julialang.org/en/v1/manual/style-guide/#Avoid-elaborate-container-types): if explicitly typing a complex container helps with safety then you should do it. But, if a container type is not concrete (abstract type or unparametrized parametric type), nesting it inside another container probably won't do what you intend (Julia types are invariant). For example: + ```julia + # Don't + const NestedContainer = AbstractDict{Symbol,Vector{Array}} + Dict{Float64, AbstractDict{Symbol,NestedContainer}} + + # Do + Dict{Float64, <:AbstractDict} + Dict{Float64, Vector{Int}} + + const Bytes = Vector{UInt8} + struct BytesCollections + collections::Vector{Bytes} + end + AbstractDict{Symbol, BytesCollections} + ``` +- [Avoid type piracy](https://docs.julialang.org/en/v1/manual/style-guide/#Avoid-type-piracy): this is more important for libraries, but in a self-contained project this may be a nice feature. +- Prefer `Foo[]` and `Pair{Symbol,Foo}[]` over `Vector{Foo}()` and `Vector{Pair{Symbol,Foo}}()` for better readability. +- Avoid explicit use of the `return` keyword if it is pointless, e.g. when a function has a unique point of return. + +Otherwise follow this: + +```julia +"Module definition first." +module ExampleModule + +# `using` of external modules. +using Distributions: Normal +# `using` of symbols from internal modules, always explicitly name them. +using ..SomeNeighbourModule: nicefn + +# `import` of symbols, usually to be extended, with the exception of those from `Base` (see below). +import StatsBase: mean + +# --------------------- +# # First main section. + +# Above begins a section of code which is readable with [Literate.jl](https://fredrikekre.github.io/Literate.jl/v2/fileformat/). + +"Function docs as usual. Write proper sentences." +f(x) = x^2 + +# ----------------------- +# ## Title of subsection. + +"Some code in subsection." +g(x) = log(x) + +# ---------------------- +# # Second main section. + +struct A + id::Int64 +end + +"Keep constructors close to datastructure definitions." +A() = A(rand(1:10)) + +""" +Do not use explicit type parameters if not needed. + +Use multi-line strings for longer docstrings. +""" +h(x::Vector{<:Real})::String = "Real vector." +h(x::Vector) = nothing +""" +Use output type annotations when the return type is not clear from context. +This facilitates readability by not requiring the reader to look for the lastly executed statement(s). +""" +function h(x)::Float64 + compute_something(x) +end +h(::Nothing) = 2 + +"Here the type parameter is used twice - it was needed." +i(x::Vector{T})::T where T<:Real = sum(x) + +# Extend symbols defined in `Base` prepending the module's name. +Base.convert(::Type{Expr}, ::Type{Union{}}) = :(Union{}) + +end +``` + +Concerning unit testing, it is a good practice to use [SafeTestsets.j](https://github.com/YingboMa/SafeTestsets.jl), since it makes every single test script an independently runnable file. In turn, this implies that imports need to be manually added in each file. Moreover, we prefer to use explicit imports since that helps to keep tests targeted at what they should be testing. Hence, we suggest the following guidelines in test scripts (which should be included using `@safetestset`): + +```julia +# load modules (eventually, also package itself) +using Test, MacroTools +# load specific names from external dependencies +using MeasureTheory: Dirac +# load specific names from MyPackage submodules (sorted alphabetically) +using MyPackage.SomeModule: Foo, bar, Baz, ⊕ + + +@testset "Descriptive name" begin + # ... +end +``` + +## Formatting + +Use [JuliaFormatter.jl](https://github.com/domluna/JuliaFormatter.jl) to ensure that all code is formatted consistently. There should be a CI job that automatically checks for formatting. However, everyone is encouraged to use the formatter locally before pushing, see usage details below. + +Notable settings: +- Use two spaces for indentation: by default the Julia guide recommends four, but that tends to push code too much to the right. + +### VS Code +If you are using VS code and the Julia Extension, you can also trigger the formatter via [various shortcuts](https://www.julia-vscode.org/docs/stable/userguide/formatter/). + diff --git a/benchmark/tune.json b/benchmark/tune.json new file mode 100644 index 00000000..d3b1dfcc --- /dev/null +++ b/benchmark/tune.json @@ -0,0 +1 @@ +[{"Julia":"1.6.5","BenchmarkTools":"1.0.0"},[["BenchmarkGroup",{"data":{"egraph":["BenchmarkGroup",{"data":{"creation":["BenchmarkGroup",{"data":{"expr":["Parameters",{"gctrial":true,"time_tolerance":0.05,"samples":10000,"evals":1,"gcsample":false,"seconds":5.0,"overhead":0.0,"memory_tolerance":0.01}],"empty":["Parameters",{"gctrial":true,"time_tolerance":0.05,"samples":10000,"evals":195,"gcsample":false,"seconds":5.0,"overhead":0.0,"memory_tolerance":0.01}]},"tags":[]}],"full_examples":["BenchmarkGroup",{"data":{"logic":["Parameters",{"gctrial":true,"time_tolerance":0.05,"samples":10000,"evals":1,"gcsample":false,"seconds":5.0,"overhead":0.0,"memory_tolerance":0.01}]},"tags":[]}]},"tags":[]}]},"tags":[]}]]] \ No newline at end of file diff --git a/benchmarks/egg_logic.jl b/benchmarks/egg_logic.jl deleted file mode 100644 index c165c107..00000000 --- a/benchmarks/egg_logic.jl +++ /dev/null @@ -1,87 +0,0 @@ -include("eggify.jl") -using Metatheory.Library -using Metatheory.EGraphs.Schedulers - -or_alg = @theory begin - ((p ∨ q) ∨ r) == (p ∨ (q ∨ r)) - (p ∨ q) == (q ∨ p) - (p ∨ p) => p - (p ∨ true) => true - (p ∨ false) => p -end - -and_alg = @theory begin - ((p ∧ q) ∧ r) == (p ∧ (q ∧ r)) - (p ∧ q) == (q ∧ p) - (p ∧ p) => p - (p ∧ true) => p - (p ∧ false) => false -end - -comb = @theory begin - # DeMorgan - ¬(p ∨ q) == (¬p ∧ ¬q) - ¬(p ∧ q) == (¬p ∨ ¬q) - # distrib - (p ∧ (q ∨ r)) == ((p ∧ q) ∨ (p ∧ r)) - (p ∨ (q ∧ r)) == ((p ∨ q) ∧ (p ∨ r)) - # absorb - (p ∧ (p ∨ q)) => p - (p ∨ (p ∧ q)) => p - # complement - (p ∧ (¬p ∨ q)) => p ∧ q - (p ∨ (¬p ∧ q)) => p ∨ q -end - -negt = @theory begin - (p ∧ ¬p) => false - (p ∨ ¬(p)) => true - ¬(¬p) == p -end - -impl = @theory begin - (p == ¬p) => false - (p == p) => true - (p == q) => (¬p ∨ q) ∧ (¬q ∨ p) - (p => q) => (¬p ∨ q) -end - -fold = @theory begin - (true == false) => false - (false == true) => false - (true == true) => true - (false == false) => true - (true ∨ false) => true - (false ∨ true) => true - (true ∨ true) => true - (false ∨ false) => false - (true ∧ true) => true - (false ∧ true) => false - (true ∧ false) => false - (false ∧ false) => false - ¬(true) => false - ¬(false) => true -end - -theory = or_alg ∪ and_alg ∪ comb ∪ negt ∪ impl ∪ fold - - -query = :(¬(((¬p ∨ q) ∧ (¬r ∨ s)) ∧ (p ∨ r)) ∨ (q ∨ s)) - -########################################### - -params = SaturationParams(timeout=22, eclasslimit=3051, - scheduler=ScoredScheduler)#, schedulerparams=(1000,5, Schedulers.exprsize)) - -for i ∈ 1:2 - G = EGraph( query ) - report = saturate!(G, theory, params) - ex = extract!(G, astsize) - println( "Best found: $ex") - println(report) -end - - -open("src/main.rs", "w") do f - write(f, rust_code(theory, query, params)) -end diff --git a/benchmarks/egg_maths.jl b/benchmarks/egg_maths.jl deleted file mode 100644 index b8a97270..00000000 --- a/benchmarks/egg_maths.jl +++ /dev/null @@ -1,88 +0,0 @@ -include("eggify.jl") -using Metatheory.Library -using Metatheory.EGraphs.Schedulers - -mult_t = commutative_monoid(:(*), 1) -plus_t = commutative_monoid(:(+), 0) - -minus_t = @theory begin - a - a => 0 - a + (-b) => a - b -end - -mulplus_t = @theory begin - 0 * a => 0 - a * 0 => 0 - a * (b + c) == ((a*b) + (a*c)) - a + (b * a) => ((b+1)*a) -end - -pow_t = @theory begin - (y^n) * y => y^(n+1) - x^n * x^m == x^(n+m) - (x * y)^z == x^z * y^z - (x^p)^q == x^(p*q) - x^0 => 1 - 0^x => 0 - 1^x => 1 - x^1 => x - inv(x) == x^(-1) -end - -function customlt(x,y) - if typeof(x) == Expr && Expr == typeof(y) - false - elseif typeof(x) == typeof(y) - isless(x,y) - elseif x isa Symbol && y isa Number - false - else - true - end -end - -canonical_t = @theory begin - # restore n-arity - (x + (+)(ys...)) => +(x,ys...) - ((+)(xs...) + y) => +(xs..., y) - (x * (*)(ys...)) => *(x,ys...) - ((*)(xs...) * y) => *(xs..., y) - - (*)(xs...) |> Expr(:call, :*, sort!(xs; lt=customlt)...) - (+)(xs...) |> Expr(:call, :+, sort!(xs; lt=customlt)...) -end - - -cas = mult_t ∪ plus_t ∪ minus_t ∪ mulplus_t ∪ pow_t -theory = cas - -query = cleanast(:(a + b + (0*c) + d)) - - -function simplify(ex) - g = EGraph(ex) - params = SaturationParams( - scheduler=BackoffScheduler, - timeout=20, - schedulerparams=(1000,5) # fuel and bantime - ) - report = saturate!(g, cas, params) - println(report) - res = extract!(g, astsize) - res = rewrite(res, canonical_t; clean=false, m=@__MODULE__) # this just orders symbols and restores n-ary plus and mult - res -end - -########################################### - -params = SaturationParams(timeout=20, schedulerparams=(1000,5)) - -for i ∈ 1:2 - ex = simplify( :( a + b + (0*c) + d) ) - println( "Best found: $ex") -end - - -open("src/main.rs", "w") do f - write(f, rust_code(theory, query)) -end diff --git a/benchmarks/eggify.jl b/benchmarks/eggify.jl deleted file mode 100644 index 7ea5368b..00000000 --- a/benchmarks/eggify.jl +++ /dev/null @@ -1,54 +0,0 @@ -using Metatheory -using Metatheory.EGraphs - -to_sexpr_pattern(p::PatLiteral) = "$(p.val)" -to_sexpr_pattern(p::PatVar) = "?$(p.name)" -function to_sexpr_pattern(p::PatTerm) - e1 = join([p.head ; to_sexpr_pattern.(p.args)], ' ') - "($e1)" -end - -to_sexpr(e::Symbol) = e -to_sexpr(e::Int64) = e -to_sexpr(e::Expr) = "($(join(to_sexpr.(e.args),' ')))" - -function eggify(rules) - egg_rules = [] - for rule in rules - l = to_sexpr_pattern(rule.left) - r = to_sexpr_pattern(rule.right) - if rule isa SymbolicRule - push!(egg_rules,"\tvec![rw!( \"$(rule.left) => $(rule.right)\" ; \"$l\" => \"$r\" )]") - elseif rule isa EqualityRule - push!(egg_rules,"\trw!( \"$(rule.left) == $(rule.right)\" ; \"$l\" <=> \"$r\" )") - else - println("Unsupported Rewrite Mode") - @assert false - end - - end - return join(egg_rules, ",\n") -end - -function rust_code(theory, query, params=SaturationParams()) - """ - use egg::{*, rewrite as rw}; - //use std::time::Duration; - fn main() { - let rules : &[Rewrite] = &vec![ - $(eggify(theory)) - ].concat(); - - let start = "$(to_sexpr(cleanast(query)))".parse().unwrap(); - let runner = Runner::default().with_expr(&start) - // More options here https://docs.rs/egg/0.6.0/egg/struct.Runner.html - .with_iter_limit($(params.timeout)) - .with_node_limit($(params.enodelimit)) - .run(rules); - runner.print_report(); - let mut extractor = Extractor::new(&runner.egraph, AstSize); - let (best_cost, best_expr) = extractor.find_best(runner.roots[0]); - println!("best cost: {}, best expr {}", best_cost, best_expr); - } - """ -end \ No newline at end of file diff --git a/benchmarks/src/main.rs b/benchmarks/src/main.rs deleted file mode 100644 index 3dea5016..00000000 --- a/benchmarks/src/main.rs +++ /dev/null @@ -1,56 +0,0 @@ -use egg::{*, rewrite as rw}; -//use std::time::Duration; -fn main() { - let rules : &[Rewrite] = &vec![ - vec![rw!( "p ∨ q ∨ r => p ∨ q ∨ r" ; "(∨ (∨ ?p ?q) ?r)" => "(∨ ?p (∨ ?q ?r))" )], - vec![rw!( "p ∨ q => q ∨ p" ; "(∨ ?p ?q)" => "(∨ ?q ?p)" )], - vec![rw!( "p ∨ p => p" ; "(∨ ?p ?p)" => "?p" )], - vec![rw!( "p ∨ true => true" ; "(∨ ?p true)" => "true" )], - vec![rw!( "p ∨ false => p" ; "(∨ ?p false)" => "?p" )], - vec![rw!( "p ∧ q ∧ r => p ∧ q ∧ r" ; "(∧ (∧ ?p ?q) ?r)" => "(∧ ?p (∧ ?q ?r))" )], - vec![rw!( "p ∧ q => q ∧ p" ; "(∧ ?p ?q)" => "(∧ ?q ?p)" )], - vec![rw!( "p ∧ p => p" ; "(∧ ?p ?p)" => "?p" )], - vec![rw!( "p ∧ true => p" ; "(∧ ?p true)" => "?p" )], - vec![rw!( "p ∧ false => false" ; "(∧ ?p false)" => "false" )], - vec![rw!( "¬p ∨ q => ¬p ∧ ¬q" ; "(¬ (∨ ?p ?q))" => "(∧ (¬ ?p) (¬ ?q))" )], - vec![rw!( "¬p ∧ q => ¬p ∨ ¬q" ; "(¬ (∧ ?p ?q))" => "(∨ (¬ ?p) (¬ ?q))" )], - vec![rw!( "p ∧ q ∨ r => p ∧ q ∨ p ∧ r" ; "(∧ ?p (∨ ?q ?r))" => "(∨ (∧ ?p ?q) (∧ ?p ?r))" )], - vec![rw!( "p ∨ q ∧ r => p ∨ q ∧ p ∨ r" ; "(∨ ?p (∧ ?q ?r))" => "(∧ (∨ ?p ?q) (∨ ?p ?r))" )], - vec![rw!( "p ∧ p ∨ q => p" ; "(∧ ?p (∨ ?p ?q))" => "?p" )], - vec![rw!( "p ∨ p ∧ q => p" ; "(∨ ?p (∧ ?p ?q))" => "?p" )], - vec![rw!( "p ∧ ¬p ∨ q => p ∧ q" ; "(∧ ?p (∨ (¬ ?p) ?q))" => "(∧ ?p ?q)" )], - vec![rw!( "p ∨ ¬p ∧ q => p ∨ q" ; "(∨ ?p (∧ (¬ ?p) ?q))" => "(∨ ?p ?q)" )], - vec![rw!( "p ∧ ¬p => false" ; "(∧ ?p (¬ ?p))" => "false" )], - vec![rw!( "p ∨ ¬p => true" ; "(∨ ?p (¬ ?p))" => "true" )], - vec![rw!( "¬¬p => p" ; "(¬ (¬ ?p))" => "?p" )], - vec![rw!( "p == ¬p => false" ; "(== ?p (¬ ?p))" => "false" )], - vec![rw!( "p == p => true" ; "(== ?p ?p)" => "true" )], - vec![rw!( "p == q => ¬p ∨ q ∧ ¬q ∨ p" ; "(== ?p ?q)" => "(∧ (∨ (¬ ?p) ?q) (∨ (¬ ?q) ?p))" )], - vec![rw!( "p => q => ¬p ∨ q" ; "(=> ?p ?q)" => "(∨ (¬ ?p) ?q)" )], - vec![rw!( "true == false => false" ; "(== true false)" => "false" )], - vec![rw!( "false == true => false" ; "(== false true)" => "false" )], - vec![rw!( "true == true => true" ; "(== true true)" => "true" )], - vec![rw!( "false == false => true" ; "(== false false)" => "true" )], - vec![rw!( "true ∨ false => true" ; "(∨ true false)" => "true" )], - vec![rw!( "false ∨ true => true" ; "(∨ false true)" => "true" )], - vec![rw!( "true ∨ true => true" ; "(∨ true true)" => "true" )], - vec![rw!( "false ∨ false => false" ; "(∨ false false)" => "false" )], - vec![rw!( "true ∧ true => true" ; "(∧ true true)" => "true" )], - vec![rw!( "false ∧ true => false" ; "(∧ false true)" => "false" )], - vec![rw!( "true ∧ false => false" ; "(∧ true false)" => "false" )], - vec![rw!( "false ∧ false => false" ; "(∧ false false)" => "false" )], - vec![rw!( "¬true => false" ; "(¬ true)" => "false" )], - vec![rw!( "¬false => true" ; "(¬ false)" => "true" )] - ].concat(); - - let start = "(∨ (¬ (∧ (∧ (∨ (¬ p) q) (∨ (¬ r) s)) (∨ p r))) (∨ q s))".parse().unwrap(); - let runner = Runner::default().with_expr(&start) - // More options here https://docs.rs/egg/0.6.0/egg/struct.Runner.html - .with_iter_limit(22) - .with_node_limit(15000) - .run(rules); - runner.print_report(); - let mut extractor = Extractor::new(&runner.egraph, AstSize); - let (best_cost, best_expr) = extractor.find_best(runner.roots[0]); - println!("best cost: {}, best expr {}", best_cost, best_expr); -} diff --git a/benchmarks/times.txt b/benchmarks/times.txt deleted file mode 100644 index f168d61f..00000000 --- a/benchmarks/times.txt +++ /dev/null @@ -1,55 +0,0 @@ -times - -vector sub, dict matches - -69.203038 seconds (218.14 M allocations: 9.312 GiB, 3.87% gc time, 21.12% compilation time) -elapsed time (ns): 69203038466 -gc time (ns): 2679282281 -bytes allocated: 9998487280 -pool allocs: 217958967 -non-pool GC allocs:74887 -malloc() calls: 37390 -realloc() calls: 67862 -free() calls: 35082 -GC pauses: 210 -full collections: 8 - -both dict -68.880892 seconds (224.14 M allocations: 9.464 GiB, 3.93% gc time, 20.62% compilation time) -elapsed time (ns): 68880892497 -gc time (ns): 2707393124 -bytes allocated: 10161446360 -pool allocs: 223961498 -non-pool GC allocs:76606 -malloc() calls: 37540 -realloc() calls: 67854 -free() calls: 35082 -GC pauses: 216 -full collections: 8 - -vector sub vector matches -64.250042 seconds (191.13 M allocations: 8.489 GiB, 3.70% gc time, 21.76% compilation time) -elapsed time (ns): 64250042213 -gc time (ns): 2376234107 -bytes allocated: 9114714927 -pool allocs: 190960525 -non-pool GC allocs:67038 -malloc() calls: 36493 -realloc() calls: 67821 -free() calls: 35082 -GC pauses: 189 -full collections: 8 - - -before optimizations -89.061844 seconds (181.57 M allocations: 8.602 GiB, 3.54% gc time, 7.70% compilation time) -elapsed time (ns): 89061844200 -gc time (ns): 3155740685 -bytes allocated: 9236789175 -pool allocs: 181391374 -non-pool GC allocs:72310 -malloc() calls: 35915 -realloc() calls: 67814 -free() calls: 35081 -GC pauses: 198 -full collections: 6 diff --git a/docs/Project.toml b/docs/Project.toml index 8bffb4c9..4866d7b6 100644 --- a/docs/Project.toml +++ b/docs/Project.toml @@ -1,6 +1,8 @@ [deps] Documenter = "e30172f5-a6a5-5a46-863b-614d45cd2de4" +Literate = "98b081ad-f1c9-55d3-8b20-4c87d4299306" Metatheory = "e9d8d322-4543-424a-9be4-0cc815abe26c" +TermInterface = "8ea1fca8-c5ef-4a55-8b96-4e9afe9c9a3c" [compat] Documenter = "~0.26" diff --git a/docs/make.jl b/docs/make.jl index f0c84e57..75e2deb6 100644 --- a/docs/make.jl +++ b/docs/make.jl @@ -1,26 +1,36 @@ using Documenter using Metatheory +using Literate using Metatheory.EGraphs using Metatheory.Library +TUTORIALSDIR = joinpath(dirname(pathof(Metatheory)), "../test/tutorials/") +OUTDIR = abspath(joinpath(@__DIR__, "src", "tutorials")) -for m ∈ [Metatheory] - for i ∈ propertynames(m) - xxx = getproperty(m, i) - println(xxx) - end - end +for f in readdir(TUTORIALSDIR) + if endswith(f, ".jl") + input = abspath(joinpath(TUTORIALSDIR, f)) + name = basename(input) + Literate.markdown(input, OUTDIR) + elseif f != "README.md" + @info "Copying $f" + cp(joinpath(TUTORIALSDIR, input), joinpath(OUTDIR, f); force=true) + end +end + +tutorials = [joinpath("tutorials", f[1:end-3]) * ".md" for f in readdir(TUTORIALSDIR) if endswith(f, ".jl")] makedocs( - modules = [Metatheory, Metatheory.EGraphs], - sitename = "Metatheory.jl", - pages = [ - "index.md" - "rewrite.md" - "egraphs.md" - "interface.md" - "api.md" - ]) + modules = [Metatheory, Metatheory.EGraphs], + sitename = "Metatheory.jl", + pages = [ + "index.md" + "rewrite.md" + "egraphs.md" + "api.md" + "Tutorials" => tutorials + ], +) -deploydocs(repo = "github.com/JuliaSymbolics/Metatheory.jl.git") +#deploydocs(repo = "github.com/JuliaSymbolics/Metatheory.jl.git") diff --git a/docs/src/egraphs.md b/docs/src/egraphs.md index fbeec07c..b9a458cd 100644 --- a/docs/src/egraphs.md +++ b/docs/src/egraphs.md @@ -10,8 +10,18 @@ the [egg](https://egraphs-good.github.io/) library for Rust. You can read more about the design of the EGraph data structure and equality saturation algorithm in the [egg paper](https://dl.acm.org/doi/pdf/10.1145/3434304). -See [Alessandro Cheli](https://0x0f0f0f.github.io/) and [Philip Zucker](https://www.philipzucker.com/)'s -[talk at JuliaCon 2021](https://www.youtube.com/watch?v=tdXfsTliRJk) for an overview of the concepts introduced in this chapter of the manual (**NOTE**: Syntax in the talk slideshow is out of date). +Let's load Metatheory and the rule library +```julia +using Metatheory +using Metatheory.Library +``` + +```@meta +DocTestSetup = quote + using Metatheory + using Metatheory.Library +end +``` ## What can I do with EGraphs in Metatheory.jl? @@ -52,10 +62,17 @@ are **EGraph Analyses**. They allow you to annotate expressions and equivalence The `Metatheory.Library` module contains utility functions and macros for creating rules and theories from commonly used algebraic structures and properties, to be used with the e-graph backend. -```julia -using Metatheory.Library - +```jldoctest comm_monoid = @commutative_monoid (*) 1 + +# output + +4-element Vector{RewriteRule}: + ~a * ~b --> ~b * ~a + (~a * ~b) * ~c --> ~a * (~b * ~c) + ~a * (~b * ~c) --> (~a * ~b) * ~c + 1 * ~a --> ~a + ``` @@ -66,54 +83,35 @@ commutativity and distributivity**, rules that are otherwise known of causing loops and require extensive user reasoning in classical rewriting. -```julia +```jldoctest t = @theory a b c begin a * b == b * a a * 1 == a a * (b * c) == (a * b) * c end + +# output + +3-element Vector{EqualityRule}: + ~a * ~b == ~b * ~a + ~a * 1 == ~a + ~a * (~b * ~c) == (~a * ~b) * ~c + ``` ## Equality Saturation -We can programmatically build and saturate an EGraph. -The function `saturate!` takes an `EGraph` and a theory, and executes -equality saturation. Returns a report -of the equality saturation process. -`saturate!` is configurable, customizable parameters include -a `timeout` on the number of iterations, a `eclasslimit` on the number of e-classes in the EGraph, a `stopwhen` functions that stops saturation when it evaluates to true. -```julia -g = EGraph(:((a * b) * (1 * (b + c)))); -report = saturate!(G, t); -# access the saturated EGraph -report.egraph - -# show some fancy stats -report -``` +We can programmatically build and saturate an EGraph. The function `saturate!` +takes an `EGraph` and a theory, and executes equality saturation. Returns a +report of the equality saturation process. `saturate!` is configurable, +customizable parameters include a `timeout` on the number of iterations, a +`eclasslimit` on the number of e-classes in the EGraph, a `stopwhen` functions +that stops saturation when it evaluates to true. -``` -Equality Saturation Report -================= - Stop Reason: saturated - Iterations: 1 - EGraph Size: 9 eclasses, 51 nodes - ─────────────────────────────────────────────────────────────────────────────────────── - Time Allocations - ────────────────────── ─────────────────────── - Tot / % measured: 1.18s / 0.45% 955KiB / 68.1% - - Section ncalls time %tot avg alloc %tot avg - ─────────────────────────────────────────────────────────────────────────────────────── - Apply 1 4.63ms 87.5% 4.63ms 512KiB 78.7% 512KiB - Search 1 656μs 12.4% 656μs 139KiB 21.3% 139KiB - a * (b * c) == (a * b) * c 1 242μs 4.58% 242μs 79.2KiB 12.2% 79.2KiB - a * b == b * a 1 153μs 2.89% 153μs 34.2KiB 5.26% 34.2KiB - a * 1 == a 1 115μs 2.17% 115μs 14.4KiB 2.21% 14.4KiB - appending matches 3 4.06μs 0.08% 1.35μs 544B 0.08% 181B - Rebuild 1 3.75μs 0.07% 3.75μs 0.00B 0.00% 0.00B - ─────────────────────────────────────────────────────────────────────────────────────── +```@example +g = EGraph(:((a * b) * (1 * (b + c)))); +report = saturate!(g, t); ``` With the EGraph equality saturation backend, Metatheory.jl can prove **simple** @@ -226,14 +224,14 @@ which *e-node* will be extracted from an *e-class*. It must return a positive, non-complex number value and, must accept 3 arguments. 1) The current [ENode](@ref) `n` that is being inspected. 2) The current [EGraph](@ref) `g`. -3) The current analysis type `an`. +3) The current analysis name `an::Symbol`. From those 3 parameters, one can access all the data needed to compute the cost of an e-node recursively. * One can use [TermInterface.jl](https://github.com/JuliaSymbolics/TermInterface.jl) methods to access the operation and child arguments of an e-node: `operation(n)`, `arity(n)` and `arguments(n)` * Since e-node children always point to e-classes in the same e-graph, one can retrieve the [EClass](@ref) object for each child of the currently visited enode with `g[id] for id in arguments(n)` -* One can inspect the analysis data for a given eclass and a given analysis type `an`, by using [hasdata](@ref) and [getdata](@ref). +* One can inspect the analysis data for a given eclass and a given analysis name `an`, by using [hasdata](@ref) and [getdata](@ref). * Extraction analyses always associate a tuple of 2 values to a single e-class: which e-node is the one that minimizes the cost and its cost. More details can be found in the [egg paper](https://dl.acm.org/doi/pdf/10.1145/3434304) in the *Analyses* section. @@ -243,7 +241,7 @@ Here's an example: # This is a cost function that behaves like `astsize` but increments the cost # of nodes containing the `^` operation. This results in a tendency to avoid # extraction of expressions containing '^'. -function cost_function(n::ENodeTerm, g::EGraph, an::Type{<:AbstractAnalysis}) +function cost_function(n::ENodeTerm, g::EGraph) cost = 1 + arity(n) operation(n) == :^ && (cost += 2) @@ -251,34 +249,36 @@ function cost_function(n::ENodeTerm, g::EGraph, an::Type{<:AbstractAnalysis}) for id in arguments(n) eclass = g[id] # if the child e-class has not yet been analyzed, return +Inf - !hasdata(eclass, an) && (cost += Inf; break) - cost += last(getdata(eclass, an)) + !hasdata(eclass, cost_function) && (cost += Inf; break) + cost += last(getdata(eclass, cost_function)) end return cost end # All literal expressions (e.g `a`, 123, 0.42, "hello") have cost 1 -cost_function(n::ENodeLiteral, g::EGraph, an::Type{<:AbstractAnalysis}) = 1 +cost_function(n::ENodeLiteral, g::EGraph) = 1 ``` ## EGraph Analyses An *EGraph Analysis* is an efficient and automated way of analyzing all the possible terms contained in an e-graph. Metatheory.jl provides a toolkit to ease and -automate the process of EGraph Analysis. An *EGraph Analysis* defines a domain -of values and associates a value from the domain to each [EClass](@ref) in the graph. -Theoretically, the domain should form a [join semilattice](https://en.wikipedia.org/wiki/Semilattice). -Rewrites can cooperate with e-class analyses by depending on analysis facts and adding -equivalences that in turn establish additional facts. - -In Metatheory.jl, EGraph Analyses are identified by a *type* that is subtype of `AbstractAnalysis`. -An [`EGraph`](@ref) can only contain one analysis per type. -The following functions define an interface for analyses based on multiple dispatch -on `AbstractAnalysis` types: -* [islazy](@ref) should return true if the analysis should NOT be computed on-the-fly during egraphs operation, only when required. -* [make](@ref) should take an ENode and return a value from the analysis domain. -* [join](@ref) should return the semilattice join of two values in the analysis domain (e.g. *given two analyses value from ENodes in the same EClass, which one should I choose?*) -* [modify!](@ref) Can be optionally implemented. Can be used modify an EClass on-the-fly given its analysis value. +automate the process of EGraph Analysis. + +An *EGraph Analysis* defines a domain of values and associates a value from the domain to each [EClass](@ref) in the graph. Theoretically, the domain should form a [join semilattice](https://en.wikipedia.org/wiki/Semilattice). Rewrites can cooperate with e-class analyses by depending on analysis facts and adding equivalences that in turn establish additional facts. + +In Metatheory.jl, **EGraph Analyses are uniquely identified** by either + +* An unique name of type `Symbol`. +* A function object `f`, used for cost function analysis. This will use built-in definitions of `make` and `join`. + +If you are specifying a custom analysis by its `Symbol` name, +the following functions define an interface for analyses based on multiple dispatch +on `Val{analysis_name::Symbol}`: +* [islazy(an)](@ref) should return true if the analysis name `an` should NOT be computed on-the-fly during egraphs operation, but only when inspected. +* [make(an, egraph, n)](@ref) should take an ENode `n` and return a value from the analysis domain. +* [join(an, x,y)](@ref) should return the semilattice join of `x` and `y` in the analysis domain (e.g. *given two analyses value from ENodes in the same EClass, which one should I choose?*). If `an` is a `Function`, it is treated as a cost function analysis, it is automatically defined to be the minimum analysis value between `x` and `y`. Typically, the domain value of cost functions are real numbers, but if you really do want to have your own cost type, make sure that `Base.isless` is defined. +* [modify!(an, egraph, eclassid)](@ref) Can be optionally implemented. This can be used modify an EClass `egraph[eclassid]` on-the-fly during an e-graph saturation iteration, given its analysis value. ### Defining a custom analysis @@ -292,14 +292,11 @@ the actual numeric result of the expressions in the EGraph, but we only care to the symbolic expressions that will result in an even or an odd number. Defining an EGraph Analysis is similar to the process [Mathematical Induction](https://en.wikipedia.org/wiki/Mathematical_induction). -To define a custom EGraph Analysis, one should start by defining a type that -subtypes `AbstractAnalysis` that will be used to identify this specific analysis and -to dispatch against the required methods. +To define a custom EGraph Analysis, one should start by defining a name of type `Symbol` that will be used to identify this specific analysis and to dispatch against the required methods. ```julia using Metatheory using Metatheory.EGraphs -abstract type OddEvenAnalysis <: AbstractAnalysis end ``` The next step, the base case of induction, is to define a method for @@ -308,7 +305,7 @@ associate an analysis value only to the *literals* contained in the EGraph. To d take advantage of multiple dispatch against `ENodeLiteral`. ```julia -function EGraphs.make(an::Type{OddEvenAnalysis}, g::EGraph, n::ENodeLiteral) +function EGraphs.make(::Val{:OddEvenAnalysis}, g::EGraph, n::ENodeLiteral) if n.value isa Integer return iseven(n.value) ? :even : :odd else @@ -336,7 +333,7 @@ From the definition of an [ENode](@ref), we know that children of ENodes are alw to EClasses in the EGraph. ```julia -function EGraphs.make(an::Type{OddEvenAnalysis}, g::EGraph, n::ENodeTerm) +function EGraphs.make(::Val{:OddEvenAnalysis}, g::EGraph, n::ENodeTerm) # Let's consider only binary function call terms. if exprhead(n) == :call && arity(n) == 2 op = operation(n) @@ -347,19 +344,25 @@ function EGraphs.make(an::Type{OddEvenAnalysis}, g::EGraph, n::ENodeTerm) # Get the corresponding OddEvenAnalysis value of the children # defaulting to nothing - ldata = getdata(l, an, nothing) - rdata = getdata(r, an, nothing) + ldata = getdata(l, :OddEvenAnalysis, nothing) + rdata = getdata(r, :OddEvenAnalysis, nothing) if ldata isa Symbol && rdata isa Symbol if op == :* - return (ldata == :even || rdata == :even) ? :even : :odd + if ldata == rdata + ldata + elseif (ldata == :even || rdata == :even) + :even + else + nothing + end elseif op == :+ - return (ldata == rdata) ? :even : :odd + (ldata == rdata) ? :even : :odd end elseif isnothing(ldata) && rdata isa Symbol && op == :* - return rdata + rdata elseif ldata isa Symbol && isnothing(rdata) && op == :* - return ldata + ldata end end @@ -375,7 +378,7 @@ how to extract a single value out of the many analyses values contained in an EG We do this by defining a method for [join](@ref). ```julia -function EGraphs.join(an::Type{OddEvenAnalysis}, a, b) +function EGraphs.join(::Val{:OddEvenAnalysis}, a, b) if a == b return a else @@ -406,7 +409,8 @@ function custom_analysis(expr) return getdata(g[g.root], OddEvenAnalysis) end -custom_analysis(:(3*a)) # :odd +custom_analysis(:(2*a)) # :even +custom_analysis(:(3*3)) # :odd custom_analysis(:(3*(2+a)*2)) # :even custom_analysis(:(3y * (2x*y))) # :even ``` diff --git a/docs/src/index.md b/docs/src/index.md index a388afe4..8ddf9009 100644 --- a/docs/src/index.md +++ b/docs/src/index.md @@ -1,4 +1,4 @@ -# Metatheory.jl 1.0 +# Metatheory.jl 2.0 ```@raw html

@@ -29,9 +29,21 @@ Intuitively, Metatheory.jl transforms Julia expressions in other Julia expressions and can achieve such at both compile and run time. This allows Metatheory.jl users to perform customized and composable compiler optimizations specifically tailored to single, arbitrary Julia packages. Our library provides a simple, algebraically composable interface to help scientists in implementing and reasoning about semantics and all kinds of formal systems, by defining concise rewriting rules in pure, syntactically valid Julia on a high level of abstraction. Our implementation of equality saturation on e-graphs is based on the excellent, state-of-the-art technique implemented in the [egg](https://egraphs-good.github.io/) library, reimplemented in pure Julia. -## 1.0 is out! +## 2.0 is out! -The first stable version of Metatheory.jl is out! The goal of this release is to unify the symbolic manipulation ecosystem of Julia packages. Many features have been ported from SymbolicUtils.jl. Now, Metatheory.jl can be used in place of SymbolicUtils.jl when you have no need of manipulating mathematical expressions. SymbolicUtils.jl can now completely leverage on the generic stack of rewriting features provided by Metatheory.jl, highly decoupled from the symbolic term representation thanks to [TermInterface.jl](https://github.com/JuliaSymbolics/TermInterface.jl). Read more in [NEWS.md](https://github.com/JuliaSymbolics/Metatheory.jl/blob/master/NEWS.md). +Second stable version is out: + +- New e-graph pattern matching system, relies on functional programming and closures, and is much more extensible than 1.0's virtual machine. +- No longer dispatch against types, but instead dispatch against objects. +- Faster E-Graph Analysis +- Better library macros +- Updated TermInterface to 0.3.3 +- New interface for e-graph extraction using `EGraphs.egraph_reconstruct_expression` +- Simplify E-Graph Analysis Interface. Use Symbols or functions for identifying Analyses. +- Remove duplicates in E-Graph analyses data. + + +Many features have been ported from SymbolicUtils.jl. Metatheory.jl can be used in place of SymbolicUtils.jl when you have no need of manipulating mathematical expressions. The introduction of [TermInterface.jl](https://github.com/JuliaSymbolics/TermInterface.jl) has allowed for large potential in generalization of term rewriting and symbolic analysis and manipulation features. Integration between Metatheory.jl with Symbolics.jl, as it has been shown in the ["High-performance symbolic-numerics via multiple dispatch"](https://arxiv.org/abs/2105.03949) paper. ## Recommended Readings - Selected Publications diff --git a/docs/src/interface.md b/docs/src/interface.md deleted file mode 100644 index 4bb43943..00000000 --- a/docs/src/interface.md +++ /dev/null @@ -1,12 +0,0 @@ -# Interfacing with Metatheory.jl - -This section is for Julia package developers who may want to use the rule rewriting systems on their own expression types. - -## Defining the interface - -Metatheory.jl matchers can match any Julia object that implements an interface to traverse it as a tree. The interface in question, is defined in the [TermInterface.jl](https://github.com/JuliaSymbolics/TermInterface.jl) package. Its purpose is to provide a shared interface between various symbolic programming Julia packages. - -In particular, you should define methods from TermInterface.jl for an expression tree type `T` with symbol types `S` to work -with SymbolicUtils.jl - -You can read the documentation of [TermInterface.jl](https://github.com/JuliaSymbolics/TermInterface.jl) on the [Github repository](https://github.com/JuliaSymbolics/TermInterface.jl). diff --git a/docs/src/rewrite.md b/docs/src/rewrite.md index 9289578d..a8923b60 100644 --- a/docs/src/rewrite.md +++ b/docs/src/rewrite.md @@ -86,7 +86,7 @@ Matcher pattern may contain slot variables with attached predicates, written as - A function that takes a matched expression and returns a boolean value. Such a slot will be considered a match only if `p` returns true. - A Julia type. Will be considered a match if and only if the value matching against `x` has a type that is a subtype of `p` (`typeof(x) <: p`) -Similarly `~x::g...` is a way of attaching a predicate `g` to a segment variable. In the case of segment variables `g` gets a vector of 0 or more expressions and must return a boolean value. If the same slot or segment variable appears twice in the matcher pattern, then at most one of the occurance should have a predicate. +Similarly `~x::g...` is a way of attaching a predicate `g` to a segment variable. In the case of segment variables `g` gets a vector of 0 or more expressions and must return a boolean value. If the same slot or segment variable appears twice in the matcher pattern, then at most one of the occurrence should have a predicate. For example, @@ -174,7 +174,7 @@ t = comm_monoid ∪ comm_group ∪ distrib ## Composing rewriters Rules may be *chained together* into more -sophisticated rewirters to avoid manual application of the rules. A rewriter is +sophisticated rewriters to avoid manual application of the rules. A rewriter is any callable object which takes an expression and returns an expression or `nothing`. If `nothing` is returned that means there was no changes applicable to the input expression. The Rules we created above are rewriters. @@ -189,7 +189,7 @@ rewriters. - `RestartedChain(itr)` like `Chain(itr)` but restarts from the first rewriter once on the first successful application of one of the chained rewriters. - `IfElse(cond, rw1, rw2)` runs the `cond` function on the input, applies `rw1` if cond - returns true, `rw2` if it retuns false + returns true, `rw2` if it returns false - `If(cond, rw)` is the same as `IfElse(cond, rw, Empty())` - `Prewalk(rw; threaded=false, thread_cutoff=100)` returns a rewriter which does a pre-order (*from top to bottom and from left to right*) traversal of a given expression and applies diff --git a/format.jl b/format.jl new file mode 100644 index 00000000..29a6e7db --- /dev/null +++ b/format.jl @@ -0,0 +1,6 @@ +using JuliaFormatter + +format(file; kwargs...) = JuliaFormatter.format(joinpath(@__DIR__, file); kwargs...) + +format("src"; verbose = true) +format("test"; verbose = true) diff --git a/benchmarks/Cargo.toml b/scratch/Cargo.toml similarity index 100% rename from benchmarks/Cargo.toml rename to scratch/Cargo.toml diff --git a/benchmarks/Project.toml b/scratch/Project.toml similarity index 100% rename from benchmarks/Project.toml rename to scratch/Project.toml diff --git a/test/logic/benchmark_logic.jl b/scratch/benchmark_logic.jl similarity index 57% rename from test/logic/benchmark_logic.jl rename to scratch/benchmark_logic.jl index 1846da78..5746b608 100644 --- a/test/logic/benchmark_logic.jl +++ b/scratch/benchmark_logic.jl @@ -1,7 +1,6 @@ include("prop_logic_theory.jl") include("prover.jl") -ex = rewrite(:(((p => q) ∧ (r => s) ∧ (p ∨ r)) => (q ∨ s)), impl) +ex = rewrite(:(((p => q) && (r => s) && (p || r)) => (q || s)), impl) prove(t, ex, 1, 25) @profview prove(t, ex, 2, 7) - diff --git a/scratch/egg_logic.jl b/scratch/egg_logic.jl new file mode 100644 index 00000000..c26e98fb --- /dev/null +++ b/scratch/egg_logic.jl @@ -0,0 +1,86 @@ +include("eggify.jl") +using Metatheory.Library +using Metatheory.EGraphs.Schedulers + +or_alg = @theory begin + ((p || q) || r) == (p || (q || r)) + (p || q) == (q || p) + (p || p) => p + (p || true) => true + (p || false) => p +end + +and_alg = @theory begin + ((p && q) && r) == (p && (q && r)) + (p && q) == (q && p) + (p && p) => p + (p && true) => p + (p && false) => false +end + +comb = @theory begin + # DeMorgan + !(p || q) == (!p && !q) + !(p && q) == (!p || !q) + # distrib + (p && (q || r)) == ((p && q) || (p && r)) + (p || (q && r)) == ((p || q) && (p || r)) + # absorb + (p && (p || q)) => p + (p || (p && q)) => p + # complement + (p && (!p || q)) => p && q + (p || (!p && q)) => p || q +end + +negt = @theory begin + (p && !p) => false + (p || !(p)) => true + !(!p) == p +end + +impl = @theory begin + (p == !p) => false + (p == p) => true + (p == q) => (!p || q) && (!q || p) + (p => q) => (!p || q) +end + +fold = @theory begin + (true == false) => false + (false == true) => false + (true == true) => true + (false == false) => true + (true || false) => true + (false || true) => true + (true || true) => true + (false || false) => false + (true && true) => true + (false && true) => false + (true && false) => false + (false && false) => false + !(true) => false + !(false) => true +end + +theory = or_alg ∪ and_alg ∪ comb ∪ negt ∪ impl ∪ fold + + +query = :(!(((!p || q) && (!r || s)) && (p || r)) || (q || s)) + +########################################### + +params = SaturationParams(timeout = 22, eclasslimit = 3051, scheduler = ScoredScheduler)#, schedulerparams=(1000,5, Schedulers.exprsize)) + +for i in 1:2 + G = EGraph(query) + report = saturate!(G, theory, params) + ex = extract!(G, astsize) + println("Best found: $ex") + println(report) +end + + +open("src/main.rs", "w") do f + write(f, rust_code(theory, query, params)) +end diff --git a/scratch/egg_maths.jl b/scratch/egg_maths.jl new file mode 100644 index 00000000..0ee1c72c --- /dev/null +++ b/scratch/egg_maths.jl @@ -0,0 +1,88 @@ +include("eggify.jl") +using Metatheory.Library +using Metatheory.EGraphs.Schedulers + +mult_t = commutative_monoid(:(*), 1) +plus_t = commutative_monoid(:(+), 0) + +minus_t = @theory begin + a - a => 0 + a + (-b) => a - b +end + +mulplus_t = @theory begin + 0 * a => 0 + a * 0 => 0 + a * (b + c) == ((a * b) + (a * c)) + a + (b * a) => ((b + 1) * a) +end + +pow_t = @theory begin + (y^n) * y => y^(n + 1) + x^n * x^m == x^(n + m) + (x * y)^z == x^z * y^z + (x^p)^q == x^(p * q) + x^0 => 1 + 0^x => 0 + 1^x => 1 + x^1 => x + inv(x) == x^(-1) +end + +function customlt(x, y) + if typeof(x) == Expr && Expr == typeof(y) + false + elseif typeof(x) == typeof(y) + isless(x, y) + elseif x isa Symbol && y isa Number + false + else + true + end +end + +canonical_t = @theory begin + # restore n-arity + (x + (+)(ys...)) => +(x, ys...) + ((+)(xs...) + y) => +(xs..., y) + (x * (*)(ys...)) => *(x, ys...) + ((*)(xs...) * y) => *(xs..., y) + + (*)(xs...) |> Expr(:call, :*, sort!(xs; lt = customlt)...) + (+)(xs...) |> Expr(:call, :+, sort!(xs; lt = customlt)...) +end + + +cas = mult_t ∪ plus_t ∪ minus_t ∪ mulplus_t ∪ pow_t +theory = cas + +query = cleanast(:(a + b + (0 * c) + d)) + + +function simplify(ex) + g = EGraph(ex) + params = SaturationParams( + scheduler = BackoffScheduler, + timeout = 20, + schedulerparams = (1000, 5), # fuel and bantime + ) + report = saturate!(g, cas, params) + println(report) + res = extract!(g, astsize) + res = rewrite(res, canonical_t; clean = false, m = @__MODULE__) # this just orders symbols and restores n-ary plus and mult + res +end + +########################################### + +params = SaturationParams(timeout = 20, schedulerparams = (1000, 5)) + +for i in 1:2 + ex = simplify(:(a + b + (0 * c) + d)) + println("Best found: $ex") +end + + +open("src/main.rs", "w") do f + write(f, rust_code(theory, query)) +end diff --git a/scratch/eggify.jl b/scratch/eggify.jl new file mode 100644 index 00000000..04e82b2c --- /dev/null +++ b/scratch/eggify.jl @@ -0,0 +1,54 @@ +using Metatheory +using Metatheory.EGraphs + +to_sexpr_pattern(p::PatLiteral) = "$(p.val)" +to_sexpr_pattern(p::PatVar) = "?$(p.name)" +function to_sexpr_pattern(p::PatTerm) + e1 = join([p.head; to_sexpr_pattern.(p.args)], ' ') + "($e1)" +end + +to_sexpr(e::Symbol) = e +to_sexpr(e::Int64) = e +to_sexpr(e::Expr) = "($(join(to_sexpr.(e.args),' ')))" + +function eggify(rules) + egg_rules = [] + for rule in rules + l = to_sexpr_pattern(rule.left) + r = to_sexpr_pattern(rule.right) + if rule isa SymbolicRule + push!(egg_rules, "\tvec![rw!( \"$(rule.left) => $(rule.right)\" ; \"$l\" => \"$r\" )]") + elseif rule isa EqualityRule + push!(egg_rules, "\trw!( \"$(rule.left) == $(rule.right)\" ; \"$l\" <=> \"$r\" )") + else + println("Unsupported Rewrite Mode") + @assert false + end + + end + return join(egg_rules, ",\n") +end + +function rust_code(theory, query, params = SaturationParams()) + """ + use egg::{*, rewrite as rw}; + //use std::time::Duration; + fn main() { + let rules : &[Rewrite] = &vec![ + $(eggify(theory)) + ].concat(); + + let start = "$(to_sexpr(cleanast(query)))".parse().unwrap(); + let runner = Runner::default().with_expr(&start) + // More options here https://docs.rs/egg/0.6.0/egg/struct.Runner.html + .with_iter_limit($(params.timeout)) + .with_node_limit($(params.enodelimit)) + .run(rules); + runner.print_report(); + let mut extractor = Extractor::new(&runner.egraph, AstSize); + let (best_cost, best_expr) = extractor.find_best(runner.roots[0]); + println!("best cost: {}, best expr {}", best_cost, best_expr); + } + """ +end diff --git a/benchmarks/figures/fib.pdf b/scratch/figures/fib.pdf similarity index 100% rename from benchmarks/figures/fib.pdf rename to scratch/figures/fib.pdf diff --git a/benchmarks/gen_egg_instructions.md b/scratch/gen_egg_instructions.md similarity index 100% rename from benchmarks/gen_egg_instructions.md rename to scratch/gen_egg_instructions.md diff --git a/scratch/src/main.rs b/scratch/src/main.rs new file mode 100644 index 00000000..a885fae3 --- /dev/null +++ b/scratch/src/main.rs @@ -0,0 +1,56 @@ +use egg::{*, rewrite as rw}; +//use std::time::Duration; +fn main() { + let rules : &[Rewrite] = &vec![ + vec![rw!( "p || q || r => p || q || r" ; "(|| (|| ?p ?q) ?r)" => "(|| ?p (|| ?q ?r))" )], + vec![rw!( "p || q => q || p" ; "(|| ?p ?q)" => "(|| ?q ?p)" )], + vec![rw!( "p || p => p" ; "(|| ?p ?p)" => "?p" )], + vec![rw!( "p || true => true" ; "(|| ?p true)" => "true" )], + vec![rw!( "p || false => p" ; "(|| ?p false)" => "?p" )], + vec![rw!( "p && q && r => p && q && r" ; "(&& (&& ?p ?q) ?r)" => "(&& ?p (&& ?q ?r))" )], + vec![rw!( "p && q => q && p" ; "(&& ?p ?q)" => "(&& ?q ?p)" )], + vec![rw!( "p && p => p" ; "(&& ?p ?p)" => "?p" )], + vec![rw!( "p && true => p" ; "(&& ?p true)" => "?p" )], + vec![rw!( "p && false => false" ; "(&& ?p false)" => "false" )], + vec![rw!( "!p || q => !p && !q" ; "(! (|| ?p ?q))" => "(&& (! ?p) (! ?q))" )], + vec![rw!( "!p && q => !p || !q" ; "(! (&& ?p ?q))" => "(|| (! ?p) (! ?q))" )], + vec![rw!( "p && q || r => p && q || p && r" ; "(&& ?p (|| ?q ?r))" => "(|| (&& ?p ?q) (&& ?p ?r))" )], + vec![rw!( "p || q && r => p || q && p || r" ; "(|| ?p (&& ?q ?r))" => "(&& (|| ?p ?q) (|| ?p ?r))" )], + vec![rw!( "p && p || q => p" ; "(&& ?p (|| ?p ?q))" => "?p" )], + vec![rw!( "p || p && q => p" ; "(|| ?p (&& ?p ?q))" => "?p" )], + vec![rw!( "p && !p || q => p && q" ; "(&& ?p (|| (! ?p) ?q))" => "(&& ?p ?q)" )], + vec![rw!( "p || !p && q => p || q" ; "(|| ?p (&& (! ?p) ?q))" => "(|| ?p ?q)" )], + vec![rw!( "p && !p => false" ; "(&& ?p (! ?p))" => "false" )], + vec![rw!( "p || !p => true" ; "(|| ?p (! ?p))" => "true" )], + vec![rw!( "!!p => p" ; "(! (! ?p))" => "?p" )], + vec![rw!( "p == !p => false" ; "(== ?p (! ?p))" => "false" )], + vec![rw!( "p == p => true" ; "(== ?p ?p)" => "true" )], + vec![rw!( "p == q => !p || q && !q || p" ; "(== ?p ?q)" => "(&& (|| (! ?p) ?q) (|| (! ?q) ?p))" )], + vec![rw!( "p => q => !p || q" ; "(=> ?p ?q)" => "(|| (! ?p) ?q)" )], + vec![rw!( "true == false => false" ; "(== true false)" => "false" )], + vec![rw!( "false == true => false" ; "(== false true)" => "false" )], + vec![rw!( "true == true => true" ; "(== true true)" => "true" )], + vec![rw!( "false == false => true" ; "(== false false)" => "true" )], + vec![rw!( "true || false => true" ; "(|| true false)" => "true" )], + vec![rw!( "false || true => true" ; "(|| false true)" => "true" )], + vec![rw!( "true || true => true" ; "(|| true true)" => "true" )], + vec![rw!( "false || false => false" ; "(|| false false)" => "false" )], + vec![rw!( "true && true => true" ; "(&& true true)" => "true" )], + vec![rw!( "false && true => false" ; "(&& false true)" => "false" )], + vec![rw!( "true && false => false" ; "(&& true false)" => "false" )], + vec![rw!( "false && false => false" ; "(&& false false)" => "false" )], + vec![rw!( "!true => false" ; "(! true)" => "false" )], + vec![rw!( "!false => true" ; "(! false)" => "true" )] + ].concat(); + + let start = "(|| (! (&& (&& (|| (! p) q) (|| (! r) s)) (|| p r))) (|| q s))".parse().unwrap(); + let runner = Runner::default().with_expr(&start) + // More options here https://docs.rs/egg/0.6.0/egg/struct.Runner.html + .with_iter_limit(22) + .with_node_limit(15000) + .run(rules); + runner.print_report(); + let mut extractor = Extractor::new(&runner.egraph, AstSize); + let (best_cost, best_expr) = extractor.find_best(runner.roots[0]); + println!("best cost: {}, best expr {}", best_cost, best_expr); +} diff --git a/src/EGraphs/EGraphs.jl b/src/EGraphs/EGraphs.jl index 76acf3e8..d418c3db 100644 --- a/src/EGraphs/EGraphs.jl +++ b/src/EGraphs/EGraphs.jl @@ -5,14 +5,11 @@ include("../docstrings.jl") using DataStructures using TermInterface using TimerOutputs -using Parameters -using Metatheory: alwaystrue, cleanast, binarize, @log +using Metatheory: + alwaystrue, cleanast, binarize, @log, DEFAULT_BUFFER_SIZE, BUFFER, BUFFER_LOCK, MERGES_BUF, MERGES_BUF_LOCK, Bindings using Metatheory.Patterns using Metatheory.Rules using Metatheory.EMatchCompiler -using Dates - -import ThreadsX include("intdisjointmap.jl") export IntDisjointSet @@ -43,11 +40,8 @@ export analyze! export extract! export astsize export astsize_inv -export AbstractAnalysis -export MetadataAnalysis export getcost! -include("ematch.jl") export Sub include("Schedulers.jl") diff --git a/src/EGraphs/Schedulers.jl b/src/EGraphs/Schedulers.jl index 68abf94b..6ca3d36b 100644 --- a/src/EGraphs/Schedulers.jl +++ b/src/EGraphs/Schedulers.jl @@ -2,7 +2,7 @@ module Schedulers include("../docstrings.jl") -using Metatheory.Rules +using Metatheory.Rules using Metatheory.EGraphs using Metatheory.Patterns using DocStringExtensions @@ -28,7 +28,7 @@ Should return `true` if the e-graph can be said to be saturated cansaturate(s::AbstractScheduler) ``` """ -function cansaturate end +function cansaturate end """ Should return `false` if the rule `r` should be skipped @@ -63,7 +63,7 @@ struct SimpleScheduler <: AbstractScheduler end cansaturate(s::SimpleScheduler) = true cansearch(s::SimpleScheduler, r::AbstractRule) = true function SimpleScheduler(G::EGraph, theory::Vector{<:AbstractRule}) - SimpleScheduler() + SimpleScheduler() end inform!(s::SimpleScheduler, r, n_matches) = true setiter!(s::SimpleScheduler, iteration) = nothing @@ -74,10 +74,10 @@ setiter!(s::SimpleScheduler, iteration) = nothing # =========================================================================== mutable struct BackoffSchedulerEntry - match_limit::Int - ban_length::Int - times_banned::Int - banned_until::Int + match_limit::Int + ban_length::Int + times_banned::Int + banned_until::Int end """ @@ -91,29 +91,29 @@ This seems effective at preventing explosive rules like associativity from taking an unfair amount of resources. """ mutable struct BackoffScheduler <: AbstractScheduler - data::IdDict{AbstractRule, BackoffSchedulerEntry} - G::EGraph - theory::Vector{<:AbstractRule} - curr_iter::Int + data::IdDict{AbstractRule,BackoffSchedulerEntry} + G::EGraph + theory::Vector{<:AbstractRule} + curr_iter::Int end cansearch(s::BackoffScheduler, r::AbstractRule)::Bool = s.curr_iter > s.data[r].banned_until function BackoffScheduler(g::EGraph, theory::Vector{<:AbstractRule}) - # BackoffScheduler(g, theory, 128, 4) - BackoffScheduler(g, theory, 1000, 5) + # BackoffScheduler(g, theory, 128, 4) + BackoffScheduler(g, theory, 1000, 5) end function BackoffScheduler(G::EGraph, theory::Vector{<:AbstractRule}, match_limit::Int, ban_length::Int) - gsize = length(G.uf) - data = IdDict{AbstractRule, BackoffSchedulerEntry}() + gsize = length(G.uf) + data = IdDict{AbstractRule,BackoffSchedulerEntry}() - for rule ∈ theory - data[rule] = BackoffSchedulerEntry(match_limit, ban_length, 0, 0) - end + for rule in theory + data[rule] = BackoffSchedulerEntry(match_limit, ban_length, 0, 0) + end - return BackoffScheduler(data, G, theory, 1) + return BackoffScheduler(data, G, theory, 1) end # can saturate if there's no banned rule @@ -121,22 +121,19 @@ cansaturate(s::BackoffScheduler)::Bool = all(kv -> s.curr_iter > last(kv).banned function inform!(s::BackoffScheduler, rule::AbstractRule, n_matches) - # println(s.data[rule]) - - rd = s.data[rule] - treshold = rd.match_limit << rd.times_banned - if n_matches > treshold - ban_length = rd.ban_length << rd.times_banned - rd.times_banned += 1 - rd.banned_until = s.curr_iter + ban_length - # @info "banning rule $rule until $(rd.banned_until)!" - return false - end - return true + rd = s.data[rule] + treshold = rd.match_limit << rd.times_banned + if n_matches > treshold + ban_length = rd.ban_length << rd.times_banned + rd.times_banned += 1 + rd.banned_until = s.curr_iter + ban_length + return false + end + return true end function setiter!(s::BackoffScheduler, curr_iter) - s.curr_iter = curr_iter + s.curr_iter = curr_iter end # =========================================================================== @@ -145,11 +142,11 @@ end mutable struct ScoredSchedulerEntry - match_limit::Int - ban_length::Int - times_banned::Int - banned_until::Int - weight::Int + match_limit::Int + ban_length::Int + times_banned::Int + banned_until::Int + weight::Int end """ @@ -163,67 +160,71 @@ This seems effective at preventing explosive rules like associativity from taking an unfair amount of resources. """ mutable struct ScoredScheduler <: AbstractScheduler - data::IdDict{AbstractRule, ScoredSchedulerEntry} - G::EGraph - theory::Vector{<:AbstractRule} - curr_iter::Int + data::IdDict{AbstractRule,ScoredSchedulerEntry} + G::EGraph + theory::Vector{<:AbstractRule} + curr_iter::Int end cansearch(s::ScoredScheduler, r::AbstractRule)::Bool = s.curr_iter > s.data[r].banned_until exprsize(a) = 1 -function exprsize(e::PatTerm) - c = 1 + length(e.args) - for a ∈ e.args - c += exprsize(a) - end - return c +function exprsize(e::PatTerm) + c = 1 + length(e.args) + for a in e.args + c += exprsize(a) + end + return c end function exprsize(e::Expr) - start = Meta.isexpr(e, :call) ? 2 : 1 + start = Meta.isexpr(e, :call) ? 2 : 1 - c = 1 + length(e.args[start:end]) - for a ∈ e.args[start:end] - c += exprsize(a) - end + c = 1 + length(e.args[start:end]) + for a in e.args[start:end] + c += exprsize(a) + end - return c + return c end function ScoredScheduler(g::EGraph, theory::Vector{<:AbstractRule}) - # BackoffScheduler(g, theory, 128, 4) - ScoredScheduler(g, theory, 1000, 5, exprsize) -end - -function ScoredScheduler(G::EGraph, theory::Vector{<:AbstractRule}, match_limit::Int, ban_length::Int, complexity::Function) - gsize = length(G.uf) - data = IdDict{AbstractRule, ScoredSchedulerEntry}() - - for rule ∈ theory - if rule isa DynamicRule - w = 2 - data[rule] = ScoredSchedulerEntry(match_limit, ban_length, 0, 0, w) - continue - end - (l, r) = rule.left, rule.right - - cl = complexity(l) - cr = complexity(r) - # println("$rule HAS SCORE $((cl, cr))") - if cl > cr - w = 1 # reduces complexity - elseif cr > cl - w = 3 # augments complexity - else - w = 2 # complexity is equal - end - # println(w) - data[rule] = ScoredSchedulerEntry(match_limit, ban_length, 0, 0, w) + # BackoffScheduler(g, theory, 128, 4) + ScoredScheduler(g, theory, 1000, 5, exprsize) +end + +function ScoredScheduler( + G::EGraph, + theory::Vector{<:AbstractRule}, + match_limit::Int, + ban_length::Int, + complexity::Function, +) + gsize = length(G.uf) + data = IdDict{AbstractRule,ScoredSchedulerEntry}() + + for rule in theory + if rule isa DynamicRule + w = 2 + data[rule] = ScoredSchedulerEntry(match_limit, ban_length, 0, 0, w) + continue + end + (l, r) = rule.left, rule.right + + cl = complexity(l) + cr = complexity(r) + if cl > cr + w = 1 # reduces complexity + elseif cr > cl + w = 3 # augments complexity + else + w = 2 # complexity is equal end + data[rule] = ScoredSchedulerEntry(match_limit, ban_length, 0, 0, w) + end - return ScoredScheduler(data, G, theory, 1) + return ScoredScheduler(data, G, theory, 1) end # can saturate if there's no banned rule @@ -231,22 +232,20 @@ cansaturate(s::ScoredScheduler)::Bool = all(kv -> s.curr_iter > last(kv).banned_ function inform!(s::ScoredScheduler, rule::AbstractRule, n_matches) - # println(s.data[rule]) - - rd = s.data[rule] - treshold = rd.match_limit * (rd.weight^rd.times_banned) - if n_matches > treshold - ban_length = rd.ban_length * (rd.weight^rd.times_banned) - rd.times_banned += 1 - rd.banned_until = s.curr_iter + ban_length - # @info "banning rule $rule until $(rd.banned_until)!" - return false - end - return true + rd = s.data[rule] + treshold = rd.match_limit * (rd.weight^rd.times_banned) + if n_matches > treshold + ban_length = rd.ban_length * (rd.weight^rd.times_banned) + rd.times_banned += 1 + rd.banned_until = s.curr_iter + ban_length + # @info "banning rule $rule until $(rd.banned_until)!" + return false + end + return true end function setiter!(s::ScoredScheduler, curr_iter) - s.curr_iter = curr_iter + s.curr_iter = curr_iter end diff --git a/src/EGraphs/analysis.jl b/src/EGraphs/analysis.jl index abb86f3d..2510cd62 100644 --- a/src/EGraphs/analysis.jl +++ b/src/EGraphs/analysis.jl @@ -1,5 +1,9 @@ +analysis_reference(x::Symbol) = Val(x) +analysis_reference(x::Function) = x +analysis_reference(x) = error("$x is not a valid analysis reference") + """ - islazy(an::Type{<:AbstractAnalysis}) + islazy(::Val{analysis_name}) Should return `true` if the EGraph Analysis `an` is lazy and false otherwise. A *lazy* EGraph Analysis is computed @@ -7,213 +11,199 @@ only when [analyze!](@ref) is called. *Non-lazy* analyses are instead computed on-the-fly every time ENodes are added to the EGraph or EClasses are merged. """ -islazy(an::Type{<:AbstractAnalysis})::Bool = false +islazy(::Val{analysis_name}) where {analysis_name} = false +islazy(analysis_name) = islazy(analysis_reference(analysis_name)) """ - modify!(an::Type{<:AbstractAnalysis}, g, id) + modify!(::Val{analysis_name}, g, id) The `modify!` function for EGraph Analysis can optionally modify the eclass `g[id]` after it has been analyzed, typically by adding an ENode. It should be **idempotent** if no other changes occur to the EClass. (See the [egg paper](https://dl.acm.org/doi/pdf/10.1145/3434304)). """ -modify!(analysis::Type{<:AbstractAnalysis}, g, id) = nothing +modify!(::Val{analysis_name}, g, id) where {analysis_name} = nothing +modify!(an, g, id) = modify!(analysis_reference(an), g, id) """ - join(an::Type{<:AbstractAnalysis}, a, b) + join(::Val{analysis_name}, a, b) Joins two analyses values into a single one, used by [analyze!](@ref) when two eclasses are being merged or the analysis is being constructed. """ -join(analysis::Type{<:AbstractAnalysis}, a, b) = - error("Analysis does not implement join") +join(analysis::Val{analysis_name}, a, b) where {analysis_name} = + error("Analysis $analysis_name does not implement join") +join(an, a, b) = join(analysis_reference(an), a, b) """ - make(an::Type{<:AbstractAnalysis}, g, n) + make(::Val{analysis_name}, g, n) Given an ENode `n`, `make` should return the corresponding analysis value. """ -make(analysis::Type{<:AbstractAnalysis}, g, n) = - error("Analysis does not implement make") - +make(::Val{analysis_name}, g, n) where {analysis_name} = error("Analysis $analysis_name does not implement make") +make(an, g, n) = make(analysis_reference(an), g, n) -# TODO default analysis for metadata here -abstract type MetadataAnalysis <: AbstractAnalysis end - -analyze!(g::EGraph, an::Type{<:AbstractAnalysis}, id::EClassId) = analyze!(g, an, reachable(g, id)) -analyze!(g::EGraph, an::Type{<:AbstractAnalysis}) = analyze!(g, an, collect(keys(g.classes))) +analyze!(g::EGraph, analysis_ref, id::EClassId) = analyze!(g, analysis_ref, reachable(g, id)) +analyze!(g::EGraph, analysis_ref) = analyze!(g, analysis_ref, collect(keys(g.classes))) """ - analyze!(egraph, analysis, [ECLASS_IDS]) + analyze!(egraph, analysis_name, [ECLASS_IDS]) -Given an [EGraph](@ref) and an `analysis` of type `<:AbstractAnalysis`, +Given an [EGraph](@ref) and an `analysis` identified by name `analysis_name`, do an automated bottom up trasversal of the EGraph, associating a value from the -domain of `analysis` to each ENode in the egraph by the [make](@ref) function. +domain of analysis to each ENode in the egraph by the [make](@ref) function. Then, for each [EClass](@ref), compute the [join](@ref) of the children ENodes analyses values. After `analyze!` is called, an analysis value will be associated to each EClass in the EGraph. -One can inspect and retrieve analysis values by using [hasdata](@ref) and [getdata](@ref). -Note that an [EGraph](@ref) can only contain one analysis of type `an`. -""" -function analyze!(g::EGraph, an::Type{<:AbstractAnalysis}, ids::Vector{EClassId}) - push!(g.analyses, an) - ids = sort(ids) - # @assert isempty(g.dirty) - - did_something = true - while did_something - did_something = false - - for id ∈ ids - eclass = g[id] - id = eclass.id - pass = mapreduce(x -> make(an, g, x), (x, y) -> join(an, x, y), eclass) - # pass = make_pass(G, analysis, find(G,id)) - - # if pass !== missing - if !isequal(pass, getdata(eclass, an, missing)) - setdata!(eclass, an, pass) - did_something = true - push!(g.dirty, id) - end - end +One can inspect and retrieve analysis values by using [hasdata](@ref) and [getdata](@ref). +""" +function analyze!(g::EGraph, analysis_ref, ids::Vector{EClassId}) + addanalysis!(g, analysis_ref) + ids = sort(ids) + # @assert isempty(g.dirty) + + did_something = true + while did_something + did_something = false + + for id in ids + eclass = g[id] + id = eclass.id + pass = mapreduce(x -> make(analysis_ref, g, x), (x, y) -> join(analysis_ref, x, y), eclass) + + if !isequal(pass, getdata(eclass, analysis_ref, missing)) + setdata!(eclass, analysis_ref, pass) + did_something = true + push!(g.dirty, id) + end end + end - for id ∈ ids - eclass = g[id] - id = eclass.id - if !hasdata(eclass, an) - error("failed to compute analysis for eclass ", id) - end + for id in ids + eclass = g[id] + id = eclass.id + if !hasdata(eclass, analysis_ref) + error("failed to compute analysis for eclass ", id) end + end - return true + return true end """ A basic cost function, where the computed cost is the size (number of children) of the current expression. """ -function astsize(n::ENodeTerm, g::EGraph, an::Type{<:AbstractAnalysis}) - cost = 1 + arity(n) - for id ∈ arguments(n) - eclass = g[id] - !hasdata(eclass, an) && (cost += Inf; break) - cost += last(getdata(eclass, an)) - end - return cost +function astsize(n::ENodeTerm, g::EGraph) + cost = 1 + arity(n) + for id in arguments(n) + eclass = g[id] + !hasdata(eclass, astsize) && (cost += Inf; break) + cost += last(getdata(eclass, astsize)) + end + return cost end -astsize(n::ENodeLiteral, g::EGraph, an::Type{<:AbstractAnalysis}) = 1 +astsize(n::ENodeLiteral, g::EGraph) = 1 """ A basic cost function, where the computed cost is the size (number of children) of the current expression, times -1. Strives to get the largest expression """ -function astsize_inv(n::ENodeTerm, g::EGraph, an::Type{<:AbstractAnalysis}) - cost = -(1 + arity(n)) # minus sign here is the only difference vs astsize - for id ∈ arguments(n) - eclass = g[id] - !hasdata(eclass, an) && (cost += Inf; break) - cost += last(getdata(eclass, an)) - end - return cost +function astsize_inv(n::ENodeTerm, g::EGraph) + cost = -(1 + arity(n)) # minus sign here is the only difference vs astsize + for id in arguments(n) + eclass = g[id] + !hasdata(eclass, astsize_inv) && (cost += Inf; break) + cost += last(getdata(eclass, astsize_inv)) + end + return cost end -astsize_inv(n::ENodeLiteral, g::EGraph, an::Type{<:AbstractAnalysis}) = -1 +astsize_inv(n::ENodeLiteral, g::EGraph) = -1 """ -An [`AbstractAnalysis`](@ref) that computes the cost of expression nodes -and chooses the node with the smallest cost for each E-Class. -This abstract type is parametrised by a function F. -This is useful for the analysis storage in [`EClass`](@ref) +When passing a function to analysis functions it is considered as a cost function """ -abstract type ExtractionAnalysis{F} <: AbstractAnalysis end +make(f::Function, g::EGraph, n::AbstractENode) = (n, f(n, g)) -make(a::Type{ExtractionAnalysis{F}}, g::EGraph, n::AbstractENode) where F = (n, F(n, g, a)) +join(f::Function, from, to) = last(from) <= last(to) ? from : to -join(a::Type{<:ExtractionAnalysis}, from, to) = last(from) <= last(to) ? from : to +islazy(::Function) = true +modify!(::Function, g, id) = nothing -islazy(a::Type{<:ExtractionAnalysis}) = true +function rec_extract(g::EGraph, costfun, id::EClassId; cse_env = nothing) + eclass = g[id] + if !isnothing(cse_env) && haskey(cse_env, id) + (sym, _) = cse_env[id] + return sym + end + (n, ck) = getdata(eclass, costfun, (nothing, Inf)) + ck == Inf && error("Infinite cost when extracting enode") -function rec_extract(g::EGraph, an, id::EClassId; simterm=similarterm, cse_env=nothing) - eclass = g[id] - if !isnothing(cse_env) && haskey(cse_env, id) - (sym, _) = cse_env[id] - return sym - end - anval = getdata(eclass, an, (nothing, Inf)) - (n, ck) = anval - ck == Inf && error("Infinite cost when extracting enode") - - if n isa ENodeLiteral - return n.value - elseif n isa ENodeTerm - children = map(child -> rec_extract(g, an, child; simterm=simterm, cse_env=cse_env), arguments(n)) - meta = getdata(eclass, MetadataAnalysis, nothing) - T = termtype(n) - simterm(T, operation(n), children; metadata=meta, exprhead=exprhead(n)); - else - error("Unknown ENode Type $(typeof(cn))") - end + if n isa ENodeLiteral + return n.value + elseif n isa ENodeTerm + children = map(arg -> rec_extract(g, costfun, arg; cse_env = cse_env), n.args) + meta = getdata(eclass, :metadata_analysis, nothing) + T = symtype(n) + egraph_reconstruct_expression(T, operation(n), collect(children); metadata = meta, exprhead = exprhead(n)) + else + error("Unknown ENode Type $(typeof(n))") + end end """ Given a cost function, extract the expression with the smallest computed cost from an [`EGraph`](@ref) """ -function extract!(g::EGraph, costfun::Function; root=-1, simterm=similarterm, cse=false) - a = ExtractionAnalysis{costfun} - if root == -1 - root = g.root - end - analyze!(g, a, root) - if cse - # TODO make sure there is no assignments/stateful code!! - cse_env = OrderedDict{EClassId, Tuple{Symbol, Any}}() # - collect_cse!(g, a, root, cse_env, Set{EClassId}(); simterm=simterm) - # @show root - # @show cse_env - - body = rec_extract(g, a, root; simterm=simterm, cse_env=cse_env) - - assignments = [Expr(:(=), name, val) for (id, (name, val)) in cse_env] - # return body - Expr(:let, Expr(:block, assignments...), body) - else - return rec_extract(g, a, root; simterm=simterm) - end +function extract!(g::EGraph, costfun::Function; root = -1, cse = false) + if root == -1 + root = g.root + end + analyze!(g, costfun, root) + if cse + # TODO make sure there is no assignments/stateful code!! + cse_env = OrderedDict{EClassId,Tuple{Symbol,Any}}() # + collect_cse!(g, costfun, root, cse_env, Set{EClassId}()) + + body = rec_extract(g, costfun, root; cse_env = cse_env) + + assignments = [Expr(:(=), name, val) for (id, (name, val)) in cse_env] + # return body + Expr(:let, Expr(:block, assignments...), body) + else + return rec_extract(g, costfun, root) + end end # Builds a dict e-class id => (symbol, extracted term) of common subexpressions in an e-graph -function collect_cse!(g::EGraph, an, id, cse_env, seen; simterm=similarterm) - eclass = g[id] - anval = getdata(eclass, an, (nothing, Inf)) - (cn, ck) = anval - ck == Inf && error("Error when computing CSE") - if cn isa ENodeTerm - if id in seen - cse_env[id] = (gensym(), rec_extract(g, an, id; simterm=simterm))#, cse_env=cse_env)) # todo generalize symbol? - return - end - for child_id in arguments(cn) - collect_cse!(g, an, child_id, cse_env, seen; simterm=simterm) - end - push!(seen, id) +function collect_cse!(g::EGraph, costfun, id, cse_env, seen) + eclass = g[id] + (cn, ck) = getdata(eclass, costfun, (nothing, Inf)) + ck == Inf && error("Error when computing CSE") + if cn isa ENodeTerm + if id in seen + cse_env[id] = (gensym(), rec_extract(g, costfun, id))#, cse_env=cse_env)) # todo generalize symbol? + return end + for child_id in arguments(cn) + collect_cse!(g, costfun, child_id, cse_env, seen) + end + push!(seen, id) + end end -getcost!(g::EGraph, costfun::Function; root=-1) = getcost!(g, ExtractionAnalysis{costfun}; root=root) -function getcost!(g::EGraph, analysis::Type{ExtractionAnalysis{F}}; root=-1) where {F} - if root == -1 - root = g.root - end - analyze!(g, analysis, root) - bestnode, cost = getdata(g[root], analysis) - return cost +function getcost!(g::EGraph, costfun; root = -1) + if root == -1 + root = g.root + end + analyze!(g, costfun, root) + bestnode, cost = getdata(g[root], costfun) + return cost end diff --git a/src/EGraphs/egraph.jl b/src/EGraphs/egraph.jl index 0e1b6925..f3ffa2ac 100644 --- a/src/EGraphs/egraph.jl +++ b/src/EGraphs/egraph.jl @@ -2,129 +2,100 @@ # https://dl.acm.org/doi/10.1145/3434304 -""" -Abstract type representing an [`EGraph`](@ref) analysis, -attaching values from a join semi-lattice domain to -an EGraph -""" -abstract type AbstractAnalysis end -abstract type AbstractENode{T} end +abstract type AbstractENode end -const AnalysisData = Base.ImmutableDict{Type{<:AbstractAnalysis}, Any} +const AnalysisData = NamedTuple{N,T} where {N,T<:Tuple{Vararg{<:Ref}}} const EClassId = Int64 -const HashCons = Dict{AbstractENode,EClassId} -const Analyses = Set{Type{<:AbstractAnalysis}} -const SymbolCache = Dict{Any, Set{EClassId}} -const TermTypes = Dict{Tuple{Any, EClassId}, Type} +const TermTypes = Dict{Tuple{Any,Int},Type} -mutable struct ENodeLiteral{T} <: AbstractENode{T} - value::T - hash::Ref{UInt} +mutable struct ENodeLiteral <: AbstractENode + value + hash::Ref{UInt} + ENodeLiteral(a) = new(a, Ref{UInt}(0)) end -mutable struct ENodeTerm{T} <: AbstractENode{T} - exprhead::Union{Symbol, Nothing} - operation::Any - args::Vector{EClassId} - hash::Ref{UInt} # hash cache -end +Base.:(==)(a::ENodeLiteral, b::ENodeLiteral) = isequal(hash(a), hash(b)) -# parametrize metadata by M -mutable struct EClass - g # EGraph - id::EClassId - nodes::Vector{AbstractENode} - parents::Vector{Pair{AbstractENode, EClassId}} - data::Union{Nothing, AnalysisData} - # data::M -end +TermInterface.istree(n::ENodeLiteral) = false +TermInterface.exprhead(n::ENodeLiteral) = nothing +TermInterface.operation(n::ENodeLiteral) = n.value +TermInterface.arity(n::ENodeLiteral) = 0 -const ClassMem = Dict{EClassId,EClass} +function Base.hash(t::ENodeLiteral, salt::UInt) + !iszero(salt) && return hash(hash(t, zero(UInt)), salt) + h = t.hash[] + !iszero(h) && return h + h′ = hash(t.value, salt) + t.hash[] = h′ + return h′ +end -function ENodeTerm{T}(exprhead, operation, c_ids) where {T} - ENodeTerm{T}(exprhead, operation, c_ids, Ref{UInt}(0)) +mutable struct ENodeTerm <: AbstractENode + exprhead::Union{Symbol,Nothing} + operation::Any + symtype::Type + args::Vector{EClassId} + hash::Ref{UInt} # hash cache + ENodeTerm(exprhead, operation, symtype, c_ids) = new(exprhead, operation, symtype, c_ids, Ref{UInt}(0)) end -function Base.isequal(a::ENodeTerm, b::ENodeTerm) - isequal(a.args, b.args) && - isequal(a.exprhead, b.exprhead) && isequal(a.operation, b.operation) +function Base.:(==)(a::ENodeTerm, b::ENodeTerm) + hash(a) == hash(b) && a.operation == b.operation end TermInterface.istree(n::ENodeTerm) = true -TermInterface.istree(t::Type{<:ENodeTerm}) = true +TermInterface.symtype(n::ENodeTerm) = n.symtype TermInterface.exprhead(n::ENodeTerm) = n.exprhead -TermInterface.operation(n::ENodeTerm) = n.operation -TermInterface.arguments(n::ENodeTerm) = n.args +TermInterface.operation(n::ENodeTerm) = n.operation +TermInterface.arguments(n::ENodeTerm) = n.args TermInterface.arity(n::ENodeTerm) = length(n.args) # This optimization comes from SymbolicUtils # The hash of an enode is cached to avoid recomputing it. # Shaves off a lot of time in accessing dictionaries with ENodes as keys. -function Base.hash(t::ENodeTerm{T}, salt::UInt) where {T} - !iszero(salt) && return hash(hash(t, zero(UInt)), salt) - h = t.hash[] - !iszero(h) && return h - h′ = hash(t.args, hash(t.exprhead, hash(t.operation, hash(T, salt)))) - t.hash[] = h′ - return h′ +function Base.hash(t::ENodeTerm, salt::UInt) + !iszero(salt) && return hash(hash(t, zero(UInt)), salt) + h = t.hash[] + !iszero(h) && return h + h′ = hash(t.args, hash(t.exprhead, hash(t.operation, salt))) + t.hash[] = h′ + return h′ end -function toexpr(n::ENodeTerm) - eh = exprhead(n) - if isnothing(eh) - return operation(n) # n is a constant enode - end - similarterm(Expr, operation(n), map(i -> Symbol(i, "ₑ"), arguments(n)); exprhead=exprhead(n)) +# parametrize metadata by M +mutable struct EClass + g # EGraph + id::EClassId + nodes::Vector{AbstractENode} + parents::Vector{Pair{AbstractENode,EClassId}} + data::AnalysisData end - -# ================================================== -# ENode Literal -# ================================================== - -TermInterface.istree(n::ENodeLiteral) = false -TermInterface.istree(t::Type{<:ENodeLiteral}) = false -TermInterface.exprhead(n::ENodeLiteral) = nothing -TermInterface.operation(n::ENodeLiteral) = n.value -TermInterface.arity(n::ENodeLiteral) = 0 - -ENodeLiteral(a::T) where{T} = ENodeLiteral{T}(a, Ref{UInt}(0)) - -Base.:(==)(a::ENodeLiteral, b::ENodeLiteral) = isequal(a.value, b.value) - - -function Base.hash(t::ENodeLiteral{T}, salt::UInt) where {T} - !iszero(salt) && return hash(hash(t, zero(UInt)), salt) - h = t.hash[] - !iszero(h) && return h - h′ = hash(t.value, hash(T, salt)) - t.hash[] = h′ - return h′ +function toexpr(n::ENodeTerm) + Expr(:call, :ENode, exprhead(n), operation(n), symtype(n), arguments(n)) end - -termtype(x::AbstractENode{T}) where T = T +function Base.show(io::IO, x::ENodeTerm) + print(io, toexpr(x)) +end toexpr(n::ENodeLiteral) = operation(n) -function Base.show(io::IO, x::ENodeTerm{T}) where {T} - print(io, "ENode{$T}(", toexpr(x), ")") -end - Base.show(io::IO, x::ENodeLiteral) = print(io, toexpr(x)) -EClass(g, id) = EClass(g, id, AbstractENode[], Pair{AbstractENode, EClassId}[], nothing) -EClass(g, id, nodes, parents) = EClass(g, id, nodes, parents, nothing) +EClass(g, id) = EClass(g, id, AbstractENode[], Pair{AbstractENode,EClassId}[], nothing) +EClass(g, id, nodes, parents) = EClass(g, id, nodes, parents, NamedTuple()) # Interface for indexing EClass Base.getindex(a::EClass, i) = a.nodes[i] Base.setindex!(a::EClass, v, i) = setindex!(a.nodes, v, i) Base.firstindex(a::EClass) = firstindex(a.nodes) Base.lastindex(a::EClass) = lastindex(a.nodes) +Base.length(a::EClass) = length(a.nodes) # Interface for iterating EClass Base.iterate(a::EClass) = iterate(a.nodes) @@ -132,80 +103,68 @@ Base.iterate(a::EClass, state) = iterate(a.nodes, state) # Showing function Base.show(io::IO, a::EClass) - print(io, "EClass $(a.id) (") - - print(io, "[", Base.join(a.nodes, ", "), "]") - if a.data === nothing - print(io, ")") - return - end - print(io, ", analysis = {") - for (k, v) ∈ a.data - print(io, "$k => $v, ") - end - print(io, "})") + print(io, "EClass $(a.id) (") + + print(io, "[", Base.join(a.nodes, ", "), "], ") + print(io, a.data) + print(io, ")") end function addparent!(a::EClass, n::AbstractENode, id::EClassId) - push!(a.parents, (n => id)) + push!(a.parents, (n => id)) end function Base.union!(to::EClass, from::EClass) - append!(to.nodes, from.nodes) - append!(to.parents, from.parents) - if to.data !== nothing && from.data !== nothing - # merge!(to.data, from.data) - # to.data = join_analysis_data(to.data, from.data) - to.data = join_analysis_data(to.data, from.data) - elseif to.data === nothing - to.data = from.data + # TODO revisit + append!(to.nodes, from.nodes) + append!(to.parents, from.parents) + if !isnothing(to.data) && !isnothing(from.data) + to.data = join_analysis_data!(to.g, something(to.data), something(from.data)) + elseif to.data === nothing + to.data = from.data + end + return to +end + +function join_analysis_data!(g, dst::AnalysisData, src::AnalysisData) + new_dst = merge(dst, src) + for analysis_name in keys(src) + analysis_ref = g.analyses[analysis_name] + if hasproperty(dst, analysis_name) + ref = getproperty(new_dst, analysis_name) + ref[] = join(analysis_ref, ref[], getproperty(src, analysis_name)[]) end - return to -end - -function join_analysis_data(d::AnalysisData, dsrc::AnalysisData) - for (an, val_b) in dsrc - if haskey(d, an) - val_a = d[an] - nv = join(an, val_a, val_b) - # d[an] = nv - # WARNING immutable version - d = Base.ImmutableDict(d,an=>nv) - end - end - return d + end + new_dst end # Thanks to Shashi Gowda -function hasdata(a::EClass, x::Type{<:AbstractAnalysis}) - a.data === nothing && (return false) - return haskey(a.data, x) -end - -function getdata(a::EClass, x::Type{<:AbstractAnalysis}) - !hasdata(a, x) && error("EClass $a does not contain analysis data for $x") - return a.data[x] -end +hasdata(a::EClass, analysis_name::Symbol) = hasproperty(a.data, analysis_name) +hasdata(a::EClass, f::Function) = hasproperty(a.data, nameof(f)) +getdata(a::EClass, analysis_name::Symbol) = getproperty(a.data, analysis_name)[] +getdata(a::EClass, f::Function) = getproperty(a.data, nameof(f))[] +getdata(a::EClass, analysis_ref::Union{Symbol,Function}, default) = + hasdata(a, analysis_ref) ? getdata(a, analysis_ref) : default -function getdata(a::EClass, x::Type{<:AbstractAnalysis}, default) - hasdata(a, x) ? a.data[x] : default -end -function setdata!(a::EClass, x::Type{<:AbstractAnalysis}, value) - # lazy allocation - a.data === nothing && (a.data = AnalysisData()) - # a.data[x] = value - a.data = AnalysisData(a.data, x, value) +setdata!(a::EClass, f::Function, value) = setdata!(a, nameof(f), value) +function setdata!(a::EClass, analysis_name::Symbol, value) + if hasdata(a, analysis_name) + ref = getproperty(a.data, analysis_name) + ref[] = value + else + a.data = merge(a.data, NamedTuple{(analysis_name,)}((Ref{Any}(value),))) + end end function funs(a::EClass) - map(operation, a.nodes) + map(operation, a.nodes) end function funs_arity(a::EClass) - map(a.nodes) do x - (operation(x), arity(x)) - end + map(a.nodes) do x + (operation(x), arity(x)) + end end """ @@ -214,100 +173,87 @@ See the [egg paper](https://dl.acm.org/doi/pdf/10.1145/3434304) for implementation details. """ mutable struct EGraph - """stores the equality relations over e-class ids""" - # uf::IntDisjointSets{EClassId} - uf::IntDisjointSet - """map from eclass id to eclasses""" - classes::ClassMem - memo::HashCons # memo - """worklist for ammortized upwards merging""" - dirty::Vector{EClassId} - root::EClassId - """A vector of analyses associated to the EGraph""" - analyses::Analyses - # """ - # a cache mapping function symbols to e-classes that - # contain e-nodes with that function symbol. - # """ - # symcache::SymbolCache - default_termtype::Type - termtypes::TermTypes - numclasses::Int - numnodes::Int - # number of rules that have been applied - # age::Int + "stores the equality relations over e-class ids" + uf::IntDisjointSet + "map from eclass id to eclasses" + classes::Dict{EClassId,EClass} + "hashcons" + memo::Dict{AbstractENode,EClassId} # memo + "worklist for ammortized upwards merging" + dirty::Vector{EClassId} + root::EClassId + "A vector of analyses associated to the EGraph" + analyses::Dict{Union{Symbol,Function},Union{Symbol,Function}} + "a cache mapping function symbols to e-classes that contain e-nodes with that function symbol." + symcache::Dict{Any,Vector{EClassId}} + default_termtype::Type + termtypes::TermTypes + numclasses::Int + numnodes::Int end + """ EGraph(expr) Construct an EGraph from a starting symbolic expression `expr`. """ function EGraph() - EGraph( - IntDisjointSet{EClassId}(), - # IntDisjointSets{EClassId}(0), - ClassMem(), - HashCons(), - # ParentMem(), - EClassId[], - -1, - Analyses(), - # SymbolCache(), - Expr, - TermTypes(), - 0, - 0, - # 0 - ) -end - -function EGraph(e; keepmeta=false) - g = EGraph() - if keepmeta - push!(g.analyses, MetadataAnalysis) - end - - rootclass, rootnode = addexpr!(g, e; keepmeta=keepmeta) - g.root = rootclass.id - g + EGraph( + IntDisjointSet(), + Dict{EClassId,EClass}(), + Dict{AbstractENode,EClassId}(), + EClassId[], + -1, + Dict{Union{Symbol,Function},Union{Symbol,Function}}(), + Dict{Any,Vector{EClassId}}(), + Expr, + TermTypes(), + 0, + 0, + # 0 + ) +end + +function EGraph(e; keepmeta = false) + g = EGraph() + keepmeta && addanalysis!(g, :metadata_analysis) + g.root = addexpr!(g, e; keepmeta = keepmeta) + g +end + +function addanalysis!(g::EGraph, costfun::Function) + g.analyses[nameof(costfun)] = costfun + g.analyses[costfun] = costfun +end + +function addanalysis!(g::EGraph, analysis_name::Symbol) + g.analyses[analysis_name] = analysis_name end function settermtype!(g::EGraph, f, ar, T) - g.termtypes[(f,ar)] = T + g.termtypes[(f, ar)] = T end function settermtype!(g::EGraph, T) - g.default_termtype = T + g.default_termtype = T end function gettermtype(g::EGraph, f, ar) - if haskey(g.termtypes, (f,ar)) - g.termtypes[(f,ar)] - else - g.default_termtype - end + if haskey(g.termtypes, (f, ar)) + g.termtypes[(f, ar)] + else + g.default_termtype + end end """ Returns the canonical e-class id for a given e-class. """ -# function find(g::EGraph, a::EClassId)::EClassId -# find_root_if_normal(g.uf, a) -# end -function find(g::EGraph, a::EClassId)::EClassId - find_root(g.uf, a) -end +find(g::EGraph, a::EClassId)::EClassId = find_root(g.uf, a) find(g::EGraph, a::EClass)::EClassId = find(g, a.id) -function Base.getindex(g::EGraph, i::EClassId) - id = find(g, i) - ec = g.classes[id] - # @show ec.id id a - # @assert ec.id == id - # ec.id = id - ec -end +Base.getindex(g::EGraph, i::EClassId) = g.classes[find(g, i)] ### Definition 2.3: canonicalization iscanonical(g::EGraph, n::ENodeTerm) = n == canonicalize(g, n) @@ -316,72 +262,70 @@ iscanonical(g::EGraph, e::EClass) = find(g, e.id) == e.id canonicalize(g::EGraph, n::ENodeLiteral) = n -function canonicalize(g::EGraph, n::ENodeTerm{T}) where {T} - if arity(n) > 0 - new_args = map(x -> find(g, x), arguments(n)) - return ENodeTerm{T}(exprhead(n), operation(n), new_args) - end - return n +function canonicalize(g::EGraph, n::ENodeTerm) + if arity(n) > 0 + new_args = map(x -> find(g, x), n.args) + return ENodeTerm(exprhead(n), operation(n), symtype(n), new_args) + end + return n end function canonicalize!(g::EGraph, n::ENodeTerm) - args = arguments(n) - for i ∈ 1:arity(n) - args[i] = find(g, args[i]) - end - n.hash[] = UInt(0) - return n + for (i, arg) in enumerate(n.args) + n.args[i] = find(g, arg) + end + n.hash[] = UInt(0) + return n end canonicalize!(g::EGraph, n::ENodeLiteral) = n function canonicalize!(g::EGraph, e::EClass) - e.id = find(g, e.id) + e.id = find(g, e.id) end -function lookup(g::EGraph, n::AbstractENode) - cc = canonicalize(g, n) - if !haskey(g.memo, cc) - return nothing - end - return find(g, g.memo[cc]) +function lookup(g::EGraph, n::AbstractENode)::EClassId + cc = canonicalize(g, n) + haskey(g.memo, cc) ? find(g, g.memo[cc]) : -1 end """ Inserts an e-node in an [`EGraph`](@ref) """ -function add!(g::EGraph, n::AbstractENode)::EClass - @debug("adding ", n) +function add!(g::EGraph, n::AbstractENode)::EClassId + @debug("adding ", n) - n = canonicalize(g, n) - if haskey(g.memo, n) - eclass = g[g.memo[n]] - return eclass - end - @debug(n, " not found in memo") + n = canonicalize(g, n) + haskey(g.memo, n) && return g.memo[n] - id = push!(g.uf) # create new singleton eclass + id = push!(g.uf) # create new singleton eclass - if n isa ENodeTerm - for c_id ∈ arguments(n) - addparent!(g.classes[c_id], n, id) - end + if n isa ENodeTerm + for c_id in arguments(n) + addparent!(g.classes[c_id], n, id) end + end - g.memo[n] = id + g.memo[n] = id - classdata = EClass(g, id, AbstractENode[n], Pair{AbstractENode, EClassId}[]) - g.classes[id] = classdata - g.numclasses += 1 + if haskey(g.symcache, operation(n)) + push!(g.symcache[operation(n)], id) + else + g.symcache[operation(n)] = [id] + end - for an ∈ g.analyses - if !islazy(an) && an !== MetadataAnalysis - setdata!(classdata, an, make(an, g, n)) - modify!(an, g, id) - end + classdata = EClass(g, id, AbstractENode[n], Pair{AbstractENode,EClassId}[]) + g.classes[id] = classdata + g.numclasses += 1 + + for an in values(g.analyses) + if !islazy(an) && an !== :metadata_analysis + setdata!(classdata, an, make(an, g, n)) + modify!(an, g, id) end - return classdata + end + return id end @@ -391,8 +335,8 @@ preprocessing of a symbolic term before adding it to an EGraph. Most common preprocessing techniques are binarization of n-ary terms and metadata stripping. """ -function preprocess(e::Expr) - cleanast(e) +function preprocess(e::Expr) + cleanast(e) end preprocess(x) = x @@ -401,74 +345,60 @@ Recursively traverse an type satisfying the `TermInterface` and insert terms int [`EGraph`](@ref). If `e` has no children (has an arity of 0) then directly insert the literal into the [`EGraph`](@ref). """ -addexpr!(g::EGraph, se::EClass; keepmeta=false) = (se, se[1]) - -function addexpr!(g::EGraph, se; keepmeta=false)::Tuple{EClass, AbstractENode} - # println("========== $e ===========") - e = preprocess(se) - T = typeof(e) - node = nothing - - if istree(T) - exhead = exprhead(e) - op = operation(e) - args = arguments(e) - - n = length(args) - - class_ids = EClassId[ - first(addexpr!(g, child; keepmeta=keepmeta)).id - for child in args] - - node = ENodeTerm{typeof(e)}(exhead, op, class_ids) - else - # constant enode - node = ENodeLiteral(e) - end +function addexpr!(g::EGraph, se; keepmeta = false)::EClassId + e = preprocess(se) - ec = add!(g, node) - if keepmeta - # TODO check if eclass already has metadata? - meta = TermInterface.metadata(e) - setdata!(ec, MetadataAnalysis, meta) - end - return (ec, node) + id = add!(g, if istree(se) + class_ids::Vector{EClassId} = [addexpr!(g, arg; keepmeta = keepmeta) for arg in arguments(e)] + ENodeTerm(exprhead(e), operation(e), symtype(e), class_ids) + else + # constant enode + ENodeLiteral(e) + end) + if keepmeta + meta = TermInterface.metadata(e) + !isnothing(meta) && setdata!(g.classes[id], :metadata_analysis, meta) + end + return id end - +function addexpr!(g::EGraph, ec::EClass; keepmeta = false) + @assert g == ec.g + find(g, ec.id) +end """ Given an [`EGraph`](@ref) and two e-class ids, set the two e-classes as equal. """ function Base.merge!(g::EGraph, a::EClassId, b::EClassId)::EClassId - id_a = find(g, a) - id_b = find(g, b) + id_a = find(g, a) + id_b = find(g, b) - - id_a == id_b && return id_a - to = union!(g.uf, id_a, id_b) - @debug "merging" id_a id_b + id_a == id_b && return id_a + to = union!(g.uf, id_a, id_b) - from = (to == id_a) ? id_b : id_a + @debug "merging" id_a id_b - push!(g.dirty, to) + from = (to == id_a) ? id_b : id_a - from_class = g.classes[from] - to_class = g.classes[to] - to_class.id = to + push!(g.dirty, to) - # I (was) the troublesome line! - g.classes[to] = union!(to_class, from_class) - delete!(g.classes, from) - g.numclasses -= 1 + from_class = g.classes[from] + to_class = g.classes[to] + to_class.id = to - return to + # I (was) the troublesome line! + g.classes[to] = union!(to_class, from_class) + delete!(g.classes, from) + g.numclasses -= 1 + + return to end function in_same_class(g::EGraph, a, b) - find(g, a) == find(g, b) + find(g, a) == find(g, b) end @@ -480,105 +410,75 @@ the [egg paper](https://dl.acm.org/doi/pdf/10.1145/3434304) for more details. """ function rebuild!(g::EGraph) - # normalize!(g.uf) - - while !isempty(g.dirty) - # todo = unique([find(egraph, id) for id ∈ egraph.dirty]) - todo = unique(g.dirty) - empty!(g.dirty) - for x ∈ todo - repair!(g, x) - end + # normalize!(g.uf) + + while !isempty(g.dirty) + # todo = unique([find(egraph, id) for id ∈ egraph.dirty]) + todo = unique(g.dirty) + empty!(g.dirty) + for x in todo + repair!(g, x) end - - if g.root != -1 - g.root = find(g, g.root) - end - - normalize!(g.uf) - - # for i ∈ 1:length(egraph.uf) - # find_root!(egraph.uf, i) - # end - # INVARIANTS ASSERTIONS - # for (id, c) ∈ egraph.classes - # # ecdata.nodes = map(n -> canonicalize(egraph.uf, n), ecdata.nodes) - # println(id, "=>", c.id) - # @assert(id == c.id) - # # for an ∈ egraph.analyses - # # if haskey(an, id) - # # @assert an[id] == mapreduce(x -> make(an, x), (x, y) -> join(an, x, y), c.nodes) - # # end - # # end - - # for n ∈ c - # println(n) - # println("canon = ", canonicalize(egraph, n)) - # hr = egraph.memo[canonicalize(egraph, n)] - # println(hr) - # @assert hr == find(egraph, id) - # end - # end - # display(egraph.classes); println() - # @show egraph.dirty + end + + if g.root != -1 + g.root = find(g, g.root) + end + normalize!(g.uf) end function repair!(g::EGraph, id::EClassId) - id = find(g, id) - ecdata = g[id] - ecdata.id = id - @debug "repairing " id - - # for (p_enode, p_eclass) ∈ ecdata.parents - # clean_enode!(g, p_enode, find(g, p_eclass)) - # end - - new_parents = (length(ecdata.parents) > 30 ? OrderedDict : LittleDict){AbstractENode,EClassId}() - - for (p_enode, p_eclass) ∈ ecdata.parents - p_enode = canonicalize!(g, p_enode) - # deduplicate parents - if haskey(new_parents, p_enode) - @debug "merging classes" p_eclass (new_parents[p_enode]) - merge!(g, p_eclass, new_parents[p_enode]) - end - n_id = find(g, p_eclass) - g.memo[p_enode] = n_id - new_parents[p_enode] = n_id + id = find(g, id) + ecdata = g[id] + ecdata.id = id + @debug "repairing " id + + new_parents = (length(ecdata.parents) > 30 ? OrderedDict : LittleDict){AbstractENode,EClassId}() + + for (p_enode, p_eclass) in ecdata.parents + p_enode = canonicalize!(g, p_enode) + # deduplicate parents + if haskey(new_parents, p_enode) + @debug "merging classes" p_eclass (new_parents[p_enode]) + merge!(g, p_eclass, new_parents[p_enode]) end - - ecdata.parents = collect(new_parents) - @debug "updated parents " id g.parents[id] - - # ecdata.nodes = map(n -> canonicalize(g.uf, n), ecdata.nodes) - - # Analysis invariant maintenance - for an ∈ g.analyses - hasdata(ecdata, an) && modify!(an, g, id) - # modify!(an, id) - # id = find(g, id) - for (p_enode, p_id) ∈ ecdata.parents - # p_eclass = find(g, p_eclass) - p_eclass = g[p_id] - if !islazy(an) && !hasdata(p_eclass, an) - setdata!(p_eclass, an, make(an, g, p_enode)) - end - if hasdata(p_eclass, an) - p_data = getdata(p_eclass, an) - - new_data = join(an, p_data, make(an, g, p_enode)) - if new_data != p_data - setdata!(p_eclass, an, new_data) - push!(g.dirty, p_id) - end - end + n_id = find(g, p_eclass) + g.memo[p_enode] = n_id + new_parents[p_enode] = n_id + end + + ecdata.parents = collect(new_parents) + @debug "updated parents " id g.parents[id] + + # ecdata.nodes = map(n -> canonicalize(g.uf, n), ecdata.nodes) + + # Analysis invariant maintenance + for an in values(g.analyses) + hasdata(ecdata, an) && modify!(an, g, id) + for (p_enode, p_id) in ecdata.parents + # p_eclass = find(g, p_eclass) + p_eclass = g[p_id] + if !islazy(an) && !hasdata(p_eclass, an) + setdata!(p_eclass, an, make(an, g, p_enode)) + end + if hasdata(p_eclass, an) + p_data = getdata(p_eclass, an) + + if an !== :metadata_analysis + new_data = join(an, p_data, make(an, g, p_enode)) + if new_data != p_data + setdata!(p_eclass, an, new_data) + push!(g.dirty, p_id) + end end + end end + end - unique!(ecdata.nodes) + unique!(ecdata.nodes) - # ecdata.nodes = map(n -> canonicalize(g.uf, n), ecdata.nodes) + # ecdata.nodes = map(n -> canonicalize(g.uf, n), ecdata.nodes) end @@ -588,29 +488,67 @@ Recursive function that traverses an [`EGraph`](@ref) and returns a vector of all reachable e-classes from a given e-class id. """ function reachable(g::EGraph, id::EClassId) - id = find(g, id) - hist = EClassId[id] - todo = EClassId[id] - - - function reachable_node(xn::ENodeTerm) - x = canonicalize(g, xn) - for c_id in arguments(x) - if c_id ∉ hist - push!(hist, c_id) - push!(todo, c_id) - end - end - end - function reachable_node(x::ENodeLiteral) + id = find(g, id) + hist = EClassId[id] + todo = EClassId[id] + + + function reachable_node(xn::ENodeTerm) + x = canonicalize(g, xn) + for c_id in arguments(x) + if c_id ∉ hist + push!(hist, c_id) + push!(todo, c_id) + end end + end + function reachable_node(x::ENodeLiteral) end - while !isempty(todo) - curr = find(g, pop!(todo)) - for n ∈ g.classes[curr] - reachable_node(n) - end + while !isempty(todo) + curr = find(g, pop!(todo)) + for n in g.classes[curr] + reachable_node(n) end + end - return hist + return hist end + + +""" +When extracting symbolic expressions from an e-graph, we need +to instruct the e-graph how to rebuild expressions of a certain type. +This function must be extended by the user to add new types of expressions that can be manipulated by e-graphs. +""" +function egraph_reconstruct_expression(T::Type{Expr}, op, args; metadata = nothing, exprhead = :call) + similarterm(Expr(:call, :_), op, args; metadata = metadata, exprhead = exprhead) +end + +# Thanks to Max Willsey and Yihong Zhang + +import Metatheory: lookup_pat + +function lookup_pat(g::EGraph, p::PatTerm)::EClassId + @assert isground(p) + + eh = exprhead(p) + op = operation(p) + args = arguments(p) + ar = arity(p) + + T = gettermtype(g, op, ar) + + ids = map(x -> lookup_pat(g, x), args) + !all((>)(0), ids) && return -1 + + if T == Expr && op isa Union{Function,DataType} + id = lookup(g, ENodeTerm(eh, op, T, ids)) + id < 0 && return lookup(g, ENodeTerm(eh, nameof(op), T, ids)) + return id + else + return lookup(g, ENodeTerm(eh, op, T, ids)) + end +end + +lookup_pat(g::EGraph, p::Any) = lookup(g, ENodeLiteral(p)) +lookup_pat(g::EGraph, p::AbstractPat) = throw(UnsupportedPatternException(p)) diff --git a/src/EGraphs/ematch.jl b/src/EGraphs/ematch.jl deleted file mode 100644 index 6ea8ba8d..00000000 --- a/src/EGraphs/ematch.jl +++ /dev/null @@ -1,271 +0,0 @@ - -# ============================================================= -# ================== INTERPRETER ============================== -# ============================================================= - -struct Sub - # sourcenode::Union{Nothing, AbstractENode} - ids::Vector{EClassId} - nodes::Vector{Union{Nothing,ENodeLiteral}} -end - -haseclassid(sub::Sub, p::PatVar) = sub.ids[p.idx] >= 0 -geteclassid(sub::Sub, p::PatVar) = sub.ids[p.idx] - -hasliteral(sub::Sub, p::PatVar) = sub.nodes[p.idx] !== nothing -getliteral(sub::Sub, p::PatVar) = sub.nodes[p.idx] - -## ====================== Instantiation ======================= - -function instantiate(g::EGraph, pat::PatVar, sub::Sub, rule::AbstractRule; kws...) - if haseclassid(sub, pat) - ec = g[geteclassid(sub, pat)] - if hasliteral(sub, pat) - node = getliteral(sub, pat) - return node.value - end - return ec - else - error("unbound pattern variable $pat in rule $rule") - end -end - -instantiate(g::EGraph, pat::Any, sub::Sub, rule::AbstractRule; kws...) = pat -instantiate(g::EGraph, pat::AbstractPat, sub::Sub, rule::AbstractRule; kws...) = - throw(UnsupportedPatternException(pat)) - -# FIXME instantiate function object as operation instead of symbol if present!! -# This needs a redesign of this pattern matcher -function instantiate(g::EGraph, pat::PatTerm, sub::Sub, rule::AbstractRule; simterm=TermInterface.similarterm) - eh = exprhead(pat) - op = operation(pat) - ar = arity(pat) - T = gettermtype(g, op, ar) - children = map(x -> instantiate(g, x, sub, rule; simterm=simterm), arguments(pat)) - simterm(T, op, children; exprhead=eh) -end - -## ====================== EMatching Machine ======================= - -mutable struct Machine - g::EGraph - program::Program - # eclass register memory - σ::Vector{EClassId} - # literals - n::Vector{Union{Nothing,ENodeLiteral}} - # output buffer - buf::Vector{Sub} -end - -const DEFAULT_MEM_SIZE = 1024 -function Machine() - m = Machine( - EGraph(), # egraph - Program(), # program - fill(-1, DEFAULT_MEM_SIZE), # memory - fill(nothing, DEFAULT_MEM_SIZE), # memory - Sub[] - ) - return m -end - -function reset(m::Machine, g, program, id) - m.g = g - m.program = program - - if program.memsize > DEFAULT_MEM_SIZE - error("E-Matching Virtual Machine Memory Overflow") - end - - fill!(m.σ, -1) - fill!(m.n, nothing) - m.σ[program.first_nonground] = id - - empty!(m.buf) - - return m -end - - -function (m::Machine)() - m(m.program[1], 1) - return m.buf -end - -function next(m::Machine, pc) - m(m.program[pc + 1], pc + 1) -end - -function (m::Machine)(instr::Yield, pc) - # @show instr - # sourcenode = m.n[m.program.first_nonground] - ecs = [m.σ[reg] for reg in instr.yields] - nodes = [m.n[reg] for reg in instr.yields] - # push!(m.buf, Sub(sourcenode, ecs, nodes)) - push!(m.buf, Sub(ecs, nodes)) - - return nothing -end - -function (m::Machine)(instr::CheckClassEq, pc) - # @show instr - l = m.σ[instr.left] - r = m.σ[instr.right] - # println("checking eq $l == $r") - if l == r - next(m, pc) - end - return nothing -end - -function (m::Machine)(instr::CheckType, pc) - # @show instr - id = m.σ[instr.reg] - eclass = m.g[id] - - for n in eclass - if checktype(n, instr.type) - m.σ[instr.reg] = id - m.n[instr.reg] = n - next(m, pc) - end - end - - return nothing -end - -checktype(n, t) = false -checktype(n::ENodeLiteral{<:T}, ::Type{T}) where {T} = true - - -function (m::Machine)(instr::CheckPredicate, pc) - # @show instr - id = m.σ[instr.reg] - eclass = m.g[id] - - if instr.predicate(m.g, eclass) - m.σ[instr.reg] = id - for n in eclass.nodes - if n isa ENodeLiteral - m.n[instr.reg] = n - break - end - end - next(m, pc) - end - - return nothing -end - - -function (m::Machine)(instr::Filter, pc) - # @show instr - id, _ = m.σ[instr.reg] - eclass = m.g[id] - - if operation(instr) ∈ funs(eclass) - next(m, pc + 1) - end - return nothing -end - -# Thanks to Max Willsey and Yihong Zhang - -function lookup_pat(g::EGraph, p::PatTerm) - # println("looking up $p") - @assert isground(p) - - eh = exprhead(p) - op = operation(p) - args = arguments(p) - ar = arity(p) - - T = gettermtype(g, op, ar) - - ids = [lookup_pat(g, pp) for pp in args] - if all(i -> i isa EClassId, ids) - # println(ids) - n = ENodeTerm{T}(eh, op, ids) - ec = lookup(g, n) - return ec - else - return nothing - end -end - -lookup_pat(g::EGraph, p::Any) = lookup(g, ENodeLiteral(p)) -lookup_pat(g::EGraph, p::AbstractPat) = throw(UnsupportedPatternException(p)) - -function (m::Machine)(instr::Lookup, pc) - # @show instr - ecid = lookup_pat(m.g, instr.p) - if ecid isa EClassId - # println("found $(instr.p) in $ecid") - m.σ[instr.reg] = ecid - next(m, pc) - end - return nothing -end - -function (m::Machine)(instr::Bind, pc) - # @show instr - ecid = m.σ[instr.reg] - eclass = m.g[ecid] - pat = instr.enodepat - reg = instr.reg - - for n in eclass.nodes - # @show n - # @show exprhead(n) exprhead(instr.enodepat) - # @show operation(n) operation(instr.enodepat) - # dump(operation(n)) - # dump(operation(instr.enodepat)) - # @show arity(n) arity(instr.enodepat) - # @show arguments(n) arguments(instr.enodepat) - - # @show exprhead(n) == exprhead(instr.enodepat) - # @show operation(n) == operation(instr.enodepat) - # @show arity(n) == arity(instr.enodepat) - if canbind(n, pat) - # m.n[reg] = n - for (j, v) in enumerate(arguments(pat)) - m.σ[v] = arguments(n)[j] - end - next(m, pc) - end - end - return nothing -end - -function canbind(n::ENodeTerm, pat::ENodePat) - exprhead(n) == exprhead(pat) && - pat.checkop(operation(n)) && - arity(n) == arity(pat) -end - -canbind(n::ENodeLiteral, pat::ENodePat) = false - -# use const to help the compiler see the type. -# each machine has a corresponding lock to ensure thread-safety in case -# tasks migrate between threads. -const MACHINES = Tuple{Machine,ReentrantLock}[] - -function __init__() - empty!(MACHINES) - for _ in 1:Threads.nthreads() - push!(MACHINES, (Machine(), ReentrantLock())) - end -end - -function ematch(g::EGraph, program::Program, id::EClassId) - # @show program - tid = Threads.threadid() - m, mlock = MACHINES[tid] - buf = lock(mlock) do - reset(m, g, program, id) - m() - end - # @show buf - buf -end \ No newline at end of file diff --git a/src/EGraphs/intdisjointmap.jl b/src/EGraphs/intdisjointmap.jl index 65830910..2f475458 100644 --- a/src/EGraphs/intdisjointmap.jl +++ b/src/EGraphs/intdisjointmap.jl @@ -1,73 +1,73 @@ -struct IntDisjointSet{T<:Integer} - parents::Vector{T} - normalized::Ref{Bool} +struct IntDisjointSet + parents::Vector{Int} + normalized::Ref{Bool} end -IntDisjointSet{T}() where {T<:Integer} = IntDisjointSet{T}(Vector{T}[], Ref(true)) +IntDisjointSet() = IntDisjointSet(Int[], Ref(true)) Base.length(x::IntDisjointSet) = length(x.parents) -function Base.push!(x::IntDisjointSet{T}) where {T} - push!(x.parents, convert(T, -1)) - convert(T, length(x)) +function Base.push!(x::IntDisjointSet)::Int + push!(x.parents, -1) + length(x) end -function find_root(x::IntDisjointSet{T}, i::T) where {T} - while x.parents[i] >= 0 - i = x.parents[i] - end - return convert(T, i) +function find_root(x::IntDisjointSet, i::Int)::Int + while x.parents[i] >= 0 + i = x.parents[i] + end + return i end -function in_same_set(x::IntDisjointSet{T}, a::T, b::T) where {T} - find_root(x, a) == find_root(x, b) +function in_same_set(x::IntDisjointSet, a::Int, b::Int) + find_root(x, a) == find_root(x, b) end -function Base.union!(x::IntDisjointSet{T}, i::T, j::T) where {T} - pi = find_root(x, i) - pj = find_root(x, j) - if pi != pj - x.normalized[] = false - isize = -x.parents[pi] - jsize = -x.parents[pj] - if isize > jsize # swap to make size of i less than j - pi, pj = pj, pi - isize, jsize = jsize, isize - end - x.parents[pj] -= isize # increase new size of pj - x.parents[pi] = pj # set parent of pi to pj +function Base.union!(x::IntDisjointSet, i::Int, j::Int) + pi = find_root(x, i) + pj = find_root(x, j) + if pi != pj + x.normalized[] = false + isize = -x.parents[pi] + jsize = -x.parents[pj] + if isize > jsize # swap to make size of i less than j + pi, pj = pj, pi + isize, jsize = jsize, isize end - return convert(T, pj) + x.parents[pj] -= isize # increase new size of pj + x.parents[pi] = pj # set parent of pi to pj + end + return pj end -function normalize!(x::IntDisjointSet{T}) where {T} - for i in convert(T, length(x)) - pi = find_root(x, i) - if pi != i - x.parents[i] = convert(T, pi) - end +function normalize!(x::IntDisjointSet) + for i in 1:length(x) + p_i = find_root(x, i) + if p_i != i + x.parents[i] = p_i end - x.normalized[] = true + end + x.normalized[] = true end # If normalized we don't even need a loop here. -function _find_root_normal(x::IntDisjointSet{T}, i::T) where {T} - pi = x.parents[i] - if pi < 0 # Is `i` a root? - return i - else - return pi - end - # return pi +function _find_root_normal(x::IntDisjointSet, i::Int) + p_i = x.parents[i] + if p_i < 0 # Is `i` a root? + return i + else + return p_i + end + # return pi end function _in_same_set_normal(x::IntDisjointSet, a::Int64, b::Int64) - _find_root_normal(x, a) == _find_root_normal(x, b) + _find_root_normal(x, a) == _find_root_normal(x, b) end function find_root_if_normal(x::IntDisjointSet, i::Int64) - if x.normalized[] - _find_root_normal(x, i) - else - find_root(x, i) - end -end \ No newline at end of file + if x.normalized[] + _find_root_normal(x, i) + else + find_root(x, i) + end +end diff --git a/src/EGraphs/saturation.jl b/src/EGraphs/saturation.jl index c9964dfc..21b3bc84 100644 --- a/src/EGraphs/saturation.jl +++ b/src/EGraphs/saturation.jl @@ -8,16 +8,16 @@ This goal is reached when the `exprs` list of expressions are in the same equivalence class. """ struct EqualityGoal <: SaturationGoal - exprs::Vector{Any} - ids::Vector{EClassId} - function EqualityGoal(exprs, eclasses) - @assert length(exprs) == length(eclasses) && length(exprs) != 0 - new(exprs, eclasses) - end + exprs::Vector{Any} + ids::Vector{EClassId} + function EqualityGoal(exprs, eclasses) + @assert length(exprs) == length(eclasses) && length(exprs) != 0 + new(exprs, eclasses) + end end function reached(g::EGraph, goal::EqualityGoal) - all(x -> in_same_class(g, goal.ids[1], x), @view goal.ids[2:end]) + all(x -> in_same_class(g, goal.ids[1], x), @view goal.ids[2:end]) end """ @@ -25,393 +25,346 @@ Boolean valued function as an arbitrary saturation goal. User supplied function must take an [`EGraph`](@ref) as the only parameter. """ struct FunctionGoal <: SaturationGoal - fun::Function + fun::Function end function reached(g::EGraph, goal::FunctionGoal)::Bool - fun(g) + goal.fun(g) end -mutable struct Report - reason::Union{Symbol, Nothing} - egraph::EGraph - iterations::Int - to::TimerOutput +mutable struct SaturationReport + reason::Union{Symbol,Nothing} + egraph::EGraph + iterations::Int + to::TimerOutput end -Report() = Report(nothing, EGraph(), 0, TimerOutput()) -Report(g::EGraph) = Report(nothing, g, 0, TimerOutput()) +SaturationReport() = SaturationReport(nothing, EGraph(), 0, TimerOutput()) +SaturationReport(g::EGraph) = SaturationReport(nothing, g, 0, TimerOutput()) # string representation of timedata -function Base.show(io::IO, x::Report) - g = x.egraph - println(io, "Equality Saturation Report") - println(io, "=================") - println(io, "\tStop Reason: $(x.reason)") - println(io, "\tIterations: $(x.iterations)") - # println(io, "\tRules applied: $(g.age)") - println(io, "\tEGraph Size: $(g.numclasses) eclasses, $(length(g.memo)) nodes") - print_timer(io, x.to) +function Base.show(io::IO, x::SaturationReport) + g = x.egraph + println(io, "SaturationReport") + println(io, "=================") + println(io, "\tStop Reason: $(x.reason)") + println(io, "\tIterations: $(x.iterations)") + println(io, "\tEGraph Size: $(g.numclasses) eclasses, $(length(g.memo)) nodes") + print_timer(io, x.to) end """ Configurable Parameters for the equality saturation process. """ -@with_kw mutable struct SaturationParams - timeout::Int = 8 - timelimit::Period = Second(-1) - # default sizeout. TODO make this bytes - # sizeout::Int = 2^14 - matchlimit::Int = 5000 - eclasslimit::Int = 5000 - enodelimit::Int = 15000 - goal::Union{Nothing, SaturationGoal} = nothing - stopwhen::Function = ()->false - scheduler::Type{<:AbstractScheduler} = BackoffScheduler - schedulerparams::Tuple=() - threaded::Bool = false - timer::Bool = true - printiter::Bool = false - simterm::Function = similarterm -end - -struct Match - rule::AbstractRule - # the rhs pattern to instantiate - pat_to_inst - # the substitution - sub::Sub - # the id the matched the lhs - id::EClassId +Base.@kwdef mutable struct SaturationParams + timeout::Int = 8 + "Timeout in nanoseconds" + timelimit::UInt64 = 0 + "Maximum number of eclasses allowed" + eclasslimit::Int = 5000 + enodelimit::Int = 15000 + goal::Union{Nothing,SaturationGoal} = nothing + stopwhen::Function = () -> false + scheduler::Type{<:AbstractScheduler} = BackoffScheduler + schedulerparams::Tuple = () + threaded::Bool = false + timer::Bool = true + printiter::Bool = false end -const MatchesBuf = Vector{Match} +# function cached_ids(g::EGraph, p::PatTerm)# ::Vector{Int64} +# if isground(p) +# id = lookup_pat(g, p) +# !isnothing(id) && return [id] +# else +# return keys(g.classes) +# end +# return [] +# end -function cached_ids(g::EGraph, p::AbstractPat)# ::Vector{Int64} - if isground(p) - id = lookup_pat(g, p) - !isnothing(id) && return [id] - else - return collect(keys(g.classes)) - end - return [] +function cached_ids(g::EGraph, p::AbstractPattern) # p is a literal + @warn "Pattern matching against the whole e-graph" + return keys(g.classes) end function cached_ids(g::EGraph, p) # p is a literal - id = lookup(g, ENodeLiteral(p)) - !isnothing(id) && return [id] - return [] + id = lookup(g, ENodeLiteral(p)) + id > 0 && return [id] + return [] end -# FIXME -function cached_ids(g::EGraph, p::PatTerm) - # println("pattern $p, $(p.head)") - # println("all ids") - # keys(g.classes) |> println - # println("cached symbols") - # cached = get(g.symcache, p.head, Set{Int64}()) - # println("symbols where $(p.head) appears") - # appears = Set{Int64}() - # for (id, class) ∈ g.classes - # for n ∈ class - # if n.head == p.head - # push!(appears, id) - # end - # end - # end - # # println(appears) - # if !(cached == appears) - # @show cached - # @show appears - # end - - collect(keys(g.classes)) - # cached - # get(g.symcache, p.head, []) -end -# function cached_ids(g::EGraph, p::PatLiteral) -# get(g.symcache, p.val, []) +# function cached_ids(g::EGraph, p::PatTerm) +# arr = get(g.symcache, operation(p), EClassId[]) +# if operation(p) isa Union{Function,DataType} +# append!(arr, get(g.symcache, nameof(operation(p)), EClassId[])) +# end +# arr # end -function (r::SymbolicRule)(g::EGraph, id::EClassId) - ematch(g, r.ematch_program, id) .|> sub -> Match(r, r.right, sub, id) -end - -function (r::DynamicRule)(g::EGraph, id::EClassId) - ematch(g, r.ematch_program, id) .|> sub -> Match(r, nothing, sub, id) -end - -function (r::BidirRule)(g::EGraph, id::EClassId) - vcat(ematch(g, r.ematch_program_l, id) .|> sub -> Match(r, r.right, sub, id), - ematch(g, r.ematch_program_r, id) .|> sub -> Match(r, r.left, sub, id)) +function cached_ids(g::EGraph, p::PatTerm) + keys(g.classes) end """ Returns an iterator of `Match`es. """ -function eqsat_search!(egraph::EGraph, theory::Vector{<:AbstractRule}, - scheduler::AbstractScheduler, report; threaded=false) - match_groups = Vector{Match}[] - function pmap(f, xs) - # const propagation should be able to optimze one of the branch away - if threaded - # # try to divide the work evenly between threads without adding much overhead - # chunks = Threads.nthreads() * 10 - # basesize = max(length(xs) ÷ chunks, 1) - # ThreadsX.mapi(f, xs; basesize=basesize) - ThreadsX.map(f, xs) - else - map(f, xs) - end +function eqsat_search!( + g::EGraph, + theory::Vector{<:AbstractRule}, + scheduler::AbstractScheduler, + report::SaturationReport, +)::Int + n_matches = 0 + + lock(BUFFER_LOCK) do + empty!(BUFFER[]) + end + + for (rule_idx, rule) in enumerate(theory) + @timeit report.to string(rule_idx) begin + # don't apply banned rules + if !cansearch(scheduler, rule) + continue + end + ids = cached_ids(g, rule.left) + rule isa BidirRule && (ids = ids ∪ cached_ids(g, rule.right)) + for i in ids + n_matches += rule.ematcher!(g, rule_idx, i) + end + inform!(scheduler, rule, n_matches) end + end - inequalities = filter(Base.Fix2(isa, UnequalRule), theory) - # never skip contradiction checks - append_time = TimerOutput() - for rule ∈ inequalities - @timeit report.to repr(rule) begin - ids = cached_ids(egraph, rule.left) - rule_matches = pmap(i -> rule(egraph, i), ids) - @timeit append_time "appending matches" begin - append!(match_groups, rule_matches) - end - end - end - other_rules = filter(theory) do rule - !(rule isa UnequalRule) - end - for rule ∈ other_rules - @timeit report.to repr(rule) begin - # don't apply banned rules - if !cansearch(scheduler, rule) - # println("skipping banned rule $rule") - continue - end - ids = cached_ids(egraph, rule.left) - rule_matches = pmap(i -> rule(egraph, i), ids) - - n_matches = isempty(rule_matches) ? 0 : sum(length, rule_matches) - # @show (rule, n_matches) - can_yield = inform!(scheduler, rule, n_matches) - if can_yield - @timeit append_time "appending matches" begin - append!(match_groups, rule_matches) - end - end - end - end + return n_matches +end - # @timeit append_time "appending matches" begin - # result = reduce(vcat, match_groups) # this should be more efficient than multiple appends - # end - merge!(report.to, append_time, tree_point=["Search"]) - return Iterators.flatten(match_groups) - # return result +function drop_n!(D::CircularDeque, nn) + D.n -= nn + tmp = D.first + nn + D.first = tmp > D.capacity ? 1 : tmp end - -function (rule::UnequalRule)(g::EGraph, match::Match; simterm=similarterm) - lc = match.id - rinst = instantiate(g, match.pat_to_inst, match.sub, rule; simterm=simterm) - rc, node = addexpr!(g, rinst) +instantiate_enode!(bindings::Bindings, g::EGraph, p::Any)::EClassId = add!(g, ENodeLiteral(p)) +instantiate_enode!(bindings::Bindings, g::EGraph, p::PatVar)::EClassId = bindings[p.idx][1] +function instantiate_enode!(bindings::Bindings, g::EGraph, p::PatTerm)::EClassId + eh = exprhead(p) + op = operation(p) + ar = arity(p) + args = arguments(p) + T = gettermtype(g, op, ar) + # TODO add predicate check `quotes_operation` + new_op = T == Expr && op isa Union{Function,DataType} ? nameof(op) : op + add!(g, ENodeTerm(eh, new_op, T, map(arg -> instantiate_enode!(bindings, g, arg), args))) +end - if find(g, lc) == find(g, rc) - @log "Contradiction!" rule - return :contradiction - end - return nothing +function apply_rule!(buf, g::EGraph, rule::RewriteRule, id, direction) + push!(MERGES_BUF[], (id, instantiate_enode!(buf, g, rule.right))) + nothing +end + +function apply_rule!(bindings::Bindings, g::EGraph, rule::EqualityRule, id::EClassId, direction::Int) + pat_to_inst = direction == 1 ? rule.right : rule.left + push!(MERGES_BUF[], (id, instantiate_enode!(bindings, g, pat_to_inst))) + nothing end -function (rule::SymbolicRule)(g::EGraph, match::Match; simterm=similarterm) - rinst = instantiate(g, match.pat_to_inst, match.sub, rule; simterm=simterm) - rc, node = addexpr!(g, rinst) - merge!(g, match.id, rc.id) - return nothing + +function apply_rule!(bindings::Bindings, g::EGraph, rule::UnequalRule, id::EClassId, direction::Int) + pat_to_inst = direction == 1 ? rule.right : rule.left + other_id = instantiate_enode!(bindings, g, pat_to_inst) + + if find(g, id) == find(g, other_id) + @log "Contradiction!" rule + return :contradiction + end + nothing end +""" +Instantiate argument for dynamic rule application in e-graph +""" +function instantiate_actual_param!(bindings::Bindings, g::EGraph, i) + ecid, literal_position = bindings[i] + ecid <= 0 && error("unbound pattern variable $pat in rule $rule") + if literal_position > 0 + eclass = g[ecid] + @assert eclass[literal_position] isa ENodeLiteral + return eclass[literal_position].value + end + return eclass +end -function (rule::DynamicRule)(g::EGraph, match::Match; simterm=similarterm) - f = rule.rhs_fun - actual_params = [instantiate(g, PatVar(v, i, alwaystrue), match.sub, rule) for (i, v) in enumerate(rule.patvars)] - r = f(g[match.id], match.sub, g, actual_params...) - isnothing(r) && return nothing - rc, node = addexpr!(g, r) - merge!(g, match.id, rc.id) - return nothing +function apply_rule!(bindings::Bindings, g::EGraph, rule::DynamicRule, id::EClassId, direction::Int) + f = rule.rhs_fun + r = f(id, g, (instantiate_actual_param!(bindings, g, i) for i in 1:length(rule.patvars))...) + isnothing(r) && return nothing + rcid = addexpr!(g, r) + push!(MERGES_BUF[], (id, rcid)) + return nothing end -function eqsat_apply!(g::EGraph, matches, rep::Report, params::SaturationParams) - i = 0 - # println.(matches) - for match ∈ matches - i += 1 - # if params.eclasslimit > 0 && g.numclasses > params.eclasslimit - # @log "E-GRAPH SIZEOUT" - # rep.reason = :eclasslimit - # return - # end +function eqsat_apply!(g::EGraph, theory::Vector{<:AbstractRule}, rep::SaturationReport, params::SaturationParams) + i = 0 + @assert isempty(MERGES_BUF[]) - if reached(g, params.goal) - @log "Goal reached" - rep.reason = :goalreached - return - end + lock(BUFFER_LOCK) do + while !isempty(BUFFER[]) + if reached(g, params.goal) + @log "Goal reached" + rep.reason = :goalreached + return + end + bindings = popfirst!(BUFFER[]) + rule_idx, id = bindings[0] + direction = sign(rule_idx) + rule_idx = abs(rule_idx) + rule = theory[rule_idx] - rule = match.rule - # println("applying $rule") - halt_reason = rule(g, match; simterm=params.simterm) - if (halt_reason !== nothing) - rep.reason = halt_reason - return - end + halt_reason = lock(MERGES_BUF_LOCK) do + apply_rule!(bindings, g, rule, id, direction) + end - # println(rule) - # println(sub) - # println(l); println(r) - # display(egraph.classes); println() + if !isnothing(halt_reason) + rep.reason = halt_reason + return + end + end + end + lock(MERGES_BUF_LOCK) do + while !isempty(MERGES_BUF[]) + (l, r) = popfirst!(MERGES_BUF[]) + merge!(g, l, r) end + end end + + import ..@log """ Core algorithm of the library: the equality saturation step. """ -function eqsat_step!(g::EGraph, theory::Vector{<:AbstractRule}, curr_iter, - scheduler::AbstractScheduler, match_hist::MatchesBuf, - params::SaturationParams, report) - - instcache = Dict{AbstractRule, Dict{Sub, EClassId}}() +function eqsat_step!( + g::EGraph, + theory::Vector{<:AbstractRule}, + curr_iter, + scheduler::AbstractScheduler, + params::SaturationParams, + report, +) - setiter!(scheduler, curr_iter) + setiter!(scheduler, curr_iter) - matches = @timeit report.to "Search" eqsat_search!(g, theory, scheduler, report; threaded=params.threaded) + @timeit report.to "Search" eqsat_search!(g, theory, scheduler, report) - # matches = setdiff!(matches, match_hist) + @timeit report.to "Apply" eqsat_apply!(g, theory, report, params) - @timeit report.to "Apply" eqsat_apply!(g, matches, report, params) - + if report.reason === nothing && cansaturate(scheduler) && isempty(g.dirty) + report.reason = :saturated + end + @timeit report.to "Rebuild" rebuild!(g) - # union!(match_hist, matches) - - if report.reason === nothing && cansaturate(scheduler) && isempty(g.dirty) - report.reason = :saturated - end - @timeit report.to "Rebuild" rebuild!(g) - - return report, g + return report end """ Given an [`EGraph`](@ref) and a collection of rewrite rules, execute the equality saturation algorithm. """ -function saturate!(g::EGraph, theory::Vector{<:AbstractRule}, params=SaturationParams()) - curr_iter = 0 +function saturate!(g::EGraph, theory::Vector{<:AbstractRule}, params = SaturationParams()) + curr_iter = 0 - sched = params.scheduler(g, theory, params.schedulerparams...) - match_hist = MatchesBuf() - report = Report(g) + sched = params.scheduler(g, theory, params.schedulerparams...) + report = SaturationReport(g) - start_time = Dates.now().instant + start_time = time_ns() - !params.timer && disable_timer!(report.to) - timelimit = params.timelimit > Second(0) - + !params.timer && disable_timer!(report.to) + timelimit = params.timelimit > 0 - while true - curr_iter+=1 + while true + curr_iter += 1 - params.printiter && @info("iteration ", curr_iter) + params.printiter && @info("iteration ", curr_iter) - report, egraph = eqsat_step!(g, theory, curr_iter, sched, match_hist, params, report) + report = eqsat_step!(g, theory, curr_iter, sched, params, report) - elapsed = Dates.now().instant - start_time + elapsed = time_ns() - start_time - if timelimit && params.timelimit <= elapsed - report.reason = :timelimit - break - end + if timelimit && params.timelimit <= elapsed + report.reason = :timelimit + break + end - # report.reason == :matchlimit && break - if !(report.reason isa Nothing) - break - end + if !(report.reason isa Nothing) + break + end - if curr_iter >= params.timeout - report.reason = :timeout - break - end + if curr_iter >= params.timeout + report.reason = :timeout + break + end - if params.eclasslimit > 0 && g.numclasses > params.eclasslimit - # println(params.eclasslimit) - report.reason = :eclasslimit - break - end + if params.eclasslimit > 0 && g.numclasses > params.eclasslimit + report.reason = :eclasslimit + break + end - if reached(g, params.goal) - report.reason = :goalreached - break - end + if reached(g, params.goal) + report.reason = :goalreached + break end - report.iterations = curr_iter - @log report + end + report.iterations = curr_iter + @log report - return report + return report end -function areequal(theory::Vector, exprs...; params=SaturationParams()) - g = EGraph(exprs[1]) - areequal(g, theory, exprs...; params=params) +function areequal(theory::Vector, exprs...; params = SaturationParams()) + g = EGraph(exprs[1]) + areequal(g, theory, exprs...; params = params) end -function areequal(g::EGraph, t::Vector{<:AbstractRule}, exprs...; params=SaturationParams()) - @log "Checking equality for " exprs - if length(exprs) == 1; return true end - # rebuild!(G) - - @log "starting saturation" +function areequal(g::EGraph, t::Vector{<:AbstractRule}, exprs...; params = SaturationParams()) + @log "Checking equality for " exprs + if length(exprs) == 1 + return true + end + # rebuild!(G) - n = length(exprs) - ids = Vector{EClassId}(undef, n) - nodes = Vector{AbstractENode}(undef, n) - for i ∈ 1:n - ec, node = addexpr!(g, exprs[i]) - ids[i] = ec.id - nodes[i] = node - end + @log "starting saturation" - goal = EqualityGoal(collect(exprs), ids) - - # alleq = () -> (all(x -> in_same_set(G.uf, ids[1], x), ids[2:end])) + n = length(exprs) + ids = map(Base.Fix1(addexpr!, g), collect(exprs)) + goal = EqualityGoal(collect(exprs), ids) - params.goal = goal - # params.stopwhen = alleq + params.goal = goal - report = saturate!(g, t, params) + report = saturate!(g, t, params) - # display(g.classes); println() - if !(report.reason === :saturated) && !reached(g, goal) - return missing # failed to prove - end - return reached(g, goal) + if !(report.reason === :saturated) && !reached(g, goal) + return missing # failed to prove + end + return reached(g, goal) end macro areequal(theory, exprs...) - esc(:(areequal($theory, $exprs...))) + esc(:(areequal($theory, $exprs...))) end macro areequalg(G, theory, exprs...) - esc(:(areequal($G, $theory, $exprs...))) + esc(:(areequal($G, $theory, $exprs...))) end diff --git a/src/Library.jl b/src/Library.jl index 9a3dce83..6a3f7f18 100644 --- a/src/Library.jl +++ b/src/Library.jl @@ -9,133 +9,91 @@ module Library using Metatheory.Patterns using Metatheory.Rules -macro associativity(op) - quote - [ - (@left_associative $op), - (@right_associative $op) - ] - end -end -macro monoid(op, id) - quote - [ - (@left_associative(op)), - (@right_associative(op)), - (@identity_left(op, id)), - (@identity_right(op, id)) - ] - end +macro commutativity(op) + RewriteRule(PatTerm(:call, op, [PatVar(:a), PatVar(:b)]), PatTerm(:call, op, [PatVar(:b), PatVar(:a)])) end -macro commutative_monoid(op, id) - quote - [ - (@commutativity $op), - (@left_associative $op), - (@right_associative $op), - (@identity_left $op $id) - ] - end +macro right_associative(op) + RewriteRule( + PatTerm(:call, op, [PatVar(:a), PatTerm(:call, op, [PatVar(:b), PatVar(:c)])]), + PatTerm(:call, op, [PatTerm(:call, op, [PatVar(:a), PatVar(:b)]), PatVar(:c)]), + ) end - -# constructs a semantic theory about a an abelian group -# The definition of a group does not require that a ⋅ b = b ⋅ a -# for all elements a and b in G. If this additional condition holds, -# then the operation is said to be commutative, and the group is called an abelian group. -macro commutative_group(op, id, invop) - # @assert Base.isbinaryoperator(op) - # @assert Base.isunaryoperator(invop) - quote - (@commutative_monoid $op $id) ∪ [@inverse_right $op $id $invop] - end +macro left_associative(op) + RewriteRule( + PatTerm(:call, op, [PatTerm(:call, op, [PatVar(:a), PatVar(:b)]), PatVar(:c)]), + PatTerm(:call, op, [PatVar(:a), PatTerm(:call, op, [PatVar(:b), PatVar(:c)])]), + ) end -macro distrib(outop, inop) - quote - [ - (@distrib_left $outop $inop), - (@distrib_right $outop $inop), - ] - end -end +macro identity_left(op, id) + RewriteRule(PatTerm(:call, op, [id, PatVar(:a)]), PatVar(:a)) +end -macro commutativity(op) - RewriteRule( - PatTerm(:call, op, [PatVar(:a), PatVar(:b)], __module__), - PatTerm(:call, op, [PatVar(:b), PatVar(:a)], __module__)) -end +macro identity_right(op, id) + RewriteRule(PatTerm(:call, op, [PatVar(:a), id]), PatVar(:a)) +end -macro right_associative(op) - RewriteRule( - PatTerm(:call, op, [PatVar(:a), - PatTerm(:call, op, [PatVar(:b), PatVar(:c)], __module__)], __module__), - PatTerm(:call, op, [ - PatTerm(:call, op, [PatVar(:a), PatVar(:b)], __module__), - PatVar(:c), - ], __module__)) +macro inverse_left(op, id, invop) + RewriteRule(PatTerm(:call, op, [PatTerm(:call, invop, [PatVar(:a)]), PatVar(:a)]), id) end -macro left_associative(op) - RewriteRule( - PatTerm(:call, op, [ - PatTerm(:call, op, [PatVar(:a), PatVar(:b)], __module__), - PatVar(:c), - ], __module__), - PatTerm(:call, op, [PatVar(:a), - PatTerm(:call, op, [PatVar(:b), PatVar(:c)], __module__)], __module__)) +macro inverse_right(op, id, invop) + RewriteRule(PatTerm(:call, op, [PatVar(:a), PatTerm(:call, invop, [PatVar(:a)])]), id) end -macro identity_left(op, id) - RewriteRule(PatTerm(:call, op, [id, PatVar(:a)], __module__), PatVar(:a)) +macro associativity(op) + esc(quote + [(@left_associative $op), (@right_associative $op)] + end) end -macro identity_right(op, id) - RewriteRule(PatTerm(:call, op, [PatVar(:a), id], __module__), PatVar(:a)) +macro monoid(op, id) + esc(quote + [(@left_associative($op)), (@right_associative($op)), (@identity_left($op, $id)), (@identity_right($op, $id))] + end) end -macro inverse_left(op, id, invop) - RewriteRule(PatTerm(:call, op, [ - PatTerm(:call, invop, [PatVar(:a)], __module__), PatVar(:a)], __module__), id) +macro commutative_monoid(op, id) + esc(quote + [(@commutativity $op), (@left_associative $op), (@right_associative $op), (@identity_left $op $id)] + end) end -macro inverse_right(op, id, invop) - RewriteRule(PatTerm(:call, op, [ - PatVar(:a), - PatTerm(:call, invop, [PatVar(:a)], __module__)], __module__), id) + +# constructs a semantic theory about a an abelian group +# The definition of a group does not require that a ⋅ b = b ⋅ a +# for all elements a and b in G. If this additional condition holds, +# then the operation is said to be commutative, and the group is called an abelian group. +macro commutative_group(op, id, invop) + # @assert Base.isbinaryoperator(op) + # @assert Base.isunaryoperator(invop) + esc(quote + (@commutative_monoid $op $id) ∪ [@inverse_right $op $id $invop] + end) +end + +macro distrib(outop, inop) + esc(quote + [(@distrib_left $outop $inop), (@distrib_right $outop $inop)] + end) end + # distributivity of two operations # example: `@distrib (⋅) (⊕)` macro distrib_left(outop, inop) - EqualityRule( - # left - PatTerm(:call, outop, [ - PatVar(:a), - PatTerm(:call, inop, [PatVar(:b), PatVar(:c)], __module__) - ], __module__), - # right - PatTerm(:call, inop, [ - PatTerm(:call, outop, [PatVar(:a), PatVar(:b)], __module__), - PatTerm(:call, outop, [PatVar(:a), PatVar(:c)], __module__), - ], __module__)) - + esc(quote + @rule a b c ($outop)(a, $(inop)(b, c)) == $(inop)($(outop)(a, b), $(outop)(a, c)) + end) end macro distrib_right(outop, inop) - EqualityRule( - # left - PatTerm(:call, outop, [ - PatTerm(:call, inop, [PatVar(:a), PatVar(:b)], __module__), - PatVar(:c) - ], __module__), - # right - PatTerm(:call, inop, [ - PatTerm(:call, outop, [PatVar(:a), PatVar(:c)], __module__), - PatTerm(:call, outop, [PatVar(:b), PatVar(:c)], __module__), - ], __module__)) + esc(quote + @rule a b c ($outop)($(inop)(a, b), c) == $(inop)($(outop)(a, c), $(outop)(b, c)) + end) end @@ -150,5 +108,9 @@ export @distrib export @monoid export @commutative_monoid export @commutative_group +export @left_associative +export @right_associative +export @inverse_left +export @inverse_right end diff --git a/src/Metatheory.jl b/src/Metatheory.jl index afaaed25..c3a9ad33 100644 --- a/src/Metatheory.jl +++ b/src/Metatheory.jl @@ -1,24 +1,49 @@ module Metatheory +using DataStructures + +import Base.ImmutableDict + +const Bindings = ImmutableDict{Int,Tuple{Int,Int}} +const DEFAULT_BUFFER_SIZE = 1048576 +const BUFFER = Ref(CircularDeque{Bindings}(DEFAULT_BUFFER_SIZE)) +const BUFFER_LOCK = ReentrantLock() +const MERGES_BUF = Ref(CircularDeque{Tuple{Int,Int}}(DEFAULT_BUFFER_SIZE)) +const MERGES_BUF_LOCK = ReentrantLock() + +function resetbuffers!(bufsize) + BUFFER[] = CircularDeque{Bindings}(bufsize) + MERGES_BUF[] = CircularDeque{Tuple{Int,Int}}(bufsize) +end + +function __init__() + println(Threads.nthreads()) + resetbuffers!(DEFAULT_BUFFER_SIZE) +end + using Base.Meta using Reexport using TermInterface macro log(args...) - quote haskey(ENV, "MT_DEBUG") && @info($(args...)) end |> esc + quote + haskey(ENV, "MT_DEBUG") && @info($(args...)) + end |> esc end @inline alwaystrue(x) = true +function lookup_pat end + include("docstrings.jl") include("utils.jl") -export @timer +export @timer export @iftimer export @timerewrite export @matchable include("Patterns.jl") -@reexport using .Patterns +@reexport using .Patterns include("ematch_compiler.jl") @reexport using .EMatchCompiler @@ -39,12 +64,12 @@ include("Rewriters.jl") using .Rewriters export Rewriters -function rewrite(expr, theory; order=:outer) - if order == :inner - Fixpoint(Prewalk(Fixpoint(Chain(theory))))(expr) - elseif order == :outer - Fixpoint(Postwalk(Fixpoint(Chain(theory))))(expr) - end +function rewrite(expr, theory; order = :outer) + if order == :inner + Fixpoint(Prewalk(Fixpoint(Chain(theory))))(expr) + elseif order == :outer + Fixpoint(Postwalk(Fixpoint(Chain(theory))))(expr) + end end export rewrite diff --git a/src/Patterns.jl b/src/Patterns.jl index 2af0075f..d2d653e7 100644 --- a/src/Patterns.jl +++ b/src/Patterns.jl @@ -1,4 +1,4 @@ -module Patterns +module Patterns using Metatheory: binarize, cleanast, alwaystrue using AutoHashEquals @@ -12,11 +12,10 @@ abstract type AbstractPat end struct UnsupportedPatternException <: Exception - p::AbstractPat + p::AbstractPat end -Base.showerror(io::IO, e::UnsupportedPatternException) = - print(io, "Pattern ", e.p, " is unsupported in this context") +Base.showerror(io::IO, e::UnsupportedPatternException) = print(io, "Pattern ", e.p, " is unsupported in this context") Base.isequal(a::AbstractPat, b::AbstractPat) = false @@ -43,17 +42,18 @@ boolean value. Such a slot will be considered a match only if `f` returns true. type assertion. Type assertions on a `PatVar`, will match if and only if the type of the matched term for the pattern variable is a subtype of `T`. """ -mutable struct PatVar{P} <: AbstractPat - name::Symbol - idx::Int - predicate::P +mutable struct PatVar{P} <: AbstractPat + name::Symbol + idx::Int + predicate::P + predicate_code end function Base.isequal(a::PatVar, b::PatVar) - # (a.name == b.name) - a.idx == b.idx + # (a.name == b.name) + a.idx == b.idx end -PatVar(var) = PatVar(var, -1, alwaystrue) -PatVar(var, i) = PatVar(var, i, alwaystrue) +PatVar(var) = PatVar(var, -1, alwaystrue, nothing) +PatVar(var, i) = PatVar(var, i, alwaystrue, nothing) """ If you want to match a variable number of subexpressions at once, you will need @@ -63,58 +63,38 @@ You can attach a predicate `g` to a segment variable. In the case of segment var expressions and must return a boolean value. """ mutable struct PatSegment{P} <: AbstractPat - name::Symbol - idx::Int - predicate::P - # hash::Ref{UInt} + name::Symbol + idx::Int + predicate::P + predicate_code end -PatSegment(v) = PatSegment(v, -1, alwaystrue) -PatSegment(v, i) = PatSegment(v, i, alwaystrues) -# PatSegment(v, i, p) = PatSegment{typeof(p)}(v, i, p), Ref{UInt}(0)) +PatSegment(v) = PatSegment(v, -1, alwaystrue, nothing) +PatSegment(v, i) = PatSegment(v, i, alwaystrue, nothing) -# function Base.hash(t::PatSegment, salt::UInt) -# !iszero(salt) && return hash(hash(t, zero(UInt)), salt) -# h = t.hash[] -# !iszero(h) && return h -# h′ = hash(t.name, hash(t.predicate, salt)) -# t.hash[] = h′ -# return h′ -# end - """ Term patterns will match on terms of the same `arity` and with the same function symbol `operation` and expression head `exprhead`. """ struct PatTerm <: AbstractPat - exprhead::Any - operation::Any - args::Vector - mod::Module # useful to match against function head symbols and function objs at the same time - PatTerm(eh, op, args, mod) = new(eh, op, args, mod) #Ref{UInt}(0)) + exprhead::Any + operation::Any + args::Vector + PatTerm(eh, op, args) = new(eh, op, args) #Ref{UInt}(0)) end -TermInterface.istree(::Type{PatTerm}) = true +TermInterface.istree(::PatTerm) = true TermInterface.exprhead(e::PatTerm) = e.exprhead TermInterface.operation(p::PatTerm) = p.operation TermInterface.arguments(p::PatTerm) = p.args TermInterface.arity(p::PatTerm) = length(arguments(p)) -TermInterface.metadata(p::PatTerm) = p.mod +TermInterface.metadata(p::PatTerm) = nothing -function TermInterface.similarterm(x::Type{PatTerm}, head, args, symtype=nothing; metadata=@__MODULE__, exprhead=:call) - PatTerm(exprhead, head, args, metadata) +function TermInterface.similarterm(x::PatTerm, head, args, symtype = nothing; metadata = nothing, exprhead = :call) + PatTerm(exprhead, head, args) end -# function Base.hash(t::PatTerm, salt::UInt) -# !iszero(salt) && return hash(hash(t, zero(UInt)), salt) -# h = t.hash[] -# !iszero(h) && return h -# h′ = hash(t.exprhead, hash(t.operation, hash(t.args, salt))) -# t.hash[] = h′ -# return h′ -# end - isground(p::PatTerm) = all(isground, p.args) @@ -127,9 +107,8 @@ Collects pattern variables appearing in a pattern into a vector of symbols """ patvars(p::PatVar, s) = push!(s, p.name) patvars(p::PatSegment, s) = push!(s, p.name) -patvars(p::PatTerm, s) = (patvars(operation(p), s); foreach(x -> patvars(x, s), arguments(p)) ; s) +patvars(p::PatTerm, s) = (patvars(operation(p), s); foreach(x -> patvars(x, s), arguments(p)); s) patvars(x, s) = s - patvars(p) = unique!(patvars(p, Symbol[])) @@ -138,48 +117,26 @@ patvars(p) = unique!(patvars(p, Symbol[])) # ============================================== function setdebrujin!(p::Union{PatVar,PatSegment}, pvars) - p.idx = findfirst((==)(p.name), pvars) + p.idx = findfirst((==)(p.name), pvars) end # literal case setdebrujin!(p, pvars) = nothing -function setdebrujin!(p::PatTerm, pvars) - setdebrujin!(operation(p), pvars) - foreach(x -> setdebrujin!(x, pvars), p.args) +function setdebrujin!(p::PatTerm, pvars) + setdebrujin!(operation(p), pvars) + foreach(x -> setdebrujin!(x, pvars), p.args) end -#TODO ADD ORIGINAL CODE OF PREDICATE TO PATVAR ? -function to_expr(x::PatVar) - if x.predicate == alwaystrue - Expr(:call, :~, x.name) - else - Expr(:call, :~, Expr(:(::), x.name, x.predicate)) - end -end - -to_expr(x::Any) = x -function to_expr(x::PatSegment) - Expr(:..., x.predicate == alwaystrue ? Expr(:call, :~, x.name) : - Expr(:call, :~, Expr(:(::), x.name, x.predicate)) - ) -end - -to_expr(x::PatSegment{typeof(alwaystrue)}) = - Expr(:..., Expr(:call, :~, x.name)) - -to_expr(x::PatSegment{T}) where {T <: Function} = - Expr(:..., Expr(:call, :~, Expr(:(::), x.name, nameof(T)))) - -to_expr(x::PatSegment{<:Type{T}}) where T = - Expr(:..., Expr(:call, :~, Expr(:(::), x.name, T))) - -function to_expr(x::PatTerm) - pl = operation(x) - similarterm(Expr, pl, map(to_expr, arguments(x)); exprhead=exprhead(x)) -end +to_expr(x) = x +to_expr(x::PatVar{T}) where {T} = Expr(:call, :~, Expr(:(::), x.name, x.predicate_code)) +to_expr(x::PatSegment{T}) where {T<:Function} = Expr(:..., Expr(:call, :~, Expr(:(::), x.name, x.predicate_code))) +to_expr(x::PatVar{typeof(alwaystrue)}) = Expr(:call, :~, x.name) +to_expr(x::PatSegment{typeof(alwaystrue)}) = Expr(:..., Expr(:call, :~, x.name)) +to_expr(x::PatTerm) = similarterm(Expr(:call, :x), operation(x), map(to_expr, arguments(x)); exprhead = exprhead(x)) +Base.show(io::IO, pat::AbstractPat) = print(io, to_expr(pat)) # include("rules/patterns.jl") diff --git a/src/Rewriters.jl b/src/Rewriters.jl index 1c811326..94d1ab38 100644 --- a/src/Rewriters.jl +++ b/src/Rewriters.jl @@ -37,7 +37,7 @@ export Empty, IfElse, If, Chain, RestartedChain, Fixpoint, Postwalk, Prewalk, Pa # Cache of printed rules to speed up @timer const repr_cache = IdDict() -cached_repr(x) = Base.get!(()->repr(x), repr_cache, x) +cached_repr(x) = Base.get!(() -> repr(x), repr_cache, x) struct Empty end @@ -46,82 +46,82 @@ struct Empty end instrument(x, f) = f(x) instrument(x::Empty, f) = x -struct IfElse{F, A, B} - cond::F - yes::A - no::B +struct IfElse{F,A,B} + cond::F + yes::A + no::B end instrument(x::IfElse, f) = IfElse(x.cond, instrument(x.yes, f), instrument(x.no, f)) function (rw::IfElse)(x) - rw.cond(x) ? rw.yes(x) : rw.no(x) + rw.cond(x) ? rw.yes(x) : rw.no(x) end If(f, x) = IfElse(f, x, Empty()) struct Chain - rws + rws end function (rw::Chain)(x) - for f in rw.rws - y = @timer cached_repr(f) f(x) - if y !== nothing - x = y - end + for f in rw.rws + y = @timer cached_repr(f) f(x) + if y !== nothing + x = y end - return x + end + return x end -instrument(c::Chain, f) = Chain(map(x->instrument(x,f), c.rws)) +instrument(c::Chain, f) = Chain(map(x -> instrument(x, f), c.rws)) struct RestartedChain{Cs} - rws::Cs + rws::Cs end -instrument(c::RestartedChain, f) = RestartedChain(map(x->instrument(x,f), c.rws)) +instrument(c::RestartedChain, f) = RestartedChain(map(x -> instrument(x, f), c.rws)) function (rw::RestartedChain)(x) - for f in rw.rws - y = @timer cached_repr(f) f(x) - if y !== nothing - return Chain(rw.rws)(y) - end + for f in rw.rws + y = @timer cached_repr(f) f(x) + if y !== nothing + return Chain(rw.rws)(y) end - return x + end + return x end -@generated function (rw::RestartedChain{<:NTuple{N,Any}})(x) where N - quote - Base.@nexprs $N i->begin - let f = rw.rws[i] - y = @timer cached_repr(repr(f)) f(x) - if y !== nothing - return Chain(rw.rws)(y) - end - end +@generated function (rw::RestartedChain{<:NTuple{N,Any}})(x) where {N} + quote + Base.@nexprs $N i -> begin + let f = rw.rws[i] + y = @timer cached_repr(repr(f)) f(x) + if y !== nothing + return Chain(rw.rws)(y) end - return x + end end + return x + end end struct Fixpoint{C} - rw::C + rw::C end instrument(x::Fixpoint, f) = Fixpoint(instrument(x.rw, f)) function (rw::Fixpoint)(x) - f = rw.rw + f = rw.rw + y = @timer cached_repr(f) f(x) + while x !== y && !isequal(x, y) + y === nothing && return x + x = y y = @timer cached_repr(f) f(x) - while x !== y && !isequal(x, y) - y === nothing && return x - x = y - y = @timer cached_repr(f) f(x) - end - return x + end + return x end """ @@ -134,112 +134,108 @@ if the repeated application of `rw` produces results `a, b, c, d, b` in order, `FixpointNoCycle` stops because `b` has been already produced. """ struct FixpointNoCycle{C} - rw::C - hist::Vector{UInt64} # vector of hashes for history + rw::C + hist::Vector{UInt64} # vector of hashes for history end instrument(x::FixpointNoCycle, f) = Fixpoint(instrument(x.rw, f)) function (rw::FixpointNoCycle)(x) - f = rw.rw - push!(rw.hist, hash(x)) - y = @timer cached_repr(f) f(x) - while x !== y && hash(x) ∉ hist - if y === nothing - empty!(rw.hist) - return x - end - push!(rw.hist, y) - x = y - y = @timer cached_repr(f) f(x) + f = rw.rw + push!(rw.hist, hash(x)) + y = @timer cached_repr(f) f(x) + while x !== y && hash(x) ∉ rw.hist + if y === nothing + empty!(rw.hist) + return x end - empty!(rw.hist) - return x + push!(rw.hist, y) + x = y + y = @timer cached_repr(f) f(x) + end + empty!(rw.hist) + return x end -struct Walk{ord, C, F, threaded} - rw::C - thread_cutoff::Int - similarterm::F +struct Walk{ord,C,F,threaded} + rw::C + thread_cutoff::Int + similarterm::F end -function instrument(x::Walk{ord, C,F,threaded}, f) where {ord,C,F,threaded} - irw = instrument(x.rw, f) - Walk{ord, typeof(irw), typeof(x.similarterm), threaded}(irw, - x.thread_cutoff, - x.similarterm) +function instrument(x::Walk{ord,C,F,threaded}, f) where {ord,C,F,threaded} + irw = instrument(x.rw, f) + Walk{ord,typeof(irw),typeof(x.similarterm),threaded}(irw, x.thread_cutoff, x.similarterm) end using .Threads -function Postwalk(rw; threaded::Bool=false, thread_cutoff=100, similarterm=similarterm) - Walk{:post, typeof(rw), typeof(similarterm), threaded}(rw, thread_cutoff, similarterm) +function Postwalk(rw; threaded::Bool = false, thread_cutoff = 100, similarterm = similarterm) + Walk{:post,typeof(rw),typeof(similarterm),threaded}(rw, thread_cutoff, similarterm) end -function Prewalk(rw; threaded::Bool=false, thread_cutoff=100, similarterm=similarterm) - Walk{:pre, typeof(rw), typeof(similarterm), threaded}(rw, thread_cutoff, similarterm) +function Prewalk(rw; threaded::Bool = false, thread_cutoff = 100, similarterm = similarterm) + Walk{:pre,typeof(rw),typeof(similarterm),threaded}(rw, thread_cutoff, similarterm) end struct PassThrough{C} - rw::C + rw::C end instrument(x::PassThrough, f) = PassThrough(instrument(x.rw, f)) -(p::PassThrough)(x) = (y=p.rw(x); y === nothing ? x : y) +(p::PassThrough)(x) = (y = p.rw(x); y === nothing ? x : y) passthrough(x, default) = x === nothing ? default : x -function (p::Walk{ord, C, F, false})(x) where {ord, C, F} - @assert ord === :pre || ord === :post +function (p::Walk{ord,C,F,false})(x) where {ord,C,F} + @assert ord === :pre || ord === :post + if istree(x) + if ord === :pre + x = p.rw(x) + end if istree(x) - if ord === :pre - x = p.rw(x) - end - if istree(x) - x = p.similarterm(x, operation(x), map(PassThrough(p), unsorted_arguments(x)); exprhead=exprhead(x)) - end - return ord === :post ? p.rw(x) : x - else - return p.rw(x) + x = p.similarterm(x, operation(x), map(PassThrough(p), unsorted_arguments(x)); exprhead = exprhead(x)) end + return ord === :post ? p.rw(x) : x + else + return p.rw(x) + end end -function (p::Walk{ord, C, F, true})(x) where {ord, C, F} - @assert ord === :pre || ord === :post +function (p::Walk{ord,C,F,true})(x) where {ord,C,F} + @assert ord === :pre || ord === :post + if istree(x) + if ord === :pre + x = p.rw(x) + end if istree(x) - if ord === :pre - x = p.rw(x) - end - if istree(x) - _args = map(arguments(x)) do arg - if node_count(arg) > p.thread_cutoff - Threads.@spawn p(arg) - else - p(arg) - end - end - args = map((t,a) -> passthrough(t isa Task ? fetch(t) : t, a), _args, arguments(x)) - t = p.similarterm(x, operation(x), args) + _args = map(arguments(x)) do arg + if node_count(arg) > p.thread_cutoff + Threads.@spawn p(arg) + else + p(arg) end - return ord === :post ? p.rw(t) : t - else - return p.rw(x) + end + args = map((t, a) -> passthrough(t isa Task ? fetch(t) : t, a), _args, arguments(x)) + t = p.similarterm(x, operation(x), args; exprhead = exprhead(x)) end + return ord === :post ? p.rw(t) : t + else + return p.rw(x) + end end function instrument_io(x) - function io_instrumenter(r) - function (args...) - println("Rule: ", r) - println("Input: ", args) - res = r(args...) - println("Output: ", res) - res - end + function io_instrumenter(r) + function (args...) + println("Rule: ", r) + println("Input: ", args) + res = r(args...) + println("Output: ", res) + res end + end - instrument(x, io_instrumenter) + instrument(x, io_instrumenter) end end # end module - - diff --git a/src/Rules.jl b/src/Rules.jl index 07022889..ddb6cdd1 100644 --- a/src/Rules.jl +++ b/src/Rules.jl @@ -1,14 +1,13 @@ module Rules using TermInterface -using Parameters using AutoHashEquals using Metatheory.EMatchCompiler using Metatheory.Patterns using Metatheory.Patterns: to_expr using Metatheory: cleanast, binarize, matcher, instantiate -const EMPTY_DICT = Base.ImmutableDict{Int, Any}() +const EMPTY_DICT = Base.ImmutableDict{Int,Any}() abstract type AbstractRule end # Must override @@ -19,19 +18,19 @@ abstract type SymbolicRule <: AbstractRule end abstract type BidirRule <: SymbolicRule end struct RuleRewriteError - rule - expr + rule + expr end getdepth(::Any) = typemax(Int) -showraw(io, t) = Base.show(IOContext(io, :simplify=>false), t) +showraw(io, t) = Base.show(IOContext(io, :simplify => false), t) showraw(t) = showraw(stdout, t) @noinline function Base.showerror(io::IO, err::RuleRewriteError) - msg = "Failed to apply rule $(err.rule) on expression " - msg *= sprint(io->showraw(io, err.expr)) - print(io, msg) + msg = "Failed to apply rule $(err.rule) on expression " + msg *= sprint(io -> showraw(io, err.expr)) + print(io, msg) end @@ -50,42 +49,36 @@ variables. @rule ~a * ~b --> ~b * ~a ``` """ -@auto_hash_equals struct RewriteRule <: SymbolicRule - expr # rule pattern stored for pretty printing - left - right - matcher - patvars::Vector{Symbol} - ematch_program::Program +@auto_hash_equals struct RewriteRule <: SymbolicRule + left + right + matcher + patvars::Vector{Symbol} + ematcher! end Base.isequal(a::RewriteRule, b::RewriteRule) = (a.left == b.left) && (a.right == b.right) function RewriteRule(l, r) - ex = :($(to_expr(l)) --> $(to_expr(r))) - RewriteRule(ex, l, r) + pvars = patvars(l) ∪ patvars(r) + # sort!(pvars) + setdebrujin!(l, pvars) + setdebrujin!(r, pvars) + RewriteRule(l, r, matcher(l), pvars, ematcher_yield(l, length(pvars))) end -function RewriteRule(ex, l, r) - pvars = patvars(l) ∪ patvars(r) - # sort!(pvars) - setdebrujin!(l, pvars) - setdebrujin!(r, pvars) - RewriteRule(ex, l, r, matcher(l), pvars, compile_pat(l)) -end - -Base.show(io::IO, r::RewriteRule) = print(io, r.expr) +Base.show(io::IO, r::RewriteRule) = print(io, :($(r.left) --> $(r.right))) function (r::RewriteRule)(term) - # n == 1 means that exactly one term of the input (term,) was matched - success(bindings, n) = n == 1 ? instantiate(term, r.right, bindings) : nothing - - try - r.matcher(success, (term,), EMPTY_DICT) - catch err - throw(RuleRewriteError(r, term)) - end + # n == 1 means that exactly one term of the input (term,) was matched + success(bindings, n) = n == 1 ? instantiate(term, r.right, bindings) : nothing + + try + r.matcher(success, (term,), EMPTY_DICT) + catch err + throw(RuleRewriteError(r, term)) + end end # ============================================================ @@ -101,37 +94,30 @@ with the EGraphs backend. @rule ~a * ~b == ~b * ~a ``` """ -@auto_hash_equals struct EqualityRule <: BidirRule - expr # rule pattern stored for pretty printing - left - right - patvars::Vector{Symbol} - ematch_program_l::Program - ematch_program_r::Program -end - -function EqualityRule(ex, l, r) - pvars = patvars(l) ∪ patvars(r) - extravars = setdiff(pvars, patvars(l) ∩ patvars(r)) - if !isempty(extravars) - error("unbound pattern variables $extravars when creating bidirectional rule") - end - setdebrujin!(l, pvars) - setdebrujin!(r, pvars) - progl = compile_pat(l) - progr = compile_pat(r) - EqualityRule(ex, l, r, pvars, progl, progr) +@auto_hash_equals struct EqualityRule <: BidirRule + left + right + patvars::Vector{Symbol} + ematcher! end function EqualityRule(l, r) - ex = :($(to_expr(l)) --> $(to_expr(r))) - EqualityRule(ex, l, r) + pvars = patvars(l) ∪ patvars(r) + extravars = setdiff(pvars, patvars(l) ∩ patvars(r)) + if !isempty(extravars) + error("unbound pattern variables $extravars when creating bidirectional rule") + end + setdebrujin!(l, pvars) + setdebrujin!(r, pvars) + + EqualityRule(l, r, pvars, ematcher_yield_bidir(l,r, length(pvars))) end -Base.show(io::IO, r::EqualityRule) = print(io, r.expr) + +Base.show(io::IO, r::EqualityRule) = print(io, :($(r.left) == $(r.right))) function (r::EqualityRule)(x) - throw(RuleRewriteError(r, x)) + throw(RuleRewriteError(r, x)) end @@ -145,40 +131,30 @@ backend. If two terms, corresponding to the left and right hand side of an *anti-rule* are found in an [`EGraph`], saturation is halted immediately. ```julia -¬a ≠ a +!a ≠ a ``` """ -@auto_hash_equals struct UnequalRule <: BidirRule - expr # rule pattern stored for pretty printing - left - right - patvars::Vector{Symbol} - ematch_program_l::Program - ematch_program_r::Program +@auto_hash_equals struct UnequalRule <: BidirRule + left + right + patvars::Vector{Symbol} + ematcher! end - function UnequalRule(l, r) - ex = :($(to_expr(l)) --> $(to_expr(r))) - UnequalRule(ex, l, r) -end - -function UnequalRule(ex, l, r) - pvars = patvars(l) ∪ patvars(r) - extravars = setdiff(pvars, patvars(l) ∩ patvars(r)) - if !isempty(extravars) - error("unbound pattern variables $extravars when creating bidirectional rule") - end -# sort!(pvars) - setdebrujin!(l, pvars) - setdebrujin!(r, pvars) - progl = compile_pat(l) - progr = compile_pat(r) - UnequalRule(ex, l, r, pvars, progl, progr) + pvars = patvars(l) ∪ patvars(r) + extravars = setdiff(pvars, patvars(l) ∩ patvars(r)) + if !isempty(extravars) + error("unbound pattern variables $extravars when creating bidirectional rule") + end + # sort!(pvars) + setdebrujin!(l, pvars) + setdebrujin!(r, pvars) + UnequalRule(l, r, pvars, ematcher_yield_bidir(l,r, length(pvars))) end -Base.show(io::IO, r::UnequalRule) = print(io, r.expr) +Base.show(io::IO, r::UnequalRule) = print(io, :($(r.left) ≠ $(r.right))) # ============================================================ # DynamicRule @@ -198,41 +174,39 @@ Dynamic rule ``` """ @auto_hash_equals struct DynamicRule <: AbstractRule - expr # rule pattern stored for pretty printing - left - rhs_fun::Function - matcher - patvars::Vector{Symbol} # useful set of pattern variables - ematch_program::Program + left + rhs_fun::Function + rhs_code + matcher + patvars::Vector{Symbol} # useful set of pattern variables + ematcher! end -function DynamicRule(l, r) - ex = :($(to_expr(l)) => $(to_expr(r))) - DynamicRule(ex, l, r) -end - -function DynamicRule(ex, l, r::Function) - pvars = patvars(l) - setdebrujin!(l, pvars) +function DynamicRule(l, r::Function, rhs_code = nothing) + pvars = patvars(l) + setdebrujin!(l, pvars) + isnothing(rhs_code) && (rhs_code = repr(rhs_code)) - DynamicRule(ex, l, r, matcher(l), pvars, compile_pat(l)) + DynamicRule(l, r, rhs_code, matcher(l), pvars, ematcher_yield(l, length(pvars))) end -Base.show(io::IO, r::DynamicRule) = print(io, r.expr) +Base.show(io::IO, r::DynamicRule) = print(io, :($(r.left) => $(r.rhs_code))) -function (r::DynamicRule)(term) - # n == 1 means that exactly one term of the input (term,) was matched - success(bindings, n) = if n == 1 - bvals = [bindings[i] for i in 1:length(r.patvars)] - return r.rhs_fun(term, bindings, nothing, bvals...) +function (r::DynamicRule)(term) + # n == 1 means that exactly one term of the input (term,) was matched + success(bindings, n) = + if n == 1 + bvals = [bindings[i] for i in 1:length(r.patvars)] + return r.rhs_fun(term, nothing, bvals...) end - try - return r.matcher(success, (term,), EMPTY_DICT) - catch err - throw(RuleRewriteError(r, term)) - end + try + return r.matcher(success, (term,), EMPTY_DICT) + catch err + rethrow(err) + throw(RuleRewriteError(r, term)) + end end export SymbolicRule @@ -243,4 +217,4 @@ export UnequalRule export DynamicRule export AbstractRule -end \ No newline at end of file +end diff --git a/src/Syntax.jl b/src/Syntax.jl index da7c438e..f3852de9 100644 --- a/src/Syntax.jl +++ b/src/Syntax.jl @@ -3,12 +3,12 @@ using Metatheory.Patterns using Metatheory.Rules using TermInterface -using Metatheory:alwaystrue, cleanast, binarize +using Metatheory: alwaystrue, cleanast, binarize export @rule export @theory export @slots -export @capture +export @capture # FIXME this thing eats up macro calls! @@ -20,117 +20,118 @@ rmlines(a) = a function makesegment(s::Expr, pvars) - if !(exprhead(s) == :(::)) - error("Syntax for specifying a segment is ~~x::\$predicate, where predicate is a boolean function or a type") - end + if !(exprhead(s) == :(::)) + error("Syntax for specifying a segment is ~~x::\$predicate, where predicate is a boolean function or a type") + end - name = arguments(s)[1] - name ∉ pvars && push!(pvars, name) - return :($PatSegment($(QuoteNode(name)), -1, $(arguments(s)[2]))) + name, predicate = arguments(s) + name ∉ pvars && push!(pvars, name) + return :($PatSegment($(QuoteNode(name)), -1, $predicate, $(QuoteNode(predicate)))) end -function makesegment(name::Symbol, pvars) - name ∉ pvars && push!(pvars, name) - PatSegment(name) + +function makesegment(name::Symbol, pvars) + name ∉ pvars && push!(pvars, name) + PatSegment(name) end + function makevar(s::Expr, pvars) - if !(exprhead(s) == :(::)) - error("Syntax for specifying a slot is ~x::\$predicate, where predicate is a boolean function or a type") - end + if !(exprhead(s) == :(::)) + error("Syntax for specifying a slot is ~x::\$predicate, where predicate is a boolean function or a type") + end - name = arguments(s)[1] - name ∉ pvars && push!(pvars, name) - return :($PatVar($(QuoteNode(name)), -1, $(arguments(s)[2]))) + name, predicate = arguments(s) + name ∉ pvars && push!(pvars, name) + return :($PatVar($(QuoteNode(name)), -1, $predicate, $(QuoteNode(predicate)))) end -function makevar(name::Symbol, pvars) - name ∉ pvars && push!(pvars, name) - PatVar(name) + +function makevar(name::Symbol, pvars) + name ∉ pvars && push!(pvars, name) + PatVar(name) end # Make a dynamic rule right hand side function makeconsequent(expr::Expr) - head = exprhead(expr) - args = arguments(expr) - op = operation(expr) - if head === :call - if op === :(~) - if args[1] isa Symbol - return args[1] - elseif args[1] isa Expr && operation(args[1]) == :(~) - n = arguments(args[1])[1] - @assert n isa Symbol - return n - else - error("Error when parsing right hand side") - end - else - return Expr(head, makeconsequent(op), - map(makeconsequent, args)...) - end + head = exprhead(expr) + args = arguments(expr) + op = operation(expr) + if head === :call + if op === :(~) + if args[1] isa Symbol + return args[1] + elseif args[1] isa Expr && operation(args[1]) == :(~) + n = arguments(args[1])[1] + @assert n isa Symbol + return n + else + error("Error when parsing right hand side") + end else - return Expr(head, map(makeconsequent, args)...) + return Expr(head, makeconsequent(op), map(makeconsequent, args)...) end + else + return Expr(head, map(makeconsequent, args)...) + end end makeconsequent(x) = x # treat as a literal -function makepattern(x, pvars, slots, mod=@__MODULE__, splat=false) - if splat - x in slots ? makesegment(x, pvars) : x - else - x in slots ? makevar(x, pvars) : x - end +function makepattern(x, pvars, slots, mod = @__MODULE__, splat = false) + x in slots ? (splat ? makesegment(x, pvars) : makevar(x, pvars)) : x end -function makepattern(ex::Expr, pvars, slots, mod=@__MODULE__, splat=false) - head = exprhead(ex) - op = operation(ex) - args = arguments(ex) - istree(op) && (op = makepattern(op, pvars, slots, mod)) - op = op isa Symbol ? QuoteNode(op) : op - #throw(Meta.ParseError("Unsupported pattern syntax $ex")) - - - if head === :call - if operation(ex) === :(~) # is a variable or segment - if args[1] isa Expr && operation(args[1]) == :(~) - # matches ~~x::predicate or ~~x::predicate... - return makesegment(arguments(args[1])[1], pvars) - elseif splat - # matches ~x::predicate... - return makesegment(args[1], pvars) - else - return makevar(args[1], pvars) - end - else # is a term - patargs = map(i -> makepattern(i, pvars, slots, mod), args) # recurse - return :($PatTerm(:call, $op, [$(patargs...)], $mod)) - end - elseif head === :... - makepattern(args[1], pvars, slots, mod, true) - elseif head == :(::) && args[1] in slots - return splat ? makesegment(ex, pvars) : makevar(ex, pvars) - elseif head === :ref - # getindex - patargs = map(i -> makepattern(i, pvars, slots, mod), args) # recurse - return :($PatTerm(:ref, getindex, [$(patargs...)], $mod)) - elseif head === :$ - return args[1] - else - patargs = map(i -> makepattern(i, pvars, slots, mod), args) # recurse - return :($PatTerm($(head isa Symbol ? QuoteNode(head) : head), $(op isa Symbol ? QuoteNode(op) : op), [$(patargs...)], $mod)) - # throw(Meta.ParseError("Unsupported pattern syntax $ex")) +function makepattern(ex::Expr, pvars, slots, mod = @__MODULE__, splat = false) + head = exprhead(ex) + op = operation(ex) + # Retrieve the function object if available + args = arguments(ex) + istree(op) && (op = makepattern(op, pvars, slots, mod)) + + if head === :call + if operation(ex) === :(~) # is a variable or segment + if args[1] isa Expr && operation(args[1]) == :(~) + # matches ~~x::predicate or ~~x::predicate... + return makesegment(arguments(args[1])[1], pvars) + elseif splat + # matches ~x::predicate... + return makesegment(args[1], pvars) + else + return makevar(args[1], pvars) + end + else # is a term + patargs = map(i -> makepattern(i, pvars, slots, mod), args) # recurse + return :($PatTerm(:call, $op, [$(patargs...)])) end + elseif head === :... + makepattern(args[1], pvars, slots, mod, true) + elseif head == :(::) && args[1] in slots + return splat ? makesegment(ex, pvars) : makevar(ex, pvars) + elseif head === :ref + # getindex + patargs = map(i -> makepattern(i, pvars, slots, mod), args) # recurse + return :($PatTerm(:ref, getindex, [$(patargs...)])) + elseif head === :$ + return args[1] + else + patargs = map(i -> makepattern(i, pvars, slots, mod), args) # recurse + return :($PatTerm($(QuoteNode(head)), $(op isa Symbol ? QuoteNode(op) : op), [$(patargs...)])) + # throw(Meta.ParseError("Unsupported pattern syntax $ex")) + end end function rule_sym_map(ex::Expr) - h = operation(ex) - if h == :(-->) || h == :(→) RewriteRule - elseif h == :(=>) DynamicRule - elseif h == :(==) EqualityRule - elseif h == :(!=) || h == :(≠) UnequalRule - else error("Cannot parse rule with operator '$h'") - end + h = operation(ex) + if h == :(-->) || h == :(→) + RewriteRule + elseif h == :(=>) + DynamicRule + elseif h == :(==) + EqualityRule + elseif h == :(!=) || h == :(≠) + UnequalRule + else + error("Cannot parse rule with operator '$h'") + end end rule_sym_map(ex) = error("Cannot parse rule from $ex") @@ -141,27 +142,26 @@ Rewrite the `expr` by dealing with `:where` if necessary. The `:where` is rewritten from, for example, `~x where f(~x)` to `f(~x) ? ~x : nothing`. """ function rewrite_rhs(ex::Expr) - if exprhead(ex) == :where - args = arguments(ex) - rhs = args[1] - predicate = args[2] - ex = :($predicate ? $rhs : nothing) - end - return ex + if exprhead(ex) == :where + rhs, predicate = arguments(ex) + return :($predicate ? $rhs : nothing) + end + ex end rewrite_rhs(x) = x function addslots(expr, slots) - if expr isa Expr - if expr.head === :macrocall && expr.args[1] in [Symbol("@rule"), Symbol("@capture"), Symbol("@slots"), Symbol("@theory")] - Expr(:macrocall, expr.args[1:2]..., slots..., expr.args[3:end]...) - else - Expr(expr.head, addslots.(expr.args, (slots,))...) - end + if expr isa Expr + if expr.head === :macrocall && + expr.args[1] in [Symbol("@rule"), Symbol("@capture"), Symbol("@slots"), Symbol("@theory")] + Expr(:macrocall, expr.args[1:2]..., slots..., expr.args[3:end]...) else - expr + Expr(expr.head, addslots.(expr.args, (slots,))...) end + else + expr + end end @@ -179,11 +179,11 @@ julia> @slots x y z a b c Chain([ See also: [`@rule`](@ref), [`@capture`](@ref) """ macro slots(args...) - length(args) >= 1 || ArgumentError("@slots requires at least one argument") - slots = args[1:end-1] - expr = args[end] + length(args) >= 1 || ArgumentError("@slots requires at least one argument") + slots = args[1:(end - 1)] + expr = args[end] - return esc(addslots(expr, slots)) + return esc(addslots(expr, slots)) end @@ -322,31 +322,35 @@ Segment variables may still be written as (`~~x`), and slot (`~x`) and segment ( See also: [`@capture`](@ref), [`@slots`](@ref) """ macro rule(args...) - length(args) >= 1 || ArgumentError("@rule requires at least one argument") - slots = args[1:end-1] - expr = args[end] - - e = macroexpand(__module__, expr) - e = rmlines(e) - op = operation(e) - RuleType = rule_sym_map(e) - - l, r = arguments(e) - pvars = Symbol[] - lhs = makepattern(l, pvars, slots, __module__) - rhs = RuleType <: SymbolicRule ? makepattern(r, [], slots, __module__) : r - - if RuleType == DynamicRule - rhs = rewrite_rhs(r) - rhs = makeconsequent(rhs) - params = Expr(:tuple, :_lhs_expr, :_subst, :_egraph, pvars...) - rhs = :($(esc(params)) -> $(esc(rhs))) - end - + length(args) >= 1 || ArgumentError("@rule requires at least one argument") + slots = args[1:(end - 1)] + expr = args[end] + + e = macroexpand(__module__, expr) + e = rmlines(e) + op = operation(e) + RuleType = rule_sym_map(e) + + l, r = arguments(e) + pvars = Symbol[] + lhs = makepattern(l, pvars, slots, __module__) + rhs = RuleType <: SymbolicRule ? esc(makepattern(r, [], slots, __module__)) : r + + if RuleType == DynamicRule + rhs_rewritten = rewrite_rhs(r) + rhs_consequent = makeconsequent(rhs_rewritten) + params = Expr(:tuple, :_lhs_expr, :_egraph, pvars...) + rhs = :($(esc(params)) -> $(esc(rhs_consequent))) return quote - $(__source__) - ($RuleType)($(QuoteNode(expr)), $(esc(lhs)), $rhs) + $(__source__) + DynamicRule($(esc(lhs)), $rhs, $(QuoteNode(rhs_consequent))) end + end + + quote + $(__source__) + ($RuleType)($(esc(lhs)), $rhs) + end end @@ -376,20 +380,20 @@ julia> v = [ ``` """ macro theory(args...) - length(args) >= 1 || ArgumentError("@rule requires at least one argument") - slots = args[1:end-1] - expr = args[end] - - e = macroexpand(__module__, expr) - e = rmlines(e) - # e = interp_dollar(e, __module__) - - if exprhead(e) == :block - ee = Expr(:vect, map(x -> addslots(:(@rule($x)), slots), arguments(e))...) - esc(ee) - else - error("theory is not in form begin a => b; ... end") - end + length(args) >= 1 || ArgumentError("@rule requires at least one argument") + slots = args[1:(end - 1)] + expr = args[end] + + e = macroexpand(__module__, expr) + e = rmlines(e) + # e = interp_dollar(e, __module__) + + if exprhead(e) == :block + ee = Expr(:vect, map(x -> addslots(:(@rule($x)), slots), arguments(e))...) + esc(ee) + else + error("theory is not in form begin a => b; ... end") + end end @@ -412,32 +416,31 @@ x = a See also: [`@rule`](@ref) """ macro capture(args...) - length(args) >= 2 || ArgumentError("@capture requires at least two arguments") - slots = args[1:end-2] - ex = args[end-1] - lhs = args[end] - lhs = macroexpand(__module__, lhs) - lhs = rmlines(lhs) - - pvars = Symbol[] - lhs_term = makepattern(lhs, pvars, slots, __module__) - bind = Expr(:block, map(key-> :($(esc(key)) = getindex(__MATCHES__, findfirst((==)($(QuoteNode(key))), $pvars))), pvars)...) - quote - $(__source__) - lhs_pattern = $(esc(lhs_term)) - println(lhs_pattern) - dump(lhs_pattern) - __MATCHES__ = DynamicRule($(QuoteNode(lhs)), - lhs_pattern, (_lhs_expr, _subst, _egraph, pvars...) -> pvars)($(esc(ex))) - if __MATCHES__ !== nothing - $bind - true - else - false - end + length(args) >= 2 || ArgumentError("@capture requires at least two arguments") + slots = args[1:(end - 2)] + ex = args[end - 1] + lhs = args[end] + lhs = macroexpand(__module__, lhs) + lhs = rmlines(lhs) + + pvars = Symbol[] + lhs_term = makepattern(lhs, pvars, slots, __module__) + bind = Expr( + :block, + map(key -> :($(esc(key)) = getindex(__MATCHES__, findfirst((==)($(QuoteNode(key))), $pvars))), pvars)..., + ) + quote + $(__source__) + lhs_pattern = $(esc(lhs_term)) + __MATCHES__ = DynamicRule(lhs_pattern, (_lhs_expr, _egraph, pvars...) -> pvars, nothing)($(esc(ex))) + if __MATCHES__ !== nothing + $bind + true + else + false end + end end end - diff --git a/src/docstrings.jl b/src/docstrings.jl index e9d35a45..c7332464 100644 --- a/src/docstrings.jl +++ b/src/docstrings.jl @@ -2,33 +2,30 @@ using DocStringExtensions -@template (FUNCTIONS, METHODS, MACROS) = -""" -$(DOCSTRING) +@template (FUNCTIONS, METHODS, MACROS) = """ + $(DOCSTRING) ---- -# Signatures -$(TYPEDSIGNATURES) ---- -## Methods -$(METHODLIST) -""" + --- + # Signatures + $(TYPEDSIGNATURES) + --- + ## Methods + $(METHODLIST) + """ -@template (TYPES) = -""" -$(TYPEDEF) -$(DOCSTRING) +@template (TYPES) = """ + $(TYPEDEF) + $(DOCSTRING) ---- -## Fields -$(TYPEDFIELDS) -""" + --- + ## Fields + $(TYPEDFIELDS) + """ -@template MODULES = -""" -$(DOCSTRING) +@template MODULES = """ + $(DOCSTRING) ---- -## Imports -$(IMPORTS) -""" + --- + ## Imports + $(IMPORTS) + """ diff --git a/src/ematch_compiler.jl b/src/ematch_compiler.jl index 28040582..5aed17dc 100644 --- a/src/ematch_compiler.jl +++ b/src/ematch_compiler.jl @@ -1,236 +1,166 @@ -# TODO make it yield enodes only? ping pavel and marisa -# TODO STAGE IT! FASTER! module EMatchCompiler -using AutoHashEquals using TermInterface -using Metatheory: alwaystrue, binarize -using Metatheory.Patterns - -abstract type Instruction end -export Instruction - -const Register = Int32 - -mutable struct Program - instructions::Vector{Instruction} - first_nonground::Int - memsize::Int - regs::Vector{Register} - ground_terms::Dict{Any,Register} -end -export Program - - -function Program() - Program(Instruction[], 0, 0, [], Dict{AbstractPat,Register}()) -end - -hasregister(prog::Program, i) = (prog.regs[i] != -1) -getregister(prog::Program, i) = prog.regs[i] -setregister(prog::Program, i, v) = (prog.regs[i] = v) -increment(prog::Program, i) = (prog.memsize += i) -memsize(prog::Program) = prog.memsize - -Base.getindex(p::Program, i) = p.instructions[i] -Base.length(p::Program) = length(p.instructions) - -@auto_hash_equals struct ENodePat - exprhead::Union{Symbol,Nothing} - operation::Any - # args::Vector{Register} - args::UnitRange{Register} - checkop::Function # function that checks both symbol or func. object as op -end -export ENodePat - -TermInterface.operation(p::ENodePat) = p.operation -TermInterface.exprhead(p::ENodePat) = p.exprhead -TermInterface.arguments(p::ENodePat) = p.args -TermInterface.arity(p::ENodePat) = length(p.args) - -@auto_hash_equals struct Bind <: Instruction - reg::Register - enodepat::ENodePat -end -export Bind - -@auto_hash_equals struct CheckClassEq <: Instruction - left::Register - right::Register -end -export CheckClassEq - -@auto_hash_equals struct CheckType <: Instruction - reg::Register - type::Any -end -export CheckType - -@auto_hash_equals struct CheckPredicate <: Instruction - reg::Register - predicate::Function -end -export CheckPredicate - -@auto_hash_equals struct Yield <: Instruction - yields::Vector{Register} -end -export Yield - -@auto_hash_equals struct Filter <: Instruction - reg::Register - operation::Any - arity::Int -end -TermInterface.operation(x::Filter) = x.operation -export Filter - -@auto_hash_equals struct Lookup <: Instruction - reg::Register - p::Any # pattern -end -export Lookup - -@auto_hash_equals struct Fail <: Instruction - err::Exception +using ..Patterns +using Metatheory: islist, car, cdr, assoc, drop_n, lookup_pat, DEFAULT_BUFFER_SIZE, BUFFER, BUFFER_LOCK, MERGES_BUF, MERGES_BUF_LOCK, LL + +function ematcher(p::Any) + function literal_ematcher(next, g, data, bindings) + !islist(data) && return + ecid = lookup_pat(g, p) + if ecid > 0 && ecid == car(data) + next(bindings, 1) + end + end end -export Fail -# ============================================= -# ========= GROUND patterns ================ -# ============================================= +checktype(n, T) = istree(n) ? symtype(n) <: T : false - -function compile_ground!(reg, p::PatTerm, prog) - p = binarize(p) - - if haskey(prog.ground_terms, p) - # push!(prog.instructions, CheckClassEq(reg, prog.ground_terms[p])) - return nothing +function predicate_ematcher(p::PatVar, pred::Type) + function type_ematcher(next, g, data, bindings) + !islist(data) && return + id = car(data) + eclass = g[id] + for (enode_idx, n) in enumerate(eclass) + if !istree(n) && operation(n) isa pred + next(assoc(bindings, p.idx, (id, enode_idx)), 1) + end end - - if isground(p) - prog.ground_terms[p] = reg - push!(prog.instructions, Lookup(reg, p)) - increment(prog, 1) - else - for p2 in p.args - compile_ground!(prog.memsize, p2, prog) + end +end + +function predicate_ematcher(p::PatVar, pred) + function predicate_ematcher(next, g, data, bindings) + !islist(data) && return + id::Int = car(data) + eclass = g[id] + if pred(eclass) + enode_idx = 0 + # Is this for cycle needed? + for (j, n) in enumerate(eclass) + # Find first literal if available + if !istree(n) + enode_idx = j + break end + end + next(assoc(bindings, p.idx, (id, enode_idx)), 1) + end + end +end + +function ematcher(p::PatVar) + pred_matcher = predicate_ematcher(p, p.predicate) + + function var_ematcher(next, g, data, bindings) + id = car(data) + ecid = get(bindings, p.idx, 0)[1] + if ecid > 0 + ecid == id ? next(bindings, 1) : nothing + else + # Variable is not bound, check predicate and bind + pred_matcher(next, g, data, bindings) end + end end +Base.@pure @inline checkop(x::Union{Function,DataType}, op) = isequal(x, op) || isequal(nameof(x), op) +Base.@pure @inline checkop(x, op) = isequal(x, op) -function compile_ground!(reg, p::PatVar, prog) - nothing +function canbind(p::PatTerm) + eh = exprhead(p) + op = operation(p) + ar = arity(p) + function canbind(n) + istree(n) && exprhead(n) == eh && checkop(op, operation(n)) && arity(n) == ar + end end -function compile_ground!(reg, p::AbstractPat, prog) - push!(prog.instructions, Fail(UnsupportedPatternException(p))) -end +function ematcher(p::PatTerm) + ematchers = map(ematcher, arguments(p)) -# A literal that is not a pattern -function compile_ground!(reg, p::Any, prog) - if haskey(prog.ground_terms, p) - return nothing + if isground(p) + return function ground_term_ematcher(next, g, data, bindings) + !islist(data) && return + ecid = lookup_pat(g, p) + if ecid > 0 && ecid == car(data) + next(bindings, 1) + end end - prog.ground_terms[p] = reg - push!(prog.instructions, Lookup(reg, p)) - increment(prog, 1) -end - -# ============================================= -# ========= NONGROUND patterns ================ -# ============================================= - -function compile_pat!(reg, p::PatTerm, prog) - p = binarize(p) - - if haskey(prog.ground_terms, p) - push!(prog.instructions, CheckClassEq(reg, prog.ground_terms[p])) + end + + canbindtop = canbind(p) + function term_ematcher(success, g, data, bindings) + !islist(data) && return nothing + + function loop(children_eclass_ids, bindings′, ematchers′) + if !islist(ematchers′) + # term is empty + if !islist(children_eclass_ids) + # we have correctly matched the term + return success(bindings′, 1) + end return nothing - end - # a = [gensym() for i in 1:length(p.args)] - c = memsize(prog) - nargs = arity(p) - # registers unit range - regrange = c:(c + nargs - 1) - - exhead = exprhead(p) - op = operation(p) - checkop = x -> isequal(x, op) - - if op isa Symbol - checkop = try - fobj = getproperty(p.mod, op) - (x) -> (isequal(x, op) || isequal(x, fobj)) - catch e - e isa UndefVarError ? checkop : rethrow(e) - end + end + car(ematchers′)(g, children_eclass_ids, bindings′) do b, n_of_matched # next + # recursion case: + # take the first matcher, on success, + # keep looping by matching the rest + # by removing the first n matched elements + # from the term, with the bindings, + loop(drop_n(children_eclass_ids, n_of_matched), b, cdr(ematchers′)) + end end - increment(prog, nargs) - - push!(prog.instructions, Bind(reg, ENodePat(exhead, op, regrange, checkop))) - for (reg, p2) in zip(regrange, arguments(p)) - compile_pat!(reg, p2, prog) + for n in g[car(data)] + if canbindtop(n) + loop(LL(arguments(n),1), bindings, ematchers) + end end -end - - -function compile_pat!(reg, p::PatVar, prog) - if hasregister(prog, p.idx) - push!(prog.instructions, CheckClassEq(reg, getregister(prog, p.idx))) - else # Variable is new - setregister(prog, p.idx, reg) - if p.predicate isa Function && p.predicate != alwaystrue - push!(prog.instructions, CheckPredicate(reg, p.predicate)) - elseif p.predicate isa Type - push!(prog.instructions, CheckType(reg, p.predicate)) + end +end + + +const EMPTY_ECLASS_DICT = Base.ImmutableDict{Int,Tuple{Int, Int}}() + +""" +Substitutions are efficiently represented in memory as vector of tuples of two integers. +This should allow for static allocation of matches and use of LoopVectorization.jl +The buffer has to be fairly big when e-matching. +The size of the buffer should double when there's too many matches. +The format is as follows +* The first pair denotes the index of the rule in the theory and the e-class id + of the node of the e-graph that is being substituted. The rule number should be negative if it's a bidirectional + the direction is right-to-left. +* From the second pair on, it represents (e-class id, literal position) at the position of the pattern variable +* The end of a substitution is delimited by (0,0) +""" +function ematcher_yield(p, npvars::Int, direction::Int) + em = ematcher(p) + function ematcher_yield(g, rule_idx, id)::Int + n_matches = 0 + em(g, (id,), EMPTY_ECLASS_DICT) do b,n + lock(BUFFER_LOCK) do + push!(BUFFER[], assoc(b, 0, (rule_idx * direction, id))) + n_matches+=1 + end end + n_matches end end -function compile_pat!(reg, p::AbstractPat, prog) - push!(prog.instructions, Fail(UnsupportedPatternException(p))) -end +ematcher_yield(p,npvars) = ematcher_yield(p,npvars,1) -# Literal values -function compile_pat!(reg, p::Any, prog) - if haskey(prog.ground_terms, p) - push!(prog.instructions, CheckClassEq(reg, prog.ground_terms[p])) - return nothing +function ematcher_yield_bidir(l, r, npvars::Int) + eml, emr = ematcher_yield(l, npvars, 1), ematcher_yield(r, npvars, -1) + function ematcher_yield_bidir(g, rule_idx, id)::Int + eml(g,rule_idx,id) + emr(g,rule_idx,id) end - @error "This shouldn't be printed. Report an issue for ematching literals" end +ematcher(p::AbstractPattern) = error("Unsupported pattern in e-matching $p") -#= ====================================================================================== =# - -# EXPECTS INDEXES OF PATTERN VARIABLES TO BE ALREADY POPULATED -function compile_pat(p) - p = binarize(p) - pvars = patvars(p) - nvars = length(pvars) - - # The program will try to match against ground terms first - prog = Program(Instruction[], 1, 1, fill(-1, nvars), Dict{AbstractPat,Register}()) - # println("compiling pattern $p") - compile_ground!(1, p, prog) - # println("compiled ground pattern \n $prog") - prog.first_nonground = prog.memsize - prog.memsize += 1 - - # And then try to match against other patterns - compile_pat!(prog.first_nonground, p, prog) - push!(prog.instructions, Yield(prog.regs)) - # println("compiled pattern $p to \n $prog") - # @show prog - return prog -end - -export compile_pat +export ematcher_yield, ematcher_yield_bidir end \ No newline at end of file diff --git a/src/matchers.jl b/src/matchers.jl index 6af8706b..e93dbd14 100644 --- a/src/matchers.jl +++ b/src/matchers.jl @@ -9,152 +9,154 @@ using Metatheory: islist, car, cdr, assoc, drop_n, take_n function matcher(val::Any) - function literal_matcher(next, data, bindings) - islist(data) && isequal(car(data), val) ? next(bindings, 1) : nothing - end + function literal_matcher(next, data, bindings) + islist(data) && isequal(car(data), val) ? next(bindings, 1) : nothing + end end function matcher(slot::PatVar) - pred = slot.predicate - if slot.predicate isa Type - pred = x -> typeof(x) <: slot.predicate - end - function slot_matcher(next, data, bindings) - !islist(data) && return - val = get(bindings, slot.idx, nothing) - if val !== nothing - if isequal(val, car(data)) - return next(bindings, 1) - end - else - # Variable is not bound, first time it is found - # check the predicate - if pred(car(data)) - next(assoc(bindings, slot.idx, car(data)), 1) - end - end + pred = slot.predicate + if slot.predicate isa Type + pred = x -> typeof(x) <: slot.predicate + end + function slot_matcher(next, data, bindings) + !islist(data) && return + val = get(bindings, slot.idx, nothing) + if val !== nothing + if isequal(val, car(data)) + return next(bindings, 1) + end + else + # Variable is not bound, first time it is found + # check the predicate + if pred(car(data)) + next(assoc(bindings, slot.idx, car(data)), 1) + end end + end end # returns n == offset, 0 if failed function trymatchexpr(data, value, n) - if !islist(value) - return n - elseif islist(value) && islist(data) - if !islist(data) - # didn't fully match - return nothing - end - - while isequal(car(value), car(data)) - n += 1 - value = cdr(value) - data = cdr(data) + if !islist(value) + return n + elseif islist(value) && islist(data) + if !islist(data) + # didn't fully match + return nothing + end - if !islist(value) - return n - elseif !islist(data) - return nothing - end - end + while isequal(car(value), car(data)) + n += 1 + value = cdr(value) + data = cdr(data) - return !islist(value) ? n : nothing - elseif isequal(value, data) - return n + 1 + if !islist(value) + return n + elseif !islist(data) + return nothing + end end + + return !islist(value) ? n : nothing + elseif isequal(value, data) + return n + 1 + end end function matcher(segment::PatSegment) - function segment_matcher(success, data, bindings) - val = get(bindings, segment.idx, nothing) - if val !== nothing - n = trymatchexpr(data, val, 0) - if !isnothing(n) - success(bindings, n) - end - else - res = nothing - - for i = length(data):-1:0 - subexpr = take_n(data, i) - - if segment.predicate(subexpr) - res = success(assoc(bindings, segment.idx, subexpr), i) - !isnothing(res) && break - end - end - - return res + function segment_matcher(success, data, bindings) + val = get(bindings, segment.idx, nothing) + if val !== nothing + n = trymatchexpr(data, val, 0) + if !isnothing(n) + success(bindings, n) + end + else + res = nothing + + for i in length(data):-1:0 + subexpr = take_n(data, i) + + if segment.predicate(subexpr) + res = success(assoc(bindings, segment.idx, subexpr), i) + !isnothing(res) && break end + end + + return res end + end end -# TODO REVIEWME # Try to match both against a function symbol or a function object at the same time. -# Slows things down a bit but lets this matcher work at the same time on both purely symbolic Expr-like object +# Slows compile time down a bit but lets this matcher work at the same time on both purely symbolic Expr-like object. +# Execution time should not be affected. # and SymbolicUtils-like objects that store function references as operations. -function head_matcher(f::Symbol, mod) - checkhead = try - fobj = getproperty(mod, f) - (x) -> (isequal(x, f) || isequal(x, fobj)) - catch e - if e isa UndefVarError - (x) -> isequal(x, f) - else - rethrow(e) - end - end - - function head_matcher(next, data, bindings) - h = car(data) - if islist(data) && checkhead(h) - next(bindings, 1) - else - nothing - end +function head_matcher(f::Union{Function,DataType,UnionAll}) + checkhead(x) = isequal(x, f) || isequal(x, nameof(f)) + function head_matcher(next, data, bindings) + h = car(data) + if islist(data) && checkhead(h) + next(bindings, 1) + else + nothing end + end end -head_matcher(x, mod) = matcher(x) +head_matcher(x) = matcher(x) function matcher(term::PatTerm) - op = operation(term) - matchers = (head_matcher(op, term.mod), map(matcher, arguments(term))...,) - function term_matcher(success, data, bindings) - !islist(data) && return nothing - !istree(car(data)) && return nothing - - function loop(term, bindings′, matchers′) # Get it to compile faster - # Base case, no more matchers - if !islist(matchers′) - # term is empty - if !islist(term) - # we have correctly matched the term - return success(bindings′, 1) - end - return nothing - end - car(matchers′)(term, bindings′) do b, n - # recursion case: - # take the first matcher, on success, - # keep looping by matching the rest - # by removing the first n matched elements - # from the term, with the bindings, - loop(drop_n(term, n), b, cdr(matchers′)) - end + op = operation(term) + matchers = (head_matcher(op), map(matcher, arguments(term))...) + function term_matcher(success, data, bindings) + !islist(data) && return nothing + !istree(car(data)) && return nothing + + function loop(term, bindings′, matchers′) # Get it to compile faster + # Base case, no more matchers + if !islist(matchers′) + # term is empty + if !islist(term) + # we have correctly matched the term + return success(bindings′, 1) end - - loop(car(data), bindings, matchers) # Try to eat exactly one term + return nothing + end + car(matchers′)(term, bindings′) do b, n + # recursion case: + # take the first matcher, on success, + # keep looping by matching the rest + # by removing the first n matched elements + # from the term, with the bindings, + loop(drop_n(term, n), b, cdr(matchers′)) + end end + + loop(car(data), bindings, matchers) # Try to eat exactly one term + end end +function TermInterface.similarterm( + x::Expr, + head::Union{Function,DataType}, + args, + symtype = nothing; + metadata = nothing, + exprhead = exprhead(x), +) + similarterm(x, nameof(head), args, symtype; metadata, exprhead) +end -# TODO REVIEWME function instantiate(left, pat::PatTerm, mem) - ar = arguments(pat) - args = [ instantiate(left, p, mem) for p in ar] - T = istree(typeof(left)) ? typeof(left) : Expr - similarterm(T, operation(pat), args; exprhead=exprhead(pat)) + args = [] + for parg in arguments(pat) + enqueue = parg isa PatSegment ? append! : push! + enqueue(args, instantiate(left, parg, mem)) + end + reference = istree(left) ? left : Expr(:call, :_) + similarterm(reference, operation(pat), args; exprhead = exprhead(pat)) end instantiate(left, pat::Any, mem) = pat @@ -162,10 +164,9 @@ instantiate(left, pat::Any, mem) = pat instantiate(left, pat::AbstractPat, mem) = error("Unsupported pattern ", pat) function instantiate(left, pat::PatVar, mem) - mem[pat.idx] + mem[pat.idx] end function instantiate(left, pat::PatSegment, mem) - mem[pat.idx] + mem[pat.idx] end - diff --git a/src/utils.jl b/src/utils.jl index d41d25b4..12cc9837 100644 --- a/src/utils.jl +++ b/src/utils.jl @@ -1,43 +1,62 @@ -using Base:ImmutableDict +using Base: ImmutableDict function binarize(e::T) where {T} - !istree(e) && return e - head = exprhead(e) - if head == :call - op = operation(e) - args = arguments(e) - meta = metadata(e) - if op ∈ binarize_ops && arity(e) > 2 - return foldl((x,y) -> similarterm(T, op, [x,y], symtype(e); metadata=meta, exprhead=head), args) - end + !istree(e) && return e + head = exprhead(e) + if head == :call + op = operation(e) + args = arguments(e) + meta = metadata(e) + if op ∈ binarize_ops && arity(e) > 2 + return foldl((x, y) -> similarterm(e, op, [x, y], symtype(e); metadata = meta, exprhead = head), args) end - return e -end + end + return e +end + +""" +Recursive version of binarize +""" +function binarize_rec(e::T) where {T} + !istree(e) && return e + head = exprhead(e) + op = operation(e) + args = map(binarize_rec, arguments(e)) + meta = metadata(e) + if head == :call + if op ∈ binarize_ops && arity(e) > 2 + return foldl((x, y) -> similarterm(e, op, [x, y], symtype(e); metadata = meta, exprhead = head), args) + end + end + return similarterm(e, op, args, symtype(e); metadata = meta, exprhead = head) +end + + const binarize_ops = [:(+), :(*), (+), (*)] function cleanast(e::Expr) - # TODO better line removal - if isexpr(e, :block) - return Expr(e.head, filter(x -> !(x isa LineNumberNode), e.args)...) - end - - # Binarize - if isexpr(e, :call) - op = e.args[1] - if op ∈ binarize_ops && length(e.args) > 3 - return foldl((x,y) -> Expr(:call, op, x, y), @view e.args[2:end]) - end + # TODO better line removal + if isexpr(e, :block) + return Expr(e.head, filter(x -> !(x isa LineNumberNode), e.args)...) + end + + # Binarize + if isexpr(e, :call) + op = e.args[1] + if op ∈ binarize_ops && length(e.args) > 3 + return foldl((x, y) -> Expr(:call, op, x, y), @view e.args[2:end]) end - return e + end + return e end # Linked List interface @inline assoc(d::ImmutableDict, k, v) = ImmutableDict(d, k => v) struct LL{V} - v::V - i::Int + v::V + i::Int end islist(x) = istree(x) || !isempty(x) @@ -56,107 +75,98 @@ Base.length(l::LL) = length(l.v) - l.i + 1 @inline car(v) = istree(v) ? operation(v) : first(v) @inline function cdr(v) - if istree(v) - arguments(v) - else - islist(v) ? LL(v, 2) : error("asked cdr of empty") - end + if istree(v) + arguments(v) + else + islist(v) ? LL(v, 2) : error("asked cdr of empty") + end end -@inline take_n(ll::LL, n) = isempty(ll) || n == 0 ? empty(ll) : @views ll.v[ll.i:n + ll.i - 1] # @views handles Tuple +@inline take_n(ll::LL, n) = isempty(ll) || n == 0 ? empty(ll) : @views ll.v[(ll.i):(n + ll.i - 1)] # @views handles Tuple @inline take_n(ll, n) = @views ll[1:n] @inline function drop_n(ll, n) - if n === 0 - return ll - else - istree(ll) ? drop_n(arguments(ll), n - 1) : drop_n(cdr(ll), n - 1) - end + if n === 0 + return ll + else + istree(ll) ? drop_n(arguments(ll), n - 1) : drop_n(cdr(ll), n - 1) + end end @inline drop_n(ll::Union{Tuple,AbstractArray}, n) = drop_n(LL(ll, 1), n) @inline drop_n(ll::LL, n) = LL(ll.v, ll.i + n) - + isliteral(::Type{T}) where {T} = x -> x isa T is_literal_number(x) = isliteral(Number)(x) -# checking the type directly is faster than dynamic dispatch in type unstable code -_iszero(x) = x isa Number && iszero(x) -_isone(x) = x isa Number && isone(x) -_isinteger(x) = (x isa Number && isinteger(x)) || (x isa Symbolic && symtype(x) <: Integer) -_isreal(x) = (x isa Number && isreal(x)) || (x isa Symbolic && symtype(x) <: Real) - -issortedₑ(args) = issorted(args, lt=<ₑ) -needs_sorting(f) = x -> is_operation(f)(x) && !issortedₑ(arguments(x)) - # are there nested ⋆ terms? function isnotflat(⋆) - function (x) + function (x) args = arguments(x) - for t in args - if istree(t) && operation(t) === (⋆) + for t in args + if istree(t) && operation(t) === (⋆) return true - end - end -return false + end end + return false + end end function hasrepeats(x) - length(x) <= 1 && return false - for i = 1:length(x) - 1 - if isequal(x[i], x[i + 1]) - return true - end - end - return false + length(x) <= 1 && return false + for i in 1:(length(x) - 1) + if isequal(x[i], x[i + 1]) + return true + end + end + return false end function merge_repeats(merge, xs) - length(xs) <= 1 && return false - merged = Any[] - i = 1 - - while i <= length(xs) - l = 1 - for j = i + 1:length(xs) - if isequal(xs[i], xs[j]) + length(xs) <= 1 && return false + merged = Any[] + i = 1 + + while i <= length(xs) + l = 1 + for j in (i + 1):length(xs) + if isequal(xs[i], xs[j]) l += 1 - else - break -end -end - if l > 1 - push!(merged, merge(xs[i], l)) -else - push!(merged, xs[i]) - end - i += l + else + break + end + end + if l > 1 + push!(merged, merge(xs[i], l)) + else + push!(merged, xs[i]) end - return merged + i += l + end + return merged end # Take a struct definition and make it be able to match in `@rule` macro matchable(expr) - @assert expr.head == :struct - name = expr.args[2] - if name isa Expr && name.head === :curly - name = name.args[1] - end - fields = filter(x -> !(x isa LineNumberNode), expr.args[3].args) - get_name(s::Symbol) = s - get_name(e::Expr) = (@assert(e.head == :(::)); e.args[1]) - fields = map(get_name, fields) - quote - $expr - TermInterface.istree(::$name) = true - TermInterface.istree(::Type{<:$name}) = true - TermInterface.operation(::$name) = $name - TermInterface.arguments(x::$name) = getfield.((x,), ($(QuoteNode.(fields)...),)) - TermInterface.arity(x::$name) = $(length(fields)) - Base.length(x::$name) = $(length(fields) + 1) - end |> esc + @assert expr.head == :struct + name = expr.args[2] + if name isa Expr + name.head === :(<:) && (name = name.args[1]) + name isa Expr && name.head === :curly && (name = name.args[1]) + end + fields = filter(x -> !(x isa LineNumberNode), expr.args[3].args) + get_name(s::Symbol) = s + get_name(e::Expr) = (@assert(e.head == :(::)); e.args[1]) + fields = map(get_name, fields) + quote + $expr + TermInterface.istree(::$name) = true + TermInterface.operation(::$name) = $name + TermInterface.arguments(x::$name) = getfield.((x,), ($(QuoteNode.(fields)...),)) + TermInterface.arity(x::$name) = $(length(fields)) + Base.length(x::$name) = $(length(fields) + 1) + end |> esc end @@ -165,25 +175,27 @@ using TimerOutputs const being_timed = Ref{Bool}(false) macro timer(name, expr) - :(if being_timed[] - @timeit $(esc(name)) $(esc(expr)) - else - $(esc(expr)) - end) + :( + if being_timed[] + @timeit $(esc(name)) $(esc(expr)) + else + $(esc(expr)) + end + ) end macro iftimer(expr) - esc(expr) + esc(expr) end function timerewrite(f) - reset_timer!() - being_timed[] = true - x = f() - being_timed[] = false - print_timer() - println() - x + reset_timer!() + being_timed[] = true + x = f() + being_timed[] = false + print_timer() + println() + x end """ @@ -221,5 +233,5 @@ julia> @timerewrite simplify(expr) ``` """ macro timerewrite(expr) - :(timerewrite(()->$(esc(expr)))) + :(timerewrite(() -> $(esc(expr)))) end diff --git a/test/EGraphs/analysis.jl b/test/EGraphs/analysis.jl new file mode 100644 index 00000000..c54f6cfe --- /dev/null +++ b/test/EGraphs/analysis.jl @@ -0,0 +1,361 @@ +# example assuming * operation is always binary + +# ENV["JULIA_DEBUG"] = Metatheory + +using Metatheory +using Metatheory.Library +using TermInterface + +EGraphs.make(::Val{:numberfold}, g::EGraph, n::ENodeLiteral) = n.value + + +# This should be auto-generated by a macro +function EGraphs.make(::Val{:numberfold}, g::EGraph, n::ENodeTerm) + if exprhead(n) == :call && arity(n) == 2 + op = operation(n) + args = arguments(n) + l = g[args[1]] + r = g[args[2]] + ldata = getdata(l, :numberfold, nothing) + rdata = getdata(r, :numberfold, nothing) + + # @show ldata rdata + + if ldata isa Number && rdata isa Number + if op == :* + return ldata * rdata + elseif op == :+ + return ldata + rdata + end + end + end + + return nothing +end + +function EGraphs.join(an::Val{:numberfold}, from, to) + if from isa Number + if to isa Number + @assert from == to + else + return from + end + end + return to +end + +function EGraphs.modify!(::Val{:numberfold}, g::EGraph, id::Int64) + eclass = g.classes[id] + d = getdata(eclass, :numberfold, nothing) + if d isa Number + merge!(g, addexpr!(g, d), id) + end +end + +EGraphs.islazy(::Val{:numberfold}) = false + + +comm_monoid = @theory begin + ~a * ~b --> ~b * ~a + ~a * 1 --> ~a + ~a * (~b * ~c) --> (~a * ~b) * ~c +end + +G = EGraph(:(3 * 4)) +analyze!(G, :numberfold) + +# exit(0) + +@testset "Basic Constant Folding Example - Commutative Monoid" begin + @test (true == @areequalg G comm_monoid 3 * 4 12) + + @test (true == @areequalg G comm_monoid 3 * 4 12 4 * 3 6 * 2) +end + +@testset "Basic Constant Folding Example 2 - Commutative Monoid" begin + ex = :(a * 3 * b * 4) + G = EGraph(ex) + analyze!(G, :numberfold) + addexpr!(G, :(12 * a)) + @test (true == @areequalg G comm_monoid (12 * a) * b ((6 * 2) * b) * a) + @test (true == @areequalg G comm_monoid (3 * a) * (4 * b) (12 * a) * b ((6 * 2) * b) * a) +end + +@testset "Basic Constant Folding Example - Adding analysis after saturation" begin + G = EGraph(:(3 * 4)) + # addexpr!(G, 12) + saturate!(G, comm_monoid) + addexpr!(G, :(a * 2)) + analyze!(G, :numberfold) + saturate!(G, comm_monoid) + + @test (true == areequal(G, comm_monoid, :(3 * 4), 12, :(4 * 3), :(6 * 2))) + + ex = :(a * 3 * b * 4) + G = EGraph(ex) + analyze!(G, :numberfold) + params = SaturationParams(timeout = 15) + @test areequal(G, comm_monoid, :((3 * a) * (4 * b)), :((12 * a) * b), :(((6 * 2) * b) * a); params = params) +end + +@testset "Infinite Loops analysis" begin + boson = @theory begin + 1 * ~x --> ~x + end + + + G = EGraph(:(1 * x)) + params = SaturationParams(timeout = 100) + saturate!(G, boson, params) + ex = extract!(G, astsize) + + + boson = @theory begin + (:c * :cdag) --> :cdag * :c + 1 + ~a * (~b + ~c) --> (~a * ~b) + (~a * ~c) + (~b + ~c) * ~a --> (~b * ~a) + (~c * ~a) + # 1 * x => x + (~a * ~b) * ~c --> ~a * (~b * ~c) + ~a * (~b * ~c) --> (~a * ~b) * ~c + end + + g = EGraph(:(c * c * cdag * cdag)) + saturate!(g, boson) + ex = extract!(g, astsize_inv) + +end + +@testset "Extraction" begin + comm_monoid = @commutative_monoid (*) 1 + + fold_mul = @theory begin + ~a::Number * ~b::Number => ~a * ~b + end + + t = comm_monoid ∪ fold_mul + + + @testset "Extraction 1 - Commutative Monoid" begin + G = EGraph(:(3 * 4)) + saturate!(G, t) + @test (12 == extract!(G, astsize)) + + ex = :(a * 3 * b * 4) + G = EGraph(ex) + params = SaturationParams(timeout = 15) + saturate!(G, t, params) + extr = extract!(G, astsize) + @test extr == :((12 * a) * b) || + extr == :(12 * (a * b)) || + extr == :(a * (b * 12)) || + extr == :((a * b) * 12) || + extr == :((12a) * b) || + extr == :(a * (12b)) || + extr == :((b * (12a))) || + extr == :((b * 12) * a) || + extr == :((b * a) * 12) || + extr == :(b * (a * 12)) || + extr == :((12b) * a) + end + + fold_add = @theory begin + ~a::Number + ~b::Number => ~a + ~b + end + + @testset "Extraction 2" begin + comm_group = @commutative_group (+) 0 inv + + + t = comm_monoid ∪ comm_group ∪ (@distrib (*) (+)) ∪ fold_mul ∪ fold_add + + # for i ∈ 1:20 + # sleep(0.3) + ex = :((x * (a + b)) + (y * (a + b))) + G = EGraph(ex) + saturate!(G, t) + # end + + extract!(G, astsize) == :((y + x) * (b + a)) + end + + @testset "Extraction - Adding analysis after saturation" begin + G = EGraph(:(3 * 4)) + addexpr!(G, 12) + saturate!(G, t) + addexpr!(G, :(a * 2)) + saturate!(G, t) + + saturate!(G, t) + + @test (12 == extract!(G, astsize)) + + # for i ∈ 1:100 + ex = :(a * 3 * b * 4) + G = EGraph(ex) + analyze!(G, :numberfold) + params = SaturationParams(timeout = 15) + saturate!(G, comm_monoid, params) + + extr = extract!(G, astsize) + + @test extr == :((12 * a) * b) || + extr == :(12 * (a * b)) || + extr == :(a * (b * 12)) || + extr == :((a * b) * 12) || + extr == :((12a) * b) || + extr == :(a * (12b)) || + extr == :((b * (12a))) || + extr == :((b * 12) * a) || + extr == :((b * a) * 12) || + extr == :(b * (a * 12)) + end + + + comm_monoid = @commutative_monoid (*) 1 + + comm_group = @commutative_group (+) 0 inv + + powers = @theory begin + ~a * ~a → (~a)^2 + ~a → (~a)^1 + (~a)^~n * (~a)^~m → (~a)^(~n + ~m) + end + logids = @theory begin + log((~a)^~n) --> ~n * log(~a) + log(~x * ~y) --> log(~x) + log(~y) + log(1) --> 0 + log(:e) --> 1 + :e^(log(~x)) --> ~x + end + + G = EGraph(:(log(e))) + params = SaturationParams(timeout = 9) + saturate!(G, logids, params) + @test extract!(G, astsize) == 1 + + + t = comm_monoid ∪ comm_group ∪ (@distrib (*) (+)) ∪ powers ∪ logids ∪ fold_mul ∪ fold_add + + @testset "Complex Extraction" begin + G = EGraph(:(log(e) * log(e))) + params = SaturationParams(timeout = 9) + saturate!(G, t, params) + @test extract!(G, astsize) == 1 + + G = EGraph(:(log(e) * (log(e) * e^(log(3))))) + params = SaturationParams(timeout = 7) + saturate!(G, t, params) + @test extract!(G, astsize) == 3 + + @time begin + G = EGraph(:(a^3 * a^2)) + saturate!(G, t) + ex = extract!(G, astsize) + end + @test ex == :(a^5) + + @time begin + G = EGraph(:(a^3 * a^2)) + saturate!(G, t) + ex = extract!(G, astsize) + end + @test ex == :(a^5) + + function cust_astsize(n::ENodeTerm, g::EGraph) + cost = 1 + arity(n) + + if operation(n) == :^ + cost += 2 + end + + for id in arguments(n) + eclass = g[id] + !hasdata(eclass, cust_astsize) && (cost += Inf; break) + cost += last(getdata(eclass, cust_astsize)) + end + return cost + end + + + cust_astsize(n::ENodeLiteral, g::EGraph) = 1 + + @time begin + G = EGraph(:((log(e) * log(e)) * (log(a^3 * a^2)))) + saturate!(G, t) + @show getcost!(G, cust_astsize) + ex = extract!(G, cust_astsize) + end + @show ex + @test ex == :(5 * log(a)) || ex == :(log(a) * 5) + end + + function costfun(n::ENodeTerm, g::EGraph) + arity(n) != 2 && (return 1) + left = arguments(n)[1] + left_class = g[left] + ENodeLiteral(:a) ∈ left_class.nodes ? 1 : 100 + end + + costfun(n::ENodeLiteral, g::EGraph) = 1 + + + moveright = @theory begin + (:b * (:a * ~c)) --> (:a * (:b * ~c)) + end + + expr = :(a * (a * (b * (a * b)))) + res = rewrite(expr, moveright) + + g = EGraph(expr) + saturate!(g, moveright) + resg = extract!(g, costfun) + + @testset "Symbols in Right hand" begin + @test resg == res == :(a * (a * (a * (b * b)))) + end + + function ⋅ end + co = @theory begin + sum(~x ⋅ :bazoo ⋅ :woo) --> sum(:n * ~x) + end + @testset "Consistency with classical backend" begin + ex = :(sum(wa(rio) ⋅ bazoo ⋅ woo)) + g = EGraph(ex) + saturate!(g, co) + + res = extract!(g, astsize) + + resclassic = rewrite(ex, co) + + @test res == resclassic + end + + + @testset "No arguments" begin + ex = :(f()) + g = EGraph(ex) + @test :(f()) == extract!(g, astsize) + + ex = :(sin() + cos()) + + t = @theory begin + sin() + cos() --> tan() + end + + gg = EGraph(ex) + saturate!(gg, t) + @show getcost!(gg, astsize) + res = extract!(gg, astsize) + + @test res == :(tan()) + end + + + @testset "Symbol or function object operators in expressions in EGraphs" begin + ex = :(($+)(x, y)) + t = [@rule a b a + b => 2] + g = EGraph(ex) + saturate!(g, t) + @test extract!(g, astsize) == 2 + end +end diff --git a/test/EGraphs/egraphs.jl b/test/EGraphs/egraphs.jl new file mode 100644 index 00000000..d58ad0bf --- /dev/null +++ b/test/EGraphs/egraphs.jl @@ -0,0 +1,73 @@ + +# ENV["JULIA_DEBUG"] = Metatheory +using Metatheory +using Metatheory.EGraphs +using Metatheory.EGraphs: in_same_set, find_root + +@testset "Merging" begin + testexpr = :((a * 2) / 2) + testmatch = :(a << 1) + G = EGraph(testexpr) + t2 = addexpr!(G, testmatch) + merge!(G, t2, EClassId(3)) + @test in_same_set(G.uf, t2, EClassId(3)) == true + # DOES NOT UPWARD MERGE +end + +# testexpr = :(42a + b * (foo($(Dict(:x => 2)), 42))) + +@testset "Simple congruence - rebuilding" begin + G = EGraph() + ec1 = addexpr!(G, :(f(a, b))) + ec2 = addexpr!(G, :(f(a, c))) + + testexpr = :(f(a, b) + f(a, c)) + + testec = addexpr!(G, testexpr) + + t1 = addexpr!(G, :b) + t2 = addexpr!(G, :c) + + c_id = merge!(G, t2, t1) + @test in_same_set(G.uf, c_id, t1) + @test in_same_set(G.uf, t2, t1) + rebuild!(G) + @test in_same_set(G.uf, ec1, ec2) +end + + +@testset "Simple nested congruence" begin + apply(n, f, x) = n == 0 ? x : apply(n - 1, f, f(x)) + f(x) = Expr(:call, :f, x) + + G = EGraph(:a) + + t1 = addexpr!(G, apply(6, f, :a)) + t2 = addexpr!(G, apply(9, f, :a)) + + c_id = merge!(G, t1, EClassId(1)) # a == apply(6,f,a) + c2_id = merge!(G, t2, EClassId(1)) # a == apply(9,f,a) + + + rebuild!(G) + + + t3 = addexpr!(G, apply(3, f, :a)) + t4 = addexpr!(G, apply(7, f, :a)) + + # f^m(a) = a = f^n(a) ⟹ f^(gcd(m,n))(a) = a + @test in_same_set(G.uf, t1, EClassId(1)) == true + @test in_same_set(G.uf, t2, EClassId(1)) == true + @test in_same_set(G.uf, t3, EClassId(1)) == true + @test in_same_set(G.uf, t4, EClassId(1)) == false + + # if m or n is prime, f(a) = a + t5 = addexpr!(G, apply(11, f, :a)) + t6 = addexpr!(G, apply(1, f, :a)) + c5_id = merge!(G, t5, EClassId(1)) # a == apply(11,f,a) + + rebuild!(G) + + @test in_same_set(G.uf, t5, EClassId(1)) == true + @test in_same_set(G.uf, t6, EClassId(1)) == true +end diff --git a/test/EGraphs/ematch.jl b/test/EGraphs/ematch.jl new file mode 100644 index 00000000..72a6e58f --- /dev/null +++ b/test/EGraphs/ematch.jl @@ -0,0 +1,177 @@ +using Metatheory +using Test +using Metatheory.Library + +falseormissing(x) = x === missing || !x + +r = @theory begin + max(~x, ~y) → 2 * ~x % ~y + max(~x, ~y) → sin(~x) + sin(~x) → max(~x, ~x) +end +@testset "Basic Equalities 1" begin + @test (@areequal r max(b, c) max(d, d)) == false +end + + +r = @theory begin + ~a * 1 → :foo + ~a * 2 → :bar + 1 * ~a → :baz + 2 * ~a → :mag +end + +@testset "Matching Literals" begin + g = EGraph(:(a * 1)) + addexpr!(g, :foo) + saturate!(g, r) + + @test (@areequal r a * 1 foo) == true + @test (@areequal r a * 2 foo) == false + @test (@areequal r a * 1 bar) == false + @test (@areequal r a * 2 bar) == true + + @test (@areequal r 1 * a baz) == true + @test (@areequal r 2 * a baz) == false + @test (@areequal r 1 * a mag) == false + @test (@areequal r 2 * a mag) == true +end + + +comm_monoid = @commutative_monoid (*) 1 +@testset "Basic Equalities - Commutative Monoid" begin + @test true == (@areequal comm_monoid a * (c * (1 * d)) c * (1 * (d * a))) + @test true == (@areequal comm_monoid x * y y * x) + @test true == (@areequal comm_monoid (x * x) * (x * 1) x * (x * x)) +end + + +comm_group = @commutative_group (+) 0 inv +t = comm_monoid ∪ comm_group ∪ (@distrib (*) (+)) + + +@testset "Basic Equalities - Comm. Monoid, Abelian Group, Distributivity" begin + @test true == (@areequal t (a * b) + (a * c) a * (b + c)) + @test true == (@areequal t a * (c * (1 * d)) c * (1 * (d * a))) + @test true == (@areequal t a + (b * (c * d)) ((d * c) * b) + a) + @test true == (@areequal t (x + y) * (a + b) ((a * (x + y)) + b * (x + y)) ((x * (a + b)) + y * (a + b))) + @test true == (@areequal t (((x * a + x * b) + y * a) + y * b) (x + y) * (a + b)) + @test true == (@areequal t a + (b * (c * d)) ((d * c) * b) + a) + @test true == (@areequal t a + inv(a) 0 (x * y) + inv(x * y) 1 * 0) +end + + +@testset "Basic Equalities - False statements" begin + @test falseormissing(@areequal t (a * b) + (a * c) a * (b + a)) + @test falseormissing(@areequal t (a * c) + (a * c) a * (b + c)) + @test falseormissing(@areequal t a * (c * c) c * (1 * (d * a))) + @test falseormissing(@areequal t c + (b * (c * d)) ((d * c) * b) + a) + @test falseormissing(@areequal t (x + y) * (a + c) ((a * (x + y)) + b * (x + y))) + @test falseormissing(@areequal t ((x * (a + b)) + y * (a + b)) (x + y) * (a + c)) + @test falseormissing(@areequal t (((x * a + x * b) + y * a) + y * b) (x + y) * (a + x)) + @test falseormissing(@areequal t a + (b * (c * a)) ((d * c) * b) + a) + @test falseormissing(@areequal t a + inv(a) a) + @test falseormissing(@areequal t (x * y) + inv(x * y) 1) +end + +# Issue 21 +simp_theory = @theory begin + sin() => :foo +end +g = EGraph(:(sin())) +saturate!(g, simp_theory) +@test extract!(g, astsize) == :foo + +module Bar +foo = 42 +export foo +using Metatheory + +t = @theory begin + :woo => foo +end +export t +end + +module Foo +foo = 12 +using Metatheory + +t = @theory begin + :woo => foo +end +export t +end + + +g = EGraph(:woo); +saturate!(g, Bar.t); +saturate!(g, Foo.t); +foo = 12 + +@testset "Different modules" begin + @test @areequalg g t 42 12 +end + + +comm_monoid = @theory begin + ~a * ~b --> ~b * ~a + ~a * 1 --> ~a + ~a * (~b * ~c) --> (~a * ~b) * ~c + ~a::Number * ~b::Number => ~a * ~b +end + +G = EGraph(:(3 * 4)) +@testset "Basic Constant Folding Example - Commutative Monoid" begin + @test (true == @areequalg G comm_monoid 3 * 4 12) + @test (true == @areequalg G comm_monoid 3 * 4 12 4 * 3 6 * 2) +end + + +@testset "Basic Constant Folding Example 2 - Commutative Monoid" begin + ex = :(a * 3 * b * 4) + G = EGraph(ex) + @test (true == @areequalg G comm_monoid (3 * a) * (4 * b) (12 * a) * b ((6 * 2) * b) * a) +end + +@testset "Type Assertions in Ematcher" begin + some_theory = @theory begin + ~a * ~b --> ~b * ~a + ~a::Number * ~b::Number --> sin(~a, ~b) + ~a::Int64 * ~b::Int64 --> cos(~a, ~b) + ~a * (~b * ~c) --> (~a * ~b) * ~c + end + + g = EGraph(:(2 * 3)) + saturate!(g, some_theory) + + @test true == areequal(g, some_theory, :(2 * 3), :(sin(2, 3))) + @test true == areequal(g, some_theory, :(sin(2, 3)), :(cos(3, 2))) +end + +Base.iszero(ec::EClass) = ENodeLiteral(0) ∈ ec + +@testset "Predicates in Ematcher" begin + some_theory = @theory begin + ~a::iszero * ~b --> 0 + ~a * ~b --> ~b * ~a + end + + g = EGraph(:(2 * 3)) + saturate!(g, some_theory) + + @test true == areequal(g, some_theory, :(a * b * 0), 0) +end + +@testset "Inequalities" begin + failme = @theory p begin + p ≠ !p + :foo == !:foo + :foo --> :bazoo + :bazoo --> :wazoo + end + + g = EGraph(:foo) + report = saturate!(g, failme) + @test report.reason === :contradiction +end diff --git a/test/cas/cas_infer.jl b/test/cas/cas_infer.jl deleted file mode 100644 index 35e2e748..00000000 --- a/test/cas/cas_infer.jl +++ /dev/null @@ -1,59 +0,0 @@ -using Metatheory -using Metatheory.Library -using Metatheory.EGraphs -using Metatheory.EGraphs.Schedulers -using TermInterface - -abstract type TypeAnalysis <: AbstractAnalysis end - -function EGraphs.make(an::Type{TypeAnalysis}, g::EGraph, n::ENodeTerm) - Any -end - -function EGraphs.make(an::Type{TypeAnalysis}, g::EGraph, n::ENodeLiteral) - v = n.value - if v == :im - typeof(im) - else - typeof(v) - end -end - -function EGraphs.make(an::Type{TypeAnalysis}, g::EGraph, n::ENodeTerm{Expr}) - if exprhead(n) != :call - # println("$n is not a call") - t = Any - # println("analyzed type of $n is $t") - return t - end - sym = operation(n) - if !(sym isa Symbol) - # println("head $sym is not a symbol") - t = Any - # println("analyzed type of $n is $t") - return t - end - - symval = getfield(@__MODULE__, sym) - child_classes = map(x -> g[x], arguments(n)) - child_types = Tuple(map(x -> getdata(x, an, Any), child_classes)) - - # t = t_arr[1] - t = Core.Compiler.return_type(symval, child_types) - - if t == Union{} - throw(MethodError(symval, child_types)) - end - # println("analyzed type of $n is $t") - return t -end - -EGraphs.join(an::Type{TypeAnalysis}, from, to) = typejoin(from, to) - -EGraphs.islazy(x::Type{TypeAnalysis}) = true - -function infer(e) - g = EGraph(e) - analyze!(g, TypeAnalysis) - getdata(g[g.root], TypeAnalysis) -end diff --git a/test/cas/cas_simplify.jl b/test/cas/cas_simplify.jl deleted file mode 100644 index e391436a..00000000 --- a/test/cas/cas_simplify.jl +++ /dev/null @@ -1,85 +0,0 @@ -using Metatheory -using Metatheory.Library -using Metatheory.EGraphs -using Metatheory.EGraphs.Schedulers -using TermInterface - -function customlt(x,y) - if typeof(x) == Expr && typeof(y) == Expr - false - elseif typeof(x) == typeof(y) - isless(x,y) - elseif x isa Symbol && y isa Number - false - elseif x isa Expr && y isa Number - false - elseif x isa Expr && y isa Symbol - false - else true end -end - -canonical_t = @theory begin - # restore n-arity - (x * x) => x^2 - (x^n::Number * x) => x^(n+1) - (x * x^n::Number) => x^(n+1) - (x + (+)(ys...)) => +(x,ys...) - ((+)(xs...) + y) => +(xs..., y) - (x * (*)(ys...)) => *(x,ys...) - ((*)(xs...) * y) => *(xs..., y) - - (*)(xs...) |> Expr(:call, :*, sort!(xs; lt=customlt)...) - (+)(xs...) |> Expr(:call, :+, sort!(xs; lt=customlt)...) -end - - -function simplcost(n::ENodeTerm, g::EGraph, an::Type{<:AbstractAnalysis}) - cost = 0 + arity(n) - if operation(n) == :∂ - cost += 20 - end - for id ∈ arguments(n) - eclass = g[id] - !hasdata(eclass, an) && (cost += Inf; break) - cost += last(getdata(eclass, an)) - end - return cost -end - -simplcost(n::ENodeLiteral, g::EGraph, an::Type{<:AbstractAnalysis}) = 0 - -function simplify(ex; steps=4) - params = SaturationParams( - scheduler=ScoredScheduler, - eclasslimit=5000, - timeout=7, - schedulerparams=(1000,5, Schedulers.exprsize), - #stopwhen=stopwhen, - ) - hist = UInt64[] - push!(hist, hash(ex)) - for i ∈ 1:steps - g = EGraph(ex) - saturate!(g, cas, params) - ex = extract!(g, simplcost) - ex = rewrite(ex, canonical_t) - if !TermInterface.istree(typeof(ex)) - return ex - end - if hash(ex) ∈ hist - println("loop detected $ex") - return ex - end - println(ex) - push!(hist, hash(ex)) - end - # println(res) - # for (id, ec) ∈ g.classes - # println(id, " => ", collect(ec.nodes)) - # println("\t\t", getdata(ec, ExtractionAnalysis{astsize})) - # end - -end -macro simplify(ex) - Meta.quot(simplify(ex)) -end \ No newline at end of file diff --git a/test/cas/cas_theory.jl b/test/cas/cas_theory.jl deleted file mode 100644 index ecc6f3e8..00000000 --- a/test/cas/cas_theory.jl +++ /dev/null @@ -1,86 +0,0 @@ -## Theory for CAS -using Metatheory -using Metatheory.Library -using Metatheory.EGraphs -using Metatheory.EGraphs.Schedulers -using TermInterface - -mult_t = @commutative_monoid (*) 1 -plus_t = @commutative_monoid (+) 0 - -minus_t = @theory begin - # TODO Jacques Carette's post in zulip chat - a - a => 0 - a - b => a + (-1*b) - -a => -1 * a - a + (-b) => a + (-1*b) -end - - -mulplus_t = @theory begin - # TODO FIXME this rules improves performance and avoids commutative - # explosion of the egraph - a + a => 2 * a - 0 * a => 0 - a * 0 => 0 - a * (b + c) == ((a*b) + (a*c)) - a + (b * a) => ((b+1)*a) -end - -pow_t = @theory begin - (y^n) * y => y^(n+1) - x^n * x^m == x^(n+m) - (x * y)^z == x^z * y^z - (x^p)^q == x^(p*q) - x^0 => 1 - 0^x => 0 - 1^x => 1 - x^1 => x - x * x => x^2 - inv(x) == x^(-1) -end - -div_t = @theory begin - x / 1 => x - # x / x => 1 TODO SIGN ANALYSIS - x / (x / y) => y - x * (y / x) => y - x * (y / z) == (x * y) / z - x^(-1) == 1 / x -end - -trig_t = @theory begin - sin(θ)^2 + cos(θ)^2 => 1 - sin(θ)^2 - 1 => cos(θ)^2 - cos(θ)^2 - 1 => sin(θ)^2 - tan(θ)^2 - sec(θ)^2 => 1 - tan(θ)^2 + 1 => sec(θ)^2 - sec(θ)^2 - 1 => tan(θ)^2 - - cot(θ)^2 - csc(θ)^2 => 1 - cot(θ)^2 + 1 => csc(θ)^2 - csc(θ)^2 - 1 => cot(θ)^2 -end - -# Dynamic rules -fold_t = @theory begin - -(a::Number) |> -a - a::Number + b::Number |> a + b - a::Number * b::Number |> a * b - a::Number ^ b::Number |> begin b < 0 && a isa Int && (a = float(a)) ; a^b end - a::Number / b::Number |> a/b -end - -using Calculus: differentiate -diff_t = @theory begin - ∂(y, x::Symbol) |> begin - z = extract!(_egraph, simplcost; root=y.id) - @show z - zd = differentiate(z, x) - @show zd - zd - end -end - -cas = fold_t ∪ mult_t ∪ plus_t ∪ minus_t ∪ - mulplus_t ∪ pow_t ∪ div_t ∪ trig_t ∪ diff_t diff --git a/test/cas/test_cas.jl b/test/cas/test_cas.jl deleted file mode 100644 index 3ffea9fb..00000000 --- a/test/cas/test_cas.jl +++ /dev/null @@ -1,59 +0,0 @@ -using Test -include("cas_theory.jl") -include("cas_simplify.jl") - -@test :(4a) == @simplify 2a + a + a -@test :(a*b*c) == @simplify a * c * b -@test :(2x) == @simplify 1 * x * 2 -@test :((a*b)^2) == @simplify (a*b)^2 -@test :((a*b)^6) == @simplify (a^2*b^2)^3 -@test :(a+b+d) == @simplify a + b + (0*c) + d -@test :(a+b) == @simplify a + b + (c*0) + d - d -@test :(a) == @simplify (a + d) - d -@test :(a + b + d) == @simplify a + b * c^0 + d -@test :(a * b * x ^ (d+y)) == @simplify a * x^y * b * x^d -@test :(a * b * x ^ 74103) == @simplify a * x^(12 + 3) * b * x^(42^3) - -@test 1 == @simplify (x+y)^(a*0) / (y+x)^0 -@test 2 == @simplify cos(x)^2 + 1 + sin(x)^2 -@test 2 == @simplify cos(y)^2 + 1 + sin(y)^2 -@test 2 == @simplify sin(y)^2 + cos(y)^2 + 1 - -@test :(y + sec(x)^2 ) == @simplify 1 + y + tan(x)^2 -@test :(y + csc(x)^2 ) == @simplify 1 + y + cot(x)^2 - - - -# @simplify ∂(x^2, x) - -@time @simplify ∂(x^(cos(x)), x) - -@test :(2x^3) == @simplify x * ∂(x^2, x) * x - -# @simplify ∂(y^3, y) * ∂(x^2 + 2, x) / y * x - -# @simplify (6 * x * x * y) - -# @simplify ∂(y^3, y) / y - -# # ex = :( ∂(x^(cos(x)), x) ) -# ex = :( (6 * x * x * y) ) -# g = EGraph(ex) -# saturate!(g, cas) -# g.classes -# extract!(g, simplcost; root=g.root) - -# params = SaturationParams( -# scheduler=BackoffScheduler, -# eclasslimit=5000, -# timeout=7, -# schedulerparams=(1000,5), -# #stopwhen=stopwhen, -# ) - -# ex = :((x+y)^(a*0) / (y+x)^0) -# g = EGraph(ex) -# @profview println(saturate!(g, cas, params)) - -# ex = extract!(g, simplcost) -# ex = rewrite(ex, canonical_t; clean=false) diff --git a/test/cas/test_infer.jl b/test/cas/test_infer.jl deleted file mode 100644 index 1e5bdc19..00000000 --- a/test/cas/test_infer.jl +++ /dev/null @@ -1,14 +0,0 @@ -using Test - -include("cas_infer.jl") - -ex1 = :(cos(1 + 3.0) + 4 + (4-4im)) -ex2 = :("ciao" * 2) -ex3 = :("ciao" * " mondo") - -@test ComplexF64 == infer(ex1) -@test_throws MethodError infer(ex2) -@test String == infer(ex3) - - - diff --git a/test/category/catlab.jl b/test/category/catlab.jl deleted file mode 100644 index cbc9364c..00000000 --- a/test/category/catlab.jl +++ /dev/null @@ -1,327 +0,0 @@ -using Catlab -using Catlab.Theories -using Catlab.Syntax - -using Metatheory, Metatheory.EGraphs -using TermInterface - - -abstract type CatType end -struct ObType <: CatType - ob - mod -end -struct HomType <: CatType - dom - codom - mod -end - -# Custom type APIs for the GATExpr -using Metatheory.TermInterface -TermInterface.operation(t::ObExpr) = :call -TermInterface.arguments(t::ObExpr) = [head(t), t.args...] -TermInterface.operation(t::HomExpr) = :call -TermInterface.arguments(t::HomExpr) = [head(t), t.args...] - -# Type information will be stored in the metadata -function TermInterface.metadata(t::HomExpr) - return HomType(t.type_args[1], t.type_args[2], typeof(t).name.module) -end -TermInterface.metadata(t::ObExpr) = ObType(t, typeof(t).name.module) -TermInterface.istree(t::GATExpr) = true -TermInterface.arity(t::GATExpr) = length(arguments(t)) - -struct CatlabAnalysis <: AbstractAnalysis end -function EGraphs.make(an::Type{CatlabAnalysis}, g::EGraph, n::ENode{T}) where T - !(T <: GATExpr) && return T - return metadata(n) -end -EGraphs.join(an::Type{CatlabAnalysis}, from, to) = from -EGraphs.islazy(x::Type{CatlabAnalysis}) = false - -function infer(t::GATExpr) - g = EGraph(t) - analyze!(g, CatlabAnalysis) - getdata(g[g.root], CatlabAnalysis) -end - -function EGraphs.extractnode(g::EGraph, n::ENode{T}, extractor::Function) where {T <: ObExpr} - @assert n.head == :call - return metadata(n).ob -end - -function EGraphs.extractnode(g::EGraph, n::ENode{T}, extractor::Function) where {T <: HomExpr} - @assert n.head == :call - nargs = extractor.(arguments(n)) - nmeta = metadata(n) - return nmeta.mod.Hom{nargs[1]}(nargs[2:end], GATExpr[nmeta.dom, nmeta.codom]) -end - -# ============================================================================== - -using MatchCore - -datasym(x::Symbol) = Symbol(String(x) * "_data") -extrsym(x::Symbol) = Symbol(String(x) * "_extr") - -function build_rhs(x::Expr, pvars, mod) - if Meta.isexpr(x, :call) - if x.args[1] == :munit && length(x.args) == 1 - mod.munit(mod.Ob) - else - Expr(x.head, getfield(mod, x.args[1]), map(y -> build_rhs(y, pvars, mod), x.args[2:end])...) - end - else - Expr(x.head, map(y -> build_rhs(y, pvars, mod), x.args)...) - end -end -function build_rhs(x, pvars, mod) - if x ∈ pvars - extrsym(x) - else - x - end -end -function gen_rule(axiom::Catlab.GAT.AxiomConstructor, mod; righttoleft=false) - # left to right - @assert axiom.name == :(==) - - ax_left = axiom.left - ax_right = axiom.right - if righttoleft - ax_left = axiom.right - ax_right = axiom.left - end - - lhs = Pattern(ax_left, mod) - l_pvars = patvars(lhs) - - rhs = build_rhs(ax_right, l_pvars, mod) - # println(rhs) - - lines = [] - - eq_ctx = Dict{Symbol, Vector{Any}}() - for patvar in l_pvars - # retrieve the catlab data - data_var = datasym(patvar) - data_expr = :($data_var = getdata($patvar, CatlabAnalysis)) - push!(lines, data_expr) - # push!(lines, :(println($data_var))) - - - extr_var = extrsym(patvar) - extr_expr = :($extr_var = extract!(_egraph, astsize; root=($patvar).id)) - push!(lines, extr_expr) - # push!(lines, :(println($extr_var))) - - - ctxval = axiom.context[patvar] - # TODO use GATTheory.types - @smatch ctxval begin - :Ob => begin - check_type_line = :(!($data_var isa ObType) && (return _lhs_expr)) - aset = get!(()->[], eq_ctx, patvar) - push!(lines, check_type_line) - push!(aset, :($(data_var).ob)) - end - :(Hom($(A::Symbol), $(B::Symbol))) => begin - aset = get!(()->[], eq_ctx, A) - bset = get!(()->[], eq_ctx, B) - push!(aset, :($(data_var).dom)) - push!(bset, :($(data_var).codom)) - check_type_line = :(!($data_var isa HomType) && (return _lhs_expr)) - push!(lines, check_type_line) - end - _ => error("unrecognized GAT type context $patvar => $ctxval") - end - end - - for (ctxvar, eqset) in eq_ctx - if ctxvar ∉ l_pvars - push!(lines, :($ctxvar = $(eqset[1]))) - end - end - - # conjunction of equalities needed - conjunction = [] - - for (ctxvar, eqset) in eq_ctx - unique!(eqset) - c = [] - if length(eqset) < 2 - continue - end - fst = first(eqset) - for other in eqset[2:end] - push!(c, :($fst == $other)) - end - append!(conjunction, c) - end - - - - if !isempty(conjunction) - conj_expr = foldl((x,y) -> :($x && $y), conjunction) - - the_big_if = :(if $conj_expr - # WORKAROUND FOR RuntimeGeneratedFunctions.jl `id` bug - # return $(evalmod).eval($(Meta.quot(ax_right))) - return $rhs - else - return _lhs_expr end) |> Metatheory.rmlines - push!(lines, the_big_if) - else - push!(lines, :(return $rhs)) - end - - block = Expr(:block, lines...) - - DynamicRule(lhs, block) -end - -# test -tt = theory(SymmetricMonoidalCategory) -ax = tt.axioms[2] - -gen_rule(tt.axioms[2], @__MODULE__) - -# Generate a theory from a syntax system -function gen_theory(m::Module) - gat_theory = theory(m.theory()) - mt_theory = Rule[] - for axiom in gat_theory.axioms - push!(mt_theory, gen_rule(axiom, m)) - push!(mt_theory, gen_rule(axiom, m, righttoleft=true)) - end - mt_theory -end - - -# ==================================================== -# TEST - -# WE HAVE TO REDEFINE THE SYNTAX TO AVOID ASSOCIATIVITY AND N-ARY FUNCTIONS -import Catlab.Theories: id, compose, otimes, ⋅, braid, σ, ⊗, Ob, Hom -@syntax SMC{ObExpr,HomExpr} SymmetricMonoidalCategory begin -end - -function simplify(ex, syntax) - t = gen_theory(syntax) - g = EGraph(ex) - analyze!(g, CatlabAnalysis) - params=SaturationParams(timeout=3) - saturate!(g, t, params) - extract!(g, astsize) -end - -A, B, C = Ob(SMC, :A, :B, :C) -f = Hom(:f, A, B) - -ex = f ⋅ id(B) -simplify(ex, SMC) == f - -ex = id(A) ⋅ id(A) ⋅ f ⋅ id(B) -simplify(ex, SMC) == f - -ex = σ(A,B) ⋅ σ(B,A) -simplify(ex, SMC) == id(A ⊗ B) - - -# ====================================================== -# another test - -using Catlab.Graphics - -l = (σ(C,B) ⊗ id(A)) ⋅ (id(B) ⊗ σ(C,A)) ⋅ (σ(B,A) ⊗ id(C)) -r = (id(C) ⊗ σ(B,A)) ⋅ (σ(C,A) ⊗ id(B)) ⋅ (id(A) ⊗ σ(C,B)) - -to_graphviz(l) -to_graphviz(r) - -g = EGraph() -analyze!(g, CatlabAnalysis) -l_ec, _ = addexpr!(g, l) -r_ec, _ = addexpr!(g, r) - -in_same_class(g, l_ec, r_ec) - -saturate!(g, gen_theory(SMC), SaturationParams(timeout=1, eclasslimit=6000)) - -ll = extract!(g, astsize; root=l_ec.id) -rr = extract!(g, astsize; root=r_ec.id) - -# ====================================================== -# another test - -# WE HAVE TO REDEFINE THE SYNTAX TO AVOID ASSOCIATIVITY AND N-ARY FUNCTIONS -import Catlab.Theories: id, compose, otimes, ⋅, braid, σ, ⊗, Ob, Hom, pair, proj1, proj2 -@syntax BPC{ObExpr,HomExpr} BiproductCategory begin -end -A, B, C = Ob(BPC, :A, :B, :C) -f = Hom(:f, A, B) -k = Hom(:k, B, C) - - -l = Δ(A) ⋅ (delete(A) ⊗ id(A)) -r = id(A) - - -g = EGraph(l) -analyze!(g, CatlabAnalysis) - - -saturate!(g, gen_theory(BPC), SaturationParams(timeout=1, eclasslimit=6000)) - -extract!(g, astsize) - -# ====================================================== -# another test - -l = σ(A, B ⊗ C) -# r = σ(B,A) ⊗ id(C) -r = (σ(A,B) ⊗ id(C)) ⋅ (id(B) ⊗ σ(A,C)) -# r = σ(B ⊗ C, A) - -to_composejl(l) -to_composejl(r) - -g = EGraph(ex) -analyze!(g, CatlabAnalysis) -l_ec, _ = addexpr!(g, l) -r_ec, _ = addexpr!(g, r) - - -saturate!(g, gen_theory(SMC), SaturationParams(timeout=1, eclasslimit=6000)) - -extract!(g, astsize; root=l_ec.id) - -extract!(g, astsize; root=r_ec.id) - - - - -# ==================================================== -# TEST -cc = gen_theory(FreeCartesianCategory) - -A, B, C = Ob(FreeCartesianCategory, :A, :B, :C) -f = Hom(:f, A, B) - -g = EGraph() -analyze!(g, CatlabAnalysis) -ex = id(A) ⊗ id(B) -to_composejl(ex) - -l_ec, _ = addexpr!(g, ex) -saturate!(g, cc, SaturationParams(timeout=2)) -extract!(g, astsize; root=l_ec.id) - - -ex = pair(proj1(A, B), proj2(A, B)) -to_composejl(ex) -r_ec, _ = addexpr!(g, ex) -saturate!(g, cc) -extract!(g, astsize; root=r_ec.id) - diff --git a/test/category/catlab_simpler.jl b/test/category/catlab_simpler.jl deleted file mode 100644 index e3fb6795..00000000 --- a/test/category/catlab_simpler.jl +++ /dev/null @@ -1,224 +0,0 @@ -using Catlab -using Catlab.Theories -using Catlab.Syntax -using Metatheory, Metatheory.EGraphs - -# ============================================================ - -# GATExpr => normal Expr in MT -function gat_to_expr(ex::ObExpr{:generator}) - @assert length(ex.args) == 1 - return ex.args[1] -end -function gat_to_expr(ex::ObExpr{H}) where {H} - return Expr(:call, head(ex), map(gat_to_expr, ex.args)...) -end -function gat_to_expr(ex::HomExpr{H}) where {H} - @assert length(ex.type_args) == 2 - expr = Expr(:call, head(ex), map(gat_to_expr, ex.args)...) - type_ex = Expr(:call, :Hom, map(gat_to_expr, ex.type_args)...) - return Expr(:call, :~, expr, type_ex) -end -function gat_to_expr(ex::HomExpr{:generator}) - f = ex.args[1] - type_ex = Expr(:call, :Hom, map(gat_to_expr, ex.type_args)...) - return Expr(:call, :~, f, type_ex) -end - - -const Code = Union{Symbol, Expr} -const TTags = Dict{Code, Tuple{Code, Symbol}} - -# ============================================================ - -# infer type of morphisms and objects -# a morphism f: A → B will be typed as f ~ Hom(A, B) -# an object A will be typed as Ob(A) -function get_concrete_type_expr(theory, x::Symbol, ctx, type_tags = TTags()) - t = ctx[x] - @show(t) - # t === :Ob && (t = Expr(:call, :Ob, x)) - if t === :Ob - type_tags[x] = (x, t) - return (x, t) - else - @assert t.args[1] == :Hom - type_tags[x] = (t, t.args[1]) - return (t, t.args[1]) - end -end - -function get_concrete_type_expr(theory, x::Expr, ctx, type_tags = TTags()) - @assert exprhead(x) == :call - f = x.args[1] - rest = x.args[2:end] - # recursion case - inductive step (?) - for a in rest - (t, sort) = get_concrete_type_expr(theory, a, ctx, type_tags) - type_tags[a] = (t, sort) - println("$a ~ $t") - end - # get the corresponding TermConstructor from theory.terms - # for each arg in `rest`, instantiate the term.params with term.context - # instantiate term.typ - - (t, sort) = gat_type_inference(theory, f, [type_tags[a] for a in rest]) - type_tags[x] = (t, sort) - # println("$x ~ $(type_tags[x])") - return (t, sort) -end - -function is_context_match(t, head, args) - # t isa TermConstructor - # println(repeat("=", 30)) - # println("is_context_match") - # @show t - # @show head - # @show args - # println(repeat("=", 30)) - - # TODO fixme! - - t.name !== head && return false - n = length(t.params) - n != length(args) && return false - for i = 1:n - arg, sort = args[i] - - if t.context[t.params[i]] === :Ob - if sort !== :Ob - return false - end - else - if sort === :Ob - return false - end - end - end - return true -end - -function gat_type_inference(theory, head, args) - for t in theory.terms - if is_context_match(t, head, args) - # @show t, head, args - return gat_type_inference(t, head, args) - end - end - # @show theory, head, args - @error "can not find $(Expr(:call, head, args...)) in the theory" -end - -function gat_type_inference(t::GAT.TermConstructor, head, args) - @assert length(t.params) == length(args) && t.name === head - bindings = Dict() - - println(args) - texprs = map(first, args) - sorts = map(last, args) - - - for i = 1:length(args) - template = t.context[t.params[i]] - template === :Ob && (template = t.params[i]) - # @show template - update_bindings!(bindings, template, texprs[i]) - end - # @show bindings - r = GAT.replace_types(bindings, t) - if r.typ == :Ob - return Expr(:call, head, texprs...), r.typ - # # return Expr(:call, :Ob, Expr(:call, head, args...)) - # Expr(:call, head, args...) - else - @show(r.typ) - return r.typ, r.typ.args[1] - end - # end -end -function update_bindings!(bindings, template::Expr, target::Expr) - for i = 1:length(template.args) - update_bindings!(bindings, template.args[i], target.args[i]) - end -end -function update_bindings!(bindings, template, target) - bindings[template] = target -end - - -function tag_expr(x::Expr, axiom, theory) - texpr, sort = get_concrete_type_expr(theory, x, axiom.context) - start = exprhead(x) == :call ? 2 : 1 - - nargs = Any[tag_expr(y, axiom, theory) for y in x.args[start:end]] - - if start == 2 - pushfirst!(nargs, x.args[1]) - end - - z = Expr(exprhead(x), nargs...) - - (sort === :Ob) && (return z) - :($z ~ $texpr) -end - -function tag_expr(x::Symbol, axiom, theory) - (texpr, sort) = get_concrete_type_expr(theory, x, axiom.context) - (sort === :Ob) && (return x) - # return (t == x ? x : :($x ~ $t)) - return :($x ~ $texpr) -end - -# ============================================================ -# Convert Catlab Axioms to rules -# ============================================================ - -function axiom_to_rule(theory, axiom) - op = axiom.name - @assert op == :(==) - lhs = tag_expr(axiom.left, axiom, tt) |> Pattern - rhs = tag_expr(axiom.right, axiom, tt) |> Pattern - - pvars = patvars(lhs) ∪ patvars(rhs) - extravars = setdiff(pvars, patvars(lhs) ∩ patvars(rhs)) - if !isempty(extravars) - if extravars ⊆ patvars(lhs) - println(lhs) - println(rhs) - return RewriteRule(lhs, rhs) - else - return RewriteRule(rhs, lhs) - end - end - # println("$lhs == $rhs") - EqualityRule(lhs, rhs) -end - - -function gen_theory(t::Catlab.GAT.Theory) - [axiom_to_rule(t, ax) for ax in t.axioms] -end - - - -# ========================================================= -# Utility Functions -# ========================================================= - -function Base.show(io::IO, a::Catlab.GAT.AxiomConstructor) - print(io, a.left) - print(io, ' ', a.name, ' ' ) - print(io, a.right) - print(io, " where ") - n = length(a.context) - ctx = collect(a.context) - for i in 1:n - (k,v) = ctx[i] - print(io, "$k => $v") - if i !== n - print(io, ", ") - end - end -end - -ax \ No newline at end of file diff --git a/test/category/catlab_simpler_test.jl b/test/category/catlab_simpler_test.jl deleted file mode 100644 index 4dde9960..00000000 --- a/test/category/catlab_simpler_test.jl +++ /dev/null @@ -1,224 +0,0 @@ -include("catlab_simpler.jl") -using Catlab -using Catlab.Theories -using Catlab.Syntax -using Metatheory, Metatheory.EGraphs - -using Test - -# ============================================================ -# GATExpr TO TAGGED EXPR -# ============================================================ - - -# WE HAVE TO REDEFINE THE SYNTAX TO AVOID ASSOCIATIVITY AND N-ARY FUNCTIONS -import Catlab.Theories: id, compose, otimes, ⋅, braid, σ, ⊗, Ob, Hom -@syntax SMC{ObExpr,HomExpr} SymmetricMonoidalCategory begin -end - -A, B, C, D = Ob(SMC, :A, :B, :C, :D) -X, Y, Z = Ob(SMC, :X, :Y, :Z) - -f = Hom(:f, A, B) -g = Hom(:g, B, C) -h = Hom(:h, C, D) - -gat_to_expr(x) = x - -gat_to_expr(A) - -A isa ObExpr{H} where {H} - -gat_to_expr(id(Z)) == :(id(Z)~(Hom(Z,Z))) - -gat_to_expr(id(Z) ⋅ f) - -gat_to_expr(id(A ⊗ B)) - -gat_to_expr(id(A) ⊗ id(B)) - -gat_to_expr(compose(compose(f, g), h)) - -gat_to_expr(f) - -gat_to_expr(A ⊗ B) - -# BUG -gat_to_expr(otimes(f, g)) - - - -# ============================================================ -# Type tagging axioms -# ============================================================ - -A, B, C, D = Ob(SMC, :A, :B, :C, :D) -X, Y, Z = Ob(SMC, :X, :Y, :Z) - -tt = theory(SymmetricMonoidalCategory) ; -ax = tt.axioms[10] ; -get_concrete_type_expr(tt, ax.left, ax.context) -tag_expr(ax.left, ax, tt) -ax = tt.axioms[4] -get_concrete_type_expr(tt, ax.left, ax.context) -tag_expr(ax.left, ax, tt) - -tt = theory(Category) - -tag_expr(tt.axioms[1].left, tt.axioms[1], tt) == gat_to_expr(compose(compose(f, g), h)) - - -# ==================================================== - - -tt = theory(Category) - - -A, B, C, D = Ob(SMC, :A, :B, :C, :D) -X, Y, Z = Ob(SMC, :X, :Y, :Z) - -f = Hom(:f, A, B) -g = Hom(:g, B, C) -h = Hom(:h, C, D) - -rules = gen_theory(tt) -expr = gat_to_expr(id(A) ⋅ id(A) ⋅ f ⋅ id(B)) -G = EGraph(expr) -saturate!(G, rules) -@test extract!(G, astsize) == :(f ~ Hom(A,B)) - -tt = theory(SymmetricMonoidalCategory) - -rules = Rule[axiom_to_rule(tt, ax) for ax in tt.axioms] - -# push!(rules, EqualityRule( @pat(otimes(Hom(A, B), Hom(X, Y))), @pat(Hom(otimes(A, X), otimes(B, Y))) )) - -gats = [ - σ(A,B⊗C), - (σ(A,B) ⊗ id(C)) ⋅ (id(B) ⊗ σ(A,C)) -] - -exprs = [gat_to_expr(i) for i in gats] - -# push!(rules, RewriteRule(Pattern(l), Pattern(r))) -G = EGraph() - -ecs = [addexpr!(G, i) for i in exprs] - - -saturate!(G, rules) -extract!(G, astsize; root=ecs[2].id) - -@test in_same_class(G, ecs[1], ecs[2]) - - -# YANG BAXTER EQUATION - -gats = [ - (σ(A,B) ⊗ id(C)) ⋅ (id(B) ⊗ σ(C,A)) ⋅ (σ(B,C) ⊗ id(A)), - σ(A, B ⊗ C) ⋅ (σ(B,C) ⊗ id(A)), - (id(A) ⊗ σ(B,C)) ⋅ σ(A, C⊗B), - (id(A) ⊗ σ(B,C)) ⋅ (σ(A,C) ⊗ id(B)) ⋅ (id(C) ⊗ σ(A,B)) -] - -exprs = [gat_to_expr(i) for i in gats] - -# push!(rules, RewriteRule(Pattern(l), Pattern(r))) -G = EGraph() - -ecs = [addexpr!(G, i) for i in exprs] - -saturate!(G, rules, SaturationParams(timeout=1)) -extract!(G, astsize; root=ecs[2].id) - -[ in_same_class(G, ecs[i], ecs[i+1]) for i in 1:length(gats)-1 ] - - - -# ======================================================================================== - -tt = theory(CartesianCategory) -A, B, C, D = Ob(FreeCartesianCategory, :A, :B, :C, :D) -f = Hom(:f, A, B) -g = Hom(:g, B, C) -h = Hom(:h, C, D) - - -l = pair(proj1(A, B), proj2(A, B)) -r = id(A ⊗ B) - -rules = [axiom_to_rule(tt, ax) for ax in tt.axioms] - -l = gat_to_expr(l) -r = gat_to_expr(r) - -G = EGraph(l) -rc, _ = addexpr!(G, r) - -# TODO identify the rules where there are more patvars on the lhs than the rhs -# and use regular rewrite rules instead of (==) rules -saturate!(G, rules) -extract!(G, astsize) -extract!(G, astsize; root=rc.id) - -l = f ⋅ delete(B) - -G = EGraph(gat_to_expr(l)) -saturate!(G, rules) -extract!(G, astsize) -#TODO expr to gat - -# ==================================================== -# TEST -mu = FreeCartesianCategory.munit(FreeCartesianCategory.Ob) - -l = σ(A, mu) -r = id(A) - -rules = [axiom_to_rule(tt, ax) for ax in tt.axioms] - -l = gat_to_expr(l) -r = gat_to_expr(r) - -G = EGraph(l) -rc, _ = addexpr!(G, r) - -# TODO identify the rules where there are more patvars on the lhs than the rhs -# and use regular rewrite rules instead of (==) rules -saturate!(G, rules) -extract!(G, astsize; root=rc.id) -extract!(G, astsize) - -l = σ(A, B) ⋅ σ(B, A) -r = id(A ⊗ B) - -rules = [axiom_to_rule(tt, ax) for ax in tt.axioms] - -l = gat_to_expr(l) -r = gat_to_expr(r) - -G = EGraph(l) -rc, _ = addexpr!(G, r) - -# TODO identify the rules where there are more patvars on the lhs than the rhs -# and use regular rewrite rules instead of (==) rules -saturate!(G, rules) -extract!(G, astsize; root=rc.id) == extract!(G, astsize) - -l = σ(A, mu) -r = id(A) - -rules = [axiom_to_rule(tt, ax) for ax in tt.axioms] - -l = gat_to_expr(l) -r = gat_to_expr(r) - -G = EGraph(l) -rc, _ = addexpr!(G, r) - -# TODO identify the rules where there are more patvars on the lhs than the rhs -# and use regular rewrite rules instead of (==) rules -saturate!(G, rules) -# extract!(G, astsize; root=rc.id) == -extract!(G, astsize) - diff --git a/test/category/test_cat.jl b/test/category/test_cat.jl deleted file mode 100644 index d8483da0..00000000 --- a/test/category/test_cat.jl +++ /dev/null @@ -1,88 +0,0 @@ -using Metatheory -using Metatheory.EGraphs -using Metatheory.Library -using Rewriters - -using Test - -Cat = @theory begin - id(A).(A→A) ⋅ f.(A→B) == f.(A→B) - f.(A→B)⋅id(B).(B→B) == f.(A→B) - (f.(A→B)⋅g.(B→C))⋅h.(C→D) == f.(A→B)⋅(g.(B→C)⋅h.(C→D)) -end - -# DOES NOT FIXPOINT! -tag_matcher_t = @theory begin - sigma(A, B) => σ(A, B).(A⊗B→B⊗A) - 1(A) => id(A).(A→A) - :f |> :(f.(A → B)) - :g |> :(g.(B → C)) - :h |> :(h.(B → C)) - :j |> :(j.(B → B)) - :k |> :(j.(C → D)) -end - -tag_matcher(x) = Chain(tag_matcher_t)(x) - -# FIXME TAGGING skip . operator -function tag(x) - r = Postwalk(PassThrough(If(x -> !istree(x) || operation(x) != :., tag_matcher)))(x) - println("tagged $x to $(r)") - r -end - -macro areequal_tag(t, exprs...) - exprs = map(tag, exprs) - - :(areequal($t, $(exprs)...)) -end - - -@testset "Cat" begin - @test @areequal_tag Cat 1(B) id(B).(B→B) - @test @areequal_tag Cat (1(A) ⋅ f) f - @test @areequal_tag Cat (1(A)⋅f) f - @test @areequal_tag Cat (f⋅1(B)) f - @test @areequal_tag Cat (f⋅j)⋅j f⋅(j⋅j) - @test @areequal_tag Cat (f⋅j)⋅h f⋅(j⋅h) -end - -MonCat = Cat ∪ monoid(:(⊗), :(:munit)) ∪ @theory begin - f.(A→B) ⊗ g.(C→D) == (f⊗g).(A⊗C→B⊗D) - id(A⊗B).(A⊗B→A⊗B) == id(A).(A→A)⊗id(B).(B→B) -end - -@testset "MonCat" begin - @test @areequal_tag MonCat 1(A)⊗1(B) 1(A⊗B) - @test @areequal_tag MonCat (f ⊗ g) ((f⊗g).(A⊗B→B⊗C)) - @test @areequal_tag MonCat 1(A⊗B)⋅(f⊗g) (f⊗g) - @test @areequal_tag MonCat ((f ⊗ g)⋅1(B⊗C)) (f⊗g) - @test @areequal_tag MonCat 1(A⊗B)⋅(f⊗g) (f⊗g).(A⊗B→B⊗C) -end - -SymMonCat = MonCat ∪ @theory begin - σ(A, B).(A⊗B→B⊗A) ⋅ σ(B, A).(B⊗A→A⊗B) == id(A⊗B).(A⊗B→A⊗B) - - σ(A, B⊗C).(A⊗(B⊗C)→(B⊗C)⊗A) == σ(A⊗B, C).((A⊗B)⊗C→C⊗(A⊗B)) - (f.(A→B)⊗g.(C→D))⋅σ(B,D).(B⊗D→D⊗B) == σ(A,C).(A⊗C→C⊗A) ⋅ (g.(C→D)⊗f.(A→B)) - σ(A,C).(A⊗C→C⊗A) ⋅ (g.(C→D)⊗f.(A→B)) == (f.(A→B)⊗g.(C→D))⋅σ(B,D).(B⊗D→D⊗B) -end - -@testset "SymMonCat" begin - @test @areequal_tag SymMonCat 1(A)⊗1(B) 1(A⊗B) - @test @areequal_tag SymMonCat (f ⊗ g) (f ⊗ g) - @test @areequal_tag SymMonCat 1(A⊗B) ⋅(f ⊗ g) (f ⊗ g) - @test @areequal_tag SymMonCat (f ⊗ g) ⋅ 1(B⊗C) (f ⊗ g) - @test @areequal_tag SymMonCat 1(A⊗B)⋅(f⊗g) (f⊗g).(A⊗B→B⊗C) - - # println("==========================================") - - @test falseormissing(@areequal_tag SymMonCat sigma(A,B)⋅sigma(A,B) 1(A⊗B)) - @test @areequal_tag SymMonCat sigma(A,B) σ(A,B).(A⊗B→B⊗A) - @test @areequal_tag SymMonCat sigma(A,B)⋅sigma(B,A) 1(A⊗B) - @test @areequal_tag SymMonCat sigma(B,A)⋅sigma(A,B) 1(B⊗A) - @test @areequal_tag SymMonCat (f ⊗ k)⋅sigma(B,D) sigma(A,C)⋅(k ⊗ f) - @test @areequal_tag SymMonCat sigma(A,A)⋅(f.(A→A)⊗g.(A→A))⋅sigma(A,A) g.(A→A)⊗f.(A→A) - @test @areequal_tag SymMonCat sigma(B,A)⋅(f.(A→A)⊗g.(B→A))⋅sigma(A,A) g.(B→A)⊗f.(A→A) - @test @areequal_tag SymMonCat sigma(B,A)⋅(f.(A→C)⊗g.(B→D))⋅sigma(C,D) g.(B→D)⊗f.(A→C) -end diff --git a/test/category/test_cat_guard.jl b/test/category/test_cat_guard.jl deleted file mode 100644 index 35f635d8..00000000 --- a/test/category/test_cat_guard.jl +++ /dev/null @@ -1,214 +0,0 @@ -using Metatheory -using Metatheory.EGraphs -# Description here: -# https://www.philipzucker.com/metatheory-progress/ -# https://github.com/AlgebraicJulia/Catlab.jl/blob/ce2fde9c63a8aab65cf2a7697f43cd24e5e00b3a/src/theories/Monoidal.jl#L127 - -cat_rules = @theory begin - f ⋅ id(b) => f - id(a) ⋅ f => f - f == f ⋅ id(cod(type(f))) - f == id(dom(type(f))) ⋅ f - - a ⊗ₒ munit() == a - munit() ⊗ₒ a == a - - f ⋅ (g ⋅ h) == (f ⋅ g) ⋅ h -end - -monoidal_rules = @theory begin - id(munit()) ⊗ₘ f => f - f ⊗ₘ id(munit()) => f - a ⊗ₒ (b ⊗ₒ c) == (a ⊗ₒ b) ⊗ₒ c - f ⊗ₘ (h ⊗ₘ j) == (f ⊗ₘ h) ⊗ₘ j - id(a ⊗ₒ b) == id(a) ⊗ₘ id(b) - (f ⋅ g) ⊗ₘ (p ⋅ q) => (f ⊗ₘ p) ⋅ (g ⊗ₘ q) -end - -push!(monoidal_rules, - MultiPatRewriteRule(@pat((f ⊗ₘ p) ⋅ (g ⊗ₘ q)), @pat((f ⋅ g) ⊗ₘ (p ⋅ q)), - [PatEquiv(@pat(cod(type(f))), @pat(dom(type(g)))), PatEquiv(@pat(cod(type(p))), @pat(dom(type(q))))]) -) - - -sym_rules = @theory begin - σ(a, b) ⋅ σ(b, a) == id(a ⊗ₒ b) - (σ(a, b) ⊗ₘ id(c)) ⋅ (id(b) ⊗ₘ σ(a, c)) == σ(a, b ⊗ₒ c) - (id(a) ⊗ₘ σ(b, c)) ⋅ (σ(a, c) ⊗ₘ id(b)) == σ(a ⊗ₒ b, c) - - # these rules arer not catlab - σ(a, munit()) == id(a) - σ(munit(), a) == id(a) - σ(munit(), munit()) => id(munit() ⊗ₒ munit()) - -end - -push!(sym_rules, - MultiPatRewriteRule(@pat((f ⊗ₘ h) ⋅ σ(a, b)), @pat(σ(dom(type(f)), dom(type(h))) ⋅ (h ⊗ₘ f)), - [PatEquiv(@pat(cod(type(f))), @pat(a)), PatEquiv(@pat(cod(type(h))), @pat(b))]), - - - MultiPatRewriteRule(@pat(σ(c, d) ⋅ (h ⊗ₘ f)), @pat((f ⊗ₘ h) ⋅ σ(cod(type(f)), cod(type(h)))), - [PatEquiv(@pat(dom(type(f))), PatVar(:c)), PatEquiv(@pat(dom(type(f))), PatVar(:d))]) -) - - -diag_rules = @theory begin - Δ(a) ⋅ (⋄(a) ⊗ₘ id(a)) == id(a) - Δ(a) ⋅ (id(a) ⊗ₘ ⋄(a)) == id(a) - Δ(a) ⋅ σ(a, a) == Δ(a) - - (Δ(a) ⊗ₘ Δ(b)) ⋅ (id(a) ⊗ₘ σ(a, b) ⊗ₘ id(b)) == Δ(a ⊗ₒ b) - - Δ(a) ⋅ (Δ(a) ⊗ₘ id(a)) == Δ(a) ⋅ (id(a) ⊗ₘ Δ(a)) - ⋄(a ⊗ₒ b) == ⋄(a) ⊗ₘ ⋄(b) - - Δ(munit()) == id(munit()) - ⋄(munit()) == id(munit()) -end - - -cart_rules = @theory begin - - pair(f, k) == Δ(dom(type(f))) ⋅ (f ⊗ₘ k) - proj1(a, b) == id(a) ⊗ₘ ⋄(b) - proj2(a, b) == ⋄(a) ⊗ₘ id(b) - f ⋅ ⋄(b) => ⋄(dom(type(f))) - # Has to invent f. Hard to fix. - # ⋄(b) => f ⋅ ⋄(b) - - f ⋅ Δ(b) => Δ(dom(type(f))) ⋅ (f ⊗ₘ f) - Δ(a) ⋅ (f ⊗ₘ f) => f ⋅ Δ(cod(type(f))) -end - -push!(cart_rules, -MultiPatRewriteRule(@pat(Δ(a) ⋅ (f ⊗ₘ k)), @pat(pair(f,k)), -[PatEquiv(@pat(dom(type(f))), @pat(dom(type(k))))]) -) - - -typing_rules = @theory begin - dom(hom(a, b)) => a - cod(hom(a, b)) => b - type(id(a)) => hom(a, a) - type(f ⋅ g) => hom(dom(type(f)), cod(type(g))) - type(f ⊗ₘ g) => hom(dom(type(f)) ⊗ₒ dom(type(g)), cod(type(f)) ⊗ₒ cod(type(g))) - type(a ⊗ₒ b) => :ob - type(munit()) => :ob - type(σ(a, b)) => hom(a ⊗ₒ b, b ⊗ₒ a) - type(⋄(a)) => hom(a, munit()) - type(Δ(a)) => hom(a, a ⊗ₒ a) - type(pair(f, g)) => hom(dom(type(f)), cod(type(f)) ⊗ₒ cod(type(g))) - type(proj1(a, b)) => hom(a ⊗ₒ b, a) - type(proj2(a, b)) => hom(a ⊗ₒ b, b) -end - - -rules = typing_rules ∪ cat_rules ∪ monoidal_rules ∪ sym_rules ∪ diag_rules ∪ cart_rules ∪ typing_rules - - -# A goofy little helper macro -# Taking inspiration from Lean/Dafny/Agda -using Metatheory.EGraphs.Schedulers -macro calc(e...) - theory = eval(e[1]) - e = rmlines(e[2]) - @assert e.head == :block - - trues = Bool[] - - for (a, b) in zip(e.args[1:end-1], e.args[2:end]) - # println(a, " =? ", b) - params = SaturationParams( - timeout=12, - eclasslimit=12000, - scheduler=BackoffScheduler - ) - g = EGraph() - ta, _ = addexpr!(g, :(type(a))) - tao, _ = addexpr!(g, :(:ob)) - merge!(g, ta.id, tao.id) - - eq = @time areequal(g, theory, a, b; params=params) - push!(trues, (eq !== missing) && eq) - println(eq) - # WOULD WORK IF COST FUNCTION IS SIMILARITY TO OTHER FUN - # if !eq - # i = 0 - # while !eq && i < 4 - # ga = EGraph(a); gb = EGraph(b) - # ga_extr = addanalysis!(ga, ExtractionAnalysis, astsize) - # gb_extr = addanalysis!(gb, ExtractionAnalysis, astsize) - # @time saturate!(ga, theory; timeout = 9) - # @time saturate!(gb, theory; timeout = 9) - # - # new_a = extract!(ga, ga_extr) - # new_b = extract!(gb, gb_extr) - # println("i = $i \nnew a = $new_a \nnew b = $new_b") - # eq = @time areequal(theory, new_a, new_b; timeout = 9) - # i += 1 - # end - # - # if !eq && i == 4 - # return false - # end - # end - end - all(trues) -end - -@calc rules begin - id(a ⊗ₒ b) - id(a) ⊗ₘ id(b) - (Δ(a) ⋅ (id(a) ⊗ₘ ⋄(a))) ⊗ₘ id(b) - (Δ(a) ⋅ (id(a) ⊗ₘ ⋄(a))) ⊗ₘ (Δ(b) ⋅ (⋄(b) ⊗ₘ id(b))) - (Δ(a) ⊗ₘ Δ(b)) ⋅ ((id(a) ⊗ₘ ⋄(a)) ⊗ₘ (⋄(b) ⊗ₘ id(b))) - (Δ(a) ⊗ₘ Δ(b)) ⋅ (id(a) ⊗ₘ (⋄(a) ⊗ₘ ⋄(b)) ⊗ₘ id(b)) - (Δ(a) ⊗ₘ Δ(b)) ⋅ (id(a) ⊗ₘ ((⋄(a) ⊗ₘ ⋄(b)) ⋅ σ(munit(), munit())) ⊗ₘ id(b)) - (Δ(a) ⊗ₘ Δ(b)) ⋅ ((id(a) ⊗ₘ (σ(a, b) ⋅ (⋄(b) ⊗ₘ ⋄(a))) ⊗ₘ id(b))) - (Δ(a) ⊗ₘ Δ(b)) ⋅ ((id(a) ⊗ₘ (σ(a, b) ⋅ (⋄(b) ⊗ₘ ⋄(a))) ⊗ₘ id(b))) - (Δ(a) ⊗ₘ Δ(b)) ⋅ ((id(a) ⋅ id(a)) ⊗ₘ (σ(a, b) ⋅ (⋄(b) ⊗ₘ ⋄(a))) ⊗ₘ id(b)) - (Δ(a) ⊗ₘ Δ(b)) ⋅ ((id(a) ⊗ₘ σ(a, b) ⊗ₘ id(b)) ⋅ (id(a) ⊗ₘ (⋄(b) ⊗ₘ ⋄(a)) ⊗ₘ id(b))) - Δ(a ⊗ₒ b) ⋅ (id(a) ⊗ₘ (⋄(b) ⊗ₘ ⋄(a)) ⊗ₘ id(b)) - Δ(a ⊗ₒ b) ⋅ (id(a) ⊗ₘ (⋄(b) ⊗ₘ ⋄(a)) ⊗ₘ id(b)) - Δ(a ⊗ₒ b) ⋅ (proj1(a, b) ⊗ₘ proj2(a, b)) - pair(proj1(a, b), proj2(a, b)) -end - -# shorter proof also accepted -@calc rules begin - id(a ⊗ₒ b) - (Δ(a) ⊗ₘ Δ(b)) ⋅ ((id(a) ⊗ₘ (σ(a, b) ⋅ (⋄(b) ⊗ₘ ⋄(a))) ⊗ₘ id(b))) - pair(proj1(a, b), proj2(a, b)) -end - -# shorter proof not quite there -@calc rules begin - id(a ⊗ₒ b) - pair(proj1(a, b), proj2(a, b)) -end - -@calc rules begin - id(a ⊗ₒ b) - id(a) ⊗ₘ id(b) - (Δ(a) ⋅ (id(a) ⊗ₘ ⋄(a))) ⊗ₘ id(b) - (Δ(a) ⋅ (id(a) ⊗ₘ ⋄(a))) ⊗ₘ (Δ(b) ⋅ (⋄(b) ⊗ₘ id(b))) - (Δ(a) ⊗ₘ Δ(b)) ⋅ ((id(a) ⊗ₘ ⋄(a)) ⊗ₘ (⋄(b) ⊗ₘ id(b))) - (Δ(a) ⊗ₘ Δ(b)) ⋅ (id(a) ⊗ₘ (⋄(a) ⊗ₘ ⋄(b)) ⊗ₘ id(b)) - (Δ(a) ⊗ₘ Δ(b)) ⋅ (id(a) ⊗ₘ ((⋄(a) ⊗ₘ ⋄(b)) ⋅ σ(munit(), munit())) ⊗ₘ id(b)) - (Δ(a) ⊗ₘ Δ(b)) ⋅ ((id(a) ⊗ₘ (σ(a, b) ⋅ (⋄(b) ⊗ₘ ⋄(a))) ⊗ₘ id(b))) - (Δ(a) ⊗ₘ Δ(b)) ⋅ ((id(a) ⊗ₘ (σ(a, b) ⋅ (⋄(b) ⊗ₘ ⋄(a))) ⊗ₘ id(b))) - (Δ(a) ⊗ₘ Δ(b)) ⋅ ((id(a) ⋅ id(a)) ⊗ₘ (σ(a, b) ⋅ (⋄(b) ⊗ₘ ⋄(a))) ⊗ₘ id(b)) - (Δ(a) ⊗ₘ Δ(b)) ⋅ ((id(a) ⊗ₘ σ(a, b) ⊗ₘ id(b)) ⋅ (id(a) ⊗ₘ (⋄(b) ⊗ₘ ⋄(a)) ⊗ₘ id(b))) - Δ(a ⊗ₒ b) ⋅ (id(a) ⊗ₘ (⋄(b) ⊗ₘ ⋄(a)) ⊗ₘ id(b)) - Δ(a ⊗ₒ b) ⋅ (id(a) ⊗ₘ (⋄(b) ⊗ₘ ⋄(a)) ⊗ₘ id(b)) - Δ(a ⊗ₒ b) ⋅ (proj1(a, b) ⊗ₘ proj2(a, b)) - pair(proj1(a, b), proj2(a, b)) -end - -G = EGraph( :(pair(proj1(a, b), proj2(a, b)))) -params = SaturationParams(timeout=10) -saturate!(G, rules, params ) -ex = extract!(G, astsize) - -G.classes \ No newline at end of file diff --git a/test/category/test_cat_zx.jl b/test/category/test_cat_zx.jl deleted file mode 100644 index 68635dab..00000000 --- a/test/category/test_cat_zx.jl +++ /dev/null @@ -1,171 +0,0 @@ -using Catlab -using Catlab.Theories - -@signature ZXCategory{Ob,Hom} <: DaggerCompactCategory{Ob,Hom} begin - # Argument α is the phase, usually <: Real - zphase(A::Ob, α)::(A → A) - zcopy(A::Ob, α)::(A → (A⊗A)) - zdelete(A::Ob, α)::(A → munit()) - zmerge(A::Ob, α)::((A⊗A) → A) - zcreate(A::Ob, α)::(munit() → A) - - xphase(A::Ob, α)::(A → A) - xcopy(A::Ob, α)::(A → (A⊗A)) - xdelete(A::Ob, α)::(A → munit()) - xmerge(A::Ob, α)::((A⊗A) → A) - xcreate(A::Ob, α)::(munit() → A) - - hadamard(A::Ob)::(A → A) -end - -# Convenience methods for phaseless spiders. -zcopy(A) = zcopy(A,0) -zdelete(A) = zdelete(A,0) -zmerge(A) = zmerge(A,0) -zcreate(A) = zcreate(A,0) - -xcopy(A) = xcopy(A,0) -xdelete(A) = xdelete(A,0) -xmerge(A) = xmerge(A,0) -xcreate(A) = xcreate(A,0); - -import Catlab.Theories.Ob -@syntax ZXCalculus{ObExpr,HomExpr} ZXCategory begin - # otimes(A::Ob, B::Ob) = associate_unit(new(A,B), munit) - # otimes(f::Hom, g::Hom) = associate(new(f,g)) - # compose(f::Hom, g::Hom) = associate(new(f,g; strict=true)) -end - -using Metatheory, Metatheory.EGraphs - -# Custom type APIs for the GATExpr -using TermInterface -TermInterface.operation(t::ObExpr) = :call -TermInterface.arguments(t::ObExpr) = [head(t), t.args...] -TermInterface.operation(t::HomExpr) = :call -TermInterface.arguments(t::HomExpr) = [head(t), t.args...] - -abstract type CatType end -struct ObType <: CatType - ob - mod -end -struct HomType <: CatType - dom - codom - mod -end - -# Type information will be stored in the metadata -function TermInterface.metadata(t::HomExpr) - return HomType(t.type_args[1], t.type_args[2], typeof(t).name.module) -end -TermInterface.metadata(t::ObExpr) = ObType(t, typeof(t).name.module) -TermInterface.istree(t::GATExpr) = true -TermInterface.arity(t::GATExpr) = length(arguments(t)) - -struct CatlabAnalysis <: AbstractAnalysis end -function EGraphs.make(an::Type{CatlabAnalysis}, g::EGraph, n::ENode{T}) where T - !(T <: GATExpr) && return t - return metadata(n) -end -EGraphs.join(an::Type{CatlabAnalysis}, from, to) = from -EGraphs.islazy(x::Type{CatlabAnalysis}) = false - -function infer(t::GATExpr) - g = EGraph(t) - analyze!(g, CatlabAnalysis) - getdata(g[g.root], CatlabAnalysis) -end - -function EGraphs.extractnode(g::EGraph, n::ENode{T}, extractor::Function) where {T <: ObExpr} - @assert n.head == :call - return metadata(n).ob -end - -function EGraphs.extractnode(g::EGraph, n::ENode{T}, extractor::Function) where {T <: HomExpr} - @assert n.head == :call - nargs = extractor.(n.args) - nmeta = metadata(n) - return nmeta.mod.Hom{nargs[1]}(nargs[2:end], GATExpr[nmeta.dom, nmeta.codom]) -end - -# function EGraphs.instantiateterm(g::EGraph, pat::PatTerm, T::Type{H{K}}, sub::Sub, rule::Rule) where {H <: GATExpr, K} -# # TODO -# end - -t = Metatheory.@theory begin - compose(hadamard(A), hadamard(A)) |> - begin - d = getdata(A, Main.CatlabAnalysis) - return d.mod.id(d.ob) - end - compose(f, id(B)) |> - begin - bd = getdata(B, CatlabAnalysis) - fd = getdata(f, CatlabAnalysis) - if bd.ob == fd.codom - return f - else - error("TYPE ERROR!") - return _lhs_expr - end - end - compose(id(A), f) |> - begin - ad = getdata(A, CatlabAnalysis) - fd = getdata(f, CatlabAnalysis) - if ad.ob == fd.dom - return f - else - error("TYPE ERROR!") - return _lhs_expr - end - end -end - -t[2] - -A = Ob(ZXCalculus.Ob, :A) -B = Ob(ZXCalculus.Ob, :B) -f = Hom(:f, A, B) -h = hadamard(A) -c = h ⋅ h -G = EGraph(c) -infer(zdelete(A)).codom == A - -analyze!(G, CatlabAnalysis) -saturate!(G, t) -ex = extract!(G, astsize) -ex == id(A) - - -G = EGraph(f ⋅ id(B)) -analyze!(G, CatlabAnalysis) -saturate!(G, t) -ex = extract!(G, astsize) -ex == f - -x = id(A) ⋅ f ⋅ id(B) -G = EGraph(x) -analyze!(G, CatlabAnalysis) -saturate!(G, t) -ex = extract!(G, astsize) -ex == f - -using Catlab, Catlab.Theories -using Catlab.WiringDiagrams, Catlab.Graphics -using Catlab.Syntax - -A, B, C, D, E = Ob(FreeBiproductCategory, :A, :B, :C, :D, :E) -f = Hom(:f, A, B) -g = Hom(:g, B, C) -h = Hom(:h, B, A) -k = Hom(:k, C, B) -x = id(A) ⋅ f ⋅ id(B) - -z = x ⊗ f ⊗ ((f ⊗ g) ⋅ braid(B,C) ⋅ (k ⊗ h) ⋅ (delete(B) ⊗ f)) -to_composejl(z; orientation=LeftToRight) - -drop = munit(FreeCompactClosedCategory.Ob) -delete() \ No newline at end of file diff --git a/test/fib/symbolic_fib_comparison.jl b/test/fib/symbolic_fib_comparison.jl deleted file mode 100644 index 12938488..00000000 --- a/test/fib/symbolic_fib_comparison.jl +++ /dev/null @@ -1,71 +0,0 @@ -# Thanks to Mason Protter for this benchmark - -module SUFib -using SymbolicUtils -using Rewriters - -@syms fib(x::Int)::Int - -const rset = [ - @rule fib(0) => 0 - @rule fib(1) => 1 - @rule fib(~n) => fib(~n - 1) + fib(~n - 2) -] |> Chain |> Postwalk |> Fixpoint - -compute_fib(n) = rset(fib(n)) - -end - - -module MTFib - -using Metatheory -using Metatheory.EGraphs - -const fibo = @theory begin - x::Int + y::Int |> x+y - fib(n::Int) |> (n < 2 ? n : :(fib($(n-1)) + fib($(n-2)))) -end; - -function compute_fib(n) - params = SaturationParams(timeout = 7000, - scheduler=Schedulers.SimpleScheduler) - g = EGraph(:(fib($n))) - saturate!(g, fibo, params) - extract!(g, astsize) -end - -end - -using BenchmarkTools - -ns = 1:2:22 - -SU_ts = map(ns) do n - println(n) - @assert SUFib.compute_fib(n) isa Number - b = @benchmarkable SUFib.compute_fib($n) seconds=0.2 - mean(run(b)).time / 1e9 -end - -MT_ts = map(ns) do n - println(n) - @assert MTFib.compute_fib(n) isa Number - b = @benchmarkable MTFib.compute_fib($n) seconds=0.2 - mean(run(b)).time / 1e9 -end - - -using Plots -pyplot() - -font = "DejaVu Math TeX Gyre" -# default(titlefont=font, legendfont=font, fontfamily=font) -default(fontfamily=font, markerstrokewidth=0) - - -plot(ns, SU_ts, label="SymbolicUtils.jl", title="fib(n)", ylabel="Time (s)", xlabel="n", - color = :black, legend = :topleft, line=:dot, # m=(:cross, :blue), - size=(320,220), legendfontsize = 9, titlefontsize=12) -plot!(ns, MT_ts, label="Metatheory.jl", color = :black) # m = (:circle, :orange) ) -savefig("benchmarks/figures/fib.pdf") diff --git a/test/integration/cas.jl b/test/integration/cas.jl new file mode 100644 index 00000000..87dc66a2 --- /dev/null +++ b/test/integration/cas.jl @@ -0,0 +1,281 @@ +using Test +using Metatheory +using Metatheory.Library +using Metatheory.Schedulers +using TermInterface + +mult_t = @commutative_monoid (*) 1 +plus_t = @commutative_monoid (+) 0 + +minus_t = @theory a b begin + # TODO Jacques Carette's post in zulip chat + a - a --> 0 + a - b --> a + (-1 * b) + -a --> -1 * a + a + (-b) --> a + (-1 * b) +end + + +mulplus_t = @theory a b c begin + # TODO FIXME these rules improves performance and avoids commutative + # explosion of the egraph + a + a --> 2 * a + 0 * a --> 0 + a * 0 --> 0 + a * (b + c) == ((a * b) + (a * c)) + a + (b * a) --> ((b + 1) * a) +end + +pow_t = @theory x y z n m p q begin + (y^n) * y --> y^(n + 1) + x^n * x^m == x^(n + m) + (x * y)^z == x^z * y^z + (x^p)^q == x^(p * q) + x^0 --> 1 + 0^x --> 0 + 1^x --> 1 + x^1 --> x + x * x --> x^2 + inv(x) == x^(-1) +end + +div_t = @theory x y z begin + x / 1 --> x + # x / x => 1 TODO SIGN ANALYSIS + x / (x / y) --> y + x * (y / x) --> y + x * (y / z) == (x * y) / z + x^(-1) == 1 / x +end + +trig_t = @theory θ begin + sin(θ)^2 + cos(θ)^2 --> 1 + sin(θ)^2 - 1 --> cos(θ)^2 + cos(θ)^2 - 1 --> sin(θ)^2 + tan(θ)^2 - sec(θ)^2 --> 1 + tan(θ)^2 + 1 --> sec(θ)^2 + sec(θ)^2 - 1 --> tan(θ)^2 + cot(θ)^2 - csc(θ)^2 --> 1 + cot(θ)^2 + 1 --> csc(θ)^2 + csc(θ)^2 - 1 --> cot(θ)^2 +end + +# Dynamic rules +fold_t = @theory a b begin + -(a::Number) => -a + a::Number + b::Number => a + b + a::Number * b::Number => a * b + a::Number^b::Number => begin + b < 0 && a isa Int && (a = float(a)) + a^b + end + a::Number / b::Number => a / b +end + +using Calculus: differentiate +function ∂ end + +diff_t = @theory x y begin + ∂(y, x::Symbol) => begin + z = extract!(_egraph, simplcost; root = y.id) + @show z + zd = differentiate(z, x) + @show zd + zd + end +end + +cas = fold_t ∪ mult_t ∪ plus_t ∪ minus_t ∪ mulplus_t ∪ pow_t ∪ div_t ∪ trig_t ∪ diff_t + + +function customlt(x, y) + if typeof(x) == Expr && typeof(y) == Expr + false + elseif typeof(x) == typeof(y) + isless(x, y) + elseif x isa Symbol && y isa Number + false + elseif x isa Expr && y isa Number + false + elseif x isa Expr && y isa Symbol + false + else + true + end +end + +canonical_t = @theory x y n xs ys begin + # restore n-arity + (x * x) --> x^2 + (x^n::Number * x) --> x^(n + 1) + (x * x^n::Number) --> x^(n + 1) + (x + (+)(ys...)) --> +(x, ys...) + ((+)(xs...) + y) --> +(xs..., y) + (x * (*)(ys...)) --> *(x, ys...) + ((*)(xs...) * y) --> *(xs..., y) + + (*)(xs...) => Expr(:call, :*, sort!(xs; lt = customlt)...) + (+)(xs...) => Expr(:call, :+, sort!(xs; lt = customlt)...) +end + + +function simplcost(n::ENodeTerm, g::EGraph) + cost = 0 + arity(n) + if operation(n) == :∂ + cost += 20 + end + for id in arguments(n) + eclass = g[id] + !hasdata(eclass, simplcost) && (cost += Inf; break) + cost += last(getdata(eclass, simplcost)) + end + return cost +end + +simplcost(n::ENodeLiteral, g::EGraph) = 0 + +function simplify(ex; steps = 4) + params = SaturationParams( + scheduler = ScoredScheduler, + eclasslimit = 5000, + timeout = 7, + schedulerparams = (1000, 5, Schedulers.exprsize), + #stopwhen=stopwhen, + ) + hist = UInt64[] + push!(hist, hash(ex)) + for i in 1:steps + g = EGraph(ex) + @profview_allocs saturate!(g, cas, params) + ex = extract!(g, simplcost) + ex = rewrite(ex, canonical_t) + if !TermInterface.istree(ex) + return ex + end + if hash(ex) ∈ hist + println("loop detected $ex") + return ex + end + println(ex) + push!(hist, hash(ex)) + end + +end + +@test :(4a) == simplify(:(2a + a + a)) +@test :(a * b * c) == simplify(:(a * c * b)) +@test :(2x) == simplify(:(1 * x * 2)) +@test :((a * b)^2) == simplify(:((a * b)^2)) +@test :((a * b)^6) == simplify(:((a^2 * b^2)^3)) +@test :(a + b + d) == simplify(:(a + b + (0 * c) + d)) +@test :(a + b) == simplify(:(a + b + (c * 0) + d - d)) +@test :(a) == simplify(:((a + d) - d)) +@test :(a + b + d) == simplify(:(a + b * c^0 + d)) +@test :(a * b * x^(d + y)) == simplify(:(a * x^y * b * x^d)) +@test :(a * b * x^74103) == simplify(:(a * x^(12 + 3) * b * x^(42^3))) + +@test 1 == simplify(:((x + y)^(a * 0) / (y + x)^0)) +@test 2 == simplify(:(cos(x)^2 + 1 + sin(x)^2)) +@test 2 == simplify(:(cos(y)^2 + 1 + sin(y)^2)) +@test 2 == simplify(:(sin(y)^2 + cos(y)^2 + 1)) + +@test :(y + sec(x)^2) == simplify(:(1 + y + tan(x)^2)) +@test :(y + csc(x)^2) == simplify(:(1 + y + cot(x)^2)) + + + +# simplify(:( ∂(x^2, x))) + +@time simplify(:(∂(x^(cos(x)), x))) + +@test :(2x^3) == simplify(:(x * ∂(x^2, x) * x)) + +# @simplify ∂(y^3, y) * ∂(x^2 + 2, x) / y * x + +# @simplify (6 * x * x * y) + +# @simplify ∂(y^3, y) / y + +# # ex = :( ∂(x^(cos(x)), x) ) +# ex = :( (6 * x * x * y) ) +# g = EGraph(ex) +# saturate!(g, cas) +# g.classes +# extract!(g, simplcost; root=g.root) + +# params = SaturationParams( +# scheduler=BackoffScheduler, +# eclasslimit=5000, +# timeout=7, +# schedulerparams=(1000,5), +# #stopwhen=stopwhen, +# ) + +# ex = :((x+y)^(a*0) / (y+x)^0) +# g = EGraph(ex) +# @profview println(saturate!(g, cas, params)) + +# ex = extract!(g, simplcost) +# ex = rewrite(ex, canonical_t; clean=false) + + +# FIXME this is a hack to get the test to work. +if VERSION < v"1.9.0-DEV" + function EGraphs.make(::Val{:type_analysis}, g::EGraph, n::ENodeLiteral) + v = n.value + if v == :im + typeof(im) + else + typeof(v) + end + end + + function EGraphs.make(::Val{:type_analysis}, g::EGraph, n::ENodeTerm) + symtype(n) !== Expr && return Any + if exprhead(n) != :call + # println("$n is not a call") + t = Any + # println("analyzed type of $n is $t") + return t + end + sym = operation(n) + if !(sym isa Symbol) + # println("head $sym is not a symbol") + t = Any + # println("analyzed type of $n is $t") + return t + end + + symval = getfield(@__MODULE__, sym) + child_classes = map(x -> g[x], arguments(n)) + child_types = Tuple(map(x -> getdata(x, :type_analysis, Any), child_classes)) + + # t = t_arr[1] + t = Core.Compiler.return_type(symval, child_types) + + if t == Union{} + throw(MethodError(symval, child_types)) + end + # println("analyzed type of $n is $t") + return t + end + + EGraphs.join(::Val{:type_analysis}, from, to) = typejoin(from, to) + + EGraphs.islazy(::Val{:type_analysis}) = true + + function infer(e) + g = EGraph(e) + analyze!(g, :type_analysis) + getdata(g[g.root], :type_analysis) + end + + + ex1 = :(cos(1 + 3.0) + 4 + (4 - 4im)) + ex2 = :("ciao" * 2) + ex3 = :("ciao" * " mondo") + + @test ComplexF64 == infer(ex1) + @test_throws MethodError infer(ex2) + @test String == infer(ex3) +end diff --git a/test/fib/test_fibonacci.jl b/test/integration/fibonacci.jl similarity index 55% rename from test/fib/test_fibonacci.jl rename to test/integration/fibonacci.jl index 97e8fb23..c5266e66 100644 --- a/test/fib/test_fibonacci.jl +++ b/test/integration/fibonacci.jl @@ -1,12 +1,14 @@ # ENV["JULIA_DEBUG"] = Metatheory using Metatheory +function fib end + fibo = @theory x y n begin - x::Int + y::Int => x + y - fib(n::Int) => (n < 2 ? n : :(fib($(n - 1)) + fib($(n - 2)))) + x::Int + y::Int => x + y + fib(n::Int) => (n < 2 ? n : :(fib($(n - 1)) + fib($(n - 2)))) end -params = SaturationParams(timeout=60) +params = SaturationParams(timeout = 60) g = EGraph(:(fib(10))) @time saturate!(g, fibo, params) @@ -14,5 +16,5 @@ z = EGraph(:(fib(10))) @time saturate!(z, fibo, params) @testset "Fibonacci" begin - @test 55 == extract!(g, astsize) + @test 55 == extract!(g, astsize) end diff --git a/test/integration/kb_benchmark.jl b/test/integration/kb_benchmark.jl new file mode 100644 index 00000000..f759b45e --- /dev/null +++ b/test/integration/kb_benchmark.jl @@ -0,0 +1,71 @@ +using Test +using Metatheory +using Metatheory.Library +using Metatheory.EGraphs +using Metatheory.Rules +using Metatheory.EGraphs.Schedulers + +function rep(x, op, n::Int) + foldl((x, y) -> :(($op)($x, $y)), repeat([x], n)) +end + +macro rep(x, op, n::Int) + expr = rep(x, op, n) + esc(expr) +end + +rep(:a, :*, 3) + +@rule (@rep :a (*) 3) => :b + +Mid = @theory a begin + a * :ε --> a + :ε * a --> a +end + +Massoc = @theory a b c begin + a * (b * c) --> (a * b) * c + (a * b) * c --> a * (b * c) +end + + +T = [ + @rule :b * :B --> :ε + @rule :a * :a --> :ε + @rule :b * :b * :b --> :ε + @rule :B * :B --> :B + @rule (@rep (:a * :b) (*) 7) --> :ε + @rule (@rep (:a * :b * :a * :B) (*) 7) --> :ε +] + +G = Mid ∪ Massoc ∪ T + + +another_expr = :(b * B) +g = EGraph(another_expr) +saturate!(g, G) +ex = extract!(g, astsize) +@test ex == :ε + +another_expr = :(a * a * a * a) +g = EGraph(another_expr) +some_eclass = addexpr!(g, another_expr) +saturate!(g, G) +ex = extract!(g, astsize; root = some_eclass) +@test ex == :ε + +another_expr = :(((((((a * b) * (a * b)) * (a * b)) * (a * b)) * (a * b)) * (a * b)) * (a * b)) +g = EGraph(another_expr) +some_eclass = addexpr!(g, another_expr) +saturate!(g, G) +ex = extract!(g, astsize; root = some_eclass) +@test ex == :ε + + +expr = :(a * b * a * a * a * b * b * b * a * B * B * B * B * a) +g = EGraph(expr) +params = SaturationParams(timeout = 9, scheduler = BackoffScheduler)# , schedulerparams=(128,4))#, scheduler=SimpleScheduler) +@timev saturate!(g, G, params) +ex = extract!(g, astsize) +@test_broken ex == :ε + diff --git a/test/integration/lambda_theory.jl b/test/integration/lambda_theory.jl new file mode 100644 index 00000000..5ec117e5 --- /dev/null +++ b/test/integration/lambda_theory.jl @@ -0,0 +1,148 @@ +using Metatheory +using Metatheory.EGraphs +using Metatheory.Library +using TermInterface +using Test + +abstract type LambdaExpr end + +@matchable struct IfThenElse <: LambdaExpr + guard + then + otherwise +end + +@matchable struct Variable <: LambdaExpr + x::Symbol +end + +@matchable struct Fix <: LambdaExpr + variable + expression +end + +@matchable struct Let <: LambdaExpr + variable + value + body +end +@matchable struct λ <: LambdaExpr + x::Symbol + body +end + +@matchable struct Apply <: LambdaExpr + lambda + value +end + +@matchable struct Add <: LambdaExpr + x + y +end + +TermInterface.exprhead(::LambdaExpr) = :call + +function EGraphs.egraph_reconstruct_expression(::Type{<:LambdaExpr}, op, args; metadata=nothing, exprhead=:call) + op(args...) +end + +#%% +EGraphs.make(::Val{:freevar}, ::EGraph, n::ENodeLiteral) = Set{Int64}() + +function EGraphs.make(::Val{:freevar}, g::EGraph, n::ENodeTerm) + free = Set{Int64}() + if exprhead(n) == :call + op = operation(n) + args = arguments(n) + + if op == Variable + push!(free, args[1]) + elseif op == Let + v, a, b = args[1:3] + adata = getdata(g[a], :freevar, Set{Int64}()) + bdata = getdata(g[a], :freevar, Set{Int64}()) + union!(free, adata) + delete!(free, v) + union!(free, bdata) + elseif op == λ + v, b = args[1:2] + bdata = getdata(g[b], :freevar, Set{Int64}()) + union!(free, bdata) + delete!(free, v) + end + end + + return free +end + +EGraphs.join(::Val{:freevar}, from, to) = union(from, to) + +islazy(::Val{:freevar}) = false + +open_term = @theory x e then alt a b c begin + # if-true + IfThenElse(true, then, alt) --> then + IfThenElse(false, then, alt) --> alt + # if-elim + IfThenElse(Variable(x) == e, then, alt) => + if addexpr!(_egraph, Let(x, e, then)) == addexpr!(_egraph, Let(x, e, alt)) + alt + else + _lhs_expr + end + Add(a, b) == Add(b, a) + Add(a, Add(b,c)) == Add(Add(a,b),c) + # (a == b) == (b == a) +end + +subst_intro = @theory v body e begin + Fix(v, e) --> Let(v, Fix(v, e), e) + # beta reduction + Apply(λ(v, body), e) --> Let(v, e, body) +end + +subst_prop = @theory v e a b then alt guard begin + # let-Apply + Let(v, e, Apply(a, b)) --> Apply(Let(v, e, a), Let(v, e, b)) + # let-add + Let(v, e, a + b) --> Let(v, e, a) + Let(v, e, b) + # let-eq + # Let(v, e, a == b) --> Let(v, e, a) == Let(v, e, b) + # let-IfThenElse (let-if) + Let(v, e, IfThenElse(guard, then, alt)) --> IfThenElse(Let(v, e, guard), Let(v, e, then), Let(v, e, alt)) +end + + +subst_elim = @theory v e c v1 v2 body begin + # let-const + Let(v, e, c::Any) --> c + # let-Variable-same + Let(v1, e, Variable(v1)) --> e + # TODO fancy let-Variable-diff + Let(v1, e, Variable(v2)) => if addexpr!(_egraph, v1) != addexpr!(_egraph, v2) + :(Variable($v2)) + else + _lhs_expr + end + # let-lam-same + Let(v1, e, λ(v1, body)) --> λ(v1, body) + # let-lam-diff #TODO captureavoid + Let(v1, e, λ(v2, body)) => if v2.id ∈ getdata(e, :freevar, Set()) # is free + :(λ($fresh, Let($v1, $e, Let($v2, Variable($fresh), $body)))) + else + :(λ($v2, Let($v1, $e, $body))) + end +end + +λT = open_term ∪ subst_intro ∪ subst_prop ∪ subst_elim + +ex = λ(:x, Add(4, Apply(λ(:y, Variable(:y)), 4))) +g = EGraph(ex) + +settermtype!(g, LambdaExpr) +saturate!(g, λT) +@test λ(:x, Add(4, 4)) == extract!(g, astsize) # expected: :(λ(x, 4 + 4)) + +#%% +@test @areequal λT 2 Apply(λ(x, Variable(x)), 2) \ No newline at end of file diff --git a/test/integration/logic.jl b/test/integration/logic.jl new file mode 100644 index 00000000..0cb38f5e --- /dev/null +++ b/test/integration/logic.jl @@ -0,0 +1,189 @@ +using Test +using Metatheory +using TermInterface + +function prove(t, ex, steps = 1, timeout = 10, eclasslimit = 5000) + params = SaturationParams( + timeout = timeout, + eclasslimit = eclasslimit, + # scheduler=Schedulers.ScoredScheduler, schedulerparams=(1000,5, Schedulers.exprsize)) + scheduler = Schedulers.BackoffScheduler, + schedulerparams = (6000, 5), + ) + + hist = UInt64[] + push!(hist, hash(ex)) + for i in 1:steps + g = EGraph(ex) + + exprs = [true, g[g.root]] + ids = [addexpr!(g, e) for e in exprs] + + goal = EqualityGoal(exprs, ids) + params.goal = goal + rep = saturate!(g, t, params) + @show rep + ex = extract!(g, astsize) + if !TermInterface.istree(ex) + return ex + end + if hash(ex) ∈ hist + return ex + end + push!(hist, hash(ex)) + end + return ex +end + +function ⟹ end + +fold = @theory p q begin + (p::Bool == q::Bool) => (p == q) + (p::Bool || q::Bool) => (p || q) + (p::Bool ⟹ q::Bool) => ((p || q) == q) + (p::Bool && q::Bool) => (p && q) + !(p::Bool) => (!p) +end + + +@testset "Prop logic" begin + or_alg = @theory p q r begin + ((p || q) || r) == (p || (q || r)) + (p || q) == (q || p) + (p || p) --> p + (p || true) --> true + (p || false) --> p + end + + and_alg = @theory p q r begin + ((p && q) && r) == (p && (q && r)) + (p && q) == (q && p) + (p && p) --> p + (p && true) --> p + (p && false) --> false + end + + comb = @theory p q r begin + # DeMorgan + !(p || q) == (!p && !q) + !(p && q) == (!p || !q) + # distrib + (p && (q || r)) == ((p && q) || (p && r)) + (p || (q && r)) == ((p || q) && (p || r)) + # absorb + (p && (p || q)) --> p + (p || (p && q)) --> p + # complement + (p && (!p || q)) --> p && q + (p || (!p && q)) --> p || q + end + + negt = @theory p begin + (p && !p) --> false + (p || !(p)) --> true + !(!p) == p + end + + impl = @theory p q begin + (p == !p) --> false + (p == p) --> true + (p == q) --> (!p || q) && (!q || p) + (p ⟹ q) --> (!p || q) + end + + + t = or_alg ∪ and_alg ∪ comb ∪ negt ∪ impl ∪ fold + + ex = rewrite(:(((p ⟹ q) && (r ⟹ s) && (p || r)) ⟹ (q || s)), impl) + @test prove(t, ex, 5, 10, 5000) + + + @test @areequal t true ((!p == p) == false) + @test @areequal t true ((!p == !p) == true) + @test @areequal t true ((!p || !p) == !p) (!p || p) !(!p && p) + @test @areequal t p (p || p) + @test @areequal t true ((p ⟹ (p || p))) + @test @areequal t true ((p ⟹ (p || p)) == ((!(p) && q) ⟹ q)) == true + + # Frege's theorem + @test @areequal t true (p ⟹ (q ⟹ r)) ⟹ ((p ⟹ q) ⟹ (p ⟹ r)) + + # Demorgan's + @test @areequal t true (!(p || q) == (!p && !q)) + + # Consensus theorem + # @test_broken @areequal t true ((x && y) || (!x && z) || (y && z)) ((x && y) || (!x && z)) +end + +# https://www.cs.cornell.edu/gries/Logic/Axioms.html +# The axioms of calculational propositional logic C are listed in the order in +# which they are usually presented and taught. Note that equivalence comes +# first. Note also that, after the first axiom, we take advantage of +# associativity of equivalence and write sequences of equivalences without +# parentheses. We use == for equivalence, | for disjunction, & for conjunction, + +# Golden rule: p & q == p == q == p | q +# +# Implication: p ⟹ q == p | q == q +# Consequence: p ⟸q == q ⟹ p + +# Definition of false: false == !true +@testset "Calculational Logic" begin + calc = @theory p q r begin + # Associativity of ==: + ((p == q) == r) == (p == (q == r)) + # Symmetry of ==: + (p == q) == (q == p) + # Identity of ==: + (q == q) --> true + # Excluded middle + # Distributivity of !: + !(p == q) == (!(p) == q) + # Definition of !=: + (p != q) == !(p == q) + #Associativity of ||: + ((p || q) || r) == (p || (q || r)) + # Symmetry of ||: + (p || q) == (q || p) + # Idempotency of ||: + (p || p) --> p + # Distributivity of ||: + (p || (q == r)) == (p || q == p || r) + # Excluded Middle: + (p || !(p)) --> true + + # DeMorgan + !(p || q) == (!p && !q) + !(p && q) == (!p || !q) + + (p && q) == ((p == q) == p || q) + + (p ⟹ q) == ((p || q) == q) + end + + # t = or_alg ∪ and_alg ∪ neg_alg ∪ demorgan ∪ and_or_distrib ∪ + # absorption ∪ calc + + t = calc ∪ fold + + g = EGraph(:(((!p == p) == false))) + saturate!(g, t) + extract!(g, astsize) + + @test @areequal t true ((!p == p) == false) + @test @areequal t true ((!p == !p) == true) + @test @areequal t true ((!p || !p) == !p) (!p || p) !(!p && p) + @test @areequal t true ((p ⟹ (p || p)) == true) + params = SaturationParams(timeout = 12, eclasslimit = 10000, schedulerparams = (1000, 5)) + + @test areequal(t, true, :(((p ⟹ (p || p)) == ((!(p) && q) ⟹ q)) == true); params = params) + + # Frege's theorem + @test areequal(t, true, :((p ⟹ (q ⟹ r)) ⟹ ((p ⟹ q) ⟹ (p ⟹ r))); params = params) + + # Demorgan's + @test @areequal t true (!(p || q) == (!p && !q)) + + # Consensus theorem + areequal(t, :((x && y) || (!x && z) || (y && z)), :((x && y) || (!x && z)); params = params) +end diff --git a/test/integration/stream_fusion.jl b/test/integration/stream_fusion.jl new file mode 100644 index 00000000..14f6874b --- /dev/null +++ b/test/integration/stream_fusion.jl @@ -0,0 +1,114 @@ +using Metatheory +using Metatheory.Rewriters +using Test +using TermInterface +# using SymbolicUtils + +apply(f, x) = f(x) +fand(f, g) = x -> f(x) && g(x) + +array_theory = @theory x y f g M N begin + #map(f,x)[n:m] = map(f,x[n:m]) # but does NOT commute with filter + map(f, fill(x, N)) == fill(apply(f, x), N) # hmm + # cumsum(fill(x,N)) == collect(x:x:(N*x)) + fill(x, N)[y] --> x + length(fill(x, N)) --> N + reverse(reverse(x)) --> x + sum(fill(x, N)) --> x * N + map(f, reverse(x)) == reverse(map(f, x)) + filter(f, reverse(x)) == reverse(filter(f, x)) + reverse(fill(x, N)) == fill(x, N) + filter(f, fill(x, N)) == ( + if apply(f, x) + fill(x, N) + else + fill(x, 0) + end + ) + filter(f, filter(g, x)) == filter(fand(f, g), x) # using functional && + cat(fill(x, N), fill(x, M)) == fill(x, N + M) + cat(map(f, x), map(f, y)) == map(f, cat(x, y)) + map(f, cat(x, y)) == cat(map(f, x), map(f, y)) + map(f, map(g, x)) == map(f ∘ g, x) + reverse(cat(x, y)) == cat(reverse(y), reverse(x)) + map(f, x)[y] == apply(f, x[y]) + apply(f ∘ g, x) == apply(f, apply(g, x)) + + reduce(g, map(f, x)) == mapreduce(f, g, x) + foldl(g, map(f, x)) == mapfoldl(f, g, x) + foldr(g, map(f, x)) == mapfoldr(f, g, x) +end + +asymptot_t = @theory x y z n m f g begin + (length(filter(f, x)) <= length(x)) => true + length(cat(x, y)) --> length(x) + length(y) + length(map(f, x)) => length(map) + length(x::UnitRange) => length(x) +end + +fold_theory = @theory x y z begin + x::Number * y::Number => x * y + x::Number + y::Number => x + y + x::Number / y::Number => x / y + x::Number - y::Number => x / y + # etc... +end + +# Simplify expressions like :(d->3:size(A,d)-3) given an explicit value for d +import Base.Cartesian: inlineanonymous + + +tryinlineanonymous(x) = nothing +function tryinlineanonymous(ex::Expr) + exprhead(ex) != :call && return nothing + f = operation(ex) + (!(f isa Expr) || exprhead(f) !== :->) && return nothing + arg = arguments(ex)[1] + try + return inlineanonymous(f, arg) + catch e + return nothing + end +end + +normalize_theory = @theory x y z f g begin + fand(f, g) => Expr(:->, :x, :(($f)(x) && ($g)(x))) + apply(f, x) => Expr(:call, f, x) +end + +params = SaturationParams() + +function stream_optimize(ex) + g = EGraph(ex) + rep = saturate!(g, array_theory, params) + @info rep + ex = extract!(g, astsize) # TODO cost fun with asymptotic complexity + ex = Fixpoint(Postwalk(Chain([tryinlineanonymous, normalize_theory..., fold_theory...])))(ex) + return ex +end + +build_fun(ex) = eval(:(() -> $ex)) + + +@testset "Stream Fusion" begin + ex = :(map(x -> 7 * x, fill(3, 4))) + opt = stream_optimize(ex) + @test opt == :(fill(21, 4)) + + ex = :(map(x -> 7 * x, fill(3, 4))[1]) + opt = stream_optimize(ex) + @test opt == 21 +end + +# ['a','1','2','3','4'] +ex = :(filter(ispow2, filter(iseven, reverse(reverse(fill(4, 100)))))) +opt = stream_optimize(ex) + + +ex = :(map(x -> 7 * x, reverse(reverse(fill(13, 40))))) +opt = stream_optimize(ex) +opt = stream_optimize(opt) + +macro stream_optimize(ex) + stream_optimize(ex) +end diff --git a/test/integration/taylor.jl b/test/integration/taylor.jl new file mode 100644 index 00000000..ff7e703f --- /dev/null +++ b/test/integration/taylor.jl @@ -0,0 +1,36 @@ +using Metatheory + +struct Σ end + +taylor = @theory x a b begin + exp(x) --> Σ(x^:n / factorial(big(:n))) + cos(x) --> Σ((-1)^:n * x^2(:n) / factorial(big(2 * :n))) + Σ(a) + Σ(b) --> Σ(a + b) +end + +macro expand(iters) + quote + @rule a Σ(a) --> sum((:n -> a), $(0:iters)) + end +end + +a = rewrite(:(exp(x) + cos(x)), taylor) + +r = @expand(5000) +# r = expand(5000) +bexpr = rewrite(a, [r]) + +# you may want to do algebraic simplification +# with egraphs here + +x = big(42) + +b = eval(bexpr) +# 1.739274941520501044994695988622883932193276720547806372656638132701531037200611e+18 + +exp(x) + cos(x) +# 1.739274941520501046994695988622883932193276720547806372656638132701531037200651e+18 + +@testset "Infinite Series Approximation" begin + @test b ≈ (exp(x) + cos(x)) +end diff --git a/test/integration/while_superinterpreter.jl b/test/integration/while_superinterpreter.jl new file mode 100644 index 00000000..d083ffc0 --- /dev/null +++ b/test/integration/while_superinterpreter.jl @@ -0,0 +1,192 @@ + +## Turing Complete Interpreter +### A Very Tiny Turing Complete Programming Language defined with denotational semantics + +# semantica dalle dispense degano +using Metatheory, Test + +import Base.ImmutableDict +Mem = Dict{Symbol,Union{Bool,Int}} + +read_mem = @theory v σ begin + (v::Symbol, σ::Mem) => if v == :skip + σ + else + σ[v] + end +end + +@testset "Reading Memory" begin + ex = :((x), $(Mem(:x => 2))) + @test true == areequal(read_mem, ex, 2) +end + +arithm_rules = @theory a b σ begin + (a + b, σ::Mem) --> (a, σ) + (b, σ) + (a * b, σ::Mem) --> (a, σ) * (b, σ) + (a - b, σ::Mem) --> (a, σ) - (b, σ) + (a::Int, σ::Mem) --> a + (a::Int + b::Int) => a + b + (a::Int * b::Int) => a * b + (a::Int - b::Int) => a - b +end + + +@testset "Arithmetic" begin + @test areequal(read_mem ∪ arithm_rules, :((2 + 3), $(Mem())), 5) +end + +# don't need to access memory +bool_rules = @theory a b σ begin + (a < b, σ::Mem) --> (a, σ) < (b, σ) + (a || b, σ::Mem) --> (a, σ) || (b, σ) + (a && b, σ::Mem) --> (a, σ) && (b, σ) + (!(a), σ::Mem) --> !((a, σ)) + + (a::Bool, σ::Mem) => a + (!a::Bool) => !a + (a::Bool || b::Bool) => (a || b) + (a::Bool && b::Bool) => (a && b) + (a::Int < b::Int) => (a < b) +end + +t = read_mem ∪ arithm_rules ∪ bool_rules + +@testset "Booleans" begin + @test areequal(t, :((false || false), $(Mem())), false) + + exx = :((false || false) || !(false || false), $(Mem(:x => 2))) + g = EGraph(exx) + saturate!(g, t) + ex = extract!(g, astsize) + @test ex == true + params = SaturationParams(timeout = 12) + @test areequal(t, exx, true; params = params) + + @test areequal(t, :((2 < 3) && (3 < 4), $(Mem(:x => 2))), true) + @test areequal(t, :((2 < x) || !(3 < 4), $(Mem(:x => 2))), false) + @test areequal(t, :((2 < x) || !(3 < 4), $(Mem(:x => 4))), true) +end + +if_rules = @theory guard t f σ begin + ( + if guard + t + end + ) --> ( + if guard + t + else + :skip + end + ) + (if guard + t + else + f + end, σ::Mem) --> (if (guard, σ) + t + else + f + end, σ) + (if true + t + else + f + end, σ::Mem) --> (t, σ) + (if false + t + else + f + end, σ::Mem) --> (f, σ) +end + +if_language = read_mem ∪ arithm_rules ∪ bool_rules ∪ if_rules + + +@testset "If Semantics" begin + @test areequal(if_language, 2, :(if true + x + else + 0 + end, $(Mem(:x => 2)))) + @test areequal(if_language, 0, :(if false + x + else + 0 + end, $(Mem(:x => 2)))) + @test areequal(if_language, 2, :(if !(false) + x + else + 0 + end, $(Mem(:x => 2)))) + params = SaturationParams(timeout = 10) + @test areequal(if_language, 0, :(if !(2 < x) + x + else + 0 + end, $(Mem(:x => 3))); params = params) +end + + +while_rules = @theory a b σ begin + (:skip, σ::Mem) --> σ + ((a; b), σ::Mem) --> ((a, σ); b) + (a::Int; b) --> b + (a::Bool; b) --> b + (σ::Mem; b) --> (b, σ) + (while a + b + end, σ::Mem) --> (if a + (b; + while a + b + end) + else + :skip + end, σ) +end + + +write_mem = @theory sym val σ begin + (sym::Symbol = val, σ::Mem) --> (sym = (val, σ), σ) + (sym::Symbol = val::Int, σ::Mem) => merge(σ, Dict(sym => val)) +end + +while_language = if_language ∪ write_mem ∪ while_rules; + +@testset "While Semantics" begin + exx = :((x = 3), $(Mem(:x => 2))) + g = EGraph(exx) + saturate!(g, while_language) + ex = extract!(g, astsize) + + @test areequal(while_language, Mem(:x => 3), exx) + + exx = :((x = 4; x = x + 1), $(Mem(:x => 3))) + g = EGraph(exx) + saturate!(g, while_language) + ex = extract!(g, astsize) + + params = SaturationParams(timeout = 10) + @test areequal(while_language, Mem(:x => 5), exx; params = params) + + params = SaturationParams(timeout = 14, timer=false) + exx = :(( + if x < 10 + x = x + 1 + else + skip + end + ), $(Mem(:x => 3))) + @test areequal(while_language, Mem(:x => 4), exx; params = params) + + exx = :((while x < 10 + x = x + 1 + end; + x), $(Mem(:x => 3))) + g = EGraph(exx) + params = SaturationParams(timeout = 100) + saturate!(g, while_language, params) + @test 10 == extract!(g, astsize) +end diff --git a/test/lambda/lambda_theory.jl b/test/lambda/lambda_theory.jl deleted file mode 100644 index ed119997..00000000 --- a/test/lambda/lambda_theory.jl +++ /dev/null @@ -1,98 +0,0 @@ -using Metatheory -using Metatheory.EGraphs -using Test - -open_term = @theory begin - # if-true - cond(true, then, alt) => then - cond(false, then, alt) => alt - # if-elim - cond(var(x) == e, then, alt) |> - if addexpr!(_egraph, :(llet($x,$e,$then))) == - addexpr!(_egraph, :(llet($x,$e,$alt))) - alt - else _lhs_expr end - a + b => b + a - a + (b + c) => (a + b) + c - (a == b) => (b == a) -end - -subst_intro = @theory begin - fix(v, e) => llet(v, fix(v,e), e) - # beta reduction - app(λ(v, body), e) => llet(v, e, body) -end - -subst_prop = @theory begin - # let-app - llet(v, e, app(a, b)) => app(llet(v,e,a), llet(v,e,b)) - # let-add - llet(v, e, a + b) => llet(v,e,a) + llet(v,e,b) - # let-eq - llet(v, e, a == b) => llet(v,e,a) == llet(v,e,b) - # let-cond (let-if) - llet(v, e, cond(guard, then, alt)) => - cond(llet(v,e,guard), llet(v,e,then), llet(v,e,alt)) -end - -subst_elim = @theory begin - # let-const - llet(v, e, c::Any) => c - # let-var-same - llet(v1, e, var(v1)) => e - # TODO fancy let-var-diff - llet(v1, e, var(v2)) |> - if find(_egraph, v1) != find(_egraph, v2) - :(var($v2)) - else _lhs_expr end - # let-lam-same - llet(v1, e, λ(v1, body)) => λ(v1, body) - # let-lam-diff #TODO captureavoid - llet(v1, e, λ(v2, body)) |> - if v2.id ∈ getdata(e, FreeVarAnalysis, Set()) # is free - :(λ($fresh, llet($v1, $e, llet($v2, var($fresh), $body)))) - else - :(λ($v2, llet($v1, $e, $body))) - end -end - -λT = open_term ∪ subst_intro ∪ subst_prop ∪ subst_elim - -ex = :(λ(x, 4 + app(λ(y, var(y)), 4))) -g = EGraph(ex) -# analyze!(g, FreeVarAnalysis) -saturate!(g, λT) -display(g.classes); println() -extract!(g, astsize) - - -@test @areequal λT 2 app(λ(x, var(x)), 2) - -abstract type FreeVarAnalysis <: AbstractAnalysis end - -function EGraphs.make(an::Type{FreeVarAnalysis}, g::EGraph, n::ENode) - free = Set{Int64}() - if n.head == :var - push!(free, n.args[1]) - elseif n.head == :llet - v,a,b = n.args[1:3] - adata = getdata(g[a], an, Set{Int64}()) - bdata = getdata(g[a], an, Set{Int64}()) - union!(free, adata) - delete!(free, v) - union!(free, bdata) - elseif n.head == :λ - v,b = n.args[1:2] - bdata = getdata(g[b], an, Set{Int64}()) - union!(free, bdata) - delete!(free, v) - end - - return free -end - -function EGraphs.join(an::Type{FreeVarAnalysis}, from, to) - union(from, to) -end - -islazy(an::Type{FreeVarAnalysis}) = false \ No newline at end of file diff --git a/test/logic/prop_logic_theory.jl b/test/logic/prop_logic_theory.jl deleted file mode 100644 index 13617b2b..00000000 --- a/test/logic/prop_logic_theory.jl +++ /dev/null @@ -1,58 +0,0 @@ -using Metatheory -using Metatheory.EGraphs -using Test - -or_alg = @theory p q r begin - ((p ∨ q) ∨ r) == (p ∨ (q ∨ r)) - (p ∨ q) == (q ∨ p) - (p ∨ p) --> p - (p ∨ true) --> true - (p ∨ false) --> p -end - -and_alg = @theory p q r begin - ((p ∧ q) ∧ r) == (p ∧ (q ∧ r)) - (p ∧ q) == (q ∧ p) - (p ∧ p) --> p - (p ∧ true) --> p - (p ∧ false) --> false -end - -comb = @theory p q r begin - # DeMorgan - ¬(p ∨ q) == (¬p ∧ ¬q) - ¬(p ∧ q) == (¬p ∨ ¬q) - # distrib - (p ∧ (q ∨ r)) == ((p ∧ q) ∨ (p ∧ r)) - (p ∨ (q ∧ r)) == ((p ∨ q) ∧ (p ∨ r)) - # absorb - (p ∧ (p ∨ q)) --> p - (p ∨ (p ∧ q)) --> p - # complement - (p ∧ (¬p ∨ q)) --> p ∧ q - (p ∨ (¬p ∧ q)) --> p ∨ q -end - -negt = @theory p begin - (p ∧ ¬p) --> false - (p ∨ ¬(p)) --> true - ¬(¬p) == p -end - -impl = @theory p q begin - (p == ¬p) --> false - (p == p) --> true - (p == q) --> (¬p ∨ q) ∧ (¬q ∨ p) - (p => q) --> (¬p ∨ q) -end - -fold = @theory p q begin - (p::Bool == q::Bool) => (p == q) - (p::Bool ∨ q::Bool) => (p || q) - (p::Bool => q::Bool) => ((p || q) == q) - (p::Bool ∧ q::Bool) => (p && q) - ¬(p::Bool) => (!p) -end - -t = or_alg ∪ and_alg ∪ comb ∪ negt ∪ impl ∪ fold - \ No newline at end of file diff --git a/test/logic/prover.jl b/test/logic/prover.jl deleted file mode 100644 index 85b234fb..00000000 --- a/test/logic/prover.jl +++ /dev/null @@ -1,34 +0,0 @@ -using Metatheory -using Metatheory.EGraphs -using TermInterface - -function prove(t, ex, steps=1, timeout=10, eclasslimit=5000) - params = SaturationParams(timeout=timeout, eclasslimit=eclasslimit, - # scheduler=Schedulers.ScoredScheduler, schedulerparams=(1000,5, Schedulers.exprsize)) - scheduler=Schedulers.BackoffScheduler, schedulerparams=(6000,5)) - - hist = UInt64[] - push!(hist, hash(ex)) - for i ∈ 1:steps - g = EGraph(ex) - - exprs = [true, g[g.root]] - ids = [addexpr!(g, e)[1].id for e in exprs] - - goal=EqualityGoal(exprs, ids) - params.goal = goal - saturate!(g, t, params) - ex = extract!(g, astsize) - println(ex) - if !TermInterface.istree(typeof(ex)) - return ex - end - if hash(ex) ∈ hist - println("loop detected") - return ex - end - push!(hist, hash(ex)) - end - return ex -end - diff --git a/test/logic/test_calculational_logic.jl b/test/logic/test_calculational_logic.jl deleted file mode 100644 index 65b83abe..00000000 --- a/test/logic/test_calculational_logic.jl +++ /dev/null @@ -1,92 +0,0 @@ -# https://www.cs.cornell.edu/gries/Logic/Axioms.html -# The axioms of calculational propositional logic C are listed in the order in -# which they are usually presented and taught. Note that equivalence comes -# first. Note also that, after the first axiom, we take advantage of -# associativity of equivalence and write sequences of equivalences without -# parentheses. We use == for equivalence, | for disjunction, & for conjunction, -# ~ for negation (not), => for implication, and <= for consequence. -# -# Associativity of ==: ((p == q) == r) == (p == (q == r)) -# Symmetry of ==: p == q == q == p -# Identity of ==: true == q == q -# -# Definition of false: false == ~true -# Distributivity of not: ~(p == q) == ~p == q -# Definition of =/=: (p =/= q) == ~(p == q) -# -# Associativity of |: (p | q) & r == p | (q | r) -# Symmetry of |: p | q == q | p -# Idempotency of |: p | p == p -# Distributivity of |: p |(q == r) == p | q == p | r -# Excluded Middle: p | ~p -# -# Golden rule: p & q == p == q == p | q -# -# Implication: p => q == p | q == q -# Consequence: p <= q == q => p - - -using Metatheory - -calc = @theory p q r begin - ((p == q) == r) == (p == (q == r)) - (p == q) == (q == p) - (q == q) --> true - - ¬(p == q) == (¬(p) == q) - (p != q) == ¬(p == q) - - ((p ∨ q) ∨ r) == (p ∨ (q ∨ r)) - (p ∨ q) == (q ∨ p) - (p ∨ p) --> p - (p ∨ (q == r)) == (p ∨ q == p ∨ r) - (p ∨ ¬(p)) --> true - - # DeMorgan - ¬(p ∨ q) == (¬p ∧ ¬q) - ¬(p ∧ q) == (¬p ∨ ¬q) - - (p ∧ q) == ((p == q) == p ∨ q) - - (p => q) == ((p ∨ q) == q) - # (p => q) == (¬p ∨ q) - # (p <= q) => (q => p) -end - -fold = @theory p q begin - (p::Bool == q::Bool) => (p == q) - (p::Bool ∨ q::Bool) => (p || q) - (p::Bool => q::Bool) => ((p || q) == q) - (p::Bool ∧ q::Bool) => (p && q) - ¬(p::Bool) => (!p) -end - -# t = or_alg ∪ and_alg ∪ neg_alg ∪ demorgan ∪ and_or_distrib ∪ -# absorption ∪ calc - -t = calc ∪ fold - -g = EGraph(:(((¬p == p) == false))) -saturate!(g, t) -extract!(g, astsize) - -@test @areequal t true ((¬p == p) == false) -@test @areequal t true ((¬p == ¬p) == true) -@test @areequal t true ((¬p ∨ ¬p) == ¬p) (¬p ∨ p) ¬(¬p ∧ p) -@test @areequal t true ((p => (p ∨ p)) == true) -params = SaturationParams(timeout=12, eclasslimit=10000, schedulerparams=(1000, 5)) - -@test areequal(t, true, :(((p => (p ∨ p)) == ((¬(p) ∧ q) => q)) == true); params=params) - -# Frege's theorem -# params = SaturationParams(timeout=12, eclasslimit=15000, scheduler=Schedulers.ScoredScheduler) -# params = SaturationParams(timeout=12, eclasslimit=15000, schedulerparams=(500, 2)) -@test_skip areequal(t, true, :((p => (q => r)) => ((p => q) => (p => r))); params=params) - -# Demorgan's -@test @areequal t true (¬(p ∨ q) == (¬p ∧ ¬q)) - -# Consensus theorem -# @test_skip -areequal(t, :((x ∧ y) ∨ (¬x ∧ z) ∨ (y ∧ z)), :((x ∧ y) ∨ (¬x ∧ z)); params=params) - diff --git a/test/logic/test_logic.jl b/test/logic/test_logic.jl deleted file mode 100644 index ec339204..00000000 --- a/test/logic/test_logic.jl +++ /dev/null @@ -1,24 +0,0 @@ -include("prop_logic_theory.jl") -include("prover.jl") - -using Test - -ex = rewrite(:(((p => q) ∧ (r => s) ∧ (p ∨ r)) => (q ∨ s)), impl) -@test prove(t, ex, 3, 10, 5000) - - -@test @areequal t true ((¬p == p) == false) -@test @areequal t true ((¬p == ¬p) == true) -@test @areequal t true ((¬p ∨ ¬p) == ¬p) (¬p ∨ p) ¬(¬p ∧ p) -@test @areequal t true ((p => (p ∨ p))) -@test @areequal t true ((p => (p ∨ p)) == ((¬(p) ∧ q) => q)) == true - -# Frege's theorem -@test @areequal t true (p => (q => r)) => ((p => q) => (p => r)) - -# Demorgan's -@test @areequal t true (¬(p ∨ q) == (¬p ∧ ¬q)) - -# Consensus theorem -@test @areequal t ((x ∧ y) ∨ (¬x ∧ z) ∨ (y ∧ z)) ((x ∧ y) ∨ (¬x ∧ z)) - diff --git a/test/numberfold.jl b/test/numberfold.jl deleted file mode 100644 index 41694087..00000000 --- a/test/numberfold.jl +++ /dev/null @@ -1,54 +0,0 @@ -abstract type NumberFold <: AbstractAnalysis end - -using TermInterface - -function EGraphs.make(an::Type{NumberFold}, g::EGraph, n::ENodeLiteral) - n.value -end - -# This should be auto-generated by a macro -function EGraphs.make(an::Type{NumberFold}, g::EGraph, n::ENodeTerm) - if exprhead(n) == :call && arity(n) == 2 - op = operation(n) - args = arguments(n) - l = g[args[1]] - r = g[args[2]] - ldata = getdata(l, an, nothing) - rdata = getdata(r, an, nothing) - - # @show ldata rdata - - if ldata isa Number && rdata isa Number - if op == :* - return ldata * rdata - elseif op == :+ - return ldata + rdata - end - end - end - - return nothing -end - -function EGraphs.join(an::Type{NumberFold}, from, to) - # println("joining!") - if from isa Number - if to isa Number - @assert from == to - else return from - end - end - return to -end - -function EGraphs.modify!(an::Type{NumberFold}, g::EGraph, id::Int64) - # !haskey(an, id) && return nothing - eclass = g.classes[id] - d = getdata(eclass, an, nothing) - if d isa Number - newclass, _ = addexpr!(g, d) - merge!(g, newclass.id, id) - end -end - -EGraphs.islazy(x::Type{NumberFold}) = false diff --git a/test/proof/test_proof.jl b/test/proof/test_proof.jl deleted file mode 100644 index c88b5361..00000000 --- a/test/proof/test_proof.jl +++ /dev/null @@ -1,189 +0,0 @@ -using Metatheory, Metatheory.EGraphs -using Test - -dbgproof(n::ENode) = println("$(n.proof_src) ⩜ $(n.proof_trg)") - -function prove(g::EGraph, t::Vector{<:Rule}, exprs...; - params=SaturationParams()) - # @info "Checking equality for " exprs - n = length(exprs) - if n == 1; return true end - # rebuild!(G) - - ids = Vector{EClassId}(undef, n) - nodes = Vector{ENode}(undef, n) - for i ∈ 1:n - ec, node = addexpr!(g, exprs[i]) - ids[i] = ec.id - nodes[i] = node - end - - goal = EqualityGoal(collect(exprs), ids) - - # params.goal = goal - report = saturate!(g, t, params; mod=mod) - - # display(g.classes); println() - if !(report.reason === :saturated) && !reached(g, goal) - return missing # failed to prove - end - - # @show reached(g, goal) - - for (id, ec) in g.classes - for n in ec - # println(id => n) - # dbgproof(n) - end - end - - for i in 1:n - node = nodes[i] - if haskey(g.memo, node) - # TODO really override the proof step here? - eclass = g[g.memo[node]] - for nn in eclass - if node == nn - # dbgproof(node) - # dbgproof(nn) - # println("$node == $nn") - nodes[i] = nn - end - end - end - end - # println("========================================") - # for i in 1:n - # nn = nodes[i] - # dbgproof(nn) - # end - # println("========================================") - @show reached(g, goal) - proof_bfs(g, nodes[1], nodes[2]) -end - -mutable struct ProofNode - state::ENode - why::Union{Nothing,Rule} - when::Int - cost::Number - parent::Union{Nothing, ProofNode} -end - -struct Proof - g::EGraph - head::ProofNode -end - -""" -A closured cost function that considers the age of the enode and the ast size. -""" -function oldestatage(age::Int) - return (n::ENode, g::EGraph, an::Type{<:AbstractAnalysis}) -> begin - # cost = 0 - # println("current age is $age") - # println("enode $n age is $(n.age)") - # cost = n.age - age - cost = n.age - age - # println("cost is $cost") - return cost - end -end - - - -function oldest_node_extract(g::EGraph, n::ENode, age::Int) - costfun = oldestatage(age) - ex = EGraphs.extractnode(g, n, costfun) - # println("extracted $ex aged $(n.age) at age $age") - return ex -end - -function Base.show(io::IO, mime::MIME"text/plain", proof::Proof) - lines = [] - curr = proof.head - while !isnothing(curr.parent) - ex = oldest_node_extract(proof.g, curr.state, curr.when) - pushfirst!(lines, "$ex") - pushfirst!(lines, "from $(repr("text/plain", curr.why))") - curr = curr.parent - end - ex = oldest_node_extract(proof.g, curr.state, curr.when) - pushfirst!(lines, "given $ex") - - for line in lines - println(io, line) - end -end - -# TODO go through each path in the proof. do a BFS? -using DataStructures -function proof_bfs(g::EGraph, src::ENode, trg::ENode) - root = ProofNode(src, nothing, src.age, 0, nothing) - if src == trg - return Proof(g, root) - end - frontier = ProofNode[] - explored = Set{ENode}() - push!(frontier, root) - while !isempty(frontier) - node = popfirst!(frontier) - # println("exploring $node") - push!(explored, node.state) - # todo take rules in account - for (rule, child_enode, age) in unique(node.state.proof_trg) #∪ node.state.proof_src - child = ProofNode(child_enode, rule, age, node.cost+1, node) - if child_enode ∉ explored && child ∉ frontier - # goal test - if child_enode == trg - return Proof(g, child) - end - push!(frontier, child) - end - end - end - error("proof not found!") -end - - -t = @theory begin - a * b == b * a - a * 2 == a + a -end - - -# g = EGraph() -# addexpr!(g, :(2 * a)) -# addexpr!(g, :(a + a)) -# saturate!(g, t) -# proof = prove(g, t, :(2 * a), :(a + a)) - -# Base.print(proof) - -prove(EGraph(), t, :(2 * x), :(x + x)) - -include("../logic/prop_logic_theory.jl") - -ex = :((x => (y => z)) => ((x => y) => (x => z))) -proof = prove(EGraph(), t, ex, true) - - -proof = prove(EGraph(), t, :((x => (x ∨ x))), :((¬(x) ∧ y) => y)) - -# TODO introduce a mechanism of enode age and egraph age to be able to extract -# precisely which enode was there at that moment -# to be printed, if rule was applied at time `t`, then extract the enodes that was -# applied earliest but before (or after???) `t` -ex = :(((x ∨ y) ∨ ¬(z ∧ a)) ∨ a) -proof = prove(EGraph(), t, ex, true) - -# chat with oliver -# introduce a new type of enode at the proof generation stage -# set of unknown variables - - -# keep proof at the expr level -# do not store src and trg in enode -# store holes -# -# keep another unionfind for the same holes \ No newline at end of file diff --git a/test/reductions.jl b/test/reductions.jl new file mode 100644 index 00000000..13abd670 --- /dev/null +++ b/test/reductions.jl @@ -0,0 +1,223 @@ +using Metatheory + +@testset "Reduction Basics" begin + t = @theory begin + ~a + ~a --> 2 * (~a) + ~x / ~x --> 1 + ~x * 1 --> ~x + end + + # basic theory to check that everything works + @test rewrite(:(a + a), t) == :(2a) + @test rewrite(:(a + (x * 1)), t) == :(a + x) + @test rewrite(:(a + (a * 1)), t; order = :inner) == :(2a) +end + + +## Free Monoid + +@testset "Free Monoid - Overriding identity" begin + # support symbol literals + function ⋅ end + symbol_monoid = @theory begin + ~a ⋅ :ε --> ~a + :ε ⋅ ~a --> ~a + ~a::Symbol --> ~a + ~a::Symbol ⋅ ~b::Symbol => Symbol(String(a) * String(b)) + # i |> error("unsupported ", i) + end + + @test rewrite(:(ε ⋅ a ⋅ ε ⋅ b ⋅ c ⋅ (ε ⋅ ε ⋅ d) ⋅ e), symbol_monoid; order = :inner) == :abcde +end + +## Interpolation should be possible at runtime + + +@testset "Calculator" begin + function ⊗ end + function ⊕ end + function ⊖ end + calculator = @theory begin + ~x::Number ⊕ ~y::Number => ~x + ~y + ~x::Number ⊗ ~y::Number => ~x * ~y + ~x::Number ⊖ ~y::Number => ~x ÷ ~y + ~x::Symbol --> ~x + ~x::Number --> ~x + end + a = 10 + + @test rewrite(:(3 ⊕ 1 ⊕ $a), calculator; order = :inner) == 14 +end + + +## Direct rules +@testset "Direct Rules" begin + t = @theory begin + # maps + ~a * ~b => ((~a isa Number && ~b isa Number) ? ~a * ~b : _lhs_expr) + end + @test rewrite(:(3 * 1), t) == 3 + + t = @theory begin + # maps + ~a::Number * ~b::Number => ~a * ~b + end + @test rewrite(:(3 * 1), t) == 3 +end + + + +## Take advantage of subtyping. +# Subtyping in Julia has been formalized in this paper +# [Julia Subtyping: A Rational Reconstruction](https://benchung.github.io/papers/jlsub.pdf) + +abstract type Vehicle end +abstract type GroundVehicle <: Vehicle end +abstract type AirVehicle <: Vehicle end +struct Airplane <: AirVehicle end +struct Car <: GroundVehicle end + +airpl = Airplane() +car = Car() + +t = @theory begin + ~a::AirVehicle * ~b => "flies" + ~a::GroundVehicle * ~b => "doesnt_fly" +end + +@testset "Subtyping" begin + + sf = rewrite(:($airpl * c), t) + df = rewrite(:($car * c), t) + + @test sf == "flies" + @test df == "doesnt_fly" +end + + +@testset "Interpolation" begin + airpl = Airplane() + car = Car() + t = @theory begin + airpl * ~b => "flies" + car * ~b => "doesnt_fly" + end + + sf = rewrite(:($airpl * c), t) + df = rewrite(:($car * c), t) + + @test sf == "flies" + @test df == "doesnt_fly" +end + +@testset "Segment Variables" begin + function f end + function ok end + t = @theory begin + f(~x, ~~y) => Expr(:call, :ok, (~~y)...) + end + sf = rewrite(:(f(1, 2, 3, 4)), t) + @test sf == :(ok(2, 3, 4)) + + t = @theory x y begin + f(x, y...) => Expr(:call, :ok, y...) + end + sf = rewrite(:(f(1, 2, 3, 4)), t) + @test sf == :(ok(2, 3, 4)) + + t = @theory x y begin + f(x, y...) --> ok(y...) + end + sf = rewrite(:(f(1, 2, 3, 4)), t) + @test sf == :(ok(2, 3, 4)) +end + + +module NonCall +using Metatheory +function ok end +t = [@rule a b (a, b) --> ok(a, b)] + +test() = rewrite(:(x, y), t) +end + +@testset "Non-Call expressions" begin + @test NonCall.test() == :(ok(x, y)) +end + + +@testset "Pattern matcher can match on both function object references and name symbols" begin + ex = :($(+)($(sin)(x)^2, $(cos)(x)^2)) + r = @rule(sin(~x)^2 + cos(~x)^2 --> 1) + + @test r(ex) == 1 +end + + + +@testset "Pattern variable as pattern term head" begin + foo(x) = x + 2 + ex = :(($foo)(bar, 2, pazz)) + r = @rule ((~f)(~x, 2, ~y) => (~f)(2)) + + @test r(ex) == 4 +end + +using TermInterface + +using Metatheory.Syntax: @capture +@testset "Capture form" begin + ex = :(a^a) + + #note that @test inserts a soft local scope (try-catch) that would gobble + #the matches from assignment statements in @capture macro, so we call it + #outside the test macro + ret = @capture ex (~x)^(~x) + @test ret + @test @isdefined x + @test x === :a + + ex = :(b^a) + ret = @capture ex (~y)^(~y) + @test !ret + @test !(@isdefined y) + + ret = @capture :(a + b) (+)(~~z) + @test ret + @test @isdefined z + @test all(z .=== arguments(:(a + b))) + + #a more typical way to use the @capture macro + + f(x) = + if @capture x (~w)^(~w) + w + end + + @test f(:(b^b)) == :b + @test isnothing(f(:(b + b))) + + x = 1 + r = (@capture x x) + @test r == true +end + +using TermInterface +@testset "Matchable struct" begin + struct qux + args + qux(args...) = new(args) + end + TermInterface.operation(::qux) = qux + TermInterface.istree(::qux) = true + TermInterface.arguments(x::qux) = [x.args...] + + @capture qux(1, 2) qux(1, 2) + + @test (@rule qux(1, 2) => "hello")(qux(1, 2)) == "hello" + @test (@rule qux(1, 2) => "hello")(1) === nothing + @test (@rule 1 => "hello")(1) == "hello" + @test (@rule 1 => "hello")(qux(1, 2)) === nothing + @test (@capture qux(1, 2) qux(1, 2)) + @test false == (@capture qux(1, 2) qux(3, 4)) +end diff --git a/test/runtests.jl b/test/runtests.jl index 85d6788d..e8ffd337 100644 --- a/test/runtests.jl +++ b/test/runtests.jl @@ -1,43 +1,42 @@ using SafeTestsets +using Documenter using Metatheory using Test -@timev begin - @testset "All Tests" begin - @safetestset "Classical Rewriting" begin include("test_reductions.jl") end - @safetestset "EGraphs Basics" begin include("test_egraphs.jl") end - @safetestset "EMatch" begin include("test_ematch.jl") end - @safetestset "EGraph Analysis" begin include("test_analysis.jl") end - @safetestset "EGraph Extraction" begin include("test_extraction.jl") end - @safetestset "Mu Puzzle" begin include("test_mu.jl") end - @safetestset "While Interpreter" begin include("test_while_interpreter.jl") end - @safetestset "Taylor Series" begin include("test_taylor.jl") end - @safetestset "While Superinterpreter" begin include("test_while_superinterpreter.jl") end - @safetestset "EGraphs Inequalities" begin include("test_inequality.jl") end - @safetestset "Custom Types" begin include("test_custom_types.jl") end - @safetestset "Fibonacci" begin include("fib/test_fibonacci.jl") end - @safetestset "Calculational Logic" begin include("logic/test_calculational_logic.jl") end - @safetestset "PROP Logic" begin include("logic/test_logic.jl") end - @safetestset "CAS Infer" begin include("cas/test_infer.jl") end - @safetestset "Knuth Bendix Alternative Hurwitz Groups" begin include("test_kb_benchmark.jl") end - @safetestset "Stream Fusion" begin include("test_stream_fusion.jl") end +doctest(Metatheory) + +function test(file::String) + @info file + @eval @time @safetestset $file begin + include(joinpath(@__DIR__, $file)) + end +end + +const TEST_FILES = ["reductions.jl", "EGraphs/egraphs.jl", "EGraphs/ematch.jl", "EGraphs/analysis.jl"] +const INTEGRATION_TEST_FILES = map( + x -> joinpath(@__DIR__, "integration", x), + [ + "fibonacci.jl", + "kb_benchmark.jl", + "logic.jl", + "stream_fusion.jl", + "taylor.jl", + "while_superinterpreter.jl", + "lambda_theory.jl" + ], +) - - # @safetestset "Boson" begin include("test_boson.jl") end - # @safetestset "PatEquiv" begin include("test_patequiv.jl") end - # @testset "EGraphs Multipattern" begin include("test_multipat.jl") end - # @testset "PatAllTerm" begin include("test_patallterm.jl") end - # use cases - # @testset "Proofs" begin include("proof/test_proof.jl") end - # @testset "CAS" begin include("cas/test_cas.jl") end - # @safetestset "Categories" begin include("category/test_cat.jl") end - # TODO n-ary splatvar - end +const TUTORIALS = [joinpath(@__DIR__, "tutorials", x) for x in readdir("tutorials/") if endswith(x, ".jl")] + +@timev begin + @timev map(test, TEST_FILES) + @timev map(test, INTEGRATION_TEST_FILES) + @timev map(test, TUTORIALS) end - + # exported consistency test -for m ∈ [Metatheory, Metatheory.EGraphs, Metatheory.EGraphs.Schedulers] - for i ∈ propertynames(m) - xxx = getproperty(m, i) - end -end \ No newline at end of file +for m in [Metatheory, Metatheory.EGraphs, Metatheory.EGraphs.Schedulers] + for i in propertynames(m) + !hasproperty(m, i) && error("Module $m exports undefined symbol $i") + end +end diff --git a/test/test_analysis.jl b/test/test_analysis.jl deleted file mode 100644 index 438507f7..00000000 --- a/test/test_analysis.jl +++ /dev/null @@ -1,87 +0,0 @@ -# example assuming * operation is always binary - -# ENV["JULIA_DEBUG"] = Metatheory - -using Metatheory - -include("numberfold.jl") - -comm_monoid = @theory begin - ~a * ~b --> ~b * ~a - ~a * 1 --> ~a - ~a * (~b * ~c) --> (~a * ~b) * ~c -end - -G = EGraph(:(3 * 4)) -analyze!(G, NumberFold) - -# exit(0) - -@testset "Basic Constant Folding Example - Commutative Monoid" begin - @test (true == @areequalg G comm_monoid 3 * 4 12) - - @test (true == @areequalg G comm_monoid 3 * 4 12 4 * 3 6 * 2) -end - -@testset "Basic Constant Folding Example 2 - Commutative Monoid" begin - ex = :(a * 3 * b * 4) - G = EGraph(ex) - analyze!(G, NumberFold) - addexpr!(G, :(12 * a)) - println(saturate!(G, comm_monoid)) - display(G.classes); println() - @test (true == @areequalg G comm_monoid (12 * a) * b ((6 * 2) * b) * a) - @test (true == @areequalg G comm_monoid (3 * a) * (4 * b) (12 * a) * b ((6 * 2) * b) * a) -end - -@testset "Basic Constant Folding Example - Adding analysis after saturation" begin - G = EGraph(:(3 * 4)) - # addexpr!(G, 12) - saturate!(G, comm_monoid) - addexpr!(G, :(a * 2)) - analyze!(G, NumberFold) - saturate!(G, comm_monoid) - - # display(G.classes); println() - # println(G.root) - # display(G.analyses[1].data); println() - - @test (true == areequal(G, comm_monoid, :(3 * 4), 12, :(4 * 3), :(6 * 2))) - - ex = :(a * 3 * b * 4) - G = EGraph(ex) - analyze!(G, NumberFold) - params = SaturationParams(timeout=15) - @test areequal(G, comm_monoid, :((3 * a) * (4 * b)), :((12 * a) * b), - :(((6 * 2) * b) * a); params=params) -end - -@testset "Infinite Loops analysis" begin - boson = @theory begin - 1 * ~x --> ~x - end - - - G = EGraph(:(1 * x)) - params = SaturationParams(timeout=100) - saturate!(G, boson, params) - ex = extract!(G, astsize) - - # println(ex) - - using Metatheory.EGraphs - boson = @theory begin - (:c * :cdag) --> :cdag * :c + 1 - ~a * (~b + ~c) --> (~a * ~b) + (~a * ~c) - (~b + ~c) * ~a --> (~b * ~a) + (~c * ~a) - # 1 * x => x - (~a * ~b) * ~c --> ~a * (~b * ~c) - ~a * (~b * ~c) --> (~a * ~b) * ~c - end - - g = EGraph(:(c * c * cdag * cdag)) - saturate!(g, boson) - ex = extract!(g, astsize_inv) - - # println(ex) -end diff --git a/test/test_custom_types.jl b/test/test_custom_types.jl deleted file mode 100644 index 1c6b038b..00000000 --- a/test/test_custom_types.jl +++ /dev/null @@ -1,81 +0,0 @@ -using Metatheory -using Metatheory.EGraphs -using TermInterface -using Test - -struct MyExpr - head::Any - # NOTE! this will not work, when replacing - # with z in the theory defined below, the arg type - # will be EGraphs.EClass! Additional manipulation - # is needed for custom term types with stricter arg types - # args::Vector{Union{Int, MyExpr}} - args::Vector{Any} - # additional metadata - foo::String - bar::Vector{Complex} - baz::Set{Int} -end - -import Base.(==) -(==)(a::MyExpr, b::MyExpr) = a.head == b.head && a.args == b.args && - a.foo == b.foo && a.bar == b.bar && a.baz == b.baz - -MyExpr(head, args) = MyExpr(head, args, "", Complex[], Set{Int}()) -MyExpr(head) = MyExpr(head, []) - -# Methods needed by `src/TermInterface.jl` -TermInterface.exprhead(e::MyExpr) = :call -TermInterface.operation(e::MyExpr) = e.head -TermInterface.arguments(e::MyExpr) = e.args -TermInterface.istree(e::Type{MyExpr}) = true -# NamedTuple -TermInterface.metadata(e::MyExpr) = (foo = e.foo, bar = e.bar, baz = e.baz) -EGraphs.preprocess(e::MyExpr) = MyExpr(e.head, e.args, uppercase(e.foo), e.bar, e.baz) - -# f(g(2), h(4)) with some metadata in h -hcall = MyExpr(:h, [4], "hello", [2 + 3im, 4 + 2im], Set{Int}([4,5,6])) -ex = MyExpr(:f, [MyExpr(:g, [2]), hcall]) - - -function TermInterface.similarterm(x::Type{MyExpr}, head, args; - metadata=("", Complex[], Set{Int64}()), exprhead=:call) - MyExpr(head, args, metadata...) -end - -# let's create an egraph -g = EGraph(ex; keepmeta=true) - - -# ========== !!! ============= !!! =============== -# ========== !!! ============= !!! =============== -# ========== !!! ============= !!! =============== - -settermtype!(g, :f, 2, MyExpr) -settermtype!(g, :f, 1, MyExpr) -settermtype!(g, :g, 1, MyExpr) - -# ========== !!! ============= !!! =============== -# ========== !!! ============= !!! =============== -# ========== !!! ============= !!! =============== - -# let's create an example theory -t = @theory a begin - # this way, z will be a regular expr - # f(g(2), a) => z(a) - # we can use dynamic rules to construct values of type MyExpr - # f(g(2), a) |> MyExpr(:z, [a]) - - # terms in the RHS inherit the type of terms in the lhs - f(g(2), a) --> f(a) -end - -saturate!(g, t) - -# display(g.classes) - -expected = MyExpr(:f, [MyExpr(:h, [4], "HELLO", Complex[2 + 3im, 4 + 2im], Set([5, 4, 6]))], "", Complex[], Set{Int64}()) - -extracted = extract!(g, astsize) - -@test expected == extracted diff --git a/test/test_egraphs.jl b/test/test_egraphs.jl deleted file mode 100644 index 7980bb2b..00000000 --- a/test/test_egraphs.jl +++ /dev/null @@ -1,88 +0,0 @@ - -# ENV["JULIA_DEBUG"] = Metatheory -using Metatheory -using Metatheory.EGraphs -using Metatheory.EGraphs: in_same_set, find_root - -@testset "Merging" begin - testexpr = :((a * 2) / 2) - testmatch = :(a << 1) - G = EGraph(testexpr) - t2, _ = addexpr!(G, testmatch) - merge!(G, t2.id, EClassId(3)) - @test in_same_set(G.uf, t2.id, EClassId(3)) == true - # DOES NOT UPWARD MERGE -end - -# testexpr = :(42a + b * (foo($(Dict(:x => 2)), 42))) - -@testset "Simple congruence - rebuilding" begin - G = EGraph() - ec1, _ = addexpr!(G, :(f(a, b))) - ec2, _ = addexpr!(G, :(f(a, c))) - - testexpr = :(f(a, b) + f(a, c)) - - testec, _ = addexpr!(G, testexpr) - - t1, _ = addexpr!(G, :b) - t2, _ = addexpr!(G, :c) - display(G.classes); println() - - c_id = merge!(G, t2.id, t1.id) - # display(G.classes); println() - @test in_same_set(G.uf, c_id, t1.id) - @test in_same_set(G.uf, t2.id, t1.id) - # println(find_root!(G.uf, t2.id)) - # @test find_root(G.uf, t2.id) == 4 - rebuild!(G) - # f(a,b) = f(a,c) - # display(G.classes); println() - # display(G.memo); println() - - # for (id, ec) ∈ G.classes - # println(id) - # dump.(ec.nodes) - # end - - @test in_same_set(G.uf, ec1.id, ec2.id) -end - - -@testset "Simple nested congruence" begin - apply(n, f, x) = n == 0 ? x : apply(n - 1, f, f(x)) - f(x) = Expr(:call, :f, x) - - G = EGraph(:a) - - t1, _ = addexpr!(G, apply(6, f, :a)) - t2, _ = addexpr!(G, apply(9, f, :a)) - - c_id = merge!(G, t1.id, EClassId(1)) # a == apply(6,f,a) - c2_id = merge!(G, t2.id, EClassId(1)) # a == apply(9,f,a) - - # display(G.classes); println() - - rebuild!(G) - - # display(G.classes); println() - - t3, _ = addexpr!(G, apply(3, f, :a)) - t4, _ = addexpr!(G, apply(7, f, :a)) - - # f^m(a) = a = f^n(a) ⟹ f^(gcd(m,n))(a) = a - @test in_same_set(G.uf, t1.id, EClassId(1)) == true - @test in_same_set(G.uf, t2.id, EClassId(1)) == true - @test in_same_set(G.uf, t3.id, EClassId(1)) == true - @test in_same_set(G.uf, t4.id, EClassId(1)) == false - - # if m or n is prime, f(a) = a - t5, _ = addexpr!(G, apply(11, f, :a)) - t6, _ = addexpr!(G, apply(1, f, :a)) - c5_id = merge!(G, t5.id, EClassId(1)) # a == apply(11,f,a) - - rebuild!(G) - - @test in_same_set(G.uf, t5.id, EClassId(1)) == true - @test in_same_set(G.uf, t6.id, EClassId(1)) == true -end diff --git a/test/test_ematch.jl b/test/test_ematch.jl deleted file mode 100644 index 2f77f926..00000000 --- a/test/test_ematch.jl +++ /dev/null @@ -1,171 +0,0 @@ -using Metatheory -using Metatheory.Library - -falseormissing(x) = - x === missing || !x - -r = @theory begin - foo(~x, ~y) → 2 * ~x % ~y - foo(~x, ~y) → sin(~x) - sin(~x) → foo(~x, ~x) -end -@testset "Basic Equalities 1" begin - @test (@areequal r foo(b, c) foo(d, d)) == false -end - - -r = @theory begin - ~a * 1 → :foo - ~a * 2 → :bar - 1 * ~a → :baz - 2 * ~a → :mag -end - -@testset "Matching Literals" begin - g = EGraph(:(a * 1)) - addexpr!(g, :foo) - saturate!(g, r) - display(g.classes); println() - - @test (@areequal r a * 1 foo) == true - @test (@areequal r a * 2 foo) == false - @test (@areequal r a * 1 bar) == false - @test (@areequal r a * 2 bar) == true - - @test (@areequal r 1 * a baz) == true - @test (@areequal r 2 * a baz) == false - @test (@areequal r 1 * a mag) == false - @test (@areequal r 2 * a mag) == true - end - - -comm_monoid = @commutative_monoid (*) 1 -@testset "Basic Equalities - Commutative Monoid" begin - @test true == (@areequal comm_monoid a * (c * (1 * d)) c * (1 * (d * a)) ) - @test true == (@areequal comm_monoid x * y y * x ) - @test true == (@areequal comm_monoid (x * x) * (x * 1) x * (x * x) ) -end - - -comm_group = @commutative_group (+) 0 inv -t = comm_monoid ∪ comm_group ∪ (@distrib (*) (+)) - - -@testset "Basic Equalities - Comm. Monoid, Abelian Group, Distributivity" begin - @test true == (@areequal t (a * b) + (a * c) a * (b + c) ) - @test true == (@areequal t a * (c * (1 * d)) c * (1 * (d * a)) ) - @test true == (@areequal t a + (b * (c * d)) ((d * c) * b) + a ) - @test true == (@areequal t (x + y) * (a + b) ((a * (x + y)) + b * (x + y)) ((x * (a + b)) + y * (a + b)) ) - @test true == (@areequal t (((x * a + x * b) + y * a) + y * b) (x + y) * (a + b) ) - @test true == (@areequal t a + (b * (c * d)) ((d * c) * b) + a ) - @test true == (@areequal t a + inv(a) 0 (x * y) + inv(x * y) 1 * 0 ) -end - - -@testset "Basic Equalities - False statements" begin - @test falseormissing(@areequal t (a * b) + (a * c) a * (b + a)) - @test falseormissing(@areequal t (a * c) + (a * c) a * (b + c)) - @test falseormissing(@areequal t a * (c * c) c * (1 * (d * a))) - @test falseormissing(@areequal t c + (b * (c * d)) ((d * c) * b) + a) - @test falseormissing(@areequal t (x + y) * (a + c) ((a * (x + y)) + b * (x + y))) - @test falseormissing(@areequal t ((x * (a + b)) + y * (a + b)) (x + y) * (a + c)) - @test falseormissing(@areequal t (((x * a + x * b) + y * a) + y * b) (x + y) * (a + x)) - @test falseormissing(@areequal t a + (b * (c * a)) ((d * c) * b) + a) - @test falseormissing(@areequal t a + inv(a) a) - @test falseormissing(@areequal t (x * y) + inv(x * y) 1) -end - -# Issue 21 -simp_theory = @theory begin - munit() => :foo -end -G = EGraph(:(munit())) -params = SaturationParams(timeout=1) -saturate!(G, simp_theory, params) - - -module Bar -foo = 42 -export foo - using Metatheory - - t = @theory begin - :woo => foo - end - export t -end - -module Foo - foo = 12 - using Metatheory - - t = @theory begin - :woo => foo - end - export t -end - - -g = EGraph(:woo); -saturate!(g, Bar.t); -saturate!(g, Foo.t); -foo = 12 - -@testset "Different modules" begin - @test @areequalg g t 42 12 -end - - -comm_monoid = @theory begin - ~a * ~b --> ~b * ~a - ~a * 1 --> ~a - ~a * (~b * ~c) --> (~a * ~b) * ~c - ~a::Number * ~b::Number => ~a * ~b -end - -G = EGraph(:(3 * 4)) -@testset "Basic Constant Folding Example - Commutative Monoid" begin - @test (true == @areequalg G comm_monoid 3 * 4 12) - @test (true == @areequalg G comm_monoid 3 * 4 12 4 * 3 6 * 2) -end - - -@testset "Basic Constant Folding Example 2 - Commutative Monoid" begin - ex = :(a * 3 * b * 4) - G = EGraph(ex) - @test (true == @areequalg G comm_monoid (3 * a) * (4 * b) (12 * a) * b ((6 * 2) * b) * a) -end - -@testset "Type Assertions in Ematcher" begin - some_theory = @theory begin - ~a * ~b --> ~b * ~a - ~a::Number * ~b::Number --> matched(~a, ~b) - ~a::Int64 * ~b::Int64 --> specific(~a, ~b) - ~a * (~b * ~c) --> (~a * ~b) * ~c - end - - g = EGraph(:(2 * 3)) - saturate!(g, some_theory) - # display(g.classes) - - @test true == areequal(g, some_theory, :(2 * 3), :(matched(2, 3))) - @test true == areequal(g, some_theory, :(matched(2, 3)), :(specific(3, 2))) -end - -function Base.iszero(g::EGraph, ec::EClass) - n = ENodeLiteral(0) - return n ∈ ec -end - -@testset "Predicates in Ematcher" begin - some_theory = @theory begin - ~a::iszero * ~b --> 0 - ~a * ~b --> ~b * ~a - end - - g = EGraph(:(2 * 3)) - saturate!(g, some_theory) - # display(g.classes) - - @test true == areequal(g, some_theory, :(a * b * 0), 0) -end diff --git a/test/test_extraction.jl b/test/test_extraction.jl deleted file mode 100644 index 4c27bd1d..00000000 --- a/test/test_extraction.jl +++ /dev/null @@ -1,236 +0,0 @@ -using Metatheory -using Metatheory.Library - -include("numberfold.jl") - -comm_monoid = @commutative_monoid (*) 1 - -fold_mul = @theory begin - ~a::Number * ~b::Number => ~a * ~b -end - -t = comm_monoid ∪ fold_mul - - -@testset "Extraction 1 - Commutative Monoid" begin - G = EGraph(:(3 * 4)) - saturate!(G, t) - @test (12 == extract!(G, astsize)) - - ex = :(a * 3 * b * 4) - G = EGraph(ex) - params = SaturationParams(timeout=15) - saturate!(G, t, params) - extr = extract!(G, astsize) - println(extr) - @test extr == :((12 * a) * b) || extr == :(12 * (a * b)) || extr == :(a * (b * 12)) || - extr == :((a * b) * 12) || extr == :((12a) * b) || extr == :(a * (12b)) || - extr == :((b * (12a))) || extr == :((b * 12) * a) || extr == :((b * a) * 12) || - extr == :(b * (a * 12)) || extr == :((12b) * a) -end - -fold_add = @theory begin - ~a::Number + ~b::Number => ~a + ~b -end - -@testset "Extraction 2" begin - comm_group = @commutative_group (+) 0 inv - - - t = comm_monoid ∪ comm_group ∪ (@distrib (*) (+)) ∪ fold_mul ∪ fold_add - - # for i ∈ 1:20 - # sleep(0.3) - ex = :((x * (a + b)) + (y * (a + b))) - G = EGraph(ex) - saturate!(G, t) - # end - - extract!(G, astsize) == :((x + y) * (b + a)) -end - -@testset "Lazy Extraction 2" begin - comm_group = @commutative_group (+) 0 inv - - t = comm_monoid ∪ comm_group ∪ (@distrib (*) (+)) ∪ fold_mul ∪ fold_add - - # for i ∈ 1:20 - # sleep(0.3) - ex = :((x * (a + b)) + (y * (a + b))) - G = EGraph(ex) - saturate!(G, t) - # end - - extract!(G, astsize) == :((x + y) * (b + a)) -end - -@testset "Extraction - Adding analysis after saturation" begin - G = EGraph(:(3 * 4)) - addexpr!(G, 12) - saturate!(G, t) - addexpr!(G, :(a * 2)) - saturate!(G, t) - - saturate!(G, t) - - @test (12 == extract!(G, astsize)) - - # for i ∈ 1:100 - ex = :(a * 3 * b * 4) - G = EGraph(ex) - analyze!(G, NumberFold) - params = SaturationParams(timeout=15) - saturate!(G, comm_monoid, params) - - extr = extract!(G, astsize) - - @test extr == :((12 * a) * b) || extr == :(12 * (a * b)) || extr == :(a * (b * 12)) || - extr == :((a * b) * 12) || extr == :((12a) * b) || extr == :(a * (12b)) || - extr == :((b * (12a))) || extr == :((b * 12) * a) || extr == :((b * a) * 12) || - extr == :(b * (a * 12)) -end - - -comm_monoid = @commutative_monoid (*) 1 - -comm_group = @commutative_group (+) 0 inv - -powers = @theory begin - ~a * ~a → (~a)^2 - ~a → (~a)^1 - (~a)^~n * (~a)^~m → (~a)^(~n + ~m) -end -logids = @theory begin - log((~a)^~n) --> ~n * log(~a) - log(~x * ~y) --> log(~x) + log(~y) - log(1) --> 0 - log(:e) --> 1 - :e^(log(~x)) --> ~x -end - -t = comm_monoid ∪ comm_group ∪ (@distrib (*) (+)) ∪ powers ∪ logids ∪ fold_mul ∪ fold_add - -@testset "Complex Extraction" begin - G = EGraph(:(log(e) * log(e))) - params = SaturationParams(timeout=8) - saturate!(G, t, params) - # display(G.classes);println() - @test extract!(G, astsize) == 1 - - G = EGraph(:(log(e) * (log(e) * e^(log(3))) )) - params = SaturationParams(timeout=7) - saturate!(G, t, params) - @test extract!(G, astsize) == 3 - - @time begin - G = EGraph(:(a^3 * a^2)) - saturate!(G, t) - ex = extract!(G, astsize) - end - @test ex == :(a^5) - - @time begin - G = EGraph(:(a^3 * a^2)) - saturate!(G, t) - ex = extract!(G, astsize) - end - @test ex == :(a^5) - - function cust_astsize(n::ENodeTerm, g::EGraph, an::Type{<:AbstractAnalysis}) - cost = 1 + arity(n) - - if operation(n) == :^ - cost += 2 - end - - for id ∈ arguments(n) - eclass = g[id] - !hasdata(eclass, an) && (cost += Inf; break) - cost += last(getdata(eclass, an)) - end - return cost - end - - - function cust_astsize(n::ENodeLiteral, g::EGraph, an::Type{<:AbstractAnalysis}) - 1 - end - - - @time begin - G = EGraph(:((log(e) * log(e)) * (log(a^3 * a^2)))) - saturate!(G, t) - @show getcost!(G, cust_astsize) - ex = extract!(G, cust_astsize) - end - @test ex == :(5 * log(a)) || ex == :(log(a) * 5) -end - -function costfun(n::ENodeTerm, g::EGraph, an) - arity(n) != 2 && (return 1) - left = arguments(n)[1] - left_class = g[left] - ENodeLiteral(:a) ∈ left_class.nodes ? 1 : 100 -end - -costfun(n::ENodeLiteral, g::EGraph, an) = 1 - - -moveright = @theory begin - (:b * (:a * ~c)) --> (:a * (:b * ~c)) -end - -expr = :(a * (a * (b * (a * b)))) -res = rewrite(expr, moveright) - -g = EGraph(expr) -saturate!(g, moveright) -resg = extract!(g, costfun) - -@testset "Symbols in Right hand" begin - @test resg == res == :(a * (a * (a * (b * b)))) -end - -co = @theory begin - foo(~x ⋅ :bazoo ⋅ :woo) --> Σ(:n * ~x) -end -@testset "Consistency with classical backend" begin - ex = :(foo(wa(rio) ⋅ bazoo ⋅ woo)) - g = EGraph(ex) - saturate!(g, co) - - res = extract!(g, astsize) - - resclassic = rewrite(ex, co) - - @test res == resclassic -end - - -@testset "No arguments" begin - ex = :(f()); - g = EGraph(ex); - @test :(f()) == extract!(g, astsize) - - ex = :(f() + g()) - - t = @theory begin - f() + g() --> h() - end; - - gg = EGraph(ex) - saturate!(gg, t) - @show getcost!(gg, astsize) - res = extract!(gg, astsize); - - @test res == :(h()) -end - - -@testset "Symbol or function object operators in expressions in EGraphs" begin - ex = :(($+)(x, y)) - t = [@rule a b a + b => 2] - g = EGraph(ex) - saturate!(g, t) - @test extract!(g, astsize) == 2 -end diff --git a/test/test_inequality.jl b/test/test_inequality.jl deleted file mode 100644 index 3d8104cc..00000000 --- a/test/test_inequality.jl +++ /dev/null @@ -1,14 +0,0 @@ -using Metatheory - -failme = @theory p begin - p ≠ ¬p - :foo == ¬:foo - :foo --> :bazoo - :bazoo --> :wazoo -end - -g = EGraph(:foo) -report = saturate!(g, failme) -println(report) -@test report.reason === :contradiction -# @test !(@areequal failme foo wazoo) diff --git a/test/test_kb_benchmark.jl b/test/test_kb_benchmark.jl deleted file mode 100644 index c7248364..00000000 --- a/test/test_kb_benchmark.jl +++ /dev/null @@ -1,61 +0,0 @@ -using Test -using Metatheory -using Metatheory.Library -using Metatheory.EGraphs -using Metatheory.Rules -using Metatheory.EGraphs.Schedulers - -function rep(x, op, n::Int) - foldl((x, y) -> :(($op)($x, $y)), repeat([x], n)) -end - -macro rep(x, op, n::Int) - expr = rep(x, op, n) - esc(expr) -end - -rep(:a, :*, 3) - -@rule (@rep :a (*) 3) => :b - -Mid = @theory a begin - a * :ε --> a - :ε * a --> a -end - -Massoc = @theory a b c begin - a * (b * c) --> (a * b) * c - (a * b) * c --> a * (b * c) -end - - -T = [ - @rule :b * :B --> :ε - @rule :a * :a --> :ε - @rule :b * :b * :b --> :ε - @rule :B * :B --> :B - @rule (@rep (:a * :b) (*) 7) --> :ε - @rule (@rep (:a * :b * :a * :B) (*) 7) --> :ε - # RewriteRule(makepattern(rep(:(:a * :b), :*, 7)), :ε) - # RewriteRule(makepattern(rep(:(:a * :b * :a * :B), :*, 12)), :ε) -] - -G = Mid ∪ Massoc ∪ T -expr = :(a * b * a * a * a * b * b * b * a * B * B * B * B * a) - -ex = expr -g = EGraph(expr) -params = SaturationParams(timeout=8, scheduler=BackoffScheduler)# , schedulerparams=(128,4))#, scheduler=SimpleScheduler) -@timev saturate!(g, G, params) -ex = extract!(g, astsize) -@test ex == :ε - -another_expr = :(b * B) -g = EGraph(another_expr) -saturate!(g, G, params) - -another_expr = :(a * a * a * a) -some_eclass, _ = addexpr!(g, another_expr) -saturate!(g, G, params) -ex = extract!(g, astsize; root=some_eclass.id) -@test ex == :ε \ No newline at end of file diff --git a/test/test_mu.jl b/test/test_mu.jl deleted file mode 100644 index df5d6937..00000000 --- a/test/test_mu.jl +++ /dev/null @@ -1,27 +0,0 @@ -# https://en.wikipedia.org/wiki/MU_puzzle#Solution - -using Metatheory - -miu = @theory x y z begin - # Composition of the string monoid is associative - x ⋅ (y ⋅ z) --> (x ⋅ y) ⋅ z - # Add a uf to the end of any string ending in I - x ⋅ :I ⋅ :END --> x ⋅ :I ⋅ :U ⋅ :END - # Double the string after the M - :M ⋅ x ⋅ :END --> :M ⋅ x ⋅ x ⋅ :END - # Replace any III with a U - :I ⋅ :I ⋅ :I --> :U - # Remove any UU - x ⋅ :U ⋅ :U ⋅ y --> x ⋅ y -end - - -@testset "MU puzzle" begin - # no matter the timeout we set here, - # MU is not a theorem of the MIU system - params = SaturationParams(timeout=12, eclasslimit=8000) - start = :(M ⋅ I ⋅ END) - g = EGraph(start) - saturate!(g, miu) - @test false == areequal(g, miu, start, :(M ⋅ U ⋅ END); params=params) -end \ No newline at end of file diff --git a/test/test_multipat.jl b/test/test_multipat.jl deleted file mode 100644 index 8a801452..00000000 --- a/test/test_multipat.jl +++ /dev/null @@ -1,67 +0,0 @@ -using Metatheory -using Test -using Metatheory.Library -using Metatheory.EGraphs -using Metatheory.EGraphs.Schedulers - - -# ===================================================== - -# Zen Lineage Chart Example from Julog.jl https://github.com/ztangent/Julog.jl -# clauses = @julog [ -# ancestor(sakyamuni, bodhidharma) <<= true, -# teacher(bodhidharma, huike) <<= true, -# teacher(huike, sengcan) <<= true, -# teacher(sengcan, daoxin) <<= true, -# teacher(daoxin, hongren) <<= true, -# teacher(hongren, huineng) <<= true, -# ancestor(A, B) <<= teacher(A, B), -# ancestor(A, C) <<= teacher(B, C) & ancestor(A, B), -# grandteacher(A, C) <<= teacher(A, B) & teacher(B, C) -# ] - -facts = [ - :(ancestor(sakyamuni, bodhidharma)), - :(teacher(bodhidharma, huike)), - :(teacher(huike, sengcan)), - :(teacher(sengcan, daoxin)), - :(teacher(daoxin, hongren)), - :(teacher(hongren, huineng)), -] - -function addfacts!(g::EGraph, facts) - for fact ∈ facts - fc, _ = addexpr!(g, fact) - tc, _ = addexpr!(g, true) - merge!(g, fc.id, tc.id) - end -end - -clauses = @theory begin - teacher(a, b) => ancestor(a, b) - # grandteacher(A, C) <<= teacher(A, B) & teacher(B, C) -end - -# TODO syntax for MultiPatRewriteRule and PatEquiv -# ancestor(A, C) <<= teacher(B, C) & ancestor(A, B), -lhs = Pattern(:(teacher(b, c)) -rhs = Pattern(:(ancestor(a,c))) -pat1 = PatEquiv(Pattern(:(ancestor(a, b))), Pattern(:(teacher(b,c))) -q = MultiPatRewriteRule(lhs, rhs, [pat1]) - -push!(clauses, q) - -# goals to prove: ancestor(sakyamuni, huineng) -g = EGraph() -addfacts!(g, facts) - -query = :(ancestor(sakyamuni, huineng)) -addexpr!(g, query) - -params = SaturationParams(timeout=14) -saturate!(g, clauses, params) - -# display(g.classes); println() - -emptyt = @theory begin end -@test areequal(g, emptyt, true, query) diff --git a/test/test_patallterm.jl b/test/test_patallterm.jl deleted file mode 100644 index c602c815..00000000 --- a/test/test_patallterm.jl +++ /dev/null @@ -1,21 +0,0 @@ -using Metatheory -using Test -using Metatheory.Library -using Metatheory.EGraphs -using Metatheory.EGraphs.Schedulers - -t = [ - RewriteRule(PatTerm(:call, PatVar(:f), [PatVar(:a), PatVar(:b)], @__MODULE__), PatLiteral(:matched)) -] - -g = EGraph(:(foo(bar))) -saturate!(g, t) - -@test !areequal(g, RewriteRule[], :(foo(bar)), :matched) - -addexpr!(g, :(foo(bar, baz))) -saturate!(g, t) - -display(g.classes); println() - -@test areequal(g, RewriteRule[], :(foo(bar,baz)), :matched) \ No newline at end of file diff --git a/test/test_patequiv.jl b/test/test_patequiv.jl deleted file mode 100644 index 4c96bec2..00000000 --- a/test/test_patequiv.jl +++ /dev/null @@ -1,32 +0,0 @@ -using Metatheory -using Test -using Metatheory.Library -using Metatheory.EGraphs -using Metatheory.EGraphs.Schedulers - -# ================= TEST PATEQUIV ===================== - -# :foo => :zoo ⟺ :foo in same class as :bar - -lhs = PatEquiv(Pattern(:foo), Pattern(:bar)) -rhs = Pattern(:zoo) -q = RewriteRule(lhs, rhs) - - -g = EGraph() -fooclass, _ = addexpr!(g, :foo) -barclass, _ = addexpr!(g, :bar) -zooclass, _ = addexpr!(g, :zoo) - -# display(g.classes); println() -@test !(in_same_class(g, fooclass, zooclass)) - -saturate!(g, [q]) - -# display(g.classes); println() -@test !(in_same_class(g, fooclass, zooclass)) - -merge!(g, fooclass.id, zooclass.id) - -# display(g.classes); println() -@test in_same_class(g, fooclass, zooclass) diff --git a/test/test_patsplatvar.jl b/test/test_patsplatvar.jl deleted file mode 100644 index dc9d79c4..00000000 --- a/test/test_patsplatvar.jl +++ /dev/null @@ -1,14 +0,0 @@ -using Metatheory -using Test -using Metatheory.Library -using Metatheory.EGraphs -using Metatheory.EGraphs.Schedulers - -t = @theory begin - f(a...) |> (println.(a); 42) -end - -dump(t) - -g = EGraph(:(f(1,2,3))) -saturate!(g, t) \ No newline at end of file diff --git a/test/test_reductions.jl b/test/test_reductions.jl deleted file mode 100644 index 16b32c60..00000000 --- a/test/test_reductions.jl +++ /dev/null @@ -1,227 +0,0 @@ -using Metatheory - -@testset "Reduction Basics" begin - t = @theory begin - ~a + ~a --> 2*(~a) - ~x / ~x --> 1 - ~x * 1 --> ~x - end - - # basic theory to check that everything works - @test rewrite(:(a + a), t) == :(2a) - @test rewrite(:(a + (x * 1)), t) == :(a + x) - @test rewrite(:(a + (a * 1)), t; order=:inner) == :(2a) -end - - -import Base.(+) -@testset "Extending Algebra Operators" begin - t = @theory begin - ~a + ~a --> 2(~a) - end - - # Let's extend an operator from base, for sake of example - function +(x::Symbol, y) - rewrite(:($x + $y), t) - end - - @test (:x + :x) == :(2x) -end - -## Free Monoid - -@testset "Free Monoid - Overriding identity" begin - # support symbol literals - symbol_monoid = @theory begin - ~a ⋅ :ε --> ~a - :ε ⋅ ~a --> ~a - ~a::Symbol --> ~a - ~a::Symbol ⋅ ~b::Symbol => Symbol(String(a) * String(b)) - # i |> error("unsupported ", i) - end; - - @test rewrite(:(ε ⋅ a ⋅ ε ⋅ b ⋅ c ⋅ (ε ⋅ ε ⋅ d) ⋅ e), symbol_monoid; order=:inner) == :abcde -end - -## Interpolation should be possible at runtime - - -@testset "Calculator" begin - calculator = @theory begin - ~x::Number ⊕ ~y::Number => ~x + ~y - ~x::Number ⊗ ~y::Number => ~x * ~y - ~x::Number ⊖ ~y::Number => ~x ÷ ~y - ~x::Symbol --> ~x - ~x::Number --> ~x - end; - a = 10 - - @test rewrite(:(3 ⊕ 1 ⊕ $a), calculator; order=:inner) == 14 -end - - -## Direct rules -@testset "Direct Rules" begin - t = @theory begin - # maps - ~a * ~b => ((~a isa Number && ~b isa Number) ? ~a * ~b : _lhs_expr) - end - @test rewrite(:(3 * 1), t) == 3 - - t = @theory begin - # maps - ~a::Number * ~b::Number => ~a * ~b - end - @test rewrite(:(3 * 1), t) == 3 -end - - - -## Take advantage of subtyping. -# Subtyping in Julia has been formalized in this paper -# [Julia Subtyping: A Rational Reconstruction](https://benchung.github.io/papers/jlsub.pdf) - -abstract type Vehicle end -abstract type GroundVehicle <: Vehicle end -abstract type AirVehicle <: Vehicle end -struct Airplane <: AirVehicle end -struct Car <: GroundVehicle end - -airpl = Airplane() -car = Car() - -t = @theory begin - ~a::AirVehicle * ~b => "flies" - ~a::GroundVehicle * ~b => "doesnt_fly" -end - -@testset "Subtyping" begin - - sf = rewrite(:($airpl * c), t) - df = rewrite(:($car * c), t) - - @test sf == "flies" - @test df == "doesnt_fly" -end - - -@testset "Interpolation" begin - airpl = Airplane() - car = Car() - t = @theory begin - airpl * ~b => "flies" - car * ~b => "doesnt_fly" - end - - sf = rewrite(:($airpl * c), t) - df = rewrite(:($car * c), t) - - @test sf == "flies" - @test df == "doesnt_fly" -end - -@testset "Segment Variables" begin - t = @theory begin - f(~x, ~~y) => Expr(:call, :ok, (~~y)...) - end - - sf = rewrite(:(f(1,2,3,4)), t) - - @test sf == :(ok(2,3,4)) - - t = @theory x y begin - f(x, y...) => Expr(:call, :ok, y...) - end - - sf = rewrite(:(f(1,2,3,4)), t) - - @test sf == :(ok(2,3,4)) -end - - -module NonCall -using Metatheory -t = [@rule a b (a, b) --> ok(a,b)] - -test() = rewrite(:(x,y), t) -end - -@testset "Non-Call expressions" begin - @test NonCall.test() == :(ok(x,y)) -end - - -@testset "Pattern matcher can match on both function object references and name symbols" begin - ex = :($(+)($(sin)(x)^2, $(cos)(x)^2)) - r = @rule(sin(~x)^2 + cos(~x)^2 --> 1) - - @test r(ex) == 1 -end - - - -@testset "Pattern variable as pattern term head" begin - foo(x) = x+2 - ex = :(($foo)(bar, 2, pazz)) - r = @rule ((~f)(~x, 2, ~y) => (~f)(2)) - - @test r(ex) == 4 -end - -using TermInterface - -using Metatheory.Syntax: @capture -@testset "Capture form" begin - ex = :(a^a) - - #note that @test inserts a soft local scope (try-catch) that would gobble - #the matches from assignment statements in @capture macro, so we call it - #outside the test macro - ret = @capture ex (~x)^(~x) - @test ret - @test @isdefined x - @test x === :a - - ex = :(b^a) - ret = @capture ex (~y)^(~y) - @test !ret - @test !(@isdefined y) - - ret = @capture :(a + b) (+)(~~z) - @test ret - @test @isdefined z - @test all(z .=== arguments(:(a + b))) - - #a more typical way to use the @capture macro - - f(x) = if @capture x (~w)^(~w) - w - end - - @test f(:(b^b)) == :b - @test isnothing(f(:(b+b))) - - x = 1 - r = (@capture x x) - @test r == true -end - -using TermInterface -@testset "Matchable struct" begin - struct qux - args - qux(args...) = new(args) - end - TermInterface.operation(::qux) = qux - TermInterface.istree(::Type{qux}) = true - TermInterface.arguments(x::qux) = [x.args...] - - @capture qux(1, 2) qux(1, 2) - - @test (@rule qux(1, 2)=>"hello")(qux(1, 2)) == "hello" - @test (@rule qux(1, 2)=>"hello")(1) === nothing - @test (@rule 1=>"hello")(1) == "hello" - @test (@rule 1=>"hello")(qux(1, 2)) === nothing - @test (@capture qux(1, 2) qux(1, 2)) - @test false == (@capture qux(1,2) qux(3,4)) -end diff --git a/test/test_stream_fusion.jl b/test/test_stream_fusion.jl deleted file mode 100644 index a96503ac..00000000 --- a/test/test_stream_fusion.jl +++ /dev/null @@ -1,106 +0,0 @@ -using Metatheory -using Metatheory.Rewriters -using Test -using TermInterface -# using SymbolicUtils - -array_theory = @theory x y f g M N begin - #map(f,x)[n:m] = map(f,x[n:m]) # but does NOT commute with filter - map(f,fill(x,N)) == fill(apply(f,x), N) # hmm - # cumsum(fill(x,N)) == collect(x:x:(N*x)) - fill(x,N)[y] --> x - length(fill(x,N)) --> N - reverse(reverse(x)) --> x - sum(fill(x,N)) --> x * N - map(f,reverse(x)) == reverse(map(f, x)) - filter(f,reverse(x)) == reverse(filter(f,x)) - reverse(fill(x,N)) == fill(x,N) - filter(f, fill(x,N)) == (if apply(f, x); fill(x,N) else fill(x,0) end) - filter(f, filter(g, x)) == filter(fand(f,g), x) # using functional && - cat(fill(x,N),fill(x,M)) == fill(x,N + M) - cat(map(f,x), map(f,y)) == map(f, cat(x,y)) - map(f, cat(x,y)) == cat(map(f,x), map(f,y)) - map(f,map(g,x)) == map(f ∘ g, x) - reverse( cat(x,y) ) == cat(reverse(y), reverse(x)) - map(f,x)[y] == apply(f,x[y]) - apply(f ∘ g, x) == apply(f, apply(g, x)) - - reduce(g, map(f, x)) == mapreduce(f, g, x) - foldl(g, map(f, x)) == mapfoldl(f, g, x) - foldr(g, map(f, x)) == mapfoldr(f, g, x) -end - -asymptot_t = @theory x y z n m f g begin - (length(filter(f, x)) <= length(x)) => true - length(cat(x, y)) --> length(x) + length(y) - length(map(f, x)) => length(map) - length(x::UnitRange) => length(x) -end - -fold_theory = @theory x y z begin - x::Number * y::Number => x*y - x::Number + y::Number => x+y - x::Number / y::Number => x/y - x::Number - y::Number => x/y - # etc... -end - -# Simplify expressions like :(d->3:size(A,d)-3) given an explicit value for d -import Base.Cartesian: inlineanonymous - - -tryinlineanonymous(x) = nothing -function tryinlineanonymous(ex::Expr) - exprhead(ex) != :call && return nothing - f = operation(ex) - (!(f isa Expr) || exprhead(f) !== :->) && return nothing - arg = arguments(ex)[1] - println(arg) - try - return inlineanonymous(f, arg) - catch e - return nothing - end -end - -normalize_theory = @theory x y z f g begin - fand(f, g) => Expr(:->, :x, :(($f)(x) && ($g)(x))) - apply(f, x) => Expr(:call, f, x) -end - -params = SaturationParams() - -function stream_optimize(ex) - g = EGraph(ex) - rep = saturate!(g, array_theory, params) - @info rep - ex = extract!(g, astsize) # TODO cost fun with asymptotic complexity - ex = Fixpoint(Postwalk(Chain([tryinlineanonymous, normalize_theory..., fold_theory...])))(ex) - return ex -end - -build_fun(ex) = eval(:(()->$ex)) - - -@testset "Stream Fusion" begin - ex = :( map(x -> 7 * x, fill(3,4))) - opt = stream_optimize(ex) - @test opt == :(fill(21, 4)) - - ex = :( map(x -> 7 * x, fill(3,4) )[1]) - opt = stream_optimize(ex) - @test opt == 21 -end - -# ['a','1','2','3','4'] -ex = :(filter(ispow2, filter(iseven, reverse(reverse(fill(4, 100)))))) -opt = stream_optimize(ex) - - -ex = :( map(x -> 7 * x, reverse(reverse(fill(13,40))) )) -opt = stream_optimize(ex) -opt = stream_optimize(opt) - -macro stream_optimize(ex) - stream_optimize(ex) -end \ No newline at end of file diff --git a/test/test_taylor.jl b/test/test_taylor.jl deleted file mode 100644 index 2d9facd4..00000000 --- a/test/test_taylor.jl +++ /dev/null @@ -1,32 +0,0 @@ -using Metatheory - -taylor = @theory x a b begin - exp(x) --> Σ(x^:n / factorial(big(:n))) - cos(x) --> Σ((-1)^:n * x^2(:n) / factorial(big(2 * :n))) - Σ(a) + Σ(b) --> Σ(a + b) -end - -function expand(iters) - RewriteRule(PatTerm(:call, :Σ, [PatVar(:a)], @__MODULE__), - PatTerm(:call, :sum, [PatTerm(:(->), :(->), [:n, PatVar(:a)], @__MODULE__), 0:iters], @__MODULE__)) -end - -a = rewrite(:(exp(x) + cos(x)), taylor) - -r = expand(5000) -bexpr = rewrite(a, [r]) - -# you may want to do algebraic simplification -# with egraphs here - -x = big(42) - -b = eval(bexpr) -# 1.739274941520501044994695988622883932193276720547806372656638132701531037200611e+18 - -exp(x) + cos(x) -# 1.739274941520501046994695988622883932193276720547806372656638132701531037200651e+18 - -@testset "Infinite Series Approximation" begin - @test b ≈ (exp(x) + cos(x)) -end diff --git a/test/test_while_interpreter.jl b/test/test_while_interpreter.jl deleted file mode 100644 index cea40c73..00000000 --- a/test/test_while_interpreter.jl +++ /dev/null @@ -1,110 +0,0 @@ - -## Turing Complete Interpreter -### A Very Tiny Turing Complete Programming Language defined with denotational semantics - -# semantica dalle dispense degano - -using Metatheory -using Metatheory.Rewriters - -Mem = Dict{Symbol, Union{Bool, Int}} - -read_mem = @theory v σ begin - (v::Symbol, σ) => σ[v] -end - -@testset "Reading Memory" begin - @test 2 == rewrite(:((x), $(Mem(:x => 2))), read_mem; order=:inner) -end - -arithm_rules = @theory a b n σ begin - (a + b, σ) --> (a, σ) + (b, σ) - (a * b, σ) --> (a, σ) * (b, σ) - (a - b, σ) --> (a, σ) - (b, σ) - (a::Int + b::Int) => a + b - (a::Int * b::Int) => a * b - (a::Int - b::Int) => a - b - (n::Int, σ) => n -end - -strategy = (Fixpoint ∘ Postwalk ∘ Fixpoint ∘ Chain) - -eval_arithm(ex, mem) = - strategy(read_mem ∪ arithm_rules)(:($ex, $mem)) - - -@testset "Arithmetic" begin - @test 5 == eval_arithm(:(2 + 3), Mem()) - @test 4 == eval_arithm(:(2 + x), Mem(:x => 2)) -end - -# don't need to access memory -bool_rules = @theory a b σ begin - (a::Bool ∨ b::Bool) => (a || b) - (a::Bool ∧ b::Bool) => (a && b) - (a::Int < b::Int) => (a < b) - ¬a::Bool => !a - (a::Bool, σ) => a - (a < b, σ) => (eval_arithm(a, σ) < eval_arithm(b, σ)) - (¬b, σ) => !eval_bool(b, σ) - (a ∨ b, σ) --> (a, σ) ∨ (b, σ) - (a ∧ b, σ) --> (a, σ) ∧ (b, σ) -end - -eval_bool(ex, mem) = - strategy(bool_rules)(:($ex, $mem)) - -@testset "Booleans" begin - @test false == eval_bool(:(false ∨ false), Mem()) - @test true == eval_bool(:((false ∨ false) ∨ ¬(false ∨ false)), Mem(:x => 2)) - @test true == eval_bool(:((2 < 3) ∧ (3 < 4)), Mem(:x => 2)) - @test false == eval_bool(:((2 < x) ∨ ¬(3 < 4)), Mem(:x => 2)) - @test true == eval_bool(:((2 < x) ∨ ¬(3 < 4)), Mem(:x => 4)) -end - -if_rules = @theory guard t f σ begin - (if guard; t end, σ) --> (if guard; t else :skip end, σ) - (if guard; t else f end, σ) => - (eval_bool(guard, σ) ? :($t, $σ) : :($f, $σ)) -end - -eval_if(ex::Expr, mem::Mem) = - strategy(read_mem ∪ arithm_rules ∪ if_rules)(:($ex, $mem)) - -@testset "If Semantics" begin - @test 2 == eval_if(:(if true x else 0 end), Mem(:x => 2)) - @test 0 == eval_if(:(if false x else 0 end), Mem(:x => 2)) - @test 2 == eval_if(:(if ¬(false) x else 0 end), Mem(:x => 2)) - @test 0 == eval_if(:(if ¬(2 < x) x else 0 end), Mem(:x => 3)) -end - -while_rules = @theory guard a b σ begin - (:skip, σ) --> σ - ((:skip; b), σ) --> (b, σ) - ((a; b), σ) => begin - r = eval_while(a, σ); - (r isa Mem) ? :($b, $r) : :($b, $σ) - end - (while guard a end, σ) --> - (if guard; (a; while guard a end) else :skip end, σ) -end - - -write_mem = @theory sym val σ begin - (sym::Symbol = val, σ) => - (σ[sym] = eval_arithm(val, σ); σ) - # (println("BEFORE $memory"); memory[sym] = eval_arithm(val, memory); println("AFTER $memory"); memory) -end - -while_language = write_mem ∪ read_mem ∪ arithm_rules ∪ if_rules ∪ while_rules; - -eval_while(ex, mem) = - strategy(while_language)(:($ex, $mem)) - -@testset "While Semantics" begin - # @test Mem(:x => 3) == eval_while(:((x = 3)), Mem(:x => 2)) - @test Mem(:x => 5) == eval_while( :(x = 4; x = x + 1) , Mem(:x => 3)) - @test Mem(:x => 4) == eval_while( :( if x < 10; x = x + 1 end ) , Mem(:x => 3)) - # @test 10 == eval_while( :( while x < 10; x = x + 1 end ; x ) , Mem(:x => 3)) - # @test 50 == eval_while( :( while x < y; (x = x + 1; y = y - 1) end ; x ) , Mem(:x => 0, :y => 100)) -end diff --git a/test/test_while_superinterpreter.jl b/test/test_while_superinterpreter.jl deleted file mode 100644 index 702e3e9c..00000000 --- a/test/test_while_superinterpreter.jl +++ /dev/null @@ -1,132 +0,0 @@ - -## Turing Complete Interpreter -### A Very Tiny Turing Complete Programming Language defined with denotational semantics - -# semantica dalle dispense degano -using Metatheory - -import Base.ImmutableDict -Mem = Dict{Symbol,Union{Bool,Int}} - -read_mem = @theory v σ begin - (v::Symbol, σ::Mem) => if v == :skip σ else σ[v] end -end - -@testset "Reading Memory" begin - ex = :((x), $(Mem(:x => 2))) - @test true == areequal(read_mem, ex, 2) -end - -arithm_rules = @theory a b σ begin - (a + b, σ::Mem) --> (a, σ) + (b, σ) - (a * b, σ::Mem) --> (a, σ) * (b, σ) - (a - b, σ::Mem) --> (a, σ) - (b, σ) - (a::Int, σ::Mem) --> a - (a::Int + b::Int) => a + b - (a::Int * b::Int) => a * b - (a::Int - b::Int) => a - b -end - - -@testset "Arithmetic" begin - @test areequal(read_mem ∪ arithm_rules, :((2 + 3), $(Mem())), 5) -end - -# don't need to access memory -bool_rules = @theory a b σ begin - (a < b, σ::Mem) --> (a, σ) < (b, σ) - (a ∨ b, σ::Mem) --> (a, σ) ∨ (b, σ) - (a ∧ b, σ::Mem) --> (a, σ) ∧ (b, σ) - (¬(a), σ::Mem) --> ¬((a, σ)) - - (a::Bool, σ::Mem) => a - (¬a::Bool) => !a - (a::Bool ∨ b::Bool) => (a || b) - (a::Bool ∧ b::Bool) => (a && b) - (a::Int < b::Int) => (a < b) -end - -t = read_mem ∪ arithm_rules ∪ bool_rules - -@testset "Booleans" begin - @test areequal(t, :((false ∨ false), $(Mem())), false) - - exx = :((false ∨ false) ∨ ¬(false ∨ false), $(Mem(:x => 2))) - g = EGraph(exx) - saturate!(g, t) - ex = extract!(g, astsize) - @test ex == true - params = SaturationParams(timeout=12) - @test areequal(t, exx, true; params=params) - - @test areequal(t, :((2 < 3) ∧ (3 < 4), $(Mem(:x => 2))), true) - @test areequal(t, :((2 < x) ∨ ¬(3 < 4), $(Mem(:x => 2))), false) - @test areequal(t, :((2 < x) ∨ ¬(3 < 4), $(Mem(:x => 4))), true) -end - -if_rules = @theory guard t f σ begin - (if guard; t end) --> (if guard; t else :skip end) - (if guard; t else f end, σ::Mem) --> (if (guard, σ); t else f end, σ) - (if true; t else f end, σ::Mem) --> (t, σ) - (if false; t else f end, σ::Mem) --> (f, σ) -end - -if_language = read_mem ∪ arithm_rules ∪ bool_rules ∪ if_rules - - -@testset "If Semantics" begin - @test areequal(if_language, 2, :(if true x else 0 end, $(Mem(:x => 2)))) - @test areequal(if_language, 0, :(if false x else 0 end, $(Mem(:x => 2)))) - @test areequal(if_language, 2, :(if ¬(false) x else 0 end, $(Mem(:x => 2)))) - params = SaturationParams(timeout=10) - @test areequal(if_language, 0, :(if ¬(2 < x) x else 0 end, $(Mem(:x => 3))); params=params) -end - - -while_rules = @theory a b σ begin - (:skip, σ::Mem) --> σ - ((a; b), σ::Mem) --> ((a, σ); b) - (a::Int; b) --> b - (a::Bool; b) --> b - (σ::Mem; b) --> (b, σ) - (while a b end, σ::Mem) --> - (if a; (b; while a b end) else :skip end, σ) -end - - -write_mem = @theory sym val σ begin - (sym::Symbol = val, σ::Mem) --> (sym = (val, σ), σ) - (sym::Symbol = val::Int, σ::Mem) => merge(σ, Dict(sym => val)) -end - -while_language = if_language ∪ write_mem ∪ while_rules; - -@testset "While Semantics" begin - exx = :((x = 3), $(Mem(:x => 2))) - g = EGraph(exx) - saturate!(g, while_language) - ex = extract!(g, astsize) - - @test areequal(while_language, Mem(:x => 3), exx) - - exx = :((x = 4; x = x + 1), $(Mem(:x => 3))) - g = EGraph(exx) - saturate!(g, while_language) - ex = extract!(g, astsize) - - params = SaturationParams(timeout=10) - @test areequal(while_language, Mem(:x => 5), exx; params=params) - - params = SaturationParams(timeout=14) - exx = :((if x < 10 x = x + 1 else skip end), $(Mem(:x => 3))) - @test areequal(while_language, Mem(:x => 4), exx; params=params) - - exx = :( - (while x < 10 - x = x + 1 - end; x), $(Mem(:x => 3))) - g = EGraph(exx) - params = SaturationParams(timeout=100, scheduler=Schedulers.SimpleScheduler) - saturate!(g, while_language, params) - @test 10 == extract!(g, astsize) -end diff --git a/test/thesis_example.jl b/test/thesis_example.jl index 2bd8534f..6279bb76 100644 --- a/test/thesis_example.jl +++ b/test/thesis_example.jl @@ -2,26 +2,26 @@ using Metatheory using Metatheory.EGraphs using TermInterface -abstract type SignAnalysis <: AbstractAnalysis end +# TODO update -function EGraphs.make(an::Type{SignAnalysis}, g::EGraph, n::ENodeLiteral{<:Real}) - if n.value == Inf - return Inf - elseif n.value == -Inf +function EGraphs.make(::Val{:sign_analysis}, g::EGraph, n::ENodeLiteral{<:Real}) + if n.value == Inf + return Inf + elseif n.value == -Inf return -Inf elseif n.value isa Real # in Julia NaN is a Real return sign(n.value) - else + else return nothing end end -function EGraphs.make(an::Type{SignAnalysis}, g::EGraph, n::ENodeTerm) +function EGraphs.make(::Val{:sign_analysis}, g::EGraph, n::ENodeTerm) # Let's consider only binary function call terms. if exprhead(n) == :call && arity(n) == 2 # get the symbol name of the operation - op = operation(n) - op = op isa Function ? nameof(op) : op + op = operation(n) + op = op isa Function ? nameof(op) : op # Get the left and right child eclasses child_eclasses = arguments(n) @@ -33,7 +33,7 @@ function EGraphs.make(an::Type{SignAnalysis}, g::EGraph, n::ENodeTerm) lsign = getdata(l, an, nothing) rsign = getdata(r, an, nothing) - (lsign == nothing || rsign == nothing ) && return nothing + (lsign == nothing || rsign == nothing) && return nothing if op == :* return lsign * rsign @@ -54,21 +54,21 @@ function EGraphs.make(an::Type{SignAnalysis}, g::EGraph, n::ENodeTerm) return nothing end -function EGraphs.join(an::Type{SignAnalysis}, a, b) +function EGraphs.join(::Val{:sign_analysis}, a, b) return a == b ? a : nothing end -function EGraphs.make(an::Type{SignAnalysis}, g::EGraph, n::ENodeLiteral{Symbol}) - s = n.value +function EGraphs.make(::Val{:sign_analysis}, g::EGraph, n::ENodeLiteral{Symbol}) + s = n.value s == :x && return 1 - s == :y && return -1 + s == :y && return -1 s == :z && return 0 - s == :k && return Inf + s == :k && return Inf return nothing end # we are cautious, so we return false by default -isnotzero(g::EGraph, x::EClass) = getdata(x, SignAnalysis, false) +isnotzero(g::EGraph, x::EClass) = getdata(x, :sign_analysis, false) # t = @theory a b c begin # a * (b * c) == (a * b) * c @@ -83,20 +83,20 @@ isnotzero(g::EGraph, x::EClass) = getdata(x, SignAnalysis, false) function custom_analysis(expr) g = EGraph(expr) # saturate!(g, t) - analyze!(g, SignAnalysis) - return getdata(g[g.root], SignAnalysis) + analyze!(g, :sign_analysis) + return getdata(g[g.root], :sign_analysis) end -custom_analysis(:(3*x)) # :odd -custom_analysis(:(3*(2+a)*2)) # :even -custom_analysis(:(-3y * (2x*y))) # :even -custom_analysis(:(k/k)) # :even +custom_analysis(:(3 * x)) # :odd +custom_analysis(:(3 * (2 + a) * 2)) # :even +custom_analysis(:(-3y * (2x * y))) # :even +custom_analysis(:(k / k)) # :even #===========================================================================================# # pattern variables can be specified before the block of rules -comm_monoid = @theory a b c begin +comm_monoid = @theory a b c begin a * b == b * a # commutativity a * 1 --> a # identity a * (b * c) == (a * b) * c # associativity @@ -120,12 +120,12 @@ end; div_sim = @theory a b c begin (a * b) / c == a * (b / c) - a::isnotzero / a::isnotzero --> 1 + a::isnotzero / a::isnotzero --> 1 end; -t = vcat(comm_monoid, comm_group, folder, div_sim) ; +t = vcat(comm_monoid, comm_group, folder, div_sim); -g = EGraph(:(a * (2 * 3) / 6)) ; -saturate!(g, t) +g = EGraph(:(a * (2 * 3) / 6)); +saturate!(g, t) ex = extract!(g, astsize) -# :a \ No newline at end of file +# :a diff --git a/test/tutorials/README.md b/test/tutorials/README.md new file mode 100644 index 00000000..ded81be9 --- /dev/null +++ b/test/tutorials/README.md @@ -0,0 +1,4 @@ +# Literate tests + +This folder contains Julia scripts in the [Literate.jl](https://fredrikekre.github.io/Literate.jl/v2/) format. +Such scripts are executed by tests, and are also included in the generated documentation. \ No newline at end of file diff --git a/test/tutorials/custom_types.jl b/test/tutorials/custom_types.jl new file mode 100644 index 00000000..04be18e0 --- /dev/null +++ b/test/tutorials/custom_types.jl @@ -0,0 +1,113 @@ +# # Interfacing with Metatheory.jl +# This section is for Julia package developers who may want to use the rule +# rewriting systems on their own expression types. +# ## Defining the interface +# +# Metatheory.jl matchers can match any Julia object that implements an interface +# to traverse it as a tree. The interface in question, is defined in the +# [TermInterface.jl](https://github.com/JuliaSymbolics/TermInterface.jl) package. +# Its purpose is to provide a shared interface between various symbolic +# programming Julia packages. +# In particular, you should define methods from TermInterface.jl for an expression +# tree type `T` with symbol types `S` to work with SymbolicUtils.jl +# You can read the documentation of +# [TermInterface.jl](https://github.com/JuliaSymbolics/TermInterface.jl) on the +# [Github repository](https://github.com/JuliaSymbolics/TermInterface.jl). + +# ## Concrete example + +using Metatheory, TermInterface, Test +using Metatheory.EGraphs + +# We first define our custom expression type in `MyExpr`: +# It behaves like `Expr`, but it adds some extra fields. +struct MyExpr + head::Any + args::Vector{Any} + foo::String # additional metadata +end +MyExpr(head, args) = MyExpr(head, args, "") +MyExpr(head) = MyExpr(head, []) + +# We also need to define equality for our expression. +function Base.:(==)(a::MyExpr, b::MyExpr) + a.head == b.head && a.args == b.args && a.foo == b.foo +end + +# ## Overriding `TermInterface`` methods + +# First, we need to discern when an expression is a leaf or a tree node. +# We can do it by overriding `istree`. +TermInterface.istree(::MyExpr) = true + +# The `operation` function tells us what's the node's represented operation. +TermInterface.operation(e::MyExpr) = e.head +# `arguments` tells the system how to extract the children nodes. +TermInterface.arguments(e::MyExpr) = e.args + +# A particular function is `exprhead`. It is used to bridge our custom `MyExpr` +# type, together with the `Expr` functionality that is used in Metatheory rule syntax. +# In this example we say that all expressions of type `MyExpr`, can be represented (and matched against) by +# a pattern that is represented by a `:call` Expr. +TermInterface.exprhead(::MyExpr) = :call + +# While for common usage you will always define `exprhead` it to be `:call`, +# there are some cases where you would like to match your expression types +# against more complex patterns, for example, to match an expression `x` against an `a[b]` kind of pattern, +# you would need to inform the system that `exprhead(x)` is `:ref`, because +dump(:(a[b])) + + +# `metadata` should return the extra metadata. If you have many fields, i suggest using a `NamedTuple`. +TermInterface.metadata(e::MyExpr) = e.foo + +# Additionally, you can override `EGraphs.preprocess` on your custom expression +# to pre-process any expression before insertion in the E-Graph. +# In this example, we always `uppercase` the `foo::String` field of `MyExpr`. +EGraphs.preprocess(e::MyExpr) = MyExpr(e.head, e.args, uppercase(e.foo)) + + +# `TermInterface` provides a very important function called `similarterm`. +# It is used to create a term that is in the same closure of types of `x`. +# Given an existing term `x`, it is used to instruct Metatheory how to recompose +# a similar expression, given a `head` (the result of `operation`), some children (given by `arguments`) +# and additionally, `metadata` and `exprehead`, in case you are recomposing an `Expr`. +function TermInterface.similarterm(x::MyExpr, head, args; metadata = nothing, exprhead = :call) + MyExpr(head, args, isnothing(metadata) ? "" : metadata) +end + +# Since `similarterm` works by making a new term similar to an existing term `x`, +# in the e-graphs system, there won't be enough information such as a 'reference' object. +# Only the type of the object is known. This extra function adds a bit of verbosity, due to compatibility +# with SymbolicUtils.jl +function EGraphs.egraph_reconstruct_expression(::Type{MyExpr}, op, args; metadata = nothing, exprhead = nothing) + MyExpr(op, args, (isnothing(metadata) ? () : metadata)) +end + +# ## Theory Example + +# Note that terms in the RHS will inherit the type of terms in the LHS. + +function f end +function h end +function z end +t = @theory a begin + f(z(2), a) --> f(a) +end + +# Let's create an example expression and e-graph +hcall = MyExpr(h, [4], "hello") +ex = MyExpr(f, [MyExpr(z, [2]), hcall]) +g = EGraph(ex; keepmeta = true) + +# We use `settermtype!` on an existing e-graph to inform the system about +# the *default* type of expressions that we want newly added expressions to have. +settermtype!(g, MyExpr) + +# Now let's test that it works. +saturate!(g, t) +expected = MyExpr(f, [MyExpr(h, [4], "HELLO")], "") +extracted = extract!(g, astsize) +@test expected == extracted + + diff --git a/test/tutorials/mu.jl b/test/tutorials/mu.jl new file mode 100644 index 00000000..0869e5c1 --- /dev/null +++ b/test/tutorials/mu.jl @@ -0,0 +1,33 @@ +# # The MU Puzzle +# The puzzle cannot be solved: it is impossible to change the string MI into MU +# by repeatedly applying the given rules. In other words, MU is not a theorem of +# the MIU formal system. To prove this, one must step "outside" the formal system +# itself. +# https://en.wikipedia.org/wiki/MU_puzzle#Solution + +using Metatheory, Test + +# Here are the axioms of MU: +# * Composition of the string monoid is associative +# * Add a uf to the end of any string ending in I +# * Double the string after the M +# * Replace any III with a U +# * Remove any UU +function ⋅ end +miu = @theory x y z begin + x ⋅ (y ⋅ z) --> (x ⋅ y) ⋅ z + x ⋅ :I ⋅ :END --> x ⋅ :I ⋅ :U ⋅ :END + :M ⋅ x ⋅ :END --> :M ⋅ x ⋅ x ⋅ :END + :I ⋅ :I ⋅ :I --> :U + x ⋅ :U ⋅ :U ⋅ y --> x ⋅ y +end + + +# No matter the timeout we set here, +# MU is not a theorem of the MIU system +params = SaturationParams(timeout = 12, eclasslimit = 8000) +start = :(M ⋅ I ⋅ END) +g = EGraph(start) +saturate!(g, miu) +@test false == areequal(g, miu, start, :(M ⋅ U ⋅ END); params = params) + diff --git a/test/tutorials/while_interpreter.jl b/test/tutorials/while_interpreter.jl new file mode 100644 index 00000000..27b7b877 --- /dev/null +++ b/test/tutorials/while_interpreter.jl @@ -0,0 +1,241 @@ + +#= +# Write a very tiny Turing Complete language in Julia. + +WHILE is a very tiny Turing Complete Programming Language defined by denotational semantics. +Semantics come from the excellent +[course notes](http://pages.di.unipi.it/degano/ECC-uno.pdf) in *"Elements of computability and +complexity"* by prof. [Pierpaolo Degano](http://pages.di.unipi.it/degano/). + +It is a toy C-like language used to explain the core concepts of computability and Turing-completeness. +The name WHILE, comes from the fact that the most complicated construct in the language is a WHILE loop. +The language supports: +* A variable-value memory that can be pre-defined for program input. +* Integer arithmetics. +* Boolean logic. +* Conditional if-then-else statement called `cond`. +* Running a command after another with `seq(c1,c2)`. +* Repeatedly applying a command `c` while a condition `g` holds with `loop(g,c)`. + +This is enough to be Turing-complete! + +We are going to implement this tiny imperative language with classical rewriting rules in [Metatheory.jl](https://github.com/JuliaSymbolics/Metatheory.jl/). +WHILE is implemented in around 55 readable lines of code, and reaches around 80 lines with tests. + +The goal of this tutorial is to show an implementation of a programming language interpreter that is very, very very close to the +simple theory used to describe it in a textbook. Each denotational semantics rule in the course notes is a Metatheory.jl rewrite rule, with a few extras and minor naming changes. +The idea, is that Julia is a really valid didactical programming language! + +=# + +# Let's load the Metatheory and Test packages. +using Test, Metatheory + +# ## Memory +# The first thing that our programming language needs, is a model of the *computer memory*, +# that is going to hold the state of the programs. We define the type of +# WHILE's memory as a map from variables (Julia `Symbol`s) to actual values. +# We want to keep things simple so in our toy programming language we are just going to use boolean or integer values. Surprisingly, we can still achieve turing completeness without having to introduce strings or any other complex data type. +# We are going to use the letter `σ` (sigma) to denote an actual value of type `Mem`, in simple words the state of a program in a given moment. +# For example, if a `σ::Mem` holds the value `σ[:a] = 2`, this means that at that given moment, in our program +# the variable `a` holds the value 2. + +Mem = Dict{Symbol,Union{Bool,Int}} + +# We are now ready to define our first rewrite rule. +# In WHILE, un-evaluated expressions are represented by a tuple of `(program, state)`. +# This simple rule tells us that, if at a given memory state `σ` we want to know the value of a variable `v`, we +# can simply read it from the memory and return the value. +read_mem = @theory v σ begin + (v::Symbol, σ::Mem) => σ[v] +end + +# Let's test this behavior. We first create a `Mem`, holding the variable `x` with value 2. +σ₁ = Mem(:x => 2) + +# Then, we define a program. Julia helps us avoid unneeded complications. +# Generally, to create an interpreted programming language, one would have to design a syntax for it, and then engineer components such as +# a lexer or a [parser](https://en.wikipedia.org/wiki/Parsing) in order to turn the input string into a manipulable, structured program. +# The Julia developers were really smart. We can directly re-use the whole Julia syntax, because Julia +# allows us to treat programs as values. You can try this by prefixing any expression you type in the REPL inside of `:( ... )` or `quote ... end`. +# If you type this in the Julia REPL: +2 + 2 + +# You get the obvious result out, but if you wrap it in `quote` or `:(...)`, you can see that the program will not be executed, but instead stored as an `Expr`. +some_expr = :(2 + 2) +dump(some_expr) + +# We can use the `$` unary operator to interpolate and insert values inside of quoted code. +:(2 + $(1 + 1)) + +# These code-manipulation utilities can be very useful, because we can completely skip the burden of having to write a new syntax for our educational programming language, and just +# re-use Julia's syntax. It hints us that Julia is very powerful, because you can define new semantics and customize the language's behaviour without +# having to leave the comfort of the Julia terminal. This is also how julia `@macros` work. +# The practice of manipulating programs in the language itself is called **Metaprogramming**, +# and you can read more about metaprogramming in Julia [in the official docs](https://docs.julialang.org/en/v1/manual/metaprogramming/). + + +# Let's test that our first, simple rule is working. +program = :(x, $σ₁) +@test rewrite(program, read_mem) == 2 + +# ## Arithmetics +# How can our programming language be turing complete if we do not include basic arithmetics? +# If we have an integer and a memory state, we can just keep the integer +# The following rules are the first cases of recursion. +# Given two expressions `a,b`, to know what's `a + b` in state `σ`, +# we need to know first what `a` and `b` are in state σ +# The last dynamic rules let us directly evaluate arithmetic operations. + +arithm_rules = @theory a b n σ begin + (n::Int, σ::Mem) --> n + (a + b, σ::Mem) --> (a, σ) + (b, σ) + (a * b, σ::Mem) --> (a, σ) * (b, σ) + (a - b, σ::Mem) --> (a, σ) - (b, σ) + (a::Int + b::Int) => a + b + (a::Int * b::Int) => a * b + (a::Int - b::Int) => a - b +end + + +# ## Evaluation strategy +# We now have some nice denotational semantic rules for arithmetics, but in what order should we apply them? +# Metatheory.jl provides a flexible rewriter combinator library. You can read more in the [Rewriters](@ref) module docs. +# +# Given a set of rules, we can define a rewriter strategy by functionally composing rewriters. +# First, we want to use `Chain` to combine together the many rules in the theory, and to try to apply them one-by-one on our expressions. +# +# But should we first evaluate the outermost operations in the expression, or the innermost? +# Intuitively, if we have the program `(1 + 2) - 3`, it can hint us that we do want to first evaluate the innermost expressions. +# To do so, we then pass the result to the [Postwalk](@ref) rewriter, which recursively walks the input expression tree, and applies the rewriter first on +# the inner expressions, and then, on the outer, rewritten expression. (Hence the name `Post`-walk. Can you guess what [Prewalk](@ref) does?). +# +# The last component of our strategy is the [Fixpoint](@ref) combinator. This combinator repeatedly applies the rewriter on the input expression, +# and it does stop looping only when the output expression is the unchanged input expression. + +using Metatheory.Rewriters +strategy = (Fixpoint ∘ Postwalk ∘ Chain) + +# In Metatheory.jl, rewrite theories are just vectors of [Rules](@ref). It means we can compose them by concatenating the vectors, or elegantly using the +# built-in set operations provided by the Julia language. +arithm_lang = read_mem ∪ arithm_rules + +# We can define a convenience function that takes an expression, a memory state and calls our strategy. +eval_arithm(ex, mem) = strategy(arithm_lang)(:($ex, $mem)) + + +# Does it work? +@test eval_arithm(:(2 + 3), Mem()) == 5 + +# Yay! Let's say that before the program started, the computer memory already held a variable `x` with value 2. +@test eval_arithm(:(2 + x), Mem(:x => 2)) == 4 + + +# ## Boolean Logic +# To be Turing-complete, our tiny WHILE language requires boolean logic support. +# There's nothing special or different from other programming languages. These rules +# define boolean operations to work just as you would expect, and in the same way we defined arithmetic rules for integers. +# +# We need to bridge together the world of integer arithmetics and boolean logic to achieve something useful. +# The last two rules in the theory. + +bool_rules = @theory a b σ begin + (a::Bool || b::Bool) => (a || b) + (a::Bool && b::Bool) => (a && b) + !a::Bool => !a + (a::Bool, σ::Mem) => a + (!b, σ::Mem) => !eval_bool(b, σ) + (a || b, σ::Mem) --> (a, σ) || (b, σ) + (a && b, σ::Mem) --> (a, σ) && (b, σ) + (a < b, σ::Mem) => (eval_arithm(a, σ) < eval_arithm(b, σ)) # This rule bridges together ints and bools + (a::Int < b::Int) => (a < b) +end + +eval_bool(ex, mem) = strategy(bool_rules)(:($ex, $mem)) + +# Let's run a few tests. +@test all( + [ + eval_bool(:(false || false), Mem()) == false + eval_bool(:((false || false) || !(false || false)), Mem(:x => 2)) == true + eval_bool(:((2 < 3) && (3 < 4)), Mem(:x => 2)) == true + eval_bool(:((2 < x) || !(3 < 4)), Mem(:x => 2)) == false + eval_bool(:((2 < x) || !(3 < 4)), Mem(:x => 4)) == true + ], +) + +# ## Conditionals: If-then-else + +# Conditional expressions in our language take the form of +# `cond(guard, thenbranch)` or `cond(guard, branch, elsebranch)` +# It means that our program at this point will: +# 1. Evaluate the `guard` expressions +# 2. If `guard` evaluates to `true`, then evaluate `thenbranch` +# 3. If `guard` evaluates to `false`, then evaluate `elsebranch` + +# The first rule here is simple. If there's no `elsebranch` in the +# `cond` statement, we add an empty one with the `skip` command. +# Otherwise, we piggyback on the existing Julia if-then-else ternary operator. +# To do so, we need to evaluate the boolean expression in the guard by +# using the `eval_bool` function we defined above. +function cond end +if_rules = @theory guard t f σ begin + (cond(guard, t), σ::Mem) --> (cond(guard, t, :skip), σ) + (cond(guard, t, f), σ::Mem) => (eval_bool(guard, σ) ? :($t, $σ) : :($f, $σ)) +end + +eval_if(ex, mem::Mem) = strategy(read_mem ∪ arithm_rules ∪ if_rules)(:($ex, $mem)) + +# And here is our working conditional + +@testset "If Semantics" begin + @test 2 == eval_if(:(cond(true, x, 0)), Mem(:x => 2)) + @test 0 == eval_if(:(cond(false, x, 0)), Mem(:x => 2)) + @test 2 == eval_if(:(cond(!(false), x, 0)), Mem(:x => 2)) + @test 0 == eval_if(:(cond(!(2 < x), x, 0)), Mem(:x => 3)) +end + + +# ## Writing memory + +# Our language then needs a mechanism to write in memory. +# We define the behavior of the `store` construct, which +# behaves like the `=` assignment operator in other programming languages. +# `store(a, 5)` will store the value 5 in the `a` variable inside the program's memory. + +function store end +write_mem = @theory sym val σ begin + (store(sym::Symbol, val), σ) => (σ[sym] = eval_if(val, σ); + σ) +end + +# ## While loops and sequential computation. + +function seq end +function loop end +while_rules = @theory guard a b σ begin + (:skip, σ::Mem) --> σ + ((:skip; b), σ::Mem) --> (b, σ) + (seq(a, b), σ::Mem) --> (b, merge((a, σ), σ)) + merge(a::Mem, σ::Mem) => merge(σ, a) + merge(a::Union{Bool,Int}, σ::Mem) --> σ + (loop(guard, a), σ::Mem) --> (cond(guard, seq(a, loop(guard, a)), :skip), σ) +end + + +# ## Completing the language. + +while_language = write_mem ∪ read_mem ∪ arithm_rules ∪ if_rules ∪ while_rules; + +using Metatheory.Syntax: rmlines +eval_while(ex, mem) = strategy(while_language)(:($(rmlines(ex)), $mem)) + +# Final steps + +@testset "While Semantics" begin + @test Mem(:x => 3) == eval_while(:((store(x, 3))), Mem(:x => 2)) + @test Mem(:x => 5) == eval_while(:(seq(store(x, 4), store(x, x + 1))), Mem(:x => 3)) + @test Mem(:x => 4) == eval_while(:(cond(x < 10, store(x, x + 1))), Mem(:x => 3)) + @test 10 == eval_while(:(seq(loop(x < 10, store(x, x + 1)), x)), Mem(:x => 3)) + @test 50 == eval_while(:(seq(loop(x < y, seq(store(x, x + 1), store(y, y - 1))), x)), Mem(:x => 0, :y => 100)) +end