diff --git a/docs/tutorial/training/image_captioning.md b/docs/tutorial/training/image_captioning.md index 0560b113..4d1456ae 100644 --- a/docs/tutorial/training/image_captioning.md +++ b/docs/tutorial/training/image_captioning.md @@ -41,8 +41,8 @@ from hezar.utils import shift_tokens_right class Flickr30kDataset(ImageCaptioningDataset): - def __init__(self, config: ImageCaptioningDatasetConfig, split=None, **kwargs): - super().__init__(config=config, split=split, **kwargs) + def __init__(self, config: ImageCaptioningDatasetConfig, split=None, preprocessor=None, **kwargs): + super().__init__(config=config, split=split, preprocessor=preprocessor, **kwargs) # Override the `_load` method (originally loads a dataset from the Hub) to load the csv file def _load(self, split=None): diff --git a/docs/tutorial/training/license_plate_recognition.md b/docs/tutorial/training/license_plate_recognition.md index b0200608..8f35a239 100644 --- a/docs/tutorial/training/license_plate_recognition.md +++ b/docs/tutorial/training/license_plate_recognition.md @@ -3,6 +3,13 @@ Image to text is the task of generating text from an image e.g, image captioning In Hezar, the `image2text` task is responsible for all of those which currently includes image captioning and OCR. In this tutorial, we'll finetune a base OCR model (CRNN) on a license plate recognition dataset. +```python +from hezar.data import Dataset +from hezar.models import CRNNImage2TextConfig, CRNNImage2Text +from hezar.preprocessors import Preprocessor + +base_model_path = "hezarai/crnn-fa-printed-96-long" +``` ## Dataset In Hezar, there are two types of `image2text` datasets: `OCRDataset` and `ImageCaptioningDataset`. The reason is that @@ -13,27 +20,22 @@ but Transformer-based models like `ViTRoberta` requires the labels as token ids. We do provide a pretty solid ALPR dataset at [hezarai/persian-license-plate-v1](https://huggingface.co/datasets/hezarai/persian-license-plate-v1) which you can load as easily as: ```python -from hezar.data import Dataset -from hezar.preprocessors import Preprocessor - -dataset_id = "hezarai/persian-license-plate-v1" max_length = 8 reverse_digits = True -image_processor_config = {"size": (384, 32)} train_dataset = Dataset.load( "hezarai/persian-license-plate-v1", split="train", + preprocessor=base_model_path, max_length=8, reverse_digits=True, - image_processor_config=image_processor_config, ) eval_dataset = Dataset.load( "hezarai/persian-license-plate-v1", split="test", + preprocessor=base_model_path, max_length=8, reverse_digits=True, - image_processor_config=image_processor_config, ) ``` - License plates have only 8 characters so we set the max_length=8 which makes the dataset remove longer/shorter samples @@ -55,12 +57,13 @@ Since we're customizing an `image2text` dataset, we can override the `OCRDataset Let's consider you have a CSV file of your dataset with two columns: `image_path`, `text`. ```python +import pandas as pd from hezar.data import OCRDataset, OCRDatasetConfig class ALPRDataset(OCRDataset): - def __init__(self, config: OCRDatasetConfig, split=None, **kwargs): - super().__init__(config=config, split=split, **kwargs) + def __init__(self, config: OCRDatasetConfig, split=None, preprocessor=None, **kwargs): + super().__init__(config=config, split=split, preprocessor=preprocessor, **kwargs) # Override the `_load` method (originally loads a dataset from the Hub) to load the csv file def _load(self, split=None): @@ -91,11 +94,6 @@ You can customize this class further according to your needs. For the model we'll use the `CRNN` model with pretrained weights from `hezarai/crnn-fa-printed-96-long` which was trained on a large Persian corpus with millions of synthetic samples. ```python -from hezar.models import CRNNImage2TextConfig, CRNNImage2Text -from hezar.preprocessors import Preprocessor - -base_model_path = "hezarai/crnn-fa-printed-96-long" - model_config = CRNNImage2TextConfig.load(base_model_path, id2label=train_dataset.config.id2label) model = CRNNImage2Text(model_config) preprocessor = Preprocessor.load(base_model_path) diff --git a/docs/tutorial/training/sequence_labeling.md b/docs/tutorial/training/sequence_labeling.md index 452ffd03..0d6958c9 100644 --- a/docs/tutorial/training/sequence_labeling.md +++ b/docs/tutorial/training/sequence_labeling.md @@ -32,8 +32,8 @@ custom datasets refer to [this tutorial](). Loading Hezar datasets is pretty straight forward: ```python -train_dataset = Dataset.load("hezarai/lscp-pos-500k", split="train", tokenizer_path=base_model_path) -eval_dataset = Dataset.load("hezarai/lscp-pos-500k", split="test", tokenizer_path=base_model_path) +train_dataset = Dataset.load("hezarai/lscp-pos-500k", split="train", preprocessor=base_model_path) +eval_dataset = Dataset.load("hezarai/lscp-pos-500k", split="test", preprocessor=base_model_path) ``` What are these objects? Well, these are basically PyTorch Dataset instances which are actually wrapped by Hezar's `SequenceLabelingDataset` class (a subclass of `hezar.data.datasets.Dataset`). diff --git a/docs/tutorial/training/text_classification.md b/docs/tutorial/training/text_classification.md index 849f77f0..98ae4fb1 100644 --- a/docs/tutorial/training/text_classification.md +++ b/docs/tutorial/training/text_classification.md @@ -46,8 +46,8 @@ from hezar.data import TextClassificationDataset, TextClassificationDatasetConfi class SentimentAnalysisDataset(TextClassificationDataset): id2label = {0: "negative", 1: "positive", 2: "neutral"} - def __init__(self, config: TextClassificationDatasetConfig, split=None, **kwargs): - super().__init__(config, split=split, **kwargs) + def __init__(self, config: TextClassificationDatasetConfig, split=None, preprocessor=None, **kwargs): + super().__init__(config, split=split, preprocessor=preprocessor, **kwargs) def _load(self, split): # Load a dataframe here and make sure the split is fetched diff --git a/docs/tutorial/training/trainer.md b/docs/tutorial/training/trainer.md index 2985ef9a..72e5e5ca 100644 --- a/docs/tutorial/training/trainer.md +++ b/docs/tutorial/training/trainer.md @@ -16,8 +16,20 @@ from hezar.trainer import Trainer, TrainerConfig base_model_path = "hezarai/crnn-fa-printed-96-long" -train_dataset = Dataset.load("hezarai/persian-license-plate-v1", split="train", max_length=8, reverse_digits=True) -eval_dataset = Dataset.load("hezarai/persian-license-plate-v1", split="test", max_length=8, reverse_digits=True) +train_dataset = Dataset.load( + "hezarai/persian-license-plate-v1", + split="train", + preprocessor=base_model_path, + max_length=8, + reverse_digits=True, +) +eval_dataset = Dataset.load( + "hezarai/persian-license-plate-v1", + split="test", + preprocessor=base_model_path, + max_length=8, + reverse_digits=True, +) model = CRNNImage2Text(CRNNImage2TextConfig(id2label=train_dataset.config.id2label)) preprocessor = Preprocessor.load(base_model_path)