Skip to content

Commit

Permalink
add ordering support for composed estimands
Browse files Browse the repository at this point in the history
  • Loading branch information
olivierlabayle committed Nov 24, 2023
1 parent e5b52da commit 320f428
Show file tree
Hide file tree
Showing 10 changed files with 175 additions and 96 deletions.
2 changes: 1 addition & 1 deletion src/TMLE.jl
Original file line number Diff line number Diff line change
Expand Up @@ -44,6 +44,7 @@ export brute_force_ordering, groups_ordering

include("utils.jl")
include("scm.jl")
include("adjustment.jl")
include("estimands.jl")
include("estimators.jl")
include("estimates.jl")
Expand All @@ -56,7 +57,6 @@ include("counterfactual_mean_based/fluctuation.jl")
include("counterfactual_mean_based/estimators.jl")
include("counterfactual_mean_based/clever_covariate.jl")
include("counterfactual_mean_based/gradient.jl")
include("counterfactual_mean_based/adjustment.jl")

include("configuration.jl")

Expand Down
28 changes: 28 additions & 0 deletions src/adjustment.jl
Original file line number Diff line number Diff line change
@@ -0,0 +1,28 @@
#####################################################################
### Identification Methods ###
#####################################################################

abstract type AdjustmentMethod end

struct BackdoorAdjustment <: AdjustmentMethod
outcome_extra_covariates::Tuple{Vararg{Symbol}}
BackdoorAdjustment(outcome_extra_covariates) = new(unique_sorted_tuple(outcome_extra_covariates))
end

BackdoorAdjustment(;outcome_extra_covariates=()) =
BackdoorAdjustment(outcome_extra_covariates)

function statistical_type_from_causal_type(T)
typestring = string(Base.typename(T).wrapper)
new_typestring = replace(typestring, "TMLE.Causal" => "")
return eval(Symbol(new_typestring))
end

identify(estimand, scm::SCM; method=BackdoorAdjustment()::BackdoorAdjustment) =
identify(method, estimand, scm)


to_dict(adjustment::BackdoorAdjustment) = Dict(
:type => "BackdoorAdjustment",
:outcome_extra_covariates => collect(adjustment.outcome_extra_covariates)
)
46 changes: 0 additions & 46 deletions src/counterfactual_mean_based/adjustment.jl

This file was deleted.

33 changes: 30 additions & 3 deletions src/counterfactual_mean_based/estimands.jl
Original file line number Diff line number Diff line change
Expand Up @@ -119,17 +119,24 @@ function indicator_fns(Ψ::StatisticalIATE)
return Dict(key_vals...)
end

expected_value::StatisticalCMCompositeEstimand) = ExpectedValue.outcome, Tuple(union.outcome_extra_covariates, keys.treatment_confounders), (Ψ.treatment_confounders)...)))
outcome_mean::StatisticalCMCompositeEstimand) = ExpectedValue.outcome, Tuple(union.outcome_extra_covariates, keys.treatment_confounders), (Ψ.treatment_confounders)...)))

outcome_mean_key::StatisticalCMCompositeEstimand) = variables(outcome_mean(Ψ))

propensity_score::StatisticalCMCompositeEstimand) = Tuple(ConditionalDistribution(T, Ψ.treatment_confounders[T]) for T in treatments(Ψ))

propensity_score_key::StatisticalCMCompositeEstimand) = Tuple(variables(x) for x propensity_score(Ψ))

function get_relevant_factors::StatisticalCMCompositeEstimand)
outcome_model = expected_value(Ψ)
outcome_model = outcome_mean(Ψ)
treatment_factors = propensity_score(Ψ)
return CMRelevantFactors(outcome_model, treatment_factors)
end

n_uniques_nuisance_functions::StatisticalCMCompositeEstimand) = length(propensity_score(Ψ)) + 1

nuisance_functions_iterator::StatisticalCMCompositeEstimand) =
(propensity_score(Ψ)..., expected_value(Ψ))
(propensity_score(Ψ)..., outcome_mean(Ψ))

