Skip to content

Commit

Permalink
Update pre-commit hooks
Browse files Browse the repository at this point in the history
  • Loading branch information
ffelten committed Oct 16, 2024
1 parent a4052e4 commit 2e8a18a
Show file tree
Hide file tree
Showing 30 changed files with 91 additions and 25 deletions.
12 changes: 6 additions & 6 deletions .pre-commit-config.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@
# See https://pre-commit.com/hooks.html for more hooks
repos:
- repo: https://github.com/pre-commit/pre-commit-hooks
rev: v4.4.0
rev: v5.0.0
hooks:
- id: check-symlinks
- id: destroyed-symlinks
Expand All @@ -18,13 +18,13 @@ repos:
- id: detect-private-key
- id: debug-statements
- repo: https://github.com/codespell-project/codespell
rev: v2.2.4
rev: v2.3.0
hooks:
- id: codespell
args:
- --ignore-words-list=reacher,ure,referenc,wile,mor,ser,esr,nowe
- repo: https://github.com/PyCQA/flake8
rev: 6.0.0
rev: 7.1.1
hooks:
- id: flake8
args:
Expand All @@ -35,16 +35,16 @@ repos:
- --show-source
- --statistics
- repo: https://github.com/asottile/pyupgrade
rev: v3.3.1
rev: v3.18.0
hooks:
- id: pyupgrade
args: ["--py37-plus"]
- repo: https://github.com/PyCQA/isort
rev: 5.12.0
rev: 5.13.2
hooks:
- id: isort
- repo: https://github.com/python/black
rev: 23.1.0
rev: 24.10.0
hooks:
- id: black
- repo: https://github.com/pycqa/pydocstyle
Expand Down
1 change: 0 additions & 1 deletion morl_baselines/__init__.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,3 @@
"""MORL-Baselines contains various MORL algorithms and utility functions."""


__version__ = "1.1.0"
1 change: 1 addition & 0 deletions morl_baselines/common/buffer.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
"""Replay buffer for multi-objective reinforcement learning."""

import numpy as np
import torch as th

Expand Down
8 changes: 6 additions & 2 deletions morl_baselines/common/diverse_buffer.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
"""Diverse Experience Replay Buffer. Code extracted from https://github.com/axelabels/DynMORL."""

from dataclasses import dataclass

import numpy as np
Expand Down Expand Up @@ -154,7 +155,7 @@ def update(self, idx: int, p, tree_id=None):
Keyword Arguments:
tree_id {object} -- Tree to be updated (default: {None})
"""
if type(p) == dict:
if isinstance(p, dict):
for k in p:
self.update(idx, p[k], k)
return
Expand Down Expand Up @@ -476,7 +477,10 @@ def get_data(self, include_indices: bool = False):
Returns:
The data
"""
all_data = list(np.arange(self.capacity) + self.capacity - 1), list(self.tree.data)
all_data = (
list(np.arange(self.capacity) + self.capacity - 1),
list(self.tree.data),
)
indices = []
data = []
for i, d in zip(all_data[0], all_data[1]):
Expand Down
1 change: 1 addition & 0 deletions morl_baselines/common/evaluation.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
"""Utilities related to evaluation."""

import os
import random
from typing import List, Optional, Tuple
Expand Down
1 change: 1 addition & 0 deletions morl_baselines/common/experiments.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
"""Common experiment utilities."""

import argparse

from morl_baselines.multi_policy.capql.capql import CAPQL
Expand Down
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
"""Probabilistic ensemble of neural networks."""

import os

import numpy as np
Expand Down
1 change: 1 addition & 0 deletions morl_baselines/common/model_based/tabular_model.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
"""Tabular dynamics model S_{t+1}, R_t ~ m(.,.|s,a) ."""

import random

import numpy as np
Expand Down
53 changes: 44 additions & 9 deletions morl_baselines/common/model_based/utils.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
"""Utility functions for the model."""

from typing import Tuple

import matplotlib.pyplot as plt
Expand Down Expand Up @@ -34,7 +35,7 @@ def termination_fn_dst(obs, act, next_obs):


