Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

MultiTaskGaussianProcessRegressionModel issue with tf.function #1981

Open
mengsiong opened this issue Dec 13, 2024 · 6 comments
Open

MultiTaskGaussianProcessRegressionModel issue with tf.function #1981

mengsiong opened this issue Dec 13, 2024 · 6 comments

Comments

@mengsiong
Copy link

Hi, I have encountered an error when using tf.function decorator on a function calling MultiTaskGaussianProcessRegressionModel from the experimental package. The error occurs specifically when passing input_signature of unknown dimension to tf.function. Below are the details. Thanks!

Code

import tensorflow as tf
import tensorflow_probability as tfp

tfde = tfp.experimental.distributions
tfk = tfp.math.psd_kernels
tfke = tfp.experimental.psd_kernels

base_kernel = tfk.ExponentiatedQuadratic(
    amplitude=tf.convert_to_tensor(0.6, tf.float64),
    length_scale=tf.convert_to_tensor(0.5, tf.float64),
)
kernel = tfke.Independent(num_tasks=2, base_kernel=base_kernel)
observations = tf.constant([[0., 1.],[-0.5, -1.0]], tf.float64)
observation_index_points = tf.constant([[0],[1.0]], tf.float64)

@tf.function(input_signature=[tf.TensorSpec(shape=[None, 1], dtype=tf.float64)])
def predict(index_points):
    gp = tfde.MultiTaskGaussianProcessRegressionModel(
        kernel=kernel,
        observations=observations,
        observation_index_points=observation_index_points,
        index_points=index_points,
    )
    return gp.mean()

predict(tf.constant([[0.2],[0.4],[0.6]], tf.float64))

The code above works:

  • if we remove tf.function, or
  • if we remove input_signature inside with tf.function

We get the following error if we set shape=[None, 1] inside tf.TensorSpec:

