Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Multigrid #17

Open
wants to merge 6 commits into
base: struct
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions lyncs_quda/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,3 +14,4 @@
from .dirac import *
from .solver import *
from .evenodd import *
from .multigrid import *
40 changes: 39 additions & 1 deletion lyncs_quda/dirac.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,11 +14,13 @@
from .clover_field import CloverField
from .spinor_field import spinor
from .lib import lib
from .structs import QudaGaugeParam
from .enums import (
QudaDiracType,
QudaMatPCType,
QudaDagType,
QudaParity,
QudaDslashType,
)


Expand Down Expand Up @@ -66,6 +68,18 @@ def type(self):
return "CLOVER" + PC
return "TWISTED_CLOVER" + PC

@property
@QudaDslashType
def dslash_type(self):
if "coarse" in self.type:
return "INVALID"
dslash_type = str(self.type).replace("pc","")
dslash_type = dslash_type.replace("gauge_","")
if "clover" == dslash_type: dslash_type += "_wilson"
if "mobius" in dslash_type: dslash_type = dslash_type.replace("domain_wall","dwf")

return dslash_type

@property
@QudaMatPCType
def matPCtype(self):
Expand Down Expand Up @@ -345,7 +359,31 @@ def force(self, *phis, out=None, mult=2, coeffs=None, **params):

return out


#TODO: needs to be more automatic
def setGaugeParam(self, **gauge_options):
g_param = QudaGaugeParam()

#TODO: prepare default params for other type of dirac op
if "wilson" in self.dslash_type or "clover" in self.dslash_type or "twisted" in self.dslash_type:
lib.setWilsonGaugeParam(g_param.quda)
elif "staggered" in self.type:
lib.setStaggeredGaugeParam(g_param.quda)
else:
lib.setGaugeParam(g_param.quda)

g_param.location = int(self.gauge.location)
g_param.X = self.gauge.local_lattice
g_param.anisotropy = self.gauge.quda_field.Anisotropy()
g_param.tadpole_coeff = self.gauge.quda_field.Tadpole()
g_param.type = int(self.gauge.link_type)
g_param.gauge_order = int(self.gauge.order)
g_param.t_boundary = int(self.gauge.t_boundary)
g_param.cpu_prec = int(self.gauge.precision)
g_param.cuda_prec = int(self.gauge.precision)
g_param.update(gauge_options)
lib.loadGaugeQuda(self.gauge.quda_field.Gauge_p(), g_param.quda)


GaugeField.Dirac = wraps(Dirac)(lambda *args, **kwargs: Dirac(*args, **kwargs))


Expand Down
1 change: 0 additions & 1 deletion lyncs_quda/enum.py
Original file line number Diff line number Diff line change
Expand Up @@ -55,7 +55,6 @@ def items(cls):
return cls._values.items()

def clean(cls, rep):
# should turn everything into upper for consistency
"Strips away prefix and suffix from key"
"See enums.py to find what is prefix and suffix for a given enum value"
if isinstance(rep, EnumValue):
Expand Down
28 changes: 26 additions & 2 deletions lyncs_quda/lib.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@
]

import atexit
import cppyy
from os import environ
from pathlib import Path
from array import array
Expand Down Expand Up @@ -223,6 +224,24 @@ def copy_struct(self):
)
return self.lyncs_quda_copy_struct

@property
def set_mg_eig_param(self):
try:
return self.lyncs_quda_set_mg_eig_param
except AttributeError:
cppdef(
"""
template<typename T, int n>
void lyncs_quda_set_mg_eig_param(T** ptr_array, T param, int i, bool is_null=false) {
if ( i < n) {
if (is_null) ptr_array[i] = nullptr;
else ptr_array[i] = &param;
}
}
"""
)
return self.lyncs_quda_set_mg_eig_param

def save_tuning(self):
if self.tune_enabled:
self.saveTuneCache()
Expand Down Expand Up @@ -276,17 +295,22 @@ def __del__(self):
"array.h",
"momentum.h",
"tune_quda.h",
"utils/host_utils.h",
"utils/command_line_params.h",
]


lib = QudaLib(
path=PATHS,
header=headers,
library=["libquda.so"] + libs,
library=["libquda.so", "libquda_test.so"] + libs,
namespace=["quda", "lyncs_quda"],
defined={"QUDA_PRECISION": QUDA_PRECISION, "QUDA_RECONSTRUCT": QUDA_RECONSTRUCT},
)
lib.MPI = MPI
# TODO: need to change "load" function of Lib from lyncs_cppyy to avoid the line below
# NOTE: This assumes: Python3.x & not runned from Jupyter notebook
# alternative: os.path.dirname(os.path.abspath(__file__))
cppyy.add_include_path(str(Path(__file__).parent.absolute()) + "/include/externals")

# used?
try:
Expand Down
114 changes: 114 additions & 0 deletions lyncs_quda/multigrid.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,114 @@
"""
Interface to multigrid_solver
"""

__all__ = ["MultigridPreconditioner"]

