Skip to content

Commit

Permalink
[skip-ci] WIP All post/bug in estimp. weights
Browse files Browse the repository at this point in the history
  • Loading branch information
fradav committed Feb 15, 2023
1 parent 265898a commit a9a23f2
Show file tree
Hide file tree
Showing 11 changed files with 289 additions and 83 deletions.
2 changes: 2 additions & 0 deletions abcranger.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -31,8 +31,10 @@ int main(int argc, char* argv[]) {
("plsmaxvar","Percentage of maximum explained Y-variance for retaining pls axis",cxxopts::value<double>()->default_value("0.9"))
("chosenscen","Chosen scenario (mandatory for parameter estimation)", cxxopts::value<size_t>())
("noob","number of oob testing samples (mandatory for parameter estimation)",cxxopts::value<size_t>())
("allpost","calculate all posteriors per model not just the selected one")
("parameter","name of the parameter of interest (mandatory for parameter estimation)",cxxopts::value<std::string>())
("g,groups","Groups of models",cxxopts::value<std::string>())
("save","save forest in ranger format")
("help", "Print help")
;
auto opts = options.parse(argc,argv);
Expand Down
14 changes: 8 additions & 6 deletions src/EstimParam.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -46,7 +46,7 @@ EstimParamResults EstimParam_fun(Reftable<MatrixType> &myread,
size_t nref, ntree, nthreads, noisecols, seed, minnodesize, ntest;
std::string outfile, parameter_of_interest;
double chosenscen, plsmaxvar;
bool plsok, seeded;
bool plsok, seeded, forest_save;

ntree = opts["t"].as<size_t>();
nthreads = opts["j"].as<size_t>();
Expand All @@ -60,7 +60,8 @@ EstimParamResults EstimParam_fun(Reftable<MatrixType> &myread,
chosenscen = static_cast<double>(opts["chosenscen"].as<size_t>());
parameter_of_interest = opts["parameter"].as<std::string>();
plsok = opts.count("nolinear") == 0;

forest_save = opts.count("save") != 0;

outfile = (opts.count("output") == 0) ? "estimparam_out" : opts["o"].as<std::string>();

double p_threshold_PLS = 0.99;
Expand Down Expand Up @@ -236,8 +237,9 @@ EstimParamResults EstimParam_fun(Reftable<MatrixType> &myread,
DEFAULT_NUM_RANDOM_SPLITS, // num_random_splits
false, //order_snps
DEFAULT_MAXDEPTH,
ntest); // max_depth
if (!quiet)
ntest,
forest_save);

forestreg.verbose_out = &std::cout;
forestreg.run(!quiet, true);
auto preds = forestreg.getPredictions();
Expand Down Expand Up @@ -274,7 +276,7 @@ EstimParamResults EstimParam_fun(Reftable<MatrixType> &myread,

if (weights)
res.oob_weights = MatrixXd(ntest, nref);
res.oob_map = forestreg.oob_subset;
res.oob_map = forestreg.getOobMapSubset();

std::vector<double> expectation(num_samples, 0.0);
std::vector<double> variance(num_samples, 0.0);
Expand Down Expand Up @@ -358,7 +360,7 @@ EstimParamResults EstimParam_fun(Reftable<MatrixType> &myread,
// mutex_quant.unlock();
if (i < ntest)
{
auto p = *(std::next(forestreg.oob_subset.begin(), i));
auto p = *(std::next(res.oob_map.begin(), i));
if (weights)
for (auto i = 0; i < nref; i++)
res.oob_weights(p.second, i) = preds[5][p.second][i];
Expand Down
41 changes: 21 additions & 20 deletions src/ForestOnline.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -33,7 +33,7 @@ ForestOnline::ForestOnline() :
true), memory_saving_splitting(false), splitrule(DEFAULT_SPLITRULE), predict_all(false), keep_inbag(false), sample_fraction(
{ 1 }), holdout(false), prediction_type(DEFAULT_PREDICTIONTYPE), num_random_splits(DEFAULT_NUM_RANDOM_SPLITS), max_depth(
DEFAULT_MAXDEPTH), alpha(DEFAULT_ALPHA), minprop(DEFAULT_MINPROP), num_threads(DEFAULT_NUM_THREADS), data { }, overall_prediction_error(
NAN), importance_mode(DEFAULT_IMPORTANCE_MODE), progress(0) {
NAN), importance_mode(DEFAULT_IMPORTANCE_MODE), progress(0), forest_save(false) {
}


Expand All @@ -42,7 +42,7 @@ void ForestOnline::init(std::string dependent_variable_name, MemoryMode memory_m
uint min_node_size, std::string status_variable_name, bool prediction_mode, bool sample_with_replacement,
const std::vector<std::string>& unordered_variable_names, bool memory_saving_splitting, SplitRule splitrule,
bool predict_all, std::vector<double>& sample_fraction, double alpha, double minprop, bool holdout,
PredictionType prediction_type, uint num_random_splits, bool order_snps, uint max_depth, size_t oob_samples_num) {
PredictionType prediction_type, uint num_random_splits, bool order_snps, uint max_depth, size_t oob_samples_num, bool forest_save) {

// Initialize data with memmode
this->data = std::move(input_data);
Expand Down Expand Up @@ -137,11 +137,11 @@ void ForestOnline::init(std::string dependent_variable_name, MemoryMode memory_m
}

tree_order = std::vector<size_t>(num_trees);

this->forest_save = forest_save;
}

void ForestOnline::run(bool verbose, bool compute_oob_error) {

if (forest_save) saveToFileBegin();
if (prediction_mode) {
if (verbose && verbose_out) {
*verbose_out << "Predicting .." << std::endl;
Expand Down Expand Up @@ -260,36 +260,27 @@ void ForestOnline::writeImportanceFile() {
// *verbose_out << "Saved variable importance to file " << filename << "." << std::endl;
}

void ForestOnline::saveToFile() {
void ForestOnline::saveToFileBegin() {

// Open file for writing
std::string filename = output_prefix + ".ForestOnline";
std::ofstream outfile;
outfile.open(filename, std::ios::binary);
if (!outfile.good()) {
forestoutfile.open(filename, std::ios::binary);
if (!forestoutfile.good()) {
throw std::runtime_error("Could not write to output file: " + filename + ".");
}

// Write dependent_varID
outfile.write((char*) &dependent_varID, sizeof(dependent_varID));
forestoutfile.write((char*) &dependent_varID, sizeof(dependent_varID));

// Write num_trees
outfile.write((char*) &num_trees, sizeof(num_trees));
forestoutfile.write((char*) &num_trees, sizeof(num_trees));

// Write is_ordered_variable
saveVector1D(data->getIsOrderedVariable(), outfile);
saveVector1D(data->getIsOrderedVariable(), forestoutfile);

saveToFileInternal(outfile);
saveToFileInternal(forestoutfile);

// Write tree data for each tree
for (auto& tree : trees) {
tree->appendToFile(outfile);
}

// Close file
outfile.close();
if (verbose_out)
*verbose_out << "Saved ForestOnline to file " << filename << "." << std::endl;
}
// #nocov end

Expand Down Expand Up @@ -617,6 +608,9 @@ void ForestOnline::growTreesInThread(uint thread_idx, std::vector<double>* varia
*verbose_out << "computed_" << !predict_all << " " << progress << "/" << num_trees << std::endl;
#endif
}
if (forest_save) {
trees[i]->appendToFile(forestoutfile);
}
trees[i].reset(nullptr);
mutex.unlock();
// condition_variable.notify_one();
Expand Down Expand Up @@ -729,6 +723,11 @@ void ForestOnline::loadFromFile(std::string filename) {
}
// #nocov end



thread_local std::vector<size_t> samples_terminalnodes;


void ForestOnline::setSplitWeightVector(std::vector<std::vector<double>>& split_select_weights) {

// Size should be 1 x num_independent_variables or num_trees x num_independent_variables
Expand Down Expand Up @@ -907,4 +906,6 @@ void ForestOnline::showProgress(std::string operation, size_t max_progress) {

#endif



} // namespace ranger
24 changes: 22 additions & 2 deletions src/ForestOnline.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,8 @@
#include <random>
#include <ctime>
#include <memory>
#include <fstream>
#include <map>
#ifndef OLD_WIN_R_BUILD
#include <thread>
#include <chrono>
Expand Down Expand Up @@ -34,7 +36,7 @@ class ForestOnline {
uint min_node_size, std::string status_variable_name, bool prediction_mode, bool sample_with_replacement,
const std::vector<std::string>& unordered_variable_names, bool memory_saving_splitting, SplitRule splitrule,
bool predict_all, std::vector<double>& sample_fraction, double alpha, double minprop, bool holdout,
PredictionType prediction_type, uint num_random_splits, bool order_snps, uint max_depth, size_t oob_weights = 0);
PredictionType prediction_type, uint num_random_splits, bool order_snps, uint max_depth, size_t oob_weights = 0, bool forest_save = false);
virtual void initInternal(std::string status_variable_name) = 0;

// Grow or predict
Expand All @@ -49,7 +51,7 @@ class ForestOnline {
std::vector<std::pair<std::string,double>> getImportance();

// Save ForestOnline to file
void saveToFile();
void saveToFileBegin();
virtual void saveToFileInternal(std::ofstream& outfile) = 0;

std::unique_ptr<Data> releaseData() {
Expand Down Expand Up @@ -122,6 +124,12 @@ class ForestOnline {
// Verbose output stream, cout if verbose==true, logfile if not
std::ostream* verbose_out;

// Get Oob map subset
std::map<size_t,size_t>& getOobMapSubset() {
return oob_subset;
}


protected:
void grow();
virtual void growInternal() = 0;
Expand Down Expand Up @@ -152,6 +160,11 @@ class ForestOnline {
void setSplitWeightVector(std::vector<std::vector<double>>& split_select_weights);
void setAlwaysSplitVariables(const std::vector<std::string>& always_split_variable_names);

std::vector<std::mutex> mutex_samples;

// OOb counts
std::vector<size_t> samples_oob_count;
std::vector<size_t> samples_terminalnodes;
// Show progress every few seconds
#ifdef OLD_WIN_R_BUILD
void showProgress(std::string operation, clock_t start_time, clock_t& lap_time);
Expand Down Expand Up @@ -224,6 +237,13 @@ class ForestOnline {
// Computation progress (finished trees)
size_t progress;
tqdm bar;

// subset of oob for predictionsS
std::map<size_t,size_t> oob_subset;

// Forest saved to file
bool forest_save;
std::ofstream forestoutfile;
#ifdef R_BUILD
size_t aborted_threads;
bool aborted;
Expand Down
68 changes: 61 additions & 7 deletions src/ForestOnlineClassification.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -5,18 +5,21 @@
#include <stdexcept>
#include <cmath>
#include <string>
#include <range/v3/all.hpp>

#include "utility.h"
#include "ForestOnlineClassification.hpp"
#include "TreeClassification.h"
#include "Data.h"

using namespace ranges;
namespace ranger
{

void ForestOnlineClassification::initInternal(std::string status_variable_name)
{

keep_inbag = false;
// If mtry not set, use floored square root of number of independent variables.
if (mtry == 0)
{
Expand Down Expand Up @@ -95,22 +98,30 @@ namespace ranger
{
size_t num_prediction_samples = predict_data->getNumRows();
class_count = std::vector<std::unordered_map<double, size_t>>(num_prediction_samples);
predictions = std::vector<std::vector<std::vector<double>>>(3);
predictions = std::vector<std::vector<std::vector<double>>>(7);
// OOB Votes
oob_votes = std::vector<std::unordered_map<double, size_t>>(num_samples);
// Predictions on the OOB set
predictions[0] = std::vector<std::vector<double>>(1, std::vector<double>(num_samples));
// OOB Error classifications on n-trees (non-cumulative)
predictions[2] = std::vector<std::vector<double>>(1, std::vector<double>(num_trees, 0.0));
// predictions[2] = std::vector<std::vector<double>>(1, std::vector<double>(num_samples));
if (predict_all || prediction_type == TERMINALNODES)
{

mutex_samples = std::vector<std::mutex>(num_samples);
if (predict_all) {
// Predictions on the provided samples by each tree
predictions[1] = std::vector<std::vector<double>>(num_prediction_samples, std::vector<double>(num_trees));
predictions[1] = std::vector<std::vector<double>>(num_prediction_samples, std::vector<double>(num_trees));
predictions[4] = std::vector<std::vector<double>>(num_prediction_samples, std::vector<double>(num_samples,0.0));
}
else
{
// Predictions on the provided samples
predictions[1] = std::vector<std::vector<double>>(1, std::vector<double>(num_prediction_samples));
}
if (prediction_type == TERMINALNODES)
{
predictions[3] = std::vector<std::vector<double>>(num_prediction_samples, std::vector<double>(num_trees));
}
}

void ForestOnlineClassification::predictInternal(size_t tree_idx)
Expand Down Expand Up @@ -140,6 +151,7 @@ namespace ranger
// }
// predictions[0][0][sample_idx] = mostFrequentValue(class_count, random_number_generator);
// }

for (size_t sample_idx = 0; sample_idx < predict_data->getNumRows(); ++sample_idx)
{
if (predict_all || prediction_type == TERMINALNODES)
Expand Down Expand Up @@ -167,14 +179,21 @@ namespace ranger
// For each tree loop over OOB samples and count classes
double to_add = 0.0;
auto numOOB = trees[tree_idx]->getNumSamplesOob();
for (size_t sample_idx = 0; sample_idx < numOOB; ++sample_idx)
if (samples_terminalnodes.empty()) samples_terminalnodes.resize(num_samples);

for (size_t sample_oob_idx = 0; sample_oob_idx < numOOB; ++sample_oob_idx)
{
size_t sampleID = trees[tree_idx]->getOobSampleIDs()[sample_idx];
auto res = static_cast<size_t>(getTreePrediction(tree_idx, sample_idx));
size_t sampleID = trees[tree_idx]->getOobSampleIDs()[sample_oob_idx];
auto sample_node_oob = static_cast<size_t>(getTreePredictionTerminalNodeID(tree_idx, sample_oob_idx));
auto res = static_cast<size_t>(getTreePrediction(tree_idx, sample_oob_idx));
mutex_post.lock();
++class_counts[sampleID][res];
if (!class_counts[sampleID].empty())
to_add += (mostFrequentValue(class_counts[sampleID], random_number_generator) == data->get(sampleID, dependent_varID)) ? 0.0 : 1.0;
for (size_t sample_idx = 0; sample_idx < predict_data->getNumRows(); ++sample_idx) {
auto sample_node = static_cast<size_t>(getTreePredictionTerminalNodeID(tree_idx, sample_idx));
if (sample_node == sample_node_oob) predictions[4][sample_idx][sampleID]++;
}
mutex_post.unlock();
}
predictions[2][0][tree_idx] += to_add / static_cast<double>(numOOB);
Expand Down Expand Up @@ -207,6 +226,7 @@ namespace ranger
{
if (!class_counts[i].empty())
{
oob_votes[i] = class_counts[i];
predictions[0][0][i] = mostFrequentValue(class_counts[i], random_number_generator);
}
else
Expand Down Expand Up @@ -240,6 +260,7 @@ namespace ranger
{
predictions[1][0][sample_idx] = mostFrequentValue(class_count[sample_idx], random_number_generator);
}

std::vector<double> sort_oob_trees(num_trees);
for (auto i = 0; i < num_trees; i++)
sort_oob_trees[i] = predictions[2][0][tree_order[i]];
Expand Down Expand Up @@ -273,6 +294,34 @@ namespace ranger
}
}

std::vector<std::pair<double,std::vector<double>>> ForestOnlineClassification::getWeights() {
std::vector<std::pair<double,std::vector<double>>> res;
auto num_targets = predict_data->getNumRows();
for (auto sample_idx = 0; sample_idx < num_samples; sample_idx++) {
std::vector<double> vec_targets(num_targets);
for(auto i = 0; i < num_targets; i++) vec_targets[i] = predictions[4][i][sample_idx];
res.push_back(std::make_pair(data->get(sample_idx,dependent_varID),vec_targets));
}
return res;
}

void ForestOnlineClassification::writeWeightsFile() {
// Open confusion file for writing
std::string filename = output_prefix + ".predweights";
std::ofstream outfile;
outfile.open(filename, std::ios::out);
if (!outfile.good())
{
throw std::runtime_error("Could not write to predweights file: " + filename + ".");
}
outfile << "value,weight" << std::endl;
for(auto& kv: getWeights()) {
outfile << kv.first;
for(auto& p : kv.second) outfile << "," << p;
outfile << std::endl;
}
}

std::vector<std::vector<size_t>> ForestOnlineClassification::getConfusion()
{
std::vector<std::vector<size_t>> res(class_values.size(), std::vector<size_t>(class_values.size()));
Expand Down Expand Up @@ -482,6 +531,11 @@ namespace ranger
this->verbose_out = verbose_output;
}

const std::vector<size_t>& ForestOnlineClassification::getInbagCounts(size_t tree_idx)
{
const auto &tree = dynamic_cast<const TreeClassification &>(*trees[tree_idx]);
return tree.getInbagCounts();
}
// #nocov end

} // namespace ranger
Loading

0 comments on commit a9a23f2

Please sign in to comment.