Skip to content

Commit

Permalink
Merge pull request #59 from AllenNeuralDynamics/han_dim_reduction
Browse files Browse the repository at this point in the history
feat: dim reduction
  • Loading branch information
hanhou authored Apr 8, 2024
2 parents eb460fb + 1b55d28 commit ecc429c
Show file tree
Hide file tree
Showing 3 changed files with 146 additions and 1 deletion.
143 changes: 143 additions & 0 deletions code/pages/3_Playground.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,143 @@
import s3fs
import streamlit as st
from sklearn.decomposition import PCA
from sklearn.preprocessing import StandardScaler
import pandas as pd
import numpy as np
import plotly.graph_objects as go
from plotly.subplots import make_subplots
from streamlit_plotly_events import plotly_events

from util.streamlit import add_session_filter, data_selector

ss = st.session_state

fs = s3fs.S3FileSystem(anon=False)
cache_folder = 'aind-behavior-data/foraging_nwb_bonsai_processed/'

@st.cache_data(ttl=24*3600)
def load_data(tables=['sessions']):
df = {}
for table in tables:
file_name = cache_folder + f'df_{table}.pkl'
if st.session_state.use_s3:
with fs.open(file_name) as f:
df[table + '_bonsai'] = pd.read_pickle(f)
else:
df[table + '_bonsai'] = pd.read_pickle(file_name)
return df

def app():

with st.sidebar:
add_session_filter(if_bonsai=True)
data_selector()

if not hasattr(ss, 'df'):
st.write('##### Data not loaded yet, start from Home:')
st.page_link('Home.py', label='Home', icon="🏠")
return

df = load_data()['sessions_bonsai']

# -- get cols --
col_task = [s for s in df.metadata.columns
if not any(ss in s for ss in ['lickspout', 'weight', 'water', 'time', 'rig',
'user_name', 'experiment', 'task', 'notes']
)
]

col_perf = [s for s in df.session_stats.columns
if not any(ss in s for ss in ['performance']
)
]

do_pca(ss.df_session_filtered.loc[:, ['subject_id', 'session'] + col_perf], 'performance')
do_pca(ss.df_session_filtered.loc[:, ['subject_id', 'session'] + col_task], 'task')


def do_pca(df, name):
df = df.dropna(axis=0, how='any')

# Standardize the features
x = StandardScaler().fit_transform(df.drop(columns=['subject_id', 'session']))

# Apply PCA
pca = PCA(n_components=10) # Reduce to 2 dimensions for visualization
principalComponents = pca.fit_transform(x)

# Create a new DataFrame with the principal components
principalDf = pd.DataFrame(data=principalComponents)
principalDf.index = df.set_index(['subject_id', 'session']).index

principalDf.reset_index(inplace=True)

# -- trajectory --
fig = go.Figure()

for mouse_id in principalDf['subject_id'].unique():
subset = principalDf[principalDf['subject_id'] == mouse_id]

# Add a 3D scatter plot for the current group
fig.add_trace(go.Scatter3d(
x=subset[0],
y=subset[1],
z=subset[2],
mode='lines+markers',
marker=dict(size=subset['session'].apply(
lambda x: 5 + 15*(x/20))),
name=f'{mouse_id}', # Name the trace for the legend
))

fig.update_layout(title=name,
scene=dict(
xaxis_title='Dim1',
yaxis_title='Dim2',
zaxis_title='Dim3'
),
width=1300,
height=1000,
font_size=15,
)
st.plotly_chart(fig)

# -- variance explained --
var_explained = pca.explained_variance_ratio_
fig = go.Figure()
fig.add_trace(go.Scatter(
x=np.arange(1, len(var_explained)+1),
y=np.cumsum(var_explained),
)
)
fig.update_layout(title='Variance Explained',
yaxis=dict(range=[0, 1]),
width=300,
height=400,
font_size=15,
)
st.plotly_chart(fig)

# -- pca components --
pca_components = pd.DataFrame(pca.components_,
columns=df.drop(columns=['subject_id', 'session']).columns)
pca_components
fig = make_subplots(rows=3, cols=1)

# In vertical subplots, each subplot show the components of a principal component
for i in range(3):
fig.add_trace(go.Bar(
x=pca_components.columns,
y=pca_components.loc[i],
name=f'PC{i+1}',
), row=i+1, col=1)

fig.update_xaxes(showticklabels=i==2, row=i+1, col=1)

fig.update_layout(title='PCA weights',
width=1000,
height=800,
font_size=20,
)
st.plotly_chart(fig)

app()
3 changes: 2 additions & 1 deletion environment/Dockerfile
Original file line number Diff line number Diff line change
Expand Up @@ -89,7 +89,8 @@ RUN pip install -U --no-cache-dir \
seaborn \
pynwb --ignore-installed ruamel.yaml\
git+https://github.com/AllenNeuralDynamics/aind-foraging-behavior-bonsai-automatic-training.git@main\
pygwalker
pygwalker \
scikit-learn==1.4.1

ADD "https://github.com/coder/code-server/releases/download/v4.9.0/code-server-4.9.0-linux-amd64.tar.gz" /.code-server/code-server.tar.gz

Expand Down
1 change: 1 addition & 0 deletions requirements.txt
Original file line number Diff line number Diff line change
Expand Up @@ -66,6 +66,7 @@ ruamel.yaml==0.17.21
ruamel.yaml.clib==0.2.7
s3fs==2022.11.0
scipy==1.10.0
scikit-learn==1.3.2
seaborn==0.11.2
semver==2.13.0
six==1.16.0
Expand Down

0 comments on commit ecc429c

Please sign in to comment.