-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathconditional_GP.py
114 lines (74 loc) · 3.86 KB
/
conditional_GP.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
import numpy as np
import tensorflow as tf
from kernels import *
DTYPE=tf.float32
def conditional(Xnew, X, kernel_type, dim_input, position_interaction , type_var, white=True, full_cov=False):
###########################################################
### helper function to implement posterior GP equations ###
###########################################################
#### partially inspired fron ononymous function from GPflow #####
num_data = tf.shape(X)[0] # M
Kmm = kernel(X,X,True,kernel_type, dim_input, position_interaction)
Kmm = condition(Kmm)
Kmn = kernel(X, Xnew,True,kernel_type, dim_input, position_interaction)
Knn = kernel(Xnew,Xnew,full_cov,kernel_type, dim_input, position_interaction)
with tf.variable_scope('model',reuse=True):
q_sqrt_real = tf.get_variable('q_sqrt_real',dtype=DTYPE)
q_mu = tf.get_variable('q_mu',dtype=DTYPE)
if type_var=='full':
q_sqrt = tf.matrix_band_part(q_sqrt_real,-1,0)
else:
q_sqrt = tf.square(q_sqrt_real)
return base_conditional(Kmn, Kmm, Knn, q_mu, full_cov=full_cov, q_sqrt=q_sqrt, white=white)
def base_conditional(Kmn, Kmm, Knn, f, full_cov, q_sqrt=None, white=False):
###########################################################
### helper function to implement posterior GP equations ###
###########################################################
#### partially inspired fron ononymous function from GPflow #####
Lm = tf.cholesky(Kmm)
A = tf.matrix_triangular_solve(Lm, Kmn, lower=True)
if full_cov:
fvar = Knn - tf.matmul(A, A, transpose_a=True)
else:
fvar = Knn - tf.transpose(tf.reduce_sum(tf.square(A), 0,keep_dims=True))
if not white:
A = tf.matrix_triangular_solve(tf.transpose(Lm), A, lower=False)
fmean = tf.matmul(A, f, transpose_a=True)
if full_cov:
LTA= tf.matmul(tf.transpose(q_sqrt),A)
fvar = fvar + tf.matmul(LTA,LTA,transpose_a=True)
else:
LTA= tf.matmul(tf.transpose(q_sqrt),A)
fvar = fvar + tf.transpose(tf.reduce_sum(tf.square(LTA),0,keep_dims=True))
return fmean, fvar
def conditional_interaction(Xnew, X, type_var, kernel_type, plot_effect, position_interaction, white=True, full_cov=False):
num_data = tf.shape(X)[0] # M
Kmm = kernel_interaction(X1 = X, X2 = X, full_cov = True, kernel_type = kernel_type,
plot_effect = plot_effect, position_interaction = position_interaction)
Kmm = condition(Kmm)
Kmn = kernel_interaction(X1 = X, X2 = Xnew, full_cov = True, kernel_type = kernel_type,
plot_effect = plot_effect, position_interaction = position_interaction)
Knn = kernel_interaction(X1 = Xnew, X2 = Xnew, full_cov = full_cov, kernel_type = kernel_type,
plot_effect = plot_effect, position_interaction = position_interaction)
with tf.variable_scope('model',reuse=True):
q_sqrt_real = tf.get_variable('q_sqrt_real', dtype=DTYPE)
q_mu = tf.get_variable('q_mu', dtype=DTYPE)
if type_var=='full':
q_sqrt = tf.matrix_band_part(q_sqrt_real,-1,0)
else:
q_sqrt = tf.square(q_sqrt_real)
return base_conditional(Kmn, Kmm, Knn, q_mu, full_cov=full_cov, q_sqrt=q_sqrt, white=white)
def conditional_additive(Xnew, X, dim_input, type_var, white=True, full_cov=False):
num_data = tf.shape(X)[0] # M
Kmm = kernel_additive(X1 = X, X2 = X, full_cov = True, dim_input = dim_input)
Kmm = condition(Kmm)
Kmn = kernel_additive(X1 = X, X2 = Xnew, full_cov = True, dim_input = dim_input)
Knn = kernel_additive(X1 = Xnew, X2 = Xnew, full_cov = full_cov, dim_input = dim_input)
with tf.variable_scope('model',reuse=True):
q_sqrt_real = tf.get_variable('q_sqrt_real',dtype=DTYPE)
q_mu = tf.get_variable('q_mu',dtype=DTYPE)
if type_var=='full':
q_sqrt = tf.matrix_band_part(q_sqrt_real,-1,0)
else:
q_sqrt = tf.square(q_sqrt_real)
return base_conditional(Kmn, Kmm, Knn, q_mu, full_cov=full_cov, q_sqrt=q_sqrt, white=white)