From c8ae95bff5854d2cce2052c0fb50117006cff28b Mon Sep 17 00:00:00 2001 From: "houhan@gmail.com" Date: Tue, 24 Dec 2024 04:02:23 +0000 Subject: [PATCH] enable cache for autotrain manager --- code/util/plot_autotrain_manager.py | 61 +++++++++++++++-------------- code/util/streamlit.py | 3 +- 2 files changed, 33 insertions(+), 31 deletions(-) diff --git a/code/util/plot_autotrain_manager.py b/code/util/plot_autotrain_manager.py index addf0b9..f5e0209 100644 --- a/code/util/plot_autotrain_manager.py +++ b/code/util/plot_autotrain_manager.py @@ -8,29 +8,33 @@ from aind_auto_train.schema.curriculum import TrainingStage -def plot_manager_all_progress(manager: 'AutoTrainManager', - x_axis: ['session', 'date', - 'relative_date'] = 'session', # type: ignore - sort_by: ['subject_id', 'first_date', - 'last_date', 'progress_to_graduated'] = 'subject_id', - sort_order: ['ascending', - 'descending'] = 'descending', - recent_days: int=None, - marker_size=10, - marker_edge_width=2, - highlight_subjects=[], - if_show_fig=True - ): - - +@st.cache_data(ttl=3600 * 24) +def plot_manager_all_progress( + x_axis: ["session", "date", "relative_date"] = "session", # type: ignore + sort_by: [ + "subject_id", + "first_date", + "last_date", + "progress_to_graduated", + ] = "subject_id", + sort_order: ["ascending", "descending"] = "descending", + recent_days: int = None, + marker_size=10, + marker_edge_width=2, + highlight_subjects=[], + if_show_fig=True, +): + + manager = st.session_state.auto_train_manager + # %% # Set default order df_manager = manager.df_manager.sort_values(by=['subject_id', 'session'], ascending=[sort_order == 'ascending', False]) - + if not len(df_manager): return None - + # Get some additional metadata from the master table df_tmp_rig_user_name = st.session_state.df['sessions_bonsai'].loc[:, ['subject_id', 'session_date', 'rig', 'user_name']] df_tmp_rig_user_name.session_date = df_tmp_rig_user_name.session_date.astype(str) @@ -51,7 +55,7 @@ def plot_manager_all_progress(manager: 'AutoTrainManager', elif sort_by == 'progress_to_graduated': manager.compute_stats() df_stats = manager.df_manager_stats - + # Sort by 'first_entry' of GRADUATED subject_ids = df_stats.reset_index().set_index( 'subject_id' @@ -59,10 +63,10 @@ def plot_manager_all_progress(manager: 'AutoTrainManager', f'current_stage_actual == "GRADUATED"' )['first_entry'].sort_values( ascending=sort_order != 'ascending').index.to_list() - + # Append subjects that have not graduated subject_ids = subject_ids + [s for s in df_manager.subject_id.unique() if s not in subject_ids] - + else: raise ValueError( f'sort_by must be in {["subject_id", "first_date", "last_date", "progress"]}') @@ -71,17 +75,17 @@ def plot_manager_all_progress(manager: 'AutoTrainManager', traces = [] for n, subject_id in enumerate(subject_ids): df_subject = df_manager[df_manager['subject_id'] == subject_id] - + # Get stage_color_mapper stage_color_mapper = get_stage_color_mapper(stage_list=list(TrainingStage.__members__)) - + # Get h2o if available if 'h2o' in manager.df_behavior: h2o = manager.df_behavior[ manager.df_behavior['subject_id'] == subject_id]['h2o'].iloc[0] else: h2o = None - + df_subject = df_subject.merge( df_tmp_rig_user_name, on=['subject_id', 'session_date'], how='left') @@ -105,11 +109,11 @@ def plot_manager_all_progress(manager: 'AutoTrainManager', else: raise ValueError( f"x_axis can only be in ['session', 'date', 'relative_date']") - + # Cache x range xrange_min = x.min() if n == 0 else min(x.min(), xrange_min) xrange_max = x.max() if n == 0 else max(x.max(), xrange_max) - + y = len(subject_ids) - n # Y axis traces.append(go.Scattergl( @@ -159,7 +163,7 @@ def plot_manager_all_progress(manager: 'AutoTrainManager', showlegend=False ) ) - + # Add "x" for open loop sessions traces.append(go.Scattergl( x=x[open_loop_ids], @@ -197,14 +201,14 @@ def plot_manager_all_progress(manager: 'AutoTrainManager', ), yaxis_range=[-0.5, len(subject_ids) + 1], ) - + # Limit x range to recent days if x is "date" if x_axis == 'date' and recent_days is not None: # xrange_max = pd.Timestamp.today() # For unknown reasons, using this line will break both plotly_events and new st.plotly_chart callback... xrange_max = pd.to_datetime(df_manager.session_date).max() + pd.Timedelta(days=1) xrange_min = xrange_max - pd.Timedelta(days=recent_days) fig.update_layout(xaxis_range=[xrange_min, xrange_max]) - + # Highight the selected subject for n, subject_id in enumerate(subject_ids): y = len(subject_ids) - n # Y axis @@ -222,7 +226,6 @@ def plot_manager_all_progress(manager: 'AutoTrainManager', opacity=0.3, layer="below" ) - # Show the plot if if_show_fig: diff --git a/code/util/streamlit.py b/code/util/streamlit.py index 5ef2b06..d1a3bec 100644 --- a/code/util/streamlit.py +++ b/code/util/streamlit.py @@ -791,7 +791,7 @@ def add_auto_train_manager(): recent_weeks = slider_wrapper_for_url_query(cols[5], label="only recent weeks", min_value=1, - max_value=26, + max_value=52, step=1, key='auto_training_history_recent_weeks', default=8, @@ -808,7 +808,6 @@ def add_auto_train_manager(): highlight_subjects = [] fig_auto_train = plot_manager_all_progress( - st.session_state.auto_train_manager, x_axis=x_axis, recent_days=recent_weeks*7, sort_by=sort_by,