-
Notifications
You must be signed in to change notification settings - Fork 1
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Update reinforcement learning code to use Q-network
Reinforcement learning part of the code was upgraded to use a more sophisticated Q-network. Necessary adjustments were made to mypy.ini and pyproject.toml to reflect changes in the codebase. Dedicated unit tests were added to ensure the functionality of the updated Q-learning algorithm. Codebase was further secured with pytest fixture setting up seed for PyTorch and Python random module. PyTorch was added to project dependencies to enable neural network functionalities. The neural network is used within a Q-learning algorithm, allowing a more flexible approximation of the action values. This change will now allow the model to learn more complex behaviors. A Qnet class and QLearningAgent were introduced to accommodate the usage of the Q-network. Additional changes include renaming 'run_q_learning_episode' function to 'run_td_episode', which now additionally handles the option of excluding the goal state from memory. Adjustments to the model of this nature often lead to increased learning capabilities. Overall, the codebase now reflects a more advanced Q-learning algorithm, expanded to use a Q-network, with rigorous unit tests ensuring its proper operation.
- Loading branch information
1 parent
8522e57
commit e4cb949
Showing
16 changed files
with
696 additions
and
80 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Large diffs are not rendered by default.
Oops, something went wrong.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Empty file.
168 changes: 168 additions & 0 deletions
168
...orcement_learning/markov_decision_process/grid_world/methods/neural_network/q_learning.py
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,168 @@ | ||
"""Classes and methods for reinforcement learning in a grid world environment.""" | ||
from collections import defaultdict | ||
from types import MappingProxyType | ||
from typing import Self, cast | ||
|
||
import torch | ||
import torch.nn.functional as f | ||
from torch import Tensor, nn, optim | ||
|
||
from reinforcement_learning.errors import InvalidMemoryError, NotInitializedError | ||
from reinforcement_learning.markov_decision_process.grid_world.agent_base import AgentBase | ||
from reinforcement_learning.markov_decision_process.grid_world.environment import ( | ||
Action, | ||
ActionResult, | ||
GridWorld, | ||
ReadOnlyActionValue, | ||
State, | ||
) | ||
from reinforcement_learning.markov_decision_process.grid_world.methods.temporal_difference.q_learning import ( | ||
QLearningMemory, | ||
) | ||
|
||
|
||
class Qnet(nn.Module): | ||
"""A Q-network for reinforcement learning in a grid world environment.""" | ||
|
||
def __init__(self: Self) -> None: | ||
"""Initialize a Qnet object.""" | ||
super().__init__() | ||
self.l1: nn.Linear = nn.Linear(in_features=12, out_features=128) | ||
self.l2: nn.Linear = nn.Linear(in_features=128, out_features=len(Action)) | ||
|
||
def forward(self: Self, x: Tensor) -> Tensor: | ||
"""Perform the forward pass of the Qnet neural network. | ||
Args: | ||
x: A torch.Tensor representing the input to the network. | ||
""" | ||
return cast(Tensor, self.l2(f.relu(self.l1(x)))) | ||
|
||
|
||
class QLearningAgent(AgentBase): | ||
"""A Q-learning algorithm for reinforcement learning in a GridWorld environment. | ||
Attributes: | ||
---------- | ||
- __gamma (float): Discount factor for future rewards. | ||
- __lr (float): Learning rate for optimization. | ||
- __epsilon (float): Exploration rate for epsilon-greedy policy. | ||
- __q_net (nn.Module): Q-network, a neural network that estimates action values. | ||
- __optimizer (optim.Optimizer): Optimizer for updating the Q-network. | ||
- __env (GridWorld): GridWorld environment. | ||
- __memory (QLearningMemory | None): Memory to store agent's experiences. | ||
- __total_loss (float): Total loss accumulated during training. | ||
- __count (int): Count of training steps performed. | ||
Methods: | ||
------- | ||
- average_loss(self: Self) -> float: | ||
Returns the average loss per training step. | ||
- action_value(self: Self) -> ReadOnlyActionValue: | ||
Returns a read-only mapping of state-action pairs to their estimated action values. | ||
- get_action(self: Self, *, state: State) -> Action: | ||
Selects an action to take given the current state based on the epsilon-greedy policy. | ||
- add_memory(self: Self, *, state: State, action: Action | None, result: ActionResult | None) -> None: | ||
Adds a new experience to the agent's memory. | ||
- reset_memory(self: Self) -> None: | ||
Resets the agent's memory. | ||
- update(self: Self) -> None: | ||
Performs a single update step of the Q-learning algorithm. | ||
""" | ||
|
||
def __init__(self: Self, *, seed: int | None, env: GridWorld): | ||
"""Initialize an agent. | ||
Args: | ||
seed: An integer specifying the seed value for random number generation. If None, no seed is set. | ||
env: A GridWorld instance representing the environment in which the agent will operate. | ||
""" | ||
super().__init__(seed=seed) | ||
self.__gamma: float = 0.9 | ||
self.__lr: float = 0.01 | ||
self.__epsilon: float = 0.1 | ||
self.__q_net: nn.Module = Qnet() | ||
self.__optimizer: optim.Optimizer = optim.SGD(lr=self.__lr, params=self.__q_net.parameters()) | ||
self.__env: GridWorld = env | ||
self.__memory: QLearningMemory | None = None | ||
self.__total_loss: float = 0.0 | ||
self.__count: int = 0 | ||
|
||
@property | ||
def average_loss(self: Self) -> float: | ||
"""Calculate the average loss of the QLearningAgent.""" | ||
return self.__total_loss / self.__count | ||
|
||
@property | ||
def action_value(self: Self) -> ReadOnlyActionValue: | ||
"""Return a readonly action value map for the agent.""" | ||
ret: defaultdict[tuple[State, Action], float] = defaultdict() | ||
with torch.set_grad_enabled(mode=False): | ||
for state in self.__env.state(): | ||
for action in Action: | ||
ret[state, action] = float( | ||
self.__q_net(self.__env.convert_to_one_hot(state=state))[:, action.value] | ||
) | ||
return MappingProxyType(ret) | ||
|
||
def get_action(self: Self, *, state: State) -> Action: | ||
"""Select an action based on `self.rng`.""" | ||
if self.rng.random() < self.__epsilon: | ||
return Action(self.rng.choice(list(Action))) | ||
|
||
return Action(torch.argmax(self.__q_net(self.__env.convert_to_one_hot(state=state)), dim=1)[0].item()) | ||
|
||
def add_memory(self: Self, *, state: State, action: Action | None, result: ActionResult | None) -> None: | ||
"""Add a new experience into the memory. | ||
Args: | ||
---- | ||
state: The current state of the agent. | ||
action: The action taken by the agent. | ||
result: The result of the action taken by the agent. | ||
""" | ||
if action is None or result is None: | ||
if action is None and result is None: | ||
message = "action or result must not be None" | ||
elif action is None: | ||
message = "action must not be None" | ||
else: | ||
message = "result must not be None" | ||
raise InvalidMemoryError(message) | ||
self.__memory = QLearningMemory( | ||
state=state, action=action, reward=result.reward, next_state=result.next_state, done=result.done | ||
) | ||
|
||
def reset_memory(self: Self) -> None: | ||
"""Reset the agent's memory.""" | ||
self.__memory = None | ||
self.__total_loss = 0.0 | ||
self.__count = 0 | ||
|
||
def update(self: Self) -> None: | ||
"""Updates the Q-values in the Q-learning agent based on the current memory.""" | ||
if self.__memory is None: | ||
raise NotInitializedError(instance_name=str(self), attribute_name="__memory") | ||
if self.__memory.done: | ||
next_action_value = torch.zeros(size=[1]) | ||
else: | ||
next_action_value = torch.max( # noqa: PD011 | ||
self.__q_net(self.__env.convert_to_one_hot(state=self.__memory.next_state)), | ||
dim=1, | ||
).values | ||
target = self.__gamma * next_action_value + self.__memory.reward | ||
current_action_values = self.__q_net(self.__env.convert_to_one_hot(state=self.__memory.state)) | ||
current_action_value = current_action_values[:, self.__memory.action.value] | ||
loss = nn.MSELoss() | ||
output = loss(target, current_action_value) | ||
self.__total_loss += float(output) | ||
self.__count += 1 | ||
self.__q_net.zero_grad() | ||
output.backward() | ||
self.__optimizer.step() |
70 changes: 6 additions & 64 deletions
70
...learning/markov_decision_process/grid_world/methods/temporal_difference/agent_episodes.py
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1,90 +1,32 @@ | ||
"""Episode runner.""" | ||
from reinforcement_learning.markov_decision_process.grid_world.agent_base import AgentBase | ||
from reinforcement_learning.markov_decision_process.grid_world.environment import GridWorld | ||
from reinforcement_learning.markov_decision_process.grid_world.methods.temporal_difference.q_learning import ( | ||
QLearningAgent, | ||
) | ||
from reinforcement_learning.markov_decision_process.grid_world.methods.temporal_difference.sarsa_agent import ( | ||
SarsaAgentBase, | ||
) | ||
from reinforcement_learning.markov_decision_process.grid_world.methods.temporal_difference.td_eval import TdAgent | ||
|
||
|
||
def run_td_episode(env: GridWorld, agent: TdAgent) -> None: | ||
def run_td_episode(*, env: GridWorld, agent: AgentBase, add_goal_state_to_memory: bool) -> None: | ||
"""Run an episode for a temporary difference agent in the environment. | ||
Args: | ||
---- | ||
env: The GridWorld environment in which the agent will run. | ||
agent: The TdAgent. | ||
add_goal_state_to_memory: Whether to add goal state to the agent's memory at the end of an episode. | ||
Returns: | ||
------- | ||
None | ||
""" | ||
env.reset_agent_state() | ||
state = env.agent_state | ||
agent.reset_memory() | ||
while True: | ||
state = env.agent_state | ||
action = agent.get_action(state=state) | ||
result = env.step(action=action) | ||
agent.add_memory(state=state, action=action, result=result) | ||
agent.update() | ||
if result.done: | ||
break | ||
state = result.next_state | ||
|
||
|
||
def run_sarsa_episode(env: GridWorld, agent: SarsaAgentBase) -> None: | ||
"""Run an episode for a SARSA agent in the environment. | ||
Args: | ||
---- | ||
env: The GridWorld environment in which the agent will run. | ||
agent: The SARSA agent. | ||
Returns: | ||
------- | ||
None | ||
""" | ||
env.reset_agent_state() | ||
state = env.agent_state | ||
agent.reset_memory() | ||
while True: | ||
action = agent.get_action(state=state) | ||
result = env.step(action=action) | ||
agent.add_memory(state=state, action=action, result=result) | ||
agent.update() | ||
|
||
if result.done: | ||
break | ||
state = result.next_state | ||
agent.add_memory(state=state, action=None, result=None) | ||
agent.update() | ||
|
||
|
||
def run_q_learning_episode(env: GridWorld, agent: QLearningAgent) -> None: | ||
"""Run an episode for a temporary difference agent in the environment. | ||
Args: | ||
---- | ||
env: The GridWorld environment in which the agent will run. | ||
agent: The TdAgent. | ||
Returns: | ||
------- | ||
None | ||
""" | ||
env.reset_agent_state() | ||
state = env.agent_state | ||
agent.reset_memory() | ||
while True: | ||
action = agent.get_action(state=state) | ||
result = env.step(action=action) | ||
agent.add_memory(state=state, action=action, result=result) | ||
if add_goal_state_to_memory: | ||
agent.add_memory(state=state, action=None, result=None) | ||
agent.update() | ||
if result.done: | ||
break | ||
state = result.next_state |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,15 @@ | ||
import random | ||
|
||
import pytest | ||
import torch | ||
|
||
|
||
@pytest.fixture(autouse=True) | ||
def _torch_fix_seed(seed: int = 42) -> None: | ||
# Python random | ||
random.seed(seed) | ||
# Pytorch | ||
torch.manual_seed(seed) | ||
torch.cuda.manual_seed(seed) | ||
torch.backends.cudnn.deterministic = True | ||
torch.use_deterministic_algorithms(mode=True) |
Empty file.
Oops, something went wrong.
e4cb949
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Coverage Report