diff --git a/R/sparse_dummy.R b/R/sparse_dummy.R index a69538b..bf248f3 100644 --- a/R/sparse_dummy.R +++ b/R/sparse_dummy.R @@ -49,7 +49,7 @@ sparse_dummy <- function(x, one_hot = TRUE) { n_lvls <- length(lvls) - if (n_lvls == 1) { + if (n_lvls == 1 && one_hot) { res <- list(rep(1L, length(x))) names(res) <- lvls return(res) diff --git a/tests/testthat/test-sparse_dummy.R b/tests/testthat/test-sparse_dummy.R index a977c47..f094acf 100644 --- a/tests/testthat/test-sparse_dummy.R +++ b/tests/testthat/test-sparse_dummy.R @@ -127,6 +127,22 @@ test_that("sparse_dummy(one_hot = FALSE) works with single level", { ) }) +test_that("sparse_dummy(one_hot = FALSE) works with two levels", { + x <- factor(c("a", "b", "a")) + exp <- list( + b = c(0L, 1L, 0L) + ) + + res <- sparse_dummy(x, one_hot = FALSE) + expect_identical( + res, + exp + ) + + expect_true(is.integer(res$b)) + expect_true(is_sparse_vector(res$b)) +}) + test_that("sparse_dummy(one_hot = TRUE) works zero length input", { x <- factor(character()) exp <- structure(list(), names = character(0))