Skip to content

Commit

Permalink
Merge branch 'han_simulation_and_fitting' into main
Browse files Browse the repository at this point in the history
  • Loading branch information
hanhou committed Aug 31, 2024
2 parents 150bc6d + b9d7387 commit a5dc31e
Show file tree
Hide file tree
Showing 15 changed files with 296 additions and 363 deletions.
53 changes: 24 additions & 29 deletions code/Home.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,40 +14,35 @@

__ver__ = 'v2.5.3'

import pandas as pd
import streamlit as st
import numpy as np
import os

import extra_streamlit_components as stx
import numpy as np
import pandas as pd
import streamlit as st
import streamlit_nested_layout
from streamlit_plotly_events import plotly_events
from aind_auto_train import __version__ as auto_train_version
from aind_auto_train.auto_train_manager import DynamicForagingAutoTrainManager
from aind_auto_train.curriculum_manager import CurriculumManager
from pygwalker.api.streamlit import StreamlitRenderer, init_streamlit_comm
import extra_streamlit_components as stx

from util.settings import draw_type_mapper_session_level, draw_type_layout_definition

from util.streamlit import (aggrid_interactive_table_session,
aggrid_interactive_table_curriculum, add_session_filter, data_selector,
add_xy_selector, add_xy_setting, add_auto_train_manager, add_dot_property_mapper,
_plot_population_x_y)
from util.aws_s3 import (
load_data,
draw_session_plots_quick_preview,
show_session_level_img_by_key_and_prefix,
show_debug_info,
)
from util.url_query_helper import (
sync_URL_to_session_state, sync_session_state_to_URL,
slider_wrapper_for_url_query, checkbox_wrapper_for_url_query,
multiselect_wrapper_for_url_query, number_input_wrapper_for_url_query,
)

from streamlit_plotly_events import plotly_events
from util.aws_s3 import (draw_session_plots_quick_preview, load_data,
show_debug_info,
show_session_level_img_by_key_and_prefix)
from util.fetch_data_docDB import load_data_from_docDB

from aind_auto_train.curriculum_manager import CurriculumManager
from aind_auto_train.auto_train_manager import DynamicForagingAutoTrainManager
from aind_auto_train import __version__ as auto_train_version

from util.settings import (draw_type_layout_definition,
draw_type_mapper_session_level)
from util.streamlit import (_plot_population_x_y, add_auto_train_manager,
add_dot_property_mapper, add_session_filter,
add_xy_selector, add_xy_setting,
aggrid_interactive_table_curriculum,
aggrid_interactive_table_session, data_selector)
from util.url_query_helper import (checkbox_wrapper_for_url_query,
multiselect_wrapper_for_url_query,
number_input_wrapper_for_url_query,
slider_wrapper_for_url_query,
sync_session_state_to_URL,
sync_URL_to_session_state)

