Skip to content

Commit

Permalink
add fit failed error for fluctuation and include model in error
Browse files Browse the repository at this point in the history
  • Loading branch information
olivierlabayle committed Dec 13, 2023
1 parent 485e04c commit e7f648f
Show file tree
Hide file tree
Showing 3 changed files with 564 additions and 17 deletions.
41 changes: 28 additions & 13 deletions src/counterfactual_mean_based/estimators.jl
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@
#####################################################################
struct FitFailedError <: Exception
estimand::Estimand
model::MLJBase.Model
msg::String
origin::Exception
end
Expand All @@ -13,6 +14,11 @@ outcome_mean_fit_error_msg(factor) = string(
string_repr(factor),
".\n Hint: don't forget to use `with_encoder` to encode categorical variables.")

outcome_mean_fluctuation_fit_error_msg(factor) = string(
"Could not fluctuate the following Outcome mean: ",
string_repr(factor),
".")

Base.showerror(io::IO, e::FitFailedError) = print(io, e.msg)

struct CMRelevantFactorsEstimator <: Estimator
Expand Down Expand Up @@ -67,27 +73,30 @@ function (estimator::CMRelevantFactorsEstimator)(estimand, dataset; cache=Dict()
# Fit propensity score
propensity_score_estimate = map(estimand.propensity_score) do factor
try
ConditionalDistributionEstimator(train_validation_indices, acquire_model(models, factor.outcome, dataset, true))(
ps_estimator = ConditionalDistributionEstimator(train_validation_indices, acquire_model(models, factor.outcome, dataset, true))
ps_estimator(
factor,
dataset;
cache=cache,
verbosity=verbosity,
machine_cache=machine_cache
)
catch e
throw(FitFailedError(factor, propensity_score_fit_error_msg(factor), e))
model = acquire_model(models, factor.outcome, dataset, true)
throw(FitFailedError(factor, model, propensity_score_fit_error_msg(factor), e))
end
end
# Fit outcome mean
outcome_mean = estimand.outcome_mean
model = acquire_model(models, outcome_mean.outcome, dataset, false)
outcome_mean_estimator = ConditionalDistributionEstimator(
train_validation_indices,
model
)
outcome_mean_estimate = try
ConditionalDistributionEstimator(
train_validation_indices,
model
)(outcome_mean, dataset; cache=cache, verbosity=verbosity, machine_cache=machine_cache)
outcome_mean_estimator(outcome_mean, dataset; cache=cache, verbosity=verbosity, machine_cache=machine_cache)
catch e
throw(FitFailedError(outcome_mean, outcome_mean_fit_error_msg(outcome_mean), e))
throw(FitFailedError(outcome_mean, model, outcome_mean_fit_error_msg(outcome_mean), e))
end
# Build estimate
estimate = MLCMRelevantFactors(estimand, outcome_mean_estimate, propensity_score_estimate)
Expand Down Expand Up @@ -115,13 +124,19 @@ TargetedCMRelevantFactorsEstimator(Ψ, initial_factors_estimate; tol=nothing, ps

function (estimator::TargetedCMRelevantFactorsEstimator)(estimand, dataset; cache=Dict(), verbosity=1, machine_cache=false)
model = estimator.model
outcome_mean = model.initial_factors.outcome_mean.estimand
# Fluctuate outcome model
fluctuated_outcome_mean = MLConditionalDistributionEstimator(model)(
model.initial_factors.outcome_mean.estimand,
dataset,
verbosity=verbosity,
machine_cache=machine_cache
)
fluctuated_estimator = MLConditionalDistributionEstimator(model)
fluctuated_outcome_mean = try
fluctuated_estimator(
outcome_mean,
dataset,
verbosity=verbosity,
machine_cache=machine_cache
)
catch e
throw(FitFailedError(outcome_mean, model, outcome_mean_fluctuation_fit_error_msg(outcome_mean), e))
end
# Do not fluctuate propensity score
fluctuated_propensity_score = model.initial_factors.propensity_score
# Build estimate
Expand Down
39 changes: 35 additions & 4 deletions test/counterfactual_mean_based/estimators_and_estimates.jl
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,9 @@ using MLJLinearModels
using CategoricalArrays
using LogExpFunctions
using MLJBase
using CSV

DATADIR = joinpath(pkgdir(TMLE), "test", "data")

function make_dataset()
n = 100
Expand Down Expand Up @@ -55,7 +58,6 @@ end
@test TMLE.key(η, new_η̂) == TMLE.key(η, η̂)
full_reuse_log = (:info, TMLE.reuse_string(η))
@test_logs full_reuse_log new_η̂(η, dataset; cache=cache, verbosity=1)

# Changing one model, only the other one is refitted
new_models = (
Y = with_encoder(LinearRegressor()),
Expand Down Expand Up @@ -85,7 +87,7 @@ end
Q = TMLE.ConditionalDistribution(:Y, [:T₁, :W])
G = (TMLE.ConditionalDistribution(:T₁, [:W]),)
η = TMLE.CMRelevantFactors(outcome_mean=Q, propensity_score=G)
# Estimator
# Propensity score model is ill-defined
models = (
Y = with_encoder(LinearRegressor()),
T₁ = LinearRegressor()
Expand All @@ -96,15 +98,44 @@ end
@test true === false
catch e
@test e isa TMLE.FitFailedError
@test e.model isa LinearRegressor
@test e.msg == TMLE.propensity_score_fit_error_msg(G[1])
end

# Outcome Mean model is ill-defined
models = (
Y = LogisticClassifier(),
T₁ = LogisticClassifier(fit_intercept=false)
)
η̂ = TMLE.CMRelevantFactorsEstimator(models=models)
@test_throws TMLE.FitFailedError η̂(η, dataset; verbosity=0)
try
η̂(η, dataset; verbosity=0)
@test true === false
catch e
@test e isa TMLE.FitFailedError
@test e.model isa LogisticClassifier
@test e.msg == TMLE.outcome_mean_fit_error_msg(Q)
end
# Fluctuation Pos Def Exception
pos_def_error_dataset = CSV.read(joinpath(DATADIR, "posdef_error_dataset.csv"), DataFrame)
outcome = Symbol("G25 Other extrapyramidal and movement disorders")
treatment = Symbol("2:14983:G:A")
pos_def_error_dataset[!, treatment] = categorical(pos_def_error_dataset[!, treatment])
pos_def_error_dataset[!, outcome] = categorical(pos_def_error_dataset[!, outcome])
Ψ = ATE(
outcome=outcome,
treatment_values = NamedTuple{(treatment,)}([(case = "GG", control = "AG")]),
treatment_confounders = (:PC1, :PC2, :PC3, :PC4, :PC5, :PC6)
)
Q = TMLE.ConditionalDistribution(outcome, [treatment, :PC1, :PC2, :PC3, :PC4, :PC5, :PC6])
tmle = TMLEE(models=TMLE.default_models(Q_binary=LogisticClassifier(), G = LogisticClassifier()))
try
tmle(Ψ, pos_def_error_dataset)
@test true === false
catch e
@test e isa TMLE.FitFailedError
@test e.model isa TMLE.Fluctuation
@test e.msg == TMLE.outcome_mean_fluctuation_fit_error_msg(Q)
end
end

@testset "Test structs are concrete types" begin
Expand Down
Loading

0 comments on commit e7f648f

Please sign in to comment.