Skip to content

Commit

Permalink
Allow mathematical expression to be passed for random_seed
Browse files Browse the repository at this point in the history
  • Loading branch information
hmenke committed Nov 8, 2023
1 parent 28d41b3 commit 7bb6272
Show file tree
Hide file tree
Showing 6 changed files with 115 additions and 9 deletions.
2 changes: 1 addition & 1 deletion python/solid_dmft/dmft_cycle.py
Original file line number Diff line number Diff line change
Expand Up @@ -660,7 +660,7 @@ def _dmft_step(sum_k, solvers, it, general_params,
mpi.report('\nSolving the impurity problem for shell {} ...'.format(icrsh))
mpi.barrier()
start_time = timer()
solvers[icrsh].solve()
solvers[icrsh].solve(it=it)
mpi.barrier()
mpi.report('Actual time for solver: {:.2f} s'.format(timer() - start_time))

Expand Down
56 changes: 56 additions & 0 deletions python/solid_dmft/dmft_tools/matheval.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,56 @@
# https://stackoverflow.com/a/30516254

import ast
import math


class MathExpr(object):
allowed_nodes = (
ast.Module,
ast.Expr,
ast.Load,
ast.Expression,
ast.Add,
ast.Sub,
ast.UnaryOp,
ast.Num,
ast.BinOp,
ast.Mult,
ast.Div,
ast.Pow,
ast.BitOr,
ast.BitAnd,
ast.BitXor,
ast.USub,
ast.UAdd,
ast.FloorDiv,
ast.Mod,
ast.LShift,
ast.RShift,
ast.Invert,
ast.Call,
ast.Name,
)

functions = {
"abs": abs,
"complex": complex,
"min": min,
"max": max,
"pow": pow,
"round": round,
} | {key: value for (key, value) in vars(math).items() if not key.startswith("_")}

def __init__(self, expr):
if any(elem in expr for elem in "\n#"):
raise ValueError(expr)

node = ast.parse(expr.strip(), mode="eval")
for curr in ast.walk(node):
if not isinstance(curr, self.allowed_nodes):
raise ValueError(curr)

self.code = compile(node, "<string>", "eval")

def __call__(self, **kwargs):
return eval(self.code, {"__builtins__": None}, self.functions | kwargs)
22 changes: 16 additions & 6 deletions python/solid_dmft/dmft_tools/solver.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,6 +32,7 @@
from h5 import HDFArchive

from . import legendre_filter
from .matheval import MathExpr

def get_n_orbitals(sum_k):
"""
Expand Down Expand Up @@ -108,7 +109,7 @@ class SolverStructure:
Methods
-------
solve(self)
solve(self, **kwargs)
solve impurity problem
'''

Expand Down Expand Up @@ -140,6 +141,10 @@ def __init__(self, general_params, solver_params, advanced_params, sum_k, icrsh,
self.h_int = h_int
self.iteration_offset = iteration_offset
self.solver_struct_ftps = solver_struct_ftps
if solver_params.get("random_seed") is None:
self.random_seed_generator = None
else:
self.random_seed_generator = MathExpr(solver_params["random_seed"])

# initialize solver object, options are cthyb
if self.general_params['solver_type'] == 'cthyb':
Expand Down Expand Up @@ -336,11 +341,16 @@ def _init_ReFreq_hartree(self):
# solver-specific solve() command
# ********************************************************************

def solve(self):
def solve(self, **kwargs):
r'''
solve impurity problem with current solver
'''

if self.random_seed_generator is None:
random_seed = {}
else:
random_seed = { "random_seed": int(self.random_seed_generator(it=kwargs["it"], rank=mpi.rank)) }

if self.general_params['solver_type'] == 'cthyb':

if self.general_params['cthyb_delta_interface']:
Expand Down Expand Up @@ -387,7 +397,7 @@ def solve(self):

# Solve the impurity problem for icrsh shell
# *************************************
self.triqs_solver.solve(h_int=self.h_int, **self.solver_params)
self.triqs_solver.solve(h_int=self.h_int, **(self.solver_params | random_seed ))
# *************************************

# call postprocessing
Expand All @@ -403,7 +413,7 @@ def solve(self):

# Solve the impurity problem for icrsh shell
# *************************************
self.triqs_solver.solve(h_int=self.h_int, **self.solver_params)
self.triqs_solver.solve(h_int=self.h_int, **(self.solver_params | random_seed ))
# *************************************

# call postprocessing
Expand Down Expand Up @@ -577,7 +587,7 @@ def make_positive_definite(G):

# Solve the impurity problem for icrsh shell
# *************************************
self.triqs_solver.solve(h_int=self.h_int, **self.solver_params)
self.triqs_solver.solve(h_int=self.h_int, **(self.solver_params | random_seed ))
# *************************************

# call postprocessing
Expand All @@ -594,7 +604,7 @@ def make_positive_definite(G):

# Solve the impurity problem for icrsh shell
# *************************************
self.triqs_solver.solve(h_int=self.h_int, **self.solver_params)
self.triqs_solver.solve(h_int=self.h_int, **(self.solver_params | random_seed ))
# *************************************

# call postprocessing
Expand Down
4 changes: 2 additions & 2 deletions python/solid_dmft/read_config.py
Original file line number Diff line number Diff line change
Expand Up @@ -281,7 +281,7 @@
start matsubara frequency to start with
fit_max_w : float, optional
highest matsubara frequency to fit
random_seed : int, optional default by triqs
random_seed : str, optional default by triqs
if specified the int will be used for random seeds! Careful, this will give the same random
numbers on all mpi ranks
legendre_fit : bool, optional default= False
Expand Down Expand Up @@ -733,7 +733,7 @@
'move_shift': {'converter': BOOL_PARSER, 'default': False,
'used': lambda params: params['general']['solver_type'] in ['cthyb']},

'random_seed': {'converter': int, 'default': None,
'random_seed': {'converter': str, 'default': None,
'used': lambda params: params['general']['solver_type'] in ['cthyb', 'ctint', 'ctseg']},

'perform_tail_fit': {'converter': BOOL_PARSER,
Expand Down
1 change: 1 addition & 0 deletions test/python/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,7 @@ endforeach()
# all other tests
set(all_tests
test_convergence
test_matheval
test_plot_correlated_bands
test_respack_sfo
)
Expand Down
39 changes: 39 additions & 0 deletions test/python/test_matheval.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,39 @@
# Copyright (c) 2018-2022 Simons Foundation
#
# This program is free software: you can redistribute it and/or modify
# it under the terms of the GNU General Public License as published by
# the Free Software Foundation, either version 3 of the License, or
# (at your option) any later version.
#
# This program is distributed in the hope that it will be useful,
# but WITHOUT ANY WARRANTY; without even the implied warranty of
# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
# GNU General Public License for more details.
#
# You may obtain a copy of the License at
# https:#www.gnu.org/licenses/gpl-3.0.txt
#
# Authors: Alexander Hampel

from solid_dmft.dmft_tools.matheval import MathExpr
import unittest


class test_mathexpr(unittest.TestCase):
def test_simple(self):
expr = MathExpr("1 + 1")
result = expr()
self.assertEqual(result, 2)

def test_variables(self):
expr = MathExpr("34788 * it + 928374 * rank")
result = expr(it=5, rank=9)
self.assertEqual(result, 34788 * 5 + 928374 * 9)

def test_breakout(self):
with self.assertRaises(ValueError):
expr = MathExpr("(1).__class__")


if __name__ == "__main__":
unittest.main()

0 comments on commit 7bb6272

Please sign in to comment.