-
Notifications
You must be signed in to change notification settings - Fork 31
/
predict.py
55 lines (47 loc) · 2.24 KB
/
predict.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
import argparse
import dill
from argparse import Namespace
import torch
import torchaudio
from utils import (detect_peaks, max_min_norm, replicate_first_k_frames)
from next_frame_classifier import NextFrameClassifier
def main(wav, ckpt, prominence):
print(f"running inference on: {wav}")
print(f"running inferece using ckpt: {ckpt}")
print("\n\n", 90 * "-")
ckpt = torch.load(ckpt, map_location=lambda storage, loc: storage)
hp = Namespace(**dict(ckpt["hparams"]))
# load weights and peak detection params
model = NextFrameClassifier(hp)
weights = ckpt["state_dict"]
weights = {k.replace("NFC.", ""): v for k,v in weights.items()}
model.load_state_dict(weights)
peak_detection_params = dill.loads(ckpt['peak_detection_params'])['cpc_1']
if prominence is not None:
print(f"overriding prominence with {prominence}")
peak_detection_params["prominence"] = prominence
# load data
audio, sr = torchaudio.load(wav)
assert sr == 16000, "model was trained with audio sampled at 16khz, please downsample."
audio = audio[0]
audio = audio.unsqueeze(0)
# run inference
preds = model(audio) # get scores
preds = preds[1][0] # get scores of positive pairs
preds = replicate_first_k_frames(preds, k=1, dim=1) # padding
preds = 1 - max_min_norm(preds) # normalize scores (good for visualizations)
preds = detect_peaks(x=preds,
lengths=[preds.shape[1]],
prominence=peak_detection_params["prominence"],
width=peak_detection_params["width"],
distance=peak_detection_params["distance"]) # run peak detection on scores
preds = preds[0] * 160 / sr # transform frame indexes to seconds
print("predicted boundaries (in seconds):")
print(preds)
if __name__ == "__main__":
parser = argparse.ArgumentParser(description='Unsupervised segmentation inference script')
parser.add_argument('--wav', help='path to wav file')
parser.add_argument('--ckpt', help='path to checkpoint file')
parser.add_argument('--prominence', type=float, default=None, help='prominence for peak detection (default: 0.05)')
args = parser.parse_args()
main(args.wav, args.ckpt, args.prominence)