Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Extracting intermediate function values/ losses from the solve #52

Open
itk22 opened this issue Mar 22, 2024 · 6 comments
Open

Extracting intermediate function values/ losses from the solve #52

itk22 opened this issue Mar 22, 2024 · 6 comments
Labels
feature New feature

Comments

@itk22
Copy link

itk22 commented Mar 22, 2024

Dear optimistix team,

First of all, thank you for your effort in developing optimistix. I have recently transitioned from JAXOpt, and I love it!

I was wondering if it is possible to extract the loss/ function value history from the optimistic solve? In the code example below, it is easy to evaluate the intermediate losses when using the multi_step_solve method, but it is much less efficient than the 'single_step_solve' approach. Using a jax.lax.scan would definitely improve the performance over using a for but I was wondering if there is a simpler method to extract this information in optimistix.

def rastrigin(x, args):
    A = 10.0
    y = A * x.shape[0] + jnp.sum(x**2 - A * jnp.cos(2 * jnp.pi * x), axis=0)
    return y

# How can we extract the losses for a single_step_solve?
def single_step_solve(solver, y0):
    sol = optx.minimise(rastrigin, solver, max_steps=2_000, y0=y0, throw=False)
    return sol.value

def multi_step_solve(solver, y0):
    # This is much less efficient, but it's easy to extract losses
    current_sol = y0
    for i in range(2_000):
        current_sol = optx.minimise(rastrigin, solver, max_steps=1, y0=current_sol, throw=False).value
    return current_sol
@patrick-kidger
Copy link
Owner

Hmm. I don't think we offer that at the moment!

FWIW if this is just for debugging purposes then you could add jax.debug.print statements to the input or output of your function.

If you really want to interact with the history programmatically then (a) I'm quite curious what the use-case is, but also (b) we could probably add an additional optx.minimise(..., saveat=...) argument without too much difficulty.

@itk22
Copy link
Author

itk22 commented Mar 23, 2024

Hi,
Thanks for the quick response. I am currently trying to use optimistix to implement Model Agnostic Meta-Learning and its implicit version (https://arxiv.org/abs/1909.04630). I was considering using a multi_step_solve approach from above for the outer meta-learning loop because it makes the training very quick. However, I need to be able to monitor the meta-losses, which is why I was looking into the less efficient single_step_solve. The saveat option would be very useful!

On a side note, I had a look at the interactive stepping example, and I thought that it could be useful for solvers to have an update method for performing a single optimisation step, similar to JAXOpt. What do you think?

@patrick-kidger
Copy link
Owner

patrick-kidger commented Mar 23, 2024

Makes sense. I'll mark this as a feature request for a saveat option. (And I'm sure we'd be happy to take a PR on this.)

For performing a single optimisation steps, then I think we already have this, as the step method on the solver?

@patrick-kidger patrick-kidger added the feature New feature label Mar 23, 2024
@itk22
Copy link
Author

itk22 commented Mar 26, 2024

Thanks Patrick, I will try to make a PR on this :)

@vadmbertr
Copy link

Hi @itk22,

Did you started a PR for the saveat option?

@itk22
Copy link
Author

itk22 commented Jan 24, 2025

hi @vadmbertr, I ended up using optax in my project and have not worked on the PR.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
feature New feature
Projects
None yet
Development

No branches or pull requests

3 participants