-
Notifications
You must be signed in to change notification settings - Fork 54
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
- Loading branch information
Showing
4 changed files
with
172 additions
and
1 deletion.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,167 @@ | ||
# Trainer In-depth Guide | ||
The `Trainer` is the base class for training all the models in Hezar no matter the task, dataset, model architecture, etc. | ||
In this guide, we'll demonstrate all the internals and details in this powerful class and how you can customize it based | ||
on your needs. | ||
|
||
We'll do this step by step in the exact order that takes place when training a model. | ||
|
||
## Initialization | ||
In order to initialize the Trainer you would need to have some objects ready. | ||
|
||
- **Trainer Config**: A Trainer config is a config dataclass (of type `TrainerConfig`) with a bunch of attributes that configures | ||
how the Trainer behaves. The most important ones being: | ||
- `output_dir` (required): Path to the directory to save trainer outputs (checkpoints, logs, etc.) | ||
- `task` (required): Specify the task of the training e.g, `text_classification`, `speech_recognition`, etc. | ||
- `num_epochs` (required): Number of training epochs | ||
- `init_weights_from` (optional): If the model is randomly initialized or you want to finetune it from a different set of | ||
weights, you can provide a path to a pretrained model using this parameter to load the model weights from. | ||
- `batch_size` (required): Training batch size | ||
- `eval_batch_size` (optional): Evaluation batch size (defaults to `batch_size` if not set) | ||
- `mixed_precision` (optional): Type of mixed precision e.g, `fp16`, `bf16`, etc. (Disabled by default) | ||
- `metrics` (optional): A set of metrics for model evaluation. Available metrics can be obtained using `utils.list_available_metrics()` | ||
- and etc. | ||
- **Model**: A Hezar Model instance | ||
- **Train & Eval Datasets**: Train and evaluation datasets (Hezar Dataset instances) | ||
- **Preprocessor(s)**: Model's preprocessor if it's not already in the model (`model.preprocessor == None`) | ||
|
||
As an example, here is how you can initialize a trainer for text classification using BERT: | ||
|
||
```python | ||
from hezar.models import BertTextClassification, BertTextClassificationConfig | ||
from hezar.data import Dataset | ||
from hezar.preprocessors import Preprocessor | ||
from hezar.trainer import Trainer, TrainerConfig | ||
|
||
|
||
dataset_path = "hezarai/sentiment-dksf" | ||
base_model_path = "hezarai/bert-base-fa" | ||
|
||
train_dataset = Dataset.load(dataset_path, split="train", tokenizer_path=base_model_path) | ||
eval_dataset = Dataset.load(dataset_path, split="test", tokenizer_path=base_model_path) | ||
|
||
model = BertTextClassification(BertTextClassificationConfig(id2label=train_dataset.config.id2label)) | ||
preprocessor = Preprocessor.load(base_model_path) | ||
|
||
train_config = TrainerConfig( | ||
output_dir="bert-fa-sentiment-analysis-dksf", | ||
task="text_classification", | ||
device="cuda", | ||
init_weights_from=base_model_path, | ||
batch_size=8, | ||
num_epochs=5, | ||
metrics=["f1", "precision", "accuracy", "recall"], | ||
) | ||
|
||
trainer = Trainer( | ||
config=train_config, | ||
model=model, | ||
train_dataset=train_dataset, | ||
eval_dataset=eval_dataset, | ||
data_collator=train_dataset.data_collator, | ||
preprocessor=preprocessor, | ||
) | ||
``` | ||
|
||
### `Trainer.__init__()` | ||
So what does exactly happen in the Trainer's `__init__`? | ||
|
||
* Config sanitization | ||
* Configure determinism (controlled by `config.seed`) | ||
* Configure device(s) | ||
* Setup model and | ||
* preprocessor (load pretrained weights if configured) | ||
* Setup data loaders | ||
* Setup optimizer and LR scheduler | ||
* Configure precision if set | ||
* Setup metrics handler (chosen based on the trainer's task, type `MetricsHandler`) | ||
* Configure paths and loggers | ||
* Setup trainer state | ||
|
||
Now the trainer has all the objects ready (almost!). | ||
|
||
## Training process | ||
The training process starts right when you call `trainer.train()`. This simple method does all the heavy lifting needed | ||
during a full training process. We'll go through each of them one by one. | ||
|
||
In a nutshell, the training process is simply a repeating loop of training the model on the full train data and then | ||
evaluating it on the evaluation data followed by calculating the metrics and saving logs and results. | ||
|
||
### 1. Training info | ||
Right before the trainer starts the main training process, it simply outputs some info about the run. | ||
These info would be something like this: | ||
``` | ||
******************** Training Info ******************** | ||
Output Directory: `bert-fa-sentiment-analysis-dksf` | ||
Task: `text_classification` | ||
Model: `BertTextClassification` | ||
Init Weights: `hezarai/bert-base-fa` | ||
Device(s): `cuda` | ||
Training Dataset: `TextClassificationDataset(path=hezarai/sentiment-dksf['train'], size=28602)` | ||
Evaluation Dataset: `TextClassificationDataset(path=hezarai/sentiment-dksf['test'], size=2315)` | ||
Optimizer: `adam` | ||
Initial Learning Rate: `2e-05` | ||
Learning Rate Decay: `0.0` | ||
Epochs: `5` | ||
Batch Size: `8` | ||
Number of Parameters: `118299651` | ||
Number of Trainable Parameters: `118299651` | ||
Mixed Precision: `Full (fp32)` | ||
Metrics: `['f1', 'precision', 'accuracy', 'recall']` | ||
Checkpoints Path: `bert-fa-sentiment-analysis-dksf/checkpoints` | ||
Logs Path: `bert-fa-sentiment-analysis-dksf/logs` | ||
******************************************************* | ||
``` | ||
|
||
### 2. Inner training loop | ||
The inner training loop which is invoked by `.inner_training_loop(epoch)` trains the model on the full iteration of the | ||
training data. This iteration itself is a repeating loop of the below processes: | ||
1. Preparation of the input batch using `.prepare_input_batch(input_batch)` which performs all the necessary validations and | ||
checks on the input batch like casting data type, device, etc. | ||
2. Performing one training step using `.training_step(input_batch)` which is the casual PyTorch forward pass followed by | ||
loss computation and backward pass and finally an optimizer step. This method outputs the loss value along the model outputs. | ||
3. Aggregation of loss values through the training loop and live verbose of loss average up until that point in the progress bar. | ||
|
||
### 3. Evaluation loop | ||
The evaluation loop is also the same as the training loop, but it does everything in eval mode (`torch.inference_mode`). | ||
1. Preparation of the input batch using `.prepare_input_batch(input_batch)` which performs all the necessary validations and | ||
checks on the input batch like casting data type, device, etc. | ||
2. Performing one evaluation step using `.evaluation_step(input_batch)`: | ||
- Perform one forward pass on the inputs | ||
- Calculate loss on the inputs | ||
- Calculate generation outputs if the model is generative (`model.is_generative`) | ||
3. Calculate metrics on the outputs which is handled by the metrics handler (`self.metrics_handler`) | ||
4. Aggregation of metrics values through the evaluation loop and live verbose of the averages in the progress bar. (the | ||
final value is the total average value) | ||
|
||
### 4. Saving & logging | ||
When a full training loop and evaluation is done, the trainer saves some properties: | ||
- Save the trainer state in the `trainer_state.yaml` file at `config.output_dir/config.checkpoints_dir`. | ||
- Save the current checkpoint (model weights, configs, etc.) at `config.checkpoints_dir/epoch`. | ||
- Save all the new training and evaluation results to the CSV file and TensorBoard at `config.logs_dir` | ||
|
||
All the above steps are repeated for `config.num_epochs` times. | ||
|
||
|
||
## Save & Push to Hub | ||
Like all other components in Hezar, you can also save or push the trainer to the Hub. | ||
|
||
### Save Trainer | ||
In order to save the trainer, you can simply call `trainer.save()` which accepts the following: | ||
- `path`: The target directory to save the objects in the trainer | ||
- `config_filename` | ||
- `model_filename` | ||
- `model_config_filename` | ||
- `subfolder` | ||
- `dataset_config_file` | ||
|
||
Overall, the `save` method saves the model, preprocessor, dataset config and the trainer config. | ||
|
||
### Push to Hub | ||
You can also push the trainer to the hub. This method just calls the `push_to_hub` method on the model, preprocessor, and configs. | ||
|
||
|
||
## Advanced Training & Customization | ||
The Trainer is implemented in a really flexible and customizable way so that any change can be done by simply overriding | ||
your desired method. You can learn more about how you can do such things [here](advanced_training.md) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters