Skip to content

Commit

Permalink
ft: added CatalogRepository class, to handle all the catalog operatio…
Browse files Browse the repository at this point in the history
…ns (retrieve, store, filter, save inputcat to models, write/load testing cats, etc.
  • Loading branch information
pabloitu committed Aug 10, 2024
1 parent 16683ba commit f6fff2a
Show file tree
Hide file tree
Showing 5 changed files with 236 additions and 191 deletions.
36 changes: 18 additions & 18 deletions floatcsep/evaluation.py
Original file line number Diff line number Diff line change
@@ -1,17 +1,14 @@
import datetime
import json
import os
from typing import Dict, Callable, Union, Sequence, List

import numpy
from csep.core.catalogs import CSEPCatalog
from csep.core.forecasts import GriddedForecast
from csep.models import EvaluationResult
from matplotlib import pyplot

from floatcsep.model import Model
from floatcsep.registry import ExperimentRegistry
from floatcsep.utils import parse_csep_func, timewindow2str
from floatcsep.utils import parse_csep_func


class Evaluation:
Expand Down Expand Up @@ -76,7 +73,8 @@ def __init__(
self.markdown = markdown
self.type = Evaluation._TYPES.get(self.func.__name__)

self.repository = None
self.results_repo = None
self.catalog_repo = None

@property
def type(self):
Expand Down Expand Up @@ -125,7 +123,6 @@ def parse_plots(self, plot_func, plot_args, plot_kwargs):
def prepare_args(
self,
timewindow: Union[str, list],
catpath: Union[str, list],
model: Union[Model, Sequence[Model]],
ref_model: Union[Model, Sequence] = None,
region=None,
Expand Down Expand Up @@ -155,7 +152,7 @@ def prepare_args(
# Prepare argument tuple

forecast = model.get_forecast(timewindow, region)
catalog = self.get_catalog(catpath, forecast)
catalog = self.get_catalog(timewindow, forecast)

if isinstance(ref_model, Model):
# Args: (Fc, RFc, Cat)
Expand All @@ -171,29 +168,32 @@ def prepare_args(

return test_args

@staticmethod
def get_catalog(
catalog_path: Union[str, Sequence[str]],
self,
timewindow: Union[str, Sequence[str]],
forecast: Union[GriddedForecast, Sequence[GriddedForecast]],
) -> Union[CSEPCatalog, List[CSEPCatalog]]:
"""
Reads the catalog(s) from the given path(s). References the catalog region to the
forecast region.
Args:
catalog_path (str, list(str)): Path to the existing catalog
timewindow (str): Time window of the testing catalog
forecast (:class:`~csep.core.forecasts.GriddedForecast`): Forecast
object, onto which the catalog will be confronted for testing.
Returns:
"""
if isinstance(catalog_path, str):
eval_cat = CSEPCatalog.load_json(catalog_path)

if isinstance(timewindow, str):
# eval_cat = CSEPCatalog.load_json(catalog_path)
eval_cat = self.catalog_repo.get_test_cat(timewindow)
eval_cat.region = getattr(forecast, "region")

else:
eval_cat = [CSEPCatalog.load_json(i) for i in catalog_path]
eval_cat = [self.catalog_repo.get_test_cat(i) for i in timewindow]
if (len(forecast) != len(eval_cat)) or (not isinstance(forecast, Sequence)):
raise IndexError("Amount of passed catalogs and forecats must " "be the same")
raise IndexError("Amount of passed catalogs and forecasts must " "be the same")

Check warning on line 196 in floatcsep/evaluation.py

View check run for this annotation

Codecov / codecov/patch

floatcsep/evaluation.py#L196

Added line #L196 was not covered by tests
for cat, fc in zip(eval_cat, forecast):
cat.region = getattr(fc, "region", None)

Expand Down Expand Up @@ -222,15 +222,15 @@ def compute(
Returns:
"""
test_args = self.prepare_args(
timewindow, catpath=catalog, model=model, ref_model=ref_model, region=region
timewindow, model=model, ref_model=ref_model, region=region
)

evaluation_result = self.func(*test_args, **self.func_kwargs)

if self.type in ["sequential", "sequential_comparative"]:
self.repository.write_result(evaluation_result, self, model, timewindow[-1])
self.results_repo.write_result(evaluation_result, self, model, timewindow[-1])
else:
self.repository.write_result(evaluation_result, self, model, timewindow)
self.results_repo.write_result(evaluation_result, self, model, timewindow)

def read_results(
self, window: Union[str, Sequence[datetime.datetime]], models: List[Model]
Expand All @@ -240,7 +240,7 @@ def read_results(
all tested models.
"""

test_results = self.repository.load_results(self, window, models)
test_results = self.results_repo.load_results(self, window, models)

return test_results

Expand Down
151 changes: 25 additions & 126 deletions floatcsep/experiment.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,7 @@
from floatcsep.logger import add_fhandler
from floatcsep.model import Model, TimeDependentModel
from floatcsep.registry import ExperimentRegistry
from floatcsep.repository import ResultsRepository
from floatcsep.repository import ResultsRepository, CatalogRepository
from floatcsep.utils import (
NoAliasLoader,
parse_csep_func,
Expand Down Expand Up @@ -146,7 +146,8 @@ def __init__(

self.name = name if name else "floatingExp"
self.registry = ExperimentRegistry(workdir, rundir)
self.repository = ResultsRepository(self.registry)
self.results_repo = ResultsRepository(self.registry)
self.catalog_repo = CatalogRepository(self.registry)

self.config_file = kwargs.get("config_file", None)
self.original_config = kwargs.get("original_config", None)
Expand Down Expand Up @@ -183,7 +184,8 @@ def __init__(
self.postproc_config = postproc_config if postproc_config else {}
self.default_test_kwargs = default_test_kwargs

self.catalog = catalog
self.catalog_repo.set_catalog(catalog, self.time_config, self.region_config)

self.models = self.set_models(
models or kwargs.get("model_config"), kwargs.get("order", None)
)
Expand Down Expand Up @@ -327,114 +329,28 @@ def set_tests(self, test_config: Union[str, Dict, List]) -> list:
tests = []

if isinstance(test_config, str):

with open(self.registry.abs(test_config), "r") as config:
config_dict = yaml.load(config, NoAliasLoader)

for eval_dict in config_dict:
eval_i = Evaluation.from_dict(eval_dict)
eval_i.repository = self.repository
eval_i.results_repo = self.results_repo
eval_i.catalog_repo = self.catalog_repo
tests.append(eval_i)

elif isinstance(test_config, (dict, list)):

for eval_dict in test_config:
eval_i = Evaluation.from_dict(eval_dict)
eval_i.repository = self.repository
eval_i.results_repo = self.results_repo
eval_i.catalog_repo = self.catalog_repo
tests.append(eval_i)

log.info(f"\tEvaluations: {[i.name for i in tests]}")

return tests

@property
def catalog(self) -> CSEPCatalog:
"""
Returns a CSEP catalog loaded from the given query function or a stored file if it
exists.
"""
cat_path = self.registry.abs(self._catpath)

if callable(self._catalog):
if isfile(self._catpath):
return CSEPCatalog.load_json(self._catpath)
bounds = {
"start_time": min([item for sublist in self.timewindows for item in sublist]),
"end_time": max([item for sublist in self.timewindows for item in sublist]),
"min_magnitude": self.magnitudes.min(),
"max_depth": self.depths.max(),
}
if self.region:
bounds.update(
{
i: j
for i, j in zip(
["min_longitude", "max_longitude", "min_latitude", "max_latitude"],
self.region.get_bbox(),
)
}
)

catalog = self._catalog(catalog_id="catalog", **bounds)

if self.region:
catalog.filter_spatial(region=self.region, in_place=True)
catalog.region = None
catalog.write_json(self._catpath)

return catalog

elif isfile(cat_path):
try:
return CSEPCatalog.load_json(cat_path)
except json.JSONDecodeError:
return csep.load_catalog(cat_path)

@catalog.setter
def catalog(self, cat: Union[Callable, CSEPCatalog, str]) -> None:

if cat is None:
self._catalog = None
self._catpath = None

elif isfile(self.registry.abs(cat)):
log.info(f"\tCatalog: '{cat}'")
self._catalog = self.registry.rel(cat)
self._catpath = self.registry.rel(cat)

else:
# catalog can be a function
self._catalog = parse_csep_func(cat)
self._catpath = self.registry.abs("catalog.json")
if isfile(self._catpath):
log.info(f"\tCatalog: stored " f"'{self._catpath}' " f"from '{cat}'")
else:
log.info(f"\tCatalog: '{cat}'")

def get_test_cat(self, tstring: str = None) -> CSEPCatalog:
"""
Filters the complete experiment catalog to a test sub-catalog bounded by the test
time-window. Writes it to filepath defined in :attr:`Experiment.registry`
Args:
tstring (str): Time window string
"""

if tstring:
start, end = str2timewindow(tstring)
else:
start = self.start_date
end = self.end_date
sub_cat = self.catalog.filter(
[
f"origin_time < {end.timestamp() * 1000}",
f"origin_time >= {start.timestamp() * 1000}",
f"magnitude >= {self.mag_min}",
f"magnitude < {self.mag_max}",
],
in_place=False,
)
if self.region:
sub_cat.filter_spatial(region=self.region, in_place=True)

return sub_cat

def set_test_cat(self, tstring: str) -> None:
"""
Filters the complete experiment catalog to a test sub-catalog bounded by the test
Expand All @@ -444,42 +360,22 @@ def set_test_cat(self, tstring: str) -> None:
tstring (str): Time window string
"""

testcat_name = self.registry.get(tstring, "catalog")
if not exists(testcat_name):
log.debug(
f"Filtering catalog to testing sub-catalog and saving to " f"{testcat_name}"
)
start, end = str2timewindow(tstring)
sub_cat = self.catalog.filter(
[
f"origin_time < {end.timestamp() * 1000}",
f"origin_time >= {start.timestamp() * 1000}",
f"magnitude >= {self.mag_min}",
f"magnitude < {self.mag_max}",
],
in_place=False,
)
if self.region:
sub_cat.filter_spatial(region=self.region, in_place=True)
sub_cat.write_json(filename=testcat_name)
else:
log.debug(f"Using stored test sub-catalog from {testcat_name}")
self.catalog_repo.set_test_cat(tstring)

def set_input_cat(self, tstring: str, model: Model) -> None:
"""
Filters the complete experiment catalog to a input sub-catalog filtered.
to the beginning of thetest time-window. Writes it to filepath defined
to the beginning of the test time-window. Writes it to filepath defined
in :attr:`Model.tree.catalog`
Args:
tstring (str): Time window string
model (:class:`~floatcsep.model.Model`): Model to give the input
catalog
"""
start, end = str2timewindow(tstring)
sub_cat = self.catalog.filter([f"origin_time < {start.timestamp() * 1000}"])
sub_cat.write_ascii(filename=model.registry.get("input_cat"))

self.catalog_repo.set_input_cat(tstring, model)

def set_tasks(self):
"""
Expand Down Expand Up @@ -696,7 +592,7 @@ def plot_catalog(self, dpi: int = 300, show: bool = False) -> None:
"legend": True,
}
plot_args.update(self.postproc_config.get("plot_catalog", {}))
catalog = self.get_test_cat()
catalog = self.catalog_repo.get_test_cat()

Check warning on line 595 in floatcsep/experiment.py

View check run for this annotation

Codecov / codecov/patch

floatcsep/experiment.py#L595

Added line #L595 was not covered by tests
if catalog.get_number_of_events() != 0:
ax = catalog.plot(plot_args=plot_args, show=show)
ax.get_figure().tight_layout()
Expand Down Expand Up @@ -829,9 +725,11 @@ def make_repr(self):
self.region_config["region"] = self.registry.rel(new_path)

Check warning on line 725 in floatcsep/experiment.py

View check run for this annotation

Codecov / codecov/patch

floatcsep/experiment.py#L725

Added line #L725 was not covered by tests

# Dropping catalog to results folder
target_cat = join(self.registry.workdir, self.registry.rundir, split(self._catpath)[-1])
target_cat = join(
self.registry.workdir, self.registry.rundir, split(self.catalog_repo._catpath)[-1]
)
if not exists(target_cat):
shutil.copy2(self.registry.abs(self._catpath), target_cat)
shutil.copy2(self.registry.abs(self.catalog_repo._catpath), target_cat)
self._catpath = self.registry.rel(target_cat)

relative_path = os.path.relpath(
Expand All @@ -851,7 +749,8 @@ def as_dict(
"tasks",
"models",
"tests",
"repository",
"results_repo",
"catalog_repo",
),
extended: bool = False,
) -> dict:
Expand All @@ -868,7 +767,7 @@ def as_dict(
"""

listwalk = [(i, j) for i, j in self.__dict__.items() if not i.startswith("_") and j]
listwalk.insert(6, ("catalog", self._catpath))
listwalk.insert(6, ("catalog", self.catalog_repo._catpath))

dictwalk = {i: j for i, j in listwalk}
dictwalk["path"] = dictwalk.pop("registry").workdir
Expand Down
2 changes: 1 addition & 1 deletion floatcsep/report.py
Original file line number Diff line number Diff line change
Expand Up @@ -39,7 +39,7 @@ def generate_report(experiment, timewindow=-1):
report.add_heading("Authoritative Data", level=2)

# Generate catalog plot
if experiment.catalog is not None:
if experiment.catalog_repo.catalog is not None:

Check warning on line 42 in floatcsep/report.py

View check run for this annotation

Codecov / codecov/patch

floatcsep/report.py#L42

Added line #L42 was not covered by tests
experiment.plot_catalog()
report.add_figure(
f"Input catalog",
Expand Down
Loading

0 comments on commit f6fff2a

Please sign in to comment.