Skip to content

Commit

Permalink
unittest dataparsing
Browse files Browse the repository at this point in the history
  • Loading branch information
yazhinia committed Dec 17, 2024
1 parent ec76763 commit a727b6b
Show file tree
Hide file tree
Showing 3 changed files with 322 additions and 26 deletions.
3 changes: 3 additions & 0 deletions mcdevol/byol_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -417,6 +417,8 @@ def process_batches(self,
self.update_moving_average()

epoch_loss += loss.detach().data.item()
if epoch == 100 or epoch == 200 or epoch == 300:
np.save(self.outdir+f'latent_epoch{epoch}',np.vstack(latent_space))
if training:
self.scheduler.step() # type: ignore
epoch_losses.extend([epoch_loss])
Expand Down Expand Up @@ -740,6 +742,7 @@ def run(abundance_matrix, outdir, contig_length, contig_names, multi_split, ncpu
kmerdata_tmp = np.load(arg_name, allow_pickle=True).astype(np.float32)
kmer_data[keyname+name] = kmerdata_tmp[nonzeroindices]

# np.save(os.path.join(outdir, 'abundance_matrix.npy'), abundance_matrix)
byol = BYOLmodel(abundance_matrix, kmer_data, contig_length, outdir, logger, multi_split, ncpus)
byol.trainmodel()
latent = byol.getlatent()
Expand Down
82 changes: 56 additions & 26 deletions mcdevol/dataparsing.py
Original file line number Diff line number Diff line change
Expand Up @@ -126,17 +126,19 @@ def compute_kmerembeddings(
kmermodel = Model([Input(shape=(136,)),*kmer_inputs], x)
kmermodel.compile()
# Load genomeface model weights
path_weight = os.path.join(current_path, "genomeface_weights","general_t2eval.m")
path_weight = os.path.join(current_path, "genomeface_weights", "general_t2eval.m")
kmermodel.load_weights(path_weight)

for counter, infile in enumerate(file_ids):
aaq = kmer_counter.find_nMer_distributions(infile, min_length)
contig_names = np.asarray(aaq[-1])
contig_length = np.asarray(aaq[0])
contig_length = contig_length[contig_length >= min_length]
if counter == 0:
contig_length = np.asarray(aaq[0])
contig_length = contig_length[contig_length >= min_length]
assert len(contig_length) == len(contig_names)
numcontigs = len(contig_names)
inpts = [np.reshape(aaq[i], (-1, size)).astype(np.float32) for i, size in enumerate([512, 136, 32, 10, 2, 528, 256, 136, 64, 36], start=1)]
inpts = [np.reshape(aaq[i], (-1, size)).astype(np.float32) \
for i, size in enumerate([512, 136, 32, 10, 2, 528, 256, 136, 64, 36], start=1)]

# generate numcontigs x 136 array filled with zeros
model_data_in = [np.zeros((inpts[0].shape[0], 136), dtype=np.float32)]
Expand All @@ -156,9 +158,9 @@ def compute_kmerembeddings(

y20_cat /= np.linalg.norm(y20_cat, axis=1, keepdims=True)
if counter == 0:
np.save(os.path.join(outdir, f'kmer_embedding{file_counter[counter]}.npy'),y20_cat)
np.save(os.path.join(outdir, f'kmer_embedding{file_counter[counter]}.npy'), y20_cat)
else:
np.save(os.path.join(outdir, f'kmer_embedding_augment{file_counter[counter]}.npy'),y20_cat)
np.save(os.path.join(outdir, f'kmer_embedding_augment{file_counter[counter]}.npy'), y20_cat)
tf.keras.backend.clear_session()
gc.collect()

Expand Down Expand Up @@ -188,56 +190,84 @@ def load_abundance(
Returns:
np.ndarray: A numpy array containing the abundance data ordered by `contig_names`.
"""
try:
# Attempt to read the file with tab separator
pd.read_csv(abundance_file, sep='\t')
except Exception as e:
raise ValueError(f"Failed to parse the file as tab-separated: {e}")

abundance_header = pd.read_csv(abundance_file, sep='\t', nrows=0)

if len(abundance_header.columns) == 0:
raise ValueError(f"abundance header is empty. Check your input abundance file!")

names_dict: Dict[str, int] = {name: index for index, name in enumerate(contig_names)}
s = time.time()
if abundformat == 'std':
num_samples = len(abundance_header.columns) - 1
arr = np.zeros((numcontigs, num_samples),dtype='float')

arr = np.zeros((numcontigs, num_samples),dtype='float')
logger.info(f'Loading abundance file with {numcontigs} contigs and {num_samples} samples')
used = 0
abundance_names = []
with pd.read_csv(abundance_file,sep='\t', lineterminator='\n', engine='c', chunksize = 10 ** 6) as reader:
for chunk in tqdm.tqdm(reader):
# TODO: this condition may not be needed as input need not have length column.
# This should be handled well by ordered_indices selection
# chunk_part = chunk[chunk['contigLen'] >= min_length]
abundance_names.extend(chunk['contigName'])
arr[used:used+len(chunk)] = chunk.iloc[:,1:len(chunk.columns)].to_numpy()
used += len(chunk)

reader = pd.read_csv(abundance_file, sep='\t',\
lineterminator='\n', engine='c', chunksize=10**6)
for chunk in tqdm.tqdm(reader):
# TODO: this condition may not be needed as input need not have length column.
# This should be handled well by ordered_indices selection
# chunk_part = chunk[chunk['contigLen'] >= min_length]
abundance_names.extend(chunk['contigName'].str.strip())
arr[used:used+len(chunk)] = chunk.iloc[:,1:len(chunk.columns)].to_numpy()
used += len(chunk)

# Remove data for contigs shorter than min_length.
# It would be present if abundance file is created from aligner2counts as it uses contigs length to save.
# If not it should be processed here
abundance_names = np.array(abundance_names)

if len(abundance_names) > len(contig_names):
indices = np.where(np.isin(abundance_names, contig_names))[0]
arr = arr[indices]
abundance_names = abundance_names[indices]

# reorder abundance as per contigs order in sequence
ordered_indices = [names_dict[name] for name in abundance_names if name in names_dict]
abundance_names_dict = {name: index for index, name in enumerate(abundance_names)}
# ordered_indices = [names_dict[name] for name in abundance_names if name in names_dict]
ordered_indices = [abundance_names_dict[name] for name in contig_names if name in abundance_names]

arr = arr[ordered_indices]
gc.collect()
logger.info(f'Loaded abundance file in {time.time()-s:.2f} seconds')

return arr

if abundformat == 'metabat':
num_columns = len(abundance_header.columns)
num_samples = (num_columns - 3) // 2
arr = np.zeros((numcontigs,num_samples),dtype='float')
num_samples = int((num_columns - 3) // 2)
logger.info(f'Loading abundance file with {numcontigs} contigs and {num_samples} samples')
used = 0
arr = [] # np.zeros((numcontigs,num_samples),dtype='float')
abundance_names = []
with pd.read_csv(abundance_file,sep='\t',lineterminator='\n',engine='c',chunksize = 10 ** 6) as reader:
for chunk in tqdm.tqdm(reader):
chunk_part = chunk[chunk['contigLen'] >= min_length]
abundance_names.extend(chunk_part['contigName'])
arr[used:used+len(chunk_part)] = chunk_part.iloc[:,range(3,num_columns,2)].to_numpy()
used += len(chunk_part)

reader = pd.read_csv(abundance_file, sep='\t', \
lineterminator='\n', engine='c', chunksize=10**6)
for chunk in tqdm.tqdm(reader):
chunk_part = chunk[chunk['contigLen'] >= min_length]
abundance_names.extend(chunk_part['contigName'].str.strip())
arr.append(chunk_part.iloc[:,range(3,num_columns,2)].to_numpy())
used += len(chunk_part)

if arr:
arr = np.vstack(arr)
else:
arr = np.array([])

# reorder abundance as per contigs order in sequence
ordered_indices = [names_dict[name] for name in abundance_names if name in names_dict]
contigs_names_filtered = [name for name in contig_names if name in abundance_names]

abundance_names_dict = {name: index for index, name in enumerate(abundance_names)}
ordered_indices = [abundance_names_dict[name] for name in contigs_names_filtered if name in abundance_names]

arr = arr[ordered_indices]
gc.collect()
logger.info(f'Loaded abundance file in {time.time()-s:.2f} seconds')
Expand Down
Loading

0 comments on commit a727b6b

Please sign in to comment.