Skip to content

Commit

Permalink
Merge pull request #9 from JulioAPeraza/add-plots
Browse files Browse the repository at this point in the history
Add `wordcloud` and `radar` plot
  • Loading branch information
JulioAPeraza authored Aug 28, 2023
2 parents 58b9986 + fdaaeb3 commit c2cae63
Show file tree
Hide file tree
Showing 4 changed files with 182 additions and 17 deletions.
6 changes: 5 additions & 1 deletion .readthedocs.yml → .readthedocs.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -5,13 +5,17 @@
# Required
version: 2

build:
os: "ubuntu-22.04"
tools:
python: "3.8"

# Build documentation in the docs/ directory with Sphinx
sphinx:
configuration: docs/conf.py
fail_on_warning: false

python:
version: 3.8
install:
- requirements: docs/requirements.txt
- method: pip
Expand Down
2 changes: 2 additions & 0 deletions gradec/fetcher.py
Original file line number Diff line number Diff line change
Expand Up @@ -73,6 +73,8 @@
"source-vasa2018_desc-null1000_space-fsaverage_density-164k_spinsamples.npz": "e956f", # Updt
"gclda_neurosynth_model.pkl.gz": "bg8ef",
"gclda_neuroquery_model.pkl.gz": "vsm65",
"lda_neurosynth_model.pkl.gz": "3kgfe",
"lda_neuroquery_model.pkl.gz": "wevdn",
"hcp-s1200_gradients.npy": "t95gk",
"principal_gradient.npy": "5th7c",
"neuroquery_counts": "p39mg",
Expand Down
168 changes: 168 additions & 0 deletions gradec/plot.py
Original file line number Diff line number Diff line change
@@ -1,13 +1,181 @@
"""Plot module for gradec."""
import math
import os

import matplotlib.cm as cm
import matplotlib.pyplot as plt
import numpy as np
from neuromaps.datasets import fetch_civet, fetch_fsaverage, fetch_fslr
from surfplot import Plot
from surfplot.utils import threshold
from wordcloud import WordCloud

from gradec.fetcher import _fetch_model
from gradec.utils import get_data_dir


def _get_twfrequencies(dset_nm, model_nm, n_top_terms, data_dir=None):
"""Get top word frequencies from a topic model."""
model_obj = _fetch_model(dset_nm, model_nm, data_dir=data_dir)

topic_word_weights = (
model_obj.p_word_g_topic_.T
if model_nm == "gclda"
else model_obj.distributions_["p_topic_g_word"]
)

n_topics = topic_word_weights.shape[0]
sorted_weights_idxs = np.argsort(-topic_word_weights, axis=1)
frequencies_lst = []
for topic_i in range(n_topics):
frequencies = topic_word_weights[topic_i, sorted_weights_idxs[topic_i, :]][
:n_top_terms
].tolist()
frequencies = [freq / np.max(frequencies) for freq in frequencies]
frequencies = np.round(frequencies, 3).tolist()
frequencies_lst.append(frequencies)

return frequencies_lst


def plot_radar(
corrs,
features,
model_nm,
cmap="YlOrRd",
n_top_terms=3,
fig=None,
ax=None,
out_fig=None,
):
"""Plot radar chart."""
n_rows = min(len(corrs), 10)
angle_zero = 0
fontsize = 36

# Sort features and correlations
corrs = np.array(corrs)
sorted_indices = np.argsort(-corrs)
corrs = corrs[sorted_indices]
features = np.array(features)[sorted_indices]

corrs = corrs[:n_rows]
features = features[:n_rows]
angles = [(angle_zero + (n / float(n_rows) * 2 * np.pi)) for n in range(n_rows)]
if model_nm in ["lda", "gclda"]:
features = ["\n".join(feature[:n_top_terms]).replace(" ", "\n") for feature in features]
else:
features = [feature.replace(" ", "\n") for feature in features]

roundup_corr = math.ceil(corrs.max() * 10) / 10

# Define color scheme
plt.rcParams["text.color"] = "#1f1f1f"
cmap_ = cm.get_cmap(cmap)
norm = plt.Normalize(vmin=corrs.min(), vmax=corrs.max())
colors = cmap_(norm(corrs))

# Plot radar
if fig is None and ax is None:
fig, ax = plt.subplots(figsize=(9, 9), subplot_kw={"projection": "polar"})

ax.set_theta_offset(0)
ax.set_ylim(-0.1, roundup_corr)

