Skip to content

Commit

Permalink
no grad eval in armijo/goldstein linesearch
Browse files Browse the repository at this point in the history
  • Loading branch information
vroulet committed Sep 25, 2023
1 parent a86b275 commit c428455
Show file tree
Hide file tree
Showing 9 changed files with 218 additions and 142 deletions.
171 changes: 98 additions & 73 deletions jaxopt/_src/backtracking_linesearch.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,29 +18,34 @@
from typing import Callable
from typing import NamedTuple
from typing import Optional
from typing import Union

from dataclasses import dataclass

import jax
import jax.numpy as jnp

from jaxopt._src import base
from jaxopt._src.cond import cond
from jaxopt.tree_util import tree_add_scalar_mul
from jaxopt.tree_util import tree_conj
from jaxopt.tree_util import tree_scalar_mul
from jaxopt.tree_util import tree_vdot_real
from jaxopt.tree_util import tree_conj


class BacktrackingLineSearchState(NamedTuple):
"""Named tuple containing state information."""
iter_num: int
params: Any
value: float
grad: Any # either initial or final for armijo or glodstein
value_init: float
grad_init: Any
error: float
done: bool
params: Any
grad: Any
failed: bool
num_fun_eval: int
num_grad_eval: int
failed: bool
aux: Optional[Any] = None


Expand All @@ -56,7 +61,8 @@ class BacktrackingLineSearch(base.IterativeLineSearch):
``*args`` and ``**kwargs`` are additional arguments.
value_and_grad: if ``False``, ``fun`` should return the function value only.
If ``True``, ``fun`` should return both the function value and the
gradient.
gradient. If it is a Callable, fun should return the value while value_and_grad
returns value and gradient of the objective.
has_aux: if ``False``, ``fun`` should return the function value only.
If ``True``, ``fun`` should return a pair ``(value, aux)`` where ``aux``
is a pytree of auxiliary values.
Expand All @@ -77,7 +83,7 @@ class BacktrackingLineSearch(base.IterativeLineSearch):
unroll: whether to unroll the optimization loop (default: "auto").
"""
fun: Callable
value_and_grad: bool = False
value_and_grad: Union[bool, Callable] = False
has_aux: bool = False

maxiter: int = 30
Expand Down Expand Up @@ -117,30 +123,30 @@ def init_state(
state
"""
del descent_direction # Not used.
del init_stepsize # Not used.

num_fun_eval = 0
num_grad_eval = 0
num_fun_eval = jnp.asarray(0, base.NUM_EVAL_DTYPE)
num_grad_eval = jnp.asarray(0, base.NUM_EVAL_DTYPE)
if value is None or grad is None:
if self.has_aux:
(value, _), grad = self._value_and_grad_fun(
params, *fun_args, **fun_kwargs
)
else:
value, grad = self._value_and_grad_fun(params, *fun_args, **fun_kwargs)
(value, _), grad = self._value_and_grad_fun_with_aux(
params, *fun_args, **fun_kwargs
)
num_fun_eval += 1
num_grad_eval += 1

return BacktrackingLineSearchState(iter_num=jnp.asarray(0),
params=params,
value=value,
grad=grad,
value_init=value,
grad_init=grad,
aux=None, # we do not need to have aux
# in the initial state
error=jnp.asarray(jnp.inf),
params=params,
num_fun_eval=num_fun_eval,
num_grad_eval=num_grad_eval,
done=jnp.asarray(False),
grad=grad,
failed=jnp.asarray(False))
failed=jnp.asarray(False),
num_fun_eval=num_fun_eval,
num_grad_eval=num_grad_eval)

