Skip to content

Commit

Permalink
enable cache for autotrain manager
Browse files Browse the repository at this point in the history
  • Loading branch information
hanhou committed Dec 24, 2024
1 parent c21efa0 commit c8ae95b
Show file tree
Hide file tree
Showing 2 changed files with 33 additions and 31 deletions.
61 changes: 32 additions & 29 deletions code/util/plot_autotrain_manager.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,29 +8,33 @@
from aind_auto_train.schema.curriculum import TrainingStage


def plot_manager_all_progress(manager: 'AutoTrainManager',
x_axis: ['session', 'date',
'relative_date'] = 'session', # type: ignore
sort_by: ['subject_id', 'first_date',
'last_date', 'progress_to_graduated'] = 'subject_id',
sort_order: ['ascending',
'descending'] = 'descending',
recent_days: int=None,
marker_size=10,
marker_edge_width=2,
highlight_subjects=[],
if_show_fig=True
):


@st.cache_data(ttl=3600 * 24)
def plot_manager_all_progress(
x_axis: ["session", "date", "relative_date"] = "session", # type: ignore
sort_by: [
"subject_id",
"first_date",
"last_date",
"progress_to_graduated",
] = "subject_id",
sort_order: ["ascending", "descending"] = "descending",
recent_days: int = None,
marker_size=10,
marker_edge_width=2,
highlight_subjects=[],
if_show_fig=True,
):

manager = st.session_state.auto_train_manager

# %%
# Set default order
df_manager = manager.df_manager.sort_values(by=['subject_id', 'session'],
ascending=[sort_order == 'ascending', False])

if not len(df_manager):
return None

# Get some additional metadata from the master table
df_tmp_rig_user_name = st.session_state.df['sessions_bonsai'].loc[:, ['subject_id', 'session_date', 'rig', 'user_name']]
df_tmp_rig_user_name.session_date = df_tmp_rig_user_name.session_date.astype(str)
Expand All @@ -51,18 +55,18 @@ def plot_manager_all_progress(manager: 'AutoTrainManager',
elif sort_by == 'progress_to_graduated':
manager.compute_stats()
df_stats = manager.df_manager_stats

# Sort by 'first_entry' of GRADUATED
subject_ids = df_stats.reset_index().set_index(
'subject_id'
).query(
f'current_stage_actual == "GRADUATED"'
)['first_entry'].sort_values(
ascending=sort_order != 'ascending').index.to_list()

# Append subjects that have not graduated
subject_ids = subject_ids + [s for s in df_manager.subject_id.unique() if s not in subject_ids]

else:
raise ValueError(
f'sort_by must be in {["subject_id", "first_date", "last_date", "progress"]}')
Expand All @@ -71,17 +75,17 @@ def plot_manager_all_progress(manager: 'AutoTrainManager',
traces = []
for n, subject_id in enumerate(subject_ids):
df_subject = df_manager[df_manager['subject_id'] == subject_id]

# Get stage_color_mapper
stage_color_mapper = get_stage_color_mapper(stage_list=list(TrainingStage.__members__))

# Get h2o if available
if 'h2o' in manager.df_behavior:
h2o = manager.df_behavior[
manager.df_behavior['subject_id'] == subject_id]['h2o'].iloc[0]
else:
h2o = None

df_subject = df_subject.merge(
df_tmp_rig_user_name,
on=['subject_id', 'session_date'], how='left')
Expand All @@ -105,11 +109,11 @@ def plot_manager_all_progress(manager: 'AutoTrainManager',
else:
raise ValueError(
f"x_axis can only be in ['session', 'date', 'relative_date']")

# Cache x range
xrange_min = x.min() if n == 0 else min(x.min(), xrange_min)
xrange_max = x.max() if n == 0 else max(x.max(), xrange_max)

y = len(subject_ids) - n # Y axis

traces.append(go.Scattergl(
Expand Down Expand Up @@ -159,7 +163,7 @@ def plot_manager_all_progress(manager: 'AutoTrainManager',
showlegend=False
)
)

# Add "x" for open loop sessions
traces.append(go.Scattergl(
x=x[open_loop_ids],
Expand Down Expand Up @@ -197,14 +201,14 @@ def plot_manager_all_progress(manager: 'AutoTrainManager',
),
yaxis_range=[-0.5, len(subject_ids) + 1],
)

# Limit x range to recent days if x is "date"
if x_axis == 'date' and recent_days is not None:
# xrange_max = pd.Timestamp.today() # For unknown reasons, using this line will break both plotly_events and new st.plotly_chart callback...
xrange_max = pd.to_datetime(df_manager.session_date).max() + pd.Timedelta(days=1)
xrange_min = xrange_max - pd.Timedelta(days=recent_days)
fig.update_layout(xaxis_range=[xrange_min, xrange_max])

# Highight the selected subject
for n, subject_id in enumerate(subject_ids):
y = len(subject_ids) - n # Y axis
Expand All @@ -222,7 +226,6 @@ def plot_manager_all_progress(manager: 'AutoTrainManager',
opacity=0.3,
layer="below"
)


# Show the plot
if if_show_fig:
Expand Down
3 changes: 1 addition & 2 deletions code/util/streamlit.py
Original file line number Diff line number Diff line change
Expand Up @@ -791,7 +791,7 @@ def add_auto_train_manager():
recent_weeks = slider_wrapper_for_url_query(cols[5],
label="only recent weeks",
min_value=1,
max_value=26,
max_value=52,
step=1,
key='auto_training_history_recent_weeks',
default=8,
Expand All @@ -808,7 +808,6 @@ def add_auto_train_manager():
highlight_subjects = []

fig_auto_train = plot_manager_all_progress(
st.session_state.auto_train_manager,
x_axis=x_axis,
recent_days=recent_weeks*7,
sort_by=sort_by,
Expand Down

0 comments on commit c8ae95b

Please sign in to comment.