Skip to content

Commit

Permalink
Support for dbt 1.0.0
Browse files Browse the repository at this point in the history
  • Loading branch information
jamesweakley committed Dec 8, 2021
1 parent ceeb800 commit 88c6f11
Show file tree
Hide file tree
Showing 14 changed files with 72 additions and 29 deletions.
11 changes: 8 additions & 3 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -10,14 +10,14 @@ The macros are:

| scikit-learn function | macro name | Snowflake | BigQuery | Redshift | MSSQL | PostgreSQL | Example |
| --- | --- | --- | --- | --- | --- | --- | --- |
| [KBinsDiscretizer](https://scikit-learn.org/stable/modules/generated/sklearn.preprocessing.KBinsDiscretizer.html#sklearn.preprocessing.KBinsDiscretizer)| k_bins_discretizer | Y | Y | Y | N | Y | ![example](images/k_bins.gif) |
| [KBinsDiscretizer](https://scikit-learn.org/stable/modules/generated/sklearn.preprocessing.KBinsDiscretizer.html#sklearn.preprocessing.KBinsDiscretizer)| k_bins_discretizer | Y | Y | Y | Y | Y | ![example](images/k_bins.gif) |
| [LabelEncoder](https://scikit-learn.org/stable/modules/generated/sklearn.preprocessing.LabelEncoder.html#sklearn.preprocessing.LabelEncoder)| label_encoder | Y | Y | Y | Y | Y | ![example](images/label_encoder.gif) |
| [MaxAbsScaler](https://scikit-learn.org/stable/modules/generated/sklearn.preprocessing.MaxAbsScaler.html#sklearn.preprocessing.MaxAbsScaler) | max_abs_scaler | Y | Y | Y | Y | Y | [![example](images/max_abs_scaler.png)](https://scikit-learn.org/stable/auto_examples/preprocessing/plot_all_scaling.html#maxabsscaler) |
| [MinMaxScaler](https://scikit-learn.org/stable/modules/generated/sklearn.preprocessing.MinMaxScaler.html#sklearn.preprocessing.MinMaxScaler) | min_max_scaler | Y | Y | Y | N | Y | [![example](images/min_max_scaler.png)](https://scikit-learn.org/stable/auto_examples/preprocessing/plot_all_scaling.html#minmaxscaler) |
| [MinMaxScaler](https://scikit-learn.org/stable/modules/generated/sklearn.preprocessing.MinMaxScaler.html#sklearn.preprocessing.MinMaxScaler) | min_max_scaler | Y | Y | Y | Y | Y | [![example](images/min_max_scaler.png)](https://scikit-learn.org/stable/auto_examples/preprocessing/plot_all_scaling.html#minmaxscaler) |
| [Normalizer](https://scikit-learn.org/stable/modules/generated/sklearn.preprocessing.Normalizer.html#sklearn.preprocessing.Normalizer) | normalizer | Y | Y | Y | Y | Y | [![example](images/normalizer.png)](https://scikit-learn.org/stable/auto_examples/preprocessing/plot_all_scaling.html#normalizer) |
| [OneHotEncoder](https://scikit-learn.org/stable/modules/generated/sklearn.preprocessing.OneHotEncoder.html#sklearn.preprocessing.OneHotEncoder) | one_hot_encoder | Y | Y | Y | Y | Y | ![example](images/one_hot_encoder.gif) |
| [QuantileTransformer](https://scikit-learn.org/stable/modules/generated/sklearn.preprocessing.QuantileTransformer.html#sklearn.preprocessing.QuantileTransformer) | quantile_transformer | Y | Y | N | N | Y | [![example](images/quantile_transformer.png)](https://scikit-learn.org/stable/auto_examples/preprocessing/plot_all_scaling.html#quantiletransformer-uniform-output) |
| [RobustScaler](https://scikit-learn.org/stable/modules/generated/sklearn.preprocessing.RobustScaler.html#sklearn.preprocessing.RobustScaler) | robust_scaler | Y | Y | Y | N | Y | [![example](images/robust_scaler.png)](https://scikit-learn.org/stable/auto_examples/preprocessing/plot_all_scaling.html#robustscaler) |
| [RobustScaler](https://scikit-learn.org/stable/modules/generated/sklearn.preprocessing.RobustScaler.html#sklearn.preprocessing.RobustScaler) | robust_scaler | Y | Y | Y | Y | Y | [![example](images/robust_scaler.png)](https://scikit-learn.org/stable/auto_examples/preprocessing/plot_all_scaling.html#robustscaler) |
| [StandardScaler](https://scikit-learn.org/stable/modules/generated/sklearn.preprocessing.StandardScaler.html#sklearn.preprocessing.StandardScaler) | standard_scaler | Y | Y | Y | N | Y | [![example](images/standard_scaler.png)](https://scikit-learn.org/stable/auto_examples/preprocessing/plot_all_scaling.html#standardscaler) |

_\* 2D charts taken from [scikit-learn.org](https://scikit-learn.org/stable/auto_examples/preprocessing/plot_all_scaling.html), GIFs are my own_
Expand All @@ -33,6 +33,11 @@ _(replace the revision number with the latest)_
Then run:
```dbt deps``` to import the package.

### dbt 1.0.0 compatibility
dbt-ml-preprocessing version 1.2.0 is the first version to support (and require) dbt 1.0.0.

If you are not ready to upgrade to dbt 1.0.0, please use dbt-ml-preprocessing version 1.0.2.

## Usage
To read the macro documentation and see examples, simply [generate your docs](https://docs.getdbt.com/reference/commands/cmd-docs/), and you'll see macro documentation in the Projects tree under ```dbt_ml_preprocessing```:

Expand Down
6 changes: 3 additions & 3 deletions dbt_project.yml
Original file line number Diff line number Diff line change
@@ -1,13 +1,13 @@
name: 'dbt_ml_preprocessing'
version: '1.0.2'
version: '1.1.0'

require-dbt-version: ">=0.15.1"
require-dbt-version: ">=1.0.0"

config-version: 2

profile: "integration_tests"

source-paths: ["models"]
model-paths: ["models"]
target-path: "target"
clean-targets: ["target", "dbt_modules"]
macro-paths: ["macros"]
Expand Down
1 change: 1 addition & 0 deletions integration_tests/dbt_packages/dbt_ml_preprocessing
1 change: 1 addition & 0 deletions integration_tests/dbt_packages/dbt_utils
Submodule dbt_utils added at 68b4b4
4 changes: 2 additions & 2 deletions integration_tests/dbt_project.yml
Original file line number Diff line number Diff line change
Expand Up @@ -8,10 +8,10 @@ profile: 'integration_tests'

config-version: 2

source-paths: ["models"]
model-paths: ["models"]
analysis-paths: ["analysis"]
test-paths: ["tests"]
data-paths: ["data"]
seed-paths: ["data"]
macro-paths: ["macros"]

target-path: "target" # directory which will store compiled SQL files
Expand Down
50 changes: 40 additions & 10 deletions integration_tests/macros/equality_with_numeric_tolerance.sql
Original file line number Diff line number Diff line change
Expand Up @@ -23,10 +23,10 @@ information schema — this allows the model to be an ephemeral model
{% set target_numeric_column_name = kwargs.get('target_numeric_column_name', kwargs.get('arg')) %}
{% set percentage_tolerance = kwargs.get('percentage_tolerance', kwargs.get('arg')) %}

{{ return(adapter.dispatch('test_equality_with_numeric_tolerance')(model,compare_model,source_join_column,target_join_column,source_numeric_column_name,target_numeric_column_name,percentage_tolerance)) }}
{{ return(adapter.dispatch('test_equality_with_numeric_tolerance')(model,compare_model,source_join_column,target_join_column,source_numeric_column_name,target_numeric_column_name,percentage_tolerance,True)) }}
{% endmacro %}

{% macro default__test_equality_with_numeric_tolerance(model,compare_model,source_join_column,target_join_column,source_numeric_column_name,target_numeric_column_name,percentage_tolerance,output_all_rows=False) %}
{% macro default__test_equality_with_numeric_tolerance(model,compare_model,source_join_column,target_join_column,source_numeric_column_name,target_numeric_column_name,percentage_tolerance,output_all_rows) %}
{% set compare_cols_csv = compare_columns | join(', ') %}
with a as (
select * from {{ model }}
Expand Down Expand Up @@ -56,16 +56,12 @@ from joined
where percent_difference > {{ percentage_tolerance }}
{% endmacro %}

{% macro sqlserver__test_equality_with_numeric_tolerance(model,compare_model,source_join_column,target_join_column,source_numeric_column_name,target_numeric_column_name,percentage_tolerance,output_all_rows=False) %}
{% do return( redshift__test_equality_with_numeric_tolerance(model,compare_model,source_join_column,target_join_column,source_numeric_column_name,target_numeric_column_name,percentage_tolerance,output_all_rows=False)) %}
{% endmacro %}

{% macro postgres__test_equality_with_numeric_tolerance(model,compare_model,source_join_column,target_join_column,source_numeric_column_name,target_numeric_column_name,percentage_tolerance,output_all_rows=False) %}
{% do return( redshift__test_equality_with_numeric_tolerance(model,compare_model,source_join_column,target_join_column,source_numeric_column_name,target_numeric_column_name,percentage_tolerance,output_all_rows=False)) %}
{% macro postgres__test_equality_with_numeric_tolerance(model,compare_model,source_join_column,target_join_column,source_numeric_column_name,target_numeric_column_name,percentage_tolerance,output_all_rows) %}
{% do return( redshift__test_equality_with_numeric_tolerance(model,compare_model,source_join_column,target_join_column,source_numeric_column_name,target_numeric_column_name,percentage_tolerance,output_all_rows)) %}
{% endmacro %}


{% macro snowflake__test_equality_with_numeric_tolerance(model,compare_model,source_join_column,target_join_column,source_numeric_column_name,target_numeric_column_name,percentage_tolerance,output_all_rows=False) %}
{% macro snowflake__test_equality_with_numeric_tolerance(model,compare_model,source_join_column,target_join_column,source_numeric_column_name,target_numeric_column_name,percentage_tolerance,output_all_rows) %}
{% set compare_cols_csv = compare_columns | join(', ') %}
with a as (
select * from {{ model }}
Expand All @@ -92,7 +88,7 @@ from joined
where percent_difference > {{ percentage_tolerance }}
{% endmacro %}

{% macro redshift__test_equality_with_numeric_tolerance(model,compare_model,source_join_column,target_join_column,source_numeric_column_name,target_numeric_column_name,percentage_tolerance,output_all_rows=False) %}
{% macro redshift__test_equality_with_numeric_tolerance(model,compare_model,source_join_column,target_join_column,source_numeric_column_name,target_numeric_column_name,percentage_tolerance,output_all_rows) %}
{% set compare_cols_csv = compare_columns | join(', ') %}
with a as (
select * from {{ model }}
Expand Down Expand Up @@ -123,4 +119,38 @@ from joined
-- The reason we tolerate tiny differences here is because of the floating point arithmetic,
-- the values do not end up exactly the same as those output from python
where percent_difference > {{ percentage_tolerance }}
{% endmacro %}

{% macro sqlserver__test_equality_with_numeric_tolerance(model,compare_model,source_join_column,target_join_column,source_numeric_column_name,target_numeric_column_name,percentage_tolerance,output_all_rows) %}
{% set compare_cols_csv = compare_columns | join(', ') %}
with a as (
select * from {{ model }}
),
b as (
select * from {{ compare_model }}
),
joined as(
select round(a.{{ source_numeric_column_name }},6) as actual,
round(b.{{ target_numeric_column_name }},6) as expected,
b.{{ target_numeric_column_name }} as actual_value
from a
join b on a.{{ source_join_column }}=b.{{ target_join_column }}
),
joined_calced as(
select
abs(actual-expected) as difference,
iif(abs(actual-expected)>0,
abs(actual-expected)/actual_value,
0)*100 as percent_difference
from joined
)
select {% if output_all_rows %}
*
{% else %}
count(*)
{% endif %}
from joined_calced
-- The reason we tolerate tiny differences here is because of the floating point arithmetic,
-- the values do not end up exactly the same as those output from python
where percent_difference > {{ percentage_tolerance }}
{% endmacro %}
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,11 @@ select * from data

-- macro not supported in other databases
{% macro default__quantile_transformer_model_macro() %}
select 1 from (select 1) where 1=2 -- empty result set so that test passes
select 1 as one from (select 1) where 1=2 -- empty result set so that test passes
{% endmacro %}

{% macro redshift__quantile_transformer_model_macro() %}
select 1 as one from (select 1) where 1=2 -- empty result set so that test passes
{% endmacro %}

-- macro not supported in sqlserver
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@

-- testing macro only works on Snowflake
{% macro default__test_quantile_transformer_result_with_tolerance() %}
select 1 from (select 1) where 1=2 -- empty result set so that test passes
select 1 as one from (select 1) where 1=2 -- empty result set so that test passes
{% endmacro %}

-- testing macro not supported in sqlserver
Expand Down
2 changes: 1 addition & 1 deletion integration_tests/packages.yml
Original file line number Diff line number Diff line change
Expand Up @@ -3,4 +3,4 @@ packages:
- local: ../

- git: "https://github.com/fishtown-analytics/dbt-utils.git"
revision: 0.6.3
revision: 0.8.0
4 changes: 2 additions & 2 deletions macros/k_bins_discretizer.sql
Original file line number Diff line number Diff line change
Expand Up @@ -33,7 +33,7 @@ with
)
{% if not loop.last %}, {% endif %}
{% endfor %}
{{ adapter.dispatch('k_bins_discretizer',packages=['dbt_ml_preprocessing'])(source_table,source_columns,include_columns,n_bins,encode,strategy) }}
{{ adapter.dispatch('k_bins_discretizer','dbt_ml_preprocessing')(source_table,source_columns,include_columns,n_bins,encode,strategy) }}
{% endmacro %}


Expand Down Expand Up @@ -80,7 +80,7 @@ source_table.{{ column }},
case when
floor(
cast({{ source_column }} - {{ source_column }}_aggregates.min_value as decimal)/ cast( {{ source_column }}_aggregates.max_value - {{ source_column }}_aggregates.min_value as decimal ) * {{ n_bins }}
) > {{ n_bins - 1 }}
) < {{ n_bins - 1 }}
then floor(
cast({{ source_column }} - {{ source_column }}_aggregates.min_value as decimal)/ cast( {{ source_column }}_aggregates.max_value - {{ source_column }}_aggregates.min_value as decimal ) * {{ n_bins }}
)
Expand Down
2 changes: 1 addition & 1 deletion macros/label_encoder.sql
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
{% macro label_encoder(source_table,source_column, include_columns='*') %}
{{ adapter.dispatch('label_encoder',packages=['dbt_ml_preprocessing'])(source_table,source_column,include_columns) }}
{{ adapter.dispatch('label_encoder','dbt_ml_preprocessing')(source_table,source_column,include_columns) }}
{% endmacro %}

{% macro default__label_encoder(source_table,source_column,include_columns) %}
Expand Down
2 changes: 1 addition & 1 deletion macros/one_hot_encoder.sql
Original file line number Diff line number Diff line change
Expand Up @@ -67,7 +67,7 @@
{%- endfor -%}
{%- endif -%}

{{ adapter.dispatch('one_hot_encoder',packages=['dbt_ml_preprocessing'])(source_table, source_column, category_values, handle_unknown, col_list) }}
{{ adapter.dispatch('one_hot_encoder','dbt_ml_preprocessing')(source_table, source_column, category_values, handle_unknown, col_list) }}
{%- endmacro %}

{% macro default__one_hot_encoder(source_table, source_column, category_values, handle_unknown, col_list) %}
Expand Down
4 changes: 3 additions & 1 deletion macros/quantile_transformer.sql
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@
{%- set all_source_columns = adapter.get_columns_in_relation(source_table) | map(attribute='quoted') -%}
{% set include_columns = all_source_columns | join(', ') %}
{%- endif -%}
{{ adapter.dispatch('quantile_transformer',packages=['dbt_ml_preprocessing'])(source_table,source_column,n_quantiles,output_distribution,subsample,include_columns) }}
{{ adapter.dispatch('quantile_transformer','dbt_ml_preprocessing')(source_table,source_column,n_quantiles,output_distribution,subsample,include_columns) }}
{% endmacro %}

{% macro default__quantile_transformer(source_table,source_column,n_quantiles,output_distribution,subsample,include_columns) %}
Expand Down Expand Up @@ -66,10 +66,12 @@ from linear_interpolation_variables
{% endmacro %}

{% macro redshift__quantile_transformer(source_table,source_column,n_quantiles,output_distribution,subsample,include_columns) %}
{% if execute %}
{% set error_message %}
The `quantile_transformer` macro is only supported on Snowflake and BigQuery at this time. It should work on other DBs, it just requires some rework.
{% endset %}
{%- do exceptions.raise_compiler_error(error_message) -%}
{% endif %}
{% endmacro %}

{% macro postgre__quantile_transformer(source_table,source_column,n_quantiles,output_distribution,subsample,include_columns) %}
Expand Down
6 changes: 3 additions & 3 deletions macros/robust_scaler.sql
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,7 @@ The `source_columns` parameter must contain a list of column names.
{%- set all_source_columns = adapter.get_columns_in_relation(source_table) | map(attribute='quoted') -%}
{% set include_columns = all_source_columns %}
{%- endif -%}
{{ adapter.dispatch('robust_scaler',packages=['dbt_ml_preprocessing'])(source_table,source_columns,include_columns,with_centering,quantile_range) }}
{{ adapter.dispatch('robust_scaler','dbt_ml_preprocessing')(source_table,source_columns,include_columns,with_centering,quantile_range) }}
{% endmacro %}

{% macro default__robust_scaler(source_table,source_columns,include_columns,with_centering,quantile_range) %}
Expand Down Expand Up @@ -110,8 +110,8 @@ with
{% for source_column in source_columns %}
{{ source_column }}_quartiles as(
select
percentile_cont({{ quantile_range[0] / 100 }}) within group (order by {{ source_column }}) OVER(PARTITION BY {{ source_column }}) as first_quartile,
percentile_cont({{ quantile_range[1] / 100 }}) within group (order by {{ source_column }}) OVER(PARTITION BY {{ source_column }}) as third_quartile
percentile_cont({{ quantile_range[0] / 100 }}) within group (order by {{ source_column }}) OVER() as first_quartile,
percentile_cont({{ quantile_range[1] / 100 }}) within group (order by {{ source_column }}) OVER() as third_quartile
from {{ source_table }}
)
{% if not loop.last %}, {% endif %}
Expand Down

0 comments on commit 88c6f11

Please sign in to comment.