Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Struct #16

Open
wants to merge 3 commits into
base: cleanup2
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion lyncs_quda/enum.py
Original file line number Diff line number Diff line change
Expand Up @@ -118,7 +118,7 @@ class Enum(metaclass=EnumMeta):

def __init__(self, fnc, lpath=None, default=None, callback=None):
# fnc is supposed to return either a stripped key name or value of
# the corresponding QUDA enum type
# the corresponding QUDA enum type so as to decorate a property obj
self.fnc = fnc
self.lpath = lpath
self.default = default
Expand Down
4 changes: 2 additions & 2 deletions lyncs_quda/lib.py
Original file line number Diff line number Diff line change
Expand Up @@ -222,7 +222,7 @@ def copy_struct(self):
"""
)
return self.lyncs_quda_copy_struct

def save_tuning(self):
if self.tune_enabled:
self.saveTuneCache()
Expand Down Expand Up @@ -261,8 +261,8 @@ def __del__(self):
PATHS = list(__path__)

headers = [
"comm_quda.h",
"quda.h",
"comm_quda.h",
"gauge_field.h",
"gauge_tools.h",
"gauge_path_quda.h",
Expand Down
150 changes: 131 additions & 19 deletions lyncs_quda/struct.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,8 +6,9 @@
"to_code",
]

from lyncs_cppyy import nullptr
from lyncs_utils import isiterable, setitems
import numpy as np
from lyncs_cppyy import nullptr, to_pointer, addressof
from lyncs_utils import isiterable
from .lib import lib
from . import enums

Expand All @@ -17,11 +18,13 @@ def to_human(val, typ=None):
if typ is None:
return val
if typ in dir(enums):
return getattr(enums, typ)[val]
return str(getattr(enums, typ)[val])
if isinstance(val, (int, float)):
return val
if "*" in typ and val == nullptr:
return 0
if "char" in typ:
return "".join(list(val))
return val


Expand All @@ -32,7 +35,7 @@ def to_code(val, typ=None):
if typ in dir(enums):
if isinstance(val, int):
return val
return getattr(enums, typ)[val]
return int(getattr(enums, typ)[val])
if typ in ["int", "float", "double"]:
return val
if "char" in typ:
Expand All @@ -45,14 +48,69 @@ def to_code(val, typ=None):
return val


def get_dtype(typ):
if "*" in typ:
return np.dtype(object)
typ_dict = {"complex":"c", "unsigned":"u", "float":"single", "char":"byte", "comlex_double":"double"}
typ_list = typ.split()
dtype = ""
for w in typ_list:
if w in ("complex", "unsigned"):
dtype = typ_dict[w] + dtype
dtype += typ_dict.get(w, w)
if typ in ("bool", "long"):
dtype += "_"
if "int" in typ:
dtype += "c"
return np.dtype(dtype)


def setitems(arr, vals, shape=None, is_string=False):
"Sets items of an iterable object"
shape = shape if shape is not None else arr.shape
size = shape[0] #len(arr)
if not is_string and type(vals) == str:
# sometimes, vals is turned into str
vals = eval(vals)
if hasattr(vals, "__len__") and type(vals) != bytes:
if len(vals) > size:
raise ValueError(
f"Values size ({len(vals)}) larger than array size ({size})"
)
else:
vals = (vals,) * size
for i, val in enumerate(vals):
if len(shape)>1 and hasattr(arr[i], "__len__"):
is_string = len(shape[1:]) == 1 and type(vals[0]) == str
setitems(arr[i], val, shape = shape[1:], is_string=is_string)
else:
arr[i] = val


class Struct:
"Struct base class"
_types = {}

def __init__(self, *args, **kwargs):
# ? better to simply store (key, val) pair into an instance's own __dict__, if key is in _types.keys()
self._params = getattr(lib, type(self).__name__)() # ? recursive?

#? is *args necessary? when provided, it causes error in update
self._quda_params = getattr(lib, "new"+type(self).__name__)()

# some fields are not set by QUDA's new* function
default_params = getattr(lib, type(self).__name__)()
for key in self.keys():
# to avoid Enum error due to unexpected key-value pair
if self._types[key] in dir(enums) and not key in kwargs:
enm = getattr(enums, self._types[key])
if not getattr(self._quda_params, key) in enm.values():
val = list(enm.values())[-1]
self._assign(key, val)

# temporal fix: newQudaMultigridParam does not assign a default value to n_level
if "Multigrid" in type(self).__name__:
n = getattr(self._quda_params, "n_level")
n = lib.QUDA_MAX_MG_LEVEL if n < 0 or n > lib.QUDA_MAX_MG_LEVEL else n
setattr(self._quda_params, "n_level", n)

for arg in args:
self.update(arg)
self.update(kwargs)
Expand All @@ -67,28 +125,67 @@ def items(self):

def update(self, params):
"Updates values of the structure"
if not hasattr(
params, "items"
): # ? in __init__, it takes *args, which is a tuple. expect a tuple of dict's?
if not hasattr(params, "items"):
raise TypeError(f"Unsopported type for params: {type(params)}")
for key, val in params.items():
setattr(self, key, val)

def _assign(self, key, val):
typ = self._types[key]
val = to_code(val, typ)
cur = getattr(self._params, key)

if hasattr(cur, "shape"): # ? what is this?
setitems(cur, val) # ? what is this?
val = to_code(val, typ)
cur = getattr(self._quda_params, key)

if "[" in self._types[key] and not hasattr(cur, "shape"):# not sure if this is needed for cppyy3.0.0
# safeguard against hectic behavior of cppyy
raise RuntimeError("cppyy is not happy for now. Try again!")


if typ.count("[") > 1:
# cppyy<=3.0.0 cannot handle subviews properly
# Trying to manipulate the sub-array either results in error or segfault
# => array of arrays is set using glb.memcpy
# Alternative:
# use ctypes (C = ctypes, arr = LowlevelView of array of arrays)
# ptr = C.cast(cppyy.ll.addressof(arr), C.POINTER(C.c_int))
# narr = np.ctypeslib.as_array(ptr, shape=arr.shape)
# This allows to access sub-indicies properly, i.e., narr[2][3] = 9 works
assert hasattr(cur, "shape")
if "file" in key:
#? array = np.zeros(cur.shape, dtype="S1") and remove setitems(array, b"\0"); is this ok?
#? not sure of this as "" is not b"\0"
array = np.chararray(cur.shape)
setitems(array, b"\0")
setitems(array, val)
size = 1
else:
dtype = get_dtype(typ[:typ.index("[")].strip())
array = np.asarray(val, dtype=dtype)
size = dtype.itemsize
lib.memcpy(to_pointer(addressof(cur)), to_pointer(array.__array_interface__["data"][0]), int(np.prod(cur.shape))*size)
elif typ.count("[") == 1:
assert hasattr(cur, "shape")
shape = tuple([getattr(lib, macro) for macro in typ.split(" ") if "QUDA_" in macro or macro.isnumeric()]) #not necessary for cppyy3.0.0?
cur.reshape(shape) #? not necessary for cppyy3.0.0?
if "*" in typ:
for i in range(shape[0]):
val = to_pointer(addressof(val), ctype = typ[:-typ.index("[")].strip())
is_string = True if "char" in typ else False
if is_string:
setitems(cur, b"\0") # for printing
setitems(cur, val, is_string=is_string)
else:
setattr(self._params, key, val)
if "*" in typ:
# cannot set nullptr to void *, int *, etc; works for classes such as Enum classes with bind_object
if val == nullptr:
raise ValueError("Cannot cast nullptr to a valid pointer")
val = to_pointer(addressof(val), ctype = typ)
setattr(self._quda_params, key, val)

def __dir__(self):
return list(set(list(super().__dir__()) + list(self._params.keys())))
return list(set(list(super().__dir__()) + list(self._quda_params.keys())))

def __getattr__(self, key):
return to_human(getattr(self._params, key), self._types[key])
return to_human(getattr(self._quda_params, key), self._types[key])

def __setattr__(self, key, val):
if key in self.keys():
Expand All @@ -98,8 +195,23 @@ def __setattr__(self, key, val):
raise TypeError(
f"Cannot assign '{val}' to '{key}' of type '{self._types[key]}'"
)
else:
else: #should we allow this?
super().__setattr__(key, val)

def __str__(self):
return str(dict(self.items()))

@property
def quda(self):
return self._quda_params

@property
def address(self):
return addressof(self.quda)

@property
def ptr(self):
return to_pointer(addressof(self.quda), ctype = type(self).__name__ + " *")

def printf(self):
getattr(lib, "print"+type(self).__name__)(self._quda_params)
34 changes: 29 additions & 5 deletions test/test_structs.py
Original file line number Diff line number Diff line change
@@ -1,13 +1,37 @@
from lyncs_quda import structs # This is also importing Enum
from lyncs_quda.enum import Enum
from lyncs_quda.testing import fixlib as lib


def test_assign_zero(lib):
for struct in dir(structs):
if struct.startswith("_") or struct == "Struct":
continue

if struct.startswith("_") or struct == "Struct" or struct == "Enum" or issubclass(getattr(structs, struct), Enum):
continue

params = getattr(structs, struct)()

for key in params.keys():
setattr(params, key, 0)
typ = getattr(structs, struct)._types[key]
obj = getattr(structs, typ) if typ in dir(structs) else None
val = 0
if obj != Enum and issubclass(obj, Enum):
val = list(obj.values())[0]
elif "*" in typ: # cannot set a pointer field to nullptr via cppyy
continue
print("tst",struct,key,typ, obj,val)

setattr(params, key, val)

def test_assign_something(lib):
mp = structs.QudaMultigridParam()
ip = structs.QudaInvertParam()
ep = structs.QudaEigParam()

# ptr to strct class works
mp.n_level = 3 # This is supposed to be set explicitly
mp.invert_param = ip.quda
ip.split_grid = list(range(lib.QUDA_MAX_DIM))
ip.madwf_param_infile = "hi I'm here!"
mp.geo_block_size = [[i+j+1 for j in range(lib.QUDA_MAX_DIM)] for i in range(lib.QUDA_MAX_MG_LEVEL)]
mp.vec_infile = ["infile" + str(i) for i in range(lib.QUDA_MAX_MG_LEVEL)]
mp.printf()
print(ip.madwf_param_infile)