-
Notifications
You must be signed in to change notification settings - Fork 6
/
Copy pathutils.py
26 lines (22 loc) · 979 Bytes
/
utils.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
'''
Misc Utility functions
'''
from collections import OrderedDict
import os
import numpy as np
def convert_state_dict(state_dict):
"""Converts a state dict saved from a dataParallel module to normal
module state_dict inplace
:param state_dict is the loaded DataParallel model_state
You probably saved the model using nn.DataParallel, which stores the model in module, and now you are trying to load it
without DataParallel. You can either add a nn.DataParallel temporarily in your network for loading purposes, or you can
load the weights file, create a new ordered dict without the module prefix, and load it back
"""
state_dict_new = OrderedDict()
#print(type(state_dict))
for k, v in state_dict.items():
#print(k)
name = k[7:] # remove the prefix module.
# My heart is borken, the pytorch have no ability to do with the problem.
state_dict_new[name] = v
return state_dict_new