diff --git a/examples/rl/rl_example_11/rl_example_11.md b/examples/rl/rl_example_11/rl_example_11.md index 6da5b76..2c480db 100644 --- a/examples/rl/rl_example_11/rl_example_11.md +++ b/examples/rl/rl_example_11/rl_example_11.md @@ -1,123 +1,315 @@ -# Example 11: An A2C Solver for _CartPole_ +# Example 11: An A2C Solver for CartPole -In this example we will develop a policy gradient RL solver in order to -use on the CartPole environment. In particular, we will develop an actor-critic method -knonw as A2C. +Example Example 13: REINFORCE algorithm on CartPole +introduce the vanilla REINFORCE algorithm in order to solve the pole balancing problem. +The REINFORCE algorithm is a simple algorithm and easy to use however it exhibits a large variance in the reward signal. +In this exampel, we will introduce actor-critic methods +and specifically the A2C algorithm. The main objective of actor-critic methods is to further reduce the high gradient variance. +One way towards this direction is to use the so-called reward-to-go term for a trajectory $T$. + +$$G = \sum_{k = t}^T R(s_{k}, \alpha_{k})$$ For this example to work, you will need to have both PyTorch and ```rlenvs_cpp``` enabled. +The workings of the A2C algorithm are handled by the A2CSolver +class. The class accepts three template parameters + +- The environment type +- The policy type or the actor type +- The critic type + +The role of the actor is to select an action for the agent to take. +The role of the critic is to tell us whether that action was good or bad. +We could use of course the raw reward signal to have an assessment on this +but for enviroments where the reward is sparse this may not work or for enviroments +where the reward signal is the same for most actions. + +Given this approach, the critic will be a term in the actor’s loss function, +whilst the critic will learn directly from the provided reward signals. + +In the code below, the critic and action networks do not share any details. +This need not be the case however. You can come up with an implementation +where the two networks share most of the layers and only differentiate +at the output layer. + + +### The Actor network class + +``` +// create the Action and the Critic networks +class ActorNetImpl: public torch::nn::Module +{ +public: + + // constructor + ActorNetImpl(uint_t state_size, uint_t action_size); + + torch_tensor_t forward(torch_tensor_t state); + torch_tensor_t log_probabilities(torch_tensor_t actions); + torch_tensor_t sample(); + +private: + + torch::nn::Linear linear1_; + torch::nn::Linear linear2_; + torch::nn::Linear linear3_; + + // the underlying distribution used to sample actions + TorchCategorical distribution_; +}; + +ActorNetImpl::ActorNetImpl(uint_t state_size, uint_t action_size) +: +torch::nn::Module(), +linear1_(nullptr), +linear2_(nullptr), +linear3_(nullptr) +{ + linear1_ = register_module("linear1_", torch::nn::Linear(state_size, 128)); + linear2_ = register_module("linea2_", torch::nn::Linear(128, 256)); + linear3_ = register_module("linear3_", torch::nn::Linear(256, action_size)); +} + + +torch_tensor_t +ActorNetImpl::forward(torch_tensor_t state){ + auto output = torch::nn::functional::relu(linear1_(state)); + output = torch::nn::functional::relu(linear2_(output)); + output = linear3_(output); + const torch_tensor_t probs = torch::nn::functional::softmax(output,-1); + distribution_.build_from_probabilities(probs); + return probs; +} + +torch_tensor_t +ActorNetImpl::sample(){ + return distribution_.sample(); +} + +torch_tensor_t +ActorNetImpl::log_probabilities(torch_tensor_t actions){ + return distribution_.log_prob(actions); +} +``` + + ## The driver code The driver code for this tutorial is shown below. ```cpp -/** - * This example illustrates a simple example of Monte Carlo - * iteration using the IterationCounter class - * - * */ +#include "cubeai/base/cubeai_config.h" + +#if defined(USE_PYTORCH) && defined(USE_RLENVS_CPP) #include "cubeai/base/cubeai_types.h" -#include "cubeai/utils/iteration_counter.h" -#include "cubeai/geom_primitives/shapes/circle.h" -#include "cubeai/extern/nlohmann/json/json.hpp" +#include "cubeai/maths/statistics/distributions/torch_categorical.h" +#include "cubeai/rl/trainers/rl_serial_agent_trainer.h" +#include "cubeai/rl/algorithms/actor_critic/a2c.h" +#include "cubeai/maths/optimization/optimizer_type.h" +#include "cubeai/maths/optimization/pytorch_optimizer_factory.h" +#include "rlenvs/envs/gymnasium/classic_control/cart_pole_env.h" -#include #include -#include -#include +#include +#include -namespace intro_example_1 -{ +namespace rl_example_11{ + +const std::string SERVER_URL = "http://0.0.0.0:8001/api"; using cubeai::real_t; using cubeai::uint_t; -using cubeai::utils::IterationCounter; -using cubeai::geom_primitives::Circle; +using cubeai::torch_tensor_t; +using cubeai::maths::stats::TorchCategorical; +using cubeai::rl::algos::ac::A2CConfig; +using cubeai::rl::algos::ac::A2CSolver; +using cubeai::rl::RLSerialAgentTrainer; +using cubeai::rl::RLSerialTrainerConfig; +using rlenvs_cpp::envs::gymnasium::CartPoleActionsEnum; -using json = nlohmann::json; -const std::string CONFIG = "config.json"; +// create the Action and the Critic networks +class ActorNetImpl: public torch::nn::Module +{ +public: -// read the JSON file -json -load_config(const std::string& filename){ + // constructor + ActorNetImpl(uint_t state_size, uint_t action_size); - std::ifstream f(filename); - json data = json::parse(f); - return data; + + torch_tensor_t forward(torch_tensor_t state); + torch_tensor_t log_probabilities(torch_tensor_t actions); + torch_tensor_t sample(); + + +private: + + torch::nn::Linear linear1_; + torch::nn::Linear linear2_; + torch::nn::Linear linear3_; + + // the underlying distribution used to sample actions + TorchCategorical distribution_; +}; + +ActorNetImpl::ActorNetImpl(uint_t state_size, uint_t action_size) +: +torch::nn::Module(), +linear1_(nullptr), +linear2_(nullptr), +linear3_(nullptr) +{ + linear1_ = register_module("linear1_", torch::nn::Linear(state_size, 128)); + linear2_ = register_module("linea2_", torch::nn::Linear(128, 256)); + linear3_ = register_module("linear3_", torch::nn::Linear(256, action_size)); } +torch_tensor_t +ActorNetImpl::forward(torch_tensor_t state){ + + auto output = torch::nn::functional::relu(linear1_(state)); + output = torch::nn::functional::relu(linear2_(output)); + output = linear3_(output); + const torch_tensor_t probs = torch::nn::functional::softmax(output,-1); + distribution_.build_from_probabilities(probs); + return probs; } -int main() { +torch_tensor_t +ActorNetImpl::sample(){ + return distribution_.sample(); +} - using namespace intro_example_1; +torch_tensor_t +ActorNetImpl::log_probabilities(torch_tensor_t actions){ + return distribution_.log_prob(actions); +} + + +class CriticNetImpl: public torch::nn::Module +{ +public: + + // constructor + CriticNetImpl(uint_t state_size); + + torch_tensor_t forward(torch_tensor_t state); + +private: + + torch::nn::Linear linear1_; + torch::nn::Linear linear2_; + torch::nn::Linear linear3_; + +}; + + +CriticNetImpl::CriticNetImpl(uint_t state_size) +: +torch::nn::Module(), +linear1_(nullptr), +linear2_(nullptr), +linear3_(nullptr) +{ + linear1_ = register_module("linear1_", torch::nn::Linear(state_size, 128)); + linear2_ = register_module("linea2_", torch::nn::Linear(128, 256)); + linear3_ = register_module("linear3_", torch::nn::Linear(256, 1)); +} + +torch_tensor_t +CriticNetImpl::forward(torch_tensor_t state){ + + auto output = torch::nn::functional::relu(linear1_(state)); + output = torch::nn::functional::relu(linear2_(output)); + output = linear3_(output); + return output; +} + +TORCH_MODULE(ActorNet); +TORCH_MODULE(CriticNet); + + +typedef rlenvs_cpp::envs::gymnasium::CartPole env_type; + + +} + + +int main(){ + + using namespace rl_example_11; try{ - BOOST_LOG_TRIVIAL(info)<<"Reading configuration file..."; - - // load the json configuration - auto data = load_config(CONFIG); - - // read properties from the configuration - const auto R = data["R"].template get(); - const auto N_POINTS = data["N_POINTS"].template get(); - const auto SEED = data["SEED"].template get(); - const auto X = data["X"].template get(); - const auto Y = data["Y"].template get(); - - // create a circle - Circle c(R, {X, Y}); - - // simple object to control iterations - IterationCounter counter(N_POINTS); - - // how many points we found in the Circle - auto points_inside_circle = 0; - - // the box has side 2 - const real_t SQUARE_SIDE = R*2.0; - std::uniform_real_distribution dist(0.0,SQUARE_SIDE); - std::mt19937 gen(SEED); - - BOOST_LOG_TRIVIAL(info)<<"Starting computation..."; - while(counter.continue_iterations()){ - auto x = dist(gen); - auto y = dist(gen); - if(c.is_inside(x,y, 1.0e-4)){ - points_inside_circle += 1; - } - } - - BOOST_LOG_TRIVIAL(info)<<"Finished computation..."; - auto area = (static_cast(points_inside_circle) / static_cast(N_POINTS)) * std::pow(SQUARE_SIDE, 2); - BOOST_LOG_TRIVIAL(info)<<"Circle area calculated with:" < options; + + std::cout<<"Creating the environment..."< solver_type; + + solver_type solver(a2c_config, policy, critic, + policy_optimizer, critic_optimizer); + + RLSerialTrainerConfig config; + RLSerialAgentTrainer trainer(config, solver); + trainer.train(env); + } catch(std::exception& e){ - BOOST_LOG_TRIVIAL(error)< - return 0; +int main(){ + + std::cout<<"This example requires the flag USE_RLENVS_CPP to be true."<EXample 12: DQN algorithm on Gridworld -and EXample 15: DQN algorithm on Gridworld with experience replay approximate a value function +The DQN algorithm we used in examples Example 12: DQN algorithm on Gridworld +and Example 15: DQN algorithm on Gridworld with experience replay approximate a value function and in particular the Q state-action value function. However, in reinforcement learning we are more interested in policies since a policy dictates how an agent behaves in a given state. diff --git a/include/cubeai/rl/actions/action_space.h b/include/cubeai/rl/actions/action_space.h deleted file mode 100644 index 949b5b9..0000000 --- a/include/cubeai/rl/actions/action_space.h +++ /dev/null @@ -1,53 +0,0 @@ -#ifndef ACTION_SPACE_H -#define ACTION_SPACE_H - -#include "cubeai/base/cubeai_config.h" -#include "cubeai/base/cubeai_types.h" - -#include -#include - -namespace cubeai { -namespace rl { -namespace actions { - -struct ActionSpace -{ - std::string type; - std::vector shape; - - /// - /// \brief ActionSpace - /// - ActionSpace(const std::string& t, std::vector& sh) - : - type(t), - shape(sh) - {} -}; - -#ifdef USE_PYTORCH -struct TorchActionSpace -{ - std::string type; - std::vector shape; - - /// - /// \brief TorchActionSpace - /// - TorchActionSpace(const std::string& t, std::vector& sh) - : - type(t), - shape(sh) - {} - -}; -#endif - -} - -} - -} - -#endif // ACTION_SPACE_H diff --git a/include/cubeai/rl/worlds/discrete_world.h b/include/cubeai/rl/worlds/discrete_world.h deleted file mode 100644 index 456f9d0..0000000 --- a/include/cubeai/rl/worlds/discrete_world.h +++ /dev/null @@ -1,79 +0,0 @@ -#ifndef DISCRETE_WORLD_H -#define DISCRETE_WORLD_H - -#include "cubeai/base/cubeai_types.h" -#include "cubeai/rl/worlds/world_base.h" - -#include -#include -#include - -namespace cubeai{ -namespace rl{ -namespace envs { - - -/// -/// \brief The DiscreteWorldBase class -/// -template -class DiscreteWorldBase: public WorldBase -{ -public: - - /// - /// \brief state_t - /// - typedef typename WorldBase::state_type state_type; - - /// - /// \brief action_t - /// - typedef typename WorldBase::action_type action_type; - - /// - /// \brief time_step_t - /// - typedef typename WorldBase::time_step_type time_step_type; - - /// - /// \brief ~DiscreteWorldBase. Destructor - /// - virtual ~DiscreteWorldBase() = default; - - /// - /// \brief n_actions - /// - virtual uint_t n_actions()const = 0; - - /// - /// \brief n_states - /// - virtual uint_t n_states()const = 0; - - /// - /// \brief transition_dynamics - /// - virtual std::vector> transition_dynamics(uint_t s, uint_t aidx)const {}; - -protected: - - /// - /// \brief DiscreteWorldBase - /// \param name - /// - DiscreteWorldBase(std::string name); - -}; - -template -DiscreteWorldBase::DiscreteWorldBase(std::string name) - : - WorldBase(name) -{} - -} -} -} - -#endif // DISCRETE_WORLD_H diff --git a/include/cubeai/rl/worlds/world_base.h b/include/cubeai/rl/worlds/world_base.h deleted file mode 100644 index 48b8ee1..0000000 --- a/include/cubeai/rl/worlds/world_base.h +++ /dev/null @@ -1,130 +0,0 @@ -#ifndef WORLD_BASE_H -#define WORLD_BASE_H - -#include "cubeai/base/cubeai_types.h" -#include - -#include -#include -#include -#include - -namespace cubeai { -namespace rl { -namespace envs{ - - -/// -/// \brief Base minimal class for RL environments -/// -template -class WorldBase: private boost::noncopyable -{ - -public: - - /// - /// \brief state_t - /// - typedef StateTp state_type; - - /// - /// \brief action_t - /// - typedef ActionTp action_type; - - /// - /// \brief time_step_t - /// - typedef TimeStepTp time_step_type; - - /// - /// \brief ~WorldBase. Destructor - /// - virtual ~WorldBase() = default; - - /// - /// \brief Transition to a new state by - /// performing the given action. It returns a tuple - /// with the following information - /// arg1: An observation of the environment. - /// arg2: Amount of reward achieved by the previous action. - /// arg3: Flag indicating whether it’s time to reset the environment again. - /// Most (but not all) tasks are divided up into well-defined episodes, - /// and done being True indicates the episode has terminated. - /// (For example, perhaps the pole tipped too far, or you lost your last life.) - /// arg4: The type depends on the subclass overriding this function - /// diagnostic information useful for debugging. It can sometimes be useful for - /// learning (for example, it might contain the raw probabilities behind the environment’s last state change). - /// However, official evaluations of your agent are not allowed to use this for learning. - /// - virtual time_step_type step(const action_type&)=0; - - /// - /// \brief restart. Restart the world and - /// return the starting state - /// - virtual time_step_type reset()=0; - - /// - /// \brief Build the world - /// - virtual void build(bool reset)=0; - - /// - /// \brief n_copies Returns the number of copies of the environment - /// - virtual uint_t n_copies()const {return 1;} - - /// - /// \brief name - /// - std::string name()const noexcept{return name_;} - - /// - /// \brief is_built - /// \return - /// - bool is_built()const noexcept{return is_built_;} - - /// - /// \brief make_is_built - /// - void make_is_built()noexcept{is_built_ = true;} - - -protected: - - /// - /// \brief WorldBase - /// \param name - /// - WorldBase(const std::string& name); - - /// - /// \brief name_ - /// - std::string name_; - - /// - /// \brief is_built_ - /// - bool is_built_{false}; - -}; - -template -WorldBase::WorldBase(const std::string& name) - : - name_(name), - is_built_(false) -{} - - -} - -} - -} - -#endif // WORLD_BASE_H diff --git a/src/cubeai/rl/worlds/discrete_world.h b/src/cubeai/rl/worlds/discrete_world.h deleted file mode 100644 index 34fcbb1..0000000 --- a/src/cubeai/rl/worlds/discrete_world.h +++ /dev/null @@ -1,80 +0,0 @@ -#ifndef DISCRETE_WORLD_H -#define DISCRETE_WORLD_H - -#include "cubic_engine/base/cubic_engine_types.h" -#include "cubic_engine/rl/worlds/world_base.h" - -#include -#include -#include - -namespace cengine{ -namespace rl{ -namespace envs { - -class TimeStep; - -/// -/// \brief The DiscreteWorldBase class -/// -template -class DiscreteWorldBase: public WorldBase -{ -public: - - /// - /// \brief state_t - /// - typedef typename WorldBase::state_t state_t; - - /// - /// \brief action_t - /// - typedef typename WorldBase::action_t action_t; - - /// - /// \brief time_step_t - /// - typedef typename WorldBase::time_step_t time_step_t; - - /// - /// - /// - virtual ~DiscreteWorldBase() = default; - - /// - /// \brief n_actions - /// - virtual uint_t n_actions()const = 0; - - /// - /// \brief n_actions - /// - virtual uint_t n_states()const = 0; - - /// - /// \brief transition_dynamics - /// - virtual std::vector> transition_dynamics(uint_t s, uint_t aidx)const = 0; - -protected: - - /// - /// \brief DiscreteWorldBase - /// \param name - /// - DiscreteWorldBase(std::string name); - -}; - -template -DiscreteWorldBase::DiscreteWorldBase(std::string name) - : - WorldBase(name) -{} - -} -} -} - -#endif // DISCRETE_WORLD_H diff --git a/src/cubeai/rl/worlds/world_base.h b/src/cubeai/rl/worlds/world_base.h deleted file mode 100644 index 97cefbf..0000000 --- a/src/cubeai/rl/worlds/world_base.h +++ /dev/null @@ -1,130 +0,0 @@ -#ifndef WORLD_BASE_H -#define WORLD_BASE_H - -#include "cubic_engine/base/cubic_engine_types.h" -#include - -#include -#include -#include -#include - -namespace cengine { -namespace rl { -namespace envs{ - - -/// -/// -/// -template -class WorldBase: private boost::noncopyable -{ - -public: - - /// - /// \brief state_t - /// - typedef StateTp state_t; - - /// - /// \brief action_t - /// - typedef ActionTp action_t; - - /// - /// \brief time_step_t - /// - typedef TimeStepTp time_step_t; - - /// - /// \brief Destructor - /// - virtual ~WorldBase() = default; - - /// - /// \brief Transition to a new state by - /// performing the given action. It returns a tuple - /// with the following information - /// arg1: An observation of the environment. - /// arg2: Amount of reward achieved by the previous action. - /// arg3: Flag indicating whether it’s time to reset the environment again. - /// Most (but not all) tasks are divided up into well-defined episodes, - /// and done being True indicates the episode has terminated. - /// (For example, perhaps the pole tipped too far, or you lost your last life.) - /// arg4: The type depends on the subclass overriding this function - /// diagnostic information useful for debugging. It can sometimes be useful for - /// learning (for example, it might contain the raw probabilities behind the environment’s last state change). - /// However, official evaluations of your agent are not allowed to use this for learning. - /// - virtual time_step_t on_episode(const action_t&)=0; - - /// - /// \brief restart. Restart the world and - /// return the starting state - /// - virtual time_step_t reset()=0; - - /// - /// \brief Build the world - /// - virtual void build(bool reset)=0; - - /// - /// \brief n_copies Returns the number of copies of the environment - /// - virtual uint_t n_copies()const = 0; - - /// - /// \brief name - /// - std::string name()const{return name_;} - - /// - /// \brief is_built - /// \return - /// - bool is_built()const{return is_built_;} - - /// - /// \brief make_is_built - /// - void make_is_built(){is_built_ = true;} - - -protected: - - /// - /// \brief WorldBase - /// \param name - /// - WorldBase(const std::string& name); - - /// - /// \brief name_ - /// - std::string name_; - - /// - /// - /// - bool is_built_{false}; - -}; - -template -WorldBase::WorldBase(const std::string& name) - : - name_(name), - is_built_(false) -{} - - -} - -} - -} - -#endif // WORLD_BASE_H