Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

STS components with batching on the initial states break #1975

Open
jeffpollock9 opened this issue Nov 28, 2024 · 0 comments
Open

STS components with batching on the initial states break #1975

jeffpollock9 opened this issue Nov 28, 2024 · 0 comments

Comments

@jeffpollock9
Copy link
Contributor

I am trying to run many time series models with different initial states via batching, it seems to be supported in some parts of the code although not explicit in the documentation, so I am wondering if this is a bug or just not supported? Either way I think it would be useful to have working.

The following example shows some weird results you can get currently whereby the joint_distribution function seems to give the wrong answer (it works as expected for batch_shape=[]) where it seems to sum 3*3 log likelihoods instead of 3:

import tensorflow_probability as tfp

print(tf.__version__)
# 2.18.0

print(tfp.__version__)
# 0.25.0

sts = tfp.sts
tfd = tfp.distributions

batch_shape = [3]
num_timesteps = 10
param_vals = [1.0, 2.0]
observations = tf.ones(batch_shape + [num_timesteps, 1])

local_level = sts.LocalLevel(
    initial_level_prior=tfd.Normal(loc=tf.zeros(batch_shape), scale=1.0)
)

print(local_level.batch_shape)
# ()
print(local_level.initial_state_prior.batch_shape)
# (3,)

model = sts.Sum(components=[local_level])

# joint dist
joint_dist = model.joint_distribution(observed_time_series=observations)
joint_dist_log_prob = joint_dist.log_prob(param_vals)

print(joint_dist_log_prob)
# tf.Tensor(-167.13658, shape=(), dtype=float32)

# ssm
ssm = model.make_state_space_model(num_timesteps=num_timesteps, param_vals=param_vals)
log_likelihood = sum(
    ssm.forward_filter(observations, final_step_only=True).log_likelihoods
)
prior = sum(p.prior.log_prob(x) for p, x in zip(model.parameters, param_vals))

print(log_likelihood + prior)
# tf.Tensor(-60.865345, shape=(), dtype=float32)

print(log_likelihood * 3 + prior)
# tf.Tensor(-167.13658, shape=(), dtype=float32)

At a guess this is due to the components only considering their parameters and not the initial state for their batch shape. I am happy to work on a fix or feature addition if that is helpful. Thanks!

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

1 participant