Skip to content

Commit

Permalink
Update reinforcement learning code to use Q-network
Browse files Browse the repository at this point in the history
Reinforcement learning part of the code was upgraded to use a more sophisticated Q-network. Necessary adjustments were made to mypy.ini and pyproject.toml to reflect changes in the codebase. Dedicated unit tests were added to ensure the functionality of the updated Q-learning algorithm. Codebase was further secured with pytest fixture setting up seed for PyTorch and Python random module.

PyTorch was added to project dependencies to enable neural network functionalities. The neural network is used within a Q-learning algorithm, allowing a more flexible approximation of the action values. This change will now allow the model to learn more complex behaviors. A Qnet class and QLearningAgent were introduced to accommodate the usage of the Q-network.

Additional changes include renaming 'run_q_learning_episode' function to 'run_td_episode', which now additionally handles the option of excluding the goal state from memory. Adjustments to the model of this nature often lead to increased learning capabilities.

Overall, the codebase now reflects a more advanced Q-learning algorithm, expanded to use a Q-network, with rigorous unit tests ensuring its proper operation.
  • Loading branch information
nakashima-hikaru committed Nov 26, 2023
1 parent 8522e57 commit e4cb949
Show file tree
Hide file tree
Showing 16 changed files with 696 additions and 80 deletions.
2 changes: 1 addition & 1 deletion examples/temporal_difference_eval.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,7 @@ def main() -> None:
agent = TdAgent(gamma=0.9, alpha=0.01)
n_episodes: int = 1000
for _ in range(n_episodes):
run_td_episode(env=env, agent=agent)
run_td_episode(env=env, agent=agent, add_goal_state_to_memory=False)
logging.info(agent.v)


Expand Down
4 changes: 2 additions & 2 deletions mypy.ini
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,7 @@ warn_no_return = True
warn_return_any = True
warn_unused_configs = True
mypy_path = ./stubs
exclude = venv/
exclude = venv/, reinforcement_learning/markov_decision_process/grid_world/methods/neural_network
plugins = numpy.typing.mypy_plugin, pydantic.mypy

[mypy-matplotlib]
Expand All @@ -33,4 +33,4 @@ ignore_missing_imports = True
[pydantic-mypy]
init_forbid_extra = True
init_typed = True
warn_required_dynamic_aliases = True
warn_required_dynamic_aliases = True
381 changes: 380 additions & 1 deletion poetry.lock

Large diffs are not rendered by default.

1 change: 1 addition & 0 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@ matplotlib = "^3.6.2"
tqdm = "^4.64.1"
numpy = "^1.26.2"
pydantic = "^2.5.1"
torch = "^2.1.1"


[tool.poetry.group.dev.dependencies]
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -22,8 +22,10 @@

import numpy as np
import numpy.typing as npt
import torch
from pydantic import StrictBool, StrictFloat
from pydantic.dataclasses import dataclass
from torch import Tensor

from reinforcement_learning.errors import NumpyDimError

Expand Down Expand Up @@ -207,6 +209,12 @@ def shape(self: Self) -> tuple[int, int]:
"""
return cast(tuple[int, int], self.__reward_map.shape)

def state(self: Self) -> Iterator[State]:
"""Iterate over the state of the GridWorld."""
for h in range(self.height):
for w in range(self.width):
yield h, w

def states(self: Self) -> Iterator[State]:
"""Execute and yield all possible states in the two-dimensional grid."""
for h in range(self.height):
Expand Down Expand Up @@ -235,7 +243,7 @@ def next_state(self: Self, state: State, action: Action) -> State:
next_state = state
return next_state

def reward(self: Self, next_state: State) -> float:
def reward(self: Self, *, next_state: State) -> float:
"""Compute the reward for a given state transition.
Args:
Expand All @@ -248,7 +256,7 @@ def reward(self: Self, next_state: State) -> float:
"""
return cast(float, self.__reward_map[next_state])

def step(self: Self, action: Action) -> ActionResult:
def step(self: Self, *, action: Action) -> ActionResult:
"""Perform an environment step based on the provided action.
Args:
Expand All @@ -260,11 +268,26 @@ def step(self: Self, action: Action) -> ActionResult:
tuple(State, float, bool): The next state, reward from the current action and whether the goal state is reached.
"""
next_state = self.next_state(state=self.__agent_state, action=action)
reward = self.reward(next_state)
reward = self.reward(next_state=next_state)
done = next_state == self.__goal_state
self.__agent_state = next_state
return ActionResult(next_state=next_state, reward=reward, done=done)

def reset_agent_state(self: Self) -> None:
"""Reset the agent's state to the start state."""
self.__agent_state = self.__start_state

