From d62199a91337d003fbc31265aa87047ad3888b7f Mon Sep 17 00:00:00 2001 From: Max Kuhn Date: Tue, 3 Dec 2024 15:08:13 -0500 Subject: [PATCH] finalizing tailors and workflows with tailors (#973) * changes for #972 * testing needs probably * requires dev dials * Apply suggestions from code review Co-authored-by: Simon P. Couch * updates from reviewer feedback * tidy format :roll_eyes: --------- Co-authored-by: Simon P. Couch --- DESCRIPTION | 6 +- NAMESPACE | 1 + R/finalize.R | 37 +++++++++ R/grid_helpers.R | 9 ++- man/finalize_model.Rd | 3 + tests/testthat/_snaps/finalization.md | 8 ++ tests/testthat/test-finalization.R | 108 ++++++++++++++++++++++++++ 7 files changed, 168 insertions(+), 4 deletions(-) diff --git a/DESCRIPTION b/DESCRIPTION index f9f0d905e..5f9684838 100644 --- a/DESCRIPTION +++ b/DESCRIPTION @@ -18,7 +18,7 @@ Depends: R (>= 4.0) Imports: cli (>= 3.3.0), - dials (>= 1.3.0), + dials (>= 1.3.0.9000), doFuture (>= 1.0.0), dplyr (>= 1.1.0), foreach, @@ -50,6 +50,7 @@ Suggests: kknn, knitr, modeldata, + probably, scales, spelling, splines2, @@ -62,7 +63,8 @@ Remotes: tidymodels/recipes, tidymodels/rsample, tidymodels/tailor, - tidymodels/workflows + tidymodels/workflows, + tidymodels/dials Config/Needs/website: pkgdown, tidymodels, kknn, doParallel, doFuture, tidyverse/tidytemplate Config/testthat/edition: 3 diff --git a/NAMESPACE b/NAMESPACE index 2adf3835f..2bb03b083 100644 --- a/NAMESPACE +++ b/NAMESPACE @@ -193,6 +193,7 @@ export(extract_workflow) export(filter_parameters) export(finalize_model) export(finalize_recipe) +export(finalize_tailor) export(finalize_workflow) export(finalize_workflow_preprocessor) export(first_eval_time) diff --git a/R/finalize.R b/R/finalize.R index a0c365b41..79f82cd5b 100644 --- a/R/finalize.R +++ b/R/finalize.R @@ -95,6 +95,43 @@ finalize_workflow <- function(x, parameters) { x <- set_workflow_recipe(x, rec) } + if (has_postprocessor(x)) { + tailor <- extract_postprocessor(x) + tailor <- finalize_tailor(tailor, parameters) + x <- set_workflow_tailor(x, tailor) + } + + x +} + +#' @export +#' @rdname finalize_model +finalize_tailor <- function(x, parameters) { + if (!inherits(x, "tailor")) { + cli::cli_abort("{.arg x} should be a tailor, not {.obj_type_friendly {x}}.") + } + check_final_param(parameters) + pset <- + hardhat::extract_parameter_set_dials(x) %>% + dplyr::filter(id %in% names(parameters) & source == "tailor") + + if (tibble::is_tibble(parameters)) { + parameters <- as.list(parameters) + } + + parameters <- parameters[names(parameters) %in% pset$id] + parameters <- parameters[pset$id] + + for (i in seq_along(x$adjustments)) { + adj <- x$adjustments[[i]] + adj_comps <- purrr::map_lgl(pset$component, ~ inherits(adj, .x)) + if (any(adj_comps)) { + adj_ids <- pset$id[adj_comps] + adj_prms <- parameters[names(parameters) %in% adj_ids] + adj$arguments <- purrr::list_modify(adj$arguments, !!!adj_prms) + x$adjustments[[i]] <- adj + } + } x } diff --git a/R/grid_helpers.R b/R/grid_helpers.R index 91b6131e1..3f6ad66d2 100644 --- a/R/grid_helpers.R +++ b/R/grid_helpers.R @@ -21,10 +21,10 @@ predict_model <- function(new_data, orig_rows, workflow, grid, metrics, msg <- c( msg, - i = + i = "Consider using {.code skip = TRUE} on any recipe steps that remove rows to avoid calling them on the assessment set." - + ) } else { msg <- c(msg, i = "Did your preprocessing steps filter or remove rows?") @@ -464,3 +464,8 @@ set_workflow_recipe <- function(workflow, recipe) { workflow$pre$actions$recipe$recipe <- recipe workflow } + +set_workflow_tailor <- function(workflow, tailor) { + workflow$post$actions$tailor$tailor <- tailor + workflow +} diff --git a/man/finalize_model.Rd b/man/finalize_model.Rd index caf391633..2e5f1b1b5 100644 --- a/man/finalize_model.Rd +++ b/man/finalize_model.Rd @@ -4,6 +4,7 @@ \alias{finalize_model} \alias{finalize_recipe} \alias{finalize_workflow} +\alias{finalize_tailor} \title{Splice final parameters into objects} \usage{ finalize_model(x, parameters) @@ -11,6 +12,8 @@ finalize_model(x, parameters) finalize_recipe(x, parameters) finalize_workflow(x, parameters) + +finalize_tailor(x, parameters) } \arguments{ \item{x}{A recipe, \code{parsnip} model specification, or workflow.} diff --git a/tests/testthat/_snaps/finalization.md b/tests/testthat/_snaps/finalization.md index 38b240ac6..780f74bac 100644 --- a/tests/testthat/_snaps/finalization.md +++ b/tests/testthat/_snaps/finalization.md @@ -7,3 +7,11 @@ ! Some model parameters require finalization but there are recipe parameters that require tuning. i Please use `extract_parameter_set_dials()` to set parameter ranges manually and supply the output to the `param_info` argument. +# finalize tailors + + Code + finalize_tailor(linear_reg(), tibble()) + Condition + Error in `finalize_tailor()`: + ! `x` should be a tailor, not a object. + diff --git a/tests/testthat/test-finalization.R b/tests/testthat/test-finalization.R index 36f0ca5e2..4df6394ff 100644 --- a/tests/testthat/test-finalization.R +++ b/tests/testthat/test-finalization.R @@ -73,3 +73,111 @@ test_that("finalize recipe step with multiple tune parameters", { expect_equal(finalize_recipe(rec, best)$steps[[1]]$degree, 1) expect_equal(finalize_recipe(rec, best)$steps[[1]]$deg_free, 2) }) + +# ------------------------------------------------------------------------------ +# post-processing + +test_that("finalize tailors", { + skip_if_not_installed("probably") + skip_if_not_installed("dials", "1.3.0.9000") + library(tailor) + + adjust_rng <- + tailor() %>% + adjust_numeric_range(lower_limit = tune(), upper_limit = tune()) + + adj_1 <- finalize_tailor(adjust_rng, tibble(lower_limit = 2)) + expect_equal(adj_1$adjustments[[1]]$arguments$lower_limit, 2) + expect_equal(adj_1$adjustments[[1]]$arguments$upper_limit, tune()) + + adj_2 <- finalize_tailor(adjust_rng, tibble(lower_limit = 2, upper_limit = 3)) + expect_equal(adj_2$adjustments[[1]]$arguments$lower_limit, 2) + expect_equal(adj_2$adjustments[[1]]$arguments$upper_limit, 3) + + adj_3 <- finalize_tailor(adjust_rng, tibble(lower_limit = 2, upper_limit = 3, a = 2)) + expect_equal(adj_3$adjustments[[1]]$arguments$lower_limit, 2) + expect_equal(adj_3$adjustments[[1]]$arguments$upper_limit, 3) + + adj_4 <- finalize_tailor(adjust_rng, tibble()) + expect_equal(adj_4, adjust_rng) + + expect_snapshot( + finalize_tailor(linear_reg(), tibble()), + error = TRUE + ) +}) + +test_that("finalize workflows with tailors", { + skip_if_not_installed("probably") + skip_if_not_installed("dials", "1.3.0.9000") + library(tailor) + library(purrr) + + adjust_rng <- + tailor() %>% + adjust_numeric_range(lower_limit = tune(), upper_limit = tune()) + wflow <- workflow(y ~ ., linear_reg(), adjust_rng) + + wflow_1 <- finalize_workflow(wflow, tibble(lower_limit = 2)) + expect_equal( + wflow_1 %>% + extract_postprocessor() %>% + pluck("adjustments") %>% + pluck(1) %>% + pluck("arguments") %>% + pluck("lower_limit"), + 2 + ) + expect_equal( + wflow_1 %>% + extract_postprocessor() %>% + pluck("adjustments") %>% + pluck(1) %>% + pluck("arguments") %>% + pluck("upper_limit"), + tune() + ) + + wflow_2 <- finalize_workflow(wflow, tibble(lower_limit = 2, upper_limit = 3)) + expect_equal( + wflow_2 %>% + extract_postprocessor() %>% + pluck("adjustments") %>% + pluck(1) %>% + pluck("arguments") %>% + pluck("lower_limit"), + 2 + ) + expect_equal( + wflow_2 %>% + extract_postprocessor() %>% + pluck("adjustments") %>% + pluck(1) %>% + pluck("arguments") %>% + pluck("upper_limit"), + 3 + ) + + wflow_3 <- finalize_workflow(wflow, tibble(lower_limit = 2, upper_limit = 3, a = 2)) + expect_equal( + wflow_3 %>% + extract_postprocessor() %>% + pluck("adjustments") %>% + pluck(1) %>% + pluck("arguments") %>% + pluck("lower_limit"), + 2 + ) + expect_equal( + wflow_3 %>% + extract_postprocessor() %>% + pluck("adjustments") %>% + pluck(1) %>% + pluck("arguments") %>% + pluck("upper_limit"), + 3 + ) + + wflow_4 <- finalize_workflow(wflow, tibble()) + expect_equal(wflow_4, wflow) +})