diff --git a/.ci/test-r-package-windows.ps1 b/.ci/test-r-package-windows.ps1 index a3f524b60be7..857b2789cbbc 100644 --- a/.ci/test-r-package-windows.ps1 +++ b/.ci/test-r-package-windows.ps1 @@ -177,7 +177,7 @@ Write-Output "Done installing CMake" Write-Output "Installing dependencies" $packages = -join @( - "c('data.table', 'jsonlite', 'knitr', 'markdown', 'Matrix', 'processx', 'R6', 'RhpcBLASctl', 'testthat'), ", + "c('data.table', 'DiagrammeR', 'jsonlite', 'knitr', 'markdown', 'Matrix', 'processx', 'R6', 'RhpcBLASctl', 'testthat'), ", "dependencies = c('Imports', 'Depends', 'LinkingTo')" ) $params = -join @( diff --git a/.ci/test-r-package.sh b/.ci/test-r-package.sh index 2e414ec0d282..55d37e6dff03 100755 --- a/.ci/test-r-package.sh +++ b/.ci/test-r-package.sh @@ -114,7 +114,7 @@ Rscript --vanilla -e "install.packages('https://cran.r-project.org/src/contrib/A # Manually install Depends and Imports libraries + 'knitr', 'markdown', 'RhpcBLASctl', 'testthat' # to avoid a CI-time dependency on devtools (for devtools::install_deps()) -packages="c('data.table', 'jsonlite', 'knitr', 'markdown', 'R6', 'RhpcBLASctl', 'testthat')" +packages="c('data.table', 'DiagrammeR', 'jsonlite', 'knitr', 'markdown', 'R6', 'RhpcBLASctl', 'testthat')" compile_from_source="both" if [[ $OS_NAME == "macos" ]]; then packages+=", type = 'binary'" diff --git a/.github/workflows/r_package.yml b/.github/workflows/r_package.yml index 66e05a18ba1f..c8506a414215 100644 --- a/.github/workflows/r_package.yml +++ b/.github/workflows/r_package.yml @@ -230,7 +230,7 @@ jobs: - name: Install packages shell: bash run: | - RDscript${{ matrix.r_customization }} -e "install.packages(c('R6', 'data.table', 'jsonlite', 'knitr', 'markdown', 'Matrix', 'RhpcBLASctl', 'testthat'), repos = 'https://cran.rstudio.com', Ncpus = parallel::detectCores())" + RDscript${{ matrix.r_customization }} -e "install.packages(c('R6', 'data.table', 'DiagrammeR', 'jsonlite', 'knitr', 'markdown', 'Matrix', 'RhpcBLASctl', 'testthat'), repos = 'https://cran.rstudio.com', Ncpus = parallel::detectCores())" sh build-cran-package.sh --r-executable=RD${{ matrix.r_customization }} RD${{ matrix.r_customization }} CMD INSTALL lightgbm_*.tar.gz || exit 1 - name: Run tests with sanitizers @@ -295,7 +295,7 @@ jobs: - name: Install packages and run tests shell: bash run: | - Rscript -e "install.packages(c('R6', 'data.table', 'jsonlite', 'knitr', 'markdown', 'Matrix', 'RhpcBLASctl'), repos = 'https://cran.rstudio.com', Ncpus = parallel::detectCores())" + Rscript -e "install.packages(c('R6', 'data.table', 'DiagrammeR', 'jsonlite', 'knitr', 'markdown', 'Matrix', 'RhpcBLASctl'), repos = 'https://cran.rstudio.com', Ncpus = parallel::detectCores())" sh build-cran-package.sh # 'rchk' isn't run through 'R CMD check', use the approach documented at diff --git a/.github/workflows/static_analysis.yml b/.github/workflows/static_analysis.yml index 34e573e0eea6..872ef9dbac14 100644 --- a/.github/workflows/static_analysis.yml +++ b/.github/workflows/static_analysis.yml @@ -64,7 +64,7 @@ jobs: - name: Install packages shell: bash run: | - Rscript -e "install.packages(c('R6', 'data.table', 'jsonlite', 'knitr', 'markdown', 'Matrix', 'RhpcBLASctl', 'roxygen2', 'testthat'), repos = 'https://cran.rstudio.com', Ncpus = parallel::detectCores())" + Rscript -e "install.packages(c('R6', 'data.table', 'DiagrammeR', 'jsonlite', 'knitr', 'markdown', 'Matrix', 'RhpcBLASctl', 'roxygen2', 'testthat'), repos = 'https://cran.rstudio.com', Ncpus = parallel::detectCores())" sh build-cran-package.sh || exit 1 R CMD INSTALL --with-keep.source lightgbm_*.tar.gz || exit 1 - name: Test documentation diff --git a/.vsts-ci.yml b/.vsts-ci.yml index 40424840c82d..6f99e37189cf 100644 --- a/.vsts-ci.yml +++ b/.vsts-ci.yml @@ -392,7 +392,7 @@ jobs: R_LIB_PATH=~/Rlib export R_LIBS=${R_LIB_PATH} mkdir -p ${R_LIB_PATH} - RDscript -e "install.packages(c('R6', 'data.table', 'jsonlite', 'knitr', 'markdown', 'Matrix', 'RhpcBLASctl'), lib = '${R_LIB_PATH}', dependencies = c('Depends', 'Imports', 'LinkingTo'), repos = 'https://cran.rstudio.com', Ncpus = parallel::detectCores())" || exit 1 + RDscript -e "install.packages(c('R6', 'data.table', 'DiagrammeR', 'jsonlite', 'knitr', 'markdown', 'Matrix', 'RhpcBLASctl'), lib = '${R_LIB_PATH}', dependencies = c('Depends', 'Imports', 'LinkingTo'), repos = 'https://cran.rstudio.com', Ncpus = parallel::detectCores())" || exit 1 sh build-cran-package.sh --r-executable=RD || exit 1 mv lightgbm_${LGB_VER}.tar.gz $(Build.ArtifactStagingDirectory)/lightgbm-${LGB_VER}-r-cran.tar.gz displayName: 'Build CRAN R-package' diff --git a/R-package/DESCRIPTION b/R-package/DESCRIPTION index c9344ceebab7..4f3730b25593 100755 --- a/R-package/DESCRIPTION +++ b/R-package/DESCRIPTION @@ -45,6 +45,7 @@ NeedsCompilation: yes Biarch: true VignetteBuilder: knitr Suggests: + DiagrammeR, knitr, markdown, processx, diff --git a/R-package/NAMESPACE b/R-package/NAMESPACE index 49ef2b5cb8fc..4f5c308ac3df 100644 --- a/R-package/NAMESPACE +++ b/R-package/NAMESPACE @@ -29,6 +29,7 @@ export(lgb.make_serializable) export(lgb.model.dt.tree) export(lgb.plot.importance) export(lgb.plot.interpretation) +export(lgb.plot.tree) export(lgb.restore_handle) export(lgb.save) export(lgb.slice.Dataset) diff --git a/R-package/R/lgb.plot.tree.R b/R-package/R/lgb.plot.tree.R new file mode 100644 index 000000000000..5df73f2de17e --- /dev/null +++ b/R-package/R/lgb.plot.tree.R @@ -0,0 +1,204 @@ +#' @name lgb.plot.tree +#' @title Plot a single LightGBM tree. +#' @description The \code{lgb.plot.tree} function creates a DiagrammeR plot of a single LightGBM tree. +#' @param model a \code{lgb.Booster} object. +#' @param tree an integer specifying the tree to plot. This is 1-based, so e.g. a value of '7' means 'the 7th tree' (tree_index=6 in LightGBM's underlying representation). +#' @param rules a list of rules to replace the split values with feature levels. +#' +#' @return +#' The \code{lgb.plot.tree} function creates a DiagrammeR plot. +#' +#' @details +#' The \code{lgb.plot.tree} function creates a DiagrammeR plot of a single LightGBM tree. The tree is extracted from the model and displayed as a directed graph. The nodes are labelled with the feature, split value, gain, cover and value. The edges are labelled with the decision type and split value. +#' +#' @examples +#' \donttest{ +#' # EXAMPLE: use the LightGBM example dataset to build a model with a single tree +#' data(agaricus.train, package = "lightgbm") +#' train <- agaricus.train +#' dtrain <- lgb.Dataset(train$data, label = train$label) +#' data(agaricus.test, package = "lightgbm") +#' test <- agaricus.test +#' dtest <- lgb.Dataset.create.valid(dtrain, test$data, label = test$label) +#' # define model parameters and build a single tree +#' params <- list( +#' objective = "regression", +#' min_data = 1L, +#' ) +#' valids <- list(test = dtest) +#' model <- lgb.train( +#' params = params, +#' data = dtrain, +#' nrounds = 1L, +#' valids = valids, +#' early_stopping_rounds = 1L +#' ) +#' # plot the tree and compare to the tree table +#' # trees start from 0 in lgb.model.dt.tree +#' tree_table <- lgb.model.dt.tree(model) +#' lgb.plot.tree(model, 0) +#' } +#' +#' @export +lgb.plot.tree <- function(model = NULL, tree = NULL, rules = NULL) { + # check model is lgb.Booster + if (!.is_Booster(x = model)) { + stop("lgb.plot.tree: model should be an ", sQuote("lgb.Booster")) + } + # check DiagrammeR is available + if (!requireNamespace("DiagrammeR", quietly = TRUE)) { + stop("lgb.plot.tree: DiagrammeR package is required", + call. = FALSE + ) + } + # tree must be numeric + if (!inherits(tree, "numeric")) { + stop("lgb.plot.tree: Has to be an integer numeric") + } + # tree must be integer + if (tree %% 1 != 0) { + stop("lgb.plot.tree: Has to be an integer numeric") + } + # extract data.table model structure + modelDT <- lgb.model.dt.tree(model) + # check that tree is less than or equal to the maximum tree index in the model + if (tree > max(modelDT$tree_index) || tree < 1) { + warning("lgb.plot.tree: Value of 'tree' should be between 1 and the total number of trees in the model (", max(modelDT$tree_index), "). Got: ", tree, ".") + stop("lgb.plot.tree: Invalid tree number") + } + # filter modelDT to just the rows for the selected tree + modelDT <- modelDT[tree_index == tree, ] + # change the column names to shorter more diagram friendly versions + data.table::setnames(modelDT + , old = c("tree_index", "split_feature", "threshold", "split_gain") + , new = c("Tree", "Feature", "Split", "Gain")) + # assign leaf_value to the Value column in modelDT + modelDT[, Value := leaf_value] + # assign new values if NA + modelDT[is.na(Value), Value := internal_value] + modelDT[is.na(Gain), Gain := leaf_value] + modelDT[is.na(Feature), Feature := "Leaf"] + # assign internal_count to Cover, and if Feature is "Leaf", assign leaf_count to Cover + modelDT[, Cover := internal_count][Feature == "Leaf", Cover := leaf_count] + # remove unnecessary columns + modelDT[, c("leaf_count", "internal_count", "leaf_value", "internal_value") := NULL] + # assign split_index to Node + modelDT[, Node := split_index] + # find the maximum value of Node, if Node is NA, assign max_node + leaf_index + 1 to Node + max_node <- max(modelDT[["Node"]], na.rm = TRUE) + modelDT[is.na(Node), Node := max_node + leaf_index + 1] + # adding ID column + modelDT[, ID := paste(Tree, Node, sep = "-")] + # remove unnecessary columns + modelDT[, c("depth", "leaf_index") := NULL] + modelDT[, parent := node_parent][is.na(parent), parent := leaf_parent] + modelDT[, c("node_parent", "leaf_parent", "split_index") := NULL] + # assign the IDs of the matching parent nodes to Yes and No + modelDT[, Yes := modelDT$ID[match(modelDT$Node, modelDT$parent)]] + modelDT <- modelDT[nrow(modelDT):1, ] + modelDT[, No := modelDT$ID[match(modelDT$Node, modelDT$parent)]] + # which way do the NA's go (this path will get a thicker arrow) + # for categorical features, NA gets put into the zero group + modelDT[default_left == TRUE, Missing := Yes] + modelDT[default_left == FALSE, Missing := No] + modelDT[.zero_present(Split), Missing := Yes] + # create the label text + modelDT[, label := paste0( + Feature + , "\nCover: " + , Cover + , ifelse(Feature == "Leaf", "", "\nGain: "), ifelse(Feature == "Leaf" + , "" + , round(Gain, 4)) + , "\nValue: " + , round(Value, 4) + )] + # style the nodes - same format as xgboost + modelDT[Node == 0, label := paste0("Tree ", Tree, "\n", label)] + modelDT[, shape := "rectangle"][Feature == "Leaf", shape := "oval"] + modelDT[, filledcolor := "Beige"][Feature == "Leaf", filledcolor := "Khaki"] + # in order to draw the first tree on top: + modelDT <- modelDT[order(-Tree)] + nodes <- DiagrammeR::create_node_df( + n = nrow(modelDT) + , ID = modelDT$ID + , label = modelDT$label + , fillcolor = modelDT$filledcolor + , shape = modelDT$shape + , data = modelDT$Feature + , fontcolor = "black" + ) + # round the edge labels to 4 s.f. if they are numeric + # as otherwise get too many decimal places and the diagram looks bad + # would rather not use suppressWarnings + numeric_idx <- suppressWarnings(!is.na(as.numeric(modelDT[["Split"]]))) + modelDT[numeric_idx, Split := round(as.numeric(Split), 4)] + # replace indices with feature levels if rules supplied + + if (!is.null(rules)) { + for (f in names(rules)) { + modelDT[Feature == f & decision_type == "==", Split := .levels.to.names(Split, f, rules)] + } + } + # replace long split names with a message + modelDT[nchar(Split) > 500, Split := "Split too long to render"] + # create the edge labels + edges <- DiagrammeR::create_edge_df( + from = match(modelDT[Feature != "Leaf", c(ID)] %>% rep(2), modelDT$ID), + to = match(modelDT[Feature != "Leaf", c(Yes, No)], modelDT$ID), + label = modelDT[Feature != "Leaf", paste(decision_type, Split)] %>% + c(rep("", nrow(modelDT[Feature != "Leaf"]))), + style = modelDT[Feature != "Leaf", ifelse(Missing == Yes, "bold", "solid")] %>% + c(modelDT[Feature != "Leaf", ifelse(Missing == No, "bold", "solid")]), + rel = "leading_to" + ) + # create the graph + graph <- DiagrammeR::create_graph( + nodes_df = nodes + , edges_df = edges + , attr_theme = NULL + ) + graph <- DiagrammeR::add_global_graph_attrs( + graph = graph + , attr_type = "graph" + , attr = c("layout", "rankdir") + , value = c("dot", "LR") + ) + graph <- DiagrammeR::add_global_graph_attrs( + graph = graph + , attr_type = "node" + , attr = c("color", "style", "fontname") + , value = c("DimGray", "filled", "Helvetica") + ) + graph <- DiagrammeR::add_global_graph_attrs( + graph = graph + , attr_type = "edge" + , attr = c("color", "arrowsize", "arrowhead", "fontname") + , value = c("DimGray", "1.5", "vee", "Helvetica") + ) + # render the graph + DiagrammeR::render_graph(graph) + return(invisible(NULL)) +} + +.zero_present <- function(x) { + sapply(strsplit(as.character(x), "||", fixed = TRUE), function(el) { + any(el == "0") + }) + return(invisible(NULL)) +} + +.levels.to.names <- function(x, feature_name, rules) { + lvls <- sort(rules[[feature_name]]) + result <- strsplit(x, "||", fixed = TRUE) + result <- lapply(result, as.numeric) + result <- lapply(result, .levels_to_names) + result <- lapply(result, paste, collapse = "\n") + result <- as.character(result) + return(invisible(NULL)) +} + +.levels_to_names <- function(x) { + names(lvls)[as.numeric(x)] + return(invisible(NULL)) +} \ No newline at end of file diff --git a/R-package/README.md b/R-package/README.md index f1821f5cc6be..8900f5c5ccec 100644 --- a/R-package/README.md +++ b/R-package/README.md @@ -428,7 +428,7 @@ docker run \ # install dependencies RDscript${R_CUSTOMIZATION} \ - -e "install.packages(c('R6', 'data.table', 'jsonlite', 'knitr', 'markdown', 'Matrix', 'RhpcBLASctl', 'testthat'), repos = 'https://cran.r-project.org', Ncpus = parallel::detectCores())" + -e "install.packages(c('R6', 'data.table', 'DiagrammeR', 'jsonlite', 'knitr', 'markdown', 'Matrix', 'RhpcBLASctl', 'testthat'), repos = 'https://cran.r-project.org', Ncpus = parallel::detectCores())" # install lightgbm sh build-cran-package.sh --r-executable=RD${R_CUSTOMIZATION} @@ -459,7 +459,7 @@ docker run \ -it \ wch1/r-debug -RDscriptvalgrind -e "install.packages(c('R6', 'data.table', 'jsonlite', 'knitr', 'markdown', 'Matrix', 'RhpcBLASctl', 'testthat'), repos = 'https://cran.rstudio.com', Ncpus = parallel::detectCores())" +RDscriptvalgrind -e "install.packages(c('R6', 'data.table', 'DiagrammeR', 'jsonlite', 'knitr', 'markdown', 'Matrix', 'RhpcBLASctl', 'testthat'), repos = 'https://cran.rstudio.com', Ncpus = parallel::detectCores())" sh build-cran-package.sh \ --r-executable=RDvalgrind diff --git a/R-package/man/lgb.plot.tree.Rd b/R-package/man/lgb.plot.tree.Rd new file mode 100644 index 000000000000..e48cfe420265 --- /dev/null +++ b/R-package/man/lgb.plot.tree.Rd @@ -0,0 +1,55 @@ +% Generated by roxygen2: do not edit by hand +% Please edit documentation in R/lgb.plot.tree.R +\name{lgb.plot.tree} +\alias{lgb.plot.tree} +\title{Plot a single LightGBM tree using DiagrammeR.} +\usage{ +lgb.plot.tree(model = NULL, tree = NULL, rules = NULL) +} +\arguments{ +\item{model}{a \code{lgb.Booster} object.} + +\item{tree}{an integer specifying the tree to plot.} + +\item{rules}{a list of rules to replace the split values with feature levels.} +} +\value{ +The \code{lgb.plot.tree} function creates a DiagrammeR plot. +} +\description{ +The \code{lgb.plot.tree} function creates a DiagrammeR plot of a single LightGBM tree. +} +\details{ +The \code{lgb.plot.tree} function creates a DiagrammeR plot of a single LightGBM tree. The tree is extracted from the model and displayed as a directed graph. The nodes are labelled with the feature, split value, gain, cover and value. The edges are labelled with the decision type and split value. The nodes are styled with a rectangle shape and filled with a beige colour. Leaf nodes are styled with an oval shape and filled with a khaki colour. The graph is rendered using the dot layout with a left-to-right rank direction. The nodes are coloured dim gray with a filled style and a Helvetica font. The edges are coloured dim gray with a solid style, a 1.5 arrow size, a vee arrowhead and a Helvetica font. +} +\examples{ +\donttest{ +# EXAMPLE: use the LightGBM example dataset to build a model with a single tree +data(agaricus.train, package = "lightgbm") +train <- agaricus.train +dtrain <- lgb.Dataset(train$data, label = train$label) +data(agaricus.test, package = "lightgbm") +test <- agaricus.test +dtest <- lgb.Dataset.create.valid(dtrain, test$data, label = test$label) +# define model parameters and build a single tree +params <- list( + objective = "regression", + metric = "l2", + min_data = 1L, + learning_rate = 1.0 +) +valids <- list(test = dtest) +model <- lgb.train( + params = params, + data = dtrain, + nrounds = 1L, + valids = valids, + early_stopping_rounds = 1L +) +# plot the tree and compare to the tree table +# trees start from 0 in lgb.model.dt.tree +tree_table <- lgb.model.dt.tree(model) +lgb.plot.tree(model, 0) +} + +} diff --git a/R-package/pkgdown/_pkgdown.yml b/R-package/pkgdown/_pkgdown.yml index c2d6718a2926..e2a6d7e6c7ac 100644 --- a/R-package/pkgdown/_pkgdown.yml +++ b/R-package/pkgdown/_pkgdown.yml @@ -97,6 +97,7 @@ reference: - '`lgb.interprete`' - '`lgb.plot.importance`' - '`lgb.plot.interpretation`' + - '`lgb.plot.tree`' - '`print.lgb.Booster`' - '`summary.lgb.Booster`' - title: Multithreading Control diff --git a/R-package/tests/testthat/test_lgb.plot.tree.R b/R-package/tests/testthat/test_lgb.plot.tree.R new file mode 100644 index 000000000000..857b61030544 --- /dev/null +++ b/R-package/tests/testthat/test_lgb.plot.tree.R @@ -0,0 +1,95 @@ +NROUNDS <- 10L +MAX_DEPTH <- 3L +N <- nrow(iris) +X <- data.matrix(iris[2L:4L]) +FEAT <- colnames(X) +NCLASS <- nlevels(iris[, 5L]) + +model_reg <- lgb.train( + params = list( + objective = "regression" + , num_threads = .LGB_MAX_THREADS + , max.depth = MAX_DEPTH + ) + , data = lgb.Dataset(X, label = iris[, 1L]) + , verbose = .LGB_VERBOSITY + , nrounds = NROUNDS +) + +model_binary <- lgb.train( + params = list( + objective = "binary" + , num_threads = .LGB_MAX_THREADS + , max.depth = MAX_DEPTH + ) + , data = lgb.Dataset(X, label = iris[, 5L] == "setosa") + , verbose = .LGB_VERBOSITY + , nrounds = NROUNDS +) + +model_multiclass <- lgb.train( + params = list( + objective = "multiclass" + , num_threads = .LGB_MAX_THREADS + , max.depth = MAX_DEPTH + , num_classes = NCLASS + ) + , data = lgb.Dataset(X, label = as.integer(iris[, 5L]) - 1L) + , verbose = .LGB_VERBOSITY + , nrounds = NROUNDS +) + +model_rank <- lgb.train( + params = list( + objective = "lambdarank" + , num_threads = .LGB_MAX_THREADS + , max.depth = MAX_DEPTH + , lambdarank_truncation_level = 3L + ) + , data = lgb.Dataset( + X + , label = as.integer(iris[, 1L] > 5.8) + , group = rep(10L, times = 15L) + ) + , verbose = .LGB_VERBOSITY + , nrounds = NROUNDS +) + +models <- list( + reg = model_reg + , bin = model_binary + , multi = model_multiclass + , rank = model_rank +) + +for (model_name in names(models)){ + model <- models[[model_name]] + modelDT <- lgb.model.dt.tree(model) + + test_that("lgb.plot.tree fails when a non existing tree is selected", { + expect_error({ + lgb.plot.tree(model, 0) + }, regexp = paste0("lgb.plot.tree: Invalid tree number")) + }) + test_that("lgb.plot.tree fails when a non existing tree is selected", { + expect_error({ + lgb.plot.tree(model, 999) + }, regexp = paste0("lgb.plot.tree: Invalid tree number")) + }) + test_that("lgb.plot.tree fails when a non numeric tree is selected", { + expect_error({ + lgb.plot.tree(model, "a") + }, regexp = "lgb.plot.tree: Has to be an integer numeric") + }) + test_that("lgb.plot.tree fails when a non integer tree is selected", { + expect_error({ + lgb.plot.tree(model, 1.5) + }, regexp = "lgb.plot.tree: Has to be an integer numeric") + }) + test_that("lgb.plot.tree fails when a non lgb.Booster model is passed", { + expect_error({ + lgb.plot.tree(1, 0) + }, regexp = paste0("lgb.plot.tree: model should be an ", sQuote("lgb.Booster"))) + }) +} +