-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathinit_variables.py
110 lines (69 loc) · 3.69 KB
/
init_variables.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
import numpy as np
import tensorflow as tf
DTYPE=tf.float32
def init_variables(num_data, type_var, kernel_type, dim_input ):
with tf.variable_scope('model'):
q_mu = tf.get_variable(initializer = tf.zeros_initializer(),shape=(num_data,1),
dtype=DTYPE,name='q_mu')
if type_var=='full':
identity_matrix = tf.eye(num_data,dtype=DTYPE)
q_sqrt_real = tf.get_variable(initializer = identity_matrix,
dtype=DTYPE, name='q_sqrt_real')
else:
identity_matrix = tf.eye(num_data,dtype=DTYPE)
q_sqrt_real = tf.get_variable(initializer = identity_matrix ,
dtype=DTYPE, name='q_sqrt_real')
if kernel_type == 'interaction':
##############################################
######### Three-way Interaction term #########
##############################################
log_variance_kernel = tf.get_variable(initializer = tf.constant(0.301),dtype=tf.float32,
name='log_variance_kernel')
log_lengthscales = tf.get_variable(initializer = tf.constant([0.301 for _ in range(dim_input)]),
dtype=tf.float32, name='log_lengthscales')
##############################################
########### Two-way Interaction terms ########
##############################################
####################################################################
#### Reminder there are going to be 3 two-way interaction terms ####
####################################################################
for _ in range(3):
log_variance_kernel = tf.get_variable(initializer = tf.constant(0.301),dtype=tf.float32,
name='log_variance_kernel_two_way_'+str(_))
log_lengthscales = tf.get_variable(initializer = tf.constant([0.301 for _ in range(2)]),
dtype=tf.float32,name='log_lengthscales_two_way_'+str(_))
##############################################
######### Additive terms #####################
##############################################
for _ in range(dim_input):
log_variance_kernel = tf.get_variable(initializer = tf.constant(0.301,dtype=DTYPE),dtype=DTYPE,
name='log_variance_kernel_'+str(_))
log_lengthscales = tf.get_variable(initializer = tf.constant([0.301],dtype=DTYPE),
dtype=DTYPE,name='log_lengthscales_'+str(_))
elif kernel_type == 'mixed-additive-interaction':
############################################
######### Two-way Interaction term #########
############################################
log_variance_kernel = tf.get_variable(initializer = tf.constant(0.301),dtype=tf.float32,
name='log_variance_kernel')
log_lengthscales = tf.get_variable(initializer = tf.constant([0.301 for _ in range(2)]),
dtype=tf.float32,name='log_lengthscales')
#####################################
######### Additive terms ############
#####################################
for _ in range(dim_input):
log_variance_kernel = tf.get_variable(initializer = tf.constant(0.301,dtype=DTYPE),dtype=DTYPE,
name='log_variance_kernel_'+str(_))
log_lengthscales = tf.get_variable(initializer = tf.constant([0.301],dtype=DTYPE),
dtype=DTYPE,name='log_lengthscales_'+str(_))
elif kernel_type == 'additive':
#####################################
######### Additive terms ############
#####################################
for _ in range(dim_input):
log_variance_kernel = tf.get_variable(initializer = tf.constant(0.301,dtype=DTYPE),dtype=DTYPE,
name='log_variance_kernel_'+str(_))
log_lengthscales = tf.get_variable(initializer = tf.constant([0.301],dtype=DTYPE),
dtype=DTYPE,name='log_lengthscales_'+str(_))
else:
pass