Skip to content

Commit

Permalink
Improve example notebook
Browse files Browse the repository at this point in the history
  • Loading branch information
chrislemke committed Nov 10, 2022
1 parent 61a72f8 commit cada8d3
Show file tree
Hide file tree
Showing 4 changed files with 83 additions and 90 deletions.
17 changes: 8 additions & 9 deletions example.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -29,7 +29,7 @@
"import numpy as np\n",
"import pandas as pd\n",
"import torch\n",
"from sklearn.manifold import TSNE\n",
"from sklearn.decomposition import PCA\n",
"from sklearn.preprocessing import MinMaxScaler\n",
"from autoembedder import Autoembedder, dataloader, fit\n",
"import plotly.offline as py\n",
Expand Down Expand Up @@ -91,10 +91,7 @@
"outputs": [],
"source": [
"def plot_scatter(X, y):\n",
" X = TSNE(\n",
" n_components=2, random_state=42, learning_rate=\"auto\", init=\"random\"\n",
" ).fit_transform(X)\n",
"\n",
" X = PCA(n_components=2, random_state=42).fit_transform(X)\n",
" traces = [\n",
" go.Scatter(\n",
" x=X[y == 0, 0],\n",
Expand All @@ -111,7 +108,6 @@
" name=\"Fraud (1)\",\n",
" ),\n",
" ]\n",
"\n",
" py.iplot(go.Figure(data=traces))"
]
},
Expand Down Expand Up @@ -153,11 +149,14 @@
"outputs": [],
"source": [
"df = (\n",
" pd.concat([df.loc[df[\"Class\"] == 1], df.loc[df[\"Class\"] == 0].sample(5000)])\n",
" pd.concat([df.loc[df[\"Class\"] == 1], df.loc[df[\"Class\"] == 0]])\n",
" .sample(frac=1)\n",
" .reset_index(drop=True)\n",
")\n",
"y = df.pop(\"Class\")"
"y = df.pop(\"Class\")\n",
"\n",
"scaler = MinMaxScaler()\n",
"df = scaler.fit_transform(df)"
]
},
{
Expand Down Expand Up @@ -281,7 +280,7 @@
"outputs": [],
"source": [
"parameters = {\n",
" \"hidden_layers\": [[25, 20], [20, 10]],\n",
" \"hidden_layers\": [[30, 25], [25, 20], [20, 15], [15, 8]],\n",
" \"epochs\": 10,\n",
" \"lr\": 0.0001,\n",
" \"verbose\": 1,\n",
Expand Down
79 changes: 38 additions & 41 deletions poetry.lock

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

6 changes: 3 additions & 3 deletions pyproject.toml
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
[tool.poetry]
name = "Autoembedder"
version = "0.1.14"
version = "0.1.15"
description = "PyTorch autoencoder with additional embeddings layer for categorical data."
authors = ["Christopher Lemke <chris@syhbl.mozmail.com>"]
license = "MIT"
Expand Down Expand Up @@ -31,9 +31,9 @@ pyarrow = "10.0.0"
pandas = "1.5.1"
torch = "1.12.1"
torchinfo = "1.7.1"
einops = "0.5.0"
einops = "0.6.0"
tqdm = "4.64.1"
tensorboard = "2.10.1"
tensorboard = "2.11.0"
pytorch-ignite = "0.4.10"
numpy = "1.23.4"

Expand Down
Loading

0 comments on commit cada8d3

Please sign in to comment.