diff --git a/examples/rl/rl_example_11/rl_example_11.md b/examples/rl/rl_example_11/rl_example_11.md
index 6da5b76..2c480db 100644
--- a/examples/rl/rl_example_11/rl_example_11.md
+++ b/examples/rl/rl_example_11/rl_example_11.md
@@ -1,123 +1,315 @@
-# Example 11: An A2C Solver for _CartPole_
+# Example 11: An A2C Solver for CartPole
-In this example we will develop a policy gradient RL solver in order to
-use on the CartPole environment. In particular, we will develop an actor-critic method
-knonw as A2C.
+Example Example 13: REINFORCE algorithm on CartPole
+introduce the vanilla REINFORCE algorithm in order to solve the pole balancing problem.
+The REINFORCE algorithm is a simple algorithm and easy to use however it exhibits a large variance in the reward signal.
+In this exampel, we will introduce actor-critic methods
+and specifically the A2C algorithm. The main objective of actor-critic methods is to further reduce the high gradient variance.
+One way towards this direction is to use the so-called reward-to-go term for a trajectory $T$.
+
+$$G = \sum_{k = t}^T R(s_{k}, \alpha_{k})$$
For this example to work, you will need to have both PyTorch and ```rlenvs_cpp``` enabled.
+The workings of the A2C algorithm are handled by the A2CSolver
+class. The class accepts three template parameters
+
+- The environment type
+- The policy type or the actor type
+- The critic type
+
+The role of the actor is to select an action for the agent to take.
+The role of the critic is to tell us whether that action was good or bad.
+We could use of course the raw reward signal to have an assessment on this
+but for enviroments where the reward is sparse this may not work or for enviroments
+where the reward signal is the same for most actions.
+
+Given this approach, the critic will be a term in the actor’s loss function,
+whilst the critic will learn directly from the provided reward signals.
+
+In the code below, the critic and action networks do not share any details.
+This need not be the case however. You can come up with an implementation
+where the two networks share most of the layers and only differentiate
+at the output layer.
+
+
+### The Actor network class
+
+```
+// create the Action and the Critic networks
+class ActorNetImpl: public torch::nn::Module
+{
+public:
+
+ // constructor
+ ActorNetImpl(uint_t state_size, uint_t action_size);
+
+ torch_tensor_t forward(torch_tensor_t state);
+ torch_tensor_t log_probabilities(torch_tensor_t actions);
+ torch_tensor_t sample();
+
+private:
+
+ torch::nn::Linear linear1_;
+ torch::nn::Linear linear2_;
+ torch::nn::Linear linear3_;
+
+ // the underlying distribution used to sample actions
+ TorchCategorical distribution_;
+};
+
+ActorNetImpl::ActorNetImpl(uint_t state_size, uint_t action_size)
+:
+torch::nn::Module(),
+linear1_(nullptr),
+linear2_(nullptr),
+linear3_(nullptr)
+{
+ linear1_ = register_module("linear1_", torch::nn::Linear(state_size, 128));
+ linear2_ = register_module("linea2_", torch::nn::Linear(128, 256));
+ linear3_ = register_module("linear3_", torch::nn::Linear(256, action_size));
+}
+
+
+torch_tensor_t
+ActorNetImpl::forward(torch_tensor_t state){
+ auto output = torch::nn::functional::relu(linear1_(state));
+ output = torch::nn::functional::relu(linear2_(output));
+ output = linear3_(output);
+ const torch_tensor_t probs = torch::nn::functional::softmax(output,-1);
+ distribution_.build_from_probabilities(probs);
+ return probs;
+}
+
+torch_tensor_t
+ActorNetImpl::sample(){
+ return distribution_.sample();
+}
+
+torch_tensor_t
+ActorNetImpl::log_probabilities(torch_tensor_t actions){
+ return distribution_.log_prob(actions);
+}
+```
+
+
## The driver code
The driver code for this tutorial is shown below.
```cpp
-/**
- * This example illustrates a simple example of Monte Carlo
- * iteration using the IterationCounter class
- *
- * */
+#include "cubeai/base/cubeai_config.h"
+
+#if defined(USE_PYTORCH) && defined(USE_RLENVS_CPP)
#include "cubeai/base/cubeai_types.h"
-#include "cubeai/utils/iteration_counter.h"
-#include "cubeai/geom_primitives/shapes/circle.h"
-#include "cubeai/extern/nlohmann/json/json.hpp"
+#include "cubeai/maths/statistics/distributions/torch_categorical.h"
+#include "cubeai/rl/trainers/rl_serial_agent_trainer.h"
+#include "cubeai/rl/algorithms/actor_critic/a2c.h"
+#include "cubeai/maths/optimization/optimizer_type.h"
+#include "cubeai/maths/optimization/pytorch_optimizer_factory.h"
+#include "rlenvs/envs/gymnasium/classic_control/cart_pole_env.h"
-#include
#include
-#include
-#include
+#include
+#include
-namespace intro_example_1
-{
+namespace rl_example_11{
+
+const std::string SERVER_URL = "http://0.0.0.0:8001/api";
using cubeai::real_t;
using cubeai::uint_t;
-using cubeai::utils::IterationCounter;
-using cubeai::geom_primitives::Circle;
+using cubeai::torch_tensor_t;
+using cubeai::maths::stats::TorchCategorical;
+using cubeai::rl::algos::ac::A2CConfig;
+using cubeai::rl::algos::ac::A2CSolver;
+using cubeai::rl::RLSerialAgentTrainer;
+using cubeai::rl::RLSerialTrainerConfig;
+using rlenvs_cpp::envs::gymnasium::CartPoleActionsEnum;
-using json = nlohmann::json;
-const std::string CONFIG = "config.json";
+// create the Action and the Critic networks
+class ActorNetImpl: public torch::nn::Module
+{
+public:
-// read the JSON file
-json
-load_config(const std::string& filename){
+ // constructor
+ ActorNetImpl(uint_t state_size, uint_t action_size);
- std::ifstream f(filename);
- json data = json::parse(f);
- return data;
+
+ torch_tensor_t forward(torch_tensor_t state);
+ torch_tensor_t log_probabilities(torch_tensor_t actions);
+ torch_tensor_t sample();
+
+
+private:
+
+ torch::nn::Linear linear1_;
+ torch::nn::Linear linear2_;
+ torch::nn::Linear linear3_;
+
+ // the underlying distribution used to sample actions
+ TorchCategorical distribution_;
+};
+
+ActorNetImpl::ActorNetImpl(uint_t state_size, uint_t action_size)
+:
+torch::nn::Module(),
+linear1_(nullptr),
+linear2_(nullptr),
+linear3_(nullptr)
+{
+ linear1_ = register_module("linear1_", torch::nn::Linear(state_size, 128));
+ linear2_ = register_module("linea2_", torch::nn::Linear(128, 256));
+ linear3_ = register_module("linear3_", torch::nn::Linear(256, action_size));
}
+torch_tensor_t
+ActorNetImpl::forward(torch_tensor_t state){
+
+ auto output = torch::nn::functional::relu(linear1_(state));
+ output = torch::nn::functional::relu(linear2_(output));
+ output = linear3_(output);
+ const torch_tensor_t probs = torch::nn::functional::softmax(output,-1);
+ distribution_.build_from_probabilities(probs);
+ return probs;
}
-int main() {
+torch_tensor_t
+ActorNetImpl::sample(){
+ return distribution_.sample();
+}
- using namespace intro_example_1;
+torch_tensor_t
+ActorNetImpl::log_probabilities(torch_tensor_t actions){
+ return distribution_.log_prob(actions);
+}
+
+
+class CriticNetImpl: public torch::nn::Module
+{
+public:
+
+ // constructor
+ CriticNetImpl(uint_t state_size);
+
+ torch_tensor_t forward(torch_tensor_t state);
+
+private:
+
+ torch::nn::Linear linear1_;
+ torch::nn::Linear linear2_;
+ torch::nn::Linear linear3_;
+
+};
+
+
+CriticNetImpl::CriticNetImpl(uint_t state_size)
+:
+torch::nn::Module(),
+linear1_(nullptr),
+linear2_(nullptr),
+linear3_(nullptr)
+{
+ linear1_ = register_module("linear1_", torch::nn::Linear(state_size, 128));
+ linear2_ = register_module("linea2_", torch::nn::Linear(128, 256));
+ linear3_ = register_module("linear3_", torch::nn::Linear(256, 1));
+}
+
+torch_tensor_t
+CriticNetImpl::forward(torch_tensor_t state){
+
+ auto output = torch::nn::functional::relu(linear1_(state));
+ output = torch::nn::functional::relu(linear2_(output));
+ output = linear3_(output);
+ return output;
+}
+
+TORCH_MODULE(ActorNet);
+TORCH_MODULE(CriticNet);
+
+
+typedef rlenvs_cpp::envs::gymnasium::CartPole env_type;
+
+
+}
+
+
+int main(){
+
+ using namespace rl_example_11;
try{
- BOOST_LOG_TRIVIAL(info)<<"Reading configuration file...";
-
- // load the json configuration
- auto data = load_config(CONFIG);
-
- // read properties from the configuration
- const auto R = data["R"].template get();
- const auto N_POINTS = data["N_POINTS"].template get();
- const auto SEED = data["SEED"].template get();
- const auto X = data["X"].template get();
- const auto Y = data["Y"].template get();
-
- // create a circle
- Circle c(R, {X, Y});
-
- // simple object to control iterations
- IterationCounter counter(N_POINTS);
-
- // how many points we found in the Circle
- auto points_inside_circle = 0;
-
- // the box has side 2
- const real_t SQUARE_SIDE = R*2.0;
- std::uniform_real_distribution dist(0.0,SQUARE_SIDE);
- std::mt19937 gen(SEED);
-
- BOOST_LOG_TRIVIAL(info)<<"Starting computation...";
- while(counter.continue_iterations()){
- auto x = dist(gen);
- auto y = dist(gen);
- if(c.is_inside(x,y, 1.0e-4)){
- points_inside_circle += 1;
- }
- }
-
- BOOST_LOG_TRIVIAL(info)<<"Finished computation...";
- auto area = (static_cast(points_inside_circle) / static_cast(N_POINTS)) * std::pow(SQUARE_SIDE, 2);
- BOOST_LOG_TRIVIAL(info)<<"Circle area calculated with:" < options;
+
+ std::cout<<"Creating the environment..."< solver_type;
+
+ solver_type solver(a2c_config, policy, critic,
+ policy_optimizer, critic_optimizer);
+
+ RLSerialTrainerConfig config;
+ RLSerialAgentTrainer trainer(config, solver);
+ trainer.train(env);
+
}
catch(std::exception& e){
- BOOST_LOG_TRIVIAL(error)<
- return 0;
+int main(){
+
+ std::cout<<"This example requires the flag USE_RLENVS_CPP to be true."<EXample 12: DQN algorithm on Gridworld
-and EXample 15: DQN algorithm on Gridworld with experience replay approximate a value function
+The DQN algorithm we used in examples Example 12: DQN algorithm on Gridworld
+and Example 15: DQN algorithm on Gridworld with experience replay approximate a value function
and in particular the Q state-action value function.
However, in reinforcement learning we are more interested in policies since a policy dictates how
an agent behaves in a given state.
diff --git a/include/cubeai/rl/actions/action_space.h b/include/cubeai/rl/actions/action_space.h
deleted file mode 100644
index 949b5b9..0000000
--- a/include/cubeai/rl/actions/action_space.h
+++ /dev/null
@@ -1,53 +0,0 @@
-#ifndef ACTION_SPACE_H
-#define ACTION_SPACE_H
-
-#include "cubeai/base/cubeai_config.h"
-#include "cubeai/base/cubeai_types.h"
-
-#include
-#include
-
-namespace cubeai {
-namespace rl {
-namespace actions {
-
-struct ActionSpace
-{
- std::string type;
- std::vector shape;
-
- ///
- /// \brief ActionSpace
- ///
- ActionSpace(const std::string& t, std::vector& sh)
- :
- type(t),
- shape(sh)
- {}
-};
-
-#ifdef USE_PYTORCH
-struct TorchActionSpace
-{
- std::string type;
- std::vector shape;
-
- ///
- /// \brief TorchActionSpace
- ///
- TorchActionSpace(const std::string& t, std::vector& sh)
- :
- type(t),
- shape(sh)
- {}
-
-};
-#endif
-
-}
-
-}
-
-}
-
-#endif // ACTION_SPACE_H
diff --git a/include/cubeai/rl/worlds/discrete_world.h b/include/cubeai/rl/worlds/discrete_world.h
deleted file mode 100644
index 456f9d0..0000000
--- a/include/cubeai/rl/worlds/discrete_world.h
+++ /dev/null
@@ -1,79 +0,0 @@
-#ifndef DISCRETE_WORLD_H
-#define DISCRETE_WORLD_H
-
-#include "cubeai/base/cubeai_types.h"
-#include "cubeai/rl/worlds/world_base.h"
-
-#include
-#include
-#include
-
-namespace cubeai{
-namespace rl{
-namespace envs {
-
-
-///
-/// \brief The DiscreteWorldBase class
-///
-template
-class DiscreteWorldBase: public WorldBase
-{
-public:
-
- ///
- /// \brief state_t
- ///
- typedef typename WorldBase::state_type state_type;
-
- ///
- /// \brief action_t
- ///
- typedef typename WorldBase::action_type action_type;
-
- ///
- /// \brief time_step_t
- ///
- typedef typename WorldBase::time_step_type time_step_type;
-
- ///
- /// \brief ~DiscreteWorldBase. Destructor
- ///
- virtual ~DiscreteWorldBase() = default;
-
- ///
- /// \brief n_actions
- ///
- virtual uint_t n_actions()const = 0;
-
- ///
- /// \brief n_states
- ///
- virtual uint_t n_states()const = 0;
-
- ///
- /// \brief transition_dynamics
- ///
- virtual std::vector> transition_dynamics(uint_t s, uint_t aidx)const {};
-
-protected:
-
- ///
- /// \brief DiscreteWorldBase
- /// \param name
- ///
- DiscreteWorldBase(std::string name);
-
-};
-
-template
-DiscreteWorldBase::DiscreteWorldBase(std::string name)
- :
- WorldBase(name)
-{}
-
-}
-}
-}
-
-#endif // DISCRETE_WORLD_H
diff --git a/include/cubeai/rl/worlds/world_base.h b/include/cubeai/rl/worlds/world_base.h
deleted file mode 100644
index 48b8ee1..0000000
--- a/include/cubeai/rl/worlds/world_base.h
+++ /dev/null
@@ -1,130 +0,0 @@
-#ifndef WORLD_BASE_H
-#define WORLD_BASE_H
-
-#include "cubeai/base/cubeai_types.h"
-#include
-
-#include
-#include
-#include
-#include
-
-namespace cubeai {
-namespace rl {
-namespace envs{
-
-
-///
-/// \brief Base minimal class for RL environments
-///
-template
-class WorldBase: private boost::noncopyable
-{
-
-public:
-
- ///
- /// \brief state_t
- ///
- typedef StateTp state_type;
-
- ///
- /// \brief action_t
- ///
- typedef ActionTp action_type;
-
- ///
- /// \brief time_step_t
- ///
- typedef TimeStepTp time_step_type;
-
- ///
- /// \brief ~WorldBase. Destructor
- ///
- virtual ~WorldBase() = default;
-
- ///
- /// \brief Transition to a new state by
- /// performing the given action. It returns a tuple
- /// with the following information
- /// arg1: An observation of the environment.
- /// arg2: Amount of reward achieved by the previous action.
- /// arg3: Flag indicating whether it’s time to reset the environment again.
- /// Most (but not all) tasks are divided up into well-defined episodes,
- /// and done being True indicates the episode has terminated.
- /// (For example, perhaps the pole tipped too far, or you lost your last life.)
- /// arg4: The type depends on the subclass overriding this function
- /// diagnostic information useful for debugging. It can sometimes be useful for
- /// learning (for example, it might contain the raw probabilities behind the environment’s last state change).
- /// However, official evaluations of your agent are not allowed to use this for learning.
- ///
- virtual time_step_type step(const action_type&)=0;
-
- ///
- /// \brief restart. Restart the world and
- /// return the starting state
- ///
- virtual time_step_type reset()=0;
-
- ///
- /// \brief Build the world
- ///
- virtual void build(bool reset)=0;
-
- ///
- /// \brief n_copies Returns the number of copies of the environment
- ///
- virtual uint_t n_copies()const {return 1;}
-
- ///
- /// \brief name
- ///
- std::string name()const noexcept{return name_;}
-
- ///
- /// \brief is_built
- /// \return
- ///
- bool is_built()const noexcept{return is_built_;}
-
- ///
- /// \brief make_is_built
- ///
- void make_is_built()noexcept{is_built_ = true;}
-
-
-protected:
-
- ///
- /// \brief WorldBase
- /// \param name
- ///
- WorldBase(const std::string& name);
-
- ///
- /// \brief name_
- ///
- std::string name_;
-
- ///
- /// \brief is_built_
- ///
- bool is_built_{false};
-
-};
-
-template
-WorldBase::WorldBase(const std::string& name)
- :
- name_(name),
- is_built_(false)
-{}
-
-
-}
-
-}
-
-}
-
-#endif // WORLD_BASE_H
diff --git a/src/cubeai/rl/worlds/discrete_world.h b/src/cubeai/rl/worlds/discrete_world.h
deleted file mode 100644
index 34fcbb1..0000000
--- a/src/cubeai/rl/worlds/discrete_world.h
+++ /dev/null
@@ -1,80 +0,0 @@
-#ifndef DISCRETE_WORLD_H
-#define DISCRETE_WORLD_H
-
-#include "cubic_engine/base/cubic_engine_types.h"
-#include "cubic_engine/rl/worlds/world_base.h"
-
-#include
-#include
-#include
-
-namespace cengine{
-namespace rl{
-namespace envs {
-
-class TimeStep;
-
-///
-/// \brief The DiscreteWorldBase class
-///
-template
-class DiscreteWorldBase: public WorldBase
-{
-public:
-
- ///
- /// \brief state_t
- ///
- typedef typename WorldBase::state_t state_t;
-
- ///
- /// \brief action_t
- ///
- typedef typename WorldBase::action_t action_t;
-
- ///
- /// \brief time_step_t
- ///
- typedef typename WorldBase::time_step_t time_step_t;
-
- ///
- ///
- ///
- virtual ~DiscreteWorldBase() = default;
-
- ///
- /// \brief n_actions
- ///
- virtual uint_t n_actions()const = 0;
-
- ///
- /// \brief n_actions
- ///
- virtual uint_t n_states()const = 0;
-
- ///
- /// \brief transition_dynamics
- ///
- virtual std::vector> transition_dynamics(uint_t s, uint_t aidx)const = 0;
-
-protected:
-
- ///
- /// \brief DiscreteWorldBase
- /// \param name
- ///
- DiscreteWorldBase(std::string name);
-
-};
-
-template
-DiscreteWorldBase::DiscreteWorldBase(std::string name)
- :
- WorldBase(name)
-{}
-
-}
-}
-}
-
-#endif // DISCRETE_WORLD_H
diff --git a/src/cubeai/rl/worlds/world_base.h b/src/cubeai/rl/worlds/world_base.h
deleted file mode 100644
index 97cefbf..0000000
--- a/src/cubeai/rl/worlds/world_base.h
+++ /dev/null
@@ -1,130 +0,0 @@
-#ifndef WORLD_BASE_H
-#define WORLD_BASE_H
-
-#include "cubic_engine/base/cubic_engine_types.h"
-#include
-
-#include
-#include
-#include
-#include
-
-namespace cengine {
-namespace rl {
-namespace envs{
-
-
-///
-///
-///
-template
-class WorldBase: private boost::noncopyable
-{
-
-public:
-
- ///
- /// \brief state_t
- ///
- typedef StateTp state_t;
-
- ///
- /// \brief action_t
- ///
- typedef ActionTp action_t;
-
- ///
- /// \brief time_step_t
- ///
- typedef TimeStepTp time_step_t;
-
- ///
- /// \brief Destructor
- ///
- virtual ~WorldBase() = default;
-
- ///
- /// \brief Transition to a new state by
- /// performing the given action. It returns a tuple
- /// with the following information
- /// arg1: An observation of the environment.
- /// arg2: Amount of reward achieved by the previous action.
- /// arg3: Flag indicating whether it’s time to reset the environment again.
- /// Most (but not all) tasks are divided up into well-defined episodes,
- /// and done being True indicates the episode has terminated.
- /// (For example, perhaps the pole tipped too far, or you lost your last life.)
- /// arg4: The type depends on the subclass overriding this function
- /// diagnostic information useful for debugging. It can sometimes be useful for
- /// learning (for example, it might contain the raw probabilities behind the environment’s last state change).
- /// However, official evaluations of your agent are not allowed to use this for learning.
- ///
- virtual time_step_t on_episode(const action_t&)=0;
-
- ///
- /// \brief restart. Restart the world and
- /// return the starting state
- ///
- virtual time_step_t reset()=0;
-
- ///
- /// \brief Build the world
- ///
- virtual void build(bool reset)=0;
-
- ///
- /// \brief n_copies Returns the number of copies of the environment
- ///
- virtual uint_t n_copies()const = 0;
-
- ///
- /// \brief name
- ///
- std::string name()const{return name_;}
-
- ///
- /// \brief is_built
- /// \return
- ///
- bool is_built()const{return is_built_;}
-
- ///
- /// \brief make_is_built
- ///
- void make_is_built(){is_built_ = true;}
-
-
-protected:
-
- ///
- /// \brief WorldBase
- /// \param name
- ///
- WorldBase(const std::string& name);
-
- ///
- /// \brief name_
- ///
- std::string name_;
-
- ///
- ///
- ///
- bool is_built_{false};
-
-};
-
-template
-WorldBase::WorldBase(const std::string& name)
- :
- name_(name),
- is_built_(false)
-{}
-
-
-}
-
-}
-
-}
-
-#endif // WORLD_BASE_H