def termination_fn_mountaincar(obs, act, next_obs):
"""Termination function of mountin car."""
"""Termination function of mountain car."""
assert len(obs.shape) == len(next_obs.shape) == len(act.shape) == 2
position = next_obs[:, 0]
velocity = next_obs[:, 1]
Expand Down Expand Up @@ -147,16 +148,29 @@ def step(
var_obs = var_obs[0]
var_rewards = var_rewards[0]

info = {"uncertainty": uncertainties, "var_obs": var_obs, "var_rewards": var_rewards}
info = {
"uncertainty": uncertainties,
"var_obs": var_obs,
"var_rewards": var_rewards,
}

# info = {'mean': return_means, 'std': return_stds, 'log_prob': log_prob, 'dev': dev}
return next_obs, rewards, terminals, info


def visualize_eval(
agent, env, model=None, w=None, horizon=10, init_obs=None, compound=True, deterministic=False, show=False, filename=None
agent,
env,
model=None,
w=None,
horizon=10,
init_obs=None,
compound=True,
deterministic=False,
show=False,
filename=None,
):
"""Generates a plot of the evolution of the state, reward and model predicitions ove time.
"""Generates a plot of the evolution of the state, reward and model predictions over time.
Args:
agent: agent to be evaluated
Expand Down Expand Up @@ -213,10 +227,16 @@ def visualize_eval(
acts = F.one_hot(acts, num_classes=env.action_space.n).squeeze(1)
for step in range(len(real_obs)):
if compound or step == 0:
obs, r, done, info = model_env.step(th.tensor(obs).to(agent.device), acts[step], deterministic=deterministic)
obs, r, done, info = model_env.step(
th.tensor(obs).to(agent.device),
acts[step],
deterministic=deterministic,
)
else:
obs, r, done, info = model_env.step(
th.tensor(real_obs[step - 1]).to(agent.device), acts[step], deterministic=deterministic
th.tensor(real_obs[step - 1]).to(agent.device),
acts[step],
deterministic=deterministic,
)
model_obs.append(obs.copy())
model_obs_stds.append(np.sqrt(info["var_obs"].copy()))
Expand All @@ -240,11 +260,26 @@ def visualize_eval(
axs[i].set_ylabel(f"Reward {i - obs_dim}")
axs[i].grid(alpha=0.25)
if w is not None:
axs[i].plot(x, [real_vec_rewards[step][i - obs_dim] for step in x], label="Environment", color="black")
axs[i].plot(
x,
[real_vec_rewards[step][i - obs_dim] for step in x],
label="Environment",
color="black",
)
else:
axs[i].plot(x, [real_rewards[step] for step in x], label="Environment", color="black")
axs[i].plot(
x,
[real_rewards[step] for step in x],
label="Environment",
color="black",
)
if model is not None:
axs[i].plot(x, [model_rewards[step][i - obs_dim] for step in x], label="Model", color="blue")
axs[i].plot(
x,
[model_rewards[step][i - obs_dim] for step in x],
label="Model",
color="blue",
)
axs[i].fill_between(
x,
[model_rewards[step][i - obs_dim] + model_rewards_stds[step][i - obs_dim] for step in x],
Expand Down
1 change: 1 addition & 0 deletions morl_baselines/common/morl_algorithm.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
"""MORL algorithm base classes."""

import os
import time
from abc import ABC, abstractmethod
Expand Down
1 change: 1 addition & 0 deletions morl_baselines/common/pareto.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
"""Pareto utilities."""

from copy import deepcopy
from typing import List, Union

Expand Down
1 change: 1 addition & 0 deletions morl_baselines/common/performance_indicators.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@
We mostly rely on pymoo for the computation of axiomatic indicators (HV and IGD), but some are customly made.
"""

from copy import deepcopy
from typing import Callable, List

Expand Down
1 change: 1 addition & 0 deletions morl_baselines/common/prioritized_buffer.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@
Code adapted from https://github.com/sfujim/LAP-PAL
"""

import numpy as np
import torch as th

Expand Down
1 change: 1 addition & 0 deletions morl_baselines/common/scalarization.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
"""Scalarization functions relying on numpy."""

import numpy as np
from pymoo.decomposition.tchebicheff import Tchebicheff

Expand Down
1 change: 1 addition & 0 deletions morl_baselines/common/utils.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
"""General utils for the MORL baselines."""

import math
import os
from typing import Callable, List
Expand Down
1 change: 1 addition & 0 deletions morl_baselines/multi_policy/capql/capql.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
"""CAPQL algorithm."""

import os
import random
from itertools import chain
Expand Down
1 change: 1 addition & 0 deletions morl_baselines/multi_policy/envelope/envelope.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
"""Envelope Q-Learning implementation."""

import os
from typing import List, Optional, Union
from typing_extensions import override
Expand Down
1 change: 1 addition & 0 deletions morl_baselines/multi_policy/gpi_pd/gpi_pd.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
"""GPI-PD algorithm."""

import os
import random
from itertools import chain
Expand Down
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
"""GPI-PD algorithm with continuous actions."""

import os
import random
from itertools import chain
Expand Down
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
"""Linear Support implementation."""

import random
from copy import deepcopy
from typing import List, Optional
Expand Down
1 change: 1 addition & 0 deletions morl_baselines/multi_policy/morld/morld.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@
See Felten, Talbi & Danoy (2024): https://arxiv.org/abs/2311.12495.
"""

import math
import time
from typing import Callable, List, Optional, Tuple, Union
Expand Down
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
"""Outer-loop MOQ-learning algorithm (uses multiple weights)."""

import time
from copy import deepcopy
from typing import List, Optional
Expand Down
17 changes: 11 additions & 6 deletions morl_baselines/multi_policy/pareto_q_learning/pql.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
"""Pareto Q-Learning."""

import numbers
from typing import Callable, List, Optional

Expand Down Expand Up @@ -60,19 +61,19 @@ def __init__(
# Algorithm setup
self.ref_point = ref_point

if type(self.env.action_space) == gym.spaces.Discrete:
if isinstance(self.env.action_space, gym.spaces.Discrete):
self.num_actions = self.env.action_space.n
elif type(self.env.action_space) == gym.spaces.MultiDiscrete:
elif isinstance(self.env.action_space, gym.spaces.MultiDiscrete):
self.num_actions = np.prod(self.env.action_space.nvec)
else:
raise Exception("PQL only supports (multi)discrete action spaces.")

if type(self.env.observation_space) == gym.spaces.Discrete:
if isinstance(self.env.observation_space, gym.spaces.Discrete):
self.env_shape = (self.env.observation_space.n,)
elif type(self.env.observation_space) == gym.spaces.MultiDiscrete:
elif isinstance(self.env.observation_space, gym.spaces.MultiDiscrete):
self.env_shape = self.env.observation_space.nvec
elif (
type(self.env.observation_space) == gym.spaces.Box
isinstance(self.env.observation_space, gym.spaces.Box)
and self.env.observation_space.is_bounded(manner="both")
and issubclass(self.env.observation_space.dtype.type, numbers.Integral)
):
Expand All @@ -96,7 +97,11 @@ def __init__(
self.log = log

if self.log:
self.setup_wandb(project_name=self.project_name, experiment_name=self.experiment_name, entity=wandb_entity)
self.setup_wandb(
project_name=self.project_name,
experiment_name=self.experiment_name,
entity=wandb_entity,
)

def get_config(self) -> dict:
"""Get the configuration dictionary.
Expand Down
1 change: 1 addition & 0 deletions morl_baselines/multi_policy/pcn/pcn.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
"""Pareto Conditioned Network. Code adapted from https://github.com/mathieu-reymond/pareto-conditioned-networks ."""

import heapq
import os
from abc import ABC
Expand Down
1 change: 1 addition & 0 deletions morl_baselines/multi_policy/pgmorl/pgmorl.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@
(!) Limited to 2 objectives for now.
(!) The post-processing phase has not been implemented yet.
"""

import time
from copy import deepcopy
from typing import List, Optional, Tuple, Union
Expand Down
1 change: 1 addition & 0 deletions morl_baselines/single_policy/esr/eupg.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
"""EUPG is an ESR algorithm based on Policy Gradient (REINFORCE like)."""

import time
from copy import deepcopy
from typing import Callable, List, Optional, Union
Expand Down
1 change: 1 addition & 0 deletions morl_baselines/single_policy/ser/mo_ppo.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
"""Multi-Objective PPO Algorithm."""

import time
from copy import deepcopy
from typing import List, Optional, Union
Expand Down
1 change: 1 addition & 0 deletions morl_baselines/single_policy/ser/mo_q_learning.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
"""Scalarized Q-learning for single policy multi-objective reinforcement learning."""

import time
from typing import Optional
from typing_extensions import override
Expand Down
1 change: 0 additions & 1 deletion pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -66,7 +66,6 @@ morl_baselines = ["*.json", "assets/*"]

# Linting, testing, ... ########################################################
[tool.black]
safe = true
line-length = 127
target-version = ['py38', 'py39', 'py310']
include = '\.pyi?$'
Expand Down
1 change: 1 addition & 0 deletions tests/test_algos.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
"""Mostly tests to make sure the algorithms are able to run."""

import time

import mo_gymnasium as mo_gym
Expand Down

0 comments on commit 2e8a18a

Please sign in to comment.