From 9ab35ac85d7f50a51fe63adeef53609c44284728 Mon Sep 17 00:00:00 2001 From: madroid Date: Sun, 24 Mar 2024 12:07:58 +0800 Subject: [PATCH] LoRA: add sort data flag --- llms/mlx_lm/lora.py | 6 +++++ llms/mlx_lm/tuner/trainer.py | 50 +++++++++++++++++++++++++++--------- 2 files changed, 44 insertions(+), 12 deletions(-) diff --git a/llms/mlx_lm/lora.py b/llms/mlx_lm/lora.py index adc426e4f..d5f293b0a 100644 --- a/llms/mlx_lm/lora.py +++ b/llms/mlx_lm/lora.py @@ -75,6 +75,11 @@ def build_parser(): type=str, help="Directory with {train, valid, test}.jsonl files", ) + parser.add_argument( + "--sort-by-data-length", + action="store_true", + help="Sorts sequences by length to reduce padding and enhance efficiency.", + ) parser.add_argument( "--lora-layers", type=int, @@ -196,6 +201,7 @@ def run(args, training_callback: TrainingCallback = None): adapter_file=args.adapter_file, max_seq_length=args.max_seq_length, grad_checkpoint=args.grad_checkpoint, + sort_by_data_length=args.sort_by_data_length, ) model.train() diff --git a/llms/mlx_lm/tuner/trainer.py b/llms/mlx_lm/tuner/trainer.py index f0d8e0a43..48afb2234 100644 --- a/llms/mlx_lm/tuner/trainer.py +++ b/llms/mlx_lm/tuner/trainer.py @@ -2,7 +2,6 @@ import time from dataclasses import dataclass, field -from functools import partial from pathlib import Path import mlx.core as mx @@ -61,6 +60,12 @@ class TrainingArgs: default=False, metadata={"help": "Use gradient checkpointing to reduce memory use."}, ) + sort_by_data_length: bool = field( + default=False, + metadata={ + "help": "Sorts sequences by length to reduce padding and enhance efficiency." + }, + ) def default_loss(model, inputs, targets, lengths): @@ -76,19 +81,37 @@ def default_loss(model, inputs, targets, lengths): return ce, ntoks -def iterate_batches(dataset, tokenizer, batch_size, max_seq_length, train=False): - # Sort by length: - idx = sorted(range(len(dataset)), key=lambda idx: len(dataset[idx])) - - # Make the batches: - batch_idx = [ - idx[i : i + batch_size] for i in range(0, len(idx) - batch_size + 1, batch_size) - ] +def iterate_batches( + dataset, + tokenizer, + batch_size, + max_seq_length, + train=False, + sort_by_data_length=False, +): + if sort_by_data_length: + # Sort by length + idx = sorted(range(len(dataset)), key=lambda idx: len(dataset[idx])) + # Make batches + batch_idx = [ + idx[i : i + batch_size] + for i in range(0, len(idx) - batch_size + 1, batch_size) + ] + else: + # Shuffle indices + indices = np.arange(len(dataset)) + indices = np.random.permutation(indices) + # Make batches + batch_idx = [ + indices[i : i + batch_size] + for i in range(0, len(indices) - batch_size + 1, batch_size) + ] while True: - indices = np.random.permutation(len(batch_idx)) - for i in indices: - # Encode batch + # Randomize batch order + batch_indices = np.random.permutation(len(batch_idx)) + + for i in batch_indices: batch = [tokenizer.encode(dataset[j]) for j in batch_idx[i]] lengths = [len(x) for x in batch] @@ -129,6 +152,7 @@ def evaluate( max_seq_length=2048, loss: callable = default_loss, iterate_batches: callable = iterate_batches, + sort_by_data_length: bool = False, ): all_losses = [] ntokens = 0 @@ -139,6 +163,7 @@ def evaluate( tokenizer=tokenizer, batch_size=batch_size, max_seq_length=max_seq_length, + sort_by_data_length=sort_by_data_length, ), ): losses, toks = loss(model, *batch) @@ -213,6 +238,7 @@ def step(batch): batch_size=args.batch_size, max_seq_length=args.max_seq_length, train=True, + sort_by_data_length=args.sort_by_data_length, ), ): lvalue, toks = step(batch)