Skip to content

Commit

Permalink
Merge branch 'main' into han_dim_reduction
Browse files Browse the repository at this point in the history
  • Loading branch information
hanhou committed Apr 5, 2024
2 parents 8d03db2 + 8060b27 commit 0b418f2
Show file tree
Hide file tree
Showing 4 changed files with 536 additions and 295 deletions.
202 changes: 133 additions & 69 deletions code/Home.py
Original file line number Diff line number Diff line change
Expand Up @@ -40,8 +40,12 @@

from util.streamlit import (filter_dataframe, aggrid_interactive_table_session,
aggrid_interactive_table_curriculum, add_session_filter, data_selector,
_sync_widget_with_query, add_xy_selector, add_xy_setting, add_auto_train_manager,
add_xy_selector, add_xy_setting, add_auto_train_manager, add_dot_property_mapper,
_plot_population_x_y)
from util.url_query_helper import (
sync_widget_with_query, slider_wrapper_for_url_query,
)

import extra_streamlit_components as stx

from aind_auto_train.curriculum_manager import CurriculumManager
Expand All @@ -59,9 +63,11 @@
'filter_foraging_eff': [0.0, None],
'filter_task': ['all'],

'table_height': 300,

'tab_id': 'tab_session_x_y',
'x_y_plot_xname': 'session',
'x_y_plot_yname': 'foraging_eff',
'x_y_plot_yname': 'foraging_performance_random_seed',
'x_y_plot_group_by': 'h2o',
'x_y_plot_if_show_dots': True,
'x_y_plot_if_aggr_each_group': True,
Expand All @@ -73,9 +79,17 @@
'x_y_plot_q_quantiles_group': 20,
'x_y_plot_if_use_x_quantile_all': False,
'x_y_plot_q_quantiles_all': 20,
'x_y_plot_if_show_diagonal': False,
'x_y_plot_dot_size': 10,
'x_y_plot_dot_opacity': 0.5,
'x_y_plot_dot_opacity': 0.3,
'x_y_plot_line_width': 2.0,
'x_y_plot_figure_width': 1300,
'x_y_plot_figure_height': 900,
'x_y_plot_font_size_scale': 1.0,

'x_y_plot_size_mapper': 'finished_trials',
'x_y_plot_size_mapper_gamma': 1.0,
'x_y_plot_size_mapper_range': [3, 20],

'session_plot_mode': 'sessions selected from table or plot',

Expand Down Expand Up @@ -282,7 +296,42 @@ def draw_session_plots(df_to_draw_session):

my_bar.progress(int((i + 1) / len(df_to_draw_session) * 100))

def draw_session_plots_quick_preview(df_to_draw_session):

# Setting up layout for each session
layout_definition = [[1], # columns in the first row
[1, 1],
]
draw_types_quick_preview = ['1. Choice history', '2. Logistic regression (Su2022)']

container_session_all_in_one = st.container()

key = df_to_draw_session.to_dict(orient='records')[0]

with container_session_all_in_one:
try:
date_str = key["session_date"].strftime('%Y-%m-%d')
except:
date_str = key["session_date"].split("T")[0]

st.markdown(f'''<h4 style='text-align: center; color: orange;'>{key["h2o"]}, Session {int(key["session"])}, {date_str}''',
unsafe_allow_html=True)

rows = []
for row, column_setting in enumerate(layout_definition):
rows.append(st.columns(column_setting))

for draw_type in draw_types_quick_preview:
if draw_type not in st.session_state.selected_draw_types: continue # To keep the draw order defined by st.session_state.draw_type_mapper_session_level
prefix, position, setting = st.session_state.draw_type_mapper_session_level[draw_type]
this_col = rows[position[0]][position[1]] if len(st.session_state.selected_draw_types) > 1 else rows[0]
show_session_level_img_by_key_and_prefix(key,
column=this_col,
prefix=prefix,
**setting)




def draw_mice_plots(df_to_draw_mice):

Expand Down Expand Up @@ -335,7 +384,7 @@ def draw_mice_plots(df_to_draw_mice):

def session_plot_settings(need_click=True):
st.markdown('##### Show plots for individual sessions ')
cols = st.columns([2, 1])
cols = st.columns([2, 1, 6])

