Skip to content

Releases: google-deepmind/tf2jax

tf2jax 0.3.7

08 Jan 18:16
Compare
Choose a tag to compare

What's Changed

  • Handle bad_indices_policy for GatherNd and ScatterNd. #195
  • Support more numpy ops. #197
  • Support python 3.11. #201
  • Drop support for python 3.9 to align with JAX. #200
  • Handle bad_indices_policy for GatherNd and ScatterNd. #206
  • Add the MulNoNan op. #210
  • Allow to optionally disable the type-check assert in TensorListGetItem. #211
  • Update tf2jax to use new gather/scatter batching dims. #217 #221
  • Fix a bug in Slice for size==-1. #220
  • Fix handling of integer types during gradient computation with mhlo modules. #225

Full Changelog: v0.3.6...v0.3.7

tf2jax 0.3.6

26 Mar 10:37
Compare
Choose a tag to compare

What's Changed

  • Add Cumprod to tf2jax ops. #149
  • Fix squeeze in tf2jax for empty tuple axis arguments. #150
  • Update expected attrs of VarHandleOp -- add debug_name. #154
  • Only enable shape assertions during refinement if all input shapes are static. #155
  • Refactor is_poly_dim to avoid depending on hard-coded module path. #156
  • Avoid reuse of polymorphic variable names in tf2jax internal. #158
  • Update deepmind -> google-deepmind. #159
  • Change conversion of library functions and gradient functions to be lazy. #162
  • Fix gradient function lookup error when higher order gradient functions are referenced in a jax2tf model but not serialized. #164
  • Enforce autograph=False on gradient function. #172
  • Update chex and haiku version in #175
  • Fix shape inference for jax2tf vjp functions. #180
  • Support native serialization with multiple target platforms. #183

Full Changelog: v0.3.5...v0.3.6

tf2jax 0.3.5

28 Jul 13:23
Compare
Choose a tag to compare

What's Changed

  • Better error message for missing inputs in #131
  • Support sharding annotation in #132, #133
  • Support shape refinement with XlaCallModule in #140, in #144, in #145
  • Add LowerBound and UpperBound operation to tf2jax to support searchsorted. in #141
  • Dropped support for python 3.8 in #140,
  • Handle shape check asserts in in #146

Full Changelog: v0.3.4...v0.3.5

tf2jax 0.3.4

21 Apr 09:07
Compare
Choose a tag to compare

What's Changed

  • Add documentation regarding jax2tf native serialization and XlaCallModule. #123
  • Improving support for XlaCallModule. #97, #122, #118, #100
  • Support FFT, FFT2D, FFT3D, IFFT, IFFT2D, IFFT3D, RFFT2D, RFFT3D, IRFFT2D, IRFFT3D. #115
  • Support Bessel0e, Bessel1e and PopulationCount. #114
  • Support LeakyReluOp. #120
  • Support BatchToSpaceND, SpaceToBatchND. by #126
  • Support MatrixSetDiagV3. #111
  • Handle VarHandleOp . #105
  • Add placeholder for XlaSharding. #103
  • Better error message for missing parameters. #110
  • Better handling of static arguments for ScatterND. #112
  • Improved handling of non-tensor outputs. #98

Full Changelog: v0.3.3...v0.3.4

tf2jax 0.3.3

10 Feb 20:21
Compare
Choose a tag to compare

What's Changed

  • Make TF2JAX run on M1 macs. #81
  • Use input_signature when available. #83
  • Preemptive fix for upcoming JAX changes. #85, #86
  • Preemptive fix for polymorphic shape support. #88
  • Re-enable native lowering test for nested custom gradient. #91
  • Support XlaReducePrecision. #93

Full Changelog: v0.3.2...v0.3.3

tf2jax 0.3.2

23 Jan 11:56
Compare
Choose a tag to compare

What's Changed

  • Support XlaOptimizationBarrier. #63
  • Support TensorListFromTensor, TensorListGetItem, TensorListSetItem, TensorListReserve, TensorListStack. #78, #79
  • Support DvNoNan, ResourceGather, ResizeNearestNeighbor, RFFT, IRFFT, Angle, StatelessRandomGetAlg, EnsureShape, CheckNumerics, Roll. #60, #61, #62, #64, #67, #68, #74
  • Support experimental XlaCallModule used by jax2tf. #70, #71
  • Add heuristic for inferring cumulative reductions and avoid potentially expensive reduce_window calls. #55
  • Fix an edge case in StrideSlice. #80
  • Dropped support for python 3.7 (to match JAX). #73

Full Changelog: v0.3.1...v0.3.2

tf2jax 0.3.1

02 Aug 14:23
Compare
Choose a tag to compare

What's Changed

  • Improved documentation. #49, #50
  • Improved handling of input specs. #48
  • Support Cholesky, Eig, SelfAdjointEigV2, SVD, QR, MatrixTriangularSolve. #45, #46, #47, #51
  • Support UnsortedSegmentMax, UnsortedSegmentMin, UnsortedSegmentProd, UnsortedSegmentSum and TensorScatterUpdate. #52

Full Changelog: v0.3.0...v0.3.1

tf2jax 0.3.0

21 Jun 13:47
Compare
Choose a tag to compare

What's Changed

  • Returns a callable AnnotatedFunction instead of a plain function. #35
  • Add missing dtype mapping for tf.bfloat16. #27
  • Improve handling of captured arguments for StatelessWhile. #31
  • Minor fix to tf_hub documentation. #34
  • Allow PreventGradient to log warning instead of raise errors. #36
  • Fix handling tf.Module with non-unique variable names. #40
  • Support VariableV2 as placeholder. #30
  • Support FusedBatchNorm. #32
  • Support Rank. #39

Full Changelog: v0.2.1...v0.3.0

tf2jax 0.2.1

06 May 15:36
Compare
Choose a tag to compare

What's Changed

  • Fix broadcasting in BiasAdd. by @copybara-service in #24

Full Changelog: v0.2.0...v0.2.1

tf2jax 0.2.0

03 May 14:18
Compare
Choose a tag to compare

What's Changed

  • Changed required tensorflow version from nightly to 2.8.0. by @copybara-service in #8
  • Enable custom_gradient support by default. by @copybara-service in #9
  • Support NCHW in convolution, pooling and BiasAdd operations. by @copybara-service in #13 and #14
  • Fix a bug with shrink axis in StridedSlice. by @copybara-service in #17
  • Support bool in MatrixDiagV3. by @copybara-service in #16
  • Support TopKV2. by @copybara-service in #15
  • Support XlaRngBitGenerator. @copybara-service in #19

Full Changelog: v0.1.1...v0.2.0