diff --git a/setup.py b/setup.py index 171ef74..1b1a9ab 100644 --- a/setup.py +++ b/setup.py @@ -26,7 +26,7 @@ "bioframe==0.3.3", "sparse==0.13.0", "multiprocess>=0.70.13", - "numba>=0.57.0" + "numba>=0.57.0", ] test_requirements = [ diff --git a/spoc/cli.py b/spoc/cli.py index 475cb79..dcad312 100644 --- a/spoc/cli.py +++ b/spoc/cli.py @@ -61,9 +61,7 @@ def bin_contacts( file_manager = FileManager(use_dask=True) contacts = file_manager.load_contacts(contact_path) # binning - binner = GenomicBinner( - bin_size=bin_size - ) + binner = GenomicBinner(bin_size=bin_size) pixels = binner.bin_contacts(contacts, same_chromosome=same_chromosome) # persisting file_manager.write_pixels(pixel_path, pixels) @@ -81,9 +79,7 @@ def merge_contacts(contact_paths, output): """Functionality to merge annotated fragments""" file_manager = FileManager(use_dask=True) manipulator = ContactManipulator() - contact_files = [ - file_manager.load_contacts(path) for path in contact_paths - ] + contact_files = [file_manager.load_contacts(path) for path in contact_paths] merged = manipulator.merge_contacts(contact_files) file_manager.write_multiway_contacts(output, merged) diff --git a/spoc/contacts.py b/spoc/contacts.py index 62705ac..3141d5b 100644 --- a/spoc/contacts.py +++ b/spoc/contacts.py @@ -1,6 +1,6 @@ """Managing multi-way contacts.""" -from __future__ import annotations # needed for self reference in type hints +from __future__ import annotations # needed for self reference in type hints from typing import List, Union, Optional import pandas as pd import dask.dataframe as dd @@ -23,13 +23,16 @@ def __init__( binary_labels_equal: bool = False, symmetry_flipped: bool = False, ) -> None: - self.contains_metadata = "metadata_1" in contact_frame.columns # All contacts contain at least one fragment + self.contains_metadata = ( + "metadata_1" in contact_frame.columns + ) # All contacts contain at least one fragment if number_fragments is None: self.number_fragments = self._guess_number_fragments(contact_frame) else: self.number_fragments = number_fragments self._schema = ContactSchema( - number_fragments=self.number_fragments, contains_metadata=self.contains_metadata + number_fragments=self.number_fragments, + contains_metadata=self.contains_metadata, ) if isinstance(contact_frame, pd.DataFrame): self.is_dask = False @@ -79,7 +82,6 @@ def get_chromosome_values(self) -> List[str]: output.update(self.data[f"chrom_{i+1}"].unique()) return output - @property def data(self): return self._data @@ -88,7 +90,6 @@ def data(self): def data(self, contact_frame): self._data = self._schema.validate(contact_frame) - def __repr__(self) -> str: return f"" @@ -117,7 +118,15 @@ def merge_contacts(self, merge_list: List[Contacts]) -> Contacts: @staticmethod def _generate_rename_columns(order, start_index=1): - columns = ["chrom", "start", "end", "mapping_quality", "align_score", "align_base_qscore", "metadata"] + columns = [ + "chrom", + "start", + "end", + "mapping_quality", + "align_score", + "align_base_qscore", + "metadata", + ] rename_columns = {} for i in range(len(order)): for column in columns: @@ -129,18 +138,25 @@ def _generate_rename_columns(order, start_index=1): @staticmethod def _get_label_combinations(labels, order): sorted_labels = sorted(labels) - combinations = set(tuple(sorted(i)) for i in product(sorted_labels, repeat=order)) + combinations = set( + tuple(sorted(i)) for i in product(sorted_labels, repeat=order) + ) return combinations @staticmethod def _get_combination_splits(combination): splits = [] - for index,(i, j) in enumerate(zip(combination[:-1], combination[1:])): + for index, (i, j) in enumerate(zip(combination[:-1], combination[1:])): if i != j: splits.append(index + 2) return [1] + splits + [len(combination) + 1] - def _flip_unlabelled_contacts(self, df: DataFrame, start_index:Optional[int]=None, end_index:Optional[int]=None) -> DataFrame: + def _flip_unlabelled_contacts( + self, + df: DataFrame, + start_index: Optional[int] = None, + end_index: Optional[int] = None, + ) -> DataFrame: """Flips contacts""" fragment_order = max(int(i.split("_")[1]) for i in df.columns if "start" in i) if start_index is None: @@ -150,20 +166,29 @@ def _flip_unlabelled_contacts(self, df: DataFrame, start_index:Optional[int]=Non subsets = [] for perm in permutations(range(start_index, end_index)): query = "<=".join([f"start_{i}" for i in perm]) - subsets.append(df.query(query).rename(columns=self._generate_rename_columns(perm, start_index))) + subsets.append( + df.query(query).rename( + columns=self._generate_rename_columns(perm, start_index) + ) + ) # determine which method to use for concatenation if isinstance(df, pd.DataFrame): result = pd.concat(subsets).sort_index() # this is needed if there are reads with equal start positions - result = result.loc[~result.index.duplicated(keep='first')] + result = result.loc[~result.index.duplicated(keep="first")] else: - result = dd.concat(subsets).reset_index()\ - .sort_values("index")\ - .drop_duplicates(subset=['index'])\ - .set_index("index") + result = ( + dd.concat(subsets) + .reset_index() + .sort_values("index") + .drop_duplicates(subset=["index"]) + .set_index("index") + ) return result - def _flip_labelled_contacts(self, df: DataFrame, label_values: List[str]) -> DataFrame: + def _flip_labelled_contacts( + self, df: DataFrame, label_values: List[str] + ) -> DataFrame: """Flips labelled contacts""" fragment_order = max(int(i.split("_")[1]) for i in df.columns if "start" in i) label_combinations = self._get_label_combinations(label_values, fragment_order) @@ -171,20 +196,30 @@ def _flip_labelled_contacts(self, df: DataFrame, label_values: List[str]) -> Dat for combination in label_combinations: splits = self._get_combination_splits(combination) # separate out name constanc_columns - query = " and ".join([f"metadata_{i} == '{j}'" for i, j in enumerate(combination, 1)]) + query = " and ".join( + [f"metadata_{i} == '{j}'" for i, j in enumerate(combination, 1)] + ) candidate_frame = df.query(query) if len(candidate_frame) == 0: continue - constant_df, variable_df = candidate_frame[['read_name', 'read_length']], candidate_frame.drop(['read_name', 'read_length'], axis=1) + constant_df, variable_df = candidate_frame[ + ["read_name", "read_length"] + ], candidate_frame.drop(["read_name", "read_length"], axis=1) split_frames = [constant_df] for start, end in zip(splits, splits[1:]): - # get all columns wiht nubmer between start and - subset_columns = [i for i in variable_df.columns if start <= int(i.split("_")[-1]) < end] + # get all columns wiht nubmer between start and + subset_columns = [ + i + for i in variable_df.columns + if start <= int(i.split("_")[-1]) < end + ] # if only columns is present, no need for flipping if start + 1 == end: split_frame = variable_df[subset_columns] else: - split_frame = self._flip_unlabelled_contacts(variable_df[subset_columns], start, end) + split_frame = self._flip_unlabelled_contacts( + variable_df[subset_columns], start, end + ) split_frames.append(split_frame) # concatenate split frames if isinstance(df, pd.DataFrame): @@ -196,74 +231,103 @@ def _flip_labelled_contacts(self, df: DataFrame, label_values: List[str]) -> Dat if isinstance(df, pd.DataFrame): result = pd.concat(subsets).sort_index() # this is needed if there are reads with equal start positions - result = result.loc[~result.index.duplicated(keep='first')] + result = result.loc[~result.index.duplicated(keep="first")] else: - result = dd.concat(subsets).reset_index()\ - .sort_values("index")\ - .drop_duplicates(subset=['index'])\ - .set_index("index") + result = ( + dd.concat(subsets) + .reset_index() + .sort_values("index") + .drop_duplicates(subset=["index"]) + .set_index("index") + ) return result - def sort_labels(self, contacts:Contacts) -> Contacts: + def sort_labels(self, contacts: Contacts) -> Contacts: """Sorts labels in ascending, alphabetical order""" if not contacts.contains_metadata: - raise ValueError("Sorting labels for unlabelled contacts is not implemented.") + raise ValueError( + "Sorting labels for unlabelled contacts is not implemented." + ) # get label values. label_values = contacts.get_label_values() # iterate over all permutations of label values subsets = [] for perm in product(label_values, repeat=contacts.number_fragments): - query = " and ".join([f"metadata_{i+1} == '{j}'" for i, j in enumerate(perm)]) + query = " and ".join( + [f"metadata_{i+1} == '{j}'" for i, j in enumerate(perm)] + ) desired_order = [i + 1 for i in np.argsort(perm)] - subsets.append(contacts.data.query(query).rename(columns=self._generate_rename_columns(desired_order))) + subsets.append( + contacts.data.query(query).rename( + columns=self._generate_rename_columns(desired_order) + ) + ) # determine which method to use for concatenation if contacts.is_dask: # this is a bit of a hack to get the index sorted. Dask does not support index sorting - result = dd.concat(subsets).reset_index()\ - .sort_values("index")\ - .set_index("index") + result = ( + dd.concat(subsets).reset_index().sort_values("index").set_index("index") + ) else: result = pd.concat(subsets).sort_index() - return Contacts(result, number_fragments=contacts.number_fragments, label_sorted=True) + return Contacts( + result, number_fragments=contacts.number_fragments, label_sorted=True + ) - def _sort_chromosomes(self, df:DataFrame, number_fragments:int) -> DataFrame: + def _sort_chromosomes(self, df: DataFrame, number_fragments: int) -> DataFrame: """Sorts chromosomes in ascending, alphabetical order""" # iterate over all permutations of chromosomes that exist subsets = [] if isinstance(df, dd.DataFrame): - chromosome_conbinations = df[[f"chrom_{i}" for i in range(1, number_fragments + 1)]].drop_duplicates().compute().values.tolist() + chromosome_conbinations = ( + df[[f"chrom_{i}" for i in range(1, number_fragments + 1)]] + .drop_duplicates() + .compute() + .values.tolist() + ) else: - chromosome_conbinations = df[[f"chrom_{i}" for i in range(1, number_fragments + 1)]].drop_duplicates().values.tolist() + chromosome_conbinations = ( + df[[f"chrom_{i}" for i in range(1, number_fragments + 1)]] + .drop_duplicates() + .values.tolist() + ) for perm in chromosome_conbinations: query = " and ".join([f"chrom_{i+1} == '{j}'" for i, j in enumerate(perm)]) desired_order = [i + 1 for i in np.argsort(perm, kind="stable")] - sorted_frame = df.query(query).rename(columns=self._generate_rename_columns(desired_order)) + sorted_frame = df.query(query).rename( + columns=self._generate_rename_columns(desired_order) + ) # ensure correct column order subsets.append(sorted_frame) # determine which method to use for concatenation if isinstance(df, dd.DataFrame): # this is a bit of a hack to get the index sorted. Dask does not support index sorting - result = dd.concat(subsets).reset_index()\ - .sort_values("index")\ - .set_index("index") + result = ( + dd.concat(subsets).reset_index().sort_values("index").set_index("index") + ) else: result = pd.concat(subsets).sort_index() - return result - + return result - def _generate_binary_label_mapping(self, label_values:List[str], number_fragments: int) -> Dict[str, str]: + def _generate_binary_label_mapping( + self, label_values: List[str], number_fragments: int + ) -> Dict[str, str]: sorted_labels = sorted(label_values) mapping = {} for i in range(number_fragments + 1): - target = [sorted_labels[0]]*(number_fragments - i) + [sorted_labels[-1]]*(i) - source = [sorted_labels[0]]*(i) + [sorted_labels[-1]]*(number_fragments - i) + target = [sorted_labels[0]] * (number_fragments - i) + [ + sorted_labels[-1] + ] * (i) + source = [sorted_labels[0]] * (i) + [sorted_labels[-1]] * ( + number_fragments - i + ) if i <= (number_fragments // 2): mapping[tuple(source)] = tuple(target) else: mapping[tuple(source)] = () return mapping - def equate_binary_labels(self, contacts:Contacts) -> Contacts: + def equate_binary_labels(self, contacts: Contacts) -> Contacts: """Binary labels often only carry information about whether they happen between the same or different fragments. This method equates these labels be replacing all equivalent binary labels with @@ -276,12 +340,18 @@ def equate_binary_labels(self, contacts:Contacts) -> Contacts: contacts = self.sort_labels(contacts) # get label values label_values = contacts.get_label_values() - assert len(label_values) == 2, "Equate binary labels only works for binary labels!" + assert ( + len(label_values) == 2 + ), "Equate binary labels only works for binary labels!" # generate mapping diectionary - mapping = self._generate_binary_label_mapping(label_values, contacts.number_fragments) + mapping = self._generate_binary_label_mapping( + label_values, contacts.number_fragments + ) subsets = [] for source, target in mapping.items(): - query = " and ".join([f"metadata_{i+1} == '{j}'" for i, j in enumerate(source)]) + query = " and ".join( + [f"metadata_{i+1} == '{j}'" for i, j in enumerate(source)] + ) subset = contacts.data.query(query) # assign target labels to dataframe for i, j in enumerate(target): @@ -290,36 +360,51 @@ def equate_binary_labels(self, contacts:Contacts) -> Contacts: # determine which method to use for concatenation if contacts.is_dask: # this is a bit of a hack to get the index sorted. Dask does not support index sorting - result = dd.concat(subsets).reset_index()\ - .sort_values("index")\ - .set_index("index") + result = ( + dd.concat(subsets).reset_index().sort_values("index").set_index("index") + ) else: result = pd.concat(subsets).sort_index() - return Contacts(result, number_fragments=contacts.number_fragments, label_sorted=True, - binary_labels_equal=True) - + return Contacts( + result, + number_fragments=contacts.number_fragments, + label_sorted=True, + binary_labels_equal=True, + ) - def subset_on_metadata(self, contacts:Contacts, metadata_combi: List[str]) -> Contacts: + def subset_on_metadata( + self, contacts: Contacts, metadata_combi: List[str] + ) -> Contacts: """Subset contacts based on metadata""" # check if metadata is present assert contacts.contains_metadata, "Contacts do not contain metadata!" # check if metadata_combi has the correct length - assert len(metadata_combi) == contacts.number_fragments, "Metadata combination does not match number of fragments!" + assert ( + len(metadata_combi) == contacts.number_fragments + ), "Metadata combination does not match number of fragments!" # get label values label_values = contacts.get_label_values() # check if metadata_combi is compatible with label values - assert all([i in label_values for i in metadata_combi]), "Metadata combination is not compatible with label values!" + assert all( + [i in label_values for i in metadata_combi] + ), "Metadata combination is not compatible with label values!" # subset contacts - query = " and ".join([f"metadata_{i+1} == '{j}'" for i, j in enumerate(metadata_combi)]) + query = " and ".join( + [f"metadata_{i+1} == '{j}'" for i, j in enumerate(metadata_combi)] + ) result = contacts.data.query(query) - return Contacts(result, number_fragments=contacts.number_fragments, - metadata_combi=metadata_combi, - label_sorted=contacts.label_sorted, - binary_labels_equal=contacts.binary_labels_equal, - symmetry_flipped=contacts.symmetry_flipped) - + return Contacts( + result, + number_fragments=contacts.number_fragments, + metadata_combi=metadata_combi, + label_sorted=contacts.label_sorted, + binary_labels_equal=contacts.binary_labels_equal, + symmetry_flipped=contacts.symmetry_flipped, + ) - def flip_symmetric_contacts(self, contacts: Contacts, sort_chromosomes: bool = False) -> Contacts: + def flip_symmetric_contacts( + self, contacts: Contacts, sort_chromosomes: bool = False + ) -> Contacts: """Flips contacts based on inherent symmetry""" if contacts.contains_metadata: if not contacts.label_sorted: @@ -328,12 +413,19 @@ def flip_symmetric_contacts(self, contacts: Contacts, sort_chromosomes: bool = F result = self._flip_labelled_contacts(contacts.data, label_values) if sort_chromosomes: result = self._sort_chromosomes(result, contacts.number_fragments) - return Contacts(result, number_fragments=contacts.number_fragments, label_sorted=True, - binary_labels_equal=contacts.binary_labels_equal, - symmetry_flipped=True - ) + return Contacts( + result, + number_fragments=contacts.number_fragments, + label_sorted=True, + binary_labels_equal=contacts.binary_labels_equal, + symmetry_flipped=True, + ) else: result = self._flip_unlabelled_contacts(contacts.data) if sort_chromosomes: result = self._sort_chromosomes(result, contacts.number_fragments) - return Contacts(result, number_fragments=contacts.number_fragments, symmetry_flipped=True) \ No newline at end of file + return Contacts( + result, + number_fragments=contacts.number_fragments, + symmetry_flipped=True, + ) diff --git a/spoc/dataframe_models.py b/spoc/dataframe_models.py index ec04d24..e9d63e0 100644 --- a/spoc/dataframe_models.py +++ b/spoc/dataframe_models.py @@ -50,13 +50,12 @@ class ContactSchema: "mapping_quality": pa.Column(int), "align_score": pa.Column(int), "align_base_qscore": pa.Column(int), - "metadata": pa.Column( - str, - required=False - ), + "metadata": pa.Column(str, required=False), } - def __init__(self, number_fragments: int = 3, contains_metadata: bool = True) -> None: + def __init__( + self, number_fragments: int = 3, contains_metadata: bool = True + ) -> None: self._number_fragments = number_fragments self._schema = pa.DataFrameSchema( dict( @@ -98,9 +97,7 @@ def validate_header(self, data_frame: DataFrame) -> None: self._schema, data_frame, "Header is invalid!" ) - def validate( - self, data_frame: DataFrame - ) -> DataFrame: + def validate(self, data_frame: DataFrame) -> DataFrame: """Validate multiway contact dataframe""" self.validate_header(data_frame) return self._schema.validate(data_frame) @@ -126,23 +123,20 @@ def _get_contact_fields(self): "start": pa.Column(int), } else: - return { - "chrom": pa.Column(str), - "start": pa.Column(int) - } + return {"chrom": pa.Column(str), "start": pa.Column(int)} def _get_constant_fields(self): if self._same_chromosome: - return { - "chrom": pa.Column(str), - "count": pa.Column(int), - "corrected_count": pa.Column(float, required=False), - } + return { + "chrom": pa.Column(str), + "count": pa.Column(int), + "corrected_count": pa.Column(float, required=False), + } else: - return { - "count": pa.Column(int), - "corrected_count": pa.Column(float, required=False), - } + return { + "count": pa.Column(int), + "corrected_count": pa.Column(float, required=False), + } def _expand_contact_fields(self, expansions: Iterable = (1, 2, 3)) -> dict: """adds suffixes to fields""" @@ -161,8 +155,6 @@ def validate_header(self, data_frame: DataFrame) -> None: self._schema, data_frame, "Header is invalid!" ) - def validate( - self, data_frame: DataFrame - ) -> DataFrame: + def validate(self, data_frame: DataFrame) -> DataFrame: """Validate multiway contact dataframe""" return self._schema.validate(data_frame) diff --git a/spoc/file_parameter_models.py b/spoc/file_parameter_models.py index 08a6dd9..28f47bc 100644 --- a/spoc/file_parameter_models.py +++ b/spoc/file_parameter_models.py @@ -3,8 +3,10 @@ from pydantic import BaseModel from typing import Optional, List + class ContactsParameters(BaseModel): """Parameters for multiway contacts""" + number_fragments: Optional[int] = None metadata_combi: Optional[List[str]] = None label_sorted: bool = False @@ -14,10 +16,11 @@ class ContactsParameters(BaseModel): class PixelParameters(BaseModel): """Parameters for genomic pixels""" + number_fragments: Optional[int] = None binsize: Optional[int] = None metadata_combi: Optional[List[str]] = None label_sorted: bool = False binary_labels_equal: bool = False symmetry_flipped: bool = False - same_chromosome: bool = True \ No newline at end of file + same_chromosome: bool = True diff --git a/spoc/fragments.py b/spoc/fragments.py index b609401..d71e2a3 100644 --- a/spoc/fragments.py +++ b/spoc/fragments.py @@ -16,7 +16,9 @@ class Fragments: def __init__(self, fragment_frame: DataFrame) -> None: self._data = FragmentSchema.validate(fragment_frame) - self._contains_metadata = True if "metadata" in fragment_frame.columns else False + self._contains_metadata = ( + True if "metadata" in fragment_frame.columns else False + ) @property def data(self): @@ -25,7 +27,7 @@ def data(self): @property def contains_metadata(self): return self._contains_metadata - + @property def is_dask(self): return isinstance(self._data, dd.DataFrame) @@ -78,7 +80,6 @@ def annotate_fragments(self, fragments: Fragments) -> Fragments: ) - class FragmentExpander: """Expands n-way fragments over sequencing reads to yield contacts.""" @@ -89,7 +90,7 @@ def __init__(self, number_fragments: int, contains_metadata: bool = True) -> Non self._schema = ContactSchema(number_fragments, contains_metadata) @staticmethod - def _add_suffix(row, suffix:int, contains_metadata:bool) -> Dict: + def _add_suffix(row, suffix: int, contains_metadata: bool) -> Dict: """expands contact fields""" output = {} for key in ContactSchema.get_contact_fields(contains_metadata): @@ -98,9 +99,13 @@ def _add_suffix(row, suffix:int, contains_metadata:bool) -> Dict: def _get_expansion_output_structure(self) -> pd.DataFrame: """returns expansion output dataframe structure for dask""" - return pd.DataFrame(columns=list(self._schema._schema.columns.keys()) + ['level_2']).set_index(["read_name", 'read_length', 'level_2']) + return pd.DataFrame( + columns=list(self._schema._schema.columns.keys()) + ["level_2"] + ).set_index(["read_name", "read_length", "level_2"]) - def _expand_single_read(self, read_df: pd.DataFrame, contains_metadata:bool) -> pd.DataFrame: + def _expand_single_read( + self, read_df: pd.DataFrame, contains_metadata: bool + ) -> pd.DataFrame: """Expands a single read""" if len(read_df) < self._number_fragments: return pd.DataFrame() @@ -116,9 +121,7 @@ def _expand_single_read(self, read_df: pd.DataFrame, contains_metadata:bool) -> contact = {} # add reads for index, align in enumerate(alignments, start=1): - contact.update( - self._add_suffix(align, index, contains_metadata) - ) + contact.update(self._add_suffix(align, index, contains_metadata)) result.append(contact) return pd.DataFrame(result) @@ -130,15 +133,15 @@ def expand(self, fragments: Fragments) -> Contacts: else: kwargs = dict() # expand - contact_df = fragments.data\ - .groupby(["read_name", "read_length"])\ - .apply(self._expand_single_read, - contains_metadata=fragments.contains_metadata, - **kwargs)\ - .reset_index()\ - .drop("level_2", axis=1) - #return contact_df - return Contacts( - contact_df, - number_fragments=self._number_fragments + contact_df = ( + fragments.data.groupby(["read_name", "read_length"]) + .apply( + self._expand_single_read, + contains_metadata=fragments.contains_metadata, + **kwargs, + ) + .reset_index() + .drop("level_2", axis=1) ) + # return contact_df + return Contacts(contact_df, number_fragments=self._number_fragments) diff --git a/spoc/io.py b/spoc/io.py index 1f021fd..cd57ebf 100644 --- a/spoc/io.py +++ b/spoc/io.py @@ -21,26 +21,24 @@ class FileManager: """Is responsible for loading and writing files""" - def __init__( - self, use_dask: bool = False - ) -> None: + def __init__(self, use_dask: bool = False) -> None: if use_dask: self._parquet_reader_func = dd.read_parquet else: self._parquet_reader_func = pd.read_parquet - def _write_parquet_dask(self, path: str, df: dd.DataFrame, global_parameters: BaseModel) -> None: + def _write_parquet_dask( + self, path: str, df: dd.DataFrame, global_parameters: BaseModel + ) -> None: """Write parquet file using dask""" custom_meta_data = { - 'spoc'.encode(): json.dumps(global_parameters.dict()).encode() + "spoc".encode(): json.dumps(global_parameters.dict()).encode() } - dd.to_parquet( - df, - path, - custom_metadata=custom_meta_data - ) - - def _write_parquet_pandas(self, path: str, df: pd.DataFrame ,global_parameters: BaseModel) -> None: + dd.to_parquet(df, path, custom_metadata=custom_meta_data) + + def _write_parquet_pandas( + self, path: str, df: pd.DataFrame, global_parameters: BaseModel + ) -> None: """Write parquet file using pandas. Pyarrow is needed because the pandas .to_parquet method does not support writing custom metadata.""" table = pa.Table.from_pandas(df) @@ -48,8 +46,8 @@ def _write_parquet_pandas(self, path: str, df: pd.DataFrame ,global_parameters: custom_meta_key = "spoc" existing_meta = table.schema.metadata combined_meta = { - custom_meta_key.encode() : json.dumps(global_parameters.dict()).encode(), - **existing_meta + custom_meta_key.encode(): json.dumps(global_parameters.dict()).encode(), + **existing_meta, } table = table.replace_schema_metadata(combined_meta) # write table @@ -60,7 +58,7 @@ def _load_parquet_global_parameters(path: str) -> BaseModel: """Load global parameters from parquet file""" # check if path is a directory, if so, we need to read the schema from one of the partitioned files if os.path.isdir(path): - path = path + '/' + os.listdir(path)[0] + path = path + "/" + os.listdir(path)[0] global_parameters = pa.parquet.read_schema(path).metadata.get("spoc".encode()) if global_parameters is not None: global_parameters = json.loads(global_parameters.decode()) @@ -91,25 +89,28 @@ def load_fragments(self, path: str): def write_fragments(path: str, fragments: Fragments) -> None: """Write annotated fragments""" # Write fragments - fragments.data.to_parquet(path, row_group_size=1024*1024) + fragments.data.to_parquet(path, row_group_size=1024 * 1024) def write_multiway_contacts(self, path: str, contacts: Contacts) -> None: """Write multiway contacts""" if contacts.is_dask: - self._write_parquet_dask(path, contacts.data, contacts.get_global_parameters()) + self._write_parquet_dask( + path, contacts.data, contacts.get_global_parameters() + ) else: - self._write_parquet_pandas(path, contacts.data, contacts.get_global_parameters()) - + self._write_parquet_pandas( + path, contacts.data, contacts.get_global_parameters() + ) - def load_contacts(self, path: str, global_parameters: Optional[ContactsParameters] = None) -> Contacts: + def load_contacts( + self, path: str, global_parameters: Optional[ContactsParameters] = None + ) -> Contacts: """Load multiway contacts""" if global_parameters is None: global_parameters = self._load_parquet_global_parameters(path) else: global_parameters = global_parameters.dict() - return Contacts( - self._parquet_reader_func(path), **global_parameters - ) + return Contacts(self._parquet_reader_func(path), **global_parameters) @staticmethod def load_chromosome_sizes(path: str): @@ -134,23 +135,22 @@ def _load_pixel_metadata(path: str): else: raise ValueError(f"Metadata file not found at {metadata_path}") return metadata - + @staticmethod def list_pixels(path: str): """List available pixels""" # read metadata.json metadata = FileManager._load_pixel_metadata(path) # instantiate pixel parameters - pixels = [ - PixelParameters(**params) for params in metadata.values() - ] + pixels = [PixelParameters(**params) for params in metadata.values()] return pixels - - def load_pixels(self, path: str, global_parameters: PixelParameters, load_dataframe:bool = True) -> Pixels: + def load_pixels( + self, path: str, global_parameters: PixelParameters, load_dataframe: bool = True + ) -> Pixels: """Loads specific pixels instance based on global parameters. load_dataframe specifies whether the dataframe should be loaded, or whether pixels - should be instantiated based on the path alone. """ + should be instantiated based on the path alone.""" metadata = self._load_pixel_metadata(path) # find matching pixels for pixel_path, value in metadata.items(): @@ -187,9 +187,11 @@ def write_pixels(self, path: str, pixels: Pixels) -> None: write_path = Path(path) / self._get_pixel_hash_path(path, pixels) # write pixels if pixels.data is None: - raise ValueError("Writing pixels only suppported for pixels hodling dataframes!") - pixels.data.to_parquet(write_path, row_group_size=1024*1024) + raise ValueError( + "Writing pixels only suppported for pixels hodling dataframes!" + ) + pixels.data.to_parquet(write_path, row_group_size=1024 * 1024) # write metadata current_metadata[write_path.name] = pixels.get_global_parameters().dict() with open(metadata_path, "w") as f: - json.dump(current_metadata, f) \ No newline at end of file + json.dump(current_metadata, f) diff --git a/spoc/pixels.py b/spoc/pixels.py index e58b6e7..9604e52 100644 --- a/spoc/pixels.py +++ b/spoc/pixels.py @@ -19,7 +19,7 @@ class Pixels: - Order - Metadata combination (Whether the pixels represent a certain combination of metadata) - Whether binary labels are equal (e.g. whether AB pixles also represent BA pixels) - + Pixels can contain different data sources such as: - pandas dataframe - dask dataframe @@ -34,13 +34,15 @@ def __init__( metadata_combi: Optional[List[str]] = None, label_sorted: bool = False, binary_labels_equal: bool = False, - symmetry_flipped:bool = False, - same_chromosome: bool = True + symmetry_flipped: bool = False, + same_chromosome: bool = True, ): """Constructor for genomic pixels. pixel_source can be a pandas or dask dataframe or a path. Caveate is that if pixels are a path, source data is not validated.""" - self._schema = PixelSchema(number_fragments=number_fragments, same_chromosome=same_chromosome) + self._schema = PixelSchema( + number_fragments=number_fragments, same_chromosome=same_chromosome + ) self._same_chromosome = same_chromosome self._number_fragments = number_fragments self._binsize = binsize @@ -73,8 +75,17 @@ def from_uri(uri, mode="path"): """ # import here to avoid circular imports from spoc.io import FileManager + # Define uir parameters - PARAMETERS = ['number_fragments', 'binsize', 'metadata_combi', 'binary_labels_equal', 'symmetry_flipped', 'label_sorted', 'same_chromosome'] + PARAMETERS = [ + "number_fragments", + "binsize", + "metadata_combi", + "binary_labels_equal", + "symmetry_flipped", + "label_sorted", + "same_chromosome", + ] # parse uri uri = uri.split("::") # validate uri @@ -82,12 +93,10 @@ def from_uri(uri, mode="path"): raise ValueError( f"Uri: {uri} is not valid. Must contain at least Path, number_fragments and binsize" ) - params = { - key:value for key, value in zip(PARAMETERS, uri[1:]) - } + params = {key: value for key, value in zip(PARAMETERS, uri[1:])} # rewrite metadata_combi parameter - if 'metadata_combi' in params.keys() and params['metadata_combi'] != 'None': - params['metadata_combi'] = str(list(params['metadata_combi'])) + if "metadata_combi" in params.keys() and params["metadata_combi"] != "None": + params["metadata_combi"] = str(list(params["metadata_combi"])) # read mode if mode == "path": load_dataframe = False @@ -102,21 +111,18 @@ def from_uri(uri, mode="path"): available_pixels = FileManager().list_pixels(uri[0]) # filter pixels matched_pixels = [ - pixel for pixel in available_pixels - if all( params[key] == str(pixel.dict()[key]) for key in params.keys()) + pixel + for pixel in available_pixels + if all(params[key] == str(pixel.dict()[key]) for key in params.keys()) ] # check whether there is a unique match if len(matched_pixels) == 0: - raise ValueError( - f"No pixels found for uri: {uri}" - ) + raise ValueError(f"No pixels found for uri: {uri}") elif len(matched_pixels) > 1: - raise ValueError( - f"Multiple pixels found for uri: {uri}" - ) - return FileManager(use_dask=use_dask).load_pixels(uri[0], matched_pixels[0], load_dataframe=load_dataframe) - - + raise ValueError(f"Multiple pixels found for uri: {uri}") + return FileManager(use_dask=use_dask).load_pixels( + uri[0], matched_pixels[0], load_dataframe=load_dataframe + ) def get_global_parameters(self): """Returns global parameters of pixels""" @@ -127,7 +133,7 @@ def get_global_parameters(self): label_sorted=self._label_sorted, binary_labels_equal=self._binary_labels_equal, symmetry_flipped=self._symmetry_flipped, - same_chromosome=self._same_chromosome + same_chromosome=self._same_chromosome, ) @property @@ -149,7 +155,7 @@ def binsize(self): @property def binary_labels_equal(self): return self._binary_labels_equal - + @property def symmetry_flipped(self): return self._symmetry_flipped @@ -161,7 +167,6 @@ def metadata_combi(self): @property def same_chromosome(self): return self._same_chromosome - class GenomicBinner: @@ -169,22 +174,14 @@ class GenomicBinner: Is capable of sorting genomic bins along columns based on sister chromatid identity""" - def __init__( - self, - bin_size: int - ) -> None: + def __init__(self, bin_size: int) -> None: self._bin_size = bin_size self._contact_order = None def _get_assigned_bin_output_structure(self): - columns = ( - [ - f'chrom_{index}' for index in range(1 , self._contact_order + 1) - ] + - [ - f'start_{index}' for index in range(1 , self._contact_order + 1) - ] - ) + columns = [f"chrom_{index}" for index in range(1, self._contact_order + 1)] + [ + f"start_{index}" for index in range(1, self._contact_order + 1) + ] return pd.DataFrame(columns=columns).astype(int) def _assign_bins(self, data_frame: pd.DataFrame) -> pd.DataFrame: @@ -193,46 +190,47 @@ def _assign_bins(self, data_frame: pd.DataFrame) -> pd.DataFrame: return self._get_assigned_bin_output_structure() return data_frame.assign( **{ - f"start_{index}": (data_frame[f"pos_{index}"] // self._bin_size) * self._bin_size - for index in range(1, self._contact_order + 1) + f"start_{index}": (data_frame[f"pos_{index}"] // self._bin_size) + * self._bin_size + for index in range(1, self._contact_order + 1) } - ).filter(regex='(chrom|start)') + ).filter(regex="(chrom|start)") def _assign_midpoints(self, contacts: dd.DataFrame) -> dd.DataFrame: """Collapses start-end to a middle position""" - return ( - contacts - .assign( - **{ - f"pos_{index}": (contacts[f'start_{index}'] + contacts[f'end_{index}'])//2 - for index in range(1, self._contact_order + 1) - } - ) - .drop( - [ - c for index in range(1, self._contact_order + 1) for c in [f'start_{index}', f'end_{index}'] - ], axis=1 - ) + return contacts.assign( + **{ + f"pos_{index}": (contacts[f"start_{index}"] + contacts[f"end_{index}"]) + // 2 + for index in range(1, self._contact_order + 1) + } + ).drop( + [ + c + for index in range(1, self._contact_order + 1) + for c in [f"start_{index}", f"end_{index}"] + ], + axis=1, ) - def bin_contacts(self, contacts: Contacts, same_chromosome: bool =True) -> dd.DataFrame: + def bin_contacts( + self, contacts: Contacts, same_chromosome: bool = True + ) -> dd.DataFrame: """Bins genomic contacts""" self._contact_order = contacts.number_fragments - contacts_w_midpoints = self._assign_midpoints( - contacts.data - ) + contacts_w_midpoints = self._assign_midpoints(contacts.data) if contacts.is_dask: contact_bins = contacts_w_midpoints.map_partitions( - self._assign_bins, - meta=self._get_assigned_bin_output_structure() + self._assign_bins, meta=self._get_assigned_bin_output_structure() ) else: contact_bins = self._assign_bins(contacts_w_midpoints) pixels = ( contact_bins.groupby( [ - c for index in range(1, self._contact_order + 1) - for c in [f'chrom_{index}', f'start_{index}'] + c + for index in range(1, self._contact_order + 1) + for c in [f"chrom_{index}", f"start_{index}"] ], observed=True, ) @@ -247,17 +245,21 @@ def bin_contacts(self, contacts: Contacts, same_chromosome: bool =True) -> dd.Da (pixels.chrom_1.astype(str) == pixels.chrom_2.astype(str)) & (pixels.chrom_2.astype(str) == pixels.chrom_3.astype(str)) ] - .drop([f'chrom_{index}' for index in range(2, self._contact_order + 1)], axis=1) + .drop( + [f"chrom_{index}" for index in range(2, self._contact_order + 1)], + axis=1, + ) .rename(columns={"chrom_1": "chrom"}) ) # sort pixels pixels_sorted = pixels.sort_values( - ['chrom'] + [f'start_{index}' for index in range(1, self._contact_order + 1)] + ["chrom"] + + [f"start_{index}" for index in range(1, self._contact_order + 1)] ).reset_index(drop=True) else: pixels_sorted = pixels.sort_values( - [f'chrom_{index}' for index in range(1, self._contact_order + 1)] + - [f'start_{index}' for index in range(1, self._contact_order + 1)] + [f"chrom_{index}" for index in range(1, self._contact_order + 1)] + + [f"start_{index}" for index in range(1, self._contact_order + 1)] ).reset_index(drop=True) # construct pixels and return return Pixels( @@ -267,7 +269,7 @@ def bin_contacts(self, contacts: Contacts, same_chromosome: bool =True) -> dd.Da binsize=self._bin_size, binary_labels_equal=contacts.binary_labels_equal, symmetry_flipped=contacts.symmetry_flipped, - metadata_combi=contacts.metadata_combi + metadata_combi=contacts.metadata_combi, ) diff --git a/tests/fixtures/symmetry.py b/tests/fixtures/symmetry.py index c9946f8..8a0619d 100644 --- a/tests/fixtures/symmetry.py +++ b/tests/fixtures/symmetry.py @@ -3,6 +3,7 @@ import pandas as pd import dask.dataframe as dd + @pytest.fixture def unlabelled_contacts_2d(): return pd.DataFrame( @@ -20,14 +21,16 @@ def unlabelled_contacts_2d(): "end_2": [2000, 300, 400], "mapping_quality_2": [10, 10, 10], "align_score_2": [10, 10, 10], - "align_base_qscore_2": [10, 10, 10] + "align_base_qscore_2": [10, 10, 10], } ) + @pytest.fixture def unlabelled_contacts_2d_dask(unlabelled_contacts_2d): return dd.from_pandas(unlabelled_contacts_2d, npartitions=2) + @pytest.fixture def unlabelled_contacts_3d(): return pd.DataFrame( @@ -51,14 +54,16 @@ def unlabelled_contacts_3d(): "end_3": [300, 500, 200], "mapping_quality_3": [10, 10, 5], "align_score_3": [10, 10, 5], - "align_base_qscore_3": [10, 10, 5] + "align_base_qscore_3": [10, 10, 5], } ) + @pytest.fixture def unlabelled_contacts_3d_dask(unlabelled_contacts_3d): return dd.from_pandas(unlabelled_contacts_3d, npartitions=2) + @pytest.fixture def unlabelled_contacts_2d_flipped(): return pd.DataFrame( @@ -76,7 +81,7 @@ def unlabelled_contacts_2d_flipped(): "end_2": [2000, 3000, 4000], "mapping_quality_2": [10, 10, 15], "align_score_2": [10, 10, 15], - "align_base_qscore_2": [10, 10, 15] + "align_base_qscore_2": [10, 10, 15], } ) @@ -104,10 +109,11 @@ def unlabelled_contacts_3d_flipped(): "end_3": [2000, 3000, 4000], "mapping_quality_3": [10, 10, 15], "align_score_3": [10, 10, 15], - "align_base_qscore_3": [10, 10, 15] + "align_base_qscore_3": [10, 10, 15], } ) + @pytest.fixture def labelled_binary_contacts_2d(): return pd.DataFrame( @@ -127,10 +133,11 @@ def labelled_binary_contacts_2d(): "mapping_quality_2": [10, 10, 15], "align_score_2": [10, 10, 15], "align_base_qscore_2": [10, 10, 15], - "metadata_2": ["B", "A", "A"] + "metadata_2": ["B", "A", "A"], } ) + @pytest.fixture def labelled_binary_contacts_2d_sorted(): return pd.DataFrame( @@ -150,7 +157,7 @@ def labelled_binary_contacts_2d_sorted(): "mapping_quality_2": [10, 10, 15], "align_score_2": [10, 10, 15], "align_base_qscore_2": [10, 10, 15], - "metadata_2": ["B", "B", "A"] + "metadata_2": ["B", "B", "A"], } ) @@ -181,10 +188,11 @@ def labelled_binary_contacts_3d(): "mapping_quality_3": [10, 10, 15], "align_score_3": [10, 10, 15], "align_base_qscore_3": [10, 10, 15], - "metadata_3": ["B", "A", "A"] + "metadata_3": ["B", "A", "A"], } ) + @pytest.fixture def labelled_binary_contacts_3d_sorted(): return pd.DataFrame( @@ -211,32 +219,34 @@ def labelled_binary_contacts_3d_sorted(): "mapping_quality_3": [10, 10, 15], "align_score_3": [10, 10, 15], "align_base_qscore_3": [10, 10, 15], - "metadata_3": ["B", "B", "A"] + "metadata_3": ["B", "B", "A"], } ) + @pytest.fixture def binary_contacts_not_equated_2d(): return pd.DataFrame( - { - "read_name": ["read1", "read2", "read3"], - "read_length": [100, 100, 100], - "chrom_1": ["chr1", "chr1", "chr1"], - "start_1": [100, 2000, 300], - "end_1": [200, 3000, 400], - "mapping_quality_1": [10, 10, 10], - "align_score_1": [10, 10, 10], - "align_base_qscore_1": [10, 10, 10], - "metadata_1": ["B", "A", "A"], - "chrom_2": ["chr1", "chr1", "chr1"], - "start_2": [1000, 200, 3000], - "end_2": [2000, 300, 4000], - "mapping_quality_2": [10, 10, 15], - "align_score_2": [10, 10, 15], - "align_base_qscore_2": [10, 10, 15], - "metadata_2": ["B", "B", "A"] - } - ) + { + "read_name": ["read1", "read2", "read3"], + "read_length": [100, 100, 100], + "chrom_1": ["chr1", "chr1", "chr1"], + "start_1": [100, 2000, 300], + "end_1": [200, 3000, 400], + "mapping_quality_1": [10, 10, 10], + "align_score_1": [10, 10, 10], + "align_base_qscore_1": [10, 10, 10], + "metadata_1": ["B", "A", "A"], + "chrom_2": ["chr1", "chr1", "chr1"], + "start_2": [1000, 200, 3000], + "end_2": [2000, 300, 4000], + "mapping_quality_2": [10, 10, 15], + "align_score_2": [10, 10, 15], + "align_base_qscore_2": [10, 10, 15], + "metadata_2": ["B", "B", "A"], + } + ) + @pytest.fixture def binary_contacts_not_equated_3d(): @@ -264,10 +274,11 @@ def binary_contacts_not_equated_3d(): "mapping_quality_3": [10, 10, 15], "align_score_3": [10, 10, 15], "align_base_qscore_3": [10, 10, 15], - "metadata_3": ["B", "B", "B"] + "metadata_3": ["B", "B", "B"], } ) + @pytest.fixture def binary_contacts_not_equated_4d(): return pd.DataFrame( @@ -301,10 +312,11 @@ def binary_contacts_not_equated_4d(): "mapping_quality_4": [10, 10, 15], "align_score_4": [10, 10, 15], "align_base_qscore_4": [10, 10, 15], - "metadata_4": ["B", "B", "B"] + "metadata_4": ["B", "B", "B"], } ) + @pytest.fixture def binary_contacts_equated_2d(): return pd.DataFrame( @@ -324,10 +336,11 @@ def binary_contacts_equated_2d(): "mapping_quality_2": [10, 10, 15], "align_score_2": [10, 10, 15], "align_base_qscore_2": [10, 10, 15], - "metadata_2": ["A", "B", "A"] + "metadata_2": ["A", "B", "A"], } ) + @pytest.fixture def binary_contacts_equated_3d(): return pd.DataFrame( @@ -354,10 +367,11 @@ def binary_contacts_equated_3d(): "mapping_quality_3": [10, 10, 15], "align_score_3": [10, 10, 15], "align_base_qscore_3": [10, 10, 15], - "metadata_3": ["B", "B", "A"] + "metadata_3": ["B", "B", "A"], } ) + @pytest.fixture def binary_contacts_equated_4d(): return pd.DataFrame( @@ -391,10 +405,11 @@ def binary_contacts_equated_4d(): "mapping_quality_4": [10, 10, 15], "align_score_4": [10, 10, 15], "align_base_qscore_4": [10, 10, 15], - "metadata_4": ["B", "B", "A"] + "metadata_4": ["B", "B", "A"], } ) + @pytest.fixture def labelled_binary_contacts_2d_unflipped(): return pd.DataFrame( @@ -414,10 +429,11 @@ def labelled_binary_contacts_2d_unflipped(): "mapping_quality_2": [10, 10, 15], "align_score_2": [10, 10, 15], "align_base_qscore_2": [10, 10, 15], - "metadata_2": ["B", "B", "A"] + "metadata_2": ["B", "B", "A"], } ) + @pytest.fixture def labelled_binary_contacts_3d_unflipped(): return pd.DataFrame( @@ -444,10 +460,11 @@ def labelled_binary_contacts_3d_unflipped(): "mapping_quality_3": [10, 10, 15, 14], "align_score_3": [10, 10, 15, 14], "align_base_qscore_3": [10, 10, 15, 14], - "metadata_3": ["B", "B", "B", "A"] + "metadata_3": ["B", "B", "B", "A"], } ) + @pytest.fixture def labelled_binary_contacts_3d_unflipped_example2(): return pd.DataFrame( @@ -474,13 +491,11 @@ def labelled_binary_contacts_3d_unflipped_example2(): "mapping_quality_3": [10, 10, 14], "align_score_3": [10, 10, 14], "align_base_qscore_3": [10, 10, 14], - "metadata_3": ["B", "B", "A"] + "metadata_3": ["B", "B", "A"], } ) - - @pytest.fixture def labelled_binary_contacts_2d_flipped(): return pd.DataFrame( @@ -500,10 +515,11 @@ def labelled_binary_contacts_2d_flipped(): "mapping_quality_2": [10, 10, 10], "align_score_2": [10, 10, 10], "align_base_qscore_2": [10, 10, 10], - "metadata_2": ["B", "B", "A"] + "metadata_2": ["B", "B", "A"], } ) + @pytest.fixture def labelled_binary_contacts_3d_flipped(): return pd.DataFrame( @@ -530,10 +546,11 @@ def labelled_binary_contacts_3d_flipped(): "mapping_quality_3": [10, 10, 15, 20], "align_score_3": [10, 10, 15, 20], "align_base_qscore_3": [10, 10, 15, 20], - "metadata_3": ["B", "B", "B", "A"] + "metadata_3": ["B", "B", "B", "A"], } ) + @pytest.fixture def labelled_binary_contacts_3d_flipped_example2(): return pd.DataFrame( @@ -560,10 +577,11 @@ def labelled_binary_contacts_3d_flipped_example2(): "mapping_quality_3": [10, 10, 14], "align_score_3": [10, 10, 14], "align_base_qscore_3": [10, 10, 14], - "metadata_3": ["B", "B", "A"] + "metadata_3": ["B", "B", "A"], } ) + @pytest.fixture def unlabelled_contacts_diff_chrom_2d(): return pd.DataFrame( @@ -609,7 +627,7 @@ def unlabelled_contacts_diff_chrom_3d(): "end_3": [300, 500, 600], "mapping_quality_3": [10, 10, 5], "align_score_3": [10, 10, 5], - "align_base_qscore_3": [10, 10, 5] + "align_base_qscore_3": [10, 10, 5], } ) @@ -643,10 +661,11 @@ def unlabelled_contacts_diff_chrom_4d(): "end_4": [350, 300, 600], "mapping_quality_4": [10, 10, 10], "align_score_4": [10, 10, 10], - "align_base_qscore_4": [10, 10, 10] + "align_base_qscore_4": [10, 10, 10], } ) + @pytest.fixture def unlabelled_contacts_diff_chrom_4d_flipped(): return pd.DataFrame( @@ -676,12 +695,11 @@ def unlabelled_contacts_diff_chrom_4d_flipped(): "end_4": [4000, 500, 600], "mapping_quality_4": [10, 10, 10], "align_score_4": [10, 10, 10], - "align_base_qscore_4": [10, 10, 10] + "align_base_qscore_4": [10, 10, 10], } ) - @pytest.fixture def unlabelled_contacts_diff_chrom_3d_flipped(): return pd.DataFrame( @@ -705,10 +723,11 @@ def unlabelled_contacts_diff_chrom_3d_flipped(): "end_3": [2000, 3000, 200], "mapping_quality_3": [10, 10, 15], "align_score_3": [10, 10, 15], - "align_base_qscore_3": [10, 10, 15] + "align_base_qscore_3": [10, 10, 15], } ) + @pytest.fixture def unlabelled_contacts_diff_chrom_2d_flipped(): return pd.DataFrame( @@ -730,6 +749,7 @@ def unlabelled_contacts_diff_chrom_2d_flipped(): } ) + @pytest.fixture def labelled_binary_contacts_diff_chrom_2d(): return pd.DataFrame( @@ -749,10 +769,11 @@ def labelled_binary_contacts_diff_chrom_2d(): "mapping_quality_2": [10, 10, 15], "align_score_2": [10, 10, 15], "align_base_qscore_2": [10, 10, 15], - "metadata_2": ["B", "B", "A"] + "metadata_2": ["B", "B", "A"], } ) + @pytest.fixture def labelled_binary_contacts_diff_chrom_2d_flipped(): return pd.DataFrame( @@ -772,10 +793,11 @@ def labelled_binary_contacts_diff_chrom_2d_flipped(): "mapping_quality_2": [10, 10, 15], "align_score_2": [10, 10, 15], "align_base_qscore_2": [10, 10, 15], - "metadata_2": ["B", "B", "A"] + "metadata_2": ["B", "B", "A"], } ) + @pytest.fixture def labelled_binary_contacts_diff_chrom_3d(): return pd.DataFrame( @@ -802,10 +824,11 @@ def labelled_binary_contacts_diff_chrom_3d(): "mapping_quality_3": [10, 10, 15, 14], "align_score_3": [10, 10, 15, 14], "align_base_qscore_3": [10, 10, 15, 14], - "metadata_3": ["B", "B", "B", "A"] + "metadata_3": ["B", "B", "B", "A"], } ) + @pytest.fixture def labelled_binary_contacts_diff_chrom_3d_flipped(): return pd.DataFrame( @@ -832,6 +855,6 @@ def labelled_binary_contacts_diff_chrom_3d_flipped(): "mapping_quality_3": [10, 10, 15, 20], "align_score_3": [10, 10, 15, 20], "align_base_qscore_3": [10, 10, 15, 20], - "metadata_3": ["B", "B", "B", "A"] + "metadata_3": ["B", "B", "B", "A"], } - ) \ No newline at end of file + ) diff --git a/tests/test_cli.py b/tests/test_cli.py index 1690a78..3d96e31 100644 --- a/tests/test_cli.py +++ b/tests/test_cli.py @@ -10,6 +10,7 @@ from spoc import cli, dataframe_models + def _create_tmp_dir(): # check if tmp dir exists if not os.path.exists("tmp"): @@ -19,6 +20,7 @@ def _create_tmp_dir(): shutil.rmtree("tmp") os.mkdir("tmp") + @pytest.fixture def good_annotated_porec_file(): # setup @@ -53,6 +55,7 @@ def good_triplet_file_for_pixels(): # teardown shutil.rmtree("tmp") + @pytest.fixture def good_porec_file(): # setup @@ -137,10 +140,7 @@ def test_merge_contacts_works(good_triplet_files): assert_frame_equal(first_half, second_half) - -def test_bin_contacts( - good_triplet_file_for_pixels, expected_pixels -): +def test_bin_contacts(good_triplet_file_for_pixels, expected_pixels): """happy path for binning contacts without sister sorting""" runner = CliRunner() output_path = "tmp/test_output5.parquet" diff --git a/tests/test_contacts.py b/tests/test_contacts.py index 3500a53..d9e0521 100644 --- a/tests/test_contacts.py +++ b/tests/test_contacts.py @@ -7,7 +7,7 @@ from spoc import contacts, dataframe_models, fragments from .fixtures.symmetry import ( unlabelled_contacts_2d, - labelled_binary_contacts_2d_sorted + labelled_binary_contacts_2d_sorted, ) @@ -16,11 +16,13 @@ def triplet_expander(): """expander for triplets""" return fragments.FragmentExpander(number_fragments=3, contains_metadata=False) + @pytest.fixture def triplet_expander_labelled(): """expander for triplets""" return fragments.FragmentExpander(number_fragments=3, contains_metadata=True) + @pytest.fixture def contact_manipulator(): """manipulator for triplest""" @@ -84,6 +86,7 @@ def unlabelled_df(): def labelled_fragments(labelled_df): return fragments.Fragments(labelled_df) + @pytest.fixture def labelled_fragments_dask(labelled_df): return fragments.Fragments(dd.from_pandas(labelled_df, npartitions=1)) @@ -93,16 +96,22 @@ def labelled_fragments_dask(labelled_df): def unlabelled_fragments(unlabelled_df): return fragments.Fragments(unlabelled_df) + @pytest.fixture def unlabelled_fragments_dask(unlabelled_df): return fragments.Fragments(dd.from_pandas(unlabelled_df, npartitions=1)) -@pytest.mark.parametrize("fragments, expander", - [("labelled_fragments", "triplet_expander_labelled"), ("labelled_fragments_dask", "triplet_expander_labelled"), - ("unlabelled_fragments", "triplet_expander"), ("unlabelled_fragments_dask", "triplet_expander")]) -def test_expander_drops_reads_w_too_little_fragments( - expander, fragments, request -): + +@pytest.mark.parametrize( + "fragments, expander", + [ + ("labelled_fragments", "triplet_expander_labelled"), + ("labelled_fragments_dask", "triplet_expander_labelled"), + ("unlabelled_fragments", "triplet_expander"), + ("unlabelled_fragments_dask", "triplet_expander"), + ], +) +def test_expander_drops_reads_w_too_little_fragments(expander, fragments, request): triplet_expander = request.getfixturevalue(expander) result = triplet_expander.expand(request.getfixturevalue(fragments)).data if isinstance(result, dd.DataFrame): @@ -111,16 +120,21 @@ def test_expander_drops_reads_w_too_little_fragments( assert result.read_name[0] == "dummy" -@pytest.mark.parametrize("fragments, expander", - [("labelled_fragments", "triplet_expander_labelled"), ("labelled_fragments_dask", "triplet_expander_labelled"), - ("unlabelled_fragments", "triplet_expander"), ("unlabelled_fragments_dask", "triplet_expander")]) -def test_expander_returns_correct_number_of_contacts( - expander, fragments, request -): +@pytest.mark.parametrize( + "fragments, expander", + [ + ("labelled_fragments", "triplet_expander_labelled"), + ("labelled_fragments_dask", "triplet_expander_labelled"), + ("unlabelled_fragments", "triplet_expander"), + ("unlabelled_fragments_dask", "triplet_expander"), + ], +) +def test_expander_returns_correct_number_of_contacts(expander, fragments, request): triplet_expander = request.getfixturevalue(expander) result = triplet_expander.expand(request.getfixturevalue(fragments)).data assert len(result) == 4 + @pytest.mark.parametrize("fragments", ["labelled_fragments", "labelled_fragments_dask"]) def test_expander_returns_correct_contacts_labelled( triplet_expander_labelled, fragments, request @@ -148,7 +162,10 @@ def test_expander_returns_correct_contacts_labelled( np.array(["SisterA", "SisterB", "SisterB", "SisterB"]), ) -@pytest.mark.parametrize("fragments", ["unlabelled_fragments", "unlabelled_fragments_dask"]) + +@pytest.mark.parametrize( + "fragments", ["unlabelled_fragments", "unlabelled_fragments_dask"] +) def test_expander_returns_correct_contacts_unlabelled( triplet_expander, fragments, request ): @@ -203,33 +220,40 @@ def test_merge_fails_for_pandas_dask_mixed( ) contact_manipulator.merge_contacts([contacts_pandas, contacts_dask]) + def test_subset_metadata_fails_if_not_labelled(unlabelled_contacts_2d): contact_manipulator = contacts.ContactManipulator() unlab_contacts = contacts.Contacts(unlabelled_contacts_2d) with pytest.raises(AssertionError): - contact_manipulator.subset_on_metadata(unlab_contacts, ['A', 'B']) + contact_manipulator.subset_on_metadata(unlab_contacts, ["A", "B"]) -def test_subset_metadata_fails_if_pattern_longer_than_number_fragments(labelled_binary_contacts_2d_sorted): +def test_subset_metadata_fails_if_pattern_longer_than_number_fragments( + labelled_binary_contacts_2d_sorted, +): contact_manipulator = contacts.ContactManipulator() lab_contacts = contacts.Contacts(labelled_binary_contacts_2d_sorted) with pytest.raises(AssertionError): - contact_manipulator.subset_on_metadata(lab_contacts, ['A', 'B', 'A']) + contact_manipulator.subset_on_metadata(lab_contacts, ["A", "B", "A"]) + -def test_subset_metadata_fails_if_pattern_contains_strings_not_in_metadata(labelled_binary_contacts_2d_sorted): +def test_subset_metadata_fails_if_pattern_contains_strings_not_in_metadata( + labelled_binary_contacts_2d_sorted, +): contact_manipulator = contacts.ContactManipulator() lab_contacts = contacts.Contacts(labelled_binary_contacts_2d_sorted) with pytest.raises(AssertionError): - contact_manipulator.subset_on_metadata(lab_contacts, ['A', 'C']) + contact_manipulator.subset_on_metadata(lab_contacts, ["A", "C"]) + def test_subset_metadata_creates_correct_subset(labelled_binary_contacts_2d_sorted): contact_manipulator = contacts.ContactManipulator() lab_contacts = contacts.Contacts(labelled_binary_contacts_2d_sorted) - result = contact_manipulator.subset_on_metadata(lab_contacts, ['A', 'B']) + result = contact_manipulator.subset_on_metadata(lab_contacts, ["A", "B"]) assert len(result.data) == 2 - assert result.data['metadata_1'].unique() == ['A'] - assert result.data['metadata_2'].unique() == ['B'] - assert result.metadata_combi == ['A', 'B'] + assert result.data["metadata_1"].unique() == ["A"] + assert result.data["metadata_2"].unique() == ["B"] + assert result.metadata_combi == ["A", "B"] # TODO: merge rejects labelled and unlabelled contacts diff --git a/tests/test_io.py b/tests/test_io.py index 05ef9a3..1fd87d3 100644 --- a/tests/test_io.py +++ b/tests/test_io.py @@ -12,18 +12,17 @@ from spoc.file_parameter_models import PixelParameters from spoc.pixels import Pixels import dask.dataframe as dd -from .fixtures.symmetry import ( - unlabelled_contacts_2d -) +from .fixtures.symmetry import unlabelled_contacts_2d CONTACT_PARAMETERS = ( [2], - [['A', 'B'], ['B', 'C']], + [["A", "B"], ["B", "C"]], [True, False], [True, False], [True, False], ) + def _create_tmp_dir(): # check if tmp dir exists if not os.path.exists("tmp"): @@ -33,6 +32,7 @@ def _create_tmp_dir(): shutil.rmtree("tmp") os.mkdir("tmp") + @pytest.fixture def df_order_2(): return pd.DataFrame( @@ -44,6 +44,7 @@ def df_order_2(): } ) + @pytest.fixture def df_order_3(): return pd.DataFrame( @@ -56,6 +57,7 @@ def df_order_3(): } ) + @pytest.fixture def example_pixels_w_metadata(df_order_2, df_order_3): # setup @@ -65,20 +67,22 @@ def example_pixels_w_metadata(df_order_2, df_order_3): os.mkdir(pixels_dir) expected_parameters = [ PixelParameters(number_fragments=2, binsize=1000), - PixelParameters(number_fragments=3, binsize=10_000, metadata_combi=['A', 'B', 'B'], - label_sorted=True, binary_labels_equal=True, symmetry_flipped=True), - PixelParameters(number_fragments=2, binsize=100_000) + PixelParameters( + number_fragments=3, + binsize=10_000, + metadata_combi=["A", "B", "B"], + label_sorted=True, + binary_labels_equal=True, + symmetry_flipped=True, + ), + PixelParameters(number_fragments=2, binsize=100_000), ] paths = [ Path("tmp/pixels_test.parquet/test1.parquet"), Path("tmp/pixels_test.parquet/test2.parquet"), Path("tmp/pixels_test.parquet/test3.parquet"), ] - dataframes = [ - df_order_2, - df_order_3, - df_order_2 - ] + dataframes = [df_order_2, df_order_3, df_order_2] # create pixels files for path, df in zip(paths, dataframes): df.to_parquet(path) @@ -88,23 +92,20 @@ def example_pixels_w_metadata(df_order_2, df_order_3): "test2.parquet": expected_parameters[1].dict(), "test3.parquet": expected_parameters[2].dict(), } - with open(pixels_dir + '/metadata.json', 'w') as f: + with open(pixels_dir + "/metadata.json", "w") as f: json.dump(metadata, f) yield pixels_dir, expected_parameters, paths, dataframes # teardown shutil.rmtree("tmp") - - - -@pytest.mark.parametrize('params', - product(*CONTACT_PARAMETERS) - ) -def test_write_read_contacts_global_parameters_w_metadata_pandas(unlabelled_contacts_2d, params): +@pytest.mark.parametrize("params", product(*CONTACT_PARAMETERS)) +def test_write_read_contacts_global_parameters_w_metadata_pandas( + unlabelled_contacts_2d, params +): """Test writing and reading contacts metadata with pandas""" with tempfile.TemporaryDirectory() as tmpdirname: - file_name = tmpdirname + '/' + '_'.join([str(x) for x in params]) + '.parquet' + file_name = tmpdirname + "/" + "_".join([str(x) for x in params]) + ".parquet" contacts = Contacts(unlabelled_contacts_2d, *params) FileManager().write_multiway_contacts(file_name, contacts) # read contacts @@ -112,27 +113,31 @@ def test_write_read_contacts_global_parameters_w_metadata_pandas(unlabelled_cont # check whether parameters are equal assert contacts.get_global_parameters() == contacts_read.get_global_parameters() -@pytest.mark.parametrize('params', - product(*CONTACT_PARAMETERS) - ) -def test_write_read_contacts_global_parameters_w_metadata_dask(unlabelled_contacts_2d, params): - """Test writing and reading contacts metadata """ + +@pytest.mark.parametrize("params", product(*CONTACT_PARAMETERS)) +def test_write_read_contacts_global_parameters_w_metadata_dask( + unlabelled_contacts_2d, params +): + """Test writing and reading contacts metadata""" with tempfile.TemporaryDirectory() as tmpdirname: - file_name = tmpdirname + '/' + '_'.join([str(x) for x in params]) + '.parquet' - contacts = Contacts(dd.from_pandas(unlabelled_contacts_2d, npartitions=2), *params) + file_name = tmpdirname + "/" + "_".join([str(x) for x in params]) + ".parquet" + contacts = Contacts( + dd.from_pandas(unlabelled_contacts_2d, npartitions=2), *params + ) FileManager().write_multiway_contacts(file_name, contacts) # read contacts contacts_read = FileManager(use_dask=True).load_contacts(file_name) # check whether parameters are equal assert contacts.get_global_parameters() == contacts_read.get_global_parameters() -@pytest.mark.parametrize('params', - product(*CONTACT_PARAMETERS) - ) -def test_write_read_contacts_global_parameters_w_metadata_pandas_to_dask(unlabelled_contacts_2d, params): + +@pytest.mark.parametrize("params", product(*CONTACT_PARAMETERS)) +def test_write_read_contacts_global_parameters_w_metadata_pandas_to_dask( + unlabelled_contacts_2d, params +): """Test writing with pandas and reading with dask""" with tempfile.TemporaryDirectory() as tmpdirname: - file_name = tmpdirname + '/' + '_'.join([str(x) for x in params]) + '.parquet' + file_name = tmpdirname + "/" + "_".join([str(x) for x in params]) + ".parquet" contacts = Contacts(unlabelled_contacts_2d, *params) FileManager().write_multiway_contacts(file_name, contacts) # read contacts @@ -140,14 +145,17 @@ def test_write_read_contacts_global_parameters_w_metadata_pandas_to_dask(unlabel # check whether parameters are equal assert contacts.get_global_parameters() == contacts_read.get_global_parameters() -@pytest.mark.parametrize('params', - product(*CONTACT_PARAMETERS) - ) -def test_write_read_contacts_global_parameters_w_metadata_dask_to_pandas(unlabelled_contacts_2d, params): + +@pytest.mark.parametrize("params", product(*CONTACT_PARAMETERS)) +def test_write_read_contacts_global_parameters_w_metadata_dask_to_pandas( + unlabelled_contacts_2d, params +): """Test writing with pandas and reading with dask""" with tempfile.TemporaryDirectory() as tmpdirname: - file_name = tmpdirname + '/' + '_'.join([str(x) for x in params]) + '.parquet' - contacts = Contacts(dd.from_pandas(unlabelled_contacts_2d, npartitions=2), *params) + file_name = tmpdirname + "/" + "_".join([str(x) for x in params]) + ".parquet" + contacts = Contacts( + dd.from_pandas(unlabelled_contacts_2d, npartitions=2), *params + ) FileManager().write_multiway_contacts(file_name, contacts) # read contacts contacts_read = FileManager(use_dask=False).load_contacts(file_name) @@ -162,7 +170,11 @@ def test_read_pixels_metadata_json(example_pixels_w_metadata): available_pixels = FileManager().list_pixels(pixels_dir) # check whether parameters are equal assert len(available_pixels) == len(expected_parameters) - assert all(actual == expected for actual, expected in zip(available_pixels, expected_parameters)) + assert all( + actual == expected + for actual, expected in zip(available_pixels, expected_parameters) + ) + def test_read_pixels_metadata_json_fails_gracefully(): """Test reading pixels metadata json file""" @@ -171,6 +183,7 @@ def test_read_pixels_metadata_json_fails_gracefully(): FileManager().list_pixels("bad_path") assert e.value == "Metadata file not found at bad_path/metadata.json" + def test_read_pixels_as_path(example_pixels_w_metadata): """Test reading pixels metadata json file""" pixels_dir, expected_parameters, paths, _ = example_pixels_w_metadata @@ -180,36 +193,49 @@ def test_read_pixels_as_path(example_pixels_w_metadata): assert pixels.path == path assert pixels.get_global_parameters() == expected + def test_read_pixels_as_pandas_df(example_pixels_w_metadata): """Test reading pixels metadata json file""" pixels_dir, expected_parameters, paths, dataframes = example_pixels_w_metadata # read metadata for path, expected, df in zip(paths, expected_parameters, dataframes): - pixels = FileManager(use_dask=False).load_pixels(pixels_dir, expected, load_dataframe=True) + pixels = FileManager(use_dask=False).load_pixels( + pixels_dir, expected, load_dataframe=True + ) assert pixels.get_global_parameters() == expected assert pixels.data.equals(df) + def test_read_pixels_as_dask_df(example_pixels_w_metadata): """Test reading pixels metadata json file""" pixels_dir, expected_parameters, paths, dataframes = example_pixels_w_metadata # read metadata for path, expected, df in zip(paths, expected_parameters, dataframes): - pixels = FileManager(use_dask=True).load_pixels(pixels_dir, expected, load_dataframe=True) + pixels = FileManager(use_dask=True).load_pixels( + pixels_dir, expected, load_dataframe=True + ) assert pixels.get_global_parameters() == expected assert pixels.data.compute().equals(df) -@pytest.mark.parametrize('df, params', - [ - ('df_order_2', PixelParameters(number_fragments=2, binsize=1000)), - ('df_order_3', PixelParameters(number_fragments=3, binsize=10_000, metadata_combi=['A', 'B', 'B'])), - ('df_order_2', PixelParameters(number_fragments=2, binsize=100_000)) - ] - ) + +@pytest.mark.parametrize( + "df, params", + [ + ("df_order_2", PixelParameters(number_fragments=2, binsize=1000)), + ( + "df_order_3", + PixelParameters( + number_fragments=3, binsize=10_000, metadata_combi=["A", "B", "B"] + ), + ), + ("df_order_2", PixelParameters(number_fragments=2, binsize=100_000)), + ], +) def test_write_pandas_pixels_to_new_file(df, params, request): df = request.getfixturevalue(df) pixels = Pixels(df, **params.dict()) with tempfile.TemporaryDirectory() as tmpdirname: - file_name = tmpdirname + '/' + 'test.parquet' + file_name = tmpdirname + "/" + "test.parquet" FileManager().write_pixels(file_name, pixels) # check metadata metadata = FileManager().list_pixels(file_name) @@ -221,19 +247,26 @@ def test_write_pandas_pixels_to_new_file(df, params, request): assert pixels.get_global_parameters() == pixels_read.get_global_parameters() assert pixels.data.equals(pixels_read.data) -@pytest.mark.parametrize('df, params', - [ - ('df_order_2', PixelParameters(number_fragments=2, binsize=1000)), - ('df_order_3', PixelParameters(number_fragments=3, binsize=10_000, metadata_combi=['A', 'B', 'B'])), - ('df_order_2', PixelParameters(number_fragments=2, binsize=100_000)) - ] - ) + +@pytest.mark.parametrize( + "df, params", + [ + ("df_order_2", PixelParameters(number_fragments=2, binsize=1000)), + ( + "df_order_3", + PixelParameters( + number_fragments=3, binsize=10_000, metadata_combi=["A", "B", "B"] + ), + ), + ("df_order_2", PixelParameters(number_fragments=2, binsize=100_000)), + ], +) def test_write_dask_pixels_to_new_file(df, params, request): df = request.getfixturevalue(df) dask_df = dd.from_pandas(df, npartitions=2) pixels = Pixels(dask_df, **params.dict()) with tempfile.TemporaryDirectory() as tmpdirname: - file_name = tmpdirname + '/' + 'test.parquet' + file_name = tmpdirname + "/" + "test.parquet" FileManager().write_pixels(file_name, pixels) # check metadata metadata = FileManager().list_pixels(file_name) @@ -244,23 +277,47 @@ def test_write_dask_pixels_to_new_file(df, params, request): assert pixels.get_global_parameters() == pixels_read.get_global_parameters() assert pixels.data.compute().equals(pixels_read.data) -@pytest.mark.parametrize('df1,df2,params', - [ - ('df_order_2','df_order_3', - [PixelParameters(number_fragments=2, binsize=1000), PixelParameters(number_fragments=3, binsize=10_000)]), - ('df_order_3','df_order_2', - [PixelParameters(number_fragments=3, binsize=10_000, metadata_combi=['A', 'B', 'B']), PixelParameters(number_fragments=2, binsize=100)]), - ('df_order_2','df_order_3', - [PixelParameters(number_fragments=2, binsize=100_000), PixelParameters(number_fragments=3, binsize=10_000, metadata_combi=['A', 'B', 'B'])]) - ] - ) + +@pytest.mark.parametrize( + "df1,df2,params", + [ + ( + "df_order_2", + "df_order_3", + [ + PixelParameters(number_fragments=2, binsize=1000), + PixelParameters(number_fragments=3, binsize=10_000), + ], + ), + ( + "df_order_3", + "df_order_2", + [ + PixelParameters( + number_fragments=3, binsize=10_000, metadata_combi=["A", "B", "B"] + ), + PixelParameters(number_fragments=2, binsize=100), + ], + ), + ( + "df_order_2", + "df_order_3", + [ + PixelParameters(number_fragments=2, binsize=100_000), + PixelParameters( + number_fragments=3, binsize=10_000, metadata_combi=["A", "B", "B"] + ), + ], + ), + ], +) def test_add_pandas_pixels_to_existing_file(df1, df2, params, request): df1, df2 = request.getfixturevalue(df1), request.getfixturevalue(df2) params_1, params_2 = params pixels1 = Pixels(df1, **params_1.dict()) pixels2 = Pixels(df2, **params_2.dict()) with tempfile.TemporaryDirectory() as tmpdirname: - file_name = tmpdirname + '/' + 'test.parquet' + file_name = tmpdirname + "/" + "test.parquet" FileManager().write_pixels(file_name, pixels1) FileManager().write_pixels(file_name, pixels2) # check metadata @@ -268,109 +325,163 @@ def test_add_pandas_pixels_to_existing_file(df1, df2, params, request): assert len(metadata) == 2 # read pixels for pixels in [pixels1, pixels2]: - pixels_read = FileManager().load_pixels(file_name, pixels.get_global_parameters()) + pixels_read = FileManager().load_pixels( + file_name, pixels.get_global_parameters() + ) # check whether parameters are equal assert pixels.get_global_parameters() == pixels_read.get_global_parameters() assert pixels.data.equals(pixels_read.data) -@pytest.mark.parametrize('df, params', - [ - ('df_order_2', PixelParameters(number_fragments=2, binsize=1000)), - ] - ) +@pytest.mark.parametrize( + "df, params", + [ + ("df_order_2", PixelParameters(number_fragments=2, binsize=1000)), + ], +) def test_load_pixels_from_uri_fails_without_required_parameters(df, params, request): """Test loading pixels from uri fails without required parameters""" df = request.getfixturevalue(df) pixels = Pixels(df, **params.dict()) with tempfile.TemporaryDirectory() as tmpdirname: - file_name = tmpdirname + '/' + 'test.parquet' + file_name = tmpdirname + "/" + "test.parquet" FileManager().write_pixels(file_name, pixels) # try loading without required parameters with pytest.raises(ValueError) as e: Pixels.from_uri(file_name) -@pytest.mark.parametrize('df, params', - [ - ('df_order_2', PixelParameters(number_fragments=2, binsize=1000)), - ('df_order_2', PixelParameters(number_fragments=2, binsize=1000, metadata_combi=['A', 'B'])), - ('df_order_2', PixelParameters(number_fragments=2, binsize=1000, metadata_combi=['A', 'B'], label_sorted=True)), - ] - ) + +@pytest.mark.parametrize( + "df, params", + [ + ("df_order_2", PixelParameters(number_fragments=2, binsize=1000)), + ( + "df_order_2", + PixelParameters( + number_fragments=2, binsize=1000, metadata_combi=["A", "B"] + ), + ), + ( + "df_order_2", + PixelParameters( + number_fragments=2, + binsize=1000, + metadata_combi=["A", "B"], + label_sorted=True, + ), + ), + ], +) def test_load_pixels_from_uri_succeeds_exact_match(df, params, request): """Test loading pixels from uri succeeds with all required parameters""" df = request.getfixturevalue(df) pixels = Pixels(df, **params.dict()) # get meata data parameter - if params.dict()['metadata_combi'] is None: - params.metadata_combi = 'None' + if params.dict()["metadata_combi"] is None: + params.metadata_combi = "None" else: - params.metadata_combi = str("".join(params.dict()['metadata_combi'])) + params.metadata_combi = str("".join(params.dict()["metadata_combi"])) uri = ( - str(params.dict()['number_fragments']) + '::' + - str(params.dict()['binsize']) + '::' + - str(params.dict()['metadata_combi']) + '::' + - str(params.dict()['binary_labels_equal']) + '::' + - str(params.dict()['symmetry_flipped']) + '::' + - str(params.dict()['label_sorted']) + '::' + str(params.dict()['same_chromosome']) + str(params.dict()["number_fragments"]) + + "::" + + str(params.dict()["binsize"]) + + "::" + + str(params.dict()["metadata_combi"]) + + "::" + + str(params.dict()["binary_labels_equal"]) + + "::" + + str(params.dict()["symmetry_flipped"]) + + "::" + + str(params.dict()["label_sorted"]) + + "::" + + str(params.dict()["same_chromosome"]) ) with tempfile.TemporaryDirectory() as tmpdirname: - file_name = tmpdirname + '/' + 'test.parquet' + file_name = tmpdirname + "/" + "test.parquet" FileManager().write_pixels(file_name, pixels) # load pixels pixels_read = Pixels.from_uri(file_name + "::" + uri) assert pixels.get_global_parameters() == pixels_read.get_global_parameters() -@pytest.mark.parametrize('df, params', - [ - ('df_order_2', PixelParameters(number_fragments=2, binsize=1000)), - ('df_order_2', PixelParameters(number_fragments=2, binsize=1000, metadata_combi=['A', 'B'])), - ('df_order_2', PixelParameters(number_fragments=2, binsize=1000, metadata_combi=['A', 'B'], label_sorted=True)), - ] - ) +@pytest.mark.parametrize( + "df, params", + [ + ("df_order_2", PixelParameters(number_fragments=2, binsize=1000)), + ( + "df_order_2", + PixelParameters( + number_fragments=2, binsize=1000, metadata_combi=["A", "B"] + ), + ), + ( + "df_order_2", + PixelParameters( + number_fragments=2, + binsize=1000, + metadata_combi=["A", "B"], + label_sorted=True, + ), + ), + ], +) def test_load_pixels_from_uri_succeeds_partial_match(df, params, request): """Test loading pixels from uri succeeds with sufficient required parameters""" df = request.getfixturevalue(df) pixels = Pixels(df, **params.dict()) # get meata data parameter - if params.dict()['metadata_combi'] is None: - params.metadata_combi = 'None' + if params.dict()["metadata_combi"] is None: + params.metadata_combi = "None" else: - params.metadata_combi = str("".join(params.dict()['metadata_combi'])) + params.metadata_combi = str("".join(params.dict()["metadata_combi"])) uri = ( - str(params.dict()['number_fragments']) + '::' + - str(params.dict()['binsize']) + '::' + - str(params.dict()['metadata_combi']) + str(params.dict()["number_fragments"]) + + "::" + + str(params.dict()["binsize"]) + + "::" + + str(params.dict()["metadata_combi"]) ) with tempfile.TemporaryDirectory() as tmpdirname: - file_name = tmpdirname + '/' + 'test.parquet' + file_name = tmpdirname + "/" + "test.parquet" FileManager().write_pixels(file_name, pixels) # load pixels pixels_read = Pixels.from_uri(file_name + "::" + uri) assert pixels.get_global_parameters() == pixels_read.get_global_parameters() -@pytest.mark.parametrize('df, params', - [ - ('df_order_2', [PixelParameters(number_fragments=2, binsize=1000, metadata_combi=['A', 'B']), - PixelParameters(number_fragments=2, binsize=1000, metadata_combi=['A', 'B'], label_sorted=True)] - ) - ] - ) +@pytest.mark.parametrize( + "df, params", + [ + ( + "df_order_2", + [ + PixelParameters( + number_fragments=2, binsize=1000, metadata_combi=["A", "B"] + ), + PixelParameters( + number_fragments=2, + binsize=1000, + metadata_combi=["A", "B"], + label_sorted=True, + ), + ], + ) + ], +) def test_load_pixels_from_uri_fails_with_ambiguous_specification(df, params, request): """Test loading pixels from uri fails with uri is ambiguous""" df = request.getfixturevalue(df) pixels = Pixels(df, **params[0].dict()) pixels2 = Pixels(df, **params[1].dict()) uri = ( - str(params[0].dict()['number_fragments']) + '::' + - str(params[0].dict()['binsize']) + str(params[0].dict()["number_fragments"]) + + "::" + + str(params[0].dict()["binsize"]) ) with tempfile.TemporaryDirectory() as tmpdirname: - file_name = tmpdirname + '/' + 'test.parquet' + file_name = tmpdirname + "/" + "test.parquet" FileManager().write_pixels(file_name, pixels) FileManager().write_pixels(file_name, pixels2) # load pixels with pytest.raises(ValueError) as e: - Pixels.from_uri(file_name + "::" + uri) \ No newline at end of file + Pixels.from_uri(file_name + "::" + uri) diff --git a/tests/test_labels.py b/tests/test_labels.py index 8a1168e..9fa499a 100644 --- a/tests/test_labels.py +++ b/tests/test_labels.py @@ -28,6 +28,7 @@ def annotator_with_entries(): def bad_df(): return pd.DataFrame({"be": ["bop"]}) + @pytest.fixture def unlabelled_df(): return pd.DataFrame( @@ -47,6 +48,7 @@ def unlabelled_df(): } ) + @pytest.fixture def unlabelled_fragments(unlabelled_df): return fragments.Fragments(unlabelled_df) @@ -56,6 +58,7 @@ def unlabelled_fragments(unlabelled_df): def unlabelled_fragments_dask(unlabelled_df): return fragments.Fragments(dd.from_pandas(unlabelled_df, npartitions=1)) + @pytest.fixture def labelled_df(): return pd.DataFrame( @@ -91,46 +94,54 @@ def test_fragments_constructor_accepts_labelled_fragments(labelled_df): frag = fragments.Fragments(labelled_df) assert frag.contains_metadata -@pytest.mark.parametrize("fragments", ["unlabelled_fragments", "unlabelled_fragments_dask"]) -def test_annotator_drops_unknown_fragments( - annotator_with_entries, fragments, request -): + +@pytest.mark.parametrize( + "fragments", ["unlabelled_fragments", "unlabelled_fragments_dask"] +) +def test_annotator_drops_unknown_fragments(annotator_with_entries, fragments, request): labelled_fragments = annotator_with_entries.annotate_fragments( request.getfixturevalue(fragments) ) # check length assert len(labelled_fragments.data) == 2 -@pytest.mark.parametrize("fragments", ["unlabelled_fragments", "unlabelled_fragments_dask"]) -def test_annotator_produces_correct_schema( - annotator_with_entries, fragments, request -): + +@pytest.mark.parametrize( + "fragments", ["unlabelled_fragments", "unlabelled_fragments_dask"] +) +def test_annotator_produces_correct_schema(annotator_with_entries, fragments, request): labelled_fragments = annotator_with_entries.annotate_fragments( request.getfixturevalue(fragments) ) # check schema (fragment constructor checks it) assert labelled_fragments.contains_metadata -@pytest.mark.parametrize("fragments", ["unlabelled_fragments", "unlabelled_fragments_dask"]) -def test_annotator_calls_sisters_correctly( - annotator_with_entries, fragments, request -): + +@pytest.mark.parametrize( + "fragments", ["unlabelled_fragments", "unlabelled_fragments_dask"] +) +def test_annotator_calls_sisters_correctly(annotator_with_entries, fragments, request): labelled_fragments = annotator_with_entries.annotate_fragments( request.getfixturevalue(fragments) ) # check values expected = pd.Series(["SisterB", "SisterA"]) if isinstance(request.getfixturevalue(fragments), dd.DataFrame): - assert np.array_equal(labelled_fragments.data.metadata.values.compute(), expected.values) + assert np.array_equal( + labelled_fragments.data.metadata.values.compute(), expected.values + ) else: assert np.array_equal(labelled_fragments.data.metadata.values, expected) -@pytest.mark.parametrize("fragments", ["unlabelled_fragments", "unlabelled_fragments_dask"]) -def test_annotator_maintains_dataframe_type( - annotator_with_entries, fragments, request -): + +@pytest.mark.parametrize( + "fragments", ["unlabelled_fragments", "unlabelled_fragments_dask"] +) +def test_annotator_maintains_dataframe_type(annotator_with_entries, fragments, request): labelled_fragments = annotator_with_entries.annotate_fragments( request.getfixturevalue(fragments) ) # check values - assert isinstance(labelled_fragments.data, type(request.getfixturevalue(fragments).data)) \ No newline at end of file + assert isinstance( + labelled_fragments.data, type(request.getfixturevalue(fragments).data) + ) diff --git a/tests/test_pixels.py b/tests/test_pixels.py index 301f937..542f98b 100644 --- a/tests/test_pixels.py +++ b/tests/test_pixels.py @@ -21,9 +21,7 @@ def chromosome_sizes(): @pytest.fixture def genomic_binner(chromosome_sizes): """genomic binner for pixels""" - return pixels.GenomicBinner( - bin_size=100_000 - ) + return pixels.GenomicBinner(bin_size=100_000) @pytest.fixture @@ -57,7 +55,6 @@ def contacts_df(): ) - @pytest.fixture def expected_pixels(): return pd.DataFrame( @@ -70,17 +67,18 @@ def expected_pixels(): } ) + @pytest.fixture def expected_pixels_different_chromosomes(): return pd.DataFrame( { "chrom_1": ["chr1"] * 3, - "start_1": [100_000,10_000_000, 5_000_000], + "start_1": [100_000, 10_000_000, 5_000_000], "chrom_2": ["chr1"] * 3, - "start_2": [500_000,25_000_000, 7_000_000], + "start_2": [500_000, 25_000_000, 7_000_000], "chrom_3": ["chr1", "chr1", "chr4"], - "start_3": [600_000,6_000_000, 2_000_000], - "contact_count": [1, 2,1], + "start_3": [600_000, 6_000_000, 2_000_000], + "contact_count": [1, 2, 1], } ) @@ -97,12 +95,15 @@ def test_genomic_binner_bins_correctly_same_chromosome_pandas( assert result.symmetry_flipped == False assert result.metadata_combi is None + def test_genomic_binner_bins_correctly_w_different_chromosomes_pandas( genomic_binner, contacts_df, expected_pixels_different_chromosomes ): contacts = Contacts(contacts_df) result = genomic_binner.bin_contacts(contacts, same_chromosome=False) - assert np.array_equal(result.data.values, expected_pixels_different_chromosomes.values) + assert np.array_equal( + result.data.values, expected_pixels_different_chromosomes.values + ) assert result.number_fragments == 3 assert result.binsize == 100_000 assert result.binary_labels_equal == False @@ -122,14 +123,17 @@ def test_genomic_binner_bins_correctly_same_chromosome_dask( assert result.symmetry_flipped == False assert result.metadata_combi is None + def test_genomic_binner_bins_correctly_w_different_chromosome_dask( genomic_binner, contacts_df, expected_pixels_different_chromosomes ): contacts = Contacts(dd.from_pandas(contacts_df, chunksize=1000)) result = genomic_binner.bin_contacts(contacts, same_chromosome=False) - assert np.array_equal(result.data.compute().values, expected_pixels_different_chromosomes.values) + assert np.array_equal( + result.data.compute().values, expected_pixels_different_chromosomes.values + ) assert result.number_fragments == 3 assert result.binsize == 100_000 assert result.binary_labels_equal == False assert result.symmetry_flipped == False - assert result.metadata_combi is None \ No newline at end of file + assert result.metadata_combi is None diff --git a/tests/test_symmetry.py b/tests/test_symmetry.py index 2d6c7be..8ad4456 100644 --- a/tests/test_symmetry.py +++ b/tests/test_symmetry.py @@ -7,132 +7,239 @@ import dask.dataframe as dd from spoc.contacts import Contacts, ContactManipulator from .fixtures.symmetry import ( - unlabelled_contacts_2d, - unlabelled_contacts_2d_flipped, - unlabelled_contacts_3d, - unlabelled_contacts_3d_flipped, - labelled_binary_contacts_2d, - labelled_binary_contacts_2d_sorted, - labelled_binary_contacts_3d, - labelled_binary_contacts_3d_sorted, - binary_contacts_not_equated_2d, - binary_contacts_not_equated_3d, - binary_contacts_not_equated_4d, - binary_contacts_equated_2d, - binary_contacts_equated_3d, - binary_contacts_equated_4d, - labelled_binary_contacts_2d_unflipped, - labelled_binary_contacts_2d_flipped, - labelled_binary_contacts_3d_unflipped, - labelled_binary_contacts_3d_unflipped_example2, - labelled_binary_contacts_3d_flipped, - labelled_binary_contacts_3d_flipped_example2, - unlabelled_contacts_diff_chrom_2d, - unlabelled_contacts_diff_chrom_3d, - unlabelled_contacts_diff_chrom_4d, - unlabelled_contacts_diff_chrom_3d_flipped, - unlabelled_contacts_diff_chrom_2d_flipped, - unlabelled_contacts_diff_chrom_4d_flipped, - labelled_binary_contacts_diff_chrom_2d, - labelled_binary_contacts_diff_chrom_2d_flipped, - labelled_binary_contacts_diff_chrom_3d, - labelled_binary_contacts_diff_chrom_3d_flipped + unlabelled_contacts_2d, + unlabelled_contacts_2d_flipped, + unlabelled_contacts_3d, + unlabelled_contacts_3d_flipped, + labelled_binary_contacts_2d, + labelled_binary_contacts_2d_sorted, + labelled_binary_contacts_3d, + labelled_binary_contacts_3d_sorted, + binary_contacts_not_equated_2d, + binary_contacts_not_equated_3d, + binary_contacts_not_equated_4d, + binary_contacts_equated_2d, + binary_contacts_equated_3d, + binary_contacts_equated_4d, + labelled_binary_contacts_2d_unflipped, + labelled_binary_contacts_2d_flipped, + labelled_binary_contacts_3d_unflipped, + labelled_binary_contacts_3d_unflipped_example2, + labelled_binary_contacts_3d_flipped, + labelled_binary_contacts_3d_flipped_example2, + unlabelled_contacts_diff_chrom_2d, + unlabelled_contacts_diff_chrom_3d, + unlabelled_contacts_diff_chrom_4d, + unlabelled_contacts_diff_chrom_3d_flipped, + unlabelled_contacts_diff_chrom_2d_flipped, + unlabelled_contacts_diff_chrom_4d_flipped, + labelled_binary_contacts_diff_chrom_2d, + labelled_binary_contacts_diff_chrom_2d_flipped, + labelled_binary_contacts_diff_chrom_3d, + labelled_binary_contacts_diff_chrom_3d_flipped, ) - -@pytest.mark.parametrize("unflipped, flipped", - [('unlabelled_contacts_2d', 'unlabelled_contacts_2d_flipped'), - ('unlabelled_contacts_3d', 'unlabelled_contacts_3d_flipped')]) +@pytest.mark.parametrize( + "unflipped, flipped", + [ + ("unlabelled_contacts_2d", "unlabelled_contacts_2d_flipped"), + ("unlabelled_contacts_3d", "unlabelled_contacts_3d_flipped"), + ], +) def test_unlabelled_contacts_flipped_correctly(unflipped, flipped, request): - unflipped, flipped = request.getfixturevalue(unflipped), request.getfixturevalue(flipped) + unflipped, flipped = request.getfixturevalue(unflipped), request.getfixturevalue( + flipped + ) contacts = Contacts(unflipped) flipped_contacts = ContactManipulator().flip_symmetric_contacts(contacts) pd.testing.assert_frame_equal(flipped_contacts.data, flipped) -@pytest.mark.parametrize("unflipped, flipped", - [('unlabelled_contacts_2d', 'unlabelled_contacts_2d_flipped'), - ('unlabelled_contacts_3d', 'unlabelled_contacts_3d_flipped')]) + +@pytest.mark.parametrize( + "unflipped, flipped", + [ + ("unlabelled_contacts_2d", "unlabelled_contacts_2d_flipped"), + ("unlabelled_contacts_3d", "unlabelled_contacts_3d_flipped"), + ], +) def test_unlabelled_contacts_flipped_correctly_dask(unflipped, flipped, request): - unflipped, flipped = dd.from_pandas(request.getfixturevalue(unflipped), npartitions=1), request.getfixturevalue(flipped) + unflipped, flipped = dd.from_pandas( + request.getfixturevalue(unflipped), npartitions=1 + ), request.getfixturevalue(flipped) contacts = Contacts(unflipped) flipped_contacts = ContactManipulator().flip_symmetric_contacts(contacts) - pd.testing.assert_frame_equal(flipped_contacts.data.compute().reset_index(drop=True), flipped.reset_index(drop=True)) - - -@pytest.mark.parametrize("unsorted, sorted_contacts", - [('labelled_binary_contacts_2d', 'labelled_binary_contacts_2d_sorted'), - ('labelled_binary_contacts_3d', 'labelled_binary_contacts_3d_sorted')]) + pd.testing.assert_frame_equal( + flipped_contacts.data.compute().reset_index(drop=True), + flipped.reset_index(drop=True), + ) + + +@pytest.mark.parametrize( + "unsorted, sorted_contacts", + [ + ("labelled_binary_contacts_2d", "labelled_binary_contacts_2d_sorted"), + ("labelled_binary_contacts_3d", "labelled_binary_contacts_3d_sorted"), + ], +) def test_labelled_contacts_are_sorted_correctly(unsorted, sorted_contacts, request): - unsorted, sorted_contacts = request.getfixturevalue(unsorted), request.getfixturevalue(sorted_contacts) + unsorted, sorted_contacts = request.getfixturevalue( + unsorted + ), request.getfixturevalue(sorted_contacts) contacts = Contacts(unsorted) result = ContactManipulator().sort_labels(contacts) pd.testing.assert_frame_equal(result.data, sorted_contacts) assert result.label_sorted -@pytest.mark.parametrize("unsorted, sorted_contacts", - [('labelled_binary_contacts_2d', 'labelled_binary_contacts_2d_sorted'), - ('labelled_binary_contacts_3d', 'labelled_binary_contacts_3d_sorted')]) -def test_labelled_contacts_are_sorted_correctly_dask(unsorted, sorted_contacts, request): - unsorted, sorted_contacts = dd.from_pandas(request.getfixturevalue(unsorted), npartitions=1), request.getfixturevalue(sorted_contacts) + +@pytest.mark.parametrize( + "unsorted, sorted_contacts", + [ + ("labelled_binary_contacts_2d", "labelled_binary_contacts_2d_sorted"), + ("labelled_binary_contacts_3d", "labelled_binary_contacts_3d_sorted"), + ], +) +def test_labelled_contacts_are_sorted_correctly_dask( + unsorted, sorted_contacts, request +): + unsorted, sorted_contacts = dd.from_pandas( + request.getfixturevalue(unsorted), npartitions=1 + ), request.getfixturevalue(sorted_contacts) contacts = Contacts(unsorted) result = ContactManipulator().sort_labels(contacts) - pd.testing.assert_frame_equal(result.data.compute().reset_index(drop=True), sorted_contacts.reset_index(drop=True)) + pd.testing.assert_frame_equal( + result.data.compute().reset_index(drop=True), + sorted_contacts.reset_index(drop=True), + ) assert result.label_sorted -@pytest.mark.parametrize("unequated, equated", - [('binary_contacts_not_equated_2d', 'binary_contacts_equated_2d'), - ('binary_contacts_not_equated_3d', 'binary_contacts_equated_3d'), - ('binary_contacts_not_equated_4d', 'binary_contacts_equated_4d')]) +@pytest.mark.parametrize( + "unequated, equated", + [ + ("binary_contacts_not_equated_2d", "binary_contacts_equated_2d"), + ("binary_contacts_not_equated_3d", "binary_contacts_equated_3d"), + ("binary_contacts_not_equated_4d", "binary_contacts_equated_4d"), + ], +) def test_equate_binary_labels(unequated, equated, request): - unequated, equated = request.getfixturevalue(unequated), request.getfixturevalue(equated) + unequated, equated = request.getfixturevalue(unequated), request.getfixturevalue( + equated + ) contacts = Contacts(unequated, label_sorted=True) result = ContactManipulator().equate_binary_labels(contacts) pd.testing.assert_frame_equal(result.data, equated) -@pytest.mark.parametrize("unequated, equated", - [('binary_contacts_not_equated_2d', 'binary_contacts_equated_2d'), - ('binary_contacts_not_equated_3d', 'binary_contacts_equated_3d'), - ('binary_contacts_not_equated_4d', 'binary_contacts_equated_4d')]) + +@pytest.mark.parametrize( + "unequated, equated", + [ + ("binary_contacts_not_equated_2d", "binary_contacts_equated_2d"), + ("binary_contacts_not_equated_3d", "binary_contacts_equated_3d"), + ("binary_contacts_not_equated_4d", "binary_contacts_equated_4d"), + ], +) def test_equate_binary_labels_dask(unequated, equated, request): - unequated, equated = dd.from_pandas(request.getfixturevalue(unequated), npartitions=1), request.getfixturevalue(equated) + unequated, equated = dd.from_pandas( + request.getfixturevalue(unequated), npartitions=1 + ), request.getfixturevalue(equated) contacts = Contacts(unequated, label_sorted=True) result = ContactManipulator().equate_binary_labels(contacts) - pd.testing.assert_frame_equal(result.data.compute().reset_index(drop=True), equated.reset_index(drop=True)) - - -@pytest.mark.parametrize("unflipped, flipped", - [('labelled_binary_contacts_2d_unflipped', 'labelled_binary_contacts_2d_flipped'), - ('labelled_binary_contacts_3d_unflipped_example2', 'labelled_binary_contacts_3d_flipped_example2'), - ('labelled_binary_contacts_3d_unflipped', 'labelled_binary_contacts_3d_flipped')]) + pd.testing.assert_frame_equal( + result.data.compute().reset_index(drop=True), equated.reset_index(drop=True) + ) + + +@pytest.mark.parametrize( + "unflipped, flipped", + [ + ( + "labelled_binary_contacts_2d_unflipped", + "labelled_binary_contacts_2d_flipped", + ), + ( + "labelled_binary_contacts_3d_unflipped_example2", + "labelled_binary_contacts_3d_flipped_example2", + ), + ( + "labelled_binary_contacts_3d_unflipped", + "labelled_binary_contacts_3d_flipped", + ), + ], +) def test_flip_labelled_contacts(unflipped, flipped, request): - unflipped, flipped = request.getfixturevalue(unflipped), request.getfixturevalue(flipped) + unflipped, flipped = request.getfixturevalue(unflipped), request.getfixturevalue( + flipped + ) contacts = Contacts(unflipped, label_sorted=True) result = ContactManipulator().flip_symmetric_contacts(contacts) - pd.testing.assert_frame_equal(result.data.reset_index(drop=True), flipped.reset_index(drop=True)) - - -@pytest.mark.parametrize("unflipped, flipped", - [('labelled_binary_contacts_2d_unflipped', 'labelled_binary_contacts_2d_flipped'), - ('labelled_binary_contacts_3d_unflipped_example2', 'labelled_binary_contacts_3d_flipped_example2'), - ('labelled_binary_contacts_3d_unflipped', 'labelled_binary_contacts_3d_flipped')]) + pd.testing.assert_frame_equal( + result.data.reset_index(drop=True), flipped.reset_index(drop=True) + ) + + +@pytest.mark.parametrize( + "unflipped, flipped", + [ + ( + "labelled_binary_contacts_2d_unflipped", + "labelled_binary_contacts_2d_flipped", + ), + ( + "labelled_binary_contacts_3d_unflipped_example2", + "labelled_binary_contacts_3d_flipped_example2", + ), + ( + "labelled_binary_contacts_3d_unflipped", + "labelled_binary_contacts_3d_flipped", + ), + ], +) def test_flip_labelled_contacts_dask(unflipped, flipped, request): - unflipped, flipped = dd.from_pandas(request.getfixturevalue(unflipped), npartitions=1), request.getfixturevalue(flipped) + unflipped, flipped = dd.from_pandas( + request.getfixturevalue(unflipped), npartitions=1 + ), request.getfixturevalue(flipped) contacts = Contacts(unflipped, label_sorted=True) result = ContactManipulator().flip_symmetric_contacts(contacts) - pd.testing.assert_frame_equal(result.data.compute().reset_index(drop=True), flipped.reset_index(drop=True)) - -@pytest.mark.parametrize("unflipped, flipped", - [('unlabelled_contacts_diff_chrom_3d', 'unlabelled_contacts_diff_chrom_3d_flipped'), - ('unlabelled_contacts_diff_chrom_2d', 'unlabelled_contacts_diff_chrom_2d_flipped'), - ('unlabelled_contacts_diff_chrom_4d', 'unlabelled_contacts_diff_chrom_4d_flipped'), - ('labelled_binary_contacts_diff_chrom_2d', 'labelled_binary_contacts_diff_chrom_2d_flipped'), - ('labelled_binary_contacts_diff_chrom_3d', 'labelled_binary_contacts_diff_chrom_3d_flipped'), - ]) + pd.testing.assert_frame_equal( + result.data.compute().reset_index(drop=True), flipped.reset_index(drop=True) + ) + + +@pytest.mark.parametrize( + "unflipped, flipped", + [ + ( + "unlabelled_contacts_diff_chrom_3d", + "unlabelled_contacts_diff_chrom_3d_flipped", + ), + ( + "unlabelled_contacts_diff_chrom_2d", + "unlabelled_contacts_diff_chrom_2d_flipped", + ), + ( + "unlabelled_contacts_diff_chrom_4d", + "unlabelled_contacts_diff_chrom_4d_flipped", + ), + ( + "labelled_binary_contacts_diff_chrom_2d", + "labelled_binary_contacts_diff_chrom_2d_flipped", + ), + ( + "labelled_binary_contacts_diff_chrom_3d", + "labelled_binary_contacts_diff_chrom_3d_flipped", + ), + ], +) def test_flip_unlabelled_contacts_different_chromosomes(unflipped, flipped, request): - unflipped, flipped = request.getfixturevalue(unflipped), request.getfixturevalue(flipped) + unflipped, flipped = request.getfixturevalue(unflipped), request.getfixturevalue( + flipped + ) contacts = Contacts(unflipped) - result = ContactManipulator().flip_symmetric_contacts(contacts, sort_chromosomes=True) - pd.testing.assert_frame_equal(result.data.reset_index(drop=True).sort_index(axis=1), - flipped.reset_index(drop=True).sort_index(axis=1)) \ No newline at end of file + result = ContactManipulator().flip_symmetric_contacts( + contacts, sort_chromosomes=True + ) + pd.testing.assert_frame_equal( + result.data.reset_index(drop=True).sort_index(axis=1), + flipped.reset_index(drop=True).sort_index(axis=1), + )