Skip to content

Commit

Permalink
2nd release (#9)
Browse files Browse the repository at this point in the history
- Improved the way splitting works
- Added SLURM scripts to simplify generating releases in the future
- Modified a bunch of scripts to simplify making new releases
- Moved the mmCIF generation code to rna3db/parser.py
- Added pre-commits
- Modified CITATION.cff to reflect JMB article
- Add READMEs for scripts
- Fix setup.py
  • Loading branch information
marcellszi authored May 16, 2024
1 parent cb82659 commit 10f7e07
Show file tree
Hide file tree
Showing 15 changed files with 767 additions and 256 deletions.
2 changes: 1 addition & 1 deletion .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -13,4 +13,4 @@ data/
[._]ss[a-gi-z]
[._]sw[a-p]

**/.DS_Store
**/.DS_Store
19 changes: 19 additions & 0 deletions .pre-commit-config.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,19 @@
repos:
- repo: local
hooks:
- id: unittests
name: run unit tests
entry: python -m unittest
language: system
pass_filenames: false
args: ["discover"]
- repo: https://github.com/pre-commit/pre-commit-hooks
rev: v2.3.0
hooks:
- id: check-yaml
- id: end-of-file-fixer
- id: trailing-whitespace
- repo: https://github.com/psf/black
rev: 24.3.0
hooks:
- id: black
22 changes: 22 additions & 0 deletions CITATION.cff
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
cff-version: 1.2.0
message: "If you use this software, please cite it as below."
title: "RNA3DB: A dataset for training and benchmarking deep learning models for RNA structure prediction"
version: 1.1
authors:
- given-names: "Marcell"
family-names: "Szikszai"
Expand All @@ -15,3 +16,24 @@ authors:
- given-names: "Elena
family-names: Rivas"
url: "https://github.com/marcellszi/rna3db"
doi: "10.1016/j.jmb.2024.168552"
date-released: 2024-04-26
preferred-citation:
type: article
authors:
- given-names: "Marcell"
family-names: "Szikszai"
- given-names: "Marcin"
family-names: Magnus
- given-names: "Siddhant"
family-names: "Sanghi"
- given-names: "Sachin"
family-names: "Kadyan"
- given-names: "Nazim"
family-names: "Bouatta"
- given-names: "Elena"
family-names: Rivas"
doi: "10.1016/j.jmb.2024.168552"
journal: "Journal of Molecular Biology"
title: "RNA3DB: A structurally-dissimilar dataset split for training and benchmarking deep learning models for RNA structure prediction"
year: 2024
23 changes: 19 additions & 4 deletions rna3db/__main__.py
Original file line number Diff line number Diff line change
Expand Up @@ -132,7 +132,16 @@ def main(args):
args.input, args.output, args.tbl_dir, args.structural_e_value_cutoff
)
elif args.command == "split":
split(args.input, args.output, args.train_percentage, args.force_zero_test)
split(
args.input,
args.output,
splits=[
args.train_ratio,
args.valid_ratio,
1 - args.train_ratio - args.valid_ratio,
],
force_zero_last=args.force_zero_test,
)
else:
raise ValueError

