-
Notifications
You must be signed in to change notification settings - Fork 3
/
Copy pathHyperParameterSpace.py
54 lines (47 loc) · 1.22 KB
/
HyperParameterSpace.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
class HyperParameterSpace:
def __init__(self, hp):
"""Define Hyper-Parameters.
"""
self.hp = hp
self.params = sorted(list(self.hp.keys()))
def iterateAllCombinations(self):
"""Go through the whole combinations of hyper-parameters.
"""
idx = [0] * len(self.params)
maxIdx = [len(self.hp[param]) for param in self.params]
while True:
yield self.idx2Str(idx), [(param, self.hp[param][i]) for param, i in zip(self.params, idx)]
self.addOneIdx(idx, maxIdx)
if self.isZeroIdx(idx):
break
def isZeroIdx(self, idx):
"""Whether idx is zero.
"""
for i in range(len(idx)):
if idx[i] > 0:
return False
return True
def addOneIdx(self, idx, maxIdx):
"""Add idx by one.
"""
assert(len(idx)==len(maxIdx))
i = len(idx) - 1
while i >= 0:
idx[i] += 1
if idx[i] == maxIdx[i]:
idx[i] = 0
i -= 1
else:
return
def idx2Str(self, idx):
"""Convert idx to string.
"""
return '_'.join([str(num) for num in idx])
def idx2Val(self, idx):
"""Convert idx to values.
"""
return [self.hp[param][i] for param, i in zip(self.params, idx)]
def getParamsName(self):
return self.params
def getParamsType(self):
return list([type(self.hp[param][0]) for param in self.params])