From 3fe7020346c7edfb2deb3a8b2132767960b29889 Mon Sep 17 00:00:00 2001 From: NikoOinonen <42408893+NikoOinonen@users.noreply.github.com> Date: Mon, 27 Nov 2023 18:21:12 +0200 Subject: [PATCH] Add tests (#35) --- mlspm/cli.py | 8 +- mlspm/data_loading.py | 6 +- mlspm/graph/_molecule_graph.py | 4 +- mlspm/graph/_utils.py | 1 + mlspm/logging.py | 4 - .../training/fit_posnet.py | 9 +- tests/integration_tests/test_train_posnet.py | 263 ++++++++++++++++++ tests/test_cli.py | 15 + tests/test_data_loading.py | 166 +++++++++++ tests/test_graph.py | 229 ++++++++++++++- tests/test_logging.py | 34 +++ tests/test_losses.py | 1 - tests/test_models.py | 8 +- tests/test_utils.py | 52 ++-- tests/test_visualization.py | 88 ++++++ 15 files changed, 838 insertions(+), 50 deletions(-) create mode 100644 tests/integration_tests/test_train_posnet.py create mode 100644 tests/test_cli.py create mode 100644 tests/test_data_loading.py create mode 100644 tests/test_logging.py create mode 100755 tests/test_visualization.py diff --git a/mlspm/cli.py b/mlspm/cli.py index 5770f66..409287f 100644 --- a/mlspm/cli.py +++ b/mlspm/cli.py @@ -1,4 +1,5 @@ import argparse +from typing import Optional def _bool_type(value): @@ -9,10 +10,13 @@ def _bool_type(value): raise KeyError(f"`{value}` can't be interpreted as a boolean.") -def parse_args() -> dict: +def parse_args(argv: Optional[list[str]] = None) -> dict: """ Parse some useful CLI arguments for use in training scripts. + Arguments: + argv: List of argument values. Defaults to ``sys.argv``. + Returns: A dictionary of the argument values. """ @@ -68,5 +72,5 @@ def parse_args() -> dict: parser.add_argument( "--avg_best_epochs", type=int, default=3, help="Number of epochs to average the best validation loss over. Default = 3." ) - args = parser.parse_args() + args = parser.parse_args(argv) return vars(args) diff --git a/mlspm/data_loading.py b/mlspm/data_loading.py index a5524da..4044e03 100644 --- a/mlspm/data_loading.py +++ b/mlspm/data_loading.py @@ -136,7 +136,7 @@ def decode_xyz(key: str, data: Any) -> Tuple[np.ndarray, np.ndarray] | Tuple[Non sw = get_scan_window_from_comment(comment) xyz = [] while line := data.readline().decode("utf-8"): - e, x, y, z, _ = line.strip().split() + e, x, y, z = line.strip().split()[:4] try: e = int(e) except ValueError: @@ -184,7 +184,7 @@ def get_scan_window_from_comment(comment: str) -> np.ndarray: return sw -def _rotate_and_stack(src: Iterable[dict], reverse: bool = True) -> Generator[dict, None, None]: +def _rotate_and_stack(src: Iterable[dict], reverse: bool = False) -> Generator[dict, None, None]: """ Take a sample in dict format and update it with fields containing an image stack, xyz coordinates and scan window. Rotate the images to be xy-indexing convention and stack them into a single array. @@ -194,7 +194,7 @@ def _rotate_and_stack(src: Iterable[dict], reverse: bool = True) -> Generator[di Arguments: src: Iterable of dicts with the fields: - - ``'{000..0xx}.jpg'`` - :class:`PIL.Image.Image` of one slice of the simulation. + - ``'{000..0xx}.{jpg,png}'`` - :class:`PIL.Image.Image` of one slice of the simulation. - ``'xyz'`` - Tuple(:class:`np.ndarray`, :class:`np.ndarray`) of the xyz data and the scan window. reverse: Whether the order of the image stack is reversed. diff --git a/mlspm/graph/_molecule_graph.py b/mlspm/graph/_molecule_graph.py index 60a90fb..7183843 100755 --- a/mlspm/graph/_molecule_graph.py +++ b/mlspm/graph/_molecule_graph.py @@ -276,8 +276,10 @@ def transform_xy( """ Transform atom positions in the xy plane. + Transformations are perfomed in the order: shift, rotate, flip x, flip y + Arguments: - shift: Shift atom positions in xy plane. Performed before rotation and flip. + shift: Shift atom positions in xy plane. rot_xy: Rotate atoms in xy plane by rot_xy degrees around center point. flip_x: Mirror atom positions in x direction with respect to the center point. flip_y: Mirror atom positions in y direction with respect to the center point. diff --git a/mlspm/graph/_utils.py b/mlspm/graph/_utils.py index e0b8e80..b3f3d10 100644 --- a/mlspm/graph/_utils.py +++ b/mlspm/graph/_utils.py @@ -419,6 +419,7 @@ def save_graphs_to_xyzs( Arguments: molecules: Molecule graphs to save. classes: Chemical elements for atom classification. Either atomic numbers of chemical symbols. + The element for each atom in the graph is the first element in the corresponding class. outfile_format: Formatting string for saved files. Sample index is available in variable ``ind``. start_ind: Index where file numbering starts. verbose: Whether to print output information. diff --git a/mlspm/logging.py b/mlspm/logging.py index 5678c50..92cb14a 100644 --- a/mlspm/logging.py +++ b/mlspm/logging.py @@ -204,10 +204,6 @@ def __init__( self._synced_losses = {"train": SyncedLoss(len(self.loss_labels)), "val": SyncedLoss(len(self.loss_labels))} self._init_log(init_epoch) - def __del__(self): - if self.stream is not sys.stdout: - self.stream.close() - def _init_log(self, init_epoch: Optional[int]): log_exists = os.path.isfile(self.log_path) if self.world_size > 1: diff --git a/papers/ice_structure_discovery/training/fit_posnet.py b/papers/ice_structure_discovery/training/fit_posnet.py index f3bd8db..dbae516 100644 --- a/papers/ice_structure_discovery/training/fit_posnet.py +++ b/papers/ice_structure_discovery/training/fit_posnet.py @@ -240,14 +240,7 @@ def run(cfg): lr_decay.step() # Log losses - try: - loss_logger.add_train_loss(loss) - except ValueError as e: - torch.save(model.module.state_dict(), save_path := os.path.join(cfg['run_dir'], 'debug_model.pth')) - with open('debug_data', 'wb') as f: - pickle.dump((X.cpu().numpy(), ref.cpu().numpy()), f) - print(f'Save debug data on rank {cfg["global_rank"]}') - raise e + loss_logger.add_train_loss(loss) if cfg['timings'] and cfg['global_rank'] == 0: torch.cuda.synchronize() diff --git a/tests/integration_tests/test_train_posnet.py b/tests/integration_tests/test_train_posnet.py new file mode 100644 index 0000000..e0c2f97 --- /dev/null +++ b/tests/integration_tests/test_train_posnet.py @@ -0,0 +1,263 @@ +#!/usr/bin/env python3 + +import os +import random +import shutil +import tarfile +from functools import partial +from pathlib import Path + +import numpy as np +import torch +import webdataset as wds +from torch import nn, optim + +import mlspm.data_loading as dl +import mlspm.preprocessing as pp +from mlspm import graph, utils +from mlspm.cli import parse_args +from mlspm.logging import LossLogPlot +from mlspm.models import PosNet + +from PIL import Image + + +def make_model(device, cfg): + outsize = round((cfg["z_lims"][1] - cfg["z_lims"][0]) / cfg["box_res"][2]) + 1 + model = PosNet( + encode_block_channels=[2, 4, 8, 16], + encode_block_depth=2, + decode_block_channels=[16, 8, 4], + decode_block_depth=1, + decode_block_channels2=[16, 8, 4], + decode_block_depth2=1, + attention_channels=[16, 16, 16], + res_connections=True, + activation="relu", + padding_mode="zeros", + pool_type="avg", + decoder_z_sizes=[5, 10, outsize], + z_outs=[3, 3, 5, 8], + peak_std=cfg["peak_std"], + device=device + ) + criterion = nn.MSELoss(reduction="mean") + optimizer = optim.Adam(model.parameters(), lr=cfg["lr"]) + lr_decay_rate = 1e-5 + lr_decay = optim.lr_scheduler.LambdaLR(optimizer, lambda b: 1.0 / (1.0 + lr_decay_rate * b)) + return model, criterion, optimizer, lr_decay + + +def make_test_data(cfg): + out_dir = Path(cfg["data_dir"]) + out_dir.mkdir(exist_ok=True) + urls = wds.shardlists.expand_urls(cfg["urls_train"]) + i_sample = 0 + for url in urls: + temp_dir = Path(f"temp_{url}") + temp_dir.mkdir(exist_ok=True) + os.chdir(temp_dir) + with tarfile.open(url, "w") as f: + for _ in range(10): + afm = np.random.randint(0, 255, (64, 64, 8), dtype=np.uint8) + for i in range(afm.shape[-1]): + img_path = f"{i_sample}.{i}.png" + Image.fromarray(afm[:, ::-1, i].T).save(img_path) + f.add(img_path) + xyz = np.random.rand(8, 3) + xyz[:, :2] *= 8 + atoms = np.concatenate([xyz, np.random.randint(1, 10, (8, 1))], axis=1) + xyz_path = f"{i_sample}.xyz" + utils.write_to_xyz(atoms, outfile=xyz_path, comment_str="Scan window: [[0.0 0.0 0.0], [8.0 8.0 1.0]]", verbose=0) + f.add(xyz_path) + i_sample += 1 + os.chdir("..") + (temp_dir / url).rename(out_dir / url) + shutil.rmtree(temp_dir) + + +def apply_preprocessing(batch, cfg): + box_res = cfg["box_res"] + z_lims = cfg["z_lims"] + zmin = cfg["zmin"] + peak_std = cfg["peak_std"] + + X, atoms, scan_windows = [batch[k] for k in ["X", "xyz", "sw"]] + + nz_max = X[0].shape[-1] + nz = random.choice(range(1, nz_max + 1)) + z0 = random.choice(range(0, min(5, nz_max + 1 - nz))) + X = [x[:, :, :, -nz:] for x in X] if z0 == 0 else [x[:, :, :, -(nz + z0) : -z0] for x in X] + + atoms = [a[a[:, -1] != 29] for a in atoms] + pp.top_atom_to_zero(atoms) + xyz = atoms.copy() + mols = [graph.MoleculeGraph(a, []) for a in atoms] + mols, sw = graph.shift_mols_window(mols, scan_windows[0]) + + pp.rand_shift_xy_trend(X, max_layer_shift=0.02, max_total_shift=0.04) + box_borders = graph.make_box_borders(X[0].shape[1:3], res=box_res[:2], z_range=z_lims) + X, mols, box_borders = graph.add_rotation_reflection_graph( + X, mols, box_borders, num_rotations=1, reflections=True, crop=(32, 32), per_batch_item=True + ) + pp.add_norm(X) + pp.add_gradient(X, c=0.3) + pp.add_noise(X, c=0.1, randomize_amplitude=True, normal_amplitude=True) + pp.add_cutout(X, n_holes=5) + + mols = graph.threshold_atoms_bonds(mols, zmin) + ref = graph.make_position_distribution(mols, box_borders, box_res=box_res, std=peak_std) + + return X, [ref], xyz, box_borders + + +def make_webDataloader(cfg): + shard_list = dl.ShardList( + cfg[f"urls_train"], + base_path=cfg["data_dir"], + substitute_param=True, + log=Path(cfg["run_dir"]) / "shards.log", + ) + + dataset = wds.WebDataset(shard_list) + dataset.pipeline.pop() + dataset.append(wds.tariterators.tarfile_to_samples()) + dataset.append(wds.split_by_worker) + dataset.append(wds.decode("pill", dl.decode_xyz)) + dataset.append(dl.rotate_and_stack()) + dataset.append(dl.batched(cfg["batch_size"])) + dataset = dataset.map(partial(apply_preprocessing, cfg=cfg)) + + dataloader = wds.WebLoader( + dataset, + num_workers=cfg["num_workers"], + batch_size=None, + pin_memory=True, + collate_fn=dl.default_collate, + persistent_workers=False, + ) + + return dataset, dataloader + + +def batch_to_device(batch, device): + X, ref, *rest = batch + X = X[0].to(device) + ref = ref[0].to(device) + return X, ref, *rest + + +def run(cfg): + device = "cuda" if torch.cuda.is_available() else "cpu" + + # Create run directory + if not os.path.exists(cfg["run_dir"]): + os.makedirs(cfg["run_dir"]) + + # Define model, optimizer, and loss + model, criterion, optimizer, lr_decay = make_model(device, cfg) + + # Setup checkpointing and load a checkpoint if available + checkpointer = utils.Checkpointer( + model, + optimizer, + additional_data={"lr_params": lr_decay}, + checkpoint_dir=os.path.join(cfg["run_dir"], "Checkpoints/"), + keep_last_epoch=True, + ) + init_epoch = checkpointer.epoch + + # Setup logging + log_file = open(os.path.join(cfg["run_dir"], "batches.log"), "a") + loss_logger = LossLogPlot( + log_path=os.path.join(cfg["run_dir"], "loss_log.csv"), + plot_path=os.path.join(cfg["run_dir"], "loss_history.png"), + loss_labels=cfg["loss_labels"], + loss_weights=cfg["loss_weights"], + print_interval=cfg["print_interval"], + init_epoch=init_epoch, + stream=log_file, + ) + + for epoch in range(cfg["epochs"]): + # Create datasets and dataloaders + _, train_loader = make_webDataloader(cfg) + val_loader = train_loader + + print(f"\n === Epoch {epoch}") + + model.train() + for ib, batch in enumerate(train_loader): + # Transfer batch to device + X, ref, _, _ = batch_to_device(batch, device) + + # Forward + pred, _ = model(X) + loss = criterion(pred, ref) + + # Backward + optimizer.zero_grad(set_to_none=True) + loss.backward() + optimizer.step() + lr_decay.step() + + # Log losses + loss_logger.add_train_loss(loss) + + print(f"Train batch {ib}") + + # Validate + + model.eval() + with torch.no_grad(): + for ib, batch in enumerate(val_loader): + # Transfer batch to device + X, ref, _, _ = batch_to_device(batch, device) + + # Forward + pred, _ = model(X) + loss = criterion(pred, ref) + + loss_logger.add_val_loss(loss) + + print(f"Val batch {ib}") + + # Write average losses to log and report to terminal + loss_logger.next_epoch() + + # Save checkpoint + checkpointer.next_epoch(loss_logger.val_losses[-1][0]) + + # Return to best epoch, and save model weights + checkpointer.revert_to_best_epoch() + print(f"Best validation loss on epoch {checkpointer.best_epoch}: {checkpointer.best_loss}") + + log_file.close() + shutil.rmtree(cfg["run_dir"]) + shutil.rmtree(cfg["data_dir"]) + + +def test_train_posnet(): + # fmt:off + cfg = parse_args( + [ + "--run_dir", "test_train", + "--epochs", "2", + "--batch_size", "4", + "--z_lims", "-1.0", "0.5", + "--zmin", "-1.0", + "--data_dir", "./test_data", + "--urls_train", "train-K-{0..1}_{0..1}.tar", + "--box_res", "0.125", "0.125", "0.100", + "--peak_std", "0.20", + "--lr", "1e-4" + ] + ) + # fmt:on + + make_test_data(cfg) + run(cfg) + + +if __name__ == "__main__": + test_train_posnet() diff --git a/tests/test_cli.py b/tests/test_cli.py new file mode 100644 index 0000000..389e51b --- /dev/null +++ b/tests/test_cli.py @@ -0,0 +1,15 @@ + +import pytest + +def test_parse_args(): + from mlspm.cli import parse_args + + args = parse_args(["--train", "false", "--predict", "False", '--test', "true", "--classes", "1,2,3", "4,5,6"]) + + assert args["train"] == False + assert args["predict"] == False + assert args["test"] == True + assert args["classes"] == [[1, 2, 3], [4, 5, 6]] + + with pytest.raises(KeyError): + parse_args(["--train", "fals"]) diff --git a/tests/test_data_loading.py b/tests/test_data_loading.py new file mode 100644 index 0000000..529a917 --- /dev/null +++ b/tests/test_data_loading.py @@ -0,0 +1,166 @@ +import os +from pathlib import Path + +import numpy as np +import pytest +import torch +from PIL import Image + + +def test_shardlist(): + from mlspm.data_loading import ShardList + + urls = "test_K-{0..1}_{0..1}" + base_path = "./base/" + log_path = Path("shards.log") + + # Test with subtitution + shard_list = ShardList(urls=urls, base_path=base_path, world_size=1, rank=0, substitute_param=True, log=log_path) + + assert len(list(shard_list)) == 2 + + # Test without subtitution + shard_list = ShardList(urls=urls, base_path=base_path, world_size=1, rank=0, substitute_param=False, log=log_path) + + shards = list(shard_list) + shards_expected = [ + {"url": "./base/test_K-0_0"}, + {"url": "./base/test_K-0_1"}, + {"url": "./base/test_K-1_0"}, + {"url": "./base/test_K-1_1"}, + ] + assert shards == shards_expected + + # Test splitting over ranks + shard_list_rank0 = ShardList(urls=urls, base_path=base_path, world_size=2, rank=0, substitute_param=False, log=log_path) + shard_list_rank1 = ShardList(urls=urls, base_path=base_path, world_size=2, rank=1, substitute_param=False, log=log_path) + + shards = list(shard_list_rank0) + shards_expected = [ + {"url": "./base/test_K-0_0"}, + {"url": "./base/test_K-1_0"}, + ] + assert shards == shards_expected + + shards = list(shard_list_rank1) + shards_expected = [ + {"url": "./base/test_K-0_1"}, + {"url": "./base/test_K-1_1"}, + ] + assert shards == shards_expected + + # Test url list filling (list size not divisible by world_size) + shard_list_rank0 = ShardList(urls=urls, base_path=base_path, world_size=3, rank=0, substitute_param=False, log=log_path) + shard_list_rank1 = ShardList(urls=urls, base_path=base_path, world_size=3, rank=1, substitute_param=False, log=log_path) + shard_list_rank2 = ShardList(urls=urls, base_path=base_path, world_size=3, rank=2, substitute_param=False, log=log_path) + + assert len(list(shard_list_rank0)) == 2 + assert len(list(shard_list_rank1)) == 2 + assert len(list(shard_list_rank2)) == 2 + + # Test url list filling (world size at least twice as big as the list size) + shard_list = ShardList(urls=urls, base_path=base_path, world_size=9, rank=8, substitute_param=False, log=log_path) + assert len(list(shard_list)) == 1 + + with pytest.raises(ValueError): + # With substitute_param=True, requires "K_{num}" in urls + shard_list = ShardList(urls="test_{0..1}", substitute_param=True, log=log_path) + list(shard_list) + + if log_path.exists(): + os.remove(log_path) + + +def test_decode_xyz(): + from mlspm.data_loading import decode_xyz + from mlspm.utils import read_xyzs, write_to_xyz + + xyz_path = "./test_decode.xyx" + # fmt: off + xyz = np.array([ + [ 0.0, 0.1, 0.2, 1], # (x, y, z, element) + [ 0.0, -0.1, -0.2, 6], + [-0.2, 1.5, -2.5, 14]] + ) + # fmt: on + + write_to_xyz(xyz, xyz_path, comment_str="Scan window: [[-1.0 1.0 .2], [10.0 20 15.0]]`") + with open(xyz_path, "rb") as f: + xyz_bytes = f.read() + + os.remove(xyz_path) + + xyz_decoded, scan_window = decode_xyz(".xyz", data=xyz_bytes) + + assert np.allclose(xyz, xyz_decoded) + assert np.allclose(scan_window, np.array([[-1.0, 1.0, 0.2], [10.0, 20.0, 15.0]])) + + a, b = decode_xyz(".png", data=xyz_bytes) + + assert a is None + assert b is None + + +def test_rotate_and_stack(): + from mlspm.data_loading import _rotate_and_stack + + img0 = np.random.randint(0, 255, (10, 10), dtype=np.uint8) + img1 = np.random.randint(0, 255, (10, 10), dtype=np.uint8) + img0_pil = Image.fromarray(img0) + img1_pil = Image.fromarray(img1) + + # fmt: off + xyz = np.array([ + [ 0.0, 0.1, 0.2, 1], # (x, y, z, element) + [ 0.0, -0.1, -0.2, 6], + [-0.2, 1.5, -2.5, 14]] + ) + # fmt: on + sw = np.array([[-1.0, 1.0, 0.2], [10.0, 20.0, 15.0]]) + + src = [{"000.png": img0_pil, "001.png": img1_pil, "xyz": (xyz, sw)}] + + samples = list(_rotate_and_stack(src, reverse=False)) + + sample = samples[0] + X = sample["X"] + xyz_out = sample["xyz"] + sw_out = sample["sw"] + + assert X.shape == (1, 10, 10, 2) + assert sw_out.shape == (1, 2, 3) + assert xyz_out.shape == (3, 4) + + assert np.allclose(X[0, :, ::-1, 0].T, img0) + assert np.allclose(X[0, :, ::-1, 1].T, img1) + assert np.allclose(xyz, xyz_out) + assert np.allclose(sw, sw_out) + +def test_collate_batch(): + + from mlspm.data_loading import _collate_batch + + batch = [ + { + 'X': np.random.randint(0, 255, (1, 10, 10, 5), dtype=np.uint8), + 'Y': np.random.randint(0, 255, (1, 10, 10), dtype=np.uint8), + 'xyz': np.random.rand(3, 4), + 'sw': np.random.rand(1, 2, 3) + } + for _ in range(3) + ] + + batch = _collate_batch(batch) + + X = batch['X'] + Y = batch['Y'] + xyz = batch['xyz'] + sw = batch['sw'] + + assert len(X) == len(Y) == len(sw) == 1 + assert X[0].shape == (3, 10, 10, 5) + assert Y[0].shape == (3, 10, 10) + assert sw[0].shape == (3, 2, 3) + assert len(xyz) == 3 + for a in xyz: + assert a.shape == (3, 4) diff --git a/tests/test_graph.py b/tests/test_graph.py index 14fe15b..bd580a0 100644 --- a/tests/test_graph.py +++ b/tests/test_graph.py @@ -1,11 +1,16 @@ -import torch -import pytest +import os +import shutil +from pathlib import Path + import numpy as np +import pytest +import torch +from scipy.spatial.distance import cdist def test_collate_graph(): - from mlspm.graph import MoleculeGraph from mlspm.data_loading import collate_graph + from mlspm.graph import MoleculeGraph # fmt: off @@ -215,7 +220,7 @@ def test_molecule_graph_array(): def test_molecule_graph_remove_atoms(): - from mlspm.graph import MoleculeGraph, Atom + from mlspm.graph import Atom, MoleculeGraph # fmt: off @@ -243,7 +248,7 @@ def test_molecule_graph_remove_atoms(): assert removed == [] new_molecule, removed = molecule.remove_atoms([1]) - removed_expected = [(Atom(np.array([1.0, 1.0, 0.0]), 2), [0, 1, 0])] + removed_expected = [(Atom(np.array([1.0, 1.0, 0.0]), 'H'), [0, 1, 0])] atoms_expected = np.array([ [0.0, 0.0, 0.0, 1], [1.0, 0.0, 0.0, 3], @@ -331,10 +336,80 @@ def test_molecule_graph_add_atom(): assert a == b -def test_GraphSeqStats(): - from mlspm.graph import GraphStats +def test_molecule_graph_transform_xy(): + from mlspm.graph import MoleculeGraph + + # fmt:off + atoms = np.array([ + [0.0, 0.0, 0.0, 1], + [1.0, 1.0, 1.0, 2], + [1.0, 0.0, 0.0, 3], + [2.0, 0.0, -1.0, 4] + ]) + # fmt:on + bonds = [(0, 2), (1, 2), (2, 3)] + molecule = MoleculeGraph(atoms, bonds) + + molecule_transformed = molecule.transform_xy(shift=(1, 1), rot_xy=90, flip_x=True, flip_y=True, center=(1, 1)) + + xyz_transformed = molecule_transformed.array(xyz=True) + # fmt:off + xyz_expected = np.array( + [ + [1.0, 1.0, 0.0], + [2.0, 0.0, 1.0], + [1.0, 0.0, 0.0], + [1.0, -1.0, -1.0] + ] + ) + # fmt:on + + assert np.allclose(xyz_transformed, xyz_expected) + + +def test_molecule_graph_crop_atoms(): + from mlspm.graph import MoleculeGraph + + # fmt:off + atoms = np.array([ + [0.0, 0.0, 0.0, 1], + [1.0, 1.0, 1.0, 2], + [1.0, 0.0, 0.0, 3], + [2.0, 0.0, -1.0, 4] + ]) + # fmt:on + bonds = [(0, 2), (1, 2), (2, 3)] + molecule = MoleculeGraph(atoms, bonds) + box_borders = np.array([[-0.5, -0.5, -0.5], [1.5, 1.5, 0.5]]) + + molecule_cropped = molecule.crop_atoms(box_borders) + + xyz_cropped = molecule_cropped.array(xyz=True) + # fmt:off + xyz_expected = np.array( + [ + [0.0, 0.0, 0.0], + [1.0, 0.0, 0.0], + ] + ) + # fmt:on + + assert np.allclose(xyz_cropped, xyz_expected) + assert molecule_cropped.bonds == [(0, 1)] + + +def test_molecule_graph_randomize_positions(): from mlspm.graph import MoleculeGraph + molecule = MoleculeGraph(np.zeros((3, 4)), []) + molecule_randomized = molecule.randomize_positions() + + assert not np.allclose(molecule_randomized.array(xyz=True), 0.0) + + +def test_GraphSeqStats(): + from mlspm.graph import GraphStats, MoleculeGraph + classes = [[0], [1], [2]] # fmt:off @@ -434,7 +509,143 @@ def test_GraphSeqStats(): seq_stats = GraphStats(classes=classes, bin_size=1) seq_stats.add_batch(pred, ref) - # seq_stats.plot('./seq_stats') - # seq_stats.report('./seq_stats') + outdir = Path('./test_stats') + seq_stats.plot(outdir) + seq_stats.report(outdir) + shutil.rmtree(outdir) # fmt:on + + +def test_peaks(): + from mlspm.graph import MoleculeGraph, find_gaussian_peaks, make_position_distribution + + # fmt:off + atoms = np.array([ + [2.0, 2.0, 1.0, 0], + [2.0, 6.0, 2.0, 0], + [6.0, 4.0, 2.0, 0], + [4.0, 2.0, 2.0, 0] + ]) + # fmt:on + box_borders = np.array([[0.0, 0.0, 0.0], [8.0, 8.0, 3.0]]) + mols = [MoleculeGraph(atoms, bonds=[])] + std = 0.2 + devices = ["cpu"] + if torch.cuda.is_available(): + devices.append("cuda") + + dist = make_position_distribution(mols, box_borders=box_borders, std=std) + + for device in devices: + dist_d = torch.from_numpy(dist).to(device) + + for method in ["zncc", "mad", "msd", "mad_norm", "msd_norm"]: + if method == "zncc" and device == "cuda": + with pytest.raises(NotImplementedError): + find_gaussian_peaks(dist_d, box_borders, method=method) + continue + + peaks, _, _ = find_gaussian_peaks(dist_d, match_threshold=0.5, box_borders=box_borders, std=std, method=method) + peaks = peaks[0].cpu().numpy() + + # The matched position could be in different order than the original (and undeterministic order on cuda), + # so we test that the minimum matching distance is close to zero. + d = cdist(peaks, atoms[:, :3]).min(axis=1) + assert np.allclose(d, 0.0) + + +def test_shift_mols_window(): + from mlspm.graph import shift_mols_window, MoleculeGraph + + atoms = [np.array([[1.0, 1.0, 0.0, 0], [3.0, 3.0, 1.0, 0]]), np.array([[2.0, 4.0, 1.0, 0], [1.0, 5.0, 1.0, 0]])] + + scan_windows = np.array( + [ + [[0.0, 0.0, 0.0], [4.0, 4.0, 1.0]], + [[1.0, 3.0, 0.0], [5.0, 7.0, 1.0]], + ] + ) + molecules = [MoleculeGraph(atoms[0], []), MoleculeGraph(atoms[1], [])] + + new_molecules, new_scan_window = shift_mols_window(molecules, scan_windows, start=(2, 2)) + + assert np.allclose(new_molecules[0].array(xyz=True), np.array([[3.0, 3.0, 0.0], [5.0, 5.0, 1.0]])) + assert np.allclose(new_molecules[1].array(xyz=True), np.array([[3.0, 3.0, 1.0], [2.0, 4.0, 1.0]])) + assert np.allclose(new_scan_window, np.array([[2.0, 2.0], [6.0, 6.0]])) + + +def test_crop_graph(): + from mlspm.graph import crop_graph, MoleculeGraph + + afm = [np.arange(32).reshape(1, 4, 4, 2)] + print(afm[0][0, :, :, 0]) + print(afm[0][0, :, :, 1]) + mols = [ + MoleculeGraph( + np.array( + [ + [0.0, 0.0, 0.0, 0], + [0.5, 0.5, 0.5, 0], + [1.1, 1.1, 1.0, 0], + ] + ), + [], + ) + ] + start = (1, 1) + size = (2, 2) + box_borders = [[0.0, 0.0, 0.0], [1.5, 1.5, 1.0]] + new_start = (1.0, 1.0) + + afm_cropped, mols_cropped, box_borders_cropped = crop_graph( + afm, mols, start=start, size=size, box_borders=box_borders, new_start=new_start + ) + + assert np.allclose(afm_cropped[0][0, :, :, 0], np.array([[10, 12], [18, 20]])) + assert np.allclose(afm_cropped[0][0, :, :, 1], np.array([[11, 13], [19, 21]])) + assert np.allclose(mols_cropped[0].array(xyz=True), np.array([[1.0, 1.0, 0.5]])) + assert np.allclose(box_borders_cropped, np.array([[1.0, 1.0, 0.0], [1.5, 1.5, 1.0]])) + + +def test_save_graphs_to_xyzs(): + from mlspm.graph import save_graphs_to_xyzs, MoleculeGraph + from mlspm.utils import read_xyzs + + classes = [[1], [6, 8]] + mols = [ + MoleculeGraph( + np.array( + [ + [0.0, 0.0, 0.0, 1], + [0.5, 0.5, 0.5, 6], + [1.0, 1.0, 1.0, 8], + ] + ), + [], + classes=classes, + ) + ] + + save_graphs_to_xyzs(mols, classes=classes) + xyz = read_xyzs(["0_graph.xyz"])[0] + + os.remove("0_graph.xyz") + + assert np.allclose( + xyz, + np.array( + [ + [0.0, 0.0, 0.0, 1], + [0.5, 0.5, 0.5, 6], + [1.0, 1.0, 1.0, 6], + ] + ), + ) + +def test_make_box_borders(): + + from mlspm.graph import make_box_borders + + box_borders = make_box_borders(shape=(100, 100), res=(0.1, 0.2), z_range=(0.0, 1.0)) + assert np.allclose(box_borders, np.array([[0.0, 0.0, 0.0], [9.9, 19.8, 1.0]])) diff --git a/tests/test_logging.py b/tests/test_logging.py new file mode 100644 index 0000000..74ad942 --- /dev/null +++ b/tests/test_logging.py @@ -0,0 +1,34 @@ + +import os +import numpy as np + +def test_loss_log_plot(): + from mlspm.logging import LossLogPlot + + loss_log_path = "loss_log.csv" + plot_path = "test_plot.png" + log_path = "test_log.txt" + + info_log = open(log_path, 'w') + + if os.path.exists(loss_log_path): + os.remove(loss_log_path) + loss_log = LossLogPlot(loss_log_path, plot_path, loss_labels=["1", "2"], loss_weights=["", ""], stream=info_log) + + losses = [[[0, 1, 2], [1, 2, 3]], [[1.5, 2.0, 3.0], [0.1, 0.4, 0.7]]] + for loss in losses: + loss_log.add_train_loss(loss[0]) + loss_log.add_val_loss(loss[1]) + loss_log.next_epoch() + + new_log = LossLogPlot(loss_log_path, "plot.png", loss_labels=["1", "2"], loss_weights=["", ""], stream=info_log) + + info_log.close() + + os.remove(loss_log_path) + os.remove(plot_path) + os.remove(log_path) + + assert new_log.epoch == 2 + assert np.allclose(new_log.train_losses, np.array([[0.75, 1.5, 2.5]])), new_log.train_losses + assert np.allclose(new_log.val_losses, np.array([[0.55, 1.2, 1.85]])), new_log.val_losses diff --git a/tests/test_losses.py b/tests/test_losses.py index 757042d..9fd19dd 100644 --- a/tests/test_losses.py +++ b/tests/test_losses.py @@ -1,5 +1,4 @@ import torch -import numpy as np def test_EqGraphLoss(): diff --git a/tests/test_models.py b/tests/test_models.py index 03e785f..916c888 100644 --- a/tests/test_models.py +++ b/tests/test_models.py @@ -165,13 +165,13 @@ def test_GraphImgNet(): # Test translational invariance of the MPNN pos += torch.Tensor([1.0, 1.0, 1.0]).to(device) node_features_shifted, edge_features_shifted = model.mpnn(pos, x_afm, edges_combined) - assert torch.allclose(node_features, node_features_shifted) - assert torch.allclose(edge_features, edge_features_shifted) + assert torch.allclose(node_features, node_features_shifted, rtol=1e-4, atol=1e-6) + assert torch.allclose(edge_features, edge_features_shifted, rtol=1e-4, atol=1e-6) # Test that the edges are not directional node_features_reverse, edge_features_reverse = model.mpnn(pos, x_afm, edges_combined[[1, 0]]) - assert torch.allclose(node_features, node_features_reverse) - assert torch.allclose(edge_features, edge_features_reverse) + assert torch.allclose(node_features, node_features_reverse, rtol=1e-4, atol=1e-6) + assert torch.allclose(edge_features, edge_features_reverse, rtol=1e-4, atol=1e-6) # Test whole model model.afm_cutoff = 0.8 diff --git a/tests/test_utils.py b/tests/test_utils.py index 03fdd4a..18b14d1 100644 --- a/tests/test_utils.py +++ b/tests/test_utils.py @@ -1,5 +1,10 @@ import os +from pathlib import Path +import shutil + import numpy as np +from torch import nn, optim +import torch def test_xyz_read_write(): @@ -22,27 +27,38 @@ def test_xyz_read_write(): assert np.allclose(xyz, xyz_read) -def test_loss_log_plot(): - from mlspm.logging import LossLogPlot +def test_checkpoints(): + from mlspm.utils import load_checkpoint, save_checkpoint + + save_dir = Path("test_checkpoints") + + model = nn.Linear(10, 10) + optimizer = optim.Adam(model.parameters()) + lr_scheduler = optim.lr_scheduler.LambdaLR(optimizer, lambda b: 1.0 / (1.0 + 1e-5 * b)) + additional_data = {"test_data": 3, "lr_scheduler": lr_scheduler} - log_path = "loss_log.csv" - plot_path = "test_plot.png" + x, y = np.random.rand(2, 1, 10, 10) + x = torch.from_numpy(x).float() + y = torch.from_numpy(y).float() + pred = model(x) + loss = ((y - pred) ** 2).mean() + loss.backward() + optimizer.step() + lr_scheduler.step() + print(loss) - if os.path.exists(log_path): - os.remove(log_path) - loss_log = LossLogPlot(log_path, plot_path, loss_labels=["1", "2"], loss_weights=["", ""]) + save_checkpoint(model, optimizer, epoch=1, save_dir=save_dir, additional_data=additional_data) - losses = [[[0, 1, 2], [1, 2, 3]], [[1.5, 2.0, 3.0], [0.1, 0.4, 0.7]]] - for loss in losses: - loss_log.add_train_loss(loss[0]) - loss_log.add_val_loss(loss[1]) - loss_log.next_epoch() + model_new = nn.Linear(10, 10) + optimizer_new = optim.Adam(model.parameters()) + lr_scheduler_new = optim.lr_scheduler.LambdaLR(optimizer, lambda b: 1.0 / (1.0 + 1e-5 * b)) + additional_data = {"test_data": 0, "lr_scheduler": lr_scheduler_new} - new_log = LossLogPlot(log_path, "plot.png", loss_labels=["1", "2"], loss_weights=["", ""]) + load_checkpoint(model_new, optimizer_new, file_name=save_dir / "model_1.pth", additional_data=additional_data) - os.remove(log_path) - os.remove(plot_path) + assert np.allclose(model.state_dict()["weight"], model_new.state_dict()["weight"]) + assert np.allclose(optimizer.state_dict()["state"][0]["exp_avg"], optimizer_new.state_dict()["state"][0]["exp_avg"]) + assert np.allclose(lr_scheduler.state_dict()["_last_lr"], lr_scheduler_new.state_dict()["_last_lr"]) + assert np.allclose(additional_data["test_data"], 3) - assert new_log.epoch == 2 - assert np.allclose(new_log.train_losses, np.array([[0.75, 1.5, 2.5]])), new_log.train_losses - assert np.allclose(new_log.val_losses, np.array([[0.55, 1.2, 1.85]])), new_log.val_losses + shutil.rmtree(save_dir) diff --git a/tests/test_visualization.py b/tests/test_visualization.py new file mode 100755 index 0000000..3a60189 --- /dev/null +++ b/tests/test_visualization.py @@ -0,0 +1,88 @@ +import shutil +from pathlib import Path + +import numpy as np + + +def test_make_input_plots(): + + from mlspm.visualization import make_input_plots + + save_dir = Path('test_input_plots') + + X = [np.random.rand(2, 20, 20, 2), np.random.rand(2, 20, 20, 2)] + make_input_plots(X, outdir=save_dir) + + assert len(list(save_dir.glob('*.png'))) == 4 + + shutil.rmtree(save_dir) + +def test_plot_graphs(): + + from mlspm.graph import MoleculeGraph + from mlspm.visualization import plot_graphs + + save_dir = Path('test_graph_plots') + + # fmt:off + atoms = [ + np.array([ + [0.1, 0.2, 0.3, 1], + [0.5, 0.6, 0.7, 1], + [0.9, 1.0, 1.1, 6], + [1.3, 1.4, 1.5, 6] + ]), + np.array([ + [1.7, 1.8, 1.9, 1], + [2.1, 2.2, 2.3, 6] + ]), + np.empty((0, 4)), + np.array([ + [2.5, 2.6, 2.7, 1], + [2.9, 3.0, 3.1, 6], + [3.3, 3.4, 3.5, 6] + ]), + ] + # fmt:on + bonds = [ + [(0,1), (0,2), (1,3), (2,3)], + [(0,1)], + [], + [(0,1), (1,2)] + ] + classes = [[1], [6]] + mols = [MoleculeGraph(a, b, classes) for a, b in zip(atoms, bonds)] + box_borders = np.array([[0.0, 0.0, 0.0], [4.0, 4.0, 4.0]]) + + plot_graphs(mols, mols, box_borders=box_borders, classes=classes, outdir=save_dir) + + assert len(list(save_dir.glob('*.png'))) == 4 + + shutil.rmtree(save_dir) + +def test_plot_distribution_grid(): + + from mlspm.graph import make_position_distribution, MoleculeGraph + from mlspm.visualization import plot_distribution_grid + + save_dir = Path('test_distribution_plots') + + # fmt:off + atoms = np.array([ + [2.0, 2.0, 1.0, 0], + [2.0, 6.0, 2.0, 0], + [6.0, 4.0, 2.0, 0], + [4.0, 2.0, 2.0, 0] + ]) + # fmt:on + box_borders = np.array([[0.0, 0.0, 0.0], [8.0, 8.0, 3.0]]) + mols = [MoleculeGraph(atoms, bonds=[])] + std = 0.2 + + dist = make_position_distribution(mols, box_borders=box_borders, std=std) + + plot_distribution_grid(dist, dist, box_borders=box_borders, outdir=save_dir) + + assert len(list(save_dir.glob('*.png'))) == 2 + + shutil.rmtree(save_dir)