forked from IMSY-DKFZ/htc
-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathDatasetSettings.py
156 lines (126 loc) · 6.09 KB
/
DatasetSettings.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
# SPDX-FileCopyrightText: 2022 Division of Intelligent Medical Systems, DKFZ
# SPDX-License-Identifier: MIT
import json
from pathlib import Path
from typing import Any, Union
import numpy as np
from htc.utils.type_from_string import type_from_string
class DatasetSettings:
def __init__(self, path_or_data: Union[str, Path, dict]):
"""
Settings of the dataset (shapes etc.) which can be accessed by the DataPaths as path.dataset_settings. The data is not loaded when constructing this object but only when the settings data is actually accessed.
It is also possible to load the settings explicitly based on the path to the data directory:
>>> from htc.settings import settings
>>> dsettings = DatasetSettings(settings.data_dirs.semantic)
>>> dsettings["shape"]
(480, 640, 100)
Args:
path_or_data: Path (or string) to the JSON file containing the dataset settings or path to the data directory which contains the JSON file (in which case the name of the file must be dataset_settings.json). Alternatively, you can pass your settings directly as a dict.
"""
if isinstance(path_or_data, str):
path_or_data = Path(path_or_data)
if type(path_or_data) == dict:
self._data = path_or_data
self._data_conversions()
self._path = None
else:
self._data = None
self._path = path_or_data
def __repr__(self) -> str:
res = (
"Settings for the dataset"
f" {self.settings_path.parent.name if self.settings_path is not None else '(no path available)'}\n"
)
res += "The following settings are available:\n"
res += f"{list(self.data.keys())}"
return res
def __eq__(self, other: "DatasetSettings") -> bool:
if self._data is None and other._data is None:
return self.settings_path == other.settings_path
else:
return self.data == other.data
def __getitem__(self, key: str) -> Any:
assert key in self.data, f"{self.settings_path = }\n{self.data = }"
return self.data[key]
def get(self, key: str, default: Any = None) -> Any:
return self.data[key] if key in self.data else default
def __contains__(self, key: str) -> bool:
return key in self.data
@property
def settings_path(self) -> Union[None, Path]:
"""
Returns: The Path to the dataset_settings.json file if it exists or None if not.
"""
if self._path is None:
return None
else:
if self._path.exists():
p = self._path
if self._path.is_dir():
p /= "dataset_settings.json"
return p if p.exists() else None
else:
return None
@property
def data(self) -> dict:
if self._data is None:
if self.settings_path is None:
self._data = {}
else:
with self.settings_path.open(encoding="utf-8") as f:
self._data = json.load(f)
self._data_conversions()
return self._data
def data_path_class(self) -> Union[type, None]:
"""
Tries to infer the appropriate data path class for the current dataset. Ideally, this is defined in the dataset_settings.json file with a key data_path_class referring to a valid data path class (e.g. htc.tivita.DataPathMultiorgan>DataPathMultiorgan). If this is not the case, it tries to infer the data path class based on the dataset name or based on the files in the folder.
Returns: Data path type or None if no match could be found.
"""
if "data_path_class" in self:
DataPathClass = type_from_string(self["data_path_class"])
elif "multiorgan" in self.get("dataset_name", ""):
from htc.tivita.DataPathMultiorgan import DataPathMultiorgan
DataPathClass = DataPathMultiorgan
elif "sepsis" in self.get("dataset_name", ""):
from htc.tivita.DataPathSepsis import DataPathSepsis
DataPathClass = DataPathSepsis
elif self._path is not None:
# Try to infer the data path class from the files in the directory
if self._path.is_file() or not self._path.exists():
dataset_dir = self._path.parent
else:
dataset_dir = self._path
assert dataset_dir.exists() and dataset_dir.is_dir(), f"The dataset directory {dataset_dir} does not exist"
files = sorted(dataset_dir.iterdir())
if any([f.name.endswith("subjects") for f in files]):
from htc.tivita.DataPathMultiorgan import DataPathMultiorgan
DataPathClass = DataPathMultiorgan
elif any([f.name == "sepsis_study" for f in files]):
from htc.tivita.DataPathSepsis import DataPathSepsis
DataPathClass = DataPathSepsis
elif any([f.stem == "image_references" for f in files]):
from htc.tivita.DataPathReference import DataPathReference
DataPathClass = DataPathReference
else:
DataPathClass = None
else:
DataPathClass = None
return DataPathClass
def pixels_image(self) -> int:
"""
Returns: Number of pixels of one image in the dataset.
"""
assert "shape" in self.data, "No shape information available in the dataset settings"
return int(np.prod(self.data["spatial_shape"]))
def _data_conversions(self) -> None:
if "shape" in self._data:
self._data["shape"] = tuple(self._data["shape"])
if "shape_names" in self._data:
names = self._data["shape_names"]
assert (
"height" in names and "width" in names
), f"shape_names must at least include height and width (got: {names})"
self._data["spatial_shape"] = (
self._data["shape"][names.index("height")],
self._data["shape"][names.index("width")],
)