-
Notifications
You must be signed in to change notification settings - Fork 28
/
Copy pathlstm-gender-predictor.py
74 lines (53 loc) · 1.91 KB
/
lstm-gender-predictor.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
from keras.models import Sequential
from keras.layers.core import Dense, Activation, Dropout
from keras.layers.recurrent import LSTM
import wandb
from wandb.keras import WandbCallback
import numpy as np
wandb.init()
config = wandb.config
config.epochs = 10
config.batch_size = 16
def load_names():
with open("male.txt") as f:
m_names = f.readlines()
with open("female.txt") as f:
f_names = f.readlines()
mf_names = []
# remove the names that are both male and female
for f_name in f_names:
if f_name in m_names:
mf_names.append(f_name)
m_names = [m_name.lower() for m_name in m_names if not m_name in mf_names]
f_names = [f_name.lower() for f_name in f_names if not f_name in mf_names]
return m_names, f_names
m_names, f_names = load_names()
totalEntries = len(m_names) + len(f_names)
maxlen = 20
chars = set( "".join(m_names) + "".join(f_names) )
char_indices = dict((c, i) for i, c in enumerate(chars))
indices_char = dict((i, c) for i, c in enumerate(chars))
X = np.zeros((totalEntries , maxlen, len(chars) ), dtype=np.float32)
y = np.zeros((totalEntries , 2 ), dtype=np.float32)
print(m_names)
for i, name in enumerate(m_names):
for t, char in enumerate(name):
X[i, t, char_indices[char]] = 1
y[i, 0 ] = 1
for i, name in enumerate(f_names):
for t, char in enumerate(name):
X[i + len(m_names), t, char_indices[char]] = 1
y[i + len(m_names) , 1 ] = 1
def vec2c(vec):
for i,v in enumerate(vec):
if v:
return indices_char[i]
return ""
model = Sequential()
model.add(LSTM(512, return_sequences=True, input_shape=(maxlen, len(chars))))
model.add(Dropout(0.2))
model.add(LSTM(512, return_sequences=False))
model.add(Dropout(0.2))
model.add(Dense(2, activation='softmax'))
model.compile(loss='binary_crossentropy', optimizer='rmsprop', metrics=['accuracy'])
model.fit(X, y,validation_split=0.2, callbacks=[WandbCallback()])