diff --git a/csep/core/forecasts.py b/csep/core/forecasts.py index 4b2f863f..b62e7bfd 100644 --- a/csep/core/forecasts.py +++ b/csep/core/forecasts.py @@ -1,13 +1,16 @@ -import itertools import time import os import datetime +from typing import Optional # third-party imports import numpy +import xml.etree.ElementTree as eTree +import h5py +import pandas from csep.utils.log import LoggingMixin -from csep.core.regions import CartesianGrid2D, create_space_magnitude_region +from csep.core.regions import CartesianGrid2D, create_space_magnitude_region, QuadtreeGrid2D from csep.models import Polygon from csep.utils.calc import bin1d_vec from csep.utils.time_utils import decimal_year, datetime_to_utc_epoch @@ -753,4 +756,250 @@ def load_ascii(cls, fname, **kwargs): Returns: :class:`csep.core.forecasts.CatalogForecast """ - raise NotImplementedError("load_ascii is not implemented!") \ No newline at end of file + raise NotImplementedError("load_ascii is not implemented!") + + +class GriddedForecastFactory: + + @staticmethod + def from_dat(filename: str, + swap_latlon: bool = False, + name: Optional[str] = None, + start_date: Optional[datetime.datetime] = None, + end_date: Optional[datetime.datetime] = None, + **kwargs) -> GriddedForecast: + """ Creates a :class:`GriddedCatalog` from a.dat file.""" + + data = numpy.loadtxt(filename) + all_polys = data[:, :4] + all_poly_mask = data[:, -1] + sorted_idx = numpy.sort( + numpy.unique(all_polys, return_index=True, axis=0)[1], kind="stable" + ) + unique_poly = all_polys[sorted_idx] + poly_mask = all_poly_mask[sorted_idx] + all_mws = data[:, -4] + sorted_idx = numpy.sort(numpy.unique(all_mws, return_index=True)[1], kind="stable") + + magnitudes = all_mws[sorted_idx] + if swap_latlon: + bboxes = [((i[2], i[0]), (i[3], i[0]), + (i[3], i[1]), (i[2], i[1])) for i in unique_poly] + else: + bboxes = [((i[0], i[2]), (i[0], i[3]), + (i[1], i[3]), (i[1], i[2])) for i in unique_poly] + + dh = float(unique_poly[0, 3] - unique_poly[0, 2]) + + n_mag_bins = len(magnitudes) + rates = data[:, -2].reshape(len(bboxes), n_mag_bins) + + region = CartesianGrid2D([Polygon(bbox) for bbox in bboxes], dh, mask=poly_mask) + + forecast = GriddedForecast( + name=f"{name}", + data=rates, + region=region, + magnitudes=magnitudes, + start_time=start_date, + end_time=end_date, + **kwargs + ) + + return forecast + + @staticmethod + def from_xml(filename: str, + **kwargs): + tree = eTree.parse(filename) + root = tree.getroot() + metadata = {} + data_ijm = [] + m_bins = [] + cells = [] + cell_dim = {} + for k, children in enumerate(list(root[0])): + if "modelName" in children.tag: + name_xml = children.text + metadata["name"] = name_xml + elif "author" in children.tag: + author_xml = children.text + metadata["author"] = author_xml + elif "forecastStartDate" in children.tag: + start_date = children.text.replace("Z", "") + metadata["forecastStartDate"] = start_date + elif "forecastEndDate" in children.tag: + end_date = children.text.replace("Z", "") + metadata["forecastEndDate"] = end_date + elif "defaultMagBinDimension" in children.tag: + m_bin_width = float(children.text) + metadata["defaultMagBinDimension"] = m_bin_width + elif "lastMagBinOpen" in children.tag: + lastmbin = float(children.text) + metadata["lastMagBinOpen"] = lastmbin + elif "defaultCellDimension" in children.tag: + cell_dim = {i[0]: float(i[1]) for i in children.attrib.items()} + metadata["defaultCellDimension"] = cell_dim + elif "depthLayer" in children.tag: + depth = {i[0]: float(i[1]) for i in root[0][k].attrib.items()} + cells = root[0][k] + metadata["depthLayer"] = depth + + for cell in cells: + cell_data = [] + m_cell_bins = [] + for i, m in enumerate(cell.iter()): + if i == 0: + cell_data.extend([float(m.attrib["lon"]), float(m.attrib["lat"])]) + else: + cell_data.append(float(m.text)) + m_cell_bins.append(float(m.attrib["m"])) + data_ijm.append(cell_data) + m_bins.append(m_cell_bins) + try: + data_ijm = numpy.array(data_ijm) + m_bins = numpy.array(m_bins) + except (TypeError, ValueError): + raise Exception("Data is not square") + + magnitudes = m_bins[0, :] + rates = data_ijm[:, -len(magnitudes) :] + all_polys = numpy.vstack( + ( + data_ijm[:, 0] - cell_dim["lonRange"] / 2.0, + data_ijm[:, 0] + cell_dim["lonRange"] / 2.0, + data_ijm[:, 1] - cell_dim["latRange"] / 2.0, + data_ijm[:, 1] + cell_dim["latRange"] / 2.0, + ) + ).T + bboxes = [((i[0], i[2]), (i[0], i[3]), (i[1], i[3]), (i[1], i[2])) for i in all_polys] + dh = float(all_polys[0, 3] - all_polys[0, 2]) + poly_mask = numpy.ones(len(bboxes)) + region = CartesianGrid2D([Polygon(bbox) for bbox in bboxes], dh, mask=poly_mask) + + forecast = GriddedForecast( + name=f"{metadata['name']}", + data=rates, + region=region, + magnitudes=magnitudes, + start_time=datetime.datetime.fromisoformat(metadata["forecastStartDate"]), + end_time=datetime.datetime.fromisoformat(metadata["forecastEndDate"]), + **kwargs + ) + return forecast + + @staticmethod + def from_quadtree(filename: str, + name: Optional[str] = None, + start_date: Optional[datetime.datetime] = None, + end_date: Optional[datetime.datetime] = None, + **kwargs) -> GriddedForecast: + + with open(filename, "r") as file_: + qt_header = file_.readline().split(",") + formats = [str] + for i in range(len(qt_header) - 1): + formats.append(float) + + qt_formats = {i: j for i, j in zip(qt_header, formats)} + data = pandas.read_csv(filename, header=0, dtype=qt_formats) + + quadkeys = numpy.array([i.encode("ascii", "ignore") for i in data.tile]) + magnitudes = numpy.array(data.keys()[3:]).astype(float) + rates = data[magnitudes.astype(str)].to_numpy() + + region = QuadtreeGrid2D.from_quadkeys([str(i) for i in quadkeys], magnitudes=magnitudes) + region.get_cell_area() + + forecast = GriddedForecast( + name=f"{name}", + data=rates, + region=region, + magnitudes=magnitudes, + start_time=start_date, + end_time=end_date, + **kwargs + ) + + return forecast + + @staticmethod + def from_csv(filename): + def is_mag(num): + try: + m = float(num) + if -1 < m < 12.0: + return True + else: + return False + except ValueError: + return False + + with open(filename, "r") as file_: + line = file_.readline() + if len(line.split(",")) > 3: + sep = "," + else: + sep = " " + + data = pandas.read_csv( + filename, header=0, sep=sep, escapechar="#", skipinitialspace=True + ) + + data.columns = [i.strip() for i in data.columns] + magnitudes = numpy.array([float(i) for i in data.columns if is_mag(i)]) + rates = data[[i for i in data.columns if is_mag(i)]].to_numpy() + all_polys = data[["lon_min", "lon_max", "lat_min", "lat_max"]].to_numpy() + bboxes = [((i[0], i[2]), (i[0], i[3]), (i[1], i[3]), (i[1], i[2])) for i in all_polys] + dh = float(all_polys[0, 3] - all_polys[0, 2]) + + try: + poly_mask = data["mask"] + except KeyError: + poly_mask = numpy.ones(len(bboxes)) + + region = CartesianGrid2D([Polygon(bbox) for bbox in bboxes], dh, mask=poly_mask) + + return rates, region, magnitudes + + @staticmethod + def from_hdf5(filename: str, + group: str = "", + name: Optional[str] = None, + start_date: Optional[datetime.datetime] = None, + end_date: Optional[datetime.datetime] = None, + **kwargs) -> GriddedForecast: + """ + Load a gridded forecast from an HDF5 file. + + Arguments: + filename: The name of the HDF5 file. + group: The HDF5 group to load the forecast from. Usually represents the forecast + time + name: The name of the gridded forecast. + start_date: The start date of the forecast. + end_date: The end date of the forecast. + **kwargs: Additional keyword arguments passed to `read`. + + """ + + with h5py.File(filename, "r") as db: + rates = db[f"{group}/rates"][:] + magnitudes = db[f"{group}/magnitudes"][:] + + dh = db[f"{group}/dh"][:][0] + bboxes = db[f"{group}/bboxes"][:] + poly_mask = db[f"{group}/poly_mask"][:] + region = CartesianGrid2D([Polygon(bbox) for bbox in bboxes], dh, mask=poly_mask) + + forecast = GriddedForecast( + name=f"{name}", + data=rates, + region=region, + magnitudes=magnitudes, + start_time=start_date, + end_time=end_date, + **kwargs + ) + + return forecast diff --git a/requirements.txt b/requirements.txt index df3845be..79b5517d 100644 --- a/requirements.txt +++ b/requirements.txt @@ -3,6 +3,7 @@ scipy pandas matplotlib cartopy +h5py obspy pyproj python-dateutil diff --git a/requirements.yml b/requirements.yml index be2fd220..3a424759 100644 --- a/requirements.yml +++ b/requirements.yml @@ -7,6 +7,7 @@ dependencies: - numpy - pandas - scipy + - h5py - matplotlib - pyproj - obspy diff --git a/requirements_dev.txt b/requirements_dev.txt index 6a95af5b..a497fead 100644 --- a/requirements_dev.txt +++ b/requirements_dev.txt @@ -4,6 +4,7 @@ pandas matplotlib cartopy obspy +h5py pyproj python-dateutil pytest diff --git a/setup.py b/setup.py index 5390de9f..4b7f088f 100644 --- a/setup.py +++ b/setup.py @@ -30,6 +30,7 @@ def get_version(): 'numpy', 'scipy', 'pandas', + 'h5py', 'matplotlib', 'cartopy', 'obspy',