Skip to content

Commit

Permalink
[pypi] WIP all class posteriors
Browse files Browse the repository at this point in the history
  • Loading branch information
fradav committed Mar 20, 2023
1 parent a9a23f2 commit c1daeef
Show file tree
Hide file tree
Showing 10 changed files with 112 additions and 123 deletions.
2 changes: 1 addition & 1 deletion setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,7 @@

setup(
name="pyabcranger",
version="0.0.67",
version="0.0.69",
author="François-David Collin",
author_email="fradav@gmail.com",
description="ABC random forests for model choice and parameter estimation, python wrapper",
Expand Down
5 changes: 3 additions & 2 deletions src/EstimParam.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -46,7 +46,7 @@ EstimParamResults EstimParam_fun(Reftable<MatrixType> &myread,
size_t nref, ntree, nthreads, noisecols, seed, minnodesize, ntest;
std::string outfile, parameter_of_interest;
double chosenscen, plsmaxvar;
bool plsok, seeded, forest_save;
bool plsok, seeded, forest_save, weights_keep;

ntree = opts["t"].as<size_t>();
nthreads = opts["j"].as<size_t>();
Expand Down Expand Up @@ -239,7 +239,8 @@ EstimParamResults EstimParam_fun(Reftable<MatrixType> &myread,
DEFAULT_MAXDEPTH,
ntest,
forest_save);


if (!quiet)
forestreg.verbose_out = &std::cout;
forestreg.run(!quiet, true);
auto preds = forestreg.getPredictions();
Expand Down
4 changes: 0 additions & 4 deletions src/ForestOnline.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -724,10 +724,6 @@ void ForestOnline::loadFromFile(std::string filename) {
// #nocov end



thread_local std::vector<size_t> samples_terminalnodes;


void ForestOnline::setSplitWeightVector(std::vector<std::vector<double>>& split_select_weights) {

// Size should be 1 x num_independent_variables or num_trees x num_independent_variables
Expand Down
2 changes: 1 addition & 1 deletion src/ForestOnline.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -164,7 +164,7 @@ class ForestOnline {

// OOb counts
std::vector<size_t> samples_oob_count;
std::vector<size_t> samples_terminalnodes;

// Show progress every few seconds
#ifdef OLD_WIN_R_BUILD
void showProgress(std::string operation, clock_t start_time, clock_t& lap_time);
Expand Down
78 changes: 24 additions & 54 deletions src/ForestOnlineClassification.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -124,51 +124,29 @@ namespace ranger
}
}

thread_local std::vector<size_t> classsamples_terminalnodes;

void ForestOnlineClassification::predictInternal(size_t tree_idx)
{
// if (predict_all || prediction_type == TERMINALNODES)
// {
// // Get all tree predictions
// for (size_t tree_idx = 0; tree_idx < num_trees; ++tree_idx)
// {
// if (prediction_type == TERMINALNODES)
// {
// predictions[0][sample_idx][tree_idx] = getTreePredictionTerminalNodeID(tree_idx, sample_idx);
// }
// else
// {
// predictions[0][sample_idx][tree_idx] = getTreePrediction(tree_idx, sample_idx);
// }
// }
// }
// else
// {
// // Count classes over trees and save class with maximum count
// std::unordered_map<double, size_t> class_count;
// for (size_t tree_idx = 0; tree_idx < num_trees; ++tree_idx)
// {
// ++class_count[getTreePrediction(tree_idx, sample_idx)];
// }
// predictions[0][0][sample_idx] = mostFrequentValue(class_count, random_number_generator);
// }

for (size_t sample_idx = 0; sample_idx < predict_data->getNumRows(); ++sample_idx)
{
if (predict_all || prediction_type == TERMINALNODES)
{
if (prediction_type == TERMINALNODES)
{
predictions[1][sample_idx][tree_idx] = getTreePredictionTerminalNodeID(tree_idx, sample_idx);
}
else
{
predictions[1][sample_idx][tree_idx] = getTreePrediction(tree_idx, sample_idx);
auto sample_node = getTreePredictionTerminalNodeID(tree_idx, sample_idx);
auto sample_pred = getTreePrediction(tree_idx, sample_idx);
for(auto i = 0; i < num_samples; i++)
if (sample_node == classsamples_terminalnodes[i]) {
mutex_samples[i].lock();
predictions[4][sample_idx][i]++;
mutex_samples[i].unlock();
}
if (predict_all || prediction_type == TERMINALNODES) {
predictions[1][sample_idx][tree_idx] = sample_pred;
if (prediction_type == TERMINALNODES)
predictions[3][sample_idx][tree_idx] = sample_node;
}
else
{
mutex_post.lock();
++class_count[sample_idx][getTreePrediction(tree_idx, sample_idx)];
++class_count[sample_idx][sample_pred];
mutex_post.unlock();
}
}
Expand All @@ -178,35 +156,27 @@ namespace ranger
{
// For each tree loop over OOB samples and count classes
double to_add = 0.0;
auto numOOB = trees[tree_idx]->getNumSamplesOob();
if (samples_terminalnodes.empty()) samples_terminalnodes.resize(num_samples);

if (classsamples_terminalnodes.empty()) classsamples_terminalnodes.reserve(num_samples);
for (size_t i = 0; i < num_samples; ++i) classsamples_terminalnodes[i] = 0;
auto numOOB = trees[tree_idx]->getNumSamplesOob();
auto mapOOB = trees[tree_idx]->getOobSampleIDs();
for (size_t sample_oob_idx = 0; sample_oob_idx < numOOB; ++sample_oob_idx)
{
size_t sampleID = trees[tree_idx]->getOobSampleIDs()[sample_oob_idx];
auto sample_node_oob = static_cast<size_t>(getTreePredictionTerminalNodeID(tree_idx, sample_oob_idx));
auto res = static_cast<size_t>(getTreePrediction(tree_idx, sample_oob_idx));
auto sampleID = mapOOB[sample_oob_idx];
auto sample_node = getTreePredictionTerminalNodeID(tree_idx, sample_oob_idx);
classsamples_terminalnodes[sampleID] = sample_node;
// classsamples_terminalnodes[sampleID+1] = getTreePredictionTerminalNodeID(tree_idx, sample_oob_idx);
auto res = getTreePrediction(tree_idx, sample_oob_idx);
mutex_post.lock();
++class_counts[sampleID][res];
if (!class_counts[sampleID].empty())
to_add += (mostFrequentValue(class_counts[sampleID], random_number_generator) == data->get(sampleID, dependent_varID)) ? 0.0 : 1.0;
for (size_t sample_idx = 0; sample_idx < predict_data->getNumRows(); ++sample_idx) {
auto sample_node = static_cast<size_t>(getTreePredictionTerminalNodeID(tree_idx, sample_idx));
if (sample_node == sample_node_oob) predictions[4][sample_idx][sampleID]++;
}

mutex_post.unlock();
}
predictions[2][0][tree_idx] += to_add / static_cast<double>(numOOB);
// for (size_t sample_idx = 0; sample_idx < num_samples; sample_idx++) {
// if (!class_counts[sample_idx].empty())
// predictions[2][0][tree_idx] += (mostFrequentValue(class_counts[sample_idx], random_number_generator) == data->get(sample_idx,dependent_varID)) ? 0.0 : 1.0;
// }

// else
// for (size_t sample_idx = 0; sample_idx < num_samples; ++sample_idx)
// {
// ++class_counts_internal[sample_idx][getTreePrediction(tree_idx,sample_idx)];
// };
}

void ForestOnlineClassification::computePredictionErrorInternal()
Expand Down
3 changes: 2 additions & 1 deletion src/ForestOnlineRegression.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@ void ForestOnlineRegression::initInternal(std::string status_variable_name)
{

keep_inbag = true;

// If mtry not set, use floored square root of number of independent variables
if (mtry == 0)
{
Expand Down Expand Up @@ -53,7 +54,7 @@ void ForestOnlineRegression::growInternal()
samples_oob_count.resize(num_samples, 0);
}


thread_local std::vector<size_t> samples_terminalnodes;

void ForestOnlineRegression::allocatePredictMemory()
{
Expand Down
Loading

0 comments on commit c1daeef

Please sign in to comment.