Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[WIP] refactor: parser using Symbolics #877

Draft
wants to merge 2 commits into
base: master
Choose a base branch
from

Conversation

sathvikbhagavan
Copy link
Member

Checklist

  • Appropriate tests were added
  • Any code changes were done in a way that does not break public API
  • All documentation related to code changes were updated
  • The new code follows the
    contributor guidelines, in particular the SciML Style Guide and
    COLPRAC.
  • Any new documentation only uses public API

Additional context

Add any other context about the problem here.

@ChrisRackauckas
Copy link
Member

@sathvikbhagavan @avik-pal @xtalax how do we finally finish this?

@sathvikbhagavan
Copy link
Member Author

Hey, apologies for the late response (still getting settled in my grad school).

I think we have to rethink a bit on how we are doing the parsing. Me and @xtalax started out with using stateless apply for the parameters but I feel we don't need it. Our goal is to do the parsing such that each dependent variables are replaced by the forward pass of the neural network and the code is generated by Symbolics.jl. When we use stateless apply, we need to pass not just parameters, but its type as well and a bunch of things we don't need to do.

So, to give an MWE, I was thinking on these lines:

using ModelingToolkit, Symbolics
using Symbolics: unwrap
using Lux
using Random

rng = Random.default_rng()
Random.seed!(rng, 0)

@parameters x, y, z
@variables u(..), v(..), h(..), p(..)
Dz = Differential(z)
expr1 = u(x, y, z) + v(x, y) - h(x) - x - y - z ~ 0

chain_u = Lux.Chain(Lux.Dense(3, 12, Lux.tanh), Lux.Dense(12, 12, Lux.tanh),
                    Lux.Dense(12, 1))
ps, st = Lux.setup(rng, chain_u)
dvs = unwrap.([u(x, y, z), v(x, y), h(x)])
phi_symbols = Symbol.("phi_" .* string.(operation.(dvs)))
theta_symbols = Symbol.("theta_" .* string.(operation.(dvs)))

dvs_phi_map = Dict(map(
    i -> operation(dvs[i]) => let phi = phi_symbols[i]
        first(@variables $phi(..))
        end, 1:length(dvs)
))

dvs_theta_map = Dict(map(
    i -> operation(dvs[i]) => let theta = theta_symbols[i]
        first(@variables $theta)
        end, 1:length(dvs)
))
rule = [u => dvs_phi_map[operation(u)](arguments(u), dvs_theta_map[operation(u)]) for (i, u) in enumerate(dvs)]

expr2 = substitute(expr1.lhs, rule)
# This is -x - y - z - phi_h(Any[x], theta_h) + phi_u(Any[x, y, z], theta_u) + phi_v(Any[x, y], theta_v

## goal to generate a function which does this using MTK where 1st arg is independent vars, 2nd is neural networks for the dependent vars and 3rd are the parameters of the neural networks:
x = args"1"[1]
y = args"1"[2]
z = args"1"[3]
phi_u = args"2"[1]
phi_v = args"2"[2]
phi_h = args"2"[3]
theta_u = args"3"[1]
theta_v = args"3"[2]
theta_h = args"3"[3]
phi_u([x, y, z], theta_u) + phi_v([x, y], theta_v) - phi_h([x], theta_h) - x - y - z

I feel this solution is simple as at the end we have to replace terms with forward pass of the neural network to generate loss functions. Let me know what you think.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

2 participants