Skip to content

Commit

Permalink
Add comprehensive tests for markov_decision_process, refine error han…
Browse files Browse the repository at this point in the history
…dling and memory addition

This commit primarily involves three aspects of changes:

- Addition of new tests for methods in the markov_decision_process. This includes tests for episodes involving a SARSA agent and for episodes involving a TD agent. This is done to ensure the correctness and robustness of both the agents in different scenarios.
- Refactoring error handling for `add_memory` methods. The check for instances when `action` or `result` is None, earlier resulted in `InvalidMemoryError` being raised. However, this error handling has been shifted to the agent implementation level from the base class level. Specific conditions that cause an error are now mentioned in the `InvalidMemoryError` message.
- Changes have been made to seamlessly add only a state as a memory when required. This has been achieved by making changes in `add_memory` and `update` methods in sarsa_agent.py and its subclasses. This allows more flexibility while adding to the agent's memory.
  • Loading branch information
nakashima-hikaru committed Nov 26, 2023
1 parent e4cb949 commit ed9c5fd
Show file tree
Hide file tree
Showing 18 changed files with 144 additions and 151 deletions.
2 changes: 1 addition & 1 deletion examples/temporal_difference_eval.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,7 @@ def main() -> None:
agent = TdAgent(gamma=0.9, alpha=0.01)
n_episodes: int = 1000
for _ in range(n_episodes):
run_td_episode(env=env, agent=agent, add_goal_state_to_memory=False)
run_td_episode(env=env, agent=agent)
logging.info(agent.v)


Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -44,7 +44,7 @@ def get_action(self: Self, *, state: State) -> Action:
"""

@abstractmethod
def add_memory(self: Self, *, state: State, action: Action | None, result: ActionResult | None) -> None:
def add_memory(self: Self, *, state: State, action: Action, result: ActionResult) -> None:
"""Add a new experience into the memory.
Args:
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,6 @@
from pydantic import StrictFloat
from pydantic.dataclasses import dataclass

from reinforcement_learning.errors import InvalidMemoryError
from reinforcement_learning.markov_decision_process.grid_world.agent_base import DistributionModelAgent
from reinforcement_learning.markov_decision_process.grid_world.environment import Action, ActionResult, GridWorld, State

Expand Down Expand Up @@ -46,7 +45,7 @@ def memories(self: Self) -> tuple[McMemory, ...]:
return tuple(self.__memories)

@final
def add_memory(self: Self, *, state: State, action: Action | None, result: ActionResult | None) -> None:
def add_memory(self: Self, *, state: State, action: Action, result: ActionResult) -> None:
"""Add a new experience into the memory.
Args:
Expand All @@ -55,14 +54,6 @@ def add_memory(self: Self, *, state: State, action: Action | None, result: Actio
action: The action taken by the agent.
result: The result of the action taken by the agent.
"""
if action is None or result is None:
if action is None and result is None:
message = "action or result must not be None"
elif action is None:
message = "action must not be None"
else:
message = "result must not be None"
raise InvalidMemoryError(message)
memory = McMemory(state=state, action=action, reward=result.reward)
self.__memories.append(memory)

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@
import torch.nn.functional as f
from torch import Tensor, nn, optim

