diff --git a/R/surv_lime.R b/R/surv_lime.R index 6b0d67e..31b2cc3 100644 --- a/R/surv_lime.R +++ b/R/surv_lime.R @@ -43,7 +43,8 @@ surv_lime <- function(explainer, new_observation, N, categorical_variables, sampling_method, - sample_around_instance + sample_around_instance, + kernel_width ) @@ -122,7 +123,8 @@ generate_neighbourhood <- function(data_org, n_samples = 100, categorical_variables = NULL, sampling_method = "gaussian", - sample_around_instance = TRUE) { + sample_around_instance = TRUE, + kernel_width = NULL) { # change categorical_variables to column names if (is.numeric(categorical_variables)) categorical_variables <- colnames(data_org)[categorical_variables] @@ -131,6 +133,13 @@ generate_neighbourhood <- function(data_org, categorical_variables <- unique(c(additional_categorical_variables, factor_variables)) data_row <- data_row[colnames(data_org)] + if (is.character(kernel_width) && kernel_width == "silverman"){ + p <- ncol(data_org) + b <- (4/(n_samples*(p+2)))^(1/(p+4)) + } else { + b <- 1 + } + feature_frequencies <- list(length(categorical_variables)) scaled_data <- scale(data_org[, !colnames(data_org) %in% categorical_variables]) @@ -154,9 +163,9 @@ generate_neighbourhood <- function(data_org, if (sample_around_instance) { to_add <- data_row[, !colnames(data_row) %in% categorical_variables] - data <- data %*% diag(sc) + to_add[col(data)] + data <- data %*% (b * diag(sc)) + to_add[col(data)] } else { - data <- data %*% diag(sc) + me[col(data)] + data <- data %*% (b * diag(sc)) + me[col(data)] } data <- as.data.frame(data)