Skip to content

Commit

Permalink
add fsspec to rdheader
Browse files Browse the repository at this point in the history
  • Loading branch information
briangow committed Jan 6, 2025
1 parent c6d4fd9 commit fce4d62
Show file tree
Hide file tree
Showing 3 changed files with 20 additions and 6 deletions.
2 changes: 2 additions & 0 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,8 @@ dependencies = [
"soundfile >= 0.10.0",
"matplotlib >= 3.2.2",
"requests >= 2.8.1",
"fsspec >= 2023.10.0",
"aiohttp >= 3.11.11",
]
dynamic = ["version"]

Expand Down
14 changes: 11 additions & 3 deletions wfdb/io/download.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@
import os
import posixpath

import fsspec
import numpy as np

from wfdb.io import _url
Expand All @@ -12,6 +13,9 @@
PN_INDEX_URL = "https://physionet.org/files/"
PN_CONTENT_URL = "https://physionet.org/content/"

# Cloud protocols
CLOUD_PROTOCOLS = ["az:", "azureml:", "s3:", "gs:"]


class Config(object):
"""
Expand Down Expand Up @@ -101,11 +105,15 @@ def _stream_header(file_name: str, pn_dir: str) -> str:
The text contained in the header file
"""
# Full url of header location
url = posixpath.join(config.db_index_url, pn_dir, file_name)
# Full cloud url
if any(pn_dir.startswith(proto) for proto in CLOUD_PROTOCOLS):
url = posixpath.join(pn_dir, file_name)
# Full physionet database url
else:
url = posixpath.join(config.db_index_url, pn_dir, file_name)

# Get the content of the remote file
with _url.openurl(url, "rb") as f:
with fsspec.open(url, "rb") as f:
content = f.read()

return content.decode("iso-8859-1")
Expand Down
10 changes: 7 additions & 3 deletions wfdb/io/record.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@
import os
import re

import fsspec
import numpy as np
import pandas as pd

Expand Down Expand Up @@ -1826,8 +1827,11 @@ def rdheader(record_name, pn_dir=None, rd_segments=False):
dir_name, base_record_name = os.path.split(record_name)
dir_name = os.path.abspath(dir_name)

# Construct the download path using the database version
if (pn_dir is not None) and ("." not in pn_dir):
# If this is a cloud path we leave it as is
if (pn_dir is not None) and any(pn_dir.startswith(proto) for proto in download.CLOUD_PROTOCOLS):
pass
# If it isn't a cloud path, construct the download path using the database version
elif (pn_dir is not None) and ("." not in pn_dir):
dir_list = pn_dir.split("/")
pn_dir = posixpath.join(
dir_list[0], download.get_version(dir_list[0]), *dir_list[1:]
Expand All @@ -1836,7 +1840,7 @@ def rdheader(record_name, pn_dir=None, rd_segments=False):
# Read the local or remote header file.
file_name = f"{base_record_name}.hea"
if pn_dir is None:
with open(
with fsspec.open(
os.path.join(dir_name, file_name),
"r",
encoding="ascii",
Expand Down

0 comments on commit fce4d62

Please sign in to comment.