Skip to content

Commit

Permalink
Update test assertions to use pytest.approx for float comparison
Browse files Browse the repository at this point in the history
The comparison of floats in test assertions have been updated to use pytest approx function. Floats comparisons can often lead to failures due to slight variations in the least significant digits which are irrelevant to the test outcome. Using pytest's approx utility function helps avoid these pitfalls and makes the tests more robust and reliable. Additionally, it standardizes the test assertions for all test cases across the project. Evaluation and assertions for both Markov Decision Process and the Reinforcement Learning algorithms have been updated accordingly.
  • Loading branch information
nakashima-hikaru committed Nov 26, 2023
1 parent ed9c5fd commit 4b4e8a4
Show file tree
Hide file tree
Showing 7 changed files with 83 additions and 80 deletions.
2 changes: 1 addition & 1 deletion .github/workflows/ci.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -32,7 +32,7 @@ jobs:
run: poetry run task mypy
- name: Run unittests with coverage check
run: |
poetry run pytest --cov --junitxml=pytest.xml --cov-report=term-missing:skip-covered | tee pytest-coverage.txt
poetry run pytest --cov --junitxml=pytest.xml --cov-report=term-missing:skip-covered && tee pytest-coverage.txt
- name: Create Coverage Comment
id: coverageComment
uses: MishaKav/pytest-coverage-comment@main
Expand Down
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
import numpy as np
import pytest

