-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathvae_layers.py
79 lines (51 loc) · 1.93 KB
/
vae_layers.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
import tensorflow as tf
from tensorflow import keras
from tensorflow.keras import layers
def kl_gaussians(mu_1,sigma_1, mu_2, sigma_2):
return tf.math.log(sigma_2/sigma_1) \
+ (sigma_1**2 + (mu_1-mu_2)**2)/(2*sigma_2**2) - 0.5
def kl_std_gaussian(mu, sigma):
return 0.5*(sigma**2 + mu**2) - 0.5 -tf.math.log(sigma)
class SamplingLoss(layers.Layer):
def call(self,inputs):
x,mean_q,std_q=inputs
kl_div = kl_std_gaussian(mean_q, std_q)
compress_loss = tf.reduce_sum(tf.reduce_mean(kl_div,axis=0))
self.add_loss(compress_loss)
return inputs
class ConditionalSamplingLoss(layers.Layer):
def call(self,inputs):
_, mean_q,std_q, mean_p,std_p = inputs
kl_div = kl_gaussians(mean_q, std_q, mean_p, std_p)
compress_loss = tf.reduce_sum(tf.reduce_mean(kl_div,axis=0))
self.add_loss(compress_loss)
return inputs
class Sampling(layers.Layer):
def call(self, inputs):
mu, sigma = inputs
eps = tf.random.normal(tf.shape(mu))
z = eps*sigma + mu
return z
class Reparameterize(layers.Layer):
def call(self, inputs):
eps, mu, sigma = inputs
z = eps*sigma + mu
return z
class XELoss(layers.Layer):
def call(self, inputs):
y_true, y_pred = inputs
losses_elem=keras.backend.binary_crossentropy(y_true,y_pred)
loss = tf.reduce_sum(
tf.reduce_mean(losses_elem,
axis=0))
self.add_loss(loss)
return y_pred
class L1Loss(layers.Layer):
def call(self, inputs):
y_true, y_pred = inputs
losses_elem=tf.abs(y_true - y_pred)
XEloss = tf.reduce_sum(
tf.reduce_mean(losses_elem,
axis=0))
self.add_loss(XEloss)
return y_pred