-
Notifications
You must be signed in to change notification settings - Fork 1
/
Copy path14-pytorch-04-autograd.qmd
350 lines (256 loc) · 10.3 KB
/
14-pytorch-04-autograd.qmd
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
---
title: "PyTorch 04: Autograd"
jupyter: python3
---
## Introduction
[![](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/github/trgardos/ml-549-fa24/blob/main/14-pytorch-04-autograd.ipynb)
Based on PyTorch [Autograd Tutorial](https://pytorch.org/tutorials/beginner/basics/autogradqs_tutorial.html).
## Gradient Descent and Back Propagation
When training neural networks, the most frequently used algorithm to minimize
the training loss is **stochastic gradient descent** (SGD).
The derivative of the loss function with respect to the parameters gives us the
direction of steepest **increase** evaluated at a set of parameter values.
So to decrease the loss, we need to move in the opposite direction of the
gradient, which is the direction of steepest **decrease**.
During training, we update the parameters, $\theta$, by the negative of the gradient, scaled
by some learning rate $\eta$:
$$
\theta \leftarrow \theta - \eta \nabla_\theta J(\theta)
$$
The gradient can be officially calcluated using the chain rule of calculus,
where the gradient of a function composed of multiple functions is the product
of the gradients of each function.
Say we have a function $f(x)$ that is a composition of two functions
$f(x) = f(g(x))$, then the gradient of $f$ with respect to $x$ is given by the
chain rule:
$$
\frac{\partial f}{\partial x} = \frac{\partial f}{\partial g} \cdot \frac{\partial g}{\partial x}
$$
A neural network is a composition of many functions, e.g. linear layer, activation function,
linear layer, etc.
We can compute every step of the gradient with **back propagation**.
## Automatic Differentiation with `torch.autograd`
To compute those gradients, PyTorch has a built-in differentiation engine called
`torch.autograd`.
It supports automatic computation of gradient for any **computational graph**.
Consider the simplest one-layer neural network, with input `x`,
parameters `w` and `b`, and some loss function. It can be defined in
PyTorch in the following manner:
```{python code-line-numbers="5-6"}
#| collapsed: false
import torch
x = torch.ones(5) # input tensor
y = torch.zeros(3) # expected output
w = torch.randn(5, 3, requires_grad=True)
b = torch.randn(3, requires_grad=True)
z = torch.matmul(x, w)+b
loss = torch.nn.functional.binary_cross_entropy_with_logits(z, y)
```
Here's the calculation of $z$ in matrix form:
```{python}
#| echo: false
import sys
if 'google.colab' in sys.modules:
!wget https://raw.githubusercontent.com/trgardos/ml-549-fa24/refs/heads/main/utils.py
!pip install torchviz
```
```{python}
#| echo: false
from utils import tensor_to_latex
from IPython.display import display, Math
latex_x = tensor_to_latex(x)
latex_w = tensor_to_latex(w.T)
latex_b = tensor_to_latex(b)
latex_z = tensor_to_latex(z)
display(Math(f"\\mathrm{{z}} = \\mathrm{{w}}\\mathrm{{x}}+\\mathrm{{b}} = {latex_w} {latex_x} + {latex_b}"))
```
## Tensors, Functions and Computational graph
The above code defines the following **computational graph**:
::: {.content-visible when-profile="web"}
```{dot}
//| label: fig-comp-graph
//| fig-cap: Computational graph for the simple neural network
digraph G {
rankdir=LR;
x [label="x", shape=box];
matmul1 [label="*"];
add1 [label="+"];
y [label="y", shape=box];
w [label="w", shape=box, style=rounded];
b [label="b", shape=box, style=rounded];
z [label="z", shape=box];
loss [label="loss"];
x -> matmul1;
{ rank=same; w -> matmul1;}
{ rank=same; b -> add1;}
matmul1 -> add1;
add1 -> z;
z -> CE;
{ rank=same; y -> CE;}
CE -> loss;
// This is supposed to draw a box around the parameters
subgraph cluster_0 {
label = "Parameters";
style = "dashed";
w;
b;
}
}
```
:::
::: {.content-hidden when-profile="web"}
Since the above won't render in Colab notebooks, we also render it in
python.
```{python}
#| code-fold: true
import sys
if 'google.colab' in sys.modules:
import graphviz
from IPython.display import display
# Define the graph using Graphviz
dot = graphviz.Digraph(comment='Computational Graph')
# Set graph attributes
dot.attr(rankdir='LR')
# Add nodes
dot.node('x', label='x', shape='box')
dot.node('matmul1', label='*')
dot.node('add1', label='+')
dot.node('y', label='y', shape='box')
dot.node('w', label='w', shape='box', style='rounded')
dot.node('b', label='b', shape='box', style='rounded')
dot.node('z', label='z', shape='box')
dot.node('loss', label='loss')
# Add edges
dot.edge('x', 'matmul1')
dot.edge('w', 'matmul1', rank='same')
dot.edge('matmul1', 'add1')
dot.edge('b', 'add1', rank='same')
dot.edge('add1', 'z')
dot.edge('z', 'CE')
dot.edge('y', 'CE', rank='same')
dot.edge('CE', 'loss')
# Add subgraph for parameters
with dot.subgraph(name='cluster_0') as c:
c.attr(label='Parameters', style='dashed')
c.node('w')
c.node('b')
# Render the graph
display(dot)
```
:::
In this network, `w` and `b` are **parameters**, which we need to
optimize.
Thus, we need to be able to compute the gradients of loss
function with respect to those variables.
### Tensors with `requires_grad`
In order to do that, we set the `requires_grad` property of those tensors.
::: {.callout-note}
You can set the value of `requires_grad` when creating a tensor, or
later by using `x.requires_grad_(True)` method.
:::
By default, new tensors are created with `requires_grad=False`.
```{python}
print(f"x: {x.requires_grad}, y: {y.requires_grad}, w: {w.requires_grad}, b: {b.requires_grad}")
```
### Functions Support Back Propagation
Any function that operates on tensors inherits from class `Function`.
The `Function` object knows
* how to compute the function in the *forward* direction, and in fact implements a `forward` method, and
* how to compute its derivative during the *backward propagation* step, and in fact implements a `backward` method.
A reference to the backward propagation function is stored in `grad_fn` property
of a tensor.
You can find more information of `Function` in the
[documentation](https://pytorch.org/docs/stable/autograd.html#function).
```{python}
print(f"Gradient function for loss = {loss.grad_fn}")
print(f"Gradient function for z = {z.grad_fn}")
```
## Computing Gradients
To optimize weights of parameters in the neural network, we need to
compute the derivatives of our loss function with respect to parameters.
We need $\frac{\partial loss}{\partial w}$ and
$\frac{\partial loss}{\partial b}$ under some fixed values of `x` and
`y`.
To compute those derivatives, we call `loss.backward()`, which steps back through
the computational graph, computing the gradients of the leaf nodes, which in
this case are the parameters `w` and `b`.
We can retrieve the values from `w.grad` and `b.grad`:
```{python}
loss.backward()
print(w.grad)
print(b.grad)
```
There's a handy package called [`torchviz`](https://github.com/szagoruyko/pytorchviz)
that also can render the computational graph.
```{python}
from torchviz import make_dot
from graphviz import Source
from IPython.display import display
# Visualize the computational graph
dot = make_dot(loss, params={"w": w, "b": b})
# Render and display the graph inline
graph = Source(dot.source)
display(graph)
```
Of course the graph will quickly become too complex for any meaningful network.
::: {.callout-warning}
We can only obtain the `grad` properties for the leafnodes of the computational
graph, which have `requires_grad` propertyset to `True`. For all other nodes in
our graph, gradients will not be available. We can only perform gradient
calculations using `backward` once on a given graph, for performance reasons. If
we needto do several `backward` calls on the same graph, we need to pass
`retain_graph=True` to the `backward` call.
:::
## Disabling Gradient Tracking
By default, all tensors with `requires_grad=True` are tracking their
computational history and support gradient computation.
However, there are some cases when we do not need to do that, for example, when
we have trained the model and just want to apply it to some input data, i.e. we
only want to do *forward* computations through the network.
We can stop tracking computations by surrounding our computation code with a
`torch.no_grad()` block:
```{python}
z = torch.matmul(x, w)+b
print(z.requires_grad)
with torch.no_grad():
z = torch.matmul(x, w)+b
print(z.requires_grad)
```
Another way to achieve the same result is to use the `detach()` method
on the tensor:
```{python}
z = torch.matmul(x, w)+b
z_det = z.detach()
print(z_det.requires_grad)
```
There are reasons you might want to disable gradient tracking:
- To mark some parameters in your neural network as **frozen parameters**.
- To **speed up computations** when you are only doing forward pass, because
computations on tensors that do not track gradients would be more efficient.
## Recap: Back Propagation and Computational Graphs
Conceptually, autograd keeps a record of data (tensors) and all executed
operations (along with the resulting new tensors) in a _directed acyclic
graph_ (DAG) consisting of
[`Function`](https://pytorch.org/docs/stable/autograd.html#torch.autograd.Function)
objects.
In this DAG, leaves are the input tensors, roots are the output tensors. By
tracing this graph from roots to leaves, you can automatically compute the
gradients using the chain rule.
In a forward pass, `autograd` does two things simultaneously:
- run the requested operation to compute a resulting tensor
- maintain the operation's *gradient function* in the DAG.
The backward pass kicks off when `.backward()` is called on the DAG
root.
`autograd` then:
- computes the gradients from each `.grad_fn`,
- accumulates them in the respective tensor's `.grad` attribute
- using the chain rule, propagates all the way to the leaf tensors.
::: {.callout-important}
An important thing to note is that the graph is _recreated from scratch_ by
walking backward through the graph.
* after each `.backward()` call, `autograd` starts populating a new graph. This
is exactly what allows you to use control flow statements in your model
* you can change the shape, size and operations at every iteration if needed.
:::
## Further Reading
- [Autograd Mechanics](https://pytorch.org/docs/stable/notes/autograd.html)