From e7ac26d745915482c069d0b396752ba7aa9e8c9e Mon Sep 17 00:00:00 2001 From: hh_itx_win10 Date: Sun, 7 Apr 2024 12:35:45 -0700 Subject: [PATCH 1/9] chore: remove unused func --- code/Home.py | 12 ------------ 1 file changed, 12 deletions(-) diff --git a/code/Home.py b/code/Home.py index 33d3637..5db8982 100644 --- a/code/Home.py +++ b/code/Home.py @@ -125,19 +125,7 @@ if 'selected_points' not in st.session_state: st.session_state['selected_points'] = [] - -def _get_urls(): - cache_folder = 'aind-behavior-data/Han/ephys/report/st_cache/' - cache_session_level_fig_folder = 'aind-behavior-data/Han/ephys/report/all_sessions/' - cache_mouse_level_fig_folder = 'aind-behavior-data/Han/ephys/report/all_subjects/' - - fs = s3fs.S3FileSystem(anon=False) - - with fs.open('aind-behavior-data/Han/streamlit_CO_url.json', 'r') as f: - data = json.load(f) - - return data['behavior'], data['ephys'] @st.cache_data(ttl=24*3600) def load_data(tables=['sessions']): From 3e7c826b11b9dd5a7d01e6bfcd5a77cb8a0a402f Mon Sep 17 00:00:00 2001 From: hh_itx_win10 Date: Sun, 7 Apr 2024 13:09:43 -0700 Subject: [PATCH 2/9] feat: merge df_session from bpod! --- code/Home.py | 47 +++++++++++++++++++++++++++++------------------ 1 file changed, 29 insertions(+), 18 deletions(-) diff --git a/code/Home.py b/code/Home.py index 5db8982..37db1d6 100644 --- a/code/Home.py +++ b/code/Home.py @@ -43,7 +43,7 @@ add_xy_selector, add_xy_setting, add_auto_train_manager, add_dot_property_mapper, _plot_population_x_y) from util.url_query_helper import ( - sync_widget_with_query, slider_wrapper_for_url_query, + sync_widget_with_query, slider_wrapper_for_url_query, checkbox_wrapper_for_url_query ) import extra_streamlit_components as stx @@ -57,6 +57,8 @@ # dict of "key": default pairs # Note: When creating the widget, add argument "value"/"index" as well as "key" for all widgets you want to sync with URL to_sync_with_url_query = { + 'if_load_bpod_sessions': True, + 'filter_subject_id': '', 'filter_session': [0.0, None], 'filter_finished_trials': [0.0, None], @@ -102,10 +104,10 @@ } -raw_nwb_folder = 'aind-behavior-data/foraging_nwb_bonsai/' -cache_folder = 'aind-behavior-data/foraging_nwb_bonsai_processed/' -# cache_session_level_fig_folder = 'aind-behavior-data/Han/ephys/report/all_sessions/' -# cache_mouse_level_fig_folder = 'aind-behavior-data/Han/ephys/report/all_subjects/' +data_sources = ['bonsai', 'bpod'] + +s3_nwb_folder = {data: f'aind-behavior-data/foraging_nwb_{data}/' for data in data_sources} +s3_processed_nwb_folder = {data: f'aind-behavior-data/foraging_nwb_{data}_processed/' for data in data_sources} fs = s3fs.S3FileSystem(anon=False) st.session_state.use_s3 = True @@ -124,14 +126,12 @@ if 'selected_points' not in st.session_state: st.session_state['selected_points'] = [] - - @st.cache_data(ttl=24*3600) -def load_data(tables=['sessions']): +def load_data(tables=['sessions'], data_source = 'bonsai'): df = {} for table in tables: - file_name = cache_folder + f'df_{table}.pkl' + file_name = s3_processed_nwb_folder[data_source] + f'df_{table}.pkl' if st.session_state.use_s3: with fs.open(file_name) as f: df[table + '_bonsai'] = pd.read_pickle(f) @@ -139,7 +139,7 @@ def load_data(tables=['sessions']): df[table + '_bonsai'] = pd.read_pickle(file_name) return df -def _fetch_img(glob_patterns, crop=None): +def _fetch_img(glob_patterns, crop=None): # Fetch the img that first matches the patterns for pattern in glob_patterns: file = fs.glob(pattern) if st.session_state.use_s3 else glob.glob(pattern) @@ -187,7 +187,7 @@ def show_session_level_img_by_key_and_prefix(key, prefix, column=None, other_pat # Convert session_date to 2024-04-01 format subject_session_date_str = f"{key['subject_id']}_{date_str}_{key['nwb_suffix']}".split('_0')[0] - glob_patterns = [cache_folder + f"{subject_session_date_str}/{subject_session_date_str}_{prefix}*"] + glob_patterns = [s3_processed_nwb_folder['bonsai'] + f"{subject_session_date_str}/{subject_session_date_str}_{prefix}*"] img, f_name = _fetch_img(glob_patterns, crop) @@ -515,8 +515,12 @@ def init(): for key, default in to_sync_with_url_query.items(): sync_widget_with_query(key, default) - df = load_data(['sessions', - ]) + df = load_data(['sessions'], data_source='bonsai') + + if st.session_state.if_load_bpod_sessions: + df_bpod = load_data(['sessions'], data_source='bpod') + # For historial reason, the suffix of df['sessions_bonsai'] just mean the data of the Home.py page + df['sessions_bonsai'] = pd.concat([df['sessions_bonsai'], df_bpod['sessions_bonsai']], axis=0) st.session_state.df = df st.session_state.df_selected_from_plotly = pd.DataFrame(columns=['h2o', 'session']) @@ -638,7 +642,7 @@ def app(): cols = st.columns([1, 1.2]) with cols[0]: - st.markdown('## 🌳🪴 Foraging sessions from Bonsai 🌳🪴') + st.markdown('## 🌳🪴 Dynamic Foraging Sessions 🌳🪴') with st.sidebar: @@ -670,8 +674,15 @@ def app(): cols[0].markdown(f'### Filter the sessions on the sidebar\n' f'##### {len(st.session_state.df_session_filtered)} sessions, ' f'{len(st.session_state.df_session_filtered.h2o.unique())} mice filtered') + + if_load_bpod_sessions = checkbox_wrapper_for_url_query( + st_prefix=cols[1], + label='Include old Bpod sessions', + key='if_load_bpod_sessions', + default=True, + ) + with cols[1]: - st.markdown('# ') if st.button(' Reload data ', type='primary'): st.cache_data.clear() init() @@ -885,7 +896,7 @@ def app(): for _ in range(10): st.write('\n') st.markdown('---\n##### Debug zone') with st.expander('CO processing NWB errors', expanded=False): - error_file = cache_folder + 'error_files.json' + error_file = s3_processed_nwb_folder['bonsai'] + 'error_files.json' if fs.exists(error_file): with fs.open(error_file) as file: st.json(json.load(file)) @@ -893,13 +904,13 @@ def app(): st.write('No NWB error files') with st.expander('CO Pipeline log', expanded=False): - with fs.open(cache_folder + 'pipeline.log') as file: + with fs.open(s3_processed_nwb_folder['bonsai'] + 'pipeline.log') as file: log_content = file.read().decode('utf-8') log_content = log_content.replace('\\n', '\n') st.text(log_content) with st.expander('NWB convertion and upload log', expanded=False): - with fs.open(raw_nwb_folder + 'bonsai_pipeline.log') as file: + with fs.open(s3_nwb_folder['bonsai'] + 'bonsai_pipeline.log') as file: log_content = file.read().decode('utf-8') st.text(log_content) From c17b07fd33d0fd9da05482052a5ec4b938dba98a Mon Sep 17 00:00:00 2001 From: hh_itx_win10 Date: Sun, 7 Apr 2024 13:35:55 -0700 Subject: [PATCH 3/9] feat: handle h2o --- code/Home.py | 59 ++++++++++++++++++++++++++---------------- code/util/streamlit.py | 2 +- 2 files changed, 37 insertions(+), 24 deletions(-) diff --git a/code/Home.py b/code/Home.py index 37db1d6..3e9f5ac 100644 --- a/code/Home.py +++ b/code/Home.py @@ -517,8 +517,10 @@ def init(): df = load_data(['sessions'], data_source='bonsai') + # --- Perform any data source-dependent preprocessing here --- if st.session_state.if_load_bpod_sessions: df_bpod = load_data(['sessions'], data_source='bpod') + # For historial reason, the suffix of df['sessions_bonsai'] just mean the data of the Home.py page df['sessions_bonsai'] = pd.concat([df['sessions_bonsai'], df_bpod['sessions_bonsai']], axis=0) @@ -588,51 +590,62 @@ def init(): # Some ad-hoc modifications on df_sessions - st.session_state.df['sessions_bonsai'].columns = st.session_state.df['sessions_bonsai'].columns.get_level_values(1) - st.session_state.df['sessions_bonsai'].sort_values(['session_end_time'], ascending=False, inplace=True) - st.session_state.df['sessions_bonsai'] = st.session_state.df['sessions_bonsai'].reset_index().query('subject_id != "0"') - st.session_state.df['sessions_bonsai']['h2o'] = st.session_state.df['sessions_bonsai']['subject_id'] - st.session_state.df['sessions_bonsai'].dropna(subset=['session'], inplace=True) # Remove rows with no session number (only leave the nwb file with the largest finished_trials for now) - st.session_state.df['sessions_bonsai'].drop(st.session_state.df['sessions_bonsai'].query('session < 1').index, inplace=True) + _df = st.session_state.df['sessions_bonsai'] # temporary df alias + + _df.columns = _df.columns.get_level_values(1) + _df.sort_values(['session_start_time'], ascending=False, inplace=True) + _df = _df.reset_index().query('subject_id != "0"') + + # Fill in h2o (water restriction number, or mouse alias) + if 'bpod_backup_h2o' in _df.columns: + _df['h2o'] = np.where(_df['bpod_backup_h2o'].notnull(), _df['bpod_backup_h2o'], _df['subject_id']) + else: + _df['h2o'] = _df['subject_id'] + + # Handle session number + _df.dropna(subset=['session'], inplace=True) # Remove rows with no session number (only leave the nwb file with the largest finished_trials for now) + _df.drop(_df.query('session < 1').index, inplace=True) # # add something else # add abs(bais) to all terms that have 'bias' in name - for col in st.session_state.df['sessions_bonsai'].columns: + for col in _df.columns: if 'bias' in col: - st.session_state.df['sessions_bonsai'][f'abs({col})'] = np.abs(st.session_state.df['sessions_bonsai'][col]) + _df[f'abs({col})'] = np.abs(_df[col]) # # delta weight - # diff_relative_weight_next_day = st.session_state.df['sessions_bonsai'].set_index( + # diff_relative_weight_next_day = _df.set_index( # ['session']).sort_values('session', ascending=True).groupby('h2o').apply( # lambda x: - x.relative_weight.diff(periods=-1)).rename("diff_relative_weight_next_day") # weekday - st.session_state.df['sessions_bonsai'].session_date = pd.to_datetime(st.session_state.df['sessions_bonsai'].session_date) - st.session_state.df['sessions_bonsai']['weekday'] = st.session_state.df['sessions_bonsai'].session_date.dt.dayofweek + 1 + _df.session_date = pd.to_datetime(_df.session_date) + _df['weekday'] = _df.session_date.dt.dayofweek + 1 # map user_name - st.session_state.df['sessions_bonsai']['user_name'] = st.session_state.df['sessions_bonsai']['user_name'].apply(_user_name_mapper) + _df['user_name'] = _df['user_name'].apply(_user_name_mapper) # fill nan for autotrain fields filled_values = {'curriculum_name': 'None', 'curriculum_version': 'None', 'curriculum_schema_version': 'None', 'current_stage_actual': 'None'} - st.session_state.df['sessions_bonsai'].fillna(filled_values, inplace=True) + _df.fillna(filled_values, inplace=True) # foraging performance = foraing_eff * finished_rate - if 'foraging_performance' not in st.session_state.df['sessions_bonsai'].columns: - st.session_state.df['sessions_bonsai']['foraging_performance'] = \ - st.session_state.df['sessions_bonsai']['foraging_eff'] \ - * st.session_state.df['sessions_bonsai']['finished_rate'] - st.session_state.df['sessions_bonsai']['foraging_performance_random_seed'] = \ - st.session_state.df['sessions_bonsai']['foraging_eff_random_seed'] \ - * st.session_state.df['sessions_bonsai']['finished_rate'] - - # st.session_state.df['sessions_bonsai'] = st.session_state.df['sessions_bonsai'].merge( + if 'foraging_performance' not in _df.columns: + _df['foraging_performance'] = \ + _df['foraging_eff'] \ + * _df['finished_rate'] + _df['foraging_performance_random_seed'] = \ + _df['foraging_eff_random_seed'] \ + * _df['finished_rate'] + + # _df = _df.merge( # diff_relative_weight_next_day, how='left', on=['h2o', 'session']) - st.session_state.session_stats_names = [keys for keys in st.session_state.df['sessions_bonsai'].keys()] + st.session_state.df['sessions_bonsai'] = _df # Somehow _df loses the reference to the original dataframe + + st.session_state.session_stats_names = [keys for keys in _df.keys()] # Establish communication between pygwalker and streamlit init_streamlit_comm() diff --git a/code/util/streamlit.py b/code/util/streamlit.py index 662bdf3..f36785b 100644 --- a/code/util/streamlit.py +++ b/code/util/streamlit.py @@ -58,7 +58,7 @@ def aggrid_interactive_table_session(df: pd.DataFrame, table_height: int = 400): options.configure_side_bar() if 'session_end_time' in df.columns: - df = df.sort_values('session_end_time', ascending=False) + df = df.sort_values('session_start_time', ascending=False) else: df = df.sort_values('session_date', ascending=False) From 77620bf7250904775d55a6f4d0849585d0751dc3 Mon Sep 17 00:00:00 2001 From: hh_itx_win10 Date: Sun, 7 Apr 2024 13:41:26 -0700 Subject: [PATCH 4/9] feat: better handle user_name --- code/Home.py | 7 +++++-- 1 file changed, 5 insertions(+), 2 deletions(-) diff --git a/code/Home.py b/code/Home.py index 3e9f5ac..1832c36 100644 --- a/code/Home.py +++ b/code/Home.py @@ -596,9 +596,10 @@ def init(): _df.sort_values(['session_start_time'], ascending=False, inplace=True) _df = _df.reset_index().query('subject_id != "0"') - # Fill in h2o (water restriction number, or mouse alias) + # Handle mouse and user name if 'bpod_backup_h2o' in _df.columns: _df['h2o'] = np.where(_df['bpod_backup_h2o'].notnull(), _df['bpod_backup_h2o'], _df['subject_id']) + _df['user_name'] = np.where(_df['bpod_backup_user_name'].notnull(), _df['bpod_backup_user_name'], _df['user_name']) else: _df['h2o'] = _df['subject_id'] @@ -628,7 +629,9 @@ def init(): filled_values = {'curriculum_name': 'None', 'curriculum_version': 'None', 'curriculum_schema_version': 'None', - 'current_stage_actual': 'None'} + 'current_stage_actual': 'None', + 'has_video': False, + 'has_ephys': False,} _df.fillna(filled_values, inplace=True) # foraging performance = foraing_eff * finished_rate From a9efe95a1f13891a5987d2b7732f18179a5d80a6 Mon Sep 17 00:00:00 2001 From: hh_itx_win10 Date: Sun, 7 Apr 2024 13:53:23 -0700 Subject: [PATCH 5/9] feat: show hover when dots are not shown --- code/util/streamlit.py | 21 ++++++++++++--------- 1 file changed, 12 insertions(+), 9 deletions(-) diff --git a/code/util/streamlit.py b/code/util/streamlit.py index f36785b..2ee4ab8 100644 --- a/code/util/streamlit.py +++ b/code/util/streamlit.py @@ -769,7 +769,7 @@ def _plot_population_x_y(df, x_name='session', y_name='foraging_eff', group_by=' font_size_scale=1.0, **kwarg): - def _add_agg(df_this, x_name, y_name, group, aggr_method, if_use_x_quantile, q_quantiles, col, line_width, **kwarg): + def _add_agg(df_this, x_name, y_name, group, aggr_method, if_use_x_quantile, q_quantiles, col, line_width, hoverinfo='skip', **kwarg): x = df_this.sort_values(x_name)[x_name].astype(float) y = df_this.sort_values(x_name)[y_name].astype(float) @@ -788,7 +788,7 @@ def _add_agg(df_this, x_name, y_name, group, aggr_method, if_use_x_quantile, q_q line_width=line_width, opacity=1, hoveron='points+fills', # Scattergl doesn't support this - hoverinfo='skip', + hoverinfo=hoverinfo, **kwarg, )) @@ -806,7 +806,7 @@ def _add_agg(df_this, x_name, y_name, group, aggr_method, if_use_x_quantile, q_q marker_color=col, opacity=1, hoveron='points+fills', # Scattergl doesn't support this - hoverinfo='skip', + hoverinfo=hoverinfo, **kwarg, )) @@ -853,7 +853,7 @@ def _add_agg(df_this, x_name, y_name, group, aggr_method, if_use_x_quantile, q_q marker_color=col, opacity=1, hoveron='points+fills', # Scattergl doesn't support this - hoverinfo='skip', + hoverinfo=hoverinfo, **kwarg, )) @@ -867,7 +867,7 @@ def _add_agg(df_this, x_name, y_name, group, aggr_method, if_use_x_quantile, q_q line=dict(width=0), legendgroup=f'group_{group}', showlegend=False, - hoverinfo='skip', + hoverinfo=hoverinfo, )) fig.add_trace(go.Scatter( # name='Upper Bound', @@ -880,7 +880,7 @@ def _add_agg(df_this, x_name, y_name, group, aggr_method, if_use_x_quantile, q_q fillcolor=f'rgba({plotly.colors.convert_colors_to_same_type(col)[0][0].split("(")[-1][:-1]}, 0.2)', legendgroup=f'group_{group}', showlegend=False, - hoverinfo='skip' + hoverinfo=hoverinfo )) elif aggr_method == 'linear fit': @@ -897,7 +897,7 @@ def _add_agg(df_this, x_name, y_name, group, aggr_method, if_use_x_quantile, q_q line=dict(dash='dot' if p_value > 0.05 else 'solid', width=line_width if p_value > 0.05 else line_width*1.5), legendgroup=f'group_{group}', - hoverinfo='skip' + hoverinfo=hoverinfo ) ) except: @@ -982,10 +982,13 @@ def _add_agg(df_this, x_name, y_name, group, aggr_method, if_use_x_quantile, q_q )) if if_aggr_each_group: - _add_agg(this_session, x_name, y_name, group, aggr_method_group, if_use_x_quantile_group, q_quantiles_group, col, line_width=line_width) + _add_agg(this_session, x_name, y_name, group, aggr_method_group, + if_use_x_quantile_group, q_quantiles_group, col, line_width=line_width, + hoverinfo='all' if not if_show_dots else 'skip') if if_aggr_all: - _add_agg(df, x_name, y_name, 'all', aggr_method_all, if_use_x_quantile_all, q_quantiles_all, 'rgb(0, 0, 0)', line_width=line_width*1.5) + _add_agg(df, x_name, y_name, 'all', aggr_method_all, if_use_x_quantile_all, q_quantiles_all, 'rgb(0, 0, 0)', line_width=line_width*1.5, + hoverinfo='all' if not if_show_dots else 'skip') n_mice = len(df['h2o'].unique()) n_sessions = len(df.groupby(['h2o', 'session']).count()) From 760428ff0a8ab603b8a445876ace8a25022da5f7 Mon Sep 17 00:00:00 2001 From: hh_itx_win10 Date: Sun, 7 Apr 2024 14:37:59 -0700 Subject: [PATCH 6/9] feat: add data_source --- code/Home.py | 31 ++++++++++++++++++++++++++++++- code/util/streamlit.py | 2 +- 2 files changed, 31 insertions(+), 2 deletions(-) diff --git a/code/Home.py b/code/Home.py index 1832c36..e80f047 100644 --- a/code/Home.py +++ b/code/Home.py @@ -602,6 +602,34 @@ def init(): _df['user_name'] = np.where(_df['bpod_backup_user_name'].notnull(), _df['bpod_backup_user_name'], _df['user_name']) else: _df['h2o'] = _df['subject_id'] + + + def _get_data_source(rig): + """From rig string, return "{institute}_{rig_type}_{room}_{hardware}" + """ + institute = 'Janelia' if ('bpod' in rig) and not ('AIND' in rig) else 'AIND' + hardware = 'bpod' if ('bpod' in rig) else 'bonsai' + rig_type = 'ephys' if ('ephys' in rig.lower()) else 'training' + + # This is a mess... + if institute == 'Janelia': + room = 'NA' + elif 'Ephys-Han' in rig: + room = '321' + elif hardware == 'bpod': + room = '347' + elif '447' in rig: + room = '447' + elif '323' in rig: + room = '323' + elif rig_type == 'ephys': + room = '323' + else: + room = '447' + return institute, rig_type, room, hardware, '_'.join([institute, rig_type, room, hardware]) + + # Add data source (Room + Hardware etc) + _df[['institute', 'rig_type', 'room', 'hardware', 'data_source']] = _df['rig'].apply(lambda x: pd.Series(_get_data_source(x))) # Handle session number _df.dropna(subset=['session'], inplace=True) # Remove rows with no session number (only leave the nwb file with the largest finished_trials for now) @@ -631,7 +659,8 @@ def init(): 'curriculum_schema_version': 'None', 'current_stage_actual': 'None', 'has_video': False, - 'has_ephys': False,} + 'has_ephys': False, + } _df.fillna(filled_values, inplace=True) # foraging performance = foraing_eff * finished_rate diff --git a/code/util/streamlit.py b/code/util/streamlit.py index 2ee4ab8..cdd6e64 100644 --- a/code/util/streamlit.py +++ b/code/util/streamlit.py @@ -396,7 +396,7 @@ def add_session_filter(if_bonsai=False, url_query={}): @st.cache_data(ttl=3600*24) def _get_grouped_by_fields(if_bonsai): if if_bonsai: - options = ['h2o', 'task', 'user_name', 'rig', 'weekday'] + options = ['h2o', 'task', 'user_name', 'rig', 'weekday', 'data_source'] options += [col for col in st.session_state.df_session_filtered.columns if is_categorical_dtype(st.session_state.df_session_filtered[col]) From ad76247bf65cb16572092944d052b19e194479ee Mon Sep 17 00:00:00 2001 From: hh_itx_win10 Date: Sun, 7 Apr 2024 14:39:22 -0700 Subject: [PATCH 7/9] minor --- code/util/streamlit.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/code/util/streamlit.py b/code/util/streamlit.py index cdd6e64..3c0a1ed 100644 --- a/code/util/streamlit.py +++ b/code/util/streamlit.py @@ -396,7 +396,7 @@ def add_session_filter(if_bonsai=False, url_query={}): @st.cache_data(ttl=3600*24) def _get_grouped_by_fields(if_bonsai): if if_bonsai: - options = ['h2o', 'task', 'user_name', 'rig', 'weekday', 'data_source'] + options = ['h2o', 'task', 'user_name', 'rig', 'data_source', 'weekday'] options += [col for col in st.session_state.df_session_filtered.columns if is_categorical_dtype(st.session_state.df_session_filtered[col]) From 1b3a9c219e74599ca05ed319d2f2e516e83a3235 Mon Sep 17 00:00:00 2001 From: hh_itx_win10 Date: Sun, 7 Apr 2024 14:57:22 -0700 Subject: [PATCH 8/9] feat: connect to bpod processed folder! --- code/Home.py | 79 ++++++++++++++++++++++++++-------------------------- 1 file changed, 40 insertions(+), 39 deletions(-) diff --git a/code/Home.py b/code/Home.py index e80f047..51c2875 100644 --- a/code/Home.py +++ b/code/Home.py @@ -12,7 +12,7 @@ """ -#%% +# %% import pandas as pd import streamlit as st from pathlib import Path @@ -57,7 +57,7 @@ # dict of "key": default pairs # Note: When creating the widget, add argument "value"/"index" as well as "key" for all widgets you want to sync with URL to_sync_with_url_query = { - 'if_load_bpod_sessions': True, + 'if_load_bpod_sessions': False, 'filter_subject_id': '', 'filter_session': [0.0, None], @@ -126,7 +126,7 @@ if 'selected_points' not in st.session_state: st.session_state['selected_points'] = [] - + @st.cache_data(ttl=24*3600) def load_data(tables=['sessions'], data_source = 'bonsai'): df = {} @@ -179,7 +179,7 @@ def _user_name_mapper(user_name): return user_name # @st.cache_data(ttl=24*3600, max_entries=20) -def show_session_level_img_by_key_and_prefix(key, prefix, column=None, other_patterns=[''], crop=None, caption=True, **kwargs): +def show_session_level_img_by_key_and_prefix(key, prefix, column=None, other_patterns=[''], crop=None, caption=True, data_source='bonsai', **kwargs): try: date_str = key["session_date"].strftime(r'%Y-%m-%d') except: @@ -187,7 +187,7 @@ def show_session_level_img_by_key_and_prefix(key, prefix, column=None, other_pat # Convert session_date to 2024-04-01 format subject_session_date_str = f"{key['subject_id']}_{date_str}_{key['nwb_suffix']}".split('_0')[0] - glob_patterns = [s3_processed_nwb_folder['bonsai'] + f"{subject_session_date_str}/{subject_session_date_str}_{prefix}*"] + glob_patterns = [s3_processed_nwb_folder[data_source] + f"{subject_session_date_str}/{subject_session_date_str}_{prefix}*"] img, f_name = _fetch_img(glob_patterns, crop) @@ -231,7 +231,7 @@ def show_mouse_level_img_by_key_and_prefix(key, prefix, column=None, other_patte def get_pyg_renderer(df, spec="./gw_config.json", **kwargs) -> "StreamlitRenderer": return StreamlitRenderer(df, spec=spec, debug=False, **kwargs) - + def draw_session_plots(df_to_draw_session): # Setting up layout for each session @@ -264,8 +264,9 @@ def draw_session_plots(df_to_draw_session): except: date_str = key["session_date"].split("T")[0] - st.markdown(f'''

