-
Notifications
You must be signed in to change notification settings - Fork 1
/
Copy pathnodevectors_node2vec.py
182 lines (173 loc) · 7.01 KB
/
nodevectors_node2vec.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
import numba
import numpy as np
import pandas as pd
import time
import warnings
# Gensim triggers automatic useless warnings for windows users...
warnings.simplefilter("ignore", category=UserWarning)
import gensim
warnings.simplefilter("default", category=UserWarning)
import csrgraph as cg
from nodevectors.embedders import BaseNodeEmbedder
class Node2Vec(BaseNodeEmbedder):
def __init__(
self,
n_components=32,
walklen=30,
epochs=20,
return_weight=1.,
neighbor_weight=1.,
threads=0,
keep_walks=False,
verbose=True,
w2vparams={"window":10, "negative":5, "iter":10,
"batch_words":128}):
"""
Parameters
----------
walklen : int
length of the random walks
epochs : int
number of times to start a walk from each nodes
threads : int
number of threads to use. 0 is full use
n_components : int
number of resulting dimensions for the embedding
This should be set here rather than in the w2vparams arguments
return_weight : float in (0, inf]
Weight on the probability of returning to node coming from
Having this higher tends the walks to be
more like a Breadth-First Search.
Having this very high (> 2) makes search very local.
Equal to the inverse of p in the Node2Vec paper.
neighbor_weight : float in (0, inf]
Weight on the probability of visitng a neighbor node
to the one we're coming from in the random walk
Having this higher tends the walks to be
more like a Depth-First Search.
Having this very high makes search more outward.
Having this very low makes search very local.
Equal to the inverse of q in the Node2Vec paper.
keep_walks : bool
Whether to save the random walks in the model object after training
w2vparams : dict
dictionary of parameters to pass to gensim's word2vec
Don't set the embedding dimensions through arguments here.
"""
if type(threads) is not int:
raise ValueError("Threads argument must be an int!")
if walklen < 1 or epochs < 1:
raise ValueError("Walklen and epochs arguments must be > 1")
self.n_components = n_components
self.walklen = walklen
self.epochs = epochs
self.keep_walks = keep_walks
if 'size' in w2vparams.keys():
raise AttributeError("Embedding dimensions should not be set "
+ "through w2v parameters, but through n_components")
self.w2vparams = w2vparams
self.return_weight = return_weight
self.neighbor_weight = neighbor_weight
if threads == 0:
threads = numba.config.NUMBA_DEFAULT_NUM_THREADS
self.threads = threads
w2vparams['workers'] = threads
self.verbose = verbose
def fit(self, G):
"""
NOTE: Currently only support str or int as node name for graph
Parameters
----------
G : graph data
Graph to embed
Can be any graph type that's supported by csrgraph library
(NetworkX, numpy 2d array, scipy CSR matrix, CSR matrix components)
"""
if not isinstance(G, cg.csrgraph):
G = cg.csrgraph(G, threads=self.threads)
if G.threads != self.threads:
G.set_threads(self.threads)
# Because networkx graphs are actually iterables of their nodes
# we do list(G) to avoid networkx 1.X vs 2.X errors
node_names = G.names
if type(node_names[0]) not in [int, str, np.int32, np.uint32,
np.int64, np.uint64]:
raise ValueError("Graph node names must be int or str!")
# Adjacency matrix
walks_t = time.time()
if self.verbose:
print("Making walks...", end=" ", flush=True)
self.walks = G.random_walks(walklen=self.walklen,
epochs=self.epochs,
return_weight=self.return_weight,
neighbor_weight=self.neighbor_weight)
if self.verbose:
print(f"Done, T={time.time() - walks_t:.2f}")
print("Mapping Walk Names...", end=" ", flush=True)
map_t = time.time()
self.walks = pd.DataFrame(self.walks)
# Map nodeId -> node name
node_dict = dict(zip(np.arange(len(node_names)), node_names))
for col in self.walks.columns:
self.walks[col] = self.walks[col].map(node_dict).astype(str)
# Somehow gensim only trains on this list iterator
# it silently mistrains on array input
self.walks = [list(x) for x in self.walks.itertuples(False, None)]
if self.verbose:
print(f"Done, T={time.time() - map_t:.2f}")
print("Training W2V...", end=" ", flush=True)
if gensim.models.word2vec.FAST_VERSION < 1:
print("WARNING: gensim word2vec version is unoptimized"
"Try version 3.6 if on windows, versions 3.7 "
"and 3.8 have had issues")
w2v_t = time.time()
# Train gensim word2vec model on random walks
self.model = gensim.models.Word2Vec(
sentences=self.walks,
vector_size=self.n_components,
**self.w2vparams)
if not self.keep_walks:
del self.walks
if self.verbose:
print(f"Done, T={time.time() - w2v_t:.2f}")
def fit_transform(self, G):
"""
NOTE: Currently only support str or int as node name for graph
Parameters
----------
G : graph data
Graph to embed
Can be any graph type that's supported by csrgraph library
(NetworkX, numpy 2d array, scipy CSR matrix, CSR matrix components)
"""
if not isinstance(G, cg.csrgraph):
G = cg.csrgraph(G, threads=self.threads)
self.fit(G)
w = np.array(
pd.DataFrame.from_records(
pd.Series(np.arange(len(G.nodes())))
.apply(self.predict)
.values)
)
return w
def predict(self, node_name):
"""
Return vector associated with node
node_name : str or int
either the node ID or node name depending on graph format
"""
# current hack to work around word2vec problem
# ints need to be str -_-
if type(node_name) is not str:
node_name = str(node_name)
return self.model.wv.__getitem__(node_name)
def save_vectors(self, out_file):
"""
Save as embeddings in gensim.models.KeyedVectors format
"""
self.model.wv.save_word2vec_format(out_file)
def load_vectors(self, out_file):
"""
Load embeddings from gensim.models.KeyedVectors format
"""
self.model = gensim.wv.load_word2vec_format(out_file)