Skip to content

Commit

Permalink
Prevent arrow crossing?
Browse files Browse the repository at this point in the history
  • Loading branch information
Phlya committed Aug 15, 2024
1 parent 9d49fd2 commit 6ea4e6d
Show file tree
Hide file tree
Showing 2 changed files with 199 additions and 31 deletions.
92 changes: 80 additions & 12 deletions adjustText/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,22 @@
matplot_get_renderer = None


# From https://gist.github.com/kylemcdonald/6132fc1c29fd3767691442ba4bc84018
def intersect(seg1, seg2):
x1, y1, x2, y2 = seg1
x3, y3, x4, y4 = seg2
denom = (y4 - y3) * (x2 - x1) - (x4 - x3) * (y2 - y1)
if denom == 0: # parallel
return False
ua = ((x4 - x3) * (y1 - y3) - (y4 - y3) * (x1 - x3)) / denom
if ua < 0 or ua > 1: # out of range
return False
ub = ((x2 - x1) * (y1 - y3) - (y2 - y1) * (x1 - x3)) / denom
if ub < 0 or ub > 1: # out of range
return False
return True


def get_renderer(fig):
# If the backend support get_renderer() or renderer, use that.
if hasattr(fig.canvas, "get_renderer"):
Expand Down Expand Up @@ -231,6 +247,39 @@ def force_into_bbox(coords, bbox):
return apply_shifts(coords, -dx, -dy)


def random_shifts(coords, max_move, only_move="xy"):
mids = np.hstack(
[
np.mean(coords[:, :2], axis=1)[:, np.newaxis],
np.mean(coords[:, 2:], axis=1)[:, np.newaxis],
]
)
if max_move is None:
max_move = 1
unq, count = np.unique(mids, axis=0, return_counts=True)
repeated_groups = unq[count > 1]

for repeated_group in repeated_groups:
repeated_idx = np.argwhere(np.all(mids == repeated_group, axis=1)).flatten()
logger.debug(f"Repeating group: {repeated_group}, idx: {repeated_idx}")
for idx in repeated_idx:
shifts = (np.random.rand(2) - 0.5) * 2 * max_move
if "x" not in only_move:
shifts[0] = 0
elif "x+" in only_move:
shifts[0] = np.abs(shifts[0])
elif "x-" in only_move:
shifts[0] = -np.abs(shifts[0])
if "y" not in only_move:
shifts[1] = 0
elif "y+" in only_move:
shifts[1] = np.abs(shifts[1])
elif "y-" in only_move:
shifts[1] = -np.abs(shifts[1])
coords[idx] += np.asarray([shifts[0], shifts[0], shifts[1], shifts[1]])
return coords


def pull_back(coords, targets):
dx = np.max(np.subtract(targets[:, 0][:, np.newaxis], coords[:, :2]), axis=1)
dy = np.max(np.subtract(targets[:, 1][:, np.newaxis], coords[:, 2:]), axis=1)
Expand Down Expand Up @@ -274,6 +323,7 @@ def iterate(
bbox_to_contain=False,
only_move={"text": "xy", "static": "xy", "explode": "xy", "pull": "xy"},
):
coords = random_shifts(coords, max_move, only_move.get("explode", "xy"))
text_shifts_x, text_shifts_y = get_shifts_texts(
expand_coords(coords, expand[0], expand[1])
)
Expand Down Expand Up @@ -393,7 +443,7 @@ def adjust_text(
force_explode: tuple[float, float] | float = (0.1, 0.5),
pull_threshold: float = 10,
expand: tuple[float, float] = (1.05, 1.2),
max_move: tuple[int, int] | int | None = (20, 20),
max_move: tuple[int, int] | int | None = (5, 5),
explode_radius: str | float = "auto",
ensure_inside_axes: bool = True,
expand_axes: bool = False,
Expand Down Expand Up @@ -479,6 +529,9 @@ def adjust_text(
a dict to restrict movement of texts to only certain axes for certain
types of overlaps.
Valid keys are 'text', 'static', 'explode' and 'pull'.
'explode' is the initial explosion of texts to avoid overlaps, and this value is
also used for random shifts of perfectly overlapping texts to ensure they don't
stay in the same place.
Can contain 'x', 'y', 'x+', 'x-', 'y+', 'y-', or combinations of one 'x?' and
one 'y?'. 'x' and 'y' mean that the text can move in that direction, 'x+' and
'x-' mean that the text can move in the positive or negative direction along
Expand Down Expand Up @@ -553,6 +606,15 @@ def adjust_text(
"explode": only_move,
"pull": only_move,
}
elif isinstance(only_move, dict):
if "text" not in only_move:
only_move["text"] = "xy"
if "static" not in only_move:
only_move["static"] = "xy"
if "explode" not in only_move:
only_move["explode"] = "xy"
if "pull" not in only_move:
only_move["pull"] = "xy"

# coords += np.random.rand(*coords.shape)*1e-6
if x is not None and y is not None:
Expand Down Expand Up @@ -630,7 +692,7 @@ def adjust_text(
else:
ax_bbox = False

i = 0
step = 0
while error > 0:
# expand = expands[min(i, expand_steps-1)]
coords, error = iterate(
Expand All @@ -646,8 +708,22 @@ def adjust_text(
bbox_to_contain=ax_bbox,
only_move=only_move,
)

i += 1
connections = np.hstack(
[
np.mean(coords[:, :2], axis=1)[:, np.newaxis],
np.mean(coords[:, 2:], axis=1)[:, np.newaxis],
target_xy_disp_coord,
]
)
for i, seg1 in enumerate(connections):
for j, seg2 in enumerate(connections):
if i >= j:
continue
inter = intersect(seg1, seg2)
if inter:
coords[i], coords[j] = coords[j].copy(), coords[i].copy()

step += 1
if time_lim is not None and timer() - start_time > time_lim:
break
if iter_lim is not None and i == iter_lim:
Expand All @@ -667,14 +743,6 @@ def adjust_text(
)
display_dists = np.max(np.vstack([xdists, ydists]), axis=0)

connections = np.hstack(
[
np.mean(coords[:, :2], axis=1)[:, np.newaxis],
np.mean(coords[:, 2:], axis=1)[:, np.newaxis],
target_xy_disp_coord,
]
) # For the future to move into the loop and resolve crossing connections

transformed_connections = np.empty_like(connections)
transformed_connections[:, :2] = transform.inverted().transform(connections[:, :2])
transformed_connections[:, 2:] = transform.inverted().transform(connections[:, 2:])
Expand Down
Loading

0 comments on commit 6ea4e6d

Please sign in to comment.