{key["h2o"]}, Session {int(key["session"])}, {date_str}''', - unsafe_allow_html=True) + st.markdown(f'''

{key["h2o"]}, Session {int(key["session"])}, {date_str} ''' + f'''({key["user_name"]}@{key["data_source"]})''', + unsafe_allow_html=True) if len(st.session_state.selected_draw_types) > 1: # more than one types, use the pre-defined layout for row, column_setting in enumerate(layout_definition): rows.append(this_major_col.columns(column_setting)) @@ -278,33 +279,35 @@ def draw_session_plots(df_to_draw_session): prefix, position, setting = st.session_state.draw_type_mapper_session_level[draw_type] this_col = rows[position[0]][position[1]] if len(st.session_state.selected_draw_types) > 1 else rows[0] show_session_level_img_by_key_and_prefix(key, - column=this_col, - prefix=prefix, - **setting) + column=this_col, + prefix=prefix, + data_source=key['hardware'], + **setting) my_bar.progress(int((i + 1) / len(df_to_draw_session) * 100)) - + def draw_session_plots_quick_preview(df_to_draw_session): - + # Setting up layout for each session layout_definition = [[1], # columns in the first row [1, 1], ] draw_types_quick_preview = ['1. Choice history', '2. Logistic regression (Su2022)'] - + container_session_all_in_one = st.container() - + key = df_to_draw_session.to_dict(orient='records')[0] - + with container_session_all_in_one: try: date_str = key["session_date"].strftime('%Y-%m-%d') except: date_str = key["session_date"].split("T")[0] - - st.markdown(f'''

