Skip to content

Commit

Permalink
adding save=TRUE argument where necessary
Browse files Browse the repository at this point in the history
  • Loading branch information
osorensen committed Apr 16, 2024
1 parent 006b7df commit 4a4b472
Show file tree
Hide file tree
Showing 5 changed files with 32 additions and 21 deletions.
4 changes: 2 additions & 2 deletions R/RcppExports.R
Original file line number Diff line number Diff line change
Expand Up @@ -5,8 +5,8 @@ abind <- function(x, y) {
.Call(`_BayesMallows_abind`, x, y)
}

all_topological_sorts <- function(prefs, n_items, maxit = 1000L) {
.Call(`_BayesMallows_all_topological_sorts`, prefs, n_items, maxit)
all_topological_sorts <- function(prefs, n_items, maxit, save) {
.Call(`_BayesMallows_all_topological_sorts`, prefs, n_items, maxit, save)
}

#' Asymptotic Approximation of Partition Function
Expand Down
2 changes: 1 addition & 1 deletion R/generate_initial_ranking.R
Original file line number Diff line number Diff line change
Expand Up @@ -42,7 +42,7 @@ generate_initial_ranking.BayesMallowsIntransitive <- function(
}

create_ranks <- function(mat, n_items, max_topological_sorts) {
ret <- all_topological_sorts(mat, n_items, max_topological_sorts)
ret <- all_topological_sorts(mat, n_items, max_topological_sorts, TRUE)
u <- sample(min(max_topological_sorts, nrow(ret)), 1)
ret <- ret[u, ]
all_items <- seq(from = 1, to = n_items, by = 1)
Expand Down
9 changes: 5 additions & 4 deletions src/RcppExports.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -24,15 +24,16 @@ BEGIN_RCPP
END_RCPP
}
// all_topological_sorts
arma::imat all_topological_sorts(arma::imat prefs, int n_items, int maxit);
RcppExport SEXP _BayesMallows_all_topological_sorts(SEXP prefsSEXP, SEXP n_itemsSEXP, SEXP maxitSEXP) {
arma::imat all_topological_sorts(arma::imat prefs, int n_items, int maxit, bool save);
RcppExport SEXP _BayesMallows_all_topological_sorts(SEXP prefsSEXP, SEXP n_itemsSEXP, SEXP maxitSEXP, SEXP saveSEXP) {
BEGIN_RCPP
Rcpp::RObject rcpp_result_gen;
Rcpp::RNGScope rcpp_rngScope_gen;
Rcpp::traits::input_parameter< arma::imat >::type prefs(prefsSEXP);
Rcpp::traits::input_parameter< int >::type n_items(n_itemsSEXP);
Rcpp::traits::input_parameter< int >::type maxit(maxitSEXP);
rcpp_result_gen = Rcpp::wrap(all_topological_sorts(prefs, n_items, maxit));
Rcpp::traits::input_parameter< bool >::type save(saveSEXP);
rcpp_result_gen = Rcpp::wrap(all_topological_sorts(prefs, n_items, maxit, save));
return rcpp_result_gen;
END_RCPP
}
Expand Down Expand Up @@ -164,7 +165,7 @@ END_RCPP

static const R_CallMethodDef CallEntries[] = {
{"_BayesMallows_abind", (DL_FUNC) &_BayesMallows_abind, 2},
{"_BayesMallows_all_topological_sorts", (DL_FUNC) &_BayesMallows_all_topological_sorts, 3},
{"_BayesMallows_all_topological_sorts", (DL_FUNC) &_BayesMallows_all_topological_sorts, 4},
{"_BayesMallows_asymptotic_partition_function", (DL_FUNC) &_BayesMallows_asymptotic_partition_function, 6},
{"_BayesMallows_get_rank_distance", (DL_FUNC) &_BayesMallows_get_rank_distance, 3},
{"_BayesMallows_compute_importance_sampling_estimate", (DL_FUNC) &_BayesMallows_compute_importance_sampling_estimate, 4},
Expand Down
35 changes: 22 additions & 13 deletions src/all_topological_sorts.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -15,17 +15,18 @@ class Graph {
vector<int> indegree;
void alltopologicalSortUtil(vector<int>& res, vector<bool>& visited);
int maxit;
int iter{};
bool save;

public:
Graph(int n_items, int maxit);
Graph(int n_items, int maxit, bool save);
void addEdge(int v, int w);
void alltopologicalSort();
vector<vector<int>> m;
int iter{};
};

Graph::Graph(int n_items, int maxit) : n_items { n_items },
maxit { maxit } {
Graph::Graph(int n_items, int maxit, bool save) : n_items { n_items },
maxit { maxit }, save { save } {
adj = new list<int>[n_items];
for (int i = 0; i < n_items; i++) indegree.push_back(0);
}
Expand Down Expand Up @@ -64,7 +65,9 @@ void Graph::alltopologicalSortUtil(vector<int>& res, vector<bool>& visited) {

if (!flag){
iter++;
m.push_back(res);
if(save) {
m.push_back(res);
}
}
}

Expand All @@ -77,19 +80,25 @@ void Graph::alltopologicalSort() {
}

// [[Rcpp::export]]
arma::imat all_topological_sorts(arma::imat prefs, int n_items, int maxit = 1000) {
Graph g(n_items, maxit);
arma::imat all_topological_sorts(arma::imat prefs, int n_items, int maxit,
bool save) {
Graph g(n_items, maxit, save);
for(size_t i{}; i < prefs.n_rows; i++) {
g.addEdge(prefs.at(i, 1) - 1, prefs.at(i, 0) - 1);
}
g.alltopologicalSort();

arma::imat m(g.m.size(), n_items);
for(size_t i{}; i < m.n_rows; i++) {
for(size_t j{}; j < m.n_cols; j++) {
m(i, j) = g.m[i][j] + 1;
if(save) {
arma::imat m(g.m.size(), n_items);
for(size_t i{}; i < m.n_rows; i++) {
for(size_t j{}; j < m.n_cols; j++) {
m(i, j) = g.m[i][j] + 1;
}
}
return m;
} else {
arma::imat m(1, 1);
m(0, 0) = g.iter;
return m;
}

return m;
}
3 changes: 2 additions & 1 deletion src/all_topological_sorts.h
Original file line number Diff line number Diff line change
@@ -1,2 +1,3 @@
#pragma once
arma::imat all_topological_sorts(arma::imat prefs, int n_items, int maxit);
arma::imat all_topological_sorts(arma::imat prefs, int n_items, int maxit = 1000,
bool save = true);

0 comments on commit 4a4b472

Please sign in to comment.