Skip to content

Commit

Permalink
changed from np.float/np.int to float/int.
Browse files Browse the repository at this point in the history
  • Loading branch information
Shami Nisimov authored and Shami Nisimov committed Apr 30, 2023
1 parent d7903a2 commit 299ac4a
Show file tree
Hide file tree
Showing 4 changed files with 28 additions and 40 deletions.
6 changes: 3 additions & 3 deletions causal_discovery_utils/cond_indep_tests.py
Original file line number Diff line number Diff line change
Expand Up @@ -230,7 +230,7 @@ def __init__(self, threshold, dataset, weights=None, retained_edges=None, count_
num_records=None, num_vars=None):
if weights is not None:
raise Exception('weighted Partial-correlation is not supported. Please avoid using weights.')
super().__init__(dataset, threshold, database_type=np.float, weights=weights, retained_edges=retained_edges,
super().__init__(dataset, threshold, database_type=float, weights=weights, retained_edges=retained_edges,
count_tests=count_tests, use_cache=use_cache, num_records=num_records, num_vars=num_vars)

self.correlation_matrix = None
Expand Down Expand Up @@ -270,15 +270,15 @@ def calc_statistic(self, x, y, zz):

class CondIndepCMI(StatCondIndep):
def __init__(self, dataset, threshold, weights=None, retained_edges=None, count_tests=False, use_cache=False):
self.weight_data_type = np.float
self.weight_data_type = float
if weights is not None:
weights = np.array(weights, dtype=self.weight_data_type)
# if np.min(weights) < 0:
# raise Exception('Negative sample weights are not allowed')
# if np.abs(np.sum(weights) - 1.0) > np.finfo(self.weight_data_type).eps:
# raise Exception('Sample weights do not sum to 1.0')
# weights *= dataset.shape[0]
super().__init__(dataset, threshold, database_type=np.int, weights=weights, retained_edges=retained_edges,
super().__init__(dataset, threshold, database_type=int, weights=weights, retained_edges=retained_edges,
count_tests=count_tests, use_cache=use_cache)

def cond_indep(self, x, y, zz):
Expand Down
2 changes: 1 addition & 1 deletion graphical_models/basic_equivalance_class_graph.py
Original file line number Diff line number Diff line change
Expand Up @@ -293,7 +293,7 @@ def get_skeleton_mat(self):
:return:
"""
num_nodes = len(self.nodes_set)
adj_mat = np.zeros((num_nodes, num_nodes), dtype=np.int)
adj_mat = np.zeros((num_nodes, num_nodes), dtype=int)
node_index_map = {node: i for i, node in enumerate(sorted(list(self.nodes_set)))}

for node in self._graph:
Expand Down
Loading

0 comments on commit 299ac4a

Please sign in to comment.