def convert_to_one_hot(self: Self, *, state: State) -> Tensor:
"""Convert the given state into a one-hot encoded tensor.
Args:
state: The state to be converted into a one-hot encoded tensor.
Returns:
A one-hot encoded tensor representing the given state.
"""
vec = torch.zeros(self.height * self.width)
y, x = state
idx = self.width * y + x
vec[idx] = 1.0
return cast(Tensor, vec.unsqueeze(dim=0))
Original file line number Diff line number Diff line change
Expand Up @@ -41,7 +41,7 @@ def eval_one_step(
new_v: float = 0.0
for action, action_prob in action_probs.items():
next_state = env.next_state(state=state, action=action)
reward = env.reward(next_state)
reward = env.reward(next_state=next_state)
new_v += action_prob * (reward + gamma * v[next_state])
v[state] = new_v
return v
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,168 @@
"""Classes and methods for reinforcement learning in a grid world environment."""
from collections import defaultdict
from types import MappingProxyType
from typing import Self, cast

import torch
import torch.nn.functional as f
from torch import Tensor, nn, optim

from reinforcement_learning.errors import InvalidMemoryError, NotInitializedError
from reinforcement_learning.markov_decision_process.grid_world.agent_base import AgentBase
from reinforcement_learning.markov_decision_process.grid_world.environment import (
Action,
ActionResult,
GridWorld,
ReadOnlyActionValue,
State,
)
from reinforcement_learning.markov_decision_process.grid_world.methods.temporal_difference.q_learning import (
QLearningMemory,
)


class Qnet(nn.Module):
"""A Q-network for reinforcement learning in a grid world environment."""

def __init__(self: Self) -> None:
"""Initialize a Qnet object."""
super().__init__()
self.l1: nn.Linear = nn.Linear(in_features=12, out_features=128)
self.l2: nn.Linear = nn.Linear(in_features=128, out_features=len(Action))

def forward(self: Self, x: Tensor) -> Tensor:
"""Perform the forward pass of the Qnet neural network.
Args:
x: A torch.Tensor representing the input to the network.
"""
return cast(Tensor, self.l2(f.relu(self.l1(x))))


class QLearningAgent(AgentBase):
"""A Q-learning algorithm for reinforcement learning in a GridWorld environment.
Attributes:
----------
- __gamma (float): Discount factor for future rewards.
- __lr (float): Learning rate for optimization.
- __epsilon (float): Exploration rate for epsilon-greedy policy.
- __q_net (nn.Module): Q-network, a neural network that estimates action values.
- __optimizer (optim.Optimizer): Optimizer for updating the Q-network.
- __env (GridWorld): GridWorld environment.
- __memory (QLearningMemory | None): Memory to store agent's experiences.
- __total_loss (float): Total loss accumulated during training.
- __count (int): Count of training steps performed.
Methods:
-------
- average_loss(self: Self) -> float:
Returns the average loss per training step.
- action_value(self: Self) -> ReadOnlyActionValue:
Returns a read-only mapping of state-action pairs to their estimated action values.
- get_action(self: Self, *, state: State) -> Action:
Selects an action to take given the current state based on the epsilon-greedy policy.
- add_memory(self: Self, *, state: State, action: Action | None, result: ActionResult | None) -> None:
Adds a new experience to the agent's memory.
- reset_memory(self: Self) -> None:
Resets the agent's memory.
- update(self: Self) -> None:
Performs a single update step of the Q-learning algorithm.
"""

def __init__(self: Self, *, seed: int | None, env: GridWorld):
"""Initialize an agent.
Args:
seed: An integer specifying the seed value for random number generation. If None, no seed is set.
env: A GridWorld instance representing the environment in which the agent will operate.
"""
super().__init__(seed=seed)
self.__gamma: float = 0.9
self.__lr: float = 0.01
self.__epsilon: float = 0.1
self.__q_net: nn.Module = Qnet()
self.__optimizer: optim.Optimizer = optim.SGD(lr=self.__lr, params=self.__q_net.parameters())
self.__env: GridWorld = env
self.__memory: QLearningMemory | None = None
self.__total_loss: float = 0.0
self.__count: int = 0

@property
def average_loss(self: Self) -> float:
"""Calculate the average loss of the QLearningAgent."""
return self.__total_loss / self.__count

@property
def action_value(self: Self) -> ReadOnlyActionValue:
"""Return a readonly action value map for the agent."""
ret: defaultdict[tuple[State, Action], float] = defaultdict()
with torch.set_grad_enabled(mode=False):
for state in self.__env.state():
for action in Action:
ret[state, action] = float(
self.__q_net(self.__env.convert_to_one_hot(state=state))[:, action.value]
)
return MappingProxyType(ret)

def get_action(self: Self, *, state: State) -> Action:
"""Select an action based on `self.rng`."""
if self.rng.random() < self.__epsilon:
return Action(self.rng.choice(list(Action)))

return Action(torch.argmax(self.__q_net(self.__env.convert_to_one_hot(state=state)), dim=1)[0].item())

def add_memory(self: Self, *, state: State, action: Action | None, result: ActionResult | None) -> None:
"""Add a new experience into the memory.
Args:
----
state: The current state of the agent.
action: The action taken by the agent.
result: The result of the action taken by the agent.
"""
if action is None or result is None:
if action is None and result is None:
message = "action or result must not be None"
elif action is None:
message = "action must not be None"
else:
message = "result must not be None"
raise InvalidMemoryError(message)
self.__memory = QLearningMemory(
state=state, action=action, reward=result.reward, next_state=result.next_state, done=result.done
)

def reset_memory(self: Self) -> None:
"""Reset the agent's memory."""
self.__memory = None
self.__total_loss = 0.0
self.__count = 0

def update(self: Self) -> None:
"""Updates the Q-values in the Q-learning agent based on the current memory."""
if self.__memory is None:
raise NotInitializedError(instance_name=str(self), attribute_name="__memory")
if self.__memory.done:
next_action_value = torch.zeros(size=[1])
else:
next_action_value = torch.max( # noqa: PD011
self.__q_net(self.__env.convert_to_one_hot(state=self.__memory.next_state)),
dim=1,
).values
target = self.__gamma * next_action_value + self.__memory.reward
current_action_values = self.__q_net(self.__env.convert_to_one_hot(state=self.__memory.state))
current_action_value = current_action_values[:, self.__memory.action.value]
loss = nn.MSELoss()
output = loss(target, current_action_value)
self.__total_loss += float(output)
self.__count += 1
self.__q_net.zero_grad()
output.backward()
self.__optimizer.step()
Original file line number Diff line number Diff line change
@@ -1,90 +1,32 @@
"""Episode runner."""
from reinforcement_learning.markov_decision_process.grid_world.agent_base import AgentBase
from reinforcement_learning.markov_decision_process.grid_world.environment import GridWorld
from reinforcement_learning.markov_decision_process.grid_world.methods.temporal_difference.q_learning import (
QLearningAgent,
)
from reinforcement_learning.markov_decision_process.grid_world.methods.temporal_difference.sarsa_agent import (
SarsaAgentBase,
)
from reinforcement_learning.markov_decision_process.grid_world.methods.temporal_difference.td_eval import TdAgent


def run_td_episode(env: GridWorld, agent: TdAgent) -> None:
def run_td_episode(*, env: GridWorld, agent: AgentBase, add_goal_state_to_memory: bool) -> None:
"""Run an episode for a temporary difference agent in the environment.
Args:
----
env: The GridWorld environment in which the agent will run.
agent: The TdAgent.
add_goal_state_to_memory: Whether to add goal state to the agent's memory at the end of an episode.
Returns:
-------
None
"""
env.reset_agent_state()
state = env.agent_state
agent.reset_memory()
while True:
state = env.agent_state
action = agent.get_action(state=state)
result = env.step(action=action)
agent.add_memory(state=state, action=action, result=result)
agent.update()
if result.done:
break
state = result.next_state


def run_sarsa_episode(env: GridWorld, agent: SarsaAgentBase) -> None:
"""Run an episode for a SARSA agent in the environment.
Args:
----
env: The GridWorld environment in which the agent will run.
agent: The SARSA agent.
Returns:
-------
None
"""
env.reset_agent_state()
state = env.agent_state
agent.reset_memory()
while True:
action = agent.get_action(state=state)
result = env.step(action=action)
agent.add_memory(state=state, action=action, result=result)
agent.update()

if result.done:
break
state = result.next_state
agent.add_memory(state=state, action=None, result=None)
agent.update()


def run_q_learning_episode(env: GridWorld, agent: QLearningAgent) -> None:
"""Run an episode for a temporary difference agent in the environment.
Args:
----
env: The GridWorld environment in which the agent will run.
agent: The TdAgent.
Returns:
-------
None
"""
env.reset_agent_state()
state = env.agent_state
agent.reset_memory()
while True:
action = agent.get_action(state=state)
result = env.step(action=action)
agent.add_memory(state=state, action=action, result=result)
if add_goal_state_to_memory:
agent.add_memory(state=state, action=None, result=None)
agent.update()
if result.done:
break
state = result.next_state
15 changes: 15 additions & 0 deletions tests/conftest.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,15 @@
import random

import pytest
import torch


@pytest.fixture(autouse=True)
def _torch_fix_seed(seed: int = 42) -> None:
# Python random
random.seed(seed)
# Pytorch
torch.manual_seed(seed)
torch.cuda.manual_seed(seed)
torch.backends.cudnn.deterministic = True
torch.use_deterministic_algorithms(mode=True)
Empty file.
Loading

1 comment on commit e4cb949

@github-actions
Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Coverage

Coverage Report
FileStmtsMissCoverMissing
reinforcement_learning/markov_decision_process/grid_world
   environment.py123398%214–216
reinforcement_learning/markov_decision_process/grid_world/methods/neural_network
   q_learning.py73692%105–112
tests/markov_dicision_process/grid_world/methods/neural_network
   test_q_learning.py27196%39
TOTAL10101099% 

Tests Skipped Failures Errors Time
43 0 💤 1 ❌ 0 🔥 1.613s ⏱️

Please sign in to comment.