From c6a9a8d62f81782e0d2c2ec838d2d41570101eb0 Mon Sep 17 00:00:00 2001 From: "Malte S. Kurz" Date: Tue, 26 Oct 2021 13:57:12 +0200 Subject: [PATCH] started implementing an exception handling for the predictions (finite / not missing), see #136 --- R/double_ml_iivm.R | 5 +++++ R/double_ml_irm.R | 3 +++ R/double_ml_pliv.R | 8 ++++++++ R/double_ml_plr.R | 2 ++ R/helper.R | 15 +++++++++++++++ 5 files changed, 33 insertions(+) diff --git a/R/double_ml_iivm.R b/R/double_ml_iivm.R index 5655b7c8..cfe19e5f 100644 --- a/R/double_ml_iivm.R +++ b/R/double_ml_iivm.R @@ -297,6 +297,7 @@ DoubleMLIIVM = R6Class("DoubleMLIIVM", return_train_preds = FALSE, learner_class = private$learner_class$ml_m, fold_specific_params = private$fold_specific_params) + check_finite_predictions(m_hat, self$learner$ml_m$id, "ml_m", smpls) g0_hat = dml_cv_predict(self$learner$ml_g, c(self$data$x_cols, self$data$other_treat_cols), @@ -308,6 +309,7 @@ DoubleMLIIVM = R6Class("DoubleMLIIVM", return_train_preds = FALSE, learner_class = private$learner_class$ml_g, fold_specific_params = private$fold_specific_params) + check_finite_predictions(g0_hat, self$learner$ml_g$id, "ml_g0", smpls) g1_hat = dml_cv_predict(self$learner$ml_g, c(self$data$x_cols, self$data$other_treat_cols), @@ -319,6 +321,7 @@ DoubleMLIIVM = R6Class("DoubleMLIIVM", return_train_preds = FALSE, learner_class = private$learner_class$ml_g, fold_specific_params = private$fold_specific_params) + check_finite_predictions(g1_hat, self$learner$ml_g$id, "ml_g1", smpls) if (self$subgroups$always_takers == FALSE) { r0_hat = rep(0, self$data$n_obs) @@ -333,6 +336,7 @@ DoubleMLIIVM = R6Class("DoubleMLIIVM", return_train_preds = FALSE, learner_class = private$learner_class$ml_r, fold_specific_params = private$fold_specific_params) + check_finite_predictions(r0_hat, self$learner$ml_r$id, "ml_r0", smpls) } if (self$subgroups$never_takers == FALSE) { @@ -348,6 +352,7 @@ DoubleMLIIVM = R6Class("DoubleMLIIVM", return_train_preds = FALSE, learner_class = private$learner_class$ml_r, fold_specific_params = private$fold_specific_params) + check_finite_predictions(r1_hat, self$learner$ml_r$id, "ml_r1", smpls) } # compute residuals diff --git a/R/double_ml_irm.R b/R/double_ml_irm.R index 9da66a9a..8a2ad057 100644 --- a/R/double_ml_irm.R +++ b/R/double_ml_irm.R @@ -229,6 +229,7 @@ DoubleMLIRM = R6Class("DoubleMLIRM", return_train_preds = FALSE, learner_class = private$learner_class$ml_m, fold_specific_params = private$fold_specific_params) + check_finite_predictions(m_hat, self$learner$ml_m$id, "ml_m", smpls) g0_hat = dml_cv_predict(self$learner$ml_g, c(self$data$x_cols, self$data$other_treat_cols), @@ -240,6 +241,7 @@ DoubleMLIRM = R6Class("DoubleMLIRM", return_train_preds = FALSE, learner_class = private$learner_class$ml_g, fold_specific_params = private$fold_specific_params) + check_finite_predictions(g0_hat, self$learner$ml_g$id, "ml_g0", smpls) g1_hat = NULL if ((is.character(self$score) && self$score == "ATE") || is.function(self$score)) { @@ -253,6 +255,7 @@ DoubleMLIRM = R6Class("DoubleMLIRM", return_train_preds = FALSE, learner_class = private$learner_class$ml_g, fold_specific_params = private$fold_specific_params) + check_finite_predictions(g1_hat, self$learner$ml_g$id, "ml_g1", smpls) } d = self$data$data_model[[self$data$treat_col]] diff --git a/R/double_ml_pliv.R b/R/double_ml_pliv.R index 0fc09e86..8e951a6e 100644 --- a/R/double_ml_pliv.R +++ b/R/double_ml_pliv.R @@ -270,6 +270,7 @@ DoubleMLPLIV = R6Class("DoubleMLPLIV", return_train_preds = FALSE, learner_class = private$learner_class$ml_g, fold_specific_params = private$fold_specific_params) + check_finite_predictions(g_hat, self$learner$ml_g$id, "ml_g", smpls) r_hat = dml_cv_predict(self$learner$ml_r, c(self$data$x_cols, self$data$other_treat_cols), @@ -281,6 +282,7 @@ DoubleMLPLIV = R6Class("DoubleMLPLIV", return_train_preds = FALSE, learner_class = private$learner_class$ml_r, fold_specific_params = private$fold_specific_params) + check_finite_predictions(r_hat, self$learner$ml_r$id, "ml_r", smpls) if (self$data$n_instr == 1) { m_hat = dml_cv_predict(self$learner$ml_m, @@ -293,6 +295,7 @@ DoubleMLPLIV = R6Class("DoubleMLPLIV", return_train_preds = FALSE, learner_class = private$learner_class$ml_m, fold_specific_params = private$fold_specific_params) + check_finite_predictions(m_hat, self$learner$ml_m$id, "ml_m", smpls) } else { m_hat = do.call( cbind, @@ -310,6 +313,7 @@ DoubleMLPLIV = R6Class("DoubleMLPLIV", learner_class = private$learner_class$ml_m, fold_specific_params = private$fold_specific_params) })) + check_finite_predictions(m_hat, self$learner$ml_m$id, "ml_m_", smpls) } d = self$data$data_model[[self$data$treat_col]] @@ -384,6 +388,7 @@ DoubleMLPLIV = R6Class("DoubleMLPLIV", return_train_preds = FALSE, learner_class = private$learner_class$ml_g, fold_specific_params = private$fold_specific_params) + check_finite_predictions(g_hat, self$learner$ml_g$id, "ml_g", smpls) m_hat_list = dml_cv_predict(self$learner$ml_m, c( @@ -399,6 +404,7 @@ DoubleMLPLIV = R6Class("DoubleMLPLIV", learner_class = private$learner_class$ml_m, fold_specific_params = private$fold_specific_params) m_hat = m_hat_list$preds + check_finite_predictions(m_hat, self$learner$ml_m$id, "ml_m", smpls) data_aux_list = lapply(m_hat_list$train_preds, function(x) { setnafill(data.table(self$data$data_model, "m_hat_on_train" = x), fill = -9999.99) # mlr3 does not allow NA's (values are not used) @@ -416,6 +422,7 @@ DoubleMLPLIV = R6Class("DoubleMLPLIV", return_train_preds = FALSE, learner_class = private$learner_class$ml_r, fold_specific_params = private$fold_specific_params) + check_finite_predictions(m_hat_tilde, self$learner$ml_r$id, "ml_r", smpls) d = self$data$data_model[[self$data$treat_col]] y = self$data$data_model[[self$data$y_col]] @@ -461,6 +468,7 @@ DoubleMLPLIV = R6Class("DoubleMLPLIV", return_train_preds = FALSE, learner_class = private$learner_class$ml_r, fold_specific_params = private$fold_specific_params) + check_finite_predictions(r_hat, self$learner$ml_r$id, "ml_r", smpls) d = self$data$data_model[[self$data$treat_col]] y = self$data$data_model[[self$data$y_col]] diff --git a/R/double_ml_plr.R b/R/double_ml_plr.R index 968631e5..6dc8e2aa 100644 --- a/R/double_ml_plr.R +++ b/R/double_ml_plr.R @@ -183,6 +183,7 @@ DoubleMLPLR = R6Class("DoubleMLPLR", return_train_preds = FALSE, learner_class = private$learner_class$ml_g, fold_specific_params = private$fold_specific_params) + check_finite_predictions(g_hat, self$learner$ml_g$id, "ml_g", smpls) m_hat = dml_cv_predict(self$learner$ml_m, c(self$data$x_cols, self$data$other_treat_cols), @@ -194,6 +195,7 @@ DoubleMLPLR = R6Class("DoubleMLPLR", return_train_preds = FALSE, learner_class = private$learner_class$ml_m, fold_specific_params = private$fold_specific_params) + check_finite_predictions(m_hat, self$learner$ml_m$id, "ml_m", smpls) d = self$data$data_model[[self$data$treat_col]] y = self$data$data_model[[self$data$y_col]] diff --git a/R/helper.R b/R/helper.R index e5c36f81..973f6cbc 100644 --- a/R/helper.R +++ b/R/helper.R @@ -384,3 +384,18 @@ check_smpl_split = function(smpl, n_obs, check_intersect = FALSE) { } return(TRUE) } + +check_finite_predictions = function(preds, learner, learner_name, smpls) { + for (i_fold in seq_len(length(smpls$test_ids))) { + test_indices = smpls$test_ids[[i_fold]] + is_finite = check_numeric(preds[test_indices], + finite = TRUE, + any.missing = FALSE) + if (!is_finite) { + stop(paste0('Predictions from learner ', + learner, ' for ', learner_name, + ' are not finite.')) + } + } + return(TRUE) +}