Skip to content

Commit

Permalink
FunMC: Add a better implementation of SMC to FunMC.
Browse files Browse the repository at this point in the history
The old implementation (i.e. AIS) will be rewritten to use this version at some future date.

PiperOrigin-RevId: 698476095
  • Loading branch information
SiegeLordEx authored and tensorflower-gardener committed Nov 20, 2024
1 parent 5536afd commit fadd403
Show file tree
Hide file tree
Showing 7 changed files with 1,508 additions and 8 deletions.
1 change: 0 additions & 1 deletion spinoffs/fun_mc/fun_mc/BUILD
Original file line number Diff line number Diff line change
Expand Up @@ -226,7 +226,6 @@ pytype_strict_contrib_test(
":types",
# absl/testing:parameterized dep,
# jax dep,
# jaxtyping dep,
# mock dep,
# tensorflow dep,
# tensorflow_probability/python/internal:test_util dep,
Expand Down
1 change: 1 addition & 0 deletions spinoffs/fun_mc/fun_mc/dynamic/backend_jax/BUILD
Original file line number Diff line number Diff line change
Expand Up @@ -33,6 +33,7 @@ py_library(
srcs = ["util.py"],
deps = [
# jaxtyping dep,
# numpy dep,
],
)

Expand Down
34 changes: 33 additions & 1 deletion spinoffs/fun_mc/fun_mc/dynamic/backend_jax/util.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,24 +14,29 @@
# ============================================================================
"""FunMC utilities implemented via JAX."""

import dataclasses
import functools
from typing import TypeVar, dataclass_transform

import jax
from jax import lax
from jax import random
from jax import tree_util
import jax.numpy as jnp
import jaxtyping
import numpy as np

__all__ = [
'Array',
'assert_same_shallow_tree',
'block_until_ready',
'convert_to_tensor',
'dataclass',
'diff',
'DType',
'flatten_tree',
'get_shallow_tree',
'get_static_value',
'inverse_fn',
'make_tensor_seed',
'map_tree',
Expand Down Expand Up @@ -406,7 +411,7 @@ def diff(x, prepend=None):
return jnp.diff(x, prepend=prepend)


def repeat(x, repeats, total_repeat_length=None):
def repeat(x, repeats, total_repeat_length):
"""Like jnp.repeat."""
return jnp.repeat(x, repeats, total_repeat_length=total_repeat_length)

Expand Down Expand Up @@ -436,3 +441,30 @@ def convert_to_tensor(x):
if x is None:
return x
return jnp.asarray(x)


T = TypeVar('T')


@dataclass_transform()
def dataclass(cls: T) -> T:
"""Create a tree-compatible dataclass."""
cls = dataclasses.dataclass(frozen=True)(cls)
fields = [f.name for f in dataclasses.fields(cls)]
jax.tree_util.register_dataclass(cls, fields, [])

def replace(self, **updates):
"""Returns a new object replacing the specified fields with new values."""
return dataclasses.replace(self, **updates)

cls.replace = replace

return cls


def get_static_value(x):
"""Returns the static value of x, or None if x is dynamic."""
try:
return np.array(x)
except TypeError:
return None
8 changes: 7 additions & 1 deletion spinoffs/fun_mc/fun_mc/dynamic/backend_tensorflow/backend.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,13 @@
tnp = tf.experimental.numpy

_lax = types.ModuleType('lax')
_lax.cond = tf.cond


def cond(pred, true_fn, false_fn, *args):
return tf.cond(pred, lambda: true_fn(*args), lambda: false_fn(*args))


_lax.cond = cond
_lax.stop_gradient = tf.stop_gradient

_nn = types.ModuleType('nn')
Expand Down
66 changes: 62 additions & 4 deletions spinoffs/fun_mc/fun_mc/dynamic/backend_tensorflow/util.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,9 @@
# ============================================================================
"""FunMC utilities implemented via TensorFlow."""

import dataclasses
import functools
from typing import TypeVar, dataclass_transform

import numpy as np
import six
Expand All @@ -29,10 +31,12 @@
'assert_same_shallow_tree',
'block_until_ready',
'convert_to_tensor',
'dataclass',
'diff',
'DType',
'flatten_tree',
'get_shallow_tree',
'get_static_value',
'inverse_fn',
'make_tensor_seed',
'map_tree',
Expand Down Expand Up @@ -419,10 +423,23 @@ def diff(x, prepend=None):

def repeat(x, repeats, total_repeat_length):
"""Like jnp.repeat."""
res = tf.repeat(x, repeats)
if total_repeat_length is not None:
res.set_shape([total_repeat_length] + [None] * (len(res.shape) - 1))
return res
# Implementation based on JAX, with some adjustments due to TF's stricted
# indexing validation.
exclusive_repeats = tf.concat([[0], repeats[:-1]], axis=0)
scatter_indices = tf.cumsum(exclusive_repeats)
scatter_indices = tf.where(
scatter_indices < total_repeat_length,
scatter_indices,
total_repeat_length,
)
block_split_indicators = tf.zeros([total_repeat_length + 1], tf.int32)
block_split_indicators = tf.tensor_scatter_nd_add(
block_split_indicators,
scatter_indices[..., tf.newaxis],
tf.ones_like(scatter_indices),
)
gather_indices = tf.cumsum(block_split_indicators[:-1]) - 1
return tf.gather(x, gather_indices)


def new_dynamic_array(shape, dtype, size):
Expand Down Expand Up @@ -454,3 +471,44 @@ def convert_to_tensor(x):
if isinstance(x, tf.TensorArray):
return x
return tf.convert_to_tensor(x)


T = TypeVar('T')


@dataclass_transform()
def dataclass(cls: T) -> T:
"""Create a tree-compatible dataclass."""
cls = dataclasses.dataclass(frozen=True)(cls)

def __tf_flatten__(self): # pylint: disable=invalid-name
metadata = ()
fields = dataclasses.fields(self)
components = tuple(getattr(self, f.name) for f in fields)
return metadata, components

@classmethod
def __tf_unflatten__(cls, metadata, leaves): # pylint: disable=invalid-name
del metadata
return cls(*leaves)

def __len__(self): # pylint: disable=invalid-name
# This is to work around a bug in TF's tree-prefix matching.
return len(dataclasses.fields(self))

cls.__tf_flatten__ = __tf_flatten__
cls.__tf_unflatten__ = __tf_unflatten__
cls.__len__ = __len__

def replace(self, **updates):
"""Returns a new object replacing the specified fields with new values."""
return dataclasses.replace(self, **updates)

cls.replace = replace

return cls


def get_static_value(x):
"""Returns the static value of x, or None if x is dynamic."""
return tf.get_static_value(x)
Loading

0 comments on commit fadd403

Please sign in to comment.