diff --git a/code/pages/4_RL model playground.py b/code/pages/4_RL model playground.py index b5bb8a2..bf8196f 100644 --- a/code/pages/4_RL model playground.py +++ b/code/pages/4_RL model playground.py @@ -1,8 +1,6 @@ """Playground for RL models of dynamic foraging """ -import inspect -from typing import _LiteralGenericAlias, get_type_hints import streamlit as st import streamlit_nested_layout @@ -12,6 +10,7 @@ from aind_dynamic_foraging_models import generative_model from aind_dynamic_foraging_models.generative_model import ForagerCollection from aind_dynamic_foraging_models.generative_model.params import ParamsSymbols +from aind_dynamic_foraging_models.generative_model.params.util import get_params_options from aind_dynamic_foraging_basic_analysis import compute_foraging_efficiency try: @@ -37,54 +36,6 @@ "Random walk task": "RandomWalkTask", } -para_range_override = { - "choice_kernel_relative_weight": [0.0, 2.0], -} - -def _get_agent_args_options(agent_class): - type_hints = get_type_hints(agent_class.__init__) - signature = inspect.signature(agent_class.__init__) - - agent_args_options = {} - for arg, type_hint in type_hints.items(): - if isinstance(type_hint, _LiteralGenericAlias): # Check if the type hint is a Literal - # Get options - literal_values = type_hint.__args__ - default_value = signature.parameters[arg].default - - agent_args_options[arg] = {'options': literal_values, 'default': default_value} - return agent_args_options - - -def _get_params_options(param_model): - # Get the schema - params_schema = param_model.model_json_schema()["properties"] - - # Extract ge and le constraints - param_options = {} - - for para_name, para_field in params_schema.items(): - default = para_field.get("default", None) - para_desc = para_field.get("description", "") - - if para_name in para_range_override: - para_range = para_range_override[para_name] - else: # Get from pydantic schema - para_range = [-20, 20] # Default range - # Override the range if specified - if "minimum" in para_field: - para_range[0] = para_field["minimum"] - if "maximum" in para_field: - para_range[1] = para_field["maximum"] - para_range = [type(default)(x) for x in para_range] - - param_options[para_name] = dict( - para_range=para_range, - para_default=default, - para_symbol=ParamsSymbols[para_name], - para_desc=para_desc, - ) - return param_options def select_agent_args(agent_args_options): agent_args = {} @@ -98,7 +49,12 @@ def select_agent_args(agent_args_options): def select_params(forager): # Get params schema - params_options = _get_params_options(forager.ParamModel) + params_options = get_params_options( + forager.ParamModel, + default_range=[-20, 20], + para_range_override={ + "choice_kernel_relative_weight": [0.0, 2.0], + }) params = {} # Select agent parameters @@ -122,7 +78,7 @@ def select_params(forager): def select_forager(model_family, seed=42): # -- Select agent -- agent_class = getattr(generative_model, model_families[model_family]) - agent_args_options = _get_agent_args_options(agent_class) + agent_args_options = ForagerCollection()._get_agent_kwargs_options(agent_class) agent_args = select_agent_args(agent_args_options) # -- Select agent parameters -- @@ -289,5 +245,17 @@ def app(): with col0[1]: st.write(f"#### **Foraging efficiency**:") st.write(f"# {foraging_eff_random_seed:.3f}") + + # --- Show all foragers --- + st.markdown("---") + st.markdown(f"### All available foragers") + df = ForagerCollection().get_all_foragers() + df.drop(columns=["agent_kwargs", "forager"], inplace=True) + first_cols = ["agent_class_name", "preset_name", "params", "n_free_params"] + df = df[ + first_cols + [col for col in df.columns if col not in first_cols] + ] # Reorder + + st.markdown(df.to_markdown(index=True)) app() \ No newline at end of file