Releases: google-deepmind/tf2jax
Releases · google-deepmind/tf2jax
tf2jax 0.3.7
What's Changed
- Handle bad_indices_policy for GatherNd and ScatterNd. #195
- Support more numpy ops. #197
- Support python 3.11. #201
- Drop support for python 3.9 to align with JAX. #200
- Handle bad_indices_policy for GatherNd and ScatterNd. #206
- Add the
MulNoNan
op. #210 - Allow to optionally disable the type-check assert in
TensorListGetItem
. #211 - Update tf2jax to use new gather/scatter batching dims. #217 #221
- Fix a bug in Slice for size==-1. #220
- Fix handling of integer types during gradient computation with mhlo modules. #225
Full Changelog: v0.3.6...v0.3.7
tf2jax 0.3.6
What's Changed
- Add Cumprod to tf2jax ops. #149
- Fix squeeze in tf2jax for empty tuple axis arguments. #150
- Update expected attrs of VarHandleOp -- add debug_name. #154
- Only enable shape assertions during refinement if all input shapes are static. #155
- Refactor is_poly_dim to avoid depending on hard-coded module path. #156
- Avoid reuse of polymorphic variable names in tf2jax internal. #158
- Update deepmind -> google-deepmind. #159
- Change conversion of library functions and gradient functions to be lazy. #162
- Fix gradient function lookup error when higher order gradient functions are referenced in a jax2tf model but not serialized. #164
- Enforce autograph=False on gradient function. #172
- Update chex and haiku version in #175
- Fix shape inference for jax2tf vjp functions. #180
- Support native serialization with multiple target platforms. #183
Full Changelog: v0.3.5...v0.3.6
tf2jax 0.3.5
What's Changed
- Better error message for missing inputs in #131
- Support sharding annotation in #132, #133
- Support shape refinement with XlaCallModule in #140, in #144, in #145
- Add LowerBound and UpperBound operation to tf2jax to support searchsorted. in #141
- Dropped support for python 3.8 in #140,
- Handle shape check asserts in in #146
Full Changelog: v0.3.4...v0.3.5
tf2jax 0.3.4
What's Changed
- Add documentation regarding jax2tf native serialization and XlaCallModule. #123
- Improving support for XlaCallModule. #97, #122, #118, #100
- Support FFT, FFT2D, FFT3D, IFFT, IFFT2D, IFFT3D, RFFT2D, RFFT3D, IRFFT2D, IRFFT3D. #115
- Support Bessel0e, Bessel1e and PopulationCount. #114
- Support LeakyReluOp. #120
- Support BatchToSpaceND, SpaceToBatchND. by #126
- Support MatrixSetDiagV3. #111
- Handle VarHandleOp . #105
- Add placeholder for XlaSharding. #103
- Better error message for missing parameters. #110
- Better handling of static arguments for ScatterND. #112
- Improved handling of non-tensor outputs. #98
Full Changelog: v0.3.3...v0.3.4
tf2jax 0.3.3
What's Changed
- Make TF2JAX run on M1 macs. #81
- Use input_signature when available. #83
- Preemptive fix for upcoming JAX changes. #85, #86
- Preemptive fix for polymorphic shape support. #88
- Re-enable native lowering test for nested custom gradient. #91
- Support XlaReducePrecision. #93
Full Changelog: v0.3.2...v0.3.3
tf2jax 0.3.2
What's Changed
- Support XlaOptimizationBarrier. #63
- Support TensorListFromTensor, TensorListGetItem, TensorListSetItem, TensorListReserve, TensorListStack. #78, #79
- Support DvNoNan, ResourceGather, ResizeNearestNeighbor, RFFT, IRFFT, Angle, StatelessRandomGetAlg, EnsureShape, CheckNumerics, Roll. #60, #61, #62, #64, #67, #68, #74
- Support experimental XlaCallModule used by jax2tf. #70, #71
- Add heuristic for inferring cumulative reductions and avoid potentially expensive reduce_window calls. #55
- Fix an edge case in StrideSlice. #80
- Dropped support for python 3.7 (to match JAX). #73
Full Changelog: v0.3.1...v0.3.2
tf2jax 0.3.1
What's Changed
- Improved documentation. #49, #50
- Improved handling of input specs. #48
- Support Cholesky, Eig, SelfAdjointEigV2, SVD, QR, MatrixTriangularSolve. #45, #46, #47, #51
- Support UnsortedSegmentMax, UnsortedSegmentMin, UnsortedSegmentProd, UnsortedSegmentSum and TensorScatterUpdate. #52
Full Changelog: v0.3.0...v0.3.1
tf2jax 0.3.0
What's Changed
- Returns a callable AnnotatedFunction instead of a plain function. #35
- Add missing dtype mapping for tf.bfloat16. #27
- Improve handling of captured arguments for StatelessWhile. #31
- Minor fix to tf_hub documentation. #34
- Allow PreventGradient to log warning instead of raise errors. #36
- Fix handling tf.Module with non-unique variable names. #40
- Support VariableV2 as placeholder. #30
- Support FusedBatchNorm. #32
- Support Rank. #39
Full Changelog: v0.2.1...v0.3.0
tf2jax 0.2.1
What's Changed
- Fix broadcasting in BiasAdd. by @copybara-service in #24
Full Changelog: v0.2.0...v0.2.1
tf2jax 0.2.0
What's Changed
- Changed required tensorflow version from nightly to 2.8.0. by @copybara-service in #8
- Enable custom_gradient support by default. by @copybara-service in #9
- Support NCHW in convolution, pooling and BiasAdd operations. by @copybara-service in #13 and #14
- Fix a bug with shrink axis in StridedSlice. by @copybara-service in #17
- Support bool in MatrixDiagV3. by @copybara-service in #16
- Support TopKV2. by @copybara-service in #15
- Support XlaRngBitGenerator. @copybara-service in #19
Full Changelog: v0.1.1...v0.2.0