Skip to content

Commit

Permalink
Prevent accident overwritten in ctypes.Structure
Browse files Browse the repository at this point in the history
  • Loading branch information
sunqm committed Oct 30, 2024
1 parent 6472403 commit 57245ac
Show file tree
Hide file tree
Showing 4 changed files with 10 additions and 0 deletions.
5 changes: 5 additions & 0 deletions pyscf/pbc/dft/multigrid/multigrid_pair.py
Original file line number Diff line number Diff line change
Expand Up @@ -83,6 +83,7 @@ class GridLevel_Info(ctypes.Structure):
'''
Info about the grid levels.
'''
__slots__ = []
_fields_ = [("nlevels", ctypes.c_int), # number of grid levels
("rel_cutoff", ctypes.c_double),
("cutoff", ctypes.POINTER(ctypes.c_double)),
Expand All @@ -92,6 +93,7 @@ class RS_Grid(ctypes.Structure):
'''
Values on real space multigrid.
'''
__slots__ = []
_fields_ = [("nlevels", ctypes.c_int),
("gridlevel_info", ctypes.POINTER(GridLevel_Info)),
("comp", ctypes.c_int),
Expand All @@ -102,6 +104,7 @@ class PGFPair(ctypes.Structure):
'''
A primitive Gaussian function pair.
'''
__slots__ = []
_fields_ = [("ish", ctypes.c_int),
("ipgf", ctypes.c_int),
("jsh", ctypes.c_int),
Expand All @@ -114,6 +117,7 @@ class Task(ctypes.Structure):
'''
A single task.
'''
__slots__ = []
_fields_ = [("buf_size", ctypes.c_size_t),
("ntasks", ctypes.c_size_t),
("pgfpairs", ctypes.POINTER(ctypes.POINTER(PGFPair))),
Expand All @@ -124,6 +128,7 @@ class TaskList(ctypes.Structure):
'''
A task list.
'''
__slots__ = []
_fields_ = [("nlevels", ctypes.c_int),
("hermi", ctypes.c_int),
("gridlevel_info", ctypes.POINTER(GridLevel_Info)),
Expand Down
1 change: 1 addition & 0 deletions pyscf/pbc/gto/_pbcintor.py
Original file line number Diff line number Diff line change
Expand Up @@ -61,6 +61,7 @@ def __del__(self):
pass

class _CPBCOpt(ctypes.Structure):
__slots__ = []
_fields_ = [('rrcut', ctypes.c_void_p),
('rcut', ctypes.c_void_p),
('fprescreen', ctypes.c_void_p)]
3 changes: 3 additions & 0 deletions pyscf/pbc/gto/neighborlist.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,20 +24,23 @@
libpbc = lib.load_library('libpbc')

class _CNeighborPair(ctypes.Structure):
__slots__ = []
_fields_ = [("nimgs", ctypes.c_int),
("Ls_list", ctypes.POINTER(ctypes.c_int)),
("q_cond", ctypes.POINTER(ctypes.c_double)),
("center", ctypes.POINTER(ctypes.c_double))]


class _CNeighborList(ctypes.Structure):
__slots__ = []
_fields_ = [("nish", ctypes.c_int),
("njsh", ctypes.c_int),
("nimgs", ctypes.c_int),
("pairs", ctypes.POINTER(ctypes.POINTER(_CNeighborPair)))]


class _CNeighborListOpt(ctypes.Structure):
__slots__ = []
_fields_ = [("nl", ctypes.POINTER(_CNeighborList)),
('fprescreen', ctypes.c_void_p)]

Expand Down
1 change: 1 addition & 0 deletions pyscf/scf/_vhf.py
Original file line number Diff line number Diff line change
Expand Up @@ -300,6 +300,7 @@ def set_dm(self, dm, atm, bas, env):


class _CVHFOpt(ctypes.Structure):
__slots__ = []
_fields_ = [('nbas', ctypes.c_int),
('ngrids', ctypes.c_int),
('direct_scf_tol', ctypes.c_double),
Expand Down

0 comments on commit 57245ac

Please sign in to comment.