-
Notifications
You must be signed in to change notification settings - Fork 2
/
ws_multilingual_exp_sasaki.py
57 lines (44 loc) · 1.65 KB
/
ws_multilingual_exp_sasaki.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
import logging
import multiprocessing as mp
from pathlib import Path
from datasets import prepare_target_vector_paths, prepare_ws_combined_query_path
from sasaki_utils import inference, prepare_codecs_path, train, get_info_from_result_path
from utils import dotdict
from ws_multilingual_exp_pbos import evaluate
logger = logging.getLogger(__name__)
def exp(ref_vec_name):
result_path = Path("results") / "ws_multi" / f"{ref_vec_name}_sasaki"
ref_vec_path = prepare_target_vector_paths(f"wiki2vec-{ref_vec_name}").w2v_emb_path
codecs_path = prepare_codecs_path(ref_vec_path, result_path)
log_file = open(result_path / "log.txt", "w+")
logging.basicConfig(level=logging.DEBUG, stream=log_file)
logger.info("Training...")
train(
ref_vec_path,
result_path,
codecs_path=codecs_path,
H=40_000,
F=500_000,
epoch=300,
)
model_info = get_info_from_result_path(result_path / "sep_kvq")
logger.info("Inferencing...")
combined_query_path = prepare_ws_combined_query_path(ref_vec_name)
result_emb_path = inference(model_info, combined_query_path)
logger.info("Evaluating...")
evaluate(dotdict(
model_type="sasaki",
eval_result_path=result_path / "result.txt",
pred_path=result_emb_path,
target_vector_name=ref_vec_name,
results_dir=result_path,
))
if __name__ == '__main__':
with mp.Pool() as pool:
target_vector_names = ("en", "de", "it", "ru")
results = [
pool.apply_async(exp, (ref_vec_name,))
for ref_vec_name in target_vector_names
]
for r in results:
r.get()