Skip to content

Commit

Permalink
improve download and datasets
Browse files Browse the repository at this point in the history
  • Loading branch information
eroell committed Nov 14, 2024
1 parent e1a6abe commit df3827a
Show file tree
Hide file tree
Showing 3 changed files with 68 additions and 98 deletions.
117 changes: 47 additions & 70 deletions src/ehrdata/dt/dataloader.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,9 +5,6 @@
import shutil
import tempfile
from pathlib import Path
from random import choice
from string import ascii_lowercase
from typing import Literal

import requests
from filelock import FileLock
Expand All @@ -18,57 +15,59 @@

def download(
url: str,
archive_format: Literal["zip", "tar", "tar.gz", "tgz"] = None,
output_file_name: str = None,
output_path: str | Path = None,
saving_path: Path | str,
block_size: int = 1024,
overwrite: bool = False,
) -> None: # pragma: no cover
"""Downloads a file irrespective of format.
Args:
url: URL to download.
archive_format: The format if an archive file.
output_file_name: Name of the downloaded file.
output_path: Path to download/extract the files to. Defaults to 'OS tmpdir' if not specified.
block_size: Block size for downloads in bytes.
overwrite: Whether to overwrite existing files.
download_path: Where the data should be downloaded to.
"""
if output_file_name is None:
letters = ascii_lowercase
output_file_name = f"ehrapy_tmp_{''.join(choice(letters) for _ in range(10))}"

if output_path is None:
output_path = tempfile.gettempdir()

def _sanitize_file_name(file_name):
if os.name == "nt":
file_name = file_name.replace("?", "_").replace("*", "_")
return file_name

download_to_path = Path(
_sanitize_file_name(
f"{output_path}{output_file_name}"
if str(output_path).endswith("/")
else f"{output_path}/{output_file_name}"
)
)

Path(output_path).mkdir(parents=True, exist_ok=True)
lock_path = f"{download_to_path}.lock"
# note: tar.gz has to be before gz for the _remove_archive_extension function to remove the entire extension
compression_formats = ["tar.gz", "zip", "tar", "gz", "bz", "xz"]
raw_formats = ["csv", "txt", "parquet"]

saving_path = Path(saving_path)
# urls can end with "?download"
file_name = os.path.basename(url).split("?")[0]
suffix = file_name.split(".")[-1]

def _remove_archive_extension(file_path: str) -> str:
for ext in compression_formats:
# if the file path ends with extension, remove the extension and the dot before it (hence the -1)
if file_path.endswith(ext):
return file_path[: -len(ext) - 1]
return file_path

if suffix in raw_formats:
raw_data_saving_path = saving_path / file_name
path_to_check = raw_data_saving_path
elif suffix in compression_formats:
tmpdir = tempfile.mkdtemp()
raw_data_saving_path = Path(tmpdir) / file_name
path_to_check = saving_path / _remove_archive_extension(file_name)
else:
raise RuntimeError(f"Unknown file format: {suffix}")
return

if path_to_check.exists():
info = f"File {path_to_check} already exists!"
if not overwrite:
logging.info(f"{info} Use downloaded dataset...")
return
else:
logging.info(f"{info} Overwriting...")

logging.info(f"Downloading {file_name} from {url} to {raw_data_saving_path}")

lock_path = f"{raw_data_saving_path}.lock"
with FileLock(lock_path):
if _remove_archive_extension(download_to_path).exists():
warning = f"File {_remove_archive_extension(download_to_path)} already exists!"
if not overwrite:
logging.info(warning)
return
else:
logging.info(f"{warning} Overwriting...")

response = requests.get(url, stream=True)
total = int(response.headers.get("content-length", 0))

temp_file_name = f"{download_to_path}.part"
temp_file_name = f"{raw_data_saving_path}.part"

with Progress(refresh_per_second=1500) as progress:
task = progress.add_task("[red]Downloading...", total=total)
Expand All @@ -80,34 +79,12 @@ def _sanitize_file_name(file_name):
# force the progress bar to 100% at the end
progress.update(task, completed=total, refresh=True)

Path(temp_file_name).replace(download_to_path)
Path(temp_file_name).replace(raw_data_saving_path)

if archive_format:
output_path = output_path or tempfile.gettempdir()
shutil.unpack_archive(download_to_path, output_path, format=archive_format)
download_to_path.unlink()
list_of_paths = [path for path in Path(output_path).resolve().glob("*/") if not path.name.startswith(".")]
latest_path = max(list_of_paths, key=lambda path: path.stat().st_ctime)
shutil.move(
latest_path,
latest_path.parent / _remove_archive_extension(output_file_name),
) # type: ignore
if suffix in compression_formats:
shutil.unpack_archive(raw_data_saving_path, saving_path)
logging.info(
f"Extracted archive {file_name} from {raw_data_saving_path} to {saving_path / _remove_archive_extension(file_name)}"
)

Path(lock_path).unlink(missing_ok=True)


