-
Notifications
You must be signed in to change notification settings - Fork 3
/
Copy pathmlm-trainer
executable file
·98 lines (87 loc) · 3.26 KB
/
mlm-trainer
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
#!/usr/bin/env python
"""Masked Modeling Language Training script."""
import json
import os
import sys
import click
from loguru import logger
from rxn_aa_mapper.dataset import DATASETS
from rxn_aa_mapper.training import train
@click.command()
@click.argument("train_dataset_dir", type=click.Path(exists=True))
@click.argument("val_dataset_dir", type=click.Path(exists=True))
@click.argument("vocabulary_file", type=click.Path(exists=True))
@click.argument("log_dir", type=click.Path())
@click.argument("config_file", type=click.Path(exists=True))
@click.argument("pattern", type=str)
@click.argument("processes", type=int, default=4)
@click.argument(
"train_organic_dataset_dir", type=click.Path(exists=True), required=False
)
@click.argument("val_organic_dataset_dir", type=click.Path(exists=True), required=False)
@click.argument(
"aa_sequence_tokenizer_filepath", type=click.Path(exists=False), required=False
)
def main(
train_dataset_dir: str,
val_dataset_dir: str,
vocabulary_file: str,
log_dir: str,
config_file: str,
pattern: str,
processes: int,
train_organic_dataset_dir: str = None,
val_organic_dataset_dir: str = None,
aa_sequence_tokenizer_filepath: str = None,
aa_sequence_tokenizer_type: str = None
) -> None:
"""Train a model via masked language modeling."""
with open(config_file) as fp:
config = json.load(fp)
model_args, dataset_args, trainer_args = (
config["model"],
config["dataset"],
config["trainer"],
)
dataset_type = dataset_args.get("dataset_type", "enzymatic")
supported_datasets = set(DATASETS.keys())
if dataset_type not in supported_datasets:
logger.error(
"the dataset type should be either `enzymatic` or `enzymatic-organic`"
)
sys.exit(1)
if dataset_type == "enzymatic-organic" and (
train_organic_dataset_dir is None or val_organic_dataset_dir is None
):
logger.error(
"enzymatic-organic data type requires `train_organic_dataset_dir` and `val_organic_dataset_dir` should be defined"
)
sys.exit(1)
logger.info(f"saving runs log to: {log_dir}")
trainer_args["log_dir"] = log_dir
model_architecture = model_args.get("architecture", {})
del model_args["architecture"]
dataset_args["num_dataloader_workers"] = min(processes, os.cpu_count())
dataset_args = {
**dataset_args,
**{
"train_dataset_dir": train_dataset_dir,
"val_dataset_dir": val_dataset_dir,
"vocabulary_file": vocabulary_file,
"aa_sequence_tokenizer_filepath": aa_sequence_tokenizer_filepath,
"aa_sequence_tokenizer_type": aa_sequence_tokenizer_type,
"pattern": pattern,
"train_organic_dataset_dir": train_organic_dataset_dir,
"val_organic_dataset_dir": val_organic_dataset_dir,
"dataset_type": dataset_type
},
}
if "shuffle" not in dataset_args:
dataset_args["shuffle"] = True
if "seed" not in dataset_args:
dataset_args["seed"] = None
if "samples_per_epoch" not in dataset_args:
dataset_args["samples_per_epoch"] = 10000000
train(model_args, model_architecture, dataset_args, trainer_args)
if __name__ == "__main__":
main()