diff --git a/hed/schema/hed_schema.py b/hed/schema/hed_schema.py index bbb16fa3..1a2ea941 100644 --- a/hed/schema/hed_schema.py +++ b/hed/schema/hed_schema.py @@ -266,9 +266,8 @@ def get_as_xml_string(self, save_merged=True): def get_as_dataframes(self, save_merged=False): """ Get a dict of dataframes representing this file - save_merged: bool - If True, this will save the schema as a merged schema if it is a "withStandard" schema. - If it is not a "withStandard" schema, this setting has no effect. + Parameters: + save_merged (bool): If True, returns DFs as if merged with standard. Returns: dataframes(dict): a dict of dataframes you can load as a schema diff --git a/hed/schema/hed_schema_df_constants.py b/hed/schema/hed_schema_df_constants.py index cdee9429..4c26f220 100644 --- a/hed/schema/hed_schema_df_constants.py +++ b/hed/schema/hed_schema_df_constants.py @@ -1,8 +1,9 @@ from hed.schema.hed_schema_constants import HedSectionKey from hed.schema import hed_schema_constants -# Known tsv format suffixes +KEY_COLUMN_NAME = 'rdfs.label' +# Known tsv format suffixes STRUCT_KEY = "Structure" TAG_KEY = "Tag" UNIT_KEY = "Unit" diff --git a/hed/schema/schema_io/df_util.py b/hed/schema/schema_io/df_util.py index 0f4927a6..1cb45e9f 100644 --- a/hed/schema/schema_io/df_util.py +++ b/hed/schema/schema_io/df_util.py @@ -12,6 +12,47 @@ UNKNOWN_LIBRARY_VALUE = 0 +def merge_dataframe_dicts(df_dict1, df_dict2, key_column=constants.KEY_COLUMN_NAME): + """ Create a new dictionary of DataFrames where dict2 is merged into dict1. + + Does not validate contents or suffixes. + + Parameters: + df_dict1(dict of str: df.DataFrame): dataframes to use as destination merge. + df_dict2(dict of str: df.DataFrame): dataframes to use as a merge element. + key_column(str): name of the column that is treated as the key when dataframes are merged + """ + + result_dict = {} + all_keys = set(df_dict1.keys()).union(set(df_dict2.keys())) + + for key in all_keys: + if key in df_dict1 and key in df_dict2: + result_dict[key] = _merge_dataframes(df_dict1[key], df_dict2[key], key_column) + elif key in df_dict1: + result_dict[key] = df_dict1[key] + else: + result_dict[key] = df_dict2[key] + + return result_dict + + +def _merge_dataframes(df1, df2, key_column): + # Add columns from df2 that are not in df1, only for rows that are in df1 + + if df1.empty or df2.empty or key_column not in df1.columns or key_column not in df2.columns: + raise HedFileError(HedExceptions.BAD_COLUMN_NAMES, + f"Both dataframes to be merged must be non-empty had nave a '{key_column}' column", "") + df1 = df1.copy() + for col in df2.columns: + if col not in df1.columns and col != key_column: + df1 = df1.merge(df2[[key_column, col]], on=key_column, how='left') + + # Fill missing values with '' + df1.fillna('', inplace=True) + + return df1 + def save_dataframes(base_filename, dataframe_dict): """ Writes out the dataframes using the provided suffixes. diff --git a/hed/schema/schema_io/schema2base.py b/hed/schema/schema_io/schema2base.py index fba2adbf..8ddc9d4b 100644 --- a/hed/schema/schema_io/schema2base.py +++ b/hed/schema/schema_io/schema2base.py @@ -1,208 +1,202 @@ -"""Baseclass for mediawiki/xml writers""" -from hed.schema.hed_schema_constants import HedSectionKey, HedKey -from hed.errors.exceptions import HedFileError, HedExceptions - - -class Schema2Base: - def __init__(self): - # Placeholder output variable - self.output = None - self._save_lib = False - self._save_base = False - self._save_merged = False - self._strip_out_in_library = False - self._schema = None - - def process_schema(self, hed_schema, save_merged=False): - """ - Takes a HedSchema object and returns it in the inherited form(mediawiki, xml, etc) - - Parameters - ---------- - hed_schema : HedSchema - save_merged: bool - If True, this will save the schema as a merged schema if it is a "withStandard" schema. - If it is not a "withStandard" schema, this setting has no effect. - - Returns - ------- - converted_output: Any - Varies based on inherited class - - """ - if not hed_schema.can_save(): - raise HedFileError(HedExceptions.SCHEMA_LIBRARY_INVALID, - "Cannot save a schema merged from multiple library schemas", - hed_schema.filename) - - self._initialize_output() - self._save_lib = False - self._save_base = False - self._strip_out_in_library = True - self._schema = hed_schema # This is needed to save attributes in dataframes for now - if hed_schema.with_standard: - self._save_lib = True - if save_merged: - self._save_base = True - self._strip_out_in_library = False - else: - # Saving a standard schema or a library schema without a standard schema - save_merged = True - self._save_lib = True - self._save_base = True - - self._save_merged = save_merged - - self._output_header(hed_schema.get_save_header_attributes(self._save_merged), hed_schema.prologue) - self._output_tags(hed_schema.tags) - self._output_units(hed_schema.unit_classes) - self._output_section(hed_schema, HedSectionKey.UnitModifiers) - self._output_section(hed_schema, HedSectionKey.ValueClasses) - self._output_section(hed_schema, HedSectionKey.Attributes) - self._output_section(hed_schema, HedSectionKey.Properties) - self._output_footer(hed_schema.epilogue) - - return self.output - - def _initialize_output(self): - raise NotImplementedError("This needs to be defined in the subclass") - - def _output_header(self, attributes, prologue): - raise NotImplementedError("This needs to be defined in the subclass") - - def _output_footer(self, epilogue): - raise NotImplementedError("This needs to be defined in the subclass") - - def _start_section(self, key_class): - raise NotImplementedError("This needs to be defined in the subclass") - - def _end_tag_section(self): - raise NotImplementedError("This needs to be defined in the subclass") - - def _write_tag_entry(self, tag_entry, parent=None, level=0): - raise NotImplementedError("This needs to be defined in the subclass") - - def _write_entry(self, entry, parent_node, include_props=True): - raise NotImplementedError("This needs to be defined in the subclass") - - def _output_tags(self, tags): - schema_node = self._start_section(HedSectionKey.Tags) - - # This assumes .all_entries is sorted in a reasonable way for output. - level_adj = 0 - all_nodes = {} # List of all nodes we've written out. - for tag_entry in tags.all_entries: - if self._should_skip(tag_entry): - continue - tag = tag_entry.name - level = tag.count("/") - - # Don't adjust if we're a top level tag(if this is a rooted tag, it will be re-adjusted below) - if not tag_entry.parent_name: - level_adj = 0 - if level == 0: - root_tag = self._write_tag_entry(tag_entry, schema_node, level) - all_nodes[tag_entry.name] = root_tag - else: - # Only output the rooted parent nodes if they have a parent(for duplicates that don't) - if tag_entry.has_attribute(HedKey.InLibrary) and tag_entry.parent and \ - not tag_entry.parent.has_attribute(HedKey.InLibrary) and not self._save_merged: - if tag_entry.parent.name not in all_nodes: - level_adj = level - - parent_node = all_nodes.get(tag_entry.parent_name, schema_node) - child_node = self._write_tag_entry(tag_entry, parent_node, level - level_adj) - all_nodes[tag_entry.name] = child_node - - self._end_tag_section() - - def _output_units(self, unit_classes): - section_node = self._start_section(HedSectionKey.UnitClasses) - - for unit_class_entry in unit_classes.values(): - has_lib_unit = False - if self._should_skip(unit_class_entry): - has_lib_unit = any(unit.attributes.get(HedKey.InLibrary) for unit in unit_class_entry.units.values()) - if not self._save_lib or not has_lib_unit: - continue - - unit_class_node = self._write_entry(unit_class_entry, section_node, not has_lib_unit) - - unit_types = unit_class_entry.units - for unit_entry in unit_types.values(): - if self._should_skip(unit_entry): - continue - - self._write_entry(unit_entry, unit_class_node) - - def _output_section(self, hed_schema, key_class): - parent_node = self._start_section(key_class) - for entry in hed_schema[key_class].values(): - if self._should_skip(entry): - continue - self._write_entry(entry, parent_node) - - def _should_skip(self, entry): - has_lib_attr = entry.has_attribute(HedKey.InLibrary) - if not self._save_base and not has_lib_attr: - return True - if not self._save_lib and has_lib_attr: - return True - return False - - def _attribute_disallowed(self, attribute): - return self._strip_out_in_library and attribute == HedKey.InLibrary - - def _format_tag_attributes(self, attributes): - """ - Takes a dictionary of tag attributes and returns a string with the .mediawiki representation - - Parameters - ---------- - attributes : {str:str} - {attribute_name : attribute_value} - Returns - ------- - str: - The formatted string that should be output to the file. - """ - prop_string = "" - final_props = [] - for prop, value in attributes.items(): - # Never save InLibrary if saving merged. - if self._attribute_disallowed(prop): - continue - if value is True: - final_props.append(prop) - else: - if "," in value: - split_values = value.split(",") - for split_value in split_values: - final_props.append(f"{prop}={split_value}") - else: - final_props.append(f"{prop}={value}") - - if final_props: - interior = ", ".join(final_props) - prop_string = f"{interior}" - - return prop_string - - @staticmethod - def _get_attribs_string_from_schema(header_attributes, sep=" "): - """ - Gets the schema attributes and converts it to a string. - - Parameters - ---------- - header_attributes : dict - Attributes to format attributes from - - Returns - ------- - str: - A string of the attributes that can be written to a .mediawiki formatted file - """ - attrib_values = [f"{attr}=\"{value}\"" for attr, value in header_attributes.items()] - final_attrib_string = sep.join(attrib_values) - return final_attrib_string +"""Baseclass for mediawiki/xml writers""" +from hed.schema.hed_schema_constants import HedSectionKey, HedKey +from hed.errors.exceptions import HedFileError, HedExceptions + + +class Schema2Base: + def __init__(self): + # Placeholder output variable + self.output = None + self._save_lib = False + self._save_base = False + self._save_merged = False + self._strip_out_in_library = False + self._schema = None + + def process_schema(self, hed_schema, save_merged=False): + """ Takes a HedSchema object and returns it in the inherited form(mediawiki, xml, etc) + + Parameters: + hed_schema (HedSchema): The schema to be processed. + save_merged (bool): If True, save as merged schema if has "withStandard". + + Returns: + Any: Varies based on inherited class + + """ + if not hed_schema.can_save(): + raise HedFileError(HedExceptions.SCHEMA_LIBRARY_INVALID, + "Cannot save a schema merged from multiple library schemas", + hed_schema.filename) + + self._initialize_output() + self._save_lib = False + self._save_base = False + self._strip_out_in_library = True + self._schema = hed_schema # This is needed to save attributes in dataframes for now + if hed_schema.with_standard: + self._save_lib = True + if save_merged: + self._save_base = True + self._strip_out_in_library = False + else: + # Saving a standard schema or a library schema without a standard schema + save_merged = True + self._save_lib = True + self._save_base = True + + self._save_merged = save_merged + + self._output_header(hed_schema.get_save_header_attributes(self._save_merged), hed_schema.prologue) + self._output_tags(hed_schema.tags) + self._output_units(hed_schema.unit_classes) + self._output_section(hed_schema, HedSectionKey.UnitModifiers) + self._output_section(hed_schema, HedSectionKey.ValueClasses) + self._output_section(hed_schema, HedSectionKey.Attributes) + self._output_section(hed_schema, HedSectionKey.Properties) + self._output_footer(hed_schema.epilogue) + + return self.output + + def _initialize_output(self): + raise NotImplementedError("This needs to be defined in the subclass") + + def _output_header(self, attributes, prologue): + raise NotImplementedError("This needs to be defined in the subclass") + + def _output_footer(self, epilogue): + raise NotImplementedError("This needs to be defined in the subclass") + + def _start_section(self, key_class): + raise NotImplementedError("This needs to be defined in the subclass") + + def _end_tag_section(self): + raise NotImplementedError("This needs to be defined in the subclass") + + def _write_tag_entry(self, tag_entry, parent=None, level=0): + raise NotImplementedError("This needs to be defined in the subclass") + + def _write_entry(self, entry, parent_node, include_props=True): + raise NotImplementedError("This needs to be defined in the subclass") + + def _output_tags(self, tags): + schema_node = self._start_section(HedSectionKey.Tags) + + # This assumes .all_entries is sorted in a reasonable way for output. + level_adj = 0 + all_nodes = {} # List of all nodes we've written out. + for tag_entry in tags.all_entries: + if self._should_skip(tag_entry): + continue + tag = tag_entry.name + level = tag.count("/") + + # Don't adjust if we're a top level tag(if this is a rooted tag, it will be re-adjusted below) + if not tag_entry.parent_name: + level_adj = 0 + if level == 0: + root_tag = self._write_tag_entry(tag_entry, schema_node, level) + all_nodes[tag_entry.name] = root_tag + else: + # Only output the rooted parent nodes if they have a parent(for duplicates that don't) + if tag_entry.has_attribute(HedKey.InLibrary) and tag_entry.parent and \ + not tag_entry.parent.has_attribute(HedKey.InLibrary) and not self._save_merged: + if tag_entry.parent.name not in all_nodes: + level_adj = level + + parent_node = all_nodes.get(tag_entry.parent_name, schema_node) + child_node = self._write_tag_entry(tag_entry, parent_node, level - level_adj) + all_nodes[tag_entry.name] = child_node + + self._end_tag_section() + + def _output_units(self, unit_classes): + section_node = self._start_section(HedSectionKey.UnitClasses) + + for unit_class_entry in unit_classes.values(): + has_lib_unit = False + if self._should_skip(unit_class_entry): + has_lib_unit = any(unit.attributes.get(HedKey.InLibrary) for unit in unit_class_entry.units.values()) + if not self._save_lib or not has_lib_unit: + continue + + unit_class_node = self._write_entry(unit_class_entry, section_node, not has_lib_unit) + + unit_types = unit_class_entry.units + for unit_entry in unit_types.values(): + if self._should_skip(unit_entry): + continue + + self._write_entry(unit_entry, unit_class_node) + + def _output_section(self, hed_schema, key_class): + parent_node = self._start_section(key_class) + for entry in hed_schema[key_class].values(): + if self._should_skip(entry): + continue + self._write_entry(entry, parent_node) + + def _should_skip(self, entry): + has_lib_attr = entry.has_attribute(HedKey.InLibrary) + if not self._save_base and not has_lib_attr: + return True + if not self._save_lib and has_lib_attr: + return True + return False + + def _attribute_disallowed(self, attribute): + return self._strip_out_in_library and attribute == HedKey.InLibrary + + def _format_tag_attributes(self, attributes): + """ + Takes a dictionary of tag attributes and returns a string with the .mediawiki representation + + Parameters + ---------- + attributes : {str:str} + {attribute_name : attribute_value} + Returns + ------- + str: + The formatted string that should be output to the file. + """ + prop_string = "" + final_props = [] + for prop, value in attributes.items(): + # Never save InLibrary if saving merged. + if self._attribute_disallowed(prop): + continue + if value is True: + final_props.append(prop) + else: + if "," in value: + split_values = value.split(",") + for split_value in split_values: + final_props.append(f"{prop}={split_value}") + else: + final_props.append(f"{prop}={value}") + + if final_props: + interior = ", ".join(final_props) + prop_string = f"{interior}" + + return prop_string + + @staticmethod + def _get_attribs_string_from_schema(header_attributes, sep=" "): + """ + Gets the schema attributes and converts it to a string. + + Parameters + ---------- + header_attributes : dict + Attributes to format attributes from + + Returns + ------- + str: + A string of the attributes that can be written to a .mediawiki formatted file + """ + attrib_values = [f"{attr}=\"{value}\"" for attr, value in header_attributes.items()] + final_attrib_string = sep.join(attrib_values) + return final_attrib_string diff --git a/tests/schema/test_hed_schema_io_util_df.py b/tests/schema/test_hed_schema_io_util_df.py new file mode 100644 index 00000000..4a350376 --- /dev/null +++ b/tests/schema/test_hed_schema_io_util_df.py @@ -0,0 +1,129 @@ +import unittest +import pandas as pd +from hed.schema.schema_io.df_util import _merge_dataframes, merge_dataframe_dicts +from hed import HedFileError + + +class TestMergeDataFrames(unittest.TestCase): + def setUp(self): + # Sample DataFrames for testing + self.df1 = pd.DataFrame({ + 'label': [1, 2, 3], + 'A_col1': ['A1', 'A2', 'A3'], + 'A_col2': [10, 20, 30] + }) + + self.df2 = pd.DataFrame({ + 'label': [2, 3, 4], + 'B_col1': ['B2', 'B3', 'B4'], + 'A_col2': [200, 300, 400] + }) + + self.df3 = pd.DataFrame({ + 'A_col1': ['A1', 'A2', 'A3'], + 'label': [2, 3, 4], + 'B_col1': ['B2', 'B3', 'B4'], + 'A_col2': [200, 300, 400], + 'B_col2': [3, 4, 5] + }) + + def test_merge_all_columns_present(self): + # Test that all columns from both DataFrames are present in the result + result = _merge_dataframes(self.df1, self.df2, 'label') + expected_columns = ['label', 'A_col1', 'A_col2', 'B_col1'] + self.assertListEqual(list(result.columns), expected_columns) + + + def test_merge_all_columns_present_different_order(self): + # Test that all columns from both DataFrames are present in the result + result = _merge_dataframes(self.df1, self.df3, 'label') + expected_columns = ['label', 'A_col1', 'A_col2', 'B_col1', 'B_col2'] + self.assertListEqual(list(result.columns), expected_columns) + + def test_merge_rows_from_df1(self): + # Test that only rows from df1 are present in the result + result = _merge_dataframes(self.df1, self.df2, 'label') + expected_labels = [1, 2, 3] # Only labels present in df1 + self.assertListEqual(list(result['label']), expected_labels) + + def test_merge_add_columns_from_df2(self): + # Test that columns from df2 are added to df1 + result = _merge_dataframes(self.df1, self.df2, 'label') + self.assertIn('B_col1', result.columns) + self.assertEqual(result.loc[result['label'] == 2, 'B_col1'].values[0], 'B2') + self.assertEqual(result.loc[result['label'] == 3, 'B_col1'].values[0], 'B3') + + def test_fill_missing_values(self): + # Test that missing values are filled with '' + result = _merge_dataframes(self.df1, self.df2, 'label') + self.assertEqual(result.loc[result['label'] == 1, 'B_col1'].values[0], '') + + def test_reset_index(self): + # Test that the index is reset correctly + result = _merge_dataframes(self.df1, self.df2, 'label') + expected_index = [0, 1, 2] + self.assertListEqual(list(result.index), expected_index) + + def test_missing_label_column_raises_error(self): + # Test that if one of the DataFrames does not have 'label' column, a HedFileError is raised + df_no_label = pd.DataFrame({ + 'A_col1': ['A1', 'A2', 'A3'], + 'A_col2': [10, 20, 30] + }) + with self.assertRaises(HedFileError): + _merge_dataframes(self.df1, df_no_label, 'label') + with self.assertRaises(HedFileError): + _merge_dataframes(df_no_label, self.df2, 'label') + + def test_merge_source_empty(self): + # Test that throws an exception if one frame is empty + with self.assertRaises(HedFileError): + _merge_dataframes(pd.DataFrame(), self.df1, 'label') + with self.assertRaises(HedFileError): + _merge_dataframes(self.df1, pd.DataFrame(), 'label') + + +class TestMergeDataFrameDicts(unittest.TestCase): + def setUp(self): + # Sample DataFrames for testing + self.df1 = pd.DataFrame({ + 'label': [1, 2, 3], + 'A_col1': ['A1', 'A2', 'A3'], + 'A_col2': [10, 20, 30] + }) + + self.df2 = pd.DataFrame({ + 'label': [2, 3, 4], + 'B_col1': ['B2', 'B3', 'B4'], + 'A_col2': [200, 300, 400] + }) + + self.dict1 = {'df1': self.df1} + self.dict2 = {'df1': self.df2, 'df2': self.df2} + + def test_merge_common_keys(self): + # Test that common keys are merged using _merge_dataframes + result = merge_dataframe_dicts(self.dict1, self.dict2, 'label') + expected_columns = ['label', 'A_col1', 'A_col2', 'B_col1'] + self.assertIn('df1', result) + self.assertListEqual(list(result['df1'].columns), expected_columns) + + def test_merge_unique_keys(self): + # Test that unique keys are preserved in the result dictionary + result = merge_dataframe_dicts(self.dict1, self.dict2, 'label') + self.assertIn('df2', result) + self.assertTrue(result['df2'].equals(self.df2)) + + def test_merge_no_common_keys(self): + # Test merging dictionaries with no common keys + dict1 = {'df1': self.df1} + dict2 = {'df2': self.df2} + result = merge_dataframe_dicts(dict1, dict2, 'label') + self.assertIn('df1', result) + self.assertIn('df2', result) + self.assertTrue(result['df1'].equals(self.df1)) + self.assertTrue(result['df2'].equals(self.df2)) + + +if __name__ == '__main__': + unittest.main()