forked from jhlau/topic_interpretability
-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathComputeObservedCoherence.py
146 lines (125 loc) · 4.55 KB
/
ComputeObservedCoherence.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
"""
Author: Jey Han Lau
Date: May 2013
"""
import argparse
import sys
import operator
import math
#parser arguments
desc = "Computes the observed coherence for a given topic and word-count file."
parser = argparse.ArgumentParser(description=desc)
#####################
#positional argument#
#####################
parser.add_argument("topic_file", help="file that contains the topics")
parser.add_argument("metric", help="type of evaluation metric", choices=["pmi","npmi","lcp"])
parser.add_argument("wordcount_file", help="file that contains the word counts")
args = parser.parse_args()
#parameters
colloc_sep = "_" #symbol for concatenating collocations
topN = 10 #top-N topic words to consider for computing coherence
debug = True
#input
topic_file = open(args.topic_file)
wc_file = open(args.wordcount_file)
#constants
WTOTALKEY = "!!<TOTAL_WINDOWS>!!" #key name for total number of windows (in word count file)
#global variables
window_total = 0 #total number of windows
wordcount = {} #a dictionary of word counts, for single and pair words
wordpos = {} #a dictionary of pos distribution
###########
#functions#
###########
#compute the association between two words
def calc_assoc(word1, word2):
combined1 = word1 + "|" + word2
combined2 = word2 + "|" + word1
combined_count = 0
if combined1 in wordcount:
combined_count = wordcount[combined1]
elif combined2 in wordcount:
combined_count = wordcount[combined2]
w1_count = 0
if word1 in wordcount:
w1_count = wordcount[word1]
w2_count = 0
if word2 in wordcount:
w2_count = wordcount[word2]
if (args.metric == "pmi") or (args.metric == "npmi"):
if w1_count == 0 or w2_count == 0 or combined_count == 0:
result = 0.0
else:
result = math.log((float(combined_count)*float(window_total))/ \
float(w1_count*w2_count), 10)
if args.metric == "npmi":
result = result / (-1.0*math.log(float(combined_count)/(window_total),10))
elif args.metric == "lcp":
if combined_count == 0:
if w2_count != 0:
result = math.log(float(w2_count)/window_total, 10)
else:
result = math.log(float(1.0)/window_total, 10)
else:
result = math.log((float(combined_count))/(float(w1_count)), 10)
return result
#compute topic coherence given a list of topic words
def calc_topic_coherence(topic_words):
topic_assoc = []
for w1_id in range(0, len(topic_words)-1):
target_word = topic_words[w1_id]
#remove the underscore and sub it with space if it's a collocation/bigram
w1 = " ".join(target_word.split(colloc_sep))
for w2_id in range(w1_id+1, len(topic_words)):
topic_word = topic_words[w2_id]
#remove the underscore and sub it with space if it's a collocation/bigram
w2 = " ".join(topic_word.split(colloc_sep))
if target_word != topic_word:
topic_assoc.append(calc_assoc(w1, w2))
return float(sum(topic_assoc))/len(topic_assoc)
######
#main#
######
#process the word count file(s)
for line in wc_file:
line = line.strip()
data = line.split("|")
if len(data) == 2:
wordcount[data[0]] = int(data[1])
elif len(data) == 3:
if data[0] < data[1]:
key = data[0] + "|" + data[1]
else:
key = data[1] + "|" + data[0]
wordcount[key] = int(data[2])
else:
print "ERROR: wordcount format incorrect. Line =", line
raise SystemExit
#get the total number of windows
if WTOTALKEY in wordcount:
window_total = wordcount[WTOTALKEY]
#read the topic file and compute the observed coherence
topic_coherence = {} # {topicid: tc}
topic_tw = {} #{topicid: topN_topicwords}
all_topic_words = set([])
topic_id = 0
for line in topic_file.readlines():
topic_list = line.split()[:topN]
topic_tw[topic_id] = " ".join(topic_list)
topic_coherence[topic_id] = calc_topic_coherence(topic_list)
for word in topic_list:
all_topic_words.add(word)
topic_id += 1
#sort the topic coherence scores in terms of topic id
tc_items = sorted(topic_coherence.items())
for item in tc_items:
topic_words = topic_tw[item[0]].split()
if debug:
print ("[%.2f]" % item[1]), topic_tw[item[0]]
else:
print item[1]
#print the overall topic coherence for all topics
if debug:
print "=========================================================================="
print "Average Topic Coherence =", sum(topic_coherence.values())/len(topic_coherence)