def update(
self,
Expand All @@ -159,8 +165,8 @@ def update(
stepsize: current estimate of the step size.
state: named tuple containing the line search state.
params: current parameters.
value: current function value (recomputed if None).
grad: current gradient (recomputed if None).
value: current function value (computed at initialization, unused here).
grad: current gradient (computed at initialization, unused here).
descent_direction: descent direction (negative gradient if None).
fun_args: additional positional arguments to be passed to ``fun``.
fun_kwargs: additional keyword arguments to be passed to ``fun``.
Expand All @@ -172,43 +178,45 @@ def update(
num_fun_eval = state.num_fun_eval
num_grad_eval = state.num_grad_eval

if value is None or grad is None:
if self.has_aux:
(value, _), grad = self._value_and_grad_fun(params, *fun_args, **fun_kwargs)
else:
value, grad = self._value_and_grad_fun(params, *fun_args, **fun_kwargs)
num_fun_eval += 1
num_grad_eval += 1
# Grab value and grad from initialization and avoid recomputing them
del value
del grad
value = state.value_init
grad = state.grad_init

if descent_direction is None:
descent_direction = tree_scalar_mul(-1, tree_conj(grad))

gd_vdot = tree_vdot_real(tree_conj(grad), descent_direction)
slope = tree_vdot_real(tree_conj(grad), descent_direction)

# For backtracking linesearches, we want to compute the next point
# from the basepoint. i.e. x_i = x_0 + s_i * p
new_params = tree_add_scalar_mul(params, stepsize, descent_direction)
if self.has_aux:
(new_value, new_aux), new_grad = self._value_and_grad_fun(
# Every condition requires the new function value, but not every one
# requires the new gradient value (we'll assume that this code is called
# under `jit`).
num_fun_eval += 1
if self.condition in ["armijo", "goldstein"]:
new_value, new_aux = self._fun_with_aux(
new_params, *fun_args, **fun_kwargs
)
# For those conditions, no need to compute a new grad
# We recompute a new_grad only once we have found the right stepsize,
# see below
new_grad, new_slope = grad, slope
else:
new_value, new_grad = self._value_and_grad_fun(
(new_value, new_aux), new_grad = self._value_and_grad_fun_with_aux(
new_params, *fun_args, **fun_kwargs
)
new_aux = None
new_gd_vdot = tree_vdot_real(tree_conj(new_grad), descent_direction)

# Every condition requires the new function value, but not every one
# requires the new gradient value (we'll assume that this code is called
# under `jit`).
num_fun_eval += 1
new_slope = tree_vdot_real(tree_conj(new_grad), descent_direction)
num_grad_eval += 1

# Armijo condition (upper bound on admissible step size).
# cond1 = new_value <= value + self.c1 * stepsize * gd_vdot
# cond1 = new_value <= value + self.c1 * stepsize * slope
# See equation (3.6a), Numerical Optimization, Second edition.
diff_cond1 = new_value - (value + self.c1 * stepsize * gd_vdot)
error_cond1 = jnp.maximum(diff_cond1, 0.0)

diff_cond1 = new_value - (value + self.c1 * stepsize * slope)
error_cond1 = jnp.where(jnp.isnan(diff_cond1), jnp.inf, diff_cond1)
error_cond1 = jnp.maximum(error_cond1, 0.0)
error = error_cond1

if self.condition == "armijo":
Expand All @@ -217,54 +225,71 @@ def update(
pass

elif self.condition == "strong-wolfe":
# cond2 = abs(new_gd_vdot) <= c2 * abs(gd_vdot)
# cond2 = abs(new_slope) <= c2 * abs(slope)
# See equation (3.7b), Numerical Optimization, Second edition.
diff_cond2 = jnp.abs(new_gd_vdot) - self.c2 * jnp.abs(gd_vdot)
error_cond2 = jnp.maximum(diff_cond2, 0.0)
diff_cond2 = jnp.abs(new_slope) - self.c2 * jnp.abs(slope)
error_cond2 = jnp.where(jnp.isnan(diff_cond2), jnp.inf, diff_cond2)
error_cond2 = jnp.maximum(error_cond2, 0.0)
error = jnp.maximum(error_cond1, error_cond2)
num_grad_eval += 1

elif self.condition == "wolfe":
# cond2 = new_gd_vdot >= c2 * gd_vdot
# cond2 = new_slope >= c2 * slope
# See equation (3.6b), Numerical Optimization, Second edition.
diff_cond2 = self.c2 * gd_vdot - new_gd_vdot
error_cond2 = jnp.maximum(diff_cond2, 0.0)
diff_cond2 = self.c2 * slope - new_slope
error_cond2 = jnp.where(jnp.isnan(diff_cond2), jnp.inf, diff_cond2)
error_cond2 = jnp.maximum(error_cond2, 0.0)
error = jnp.maximum(error_cond1, error_cond2)
num_grad_eval += 1

elif self.condition == "goldstein":
# cond2 = new_value >= value + (1 - self.c1) * stepsize * gd_vdot
diff_cond2 = value + (1 - self.c1) * stepsize * gd_vdot - new_value
error_cond2 = jnp.maximum(diff_cond2, 0.0)
# cond2 = new_value >= value + (1 - self.c1) * stepsize * slope
diff_cond2 = value + (1 - self.c1) * stepsize * slope - new_value
error_cond2 = jnp.where(jnp.isnan(diff_cond2), jnp.inf, diff_cond2)
error_cond2 = jnp.maximum(error_cond2, 0.0)
error = jnp.maximum(error_cond1, error_cond2)

else:
raise ValueError("condition should be one of "
"'armijo', 'goldstein', 'strong-wolfe' or 'wolfe'.")

new_stepsize = jnp.where(error <= self.tol,
stepsize,
stepsize * self.decrease_factor)
done = state.done | (error <= self.tol)
failed = state.failed | ((state.iter_num + 1 == self.maxiter) & ~done)

new_state = BacktrackingLineSearchState(iter_num=state.iter_num + 1,
value=new_value,
aux=new_aux,
grad=new_grad,
params=new_params,
num_fun_eval=num_fun_eval,
num_grad_eval=num_grad_eval,
done=done,
error=error,
failed=failed)
new_stepsize = jnp.where(done | failed,
stepsize,
stepsize * self.decrease_factor)

if self.condition in ["armijo", "goldstein"]:
# If we are done for the armijo or the goldstein conditions,
# we compute the final gradient (we had not computed it before since
# these conditions did not require it)
new_grad = cond(done | failed,
self._compute_final_grad,
lambda *_: grad,
new_params, fun_args, fun_kwargs,
jit=self.jit)
maybe_additional_eval = jnp.asarray(done | failed, dtype=base.NUM_EVAL_DTYPE)
num_grad_eval = num_grad_eval + maybe_additional_eval
# We a priori always access the function value when computing the gradient
num_fun_eval = num_fun_eval + maybe_additional_eval

new_state = state._replace(iter_num=state.iter_num + 1,
params=new_params,
value=new_value,
grad=new_grad,
aux=new_aux,
done=done,
error=error,
failed=failed,
num_fun_eval=num_fun_eval,
num_grad_eval=num_grad_eval)

return base.LineSearchStep(stepsize=new_stepsize, state=new_state)

def _compute_final_grad(self, params, fun_args, fun_kwargs):
return self._grad_with_aux(params, *fun_args, **fun_kwargs)[0]

def __post_init__(self):
if self.value_and_grad:
self._value_and_grad_fun = self.fun
else:
self._value_and_grad_fun = jax.value_and_grad(
self.fun, has_aux=self.has_aux
)
self._fun_with_aux, self._grad_with_aux, self._value_and_grad_fun_with_aux = \
base._make_funs_with_aux(fun=self.fun,
value_and_grad=self.value_and_grad,
has_aux=self.has_aux)
2 changes: 1 addition & 1 deletion jaxopt/_src/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -266,7 +266,7 @@ def _cond_fun(self, inputs):
_, state = inputs[0]
if self.verbose:
name = self.__class__.__name__
jax.debug.print("Sovler: %s, Error: {error}" % name, error=state.error)
jax.debug.print("Solver: %s, Error: {error}" % name, error=state.error)
return state.error > self.tol

def _body_fun(self, inputs):
Expand Down
6 changes: 3 additions & 3 deletions jaxopt/_src/bfgs.py
Original file line number Diff line number Diff line change
Expand Up @@ -271,7 +271,7 @@ def _value_and_grad_fun(self, params, *args, **kwargs):
def __post_init__(self):
super().__post_init__()

_, _, self._value_and_grad_with_aux = \
_fun_with_aux, _, self._value_and_grad_with_aux = \
base._make_funs_with_aux(fun=self.fun,
value_and_grad=self.value_and_grad,
has_aux=self.has_aux)
Expand All @@ -280,8 +280,8 @@ def __post_init__(self):
unroll = self._get_unroll_option()
self.linesearch_solver = _setup_linesearch(
linesearch=self.linesearch,
fun=self._value_and_grad_with_aux,
value_and_grad=True,
fun=_fun_with_aux,
value_and_grad=self._value_and_grad_with_aux,
has_aux=True,
maxlsiter=self.maxls,
max_stepsize=self.max_stepsize,
Expand Down
Loading

0 comments on commit c428455

Please sign in to comment.