diff --git a/adjustText/__init__.py b/adjustText/__init__.py index 7c18a4d..b282ada 100644 --- a/adjustText/__init__.py +++ b/adjustText/__init__.py @@ -122,11 +122,7 @@ def get_bboxes(objs, r=None, expand=(1, 1), ax=None, transform=None): return get_bboxes_pathcollection(objs, ax) -def get_2d_coordinates(objs): - try: - ax = objs[0].axes - except: - ax = objs.axes +def get_2d_coordinates(objs, ax): bboxes = get_bboxes(objs, get_renderer(ax.get_figure()), (1.0, 1.0), ax) xs = [ (ax.convert_xunits(bbox.xmin), ax.convert_yunits(bbox.xmax)) for bbox in bboxes @@ -173,10 +169,16 @@ def get_shifts_extra(coords, extra_coords): if len(overlaps) == 0: return np.zeros((coords.shape[0])), np.zeros((coords.shape[0])) - diff = coords[overlaps[:, 0]] - extra_coords[overlaps[:, 1]] + diff_x = coords[overlaps[:, 0], :2] - extra_coords[overlaps[:, 1], -3::-1] + diff_y = coords[overlaps[:, 0], 2:] - extra_coords[overlaps[:, 1], -1:-3:-1] + + xshifts = np.where( + np.abs(diff_x[:, 0]) < np.abs(diff_x[:, 1]), diff_x[:, 0], diff_x[:, 1] + ) + yshifts = np.where( + np.abs(diff_y[:, 0]) < np.abs(diff_y[:, 1]), diff_y[:, 0], diff_y[:, 1] + ) - xshifts = np.where(np.abs(diff[:, 0]) < np.abs(diff[:, 1]), diff[:, 0], diff[:, 1]) - yshifts = np.where(np.abs(diff[:, 2]) < np.abs(diff[:, 3]), diff[:, 2], diff[:, 3]) xshifts = np.bincount(overlaps[:, 0], xshifts, minlength=N) yshifts = np.bincount(overlaps[:, 0], yshifts, minlength=N) return xshifts, yshifts @@ -495,13 +497,13 @@ def adjust_text( elif time_lim is not None and iter_lim is not None: logging.warn("Both time_lim and iter_lim are set, faster will be used") start_time = timer() - coords = get_2d_coordinates(texts) + coords = get_2d_coordinates(texts, ax) if expand_axes: expand_axes_to_fit(coords, ax, transform) force_draw(ax) transform = texts[0].get_transform() - coords = get_2d_coordinates(texts) + coords = get_2d_coordinates(texts, ax) original_coords = [text.get_unitless_position() for text in texts] original_coords_disp_coord = transform.transform(original_coords) @@ -531,7 +533,9 @@ def adjust_text( if objects is None: obj_coords = np.empty((0, 4)) else: - obj_coords = get_2d_coordinates(objects) + obj_coords = get_2d_coordinates(objects, ax) + obj_coords[:, [0, 2]] = transform.transform(obj_coords[:, [0, 2]]) + obj_coords[:, [1, 3]] = transform.transform(obj_coords[:, [1, 3]]) static_coords = np.vstack([point_coords[:, [0, 0, 1, 1]], obj_coords]) if explode_radius == "auto": explode_radius = max(