diff --git a/python/solid_dmft/dmft_cycle.py b/python/solid_dmft/dmft_cycle.py index 7a8ccdc7..f8c11ad0 100755 --- a/python/solid_dmft/dmft_cycle.py +++ b/python/solid_dmft/dmft_cycle.py @@ -660,7 +660,7 @@ def _dmft_step(sum_k, solvers, it, general_params, mpi.report('\nSolving the impurity problem for shell {} ...'.format(icrsh)) mpi.barrier() start_time = timer() - solvers[icrsh].solve() + solvers[icrsh].solve(it=it) mpi.barrier() mpi.report('Actual time for solver: {:.2f} s'.format(timer() - start_time)) diff --git a/python/solid_dmft/dmft_tools/matheval.py b/python/solid_dmft/dmft_tools/matheval.py new file mode 100644 index 00000000..47ea6267 --- /dev/null +++ b/python/solid_dmft/dmft_tools/matheval.py @@ -0,0 +1,56 @@ +# https://stackoverflow.com/a/30516254 + +import ast +import math + + +class MathExpr(object): + allowed_nodes = ( + ast.Module, + ast.Expr, + ast.Load, + ast.Expression, + ast.Add, + ast.Sub, + ast.UnaryOp, + ast.Num, + ast.BinOp, + ast.Mult, + ast.Div, + ast.Pow, + ast.BitOr, + ast.BitAnd, + ast.BitXor, + ast.USub, + ast.UAdd, + ast.FloorDiv, + ast.Mod, + ast.LShift, + ast.RShift, + ast.Invert, + ast.Call, + ast.Name, + ) + + functions = { + "abs": abs, + "complex": complex, + "min": min, + "max": max, + "pow": pow, + "round": round, + } | {key: value for (key, value) in vars(math).items() if not key.startswith("_")} + + def __init__(self, expr): + if any(elem in expr for elem in "\n#"): + raise ValueError(expr) + + node = ast.parse(expr.strip(), mode="eval") + for curr in ast.walk(node): + if not isinstance(curr, self.allowed_nodes): + raise ValueError(curr) + + self.code = compile(node, "", "eval") + + def __call__(self, **kwargs): + return eval(self.code, {"__builtins__": None}, self.functions | kwargs) diff --git a/python/solid_dmft/dmft_tools/solver.py b/python/solid_dmft/dmft_tools/solver.py index 2cf272d2..951e031e 100755 --- a/python/solid_dmft/dmft_tools/solver.py +++ b/python/solid_dmft/dmft_tools/solver.py @@ -32,6 +32,7 @@ from h5 import HDFArchive from . import legendre_filter +from .matheval import MathExpr def get_n_orbitals(sum_k): """ @@ -108,7 +109,7 @@ class SolverStructure: Methods ------- - solve(self) + solve(self, **kwargs) solve impurity problem ''' @@ -140,6 +141,10 @@ def __init__(self, general_params, solver_params, advanced_params, sum_k, icrsh, self.h_int = h_int self.iteration_offset = iteration_offset self.solver_struct_ftps = solver_struct_ftps + if solver_params.get("random_seed") is None: + self.random_seed_generator = None + else: + self.random_seed_generator = MathExpr(solver_params["random_seed"]) # initialize solver object, options are cthyb if self.general_params['solver_type'] == 'cthyb': @@ -336,11 +341,16 @@ def _init_ReFreq_hartree(self): # solver-specific solve() command # ******************************************************************** - def solve(self): + def solve(self, **kwargs): r''' solve impurity problem with current solver ''' + if self.random_seed_generator is None: + random_seed = {} + else: + random_seed = { "random_seed": int(self.random_seed_generator(it=kwargs["it"], rank=mpi.rank)) } + if self.general_params['solver_type'] == 'cthyb': if self.general_params['cthyb_delta_interface']: @@ -387,7 +397,7 @@ def solve(self): # Solve the impurity problem for icrsh shell # ************************************* - self.triqs_solver.solve(h_int=self.h_int, **self.solver_params) + self.triqs_solver.solve(h_int=self.h_int, **(self.solver_params | random_seed )) # ************************************* # call postprocessing @@ -403,7 +413,7 @@ def solve(self): # Solve the impurity problem for icrsh shell # ************************************* - self.triqs_solver.solve(h_int=self.h_int, **self.solver_params) + self.triqs_solver.solve(h_int=self.h_int, **(self.solver_params | random_seed )) # ************************************* # call postprocessing @@ -577,7 +587,7 @@ def make_positive_definite(G): # Solve the impurity problem for icrsh shell # ************************************* - self.triqs_solver.solve(h_int=self.h_int, **self.solver_params) + self.triqs_solver.solve(h_int=self.h_int, **(self.solver_params | random_seed )) # ************************************* # call postprocessing @@ -594,7 +604,7 @@ def make_positive_definite(G): # Solve the impurity problem for icrsh shell # ************************************* - self.triqs_solver.solve(h_int=self.h_int, **self.solver_params) + self.triqs_solver.solve(h_int=self.h_int, **(self.solver_params | random_seed )) # ************************************* # call postprocessing diff --git a/python/solid_dmft/read_config.py b/python/solid_dmft/read_config.py index e29a8937..11b44068 100755 --- a/python/solid_dmft/read_config.py +++ b/python/solid_dmft/read_config.py @@ -281,7 +281,7 @@ start matsubara frequency to start with fit_max_w : float, optional highest matsubara frequency to fit -random_seed : int, optional default by triqs +random_seed : str, optional default by triqs if specified the int will be used for random seeds! Careful, this will give the same random numbers on all mpi ranks legendre_fit : bool, optional default= False @@ -733,7 +733,7 @@ 'move_shift': {'converter': BOOL_PARSER, 'default': False, 'used': lambda params: params['general']['solver_type'] in ['cthyb']}, - 'random_seed': {'converter': int, 'default': None, + 'random_seed': {'converter': str, 'default': None, 'used': lambda params: params['general']['solver_type'] in ['cthyb', 'ctint', 'ctseg']}, 'perform_tail_fit': {'converter': BOOL_PARSER, diff --git a/test/python/CMakeLists.txt b/test/python/CMakeLists.txt index a6804be7..6b06542f 100644 --- a/test/python/CMakeLists.txt +++ b/test/python/CMakeLists.txt @@ -25,6 +25,7 @@ endforeach() # all other tests set(all_tests test_convergence + test_matheval test_plot_correlated_bands test_respack_sfo ) diff --git a/test/python/test_matheval.py b/test/python/test_matheval.py new file mode 100644 index 00000000..332391cd --- /dev/null +++ b/test/python/test_matheval.py @@ -0,0 +1,39 @@ +# Copyright (c) 2018-2022 Simons Foundation +# +# This program is free software: you can redistribute it and/or modify +# it under the terms of the GNU General Public License as published by +# the Free Software Foundation, either version 3 of the License, or +# (at your option) any later version. +# +# This program is distributed in the hope that it will be useful, +# but WITHOUT ANY WARRANTY; without even the implied warranty of +# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the +# GNU General Public License for more details. +# +# You may obtain a copy of the License at +# https:#www.gnu.org/licenses/gpl-3.0.txt +# +# Authors: Alexander Hampel + +from solid_dmft.dmft_tools.matheval import MathExpr +import unittest + + +class test_mathexpr(unittest.TestCase): + def test_simple(self): + expr = MathExpr("1 + 1") + result = expr() + self.assertEqual(result, 2) + + def test_variables(self): + expr = MathExpr("34788 * it + 928374 * rank") + result = expr(it=5, rank=9) + self.assertEqual(result, 34788 * 5 + 928374 * 9) + + def test_breakout(self): + with self.assertRaises(ValueError): + expr = MathExpr("(1).__class__") + + +if __name__ == "__main__": + unittest.main()