Skip to content

Commit

Permalink
Linting plot
Browse files Browse the repository at this point in the history
  • Loading branch information
arokem committed Sep 10, 2024
1 parent 476229a commit 618dfe6
Showing 1 changed file with 27 additions and 19 deletions.
46 changes: 27 additions & 19 deletions afqinsight/plot.py
Original file line number Diff line number Diff line change
@@ -1,17 +1,16 @@
"""Plot bundle profiles."""

from collections import OrderedDict

import matplotlib.pyplot as plt
import numpy as np
import pandas as pd
import seaborn as sns

from collections import OrderedDict
from groupyr.transform import GroupExtractor
from tqdm.auto import tqdm

from .utils import BUNDLE_MAT_2_PYTHON
from .datasets import AFQDataset

from .utils import BUNDLE_MAT_2_PYTHON

POSITIONS = OrderedDict(
{
Expand Down Expand Up @@ -73,7 +72,8 @@ def plot_tract_profiles(
Parameters
----------
X : numpy.ndarray or AFQDataset class instance
If array, this is a matrix of tractometry features with shape (n_subjects, n_features).
If array, this is a matrix of tractometry features
with shape (n_subjects, n_features).
groups : list of numpy.ndarray, optional
feature indices for each feature group of ``X``.
Expand Down Expand Up @@ -149,7 +149,9 @@ def plot_tract_profiles(
if isinstance(X, AFQDataset):
if groups is not None or group_names is not None:
raise ValueError(
"You provided an AFQDataset class instance as `X` input and also a `groups` or `group_names` input, but these are mutually exclusive."
"You provided an AFQDataset class instance as `X` input and "
"also a `groups` or `group_names` input, but these are "
"mutually exclusive."
)
# Allocate the variables needed below based on the input dataset:
group_names = X.group_names
Expand All @@ -159,19 +161,23 @@ def plot_tract_profiles(
else:
if groups is None or group_names is None:
raise ValueError(
"You provided an array input as `X` but did not provide both a `groups` and a `group_names` input. You must provide both of these for array input. "
"You provided an array input as `X` but did not provide both "
"a `groups` and a `group_names` input. You must provide both "
"of these for array input."
)

plt_positions = subplot_positions if subplot_positions is not None else POSITIONS

if bins is not None and quantiles is not None:
raise ValueError(
"You specified both bins and quantiles. These parameters are mutually exclusive."
"You specified both bins and quantiles. These parameters are "
"mutually exclusive."
)

if (bins is not None or quantiles is not None) and group_by_name is None:
raise ValueError(
"You must supply a group_by_name when binning using either the bins or quantiles parameter."
"You must supply a group_by_name when binning using either the "
"bins or quantiles parameter."
)

if group_by is None:
Expand Down Expand Up @@ -204,7 +210,7 @@ def plot_tract_profiles(
X_select = GroupExtractor(
select=tid, groups=groups_metric, group_names=group_names_metric
).fit_transform(X_metric)
columns = [idx for idx in range(X_select.shape[1])]
columns = list(range(X_select.shape[1]))
df = pd.concat(
[
pd.DataFrame(X_select, columns=columns, dtype=np.float64),
Expand Down Expand Up @@ -240,8 +246,10 @@ def plot_tract_profiles(
if nrows is None:
nrows = 5
cc_bundles = ["PostParietal", "SupFrontal", "SupParietal", "Temporal"]
if any([tid in tract_stats.keys() for tid in cc_bundles]):
nrows = 6
for tid in cc_bundles:
if tid in tract_stats.keys():
nrows = 6
break

ncols = ncols if ncols is not None else 4

Expand Down Expand Up @@ -303,12 +311,12 @@ def plot_tract_profiles(
for b in zip(_bins[:-1], _bins[1:])
]
if group_by_name is not None:
figlegend_kwargs = dict(
facecolor="whitesmoke",
bbox_to_anchor=(0.5, 0.02),
loc="upper center",
ncol=6,
)
figlegend_kwargs = {
"facecolor": "whitesmoke",
"bbox_to_anchor": (0.5, 0.02),
"loc": "upper center",
"ncol": 6,
}

if legend_kwargs is not None:
figlegend_kwargs.update(legend_kwargs)
Expand All @@ -324,7 +332,7 @@ def plot_tract_profiles(
_ = legobj.set_linewidth(3.0)

if fig_tight_layout_kws is None:
fig_tight_layout_kws = dict(h_pad=0.5, w_pad=-0.5)
fig_tight_layout_kws = {"h_pad": 0.5, "w_pad": -0.5}

fig.tight_layout(**fig_tight_layout_kws)

Expand Down

0 comments on commit 618dfe6

Please sign in to comment.