Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add callbacks #11

Closed

Conversation

brettshollenberger
Copy link
Contributor

Hey Andrew! Thanks for this library 😄

I wanted to add the callbacks API so I could support an integration with Wandb

I have the Ruby implementation of this callback over here as an example of the use case and have tested the integration in my Wandb console

Some dummy data just to show the integration:
Screenshot 2024-10-11 at 2 36 17 PM
Screenshot 2024-10-11 at 2 35 55 PM

I kept the API consistent w/ the Python implementation but let me know if there's anything else you'd want to see here

@ankane
Copy link
Owner

ankane commented Oct 13, 2024

Hey Brett, great to hear from you! This looks pretty neat. It looks like the Python API uses the return values of the callbacks in train. I think the best way to do that would be to add TrainingCallback and CallbackContainer classes like Python (code), but I can do that in a follow-up commit if you just want to get the tests passing.

@brettshollenberger brettshollenberger force-pushed the callbacks branch 3 times, most recently from 96dd7ec to f66a4ba Compare October 14, 2024 21:21
@brettshollenberger
Copy link
Contributor Author

Good idea! Let me know what you think, should be passing now

Copy link
Owner

@ankane ankane left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Awesome, thanks for adding those classes. Added some comments inline. It looks like the new files need to be required.

require_relative "xgboost/callback_container"
require_relative "xgboost/training_callback"

Also, if you're seeing an error running the tests locally, try running:

bundle exec rake vendor:platform

lib/xgboost.rb Outdated
booster = Booster.new(params: params)
cb_container = CallbackContainer.new(callbacks)
cb_container.before_training(model: booster)
Copy link
Owner

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This should update booster to match Python.

lib/xgboost.rb Outdated
@@ -59,32 +63,36 @@ def train(params, dtrain, num_boost_round: 10, evals: nil, early_stopping_rounds
end

num_boost_round.times do |iteration|
cb_container.before_iteration(model: booster, epoch: iteration)
Copy link
Owner

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This should break if the return value is falsy.

booster.update(dtrain, iteration)

if evals.any?
Copy link
Owner

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Please keep the existing code where possible to keep the changeset minimal / easier to review.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Sorry, disabled rubocop

lib/xgboost.rb Outdated
end
cb_container.after_iteration(model: booster, epoch: iteration, res: res)
Copy link
Owner

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This should break for falsy values like before_iteration.

lib/xgboost.rb Outdated
end
cb_container.after_training(model: booster)
Copy link
Owner

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This should update booster like before_training.

@history = {}

callbacks.each do |callback|
unless callback.class.ancestors.include?(TrainingCallback)
Copy link
Owner

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

callback.is_a?(TrainingCallback)

@@ -55,6 +55,88 @@ def test_feature_names_and_types
assert_nil model.feature_types
end

class MockCallback < XGBoost::TrainingCallback
Copy link
Owner

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Let's create a separate file for the callback tests.

@brettshollenberger
Copy link
Contributor Author

Thanks Andrew, had to get rid of some aggressive Rubocop settings complicating the changes. Tests are all passing for me and everything's committed now 😅 Looks like you have to approve the Github workflow run here but I think we should have it this time.

ankane added a commit that referenced this pull request Oct 16, 2024
Co-authored-by: Brett Shollenberger <brett.shollenberger@gmail.com>
@ankane
Copy link
Owner

ankane commented Oct 16, 2024

Thanks @brettshollenberger! Merged in the commit above with a few minor changes:

  1. I had the logic backwards for before/after_iteration - it should stop if it returns a truthy value
  2. Changed callbacks to use positional arguments instead of keyword
  3. Left out the params change, as it's not present in the Python library (from what I can tell)

Going to spend a little time getting the overall code more in sync with Python, and then will push a new release.

@ankane ankane closed this Oct 16, 2024
@brettshollenberger
Copy link
Contributor Author

brettshollenberger commented Oct 16, 2024

@ankane nice, thanks!

With the params thing, it's not exactly 1:1, but I think an easier solution than the Python Wandb callback, which calls xgboost/core#save_config.

https://github.com/wandb/wandb/blob/8698af5862e44baf31af5411b81bea546e069257/wandb/integration/xgboost/xgboost.py#L117

    def before_training(self, model: Booster) -> Booster:
        """Run before training is finished."""
        # Update W&B config
        config = model.save_config()
        wandb.config.update(json.loads(config))

        return model

https://github.com/dmlc/xgboost/blob/3f9bfaf86e6db6a4f54734aa7d164df55aa69ef6/python-package/xgboost/core.py#L2008

    def save_config(self) -> str:
        """Output internal parameter configuration of Booster as a JSON
        string.

        .. versionadded:: 1.0.0

        """
        json_string = ctypes.c_char_p()
        length = c_bst_ulong()
        _check_call(
            _LIB.XGBoosterSaveJsonConfig(
                self.handle, ctypes.byref(length), ctypes.byref(json_string)
            )
        )
        assert json_string.value is not None
        result = json_string.value.decode()  # pylint: disable=no-member
        return result

ankane added a commit that referenced this pull request Oct 16, 2024
@ankane
Copy link
Owner

ankane commented Oct 16, 2024

I think it's better to keep things synced for maintainability in most cases. Added save_config in the commit above.

@brettshollenberger
Copy link
Contributor Author

Awesome, thank you!

@ankane
Copy link
Owner

ankane commented Oct 17, 2024

Great, just pushed 0.9.0. Let me know if you need anything else that's missing.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

2 participants