try:
st.set_page_config(layout="wide",
Expand Down
11 changes: 5 additions & 6 deletions code/pages/1_Learning trajectory.py
Original file line number Diff line number Diff line change
@@ -1,15 +1,14 @@
import numpy as np
import pandas as pd
import plotly.graph_objects as go
import s3fs
import streamlit as st
from plotly.subplots import make_subplots
from sklearn.decomposition import PCA
from sklearn.preprocessing import StandardScaler
import pandas as pd
import numpy as np
import plotly.graph_objects as go
from plotly.subplots import make_subplots
from streamlit_plotly_events import plotly_events

from util.streamlit import add_session_filter, data_selector
from util.aws_s3 import load_data
from util.streamlit import add_session_filter, data_selector

ss = st.session_state

Expand Down
6 changes: 3 additions & 3 deletions code/pages/2_HMM-GLM.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,12 +4,12 @@
"""
import os
import re
import numpy as np
from PIL import Image

import streamlit as st
import numpy as np
import s3fs
import streamlit as st
import streamlit_nested_layout
from PIL import Image

try:
st.set_page_config(layout="wide",
Expand Down
2 changes: 1 addition & 1 deletion code/pages/3_AIND data access playground.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,9 +2,9 @@
'''

import logging

import streamlit as st
from streamlit_dynamic_filters import DynamicFilters

from util.fetch_data_docDB import load_data_from_docDB

try:
Expand Down
12 changes: 6 additions & 6 deletions code/pages/4_RL model playground.py
Original file line number Diff line number Diff line change
@@ -1,14 +1,14 @@
"""Playground for RL models of dynamic foraging
"""

import streamlit as st
import streamlit_nested_layout
from typing import get_type_hints, _LiteralGenericAlias
import inspect
from typing import _LiteralGenericAlias, get_type_hints

from aind_behavior_gym.dynamic_foraging.task import (
CoupledBlockTask, UncoupledBlockTask, RandomWalkTask
)
import streamlit as st
import streamlit_nested_layout
from aind_behavior_gym.dynamic_foraging.task import (CoupledBlockTask,
RandomWalkTask,
UncoupledBlockTask)
from aind_dynamic_foraging_models import generative_model
from aind_dynamic_foraging_models.generative_model import ForagerCollection
from aind_dynamic_foraging_models.generative_model.params import ParamsSymbols
Expand Down
10 changes: 5 additions & 5 deletions code/util/aws_s3.py
Original file line number Diff line number Diff line change
@@ -1,13 +1,13 @@
from PIL import Image
import json

import s3fs
import pandas as pd
import s3fs
import streamlit as st
from PIL import Image

from .settings import (
draw_type_layout_definition, draw_type_mapper_session_level, draw_types_quick_preview
)
from .settings import (draw_type_layout_definition,
draw_type_mapper_session_level,
draw_types_quick_preview)

# --------------------------------------
data_sources = ['bonsai', 'bpod']
Expand Down
5 changes: 4 additions & 1 deletion code/util/fetch_data_docDB.py
Original file line number Diff line number Diff line change
@@ -1,15 +1,18 @@
"""Code to fetch data from docDB by David Feng
"""

import pandas as pd
import logging
import time

import pandas as pd
import semver
import streamlit as st

logger = logging.getLogger(__name__)

from aind_data_access_api.document_db import MetadataDbClient


@st.cache_data(ttl=3600*12) # Cache the df_docDB up to 12 hours
def load_data_from_docDB():
client = load_client()
Expand Down
2 changes: 1 addition & 1 deletion code/util/foraging_plotly.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
import numpy as np
import plotly.graph_objs as go
import plotly.express as px
import plotly.graph_objs as go


def moving_average(a, n=3) :
Expand Down
7 changes: 3 additions & 4 deletions code/util/plot_autotrain_manager.py
Original file line number Diff line number Diff line change
@@ -1,12 +1,11 @@
from datetime import datetime

import streamlit as st
import numpy as np
import plotly.graph_objects as go
import pandas as pd

from aind_auto_train.schema.curriculum import TrainingStage
import plotly.graph_objects as go
import streamlit as st
from aind_auto_train.plot.curriculum import get_stage_color_mapper
from aind_auto_train.schema.curriculum import TrainingStage


def plot_manager_all_progress(manager: 'AutoTrainManager',
Expand Down
6 changes: 2 additions & 4 deletions code/util/population.py
Original file line number Diff line number Diff line change
@@ -1,13 +1,11 @@
import numpy as np
import matplotlib.pyplot as plt
import numpy as np
import pandas as pd
import seaborn as sns
import scipy

import seaborn as sns
from statannotations.Annotator import Annotator



def _draw_variable_trial_back(df, beta_name, trials_back, ax=None):
if ax is None:
_, ax = plt.subplots(1, 1, figsize=(7, 3), constrained_layout=True)
Expand Down
38 changes: 16 additions & 22 deletions code/util/streamlit.py
Original file line number Diff line number Diff line change
@@ -1,36 +1,30 @@
import json
from collections import OrderedDict
import pandas as pd
import streamlit as st
from datetime import datetime
from st_aggrid import AgGrid, GridOptionsBuilder
from st_aggrid.shared import GridUpdateMode, ColumnsAutoSizeMode, DataReturnMode
from pandas.api.types import (
is_categorical_dtype,
is_numeric_dtype,
is_string_dtype,
)
import json
import streamlit.components.v1 as components
from streamlit_plotly_events import plotly_events

import matplotlib.pyplot as plt
import numpy as np
import pandas as pd
import plotly
import plotly.express as px
import numpy as np
import plotly.graph_objects as go
import statsmodels.api as sm
import streamlit as st
import streamlit.components.v1 as components
from pandas.api.types import (is_categorical_dtype, is_numeric_dtype,
is_string_dtype)
from scipy.stats import linregress

from .url_query_helper import (
checkbox_wrapper_for_url_query,
selectbox_wrapper_for_url_query,
slider_wrapper_for_url_query,
multiselect_wrapper_for_url_query,
get_filter_type,
)
from .plot_autotrain_manager import plot_manager_all_progress
from st_aggrid import AgGrid, GridOptionsBuilder
from st_aggrid.shared import (ColumnsAutoSizeMode, DataReturnMode,
GridUpdateMode)
from streamlit_plotly_events import plotly_events

from .aws_s3 import draw_session_plots_quick_preview
from .plot_autotrain_manager import plot_manager_all_progress
from .url_query_helper import (checkbox_wrapper_for_url_query, get_filter_type,
multiselect_wrapper_for_url_query,
selectbox_wrapper_for_url_query,
slider_wrapper_for_url_query)

custom_css = {
".ag-root.ag-unselectable.ag-layout-normal": {"font-size": "15px !important",
Expand Down
8 changes: 2 additions & 6 deletions code/util/url_query_helper.py
Original file line number Diff line number Diff line change
@@ -1,13 +1,9 @@
import streamlit as st
from pandas.api.types import (is_categorical_dtype, is_datetime64_any_dtype,
is_numeric_dtype)

from .settings import draw_type_mapper_session_level

from pandas.api.types import (
is_categorical_dtype,
is_datetime64_any_dtype,
is_numeric_dtype,
)

# Sync widgets with URL query params
# https://blog.streamlit.io/how-streamlit-uses-streamlit-sharing-contextual-apps/
# dict of "key": default pairs
Expand Down
Loading

0 comments on commit a5dc31e

Please sign in to comment.