Skip to content

Commit

Permalink
Adding entities back
Browse files Browse the repository at this point in the history
  • Loading branch information
gvegayon committed Apr 11, 2024
1 parent 22ecc64 commit 9a09073
Show file tree
Hide file tree
Showing 2 changed files with 90 additions and 55 deletions.
7 changes: 6 additions & 1 deletion include/epiworld/model-bones.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -145,6 +145,9 @@ class Model {
std::vector< ToolToAgentFun<TSeq> > tools_dist_funs = {};

std::vector< Entity<TSeq> > entities = {};
std::vector< epiworld_double > prevalence_entity = {};
std::vector< bool > prevalence_entity_as_proportion = {};
std::vector< EntityToAgentFun<TSeq> > entities_dist_funs = {};
std::vector< Entity<TSeq> > entities_backup = {};

std::mt19937 engine;
Expand Down Expand Up @@ -183,7 +186,7 @@ class Model {

void dist_tools();
void dist_virus();
// void dist_entities();
void dist_entities();

std::chrono::time_point<std::chrono::steady_clock> time_start;
std::chrono::time_point<std::chrono::steady_clock> time_end;
Expand Down Expand Up @@ -344,6 +347,8 @@ class Model {
void add_tool_n(Tool<TSeq> & t, epiworld_fast_uint preval);
void add_tool_fun(Tool<TSeq> & t, ToolToAgentFun<TSeq> fun);
void add_entity(Entity<TSeq> e);
void add_entity_n(Entity<TSeq> e, epiworld_fast_uint preval);
void add_entity_fun(Entity<TSeq> e, EntityToAgentFun<TSeq> fun);
void rm_virus(size_t virus_pos);
void rm_tool(size_t tool_pos);
void rm_entity(size_t entity_pos);
Expand Down
138 changes: 84 additions & 54 deletions include/epiworld/model-meat.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -388,10 +388,10 @@ inline Model<TSeq>::Model(const Model<TSeq> & model) :
prevalence_tool_as_proportion(model.prevalence_tool_as_proportion),
tools_dist_funs(model.tools_dist_funs),
entities(model.entities),
prevalence_entity(model.prevalence_entity),
prevalence_entity_as_proportion(model.prevalence_entity_as_proportion),
entities_dist_funs(model.entities_dist_funs),
entities_backup(model.entities_backup),
// prevalence_entity(model.prevalence_entity),
// prevalence_entity_as_proportion(model.prevalence_entity_as_proportion),
// entities_dist_funs(model.entities_dist_funs),
rewire_fun(model.rewire_fun),
rewire_prop(model.rewire_prop),
parameters(model.parameters),
Expand Down Expand Up @@ -464,10 +464,10 @@ inline Model<TSeq>::Model(Model<TSeq> && model) :
tools_dist_funs(std::move(model.tools_dist_funs)),
// Entities
entities(std::move(model.entities)),
prevalence_entity(std::move(model.prevalence_entity)),
prevalence_entity_as_proportion(std::move(model.prevalence_entity_as_proportion)),
entities_dist_funs(std::move(model.entities_dist_funs)),
entities_backup(std::move(model.entities_backup)),
// prevalence_entity(std::move(model.prevalence_entity)),
// prevalence_entity_as_proportion(std::move(model.prevalence_entity_as_proportion)),
// entities_dist_funs(std::move(model.entities_dist_funs)),
// Pseudo-RNG
engine(std::move(model.engine)),
runifd(std::move(model.runifd)),
Expand Down Expand Up @@ -542,10 +542,10 @@ inline Model<TSeq> & Model<TSeq>::operator=(const Model<TSeq> & m)
tools_dist_funs = m.tools_dist_funs;

entities = m.entities;
prevalence_entity = m.prevalence_entity;
prevalence_entity_as_proportion = m.prevalence_entity_as_proportion;
entities_dist_funs = m.entities_dist_funs;
entities_backup = m.entities_backup;
// prevalence_entity = m.prevalence_entity;
// prevalence_entity_as_proportion = m.prevalence_entity_as_proportion;
// entities_dist_funs = m.entities_dist_funs;

rewire_fun = m.rewire_fun;
rewire_prop = m.rewire_prop;
Expand Down Expand Up @@ -865,62 +865,62 @@ inline void Model<TSeq>::dist_tools()

}

// template<typename TSeq>
// inline void Model<TSeq>::dist_entities()
// {
template<typename TSeq>
inline void Model<TSeq>::dist_entities()
{

// Starting first infection
int n = size();
std::vector< size_t > idx(n);
for (epiworld_fast_uint e = 0; e < entities.size(); ++e)
{

if (entities_dist_funs[e])
{

entities_dist_funs[e](entities[e], this);

} else {

// Picking how many
int nsampled;
if (prevalence_entity_as_proportion[e])
{
nsampled = static_cast<int>(std::floor(prevalence_entity[e] * size()));
}
else
{
nsampled = static_cast<int>(prevalence_entity[e]);
}

// // Starting first infection
// int n = size();
// std::vector< size_t > idx(n);
// for (epiworld_fast_uint e = 0; e < entities.size(); ++e)
// {

// if (entities_dist_funs[e])
// {

// entities_dist_funs[e](entities[e], this);

// } else {

// // Picking how many
// int nsampled;
// if (prevalence_entity_as_proportion[e])
// {
// nsampled = static_cast<int>(std::floor(prevalence_entity[e] * size()));
// }
// else
// {
// nsampled = static_cast<int>(prevalence_entity[e]);
// }

// if (nsampled > static_cast<int>(size()))
// throw std::range_error("There are only " + std::to_string(size()) +
// " individuals in the population. Cannot add the entity to " + std::to_string(nsampled));
if (nsampled > static_cast<int>(size()))
throw std::range_error("There are only " + std::to_string(size()) +
" individuals in the population. Cannot add the entity to " + std::to_string(nsampled));

// Entity<TSeq> & entity = entities[e];
Entity<TSeq> & entity = entities[e];

// int n_left = n;
// std::iota(idx.begin(), idx.end(), 0);
// while (nsampled > 0)
// {
// int loc = static_cast<epiworld_fast_uint>(floor(runif() * n_left--));
int n_left = n;
std::iota(idx.begin(), idx.end(), 0);
while (nsampled > 0)
{
int loc = static_cast<epiworld_fast_uint>(floor(runif() * n_left--));

// population[idx[loc]].add_entity(entity, this, entity.state_init, entity.queue_init);
population[idx[loc]].add_entity(entity, this, entity.state_init, entity.queue_init);

// nsampled--;
nsampled--;

// std::swap(idx[loc], idx[n_left]);
std::swap(idx[loc], idx[n_left]);

// }
}

// }
}

// // Apply the events
// events_run();
// Apply the events
events_run();

// }
}

// }
}

template<typename TSeq>
inline void Model<TSeq>::chrono_start() {
Expand Down Expand Up @@ -1201,6 +1201,32 @@ inline void Model<TSeq>::add_entity(Entity<TSeq> e)

}

template<typename TSeq>
inline void Model<TSeq>::add_entity_n(Entity<TSeq> e, epiworld_fast_uint preval)
{

e.model = this;
e.id = entities.size();
entities.push_back(e);
prevalence_entity.push_back(preval);
prevalence_entity_as_proportion.push_back(false);
entities_dist_funs.push_back(nullptr);

}

template<typename TSeq>
inline void Model<TSeq>::add_entity_fun(Entity<TSeq> e, EntityToAgentFun<TSeq> fun)
{

e.model = this;
e.id = entities.size();
entities.push_back(e);
prevalence_entity.push_back(0.0);
prevalence_entity_as_proportion.push_back(false);
entities_dist_funs.push_back(fun);

}

template<typename TSeq>
inline void Model<TSeq>::rm_virus(size_t virus_pos)
{
Expand Down Expand Up @@ -1578,7 +1604,11 @@ inline void Model<TSeq>::run_multiple(
std::function<void(size_t,Model<TSeq>*)> fun,
bool reset,
bool verbose,
#ifdef _OPENMP
int nthreads
#else
int
#endif
)
{

Expand Down

0 comments on commit 9a09073

Please sign in to comment.