Skip to content

Commit

Permalink
remove test classes
Browse files Browse the repository at this point in the history
  • Loading branch information
wxicu committed Jun 10, 2024
1 parent 2c8127c commit 57b14c3
Showing 1 changed file with 49 additions and 46 deletions.
95 changes: 49 additions & 46 deletions tests/tools/test_metrics_3g.py
Original file line number Diff line number Diff line change
@@ -1,55 +1,58 @@
import pertpy as pt
import numpy as np
import pertpy as pt
import pytest
from pertpy.tools._metrics_3g import (
compare_de,
compare_class,
compare_de,
compare_dist,
compare_knn,
)


class TestMetrics3G:
@pytest.fixture
def test_data(self):
rng = np.random.default_rng()
X = rng.normal(size=(100, 10))
Y = rng.normal(size=(100, 10))
C = rng.normal(size=(100, 10))
return X, Y, C

def test_compare_de(self, test_data):
X, Y, C = test_data
result = compare_de(X, Y, C, shared_top=5)
assert isinstance(result, dict)
required_keys = {
"shared_top_genes",
"scores_corr",
"pvals_adj_corr",
"scores_ranks_corr",
}
assert all(key in result for key in required_keys)

def test_compare_class(self, test_data):
X, Y, C = test_data
result = compare_class(X, Y, C)
assert result <= 1

def test_compare_knn(self, test_data):
X, Y, C = test_data
result = compare_knn(X, Y, C)
assert isinstance(result, dict)
assert "comp" in result
assert isinstance(result["comp"], float)

result_no_ctrl = compare_knn(X, Y)
assert isinstance(result_no_ctrl, dict)

def test_compare_dist(self, test_data):
X, Y, C = test_data
res_simple = compare_dist(X, Y, C, mode="simple")
assert isinstance(res_simple, float)
res_scaled = compare_dist(X, Y, C, mode="scaled")
assert isinstance(res_scaled, float)
with pytest.raises(ValueError):
compare_dist(X, Y, C, mode="new_mode")
@pytest.fixture
def test_data():
rng = np.random.default_rng()
X = rng.normal(size=(100, 10))
Y = rng.normal(size=(100, 10))
C = rng.normal(size=(100, 10))
return X, Y, C


def test_compare_de(test_data):
X, Y, C = test_data
result = compare_de(X, Y, C, shared_top=5)
assert isinstance(result, dict)
required_keys = {
"shared_top_genes",
"scores_corr",
"pvals_adj_corr",
"scores_ranks_corr",
}
assert all(key in result for key in required_keys)


def test_compare_class(test_data):
X, Y, C = test_data
result = compare_class(X, Y, C)
assert result <= 1


def test_compare_knn(test_data):
X, Y, C = test_data
result = compare_knn(X, Y, C)
assert isinstance(result, dict)
assert "comp" in result
assert isinstance(result["comp"], float)

result_no_ctrl = compare_knn(X, Y)
assert isinstance(result_no_ctrl, dict)


def test_compare_dist(test_data):
X, Y, C = test_data
res_simple = compare_dist(X, Y, C, mode="simple")
assert isinstance(res_simple, float)
res_scaled = compare_dist(X, Y, C, mode="scaled")
assert isinstance(res_scaled, float)
with pytest.raises(ValueError):
compare_dist(X, Y, C, mode="new_mode")

0 comments on commit 57b14c3

Please sign in to comment.