Skip to content

Commit

Permalink
finalizing tailors and workflows with tailors (#973)
Browse files Browse the repository at this point in the history
* changes for #972

* testing needs probably

* requires dev dials

* Apply suggestions from code review

Co-authored-by: Simon P. Couch <simonpatrickcouch@gmail.com>

* updates from reviewer feedback

* tidy format 🙄

---------

Co-authored-by: Simon P. Couch <simonpatrickcouch@gmail.com>
  • Loading branch information
topepo and simonpcouch authored Dec 3, 2024
1 parent f6772f4 commit d62199a
Show file tree
Hide file tree
Showing 7 changed files with 168 additions and 4 deletions.
6 changes: 4 additions & 2 deletions DESCRIPTION
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,7 @@ Depends:
R (>= 4.0)
Imports:
cli (>= 3.3.0),
dials (>= 1.3.0),
dials (>= 1.3.0.9000),
doFuture (>= 1.0.0),
dplyr (>= 1.1.0),
foreach,
Expand Down Expand Up @@ -50,6 +50,7 @@ Suggests:
kknn,
knitr,
modeldata,
probably,
scales,
spelling,
splines2,
Expand All @@ -62,7 +63,8 @@ Remotes:
tidymodels/recipes,
tidymodels/rsample,
tidymodels/tailor,
tidymodels/workflows
tidymodels/workflows,
tidymodels/dials
Config/Needs/website: pkgdown, tidymodels, kknn, doParallel, doFuture,
tidyverse/tidytemplate
Config/testthat/edition: 3
Expand Down
1 change: 1 addition & 0 deletions NAMESPACE
Original file line number Diff line number Diff line change
Expand Up @@ -193,6 +193,7 @@ export(extract_workflow)
export(filter_parameters)
export(finalize_model)
export(finalize_recipe)
export(finalize_tailor)
export(finalize_workflow)
export(finalize_workflow_preprocessor)
export(first_eval_time)
Expand Down
37 changes: 37 additions & 0 deletions R/finalize.R
Original file line number Diff line number Diff line change
Expand Up @@ -95,6 +95,43 @@ finalize_workflow <- function(x, parameters) {
x <- set_workflow_recipe(x, rec)
}

if (has_postprocessor(x)) {
tailor <- extract_postprocessor(x)
tailor <- finalize_tailor(tailor, parameters)
x <- set_workflow_tailor(x, tailor)
}

x
}

#' @export
#' @rdname finalize_model
finalize_tailor <- function(x, parameters) {
if (!inherits(x, "tailor")) {
cli::cli_abort("{.arg x} should be a tailor, not {.obj_type_friendly {x}}.")
}
check_final_param(parameters)
pset <-
hardhat::extract_parameter_set_dials(x) %>%
dplyr::filter(id %in% names(parameters) & source == "tailor")

if (tibble::is_tibble(parameters)) {
parameters <- as.list(parameters)
}

parameters <- parameters[names(parameters) %in% pset$id]
parameters <- parameters[pset$id]

for (i in seq_along(x$adjustments)) {
adj <- x$adjustments[[i]]
adj_comps <- purrr::map_lgl(pset$component, ~ inherits(adj, .x))
if (any(adj_comps)) {
adj_ids <- pset$id[adj_comps]
adj_prms <- parameters[names(parameters) %in% adj_ids]
adj$arguments <- purrr::list_modify(adj$arguments, !!!adj_prms)
x$adjustments[[i]] <- adj
}
}
x
}