UnboundLocalError: in user code:

    File "/tmp/ipykernel_299029/1064227357.py", line 24, in predict  *
        return gp.mean()
    File "/usr/.local/lib/python3.12/site-packages/tensorflow_probability/python/distributions/distribution.py", line 1536, in mean  **
        return self._mean(**kwargs)
    File "/usr/.local/lib/python3.12/site-packages/tensorflow_probability/python/experimental/distributions/multitask_gaussian_process_regression_model.py", line 852, in _mean
        self._get_flattened_marginal_distribution(
    File "/usr/.local/lib/python3.12/site-packages/tensorflow_probability/python/experimental/distributions/multitask_gaussian_process_regression_model.py", line 831, in _get_flattened_marginal_distribution
        covariance = self._compute_flattened_covariance(index_points)
    File "/usr/.local/lib/python3.12/site-packages/tensorflow_probability/python/experimental/distributions/multitask_gaussian_process_regression_model.py", line 811, in _compute_flattened_covariance
        cholinv_kzx = observation_scale.solve(kxz, adjoint_arg=True)

    UnboundLocalError: cannot access local variable 'dim' where it is not associated with a value
@csuter
Copy link
Member

csuter commented Dec 13, 2024

it looks like TF is masking some of the call stack. can you get a hold of the full stack trace?

@mengsiong
Copy link
Author

mengsiong commented Dec 13, 2024

@csuter Here you go. Thanks!

---------------------------------------------------------------------------
UnboundLocalError                         Traceback (most recent call last)
Cell In[1], line 26
     18     gp = tfde.MultiTaskGaussianProcessRegressionModel(
     19         kernel=kernel,
     20         observations=observations,
     21         observation_index_points=observation_index_points,
     22         index_points=index_points,
     23     )
     24     return gp.mean()
---> 26 predict(tf.constant([[0.2],[0.4],[0.6]], tf.float64))

File ~/.local/lib/python3.12/site-packages/tensorflow/python/util/traceback_utils.py:153, in filter_traceback.<locals>.error_handler(*args, **kwargs)
    151 except Exception as e:
    152   filtered_tb = _process_traceback_frames(e.__traceback__)
--> 153   raise e.with_traceback(filtered_tb) from None
    154 finally:
    155   del filtered_tb

File /tmp/__autograph_generated_fileviciofx3.py:13, in outer_factory.<locals>.inner_factory.<locals>.tf__predict(index_points)
     11 try:
     12     do_return = True
---> 13     retval_ = ag__.converted_call(ag__.ld(gp).mean, (), None, fscope)
     14 except:
     15     do_return = False

File ~/.local/lib/python3.12/site-packages/tensorflow_probability/python/distributions/distribution.py:1536, in Distribution.mean(self, name, **kwargs)
   1534 """Mean."""
   1535 with self._name_and_control_scope(name):
-> 1536   return self._mean(**kwargs)

File ~/.local/lib/python3.12/site-packages/tensorflow_probability/python/experimental/distributions/multitask_gaussian_process_regression_model.py:856, in MultiTaskGaussianProcessRegressionModel._mean(self, index_points)
    852 def _mean(self, index_points=None):
    853   # The mean is of shape B1 + [E, N], where E is the number of index points,
    854   # and N is the number of tasks.
    855   return _unvec(
--> 856       self._get_flattened_marginal_distribution(
    857           index_points=index_points).mean(), [-1, self.kernel.num_tasks])

File ~/.local/lib/python3.12/site-packages/tensorflow_probability/python/experimental/distributions/multitask_gaussian_process_regression_model.py:835, in MultiTaskGaussianProcessRegressionModel._get_flattened_marginal_distribution(self, index_points)
    833 with self._name_and_control_scope('get_flattened_marginal_distribution'):
    834   index_points = self._get_index_points(index_points)
--> 835   covariance = self._compute_flattened_covariance(index_points)
    836   loc = self._flattened_conditional_mean_fn(index_points)
    837   scale = tf.linalg.LinearOperatorLowerTriangular(
    838       self._cholesky_fn(covariance),
    839       is_non_singular=True,
    840       name='GaussianProcessScaleLinearOperator')

File ~/.local/lib/python3.12/site-packages/tensorflow_probability/python/experimental/distributions/multitask_gaussian_process_regression_model.py:815, in MultiTaskGaussianProcessRegressionModel._compute_flattened_covariance(self, index_points)
    807 else:
    808   observation_scale = _compute_observation_scale(
    809       self.kernel,
    810       self.observation_index_points,
    811       self.cholesky_fn,
    812       observation_noise_variance=self.observation_noise_variance,
    813       observations_is_missing=self.observations_is_missing)
--> 815 cholinv_kzx = observation_scale.solve(kxz, adjoint_arg=True)
    816 kxz_kzzinv_kzx = tf.linalg.matmul(
    817     cholinv_kzx, cholinv_kzx, transpose_a=True)
    819 flattened_covariance = kxx.to_dense() - kxz_kzzinv_kzx

UnboundLocalError: in user code:

    File "/tmp/ipykernel_1177/539756717.py", line 24, in predict  *
        return gp.mean()
    File "/usr/.local/lib/python3.12/site-packages/tensorflow_probability/python/distributions/distribution.py", line 1536, in mean  **
        return self._mean(**kwargs)
    File "/usr/.local/lib/python3.12/site-packages/tensorflow_probability/python/experimental/distributions/multitask_gaussian_process_regression_model.py", line 856, in _mean
        self._get_flattened_marginal_distribution(
    File "/usr/.local/lib/python3.12/site-packages/tensorflow_probability/python/experimental/distributions/multitask_gaussian_process_regression_model.py", line 835, in _get_flattened_marginal_distribution
        covariance = self._compute_flattened_covariance(index_points)
    File "/usr/.local/lib/python3.12/site-packages/tensorflow_probability/python/experimental/distributions/multitask_gaussian_process_regression_model.py", line 815, in _compute_flattened_covariance
        cholinv_kzx = observation_scale.solve(kxz, adjoint_arg=True)

    UnboundLocalError: cannot access local variable 'dim' where it is not associated with a value

@csuter
Copy link
Member

csuter commented Dec 17, 2024

This is still not showing the hidden stack frames from within TF code. try this https://www.tensorflow.org/api_docs/python/tf/debugging/disable_traceback_filtering

@mengsiong
Copy link
Author

@csuter please see below. thanks!

---------------------------------------------------------------------------
UnboundLocalError                         Traceback (most recent call last)
Cell In[1], line 27
     19     gp = tfde.MultiTaskGaussianProcessRegressionModel(
     20         kernel=kernel,
     21         observations=observations,
     22         observation_index_points=observation_index_points,
     23         index_points=index_points,
     24     )
     25     return gp.mean()
---> 27 predict(tf.constant([[0.2],[0.4],[0.6]], tf.float64))

File ~/.local/lib/python3.12/site-packages/tensorflow/python/util/traceback_utils.py:146, in filter_traceback.<locals>.error_handler(*args, **kwargs)
    141     return fn(*args, **kwargs)
    142 except NameError:
    143   # In some very rare cases,
    144   # `is_traceback_filtering_enabled` (from the outer scope) may not be
    145   # accessible from inside this function
--> 146   return fn(*args, **kwargs)
    148 filtered_tb = None
    149 try:

File ~/.local/lib/python3.12/site-packages/tensorflow/python/eager/polymorphic_function/polymorphic_function.py:833, in Function.__call__(self, *args, **kwds)
    830 compiler = "xla" if self._jit_compile else "nonXla"
    832 with OptionalXlaContext(self._jit_compile):
--> 833   result = self._call(*args, **kwds)
    835 new_tracing_count = self.experimental_get_tracing_count()
    836 without_tracing = (tracing_count == new_tracing_count)

File ~/.local/lib/python3.12/site-packages/tensorflow/python/eager/polymorphic_function/polymorphic_function.py:878, in Function._call(self, *args, **kwds)
    875 self._lock.release()
    876 # In this case we have not created variables on the first call. So we can
    877 # run the first trace but we should fail if variables are created.
--> 878 results = tracing_compilation.call_function(
    879     args, kwds, self._variable_creation_config
    880 )
    881 if self._created_variables:
    882   raise ValueError("Creating variables on a non-first call to a function"
    883                    " decorated with tf.function.")

File ~/.local/lib/python3.12/site-packages/tensorflow/python/eager/polymorphic_function/tracing_compilation.py:132, in call_function(args, kwargs, tracing_options)
    130 args = args if args else ()
    131 kwargs = kwargs if kwargs else {}
--> 132 function = trace_function(
    133     args=args, kwargs=kwargs, tracing_options=tracing_options
    134 )
    136 # Bind it ourselves to skip unnecessary canonicalization of default call.
    137 bound_args = function.function_type.bind(*args, **kwargs)

File ~/.local/lib/python3.12/site-packages/tensorflow/python/eager/polymorphic_function/tracing_compilation.py:178, in trace_function(args, kwargs, tracing_options)
    175     args = tracing_options.input_signature
    176     kwargs = {}
--> 178   concrete_function = _maybe_define_function(
    179       args, kwargs, tracing_options
    180   )
    182 if not tracing_options.bind_graph_to_function:
    183   concrete_function._garbage_collector.release()  # pylint: disable=protected-access

File ~/.local/lib/python3.12/site-packages/tensorflow/python/eager/polymorphic_function/tracing_compilation.py:283, in _maybe_define_function(args, kwargs, tracing_options)
    281 else:
    282   target_func_type = lookup_func_type
--> 283 concrete_function = _create_concrete_function(
    284     target_func_type, lookup_func_context, func_graph, tracing_options
    285 )
    287 if tracing_options.function_cache is not None:
    288   tracing_options.function_cache.add(
    289       concrete_function, current_func_context
    290   )

File ~/.local/lib/python3.12/site-packages/tensorflow/python/eager/polymorphic_function/tracing_compilation.py:310, in _create_concrete_function(function_type, type_context, func_graph, tracing_options)
    303   placeholder_bound_args = function_type.placeholder_arguments(
    304       placeholder_context
    305   )
    307 disable_acd = tracing_options.attributes and tracing_options.attributes.get(
    308     attributes_lib.DISABLE_ACD, False
    309 )
--> 310 traced_func_graph = func_graph_module.func_graph_from_py_func(
    311     tracing_options.name,
    312     tracing_options.python_function,
    313     placeholder_bound_args.args,
    314     placeholder_bound_args.kwargs,
    315     None,
    316     func_graph=func_graph,
    317     add_control_dependencies=not disable_acd,
    318     arg_names=function_type_utils.to_arg_names(function_type),
    319     create_placeholders=False,
    320 )
    322 transform.apply_func_graph_transforms(traced_func_graph)
    324 graph_capture_container = traced_func_graph.function_captures

File ~/.local/lib/python3.12/site-packages/tensorflow/python/framework/func_graph.py:1059, in func_graph_from_py_func(name, python_func, args, kwargs, signature, func_graph, add_control_dependencies, arg_names, op_return_value, collections, capture_by_value, create_placeholders)
   1056   return x
   1058 _, original_func = tf_decorator.unwrap(python_func)
-> 1059 func_outputs = python_func(*func_args, **func_kwargs)
   1061 # invariant: `func_outputs` contains only Tensors, CompositeTensors,
   1062 # TensorArrays and `None`s.
   1063 func_outputs = variable_utils.convert_variables_to_tensors(func_outputs)

File ~/.local/lib/python3.12/site-packages/tensorflow/python/eager/polymorphic_function/polymorphic_function.py:599, in Function._generate_scoped_tracing_options.<locals>.wrapped_fn(*args, **kwds)
    595 with default_graph._variable_creator_scope(scope, priority=50):  # pylint: disable=protected-access
    596   # __wrapped__ allows AutoGraph to swap in a converted function. We give
    597   # the function a weak reference to itself to avoid a reference cycle.
    598   with OptionalXlaContext(compile_with_xla):
--> 599     out = weak_wrapped_fn().__wrapped__(*args, **kwds)
    600   return out

File ~/.local/lib/python3.12/site-packages/tensorflow/python/eager/polymorphic_function/autograph_util.py:52, in py_func_from_autograph.<locals>.autograph_handler(*args, **kwargs)
     50 except Exception as e:  # pylint:disable=broad-except
     51   if hasattr(e, "ag_error_metadata"):
---> 52     raise e.ag_error_metadata.to_exception(e)
     53   else:
     54     raise

File ~/.local/lib/python3.12/site-packages/tensorflow/python/eager/polymorphic_function/autograph_util.py:41, in py_func_from_autograph.<locals>.autograph_handler(*args, **kwargs)
     39 """Calls a converted version of original_func."""
     40 try:
---> 41   return api.converted_call(
     42       original_func,
     43       args,
     44       kwargs,
     45       options=converter.ConversionOptions(
     46           recursive=True,
     47           optional_features=autograph_options,
     48           user_requested=True,
     49       ))
     50 except Exception as e:  # pylint:disable=broad-except
     51   if hasattr(e, "ag_error_metadata"):

File ~/.local/lib/python3.12/site-packages/tensorflow/python/autograph/impl/api.py:439, in converted_call(f, args, kwargs, caller_fn_scope, options)
    437 try:
    438   if kwargs is not None:
--> 439     result = converted_f(*effective_args, **kwargs)
    440   else:
    441     result = converted_f(*effective_args)

File /tmp/__autograph_generated_fileaxs5xf4h.py:13, in outer_factory.<locals>.inner_factory.<locals>.tf__predict(index_points)
     11 try:
     12     do_return = True
---> 13     retval_ = ag__.converted_call(ag__.ld(gp).mean, (), None, fscope)
     14 except:
     15     do_return = False

File ~/.local/lib/python3.12/site-packages/tensorflow/python/autograph/impl/api.py:331, in converted_call(f, args, kwargs, caller_fn_scope, options)
    329 if conversion.is_in_allowlist_cache(f, options):
    330   logging.log(2, 'Allowlisted %s: from cache', f)
--> 331   return _call_unconverted(f, args, kwargs, options, False)
    333 if ag_ctx.control_status_ctx().status == ag_ctx.Status.DISABLED:
    334   logging.log(2, 'Allowlisted: %s: AutoGraph is disabled in context', f)

File ~/.local/lib/python3.12/site-packages/tensorflow/python/autograph/impl/api.py:460, in _call_unconverted(f, args, kwargs, options, update_cache)
    458 if kwargs is not None:
    459   return f(*args, **kwargs)
--> 460 return f(*args)

File ~/.local/lib/python3.12/site-packages/tensorflow_probability/python/distributions/distribution.py:1536, in Distribution.mean(self, name, **kwargs)
   1534 """Mean."""
   1535 with self._name_and_control_scope(name):
-> 1536   return self._mean(**kwargs)

File ~/.local/lib/python3.12/site-packages/tensorflow_probability/python/experimental/distributions/multitask_gaussian_process_regression_model.py:856, in MultiTaskGaussianProcessRegressionModel._mean(self, index_points)
    852 def _mean(self, index_points=None):
    853   # The mean is of shape B1 + [E, N], where E is the number of index points,
    854   # and N is the number of tasks.
    855   return _unvec(
--> 856       self._get_flattened_marginal_distribution(
    857           index_points=index_points).mean(), [-1, self.kernel.num_tasks])

File ~/.local/lib/python3.12/site-packages/tensorflow_probability/python/experimental/distributions/multitask_gaussian_process_regression_model.py:835, in MultiTaskGaussianProcessRegressionModel._get_flattened_marginal_distribution(self, index_points)
    833 with self._name_and_control_scope('get_flattened_marginal_distribution'):
    834   index_points = self._get_index_points(index_points)
--> 835   covariance = self._compute_flattened_covariance(index_points)
    836   loc = self._flattened_conditional_mean_fn(index_points)
    837   scale = tf.linalg.LinearOperatorLowerTriangular(
    838       self._cholesky_fn(covariance),
    839       is_non_singular=True,
    840       name='GaussianProcessScaleLinearOperator')

File ~/.local/lib/python3.12/site-packages/tensorflow_probability/python/experimental/distributions/multitask_gaussian_process_regression_model.py:815, in MultiTaskGaussianProcessRegressionModel._compute_flattened_covariance(self, index_points)
    807 else:
    808   observation_scale = _compute_observation_scale(
    809       self.kernel,
    810       self.observation_index_points,
    811       self.cholesky_fn,
    812       observation_noise_variance=self.observation_noise_variance,
    813       observations_is_missing=self.observations_is_missing)
--> 815 cholinv_kzx = observation_scale.solve(kxz, adjoint_arg=True)
    816 kxz_kzzinv_kzx = tf.linalg.matmul(
    817     cholinv_kzx, cholinv_kzx, transpose_a=True)
    819 flattened_covariance = kxx.to_dense() - kxz_kzzinv_kzx

File ~/.local/lib/python3.12/site-packages/tensorflow/python/ops/linalg/linear_operator.py:979, in LinearOperator.solve(self, rhs, adjoint, adjoint_arg, name)
    974 arg_dim = -1 if adjoint_arg else -2
    975 tensor_shape.dimension_at_index(
    976     self.shape, self_dim).assert_is_compatible_with(
    977         rhs.shape[arg_dim])
--> 979 return self._solve(rhs, adjoint=adjoint, adjoint_arg=adjoint_arg)

File ~/.local/lib/python3.12/site-packages/tensorflow/python/ops/linalg/linear_operator_kronecker.py:408, in LinearOperatorKronecker._solve(self, rhs, adjoint, adjoint_arg)
    406 def solve_fn(o, rhs, adjoint, adjoint_arg):
    407   return o.solve(rhs, adjoint=adjoint, adjoint_arg=adjoint_arg)
--> 408 return self._solve_matmul_internal(
    409     x=rhs,
    410     solve_matmul_fn=solve_fn,
    411     adjoint=adjoint,
    412     adjoint_arg=adjoint_arg)

File ~/.local/lib/python3.12/site-packages/tensorflow/python/ops/linalg/linear_operator_kronecker.py:351, in LinearOperatorKronecker._solve_matmul_internal(self, x, solve_matmul_fn, adjoint, adjoint_arg)
    345 else:
    346   dim = math_ops.cast(
    347       output_shape[-2] * output_shape[-1] // operator_dimension,
    348       dtype=dtypes.int32)
    350 output_shape = _prefer_static_concat_shape(
--> 351     output_shape[:-2], [dim, operator_dimension])
    352 output = array_ops.reshape(output, shape=output_shape)
    354 # Conjugate because we are trying to compute A @ B^T, but
    355 # `LinearOperator` only supports `adjoint_arg`.

UnboundLocalError: in user code:

    File "/tmp/ipykernel_84868/3942600544.py", line 25, in predict  *
        return gp.mean()
    File "/usr/.local/lib/python3.12/site-packages/tensorflow_probability/python/distributions/distribution.py", line 1536, in mean  **
        return self._mean(**kwargs)
    File "/usr/.local/lib/python3.12/site-packages/tensorflow_probability/python/experimental/distributions/multitask_gaussian_process_regression_model.py", line 856, in _mean
        self._get_flattened_marginal_distribution(
    File "/usr/.local/lib/python3.12/site-packages/tensorflow_probability/python/experimental/distributions/multitask_gaussian_process_regression_model.py", line 835, in _get_flattened_marginal_distribution
        covariance = self._compute_flattened_covariance(index_points)
    File "/usr/.local/lib/python3.12/site-packages/tensorflow_probability/python/experimental/distributions/multitask_gaussian_process_regression_model.py", line 815, in _compute_flattened_covariance
        cholinv_kzx = observation_scale.solve(kxz, adjoint_arg=True)
    File "/usr/.local/lib/python3.12/site-packages/tensorflow/python/ops/linalg/linear_operator.py", line 979, in solve
        return self._solve(rhs, adjoint=adjoint, adjoint_arg=adjoint_arg)
    File "/usr/.local/lib/python3.12/site-packages/tensorflow/python/ops/linalg/linear_operator_kronecker.py", line 408, in _solve
        return self._solve_matmul_internal(
    File "/usr/.local/lib/python3.12/site-packages/tensorflow/python/ops/linalg/linear_operator_kronecker.py", line 351, in _solve_matmul_internal
        output_shape[:-2], [dim, operator_dimension])

    UnboundLocalError: cannot access local variable 'dim' where it is not associated with a value

@csuter
Copy link
Member

csuter commented Dec 18, 2024

looks like a legit bug in kronecker linear operator code. there's a code path that doesn't set the dim variable at all.

@mengsiong
Copy link
Author

@csuter seems like a quick fix is to remove the indent marked by red below in line 343, 344 of linear_operator_kronecker.py

image

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

2 participants