diff --git a/src/koza/cli_utils.py b/src/koza/cli_utils.py index e62b7e9..7616b51 100644 --- a/src/koza/cli_utils.py +++ b/src/koza/cli_utils.py @@ -127,14 +127,16 @@ def _check_row_count(type: Literal["node", "edge"]): _check_row_count("edge") -def split_file(file: str, - fields: str, - format: OutputFormat = OutputFormat.tsv, - remove_prefixes: bool = False, - output_dir: str = "./output"): +def split_file( + file: str, + fields: str, + format: OutputFormat = OutputFormat.tsv, + remove_prefixes: bool = False, + output_dir: str = "./output", +): db = duckdb.connect(":memory:") - #todo: validate that each of the fields is actually a column in the file + # todo: validate that each of the fields is actually a column in the file if format == OutputFormat.tsv: read_file = f"read_csv('{file}')" elif format == OutputFormat.json: @@ -179,14 +181,18 @@ def get_filename_suffix(name): for row in list_of_value_dicts: # export to a tsv file named with the values of the pivot fields where_clause = ' AND '.join([f"{k} = '{row[k]}'" for k in keys]) - file_name = output_dir + "/" + get_filename_prefix(file) + generate_filename_from_row(row) + get_filename_suffix(file) + file_name = ( + output_dir + "/" + get_filename_prefix(file) + generate_filename_from_row(row) + get_filename_suffix(file) + ) print(f"writing {file_name}") - db.execute(f""" + db.execute( + f""" COPY ( SELECT * FROM {read_file} WHERE {where_clause} ) TO '{file_name}' (HEADER, DELIMITER '\t'); - """) + """ + ) def validate_file( diff --git a/src/koza/main.py b/src/koza/main.py index fa4a27b..2dee4b9 100755 --- a/src/koza/main.py +++ b/src/koza/main.py @@ -65,15 +65,20 @@ def validate( """Validate a source file""" validate_file(file, format, delimiter, header_delimiter, skip_blank_lines) + @typer_app.command() def split( file: str = typer.Argument(..., help="Path to the source kgx file to be split"), fields: str = typer.Argument(..., help="Comma separated list of fields to split on"), - remove_prefixes: bool = typer.Option(False, help="Remove prefixes from the file names for values from the specified fields. (e.g, NCBIGene:9606 becomes 9606"), + remove_prefixes: bool = typer.Option( + False, + help="Remove prefixes from the file names for values from the specified fields. (e.g, NCBIGene:9606 becomes 9606", + ), output_dir: str = typer.Option(default="output", help="Path to output directory"), ): """Split a file by fields""" - split_file(file, fields,remove_prefixes=remove_prefixes, output_dir=output_dir) + split_file(file, fields, remove_prefixes=remove_prefixes, output_dir=output_dir) + if __name__ == "__main__": typer_app() diff --git a/src/koza/model/source.py b/src/koza/model/source.py index 79e6565..7d44e32 100644 --- a/src/koza/model/source.py +++ b/src/koza/model/source.py @@ -87,13 +87,13 @@ def __next__(self) -> Dict[str, Any]: return row def _get_row(self): - #If we built a filter for this source, run extra code to validate each row for inclusion in the final output. + # If we built a filter for this source, run extra code to validate each row for inclusion in the final output. if self._filter: row = next(self._reader) reject_current_row = not self._filter.include_row(row) - #If the filter says we shouldn't include the current row; we filter it out and move onto the next row. - #We'll only break out of the following loop if "reject_current_row" is false (i.e. include_row is True/we - #have a valid row to return) or we hit a StopIteration exception from self._reader. + # If the filter says we shouldn't include the current row; we filter it out and move onto the next row. + # We'll only break out of the following loop if "reject_current_row" is false (i.e. include_row is True/we + # have a valid row to return) or we hit a StopIteration exception from self._reader. while reject_current_row: row = next(self._reader) reject_current_row = not self._filter.include_row(row) diff --git a/src/koza/utils/row_filter.py b/src/koza/utils/row_filter.py index 61a15f0..daa0633 100644 --- a/src/koza/utils/row_filter.py +++ b/src/koza/utils/row_filter.py @@ -54,22 +54,22 @@ def include_row(self, row) -> bool: return include_row def inlist(self, column_value, filter_values): - #Check if the passed in column is exactly matched against - #For a filter_list of ['abc','def','ghi']; this will be true - #for column_value 'abc' but not 'abcde.' + # Check if the passed in column is exactly matched against + # For a filter_list of ['abc','def','ghi']; this will be true + # for column_value 'abc' but not 'abcde.' col_exact_match = column_value in filter_values - #The following iterates through all filters and will return true if - #the text of the filter is found within the column_value. - #So for the above example this boolean will return True, because :'abc' in 'abcde': returns True. - if(type(column_value)==str): + # The following iterates through all filters and will return true if + # the text of the filter is found within the column_value. + # So for the above example this boolean will return True, because :'abc' in 'abcde': returns True. + if type(column_value) == str: col_inexact_match = any([filter_value in column_value for filter_value in filter_values]) else: col_inexact_match = False return col_exact_match or col_inexact_match - + def inlist_exact(self, column_value, filter_values): - #Check if the passed in column is exactly matched against - #For a filter_list of ['abc','def','ghi']; this will be true - #for column_value 'abc' but not 'abcde.' + # Check if the passed in column is exactly matched against + # For a filter_list of ['abc','def','ghi']; this will be true + # for column_value 'abc' but not 'abcde.' col_exact_match = column_value in filter_values return col_exact_match