This repository contains the code and resources for text classification task using state-of-the-art Transformer models such as BERT and XLNet. You can use this github repository to fine-tune different tranformer models on your dataset for the task of sequence classification.
For the experiment part, we propose a Query Gender classifier that is capable of identifying the gender of each sentence. As the first step and in order to be able to label sequences based on their gender at scale, we employed the gender-annotated dataset released by Navid Rekabsaz to train relevant classifiers. This dataset consists of 742 female, 1,202 male and 1,765 neutral queries. We trained various types of classifiers on this dataset and in order to evaluate the performance of the classifiers, we adopt a 5-fold cross-validation strategy.
Classifier | Accuracy | F1-Score | ||
---|---|---|---|---|
Female | Male | Neutral | ||
BERT (base uncased) | 0.856 | 0.816 | 0.872 | 0.862 |
DistilBERT (base uncased) | 0.847 | 0.815 | 0.861 | 0.853 |
RoBERTa | 0.810 | 0.733 | 0.820 | 0.836 |
DistilBERT (base cased) | 0.800 | 0.730 | 0.823 | 0.833 |
BERT (base cased) | 0.797 | 0.710 | 0.805 | 0.827 |
XLNet (base cased) | 0.795 | 0.710 | 0.805 | 0.826 |
You can find the fine-tuned BERT model on the gender-annotated dataset here and use inference.py for predicting the gender of sequences.
All of the models on the Huggingface that support AutoModelForSequenceClassification
are supported by this repository and can be used by setting the model parameter of the train.py with the appropriate name of the model. Some of them are listed below and the others can be found on Huggingface website.
Models = {
"BERT base uncased": "bert-base-uncased",
"BERT base cased": "bert-base-cased",
"BERT large uncased": "bert-large-uncased",
"BERT large cased": "bert-large-cased",
"XLNet": "xlnet-base-cased",
"RoBERTa": "roberta-base",
"DistilBERT": "distilbert-base-uncased"
}
In order to fine-tune each of the transformer models on your dataset, you can execute the following bash file:
bash train.sh
Please note that before executing the bash file, you need to define a set of files path in it.
-model BERT, XLNet, RoBERTa, DistilBERT, etc.
-train path to training dataset.
-dev path to validation dataset if you want to validate the trained model, otherwise ignore it.
-res path to result dir. If not defined, the model will be saved in the experiments dir.
-max_sequence_len max length of sequence tokens.
-epoch number of epochs.
-train_batch_size batch size.
-valid_batch_size batch size.
-lr learning rate.
-n_warmup_steps warmup steps.
In order to inference the fine-tuned models, you can execute the following bash file:
bash inference.sh
Please note that before executing the bash file, you need to define a set of files path in it.
-model BERT, XLNet, RoBERTa, DistilBERT, etc.
-test path to test dataset.
-checkpoint path to checkpoint.
-res path to result dir.
-num_labels number of classes.
-max_sequence_len max length of sequence tokens.
-test_batch_size batch size.
Both train.py and inference.py scripts receive the datasets in .tsv format with the following format:
file | format (.tsv) |
---|---|
train | id \t sequence \t label |
dev | id \t sequence \t label |
test | id \t sequence |
In order to compare the performance of different transformer models on your dataset, you can perform k-fold cross validation by executing the following bash file:
bash cross_validation.sh
By default, dataset is the gender-annotated dataset and the result path is /data/cross_validation_path folder.