session_plot_modes = [f'sessions selected from table or plot', f'all sessions filtered from sidebar']
st.session_state.selected_draw_sessions = cols[0].selectbox(f'Which session(s) to draw?',
Expand All @@ -351,7 +400,7 @@ def session_plot_settings(need_click=True):
n_session_to_draw = len(st.session_state.df_selected_from_plotly) \
if 'selected from table or plot' in st.session_state.selected_draw_sessions \
else len(st.session_state.df_session_filtered)
st.markdown(f'{n_session_to_draw} sessions to draw')
cols[0].markdown(f'{n_session_to_draw} sessions to draw')

st.session_state.num_cols = cols[1].number_input('number of columns', 1, 10,
3 if 'num_cols' not in st.session_state else st.session_state.num_cols)
Expand All @@ -365,61 +414,38 @@ def session_plot_settings(need_click=True):
</style>""",
unsafe_allow_html=True,
)
st.session_state.selected_draw_types = st.multiselect('Which plot(s) to draw?',
st.session_state.selected_draw_types = cols[2].multiselect('Which plot(s) to draw?',
st.session_state.draw_type_mapper_session_level.keys(),
default=st.session_state.draw_type_mapper_session_level.keys()
if 'selected_draw_types' not in st.session_state else
st.session_state.selected_draw_types)
if need_click:
draw_it = st.button(f'Show me all {n_session_to_draw} sessions!', use_container_width=True)
else:
draw_it = True
return draw_it

def mouse_plot_settings(need_click=True):
st.markdown('##### Show plots for individual mice ')
cols = st.columns([2, 1])
st.session_state.selected_draw_mice = cols[0].selectbox('Which mice to draw?',
[f'selected from table/plot ({len(st.session_state.df_selected_from_plotly.h2o.unique())} mice)',
f'filtered from sidebar ({len(st.session_state.df_session_filtered.h2o.unique())} mice)'],
index=0
)
st.session_state.num_cols_mice = cols[1].number_input('Number of columns', 1, 10,
3 if 'num_cols_mice' not in st.session_state else st.session_state.num_cols_mice)
st.markdown(
"""
<style>
.stMultiSelect [data-baseweb=select] span{
max-width: 1000px;
}
</style>""",
unsafe_allow_html=True,
)
st.session_state.selected_draw_types_mice = st.multiselect('Which plot(s) to draw?',
st.session_state.draw_type_mapper_mouse_level.keys(),
default=st.session_state.draw_type_mapper_mouse_level.keys()
if 'selected_draw_types_mice' not in st.session_state else
st.session_state.selected_draw_types_mice)
if need_click:
draw_it = st.button('Show me all mice!', use_container_width=True)
draw_it = st.button(f'Show me all {n_session_to_draw} sessions!', use_container_width=True, type="primary")
draw_it_now_override = cols[1].checkbox('Auto draw')
else:
draw_it = True
return draw_it
draw_it_now_override = True
return draw_it | draw_it_now_override


def plot_x_y_session():

cols = st.columns([4, 10])
cols = st.columns([1, 1, 1])

with cols[0]:

x_name, y_name, group_by = add_xy_selector(if_bonsai=True)

with cols[1]:
(if_show_dots, if_aggr_each_group, aggr_method_group, if_use_x_quantile_group, q_quantiles_group,
if_aggr_all, aggr_method_all, if_use_x_quantile_all, q_quantiles_all, smooth_factor,
dot_size, dot_opacity, line_width) = add_xy_setting()


if_aggr_all, aggr_method_all, if_use_x_quantile_all, q_quantiles_all, smooth_factor, if_show_diagonal,
dot_size, dot_opacity, line_width, x_y_plot_figure_width, x_y_plot_figure_height, font_size_scale) = add_xy_setting()

if st.session_state.x_y_plot_if_show_dots:
with cols[2]:
size_mapper, size_mapper_range, size_mapper_gamma = add_dot_property_mapper()
else:
size_mapper = 'None'
size_mapper_range, size_mapper_gamma = None, None

# If no sessions are selected, use all filtered entries
# df_x_y_session = st.session_state.df_selected_from_dataframe if if_plot_only_selected_from_dataframe else st.session_state.df_session_filtered
Expand All @@ -432,33 +458,55 @@ def plot_x_y_session():
df_selected_from_plotly = pd.DataFrame()
# for i, (title, (x_name, y_name)) in enumerate(names.items()):
# with cols[i]:
with cols[1]:
fig = _plot_population_x_y(df=df_x_y_session,
x_name=x_name, y_name=y_name,
group_by=group_by,
smooth_factor=smooth_factor,
if_show_dots=if_show_dots,
if_aggr_each_group=if_aggr_each_group,
if_aggr_all=if_aggr_all,
aggr_method_group=aggr_method_group,
aggr_method_all=aggr_method_all,
if_use_x_quantile_group=if_use_x_quantile_group,
q_quantiles_group=q_quantiles_group,
if_use_x_quantile_all=if_use_x_quantile_all,
q_quantiles_all=q_quantiles_all,
title=names[(x_name, y_name)] if (x_name, y_name) in names else y_name,
states = st.session_state.df_selected_from_plotly,
dot_size=dot_size,
dot_opacity=dot_opacity,
line_width=line_width)

if hasattr(st.session_state, 'x_y_plot_figure_width'):
_x_y_plot_scale = st.session_state.x_y_plot_figure_width / 1300
cols = st.columns([1 * _x_y_plot_scale, 0.7])
else:
cols = st.columns([1, 0.7])
with cols[0]:
fig = _plot_population_x_y(df=df_x_y_session.copy(),
x_name=x_name, y_name=y_name,
group_by=group_by,
smooth_factor=smooth_factor,
if_show_dots=if_show_dots,
if_aggr_each_group=if_aggr_each_group,
if_aggr_all=if_aggr_all,
aggr_method_group=aggr_method_group,
aggr_method_all=aggr_method_all,
if_use_x_quantile_group=if_use_x_quantile_group,
q_quantiles_group=q_quantiles_group,
if_use_x_quantile_all=if_use_x_quantile_all,
q_quantiles_all=q_quantiles_all,
title=names[(x_name, y_name)] if (x_name, y_name) in names else y_name,
states = st.session_state.df_selected_from_plotly,
if_show_diagonal=if_show_diagonal,
dot_size_base=dot_size,
dot_size_mapping_name=size_mapper,
dot_size_mapping_range=size_mapper_range,
dot_size_mapping_gamma=size_mapper_gamma,
dot_opacity=dot_opacity,
line_width=line_width,
x_y_plot_figure_width=x_y_plot_figure_width,
x_y_plot_figure_height=x_y_plot_figure_height,
font_size_scale=font_size_scale,
)

# st.plotly_chart(fig)
selected = plotly_events(fig, click_event=True, hover_event=False, select_event=True,
override_height=fig.layout.height * 1.1, override_width=fig.layout.width)

with cols[1]:
st.markdown('#### 👀 Quick preview')
st.markdown('###### Click on one session to preview here, or Box/Lasso select multiple sessions to draw them in the section below')
st.markdown('(sometimes you have to click twice...)')

if len(selected):
df_selected_from_plotly = df_x_y_session.merge(pd.DataFrame(selected).rename({'x': x_name, 'y': y_name}, axis=1),
on=[x_name, y_name], how='inner')
if len(st.session_state.df_selected_from_plotly) == 1:
with cols[1]:
draw_session_plots_quick_preview(st.session_state.df_selected_from_plotly)

return df_selected_from_plotly, cols

Expand All @@ -477,7 +525,7 @@ def init():

# Set session state from URL
for key, default in to_sync_with_url_query.items():
_sync_widget_with_query(key, default)
sync_widget_with_query(key, default)

df = load_data(['sessions',
])
Expand Down Expand Up @@ -573,6 +621,13 @@ def init():
# map user_name
st.session_state.df['sessions_bonsai']['user_name'] = st.session_state.df['sessions_bonsai']['user_name'].apply(_user_name_mapper)

# fill nan for autotrain fields
filled_values = {'curriculum_name': 'None',
'curriculum_version': 'None',
'curriculum_schema_version': 'None',
'current_stage_actual': 'None'}
st.session_state.df['sessions_bonsai'].fillna(filled_values, inplace=True)

# foraging performance = foraing_eff * finished_rate
if 'foraging_performance' not in st.session_state.df['sessions_bonsai'].columns:
st.session_state.df['sessions_bonsai']['foraging_performance'] = \
Expand Down Expand Up @@ -634,8 +689,15 @@ def app():
init()
st.rerun()

table_height = cols[3].slider('Table height', 100, 2000, 400, 50, key='table_height')

table_height = slider_wrapper_for_url_query(st_prefix=cols[3],
label='Table height',
min_value=0,
max_value=2000,
default=300,
step=50,
key='table_height',
)

# aggrid_outputs = aggrid_interactive_table_units(df=df['ephys_units'])
# st.session_state.df_session_filtered = aggrid_outputs['data']

Expand Down Expand Up @@ -673,8 +735,8 @@ def app():
with placeholder:
df_selected_from_plotly, x_y_cols = plot_x_y_session()

with x_y_cols[0]:
for i in range(7): st.write('\n')
# Add session_plot_setting
with st.columns([1, 0.5])[0]:
st.markdown("***")
if_draw_all_sessions = session_plot_settings()

Expand Down Expand Up @@ -722,7 +784,7 @@ def app():

elif chosen_id == "tab_session_inspector":
with placeholder:
cols = st.columns([6, 3, 7])
cols = st.columns([1, 0.5])
with cols[0]:
if_draw_all_sessions = session_plot_settings(need_click=False)
df_to_draw_sessions = st.session_state.df_selected_from_plotly if 'selected' in st.session_state.selected_draw_sessions else st.session_state.df_session_filtered
Expand Down Expand Up @@ -832,6 +894,8 @@ def app():

# Add debug info
if chosen_id != "tab_auto_train_curriculum":
for _ in range(10): st.write('\n')
st.markdown('---\n##### Debug zone')
with st.expander('CO processing NWB errors', expanded=False):
error_file = cache_folder + 'error_files.json'
if fs.exists(error_file):
Expand Down
14 changes: 8 additions & 6 deletions code/pages/1_Old mice.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,8 +19,10 @@
from streamlit_plotly_events import plotly_events

from util.streamlit import (filter_dataframe, aggrid_interactive_table_session, add_session_filter, data_selector,
add_xy_selector, _sync_widget_with_query, add_xy_setting, add_auto_train_manager,
add_xy_selector, add_xy_setting, add_auto_train_manager,
_plot_population_x_y)
from util.url_query_helper import sync_widget_with_query

import extra_streamlit_components as stx

from aind_auto_train.auto_train_manager import DynamicForagingAutoTrainManager
Expand All @@ -41,7 +43,7 @@

'tab_id': 'tab_session_x_y',
'x_y_plot_xname': 'session',
'x_y_plot_yname': 'foraging_eff',
'x_y_plot_yname': 'foraging_performance',
'x_y_plot_group_by': 'h2o',
'x_y_plot_if_show_dots': True,
'x_y_plot_if_aggr_each_group': True,
Expand Down Expand Up @@ -354,8 +356,8 @@ def plot_x_y_session():
x_name, y_name, group_by = add_xy_selector(if_bonsai=False)

(if_show_dots, if_aggr_each_group, aggr_method_group, if_use_x_quantile_group, q_quantiles_group,
if_aggr_all, aggr_method_all, if_use_x_quantile_all, q_quantiles_all, smooth_factor,
dot_size, dot_opacity, line_width) = add_xy_setting()
if_aggr_all, aggr_method_all, if_use_x_quantile_all, q_quantiles_all, smooth_factor, if_show_diagonal,
dot_size, dot_opacity, line_width, _, _, _) = add_xy_setting()


# If no sessions are selected, use all filtered entries
Expand Down Expand Up @@ -385,7 +387,7 @@ def plot_x_y_session():
q_quantiles_all=q_quantiles_all,
title=names[(x_name, y_name)] if (x_name, y_name) in names else y_name,
states = st.session_state.df_selected_from_plotly,
dot_size=dot_size,
dot_size_base=dot_size,
dot_opacity=dot_opacity,
line_width=line_width)

Expand All @@ -410,7 +412,7 @@ def init():

# Set session state from URL
for key, default in to_sync_with_url_query.items():
_sync_widget_with_query(key, default)
sync_widget_with_query(key, default)

df = load_data(['sessions',
'logistic_regression_hattori',
Expand Down
Loading

0 comments on commit 0b418f2

Please sign in to comment.