From e5a85c653cbfc502b56feb2545b818056024ba5b Mon Sep 17 00:00:00 2001 From: Guillaume Dalle <22795598+gdalle@users.noreply.github.com> Date: Mon, 29 Jul 2024 21:07:23 +0200 Subject: [PATCH 1/3] Switch to DifferentiationInterface --- Project.toml | 6 +++--- src/TMLE.jl | 2 +- src/estimators.jl | 11 ++++++----- 3 files changed, 10 insertions(+), 9 deletions(-) diff --git a/Project.toml b/Project.toml index f2910362..6c8af64d 100644 --- a/Project.toml +++ b/Project.toml @@ -1,12 +1,12 @@ name = "TMLE" uuid = "8afdd2fb-6e73-43df-8b62-b1650cd9c8cf" authors = ["Olivier Labayle"] -version = "0.16.1" +version = "0.16.2" [deps] -AbstractDifferentiation = "c29ec348-61ec-40c8-8164-b8c60e9d9f3d" CategoricalArrays = "324d7699-5711-5eae-9e2f-1d82baa6b597" Combinatorics = "861a8166-3701-5b0c-9a16-15d98fcdc6aa" +DifferentiationInterface = "a0c0ee7d-e4b9-4e03-894e-1c5f64a51d63" Distributions = "31c24e10-a181-5473-b8eb-7969acd0382f" GLM = "38e38edf-8417-5370-95a0-9cbb8c7f171a" Graphs = "86223c79-3864-5bf0-83f7-82e725a168b6" @@ -35,9 +35,9 @@ JSONExt = "JSON" YAMLExt = "YAML" [compat] -AbstractDifferentiation = "0.6.0" CategoricalArrays = "0.10" Combinatorics = "1.0.2" +DifferentiationInterface = "0.5" Distributions = "0.25" GLM = "1.8.2" Graphs = "1.8" diff --git a/src/TMLE.jl b/src/TMLE.jl index 5510cb5f..d34e8139 100644 --- a/src/TMLE.jl +++ b/src/TMLE.jl @@ -15,7 +15,7 @@ using Zygote using LogExpFunctions using PrecompileTools using Random -import AbstractDifferentiation as AD +import DifferentiationInterface as DI using Graphs using MetaGraphsNext using Combinatorics diff --git a/src/estimators.jl b/src/estimators.jl index a7fb1cbb..b5e225ef 100644 --- a/src/estimators.jl +++ b/src/estimators.jl @@ -95,7 +95,7 @@ ConditionalDistributionEstimator(train_validation_indices, model) = Estimates all components of Ψ and then Ψ itself. """ -function (estimator::Estimator)(Ψ::ComposedEstimand, dataset; cache=Dict(), verbosity=1, backend=AD.ZygoteBackend()) +function (estimator::Estimator)(Ψ::ComposedEstimand, dataset; cache=Dict(), verbosity=1, backend=DI.AutoZygote()) estimates = map(Ψ.args) do estimand estimate, _ = estimator(estimand, dataset; cache=cache, verbosity=verbosity) estimate @@ -157,16 +157,17 @@ f(x, y) = [x^2 - y, y - 3x] compose(f, res₁, res₂) ``` """ -function compose(f, estimates...; backend=AD.ZygoteBackend()) +function compose(f, estimates...; backend=DI.AutoZygote()) f₀, σ₀, n = _compose(f, estimates...; backend=backend) estimand = ComposedEstimand(f, Tuple(e.estimand for e in estimates)) return ComposedEstimate(estimand, estimates, f₀, σ₀, n) end -function _compose(f, estimates...; backend=AD.ZygoteBackend()) +function _compose(f, estimates...; backend=DI.AutoZygote()) Σ = covariance_matrix(estimates...) point_estimates = [r.estimate for r in estimates] - f₀, Js = AD.value_and_jacobian(backend, f, point_estimates...) + f₀_and_Js = DI.value_and_jacobian.(Ref(f), Ref(backend), point_estimates) + f₀, Js = first.(f₀_and_Js), last.(f₀_and_Js) J = hcat(Js...) n = size(first(estimates).IC, 1) σ₀ = J * Σ * J' @@ -176,4 +177,4 @@ end function covariance_matrix(estimates...) X = hcat([r.IC for r in estimates]...) return cov(X, dims=1, corrected=true) -end \ No newline at end of file +end From d3379e682411ebf848162ed8b624e25f80f81106 Mon Sep 17 00:00:00 2001 From: Guillaume Dalle <22795598+gdalle@users.noreply.github.com> Date: Tue, 30 Jul 2024 14:56:38 +0200 Subject: [PATCH 2/3] Fix splatting --- src/estimators.jl | 4 +--- 1 file changed, 1 insertion(+), 3 deletions(-) diff --git a/src/estimators.jl b/src/estimators.jl index b5e225ef..98645480 100644 --- a/src/estimators.jl +++ b/src/estimators.jl @@ -166,9 +166,7 @@ end function _compose(f, estimates...; backend=DI.AutoZygote()) Σ = covariance_matrix(estimates...) point_estimates = [r.estimate for r in estimates] - f₀_and_Js = DI.value_and_jacobian.(Ref(f), Ref(backend), point_estimates) - f₀, Js = first.(f₀_and_Js), last.(f₀_and_Js) - J = hcat(Js...) + f₀, J = DI.value_and_jacobian(Base.splat(f), backend, point_estimates) n = size(first(estimates).IC, 1) σ₀ = J * Σ * J' return collect(f₀), σ₀, n From a708a418affd8a0c8bf1867ee9fae7465da02941 Mon Sep 17 00:00:00 2001 From: Guillaume Dalle <22795598+gdalle@users.noreply.github.com> Date: Fri, 2 Aug 2024 07:36:39 +0200 Subject: [PATCH 3/3] Always vector --- src/estimators.jl | 7 +++++-- 1 file changed, 5 insertions(+), 2 deletions(-) diff --git a/src/estimators.jl b/src/estimators.jl index 98645480..7da2719c 100644 --- a/src/estimators.jl +++ b/src/estimators.jl @@ -163,13 +163,16 @@ function compose(f, estimates...; backend=DI.AutoZygote()) return ComposedEstimate(estimand, estimates, f₀, σ₀, n) end +_make_vec(x::Number) = [x] +_make_vec(x::AbstractVector) = x + function _compose(f, estimates...; backend=DI.AutoZygote()) Σ = covariance_matrix(estimates...) point_estimates = [r.estimate for r in estimates] - f₀, J = DI.value_and_jacobian(Base.splat(f), backend, point_estimates) + f₀, J = DI.value_and_jacobian(_make_vec ∘ Base.splat(f), backend, point_estimates) n = size(first(estimates).IC, 1) σ₀ = J * Σ * J' - return collect(f₀), σ₀, n + return f₀, σ₀, n end function covariance_matrix(estimates...)