Add balance_weights
to weight balanced batches
#1588
Draft
+3,047
−91
Add this suggestion to a batch that can be applied as a single commit.
This suggestion is invalid because no changes were made to the code.
Suggestions cannot be applied while the pull request is closed.
Suggestions cannot be applied while viewing a subset of changes.
Only one suggestion per line can be applied in a batch.
Add this suggestion to a batch that can be applied as a single commit.
Applying suggestions on deleted lines is not supported.
You must change the existing code in this line in order to create a valid suggestion.
Outdated suggestions cannot be applied.
This suggestion has been applied or marked resolved.
Suggestions cannot be applied from pending reviews.
Suggestions cannot be applied on multi-line comments.
Suggestions cannot be applied while the pull request is queued to merge.
Suggestion cannot be applied right now. Please check back later.
The
balance
option of the segmentation tasks allows to pass a list ofProtocolFile
fields, e.g.['database', 'foo']
. Then when batches are sampled, it looks at all existing combinations of values for these fields in the task protocol.For example if they come from databases
aishell
andami
, and theirfoo
field is eithera
orb
, we compute the cartesian product[('aishell', 'a'), ('aishell', 'b'), ('ami', 'a'), ('ami', 'b')]
, batches are created by randomly selecting one of these tuples and picking a sample from a matching file.The PR allows to weight the random choice from the cartesian product. For example with
we will sample from the cartesian product using random.choices with these weights:
e.g. for each tuple of the cartesian product, we find the longest matching (tuple) prefix in
balance_weights
and use this weight.I'm not sure this approach is flexible/clean enough to be PR-ready, and it's hard to make the docstring concise, but i think it could be really useful :)