function Base.show(io::IO, ::MIME"text/plain", Ψ::T) where T <: StatisticalCMCompositeEstimand
param_string = string(

Check warning on line 142 in src/counterfactual_mean_based/estimands.jl

View check run for this annotation

Codecov / codecov/patch

src/counterfactual_mean_based/estimands.jl#L141-L142

Added lines #L141 - L142 were not covered by tests
Expand Down Expand Up @@ -193,4 +200,24 @@ function to_dict(Ψ::T) where T <: StatisticalCMCompositeEstimand
:treatment_confounders => treatment_confounders_to_dict.treatment_confounders),
:outcome_extra_covariates => collect.outcome_extra_covariates)
)
end

identify(method::AdjustmentMethod, Ψ::StatisticalCMCompositeEstimand, scm::SCM) = Ψ

Check warning on line 205 in src/counterfactual_mean_based/estimands.jl

View check run for this annotation

Codecov / codecov/patch

src/counterfactual_mean_based/estimands.jl#L205

Added line #L205 was not covered by tests

function identify(method::BackdoorAdjustment, causal_estimand::T, scm::SCM) where T<:CausalCMCompositeEstimands
# Treatment confounders
treatment_names = keys(causal_estimand.treatment_values)
treatment_codes = [code_for(scm.graph, treatment) for treatment treatment_names]
confounders_codes = scm.graph.graph.badjlist[treatment_codes]
treatment_confounders = NamedTuple{treatment_names}(
[[scm.graph.vertex_labels[w] for w in confounders_codes[i]]
for i in eachindex(confounders_codes)]
)

return statistical_type_from_causal_type(T)(;
outcome=causal_estimand.outcome,
treatment_values = causal_estimand.treatment_values,
treatment_confounders = treatment_confounders,
outcome_extra_covariates = method.outcome_extra_covariates
)
end
84 changes: 42 additions & 42 deletions src/estimand_ordering.jl
Original file line number Diff line number Diff line change
Expand Up @@ -33,23 +33,22 @@ function evaluate_proxy_costs(estimands, η_counts; verbosity=0)
maxmem = 0
compcost = 0
for (estimand_index, Ψ) enumerate(estimands)
η = get_relevant_factors(Ψ)
models =.propensity_score..., η.outcome_mean)
# Append cache
for model in models
if model cache
ηs = collect(nuisance_functions_iterator(Ψ))
for η ηs
if η cache
compcost += 1
push!(cache, model)
push!(cache, η)
end
end
# Update maxmem
maxmem = max(maxmem, length(cache))
verbosity > 0 && @info string("Cache size after estimand $estimand_index: ", length(cache))
# Free cache from models that are not useful anymore
for model in models
η_counts[model] -= 1
if η_counts[model] <= 0
pop!(cache, model)
for η in ηs
η_counts[η] -= 1
if η_counts[η] <= 0
pop!(cache, η)
end
end
end
Expand All @@ -66,8 +65,7 @@ on the cache size. It can be computed in a single pass, i.e. in O(N).
function get_min_maxmem_lowerbound(estimands)
min_maxmem_lowerbound = 0
for Ψ in estimands
η = get_relevant_factors(Ψ)
candidate_min = length((η.propensity_score..., η.outcome_mean))
candidate_min = n_uniques_nuisance_functions(Ψ)
if candidate_min > min_maxmem_lowerbound
min_maxmem_lowerbound = candidate_min
end
Expand All @@ -77,25 +75,9 @@ end

estimands_permutation_generator(estimands) = Combinatorics.permutations(estimands)

function get_propensity_score_groups(estimands_and_nuisances)
ps_groups = [[]]
current_ps = estimands_and_nuisances[1][2]
for (index, Ψ_and_ηs) enumerate(estimands_and_nuisances)
new_ps = Ψ_and_ηs[2]
if new_ps == current_ps
push!(ps_groups[end], index)
else
current_ps = new_ps
push!(ps_groups, [index])
end
end
return ps_groups
end

