Skip to content

Commit

Permalink
Merge branch 'develop' into issue_725_sos
Browse files Browse the repository at this point in the history
  • Loading branch information
sammlapp committed Sep 18, 2024
2 parents 491ff72 + 866cbd8 commit 8ba62c5
Show file tree
Hide file tree
Showing 7 changed files with 1,340 additions and 198 deletions.
54 changes: 48 additions & 6 deletions docs/tutorials/customize_cnn_training.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -56,7 +56,7 @@
},
{
"cell_type": "code",
"execution_count": 1,
"execution_count": 2,
"id": "eb11d417-c381-4950-a78c-65e8a512a786",
"metadata": {},
"outputs": [],
Expand Down Expand Up @@ -409,7 +409,7 @@
},
{
"cell_type": "code",
"execution_count": 10,
"execution_count": 5,
"id": "cd0d0636-bea4-42a4-9f79-a4ed340e30c5",
"metadata": {},
"outputs": [
Expand All @@ -419,7 +419,7 @@
"Linear(in_features=512, out_features=1, bias=True)"
]
},
"execution_count": 10,
"execution_count": 5,
"metadata": {},
"output_type": "execute_result"
}
Expand All @@ -434,7 +434,47 @@
"id": "00453342-0830-45eb-84ac-93cd1a3c7337",
"metadata": {},
"source": [
"It is also possbile to replace an architecture of a model entirely simply by setting `model.architecture` to a new architecture, but this is not recommended. It will completely remove anything the model has \"learned,\" since the learned weights are a part of the architecture."
"It is also possbile to replace an architecture of a model entirely simply by setting `model.architecture` to a new architecture, but this is not generally recommended unless you know what you're doing. It will completely remove anything the model has \"learned,\" since the learned weights are a part of the architecture."
]
},
{
"cell_type": "markdown",
"id": "fbd0d364",
"metadata": {},
"source": [
"## Freezing the feature extractor\n",
"\n",
"Sometimes, we only wish to train the final layer or layers of a CNN, known as the \"classification head\" or simply \"classifier\", rather than training all of the layers. This technique makes it possible to fine-tune a pre-trained network using limited training data, without ruining the generalizability of the \"feature extractor\" (the term for all of the layers before the \"classification head\"). \n",
"\n",
"If you're using one of the built-in CNN architectures in OpenSoundscape, you can easily \"freeze\" the feature extractor (i.e., tell PyTorch not to update any of the weights during training of the classification head) with a one-liner, then proceed with training as normal (`cnn.train()...`)"
]
},
{
"cell_type": "code",
"execution_count": 6,
"id": "3a7b6775",
"metadata": {},
"outputs": [],
"source": [
"model.freeze_feature_extractor()"
]
},
{
"cell_type": "markdown",
"id": "a2fdf566",
"metadata": {},
"source": [
"If you are using a custom architecture not native to OpenSoundscape, you can still freeze all but one layer with a one-liner. You just need to specify which layer or layers you wish to keep \"trainable\" or \"unfrozen\". In the case of a resnet architecture, we can point to the `.fc` (for \"fully connected\") layer as the classification layer we want to train while freezing all others. Note that different pytorch architectures may not call the classification layer `.fc`. "
]
},
{
"cell_type": "code",
"execution_count": 7,
"id": "c3eafbc9",
"metadata": {},
"outputs": [],
"source": [
"model.freeze_layers_except(model.network.fc)"
]
},
{
Expand Down Expand Up @@ -792,13 +832,15 @@
},
{
"cell_type": "code",
"execution_count": 23,
"execution_count": 8,
"id": "5963bca9-e8c6-4390-84ca-175b33914ced",
"metadata": {},
"outputs": [],
"source": [
"import shutil\n",
"shutil.rmtree('./woodcock_labeled_data')\n",
"# shutil.rmtree('./woodcock_labeled_data')\n",
"\n",
"Path('./my_pre.json').unlink(missing_ok=True)\n",
"\n",
"for p in Path('.').glob('*.model'):\n",
" p.unlink()"
Expand Down
Loading

0 comments on commit 8ba62c5

Please sign in to comment.