-
Notifications
You must be signed in to change notification settings - Fork 0
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
- Loading branch information
1 parent
1035c9e
commit f22bb08
Showing
1 changed file
with
259 additions
and
0 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,259 @@ | ||
{ | ||
"cells": [ | ||
{ | ||
"cell_type": "markdown", | ||
"id": "a4a3f852", | ||
"metadata": {}, | ||
"source": [ | ||
"# Ice Sheet Emulator Demo" | ||
] | ||
}, | ||
{ | ||
"cell_type": "markdown", | ||
"id": "e7ebd4e4", | ||
"metadata": {}, | ||
"source": [ | ||
"### Installation\n", | ||
"\n", | ||
"To install, run this is your terminal. This will create an environment, download ISE, and activate a jupyter notebook instance. Open up this notebook to run the model. \n", | ||
"\n", | ||
"```conda create -n ise -y``` \n", | ||
"```conda activate ise``` \n", | ||
"```conda install nb_conda ipykernel -y``` \n", | ||
"```pip install git+https://github.com/Brown-SciML/ise``` \n", | ||
"```jupyter notebook``` " | ||
] | ||
}, | ||
{ | ||
"cell_type": "code", | ||
"execution_count": 1, | ||
"id": "a13a05eb", | ||
"metadata": {}, | ||
"outputs": [], | ||
"source": [ | ||
"import numpy as np\n", | ||
"\n", | ||
"from ise.models.hybrid import HybridEmulator, DeepEnsemble, NormalizingFlow\n", | ||
"from sklearn.preprocessing import StandardScaler" | ||
] | ||
}, | ||
{ | ||
"cell_type": "markdown", | ||
"id": "0882b6e8", | ||
"metadata": {}, | ||
"source": [ | ||
"### Synthetic Data and Scaling" | ||
] | ||
}, | ||
{ | ||
"cell_type": "code", | ||
"execution_count": 2, | ||
"id": "a814d0bd", | ||
"metadata": {}, | ||
"outputs": [], | ||
"source": [ | ||
"X = np.array([[1,2,3], [4,5,6], [7,8,9]], dtype=int)\n", | ||
"y = np.array([[1], [2], [3],], dtype=int)\n", | ||
"\n", | ||
"scaler_y = StandardScaler().fit(y)" | ||
] | ||
}, | ||
{ | ||
"cell_type": "markdown", | ||
"id": "8847c8d1", | ||
"metadata": {}, | ||
"source": [ | ||
"### Create emulator Model\n", | ||
"Create the model with the Deep Ensemble as the Predictor and the Normalizing Flow as the uncertainty quantifier." | ||
] | ||
}, | ||
{ | ||
"cell_type": "code", | ||
"execution_count": 3, | ||
"id": "0beb86de", | ||
"metadata": {}, | ||
"outputs": [], | ||
"source": [ | ||
"input_shape = X.shape[1]\n", | ||
"output_shape = y.shape[1]\n", | ||
"num_ensemble_members = 2\n", | ||
"\n", | ||
"predictor = DeepEnsemble(num_predictors=num_ensemble_members, forcing_size=input_shape, sle_size=output_shape)\n", | ||
"uncertainty_quantifier = NormalizingFlow(forcing_size=input_shape, sle_size=output_shape)\n", | ||
"emulator = HybridEmulator(deep_ensemble=predictor, normalizing_flow=uncertainty_quantifier)" | ||
] | ||
}, | ||
{ | ||
"cell_type": "code", | ||
"execution_count": 4, | ||
"id": "bf7a5dc3", | ||
"metadata": { | ||
"scrolled": false | ||
}, | ||
"outputs": [ | ||
{ | ||
"name": "stdout", | ||
"output_type": "stream", | ||
"text": [ | ||
"\n", | ||
"Training Normalizing Flow (10 epochs):\n", | ||
"Epoch 1, Loss: 4.753732204437256\n", | ||
"Epoch 2, Loss: 4.68987512588501\n", | ||
"Epoch 3, Loss: 4.626246929168701\n", | ||
"Epoch 4, Loss: 4.562777042388916\n", | ||
"Epoch 5, Loss: 4.499433517456055\n", | ||
"Epoch 6, Loss: 4.436193943023682\n", | ||
"Epoch 7, Loss: 4.373051166534424\n", | ||
"Epoch 8, Loss: 4.310001850128174\n", | ||
"Epoch 9, Loss: 4.247047424316406\n", | ||
"Epoch 10, Loss: 4.1841888427734375\n", | ||
"\n", | ||
"Training Deep Ensemble (15 epochs):\n", | ||
"Training Weak Predictor 1 of 2:\n", | ||
"Epoch 1, Average Batch Loss: 4.4524664878845215\n", | ||
"Epoch 2, Average Batch Loss: 4.2489094734191895\n", | ||
"Epoch 3, Average Batch Loss: 4.11221170425415\n", | ||
"Epoch 4, Average Batch Loss: 3.971240282058716\n", | ||
"Epoch 5, Average Batch Loss: 3.7946555614471436\n", | ||
"Epoch 6, Average Batch Loss: 3.5632998943328857\n", | ||
"Epoch 7, Average Batch Loss: 3.2641429901123047\n", | ||
"Epoch 8, Average Batch Loss: 2.8817074298858643\n", | ||
"Epoch 9, Average Batch Loss: 2.406568765640259\n", | ||
"Epoch 10, Average Batch Loss: 1.8449459075927734\n", | ||
"Epoch 11, Average Batch Loss: 1.2372466325759888\n", | ||
"Epoch 12, Average Batch Loss: 0.6458789706230164\n" | ||
] | ||
}, | ||
{ | ||
"name": "stderr", | ||
"output_type": "stream", | ||
"text": [ | ||
"C:\\Users\\pvank\\miniconda3\\envs\\ise\\Lib\\site-packages\\ise\\data\\dataclasses.py:33: UserWarning: Full projections of 86 timesteps are not present in the dataset. This may lead to unexpected behavior.\n", | ||
" warnings.warn(\n" | ||
] | ||
}, | ||
{ | ||
"name": "stdout", | ||
"output_type": "stream", | ||
"text": [ | ||
"Epoch 13, Average Batch Loss: 0.20140044391155243\n", | ||
"Epoch 14, Average Batch Loss: 0.067240871489048\n", | ||
"Epoch 15, Average Batch Loss: 0.31725022196769714\n", | ||
"\n", | ||
"Training Weak Predictor 2 of 2:\n", | ||
"Epoch 1, Average Batch Loss: 1.4070473909378052\n", | ||
"Epoch 2, Average Batch Loss: 1.3644334077835083\n", | ||
"Epoch 3, Average Batch Loss: 1.327694296836853\n", | ||
"Epoch 4, Average Batch Loss: 1.2840057611465454\n", | ||
"Epoch 5, Average Batch Loss: 1.2289267778396606\n", | ||
"Epoch 6, Average Batch Loss: 1.1562410593032837\n", | ||
"Epoch 7, Average Batch Loss: 1.0581227540969849\n", | ||
"Epoch 8, Average Batch Loss: 0.9263723492622375\n", | ||
"Epoch 9, Average Batch Loss: 0.7507588863372803\n", | ||
"Epoch 10, Average Batch Loss: 0.5248854756355286\n", | ||
"Epoch 11, Average Batch Loss: 0.26091358065605164\n", | ||
"Epoch 12, Average Batch Loss: 0.05477694049477577\n", | ||
"Epoch 13, Average Batch Loss: 0.053589772433042526\n", | ||
"Epoch 14, Average Batch Loss: 0.23185868561267853\n", | ||
"Epoch 15, Average Batch Loss: 0.30263322591781616\n", | ||
"\n" | ||
] | ||
} | ||
], | ||
"source": [ | ||
"emulator.fit(X, y, nf_epochs=10, de_epochs=15)" | ||
] | ||
}, | ||
{ | ||
"cell_type": "markdown", | ||
"id": "b70b8513", | ||
"metadata": {}, | ||
"source": [ | ||
"### Predict on unseen data" | ||
] | ||
}, | ||
{ | ||
"cell_type": "code", | ||
"execution_count": 5, | ||
"id": "a525e483", | ||
"metadata": {}, | ||
"outputs": [], | ||
"source": [ | ||
"predictions, uncertainties = emulator.predict(np.array([[10, 11, 12]]), output_scaler=scaler_y)" | ||
] | ||
}, | ||
{ | ||
"cell_type": "code", | ||
"execution_count": 6, | ||
"id": "2318cea9", | ||
"metadata": {}, | ||
"outputs": [ | ||
{ | ||
"data": { | ||
"text/plain": [ | ||
"array([[4.0753965]], dtype=float32)" | ||
] | ||
}, | ||
"execution_count": 6, | ||
"metadata": {}, | ||
"output_type": "execute_result" | ||
} | ||
], | ||
"source": [ | ||
"predictions" | ||
] | ||
}, | ||
{ | ||
"cell_type": "code", | ||
"execution_count": 7, | ||
"id": "64d26fc6", | ||
"metadata": {}, | ||
"outputs": [ | ||
{ | ||
"data": { | ||
"text/plain": [ | ||
"{'total': array([[135.88687]], dtype=float32),\n", | ||
" 'epistemic': array([[0.06036377]], dtype=float32),\n", | ||
" 'aleatoric': array([[135.8265]], dtype=float32)}" | ||
] | ||
}, | ||
"execution_count": 7, | ||
"metadata": {}, | ||
"output_type": "execute_result" | ||
} | ||
], | ||
"source": [ | ||
"uncertainties" | ||
] | ||
}, | ||
{ | ||
"cell_type": "code", | ||
"execution_count": null, | ||
"id": "f62b427f", | ||
"metadata": {}, | ||
"outputs": [], | ||
"source": [] | ||
} | ||
], | ||
"metadata": { | ||
"kernelspec": { | ||
"display_name": "Python 3 (ipykernel)", | ||
"language": "python", | ||
"name": "python3" | ||
}, | ||
"language_info": { | ||
"codemirror_mode": { | ||
"name": "ipython", | ||
"version": 3 | ||
}, | ||
"file_extension": ".py", | ||
"mimetype": "text/x-python", | ||
"name": "python", | ||
"nbconvert_exporter": "python", | ||
"pygments_lexer": "ipython3", | ||
"version": "3.11.5" | ||
} | ||
}, | ||
"nbformat": 4, | ||
"nbformat_minor": 5 | ||
} |