from cppyy import bind_object
from lyncs_cppyy import nullptr
from lyncs_utils import isiterable
from .lib import lib
from .enums import QudaInverterType, QudaPrecision, QudaSolveType
from .structs import QudaInvertParam, QudaMultigridParam, QudaEigParam

class MultigridPreconditioner:
__slots__ = ["_quda", "mg_param", "inv_param"]

def __init__(self, D, inv_options={}, mg_options={}, eig_options={}, is_eig=False):
self._quda = None
self.mg_param, self.inv_param = self.prepareParams(D, inv_options=inv_options, mg_options=mg_options, eig_options=eig_options, is_eig=is_eig)

@property
@QudaInverterType
def inv_type_precondition(self):
return "MG_INVERTER"

#TODO: absorb updateMG_solver into this property and delete the function
# This will reqiure detecting the change of mg and inv param structs from
# the last update or creation of QUDA multigrid_solver object
@property
def quda(self):
if self._quda is None:
self._quda = lib.newMultigridQuda(self.mg_param.quda)
elif self.mg_param.updated or self.inv_param.updated:
lib.updateMultigridQuda(self._quda, self.mg_param.quda)
self.mg_param.updated = False
self.inv_param.updated = False
return self._quda

# TODO: can also accept structs?
def prepareParams(self, D, g_options={}, inv_options={}, mg_options={}, eig_options={}, is_eig=False):
# INPUT: D is a Dirac instance
# is_eig is a list of bools indicating whether eigsolver is used to generate
# near null-vectors at each level
inv_param = QudaInvertParam()
mg_param = QudaMultigridParam()
mg_param.invert_param = inv_param.quda

# set* are defined in set_params.cpp, setting params to vals according to the ones defined globally
# <- command_line_params.cpp: contains some default values for those global vars, some set to invalid
# <- host_utils.h provides funcs to set global vars to some meaningful vals, according to vals in command_line...
# <- misc.h implemented in misc.cpp

# Set internal global vars to their default vals
dslash_type = D.dslash_type
solve_type = QudaSolveType["direct"] if D.full else QudaSolveType["direct_pc"]
lib.dslash_type = int(dslash_type)
lib.solve_type = int(solve_type)
lib.setQudaPrecisions()
lib.setQudaDefaultMgTestParams()
lib.setQudaMgSolveTypes()

# Set param vals to the default vals and update according to the user's specification
D.setGaugeParam(gauge_options=g_options)
lib.setMultigridParam(mg_param.quda)
if not D.full: inv_param.matpc_type = int(D.matPCtype)
inv_param.dagger = int(D.dagger)
inv_param.cpu_prec = int(D.precision) # quda.h says this is supposed to be the prec of input fermion field
inv_param.cuda_prec = int(D.precision)
if "clover" in D.type:
inv_param.compute_clover = False
inv_param.clover_cpu_prec = int(D.clover.precision)
inv_param.clover_cuda_prec = int(D.clover.precision)
inv_param.clover_order = int(D.clover.order)
inv_param.clover_location = int(D.clover.location)
inv_param.clover_csw = D.clover.csw
inv_param.clover_coeff = D.clover.coeff
inv_param.clover_rho = D.clover.rho
inv_param.compute_clover = False
inv_param.compute_clover_inverse = False
inv_param.return_clover = False
inv_param.return_clover_inverse = False
inv_param.update(inv_options)
mg_param.update(mg_options)
if "clover" in D.type:
D.clover.clover_field
D.clover.inverse_field
lib.loadCloverQuda(D.clover.quda_field.V(), D.clover.quda_field.V(True), inv_param.quda)
mg_param.invert_param = inv_param.quda #not sure if this is necessary?

# Only these fermions are supported with MG
if dslash_type != "WILSON" and dslash_type != "CLOVER_WILSON" and dslash_type != "TWISTED_MASS" and dslash_type != "TWISTED_CLOVER":
raise ValueError(f"dslash_type {dslash_type} not supported for MG")
# Only these solve types are supported with MG
if solve_type != "DIRECT" and solve_type != "DIRECT_PC":
raise ValueError(f"Solve_type {solve_type} not supported with MG. Please use QUDA_DIRECT_SOLVE or QUDA_DIRECT_PC_SOLVE")
if not isiterable(is_eig):
is_eig = [is_eig]*mg_param.n_level
for i, eig in enumerate(is_eig):
eig_param = QudaEigParam()
if eig:
lib.setMultigridEigParam(eig_param.quda)
eig_param.update(eig_options)
lib.set_mg_eig_param["QudaEigParam", lib.QUDA_MAX_MG_LEVEL](mg_param.eig_param, eig_param.quda, i)
else:
lib.set_mg_eig_param["QudaEigParam", lib.QUDA_MAX_MG_LEVEL](mg_param.eig_param, eig_param.quda, i, is_null=True)

return mg_param, inv_param

def __del__(self):
if self._quda is not None:
lib.destroyMultigridQuda(self._quda)
self._quda = None


33 changes: 23 additions & 10 deletions lyncs_quda/solver.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,16 +9,17 @@