def _remove_archive_extension(file_path):
path = Path(file_path)
for ext in [
".tar.gz",
".tgz",
".tar.bz2",
".tbz2",
".tar.xz",
".txz",
".zip",
".tar",
]:
if str(path).endswith(ext):
return Path(str(path)[: -len(ext)])
return Path(path)
33 changes: 15 additions & 18 deletions src/ehrdata/dt/datasets.py
Original file line number Diff line number Diff line change
Expand Up @@ -63,22 +63,21 @@ def _set_up_duckdb(path: Path, backend_handle: DuckDBPyConnection, prefix: str =


def _setup_eunomia_datasets(
data_url: str,
backend_handle: DuckDBPyConnection,
data_path: Path | None = None,
data_url: str = None,
nested_omop_table_path: str = "",
nested_omop_tables_folder: str = None,
dataset_prefix: str = "",
) -> None:
"""Loads the Eunomia datasets in the OMOP Common Data model."""
download(
data_url,
archive_format="zip",
output_file_name=DOWNLOAD_VERIFICATION_TAG,
output_path=data_path,
saving_path=data_path,
)

for file_path in (data_path / DOWNLOAD_VERIFICATION_TAG / nested_omop_table_path).glob("*.csv"):
shutil.move(file_path, data_path)
if nested_omop_tables_folder:
for file_path in (data_path / nested_omop_tables_folder).glob("*.csv"):
shutil.move(file_path, data_path)

_set_up_duckdb(
data_path,
Expand Down Expand Up @@ -120,10 +119,10 @@ def mimic_iv_omop(backend_handle: DuckDBPyConnection, data_path: Path | None = N
data_path = Path("ehrapy_data/mimic-iv-demo-data-in-the-omop-common-data-model-0.9")

_setup_eunomia_datasets(
data_url=data_url,
backend_handle=backend_handle,
data_path=data_path,
data_url=data_url,
nested_omop_table_path="1_omop_data_csv",
nested_omop_tables_folder="mimic-iv-demo-data-in-the-omop-common-data-model-0.9/1_omop_data_csv",
dataset_prefix="2b_",
)

Expand Down Expand Up @@ -157,12 +156,13 @@ def gibleed_omop(backend_handle: DuckDBPyConnection, data_path: Path | None = No
data_url = "https://github.com/OHDSI/EunomiaDatasets/raw/main/datasets/GiBleed/GiBleed_5.3.zip"

if data_path is None:
data_path = Path("ehrapy_data/GiBleed")
data_path = Path("ehrapy_data/GiBleed_5.3")

_setup_eunomia_datasets(
data_url=data_url,
backend_handle=backend_handle,
data_path=data_path,
data_url=data_url,
nested_omop_tables_folder="GiBleed_5.3",
)


Expand Down Expand Up @@ -195,12 +195,12 @@ def synthea27nj_omop(backend_handle: DuckDBPyConnection, data_path: Path | None
data_url = "https://github.com/OHDSI/EunomiaDatasets/raw/main/datasets/Synthea27Nj/Synthea27Nj_5.4.zip"

if data_path is None:
data_path = Path("ehrapy_data/Synthea27Nj")
data_path = Path("ehrapy_data/Synthea27Nj_5.4")

_setup_eunomia_datasets(
data_url=data_url,
backend_handle=backend_handle,
data_path=data_path,
data_url=data_url,
)


Expand Down Expand Up @@ -289,16 +289,13 @@ def physionet2012(
for file_name in temp_data_set_names:
download(
url=f"https://physionet.org/files/challenge-2012/1.0.0/{file_name}.tar.gz?download",
output_path=data_path,
output_file_name=file_name + ".tar.gz",
archive_format="gztar",
saving_path=data_path,
)

for file_name in outcome_file_names:
download(
url=f"https://physionet.org/files/challenge-2012/1.0.0/{file_name}?download",
output_path=data_path,
output_file_name=file_name,
saving_path=data_path,
)

static_features = ["Age", "Gender", "ICUType", "Height"]
Expand Down
16 changes: 6 additions & 10 deletions tests/test_dt/test_dt.py
Original file line number Diff line number Diff line change
@@ -1,13 +1,9 @@
from pathlib import Path

import duckdb
import numpy as np
import pytest

import ehrdata as ed

TEST_DATA_DIR = Path(__file__).parent / "ehrapy_data2"


@pytest.fixture(scope="function")
def duckdb_connection():
Expand All @@ -17,27 +13,27 @@ def duckdb_connection():
con.close()


def test_mimic_iv_omop():
def test_mimic_iv_omop(tmp_path):
duckdb_connection = duckdb.connect()
ed.dt.mimic_iv_omop(data_path=TEST_DATA_DIR, backend_handle=duckdb_connection)
ed.dt.mimic_iv_omop(data_path=tmp_path, backend_handle=duckdb_connection)
assert len(duckdb_connection.execute("SHOW TABLES").df()) == 30
# sanity check of one table
assert duckdb_connection.execute("SELECT * FROM person").df().shape == (100, 18)
duckdb_connection.close()


def test_gibleed_omop():
def test_gibleed_omop(tmp_path):
duckdb_connection = duckdb.connect()
ed.dt.gibleed_omop(data_path=TEST_DATA_DIR, backend_handle=duckdb_connection)
ed.dt.gibleed_omop(data_path=tmp_path, backend_handle=duckdb_connection)
assert len(duckdb_connection.execute("SHOW TABLES").df()) == 36
# sanity check of one table
assert duckdb_connection.execute("SELECT * FROM person").df().shape == (2694, 18)
duckdb_connection.close()


def test_synthea27nj_omop():
def test_synthea27nj_omop(tmp_path):
duckdb_connection = duckdb.connect()
ed.dt.synthea27nj_omop(data_path=TEST_DATA_DIR, backend_handle=duckdb_connection)
ed.dt.synthea27nj_omop(data_path=tmp_path, backend_handle=duckdb_connection)
assert len(duckdb_connection.execute("SHOW TABLES").df()) == 37
# sanity check of one table
assert duckdb_connection.execute("SELECT * FROM person").df().shape == (28, 18)
Expand Down

0 comments on commit df3827a

Please sign in to comment.