Skip to content

Commit

Permalink
feat: refactor and add all forager list
Browse files Browse the repository at this point in the history
  • Loading branch information
hanhou committed Sep 5, 2024
1 parent d2f7a53 commit 62b9b6a
Showing 1 changed file with 20 additions and 52 deletions.
72 changes: 20 additions & 52 deletions code/pages/4_RL model playground.py
Original file line number Diff line number Diff line change
@@ -1,8 +1,6 @@
"""Playground for RL models of dynamic foraging
"""

import inspect
from typing import _LiteralGenericAlias, get_type_hints

import streamlit as st
import streamlit_nested_layout
Expand All @@ -12,6 +10,7 @@
from aind_dynamic_foraging_models import generative_model
from aind_dynamic_foraging_models.generative_model import ForagerCollection
from aind_dynamic_foraging_models.generative_model.params import ParamsSymbols
from aind_dynamic_foraging_models.generative_model.params.util import get_params_options
from aind_dynamic_foraging_basic_analysis import compute_foraging_efficiency

try:
Expand All @@ -37,54 +36,6 @@
"Random walk task": "RandomWalkTask",
}

para_range_override = {
"choice_kernel_relative_weight": [0.0, 2.0],
}

def _get_agent_args_options(agent_class):
type_hints = get_type_hints(agent_class.__init__)
signature = inspect.signature(agent_class.__init__)

agent_args_options = {}
for arg, type_hint in type_hints.items():
if isinstance(type_hint, _LiteralGenericAlias): # Check if the type hint is a Literal
# Get options
literal_values = type_hint.__args__
default_value = signature.parameters[arg].default

agent_args_options[arg] = {'options': literal_values, 'default': default_value}
return agent_args_options


def _get_params_options(param_model):
# Get the schema
params_schema = param_model.model_json_schema()["properties"]

# Extract ge and le constraints
param_options = {}

for para_name, para_field in params_schema.items():
default = para_field.get("default", None)
para_desc = para_field.get("description", "")

if para_name in para_range_override:
para_range = para_range_override[para_name]
else: # Get from pydantic schema
para_range = [-20, 20] # Default range
# Override the range if specified
if "minimum" in para_field:
para_range[0] = para_field["minimum"]
if "maximum" in para_field:
para_range[1] = para_field["maximum"]
para_range = [type(default)(x) for x in para_range]

param_options[para_name] = dict(
para_range=para_range,
para_default=default,
para_symbol=ParamsSymbols[para_name],
para_desc=para_desc,
)
return param_options

def select_agent_args(agent_args_options):
agent_args = {}
Expand All @@ -98,7 +49,12 @@ def select_agent_args(agent_args_options):

def select_params(forager):
# Get params schema
params_options = _get_params_options(forager.ParamModel)
params_options = get_params_options(
forager.ParamModel,
default_range=[-20, 20],
para_range_override={
"choice_kernel_relative_weight": [0.0, 2.0],
})

params = {}
# Select agent parameters
Expand All @@ -122,7 +78,7 @@ def select_params(forager):
def select_forager(model_family, seed=42):
# -- Select agent --
agent_class = getattr(generative_model, model_families[model_family])
agent_args_options = _get_agent_args_options(agent_class)
agent_args_options = ForagerCollection()._get_agent_kwargs_options(agent_class)
agent_args = select_agent_args(agent_args_options)

# -- Select agent parameters --
Expand Down Expand Up @@ -289,5 +245,17 @@ def app():
with col0[1]:
st.write(f"#### **Foraging efficiency**:")
st.write(f"# {foraging_eff_random_seed:.3f}")

# --- Show all foragers ---
st.markdown("---")
st.markdown(f"### All available foragers")
df = ForagerCollection().get_all_foragers()
df.drop(columns=["agent_kwargs", "forager"], inplace=True)
first_cols = ["agent_class_name", "preset_name", "params", "n_free_params"]
df = df[
first_cols + [col for col in df.columns if col not in first_cols]
] # Reorder

st.markdown(df.to_markdown(index=True))

app()

0 comments on commit 62b9b6a

Please sign in to comment.