Skip to content

Commit

Permalink
first steps to include omge conf into framework
Browse files Browse the repository at this point in the history
  • Loading branch information
Dingel321 committed Jul 3, 2024
1 parent a60c974 commit 9dd59bd
Show file tree
Hide file tree
Showing 3 changed files with 74 additions and 18 deletions.
56 changes: 56 additions & 0 deletions src/cryo_sbi/utils/configurations.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,56 @@
import dataclasses
from typing import List, Union


@dataclasses.dataclass
class SimulatorConfig:
model_file: str
multiple_models_per_cv: bool
num_pixels: int
pixel_size: float
sigma: float
shift: float
defocus: List[float]
snr: List[float]
amp: float
b_factor: List[float]


@dataclasses.dataclass
class NeuralNetworkConfig:
embedding_net: str = "RESNET18"
out_dim: int = 256
num_transforms: int = 5
num_flow_layers: int = 10
num_flow_nodes: int = 256
flow_model: str = "NSF"
theta_scale: float = 1.0
theta_shift: float = 0.0


@dataclasses.dataclass
class OptimizerConfig:
optimizer: str = "Adamw"
lr: float = 0.0003
weight_decay: float = 0.01


@dataclasses.dataclass
class TrainingConfig:
start_from_checkpoint: bool = False
model_checkpoint: Union[str, None] = None
batch_size: int = 256
num_epochs: int = 300
log_interval: int = 10
save_interval: int = 50
clip_grad: float = 5.0
optimizer: OptimizerConfig = OptimizerConfig()
use_misspecification_loss: bool = False
misspecification_loss_weight: float = 0.0


@dataclasses.dataclass
class Config:
simulator: SimulatorConfig
neural_network: NeuralNetworkConfig = NeuralNetworkConfig()
training: TrainingConfig = TrainingConfig()
34 changes: 17 additions & 17 deletions src/cryo_sbi/wpa_simulator/cryo_em_simulator.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
from typing import Union, Callable
import json
from omegaconf import OmegaConf
import numpy as np
import torch

Expand All @@ -9,6 +9,7 @@
from cryo_sbi.wpa_simulator.normalization import gaussian_normalize_image
from cryo_sbi.inference.priors import get_image_priors
from cryo_sbi.wpa_simulator.validate_image_config import check_image_params
from cryo_sbi.utils.configurations import SimulatorConfig, Config


def cryo_em_simulator(
Expand Down Expand Up @@ -59,9 +60,9 @@ def cryo_em_simulator(


class CryoEmSimulator:
def __init__(self, config_fname: str, device: str = "cpu"):
self._device = device
self._load_params(config_fname)
def __init__(self, config_file: str, device: str = "cpu"):
self.device = device
self._load_params(config_file)
self._load_models()
self._priors = get_image_priors(self.max_index, self._config, device=device)
self._num_pixels = torch.tensor(
Expand All @@ -71,7 +72,7 @@ def __init__(self, config_fname: str, device: str = "cpu"):
self._config["PIXEL_SIZE"], dtype=torch.float32, device=device
)

def _load_params(self, config_fname: str) -> None:
def _load_params(self, config_file: str) -> None:
"""
Loads the parameters from the config file into a dictionary.
Expand All @@ -81,10 +82,9 @@ def _load_params(self, config_fname: str) -> None:
Returns:
None
"""

config = json.load(open(config_fname))
check_image_params(config)
self._config = config
conf_from_yaml = OmegaConf.load()
validated_config = OmegaConf.merge(SimulatorConfig, conf_from_yaml)
self.config = validated_config

def _load_models(self) -> None:
"""
Expand All @@ -94,17 +94,17 @@ def _load_models(self) -> None:
None
"""
if self._config["MODEL_FILE"].endswith("npy"):
if self.config.model_file.endswith("npy"):
models = (
torch.from_numpy(
np.load(self._config["MODEL_FILE"]),
np.load(self.config.model_file),
)
.to(self._device)
.to(torch.float32)
)
elif self._config["MODEL_FILE"].endswith("pt"):
elif self.config.model_file.endswith("pt"):
models = (
torch.load(self._config["MODEL_FILE"])
torch.load(self.config.model_file)
.to(self._device)
.to(torch.float32)
)
Expand All @@ -114,10 +114,10 @@ def _load_models(self) -> None:
"Model file format not supported. Please use .npy or .pt."
)

self._models = models
self.models = models

assert self._models.ndim == 3, "Models are not of shape (models, 3, atoms)."
assert self._models.shape[1] == 3, "Models are not of shape (models, 3, atoms)."
assert self.models.ndim == 3, "Models are not of shape (models, 3, atoms)."
assert self.models.shape[1] == 3, "Models are not of shape (models, 3, atoms)."

@property
def max_index(self) -> int:
Expand All @@ -127,7 +127,7 @@ def max_index(self) -> int:
Returns:
int: Maximum index of the model file.
"""
return len(self._models) - 1
return len(self.models) - 1

def simulate(self, num_sim, indices=None, return_parameters=False, batch_size=None):
"""
Expand Down
2 changes: 1 addition & 1 deletion tests/config_files/image_params_testing.json
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
{
{
"N_PIXELS": 64,
"PIXEL_SIZE": 2.06,
"SIGMA": [0.5, 5.0],
Expand Down

0 comments on commit 9dd59bd

Please sign in to comment.