Skip to content

Commit

Permalink
Fix issues with mf_leg_to_new and add tests (#1640)
Browse files Browse the repository at this point in the history
  • Loading branch information
oerc0122 authored Apr 16, 2024
1 parent bab6289 commit 0dbe7af
Showing 1 changed file with 123 additions and 15 deletions.
138 changes: 123 additions & 15 deletions tools/mf_leg_to_new.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,11 +4,12 @@
J. Wilkins 8-9-2023
"""

import sys
import doctest
import re
import sys
import unittest
from typing import Iterator


ARG_TYPE = ("fun", "pin", "free", "bind")
OLD_OPT_TO_NEW = {'fit': 'fit_control_parameters',
'list': 'listing',
Expand Down Expand Up @@ -66,7 +67,18 @@ def get_args(string: str) -> Iterator[str]:


def convert_legacy_multifit(string: str) -> str:
""" Convert a string from multifit_legacy style to modern mfclass syntax """
""" Convert a MATLAB multifit line from multifit_legacy style to modern mfclass syntax
>>> print(convert_legacy_multifit("[wfit, fitdata] = multifit_sqw(my_new_cut, @sr122_xsec, pars, pfree, pbind, 'list', 1)"))
kk = multifit(my_new_cut);
kk = kk.set_fun(@sr122_xsec);
kk = kk.set_pin(pars);
kk = kk.set_free(pfree);
kk = kk.set_bind(pbind);
kk = kk.set_options('listing', 1);
[wfit, fitdata] = kk.fit();
<BLANKLINE>
"""

# In case of continuation
string = " ".join(string.split("...\n"))
Expand All @@ -79,20 +91,10 @@ def convert_legacy_multifit(string: str) -> str:

args = string[string.index('(')+1:string.rindex(')')]

strargs = get_args(args)
args = list(get_args(args))

typ = re.search("fit((?:_sqw){0,2})", string)

# Temporary substitution with format-spec to avoid commas
strlist = []
for i, strarg in enumerate(strargs):
args = args.replace(strarg, f"{{strlist[{i}]}}", 1)
strlist.append(strarg)

# Restore args to their correct places and remove whitespace
args = map(lambda x: x.strip(), args.split(","))
args = [arg.format(strlist=strlist) for arg in args]

# Find where key points are
fhandles_loc = [i for i, arg in enumerate(args) if arg.startswith("@")]
kwargs_loc = [i for i, arg in enumerate(args) if arg.startswith("'") or arg.startswith('"')]
Expand Down Expand Up @@ -143,8 +145,114 @@ def convert_legacy_multifit(string: str) -> str:
return outstr


class TestConversion(unittest.TestCase):

def test_a(self):
ans = convert_legacy_multifit("[wfit, fitdata] = multifit_sqw"
"(my_new_cut, @sr122_xsec, pars,"
"pfree, pbind, 'list', 1)")

self.assertEqual(ans, """
kk = multifit(my_new_cut);
kk = kk.set_fun(@sr122_xsec);
kk = kk.set_pin(pars);
kk = kk.set_free(pfree);
kk = kk.set_bind(pbind);
kk = kk.set_options('listing', 1);
[wfit, fitdata] = kk.fit();
""".lstrip())

def test_b(self):
ans = convert_legacy_multifit("[wfit,fitdata]=multifit_sqw"
"(my_new_cut,@sr122_xsec,pars,"
"pfree,@bfunc,bpars,bfree,bpbind, 'list',1)")
self.assertEqual(ans, """
kk = multifit(my_new_cut);
kk = kk.set_fun(@sr122_xsec);
kk = kk.set_pin(pars);
kk = kk.set_free(pfree);
kk = kk.set_bfun(@bfunc);
kk = kk.set_bpin(bpars);
kk = kk.set_bfree(bfree);
kk = kk.set_bbind(bpbind);
kk = kk.set_options('listing', 1);
[wfit,fitdata] = kk.fit();
""".lstrip())

def test_c(self):
ans = convert_legacy_multifit("multifit(mynewcut, @x, p, pf,"
"pb, @bf, bp, [true false false],"
"{1, 2, 3}, 'list', n,"
"'fit', [1 2 3], 'global_foreground',"
"'local_background', 'evaluate' 'chisqr',"
"'foreground', 'list', n, 'keep',"
"[1 2 3 4], 'mask', [1 1 1 0 0 0],"
"'ranges', 'select', false, 'average')")

self.assertEqual(ans, """
kk = multifit(mynewcut);
kk = kk.set_fun(@x);
kk = kk.set_pin(p);
kk = kk.set_free(pf);
kk = kk.set_bind(pb);
kk = kk.set_bfun(@bf);
kk = kk.set_bpin(bp);
kk = kk.set_bfree([true false false]);
kk = kk.set_bbind({1, 2, 3});
kk = kk.set_options('listing', n);
kk = kk.set_options('fit_control_parameters', [1 2 3]);
kk.global_foreground = true;
kk.local_background = true;
kk = kk.set_options('listing', n);
kk = kk.set_mask(~[1 2 3 4]);
kk = kk.set_mask([1 1 1 0 0 0]);
warning('HORACE:impossible_auto_conversion', 'Cannot convert keyword ranges')
kk = kk.set_options('selected', false);
[wfit, fit_data] = kk.fit();
""".lstrip())

def test_d(self):
ans = convert_legacy_multifit("multifit(x, y, e, @x, p, pf, pb, @bf, bp,"
"[true false false], {1, 2, 3}, "
"'fit', [1 2 3],"
"'global_foreground', 'local_background',"
" 'evaluate' 'chisqr', 'foreground', "
"'list', n, 'keep', [1 2 3 4], "
"'mask', [1 1 1 0 0 0], 'ranges', "
"'select', false, 'average')")

self.assertEqual(ans, """
kk = multifit(x, y, e);
kk = kk.set_fun(@x);
kk = kk.set_pin(p);
kk = kk.set_free(pf);
kk = kk.set_bind(pb);
kk = kk.set_bfun(@bf);
kk = kk.set_bpin(bp);
kk = kk.set_bfree([true false false]);
kk = kk.set_bbind({1, 2, 3});
kk = kk.set_options('fit_control_parameters', [1 2 3]);
kk.global_foreground = true;
kk.local_background = true;
kk = kk.set_options('listing', n);
kk = kk.set_mask(~[1 2 3 4]);
kk = kk.set_mask([1 1 1 0 0 0]);
warning('HORACE:impossible_auto_conversion', 'Cannot convert keyword ranges')
kk = kk.set_options('selected', false);
[wfit, fit_data] = kk.fit();
""".lstrip())


def load_tests(loader, tests, ignore):
tests.addTests(doctest.DocTestSuite())
return tests


if __name__ == '__main__':
if len(sys.argv) > 1:
if "test" == sys.argv[1]:
del sys.argv[1]
unittest.main()
elif len(sys.argv) > 1:
print(convert_legacy_multifit(" ".join(sys.argv[1:])))
else:
print(USAGE)

0 comments on commit 0dbe7af

Please sign in to comment.