{key["h2o"]}, Session {int(key["session"])}, {date_str}''', + + st.markdown(f'''

{key["h2o"]}, Session {int(key["session"])}, {date_str} ''' + f'''({key["user_name"]}@{key["data_source"]})''', unsafe_allow_html=True) - + rows = [] for row, column_setting in enumerate(layout_definition): rows.append(st.columns(column_setting)) @@ -313,14 +316,15 @@ def draw_session_plots_quick_preview(df_to_draw_session): if draw_type not in st.session_state.selected_draw_types: continue # To keep the draw order defined by st.session_state.draw_type_mapper_session_level prefix, position, setting = st.session_state.draw_type_mapper_session_level[draw_type] this_col = rows[position[0]][position[1]] if len(st.session_state.selected_draw_types) > 1 else rows[0] - show_session_level_img_by_key_and_prefix(key, - column=this_col, - prefix=prefix, - **setting) - - - - + show_session_level_img_by_key_and_prefix( + key, + column=this_col, + prefix=prefix, + data_source=key["hardware"], + **setting, + ) + + def draw_mice_plots(df_to_draw_mice): # Setting up layout for each session @@ -364,12 +368,8 @@ def draw_mice_plots(df_to_draw_mice): **setting) my_bar.progress(int((i + 1) / len(df_to_draw_mice) * 100)) - - - - def session_plot_settings(need_click=True): st.markdown('##### Show plots for individual sessions ') cols = st.columns([2, 1, 6]) @@ -499,11 +499,10 @@ def plot_x_y_session(): return df_selected_from_plotly, cols - def show_curriculums(): pass -# ------- Layout starts here -------- # +# ------- Layout starts here -------- # def init(): # Clear specific session state and all filters @@ -672,6 +671,9 @@ def _get_data_source(rig): _df['foraging_eff_random_seed'] \ * _df['finished_rate'] + # drop 'bpod_backup_' columns + _df.drop([col for col in _df.columns if 'bpod_backup_' in col], axis=1, inplace=True) + # _df = _df.merge( # diff_relative_weight_next_day, how='left', on=['h2o', 'session']) @@ -681,7 +683,7 @@ def _get_data_source(rig): # Establish communication between pygwalker and streamlit init_streamlit_comm() - + def app(): @@ -715,16 +717,16 @@ def app(): # with col1: # -- 1. unit dataframe -- - cols = st.columns([2, 1, 4, 1]) + cols = st.columns([2, 2, 4, 1]) cols[0].markdown(f'### Filter the sessions on the sidebar\n' f'##### {len(st.session_state.df_session_filtered)} sessions, ' f'{len(st.session_state.df_session_filtered.h2o.unique())} mice filtered') if_load_bpod_sessions = checkbox_wrapper_for_url_query( st_prefix=cols[1], - label='Include old Bpod sessions', + label='Include old Bpod sessions (reload after change)', key='if_load_bpod_sessions', - default=True, + default=False, ) with cols[1]: @@ -972,6 +974,5 @@ def app(): if 'df' not in st.session_state or 'sessions_bonsai' not in st.session_state.df.keys(): init() - -app() +app() From 900e4d62e07f7613952a81e77528b7668fc9a503 Mon Sep 17 00:00:00 2001 From: hh_itx_win10 Date: Sun, 7 Apr 2024 15:07:14 -0700 Subject: [PATCH 9/9] feat: show data_source in hovertop --- code/util/streamlit.py | 4 +++- 1 file changed, 3 insertions(+), 1 deletion(-) diff --git a/code/util/streamlit.py b/code/util/streamlit.py index 3c0a1ed..7b88bd2 100644 --- a/code/util/streamlit.py +++ b/code/util/streamlit.py @@ -953,7 +953,8 @@ def _add_agg(df_this, x_name, y_name, group, aggr_method, if_use_x_quantile, q_q marker_color=this_session['colors'], opacity=dot_opacity, hovertemplate = '%{customdata[0]}, %{customdata[1]}, Session %{customdata[2]}' - '
%{customdata[3]}, %{customdata[4]}' + '
%{customdata[4]} @ %{customdata[9]}' + '
Rig: %{customdata[3]}' '
Task: %{customdata[5]}' '
AutoTrain: %{customdata[7]} @ %{customdata[6]}
' f'
{"-"*10}
X: {x_name} = %{{x}}' @@ -977,6 +978,7 @@ def _add_agg(df_this, x_name, y_name, group, aggr_method, if_use_x_quantile, q_q this_session[dot_size_mapping_name] if dot_size_mapping_name !='None' else [np.nan] * len(this_session.h2o), # 8 + this_session.data_source if 'data_source' in this_session else '', # 9 ), axis=-1), unselected=dict(marker_color='lightgrey') ))