Skip to content

Commit

Permalink
Adding Q-Learning agent class and refactoring for consistency.
Browse files Browse the repository at this point in the history
Implemented a new QLearningAgent class in grid_world/methods/temporal_difference/q_learning.py that uses the Q-learning algorithm. The additions include action-value estimates and Q-learning memory alongside better defined environments.

Refactored the entirety of agent's memory methods (add_memory and reset_memory) in QLearningAgent, SarsaAgentBase, and TdAgent for better consistency and clarity.

Adjusted several return type in GridWorld and GridWorldRender.

The methods add_memory and reset_memory in q_learning.py and sarsa_agent.py now have specific and improved docstrings.

Included a new method, run_q_learning_episode, in agent_episodes.py to accommodate the Q-learning agent.

Introduced a test module for the QLearningAgent class in tests to ensure accurate implementation of the Q-learning algorithm.

These changes enhance the functionality of the grid_world environment, including reinforcement learning capabilities which can facilitate experimental research.
  • Loading branch information
nakashima-hikaru committed Nov 26, 2023
1 parent 50f279d commit 2768567
Show file tree
Hide file tree
Showing 10 changed files with 245 additions and 12 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -31,7 +31,7 @@ def rng(self: Self) -> np.random.Generator:

@abstractmethod
def get_action(self: Self, *, state: State) -> Action:
"""Select an action based on policy `self.__b`.
"""Select an action.
Args:
----
Expand All @@ -56,7 +56,7 @@ def add_memory(self: Self, *, state: State, action: Action | None, result: Actio

@abstractmethod
def reset_memory(self: Self) -> None:
"""Clear the memory of the reinforcement learning agent."""
"""Reset the agent's memory."""

@abstractmethod
def update(self: Self) -> None:
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -173,6 +173,11 @@ def agent_state(self: Self) -> State:
"""Return the current state of the agent."""
return self.__agent_state

@property
def wall_states(self: Self) -> frozenset[State]:
"""Return the wall states of the GridWorld."""
return self.__wall_states

@property
def height(self: Self) -> int:
"""Return the height of the grid in the GridWorld object.
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -47,7 +47,14 @@ def memories(self: Self) -> tuple[McMemory, ...]:

@final
def add_memory(self: Self, *, state: State, action: Action | None, result: ActionResult | None) -> None:
"""Add a new experience into the memory."""
"""Add a new experience into the memory.
Args:
----
state: The current state of the agent.
action: The action taken by the agent.
result: The result of the action taken by the agent.
"""
if action is None or result is None:
if action is None and result is None:
message = "action or result must not be None"
Expand All @@ -61,7 +68,7 @@ def add_memory(self: Self, *, state: State, action: Action | None, result: Actio

@final
def reset_memory(self: Self) -> None:
"""Clear the memory of the reinforcement learning agent."""
"""Reset the agent's memory."""
self.__memories.clear()


Expand Down
Original file line number Diff line number Diff line change
@@ -1,12 +1,15 @@
"""Episode runner."""
from reinforcement_learning.markov_decision_process.grid_world.agent_base import AgentBase
from reinforcement_learning.markov_decision_process.grid_world.environment import GridWorld
from reinforcement_learning.markov_decision_process.grid_world.methods.temporal_difference.q_learning import (
QLearningAgent,
)
from reinforcement_learning.markov_decision_process.grid_world.methods.temporal_difference.sarsa_agent import (
SarsaAgentBase,
)
from reinforcement_learning.markov_decision_process.grid_world.methods.temporal_difference.td_eval import TdAgent


def run_td_episode(env: GridWorld, agent: AgentBase) -> None:
def run_td_episode(env: GridWorld, agent: TdAgent) -> None:
"""Run an episode for a temporary difference agent in the environment.
Args:
Expand Down Expand Up @@ -59,3 +62,29 @@ def run_sarsa_episode(env: GridWorld, agent: SarsaAgentBase) -> None:
state = result.next_state
agent.add_memory(state=state, action=None, result=None)
agent.update()


def run_q_learning_episode(env: GridWorld, agent: QLearningAgent) -> None:
"""Run an episode for a temporary difference agent in the environment.
Args:
----
env: The GridWorld environment in which the agent will run.
agent: The TdAgent.
Returns:
-------
None
"""
env.reset_agent_state()
state = env.agent_state
agent.reset_memory()
while True:
action = agent.get_action(state=state)
result = env.step(action=action)
agent.add_memory(state=state, action=action, result=result)
agent.update()
if result.done:
break
state = result.next_state
Original file line number Diff line number Diff line change
@@ -0,0 +1,111 @@
"""A Q-learning agent.
The QLearningMemory class represents a single transition in the reinforcement learning environment. Each transition
consists of a current state, reward received, next state and a boolean flag indicating whether the episode is done.
"""
from collections import defaultdict
from types import MappingProxyType
from typing import TYPE_CHECKING, Self, final

import numpy as np
from pydantic import StrictBool, StrictFloat
from pydantic.dataclasses import dataclass

from reinforcement_learning.errors import InvalidMemoryError, NotInitializedError
from reinforcement_learning.markov_decision_process.grid_world.agent_base import AgentBase
from reinforcement_learning.markov_decision_process.grid_world.environment import (
Action,
ActionResult,
ReadOnlyActionValue,
State,
)

if TYPE_CHECKING:
from reinforcement_learning.markov_decision_process.grid_world.environment import ActionValue


@final
@dataclass(frozen=True)
class QLearningMemory:
"""Memory class represents a single transition in a reinforcement learning environment.
Attributes:
----------
state (State): The current state in the transition.
reward (StrictFloat): The reward received in the transition.
next_state (State): The next state after the transition.
done (StrictBool): Indicates whether the episode is done after the transition.
"""

state: State
action: Action
reward: StrictFloat
next_state: State
done: StrictBool


class QLearningAgent(AgentBase):
"""An agent that uses the Q-learning algorithm to learn and make decisions in a grid world environment."""

def __init__(self: Self, *, seed: int | None):
"""Initialize the agent with the given seed."""
super().__init__(seed=seed)
self.__gamma: float = 0.9
self.__alpha: float = 0.8
self.__epsilon: float = 0.1
self.__action_value: ActionValue = defaultdict(lambda: 0.0)
self.__memory: QLearningMemory | None = None

@property
def action_value(self: Self) -> ReadOnlyActionValue:
"""Get the current value of the action-value function.
Returns:
-------
ActionValue: The instance's internal action-value function.
"""
return MappingProxyType(self.__action_value)

def get_action(self: Self, *, state: State) -> Action:
"""Select an action based on `self.rng`."""
if self.rng.random() < self.__epsilon:
return Action(self.rng.choice(list(Action)))
return Action(np.argmax([self.__action_value[state, action] for action in Action]).item())

def add_memory(self: Self, *, state: State, action: Action | None, result: ActionResult | None) -> None:
"""Add a new experience into the memory.
Args:
----
state: The current state of the agent.
action: The action taken by the agent.
result: The result of the action taken by the agent.
"""
if action is None or result is None:
if action is None and result is None:
message = "action or result must not be None"
elif action is None:
message = "action must not be None"
else:
message = "result must not be None"
raise InvalidMemoryError(message)
self.__memory = QLearningMemory(
state=state, action=action, reward=result.reward, next_state=result.next_state, done=result.done
)

def reset_memory(self: Self) -> None:
"""Reset the agent's memory."""
self.__memory = None

def update(self: Self) -> None:
"""Update the action-value estimates of the Q-learning agent."""
if self.__memory is None:
raise NotInitializedError(instance_name=str(self), attribute_name="__memory")
if self.__memory.done:
max_next_action_value = 0.0
else:
max_next_action_value = max(self.__action_value[self.__memory.next_state, action] for action in Action)
target = self.__gamma * max_next_action_value + self.__memory.reward
self.__action_value[self.__memory.state, self.__memory.action] += (
target - self.__action_value[self.__memory.state, self.__memory.action]
) * self.__alpha
Original file line number Diff line number Diff line change
Expand Up @@ -79,7 +79,14 @@ def memories(self: Self) -> tuple[SarsaMemory, ...]:
return tuple(self.__memories)

def add_memory(self: Self, *, state: State, action: Action | None, result: ActionResult | None) -> None:
"""Add a new experience into the memory."""
"""Add a new experience into the memory.
Args:
----
state: The current state of the agent.
action: The action taken by the agent.
result: The result of the action taken by the agent.
"""
memory = SarsaMemory(
state=state,
action=action,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -95,15 +95,22 @@ def v(self: Self) -> ReadOnlyStateValue:
return MappingProxyType(self.__v)

def add_memory(self: Self, *, state: State, action: Action | None, result: ActionResult | None) -> None: # noqa: ARG002
"""Add a new experience into the memory."""
"""Add a new experience into the memory.
Args:
----
state: The current state of the agent.
action: The action taken by the agent.
result: The result of the action taken by the agent.
"""
if result is None:
message = "result must not be None"
raise InvalidMemoryError(message)
memory = TdMemory(state=state, reward=result.reward, next_state=result.next_state, done=result.done)
self.__memory = memory

def reset_memory(self: Self) -> None:
"""Clear the memory of the reinforcement learning agent."""
"""Reset the agent's memory."""
self.__memory = None

def update(self: Self) -> None:
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -14,9 +14,9 @@
from reinforcement_learning.errors import NotInitializedError
from reinforcement_learning.markov_decision_process.grid_world.environment import (
Action,
ActionValue,
Map,
Policy,
ReadOnlyActionValue,
State,
StateValue,
)
Expand Down Expand Up @@ -260,7 +260,7 @@ def render_policy(self: Self, policy: Policy, state: State) -> None:
offset = offsets[action]
self.ax.text(state[1] + 0.45 + offset[0], self.ys - state[0] - 0.5 + offset[1], arrow)

def generate_policy(self: Self, *, q: ActionValue) -> Policy:
def generate_policy(self: Self, *, q: ReadOnlyActionValue) -> Policy:
"""Generate a policy based on the state-action values.
Returns:
Expand All @@ -279,7 +279,7 @@ def generate_policy(self: Self, *, q: ActionValue) -> Policy:

return policy

def render_q(self: Self, *, q: ActionValue, show_greedy_policy: bool = True) -> None:
def render_q(self: Self, *, q: ReadOnlyActionValue, show_greedy_policy: bool = True) -> None:
"""Render the Q-values of the grid world environment.
Args:
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,65 @@
import numpy as np
import pytest

from reinforcement_learning.errors import InvalidMemoryError, NotInitializedError
from reinforcement_learning.markov_decision_process.grid_world.environment import Action, ActionResult, GridWorld
from reinforcement_learning.markov_decision_process.grid_world.methods.temporal_difference.agent_episodes import (
run_q_learning_episode,
)
from reinforcement_learning.markov_decision_process.grid_world.methods.temporal_difference.q_learning import (
QLearningAgent,
)


def test_q_learning_add_memory() -> None:
agent = QLearningAgent(seed=0)
with pytest.raises(InvalidMemoryError):
agent.add_memory(state=(0, 0), action=None, result=None)
with pytest.raises(InvalidMemoryError):
agent.add_memory(state=(0, 0), action=Action.UP, result=None)
with pytest.raises(InvalidMemoryError):
agent.add_memory(state=(0, 0), action=None, result=ActionResult(next_state=(0, 1), reward=1.0, done=False))


def test_update_with_empty_memory() -> None:
agent = QLearningAgent(seed=0)
with pytest.raises(NotInitializedError):
agent.update()


def test_q_learning() -> None:
test_map = np.array(
[[0.0, 0.0, 0.0, 1.0], [0.0, None, 0.0, -1.0], [0.0, 0.0, 0.0, 0.0]],
dtype=np.float64,
)
env = GridWorld(reward_map=test_map, goal_state=(0, 3), start_state=(2, 0))
agent = QLearningAgent(seed=0)
for _ in range(2):
run_q_learning_episode(env=env, agent=agent)

assert agent.action_value == {
((2, 0), Action.UP): 0.0,
((2, 0), Action.DOWN): 0.0,
((2, 0), Action.LEFT): 0.0,
((2, 0), Action.RIGHT): 0.0,
((1, 0), Action.UP): 0.0,
((1, 0), Action.DOWN): 0.0,
((1, 0), Action.LEFT): 0.0,
((1, 0), Action.RIGHT): 0.0,
((0, 0), Action.UP): 0.0,
((0, 0), Action.DOWN): 0.0,
((0, 0), Action.LEFT): 0.0,
((0, 0), Action.RIGHT): 0.0,
((0, 1), Action.UP): 0.0,
((0, 1), Action.DOWN): 0.0,
((0, 1), Action.LEFT): 0.0,
((0, 1), Action.RIGHT): 0.5760000000000001,
((0, 2), Action.UP): 0.0,
((0, 2), Action.DOWN): 0.0,
((0, 2), Action.LEFT): 0.0,
((0, 2), Action.RIGHT): 0.96,
((1, 2), Action.UP): 0.0,
((1, 2), Action.DOWN): 0.0,
((1, 2), Action.LEFT): 0.0,
((1, 2), Action.RIGHT): 0.0,
}
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,8 @@ def test_states(mock_env: GridWorld) -> None:
all_states = set(mock_env.states())
assert all_states == {(h, w) for h in range(mock_env.shape[0]) for w in range(mock_env.shape[1])}

assert mock_env.wall_states == {(1, 1)}


@pytest.fixture(params=Action)
def action(request: SubRequest) -> Action:
Expand Down

1 comment on commit 2768567

@github-actions
Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Coverage

Coverage Report
FileStmtsMissCover
TOTAL9140100%

Tests Skipped Failures Errors Time
40 0 💤 0 ❌ 0 🔥 0.732s ⏱️

Please sign in to comment.