-
Notifications
You must be signed in to change notification settings - Fork 46
/
encode.py
67 lines (47 loc) · 1.96 KB
/
encode.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
import hydra
import hydra.utils as utils
import json
import numpy as np
from pathlib import Path
from tqdm import tqdm
import torch
from model import Encoder
@hydra.main(config_path="config/encode.yaml")
def encode_dataset(cfg):
out_dir = Path(utils.to_absolute_path(cfg.out_dir))
out_dir.mkdir(exist_ok=True, parents=True)
if cfg.save_auxiliary:
aux_path = out_dir.parent / "auxiliary_embedding1"
aux_path.mkdir(exist_ok=True, parents=True)
root_path = Path(utils.to_absolute_path("datasets")) / cfg.dataset.path
with open(root_path / "test.json") as file:
metadata = json.load(file)
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
encoder = Encoder(**cfg.model.encoder)
encoder.to(device)
print("Load checkpoint from: {}:".format(cfg.checkpoint))
checkpoint_path = utils.to_absolute_path(cfg.checkpoint)
checkpoint = torch.load(checkpoint_path, map_location=lambda storage, loc: storage)
encoder.load_state_dict(checkpoint["encoder"])
encoder.eval()
if cfg.save_auxiliary:
auxiliary = []
def hook(module, input, output):
auxiliary.append(output.clone().transpose(1, 2))
encoder.encoder[-1].register_forward_hook(hook)
for _, _, _, path in tqdm(metadata):
path = root_path.parent / path
mel = torch.from_numpy(np.load(path.with_suffix(".mel.npy"))).unsqueeze(0).to(device)
with torch.no_grad():
z, indices = encoder.encode(mel)
z = z.squeeze().cpu().numpy()
out_path = out_dir / path.stem
with open(out_path.with_suffix(".txt"), "w") as file:
np.savetxt(file, z, fmt="%.16f")
if cfg.save_auxiliary:
out_path = aux_path / path.stem
aux = auxiliary.pop().squeeze().cpu().numpy()
with open(out_path.with_suffix(".txt"), "w") as file:
np.savetxt(file, aux, fmt="%.16f")
if __name__ == "__main__":
encode_dataset()