forked from qieaaa/Singal-CNN
-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathtrain.py
130 lines (125 loc) · 5.16 KB
/
train.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
# -*- coding: utf-8 -*-
import os,random
os.environ["KERAS_BACKEND"] = "tensorflow"
import numpy as np
from keras.utils import np_utils
import keras.models as models
from keras.layers.core import Reshape,Dense,Dropout,Activation,Flatten
from keras.layers.noise import GaussianNoise
from keras.layers.convolutional import Conv2D, MaxPooling2D, ZeroPadding2D
from keras.regularizers import *
from keras.optimizers import adam
import matplotlib.pyplot as plt
import seaborn as sns
import pickle, random, sys, keras
import h5py
#%%
with open("/home/hzy/keras/RML2016.10a_dict.dat",'rb') as xd1: #这段执行对原始数据进行切片的任务,可在spyder下运行,查看变量
Xd = pickle.load(xd1,encoding='latin1')
snrs,mods = map(lambda j: sorted(list(set(map(lambda x: x[j], Xd.keys())))), [1,0])
X = []
lbl = []
for mod in mods:
for snr in snrs:
X.append(Xd[(mod,snr)])
for i in range(Xd[(mod,snr)].shape[0]): lbl.append((mod,snr))
X = np.vstack(X)
#%%
np.random.seed(2016) #对预处理好的数据进行打包,制作成投入网络训练的格式,并进行one-hot编码
n_examples = X.shape[0]
n_train = n_examples * 0.5 #对半
train_idx = np.random.choice(range(0,n_examples), size=int(n_train), replace=False)
test_idx = list(set(range(0,n_examples))-set(train_idx)) #label
X_train = X[train_idx]
X_test = X[test_idx]
def to_onehot(yy):
yy1 = np.zeros([len(yy), max(yy)+1])
yy1[np.arange(len(yy)),yy] = 1
return yy1
trainy =list(map(lambda x: mods.index(lbl[x][0]), train_idx))
Y_train = to_onehot(trainy)
Y_test = to_onehot(list(map(lambda x: mods.index(lbl[x][0]), test_idx)))
#%%
in_shp = list(X_train.shape[1:])
print (X_train.shape, in_shp)
classes = mods
#%%
dr = 0.5 # dropout rate (%) 卷积层部分 https://keras-cn.readthedocs.io/en/latest/layers/convolutional_layer/#conv2d
model = models.Sequential() #这里使用keras的序贯模型 https://keras-cn.readthedocs.io/en/latest/models/sequential/
model.add(Reshape(([1]+in_shp), input_shape=in_shp))
model.add(ZeroPadding2D((0, 2)))
model.add(Conv2D(256, (1, 3),padding='valid', activation="relu", name="conv1", init='glorot_uniform',data_format="channels_first"))
model.add(Dropout(dr))
model.add(ZeroPadding2D((0, 2)))
model.add(Conv2D(80, (2, 3), padding="valid", activation="relu", name="conv2", init='glorot_uniform',data_format="channels_first"))
model.add(Dropout(dr))
model.add(Flatten())
model.add(Dense(256, activation='relu', init='he_normal', name="dense1"))
model.add(Dropout(dr))
model.add(Dense( len(classes), init='he_normal', name="dense2" ))
model.add(Activation('softmax'))
model.add(Reshape([len(classes)]))
model.compile(loss='categorical_crossentropy', optimizer='adam')
model.summary()
#%%
# Set up some params
epochs = 100 # number of epochs to train on
batch_size = 1024 # training batch size default1024
#%%
filepath = "convmodrecnets_CNN2_0.5.wts.h5" #所要保存的文件名字,h5格式,不用写路径,默认在程序执行的文件夹内
history = model.fit(X_train,
Y_train,
batch_size=batch_size,
epochs=epochs,
verbose=2,
validation_data=(X_test, Y_test),
callbacks = [ #回调函数,https://keras-cn.readthedocs.io/en/latest/other/callbacks/
keras.callbacks.ModelCheckpoint(filepath, monitor='val_loss', verbose=0, save_best_only=True, mode='auto'),
keras.callbacks.EarlyStopping(monitor='val_loss', patience=5, verbose=0, mode='auto')
]) #EarlyStopping 当监测值不再改善时,该回调函数将中止训练,如去除本行将执行所有epoch
model.load_weights(filepath)
#%%
score = model.evaluate(X_test, Y_test, verbose=0, batch_size=batch_size)
print(score)
#%%
def plot_confusion_matrix(cm, title='Confusion matrix', cmap=plt.cm.Blues, labels=[]):
plt.imshow(cm, interpolation='nearest', cmap=cmap)
plt.title(title)
plt.colorbar()
tick_marks = np.arange(len(labels))
plt.xticks(tick_marks, labels, rotation=45)
plt.yticks(tick_marks, labels)
plt.tight_layout()
plt.ylabel('True label')
plt.xlabel('Predicted label')
#%%
# Plot confusion matrix 画图
acc = {}
for snr in snrs:
# extract classes @ SNR
test_SNRs = list(map(lambda x: lbl[x][1], test_idx))
# print(test_SNRs)
test_X_i = X_test[np.where(np.array(test_SNRs)==snr)]
test_Y_i = Y_test[np.where(np.array(test_SNRs)==snr)]
# estimate classes
test_Y_i_hat = model.predict(test_X_i)
conf = np.zeros([len(classes),len(classes)])
confnorm = np.zeros([len(classes),len(classes)])
for i in range(0,test_X_i.shape[0]):
j = list(test_Y_i[i,:]).index(1)
k = int(np.argmax(test_Y_i_hat[i,:]))
conf[j,k] = conf[j,k] + 1
for i in range(0,len(classes)):
confnorm[i,:] = conf[i,:] / np.sum(conf[i,:])
plt.figure()
plot_confusion_matrix(confnorm, labels=classes, title="ConvNet Confusion Matrix (SNR=%d)"%(snr))
cor = np.sum(np.diag(conf))
ncor = np.sum(conf) - cor
print ("Overall Accuracy: ", cor / (cor+ncor))
acc[snr] = 1.0*cor/(cor+ncor)
#%%
# Plot accuracy curve
plt.plot(snrs, list(map(lambda x: acc[x], snrs)))
plt.xlabel("Signal to Noise Ratio")
plt.ylabel("Classification Accuracy")
plt.title("CNN2 Classification Accuracy on RadioML 2016.10 Alpha")