Skip to content

Commit

Permalink
Merge pull request #23 from analysiscenter/fix_loader
Browse files Browse the repository at this point in the history
Fix loader
  • Loading branch information
alexanderkuvaev authored Mar 5, 2019
2 parents 822f7c8 + ba70d19 commit e36c1df
Show file tree
Hide file tree
Showing 2 changed files with 12 additions and 13 deletions.
2 changes: 1 addition & 1 deletion cardio/batchflow
23 changes: 11 additions & 12 deletions cardio/core/ecg_batch.py
Original file line number Diff line number Diff line change
Expand Up @@ -132,6 +132,8 @@ class EcgBatch(bf.Batch):
``batch.resample_signals(fs)``.
"""

components = "signal", "annotation", "meta", "target"

def __init__(self, index, preloaded=None, unique_labels=None):
super().__init__(index, preloaded)
self.signal = self.array_of_nones
Expand All @@ -142,11 +144,6 @@ def __init__(self, index, preloaded=None, unique_labels=None):
self._label_binarizer = None
self.unique_labels = unique_labels

@property
def components(self):
"""tuple of str: Data components names."""
return "signal", "annotation", "meta", "target"

@property
def array_of_nones(self):
"""1-D ndarray: ``NumPy`` array with ``None`` values."""
Expand Down Expand Up @@ -249,12 +246,16 @@ def load(self, src=None, fmt=None, components=None, ann_ext=None, *args, **kwarg
"""
if components is None:
components = self.components
components = np.asarray(components).ravel()
if (fmt == "csv" or fmt is None and isinstance(src, pd.Series)) and np.all(components == "target"):
components = np.unique(components).ravel().tolist()

if (fmt == "csv" or fmt is None and isinstance(src, pd.Series)) and components == ['target']:
return self._load_labels(src)
if fmt in ["wfdb", "dicom", "edf", "wav", "xml"]:
unexpected_components = set(components) - set(self.components)
if unexpected_components:
raise ValueError('Unexpected components: ', unexpected_components)
return self._load_data(src=src, fmt=fmt, components=components, ann_ext=ann_ext, *args, **kwargs)
return super().load(src, fmt, components, *args, **kwargs)
return super().load(src=src, fmt=fmt, components=components, **kwargs)

@bf.inbatch_parallel(init="indices", post="_assemble_load", target="threads")
def _load_data(self, index, src=None, fmt=None, components=None, *args, **kwargs):
Expand Down Expand Up @@ -1088,10 +1089,8 @@ def _get_segmentation_arg(arg, arg_name, target):
arg = arg.get(target)
if arg is None:
raise KeyError("Undefined {} for target {}".format(arg_name, target))
else:
return arg
else:
raise ValueError("Unsupported {} type".format(arg_name))
return arg
raise ValueError("Unsupported {} type".format(arg_name))

@staticmethod
def _check_segmentation_args(signal, target, length, arg, arg_name):
Expand Down

0 comments on commit e36c1df

Please sign in to comment.