Skip to content

Commit

Permalink
add male/female S! fig
Browse files Browse the repository at this point in the history
  • Loading branch information
mschart authored Oct 17, 2024
1 parent 4aa3516 commit c852637
Showing 1 changed file with 110 additions and 14 deletions.
124 changes: 110 additions & 14 deletions brainwidemap/meta/bwm_figs.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,11 +30,13 @@
from matplotlib.ticker import (MultipleLocator, AutoMinorLocator)
import matplotlib.ticker as tck
import matplotlib.colors as mcolors
import matplotlib.patches as mpatches

from brainwidemap import download_aggregate_tables, bwm_units
from brainwidemap.encoding.design import generate_design
from brainwidemap.encoding.glm_predict import GLMPredictor, predict
from brainwidemap.encoding.utils import load_regressors, single_cluster_raster, find_trial_ids
from brainwidemap.encoding.utils import (load_regressors,
single_cluster_raster, find_trial_ids)
from brainbox.plot import peri_event_time_histogram
from reproducible_ephys_functions import labs

Expand Down Expand Up @@ -1759,6 +1761,9 @@ def perf_scatter(rerun=False):
#####
'''

dec_d = {'stim': 'stimside', 'choice': 'choice',
'fback': 'feedback'}

def group_into_regions():

'''
Expand Down Expand Up @@ -1811,9 +1816,6 @@ def significance_by_region(group):
result['sig_combined'] = result['pval_combined'] < ALPHA_LEVEL
return result

dec_d = {'stim': 'stimside', 'choice': 'choice',
'fback': 'feedback'}

# indicate in file name constraint
exx = '' if MIN_TRIALS == 0 else ('_' + str(MIN_TRIALS))

Expand Down Expand Up @@ -2218,8 +2220,7 @@ def swansons_SI_dec(vari):

fig, ax = plt.subplots(figsize=(5.67, 3.18))

dec_d = {'stim': 'stimside', 'choice': 'choice',
'fback': 'feedback'}


pqt_file = os.path.join(dec_pth,f"{dec_d[vari]}_stage2.pqt")
df1 = pd.read_parquet(pqt_file)
Expand Down Expand Up @@ -2283,6 +2284,103 @@ def swansons_SI_dec(vari):



def plot_female_male_repro():

'''
For the 5 repro regions, and all variables (stimulus, choice, feedback),
plot results split by female and male mice with two strips of dots
(blue for male, red for female) for each region.
'''
repro_regs = ["VISa/am", "CA1", "DG", "LP", "PO"]


# subject = one.get_details(eid)['subject']
# info = one.alyx.rest('subjects', 'read', id=subject)
# sex = info['sex'] # 'M', 'F', or 'U'

# Load the male/female subject dictionary
ds = np.load(os.path.join(dec_pth, 'male_female.npy'), allow_pickle=True).flat[0]

# Set up the plot structure (3 rows for each variable, 1 column)
fig, axes = plt.subplots(nrows=3, ncols=1, figsize=(8.26, 4.37),
sharey=True, sharex=True)

for row_idx, vari in enumerate(variables):
pqt_file = os.path.join(dec_pth, f"{dec_d[vari]}_stage2.pqt")
df = pd.read_parquet(pqt_file)

# Combine 'VISa' and 'VISam' into 'VISa/am'
df['region'] = df['region'].replace({'VISa': 'VISa/am', 'VISam': 'VISa/am'})

# Restrict the DataFrame to the 5 repro regions
df = df[df['region'].isin(repro_regs)]

# Map the sex information to the DataFrame
df['sex'] = df['subject'].map(ds)

# Separate by sex
female_df = df[df['sex'] == 'F']
male_df = df[df['sex'] == 'M']

# Combine male and female data into a new DataFrame for better comparison in the plot
df_combined = pd.concat([female_df.assign(gender='Female'), male_df.assign(gender='Male')])

# Plot in the corresponding row
ax = axes[row_idx]

# Strip plot for both male (blue) and female (red) side by side for each region
sns.stripplot(data=df_combined, x='region', y='score', hue='gender',
palette={'Male': 'blue', 'Female': 'red'},
dodge=True, jitter=True, size=3, ax=ax)

