-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathcdl_fda_fit.py
executable file
·60 lines (43 loc) · 1.29 KB
/
cdl_fda_fit.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
#!/usr/bin/env python
"""CDL encoding FDA fitter
CREATED:2013-05-08 16:15:55 by Brian McFee <brm2132@columbia.edu>
Usage:
./cdl_fda_fit.py output_fda_model.pickle /path/to/octarines/glob label_type
"""
import sys
import glob
import cPickle as pickle
import numpy as np
import FDA
def vectorize(A):
return A.squeeze().reshape((A.shape[0], -1))
def load_labels(infile, label_type):
'''filename has some -CL-encoded garbage: get the label instead'''
infile = '%s-%s.npy' % (infile[:infile.index('-encoded.npy')], label_type)
return np.load(infile)
def learn_fda(inpath, label_type):
files = glob.glob(inpath)
files.sort()
A = None
print 'Loading data...'
for f in files:
Anew = vectorize(np.load(f))
Ynew = load_labels(f, label_type)
if A is None:
A = Anew
Y = Ynew
else:
A = np.vstack((A, Anew))
Y = np.hstack((Y, Ynew))
print 'Building FDA model...'
transformer = FDA.FDA()
transformer.fit(A, Y)
print 'done.'
return transformer
if __name__ == '__main__':
outpath = sys.argv[1]
inpath = sys.argv[2]
label_type = sys.argv[3]
transform = learn_fda(inpath, label_type)
with open(outpath, 'w') as f:
pickle.dump(transform, f, protocol=-1)