-
Notifications
You must be signed in to change notification settings - Fork 6
/
Copy pathlda_image.py
143 lines (97 loc) · 3.78 KB
/
lda_image.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
import numpy as np
import lda
from step1_image import multiModal_features
import os
import cPickle
import cv2
def euclidean(x,y):
return np.sqrt(np.sum(np.square(x-y)))
class lda_image():
def __init__(self,data_path):
# self.histogram = multiModal_features(data_path).vbow.astype(np.int64)
self.num_topics = 100
# print self.histogram.shape
# print np.sum(self.histogram)
# exit(0)
#sum = np.sum(self.histogram,axis=0)
#print sum
#print np.sum(sum==0)
#exit(0)
#print self.histogram.shape
#self.lda_features = self.get_lda()
self.lda_model = self.get_lda()
self.doc_topics = self.lda_model.doc_topic_
# print self.doc_topics.shape
# self.query(200)
def get_lda(self):
lda_cache_file = os.path.join('cache', 'lda.pkl')
if os.path.isfile(lda_cache_file):
print('Loading lda object from : ' + lda_cache_file)
with open(lda_cache_file, 'rb') as f:
model = cPickle.load(f)
print 'Done!'
return model
else:
model = lda.LDA(self.num_topics, n_iter=500, random_state=1)
print 'fitting lda...'
model.fit(self.histogram)
print 'Done!'
print('Saving lda object to: ' + lda_cache_file)
with open(lda_cache_file, 'wb') as f:
cPickle.dump(model, f)
print 'Done!'
return model
def query(self,idx):
ipath_cache_file = os.path.join('cache', 'paths.pkl')
if os.path.isfile(ipath_cache_file):
print('Loading image paths from : ' + ipath_cache_file)
with open(ipath_cache_file, 'rb') as f:
ipath = cPickle.load(f)
print 'Done!'
img = cv2.imread(ipath[idx])
cv2.imwrite(os.path.join('query_demo','query.jpg'),img)
query_topics = self.lda_model.transform(np.expand_dims(self.histogram[idx],axis = 0),500)
print query_topics.shape
result = []
for i,doc_topics in enumerate(self.doc_topics):
doc_topics = doc_topics.reshape(1,self.num_topics)
dis = euclidean(doc_topics,query_topics)
if(len(result)<4):
result.append((dis,i))
result = sorted(result)
else:
if(dis<result[3][0]):
result[3]=(dis,i)
result = sorted(result)
for i,x in enumerate(result):
img = cv2.imread(ipath[x[1]])
cv2.imwrite(os.path.join('query_demo',str(i)+'.jpg'),img)
print x
def query2(self):
ipath_cache_file = os.path.join('cache', 'paths.pkl')
if os.path.isfile(ipath_cache_file):
with open(ipath_cache_file, 'rb') as f:
ipath = cPickle.load(f)
# print 'Done!'
idx = np.random.randint(0,len(self.doc_topics))
img = cv2.imread(ipath[idx])
query_topics = self.doc_topics[idx]
result = []
for i,doc_topics in enumerate(self.doc_topics):
doc_topics = doc_topics.reshape(1,self.num_topics)
dis = euclidean(doc_topics,query_topics)
result.append((dis,ipath[i][:-4]))
result = sorted(result)
# print result
return result,ipath[idx][:-4]
if __name__ =='__main__' :
data_path = os.path.join('dataset','ImageCLEFmed2009_train.02')
model = lda_image(data_path)
print model.query2()
# print 'enter query image index'
# while(1):
# k = raw_input()
# if (k=='exit'):
# break
# idx = int(k)
# model.query(idx)