Skip to content

Commit

Permalink
📝 Fix some docs
Browse files Browse the repository at this point in the history
  • Loading branch information
arxyzan committed Jun 12, 2024
1 parent 0524341 commit df0ab1f
Show file tree
Hide file tree
Showing 5 changed files with 32 additions and 22 deletions.
4 changes: 2 additions & 2 deletions docs/tutorial/training/image_captioning.md
Original file line number Diff line number Diff line change
Expand Up @@ -41,8 +41,8 @@ from hezar.utils import shift_tokens_right


class Flickr30kDataset(ImageCaptioningDataset):
def __init__(self, config: ImageCaptioningDatasetConfig, split=None, **kwargs):
super().__init__(config=config, split=split, **kwargs)
def __init__(self, config: ImageCaptioningDatasetConfig, split=None, preprocessor=None, **kwargs):
super().__init__(config=config, split=split, preprocessor=preprocessor, **kwargs)

# Override the `_load` method (originally loads a dataset from the Hub) to load the csv file
def _load(self, split=None):
Expand Down
26 changes: 12 additions & 14 deletions docs/tutorial/training/license_plate_recognition.md
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,13 @@ Image to text is the task of generating text from an image e.g, image captioning
In Hezar, the `image2text` task is responsible for all of those which currently includes image captioning and OCR.

In this tutorial, we'll finetune a base OCR model (CRNN) on a license plate recognition dataset.
```python
from hezar.data import Dataset
from hezar.models import CRNNImage2TextConfig, CRNNImage2Text
from hezar.preprocessors import Preprocessor

base_model_path = "hezarai/crnn-fa-printed-96-long"
```

## Dataset
In Hezar, there are two types of `image2text` datasets: `OCRDataset` and `ImageCaptioningDataset`. The reason is that
Expand All @@ -13,27 +20,22 @@ but Transformer-based models like `ViTRoberta` requires the labels as token ids.
We do provide a pretty solid ALPR dataset at [hezarai/persian-license-plate-v1](https://huggingface.co/datasets/hezarai/persian-license-plate-v1)
which you can load as easily as:
```python
from hezar.data import Dataset
from hezar.preprocessors import Preprocessor

dataset_id = "hezarai/persian-license-plate-v1"
max_length = 8
reverse_digits = True
image_processor_config = {"size": (384, 32)}

train_dataset = Dataset.load(
"hezarai/persian-license-plate-v1",
split="train",
preprocessor=base_model_path,
max_length=8,
reverse_digits=True,
image_processor_config=image_processor_config,
)
eval_dataset = Dataset.load(
"hezarai/persian-license-plate-v1",
split="test",
preprocessor=base_model_path,
max_length=8,
reverse_digits=True,
image_processor_config=image_processor_config,
)
```
- License plates have only 8 characters so we set the max_length=8 which makes the dataset remove longer/shorter samples
Expand All @@ -55,12 +57,13 @@ Since we're customizing an `image2text` dataset, we can override the `OCRDataset
Let's consider you have a CSV file of your dataset with two columns: `image_path`, `text`.

```python
import pandas as pd
from hezar.data import OCRDataset, OCRDatasetConfig


class ALPRDataset(OCRDataset):
def __init__(self, config: OCRDatasetConfig, split=None, **kwargs):
super().__init__(config=config, split=split, **kwargs)
def __init__(self, config: OCRDatasetConfig, split=None, preprocessor=None, **kwargs):
super().__init__(config=config, split=split, preprocessor=preprocessor, **kwargs)

# Override the `_load` method (originally loads a dataset from the Hub) to load the csv file
def _load(self, split=None):
Expand Down Expand Up @@ -91,11 +94,6 @@ You can customize this class further according to your needs.
For the model we'll use the `CRNN` model with pretrained weights from `hezarai/crnn-fa-printed-96-long` which was trained
on a large Persian corpus with millions of synthetic samples.
```python
from hezar.models import CRNNImage2TextConfig, CRNNImage2Text
from hezar.preprocessors import Preprocessor

base_model_path = "hezarai/crnn-fa-printed-96-long"

model_config = CRNNImage2TextConfig.load(base_model_path, id2label=train_dataset.config.id2label)
model = CRNNImage2Text(model_config)
preprocessor = Preprocessor.load(base_model_path)
Expand Down
4 changes: 2 additions & 2 deletions docs/tutorial/training/sequence_labeling.md
Original file line number Diff line number Diff line change
Expand Up @@ -32,8 +32,8 @@ custom datasets refer to [this tutorial]().
Loading Hezar datasets is pretty straight forward:

```python
train_dataset = Dataset.load("hezarai/lscp-pos-500k", split="train", tokenizer_path=base_model_path)
eval_dataset = Dataset.load("hezarai/lscp-pos-500k", split="test", tokenizer_path=base_model_path)
train_dataset = Dataset.load("hezarai/lscp-pos-500k", split="train", preprocessor=base_model_path)
eval_dataset = Dataset.load("hezarai/lscp-pos-500k", split="test", preprocessor=base_model_path)
```
What are these objects? Well, these are basically PyTorch Dataset instances which are actually wrapped by Hezar's
`SequenceLabelingDataset` class (a subclass of `hezar.data.datasets.Dataset`).
Expand Down
4 changes: 2 additions & 2 deletions docs/tutorial/training/text_classification.md
Original file line number Diff line number Diff line change
Expand Up @@ -46,8 +46,8 @@ from hezar.data import TextClassificationDataset, TextClassificationDatasetConfi
class SentimentAnalysisDataset(TextClassificationDataset):
id2label = {0: "negative", 1: "positive", 2: "neutral"}

def __init__(self, config: TextClassificationDatasetConfig, split=None, **kwargs):
super().__init__(config, split=split, **kwargs)
def __init__(self, config: TextClassificationDatasetConfig, split=None, preprocessor=None, **kwargs):
super().__init__(config, split=split, preprocessor=preprocessor, **kwargs)

def _load(self, split):
# Load a dataframe here and make sure the split is fetched
Expand Down
16 changes: 14 additions & 2 deletions docs/tutorial/training/trainer.md
Original file line number Diff line number Diff line change
Expand Up @@ -16,8 +16,20 @@ from hezar.trainer import Trainer, TrainerConfig

base_model_path = "hezarai/crnn-fa-printed-96-long"

train_dataset = Dataset.load("hezarai/persian-license-plate-v1", split="train", max_length=8, reverse_digits=True)
eval_dataset = Dataset.load("hezarai/persian-license-plate-v1", split="test", max_length=8, reverse_digits=True)
train_dataset = Dataset.load(
"hezarai/persian-license-plate-v1",
split="train",
preprocessor=base_model_path,
max_length=8,
reverse_digits=True,
)
eval_dataset = Dataset.load(
"hezarai/persian-license-plate-v1",
split="test",
preprocessor=base_model_path,
max_length=8,
reverse_digits=True,
)

model = CRNNImage2Text(CRNNImage2TextConfig(id2label=train_dataset.config.id2label))
preprocessor = Preprocessor.load(base_model_path)
Expand Down

0 comments on commit df0ab1f

Please sign in to comment.