Skip to content

Commit

Permalink
Merge pull request #928 from IanCa/develop
Browse files Browse the repository at this point in the history
Update dataframe loading/saving to allow passing a folder name
  • Loading branch information
VisLab authored May 17, 2024
2 parents c61c7af + 40964c2 commit 7fb7317
Show file tree
Hide file tree
Showing 8 changed files with 59 additions and 10 deletions.
6 changes: 5 additions & 1 deletion hed/schema/hed_schema.py
Original file line number Diff line number Diff line change
Expand Up @@ -310,7 +310,11 @@ def save_as_xml(self, filename, save_merged=True):
opened_file.write(xml_string)

def save_as_dataframes(self, base_filename, save_merged=False):
""" Save as mediawiki to a file.
""" Save as dataframes to a folder of files.
If base_filename has a .tsv suffix, save directly to the indicated location.
If base_filename is a directory(does NOT have a .tsv suffix), save the contents into a directory named that.
The subfiles are named the same. e.g. HED8.3.0/HED8.3.0_Tag.tsv
base_filename: str
save filename. A suffix will be added to most, e.g. _Tag
Expand Down
2 changes: 1 addition & 1 deletion hed/schema/hed_schema_io.py
Original file line number Diff line number Diff line change
Expand Up @@ -94,7 +94,7 @@ def load_schema(hed_path, schema_namespace=None, schema=None, name=None):
hed_schema = SchemaLoaderXML.load(hed_path, schema=schema, name=name)
elif hed_path.lower().endswith(".mediawiki"):
hed_schema = SchemaLoaderWiki.load(hed_path, schema=schema, name=name)
elif hed_path.lower().endswith(".tsv"):
elif hed_path.lower().endswith(".tsv") or os.path.isdir(hed_path):
if schema is not None:
raise HedFileError(HedExceptions.INVALID_HED_FORMAT,
"Cannot pass a schema to merge into spreadsheet loading currently.", filename=name)
Expand Down
11 changes: 9 additions & 2 deletions hed/schema/schema_io/df2schema.py
Original file line number Diff line number Diff line change
Expand Up @@ -53,12 +53,19 @@ def convert_filenames_to_dict(filenames):
Parameters:
filenames(str or None or list or dict): The list to convert to a dict
If a string with a .tsv suffix: Save to that location, adding the suffix to each .tsv file
If a string with no .tsv suffix: Save to that folder, with the contents being the separate .tsv files.
Returns:
filename_dict(str: str): The required suffix to filename mapping"""
result_filenames = {}
if isinstance(filenames, str):
base, base_ext = os.path.splitext(filenames)
if filenames.endswith(".tsv"):
base, base_ext = os.path.splitext(filenames)
else:
# Load as foldername/foldername_suffix.tsv
base_dir = filenames
base_filename = os.path.split(base_dir)[1]
base = os.path.join(base_dir, base_filename)
for suffix in constants.DF_SUFFIXES:
filename = f"{base}_{suffix}.tsv"
result_filenames[suffix] = filename
Expand Down
14 changes: 13 additions & 1 deletion hed/schema/schema_io/ontology_util.py
Original file line number Diff line number Diff line change
Expand Up @@ -360,12 +360,24 @@ def save_dataframes(base_filename, dataframe_dict):
Does not validate contents or suffixes.
If base_filename has a .tsv suffix, save directly to the indicated location.
If base_filename is a directory(does NOT have a .tsv suffix), save the contents into a directory named that.
The subfiles are named the same. e.g. HED8.3.0/HED8.3.0_Tag.tsv
Parameters:
base_filename(str): The base filename to use. Output is {base_filename}_{suffix}.tsv
See DF_SUFFIXES for all expected names.
dataframe_dict(dict of str: df.DataFrame): The list of files to save out. No validation is done.
"""
base, base_ext = os.path.splitext(base_filename)
if base_filename.lower().endswith(".tsv"):
base, base_ext = os.path.splitext(base_filename)
base_dir, base_name = os.path.split(base)
else:
# Assumed as a directory name
base_dir = base_filename
base_filename = os.path.split(base_dir)[1]
base = os.path.join(base_dir, base_filename)
os.makedirs(base_dir, exist_ok=True)
for suffix, dataframe in dataframe_dict.items():
filename = f"{base}_{suffix}.tsv"
with open(filename, mode='w', encoding='utf-8') as opened_file:
Expand Down
7 changes: 5 additions & 2 deletions hed/schema/schema_io/schema2df.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,7 @@ def __init__(self, get_as_ids=False):
"""
super().__init__()
self._get_as_ids = get_as_ids
self._tag_rows = []

def _get_object_name_and_id(self, object_name, include_prefix=False):
""" Get the adjusted name and ID for the given object type.
Expand Down Expand Up @@ -67,6 +68,7 @@ def _initialize_output(self):
constants.OBJECT_KEY: pd.DataFrame(columns=constants.property_columns, dtype=str),
constants.ATTRIBUTE_PROPERTY_KEY: pd.DataFrame(columns=constants.property_columns_reduced, dtype=str),
}
self._tag_rows = []

