From dacbf45317cd80851bab2d3f3364c4f6381581bf Mon Sep 17 00:00:00 2001 From: Phlya Date: Fri, 13 Sep 2024 11:30:18 +0200 Subject: [PATCH] small fixes --- adjustText/__init__.py | 61 ++++++++++++++++++++++++------------------ 1 file changed, 35 insertions(+), 26 deletions(-) diff --git a/adjustText/__init__.py b/adjustText/__init__.py index 10bfaab..33b8196 100644 --- a/adjustText/__init__.py +++ b/adjustText/__init__.py @@ -247,15 +247,16 @@ def force_into_bbox(coords, bbox): return apply_shifts(coords, -dx, -dy) -def random_shifts(coords, max_move, only_move="xy"): +def random_shifts(coords, only_move="xy"): + # logger.debug(f"Random shifts with max_move: {max_move}") mids = np.hstack( [ np.mean(coords[:, :2], axis=1)[:, np.newaxis], np.mean(coords[:, 2:], axis=1)[:, np.newaxis], ] ) - if max_move is None: - max_move = 1 + # if max_move is None: + # max_move = 1 unq, count = np.unique(mids, axis=0, return_counts=True) repeated_groups = unq[count > 1] @@ -263,7 +264,7 @@ def random_shifts(coords, max_move, only_move="xy"): repeated_idx = np.argwhere(np.all(mids == repeated_group, axis=1)).flatten() logger.debug(f"Repeating group: {repeated_group}, idx: {repeated_idx}") for idx in repeated_idx: - shifts = (np.random.rand(2) - 0.5) * 2 * max_move + shifts = (np.random.rand(2) - 0.5) * 2 if "x" not in only_move: shifts[0] = 0 elif "x+" in only_move: @@ -276,6 +277,7 @@ def random_shifts(coords, max_move, only_move="xy"): shifts[1] = np.abs(shifts[1]) elif "y-" in only_move: shifts[1] = -np.abs(shifts[1]) + print(idx, shifts) coords[idx] += np.asarray([shifts[0], shifts[0], shifts[1], shifts[1]]) return coords @@ -323,7 +325,7 @@ def iterate( bbox_to_contain=False, only_move={"text": "xy", "static": "xy", "explode": "xy", "pull": "xy"}, ): - coords = random_shifts(coords, max_move, only_move.get("explode", "xy")) + coords = random_shifts(coords, only_move.get("explode", "xy")) text_shifts_x, text_shifts_y = get_shifts_texts( expand_coords(coords, expand[0], expand[1]) ) @@ -429,6 +431,26 @@ def warn_once(msg: str): logger.warning(msg) +def remove_crossings(coords, target_coords, step): + connections = np.hstack( + [ + np.mean(coords[:, :2], axis=1)[:, np.newaxis], + np.mean(coords[:, 2:], axis=1)[:, np.newaxis], + target_coords, + ] + ) + for i, seg1 in enumerate(connections): + for j, seg2 in enumerate(connections): + if i <= j: + continue + inter = intersect(seg1, seg2) + if inter: + logger.debug(f"Removing crossing at step {step}: {i} and {j}") + logger.debug(f"Segments: {seg1} and {seg2}") + coords[i], coords[j] = coords[j].copy(), coords[i].copy() + return coords + + def adjust_text( texts, x=None, @@ -635,21 +657,21 @@ def adjust_text( obj_coords[:, [1, 3]] = transform.transform(obj_coords[:, [1, 3]]) static_coords = np.vstack([point_coords[:, [0, 0, 1, 1]], obj_coords]) - if isinstance(max_move, int): + if isinstance(max_move, float) or isinstance(max_move, int): max_move = (max_move, max_move) elif max_move is None: max_move = (np.inf, np.inf) - if isinstance(force_explode, float): + if isinstance(force_explode, float) or isinstance(force_explode, int): force_explode = (force_explode, force_explode) - if isinstance(force_text, float): + if isinstance(force_text, float) or isinstance(force_text, int): force_text = (force_text, force_text) - if isinstance(force_static, float): + if isinstance(force_static, float) or isinstance(force_static, int): force_static = (force_static, force_static) - if isinstance(force_pull, float): + if isinstance(force_pull, float) or isinstance(force_pull, int): force_pull = (force_pull, force_pull) if explode_radius == "auto": @@ -698,6 +720,7 @@ def adjust_text( step = 0 while error > 0: # expand = expands[min(i, expand_steps-1)] + logger.debug(step) coords, error = iterate( coords, target_xy_disp_coord, @@ -712,26 +735,12 @@ def adjust_text( only_move=only_move, ) if prevent_crossings: - connections = np.hstack( - [ - np.mean(coords[:, :2], axis=1)[:, np.newaxis], - np.mean(coords[:, 2:], axis=1)[:, np.newaxis], - target_xy_disp_coord, - ] - ) - for i, seg1 in enumerate(connections): - for j, seg2 in enumerate(connections): - if i >= j: - continue - inter = intersect(seg1, seg2) - if inter: - logger.debug(f"Removing crossing at step {step}: {i} and {j}") - coords[i], coords[j] = coords[j].copy(), coords[i].copy() + coords = remove_crossings(coords, target_xy_disp_coord, step) step += 1 if time_lim is not None and timer() - start_time > time_lim: break - if iter_lim is not None and i == iter_lim: + if iter_lim is not None and step == iter_lim: break logger.debug(f"Adjustment took {step} iterations")