diff --git a/brainwidemap/meta/bwm_figs.py b/brainwidemap/meta/bwm_figs.py index b3a8cbe..c1d227f 100644 --- a/brainwidemap/meta/bwm_figs.py +++ b/brainwidemap/meta/bwm_figs.py @@ -30,11 +30,13 @@ from matplotlib.ticker import (MultipleLocator, AutoMinorLocator) import matplotlib.ticker as tck import matplotlib.colors as mcolors +import matplotlib.patches as mpatches from brainwidemap import download_aggregate_tables, bwm_units from brainwidemap.encoding.design import generate_design from brainwidemap.encoding.glm_predict import GLMPredictor, predict -from brainwidemap.encoding.utils import load_regressors, single_cluster_raster, find_trial_ids +from brainwidemap.encoding.utils import (load_regressors, + single_cluster_raster, find_trial_ids) from brainbox.plot import peri_event_time_histogram from reproducible_ephys_functions import labs @@ -1759,6 +1761,9 @@ def perf_scatter(rerun=False): ##### ''' +dec_d = {'stim': 'stimside', 'choice': 'choice', + 'fback': 'feedback'} + def group_into_regions(): ''' @@ -1811,9 +1816,6 @@ def significance_by_region(group): result['sig_combined'] = result['pval_combined'] < ALPHA_LEVEL return result - dec_d = {'stim': 'stimside', 'choice': 'choice', - 'fback': 'feedback'} - # indicate in file name constraint exx = '' if MIN_TRIALS == 0 else ('_' + str(MIN_TRIALS)) @@ -2218,8 +2220,7 @@ def swansons_SI_dec(vari): fig, ax = plt.subplots(figsize=(5.67, 3.18)) - dec_d = {'stim': 'stimside', 'choice': 'choice', - 'fback': 'feedback'} + pqt_file = os.path.join(dec_pth,f"{dec_d[vari]}_stage2.pqt") df1 = pd.read_parquet(pqt_file) @@ -2283,6 +2284,103 @@ def swansons_SI_dec(vari): +def plot_female_male_repro(): + + ''' + For the 5 repro regions, and all variables (stimulus, choice, feedback), + plot results split by female and male mice with two strips of dots + (blue for male, red for female) for each region. + ''' + repro_regs = ["VISa/am", "CA1", "DG", "LP", "PO"] + + + # subject = one.get_details(eid)['subject'] + # info = one.alyx.rest('subjects', 'read', id=subject) + # sex = info['sex'] # 'M', 'F', or 'U' + + # Load the male/female subject dictionary + ds = np.load(os.path.join(dec_pth, 'male_female.npy'), allow_pickle=True).flat[0] + + # Set up the plot structure (3 rows for each variable, 1 column) + fig, axes = plt.subplots(nrows=3, ncols=1, figsize=(8.26, 4.37), + sharey=True, sharex=True) + + for row_idx, vari in enumerate(variables): + pqt_file = os.path.join(dec_pth, f"{dec_d[vari]}_stage2.pqt") + df = pd.read_parquet(pqt_file) + + # Combine 'VISa' and 'VISam' into 'VISa/am' + df['region'] = df['region'].replace({'VISa': 'VISa/am', 'VISam': 'VISa/am'}) + + # Restrict the DataFrame to the 5 repro regions + df = df[df['region'].isin(repro_regs)] + + # Map the sex information to the DataFrame + df['sex'] = df['subject'].map(ds) + + # Separate by sex + female_df = df[df['sex'] == 'F'] + male_df = df[df['sex'] == 'M'] + + # Combine male and female data into a new DataFrame for better comparison in the plot + df_combined = pd.concat([female_df.assign(gender='Female'), male_df.assign(gender='Male')]) + + # Plot in the corresponding row + ax = axes[row_idx] + + # Strip plot for both male (blue) and female (red) side by side for each region + sns.stripplot(data=df_combined, x='region', y='score', hue='gender', + palette={'Male': 'blue', 'Female': 'red'}, + dodge=True, jitter=True, size=3, ax=ax) + + # Add mean null score (white dot) for each gender and region + # Ensure the x-alignment is consistent (no dodge for null scores) + for region in repro_regs: + # Get the x position for this region (for both male and female) + x_female = repro_regs.index(region) - 0.2 # Slight left shift for female (consistent with dodge) + x_male = repro_regs.index(region) + 0.2 # Slight right shift for male (consistent with dodge) + + # Plot null score for females (red) + female_null = female_df[female_df['region'] == region]['median-null'].mean() + ax.scatter(x_female, female_null, color='white', + edgecolor='red', zorder=3, s=30) + + # Plot null score for males (blue) + male_null = male_df[male_df['region'] == region]['median-null'].mean() + ax.scatter(x_male, male_null, color='white', + edgecolor='blue', zorder=3, s=30) + + ax.spines['top'].set_visible(False) + ax.spines['right'].set_visible(False) + + # Set labels and titles for each row + ax.set_xlabel("Brain Regions") + ax.set_ylabel(f"{variverb[vari]} \n decoding accuracy") + + # Remove legend from individual plots to keep the layout clean + ax.get_legend().remove() + for i in range(1, len(repro_regs)): + ax.axvline(x=i-0.5, color='grey', linestyle='--', linewidth=1) + + legend_handles = [ + plt.Line2D([0], [0], marker='o', color='white', label='Male', + markerfacecolor='blue', markeredgecolor='blue', markersize=5), + plt.Line2D([0], [0], marker='o', color='white', label='Female', + markerfacecolor='red', markeredgecolor='red', markersize=5), + plt.Line2D([0], [0], marker='o', color='white', label='Null (Male)', + markerfacecolor='white', markeredgecolor='blue', markersize=5), + plt.Line2D([0], [0], marker='o', color='white', label='Null (Female)', + markerfacecolor='white', markeredgecolor='red', markersize=5) + ] + + fig.legend(handles=legend_handles, loc='upper center', + frameon=False, ncols=2).set_draggable(True) + + # Adjust layout + plt.tight_layout() + fig.savefig(Path(imgs_pth, 'si', + f'n6_supp_dec_repro_male_female.pdf'), dpi=150) + ''' ########## @@ -2475,9 +2573,9 @@ def get_example_results(): "stimOn_times", # Alignset key ), "choice": ( - "671c7ea7-6726-4fbe-adeb-f89c2c8e489b", - "04c9890f-2276-4c20-854f-305ff5c9b6cf", - 143, # was 143 + "a7763417-e0d6-4f2a-aa55-e382fd9b5fb8",#"671c7ea7-6726-4fbe-adeb-f89c2c8e489b" + "57c5856a-c7bd-4d0f-87c6-37005b1484aa",#"04c9890f-2276-4c20-854f-305ff5c9b6cf" + 74, # was 143 "GRN", 0.000992895, # drsq "firstMovement_times", @@ -2500,7 +2598,7 @@ def get_example_results(): return targetunits, alignsets, sortlookup -def ecoding_raster_lines(variable, clu_id0=None, axs=None, +def encoding_raster_lines(variable, clu_id0=None, axs=None, frac_tr=3): ''' @@ -3028,8 +3126,6 @@ def plot_curves_scatter(variable, ga_pcs=False, curve='euc', color=palette[reg], label=f"{reg} {d[reg]['nclus']}") - axs[k].set_xscale('log') - # put region labels y = yy[-1] x = xx[-1] @@ -3282,12 +3378,12 @@ def ax_str(x): # encoding panels if not save_pans: - ecoding_raster_lines(variable,clu_id0= clu_id0, + encoding_raster_lines(variable,clu_id0= clu_id0, axs=[ax_str(x) for x in ['ras', 'enc0', 'enc1']]) else: - ecoding_raster_lines(variable, clu_id0=clu_id0) + encoding_raster_lines(variable, clu_id0=clu_id0) '''