From 937c2d323de860a7ae14d41c80eff7a3fcbb62c0 Mon Sep 17 00:00:00 2001 From: "George G. Vega Yon" Date: Thu, 25 Apr 2024 09:44:47 -0600 Subject: [PATCH] Adding faster sampler for SEIRD connected --- include/epiworld/models/seirconnected.hpp | 1 - include/epiworld/models/seirdconnected.hpp | 96 ++++++++++++++++------ include/epiworld/models/sirconnected.hpp | 1 - 3 files changed, 70 insertions(+), 28 deletions(-) diff --git a/include/epiworld/models/seirconnected.hpp b/include/epiworld/models/seirconnected.hpp index b27c7e0d2..7295dbac2 100644 --- a/include/epiworld/models/seirconnected.hpp +++ b/include/epiworld/models/seirconnected.hpp @@ -6,7 +6,6 @@ class ModelSEIRCONN : public epiworld::Model { private: std::vector< epiworld::Agent * > infected; - double effective_contact_rate; void update_infected(); public: diff --git a/include/epiworld/models/seirdconnected.hpp b/include/epiworld/models/seirdconnected.hpp index 148801b8c..21966a73b 100644 --- a/include/epiworld/models/seirdconnected.hpp +++ b/include/epiworld/models/seirdconnected.hpp @@ -4,6 +4,10 @@ template class ModelSEIRDCONN : public epiworld::Model { +private: + std::vector< epiworld::Agent * > infected; + void update_infected(); + public: static const int SUSCEPTIBLE = 0; @@ -12,7 +16,6 @@ class ModelSEIRDCONN : public epiworld::Model static const int REMOVED = 3; static const int DECEASED = 4; - ModelSEIRDCONN() {}; ModelSEIRDCONN( @@ -58,8 +61,36 @@ class ModelSEIRDCONN : public epiworld::Model std::vector< int > queue_ = {} ); + size_t get_n_infected() const + { + return infected.size(); + } + }; +template +inline void ModelSEIRDCONN::update_infected() +{ + infected.clear(); + infected.reserve(this->size()); + + for (auto & p : this->get_agents()) + { + if (p.get_state() == ModelSEIRDCONN::INFECTED) + { + infected.push_back(&p); + } + } + + Model::set_rand_binom( + this->get_n_infected(), + static_cast(Model::par("Contact rate"))/ + static_cast(Model::size()) + ); + + return; +} + template inline ModelSEIRDCONN & ModelSEIRDCONN::run( epiworld_fast_uint ndays, @@ -139,13 +170,19 @@ inline ModelSEIRDCONN::ModelSEIRDCONN( if (ndraw == 0) return; + ModelSEIRDCONN * model = dynamic_cast *>( + m + ); + + size_t ninfected = model->get_n_infected(); + // Drawing from the set int nviruses_tmp = 0; for (int i = 0; i < ndraw; ++i) { // Now selecting who is transmitting the disease int which = static_cast( - std::floor(m->size() * m->runif()) + std::floor(ninfected * m->runif()) ); /* There is a bug in which runif() returns 1.0. It is rare, but @@ -155,36 +192,31 @@ inline ModelSEIRDCONN::ModelSEIRDCONN( * https://gcc.gnu.org/bugzilla/show_bug.cgi?id=63176 * */ - if (which == static_cast(m->size())) + if (which == static_cast(ninfected)) --which; + epiworld::Agent & neighbor = *model->infected[which]; + // Can't sample itself - if (which == static_cast(p->get_id())) + if (neighbor.get_id() == p->get_id()) continue; - // If the neighbor is infected, then proceed - auto & neighbor = m->get_agents()[which]; - if (neighbor.get_state() == ModelSEIRDCONN::INFECTED) - { - - const auto & v = neighbor.get_virus(); - - - #ifdef EPI_DEBUG - if (nviruses_tmp >= static_cast(m->array_virus_tmp.size())) - throw std::logic_error("Trying to add an extra element to a temporal array outside of the range."); - #endif - - /* And it is a function of susceptibility_reduction as well */ - m->array_double_tmp[nviruses_tmp] = - (1.0 - p->get_susceptibility_reduction(v, m)) * - v->get_prob_infecting(m) * - (1.0 - neighbor.get_transmission_reduction(v, m)) - ; - - m->array_virus_tmp[nviruses_tmp++] = &(*v); + // All neighbors in this set are infected by construction + const auto & v = neighbor.get_virus(); + + #ifdef EPI_DEBUG + if (nviruses_tmp >= static_cast(m->array_virus_tmp.size())) + throw std::logic_error("Trying to add an extra element to a temporal array outside of the range."); + #endif - } + /* And it is a function of susceptibility_reduction as well */ + m->array_double_tmp[nviruses_tmp] = + (1.0 - p->get_susceptibility_reduction(v, m)) * + v->get_prob_infecting(m) * + (1.0 - neighbor.get_transmission_reduction(v, m)) + ; + + m->array_virus_tmp[nviruses_tmp++] = &(*v); } // No virus to compute @@ -301,6 +333,18 @@ inline ModelSEIRDCONN::ModelSEIRDCONN( model.add_state("Deceased"); + // Adding update function + epiworld::GlobalFun update = [](epiworld::Model * m) -> void + { + ModelSEIRDCONN * model = dynamic_cast *>(m); + model->update_infected(); + + return; + }; + + model.add_globalevent(update, "Update infected individuals"); + + // Preparing the virus ------------------------------------------- epiworld::Virus virus(vname); virus.set_state( diff --git a/include/epiworld/models/sirconnected.hpp b/include/epiworld/models/sirconnected.hpp index 1f6d30d57..705a309ff 100644 --- a/include/epiworld/models/sirconnected.hpp +++ b/include/epiworld/models/sirconnected.hpp @@ -8,7 +8,6 @@ class ModelSIRCONN : public epiworld::Model private: std::vector< epiworld::Agent * > infected; - double effective_contact_rate; void update_infected(); public: