Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

best way to write a wrapper function for tree_at #784

Open
sede-fa opened this issue Jul 15, 2024 · 1 comment
Open

best way to write a wrapper function for tree_at #784

sede-fa opened this issue Jul 15, 2024 · 1 comment

Comments

@sede-fa
Copy link

sede-fa commented Jul 15, 2024

Hi,

I'm trying to rewrite my process flowsheet simulator in equinox. I'm able to solve individual components fine by creating my classes as follows:

class Unit1(eqx.Module):
    F0: interfaces.Flange
    F1: interfaces.Flange
    [other params]

    def __init__(self, F0, [other params]):
        self.F0 = F0
        self.F1 = F0  # just for initialising

    def forward(self):
        Eq = jnp.array([
                         self.F0.n - self.F1.n,
                          {other balance equations here}
                ])
        return Eq

    @eqx.filter_jit
    def f_sol(self, x):
        self = eqx.tree_at(lambda m: m, self, self.update(x))
        return self.forward()

    def update(self, x):
        self = eqx.tree_at(lambda m: m.F1.n, x[0])
        {update remaining fields of Flange F1}

    def run_sol(self):
        {run solver here with initial guesses etc.}
        self = eqx.tree_at(lambda m: m, self, self.update(solver.x))
        return self

However, when I want to connect components and update them sequentially I'm only doing this through repeated use of eqx.tree_at as follows:

class SystemModel(eqx.Module):
    unit1: Unit1
    unit2: Unit2

    def __init__(self, [params]):
        F0 = interfaces.Flange(n=1, ...)
        self.unit1 = Unit1(F0, ...)
        self.unit2 = Unit2(self.unit1.F1, ...)
        
    def forward(self, F0):
         self = eqx.tree_at(lambda m: m.unit1, self, F0)   # make connection to source
         self = eqx.tree_at(lambda m: m.unit1, self, self.unit1.run_sol())  # run solver to update outlet flange in unit1
         self = eqx.tree_at(lambda m: m.unit2, self, self.unit1.F0)  # sequentially update inlet of unit2
         self = eqx.tree_at(lambda m: m.unit2, self, self.unit2.run_sol())  # run solver to update unit2

        return self

and this can be executed as follows:

sysModel = SystemModel(...)
F0 = interfaces.Flange(n=1, ...)

sysModel = sysModel.forward(F0)

This all executes fine, however I'm wondering if it may be possible to do this through a wrapper method like connect(Flange1, Flange2) that could handle the connections and also allow me to add some additional features such as graph generation using mermaid which I have in the numpy version of my framework.

I know this is out of scope of equinox and a long shot, but any help would be much appreciated :)

@patrick-kidger
Copy link
Owner

Is there any reason you need to modify them after-the-fact? Can this all just happen at initialization?

At any rate the construction you have here looks a little odd. Assuming everything is functionally pure (I sure hope so!) then code like

self = eqx.tree_at(lambda m: m.unit1, self, F0)
self = eqx.tree_at(lambda m: m.unit1, self, self.unit1.run_sol()

should be equivalent to

self = eqx.tree_at(lambda m: m.unit1, self, F0.run_sol()

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

2 participants