diff --git a/.github/workflows/build.yml b/.github/workflows/build.yml index 208fab0f..9b9da1d7 100644 --- a/.github/workflows/build.yml +++ b/.github/workflows/build.yml @@ -24,11 +24,8 @@ jobs: python -m pip install black flake8 - name: Code Style (Black/Flake8) run: | - # Black code style black --check --diff pytorch_widedeep tests examples setup.py - # Stop the build if there are Python syntax errors or undefined names flake8 . --count --select=E901,E999,F821,F822,F823 --ignore=E266 --show-source --statistics - # exit-zero treats all errors as warnings. The GitHub editor is 127 chars wide flake8 . --count --exit-zero --max-complexity=10 --max-line-length=127 --ignore=E203,E266,E501,E721,E722,F401,F403,F405,F811,W503,C901 --statistics test: @@ -47,41 +44,16 @@ jobs: - name: Install dependencies run: | python -m pip install --upgrade pip - python -m pip install pytest-cov codecov faker + python -m pip install pytest pytest-cov codecov faker if [ -f requirements.txt ]; then pip install -r requirements.txt; fi - - name: Test with pytest + - name: Test with pytest and generate coverage run: | - pytest --doctest-modules pytorch_widedeep --cov-report xml --cov-report term --disable-pytest-warnings --cov=pytorch_widedeep tests/ - - name: Upload coverage - uses: actions/upload-artifact@v4 - with: - name: coverage${{ matrix.python-version }} - path: .coverage - - finish: - needs: test - runs-on: ubuntu-latest - if: ${{ github.event_name == 'push' || !github.event.pull_request.draft }} - steps: - - uses: actions/checkout@v4 - - name: Set up Python 3.10 - uses: actions/setup-python@v5 - with: - python-version: "3.10" - - name: Install dependencies - run: | - python -m pip install --upgrade pip - python -m pip install coverage - - name: Download all artifacts - # Downloads coverage1, coverage2, etc. - uses: actions/download-artifact@v4 - - name: Convert coverage - run: | - coverage combine coverage*/.coverage* - coverage report --fail-under=90 - coverage xml - - name: upload coverage to Codecov + pytest --doctest-modules pytorch_widedeep --cov=pytorch_widedeep --cov-report=xml --cov-report=term --disable-pytest-warnings tests + - name: Upload coverage to Codecov uses: codecov/codecov-action@v4 with: token: ${{ secrets.CODECOV_TOKEN }} - fail_ci_if_error: true + file: ./coverage.xml + flags: unittests + name: codecov-${{ matrix.python-version }} + fail_ci_if_error: true \ No newline at end of file diff --git a/.isort.cfg b/.isort.cfg index d81d7de1..023dbdb4 100644 --- a/.isort.cfg +++ b/.isort.cfg @@ -1,4 +1,4 @@ [settings] +profile=black multi_line_output=3 -include_trailing_comma=True length_sort=1 \ No newline at end of file diff --git a/README.md b/README.md index ddf11304..2db68794 100644 --- a/README.md +++ b/README.md @@ -33,6 +33,7 @@ The content of this document is organized as follows: - [Introduction](#introduction) - [Architectures](#architectures) - [The ``deeptabular`` component](#the-deeptabular-component) + - [The ``rec`` module](#the-rec-module) - [Text and Images](#text-and-images) - [Installation](#installation) - [Developer Install](#developer-install) @@ -795,6 +796,28 @@ encoder-decoder method and constrastive-denoising method. Please, see the documentation and the examples for details on this functionality, and all other options in the library. +### The ``rec`` module + +This module was introduced as an extension to the existing components in the +library, addressing questions and issues related to recommendation systems. +While still under active development, it currently includes a select number +of powerful recommendation models. + +It's worth noting that this library already supported the implementation of +various recommendation algorithms using existing components. For example, +models like Wide and Deep, Two-Tower, or Neural Collaborative Filtering could +be constructed using the library's core functionalities. + +The recommendation algorithms in the `rec` module are: + +1. [DeepFM: A Factorization-Machine based Neural Network for CTR Prediction](https://arxiv.org/abs/1703.04247) +2. (Deep) Field Aware Factorization Machine (FFM): a Deep Learning version of the algorithm presented in [Field-aware Factorization Machines in a Real-world Online Advertising System](https://arxiv.org/abs/1701.04099) +3. [xDeepFM: Combining Explicit and Implicit Feature Interactions for Recommender Systems](https://arxiv.org/pdf/1803.05170) +4. [Deep Interest Network for Click-Through Rate Prediction](https://arxiv.org/abs/1706.06978) + +These can all be used as the `deeptabular` component in the `WideDeep` model. +See the examples for more details. + ### Text and Images For the text component, `deeptext`, the library offers the following models: diff --git a/VERSION b/VERSION index 266146b8..9edc58bb 100644 --- a/VERSION +++ b/VERSION @@ -1 +1 @@ -1.6.3 +1.6.4 diff --git a/examples/scripts/adult_census.py b/examples/scripts/adult_census.py index e76e38cd..17173c41 100644 --- a/examples/scripts/adult_census.py +++ b/examples/scripts/adult_census.py @@ -3,19 +3,10 @@ import pandas as pd from pytorch_widedeep import Trainer -from pytorch_widedeep.models import ( # noqa: F401 - Wide, - TabMlp, - WideDeep, - TabResnet, -) +from pytorch_widedeep.models import Wide, TabMlp, WideDeep, TabResnet # noqa: F401 from pytorch_widedeep.metrics import Accuracy, Precision from pytorch_widedeep.datasets import load_adult -from pytorch_widedeep.callbacks import ( - LRHistory, - EarlyStopping, - ModelCheckpoint, -) +from pytorch_widedeep.callbacks import LRHistory, EarlyStopping, ModelCheckpoint from pytorch_widedeep.initializers import XavierNormal, KaimingNormal from pytorch_widedeep.preprocessing import TabPreprocessor, WidePreprocessor diff --git a/examples/scripts/adult_census_attention_mlp.py b/examples/scripts/adult_census_attention_mlp.py index 2fb37b3f..85562f63 100644 --- a/examples/scripts/adult_census_attention_mlp.py +++ b/examples/scripts/adult_census_attention_mlp.py @@ -3,11 +3,7 @@ import pandas as pd from pytorch_widedeep import Trainer -from pytorch_widedeep.models import ( - WideDeep, - SelfAttentionMLP, - ContextAttentionMLP, -) +from pytorch_widedeep.models import WideDeep, SelfAttentionMLP, ContextAttentionMLP from pytorch_widedeep.metrics import Accuracy from pytorch_widedeep.datasets import load_adult from pytorch_widedeep.preprocessing import TabPreprocessor diff --git a/examples/scripts/adult_census_cont_den_full_example.py b/examples/scripts/adult_census_cont_den_full_example.py index d282cdbf..a661d011 100644 --- a/examples/scripts/adult_census_cont_den_full_example.py +++ b/examples/scripts/adult_census_cont_den_full_example.py @@ -8,9 +8,7 @@ from pytorch_widedeep.metrics import Accuracy from pytorch_widedeep.datasets import load_adult from pytorch_widedeep.preprocessing import TabPreprocessor -from pytorch_widedeep.self_supervised_training import ( - ContrastiveDenoisingTrainer, -) +from pytorch_widedeep.self_supervised_training import ContrastiveDenoisingTrainer use_cuda = torch.cuda.is_available() diff --git a/examples/scripts/adult_census_cont_den_run_all_models.py b/examples/scripts/adult_census_cont_den_run_all_models.py index ecbaa7cc..b69174bb 100644 --- a/examples/scripts/adult_census_cont_den_run_all_models.py +++ b/examples/scripts/adult_census_cont_den_run_all_models.py @@ -14,9 +14,7 @@ ) from pytorch_widedeep.datasets import load_adult from pytorch_widedeep.preprocessing import TabPreprocessor -from pytorch_widedeep.self_supervised_training import ( - ContrastiveDenoisingTrainer, -) +from pytorch_widedeep.self_supervised_training import ContrastiveDenoisingTrainer use_cuda = torch.cuda.is_available() diff --git a/examples/scripts/adult_census_enc_dec_run_all_models.py b/examples/scripts/adult_census_enc_dec_run_all_models.py index 5a5bf36e..2c77cffc 100644 --- a/examples/scripts/adult_census_enc_dec_run_all_models.py +++ b/examples/scripts/adult_census_enc_dec_run_all_models.py @@ -5,11 +5,7 @@ from pytorch_widedeep.models import TabMlp as TabMlpEncoder from pytorch_widedeep.models import TabNet as TabNetEncoder from pytorch_widedeep.models import TabResnet as TabResnetEncoder -from pytorch_widedeep.models import ( - TabMlpDecoder, - TabNetDecoder, - TabResnetDecoder, -) +from pytorch_widedeep.models import TabMlpDecoder, TabNetDecoder, TabResnetDecoder from pytorch_widedeep.datasets import load_adult from pytorch_widedeep.preprocessing import TabPreprocessor from pytorch_widedeep.self_supervised_training import EncoderDecoderTrainer diff --git a/examples/scripts/adult_census_tabnet.py b/examples/scripts/adult_census_tabnet.py index c523e54b..e3f2cfe7 100644 --- a/examples/scripts/adult_census_tabnet.py +++ b/examples/scripts/adult_census_tabnet.py @@ -6,11 +6,7 @@ from pytorch_widedeep.models import TabNet, WideDeep from pytorch_widedeep.metrics import Accuracy, Precision from pytorch_widedeep.datasets import load_adult -from pytorch_widedeep.callbacks import ( - LRHistory, - EarlyStopping, - ModelCheckpoint, -) +from pytorch_widedeep.callbacks import LRHistory, EarlyStopping, ModelCheckpoint from pytorch_widedeep.preprocessing import TabPreprocessor use_cuda = torch.cuda.is_available() diff --git a/examples/scripts/adult_census_transformers.py b/examples/scripts/adult_census_transformers.py index 08e2fbe8..1718f0d0 100644 --- a/examples/scripts/adult_census_transformers.py +++ b/examples/scripts/adult_census_transformers.py @@ -14,11 +14,7 @@ ) from pytorch_widedeep.metrics import Accuracy from pytorch_widedeep.datasets import load_adult -from pytorch_widedeep.callbacks import ( - LRHistory, - EarlyStopping, - ModelCheckpoint, -) +from pytorch_widedeep.callbacks import LRHistory, EarlyStopping, ModelCheckpoint from pytorch_widedeep.initializers import XavierNormal, KaimingNormal from pytorch_widedeep.preprocessing import TabPreprocessor, WidePreprocessor diff --git a/examples/scripts/movielens_din.py b/examples/scripts/movielens_din.py new file mode 100644 index 00000000..551c5dec --- /dev/null +++ b/examples/scripts/movielens_din.py @@ -0,0 +1,257 @@ +# DIN is a "special" model and the data needs a very particular preparation +# process. Therefore, the library does not have a dedicated preprocessor for +# this specific model. Below is a detail explanation and if someone wants to +# use this algo I would suggest to wrap it all in an object prior to pass the +# data to the model + +import re +from functools import partial + +import numpy as np +import pandas as pd +from sklearn.preprocessing import LabelEncoder + +from pytorch_widedeep import Trainer +from pytorch_widedeep.models import WideDeep +from pytorch_widedeep.metrics import Accuracy +from pytorch_widedeep.datasets import load_movielens100k +from pytorch_widedeep.preprocessing import TabPreprocessor +from pytorch_widedeep.models.rec.din import DeepInterestNetwork + + +def clean_genre_list(genre_list): + return "_".join( + sorted([re.sub(r"[^a-z0-9]", "", genre.lower()) for genre in genre_list]) + ) + + +def label_encode_column(df, column_name): + le = LabelEncoder() + df[column_name] = le.fit_transform(df[column_name]) + return df, le + + +def create_sequences(group, seq_len=5): + movies = group["movie_id"].tolist() + genres = group["genre_list"].tolist() + ratings = group["rating"].tolist() + + sequences = [] + for i in range(len(movies) - seq_len): + user_movies_sequence = movies[i : i + seq_len] + genres_sequence = genres[i : i + seq_len] + ratings_sequence = ratings[i : i + seq_len] + target_item = movies[i + seq_len] + target_item_rating = ratings[i + seq_len] + + sequences.append( + { + "user_id": group.name, + "user_movies_sequence": user_movies_sequence, + "genres_sequence": genres_sequence, + "ratings_sequence": ratings_sequence, + "target_item": target_item, + "target_item_rating": target_item_rating, + } + ) + + seq_df = pd.DataFrame(sequences) + non_seq_cols = group.drop_duplicates(["user_id"]).drop( + ["movie_id", "genre_list", "rating", "timestamp"], axis=1 + ) + + return pd.merge(seq_df, non_seq_cols, on="user_id") + + +def preprocess_movie_data(df, seq_len=5): + df_sorted = df.sort_values(["user_id", "timestamp"]) + + partial_create_sequences = partial(create_sequences, seq_len=seq_len) + + result_df = ( + df_sorted.groupby("user_id") + .apply(partial_create_sequences) + .reset_index(drop=True) + ) + + return result_df + + +if __name__ == "__main__": + + data, users, items = load_movielens100k(as_frame=True) + + list_of_genres = [ + "unknown", + "Action", + "Adventure", + "Animation", + "Children's", + "Comedy", + "Crime", + "Documentary", + "Drama", + "Fantasy", + "Film-Noir", + "Horror", + "Musical", + "Mystery", + "Romance", + "Sci-Fi", + "Thriller", + "War", + "Western", + ] + + assert ( + isinstance(items, pd.DataFrame) + and isinstance(data, pd.DataFrame) + and isinstance(users, pd.DataFrame) + ) + items["genre_list"] = items[list_of_genres].apply( + lambda x: [genre for genre in list_of_genres if x[genre] == 1], axis=1 + ) + + items["genre_list"] = items["genre_list"].apply(clean_genre_list) + + df = pd.merge(data, items[["movie_id", "genre_list"]], on="movie_id") + df = pd.merge( + df, + users[["user_id", "age", "gender", "occupation"]], + on="user_id", + ) + + df["rating"] = df["rating"].apply(lambda x: 1 if x >= 4 else 0) + + # Up until here, everything is quite standard. Now, some columns will be + # treated as sequences, and therefore, they need to be tokenized + # and "numericalised"/label encoded + df, user_le = label_encode_column(df, "user_id") + df, item_le = label_encode_column(df, "movie_id") + df, genre_le = label_encode_column(df, "genre_list") + + # Internally all models for tabular data in this libray use padding idx = + # 0 for unseen values, while sklearn's LabelEncoder starts at 0. + # Therefore we need to add 1 to the encoded values to leave 0 for + # unknown/unseen values + df["movie_id"] = df["movie_id"] + 1 + df["genre_list"] = df["genre_list"] + 1 + + # The explanation as to why we do this with the ratings will come later + df["rating"] = df["rating"] + 1 + + # we build sequences of 5 movies. Our goal will be predicting whether the + # next movie will be reviewed positively or negatively + df = df.sort_values(by=["timestamp"]).reset_index(drop=True) + seq_df = preprocess_movie_data(df, seq_len=5) + # target back to 0/1 + seq_df["target_item_rating"] = seq_df["target_item_rating"] - 1 + + X_target_item = np.array(seq_df.target_item.tolist()).reshape(-1, 1) + + # in reality, all users here have more than 5 reviews, so we have complete + # sequences, but there is a change that this does not happen in a given + # dataset, so one would have to pad with padding idx (0 in this case) + seq_len = 5 + X_user_behaviour = np.array( + [ + lst + [0] * (seq_len - len(lst)) + for lst in seq_df.user_movies_sequence.tolist() + ] + ) + X_ratings = np.array( + [lst + [0] * (seq_len - len(lst)) for lst in seq_df.ratings_sequence.tolist()] + ) + X_genres = np.array( + [lst + [0] * (seq_len - len(lst)) for lst in seq_df.genres_sequence.tolist()] + ) + + # At this point we have the target item as an array of shape (N obs, 1), + # and all columns that are going to be treated as sequences stored in + # arrays of shape (N obs, seq_len). The rest of the columns are going to + # be treated as ANY other "standard" tabular dataset + other_cols = ["user_id", "age", "gender", "occupation"] + df_other_feat = seq_df[other_cols] + tab_preprocessor = TabPreprocessor(cat_embed_cols=other_cols) + X_other_feats = tab_preprocessor.fit_transform(df_other_feat) + + X_all = np.concatenate( + [X_other_feats, X_target_item, X_user_behaviour, X_ratings, X_genres], axis=1 + ) + + # Now, all the model components in this library they take a tensor + # (just one) as input. Therefore, if we want to treat some data + # differently, we need to slice the tensor internally. For this to happen + # we need to "tell" the algorithm which column is which. DIN has data of + # two different natures: sequences and standard tabular (i.e. everything + # else). For the sequences, lets simply define columns with what they are + # an an index (but you can define then with whichever string you want) + user_behaviour_cols = [f"item_{i+1}" for i in range(5)] + genres_cols = [f"genre_{i+1}" for i in range(5)] + ratings_cols = [f"rating_{i+1}" for i in range(5)] + + # Then all columns in the datasets are, in order of appearance in X_all: + all_cols = ( + [el[0] for el in tab_preprocessor.cat_embed_input] # tabular cols + + ["target_item"] # target item + + user_behaviour_cols # user behaviour seq cols + + ratings_cols # ratings seq cols + + genres_cols # genres seq cols + ) + column_idx = {k: i for i, k in enumerate(all_cols)} + + # Now we need to define the so called "configs". For the sequence columns these will consist of: + # - the column names + # - the maximum value in the column (to define the number of embeddings) + # - the embedding dim + user_behavior_confiq = ( + user_behaviour_cols, + X_user_behaviour.max(), + 32, + ) + + # Again, the explanation to this will come when we instantiate the model + # (also, please, read the docs) + rating_seq_config = (ratings_cols, 2) + + # all the other sequence columns that are not user behaviour or an action + # related to the items that define the user behaviour will be refer + # as "other sequence columns" and will be pass as elements of a list + other_seq_cols_confiq = [(genres_cols, X_genres.max(), 16)] + + # And finally, the config for the remaining, tabular columns + other_cols_config = tab_preprocessor.cat_embed_input + + # Now, one of the params of the DeepInterestNetwork is action_seq_config. + # This 'action' can be, for example a rating, or purchased/not-purchased. + # The way that this 'action' will be used is the following: this action + # will **always** be learned as a 1d embedding and will be combined with + # the user behaviour. For example, imagine that the action is + # purchased/not-purchased. then per item in the user behaviour sequence + # there will be a binary action to learn 0/1. Such action will be + # represented by a float number (so 3 floats will be learned, one for + # purchased, one for not-purchased and one for padding) that will + # multiply the corresponding item embedding in the user behaviour + # sequence. + din = DeepInterestNetwork( + column_idx=column_idx, + target_item_col="target_item", + user_behavior_confiq=user_behavior_confiq, + action_seq_config=rating_seq_config, + other_seq_cols_confiq=other_seq_cols_confiq, + cat_embed_input=other_cols_config, + mlp_hidden_dims=[128, 64], + ) + + # And from here on, everything is standard + model = WideDeep(deeptabular=din) + + trainer = Trainer(model=model, objective="binary", metrics=[Accuracy()]) + + # in the real world you would have to split the data into train, val and test + trainer.fit( + X_tab=X_all, + target=seq_df.target_item_rating.values, + n_epochs=5, + batch_size=512, + ) diff --git a/examples/scripts/movielens_fm.py b/examples/scripts/movielens_fm.py new file mode 100644 index 00000000..77db3423 --- /dev/null +++ b/examples/scripts/movielens_fm.py @@ -0,0 +1,139 @@ +import re +from typing import List, Tuple + +import pandas as pd + +from pytorch_widedeep import Trainer +from pytorch_widedeep.models import Wide, WideDeep +from pytorch_widedeep.metrics import Accuracy +from pytorch_widedeep.datasets import load_movielens100k +from pytorch_widedeep.models.rec import ( + DeepFactorizationMachine, + ExtremeDeepFactorizationMachine, + DeepFieldAwareFactorizationMachine, +) +from pytorch_widedeep.preprocessing import TabPreprocessor, WidePreprocessor + +if __name__ == "__main__": + + data, users, items = load_movielens100k(as_frame=True) + + list_of_genres = [ + "unknown", + "Action", + "Adventure", + "Animation", + "Children's", + "Comedy", + "Crime", + "Documentary", + "Drama", + "Fantasy", + "Film-Noir", + "Horror", + "Musical", + "Mystery", + "Romance", + "Sci-Fi", + "Thriller", + "War", + "Western", + ] + + # useless assertion to avoid mypy warnings + assert ( + isinstance(items, pd.DataFrame) + and isinstance(data, pd.DataFrame) + and isinstance(users, pd.DataFrame) + ) + items["genre_list"] = items[list_of_genres].apply( + lambda x: [genre for genre in list_of_genres if x[genre] == 1], axis=1 + ) + + # for each element in genre_list, all to lower case, remove non-alphanumeric + # characters, sort and join with an underscore + def clean_genre_list(genre_list): + return "_".join( + sorted([re.sub(r"[^a-z0-9]", "", genre.lower()) for genre in genre_list]) + ) + + items["genre_list"] = items["genre_list"].apply(clean_genre_list) + + df = pd.merge(data, users[["user_id", "age", "gender", "occupation"]], on="user_id") + df = pd.merge(df, items[["movie_id", "genre_list"]], on="movie_id") + + # binarize the ratings. + df["rating"] = df["rating"].apply(lambda x: 1 if x >= 4 else 0) + + # sort by timestamp, groupby user and keep the one before the last for val + # and the last for test + df = df.sort_values(by=["timestamp"]) + train_df = df.groupby("user_id").apply(lambda x: x.iloc[:-2]).reset_index(drop=True) + val_df = df.groupby("user_id").apply(lambda x: x.iloc[-2]).reset_index(drop=True) + test_df = df.groupby("user_id").apply(lambda x: x.iloc[-1]).reset_index(drop=True) + assert len(df) == len(train_df) + len(val_df) + len(test_df) + + cat_cols = [ + "user_id", + "movie_id", + "age", + "gender", + "occupation", + "genre_list", + ] + + tab_preprocessor = TabPreprocessor(cat_embed_cols=cat_cols, for_mf=True) + X_tab_tr = tab_preprocessor.fit_transform(train_df) + X_tab_val = tab_preprocessor.transform(val_df) + X_tab_te = tab_preprocessor.transform(test_df) + + wide_preprocessor = WidePreprocessor(wide_cols=cat_cols) + X_wide_tr = wide_preprocessor.fit_transform(train_df) + X_wide_val = wide_preprocessor.transform(val_df) + X_wide_te = wide_preprocessor.transform(test_df) + + cat_embed_input: List[Tuple[str, int]] = tab_preprocessor.cat_embed_input + dfm = DeepFactorizationMachine( + column_idx=tab_preprocessor.column_idx, + num_factors=8, + cat_embed_input=cat_embed_input, + mlp_hidden_dims=[64, 32], + ) + + dffm = DeepFieldAwareFactorizationMachine( + column_idx=tab_preprocessor.column_idx, + num_factors=8, + cat_embed_input=cat_embed_input, + mlp_hidden_dims=[64, 32], + ) + + xdfm = ExtremeDeepFactorizationMachine( + column_idx=tab_preprocessor.column_idx, + input_dim=16, + cat_embed_input=cat_embed_input, + cin_layer_dims=[32, 16], + mlp_hidden_dims=[64, 32], + ) + + wide = Wide(input_dim=X_tab_tr.max(), pred_dim=1) + + for fm_model in [dfm, dffm, xdfm]: + + fm_model = WideDeep(wide=wide, deeptabular=fm_model) + + trainer = Trainer(fm_model, objective="binary", metrics=[Accuracy()]) + + X_train = { + "X_wide": X_wide_tr, + "X_tab": X_tab_tr, + "target": train_df["rating"].values, + } + X_val = { + "X_wide": X_wide_val, + "X_tab": X_tab_val, + "target": val_df["rating"].values, + } + X_test = {"X_wide": X_wide_te, "X_tab": X_tab_te} + + trainer.fit(X_train=X_train, X_val=X_val, n_epochs=2) + trainer.predict(X_test={"X_wide": X_wide_te, "X_tab": X_tab_te}) diff --git a/examples/scripts/readme_snippets.py b/examples/scripts/readme_snippets.py index 9e230e8c..bf430923 100644 --- a/examples/scripts/readme_snippets.py +++ b/examples/scripts/readme_snippets.py @@ -12,14 +12,7 @@ from faker import Faker from pytorch_widedeep import Trainer -from pytorch_widedeep.models import ( - Wide, - TabMlp, - Vision, - BasicRNN, - WideDeep, - ModelFuser, -) +from pytorch_widedeep.models import Wide, TabMlp, Vision, BasicRNN, WideDeep, ModelFuser from pytorch_widedeep.preprocessing import ( TabPreprocessor, TextPreprocessor, @@ -27,9 +20,7 @@ ImagePreprocessor, ) from pytorch_widedeep.losses_multitarget import MultiTargetClassificationLoss -from pytorch_widedeep.models._base_wd_model_component import ( - BaseWDModelComponent, -) +from pytorch_widedeep.models._base_wd_model_component import BaseWDModelComponent def create_and_save_random_image(image_number, size=(32, 32)): diff --git a/mkdocs/mkdocs.yml b/mkdocs/mkdocs.yml index 3f8347cf..7cdd8ee8 100644 --- a/mkdocs/mkdocs.yml +++ b/mkdocs/mkdocs.yml @@ -28,6 +28,7 @@ nav: - Preprocessing: pytorch-widedeep/preprocessing.md - Load From Folder: pytorch-widedeep/load_from_folder.md - Model Components: pytorch-widedeep/model_components.md + - The Rec Module: pytorch-widedeep/the_rec_module.md - Bayesian models: pytorch-widedeep/bayesian_models.md - Losses: pytorch-widedeep/losses.md - Metrics: pytorch-widedeep/metrics.md diff --git a/mkdocs/site/404.html b/mkdocs/site/404.html index fe1eed8b..acaae39f 100644 --- a/mkdocs/site/404.html +++ b/mkdocs/site/404.html @@ -12,7 +12,7 @@ - + @@ -20,7 +20,7 @@ - + @@ -46,7 +46,7 @@ - + @@ -90,7 +90,7 @@
deeptabular
componentrec
module7. Tabular with a multi-target loss
+7. A two-tower model
+This is a popular model in the context of recommendation systems. Let's say we +have a tabular dataset formed my triples (user features, item features, +target). We can create a two-tower model where the user and item features are +passed through two separate models and then "fused" via a dot product.
++ +
+ +import numpy as np
+import pandas as pd
+
+from pytorch_widedeep import Trainer
+from pytorch_widedeep.preprocessing import TabPreprocessor
+from pytorch_widedeep.models import TabMlp, WideDeep, ModelFuser
+
+# Let's create the interaction dataset
+# user_features dataframe
+np.random.seed(42)
+user_ids = np.arange(1, 101)
+ages = np.random.randint(18, 60, size=100)
+genders = np.random.choice(["male", "female"], size=100)
+locations = np.random.choice(["city_a", "city_b", "city_c", "city_d"], size=100)
+user_features = pd.DataFrame(
+ {"id": user_ids, "age": ages, "gender": genders, "location": locations}
+)
+
+# item_features dataframe
+item_ids = np.arange(1, 101)
+prices = np.random.uniform(10, 500, size=100).round(2)
+colors = np.random.choice(["red", "blue", "green", "black"], size=100)
+categories = np.random.choice(["electronics", "clothing", "home", "toys"], size=100)
+
+item_features = pd.DataFrame(
+ {"id": item_ids, "price": prices, "color": colors, "category": categories}
+)
+
+# Interactions dataframe
+interaction_user_ids = np.random.choice(user_ids, size=1000)
+interaction_item_ids = np.random.choice(item_ids, size=1000)
+purchased = np.random.choice([0, 1], size=1000, p=[0.7, 0.3])
+interactions = pd.DataFrame(
+ {
+ "user_id": interaction_user_ids,
+ "item_id": interaction_item_ids,
+ "purchased": purchased,
+ }
+)
+user_item_purchased = interactions.merge(
+ user_features, left_on="user_id", right_on="id"
+).merge(item_features, left_on="item_id", right_on="id")
+
+# Users
+tab_preprocessor_user = TabPreprocessor(
+ cat_embed_cols=["gender", "location"],
+ continuous_cols=["age"],
+)
+X_user = tab_preprocessor_user.fit_transform(user_item_purchased)
+tab_mlp_user = TabMlp(
+ column_idx=tab_preprocessor_user.column_idx,
+ cat_embed_input=tab_preprocessor_user.cat_embed_input,
+ continuous_cols=["age"],
+ mlp_hidden_dims=[16, 8],
+ mlp_dropout=[0.2, 0.2],
+)
+
+# Items
+tab_preprocessor_item = TabPreprocessor(
+ cat_embed_cols=["color", "category"],
+ continuous_cols=["price"],
+)
+X_item = tab_preprocessor_item.fit_transform(user_item_purchased)
+tab_mlp_item = TabMlp(
+ column_idx=tab_preprocessor_item.column_idx,
+ cat_embed_input=tab_preprocessor_item.cat_embed_input,
+ continuous_cols=["price"],
+ mlp_hidden_dims=[16, 8],
+ mlp_dropout=[0.2, 0.2],
+)
+
+two_tower_model = ModelFuser([tab_mlp_user, tab_mlp_item], fusion_method="dot")
+
+model = WideDeep(deeptabular=two_tower_model)
+
+trainer = Trainer(model, objective="binary")
+
+trainer.fit(
+ X_tab=[X_user, X_item],
+ target=interactions.purchased.values,
+ n_epochs=1,
+ batch_size=32,
+)
+
8. Tabular with a multi-target loss
This one is "a bonus" to illustrate the use of multi-target losses, more than actually a different architecture.
- +
from pytorch_widedeep.preprocessing import TabPreprocessor, TextPreprocessor, ImagePreprocessor
@@ -2163,6 +2280,24 @@ The deeptabular
component
+The rec
module¶
+This module was introduced as an extension to the existing components in the
+library, addressing questions and issues related to recommendation systems.
+While still under active development, it currently includes a select number
+of powerful recommendation models.
+It's worth noting that this library already supported the implementation of
+various recommendation algorithms using existing components. For example,
+models like Wide and Deep, Two-Tower, or Neural Collaborative Filtering could
+be constructed using the library's core functionalities.
+The recommendation algorithms in the rec
module are:
+
+- DeepFM: A Factorization-Machine based Neural Network for CTR Prediction
+- (Deep) Field Aware Factorization Machine (FFM): a Deep Learning version of the algorithm presented in Field-aware Factorization Machines in a Real-world Online Advertising System
+- xDeepFM: Combining Explicit and Implicit Feature Interactions for Recommender Systems
+- Deep Interest Network for Click-Through Rate Prediction
+
+These can all be used as the deeptabular
component in the WideDeep
model.
+See the examples for more details.
Text and Images¶
For the text component, deeptext
, the library offers the following models:
@@ -2240,14 +2375,15 @@ APA¶<
@@ -2297,7 +2433,7 @@
APA¶<
-
+
+ +
+ + +```python +import numpy as np +import pandas as pd + +from pytorch_widedeep import Trainer +from pytorch_widedeep.preprocessing import TabPreprocessor +from pytorch_widedeep.models import TabMlp, WideDeep, ModelFuser + +# Let's create the interaction dataset +# user_features dataframe +np.random.seed(42) +user_ids = np.arange(1, 101) +ages = np.random.randint(18, 60, size=100) +genders = np.random.choice(["male", "female"], size=100) +locations = np.random.choice(["city_a", "city_b", "city_c", "city_d"], size=100) +user_features = pd.DataFrame( + {"id": user_ids, "age": ages, "gender": genders, "location": locations} +) + +# item_features dataframe +item_ids = np.arange(1, 101) +prices = np.random.uniform(10, 500, size=100).round(2) +colors = np.random.choice(["red", "blue", "green", "black"], size=100) +categories = np.random.choice(["electronics", "clothing", "home", "toys"], size=100) + +item_features = pd.DataFrame( + {"id": item_ids, "price": prices, "color": colors, "category": categories} +) + +# Interactions dataframe +interaction_user_ids = np.random.choice(user_ids, size=1000) +interaction_item_ids = np.random.choice(item_ids, size=1000) +purchased = np.random.choice([0, 1], size=1000, p=[0.7, 0.3]) +interactions = pd.DataFrame( + { + "user_id": interaction_user_ids, + "item_id": interaction_item_ids, + "purchased": purchased, + } +) +user_item_purchased = interactions.merge( + user_features, left_on="user_id", right_on="id" +).merge(item_features, left_on="item_id", right_on="id") + +# Users +tab_preprocessor_user = TabPreprocessor( + cat_embed_cols=["gender", "location"], + continuous_cols=["age"], +) +X_user = tab_preprocessor_user.fit_transform(user_item_purchased) +tab_mlp_user = TabMlp( + column_idx=tab_preprocessor_user.column_idx, + cat_embed_input=tab_preprocessor_user.cat_embed_input, + continuous_cols=["age"], + mlp_hidden_dims=[16, 8], + mlp_dropout=[0.2, 0.2], +) + +# Items +tab_preprocessor_item = TabPreprocessor( + cat_embed_cols=["color", "category"], + continuous_cols=["price"], +) +X_item = tab_preprocessor_item.fit_transform(user_item_purchased) +tab_mlp_item = TabMlp( + column_idx=tab_preprocessor_item.column_idx, + cat_embed_input=tab_preprocessor_item.cat_embed_input, + continuous_cols=["price"], + mlp_hidden_dims=[16, 8], + mlp_dropout=[0.2, 0.2], +) + +two_tower_model = ModelFuser([tab_mlp_user, tab_mlp_item], fusion_method="dot") + +model = WideDeep(deeptabular=two_tower_model) + +trainer = Trainer(model, objective="binary") + +trainer.fit( + X_tab=[X_user, X_item], + target=interactions.purchased.values, + n_epochs=1, + batch_size=32, +) +``` + +**8. Tabular with a multi-target loss** This one is "a bonus" to illustrate the use of multi-target losses, more than actually a different architecture.- +
@@ -692,6 +791,28 @@ encoder-decoder method and constrastive-denoising method. Please, see the documentation and the examples for details on this functionality, and all other options in the library. +### The ``rec`` module + +This module was introduced as an extension to the existing components in the +library, addressing questions and issues related to recommendation systems. +While still under active development, it currently includes a select number +of powerful recommendation models. + +It's worth noting that this library already supported the implementation of +various recommendation algorithms using existing components. For example, +models like Wide and Deep, Two-Tower, or Neural Collaborative Filtering could +be constructed using the library's core functionalities. + +The recommendation algorithms in the `rec` module are: + +1. [DeepFM: A Factorization-Machine based Neural Network for CTR Prediction](https://arxiv.org/abs/1703.04247) +2. (Deep) Field Aware Factorization Machine (FFM): a Deep Learning version of the algorithm presented in [Field-aware Factorization Machines in a Real-world Online Advertising System](https://arxiv.org/abs/1701.04099) +3. [xDeepFM: Combining Explicit and Implicit Feature Interactions for Recommender Systems](https://arxiv.org/pdf/1803.05170) +4. [Deep Interest Network for Click-Through Rate Prediction](https://arxiv.org/abs/1706.06978) + +These can all be used as the `deeptabular` component in the `WideDeep` model. +See the examples for more details. + ### Text and Images For the text component, `deeptext`, the library offers the following models: diff --git a/mkdocs/site/installation.html b/mkdocs/site/installation.html index 0cc41b2a..f321b41e 100644 --- a/mkdocs/site/installation.html +++ b/mkdocs/site/installation.html @@ -18,7 +18,7 @@ - + @@ -26,7 +26,7 @@ - + @@ -52,7 +52,7 @@ - + @@ -101,7 +101,7 @@Defines a Wide
model. This is a linear model where the
+
Defines a Wide
model. This is a linear model where the
non-linearlities are captured via crossed-columns
>>> import torch
>>> from pytorch_widedeep.bayesian_models import BayesianWide
>>> X = torch.empty(4, 4).random_(6)
->>> wide = BayesianWide(input_dim=X.unique().size(0), pred_dim=1)
+>>> wide = BayesianWide(input_dim=int(X.max().item()), pred_dim=1)
>>> out = wide(X)
pytorch_widedeep/bayesian_models/tabular/bayesian_linear/bayesian_wide.py