Skip to content

Commit

Permalink
add argument for preventing crossings
Browse files Browse the repository at this point in the history
  • Loading branch information
Phlya committed Aug 16, 2024
1 parent 6ea4e6d commit 8e13a22
Showing 1 changed file with 28 additions and 17 deletions.
45 changes: 28 additions & 17 deletions adjustText/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,7 @@
matplot_get_renderer = None


# From https://gist.github.com/kylemcdonald/6132fc1c29fd3767691442ba4bc84018
# Modified from https://gist.github.com/kylemcdonald/6132fc1c29fd3767691442ba4bc84018
def intersect(seg1, seg2):
x1, y1, x2, y2 = seg1
x3, y3, x4, y4 = seg2
Expand Down Expand Up @@ -437,6 +437,7 @@ def adjust_text(
target_x=None,
target_y=None,
avoid_self=True,
prevent_crossings=True,
force_text: tuple[float, float] | float = (0.1, 0.2),
force_static: tuple[float, float] | float = (0.1, 0.2),
force_pull: tuple[float, float] | float = (0.01, 0.01),
Expand Down Expand Up @@ -498,6 +499,8 @@ def adjust_text(
Should be the same length as texts and in the same order, or None.
avoid_self : bool, default True
whether to repel texts from its original positions.
prevent_crossings : bool, default True
whether to prevent arrows from crossing each other [NEW, EXPERIMENTAL]
force_text : tuple[float, float] | float, default (0.1, 0.2)
the repel force from texts is multiplied by this value
force_static : tuple[float, float] | float, default (0.1, 0.2)
Expand Down Expand Up @@ -708,28 +711,30 @@ def adjust_text(
bbox_to_contain=ax_bbox,
only_move=only_move,
)
connections = np.hstack(
[
np.mean(coords[:, :2], axis=1)[:, np.newaxis],
np.mean(coords[:, 2:], axis=1)[:, np.newaxis],
target_xy_disp_coord,
]
)
for i, seg1 in enumerate(connections):
for j, seg2 in enumerate(connections):
if i >= j:
continue
inter = intersect(seg1, seg2)
if inter:
coords[i], coords[j] = coords[j].copy(), coords[i].copy()
if prevent_crossings:
connections = np.hstack(
[
np.mean(coords[:, :2], axis=1)[:, np.newaxis],
np.mean(coords[:, 2:], axis=1)[:, np.newaxis],
target_xy_disp_coord,
]
)
for i, seg1 in enumerate(connections):
for j, seg2 in enumerate(connections):
if i >= j:
continue
inter = intersect(seg1, seg2)
if inter:
logger.debug(f"Removing crossing at step {step}: {i} and {j}")
coords[i], coords[j] = coords[j].copy(), coords[i].copy()

step += 1
if time_lim is not None and timer() - start_time > time_lim:
break
if iter_lim is not None and i == iter_lim:
break

logger.debug(f"Adjustment took {i} iterations")
logger.debug(f"Adjustment took {step} iterations")
logger.debug(f"Time: {timer() - start_time}")
logger.debug(f"Error: {error}")

Expand All @@ -742,7 +747,13 @@ def adjust_text(
axis=1,
)
display_dists = np.max(np.vstack([xdists, ydists]), axis=0)

connections = np.hstack(
[
np.mean(coords[:, :2], axis=1)[:, np.newaxis],
np.mean(coords[:, 2:], axis=1)[:, np.newaxis],
target_xy_disp_coord,
]
)
transformed_connections = np.empty_like(connections)
transformed_connections[:, :2] = transform.inverted().transform(connections[:, :2])
transformed_connections[:, 2:] = transform.inverted().transform(connections[:, 2:])
Expand Down

0 comments on commit 8e13a22

Please sign in to comment.