Skip to content

Commit

Permalink
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Adding faster sampler for SEIRD connected
Browse files Browse the repository at this point in the history
gvegayon committed Apr 25, 2024

Verified

This commit was signed with the committer’s verified signature. The key has expired.
manveru Michael Fellinger
1 parent 869b75e commit 937c2d3
Showing 3 changed files with 70 additions and 28 deletions.
1 change: 0 additions & 1 deletion include/epiworld/models/seirconnected.hpp
Original file line number Diff line number Diff line change
@@ -6,7 +6,6 @@ class ModelSEIRCONN : public epiworld::Model<TSeq>
{
private:
std::vector< epiworld::Agent<TSeq> * > infected;
double effective_contact_rate;
void update_infected();

public:
96 changes: 70 additions & 26 deletions include/epiworld/models/seirdconnected.hpp
Original file line number Diff line number Diff line change
@@ -4,6 +4,10 @@
template<typename TSeq = EPI_DEFAULT_TSEQ>
class ModelSEIRDCONN : public epiworld::Model<TSeq>
{
private:
std::vector< epiworld::Agent<TSeq> * > infected;
void update_infected();

public:

static const int SUSCEPTIBLE = 0;
@@ -12,7 +16,6 @@ class ModelSEIRDCONN : public epiworld::Model<TSeq>
static const int REMOVED = 3;
static const int DECEASED = 4;


ModelSEIRDCONN() {};

ModelSEIRDCONN(
@@ -58,8 +61,36 @@ class ModelSEIRDCONN : public epiworld::Model<TSeq>
std::vector< int > queue_ = {}
);

size_t get_n_infected() const
{
return infected.size();
}

};

template<typename TSeq>
inline void ModelSEIRDCONN<TSeq>::update_infected()
{
infected.clear();
infected.reserve(this->size());

for (auto & p : this->get_agents())
{
if (p.get_state() == ModelSEIRDCONN<TSeq>::INFECTED)
{
infected.push_back(&p);
}
}

Model<TSeq>::set_rand_binom(
this->get_n_infected(),
static_cast<double>(Model<TSeq>::par("Contact rate"))/
static_cast<double>(Model<TSeq>::size())
);

return;
}

template<typename TSeq>
inline ModelSEIRDCONN<TSeq> & ModelSEIRDCONN<TSeq>::run(
epiworld_fast_uint ndays,
@@ -139,13 +170,19 @@ inline ModelSEIRDCONN<TSeq>::ModelSEIRDCONN(
if (ndraw == 0)
return;

ModelSEIRDCONN<TSeq> * model = dynamic_cast<ModelSEIRDCONN<TSeq> *>(
m
);

size_t ninfected = model->get_n_infected();

// Drawing from the set
int nviruses_tmp = 0;
for (int i = 0; i < ndraw; ++i)
{
// Now selecting who is transmitting the disease
int which = static_cast<int>(
std::floor(m->size() * m->runif())
std::floor(ninfected * m->runif())
);

/* There is a bug in which runif() returns 1.0. It is rare, but
@@ -155,36 +192,31 @@ inline ModelSEIRDCONN<TSeq>::ModelSEIRDCONN(
* https://gcc.gnu.org/bugzilla/show_bug.cgi?id=63176
*
*/
if (which == static_cast<int>(m->size()))
if (which == static_cast<int>(ninfected))
--which;

epiworld::Agent<TSeq> & neighbor = *model->infected[which];

// Can't sample itself
if (which == static_cast<int>(p->get_id()))
if (neighbor.get_id() == p->get_id())
continue;

// If the neighbor is infected, then proceed
auto & neighbor = m->get_agents()[which];
if (neighbor.get_state() == ModelSEIRDCONN<TSeq>::INFECTED)
{

const auto & v = neighbor.get_virus();


#ifdef EPI_DEBUG
if (nviruses_tmp >= static_cast<int>(m->array_virus_tmp.size()))
throw std::logic_error("Trying to add an extra element to a temporal array outside of the range.");
#endif

/* And it is a function of susceptibility_reduction as well */
m->array_double_tmp[nviruses_tmp] =
(1.0 - p->get_susceptibility_reduction(v, m)) *
v->get_prob_infecting(m) *
(1.0 - neighbor.get_transmission_reduction(v, m))
;

m->array_virus_tmp[nviruses_tmp++] = &(*v);
// All neighbors in this set are infected by construction
const auto & v = neighbor.get_virus();

#ifdef EPI_DEBUG
if (nviruses_tmp >= static_cast<int>(m->array_virus_tmp.size()))
throw std::logic_error("Trying to add an extra element to a temporal array outside of the range.");
#endif

}
/* And it is a function of susceptibility_reduction as well */
m->array_double_tmp[nviruses_tmp] =
(1.0 - p->get_susceptibility_reduction(v, m)) *
v->get_prob_infecting(m) *
(1.0 - neighbor.get_transmission_reduction(v, m))
;

m->array_virus_tmp[nviruses_tmp++] = &(*v);
}

// No virus to compute
@@ -301,6 +333,18 @@ inline ModelSEIRDCONN<TSeq>::ModelSEIRDCONN(
model.add_state("Deceased");


// Adding update function
epiworld::GlobalFun<TSeq> update = [](epiworld::Model<TSeq> * m) -> void
{
ModelSEIRDCONN<TSeq> * model = dynamic_cast<ModelSEIRDCONN<TSeq> *>(m);
model->update_infected();

return;
};

model.add_globalevent(update, "Update infected individuals");


// Preparing the virus -------------------------------------------
epiworld::Virus<TSeq> virus(vname);
virus.set_state(
1 change: 0 additions & 1 deletion include/epiworld/models/sirconnected.hpp
Original file line number Diff line number Diff line change
@@ -8,7 +8,6 @@ class ModelSIRCONN : public epiworld::Model<TSeq>
private:

std::vector< epiworld::Agent<TSeq> * > infected;
double effective_contact_rate;
void update_infected();

public:

0 comments on commit 937c2d3

Please sign in to comment.