From 298230177a844ebc01c8983fd28f23a7104b0713 Mon Sep 17 00:00:00 2001 From: marcoyang Date: Thu, 22 Aug 2024 14:22:56 +0800 Subject: [PATCH] fix format --- egs/audioset/AT/local/compute_weight.py | 12 ++++++------ 1 file changed, 6 insertions(+), 6 deletions(-) diff --git a/egs/audioset/AT/local/compute_weight.py b/egs/audioset/AT/local/compute_weight.py index 1da6eb23cc..a0deddc0c9 100644 --- a/egs/audioset/AT/local/compute_weight.py +++ b/egs/audioset/AT/local/compute_weight.py @@ -25,25 +25,24 @@ import lhotse from lhotse import load_manifest + def get_parser(): parser = argparse.ArgumentParser( formatter_class=argparse.ArgumentDefaultsHelpFormatter ) parser.add_argument( - "--input-manifest", - type=str, - default="data/fbank/cuts_audioset_full.jsonl.gz" + "--input-manifest", type=str, default="data/fbank/cuts_audioset_full.jsonl.gz" ) parser.add_argument( "--output", type=str, required=True, - ) return parser + def main(): # Reference: https://github.com/YuanGongND/ast/blob/master/egs/audioset/gen_weight_file.py parser = get_parser() @@ -53,7 +52,7 @@ def main(): print(f"A total of {len(cuts)} cuts.") - label_count = [0] * 527 # a total of 527 classes + label_count = [0] * 527 # a total of 527 classes for c in cuts: audio_event = c.supervisions[0].audio_event labels = list(map(int, audio_event.split(";"))) @@ -68,6 +67,7 @@ def main(): for label in labels: weight += 1000 / (label_count[label] + 0.01) f.write(f"{c.id} {weight}\n") - + + if __name__ == "__main__": main()