ax.bar(angles, corrs, color=colors, alpha=0.9, width=0.52, zorder=10)
ax.vlines(angles, 0, roundup_corr, color="grey", ls=(0, (4, 4)), zorder=11)

ax.set_xticks(angles)
ax.set_xticklabels(features, size=fontsize, zorder=13)

ax.xaxis.grid(False)

step = 0.1 + 1e-09
yticks = np.round(np.arange(0, roundup_corr + step, step), 1)
ax.set_yticklabels([])
ax.set_yticks(yticks)

ax.spines["start"].set_color("none")
ax.spines["polar"].set_color("none")

xticks = ax.xaxis.get_major_ticks()
[xtick.set_pad(90) for xtick in xticks]

sep = 0.06
[
ax.text(
np.pi / 2,
ytick - sep,
f"{ytick}",
ha="center",
size=fontsize - 2,
color="grey",
zorder=12,
)
for ytick in yticks
]

if out_fig is None:
return fig

fig.savefig(out_fig, bbox_inches="tight")
plt.close()


def plot_cloud(
corrs,
features,
dset_nm,
model_nm,
cmap="YlOrRd",
n_top_terms=3,
dpi=100,
fig=None,
ax=None,
out_fig=None,
data_dir=None,
):
"""Plot word cloud."""
features = [feature[:n_top_terms] for feature in features]
frequencies = _get_twfrequencies(dset_nm, model_nm, n_top_terms, data_dir=data_dir)

frequencies_dict = {}
if model_nm in ["lda", "gclda"]:
for corr, features, frequency in zip(corrs, features, frequencies):
for word, freq in zip(features, frequency):
if word not in frequencies_dict:
frequencies_dict[word] = freq * corr
else:
for word, corr in zip(features, corrs):
if word not in frequencies_dict:
frequencies_dict[word] = corr

with_, hight_ = 9, 5
if fig is None and ax is None:
fig, ax = plt.subplots(figsize=(with_, hight_))

wc = WordCloud(
width=with_ * dpi,
height=hight_ * dpi,
background_color="white",
random_state=0,
colormap=cmap,
)
wc.generate_from_frequencies(frequencies=frequencies_dict)
ax.imshow(wc)

ax.get_xaxis().set_ticks([])
ax.get_yaxis().set_ticks([])
for spine in ax.spines.values():
spine.set_visible(False)

if out_fig is None:
return fig

fig.savefig(out_fig, bbox_inches="tight", dpi=dpi)
plt.close()


def plot_surf_maps(
lh_grad,
rh_grad,
Expand Down
23 changes: 7 additions & 16 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@ build-backend = "hatchling.build"
[project]
name = "gradec"
description = "Meta-analytic gradient decoding"
license = {file = "LICENSE"}
license = { file = "LICENSE" }
authors = [{ name = "Gradec developers" }]
maintainers = [{ name = "Julio A Peraza", email = "[email protected]" }]
readme = "README.md"
Expand All @@ -25,9 +25,10 @@ classifiers = [
dependencies = [
"nimare",
"neuromaps",
"scipy<1.8.0", # temporary fix for netneurotools issue
"scipy<1.8.0", # temporary fix for netneurotools issue
"netneurotools",
"surfplot",
"wordcloud", # for plotting
]
dynamic = ["version"]

Expand All @@ -47,16 +48,9 @@ doc = [
"sphinx_rtd_theme>=0.5.2",
"sphinxcontrib-bibtex",
]
dev = [
"black",
"pre-commit",
"isort",
"flake8-pyproject",
]
dev = ["black", "pre-commit", "isort", "flake8-pyproject"]
# For testing of oldest usable versions of dependencies.
min = [
"nimare==0.1.0",
]
min = ["nimare==0.1.0"]
test = [
"coverage",
"coveralls",
Expand Down Expand Up @@ -87,10 +81,7 @@ version-file = "gradec/_version.py"

[tool.flake8]
max-line-length = 99
exclude = [
"*build/",
"gradec/_version.py",
]
exclude = ["*build/", "gradec/_version.py"]
ignore = ["E203", "E402", "E722", "W503"]
per-file-ignores = """
*/__init__.py:D401
Expand Down Expand Up @@ -121,4 +112,4 @@ exclude = '''
| versioneer.py
| gradec/_version.py
)
'''
'''

0 comments on commit c2cae63

Please sign in to comment.