Skip to content

Commit

Permalink
started implementing an exception handling for the predictions (finit…
Browse files Browse the repository at this point in the history
…e / not missing), see #136
  • Loading branch information
MalteKurz committed Oct 26, 2021
1 parent 52301d2 commit c6a9a8d
Show file tree
Hide file tree
Showing 5 changed files with 33 additions and 0 deletions.
5 changes: 5 additions & 0 deletions R/double_ml_iivm.R
Original file line number Diff line number Diff line change
Expand Up @@ -297,6 +297,7 @@ DoubleMLIIVM = R6Class("DoubleMLIIVM",
return_train_preds = FALSE,
learner_class = private$learner_class$ml_m,
fold_specific_params = private$fold_specific_params)
check_finite_predictions(m_hat, self$learner$ml_m$id, "ml_m", smpls)

g0_hat = dml_cv_predict(self$learner$ml_g,
c(self$data$x_cols, self$data$other_treat_cols),
Expand All @@ -308,6 +309,7 @@ DoubleMLIIVM = R6Class("DoubleMLIIVM",
return_train_preds = FALSE,
learner_class = private$learner_class$ml_g,
fold_specific_params = private$fold_specific_params)
check_finite_predictions(g0_hat, self$learner$ml_g$id, "ml_g0", smpls)

g1_hat = dml_cv_predict(self$learner$ml_g,
c(self$data$x_cols, self$data$other_treat_cols),
Expand All @@ -319,6 +321,7 @@ DoubleMLIIVM = R6Class("DoubleMLIIVM",
return_train_preds = FALSE,
learner_class = private$learner_class$ml_g,
fold_specific_params = private$fold_specific_params)
check_finite_predictions(g1_hat, self$learner$ml_g$id, "ml_g1", smpls)

if (self$subgroups$always_takers == FALSE) {
r0_hat = rep(0, self$data$n_obs)
Expand All @@ -333,6 +336,7 @@ DoubleMLIIVM = R6Class("DoubleMLIIVM",
return_train_preds = FALSE,
learner_class = private$learner_class$ml_r,
fold_specific_params = private$fold_specific_params)
check_finite_predictions(r0_hat, self$learner$ml_r$id, "ml_r0", smpls)
}

if (self$subgroups$never_takers == FALSE) {
Expand All @@ -348,6 +352,7 @@ DoubleMLIIVM = R6Class("DoubleMLIIVM",
return_train_preds = FALSE,
learner_class = private$learner_class$ml_r,
fold_specific_params = private$fold_specific_params)
check_finite_predictions(r1_hat, self$learner$ml_r$id, "ml_r1", smpls)
}

# compute residuals
Expand Down
3 changes: 3 additions & 0 deletions R/double_ml_irm.R
Original file line number Diff line number Diff line change
Expand Up @@ -229,6 +229,7 @@ DoubleMLIRM = R6Class("DoubleMLIRM",
return_train_preds = FALSE,
learner_class = private$learner_class$ml_m,
fold_specific_params = private$fold_specific_params)
check_finite_predictions(m_hat, self$learner$ml_m$id, "ml_m", smpls)

g0_hat = dml_cv_predict(self$learner$ml_g,
c(self$data$x_cols, self$data$other_treat_cols),
Expand All @@ -240,6 +241,7 @@ DoubleMLIRM = R6Class("DoubleMLIRM",
return_train_preds = FALSE,
learner_class = private$learner_class$ml_g,
fold_specific_params = private$fold_specific_params)
check_finite_predictions(g0_hat, self$learner$ml_g$id, "ml_g0", smpls)

g1_hat = NULL
if ((is.character(self$score) && self$score == "ATE") || is.function(self$score)) {
Expand All @@ -253,6 +255,7 @@ DoubleMLIRM = R6Class("DoubleMLIRM",
return_train_preds = FALSE,
learner_class = private$learner_class$ml_g,
fold_specific_params = private$fold_specific_params)
check_finite_predictions(g1_hat, self$learner$ml_g$id, "ml_g1", smpls)
}

d = self$data$data_model[[self$data$treat_col]]
Expand Down
8 changes: 8 additions & 0 deletions R/double_ml_pliv.R
Original file line number Diff line number Diff line change
Expand Up @@ -270,6 +270,7 @@ DoubleMLPLIV = R6Class("DoubleMLPLIV",
return_train_preds = FALSE,
learner_class = private$learner_class$ml_g,
fold_specific_params = private$fold_specific_params)
check_finite_predictions(g_hat, self$learner$ml_g$id, "ml_g", smpls)

r_hat = dml_cv_predict(self$learner$ml_r,
c(self$data$x_cols, self$data$other_treat_cols),
Expand All @@ -281,6 +282,7 @@ DoubleMLPLIV = R6Class("DoubleMLPLIV",
return_train_preds = FALSE,
learner_class = private$learner_class$ml_r,
fold_specific_params = private$fold_specific_params)
check_finite_predictions(r_hat, self$learner$ml_r$id, "ml_r", smpls)

if (self$data$n_instr == 1) {
m_hat = dml_cv_predict(self$learner$ml_m,
Expand All @@ -293,6 +295,7 @@ DoubleMLPLIV = R6Class("DoubleMLPLIV",
return_train_preds = FALSE,
learner_class = private$learner_class$ml_m,
fold_specific_params = private$fold_specific_params)
check_finite_predictions(m_hat, self$learner$ml_m$id, "ml_m", smpls)
} else {
m_hat = do.call(
cbind,
Expand All @@ -310,6 +313,7 @@ DoubleMLPLIV = R6Class("DoubleMLPLIV",
learner_class = private$learner_class$ml_m,
fold_specific_params = private$fold_specific_params)
}))
check_finite_predictions(m_hat, self$learner$ml_m$id, "ml_m_", smpls)
}

d = self$data$data_model[[self$data$treat_col]]
Expand Down Expand Up @@ -384,6 +388,7 @@ DoubleMLPLIV = R6Class("DoubleMLPLIV",
return_train_preds = FALSE,
learner_class = private$learner_class$ml_g,
fold_specific_params = private$fold_specific_params)
check_finite_predictions(g_hat, self$learner$ml_g$id, "ml_g", smpls)

m_hat_list = dml_cv_predict(self$learner$ml_m,
c(
Expand All @@ -399,6 +404,7 @@ DoubleMLPLIV = R6Class("DoubleMLPLIV",
learner_class = private$learner_class$ml_m,
fold_specific_params = private$fold_specific_params)
m_hat = m_hat_list$preds
check_finite_predictions(m_hat, self$learner$ml_m$id, "ml_m", smpls)
data_aux_list = lapply(m_hat_list$train_preds, function(x) {
setnafill(data.table(self$data$data_model, "m_hat_on_train" = x),
fill = -9999.99) # mlr3 does not allow NA's (values are not used)
Expand All @@ -416,6 +422,7 @@ DoubleMLPLIV = R6Class("DoubleMLPLIV",
return_train_preds = FALSE,
learner_class = private$learner_class$ml_r,
fold_specific_params = private$fold_specific_params)
check_finite_predictions(m_hat_tilde, self$learner$ml_r$id, "ml_r", smpls)

d = self$data$data_model[[self$data$treat_col]]
y = self$data$data_model[[self$data$y_col]]
Expand Down Expand Up @@ -461,6 +468,7 @@ DoubleMLPLIV = R6Class("DoubleMLPLIV",
return_train_preds = FALSE,
learner_class = private$learner_class$ml_r,
fold_specific_params = private$fold_specific_params)
check_finite_predictions(r_hat, self$learner$ml_r$id, "ml_r", smpls)

d = self$data$data_model[[self$data$treat_col]]
y = self$data$data_model[[self$data$y_col]]
Expand Down
2 changes: 2 additions & 0 deletions R/double_ml_plr.R
Original file line number Diff line number Diff line change
Expand Up @@ -183,6 +183,7 @@ DoubleMLPLR = R6Class("DoubleMLPLR",
return_train_preds = FALSE,
learner_class = private$learner_class$ml_g,
fold_specific_params = private$fold_specific_params)
check_finite_predictions(g_hat, self$learner$ml_g$id, "ml_g", smpls)

m_hat = dml_cv_predict(self$learner$ml_m,
c(self$data$x_cols, self$data$other_treat_cols),
Expand All @@ -194,6 +195,7 @@ DoubleMLPLR = R6Class("DoubleMLPLR",
return_train_preds = FALSE,
learner_class = private$learner_class$ml_m,
fold_specific_params = private$fold_specific_params)
check_finite_predictions(m_hat, self$learner$ml_m$id, "ml_m", smpls)

d = self$data$data_model[[self$data$treat_col]]
y = self$data$data_model[[self$data$y_col]]
Expand Down
15 changes: 15 additions & 0 deletions R/helper.R
Original file line number Diff line number Diff line change
Expand Up @@ -384,3 +384,18 @@ check_smpl_split = function(smpl, n_obs, check_intersect = FALSE) {
}
return(TRUE)
}

check_finite_predictions = function(preds, learner, learner_name, smpls) {
for (i_fold in seq_len(length(smpls$test_ids))) {
test_indices = smpls$test_ids[[i_fold]]
is_finite = check_numeric(preds[test_indices],
finite = TRUE,
any.missing = FALSE)
if (!is_finite) {
stop(paste0('Predictions from learner ',
learner, ' for ', learner_name,
' are not finite.'))
}
}
return(TRUE)
}

0 comments on commit c6a9a8d

Please sign in to comment.