Skip to content

Commit

Permalink
Fix objects!
Browse files Browse the repository at this point in the history
  • Loading branch information
Phlya committed Mar 18, 2024
1 parent 5d2f7c9 commit f55b83c
Showing 1 changed file with 15 additions and 11 deletions.
26 changes: 15 additions & 11 deletions adjustText/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -122,11 +122,7 @@ def get_bboxes(objs, r=None, expand=(1, 1), ax=None, transform=None):
return get_bboxes_pathcollection(objs, ax)


def get_2d_coordinates(objs):
try:
ax = objs[0].axes
except:
ax = objs.axes
def get_2d_coordinates(objs, ax):
bboxes = get_bboxes(objs, get_renderer(ax.get_figure()), (1.0, 1.0), ax)
xs = [
(ax.convert_xunits(bbox.xmin), ax.convert_yunits(bbox.xmax)) for bbox in bboxes
Expand Down Expand Up @@ -173,10 +169,16 @@ def get_shifts_extra(coords, extra_coords):
if len(overlaps) == 0:
return np.zeros((coords.shape[0])), np.zeros((coords.shape[0]))

diff = coords[overlaps[:, 0]] - extra_coords[overlaps[:, 1]]
diff_x = coords[overlaps[:, 0], :2] - extra_coords[overlaps[:, 1], -3::-1]
diff_y = coords[overlaps[:, 0], 2:] - extra_coords[overlaps[:, 1], -1:-3:-1]

xshifts = np.where(
np.abs(diff_x[:, 0]) < np.abs(diff_x[:, 1]), diff_x[:, 0], diff_x[:, 1]
)
yshifts = np.where(
np.abs(diff_y[:, 0]) < np.abs(diff_y[:, 1]), diff_y[:, 0], diff_y[:, 1]
)

xshifts = np.where(np.abs(diff[:, 0]) < np.abs(diff[:, 1]), diff[:, 0], diff[:, 1])
yshifts = np.where(np.abs(diff[:, 2]) < np.abs(diff[:, 3]), diff[:, 2], diff[:, 3])
xshifts = np.bincount(overlaps[:, 0], xshifts, minlength=N)
yshifts = np.bincount(overlaps[:, 0], yshifts, minlength=N)
return xshifts, yshifts
Expand Down Expand Up @@ -495,13 +497,13 @@ def adjust_text(
elif time_lim is not None and iter_lim is not None:
logging.warn("Both time_lim and iter_lim are set, faster will be used")
start_time = timer()
coords = get_2d_coordinates(texts)
coords = get_2d_coordinates(texts, ax)

if expand_axes:
expand_axes_to_fit(coords, ax, transform)
force_draw(ax)
transform = texts[0].get_transform()
coords = get_2d_coordinates(texts)
coords = get_2d_coordinates(texts, ax)

original_coords = [text.get_unitless_position() for text in texts]
original_coords_disp_coord = transform.transform(original_coords)
Expand Down Expand Up @@ -531,7 +533,9 @@ def adjust_text(
if objects is None:
obj_coords = np.empty((0, 4))
else:
obj_coords = get_2d_coordinates(objects)
obj_coords = get_2d_coordinates(objects, ax)
obj_coords[:, [0, 2]] = transform.transform(obj_coords[:, [0, 2]])
obj_coords[:, [1, 3]] = transform.transform(obj_coords[:, [1, 3]])
static_coords = np.vstack([point_coords[:, [0, 0, 1, 1]], obj_coords])
if explode_radius == "auto":
explode_radius = max(
Expand Down

0 comments on commit f55b83c

Please sign in to comment.