From 27685672a0bd92557fe47c4dce7abf049e6d3677 Mon Sep 17 00:00:00 2001 From: hnakashima Date: Sun, 26 Nov 2023 11:52:02 +0900 Subject: [PATCH] Adding Q-Learning agent class and refactoring for consistency. Implemented a new QLearningAgent class in grid_world/methods/temporal_difference/q_learning.py that uses the Q-learning algorithm. The additions include action-value estimates and Q-learning memory alongside better defined environments. Refactored the entirety of agent's memory methods (add_memory and reset_memory) in QLearningAgent, SarsaAgentBase, and TdAgent for better consistency and clarity. Adjusted several return type in GridWorld and GridWorldRender. The methods add_memory and reset_memory in q_learning.py and sarsa_agent.py now have specific and improved docstrings. Included a new method, run_q_learning_episode, in agent_episodes.py to accommodate the Q-learning agent. Introduced a test module for the QLearningAgent class in tests to ensure accurate implementation of the Q-learning algorithm. These changes enhance the functionality of the grid_world environment, including reinforcement learning capabilities which can facilitate experimental research. --- .../grid_world/agent_base.py | 4 +- .../grid_world/environment.py | 5 + .../methods/monte_carlo/mc_agent.py | 11 +- .../temporal_difference/agent_episodes.py | 33 +++++- .../methods/temporal_difference/q_learning.py | 111 ++++++++++++++++++ .../temporal_difference/sarsa_agent.py | 9 +- .../methods/temporal_difference/td_eval.py | 11 +- .../grid_world/render.py | 6 +- .../temporal_difference/test_q_learning.py | 65 ++++++++++ .../{methods => }/test_environment.py | 2 + 10 files changed, 245 insertions(+), 12 deletions(-) create mode 100644 reinforcement_learning/markov_decision_process/grid_world/methods/temporal_difference/q_learning.py create mode 100644 tests/markov_dicision_process/grid_world/methods/temporal_difference/test_q_learning.py rename tests/markov_dicision_process/grid_world/{methods => }/test_environment.py (96%) diff --git a/reinforcement_learning/markov_decision_process/grid_world/agent_base.py b/reinforcement_learning/markov_decision_process/grid_world/agent_base.py index 73c6148..cd77bf1 100644 --- a/reinforcement_learning/markov_decision_process/grid_world/agent_base.py +++ b/reinforcement_learning/markov_decision_process/grid_world/agent_base.py @@ -31,7 +31,7 @@ def rng(self: Self) -> np.random.Generator: @abstractmethod def get_action(self: Self, *, state: State) -> Action: - """Select an action based on policy `self.__b`. + """Select an action. Args: ---- @@ -56,7 +56,7 @@ def add_memory(self: Self, *, state: State, action: Action | None, result: Actio @abstractmethod def reset_memory(self: Self) -> None: - """Clear the memory of the reinforcement learning agent.""" + """Reset the agent's memory.""" @abstractmethod def update(self: Self) -> None: diff --git a/reinforcement_learning/markov_decision_process/grid_world/environment.py b/reinforcement_learning/markov_decision_process/grid_world/environment.py index 995eb47..d862901 100644 --- a/reinforcement_learning/markov_decision_process/grid_world/environment.py +++ b/reinforcement_learning/markov_decision_process/grid_world/environment.py @@ -173,6 +173,11 @@ def agent_state(self: Self) -> State: """Return the current state of the agent.""" return self.__agent_state + @property + def wall_states(self: Self) -> frozenset[State]: + """Return the wall states of the GridWorld.""" + return self.__wall_states + @property def height(self: Self) -> int: """Return the height of the grid in the GridWorld object. diff --git a/reinforcement_learning/markov_decision_process/grid_world/methods/monte_carlo/mc_agent.py b/reinforcement_learning/markov_decision_process/grid_world/methods/monte_carlo/mc_agent.py index 32da446..49a9965 100644 --- a/reinforcement_learning/markov_decision_process/grid_world/methods/monte_carlo/mc_agent.py +++ b/reinforcement_learning/markov_decision_process/grid_world/methods/monte_carlo/mc_agent.py @@ -47,7 +47,14 @@ def memories(self: Self) -> tuple[McMemory, ...]: @final def add_memory(self: Self, *, state: State, action: Action | None, result: ActionResult | None) -> None: - """Add a new experience into the memory.""" + """Add a new experience into the memory. + + Args: + ---- + state: The current state of the agent. + action: The action taken by the agent. + result: The result of the action taken by the agent. + """ if action is None or result is None: if action is None and result is None: message = "action or result must not be None" @@ -61,7 +68,7 @@ def add_memory(self: Self, *, state: State, action: Action | None, result: Actio @final def reset_memory(self: Self) -> None: - """Clear the memory of the reinforcement learning agent.""" + """Reset the agent's memory.""" self.__memories.clear() diff --git a/reinforcement_learning/markov_decision_process/grid_world/methods/temporal_difference/agent_episodes.py b/reinforcement_learning/markov_decision_process/grid_world/methods/temporal_difference/agent_episodes.py index dd6a909..a770842 100644 --- a/reinforcement_learning/markov_decision_process/grid_world/methods/temporal_difference/agent_episodes.py +++ b/reinforcement_learning/markov_decision_process/grid_world/methods/temporal_difference/agent_episodes.py @@ -1,12 +1,15 @@ """Episode runner.""" -from reinforcement_learning.markov_decision_process.grid_world.agent_base import AgentBase from reinforcement_learning.markov_decision_process.grid_world.environment import GridWorld +from reinforcement_learning.markov_decision_process.grid_world.methods.temporal_difference.q_learning import ( + QLearningAgent, +) from reinforcement_learning.markov_decision_process.grid_world.methods.temporal_difference.sarsa_agent import ( SarsaAgentBase, ) +from reinforcement_learning.markov_decision_process.grid_world.methods.temporal_difference.td_eval import TdAgent -def run_td_episode(env: GridWorld, agent: AgentBase) -> None: +def run_td_episode(env: GridWorld, agent: TdAgent) -> None: """Run an episode for a temporary difference agent in the environment. Args: @@ -59,3 +62,29 @@ def run_sarsa_episode(env: GridWorld, agent: SarsaAgentBase) -> None: state = result.next_state agent.add_memory(state=state, action=None, result=None) agent.update() + + +def run_q_learning_episode(env: GridWorld, agent: QLearningAgent) -> None: + """Run an episode for a temporary difference agent in the environment. + + Args: + ---- + env: The GridWorld environment in which the agent will run. + agent: The TdAgent. + + Returns: + ------- + None + + """ + env.reset_agent_state() + state = env.agent_state + agent.reset_memory() + while True: + action = agent.get_action(state=state) + result = env.step(action=action) + agent.add_memory(state=state, action=action, result=result) + agent.update() + if result.done: + break + state = result.next_state diff --git a/reinforcement_learning/markov_decision_process/grid_world/methods/temporal_difference/q_learning.py b/reinforcement_learning/markov_decision_process/grid_world/methods/temporal_difference/q_learning.py new file mode 100644 index 0000000..472513d --- /dev/null +++ b/reinforcement_learning/markov_decision_process/grid_world/methods/temporal_difference/q_learning.py @@ -0,0 +1,111 @@ +"""A Q-learning agent. + +The QLearningMemory class represents a single transition in the reinforcement learning environment. Each transition +consists of a current state, reward received, next state and a boolean flag indicating whether the episode is done. +""" +from collections import defaultdict +from types import MappingProxyType +from typing import TYPE_CHECKING, Self, final + +import numpy as np +from pydantic import StrictBool, StrictFloat +from pydantic.dataclasses import dataclass + +from reinforcement_learning.errors import InvalidMemoryError, NotInitializedError +from reinforcement_learning.markov_decision_process.grid_world.agent_base import AgentBase +from reinforcement_learning.markov_decision_process.grid_world.environment import ( + Action, + ActionResult, + ReadOnlyActionValue, + State, +) + +if TYPE_CHECKING: + from reinforcement_learning.markov_decision_process.grid_world.environment import ActionValue + + +@final +@dataclass(frozen=True) +class QLearningMemory: + """Memory class represents a single transition in a reinforcement learning environment. + + Attributes: + ---------- + state (State): The current state in the transition. + reward (StrictFloat): The reward received in the transition. + next_state (State): The next state after the transition. + done (StrictBool): Indicates whether the episode is done after the transition. + """ + + state: State + action: Action + reward: StrictFloat + next_state: State + done: StrictBool + + +class QLearningAgent(AgentBase): + """An agent that uses the Q-learning algorithm to learn and make decisions in a grid world environment.""" + + def __init__(self: Self, *, seed: int | None): + """Initialize the agent with the given seed.""" + super().__init__(seed=seed) + self.__gamma: float = 0.9 + self.__alpha: float = 0.8 + self.__epsilon: float = 0.1 + self.__action_value: ActionValue = defaultdict(lambda: 0.0) + self.__memory: QLearningMemory | None = None + + @property + def action_value(self: Self) -> ReadOnlyActionValue: + """Get the current value of the action-value function. + + Returns: + ------- + ActionValue: The instance's internal action-value function. + """ + return MappingProxyType(self.__action_value) + + def get_action(self: Self, *, state: State) -> Action: + """Select an action based on `self.rng`.""" + if self.rng.random() < self.__epsilon: + return Action(self.rng.choice(list(Action))) + return Action(np.argmax([self.__action_value[state, action] for action in Action]).item()) + + def add_memory(self: Self, *, state: State, action: Action | None, result: ActionResult | None) -> None: + """Add a new experience into the memory. + + Args: + ---- + state: The current state of the agent. + action: The action taken by the agent. + result: The result of the action taken by the agent. + """ + if action is None or result is None: + if action is None and result is None: + message = "action or result must not be None" + elif action is None: + message = "action must not be None" + else: + message = "result must not be None" + raise InvalidMemoryError(message) + self.__memory = QLearningMemory( + state=state, action=action, reward=result.reward, next_state=result.next_state, done=result.done + ) + + def reset_memory(self: Self) -> None: + """Reset the agent's memory.""" + self.__memory = None + + def update(self: Self) -> None: + """Update the action-value estimates of the Q-learning agent.""" + if self.__memory is None: + raise NotInitializedError(instance_name=str(self), attribute_name="__memory") + if self.__memory.done: + max_next_action_value = 0.0 + else: + max_next_action_value = max(self.__action_value[self.__memory.next_state, action] for action in Action) + target = self.__gamma * max_next_action_value + self.__memory.reward + self.__action_value[self.__memory.state, self.__memory.action] += ( + target - self.__action_value[self.__memory.state, self.__memory.action] + ) * self.__alpha diff --git a/reinforcement_learning/markov_decision_process/grid_world/methods/temporal_difference/sarsa_agent.py b/reinforcement_learning/markov_decision_process/grid_world/methods/temporal_difference/sarsa_agent.py index e34dd51..350eb37 100644 --- a/reinforcement_learning/markov_decision_process/grid_world/methods/temporal_difference/sarsa_agent.py +++ b/reinforcement_learning/markov_decision_process/grid_world/methods/temporal_difference/sarsa_agent.py @@ -79,7 +79,14 @@ def memories(self: Self) -> tuple[SarsaMemory, ...]: return tuple(self.__memories) def add_memory(self: Self, *, state: State, action: Action | None, result: ActionResult | None) -> None: - """Add a new experience into the memory.""" + """Add a new experience into the memory. + + Args: + ---- + state: The current state of the agent. + action: The action taken by the agent. + result: The result of the action taken by the agent. + """ memory = SarsaMemory( state=state, action=action, diff --git a/reinforcement_learning/markov_decision_process/grid_world/methods/temporal_difference/td_eval.py b/reinforcement_learning/markov_decision_process/grid_world/methods/temporal_difference/td_eval.py index e88993a..0630f00 100644 --- a/reinforcement_learning/markov_decision_process/grid_world/methods/temporal_difference/td_eval.py +++ b/reinforcement_learning/markov_decision_process/grid_world/methods/temporal_difference/td_eval.py @@ -95,7 +95,14 @@ def v(self: Self) -> ReadOnlyStateValue: return MappingProxyType(self.__v) def add_memory(self: Self, *, state: State, action: Action | None, result: ActionResult | None) -> None: # noqa: ARG002 - """Add a new experience into the memory.""" + """Add a new experience into the memory. + + Args: + ---- + state: The current state of the agent. + action: The action taken by the agent. + result: The result of the action taken by the agent. + """ if result is None: message = "result must not be None" raise InvalidMemoryError(message) @@ -103,7 +110,7 @@ def add_memory(self: Self, *, state: State, action: Action | None, result: Actio self.__memory = memory def reset_memory(self: Self) -> None: - """Clear the memory of the reinforcement learning agent.""" + """Reset the agent's memory.""" self.__memory = None def update(self: Self) -> None: diff --git a/reinforcement_learning/markov_decision_process/grid_world/render.py b/reinforcement_learning/markov_decision_process/grid_world/render.py index ac1a18a..7dc6fe5 100644 --- a/reinforcement_learning/markov_decision_process/grid_world/render.py +++ b/reinforcement_learning/markov_decision_process/grid_world/render.py @@ -14,9 +14,9 @@ from reinforcement_learning.errors import NotInitializedError from reinforcement_learning.markov_decision_process.grid_world.environment import ( Action, - ActionValue, Map, Policy, + ReadOnlyActionValue, State, StateValue, ) @@ -260,7 +260,7 @@ def render_policy(self: Self, policy: Policy, state: State) -> None: offset = offsets[action] self.ax.text(state[1] + 0.45 + offset[0], self.ys - state[0] - 0.5 + offset[1], arrow) - def generate_policy(self: Self, *, q: ActionValue) -> Policy: + def generate_policy(self: Self, *, q: ReadOnlyActionValue) -> Policy: """Generate a policy based on the state-action values. Returns: @@ -279,7 +279,7 @@ def generate_policy(self: Self, *, q: ActionValue) -> Policy: return policy - def render_q(self: Self, *, q: ActionValue, show_greedy_policy: bool = True) -> None: + def render_q(self: Self, *, q: ReadOnlyActionValue, show_greedy_policy: bool = True) -> None: """Render the Q-values of the grid world environment. Args: diff --git a/tests/markov_dicision_process/grid_world/methods/temporal_difference/test_q_learning.py b/tests/markov_dicision_process/grid_world/methods/temporal_difference/test_q_learning.py new file mode 100644 index 0000000..d805b18 --- /dev/null +++ b/tests/markov_dicision_process/grid_world/methods/temporal_difference/test_q_learning.py @@ -0,0 +1,65 @@ +import numpy as np +import pytest + +from reinforcement_learning.errors import InvalidMemoryError, NotInitializedError +from reinforcement_learning.markov_decision_process.grid_world.environment import Action, ActionResult, GridWorld +from reinforcement_learning.markov_decision_process.grid_world.methods.temporal_difference.agent_episodes import ( + run_q_learning_episode, +) +from reinforcement_learning.markov_decision_process.grid_world.methods.temporal_difference.q_learning import ( + QLearningAgent, +) + + +def test_q_learning_add_memory() -> None: + agent = QLearningAgent(seed=0) + with pytest.raises(InvalidMemoryError): + agent.add_memory(state=(0, 0), action=None, result=None) + with pytest.raises(InvalidMemoryError): + agent.add_memory(state=(0, 0), action=Action.UP, result=None) + with pytest.raises(InvalidMemoryError): + agent.add_memory(state=(0, 0), action=None, result=ActionResult(next_state=(0, 1), reward=1.0, done=False)) + + +def test_update_with_empty_memory() -> None: + agent = QLearningAgent(seed=0) + with pytest.raises(NotInitializedError): + agent.update() + + +def test_q_learning() -> None: + test_map = np.array( + [[0.0, 0.0, 0.0, 1.0], [0.0, None, 0.0, -1.0], [0.0, 0.0, 0.0, 0.0]], + dtype=np.float64, + ) + env = GridWorld(reward_map=test_map, goal_state=(0, 3), start_state=(2, 0)) + agent = QLearningAgent(seed=0) + for _ in range(2): + run_q_learning_episode(env=env, agent=agent) + + assert agent.action_value == { + ((2, 0), Action.UP): 0.0, + ((2, 0), Action.DOWN): 0.0, + ((2, 0), Action.LEFT): 0.0, + ((2, 0), Action.RIGHT): 0.0, + ((1, 0), Action.UP): 0.0, + ((1, 0), Action.DOWN): 0.0, + ((1, 0), Action.LEFT): 0.0, + ((1, 0), Action.RIGHT): 0.0, + ((0, 0), Action.UP): 0.0, + ((0, 0), Action.DOWN): 0.0, + ((0, 0), Action.LEFT): 0.0, + ((0, 0), Action.RIGHT): 0.0, + ((0, 1), Action.UP): 0.0, + ((0, 1), Action.DOWN): 0.0, + ((0, 1), Action.LEFT): 0.0, + ((0, 1), Action.RIGHT): 0.5760000000000001, + ((0, 2), Action.UP): 0.0, + ((0, 2), Action.DOWN): 0.0, + ((0, 2), Action.LEFT): 0.0, + ((0, 2), Action.RIGHT): 0.96, + ((1, 2), Action.UP): 0.0, + ((1, 2), Action.DOWN): 0.0, + ((1, 2), Action.LEFT): 0.0, + ((1, 2), Action.RIGHT): 0.0, + } diff --git a/tests/markov_dicision_process/grid_world/methods/test_environment.py b/tests/markov_dicision_process/grid_world/test_environment.py similarity index 96% rename from tests/markov_dicision_process/grid_world/methods/test_environment.py rename to tests/markov_dicision_process/grid_world/test_environment.py index b7fb80f..9c2dc03 100644 --- a/tests/markov_dicision_process/grid_world/methods/test_environment.py +++ b/tests/markov_dicision_process/grid_world/test_environment.py @@ -28,6 +28,8 @@ def test_states(mock_env: GridWorld) -> None: all_states = set(mock_env.states()) assert all_states == {(h, w) for h in range(mock_env.shape[0]) for w in range(mock_env.shape[1])} + assert mock_env.wall_states == {(1, 1)} + @pytest.fixture(params=Action) def action(request: SubRequest) -> Action: