Skip to content

Commit

Permalink
use cov matrix with kernel width for sampling
Browse files Browse the repository at this point in the history
  • Loading branch information
krzyzinskim committed Nov 6, 2023
1 parent a6292ac commit 6682a31
Showing 1 changed file with 13 additions and 4 deletions.
17 changes: 13 additions & 4 deletions R/surv_lime.R
Original file line number Diff line number Diff line change
Expand Up @@ -43,7 +43,8 @@ surv_lime <- function(explainer, new_observation,
N,
categorical_variables,
sampling_method,
sample_around_instance
sample_around_instance,
kernel_width
)


Expand Down Expand Up @@ -122,7 +123,8 @@ generate_neighbourhood <- function(data_org,
n_samples = 100,
categorical_variables = NULL,
sampling_method = "gaussian",
sample_around_instance = TRUE) {
sample_around_instance = TRUE,
kernel_width = NULL) {

# change categorical_variables to column names
if (is.numeric(categorical_variables)) categorical_variables <- colnames(data_org)[categorical_variables]
Expand All @@ -131,6 +133,13 @@ generate_neighbourhood <- function(data_org,
categorical_variables <- unique(c(additional_categorical_variables, factor_variables))
data_row <- data_row[colnames(data_org)]

if (is.character(kernel_width) && kernel_width == "silverman"){
p <- ncol(data_org)
b <- (4/(n_samples*(p+2)))^(1/(p+4))
} else {
b <- 1
}

feature_frequencies <- list(length(categorical_variables))
scaled_data <- scale(data_org[, !colnames(data_org) %in% categorical_variables])

Expand All @@ -154,9 +163,9 @@ generate_neighbourhood <- function(data_org,

if (sample_around_instance) {
to_add <- data_row[, !colnames(data_row) %in% categorical_variables]
data <- data %*% diag(sc) + to_add[col(data)]
data <- data %*% (b * diag(sc)) + to_add[col(data)]
} else {
data <- data %*% diag(sc) + me[col(data)]
data <- data %*% (b * diag(sc)) + me[col(data)]
}

data <- as.data.frame(data)
Expand Down

0 comments on commit 6682a31

Please sign in to comment.