-
Notifications
You must be signed in to change notification settings - Fork 927
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
LoRA: Increased volatility of train loss #583
Comments
How did you measure that? Indeed I added the sorting before batching because it reduces a lot of wasted computation on the padded sequences, so it will be faster. This might cause the training loss to have slightly higher variance but I did not expect there to be an impact on the final validation loss, but it would be good to make sure we measure that correctly. Roughly how large is the dataset you are training on? |
I used a private data set for training, with a data set ratio of 8:1:1, and a total of about 10k. In the scenario I trained, the model would have a stable output (a structured URL instruction), so it was easy for me to calculate its accuracy. Accuracy before changes: 51.98% (551/1060) |
Yikes that’s pretty bad. So couple ideas:
I will play around with some options. If you have time to explore a bit that would also be helpful! |
I'd be happy to take a stab at this section and test it if you have some changes to make.
In addition to that, I'd like to know what you think of this part of the dataset formatting change (#548), and whether this part of the submission is something that will be accepted (or whether it's simple) in your opinion. |
Thanks!! I will take a look at #548 shortly, sorry for the delay! |
@madroidmaq did you have any time to investigate this? I am hoping to come back to it and figure out a better batching strategy. |
@awni Synchronizing some of my recent attempts, so far the logic for sorting has been the least effective, other than that I've tried random sorting and sorting by label (I have the corresponding lebel information in my data). Here is some data from my tests, the final accuracy and the loss curve.
I don't have any other validation optimization ideas on my side, I can submit the local logic for adding the sort flag to PR first, and then tweak it if there are other better ways to deal with it in the future. |
Thanks for the update.. it might be best to simply disable sorting and compile in LoRA for now :(. It is a modest but nice speed improvement so it's a shame, but it clearly requires some more work to get right. |
Synchronizing with the bad news, I rebase the latest code and test it and find that the loss curve changes similarly to the previous one, and the way we I used the flag is indeed working. However, when I tested the accuracy, I found that there was a significant drop in both sets of data, and the difference between the 2 became very small. I'm not quite sure what's causing this at the moment, and I'll further rule out recent code commits as having an effect on this. The dashed line shows the training results of the code before rebase, and the solid lines show the results of passing in true and False via flag. |
That doesn't look so good. What version of MLX are you using and what commit for MLX LM? |
@madroidmaq make sure you are using the latest MLX (0.8 or building from source). That's pretty important otherwise you will go through a bad path for RMS Norm (it won't accumulate in the right precision). |
@madroidmaq a couple of weeks ago we changed the default dropout from 0.05 to 0. Could this be the issue? I am trying to reconcile in the lower training and validation loss with the worse accuracy... Perhaps resetting it to 0.05 in your training config would be something to check. |
@awni I've tried this with both versions 0.6 and 0.8 of MLX and it works about the same, neither is good. I'm using the most current commit fbed720。 @angeloskath I'll adjust the dropout to 0.05 and try again. |
After adjusting the dropout to a value other than 0, the checkpoints file cannot be saved properly, and the following error message is reported:
I don't know much about this part, should I if I solve this problem, can you provide some ideas that I can try further. |
Oof sorry that’s the compile, you can just remove it for now as in #608 |
Hm that is interesting because I just trained with |
After I followed the tweaks in #608, it works fine. I tried the dropout parameter and it did improve, but not to the original accuracy. Here are the numbers with the dropout adjusted to 0.05 and 0.1 respectively: In addition to that, I rolled back the code to before #528 was committed and the accuracy was also around 30%. So it's possible that there's something wrong with my local code, and I'll look into it further. Currently my code logic is somewhat coupled with the code in the project, I'll spend some time decoupling them, and will probably also follow up with tweaks to the APIs in the current project so that refactoring can go smoothly. But this shouldn't affect commit #611 on sort flag. |
When using the latest LoRA training, the volatility of Loss became larger. When I further analyzed the cause, I suspected that it might have been introduced by commit #528 due to an internal reordering of the dataset.
mlx-examples/llms/mlx_lm/tuner/trainer.py
Lines 78 to 86 in e2205be
I revert this commit and retrain, and the loss curve returns to stability. The following is the situation during my training. The fluctuation in this part will reduce the accuracy of the final model by about 10%.
I'm not quite sure what the reason for this adjustment is, or how I should ignore this part of the logic for sorting the dataset in this case. If this part of the logic is really needed, the way I think of it is to enlarge it in the Dataset definition and allow adding your own data sets.
The text was updated successfully, but these errors were encountered: