-
Notifications
You must be signed in to change notification settings - Fork 0
/
main.py
71 lines (57 loc) · 1.83 KB
/
main.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
#!/usr/bin/env python
import torch
import typer
from allennlp.commands.train import train_model_from_file
from ger_wiki.batch_predictor import RunBatchPredictions
app = typer.Typer()
def train_model(name: str):
typer.echo(f"Running {name}")
try:
train_model_from_file(
parameter_filename=f"./configs/{name}.jsonnet",
serialization_dir=f"./models/{name}_model",
include_package=["ger_wiki", "allennlp_models"],
force=True,
)
except FileNotFoundError as e:
print(e)
cuda_device = 0 if torch.cuda.is_available() else -1
def get_predictions(name: str):
# run predictions on Wikipedia corpus using second model
batch_predictor = RunBatchPredictions(
archive_path=f"./models/{name}_model/model.tar.gz",
predictor_name="text_predictor",
text_path="./data_processing/data/raw/wiki/wiki_info.csv",
text_col="abs",
cuda_device=cuda_device,
)
batch_predictor.run_batch_predictions(batch_size=8)
batch_predictor.write_csv(csv_file="./data_processing/data/results/predictions.csv")
batch_predictor.write_json(
json_file="./data_processing/data/results/predictions.json"
)
def main(
name: str,
predict: bool = False,
):
"""
Choose an NER model to train, or use model predictor.
:param name str: Config name (see configs/)\n
:param baseline bool: Train baseline model\n
:param predict bool: Label Wikipedia corpus using predictor\n
"""
if predict:
get_predictions(name)
elif name == "all":
for name in [
"wiki_bert",
"wiki_crf_basic",
"wiki_crf",
"wiki_distil",
"wiki_roberta",
]:
train_model(name)
else:
train_model(name)
if __name__ == "__main__":
typer.run(main)