-
Notifications
You must be signed in to change notification settings - Fork 38
/
Copy pathautoencoder_cnn.py
42 lines (32 loc) · 1.25 KB
/
autoencoder_cnn.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
from keras.layers import Input, Dense, Flatten, Reshape, Conv2D, UpSampling2D, MaxPooling2D
from keras.models import Model, Sequential
from keras.datasets import mnist
from keras.callbacks import Callback
import numpy as np
from util import Images
import wandb
from wandb.keras import WandbCallback
run = wandb.init()
config = run.config
config.epochs = 30
(X_train, _), (X_test, _) = mnist.load_data()
X_train = X_train.astype('float32') / 255.
X_test = X_test.astype('float32') / 255.
model = Sequential()
model.add(Reshape((28, 28, 1), input_shape=(28, 28)))
model.add(Conv2D(8, (3, 3), activation='relu', padding='same'))
model.add(MaxPooling2D(2, 2))
model.add(Conv2D(4, (3, 3), activation='relu', padding='same'))
model.add(MaxPooling2D(2, 2))
model.add(Conv2D(1, (3, 3), activation='relu', padding='same'))
model.add(UpSampling2D((2, 2)))
model.add(Conv2D(12, (3, 3), activation='relu', padding='same'))
model.add(UpSampling2D((2, 2)))
model.add(Conv2D(1, (3, 3), activation='relu', padding='same'))
model.add(Reshape((28, 28)))
model.compile(optimizer='adam', loss='mse')
model.fit(X_train, X_train,
epochs=config.epochs,
validation_data=(X_test, X_test),
callbacks=[Images(), WandbCallback(save_model=False)])
model.save('auto-cnn.h5')