def _create_and_add_object_row(self, base_object, attributes="", description=""):
name, full_hed_id = self._get_object_name_and_id(base_object)
Expand Down Expand Up @@ -95,7 +97,7 @@ def _start_section(self, key_class):
pass

def _end_tag_section(self):
pass
self.output[constants.TAG_KEY] = pd.DataFrame(self._tag_rows, columns=constants.tag_columns, dtype=str)

def _write_tag_entry(self, tag_entry, parent_node=None, level=0):
tag_id = tag_entry.attributes.get(HedKey.HedID, "")
Expand All @@ -108,7 +110,8 @@ def _write_tag_entry(self, tag_entry, parent_node=None, level=0):
constants.description: tag_entry.description,
constants.equivalent_to: self._get_tag_equivalent_to(tag_entry),
}
self.output[constants.TAG_KEY].loc[len(self.output[constants.TAG_KEY])] = new_row
# Todo: do other sections like this as well for efficiency
self._tag_rows.append(new_row)

def _write_entry(self, entry, parent_node, include_props=True):
df_key = section_key_to_df.get(entry.section_key)
Expand Down
2 changes: 1 addition & 1 deletion hed/tools/remodeling/operations/summarize_hed_tags_op.py
Original file line number Diff line number Diff line change
Expand Up @@ -335,7 +335,7 @@ def save_visualizations(self, save_dir, file_formats=['.svg'], individual_summar
specifics = overall_summary.get("Specifics", {})
word_dict = self.summary_to_dict(specifics, scale_adjustment=wc["scale_adjustment"])

tag_wc = tag_word_cloud.tag_word_cloud.create_wordcloud(word_dict, mask_path=wc["mask_path"],
tag_wc = tag_word_cloud.create_wordcloud(word_dict, mask_path=wc["mask_path"],
width=wc["width"], height=wc["height"],
prefer_horizontal=wc["prefer_horizontal"], background_color=wc["background_color"],
min_font_size=wc["min_font_size"], max_font_size=wc["max_font_size"],
Expand Down
4 changes: 2 additions & 2 deletions hed/tools/visualization/word_cloud_util.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@

import numpy as np
from PIL import Image, ImageFilter
from matplotlib import cm
import matplotlib as mp1
import wordcloud as wcloud


Expand Down Expand Up @@ -139,7 +139,7 @@ def __init__(self, colormap='nipy_spectral', color_range=(0.0, 0.5), color_step_
This is the speed at which it goes through the range chosen.
.25 means it will go through 1/4 of the range each pick.
"""
self.colormap = cm.get_cmap(colormap)
self.colormap = mp1.colormaps[colormap]
self.color_range = color_range
self.color_step_range = color_step_range
self.current_fraction = random.uniform(0, 1) # Start at a random point
Expand Down
23 changes: 23 additions & 0 deletions tests/schema/test_hed_schema_io_df.py
Original file line number Diff line number Diff line change
Expand Up @@ -62,3 +62,26 @@ def test_from_dataframes(self):
reloaded_schema = from_dataframes(dfs)
self.assertEqual(schema, reloaded_schema)

def test_save_load_location(self):
schema = load_schema_version("8.3.0")
schema_name = "test_output"
output_location = self.output_folder + schema_name
schema.save_as_dataframes(output_location)
expected_location = os.path.join(output_location, f"{schema_name}_Tag.tsv")
self.assertTrue(os.path.exists(expected_location))

reloaded_schema = load_schema(output_location)

self.assertEqual(schema, reloaded_schema)

def test_save_load_location2(self):
schema = load_schema_version("8.3.0")
schema_name = "test_output"
output_location = self.output_folder + schema_name + ".tsv"
schema.save_as_dataframes(output_location)
expected_location = self.output_folder + schema_name + "_Tag.tsv"
self.assertTrue(os.path.exists(expected_location))

reloaded_schema = load_schema(output_location)

self.assertEqual(schema, reloaded_schema)

0 comments on commit 7fb7317

Please sign in to comment.