This repository has been archived by the owner on Dec 11, 2021. It is now read-only.
-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathtrain.py
74 lines (58 loc) · 2.23 KB
/
train.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
import argparse
import torch
from pytorch_lightning import Trainer, seed_everything
from pytorch_lightning.callbacks import ModelCheckpoint, LearningRateMonitor, EarlyStopping
from pytorch_lightning.plugins.training_type import DDPPlugin
from transformers import logging
from detector import SarcasmDataModule, SarcasmDetector, SarcasmProgressBar
from detector.util import get_device_count
logging.set_verbosity_error()
def main(args: argparse.Namespace):
if args.deterministic:
seed_everything(args.seed, workers=True)
dm = SarcasmDataModule.from_argparse_args(args)
callbacks = [
ModelCheckpoint(
filename='{epoch}-{val_loss:.2f}-{val_f1:.2f}-{val_accuracy:.2f}',
monitor='val_f1',
mode='max',
save_last=False,
save_top_k=1,
every_n_epochs=1
),
LearningRateMonitor(
logging_interval='step',
log_momentum=True
),
EarlyStopping(
monitor='val_f1',
min_delta=0.0,
patience=3,
mode='max'
),
SarcasmProgressBar()
]
multi_device = get_device_count(args.devices) * args.num_nodes > 1
trainer = Trainer.from_argparse_args(
args,
callbacks=callbacks,
strategy=DDPPlugin(find_unused_parameters=False) if multi_device else None
)
model = SarcasmDetector(**vars(args))
trainer.fit(model, datamodule=dm)
if args.test:
trainer.test(model)
if __name__ == "__main__":
parser = argparse.ArgumentParser(conflict_handler='resolve')
# General args
parser.add_argument('--seed', help="Seed to use if deterministic flag is set", default=3244, type=int)
parser.add_argument('--test', help="Perform testing at the end of training", action='store_true')
parser = SarcasmDataModule.add_argparse_args(parser)
parser = SarcasmDetector.add_argparse_args(parser)
parser = Trainer.add_argparse_args(parser)
parser.set_defaults(accelerator='auto',
devices=1,
enable_model_summary=False,
precision='bf16' if torch.cuda.is_bf16_supported() else 16,
max_epochs=4)
main(parser.parse_args())