diff --git a/mcdevol/byol_model.py b/mcdevol/byol_model.py index 689edcd..59f297d 100644 --- a/mcdevol/byol_model.py +++ b/mcdevol/byol_model.py @@ -417,6 +417,8 @@ def process_batches(self, self.update_moving_average() epoch_loss += loss.detach().data.item() + if epoch == 100 or epoch == 200 or epoch == 300: + np.save(self.outdir+f'latent_epoch{epoch}',np.vstack(latent_space)) if training: self.scheduler.step() # type: ignore epoch_losses.extend([epoch_loss]) @@ -740,6 +742,7 @@ def run(abundance_matrix, outdir, contig_length, contig_names, multi_split, ncpu kmerdata_tmp = np.load(arg_name, allow_pickle=True).astype(np.float32) kmer_data[keyname+name] = kmerdata_tmp[nonzeroindices] + # np.save(os.path.join(outdir, 'abundance_matrix.npy'), abundance_matrix) byol = BYOLmodel(abundance_matrix, kmer_data, contig_length, outdir, logger, multi_split, ncpus) byol.trainmodel() latent = byol.getlatent() diff --git a/mcdevol/dataparsing.py b/mcdevol/dataparsing.py index 9aa9b9f..cff2363 100644 --- a/mcdevol/dataparsing.py +++ b/mcdevol/dataparsing.py @@ -126,17 +126,19 @@ def compute_kmerembeddings( kmermodel = Model([Input(shape=(136,)),*kmer_inputs], x) kmermodel.compile() # Load genomeface model weights - path_weight = os.path.join(current_path, "genomeface_weights","general_t2eval.m") + path_weight = os.path.join(current_path, "genomeface_weights", "general_t2eval.m") kmermodel.load_weights(path_weight) for counter, infile in enumerate(file_ids): aaq = kmer_counter.find_nMer_distributions(infile, min_length) contig_names = np.asarray(aaq[-1]) - contig_length = np.asarray(aaq[0]) - contig_length = contig_length[contig_length >= min_length] + if counter == 0: + contig_length = np.asarray(aaq[0]) + contig_length = contig_length[contig_length >= min_length] assert len(contig_length) == len(contig_names) numcontigs = len(contig_names) - inpts = [np.reshape(aaq[i], (-1, size)).astype(np.float32) for i, size in enumerate([512, 136, 32, 10, 2, 528, 256, 136, 64, 36], start=1)] + inpts = [np.reshape(aaq[i], (-1, size)).astype(np.float32) \ + for i, size in enumerate([512, 136, 32, 10, 2, 528, 256, 136, 64, 36], start=1)] # generate numcontigs x 136 array filled with zeros model_data_in = [np.zeros((inpts[0].shape[0], 136), dtype=np.float32)] @@ -156,9 +158,9 @@ def compute_kmerembeddings( y20_cat /= np.linalg.norm(y20_cat, axis=1, keepdims=True) if counter == 0: - np.save(os.path.join(outdir, f'kmer_embedding{file_counter[counter]}.npy'),y20_cat) + np.save(os.path.join(outdir, f'kmer_embedding{file_counter[counter]}.npy'), y20_cat) else: - np.save(os.path.join(outdir, f'kmer_embedding_augment{file_counter[counter]}.npy'),y20_cat) + np.save(os.path.join(outdir, f'kmer_embedding_augment{file_counter[counter]}.npy'), y20_cat) tf.keras.backend.clear_session() gc.collect() @@ -188,56 +190,84 @@ def load_abundance( Returns: np.ndarray: A numpy array containing the abundance data ordered by `contig_names`. """ + try: + # Attempt to read the file with tab separator + pd.read_csv(abundance_file, sep='\t') + except Exception as e: + raise ValueError(f"Failed to parse the file as tab-separated: {e}") + abundance_header = pd.read_csv(abundance_file, sep='\t', nrows=0) + + if len(abundance_header.columns) == 0: + raise ValueError(f"abundance header is empty. Check your input abundance file!") + names_dict: Dict[str, int] = {name: index for index, name in enumerate(contig_names)} s = time.time() if abundformat == 'std': num_samples = len(abundance_header.columns) - 1 - arr = np.zeros((numcontigs, num_samples),dtype='float') + + arr = np.zeros((numcontigs, num_samples),dtype='float') logger.info(f'Loading abundance file with {numcontigs} contigs and {num_samples} samples') used = 0 abundance_names = [] - with pd.read_csv(abundance_file,sep='\t', lineterminator='\n', engine='c', chunksize = 10 ** 6) as reader: - for chunk in tqdm.tqdm(reader): - # TODO: this condition may not be needed as input need not have length column. - # This should be handled well by ordered_indices selection - # chunk_part = chunk[chunk['contigLen'] >= min_length] - abundance_names.extend(chunk['contigName']) - arr[used:used+len(chunk)] = chunk.iloc[:,1:len(chunk.columns)].to_numpy() - used += len(chunk) - + reader = pd.read_csv(abundance_file, sep='\t',\ + lineterminator='\n', engine='c', chunksize=10**6) + for chunk in tqdm.tqdm(reader): + # TODO: this condition may not be needed as input need not have length column. + # This should be handled well by ordered_indices selection + # chunk_part = chunk[chunk['contigLen'] >= min_length] + abundance_names.extend(chunk['contigName'].str.strip()) + arr[used:used+len(chunk)] = chunk.iloc[:,1:len(chunk.columns)].to_numpy() + used += len(chunk) + # Remove data for contigs shorter than min_length. # It would be present if abundance file is created from aligner2counts as it uses contigs length to save. # If not it should be processed here abundance_names = np.array(abundance_names) + if len(abundance_names) > len(contig_names): indices = np.where(np.isin(abundance_names, contig_names))[0] arr = arr[indices] abundance_names = abundance_names[indices] # reorder abundance as per contigs order in sequence - ordered_indices = [names_dict[name] for name in abundance_names if name in names_dict] + abundance_names_dict = {name: index for index, name in enumerate(abundance_names)} + # ordered_indices = [names_dict[name] for name in abundance_names if name in names_dict] + ordered_indices = [abundance_names_dict[name] for name in contig_names if name in abundance_names] arr = arr[ordered_indices] gc.collect() logger.info(f'Loaded abundance file in {time.time()-s:.2f} seconds') + return arr if abundformat == 'metabat': num_columns = len(abundance_header.columns) - num_samples = (num_columns - 3) // 2 - arr = np.zeros((numcontigs,num_samples),dtype='float') + num_samples = int((num_columns - 3) // 2) logger.info(f'Loading abundance file with {numcontigs} contigs and {num_samples} samples') used = 0 + arr = [] # np.zeros((numcontigs,num_samples),dtype='float') abundance_names = [] - with pd.read_csv(abundance_file,sep='\t',lineterminator='\n',engine='c',chunksize = 10 ** 6) as reader: - for chunk in tqdm.tqdm(reader): - chunk_part = chunk[chunk['contigLen'] >= min_length] - abundance_names.extend(chunk_part['contigName']) - arr[used:used+len(chunk_part)] = chunk_part.iloc[:,range(3,num_columns,2)].to_numpy() - used += len(chunk_part) + + reader = pd.read_csv(abundance_file, sep='\t', \ + lineterminator='\n', engine='c', chunksize=10**6) + for chunk in tqdm.tqdm(reader): + chunk_part = chunk[chunk['contigLen'] >= min_length] + abundance_names.extend(chunk_part['contigName'].str.strip()) + arr.append(chunk_part.iloc[:,range(3,num_columns,2)].to_numpy()) + used += len(chunk_part) + + if arr: + arr = np.vstack(arr) + else: + arr = np.array([]) + # reorder abundance as per contigs order in sequence - ordered_indices = [names_dict[name] for name in abundance_names if name in names_dict] + contigs_names_filtered = [name for name in contig_names if name in abundance_names] + + abundance_names_dict = {name: index for index, name in enumerate(abundance_names)} + ordered_indices = [abundance_names_dict[name] for name in contigs_names_filtered if name in abundance_names] + arr = arr[ordered_indices] gc.collect() logger.info(f'Loaded abundance file in {time.time()-s:.2f} seconds') diff --git a/tests/test_dataparsing.py b/tests/test_dataparsing.py new file mode 100644 index 0000000..83e05e1 --- /dev/null +++ b/tests/test_dataparsing.py @@ -0,0 +1,263 @@ +import os +import sys + +parent_path = os.path.abspath(os.path.join(os.path.dirname(os.path.realpath(__file__)), '../')) +sys.path.insert(0, parent_path) +sys.path.insert(0, os.path.join(parent_path, 'mcdevol')) + +import unittest +import numpy as np +import pandas as pd # type: ignore +import logging +import tempfile +import os +from io import StringIO +from Bio import SeqIO +from Bio.Seq import Seq +from Bio.SeqRecord import SeqRecord + +# Assuming the function is in a module named 'abundance_loader' +from dataparsing import load_abundance, compute_kmerembeddings + + + +class TestComputeKmerEmbeddings(unittest.TestCase): + def setUp(self): + # Create a temporary directory + self.test_dir = tempfile.mkdtemp() + + # Create a mock FASTA file + self.fasta_file = os.path.join(self.test_dir, "test.fasta") + self.create_mock_fasta() + + # Set up logger + self.logger = logging.getLogger("test_logger") + self.logger.setLevel(logging.INFO) + + def create_mock_fasta(self): + # Create mock DNA sequences + sequences = [ + ("contig1", "ATCGATCGATCGATCGATCG"), # 20 bp + ("contig2", "GCTAGCTAGCTAGCTAGCTAGCTAGA"), # 26 bp + ("contig3", "TATATATATATATATA") # 16 bp + ] + + # Write sequences to FASTA file + with open(self.fasta_file, "w") as handle: + for seq_id, seq in sequences: + record = SeqRecord(Seq(seq), id=seq_id, description="") + SeqIO.write(record, handle, "fasta") + + def test_compute_kmerembeddings(self): + min_length = 20 + n_fragments = 2 + + numcontigs, contig_length, contig_names = compute_kmerembeddings( + self.test_dir, + self.fasta_file, + min_length, + self.logger, + n_fragments + ) + + # Check the number of contigs + self.assertEqual(numcontigs, 2) # Only 2 contigs should meet the min_length requirement + + # Check contig lengths + np.testing.assert_array_equal(contig_length, np.array([20, 26])) + + # Check contig names + np.testing.assert_array_equal(contig_names, np.array(["contig1", "contig2"])) + + # Check if embedding files are created + self.assertTrue(os.path.exists(os.path.join(self.test_dir, "kmer_embedding.npy"))) + self.assertTrue(os.path.join(self.test_dir, "kmer_embedding_augment1.npy")) + self.assertTrue(os.path.join(self.test_dir, "kmer_embedding_augment2.npy")) + + def tearDown(self): + # Remove the temporary directory and its contents + for file in os.listdir(self.test_dir): + os.remove(os.path.join(self.test_dir, file)) + os.rmdir(self.test_dir) + + +class TestLoadAbundance(unittest.TestCase): + def setUp(self): + self.logger = logging.getLogger('test_logger') + self.logger.setLevel(logging.INFO) + + # Sample data for testing + self.sample_data = """contigName\tsample1\tsample2\tsample3 + contig1\t1.0\t2.0\t3.0 + contig2\t4.0\t5.0\t6.0 + contig3\t7.0\t8.0\t9.0 + """ + + # Create a temporary file + self.temp_file = tempfile.NamedTemporaryFile(mode='w+', delete=False) + self.temp_file.write(self.sample_data) + self.temp_file.flush() + self.temp_file.close() + + self.contig_names = np.array(['contig1', 'contig2', 'contig3']) + + def tearDown(self): + os.unlink(self.temp_file.name) + + def test_load_abundance_standard_format(self): + + # with open(self.temp_file.name, 'r') as f: + # print(f"File content:\n{f.read()}") + + result = load_abundance( + self.temp_file.name, + numcontigs=3, + contig_names=self.contig_names, + min_length=0, + logger=self.logger + ) + + expected = np.array([ + [1.0, 2.0, 3.0], + [4.0, 5.0, 6.0], + [7.0, 8.0, 9.0]]) + + np.testing.assert_array_equal(result, expected) + + def test_load_abundance_file_not_found(self): + with self.assertRaises(ValueError): + load_abundance( + 'non_existent_file.tsv', + numcontigs=3, + contig_names=self.contig_names, + min_length=0, + logger=self.logger + ) + + def test_load_abundance_empty_file(self): + empty_file = tempfile.NamedTemporaryFile(mode='w+', delete=False) + empty_file.close() + + with self.assertRaises(ValueError): + load_abundance( + empty_file.name, + numcontigs=3, + contig_names=self.contig_names, + min_length=0, + logger=self.logger + ) + + os.unlink(empty_file.name) + + def test_load_abundance_mismatched_contigs(self): + mismatched_contig_names = np.array(['contig2', 'contig1', 'contig4']) + + result = load_abundance( + self.temp_file.name, + numcontigs=3, + contig_names=mismatched_contig_names, + min_length=0, + logger=self.logger + ) + + # Expect only data for contig1 and contig2 to be present + expected = np.array([[4.0, 5.0, 6.0], + [1.0, 2.0, 3.0]]) + + np.testing.assert_array_equal(result, expected) + + +class TestLoadAbundanceMetaBAT(unittest.TestCase): + def setUp(self): + self.logger = logging.getLogger('test_logger') + self.logger.setLevel(logging.INFO) + + # Sample data for testing MetaBAT format + self.sample_data = \ + """contigName\tcontigLen\ttotalAvgDepth\tcov1\tvar1\tcov2\tvar2\tcov3\tvar3 + contig1\t1000\t6.0\t1.0\t0.1\t2.0\t0.2\t3.0\t0.3 + contig2\t2000\t15.0\t4.0\t0.2\t5.0\t0.3\t6.0\t0.4 + contig3\t500\t24.0\t7.0\t0.3\t8.0\t0.4\t9.0\t0.5 + contig4\t1500\t33.0\t10.0\t0.4\t11.0\t0.5\t12.0\t0.6 + """ + + # Create a temporary file + self.temp_file = tempfile.NamedTemporaryFile(mode='w+', delete=False) + self.temp_file.write(self.sample_data) + self.temp_file.flush() + self.temp_file.close() + + self.contig_names = np.array(['contig1', 'contig2', 'contig3', 'contig4']) + + def tearDown(self): + os.unlink(self.temp_file.name) + + def test_load_abundance_metabat_format(self): + result = load_abundance( + self.temp_file.name, + numcontigs=4, + contig_names=self.contig_names, + min_length=1000, # This should exclude contig3 + logger=self.logger, + abundformat='metabat' + ) + + # Expected result: only contigs >= 1000 bp, and only depth columns + expected = np.array([ + [1.0, 2.0, 3.0], + [4.0, 5.0, 6.0], + [10.0, 11.0, 12.0] + ]) + + np.testing.assert_array_almost_equal(result, expected) + self.assertEqual(result.shape, (3, 3)) # 3 contigs (excluding contig3), 3 samples + + def test_load_abundance_metabat_all_contigs(self): + result = load_abundance( + self.temp_file.name, + numcontigs=4, + contig_names=self.contig_names, + min_length=0, # This should include all contigs + logger=self.logger, + abundformat='metabat' + ) + + # Expected result: all contigs, only depth columns + expected = np.array([ + [1.0, 2.0, 3.0], + [4.0, 5.0, 6.0], + [7.0, 8.0, 9.0], + [10.0, 11.0, 12.0] + ]) + + np.testing.assert_array_almost_equal(result, expected) + self.assertEqual(result.shape, (4, 3)) # 4 contigs, 3 samples + + def test_load_abundance_metabat_reordering(self): + # Test with a different order of contig_names + reordered_contig_names = np.array(['contig2', 'contig4', 'contig1', 'contig3']) + + result = load_abundance( + self.temp_file.name, + numcontigs=4, + contig_names=reordered_contig_names, + min_length=0, + logger=self.logger, + abundformat='metabat' + ) + + # Expected result: reordered according to reordered_contig_names + expected = np.array([ + [4.0, 5.0, 6.0], + [10.0, 11.0, 12.0], + [1.0, 2.0, 3.0], + [7.0, 8.0, 9.0] + ]) + + np.testing.assert_array_almost_equal(result, expected) + self.assertEqual(result.shape, (4, 3)) # 3 contigs, 3 samples + + + +if __name__ == '__main__': + unittest.main()