From 6e28cc38eadb3fbd99b7099aa01bba809f1756a2 Mon Sep 17 00:00:00 2001 From: "houhan@gmail.com" Date: Sat, 10 Aug 2024 20:06:46 +0000 Subject: [PATCH 01/11] build: update CO docker --- environment/Dockerfile | 12 ++++++------ environment/postInstall | 1 + 2 files changed, 7 insertions(+), 6 deletions(-) diff --git a/environment/Dockerfile b/environment/Dockerfile index 5f1fd11..082b153 100644 --- a/environment/Dockerfile +++ b/environment/Dockerfile @@ -90,14 +90,14 @@ RUN pip install -U --no-cache-dir \ pynwb --ignore-installed ruamel.yaml\ git+https://github.com/AllenNeuralDynamics/aind-foraging-behavior-bonsai-automatic-training.git@main\ pygwalker \ - scikit-learn==1.4.1 + scikit-learn==1.4 + +ADD "https://github.com/coder/code-server/releases/download/v4.21.1/code-server-4.21.1-linux-amd64.tar.gz" /.code-server/code-server.tar.gz -ADD "https://github.com/coder/code-server/releases/download/v4.9.0/code-server-4.9.0-linux-amd64.tar.gz" /.code-server/code-server.tar.gz - RUN cd /.code-server \ - && tar -xvf code-server.tar.gz \ - && rm code-server.tar.gz \ - && ln -s /.code-server/code-server-4.9.0-linux-amd64/bin/code-server /usr/bin/code-server + && tar -xvf code-server.tar.gz \ + && rm code-server.tar.gz \ + && ln -s /.code-server/code-server-4.21.1-linux-amd64/bin/code-server /usr/bin/code-server COPY postInstall / diff --git a/environment/postInstall b/environment/postInstall index 07c7644..6df0669 100755 --- a/environment/postInstall +++ b/environment/postInstall @@ -13,6 +13,7 @@ if [ ! -d "/.vscode/extensions" ] code-server --disable-telemetry --extensions-dir=/.vscode/extensions --install-extension njpwerner.autodocstring code-server --disable-telemetry --extensions-dir=/.vscode/extensions --install-extension KevinRose.vsc-python-indent code-server --disable-telemetry --extensions-dir=/.vscode/extensions --install-extension sugatoray.vscode-git-extension-pack + code-server --disable-telemetry --extensions-dir=/.vscode/extensions --install-extension mhutchie.git-graph else echo "code-server not found" From 96c95af724c4d97e3810c66ca02a9a0497dec217 Mon Sep 17 00:00:00 2001 From: "houhan@gmail.com" Date: Sat, 10 Aug 2024 20:19:06 +0000 Subject: [PATCH 02/11] fix: bugs in quick_preview --- code/util/aws_s3.py | 12 ++++-------- code/util/settings.py | 6 ++++++ 2 files changed, 10 insertions(+), 8 deletions(-) diff --git a/code/util/aws_s3.py b/code/util/aws_s3.py index 62c71c3..3728d94 100644 --- a/code/util/aws_s3.py +++ b/code/util/aws_s3.py @@ -5,7 +5,9 @@ import pandas as pd import streamlit as st -from .settings import draw_type_mapper_session_level +from .settings import ( + draw_type_layout_definition, draw_type_mapper_session_level, draw_types_quick_preview +) # -------------------------------------- data_sources = ['bonsai', 'bpod'] @@ -38,12 +40,6 @@ def load_data(tables=['sessions'], data_source = 'bonsai'): def draw_session_plots_quick_preview(df_to_draw_session): - # Setting up layout for each session - layout_definition = [[1], # columns in the first row - [1, 1], - ] - draw_types_quick_preview = ['1. Choice history', '2. Logistic regression (Su2022)'] - container_session_all_in_one = st.container() key = df_to_draw_session.to_dict(orient='records')[0] @@ -59,7 +55,7 @@ def draw_session_plots_quick_preview(df_to_draw_session): unsafe_allow_html=True) rows = [] - for row, column_setting in enumerate(layout_definition): + for row, column_setting in enumerate(draw_type_layout_definition): rows.append(st.columns(column_setting)) for draw_type in draw_types_quick_preview: diff --git a/code/util/settings.py b/code/util/settings.py index dac93c7..44b05ab 100644 --- a/code/util/settings.py +++ b/code/util/settings.py @@ -7,6 +7,7 @@ ] logistic_regression_models = ["Su2022", "Bari2019", "Hattori2019", "Miller2021"] + draw_type_mapper_session_level = { "1. Choice history": ( "choice_history", # prefix @@ -27,3 +28,8 @@ for n, model in enumerate(logistic_regression_models) }, } + +# For quick preview +draw_types_quick_preview = [ + '1. Choice history', + '3. Logistic regression (Su2022)'] From f16abd137d5cc0c5712546bd8cfe6b4b013eb304 Mon Sep 17 00:00:00 2001 From: "houhan@gmail.com" Date: Sat, 10 Aug 2024 20:23:24 +0000 Subject: [PATCH 03/11] feat: reorder AutoTrain --- code/Home.py | 4 ++-- code/util/url_query_helper.py | 6 +++--- 2 files changed, 5 insertions(+), 5 deletions(-) diff --git a/code/Home.py b/code/Home.py index 841f229..e963ad8 100644 --- a/code/Home.py +++ b/code/Home.py @@ -542,10 +542,10 @@ def app(): st.rerun() chosen_id = stx.tab_bar(data=[ - stx.TabBarItemData(id="tab_session_x_y", title="📈 Session X-Y plot", description="Interactive session-wise scatter plot"), + stx.TabBarItemData(id="tab_auto_train_history", title="🎓 Automatic Training History", description="Track progress"), stx.TabBarItemData(id="tab_session_inspector", title="👀 Session Inspector", description="Select sessions from the table and show plots"), + stx.TabBarItemData(id="tab_session_x_y", title="📈 Session X-Y plot", description="Interactive session-wise scatter plot"), stx.TabBarItemData(id="tab_pygwalker", title="📊 PyGWalker (Tableau)", description="Interactive dataframe explorer"), - stx.TabBarItemData(id="tab_auto_train_history", title="🎓 Automatic Training History", description="Track progress"), stx.TabBarItemData(id="tab_auto_train_curriculum", title="📚 Automatic Training Curriculums", description="Collection of curriculums"), # stx.TabBarItemData(id="tab_mouse_inspector", title="🐭 Mouse Inspector", description="Mouse-level summary"), ], default=st.query_params['tab_id'] if 'tab_id' in st.query_params diff --git a/code/util/url_query_helper.py b/code/util/url_query_helper.py index 140fe0d..fc53a3a 100644 --- a/code/util/url_query_helper.py +++ b/code/util/url_query_helper.py @@ -24,7 +24,7 @@ 'table_height': 300, - 'tab_id': 'tab_session_x_y', + 'tab_id': 'tab_auto_train_history', 'x_y_plot_xname': 'session', 'x_y_plot_yname': 'foraging_performance_random_seed', 'x_y_plot_group_by': 'h2o', @@ -55,8 +55,8 @@ 'session_plot_selected_draw_types': list(draw_type_mapper_session_level.keys()), 'session_plot_number_cols': 3, - 'auto_training_history_x_axis': 'session', - 'auto_training_history_sort_by': 'subject_id', + 'auto_training_history_x_axis': 'date', + 'auto_training_history_sort_by': 'last_date', 'auto_training_history_sort_order': 'descending', 'auto_training_curriculum_name': 'Uncoupled Baiting', 'auto_training_curriculum_version': '1.0', From 733a1b7879200319870fb2dbb079da0fc1fbabf3 Mon Sep 17 00:00:00 2001 From: "houhan@gmail.com" Date: Sat, 10 Aug 2024 21:32:59 +0000 Subject: [PATCH 04/11] feat: autotrain only show recent 30 days --- code/util/plot_autotrain_manager.py | 10 +++++++++- code/util/streamlit.py | 3 +++ 2 files changed, 12 insertions(+), 1 deletion(-) diff --git a/code/util/plot_autotrain_manager.py b/code/util/plot_autotrain_manager.py index 48bdf78..d761a4e 100644 --- a/code/util/plot_autotrain_manager.py +++ b/code/util/plot_autotrain_manager.py @@ -15,6 +15,7 @@ def plot_manager_all_progress(manager: 'AutoTrainManager', 'last_date', 'progress_to_graduated'] = 'subject_id', sort_order: ['ascending', 'descending'] = 'descending', + recent_days: int=None, marker_size=10, marker_edge_width=2, highlight_subjects=[], @@ -174,9 +175,16 @@ def plot_manager_all_progress(manager: 'AutoTrainManager', autorange='reversed', zeroline=False, title='' - ) + ), + yaxis_range=[-0.5, n + 0.5], ) + # Limit x range to recent days if x is "date" + if x_axis == 'date' and recent_days is not None: + xrange_min = datetime.now() - pd.Timedelta(days=recent_days) + xrange_max = datetime.now() + fig.update_xaxes(range=[xrange_min, xrange_max]) + # Highight the selected subject for n, subject_id in enumerate(subject_ids): if subject_id in highlight_subjects: diff --git a/code/util/streamlit.py b/code/util/streamlit.py index 4d412ac..e7e4ff6 100644 --- a/code/util/streamlit.py +++ b/code/util/streamlit.py @@ -734,6 +734,7 @@ def add_auto_train_manager(): fig_auto_train = plot_manager_all_progress( st.session_state.auto_train_manager, x_axis=x_axis, + recent_days=30, sort_by=sort_by, sort_order=sort_order, marker_size=marker_size, @@ -748,6 +749,8 @@ def add_auto_train_manager(): ), font=dict(size=18), height=30 * len(df_training_manager.subject_id.unique()), + xaxis_side='top', + title='', ) cols = st.columns([2, 1]) From 4c7273074b47a672be09b8214a451d2acd2a2767 Mon Sep 17 00:00:00 2001 From: "houhan@gmail.com" Date: Sat, 10 Aug 2024 21:40:50 +0000 Subject: [PATCH 05/11] feat: add slider_recent_weeks --- code/util/streamlit.py | 14 ++++++++++++-- code/util/url_query_helper.py | 1 + 2 files changed, 13 insertions(+), 2 deletions(-) diff --git a/code/util/streamlit.py b/code/util/streamlit.py index e7e4ff6..c0b1324 100644 --- a/code/util/streamlit.py +++ b/code/util/streamlit.py @@ -691,7 +691,7 @@ def add_auto_train_manager(): df_training_manager = st.session_state.auto_train_manager.df_manager # -- Show plotly chart -- - cols = st.columns([1, 1, 1, 0.7, 0.7, 3]) + cols = st.columns([1, 1, 1, 0.7, 0.7, 1, 2]) options = ["session", "date", "relative_date"] x_axis = selectbox_wrapper_for_url_query( st_prefix=cols[0], @@ -721,6 +721,16 @@ def add_auto_train_manager(): marker_size = cols[3].number_input('Marker size', value=15, step=1) marker_edge_width = cols[4].number_input('Marker edge width', value=3, step=1) + + recent_weeks = slider_wrapper_for_url_query(cols[5], + label="only recent weeks", + min_value=1, + max_value=26, + step=1, + key='auto_training_history_recent_weeks', + default=8, + disabled=x_axis != 'date', + ) # Get highlighted subjects if ('filter_subject_id' in st.session_state and st.session_state['filter_subject_id']) or\ @@ -734,7 +744,7 @@ def add_auto_train_manager(): fig_auto_train = plot_manager_all_progress( st.session_state.auto_train_manager, x_axis=x_axis, - recent_days=30, + recent_days=recent_weeks*7, sort_by=sort_by, sort_order=sort_order, marker_size=marker_size, diff --git a/code/util/url_query_helper.py b/code/util/url_query_helper.py index fc53a3a..674c4f1 100644 --- a/code/util/url_query_helper.py +++ b/code/util/url_query_helper.py @@ -61,6 +61,7 @@ 'auto_training_curriculum_name': 'Uncoupled Baiting', 'auto_training_curriculum_version': '1.0', 'auto_training_curriculum_schema_version': '1.0', + 'auto_training_history_recent_weeks': 4, } def checkbox_wrapper_for_url_query(st_prefix, label, key, default, **kwargs): From f890c6b2de26d9a2d3382699b6cf824872c84141 Mon Sep 17 00:00:00 2001 From: "houhan@gmail.com" Date: Sun, 11 Aug 2024 01:12:34 +0000 Subject: [PATCH 06/11] feat: add 446 room --- code/Home.py | 2 ++ 1 file changed, 2 insertions(+) diff --git a/code/Home.py b/code/Home.py index e963ad8..9867111 100644 --- a/code/Home.py +++ b/code/Home.py @@ -364,6 +364,8 @@ def _get_data_source(rig): room = '347' elif '447' in rig: room = '447' + elif '446' in rig: + room = '446' elif '323' in rig: room = '323' elif rig_type == 'ephys': From dcd6da70348da23eb745a42bc6c2ba940f0a125e Mon Sep 17 00:00:00 2001 From: "houhan@gmail.com" Date: Sun, 11 Aug 2024 02:02:42 +0000 Subject: [PATCH 07/11] feat: improve hover label for autotrain plot --- code/util/plot_autotrain_manager.py | 27 ++++++++++++++++++++------- code/util/url_query_helper.py | 2 +- 2 files changed, 21 insertions(+), 8 deletions(-) diff --git a/code/util/plot_autotrain_manager.py b/code/util/plot_autotrain_manager.py index d761a4e..7537a03 100644 --- a/code/util/plot_autotrain_manager.py +++ b/code/util/plot_autotrain_manager.py @@ -1,5 +1,6 @@ from datetime import datetime +import streamlit as st import numpy as np import plotly.graph_objects as go import pandas as pd @@ -7,6 +8,9 @@ from aind_auto_train.schema.curriculum import TrainingStage from aind_auto_train.plot.curriculum import get_stage_color_mapper +# Get some additional metadata from the master table +df_tmp_rig_user_name = st.session_state.df['sessions_bonsai'].loc[:, ['subject_id', 'session_date', 'rig', 'user_name']] +df_tmp_rig_user_name.session_date = df_tmp_rig_user_name.session_date.astype(str) def plot_manager_all_progress(manager: 'AutoTrainManager', x_axis: ['session', 'date', @@ -73,6 +77,10 @@ def plot_manager_all_progress(manager: 'AutoTrainManager', manager.df_behavior['subject_id'] == subject_id]['h2o'].iloc[0] else: h2o = None + + df_subject = df_subject.merge( + df_tmp_rig_user_name, + on=['subject_id', 'session_date'], how='left') # Handle open loop sessions open_loop_ids = df_subject.if_closed_loop == False @@ -100,7 +108,7 @@ def plot_manager_all_progress(manager: 'AutoTrainManager', traces.append(go.Scattergl( x=x, - y=[n] * len(df_subject), + y=len(subject_ids) - np.array([n] * len(df_subject)), mode='markers', marker=dict( size=marker_size, @@ -115,12 +123,15 @@ def plot_manager_all_progress(manager: 'AutoTrainManager', name=f'Mouse {subject_id}', hovertemplate=(f"Subject {subject_id} ({h2o})" "
Session %{customdata[0]}, %{customdata[1]}" - "
Curriculum: %{customdata[2]}_v%{customdata[3]}" + "
%{customdata[12]} @ %{customdata[11]}" + "
%{customdata[2]}_v%{customdata[3]}" "
Suggested: %{customdata[4]}" "
Actual: %{customdata[5]}" + f"
{'-'*10}" "
Session task: %{customdata[6]}" "
foraging_eff = %{customdata[7]}" "
finished_trials = %{customdata[8]}" + f"
{'-'*10}" "
Decision = %{customdata[9]}" "
Next suggested: %{customdata[10]}" ""), @@ -130,21 +141,23 @@ def plot_manager_all_progress(manager: 'AutoTrainManager', df_subject.curriculum_name, df_subject.curriculum_version, df_subject.current_stage_suggested, - stage_actual, + stage_actual, # 5 df_subject.task, np.round(df_subject.foraging_efficiency, 3), df_subject.finished_trials, df_subject.decision, - df_subject.next_stage_suggested, + df_subject.next_stage_suggested, # 10 + df_subject.rig, # 11 + df_subject.user_name, # 12 ), axis=-1), showlegend=False ) ) - + # Add "x" for open loop sessions traces.append(go.Scattergl( x=x[open_loop_ids], - y=[n] * len(x[open_loop_ids]), + y=len(subject_ids) - np.array([n] * len(x[open_loop_ids])), mode='markers', marker=dict( size=marker_size*0.8, @@ -172,7 +185,7 @@ def plot_manager_all_progress(manager: 'AutoTrainManager', tickmode='array', tickvals=np.arange(0, n + 1), # Original y-axis values ticktext=subject_ids, # New labels - autorange='reversed', + # autorange='reversed', zeroline=False, title='' ), diff --git a/code/util/url_query_helper.py b/code/util/url_query_helper.py index 674c4f1..12d2b76 100644 --- a/code/util/url_query_helper.py +++ b/code/util/url_query_helper.py @@ -56,7 +56,7 @@ 'session_plot_number_cols': 3, 'auto_training_history_x_axis': 'date', - 'auto_training_history_sort_by': 'last_date', + 'auto_training_history_sort_by': 'first_date', 'auto_training_history_sort_order': 'descending', 'auto_training_curriculum_name': 'Uncoupled Baiting', 'auto_training_curriculum_version': '1.0', From 1d6acca5c9ee9d3bc69b321954e787f6aecacfc5 Mon Sep 17 00:00:00 2001 From: "houhan@gmail.com" Date: Sun, 11 Aug 2024 02:24:33 +0000 Subject: [PATCH 08/11] feat: also sort subject_id when the dates are the same --- code/util/plot_autotrain_manager.py | 18 +++++++++++------- 1 file changed, 11 insertions(+), 7 deletions(-) diff --git a/code/util/plot_autotrain_manager.py b/code/util/plot_autotrain_manager.py index 7537a03..c80b8d8 100644 --- a/code/util/plot_autotrain_manager.py +++ b/code/util/plot_autotrain_manager.py @@ -39,11 +39,15 @@ def plot_manager_all_progress(manager: 'AutoTrainManager', if sort_by == 'subject_id': subject_ids = df_manager.subject_id.unique() elif sort_by == 'first_date': - subject_ids = df_manager.groupby('subject_id').session_date.min().sort_values( - ascending=sort_order == 'ascending').index + subject_ids = pd.DataFrame(df_manager.groupby('subject_id').session_date.min() + ).reset_index().sort_values( + by=['session_date', 'subject_id'], + ascending=sort_order == 'ascending').subject_id elif sort_by == 'last_date': - subject_ids = df_manager.groupby('subject_id').session_date.max().sort_values( - ascending=sort_order == 'ascending').index + subject_ids = pd.DataFrame(df_manager.groupby('subject_id').session_date.max() + ).reset_index().sort_values( + by=['session_date', 'subject_id'], + ascending=sort_order == 'ascending').subject_id elif sort_by == 'progress_to_graduated': manager.compute_stats() df_stats = manager.df_manager_stats @@ -183,13 +187,13 @@ def plot_manager_all_progress(manager: 'AutoTrainManager', hovermode='closest', yaxis=dict( tickmode='array', - tickvals=np.arange(0, n + 1), # Original y-axis values - ticktext=subject_ids, # New labels + tickvals=np.arange(n+1, 0, -1), + ticktext=subject_ids, # autorange='reversed', zeroline=False, title='' ), - yaxis_range=[-0.5, n + 0.5], + yaxis_range=[-0.5, n + 2], ) # Limit x range to recent days if x is "date" From ee12dc2d8625d10d5477edc40ca3c9ae157deeba Mon Sep 17 00:00:00 2001 From: "houhan@gmail.com" Date: Sun, 11 Aug 2024 02:30:29 +0000 Subject: [PATCH 09/11] fix: highlight subject --- code/util/plot_autotrain_manager.py | 17 ++++++++++------- 1 file changed, 10 insertions(+), 7 deletions(-) diff --git a/code/util/plot_autotrain_manager.py b/code/util/plot_autotrain_manager.py index c80b8d8..f43c782 100644 --- a/code/util/plot_autotrain_manager.py +++ b/code/util/plot_autotrain_manager.py @@ -109,10 +109,12 @@ def plot_manager_all_progress(manager: 'AutoTrainManager', # Cache x range xrange_min = x.min() if n == 0 else min(x.min(), xrange_min) xrange_max = x.max() if n == 0 else max(x.max(), xrange_max) + + y = len(subject_ids) - n # Y axis traces.append(go.Scattergl( x=x, - y=len(subject_ids) - np.array([n] * len(df_subject)), + y=[y] * len(df_subject), mode='markers', marker=dict( size=marker_size, @@ -161,7 +163,7 @@ def plot_manager_all_progress(manager: 'AutoTrainManager', # Add "x" for open loop sessions traces.append(go.Scattergl( x=x[open_loop_ids], - y=len(subject_ids) - np.array([n] * len(x[open_loop_ids])), + y=[y] * len(df_subject), mode='markers', marker=dict( size=marker_size*0.8, @@ -187,13 +189,13 @@ def plot_manager_all_progress(manager: 'AutoTrainManager', hovermode='closest', yaxis=dict( tickmode='array', - tickvals=np.arange(n+1, 0, -1), + tickvals=np.arange(len(subject_ids), 0, -1), ticktext=subject_ids, - # autorange='reversed', + # autorange='reversed', # This will lead to a weird space at the top zeroline=False, title='' ), - yaxis_range=[-0.5, n + 2], + yaxis_range=[-0.5, len(subject_ids) + 1], ) # Limit x range to recent days if x is "date" @@ -204,11 +206,12 @@ def plot_manager_all_progress(manager: 'AutoTrainManager', # Highight the selected subject for n, subject_id in enumerate(subject_ids): + y = len(subject_ids) - n # Y axis if subject_id in highlight_subjects: fig.add_shape( type="rect", - y0=n-0.5, - y1=n+0.5, + y0=y-0.5, + y1=y+0.5, x0=xrange_min - (1 if x_axis != 'date' else pd.Timedelta(days=1)), x1=xrange_max + (1 if x_axis != 'date' else pd.Timedelta(days=1)), line=dict( From 20c9c5f2e40ab98a2ef1a64ede1a4a3fe3ed9258 Mon Sep 17 00:00:00 2001 From: "houhan@gmail.com" Date: Sun, 11 Aug 2024 02:33:01 +0000 Subject: [PATCH 10/11] feat: minor reorder --- code/util/streamlit.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/code/util/streamlit.py b/code/util/streamlit.py index c0b1324..acac0af 100644 --- a/code/util/streamlit.py +++ b/code/util/streamlit.py @@ -692,7 +692,7 @@ def add_auto_train_manager(): # -- Show plotly chart -- cols = st.columns([1, 1, 1, 0.7, 0.7, 1, 2]) - options = ["session", "date", "relative_date"] + options = ["date", "session", "relative_date"] x_axis = selectbox_wrapper_for_url_query( st_prefix=cols[0], label="X axis", @@ -701,7 +701,7 @@ def add_auto_train_manager(): key="auto_training_history_x_axis", ) - options = ["subject_id", "first_date", "last_date", "progress_to_graduated"] + options = ["first_date", "last_date", "subject_id", "progress_to_graduated"] sort_by = selectbox_wrapper_for_url_query( st_prefix=cols[1], label="Sort by", From 9ce6b696c3140aa7e585557e8af107348cfd4a3e Mon Sep 17 00:00:00 2001 From: "houhan@gmail.com" Date: Sun, 11 Aug 2024 02:34:20 +0000 Subject: [PATCH 11/11] ci: bump version --- code/Home.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/code/Home.py b/code/Home.py index 9867111..4fc3b6f 100644 --- a/code/Home.py +++ b/code/Home.py @@ -12,7 +12,7 @@ """ -__ver__ = 'v2.3.2' +__ver__ = 'v2.4.0' import pandas as pd import streamlit as st