From 730b5a6e60e4ae72924d56764bfe851d2274a566 Mon Sep 17 00:00:00 2001 From: Mathieu Blondel Date: Thu, 9 Feb 2023 16:02:32 +0100 Subject: [PATCH] Release v0.6. --- docs/api.rst | 2 +- docs/changelog.rst | 38 +++++++++++++++++++++++++++++++++++++ docs/line_search.rst | 1 + docs/objective_and_loss.rst | 10 ++++++++-- jaxopt/_src/loss.py | 34 +++++++++++++++++---------------- jaxopt/version.py | 2 +- 6 files changed, 67 insertions(+), 20 deletions(-) diff --git a/docs/api.rst b/docs/api.rst index bc8adf55..c5de5d18 100644 --- a/docs/api.rst +++ b/docs/api.rst @@ -135,7 +135,7 @@ Line search :toctree: _autosummary jaxopt.BacktrackingLineSearch - + jaxopt.HagerZhangLineSearch Perturbed optimizers diff --git a/docs/changelog.rst b/docs/changelog.rst index 1b3472ab..42143e87 100644 --- a/docs/changelog.rst +++ b/docs/changelog.rst @@ -1,6 +1,44 @@ Changelog ========= +Version 0.6 +----------- + +New features +~~~~~~~~~~~~ + +- Added new Hager-Zhang linesearch in LBFGS, by Srinivas Vasudevan (code review by Emily Fertig). +- Added perceptron and hinge losses, by Quentin Berthet. +- Added binary sparsemax loss, sparse_plus and sparse_sigmoid, by Vincent Roulet. +- Added isotonic regression, by Michael Sander. + +Bug fixes and enhancements +~~~~~~~~~~~~~~~~~~~~~~~~~~ + +- Added TPU support to notebooks, by Ayush Shridhar. +- Allowed users to restart from a previous optimizer state in LBFGS, by Zaccharie Ramzi. +- Added faster error computation in gradient descent algorithm, by Zaccharie Ramzi. +- Got rid of extra function call in BFGS and LBFGS, by Zaccharie Ramzi. +- Improved dtype consistency between input and output of update method, by Mathieu Blondel. +- Added perturbed optimizers notebook and narrative documentation, by Quentin Berthet and Fabian Pedregosa. +- Enabled auxiliary value returned by linesearch methods, by Zaccharie Ramzi. +- Added distributed examples to the website, by Fabian Pedregosa. +- Added Custom loop pjit example, by Felipe Llinares. +- Fixed wrong latex in maml.ipynb, by Fabian Pedregosa. +- Fixed bug in backtracking line search, by Srinivas Vasudevan (code review by Emily Fertig). +- Added pylintrc to top level directory, by Fabian Pedregosa. +- Corrected the condition function in LBFGS, by Zaccharie Ramzi. +- Added custom loop pmap example, by Felipe Llinares. +- Fixed pytree support in IterativeRefinement, by Louis Béthune. +- Fixed has_aux support in ArmijoSGD, by Louis Béthune. +- Documentation improvements, by Fabian Pedregosa and Mathieu Blondel. + +Contributors +~~~~~~~~~~~~ + +Ayush Shridhar, Fabian Pedregosa, Felipe Llinares, Louis Bethune, +Mathieu Blondel, Michael Sander, Quentin Berthet, Srinivas Vasudevan, Vincent Roulet, Zaccharie Ramzi. + Version 0.5.5 ------------- diff --git a/docs/line_search.rst b/docs/line_search.rst index 44a3a7b1..ef322292 100644 --- a/docs/line_search.rst +++ b/docs/line_search.rst @@ -51,6 +51,7 @@ Algorithms :toctree: _autosummary jaxopt.BacktrackingLineSearch + jaxopt.HagerZhangLineSearch The :class:`BacktrackingLineSearch ` algorithm iteratively reduces the step size by some decrease factor until the conditions diff --git a/docs/objective_and_loss.rst b/docs/objective_and_loss.rst index 76d55cf3..7cbe1f8f 100644 --- a/docs/objective_and_loss.rst +++ b/docs/objective_and_loss.rst @@ -29,6 +29,14 @@ Binary classification Binary classification losses are of the form ``loss(int: label, float: score) -> float``, where ``label`` is the ground-truth (``0`` or ``1``) and ``score`` is the model's output. +The following utility functions are useful for the binary sparsemax loss. + +.. autosummary:: + :toctree: _autosummary + + jaxopt.loss.sparse_plus + jaxopt.loss.sparse_sigmoid + Multiclass classification ~~~~~~~~~~~~~~~~~~~~~~~~~ @@ -79,5 +87,3 @@ Other functions jaxopt.objective.multiclass_logreg_with_intercept jaxopt.objective.l2_multiclass_logreg jaxopt.objective.l2_multiclass_logreg_with_intercept - jaxopt.loss.sparse_plus - jaxopt.loss.sparse_sigmoid diff --git a/jaxopt/_src/loss.py b/jaxopt/_src/loss.py index eea660e3..f1fb9422 100644 --- a/jaxopt/_src/loss.py +++ b/jaxopt/_src/loss.py @@ -74,58 +74,60 @@ def binary_sparsemax_loss(label: int, logit: float) -> float: loss value References: - Learning with Fenchel-Young Losses. Mathieu Blondel, André F. T. Martins, + Learning with Fenchel-Young Losses. Mathieu Blondel, André F. T. Martins, Vlad Niculae. JMLR 2020. (Sec. 4.4) """ return sparse_plus(jnp.where(label, -logit, logit)) def sparse_plus(x: float) -> float: - """Sparse plus function. + r"""Sparse plus function. Computes the function: - .. math: - \mathrm{sparseplus}(x) = \begin{cases} + .. math:: + + \mathrm{sparse\_plus}(x) = \begin{cases} 0, & x \leq -1\\ - \frac{1}{4}(x+1)^2, & -1 < x < 1 \\ + \frac{1}{4}(x+1)^2, & -1 < x < 1 \\ x, & 1 \leq x \end{cases} - This is the twin function of the softplus activation ensuring a zero output - for inputs less than -1 and a linear output for inputs greater than 1, - while remaining smooth, convex, monotonic by an adequate definition between + This is the twin function of the softplus activation ensuring a zero output + for inputs less than -1 and a linear output for inputs greater than 1, + while remaining smooth, convex, monotonic by an adequate definition between -1 and 1. Args: x: input (float) Returns: - sparseplus(x) as defined above + sparse_plus(x) as defined above """ return jnp.where(x <= -1.0, 0.0, jnp.where(x >= 1.0, x, (x + 1.0)**2/4)) def sparse_sigmoid(x: float) -> float: - """Sparse sigmoid function. + r"""Sparse sigmoid function. + + Computes the function: - Computes the function: + .. math:: - .. math: - \mathrm{sparsesigmoid}(x) = \begin{cases} + \mathrm{sparse\_sigmoid}(x) = \begin{cases} 0, & x \leq -1\\ - \frac{1}{2}(x+1), & -1 < x < 1 \\ + \frac{1}{2}(x+1), & -1 < x < 1 \\ 1, & 1 \leq x \end{cases} This is the twin function of the sigmoid activation ensuring a zero output for inputs less than -1, a 1 ouput for inputs greater than 1, and a linear - output for inputs between -1 and 1. This is the derivative of the sparse + output for inputs between -1 and 1. This is the derivative of the sparse plus function. Args: x: input (float) Returns: - sparsesigmoid(x) as defined above + sparse_sigmoid(x) as defined above """ return 0.5 * projection_hypercube(x + 1.0, 2.0) diff --git a/jaxopt/version.py b/jaxopt/version.py index 37e5adb3..ec230c26 100644 --- a/jaxopt/version.py +++ b/jaxopt/version.py @@ -14,4 +14,4 @@ """JAXopt version.""" -__version__ = "0.5.5" +__version__ = "0.6"