from functools import wraps
from warnings import warn
from cppyy import bind_object
from lyncs_cppyy import nullptr, make_shared
from .dirac import Dirac, DiracMatrix
from .enums import QudaInverterType, QudaPrecision, QudaResidualType, QudaBoolean
from .enums import QudaInverterType, QudaPrecision, QudaResidualType, QudaBoolean, QudaSolutionType
from .lib import lib
from .spinor_field import spinor
from .time_profile import default_profiler, TimeProfile


def solve(mat, rhs, out=None, **kwargs):
return Solver(mat)(rhs, out, **kwargs)
def solve(mat, rhs, out=None, precon=None, **kwargs):
return Solver(mat, precon=precon)(rhs, out, **kwargs)


class Solver:
Expand Down Expand Up @@ -79,11 +80,11 @@ class Solver:
def _init_params():
return lib.SolverParam()

def __init__(self, mat, **kwargs):
def __init__(self, mat, precon=None, **kwargs):
self._params = self._init_params()
self._solver = None
self._profiler = None
self._precon = None
self.preconditioner = precon
self.mat = mat

params = type(self).default_params.copy()
Expand Down Expand Up @@ -180,7 +181,10 @@ def preconditioner(self, value):
self._params.inv_type_precondition = int(QudaInverterType["INVALID"])
self._params.preconditioner = nullptr
else:
raise NotImplementedError
self._precon = value
self._params.inv_type_precondition = int(self._precon.inv_type_precondition)
self._params.preconditioner = self._precon.quda


def _update_return_residual(self, old, new):
assert self._params.return_residual == new
Expand Down Expand Up @@ -223,17 +227,26 @@ def swap(self, **params):
del params[key]
return params

def __call__(self, rhs, out=None, warning=True, **kwargs):
def __call__(self, rhs, out=None, warning=True, solution_typ=None, **kwargs):
rhs = spinor(rhs)
out = rhs.prepare_out(out)
kwargs = self.swap(**kwargs)
# ASSUME: QUDA_FULL_SITE_SUBSET
if self.mat.dirac.full:
self.quda(out.quda_field, rhs.quda_field)
elif self.mat.dirac.even:
self.quda(out.quda_field.Even(), rhs.quda_field.Even())
elif solution_typ is not None:
# Computes the full inverse based on the e-o preconditioned matrix
in_, out_ = bind_object(nullptr, "quda::ColorSpinorField"), bind_object(nullptr, "quda::ColorSpinorField")
styp = int(QudaSolutionType[solution_typ])
self.mat.dirac.quda_dirac.prepare(in_, out_, out.quda_field, rhs.quda_field, styp)
self.quda(out_, in_)
self.mat.dirac.quda_dirac.reconstruct(out.quda_field, rhs.quda_field, styp)
else:
self.quda(out.quda_field.Odd(), rhs.quda_field.Odd())
# Computes the inverse of the Schur complement of the matpc type
if self.mat.dirac.even:
self.quda(out.quda_field.Even(), rhs.quda_field.Even())
else:
self.quda(out.quda_field.Odd(), rhs.quda_field.Odd())
self.swap(**kwargs)

if self.true_res > self.tol:
Expand Down
16 changes: 11 additions & 5 deletions lyncs_quda/spinor_field.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@
from lyncs_cppyy.ll import to_pointer
from .lib import lib
from .lattice_field import LatticeField
from .enum import EnumValue
from .enums import (
QudaGammaBasis,
QudaFieldOrder,
Expand Down Expand Up @@ -64,6 +65,11 @@ def __init__(self, *args, gamma_basis=None, site_order="EO", **kwargs):
self.gamma_basis = gamma_basis
self.site_order = site_order

def _prepare(self, field, **kwargs):
kwargs.setdefault("gamma_basis", self.gamma_basis)
kwargs.setdefault("site_order", self.site_order)
return super()._prepare(field, **kwargs)

@property
def ncolor(self):
"Number of colors of the field"
Expand Down Expand Up @@ -93,11 +99,10 @@ def gamma_basis(self, value):
if value is None:
value = "UKQCD"
values = f"Possible values are {SpinorField.gammas}"
if not isinstance(value, str):
raise TypeError("Expected a string. " + values)
if not value.upper() in values:
value = str(QudaGammaBasis[value]).upper()
if not value in values:
raise ValueError("Invalid gamma. " + values)
self._gamma_basis = value.upper()
self._gamma_basis = value

@property
@QudaFieldOrder
Expand Down Expand Up @@ -125,9 +130,10 @@ def site_order(self, value):
if value is None:
value = "NONE"
values = "Possible values are NONE, EVEN_ODD, ODD_EVEN"
if isinstance(value, EnumValue):
value = str(value).upper()
if not isinstance(value, str):
raise TypeError("Expected a string. " + values)
value = value.upper()
if value in ["NONE", "LEX", "LEXICOGRAPHIC"]:
value = "LEXICOGRAPHIC"
elif value in ["EO", "EVEN_ODD"]:
Expand Down
Loading