diff --git a/potts_model.py b/potts_model.py index 39836be..8f52ad3 100644 --- a/potts_model.py +++ b/potts_model.py @@ -20,7 +20,7 @@ import numpy as np -import utils +from slip import utils def _get_shifted_weights(weight_matrix: np.ndarray, @@ -333,14 +333,48 @@ def _potts_energy(self, sequences): return linear_term + quadratic_term +def subset_landscape(sites, bias, couplings, wt_seq): + """Create subset of Potts Model state dict dumped from Mogwai + that only contains specified sites of inteterest. -def load_from_mogwai_npz(filepath, **init_kwargs): + Args: + bias: weight matrix + couplings: pair-wise weight matrix + wt_seq: wildtype sequence + sites: list of site numbers of interest, one-indexed to match MSA + + Returns: + subset of Potts Model matrices, with specified sites only. + """ + # subtract 1 for zero indexing + sites = [p-1 for p in sites] + + # subsection of wt_seq + wt_seq = wt_seq[sites] + + # subsection of bias + bias = bias[sites, :] + + # subsection of couplings + temp = np.empty((len(sites), 20, len(sites), 20)) + + for i, row_pos in enumerate(sites): + for j, col_pos in enumerate(sites): + temp[i, :, j, :] = couplings[row_pos, :, col_pos, :] + + couplings = temp + + return bias, couplings, wt_seq + + +def load_from_mogwai_npz(filepath, sites = [], **init_kwargs): """Load a landscape from a Potts Model state dict dumped from Mogwai. Args: filepath: A path to a .npz file with the following fields: ['weight', 'bias', 'query_seq']. This file is assumed to be a saved state dict from the package mogwai https://github.com/nickbhat/mogwai. + sites (optional): list of site numbers of interest, one-indexed to match MSA **init_kwargs: Kwargs passed to the PottsModel constructor. Returns: @@ -354,6 +388,10 @@ def load_from_mogwai_npz(filepath, **init_kwargs): bias = -1 * state_dict['bias'] wt_seq = state_dict['query_seq'] + # subsect landscape + if len(sites) > 0: + bias, couplings, wt_seq = subset_landscape(sites, bias, couplings, wt_seq) + # Reshape the couplings from Mogwai. L, A, L, A -> L, L, A, A couplings = np.moveaxis(couplings, [0, 1, 2, 3], [0, 2, 1, 3])