Skip to content

Commit

Permalink
add emulator ipynb example
Browse files Browse the repository at this point in the history
  • Loading branch information
pvankatwyk committed Apr 2, 2024
1 parent 1035c9e commit f22bb08
Showing 1 changed file with 259 additions and 0 deletions.
259 changes: 259 additions & 0 deletions examples/hybrid_emulator_demo.ipynb
Original file line number Diff line number Diff line change
@@ -0,0 +1,259 @@
{
"cells": [
{
"cell_type": "markdown",
"id": "a4a3f852",
"metadata": {},
"source": [
"# Ice Sheet Emulator Demo"
]
},
{
"cell_type": "markdown",
"id": "e7ebd4e4",
"metadata": {},
"source": [
"### Installation\n",
"\n",
"To install, run this is your terminal. This will create an environment, download ISE, and activate a jupyter notebook instance. Open up this notebook to run the model. \n",
"\n",
"```conda create -n ise -y``` \n",
"```conda activate ise``` \n",
"```conda install nb_conda ipykernel -y``` \n",
"```pip install git+https://github.com/Brown-SciML/ise``` \n",
"```jupyter notebook``` "
]
},
{
"cell_type": "code",
"execution_count": 1,
"id": "a13a05eb",
"metadata": {},
"outputs": [],
"source": [
"import numpy as np\n",
"\n",
"from ise.models.hybrid import HybridEmulator, DeepEnsemble, NormalizingFlow\n",
"from sklearn.preprocessing import StandardScaler"
]
},
{
"cell_type": "markdown",
"id": "0882b6e8",
"metadata": {},
"source": [
"### Synthetic Data and Scaling"
]
},
{
"cell_type": "code",
"execution_count": 2,
"id": "a814d0bd",
"metadata": {},
"outputs": [],
"source": [
"X = np.array([[1,2,3], [4,5,6], [7,8,9]], dtype=int)\n",
"y = np.array([[1], [2], [3],], dtype=int)\n",
"\n",
"scaler_y = StandardScaler().fit(y)"
]
},
{
"cell_type": "markdown",
"id": "8847c8d1",
"metadata": {},
"source": [
"### Create emulator Model\n",
"Create the model with the Deep Ensemble as the Predictor and the Normalizing Flow as the uncertainty quantifier."
]
},
{
"cell_type": "code",
"execution_count": 3,
"id": "0beb86de",
"metadata": {},
"outputs": [],
"source": [
"input_shape = X.shape[1]\n",
"output_shape = y.shape[1]\n",
"num_ensemble_members = 2\n",
"\n",
"predictor = DeepEnsemble(num_predictors=num_ensemble_members, forcing_size=input_shape, sle_size=output_shape)\n",
"uncertainty_quantifier = NormalizingFlow(forcing_size=input_shape, sle_size=output_shape)\n",
"emulator = HybridEmulator(deep_ensemble=predictor, normalizing_flow=uncertainty_quantifier)"
]
},
{
"cell_type": "code",
"execution_count": 4,
"id": "bf7a5dc3",
"metadata": {
"scrolled": false
},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"\n",
"Training Normalizing Flow (10 epochs):\n",
"Epoch 1, Loss: 4.753732204437256\n",
"Epoch 2, Loss: 4.68987512588501\n",
"Epoch 3, Loss: 4.626246929168701\n",
"Epoch 4, Loss: 4.562777042388916\n",
"Epoch 5, Loss: 4.499433517456055\n",
"Epoch 6, Loss: 4.436193943023682\n",
"Epoch 7, Loss: 4.373051166534424\n",
"Epoch 8, Loss: 4.310001850128174\n",
"Epoch 9, Loss: 4.247047424316406\n",
"Epoch 10, Loss: 4.1841888427734375\n",
"\n",
"Training Deep Ensemble (15 epochs):\n",
"Training Weak Predictor 1 of 2:\n",
"Epoch 1, Average Batch Loss: 4.4524664878845215\n",
"Epoch 2, Average Batch Loss: 4.2489094734191895\n",
"Epoch 3, Average Batch Loss: 4.11221170425415\n",
"Epoch 4, Average Batch Loss: 3.971240282058716\n",
"Epoch 5, Average Batch Loss: 3.7946555614471436\n",
"Epoch 6, Average Batch Loss: 3.5632998943328857\n",
"Epoch 7, Average Batch Loss: 3.2641429901123047\n",
"Epoch 8, Average Batch Loss: 2.8817074298858643\n",
"Epoch 9, Average Batch Loss: 2.406568765640259\n",
"Epoch 10, Average Batch Loss: 1.8449459075927734\n",
"Epoch 11, Average Batch Loss: 1.2372466325759888\n",
"Epoch 12, Average Batch Loss: 0.6458789706230164\n"
]
},
{
"name": "stderr",
"output_type": "stream",
"text": [
"C:\\Users\\pvank\\miniconda3\\envs\\ise\\Lib\\site-packages\\ise\\data\\dataclasses.py:33: UserWarning: Full projections of 86 timesteps are not present in the dataset. This may lead to unexpected behavior.\n",
" warnings.warn(\n"
]
},
{
"name": "stdout",
"output_type": "stream",
"text": [
"Epoch 13, Average Batch Loss: 0.20140044391155243\n",
"Epoch 14, Average Batch Loss: 0.067240871489048\n",
"Epoch 15, Average Batch Loss: 0.31725022196769714\n",
"\n",
"Training Weak Predictor 2 of 2:\n",
"Epoch 1, Average Batch Loss: 1.4070473909378052\n",
"Epoch 2, Average Batch Loss: 1.3644334077835083\n",
"Epoch 3, Average Batch Loss: 1.327694296836853\n",
"Epoch 4, Average Batch Loss: 1.2840057611465454\n",
"Epoch 5, Average Batch Loss: 1.2289267778396606\n",
"Epoch 6, Average Batch Loss: 1.1562410593032837\n",
"Epoch 7, Average Batch Loss: 1.0581227540969849\n",
"Epoch 8, Average Batch Loss: 0.9263723492622375\n",
"Epoch 9, Average Batch Loss: 0.7507588863372803\n",
"Epoch 10, Average Batch Loss: 0.5248854756355286\n",
"Epoch 11, Average Batch Loss: 0.26091358065605164\n",
"Epoch 12, Average Batch Loss: 0.05477694049477577\n",
"Epoch 13, Average Batch Loss: 0.053589772433042526\n",
"Epoch 14, Average Batch Loss: 0.23185868561267853\n",
"Epoch 15, Average Batch Loss: 0.30263322591781616\n",
"\n"
]
}
],
"source": [
"emulator.fit(X, y, nf_epochs=10, de_epochs=15)"
]
},
{
"cell_type": "markdown",
"id": "b70b8513",
"metadata": {},
"source": [
"### Predict on unseen data"
]
},
{
"cell_type": "code",
"execution_count": 5,
"id": "a525e483",
"metadata": {},
"outputs": [],
"source": [
"predictions, uncertainties = emulator.predict(np.array([[10, 11, 12]]), output_scaler=scaler_y)"
]
},
{
"cell_type": "code",
"execution_count": 6,
"id": "2318cea9",
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
"array([[4.0753965]], dtype=float32)"
]
},
"execution_count": 6,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"predictions"
]
},
{
"cell_type": "code",
"execution_count": 7,
"id": "64d26fc6",
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
"{'total': array([[135.88687]], dtype=float32),\n",
" 'epistemic': array([[0.06036377]], dtype=float32),\n",
" 'aleatoric': array([[135.8265]], dtype=float32)}"
]
},
"execution_count": 7,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"uncertainties"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "f62b427f",
"metadata": {},
"outputs": [],
"source": []
}
],
"metadata": {
"kernelspec": {
"display_name": "Python 3 (ipykernel)",
"language": "python",
"name": "python3"
},
"language_info": {
"codemirror_mode": {
"name": "ipython",
"version": 3
},
"file_extension": ".py",
"mimetype": "text/x-python",
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.11.5"
}
},
"nbformat": 4,
"nbformat_minor": 5
}

0 comments on commit f22bb08

Please sign in to comment.