Skip to content

Commit

Permalink
Attach underlying metric functions to metric_set() output
Browse files Browse the repository at this point in the history
  • Loading branch information
DavisVaughan committed Aug 26, 2019
1 parent 733cc88 commit 880b61b
Show file tree
Hide file tree
Showing 4 changed files with 68 additions and 10 deletions.
2 changes: 2 additions & 0 deletions NEWS.md
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,8 @@

## Other improvements

* `metric_set()` output now includes a `metrics` attribute which contains a list of the original metric functions used to generate the metric set.

* Each metric function now has a `direction` attribute attached to it, specifying whether to minimize or maximize the metric.

* Classification metrics that can potentially have a `0` value denominator now throw an informative warning when this case occurs. These include `recall()`, `precision()`, `sens()`, and `spec()` (#98).
Expand Down
20 changes: 13 additions & 7 deletions R/metrics.R
Original file line number Diff line number Diff line change
Expand Up @@ -167,7 +167,6 @@ metrics.data.frame <- function(data, truth, estimate, ...,
#' probability columns) as bare column names or `tidyselect` selectors to `...`.
#'
#' @examples
#'
#' library(dplyr)
#'
#' # Multiple regression metrics
Expand All @@ -189,7 +188,8 @@ metrics.data.frame <- function(data, truth, estimate, ...,
#' # If you need to set options for certain metrics,
#' # do so by wrapping the metric and setting the options inside the wrapper,
#' # passing along truth and estimate as quoted arguments.
#' # Then add on the function class of the underlying wrapped function.
#' # Then add on the function class of the underlying wrapped function,
#' # and the direction of optimization.
#' ccc_with_bias <- function(data, truth, estimate, na_rm = TRUE, ...) {
#' ccc(
#' data = data,
Expand All @@ -202,8 +202,10 @@ metrics.data.frame <- function(data, truth, estimate, ...,
#' )
#' }
#'
#' # Add on the underlying function class (here, "numeric_metric")
#' # Add on the underlying function class (here, "numeric_metric"), and the
#' # direction to optimize the metric
#' class(ccc_with_bias) <- class(ccc)
#' attr(ccc_with_bias, "direction") <- attr(ccc, "direction")
#'
#' multi_metric2 <- metric_set(rmse, rsq, ccc_with_bias)
#'
Expand Down Expand Up @@ -349,11 +351,13 @@ make_prob_class_metric_function <- function(fns) {
class(metric_function)
)

attr(metric_function, "metrics") <- fns

metric_function
}

make_numeric_metric_function <- function(fns) {
numeric_metric_function <- function(data, truth, estimate, na_rm = TRUE, ...) {
metric_function <- function(data, truth, estimate, na_rm = TRUE, ...) {

# Construct common argument set for each metric call
# Doing this dynamically inside the generated function means
Expand Down Expand Up @@ -381,12 +385,14 @@ make_numeric_metric_function <- function(fns) {
bind_rows(metric_list)
}

class(numeric_metric_function) <- c(
class(metric_function) <- c(
"numeric_metric_set",
class(numeric_metric_function)
class(metric_function)
)

numeric_metric_function
attr(metric_function, "metrics") <- fns

metric_function
}

validate_not_empty <- function(x) {
Expand Down
8 changes: 5 additions & 3 deletions man/metric_set.Rd

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

48 changes: 48 additions & 0 deletions tests/testthat/test_metrics.R
Original file line number Diff line number Diff line change
Expand Up @@ -136,3 +136,51 @@ test_that('metric set functions are classed', {
"numeric_metric_set"
)
})

test_that('metric set functions retain class/prob metric functions', {
fns <- attr(metric_set(accuracy, roc_auc), "metrics")

expect_equal(
names(fns),
c("accuracy", "roc_auc")
)

expect_equal(
class(fns[[1]]),
c("class_metric", "function")
)

expect_equal(
class(fns[[2]]),
c("prob_metric", "function")
)

expect_equal(
vapply(fns, function(fn) attr(fn, "direction"), character(1)),
c(accuracy = "maximize", roc_auc = "maximize")
)
})

test_that('metric set functions retain numeric metric functions', {
fns <- attr(metric_set(mae, rmse), "metrics")

expect_equal(
names(fns),
c("mae", "rmse")
)

expect_equal(
class(fns[[1]]),
c("numeric_metric", "function")
)

expect_equal(
class(fns[[2]]),
c("numeric_metric", "function")
)

expect_equal(
vapply(fns, function(fn) attr(fn, "direction"), character(1)),
c(mae = "minimize", rmse = "minimize")
)
})

0 comments on commit 880b61b

Please sign in to comment.