-
Notifications
You must be signed in to change notification settings - Fork 36
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Numba parallel weights computation + dataloader #5
base: master
Are you sure you want to change the base?
Conversation
…_encoding function; minor linting, changed seq_name_to_sequence to a string instead of list of chars
…minor printouts to MSA_processing
passing z_dim into train_VAE script manually passing in args for VAE checkpoint reloading small typos in scripts
…to mixed batch joint training. Initialising the bias to mean(y_train) for much better convergence, still not great performance though. Moved parameter reading outside of main function so that we can override the z_dim size
saving vae checkpoints, checkpoint loading vs train from scratch, added sigmoid+bce loss, added 3 very long functions for mixed/alternating/frozen training modes to switch from command line, added linear model loss weight
… outputs are equal
…cord, will delete the bad ones
…gure out properly later
…apping file also added "identity" weights for completion
# Conflicts: # EVE/VAE_model.py # calc_weights.py # compute_evol_indices.py # data/mappings/example_mapping.csv # examples/Step0_optional_calc_weights.sh # examples/Step0_optional_calc_weights_slurm.sh # train_VAE.py # utils/data_utils.py # utils/weights.py
Added progress bar, weights-only calc mode
Fallback to normal mode also works well
…der, merged in changes from ProteinGym. Removed the aggregation methods for evol indices.
…default, tested with DLG4 (cherry picked from commit fcb7894)
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Posting a few comments for now
@@ -11,7 +11,7 @@ EVE is a set of protein-specific models providing for any single amino acid muta | |||
The end to end process to compute EVE scores consists of three consecutive steps: | |||
1. Train the Bayesian VAE on a re-weighted multiple sequence alignment (MSA) for the protein of interest => train_VAE.py | |||
2. Compute the evolutionary indices for all single amino acid mutations => compute_evol_indices.py | |||
3. Train a GMM to cluster variants on the basis of the evol indices then output scores and uncertainties on the class assignments => train_GMM_and_compute_EVE_scores.py | |||
3. Train a GMM to cluster variants on the basis of the qevol indices then output scores and uncertainties on the class assignments => train_GMM_and_compute_EVE_scores.py |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
revert please
def sample_latent(self, mu, log_var): | ||
""" | ||
Samples a latent vector via reparametrization trick | ||
""" | ||
eps = torch.randn_like(mu).to(self.device) | ||
z = torch.exp(0.5*log_var) * eps + mu | ||
z = torch.exp(0.5 * log_var) * eps + mu |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
it would be nice to keep linting/formatting to a separate PR so this one isn't as cluttered
|
||
parser = argparse.ArgumentParser(description='Evol indices') | ||
parser.add_argument('--MSA_data_folder', type=str, help='Folder where MSAs are stored') | ||
parser.add_argument('--MSA_list', type=str, help='List of proteins and corresponding MSA file name') | ||
parser.add_argument('--protein_index', type=int, help='Row index of protein in input mapping file') | ||
parser.add_argument('--MSA_weights_location', type=str, help='Location where weights for each sequence in the MSA will be stored') | ||
parser.add_argument('--theta_reweighting', type=float, help='Parameters for MSA sequence re-weighting') | ||
# parser.add_argument('--MSA_weights_location', type=str, help='Location where weights for each sequence in the MSA will be stored') |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
should these arguments be deprecated instead of removed entirely?
parser.add_argument('--training_logs_location', type=str, | ||
help='Location of VAE model parameters') | ||
parser.add_argument("--seed", type=int, help="Random seed", default=42) | ||
parser.add_argument('--z_dim', type=int, help='Specify a different latent dim than in the params file') |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
this can be done by editing the params config file
parser.add_argument('--force_load_weights', action='store_true', | ||
help="Force loading of weights from MSA_weights_location (useful if you want to make sure you're using precalculated weights). Will fail if weight file doesn't exist.", | ||
default=False) | ||
parser.add_argument("--overwrite_weights", |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
is --overwrite_weights necessary?
action="store_true", default=False) | ||
parser.add_argument("--batch_size", type=int, | ||
help="Batch size for training", default=None) | ||
parser.add_argument("--experimental_stream_data", |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
we could consider making this the default and removing the CLI option, if it's judged robust enough
print("Protein name: "+str(protein_name)) | ||
print("MSA file: "+str(msa_location)) | ||
|
||
if mapping_file["MSA_filename"].duplicated().any(): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
for the new mapping colnames "MSA_filename" and "MSA_theta", we should probably accept the old filenames "protein_name" and "theta" for legacy mapping files
We should squash all the commit messages together, there's a lot of random scripts from other projects that were in (but I've now removed)