Skip to content

Commit

Permalink
work going forward
Browse files Browse the repository at this point in the history
  • Loading branch information
osorensen committed Apr 17, 2024
1 parent 008cf9c commit 8cf93c4
Show file tree
Hide file tree
Showing 5 changed files with 87 additions and 57 deletions.
4 changes: 2 additions & 2 deletions R/RcppExports.R
Original file line number Diff line number Diff line change
Expand Up @@ -80,7 +80,7 @@ run_smc <- function(data, new_data, model_options, smc_options, compute_options,
.Call(`_BayesMallows_run_smc`, data, new_data, model_options, smc_options, compute_options, priors, initial_values, pfun_values, pfun_estimate)
}

run_sushi_smc2 <- function(rankings) {
.Call(`_BayesMallows_run_sushi_smc2`, rankings)
run_sushi_smc2 <- function(rankings, pfun_values) {
.Call(`_BayesMallows_run_sushi_smc2`, rankings, pfun_values)
}

9 changes: 5 additions & 4 deletions src/RcppExports.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -163,13 +163,14 @@ BEGIN_RCPP
END_RCPP
}
// run_sushi_smc2
Rcpp::List run_sushi_smc2(arma::mat rankings);
RcppExport SEXP _BayesMallows_run_sushi_smc2(SEXP rankingsSEXP) {
Rcpp::List run_sushi_smc2(arma::mat rankings, Rcpp::Nullable<arma::mat> pfun_values);
RcppExport SEXP _BayesMallows_run_sushi_smc2(SEXP rankingsSEXP, SEXP pfun_valuesSEXP) {
BEGIN_RCPP
Rcpp::RObject rcpp_result_gen;
Rcpp::RNGScope rcpp_rngScope_gen;
Rcpp::traits::input_parameter< arma::mat >::type rankings(rankingsSEXP);
rcpp_result_gen = Rcpp::wrap(run_sushi_smc2(rankings));
Rcpp::traits::input_parameter< Rcpp::Nullable<arma::mat> >::type pfun_values(pfun_valuesSEXP);
rcpp_result_gen = Rcpp::wrap(run_sushi_smc2(rankings, pfun_values));
return rcpp_result_gen;
END_RCPP
}
Expand All @@ -185,7 +186,7 @@ static const R_CallMethodDef CallEntries[] = {
{"_BayesMallows_rmallows", (DL_FUNC) &_BayesMallows_rmallows, 7},
{"_BayesMallows_run_mcmc", (DL_FUNC) &_BayesMallows_run_mcmc, 8},
{"_BayesMallows_run_smc", (DL_FUNC) &_BayesMallows_run_smc, 9},
{"_BayesMallows_run_sushi_smc2", (DL_FUNC) &_BayesMallows_run_sushi_smc2, 1},
{"_BayesMallows_run_sushi_smc2", (DL_FUNC) &_BayesMallows_run_sushi_smc2, 2},
{NULL, NULL, 0}
};

Expand Down
95 changes: 71 additions & 24 deletions src/sushi_complete.cpp
Original file line number Diff line number Diff line change
@@ -1,8 +1,19 @@
#include <RcppArmadillo.h>
#include <vector>
#include <string>
#include <memory>
#include <algorithm>
#include "partition_functions.h"
#include "distances.h"

using namespace arma;

double log_sum_exp(const vec& x) {
double max_val = x.max();
double sum_exp = accu(exp(x - max_val));
return max_val + log(sum_exp);
}

struct PriorParameters {
PriorParameters(double shape, double rate, double psi, int num_clusters,
int num_items) :
Expand Down Expand Up @@ -42,8 +53,8 @@ struct PriorParameters {
struct LatentParticle {
LatentParticle() {}

vec cluster_indicators{};
double incremental_weight{};
int cluster_indicator{};
double log_weight{0};

};

Expand All @@ -52,45 +63,81 @@ struct StaticParticle {
alpha { prior.initialize_alpha() },
tau { prior.initialize_tau() },
rho { prior.initialize_rho() },
latent_particles { std::vector<LatentParticle>(num_latent_particles) }
latent_particles { std::vector<LatentParticle>(num_latent_particles) },
num_clusters { prior.num_clusters },
particle_filter_weights { ones<vec>(num_latent_particles) }
{}

vec alpha;
vec tau;
mat rho;
std::vector<LatentParticle> latent_particles;
};

double particle_filter(const std::vector<LatentParticle>& latent_particles, int t) {
if(t > 0) {

}



return 0;

}
int t{1};
double ess{};
int num_clusters;
vec particle_filter_weights;
double marginal_likelihood{};

std::vector<StaticParticle> iterated_batch_importance_sampling(int T) {
void run_particle_filter(
const vec& latent_ranking,
const std::unique_ptr<PartitionFunction>& pfun,
const std::unique_ptr<Distance>& distfun
) {

for(size_t t{}; t < T; t++) {
if(t > 1) {
ess = pow(norm(particle_filter_weights, 2), -2);
}

for(size_t i{}; i < latent_particles.size(); i++) {
vec log_cluster_probs = ones(num_clusters);
for(size_t cluster{}; cluster < num_clusters; cluster++) {
double distance = distfun->d(latent_ranking, rho.col(cluster));
log_cluster_probs(cluster) =
log(tau(cluster)) - alpha(cluster) * distance - pfun->logz(alpha(cluster));
}
double lse = log_sum_exp(log_cluster_probs);
Rcpp::NumericVector cluster_probs = Rcpp::as<Rcpp::NumericVector>(
Rcpp::wrap(exp(log_cluster_probs - lse)));

Rcpp::IntegerVector tmp = Rcpp::sample(num_clusters, 1, false, cluster_probs);
latent_particles[i].cluster_indicator = tmp(0);
latent_particles[i].log_weight += lse;
particle_filter_weights(i) = exp(latent_particles[i].log_weight);
}
marginal_likelihood = mean(particle_filter_weights);
Rcpp::Rcout << marginal_likelihood << std::endl;
particle_filter_weights = normalise(particle_filter_weights);
t++;
}
};


return std::vector<StaticParticle>(1);
}

// [[Rcpp::export]]
Rcpp::List run_sushi_smc2(arma::mat rankings) {
Rcpp::List run_sushi_smc2(
arma::mat rankings,
Rcpp::Nullable<arma::mat> pfun_values) {
int T = rankings.n_cols;
int num_items = rankings.n_rows;
std::string metric = "footrule";
int S{2};
int R{1};
PriorParameters prior{1, .001, 10, 5, 10};
int R{3};
PriorParameters prior{1, .1, 10, 5, num_items};
auto pfun = choose_partition_function(num_items, metric, pfun_values, R_NilValue);
auto distfun = choose_distance_function(metric);
std::vector<StaticParticle> static_particles(S, StaticParticle(prior, R));
vec static_particle_weights = ones(S);



for(int t{}; t < T; t++) {
vec current_observation = rankings.col(t);
for(size_t i{}; i < static_particles.size(); i++) {
static_particles[i].run_particle_filter(current_observation, pfun, distfun);
static_particle_weights(i) *= static_particles[i].marginal_likelihood;
}
static_particle_weights = normalise(static_particle_weights);
double ess = pow(norm(static_particle_weights), -2);
Rcpp::Rcout << ess << std::endl;
}


return Rcpp::List::create(
Expand Down
2 changes: 1 addition & 1 deletion work-docs/paper-scripts/sushi_complete.R
Original file line number Diff line number Diff line change
Expand Up @@ -2,4 +2,4 @@ devtools::load_all()

pfun_values <- extract_pfun_values("footrule", 10, NULL)

run_sushi_smc2(t(sushi_rankings[1:13, ]))
run_sushi_smc2(t(sushi_rankings[1:13, ]), pfun_values)
34 changes: 8 additions & 26 deletions work-docs/test.cpp
Original file line number Diff line number Diff line change
@@ -1,34 +1,16 @@
#include <Rcpp.h>
using namespace Rcpp;
#include <RcppArmadillo.h>

// This is a simple example of exporting a C++ function to R. You can
// source this function into an R session using the Rcpp::sourceCpp
// function (or via the Source button on the editor toolbar). Learn
// more about Rcpp at:
//
// http://www.rcpp.org/
// http://adv-r.had.co.nz/Rcpp.html
// http://gallery.rcpp.org/
//

// [[Rcpp::export]]
NumericVector test(int n, double shape, double rate) {
NumericVector result(n);
for(int i = 0; i < n; i++) {
result(i) = R::rgamma(shape, 1 / rate);
}
return result;
double log_sum_exp(const arma::vec& x) {
double max_val = x.max();
double sum_exp = arma::accu(exp(x - max_val));
return max_val + log(sum_exp);
}


// You can include R code blocks in C++ files processed with sourceCpp
// (useful for testing and development). The R code will be automatically
// run after the compilation.
//

/*** R
mean(rgamma(1e5, shape = 2, rate = 2))
sd(rgamma(1e5, shape = 2, rate = 2))
mean(test(1e5, shape = 2, rate = 2))
sd(test(1e5, shape = 2, rate = 2))
x <- 1:3
sum(exp(x - log_sum_exp(x)))
*/

0 comments on commit 8cf93c4

Please sign in to comment.