# Add mean null score (white dot) for each gender and region
# Ensure the x-alignment is consistent (no dodge for null scores)
for region in repro_regs:
# Get the x position for this region (for both male and female)
x_female = repro_regs.index(region) - 0.2 # Slight left shift for female (consistent with dodge)
x_male = repro_regs.index(region) + 0.2 # Slight right shift for male (consistent with dodge)

# Plot null score for females (red)
female_null = female_df[female_df['region'] == region]['median-null'].mean()
ax.scatter(x_female, female_null, color='white',
edgecolor='red', zorder=3, s=30)

# Plot null score for males (blue)
male_null = male_df[male_df['region'] == region]['median-null'].mean()
ax.scatter(x_male, male_null, color='white',
edgecolor='blue', zorder=3, s=30)

ax.spines['top'].set_visible(False)
ax.spines['right'].set_visible(False)

# Set labels and titles for each row
ax.set_xlabel("Brain Regions")
ax.set_ylabel(f"{variverb[vari]} \n decoding accuracy")

# Remove legend from individual plots to keep the layout clean
ax.get_legend().remove()
for i in range(1, len(repro_regs)):
ax.axvline(x=i-0.5, color='grey', linestyle='--', linewidth=1)

legend_handles = [
plt.Line2D([0], [0], marker='o', color='white', label='Male',
markerfacecolor='blue', markeredgecolor='blue', markersize=5),
plt.Line2D([0], [0], marker='o', color='white', label='Female',
markerfacecolor='red', markeredgecolor='red', markersize=5),
plt.Line2D([0], [0], marker='o', color='white', label='Null (Male)',
markerfacecolor='white', markeredgecolor='blue', markersize=5),
plt.Line2D([0], [0], marker='o', color='white', label='Null (Female)',
markerfacecolor='white', markeredgecolor='red', markersize=5)
]

fig.legend(handles=legend_handles, loc='upper center',
frameon=False, ncols=2).set_draggable(True)

# Adjust layout
plt.tight_layout()
fig.savefig(Path(imgs_pth, 'si',
f'n6_supp_dec_repro_male_female.pdf'), dpi=150)


'''
##########
Expand Down Expand Up @@ -2475,9 +2573,9 @@ def get_example_results():
"stimOn_times", # Alignset key
),
"choice": (
"671c7ea7-6726-4fbe-adeb-f89c2c8e489b",
"04c9890f-2276-4c20-854f-305ff5c9b6cf",
143, # was 143
"a7763417-e0d6-4f2a-aa55-e382fd9b5fb8",#"671c7ea7-6726-4fbe-adeb-f89c2c8e489b"
"57c5856a-c7bd-4d0f-87c6-37005b1484aa",#"04c9890f-2276-4c20-854f-305ff5c9b6cf"
74, # was 143
"GRN",
0.000992895, # drsq
"firstMovement_times",
Expand All @@ -2500,7 +2598,7 @@ def get_example_results():
return targetunits, alignsets, sortlookup


def ecoding_raster_lines(variable, clu_id0=None, axs=None,
def encoding_raster_lines(variable, clu_id0=None, axs=None,
frac_tr=3):

'''
Expand Down Expand Up @@ -3028,8 +3126,6 @@ def plot_curves_scatter(variable, ga_pcs=False, curve='euc',
color=palette[reg],
label=f"{reg} {d[reg]['nclus']}")

axs[k].set_xscale('log')

# put region labels
y = yy[-1]
x = xx[-1]
Expand Down Expand Up @@ -3282,12 +3378,12 @@ def ax_str(x):

# encoding panels
if not save_pans:
ecoding_raster_lines(variable,clu_id0= clu_id0,
encoding_raster_lines(variable,clu_id0= clu_id0,
axs=[ax_str(x) for x in
['ras', 'enc0', 'enc1']])

else:
ecoding_raster_lines(variable, clu_id0=clu_id0)
encoding_raster_lines(variable, clu_id0=clu_id0)


'''
Expand Down

0 comments on commit c852637

Please sign in to comment.