Skip to content

Commit

Permalink
Merge pull request #541 from vroulet:improved_lbfgs_zoom
Browse files Browse the repository at this point in the history
PiperOrigin-RevId: 570802016
  • Loading branch information
JAXopt authors committed Oct 4, 2023
2 parents b23c7a5 + 614dc7b commit 0c032fd
Show file tree
Hide file tree
Showing 13 changed files with 210 additions and 100 deletions.
2 changes: 1 addition & 1 deletion jaxopt/_src/bfgs.py
Original file line number Diff line number Diff line change
Expand Up @@ -101,7 +101,7 @@ class BFGS(base.IterativeSolver):
increase_factor: factor by which to increase the stepsize during line search
(default: 1.5).
max_stepsize: upper bound on stepsize.
min_stepsize: lower bound on stepsize.
min_stepsize: lower bound on stepsize guess at start of the linesearch run.
implicit_diff: whether to enable implicit diff or autodiff of unrolled
iterations.
implicit_diff_solve: the linear system solver to use.
Expand Down
2 changes: 1 addition & 1 deletion jaxopt/_src/broyden.py
Original file line number Diff line number Diff line change
Expand Up @@ -165,7 +165,7 @@ class Broyden(base.IterativeSolver):
increase_factor: factor by which to increase the stepsize during line search
(default: 1.5).
max_stepsize: upper bound on stepsize.
min_stepsize: lower bound on stepsize.
min_stepsize: lower bound on stepsize guess at start of the linesearch run.
history_size: size of the memory to use.
gamma: the initialization of the inverse Jacobian is going to be gamma * I.
Expand Down
5 changes: 3 additions & 2 deletions jaxopt/_src/hager_zhang_linesearch.py
Original file line number Diff line number Diff line change
Expand Up @@ -82,7 +82,7 @@ class HagerZhangLineSearch(base.IterativeLineSearch):
c1: constant used by the Wolfe and Approximate Wolfe condition.
c2: constant strictly less than 1 used by the Wolfe and Approximate Wolfe
condition.
max_stepsize: upper bound on stepsize.
max_stepsize: upper bound on stepsize (unused).
maxiter: maximum number of line search iterations.
tol: tolerance of the stopping criterion.
Expand All @@ -104,7 +104,8 @@ class HagerZhangLineSearch(base.IterativeLineSearch):
expansion_factor: float = 5.0
shrinkage_factor: float = 0.66
approximate_wolfe_threshold = 1e-6
max_stepsize: float = 1.0
# TODO(vroulet): remove max_stepsize argument as it is not used
max_stepsize: float = 1.0

verbose: int = 0
jit: base.AutoOrBoolean = "auto"
Expand Down
2 changes: 1 addition & 1 deletion jaxopt/_src/lbfgs.py
Original file line number Diff line number Diff line change
Expand Up @@ -184,7 +184,7 @@ class LBFGS(base.IterativeSolver):
increase_factor: factor by which to increase the stepsize during line search
(default: 1.5).
max_stepsize: upper bound on stepsize.
min_stepsize: lower bound on stepsize.
min_stepsize: lower bound on stepsize guess at start of each linesearch run.
history_size: size of the memory to use.
use_gamma: whether to initialize the inverse Hessian approximation with
gamma * I, where gamma is chosen following equation (7.20) of 'Numerical
Expand Down
4 changes: 2 additions & 2 deletions jaxopt/_src/lbfgsb.py
Original file line number Diff line number Diff line change
Expand Up @@ -262,7 +262,7 @@ class LBFGSB(base.IterativeSolver):
increase_factor: factor by which to increase the stepsize during line search
(default: 1.5).
max_stepsize: upper bound on stepsize.
min_stepsize: lower bound on stepsize.
min_stepsize: lower bound on stepsize guess at start of each linesearch run.
history_size: size of the memory to use.
use_gamma: whether to initialize the Hessian approximation with gamma *
theta, where gamma is chosen following equation (7.20) of 'Numerical
Expand All @@ -289,7 +289,7 @@ class LBFGSB(base.IterativeSolver):
linesearch_init: str = "increase"
stop_if_linesearch_fails: bool = False
condition: Any = None # deprecated in v0.8
maxls: int = 20
maxls: int = 30
decrease_factor: Any = None # deprecated in v0.8
increase_factor: float = 1.5
max_stepsize: float = 1.0
Expand Down
2 changes: 1 addition & 1 deletion jaxopt/_src/linesearch_util.py
Original file line number Diff line number Diff line change
Expand Up @@ -58,12 +58,12 @@ def _setup_linesearch(
verbose=verbose,
)
elif linesearch == "hager-zhang":
# NOTE(vroulet): max_stepsize has no effect in HZ
linesearch_solver = HagerZhangLineSearch(
fun=fun,
value_and_grad=value_and_grad,
has_aux=has_aux,
maxiter=maxlsiter,
max_stepsize=max_stepsize,
jit=jit,
unroll=unroll,
verbose=verbose,
Expand Down
5 changes: 2 additions & 3 deletions jaxopt/_src/nonlinear_cg.py
Original file line number Diff line number Diff line change
Expand Up @@ -91,8 +91,7 @@ class NonlinearCG(base.IterativeSolver):
increase_factor: factor by which to increase the stepsize during line search
(default: 1.2).
max_stepsize: upper bound on stepsize.
min_stepsize: lower bound on stepsize.
min_stepsize: lower bound on stepsize guess at start of each linesearch run.
implicit_diff: whether to enable implicit diff or autodiff of unrolled
iterations.
implicit_diff_solve: the linear system solver to use.
Expand Down Expand Up @@ -123,7 +122,7 @@ class NonlinearCG(base.IterativeSolver):
linesearch: str = "zoom"
linesearch_init: str = "increase"
condition: Any = None # deprecated in v0.8
maxls: int = 15
maxls: int = 30
decrease_factor: Any = None # deprecated in v0.8
increase_factor: float = 1.2
max_stepsize: float = 1.0
Expand Down
Loading

0 comments on commit 0c032fd

Please sign in to comment.