-
Notifications
You must be signed in to change notification settings - Fork 927
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Some improvements to LoRA #528
Conversation
llms/mlx_lm/models/llama.py
Outdated
@@ -124,7 +124,7 @@ def __init__(self, args: ModelArgs): | |||
self.post_attention_layernorm = RMSNorm(args.hidden_size, eps=args.rms_norm_eps) | |||
self.args = args | |||
|
|||
def __call__( | |||
def forward( |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
@angeloskath maybe you have a better idea for how to checkpoint all the models. What I found to be the simplest is to change __call__
to forward
and then set the forward
to the checkpointed function in the main training file.
Note monkey patching __call__
does not work for this as far as I can tell. I didn't know this before but evidently a.__call__
is actually
I will make this change for all the models unless you have a better idea.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Interesting. First cause I had no idea of the __call__
weird semantics and secondly cause doing
for i in range(model.layers):
model.layers[i] = nn.utils.checkpoint(model.layers[i])
will obviously break a bunch of other things like accessing the parameters. Maybe we should do proper monkey-patching if checkpoint
is passed module.
Another option is passing a checkpoint
argument to the module containing the layers but having a forward()
seems fine. We are not using __call__
for anything except as a convention sometimes which almost always has an alternative path (eg nn.utils.checkpoint
which can also take a function).
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
lol that was the first thing I tried.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I thought about the arg but it was slightly more involved to wire it all the way through
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
A proper monkey patch of module is interesting if we can work it
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Looks great!
print( | ||
f"Iter {it + 1}: Train loss {train_loss:.3f}, " | ||
f"Learning Rate {learning_rate:.3e}, " | ||
f"It/sec {it_sec:.3f}, " | ||
f"Tokens/sec {tokens_sec:.3f}, " | ||
f"Trained Tokens {trained_tokens}" | ||
f"Trained Tokens {trained_tokens}, " | ||
f"Peak mem {peak_mem:.3f} GB" | ||
) | ||
|
||
if training_callback is not None: |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Is it necessary to expose Peak memory information in training_callback? I think this part would be very useful.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Happy to
I will wait for this to land and then adopt it here ml-explore/mlx#788 |
When fine-tuning, I only tried to print out the Peak memory (
I'm not sure if this phenomenon is caused by not fully adopting all the code logic of the current PR. If that's the case, the current issue can be ignored. |
@madroidmaq I think that is accounted for by a race condition that we recently fixed. It should already be fixed on main in MLX. |
indices = np.random.permutation(indices) | ||
# Collect batches from dataset | ||
for i in range(0, len(indices) - batch_size + 1, batch_size): | ||
indices = np.random.permutation(len(batch_idx)) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
@awni I noticed this line of code when I was trying to add the sort flag to tweak this section #583. I'm a little confused as to whether or not sorting by length is canceled out by calling the function again after sorting by length earlier. Why would it sort by length first and then randomize the order.
I'm not very familiar with this part, so please let me know if my understanding is off.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
That just to randomize the order of the batches so it doesn't go from batches with the shortest sequences to batches with the longest sequences.
--grad-checkpoint
Compile Benchmarks
Decent gain from both sorting and compile (we need a larger average to do a proper comparison, but you can see overall sorting helps as does compile:
Original (no sorting, no compile)
Just sorting no Compile
With Compile + Sorting
Checkpoint Benchmarks
TLDR: reduces memory nicely with large batch + lots of LoRA layers, especially noticeable with QLoRA where the model occupies less memory.
Regular LoRA with the command:
QLoRA with the command: