Skip to content

Commit

Permalink
gene_nmf: added max gene sets to use for compute option for speed
Browse files Browse the repository at this point in the history
  • Loading branch information
marcdubybroad committed Aug 29, 2024
1 parent 219991f commit 16915c5
Show file tree
Hide file tree
Showing 3 changed files with 822 additions and 54 deletions.
136 changes: 86 additions & 50 deletions app/novelty/gene_nmf/dcc/compute_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -43,6 +43,7 @@
from numpy.random import exponential
from sklearn.decomposition import NMF
import json
import time

import dcc.dcc_utils as dutils
import dcc.matrix_utils as mutils
Expand Down Expand Up @@ -75,7 +76,8 @@ def __init__(self, message):
super().__init__(self.message)

# methods
def calculate_factors(matrix_gene_sets_gene_original, list_gene, list_system_genes, map_gene_index, map_gene_set_index, mean_shifts, scale_factors, p_value=0.05, log=False):
def calculate_factors(matrix_gene_sets_gene_original, list_gene, list_system_genes, map_gene_index, map_gene_set_index, mean_shifts, scale_factors,
p_value=0.05, max_num_gene_sets=100, log=False):
'''
will produce the gene set factors and gene factors
'''
Expand All @@ -86,34 +88,41 @@ def calculate_factors(matrix_gene_sets_gene_original, list_gene, list_system_gen
gene_factor = None
gene_set_factor = None
map_lowest_factor_per_gene = {}
logs_process = []

# start time counter
start = time.time()

# step 1/2: get the gene vector from the gene list
if log:
logger.info("step 0: got input gene list from user of size: {}".format(len(list_gene)))
vector_gene, list_input_gene_indices = mutils.generate_gene_vector_from_list(list_gene=list_gene, map_gene_index=map_gene_index)

# log
if log:
print("step 1: got gene set matrix of shape: {}".format(matrix_gene_sets_gene_original.shape))
print("step 1: got mean_shifts of shape: {}".format(mean_shifts.shape))
print("step 1: got scale_factors of shape: {}".format(scale_factors.shape))
print("step 2: got gene vector of shape: {}".format(vector_gene.shape))
logger.info("step 1: got gene set matrix of shape: {}".format(matrix_gene_sets_gene_original.shape))
logger.info("step 1: got mean_shifts of shape: {}".format(mean_shifts.shape))
logger.info("step 1: got scale_factors of shape: {}".format(scale_factors.shape))
logger.info("step 2: got one hot gene vector of shape: {}".format(vector_gene.shape))
logger.info("step 2: got resulting found gene indices list of size: {}".format(len(list_input_gene_indices)))

# step 3: get the p_values by gene set
vector_gene_set_pvalues = compute_beta_tildes(X=matrix_gene_sets_gene_original, Y=vector_gene, scale_factors=scale_factors, mean_shifts=mean_shifts)

if log:
print("step 3: got p values vector of shape: {}".format(vector_gene_set_pvalues.shape))
print("step 3: filtering gene sets using p_value: {}".format(p_value))
logger.info("step 3: got p values vector of shape: {}".format(vector_gene_set_pvalues.shape))
logger.info("step 3: filtering gene sets using p_value: {}".format(p_value))

# step 4: filter the gene set columns based on computed pvalue for each gene set
matrix_gene_set_filtered_by_pvalues, selected_gene_set_indices = filter_matrix_columns(matrix_input=matrix_gene_sets_gene_original, vector_input=vector_gene_set_pvalues,
cutoff_input=p_value, log=log)
cutoff_input=p_value, max_num_gene_sets=max_num_gene_sets, log=log)
# matrix_gene_set_filtered_by_pvalues, selected_gene_set_indices = filter_matrix_columns(matrix_input=matrix_gene_sets_gene_original, vector_input=vector_gene_set_pvalues,
# cutoff_input=0.5, log=log)

if log:
print("step 4: got gene set filtered (col) matrix of shape: {}".format(matrix_gene_set_filtered_by_pvalues.shape))
print("step 4: got gene set filtered indices of length: {}".format(len(selected_gene_set_indices)))
print("step 4: got gene set filtered indices: {}".format(selected_gene_set_indices))
logger.info("step 4: got gene set filtered (col) matrix of shape: {}".format(matrix_gene_set_filtered_by_pvalues.shape))
logger.info("step 4: got gene set filtered indices of length: {}".format(len(selected_gene_set_indices)))
logger.info("step 4: got gene set filtered indices: {}".format(selected_gene_set_indices))