from reinforcement_learning.markov_decision_process.grid_world.environment import Action, GridWorld
from reinforcement_learning.markov_decision_process.grid_world.methods.monte_carlo.mc_agent import (
Expand All @@ -20,11 +21,11 @@ def test_mc_control_off_policy() -> None:
run_monte_carlo_episode(env=env, agent=agent)

assert agent.action_value == {
((0, 2), Action.RIGHT): 0.19,
((0, 2), Action.UP): 0.09000000000000001,
((0, 2), Action.RIGHT): pytest.approx(0.19),
((0, 2), Action.UP): pytest.approx(0.09000000000000001),
((0, 2), Action.DOWN): 0.0,
((0, 2), Action.LEFT): 0.0,
((0, 1), Action.RIGHT): 0.09729729729729729,
((0, 1), Action.RIGHT): pytest.approx(0.09729729729729729),
((0, 1), Action.UP): 0.0,
((0, 1), Action.DOWN): 0.0,
((0, 1), Action.LEFT): 0.0,
Expand Down
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
import numpy as np
import pytest

from reinforcement_learning.markov_decision_process.grid_world.environment import Action, GridWorld
from reinforcement_learning.markov_decision_process.grid_world.methods.monte_carlo.mc_agent import (
Expand All @@ -20,24 +21,24 @@ def test_mc_control_on_policy() -> None:
run_monte_carlo_episode(env=env, agent=agent)

assert agent.action_value == {
((0, 2), Action.RIGHT): 0.19,
((0, 2), Action.UP): 0.09000000000000001,
((0, 2), Action.RIGHT): pytest.approx(0.19),
((0, 2), Action.UP): pytest.approx(0.09000000000000001),
((0, 2), Action.DOWN): 0.0,
((0, 2), Action.LEFT): 0.05314410000000002,
((0, 1), Action.RIGHT): 0.19865672100000004,
((0, 2), Action.LEFT): pytest.approx(0.05314410000000002),
((0, 1), Action.RIGHT): pytest.approx(0.19865672100000004),
((0, 1), Action.UP): 0.0,
((0, 1), Action.DOWN): 0.0,
((0, 1), Action.LEFT): 0.05904900000000002,
((0, 0), Action.RIGHT): 0.17879104890000003,
((0, 1), Action.LEFT): pytest.approx(0.05904900000000002),
((0, 0), Action.RIGHT): pytest.approx(0.17879104890000003),
((0, 0), Action.UP): 0.0,
((0, 0), Action.DOWN): 0.0,
((0, 0), Action.LEFT): 0.06561000000000002,
((1, 0), Action.UP): 0.10776784401000003,
((0, 0), Action.LEFT): pytest.approx(0.06561000000000002),
((1, 0), Action.UP): pytest.approx(0.10776784401000003),
((1, 0), Action.DOWN): 0.0,
((1, 0), Action.LEFT): 0.0,
((1, 0), Action.RIGHT): 0.0,
((2, 0), Action.UP): 0.09699105960900004,
((2, 0), Action.DOWN): 0.03138105960900002,
((2, 0), Action.LEFT): 0.02824295364810002,
((2, 0), Action.UP): pytest.approx(0.09699105960900004),
((2, 0), Action.DOWN): pytest.approx(0.03138105960900002),
((2, 0), Action.LEFT): pytest.approx(0.02824295364810002),
((2, 0), Action.RIGHT): 0.0,
}
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
import numpy as np
import pytest

from reinforcement_learning.markov_decision_process.grid_world.environment import GridWorld
from reinforcement_learning.markov_decision_process.grid_world.methods.monte_carlo.mc_agent import (
Expand All @@ -18,14 +19,14 @@ def test_mc_eval() -> None:
run_monte_carlo_episode(env=env, agent=agent)

assert agent.state_value == {
(0, 2): 0.22042382796141025,
(0, 1): 0.05544759280164322,
(0, 0): 0.11230028389108984,
(1, 0): 0.018681242818946488,
(2, 0): 0.07457318402670388,
(1, 3): 0.44999999999999996,
(1, 2): -0.5866904341957879,
(2, 2): -0.38389590313967503,
(2, 1): -0.1588166194308249,
(2, 3): -0.47275257189285014,
(0, 2): pytest.approx(0.22042382796141025),
(0, 1): pytest.approx(0.05544759280164322),
(0, 0): pytest.approx(0.11230028389108984),
(1, 0): pytest.approx(0.018681242818946488),
(2, 0): pytest.approx(0.07457318402670388),
(1, 3): pytest.approx(0.44999999999999996),
(1, 2): pytest.approx(-0.5866904341957879),
(2, 2): pytest.approx(-0.38389590313967503),
(2, 1): pytest.approx(-0.1588166194308249),
(2, 3): pytest.approx(-0.47275257189285014),
}
Original file line number Diff line number Diff line change
Expand Up @@ -27,52 +27,52 @@ def test_q_learning() -> None:
expected = 0.1085729799230003
assert agent.average_loss == expected
assert agent.action_value == {
((0, 0), Action.UP): 0.012391064316034317,
((0, 0), Action.DOWN): 0.03795982897281647,
((0, 0), Action.LEFT): -0.024705907329916954,
((0, 0), Action.RIGHT): 0.122371606528759,
((0, 1), Action.UP): 0.02835841476917267,
((0, 1), Action.DOWN): 0.03410963714122772,
((0, 1), Action.LEFT): -0.06401852518320084,
((0, 1), Action.RIGHT): 0.08941204845905304,
((0, 2), Action.UP): 0.13887614011764526,
((0, 2), Action.DOWN): -0.12020474672317505,
((0, 2), Action.LEFT): -0.09385768324136734,
((0, 2), Action.RIGHT): 0.14876116812229156,
((0, 3), Action.UP): 0.08330991864204407,
((0, 3), Action.DOWN): -0.05502215027809143,
((0, 3), Action.LEFT): -0.024705795571208,
((0, 3), Action.RIGHT): 0.07155784964561462,
((1, 0), Action.UP): 0.08537131547927856,
((1, 0), Action.DOWN): 0.030280638486146927,
((1, 0), Action.LEFT): -0.09365365654230118,
((1, 0), Action.RIGHT): 0.055440232157707214,
((1, 1), Action.UP): 0.12489913403987885,
((1, 1), Action.DOWN): -0.06309452652931213,
((1, 1), Action.LEFT): -0.06900961697101593,
((1, 1), Action.RIGHT): 0.17078934609889984,
((1, 2), Action.UP): 0.03831055760383606,
((1, 2), Action.DOWN): -0.01612631231546402,
((1, 2), Action.LEFT): -0.06555116921663284,
((1, 2), Action.RIGHT): 0.028176531195640564,
((1, 3), Action.UP): 0.10358510911464691,
((1, 3), Action.DOWN): -0.08015541732311249,
((1, 3), Action.LEFT): -0.005136379972100258,
((1, 3), Action.RIGHT): 0.16100265085697174,
((2, 0), Action.UP): 0.08069709688425064,
((2, 0), Action.DOWN): 0.005144223570823669,
((2, 0), Action.LEFT): -0.12853993475437164,
((2, 0), Action.RIGHT): 0.038303472101688385,
((2, 1), Action.UP): 0.16602256894111633,
((2, 1), Action.DOWN): -0.016644027084112167,
((2, 1), Action.LEFT): -0.058554477989673615,
((2, 1), Action.RIGHT): 0.15737693011760712,
((2, 2), Action.UP): 0.06638707220554352,
((2, 2), Action.DOWN): 0.05980466678738594,
((2, 2), Action.LEFT): 0.05435100942850113,
((2, 2), Action.RIGHT): 0.1403900384902954,
((2, 3), Action.UP): 0.048222266137599945,
((2, 3), Action.DOWN): -0.024945009499788284,
((2, 3), Action.LEFT): -0.166785329580307,
((2, 3), Action.RIGHT): 0.13613100349903107,
((0, 0), Action.UP): pytest.approx(0.012391064316034317),
((0, 0), Action.DOWN): pytest.approx(0.03795982897281647),
((0, 0), Action.LEFT): pytest.approx(-0.024705907329916954),
((0, 0), Action.RIGHT): pytest.approx(0.122371606528759),
((0, 1), Action.UP): pytest.approx(0.02835841476917267),
((0, 1), Action.DOWN): pytest.approx(0.03410963714122772),
((0, 1), Action.LEFT): pytest.approx(-0.06401852518320084),
((0, 1), Action.RIGHT): pytest.approx(0.08941204845905304),
((0, 2), Action.UP): pytest.approx(0.13887614011764526),
((0, 2), Action.DOWN): pytest.approx(-0.12020474672317505),
((0, 2), Action.LEFT): pytest.approx(-0.09385768324136734),
((0, 2), Action.RIGHT): pytest.approx(0.14876116812229156),
((0, 3), Action.UP): pytest.approx(0.08330991864204407),
((0, 3), Action.DOWN): pytest.approx(-0.05502215027809143),
((0, 3), Action.LEFT): pytest.approx(-0.024705795571208),
((0, 3), Action.RIGHT): pytest.approx(0.07155784964561462),
((1, 0), Action.UP): pytest.approx(0.08537131547927856),
((1, 0), Action.DOWN): pytest.approx(0.030280638486146927),
((1, 0), Action.LEFT): pytest.approx(-0.09365365654230118),
((1, 0), Action.RIGHT): pytest.approx(0.055440232157707214),
((1, 1), Action.UP): pytest.approx(0.12489913403987885),
((1, 1), Action.DOWN): pytest.approx(-0.06309452652931213),
((1, 1), Action.LEFT): pytest.approx(-0.06900961697101593),
((1, 1), Action.RIGHT): pytest.approx(0.17078934609889984),
((1, 2), Action.UP): pytest.approx(0.03831055760383606),
((1, 2), Action.DOWN): pytest.approx(-0.01612631231546402),
((1, 2), Action.LEFT): pytest.approx(-0.06555116921663284),
((1, 2), Action.RIGHT): pytest.approx(0.028176531195640564),
((1, 3), Action.UP): pytest.approx(0.10358510911464691),
((1, 3), Action.DOWN): pytest.approx(-0.08015541732311249),
((1, 3), Action.LEFT): pytest.approx(-0.005136379972100258),
((1, 3), Action.RIGHT): pytest.approx(0.16100265085697174),
((2, 0), Action.UP): pytest.approx(0.08069709688425064),
((2, 0), Action.DOWN): pytest.approx(0.005144223570823669),
((2, 0), Action.LEFT): pytest.approx(-0.12853993475437164),
((2, 0), Action.RIGHT): pytest.approx(0.038303472101688385),
((2, 1), Action.UP): pytest.approx(0.16602256894111633),
((2, 1), Action.DOWN): pytest.approx(-0.016644027084112167),
((2, 1), Action.LEFT): pytest.approx(-0.058554477989673615),
((2, 1), Action.RIGHT): pytest.approx(0.15737693011760712),
((2, 2), Action.UP): pytest.approx(0.06638707220554352),
((2, 2), Action.DOWN): pytest.approx(0.05980466678738594),
((2, 2), Action.LEFT): pytest.approx(0.05435100942850113),
((2, 2), Action.RIGHT): pytest.approx(0.1403900384902954),
((2, 3), Action.UP): pytest.approx(0.048222266137599945),
((2, 3), Action.DOWN): pytest.approx(-0.024945009499788284),
((2, 3), Action.LEFT): pytest.approx(-0.166785329580307),
((2, 3), Action.RIGHT): pytest.approx(0.13613100349903107),
}
Original file line number Diff line number Diff line change
Expand Up @@ -43,11 +43,11 @@ def test_q_learning() -> None:
((0, 1), Action.UP): 0.0,
((0, 1), Action.DOWN): 0.0,
((0, 1), Action.LEFT): 0.0,
((0, 1), Action.RIGHT): 0.5760000000000001,
((0, 1), Action.RIGHT): pytest.approx(0.5760000000000001),
((0, 2), Action.UP): 0.0,
((0, 2), Action.DOWN): 0.0,
((0, 2), Action.LEFT): 0.0,
((0, 2), Action.RIGHT): 0.96,
((0, 2), Action.RIGHT): pytest.approx(0.96),
((1, 2), Action.UP): 0.0,
((1, 2), Action.DOWN): 0.0,
((1, 2), Action.LEFT): 0.0,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -29,10 +29,10 @@ def test_td_evaluation() -> None:
run_td_episode(env=env, agent=agent)

assert agent.v == {
(2, 1): -8.100000000000001e-07,
(2, 1): pytest.approx(-8.100000000000001e-07),
(2, 0): 0.0,
(2, 2): -8.901090000000001e-05,
(1, 2): -0.0199891,
(1, 3): 0.01019701,
(2, 3): -0.009998209000000001,
(2, 2): pytest.approx(-8.901090000000001e-05),
(1, 2): pytest.approx(-0.0199891),
(1, 3): pytest.approx(0.01019701),
(2, 3): pytest.approx(-0.009998209000000001),
}

0 comments on commit 4b4e8a4

Please sign in to comment.