-
Notifications
You must be signed in to change notification settings - Fork 6
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Add callbacks #11
Add callbacks #11
Conversation
Hey Brett, great to hear from you! This looks pretty neat. It looks like the Python API uses the return values of the callbacks in |
96dd7ec
to
f66a4ba
Compare
Good idea! Let me know what you think, should be passing now |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Awesome, thanks for adding those classes. Added some comments inline. It looks like the new files need to be required.
require_relative "xgboost/callback_container"
require_relative "xgboost/training_callback"
Also, if you're seeing an error running the tests locally, try running:
bundle exec rake vendor:platform
lib/xgboost.rb
Outdated
booster = Booster.new(params: params) | ||
cb_container = CallbackContainer.new(callbacks) | ||
cb_container.before_training(model: booster) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This should update booster
to match Python.
lib/xgboost.rb
Outdated
@@ -59,32 +63,36 @@ def train(params, dtrain, num_boost_round: 10, evals: nil, early_stopping_rounds | |||
end | |||
|
|||
num_boost_round.times do |iteration| | |||
cb_container.before_iteration(model: booster, epoch: iteration) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This should break if the return value is falsy.
booster.update(dtrain, iteration) | ||
|
||
if evals.any? |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Please keep the existing code where possible to keep the changeset minimal / easier to review.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Sorry, disabled rubocop
lib/xgboost.rb
Outdated
end | ||
cb_container.after_iteration(model: booster, epoch: iteration, res: res) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This should break
for falsy values like before_iteration
.
lib/xgboost.rb
Outdated
end | ||
cb_container.after_training(model: booster) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This should update booster
like before_training
.
lib/xgboost/callback_container.rb
Outdated
@history = {} | ||
|
||
callbacks.each do |callback| | ||
unless callback.class.ancestors.include?(TrainingCallback) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
callback.is_a?(TrainingCallback)
test/train_test.rb
Outdated
@@ -55,6 +55,88 @@ def test_feature_names_and_types | |||
assert_nil model.feature_types | |||
end | |||
|
|||
class MockCallback < XGBoost::TrainingCallback |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Let's create a separate file for the callback tests.
f66a4ba
to
80ff93c
Compare
Thanks Andrew, had to get rid of some aggressive Rubocop settings complicating the changes. Tests are all passing for me and everything's committed now 😅 Looks like you have to approve the Github workflow run here but I think we should have it this time. |
Co-authored-by: Brett Shollenberger <brett.shollenberger@gmail.com>
Thanks @brettshollenberger! Merged in the commit above with a few minor changes:
Going to spend a little time getting the overall code more in sync with Python, and then will push a new release. |
@ankane nice, thanks! With the params thing, it's not exactly 1:1, but I think an easier solution than the Python Wandb callback, which calls def before_training(self, model: Booster) -> Booster:
"""Run before training is finished."""
# Update W&B config
config = model.save_config()
wandb.config.update(json.loads(config))
return model def save_config(self) -> str:
"""Output internal parameter configuration of Booster as a JSON
string.
.. versionadded:: 1.0.0
"""
json_string = ctypes.c_char_p()
length = c_bst_ulong()
_check_call(
_LIB.XGBoosterSaveJsonConfig(
self.handle, ctypes.byref(length), ctypes.byref(json_string)
)
)
assert json_string.value is not None
result = json_string.value.decode() # pylint: disable=no-member
return result |
I think it's better to keep things synced for maintainability in most cases. Added |
Awesome, thank you! |
Great, just pushed 0.9.0. Let me know if you need anything else that's missing. |
Hey Andrew! Thanks for this library 😄
I wanted to add the callbacks API so I could support an integration with Wandb
I have the Ruby implementation of this callback over here as an example of the use case and have tested the integration in my Wandb console
Some dummy data just to show the integration:
I kept the API consistent w/ the Python implementation but let me know if there's anything else you'd want to see here