Expand Down Expand Up @@ -246,10 +255,16 @@ def main(args):
split_parser.add_argument("input", type=Path, help="Input JSON file")
split_parser.add_argument("output", type=Path, help="Output JSON file")
split_parser.add_argument(
"--train_percentage",
"--train_ratio",
type=float,
default=0.3,
help="Percentage of data for the train set",
default=0.7,
help="Ratio of data to use for the training set",
)
split_parser.add_argument(
"--valid_ratio",
type=float,
default=0.0,
help="Ratio of the data to use for the validation set",
)
split_parser.add_argument(
"--force_zero_test",
Expand Down
180 changes: 180 additions & 0 deletions rna3db/parser.py
Original file line number Diff line number Diff line change
Expand Up @@ -135,6 +135,10 @@ def __getitem__(self, idx):
def __len__(self):
return len(self.residues)

@property
def has_atoms(self):
return any([not res.is_missing for res in self])

def add_residue(self, res: Residue):
"""Add a residue to the chain.
Expand Down Expand Up @@ -341,6 +345,182 @@ def __repr__(self):
f"resolution={self.resolution}, release_date={self.release_date}, structure_method={self.structure_method})"
)

@staticmethod
def _gen_mmcif_loop_str(name: str, headers: Sequence[str], values: Sequence[tuple]):
s = "#\nloop_\n"
for header in headers:
s += f"_{name}.{header}\n"

max_widths = {k: 0 for k in headers}
for V in values:
for k, v in zip(headers, V):
max_widths[k] = max(max_widths[k], len(str(v)))

for V in values:
row = ""
for k, v in zip(headers, V):
row += f"{str(v):<{max_widths[k]}} "
s += row + "\n"

return s

def write_mmcif_chain(self, output_path, author_id):
if not self[author_id].has_atoms:
raise ValueError(
f"Did not find any atoms for chain {author_id}. Did you set `include_atoms=True`?"
)
# extract needed info
entity_poly_seq_data = []
atom_site_data = []
for i, res in enumerate(self[author_id]):
entity_poly_seq_data.append((1, res.index + 1, res.code, "n"))
for idx, (atom_name, atom_coords) in enumerate(res.atoms.items()):
x, y, z = atom_coords
atom_site_data.append(
(
"ATOM",
idx + 1,
atom_name[0],
atom_name,
".",
res.code,
author_id,
"?",
i + 1,
"?",
x,
y,
z,
1.0,
0.0,
"?",
i + 1,
res.code,
author_id,
atom_name,
1,
)
)

# build required strings
header_str = (
f"# generated by rna3db\n"
f"#\n"
f"data_{self.pdb_id}_{author_id}\n"
f"_entry.id {self.pdb_id}_{author_id}\n"
f"_pdbx_database_status.recvd_initial_deposition_date {self.release_date}\n"
f"_exptl.method '{self.structure_method.upper()}'\n"
f"_reflns.d_resolution_high {self.resolution}\n"
f"_entity_poly.pdbx_seq_one_letter_code_can {self[author_id].sequence}\n"
)
struct_asym_str = StructureFile._gen_mmcif_loop_str(
"_struct_asym",
[
"id",
"pdbx_blank_PDB_chainid_flag",
"pdbx_modified",
"entity_id",
"details",
],
[("A", "N", "N", 1, "?")],
)
chem_comp_str = StructureFile._gen_mmcif_loop_str(
"_chem_comp",
[
"id",
"type",
"mon_nstd_flag",
"pdbx_synonyms",
"formula",
"formula_weight",
],
[
(
"A",
"'RNA linking'",
"y",
'"ADENOSINE-5\'-MONOPHOSPHATE"',
"?",
"'C10 H14 N5 O7 P'",
347.221,
),
(
"C",
"'RNA linking'",
"y",
'"CYTIDINE-5\'-MONOPHOSPHATE"',
"?",
"'C9 H14 N3 O8 P'",
323.197,
),
(
"G",
"'RNA linking'",
"y",
'"GUANOSINE-5\'-MONOPHOSPHATE"',
"?",
"'C9 H13 N2 O9 P'",
363.221,
),
(
"U",
"'RNA linking'",
"y",
'"URIDINE-5\'-MONOPHOSPHATE"',
"?",
"'C9 H13 N2 O9 P'",
324.181,
),
("T", "'RNA linking'", "y", '"T"', "?", "''", 0),
("N", "'RNA linking'", "y", '"N"', "?", "''", 0),
],
)
entity_poly_seq_str = StructureFile._gen_mmcif_loop_str(
"entity_poly_seq",
[
"entity_id",
"num",
"mon_id",
"heter",
],
entity_poly_seq_data,
)
atom_site_str = StructureFile._gen_mmcif_loop_str(
"atom_site",
[
"group_PDB",
"id",
"type_symbol",
"label_atom_id",
"label_alt_id",
"label_comp_id",
"label_asym_id",
"label_entity_id",
"label_seq_id",
"pdbx_PDB_ins_code",
"Cartn_x",
"Cartn_y",
"Cartn_z",
"occupancy",
"B_iso_or_equiv",
"pdbx_formal_charge",
"auth_seq_id",
"auth_comp_id",
"auth_asym_id",
"auth_atom_id",
"pdbx_PDB_model_num",
],
atom_site_data,
)

# write to file
with open(output_path, "w") as f:
f.write(header_str)
f.write(struct_asym_str)
f.write(chem_comp_str)
f.write(entity_poly_seq_str)
f.write(atom_site_str)


class mmCIFParser:
def __init__(
Expand Down
79 changes: 51 additions & 28 deletions rna3db/split.py
Original file line number Diff line number Diff line change
@@ -1,11 +1,32 @@
import random

from typing import Sequence

from rna3db.utils import PathLike, read_json, write_json


def find_optimal_components(lengths_dict, capacity):
component_name = list(lengths_dict.keys())
lengths = list(lengths_dict.values())

dp = [0] * (capacity + 1)
trace = [[] for i in range(capacity + 1)]
for i in range(len(lengths)):
for j in range(capacity, lengths[i] - 1, -1):
if dp[j] < dp[j - lengths[i]] + lengths[i]:
dp[j] = dp[j - lengths[i]] + lengths[i]
trace[j] = trace[j - lengths[i]] + [component_name[i]]

return set(trace[capacity])


def split(
input_path: PathLike,
output_path: PathLike,
train_size: float = 0.7,
force_zero_test: bool = True,
splits: Sequence[float] = [0.7, 0.0, 0.3],
split_names: Sequence[str] = ["train_set", "valid_set", "test_set"],
shuffle: bool = False,
force_zero_last: bool = False,
):
"""A function that splits a JSON of components into a train/test set.
Expand All @@ -16,35 +37,37 @@ def split(
Args:
input_path (PathLike): path to JSON containing components
output_path (PathLike): path to output JSON
train_size (float): percentage of data to use as training set
force_zero_test (bool): whether to force component_0 into the test set
"""
if sum(splits) != 1.0:
raise ValueError("Sum of splits must equal 1.0.")

# read json
cluster_json = read_json(input_path)

# count number of repr sequences
total_repr_clusters = sum(len(v) for v in cluster_json.values())

# figure out which components need to go into training set
train_components = set()
train_set_length = 0
i = 1 if force_zero_test else 0
while train_set_length / total_repr_clusters < train_size:
# skip if it's not a real component (should only happen with 0)
if f"component_{i}" not in cluster_json:
i += 1
continue
train_components.add(f"component_{i}")
train_set_length += len(cluster_json[f"component_{i}"].keys())
i += 1

# test_components are just total-train_components
test_components = set(cluster_json.keys()) - train_components

# actually build JSON
output = {"train_set": {}, "test_set": {}}
for k in sorted(train_components):
output["train_set"][k] = cluster_json[k]
for k in sorted(test_components):
output["test_set"][k] = cluster_json[k]
lengths = {k: len(v) for k, v in cluster_json.items()}
total_repr_clusters = sum(lengths.values())

# shuffle if we want to add randomness
if shuffle:
L = list(zip(component_name, lengths))
random.shuffle(L)
component_name, lengths = zip(*L)
component_name, lengths = list(component_name), list(lengths)

output = {k: {} for k in split_names}

if force_zero_last:
output[split_names[-1]]["component_0"] = cluster_json["component_0"]
lengths.pop("component_0")

capacities = [round(total_repr_clusters * ratio) for ratio in splits]
for name, capacity in zip(split_names, capacities):
components = find_optimal_components(lengths, capacity)
for k in sorted(components):
lengths.pop(k)
output[name][k] = cluster_json[k]

assert len(lengths) == 0

write_json(output, output_path)
Loading

0 comments on commit 10f7e07

Please sign in to comment.