Skip to content

Commit

Permalink
fix pyporject.toml
Browse files Browse the repository at this point in the history
  • Loading branch information
DSilva27 committed Mar 16, 2023
1 parent 77a4dca commit f2ffedd
Show file tree
Hide file tree
Showing 8 changed files with 35 additions and 28 deletions.
3 changes: 2 additions & 1 deletion pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -28,7 +28,8 @@ dependencies = [
"torch",
"numpy",
"matplotlib",
"scipy"
"scipy",
"torchvision"
]
#dynamic = ["version"]

Expand Down
2 changes: 1 addition & 1 deletion src/cryo_sbi/inference/NPE_train_without_saving.py
Original file line number Diff line number Diff line change
Expand Up @@ -113,5 +113,5 @@ def npe_train_no_saving(
loss_file=args.loss_file,
train_from_checkpoint=args.train_from_checkpoint,
state_dict_file=args.state_dict_file,
n_workers=args.n_workers
n_workers=args.n_workers,
)
5 changes: 2 additions & 3 deletions src/cryo_sbi/inference/NRE_train_without_saving.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,7 @@ def nre_train_no_saving(
loss_file,
train_from_checkpoint=False,
model_state_dict=None,
n_workers=1
n_workers=1,
):
cryo_simulator = CryoEmSimulator(image_config)

Expand Down Expand Up @@ -110,6 +110,5 @@ def nre_train_no_saving(
loss_file=args.loss_file,
train_from_checkpoint=args.train_from_checkpoint,
state_dict_file=args.state_dict_file,
n_workers=args.n_workers
n_workers=args.n_workers,
)

1 change: 0 additions & 1 deletion src/cryo_sbi/wpa_simulator/cryo_em_simulator.py
Original file line number Diff line number Diff line change
Expand Up @@ -57,7 +57,6 @@ def max_index(self):
return len(self.models) - 1

def simulator(self, index, seed=None):

# if seed is not None:
# torch.manual_seed(seed)

Expand Down
9 changes: 4 additions & 5 deletions src/cryo_sbi/wpa_simulator/noise.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,6 @@ def circular_mask(n_pixels, radius):


def add_noise(image, image_params, seed=None):

if seed is not None:
torch.manual_seed(seed)

Expand All @@ -27,12 +26,12 @@ def add_noise(image, image_params, seed=None):
)

else:
raise ValueError(
"SNR should be a single value or a list of [min_snr, max_snr]"
)
raise ValueError("SNR should be a single value or a list of [min_snr, max_snr]")

noise_power = signal_power / np.sqrt(snr)
image_noise = image + torch.distributions.normal.Normal(0, noise_power).sample(image.shape)
image_noise = image + torch.distributions.normal.Normal(0, noise_power).sample(
image.shape
)

return image_noise

Expand Down
1 change: 0 additions & 1 deletion src/cryo_sbi/wpa_simulator/padding.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,6 @@


def pad_image(image, image_params):

pad_width = int(np.ceil(image_params["N_PIXELS"] * 0.1)) + 1
padder = ConstantPad2d(pad_width, 0.0)
padded_image = padder(image)
Expand Down
5 changes: 3 additions & 2 deletions src/cryo_sbi/wpa_simulator/shift.py
Original file line number Diff line number Diff line change
@@ -1,13 +1,14 @@
import torch
import numpy as np


def apply_random_shift(padded_image, image_params, seed=None):
if seed is not None:
torch.manual_seed(seed)

max_shift = int(np.round(image_params["N_PIXELS"] * 0.1))
shift_x = int(torch.randint(low=-max_shift, high=max_shift+1, size=(1,)))
shift_y = int(torch.randint(low=-max_shift, high=max_shift+1, size=(1,)))
shift_x = int(torch.randint(low=-max_shift, high=max_shift + 1, size=(1,)))
shift_y = int(torch.randint(low=-max_shift, high=max_shift + 1, size=(1,)))

