Skip to content

Commit

Permalink
Added a CI-test that uses d-sepration and m-separation graphical crit…
Browse files Browse the repository at this point in the history
…eria.
  • Loading branch information
raanan-rohekar committed Dec 10, 2023
1 parent 348b631 commit 307d891
Showing 1 changed file with 65 additions and 1 deletion.
66 changes: 65 additions & 1 deletion causal_discovery_utils/cond_indep_tests.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@
import numpy as np
from causal_discovery_utils.data_utils import calc_stats
from causal_discovery_utils.data_utils import get_var_size
from graphical_models import DAG, UndirectedGraph
from graphical_models import DAG, UndirectedGraph, PAG
from scipy import stats


Expand Down Expand Up @@ -134,6 +134,70 @@ def cond_indep(self, x, y, zz):
return res


class GraphCondIndep:
"""
GraphCondIndep: a CI test that derive its result from a given graph.
Depending on the graph type, an appropriate criterion is used:
DAG type: d-separation criterion
PAG type: m-separation criterion
"""
def __init__(self, reference_graph, static_conditioning=None, count_tests=False, use_cache=False, verbose=False):
"""
Initialize GraphCondIndep, a CI test that derive its result from a given graph.
:param reference_graph: a graph from which independence relations are inferred. Only DAG and PAG are supported.
:param static_conditioning: a set of nodes that will always be included in the conditioning set.
:param count_tests: if True, count the number of CI test queries (default: False). Mainly for debug
:param use_cache: if True, cache CI tests' results (default: False). Used for avoiding redundant CI tests.
:param verbose: Verbose flag (default: False). Mainly for debug
"""
self.reference_graph = reference_graph
self.verbose = verbose

if type(reference_graph) == DAG:
self.ci_criterion = reference_graph.dsep
elif type(reference_graph) == PAG:
self.ci_criterion = reference_graph.is_m_separated
else:
raise TypeError('Unsupported graph type.')

if static_conditioning is None or type(static_conditioning) == tuple:
self.static_conditioning = static_conditioning
else:
raise TypeError('Static conditioning, if defined, should be a tuple.')

num_nodes = len(reference_graph.nodes_set)
self.count_tests = count_tests
if count_tests:
self.test_counter = [0 for _ in range(num_nodes - 1)]
else:
self.test_counter = None

self.is_cache = use_cache
if use_cache:
self.cache_ci = CacheCI(num_nodes)
else:
self.cache_ci = CacheCI(None)

def cond_indep(self, x, y, zz_conditioning):
if self.static_conditioning is None:
zz = zz_conditioning
else:
zz = tuple(set(zz_conditioning + self.static_conditioning))

res = self.cache_ci.get_cache_result(x, y, zz)

if res is None:
res = self.ci_criterion(x, y, zz)
if self.verbose:
print(self.ci_criterion.__name__, '(', x, ',', y, '|', zz, ')', '=', res)
if self.is_cache:
self.cache_ci.set_cache_result(x, y, zz, res)
if self.count_tests:
self.test_counter[len(zz)] += 1 # update counter only if the test was not previously cached
return res


class StatCondIndep:
def __init__(self,
dataset, threshold, database_type, weights=None,
Expand Down

0 comments on commit 307d891

Please sign in to comment.