Skip to content

Commit

Permalink
unittest clustering
Browse files Browse the repository at this point in the history
  • Loading branch information
yazhinia committed Dec 18, 2024
1 parent 0865a48 commit ec835d7
Show file tree
Hide file tree
Showing 3 changed files with 52 additions and 88 deletions.
10 changes: 3 additions & 7 deletions mcdevol/clustering.py
Original file line number Diff line number Diff line change
Expand Up @@ -73,7 +73,6 @@ def run_leiden(latent_norm, ncpus, resolution_param = 1.0, max_edges = 100):
weights = weights[index]
edgelist = list(zip(sources, targets))
g = ig.Graph(num_elements, edgelist)
print(resolution_param, 'resolution parameter')
rbconf = leidenalg.RBConfigurationVertexPartition(g, weights=weights,resolution_parameter=resolution_param)
optimiser = leidenalg.Optimiser()
optimiser.optimise_partition(rbconf, n_iterations=-1)
Expand Down Expand Up @@ -116,7 +115,6 @@ def cluster(
latent_norm = latent / np.linalg.norm(latent, axis=1, keepdims=True)

community_assignments = run_leiden(latent_norm, ncpus, max_edges=max_edges)

cluster_ids = pd.DataFrame({
"contig_name": contig_names,
"cluster_id": community_assignments
Expand All @@ -131,7 +129,6 @@ def cluster(
logger.info(f'Filtered bins by 200kb size: {len(cluster_selected.index)}')
file_name = 'bins_filtered.tsv'
cluster_selected.to_csv(os.path.join(outdir, file_name), header=None, sep=',', index=False)

if multi_split:
clusters = cluster_selected.groupby("cluster_id")["contig_name"].apply(list).tolist()
cluster_counter = 0
Expand Down Expand Up @@ -160,8 +157,7 @@ def cluster(
os.makedirs(bindirectory, exist_ok=True)
for inds in sampleindices:
# sample order can differ in pandas grouping and hence explicitly get sample id from contig name
sample_id = contig_names[inds[0]].split('_')[0].split('C')[0].replace('S','')
print(sample_id, 'sample_id')
sample_id = contig_names[inds[0]].split('C')[0] # .split('_')[0].
latent_sample = latent_norm[inds]
contig_length_sample = contig_length[inds]
names_subset = contig_names[inds]
Expand All @@ -174,9 +170,9 @@ def cluster(
binsize = pd.DataFrame(bin_ids.groupby("cluster_id")["contig_length"].sum().reset_index(drop=True))
binids_selected = binsize[binsize>=200000].index
bins_selected = bin_ids[bin_ids["cluster_id"].isin(binids_selected)][["contig_name","cluster_id"]]
file_name = f'S{sample_id}_bins_filtered'
file_name = f'{sample_id}_bins_filtered'
bins_selected.to_csv(os.path.join(outdir, file_name), header=None, sep=',', index=False)
samplebin_directory = os.path.join(bindirectory,"S"+str(sample_id))
samplebin_directory = os.path.join(bindirectory, str(sample_id))

subprocess.run(f"{util_path}/get_sequence_bybin {outdir} {file_name} {fasta_file} bin {samplebin_directory}", shell=True)
logger.info(f'Splitting clusters by sample: {len(cluster_selected.index)}')
Expand Down
110 changes: 44 additions & 66 deletions tests/test_clustering.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,8 +11,12 @@
import pandas as pd # type: ignore
import os
import io
import sys
import tempfile
import logging
import shutil
import igraph as ig
from io import StringIO
from unittest.mock import patch, MagicMock, call
import clustering
from clustering import cluster, run_leiden
Expand All @@ -22,28 +26,19 @@ def setUp(self):
# 100 contigs, 32-dimensional latent space
self.latent = np.random.rand(100, 32)
self.contig_length = np.random.randint(1000, 10000, 100)
self.contig_names = np.array([f"contig_{i}" for i in range(100)])
self.contig_names = np.array([f"k141_{i}" for i in range(100)])
self.fasta_file = "dummy.fasta"
self.outdir = tempfile.mkdtemp()
self.ncpus = 2
self.logger = logging.getLogger("test_logger")

@patch('clustering.run_leiden')
@patch('subprocess.run')
# @patch('sys.stdout', new_callable=io.StringIO)
def test_cluster(self, mock_subprocess_run, mock_run_leiden):
# Mock the Leiden clustering result
# mock_run_leiden.return_value = np.random.randint(0, 20, 100)
num_elements = 100
mock_edgelist = [(i, i + 1) for i in range(num_elements - 1)] # Simple chain graph
mock_g = ig.Graph(num_elements, mock_edgelist)
mock_run_leiden.return_value = (
np.random.randint(0, 20, 100), # Mocked community_assignments
100, # Mocked num_elements
50, # Mocked max_edges
np.random.rand(100, 50), # Mocked ann_distances
np.random.randint(0, 100, (100, 50)), # Mocked ann_neighbor_indices
mock_g
np.random.randint(0, 20, 100)
)

cluster(self.latent, self.contig_length, self.contig_names,
Expand All @@ -53,74 +48,57 @@ def test_cluster(self, mock_subprocess_run, mock_run_leiden):
self.assertTrue(os.path.exists(os.path.join(self.outdir, 'allbins.tsv')))
self.assertTrue(os.path.exists(os.path.join(self.outdir, 'bins_filtered.tsv')))

# Check if subprocess was called
mock_subprocess_run.assert_called_once()

def tearDown(self):
# Clean up temporary directory
for file in os.listdir(self.outdir):
os.remove(os.path.join(self.outdir, file))
os.rmdir(self.outdir)

class TestClusterFunction(unittest.TestCase):
@patch('subprocess.run')
@patch('os.makedirs')
@patch('os.path.exists')
@patch('builtins.print')
@patch('pandas.DataFrame.to_csv')
@patch('clustering.run_leiden') # Replace 'clustering' with the actual module name
def test_cluster_with_multi_split(
self, mock_run_leiden, mock_to_csv, mock_print, mock_exists, mock_makedirs, mock_subprocess_run
):
# Mock inputs
latent = np.random.rand(100, 10)
contig_length = np.random.randint(1000, 5000, size=100)
contig_names = np.array([f"S1C{i}" for i in range(50)] + [f"S2C{i}" for i in range(50, 100)])
fasta_file = 'test.fasta'
outdir = 'test_output'
ncpus = 4
logger = logging.getLogger('test_logger')
multi_split = True
separator = 'C'
def dynamic_run_leiden(latent_subset, ncpus, resolution_param=1.0, max_edges=100):
num_elements = latent_subset.shape[0]
community_assignments = np.random.randint(0, 20, size=num_elements) # Simulate cluster IDs
return community_assignments

mock_exists.return_value = False
mock_to_csv.return_value = None
class TestClusterFunctionMultiSplit(unittest.TestCase):

def dynamic_run_leiden(latent_sample, *args, **kwargs):
num_elements = len(latent_sample)
return (
np.random.randint(0, 10, size=num_elements),
num_elements,
100,
[np.random.rand(5) for _ in range(num_elements)],
[np.random.randint(0, num_elements, size=5) for _ in range(num_elements)],
MagicMock(vcount=lambda: num_elements)
)
def setUp(self):
self.latent = np.random.rand(100, 32)
self.contig_length = np.random.randint(1000, 10000, 100)
self.contig_names = np.array([f"S1Ck141_{i}" for i in range(50)] + [f"S2C{i}" for i in range(50, 100)])
self.fasta_file = "dummy.fasta"
self.outdir = tempfile.mkdtemp()
self.ncpus = 2
self.logger = logging.getLogger("test_logger")

@patch('clustering.run_leiden')
@patch('subprocess.run')
def test_cluster(self, mock_subprocess_run, mock_run_leiden):
mock_run_leiden.side_effect = dynamic_run_leiden

# Call the function
from clustering import cluster # Replace 'clustering' with your actual module name
cluster(latent, contig_length, contig_names, fasta_file, outdir, ncpus, logger, multi_split, separator=separator)

# Verify that run_leiden was called with appropriate subsets
calls = mock_run_leiden.call_args_list
self.assertGreater(len(calls), 0, "Expected multiple calls to run_leiden for sample-wise clustering.")

cluster(self.latent, self.contig_length, self.contig_names,
self.fasta_file, self.outdir, self.ncpus, self.logger, True)

# Check calls for critical operations
self.assertTrue(mock_makedirs.called)
self.assertTrue(mock_subprocess_run.called)
self.assertTrue(mock_to_csv.called)

# Verify that cluster splitting was performed
split_calls = [call for call in mock_to_csv.call_args_list if 'cluster_split_allsamplewisebins' in str(call)]
self.assertGreater(len(split_calls), 0, "Expected 'cluster_split_allsamplewisebins' to be saved.")

# Verify sub-clustering logic
bin_calls = [call for call in mock_subprocess_run.call_args_list if "get_sequence_bybin" in str(call)]
self.assertGreater(len(bin_calls), 0, "Expected 'get_sequence_bybin' to be called for sample bins.")

self.assertEqual(mock_run_leiden.call_count, 3, "Expected run_leiden to be called once for the entire dataset and twice for two samples.")

# Verify run_leiden was called for each subset
calls = mock_run_leiden.call_args_list
for i, call in enumerate(mock_run_leiden.call_args_list):
latent_subset = call[0][0] # Get the `latent_norm` argument from the call
sample_size = latent_subset.shape[0]
print(f"Call {i}: sample size = {sample_size}")
self.assertTrue(sample_size in [100,50], "Each latent_subset should have size 50 (one for each sample).")

self.assertTrue(os.path.exists(os.path.join(self.outdir, 'cluster_split_allsamplewisebins')))

def tearDown(self):
for file in os.listdir(self.outdir):
file_path = os.path.join(self.outdir, file)
if os.path.isfile(file_path): # Delete files
os.remove(file_path)
elif os.path.isdir(file_path): # Delete directories
shutil.rmtree(file_path)
os.rmdir(self.outdir)

if __name__ == '__main__':
unittest.main()
20 changes: 5 additions & 15 deletions tests/test_dataparsing.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,36 +16,27 @@
from Bio.Seq import Seq
from Bio.SeqRecord import SeqRecord

# Assuming the function is in a module named 'abundance_loader'
from dataparsing import load_abundance, compute_kmerembeddings



class TestComputeKmerEmbeddings(unittest.TestCase):
def setUp(self):
# Create a temporary directory
self.test_dir = tempfile.mkdtemp()

# Create a mock FASTA file
self.fasta_file = os.path.join(self.test_dir, "test.fasta")
self.create_mock_fasta()
parent_path = os.path.dirname(os.path.dirname(os.path.abspath(__file__)))
weight_path = os.path.join(parent_path, "mcdevol", "genomeface_weights", "general_t2eval.m.index")
self.assertTrue(os.path.exists(weight_path), f"Weight file not found at {weight_path}")

# Set up logger
self.logger = logging.getLogger("test_logger")
self.logger.setLevel(logging.INFO)

def create_mock_fasta(self):
# Create mock DNA sequences
sequences = [
("contig1", "ATCGATCGATCGATCGATCG"), # 20 bp
("contig2", "GCTAGCTAGCTAGCTAGCTAGCTAGA"), # 26 bp
("contig3", "TATATATATATATATA") # 16 bp
]

# Write sequences to FASTA file
with open(self.fasta_file, "w") as handle:
for seq_id, seq in sequences:
record = SeqRecord(Seq(seq), id=seq_id, description="")
Expand All @@ -64,7 +55,7 @@ def test_compute_kmerembeddings(self):
)

# Check the number of contigs
self.assertEqual(numcontigs, 2) # Only 2 contigs should meet the min_length requirement
self.assertEqual(numcontigs, 2)

# Check contig lengths
np.testing.assert_array_equal(contig_length, np.array([20, 26]))
Expand Down Expand Up @@ -213,14 +204,14 @@ def test_load_abundance_metabat_format(self):
])

np.testing.assert_array_almost_equal(result, expected)
self.assertEqual(result.shape, (3, 3)) # 3 contigs (excluding contig3), 3 samples
self.assertEqual(result.shape, (3, 3))

def test_load_abundance_metabat_all_contigs(self):
result = load_abundance(
self.temp_file.name,
numcontigs=4,
contig_names=self.contig_names,
min_length=0, # This should include all contigs
min_length=0,
logger=self.logger,
abundformat='metabat'
)
Expand All @@ -234,10 +225,9 @@ def test_load_abundance_metabat_all_contigs(self):
])

np.testing.assert_array_almost_equal(result, expected)
self.assertEqual(result.shape, (4, 3)) # 4 contigs, 3 samples
self.assertEqual(result.shape, (4, 3))

def test_load_abundance_metabat_reordering(self):
# Test with a different order of contig_names
reordered_contig_names = np.array(['contig2', 'contig4', 'contig1', 'contig3'])

result = load_abundance(
Expand All @@ -258,7 +248,7 @@ def test_load_abundance_metabat_reordering(self):
])

np.testing.assert_array_almost_equal(result, expected)
self.assertEqual(result.shape, (4, 3)) # 3 contigs, 3 samples
self.assertEqual(result.shape, (4, 3))



Expand Down

0 comments on commit ec835d7

Please sign in to comment.