Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Some improvements to LoRA #528

Merged
merged 7 commits into from
Mar 13, 2024
Merged

Some improvements to LoRA #528

merged 7 commits into from
Mar 13, 2024

Conversation

awni
Copy link
Member

@awni awni commented Mar 4, 2024

  • Sort dataset prior to batching for more consistent lengths
  • Compile non MOE models
  • Add checkpointing as an option --grad-checkpoint

Compile Benchmarks

Decent gain from both sorting and compile (we need a larger average to do a proper comparison, but you can see overall sorting helps as does compile:

Original (no sorting, no compile)

Iter 10: Train loss 2.053, Learning Rate 1.000e-05, It/sec 1.312, Tokens/sec 521.974, Trained Tokens 3978
Iter 20: Train loss 1.450, Learning Rate 1.000e-05, It/sec 1.234, Tokens/sec 494.780, Trained Tokens 7989
Iter 30: Train loss 1.319, Learning Rate 1.000e-05, It/sec 1.196, Tokens/sec 487.183, Trained Tokens 12063
Iter 40: Train loss 1.240, Learning Rate 1.000e-05, It/sec 1.279, Tokens/sec 490.877, Trained Tokens 15900
Iter 50: Train loss 1.172, Learning Rate 1.000e-05, It/sec 1.274, Tokens/sec 501.700, Trained Tokens 19837
Iter 60: Train loss 1.062, Learning Rate 1.000e-05, It/sec 1.284, Tokens/sec 503.399, Trained Tokens 23757

Just sorting no Compile

Iter 10: Train loss 2.063, Learning Rate 1.000e-05, It/sec 1.446, Tokens/sec 578.409, Trained Tokens 3999, Peak mem 17.383 GB
Iter 20: Train loss 1.639, Learning Rate 1.000e-05, It/sec 1.325, Tokens/sec 534.407, Trained Tokens 8032, Peak mem 17.383 GB
Iter 30: Train loss 1.364, Learning Rate 1.000e-05, It/sec 1.374, Tokens/sec 528.128, Trained Tokens 11877, Peak mem 17.383 GB
Iter 40: Train loss 1.253, Learning Rate 1.000e-05, It/sec 1.391, Tokens/sec 529.224, Trained Tokens 15682, Peak mem 17.383 GB
Iter 50: Train loss 1.087, Learning Rate 1.000e-05, It/sec 1.425, Tokens/sec 527.453, Trained Tokens 19383, Peak mem 17.383 GB
Iter 60: Train loss 1.165, Learning Rate 1.000e-05, It/sec 1.354, Tokens/sec 514.346, Trained Tokens 23181, Peak mem 17.383 GB

With Compile + Sorting

Iter 10: Train loss 2.061, Learning Rate 1.000e-05, It/sec 1.462, Tokens/sec 584.663, Trained Tokens 3999, Peak mem 17.408 GB
Iter 20: Train loss 1.651, Learning Rate 1.000e-05, It/sec 1.366, Tokens/sec 551.017, Trained Tokens 8032, Peak mem 17.423 GB
Iter 30: Train loss 1.368, Learning Rate 1.000e-05, It/sec 1.403, Tokens/sec 539.275, Trained Tokens 11877, Peak mem 17.423 GB
Iter 40: Train loss 1.257, Learning Rate 1.000e-05, It/sec 1.398, Tokens/sec 531.795, Trained Tokens 15682, Peak mem 17.423 GB
Iter 50: Train loss 1.089, Learning Rate 1.000e-05, It/sec 1.488, Tokens/sec 550.811, Trained Tokens 19383, Peak mem 17.423 GB
Iter 60: Train loss 1.168, Learning Rate 1.000e-05, It/sec 1.410, Tokens/sec 535.432, Trained Tokens 23181, Peak mem 17.423 GB

Checkpoint Benchmarks

TLDR: reduces memory nicely with large batch + lots of LoRA layers, especially noticeable with QLoRA where the model occupies less memory.

Regular LoRA with the command:

python -m mlx_lm.lora --model mistralai/Mistral-7B-v0.1 --train --data ../lora/data --lora-layers 32 --grad-checkpoint --batch-size 8
Peak memory
No Checkpoint 32.078 GB
Checkpoint 19.795 GB

QLoRA with the command:

python -m mlx_lm.lora --model mlx-community/NeuralBeagle14-7B-4bit-mlx --train --data ../lora/data --lora-layers 32 --batch-size 8
Peak memory
No Checkpoint 20.695 GB
Checkpoint 8.199 GB

@@ -124,7 +124,7 @@ def __init__(self, args: ModelArgs):
self.post_attention_layernorm = RMSNorm(args.hidden_size, eps=args.rms_norm_eps)
self.args = args

def __call__(
def forward(
Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@angeloskath maybe you have a better idea for how to checkpoint all the models. What I found to be the simplest is to change __call__ to forward and then set the forward to the checkpointed function in the main training file.

Note monkey patching __call__ does not work for this as far as I can tell. I didn't know this before but evidently a.__call__ is actually

I will make this change for all the models unless you have a better idea.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Interesting. First cause I had no idea of the __call__ weird semantics and secondly cause doing

for i in range(model.layers):
    model.layers[i] = nn.utils.checkpoint(model.layers[i])

will obviously break a bunch of other things like accessing the parameters. Maybe we should do proper monkey-patching if checkpoint is passed module.

Another option is passing a checkpoint argument to the module containing the layers but having a forward() seems fine. We are not using __call__ for anything except as a convention sometimes which almost always has an alternative path (eg nn.utils.checkpoint which can also take a function).

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

lol that was the first thing I tried.

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I thought about the arg but it was slightly more involved to wire it all the way through

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

A proper monkey patch of module is interesting if we can work it

@awni awni requested a review from angeloskath March 4, 2024 15:27
Copy link
Member

@angeloskath angeloskath left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Looks great!

print(
f"Iter {it + 1}: Train loss {train_loss:.3f}, "
f"Learning Rate {learning_rate:.3e}, "
f"It/sec {it_sec:.3f}, "
f"Tokens/sec {tokens_sec:.3f}, "
f"Trained Tokens {trained_tokens}"
f"Trained Tokens {trained_tokens}, "
f"Peak mem {peak_mem:.3f} GB"
)

if training_callback is not None:
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Is it necessary to expose Peak memory information in training_callback? I think this part would be very useful.

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Happy to

@awni
Copy link
Member Author

awni commented Mar 5, 2024

I will wait for this to land and then adopt it here ml-explore/mlx#788

@awni awni marked this pull request as ready for review March 5, 2024 23:37
@madroidmaq
Copy link
Contributor

I will wait for this to land and then adopt it here ml-explore/mlx#788

When fine-tuning, I only tried to print out the Peak memory (mx.metal.get_peak_memory() / 2**30) information without making any other adjustments. Everything was normal at the beginning of the run, but after running for a while, there was a significant anomaly in the values (14.809 GB -> 17179869184.000 GB), with the printed information as follows:

...
Iter 160: Train loss 1.165, Learning Rate 1.000e-05, It/sec 0.279, Tokens/sec 267.007, Trained Tokens 152357, peak_memory 14.809 GB
Iter 170: Train loss 1.123, Learning Rate 1.000e-05, It/sec 0.291, Tokens/sec 270.279, Trained Tokens 161649, peak_memory 14.809 GB
Iter 180: Train loss 1.077, Learning Rate 1.000e-05, It/sec 0.239, Tokens/sec 234.937, Trained Tokens 171491, peak_memory 14.809 GB
Iter 190: Train loss 1.042, Learning Rate 1.000e-05, It/sec 0.291, Tokens/sec 272.583, Trained Tokens 180845, peak_memory 17179869183.987 GB
Iter 200: Train loss 1.058, Learning Rate 1.000e-05, It/sec 0.275, Tokens/sec 267.219, Trained Tokens 190570, peak_memory 17179869183.997 GB
Iter 200: Val loss 1.017, Val took 60.756s
Iter 200: Saved adapter weights to checkpoints/200_adapters.npz.
Iter 210: Train loss 1.002, Learning Rate 1.000e-05, It/sec 0.252, Tokens/sec 242.204, Trained Tokens 200185, peak_memory 17179869184.000 GB
Iter 220: Train loss 1.012, Learning Rate 1.000e-05, It/sec 0.253, Tokens/sec 249.107, Trained Tokens 210023, peak_memory 17179869184.000 GB
Iter 230: Train loss 0.994, Learning Rate 1.000e-05, It/sec 0.255, Tokens/sec 248.113, Trained Tokens 219751, peak_memory 17179869184.000 GB
...

I'm not sure if this phenomenon is caused by not fully adopting all the code logic of the current PR. If that's the case, the current issue can be ignored.

@awni
Copy link
Member Author

awni commented Mar 6, 2024

@madroidmaq I think that is accounted for by a race condition that we recently fixed. It should already be fixed on main in MLX.

@awni awni merged commit 39084e8 into main Mar 13, 2024
2 checks passed
@awni awni deleted the lora branch March 13, 2024 03:02
indices = np.random.permutation(indices)
# Collect batches from dataset
for i in range(0, len(indices) - batch_size + 1, batch_size):
indices = np.random.permutation(len(batch_idx))
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@awni I noticed this line of code when I was trying to add the sort flag to tweak this section #583. I'm a little confused as to whether or not sorting by length is canceled out by calling the function again after sorting by length earlier. Why would it sort by length first and then randomize the order.

I'm not very familiar with this part, so please let me know if my understanding is off.

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

That just to randomize the order of the batches so it doesn't go from batches with the shortest sequences to batches with the longest sequences.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

3 participants