Skip to content

Commit

Permalink
fixing #421
Browse files Browse the repository at this point in the history
  • Loading branch information
osorensen committed Sep 5, 2024
1 parent 771fa30 commit 770033c
Show file tree
Hide file tree
Showing 4 changed files with 12 additions and 5 deletions.
2 changes: 1 addition & 1 deletion DESCRIPTION
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
Package: BayesMallows
Type: Package
Title: Bayesian Preference Learning with the Mallows Rank Model
Version: 2.2.2
Version: 2.2.2.9000
Authors@R: c(person("Oystein", "Sorensen",
email = "oystein.sorensen.1985@gmail.com",
role = c("aut", "cre"),
Expand Down
4 changes: 4 additions & 0 deletions NEWS.md
Original file line number Diff line number Diff line change
@@ -1,3 +1,7 @@
# BayesMallows (development versions)

* A bug in the pseudolikelihood proposal for latent rankings has been corrected.

# BayesMallows 2.2.2

* An error in compute_mallows_loglik when the number of clusters is more than
Expand Down
8 changes: 5 additions & 3 deletions src/rank_proposal.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -224,9 +224,11 @@ std::pair<arma::vec, double> PartialPseudoProposal::propose_pseudo(
vec sample_probs = normalise(exp(log_numerator), 1);

if(forward) {
ivec ans(sample_probs.size());
R::rmultinom(1, sample_probs.begin(), sample_probs.size(), ans.begin());
proposal(span(item_to_rank)) = available_rankings(find(ans == 1));
Rcpp::IntegerVector rank_index = Rcpp::sample(
sample_probs.size(), 1, true,
Rcpp::as<Rcpp::NumericVector>(Rcpp::wrap(sample_probs)),
false);
proposal(span(item_to_rank)) = available_rankings(rank_index(0));
}

int ranking_chosen = as_scalar(find(proposal(item_to_rank) == available_rankings));
Expand Down
3 changes: 2 additions & 1 deletion tests/testthat/test-smc_update_correctness.R
Original file line number Diff line number Diff line change
Expand Up @@ -137,7 +137,8 @@ test_that("update_mallows is correct for new partial rankings", {

expect_equal(
mean(mod_smc_next$alpha$value),
ifelse(aug == "uniform", 10.9348917143992, 10.9279783083453)
ifelse(aug == "uniform", 10.9348917143992, 10.910),
tolerance = .01
)

expect_equal(
Expand Down

0 comments on commit 770033c

Please sign in to comment.