diff --git a/pyproject.toml b/pyproject.toml index 91a59cf..e6892fb 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -4,7 +4,7 @@ build-backend = "setuptools.build_meta" [project] name = "ModDotPlot" -version = "0.8.4" +version = "0.8.5" requires-python = ">= 3.7" dependencies = [ "pysam", @@ -16,7 +16,9 @@ dependencies = [ "mmh3", "tk", "setproctitle", - "numpy" + "numpy", + "PIL", + "patchworklib" ] authors = [ {name = "Alex Sweeten", email = "alex.sweeten@nih.gov"}, diff --git a/src/moddotplot/const.py b/src/moddotplot/const.py index f283d4a..7c971f2 100644 --- a/src/moddotplot/const.py +++ b/src/moddotplot/const.py @@ -1,4 +1,4 @@ -VERSION = "0.8.4" +VERSION = "0.8.5" COLS = [ "#query_name", "query_start", diff --git a/src/moddotplot/moddotplot.py b/src/moddotplot/moddotplot.py index dfeb282..505cf56 100644 --- a/src/moddotplot/moddotplot.py +++ b/src/moddotplot/moddotplot.py @@ -336,6 +336,12 @@ def get_parser(): help="Preserve diagonal when handling strings of ambiguous homopolymers (eg. long runs of N's).", ) + static_parser.add_argument( + "--grid", + action="store_true", + help="Plot comparative plots in an NxN grid like format.", + ) + # TODO: Implement static mode logging options return parser @@ -932,6 +938,9 @@ def main(): seq_sparsity = 2 ** (int(math.log2(seq_sparsity - 1)) + 1) expectation = round(win / seq_sparsity) + if args.grid: + grid_vals = [] + for i in range(len(sequences)): larger_seq = sequences[i][1] diff --git a/src/moddotplot/parse_fasta.py b/src/moddotplot/parse_fasta.py index f271862..bf12e34 100644 --- a/src/moddotplot/parse_fasta.py +++ b/src/moddotplot/parse_fasta.py @@ -31,8 +31,8 @@ def generateKmersFromFasta(seq: Sequence[str], k: int, quiet: bool) -> Iterable[ suffix="Completed", length=40, ) - - kmer = seq[i : i + k] + # Remove case sensitivity + kmer = seq[i : i + k].upper() fh = mmh3.hash(kmer) # Calculate reverse complement hash directly without the need for translation diff --git a/src/moddotplot/static_plots.py b/src/moddotplot/static_plots.py index ee4615c..43eeaec 100644 --- a/src/moddotplot/static_plots.py +++ b/src/moddotplot/static_plots.py @@ -24,7 +24,11 @@ ) import pandas as pd import numpy as np +from PIL import Image +import patchworklib as pw import math +import os + from moddotplot.const import ( DIVERGING_PALETTES, QUALITATIVE_PALETTES, @@ -67,6 +71,43 @@ def make_scale(vals: list) -> list: return make_m(scaled) +def overlap_axis(rotated_plot, filename, prefix): + scale_factor = math.sqrt(2) + 0.04 + new_width = int(rotated_plot.width / scale_factor) + new_height = int(rotated_plot.height / scale_factor) + resized_rotated_plot = rotated_plot.resize((new_width, new_height), Image.LANCZOS) + + # Step 3: Overlay the resized rotated heatmap onto the original axes + + # Open the original heatmap with axes + image_with_axes = Image.open(filename) + + # Create a blank image with the same size as the original + final_image = Image.new("RGBA", image_with_axes.size) + + # Calculate the position to center the resized rotated image within the original plot area + x_offset = (final_image.width - resized_rotated_plot.width) // 2 + y_offset = (final_image.height - resized_rotated_plot.height) // 2 + y_offset += 2400 + x_offset += 30 + + # Paste the original image with axes onto the final image + final_image.paste(image_with_axes, (0, 0)) + + # Paste the resized rotated plot onto the final image + final_image.paste(resized_rotated_plot, (x_offset, y_offset), resized_rotated_plot) + width, height = final_image.size + cropped_image = final_image.crop((0, height // 2.6, width, height)) + + # Save or show the final image + cropped_image.save(f"{prefix}_TRI.png") + cropped_image.save(f"{prefix}_TRI.pdf", "PDF", resolution=100.0) + + # Remove temp files + if os.path.exists(filename): + os.remove(filename) + + def get_colors(sdf, ncolors, is_freq, custom_breakpoints): assert ncolors > 2 and ncolors < 12 bot = math.floor(min(sdf["perID_by_events"])) @@ -316,7 +357,89 @@ def make_tri(sdf, title_name, palette, palette_orientation, colors, breaks, xlim + scale_x_continuous(labels=make_scale, limits=[0, max_val], breaks=breaks) + scale_y_continuous(labels=make_scale, limits=[0, max_val], breaks=breaks) + coord_fixed(ratio=1) - + facet_grid("r ~ q") + + labs(x="", y="", title=title_name) + ) + + # Adjust x-axis label size + p += theme(axis_title_x=element_text()) + + return p + + +def make_tri2(sdf, title_name, palette, palette_orientation, colors, breaks, xlim): + if not breaks: + breaks = True + else: + breaks = [float(number) for number in breaks] + if not xlim: + xlim = 0 + hexcodes = [] + new_hexcodes = [] + if palette in DIVERGING_PALETTES: + function_name = getattr(diverging, palette) + hexcodes = function_name.hex_colors + if palette_orientation == "+": + palette_orientation = "-" + else: + palette_orientation = "+" + elif palette in QUALITATIVE_PALETTES: + function_name = getattr(qualitative, palette) + hexcodes = function_name.hex_colors + elif palette in SEQUENTIAL_PALETTES: + function_name = getattr(sequential, palette) + hexcodes = function_name.hex_colors + else: + function_name = getattr(sequential, "Spectral_11") + palette_orientation = "-" + hexcodes = function_name.hex_colors + + if palette_orientation == "-": + new_hexcodes = hexcodes[::-1] + else: + new_hexcodes = hexcodes + if colors: + new_hexcodes = colors + max_val = max(sdf["q_en"].max(), sdf["r_en"].max(), xlim) + window = max(sdf["q_en"] - sdf["q_st"]) + if max_val < 100000: + x_label = "Genomic Position (Kbp)" + elif max_val < 100000000: + x_label = "Genomic Position (Mbp)" + else: + x_label = "Genomic Position (Gbp)" + p = ( + ggplot(sdf) + + geom_tile( + aes(x="q_st", y="r_st", fill="discrete", height=window, width=window), + alpha=0, + ) + + scale_color_discrete(guide=False) + + scale_fill_manual( + values=new_hexcodes, + guide=False, + ) + + theme( + legend_position="none", + panel_grid_major=element_blank(), + panel_grid_minor=element_blank(), + plot_background=element_blank(), + panel_background=element_blank(), + axis_line=element_line(color="black"), # Adjust axis line size + axis_text=element_text( + family=["DejaVu Sans"] + ), # Change axis text font and size + axis_ticks_major=element_line(), + axis_line_x=element_line(), # Keep the x-axis line + axis_line_y=element_blank(), # Remove the y-axis line + axis_ticks_major_x=element_line(), # Keep x-axis ticks + axis_ticks_major_y=element_blank(), # Remove y-axis ticks + axis_text_x=element_line(), # Keep x-axis text + axis_text_y=element_blank(), + plot_title=element_blank(), + ) + + scale_x_continuous(labels=make_scale, limits=[0, max_val], breaks=breaks) + + scale_y_continuous(labels=make_scale, limits=[0, max_val], breaks=breaks) + + coord_fixed(ratio=1) + labs(x="", y="", title=title_name) ) @@ -376,6 +499,28 @@ def make_hist(sdf, palette, palette_orientation, custom_colors, custom_breakpoin return p +def create_grid( + singles, + doubles, + directory, + name_x, + name_y, + palette, + palette_orientation, + no_hist, + width, + dpi, + is_freq, + xlim, + custom_colors, + custom_breakpoints, + from_file, + is_pairwise, + axes_label, +): + print(singles) + + def create_plots( sdf, directory, @@ -394,7 +539,6 @@ def create_plots( is_pairwise, axes_labels, ): - # TODO: Implement xlim df = read_df( sdf, palette, @@ -477,6 +621,15 @@ def create_plots( axes_labels, xlim, ) + tri_plot_axis_only = make_tri2( + sdf, + plot_filename, + palette, + palette_orientation, + custom_colors, + axes_labels, + xlim, + ) full_plot = make_dot( check_st_en_equality(sdf), plot_filename, @@ -486,27 +639,51 @@ def create_plots( axes_labels, xlim, ) - print(f"Creating plots and saving to {plot_filename}...\n") + triplot_no_axis = tri_plot + theme( + axis_text_x=element_blank(), + axis_text_y=element_blank(), + axis_title_x=element_blank(), + axis_title_y=element_blank(), + axis_line_x=element_blank(), + axis_line_y=element_blank(), + axis_ticks_major=element_blank(), + axis_ticks_minor=element_blank(), + panel_background=element_blank(), + panel_grid_major=element_blank(), + panel_grid_minor=element_blank(), + plot_title=element_blank(), + ) ggsave( - tri_plot, + triplot_no_axis, width=9, height=9, - dpi=dpi, - format="pdf", - filename=f"{plot_filename}_TRI.pdf", + dpi=600, + format="png", + filename=f"{plot_filename}_TRI_NOAXIS.png", verbose=False, ) ggsave( - tri_plot, + tri_plot_axis_only, width=9, height=9, - dpi=dpi, + dpi=600, format="png", - filename=f"{plot_filename}_TRI.png", + filename=f"{plot_filename}_AXIS.png", verbose=False, ) + png_no_axes = Image.open(f"{plot_filename}_TRI_NOAXIS.png") + rotated_png = png_no_axes.rotate(315, expand=True) + + rotated_png.save(f"{plot_filename}_ROTATED_TRI_NOAXIS.png") + overlap_axis(rotated_png, f"{plot_filename}_AXIS.png", plot_filename) + + if os.path.exists(f"{plot_filename}_ROTATED_TRI_NOAXIS.png"): + os.remove(f"{plot_filename}_ROTATED_TRI_NOAXIS.png") + if os.path.exists(f"{plot_filename}_TRI_NOAXIS.png"): + os.remove(f"{plot_filename}_TRI_NOAXIS.png") + ggsave( full_plot, width=9, @@ -525,10 +702,9 @@ def create_plots( filename=f"{plot_filename}_FULL.png", verbose=False, ) - if no_hist: print( - f"{plot_filename}_TRI.png, {plot_filename}_TRI.pdf, {plot_filename}_FULL.png and {plot_filename}_FULL.png saved sucessfully. \n" + f"{plot_filename}_TRI.png, {plot_filename}_TRI.pdf, {plot_filename}_FULL.png and {plot_filename}_FULL.pdf saved sucessfully. \n" ) else: ggsave( @@ -550,5 +726,5 @@ def create_plots( verbose=False, ) print( - f"{plot_filename}_TRI.png, {plot_filename}_TRI.pdf, {plot_filename}_FULL.png, {plot_filename}_FULL.png, {plot_filename}_HIST.png and {plot_filename}_HIST.pdf, saved sucessfully. \n" + f"{plot_filename}_TRI.png, {plot_filename}_TRI.pdf, {plot_filename}_FULL.png, {plot_filename}_FULL.pdf, {plot_filename}_HIST.png and {plot_filename}_HIST.pdf, saved sucessfully. \n" )