Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Updated potts_model.py #22

Open
wants to merge 1 commit into
base: main
Choose a base branch
from
Open
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
42 changes: 40 additions & 2 deletions potts_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,7 @@

import numpy as np

import utils
from slip import utils


def _get_shifted_weights(weight_matrix: np.ndarray,
Expand Down Expand Up @@ -333,14 +333,48 @@ def _potts_energy(self, sequences):

return linear_term + quadratic_term

def subset_landscape(sites, bias, couplings, wt_seq):
"""Create subset of Potts Model state dict dumped from Mogwai
that only contains specified sites of inteterest.

def load_from_mogwai_npz(filepath, **init_kwargs):
Args:
bias: weight matrix
couplings: pair-wise weight matrix
wt_seq: wildtype sequence
sites: list of site numbers of interest, one-indexed to match MSA

Returns:
subset of Potts Model matrices, with specified sites only.
"""
# subtract 1 for zero indexing
sites = [p-1 for p in sites]

# subsection of wt_seq
wt_seq = wt_seq[sites]

# subsection of bias
bias = bias[sites, :]

# subsection of couplings
temp = np.empty((len(sites), 20, len(sites), 20))

for i, row_pos in enumerate(sites):
for j, col_pos in enumerate(sites):
temp[i, :, j, :] = couplings[row_pos, :, col_pos, :]

couplings = temp

return bias, couplings, wt_seq


def load_from_mogwai_npz(filepath, sites = [], **init_kwargs):
"""Load a landscape from a Potts Model state dict dumped from Mogwai.

Args:
filepath: A path to a .npz file with the following fields: ['weight',
'bias', 'query_seq']. This file is assumed to be a saved state dict
from the package mogwai https://github.com/nickbhat/mogwai.
sites (optional): list of site numbers of interest, one-indexed to match MSA
**init_kwargs: Kwargs passed to the PottsModel constructor.

Returns:
Expand All @@ -354,6 +388,10 @@ def load_from_mogwai_npz(filepath, **init_kwargs):
bias = -1 * state_dict['bias']
wt_seq = state_dict['query_seq']

# subsect landscape
if len(sites) > 0:
bias, couplings, wt_seq = subset_landscape(sites, bias, couplings, wt_seq)

# Reshape the couplings from Mogwai. L, A, L, A -> L, L, A, A
couplings = np.moveaxis(couplings, [0, 1, 2, 3], [0, 2, 1, 3])

Expand Down