Skip to content

Commit

Permalink
adapt to current pydantic version
Browse files Browse the repository at this point in the history
  • Loading branch information
DSilva27 committed Aug 28, 2024
1 parent ab58e94 commit 3beaaab
Showing 1 changed file with 14 additions and 14 deletions.
28 changes: 14 additions & 14 deletions src/cryo_challenge/data/_validation/config_validators.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@
import numpy as np
import pandas as pd
import os
from pydantic import BaseModel, validator, root_validator
from pydantic import BaseModel, field_validator, model_validator
from typing import Optional, List


Expand Down Expand Up @@ -259,21 +259,21 @@ class SVDNormalizeParams(BaseModel):
bfactor: float = None
box_size_ds: Optional[int] = None

@validator("mask_path")
@field_validator("mask_path")
def check_mask_path_exists(cls, value):
if value is not None:
if not os.path.exists(value):
raise ValueError(f"Mask file {value} does not exist.")
return value

@validator("bfactor")
@field_validator("bfactor")
def check_bfactor(cls, value):
if value is not None:
if value < 0:
raise ValueError("B-factor must be non-negative.")
return value

@validator("box_size_ds")
@field_validator("box_size_ds")
def check_box_size_ds(cls, value):
if value is not None:
if value < 0:
Expand All @@ -285,7 +285,7 @@ class SVDGtParams(BaseModel):
gt_vols_file: str
skip_vols: int = 1

@validator("gt_vols_file")
@field_validator("gt_vols_file")
def check_mask_path_exists(cls, value):
if not os.path.exists(value):
raise ValueError(f"Could not find file {value}.")
Expand All @@ -300,7 +300,7 @@ def check_mask_path_exists(cls, value):
)
return value

@validator("skip_vols")
@field_validator("skip_vols")
def check_skip_vols(cls, value):
if value is not None:
if value < 0:
Expand All @@ -327,10 +327,10 @@ class SVDConfig(BaseModel):
gt_params: Optional[SVDGtParams] = None
output_params: SVDOutputParams

@root_validator
def check_path_to_submissions(cls, values):
path_to_submissions = values.get("path_to_submissions")
excluded_submissions = values.get("excluded_submissions")
@model_validator(mode="after")
def check_path_to_submissions(self):
path_to_submissions = self.path_to_submissions
excluded_submissions = self.excluded_submissions

if not os.path.exists(path_to_submissions):
raise ValueError(f"Could not find path {path_to_submissions}.")
Expand All @@ -354,21 +354,21 @@ def check_path_to_submissions(cls, values):
f"No submission files found after excluding {excluded_submissions}."
)

return values
return self

@validator("dtype")
@field_validator("dtype")
def check_dtype(cls, value):
if value not in ["float32", "float64"]:
raise ValueError(f"Invalid dtype {value}.")
return value

@validator("svd_max_rank")
@field_validator("svd_max_rank")
def check_svd_max_rank(cls, value):
if value < 1 and value is not None:
raise ValueError("Max rank must be at least 1.")
return value

@validator("voxel_size")
@field_validator("voxel_size")
def check_voxel_size(cls, value):
if value <= 0:
raise ValueError("Voxel size must be positive.")
Expand Down

0 comments on commit 3beaaab

Please sign in to comment.