diff --git a/doc/protocols_pipelines.rst b/doc/protocols_pipelines.rst index bd910fb..bc71a2b 100644 --- a/doc/protocols_pipelines.rst +++ b/doc/protocols_pipelines.rst @@ -9,7 +9,7 @@ Typical Hi-C Workflow ---------------------- A typical pairtools workflow for processing standard Hi-C data is outlined below. -Please, note that this is a shorter version; you can find a more detailed and reproducible example in chapter :ref:`examples/pairtools_walkthrough`. +Please, note that this is a shorter version. For a detailed reproducible example, please, check the Jupyter notebook "Pairtools Walkthrough". 1. Align sequences to the reference genome with ``bwa mem``: @@ -103,6 +103,7 @@ Technical tips bwa mem -SP index input.R1.fastq input.R2.fastq | \ pairtools parse -c chromsizes.txt | \ pairtools sort | \ + pairtools dedup | \ --output output.nodups.pairs.gz \ --output-dups output.dups.pairs.gz \ --output-unmapped output.unmapped.pairs.gz @@ -116,8 +117,9 @@ Technical tips Each pairtool has the CLI flags --nproc-in and --nproc-out to control the number of cores dedicated to input decompression and output compression. Additionally, `pairtools sort` parallelizes sorting with `--nproc`.ß -Example Workflows +Advanced Workflows ------------------ + For more advanced workflows, please check the following projects: - `Distiller-nf `_ is a feature-rich Open2C Hi-C processing pipeline for the Nextflow workflow manager. diff --git a/doc/stats.rst b/doc/stats.rst index 49b216a..4895767 100644 --- a/doc/stats.rst +++ b/doc/stats.rst @@ -14,7 +14,7 @@ output file. - **Global statistics** include: - number of pairs (total, unmapped, single-side mapped, etc.), - - total number of different pair types (UU, NN, NU, and others, see ` Pair types in pairtools docs `_), + - total number of different pair types (UU, NN, NU, and others, see `Pair types in pairtools docs `_), - number of contacts between all chromosome pairs - **Summary statistics** include: @@ -59,17 +59,23 @@ replacement from a finite pool of fragments in DNA library [1]_ [2]_. With each new sequenced molecule, the expected number of observed unique molecules increases according to a simple equation: -$$ U(N+1) = U(N) + (1 - {U(N) \\over C}), $$ +.. math:: -where $N$ is the number of sequenced molecules, $U(N)$ is the expected number -of observed unique molecules after sequencing $N$ molecules, and C is the library complexity. + U(N+1) = U(N) + \left(1 - \frac{U(N)}{C} \right), + +where :math:`N` is the number of sequenced molecules, :math:`U(N)` is the expected number +of observed unique molecules after sequencing :math:`N` molecules, and :math:`C` is the library complexity. This differential equation yields [1, 2]: -$$ {U(N) \\over C} = 1 - exp( - {N \\over C}), $$ +.. math:: + + {U(N) \over C} = 1 - exp\left( - \frac{N}{C} \right), which can be solved as -$$ C = \Re(lambert W( - { \exp( - {1 \\over u} ) \\over u} ) ) + {1 \\over u} $$ +.. math:: + + C = \Re \left( W_{Lambert} \left( - \frac{ \exp\left( - \frac{1}{U} \right) } {U} \right) \right) + \frac{1}{U} Library complexity can guide in the choice of sequencing depth of the library and provide an estimate of library quality. diff --git a/pairtools/cli/stats.py b/pairtools/cli/stats.py index 9001bfc..c7166ac 100644 --- a/pairtools/cli/stats.py +++ b/pairtools/cli/stats.py @@ -28,6 +28,14 @@ " all overlapping statistics. Non-overlapping statistics are appended to" " the end of the file. Supported for tsv stats with single filter.", ) +@click.option( + "--n-dist-bins-decade", + type=int, + default=PairCounter.N_DIST_BINS_DECADE_DEFAULT, + show_default=True, + required=False, + help="Number of bins to split the distance range in log10-space, specified per a factor of 10 difference.", +) @click.option( "--with-chromsizes/--no-chromsizes", is_flag=True, @@ -107,7 +115,7 @@ ) @common_io_options def stats( - input_path, output, merge, bytile_dups, output_bytile_stats, filter, **kwargs + input_path, output, merge, n_dist_bins_decade, bytile_dups, output_bytile_stats, filter, **kwargs ): """Calculate pairs statistics. @@ -123,6 +131,7 @@ def stats( input_path, output, merge, + n_dist_bins_decade, bytile_dups, output_bytile_stats, filter, @@ -131,10 +140,10 @@ def stats( def stats_py( - input_path, output, merge, bytile_dups, output_bytile_stats, filter, **kwargs + input_path, output, merge, n_dist_bins_decade, bytile_dups, output_bytile_stats, filter, **kwargs ): if merge: - do_merge(output, input_path, **kwargs) + do_merge(output, input_path, n_dist_bins_decade=n_dist_bins_decade, **kwargs) return if len(input_path) == 0: @@ -181,6 +190,7 @@ def stats_py( filter = None stats = PairCounter( + n_dist_bins_decade=n_dist_bins_decade, bytile_dups=bytile_dups, filters=filter, startup_code=kwargs.get("startup_code", ""), # for evaluation of filters diff --git a/pairtools/lib/stats.py b/pairtools/lib/stats.py index 16ed11a..e008314 100644 --- a/pairtools/lib/stats.py +++ b/pairtools/lib/stats.py @@ -12,6 +12,183 @@ logger = get_logger() +def parse_number(s): + if s.isdigit(): + return int(s) + elif s.replace(".", "", 1).isdigit(): + return float(s) + else: + return s + + +def flat_dict_to_nested(input_dict, sep='/'): + output_dict = {} + + for key, value in input_dict.items(): + if type(key) == tuple: + key_parts = key + elif type(key) == str: + key_parts = key.split(sep) + else: + raise ValueError(f"Key type can be either str or tuple. Found key {key} of type {type(key)}.") + + current_dict = output_dict + for key_part in key_parts[:-1]: + current_dict = current_dict.setdefault(key_part, {}) + current_dict[key_parts[-1]] = value + + return output_dict + + +def nested_dict_to_flat(d, tuple_keys=False, sep='/'): + """Flatten a nested dictionary to a flat dictionary. + + Parameters + ---------- + d: dict + A nested dictionary to flatten. + tuple_keys: bool + If True, keys will be joined into tuples. Otherwise, they will be joined into strings. + sep: str + The separator to use between the parent key and the key if tuple_keys==False. + Returns + ------- + dict + A flat dictionary. + """ + + if tuple_keys: + join_keys = lambda k1,k2: (k1,) + k2 + else: + join_keys = lambda k1,k2: (k1+sep+k2) if k2 else k1 + + out = {} + + for k1, v1 in d.items(): + if isinstance(v1, dict): + out.update({ + join_keys(k1,k2): v2 + for k2, v2 in nested_dict_to_flat(v1, tuple_keys, sep).items() + }) + else: + if tuple_keys: + out[(k1,)] = v1 + else: + out[k1] = v1 + + return out + +def is_nested_dict(d): + """Check if a dictionary is nested. + + Parameters + ---------- + d: dict + A dictionary to check. + Returns + ------- + bool + True if the dictionary is nested, False otherwise. + """ + + if not isinstance(d, dict): + return False + + for v in d.values(): + if isinstance(v, dict): + return True + + return False + +def is_tuple_keyed_dict(d): + """Check if a dictionary is tuple-keyed. + + Parameters + ---------- + d: dict + A dictionary to check. + Returns + ------- + bool + True if the dictionary is tuple-keyed, False otherwise. + """ + + if not isinstance(d, dict): + return False + + for k,v in d.items(): + if not isinstance(k, tuple): + return False + if isinstance(v, dict): + return False + + return True + +def is_str_keyed_dict(d): + """Check if a dictionary is string-keyed. + + Parameters + ---------- + d: dict + A dictionary to check. + Returns + ------- + bool + True if the dictionary is string-keyed, False otherwise. + """ + + if not isinstance(d, dict): + return False + + for k,v in d.keys(): + if not isinstance(k, str): + return False + if isinstance(v, dict): + return False + + return True + + +def swap_levels_nested_dict(nested_dict, level1, level2, sep='/'): + """Swap the order of two levels in a nested dictionary. + + Parameters + ---------- + nested_dict: dict + A nested dictionary. + level1: int + The index of the first level to swap. + level2: int + The index of the second level to swap. + Returns + ------- + dict + A nested dictionary with the levels swapped. + """ + + if is_tuple_keyed_dict(nested_dict): + out = {} + for k1, v1 in nested_dict.items(): + k1_list = list(k1) + k1_list[level1], k1_list[level2] = k1_list[level2], k1_list[level1] + out[tuple(k1_list)] = v1 + return out + + elif is_nested_dict(nested_dict): + out = nested_dict_to_flat(nested_dict, tuple_keys=True) + out = swap_levels_nested_dict(out, level1, level2) + out = flat_dict_to_nested(out) + return out + + elif is_str_keyed_dict(nested_dict): + out = nested_dict_to_flat(nested_dict, sep=sep) + out = swap_levels_nested_dict(out, level1, level2) + out = {sep.join(k):v for k,v in out.items()} + return out + + else: + raise ValueError("Input dictionary must be either nested, string-keyed or tuple-keyed") + class PairCounter(Mapping): """ A Counter for Hi-C pairs that accumulates various statistics. @@ -25,12 +202,16 @@ class PairCounter(Mapping): _SEP = "\t" _KEY_SEP = "/" + DIST_FREQ_REL_DIFF_THRESHOLD = 0.05 + N_DIST_BINS_DECADE_DEFAULT = 8 + MIN_LOG10_DIST_DEFAULT = 0 + MAX_LOG10_DIST_DEFAULT = 9 def __init__( self, - min_log10_dist=0, - max_log10_dist=9, - log10_dist_bin_step=0.25, + min_log10_dist=MIN_LOG10_DIST_DEFAULT, + max_log10_dist=MAX_LOG10_DIST_DEFAULT, + n_dist_bins_decade=N_DIST_BINS_DECADE_DEFAULT, bytile_dups=False, filters=None, **kwargs, @@ -51,15 +232,18 @@ def __init__( # some variables used for initialization: # genomic distance bining for the ++/--/-+/+- distribution - self._dist_bins = np.r_[ - 0, - np.round( - 10 - ** np.arange( - min_log10_dist, max_log10_dist + 0.001, log10_dist_bin_step - ) - ).astype(np.int_), - ] + log10_dist_bin_step = 1.0 / n_dist_bins_decade + self._dist_bins = np.unique( + np.r_[ + 0, + np.round( + 10 + ** np.arange( + min_log10_dist, max_log10_dist + 0.001, log10_dist_bin_step + ) + ).astype(np.int_), + ] + ) # establish structure of an empty _stat: for key in self.filters: @@ -77,8 +261,9 @@ def __init__( self._stat[key]["cis"] = 0 self._stat[key]["trans"] = 0 self._stat[key]["pair_types"] = {} + # to be removed: - self._stat[key]["dedup"] = {} + # self._stat[key]["dedup"] = {} self._stat[key]["cis_1kb+"] = 0 self._stat[key]["cis_2kb+"] = 0 @@ -122,6 +307,7 @@ def __init__( ) self._summaries_calculated = False + def __getitem__(self, key, filter="no_filter"): if isinstance(key, str): # let's strip any unintentional '/' @@ -192,24 +378,113 @@ def __getitem__(self, key, filter="no_filter"): else: raise ValueError("{} is not a valid key".format(k)) + def __iter__(self): return iter(self._stat) + def __len__(self): return len(self._stat) + + def find_dist_freq_convergence_distance(self, rel_threshold): + """Finds the largest distance at which the frequency of pairs of reads + with different strands deviates from their average by the specified + relative threshold.""" + + out = {} + all_strands = ["++", "--", "-+", "+-"] + + for filter in self.filters: + out[filter] = {} + + dist_freqs_by_strands = { + strands: np.array(list(self._stat[filter]["dist_freq"][strands].values())) + for strands in all_strands} + + # Calculate the average frequency of pairs with different strands + avg_freq_all_strands = np.mean(np.vstack(list(dist_freqs_by_strands.values())), axis=0) + + # Calculate the largest distance at which the frequency of pairs of at least one strand combination deviates from the average by the given threshold + rel_deviations = {strands: np.nan_to_num( + np.abs(dist_freqs_by_strands[strands] - avg_freq_all_strands) + / avg_freq_all_strands) + for strands in all_strands} + + idx_maxs = {strand:0 for strand in all_strands} + for strands in all_strands: + bin_exceeds = rel_deviations[strands] > rel_threshold + if np.any(bin_exceeds): + idx_maxs[strands] = np.max(np.nonzero(bin_exceeds)) + + # Find the largest distance and the strand combination where frequency of pairs deviates from the average by the given threshold: + convergence_bin_idx = 0 + convergence_strands = '??' + convergence_dist = '0' + + for strands in all_strands: + if (idx_maxs[strands] > convergence_bin_idx): + convergence_bin_idx = idx_maxs[strands] + convergence_strands = strands + + if idx_maxs[strands] < len(self._dist_bins): + convergence_dist = self._dist_bins[convergence_bin_idx+1] + else: + convergence_dist = np.iinfo(np.int64) + + + out[filter]["convergence_dist"] = convergence_dist + out[filter]["strands_w_max_convergence_dist"] = convergence_strands + out[filter]['convergence_rel_diff_threshold'] = rel_threshold + + out[filter]['n_cis_pairs_below_convergence_dist'] = { + strands:dist_freqs_by_strands[strands][:convergence_bin_idx+1].sum() for strands in all_strands + for strands in all_strands + } + + out[filter]['n_cis_pairs_below_convergence_dist_all_strands'] = sum( + list(out[filter]['n_cis_pairs_below_convergence_dist'].values())) + + n_cis_pairs_above_convergence_dist = { + strands:dist_freqs_by_strands[strands][convergence_bin_idx+1:].sum() for strands in all_strands + for strands in all_strands + } + + out[filter]['n_cis_pairs_above_convergence_dist_all_strands'] = sum( + list(n_cis_pairs_above_convergence_dist.values())) + + norms = dict( + cis=self._stat[filter]['cis'], + total_mapped=self._stat[filter]['total_mapped'] + ) + + if 'total_nodups' in self._stat[filter]: + norms['total_nodups'] = self._stat[filter]['total_nodups'] + + for key, norm_factor in norms.items(): + out[filter][f'frac_{key}_in_cis_below_convergence_dist'] = { + strands: n_cis_pairs / norm_factor + for strands, n_cis_pairs in out[filter]['n_cis_pairs_below_convergence_dist'].items() + } + + out[filter][f'frac_{key}_in_cis_below_convergence_dist_all_strands'] = sum( + list(out[filter][f'frac_{key}_in_cis_below_convergence_dist'].values())) + + out[filter][f'frac_{key}_in_cis_above_convergence_dist_all_strands'] = ( + sum(list(n_cis_pairs_above_convergence_dist.values())) / norm_factor ) + + return out + + def calculate_summaries(self): """calculate summary statistics (fraction of cis pairs at different cutoffs, complexity estimate) based on accumulated counts. Results are saved into self._stat["filter_name"]['summary"] """ - for key in self.filters.keys(): - self._stat[key]["summary"]["frac_dups"] = ( - (self._stat[key]["total_dups"] / self._stat[key]["total_mapped"]) - if self._stat[key]["total_mapped"] > 0 - else 0 - ) + convergence_stats = self.find_dist_freq_convergence_distance( + self.DIST_FREQ_REL_DIFF_THRESHOLD) + for filter_name in self.filters.keys(): for cis_count in ( "cis", "cis_1kb+", @@ -219,40 +494,50 @@ def calculate_summaries(self): "cis_20kb+", "cis_40kb+", ): - self._stat[key]["summary"][f"frac_{cis_count}"] = ( - (self._stat[key][cis_count] / self._stat[key]["total_nodups"]) - if self._stat[key]["total_nodups"] > 0 + self._stat[filter_name]["summary"][f"frac_{cis_count}"] = ( + (self._stat[filter_name][cis_count] / self._stat[filter_name]["total_nodups"]) + if self._stat[filter_name]["total_nodups"] > 0 else 0 ) - self._stat[key]["summary"][ + self._stat[filter_name]["summary"]["dist_freq_convergence"] = convergence_stats[filter_name] + + self._stat[filter_name]["summary"]["frac_dups"] = ( + (self._stat[filter_name]["total_dups"] / self._stat[filter_name]["total_mapped"]) + if self._stat[filter_name]["total_mapped"] > 0 + else 0 + ) + + self._stat[filter_name]["summary"][ "complexity_naive" ] = estimate_library_complexity( - self._stat[key]["total_mapped"], self._stat[key]["total_dups"], 0 + self._stat[filter_name]["total_mapped"], self._stat[filter_name]["total_dups"], 0 ) - if key == "no_filter" and self._save_bytile_dups: + + if filter_name == "no_filter" and self._save_bytile_dups: # Estimate library complexity with information by tile, if provided: if self._bytile_dups.shape[0] > 0: - self._stat[key]["dups_by_tile_median"] = int( + self._stat[filter_name]["dups_by_tile_median"] = int( round( self._bytile_dups["dup_count"].median() * self._bytile_dups.shape[0] ) ) - if "dups_by_tile_median" in self._stat[key].keys(): - self._stat[key]["summary"][ + if "dups_by_tile_median" in self._stat[filter_name].keys(): + self._stat[filter_name]["summary"][ "complexity_dups_by_tile_median" ] = estimate_library_complexity( - self._stat[key]["total_mapped"], - self._stat[key]["total_dups"], - self._stat[key]["total_dups"] - - self._stat[key]["dups_by_tile_median"], + self._stat[filter_name]["total_mapped"], + self._stat[filter_name]["total_dups"], + self._stat[filter_name]["total_dups"] + - self._stat[filter_name]["dups_by_tile_median"], ) self._summaries_calculated = True + @classmethod - def from_file(cls, file_handle): + def from_file(cls, file_handle, n_dist_bins_decade=N_DIST_BINS_DECADE_DEFAULT): """create instance of PairCounter from file Parameters ---------- @@ -264,101 +549,45 @@ def from_file(cls, file_handle): """ # fill in from file - file_handle: default_filter = "no_filter" - stat_from_file = cls() + stat_from_file = cls(n_dist_bins_decade=n_dist_bins_decade) + raw_stat = {} for l in file_handle: - fields = l.strip().split(cls._SEP) - if len(fields) == 0: + key_val_pair = l.strip().split(cls._SEP) + if len(key_val_pair) == 0: # skip empty lines: continue - if len(fields) != 2: + if len(key_val_pair) != 2: # expect two _SEP separated values per line: raise fileio.ParseError( "{} is not a valid stats file".format(file_handle.name) ) - # extract key and value, then split the key: - putative_key, putative_val = fields[0], fields[1] - key_fields = putative_key.split(cls._KEY_SEP) - # we should impose a rigid structure of .stats or redo it: - if len(key_fields) == 1: - key = key_fields[0] - if key in stat_from_file._stat[default_filter]: - stat_from_file._stat[default_filter][key] = int(fields[1]) - else: - raise fileio.ParseError( - "{} is not a valid stats file: unknown field {} detected".format( - file_handle.name, key - ) - ) - else: - # in this case key must be in ['pair_types','chrom_freq','dist_freq','dedup', 'summary'] - # get the first 'key' and keep the remainders in 'key_fields' - key = key_fields.pop(0) - if key in ["pair_types", "dedup", "summary", "chromsizes"]: - # assert there is only one element in key_fields left: - # 'pair_types', 'dedup', 'summary' and 'chromsizes' treated the same - if len(key_fields) == 1: - try: - stat_from_file._stat[default_filter][key][ - key_fields[0] - ] = int(fields[1]) - except ValueError: - stat_from_file._stat[default_filter][key][ - key_fields[0] - ] = float(fields[1]) - else: - raise fileio.ParseError( - "{} is not a valid stats file: {} section implies 1 identifier".format( - file_handle.name, key - ) - ) + raw_stat[key_val_pair[0]] = parse_number(key_val_pair[1]) + - elif key == "chrom_freq": - # assert remaining key_fields == [chr1, chr2]: - if len(key_fields) == 2: - stat_from_file._stat[default_filter][key][ - tuple(key_fields) - ] = int(fields[1]) - else: - raise fileio.ParseError( - "{} is not a valid stats file: {} section implies 2 identifiers".format( - file_handle.name, key - ) - ) + ## TODO: check if raw_stat does not contain any unknown keys - elif key == "dist_freq": - # assert that last element of key_fields is the 'directions' - if len(key_fields) == 2: - # assert 'dirs' in ['++','--','+-','-+'] - dirs = key_fields.pop() - # there is only genomic distance range of the bin that's left: - (bin_range,) = key_fields - # extract left border of the bin "1000000+" or "1500-6000": - dist_bin_left = int( - bin_range.strip("+") - if bin_range.endswith("+") - else bin_range.split("-")[0] - ) - # store corresponding value: - stat_from_file._stat[default_filter][key][dirs][dist_bin_left] = int( - fields[1] - ) - else: - raise fileio.ParseError( - "{} is not a valid stats file: {} section implies 2 identifiers".format( - file_handle.name, key - ) - ) - else: - raise fileio.ParseError( - "{} is not a valid stats file: unknown field {} detected".format( - file_handle.name, key - ) - ) - # return PairCounter from a non-empty dict: + # Convert flat dict to nested dict + stat_from_file._stat[default_filter].update(flat_dict_to_nested(raw_stat, sep=cls._KEY_SEP)) + + stat_from_file._stat[default_filter]['chrom_freq'] = nested_dict_to_flat( + stat_from_file._stat[default_filter]['chrom_freq'], tuple_keys=True) + + bin_to_left_val = lambda bin: int(bin.rstrip('+') if ('+' in bin) else bin.split('-')[0]) + + stat_from_file._stat[default_filter]['dist_freq'] = { + bin_to_left_val(k): v + for k,v in stat_from_file._stat[default_filter]['dist_freq'].items() + } + + stat_from_file._stat[default_filter]['dist_freq'] = swap_levels_nested_dict( + stat_from_file._stat[default_filter]['dist_freq'], 0, 1 + ) + + return stat_from_file @classmethod - def from_yaml(cls, file_handle): + def from_yaml(cls, file_handle, n_dist_bins_decade=N_DIST_BINS_DECADE_DEFAULT): """create instance of PairCounter from file Parameters ---------- @@ -371,6 +600,7 @@ def from_yaml(cls, file_handle): # fill in from file - file_handle: stat = yaml.safe_load(file_handle) stat_from_file = cls( + n_dist_bins_decade=n_dist_bins_decade, filters={key: val.get("filter_expression", "") for key, val in stat.items()} ) @@ -384,6 +614,7 @@ def from_yaml(cls, file_handle): stat_from_file._stat = stat return stat_from_file + def add_pair( self, chrom1, @@ -444,24 +675,18 @@ def add_pair( np.searchsorted(self._dist_bins, dist, "right") - 1 ] self._stat[filter]["dist_freq"][strand1 + strand2][dist_bin] += 1 - if dist >= 1000: - self._stat[filter]["cis_1kb+"] += 1 - if dist >= 2000: - self._stat[filter]["cis_2kb+"] += 1 - if dist >= 4000: - self._stat[filter]["cis_4kb+"] += 1 - if dist >= 10000: - self._stat[filter]["cis_10kb+"] += 1 - if dist >= 20000: - self._stat[filter]["cis_20kb+"] += 1 - if dist >= 40000: - self._stat[filter]["cis_40kb+"] += 1 + + for dist_kb in [1, 2, 4, 10, 20, 40]: + if dist >= dist_kb * 1000: + self._stat[filter][f"cis_{dist_kb}kb+"] += 1 + else: self._stat[filter]["trans"] += 1 else: self._stat[filter]["total_single_sided_mapped"] += 1 + def add_pairs_from_dataframe(self, df, unmapped_chrom="!"): """Gather statistics for Hi-C pairs in a dataframe and add to the PairCounter. @@ -541,17 +766,21 @@ def add_pairs_from_dataframe(self, df, unmapped_chrom="!"): self._stat[key]["cis"] += df_cis.shape[0] self._stat[key]["trans"] += df_nodups.shape[0] - df_cis.shape[0] + + # Count cis distance frequencies: dist = np.abs(df_cis["pos2"].values - df_cis["pos1"].values) df_cis.loc[:, "bin_idx"] = ( np.searchsorted(self._dist_bins, dist, "right") - 1 ) + for (strand1, strand2, bin_id), strand_bin_count in ( df_cis[["strand1", "strand2", "bin_idx"]].value_counts().items() ): self._stat[key]["dist_freq"][strand1 + strand2][ self._dist_bins[bin_id].item() ] += strand_bin_count + self._stat[key]["cis_1kb+"] += int(np.sum(dist >= 1000)) self._stat[key]["cis_2kb+"] += int(np.sum(dist >= 2000)) self._stat[key]["cis_4kb+"] += int(np.sum(dist >= 4000)) @@ -592,21 +821,23 @@ def __add__(self, other, filter="no_filter"): # use the empty PairCounter to iterate over: for k, v in sum_stat._stat[filter].items(): if k != "chromsizes" and ( - k not in self._stat[filter] or k not in other._stat[filter] + (k not in self._stat[filter]) or (k not in other._stat[filter]) ): # Skip any missing fields and warn logger.warning( f"{k} not found in at least one of the input stats, skipping" ) continue + # not nested fields are summed trivially: if isinstance(v, int): sum_stat._stat[filter][k] = ( self._stat[filter][k] + other._stat[filter][k] ) + # sum nested dicts/arrays in a context dependet manner: else: - if k in ["pair_types", "dedup", "summary"]: + if k in ["pair_types", "dedup"]: # handy function for summation of a pair of dicts: # https://stackoverflow.com/questions/10461531/merge-and-sum-of-two-dictionaries sum_dicts = lambda dict_x, dict_y: { @@ -617,6 +848,7 @@ def __add__(self, other, filter="no_filter"): sum_stat._stat[filter][k] = sum_dicts( self._stat[filter][k], other._stat[filter][k] ) + elif k == "chrom_freq": # union list of keys (chr1,chr2) with potential duplicates: union_keys_with_dups = list(self._stat[filter][k].keys()) + list( @@ -631,6 +863,7 @@ def __add__(self, other, filter="no_filter"): sum_stat._stat[filter][k][union_key] = self._stat[filter][ k ].get(union_key, 0) + other._stat[filter][k].get(union_key, 0) + elif k == "dist_freq": for dirs in sum_stat[k]: from functools import reduce @@ -646,6 +879,7 @@ def reducer(accumulator, element): {}, ) # sum_stat[k][dirs] = self._stat[filter][k][dirs] + other._stat[filter][k][dirs] + elif k == "chromsizes": if k in self._stat[filter] and k in other._stat[filter]: if self._stat[filter][k] == other._stat[filter][k]: @@ -671,6 +905,8 @@ def reducer(accumulator, element): "One or both stats don't have chromsizes recorded" ) + sum_stat.calculate_summaries() + return sum_stat # we need this to be able to sum(list_of_PairCounters) @@ -690,13 +926,14 @@ def flatten(self, filter="no_filter"): for k, v in self._stat[filter].items(): if isinstance(v, int): flat_stat[k] = v - # store nested dicts/arrays in a context dependet manner: + # store nested dicts/arrays in a context dependent manner: # nested categories are stored only if they are non-trivial else: if (k == "dist_freq") and v: for i in range(len(self._dist_bins)): for dirs, freqs in v.items(): dist = self._dist_bins[i] + # last bin is treated differently: "100000+" vs "1200-3000": if i < len(self._dist_bins) - 1: dist_next = self._dist_bins[i + 1] @@ -709,17 +946,23 @@ def flatten(self, filter="no_filter"): ).format(k, dist, dirs) else: raise ValueError("There is a mismatch between dist_freq bins in the instance") + # store key,value pair: - flat_stat[formatted_key] = freqs[dist] - elif (k in ["pair_types", "dedup", "chromsizes"]) and v: + try: + flat_stat[formatted_key] = freqs[dist] + except: + # in some previous versions of stats, last bin was not reported, so we need to skip it now: + if (dist not in freqs) and (i == len(self._dist_bins) - 1): + flat_stat[formatted_key] = 0 + else: + raise ValueError(f"Error in {k} {dirs} {dist} {dist_next} {freqs}: source and destination bins do not match") + + elif (k in ["pair_types", "dedup", "chromsizes", 'summary']) and v: # 'pair_types' and 'dedup' are simple dicts inside, # treat them the exact same way: - for k_item, freq in v.items(): - formatted_key = self._KEY_SEP.join(["{}", "{}"]).format( - k, k_item - ) - # store key,value pair: - flat_stat[formatted_key] = freq + flat_stat.update( + {k+self._KEY_SEP+k2 : v2 for k2,v2 in nested_dict_to_flat(v, sep=self._KEY_SEP).items()}) + elif (k == "chrom_freq") and v: for (chrom1, chrom2), freq in v.items(): formatted_key = self._KEY_SEP.join(["{}", "{}", "{}"]).format( @@ -727,11 +970,6 @@ def flatten(self, filter="no_filter"): ) # store key,value pair: flat_stat[formatted_key] = freq - elif (k == "summary") and v: - for key, frac in v.items(): - formatted_key = self._KEY_SEP.join(["{}", "{}"]).format(k, key) - # store key,value pair: - flat_stat[formatted_key] = frac # return flattened dict return flat_stat @@ -742,30 +980,23 @@ def format_yaml(self, filter="no_filter"): from copy import deepcopy - formatted_stat = {key: {} for key in self.filters.keys()} + formatted_stat = {filter_name: {} for filter_name in self.filters.keys()} # Storing statistics for each filter - for key in self.filters.keys(): - for k, v in self._stat[key].items(): - if isinstance(v, int): - formatted_stat[key][k] = v - # store nested dicts/arrays in a context dependet manner: - # nested categories are stored only if they are non-trivial - else: - if (k != "chrom_freq") and v: - # simple dicts inside - # treat them the exact same way: - formatted_stat[key][k] = deepcopy(v) - elif (k == "chrom_freq") and v: - # need to convert tuples of chromosome names to str - freqs = {} - for (chrom1, chrom2), freq in sorted(v.items()): - freqs[ - self._KEY_SEP.join(["{}", "{}"]).format(chrom1, chrom2) - ] = freq - # store key,value pair: - formatted_stat[key][k] = deepcopy(freqs) + for filter_name in self.filters.keys(): + for k, v in self._stat[filter_name].items(): + if (k == "chrom_freq"): + v = {self._KEY_SEP.join(k2):v2 for k2, v2 in v.items()} + if v: + formatted_stat[filter_name][k] = deepcopy(v) # return formatted dict + formatted_stat = nested_dict_to_flat(formatted_stat, tuple_keys=True) + for k in formatted_stat: + v = formatted_stat[k] + if isinstance(v, np.generic): + formatted_stat[k] = v.item() + formatted_stat = flat_dict_to_nested(formatted_stat) + return formatted_stat def save(self, outstream, yaml=False, filter="no_filter"): @@ -800,6 +1031,7 @@ def save(self, outstream, yaml=False, filter="no_filter"): for k, v in data.items(): outstream.write("{}{}{}\n".format(k, self._SEP, v)) + def save_bytile_dups(self, outstream): """save bytile duplication counts to a tab-delimited text file. Parameters @@ -831,9 +1063,9 @@ def do_merge(output, files_to_merge, **kwargs): ) # use a factory method to instanciate PairCounter if kwargs.get("yaml", False): - stat = PairCounter.from_yaml(f) + stat = PairCounter.from_yaml(f, n_dist_bins_decade=kwargs.get('n_dist_bins_decade', PairCounter.N_DIST_BINS_DECADE_DEFAULT)) else: - stat = PairCounter.from_file(f) + stat = PairCounter.from_file(f, n_dist_bins_decade=kwargs.get('n_dist_bins_decade', PairCounter.N_DIST_BINS_DECADE_DEFAULT)) stats.append(stat) f.close() diff --git a/tests/test_stats.py b/tests/test_stats.py index 0a108e0..8ead642 100644 --- a/tests/test_stats.py +++ b/tests/test_stats.py @@ -5,6 +5,9 @@ import numpy as np import yaml +import pytest + + testdir = os.path.dirname(os.path.realpath(__file__)) @@ -45,7 +48,7 @@ def test_mock_pairsam(): for orientation in ("++", "+-", "-+", "--"): s = stats["no_filter"]["dist_freq"][orientation] for k, val in s.items(): - if orientation == "++" and k in [1, 2, 32]: + if orientation == "++" and k in [1, 2, 42]: assert s[k] == 1 else: assert s[k] == 0 @@ -131,3 +134,59 @@ def test_merge_stats(): print(e.output) print(sys.exc_info()) raise e + + +from pairtools.lib.stats import PairCounter + +@pytest.fixture +def pair_counter(): + counter = PairCounter(filters={"f1": "filter1", "f2": "filter2"}) + counter._dist_bins = np.array([1, 1000, 10000, 100000, 1000000]) + # Populate the counter with some sample data + counter._stat["f1"]["dist_freq"] = { + "++": {1: 80, 1000: 80, 10000: 91, 100000: 95}, + "--": {1: 100, 1000: 100, 10000: 100, 100000: 100}, + "-+": {1: 100, 1000: 100, 10000: 100, 100000: 100}, + "+-": {1: 120, 1000: 120, 10000: 109, 100000: 105}, + } + + counter._stat["f2"]["dist_freq"] = { + "++": {1: 200, 1000: 180, 10000: 160, 100000: 140}, + "--": {1: 220, 1000: 190, 10000: 170, 100000: 150}, + "-+": {1: 210, 1000: 185, 10000: 165, 100000: 145}, + "+-": {1: 230, 1000: 195, 10000: 175, 100000: 155}, + } + + return counter + + +def test_find_dist_freq_convergence_distance(pair_counter): + result = pair_counter.find_dist_freq_convergence_distance(0.1) + + assert "f1" in result + assert "f2" in result + + f1_result = result["f1"] + assert "convergence_dist" in f1_result + assert "strands_w_max_convergence_dist" in f1_result + assert "convergence_rel_diff_threshold" in f1_result + assert "n_cis_pairs_below_convergence_dist" in f1_result + assert "n_cis_pairs_below_convergence_dist_all_strands" in f1_result + assert "n_cis_pairs_above_convergence_dist_all_strands" in f1_result + assert "frac_cis_in_cis_below_convergence_dist" in f1_result + assert "frac_cis_in_cis_below_convergence_dist_all_strands" in f1_result + assert "frac_cis_in_cis_above_convergence_dist_all_strands" in f1_result + assert "frac_total_mapped_in_cis_below_convergence_dist" in f1_result + assert "frac_total_mapped_in_cis_below_convergence_dist_all_strands" in f1_result + assert "frac_total_mapped_in_cis_above_convergence_dist_all_strands" in f1_result + + assert f1_result["convergence_rel_diff_threshold"] == 0.1 + assert f1_result["convergence_dist"] == 10000 + assert f1_result["strands_w_max_convergence_dist"] == "++" + + + # f2_result = result["f2"] + # assert "convergence_dist" in f2_result + # assert "strands_w_max_convergence_dist" in f2_result + # assert "convergence_rel_diff_threshold" in f2_result + # Add more assertions for f2_result as needed \ No newline at end of file