Skip to content

Commit

Permalink
Merge pull request #117 from feixie/update_numpy
Browse files Browse the repository at this point in the history
update numpy and fix numpy records
  • Loading branch information
nikhilwoodruff authored Oct 4, 2023
2 parents ffcf4a5 + ddc09af commit be82476
Show file tree
Hide file tree
Showing 4 changed files with 13 additions and 4 deletions.
5 changes: 5 additions & 0 deletions changelog_entry.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,5 @@
- bump: minor
changes:
fixed:
- Bump numpy version
- Fix numpy record compatibility issue
4 changes: 3 additions & 1 deletion policyengine_core/parameters/helpers.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,9 @@


def contains_nan(vector):
if numpy.issubdtype(vector.dtype, numpy.record):
if numpy.issubdtype(vector.dtype, numpy.record) or numpy.issubdtype(
vector.dtype, numpy.void
):
return any([contains_nan(vector[name]) for name in vector.dtype.names])
else:
return numpy.isnan(vector).any()
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,7 @@ def build_from_node(
node: "ParameterNode",
) -> "VectorialParameterNodeAtInstant":
VectorialParameterNodeAtInstant.check_node_vectorisable(node)
subnodes_name = node._children.keys()
subnodes_name = sorted(node._children.keys())
# Recursively vectorize the children of the node
vectorial_subnodes = tuple(
[
Expand Down Expand Up @@ -227,7 +227,9 @@ def __getitem__(self, key: str) -> Any:
)

# If the result is not a leaf, wrap the result in a vectorial node.
if numpy.issubdtype(result.dtype, numpy.record):
if numpy.issubdtype(
result.dtype, numpy.record
) or numpy.issubdtype(result.dtype, numpy.void):
return VectorialParameterNodeAtInstant(
self._name, result.view(numpy.recarray), self._instant_str
)
Expand Down
2 changes: 1 addition & 1 deletion setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,7 @@

general_requirements = [
"pytest>=7,<8",
"numpy>=1.21,<1.22",
"numpy >=1.24.2, <1.25",
"black",
"linecheck<1",
"yaml-changelog<1",
Expand Down

0 comments on commit be82476

Please sign in to comment.