Skip to content

Commit

Permalink
Merge pull request #79 from AllenNeuralDynamics/han_improve_autotrain
Browse files Browse the repository at this point in the history
feat: Han improve autotrain
  • Loading branch information
hanhou authored Aug 11, 2024
2 parents a08a95e + 9ce6b69 commit 3fea99b
Show file tree
Hide file tree
Showing 8 changed files with 86 additions and 39 deletions.
8 changes: 5 additions & 3 deletions code/Home.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@
"""

__ver__ = 'v2.3.2'
__ver__ = 'v2.4.0'

import pandas as pd
import streamlit as st
Expand Down Expand Up @@ -364,6 +364,8 @@ def _get_data_source(rig):
room = '347'
elif '447' in rig:
room = '447'
elif '446' in rig:
room = '446'
elif '323' in rig:
room = '323'
elif rig_type == 'ephys':
Expand Down Expand Up @@ -542,10 +544,10 @@ def app():
st.rerun()

chosen_id = stx.tab_bar(data=[
stx.TabBarItemData(id="tab_session_x_y", title="📈 Session X-Y plot", description="Interactive session-wise scatter plot"),
stx.TabBarItemData(id="tab_auto_train_history", title="🎓 Automatic Training History", description="Track progress"),
stx.TabBarItemData(id="tab_session_inspector", title="👀 Session Inspector", description="Select sessions from the table and show plots"),
stx.TabBarItemData(id="tab_session_x_y", title="📈 Session X-Y plot", description="Interactive session-wise scatter plot"),
stx.TabBarItemData(id="tab_pygwalker", title="📊 PyGWalker (Tableau)", description="Interactive dataframe explorer"),
stx.TabBarItemData(id="tab_auto_train_history", title="🎓 Automatic Training History", description="Track progress"),
stx.TabBarItemData(id="tab_auto_train_curriculum", title="📚 Automatic Training Curriculums", description="Collection of curriculums"),
# stx.TabBarItemData(id="tab_mouse_inspector", title="🐭 Mouse Inspector", description="Mouse-level summary"),
], default=st.query_params['tab_id'] if 'tab_id' in st.query_params
Expand Down
12 changes: 4 additions & 8 deletions code/util/aws_s3.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,9 @@
import pandas as pd
import streamlit as st

from .settings import draw_type_mapper_session_level
from .settings import (
draw_type_layout_definition, draw_type_mapper_session_level, draw_types_quick_preview
)

# --------------------------------------
data_sources = ['bonsai', 'bpod']
Expand Down Expand Up @@ -38,12 +40,6 @@ def load_data(tables=['sessions'], data_source = 'bonsai'):

def draw_session_plots_quick_preview(df_to_draw_session):

# Setting up layout for each session
layout_definition = [[1], # columns in the first row
[1, 1],
]
draw_types_quick_preview = ['1. Choice history', '2. Logistic regression (Su2022)']

container_session_all_in_one = st.container()

key = df_to_draw_session.to_dict(orient='records')[0]
Expand All @@ -59,7 +55,7 @@ def draw_session_plots_quick_preview(df_to_draw_session):
unsafe_allow_html=True)

rows = []
for row, column_setting in enumerate(layout_definition):
for row, column_setting in enumerate(draw_type_layout_definition):
rows.append(st.columns(column_setting))

for draw_type in draw_types_quick_preview:
Expand Down
60 changes: 44 additions & 16 deletions code/util/plot_autotrain_manager.py
Original file line number Diff line number Diff line change
@@ -1,12 +1,16 @@
from datetime import datetime

import streamlit as st
import numpy as np
import plotly.graph_objects as go
import pandas as pd

from aind_auto_train.schema.curriculum import TrainingStage
from aind_auto_train.plot.curriculum import get_stage_color_mapper

# Get some additional metadata from the master table
df_tmp_rig_user_name = st.session_state.df['sessions_bonsai'].loc[:, ['subject_id', 'session_date', 'rig', 'user_name']]
df_tmp_rig_user_name.session_date = df_tmp_rig_user_name.session_date.astype(str)

def plot_manager_all_progress(manager: 'AutoTrainManager',
x_axis: ['session', 'date',
Expand All @@ -15,6 +19,7 @@ def plot_manager_all_progress(manager: 'AutoTrainManager',
'last_date', 'progress_to_graduated'] = 'subject_id',
sort_order: ['ascending',
'descending'] = 'descending',
recent_days: int=None,
marker_size=10,
marker_edge_width=2,
highlight_subjects=[],
Expand All @@ -34,11 +39,15 @@ def plot_manager_all_progress(manager: 'AutoTrainManager',
if sort_by == 'subject_id':
subject_ids = df_manager.subject_id.unique()
elif sort_by == 'first_date':
subject_ids = df_manager.groupby('subject_id').session_date.min().sort_values(
ascending=sort_order == 'ascending').index
subject_ids = pd.DataFrame(df_manager.groupby('subject_id').session_date.min()
).reset_index().sort_values(
by=['session_date', 'subject_id'],
ascending=sort_order == 'ascending').subject_id
elif sort_by == 'last_date':
subject_ids = df_manager.groupby('subject_id').session_date.max().sort_values(
ascending=sort_order == 'ascending').index
subject_ids = pd.DataFrame(df_manager.groupby('subject_id').session_date.max()
).reset_index().sort_values(
by=['session_date', 'subject_id'],
ascending=sort_order == 'ascending').subject_id
elif sort_by == 'progress_to_graduated':
manager.compute_stats()
df_stats = manager.df_manager_stats
Expand Down Expand Up @@ -72,6 +81,10 @@ def plot_manager_all_progress(manager: 'AutoTrainManager',
manager.df_behavior['subject_id'] == subject_id]['h2o'].iloc[0]
else:
h2o = None

df_subject = df_subject.merge(
df_tmp_rig_user_name,
on=['subject_id', 'session_date'], how='left')

# Handle open loop sessions
open_loop_ids = df_subject.if_closed_loop == False
Expand All @@ -96,10 +109,12 @@ def plot_manager_all_progress(manager: 'AutoTrainManager',
# Cache x range
xrange_min = x.min() if n == 0 else min(x.min(), xrange_min)
xrange_max = x.max() if n == 0 else max(x.max(), xrange_max)

y = len(subject_ids) - n # Y axis

traces.append(go.Scattergl(
x=x,
y=[n] * len(df_subject),
y=[y] * len(df_subject),
mode='markers',
marker=dict(
size=marker_size,
Expand All @@ -114,12 +129,15 @@ def plot_manager_all_progress(manager: 'AutoTrainManager',
name=f'Mouse {subject_id}',
hovertemplate=(f"<b>Subject {subject_id} ({h2o})</b>"
"<br><b>Session %{customdata[0]}, %{customdata[1]}</b>"
"<br>Curriculum: <b>%{customdata[2]}_v%{customdata[3]}</b>"
"<br><b>%{customdata[12]} @ %{customdata[11]}</b>"
"<br><b>%{customdata[2]}_v%{customdata[3]}</b>"
"<br>Suggested: <b>%{customdata[4]}</b>"
"<br>Actual: <b>%{customdata[5]}</b>"
f"<br>{'-'*10}"
"<br>Session task: <b>%{customdata[6]}</b>"
"<br>foraging_eff = %{customdata[7]}"
"<br>finished_trials = %{customdata[8]}"
f"<br>{'-'*10}"
"<br>Decision = <b>%{customdata[9]}</b>"
"<br>Next suggested: <b>%{customdata[10]}</b>"
"<extra></extra>"),
Expand All @@ -129,21 +147,23 @@ def plot_manager_all_progress(manager: 'AutoTrainManager',
df_subject.curriculum_name,
df_subject.curriculum_version,
df_subject.current_stage_suggested,
stage_actual,
stage_actual, # 5
df_subject.task,
np.round(df_subject.foraging_efficiency, 3),
df_subject.finished_trials,
df_subject.decision,
df_subject.next_stage_suggested,
df_subject.next_stage_suggested, # 10
df_subject.rig, # 11
df_subject.user_name, # 12
), axis=-1),
showlegend=False
)
)

# Add "x" for open loop sessions
traces.append(go.Scattergl(
x=x[open_loop_ids],
y=[n] * len(x[open_loop_ids]),
y=[y] * len(df_subject),
mode='markers',
marker=dict(
size=marker_size*0.8,
Expand All @@ -169,21 +189,29 @@ def plot_manager_all_progress(manager: 'AutoTrainManager',
hovermode='closest',
yaxis=dict(
tickmode='array',
tickvals=np.arange(0, n + 1), # Original y-axis values
ticktext=subject_ids, # New labels
autorange='reversed',
tickvals=np.arange(len(subject_ids), 0, -1),
ticktext=subject_ids,
# autorange='reversed', # This will lead to a weird space at the top
zeroline=False,
title=''
)
),
yaxis_range=[-0.5, len(subject_ids) + 1],
)

# Limit x range to recent days if x is "date"
if x_axis == 'date' and recent_days is not None:
xrange_min = datetime.now() - pd.Timedelta(days=recent_days)
xrange_max = datetime.now()
fig.update_xaxes(range=[xrange_min, xrange_max])

# Highight the selected subject
for n, subject_id in enumerate(subject_ids):
y = len(subject_ids) - n # Y axis
if subject_id in highlight_subjects:
fig.add_shape(
type="rect",
y0=n-0.5,
y1=n+0.5,
y0=y-0.5,
y1=y+0.5,
x0=xrange_min - (1 if x_axis != 'date' else pd.Timedelta(days=1)),
x1=xrange_max + (1 if x_axis != 'date' else pd.Timedelta(days=1)),
line=dict(
Expand Down
6 changes: 6 additions & 0 deletions code/util/settings.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@
]

logistic_regression_models = ["Su2022", "Bari2019", "Hattori2019", "Miller2021"]

draw_type_mapper_session_level = {
"1. Choice history": (
"choice_history", # prefix
Expand All @@ -27,3 +28,8 @@
for n, model in enumerate(logistic_regression_models)
},
}

# For quick preview
draw_types_quick_preview = [
'1. Choice history',
'3. Logistic regression (Su2022)']
19 changes: 16 additions & 3 deletions code/util/streamlit.py
Original file line number Diff line number Diff line change
Expand Up @@ -691,8 +691,8 @@ def add_auto_train_manager():
df_training_manager = st.session_state.auto_train_manager.df_manager

# -- Show plotly chart --
cols = st.columns([1, 1, 1, 0.7, 0.7, 3])
options = ["session", "date", "relative_date"]
cols = st.columns([1, 1, 1, 0.7, 0.7, 1, 2])
options = ["date", "session", "relative_date"]
x_axis = selectbox_wrapper_for_url_query(
st_prefix=cols[0],
label="X axis",
Expand All @@ -701,7 +701,7 @@ def add_auto_train_manager():
key="auto_training_history_x_axis",
)

options = ["subject_id", "first_date", "last_date", "progress_to_graduated"]
options = ["first_date", "last_date", "subject_id", "progress_to_graduated"]
sort_by = selectbox_wrapper_for_url_query(
st_prefix=cols[1],
label="Sort by",
Expand All @@ -721,6 +721,16 @@ def add_auto_train_manager():

marker_size = cols[3].number_input('Marker size', value=15, step=1)
marker_edge_width = cols[4].number_input('Marker edge width', value=3, step=1)

recent_weeks = slider_wrapper_for_url_query(cols[5],
label="only recent weeks",
min_value=1,
max_value=26,
step=1,
key='auto_training_history_recent_weeks',
default=8,
disabled=x_axis != 'date',
)

# Get highlighted subjects
if ('filter_subject_id' in st.session_state and st.session_state['filter_subject_id']) or\
Expand All @@ -734,6 +744,7 @@ def add_auto_train_manager():
fig_auto_train = plot_manager_all_progress(
st.session_state.auto_train_manager,
x_axis=x_axis,
recent_days=recent_weeks*7,
sort_by=sort_by,
sort_order=sort_order,
marker_size=marker_size,
Expand All @@ -748,6 +759,8 @@ def add_auto_train_manager():
),
font=dict(size=18),
height=30 * len(df_training_manager.subject_id.unique()),
xaxis_side='top',
title='',
)

cols = st.columns([2, 1])
Expand Down
7 changes: 4 additions & 3 deletions code/util/url_query_helper.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,7 @@

'table_height': 300,

'tab_id': 'tab_session_x_y',
'tab_id': 'tab_auto_train_history',
'x_y_plot_xname': 'session',
'x_y_plot_yname': 'foraging_performance_random_seed',
'x_y_plot_group_by': 'h2o',
Expand Down Expand Up @@ -55,12 +55,13 @@
'session_plot_selected_draw_types': list(draw_type_mapper_session_level.keys()),
'session_plot_number_cols': 3,

'auto_training_history_x_axis': 'session',
'auto_training_history_sort_by': 'subject_id',
'auto_training_history_x_axis': 'date',
'auto_training_history_sort_by': 'first_date',
'auto_training_history_sort_order': 'descending',
'auto_training_curriculum_name': 'Uncoupled Baiting',
'auto_training_curriculum_version': '1.0',
'auto_training_curriculum_schema_version': '1.0',
'auto_training_history_recent_weeks': 4,
}

def checkbox_wrapper_for_url_query(st_prefix, label, key, default, **kwargs):
Expand Down
12 changes: 6 additions & 6 deletions environment/Dockerfile
Original file line number Diff line number Diff line change
Expand Up @@ -90,14 +90,14 @@ RUN pip install -U --no-cache-dir \
pynwb --ignore-installed ruamel.yaml\
git+https://github.com/AllenNeuralDynamics/aind-foraging-behavior-bonsai-automatic-training.git@main\
pygwalker \
scikit-learn==1.4.1
scikit-learn==1.4

ADD "https://github.com/coder/code-server/releases/download/v4.21.1/code-server-4.21.1-linux-amd64.tar.gz" /.code-server/code-server.tar.gz

ADD "https://github.com/coder/code-server/releases/download/v4.9.0/code-server-4.9.0-linux-amd64.tar.gz" /.code-server/code-server.tar.gz

RUN cd /.code-server \
&& tar -xvf code-server.tar.gz \
&& rm code-server.tar.gz \
&& ln -s /.code-server/code-server-4.9.0-linux-amd64/bin/code-server /usr/bin/code-server
&& tar -xvf code-server.tar.gz \
&& rm code-server.tar.gz \
&& ln -s /.code-server/code-server-4.21.1-linux-amd64/bin/code-server /usr/bin/code-server


COPY postInstall /
Expand Down
1 change: 1 addition & 0 deletions environment/postInstall
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@ if [ ! -d "/.vscode/extensions" ]
code-server --disable-telemetry --extensions-dir=/.vscode/extensions --install-extension njpwerner.autodocstring
code-server --disable-telemetry --extensions-dir=/.vscode/extensions --install-extension KevinRose.vsc-python-indent
code-server --disable-telemetry --extensions-dir=/.vscode/extensions --install-extension sugatoray.vscode-git-extension-pack
code-server --disable-telemetry --extensions-dir=/.vscode/extensions --install-extension mhutchie.git-graph

else
echo "code-server not found"
Expand Down

0 comments on commit 3fea99b

Please sign in to comment.