Expand Down
9 changes: 7 additions & 2 deletions R/grid_helpers.R
Original file line number Diff line number Diff line change
Expand Up @@ -21,10 +21,10 @@ predict_model <- function(new_data, orig_rows, workflow, grid, metrics,
msg <-
c(
msg,
i =
i =
"Consider using {.code skip = TRUE} on any recipe steps that
remove rows to avoid calling them on the assessment set."

)
} else {
msg <- c(msg, i = "Did your preprocessing steps filter or remove rows?")
Expand Down Expand Up @@ -464,3 +464,8 @@ set_workflow_recipe <- function(workflow, recipe) {
workflow$pre$actions$recipe$recipe <- recipe
workflow
}

set_workflow_tailor <- function(workflow, tailor) {
workflow$post$actions$tailor$tailor <- tailor
workflow
}
3 changes: 3 additions & 0 deletions man/finalize_model.Rd

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

8 changes: 8 additions & 0 deletions tests/testthat/_snaps/finalization.md
Original file line number Diff line number Diff line change
Expand Up @@ -7,3 +7,11 @@
! Some model parameters require finalization but there are recipe parameters that require tuning.
i Please use `extract_parameter_set_dials()` to set parameter ranges manually and supply the output to the `param_info` argument.

# finalize tailors

Code
finalize_tailor(linear_reg(), tibble())
Condition
Error in `finalize_tailor()`:
! `x` should be a tailor, not a <linear_reg> object.

108 changes: 108 additions & 0 deletions tests/testthat/test-finalization.R
Original file line number Diff line number Diff line change
Expand Up @@ -73,3 +73,111 @@ test_that("finalize recipe step with multiple tune parameters", {
expect_equal(finalize_recipe(rec, best)$steps[[1]]$degree, 1)
expect_equal(finalize_recipe(rec, best)$steps[[1]]$deg_free, 2)
})

# ------------------------------------------------------------------------------
# post-processing

test_that("finalize tailors", {
skip_if_not_installed("probably")
skip_if_not_installed("dials", "1.3.0.9000")
library(tailor)

adjust_rng <-
tailor() %>%
adjust_numeric_range(lower_limit = tune(), upper_limit = tune())

adj_1 <- finalize_tailor(adjust_rng, tibble(lower_limit = 2))
expect_equal(adj_1$adjustments[[1]]$arguments$lower_limit, 2)
expect_equal(adj_1$adjustments[[1]]$arguments$upper_limit, tune())

adj_2 <- finalize_tailor(adjust_rng, tibble(lower_limit = 2, upper_limit = 3))
expect_equal(adj_2$adjustments[[1]]$arguments$lower_limit, 2)
expect_equal(adj_2$adjustments[[1]]$arguments$upper_limit, 3)

adj_3 <- finalize_tailor(adjust_rng, tibble(lower_limit = 2, upper_limit = 3, a = 2))
expect_equal(adj_3$adjustments[[1]]$arguments$lower_limit, 2)
expect_equal(adj_3$adjustments[[1]]$arguments$upper_limit, 3)

adj_4 <- finalize_tailor(adjust_rng, tibble())
expect_equal(adj_4, adjust_rng)

expect_snapshot(
finalize_tailor(linear_reg(), tibble()),
error = TRUE
)
})

test_that("finalize workflows with tailors", {
skip_if_not_installed("probably")
skip_if_not_installed("dials", "1.3.0.9000")
library(tailor)
library(purrr)

adjust_rng <-
tailor() %>%
adjust_numeric_range(lower_limit = tune(), upper_limit = tune())
wflow <- workflow(y ~ ., linear_reg(), adjust_rng)

wflow_1 <- finalize_workflow(wflow, tibble(lower_limit = 2))
expect_equal(
wflow_1 %>%
extract_postprocessor() %>%
pluck("adjustments") %>%
pluck(1) %>%
pluck("arguments") %>%
pluck("lower_limit"),
2
)
expect_equal(
wflow_1 %>%
extract_postprocessor() %>%
pluck("adjustments") %>%
pluck(1) %>%
pluck("arguments") %>%
pluck("upper_limit"),
tune()
)

wflow_2 <- finalize_workflow(wflow, tibble(lower_limit = 2, upper_limit = 3))
expect_equal(
wflow_2 %>%
extract_postprocessor() %>%
pluck("adjustments") %>%
pluck(1) %>%
pluck("arguments") %>%
pluck("lower_limit"),
2
)
expect_equal(
wflow_2 %>%
extract_postprocessor() %>%
pluck("adjustments") %>%
pluck(1) %>%
pluck("arguments") %>%
pluck("upper_limit"),
3
)

wflow_3 <- finalize_workflow(wflow, tibble(lower_limit = 2, upper_limit = 3, a = 2))
expect_equal(
wflow_3 %>%
extract_postprocessor() %>%
pluck("adjustments") %>%
pluck(1) %>%
pluck("arguments") %>%
pluck("lower_limit"),
2
)
expect_equal(
wflow_3 %>%
extract_postprocessor() %>%
pluck("adjustments") %>%
pluck(1) %>%
pluck("arguments") %>%
pluck("upper_limit"),
3
)

wflow_4 <- finalize_workflow(wflow, tibble())
expect_equal(wflow_4, wflow)
})

0 comments on commit d62199a

Please sign in to comment.