pad_width = int(np.ceil(image_params["N_PIXELS"] * 0.1)) + 1

Expand Down
37 changes: 23 additions & 14 deletions tests/test_simulator.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,69 +12,75 @@
from cryo_sbi.wpa_simulator.validate_image_config import check_params
from cryo_sbi import CryoEmSimulator


def _get_config():
config = json.load(open("tests/image_params_testing.json"))
check_params(config)

return config


def test_padding():
image_params = _get_config()
pad_width = int(np.ceil(image_params["N_PIXELS"] * 0.1)) + 1
image = torch.zeros((image_params["N_PIXELS"], image_params["N_PIXELS"]))
padded_image = pad_image(image, image_params)

for size in padded_image.shape:
assert size == pad_width*2 + image_params["N_PIXELS"]
for size in padded_image.shape:
assert size == pad_width * 2 + image_params["N_PIXELS"]
return


def test_shift_size():
image_params = _get_config()
image = torch.zeros((image_params["N_PIXELS"], image_params["N_PIXELS"]))
padded_image = pad_image(image, image_params)
shifted_image = apply_random_shift(padded_image, image_params)

for size in shifted_image.shape:
for size in shifted_image.shape:
assert size == image_params["N_PIXELS"]
return


def test_shift_bias():
image_params = _get_config()
x_0 = image_params["N_PIXELS"]//2
y_0 = image_params["N_PIXELS"]//2

x_0 = image_params["N_PIXELS"] // 2
y_0 = image_params["N_PIXELS"] // 2

image = torch.zeros((image_params["N_PIXELS"], image_params["N_PIXELS"]))
image[x_0, y_0] = 1
image[x_0-1, y_0] = 1
image[x_0, y_0-1] = 1
image[x_0-1, y_0-1] = 1
image[x_0 - 1, y_0] = 1
image[x_0, y_0 - 1] = 1
image[x_0 - 1, y_0 - 1] = 1

padded_image = pad_image(image, image_params)
shifted_image = torch.zeros_like(image)

for _ in range(10000):
shifted_image = shifted_image + apply_random_shift(padded_image, image_params)

indices_x, indices_y = np.where(shifted_image >= 1)

assert np.mean(indices_x) == image_params["N_PIXELS"] / 2 - 0.5
assert np.mean(indices_y) == image_params["N_PIXELS"] / 2 - 0.5

return


def test_no_shift():
image_params = _get_config()
image = torch.zeros((image_params["N_PIXELS"], image_params["N_PIXELS"]))
padded_image = pad_image(image, image_params)
shifted_image = apply_no_shift(padded_image, image_params)

for size in shifted_image.shape:
for size in shifted_image.shape:
assert size == image_params["N_PIXELS"]

assert torch.allclose(image, shifted_image)
return


def test_normalization():
image_params = _get_config()
img_shape = (image_params["N_PIXELS"], image_params["N_PIXELS"])
Expand All @@ -85,6 +91,7 @@ def test_normalization():
assert torch.allclose(torch.std(gnormed_image), torch.tensor(1.0), atol=1e-3)
return


def test_noise():
image_params = _get_config()
N = 10000
Expand All @@ -99,14 +106,16 @@ def test_noise():
assert torch.allclose(torch.mean(stds), torch.tensor(1.0), atol=1e-3)
return

#def test_ctf():

# def test_ctf():


def test_simulation():
simul = CryoEmSimulator("tests/image_params_testing.json")
image_sim = simul.simulator(index=torch.tensor(0), seed=0)

image_params = _get_config()
model = np.load(image_params["MODEL_FILE"])[0,0]
model = np.load(image_params["MODEL_FILE"])[0, 0]
image = gen_img(model, image_params)
image = pad_image(image, image_params)
ctf = calc_ctf(image_params)
Expand Down

0 comments on commit f2ffedd

Please sign in to comment.