-
Notifications
You must be signed in to change notification settings - Fork 6
/
Copy pathema.py
31 lines (25 loc) · 965 Bytes
/
ema.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
import torch
from torch import nn
class EMA():
def __init__(self, opt, shared):
self.mu = opt.mu
self.avg = {} # keeps a copy of parameter averages
def step(self, m):
# recompute the averages
for n, p in m.named_parameters():
if p.requires_grad:
if n not in self.avg:
# intialize with the model itself
self.avg[n] = torch.Tensor().type_as(p.data).resize_as_(p.data).zero_()
self.avg[n].copy_(p.data)
new_avg = (1.0 - self.mu) * p.data + self.mu * self.avg[n]
self.avg[n].copy_(new_avg)
# copy to model
#for n, p in m.named_parameters():
# if p.requires_grad:
# p.data.copy_(self.avg[n])
def get_param_dict(self):
param_dict = {}
for n, p in self.avg.items():
param_dict[n] = p
return param_dict