from reinforcement_learning.errors import InvalidMemoryError, NotInitializedError
from reinforcement_learning.errors import NotInitializedError
from reinforcement_learning.markov_decision_process.grid_world.agent_base import AgentBase
from reinforcement_learning.markov_decision_process.grid_world.environment import (
Action,
Expand Down Expand Up @@ -118,7 +118,7 @@ def get_action(self: Self, *, state: State) -> Action:

return Action(torch.argmax(self.__q_net(self.__env.convert_to_one_hot(state=state)), dim=1)[0].item())

def add_memory(self: Self, *, state: State, action: Action | None, result: ActionResult | None) -> None:
def add_memory(self: Self, *, state: State, action: Action, result: ActionResult) -> None:
"""Add a new experience into the memory.
Args:
Expand All @@ -127,14 +127,6 @@ def add_memory(self: Self, *, state: State, action: Action | None, result: Actio
action: The action taken by the agent.
result: The result of the action taken by the agent.
"""
if action is None or result is None:
if action is None and result is None:
message = "action or result must not be None"
elif action is None:
message = "action must not be None"
else:
message = "result must not be None"
raise InvalidMemoryError(message)
self.__memory = QLearningMemory(
state=state, action=action, reward=result.reward, next_state=result.next_state, done=result.done
)
Expand Down
Original file line number Diff line number Diff line change
@@ -1,16 +1,18 @@
"""Episode runner."""
from reinforcement_learning.markov_decision_process.grid_world.agent_base import AgentBase
from reinforcement_learning.markov_decision_process.grid_world.environment import GridWorld
from reinforcement_learning.markov_decision_process.grid_world.methods.temporal_difference.sarsa_agent import (
SarsaAgentBase,
)


def run_td_episode(*, env: GridWorld, agent: AgentBase, add_goal_state_to_memory: bool) -> None:
def run_sarsa_episode(*, env: GridWorld, agent: SarsaAgentBase) -> None:
"""Run an episode for a temporary difference agent in the environment.
Args:
----
env: The GridWorld environment in which the agent will run.
agent: The TdAgent.
add_goal_state_to_memory: Whether to add goal state to the agent's memory at the end of an episode.
Returns:
-------
Expand All @@ -19,14 +21,38 @@ def run_td_episode(*, env: GridWorld, agent: AgentBase, add_goal_state_to_memory
"""
env.reset_agent_state()
agent.reset_memory()
while True:
while env.agent_state != env.goal_state:
state = env.agent_state
action = agent.get_action(state=state)
result = env.step(action=action)
agent.add_memory(state=state, action=action, result=result)
agent.update()
if result.done:
agent.add_state_as_memory(state=result.next_state)
agent.update()
break
if add_goal_state_to_memory:
agent.add_memory(state=state, action=None, result=None)


def run_td_episode(*, env: GridWorld, agent: AgentBase) -> None:
"""Run an episode for a temporary difference agent in the environment.
Args:
----
env: The GridWorld environment in which the agent will run.
agent: The TdAgent.
Returns:
-------
None
"""
env.reset_agent_state()
agent.reset_memory()
while env.agent_state != env.goal_state:
state = env.agent_state
action = agent.get_action(state=state)
result = env.step(action=action)
agent.add_memory(state=state, action=action, result=result)
agent.update()
if result.done:
break
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,7 @@
from pydantic import StrictBool, StrictFloat
from pydantic.dataclasses import dataclass

from reinforcement_learning.errors import InvalidMemoryError, NotInitializedError
from reinforcement_learning.errors import NotInitializedError
from reinforcement_learning.markov_decision_process.grid_world.agent_base import AgentBase
from reinforcement_learning.markov_decision_process.grid_world.environment import (
Action,
Expand Down Expand Up @@ -72,7 +72,7 @@ def get_action(self: Self, *, state: State) -> Action:
return Action(self.rng.choice(list(Action)))
return Action(np.argmax([self.__action_value[state, action] for action in Action]).item())

def add_memory(self: Self, *, state: State, action: Action | None, result: ActionResult | None) -> None:
def add_memory(self: Self, *, state: State, action: Action, result: ActionResult) -> None:
"""Add a new experience into the memory.
Args:
Expand All @@ -81,14 +81,6 @@ def add_memory(self: Self, *, state: State, action: Action | None, result: Actio
action: The action taken by the agent.
result: The result of the action taken by the agent.
"""
if action is None or result is None:
if action is None and result is None:
message = "action or result must not be None"
elif action is None:
message = "action must not be None"
else:
message = "result must not be None"
raise InvalidMemoryError(message)
self.__memory = QLearningMemory(
state=state, action=action, reward=result.reward, next_state=result.next_state, done=result.done
)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -28,9 +28,9 @@ class SarsaMemory:
"""

state: State
action: Action | None
reward: StrictFloat | None
done: StrictBool | None
action: Action
reward: StrictFloat
done: StrictBool


class SarsaAgentBase(DistributionModelAgent, ABC):
Expand All @@ -44,7 +44,7 @@ def __init__(self: Self, *, seed: int | None) -> None:
self.__gamma: float = 0.9
self.__alpha: float = 0.8
self.__epsilon: float = 0.1
self.__memories: deque[SarsaMemory] = deque(maxlen=SarsaAgentBase.max_memory_length)
self.__memories: deque[SarsaMemory | State] = deque(maxlen=SarsaAgentBase.max_memory_length)

@property
def gamma(self: Self) -> float:
Expand Down Expand Up @@ -74,11 +74,11 @@ def epsilon(self: Self) -> float:
return self.__epsilon

@property
def memories(self: Self) -> tuple[SarsaMemory, ...]:
def memories(self: Self) -> tuple[SarsaMemory | State, ...]:
"""Return a tuple of memories."""
return tuple(self.__memories)

def add_memory(self: Self, *, state: State, action: Action | None, result: ActionResult | None) -> None:
def add_memory(self: Self, *, state: State, action: Action, result: ActionResult) -> None:
"""Add a new experience into the memory.
Args:
Expand All @@ -90,11 +90,19 @@ def add_memory(self: Self, *, state: State, action: Action | None, result: Actio
memory = SarsaMemory(
state=state,
action=action,
reward=result.reward if result is not None else None,
done=result.done if result is not None else None,
reward=result.reward,
done=result.done,
)
self.__memories.append(memory)

def add_state_as_memory(self: Self, *, state: State) -> None:
"""Add a state to the agent's memory.
Args:
state: The state to be added to the agent's memory.
"""
self.__memories.append(state)

def reset_memory(self: Self) -> None:
"""Reset the agent's memory."""
self.__memories.clear()
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@
from types import MappingProxyType
from typing import TYPE_CHECKING, Self, final

from reinforcement_learning.errors import NotInitializedError
from reinforcement_learning.errors import InvalidMemoryError
from reinforcement_learning.markov_decision_process.grid_world.environment import (
RANDOM_ACTIONS,
ReadOnlyActionValue,
Expand Down Expand Up @@ -63,22 +63,22 @@ def update(self: Self) -> None:
if len(self.memories) < SarsaAgentBase.max_memory_length:
return
current_memory = self.memories[0]
if current_memory.action is None:
raise NotInitializedError(instance_name=str(current_memory), attribute_name="action")
next_memory = self.memories[1]
if isinstance(current_memory, tuple):
message = "Memory must be cleared after state-only memory added to memories"
raise InvalidMemoryError(message)
if current_memory.done:
next_q = 0.0
rho = 1.0
else:
if next_memory.action is None:
raise NotInitializedError(instance_name=str(next_memory), attribute_name="action")
next_memory = self.memories[1]
if isinstance(next_memory, tuple):
message = "State-only memory must be added after an episode is done"
raise InvalidMemoryError(message)
next_q = self.__action_value[next_memory.state, next_memory.action]
rho = (
self.__evaluation_policy[current_memory.state][current_memory.action]
/ self.behavior_policy[current_memory.state][current_memory.action]
)
if current_memory.reward is None:
raise NotInitializedError(instance_name=str(current_memory), attribute_name="reward")
target = rho * (current_memory.reward + self.gamma * next_q)
key = current_memory.state, current_memory.action
self.__action_value[key] += (target - self.__action_value[key]) * self.alpha
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,7 @@
from types import MappingProxyType
from typing import TYPE_CHECKING, Self, final

from reinforcement_learning.errors import NotInitializedError
from reinforcement_learning.errors import InvalidMemoryError
from reinforcement_learning.markov_decision_process.grid_world.environment import (
RANDOM_ACTIONS,
ReadOnlyActionValue,
Expand Down Expand Up @@ -66,17 +66,17 @@ def update(self: Self) -> None:
if len(self.memories) < SarsaAgentBase.max_memory_length:
return
current_memory = self.memories[0]
if current_memory.action is None:
raise NotInitializedError(instance_name=str(current_memory), attribute_name="action")
next_memory = self.memories[1]
if isinstance(current_memory, tuple):
message = "Memory must be cleared after state-only memory added to memories"
raise InvalidMemoryError(message)
if current_memory.done:
next_q = 0.0
else:
if next_memory.action is None:
raise NotInitializedError(instance_name=str(next_memory), attribute_name="action")
next_memory = self.memories[1]
if isinstance(next_memory, tuple):
message = "State-only memory must be added after an episode is done"
raise InvalidMemoryError(message)
next_q = self.__action_value[next_memory.state, next_memory.action]
if current_memory.reward is None:
raise NotInitializedError(instance_name=str(current_memory), attribute_name="reward")
target = current_memory.reward + self.gamma * next_q
key = current_memory.state, current_memory.action
self.__action_value[key] += (target - self.__action_value[key]) * self.alpha
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -27,7 +27,7 @@

from pydantic import StrictBool, StrictFloat

from reinforcement_learning.errors import InvalidMemoryError, NotInitializedError
from reinforcement_learning.errors import NotInitializedError
from reinforcement_learning.markov_decision_process.grid_world.agent_base import DistributionModelAgent
from reinforcement_learning.markov_decision_process.grid_world.environment import (
RANDOM_ACTIONS,
Expand Down Expand Up @@ -94,7 +94,7 @@ def v(self: Self) -> ReadOnlyStateValue:
"""Return the state value."""
return MappingProxyType(self.__v)

def add_memory(self: Self, *, state: State, action: Action | None, result: ActionResult | None) -> None: # noqa: ARG002
def add_memory(self: Self, *, state: State, action: Action, result: ActionResult) -> None: # noqa: ARG002
"""Add a new experience into the memory.
Args:
Expand All @@ -103,9 +103,6 @@ def add_memory(self: Self, *, state: State, action: Action | None, result: Actio
action: The action taken by the agent.
result: The result of the action taken by the agent.
"""
if result is None:
message = "result must not be None"
raise InvalidMemoryError(message)
memory = TdMemory(state=state, reward=result.reward, next_state=result.next_state, done=result.done)
self.__memory = memory

Expand Down
Original file line number Diff line number Diff line change
@@ -1,8 +1,6 @@
import numpy as np
import pytest

from reinforcement_learning.errors import InvalidMemoryError
from reinforcement_learning.markov_decision_process.grid_world.environment import Action, ActionResult, GridWorld
from reinforcement_learning.markov_decision_process.grid_world.environment import Action, GridWorld
from reinforcement_learning.markov_decision_process.grid_world.methods.monte_carlo.mc_agent import (
run_monte_carlo_episode,
)
Expand All @@ -11,16 +9,6 @@
)


def test_mc_add_memory() -> None:
agent = McOffPolicyAgent(gamma=0.9, epsilon=0.1, alpha=0.1, seed=0)
with pytest.raises(InvalidMemoryError):
agent.add_memory(state=(0, 0), action=None, result=None)
with pytest.raises(InvalidMemoryError):
agent.add_memory(state=(0, 0), action=Action.UP, result=None)
with pytest.raises(InvalidMemoryError):
agent.add_memory(state=(0, 0), action=None, result=ActionResult(next_state=(0, 1), reward=1.0, done=False))


def test_mc_control_off_policy() -> None:
test_map = np.array(
[[0.0, 0.0, 0.0, 1.0], [0.0, None, 0.0, -1.0], [0.0, 0.0, 0.0, 0.0]],
Expand Down
Original file line number Diff line number Diff line change
@@ -1,8 +1,6 @@
import numpy as np
import pytest

from reinforcement_learning.errors import InvalidMemoryError
from reinforcement_learning.markov_decision_process.grid_world.environment import Action, ActionResult, GridWorld
from reinforcement_learning.markov_decision_process.grid_world.environment import Action, GridWorld
from reinforcement_learning.markov_decision_process.grid_world.methods.monte_carlo.mc_agent import (
run_monte_carlo_episode,
)
Expand All @@ -11,16 +9,6 @@
)


def test_mc_add_memory() -> None:
agent = McOnPolicyAgent(gamma=0.9, epsilon=0.1, alpha=0.1, seed=0)
with pytest.raises(InvalidMemoryError):
agent.add_memory(state=(0, 0), action=None, result=None)
with pytest.raises(InvalidMemoryError):
agent.add_memory(state=(0, 0), action=Action.UP, result=None)
with pytest.raises(InvalidMemoryError):
agent.add_memory(state=(0, 0), action=None, result=ActionResult(next_state=(0, 1), reward=1.0, done=False))


def test_mc_control_on_policy() -> None:
test_map = np.array(
[[0.0, 0.0, 0.0, 1.0], [0.0, None, 0.0, -1.0], [0.0, 0.0, 0.0, 0.0]],
Expand Down
Loading

1 comment on commit ed9c5fd

@github-actions
Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Coverage

Coverage Report
FileStmtsMissCoverMissing
reinforcement_learning/markov_decision_process/grid_world
   environment.py123398%214–216
reinforcement_learning/markov_decision_process/grid_world/methods/neural_network
   q_learning.py66691%105–112
tests/markov_dicision_process/grid_world/methods/neural_network
   test_q_learning.py19195%29
TOTAL9751099% 

Tests Skipped Failures Errors Time
40 0 💤 1 ❌ 0 🔥 1.586s ⏱️

Please sign in to comment.