Skip to content

Commit

Permalink
feat: add foraging eff
Browse files Browse the repository at this point in the history
  • Loading branch information
hanhou committed Aug 31, 2024
1 parent 93381f3 commit d2f7a53
Showing 1 changed file with 17 additions and 2 deletions.
19 changes: 17 additions & 2 deletions code/pages/4_RL model playground.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@
from aind_dynamic_foraging_models import generative_model
from aind_dynamic_foraging_models.generative_model import ForagerCollection
from aind_dynamic_foraging_models.generative_model.params import ParamsSymbols
from aind_dynamic_foraging_basic_analysis import compute_foraging_efficiency

try:
st.set_page_config(layout="wide",
Expand Down Expand Up @@ -251,7 +252,14 @@ def app():
# -- Run the model --
forager.perform(task)

if_plot_latent = st.checkbox("Plot latent variables", value=False)
# Evaluate the foraging efficiency
foraging_eff, foraging_eff_random_seed = compute_foraging_efficiency(
baited=task.reward_baiting,
choice_history=forager.get_choice_history(),
reward_history=forager.get_reward_history(),
p_reward=forager.get_p_reward(),
random_number=task.random_numbers.T,
)

# Capture the results
# ground_truth_params = forager.params.model_dump()
Expand All @@ -262,8 +270,11 @@ def app():
# reward_history = forager.get_reward_history()

# Plot the session results
if_plot_latent = st.checkbox("Plot latent variables", value=False)
fig, axes = forager.plot_session(if_plot_latent=if_plot_latent)
with st.columns([1, 0.5])[0]:

col0 = st.columns([1, 0.5])
with col0[0]:
st.pyplot(fig)

# Plot block logic
Expand All @@ -274,5 +285,9 @@ def app():
ax[0].legend()
fig.suptitle("Reward schedule")
st.pyplot(fig)

with col0[1]:
st.write(f"#### **Foraging efficiency**:")
st.write(f"# {foraging_eff_random_seed:.3f}")

app()

0 comments on commit d2f7a53

Please sign in to comment.