diff --git a/ProteinDT/model_ColdDiffusionDecoder.py b/ProteinDT/model_ColdDiffusionDecoder.py deleted file mode 100644 index 7b0a42c..0000000 --- a/ProteinDT/model_ColdDiffusionDecoder.py +++ /dev/null @@ -1,131 +0,0 @@ -import torch -import torch.nn as nn -import torch.nn.functional as F -from transformers import BertConfig -from ProteinDT.models.model_SDE import VESDE, VPSDE -from ProteinDT.models.model_Sampler import ReverseDiffusionPredictor, LangevinCorrector -from ProteinDT.models.score_networks import ToyScoreNetwork, RNNScoreNetwork, BertScoreNetwork - -EPS = 1e-6 - - -class ColdDiffusionDecoder(nn.Module): - def __init__( - self, hidden_dim, condition_dim, beta_min, beta_max, num_diffusion_timesteps, num_classes, score_network_type - ): - super().__init__() - self.hidden_dim = hidden_dim - self.condition_dim = condition_dim - self.beta_min = beta_min - self.beta_max = beta_max - self.num_diffusion_timesteps = num_diffusion_timesteps - self.num_classes = num_classes - - self.SDE_func = VPSDE(beta_min=self.beta_min, beta_max=self.beta_max, N=self.num_diffusion_timesteps) - - output_dim = hidden_dim - if score_network_type == "Toy": - word_embedding_dim = self.hidden_dim - self.score_network = ToyScoreNetwork(hidden_dim=hidden_dim, output_dim=output_dim) - - elif score_network_type == "RNN": - word_embedding_dim = self.hidden_dim - self.score_network = RNNScoreNetwork(hidden_dim=hidden_dim, output_dim=output_dim) - - elif score_network_type == "BertBase": - config = BertConfig.from_pretrained( - "bert-base-uncased", - cache_dir="../data/temp_Bert_base", - vocab_size=self.num_classes, - hidden_size=hidden_dim, - num_attention_heads=8 - ) - word_embedding_dim = self.hidden_dim - self.score_network = BertScoreNetwork(config=config, output_dim=output_dim) - - self.word_embedding_dim = word_embedding_dim - self.embedding_layer = nn.Linear(self.num_classes, self.word_embedding_dim, bias=False) - self.decoder_layer = nn.Linear(word_embedding_dim, self.num_classes) - self.condition_proj_layer = nn.Linear(self.condition_dim, self.word_embedding_dim) - - self.CE_criterion = nn.CrossEntropyLoss(reduction='none') - return - - def forward(self, protein_seq_input_ids, protein_seq_attention_mask, condition): - B = protein_seq_input_ids.size()[0] - device = protein_seq_input_ids.device - - # TODO: need double-check range of timesteps - timesteps = torch.rand(B, device=device) * (1 - EPS) + EPS # (B) - - condition = self.condition_proj_layer(condition) # (B, max_seq_len, condition_dim) ---> (B, max_seq_len, hidden_dim) - - protein_seq_onehot = F.one_hot(protein_seq_input_ids, num_classes=self.num_classes) # (B, max_seq_len, num_class) - protein_seq_onehot = protein_seq_onehot.float() - - #### cold diffusion can add noise either in the one-hot level - epsilon = torch.randn_like(protein_seq_onehot.float()) # (B, max_seq_len, num_class) - mean_noise, std_noise = self.SDE_func.marginal_prob(protein_seq_onehot, timesteps) # (B, max_seq_len, num_classes), (B) - protein_seq_onehot_noise = mean_noise + std_noise[:, None, None] * epsilon # (B, max_seq_len, num_classes) - protein_seq_repr_noise = self.embedding_layer(protein_seq_onehot_noise) # (B, max_seq_len, hidden_dim) - - # ##### TODO: or in the embedding level??? - # protein_seq_repr = self.embedding_layer(protein_seq_onehot) # (B, max_seq_len, hidden_dim) - # epsilon = torch.randn_like(protein_seq_repr.float()) # (B, max_seq_len, hidden_dim) - # protein_seq_repr = self.embedding_layer(protein_seq_onehot) # (B, max_seq_len, hidden_dim) - # mean_noise, std_noise = self.SDE_func.marginal_prob(protein_seq_repr, timesteps) # (B, max_seq_len, hidden_dim), (B) - # protein_seq_repr_noise = mean_noise + std_noise[:, None, None] * epsilon # (B, max_seq_len, hidden_dim) - - score = self.score_network(protein_seq_repr=protein_seq_repr_noise, protein_seq_attention_mask=protein_seq_attention_mask, condition=condition) # (B, max_seq_len, hidden_dim) or (B, max_seq_len, num_class) - score = self.decoder_layer(score) # (B*max_sequence_len, num_class) - - flattened_logits = score.view(-1, score.size(-1)) # (B*max_sequence_len, num_class) - flattened_ids = protein_seq_input_ids.view(-1) # (B*max_sequence_len) - flattened_mask = protein_seq_attention_mask.view(-1) # (B*max_sequence_len) - total_SDE_loss = self.CE_criterion(flattened_logits, flattened_ids) # (B*max_sequence_len) - masked_SDE_loss = total_SDE_loss * flattened_mask # (B*max_sequence_len) - total_SDE_loss = torch.mean(total_SDE_loss) - masked_SDE_loss = total_SDE_loss.sum() / flattened_mask.sum() - - SDE_loss = total_SDE_loss + masked_SDE_loss - decoding_loss = 0 - - return SDE_loss, decoding_loss - - @torch.no_grad() - def inference(self, condition, max_seq_len, protein_seq_attention_mask): - B = condition.size()[0] - device = condition.device - - shape = (B, max_seq_len, self.num_classes) - - X_one_hot = self.SDE_func.prior_sampling(shape).to(device) # (B, max_seq_len, word_embedding_dim) - - EPSILON = 1e-5 - - timesteps = torch.linspace(self.SDE_func.T, EPSILON, self.num_diffusion_timesteps, device=device) # (num_diffusion_timesteps) - - condition = condition.float() - condition = self.condition_proj_layer(condition) # (B, max_seq_len, condition_dim) ---> (B, max_seq_len, hidden_dim) - - x_one_hot_t = X_one_hot - for i in range(0, self.num_diffusion_timesteps-1): - x_repr_t = self.embedding_layer(x_one_hot_t) # (B, max_seq_len, hidden_dim) - score = self.score_network(protein_seq_repr=x_repr_t, protein_seq_attention_mask=protein_seq_attention_mask, condition=condition) # (B, max_seq_len, hidden_dim) - hat_x_one_hot_0 = self.decoder_layer(score) # (B, max_sequence_len, num_class) - - t = timesteps[i] - vec_t = torch.ones(shape[0], device=device) * t # (B) - t_1 = timesteps[i+1] - vec_t_1 = torch.ones(shape[0], device=device) * t_1 # (B) - epsilon = torch.randn_like(hat_x_one_hot_0) # (B, max_seq_len, num_class) - - mean_noise, std_noise = self.SDE_func.marginal_prob(hat_x_one_hot_0, vec_t) # (B, max_seq_len, num_classes), (B) - x_one_hot_t = mean_noise + std_noise[:, None, None] * epsilon # (B, max_seq_len, num_classes) - mean_noise, std_noise = self.SDE_func.marginal_prob(hat_x_one_hot_0, vec_t_1) # (B, max_seq_len, num_classes), (B) - x_one_hot_t_1 = mean_noise + std_noise[:, None, None] * epsilon # (B, max_seq_len, num_classes) - - x_one_hot_t = x_one_hot_t - x_one_hot_t - x_one_hot_t_1 # (B, max_sequence_len, num_class) - x = x_one_hot_t - - return x diff --git a/ProteinDT/model_GaussianSDEDecoder.py b/ProteinDT/model_GaussianSDEDecoder.py deleted file mode 100644 index 83ed0dc..0000000 --- a/ProteinDT/model_GaussianSDEDecoder.py +++ /dev/null @@ -1,238 +0,0 @@ -import torch -import torch.nn as nn -import torch.nn.functional as F -from transformers import BertConfig -from ProteinDT.models.model_SDE import VESDE, VPSDE -from ProteinDT.models.model_Sampler import ReverseDiffusionPredictor, LangevinCorrector -from ProteinDT.models.score_networks import ToyScoreNetwork, RNNScoreNetwork, BertScoreNetwork - -EPS = 1e-6 - - -def get_score_fn(SDE_func, score_network, train=True, continuous=True): - if not train: - score_network.eval() - - if isinstance(SDE_func, VPSDE): - def score_fn(x, x_mask, condition, t): - if continuous: - score = score_network(protein_seq_repr=x, protein_seq_attention_mask=x_mask, condition=condition) - std = SDE_func.marginal_prob(x, t)[1] - else: - raise NotImplementedError(f"Discrete not supported") - score = -score / std[:, None, None] - return score - - elif isinstance(SDE_func, VESDE): - def score_fn(x, x_mask, condition, t): - if continuous: - score = score_network(protein_seq_repr=x, protein_seq_attention_mask=x_mask, condition=condition) - else: - raise NotImplementedError(f"Discrete not supported") - score = -score - return score - - else: - raise NotImplementedError(f"SDE class {SDE_func.__class__.__name__} not supported.") - - return score_fn - - -class GaussianSDEDecoderModel(nn.Module): - def __init__( - self, hidden_dim, condition_dim, beta_min, beta_max, num_diffusion_timesteps, SDE_mode, num_classes, - score_network_type - ): - super().__init__() - self.hidden_dim = hidden_dim - self.condition_dim = condition_dim - self.beta_min = beta_min - self.beta_max = beta_max - self.num_diffusion_timesteps = num_diffusion_timesteps - self.SDE_mode = SDE_mode - self.num_classes = num_classes - - if self.SDE_mode in ["VE", "ColdVE"]: - self.SDE_func = VESDE(sigma_min=self.beta_min, sigma_max=self.beta_max, N=self.num_diffusion_timesteps) - elif self.SDE_mode in ["VP", "ColdVP"]: - self.SDE_func = VPSDE(beta_min=self.beta_min, beta_max=self.beta_max, N=self.num_diffusion_timesteps) - - if score_network_type == "Toy": - word_embedding_dim = self.hidden_dim - if self.SDE_mode in ["ColdVE", "ColdVP"]: - output_dim = self.num_classes - else: - output_dim = word_embedding_dim - self.score_network = ToyScoreNetwork(hidden_dim=hidden_dim, output_dim=output_dim) - - elif score_network_type == "RNN": - word_embedding_dim = self.hidden_dim - if self.SDE_mode in ["ColdVE", "ColdVP"]: - output_dim = self.num_classes - else: - output_dim = word_embedding_dim - self.score_network = RNNScoreNetwork(hidden_dim=hidden_dim, output_dim=output_dim) - - elif score_network_type == "BertProtBFD": - word_embedding_dim = 1024 - if self.SDE_mode in ["ColdVE", "ColdVP"]: - output_dim = self.num_classes - else: - output_dim = word_embedding_dim - self.score_network = BertScoreNetwork.from_pretrained( - "Rostlab/prot_bert_bfd", - cache_dir="../data/temp_pretrained_ProtBert_BFD", - ignore_mismatched_sizes=True, - ) - - elif score_network_type == "BertBase": - config = BertConfig.from_pretrained( - "bert-base-uncased", - cache_dir="../data/temp_Bert_base", - vocab_size=self.num_classes, - hidden_size=hidden_dim, - num_attention_heads=8 - ) - word_embedding_dim = self.hidden_dim - if self.SDE_mode in ["ColdVE", "ColdVP"]: - output_dim = self.num_classes - else: - output_dim = word_embedding_dim - self.score_network = BertScoreNetwork(config=config, output_dim=output_dim) - - self.word_embedding_dim = word_embedding_dim - self.embedding_layer = nn.Linear(self.num_classes, self.word_embedding_dim, bias=False) - self.decoder_layer = nn.Linear(word_embedding_dim, self.num_classes) - self.condition_proj_layer = nn.Linear(self.condition_dim, self.word_embedding_dim) - - self.score_fn = get_score_fn(self.SDE_func, self.score_network, train=True, continuous=True) - self.predictor = ReverseDiffusionPredictor(self.SDE_func, self.score_fn) - self.corrector = LangevinCorrector(self.SDE_func, self.score_fn, snr=0.1, n_steps=100) - self.CE_criterion = nn.CrossEntropyLoss(reduction='none') - return - - def forward(self, protein_seq_input_ids, protein_seq_attention_mask, condition): - B = protein_seq_input_ids.size()[0] - device = protein_seq_input_ids.device - - # TODO: need double-check range of timesteps - timesteps = torch.rand(B, device=device) * (1 - EPS) + EPS # (B) - - condition = self.condition_proj_layer(condition) # (B, max_seq_len, condition_dim) ---> (B, max_seq_len, hidden_dim) - - protein_seq_onehot = F.one_hot(protein_seq_input_ids, num_classes=self.num_classes) # (B, max_seq_len, num_class) - protein_seq_onehot = protein_seq_onehot.float() - - if self.SDE_mode in ["ColdVE", "ColdVP"]: - # epsilon = torch.randn_like(protein_seq_onehot.float()) # (B, max_seq_len, num_class) - # mean_noise, std_noise = self.SDE_func.marginal_prob(protein_seq_onehot, timesteps) # (B, max_seq_len, num_classes), (B) - # protein_seq_onehot_noise = mean_noise + std_noise[:, None, None] * epsilon # (B, max_seq_len, num_classes) - # protein_seq_repr_noise = self.embedding_layer(protein_seq_onehot_noise) # (B, max_seq_len, hidden_dim) - - protein_seq_repr = self.embedding_layer(protein_seq_onehot) # (B, max_seq_len, hidden_dim) - epsilon = torch.randn_like(protein_seq_repr.float()) # (B, max_seq_len, hidden_dim) - protein_seq_repr = self.embedding_layer(protein_seq_onehot) # (B, max_seq_len, hidden_dim) - mean_noise, std_noise = self.SDE_func.marginal_prob(protein_seq_repr, timesteps) # (B, max_seq_len, hidden_dim), (B) - protein_seq_repr_noise = mean_noise + std_noise[:, None, None] * epsilon # (B, max_seq_len, hidden_dim) - else: - protein_seq_repr = self.embedding_layer(protein_seq_onehot) # (B, max_seq_len, hidden_dim) - epsilon = torch.randn_like(protein_seq_repr.float()) # (B, max_seq_len, hidden_dim) - protein_seq_repr = self.embedding_layer(protein_seq_onehot) # (B, max_seq_len, hidden_dim) - mean_noise, std_noise = self.SDE_func.marginal_prob(protein_seq_repr, timesteps) # (B, max_seq_len, hidden_dim), (B) - protein_seq_repr_noise = mean_noise + std_noise[:, None, None] * epsilon # (B, max_seq_len, hidden_dim) - - score = self.score_fn(protein_seq_repr_noise, protein_seq_attention_mask, condition, timesteps) # (B, max_seq_len, hidden_dim) or (B, max_seq_len, num_class) - - if self.SDE_mode in ["ColdVE", "ColdVP"]: - flattened_logits = score.view(-1, score.size(-1)) # (B*max_sequence_len, num_class) - flattened_ids = protein_seq_input_ids.view(-1) # (B*max_sequence_len) - flattened_mask = protein_seq_attention_mask.view(-1) # (B*max_sequence_len) - total_SDE_loss = self.CE_criterion(flattened_logits, flattened_ids) # (B*max_sequence_len) - masked_SDE_loss = total_SDE_loss * flattened_mask # (B*max_sequence_len) - total_SDE_loss = torch.mean(total_SDE_loss) - masked_SDE_loss = total_SDE_loss.sum() / flattened_mask.sum() - - SDE_loss = total_SDE_loss + masked_SDE_loss - decoding_loss = 0 - - else: - total_SDE_loss = torch.square(score * std_noise[:, None, None] + epsilon) # (B, max_seq_len, hidden_dim) - masked_SDE_loss = total_SDE_loss * protein_seq_attention_mask.unsqueeze(2) # (B, max_seq_len, hidden_dim) - total_SDE_loss = torch.mean(total_SDE_loss) - masked_SDE_loss = masked_SDE_loss.sum() / protein_seq_attention_mask.sum() - SDE_loss = total_SDE_loss + masked_SDE_loss - - # regenerate protein_seq_repr - protein_seq_ids_pred_logit = self.decoder_layer(protein_seq_repr) # (B, max_seq_len, num_class) - flattened_logits = protein_seq_ids_pred_logit.view(-1, protein_seq_ids_pred_logit.size(-1)) # (B*max_sequence_len, num_class) - flattened_ids = protein_seq_input_ids.view(-1) # (B*max_sequence_len) - flattened_mask = protein_seq_attention_mask.view(-1) # (B*max_sequence_len) - total_decoding_loss = self.CE_criterion(flattened_logits, flattened_ids) # (B*max_sequence_len) - masked_decoding_loss = total_decoding_loss * flattened_mask # (B*max_sequence_len) - total_decoding_loss = torch.mean(total_decoding_loss) - masked_decoding_loss = masked_decoding_loss.sum() / flattened_mask.sum() - decoding_loss = total_decoding_loss + masked_decoding_loss - return SDE_loss, decoding_loss - - @torch.no_grad() - def inference(self, condition, max_seq_len, protein_seq_attention_mask): - B = condition.size()[0] - device = condition.device - - if self.SDE_mode in ["ColdVE", "ColdVP"]: - shape = (B, max_seq_len, self.num_classes) - else: - shape = (B, max_seq_len, self.word_embedding_dim) - - X_T = self.SDE_func.prior_sampling(shape).to(device) # (B, max_seq_len, word_embedding_dim) - - EPSILON = 1e-5 - - timesteps = torch.linspace(self.SDE_func.T, EPSILON, self.num_diffusion_timesteps, device=device) # (num_diffusion_timesteps) - - condition = condition.float() - condition = self.condition_proj_layer(condition) # (B, max_seq_len, condition_dim) ---> (B, max_seq_len, hidden_dim) - - if self.SDE_mode in ["ColdVE", "ColdVP"]: - # TODO: this is wrong - onehot_x = X_T - for i in range(0, self.num_diffusion_timesteps-1): - repr_x = self.embedding_layer(onehot_x) - - t = timesteps[i] - vec_t = torch.ones(shape[0], device=device) * t - std_t = self.SDE_func.marginal_prob(torch.ones_like(vec_t), vec_t)[1] - score_t = self.score_fn(repr_x, protein_seq_attention_mask, condition, vec_t) - # score_t = -score_t / std_t[:, None, None] - - t_1 = timesteps[i+1] - vec_t_1 = torch.ones(shape[0], device=device) * t_1 - std_t_1 = self.SDE_func.marginal_prob(torch.ones_like(vec_t_1), vec_t_1)[1] - score_t_1 = self.score_fn(repr_x, protein_seq_attention_mask, condition, vec_t_1) - # score_t_1 = -score_t_1 / std_t_1[:, None, None] - - onehot_x = onehot_x + score_t - score_t_1 - x = onehot_x - - else: - x = X_T - for i in range(0, self.num_diffusion_timesteps): - t = timesteps[i] - vec_t = torch.ones(shape[0], device=device) * t - - x, x_mean = self.corrector.update_fn(x, protein_seq_attention_mask, condition, vec_t) - - x, x_mean = self.predictor.update_fn(x, protein_seq_attention_mask, condition, vec_t) # (B, max_seq_len, num_class), (B, max_seq_len, num_class) - - x = self.decoder_layer(x) - return x - - -if __name__ == "__main__": - B = 10 - timesteps = torch.rand(B) * (1 - EPS) + EPS - print(timesteps) - - EPS = 1e-3 - timesteps = torch.linspace(1, EPS, 1000) - print(timesteps) diff --git a/ProteinDT/model_LSTMDecoder.py b/ProteinDT/model_LSTMDecoder.py deleted file mode 100644 index 42782aa..0000000 --- a/ProteinDT/model_LSTMDecoder.py +++ /dev/null @@ -1,103 +0,0 @@ -import numpy as np - -import torch -import torch.nn as nn -import torch.nn.functional as F -import random - -# from utils import * - - -class LSTMDecoder(nn.Module): - def __init__(self, hidden_dim, n_layer, embedding_dim, epsilon, num_classes, tokenizer): - super(LSTMDecoder, self).__init__() - self.hidden_dim = hidden_dim - self.n_layer = n_layer - self.embedding_dim = embedding_dim - self.num_classes = num_classes - self.epsilon = epsilon - - self.embedding_layer = nn.Embedding(num_embeddings=self.num_classes, embedding_dim=self.embedding_dim) - - self.lstm = nn.LSTM(input_size=self.hidden_dim+self.embedding_dim, hidden_size=self.hidden_dim, num_layers=self.n_layer, batch_first=True) - self.fc = nn.Linear(self.hidden_dim, self.num_classes) - self.CE_criterion = nn.CrossEntropyLoss() - self.tokenizer = tokenizer - return - - def forward(self, protein_seq_input_ids, protein_seq_attention_mask, condition): - input_seq_emb = self.embedding_layer(protein_seq_input_ids) # [B, max_seq_len, embedding_dim] - - batch_size, n_seq, _ = input_seq_emb.size() - - h_0 = torch.zeros(self.n_layer, batch_size, self.hidden_dim).detach() - c_0 = torch.zeros(self.n_layer, batch_size, self.hidden_dim).detach() - g_hidden = (h_0.cuda(), c_0.cuda()) - - condition = condition.unsqueeze(1) # [B, 1, hidde_dim] - prev_word = protein_seq_input_ids[:, 0:1] # [B, 1] - - output = [] - - for j in range(n_seq-1): - if random.random() < self.epsilon: - current_word_emb = self.embedding_layer(prev_word) # [B, 1, embedding_dim] - else: - current_word_emb = input_seq_emb[:, j:(j+1), :] # [B, 1, embedding_dim] - - x = torch.cat([condition, current_word_emb], dim=-1) # [B, 1, hidden_dim+embedding_dim] - logits, g_hidden = self.lstm(x, g_hidden) # [B, 1, hidden_dim], [[n_layer, B, hidden_dim], [n_layer, B, hidden_dim]] - logits = self.fc(logits) # [B, 1, num_classes] - prev_word = torch.argmax(logits, dim = -1) # [B, 1] - output.append(logits) - - output = torch.cat(output, dim=1) # [B, max_seq_len-1, num_classes] - target_protein_seq_input_ids = protein_seq_input_ids[:, 1:].contiguous() # [B, max_seq_len-1] - target_protein_seq_attention_mask = protein_seq_attention_mask[:, 1:].contiguous() # [B, max_seq_len-1] - flattened_logits = output.view(-1, output.size(-1)) # [B * (max_sequence_len-1), num_class] - flattened_ids = target_protein_seq_input_ids.view(-1) # [B * (max_sequence_len-1)] - flattened_mask = target_protein_seq_attention_mask.view(-1) # [B * (max_sequence_len-1)] - total_loss = self.CE_criterion(flattened_logits, flattened_ids) # [B * (max_sequence_len-1)] - masked_loss = total_loss * flattened_mask # [B * (max_sequence_len-1)] - total_loss = torch.mean(total_loss) - masked_loss = masked_loss.sum() / flattened_mask.sum() - - loss = total_loss + masked_loss - decoding_loss = 0 - - return loss, decoding_loss - - def inference(self, condition, protein_seq_attention_mask, max_seq_len, temperature=1, use_sample=False): - - device = condition.device - condition = condition.unsqueeze(1) # [B, 1, hidde_dim] - batch_size = condition.size()[0] - prev_word = torch.ones([batch_size]).long().to(device) * self.tokenizer.cls_token_id # [B] - prev_word = prev_word.unsqueeze(1) # [B, 1] - - h_0 = torch.zeros(self.n_layer, batch_size, self.hidden_dim).detach() - c_0 = torch.zeros(self.n_layer, batch_size, self.hidden_dim).detach() - g_hidden = (h_0.cuda(), c_0.cuda()) - - output = [] - for _ in range(max_seq_len): - current_word_emb = self.embedding_layer(prev_word) # [B, 1, embedding_dim] - x = torch.cat([condition, current_word_emb], dim=-1) # [B, 1, hidden_dim+embedding_dim] - logits, g_hidden = self.lstm(x, g_hidden) # [B, 1, hidden_dim], [[n_layer, B, hidden_dim], [n_layer, B, hidden_dim]] - logits = self.fc(logits) # [B, 1, num_classes] - - if use_sample: - probs = torch.softmax(logits / temperature, dim=-1) - prediction = [] - for data_idx in range(batch_size): - prediction_temp = torch.multinomial(probs[data_idx], num_samples=1) - prediction.append(prediction_temp) - prediction = torch.cat(prediction) # [B, 1] - prev_word = prediction # [B, 1] - else: - prev_word = torch.argmax(logits, dim=-1) # [B, 1] - - output.append(prev_word) - - output = torch.cat(output, dim=1) # [B, max_seq_len] - return output diff --git a/ProteinDT/model_LatentDiffusionDecoder.py b/ProteinDT/model_LatentDiffusionDecoder.py deleted file mode 100644 index 6076e41..0000000 --- a/ProteinDT/model_LatentDiffusionDecoder.py +++ /dev/null @@ -1,152 +0,0 @@ -import torch -import torch.nn as nn -import torch.nn.functional as F -from transformers import BertConfig -from ProteinDT.models.model_SDE import VESDE, VPSDE -from ProteinDT.models.model_Sampler import ReverseDiffusionPredictor, LangevinCorrector -from ProteinDT.models.score_networks import ToyScoreNetwork, RNNScoreNetwork, BertScoreNetwork - -EPS = 1e-6 - - -def get_score_fn(SDE_func, score_network, train=True, continuous=True): - if not train: - score_network.eval() - - if isinstance(SDE_func, VPSDE): - def score_fn(x, x_mask, condition, t): - if continuous: - score = score_network(protein_seq_repr=x, protein_seq_attention_mask=x_mask, condition=condition) - std = SDE_func.marginal_prob(x, t)[1] - else: - raise NotImplementedError(f"Discrete not supported") - score = -score / std[:, None, None] - return score - - elif isinstance(SDE_func, VESDE): - def score_fn(x, x_mask, condition, t): - if continuous: - score = score_network(protein_seq_repr=x, protein_seq_attention_mask=x_mask, condition=condition) - else: - raise NotImplementedError(f"Discrete not supported") - score = -score - return score - - else: - raise NotImplementedError(f"SDE class {SDE_func.__class__.__name__} not supported.") - - return score_fn - - -class LatentDiffusionDecoder(nn.Module): - def __init__( - self, hidden_dim, condition_dim, beta_min, beta_max, num_diffusion_timesteps, num_classes, score_network_type - ): - super().__init__() - self.hidden_dim = hidden_dim - self.condition_dim = condition_dim - self.beta_min = beta_min - self.beta_max = beta_max - self.num_diffusion_timesteps = num_diffusion_timesteps - self.num_classes = num_classes - - self.SDE_func = VPSDE(beta_min=self.beta_min, beta_max=self.beta_max, N=self.num_diffusion_timesteps) - - word_embedding_dim = self.hidden_dim - output_dim = word_embedding_dim - - if score_network_type == "Toy": - self.score_network = ToyScoreNetwork(hidden_dim=hidden_dim, output_dim=output_dim) - - elif score_network_type == "RNN": - self.score_network = RNNScoreNetwork(hidden_dim=hidden_dim, output_dim=output_dim) - - elif score_network_type == "BertBase": - config = BertConfig.from_pretrained( - "bert-base-uncased", - cache_dir="../data/temp_Bert_base", - vocab_size=self.num_classes, - hidden_size=hidden_dim, - num_attention_heads=8 - ) - self.score_network = BertScoreNetwork(config=config, output_dim=output_dim) - - self.word_embedding_dim = word_embedding_dim - self.embedding_layer = nn.Linear(self.num_classes, self.word_embedding_dim, bias=False) - self.decoder_layer = nn.Linear(word_embedding_dim, self.num_classes) - self.condition_proj_layer = nn.Linear(self.condition_dim, self.word_embedding_dim) - - self.score_fn = get_score_fn(self.SDE_func, self.score_network, train=True, continuous=True) - self.predictor = ReverseDiffusionPredictor(self.SDE_func, self.score_fn) - self.corrector = LangevinCorrector(self.SDE_func, self.score_fn, snr=0.1, n_steps=10) - self.CE_criterion = nn.CrossEntropyLoss(reduction='none') - return - - def forward(self, protein_seq_input_ids, protein_seq_attention_mask, condition): - B = protein_seq_input_ids.size()[0] - device = protein_seq_input_ids.device - - # TODO: need double-check range of timesteps - timesteps = torch.rand(B, device=device) * (1 - EPS) + EPS # (B) - - condition = self.condition_proj_layer(condition) # (B, max_seq_len, condition_dim) ---> (B, max_seq_len, hidden_dim) - - protein_seq_onehot = F.one_hot(protein_seq_input_ids, num_classes=self.num_classes) # (B, max_seq_len, num_class) - protein_seq_onehot = protein_seq_onehot.float() - - protein_seq_repr = self.embedding_layer(protein_seq_onehot) # (B, max_seq_len, hidden_dim) - epsilon = torch.randn_like(protein_seq_repr.float()) # (B, max_seq_len, hidden_dim) - protein_seq_repr = self.embedding_layer(protein_seq_onehot) # (B, max_seq_len, hidden_dim) - mean_noise, std_noise = self.SDE_func.marginal_prob(protein_seq_repr, timesteps) # (B, max_seq_len, hidden_dim), (B) - protein_seq_repr_noise = mean_noise + std_noise[:, None, None] * epsilon # (B, max_seq_len, hidden_dim) - - score = self.score_network(protein_seq_repr=protein_seq_repr_noise, protein_seq_attention_mask=protein_seq_attention_mask, condition=condition) # (B, max_seq_len, hidden_dim) - score = -score / std_noise[:, None, None] - - total_SDE_loss = torch.square(score * std_noise[:, None, None] + epsilon) # (B, max_seq_len, hidden_dim) - masked_SDE_loss = total_SDE_loss * protein_seq_attention_mask.unsqueeze(2) # (B, max_seq_len, hidden_dim) - total_SDE_loss = torch.mean(total_SDE_loss) - masked_SDE_loss = masked_SDE_loss.sum() / protein_seq_attention_mask.sum() - SDE_loss = total_SDE_loss + masked_SDE_loss - - # regenerate protein_seq_repr - protein_seq_ids_pred_logit = self.decoder_layer(protein_seq_repr) # (B, max_seq_len, num_class) - flattened_logits = protein_seq_ids_pred_logit.view(-1, protein_seq_ids_pred_logit.size(-1)) # (B*max_sequence_len, num_class) - flattened_ids = protein_seq_input_ids.view(-1) # (B*max_sequence_len) - flattened_mask = protein_seq_attention_mask.view(-1) # (B*max_sequence_len) - total_decoding_loss = self.CE_criterion(flattened_logits, flattened_ids) # (B*max_sequence_len) - masked_decoding_loss = total_decoding_loss * flattened_mask # (B*max_sequence_len) - total_decoding_loss = torch.mean(total_decoding_loss) - masked_decoding_loss = masked_decoding_loss.sum() / flattened_mask.sum() - decoding_loss = total_decoding_loss + masked_decoding_loss - - return SDE_loss, decoding_loss - - @torch.no_grad() - def inference(self, condition, max_seq_len, protein_seq_attention_mask): - B = condition.size()[0] - device = condition.device - - shape = (B, max_seq_len, self.word_embedding_dim) - - X_T = self.SDE_func.prior_sampling(shape).to(device) # (B, max_seq_len, word_embedding_dim) - - EPSILON = 1e-5 - - timesteps = torch.linspace(self.SDE_func.T, EPSILON, self.num_diffusion_timesteps, device=device) # (num_diffusion_timesteps) - - condition = condition.float() - condition = self.condition_proj_layer(condition) # (B, max_seq_len, condition_dim) ---> (B, max_seq_len, hidden_dim) - - x = X_T - for i in range(0, self.num_diffusion_timesteps): - t = timesteps[i] - vec_t = torch.ones(shape[0], device=device) * t - - x, x_mean = self.corrector.update_fn(x, protein_seq_attention_mask, condition, vec_t) - - x, x_mean = self.predictor.update_fn(x, protein_seq_attention_mask, condition, vec_t) # (B, max_seq_len, num_class), (B, max_seq_len, num_class) - - x = self.decoder_layer(x) - - return x diff --git a/ProteinDT/model_Sampler.py b/ProteinDT/model_Sampler.py deleted file mode 100644 index d38a807..0000000 --- a/ProteinDT/model_Sampler.py +++ /dev/null @@ -1,89 +0,0 @@ -import abc -import torch -from .model_SDE import VPSDE, VESDE - - -class Predictor(abc.ABC): - """The abstract class for a predictor algorithm.""" - - def __init__(self, sde, score_fn, probability_flow=False): - super().__init__() - self.sde = sde - # Compute the reverse SDE/ODE - self.rsde = sde.reverse(score_fn, probability_flow) - self.score_fn = score_fn - - @abc.abstractmethod - def update_fn(self, x, t): - """One update of the predictor. - Args: - x: A PyTorch tensor representing the current state - t: A Pytorch tensor representing the current time step. - Returns: - x: A PyTorch tensor of the next state. - x_mean: A PyTorch tensor. The next state without random noise. Useful for denoising. - """ - pass - - -class Corrector(abc.ABC): - """The abstract class for a corrector algorithm.""" - - def __init__(self, sde, score_fn, snr, n_steps): - super().__init__() - self.sde = sde - self.score_fn = score_fn - self.snr = snr - self.n_steps = n_steps - - @abc.abstractmethod - def update_fn(self, x, t): - """One update of the corrector. - Args: - x: A PyTorch tensor representing the current state - t: A PyTorch tensor representing the current time step. - Returns: - x: A PyTorch tensor of the next state. - x_mean: A PyTorch tensor. The next state without random noise. Useful for denoising. - """ - pass - - -class ReverseDiffusionPredictor(Predictor): - def __init__(self, sde, score_fn, probability_flow=False): - super().__init__(sde, score_fn, probability_flow) - - def update_fn(self, x, x_mask, condition, t): - f, G = self.rsde.discretize(x, x_mask, condition, t) # (B, max_seq_len, num_class), (B) - z = torch.randn_like(x) # (B, max_seq_len, num_class) - x_mean = x - f # (B, max_seq_len, num_class) - x = x_mean + G[:, None, None] * z # (B, max_seq_len, num_class) - return x, x_mean - - -class LangevinCorrector(Corrector): - def __init__(self, sde, score_fn, snr, n_steps): - super().__init__(sde, score_fn, snr, n_steps) - if not isinstance(sde, VPSDE) and not isinstance(sde, VESDE): - raise NotImplementedError(f"SDE class {sde.__class__.__name__} not yet supported.") - - def update_fn(self, x, x_mask, condition, t): - sde = self.sde - score_fn = self.score_fn - n_steps = self.n_steps - target_snr = self.snr - if isinstance(sde, VPSDE): - timestep = (t * (sde.N - 1) / sde.T).long() - alpha = sde.alphas.to(t.device)[timestep] - else: - alpha = torch.ones_like(t) - - for i in range(n_steps): - grad = score_fn(x, x_mask, condition, t) # (B, max_seq_len, num_class) - noise = torch.randn_like(x) # (B, max_seq_len, num_class) - grad_norm = torch.norm(grad.reshape(grad.shape[0], -1), dim=-1).mean() # 1 - noise_norm = torch.norm(noise.reshape(noise.shape[0], -1), dim=-1).mean() # 1 - step_size = (target_snr * noise_norm / grad_norm) ** 2 * 2 * alpha # (B) - x_mean = x + step_size[:, None, None] * grad # (B, max_seq_len, num_class) - x = x_mean + torch.sqrt(step_size * 2)[:, None, None] * noise # (B, max_seq_len, num_class) - return x, x_mean diff --git a/ProteinDT/model_SDE.py b/ProteinDT/models/model_SDE.py similarity index 100% rename from ProteinDT/model_SDE.py rename to ProteinDT/models/model_SDE.py diff --git a/README.md b/README.md index a86f777..de03f1e 100644 --- a/README.md +++ b/README.md @@ -54,7 +54,7 @@ pip install fair-esm[esmfold]==2.0.0 --no-dependencies # Override deepspeed==0. pip install 'dllogger @ git+https://github.com/NVIDIA/dllogger.git' pip install 'openfold @ git+https://github.com/aqlaboratory/openfold.git@4b41059694619831a7db195b7e0988fc4ff3a307' -conda install mdtraj biopython -c conda-forge -yq +conda install -c conda-forge -yq mdtraj # for ProteinDT pip install . diff --git a/README_checkpoints.md b/README_checkpoints.md deleted file mode 100644 index 72558d6..0000000 --- a/README_checkpoints.md +++ /dev/null @@ -1 +0,0 @@ -# Checkpoints for ProteinDT \ No newline at end of file diff --git a/README_data.md b/README_data.md deleted file mode 100644 index 3ff9f1b..0000000 --- a/README_data.md +++ /dev/null @@ -1 +0,0 @@ -# Data for ProteinDT \ No newline at end of file diff --git a/examples/downstream_Editing/README.md b/examples/downstream_Editing/README.md index b7ca1e3..e299382 100644 --- a/examples/downstream_Editing/README.md +++ b/examples/downstream_Editing/README.md @@ -54,10 +54,22 @@ cp ../../output/downstream_TAPE/stability/ProtBERT_BFD/3-3e-5-5-2-16-0.08/pytorc datasets_and_checkpoints/stability/oracle/pytorch_model_stability.bin ``` -## 3 Peptide Editing +## 3 Region Editing ### 3.1 Dataset +Refer to this [paper](https://www.sciencedirect.com/science/article/pii/S0959440X22000513#appsec1), the excel data in the SI. + +Download the SI excel data to `datasets_and_checkpoints/region`, then: + +``` +python prepare_01_region.py +``` + +## 4 Peptide Editing + +### 4.1 Dataset + Refer to this [GitHub link](https://github.com/t7morgen/misato-dataset). ``` @@ -81,7 +93,7 @@ Back to this folder, and do the following: python prepare_01_peptide_editing_raw_and_processed_data.py ``` -### 3.2 Oracle Evaluator +### 4.2 Oracle Evaluator We will have to train an oracle model by ourselves + docking. diff --git a/examples/downstream_Editing/step_01_editing_Galactica.py b/examples/downstream_Editing/step_01_editing_Galactica.py index 538b802..f2fd167 100644 --- a/examples/downstream_Editing/step_01_editing_Galactica.py +++ b/examples/downstream_Editing/step_01_editing_Galactica.py @@ -32,14 +32,14 @@ def parse_Galatica_result(text_prompt, result): @torch.no_grad() -def inference_Galactica(dataloader, mutation_number): +def inference_Galactica(dataloader): if args.verbose: L = tqdm(dataloader) else: L = dataloader - galactica_tokenizer = AutoTokenizer.from_pretrained("facebook/galactica-1.3b") - galactica_model = OPTForCausalLM.from_pretrained("facebook/galactica-1.3b", device_map="auto") + galactica_tokenizer = AutoTokenizer.from_pretrained("facebook/galactica-1.3b", cache_dir="../../data/temp_Galactica") + galactica_model = OPTForCausalLM.from_pretrained("facebook/galactica-1.3b", device_map="auto", cache_dir="../../data/temp_Galactica") input_protein_sequence_list, edited_protein_sequence_list = [], [] for batch_idx, batch in enumerate(L): @@ -83,7 +83,6 @@ def inference_Galactica(dataloader, mutation_number): parser.add_argument("--seed", type=int, default=42) parser.add_argument("--batch_size", type=int, default=16) parser.add_argument("--num_workers", type=int, default=8) - parser.add_argument("--mutation_number", type=int, default=1) parser.add_argument("--editing_task", type=str, default="Villin") parser.add_argument("--dataset_size", type=int, default=None) @@ -114,9 +113,9 @@ def inference_Galactica(dataloader, mutation_number): ##### Load pretrained protein model if args.protein_backbone_model == "ProtBERT": - CLAP_protein_tokenizer = BertTokenizer.from_pretrained("Rostlab/prot_bert", do_lower_case=False) + CLAP_protein_tokenizer = BertTokenizer.from_pretrained("Rostlab/prot_bert", do_lower_case=False, cache_dir="../../data/temp_pretrained_ProtBert") elif args.protein_backbone_model == "ProtBERT_BFD": - CLAP_protein_tokenizer = BertTokenizer.from_pretrained("Rostlab/prot_bert_bfd", do_lower_case=False) + CLAP_protein_tokenizer = BertTokenizer.from_pretrained("Rostlab/prot_bert_bfd", do_lower_case=False, cache_dir="../../data/temp_pretrained_ProtBert_BFD") protein_dim = 1024 ##### load protein sequence @@ -128,7 +127,7 @@ def inference_Galactica(dataloader, mutation_number): protein_max_sequence_len=args.protein_max_sequence_len) dataloader = DataLoader(dataset, batch_size=args.batch_size, shuffle=False, num_workers=args.num_workers) - input_protein_sequence_list, edited_protein_sequence_list = inference_Galactica(dataloader, mutation_number=args.mutation_number) + input_protein_sequence_list, edited_protein_sequence_list = inference_Galactica(dataloader) if args.output_folder is None: exit() @@ -147,7 +146,7 @@ def inference_Galactica(dataloader, mutation_number): output_dataset = ProteinSeqDataset(edited_protein_sequence_list, eval_protein_tokenizer, args.protein_max_sequence_len) output_dataloader = DataLoader(output_dataset, batch_size=args.batch_size, shuffle=False, num_workers=args.num_workers) output_eval_list = evaluate(output_dataloader, eval_prediction_model, device, args) - file_path = os.path.join(args.output_folder, "{editing_task}_output_{mutation_number}.png".format(editing_task=args.editing_task, mutation_number=args.mutation_number)) + file_path = os.path.join(args.output_folder, "{editing_task}_output_1.png".format(editing_task=args.editing_task)) analyze(output_eval_list, args, file_path) @@ -170,13 +169,13 @@ def inference_Galactica(dataloader, mutation_number): if output_eval < input_eval: hit += 1 - elif args.editing_task in ["Villin", "Pin1", "hYAP65"]: - if args.text_prompt_id in [101, 102]: - if output_eval < input_eval: - hit += 1 - elif args.text_prompt_id in [201, 202]: + elif args.editing_task in ["Villin", "Pin1"]: + if args.text_prompt_id in [101]: if output_eval > input_eval: hit += 1 + elif args.text_prompt_id in [201]: + if output_eval < input_eval: + hit += 1 hit_ratio = 100. * hit / total print("hit: {}".format(hit)) diff --git a/examples/downstream_Editing/step_01_editing_latent_interpolation.py b/examples/downstream_Editing/step_01_editing_latent_interpolation.py index f6634bc..439cc42 100644 --- a/examples/downstream_Editing/step_01_editing_latent_interpolation.py +++ b/examples/downstream_Editing/step_01_editing_latent_interpolation.py @@ -271,12 +271,7 @@ def inference(dataloader, theta, text_condition_repr, device): if args.decoder_distribution in ["T5Decoder"]: protein_decoder_tokenizer = T5Tokenizer.from_pretrained("Rostlab/prot_t5_xl_uniref50", do_lower_case=False, chache_dir="../../data/temp_pretrained_t5_base") else: - protein_decoder_tokenizer = BertTokenizer.from_pretrained("Rostlab/prot_bert", do_lower_case=False, chache_dir="../../data/temp_pretrained_PotBert") - # print("protein_decoder_tokenizer pad_token_id", protein_decoder_tokenizer.pad_token_id) - # print("protein_decoder_tokenizer sep_token_id", protein_decoder_tokenizer.sep_token_id) - # print("protein_decoder_tokenizer eos_token_id", protein_decoder_tokenizer.eos_token_id) - # print(CLAP_protein_tokenizer.get_vocab()) - # print(protein_decoder_tokenizer.get_vocab()) + protein_decoder_tokenizer = BertTokenizer.from_pretrained("Rostlab/prot_bert", do_lower_case=False, chache_dir="../../data/temp_pretrained_ProtBert") ##### Load pretrained decoder model if args.decoder_distribution == "MultinomialDiffusion": @@ -372,13 +367,13 @@ def inference(dataloader, theta, text_condition_repr, device): if output_eval < input_eval: hit += 1 - elif args.editing_task in ["Villin", "Pin1", "hYAP65"]: - if args.text_prompt_id in [101, 102]: - if output_eval < input_eval: - hit += 1 - elif args.text_prompt_id in [201, 202]: + elif args.editing_task in ["Villin", "Pin1"]: + if args.text_prompt_id in [101]: if output_eval > input_eval: hit += 1 + elif args.text_prompt_id in [201]: + if output_eval < input_eval: + hit += 1 hit_ratio = 100. * hit / total print("hit: {}".format(hit)) diff --git a/examples/downstream_Editing/step_01_editing_latent_optimization.py b/examples/downstream_Editing/step_01_editing_latent_optimization.py index 291e62b..370acb0 100644 --- a/examples/downstream_Editing/step_01_editing_latent_optimization.py +++ b/examples/downstream_Editing/step_01_editing_latent_optimization.py @@ -417,13 +417,13 @@ def inference(dataloader, text_prompt_CLAP_repr, device): if output_eval < input_eval: hit += 1 - elif args.editing_task in ["Villin", "Pin1", "hYAP65"]: - if args.text_prompt_id in [101, 102]: - if output_eval < input_eval: - hit += 1 - elif args.text_prompt_id in [201, 202]: + elif args.editing_task in ["Villin", "Pin1"]: + if args.text_prompt_id in [101]: if output_eval > input_eval: hit += 1 + elif args.text_prompt_id in [201]: + if output_eval < input_eval: + hit += 1 hit_ratio = 100. * hit / total print("hit: {}".format(hit)) diff --git a/examples/downstream_Editing/step_01_evaluate_region.py b/examples/downstream_Editing/step_01_evaluate_region.py new file mode 100644 index 0000000..e545a3a --- /dev/null +++ b/examples/downstream_Editing/step_01_evaluate_region.py @@ -0,0 +1,205 @@ +import os +import random +import argparse +import numpy as np +from tqdm import tqdm +import string +import re + +import torch +import torch.nn as nn + +from transformers import BertModel, BertTokenizer +from torch.utils.data import DataLoader + +from utils import ProteinDataset, ProteinSeqDataset, text_prompt_dict, load_oracle_evaluator, evaluate, analyze + + +from transformers import AutoTokenizer, EsmForProteinFolding +from transformers.models.esm.openfold_utils.protein import to_pdb, Protein as OFProtein +from transformers.models.esm.openfold_utils.feats import atom14_to_atom37 + + +@torch.no_grad() +def save_PDB_list(PDB_list, idx_list, PDB_output_folder): + for PDB, idx in zip(PDB_list, idx_list): + file_name = os.path.join(PDB_output_folder, "{}.txt".format(idx)) + f = open(file_name, "w") + f.write("".join(PDB)) + return + + +def convert_outputs_to_pdb(outputs): + final_atom_positions = atom14_to_atom37(outputs["positions"][-1], outputs) + outputs = {k: v.to("cpu").numpy() for k, v in outputs.items()} + final_atom_positions = final_atom_positions.cpu().numpy() + final_atom_mask = outputs["atom37_atom_exists"] + pdbs = [] + for i in range(outputs["aatype"].shape[0]): + aa = outputs["aatype"][i] + pred_pos = final_atom_positions[i] + mask = final_atom_mask[i] + resid = outputs["residue_index"][i] + 1 + pred = OFProtein( + aatype=aa, + atom_positions=pred_pos, + atom_mask=mask, + residue_index=resid, + b_factors=outputs["plddt"][i], + chain_index=outputs["chain_index"][i] if "chain_index" in outputs else None, + ) + pdb = to_pdb(pred) + pdbs.append(pdb) + return pdbs + + +@torch.no_grad() +def evaluate_folding(protein_sequence_list): + PDB_data_list, idx_list, plddt_value_list = [], [], [] + for idx, protein_sequence in enumerate(protein_sequence_list): + print("protein_sequence", protein_sequence) + + tokenized_input = folding_tokenizer(protein_sequence, return_tensors="pt", add_special_tokens=False)['input_ids'] + tokenized_input = tokenized_input.to(device) + + output = folding_model(tokenized_input) + plddt_value = output["plddt"].squeeze(0) + tokenized_input = tokenized_input.squeeze(0) + + plddt_value_total = 0 + L = plddt_value.shape[0] + for i in range(L): + plddt_value_total += plddt_value[i][tokenized_input[i]] + plddt_value_mean = (plddt_value_total / L).item() + + PDB_list = convert_outputs_to_pdb(output) + + PDB_data_list.extend(PDB_list) + idx_list.append(idx) + plddt_value_list.append(plddt_value_mean) + + return PDB_data_list, idx_list, plddt_value_list + + +if __name__ == "__main__": + parser = argparse.ArgumentParser() + parser.add_argument("--device", type=int, default=0) + parser.add_argument("--seed", type=int, default=42) + parser.add_argument("--batch_size", type=int, default=16) + parser.add_argument("--num_workers", type=int, default=8) + parser.add_argument("--mutation_number", type=int, default=1) + + parser.add_argument("--editing_task", type=str, default="region") + parser.add_argument("--dataset_size", type=int, default=None) + parser.add_argument("--text_prompt_id", type=int, default=101) + + parser.add_argument("--output_folder", type=str, default=None) + parser.add_argument("--output_text_file_path", type=str, default=None) + + parser.add_argument("--verbose", dest="verbose", action="store_true") + parser.set_defaults(verbose=False) + + parser.add_argument("--protein_backbone_model", type=str, default="ProtBERT_BFD", choices=["ProtBERT", "ProtBERT_BFD"]) + parser.add_argument("--protein_max_sequence_len", type=int, default=512) + + args = parser.parse_args() + print("arguments", args) + + assert args.editing_task in ["region"] + + random.seed(args.seed) + os.environ['PYTHONHASHSEED'] = str(args.seed) + np.random.seed(args.seed) + torch.manual_seed(args.seed) + torch.cuda.manual_seed(args.seed) + torch.cuda.manual_seed_all(args.seed) + torch.backends.cudnn.deterministic = True + torch.backends.cudnn.benchmark = True + + device = torch.device("cuda:" + str(args.device)) if torch.cuda.is_available() else torch.device("cpu") + + folding_tokenizer = AutoTokenizer.from_pretrained("facebook/esmfold_v1", cache_dir="../../data/temp_pretrained_ESMFold") + folding_model = EsmForProteinFolding.from_pretrained("facebook/esmfold_v1", cache_dir="../../data/temp_pretrained_ESMFold").to(device) + + ##### Load pretrained protein model + if args.protein_backbone_model == "ProtBERT": + CLAP_protein_tokenizer = BertTokenizer.from_pretrained("Rostlab/prot_bert", do_lower_case=False, cache_dir="../../data/temp_pretrained_ProtBert") + elif args.protein_backbone_model == "ProtBERT_BFD": + CLAP_protein_tokenizer = BertTokenizer.from_pretrained("Rostlab/prot_bert_bfd", do_lower_case=False, cache_dir="../../data/temp_pretrained_ProtBert_BFD") + protein_dim = 1024 + + ##### load protein sequence + dataset_file_path = os.path.join(text_prompt_dict[args.editing_task]["data_folder"], "preprocessed_data.csv") + dataset = ProteinDataset( + dataset_file_path=dataset_file_path, + dataset_size=args.dataset_size, + protein_tokenizer=CLAP_protein_tokenizer, + protein_max_sequence_len=args.protein_max_sequence_len) + dataloader = DataLoader(dataset, batch_size=args.batch_size, shuffle=False, num_workers=args.num_workers) + + if args.output_text_file_path is None: + args.output_text_file_path = os.path.join(args.output_folder, "step_01_editing.txt") + + f = open(args.output_text_file_path, "r") + input_protein_sequence_list, edited_protein_sequence_list = [], [] + for line in f.readlines(): + if line.startswith("input"): + line = line.strip().split(",") + input_protein_sequence_list.append(line[1]) + value = line[2].replace("[", "").replace("]", "") + elif line.startswith("output"): + line = line.strip().split(",") + edited_protein_sequence = line[1] + edited_protein_sequence = re.sub(r"[UZOB]", "X", edited_protein_sequence) + edited_protein_sequence_list.append(edited_protein_sequence) + value = line[2].replace("[", "").replace("]", "") + + neo_input_protein_sequence_list, neo_edited_protein_sequence_list = [], [] + for a,c in zip(input_protein_sequence_list, edited_protein_sequence_list): + if len(c) == 0: + continue + neo_input_protein_sequence_list.append(a) + neo_edited_protein_sequence_list.append(c) + input_protein_sequence_list, edited_protein_sequence_list = neo_input_protein_sequence_list, neo_edited_protein_sequence_list + + input_PDB_list, idx_list, input_plddt_list = evaluate_folding(input_protein_sequence_list) + PDB_output_folder = os.path.join(args.output_folder, "input_PDB") + os.makedirs(PDB_output_folder, exist_ok = True) + save_PDB_list(input_PDB_list, idx_list, PDB_output_folder) + + output_PDB_list, idx_list, output_plddt_list = evaluate_folding(edited_protein_sequence_list) + PDB_output_folder = os.path.join(args.output_folder, "output_PDB") + os.makedirs(PDB_output_folder, exist_ok = True) + save_PDB_list(output_PDB_list, idx_list, PDB_output_folder) + + ##### compare + evaluation_output_file_path = os.path.join(args.output_folder, "step_01_evaluate.txt") + f = open(evaluation_output_file_path, 'w') + plddt_hit, total = 0, 0 + for input_protein, input_plddt, edited_protein, output_plddt in zip(input_protein_sequence_list, input_plddt_list, edited_protein_sequence_list, output_plddt_list): + print('input,{},{}'.format(input_protein, input_plddt), file=f) + print('output,{},{}'.format(edited_protein, output_plddt), file=f) + + total += 1 + + if args.text_prompt_id in [101]: + if output_plddt > input_plddt: + plddt_hit += 1 + elif args.text_prompt_id in [201]: + if output_plddt < input_plddt: + plddt_hit += 1 + else: + raise ValueError("No valid prompt id {}".format(args.text_prompt_id)) + + if total > 0: + plddt_hit_ratio = 100. * plddt_hit / total + print("#1 pLDDT hit: {}".format(plddt_hit)) + print("#1 pLDDT total: {}".format(total)) + print("#1 pLDDT hit ratio: {}".format(plddt_hit_ratio)) + + total = len(dataset) + + plddt_hit_ratio = 100. * plddt_hit / total + print("pLDDT hit: {}".format(plddt_hit)) + print("pLDDT total: {}".format(total)) + print("pLDDT hit ratio: {}".format(plddt_hit_ratio)) diff --git a/examples/downstream_Editing/step_01_evaluate_stability.py b/examples/downstream_Editing/step_01_evaluate_stability.py index a4f2f5d..645bead 100644 --- a/examples/downstream_Editing/step_01_evaluate_stability.py +++ b/examples/downstream_Editing/step_01_evaluate_stability.py @@ -60,7 +60,7 @@ def evaluate_folding(protein_sequence_list): print("protein_sequence", protein_sequence) tokenized_input = folding_tokenizer(protein_sequence, return_tensors="pt", add_special_tokens=False)['input_ids'] - tokenized_input = tokenized_input.cuda() + tokenized_input = tokenized_input.to(device) output = folding_model(tokenized_input) plddt_value = output["plddt"].squeeze(0) @@ -118,14 +118,14 @@ def evaluate_folding(protein_sequence_list): device = torch.device("cuda:" + str(args.device)) if torch.cuda.is_available() else torch.device("cpu") - folding_tokenizer = AutoTokenizer.from_pretrained("facebook/esmfold_v1") - folding_model = EsmForProteinFolding.from_pretrained("facebook/esmfold_v1", low_cpu_mem_usage=True).to(device) + folding_tokenizer = AutoTokenizer.from_pretrained("facebook/esmfold_v1", cache_dir="../../data/temp_pretrained_ESMFold") + folding_model = EsmForProteinFolding.from_pretrained("facebook/esmfold_v1", cache_dir="../../data/temp_pretrained_ESMFold").to(device) ##### Load pretrained protein model if args.protein_backbone_model == "ProtBERT": - CLAP_protein_tokenizer = BertTokenizer.from_pretrained("Rostlab/prot_bert", do_lower_case=False) + CLAP_protein_tokenizer = BertTokenizer.from_pretrained("Rostlab/prot_bert", do_lower_case=False, cache_dir="../../data/temp_pretrained_ProtBert") elif args.protein_backbone_model == "ProtBERT_BFD": - CLAP_protein_tokenizer = BertTokenizer.from_pretrained("Rostlab/prot_bert_bfd", do_lower_case=False) + CLAP_protein_tokenizer = BertTokenizer.from_pretrained("Rostlab/prot_bert_bfd", do_lower_case=False, cache_dir="../../data/temp_pretrained_ProtBert_BFD") protein_dim = 1024 ##### load protein sequence @@ -186,16 +186,18 @@ def evaluate_folding(protein_sequence_list): total += 1 - if args.text_prompt_id in [101, 102]: + if args.text_prompt_id in [101]: if output_eval > input_eval: eval_hit += 1 if output_plddt > input_plddt: plddt_hit += 1 - elif args.text_prompt_id in [201, 202]: + elif args.text_prompt_id in [201]: if output_eval < input_eval: eval_hit += 1 if output_plddt < input_plddt: plddt_hit += 1 + else: + raise ValueError("No valid prompt id {}".format(args.text_prompt_id)) if total > 0: eval_hit_ratio = 100. * eval_hit / total diff --git a/examples/downstream_Editing/step_01_evaluate_structure.py b/examples/downstream_Editing/step_01_evaluate_structure.py index b9da823..4cc56f2 100644 --- a/examples/downstream_Editing/step_01_evaluate_structure.py +++ b/examples/downstream_Editing/step_01_evaluate_structure.py @@ -79,7 +79,7 @@ def evaluate_folding(protein_sequence_list): for idx, protein_sequence in enumerate(tqdm(protein_sequence_list)): tokenized_input = folding_tokenizer(protein_sequence, return_tensors="pt", add_special_tokens=False)['input_ids'] - tokenized_input = tokenized_input.cuda() + tokenized_input = tokenized_input.to(device) output = folding_model(tokenized_input) @@ -128,15 +128,15 @@ def evaluate_folding(protein_sequence_list): device = torch.device("cuda:" + str(args.device)) if torch.cuda.is_available() else torch.device("cpu") - folding_tokenizer = AutoTokenizer.from_pretrained("facebook/esmfold_v1") - folding_model = EsmForProteinFolding.from_pretrained("facebook/esmfold_v1", low_cpu_mem_usage=True).to(device) + folding_tokenizer = AutoTokenizer.from_pretrained("facebook/esmfold_v1", cache_dir="../../data/temp_pretrained_ESMFold") + folding_model = EsmForProteinFolding.from_pretrained("facebook/esmfold_v1", low_cpu_mem_usage=True, cache_dir="../../data/temp_pretrained_ESMFold").to(device) folding_model.trunk.set_chunk_size(512) ##### Load pretrained protein model if args.protein_backbone_model == "ProtBERT": - CLAP_protein_tokenizer = BertTokenizer.from_pretrained("Rostlab/prot_bert", do_lower_case=False) + CLAP_protein_tokenizer = BertTokenizer.from_pretrained("Rostlab/prot_bert", do_lower_case=False, cache_dir="../../data/temp_pretrained_ProtBert") elif args.protein_backbone_model == "ProtBERT_BFD": - CLAP_protein_tokenizer = BertTokenizer.from_pretrained("Rostlab/prot_bert_bfd", do_lower_case=False) + CLAP_protein_tokenizer = BertTokenizer.from_pretrained("Rostlab/prot_bert_bfd", do_lower_case=False, cache_dir="../../data/temp_pretrained_ProtBert_BFD") protein_dim = 1024 ##### load protein sequence diff --git a/examples/downstream_Editing/test.py b/examples/downstream_Editing/test.py deleted file mode 100644 index 2c28f54..0000000 --- a/examples/downstream_Editing/test.py +++ /dev/null @@ -1,28 +0,0 @@ -import os -import random -import argparse -import numpy as np -from tqdm import tqdm -import string -import mdtraj as md - -import torch -import torch.nn as nn - -from transformers import BertModel, BertTokenizer -from torch.utils.data import DataLoader - -from utils import ProteinDataset, ProteinSeqDataset, text_prompt_dict, load_oracle_evaluator, evaluate, analyze - -from transformers import AutoTokenizer, EsmForProteinFolding -from transformers.models.esm.openfold_utils.protein import to_pdb, Protein as OFProtein -from transformers.models.esm.openfold_utils.feats import atom14_to_atom37 - - - -# traj = md.load_pdb("0.pdb") -traj = md.load_pdb("../../output/ProteinDT/ProtBERT_BFD-512-1e-5-1e-1-text-512-1e-5-1e-1-EBM_NCE-0.1-batch-9-gpu-8-epoch-5/step_06_AE_1e-3_3/downstream_Editing_latent_optimization/alpha_prompt_101_lambda_0.1_num_repeat_16_oracle_text_T_2/input_PDB/0.pdb") -print(traj) - -pdb_ss = md.compute_dssp(traj, simplified=True)[0] # (L, ) -print(pdb_ss) diff --git a/examples/downstream_Editing/utils.py b/examples/downstream_Editing/utils.py index dfa15b1..08eeb92 100644 --- a/examples/downstream_Editing/utils.py +++ b/examples/downstream_Editing/utils.py @@ -5,7 +5,7 @@ import torch.nn as nn from torch.utils.data import Dataset from tqdm import tqdm -from transformers import BertTokenizer +from transformers import BertTokenizerFast from ProteinDT.TAPE_benchmark.models import BertForTokenClassification2, BertForSequenceClassification2 import matplotlib import matplotlib.pyplot as plt @@ -27,27 +27,18 @@ }, "Villin": { 101: "modify the amino acid sequence to have higher stability", - 102: "modify the amino acid sequence to have fewer intrinsically disordered regions", 201: "modify the amino acid sequence to have lower stability", - 202: "modify the amino acid sequence to have more intrinsically disordered regions", "data_folder": "datasets_and_checkpoints/stability/Villin", }, "Pin1": { 101: "modify the amino acid sequence to have higher stability", - 102: "modify the amino acid sequence to have fewer intrinsically disordered regions", 201: "modify the amino acid sequence to have lower stability", - 202: "modify the amino acid sequence to have more intrinsically disordered regions", "data_folder": "datasets_and_checkpoints/stability/Pin1", }, - "hYAP65": { - 101: "modify the amino acid sequence to have higher stability", - 102: "modify the amino acid sequence to have fewer intrinsically disordered regions", - 201: "modify the amino acid sequence to have lower stability", - 202: "modify the amino acid sequence to have more intrinsically disordered regions", - "data_folder": "datasets_and_checkpoints/stability/hYAP65", - }, - "stability": { - "data_folder": "datasets_and_checkpoints/stability_test" + "region": { + 101: "modify the amino acid sequence to have more ordered regions", + 201: "modify the amino acid sequence to have more disordered regions", + "data_folder": "datasets_and_checkpoints/region", }, "peptide_binding": { 101: "modify the peptide amino acid sequence to have higher binding affinity with the target protein. The target protein satisfies the following property. {}", @@ -126,12 +117,15 @@ def __len__(self): def load_oracle_evaluator(editing_task, device, input_model_path=None): + cache_dir = "../../data/temp_pretrained_ProtBert_BFD" + if editing_task in ["alpha", "beta"]: num_labels = 3 eval_prediction_model = BertForTokenClassification2.from_pretrained( "Rostlab/prot_bert_bfd", mean_output=True, num_labels=num_labels, + cache_dir=cache_dir, ) if input_model_path is None: input_model_path = os.path.join("datasets_and_checkpoints/structure/oracle/pytorch_model_ss3.bin") @@ -142,6 +136,7 @@ def load_oracle_evaluator(editing_task, device, input_model_path=None): "Rostlab/prot_bert_bfd", mean_output=True, num_labels=num_labels, + cache_dir=cache_dir, ) if input_model_path is None: input_model_path = os.path.join("datasets_and_checkpoints/stability/oracle/pytorch_model_stability.bin") @@ -153,15 +148,20 @@ def load_oracle_evaluator(editing_task, device, input_model_path=None): "Rostlab/prot_bert_bfd", mean_output=True, num_labels=num_labels, + cache_dir=cache_dir, ) if input_model_path is None: input_model_path = os.path.join("datasets_and_checkpoints/stability/oracle/pytorch_model_stability.bin") + elif editing_task == "region": + eval_prediction_model = None + print("Loading protein model from {}...".format(input_model_path)) - state_dict = torch.load(input_model_path, map_location='cpu') - eval_prediction_model.load_state_dict(state_dict) - eval_prediction_model = eval_prediction_model.to(device) - eval_protein_tokenizer = BertTokenizer.from_pretrained("Rostlab/prot_bert_bfd", do_lower_case=False) + if eval_prediction_model is not None: + state_dict = torch.load(input_model_path, map_location='cpu') + eval_prediction_model.load_state_dict(state_dict) + eval_prediction_model = eval_prediction_model.to(device) + eval_protein_tokenizer = BertTokenizerFast.from_pretrained("Rostlab/prot_bert_bfd", do_lower_case=False, cache_dir=cache_dir) return eval_prediction_model, eval_protein_tokenizer @@ -200,7 +200,8 @@ def load_editing_dataset_and_loader(args, eval_protein_tokenizer): @torch.no_grad() def evaluate(dataloader, eval_prediction_model, device, args): - eval_prediction_model.eval() + if eval_prediction_model is not None: + eval_prediction_model.eval() L = tqdm(dataloader) result_list = [] @@ -209,18 +210,20 @@ def evaluate(dataloader, eval_prediction_model, device, args): protein_sequence_input_ids = batch["protein_sequence_input_ids"].to(device) protein_sequence_attention_mask = batch["protein_sequence_attention_mask"].to(device) - output = eval_prediction_model(protein_sequence_input_ids, protein_sequence_attention_mask) - logits = output.logits - - if args.editing_task in ["alpha", "beta"]: - pred = logits.argmax(dim=-1, keepdim=False) - pred = torch.where(protein_sequence_attention_mask==1, pred, -1) - pred = (pred == text_prompt_dict[args.editing_task]["target_label"]).sum(-1) + if eval_prediction_model is not None: + output = eval_prediction_model(protein_sequence_input_ids, protein_sequence_attention_mask) + logits = output.logits + + if args.editing_task in ["alpha", "beta"]: + pred = logits.argmax(dim=-1, keepdim=False) + pred = torch.where(protein_sequence_attention_mask==1, pred, -1) + pred = (pred == text_prompt_dict[args.editing_task]["target_label"]).sum(-1) + else: + pred = logits + result_list.append(pred.detach().cpu().numpy()) else: - pred = logits - - result_list.append(pred.detach().cpu().numpy()) + result_list.append(np.array([0 for _ in range(len(protein_sequence_input_ids))])) result_list = np.concatenate(result_list) return result_list @@ -250,19 +253,3 @@ def slerp(theta, start, end): so = torch.sin(omega) res = (torch.sin((1.0-theta)*omega)/so).unsqueeze(1) * start + (torch.sin(theta*omega)/so).unsqueeze(1) * end return res - - -if __name__ == "__main__": - start = torch.Tensor([[0, 0, 0], [1, 1, 1]]) - end = torch.Tensor([[1, 1, 1], [2, 2, 2]]) - - start = torch.Tensor([[1, 1, 1], [2, 2, 2]]) - end = torch.Tensor([[3, 3, 3], [6, 6, 6]]) - - start = torch.Tensor([[1, 1, 1], [1, 2, 2]]) - end = torch.Tensor([[2, 2, 2], [6, 6, 6]]) - - theta_list = [0, 0.1, 0.5, 0.9, 1] - for theta in theta_list: - interpolation = slerp(theta, start, end) - print(theta, interpolation) \ No newline at end of file diff --git a/examples/downstream_Editing/utils_analysis.py b/examples/downstream_Editing/utils_analysis.py index c8e3f91..6cc7c63 100644 --- a/examples/downstream_Editing/utils_analysis.py +++ b/examples/downstream_Editing/utils_analysis.py @@ -1,16 +1,13 @@ -task_list = ["alpha", "beta", "Villin", "Pin1", "peptide_binding"] +task_list = ["alpha", "beta", "Villin", "Pin1", "region", "peptide_binding"] text_prompt_dict = { "alpha": [101, 201], "beta": [101, 201], - # "Villin": [101, 102, 201, 202], - # "Pin1": [101, 102, 201, 202], - # "hYAP65": [101, 102, 201, 202], "Villin": [101, 201], "Pin1": [101, 201], - "hYAP65": [101, 201], + "region": [101, 201], "peptide_binding": [101, 201], } @@ -24,8 +21,8 @@ def prase_hit_ratio(filename): line = f.readline() if not line: break - line = line.strip() + if line.startswith("hit:"): hit = int(line.split(":")[1].strip()) elif line.startswith("total:"): @@ -51,7 +48,8 @@ def prase_hit_ratio(filename): second_hit = int(line.split(":")[1].strip()) elif line.startswith("pLDDT total:"): second_total = int(line.split(":")[1].strip()) - assert total == second_total + if total is not None: + assert total == second_total elif line.startswith("pLDDT hit ratio:"): second_hit_ratio = float(line.split(":")[1].strip()) except: diff --git a/examples/downstream_TAPE_analysis.py b/examples/downstream_TAPE_analysis.py deleted file mode 100644 index 6bb9135..0000000 --- a/examples/downstream_TAPE_analysis.py +++ /dev/null @@ -1,116 +0,0 @@ -import os - - -def extract(file_path, task): - f_ = open(file_path, "r") - - if task in ["ss3", "ss8"]: - for line in f_.readlines(): - line = line.strip() - if line.startswith("cb513"): - line = line.replace("{", "").replace("}", "").replace(":", ",") - line = line.split(",") - value = float(line[4]) - value = "{:.5f}".format(value) - - elif task == "contact": - for line in f_.readlines(): - line = line.strip() - if line.startswith("metrics"): - line = line.replace("{", "").replace("}", "").replace(":", ",") - line = line.split(",") - value = float(line[3]) - value = "{:.5f}".format(value) - - elif task == "remote_homology": - for line in f_.readlines(): - line = line.strip() - if line.startswith("metrics_fold"): - line = line.replace("{", "").replace("}", "").replace(":", ",") - line = line.split(",") - value = float(line[4]) - value = "{:.5f}".format(value) - - elif task == "fluorescence": - for line in f_.readlines(): - line = line.strip() - if line.startswith("metrics"): - line = line.replace("{", "").replace("}", "").replace(":", ",") - line = line.split(",") - value = float(line[3]) - value = "{:.5f}".format(value) - - elif task == "stability": - for line in f_.readlines(): - line = line.strip() - if line.startswith("metrics"): - line = line.replace("{", "").replace("}", "").replace(":", ",") - line = line.split(",") - value = float(line[3]) - value = "{:.5f}".format(value) - - return value - - -task2hyper = { - "ss3": [ - "3-3e-5-5-2-8-0.08", - ], - "ss8": [ - "3-3e-5-5-2-8-0.08", - # "3-3e-5-5-2-16-0.08", - ], - "contact": [ - # "3-3e-5-10-1-1-0.08", - "3-3e-5-10-1-2-0.08", - ], - "remote_homology": [ - # "3-3e-5-10-1-64-0.08", - "3-3e-5-10-8-8-0.08", - ], - "fluorescence": [ - "3-3e-5-25-4-16-0.0-True", - ], - "stability": [ - # "3-3e-5-1-2-16-0.08", - # "3-3e-5-3-2-16-0.08", - "3-3e-5-5-2-16-0.08", - ], -} - -if __name__ == "__main__": - task_list = [ - "ss3", "ss8", "contact", "remote_homology", "fluorescence", "stability", - ] - pretrained_model_list=[ - "ProteinDT/ProtBERT_BFD-512-1e-5-1e-1-text-512-1e-5-1e-1-EBM_NCE-0.1-batch-9-gpu-8-epoch-5", - "ProteinDT/ProtBERT_BFD-512-1e-5-1e-1-text-512-1e-5-1e-1-InfoNCE-0.1-batch-9-gpu-8-epoch-5", - # "ProteinDT/ProtBERT_BFD-512-1e-5-1-text-512-1e-5-1-EBM_NCE-1-batch-9-gpu-8-epoch-10", - ] - - for pretrained_mode in pretrained_model_list: - row = "ProteinCLAP" - - for task in task_list: - value_list = [] - for hyper in task2hyper[task]: - file_path = os.path.join("../output", pretrained_mode, "downstream_TAPE", task, hyper, "result.txt") - try: - value = extract(file_path, task) - value_list.append(value) - except: - print("% missing {}".format(file_path)) - - if len(value_list) > 0: - optimal_value = max(value_list) - print("task", task, value_list) - row = "{} & {}".format(row, optimal_value) - else: - row = "{} & {}".format(row, "--") - - print("%", pretrained_mode) - row += "\\\\" - print(row) - print() - print() - print() diff --git a/examples/downstream_Text2Protein/step_03_analyze.py b/examples/downstream_Text2Protein/step_03_analyze.py deleted file mode 100644 index 16d8084..0000000 --- a/examples/downstream_Text2Protein/step_03_analyze.py +++ /dev/null @@ -1,132 +0,0 @@ -import os -import itertools - - -def extract(file_path): - evaluation_result_list = [] - f = open(file_path, 'r') - for line in f.readlines(): - # print(line) - if not line.startswith("evaluation_T:"): - continue - line = line.replace("accuracy:", ",").replace(":", ",").split(",") - evaluation_T = int(line[1]) - accuracy = float(line[2]) - evaluation_result_list.append([evaluation_T, accuracy]) - - row = "" - for evaluation_result in evaluation_result_list: - row = "{} & {:.2f}".format(row, evaluation_result[1]) - - if "MultinomialDiffusion_RNN" in file_path: - head = "\\ProteinSDE{}-RNN" - elif "MultinomialDiffusion_BertBase" in file_path: - head = "\\ProteinSDE{}-BERT" - else: - head = "AR" - - if "no_use_facilitator" in file_path: - facilitator = "no_facilitator" - else: - facilitator = "facilitator" - - row = head + " & {}".format(facilitator) + row + "\\\\" - print(row) - print() - return - - -def extract_results_Diffusion(pretrained_mode, decoder_distribution, score_network_type): - for lr, hidden_dim, epochs, prob_unconditional in itertools.product(*[lr_list, hidden_dim_list, epochs_list, prob_unconditional_list]): - - step_04_folder = "../../output/{}/step_04_{}_{}_lr_{lr}_hidden_{hidden_dim}_e_{epochs}_unconditional_{prob_unconditional}".format( - pretrained_mode, decoder_distribution, score_network_type, lr=lr, hidden_dim=hidden_dim, epochs=epochs, prob_unconditional=prob_unconditional) - # print("step_04_folder", step_04_folder) - print("% step_04_{}_{}_lr_{lr}_hidden_{hidden_dim}_e_{epochs}_unconditional_{prob_unconditional}".format(decoder_distribution, score_network_type, lr=lr, hidden_dim=hidden_dim, epochs=epochs, prob_unconditional=prob_unconditional)) - - for num_repeat, facilitator, SDE_sampling_mode in itertools.product(*[num_repeat_list, facilitator_list, SDE_sampling_mode_list]): - print("% num_repeat: {}".format(num_repeat)) - retrieval_folder = os.path.join(step_04_folder, "downstream_Retrieval/num_repeat_{}_{}_{}".format(num_repeat, facilitator, SDE_sampling_mode)) - retrieval_file_path = os.path.join(retrieval_folder, "step_02_inference.out") - - try: - print("% ", retrieval_file_path) - extract(retrieval_file_path) - except: - print("file {} missing or still running.".format(retrieval_file_path)) - continue - - print() - - return - - -def extract_results_AR(pretrained_mode, decoder_distribution, score_network_type): - for lr, hidden_dim, epochs, prob_unconditional in itertools.product(*[lr_list, hidden_dim_list, epochs_list, prob_unconditional_list]): - - step_04_folder = "../../output/{}/step_04_{}_{}_lr_{lr}_hidden_{hidden_dim}_e_{epochs}_unconditional_{prob_unconditional}".format( - pretrained_mode, decoder_distribution, score_network_type, lr=lr, hidden_dim=hidden_dim, epochs=epochs, prob_unconditional=prob_unconditional) - # print("step_04_folder", step_04_folder) - print("% step_04_{}_{}_lr_{lr}_hidden_{hidden_dim}_e_{epochs}_unconditional_{prob_unconditional}".format(decoder_distribution, score_network_type, lr=lr, hidden_dim=hidden_dim, epochs=epochs, prob_unconditional=prob_unconditional)) - - for num_repeat, facilitator, AR_generation_mode in itertools.product(*[num_repeat_list, facilitator_list, AR_generation_mode_list]): - print("% num_repeat: {}\t{}".format(num_repeat, AR_generation_mode)) - retrieval_folder = os.path.join(step_04_folder, "downstream_Retrieval/num_repeat_{}_{}_inference_{}".format(num_repeat, facilitator, AR_generation_mode)) - # retrieval_file_path = os.path.join(retrieval_folder, "downstream_Retrieval_step_02_inference.out") - retrieval_file_path = os.path.join(retrieval_folder, "step_02_inference.out") - - try: - print("% ", retrieval_file_path) - extract(retrieval_file_path) - except: - print("file {} missing or still running.".format(retrieval_file_path)) - continue - - print() - - return - - -if __name__ == "__main__": - # Hyperparameters for step-04 pretraining - pretrained_mode_list = [ - "ProteinDT/ProtBERT_BFD-512-1e-5-1-text-512-1e-5-1-EBM_NCE-1-batch-9-gpu-8-epoch-10", - "ProteinDT/ProtBERT_BFD-512-1e-5-1e-1-text-512-1e-5-1e-1-EBM_NCE-0.1-batch-9-gpu-8-epoch-5", - "ProteinDT/ProtBERT_BFD-512-1e-5-1e-1-text-512-1e-5-1e-1-InfoNCE-0.1-batch-9-gpu-8-epoch-5", - ] - prob_unconditional_list = [0, 0.1] - epochs_list = [10] - hidden_dim_list = [16, 32] - facilitator_list = ["use_facilitator", "no_use_facilitator"] - SDE_sampling_mode_list =["simplified", "weighted"] - - decoder_distribution = "MultinomialDiffusion" - score_network_type = "RNN" - lr_list = ["1e-4", "1e-5"] - num_repeat_list = [16, 32] - for pretrained_mode in pretrained_mode_list: - extract_results_Diffusion(pretrained_mode, decoder_distribution, score_network_type) - print("\n\n\n") - print("\n\n\n") - - decoder_distribution = "MultinomialDiffusion" - score_network_type = "BertBase" - lr_list = ["1e-4", "1e-5"] - num_repeat_list = [16, 32] - for pretrained_mode in pretrained_mode_list: - extract_results_Diffusion(pretrained_mode, decoder_distribution, score_network_type) - print("\n\n\n") - print("\n\n\n") - - # epochs_list = [10, 50] - epochs_list = [10] - decoder_distribution = "T5Decoder" - score_network_type = "T5Base" - lr_list = ["1e-4", "1e-5"] - # num_repeat_list = [16, 8] - num_repeat_list = [16] - AR_generation_mode_list = ["01", "02"] - for pretrained_mode in pretrained_mode_list: - extract_results_AR(pretrained_mode, decoder_distribution, score_network_type) - print("\n\n\n") - print("\n\n\n") diff --git a/examples/downstream_Text2Protein/step_04_gather.py b/examples/downstream_Text2Protein/step_04_gather.py deleted file mode 100644 index 61e27b1..0000000 --- a/examples/downstream_Text2Protein/step_04_gather.py +++ /dev/null @@ -1,100 +0,0 @@ -# lines = """ - -# Galactica & -- & --\\ - -# ChatGPT & -- & --\\ - -# % step_04_T5Decoder_T5Base_lr_1e-4_hidden_16_e_10_unconditional_0.1 -# AR & BERT & 49.25 & 27.14 & 97.38 & 91.62\\ - -# % step_04_MultinomialDiffusion_RNN_lr_1e-4_hidden_32_e_10_unconditional_0 -# \ProteinSDE{} & RNN & 23.08 & 9.89 & 38.07 & 17.26\\ - -# % step_04_MultinomialDiffusion_BertBase_lr_1e-4_hidden_32_e_10_unconditional_0 -# \ProteinSDE{} & BERT & 45.26 & 24.21 & 46.94 & 29.59\\ -# """ - -lines = """ - -Galactica & 51.5 & 29.0 & 19.0\\ - -ChatGPT & 38.5 & 23.0 & 15.5\\ - -% step_04_T5Decoder_T5Base_lr_1e-4_hidden_16_e_10_unconditional_0.1 -% num_repeat: 16 01 -% ../../output/ProteinDT/ProtBERT_BFD-512-1e-5-1-text-512-1e-5-1-EBM_NCE-1-batch-9-gpu-8-epoch-10/step_04_T5Decoder_T5Base_lr_1e-4_hidden_16_e_10_unconditional_0.1/downstream_Retrieval/num_repeat_16_use_prior_inference_01/step_02_inference.out -AR & prior & 97.00 & 91.00 & 83.50\\ - -% num_repeat: 16 01 -% ../../output/ProteinDT/ProtBERT_BFD-512-1e-5-1-text-512-1e-5-1-EBM_NCE-1-batch-9-gpu-8-epoch-10/step_04_T5Decoder_T5Base_lr_1e-4_hidden_16_e_10_unconditional_0.1/downstream_Retrieval/num_repeat_16_no_use_prior_inference_01/step_02_inference.out -AR & no_prior & 49.00 & 27.00 & 20.00\\ - - - - -% step_04_MultinomialDiffusion_RNN_lr_1e-4_hidden_32_e_10_unconditional_0 -% num_repeat: 16 -% ../../output/ProteinDT/ProtBERT_BFD-512-1e-5-1e-1-text-512-1e-5-1e-1-EBM_NCE-0.1-batch-9-gpu-8-epoch-5/step_04_MultinomialDiffusion_RNN_lr_1e-4_hidden_32_e_10_unconditional_0/downstream_Retrieval/num_repeat_16_use_prior_simplified/step_02_inference.out -\ProteinSDE{}-RNN & prior & 40.50 & 21.50 & 15.00\\ - -% num_repeat: 16 -% ../../output/ProteinDT/ProtBERT_BFD-512-1e-5-1e-1-text-512-1e-5-1e-1-EBM_NCE-0.1-batch-9-gpu-8-epoch-5/step_04_MultinomialDiffusion_RNN_lr_1e-4_hidden_32_e_10_unconditional_0/downstream_Retrieval/num_repeat_16_no_use_prior_simplified/step_02_inference.out -\ProteinSDE{}-RNN & no_prior & 24.00 & 10.50 & 5.50\\ - - - - - -% step_04_MultinomialDiffusion_BertBase_lr_1e-4_hidden_32_e_10_unconditional_0 -% num_repeat: 16 -% ../../output/ProteinDT/ProtBERT_BFD-512-1e-5-1-text-512-1e-5-1-EBM_NCE-1-batch-9-gpu-8-epoch-10/step_04_MultinomialDiffusion_BertBase_lr_1e-4_hidden_32_e_10_unconditional_0/downstream_Retrieval/num_repeat_16_use_prior_simplified/step_02_inference.out -\ProteinSDE{}-BERT & prior & 51.50 & 25.00 & 13.50\\ - -% num_repeat: 16 -% ../../output/ProteinDT/ProtBERT_BFD-512-1e-5-1-text-512-1e-5-1-EBM_NCE-1-batch-9-gpu-8-epoch-10/step_04_MultinomialDiffusion_BertBase_lr_1e-4_hidden_32_e_10_unconditional_0/downstream_Retrieval/num_repeat_16_no_use_prior_simplified/step_02_inference.out -\ProteinSDE{}-BERT & no_prior & 35.50 & 17.50 & 9.50\\ -""" - - -if __name__ == "__main__": - baseline_results_list = [] - proteinDT_results_list_without_facilitator = [] - proteinDT_results_list_with_facilitator = [] - - for line in lines.split("\n"): - line = line.strip() - if line == "": - continue - - line = line.replace("\\", "") - line = line.split("&") - - if line[0].startswith("Galactica"): - # print(line) - baseline_results_list.append([line[1], line[2], line[3]]) - - elif line[0].startswith("ChatGPT"): - # print(line) - baseline_results_list.append([line[1], line[2], line[3]]) - - elif "AR" in line[0] or "ProteinSDE" in line[0]: - if "no_prior" in line[1]: - proteinDT_results_list_without_facilitator.append([line[2], line[3], line[4]]) - else: - proteinDT_results_list_with_facilitator.append([line[2], line[3], line[4]]) - - T_list = [4, 10, 20] - for T_idx, T in enumerate(T_list): - row = "T = {}".format(T) - - for baseline_results in baseline_results_list: - row = "{} & {}".format(row, baseline_results[T_idx]) - - for proteinDT_results in proteinDT_results_list_without_facilitator: - row = "{} & {}".format(row, proteinDT_results[T_idx]) - - for proteinDT_results in proteinDT_results_list_with_facilitator: - row = "{} & {}".format(row, proteinDT_results[T_idx]) - - row = "{} \\\\".format(row) - print(row) \ No newline at end of file