Skip to content

Commit

Permalink
Merge pull request prody#2007 from jamesmkrieger/showProj
Browse files Browse the repository at this point in the history
restore one label and markersize to showProjection
  • Loading branch information
jamesmkrieger authored Nov 30, 2024
2 parents ffedac4 + 1323485 commit 7d4ab54
Showing 1 changed file with 16 additions and 4 deletions.
20 changes: 16 additions & 4 deletions prody/dynamics/plotting.py
Original file line number Diff line number Diff line change
Expand Up @@ -271,6 +271,8 @@ def showProjection(ensemble=None, modes=None, projection=None, *args, **kwargs):
weights = kwargs.pop('weights', None)
weights = None

markersize = kwargs.pop('markersize', None)

num = projection.shape[0]

use_labels = kwargs.pop('use_labels', True)
Expand All @@ -283,8 +285,14 @@ def showProjection(ensemble=None, modes=None, projection=None, *args, **kwargs):
labels = modes.getModel()._labels.tolist()
LOGGER.info('using labels from LDA model')

if labels is not None and len(labels) != num:
raise ValueError('label should have the same length as ensemble')
one_label = False
if labels is not None:
if len(labels) == 1 or np.isscalar(labels):
one_label = True
kwargs['label'] = labels

elif len(labels) != num:
raise ValueError('label should have the same length as ensemble')

c = kwargs.pop('c', 'b')
colors = kwargs.pop('color', c)
Expand All @@ -297,14 +305,14 @@ def showProjection(ensemble=None, modes=None, projection=None, *args, **kwargs):
if len(colors) != num:
raise ValueError('length of color must be {0}'.format(num))
elif isinstance(colors, dict):
if labels is None:
if labels is None or one_label:
raise TypeError('color must be a string or a list unless labels are provided')
colors_dict = colors
colors = [colors_dict[label] for label in labels]
else:
raise TypeError('color must be a string or a list or a dict if labels are provided')

if labels is not None and len(colors_dict) == 0:
if labels is not None and not one_label and len(colors_dict) == 0:
cycle_colors = plt.rcParams['axes.prop_cycle'].by_key()['color']
for i, label in enumerate(set(labels)):
colors_dict[label] = cycle_colors[i % len(cycle_colors)]
Expand All @@ -318,6 +326,8 @@ def showProjection(ensemble=None, modes=None, projection=None, *args, **kwargs):
show = plt.plot(range(len(projection)), projection.flatten(), *args, **kwargs)
if use_weights:
kwargs['s'] = weights
elif markersize is not None:
kwargs['s'] = markersize
if labels is not None and use_labels:
for label in set(labels):
kwargs['c'] = colors_dict[label]
Expand Down Expand Up @@ -444,6 +454,8 @@ def showProjection(ensemble=None, modes=None, projection=None, *args, **kwargs):
kwargs['c'] = color
if weights is not None and use_weights:
kwargs['s'] = weights
elif markersize is not None:
kwargs['s'] = markersize
plot(*(list(projection[indices].T) + args), **kwargs)
else:
kwargs['color'] = color
Expand Down

0 comments on commit 7d4ab54

Please sign in to comment.