-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathinstall_nlp_models.py
110 lines (85 loc) · 3.46 KB
/
install_nlp_models.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
"""Install the default NLP models defined in the provided yaml file."""
import argparse
import logging
from typing import Dict, Union
import yaml
from spacy.cli import download as spacy_download
try:
import stanza
except ImportError:
# stanza should be installed manually
stanza = None
try:
import transformers
from huggingface_hub import snapshot_download
from transformers import AutoModelForTokenClassification, AutoTokenizer
except ImportError:
# transformers should be installed manually
transformers = None
logger = logging.getLogger()
logger.setLevel("INFO")
logger.addHandler(logging.StreamHandler())
def install_models(conf_file: str) -> None:
"""Installs models in conf/default.yaml.
:param conf_file: Path to the yaml file containing the models to install.
See examples in the conf directory.
"""
nlp_configuration = yaml.safe_load(open(conf_file))
logger.info(f"Installing models from configuration: {nlp_configuration}")
if "nlp_engine_name" not in nlp_configuration:
raise ValueError("NLP config file should contain an nlp_engine_name field")
if "models" not in nlp_configuration:
raise ValueError("NLP config file should contain a list of models")
for model in nlp_configuration["models"]:
engine_name = nlp_configuration["nlp_engine_name"]
model_name = model["model_name"]
_download_model(engine_name, model_name)
logger.info("finished installing models")
def _download_model(engine_name: str, model_name: Union[str, Dict[str, str]]) -> None:
if engine_name == "spacy":
spacy_download(model_name)
elif engine_name == "stanza":
if stanza:
stanza.download(model_name)
else:
raise ImportError("stanza is not installed")
elif engine_name == "transformers":
if transformers:
_install_transformers_spacy_models(model_name)
else:
raise ImportError("transformers is not installed")
else:
raise ValueError(f"Unsupported nlp engine: {engine_name}")
def _install_transformers_spacy_models(model_name: Dict[str, str]) -> None:
if "spacy" not in model_name:
raise ValueError(
"transformers config should contain "
"a spacy model/pipeline such as en_core_web_sm"
)
if "transformers" not in model_name:
raise ValueError(
"transformers config should contain a path to a transformers model"
)
spacy_model = model_name["spacy"]
transformers_model = model_name["transformers"]
# download spacy model/pipeline
logger.info(f"Installing spaCy model: {spacy_model}")
spacy_download(spacy_model)
# download transformers model
logger.info(f"Installing transformers model: {transformers_model}")
snapshot_download(repo_id=transformers_model)
# Instantiate to make sure it's downloaded during installation and not runtime
AutoTokenizer.from_pretrained(transformers_model)
AutoModelForTokenClassification.from_pretrained(transformers_model)
if __name__ == "__main__":
parser = argparse.ArgumentParser(
description="Install NLP models into the presidio-analyzer Docker container"
)
parser.add_argument(
"--conf_file",
required=False,
default="presidio_analyzer/conf/default.yaml",
help="Location of nlp configuration yaml file. Default: conf/default.yaml",
)
args = parser.parse_args()
install_models(conf_file=args.conf_file)