diff --git a/docs/changelog.rst b/docs/changelog.rst index 516fec9e..f245e0fe 100644 --- a/docs/changelog.rst +++ b/docs/changelog.rst @@ -1,6 +1,20 @@ Changelog ========= +Version 0.4.2 +------------- + +Bug fixes and enhancements +~~~~~~~~~~~~~~~~~~~~~~~~~~ + +- Fix issue with positional arguments in :class:`jaxopt.LBFGS` and :class:`jaxopt.NonlinearCG`, + by Mathieu Blondel. + +Contributors +~~~~~~~~~~~~ + +Mathieu Blondel. + Version 0.4.1 ------------- diff --git a/jaxopt/_src/lbfgs.py b/jaxopt/_src/lbfgs.py index 301d7cf9..10fc3c81 100644 --- a/jaxopt/_src/lbfgs.py +++ b/jaxopt/_src/lbfgs.py @@ -228,7 +228,7 @@ def init_state(self, """ return LbfgsState(iter_num=jnp.asarray(0), value=jnp.asarray(jnp.inf), - stepsize=jnp.asarray(1.0), + stepsize=jnp.asarray(self.max_stepsize), error=jnp.asarray(jnp.inf), s_history=init_history(init_params, self.history_size), y_history=init_history(init_params, self.history_size), @@ -280,9 +280,9 @@ def update(self, self.max_stepsize, # Otherwise, we increase a bit the previous one. state.stepsize * self.increase_factor) - new_stepsize, ls_state = ls.run(init_stepsize=init_stepsize, - params=params, value=value, grad=grad, - descent_direction=descent_direction, + new_stepsize, ls_state = ls.run(init_stepsize, + params, value, grad, + descent_direction, *args, **kwargs) new_value = ls_state.value new_params = ls_state.params diff --git a/jaxopt/_src/nonlinear_cg.py b/jaxopt/_src/nonlinear_cg.py index 15e63668..11eed03e 100644 --- a/jaxopt/_src/nonlinear_cg.py +++ b/jaxopt/_src/nonlinear_cg.py @@ -63,6 +63,8 @@ class NonlinearCG(base.IterativeSolver): (default: 0.8). increase_factor: factor by which to increase the stepsize during line search (default: 1.2). + max_stepsize: upper bound on stepsize. + min_stepsize: lower bound on stepsize. implicit_diff: whether to enable implicit diff or autodiff of unrolled iterations. implicit_diff_solve: the linear system solver to use. @@ -87,6 +89,10 @@ class NonlinearCG(base.IterativeSolver): maxls: int = 15 decrease_factor: float = 0.8 increase_factor: float = 1.2 + max_stepsize: float = 1.0 + # FIXME: should depend on whether float32 or float64 is used. + min_stepsize: float = 1e-6 + implicit_diff: bool = True implicit_diff_solve: Optional[Callable] = None @@ -110,7 +116,7 @@ def init_state(self, value, grad = self._value_and_grad_fun(init_params, *args, **kwargs) return NonlinearCGState(iter_num=jnp.asarray(0), - stepsize=jnp.asarray(1.0), + stepsize=jnp.asarray(self.max_stepsize), error=jnp.asarray(jnp.inf), value=value, grad=grad, @@ -133,16 +139,24 @@ def update(self, eps = 1e-6 value, grad, descent_direction = state.value, state.grad, state.descent_direction - init_stepsize = state.stepsize * self.increase_factor ls = BacktrackingLineSearch(fun=self._value_and_grad_fun, value_and_grad=True, maxiter=self.maxls, decrease_factor=self.decrease_factor, - condition=self.condition) - new_stepsize, ls_state = ls.run(init_stepsize=init_stepsize, - params=params, - value=value, - grad=grad, + condition=self.condition, + max_stepsize=self.max_stepsize) + + init_stepsize = jnp.where(state.stepsize <= self.min_stepsize, + # If stepsize became too small, we restart it. + self.max_stepsize, + # Otherwise, we increase a bit the previous one. + state.stepsize * self.increase_factor) + + new_stepsize, ls_state = ls.run(init_stepsize, + params, + value, + grad, + None, # descent_direction *args, **kwargs) new_params = tree_add_scalar_mul(params, new_stepsize, descent_direction) diff --git a/jaxopt/version.py b/jaxopt/version.py index c9cfab72..e0b5b397 100644 --- a/jaxopt/version.py +++ b/jaxopt/version.py @@ -14,4 +14,4 @@ """JAXopt version.""" -__version__ = "0.4.1" +__version__ = "0.4.2" diff --git a/tests/lbfgs_test.py b/tests/lbfgs_test.py index 0d64a2d2..55ee85a4 100644 --- a/tests/lbfgs_test.py +++ b/tests/lbfgs_test.py @@ -213,6 +213,7 @@ def test_binary_logreg(self, use_gamma): w_init = jnp.zeros(X.shape[1]) lbfgs = LBFGS(fun=fun, tol=1e-3, maxiter=500, use_gamma=use_gamma) + # Test with keyword argument. w_fit, info = lbfgs.run(w_init, data=data) # Check optimality conditions. @@ -236,7 +237,8 @@ def test_multiclass_logreg(self, use_gamma): pytree_init = (W_init, b_init) lbfgs = LBFGS(fun=fun, tol=1e-3, maxiter=500, use_gamma=use_gamma) - pytree_fit, info = lbfgs.run(pytree_init, data=data) + # Test with positional argument. + pytree_fit, info = lbfgs.run(pytree_init, data) # Check optimality conditions. self.assertLessEqual(info.error, 1e-2) diff --git a/tests/nonlinear_cg_test.py b/tests/nonlinear_cg_test.py index adf02c0b..84bbe2e6 100644 --- a/tests/nonlinear_cg_test.py +++ b/tests/nonlinear_cg_test.py @@ -81,7 +81,8 @@ def test_binary_logreg(self, method): w_init = jnp.zeros(X.shape[1]) cg_model = NonlinearCG(fun=fun, tol=1e-3, maxiter=100, method=method) - w_fit, info = cg_model.run(w_init, data=data) + # Test with positional argument. + w_fit, info = cg_model.run(w_init, data) # Check optimality conditions. self.assertLessEqual(info.error, 5e-2)