-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathcheckpoints.lua
74 lines (59 loc) · 2.09 KB
/
checkpoints.lua
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
--
-- Adapted from https://github.com/facebook/fb.resnet.torch/blob/master/checkpoints.lua
--
-- Checkpoint loading and saving
--
local M = {}
-- Sanitize gradients to reduce checkpoint size.
-- See https://github.com/karpathy/neuraltalk2/blob/master/misc/net_utils.lua
local function sanitizeGradients(model)
for _, m in ipairs(model:listModules()) do
if m.weight and m.gradWeight then m.gradWeight = nil end
if m.bias and m.gradBias then m.gradBias = nil end
end
end
local function unsanitizeGradients(model)
for _, m in ipairs(model:listModules()) do
if m.weight and not m.gradWeight then
m.gradWeight = m.weight:clone():zero()
end
if m.bias and not m.gradBias then
m.gradBias = m.bias:clone():zero()
end
end
end
function M.latest(opt)
if opt.resume == 'none' then
return nil
end
local latestPath = paths.concat(opt.resume, 'latest.t7')
if not paths.filep(latestPath) then
return nil
end
print('=> Loading checkpoint ' .. latestPath)
local latest = torch.load(latestPath)
local modelPath = paths.concat(opt.resume, latest.modelFile)
assert(paths.filep(modelPath), 'Saved model not found: ' .. modelPath)
print('=> Resuming model from ' .. modelPath)
local model = torch.load(modelPath):cuda()
unsanitizeGradients(model)
local optimState = torch.load(paths.concat(opt.resume, latest.optimFile))
return latest, model, optimState
end
function M.save(lightModel, optimState, checkpoint, opt)
sanitizeGradients(lightModel)
-- save model
local modelFile = 'model_' .. checkpoint.epoch .. '.t7'
torch.save(paths.concat(opt.save, modelFile), lightModel)
-- save optim state
local optimFile = 'optimState_' .. checkpoint.epoch .. '.t7'
torch.save(paths.concat(opt.save, optimFile), optimState)
-- save rest of checkpoint data
checkpoint['modelFile'] = modelFile
checkpoint['optimFile'] = optimFile
torch.save(paths.concat(opt.save, 'latest.t7'), checkpoint)
if checkpoint.isBestModel then
torch.save(paths.concat(opt.save, 'model_best.t7'), lightModel)
end
end
return M