Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

How to save the trained model #21

Open
luokuang2001 opened this issue Sep 18, 2024 · 1 comment
Open

How to save the trained model #21

luokuang2001 opened this issue Sep 18, 2024 · 1 comment

Comments

@luokuang2001
Copy link

I have trained a symbolic regression model, it has the following form:
image

How can I save the trained model, so that next time I can use the model directly without retraining it?
Just like the "torch.save" in Pytorch can save a "pth" file

@foolnotion
Copy link
Member

Hi, you have some options:

  1. Save the string and parse it back into a tree later
  2. Save the tree model directly using pickle

Here is some example code to illustrate the above. Pyoperon offers bindings to the Operon library which can enable this task.

import pyoperon as op # the operon bindings
import os
import re
import pickle

Load some data and get the actual variables (described by hashes):

dataset = op.Dataset('./data/Poly-10.csv', True)

variable_hashes = [v.Hash for v in dataset.Variables]
variable_hashes

Now let's say you have an infix expression in string form. First, we extract the variable names from the expressions (these will always be named $X_i$ since pyoperon doesn't need a header). We map every $X_i$ to the actual corresponding variable from the dataset

expr = '((-0.000030386898288270458579) + ((-0.400405615568161010742188) * (((2.294569730758666992187500 * X5) * ((-1.088928818702697753906250) * X6)) - (((((-0.937077045440673828125000) * X7) * (((-2.449583292007446289062500) * X9) * (0.209568977355957031250000 * X1))) + (((((-0.039391547441482543945312) * X4) * ((-0.039391547441482543945312) * X4)) * (((2.302585124969482421875000 * X2) * (1.030569076538085937500000 * X1)) - ((((-0.039391547441482543945312) * X4) * (1.036243557929992675781250 * X1)) + ((2.718281745910644531250000 * X3) * (1.777204394340515136718750 * X4))))) + ((((-0.929394364356994628906250) * X7) * (((-2.058582305908203125000000) * X9) * (1.052866220474243164062500 * X1))) + ((2.407963514328002929687500 * X2) * (1.036243557929992675781250 * X1))))) + ((1.414213538169860839843750 * X3) * ((1.770285606384277343750000 * X4) - (((-1.811097383499145507812500) * X10) * (0.974092006683349609375000 * X6))))))))'

m = re.findall(r'X\d+', expr)

variables = {}

for v in (v for v in m if v not in variables):
    i = int(v.split('X')[1])
    variables[v] = variable_hashes[i-1]

variables
{'X5': 16075665569742270374,
 'X6': 9134146818458426180,
 'X7': 18044635619207560834,
 'X9': 2652961248133790663,
 'X1': 4295753595843180382,
 'X4': 17733306235974623085,
 'X2': 18188060951833565637,
 'X3': 4397419642548150523,
 'X10': 17424446509373167524}

Now that we have obtained the correct variable mapping, we can parse the expression

tree = op.InfixParser.Parse(expr, variables)

# print it out again
decimal_precision = 3 # how many decimals to use when formatting floating point values
print(op.InfixFormatter.Format(tree, dataset, decimal_precision))
((-0.000) + ((-0.400) * (((2.295 * (1.000 * X5)) * ((-1.089) * (1.000 * X6))) - (((((-0.937) * (1.000 * X7)) * (((-2.450) * (1.000 * X9)) * (0.210 * (1.000 * X1)))) + (((((-0.039) * (1.000 * X4)) * ((-0.039) * (1.000 * X4))) * (((2.303 * (1.000 * X2)) * (1.031 * (1.000 * X1))) - ((((-0.039) * (1.000 * X4)) * (1.036 * (1.000 * X1))) + ((2.718 * (1.000 * X3)) * (1.777 * (1.000 * X4)))))) + ((((-0.929) * (1.000 * X7)) * (((-2.059) * (1.000 * X9)) * (1.053 * (1.000 * X1)))) + ((2.408 * (1.000 * X2)) * (1.036 * (1.000 * X1)))))) + ((1.414 * (1.000 * X3)) * ((1.770 * (1.000 * X4)) - (((-1.811) * (1.000 * X10)) * (0.974 * (1.000 * X6)))))))))

We can evaluate the parsed tree using op.Evaluate:

values = op.Evaluate(tree, dataset, op.Range(0, 10))
values
array([ 0.4543666 ,  0.27158856, -0.11406795, -0.4064015 , -0.10081271,
        0.17754017, -1.0105664 ,  0.4164615 ,  0.44278234,  0.0433833 ],
      dtype=float32)

The tree can be pickled:

path = os.path.join('pickled', 'tree.pkl')

with open(path, 'wb+') as f:
    pickle.dump(tree, f)

Then it can be loaded again:

with open(path, 'rb') as f:
    tree_unpickled = pickle.load(f)
    print(op.InfixFormatter.Format(tree_unpickled, dataset, decimal_precision))
((-0.000) + ((-0.400) * (((2.295 * (1.000 * X5)) * ((-1.089) * (1.000 * X6))) - (((((-0.937) * (1.000 * X7)) * (((-2.450) * (1.000 * X9)) * (0.210 * (1.000 * X1)))) + (((((-0.039) * (1.000 * X4)) * ((-0.039) * (1.000 * X4))) * (((2.303 * (1.000 * X2)) * (1.031 * (1.000 * X1))) - ((((-0.039) * (1.000 * X4)) * (1.036 * (1.000 * X1))) + ((2.718 * (1.000 * X3)) * (1.777 * (1.000 * X4)))))) + ((((-0.929) * (1.000 * X7)) * (((-2.059) * (1.000 * X9)) * (1.053 * (1.000 * X1)))) + ((2.408 * (1.000 * X2)) * (1.036 * (1.000 * X1)))))) + ((1.414 * (1.000 * X3)) * ((1.770 * (1.000 * X4)) - (((-1.811) * (1.000 * X10)) * (0.974 * (1.000 * X6)))))))))

Check that it returns the same values:

values = op.Evaluate(tree, dataset, op.Range(0, 10))
values
array([ 0.4543666 ,  0.27158856, -0.11406795, -0.4064015 , -0.10081271,
        0.17754017, -1.0105664 ,  0.4164615 ,  0.44278234,  0.0433833 ],
      dtype=float32)

If you don't want to use the bindings, one can simply parse and evaluate the expression string using SymPy:

from sympy import lambdify, parse_expr, Symbol
import pandas as pd

def evaluate_expression(sexpr, data):
    symbols = [Symbol(x) for x in data.columns.values[:-1]]
    return lambdify(symbols, parse_expr(sexpr))(*data.values[:,:-1].T)

# read the data as a pandas dataframe
df = pd.read_csv('./data/Poly-10.csv')

sexpr = expr.replace('^', '**')
values = evaluate_expression(sexpr, df.iloc[0:10])
values

Hope this helps, feel free to ask for details.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

2 participants