function propensity_score_group_based_permutation_generator(estimands, estimands_and_nuisances)
ps_groups = get_propensity_score_groups(estimands_and_nuisances)
group_permutations = Combinatorics.permutations(collect(1:length(ps_groups)))
return (vcat((estimands[ps_groups[index]] for index in group_perm)...) for group_perm in group_permutations)
function propensity_score_group_based_permutation_generator(groups)
group_permutations = Combinatorics.permutations(collect(keys(groups)))
return (vcat((groups[ps_key] for ps_key in group_perm)...) for group_perm in group_permutations)
end

"""
Expand Down Expand Up @@ -133,6 +115,32 @@ function brute_force_ordering(estimands;
return optimal_ordering

Check warning on line 115 in src/estimand_ordering.jl

View check run for this annotation

Codecov / codecov/patch

src/estimand_ordering.jl#L115

Added line #L115 was not covered by tests
end

"""
groupby_by_propensity_score(estimands)
Group parameters per propensity score and order each group by outcome_mean.
"""
function groupby_by_propensity_score(estimands)
groups = Dict()
for Ψ in estimands
propensity_score_key_ = propensity_score_key(Ψ)
outcome_mean_key_ = outcome_mean_key(Ψ)
if haskey(groups, propensity_score_key_)
propensity_score_group = groups[propensity_score_key_]
if haskey(propensity_score_group, outcome_mean_key_)
push!(propensity_score_group[outcome_mean_key_], Ψ)
else
propensity_score_group[outcome_mean_key_] = Any[Ψ]
end
else
groups[propensity_score_key_] = Dict()
groups[propensity_score_key_][outcome_mean_key_] = Any[Ψ]
end

end
return Dict(key => vcat(values(groups[key])...) for key in keys(groups))
end

"""
groups_ordering(estimands)
Expand All @@ -142,26 +150,18 @@ work reasonably well in practice. It could be optimized further by:
- Brute forcing the ordering of these groups to find an optimal one.
"""
function groups_ordering(estimands; brute_force=false, do_shuffle=true, rng=Random.default_rng(), verbosity=0)
# Sort estimands based on propensity_score first and outcome_mean second
estimands_and_nuisances = []
for Ψ in estimands
η = TMLE.get_relevant_factors(Ψ)
push!(estimands_and_nuisances, (Ψ, η.propensity_score, η.outcome_mean))
end
sort!(estimands_and_nuisances, by = x -> (Tuple(TMLE.variables(ps) for ps in x[2]), TMLE.variables(x[3])))

# Sorted estimands only
estimands = [x[1] for x in estimands_and_nuisances]
# Group estimands based on propensity_score first and outcome_mean second
groups = groupby_by_propensity_score(estimands)

# Brute force on the propensity score groups
if brute_force
return brute_force_ordering(estimands;
permutation_generator = propensity_score_group_based_permutation_generator(estimands, estimands_and_nuisances),
permutation_generator = propensity_score_group_based_permutation_generator(groups),
do_shuffle=do_shuffle,
rng=rng,
verbosity=verbosity
)
else
return estimands
return vcat(values(groups)...)
end
end
8 changes: 8 additions & 0 deletions src/estimands.jl
Original file line number Diff line number Diff line change
Expand Up @@ -121,5 +121,13 @@ to_dict(Ψ::ComposedEstimand) = Dict(
:args => [to_dict(x) for x in Ψ.args]
)

propensity_score_key::ComposedEstimand) = Tuple(unique(Iterators.flatten(propensity_score_key(arg) for arg in Ψ.args)))
outcome_mean_key::ComposedEstimand) = Tuple(unique(outcome_mean_key(arg) for arg in Ψ.args))

n_uniques_nuisance_functions::ComposedEstimand) = length(propensity_score_key(Ψ)) + length(outcome_mean_key(Ψ))

nuisance_functions_iterator::ComposedEstimand) =
Iterators.flatten(nuisance_functions_iterator(arg) for arg in Ψ.args)

identify(method::AdjustmentMethod, Ψ::ComposedEstimand, scm::SCM) =
ComposedEstimand.f, Tuple(identify(method, arg, scm) for arg Ψ.args))
File renamed without changes.
48 changes: 46 additions & 2 deletions test/estimand_ordering.jl
Original file line number Diff line number Diff line change
Expand Up @@ -47,6 +47,24 @@ causal_estimands = [
]
statistical_estimands = [identify(x, scm) for x in causal_estimands]

@testset "Test misc" begin
# Test groupby_by_propensity_score
groups = TMLE.groupby_by_propensity_score(statistical_estimands)
@test groups[((:T₂, :W₂),)] == [statistical_estimands[4]]
@test groups[((:T₃, :W₂, :W₃),)] == [statistical_estimands[7]]
@test groups[((:T₁, :W₁, :W₂), (:T₃, :W₂, :W₃))] == [statistical_estimands[end]]
@test Set(groups[((:T₁, :W₁, :W₂),)]) == Set([statistical_estimands[1], statistical_estimands[2], statistical_estimands[6]])
@test Set(groups[((:T₁, :W₁, :W₂), (:T₂, :W₂))]) == Set([statistical_estimands[3], statistical_estimands[5]])
@test size(vcat(values(groups)...)) == size(statistical_estimands)

# Test PS groups permutations
factorial(5)
permutations = TMLE.propensity_score_group_based_permutation_generator(groups)
@test length(permutations) == factorial(length(groups))
for permutation in permutations
@test Set(permutation) == Set(statistical_estimands)
end
end
@testset "Test ordering strategies" begin
# Estimand ID || Required models
# 1 || (T₁, Y₁|T₁)
Expand All @@ -72,22 +90,48 @@ statistical_estimands = [identify(x, scm) for x in causal_estimands]
)
@test TMLE.evaluate_proxy_costs(statistical_estimands, η_counts) == (4, 9)
@test TMLE.get_min_maxmem_lowerbound(statistical_estimands) == 3

# The brute force solution returns the optimal solution
optimal_ordering = @test_logs (:info, "Lower bound reached, stopping.") brute_force_ordering(statistical_estimands, verbosity=1, rng=StableRNG(123))
@test TMLE.evaluate_proxy_costs(optimal_ordering, η_counts) == (3, 9)
# Creating a bad ordering
bad_ordering = statistical_estimands[[1, 7, 3, 6, 2, 5, 8, 4]]
@test TMLE.evaluate_proxy_costs(bad_ordering, η_counts) == (6, 9)
# Without the brute force on groups, the solution is not necessarily optimal
# but still widely improved
# but still improved
ordering_from_groups = groups_ordering(bad_ordering)
@test TMLE.evaluate_proxy_costs(ordering_from_groups, η_counts) == (4, 9)
# Adding a layer of brute forcing results in an optimal ordering

ordering_from_groups_with_brute_force = groups_ordering(bad_ordering, brute_force=true)
@test TMLE.evaluate_proxy_costs(ordering_from_groups_with_brute_force, η_counts) == (3, 9)
end

@testset "Test ordering strategies with Composed Estimands" begin
ATE₁ = ATE(
outcome=:Y₁,
treatment_values=(T₁=(case=1, control=0),)
)
ATE₂ = ATE(
outcome=:Y₁,
treatment_values=(T₁=(case=2, control=1),)
)
diff = ComposedEstimand(-, (ATE₁, ATE₂))
ATE₃ = ATE(
outcome=:Y₁,
treatment_values=(T₂=(case=1, control=0),)
)
estimands = [identify(x, scm) for x in [ATE₁, ATE₃, diff, ATE₂]]
η_counts = TMLE.nuisance_function_counts(estimands)
@test TMLE.get_min_maxmem_lowerbound(estimands) == 2
@test TMLE.evaluate_proxy_costs(estimands, η_counts) == (4, 4)
# Brute Force
optimal_ordering = brute_force_ordering(estimands)
@test TMLE.evaluate_proxy_costs(optimal_ordering, η_counts) == (2, 4)
# PS Group
grouped_ordering = groups_ordering(estimands, brute_force=true)
@test TMLE.evaluate_proxy_costs(grouped_ordering, η_counts) == (2, 4)

end

end

Expand Down
Loading

0 comments on commit 320f428

Please sign in to comment.