Skip to content

Commit

Permalink
make default_models public
Browse files Browse the repository at this point in the history
  • Loading branch information
olivierlabayle committed Apr 1, 2024
1 parent e4c3e40 commit cb83bc9
Show file tree
Hide file tree
Showing 6 changed files with 31 additions and 37 deletions.
3 changes: 1 addition & 2 deletions docs/src/index.md
Original file line number Diff line number Diff line change
Expand Up @@ -67,8 +67,7 @@ The Average Treatment Effect of ``T`` on ``Y`` confounded by ``W`` is defined as
3. An estimator: here a Targeted Maximum Likelihood Estimator (TMLE).

```@example quick-start
models = (Y=with_encoder(LinearRegressor()), T = LogisticClassifier())
tmle = TMLEE(models=models)
tmle = TMLEE()
result, _ = tmle(Ψ, dataset, verbosity=0);
result
```
Expand Down
7 changes: 1 addition & 6 deletions docs/src/user_guide/estimation.md
Original file line number Diff line number Diff line change
Expand Up @@ -66,12 +66,7 @@ Drawing from the example dataset and `SCM` from the Walk Through section, we can
treatment_confounders=(T₁=[:W₁₁, :W₁₂],),
outcome_extra_covariates=[:C]
)
models = (
Y=with_encoder(LinearRegressor()),
T₁=LogisticClassifier(),
T₂=LogisticClassifier(),
)
tmle = TMLEE(models=models)
tmle = TMLEE()
result₁, cache = tmle(Ψ₁, dataset);
result₁
nothing # hide
Expand Down
19 changes: 0 additions & 19 deletions docs/src/user_guide/misc.md
Original file line number Diff line number Diff line change
Expand Up @@ -12,25 +12,6 @@ The adjustment set consists of all the treatment variable's parents. Additional
BackdoorAdjustment(;outcome_extra_covariates=[:C])
```

## Treatment Transformer

To account for the fact that treatment variables are categorical variables we provide a MLJ compliant transformer that will either:

- Retrieve the floating point representation of a treatment if it has a natural ordering
- One hot encode it otherwise

Such transformer can be created with:

```julia
TreatmentTransformer(;encoder=encoder())
```

where `encoder` is a [OneHotEncoder](https://alan-turing-institute.github.io/MLJ.jl/dev/models/OneHotEncoder_MLJModels/#OneHotEncoder_MLJModels).

The `with_encoder(model; encoder=TreatmentTransformer())` provides a shorthand to combine a `TreatmentTransformer` with another MLJ model in a pipeline.

Of course you are also free to define your own strategy!

## Serialization

Many objects from TMLE.jl can be serialized to various file formats. This is achieved by transforming these structures to dictionaries that can then be serialized to classic JSON or YAML format. For that purpose you can use the `TMLE.read_json`, `TMLE.write_json`, `TMLE.read_yaml` and `TMLE.write_yaml` functions.
7 changes: 1 addition & 6 deletions docs/src/walk_through.md
Original file line number Diff line number Diff line change
Expand Up @@ -135,12 +135,7 @@ Alternatively, you can also directly define the statistical parameters (see [Est
Then each parameter can be estimated by building an estimator (which is simply a function) and evaluating it on data. For illustration, we will keep the models simple. We define a Targeted Maximum Likelihood Estimator:

```@example walk-through
models = (
Y = with_encoder(LinearRegressor()),
T₁ = LogisticClassifier(),
T₂ = LogisticClassifier()
)
tmle = TMLEE(models=models)
tmle = TMLEE()
```

Because we haven't identified the `cm` causal estimand yet, we need to provide the `scm` as well to the estimator:
Expand Down
2 changes: 1 addition & 1 deletion src/TMLE.jl
Original file line number Diff line number Diff line change
Expand Up @@ -34,7 +34,7 @@ export ComposedEstimand
export var, estimate, pvalue, confint, emptyIC
export significance_test, OneSampleTTest, OneSampleZTest, OneSampleHotellingT2Test
export compose
export TreatmentTransformer, with_encoder, encoder
export default_models, TreatmentTransformer, with_encoder, encoder
export BackdoorAdjustment, identify
export last_fluctuation_epsilon
export Configuration
Expand Down
30 changes: 27 additions & 3 deletions src/utils.jl
Original file line number Diff line number Diff line change
Expand Up @@ -68,9 +68,33 @@ function last_fluctuation_epsilon(cache)
return fp.coef
end

default_models(;Q_binary=LinearBinaryClassifier(), Q_continuous=LinearRegressor(), G=LinearBinaryClassifier(), encoder=encoder()) = (
Q_binary_default = with_encoder(Q_binary, encoder=encoder),
Q_continuous_default = with_encoder(Q_continuous, encoder=encoder),
"""
default_models(;Q_binary=LinearBinaryClassifier(), Q_continuous=LinearRegressor(), G=LinearBinaryClassifier()) = (
Create a NamedTuple containing default models to be used by downstream estimators.
Each provided model is prepended (in a `MLJ.Pipeline`) with an `MLJ.ContinuousEncoder`.
By default:
- Q_binary is a LinearBinaryClassifier
- Q_continuous is a LinearRegressor
- G is a LinearBinaryClassifier
# Example
The following changes the default `Q_binary` to a `LogisticClassifier` and provides a `RidgeRegressor` for `special_y`.
```julia
using MLJLinearModels
models = (
special_y = RidgeRegressor(),
default_models(Q_binary=LogisticClassifier())...
)
```
"""
default_models(;Q_binary=LinearBinaryClassifier(), Q_continuous=LinearRegressor(), G=LinearBinaryClassifier()) = (
Q_binary_default = with_encoder(Q_binary),
Q_continuous_default = with_encoder(Q_continuous),
G_default = with_encoder(G)
)

Expand Down

0 comments on commit cb83bc9

Please sign in to comment.