Skip to content

Commit

Permalink
Better Handling of zero median values in Kernel Width (#160)
Browse files Browse the repository at this point in the history
* Filter zeros out of median computation
* removing IDE files from commit

---------

Signed-off-by: Nicholas Parente <[email protected]>
Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
  • Loading branch information
nparent1 and pre-commit-ci[bot] authored Sep 12, 2024
1 parent 9d732dd commit f0affa5
Show file tree
Hide file tree
Showing 4 changed files with 23 additions and 2 deletions.
5 changes: 5 additions & 0 deletions .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -123,6 +123,11 @@ venv.bak/
.spyderproject
.spyproject

# PyCharm
.idea/
*.iml
*.iws

# Rope project settings
.ropeproject

Expand Down
1 change: 1 addition & 0 deletions doc/whats_new/v0.1.rst
Original file line number Diff line number Diff line change
Expand Up @@ -50,6 +50,7 @@ Changelog
- |Feature| Add a suite of general categorical data CI tests, by `Adam Li`_ (:pr:`128`)
- |Feature| Implement CAM, SCORE, DAS, NoGAM algorithms in ``dodiscover.toporder`` submodule (:pr:`129`)
- |Feature| Add Psi-FCI and I-FCI algorithm for handling soft-interventional data, :class:`dodiscover.constraint.PsiFCI` by `Adam Li`_ (:pr:`111`)
- |Fix| Update the kernel_width method to filter out zeros before computing the median pairwise distance, by `Nick Parente`_ (:pr:`160`)

Code and Documentation Contributors
-----------------------------------
Expand Down
5 changes: 3 additions & 2 deletions dodiscover/toporder/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,8 +19,9 @@ def kernel_width(X: NDArray):
Matrix of the data.
"""
X_diff = np.expand_dims(X, axis=1) - X # Gram matrix of the data
D = np.linalg.norm(X_diff, axis=2)
s = np.median(D.flatten())
D = np.linalg.norm(X_diff, axis=2).flatten()
D_nonzeros = D[D > 0] # Remove zeros
s = np.median(D_nonzeros) if np.any(D_nonzeros) else 1
return s


Expand Down
14 changes: 14 additions & 0 deletions tests/unit_tests/toporder/test_utils.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,14 @@
import numpy as np

from dodiscover.toporder.utils import kernel_width


def test_kernel_width_when_zero_median_pairwise_distances():
arr = np.zeros((100, 1), dtype=np.int64)
arr[1] = 1
assert kernel_width(arr) == 1


def test_kernel_width_when_all_zero_pairwise_distances():
arr = np.ones((100, 1), dtype=np.int64)
assert kernel_width(arr) == 1

0 comments on commit f0affa5

Please sign in to comment.