Skip to content

Commit

Permalink
small fixes
Browse files Browse the repository at this point in the history
  • Loading branch information
Phlya committed Sep 13, 2024
1 parent faeb633 commit dacbf45
Showing 1 changed file with 35 additions and 26 deletions.
61 changes: 35 additions & 26 deletions adjustText/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -247,23 +247,24 @@ def force_into_bbox(coords, bbox):
return apply_shifts(coords, -dx, -dy)


def random_shifts(coords, max_move, only_move="xy"):
def random_shifts(coords, only_move="xy"):
# logger.debug(f"Random shifts with max_move: {max_move}")
mids = np.hstack(
[
np.mean(coords[:, :2], axis=1)[:, np.newaxis],
np.mean(coords[:, 2:], axis=1)[:, np.newaxis],
]
)
if max_move is None:
max_move = 1
# if max_move is None:
# max_move = 1
unq, count = np.unique(mids, axis=0, return_counts=True)
repeated_groups = unq[count > 1]

for repeated_group in repeated_groups:
repeated_idx = np.argwhere(np.all(mids == repeated_group, axis=1)).flatten()
logger.debug(f"Repeating group: {repeated_group}, idx: {repeated_idx}")
for idx in repeated_idx:
shifts = (np.random.rand(2) - 0.5) * 2 * max_move
shifts = (np.random.rand(2) - 0.5) * 2
if "x" not in only_move:
shifts[0] = 0
elif "x+" in only_move:
Expand All @@ -276,6 +277,7 @@ def random_shifts(coords, max_move, only_move="xy"):
shifts[1] = np.abs(shifts[1])
elif "y-" in only_move:
shifts[1] = -np.abs(shifts[1])
print(idx, shifts)
coords[idx] += np.asarray([shifts[0], shifts[0], shifts[1], shifts[1]])
return coords

Expand Down Expand Up @@ -323,7 +325,7 @@ def iterate(
bbox_to_contain=False,
only_move={"text": "xy", "static": "xy", "explode": "xy", "pull": "xy"},
):
coords = random_shifts(coords, max_move, only_move.get("explode", "xy"))
coords = random_shifts(coords, only_move.get("explode", "xy"))
text_shifts_x, text_shifts_y = get_shifts_texts(
expand_coords(coords, expand[0], expand[1])
)
Expand Down Expand Up @@ -429,6 +431,26 @@ def warn_once(msg: str):
logger.warning(msg)


def remove_crossings(coords, target_coords, step):
connections = np.hstack(
[
np.mean(coords[:, :2], axis=1)[:, np.newaxis],
np.mean(coords[:, 2:], axis=1)[:, np.newaxis],
target_coords,
]
)
for i, seg1 in enumerate(connections):
for j, seg2 in enumerate(connections):
if i <= j:
continue
inter = intersect(seg1, seg2)
if inter:
logger.debug(f"Removing crossing at step {step}: {i} and {j}")
logger.debug(f"Segments: {seg1} and {seg2}")
coords[i], coords[j] = coords[j].copy(), coords[i].copy()
return coords


def adjust_text(
texts,
x=None,
Expand Down Expand Up @@ -635,21 +657,21 @@ def adjust_text(
obj_coords[:, [1, 3]] = transform.transform(obj_coords[:, [1, 3]])
static_coords = np.vstack([point_coords[:, [0, 0, 1, 1]], obj_coords])

if isinstance(max_move, int):
if isinstance(max_move, float) or isinstance(max_move, int):
max_move = (max_move, max_move)
elif max_move is None:
max_move = (np.inf, np.inf)

if isinstance(force_explode, float):
if isinstance(force_explode, float) or isinstance(force_explode, int):
force_explode = (force_explode, force_explode)

if isinstance(force_text, float):
if isinstance(force_text, float) or isinstance(force_text, int):
force_text = (force_text, force_text)

if isinstance(force_static, float):
if isinstance(force_static, float) or isinstance(force_static, int):
force_static = (force_static, force_static)

if isinstance(force_pull, float):
if isinstance(force_pull, float) or isinstance(force_pull, int):
force_pull = (force_pull, force_pull)

if explode_radius == "auto":
Expand Down Expand Up @@ -698,6 +720,7 @@ def adjust_text(
step = 0
while error > 0:
# expand = expands[min(i, expand_steps-1)]
logger.debug(step)
coords, error = iterate(
coords,
target_xy_disp_coord,
Expand All @@ -712,26 +735,12 @@ def adjust_text(
only_move=only_move,
)
if prevent_crossings:
connections = np.hstack(
[
np.mean(coords[:, :2], axis=1)[:, np.newaxis],
np.mean(coords[:, 2:], axis=1)[:, np.newaxis],
target_xy_disp_coord,
]
)
for i, seg1 in enumerate(connections):
for j, seg2 in enumerate(connections):
if i >= j:
continue
inter = intersect(seg1, seg2)
if inter:
logger.debug(f"Removing crossing at step {step}: {i} and {j}")
coords[i], coords[j] = coords[j].copy(), coords[i].copy()
coords = remove_crossings(coords, target_xy_disp_coord, step)

step += 1
if time_lim is not None and timer() - start_time > time_lim:
break
if iter_lim is not None and i == iter_lim:
if iter_lim is not None and step == iter_lim:
break

logger.debug(f"Adjustment took {step} iterations")
Expand Down

0 comments on commit dacbf45

Please sign in to comment.