-
Notifications
You must be signed in to change notification settings - Fork 1
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Adding Q-Learning agent class and refactoring for consistency.
Implemented a new QLearningAgent class in grid_world/methods/temporal_difference/q_learning.py that uses the Q-learning algorithm. The additions include action-value estimates and Q-learning memory alongside better defined environments. Refactored the entirety of agent's memory methods (add_memory and reset_memory) in QLearningAgent, SarsaAgentBase, and TdAgent for better consistency and clarity. Adjusted several return type in GridWorld and GridWorldRender. The methods add_memory and reset_memory in q_learning.py and sarsa_agent.py now have specific and improved docstrings. Included a new method, run_q_learning_episode, in agent_episodes.py to accommodate the Q-learning agent. Introduced a test module for the QLearningAgent class in tests to ensure accurate implementation of the Q-learning algorithm. These changes enhance the functionality of the grid_world environment, including reinforcement learning capabilities which can facilitate experimental research.
- Loading branch information
1 parent
50f279d
commit 2768567
Showing
10 changed files
with
245 additions
and
12 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
111 changes: 111 additions & 0 deletions
111
...ent_learning/markov_decision_process/grid_world/methods/temporal_difference/q_learning.py
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,111 @@ | ||
"""A Q-learning agent. | ||
The QLearningMemory class represents a single transition in the reinforcement learning environment. Each transition | ||
consists of a current state, reward received, next state and a boolean flag indicating whether the episode is done. | ||
""" | ||
from collections import defaultdict | ||
from types import MappingProxyType | ||
from typing import TYPE_CHECKING, Self, final | ||
|
||
import numpy as np | ||
from pydantic import StrictBool, StrictFloat | ||
from pydantic.dataclasses import dataclass | ||
|
||
from reinforcement_learning.errors import InvalidMemoryError, NotInitializedError | ||
from reinforcement_learning.markov_decision_process.grid_world.agent_base import AgentBase | ||
from reinforcement_learning.markov_decision_process.grid_world.environment import ( | ||
Action, | ||
ActionResult, | ||
ReadOnlyActionValue, | ||
State, | ||
) | ||
|
||
if TYPE_CHECKING: | ||
from reinforcement_learning.markov_decision_process.grid_world.environment import ActionValue | ||
|
||
|
||
@final | ||
@dataclass(frozen=True) | ||
class QLearningMemory: | ||
"""Memory class represents a single transition in a reinforcement learning environment. | ||
Attributes: | ||
---------- | ||
state (State): The current state in the transition. | ||
reward (StrictFloat): The reward received in the transition. | ||
next_state (State): The next state after the transition. | ||
done (StrictBool): Indicates whether the episode is done after the transition. | ||
""" | ||
|
||
state: State | ||
action: Action | ||
reward: StrictFloat | ||
next_state: State | ||
done: StrictBool | ||
|
||
|
||
class QLearningAgent(AgentBase): | ||
"""An agent that uses the Q-learning algorithm to learn and make decisions in a grid world environment.""" | ||
|
||
def __init__(self: Self, *, seed: int | None): | ||
"""Initialize the agent with the given seed.""" | ||
super().__init__(seed=seed) | ||
self.__gamma: float = 0.9 | ||
self.__alpha: float = 0.8 | ||
self.__epsilon: float = 0.1 | ||
self.__action_value: ActionValue = defaultdict(lambda: 0.0) | ||
self.__memory: QLearningMemory | None = None | ||
|
||
@property | ||
def action_value(self: Self) -> ReadOnlyActionValue: | ||
"""Get the current value of the action-value function. | ||
Returns: | ||
------- | ||
ActionValue: The instance's internal action-value function. | ||
""" | ||
return MappingProxyType(self.__action_value) | ||
|
||
def get_action(self: Self, *, state: State) -> Action: | ||
"""Select an action based on `self.rng`.""" | ||
if self.rng.random() < self.__epsilon: | ||
return Action(self.rng.choice(list(Action))) | ||
return Action(np.argmax([self.__action_value[state, action] for action in Action]).item()) | ||
|
||
def add_memory(self: Self, *, state: State, action: Action | None, result: ActionResult | None) -> None: | ||
"""Add a new experience into the memory. | ||
Args: | ||
---- | ||
state: The current state of the agent. | ||
action: The action taken by the agent. | ||
result: The result of the action taken by the agent. | ||
""" | ||
if action is None or result is None: | ||
if action is None and result is None: | ||
message = "action or result must not be None" | ||
elif action is None: | ||
message = "action must not be None" | ||
else: | ||
message = "result must not be None" | ||
raise InvalidMemoryError(message) | ||
self.__memory = QLearningMemory( | ||
state=state, action=action, reward=result.reward, next_state=result.next_state, done=result.done | ||
) | ||
|
||
def reset_memory(self: Self) -> None: | ||
"""Reset the agent's memory.""" | ||
self.__memory = None | ||
|
||
def update(self: Self) -> None: | ||
"""Update the action-value estimates of the Q-learning agent.""" | ||
if self.__memory is None: | ||
raise NotInitializedError(instance_name=str(self), attribute_name="__memory") | ||
if self.__memory.done: | ||
max_next_action_value = 0.0 | ||
else: | ||
max_next_action_value = max(self.__action_value[self.__memory.next_state, action] for action in Action) | ||
target = self.__gamma * max_next_action_value + self.__memory.reward | ||
self.__action_value[self.__memory.state, self.__memory.action] += ( | ||
target - self.__action_value[self.__memory.state, self.__memory.action] | ||
) * self.__alpha |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
65 changes: 65 additions & 0 deletions
65
tests/markov_dicision_process/grid_world/methods/temporal_difference/test_q_learning.py
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,65 @@ | ||
import numpy as np | ||
import pytest | ||
|
||
from reinforcement_learning.errors import InvalidMemoryError, NotInitializedError | ||
from reinforcement_learning.markov_decision_process.grid_world.environment import Action, ActionResult, GridWorld | ||
from reinforcement_learning.markov_decision_process.grid_world.methods.temporal_difference.agent_episodes import ( | ||
run_q_learning_episode, | ||
) | ||
from reinforcement_learning.markov_decision_process.grid_world.methods.temporal_difference.q_learning import ( | ||
QLearningAgent, | ||
) | ||
|
||
|
||
def test_q_learning_add_memory() -> None: | ||
agent = QLearningAgent(seed=0) | ||
with pytest.raises(InvalidMemoryError): | ||
agent.add_memory(state=(0, 0), action=None, result=None) | ||
with pytest.raises(InvalidMemoryError): | ||
agent.add_memory(state=(0, 0), action=Action.UP, result=None) | ||
with pytest.raises(InvalidMemoryError): | ||
agent.add_memory(state=(0, 0), action=None, result=ActionResult(next_state=(0, 1), reward=1.0, done=False)) | ||
|
||
|
||
def test_update_with_empty_memory() -> None: | ||
agent = QLearningAgent(seed=0) | ||
with pytest.raises(NotInitializedError): | ||
agent.update() | ||
|
||
|
||
def test_q_learning() -> None: | ||
test_map = np.array( | ||
[[0.0, 0.0, 0.0, 1.0], [0.0, None, 0.0, -1.0], [0.0, 0.0, 0.0, 0.0]], | ||
dtype=np.float64, | ||
) | ||
env = GridWorld(reward_map=test_map, goal_state=(0, 3), start_state=(2, 0)) | ||
agent = QLearningAgent(seed=0) | ||
for _ in range(2): | ||
run_q_learning_episode(env=env, agent=agent) | ||
|
||
assert agent.action_value == { | ||
((2, 0), Action.UP): 0.0, | ||
((2, 0), Action.DOWN): 0.0, | ||
((2, 0), Action.LEFT): 0.0, | ||
((2, 0), Action.RIGHT): 0.0, | ||
((1, 0), Action.UP): 0.0, | ||
((1, 0), Action.DOWN): 0.0, | ||
((1, 0), Action.LEFT): 0.0, | ||
((1, 0), Action.RIGHT): 0.0, | ||
((0, 0), Action.UP): 0.0, | ||
((0, 0), Action.DOWN): 0.0, | ||
((0, 0), Action.LEFT): 0.0, | ||
((0, 0), Action.RIGHT): 0.0, | ||
((0, 1), Action.UP): 0.0, | ||
((0, 1), Action.DOWN): 0.0, | ||
((0, 1), Action.LEFT): 0.0, | ||
((0, 1), Action.RIGHT): 0.5760000000000001, | ||
((0, 2), Action.UP): 0.0, | ||
((0, 2), Action.DOWN): 0.0, | ||
((0, 2), Action.LEFT): 0.0, | ||
((0, 2), Action.RIGHT): 0.96, | ||
((1, 2), Action.UP): 0.0, | ||
((1, 2), Action.DOWN): 0.0, | ||
((1, 2), Action.LEFT): 0.0, | ||
((1, 2), Action.RIGHT): 0.0, | ||
} |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
2768567
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Coverage Report