# step 5: filter gene rows by only the genes that are part of the remaining gene sets from the filtered gene set matrix
matrix_gene_filtered_by_remaining_gene_sets, selected_gene_indices = filter_matrix_rows_by_sum_cutoff(matrix_to_filter=matrix_gene_set_filtered_by_pvalues,
Expand All @@ -122,23 +131,23 @@ def calculate_factors(matrix_gene_sets_gene_original, list_gene, list_system_gen
list_input_genes_filtered_out_indices = [item for item in list_input_gene_indices if item not in selected_gene_indices.tolist()]

if log:
print("step 5: ===> got input gene filtered out of length: {}".format(len(list_input_genes_filtered_out_indices)))
print("step 5: got gene filtered indices of length: {}".format(len(selected_gene_indices)))
print("step 5: ===> got gene filtered (rows) matrix of shape: {} to start bayes NMF".format(matrix_gene_filtered_by_remaining_gene_sets.shape))
logger.info("step 5: ===> got input gene filtered out of length: {}".format(len(list_input_genes_filtered_out_indices)))
logger.info("step 5: got gene filtered indices of length: {}".format(len(selected_gene_indices)))
logger.info("step 5: ===> got gene filtered (rows) matrix of shape: {} to start bayes NMF".format(matrix_gene_filtered_by_remaining_gene_sets.shape))
# print("step 5: got gene filtered indices of length: {}".format(selected_gene_indices.shape))

if not all(dim > 0 for dim in matrix_gene_filtered_by_remaining_gene_sets.shape):
print("step 6: ===> skipping due to pre bayes NMF matrix of shape".format(matrix_gene_filtered_by_remaining_gene_sets.shape))
logger.info("step 6: ===> skipping due to pre bayes NMF matrix of shape".format(matrix_gene_filtered_by_remaining_gene_sets.shape))

else:
# step 6: from this double filtered matrix, compute the factors
gene_factor, gene_set_factor, _, _, exp_lambda, _ = _bayes_nmf_l2(V0=matrix_gene_filtered_by_remaining_gene_sets)
# gene_factor, gene_set_factor = run_nmf(matrix_input=matrix_gene_filtered_by_remaining_gene_sets, log=log)

if log:
print("step 6: got gene factor matrix of shape: {}".format(gene_factor.shape))
print("step 6: got gene set factor matrix of shape: {}".format(gene_set_factor.shape))
print("step 6: got lambda matrix of shape: {} with data: {}".format(exp_lambda.shape, exp_lambda))
logger.info("step 6: got gene factor matrix of shape: {}".format(gene_factor.shape))
logger.info("step 6: got gene set factor matrix of shape: {}".format(gene_set_factor.shape))
logger.info("step 6: got lambda matrix of shape: {} with data: {}".format(exp_lambda.shape, exp_lambda))

# step 7: find and rank the gene and gene set groups
list_factor, list_factor_genes, list_factor_gene_sets, updated_gene_factors = rank_gene_and_gene_sets(X=None, Y=None, exp_lambdak=exp_lambda, exp_gene_factors=gene_factor, exp_gene_set_factors=gene_set_factor.T,
Expand All @@ -150,17 +159,27 @@ def calculate_factors(matrix_gene_sets_gene_original, list_gene, list_system_gen
# print(json.dumps(map_lowest_factor_per_gene, indent=2))

if log:
print("step 7: got factor list: {}".format(list_factor))
print("step 7: got gene list:")
logger.info("step 7: got factor list: {}".format(list_factor))
logger.info("step 7: got gene list:")
for row in list_factor_genes:
print (row)
print("step 7: got gene set list:")
logger.info (row)
logger.info("step 7: got gene set list:")
for row in list_factor_gene_sets:
print (row)
logger.info (row)

# end time counter
end = time.time()
str_message = "compute process time is: {}s".format(end-start)
logs_process.append(str_message)
logs_process.append("used p_value: {}".format(p_value))
logs_process.append("used max number of gene sets: {}".format(max_num_gene_sets))

# log
for row in logs_process:
logger.info(row)

# only return the gene factors and gene set factors
return list_factor, list_factor_genes, list_factor_gene_sets, gene_factor, gene_set_factor, map_lowest_factor_per_gene
return list_factor, list_factor_genes, list_factor_gene_sets, gene_factor, gene_set_factor, map_lowest_factor_per_gene, logs_process


def group_factor_results(list_factor, list_factor_genes, list_factor_gene_sets, log=False):
Expand Down Expand Up @@ -372,17 +391,17 @@ def rank_gene_and_gene_sets(X, Y, exp_lambdak, exp_gene_factors, exp_gene_set_fa

# log
if log:
print("got lambda of shape: {}".format(exp_lambdak.shape))
print("got gene factor of shape: {}".format(exp_gene_factors.shape))
print("got gene set factor of shape: {}".format(exp_gene_set_factors.shape))
logger.info("got lambda of shape: {}".format(exp_lambdak.shape))
logger.info("got gene factor of shape: {}".format(exp_gene_factors.shape))
logger.info("got gene set factor of shape: {}".format(exp_gene_set_factors.shape))

# subset_down
# GUESS: filter and keep if exp_lambdak > 0 and at least one non zero factor for a gene and gene set; then filter by cutoff
factor_mask = exp_lambdak != 0 & (np.sum(exp_gene_factors, axis=0) > 0) & (np.sum(exp_gene_set_factors, axis=0) > 0)
factor_mask = factor_mask & (np.max(exp_gene_set_factors, axis=0) > cutoff * np.max(exp_gene_set_factors))

if log:
print("end up with factor mask of shape: {} and true count: {}".format(factor_mask.shape, np.sum(factor_mask)))
logger.info("end up with factor mask of shape: {} and true count: {}".format(factor_mask.shape, np.sum(factor_mask)))

# TODO - QUESTION
# filter by factors; why invert factor_mask?
Expand All @@ -394,9 +413,9 @@ def rank_gene_and_gene_sets(X, Y, exp_lambdak, exp_gene_factors, exp_gene_set_fa
# gene_set_values = self.betas_uncorrected

if log:
print("got NEW shrunk lambda of shape: {}".format(exp_lambdak.shape))
print("got NEW shrunk gene factor of shape: {}".format(exp_gene_factors.shape))
print("got NEW shrunk gene set factor of shape: {}".format(exp_gene_set_factors.shape))
logger.info("got NEW shrunk lambda of shape: {}".format(exp_lambdak.shape))
logger.info("got NEW shrunk gene factor of shape: {}".format(exp_gene_factors.shape))
logger.info("got NEW shrunk gene set factor of shape: {}".format(exp_gene_set_factors.shape))

# gene_values = None
# if self.combined_prior_Ys is not None:
Expand Down Expand Up @@ -448,9 +467,9 @@ def rank_gene_and_gene_sets(X, Y, exp_lambdak, exp_gene_factors, exp_gene_set_fa

# log
if log:
print("looping through factor gene set scores of size: {} and data: \n{}".format(len(factor_gene_set_scores), factor_gene_set_scores))
print("got top pathway ids type: {} and data: {}".format(type(top_gene_set_inds), top_gene_set_inds))
print("got top gene ids: {}".format(top_gene_inds))
logger.info("looping through factor gene set scores of size: {} and data: \n{}".format(len(factor_gene_set_scores), factor_gene_set_scores))
logger.info("got top pathway ids type: {} and data: {}".format(type(top_gene_set_inds), top_gene_set_inds))
logger.info("got top gene ids: {}".format(top_gene_inds))

for i in range(len(factor_gene_set_scores)):
# orginal for reference
Expand Down Expand Up @@ -503,16 +522,16 @@ def get_lowest_gene_factor_by_gene(exp_gene_factors, list_system_genes, list_gen
if all(dim > 0 for dim in exp_gene_factors.shape):
# log
if log:
print("lowest factor - got gene factor of shape: {}".format(exp_gene_factors.shape))
logger.info("lowest factor - got gene factor of shape: {}".format(exp_gene_factors.shape))
# print("lowest factor - got filtered gene mask of size: {} and data: \n{}".format(len(list_gene_mask), list_gene_mask))

# get the lowest value per row
min_per_row = np.min(exp_gene_factors, axis=1)

if log:
print("lowest factor - got gene factor MINIMUM of shape: {} and type: {}".format(min_per_row.shape, type(min_per_row)))
logger.info("lowest factor - got gene factor MINIMUM of shape: {} and type: {}".format(min_per_row.shape, type(min_per_row)))
for index in range(len(list_gene_mask)):
print("lowest factor - for gene: {} get factor : {}".format(list_system_genes[list_gene_mask[index]], exp_gene_factors[index]))
logger.info("lowest factor - for gene: {} get factor : {}".format(list_system_genes[list_gene_mask[index]], exp_gene_factors[index]))

# build the map
if min_per_row is not None:
Expand Down Expand Up @@ -542,8 +561,8 @@ def get_referenced_list_elements(list_referenced, list_index, log=False):

# log
if log:
print("ref list: {}".format(list_referenced))
print("index list: {}".format(list_index))
logger.info("ref list: {}".format(list_referenced))
logger.info("index list: {}".format(list_index))

# get the elements
list_result = [list_referenced[i] for i in list_index]
Expand Down Expand Up @@ -660,9 +679,10 @@ def _get_num_X_blocks(X_orig, batch_size=None):
return int(np.ceil(X_orig.shape[1] / batch_size))


def filter_matrix_columns(matrix_input, vector_input, cutoff_input=0.05, log=False):
def filter_matrix_columns(matrix_input, vector_input, cutoff_input, max_num_gene_sets, log=False):
'''
will filter the matrix based on the vector and cutoff
the columns are gene sets in this instance
'''

# REFERENCE
Expand All @@ -675,17 +695,33 @@ def filter_matrix_columns(matrix_input, vector_input, cutoff_input=0.05, log=Fal

# log
if log:
print("got matrix to filter of shape: {} and type: {}".format(matrix_input.shape, type(matrix_input)))
print("got filter vector of shape: {} and type: {}".format(vector_input.shape, type(vector_input)))
logger.info("got matrix to filter of shape: {} and type: {}".format(matrix_input.shape, type(matrix_input)))
logger.info("got filter vector of shape: {} and type: {}".format(vector_input.shape, type(vector_input)))
# logger.info("passing vector value: {}".format(vector_input[0,51864]))

# select the columns that pass the p_value cutoff
selected_column_indices = np.where(np.any(vector_input < cutoff_input, axis=0))[0]

# CHECK - if there are more selected columns than the max_column parameter, take the top columns only
if len(selected_column_indices) > max_num_gene_sets:
# log
if log:
logger.info("filtered gene sets of size: {} is larger than the max: {}, so taking top {}".format(len(selected_column_indices), max_num_gene_sets, max_num_gene_sets))

# Get the indices of the n lowest values
min_values = np.min(vector_input, axis=0)
selected_column_indices = np.argsort(min_values)[:max_num_gene_sets]

# filter the reference gene/gene sets matrix down
matrix_result = matrix_input[:, selected_column_indices]

# log
if log:
print("got filtered column list of length: {}".format(len(selected_column_indices)))
print("got resulting shape from column filters from: {} to {}".format(matrix_input.shape, matrix_result.shape))
# print("example filtered: {}".format(matrix_result[11205]))
logger.info("vector values that passed {} filter or are top {} gene sets: {}".format(cutoff_input, max_num_gene_sets, vector_input[0, selected_column_indices]))
logger.info("got filtered column list of length: {}".format(len(selected_column_indices)))
logger.info("got filtered column list of: {}".format(selected_column_indices))
logger.info("got resulting shape of column filters from: {} to {}".format(matrix_input.shape, matrix_result.shape))
# logger.info("filtered matrix: {}".format(matrix_result))

# return
return matrix_result, selected_column_indices
Expand All @@ -702,8 +738,8 @@ def filter_matrix_rows_by_sum_cutoff(matrix_to_filter, matrix_to_sum, cutoff_inp
# # matrix_result = matrix_to_filter[mask, :]

if log:
print("got matrix to filter of shape: {} and type: {}".format(matrix_to_filter.shape, type(matrix_to_filter)))
print("got matrix to sum of shape: {} and type: {}".format(matrix_to_sum.shape, type(matrix_to_sum)))
logger.info("got matrix to filter of shape: {} and type: {}".format(matrix_to_filter.shape, type(matrix_to_filter)))
logger.info("got matrix to sum of shape: {} and type: {}".format(matrix_to_sum.shape, type(matrix_to_sum)))

mask = matrix_to_sum.sum(axis=1) > cutoff_input
# selected_indices = np.where(mask)[0]
Expand All @@ -713,7 +749,7 @@ def filter_matrix_rows_by_sum_cutoff(matrix_to_filter, matrix_to_sum, cutoff_inp

# log
if log:
print("got resulting shape from row sum filters from: {} to {}".format(matrix_to_filter.shape, matrix_result.shape))
logger.info("got resulting shape from row sum filters from: {} to {}".format(matrix_to_filter.shape, matrix_result.shape))
# print("got filter rows indices: {}".format(selected_indices))
# print("example matrix to sum: {}".format(matrix_to_sum.toarray()[2]))

Expand All @@ -734,8 +770,8 @@ def run_nmf(matrix_input, num_components=15, log=False):

# log
if log:
print("for gene factor of shape: {}".format(W.shape))
print("for gene set factor of shape: {}".format(H.shape))
logger.info("for gene factor of shape: {}".format(W.shape))
logger.info("for gene set factor of shape: {}".format(H.shape))

# return
return W, H
Expand Down
Loading

0 comments on commit 16915c5

Please sign in to comment.