From 802d39498c706e4853e5c3389c3e9d7ea209388d Mon Sep 17 00:00:00 2001 From: "houhan@gmail.com" Date: Fri, 30 Aug 2024 21:56:08 +0000 Subject: [PATCH 1/7] feat: add RL agent and gym task --- ...Playground.py => 1_Learning trajectory.py} | 0 code/pages/4_RL model playground.py | 50 +++++++++++++++++++ environment/Dockerfile | 4 +- requirements.txt | 4 +- 4 files changed, 56 insertions(+), 2 deletions(-) rename code/pages/{1_Playground.py => 1_Learning trajectory.py} (100%) create mode 100644 code/pages/4_RL model playground.py diff --git a/code/pages/1_Playground.py b/code/pages/1_Learning trajectory.py similarity index 100% rename from code/pages/1_Playground.py rename to code/pages/1_Learning trajectory.py diff --git a/code/pages/4_RL model playground.py b/code/pages/4_RL model playground.py new file mode 100644 index 0000000..1008218 --- /dev/null +++ b/code/pages/4_RL model playground.py @@ -0,0 +1,50 @@ +"""Playground for RL models of dynamic foraging +""" + +import streamlit as st +import streamlit_nested_layout + +from aind_behavior_gym.dynamic_foraging.task import CoupledBlockTask +from aind_dynamic_foraging_models.generative_model import ( + ForagerQLearning, ForagerLossCounting, ForagerCollection +) + +try: + st.set_page_config(layout="wide", + page_title='Foraging behavior browser', + page_icon=':mouse2:', + menu_items={ + 'Report a bug': "https://github.com/hanhou/foraging-behavior-browser/issues", + 'About': "Github repo: https://github.com/hanhou/foraging-behavior-browser/" + } + ) +except: + pass + +def app(): + # Initialize the model + forager = ForagerCollection().get_preset_forager("Hattori2019", seed=42) + forager.set_params( + softmax_inverse_temperature=5, + biasL=0, + ) + + # Create the task environment + task = CoupledBlockTask(reward_baiting=True, num_trials=1000, seed=42) + + # Run the model + forager.perform(task) + + # Capture the results + ground_truth_params = forager.params.model_dump() + ground_truth_choice_prob = forager.choice_prob + ground_truth_q_value = forager.q_value + # Get the history + choice_history = forager.get_choice_history() + reward_history = forager.get_reward_history() + + # Plot the session results + fig, axes = forager.plot_session(if_plot_latent=True) + st.pyplot(fig) + +app() \ No newline at end of file diff --git a/environment/Dockerfile b/environment/Dockerfile index 4834dd3..4905d32 100644 --- a/environment/Dockerfile +++ b/environment/Dockerfile @@ -107,7 +107,9 @@ RUN pip install -U --no-cache-dir \ pygwalker==0.4.7 \ aind-data-access-api[docdb]==0.13.0 \ streamlit-dynamic-filters==0.1.9 \ - semver==3.0.2 + semver==3.0.2 \ + aind_dynamic_foraging_models \ + aind_behavior_gym ADD "https://github.com/coder/code-server/releases/download/v4.21.1/code-server-4.21.1-linux-amd64.tar.gz" /.code-server/code-server.tar.gz diff --git a/requirements.txt b/requirements.txt index 2c37526..005f005 100644 --- a/requirements.txt +++ b/requirements.txt @@ -96,4 +96,6 @@ git+https://github.com/AllenNeuralDynamics/aind-foraging-behavior-bonsai-automat pygwalker==0.4.7 aind-data-access-api[docdb]==0.13.0 streamlit-dynamic-filters==0.1.9 -semver==3.0.2 \ No newline at end of file +semver==3.0.2 +aind_dynamic_foraging_models +aind_behavior_gym \ No newline at end of file From 4142d801ac53116faec2cc6bef10000335877c09 Mon Sep 17 00:00:00 2001 From: "houhan@gmail.com" Date: Fri, 30 Aug 2024 22:46:29 +0000 Subject: [PATCH 2/7] feat: get agent args --- code/pages/4_RL model playground.py | 51 ++++++++++++++++++++++++++--- environment/Dockerfile | 2 +- requirements.txt | 2 +- 3 files changed, 49 insertions(+), 6 deletions(-) diff --git a/code/pages/4_RL model playground.py b/code/pages/4_RL model playground.py index 1008218..a15fa2f 100644 --- a/code/pages/4_RL model playground.py +++ b/code/pages/4_RL model playground.py @@ -3,11 +3,12 @@ import streamlit as st import streamlit_nested_layout +from typing import get_type_hints, _LiteralGenericAlias +import inspect from aind_behavior_gym.dynamic_foraging.task import CoupledBlockTask -from aind_dynamic_foraging_models.generative_model import ( - ForagerQLearning, ForagerLossCounting, ForagerCollection -) +from aind_dynamic_foraging_models import generative_model +from aind_dynamic_foraging_models.generative_model import ForagerCollection try: st.set_page_config(layout="wide", @@ -21,9 +22,51 @@ except: pass +model_families = { + "Q-learning": "ForagerQLearning", + "Loss counting": "ForagerLossCounting" +} + +def _get_agent_args(agent_class): + type_hints = get_type_hints(agent_class.__init__) + signature = inspect.signature(agent_class.__init__) + + agent_args_options = {} + for arg, type_hint in type_hints.items(): + if isinstance(type_hint, _LiteralGenericAlias): # Check if the type hint is a Literal + # Get options + literal_values = type_hint.__args__ + default_value = signature.parameters[arg].default + + agent_args_options[arg] = {'options': literal_values, 'default': default_value} + return agent_args_options + +def get_arg_params(agent_args_options): + agent_args = {} + # Select agent parameters + for n, arg_name in enumerate(agent_args_options.keys()): + agent_args[arg_name] = st.selectbox( + arg_name, + agent_args_options[arg_name]['options'] + ) + return agent_args + + def app(): + + # A dropdown to select model family + model_family = st.selectbox("Select model family", list(model_families.keys())) + + agent_class = getattr(generative_model, model_families[model_family]) + agent_args_options = _get_agent_args(agent_class) + agent_args = get_arg_params(agent_args_options) + # Initialize the model - forager = ForagerCollection().get_preset_forager("Hattori2019", seed=42) + forager_collection = ForagerCollection() + all_presets = forager_collection.FORAGER_PRESETS.keys() + + forager = forager_collection.get_preset_forager("Hattori2019", seed=42) + forager.set_params( softmax_inverse_temperature=5, biasL=0, diff --git a/environment/Dockerfile b/environment/Dockerfile index 4905d32..1d46388 100644 --- a/environment/Dockerfile +++ b/environment/Dockerfile @@ -108,7 +108,7 @@ RUN pip install -U --no-cache-dir \ aind-data-access-api[docdb]==0.13.0 \ streamlit-dynamic-filters==0.1.9 \ semver==3.0.2 \ - aind_dynamic_foraging_models \ + git+https://github.com/AllenNeuralDynamics/aind-dynamic-foraging-models.git@develop \ aind_behavior_gym diff --git a/requirements.txt b/requirements.txt index 005f005..58ef5fd 100644 --- a/requirements.txt +++ b/requirements.txt @@ -97,5 +97,5 @@ pygwalker==0.4.7 aind-data-access-api[docdb]==0.13.0 streamlit-dynamic-filters==0.1.9 semver==3.0.2 -aind_dynamic_foraging_models +git+https://github.com/AllenNeuralDynamics/aind-dynamic-foraging-models.git@develop aind_behavior_gym \ No newline at end of file From 540e03043bd775d52225213e64da5889d993721e Mon Sep 17 00:00:00 2001 From: "houhan@gmail.com" Date: Sat, 31 Aug 2024 01:27:01 +0000 Subject: [PATCH 3/7] feat: add params selectors --- code/pages/4_RL model playground.py | 81 +++++++++++++++++++++++++---- 1 file changed, 71 insertions(+), 10 deletions(-) diff --git a/code/pages/4_RL model playground.py b/code/pages/4_RL model playground.py index a15fa2f..780c319 100644 --- a/code/pages/4_RL model playground.py +++ b/code/pages/4_RL model playground.py @@ -9,6 +9,7 @@ from aind_behavior_gym.dynamic_foraging.task import CoupledBlockTask from aind_dynamic_foraging_models import generative_model from aind_dynamic_foraging_models.generative_model import ForagerCollection +from aind_dynamic_foraging_models.generative_model.params import ParamsSymbols try: st.set_page_config(layout="wide", @@ -27,7 +28,12 @@ "Loss counting": "ForagerLossCounting" } -def _get_agent_args(agent_class): +para_range_override = { + "biasL": [-5.0, 5.0], + "choice_kernel_relative_weight": [0.0, 2.0], +} + +def _get_agent_args_options(agent_class): type_hints = get_type_hints(agent_class.__init__) signature = inspect.signature(agent_class.__init__) @@ -41,7 +47,38 @@ def _get_agent_args(agent_class): agent_args_options[arg] = {'options': literal_values, 'default': default_value} return agent_args_options -def get_arg_params(agent_args_options): + +def _get_params_options(param_model): + # Get the schema + params_schema = param_model.model_json_schema()["properties"] + + # Extract ge and le constraints + param_options = {} + + for para_name, para_field in params_schema.items(): + default = para_field.get("default", None) + para_desc = para_field.get("description", "") + + if para_name in para_range_override: + para_range = para_range_override[para_name] + else: # Get from pydantic schema + para_range = [-20, 20] # Default range + # Override the range if specified + if "minimum" in para_field: + para_range[0] = para_field["minimum"] + if "maximum" in para_field: + para_range[1] = para_field["maximum"] + para_range = [type(default)(x) for x in para_range] + + param_options[para_name] = dict( + para_range=para_range, + para_default=default, + para_symbol=ParamsSymbols[para_name], + para_desc=para_desc, + ) + return param_options + +def select_agent_args(agent_args_options): agent_args = {} # Select agent parameters for n, arg_name in enumerate(agent_args_options.keys()): @@ -51,27 +88,51 @@ def get_arg_params(agent_args_options): ) return agent_args +def select_params(params_options): + params = {} + # Select agent parameters + for n, para_name in enumerate(params_options.keys()): + para_range = params_options[para_name]['para_range'] + para_default = params_options[para_name]['para_default'] + para_symbol = params_options[para_name]['para_symbol'] + para_desc = params_options[para_name].get('para_desc', '') + + if para_range[0] == para_range[1]: # Fixed parameter + params[para_name] = para_range[0] + st.markdown(f"{para_symbol} ({para_name}) is fixed at {para_range[0]}") + else: + params[para_name] = st.slider( + f"{para_symbol} ({para_name})", + para_range[0], para_range[1], + para_default + ) + return params def app(): - # A dropdown to select model family + # -- Select agent family -- model_family = st.selectbox("Select model family", list(model_families.keys())) + # -- Select agent -- agent_class = getattr(generative_model, model_families[model_family]) - agent_args_options = _get_agent_args(agent_class) - agent_args = get_arg_params(agent_args_options) + agent_args_options = _get_agent_args_options(agent_class) + agent_args = select_agent_args(agent_args_options) - # Initialize the model - forager_collection = ForagerCollection() - all_presets = forager_collection.FORAGER_PRESETS.keys() - - forager = forager_collection.get_preset_forager("Hattori2019", seed=42) + # -- Select agent parameters -- + # Initialize the agent + forager = agent_class(**agent_args) + # Get params options + params_options = _get_params_options(forager.ParamModel) + params = select_params(params_options) forager.set_params( softmax_inverse_temperature=5, biasL=0, ) + # forager_collection = ForagerCollection() + # all_presets = forager_collection.FORAGER_PRESETS.keys() + # Create the task environment task = CoupledBlockTask(reward_baiting=True, num_trials=1000, seed=42) From 0ee18f84336b79fd57ee3487b4bac7b3e6e9b638 Mon Sep 17 00:00:00 2001 From: "houhan@gmail.com" Date: Sat, 31 Aug 2024 01:34:24 +0000 Subject: [PATCH 4/7] feat: add seed --- code/pages/4_RL model playground.py | 34 ++++++++++++++++++----------- 1 file changed, 21 insertions(+), 13 deletions(-) diff --git a/code/pages/4_RL model playground.py b/code/pages/4_RL model playground.py index 780c319..ccbdfb4 100644 --- a/code/pages/4_RL model playground.py +++ b/code/pages/4_RL model playground.py @@ -108,11 +108,7 @@ def select_params(params_options): ) return params -def app(): - - # -- Select agent family -- - model_family = st.selectbox("Select model family", list(model_families.keys())) - +def set_forager(model_family, seed=42): # -- Select agent -- agent_class = getattr(generative_model, model_families[model_family]) agent_args_options = _get_agent_args_options(agent_class) @@ -120,24 +116,36 @@ def app(): # -- Select agent parameters -- # Initialize the agent - forager = agent_class(**agent_args) + forager = agent_class(**agent_args, seed=seed) # Get params options params_options = _get_params_options(forager.ParamModel) params = select_params(params_options) + # Set the parameters + forager.set_params(**params) + return forager + +def app(): - forager.set_params( - softmax_inverse_temperature=5, - biasL=0, - ) + with st.sidebar: + seed = st.number_input("Random seed", value=42) + + + # -- Select forager family -- + agent_family = st.selectbox("Select model family", list(model_families.keys())) + + # -- Select forager -- + forager = set_forager(agent_family, seed=seed) # forager_collection = ForagerCollection() # all_presets = forager_collection.FORAGER_PRESETS.keys() # Create the task environment - task = CoupledBlockTask(reward_baiting=True, num_trials=1000, seed=42) + task = CoupledBlockTask(reward_baiting=True, num_trials=1000, seed=seed) - # Run the model + # -- Run the model -- forager.perform(task) + + if_plot_latent = st.checkbox("Plot latent variables", value=False) # Capture the results ground_truth_params = forager.params.model_dump() @@ -148,7 +156,7 @@ def app(): reward_history = forager.get_reward_history() # Plot the session results - fig, axes = forager.plot_session(if_plot_latent=True) + fig, axes = forager.plot_session(if_plot_latent=if_plot_latent) st.pyplot(fig) app() \ No newline at end of file From db53f2ad793f726a4d6a907d28130ea0f2edf0bb Mon Sep 17 00:00:00 2001 From: "houhan@gmail.com" Date: Sat, 31 Aug 2024 02:35:03 +0000 Subject: [PATCH 5/7] feat: task selector! --- code/pages/4_RL model playground.py | 105 ++++++++++++++++++++++++---- 1 file changed, 93 insertions(+), 12 deletions(-) diff --git a/code/pages/4_RL model playground.py b/code/pages/4_RL model playground.py index ccbdfb4..adf334b 100644 --- a/code/pages/4_RL model playground.py +++ b/code/pages/4_RL model playground.py @@ -6,7 +6,9 @@ from typing import get_type_hints, _LiteralGenericAlias import inspect -from aind_behavior_gym.dynamic_foraging.task import CoupledBlockTask +from aind_behavior_gym.dynamic_foraging.task import ( + CoupledBlockTask, UncoupledBlockTask, RandomWalkTask + ) from aind_dynamic_foraging_models import generative_model from aind_dynamic_foraging_models.generative_model import ForagerCollection from aind_dynamic_foraging_models.generative_model.params import ParamsSymbols @@ -28,6 +30,12 @@ "Loss counting": "ForagerLossCounting" } +task_families = { + "Coupled block task": "CoupledBlockTask", + "Uncoupled block task": "UncoupledBlockTask", + "Random walk task": "RandomWalkTask", +} + para_range_override = { "biasL": [-5.0, 5.0], "choice_kernel_relative_weight": [0.0, 2.0], @@ -108,7 +116,7 @@ def select_params(params_options): ) return params -def set_forager(model_family, seed=42): +def select_forager(model_family, seed=42): # -- Select agent -- agent_class = getattr(generative_model, model_families[model_family]) agent_args_options = _get_agent_args_options(agent_class) @@ -124,23 +132,96 @@ def set_forager(model_family, seed=42): forager.set_params(**params) return forager +def select_task(task_family, reward_baiting, n_trials, seed): + # Task parameters (hard coded for now) + if task_family == "Coupled block task": + block_min, block_max = st.slider("Block length range", 0, 200, [40, 80]) + block_beta = st.slider("Block beta", 0, 100, 20) + p_reward_contrast = st.multiselect( + "Reward contrasts", + options=["1:1", "1:3", "1:6", "1:8"], + default=["1:1", "1:3", "1:6", "1:8"],) + p_reward_sum = st.slider("p_reward sum", 0.0, 1.0, 0.45) + + p_reward_pairs = [] + for contrast in p_reward_contrast: + p1, p2 = contrast.split(":") + p1, p2 = int(p1), int(p2) + p_reward_pairs.append( + [p1 * p_reward_sum / (p1 + p2), + p2 * p_reward_sum / (p1 + p2)] + ) + + # Create task + return CoupledBlockTask( + block_min=block_min, + block_max=block_max, + p_reward_pairs=p_reward_pairs, + block_beta=block_beta, + reward_baiting=reward_baiting, + num_trials=n_trials, + seed=seed) + + if task_family == "Uncoupled block task": + block_min, block_max = st.slider( + "Block length range on each side", 0, 200, [20, 35] + ) + rwd_prob_array = st.selectbox( + "Reward probabilities", + options=[[0.1, 0.5, 0.9], + [0.1, 0.4, 0.7], + [0.1, 0.3, 0.5], + ], + index=0, + ) + return UncoupledBlockTask( + rwd_prob_array=rwd_prob_array, + block_min=block_min, + block_max=block_max, + persev_add=True, + perseverative_limit=4, + max_block_tally=4, + num_trials=n_trials, + seed=seed, + ) + + if task_family == "Random walk task": + p_min, p_max = st.slider("p_reward range", 0.0, 1.0, [0.0, 1.0]) + sigma = st.slider("Random walk $\sigma$", 0.0, 1.0, 0.15) + return RandomWalkTask( + p_min=[p_min, p_min], + p_max=[p_max, p_max], + sigma=[sigma, sigma], + mean=[0, 0], + num_trials=n_trials, + seed=seed, + ) + def app(): with st.sidebar: seed = st.number_input("Random seed", value=42) - # -- Select forager family -- agent_family = st.selectbox("Select model family", list(model_families.keys())) # -- Select forager -- - forager = set_forager(agent_family, seed=seed) + forager = select_forager(agent_family, seed=seed) # forager_collection = ForagerCollection() # all_presets = forager_collection.FORAGER_PRESETS.keys() - # Create the task environment - task = CoupledBlockTask(reward_baiting=True, num_trials=1000, seed=seed) + # -- Select task family -- + task_family = st.selectbox( + "Select task family", + list(task_families.keys()), + index=0, + ) + reward_baiting = st.checkbox("Reward baiting", value=True) + n_trials = st.slider("Number of trials", 100, 5000, 1000) + + # -- Select task -- + task = select_task(task_family, reward_baiting, n_trials, seed) # -- Run the model -- forager.perform(task) @@ -148,12 +229,12 @@ def app(): if_plot_latent = st.checkbox("Plot latent variables", value=False) # Capture the results - ground_truth_params = forager.params.model_dump() - ground_truth_choice_prob = forager.choice_prob - ground_truth_q_value = forager.q_value - # Get the history - choice_history = forager.get_choice_history() - reward_history = forager.get_reward_history() + # ground_truth_params = forager.params.model_dump() + # ground_truth_choice_prob = forager.choice_prob + # ground_truth_q_value = forager.q_value + # # Get the history + # choice_history = forager.get_choice_history() + # reward_history = forager.get_reward_history() # Plot the session results fig, axes = forager.plot_session(if_plot_latent=if_plot_latent) From 1585de20bf4ecd65e173a6565309297eeb47d181 Mon Sep 17 00:00:00 2001 From: "houhan@gmail.com" Date: Sat, 31 Aug 2024 02:35:32 +0000 Subject: [PATCH 6/7] build: develop branch of gym --- environment/Dockerfile | 2 +- requirements.txt | 2 +- 2 files changed, 2 insertions(+), 2 deletions(-) diff --git a/environment/Dockerfile b/environment/Dockerfile index 1d46388..33ac566 100644 --- a/environment/Dockerfile +++ b/environment/Dockerfile @@ -109,7 +109,7 @@ RUN pip install -U --no-cache-dir \ streamlit-dynamic-filters==0.1.9 \ semver==3.0.2 \ git+https://github.com/AllenNeuralDynamics/aind-dynamic-foraging-models.git@develop \ - aind_behavior_gym + git+https://github.com/AllenNeuralDynamics/aind-behavior-gym.git@develop ADD "https://github.com/coder/code-server/releases/download/v4.21.1/code-server-4.21.1-linux-amd64.tar.gz" /.code-server/code-server.tar.gz diff --git a/requirements.txt b/requirements.txt index 58ef5fd..6573052 100644 --- a/requirements.txt +++ b/requirements.txt @@ -98,4 +98,4 @@ aind-data-access-api[docdb]==0.13.0 streamlit-dynamic-filters==0.1.9 semver==3.0.2 git+https://github.com/AllenNeuralDynamics/aind-dynamic-foraging-models.git@develop -aind_behavior_gym \ No newline at end of file +git+https://github.com/AllenNeuralDynamics/aind-behavior-gym.git@develop From 4b72779990514866f02dc8df767690a4ef85e8dd Mon Sep 17 00:00:00 2001 From: "houhan@gmail.com" Date: Sat, 31 Aug 2024 02:55:59 +0000 Subject: [PATCH 7/7] feat: widget organization --- code/pages/4_RL model playground.py | 72 ++++++++++++++++++----------- 1 file changed, 45 insertions(+), 27 deletions(-) diff --git a/code/pages/4_RL model playground.py b/code/pages/4_RL model playground.py index adf334b..1d70557 100644 --- a/code/pages/4_RL model playground.py +++ b/code/pages/4_RL model playground.py @@ -96,7 +96,10 @@ def select_agent_args(agent_args_options): ) return agent_args -def select_params(params_options): +def select_params(forager): + # Get params schema + params_options = _get_params_options(forager.ParamModel) + params = {} # Select agent parameters for n, para_name in enumerate(params_options.keys()): @@ -123,13 +126,8 @@ def select_forager(model_family, seed=42): agent_args = select_agent_args(agent_args_options) # -- Select agent parameters -- - # Initialize the agent + # Initialize the agent`` forager = agent_class(**agent_args, seed=seed) - # Get params options - params_options = _get_params_options(forager.ParamModel) - params = select_params(params_options) - # Set the parameters - forager.set_params(**params) return forager def select_task(task_family, reward_baiting, n_trials, seed): @@ -202,26 +200,45 @@ def app(): with st.sidebar: seed = st.number_input("Random seed", value=42) - # -- Select forager family -- - agent_family = st.selectbox("Select model family", list(model_families.keys())) + st.title("RL model playground") - # -- Select forager -- - forager = select_forager(agent_family, seed=seed) - - # forager_collection = ForagerCollection() - # all_presets = forager_collection.FORAGER_PRESETS.keys() - - # -- Select task family -- - task_family = st.selectbox( - "Select task family", - list(task_families.keys()), - index=0, - ) - reward_baiting = st.checkbox("Reward baiting", value=True) - n_trials = st.slider("Number of trials", 100, 5000, 1000) + col0 = st.columns([1, 1]) + + with col0[0]: + with st.expander("Agent", expanded=True): + col1 = st.columns([1, 2]) + with col1[0]: + # -- Select forager family -- + agent_family = st.selectbox( + "Select agent family", + list(model_families.keys()) + ) + # -- Select forager -- + forager = select_forager(agent_family, seed=seed) + with col1[1]: + # -- Select forager parameters -- + params = select_params(forager) + # Set the parameters + forager.set_params(**params) + + # forager_collection = ForagerCollection() + # all_presets = forager_collection.FORAGER_PRESETS.keys() - # -- Select task -- - task = select_task(task_family, reward_baiting, n_trials, seed) + with col0[1]: + with st.expander("Task", expanded=True): + col1 = st.columns([1, 2]) + with col1[0]: + # -- Select task family -- + task_family = st.selectbox( + "Select task family", + list(task_families.keys()), + index=0, + ) + reward_baiting = st.checkbox("Reward baiting", value=True) + n_trials = st.slider("Number of trials", 100, 5000, 1000) + with col1[1]: + # -- Select task -- + task = select_task(task_family, reward_baiting, n_trials, seed) # -- Run the model -- forager.perform(task) @@ -237,7 +254,8 @@ def app(): # reward_history = forager.get_reward_history() # Plot the session results - fig, axes = forager.plot_session(if_plot_latent=if_plot_latent) - st.pyplot(fig) + fig, axes = forager.plot_session(if_plot_latent=if_plot_latent) + with st.columns([1, 0.5])[0]: + st.pyplot(fig) app() \ No newline at end of file