Skip to content

Commit

Permalink
SSL test
Browse files Browse the repository at this point in the history
  • Loading branch information
georgeyiasemis committed Sep 28, 2023
1 parent caa22ee commit 5e6be48
Showing 1 changed file with 116 additions and 0 deletions.
116 changes: 116 additions & 0 deletions tests/tests_ssl/test_ssl.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,116 @@
# coding=utf-8
# Copyright (c) DIRECT Contributors

"""Tests for the direct.ssl.ssl module."""

import numpy as np
import pytest
import torch

from direct.ssl.ssl import *


def create_sample(shape, **kwargs):
sample = dict()
sample["kspace"] = torch.rand(*shape).float()
sample["filename"] = ["filename" + str(_) for _ in np.random.randint(100, 10000, size=shape[0])]
sample["slice_no"] = [_ for _ in np.random.randint(0, 1000, size=shape[0])]

sample["sampling_mask"] = torch.rand(shape[0], 1, *shape[2:-1], 1).round().bool()
sample["sampling_mask"][
:, :, shape[2] // 2 - 16 : shape[2] // 2 + 16, shape[3] // 2 - 16 : shape[3] // 2 + 16
] = True

sample["acs_mask"] = torch.zeros(shape[0], 1, *shape[2:-1], 1).bool()
sample["acs_mask"][:, :, shape[2] // 2 - 16 : shape[2] // 2 + 16, shape[3] // 2 - 16 : shape[3] // 2 + 16] = True

for k, v in locals()["kwargs"].items():
sample[k] = v
return sample


@pytest.mark.parametrize(
"shape",
[(1, 3, 40, 60)],
)
@pytest.mark.parametrize(
"ratio",
[0.2, 0.5],
)
@pytest.mark.parametrize(
"acs_region",
[[4, 4], [0, 0], [14, 13]],
)
@pytest.mark.parametrize(
"keep_acs",
[True, False],
)
@pytest.mark.parametrize(
"use_seed",
[True, False],
)
@pytest.mark.parametrize(
"std_scale",
[2.0, 4.0],
)
def test_gaussian_mask_splitter(shape, ratio, acs_region, keep_acs, use_seed, std_scale):
sample = create_sample(shape + (2,))
splitter = GaussianMaskSplitterModule(
ratio=ratio,
acs_region=acs_region,
keep_acs=keep_acs,
use_seed=use_seed,
kspace_key="kspace",
std_scale=std_scale,
)
sample = splitter(sample)
if not keep_acs:
assert torch.allclose(
sample["theta_sampling_mask"] & sample["lambda_sampling_mask"], torch.zeros_like(sample["sampling_mask"])
)
assert torch.allclose(sample["theta_sampling_mask"] | sample["lambda_sampling_mask"], sample["sampling_mask"])
assert torch.allclose(
torch.where(
sample["theta_sampling_mask"], sample["theta_kspace"], sample["lambda_sampling_mask"] * sample["kspace"]
),
(sample["theta_sampling_mask"] | sample["lambda_sampling_mask"]) * sample["kspace"],
)


@pytest.mark.parametrize(
"shape",
[(1, 3, 40, 60)],
)
@pytest.mark.parametrize(
"ratio",
[0.2, 0.5],
)
@pytest.mark.parametrize(
"acs_region",
[[4, 4], [0, 0], [14, 13]],
)
@pytest.mark.parametrize(
"keep_acs",
[True, False],
)
@pytest.mark.parametrize(
"use_seed",
[True, False],
)
def test_uniform_mask_splitter(shape, ratio, acs_region, keep_acs, use_seed):
sample = create_sample(shape + (2,))
splitter = UniformMaskSplitterModule(
ratio=ratio, acs_region=acs_region, keep_acs=keep_acs, use_seed=use_seed, kspace_key="kspace"
)
sample = splitter(sample)
if not keep_acs:
assert torch.allclose(
sample["theta_sampling_mask"] & sample["lambda_sampling_mask"], torch.zeros_like(sample["sampling_mask"])
)
assert torch.allclose(sample["theta_sampling_mask"] | sample["lambda_sampling_mask"], sample["sampling_mask"])
assert torch.allclose(
torch.where(
sample["theta_sampling_mask"], sample["theta_kspace"], sample["lambda_sampling_mask"] * sample["kspace"]
),
(sample["theta_sampling_mask"] | sample["lambda_sampling_mask"]) * sample["kspace"],
)

0 comments on commit 5e6be48

Please sign in to comment.