Skip to content

Commit

Permalink
Merge pull request #54 from jrzaurin/jrzaurin/fix_bug_issue53
Browse files Browse the repository at this point in the history
Fixed issue #53 related to the use of some transformer models without categorical columns
  • Loading branch information
jrzaurin authored Oct 7, 2021
2 parents 6540cd3 + fadede2 commit 5f60170
Show file tree
Hide file tree
Showing 8 changed files with 37 additions and 19 deletions.
1 change: 1 addition & 0 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@
[![Code style: black](https://img.shields.io/badge/code%20style-black-000000.svg)](https://github.com/psf/black)
[![Maintenance](https://img.shields.io/badge/Maintained%3F-yes-green.svg)](https://github.com/jrzaurin/pytorch-widedeep/graphs/commit-activity)
[![contributions welcome](https://img.shields.io/badge/contributions-welcome-brightgreen.svg?style=flat)](https://github.com/jrzaurin/pytorch-widedeep/issues)
[![Slack](https://img.shields.io/badge/slack-chat-green.svg?logo=slack)](https://join.slack.com/t/pytorch-widedeep/shared_invite/zt-soss7stf-iXpVuLeKZz8lGTnxxtHtTw)

# pytorch-widedeep

Expand Down
2 changes: 1 addition & 1 deletion VERSION
Original file line number Diff line number Diff line change
@@ -1 +1 @@
1.0.9
1.0.10
1 change: 1 addition & 0 deletions pypi_README.md
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@
[![Code style: black](https://img.shields.io/badge/code%20style-black-000000.svg)](https://github.com/psf/black)
[![Maintenance](https://img.shields.io/badge/Maintained%3F-yes-green.svg)](https://github.com/jrzaurin/pytorch-widedeep/graphs/commit-activity)
[![contributions welcome](https://img.shields.io/badge/contributions-welcome-brightgreen.svg?style=flat)](https://github.com/jrzaurin/pytorch-widedeep/issues)
[![Slack](https://img.shields.io/badge/slack-chat-green.svg?logo=slack)](https://join.slack.com/t/pytorch-widedeep/shared_invite/zt-soss7stf-iXpVuLeKZz8lGTnxxtHtTw)


# pytorch-widedeep
Expand Down
7 changes: 1 addition & 6 deletions pytorch_widedeep/models/transformers/ft_transformer.py
Original file line number Diff line number Diff line change
Expand Up @@ -134,7 +134,7 @@ class FTTransformer(nn.Module):
def __init__(
self,
column_idx: Dict[str, int],
embed_input: List[Tuple[str, int]],
embed_input: Optional[List[Tuple[str, int]]] = None,
embed_dropout: float = 0.1,
full_embed_dropout: bool = False,
shared_embed: bool = False,
Expand Down Expand Up @@ -194,11 +194,6 @@ def __init__(
self.n_cont = len(continuous_cols) if continuous_cols is not None else 0
self.n_feats = self.n_cat + self.n_cont

if self.n_cont and not self.n_cat and not self.embed_continuous:
raise ValueError(
"If only continuous features are used 'embed_continuous' must be set to 'True'"
)

self.cat_and_cont_embed = CatAndContEmbeddings(
input_dim,
column_idx,
Expand Down
7 changes: 1 addition & 6 deletions pytorch_widedeep/models/transformers/saint.py
Original file line number Diff line number Diff line change
Expand Up @@ -120,7 +120,7 @@ class SAINT(nn.Module):
def __init__(
self,
column_idx: Dict[str, int],
embed_input: List[Tuple[str, int]],
embed_input: Optional[List[Tuple[str, int]]] = None,
embed_dropout: float = 0.1,
full_embed_dropout: bool = False,
shared_embed: bool = False,
Expand Down Expand Up @@ -173,11 +173,6 @@ def __init__(
self.n_cont = len(continuous_cols) if continuous_cols is not None else 0
self.n_feats = self.n_cat + self.n_cont

if self.n_cont and not self.n_cat and not self.embed_continuous:
raise ValueError(
"If only continuous features are used 'embed_continuous' must be set to 'True'"
)

self.cat_and_cont_embed = CatAndContEmbeddings(
input_dim,
column_idx,
Expand Down
5 changes: 0 additions & 5 deletions pytorch_widedeep/models/transformers/tab_fastformer.py
Original file line number Diff line number Diff line change
Expand Up @@ -182,11 +182,6 @@ def __init__(
self.n_cont = len(continuous_cols) if continuous_cols is not None else 0
self.n_feats = self.n_cat + self.n_cont

if self.n_cont and not self.n_cat and not self.embed_continuous:
raise ValueError(
"If only continuous features are used 'embed_continuous' must be set to 'True'"
)

self.cat_and_cont_embed = CatAndContEmbeddings(
input_dim,
column_idx,
Expand Down
2 changes: 1 addition & 1 deletion pytorch_widedeep/version.py
Original file line number Diff line number Diff line change
@@ -1 +1 @@
__version__ = "1.0.9"
__version__ = "1.0.10"
31 changes: 31 additions & 0 deletions tests/test_model_components/test_mc_transformers.py
Original file line number Diff line number Diff line change
Expand Up @@ -449,3 +449,34 @@ def test_ft_transformer_mlp(mlp_first_h, shoud_work):
else:
with pytest.raises(AssertionError):
model = _build_model("fttransformer", params) # noqa: F841


###############################################################################
# Test transformers with only continuous cols
###############################################################################


X_tab_only_cont = torch.from_numpy(
np.vstack([np.random.rand(10) for _ in range(4)]).transpose()
)
colnames_only_cont = list(string.ascii_lowercase)[:4]


@pytest.mark.parametrize(
"model_name",
[
"fttransformer",
"saint",
"tabfastformer",
],
)
def test_transformers_only_cont(model_name):
params = {
"column_idx": {k: v for v, k in enumerate(colnames_only_cont)},
"continuous_cols": colnames_only_cont,
}

model = _build_model(model_name, params)
out = model(X_tab_only_cont)

assert out.size(0) == 10 and out.size(1) == model.output_dim

0 comments on commit 5f60170

Please sign in to comment.