From a58c26520d08140c19a6fef748a1e4faa8bc62e6 Mon Sep 17 00:00:00 2001 From: Anjok07 <68268275+Anjok07@users.noreply.github.com> Date: Sun, 18 Dec 2022 21:18:56 -0600 Subject: [PATCH] Add files via upload --- UVR.py | 4830 ++++++++++++++++- __version__.py | 3 +- demucs/__init__.py | 5 + demucs/__main__.py | 272 + demucs/apply.py | 294 + demucs/demucs.py | 459 ++ demucs/filtering.py | 502 ++ demucs/hdemucs.py | 782 +++ demucs/htdemucs.py | 648 +++ demucs/model.py | 218 + demucs/model_v2.py | 218 + demucs/pretrained.py | 180 + demucs/repo.py | 148 + demucs/spec.py | 41 + demucs/states.py | 148 + demucs/tasnet.py | 447 ++ demucs/tasnet_v2.py | 452 ++ demucs/transformer.py | 839 +++ demucs/utils.py | 502 ++ lib_v5/spec_utils.py | 736 +++ lib_v5/vr_network/__init__.py | 1 + lib_v5/vr_network/layers.py | 143 + lib_v5/vr_network/layers_new.py | 126 + lib_v5/vr_network/model_param_init.py | 59 + .../modelparams/1band_sr16000_hl512.json | 19 + .../modelparams/1band_sr32000_hl512.json | 19 + .../modelparams/1band_sr33075_hl384.json | 19 + .../modelparams/1band_sr44100_hl1024.json | 19 + .../modelparams/1band_sr44100_hl256.json | 19 + .../modelparams/1band_sr44100_hl512.json | 19 + .../modelparams/1band_sr44100_hl512_cut.json | 19 + .../1band_sr44100_hl512_nf1024.json | 19 + .../vr_network/modelparams/2band_32000.json | 30 + .../modelparams/2band_44100_lofi.json | 30 + .../vr_network/modelparams/2band_48000.json | 30 + .../vr_network/modelparams/3band_44100.json | 42 + .../modelparams/3band_44100_mid.json | 43 + .../modelparams/3band_44100_msb2.json | 43 + .../vr_network/modelparams/4band_44100.json | 54 + .../modelparams/4band_44100_mid.json | 55 + .../modelparams/4band_44100_msb.json | 55 + .../modelparams/4band_44100_msb2.json | 55 + .../modelparams/4band_44100_reverse.json | 55 + .../modelparams/4band_44100_sw.json | 55 + lib_v5/vr_network/modelparams/4band_v2.json | 54 + .../vr_network/modelparams/4band_v2_sn.json | 55 + lib_v5/vr_network/modelparams/4band_v3.json | 54 + lib_v5/vr_network/modelparams/ensemble.json | 43 + lib_v5/vr_network/nets.py | 171 + lib_v5/vr_network/nets_new.py | 143 + .../v3_v4_repo/demucs_models.txt | 1 + .../MDX_Net_Models/model_data/model_data.json | 184 + models/VR_Models/model_data/model_data.json | 94 + separate.py | 924 ++++ 54 files changed, 14473 insertions(+), 2 deletions(-) create mode 100644 demucs/__init__.py create mode 100644 demucs/__main__.py create mode 100644 demucs/apply.py create mode 100644 demucs/demucs.py create mode 100644 demucs/filtering.py create mode 100644 demucs/hdemucs.py create mode 100644 demucs/htdemucs.py create mode 100644 demucs/model.py create mode 100644 demucs/model_v2.py create mode 100644 demucs/pretrained.py create mode 100644 demucs/repo.py create mode 100644 demucs/spec.py create mode 100644 demucs/states.py create mode 100644 demucs/tasnet.py create mode 100644 demucs/tasnet_v2.py create mode 100644 demucs/transformer.py create mode 100644 demucs/utils.py create mode 100644 lib_v5/spec_utils.py create mode 100644 lib_v5/vr_network/__init__.py create mode 100644 lib_v5/vr_network/layers.py create mode 100644 lib_v5/vr_network/layers_new.py create mode 100644 lib_v5/vr_network/model_param_init.py create mode 100644 lib_v5/vr_network/modelparams/1band_sr16000_hl512.json create mode 100644 lib_v5/vr_network/modelparams/1band_sr32000_hl512.json create mode 100644 lib_v5/vr_network/modelparams/1band_sr33075_hl384.json create mode 100644 lib_v5/vr_network/modelparams/1band_sr44100_hl1024.json create mode 100644 lib_v5/vr_network/modelparams/1band_sr44100_hl256.json create mode 100644 lib_v5/vr_network/modelparams/1band_sr44100_hl512.json create mode 100644 lib_v5/vr_network/modelparams/1band_sr44100_hl512_cut.json create mode 100644 lib_v5/vr_network/modelparams/1band_sr44100_hl512_nf1024.json create mode 100644 lib_v5/vr_network/modelparams/2band_32000.json create mode 100644 lib_v5/vr_network/modelparams/2band_44100_lofi.json create mode 100644 lib_v5/vr_network/modelparams/2band_48000.json create mode 100644 lib_v5/vr_network/modelparams/3band_44100.json create mode 100644 lib_v5/vr_network/modelparams/3band_44100_mid.json create mode 100644 lib_v5/vr_network/modelparams/3band_44100_msb2.json create mode 100644 lib_v5/vr_network/modelparams/4band_44100.json create mode 100644 lib_v5/vr_network/modelparams/4band_44100_mid.json create mode 100644 lib_v5/vr_network/modelparams/4band_44100_msb.json create mode 100644 lib_v5/vr_network/modelparams/4band_44100_msb2.json create mode 100644 lib_v5/vr_network/modelparams/4band_44100_reverse.json create mode 100644 lib_v5/vr_network/modelparams/4band_44100_sw.json create mode 100644 lib_v5/vr_network/modelparams/4band_v2.json create mode 100644 lib_v5/vr_network/modelparams/4band_v2_sn.json create mode 100644 lib_v5/vr_network/modelparams/4band_v3.json create mode 100644 lib_v5/vr_network/modelparams/ensemble.json create mode 100644 lib_v5/vr_network/nets.py create mode 100644 lib_v5/vr_network/nets_new.py create mode 100644 models/Demucs_Models/v3_v4_repo/demucs_models.txt create mode 100644 models/MDX_Net_Models/model_data/model_data.json create mode 100644 models/VR_Models/model_data/model_data.json create mode 100644 separate.py diff --git a/UVR.py b/UVR.py index 3dd08107..43fff6a3 100644 --- a/UVR.py +++ b/UVR.py @@ -1 +1,4829 @@ -##Code Undergoing Refactorization## +# GUI modules +import audioread +import base64 +import gui_data.sv_ttk +import hashlib +import json +import librosa +import logging +import math +import natsort +import onnx +import os +import pickle # Save Data +import psutil +import pyglet +import pyperclip +import base64 +import queue +import re +import shutil +import string +import subprocess +import sys +import soundfile as sf +import time +#import timeit +import tkinter as tk +import tkinter.filedialog +import tkinter.font +import tkinter.messagebox +import tkinter.ttk as ttk +import torch +import urllib.request +import webbrowser +import wget +import traceback +#import multiprocessing as KThread +from __version__ import VERSION, PATCH +from cryptography.fernet import Fernet +from cryptography.hazmat.primitives import hashes +from cryptography.hazmat.primitives.kdf.pbkdf2 import PBKDF2HMAC +from datetime import datetime +from gui_data.app_size_values import ImagePath, AdjustedValues as av +from gui_data.constants import * +from gui_data.error_handling import error_text, error_dialouge +from gui_data.old_data_check import file_check, remove_unneeded_yamls, remove_temps +from gui_data.tkinterdnd2 import TkinterDnD, DND_FILES # Enable Drag & Drop +from lib_v5.vr_network.model_param_init import ModelParameters +from kthread import KThread +from lib_v5 import spec_utils +from pathlib import Path +from separate import SeperateDemucs, SeperateMDX, SeperateVR, save_format +from playsound import playsound +from tkinter import * +from tkinter.tix import * +import re +from typing import List + +logging.basicConfig(format='%(asctime)s - %(message)s', level=logging.INFO) +logging.info('UVR BEGIN') + +try: + with open(os.path.join(os.getcwd(), 'tmp', 'splash.txt'), 'w') as f: + f.write('1') +except: + pass + +def save_data(data): + """ + Saves given data as a .pkl (pickle) file + + Paramters: + data(dict): + Dictionary containing all the necessary data to save + """ + # Open data file, create it if it does not exist + with open('data.pkl', 'wb') as data_file: + pickle.dump(data, data_file) + +def load_data() -> dict: + """ + Loads saved pkl file and returns the stored data + + Returns(dict): + Dictionary containing all the saved data + """ + try: + with open('data.pkl', 'rb') as data_file: # Open data file + data = pickle.load(data_file) + + return data + except (ValueError, FileNotFoundError): + # Data File is corrupted or not found so recreate it + + save_data(data=DEFAULT_DATA) + + return load_data() + +def load_model_hash_data(dictionary): + '''Get the model hash dictionary''' + + with open(dictionary) as d: + data = d.read() + + return json.loads(data) + +# Change the current working directory to the directory +# this file sits in +if getattr(sys, 'frozen', False): + # If the application is run as a bundle, the PyInstaller bootloader + # extends the sys module by a flag frozen=True and sets the app + # path into variable _MEIPASS'. + BASE_PATH = sys._MEIPASS +else: + BASE_PATH = os.path.dirname(os.path.abspath(__file__)) + +os.chdir(BASE_PATH) # Change the current working directory to the base path + +debugger = [] + +#--Constants-- +#Models +MODELS_DIR = os.path.join(BASE_PATH, 'models') +VR_MODELS_DIR = os.path.join(MODELS_DIR, 'VR_Models') +MDX_MODELS_DIR = os.path.join(MODELS_DIR, 'MDX_Net_Models') +DEMUCS_MODELS_DIR = os.path.join(MODELS_DIR, 'Demucs_Models') +DEMUCS_NEWER_REPO_DIR = os.path.join(DEMUCS_MODELS_DIR, 'v3_v4_repo') + +#Cache & Parameters +VR_HASH_DIR = os.path.join(VR_MODELS_DIR, 'model_data') +VR_HASH_JSON = os.path.join(VR_MODELS_DIR, 'model_data', 'model_data.json') +MDX_HASH_DIR = os.path.join(MDX_MODELS_DIR, 'model_data') +MDX_HASH_JSON = os.path.join(MDX_MODELS_DIR, 'model_data', 'model_data.json') +ENSEMBLE_CACHE_DIR = os.path.join(BASE_PATH, 'gui_data', 'saved_ensembles') +SETTINGS_CACHE_DIR = os.path.join(BASE_PATH, 'gui_data', 'saved_settings') +VR_PARAM_DIR = os.path.join(BASE_PATH, 'lib_v5', 'vr_network', 'modelparams') +SAMPLE_CLIP_PATH = os.path.join(BASE_PATH, 'temp_sample_clips') +ENSEMBLE_TEMP_PATH = os.path.join(BASE_PATH, 'ensemble_temps') + +#Style +ICON_IMG_PATH = os.path.join(BASE_PATH, 'gui_data', 'img', 'GUI-icon.ico') +MAIN_ICON_IMG_PATH = os.path.join(BASE_PATH, 'gui_data', 'img', 'GUI-icon.png') +FONT_PATH = os.path.join(BASE_PATH, 'gui_data', 'fonts', 'centurygothic', 'GOTHIC.TTF')#ensemble_temps +MENU_COMBOBOX_WIDTH = 18 + +#Other +COMPLETE_CHIME = os.path.join(BASE_PATH, 'gui_data', 'complete_chime.wav') +FAIL_CHIME = os.path.join(BASE_PATH, 'gui_data', 'fail_chime.wav') +CHANGE_LOG = os.path.join(BASE_PATH, 'gui_data', 'change_log.txt') +SPLASH_DOC = os.path.join(BASE_PATH, 'tmp', 'splash.txt') + +file_check(os.path.join(MODELS_DIR, 'Main_Models'), VR_MODELS_DIR) +file_check(os.path.join(DEMUCS_MODELS_DIR, 'v3_repo'), DEMUCS_NEWER_REPO_DIR) +remove_unneeded_yamls(DEMUCS_MODELS_DIR) + +remove_temps(ENSEMBLE_TEMP_PATH) +remove_temps(SAMPLE_CLIP_PATH) +remove_temps(os.path.join(BASE_PATH, 'img')) + +if not os.path.isdir(ENSEMBLE_TEMP_PATH): + os.mkdir(ENSEMBLE_TEMP_PATH) + +if not os.path.isdir(SAMPLE_CLIP_PATH): + os.mkdir(SAMPLE_CLIP_PATH) + +model_hash_table = {} +data = load_data() + +def drop(event, accept_mode: str = 'files'): + """Drag & Drop verification process""" + + path = event.data + + if accept_mode == 'folder': + path = path.replace('{', '').replace('}', '') + if not os.path.isdir(path): + tk.messagebox.showerror(title='Invalid Folder', + message='Your given export path is not a valid folder!') + return + # Set Variables + root.export_path_var.set(path) + elif accept_mode == 'files': + # Clean path text and set path to the list of paths + path = path.replace("{", "").replace("}", "") + for drive_letter in list(string.ascii_lowercase.upper()): + path = path.replace(f" {drive_letter}:", f";{drive_letter}:") + path = path.split(';') + path[-1] = path[-1].replace(';', '') + # Set Variables + root.inputPaths = tuple(path) + root.process_input_selections() + root.update_inputPaths() + + else: + # Invalid accept mode + return + +class ModelData(): + def __init__(self, model_name: str, + selected_process_method=ENSEMBLE_MODE, + is_secondary_model=False, + primary_model_primary_stem=None, + is_primary_model_primary_stem_only=False, + is_primary_model_secondary_stem_only=False, + is_pre_proc_model=False, + is_dry_check=False): + + self.is_gpu_conversion = 0 if root.is_gpu_conversion_var.get() else -1 + self.is_normalization = root.is_normalization_var.get() + self.is_primary_stem_only = root.is_primary_stem_only_var.get() + self.is_secondary_stem_only = root.is_secondary_stem_only_var.get() + self.is_denoise = root.is_denoise_var.get() + self.wav_type_set = root.wav_type_set + self.mp3_bit_set = root.mp3_bit_set_var.get() + self.save_format = root.save_format_var.get() + self.is_invert_spec = root.is_invert_spec_var.get() + self.demucs_stems = root.demucs_stems_var.get() + self.demucs_source_list = [] + self.demucs_stem_count = 0 + self.model_name = model_name + self.process_method = selected_process_method + self.model_status = False if self.model_name == CHOOSE_MODEL or self.model_name == NO_MODEL else True + self.primary_stem = None + self.secondary_stem = None + self.is_ensemble_mode = False + self.ensemble_primary_stem = None + self.ensemble_secondary_stem = None + self.primary_model_primary_stem = primary_model_primary_stem + self.is_secondary_model = is_secondary_model + self.secondary_model = None + self.secondary_model_scale = None + self.demucs_4_stem_added_count = 0 + self.is_demucs_4_stem_secondaries = False + self.is_4_stem_ensemble = False + self.pre_proc_model = None + self.pre_proc_model_activated = False + self.is_pre_proc_model = is_pre_proc_model + self.is_dry_check = is_dry_check + self.model_samplerate = 44100 + self.is_demucs_pre_proc_model_inst_mix = False + + self.secondary_model_4_stem = [] + self.secondary_model_4_stem_scale = [] + self.secondary_model_4_stem_names = [] + self.secondary_model_4_stem_model_names_list = [] + self.all_models = [] + + self.secondary_model_other = None + self.secondary_model_scale_other = None + self.secondary_model_bass = None + self.secondary_model_scale_bass = None + self.secondary_model_drums = None + self.secondary_model_scale_drums = None + + if selected_process_method == ENSEMBLE_MODE: + partitioned_name = model_name.partition(ENSEMBLE_PARTITION) + self.process_method = partitioned_name[0] + self.model_name = partitioned_name[2] + self.model_and_process_tag = model_name + self.ensemble_primary_stem, self.ensemble_secondary_stem = root.return_ensemble_stems() + self.is_ensemble_mode = True if not is_secondary_model and not is_pre_proc_model else False + self.is_4_stem_ensemble = True if root.ensemble_main_stem_var.get() == FOUR_STEM_ENSEMBLE and self.is_ensemble_mode else False + self.pre_proc_model_activated = root.is_demucs_pre_proc_model_activate_var.get() if not self.ensemble_primary_stem == VOCAL_STEM else False + + if self.process_method == VR_ARCH_TYPE: + self.is_secondary_model_activated = root.vr_is_secondary_model_activate_var.get() if not self.is_secondary_model else False + self.aggression_setting = float(int(root.aggression_setting_var.get())/100) + self.is_tta = root.is_tta_var.get() + self.is_post_process = root.is_post_process_var.get() + self.window_size = int(root.window_size_var.get()) + self.batch_size = int(root.batch_size_var.get()) + self.crop_size = int(root.crop_size_var.get()) + self.is_high_end_process = 'mirroring' if root.is_high_end_process_var.get() else 'None' + self.model_path = os.path.join(VR_MODELS_DIR, f"{self.model_name}.pth") + self.get_model_hash() + if self.model_hash: + self.model_data = self.get_model_data(VR_HASH_DIR, root.vr_hash_MAPPER) + if self.model_data: + vr_model_param = os.path.join(VR_PARAM_DIR, "{}.json".format(self.model_data["vr_model_param"])) + self.primary_stem = self.model_data["primary_stem"] + self.secondary_stem = STEM_PAIR_MAPPER[self.primary_stem] + self.vr_model_param = ModelParameters(vr_model_param) + self.model_samplerate = self.vr_model_param.param['sr'] + else: + self.model_status = False + + if self.process_method == MDX_ARCH_TYPE: + self.is_secondary_model_activated = root.mdx_is_secondary_model_activate_var.get() if not is_secondary_model else False + self.margin = int(root.margin_var.get()) + self.chunks = root.determine_auto_chunks(root.chunks_var.get(), self.is_gpu_conversion) + self.get_mdx_model_path() + self.get_model_hash() + if self.model_hash: + self.model_data = self.get_model_data(MDX_HASH_DIR, root.mdx_hash_MAPPER) + if self.model_data: + self.compensate = self.model_data["compensate"] if root.compensate_var.get() == AUTO_SELECT else float(root.compensate_var.get()) + self.mdx_dim_f_set = self.model_data["mdx_dim_f_set"] + self.mdx_dim_t_set = self.model_data["mdx_dim_t_set"] + self.mdx_n_fft_scale_set = self.model_data["mdx_n_fft_scale_set"] + self.primary_stem = self.model_data["primary_stem"] + self.secondary_stem = STEM_PAIR_MAPPER[self.primary_stem] + else: + self.model_status = False + + if self.process_method == DEMUCS_ARCH_TYPE: + self.is_secondary_model_activated = root.demucs_is_secondary_model_activate_var.get() if not is_secondary_model else False + if not self.is_ensemble_mode: + self.pre_proc_model_activated = root.is_demucs_pre_proc_model_activate_var.get() if not root.demucs_stems_var.get() in [VOCAL_STEM, INST_STEM] else False + self.overlap = float(root.overlap_var.get()) + self.margin_demucs = int(root.margin_demucs_var.get()) + self.chunks_demucs = root.determine_auto_chunks(root.chunks_demucs_var.get(), self.is_gpu_conversion) + self.shifts = int(root.shifts_var.get()) + self.is_split_mode = root.is_split_mode_var.get() + self.segment = root.segment_var.get() + self.is_chunk_demucs = root.is_chunk_demucs_var.get() + self.is_demucs_combine_stems = root.is_demucs_combine_stems_var.get() + self.is_primary_stem_only = root.is_primary_stem_only_var.get() if self.is_ensemble_mode else root.is_primary_stem_only_Demucs_var.get() + self.is_secondary_stem_only = root.is_secondary_stem_only_var.get() if self.is_ensemble_mode else root.is_secondary_stem_only_Demucs_var.get() + self.get_demucs_model_path() + self.get_demucs_model_data() + + self.model_basename = os.path.splitext(os.path.basename(self.model_path))[0] if self.model_status else None + self.pre_proc_model_activated = self.pre_proc_model_activated if not self.is_secondary_model else False + + self.is_primary_model_primary_stem_only = is_primary_model_primary_stem_only + self.is_primary_model_secondary_stem_only = is_primary_model_secondary_stem_only + + if self.is_secondary_model_activated and self.model_status: + + if (not self.is_ensemble_mode and root.demucs_stems_var.get() == ALL_STEMS and self.process_method == DEMUCS_ARCH_TYPE) or self.is_4_stem_ensemble: + + for key in DEMUCS_4_SOURCE_LIST: + self.secondary_model_data(key) + self.secondary_model_4_stem.append(self.secondary_model) + self.secondary_model_4_stem_scale.append(self.secondary_model_scale) + self.secondary_model_4_stem_names.append(key) + + self.demucs_4_stem_added_count = sum(i is not None for i in self.secondary_model_4_stem) + self.is_secondary_model_activated = False if all(i is None for i in self.secondary_model_4_stem) else True + self.demucs_4_stem_added_count = self.demucs_4_stem_added_count - 1 if self.is_secondary_model_activated else self.demucs_4_stem_added_count + if self.is_secondary_model_activated: + self.secondary_model_4_stem_model_names_list = [None if i is None else i.model_basename for i in self.secondary_model_4_stem] + self.is_demucs_4_stem_secondaries = True + else: + primary_stem = self.ensemble_primary_stem if self.is_ensemble_mode and self.process_method == DEMUCS_ARCH_TYPE else self.primary_stem + self.secondary_model_data(primary_stem) + + if self.process_method == DEMUCS_ARCH_TYPE and not is_secondary_model: + if self.demucs_stem_count >= 3 and self.pre_proc_model_activated: + self.pre_proc_model_activated = True + self.pre_proc_model = root.process_determine_demucs_pre_proc_model(self.primary_stem) + self.is_demucs_pre_proc_model_inst_mix = root.is_demucs_pre_proc_model_inst_mix_var.get() if self.pre_proc_model else False + + def secondary_model_data(self, primary_stem): + secondary_model_data = root.process_determine_secondary_model(self.process_method, primary_stem, self.is_primary_stem_only, self.is_secondary_stem_only) + self.secondary_model = secondary_model_data[0] + self.secondary_model_scale = secondary_model_data[1] + self.is_secondary_model_activated = False if not self.secondary_model else True + if self.secondary_model: + self.is_secondary_model_activated = False if self.secondary_model.model_basename == self.model_basename else True + + def get_mdx_model_path(self): + + for file_name, chosen_mdx_model in MDX_NAME_SELECT.items(): + if self.model_name in chosen_mdx_model: + self.model_path = os.path.join(MDX_MODELS_DIR, f"{file_name}.onnx") + break + else: + self.model_path = os.path.join(MDX_MODELS_DIR, f"{self.model_name}.onnx") + + self.mixer_path = os.path.join(MDX_MODELS_DIR, f"mixer_val.ckpt") + + def get_demucs_model_path(self): + + demucs_newer = [True for x in DEMUCS_NEWER_TAGS if x in self.model_name] + demucs_model_dir = DEMUCS_NEWER_REPO_DIR if demucs_newer else DEMUCS_MODELS_DIR + + for file_name, chosen_model in DEMUCS_NAME_SELECT.items(): + if self.model_name in chosen_model: + self.model_path = os.path.join(demucs_model_dir, file_name) + break + else: + self.model_path = os.path.join(DEMUCS_NEWER_REPO_DIR, f'{self.model_name}.yaml') + + def get_demucs_model_data(self): + + self.demucs_version = DEMUCS_V4 + + for key, value in DEMUCS_VERSION_MAPPER.items(): + if value in self.model_name: + self.demucs_version = key + + self.demucs_source_list = DEMUCS_2_SOURCE if DEMUCS_UVR_MODEL in self.model_name else DEMUCS_4_SOURCE + self.demucs_source_map = DEMUCS_2_SOURCE_MAPPER if DEMUCS_UVR_MODEL in self.model_name else DEMUCS_4_SOURCE_MAPPER + self.demucs_stem_count = 2 if DEMUCS_UVR_MODEL in self.model_name else 4 + + if not self.is_ensemble_mode: + self.primary_stem = PRIMARY_STEM if self.demucs_stems == ALL_STEMS else self.demucs_stems + self.secondary_stem = STEM_PAIR_MAPPER[self.primary_stem] + + def get_model_data(self, model_hash_dir, hash_mapper): + + model_settings_json = os.path.join(model_hash_dir, "{}.json".format(self.model_hash)) + + if os.path.isfile(model_settings_json): + return json.load(open(model_settings_json)) + else: + for hash, settings in hash_mapper.items(): + if self.model_hash in hash: + return settings + else: + return self.get_model_data_from_popup() + + def get_model_data_from_popup(self): + + if not self.is_dry_check: + confirm = tk.messagebox.askyesno(title=UNRECOGNIZED_MODEL[0], + message=f"\"{self.model_name}\"{UNRECOGNIZED_MODEL[1]}", + parent=root) + + if confirm: + if self.process_method == VR_ARCH_TYPE: + root.pop_up_vr_param(self.model_hash) + return root.vr_model_params + if self.process_method == MDX_ARCH_TYPE: + root.pop_up_mdx_model(self.model_hash, self.model_path) + return root.mdx_model_params + else: + return None + else: + return None + + def get_model_hash(self): + self.model_hash = None + + if not os.path.isfile(self.model_path): + self.model_status = False + self.model_hash is None + else: + if model_hash_table: + for (key, value) in model_hash_table.items(): + if self.model_path == key: + self.model_hash = value + break + + if not self.model_hash: + with open(self.model_path, 'rb') as f: + f.seek(- 10000 * 1024, 2) + self.model_hash = hashlib.md5(f.read()).hexdigest() + + table_entry = {self.model_path: self.model_hash} + model_hash_table.update(table_entry) + +class Ensembler(): + def __init__(self, is_manual_ensemble=False): + self.is_save_all_outputs_ensemble = root.is_save_all_outputs_ensemble_var.get() + chosen_ensemble_name = '{}'.format(root.chosen_ensemble_var.get().replace(" ", "_")) if not root.chosen_ensemble_var.get() == CHOOSE_ENSEMBLE_OPTION else 'Ensembled' + ensemble_algorithm = root.ensemble_type_var.get().partition("/") + ensemble_main_stem_pair = root.ensemble_main_stem_var.get().partition("/") + time_stamp = round(time.time()) + self.audio_tool = MANUAL_ENSEMBLE + self.main_export_path = Path(root.export_path_var.get()) + self.chosen_ensemble = f"_{chosen_ensemble_name}" if root.is_append_ensemble_name_var.get() else '' + ensemble_folder_name = self.main_export_path if self.is_save_all_outputs_ensemble else ENSEMBLE_TEMP_PATH + self.ensemble_folder_name = os.path.join(ensemble_folder_name, '{}_Outputs_{}'.format(chosen_ensemble_name, time_stamp)) + self.is_testing_audio = f"{time_stamp}_" if root.is_testing_audio_var.get() else '' + self.primary_algorithm = ensemble_algorithm[0] + self.secondary_algorithm = ensemble_algorithm[2] + self.ensemble_primary_stem = ensemble_main_stem_pair[0] + self.ensemble_secondary_stem = ensemble_main_stem_pair[2] + self.is_normalization = root.is_normalization_var.get() + self.wav_type_set = root.wav_type_set + self.mp3_bit_set = root.mp3_bit_set_var.get() + self.save_format = root.save_format_var.get() + if not is_manual_ensemble: + os.mkdir(self.ensemble_folder_name) + + def ensemble_outputs(self, audio_file_base, export_path, stem, is_4_stem=False, is_inst_mix=False): + """Processes the given outputs and ensembles them with the chosen algorithm""" + + if is_4_stem: + algorithm = root.ensemble_type_var.get() + stem_tag = stem + else: + if is_inst_mix: + algorithm = self.secondary_algorithm + stem_tag = f"{self.ensemble_secondary_stem} {INST_STEM}" + else: + algorithm = self.primary_algorithm if stem == PRIMARY_STEM else self.secondary_algorithm + stem_tag = self.ensemble_primary_stem if stem == PRIMARY_STEM else self.ensemble_secondary_stem + + stem_outputs = self.get_files_to_ensemble(folder=export_path, prefix=audio_file_base, suffix=f"_({stem_tag}).wav") + audio_file_output = f"{self.is_testing_audio}{audio_file_base}{self.chosen_ensemble}_({stem_tag})" + stem_save_path = os.path.join('{}'.format(self.main_export_path),'{}.wav'.format(audio_file_output)) + if stem_outputs: + spec_utils.ensemble_inputs(stem_outputs, algorithm, self.is_normalization, self.wav_type_set, stem_save_path) + save_format(stem_save_path, self.save_format, self.mp3_bit_set) + + if self.is_save_all_outputs_ensemble: + for i in stem_outputs: + save_format(i, self.save_format, self.mp3_bit_set) + else: + for i in stem_outputs: + try: + os.remove(i) + except Exception as e: + print(e) + + def ensemble_manual(self, audio_inputs, audio_file_base): + """Processes the given outputs and ensembles them with the chosen algorithm""" + + algorithm = root.choose_algorithm_var.get() + stem_save_path = os.path.join('{}'.format(self.main_export_path),'{}{}_({}).wav'.format(self.is_testing_audio, audio_file_base, algorithm)) + spec_utils.ensemble_inputs(audio_inputs, algorithm, self.is_normalization, self.wav_type_set, stem_save_path) + save_format(stem_save_path, self.save_format, self.mp3_bit_set) + + def get_files_to_ensemble(self, folder="", prefix="", suffix=""): + """Grab all the files to be ensembles""" + + return [os.path.join(folder, i) for i in os.listdir(folder) if i.startswith(prefix) and i.endswith(suffix)] + +class AudioTools(): + def __init__(self, audio_tool): + time_stamp = round(time.time()) + self.audio_tool = audio_tool + self.main_export_path = Path(root.export_path_var.get()) + self.wav_type_set = root.wav_type_set + self.is_normalization = root.is_normalization_var.get() + self.is_testing_audio = f"{time_stamp}_" if root.is_testing_audio_var.get() else '' + self.save_format = lambda save_path:save_format(save_path, root.save_format_var.get(), root.mp3_bit_set_var.get()) + + def align_inputs(self, audio_inputs, audio_file_base, audio_file_2_base, command_Text): + audio_file_base = f"{self.is_testing_audio}{audio_file_base}" + audio_file_2_base = f"{self.is_testing_audio}{audio_file_2_base}" + + aligned_path = os.path.join('{}'.format(self.main_export_path),'{}_aligned.wav'.format(audio_file_2_base)) + inverted_path = os.path.join('{}'.format(self.main_export_path),'{}_inverted.wav'.format(audio_file_base)) + + spec_utils.align_audio(audio_inputs[0], audio_inputs[1], aligned_path, inverted_path, self.wav_type_set, self.is_normalization, command_Text, root.progress_bar_main_var, self.save_format) + + def pitch_or_time_shift(self, audio_file, audio_file_base): + + rate = float(root.time_stretch_rate_var.get()) if self.audio_tool == TIME_STRETCH else float(root.pitch_rate_var.get()) + is_pitch = False if self.audio_tool == TIME_STRETCH else True + file_text = TIME_TEXT if self.audio_tool == TIME_STRETCH else PITCH_TEXT + save_path = os.path.join(self.main_export_path, f"{self.is_testing_audio}{audio_file_base}{file_text}.wav") + spec_utils.augment_audio(save_path, audio_file, rate, self.is_normalization, self.wav_type_set, self.save_format, is_pitch=is_pitch) + +class ToolTip(object): + + def __init__(self, widget): + self.widget = widget + self.tipwindow = None + self.id = None + self.x = self.y = 0 + + def showtip(self, text): + "Display text in tooltip window" + self.text = text + if self.tipwindow or not self.text: + return + x, y, cx, cy = self.widget.bbox("insert") + x = x + self.widget.winfo_rootx() + 57 + y = y + cy + self.widget.winfo_rooty() +27 + self.tipwindow = tw = Toplevel(self.widget) + tw.wm_overrideredirect(1) + tw.wm_geometry("+%d+%d" % (x, y)) + label = Label(tw, text=self.text, justify=LEFT, + background="#ffffe0", foreground="black", relief=SOLID, borderwidth=1, + font=("tahoma", "8", "normal")) + label.pack(ipadx=1) + + def hidetip(self): + tw = self.tipwindow + self.tipwindow = None + if tw: + tw.destroy() + +class ThreadSafeConsole(tk.Text): + """ + Text Widget which is thread safe for tkinter + """ + + def __init__(self, master, **options): + tk.Text.__init__(self, master, **options) + self.queue = queue.Queue() + self.update_me() + + def write(self, line): + self.queue.put(line) + + def clear(self): + self.queue.put(None) + + def update_me(self): + self.configure(state=tk.NORMAL) + try: + while 1: + line = self.queue.get_nowait() + if line is None: + self.delete(1.0, tk.END) + else: + self.insert(tk.END, str(line)) + self.see(tk.END) + self.update_idletasks() + except queue.Empty: + pass + self.configure(state=tk.DISABLED) + self.after(100, self.update_me) + + def copy_text(self): + hightlighted_text = self.selection_get() + self.clipboard_clear() + self.clipboard_append(hightlighted_text) + + def select_all_text(self): + self.tag_add('sel', '1.0', 'end') + +class MainWindow(TkinterDnD.Tk): + # --Constants-- + # Layout + + IMAGE_HEIGHT = av.IMAGE_HEIGHT + FILEPATHS_HEIGHT = av.FILEPATHS_HEIGHT + OPTIONS_HEIGHT = av.OPTIONS_HEIGHT + CONVERSIONBUTTON_HEIGHT = av.CONVERSIONBUTTON_HEIGHT + COMMAND_HEIGHT = av.COMMAND_HEIGHT + PROGRESS_HEIGHT = av.PROGRESS_HEIGHT + PADDING = av.PADDING + COL1_ROWS = 11 + COL2_ROWS = 11 + + def __init__(self): + #Run the __init__ method on the tk.Tk class + super().__init__() + + gui_data.sv_ttk.set_theme("dark") + gui_data.sv_ttk.use_dark_theme() # Set dark theme + + # Calculate window height + height = self.IMAGE_HEIGHT + self.FILEPATHS_HEIGHT + self.OPTIONS_HEIGHT + height += self.CONVERSIONBUTTON_HEIGHT + self.COMMAND_HEIGHT + self.PROGRESS_HEIGHT + height += self.PADDING * 5 # Padding + width = 680 + self.main_window_width = width + self.main_window_height = height + + # --Window Settings-- + + self.title('Ultimate Vocal Remover') + # Set Geometry and Center Window + self.geometry('{width}x{height}+{xpad}+{ypad}'.format( + width=self.main_window_width, + height=height, + xpad=int(self.winfo_screenwidth()/2 - width/2), + ypad=int(self.winfo_screenheight()/2 - height/2 - 30))) + + self.tk.call('wm', 'iconphoto', self._w, tk.PhotoImage(file=MAIN_ICON_IMG_PATH)) + self.configure(bg='#0e0e0f') # Set background color to #0c0c0d + self.protocol("WM_DELETE_WINDOW", self.save_values) + self.resizable(False, False) + #self.withdraw() + self.update() + + #Load Images + img = ImagePath(BASE_PATH) + self.logo_img = img.open_image(path=img.banner_path, size=(self.winfo_width(), 9999)) + self.efile_img = img.efile_img + self.stop_img = img.stop_img + self.help_img = img.help_img + self.download_img = img.download_img + self.donate_img = img.donate_img + self.key_img = img.key_img + self.credits_img = img.credits_img + + #Placeholders + self.error_log_var = tk.StringVar(value='') + self.vr_secondary_model_names = [] + self.mdx_secondary_model_names = [] + self.demucs_secondary_model_names = [] + self.vr_primary_model_names = [] + self.mdx_primary_model_names = [] + self.demucs_primary_model_names = [] + + self.vr_cache_source_mapper = {} + self.mdx_cache_source_mapper = {} + self.demucs_cache_source_mapper = {} + + # -Tkinter Value Holders- + + try: + self.load_saved_vars(data) + except Exception as e: + self.error_log_var.set(error_text('Loading Saved Variables', e)) + self.load_saved_vars(DEFAULT_DATA) + + self.cached_sources_clear() + + self.method_mapper = { + VR_ARCH_PM: self.vr_model_var, + MDX_ARCH_TYPE: self.mdx_net_model_var, + DEMUCS_ARCH_TYPE: self.demucs_model_var} + + self.vr_secondary_model_vars = {'voc_inst_secondary_model': self.vr_voc_inst_secondary_model_var, + 'other_secondary_model': self.vr_other_secondary_model_var, + 'bass_secondary_model': self.vr_bass_secondary_model_var, + 'drums_secondary_model': self.vr_drums_secondary_model_var, + 'is_secondary_model_activate': self.vr_is_secondary_model_activate_var, + 'voc_inst_secondary_model_scale': self.vr_voc_inst_secondary_model_scale_var, + 'other_secondary_model_scale': self.vr_other_secondary_model_scale_var, + 'bass_secondary_model_scale': self.vr_bass_secondary_model_scale_var, + 'drums_secondary_model_scale': self.vr_drums_secondary_model_scale_var} + + self.demucs_secondary_model_vars = {'voc_inst_secondary_model': self.demucs_voc_inst_secondary_model_var, + 'other_secondary_model': self.demucs_other_secondary_model_var, + 'bass_secondary_model': self.demucs_bass_secondary_model_var, + 'drums_secondary_model': self.demucs_drums_secondary_model_var, + 'is_secondary_model_activate': self.demucs_is_secondary_model_activate_var, + 'voc_inst_secondary_model_scale': self.demucs_voc_inst_secondary_model_scale_var, + 'other_secondary_model_scale': self.demucs_other_secondary_model_scale_var, + 'bass_secondary_model_scale': self.demucs_bass_secondary_model_scale_var, + 'drums_secondary_model_scale': self.demucs_drums_secondary_model_scale_var} + + self.mdx_secondary_model_vars = {'voc_inst_secondary_model': self.mdx_voc_inst_secondary_model_var, + 'other_secondary_model': self.mdx_other_secondary_model_var, + 'bass_secondary_model': self.mdx_bass_secondary_model_var, + 'drums_secondary_model': self.mdx_drums_secondary_model_var, + 'is_secondary_model_activate': self.mdx_is_secondary_model_activate_var, + 'voc_inst_secondary_model_scale': self.mdx_voc_inst_secondary_model_scale_var, + 'other_secondary_model_scale': self.mdx_other_secondary_model_scale_var, + 'bass_secondary_model_scale': self.mdx_bass_secondary_model_scale_var, + 'drums_secondary_model_scale': self.mdx_drums_secondary_model_scale_var} + + #Main Application Vars + self.progress_bar_main_var = tk.IntVar(value=0) + self.inputPathsEntry_var = tk.StringVar(value='') + self.conversion_Button_Text_var = tk.StringVar(value=START_PROCESSING) + self.chosen_ensemble_var = tk.StringVar(value=CHOOSE_ENSEMBLE_OPTION) + self.ensemble_main_stem_var = tk.StringVar(value=CHOOSE_STEM_PAIR) + self.ensemble_type_var = tk.StringVar(value=MAX_MIN) + self.save_current_settings_var = tk.StringVar(value=SELECT_SAVED_SET) + self.demucs_stems_var = tk.StringVar(value=ALL_STEMS) + self.is_primary_stem_only_Text_var = tk.StringVar(value='') + self.is_secondary_stem_only_Text_var = tk.StringVar(value='') + self.is_primary_stem_only_Demucs_Text_var = tk.StringVar(value='') + self.is_secondary_stem_only_Demucs_Text_var = tk.StringVar(value='') + self.scaling_var = tk.DoubleVar(value=1.0) + self.active_processing_thread = None + self.verification_thread = None + self.is_menu_settings_open = False + + self.is_open_menu_advanced_vr_options = tk.BooleanVar(value=False) + self.is_open_menu_advanced_demucs_options = tk.BooleanVar(value=False) + self.is_open_menu_advanced_mdx_options = tk.BooleanVar(value=False) + self.is_open_menu_advanced_ensemble_options = tk.BooleanVar(value=False) + self.is_open_menu_view_inputs = tk.BooleanVar(value=False) + self.is_open_menu_help = tk.BooleanVar(value=False) + self.is_open_menu_error_log = tk.BooleanVar(value=False) + + self.mdx_model_params = None + self.vr_model_params = None + self.current_text_box = None + self.wav_type_set = None + self.is_online_model_menu = None + self.progress_bar_var = tk.IntVar(value=0) + self.is_confirm_error_var = tk.BooleanVar(value=False) + self.clear_cache_torch = False + self.vr_hash_MAPPER = load_model_hash_data(VR_HASH_JSON) + self.mdx_hash_MAPPER = load_model_hash_data(MDX_HASH_JSON) + self.is_gpu_available = torch.cuda.is_available() + self.is_process_stopped = False + self.inputs_from_dir = [] + self.iteration = 0 + self.vr_primary_source = None + self.vr_secondary_source = None + self.mdx_primary_source = None + self.mdx_secondary_source = None + self.demucs_primary_source = None + self.demucs_secondary_source = None + + #Download Center Vars + self.online_data = {} + self.is_online = False + self.lastest_version = '' + self.model_download_demucs_var = tk.StringVar(value='') + self.model_download_mdx_var = tk.StringVar(value='') + self.model_download_vr_var = tk.StringVar(value='') + self.selected_download_var = tk.StringVar(value=NO_MODEL) + self.select_download_var = tk.StringVar(value='') + self.download_progress_info_var = tk.StringVar(value='') + self.download_progress_percent_var = tk.StringVar(value='') + self.download_progress_bar_var = tk.IntVar(value=0) + self.download_stop_var = tk.StringVar(value='') + self.app_update_status_Text_var = tk.StringVar(value='') + self.app_update_button_Text_var = tk.StringVar(value='') + self.user_code_validation_var = tk.StringVar(value='') + self.download_link_path_var = tk.StringVar(value='') + self.download_save_path_var = tk.StringVar(value='') + self.download_update_link_var = tk.StringVar(value='') + self.download_update_path_var = tk.StringVar(value='') + self.download_demucs_models_list = [] + self.download_demucs_newer_models = [] + self.refresh_list_Button = None + self.stop_download_Button_DISABLE = None + self.enable_tabs = None + self.is_download_thread_active = False + self.is_process_thread_active = False + self.is_active_processing_thread = False + self.active_download_thread = None + + # Font + pyglet.font.add_file(FONT_PATH) + self.font = tk.font.Font(family='Century Gothic', size=10) + self.fontRadio = tk.font.Font(family='Century Gothic', size=9) + + #Model Update + self.last_found_ensembles = ENSEMBLE_OPTIONS + self.last_found_settings = ENSEMBLE_OPTIONS + self.last_found_models = () + self.model_data_table = () + self.ensemble_model_list = () + + # --Widgets-- + self.fill_main_frame() + self.bind_widgets() + self.online_data_refresh(user_refresh=False) + + # --Update Widgets-- + self.update_available_models() + self.update_main_widget_states() + self.update_loop() + self.update_button_states() + self.delete_temps() + self.download_validate_code() + self.ensemble_listbox_Option.configure(state=tk.DISABLED) + + self.command_Text.write(f'Ultimate Vocal Remover {VERSION} [{datetime.now().strftime("%Y-%m-%d %H:%M:%S")}]') + self.new_update_notify = lambda latest_version:self.command_Text.write(f"\n\nNew Update Found: {latest_version}\n\nClick the update button in the \"Settings\" menu to download and install!") + self.update_checkbox_text = lambda:self.selection_action_process_method(self.chosen_process_method_var.get()) + + # Menu Functions + def main_window_LABEL_SET(self, master, text):return ttk.Label(master=master, text=text, background='#0e0e0f', font=self.font, foreground='#13a4c9', anchor=tk.CENTER) + def menu_title_LABEL_SET(self, frame, text, width=35):return ttk.Label(master=frame, text=text, font=("Century Gothic", "12", "underline"), justify="center", foreground="#13a4c9", width=width, anchor=tk.CENTER) + def menu_sub_LABEL_SET(self, frame, text, font_size=9):return ttk.Label(master=frame, text=text, font=("Century Gothic", f"{font_size}"), foreground='#13a4c9', anchor=tk.CENTER) + def menu_FRAME_SET(self, frame):return Frame(frame, highlightbackground='#0e0e0f', highlightcolor='#0e0e0f', highlightthicknes=20) + def check_is_menu_settings_open(self):self.menu_settings() if not self.is_menu_settings_open else None + + def check_is_open_menu_advanced_vr_options(self): + if not self.is_open_menu_advanced_vr_options.get(): + self.menu_advanced_vr_options() + else: + self.menu_advanced_vr_options_close_window() + self.menu_advanced_vr_options() + + def check_is_open_menu_advanced_demucs_options(self): + if not self.is_open_menu_advanced_demucs_options.get(): + self.menu_advanced_demucs_options() + else: + self.menu_advanced_demucs_options_close_window() + self.menu_advanced_demucs_options() + + def check_is_open_menu_advanced_mdx_options(self): + if not self.is_open_menu_advanced_mdx_options.get(): + self.menu_advanced_mdx_options() + else: + self.menu_advanced_mdx_options_close_window() + self.menu_advanced_mdx_options() + + def check_is_open_menu_advanced_ensemble_options(self): + if not self.is_open_menu_advanced_ensemble_options.get(): + self.menu_advanced_ensemble_options() + else: + self.menu_advanced_ensemble_options_close_window() + self.menu_advanced_ensemble_options() + + def check_is_open_menu_help(self): + if not self.is_open_menu_help.get(): + self.menu_help() + else: + self.menu_help_close_window() + self.menu_help() + + def check_is_open_menu_error_log(self): + if not self.is_open_menu_error_log.get(): + self.menu_error_log() + else: + self.menu_error_log_close_window() + self.menu_error_log() + + def check_is_open_menu_view_inputs(self): + if not self.is_open_menu_view_inputs.get(): + self.menu_view_inputs() + else: + self.menu_view_inputs_close_window() + self.menu_view_inputs() + + #Ensemble Listbox Functions + def ensemble_listbox_get_all_selected_models(self):return [self.ensemble_listbox_Option.get(i) for i in self.ensemble_listbox_Option.curselection()] + def ensemble_listbox_select_from_indexs(self, indexes):return [self.ensemble_listbox_Option.selection_set(i) for i in indexes] + def ensemble_listbox_clear_and_insert_new(self, model_ensemble_updated):return (self.ensemble_listbox_Option.delete(0, 'end'), [self.ensemble_listbox_Option.insert(tk.END, models) for models in model_ensemble_updated]) + def ensemble_listbox_get_indexes_for_files(self, updated, selected):return [updated.index(model) for model in selected if model in updated] + + def process_iteration(self): + self.iteration = self.iteration + 1 + + def assemble_model_data(self, model=None, arch_type=ENSEMBLE_MODE, is_dry_check=False): + + if arch_type == ENSEMBLE_STEM_CHECK: + + model_data = self.model_data_table + missing_models = [model.model_status for model in model_data if not model.model_status] + + if missing_models or not model_data: + model_data: List[ModelData] = [ModelData(model_name, is_dry_check=is_dry_check) for model_name in self.ensemble_model_list] + self.model_data_table = model_data + + if arch_type == ENSEMBLE_MODE: + model_data: List[ModelData] = [ModelData(model_name) for model_name in self.ensemble_listbox_get_all_selected_models()] + if arch_type == ENSEMBLE_CHECK: + model_data: List[ModelData] = [ModelData(model)] + if arch_type == VR_ARCH_TYPE or arch_type == VR_ARCH_PM: + model_data: List[ModelData] = [ModelData(model, VR_ARCH_TYPE)] + if arch_type == MDX_ARCH_TYPE: + model_data: List[ModelData] = [ModelData(model, MDX_ARCH_TYPE)] + if arch_type == DEMUCS_ARCH_TYPE: + model_data: List[ModelData] = [ModelData(model, DEMUCS_ARCH_TYPE)] + + return model_data + + def clear_cache(self, network): + + if network == VR_ARCH_TYPE: + dir = VR_HASH_DIR + if network == MDX_ARCH_TYPE: + dir = MDX_HASH_DIR + + [os.remove(os.path.join(dir, x)) for x in os.listdir(dir) if x not in 'model_data.json'] + self.vr_model_var.set(CHOOSE_MODEL) + self.mdx_net_model_var.set(CHOOSE_MODEL) + self.model_data_table.clear() + self.chosen_ensemble_var.set(CHOOSE_ENSEMBLE_OPTION) + self.ensemble_main_stem_var.set(CHOOSE_STEM_PAIR) + self.ensemble_listbox_Option.configure(state=tk.DISABLED) + self.update_checkbox_text() + + def thread_check(self, thread_to_check): + '''Checks if thread is alive''' + + is_running = False + + if type(thread_to_check) is KThread: + if thread_to_check.is_alive(): + is_running = True + + return is_running + + # -Widget Methods-- + + def fill_main_frame(self): + """Creates root window widgets""" + + self.title_Label = tk.Label(master=self, image=self.logo_img, compound=tk.TOP) + self.title_Label.place(x=-2, y=-2) + + button_y = self.IMAGE_HEIGHT + self.FILEPATHS_HEIGHT + self.OPTIONS_HEIGHT - 8 + self.PADDING*2 + + self.fill_filePaths_Frame() + self.fill_options_Frame() + + self.conversion_Button = ttk.Button(master=self, textvariable=self.conversion_Button_Text_var, command=self.process_initialize) + self.conversion_Button.place(x=50, y=button_y, width=-100, height=35, + relx=0, rely=0, relwidth=1, relheight=0) + self.conversion_Button_enable = lambda:(self.conversion_Button_Text_var.set(START_PROCESSING), self.conversion_Button.configure(state=tk.NORMAL)) + self.conversion_Button_disable = lambda message:(self.conversion_Button_Text_var.set(message), self.conversion_Button.configure(state=tk.DISABLED)) + + self.stop_Button = ttk.Button(master=self, image=self.stop_img, command=self.confirm_stop_process) + self.stop_Button.place(x=-10 - 35, y=button_y, width=35, height=35, + relx=1, rely=0, relwidth=0, relheight=0) + self.help_hints(self.stop_Button, text=STOP_HELP) + + self.settings_Button = ttk.Button(master=self, image=self.help_img, command=self.check_is_menu_settings_open) + self.settings_Button.place(x=-670, y=button_y, width=35, height=35, + relx=1, rely=0, relwidth=0, relheight=0) + self.help_hints(self.settings_Button, text=SETTINGS_HELP) + + self.progressbar = ttk.Progressbar(master=self, variable=self.progress_bar_main_var) + self.progressbar.place(x=25, y=self.IMAGE_HEIGHT + self.FILEPATHS_HEIGHT + self.OPTIONS_HEIGHT + self.CONVERSIONBUTTON_HEIGHT + self.COMMAND_HEIGHT + self.PADDING*4, width=-50, height=self.PROGRESS_HEIGHT, + relx=0, rely=0, relwidth=1, relheight=0) + + # Select Music Files Option + self.console_Frame = Frame(master=self, highlightbackground='#101012', highlightcolor='#101012', highlightthicknes=2) + self.console_Frame.place(x=15, y=self.IMAGE_HEIGHT + self.FILEPATHS_HEIGHT + self.OPTIONS_HEIGHT + self.CONVERSIONBUTTON_HEIGHT + self.PADDING + 5 *3, width=-30, height=self.COMMAND_HEIGHT+7, + relx=0, rely=0, relwidth=1, relheight=0) + + self.command_Text = ThreadSafeConsole(master=self.console_Frame, background='#0c0c0d',fg='#898b8e', font=('Century Gothic', 11), borderwidth=0) + self.command_Text.pack(fill=BOTH, expand=1) + self.command_Text.bind('', lambda e:self.right_click_console(e)) + + def fill_filePaths_Frame(self): + """Fill Frame with neccessary widgets""" + + # Select Music Files Option + self.filePaths_Frame = ttk.Frame(master=self) + self.filePaths_Frame.place(x=10, y=155, width=-20, height=self.FILEPATHS_HEIGHT, relx=0, rely=0, relwidth=1, relheight=0) + + self.filePaths_musicFile_Button = ttk.Button(master=self.filePaths_Frame, text='Select Input', command=self.input_select_filedialog) + self.filePaths_musicFile_Button.place(x=0, y=5, width=0, height=-5, relx=0, rely=0, relwidth=0.3, relheight=0.5) + self.filePaths_musicFile_Entry = ttk.Entry(master=self.filePaths_Frame, textvariable=self.inputPathsEntry_var, font=self.fontRadio, state=tk.DISABLED) + self.filePaths_musicFile_Entry.place(x=7.5, y=5, width=-50, height=-5, relx=0.3, rely=0, relwidth=0.7, relheight=0.5) + self.filePaths_musicFile_Open = ttk.Button(master=self, image=self.efile_img, command=lambda:os.startfile(os.path.dirname(self.inputPaths[0])) if self.inputPaths and os.path.isdir(os.path.dirname(self.inputPaths[0])) else self.error_dialoge(INVALID_INPUT)) + self.filePaths_musicFile_Open.place(x=-45, y=160, width=35, height=33, relx=1, rely=0, relwidth=0, relheight=0) + self.filePaths_musicFile_Entry.configure(cursor="hand2") + self.help_hints(self.filePaths_musicFile_Button, text=INPUT_FOLDER_ENTRY_HELP) + self.help_hints(self.filePaths_musicFile_Open, text=INPUT_FOLDER_BUTTON_HELP) + + # Save To Option + self.filePaths_saveTo_Button = ttk.Button(master=self.filePaths_Frame, text='Select Output', command=self.export_select_filedialog) + self.filePaths_saveTo_Button.place(x=0, y=5, width=0, height=-5, relx=0, rely=0.5, relwidth=0.3, relheight=0.5) + self.filePaths_saveTo_Entry = ttk.Entry(master=self.filePaths_Frame, textvariable=self.export_path_var, font=self.fontRadio, state=tk.DISABLED) + self.filePaths_saveTo_Entry.place(x=7.5, y=5, width=-50, height=-5, relx=0.3, rely=0.5, relwidth=0.7, relheight=0.5) + self.filePaths_saveTo_Open = ttk.Button(master=self, image=self.efile_img, command=lambda:os.startfile(Path(self.export_path_var.get())) if os.path.isdir(self.export_path_var.get()) else self.error_dialoge(INVALID_EXPORT)) + self.filePaths_saveTo_Open.place(x=-45, y=197.5, width=35, height=33, relx=1, rely=0, relwidth=0, relheight=0) + self.help_hints(self.filePaths_saveTo_Button, text=OUTPUT_FOLDER_ENTRY_HELP) + self.help_hints(self.filePaths_saveTo_Entry, text=OUTPUT_FOLDER_ENTRY_HELP) + self.help_hints(self.filePaths_saveTo_Open, text=OUTPUT_FOLDER_BUTTON_HELP) + + def fill_options_Frame(self): + """Fill Frame with neccessary widgets""" + + self.options_Frame = ttk.Frame(master=self) + self.options_Frame.place(x=10, y=250, width=-20, height=self.OPTIONS_HEIGHT, relx=0, rely=0, relwidth=1, relheight=0) + + # -Create Widgets- + + ## Save Format + self.wav_button = ttk.Radiobutton(master=self.options_Frame, text=WAV, variable=self.save_format_var, value=WAV) + self.wav_button.place(x=457, y=-5, width=0, height=6, relx=0, rely=0/self.COL2_ROWS, relwidth=1/3, relheight=1/self.COL2_ROWS) + self.help_hints(self.wav_button, text=f'{FORMAT_SETTING_HELP}{WAV}') + self.flac_button = ttk.Radiobutton(master=self.options_Frame, text=FLAC, variable=self.save_format_var, value=FLAC) + self.flac_button.place(x=300, y=-5, width=0, height=6, relx=1/3, rely=0/self.COL2_ROWS, relwidth=1/3, relheight=1/self.COL2_ROWS) + self.help_hints(self.flac_button, text=f'{FORMAT_SETTING_HELP}{FLAC}') + self.mp3_button = ttk.Radiobutton(master=self.options_Frame, text=MP3, variable=self.save_format_var, value=MP3) + self.mp3_button.place(x=143, y=-5, width=0, height=6, relx=2/3, rely=0/self.COL2_ROWS, relwidth=1/3, relheight=1/self.COL2_ROWS) + self.help_hints(self.mp3_button, text=f'{FORMAT_SETTING_HELP}{MP3}') + + # Choose Conversion Method + self.chosen_process_method_Label = self.main_window_LABEL_SET(self.options_Frame, 'Choose Process Method')#tk.Button(master=self.options_Frame, text='Choose Process Method', anchor=tk.CENTER, background='#0e0e0f', font=self.font, foreground='#13a4c9', borderwidth=0, command=lambda:self.pop_up_vr_param('ihbuhb')) + self.chosen_process_method_Label.place(x=0, y=MAIN_ROW_Y[0], width=LEFT_ROW_WIDTH, height=LABEL_HEIGHT, relx=0, rely=2/self.COL1_ROWS, relwidth=1/3, relheight=1/self.COL1_ROWS) + self.chosen_process_method_Option = ttk.OptionMenu(self.options_Frame, self.chosen_process_method_var, None, *PROCESS_METHODS, command=lambda s:self.selection_action_process_method(s, from_widget=True)) + self.chosen_process_method_Option.place(x=0, y=MAIN_ROW_Y[1], width=LEFT_ROW_WIDTH, height=OPTION_HEIGHT, relx=0, rely=3/self.COL1_ROWS, relwidth=1/3, relheight=1/self.COL1_ROWS) + self.chosen_process_method_var.trace_add('write', lambda *args: self.update_main_widget_states()) + self.help_hints(self.chosen_process_method_Label, text=CHOSEN_PROCESS_METHOD_HELP) + + # Choose Settings Option + self.save_current_settings_Label = self.main_window_LABEL_SET(self.options_Frame, 'Select Saved Settings') + self.save_current_settings_Label_place = lambda:self.save_current_settings_Label.place(x=MAIN_ROW_2_X[0], y=LOW_MENU_Y[0], width=0, height=LABEL_HEIGHT, relx=2/3, rely=6/self.COL1_ROWS, relwidth=1/3, relheight=1/self.COL1_ROWS) + self.save_current_settings_Option = ttk.OptionMenu(self.options_Frame, self.save_current_settings_var) + self.save_current_settings_Option_place = lambda:self.save_current_settings_Option.place(x=MAIN_ROW_2_X[1], y=LOW_MENU_Y[1], width=MAIN_ROW_WIDTH, height=OPTION_HEIGHT, relx=2/3, rely=7/self.COL1_ROWS, relwidth=1/3, relheight=1/self.COL1_ROWS) + self.help_hints(self.save_current_settings_Label, text=SAVE_CURRENT_SETTINGS_HELP) + + ### MDX-NET ### + + # Choose MDX-Net Model + self.mdx_net_model_Label = self.main_window_LABEL_SET(self.options_Frame, 'Choose MDX-Net Model') + self.mdx_net_model_Label_place = lambda:self.mdx_net_model_Label.place(x=0, y=LOW_MENU_Y[0], width=LEFT_ROW_WIDTH, height=LABEL_HEIGHT, relx=0, rely=6/self.COL1_ROWS, relwidth=1/3, relheight=1/self.COL1_ROWS) + self.mdx_net_model_Option = ttk.OptionMenu(self.options_Frame, self.mdx_net_model_var) + self.mdx_net_model_Option_place = lambda:self.mdx_net_model_Option.place(x=0, y=LOW_MENU_Y[1], width=LEFT_ROW_WIDTH, height=OPTION_HEIGHT, relx=0, rely=7/self.COL1_ROWS, relwidth=1/3, relheight=1/self.COL1_ROWS) + self.help_hints(self.mdx_net_model_Label, text=CHOOSE_MODEL_HELP) + + # MDX-chunks + self.chunks_Label = self.main_window_LABEL_SET(self.options_Frame, 'Chunks') + self.chunks_Label_place = lambda:self.chunks_Label.place(x=MAIN_ROW_X[0], y=MAIN_ROW_Y[0], width=0, height=LABEL_HEIGHT, relx=1/3, rely=2/self.COL1_ROWS, relwidth=1/3, relheight=1/self.COL2_ROWS) + self.chunks_Option = ttk.Combobox(self.options_Frame, value=CHUNKS, textvariable=self.chunks_var) + self.chunks_Option_place = lambda:self.chunks_Option.place(x=MAIN_ROW_X[1], y=MAIN_ROW_Y[1], width=MAIN_ROW_WIDTH, height=OPTION_HEIGHT, relx=1/3, rely=3/self.COL1_ROWS, relwidth=1/3, relheight=1/self.COL2_ROWS) + self.combobox_entry_validation(self.chunks_Option, self.chunks_var, REG_CHUNKS, CHUNKS) + self.help_hints(self.chunks_Label, text=CHUNKS_HELP) + + # MDX-Margin + self.margin_Label = self.main_window_LABEL_SET(self.options_Frame, 'Margin Size') + self.margin_Label_place = lambda:self.margin_Label.place(x=MAIN_ROW_2_X[0], y=MAIN_ROW_2_Y[0], width=0, height=LABEL_HEIGHT, relx=2/3, rely=2/self.COL1_ROWS, relwidth=1/3, relheight=1/self.COL2_ROWS) + self.margin_Option = ttk.Combobox(self.options_Frame, value=MARGIN_SIZE, textvariable=self.margin_var) + self.margin_Option_place = lambda:self.margin_Option.place(x=MAIN_ROW_2_X[1], y=MAIN_ROW_2_Y[1], width=MAIN_ROW_WIDTH, height=OPTION_HEIGHT, relx=2/3, rely=3/self.COL1_ROWS, relwidth=1/3, relheight=1/self.COL2_ROWS) + self.combobox_entry_validation(self.margin_Option, self.margin_var, REG_MARGIN, MARGIN_SIZE) + self.help_hints(self.margin_Label, text=MARGIN_HELP) + + ### VR ARCH ### + + # Choose VR Model + self.vr_model_Label = self.main_window_LABEL_SET(self.options_Frame, 'Choose VR Model') + self.vr_model_Label_place = lambda:self.vr_model_Label.place(x=0, y=LOW_MENU_Y[0], width=LEFT_ROW_WIDTH, height=LABEL_HEIGHT, relx=0, rely=6/self.COL1_ROWS, relwidth=1/3, relheight=1/self.COL1_ROWS) + self.vr_model_Option = ttk.OptionMenu(self.options_Frame, self.vr_model_var) + self.vr_model_Option_place = lambda:self.vr_model_Option.place(x=0, y=LOW_MENU_Y[1], width=LEFT_ROW_WIDTH, height=OPTION_HEIGHT, relx=0, rely=7/self.COL1_ROWS, relwidth=1/3, relheight=1/self.COL1_ROWS) + self.help_hints(self.vr_model_Label, text=CHOOSE_MODEL_HELP) + + # Aggression Setting + self.aggression_setting_Label = self.main_window_LABEL_SET(self.options_Frame, 'Aggression Setting') + self.aggression_setting_Label_place = lambda:self.aggression_setting_Label.place(x=MAIN_ROW_2_X[0], y=MAIN_ROW_2_Y[0], width=0, height=LABEL_HEIGHT, relx=2/3, rely=2/self.COL2_ROWS, relwidth=1/3, relheight=1/self.COL2_ROWS) + self.aggression_setting_Option = ttk.Combobox(self.options_Frame, value=VR_AGGRESSION, textvariable=self.aggression_setting_var) + self.aggression_setting_Option_place = lambda:self.aggression_setting_Option.place(x=MAIN_ROW_2_X[1], y=MAIN_ROW_2_Y[1], width=MAIN_ROW_WIDTH, height=OPTION_HEIGHT, relx=2/3, rely=3/self.COL2_ROWS, relwidth=1/3, relheight=1/self.COL2_ROWS) + self.combobox_entry_validation(self.aggression_setting_Option, self.aggression_setting_var, REG_AGGRESSION, ['10']) + self.help_hints(self.aggression_setting_Label, text=AGGRESSION_SETTING_HELP) + + # Window Size + self.window_size_Label = self.main_window_LABEL_SET(self.options_Frame, 'Window Size')#anchor=tk.CENTER + self.window_size_Label_place = lambda:self.window_size_Label.place(x=MAIN_ROW_X[0], y=MAIN_ROW_Y[0], width=0, height=LABEL_HEIGHT, relx=1/3, rely=2/self.COL2_ROWS, relwidth=1/3, relheight=1/self.COL2_ROWS) + self.window_size_Option = ttk.Combobox(self.options_Frame, value=VR_WINDOW, textvariable=self.window_size_var) + self.window_size_Option_place = lambda:self.window_size_Option.place(x=MAIN_ROW_X[1], y=MAIN_ROW_Y[1], width=MAIN_ROW_WIDTH, height=OPTION_HEIGHT, relx=1/3, rely=3/self.COL2_ROWS, relwidth=1/3, relheight=1/self.COL2_ROWS) + self.combobox_entry_validation(self.window_size_Option, self.window_size_var, REG_WINDOW, VR_WINDOW) + self.help_hints(self.window_size_Label, text=WINDOW_SIZE_HELP) + + ### DEMUCS ### + + # Choose Demucs Model + self.demucs_model_Label = self.main_window_LABEL_SET(self.options_Frame, 'Choose Demucs Model') + self.demucs_model_Label_place = lambda:self.demucs_model_Label.place(x=0, y=LOW_MENU_Y[0], width=LEFT_ROW_WIDTH, height=LABEL_HEIGHT, relx=0, rely=6/self.COL1_ROWS, relwidth=1/3, relheight=1/self.COL1_ROWS) + self.demucs_model_Option = ttk.OptionMenu(self.options_Frame, self.demucs_model_var) + self.demucs_model_Option_place = lambda:self.demucs_model_Option.place(x=0, y=LOW_MENU_Y[1], width=LEFT_ROW_WIDTH, height=OPTION_HEIGHT, relx=0, rely=7/self.COL1_ROWS, relwidth=1/3, relheight=1/self.COL1_ROWS) + self.help_hints(self.demucs_model_Label, text=CHOOSE_MODEL_HELP) + + # Choose Demucs Stems + self.demucs_stems_Label = self.main_window_LABEL_SET(self.options_Frame, 'Choose Stem(s)') + self.demucs_stems_Label_place = lambda:self.demucs_stems_Label.place(x=MAIN_ROW_X[0], y=MAIN_ROW_Y[0], width=0, height=LABEL_HEIGHT, relx=1/3, rely=2/self.COL2_ROWS, relwidth=1/3, relheight=1/self.COL2_ROWS) + self.demucs_stems_Option = ttk.OptionMenu(self.options_Frame, self.demucs_stems_var, None) + self.demucs_stems_Option_place = lambda:self.demucs_stems_Option.place(x=MAIN_ROW_X[1], y=MAIN_ROW_Y[1], width=MAIN_ROW_WIDTH, height=OPTION_HEIGHT, relx=1/3, rely=3/self.COL2_ROWS, relwidth=1/3, relheight=1/self.COL2_ROWS) + self.help_hints(self.demucs_stems_Label, text=DEMUCS_STEMS_HELP) + + # Demucs-Segment + self.segment_Label = self.main_window_LABEL_SET(self.options_Frame, 'Segment') + self.segment_Label_place = lambda:self.segment_Label.place(x=MAIN_ROW_2_X[0], y=MAIN_ROW_2_Y[0], width=0, height=LABEL_HEIGHT, relx=2/3, rely=2/self.COL2_ROWS, relwidth=1/3, relheight=1/self.COL2_ROWS) + self.segment_Option = ttk.Combobox(self.options_Frame, value=DEMUCS_SEGMENTS, textvariable=self.segment_var) + self.segment_Option_place = lambda:self.segment_Option.place(x=MAIN_ROW_2_X[1], y=MAIN_ROW_2_Y[1], width=MAIN_ROW_WIDTH, height=OPTION_HEIGHT, relx=2/3, rely=3/self.COL2_ROWS, relwidth=1/3, relheight=1/self.COL2_ROWS) + self.combobox_entry_validation(self.segment_Option, self.segment_var, REG_SEGMENTS, DEMUCS_SEGMENTS) + self.help_hints(self.segment_Label, text=SEGMENT_HELP) + + # Stem A + self.is_primary_stem_only_Demucs_Option = ttk.Checkbutton(master=self.options_Frame, textvariable=self.is_primary_stem_only_Demucs_Text_var, variable=self.is_primary_stem_only_Demucs_var, command=lambda:self.is_primary_stem_only_Demucs_Option_toggle()) + self.is_primary_stem_only_Demucs_Option_place = lambda:self.is_primary_stem_only_Demucs_Option.place(x=CHECK_BOX_X, y=CHECK_BOX_Y, width=CHECK_BOX_WIDTH, height=CHECK_BOX_HEIGHT, relx=1/3, rely=6/self.COL2_ROWS, relwidth=1/3, relheight=1/self.COL2_ROWS) + self.is_primary_stem_only_Demucs_Option_toggle = lambda:self.is_secondary_stem_only_Demucs_var.set(False) if self.is_primary_stem_only_Demucs_var.get() else self.is_secondary_stem_only_Demucs_Option.configure(state=tk.NORMAL) + self.help_hints(self.is_primary_stem_only_Demucs_Option, text=SAVE_STEM_ONLY_HELP) + + # Stem B + self.is_secondary_stem_only_Demucs_Option = ttk.Checkbutton(master=self.options_Frame, textvariable=self.is_secondary_stem_only_Demucs_Text_var, variable=self.is_secondary_stem_only_Demucs_var, command=lambda:self.is_secondary_stem_only_Demucs_Option_toggle()) + self.is_secondary_stem_only_Demucs_Option_place = lambda:self.is_secondary_stem_only_Demucs_Option.place(x=CHECK_BOX_X, y=CHECK_BOX_Y, width=CHECK_BOX_WIDTH, height=CHECK_BOX_HEIGHT, relx=1/3, rely=7/self.COL2_ROWS, relwidth=1/3, relheight=1/self.COL2_ROWS) + self.is_secondary_stem_only_Demucs_Option_toggle = lambda:self.is_primary_stem_only_Demucs_var.set(False) if self.is_secondary_stem_only_Demucs_var.get() else self.is_primary_stem_only_Demucs_Option.configure(state=tk.NORMAL) + self.is_stem_only_Demucs_Options_Enable = lambda:(self.is_primary_stem_only_Demucs_Option.configure(state=tk.NORMAL), self.is_secondary_stem_only_Demucs_Option.configure(state=tk.NORMAL)) + self.help_hints(self.is_secondary_stem_only_Demucs_Option, text=SAVE_STEM_ONLY_HELP) + + ### ENSEMBLE MODE ### + + # Ensemble Mode + self.chosen_ensemble_Label = self.main_window_LABEL_SET(self.options_Frame, 'Ensemble Options') + self.chosen_ensemble_Label_place = lambda:self.chosen_ensemble_Label.place(x=0, y=LOW_MENU_Y[0], width=LEFT_ROW_WIDTH, height=LABEL_HEIGHT, relx=0, rely=6/self.COL1_ROWS, relwidth=1/3, relheight=1/self.COL1_ROWS) + self.chosen_ensemble_Option = ttk.OptionMenu(self.options_Frame, self.chosen_ensemble_var) + self.chosen_ensemble_Option_place = lambda:self.chosen_ensemble_Option.place(x=0, y=LOW_MENU_Y[1], width=LEFT_ROW_WIDTH, height=OPTION_HEIGHT, relx=0, rely=7/self.COL1_ROWS, relwidth=1/3, relheight=1/self.COL1_ROWS) + self.help_hints(self.chosen_ensemble_Label, text=CHOSEN_ENSEMBLE_HELP) + + # Ensemble Main Stems + self.ensemble_main_stem_Label = self.main_window_LABEL_SET(self.options_Frame, 'Main Stem Pair') + self.ensemble_main_stem_Label_place = lambda:self.ensemble_main_stem_Label.place(x=MAIN_ROW_X[0], y=MAIN_ROW_Y[0], width=0, height=LABEL_HEIGHT, relx=1/3, rely=2/self.COL1_ROWS, relwidth=1/3, relheight=1/self.COL1_ROWS) + self.ensemble_main_stem_Option = ttk.OptionMenu(self.options_Frame, self.ensemble_main_stem_var, None, *ENSEMBLE_MAIN_STEM, command=self.selection_action_ensemble_stems) + self.ensemble_main_stem_Option_place = lambda:self.ensemble_main_stem_Option.place(x=MAIN_ROW_X[1], y=MAIN_ROW_Y[1], width=MAIN_ROW_WIDTH, height=OPTION_HEIGHT, relx=1/3, rely=3/self.COL1_ROWS, relwidth=1/3, relheight=1/self.COL1_ROWS) + self.help_hints(self.ensemble_main_stem_Label, text=ENSEMBLE_MAIN_STEM_HELP) + + # Ensemble Algorithm + self.ensemble_type_Label = self.main_window_LABEL_SET(self.options_Frame, 'Ensemble Algorithm') + self.ensemble_type_Label_place = lambda:self.ensemble_type_Label.place(x=MAIN_ROW_2_X[0], y=MAIN_ROW_2_Y[0], width=0, height=LABEL_HEIGHT, relx=2/3, rely=2/11, relwidth=1/3, relheight=1/self.COL1_ROWS) + self.ensemble_type_Option = ttk.OptionMenu(self.options_Frame, self.ensemble_type_var, None, *ENSEMBLE_TYPE) + self.ensemble_type_Option_place = lambda:self.ensemble_type_Option.place(x=MAIN_ROW_2_X[1], y=MAIN_ROW_2_Y[1], width=MAIN_ROW_WIDTH, height=OPTION_HEIGHT,relx=2/3, rely=3/11, relwidth=1/3, relheight=1/self.COL1_ROWS) + self.help_hints(self.ensemble_type_Label, text=ENSEMBLE_TYPE_HELP) + + # Select Music Files Option + + # Ensemble Save Ensemble Outputs + + self.ensemble_listbox_Label = self.main_window_LABEL_SET(self.options_Frame, 'Available Models') + self.ensemble_listbox_Label_place = lambda:self.ensemble_listbox_Label.place(x=MAIN_ROW_2_X[0], y=MAIN_ROW_2_Y[1], width=0, height=LABEL_HEIGHT, relx=2/3, rely=5/11, relwidth=1/3, relheight=1/self.COL1_ROWS) + self.ensemble_listbox_Frame = Frame(self.options_Frame, highlightbackground='#04332c', highlightcolor='#04332c', highlightthicknes=1) + self.ensemble_listbox_Option = tk.Listbox(self.ensemble_listbox_Frame, selectmode=tk.MULTIPLE, activestyle='dotbox', font=("Century Gothic", "8"), background='#070708', exportselection=0, relief=SOLID, borderwidth=0) + self.ensemble_listbox_scroll = ttk.Scrollbar(self.options_Frame, orient=VERTICAL) + self.ensemble_listbox_Option.config(yscrollcommand=self.ensemble_listbox_scroll.set) + self.ensemble_listbox_scroll.configure(command=self.ensemble_listbox_Option.yview) + self.ensemble_listbox_Option_place = lambda:(self.ensemble_listbox_Frame.place(x=-25, y=-20, width=0, height=67, relx=2/3, rely=6/11, relwidth=1/3, relheight=1/self.COL1_ROWS), + self.ensemble_listbox_scroll.place(x=195, y=-20, width=-48, height=69, relx=2/3, rely=6/11, relwidth=1/10, relheight=1/self.COL1_ROWS)) + self.ensemble_listbox_Option_pack = lambda:self.ensemble_listbox_Option.pack(fill=BOTH, expand=1) + self.help_hints(self.ensemble_listbox_Label, text=ENSEMBLE_LISTBOX_HELP) + + ### AUDIO TOOLS ### + + # Chosen Audio Tool + self.chosen_audio_tool_Label = self.main_window_LABEL_SET(self.options_Frame, 'Choose Audio Tool') + self.chosen_audio_tool_Label_place = lambda:self.chosen_audio_tool_Label.place(x=0, y=LOW_MENU_Y[0], width=LEFT_ROW_WIDTH, height=LABEL_HEIGHT, relx=0, rely=6/self.COL1_ROWS, relwidth=1/3, relheight=1/self.COL1_ROWS) + self.chosen_audio_tool_Option = ttk.OptionMenu(self.options_Frame, self.chosen_audio_tool_var, None, *AUDIO_TOOL_OPTIONS) + self.chosen_audio_tool_Option_place = lambda:self.chosen_audio_tool_Option.place(x=0, y=LOW_MENU_Y[1], width=LEFT_ROW_WIDTH, height=OPTION_HEIGHT, relx=0, rely=7/self.COL1_ROWS, relwidth=1/3, relheight=1/self.COL1_ROWS) + self.chosen_audio_tool_var.trace_add('write', lambda *args: self.update_main_widget_states()) + self.help_hints(self.chosen_audio_tool_Label, text=AUDIO_TOOLS_HELP) + + # Choose Agorithim + self.choose_algorithm_Label = self.main_window_LABEL_SET(self.options_Frame, 'Choose Algorithm') + self.choose_algorithm_Label_place = lambda:self.choose_algorithm_Label.place(x=MAIN_ROW_X[0], y=MAIN_ROW_Y[0], width=0, height=LABEL_HEIGHT, relx=1/3, rely=2/self.COL1_ROWS, relwidth=1/3, relheight=1/self.COL2_ROWS) + self.choose_algorithm_Option = ttk.OptionMenu(self.options_Frame, self.choose_algorithm_var, None, *MANUAL_ENSEMBLE_OPTIONS) + self.choose_algorithm_Option_place = lambda:self.choose_algorithm_Option.place(x=MAIN_ROW_X[1], y=MAIN_ROW_Y[1], width=MAIN_ROW_WIDTH, height=OPTION_HEIGHT, relx=1/3, rely=3/self.COL1_ROWS, relwidth=1/3, relheight=1/self.COL2_ROWS) + + # Time Stretch + self.time_stretch_rate_Label = self.main_window_LABEL_SET(self.options_Frame, 'Rate') + self.time_stretch_rate_Label_place = lambda:self.time_stretch_rate_Label.place(x=MAIN_ROW_X[0], y=MAIN_ROW_Y[0], width=0, height=LABEL_HEIGHT, relx=1/3, rely=2/self.COL1_ROWS, relwidth=1/3, relheight=1/self.COL2_ROWS) + self.time_stretch_rate_Option = ttk.Combobox(self.options_Frame, value=TIME_PITCH, textvariable=self.time_stretch_rate_var) + self.time_stretch_rate_Option_place = lambda:self.time_stretch_rate_Option.place(x=MAIN_ROW_X[1], y=MAIN_ROW_Y[1], width=MAIN_ROW_WIDTH, height=OPTION_HEIGHT, relx=1/3, rely=3/self.COL1_ROWS, relwidth=1/3, relheight=1/self.COL2_ROWS) + self.combobox_entry_validation(self.time_stretch_rate_Option, self.time_stretch_rate_var, REG_TIME_PITCH, TIME_PITCH) + + # Pitch Rate + self.pitch_rate_Label = self.main_window_LABEL_SET(self.options_Frame, 'Semitones') + self.pitch_rate_Label_place = lambda:self.pitch_rate_Label.place(x=MAIN_ROW_X[0], y=MAIN_ROW_Y[0], width=0, height=LABEL_HEIGHT, relx=1/3, rely=2/self.COL1_ROWS, relwidth=1/3, relheight=1/self.COL2_ROWS) + self.pitch_rate_Option = ttk.Combobox(self.options_Frame, value=TIME_PITCH, textvariable=self.pitch_rate_var) + self.pitch_rate_Option_place = lambda:self.pitch_rate_Option.place(x=MAIN_ROW_X[1], y=MAIN_ROW_Y[1], width=MAIN_ROW_WIDTH, height=OPTION_HEIGHT, relx=1/3, rely=3/self.COL1_ROWS, relwidth=1/3, relheight=1/self.COL2_ROWS) + self.combobox_entry_validation(self.pitch_rate_Option, self.pitch_rate_var, REG_TIME_PITCH, TIME_PITCH) + + ### SHARED SETTINGS ### + + # GPU Selection + self.is_gpu_conversion_Option = ttk.Checkbutton(master=self.options_Frame, text='GPU Conversion', variable=self.is_gpu_conversion_var) + self.is_gpu_conversion_Option_place = lambda:self.is_gpu_conversion_Option.place(x=CHECK_BOX_X, y=CHECK_BOX_Y, width=CHECK_BOX_WIDTH, height=CHECK_BOX_HEIGHT, relx=1/3, rely=5/self.COL2_ROWS, relwidth=1/3, relheight=1/self.COL2_ROWS) + self.is_gpu_conversion_Disable = lambda:(self.is_gpu_conversion_Option.configure(state=tk.DISABLED), self.is_gpu_conversion_var.set(False)) + self.help_hints(self.is_gpu_conversion_Option, text=IS_GPU_CONVERSION_HELP) + + # Vocal Only + self.is_primary_stem_only_Option = ttk.Checkbutton(master=self.options_Frame, textvariable=self.is_primary_stem_only_Text_var, variable=self.is_primary_stem_only_var, command=lambda:self.is_primary_stem_only_Option_toggle()) + self.is_primary_stem_only_Option_place = lambda:self.is_primary_stem_only_Option.place(x=CHECK_BOX_X, y=CHECK_BOX_Y, width=CHECK_BOX_WIDTH, height=CHECK_BOX_HEIGHT, relx=1/3, rely=6/self.COL2_ROWS, relwidth=1/3, relheight=1/self.COL2_ROWS) + self.is_primary_stem_only_Option_toggle = lambda:self.is_secondary_stem_only_var.set(False) if self.is_primary_stem_only_var.get() else self.is_secondary_stem_only_Option.configure(state=tk.NORMAL) + self.help_hints(self.is_primary_stem_only_Option, text=SAVE_STEM_ONLY_HELP) + + # Instrumental Only + self.is_secondary_stem_only_Option = ttk.Checkbutton(master=self.options_Frame, textvariable=self.is_secondary_stem_only_Text_var, variable=self.is_secondary_stem_only_var, command=lambda:self.is_secondary_stem_only_Option_toggle()) + self.is_secondary_stem_only_Option_place = lambda:self.is_secondary_stem_only_Option.place(x=CHECK_BOX_X, y=CHECK_BOX_Y, width=CHECK_BOX_WIDTH, height=CHECK_BOX_HEIGHT, relx=1/3, rely=7/self.COL2_ROWS, relwidth=1/3, relheight=1/self.COL2_ROWS) + self.is_secondary_stem_only_Option_toggle = lambda:self.is_primary_stem_only_var.set(False) if self.is_secondary_stem_only_var.get() else self.is_primary_stem_only_Option.configure(state=tk.NORMAL) + self.is_stem_only_Options_Enable = lambda:(self.is_primary_stem_only_Option.configure(state=tk.NORMAL), self.is_secondary_stem_only_Option.configure(state=tk.NORMAL)) + self.help_hints(self.is_secondary_stem_only_Option, text=SAVE_STEM_ONLY_HELP) + + # Sample Mode + self.model_sample_mode_Option = ttk.Checkbutton(master=self.options_Frame, textvariable=self.model_sample_mode_duration_checkbox_var, variable=self.model_sample_mode_var)#f'Sample ({self.model_sample_mode_duration_var.get()} Seconds)' + self.model_sample_mode_Option_place = lambda rely=8:self.model_sample_mode_Option.place(x=CHECK_BOX_X, y=CHECK_BOX_Y, width=CHECK_BOX_WIDTH, height=CHECK_BOX_HEIGHT, relx=1/3, rely=rely/self.COL2_ROWS, relwidth=1/3, relheight=1/self.COL2_ROWS) + self.help_hints(self.model_sample_mode_Option, text=MODEL_SAMPLE_MODE_HELP) + + self.GUI_LIST = (self.vr_model_Label, + self.vr_model_Option, + self.aggression_setting_Label, + self.aggression_setting_Option, + self.window_size_Label, + self.window_size_Option, + self.demucs_model_Label, + self.demucs_model_Option, + self.demucs_stems_Label, + self.demucs_stems_Option, + self.segment_Label, + self.segment_Option, + self.mdx_net_model_Label, + self.mdx_net_model_Option, + self.chunks_Label, + self.chunks_Option, + self.margin_Label, + self.margin_Option, + self.chosen_ensemble_Label, + self.chosen_ensemble_Option, + self.save_current_settings_Label, + self.save_current_settings_Option, + self.ensemble_main_stem_Label, + self.ensemble_main_stem_Option, + self.ensemble_type_Label, + self.ensemble_type_Option, + self.ensemble_listbox_Label, + self.ensemble_listbox_Frame, + self.ensemble_listbox_Option, + self.ensemble_listbox_scroll, + self.chosen_audio_tool_Label, + self.chosen_audio_tool_Option, + self.choose_algorithm_Label, + self.choose_algorithm_Option, + self.time_stretch_rate_Label, + self.time_stretch_rate_Option, + self.pitch_rate_Label, + self.pitch_rate_Option, + self.is_gpu_conversion_Option, + self.is_primary_stem_only_Option, + self.is_secondary_stem_only_Option, + self.is_primary_stem_only_Demucs_Option, + self.is_secondary_stem_only_Demucs_Option, + self.model_sample_mode_Option) + + REFRESH_VARS = (self.mdx_net_model_var, + self.vr_model_var, + self.demucs_model_var, + self.demucs_stems_var, + self.is_chunk_demucs_var, + self.is_primary_stem_only_Demucs_var, + self.is_secondary_stem_only_Demucs_var, + self.is_primary_stem_only_var, + self.is_secondary_stem_only_var, + self.model_download_demucs_var, + self.model_download_mdx_var, + self.model_download_vr_var, + self.select_download_var, + self.is_primary_stem_only_Demucs_Text_var, + self.is_secondary_stem_only_Demucs_Text_var, + self.chosen_process_method_var, + self.ensemble_main_stem_var) + + # Change States + for var in REFRESH_VARS: + var.trace_add('write', lambda *args: self.update_button_states()) + + def combobox_entry_validation(self, combobox: ttk.Combobox, var: tk.StringVar, pattern, default): + """Verifies valid input for comboboxes""" + + validation = lambda value:False if re.fullmatch(pattern, value) is None else True + invalid = lambda:(var.set(default[0])) + combobox.config(validate='focus', validatecommand=(self.register(validation), '%P'), invalidcommand=(self.register(invalid),)) + + def bind_widgets(self): + """Bind widgets to the drag & drop mechanic""" + + #print(self.chosen_audio_tool_Option.option_get()) + self.chosen_audio_tool_align = tk.BooleanVar(value=True) + add_align = lambda e:(self.chosen_audio_tool_Option['menu'].add_radiobutton(label=ALIGN_INPUTS, command=tk._setit(self.chosen_audio_tool_var, ALIGN_INPUTS)), self.chosen_audio_tool_align.set(False)) if self.chosen_audio_tool_align else None + + self.filePaths_saveTo_Button.drop_target_register(DND_FILES) + self.filePaths_saveTo_Entry.drop_target_register(DND_FILES) + self.drop_target_register(DND_FILES) + + self.dnd_bind('<>', lambda e: drop(e, accept_mode='files')) + self.bind(" <\>", add_align) + self.filePaths_saveTo_Button.dnd_bind('<>', lambda e: drop(e, accept_mode='folder')) + self.filePaths_saveTo_Entry.dnd_bind('<>', lambda e: drop(e, accept_mode='folder')) + self.ensemble_listbox_Option.bind('<>', lambda e: self.chosen_ensemble_var.set(CHOOSE_ENSEMBLE_OPTION)) + + self.options_Frame.bind('', lambda e:self.right_click_menu_popup(e, main_menu=True)) + self.filePaths_musicFile_Entry.bind('', lambda e:self.input_right_click_menu(e)) + self.filePaths_musicFile_Entry.bind('', lambda e:self.check_is_open_menu_view_inputs()) + + #--Input/Export Methods-- + + def input_select_filedialog(self): + """Make user select music files""" + + if self.lastDir is not None: + if not os.path.isdir(self.lastDir): + self.lastDir = None + + paths = tk.filedialog.askopenfilenames( + parent=self, + title=f'Select Music Files', + initialfile='', + initialdir=self.lastDir) + + if paths: # Path selected + self.inputPaths = paths + + self.process_input_selections() + self.update_inputPaths() + + def export_select_filedialog(self): + """Make user select a folder to export the converted files in""" + + export_path = None + + path = tk.filedialog.askdirectory( + parent=self, + title=f'Select Folder',) + + if path: # Path selected + self.export_path_var.set(path) + export_path = self.export_path_var.get() + + return export_path + + def update_inputPaths(self): + """Update the music file entry""" + + if self.inputPaths: + if len(self.inputPaths) == 1: + text = self.inputPaths[0] + else: + count = len(self.inputPaths) - 1 + file_text = 'file' if len(self.inputPaths) == 2 else 'files' + text = f"{self.inputPaths[0]}, +{count} {file_text}" + else: + # Empty Selection + text = '' + + self.inputPathsEntry_var.set(text) + + #--Utility Methods-- + + def restart(self): + """Restart the application after asking for confirmation""" + + confirm = tk.messagebox.askyesno(title='Restart Confirmation', + message='This will restart the application and halt any running processes. Your current settings will be saved. \n\n Are you sure you wish to continue?') + + if confirm: + #self.save_values() + try: + subprocess.Popen(f'UVR_Launcher.exe') + except Exception: + logging.exception("Restart") + subprocess.Popen(f'python "{__file__}"', shell=True) + + self.destroy() + + def delete_temps(self): + """Deletes temp files""" + + DIRECTORIES = (BASE_PATH, VR_MODELS_DIR, MDX_MODELS_DIR, DEMUCS_MODELS_DIR, DEMUCS_NEWER_REPO_DIR) + EXTENSIONS = (('.aes', '.txt', '.tmp')) + + try: + if os.path.isfile(f"{PATCH}.exe"): + os.remove(f"{PATCH}.exe") + + if os.path.isfile(SPLASH_DOC): + os.remove(SPLASH_DOC) + + for dir in DIRECTORIES: + for temp_file in os.listdir(dir): + if temp_file.endswith(EXTENSIONS): + if os.path.isfile(os.path.join(dir, temp_file)): + os.remove(os.path.join(dir, temp_file)) + except Exception as e: + self.error_log_var.set(error_text('Temp File Deletion', e)) + + def get_files_from_dir(self, directory, ext): + """Gets files from specified directory that ends with specified extention""" + + return tuple(os.path.splitext(x)[0] for x in os.listdir(directory) if x.endswith(ext)) + + def determine_auto_chunks(self, chunks, gpu): + """Determines appropriate chunk size based on user computer specs""" + + if chunks == 'Full': + chunk_set = 0 + elif chunks == 'Auto': + if gpu == 0: + gpu_mem = round(torch.cuda.get_device_properties(0).total_memory/1.074e+9) + if gpu_mem <= int(6): + chunk_set = int(5) + if gpu_mem in [7, 8, 9, 10, 11, 12, 13, 14, 15]: + chunk_set = int(10) + if gpu_mem >= int(16): + chunk_set = int(40) + if gpu == -1: + sys_mem = psutil.virtual_memory().total >> 30 + if sys_mem <= int(4): + chunk_set = int(1) + if sys_mem in [5, 6, 7, 8]: + chunk_set = int(10) + if sys_mem in [9, 10, 11, 12, 13, 14, 15, 16]: + chunk_set = int(25) + if sys_mem >= int(17): + chunk_set = int(60) + elif chunks == '0': + chunk_set = 0 + else: + chunk_set = int(chunks) + + return chunk_set + + def return_ensemble_stems(self, is_primary=False): + """Grabs and returns the chosen ensemble stems.""" + + ensemble_stem = self.ensemble_main_stem_var.get().partition("/") + + if is_primary: + return ensemble_stem[0] + else: + return ensemble_stem[0], ensemble_stem[2] + + def message_box(self, message): + """Template for confirmation box""" + + confirm = tk.messagebox.askyesno(title=message[0], + message=message[1], + parent=root) + + return confirm + + def error_dialoge(self, message): + """Template for messagebox that informs user of error""" + + tk.messagebox.showerror(master=self, + title=message[0], + message=message[1], + parent=root) + + def model_list(self, primary_stem: str, secondary_stem: str, is_4_stem_check=False, is_dry_check=False, is_no_demucs=False): + stem_check = self.assemble_model_data(arch_type=ENSEMBLE_STEM_CHECK, is_dry_check=is_dry_check) + + if is_no_demucs: + return [model.model_and_process_tag for model in stem_check if model.primary_stem == primary_stem or model.primary_stem == secondary_stem] + else: + if is_4_stem_check: + return [model.model_and_process_tag for model in stem_check if model.demucs_stem_count == 4] + else: + return [model.model_and_process_tag for model in stem_check if model.primary_stem == primary_stem or model.primary_stem == secondary_stem or primary_stem.lower() in model.demucs_source_list] + + def help_hints(self, widget, text): + toolTip = ToolTip(widget) + def enter(event): + if self.help_hints_var.get(): + toolTip.showtip(text) + def leave(event): + toolTip.hidetip() + widget.bind('', enter) + widget.bind('', leave) + widget.bind('', lambda e:copy_help_hint(e)) + + def copy_help_hint(event): + if self.help_hints_var.get(): + right_click_menu = Menu(self, font=('Century Gothic', 8), tearoff=0) + right_click_menu.add_command(label='Copy Help Hint Text', command=right_click_menu_copy_hint) + + try: + right_click_menu.tk_popup(event.x_root,event.y_root) + finally: + right_click_menu.grab_release() + else: + self.right_click_menu_popup(event, main_menu=True) + + def right_click_menu_copy_hint(): + pyperclip.copy(text) + + def input_right_click_menu(self, event): + + right_click_menu = Menu(self, font=('Century Gothic', 8), tearoff=0) + right_click_menu.add_command(label='See All Inputs', command=self.check_is_open_menu_view_inputs) + + try: + right_click_menu.tk_popup(event.x_root,event.y_root) + finally: + right_click_menu.grab_release() + + def cached_sources_clear(self): + + # print('\n==================================\n', 'vr_cache_source_mapper: \n\n', self.vr_cache_source_mapper, '\n==================================\n') + # print('\n==================================\n', 'mdx_cache_source_mapper: \n\n', self.mdx_cache_source_mapper, '\n==================================\n') + # print('\n==================================\n', 'demucs_cache_source_mapper: \n\n', self.demucs_cache_source_mapper, '\n==================================\n') + + self.vr_cache_source_mapper = {} + self.mdx_cache_source_mapper = {} + self.demucs_cache_source_mapper = {} + + def cached_model_source_holder(self, process_method, sources, model_name=None): + + if process_method == VR_ARCH_TYPE: + self.vr_cache_source_mapper = {**self.vr_cache_source_mapper, **{model_name: sources}} + if process_method == MDX_ARCH_TYPE: + self.mdx_cache_source_mapper = {**self.mdx_cache_source_mapper, **{model_name: sources}} + if process_method == DEMUCS_ARCH_TYPE: + self.demucs_cache_source_mapper = {**self.demucs_cache_source_mapper, **{model_name: sources}} + + def cached_source_callback(self, process_method, model_name=None): + + model, sources = None, None + + if process_method == VR_ARCH_TYPE: + mapper = self.vr_cache_source_mapper + if process_method == MDX_ARCH_TYPE: + mapper = self.mdx_cache_source_mapper + if process_method == DEMUCS_ARCH_TYPE: + mapper = self.demucs_cache_source_mapper + + for key, value in mapper.items(): + if model_name in key: + model = key + sources = value + + return model, sources + + def cached_source_model_list_check(self, model_list: list[ModelData]): + + model: ModelData + primary_model_names = lambda process_method:[model.model_basename if model.process_method == process_method else None for model in model_list] + secondary_model_names = lambda process_method:[model.secondary_model.model_basename if model.is_secondary_model_activated and model.process_method == process_method else None for model in model_list] + + self.vr_primary_model_names = primary_model_names(VR_ARCH_TYPE) + self.mdx_primary_model_names = primary_model_names(MDX_ARCH_TYPE) + self.demucs_primary_model_names = primary_model_names(DEMUCS_ARCH_TYPE) + self.vr_secondary_model_names = secondary_model_names(VR_ARCH_TYPE) + self.mdx_secondary_model_names = secondary_model_names(MDX_ARCH_TYPE) + self.demucs_secondary_model_names = [model.secondary_model.model_basename if model.is_secondary_model_activated and model.process_method == DEMUCS_ARCH_TYPE and not model.secondary_model is None else None for model in model_list] + self.demucs_pre_proc_model_name = [model.pre_proc_model.model_basename if model.pre_proc_model else None for model in model_list]#list(dict.fromkeys()) + + for model in model_list: + if model.process_method == DEMUCS_ARCH_TYPE and model.is_demucs_4_stem_secondaries: + if not model.is_4_stem_ensemble: + self.demucs_secondary_model_names = model.secondary_model_4_stem_model_names_list + break + else: + for i in model.secondary_model_4_stem_model_names_list: + self.demucs_secondary_model_names.append(i) + + print('self.demucs_pre_proc_model_name: ', self.demucs_pre_proc_model_name) + + self.all_models = self.vr_primary_model_names + self.mdx_primary_model_names + self.demucs_primary_model_names + self.vr_secondary_model_names + self.mdx_secondary_model_names + self.demucs_secondary_model_names + self.demucs_pre_proc_model_name + + def verify_audio(self, audio_file, is_process=True, sample_path=None): + is_good = False + error_data = '' + + if os.path.isfile(audio_file): + try: + librosa.load(audio_file, duration=3, mono=False, sr=44100) if not type(sample_path) is str else self.create_sample(audio_file, sample_path) + is_good = True + except Exception as e: + error_name = f'{type(e).__name__}' + traceback_text = ''.join(traceback.format_tb(e.__traceback__)) + message = f'{error_name}: "{e}"\n{traceback_text}"' + if is_process: + audio_base_name = os.path.basename(audio_file) + self.error_log_var.set(f'Error Loading the Following File:\n\n\"{audio_base_name}\"\n\nRaw Error Details:\n\n{message}') + else: + error_data = AUDIO_VERIFICATION_CHECK(audio_file, message) + + if is_process: + return is_good + else: + return is_good, error_data + + def create_sample(self, audio_file, sample_path=SAMPLE_CLIP_PATH): + try: + with audioread.audio_open(audio_file) as f: + track_length = int(f.duration) + except Exception as e: + print('Audioread failed to get duration. Trying Librosa...') + y, sr = librosa.load(audio_file, mono=False, sr=44100) + track_length = int(librosa.get_duration(y=y, sr=sr)) + + clip_duration = int(self.model_sample_mode_duration_var.get()) + + if track_length >= clip_duration: + offset_cut = track_length//3 + off_cut = offset_cut + track_length + if not off_cut >= clip_duration: + offset_cut = 0 + name_apped = f'{clip_duration}_second_' + else: + offset_cut, clip_duration = 0, track_length + name_apped = '' + + #if not track_length <= clip_duration: + sample = librosa.load(audio_file, offset=offset_cut, duration=clip_duration, mono=False, sr=44100)[0].T + audio_sample = os.path.join(sample_path, f'{os.path.splitext(os.path.basename(audio_file))[0]}_{name_apped}sample.wav') + sf.write(audio_sample, sample, 44100) + # else: + # audio_sample = audio_file + + return audio_sample + + #--Right Click Menu Pop-Ups-- + + def right_click_select_settings_sub(self, parent_menu, process_method): + saved_settings_sub_menu = Menu(parent_menu, font=('Century Gothic', 8), tearoff=False) + settings_options = self.last_found_settings + SAVE_SET_OPTIONS + + for settings_options in settings_options: + settings_options = settings_options.replace("_", " ") + saved_settings_sub_menu.add_command(label=settings_options, command=lambda o=settings_options:self.selection_action_saved_settings(o, process_method=process_method)) + + saved_settings_sub_menu.insert_separator(len(self.last_found_settings)) + + return saved_settings_sub_menu + + def right_click_menu_popup(self, event, text_box=False, main_menu=False): + + right_click_menu = Menu(self, font=('Century Gothic', 8), tearoff=0) + + PM_RIGHT_CLICK_MAPPER = { + ENSEMBLE_MODE:self.check_is_open_menu_advanced_ensemble_options, + VR_ARCH_PM:self.check_is_open_menu_advanced_vr_options, + MDX_ARCH_TYPE:self.check_is_open_menu_advanced_mdx_options, + DEMUCS_ARCH_TYPE:self.check_is_open_menu_advanced_demucs_options} + + PM_RIGHT_CLICK_VAR_MAPPER = { + ENSEMBLE_MODE:True, + VR_ARCH_PM:self.vr_is_secondary_model_activate_var.get(), + MDX_ARCH_TYPE:self.mdx_is_secondary_model_activate_var.get(), + DEMUCS_ARCH_TYPE:self.demucs_is_secondary_model_activate_var.get()} + + saved_settings_sub_load_for_menu = Menu(right_click_menu, font=('Century Gothic', 8), tearoff=False) + saved_settings_sub_load_for_menu.add_cascade(label=VR_ARCH_SETTING_LOAD, menu=self.right_click_select_settings_sub(saved_settings_sub_load_for_menu, VR_ARCH_PM)) + saved_settings_sub_load_for_menu.add_cascade(label=MDX_SETTING_LOAD, menu=self.right_click_select_settings_sub(saved_settings_sub_load_for_menu, MDX_ARCH_TYPE)) + saved_settings_sub_load_for_menu.add_cascade(label=DEMUCS_SETTING_LOAD, menu=self.right_click_select_settings_sub(saved_settings_sub_load_for_menu, DEMUCS_ARCH_TYPE)) + saved_settings_sub_load_for_menu.add_cascade(label=ALL_ARCH_SETTING_LOAD, menu=self.right_click_select_settings_sub(saved_settings_sub_load_for_menu, None)) + + if not main_menu: + right_click_menu.add_command(label='Copy', command=self.right_click_menu_copy) + right_click_menu.add_command(label='Paste', command=lambda:self.right_click_menu_paste(text_box=text_box)) + right_click_menu.add_command(label='Delete', command=lambda:self.right_click_menu_delete(text_box=text_box)) + else: + for method_type, option in PM_RIGHT_CLICK_MAPPER.items(): + if method_type == self.chosen_process_method_var.get(): + if PM_RIGHT_CLICK_VAR_MAPPER[method_type] or (method_type == DEMUCS_ARCH_TYPE and self.is_demucs_pre_proc_model_activate_var.get()): + right_click_menu.add_cascade(label='Select Saved Settings', menu=saved_settings_sub_load_for_menu) + right_click_menu.add_separator() + for method_type_sub, option_sub in PM_RIGHT_CLICK_MAPPER.items(): + if method_type_sub == ENSEMBLE_MODE and not self.chosen_process_method_var.get() == ENSEMBLE_MODE: + pass + else: + right_click_menu.add_command(label=f'Advanced {method_type_sub} Settings', command=option_sub) + else: + right_click_menu.add_command(label=f'Advanced {method_type} Settings', command=option) + break + + if not self.is_menu_settings_open: + right_click_menu.add_command(label='Additional Settings', command=lambda:self.menu_settings(select_tab_2=True)) + + help_hints_label = 'Enable' if self.help_hints_var.get() == False else 'Disable' + help_hints_bool = True if self.help_hints_var.get() == False else False + right_click_menu.add_command(label=f'{help_hints_label} Help Hints', command=lambda:self.help_hints_var.set(help_hints_bool)) + + if self.error_log_var.get(): + right_click_menu.add_command(label='Error Log', command=self.check_is_open_menu_error_log) + + try: + right_click_menu.tk_popup(event.x_root,event.y_root) + finally: + right_click_menu.grab_release() + + def right_click_menu_copy(self): + hightlighted_text = self.current_text_box.selection_get() + self.clipboard_clear() + self.clipboard_append(hightlighted_text) + + def right_click_menu_paste(self, text_box=False): + clipboard = self.clipboard_get() + self.right_click_menu_delete(text_box=True) if text_box else self.right_click_menu_delete() + self.current_text_box.insert(self.current_text_box.index(tk.INSERT), clipboard) + + def right_click_menu_delete(self, text_box=False): + if text_box: + try: + s0 = self.current_text_box.index("sel.first") + s1 = self.current_text_box.index("sel.last") + self.current_text_box.tag_configure('highlight') + self.current_text_box.tag_add("highlight", s0, s1) + start_indexes = self.current_text_box.tag_ranges("highlight")[0::2] + end_indexes = self.current_text_box.tag_ranges("highlight")[1::2] + + for start, end in zip(start_indexes, end_indexes): + self.current_text_box.tag_remove("highlight", start, end) + + for start, end in zip(start_indexes, end_indexes): + self.current_text_box.delete(start, end) + except Exception as e: + print('RIGHT-CLICK-DELETE ERROR: \n', e) + else: + self.current_text_box.delete(0, END) + + def right_click_console(self, event): + right_click_menu = Menu(self, font=('Century Gothic', 8), tearoff=0) + right_click_menu.add_command(label='Copy', command=self.command_Text.copy_text) + right_click_menu.add_command(label='Select All', command=self.command_Text.select_all_text) + + try: + right_click_menu.tk_popup(event.x_root,event.y_root) + finally: + right_click_menu.grab_release() + + #--Secondary Window Methods-- + + def menu_placement(self, window: Toplevel, title, pop_up=False, is_help_hints=False, close_function=None): + """Prepares and centers each secondary window relative to the main window""" + + window.geometry("+%d+%d" %(8000, 5000)) + window.resizable(False, False) + window.wm_transient(root) + window.title(title) + window.iconbitmap(ICON_IMG_PATH) + window.update() + window.deiconify() + + root_location_x = root.winfo_x() + root_location_y = root.winfo_y() + + root_x = root.winfo_width() + root_y = root.winfo_height() + + sub_menu_x = window.winfo_width() + sub_menu_y = window.winfo_height() + + menu_offset_x = (root_x - sub_menu_x) // 2 + menu_offset_y = (root_y - sub_menu_y) // 2 + window.geometry("+%d+%d" %(root_location_x+menu_offset_x, root_location_y+menu_offset_y)) + + def right_click_menu(event): + help_hints_label = 'Enable' if self.help_hints_var.get() == False else 'Disable' + help_hints_bool = True if self.help_hints_var.get() == False else False + right_click_menu = Menu(self, font=('Century Gothic', 8), tearoff=0) + if is_help_hints: + right_click_menu.add_command(label=f'{help_hints_label} Help Hints', command=lambda:self.help_hints_var.set(help_hints_bool)) + right_click_menu.add_command(label='Exit Window', command=close_function) + + try: + right_click_menu.tk_popup(event.x_root,event.y_root) + finally: + right_click_menu.grab_release() + + if close_function: + window.bind('', lambda e:right_click_menu(e)) + + if pop_up: + window.grab_set() + root.wait_window(window) + + def menu_tab_control(self, toplevel, ai_network_vars, is_demucs=False): + """Prepares the tabs setup for some windows""" + + tabControl = ttk.Notebook(toplevel) + + tab1 = ttk.Frame(tabControl) + tab2 = ttk.Frame(tabControl) + + tabControl.add(tab1, text ='Settings Guide') + tabControl.add(tab2, text ='Secondary Model') + + tabControl.pack(expand = 1, fill ="both") + + tab1.grid_rowconfigure(0, weight=1) + tab1.grid_columnconfigure(0, weight=1) + + tab2.grid_rowconfigure(0, weight=1) + tab2.grid_columnconfigure(0, weight=1) + + self.menu_secondary_model(tab2, ai_network_vars) + + if is_demucs: + tab3 = ttk.Frame(tabControl) + tabControl.add(tab3, text ='Pre-process Model') + tab3.grid_rowconfigure(0, weight=1) + tab3.grid_columnconfigure(0, weight=1) + + return tab1, tab3 + else: + return tab1 + + def menu_view_inputs(self): + + menu_view_inputs_top = Toplevel(root) + + self.is_open_menu_view_inputs.set(True) + self.menu_view_inputs_close_window = lambda:close_window() + menu_view_inputs_top.protocol("WM_DELETE_WINDOW", self.menu_view_inputs_close_window) + + input_length_var = tk.StringVar(value='') + input_info_text_var = tk.StringVar(value='') + is_widen_box_var = tk.BooleanVar(value=False) + is_play_file_var = tk.BooleanVar(value=False) + varification_text_var = tk.StringVar(value='Verify Inputs') + + reset_list = lambda:(input_files_listbox_Option.delete(0, 'end'), [input_files_listbox_Option.insert(tk.END, inputs) for inputs in self.inputPaths]) + audio_input_total = lambda:input_length_var.set(f'Audio Input Total: {len(self.inputPaths)}') + audio_input_total() + + def list_diff(list1, list2): return list(set(list1).symmetric_difference(set(list2))) + + def list_to_string(list1): return '\n'.join(''.join(sub) for sub in list1) + + def close_window(): + self.verification_thread.kill() if self.thread_check(self.verification_thread) else None + self.is_open_menu_view_inputs.set(False) + menu_view_inputs_top.destroy() + + def drag_n_drop(e): + input_info_text_var.set('') + drop(e, accept_mode='files') + reset_list() + audio_input_total() + + def selected_files(is_remove=False): + if not self.thread_check(self.active_processing_thread): + items_list = [input_files_listbox_Option.get(i) for i in input_files_listbox_Option.curselection()] + inputPaths = list(self.inputPaths)# if is_remove else items_list + if is_remove: + [inputPaths.remove(i) for i in items_list if items_list] + else: + [inputPaths.remove(i) for i in self.inputPaths if i not in items_list] + removed_files = list_diff(self.inputPaths, inputPaths) + [input_files_listbox_Option.delete(input_files_listbox_Option.get(0, tk.END).index(i)) for i in removed_files] + starting_len = len(self.inputPaths) + self.inputPaths = tuple(inputPaths) + self.update_inputPaths() + audio_input_total() + input_info_text_var.set(f'{starting_len - len(self.inputPaths)} input(s) removed.') + else: + input_info_text_var.set('You cannot remove inputs during an active process.') + + def box_size(): + input_info_text_var.set('') + input_files_listbox_Option.config(width=230, height=25) if is_widen_box_var.get() else input_files_listbox_Option.config(width=110, height=17) + self.menu_placement(menu_view_inputs_top, 'Selected Inputs', pop_up=True) + + def input_options(is_select_inputs=True): + input_info_text_var.set('') + if is_select_inputs: + self.input_select_filedialog() + else: + self.inputPaths = () + reset_list() + self.update_inputPaths() + audio_input_total() + + def pop_open_file_path(is_play_file=False): + if self.inputPaths: + track_selected = self.inputPaths[input_files_listbox_Option.index(tk.ACTIVE)] + if os.path.isfile(track_selected): + os.startfile(track_selected) if is_play_file else os.startfile(os.path.dirname(track_selected)) + + def get_export_dir(): + if os.path.isdir(self.export_path_var.get()): + export_dir = self.export_path_var.get() + else: + export_dir = self.export_select_filedialog() + + return export_dir + + def verify_audio(is_create_samples=False): + inputPaths = list(self.inputPaths) + iterated_list = self.inputPaths if not is_create_samples else [input_files_listbox_Option.get(i) for i in input_files_listbox_Option.curselection()] + removed_files = [] + export_dir = None + total_audio_count, current_file = len(iterated_list), 0 + if iterated_list: + for i in iterated_list: + current_file += 1 + input_info_text_var.set(f'{SAMPLE_BEGIN if is_create_samples else VERIFY_BEGIN}{current_file}/{total_audio_count}') + if is_create_samples: + export_dir = get_export_dir() + if not export_dir: + input_info_text_var.set(f'No export directory selected.') + return + is_good, error_data = self.verify_audio(i, is_process=False, sample_path=export_dir) + if not is_good: + inputPaths.remove(i) + removed_files.append(error_data)#sample = self.create_sample(i) + + varification_text_var.set('Verify Inputs') + input_files_listbox_Option.configure(state=tk.NORMAL) + + if removed_files: + input_info_text_var.set(f'{len(removed_files)} Broken or Incompatible File(s) Removed. Check Error Log for details.') + error_text = '' + for i in removed_files: + error_text += i + removed_files = list_diff(self.inputPaths, inputPaths) + [input_files_listbox_Option.delete(input_files_listbox_Option.get(0, tk.END).index(i)) for i in removed_files] + self.error_log_var.set(REMOVED_FILES(list_to_string(removed_files), error_text)) + self.inputPaths = tuple(inputPaths) + self.update_inputPaths() + else: + input_info_text_var.set(f'No errors found!') + + audio_input_total() + else: + input_info_text_var.set(f'No Files {SELECTED_VER if is_create_samples else DETECTED_VER}') + varification_text_var.set('Verify Inputs') + input_files_listbox_Option.configure(state=tk.NORMAL) + return + + #print(list_to_string(self.inputPaths)) + audio_input_total() + + def verify_audio_start_thread(is_create_samples=False): + + if not self.thread_check(self.active_processing_thread): + if not self.thread_check(self.verification_thread): + varification_text_var.set('Stop Progress') + input_files_listbox_Option.configure(state=tk.DISABLED) + self.verification_thread = KThread(target=lambda:verify_audio(is_create_samples=is_create_samples)) + self.verification_thread.start() + else: + input_files_listbox_Option.configure(state=tk.NORMAL) + varification_text_var.set('Verify Inputs') + input_info_text_var.set('Process Stopped') + self.verification_thread.kill() + else: + input_info_text_var.set('You cannot verify inputs during an active process.') + + def right_click_menu(event): + right_click_menu = Menu(self, font=('Century Gothic', 8), tearoff=0) + right_click_menu.add_command(label='Remove Selected Items Only', command=lambda:selected_files(is_remove=True)) + right_click_menu.add_command(label='Keep Selected Items Only', command=lambda:selected_files(is_remove=False)) + right_click_menu.add_command(label='Clear All Input(s)', command=lambda:input_options(is_select_inputs=False)) + right_click_menu.add_separator() + right_click_menu_sub = Menu(right_click_menu, font=('Century Gothic', 8), tearoff=False) + right_click_menu.add_command(label='Verify and Create Samples of Selected Inputs', command=lambda:verify_audio_start_thread(is_create_samples=True)) + right_click_menu.add_cascade(label='Preferred Double Click Action', menu=right_click_menu_sub) + if is_play_file_var.get(): + right_click_menu_sub.add_command(label='Enable: Open Audio File Directory', command=lambda:(input_files_listbox_Option.bind('', lambda e:pop_open_file_path()), is_play_file_var.set(False))) + else: + right_click_menu_sub.add_command(label='Enable: Open Audio File', command=lambda:(input_files_listbox_Option.bind('', lambda e:pop_open_file_path(is_play_file=True)), is_play_file_var.set(True))) + + try: + right_click_menu.tk_popup(event.x_root,event.y_root) + finally: + right_click_menu.grab_release() + + menu_view_inputs_Frame = self.menu_FRAME_SET(menu_view_inputs_top) + menu_view_inputs_Frame.grid(row=0,column=0,padx=0,pady=0) + + self.main_window_LABEL_SET(menu_view_inputs_Frame, 'Selected Inputs').grid(row=0,column=0,padx=0,pady=5) + tk.Label(menu_view_inputs_Frame, textvariable=input_length_var, font=("Century Gothic", "8"), foreground='#13a4c9').grid(row=1, column=0, padx=0, pady=5) + ttk.Button(menu_view_inputs_Frame, text='Select Input(s)', command=lambda:input_options()).grid(row=2,column=0,padx=0,pady=10) + + input_files_listbox_Option = tk.Listbox(menu_view_inputs_Frame, selectmode=tk.EXTENDED, activestyle='dotbox', font=("Century Gothic", "8"), background='#101414', exportselection=0, width=110, height=17, relief=SOLID, borderwidth=0) + input_files_listbox_vertical_scroll = ttk.Scrollbar(menu_view_inputs_Frame, orient=VERTICAL) + input_files_listbox_Option.config(yscrollcommand=input_files_listbox_vertical_scroll.set) + input_files_listbox_vertical_scroll.configure(command=input_files_listbox_Option.yview) + input_files_listbox_Option.grid(row=4, sticky=W) + input_files_listbox_vertical_scroll.grid(row=4, column=1, sticky=NS) + + tk.Label(menu_view_inputs_Frame, textvariable=input_info_text_var, font=("Century Gothic", "8"), foreground='#13a4c9').grid(row=5, column=0, padx=0, pady=0) + ttk.Checkbutton(menu_view_inputs_Frame, text='Widen Box', variable=is_widen_box_var, command=lambda:box_size()).grid(row=6,column=0,padx=0,pady=0) + verify_audio_Button = ttk.Button(menu_view_inputs_Frame, textvariable=varification_text_var, command=lambda:verify_audio_start_thread()) + verify_audio_Button.grid(row=7,column=0,padx=0,pady=5) + ttk.Button(menu_view_inputs_Frame, text='Close Window', command=lambda:menu_view_inputs_top.destroy()).grid(row=8,column=0,padx=0,pady=5) + + menu_view_inputs_top.drop_target_register(DND_FILES) + menu_view_inputs_top.dnd_bind('<>', lambda e: drag_n_drop(e)) + input_files_listbox_Option.bind('', lambda e:right_click_menu(e)) + input_files_listbox_Option.bind('', lambda e:pop_open_file_path()) + input_files_listbox_Option.bind('', lambda e:selected_files(is_remove=True)) + input_files_listbox_Option.bind('', lambda e:selected_files(is_remove=False)) + + reset_list() + + self.menu_placement(menu_view_inputs_top, 'Selected Inputs', pop_up=True) + + def menu_settings(self, select_tab_2=False, select_tab_3=False): + """Open Settings and Download Center""" + + settings_menu = Toplevel() + + option_var = tk.StringVar(value=SELECT_SAVED_SETTING) + self.is_menu_settings_open = True + + tabControl = ttk.Notebook(settings_menu) + + tab1 = ttk.Frame(tabControl) + tab2 = ttk.Frame(tabControl) + tab3 = ttk.Frame(tabControl) + + tabControl.add(tab1, text ='Settings Guide') + tabControl.add(tab2, text ='Additional Settings') + tabControl.add(tab3, text ='Download Center') + + tabControl.pack(expand = 1, fill ="both") + + tab1.grid_rowconfigure(0, weight=1) + tab1.grid_columnconfigure(0, weight=1) + + tab2.grid_rowconfigure(0, weight=1) + tab2.grid_columnconfigure(0, weight=1) + + tab3.grid_rowconfigure(0, weight=1) + tab3.grid_columnconfigure(0, weight=1) + + self.disable_tabs = lambda:(tabControl.tab(0, state="disabled"), tabControl.tab(1, state="disabled")) + self.enable_tabs = lambda:(tabControl.tab(0, state="normal"), tabControl.tab(1, state="normal")) + self.main_menu_var = tk.StringVar(value='Choose Option') + model_sample_mode_duration_label_var = tk.StringVar(value=f'{self.model_sample_mode_duration_var.get()} Seconds') + + self.download_progress_bar_var.set(0) + self.download_progress_info_var.set('') + self.download_progress_percent_var.set('') + + OPTION_LIST = { + ENSEMBLE_OPTION:self.check_is_open_menu_advanced_ensemble_options, + MDX_OPTION:self.check_is_open_menu_advanced_mdx_options, + DEMUCS_OPTION:self.check_is_open_menu_advanced_demucs_options, + VR_OPTION:self.check_is_open_menu_advanced_vr_options, + HELP_OPTION:self.check_is_open_menu_help, + ERROR_OPTION:self.check_is_open_menu_error_log} + + def set_vars_for_sample_mode(event): + value = int(float(event)) + value = round(value / 5) * 5 + self.model_sample_mode_duration_var.set(value) + self.model_sample_mode_duration_checkbox_var.set(SAMPLE_MODE_CHECKBOX(value)) + model_sample_mode_duration_label_var.set(f'{value} Seconds') + + #Settings Tab 1 + settings_menu_main_Frame = self.menu_FRAME_SET(tab1) + settings_menu_main_Frame.grid(row=0,column=0,padx=0,pady=0) + settings_title_Label = self.menu_title_LABEL_SET(settings_menu_main_Frame, "General Menu") + settings_title_Label.grid(row=0,column=0,padx=0,pady=15) + + select_Label = self.menu_sub_LABEL_SET(settings_menu_main_Frame, 'Additional Menus & Information') + select_Label.grid(row=1,column=0,padx=0,pady=5) + + select_Option = ttk.OptionMenu(settings_menu_main_Frame, self.main_menu_var, None, *ADVANCED_SETTINGS, command=lambda selection:(OPTION_LIST[selection](), close_window())) + select_Option.grid(row=2,column=0,padx=0,pady=5) + + help_hints_Option = ttk.Checkbutton(settings_menu_main_Frame, text='Enable Help Hints', variable=self.help_hints_var, width=16) + help_hints_Option.grid(row=3,column=0,padx=0,pady=5) + + open_app_dir_Button = ttk.Button(settings_menu_main_Frame, text='Open Application Directory', command=lambda:os.startfile('.')) + open_app_dir_Button.grid(row=6,column=0,padx=0,pady=5) + + reset_all_app_settings_Button = ttk.Button(settings_menu_main_Frame, text='Reset All Settings to Default', command=lambda:self.load_to_default_confirm()) + reset_all_app_settings_Button.grid(row=7,column=0,padx=0,pady=5) + + restart_app_Button = ttk.Button(settings_menu_main_Frame, text='Restart Application', command=lambda:self.restart()) + restart_app_Button.grid(row=8,column=0,padx=0,pady=5) + + close_settings_win_Button = ttk.Button(settings_menu_main_Frame, text='Close Window', command=lambda:close_window()) + close_settings_win_Button.grid(row=9,column=0,padx=0,pady=5) + + app_update_Label = self.menu_title_LABEL_SET(settings_menu_main_Frame, "Application Updates") + app_update_Label.grid(row=10,column=0,padx=0,pady=15) + + self.app_update_button = ttk.Button(settings_menu_main_Frame, textvariable=self.app_update_button_Text_var, command=lambda:self.pop_up_update_confirmation()) + self.app_update_button.grid(row=11,column=0,padx=0,pady=5) + + self.app_update_status_Label = tk.Label(settings_menu_main_Frame, textvariable=self.app_update_status_Text_var, font=("Century Gothic", "12"), width=35, justify="center", relief="ridge", fg="#13a4c9") + self.app_update_status_Label.grid(row=12,column=0,padx=0,pady=20) + + donate_Button = ttk.Button(settings_menu_main_Frame, image=self.donate_img, command=lambda:webbrowser.open_new_tab(DONATE_LINK_BMAC)) + donate_Button.grid(row=13,column=0,padx=0,pady=5) + self.help_hints(donate_Button, text=DONATE_HELP) + + #Settings Tab 2 + settings_menu_format_Frame = self.menu_FRAME_SET(tab2) + settings_menu_format_Frame.grid(row=0,column=0,padx=0,pady=0) + + audio_format_title_Label = self.menu_title_LABEL_SET(settings_menu_format_Frame, "Audio Format Settings", width=20) + audio_format_title_Label.grid(row=0,column=0,padx=0,pady=10) + + self.wav_type_set_Label = self.menu_sub_LABEL_SET(settings_menu_format_Frame, 'Wav Type') + self.wav_type_set_Label.grid(row=1,column=0,padx=0,pady=5) + + self.wav_type_set_Option = ttk.OptionMenu(settings_menu_format_Frame, self.wav_type_set_var, None, *WAV_TYPE) + self.wav_type_set_Option.grid(row=2,column=0,padx=20,pady=5) + + self.mp3_bit_set_Label = self.menu_sub_LABEL_SET(settings_menu_format_Frame, 'Mp3 Bitrate') + self.mp3_bit_set_Label.grid(row=3,column=0,padx=0,pady=5) + + self.mp3_bit_set_Option = ttk.OptionMenu(settings_menu_format_Frame, self.mp3_bit_set_var, None, *MP3_BIT_RATES) + self.mp3_bit_set_Option.grid(row=4,column=0,padx=20,pady=5) + + audio_format_title_Label = self.menu_title_LABEL_SET(settings_menu_format_Frame, "General Process Settings") + audio_format_title_Label.grid(row=5,column=0,padx=0,pady=10) + + self.is_testing_audio_Option = ttk.Checkbutton(settings_menu_format_Frame, text='Settings Test Mode', width=23, variable=self.is_testing_audio_var) + self.is_testing_audio_Option.grid(row=7,column=0,padx=0,pady=0) + self.help_hints(self.is_testing_audio_Option, text=IS_TESTING_AUDIO_HELP) + + self.is_add_model_name_Option = ttk.Checkbutton(settings_menu_format_Frame, text='Model Test Mode', width=23, variable=self.is_add_model_name_var) + self.is_add_model_name_Option.grid(row=8,column=0,padx=0,pady=0) + self.help_hints(self.is_add_model_name_Option, text=IS_MODEL_TESTING_AUDIO_HELP) + + self.is_create_model_folder_Option = ttk.Checkbutton(settings_menu_format_Frame, text='Generate Model Folders', width=23, variable=self.is_create_model_folder_var) + self.is_create_model_folder_Option.grid(row=9,column=0,padx=0,pady=0) + self.help_hints(self.is_create_model_folder_Option, text=IS_CREATE_MODEL_FOLDER_HELP) + + self.is_accept_any_input_Option = ttk.Checkbutton(settings_menu_format_Frame, text='Accept Any Input', width=23, variable=self.is_accept_any_input_var) + self.is_accept_any_input_Option.grid(row=10,column=0,padx=0,pady=0) + self.help_hints(self.is_accept_any_input_Option, text=IS_ACCEPT_ANY_INPUT_HELP) + + self.is_task_complete_Option = ttk.Checkbutton(settings_menu_format_Frame, text='Notification Chimes', width=23, variable=self.is_task_complete_var) + self.is_task_complete_Option.grid(row=11,column=0,padx=0,pady=0) + self.help_hints(self.is_task_complete_Option, text=IS_TASK_COMPLETE_HELP) + + is_normalization_Option = ttk.Checkbutton(settings_menu_format_Frame, text='Normalize Output', width=23, variable=self.is_normalization_var) + is_normalization_Option.grid(row=12,column=0,padx=0,pady=0) + self.help_hints(is_normalization_Option, text=IS_NORMALIZATION_HELP) + + model_sample_mode_Label = self.menu_title_LABEL_SET(settings_menu_format_Frame, "Model Sample Mode Settings") + model_sample_mode_Label.grid(row=13,column=0,padx=0,pady=10) + + self.model_sample_mode_duration_Label = self.menu_sub_LABEL_SET(settings_menu_format_Frame, 'Sample Clip Duration') + self.model_sample_mode_duration_Label.grid(row=14,column=0,padx=0,pady=5) + + tk.Label(settings_menu_format_Frame, textvariable=model_sample_mode_duration_label_var, font=("Century Gothic", "8"), foreground='#13a4c9').grid(row=15,column=0,padx=0,pady=2) + model_sample_mode_duration_Option = ttk.Scale(settings_menu_format_Frame, variable=self.model_sample_mode_duration_var, from_=5, to=120, command=set_vars_for_sample_mode, orient='horizontal') + model_sample_mode_duration_Option.grid(row=16,column=0,padx=0,pady=2) + + delete_your_settings_Label = self.menu_title_LABEL_SET(settings_menu_format_Frame, "Delete User Saved Setting") + delete_your_settings_Label.grid(row=17,column=0,padx=0,pady=10) + self.help_hints(delete_your_settings_Label, text=DELETE_YOUR_SETTINGS_HELP) + + delete_your_settings_Option = ttk.OptionMenu(settings_menu_format_Frame, option_var) + delete_your_settings_Option.grid(row=18,column=0,padx=20,pady=5) + self.deletion_list_fill(delete_your_settings_Option, option_var, self.last_found_settings, SETTINGS_CACHE_DIR, SELECT_SAVED_SETTING) + + #Settings Tab 3 + settings_menu_download_center_Frame = self.menu_FRAME_SET(tab3) + settings_menu_download_center_Frame.grid(row=0,column=0,padx=0,pady=0) + + download_center_title_Label = self.menu_title_LABEL_SET(settings_menu_download_center_Frame, "Application Download Center") + download_center_title_Label.grid(row=0,column=0,padx=20,pady=10) + + select_download_Label = self.menu_sub_LABEL_SET(settings_menu_download_center_Frame, "Select Download") + select_download_Label.grid(row=1,column=0,padx=0,pady=10) + + self.model_download_vr_Button = ttk.Radiobutton(settings_menu_download_center_Frame, text='VR Arch', width=8, variable=self.select_download_var, value='VR Arc', command=lambda:self.download_list_state()) + self.model_download_vr_Button.grid(row=3,column=0,padx=0,pady=5) + self.model_download_vr_Option = ttk.OptionMenu(settings_menu_download_center_Frame, self.model_download_vr_var) + self.model_download_vr_Option.grid(row=4,column=0,padx=0,pady=5) + + self.model_download_mdx_Button = ttk.Radiobutton(settings_menu_download_center_Frame, text='MDX-Net', width=8, variable=self.select_download_var, value='MDX-Net', command=lambda:self.download_list_state()) + self.model_download_mdx_Button.grid(row=5,column=0,padx=0,pady=5) + self.model_download_mdx_Option = ttk.OptionMenu(settings_menu_download_center_Frame, self.model_download_mdx_var) + self.model_download_mdx_Option.grid(row=6,column=0,padx=0,pady=5) + + self.model_download_demucs_Button = ttk.Radiobutton(settings_menu_download_center_Frame, text='Demucs', width=8, variable=self.select_download_var, value='Demucs', command=lambda:self.download_list_state()) + self.model_download_demucs_Button.grid(row=7,column=0,padx=0,pady=5) + self.model_download_demucs_Option = ttk.OptionMenu(settings_menu_download_center_Frame, self.model_download_demucs_var) + self.model_download_demucs_Option.grid(row=8,column=0,padx=0,pady=5) + + self.download_Button = ttk.Button(settings_menu_download_center_Frame, image=self.download_img, command=lambda:self.download_item())#, command=download_model) + self.download_Button.grid(row=9,column=0,padx=0,pady=5) + + self.download_progress_info_Label = tk.Label(settings_menu_download_center_Frame, textvariable=self.download_progress_info_var, font=("Century Gothic", "9"), foreground='#13a4c9', borderwidth=0) + self.download_progress_info_Label.grid(row=10,column=0,padx=0,pady=5) + + self.download_progress_percent_Label = tk.Label(settings_menu_download_center_Frame, textvariable=self.download_progress_percent_var, font=("Century Gothic", "9"), wraplength=350, foreground='#13a4c9') + self.download_progress_percent_Label.grid(row=11,column=0,padx=0,pady=5) + + self.download_progress_bar_Progressbar = ttk.Progressbar(settings_menu_download_center_Frame, variable=self.download_progress_bar_var) + self.download_progress_bar_Progressbar.grid(row=12,column=0,padx=0,pady=5) + + self.stop_download_Button = ttk.Button(settings_menu_download_center_Frame, textvariable=self.download_stop_var, width=15, command=lambda:self.download_post_action(DOWNLOAD_STOPPED)) + self.stop_download_Button.grid(row=13,column=0,padx=0,pady=5) + self.stop_download_Button_DISABLE = lambda:(self.download_stop_var.set(""), self.stop_download_Button.configure(state=tk.DISABLED)) + self.stop_download_Button_ENABLE = lambda:(self.download_stop_var.set("Stop Download"), self.stop_download_Button.configure(state=tk.NORMAL)) + + self.refresh_list_Button = ttk.Button(settings_menu_download_center_Frame, text='Refresh List', command=lambda:(self.online_data_refresh(refresh_list_Button=True), self.download_list_state()))#, command=refresh_list) + self.refresh_list_Button.grid(row=14,column=0,padx=0,pady=5) + + self.download_key_Button = ttk.Button(settings_menu_download_center_Frame, image=self.key_img, command=lambda:self.pop_up_user_code_input()) + self.download_key_Button.grid(row=15,column=0,padx=0,pady=5) + + self.download_center_Buttons = (self.model_download_vr_Button, + self.model_download_mdx_Button, + self.model_download_demucs_Button, + self.download_Button, + self.download_key_Button) + + self.download_lists = (self.model_download_vr_Option, + self.model_download_mdx_Option, + self.model_download_demucs_Option) + + self.download_list_vars = (self.model_download_vr_var, + self.model_download_mdx_var, + self.model_download_demucs_var) + + self.online_data_refresh() + self.download_list_state() + + if self.is_online: + self.download_list_fill() + + self.menu_placement(settings_menu, "Settings Guide", is_help_hints=True, close_function=lambda:close_window()) + + if select_tab_2: + tabControl.select(tab2) + + if select_tab_3: + tabControl.select(tab3) + + def close_window(): + self.active_download_thread.terminate() if self.thread_check(self.active_download_thread) else None + self.is_menu_settings_open = False + settings_menu.destroy() + + settings_menu.protocol("WM_DELETE_WINDOW", close_window) + + def menu_advanced_vr_options(self): + """Open Advanced VR Options""" + + vr_opt = Toplevel() + + tab1 = self.menu_tab_control(vr_opt, self.vr_secondary_model_vars) + + self.is_open_menu_advanced_vr_options.set(True) + self.menu_advanced_vr_options_close_window = lambda:(self.is_open_menu_advanced_vr_options.set(False), vr_opt.destroy()) + vr_opt.protocol("WM_DELETE_WINDOW", self.menu_advanced_vr_options_close_window) + + vr_opt_frame = self.menu_FRAME_SET(tab1) + vr_opt_frame.grid(row=0,column=0,padx=0,pady=0) + + vr_title = self.menu_title_LABEL_SET(vr_opt_frame, "Advanced VR Options") + vr_title.grid(row=0,column=0,padx=0,pady=10) + + if not self.chosen_process_method_var.get() == VR_ARCH_PM: + window_size_Label = self.menu_sub_LABEL_SET(vr_opt_frame, 'Window Size') + window_size_Label.grid(row=1,column=0,padx=0,pady=5) + window_size_Option = ttk.Combobox(vr_opt_frame, value=VR_WINDOW, width=MENU_COMBOBOX_WIDTH, textvariable=self.window_size_var) + window_size_Option.grid(row=2,column=0,padx=0,pady=5) + self.combobox_entry_validation(window_size_Option, self.window_size_var, REG_WINDOW, VR_WINDOW) + self.help_hints(window_size_Label, text=WINDOW_SIZE_HELP) + + aggression_setting_Label = self.menu_sub_LABEL_SET(vr_opt_frame, 'Aggression Setting') + aggression_setting_Label.grid(row=3,column=0,padx=0,pady=5) + aggression_setting_Option = ttk.Combobox(vr_opt_frame, value=VR_BATCH, width=MENU_COMBOBOX_WIDTH, textvariable=self.aggression_setting_var) + aggression_setting_Option.grid(row=4,column=0,padx=0,pady=5) + self.combobox_entry_validation(aggression_setting_Option, self.aggression_setting_var, REG_WINDOW, VR_BATCH) + self.help_hints(aggression_setting_Label, text=AGGRESSION_SETTING_HELP) + + self.crop_size_Label = self.menu_sub_LABEL_SET(vr_opt_frame, 'Crop Size') + self.crop_size_Label.grid(row=5,column=0,padx=0,pady=5) + self.crop_size_sub_Label = self.menu_sub_LABEL_SET(vr_opt_frame, '(Works with select models only)', font_size=8) + self.crop_size_sub_Label.grid(row=6,column=0,padx=0,pady=0) + self.crop_size_Option = ttk.Combobox(vr_opt_frame, value=VR_CROP, width=MENU_COMBOBOX_WIDTH, textvariable=self.crop_size_var) + self.crop_size_Option.grid(row=7,column=0,padx=0,pady=5) + self.combobox_entry_validation(self.crop_size_Option, self.crop_size_var, REG_WINDOW, VR_CROP) + self.help_hints(self.crop_size_Label, text=CROP_SIZE_HELP) + + self.batch_size_Label = self.menu_sub_LABEL_SET(vr_opt_frame, 'Batch Size') + self.batch_size_Label.grid(row=8,column=0,padx=0,pady=5) + self.batch_size_sub_Label = self.menu_sub_LABEL_SET(vr_opt_frame, '(Works with select models only)', font_size=8) + self.batch_size_sub_Label.grid(row=9,column=0,padx=0,pady=0) + self.batch_size_Option = ttk.Combobox(vr_opt_frame, value=VR_BATCH, width=MENU_COMBOBOX_WIDTH, textvariable=self.batch_size_var) + self.batch_size_Option.grid(row=10,column=0,padx=0,pady=5) + self.combobox_entry_validation(self.batch_size_Option, self.batch_size_var, REG_WINDOW, VR_BATCH) + self.help_hints(self.batch_size_Label, text=BATCH_SIZE_HELP) + + self.is_tta_Option = ttk.Checkbutton(vr_opt_frame, text='Enable TTA', width=16, variable=self.is_tta_var) + self.is_tta_Option.grid(row=11,column=0,padx=0,pady=0) + self.help_hints(self.is_tta_Option, text=IS_TTA_HELP) + + self.is_post_process_Option = ttk.Checkbutton(vr_opt_frame, text='Post-Process', width=16, variable=self.is_post_process_var) + self.is_post_process_Option.grid(row=12,column=0,padx=0,pady=0) + self.help_hints(self.is_post_process_Option, text=IS_POST_PROCESS_HELP) + + self.is_high_end_process_Option = ttk.Checkbutton(vr_opt_frame, text='High-End Process', width=16, variable=self.is_high_end_process_var) + self.is_high_end_process_Option.grid(row=13,column=0,padx=0,pady=0) + self.help_hints(self.is_high_end_process_Option, text=IS_HIGH_END_PROCESS_HELP) + + self.vr_clear_cache_Button = ttk.Button(vr_opt_frame, text='Clear Auto-Set Cache', command=lambda:self.clear_cache(VR_ARCH_TYPE)) + self.vr_clear_cache_Button.grid(row=14,column=0,padx=0,pady=5) + self.help_hints(self.vr_clear_cache_Button, text=CLEAR_CACHE_HELP) + + self.open_vr_model_dir_Button = ttk.Button(vr_opt_frame, text='Open VR Models Folder', command=lambda:os.startfile(VR_MODELS_DIR)) + self.open_vr_model_dir_Button.grid(row=15,column=0,padx=0,pady=5) + + self.vr_return_Button=ttk.Button(vr_opt_frame, text=BACK_TO_MAIN_MENU, command=lambda:(self.menu_advanced_vr_options_close_window(), self.check_is_menu_settings_open())) + self.vr_return_Button.grid(row=16,column=0,padx=0,pady=5) + + self.vr_close_Button = ttk.Button(vr_opt_frame, text='Close Window', command=lambda:self.menu_advanced_vr_options_close_window()) + self.vr_close_Button.grid(row=17,column=0,padx=0,pady=5) + + self.menu_placement(vr_opt, "Advanced VR Options", is_help_hints=True, close_function=self.menu_advanced_vr_options_close_window) + + def menu_advanced_demucs_options(self): + """Open Advanced Demucs Options""" + + demuc_opt = Toplevel() + + self.is_open_menu_advanced_demucs_options.set(True) + self.menu_advanced_demucs_options_close_window = lambda:(self.is_open_menu_advanced_demucs_options.set(False), demuc_opt.destroy()) + demuc_opt.protocol("WM_DELETE_WINDOW", self.menu_advanced_demucs_options_close_window) + pre_proc_list = self.model_list(VOCAL_STEM, INST_STEM, is_dry_check=True, is_no_demucs=True) + + tab1, tab3 = self.menu_tab_control(demuc_opt, self.demucs_secondary_model_vars, is_demucs=True) + + demucs_frame = self.menu_FRAME_SET(tab1) + demucs_frame.grid(row=0,column=0,padx=0,pady=0) + + demucs_pre_model_frame = self.menu_FRAME_SET(tab3) + demucs_pre_model_frame.grid(row=0,column=0,padx=0,pady=0) + + demucs_title_Label = self.menu_title_LABEL_SET(demucs_frame, "Advanced Demucs Options") + demucs_title_Label.grid(row=0,column=0,padx=0,pady=10) + + enable_chunks = lambda:(self.margin_demucs_Option.configure(state=tk.NORMAL), self.chunks_demucs_Option.configure(state=tk.NORMAL)) + disable_chunks = lambda:(self.margin_demucs_Option.configure(state=tk.DISABLED), self.chunks_demucs_Option.configure(state=tk.DISABLED)) + chunks_toggle = lambda:enable_chunks() if self.is_chunk_demucs_var.get() else disable_chunks() + enable_pre_proc_model = lambda:(is_demucs_pre_proc_model_inst_mix_Option.configure(state=tk.NORMAL), demucs_pre_proc_model_Option.configure(state=tk.NORMAL)) + disable_pre_proc_model = lambda:(is_demucs_pre_proc_model_inst_mix_Option.configure(state=tk.DISABLED), demucs_pre_proc_model_Option.configure(state=tk.DISABLED), self.is_demucs_pre_proc_model_inst_mix_var.set(False)) + pre_proc_model_toggle = lambda:enable_pre_proc_model() if self.is_demucs_pre_proc_model_activate_var.get() else disable_pre_proc_model() + + if not self.chosen_process_method_var.get() == DEMUCS_ARCH_TYPE: + segment_Label = self.menu_sub_LABEL_SET(demucs_frame, 'Segments') + segment_Label.grid(row=1,column=0,padx=0,pady=10) + segment_Option = ttk.Combobox(demucs_frame, value=DEMUCS_SEGMENTS, width=MENU_COMBOBOX_WIDTH, textvariable=self.segment_var) + segment_Option.grid(row=2,column=0,padx=0,pady=0) + self.combobox_entry_validation(segment_Option, self.segment_var, REG_SEGMENTS, DEMUCS_SEGMENTS) + self.help_hints(segment_Label, text=SEGMENT_HELP) + + self.shifts_Label = self.menu_sub_LABEL_SET(demucs_frame, 'Shifts') + self.shifts_Label.grid(row=3,column=0,padx=0,pady=5) + self.shifts_Option = ttk.Combobox(demucs_frame, value=DEMUCS_SHIFTS, width=MENU_COMBOBOX_WIDTH, textvariable=self.shifts_var) + self.shifts_Option.grid(row=4,column=0,padx=0,pady=5) + self.combobox_entry_validation(self.shifts_Option, self.shifts_var, REG_SHIFTS, DEMUCS_SHIFTS) + self.help_hints(self.shifts_Label, text=SHIFTS_HELP) + + self.overlap_Label = self.menu_sub_LABEL_SET(demucs_frame, 'Overlap') + self.overlap_Label.grid(row=5,column=0,padx=0,pady=5) + self.overlap_Option = ttk.Combobox(demucs_frame, value=DEMUCS_OVERLAP, width=MENU_COMBOBOX_WIDTH, textvariable=self.overlap_var) + self.overlap_Option.grid(row=6,column=0,padx=0,pady=5) + self.combobox_entry_validation(self.overlap_Option, self.overlap_var, REG_OVERLAP, DEMUCS_OVERLAP) + self.help_hints(self.overlap_Label, text=OVERLAP_HELP) + + self.chunks_demucs_Label = self.menu_sub_LABEL_SET(demucs_frame, 'Chunks') + self.chunks_demucs_Label.grid(row=7,column=0,padx=0,pady=5) + self.chunks_demucs_Option = ttk.Combobox(demucs_frame, value=CHUNKS, width=MENU_COMBOBOX_WIDTH, textvariable=self.chunks_demucs_var) + self.chunks_demucs_Option.grid(row=8,column=0,padx=0,pady=5) + self.combobox_entry_validation(self.chunks_demucs_Option, self.chunks_demucs_var, REG_CHUNKS, CHUNKS) + self.help_hints(self.chunks_demucs_Label, text=CHUNKS_HELP) + + self.margin_demucs_Label = self.menu_sub_LABEL_SET(demucs_frame, 'Chunk Margin') + self.margin_demucs_Label.grid(row=9,column=0,padx=0,pady=5) + self.margin_demucs_Option = ttk.Combobox(demucs_frame, value=MARGIN_SIZE, width=MENU_COMBOBOX_WIDTH, textvariable=self.margin_demucs_var) + self.margin_demucs_Option.grid(row=10,column=0,padx=0,pady=5) + self.combobox_entry_validation(self.margin_Option, self.margin_demucs_var, REG_MARGIN, MARGIN_SIZE) + self.help_hints(self.margin_demucs_Label, text=MARGIN_HELP) + + self.is_chunk_demucs_Option = ttk.Checkbutton(demucs_frame, text='Enable Chunks', width=18, variable=self.is_chunk_demucs_var, command=chunks_toggle) + self.is_chunk_demucs_Option.grid(row=11,column=0,padx=0,pady=0) + self.help_hints(self.is_chunk_demucs_Option, text=IS_CHUNK_DEMUCS_HELP) + + self.is_split_mode_Option = ttk.Checkbutton(demucs_frame, text='Split Mode', width=18, variable=self.is_split_mode_var) + self.is_split_mode_Option.grid(row=12,column=0,padx=0,pady=0) + self.help_hints(self.is_split_mode_Option, text=IS_SPLIT_MODE_HELP) + + self.is_demucs_combine_stems_Option = ttk.Checkbutton(demucs_frame, text='Combine Stems', width=18, variable=self.is_demucs_combine_stems_var) + self.is_demucs_combine_stems_Option.grid(row=13,column=0,padx=0,pady=0) + self.help_hints(self.is_demucs_combine_stems_Option, text=IS_DEMUCS_COMBINE_STEMS_HELP) + + is_invert_spec_Option = ttk.Checkbutton(demucs_frame, text='Spectral Inversion', width=18, variable=self.is_invert_spec_var) + is_invert_spec_Option.grid(row=14,column=0,padx=0,pady=0) + self.help_hints(is_invert_spec_Option, text=IS_INVERT_SPEC_HELP) + + self.open_demucs_model_dir_Button = ttk.Button(demucs_frame, text='Open Demucs Model Folder', command=lambda:os.startfile('models\Demucs_Models')) + self.open_demucs_model_dir_Button.grid(row=15,column=0,padx=0,pady=5) + + self.demucs_return_Button = ttk.Button(demucs_frame, text=BACK_TO_MAIN_MENU, command=lambda:(self.menu_advanced_demucs_options_close_window(), self.check_is_menu_settings_open())) + self.demucs_return_Button.grid(row=16,column=0,padx=0,pady=5) + + self.demucs_close_Button = ttk.Button(demucs_frame, text='Close Window', command=lambda:self.menu_advanced_demucs_options_close_window()) + self.demucs_close_Button.grid(row=17,column=0,padx=0,pady=5) + + demucs_pre_proc_model_title_Label = self.menu_title_LABEL_SET(demucs_pre_model_frame, "Pre-process Model") + demucs_pre_proc_model_title_Label.grid(row=0,column=0,padx=0,pady=15) + + demucs_pre_proc_model_Label = self.menu_sub_LABEL_SET(demucs_pre_model_frame, 'Select Model', font_size=10) + demucs_pre_proc_model_Label.grid(row=1,column=0,padx=0,pady=0) + demucs_pre_proc_model_Option = ttk.OptionMenu(demucs_pre_model_frame, self.demucs_pre_proc_model_var, None, NO_MODEL, *pre_proc_list) + demucs_pre_proc_model_Option.configure(width=33) + demucs_pre_proc_model_Option.grid(row=2,column=0,padx=0,pady=10) + + is_demucs_pre_proc_model_inst_mix_Option = ttk.Checkbutton(demucs_pre_model_frame, text='Save Instrumental Mixture', width=27, variable=self.is_demucs_pre_proc_model_inst_mix_var) + is_demucs_pre_proc_model_inst_mix_Option.grid(row=3,column=0,padx=0,pady=0) + self.help_hints(is_demucs_pre_proc_model_inst_mix_Option, text=PRE_PROC_MODEL_INST_MIX_HELP) + + is_demucs_pre_proc_model_activate_Option = ttk.Checkbutton(demucs_pre_model_frame, text='Activate Pre-process Model', width=27, variable=self.is_demucs_pre_proc_model_activate_var, command=pre_proc_model_toggle) + is_demucs_pre_proc_model_activate_Option.grid(row=4,column=0,padx=0,pady=0) + self.help_hints(is_demucs_pre_proc_model_activate_Option, text=PRE_PROC_MODEL_ACTIVATE_HELP) + + chunks_toggle() + pre_proc_model_toggle() + + self.menu_placement(demuc_opt, "Advanced Demucs Options", is_help_hints=True, close_function=self.menu_advanced_demucs_options_close_window) + + def menu_advanced_mdx_options(self): + """Open Advanced MDX Options""" + + mdx_net_opt = Toplevel() + + self.is_open_menu_advanced_mdx_options.set(True) + self.menu_advanced_mdx_options_close_window = lambda:(self.is_open_menu_advanced_mdx_options.set(False), mdx_net_opt.destroy()) + mdx_net_opt.protocol("WM_DELETE_WINDOW", self.menu_advanced_mdx_options_close_window) + + tab1 = self.menu_tab_control(mdx_net_opt, self.mdx_secondary_model_vars) + + mdx_net_frame = self.menu_FRAME_SET(tab1) + mdx_net_frame.grid(row=0,column=0,padx=0,pady=0) + + mdx_opt_title = self.menu_title_LABEL_SET(mdx_net_frame, "Advanced MDX-Net Options") + mdx_opt_title.grid(row=0,column=0,padx=0,pady=10) + + if not self.chosen_process_method_var.get() == MDX_ARCH_TYPE: + chunks_Label = self.menu_sub_LABEL_SET(mdx_net_frame, 'Chunks') + chunks_Label.grid(row=1,column=0,padx=0,pady=5) + chunks_Option = ttk.Combobox(mdx_net_frame, value=CHUNKS, width=MENU_COMBOBOX_WIDTH, textvariable=self.chunks_var) + chunks_Option.grid(row=2,column=0,padx=0,pady=5) + self.combobox_entry_validation(chunks_Option, self.chunks_var, REG_CHUNKS, CHUNKS) + self.help_hints(chunks_Label, text=CHUNKS_HELP) + + margin_Label = self.menu_sub_LABEL_SET(mdx_net_frame, 'Chunk Margin') + margin_Label.grid(row=3,column=0,padx=0,pady=5) + margin_Option = ttk.Combobox(mdx_net_frame, value=MARGIN_SIZE, width=MENU_COMBOBOX_WIDTH, textvariable=self.margin_var) + margin_Option.grid(row=4,column=0,padx=0,pady=5) + self.combobox_entry_validation(margin_Option, self.margin_var, REG_MARGIN, MARGIN_SIZE) + self.help_hints(margin_Label, text=MARGIN_HELP) + + compensate_Label = self.menu_sub_LABEL_SET(mdx_net_frame, 'Volume Compensation') + compensate_Label.grid(row=5,column=0,padx=0,pady=5) + compensate_Option = ttk.Combobox(mdx_net_frame, value=VOL_COMPENSATION, width=MENU_COMBOBOX_WIDTH, textvariable=self.compensate_var) + compensate_Option.grid(row=6,column=0,padx=0,pady=5) + self.combobox_entry_validation(compensate_Option, self.compensate_var, REG_COMPENSATION, VOL_COMPENSATION) + self.help_hints(compensate_Label, text=COMPENSATE_HELP) + + is_denoise_Option = ttk.Checkbutton(mdx_net_frame, text='Denoise Output', width=18, variable=self.is_denoise_var) + is_denoise_Option.grid(row=8,column=0,padx=0,pady=0) + self.help_hints(is_denoise_Option, text=IS_DENOISE_HELP) + + is_invert_spec_Option = ttk.Checkbutton(mdx_net_frame, text='Spectral Inversion', width=18, variable=self.is_invert_spec_var) + is_invert_spec_Option.grid(row=9,column=0,padx=0,pady=0) + self.help_hints(is_invert_spec_Option, text=IS_INVERT_SPEC_HELP) + + clear_mdx_cache_Button = ttk.Button(mdx_net_frame, text='Clear Auto-Set Cache', command=lambda:self.clear_cache(MDX_ARCH_TYPE)) + clear_mdx_cache_Button.grid(row=10,column=0,padx=0,pady=5) + self.help_hints(clear_mdx_cache_Button, text=CLEAR_CACHE_HELP) + + open_mdx_model_dir_Button = ttk.Button(mdx_net_frame, text='Open MDX-Net Models Folder', command=lambda:os.startfile(MDX_MODELS_DIR)) + open_mdx_model_dir_Button.grid(row=11,column=0,padx=0,pady=5) + + mdx_return_Button = ttk.Button(mdx_net_frame, text=BACK_TO_MAIN_MENU, command=lambda:(self.menu_advanced_mdx_options_close_window(), self.check_is_menu_settings_open())) + mdx_return_Button.grid(row=12,column=0,padx=0,pady=5) + + mdx_close_Button = ttk.Button(mdx_net_frame, text='Close Window', command=lambda:self.menu_advanced_mdx_options_close_window()) + mdx_close_Button.grid(row=13,column=0,padx=0,pady=5) + + self.menu_placement(mdx_net_opt, "Advanced MDX-Net Options", is_help_hints=True, close_function=self.menu_advanced_mdx_options_close_window) + + def menu_advanced_ensemble_options(self): + """Open Ensemble Custom""" + + custom_ens_opt = Toplevel() + + self.is_open_menu_advanced_ensemble_options.set(True) + self.menu_advanced_ensemble_options_close_window = lambda:(self.is_open_menu_advanced_ensemble_options.set(False), custom_ens_opt.destroy()) + custom_ens_opt.protocol("WM_DELETE_WINDOW", self.menu_advanced_ensemble_options_close_window) + + option_var = tk.StringVar(value=SELECT_SAVED_ENSEMBLE) + + custom_ens_opt_frame = self.menu_FRAME_SET(custom_ens_opt) + custom_ens_opt_frame.grid(row=0,column=0,padx=0,pady=0) + + settings_title_Label = self.menu_title_LABEL_SET(custom_ens_opt_frame, "Advanced Option Menu") + settings_title_Label.grid(row=1,column=0,padx=0,pady=10) + + delete_entry_Label = self.menu_sub_LABEL_SET(custom_ens_opt_frame, 'Remove Saved Ensemble') + delete_entry_Label.grid(row=2,column=0,padx=0,pady=5) + delete_entry_Option = ttk.OptionMenu(custom_ens_opt_frame, option_var) + delete_entry_Option.grid(row=3,column=0,padx=20,pady=5) + + is_save_all_outputs_ensemble_Option = ttk.Checkbutton(custom_ens_opt_frame, text='Save All Outputs', width=25, variable=self.is_save_all_outputs_ensemble_var) + is_save_all_outputs_ensemble_Option.grid(row=4,column=0,padx=0,pady=0) + self.help_hints(is_save_all_outputs_ensemble_Option, text=IS_SAVE_ALL_OUTPUTS_ENSEMBLE_HELP) + + is_append_ensemble_name_Option = ttk.Checkbutton(custom_ens_opt_frame, text='Append Ensemble Name', width=25, variable=self.is_append_ensemble_name_var) + is_append_ensemble_name_Option.grid(row=5,column=0,padx=0,pady=0) + self.help_hints(is_append_ensemble_name_Option, text=IS_APPEND_ENSEMBLE_NAME_HELP) + + ensemble_return_Button = ttk.Button(custom_ens_opt_frame, text="Back to Main Menu", command=lambda:(self.menu_advanced_ensemble_options_close_window(), self.check_is_menu_settings_open())) + ensemble_return_Button.grid(row=10,column=0,padx=0,pady=5) + + ensemble_close_Button = ttk.Button(custom_ens_opt_frame, text='Close Window', command=lambda:self.menu_advanced_ensemble_options_close_window()) + ensemble_close_Button.grid(row=11,column=0,padx=0,pady=5) + + self.deletion_list_fill(delete_entry_Option, option_var, self.last_found_ensembles, ENSEMBLE_CACHE_DIR, SELECT_SAVED_ENSEMBLE) + + self.menu_placement(custom_ens_opt, "Advanced Ensemble Options", is_help_hints=True, close_function=self.menu_advanced_ensemble_options_close_window) + + def menu_help(self): + """Open Help Guide""" + + help_guide_opt = Toplevel() + + self.is_open_menu_help.set(True) + self.menu_help_close_window = lambda:(self.is_open_menu_help.set(False), help_guide_opt.destroy()) + help_guide_opt.protocol("WM_DELETE_WINDOW", self.menu_help_close_window) + + tabControl = ttk.Notebook(help_guide_opt) + + tab1 = ttk.Frame(tabControl) + tab2 = ttk.Frame(tabControl) + tab3 = ttk.Frame(tabControl) + tab4 = ttk.Frame(tabControl) + + tabControl.add(tab1, text ='Credits') + tabControl.add(tab2, text ='Resources') + tabControl.add(tab3, text ='Application License & Version Information') + tabControl.add(tab4, text ='Application Change Log') + + tabControl.pack(expand = 1, fill ="both") + + tab1.grid_rowconfigure(0, weight=1) + tab1.grid_columnconfigure(0, weight=1) + + tab2.grid_rowconfigure(0, weight=1) + tab2.grid_columnconfigure(0, weight=1) + + tab3.grid_rowconfigure(0, weight=1) + tab3.grid_columnconfigure(0, weight=1) + + tab4.grid_rowconfigure(0, weight=1) + tab4.grid_columnconfigure(0, weight=1) + + section_title_Label = lambda place, frame, text, font_size=11: tk.Label(master=frame, text=text,font=("Century Gothic", f"{font_size}", "bold"), justify="center", fg="#F4F4F4").grid(row=place,column=0,padx=0,pady=3) + description_Label = lambda place, frame, text, font=9: tk.Label(master=frame, text=text, font=("Century Gothic", f"{font}"), justify="center", fg="#F6F6F7").grid(row=place,column=0,padx=0,pady=3) + + def credit_label(place, frame, text, link=None, message=None, is_link=False, is_top=False): + if is_top: + thank = tk.Label(master=frame, text=text, font=("Century Gothic", "10", "bold"), justify="center", fg="#13a4c9") + else: + thank = tk.Label(master=frame, text=text, font=("Century Gothic", "10", "underline" if is_link else "normal"), justify="center", fg="#13a4c9") + thank.configure(cursor="hand2") if is_link else None + thank.grid(row=place,column=0,padx=0,pady=1) + if link: + thank.bind("", lambda e:webbrowser.open_new_tab(link)) + if message: + description_Label(place+1, frame, message) + + def Link(place, frame, text, link, description, font=9): + link_label = tk.Label(master=frame, text=text, font=("Century Gothic", "11", "underline"), foreground='#15bfeb', justify="center", cursor="hand2") + link_label.grid(row=place,column=0,padx=0,pady=5) + link_label.bind("", lambda e:webbrowser.open_new_tab(link)) + description_Label(place+1, frame, description, font=font) + + def right_click_menu(event): + right_click_menu = Menu(self, font=('Century Gothic', 8), tearoff=0) + right_click_menu.add_command(label='Return to Settings Menu', command=lambda:(self.menu_help_close_window(), self.check_is_menu_settings_open())) + right_click_menu.add_command(label='Exit Window', command=lambda:self.menu_help_close_window()) + + try: + right_click_menu.tk_popup(event.x_root,event.y_root) + finally: + right_click_menu.grab_release() + + help_guide_opt.bind('', lambda e:right_click_menu(e)) + credits_Frame = Frame(tab1, highlightthicknes=50) + credits_Frame.grid(row=0, column=0, padx=0, pady=0) + tk.Label(credits_Frame, image=self.credits_img).grid(row=1,column=0,padx=0,pady=5) + + section_title_Label(place=0, + frame=credits_Frame, + text="Core UVR Developers") + + credit_label(place=2, + frame=credits_Frame, + text="Anjok07\nAufr33", + is_top=True) + + section_title_Label(place=3, + frame=credits_Frame, + text="Special Thanks") + + credit_label(place=6, + frame=credits_Frame, + text="Tsurumeso", + message="Developed the original VR Architecture AI code.", + link="https://github.com/tsurumeso/vocal-remover", + is_link=True) + + credit_label(place=8, + frame=credits_Frame, + text="Kuielab & Woosung Choi", + message="Developed the original MDX-Net AI code.", + link="https://github.com/kuielab", + is_link=True) + + credit_label(place=10, + frame=credits_Frame, + text="Adefossez & Demucs", + message="Core developer of Facebook's Demucs Music Source Separation.", + link="https://github.com/facebookresearch/demucs", + is_link=True) + + credit_label(place=12, + frame=credits_Frame, + text="Bas Curtiz", + message="Designed the official UVR logo, icon, banner, splash screen.") + + credit_label(place=14, + frame=credits_Frame, + text="DilanBoskan", + message="Your contributions at the start of this project were essential to the success of UVR. Thank you!") + + credit_label(place=16, + frame=credits_Frame, + text="Audio Separation and CC Karokee & Friends Discord Communities", + message="Thank you for the support!") + + more_info_tab_Frame = Frame(tab2, highlightthicknes=30) + more_info_tab_Frame.grid(row=0,column=0,padx=0,pady=0) + + section_title_Label(place=3, + frame=more_info_tab_Frame, + text="Resources") + + Link(place=4, + frame=more_info_tab_Frame, + text="Ultimate Vocal Remover (Official GitHub)", + link="https://github.com/Anjok07/ultimatevocalremovergui", + description="You can find updates, report issues, and give us a shout via our official GitHub.", + font=10) + + Link(place=8, + frame=more_info_tab_Frame, + text="X-Minus AI", + link="https://x-minus.pro/ai", + description="Many of the models provided are also on X-Minus.\n" + \ + "X-Minus benefits users without the computing resources to run the GUI or models locally.", + font=10) + + Link(place=12, + frame=more_info_tab_Frame, + text="FFmpeg", + link="https://www.wikihow.com/Install-FFmpeg-on-Windows", + description="UVR relies on FFmpeg for processing non-wav audio files.\n" + \ + "If you are missing FFmpeg, please see the installation guide via the link provided.", + font=10) + + Link(place=18, + frame=more_info_tab_Frame, + text="Rubber Band Library", + link="https://breakfastquay.com/rubberband/", + description="UVR uses the Rubber Band library for the sound stretch and pitch shift tool.\n" + \ + "You can get more information on it via the link provided.", + font=10) + + Link(place=22, + frame=more_info_tab_Frame, + text="Official UVR Patreon", + link=DONATE_LINK_PATREON, + description="If you wish to support and donate to this project, click the link above and become a Patreon!", + font=10) + + + appplication_license_tab_Frame = Frame(tab3) + appplication_license_tab_Frame.grid(row=0,column=0,padx=0,pady=0) + + appplication_license_Label = tk.Label(appplication_license_tab_Frame, text='UVR License Information', font=("Century Gothic", "15", "bold"), justify="center", fg="#f4f4f4") + appplication_license_Label.grid(row=0,column=0,padx=0,pady=25) + + appplication_license_Text = tk.Text(appplication_license_tab_Frame, font=("Century Gothic", "11"), fg="white", bg="black", width=80, wrap=WORD, borderwidth=0) + appplication_license_Text.grid(row=1,column=0,padx=0,pady=0) + appplication_license_Text_scroll = ttk.Scrollbar(appplication_license_tab_Frame, orient=VERTICAL) + appplication_license_Text.config(yscrollcommand=appplication_license_Text_scroll.set) + appplication_license_Text_scroll.configure(command=appplication_license_Text.yview) + appplication_license_Text.grid(row=4,sticky=W) + appplication_license_Text_scroll.grid(row=4, column=1, sticky=NS) + appplication_license_Text.insert("insert", LICENSE_TEXT(VERSION, PATCH)) + appplication_license_Text.configure(state=tk.DISABLED) + + application_change_log_tab_Frame = Frame(tab4) + application_change_log_tab_Frame.grid(row=0,column=0,padx=0,pady=0) + + if os.path.isfile(CHANGE_LOG): + with open(CHANGE_LOG, 'r') as file : + change_log_text = file.read() + else: + change_log_text = 'Change log unavailable.' + + application_change_log_Label = tk.Label(application_change_log_tab_Frame, text='UVR Change Log', font=("Century Gothic", "15", "bold"), justify="center", fg="#f4f4f4") + application_change_log_Label.grid(row=0,column=0,padx=0,pady=25) + + application_change_log_Text = tk.Text(application_change_log_tab_Frame, font=("Century Gothic", "11"), fg="white", bg="black", width=80, wrap=WORD, borderwidth=0) + application_change_log_Text.grid(row=1,column=0,padx=0,pady=0) + application_change_log_Text_scroll = ttk.Scrollbar(application_change_log_tab_Frame, orient=VERTICAL) + application_change_log_Text.config(yscrollcommand=application_change_log_Text_scroll.set) + application_change_log_Text_scroll.configure(command=application_change_log_Text.yview) + application_change_log_Text.grid(row=4,sticky=W) + application_change_log_Text_scroll.grid(row=4, column=1, sticky=NS) + application_change_log_Text.insert("insert", change_log_text) + application_change_log_Text.configure(state=tk.DISABLED) + + self.menu_placement(help_guide_opt, "Information Guide") + + def menu_error_log(self): + """Open Error Log""" + + self.is_confirm_error_var.set(False) + + copied_var = tk.StringVar('') + error_log_screen = Toplevel() + + self.is_open_menu_error_log.set(True) + self.menu_error_log_close_window = lambda:(self.is_open_menu_error_log.set(False), error_log_screen.destroy()) + error_log_screen.protocol("WM_DELETE_WINDOW", self.menu_error_log_close_window) + + error_log_frame = self.menu_FRAME_SET(error_log_screen) + error_log_frame.grid(row=0,column=0,padx=0,pady=0) + + error_consol_title_Label = self.menu_title_LABEL_SET(error_log_frame, "Error Console") + error_consol_title_Label.grid(row=1,column=0,padx=20,pady=10) + + # error_details_Text = tk.Text(error_log_frame, font=("Century Gothic", "8"), fg="#D37B7B", bg="black", width=110, relief="sunken") + # error_details_Text.grid(row=4,column=0,padx=0,pady=0) + # error_details_Text.insert("insert", self.error_log_var.get()) + # error_details_Text.bind('', lambda e:self.right_click_menu_popup(e, text_box=True)) + + error_details_Text = tk.Text(error_log_frame, font=("Century Gothic", "8"), fg="#D37B7B", bg="black", width=110, wrap=WORD, borderwidth=0) + error_details_Text.grid(row=4,column=0,padx=0,pady=0) + error_details_Text.insert("insert", self.error_log_var.get()) + error_details_Text.bind('', lambda e:self.right_click_menu_popup(e, text_box=True)) + self.current_text_box = error_details_Text + error_details_Text_scroll = ttk.Scrollbar(error_log_frame, orient=VERTICAL) + error_details_Text.config(yscrollcommand=error_details_Text_scroll.set) + error_details_Text_scroll.configure(command=error_details_Text.yview) + error_details_Text.grid(row=4,sticky=W) + error_details_Text_scroll.grid(row=4, column=1, sticky=NS) + + copy_text_Label = tk.Label(error_log_frame, textvariable=copied_var, font=("Century Gothic", "7"), justify="center", fg="#f4f4f4") + copy_text_Label.grid(row=5,column=0,padx=20,pady=0) + + copy_text_Button = ttk.Button(error_log_frame, text="Copy All Text", command=lambda:(pyperclip.copy(error_details_Text.get(1.0, tk.END+"-1c")), copied_var.set('Copied!'))) + copy_text_Button.grid(row=6,column=0,padx=20,pady=5) + + report_issue_Button = ttk.Button(error_log_frame, text="Report Issue", command=lambda:webbrowser.open_new_tab(ISSUE_LINK)) + report_issue_Button.grid(row=7,column=0,padx=20,pady=5) + + error_log_return_Button = ttk.Button(error_log_frame, text="Back to Main Menu", command=lambda:(self.menu_error_log_close_window(), self.menu_settings())) + error_log_return_Button.grid(row=8,column=0,padx=20,pady=5) + + error_log_close_Button = ttk.Button(error_log_frame, text='Close Window', command=lambda:self.menu_error_log_close_window()) + error_log_close_Button.grid(row=9,column=0,padx=20,pady=5) + + self.menu_placement(error_log_screen, "UVR Error Log") + + def menu_secondary_model(self, tab, ai_network_vars: dict): + + #Settings Tab 1 + secondary_model_Frame = self.menu_FRAME_SET(tab) + secondary_model_Frame.grid(row=0,column=0,padx=0,pady=0) + + settings_title_Label = self.menu_title_LABEL_SET(secondary_model_Frame, "Secondary Model") + settings_title_Label.grid(row=0,column=0,padx=0,pady=15) + + voc_inst_list = self.model_list(VOCAL_STEM, INST_STEM, is_dry_check=True) + other_list = self.model_list(OTHER_STEM, NO_OTHER_STEM, is_dry_check=True) + bass_list = self.model_list(BASS_STEM, NO_BASS_STEM, is_dry_check=True) + drum_list = self.model_list(DRUM_STEM, NO_DRUM_STEM, is_dry_check=True) + + voc_inst_secondary_model_var = ai_network_vars["voc_inst_secondary_model"] + other_secondary_model_var = ai_network_vars["other_secondary_model"] + bass_secondary_model_var = ai_network_vars["bass_secondary_model"] + drums_secondary_model_var = ai_network_vars["drums_secondary_model"] + voc_inst_secondary_model_scale_var = ai_network_vars['voc_inst_secondary_model_scale'] + other_secondary_model_scale_var = ai_network_vars['other_secondary_model_scale'] + bass_secondary_model_scale_var = ai_network_vars['bass_secondary_model_scale'] + drums_secondary_model_scale_var = ai_network_vars['drums_secondary_model_scale'] + is_secondary_model_activate_var = ai_network_vars["is_secondary_model_activate"] + + change_state_lambda = lambda:change_state(NORMAL if is_secondary_model_activate_var.get() else DISABLED) + init_convert_to_percentage = lambda raw_value:f"{int(float(raw_value)*100)}%" + + voc_inst_secondary_model_scale_LABEL_var = tk.StringVar(value=init_convert_to_percentage(voc_inst_secondary_model_scale_var.get())) + other_secondary_model_scale_LABEL_var = tk.StringVar(value=init_convert_to_percentage(other_secondary_model_scale_var.get())) + bass_secondary_model_scale_LABEL_var = tk.StringVar(value=init_convert_to_percentage(bass_secondary_model_scale_var.get())) + drums_secondary_model_scale_LABEL_var = tk.StringVar(value=init_convert_to_percentage(drums_secondary_model_scale_var.get())) + + def change_state(change_state): + for child_widget in secondary_model_Frame.winfo_children(): + if type(child_widget) is ttk.OptionMenu or type(child_widget) is ttk.Scale: + child_widget.configure(state=change_state) + + def convert_to_percentage(raw_value, scale_var: tk.StringVar, label_var: tk.StringVar): + raw_value = '%0.2f' % float(raw_value) + scale_var.set(raw_value) + label_var.set(f"{int(float(raw_value)*100)}%") + + def build_widgets(stem_pair: str, model_list: list, option_var: tk.StringVar, label_var: tk.StringVar, scale_var: tk.DoubleVar, placement: tuple): + secondary_model_Label = self.menu_sub_LABEL_SET(secondary_model_Frame, f'{stem_pair}', font_size=10) + secondary_model_Label.grid(row=placement[0],column=0,padx=0,pady=5) + secondary_model_Option = ttk.OptionMenu(secondary_model_Frame, option_var, None, NO_MODEL, *model_list) + secondary_model_Option.configure(width=33) + secondary_model_Option.grid(row=placement[1],column=0,padx=0,pady=5) + secondary_scale_info_Label = tk.Label(secondary_model_Frame, textvariable=label_var, font=("Century Gothic", "8"), foreground='#13a4c9') + secondary_scale_info_Label.grid(row=placement[2],column=0,padx=0,pady=0) + secondary_model_scale_Option = ttk.Scale(secondary_model_Frame, variable=scale_var, from_=0.01, to=0.99, command=lambda s:convert_to_percentage(s, scale_var, label_var), orient='horizontal') + secondary_model_scale_Option.grid(row=placement[3],column=0,padx=0,pady=2) + self.help_hints(secondary_model_Label, text=SECONDARY_MODEL_HELP) + self.help_hints(secondary_scale_info_Label, text=SECONDARY_MODEL_SCALE_HELP) + + build_widgets(stem_pair=VOCAL_PAIR, + model_list=voc_inst_list, + option_var=voc_inst_secondary_model_var, + label_var=voc_inst_secondary_model_scale_LABEL_var, + scale_var=voc_inst_secondary_model_scale_var, + placement=VOCAL_PAIR_PLACEMENT) + + build_widgets(stem_pair=OTHER_PAIR, + model_list=other_list, + option_var=other_secondary_model_var, + label_var=other_secondary_model_scale_LABEL_var, + scale_var=other_secondary_model_scale_var, + placement=OTHER_PAIR_PLACEMENT) + + build_widgets(stem_pair=BASS_PAIR, + model_list=bass_list, + option_var=bass_secondary_model_var, + label_var=bass_secondary_model_scale_LABEL_var, + scale_var=bass_secondary_model_scale_var, + placement=BASS_PAIR_PLACEMENT) + + build_widgets(stem_pair=DRUM_PAIR, + model_list=drum_list, + option_var=drums_secondary_model_var, + label_var=drums_secondary_model_scale_LABEL_var, + scale_var=drums_secondary_model_scale_var, + placement=DRUMS_PAIR_PLACEMENT) + + is_secondary_model_activate_Option = ttk.Checkbutton(secondary_model_Frame, text='Activate Secondary Model', variable=is_secondary_model_activate_var, command=change_state_lambda) + is_secondary_model_activate_Option.grid(row=21,column=0,padx=0,pady=5) + self.help_hints(is_secondary_model_activate_Option, text=SECONDARY_MODEL_ACTIVATE_HELP) + + change_state_lambda() + + def pop_up_save_current_settings(self): + """Save current application settings as...""" + + settings_save = Toplevel(root) + + settings_save_var = tk.StringVar('') + entry_validation_header_var = tk.StringVar(value='Input Notes') + + settings_save_Frame = self.menu_FRAME_SET(settings_save) + settings_save_Frame.grid(row=1,column=0,padx=0,pady=0) + + validation = lambda value:False if re.fullmatch(REG_SAVE_INPUT, value) is None and settings_save_var.get() else True + invalid = lambda:(entry_validation_header_var.set(INVALID_ENTRY)) + save_func = lambda:(self.pop_up_save_current_settings_sub_json_dump(settings_save_var.get()), settings_save.destroy()) + + settings_save_title = self.menu_title_LABEL_SET(settings_save_Frame, "Save Current Settings") + settings_save_title.grid(row=2,column=0,padx=0,pady=0) + + settings_save_name_Label = self.menu_sub_LABEL_SET(settings_save_Frame, 'Name Settings') + settings_save_name_Label.grid(row=3,column=0,padx=0,pady=5) + settings_save_name_Entry = ttk.Entry(settings_save_Frame, textvariable=settings_save_var, justify='center', width=25) + settings_save_name_Entry.grid(row=4,column=0,padx=0,pady=5) + settings_save_name_Entry.config(validate='focus', validatecommand=(self.register(validation), '%P'), invalidcommand=(self.register(invalid),)) + settings_save_name_Entry.bind('', self.right_click_menu_popup) + self.current_text_box = settings_save_name_Entry + + entry_validation_header_Label = tk.Label(settings_save_Frame, textvariable=entry_validation_header_var, font=("Century Gothic", "8"), foreground='#868687', justify="left") + entry_validation_header_Label.grid(row=5,column=0,padx=0,pady=0) + + entry_rules_Label = tk.Label(settings_save_Frame, text=ENSEMBLE_INPUT_RULE, font=("Century Gothic", "8"), foreground='#868687', justify="left") + entry_rules_Label.grid(row=6,column=0,padx=0,pady=0) + + settings_save_Button = ttk.Button(settings_save_Frame, text="Save", command=lambda:save_func() if validation(settings_save_var.get()) else None) + settings_save_Button.grid(row=7,column=0,padx=0,pady=5) + + stop_process_Button = ttk.Button(settings_save_Frame, text="Cancel", command=lambda:settings_save.destroy()) + stop_process_Button.grid(row=8,column=0,padx=0,pady=5) + + self.menu_placement(settings_save, "Save Current Settings", pop_up=True) + + def pop_up_save_current_settings_sub_json_dump(self, settings_save_name: str): + """Dumps current application settings to a json named after user input""" + + if settings_save_name: + self.save_current_settings_var.set(settings_save_name) + settings_save_name = settings_save_name.replace(" ", "_") + current_settings = self.save_values(app_close=False) + + saved_data_dump = json.dumps(current_settings, indent=4) + with open(os.path.join(SETTINGS_CACHE_DIR, f'{settings_save_name}.json'), "w") as outfile: + outfile.write(saved_data_dump) + + def pop_up_update_confirmation(self): + """Ask user is they want to update""" + + is_new_update = self.online_data_refresh(confirmation_box=True) + + if is_new_update: + + update_confirmation_win = Toplevel() + + update_confirmation_Frame = self.menu_FRAME_SET(update_confirmation_win) + update_confirmation_Frame.grid(row=0,column=0,padx=0,pady=0) + + update_found_label = self.menu_title_LABEL_SET(update_confirmation_Frame, 'Update Found', width=15) + update_found_label.grid(row=0,column=0,padx=0,pady=10) + + confirm_update_label = self.menu_sub_LABEL_SET(update_confirmation_Frame, 'Are you sure you want to continue?\n\nThe application will need to be restarted.\n', font_size=10) + confirm_update_label.grid(row=1,column=0,padx=0,pady=5) + + yes_button = ttk.Button(update_confirmation_Frame, text='Yes', command=lambda:(self.download_item(is_update_app=True), update_confirmation_win.destroy())) + yes_button.grid(row=2,column=0,padx=0,pady=5) + + no_button = ttk.Button(update_confirmation_Frame, text='No', command=lambda:(update_confirmation_win.destroy())) + no_button.grid(row=3,column=0,padx=0,pady=5) + + self.menu_placement(update_confirmation_win, "Confirm Update", pop_up=True) + + def pop_up_user_code_input(self): + """Input VIP Code""" + + self.user_code_validation_var.set('') + + self.user_code = Toplevel() + + user_code_Frame = self.menu_FRAME_SET(self.user_code) + user_code_Frame.grid(row=0,column=0,padx=0,pady=0) + + user_code_title_Label = self.menu_title_LABEL_SET(user_code_Frame, 'User Download Codes', width=20) + user_code_title_Label.grid(row=0,column=0,padx=0,pady=5) + + user_code_Label = self.menu_sub_LABEL_SET(user_code_Frame, 'Download Code') + user_code_Label.grid(row=1,column=0,padx=0,pady=5) + + self.user_code_Entry = ttk.Entry(user_code_Frame, textvariable=self.user_code_var, justify='center') + self.user_code_Entry.grid(row=2,column=0,padx=0,pady=5) + self.user_code_Entry.bind('', self.right_click_menu_popup) + self.current_text_box = self.user_code_Entry + + validation_Label = tk.Label(user_code_Frame, textvariable=self.user_code_validation_var, font=("Century Gothic", "7"), foreground='#868687') + validation_Label.grid(row=3,column=0,padx=0,pady=0) + + user_code_confrim_Button = ttk.Button(user_code_Frame, text='Confirm', command=lambda:self.download_validate_code(confirm=True)) + user_code_confrim_Button.grid(row=4,column=0,padx=0,pady=5) + + user_code_cancel_Button = ttk.Button(user_code_Frame, text='Cancel', command=lambda:self.user_code.destroy()) + user_code_cancel_Button.grid(row=5,column=0,padx=0,pady=5) + + support_title_Label = self.menu_title_LABEL_SET(user_code_Frame, text='Support UVR', width=20) + support_title_Label.grid(row=6,column=0,padx=0,pady=5) + + support_sub_Label = tk.Label(user_code_Frame, text="Obtain codes by making a one-time donation\n via \"Buy Me a Coffee\" " +\ + "or by becoming a Patreon.\nClick one of the buttons below to donate or pledge!", + font=("Century Gothic", "8"), foreground='#13a4c9') + support_sub_Label.grid(row=7,column=0,padx=0,pady=5) + + uvr_patreon_Button = ttk.Button(user_code_Frame, text='UVR Patreon Link', command=lambda:webbrowser.open_new_tab(DONATE_LINK_PATREON)) + uvr_patreon_Button.grid(row=8,column=0,padx=0,pady=5) + + bmac_patreon_Button=ttk.Button(user_code_Frame, text='UVR \"Buy Me a Coffee\" Link', command=lambda:webbrowser.open_new_tab(DONATE_LINK_BMAC)) + bmac_patreon_Button.grid(row=9,column=0,padx=0,pady=5) + + self.menu_placement(self.user_code, "Input Code", pop_up=True) + + def pop_up_mdx_model(self, mdx_model_hash, model_path): + """Opens MDX-Net model settings""" + + model = onnx.load(model_path) + model_shapes = [[d.dim_value for d in _input.type.tensor_type.shape.dim] for _input in model.graph.input][0] + dim_f = model_shapes[2] + dim_t = int(math.log(model_shapes[3], 2)) + + mdx_model_set = Toplevel(root) + + mdx_n_fft_scale_set_var = tk.StringVar(value='6144') + mdx_dim_f_set_var = tk.StringVar(value=dim_f) + mdx_dim_t_set_var = tk.StringVar(value=dim_t) + primary_stem_var = tk.StringVar(value='Vocals') + mdx_compensate_var = tk.StringVar(value=1.035) + + mdx_model_set_Frame = self.menu_FRAME_SET(mdx_model_set) + mdx_model_set_Frame.grid(row=2,column=0,padx=0,pady=0) + + mdx_model_set_title = self.menu_title_LABEL_SET(mdx_model_set_Frame, "Specify MDX-Net Model Parameters") + mdx_model_set_title.grid(row=0,column=0,padx=0,pady=15) + + set_stem_name_Label = self.menu_sub_LABEL_SET(mdx_model_set_Frame, 'Primary Stem') + set_stem_name_Label.grid(row=3,column=0,padx=0,pady=5) + set_stem_name_Option = ttk.OptionMenu(mdx_model_set_Frame, primary_stem_var, None, *STEM_SET_MENU) + set_stem_name_Option.configure(width=12) + set_stem_name_Option.grid(row=4,column=0,padx=0,pady=5) + self.help_hints(set_stem_name_Label, text=SET_STEM_NAME_HELP) + + mdx_dim_t_set_Label = self.menu_sub_LABEL_SET(mdx_model_set_Frame, 'Dim_t') + mdx_dim_t_set_Label.grid(row=5,column=0,padx=0,pady=5) + mdx_dim_f_set_Label = self.menu_sub_LABEL_SET(mdx_model_set_Frame, '(Leave this setting as is if you are unsure.)') + mdx_dim_f_set_Label.grid(row=6,column=0,padx=0,pady=5) + mdx_dim_t_set_Option = ttk.Combobox(mdx_model_set_Frame, value=('7', '8'), textvariable=mdx_dim_t_set_var) + mdx_dim_t_set_Option.configure(width=12) + mdx_dim_t_set_Option.grid(row=7,column=0,padx=0,pady=5) + self.help_hints(mdx_dim_t_set_Label, text=MDX_DIM_T_SET_HELP) + + mdx_dim_f_set_Label = self.menu_sub_LABEL_SET(mdx_model_set_Frame, 'Dim_f') + mdx_dim_f_set_Label.grid(row=8,column=0,padx=0,pady=5) + mdx_dim_f_set_Label = self.menu_sub_LABEL_SET(mdx_model_set_Frame, '(Leave this setting as is if you are unsure.)') + mdx_dim_f_set_Label.grid(row=9,column=0,padx=0,pady=5) + mdx_dim_f_set_Option = ttk.Combobox(mdx_model_set_Frame, value=(MDX_POP_DIMF), textvariable=mdx_dim_f_set_var) + mdx_dim_f_set_Option.configure(width=12) + mdx_dim_f_set_Option.grid(row=10,column=0,padx=0,pady=5) + self.help_hints(mdx_dim_f_set_Label, text=MDX_DIM_F_SET_HELP) + + mdx_n_fft_scale_set_Label = self.menu_sub_LABEL_SET(mdx_model_set_Frame, 'N_FFT Scale') + mdx_n_fft_scale_set_Label.grid(row=11,column=0,padx=0,pady=5) + mdx_n_fft_scale_set_Option = ttk.Combobox(mdx_model_set_Frame, values=(MDX_POP_NFFT), textvariable=mdx_n_fft_scale_set_var) + mdx_n_fft_scale_set_Option.configure(width=12) + mdx_n_fft_scale_set_Option.grid(row=12,column=0,padx=0,pady=5) + self.help_hints(mdx_n_fft_scale_set_Label, text=MDX_N_FFT_SCALE_SET_HELP) + + mdx_compensate_Label = self.menu_sub_LABEL_SET(mdx_model_set_Frame, 'Volume Compensation') + mdx_compensate_Label.grid(row=13,column=0,padx=0,pady=5) + mdx_compensate_Entry = ttk.Combobox(mdx_model_set_Frame, value=('1.035', '1.08'), textvariable=mdx_compensate_var) + mdx_compensate_Entry.configure(width=14) + mdx_compensate_Entry.grid(row=15,column=0,padx=0,pady=5) + self.help_hints(mdx_compensate_Label, text=POPUP_COMPENSATE_HELP) + + mdx_param_set_Button = ttk.Button(mdx_model_set_Frame, text="Confirm", command=lambda:pull_data()) + mdx_param_set_Button.grid(row=16,column=0,padx=0,pady=10) + + stop_process_Button = ttk.Button(mdx_model_set_Frame, text="Cancel", command=lambda:cancel()) + stop_process_Button.grid(row=17,column=0,padx=0,pady=0) + + def pull_data(): + mdx_model_params = { + 'compensate': float(mdx_compensate_var.get()), + 'mdx_dim_f_set': int(mdx_dim_f_set_var.get()), + 'mdx_dim_t_set': int(mdx_dim_t_set_var.get()), + 'mdx_n_fft_scale_set': int(mdx_n_fft_scale_set_var.get()), + 'primary_stem': primary_stem_var.get() + } + + self.pop_up_mdx_model_sub_json_dump(mdx_model_params, mdx_model_hash) + mdx_model_set.destroy() + + def cancel(): + mdx_model_set.destroy() + + mdx_model_set.protocol("WM_DELETE_WINDOW", cancel) + + self.menu_placement(mdx_model_set, "Specify Parameters", pop_up=True) + + def pop_up_mdx_model_sub_json_dump(self, mdx_model_params, mdx_model_hash): + """Dumps current selected MDX-Net model settings to a json named after model hash""" + + self.mdx_model_params = mdx_model_params + + mdx_model_params_dump = json.dumps(mdx_model_params, indent=4) + with open(os.path.join(MDX_HASH_DIR, f'{mdx_model_hash}.json'), "w") as outfile: + outfile.write(mdx_model_params_dump) + + def pop_up_vr_param(self, vr_model_hash): + """Opens VR param settings""" + + vr_param_menu = Toplevel() + + get_vr_params = lambda dir, ext:tuple(os.path.splitext(x)[0] for x in os.listdir(dir) if x.endswith(ext)) + new_vr_params = get_vr_params(VR_PARAM_DIR, '.json') + vr_model_param_var = tk.StringVar(value='None Selected') + vr_model_stem_var = tk.StringVar(value='Vocals') + + def pull_data(): + vr_model_params = { + 'vr_model_param': vr_model_param_var.get(), + 'primary_stem': vr_model_stem_var.get()} + + if not vr_model_param_var.get() == 'None Selected': + self.pop_up_vr_param_sub_json_dump(vr_model_params, vr_model_hash) + vr_param_menu.destroy() + else: + self.vr_model_params = None + + def cancel(): + self.vr_model_params = None + vr_param_menu.destroy() + + vr_param_Frame = self.menu_FRAME_SET(vr_param_menu) + vr_param_Frame.grid(row=0,column=0,padx=0,pady=0) + + vr_param_title_title = self.menu_title_LABEL_SET(vr_param_Frame, "Specify VR Model Parameters", width=25) + vr_param_title_title.grid(row=0,column=0,padx=0,pady=0) + + vr_model_stem_Label = self.menu_sub_LABEL_SET(vr_param_Frame, 'Primary Stem') + vr_model_stem_Label.grid(row=1,column=0,padx=0,pady=5) + vr_model_stem_Option = ttk.OptionMenu(vr_param_Frame, vr_model_stem_var, None, *STEM_SET_MENU) + vr_model_stem_Option.grid(row=2,column=0,padx=20,pady=5) + self.help_hints(vr_model_stem_Label, text=SET_STEM_NAME_HELP) + + vr_model_param_Label = self.menu_sub_LABEL_SET(vr_param_Frame, 'Select Model Param') + vr_model_param_Label.grid(row=3,column=0,padx=0,pady=5) + vr_model_param_Option = ttk.OptionMenu(vr_param_Frame, vr_model_param_var) + vr_model_param_Option.configure(width=30) + vr_model_param_Option.grid(row=4,column=0,padx=20,pady=5) + self.help_hints(vr_model_param_Label, text=VR_MODEL_PARAM_HELP) + + vr_param_confrim_Button = ttk.Button(vr_param_Frame, text='Confirm', command=lambda:pull_data()) + vr_param_confrim_Button.grid(row=5,column=0,padx=0,pady=5) + + vr_param_cancel_Button = ttk.Button(vr_param_Frame, text='Cancel', command=cancel) + vr_param_cancel_Button.grid(row=6,column=0,padx=0,pady=5) + + for option_name in new_vr_params: + vr_model_param_Option['menu'].add_radiobutton(label=option_name, command=tk._setit(vr_model_param_var, option_name)) + + vr_param_menu.protocol("WM_DELETE_WINDOW", cancel) + + self.menu_placement(vr_param_menu, "Choose Model Param", pop_up=True) + + def pop_up_vr_param_sub_json_dump(self, vr_model_params, vr_model_hash): + """Dumps current selected VR model settings to a json named after model hash""" + + self.vr_model_params = vr_model_params + + vr_model_params_dump = json.dumps(vr_model_params, indent=4) + + with open(os.path.join(VR_HASH_DIR, f'{vr_model_hash}.json'), "w") as outfile: + outfile.write(vr_model_params_dump) + + def pop_up_save_ensemble(self): + """ + Save Ensemble as... + """ + + ensemble_save = Toplevel(root) + + ensemble_save_var = tk.StringVar('') + entry_validation_header_var = tk.StringVar(value='Input Notes') + + ensemble_save_Frame = self.menu_FRAME_SET(ensemble_save) + ensemble_save_Frame.grid(row=1,column=0,padx=0,pady=0) + + validation = lambda value:False if re.fullmatch(REG_SAVE_INPUT, value) is None and ensemble_save_var.get() else True + invalid = lambda:(entry_validation_header_var.set(INVALID_ENTRY)) + save_func = lambda:(self.pop_up_save_ensemble_sub_json_dump(self.ensemble_listbox_get_all_selected_models(), ensemble_save_var.get()), ensemble_save.destroy()) + + if len(self.ensemble_listbox_get_all_selected_models()) <= 1: + ensemble_save_title = self.menu_title_LABEL_SET(ensemble_save_Frame, "Not Enough Models", width=20) + ensemble_save_title.grid(row=1,column=0,padx=0,pady=0) + + ensemble_save_title = self.menu_sub_LABEL_SET(ensemble_save_Frame, "You must select 2 or more models to save an ensemble.") + ensemble_save_title.grid(row=2,column=0,padx=0,pady=5) + + stop_process_Button = ttk.Button(ensemble_save_Frame, text="OK", command=lambda:ensemble_save.destroy()) + stop_process_Button.grid(row=3,column=0,padx=0,pady=10) + else: + ensemble_save_title = self.menu_title_LABEL_SET(ensemble_save_Frame, "Save Current Ensemble") + ensemble_save_title.grid(row=2,column=0,padx=0,pady=0) + + ensemble_name_Label = self.menu_sub_LABEL_SET(ensemble_save_Frame, 'Ensemble Name') + ensemble_name_Label.grid(row=3,column=0,padx=0,pady=5) + ensemble_name_Entry = ttk.Entry(ensemble_save_Frame, textvariable=ensemble_save_var, justify='center', width=25) + ensemble_name_Entry.grid(row=4,column=0,padx=0,pady=5) + ensemble_name_Entry.config(validate='focus', validatecommand=(self.register(validation), '%P'), invalidcommand=(self.register(invalid),)) + + entry_validation_header_Label = tk.Label(ensemble_save_Frame, textvariable=entry_validation_header_var, font=("Century Gothic", "8"), foreground='#868687', justify="left") + entry_validation_header_Label.grid(row=5,column=0,padx=0,pady=0) + + entry_rules_Label = tk.Label(ensemble_save_Frame, text=ENSEMBLE_INPUT_RULE, font=("Century Gothic", "8"), foreground='#868687', justify="left") + entry_rules_Label.grid(row=6,column=0,padx=0,pady=0) + + mdx_param_set_Button = ttk.Button(ensemble_save_Frame, text="Save", command=lambda:save_func() if validation(ensemble_save_var.get()) else None) + mdx_param_set_Button.grid(row=7,column=0,padx=0,pady=5) + + stop_process_Button = ttk.Button(ensemble_save_Frame, text="Cancel", command=lambda:ensemble_save.destroy()) + stop_process_Button.grid(row=8,column=0,padx=0,pady=5) + + self.menu_placement(ensemble_save, "Save Current Ensemble", pop_up=True) + + def pop_up_save_ensemble_sub_json_dump(self, selected_ensemble_model, ensemble_save_name: str): + """Dumps current ensemble settings to a json named after user input""" + + if ensemble_save_name: + self.chosen_ensemble_var.set(ensemble_save_name) + ensemble_save_name = ensemble_save_name.replace(" ", "_") + saved_data = { + 'ensemble_main_stem': self.ensemble_main_stem_var.get(), + 'ensemble_type': self.ensemble_type_var.get(), + 'selected_models': selected_ensemble_model, + } + + saved_data_dump = json.dumps(saved_data, indent=4) + with open(os.path.join(ENSEMBLE_CACHE_DIR, f'{ensemble_save_name}.json'), "w") as outfile: + outfile.write(saved_data_dump) + + def deletion_list_fill(self, option_menu: ttk.OptionMenu, selection_var: tk.StringVar, selection_list, selection_dir, var_set): + """Fills the saved settings menu located in tab 2 of the main settings window""" + + option_menu['menu'].delete(0, 'end') + for selection in selection_list: + selection = selection.replace("_", " ") + option_menu['menu'].add_radiobutton(label=selection, + command=tk._setit(selection_var, + selection, + lambda s:(self.deletion_entry(s, option_menu, selection_dir), + selection_var.set(var_set)))) + + def deletion_entry(self, selection: str, option_menu, path): + """Deletes selected user saved application settings""" + + if selection not in [SELECT_SAVED_SET, SELECT_SAVED_ENSEMBLE]: + saved_path = os.path.join(path, f'{selection.replace(" ", "_")}.json') + confirm = self.message_box(DELETE_ENS_ENTRY) + + if confirm: + if os.path.isfile(saved_path): + os.remove(saved_path) + r_index=option_menu['menu'].index(selection) # index of selected option. + option_menu['menu'].delete(r_index) # deleted the option + + #--Download Center Methods-- + + def online_data_refresh(self, user_refresh=True, confirmation_box=False, refresh_list_Button=False): + """Checks for application updates""" + + def online_check(): + + self.app_update_status_Text_var.set('Loading version information...') + self.app_update_button_Text_var.set('Check for Updates') + is_new_update = False + + try: + self.online_data = json.load(urllib.request.urlopen(DOWNLOAD_CHECKS)) + self.is_online = True + self.lastest_version = self.online_data["current_version"] + + if self.lastest_version == PATCH: + self.app_update_status_Text_var.set('UVR Version Current') + else: + is_new_update = True + self.app_update_status_Text_var.set(f"Update Found: {self.lastest_version}") + self.app_update_button_Text_var.set('Click Here to Update') + self.download_update_link_var.set('{}{}.exe'.format(UPDATE_REPO, self.lastest_version)) + self.download_update_path_var.set(os.path.join(BASE_PATH, f'{self.lastest_version}.exe')) + + if not user_refresh: + self.new_update_notify(self.lastest_version) + + if user_refresh: + self.download_list_state() + self.download_list_fill() + for widget in self.download_center_Buttons:widget.configure(state=tk.NORMAL) + + if refresh_list_Button: + self.download_progress_info_var.set('Download List Refreshed!') + + self.download_model_settings() + + except Exception as e: + self.error_log_var.set(error_text('Online Data Refresh', e)) + self.offline_state_set() + is_new_update = False + + if user_refresh: + self.download_list_state(disable_only=True) + for widget in self.download_center_Buttons:widget.configure(state=tk.DISABLED) + + return is_new_update + + if confirmation_box: + return online_check() + else: + self.current_thread = KThread(target=online_check) + self.current_thread.start() + + def offline_state_set(self): + """Changes relevent settings and "Download Center" buttons if no internet connection is available""" + + self.app_update_status_Text_var.set(f'Version Status: {NO_CONNECTION}') + self.download_progress_info_var.set(NO_CONNECTION) + self.app_update_button_Text_var.set('Refresh') + self.refresh_list_Button.configure(state=tk.NORMAL) if self.refresh_list_Button else None + self.stop_download_Button_DISABLE() if self.stop_download_Button_DISABLE else None + self.enable_tabs() if self.enable_tabs else None + self.is_online = False + + def download_validate_code(self, confirm=False): + """Verifies the VIP download code""" + + self.decoded_vip_link = vip_downloads(self.user_code_var.get()) + + if confirm: + if not self.decoded_vip_link == NO_CODE: + self.download_progress_info_var.set('VIP Models Added!') + self.user_code.destroy() + else: + self.download_progress_info_var.set('Incorrect Code') + self.user_code_validation_var.set('Code Incorrect') + + self.download_list_fill() + + def download_list_fill(self): + """Fills the download lists with the data retrieved from the update check.""" + + self.download_demucs_models_list.clear() + + for list_option in self.download_lists: + list_option['menu'].delete(0, 'end') + + self.vr_download_list = self.online_data["vr_download_list"] + self.mdx_download_list = self.online_data["mdx_download_list"] + self.demucs_download_list = self.online_data["demucs_download_list"] + + if not self.decoded_vip_link is NO_CODE: + self.vr_download_list.update(self.online_data["vr_download_vip_list"]) + self.mdx_download_list.update(self.online_data["mdx_download_vip_list"]) + + for (selectable, model) in self.vr_download_list.items(): + if not os.path.isfile(os.path.join(VR_MODELS_DIR, model)): + self.model_download_vr_Option['menu'].add_radiobutton(label=selectable, command=tk._setit(self.model_download_vr_var, selectable, lambda s:self.download_model_select(s, VR_ARCH_TYPE))) + + for (selectable, model) in self.mdx_download_list.items(): + if not os.path.isfile(os.path.join(MDX_MODELS_DIR, model)): + self.model_download_mdx_Option['menu'].add_radiobutton(label=selectable, command=tk._setit(self.model_download_mdx_var, selectable, lambda s:self.download_model_select(s, MDX_ARCH_TYPE))) + + for (selectable, model) in self.demucs_download_list.items(): + for name in model.items(): + if [True for x in DEMUCS_NEWER_ARCH_TYPES if x in selectable]: + if not os.path.isfile(os.path.join(DEMUCS_NEWER_REPO_DIR, name[0])): + self.download_demucs_models_list.append(selectable) + else: + if not os.path.isfile(os.path.join(DEMUCS_MODELS_DIR, name[0])): + self.download_demucs_models_list.append(selectable) + + self.download_demucs_models_list = list(dict.fromkeys(self.download_demucs_models_list)) + + for option_name in self.download_demucs_models_list: + self.model_download_demucs_Option['menu'].add_radiobutton(label=option_name, command=tk._setit(self.model_download_demucs_var, option_name, lambda s:self.download_model_select(s, DEMUCS_ARCH_TYPE))) + + if self.model_download_vr_Option['menu'].index("end") is None: + self.model_download_vr_Option['menu'].add_radiobutton(label=NO_NEW_MODELS, command=tk._setit(self.model_download_vr_var, NO_MODEL, lambda s:self.download_model_select(s, MDX_ARCH_TYPE))) + + if self.model_download_mdx_Option['menu'].index("end") is None: + self.model_download_mdx_Option['menu'].add_radiobutton(label=NO_NEW_MODELS, command=tk._setit(self.model_download_mdx_var, NO_MODEL, lambda s:self.download_model_select(s, MDX_ARCH_TYPE))) + + if self.model_download_demucs_Option['menu'].index("end") is None: + self.model_download_demucs_Option['menu'].add_radiobutton(label=NO_NEW_MODELS, command=tk._setit(self.model_download_demucs_var, NO_MODEL, lambda s:self.download_model_select(s, DEMUCS_ARCH_TYPE))) + + def download_model_settings(self): + '''Update the newest model settings''' + + self.vr_hash_MAPPER = json.load(urllib.request.urlopen(VR_MODEL_DATA_LINK)) + self.mdx_hash_MAPPER = json.load(urllib.request.urlopen(MDX_MODEL_DATA_LINK)) + + try: + vr_hash_MAPPER_dump = json.dumps(self.vr_hash_MAPPER, indent=4) + with open(VR_HASH_JSON, "w") as outfile: + outfile.write(vr_hash_MAPPER_dump) + + mdx_hash_MAPPER_dump = json.dumps(self.mdx_hash_MAPPER, indent=4) + with open(MDX_HASH_JSON, "w") as outfile: + outfile.write(mdx_hash_MAPPER_dump) + except Exception as e: + self.error_log_var.set(e) + print(e) + + def download_list_state(self, reset=True, disable_only=False): + """Makes sure only the models from the chosen AI network are selectable.""" + + for widget in self.download_lists:widget.configure(state=tk.DISABLED) + + if reset: + for download_list_var in self.download_list_vars: + if self.is_online: + download_list_var.set(NO_MODEL) + self.download_Button.configure(state=tk.NORMAL) + else: + download_list_var.set(NO_CONNECTION) + self.download_Button.configure(state=tk.DISABLED) + + if not disable_only: + + self.download_Button.configure(state=tk.NORMAL) + if self.select_download_var.get() == VR_ARCH_TYPE: + self.model_download_vr_Option.configure(state=tk.NORMAL) + self.selected_download_var = self.model_download_vr_var + if self.select_download_var.get() == MDX_ARCH_TYPE: + self.model_download_mdx_Option.configure(state=tk.NORMAL) + self.selected_download_var = self.model_download_mdx_var + if self.select_download_var.get() == DEMUCS_ARCH_TYPE: + self.model_download_demucs_Option.configure(state=tk.NORMAL) + self.selected_download_var = self.model_download_demucs_var + + self.stop_download_Button_DISABLE() + + def download_model_select(self, selection, type): + """Prepares the data needed to download selected model.""" + + self.download_demucs_newer_models.clear() + + model_repo = self.decoded_vip_link if VIP_SELECTION in selection else NORMAL_REPO + is_demucs_newer = [True for x in DEMUCS_NEWER_ARCH_TYPES if x in selection] + + if type == VR_ARCH_TYPE: + for selected_model in self.vr_download_list.items(): + if selection in selected_model: + self.download_link_path_var.set("{}{}".format(model_repo, selected_model[1])) + self.download_save_path_var.set(os.path.join(VR_MODELS_DIR, selected_model[1])) + break + + if type == MDX_ARCH_TYPE: + for selected_model in self.mdx_download_list.items(): + if selection in selected_model: + self.download_link_path_var.set("{}{}".format(model_repo, selected_model[1])) + self.download_save_path_var.set(os.path.join(MDX_MODELS_DIR, selected_model[1])) + break + + if type == DEMUCS_ARCH_TYPE: + for selected_model, model_data in self.demucs_download_list.items(): + if selection == selected_model: + for key, value in model_data.items(): + if is_demucs_newer: + self.download_demucs_newer_models.append([os.path.join(DEMUCS_NEWER_REPO_DIR, key), value]) + else: + self.download_save_path_var.set(os.path.join(DEMUCS_MODELS_DIR, key)) + self.download_link_path_var.set(value) + + def download_item(self, is_update_app=False): + """Downloads the model selected.""" + + if not is_update_app: + if self.selected_download_var.get() == NO_MODEL: + self.download_progress_info_var.set(NO_MODEL) + return + + for widget in self.download_center_Buttons:widget.configure(state=tk.DISABLED) + self.refresh_list_Button.configure(state=tk.DISABLED) + + is_demucs_newer = [True for x in DEMUCS_NEWER_ARCH_TYPES if x in self.selected_download_var.get()] + + self.download_list_state(reset=False, disable_only=True) + self.stop_download_Button_ENABLE() + self.disable_tabs() + + def download_progress_bar(current, total, model=80): + progress = ('%s' % (100 * current // total)) + self.download_progress_bar_var.set(int(progress)) + self.download_progress_percent_var.set(progress + ' %') + + def push_download(): + self.is_download_thread_active = True + try: + if is_update_app: + self.download_progress_info_var.set(DOWNLOADING_UPDATE) + if os.path.isfile(self.download_update_path_var.get()): + self.download_progress_info_var.set(FILE_EXISTS) + else: + wget.download(self.download_update_link_var.get(), self.download_update_path_var.get(), bar=download_progress_bar) + + self.download_post_action(DOWNLOAD_UPDATE_COMPLETE) + else: + if self.select_download_var.get() == DEMUCS_ARCH_TYPE and is_demucs_newer: + for model_num, model_data in enumerate(self.download_demucs_newer_models, start=1): + self.download_progress_info_var.set('{} {}/{}...'.format(DOWNLOADING_ITEM, model_num, len(self.download_demucs_newer_models))) + if os.path.isfile(model_data[0]): + continue + else: + wget.download(model_data[1], model_data[0], bar=download_progress_bar) + else: + self.download_progress_info_var.set(SINGLE_DOWNLOAD) + if os.path.isfile(self.download_save_path_var.get()): + self.download_progress_info_var.set(FILE_EXISTS) + else: + wget.download(self.download_link_path_var.get(), self.download_save_path_var.get(), bar=download_progress_bar) + + self.download_post_action(DOWNLOAD_COMPLETE) + + except Exception as e: + self.error_log_var.set(error_text(DOWNLOADING_ITEM, e)) + self.download_progress_info_var.set(DOWNLOAD_FAILED) + + if type(e).__name__ == 'URLError': + self.offline_state_set() + else: + self.download_progress_percent_var.set(f"{type(e).__name__}") + self.download_post_action(DOWNLOAD_FAILED) + + self.active_download_thread = KThread(target=push_download) + self.active_download_thread.start() + + def download_post_action(self, action): + """Resets the widget variables in the "Download Center" based on the state of the download.""" + + for widget in self.download_center_Buttons:widget.configure(state=tk.NORMAL) + self.refresh_list_Button.configure(state=tk.NORMAL) + + self.enable_tabs() + self.stop_download_Button_DISABLE() + + if action == DOWNLOAD_FAILED: + try: + self.active_download_thread.terminate() + finally: + self.download_progress_info_var.set(DOWNLOAD_FAILED) + self.download_list_state(reset=False) + if action == DOWNLOAD_STOPPED: + try: + self.active_download_thread.terminate() + finally: + self.download_progress_info_var.set(DOWNLOAD_STOPPED) + self.download_list_state(reset=False) + if action == DOWNLOAD_COMPLETE: + self.online_data_refresh() + self.download_progress_info_var.set(DOWNLOAD_COMPLETE) + self.download_list_state() + if action == DOWNLOAD_UPDATE_COMPLETE: + self.download_progress_info_var.set(DOWNLOAD_UPDATE_COMPLETE) + if os.path.isfile(self.download_update_path_var.get()): + subprocess.Popen(self.download_update_path_var.get()) + self.download_list_state() + + self.is_download_thread_active = False + + self.delete_temps() + + #--Refresh/Loop Methods-- + + def update_loop(self): + """Update the model dropdown menus""" + + if self.clear_cache_torch: + torch.cuda.empty_cache() + self.clear_cache_torch = False + + if self.is_process_stopped: + if self.thread_check(self.active_processing_thread): + self.conversion_Button_Text_var.set(STOP_PROCESSING) + self.conversion_Button.configure(state=tk.DISABLED) + self.stop_Button.configure(state=tk.DISABLED) + else: + self.stop_Button.configure(state=tk.NORMAL) + self.conversion_Button_Text_var.set(START_PROCESSING) + self.conversion_Button.configure(state=tk.NORMAL) + self.progress_bar_main_var.set(0) + torch.cuda.empty_cache() + self.is_process_stopped = False + + if self.is_confirm_error_var.get(): + self.check_is_open_menu_error_log() + self.is_confirm_error_var.set(False) + + self.update_available_models() + self.after(600, self.update_loop) + + def update_available_models(self): + """ + Loops through all models in each model directory and adds them to the appropriate model menu. + Also updates ensemble listbox and user saved settings list. + """ + + def fix_names(file, name_mapper: dict):return tuple(new_name for (old_name, new_name) in name_mapper.items() if file in old_name) + + new_vr_models = self.get_files_from_dir(VR_MODELS_DIR, '.pth') + new_mdx_models = self.get_files_from_dir(MDX_MODELS_DIR, '.onnx') + new_demucs_models = self.get_files_from_dir(DEMUCS_MODELS_DIR, ('.ckpt', '.gz', '.th')) + self.get_files_from_dir(DEMUCS_NEWER_REPO_DIR, '.yaml') + new_ensembles_found = self.get_files_from_dir(ENSEMBLE_CACHE_DIR, '.json') + new_settings_found = self.get_files_from_dir(SETTINGS_CACHE_DIR, '.json') + new_models_found = new_vr_models + new_mdx_models + new_demucs_models + is_online = self.is_online_model_menu + + def loop_directories(option_menu, option_var, model_list, model_type, name_mapper): + + option_list = model_list + option_menu['menu'].delete(0, 'end') + + if name_mapper: + option_list = [] + for file_name in model_list: + if fix_names(file_name, name_mapper): + file = fix_names(file_name, name_mapper)[0] + else: + file = file_name + option_list.append(file) + + option_list = tuple(option_list) + + for option_name in natsort.natsorted(option_list): + option_menu['menu'].add_radiobutton(label=option_name, command=tk._setit(option_var, option_name, self.selection_action_models)) + + if self.is_online: + option_menu['menu'].insert_separator(len(model_list)) + option_menu['menu'].add_radiobutton(label=DOWNLOAD_MORE, command=tk._setit(option_var, DOWNLOAD_MORE, self.selection_action_models)) + + return tuple(f"{model_type}{ENSEMBLE_PARTITION}{model_name}" for model_name in natsort.natsorted(option_list)) + + if new_models_found != self.last_found_models or is_online != self.is_online: + self.model_data_table = [] + + vr_model_list = loop_directories(self.vr_model_Option, self.vr_model_var, new_vr_models, VR_ARCH_TYPE, name_mapper=None) + mdx_model_list = loop_directories(self.mdx_net_model_Option, self.mdx_net_model_var, new_mdx_models, MDX_ARCH_TYPE, name_mapper=MDX_NAME_SELECT) + demucs_model_list = loop_directories(self.demucs_model_Option, self.demucs_model_var, new_demucs_models, DEMUCS_ARCH_TYPE, name_mapper=DEMUCS_NAME_SELECT) + + self.ensemble_model_list = vr_model_list + mdx_model_list + demucs_model_list + self.last_found_models = new_models_found + self.is_online_model_menu = self.is_online + + if not self.chosen_ensemble_var.get() == CHOOSE_ENSEMBLE_OPTION: + self.selection_action_chosen_ensemble(self.chosen_ensemble_var.get()) + else: + if not self.ensemble_main_stem_var.get() == CHOOSE_STEM_PAIR: + self.selection_action_ensemble_stems(self.ensemble_main_stem_var.get(), auto_update=self.ensemble_listbox_get_all_selected_models()) + else: + self.ensemble_listbox_clear_and_insert_new(self.ensemble_model_list) + + if new_ensembles_found != self.last_found_ensembles: + ensemble_options = new_ensembles_found + ENSEMBLE_OPTIONS + self.chosen_ensemble_Option['menu'].delete(0, 'end') + + for saved_ensemble in ensemble_options: + saved_ensemble = saved_ensemble.replace("_", " ") + self.chosen_ensemble_Option['menu'].add_radiobutton(label=saved_ensemble, + command=tk._setit(self.chosen_ensemble_var, saved_ensemble, self.selection_action_chosen_ensemble)) + + self.chosen_ensemble_Option['menu'].insert_separator(len(new_ensembles_found)) + self.last_found_ensembles = new_ensembles_found + + if new_settings_found != self.last_found_settings: + settings_options = new_settings_found + SAVE_SET_OPTIONS + self.save_current_settings_Option['menu'].delete(0, 'end') + + for settings_options in settings_options: + settings_options = settings_options.replace("_", " ") + self.save_current_settings_Option['menu'].add_radiobutton(label=settings_options, + command=tk._setit(self.save_current_settings_var, settings_options, self.selection_action_saved_settings)) + + self.save_current_settings_Option['menu'].insert_separator(len(new_settings_found)) + self.last_found_settings = new_settings_found + + def update_main_widget_states(self): + """Updates main widget states based on chosen process method""" + + for widget in self.GUI_LIST: + widget.place_forget() + + general_shared_Buttons_place = lambda:(self.is_gpu_conversion_Option_place(), self.model_sample_mode_Option_place()) + stem_save_Options_place = lambda:(self.is_primary_stem_only_Option_place(), self.is_secondary_stem_only_Option_place()) + stem_save_demucs_Options_place = lambda:(self.is_primary_stem_only_Demucs_Option_place(), self.is_secondary_stem_only_Demucs_Option_place()) + no_ensemble_shared = lambda:(self.save_current_settings_Label_place(), self.save_current_settings_Option_place()) + + if self.chosen_process_method_var.get() == MDX_ARCH_TYPE: + self.mdx_net_model_Label_place() + self.mdx_net_model_Option_place() + self.chunks_Label_place() + self.chunks_Option_place() + self.margin_Label_place() + self.margin_Option_place() + general_shared_Buttons_place() + stem_save_Options_place() + no_ensemble_shared() + elif self.chosen_process_method_var.get() == VR_ARCH_PM: + self.vr_model_Label_place() + self.vr_model_Option_place() + self.aggression_setting_Label_place() + self.aggression_setting_Option_place() + self.window_size_Label_place() + self.window_size_Option_place() + general_shared_Buttons_place() + stem_save_Options_place() + no_ensemble_shared() + elif self.chosen_process_method_var.get() == DEMUCS_ARCH_TYPE: + self.demucs_model_Label_place() + self.demucs_model_Option_place() + self.demucs_stems_Label_place() + self.demucs_stems_Option_place() + self.segment_Label_place() + self.segment_Option_place() + general_shared_Buttons_place() + stem_save_demucs_Options_place() + no_ensemble_shared() + elif self.chosen_process_method_var.get() == AUDIO_TOOLS: + self.chosen_audio_tool_Label_place() + self.chosen_audio_tool_Option_place() + if self.chosen_audio_tool_var.get() == MANUAL_ENSEMBLE: + self.choose_algorithm_Label_place() + self.choose_algorithm_Option_place() + elif self.chosen_audio_tool_var.get() == TIME_STRETCH: + self.model_sample_mode_Option_place(rely=5) + self.time_stretch_rate_Label_place() + self.time_stretch_rate_Option_place() + elif self.chosen_audio_tool_var.get() == CHANGE_PITCH: + self.model_sample_mode_Option_place(rely=5) + self.pitch_rate_Label_place() + self.pitch_rate_Option_place() + elif self.chosen_process_method_var.get() == ENSEMBLE_MODE: + self.chosen_ensemble_Label_place() + self.chosen_ensemble_Option_place() + self.ensemble_main_stem_Label_place() + self.ensemble_main_stem_Option_place() + self.ensemble_type_Label_place() + self.ensemble_type_Option_place() + self.ensemble_listbox_Label_place() + self.ensemble_listbox_Option_place() + self.ensemble_listbox_Option_pack() + general_shared_Buttons_place() + stem_save_Options_place() + + self.is_gpu_conversion_Disable() if not self.is_gpu_available else None + + self.update_inputPaths() + + def update_button_states(self): + """Updates the available stems for selected Demucs model""" + + if self.demucs_stems_var.get() == ALL_STEMS: + self.update_stem_checkbox_labels(PRIMARY_STEM, demucs=True) + elif self.demucs_stems_var.get() == VOCAL_STEM: + self.update_stem_checkbox_labels(VOCAL_STEM, demucs=True, is_disable_demucs_boxes=False) + self.is_stem_only_Demucs_Options_Enable() + else: + self.is_stem_only_Demucs_Options_Enable() + + self.demucs_stems_Option['menu'].delete(0,'end') + + if not self.demucs_model_var.get() == CHOOSE_MODEL: + if DEMUCS_UVR_MODEL in self.demucs_model_var.get(): + stems = DEMUCS_2_STEM_OPTIONS + elif DEMUCS_6_STEM_MODEL in self.demucs_model_var.get(): + stems = DEMUCS_6_STEM_OPTIONS + else: + stems = DEMUCS_4_STEM_OPTIONS + + for stem in stems: + self.demucs_stems_Option['menu'].add_radiobutton(label=stem, + command=tk._setit(self.demucs_stems_var, stem, lambda s:self.update_stem_checkbox_labels(s, demucs=True))) + + def update_stem_checkbox_labels(self, selection, demucs=False, disable_boxes=False, is_disable_demucs_boxes=True): + """Updates the "save only" checkboxes based on the model selected""" + + stem_text = self.is_primary_stem_only_Text_var, self.is_secondary_stem_only_Text_var + + if disable_boxes: + self.is_primary_stem_only_Option.configure(state=tk.DISABLED) + self.is_secondary_stem_only_Option.configure(state=tk.DISABLED) + self.is_primary_stem_only_var.set(False) + self.is_secondary_stem_only_var.set(False) + + if demucs: + stem_text = self.is_primary_stem_only_Demucs_Text_var, self.is_secondary_stem_only_Demucs_Text_var + if is_disable_demucs_boxes: + self.is_primary_stem_only_Demucs_Option.configure(state=tk.DISABLED) + self.is_secondary_stem_only_Demucs_Option.configure(state=tk.DISABLED) + self.is_primary_stem_only_Demucs_var.set(False) + self.is_secondary_stem_only_Demucs_var.set(False) + + for primary_stem, secondary_stem in STEM_PAIR_MAPPER.items(): + if selection == primary_stem: + stem_text[0].set(f"{primary_stem} Only") + stem_text[1].set(f"{secondary_stem} Only") + + def update_ensemble_algorithm_menu(self, is_4_stem=False): + + self.ensemble_type_Option['menu'].delete(0, 'end') + options = ENSEMBLE_TYPE_4_STEM if is_4_stem else ENSEMBLE_TYPE + + if not "/" in self.ensemble_type_var.get() or is_4_stem: + self.ensemble_type_var.set(options[0]) + + for choice in options: + self.ensemble_type_Option['menu'].add_command(label=choice, command=tk._setit(self.ensemble_type_var, choice)) + + def selection_action_models(self, selection): + """Accepts model names and verifies their state""" + + if selection in DOWNLOAD_MORE: + self.update_stem_checkbox_labels(PRIMARY_STEM, disable_boxes=True) + self.menu_settings(select_tab_3=True) if not self.is_menu_settings_open else None + for method_type, model_var in self.method_mapper.items(): + if method_type == self.chosen_process_method_var.get(): + model_var.set(CHOOSE_ENSEMBLE_OPTION) if method_type in ENSEMBLE_MODE else model_var.set(CHOOSE_MODEL) + elif selection in CHOOSE_MODEL: + self.update_stem_checkbox_labels(PRIMARY_STEM, disable_boxes=True) + else: + self.is_stem_only_Options_Enable() + + for method_type, model_var in self.method_mapper.items(): + if method_type == self.chosen_process_method_var.get(): + self.selection_action_models_sub(selection, method_type, model_var) + + if self.chosen_process_method_var.get() == ENSEMBLE_MODE: + model_data = self.assemble_model_data(selection, ENSEMBLE_CHECK)[0] + if not model_data.model_status: + return self.model_stems_list.index(selection) + else: + return False + + def selection_action_models_sub(self, selection, ai_type, var: tk.StringVar): + """Takes input directly from the selection_action_models parent function""" + + model_data = self.assemble_model_data(selection, ai_type)[0] + + if not model_data.model_status: + var.set(CHOOSE_MODEL) + self.update_stem_checkbox_labels(PRIMARY_STEM, disable_boxes=True) + else: + if ai_type == DEMUCS_ARCH_TYPE: + if not self.demucs_stems_var.get().lower() in model_data.demucs_source_list: + self.demucs_stems_var.set(ALL_STEMS if model_data.demucs_stem_count == 4 else VOCAL_STEM) + else: + stem = model_data.primary_stem + self.update_stem_checkbox_labels(stem) + + def selection_action_process_method(self, selection, from_widget=False): + """Checks model and variable status when toggling between process methods""" + + if from_widget: + self.save_current_settings_var.set(CHOOSE_ENSEMBLE_OPTION) + + if selection == ENSEMBLE_MODE: + if self.ensemble_main_stem_var.get() in [CHOOSE_STEM_PAIR, FOUR_STEM_ENSEMBLE]: + self.update_stem_checkbox_labels(PRIMARY_STEM, disable_boxes=True) + else: + self.update_stem_checkbox_labels(self.return_ensemble_stems(is_primary=True)) + self.is_stem_only_Options_Enable() + else: + for method_type, model_var in self.method_mapper.items(): + if method_type in selection: + self.selection_action_models(model_var.get()) + + def selection_action_chosen_ensemble(self, selection): + """Activates specific actions depending on selected ensemble option""" + + if selection not in ENSEMBLE_OPTIONS: + self.selection_action_chosen_ensemble_load_saved(selection) + if selection == SAVE_ENSEMBLE: + self.chosen_ensemble_var.set(CHOOSE_ENSEMBLE_OPTION) + self.pop_up_save_ensemble() + if selection == MENU_SEPARATOR: + self.chosen_ensemble_var.set(CHOOSE_ENSEMBLE_OPTION) + if selection == CLEAR_ENSEMBLE: + self.ensemble_listbox_Option.selection_clear(0, 'end') + self.chosen_ensemble_var.set(CHOOSE_ENSEMBLE_OPTION) + + def selection_action_chosen_ensemble_load_saved(self, saved_ensemble): + """Loads the data from selected saved ensemble""" + + saved_data = None + saved_ensemble = saved_ensemble.replace(" ", "_") + saved_ensemble_path = os.path.join(ENSEMBLE_CACHE_DIR, f'{saved_ensemble}.json') + + if os.path.isfile(saved_ensemble_path): + saved_data = json.load(open(saved_ensemble_path)) + + if saved_data: + self.selection_action_ensemble_stems(saved_data['ensemble_main_stem'], from_menu=False) + self.ensemble_main_stem_var.set(saved_data['ensemble_main_stem']) + self.ensemble_type_var.set(saved_data['ensemble_type']) + self.saved_model_list = saved_data['selected_models'] + + for saved_model in self.saved_model_list: + status = self.assemble_model_data(saved_model, ENSEMBLE_CHECK)[0].model_status + if not status: + self.saved_model_list.remove(saved_model) + + indexes = self.ensemble_listbox_get_indexes_for_files(self.model_stems_list, self.saved_model_list) + + for i in indexes: + self.ensemble_listbox_Option.selection_set(i) + + def selection_action_ensemble_stems(self, selection: str, from_menu=True, auto_update=None): + """Filters out all models from ensemble listbox that are incompatible with selected ensemble stem""" + + if not selection == CHOOSE_STEM_PAIR: + + if selection == FOUR_STEM_ENSEMBLE: + self.update_stem_checkbox_labels(PRIMARY_STEM, disable_boxes=True) + self.update_ensemble_algorithm_menu(is_4_stem=True) + self.ensemble_primary_stem = PRIMARY_STEM + self.ensemble_secondary_stem = SECONDARY_STEM + is_4_stem_check = True + else: + self.update_ensemble_algorithm_menu() + self.is_stem_only_Options_Enable() + stems = selection.partition("/") + self.update_stem_checkbox_labels(stems[0]) + self.ensemble_primary_stem = stems[0] + self.ensemble_secondary_stem = stems[2] + is_4_stem_check = False + + self.model_stems_list = self.model_list(self.ensemble_primary_stem, self.ensemble_secondary_stem, is_4_stem_check=is_4_stem_check) + self.ensemble_listbox_Option.configure(state=tk.NORMAL) + self.ensemble_listbox_clear_and_insert_new(self.model_stems_list) + + if auto_update: + indexes = self.ensemble_listbox_get_indexes_for_files(self.model_stems_list, auto_update) + self.ensemble_listbox_select_from_indexs(indexes) + else: + self.ensemble_listbox_Option.configure(state=tk.DISABLED) + self.update_stem_checkbox_labels(PRIMARY_STEM, disable_boxes=True) + self.model_stems_list = () + + if from_menu: + self.chosen_ensemble_var.set(CHOOSE_ENSEMBLE_OPTION) + + def selection_action_saved_settings(self, selection, process_method=None): + """Activates specific action based on the selected settings from the saved settings selections""" + + if self.thread_check(self.active_processing_thread): + self.error_dialoge(SET_TO_ANY_PROCESS_ERROR) + else: + saved_data = None + chosen_process_method = self.chosen_process_method_var.get() if not process_method else process_method + + if selection not in SAVE_SET_OPTIONS: + selection = selection.replace(" ", "_") + saved_ensemble_path = os.path.join(SETTINGS_CACHE_DIR, f'{selection}.json') + + if os.path.isfile(saved_ensemble_path): + saved_data = json.load(open(saved_ensemble_path)) + + if saved_data: + self.load_saved_settings(saved_data, chosen_process_method) + + if selection == SAVE_SETTINGS: + self.save_current_settings_var.set(SELECT_SAVED_SET) + self.pop_up_save_current_settings() + + if selection == RESET_TO_DEFAULT: + self.save_current_settings_var.set(SELECT_SAVED_SET) + self.load_saved_settings(DEFAULT_DATA, chosen_process_method) + + self.update_checkbox_text() + + #--Processing Methods-- + + def process_input_selections(self): + """Grabbing all audio files from selected directories.""" + + input_list = [] + + ext = FFMPEG_EXT if not self.is_accept_any_input_var.get() else ANY_EXT + + for i in self.inputPaths: + if os.path.isfile(i): + if i.endswith(ext): + input_list.append(i) + for root, dirs, files in os.walk(i): + for file in files: + if file.endswith(ext): + file = os.path.join(root, file) + if os.path.isfile(file): + input_list.append(file) + + self.inputPaths = tuple(input_list) + + def process_preliminary_checks(self): + """Verifies a valid model is chosen""" + + if self.wav_type_set_var.get() == '32-bit Float': + self.wav_type_set = 'FLOAT' + elif self.wav_type_set_var.get() == '64-bit Float': + self.wav_type_set = 'FLOAT' if not self.save_format_var.get() == WAV else 'DOUBLE' + else: + self.wav_type_set = self.wav_type_set_var.get() + + if self.chosen_process_method_var.get() == ENSEMBLE_MODE: + continue_process = lambda:False if len(self.ensemble_listbox_get_all_selected_models()) <= 1 else True + if self.chosen_process_method_var.get() == VR_ARCH_PM: + continue_process = lambda:False if self.vr_model_var.get() == CHOOSE_MODEL else True + if self.chosen_process_method_var.get() == MDX_ARCH_TYPE: + continue_process = lambda:False if self.mdx_net_model_var.get() == CHOOSE_MODEL else True + if self.chosen_process_method_var.get() == DEMUCS_ARCH_TYPE: + continue_process = lambda:False if self.demucs_model_var.get() == CHOOSE_MODEL else True + + return continue_process() + + def process_storage_check(self): + """Verifies storage requirments""" + + total, used, free = shutil.disk_usage("/") + + space_details = f'Detected Total Space: {int(total/1.074e+9)} GB\'s\n' +\ + f'Detected Used Space: {int(used/1.074e+9)} GB\'s\n' +\ + f'Detected Free Space: {int(free/1.074e+9)} GB\'s\n' + + appropriate_storage = True + + if int(free/1.074e+9) <= int(2): + self.error_dialoge([STORAGE_ERROR[0], f'{STORAGE_ERROR[1]}{space_details}']) + appropriate_storage = False + + if int(free/1.074e+9) in [3, 4, 5, 6, 7, 8]: + appropriate_storage = self.message_box([STORAGE_WARNING[0], f'{STORAGE_WARNING[1]}{space_details}{CONFIRM_WARNING}']) + + return appropriate_storage + + def process_initialize(self): + """Verifies the input/output directories are valid and prepares to thread the main process.""" + + if self.inputPaths: + if not os.path.isfile(self.inputPaths[0]): + self.error_dialoge(INVALID_INPUT) + return + else: + self.error_dialoge(INVALID_INPUT) + return + + if not os.path.isdir(self.export_path_var.get()): + self.error_dialoge(INVALID_EXPORT) + return + + if not self.process_storage_check(): + return + + if not self.chosen_process_method_var.get() == AUDIO_TOOLS: + if not self.process_preliminary_checks(): + self.error_dialoge(INVALID_ENSEMBLE if self.chosen_process_method_var.get() == ENSEMBLE_MODE else INVALID_MODEL) + return + + self.active_processing_thread = KThread(target=self.process_start) + self.active_processing_thread.start() + else: + self.active_processing_thread = KThread(target=self.process_tool_start) + self.active_processing_thread.start() + + def process_button_init(self): + self.command_Text.clear() + self.conversion_Button_Text_var.set(WAIT_PROCESSING) + self.conversion_Button.configure(state=tk.DISABLED) + + def process_get_baseText(self, total_files, file_num): + """Create the base text for the command widget""" + + text = 'File {file_num}/{total_files} '.format(file_num=file_num, + total_files=total_files) + + return text + + def process_update_progress(self, model_count, total_files, step: float = 1): + """Calculate the progress for the progress widget in the GUI""" + + total_count = model_count * total_files + base = (100 / total_count) + progress = base * self.iteration - base + progress += base * step + + self.progress_bar_main_var.set(progress) + + self.conversion_Button_Text_var.set(f'Process Progress: {int(progress)}%') + + def confirm_stop_process(self): + """Asks for confirmation before halting active process""" + + if self.thread_check(self.active_processing_thread): + confirm = tk.messagebox.askyesno(title=STOP_PROCESS_CONFIRM[0], + message=STOP_PROCESS_CONFIRM[1]) + + if confirm: + try: + self.active_processing_thread.terminate() + finally: + self.is_process_stopped = True + self.command_Text.write('\n\nProcess stopped by user.') + else: + self.clear_cache_torch = True + + def process_end(self, error=None): + """End of process actions""" + + self.cached_sources_clear() + self.clear_cache_torch = True + self.conversion_Button_Text_var.set(START_PROCESSING) + self.conversion_Button.configure(state=tk.NORMAL) + self.progress_bar_main_var.set(0) + + if error: + error_message_box_text = f'{error_dialouge(error)}{ERROR_OCCURED[1]}' + confirm = tk.messagebox.askyesno(master=self, + title=ERROR_OCCURED[0], + message=error_message_box_text) + + if confirm: + self.is_confirm_error_var.set(True) + self.clear_cache_torch = True + + self.clear_cache_torch = True + + if MODEL_MISSING_CHECK in error_message_box_text: + self.update_checkbox_text() + + def process_tool_start(self): + """Start the conversion for all the given mp3 and wav files""" + + multiple_files = False + stime = time.perf_counter() + time_elapsed = lambda:f'Time Elapsed: {time.strftime("%H:%M:%S", time.gmtime(int(time.perf_counter() - stime)))}' + self.process_button_init() + inputPaths = self.inputPaths + is_verified_audio = True + is_model_sample_mode = self.model_sample_mode_var.get() + + try: + total_files = len(inputPaths) + + if self.chosen_audio_tool_var.get() == TIME_STRETCH: + audio_tool = AudioTools(TIME_STRETCH) + self.progress_bar_main_var.set(2) + if self.chosen_audio_tool_var.get() == CHANGE_PITCH: + audio_tool = AudioTools(CHANGE_PITCH) + self.progress_bar_main_var.set(2) + if self.chosen_audio_tool_var.get() == MANUAL_ENSEMBLE: + audio_tool = Ensembler(is_manual_ensemble=True) + multiple_files = True + if total_files <= 1: + self.command_Text.write("Not enough files to process.\n") + self.process_end() + return + if self.chosen_audio_tool_var.get() == ALIGN_INPUTS: + multiple_files = True + audio_tool = AudioTools(ALIGN_INPUTS) + if not total_files == 2: + self.command_Text.write("You must select exactly 2 inputs!\n") + self.process_end() + return + + for file_num, audio_file in enumerate(inputPaths, start=1): + + base = (100 / total_files) + + if audio_tool.audio_tool in [MANUAL_ENSEMBLE, ALIGN_INPUTS]: + audio_file_base = f'{os.path.splitext(os.path.basename(inputPaths[0]))[0]}' + + else: + audio_file_base = f'{os.path.splitext(os.path.basename(audio_file))[0]}' + + self.base_text = self.process_get_baseText(total_files=total_files, file_num=total_files if multiple_files else file_num) + command_Text = lambda text:self.command_Text.write(self.base_text + text) + + if self.verify_audio(audio_file): + if not audio_tool.audio_tool in [MANUAL_ENSEMBLE, ALIGN_INPUTS]: + audio_file = self.create_sample(audio_file) if is_model_sample_mode else audio_file + self.command_Text.write(f'{NEW_LINE if not file_num ==1 else NO_LINE}{self.base_text}"{os.path.basename(audio_file)}\".{NEW_LINES}') + elif audio_tool.audio_tool == ALIGN_INPUTS: + self.command_Text.write('File 1 "{}"{}'.format(os.path.basename(inputPaths[0]), NEW_LINE)) + self.command_Text.write('File 2 "{}"{}'.format(os.path.basename(inputPaths[1]), NEW_LINES)) + elif audio_tool.audio_tool == MANUAL_ENSEMBLE: + for n, i in enumerate(inputPaths): + self.command_Text.write('File {} "{}"{}'.format(n+1, os.path.basename(i), NEW_LINE)) + self.command_Text.write(NEW_LINE) + + is_verified_audio = True + else: + error_text_console = f'{self.base_text}"{os.path.basename(audio_file)}\" is missing or currupted.\n' + self.command_Text.write(f'\n{error_text_console}') if total_files >= 2 else None + is_verified_audio = False + continue + + command_Text('Process starting... ') if not audio_tool.audio_tool == ALIGN_INPUTS else None + + if audio_tool.audio_tool == MANUAL_ENSEMBLE: + self.progress_bar_main_var.set(50) + audio_tool.ensemble_manual(inputPaths, audio_file_base) + self.progress_bar_main_var.set(100) + self.command_Text.write(DONE) + break + if audio_tool.audio_tool == ALIGN_INPUTS: + command_Text('Process starting... \n') + audio_file_2_base = f'{os.path.splitext(os.path.basename(inputPaths[1]))[0]}' + audio_tool.align_inputs(inputPaths, audio_file_base, audio_file_2_base, command_Text) + self.command_Text.write(DONE) + break + if audio_tool.audio_tool in [TIME_STRETCH, CHANGE_PITCH]: + audio_tool.pitch_or_time_shift(audio_file, audio_file_base) + self.progress_bar_main_var.set(base*file_num) + self.command_Text.write(DONE) + + if total_files == 1 and not is_verified_audio: + self.command_Text.write(f'{error_text_console}\n{PROCESS_FAILED}') + self.command_Text.write(time_elapsed()) + playsound(FAIL_CHIME) if self.is_task_complete_var.get() else None + else: + self.command_Text.write('\nProcess complete\n{}'.format(time_elapsed())) + playsound(COMPLETE_CHIME) if self.is_task_complete_var.get() else None + + self.process_end() + + except Exception as e: + self.error_log_var.set(error_text(self.chosen_audio_tool_var.get(), e)) + self.command_Text.write(f'\n\n{PROCESS_FAILED}') + self.command_Text.write(time_elapsed()) + playsound(FAIL_CHIME) if self.is_task_complete_var.get() else None + self.process_end(error=e) + + def process_determine_secondary_model(self, process_method, main_model_primary_stem, is_primary_stem_only=False, is_secondary_stem_only=False): + """Obtains the correct secondary model data for conversion.""" + + secondary_model_scale = None + secondary_model = StringVar(value=NO_MODEL) + + if process_method == VR_ARCH_TYPE: + secondary_model_vars = self.vr_secondary_model_vars + if process_method == MDX_ARCH_TYPE: + secondary_model_vars = self.mdx_secondary_model_vars + if process_method == DEMUCS_ARCH_TYPE: + secondary_model_vars = self.demucs_secondary_model_vars + + if main_model_primary_stem in [VOCAL_STEM, INST_STEM]: + secondary_model = secondary_model_vars["voc_inst_secondary_model"] + secondary_model_scale = secondary_model_vars["voc_inst_secondary_model_scale"].get() + if main_model_primary_stem in [OTHER_STEM, NO_OTHER_STEM]: + secondary_model = secondary_model_vars["other_secondary_model"] + secondary_model_scale = secondary_model_vars["other_secondary_model_scale"].get() + if main_model_primary_stem in [DRUM_STEM, NO_DRUM_STEM]: + secondary_model = secondary_model_vars["drums_secondary_model"] + secondary_model_scale = secondary_model_vars["drums_secondary_model_scale"].get() + if main_model_primary_stem in [BASS_STEM, NO_BASS_STEM]: + secondary_model = secondary_model_vars["bass_secondary_model"] + secondary_model_scale = secondary_model_vars["bass_secondary_model_scale"].get() + + if secondary_model_scale: + secondary_model_scale = float(secondary_model_scale) + + if not secondary_model.get() == NO_MODEL: + secondary_model = ModelData(secondary_model.get(), + is_secondary_model=True, + primary_model_primary_stem=main_model_primary_stem, + is_primary_model_primary_stem_only=is_primary_stem_only, + is_primary_model_secondary_stem_only=is_secondary_stem_only) + if not secondary_model.model_status: + secondary_model = None + else: + secondary_model = None + + return secondary_model, secondary_model_scale + + def process_determine_demucs_pre_proc_model(self, primary_stem=None): + """Obtains the correct secondary model data for conversion.""" + + pre_proc_model = None + + if not self.demucs_pre_proc_model_var.get() == NO_MODEL and self.is_demucs_pre_proc_model_activate_var.get(): + pre_proc_model = ModelData(self.demucs_pre_proc_model_var.get(), + primary_model_primary_stem=primary_stem, + is_pre_proc_model=True) + if not pre_proc_model.model_status: + pre_proc_model = None + else: + pre_proc_model = None + + return pre_proc_model + + def process_start(self): + """Start the conversion for all the given mp3 and wav files""" + + stime = time.perf_counter() + time_elapsed = lambda:f'Time Elapsed: {time.strftime("%H:%M:%S", time.gmtime(int(time.perf_counter() - stime)))}' + export_path = self.export_path_var.get() + is_ensemble = False + true_model_count = 0 + self.iteration = 0 + is_verified_audio = True + self.process_button_init() + inputPaths = self.inputPaths + inputPath_total_len = len(inputPaths) + is_model_sample_mode = self.model_sample_mode_var.get() + + try: + if self.chosen_process_method_var.get() == ENSEMBLE_MODE: + model, ensemble = self.assemble_model_data(), Ensembler() + export_path, is_ensemble = ensemble.ensemble_folder_name, True + if self.chosen_process_method_var.get() == VR_ARCH_PM: + model = self.assemble_model_data(self.vr_model_var.get(), VR_ARCH_TYPE) + if self.chosen_process_method_var.get() == MDX_ARCH_TYPE: + model = self.assemble_model_data(self.mdx_net_model_var.get(), MDX_ARCH_TYPE) + if self.chosen_process_method_var.get() == DEMUCS_ARCH_TYPE: + model = self.assemble_model_data(self.demucs_model_var.get(), DEMUCS_ARCH_TYPE) + + self.cached_source_model_list_check(model) + + true_model_4_stem_count = sum(m.demucs_4_stem_added_count if m.process_method == DEMUCS_ARCH_TYPE else 0 for m in model) + true_model_pre_proc_model_count = sum(2 if m.pre_proc_model_activated else 0 for m in model) + true_model_count = sum(2 if m.is_secondary_model_activated else 1 for m in model) + true_model_4_stem_count + true_model_pre_proc_model_count + + for file_num, audio_file in enumerate(inputPaths, start=1): + self.cached_sources_clear() + base_text = self.process_get_baseText(total_files=inputPath_total_len, file_num=file_num) + + if self.verify_audio(audio_file): + audio_file = self.create_sample(audio_file) if is_model_sample_mode else audio_file + self.command_Text.write(f'{NEW_LINE if not file_num ==1 else NO_LINE}{base_text}"{os.path.basename(audio_file)}\".{NEW_LINES}') + is_verified_audio = True + else: + error_text_console = f'{base_text}"{os.path.basename(audio_file)}\" is missing or currupted.\n' + self.command_Text.write(f'\n{error_text_console}') if inputPath_total_len >= 2 else None + self.iteration += true_model_count + is_verified_audio = False + continue + + for current_model_num, current_model in enumerate(model, start=1): + self.iteration += 1 + + if is_ensemble: + self.command_Text.write(f'Ensemble Mode - {current_model.model_basename} - Model {current_model_num}/{len(model)}{NEW_LINES}') + + model_name_text = f'({current_model.model_basename})' if not is_ensemble else '' + self.command_Text.write(base_text + f'Loading model {model_name_text}...') + + progress_kwargs = {'model_count': true_model_count, + 'total_files': inputPath_total_len} + + set_progress_bar = lambda step, inference_iterations=0:self.process_update_progress(**progress_kwargs, step=(step + (inference_iterations))) + write_to_console = lambda progress_text, base_text=base_text:self.command_Text.write(base_text + progress_text) + + audio_file_base = f"{file_num}_{os.path.splitext(os.path.basename(audio_file))[0]}" + audio_file_base = audio_file_base if not self.is_testing_audio_var.get() or is_ensemble else f"{round(time.time())}_{audio_file_base}" + audio_file_base = audio_file_base if not is_ensemble else f"{audio_file_base}_{current_model.model_basename}" + audio_file_base = audio_file_base if not self.is_add_model_name_var.get() else f"{audio_file_base}_{current_model.model_basename}" + + if self.is_create_model_folder_var.get() and not is_ensemble: + export_path = os.path.join(Path(self.export_path_var.get()), current_model.model_basename, os.path.splitext(os.path.basename(audio_file))[0]) + if not os.path.isdir(export_path):os.makedirs(export_path) + + process_data = { + 'model_data': current_model, + 'export_path': export_path, + 'audio_file_base': audio_file_base, + 'audio_file': audio_file, + 'set_progress_bar': set_progress_bar, + 'write_to_console': write_to_console, + 'process_iteration': self.process_iteration, + 'cached_source_callback': self.cached_source_callback, + 'cached_model_source_holder': self.cached_model_source_holder, + 'list_all_models': self.all_models, + 'is_ensemble_master': is_ensemble, + 'is_4_stem_ensemble': True if self.ensemble_main_stem_var.get() == FOUR_STEM_ENSEMBLE and is_ensemble else False} + + if current_model.process_method == VR_ARCH_TYPE: + seperator = SeperateVR(current_model, process_data) + if current_model.process_method == MDX_ARCH_TYPE: + seperator = SeperateMDX(current_model, process_data) + if current_model.process_method == DEMUCS_ARCH_TYPE: + seperator = SeperateDemucs(current_model, process_data) + + seperator.seperate() + + if is_ensemble: + self.command_Text.write('\n') + + if is_ensemble: + + audio_file_base = audio_file_base.replace(f"_{current_model.model_basename}","") + self.command_Text.write(base_text + ENSEMBLING_OUTPUTS) + + if self.ensemble_main_stem_var.get() == FOUR_STEM_ENSEMBLE: + for output_stem in DEMUCS_4_SOURCE_LIST: + ensemble.ensemble_outputs(audio_file_base, export_path, output_stem, is_4_stem=True) + else: + if not self.is_secondary_stem_only_var.get(): + ensemble.ensemble_outputs(audio_file_base, export_path, PRIMARY_STEM) + if not self.is_primary_stem_only_var.get(): + ensemble.ensemble_outputs(audio_file_base, export_path, SECONDARY_STEM) + ensemble.ensemble_outputs(audio_file_base, export_path, SECONDARY_STEM, is_inst_mix=True) + + self.command_Text.write(DONE) + + if is_model_sample_mode: + if os.path.isfile(audio_file): + os.remove(audio_file) + + torch.cuda.empty_cache() + + shutil.rmtree(export_path) if is_ensemble and len(os.listdir(export_path)) == 0 else None + + if inputPath_total_len == 1 and not is_verified_audio: + self.command_Text.write(f'{error_text_console}\n{PROCESS_FAILED}') + self.command_Text.write(time_elapsed()) + playsound(FAIL_CHIME) if self.is_task_complete_var.get() else None + else: + set_progress_bar(1.0) + self.command_Text.write('\nProcess Complete\n') + self.command_Text.write(time_elapsed()) + playsound(COMPLETE_CHIME) if self.is_task_complete_var.get() else None + + self.process_end() + + except Exception as e: + self.error_log_var.set(error_text(self.chosen_process_method_var.get(), e)) + self.command_Text.write(f'\n\n{PROCESS_FAILED}') + self.command_Text.write(time_elapsed()) + playsound(FAIL_CHIME) if self.is_task_complete_var.get() else None + self.process_end(error=e) + + #--Varible Methods-- + + def load_to_default_confirm(self): + """Reset settings confirmation after asking for confirmation""" + + if self.thread_check(self.active_processing_thread): + self.error_dialoge(SET_TO_DEFAULT_PROCESS_ERROR) + else: + confirm = tk.messagebox.askyesno(title=RESET_ALL_TO_DEFAULT_WARNING[0], + message=RESET_ALL_TO_DEFAULT_WARNING[1]) + + if confirm: + self.load_saved_settings(DEFAULT_DATA) + + def load_saved_vars(self, data): + """Initializes primary Tkinter vars""" + + ## ADD_BUTTON + self.chosen_process_method_var = tk.StringVar(value=data['chosen_process_method']) + + #VR Architecture Vars + self.vr_model_var = tk.StringVar(value=data['vr_model']) + self.aggression_setting_var = tk.StringVar(value=data['aggression_setting']) + self.window_size_var = tk.StringVar(value=data['window_size']) + self.batch_size_var = tk.StringVar(value=data['batch_size']) + self.crop_size_var = tk.StringVar(value=data['crop_size']) + self.is_tta_var = tk.BooleanVar(value=data['is_tta']) + self.is_output_image_var = tk.BooleanVar(value=data['is_output_image']) + self.is_post_process_var = tk.BooleanVar(value=data['is_post_process']) + self.is_high_end_process_var = tk.BooleanVar(value=data['is_high_end_process']) + self.vr_voc_inst_secondary_model_var = tk.StringVar(value=data['vr_voc_inst_secondary_model']) + self.vr_other_secondary_model_var = tk.StringVar(value=data['vr_other_secondary_model']) + self.vr_bass_secondary_model_var = tk.StringVar(value=data['vr_bass_secondary_model']) + self.vr_drums_secondary_model_var = tk.StringVar(value=data['vr_drums_secondary_model']) + self.vr_is_secondary_model_activate_var = tk.BooleanVar(value=data['vr_is_secondary_model_activate']) + self.vr_voc_inst_secondary_model_scale_var = tk.StringVar(value=data['vr_voc_inst_secondary_model_scale']) + self.vr_other_secondary_model_scale_var = tk.StringVar(value=data['vr_other_secondary_model_scale']) + self.vr_bass_secondary_model_scale_var = tk.StringVar(value=data['vr_bass_secondary_model_scale']) + self.vr_drums_secondary_model_scale_var = tk.StringVar(value=data['vr_drums_secondary_model_scale']) + + #Demucs Vars + self.demucs_model_var = tk.StringVar(value=data['demucs_model']) + self.segment_var = tk.StringVar(value=data['segment']) + self.overlap_var = tk.StringVar(value=data['overlap']) + self.shifts_var = tk.StringVar(value=data['shifts']) + self.chunks_demucs_var = tk.StringVar(value=data['chunks_demucs']) + self.margin_demucs_var = tk.StringVar(value=data['margin_demucs']) + self.is_chunk_demucs_var = tk.BooleanVar(value=data['is_chunk_demucs']) + self.is_primary_stem_only_Demucs_var = tk.BooleanVar(value=data['is_primary_stem_only_Demucs']) + self.is_secondary_stem_only_Demucs_var = tk.BooleanVar(value=data['is_secondary_stem_only_Demucs']) + self.is_split_mode_var = tk.BooleanVar(value=data['is_split_mode']) + self.is_demucs_combine_stems_var = tk.BooleanVar(value=data['is_demucs_combine_stems']) + self.demucs_voc_inst_secondary_model_var = tk.StringVar(value=data['demucs_voc_inst_secondary_model']) + self.demucs_other_secondary_model_var = tk.StringVar(value=data['demucs_other_secondary_model']) + self.demucs_bass_secondary_model_var = tk.StringVar(value=data['demucs_bass_secondary_model']) + self.demucs_drums_secondary_model_var = tk.StringVar(value=data['demucs_drums_secondary_model']) + self.demucs_is_secondary_model_activate_var = tk.BooleanVar(value=data['demucs_is_secondary_model_activate']) + self.demucs_voc_inst_secondary_model_scale_var = tk.StringVar(value=data['demucs_voc_inst_secondary_model_scale']) + self.demucs_other_secondary_model_scale_var = tk.StringVar(value=data['demucs_other_secondary_model_scale']) + self.demucs_bass_secondary_model_scale_var = tk.StringVar(value=data['demucs_bass_secondary_model_scale']) + self.demucs_drums_secondary_model_scale_var = tk.StringVar(value=data['demucs_drums_secondary_model_scale']) + self.demucs_pre_proc_model_var = tk.StringVar(value=data['demucs_pre_proc_model']) + self.is_demucs_pre_proc_model_activate_var = tk.BooleanVar(value=data['is_demucs_pre_proc_model_activate']) + self.is_demucs_pre_proc_model_inst_mix_var = tk.BooleanVar(value=data['is_demucs_pre_proc_model_inst_mix']) + + #MDX-Net Vars + self.mdx_net_model_var = tk.StringVar(value=data['mdx_net_model']) + self.chunks_var = tk.StringVar(value=data['chunks']) + self.margin_var = tk.StringVar(value=data['margin']) + self.compensate_var = tk.StringVar(value=data['compensate']) + self.is_denoise_var = tk.BooleanVar(value=data['is_denoise']) + self.is_invert_spec_var = tk.BooleanVar(value=data['is_invert_spec']) + self.mdx_voc_inst_secondary_model_var = tk.StringVar(value=data['mdx_voc_inst_secondary_model']) + self.mdx_other_secondary_model_var = tk.StringVar(value=data['mdx_other_secondary_model']) + self.mdx_bass_secondary_model_var = tk.StringVar(value=data['mdx_bass_secondary_model']) + self.mdx_drums_secondary_model_var = tk.StringVar(value=data['mdx_drums_secondary_model']) + self.mdx_is_secondary_model_activate_var = tk.BooleanVar(value=data['mdx_is_secondary_model_activate']) + self.mdx_voc_inst_secondary_model_scale_var = tk.StringVar(value=data['mdx_voc_inst_secondary_model_scale']) + self.mdx_other_secondary_model_scale_var = tk.StringVar(value=data['mdx_other_secondary_model_scale']) + self.mdx_bass_secondary_model_scale_var = tk.StringVar(value=data['mdx_bass_secondary_model_scale']) + self.mdx_drums_secondary_model_scale_var = tk.StringVar(value=data['mdx_drums_secondary_model_scale']) + + #Ensemble Vars + self.is_save_all_outputs_ensemble_var = tk.BooleanVar(value=data['is_save_all_outputs_ensemble']) + self.is_append_ensemble_name_var = tk.BooleanVar(value=data['is_append_ensemble_name']) + + #Audio Tool Vars + self.chosen_audio_tool_var = tk.StringVar(value=data['chosen_audio_tool']) + self.choose_algorithm_var = tk.StringVar(value=data['choose_algorithm']) + self.time_stretch_rate_var = tk.StringVar(value=data['time_stretch_rate']) + self.pitch_rate_var = tk.StringVar(value=data['pitch_rate']) + + #Shared Vars + self.mp3_bit_set_var = tk.StringVar(value=data['mp3_bit_set']) + self.save_format_var = tk.StringVar(value=data['save_format']) + self.wav_type_set_var = tk.StringVar(value=data['wav_type_set']) + self.user_code_var = tk.StringVar(value=data['user_code']) + self.is_gpu_conversion_var = tk.BooleanVar(value=data['is_gpu_conversion']) + self.is_primary_stem_only_var = tk.BooleanVar(value=data['is_primary_stem_only']) + self.is_secondary_stem_only_var = tk.BooleanVar(value=data['is_secondary_stem_only']) + self.is_testing_audio_var = tk.BooleanVar(value=data['is_testing_audio']) + self.is_add_model_name_var = tk.BooleanVar(value=data['is_add_model_name']) + self.is_accept_any_input_var = tk.BooleanVar(value=data['is_accept_any_input']) + self.is_task_complete_var = tk.BooleanVar(value=data['is_task_complete']) + self.is_normalization_var = tk.BooleanVar(value=data['is_normalization']) + self.is_create_model_folder_var = tk.BooleanVar(value=data['is_create_model_folder']) + self.help_hints_var = tk.BooleanVar(value=data['help_hints_var']) + self.model_sample_mode_var = tk.BooleanVar(value=data['model_sample_mode']) + self.model_sample_mode_duration_var = tk.StringVar(value=data['model_sample_mode_duration']) + self.model_sample_mode_duration_checkbox_var = tk.StringVar(value=SAMPLE_MODE_CHECKBOX(self.model_sample_mode_duration_var.get())) + + #Path Vars + self.export_path_var = tk.StringVar(value=data['export_path']) + self.inputPaths = data['input_paths'] + self.lastDir = data['lastDir'] + + def load_saved_settings(self, loaded_setting: dict, process_method=None): + """Loads user saved application settings or resets to default""" + + if not process_method or process_method == VR_ARCH_PM: + self.vr_model_var.set(loaded_setting['vr_model']) + self.aggression_setting_var.set(loaded_setting['aggression_setting']) + self.window_size_var.set(loaded_setting['window_size']) + self.batch_size_var.set(loaded_setting['batch_size']) + self.crop_size_var.set(loaded_setting['crop_size']) + self.is_tta_var.set(loaded_setting['is_tta']) + self.is_output_image_var.set(loaded_setting['is_output_image']) + self.is_post_process_var.set(loaded_setting['is_post_process']) + self.is_high_end_process_var.set(loaded_setting['is_high_end_process']) + self.vr_voc_inst_secondary_model_var.set(loaded_setting['vr_voc_inst_secondary_model']) + self.vr_other_secondary_model_var.set(loaded_setting['vr_other_secondary_model']) + self.vr_bass_secondary_model_var.set(loaded_setting['vr_bass_secondary_model']) + self.vr_drums_secondary_model_var.set(loaded_setting['vr_drums_secondary_model']) + self.vr_is_secondary_model_activate_var.set(loaded_setting['vr_is_secondary_model_activate']) + self.vr_voc_inst_secondary_model_scale_var.set(loaded_setting['vr_voc_inst_secondary_model_scale']) + self.vr_other_secondary_model_scale_var.set(loaded_setting['vr_other_secondary_model_scale']) + self.vr_bass_secondary_model_scale_var.set(loaded_setting['vr_bass_secondary_model_scale']) + self.vr_drums_secondary_model_scale_var.set(loaded_setting['vr_drums_secondary_model_scale']) + + if not process_method or process_method == DEMUCS_ARCH_TYPE: + self.demucs_model_var.set(loaded_setting['demucs_model']) + self.segment_var.set(loaded_setting['segment']) + self.overlap_var.set(loaded_setting['overlap']) + self.shifts_var.set(loaded_setting['shifts']) + self.chunks_demucs_var.set(loaded_setting['chunks_demucs']) + self.margin_demucs_var.set(loaded_setting['margin_demucs']) + self.is_chunk_demucs_var.set(loaded_setting['is_chunk_demucs']) + self.is_primary_stem_only_Demucs_var.set(loaded_setting['is_primary_stem_only_Demucs']) + self.is_secondary_stem_only_Demucs_var.set(loaded_setting['is_secondary_stem_only_Demucs']) + self.is_split_mode_var.set(loaded_setting['is_split_mode']) + self.is_demucs_combine_stems_var.set(loaded_setting['is_demucs_combine_stems']) + self.demucs_voc_inst_secondary_model_var.set(loaded_setting['demucs_voc_inst_secondary_model']) + self.demucs_other_secondary_model_var.set(loaded_setting['demucs_other_secondary_model']) + self.demucs_bass_secondary_model_var.set(loaded_setting['demucs_bass_secondary_model']) + self.demucs_drums_secondary_model_var.set(loaded_setting['demucs_drums_secondary_model']) + self.demucs_is_secondary_model_activate_var.set(loaded_setting['demucs_is_secondary_model_activate']) + self.demucs_voc_inst_secondary_model_scale_var.set(loaded_setting['demucs_voc_inst_secondary_model_scale']) + self.demucs_other_secondary_model_scale_var.set(loaded_setting['demucs_other_secondary_model_scale']) + self.demucs_bass_secondary_model_scale_var.set(loaded_setting['demucs_bass_secondary_model_scale']) + self.demucs_drums_secondary_model_scale_var.set(loaded_setting['demucs_drums_secondary_model_scale']) + self.demucs_stems_var.set(loaded_setting['demucs_stems']) + self.update_stem_checkbox_labels(self.demucs_stems_var.get(), demucs=True) + self.demucs_pre_proc_model_var.set(data['demucs_pre_proc_model']) + self.is_demucs_pre_proc_model_activate_var.set(data['is_demucs_pre_proc_model_activate']) + self.is_demucs_pre_proc_model_inst_mix_var.set(data['is_demucs_pre_proc_model_inst_mix']) + + if not process_method or process_method == MDX_ARCH_TYPE: + self.mdx_net_model_var.set(loaded_setting['mdx_net_model']) + self.chunks_var.set(loaded_setting['chunks']) + self.margin_var.set(loaded_setting['margin']) + self.compensate_var.set(loaded_setting['compensate']) + self.is_denoise_var.set(loaded_setting['is_denoise']) + self.is_invert_spec_var.set(loaded_setting['is_invert_spec']) + self.mdx_voc_inst_secondary_model_var.set(loaded_setting['mdx_voc_inst_secondary_model']) + self.mdx_other_secondary_model_var.set(loaded_setting['mdx_other_secondary_model']) + self.mdx_bass_secondary_model_var.set(loaded_setting['mdx_bass_secondary_model']) + self.mdx_drums_secondary_model_var.set(loaded_setting['mdx_drums_secondary_model']) + self.mdx_is_secondary_model_activate_var.set(loaded_setting['mdx_is_secondary_model_activate']) + self.mdx_voc_inst_secondary_model_scale_var.set(loaded_setting['mdx_voc_inst_secondary_model_scale']) + self.mdx_other_secondary_model_scale_var.set(loaded_setting['mdx_other_secondary_model_scale']) + self.mdx_bass_secondary_model_scale_var.set(loaded_setting['mdx_bass_secondary_model_scale']) + self.mdx_drums_secondary_model_scale_var.set(loaded_setting['mdx_drums_secondary_model_scale']) + + if not process_method: + self.is_save_all_outputs_ensemble_var.set(loaded_setting['is_save_all_outputs_ensemble']) + self.is_append_ensemble_name_var.set(loaded_setting['is_append_ensemble_name']) + self.chosen_audio_tool_var.set(loaded_setting['chosen_audio_tool']) + self.choose_algorithm_var.set(loaded_setting['choose_algorithm']) + self.time_stretch_rate_var.set(loaded_setting['time_stretch_rate']) + self.pitch_rate_var.set(loaded_setting['pitch_rate']) + self.is_primary_stem_only_var.set(loaded_setting['is_primary_stem_only']) + self.is_secondary_stem_only_var.set(loaded_setting['is_secondary_stem_only']) + self.is_testing_audio_var.set(loaded_setting['is_testing_audio']) + self.is_add_model_name_var.set(loaded_setting['is_add_model_name']) + self.is_accept_any_input_var.set(loaded_setting["is_accept_any_input"]) + self.is_task_complete_var.set(loaded_setting['is_task_complete']) + self.is_create_model_folder_var.set(loaded_setting['is_create_model_folder']) + self.mp3_bit_set_var.set(loaded_setting['mp3_bit_set']) + self.save_format_var.set(loaded_setting['save_format']) + self.wav_type_set_var.set(loaded_setting['wav_type_set']) + self.user_code_var.set(loaded_setting['user_code']) + + self.is_gpu_conversion_var.set(loaded_setting['is_gpu_conversion']) + self.is_normalization_var.set(loaded_setting['is_normalization']) + self.help_hints_var.set(loaded_setting['help_hints_var']) + + self.model_sample_mode_var.set(loaded_setting['model_sample_mode']) + self.model_sample_mode_duration_var.set(loaded_setting['model_sample_mode_duration']) + self.model_sample_mode_duration_checkbox_var.set(SAMPLE_MODE_CHECKBOX(self.model_sample_mode_duration_var.get())) + + def save_values(self, app_close=True): + """Saves application data""" + + # -Save Data- + main_settings={ + 'vr_model': self.vr_model_var.get(), + 'aggression_setting': self.aggression_setting_var.get(), + 'window_size': self.window_size_var.get(), + 'batch_size': self.batch_size_var.get(), + 'crop_size': self.crop_size_var.get(), + 'is_tta': self.is_tta_var.get(), + 'is_output_image': self.is_output_image_var.get(), + 'is_post_process': self.is_post_process_var.get(), + 'is_high_end_process': self.is_high_end_process_var.get(), + 'vr_voc_inst_secondary_model': self.vr_voc_inst_secondary_model_var.get(), + 'vr_other_secondary_model': self.vr_other_secondary_model_var.get(), + 'vr_bass_secondary_model': self.vr_bass_secondary_model_var.get(), + 'vr_drums_secondary_model': self.vr_drums_secondary_model_var.get(), + 'vr_is_secondary_model_activate': self.vr_is_secondary_model_activate_var.get(), + 'vr_voc_inst_secondary_model_scale': self.vr_voc_inst_secondary_model_scale_var.get(), + 'vr_other_secondary_model_scale': self.vr_other_secondary_model_scale_var.get(), + 'vr_bass_secondary_model_scale': self.vr_bass_secondary_model_scale_var.get(), + 'vr_drums_secondary_model_scale': self.vr_drums_secondary_model_scale_var.get(), + + 'demucs_model': self.demucs_model_var.get(), + 'segment': self.segment_var.get(), + 'overlap': self.overlap_var.get(), + 'shifts': self.shifts_var.get(), + 'chunks_demucs': self.chunks_demucs_var.get(), + 'margin_demucs': self.margin_demucs_var.get(), + 'is_chunk_demucs': self.is_chunk_demucs_var.get(), + 'is_primary_stem_only_Demucs': self.is_primary_stem_only_Demucs_var.get(), + 'is_secondary_stem_only_Demucs': self.is_secondary_stem_only_Demucs_var.get(), + 'is_split_mode': self.is_split_mode_var.get(), + 'is_demucs_combine_stems': self.is_demucs_combine_stems_var.get(), + 'demucs_voc_inst_secondary_model': self.demucs_voc_inst_secondary_model_var.get(), + 'demucs_other_secondary_model': self.demucs_other_secondary_model_var.get(), + 'demucs_bass_secondary_model': self.demucs_bass_secondary_model_var.get(), + 'demucs_drums_secondary_model': self.demucs_drums_secondary_model_var.get(), + 'demucs_is_secondary_model_activate': self.demucs_is_secondary_model_activate_var.get(), + 'demucs_voc_inst_secondary_model_scale': self.demucs_voc_inst_secondary_model_scale_var.get(), + 'demucs_other_secondary_model_scale': self.demucs_other_secondary_model_scale_var.get(), + 'demucs_bass_secondary_model_scale': self.demucs_bass_secondary_model_scale_var.get(), + 'demucs_drums_secondary_model_scale': self.demucs_drums_secondary_model_scale_var.get(), + 'demucs_pre_proc_model': self.demucs_pre_proc_model_var.get(), + 'is_demucs_pre_proc_model_activate': self.is_demucs_pre_proc_model_activate_var.get(), + 'is_demucs_pre_proc_model_inst_mix': self.is_demucs_pre_proc_model_inst_mix_var.get(), + + 'mdx_net_model': self.mdx_net_model_var.get(), + 'chunks': self.chunks_var.get(), + 'margin': self.margin_var.get(), + 'compensate': self.compensate_var.get(), + 'is_denoise': self.is_denoise_var.get(), + 'is_invert_spec': self.is_invert_spec_var.get(), + 'mdx_voc_inst_secondary_model': self.mdx_voc_inst_secondary_model_var.get(), + 'mdx_other_secondary_model': self.mdx_other_secondary_model_var.get(), + 'mdx_bass_secondary_model': self.mdx_bass_secondary_model_var.get(), + 'mdx_drums_secondary_model': self.mdx_drums_secondary_model_var.get(), + 'mdx_is_secondary_model_activate': self.mdx_is_secondary_model_activate_var.get(), + 'mdx_voc_inst_secondary_model_scale': self.mdx_voc_inst_secondary_model_scale_var.get(), + 'mdx_other_secondary_model_scale': self.mdx_other_secondary_model_scale_var.get(), + 'mdx_bass_secondary_model_scale': self.mdx_bass_secondary_model_scale_var.get(), + 'mdx_drums_secondary_model_scale': self.mdx_drums_secondary_model_scale_var.get(), + + 'is_save_all_outputs_ensemble': self.is_save_all_outputs_ensemble_var.get(), + 'is_append_ensemble_name': self.is_append_ensemble_name_var.get(), + 'chosen_audio_tool': self.chosen_audio_tool_var.get(), + 'choose_algorithm': self.choose_algorithm_var.get(), + 'time_stretch_rate': self.time_stretch_rate_var.get(), + 'pitch_rate': self.pitch_rate_var.get(), + 'is_gpu_conversion': self.is_gpu_conversion_var.get(), + 'is_primary_stem_only': self.is_primary_stem_only_var.get(), + 'is_secondary_stem_only': self.is_secondary_stem_only_var.get(), + 'is_testing_audio': self.is_testing_audio_var.get(), + 'is_add_model_name': self.is_add_model_name_var.get(), + 'is_accept_any_input': self.is_add_model_name_var.get(), + 'is_task_complete': self.is_task_complete_var.get(), + 'is_normalization': self.is_normalization_var.get(), + 'is_create_model_folder': self.is_create_model_folder_var.get(), + 'mp3_bit_set': self.mp3_bit_set_var.get(), + 'save_format': self.save_format_var.get(), + 'wav_type_set': self.wav_type_set_var.get(), + 'user_code': self.user_code_var.get(), + 'help_hints_var': self.help_hints_var.get(), + 'model_sample_mode': self.model_sample_mode_var.get(), + 'model_sample_mode_duration': self.model_sample_mode_duration_var.get() + } + + other_data = { + 'chosen_process_method': self.chosen_process_method_var.get(), + 'input_paths': self.inputPaths, + 'lastDir': self.lastDir, + 'export_path': self.export_path_var.get(), + 'model_hash_table': model_hash_table, + } + + user_saved_extras = { + 'demucs_stems': self.demucs_stems_var.get()} + + if app_close: + save_data(data={**main_settings, **other_data}) + + if self.thread_check(self.active_download_thread): + self.error_dialoge(EXIT_DOWNLOAD_ERROR) + return + + if self.thread_check(self.active_processing_thread): + if self.is_process_stopped: + self.error_dialoge(EXIT_HALTED_PROCESS_ERROR) + else: + self.error_dialoge(EXIT_PROCESS_ERROR) + return + + remove_temps(ENSEMBLE_TEMP_PATH) + remove_temps(SAMPLE_CLIP_PATH) + self.delete_temps() + self.destroy() + + else: + return {**main_settings, **user_saved_extras} + +def secondary_stem(stem): + """Determines secondary stem""" + + for key, value in STEM_PAIR_MAPPER.items(): + if stem in key: + secondary_stem = value + + return secondary_stem + +def vip_downloads(password, link_type=VIP_REPO): + """Attempts to decrypt VIP model link with given input code""" + + try: + kdf = PBKDF2HMAC( + algorithm=hashes.SHA256(), + length=32, + salt=link_type[0], + iterations=390000,) + + key = base64.urlsafe_b64encode(kdf.derive(bytes(password, 'utf-8'))) + f = Fernet(key) + + return str(f.decrypt(link_type[1]), 'UTF-8') + except Exception: + return NO_CODE + +if __name__ == "__main__": + + try: + from ctypes import windll, wintypes + windll.user32.SetThreadDpiAwarenessContext(wintypes.HANDLE(-1)) + except Exception as e: + print(e) + pass + + root = MainWindow() + + root.update_checkbox_text() + + root.mainloop() \ No newline at end of file diff --git a/__version__.py b/__version__.py index 7342798e..81f221d7 100644 --- a/__version__.py +++ b/__version__.py @@ -1 +1,2 @@ -VERSION = '5.3.0' +VERSION = 'v5.5.0' +PATCH = 'UVR_Patch_12_16_22_3_30' \ No newline at end of file diff --git a/demucs/__init__.py b/demucs/__init__.py new file mode 100644 index 00000000..5656d59e --- /dev/null +++ b/demucs/__init__.py @@ -0,0 +1,5 @@ +# Copyright (c) Facebook, Inc. and its affiliates. +# All rights reserved. +# +# This source code is licensed under the license found in the +# LICENSE file in the root directory of this source tree. diff --git a/demucs/__main__.py b/demucs/__main__.py new file mode 100644 index 00000000..5de878f1 --- /dev/null +++ b/demucs/__main__.py @@ -0,0 +1,272 @@ +# Copyright (c) Facebook, Inc. and its affiliates. +# All rights reserved. +# +# This source code is licensed under the license found in the +# LICENSE file in the root directory of this source tree. + +import json +import os +import sys +import time +from dataclasses import dataclass, field +from fractions import Fraction + +import torch as th +from torch import distributed, nn +from torch.nn.parallel.distributed import DistributedDataParallel + +from .augment import FlipChannels, FlipSign, Remix, Shift +from .compressed import StemsSet, build_musdb_metadata, get_musdb_tracks +from .model import Demucs +from .parser import get_name, get_parser +from .raw import Rawset +from .tasnet import ConvTasNet +from .test import evaluate +from .train import train_model, validate_model +from .utils import human_seconds, load_model, save_model, sizeof_fmt + + +@dataclass +class SavedState: + metrics: list = field(default_factory=list) + last_state: dict = None + best_state: dict = None + optimizer: dict = None + + +def main(): + parser = get_parser() + args = parser.parse_args() + name = get_name(parser, args) + print(f"Experiment {name}") + + if args.musdb is None and args.rank == 0: + print( + "You must provide the path to the MusDB dataset with the --musdb flag. " + "To download the MusDB dataset, see https://sigsep.github.io/datasets/musdb.html.", + file=sys.stderr) + sys.exit(1) + + eval_folder = args.evals / name + eval_folder.mkdir(exist_ok=True, parents=True) + args.logs.mkdir(exist_ok=True) + metrics_path = args.logs / f"{name}.json" + eval_folder.mkdir(exist_ok=True, parents=True) + args.checkpoints.mkdir(exist_ok=True, parents=True) + args.models.mkdir(exist_ok=True, parents=True) + + if args.device is None: + device = "cpu" + if th.cuda.is_available(): + device = "cuda" + else: + device = args.device + + th.manual_seed(args.seed) + # Prevents too many threads to be started when running `museval` as it can be quite + # inefficient on NUMA architectures. + os.environ["OMP_NUM_THREADS"] = "1" + + if args.world_size > 1: + if device != "cuda" and args.rank == 0: + print("Error: distributed training is only available with cuda device", file=sys.stderr) + sys.exit(1) + th.cuda.set_device(args.rank % th.cuda.device_count()) + distributed.init_process_group(backend="nccl", + init_method="tcp://" + args.master, + rank=args.rank, + world_size=args.world_size) + + checkpoint = args.checkpoints / f"{name}.th" + checkpoint_tmp = args.checkpoints / f"{name}.th.tmp" + if args.restart and checkpoint.exists(): + checkpoint.unlink() + + if args.test: + args.epochs = 1 + args.repeat = 0 + model = load_model(args.models / args.test) + elif args.tasnet: + model = ConvTasNet(audio_channels=args.audio_channels, samplerate=args.samplerate, X=args.X) + else: + model = Demucs( + audio_channels=args.audio_channels, + channels=args.channels, + context=args.context, + depth=args.depth, + glu=args.glu, + growth=args.growth, + kernel_size=args.kernel_size, + lstm_layers=args.lstm_layers, + rescale=args.rescale, + rewrite=args.rewrite, + sources=4, + stride=args.conv_stride, + upsample=args.upsample, + samplerate=args.samplerate + ) + model.to(device) + if args.show: + print(model) + size = sizeof_fmt(4 * sum(p.numel() for p in model.parameters())) + print(f"Model size {size}") + return + + optimizer = th.optim.Adam(model.parameters(), lr=args.lr) + + try: + saved = th.load(checkpoint, map_location='cpu') + except IOError: + saved = SavedState() + else: + model.load_state_dict(saved.last_state) + optimizer.load_state_dict(saved.optimizer) + + if args.save_model: + if args.rank == 0: + model.to("cpu") + model.load_state_dict(saved.best_state) + save_model(model, args.models / f"{name}.th") + return + + if args.rank == 0: + done = args.logs / f"{name}.done" + if done.exists(): + done.unlink() + + if args.augment: + augment = nn.Sequential(FlipSign(), FlipChannels(), Shift(args.data_stride), + Remix(group_size=args.remix_group_size)).to(device) + else: + augment = Shift(args.data_stride) + + if args.mse: + criterion = nn.MSELoss() + else: + criterion = nn.L1Loss() + + # Setting number of samples so that all convolution windows are full. + # Prevents hard to debug mistake with the prediction being shifted compared + # to the input mixture. + samples = model.valid_length(args.samples) + print(f"Number of training samples adjusted to {samples}") + + if args.raw: + train_set = Rawset(args.raw / "train", + samples=samples + args.data_stride, + channels=args.audio_channels, + streams=[0, 1, 2, 3, 4], + stride=args.data_stride) + + valid_set = Rawset(args.raw / "valid", channels=args.audio_channels) + else: + if not args.metadata.is_file() and args.rank == 0: + build_musdb_metadata(args.metadata, args.musdb, args.workers) + if args.world_size > 1: + distributed.barrier() + metadata = json.load(open(args.metadata)) + duration = Fraction(samples + args.data_stride, args.samplerate) + stride = Fraction(args.data_stride, args.samplerate) + train_set = StemsSet(get_musdb_tracks(args.musdb, subsets=["train"], split="train"), + metadata, + duration=duration, + stride=stride, + samplerate=args.samplerate, + channels=args.audio_channels) + valid_set = StemsSet(get_musdb_tracks(args.musdb, subsets=["train"], split="valid"), + metadata, + samplerate=args.samplerate, + channels=args.audio_channels) + + best_loss = float("inf") + for epoch, metrics in enumerate(saved.metrics): + print(f"Epoch {epoch:03d}: " + f"train={metrics['train']:.8f} " + f"valid={metrics['valid']:.8f} " + f"best={metrics['best']:.4f} " + f"duration={human_seconds(metrics['duration'])}") + best_loss = metrics['best'] + + if args.world_size > 1: + dmodel = DistributedDataParallel(model, + device_ids=[th.cuda.current_device()], + output_device=th.cuda.current_device()) + else: + dmodel = model + + for epoch in range(len(saved.metrics), args.epochs): + begin = time.time() + model.train() + train_loss = train_model(epoch, + train_set, + dmodel, + criterion, + optimizer, + augment, + batch_size=args.batch_size, + device=device, + repeat=args.repeat, + seed=args.seed, + workers=args.workers, + world_size=args.world_size) + model.eval() + valid_loss = validate_model(epoch, + valid_set, + model, + criterion, + device=device, + rank=args.rank, + split=args.split_valid, + world_size=args.world_size) + + duration = time.time() - begin + if valid_loss < best_loss: + best_loss = valid_loss + saved.best_state = { + key: value.to("cpu").clone() + for key, value in model.state_dict().items() + } + saved.metrics.append({ + "train": train_loss, + "valid": valid_loss, + "best": best_loss, + "duration": duration + }) + if args.rank == 0: + json.dump(saved.metrics, open(metrics_path, "w")) + + saved.last_state = model.state_dict() + saved.optimizer = optimizer.state_dict() + if args.rank == 0 and not args.test: + th.save(saved, checkpoint_tmp) + checkpoint_tmp.rename(checkpoint) + + print(f"Epoch {epoch:03d}: " + f"train={train_loss:.8f} valid={valid_loss:.8f} best={best_loss:.4f} " + f"duration={human_seconds(duration)}") + + del dmodel + model.load_state_dict(saved.best_state) + if args.eval_cpu: + device = "cpu" + model.to(device) + model.eval() + evaluate(model, + args.musdb, + eval_folder, + rank=args.rank, + world_size=args.world_size, + device=device, + save=args.save, + split=args.split_valid, + shifts=args.shifts, + workers=args.eval_workers) + model.to("cpu") + save_model(model, args.models / f"{name}.th") + if args.rank == 0: + print("done") + done.write_text("done") + + +if __name__ == "__main__": + main() diff --git a/demucs/apply.py b/demucs/apply.py new file mode 100644 index 00000000..5769376c --- /dev/null +++ b/demucs/apply.py @@ -0,0 +1,294 @@ +# Copyright (c) Facebook, Inc. and its affiliates. +# All rights reserved. +# +# This source code is licensed under the license found in the +# LICENSE file in the root directory of this source tree. +""" +Code to apply a model to a mix. It will handle chunking with overlaps and +inteprolation between chunks, as well as the "shift trick". +""" +from concurrent.futures import ThreadPoolExecutor +import random +import typing as tp +from multiprocessing import Process,Queue,Pipe + +import torch as th +from torch import nn +from torch.nn import functional as F +import tqdm +import tkinter as tk + +from .demucs import Demucs +from .hdemucs import HDemucs +from .utils import center_trim, DummyPoolExecutor + +Model = tp.Union[Demucs, HDemucs] + +progress_bar_num = 0 + +class BagOfModels(nn.Module): + def __init__(self, models: tp.List[Model], + weights: tp.Optional[tp.List[tp.List[float]]] = None, + segment: tp.Optional[float] = None): + """ + Represents a bag of models with specific weights. + You should call `apply_model` rather than calling directly the forward here for + optimal performance. + + Args: + models (list[nn.Module]): list of Demucs/HDemucs models. + weights (list[list[float]]): list of weights. If None, assumed to + be all ones, otherwise it should be a list of N list (N number of models), + each containing S floats (S number of sources). + segment (None or float): overrides the `segment` attribute of each model + (this is performed inplace, be careful if you reuse the models passed). + """ + + super().__init__() + assert len(models) > 0 + first = models[0] + for other in models: + assert other.sources == first.sources + assert other.samplerate == first.samplerate + assert other.audio_channels == first.audio_channels + if segment is not None: + other.segment = segment + + self.audio_channels = first.audio_channels + self.samplerate = first.samplerate + self.sources = first.sources + self.models = nn.ModuleList(models) + + if weights is None: + weights = [[1. for _ in first.sources] for _ in models] + else: + assert len(weights) == len(models) + for weight in weights: + assert len(weight) == len(first.sources) + self.weights = weights + + def forward(self, x): + raise NotImplementedError("Call `apply_model` on this.") + +class TensorChunk: + def __init__(self, tensor, offset=0, length=None): + total_length = tensor.shape[-1] + assert offset >= 0 + assert offset < total_length + + if length is None: + length = total_length - offset + else: + length = min(total_length - offset, length) + + if isinstance(tensor, TensorChunk): + self.tensor = tensor.tensor + self.offset = offset + tensor.offset + else: + self.tensor = tensor + self.offset = offset + self.length = length + self.device = tensor.device + + @property + def shape(self): + shape = list(self.tensor.shape) + shape[-1] = self.length + return shape + + def padded(self, target_length): + delta = target_length - self.length + total_length = self.tensor.shape[-1] + assert delta >= 0 + + start = self.offset - delta // 2 + end = start + target_length + + correct_start = max(0, start) + correct_end = min(total_length, end) + + pad_left = correct_start - start + pad_right = end - correct_end + + out = F.pad(self.tensor[..., correct_start:correct_end], (pad_left, pad_right)) + assert out.shape[-1] == target_length + return out + +def tensor_chunk(tensor_or_chunk): + if isinstance(tensor_or_chunk, TensorChunk): + return tensor_or_chunk + else: + assert isinstance(tensor_or_chunk, th.Tensor) + return TensorChunk(tensor_or_chunk) + +def apply_model(model, mix, shifts=1, split=True, overlap=0.25, transition_power=1., static_shifts=1, set_progress_bar=None, device=None, progress=False, num_workers=0, pool=None): + """ + Apply model to a given mixture. + + Args: + shifts (int): if > 0, will shift in time `mix` by a random amount between 0 and 0.5 sec + and apply the oppositve shift to the output. This is repeated `shifts` time and + all predictions are averaged. This effectively makes the model time equivariant + and improves SDR by up to 0.2 points. + split (bool): if True, the input will be broken down in 8 seconds extracts + and predictions will be performed individually on each and concatenated. + Useful for model with large memory footprint like Tasnet. + progress (bool): if True, show a progress bar (requires split=True) + device (torch.device, str, or None): if provided, device on which to + execute the computation, otherwise `mix.device` is assumed. + When `device` is different from `mix.device`, only local computations will + be on `device`, while the entire tracks will be stored on `mix.device`. + """ + + global fut_length + global bag_num + global prog_bar + + if device is None: + device = mix.device + else: + device = th.device(device) + if pool is None: + if num_workers > 0 and device.type == 'cpu': + pool = ThreadPoolExecutor(num_workers) + else: + pool = DummyPoolExecutor() + + kwargs = { + 'shifts': shifts, + 'split': split, + 'overlap': overlap, + 'transition_power': transition_power, + 'progress': progress, + 'device': device, + 'pool': pool, + 'set_progress_bar': set_progress_bar, + 'static_shifts': static_shifts, + } + + if isinstance(model, BagOfModels): + # Special treatment for bag of model. + # We explicitely apply multiple times `apply_model` so that the random shifts + # are different for each model. + + estimates = 0 + totals = [0] * len(model.sources) + bag_num = len(model.models) + fut_length = 0 + prog_bar = 0 + current_model = 0 #(bag_num + 1) + for sub_model, weight in zip(model.models, model.weights): + original_model_device = next(iter(sub_model.parameters())).device + sub_model.to(device) + fut_length += fut_length + current_model += 1 + out = apply_model(sub_model, mix, **kwargs) + sub_model.to(original_model_device) + for k, inst_weight in enumerate(weight): + out[:, k, :, :] *= inst_weight + totals[k] += inst_weight + estimates += out + del out + + for k in range(estimates.shape[1]): + estimates[:, k, :, :] /= totals[k] + return estimates + + model.to(device) + model.eval() + assert transition_power >= 1, "transition_power < 1 leads to weird behavior." + batch, channels, length = mix.shape + + if shifts: + kwargs['shifts'] = 0 + max_shift = int(0.5 * model.samplerate) + mix = tensor_chunk(mix) + padded_mix = mix.padded(length + 2 * max_shift) + out = 0 + for _ in range(shifts): + offset = random.randint(0, max_shift) + shifted = TensorChunk(padded_mix, offset, length + max_shift - offset) + shifted_out = apply_model(model, shifted, **kwargs) + out += shifted_out[..., max_shift - offset:] + out /= shifts + return out + elif split: + kwargs['split'] = False + out = th.zeros(batch, len(model.sources), channels, length, device=mix.device) + sum_weight = th.zeros(length, device=mix.device) + segment = int(model.samplerate * model.segment) + stride = int((1 - overlap) * segment) + offsets = range(0, length, stride) + scale = float(format(stride / model.samplerate, ".2f")) + # We start from a triangle shaped weight, with maximal weight in the middle + # of the segment. Then we normalize and take to the power `transition_power`. + # Large values of transition power will lead to sharper transitions. + weight = th.cat([th.arange(1, segment // 2 + 1, device=device), + th.arange(segment - segment // 2, 0, -1, device=device)]) + assert len(weight) == segment + # If the overlap < 50%, this will translate to linear transition when + # transition_power is 1. + weight = (weight / weight.max())**transition_power + futures = [] + for offset in offsets: + chunk = TensorChunk(mix, offset, segment) + future = pool.submit(apply_model, model, chunk, **kwargs) + futures.append((future, offset)) + offset += segment + if progress: + futures = tqdm.tqdm(futures, unit_scale=scale, ncols=120, unit='seconds') + for future, offset in futures: + if set_progress_bar: + fut_length = (len(futures) * bag_num * static_shifts) + prog_bar += 1 + set_progress_bar(0.1, (0.8/fut_length*prog_bar)) + chunk_out = future.result() + chunk_length = chunk_out.shape[-1] + out[..., offset:offset + segment] += (weight[:chunk_length] * chunk_out).to(mix.device) + sum_weight[offset:offset + segment] += weight[:chunk_length].to(mix.device) + assert sum_weight.min() > 0 + out /= sum_weight + return out + else: + if hasattr(model, 'valid_length'): + valid_length = model.valid_length(length) + else: + valid_length = length + mix = tensor_chunk(mix) + padded_mix = mix.padded(valid_length).to(device) + with th.no_grad(): + out = model(padded_mix) + return center_trim(out, length) + +def demucs_segments(demucs_segment, demucs_model): + + if demucs_segment == 'Default': + segment = None + if isinstance(demucs_model, BagOfModels): + if segment is not None: + for sub in demucs_model.models: + sub.segment = segment + else: + if segment is not None: + sub.segment = segment + else: + try: + segment = int(demucs_segment) + if isinstance(demucs_model, BagOfModels): + if segment is not None: + for sub in demucs_model.models: + sub.segment = segment + else: + if segment is not None: + sub.segment = segment + except: + segment = None + if isinstance(demucs_model, BagOfModels): + if segment is not None: + for sub in demucs_model.models: + sub.segment = segment + else: + if segment is not None: + sub.segment = segment + + return demucs_model \ No newline at end of file diff --git a/demucs/demucs.py b/demucs/demucs.py new file mode 100644 index 00000000..d2c08e73 --- /dev/null +++ b/demucs/demucs.py @@ -0,0 +1,459 @@ +# Copyright (c) Facebook, Inc. and its affiliates. +# All rights reserved. +# +# This source code is licensed under the license found in the +# LICENSE file in the root directory of this source tree. + +import math +import typing as tp + +import julius +import torch +from torch import nn +from torch.nn import functional as F + +from .states import capture_init +from .utils import center_trim, unfold + + +class BLSTM(nn.Module): + """ + BiLSTM with same hidden units as input dim. + If `max_steps` is not None, input will be splitting in overlapping + chunks and the LSTM applied separately on each chunk. + """ + def __init__(self, dim, layers=1, max_steps=None, skip=False): + super().__init__() + assert max_steps is None or max_steps % 4 == 0 + self.max_steps = max_steps + self.lstm = nn.LSTM(bidirectional=True, num_layers=layers, hidden_size=dim, input_size=dim) + self.linear = nn.Linear(2 * dim, dim) + self.skip = skip + + def forward(self, x): + B, C, T = x.shape + y = x + framed = False + if self.max_steps is not None and T > self.max_steps: + width = self.max_steps + stride = width // 2 + frames = unfold(x, width, stride) + nframes = frames.shape[2] + framed = True + x = frames.permute(0, 2, 1, 3).reshape(-1, C, width) + + x = x.permute(2, 0, 1) + + x = self.lstm(x)[0] + x = self.linear(x) + x = x.permute(1, 2, 0) + if framed: + out = [] + frames = x.reshape(B, -1, C, width) + limit = stride // 2 + for k in range(nframes): + if k == 0: + out.append(frames[:, k, :, :-limit]) + elif k == nframes - 1: + out.append(frames[:, k, :, limit:]) + else: + out.append(frames[:, k, :, limit:-limit]) + out = torch.cat(out, -1) + out = out[..., :T] + x = out + if self.skip: + x = x + y + return x + + +def rescale_conv(conv, reference): + """Rescale initial weight scale. It is unclear why it helps but it certainly does. + """ + std = conv.weight.std().detach() + scale = (std / reference)**0.5 + conv.weight.data /= scale + if conv.bias is not None: + conv.bias.data /= scale + + +def rescale_module(module, reference): + for sub in module.modules(): + if isinstance(sub, (nn.Conv1d, nn.ConvTranspose1d, nn.Conv2d, nn.ConvTranspose2d)): + rescale_conv(sub, reference) + + +class LayerScale(nn.Module): + """Layer scale from [Touvron et al 2021] (https://arxiv.org/pdf/2103.17239.pdf). + This rescales diagonaly residual outputs close to 0 initially, then learnt. + """ + def __init__(self, channels: int, init: float = 0): + super().__init__() + self.scale = nn.Parameter(torch.zeros(channels, requires_grad=True)) + self.scale.data[:] = init + + def forward(self, x): + return self.scale[:, None] * x + + +class DConv(nn.Module): + """ + New residual branches in each encoder layer. + This alternates dilated convolutions, potentially with LSTMs and attention. + Also before entering each residual branch, dimension is projected on a smaller subspace, + e.g. of dim `channels // compress`. + """ + def __init__(self, channels: int, compress: float = 4, depth: int = 2, init: float = 1e-4, + norm=True, attn=False, heads=4, ndecay=4, lstm=False, gelu=True, + kernel=3, dilate=True): + """ + Args: + channels: input/output channels for residual branch. + compress: amount of channel compression inside the branch. + depth: number of layers in the residual branch. Each layer has its own + projection, and potentially LSTM and attention. + init: initial scale for LayerNorm. + norm: use GroupNorm. + attn: use LocalAttention. + heads: number of heads for the LocalAttention. + ndecay: number of decay controls in the LocalAttention. + lstm: use LSTM. + gelu: Use GELU activation. + kernel: kernel size for the (dilated) convolutions. + dilate: if true, use dilation, increasing with the depth. + """ + + super().__init__() + assert kernel % 2 == 1 + self.channels = channels + self.compress = compress + self.depth = abs(depth) + dilate = depth > 0 + + norm_fn: tp.Callable[[int], nn.Module] + norm_fn = lambda d: nn.Identity() # noqa + if norm: + norm_fn = lambda d: nn.GroupNorm(1, d) # noqa + + hidden = int(channels / compress) + + act: tp.Type[nn.Module] + if gelu: + act = nn.GELU + else: + act = nn.ReLU + + self.layers = nn.ModuleList([]) + for d in range(self.depth): + dilation = 2 ** d if dilate else 1 + padding = dilation * (kernel // 2) + mods = [ + nn.Conv1d(channels, hidden, kernel, dilation=dilation, padding=padding), + norm_fn(hidden), act(), + nn.Conv1d(hidden, 2 * channels, 1), + norm_fn(2 * channels), nn.GLU(1), + LayerScale(channels, init), + ] + if attn: + mods.insert(3, LocalState(hidden, heads=heads, ndecay=ndecay)) + if lstm: + mods.insert(3, BLSTM(hidden, layers=2, max_steps=200, skip=True)) + layer = nn.Sequential(*mods) + self.layers.append(layer) + + def forward(self, x): + for layer in self.layers: + x = x + layer(x) + return x + + +class LocalState(nn.Module): + """Local state allows to have attention based only on data (no positional embedding), + but while setting a constraint on the time window (e.g. decaying penalty term). + + Also a failed experiments with trying to provide some frequency based attention. + """ + def __init__(self, channels: int, heads: int = 4, nfreqs: int = 0, ndecay: int = 4): + super().__init__() + assert channels % heads == 0, (channels, heads) + self.heads = heads + self.nfreqs = nfreqs + self.ndecay = ndecay + self.content = nn.Conv1d(channels, channels, 1) + self.query = nn.Conv1d(channels, channels, 1) + self.key = nn.Conv1d(channels, channels, 1) + if nfreqs: + self.query_freqs = nn.Conv1d(channels, heads * nfreqs, 1) + if ndecay: + self.query_decay = nn.Conv1d(channels, heads * ndecay, 1) + # Initialize decay close to zero (there is a sigmoid), for maximum initial window. + self.query_decay.weight.data *= 0.01 + assert self.query_decay.bias is not None # stupid type checker + self.query_decay.bias.data[:] = -2 + self.proj = nn.Conv1d(channels + heads * nfreqs, channels, 1) + + def forward(self, x): + B, C, T = x.shape + heads = self.heads + indexes = torch.arange(T, device=x.device, dtype=x.dtype) + # left index are keys, right index are queries + delta = indexes[:, None] - indexes[None, :] + + queries = self.query(x).view(B, heads, -1, T) + keys = self.key(x).view(B, heads, -1, T) + # t are keys, s are queries + dots = torch.einsum("bhct,bhcs->bhts", keys, queries) + dots /= keys.shape[2]**0.5 + if self.nfreqs: + periods = torch.arange(1, self.nfreqs + 1, device=x.device, dtype=x.dtype) + freq_kernel = torch.cos(2 * math.pi * delta / periods.view(-1, 1, 1)) + freq_q = self.query_freqs(x).view(B, heads, -1, T) / self.nfreqs ** 0.5 + dots += torch.einsum("fts,bhfs->bhts", freq_kernel, freq_q) + if self.ndecay: + decays = torch.arange(1, self.ndecay + 1, device=x.device, dtype=x.dtype) + decay_q = self.query_decay(x).view(B, heads, -1, T) + decay_q = torch.sigmoid(decay_q) / 2 + decay_kernel = - decays.view(-1, 1, 1) * delta.abs() / self.ndecay**0.5 + dots += torch.einsum("fts,bhfs->bhts", decay_kernel, decay_q) + + # Kill self reference. + dots.masked_fill_(torch.eye(T, device=dots.device, dtype=torch.bool), -100) + weights = torch.softmax(dots, dim=2) + + content = self.content(x).view(B, heads, -1, T) + result = torch.einsum("bhts,bhct->bhcs", weights, content) + if self.nfreqs: + time_sig = torch.einsum("bhts,fts->bhfs", weights, freq_kernel) + result = torch.cat([result, time_sig], 2) + result = result.reshape(B, -1, T) + return x + self.proj(result) + + +class Demucs(nn.Module): + @capture_init + def __init__(self, + sources, + # Channels + audio_channels=2, + channels=64, + growth=2., + # Main structure + depth=6, + rewrite=True, + lstm_layers=0, + # Convolutions + kernel_size=8, + stride=4, + context=1, + # Activations + gelu=True, + glu=True, + # Normalization + norm_starts=4, + norm_groups=4, + # DConv residual branch + dconv_mode=1, + dconv_depth=2, + dconv_comp=4, + dconv_attn=4, + dconv_lstm=4, + dconv_init=1e-4, + # Pre/post processing + normalize=True, + resample=True, + # Weight init + rescale=0.1, + # Metadata + samplerate=44100, + segment=4 * 10): + """ + Args: + sources (list[str]): list of source names + audio_channels (int): stereo or mono + channels (int): first convolution channels + depth (int): number of encoder/decoder layers + growth (float): multiply (resp divide) number of channels by that + for each layer of the encoder (resp decoder) + depth (int): number of layers in the encoder and in the decoder. + rewrite (bool): add 1x1 convolution to each layer. + lstm_layers (int): number of lstm layers, 0 = no lstm. Deactivated + by default, as this is now replaced by the smaller and faster small LSTMs + in the DConv branches. + kernel_size (int): kernel size for convolutions + stride (int): stride for convolutions + context (int): kernel size of the convolution in the + decoder before the transposed convolution. If > 1, + will provide some context from neighboring time steps. + gelu: use GELU activation function. + glu (bool): use glu instead of ReLU for the 1x1 rewrite conv. + norm_starts: layer at which group norm starts being used. + decoder layers are numbered in reverse order. + norm_groups: number of groups for group norm. + dconv_mode: if 1: dconv in encoder only, 2: decoder only, 3: both. + dconv_depth: depth of residual DConv branch. + dconv_comp: compression of DConv branch. + dconv_attn: adds attention layers in DConv branch starting at this layer. + dconv_lstm: adds a LSTM layer in DConv branch starting at this layer. + dconv_init: initial scale for the DConv branch LayerScale. + normalize (bool): normalizes the input audio on the fly, and scales back + the output by the same amount. + resample (bool): upsample x2 the input and downsample /2 the output. + rescale (int): rescale initial weights of convolutions + to get their standard deviation closer to `rescale`. + samplerate (int): stored as meta information for easing + future evaluations of the model. + segment (float): duration of the chunks of audio to ideally evaluate the model on. + This is used by `demucs.apply.apply_model`. + """ + + super().__init__() + self.audio_channels = audio_channels + self.sources = sources + self.kernel_size = kernel_size + self.context = context + self.stride = stride + self.depth = depth + self.resample = resample + self.channels = channels + self.normalize = normalize + self.samplerate = samplerate + self.segment = segment + self.encoder = nn.ModuleList() + self.decoder = nn.ModuleList() + self.skip_scales = nn.ModuleList() + + if glu: + activation = nn.GLU(dim=1) + ch_scale = 2 + else: + activation = nn.ReLU() + ch_scale = 1 + if gelu: + act2 = nn.GELU + else: + act2 = nn.ReLU + + in_channels = audio_channels + padding = 0 + for index in range(depth): + norm_fn = lambda d: nn.Identity() # noqa + if index >= norm_starts: + norm_fn = lambda d: nn.GroupNorm(norm_groups, d) # noqa + + encode = [] + encode += [ + nn.Conv1d(in_channels, channels, kernel_size, stride), + norm_fn(channels), + act2(), + ] + attn = index >= dconv_attn + lstm = index >= dconv_lstm + if dconv_mode & 1: + encode += [DConv(channels, depth=dconv_depth, init=dconv_init, + compress=dconv_comp, attn=attn, lstm=lstm)] + if rewrite: + encode += [ + nn.Conv1d(channels, ch_scale * channels, 1), + norm_fn(ch_scale * channels), activation] + self.encoder.append(nn.Sequential(*encode)) + + decode = [] + if index > 0: + out_channels = in_channels + else: + out_channels = len(self.sources) * audio_channels + if rewrite: + decode += [ + nn.Conv1d(channels, ch_scale * channels, 2 * context + 1, padding=context), + norm_fn(ch_scale * channels), activation] + if dconv_mode & 2: + decode += [DConv(channels, depth=dconv_depth, init=dconv_init, + compress=dconv_comp, attn=attn, lstm=lstm)] + decode += [nn.ConvTranspose1d(channels, out_channels, + kernel_size, stride, padding=padding)] + if index > 0: + decode += [norm_fn(out_channels), act2()] + self.decoder.insert(0, nn.Sequential(*decode)) + in_channels = channels + channels = int(growth * channels) + + channels = in_channels + if lstm_layers: + self.lstm = BLSTM(channels, lstm_layers) + else: + self.lstm = None + + if rescale: + rescale_module(self, reference=rescale) + + def valid_length(self, length): + """ + Return the nearest valid length to use with the model so that + there is no time steps left over in a convolution, e.g. for all + layers, size of the input - kernel_size % stride = 0. + + Note that input are automatically padded if necessary to ensure that the output + has the same length as the input. + """ + if self.resample: + length *= 2 + + for _ in range(self.depth): + length = math.ceil((length - self.kernel_size) / self.stride) + 1 + length = max(1, length) + + for idx in range(self.depth): + length = (length - 1) * self.stride + self.kernel_size + + if self.resample: + length = math.ceil(length / 2) + return int(length) + + def forward(self, mix): + x = mix + length = x.shape[-1] + + if self.normalize: + mono = mix.mean(dim=1, keepdim=True) + mean = mono.mean(dim=-1, keepdim=True) + std = mono.std(dim=-1, keepdim=True) + x = (x - mean) / (1e-5 + std) + else: + mean = 0 + std = 1 + + delta = self.valid_length(length) - length + x = F.pad(x, (delta // 2, delta - delta // 2)) + + if self.resample: + x = julius.resample_frac(x, 1, 2) + + saved = [] + for encode in self.encoder: + x = encode(x) + saved.append(x) + + if self.lstm: + x = self.lstm(x) + + for decode in self.decoder: + skip = saved.pop(-1) + skip = center_trim(skip, x) + x = decode(x + skip) + + if self.resample: + x = julius.resample_frac(x, 2, 1) + x = x * std + mean + x = center_trim(x, length) + x = x.view(x.size(0), len(self.sources), self.audio_channels, x.size(-1)) + return x + + def load_state_dict(self, state, strict=True): + # fix a mismatch with previous generation Demucs models. + for idx in range(self.depth): + for a in ['encoder', 'decoder']: + for b in ['bias', 'weight']: + new = f'{a}.{idx}.3.{b}' + old = f'{a}.{idx}.2.{b}' + if old in state and new not in state: + state[new] = state.pop(old) + super().load_state_dict(state, strict=strict) diff --git a/demucs/filtering.py b/demucs/filtering.py new file mode 100644 index 00000000..08a2c17d --- /dev/null +++ b/demucs/filtering.py @@ -0,0 +1,502 @@ +from typing import Optional +import torch +import torch.nn as nn +from torch import Tensor +from torch.utils.data import DataLoader + +def atan2(y, x): + r"""Element-wise arctangent function of y/x. + Returns a new tensor with signed angles in radians. + It is an alternative implementation of torch.atan2 + + Args: + y (Tensor): First input tensor + x (Tensor): Second input tensor [shape=y.shape] + + Returns: + Tensor: [shape=y.shape]. + """ + pi = 2 * torch.asin(torch.tensor(1.0)) + x += ((x == 0) & (y == 0)) * 1.0 + out = torch.atan(y / x) + out += ((y >= 0) & (x < 0)) * pi + out -= ((y < 0) & (x < 0)) * pi + out *= 1 - ((y > 0) & (x == 0)) * 1.0 + out += ((y > 0) & (x == 0)) * (pi / 2) + out *= 1 - ((y < 0) & (x == 0)) * 1.0 + out += ((y < 0) & (x == 0)) * (-pi / 2) + return out + + +# Define basic complex operations on torch.Tensor objects whose last dimension +# consists in the concatenation of the real and imaginary parts. + + +def _norm(x: torch.Tensor) -> torch.Tensor: + r"""Computes the norm value of a torch Tensor, assuming that it + comes as real and imaginary part in its last dimension. + + Args: + x (Tensor): Input Tensor of shape [shape=(..., 2)] + + Returns: + Tensor: shape as x excluding the last dimension. + """ + return torch.abs(x[..., 0]) ** 2 + torch.abs(x[..., 1]) ** 2 + + +def _mul_add(a: torch.Tensor, b: torch.Tensor, out: Optional[torch.Tensor] = None) -> torch.Tensor: + """Element-wise multiplication of two complex Tensors described + through their real and imaginary parts. + The result is added to the `out` tensor""" + + # check `out` and allocate it if needed + target_shape = torch.Size([max(sa, sb) for (sa, sb) in zip(a.shape, b.shape)]) + if out is None or out.shape != target_shape: + out = torch.zeros(target_shape, dtype=a.dtype, device=a.device) + if out is a: + real_a = a[..., 0] + out[..., 0] = out[..., 0] + (real_a * b[..., 0] - a[..., 1] * b[..., 1]) + out[..., 1] = out[..., 1] + (real_a * b[..., 1] + a[..., 1] * b[..., 0]) + else: + out[..., 0] = out[..., 0] + (a[..., 0] * b[..., 0] - a[..., 1] * b[..., 1]) + out[..., 1] = out[..., 1] + (a[..., 0] * b[..., 1] + a[..., 1] * b[..., 0]) + return out + + +def _mul(a: torch.Tensor, b: torch.Tensor, out: Optional[torch.Tensor] = None) -> torch.Tensor: + """Element-wise multiplication of two complex Tensors described + through their real and imaginary parts + can work in place in case out is a only""" + target_shape = torch.Size([max(sa, sb) for (sa, sb) in zip(a.shape, b.shape)]) + if out is None or out.shape != target_shape: + out = torch.zeros(target_shape, dtype=a.dtype, device=a.device) + if out is a: + real_a = a[..., 0] + out[..., 0] = real_a * b[..., 0] - a[..., 1] * b[..., 1] + out[..., 1] = real_a * b[..., 1] + a[..., 1] * b[..., 0] + else: + out[..., 0] = a[..., 0] * b[..., 0] - a[..., 1] * b[..., 1] + out[..., 1] = a[..., 0] * b[..., 1] + a[..., 1] * b[..., 0] + return out + + +def _inv(z: torch.Tensor, out: Optional[torch.Tensor] = None) -> torch.Tensor: + """Element-wise multiplicative inverse of a Tensor with complex + entries described through their real and imaginary parts. + can work in place in case out is z""" + ez = _norm(z) + if out is None or out.shape != z.shape: + out = torch.zeros_like(z) + out[..., 0] = z[..., 0] / ez + out[..., 1] = -z[..., 1] / ez + return out + + +def _conj(z, out: Optional[torch.Tensor] = None) -> torch.Tensor: + """Element-wise complex conjugate of a Tensor with complex entries + described through their real and imaginary parts. + can work in place in case out is z""" + if out is None or out.shape != z.shape: + out = torch.zeros_like(z) + out[..., 0] = z[..., 0] + out[..., 1] = -z[..., 1] + return out + + +def _invert(M: torch.Tensor, out: Optional[torch.Tensor] = None) -> torch.Tensor: + """ + Invert 1x1 or 2x2 matrices + + Will generate errors if the matrices are singular: user must handle this + through his own regularization schemes. + + Args: + M (Tensor): [shape=(..., nb_channels, nb_channels, 2)] + matrices to invert: must be square along dimensions -3 and -2 + + Returns: + invM (Tensor): [shape=M.shape] + inverses of M + """ + nb_channels = M.shape[-2] + + if out is None or out.shape != M.shape: + out = torch.empty_like(M) + + if nb_channels == 1: + # scalar case + out = _inv(M, out) + elif nb_channels == 2: + # two channels case: analytical expression + + # first compute the determinent + det = _mul(M[..., 0, 0, :], M[..., 1, 1, :]) + det = det - _mul(M[..., 0, 1, :], M[..., 1, 0, :]) + # invert it + invDet = _inv(det) + + # then fill out the matrix with the inverse + out[..., 0, 0, :] = _mul(invDet, M[..., 1, 1, :], out[..., 0, 0, :]) + out[..., 1, 0, :] = _mul(-invDet, M[..., 1, 0, :], out[..., 1, 0, :]) + out[..., 0, 1, :] = _mul(-invDet, M[..., 0, 1, :], out[..., 0, 1, :]) + out[..., 1, 1, :] = _mul(invDet, M[..., 0, 0, :], out[..., 1, 1, :]) + else: + raise Exception("Only 2 channels are supported for the torch version.") + return out + + +# Now define the signal-processing low-level functions used by the Separator + + +def expectation_maximization( + y: torch.Tensor, + x: torch.Tensor, + iterations: int = 2, + eps: float = 1e-10, + batch_size: int = 200, +): + r"""Expectation maximization algorithm, for refining source separation + estimates. + + This algorithm allows to make source separation results better by + enforcing multichannel consistency for the estimates. This usually means + a better perceptual quality in terms of spatial artifacts. + + The implementation follows the details presented in [1]_, taking + inspiration from the original EM algorithm proposed in [2]_ and its + weighted refinement proposed in [3]_, [4]_. + It works by iteratively: + + * Re-estimate source parameters (power spectral densities and spatial + covariance matrices) through :func:`get_local_gaussian_model`. + + * Separate again the mixture with the new parameters by first computing + the new modelled mixture covariance matrices with :func:`get_mix_model`, + prepare the Wiener filters through :func:`wiener_gain` and apply them + with :func:`apply_filter``. + + References + ---------- + .. [1] S. Uhlich and M. Porcu and F. Giron and M. Enenkl and T. Kemp and + N. Takahashi and Y. Mitsufuji, "Improving music source separation based + on deep neural networks through data augmentation and network + blending." 2017 IEEE International Conference on Acoustics, Speech + and Signal Processing (ICASSP). IEEE, 2017. + + .. [2] N.Q. Duong and E. Vincent and R.Gribonval. "Under-determined + reverberant audio source separation using a full-rank spatial + covariance model." IEEE Transactions on Audio, Speech, and Language + Processing 18.7 (2010): 1830-1840. + + .. [3] A. Nugraha and A. Liutkus and E. Vincent. "Multichannel audio source + separation with deep neural networks." IEEE/ACM Transactions on Audio, + Speech, and Language Processing 24.9 (2016): 1652-1664. + + .. [4] A. Nugraha and A. Liutkus and E. Vincent. "Multichannel music + separation with deep neural networks." 2016 24th European Signal + Processing Conference (EUSIPCO). IEEE, 2016. + + .. [5] A. Liutkus and R. Badeau and G. Richard "Kernel additive models for + source separation." IEEE Transactions on Signal Processing + 62.16 (2014): 4298-4310. + + Args: + y (Tensor): [shape=(nb_frames, nb_bins, nb_channels, 2, nb_sources)] + initial estimates for the sources + x (Tensor): [shape=(nb_frames, nb_bins, nb_channels, 2)] + complex STFT of the mixture signal + iterations (int): [scalar] + number of iterations for the EM algorithm. + eps (float or None): [scalar] + The epsilon value to use for regularization and filters. + + Returns: + y (Tensor): [shape=(nb_frames, nb_bins, nb_channels, 2, nb_sources)] + estimated sources after iterations + v (Tensor): [shape=(nb_frames, nb_bins, nb_sources)] + estimated power spectral densities + R (Tensor): [shape=(nb_bins, nb_channels, nb_channels, 2, nb_sources)] + estimated spatial covariance matrices + + Notes: + * You need an initial estimate for the sources to apply this + algorithm. This is precisely what the :func:`wiener` function does. + * This algorithm *is not* an implementation of the "exact" EM + proposed in [1]_. In particular, it does compute the posterior + covariance matrices the same (exact) way. Instead, it uses the + simplified approximate scheme initially proposed in [5]_ and further + refined in [3]_, [4]_, that boils down to just take the empirical + covariance of the recent source estimates, followed by a weighted + average for the update of the spatial covariance matrix. It has been + empirically demonstrated that this simplified algorithm is more + robust for music separation. + + Warning: + It is *very* important to make sure `x.dtype` is `torch.float64` + if you want double precision, because this function will **not** + do such conversion for you from `torch.complex32`, in case you want the + smaller RAM usage on purpose. + + It is usually always better in terms of quality to have double + precision, by e.g. calling :func:`expectation_maximization` + with ``x.to(torch.float64)``. + """ + # dimensions + (nb_frames, nb_bins, nb_channels) = x.shape[:-1] + nb_sources = y.shape[-1] + + regularization = torch.cat( + ( + torch.eye(nb_channels, dtype=x.dtype, device=x.device)[..., None], + torch.zeros((nb_channels, nb_channels, 1), dtype=x.dtype, device=x.device), + ), + dim=2, + ) + regularization = torch.sqrt(torch.as_tensor(eps)) * ( + regularization[None, None, ...].expand((-1, nb_bins, -1, -1, -1)) + ) + + # allocate the spatial covariance matrices + R = [ + torch.zeros((nb_bins, nb_channels, nb_channels, 2), dtype=x.dtype, device=x.device) + for j in range(nb_sources) + ] + weight: torch.Tensor = torch.zeros((nb_bins,), dtype=x.dtype, device=x.device) + + v: torch.Tensor = torch.zeros((nb_frames, nb_bins, nb_sources), dtype=x.dtype, device=x.device) + for it in range(iterations): + # constructing the mixture covariance matrix. Doing it with a loop + # to avoid storing anytime in RAM the whole 6D tensor + + # update the PSD as the average spectrogram over channels + v = torch.mean(torch.abs(y[..., 0, :]) ** 2 + torch.abs(y[..., 1, :]) ** 2, dim=-2) + + # update spatial covariance matrices (weighted update) + for j in range(nb_sources): + R[j] = torch.tensor(0.0, device=x.device) + weight = torch.tensor(eps, device=x.device) + pos: int = 0 + batch_size = batch_size if batch_size else nb_frames + while pos < nb_frames: + t = torch.arange(pos, min(nb_frames, pos + batch_size)) + pos = int(t[-1]) + 1 + + R[j] = R[j] + torch.sum(_covariance(y[t, ..., j]), dim=0) + weight = weight + torch.sum(v[t, ..., j], dim=0) + R[j] = R[j] / weight[..., None, None, None] + weight = torch.zeros_like(weight) + + # cloning y if we track gradient, because we're going to update it + if y.requires_grad: + y = y.clone() + + pos = 0 + while pos < nb_frames: + t = torch.arange(pos, min(nb_frames, pos + batch_size)) + pos = int(t[-1]) + 1 + + y[t, ...] = torch.tensor(0.0, device=x.device, dtype=x.dtype) + + # compute mix covariance matrix + Cxx = regularization + for j in range(nb_sources): + Cxx = Cxx + (v[t, ..., j, None, None, None] * R[j][None, ...].clone()) + + # invert it + inv_Cxx = _invert(Cxx) + + # separate the sources + for j in range(nb_sources): + + # create a wiener gain for this source + gain = torch.zeros_like(inv_Cxx) + + # computes multichannel Wiener gain as v_j R_j inv_Cxx + indices = torch.cartesian_prod( + torch.arange(nb_channels), + torch.arange(nb_channels), + torch.arange(nb_channels), + ) + for index in indices: + gain[:, :, index[0], index[1], :] = _mul_add( + R[j][None, :, index[0], index[2], :].clone(), + inv_Cxx[:, :, index[2], index[1], :], + gain[:, :, index[0], index[1], :], + ) + gain = gain * v[t, ..., None, None, None, j] + + # apply it to the mixture + for i in range(nb_channels): + y[t, ..., j] = _mul_add(gain[..., i, :], x[t, ..., i, None, :], y[t, ..., j]) + + return y, v, R + + +def wiener( + targets_spectrograms: torch.Tensor, + mix_stft: torch.Tensor, + iterations: int = 1, + softmask: bool = False, + residual: bool = False, + scale_factor: float = 10.0, + eps: float = 1e-10, +): + """Wiener-based separation for multichannel audio. + + The method uses the (possibly multichannel) spectrograms of the + sources to separate the (complex) Short Term Fourier Transform of the + mix. Separation is done in a sequential way by: + + * Getting an initial estimate. This can be done in two ways: either by + directly using the spectrograms with the mixture phase, or + by using a softmasking strategy. This initial phase is controlled + by the `softmask` flag. + + * If required, adding an additional residual target as the mix minus + all targets. + + * Refinining these initial estimates through a call to + :func:`expectation_maximization` if the number of iterations is nonzero. + + This implementation also allows to specify the epsilon value used for + regularization. It is based on [1]_, [2]_, [3]_, [4]_. + + References + ---------- + .. [1] S. Uhlich and M. Porcu and F. Giron and M. Enenkl and T. Kemp and + N. Takahashi and Y. Mitsufuji, "Improving music source separation based + on deep neural networks through data augmentation and network + blending." 2017 IEEE International Conference on Acoustics, Speech + and Signal Processing (ICASSP). IEEE, 2017. + + .. [2] A. Nugraha and A. Liutkus and E. Vincent. "Multichannel audio source + separation with deep neural networks." IEEE/ACM Transactions on Audio, + Speech, and Language Processing 24.9 (2016): 1652-1664. + + .. [3] A. Nugraha and A. Liutkus and E. Vincent. "Multichannel music + separation with deep neural networks." 2016 24th European Signal + Processing Conference (EUSIPCO). IEEE, 2016. + + .. [4] A. Liutkus and R. Badeau and G. Richard "Kernel additive models for + source separation." IEEE Transactions on Signal Processing + 62.16 (2014): 4298-4310. + + Args: + targets_spectrograms (Tensor): spectrograms of the sources + [shape=(nb_frames, nb_bins, nb_channels, nb_sources)]. + This is a nonnegative tensor that is + usually the output of the actual separation method of the user. The + spectrograms may be mono, but they need to be 4-dimensional in all + cases. + mix_stft (Tensor): [shape=(nb_frames, nb_bins, nb_channels, complex=2)] + STFT of the mixture signal. + iterations (int): [scalar] + number of iterations for the EM algorithm + softmask (bool): Describes how the initial estimates are obtained. + * if `False`, then the mixture phase will directly be used with the + spectrogram as initial estimates. + * if `True`, initial estimates are obtained by multiplying the + complex mix element-wise with the ratio of each target spectrogram + with the sum of them all. This strategy is better if the model are + not really good, and worse otherwise. + residual (bool): if `True`, an additional target is created, which is + equal to the mixture minus the other targets, before application of + expectation maximization + eps (float): Epsilon value to use for computing the separations. + This is used whenever division with a model energy is + performed, i.e. when softmasking and when iterating the EM. + It can be understood as the energy of the additional white noise + that is taken out when separating. + + Returns: + Tensor: shape=(nb_frames, nb_bins, nb_channels, complex=2, nb_sources) + STFT of estimated sources + + Notes: + * Be careful that you need *magnitude spectrogram estimates* for the + case `softmask==False`. + * `softmask=False` is recommended + * The epsilon value will have a huge impact on performance. If it's + large, only the parts of the signal with a significant energy will + be kept in the sources. This epsilon then directly controls the + energy of the reconstruction error. + + Warning: + As in :func:`expectation_maximization`, we recommend converting the + mixture `x` to double precision `torch.float64` *before* calling + :func:`wiener`. + """ + if softmask: + # if we use softmask, we compute the ratio mask for all targets and + # multiply by the mix stft + y = ( + mix_stft[..., None] + * ( + targets_spectrograms + / (eps + torch.sum(targets_spectrograms, dim=-1, keepdim=True).to(mix_stft.dtype)) + )[..., None, :] + ) + else: + # otherwise, we just multiply the targets spectrograms with mix phase + # we tacitly assume that we have magnitude estimates. + angle = atan2(mix_stft[..., 1], mix_stft[..., 0])[..., None] + nb_sources = targets_spectrograms.shape[-1] + y = torch.zeros( + mix_stft.shape + (nb_sources,), dtype=mix_stft.dtype, device=mix_stft.device + ) + y[..., 0, :] = targets_spectrograms * torch.cos(angle) + y[..., 1, :] = targets_spectrograms * torch.sin(angle) + + if residual: + # if required, adding an additional target as the mix minus + # available targets + y = torch.cat([y, mix_stft[..., None] - y.sum(dim=-1, keepdim=True)], dim=-1) + + if iterations == 0: + return y + + # we need to refine the estimates. Scales down the estimates for + # numerical stability + max_abs = torch.max( + torch.as_tensor(1.0, dtype=mix_stft.dtype, device=mix_stft.device), + torch.sqrt(_norm(mix_stft)).max() / scale_factor, + ) + + mix_stft = mix_stft / max_abs + y = y / max_abs + + # call expectation maximization + y = expectation_maximization(y, mix_stft, iterations, eps=eps)[0] + + # scale estimates up again + y = y * max_abs + return y + + +def _covariance(y_j): + """ + Compute the empirical covariance for a source. + + Args: + y_j (Tensor): complex stft of the source. + [shape=(nb_frames, nb_bins, nb_channels, 2)]. + + Returns: + Cj (Tensor): [shape=(nb_frames, nb_bins, nb_channels, nb_channels, 2)] + just y_j * conj(y_j.T): empirical covariance for each TF bin. + """ + (nb_frames, nb_bins, nb_channels) = y_j.shape[:-1] + Cj = torch.zeros( + (nb_frames, nb_bins, nb_channels, nb_channels, 2), + dtype=y_j.dtype, + device=y_j.device, + ) + indices = torch.cartesian_prod(torch.arange(nb_channels), torch.arange(nb_channels)) + for index in indices: + Cj[:, :, index[0], index[1], :] = _mul_add( + y_j[:, :, index[0], :], + _conj(y_j[:, :, index[1], :]), + Cj[:, :, index[0], index[1], :], + ) + return Cj diff --git a/demucs/hdemucs.py b/demucs/hdemucs.py new file mode 100644 index 00000000..d776d556 --- /dev/null +++ b/demucs/hdemucs.py @@ -0,0 +1,782 @@ +# Copyright (c) Facebook, Inc. and its affiliates. +# All rights reserved. +# +# This source code is licensed under the license found in the +# LICENSE file in the root directory of this source tree. +""" +This code contains the spectrogram and Hybrid version of Demucs. +""" +from copy import deepcopy +import math +import typing as tp +import torch +from torch import nn +from torch.nn import functional as F +from .filtering import wiener +from .demucs import DConv, rescale_module +from .states import capture_init +from .spec import spectro, ispectro + +def pad1d(x: torch.Tensor, paddings: tp.Tuple[int, int], mode: str = 'constant', value: float = 0.): + """Tiny wrapper around F.pad, just to allow for reflect padding on small input. + If this is the case, we insert extra 0 padding to the right before the reflection happen.""" + x0 = x + length = x.shape[-1] + padding_left, padding_right = paddings + if mode == 'reflect': + max_pad = max(padding_left, padding_right) + if length <= max_pad: + extra_pad = max_pad - length + 1 + extra_pad_right = min(padding_right, extra_pad) + extra_pad_left = extra_pad - extra_pad_right + paddings = (padding_left - extra_pad_left, padding_right - extra_pad_right) + x = F.pad(x, (extra_pad_left, extra_pad_right)) + out = F.pad(x, paddings, mode, value) + assert out.shape[-1] == length + padding_left + padding_right + assert (out[..., padding_left: padding_left + length] == x0).all() + return out + +class ScaledEmbedding(nn.Module): + """ + Boost learning rate for embeddings (with `scale`). + Also, can make embeddings continuous with `smooth`. + """ + def __init__(self, num_embeddings: int, embedding_dim: int, + scale: float = 10., smooth=False): + super().__init__() + self.embedding = nn.Embedding(num_embeddings, embedding_dim) + if smooth: + weight = torch.cumsum(self.embedding.weight.data, dim=0) + # when summing gaussian, overscale raises as sqrt(n), so we nornalize by that. + weight = weight / torch.arange(1, num_embeddings + 1).to(weight).sqrt()[:, None] + self.embedding.weight.data[:] = weight + self.embedding.weight.data /= scale + self.scale = scale + + @property + def weight(self): + return self.embedding.weight * self.scale + + def forward(self, x): + out = self.embedding(x) * self.scale + return out + + +class HEncLayer(nn.Module): + def __init__(self, chin, chout, kernel_size=8, stride=4, norm_groups=1, empty=False, + freq=True, dconv=True, norm=True, context=0, dconv_kw={}, pad=True, + rewrite=True): + """Encoder layer. This used both by the time and the frequency branch. + + Args: + chin: number of input channels. + chout: number of output channels. + norm_groups: number of groups for group norm. + empty: used to make a layer with just the first conv. this is used + before merging the time and freq. branches. + freq: this is acting on frequencies. + dconv: insert DConv residual branches. + norm: use GroupNorm. + context: context size for the 1x1 conv. + dconv_kw: list of kwargs for the DConv class. + pad: pad the input. Padding is done so that the output size is + always the input size / stride. + rewrite: add 1x1 conv at the end of the layer. + """ + super().__init__() + norm_fn = lambda d: nn.Identity() # noqa + if norm: + norm_fn = lambda d: nn.GroupNorm(norm_groups, d) # noqa + if pad: + pad = kernel_size // 4 + else: + pad = 0 + klass = nn.Conv1d + self.freq = freq + self.kernel_size = kernel_size + self.stride = stride + self.empty = empty + self.norm = norm + self.pad = pad + if freq: + kernel_size = [kernel_size, 1] + stride = [stride, 1] + pad = [pad, 0] + klass = nn.Conv2d + self.conv = klass(chin, chout, kernel_size, stride, pad) + if self.empty: + return + self.norm1 = norm_fn(chout) + self.rewrite = None + if rewrite: + self.rewrite = klass(chout, 2 * chout, 1 + 2 * context, 1, context) + self.norm2 = norm_fn(2 * chout) + + self.dconv = None + if dconv: + self.dconv = DConv(chout, **dconv_kw) + + def forward(self, x, inject=None): + """ + `inject` is used to inject the result from the time branch into the frequency branch, + when both have the same stride. + """ + if not self.freq and x.dim() == 4: + B, C, Fr, T = x.shape + x = x.view(B, -1, T) + + if not self.freq: + le = x.shape[-1] + if not le % self.stride == 0: + x = F.pad(x, (0, self.stride - (le % self.stride))) + y = self.conv(x) + if self.empty: + return y + if inject is not None: + assert inject.shape[-1] == y.shape[-1], (inject.shape, y.shape) + if inject.dim() == 3 and y.dim() == 4: + inject = inject[:, :, None] + y = y + inject + y = F.gelu(self.norm1(y)) + if self.dconv: + if self.freq: + B, C, Fr, T = y.shape + y = y.permute(0, 2, 1, 3).reshape(-1, C, T) + y = self.dconv(y) + if self.freq: + y = y.view(B, Fr, C, T).permute(0, 2, 1, 3) + if self.rewrite: + z = self.norm2(self.rewrite(y)) + z = F.glu(z, dim=1) + else: + z = y + return z + + +class MultiWrap(nn.Module): + """ + Takes one layer and replicate it N times. each replica will act + on a frequency band. All is done so that if the N replica have the same weights, + then this is exactly equivalent to applying the original module on all frequencies. + + This is a bit over-engineered to avoid edge artifacts when splitting + the frequency bands, but it is possible the naive implementation would work as well... + """ + def __init__(self, layer, split_ratios): + """ + Args: + layer: module to clone, must be either HEncLayer or HDecLayer. + split_ratios: list of float indicating which ratio to keep for each band. + """ + super().__init__() + self.split_ratios = split_ratios + self.layers = nn.ModuleList() + self.conv = isinstance(layer, HEncLayer) + assert not layer.norm + assert layer.freq + assert layer.pad + if not self.conv: + assert not layer.context_freq + for k in range(len(split_ratios) + 1): + lay = deepcopy(layer) + if self.conv: + lay.conv.padding = (0, 0) + else: + lay.pad = False + for m in lay.modules(): + if hasattr(m, 'reset_parameters'): + m.reset_parameters() + self.layers.append(lay) + + def forward(self, x, skip=None, length=None): + B, C, Fr, T = x.shape + + ratios = list(self.split_ratios) + [1] + start = 0 + outs = [] + for ratio, layer in zip(ratios, self.layers): + if self.conv: + pad = layer.kernel_size // 4 + if ratio == 1: + limit = Fr + frames = -1 + else: + limit = int(round(Fr * ratio)) + le = limit - start + if start == 0: + le += pad + frames = round((le - layer.kernel_size) / layer.stride + 1) + limit = start + (frames - 1) * layer.stride + layer.kernel_size + if start == 0: + limit -= pad + assert limit - start > 0, (limit, start) + assert limit <= Fr, (limit, Fr) + y = x[:, :, start:limit, :] + if start == 0: + y = F.pad(y, (0, 0, pad, 0)) + if ratio == 1: + y = F.pad(y, (0, 0, 0, pad)) + outs.append(layer(y)) + start = limit - layer.kernel_size + layer.stride + else: + if ratio == 1: + limit = Fr + else: + limit = int(round(Fr * ratio)) + last = layer.last + layer.last = True + + y = x[:, :, start:limit] + s = skip[:, :, start:limit] + out, _ = layer(y, s, None) + if outs: + outs[-1][:, :, -layer.stride:] += ( + out[:, :, :layer.stride] - layer.conv_tr.bias.view(1, -1, 1, 1)) + out = out[:, :, layer.stride:] + if ratio == 1: + out = out[:, :, :-layer.stride // 2, :] + if start == 0: + out = out[:, :, layer.stride // 2:, :] + outs.append(out) + layer.last = last + start = limit + out = torch.cat(outs, dim=2) + if not self.conv and not last: + out = F.gelu(out) + if self.conv: + return out + else: + return out, None + + +class HDecLayer(nn.Module): + def __init__(self, chin, chout, last=False, kernel_size=8, stride=4, norm_groups=1, empty=False, + freq=True, dconv=True, norm=True, context=1, dconv_kw={}, pad=True, + context_freq=True, rewrite=True): + """ + Same as HEncLayer but for decoder. See `HEncLayer` for documentation. + """ + super().__init__() + norm_fn = lambda d: nn.Identity() # noqa + if norm: + norm_fn = lambda d: nn.GroupNorm(norm_groups, d) # noqa + if pad: + pad = kernel_size // 4 + else: + pad = 0 + self.pad = pad + self.last = last + self.freq = freq + self.chin = chin + self.empty = empty + self.stride = stride + self.kernel_size = kernel_size + self.norm = norm + self.context_freq = context_freq + klass = nn.Conv1d + klass_tr = nn.ConvTranspose1d + if freq: + kernel_size = [kernel_size, 1] + stride = [stride, 1] + klass = nn.Conv2d + klass_tr = nn.ConvTranspose2d + self.conv_tr = klass_tr(chin, chout, kernel_size, stride) + self.norm2 = norm_fn(chout) + if self.empty: + return + self.rewrite = None + if rewrite: + if context_freq: + self.rewrite = klass(chin, 2 * chin, 1 + 2 * context, 1, context) + else: + self.rewrite = klass(chin, 2 * chin, [1, 1 + 2 * context], 1, + [0, context]) + self.norm1 = norm_fn(2 * chin) + + self.dconv = None + if dconv: + self.dconv = DConv(chin, **dconv_kw) + + def forward(self, x, skip, length): + if self.freq and x.dim() == 3: + B, C, T = x.shape + x = x.view(B, self.chin, -1, T) + + if not self.empty: + x = x + skip + + if self.rewrite: + y = F.glu(self.norm1(self.rewrite(x)), dim=1) + else: + y = x + if self.dconv: + if self.freq: + B, C, Fr, T = y.shape + y = y.permute(0, 2, 1, 3).reshape(-1, C, T) + y = self.dconv(y) + if self.freq: + y = y.view(B, Fr, C, T).permute(0, 2, 1, 3) + else: + y = x + assert skip is None + z = self.norm2(self.conv_tr(y)) + if self.freq: + if self.pad: + z = z[..., self.pad:-self.pad, :] + else: + z = z[..., self.pad:self.pad + length] + assert z.shape[-1] == length, (z.shape[-1], length) + if not self.last: + z = F.gelu(z) + return z, y + + +class HDemucs(nn.Module): + """ + Spectrogram and hybrid Demucs model. + The spectrogram model has the same structure as Demucs, except the first few layers are over the + frequency axis, until there is only 1 frequency, and then it moves to time convolutions. + Frequency layers can still access information across time steps thanks to the DConv residual. + + Hybrid model have a parallel time branch. At some layer, the time branch has the same stride + as the frequency branch and then the two are combined. The opposite happens in the decoder. + + Models can either use naive iSTFT from masking, Wiener filtering ([Ulhih et al. 2017]), + or complex as channels (CaC) [Choi et al. 2020]. Wiener filtering is based on + Open Unmix implementation [Stoter et al. 2019]. + + The loss is always on the temporal domain, by backpropagating through the above + output methods and iSTFT. This allows to define hybrid models nicely. However, this breaks + a bit Wiener filtering, as doing more iteration at test time will change the spectrogram + contribution, without changing the one from the waveform, which will lead to worse performance. + I tried using the residual option in OpenUnmix Wiener implementation, but it didn't improve. + CaC on the other hand provides similar performance for hybrid, and works naturally with + hybrid models. + + This model also uses frequency embeddings are used to improve efficiency on convolutions + over the freq. axis, following [Isik et al. 2020] (https://arxiv.org/pdf/2008.04470.pdf). + + Unlike classic Demucs, there is no resampling here, and normalization is always applied. + """ + @capture_init + def __init__(self, + sources, + # Channels + audio_channels=2, + channels=48, + channels_time=None, + growth=2, + # STFT + nfft=4096, + wiener_iters=0, + end_iters=0, + wiener_residual=False, + cac=True, + # Main structure + depth=6, + rewrite=True, + hybrid=True, + hybrid_old=False, + # Frequency branch + multi_freqs=None, + multi_freqs_depth=2, + freq_emb=0.2, + emb_scale=10, + emb_smooth=True, + # Convolutions + kernel_size=8, + time_stride=2, + stride=4, + context=1, + context_enc=0, + # Normalization + norm_starts=4, + norm_groups=4, + # DConv residual branch + dconv_mode=1, + dconv_depth=2, + dconv_comp=4, + dconv_attn=4, + dconv_lstm=4, + dconv_init=1e-4, + # Weight init + rescale=0.1, + # Metadata + samplerate=44100, + segment=4 * 10): + + """ + Args: + sources (list[str]): list of source names. + audio_channels (int): input/output audio channels. + channels (int): initial number of hidden channels. + channels_time: if not None, use a different `channels` value for the time branch. + growth: increase the number of hidden channels by this factor at each layer. + nfft: number of fft bins. Note that changing this require careful computation of + various shape parameters and will not work out of the box for hybrid models. + wiener_iters: when using Wiener filtering, number of iterations at test time. + end_iters: same but at train time. For a hybrid model, must be equal to `wiener_iters`. + wiener_residual: add residual source before wiener filtering. + cac: uses complex as channels, i.e. complex numbers are 2 channels each + in input and output. no further processing is done before ISTFT. + depth (int): number of layers in the encoder and in the decoder. + rewrite (bool): add 1x1 convolution to each layer. + hybrid (bool): make a hybrid time/frequency domain, otherwise frequency only. + hybrid_old: some models trained for MDX had a padding bug. This replicates + this bug to avoid retraining them. + multi_freqs: list of frequency ratios for splitting frequency bands with `MultiWrap`. + multi_freqs_depth: how many layers to wrap with `MultiWrap`. Only the outermost + layers will be wrapped. + freq_emb: add frequency embedding after the first frequency layer if > 0, + the actual value controls the weight of the embedding. + emb_scale: equivalent to scaling the embedding learning rate + emb_smooth: initialize the embedding with a smooth one (with respect to frequencies). + kernel_size: kernel_size for encoder and decoder layers. + stride: stride for encoder and decoder layers. + time_stride: stride for the final time layer, after the merge. + context: context for 1x1 conv in the decoder. + context_enc: context for 1x1 conv in the encoder. + norm_starts: layer at which group norm starts being used. + decoder layers are numbered in reverse order. + norm_groups: number of groups for group norm. + dconv_mode: if 1: dconv in encoder only, 2: decoder only, 3: both. + dconv_depth: depth of residual DConv branch. + dconv_comp: compression of DConv branch. + dconv_attn: adds attention layers in DConv branch starting at this layer. + dconv_lstm: adds a LSTM layer in DConv branch starting at this layer. + dconv_init: initial scale for the DConv branch LayerScale. + rescale: weight recaling trick + + """ + super().__init__() + + self.cac = cac + self.wiener_residual = wiener_residual + self.audio_channels = audio_channels + self.sources = sources + self.kernel_size = kernel_size + self.context = context + self.stride = stride + self.depth = depth + self.channels = channels + self.samplerate = samplerate + self.segment = segment + + self.nfft = nfft + self.hop_length = nfft // 4 + self.wiener_iters = wiener_iters + self.end_iters = end_iters + self.freq_emb = None + self.hybrid = hybrid + self.hybrid_old = hybrid_old + if hybrid_old: + assert hybrid, "hybrid_old must come with hybrid=True" + if hybrid: + assert wiener_iters == end_iters + + self.encoder = nn.ModuleList() + self.decoder = nn.ModuleList() + + if hybrid: + self.tencoder = nn.ModuleList() + self.tdecoder = nn.ModuleList() + + chin = audio_channels + chin_z = chin # number of channels for the freq branch + if self.cac: + chin_z *= 2 + chout = channels_time or channels + chout_z = channels + freqs = nfft // 2 + + for index in range(depth): + lstm = index >= dconv_lstm + attn = index >= dconv_attn + norm = index >= norm_starts + freq = freqs > 1 + stri = stride + ker = kernel_size + if not freq: + assert freqs == 1 + ker = time_stride * 2 + stri = time_stride + + pad = True + last_freq = False + if freq and freqs <= kernel_size: + ker = freqs + pad = False + last_freq = True + + kw = { + 'kernel_size': ker, + 'stride': stri, + 'freq': freq, + 'pad': pad, + 'norm': norm, + 'rewrite': rewrite, + 'norm_groups': norm_groups, + 'dconv_kw': { + 'lstm': lstm, + 'attn': attn, + 'depth': dconv_depth, + 'compress': dconv_comp, + 'init': dconv_init, + 'gelu': True, + } + } + kwt = dict(kw) + kwt['freq'] = 0 + kwt['kernel_size'] = kernel_size + kwt['stride'] = stride + kwt['pad'] = True + kw_dec = dict(kw) + multi = False + if multi_freqs and index < multi_freqs_depth: + multi = True + kw_dec['context_freq'] = False + + if last_freq: + chout_z = max(chout, chout_z) + chout = chout_z + + enc = HEncLayer(chin_z, chout_z, + dconv=dconv_mode & 1, context=context_enc, **kw) + if hybrid and freq: + tenc = HEncLayer(chin, chout, dconv=dconv_mode & 1, context=context_enc, + empty=last_freq, **kwt) + self.tencoder.append(tenc) + + if multi: + enc = MultiWrap(enc, multi_freqs) + self.encoder.append(enc) + if index == 0: + chin = self.audio_channels * len(self.sources) + chin_z = chin + if self.cac: + chin_z *= 2 + dec = HDecLayer(chout_z, chin_z, dconv=dconv_mode & 2, + last=index == 0, context=context, **kw_dec) + if multi: + dec = MultiWrap(dec, multi_freqs) + if hybrid and freq: + tdec = HDecLayer(chout, chin, dconv=dconv_mode & 2, empty=last_freq, + last=index == 0, context=context, **kwt) + self.tdecoder.insert(0, tdec) + self.decoder.insert(0, dec) + + chin = chout + chin_z = chout_z + chout = int(growth * chout) + chout_z = int(growth * chout_z) + if freq: + if freqs <= kernel_size: + freqs = 1 + else: + freqs //= stride + if index == 0 and freq_emb: + self.freq_emb = ScaledEmbedding( + freqs, chin_z, smooth=emb_smooth, scale=emb_scale) + self.freq_emb_scale = freq_emb + + if rescale: + rescale_module(self, reference=rescale) + + def _spec(self, x): + hl = self.hop_length + nfft = self.nfft + x0 = x # noqa + + if self.hybrid: + # We re-pad the signal in order to keep the property + # that the size of the output is exactly the size of the input + # divided by the stride (here hop_length), when divisible. + # This is achieved by padding by 1/4th of the kernel size (here nfft). + # which is not supported by torch.stft. + # Having all convolution operations follow this convention allow to easily + # align the time and frequency branches later on. + assert hl == nfft // 4 + le = int(math.ceil(x.shape[-1] / hl)) + pad = hl // 2 * 3 + if not self.hybrid_old: + x = pad1d(x, (pad, pad + le * hl - x.shape[-1]), mode='reflect') + else: + x = pad1d(x, (pad, pad + le * hl - x.shape[-1])) + + z = spectro(x, nfft, hl)[..., :-1, :] + if self.hybrid: + assert z.shape[-1] == le + 4, (z.shape, x.shape, le) + z = z[..., 2:2+le] + return z + + def _ispec(self, z, length=None, scale=0): + hl = self.hop_length // (4 ** scale) + z = F.pad(z, (0, 0, 0, 1)) + if self.hybrid: + z = F.pad(z, (2, 2)) + pad = hl // 2 * 3 + if not self.hybrid_old: + le = hl * int(math.ceil(length / hl)) + 2 * pad + else: + le = hl * int(math.ceil(length / hl)) + x = ispectro(z, hl, length=le) + if not self.hybrid_old: + x = x[..., pad:pad + length] + else: + x = x[..., :length] + else: + x = ispectro(z, hl, length) + return x + + def _magnitude(self, z): + # return the magnitude of the spectrogram, except when cac is True, + # in which case we just move the complex dimension to the channel one. + if self.cac: + B, C, Fr, T = z.shape + m = torch.view_as_real(z).permute(0, 1, 4, 2, 3) + m = m.reshape(B, C * 2, Fr, T) + else: + m = z.abs() + return m + + def _mask(self, z, m): + # Apply masking given the mixture spectrogram `z` and the estimated mask `m`. + # If `cac` is True, `m` is actually a full spectrogram and `z` is ignored. + niters = self.wiener_iters + if self.cac: + B, S, C, Fr, T = m.shape + out = m.view(B, S, -1, 2, Fr, T).permute(0, 1, 2, 4, 5, 3) + out = torch.view_as_complex(out.contiguous()) + return out + if self.training: + niters = self.end_iters + if niters < 0: + z = z[:, None] + return z / (1e-8 + z.abs()) * m + else: + return self._wiener(m, z, niters) + + def _wiener(self, mag_out, mix_stft, niters): + # apply wiener filtering from OpenUnmix. + init = mix_stft.dtype + wiener_win_len = 300 + residual = self.wiener_residual + + B, S, C, Fq, T = mag_out.shape + mag_out = mag_out.permute(0, 4, 3, 2, 1) + mix_stft = torch.view_as_real(mix_stft.permute(0, 3, 2, 1)) + + outs = [] + for sample in range(B): + pos = 0 + out = [] + for pos in range(0, T, wiener_win_len): + frame = slice(pos, pos + wiener_win_len) + z_out = wiener( + mag_out[sample, frame], mix_stft[sample, frame], niters, + residual=residual) + out.append(z_out.transpose(-1, -2)) + outs.append(torch.cat(out, dim=0)) + out = torch.view_as_complex(torch.stack(outs, 0)) + out = out.permute(0, 4, 3, 2, 1).contiguous() + if residual: + out = out[:, :-1] + assert list(out.shape) == [B, S, C, Fq, T] + return out.to(init) + + def forward(self, mix): + x = mix + length = x.shape[-1] + + z = self._spec(mix) + mag = self._magnitude(z) + x = mag + + B, C, Fq, T = x.shape + + # unlike previous Demucs, we always normalize because it is easier. + mean = x.mean(dim=(1, 2, 3), keepdim=True) + std = x.std(dim=(1, 2, 3), keepdim=True) + x = (x - mean) / (1e-5 + std) + # x will be the freq. branch input. + + if self.hybrid: + # Prepare the time branch input. + xt = mix + meant = xt.mean(dim=(1, 2), keepdim=True) + stdt = xt.std(dim=(1, 2), keepdim=True) + xt = (xt - meant) / (1e-5 + stdt) + + # okay, this is a giant mess I know... + saved = [] # skip connections, freq. + saved_t = [] # skip connections, time. + lengths = [] # saved lengths to properly remove padding, freq branch. + lengths_t = [] # saved lengths for time branch. + for idx, encode in enumerate(self.encoder): + lengths.append(x.shape[-1]) + inject = None + if self.hybrid and idx < len(self.tencoder): + # we have not yet merged branches. + lengths_t.append(xt.shape[-1]) + tenc = self.tencoder[idx] + xt = tenc(xt) + if not tenc.empty: + # save for skip connection + saved_t.append(xt) + else: + # tenc contains just the first conv., so that now time and freq. + # branches have the same shape and can be merged. + inject = xt + x = encode(x, inject) + if idx == 0 and self.freq_emb is not None: + # add frequency embedding to allow for non equivariant convolutions + # over the frequency axis. + frs = torch.arange(x.shape[-2], device=x.device) + emb = self.freq_emb(frs).t()[None, :, :, None].expand_as(x) + x = x + self.freq_emb_scale * emb + + saved.append(x) + + x = torch.zeros_like(x) + if self.hybrid: + xt = torch.zeros_like(x) + # initialize everything to zero (signal will go through u-net skips). + + for idx, decode in enumerate(self.decoder): + skip = saved.pop(-1) + x, pre = decode(x, skip, lengths.pop(-1)) + # `pre` contains the output just before final transposed convolution, + # which is used when the freq. and time branch separate. + + if self.hybrid: + offset = self.depth - len(self.tdecoder) + if self.hybrid and idx >= offset: + tdec = self.tdecoder[idx - offset] + length_t = lengths_t.pop(-1) + if tdec.empty: + assert pre.shape[2] == 1, pre.shape + pre = pre[:, :, 0] + xt, _ = tdec(pre, None, length_t) + else: + skip = saved_t.pop(-1) + xt, _ = tdec(xt, skip, length_t) + + # Let's make sure we used all stored skip connections. + assert len(saved) == 0 + assert len(lengths_t) == 0 + assert len(saved_t) == 0 + + S = len(self.sources) + x = x.view(B, S, -1, Fq, T) + x = x * std[:, None] + mean[:, None] + + zout = self._mask(z, x) + x = self._ispec(zout, length) + + if self.hybrid: + xt = xt.view(B, S, -1, length) + xt = xt * stdt[:, None] + meant[:, None] + x = xt + x + return x + + diff --git a/demucs/htdemucs.py b/demucs/htdemucs.py new file mode 100644 index 00000000..ffa466bb --- /dev/null +++ b/demucs/htdemucs.py @@ -0,0 +1,648 @@ +# Copyright (c) Meta, Inc. and its affiliates. +# All rights reserved. +# +# This source code is licensed under the license found in the +# LICENSE file in the root directory of this source tree. +# First author is Simon Rouard. +""" +This code contains the spectrogram and Hybrid version of Demucs. +""" +import math + +from .filtering import wiener +import torch +from torch import nn +from torch.nn import functional as F +from fractions import Fraction +from einops import rearrange + +from .transformer import CrossTransformerEncoder + +from .demucs import rescale_module +from .states import capture_init +from .spec import spectro, ispectro +from .hdemucs import pad1d, ScaledEmbedding, HEncLayer, MultiWrap, HDecLayer + + +class HTDemucs(nn.Module): + """ + Spectrogram and hybrid Demucs model. + The spectrogram model has the same structure as Demucs, except the first few layers are over the + frequency axis, until there is only 1 frequency, and then it moves to time convolutions. + Frequency layers can still access information across time steps thanks to the DConv residual. + + Hybrid model have a parallel time branch. At some layer, the time branch has the same stride + as the frequency branch and then the two are combined. The opposite happens in the decoder. + + Models can either use naive iSTFT from masking, Wiener filtering ([Ulhih et al. 2017]), + or complex as channels (CaC) [Choi et al. 2020]. Wiener filtering is based on + Open Unmix implementation [Stoter et al. 2019]. + + The loss is always on the temporal domain, by backpropagating through the above + output methods and iSTFT. This allows to define hybrid models nicely. However, this breaks + a bit Wiener filtering, as doing more iteration at test time will change the spectrogram + contribution, without changing the one from the waveform, which will lead to worse performance. + I tried using the residual option in OpenUnmix Wiener implementation, but it didn't improve. + CaC on the other hand provides similar performance for hybrid, and works naturally with + hybrid models. + + This model also uses frequency embeddings are used to improve efficiency on convolutions + over the freq. axis, following [Isik et al. 2020] (https://arxiv.org/pdf/2008.04470.pdf). + + Unlike classic Demucs, there is no resampling here, and normalization is always applied. + """ + + @capture_init + def __init__( + self, + sources, + # Channels + audio_channels=2, + channels=48, + channels_time=None, + growth=2, + # STFT + nfft=4096, + wiener_iters=0, + end_iters=0, + wiener_residual=False, + cac=True, + # Main structure + depth=4, + rewrite=True, + # Frequency branch + multi_freqs=None, + multi_freqs_depth=3, + freq_emb=0.2, + emb_scale=10, + emb_smooth=True, + # Convolutions + kernel_size=8, + time_stride=2, + stride=4, + context=1, + context_enc=0, + # Normalization + norm_starts=4, + norm_groups=4, + # DConv residual branch + dconv_mode=1, + dconv_depth=2, + dconv_comp=8, + dconv_init=1e-3, + # Before the Transformer + bottom_channels=0, + # Transformer + t_layers=5, + t_emb="sin", + t_hidden_scale=4.0, + t_heads=8, + t_dropout=0.0, + t_max_positions=10000, + t_norm_in=True, + t_norm_in_group=False, + t_group_norm=False, + t_norm_first=True, + t_norm_out=True, + t_max_period=10000.0, + t_weight_decay=0.0, + t_lr=None, + t_layer_scale=True, + t_gelu=True, + t_weight_pos_embed=1.0, + t_sin_random_shift=0, + t_cape_mean_normalize=True, + t_cape_augment=True, + t_cape_glob_loc_scale=[5000.0, 1.0, 1.4], + t_sparse_self_attn=False, + t_sparse_cross_attn=False, + t_mask_type="diag", + t_mask_random_seed=42, + t_sparse_attn_window=500, + t_global_window=100, + t_sparsity=0.95, + t_auto_sparsity=False, + # ------ Particuliar parameters + t_cross_first=False, + # Weight init + rescale=0.1, + # Metadata + samplerate=44100, + segment=10, + use_train_segment=True, + ): + """ + Args: + sources (list[str]): list of source names. + audio_channels (int): input/output audio channels. + channels (int): initial number of hidden channels. + channels_time: if not None, use a different `channels` value for the time branch. + growth: increase the number of hidden channels by this factor at each layer. + nfft: number of fft bins. Note that changing this require careful computation of + various shape parameters and will not work out of the box for hybrid models. + wiener_iters: when using Wiener filtering, number of iterations at test time. + end_iters: same but at train time. For a hybrid model, must be equal to `wiener_iters`. + wiener_residual: add residual source before wiener filtering. + cac: uses complex as channels, i.e. complex numbers are 2 channels each + in input and output. no further processing is done before ISTFT. + depth (int): number of layers in the encoder and in the decoder. + rewrite (bool): add 1x1 convolution to each layer. + multi_freqs: list of frequency ratios for splitting frequency bands with `MultiWrap`. + multi_freqs_depth: how many layers to wrap with `MultiWrap`. Only the outermost + layers will be wrapped. + freq_emb: add frequency embedding after the first frequency layer if > 0, + the actual value controls the weight of the embedding. + emb_scale: equivalent to scaling the embedding learning rate + emb_smooth: initialize the embedding with a smooth one (with respect to frequencies). + kernel_size: kernel_size for encoder and decoder layers. + stride: stride for encoder and decoder layers. + time_stride: stride for the final time layer, after the merge. + context: context for 1x1 conv in the decoder. + context_enc: context for 1x1 conv in the encoder. + norm_starts: layer at which group norm starts being used. + decoder layers are numbered in reverse order. + norm_groups: number of groups for group norm. + dconv_mode: if 1: dconv in encoder only, 2: decoder only, 3: both. + dconv_depth: depth of residual DConv branch. + dconv_comp: compression of DConv branch. + dconv_attn: adds attention layers in DConv branch starting at this layer. + dconv_lstm: adds a LSTM layer in DConv branch starting at this layer. + dconv_init: initial scale for the DConv branch LayerScale. + bottom_channels: if >0 it adds a linear layer (1x1 Conv) before and after the + transformer in order to change the number of channels + t_layers: number of layers in each branch (waveform and spec) of the transformer + t_emb: "sin", "cape" or "scaled" + t_hidden_scale: the hidden scale of the Feedforward parts of the transformer + for instance if C = 384 (the number of channels in the transformer) and + t_hidden_scale = 4.0 then the intermediate layer of the FFN has dimension + 384 * 4 = 1536 + t_heads: number of heads for the transformer + t_dropout: dropout in the transformer + t_max_positions: max_positions for the "scaled" positional embedding, only + useful if t_emb="scaled" + t_norm_in: (bool) norm before addinf positional embedding and getting into the + transformer layers + t_norm_in_group: (bool) if True while t_norm_in=True, the norm is on all the + timesteps (GroupNorm with group=1) + t_group_norm: (bool) if True, the norms of the Encoder Layers are on all the + timesteps (GroupNorm with group=1) + t_norm_first: (bool) if True the norm is before the attention and before the FFN + t_norm_out: (bool) if True, there is a GroupNorm (group=1) at the end of each layer + t_max_period: (float) denominator in the sinusoidal embedding expression + t_weight_decay: (float) weight decay for the transformer + t_lr: (float) specific learning rate for the transformer + t_layer_scale: (bool) Layer Scale for the transformer + t_gelu: (bool) activations of the transformer are GeLU if True, ReLU else + t_weight_pos_embed: (float) weighting of the positional embedding + t_cape_mean_normalize: (bool) if t_emb="cape", normalisation of positional embeddings + see: https://arxiv.org/abs/2106.03143 + t_cape_augment: (bool) if t_emb="cape", must be True during training and False + during the inference, see: https://arxiv.org/abs/2106.03143 + t_cape_glob_loc_scale: (list of 3 floats) if t_emb="cape", CAPE parameters + see: https://arxiv.org/abs/2106.03143 + t_sparse_self_attn: (bool) if True, the self attentions are sparse + t_sparse_cross_attn: (bool) if True, the cross-attentions are sparse (don't use it + unless you designed really specific masks) + t_mask_type: (str) can be "diag", "jmask", "random", "global" or any combination + with '_' between: i.e. "diag_jmask_random" (note that this is permutation + invariant i.e. "diag_jmask_random" is equivalent to "jmask_random_diag") + t_mask_random_seed: (int) if "random" is in t_mask_type, controls the seed + that generated the random part of the mask + t_sparse_attn_window: (int) if "diag" is in t_mask_type, for a query (i), and + a key (j), the mask is True id |i-j|<=t_sparse_attn_window + t_global_window: (int) if "global" is in t_mask_type, mask[:t_global_window, :] + and mask[:, :t_global_window] will be True + t_sparsity: (float) if "random" is in t_mask_type, t_sparsity is the sparsity + level of the random part of the mask. + t_cross_first: (bool) if True cross attention is the first layer of the + transformer (False seems to be better) + rescale: weight rescaling trick + use_train_segment: (bool) if True, the actual size that is used during the + training is used during inference. + """ + super().__init__() + self.cac = cac + self.wiener_residual = wiener_residual + self.audio_channels = audio_channels + self.sources = sources + self.kernel_size = kernel_size + self.context = context + self.stride = stride + self.depth = depth + self.bottom_channels = bottom_channels + self.channels = channels + self.samplerate = samplerate + self.segment = segment + self.use_train_segment = use_train_segment + self.nfft = nfft + self.hop_length = nfft // 4 + self.wiener_iters = wiener_iters + self.end_iters = end_iters + self.freq_emb = None + assert wiener_iters == end_iters + + self.encoder = nn.ModuleList() + self.decoder = nn.ModuleList() + + self.tencoder = nn.ModuleList() + self.tdecoder = nn.ModuleList() + + chin = audio_channels + chin_z = chin # number of channels for the freq branch + if self.cac: + chin_z *= 2 + chout = channels_time or channels + chout_z = channels + freqs = nfft // 2 + + for index in range(depth): + norm = index >= norm_starts + freq = freqs > 1 + stri = stride + ker = kernel_size + if not freq: + assert freqs == 1 + ker = time_stride * 2 + stri = time_stride + + pad = True + last_freq = False + if freq and freqs <= kernel_size: + ker = freqs + pad = False + last_freq = True + + kw = { + "kernel_size": ker, + "stride": stri, + "freq": freq, + "pad": pad, + "norm": norm, + "rewrite": rewrite, + "norm_groups": norm_groups, + "dconv_kw": { + "depth": dconv_depth, + "compress": dconv_comp, + "init": dconv_init, + "gelu": True, + }, + } + kwt = dict(kw) + kwt["freq"] = 0 + kwt["kernel_size"] = kernel_size + kwt["stride"] = stride + kwt["pad"] = True + kw_dec = dict(kw) + multi = False + if multi_freqs and index < multi_freqs_depth: + multi = True + kw_dec["context_freq"] = False + + if last_freq: + chout_z = max(chout, chout_z) + chout = chout_z + + enc = HEncLayer( + chin_z, chout_z, dconv=dconv_mode & 1, context=context_enc, **kw + ) + if freq: + tenc = HEncLayer( + chin, + chout, + dconv=dconv_mode & 1, + context=context_enc, + empty=last_freq, + **kwt + ) + self.tencoder.append(tenc) + + if multi: + enc = MultiWrap(enc, multi_freqs) + self.encoder.append(enc) + if index == 0: + chin = self.audio_channels * len(self.sources) + chin_z = chin + if self.cac: + chin_z *= 2 + dec = HDecLayer( + chout_z, + chin_z, + dconv=dconv_mode & 2, + last=index == 0, + context=context, + **kw_dec + ) + if multi: + dec = MultiWrap(dec, multi_freqs) + if freq: + tdec = HDecLayer( + chout, + chin, + dconv=dconv_mode & 2, + empty=last_freq, + last=index == 0, + context=context, + **kwt + ) + self.tdecoder.insert(0, tdec) + self.decoder.insert(0, dec) + + chin = chout + chin_z = chout_z + chout = int(growth * chout) + chout_z = int(growth * chout_z) + if freq: + if freqs <= kernel_size: + freqs = 1 + else: + freqs //= stride + if index == 0 and freq_emb: + self.freq_emb = ScaledEmbedding( + freqs, chin_z, smooth=emb_smooth, scale=emb_scale + ) + self.freq_emb_scale = freq_emb + + if rescale: + rescale_module(self, reference=rescale) + + transformer_channels = channels * growth ** (depth - 1) + if bottom_channels: + self.channel_upsampler = nn.Conv1d(transformer_channels, bottom_channels, 1) + self.channel_downsampler = nn.Conv1d( + bottom_channels, transformer_channels, 1 + ) + self.channel_upsampler_t = nn.Conv1d( + transformer_channels, bottom_channels, 1 + ) + self.channel_downsampler_t = nn.Conv1d( + bottom_channels, transformer_channels, 1 + ) + + transformer_channels = bottom_channels + + if t_layers > 0: + self.crosstransformer = CrossTransformerEncoder( + dim=transformer_channels, + emb=t_emb, + hidden_scale=t_hidden_scale, + num_heads=t_heads, + num_layers=t_layers, + cross_first=t_cross_first, + dropout=t_dropout, + max_positions=t_max_positions, + norm_in=t_norm_in, + norm_in_group=t_norm_in_group, + group_norm=t_group_norm, + norm_first=t_norm_first, + norm_out=t_norm_out, + max_period=t_max_period, + weight_decay=t_weight_decay, + lr=t_lr, + layer_scale=t_layer_scale, + gelu=t_gelu, + sin_random_shift=t_sin_random_shift, + weight_pos_embed=t_weight_pos_embed, + cape_mean_normalize=t_cape_mean_normalize, + cape_augment=t_cape_augment, + cape_glob_loc_scale=t_cape_glob_loc_scale, + sparse_self_attn=t_sparse_self_attn, + sparse_cross_attn=t_sparse_cross_attn, + mask_type=t_mask_type, + mask_random_seed=t_mask_random_seed, + sparse_attn_window=t_sparse_attn_window, + global_window=t_global_window, + sparsity=t_sparsity, + auto_sparsity=t_auto_sparsity, + ) + else: + self.crosstransformer = None + + def _spec(self, x): + hl = self.hop_length + nfft = self.nfft + x0 = x # noqa + + # We re-pad the signal in order to keep the property + # that the size of the output is exactly the size of the input + # divided by the stride (here hop_length), when divisible. + # This is achieved by padding by 1/4th of the kernel size (here nfft). + # which is not supported by torch.stft. + # Having all convolution operations follow this convention allow to easily + # align the time and frequency branches later on. + assert hl == nfft // 4 + le = int(math.ceil(x.shape[-1] / hl)) + pad = hl // 2 * 3 + x = pad1d(x, (pad, pad + le * hl - x.shape[-1]), mode="reflect") + + z = spectro(x, nfft, hl)[..., :-1, :] + assert z.shape[-1] == le + 4, (z.shape, x.shape, le) + z = z[..., 2: 2 + le] + return z + + def _ispec(self, z, length=None, scale=0): + hl = self.hop_length // (4**scale) + z = F.pad(z, (0, 0, 0, 1)) + z = F.pad(z, (2, 2)) + pad = hl // 2 * 3 + le = hl * int(math.ceil(length / hl)) + 2 * pad + x = ispectro(z, hl, length=le) + x = x[..., pad: pad + length] + return x + + def _magnitude(self, z): + # return the magnitude of the spectrogram, except when cac is True, + # in which case we just move the complex dimension to the channel one. + if self.cac: + B, C, Fr, T = z.shape + m = torch.view_as_real(z).permute(0, 1, 4, 2, 3) + m = m.reshape(B, C * 2, Fr, T) + else: + m = z.abs() + return m + + def _mask(self, z, m): + # Apply masking given the mixture spectrogram `z` and the estimated mask `m`. + # If `cac` is True, `m` is actually a full spectrogram and `z` is ignored. + niters = self.wiener_iters + if self.cac: + B, S, C, Fr, T = m.shape + out = m.view(B, S, -1, 2, Fr, T).permute(0, 1, 2, 4, 5, 3) + out = torch.view_as_complex(out.contiguous()) + return out + if self.training: + niters = self.end_iters + if niters < 0: + z = z[:, None] + return z / (1e-8 + z.abs()) * m + else: + return self._wiener(m, z, niters) + + def _wiener(self, mag_out, mix_stft, niters): + # apply wiener filtering from OpenUnmix. + init = mix_stft.dtype + wiener_win_len = 300 + residual = self.wiener_residual + + B, S, C, Fq, T = mag_out.shape + mag_out = mag_out.permute(0, 4, 3, 2, 1) + mix_stft = torch.view_as_real(mix_stft.permute(0, 3, 2, 1)) + + outs = [] + for sample in range(B): + pos = 0 + out = [] + for pos in range(0, T, wiener_win_len): + frame = slice(pos, pos + wiener_win_len) + z_out = wiener( + mag_out[sample, frame], + mix_stft[sample, frame], + niters, + residual=residual, + ) + out.append(z_out.transpose(-1, -2)) + outs.append(torch.cat(out, dim=0)) + out = torch.view_as_complex(torch.stack(outs, 0)) + out = out.permute(0, 4, 3, 2, 1).contiguous() + if residual: + out = out[:, :-1] + assert list(out.shape) == [B, S, C, Fq, T] + return out.to(init) + + def valid_length(self, length: int): + """ + Return a length that is appropriate for evaluation. + In our case, always return the training length, unless + it is smaller than the given length, in which case this + raises an error. + """ + if not self.use_train_segment: + return length + training_length = int(self.segment * self.samplerate) + if training_length < length: + raise ValueError( + f"Given length {length} is longer than " + f"training length {training_length}") + return training_length + + def forward(self, mix): + length = mix.shape[-1] + length_pre_pad = None + if self.use_train_segment: + if self.training: + self.segment = Fraction(mix.shape[-1], self.samplerate) + else: + training_length = int(self.segment * self.samplerate) + if mix.shape[-1] < training_length: + length_pre_pad = mix.shape[-1] + mix = F.pad(mix, (0, training_length - length_pre_pad)) + z = self._spec(mix) + mag = self._magnitude(z) + x = mag + + B, C, Fq, T = x.shape + + # unlike previous Demucs, we always normalize because it is easier. + mean = x.mean(dim=(1, 2, 3), keepdim=True) + std = x.std(dim=(1, 2, 3), keepdim=True) + x = (x - mean) / (1e-5 + std) + # x will be the freq. branch input. + + # Prepare the time branch input. + xt = mix + meant = xt.mean(dim=(1, 2), keepdim=True) + stdt = xt.std(dim=(1, 2), keepdim=True) + xt = (xt - meant) / (1e-5 + stdt) + + # okay, this is a giant mess I know... + saved = [] # skip connections, freq. + saved_t = [] # skip connections, time. + lengths = [] # saved lengths to properly remove padding, freq branch. + lengths_t = [] # saved lengths for time branch. + for idx, encode in enumerate(self.encoder): + lengths.append(x.shape[-1]) + inject = None + if idx < len(self.tencoder): + # we have not yet merged branches. + lengths_t.append(xt.shape[-1]) + tenc = self.tencoder[idx] + xt = tenc(xt) + if not tenc.empty: + # save for skip connection + saved_t.append(xt) + else: + # tenc contains just the first conv., so that now time and freq. + # branches have the same shape and can be merged. + inject = xt + x = encode(x, inject) + if idx == 0 and self.freq_emb is not None: + # add frequency embedding to allow for non equivariant convolutions + # over the frequency axis. + frs = torch.arange(x.shape[-2], device=x.device) + emb = self.freq_emb(frs).t()[None, :, :, None].expand_as(x) + x = x + self.freq_emb_scale * emb + + saved.append(x) + if self.crosstransformer: + if self.bottom_channels: + b, c, f, t = x.shape + x = rearrange(x, "b c f t-> b c (f t)") + x = self.channel_upsampler(x) + x = rearrange(x, "b c (f t)-> b c f t", f=f) + xt = self.channel_upsampler_t(xt) + + x, xt = self.crosstransformer(x, xt) + + if self.bottom_channels: + x = rearrange(x, "b c f t-> b c (f t)") + x = self.channel_downsampler(x) + x = rearrange(x, "b c (f t)-> b c f t", f=f) + xt = self.channel_downsampler_t(xt) + + for idx, decode in enumerate(self.decoder): + skip = saved.pop(-1) + x, pre = decode(x, skip, lengths.pop(-1)) + # `pre` contains the output just before final transposed convolution, + # which is used when the freq. and time branch separate. + + offset = self.depth - len(self.tdecoder) + if idx >= offset: + tdec = self.tdecoder[idx - offset] + length_t = lengths_t.pop(-1) + if tdec.empty: + assert pre.shape[2] == 1, pre.shape + pre = pre[:, :, 0] + xt, _ = tdec(pre, None, length_t) + else: + skip = saved_t.pop(-1) + xt, _ = tdec(xt, skip, length_t) + + # Let's make sure we used all stored skip connections. + assert len(saved) == 0 + assert len(lengths_t) == 0 + assert len(saved_t) == 0 + + S = len(self.sources) + x = x.view(B, S, -1, Fq, T) + x = x * std[:, None] + mean[:, None] + + zout = self._mask(z, x) + if self.use_train_segment: + if self.training: + x = self._ispec(zout, length) + else: + x = self._ispec(zout, training_length) + else: + x = self._ispec(zout, length) + + if self.use_train_segment: + if self.training: + xt = xt.view(B, S, -1, length) + else: + xt = xt.view(B, S, -1, training_length) + else: + xt = xt.view(B, S, -1, length) + xt = xt * stdt[:, None] + meant[:, None] + x = xt + x + if length_pre_pad: + x = x[..., :length_pre_pad] + return x diff --git a/demucs/model.py b/demucs/model.py new file mode 100644 index 00000000..e2745b8c --- /dev/null +++ b/demucs/model.py @@ -0,0 +1,218 @@ +# Copyright (c) Facebook, Inc. and its affiliates. +# All rights reserved. +# +# This source code is licensed under the license found in the +# LICENSE file in the root directory of this source tree. + +import math + +import torch as th +from torch import nn + +from .utils import capture_init, center_trim + + +class BLSTM(nn.Module): + def __init__(self, dim, layers=1): + super().__init__() + self.lstm = nn.LSTM(bidirectional=True, num_layers=layers, hidden_size=dim, input_size=dim) + self.linear = nn.Linear(2 * dim, dim) + + def forward(self, x): + x = x.permute(2, 0, 1) + x = self.lstm(x)[0] + x = self.linear(x) + x = x.permute(1, 2, 0) + return x + + +def rescale_conv(conv, reference): + std = conv.weight.std().detach() + scale = (std / reference)**0.5 + conv.weight.data /= scale + if conv.bias is not None: + conv.bias.data /= scale + + +def rescale_module(module, reference): + for sub in module.modules(): + if isinstance(sub, (nn.Conv1d, nn.ConvTranspose1d)): + rescale_conv(sub, reference) + + +def upsample(x, stride): + """ + Linear upsampling, the output will be `stride` times longer. + """ + batch, channels, time = x.size() + weight = th.arange(stride, device=x.device, dtype=th.float) / stride + x = x.view(batch, channels, time, 1) + out = x[..., :-1, :] * (1 - weight) + x[..., 1:, :] * weight + return out.reshape(batch, channels, -1) + + +def downsample(x, stride): + """ + Downsample x by decimation. + """ + return x[:, :, ::stride] + + +class Demucs(nn.Module): + @capture_init + def __init__(self, + sources=4, + audio_channels=2, + channels=64, + depth=6, + rewrite=True, + glu=True, + upsample=False, + rescale=0.1, + kernel_size=8, + stride=4, + growth=2., + lstm_layers=2, + context=3, + samplerate=44100): + """ + Args: + sources (int): number of sources to separate + audio_channels (int): stereo or mono + channels (int): first convolution channels + depth (int): number of encoder/decoder layers + rewrite (bool): add 1x1 convolution to each encoder layer + and a convolution to each decoder layer. + For the decoder layer, `context` gives the kernel size. + glu (bool): use glu instead of ReLU + upsample (bool): use linear upsampling with convolutions + Wave-U-Net style, instead of transposed convolutions + rescale (int): rescale initial weights of convolutions + to get their standard deviation closer to `rescale` + kernel_size (int): kernel size for convolutions + stride (int): stride for convolutions + growth (float): multiply (resp divide) number of channels by that + for each layer of the encoder (resp decoder) + lstm_layers (int): number of lstm layers, 0 = no lstm + context (int): kernel size of the convolution in the + decoder before the transposed convolution. If > 1, + will provide some context from neighboring time + steps. + """ + + super().__init__() + self.audio_channels = audio_channels + self.sources = sources + self.kernel_size = kernel_size + self.context = context + self.stride = stride + self.depth = depth + self.upsample = upsample + self.channels = channels + self.samplerate = samplerate + + self.encoder = nn.ModuleList() + self.decoder = nn.ModuleList() + + self.final = None + if upsample: + self.final = nn.Conv1d(channels + audio_channels, sources * audio_channels, 1) + stride = 1 + + if glu: + activation = nn.GLU(dim=1) + ch_scale = 2 + else: + activation = nn.ReLU() + ch_scale = 1 + in_channels = audio_channels + for index in range(depth): + encode = [] + encode += [nn.Conv1d(in_channels, channels, kernel_size, stride), nn.ReLU()] + if rewrite: + encode += [nn.Conv1d(channels, ch_scale * channels, 1), activation] + self.encoder.append(nn.Sequential(*encode)) + + decode = [] + if index > 0: + out_channels = in_channels + else: + if upsample: + out_channels = channels + else: + out_channels = sources * audio_channels + if rewrite: + decode += [nn.Conv1d(channels, ch_scale * channels, context), activation] + if upsample: + decode += [ + nn.Conv1d(channels, out_channels, kernel_size, stride=1), + ] + else: + decode += [nn.ConvTranspose1d(channels, out_channels, kernel_size, stride)] + if index > 0: + decode.append(nn.ReLU()) + self.decoder.insert(0, nn.Sequential(*decode)) + in_channels = channels + channels = int(growth * channels) + + channels = in_channels + + if lstm_layers: + self.lstm = BLSTM(channels, lstm_layers) + else: + self.lstm = None + + if rescale: + rescale_module(self, reference=rescale) + + def valid_length(self, length): + """ + Return the nearest valid length to use with the model so that + there is no time steps left over in a convolutions, e.g. for all + layers, size of the input - kernel_size % stride = 0. + + If the mixture has a valid length, the estimated sources + will have exactly the same length when context = 1. If context > 1, + the two signals can be center trimmed to match. + + For training, extracts should have a valid length.For evaluation + on full tracks we recommend passing `pad = True` to :method:`forward`. + """ + for _ in range(self.depth): + if self.upsample: + length = math.ceil(length / self.stride) + self.kernel_size - 1 + else: + length = math.ceil((length - self.kernel_size) / self.stride) + 1 + length = max(1, length) + length += self.context - 1 + for _ in range(self.depth): + if self.upsample: + length = length * self.stride + self.kernel_size - 1 + else: + length = (length - 1) * self.stride + self.kernel_size + + return int(length) + + def forward(self, mix): + x = mix + saved = [x] + for encode in self.encoder: + x = encode(x) + saved.append(x) + if self.upsample: + x = downsample(x, self.stride) + if self.lstm: + x = self.lstm(x) + for decode in self.decoder: + if self.upsample: + x = upsample(x, stride=self.stride) + skip = center_trim(saved.pop(-1), x) + x = x + skip + x = decode(x) + if self.final: + skip = center_trim(saved.pop(-1), x) + x = th.cat([x, skip], dim=1) + x = self.final(x) + + x = x.view(x.size(0), self.sources, self.audio_channels, x.size(-1)) + return x diff --git a/demucs/model_v2.py b/demucs/model_v2.py new file mode 100644 index 00000000..db43fc5e --- /dev/null +++ b/demucs/model_v2.py @@ -0,0 +1,218 @@ +# Copyright (c) Facebook, Inc. and its affiliates. +# All rights reserved. +# +# This source code is licensed under the license found in the +# LICENSE file in the root directory of this source tree. + +import math + +import julius +from torch import nn +from .tasnet_v2 import ConvTasNet + +from .utils import capture_init, center_trim + + +class BLSTM(nn.Module): + def __init__(self, dim, layers=1): + super().__init__() + self.lstm = nn.LSTM(bidirectional=True, num_layers=layers, hidden_size=dim, input_size=dim) + self.linear = nn.Linear(2 * dim, dim) + + def forward(self, x): + x = x.permute(2, 0, 1) + x = self.lstm(x)[0] + x = self.linear(x) + x = x.permute(1, 2, 0) + return x + + +def rescale_conv(conv, reference): + std = conv.weight.std().detach() + scale = (std / reference)**0.5 + conv.weight.data /= scale + if conv.bias is not None: + conv.bias.data /= scale + + +def rescale_module(module, reference): + for sub in module.modules(): + if isinstance(sub, (nn.Conv1d, nn.ConvTranspose1d)): + rescale_conv(sub, reference) + +def auto_load_demucs_model_v2(sources, demucs_model_name): + + if '48' in demucs_model_name: + channels=48 + elif 'unittest' in demucs_model_name: + channels=4 + else: + channels=64 + + if 'tasnet' in demucs_model_name: + init_demucs_model = ConvTasNet(sources, X=10) + else: + init_demucs_model = Demucs(sources, channels=channels) + + return init_demucs_model + +class Demucs(nn.Module): + @capture_init + def __init__(self, + sources, + audio_channels=2, + channels=64, + depth=6, + rewrite=True, + glu=True, + rescale=0.1, + resample=True, + kernel_size=8, + stride=4, + growth=2., + lstm_layers=2, + context=3, + normalize=False, + samplerate=44100, + segment_length=4 * 10 * 44100): + """ + Args: + sources (list[str]): list of source names + audio_channels (int): stereo or mono + channels (int): first convolution channels + depth (int): number of encoder/decoder layers + rewrite (bool): add 1x1 convolution to each encoder layer + and a convolution to each decoder layer. + For the decoder layer, `context` gives the kernel size. + glu (bool): use glu instead of ReLU + resample_input (bool): upsample x2 the input and downsample /2 the output. + rescale (int): rescale initial weights of convolutions + to get their standard deviation closer to `rescale` + kernel_size (int): kernel size for convolutions + stride (int): stride for convolutions + growth (float): multiply (resp divide) number of channels by that + for each layer of the encoder (resp decoder) + lstm_layers (int): number of lstm layers, 0 = no lstm + context (int): kernel size of the convolution in the + decoder before the transposed convolution. If > 1, + will provide some context from neighboring time + steps. + samplerate (int): stored as meta information for easing + future evaluations of the model. + segment_length (int): stored as meta information for easing + future evaluations of the model. Length of the segments on which + the model was trained. + """ + + super().__init__() + self.audio_channels = audio_channels + self.sources = sources + self.kernel_size = kernel_size + self.context = context + self.stride = stride + self.depth = depth + self.resample = resample + self.channels = channels + self.normalize = normalize + self.samplerate = samplerate + self.segment_length = segment_length + + self.encoder = nn.ModuleList() + self.decoder = nn.ModuleList() + + if glu: + activation = nn.GLU(dim=1) + ch_scale = 2 + else: + activation = nn.ReLU() + ch_scale = 1 + in_channels = audio_channels + for index in range(depth): + encode = [] + encode += [nn.Conv1d(in_channels, channels, kernel_size, stride), nn.ReLU()] + if rewrite: + encode += [nn.Conv1d(channels, ch_scale * channels, 1), activation] + self.encoder.append(nn.Sequential(*encode)) + + decode = [] + if index > 0: + out_channels = in_channels + else: + out_channels = len(self.sources) * audio_channels + if rewrite: + decode += [nn.Conv1d(channels, ch_scale * channels, context), activation] + decode += [nn.ConvTranspose1d(channels, out_channels, kernel_size, stride)] + if index > 0: + decode.append(nn.ReLU()) + self.decoder.insert(0, nn.Sequential(*decode)) + in_channels = channels + channels = int(growth * channels) + + channels = in_channels + + if lstm_layers: + self.lstm = BLSTM(channels, lstm_layers) + else: + self.lstm = None + + if rescale: + rescale_module(self, reference=rescale) + + def valid_length(self, length): + """ + Return the nearest valid length to use with the model so that + there is no time steps left over in a convolutions, e.g. for all + layers, size of the input - kernel_size % stride = 0. + + If the mixture has a valid length, the estimated sources + will have exactly the same length when context = 1. If context > 1, + the two signals can be center trimmed to match. + + For training, extracts should have a valid length.For evaluation + on full tracks we recommend passing `pad = True` to :method:`forward`. + """ + if self.resample: + length *= 2 + for _ in range(self.depth): + length = math.ceil((length - self.kernel_size) / self.stride) + 1 + length = max(1, length) + length += self.context - 1 + for _ in range(self.depth): + length = (length - 1) * self.stride + self.kernel_size + + if self.resample: + length = math.ceil(length / 2) + return int(length) + + def forward(self, mix): + x = mix + + if self.normalize: + mono = mix.mean(dim=1, keepdim=True) + mean = mono.mean(dim=-1, keepdim=True) + std = mono.std(dim=-1, keepdim=True) + else: + mean = 0 + std = 1 + + x = (x - mean) / (1e-5 + std) + + if self.resample: + x = julius.resample_frac(x, 1, 2) + + saved = [] + for encode in self.encoder: + x = encode(x) + saved.append(x) + if self.lstm: + x = self.lstm(x) + for decode in self.decoder: + skip = center_trim(saved.pop(-1), x) + x = x + skip + x = decode(x) + + if self.resample: + x = julius.resample_frac(x, 2, 1) + x = x * std + mean + x = x.view(x.size(0), len(self.sources), self.audio_channels, x.size(-1)) + return x diff --git a/demucs/pretrained.py b/demucs/pretrained.py new file mode 100644 index 00000000..ee5bc730 --- /dev/null +++ b/demucs/pretrained.py @@ -0,0 +1,180 @@ +# Copyright (c) Facebook, Inc. and its affiliates. +# All rights reserved. +# +# This source code is licensed under the license found in the +# LICENSE file in the root directory of this source tree. +"""Loading pretrained models. +""" + +import logging +from pathlib import Path +import typing as tp + +from dora.log import fatal + +import logging + +from diffq import DiffQuantizer +import torch.hub + +from .model import Demucs +from .tasnet_v2 import ConvTasNet +from .utils import set_state + +from .hdemucs import HDemucs +from .repo import RemoteRepo, LocalRepo, ModelOnlyRepo, BagOnlyRepo, AnyModelRepo, ModelLoadingError # noqa + +logger = logging.getLogger(__name__) +ROOT_URL = "https://dl.fbaipublicfiles.com/demucs/mdx_final/" +REMOTE_ROOT = Path(__file__).parent / 'remote' + +SOURCES = ["drums", "bass", "other", "vocals"] + + +def demucs_unittest(): + model = HDemucs(channels=4, sources=SOURCES) + return model + + +def add_model_flags(parser): + group = parser.add_mutually_exclusive_group(required=False) + group.add_argument("-s", "--sig", help="Locally trained XP signature.") + group.add_argument("-n", "--name", default="mdx_extra_q", + help="Pretrained model name or signature. Default is mdx_extra_q.") + parser.add_argument("--repo", type=Path, + help="Folder containing all pre-trained models for use with -n.") + + +def _parse_remote_files(remote_file_list) -> tp.Dict[str, str]: + root: str = '' + models: tp.Dict[str, str] = {} + for line in remote_file_list.read_text().split('\n'): + line = line.strip() + if line.startswith('#'): + continue + elif line.startswith('root:'): + root = line.split(':', 1)[1].strip() + else: + sig = line.split('-', 1)[0] + assert sig not in models + models[sig] = ROOT_URL + root + line + return models + +def get_model(name: str, + repo: tp.Optional[Path] = None): + """`name` must be a bag of models name or a pretrained signature + from the remote AWS model repo or the specified local repo if `repo` is not None. + """ + if name == 'demucs_unittest': + return demucs_unittest() + model_repo: ModelOnlyRepo + if repo is None: + models = _parse_remote_files(REMOTE_ROOT / 'files.txt') + model_repo = RemoteRepo(models) + bag_repo = BagOnlyRepo(REMOTE_ROOT, model_repo) + else: + if not repo.is_dir(): + fatal(f"{repo} must exist and be a directory.") + model_repo = LocalRepo(repo) + bag_repo = BagOnlyRepo(repo, model_repo) + any_repo = AnyModelRepo(model_repo, bag_repo) + model = any_repo.get_model(name) + model.eval() + return model + +def get_model_from_args(args): + """ + Load local model package or pre-trained model. + """ + return get_model(name=args.name, repo=args.repo) + +logger = logging.getLogger(__name__) +ROOT = "https://dl.fbaipublicfiles.com/demucs/v3.0/" + +PRETRAINED_MODELS = { + 'demucs': 'e07c671f', + 'demucs48_hq': '28a1282c', + 'demucs_extra': '3646af93', + 'demucs_quantized': '07afea75', + 'tasnet': 'beb46fac', + 'tasnet_extra': 'df3777b2', + 'demucs_unittest': '09ebc15f', +} + +SOURCES = ["drums", "bass", "other", "vocals"] + + +def get_url(name): + sig = PRETRAINED_MODELS[name] + return ROOT + name + "-" + sig[:8] + ".th" + +def is_pretrained(name): + return name in PRETRAINED_MODELS + + +def load_pretrained(name): + if name == "demucs": + return demucs(pretrained=True) + elif name == "demucs48_hq": + return demucs(pretrained=True, hq=True, channels=48) + elif name == "demucs_extra": + return demucs(pretrained=True, extra=True) + elif name == "demucs_quantized": + return demucs(pretrained=True, quantized=True) + elif name == "demucs_unittest": + return demucs_unittest(pretrained=True) + elif name == "tasnet": + return tasnet(pretrained=True) + elif name == "tasnet_extra": + return tasnet(pretrained=True, extra=True) + else: + raise ValueError(f"Invalid pretrained name {name}") + + +def _load_state(name, model, quantizer=None): + url = get_url(name) + state = torch.hub.load_state_dict_from_url(url, map_location='cpu', check_hash=True) + set_state(model, quantizer, state) + if quantizer: + quantizer.detach() + + +def demucs_unittest(pretrained=True): + model = Demucs(channels=4, sources=SOURCES) + if pretrained: + _load_state('demucs_unittest', model) + return model + + +def demucs(pretrained=True, extra=False, quantized=False, hq=False, channels=64): + if not pretrained and (extra or quantized or hq): + raise ValueError("if extra or quantized is True, pretrained must be True.") + model = Demucs(sources=SOURCES, channels=channels) + if pretrained: + name = 'demucs' + if channels != 64: + name += str(channels) + quantizer = None + if sum([extra, quantized, hq]) > 1: + raise ValueError("Only one of extra, quantized, hq, can be True.") + if quantized: + quantizer = DiffQuantizer(model, group_size=8, min_size=1) + name += '_quantized' + if extra: + name += '_extra' + if hq: + name += '_hq' + _load_state(name, model, quantizer) + return model + + +def tasnet(pretrained=True, extra=False): + if not pretrained and extra: + raise ValueError("if extra is True, pretrained must be True.") + model = ConvTasNet(X=10, sources=SOURCES) + if pretrained: + name = 'tasnet' + if extra: + name = 'tasnet_extra' + _load_state(name, model) + return model \ No newline at end of file diff --git a/demucs/repo.py b/demucs/repo.py new file mode 100644 index 00000000..65ff6b33 --- /dev/null +++ b/demucs/repo.py @@ -0,0 +1,148 @@ +# Copyright (c) Facebook, Inc. and its affiliates. +# All rights reserved. +# +# This source code is licensed under the license found in the +# LICENSE file in the root directory of this source tree. +"""Represents a model repository, including pre-trained models and bags of models. +A repo can either be the main remote repository stored in AWS, or a local repository +with your own models. +""" + +from hashlib import sha256 +from pathlib import Path +import typing as tp + +import torch +import yaml + +from .apply import BagOfModels, Model +from .states import load_model + + +AnyModel = tp.Union[Model, BagOfModels] + + +class ModelLoadingError(RuntimeError): + pass + + +def check_checksum(path: Path, checksum: str): + sha = sha256() + with open(path, 'rb') as file: + while True: + buf = file.read(2**20) + if not buf: + break + sha.update(buf) + actual_checksum = sha.hexdigest()[:len(checksum)] + if actual_checksum != checksum: + raise ModelLoadingError(f'Invalid checksum for file {path}, ' + f'expected {checksum} but got {actual_checksum}') + +class ModelOnlyRepo: + """Base class for all model only repos. + """ + def has_model(self, sig: str) -> bool: + raise NotImplementedError() + + def get_model(self, sig: str) -> Model: + raise NotImplementedError() + + +class RemoteRepo(ModelOnlyRepo): + def __init__(self, models: tp.Dict[str, str]): + self._models = models + + def has_model(self, sig: str) -> bool: + return sig in self._models + + def get_model(self, sig: str) -> Model: + try: + url = self._models[sig] + except KeyError: + raise ModelLoadingError(f'Could not find a pre-trained model with signature {sig}.') + pkg = torch.hub.load_state_dict_from_url(url, map_location='cpu', check_hash=True) + return load_model(pkg) + + +class LocalRepo(ModelOnlyRepo): + def __init__(self, root: Path): + self.root = root + self.scan() + + def scan(self): + self._models = {} + self._checksums = {} + for file in self.root.iterdir(): + if file.suffix == '.th': + if '-' in file.stem: + xp_sig, checksum = file.stem.split('-') + self._checksums[xp_sig] = checksum + else: + xp_sig = file.stem + if xp_sig in self._models: + print('Whats xp? ', xp_sig) + raise ModelLoadingError( + f'Duplicate pre-trained model exist for signature {xp_sig}. ' + 'Please delete all but one.') + self._models[xp_sig] = file + + def has_model(self, sig: str) -> bool: + return sig in self._models + + def get_model(self, sig: str) -> Model: + try: + file = self._models[sig] + except KeyError: + raise ModelLoadingError(f'Could not find pre-trained model with signature {sig}.') + if sig in self._checksums: + check_checksum(file, self._checksums[sig]) + return load_model(file) + + +class BagOnlyRepo: + """Handles only YAML files containing bag of models, leaving the actual + model loading to some Repo. + """ + def __init__(self, root: Path, model_repo: ModelOnlyRepo): + self.root = root + self.model_repo = model_repo + self.scan() + + def scan(self): + self._bags = {} + for file in self.root.iterdir(): + if file.suffix == '.yaml': + self._bags[file.stem] = file + + def has_model(self, name: str) -> bool: + return name in self._bags + + def get_model(self, name: str) -> BagOfModels: + try: + yaml_file = self._bags[name] + except KeyError: + raise ModelLoadingError(f'{name} is neither a single pre-trained model or ' + 'a bag of models.') + bag = yaml.safe_load(open(yaml_file)) + signatures = bag['models'] + models = [self.model_repo.get_model(sig) for sig in signatures] + weights = bag.get('weights') + segment = bag.get('segment') + return BagOfModels(models, weights, segment) + + +class AnyModelRepo: + def __init__(self, model_repo: ModelOnlyRepo, bag_repo: BagOnlyRepo): + self.model_repo = model_repo + self.bag_repo = bag_repo + + def has_model(self, name_or_sig: str) -> bool: + return self.model_repo.has_model(name_or_sig) or self.bag_repo.has_model(name_or_sig) + + def get_model(self, name_or_sig: str) -> AnyModel: + print('name_or_sig: ', name_or_sig) + if self.model_repo.has_model(name_or_sig): + return self.model_repo.get_model(name_or_sig) + else: + return self.bag_repo.get_model(name_or_sig) diff --git a/demucs/spec.py b/demucs/spec.py new file mode 100644 index 00000000..85e5dc93 --- /dev/null +++ b/demucs/spec.py @@ -0,0 +1,41 @@ +# Copyright (c) Facebook, Inc. and its affiliates. +# All rights reserved. +# +# This source code is licensed under the license found in the +# LICENSE file in the root directory of this source tree. +"""Conveniance wrapper to perform STFT and iSTFT""" + +import torch as th + + +def spectro(x, n_fft=512, hop_length=None, pad=0): + *other, length = x.shape + x = x.reshape(-1, length) + z = th.stft(x, + n_fft * (1 + pad), + hop_length or n_fft // 4, + window=th.hann_window(n_fft).to(x), + win_length=n_fft, + normalized=True, + center=True, + return_complex=True, + pad_mode='reflect') + _, freqs, frame = z.shape + return z.view(*other, freqs, frame) + + +def ispectro(z, hop_length=None, length=None, pad=0): + *other, freqs, frames = z.shape + n_fft = 2 * freqs - 2 + z = z.view(-1, freqs, frames) + win_length = n_fft // (1 + pad) + x = th.istft(z, + n_fft, + hop_length, + window=th.hann_window(win_length).to(z.real), + win_length=win_length, + normalized=True, + length=length, + center=True) + _, length = x.shape + return x.view(*other, length) diff --git a/demucs/states.py b/demucs/states.py new file mode 100644 index 00000000..db17a182 --- /dev/null +++ b/demucs/states.py @@ -0,0 +1,148 @@ +# Copyright (c) Facebook, Inc. and its affiliates. +# All rights reserved. +# +# This source code is licensed under the license found in the +# LICENSE file in the root directory of this source tree. +""" +Utilities to save and load models. +""" +from contextlib import contextmanager + +import functools +import hashlib +import inspect +import io +from pathlib import Path +import warnings + +from omegaconf import OmegaConf +from diffq import DiffQuantizer, UniformQuantizer, restore_quantized_state +import torch + + +def get_quantizer(model, args, optimizer=None): + """Return the quantizer given the XP quantization args.""" + quantizer = None + if args.diffq: + quantizer = DiffQuantizer( + model, min_size=args.min_size, group_size=args.group_size) + if optimizer is not None: + quantizer.setup_optimizer(optimizer) + elif args.qat: + quantizer = UniformQuantizer( + model, bits=args.qat, min_size=args.min_size) + return quantizer + + +def load_model(path_or_package, strict=False): + """Load a model from the given serialized model, either given as a dict (already loaded) + or a path to a file on disk.""" + if isinstance(path_or_package, dict): + package = path_or_package + elif isinstance(path_or_package, (str, Path)): + with warnings.catch_warnings(): + warnings.simplefilter("ignore") + path = path_or_package + package = torch.load(path, 'cpu') + else: + raise ValueError(f"Invalid type for {path_or_package}.") + + klass = package["klass"] + args = package["args"] + kwargs = package["kwargs"] + + if strict: + model = klass(*args, **kwargs) + else: + sig = inspect.signature(klass) + for key in list(kwargs): + if key not in sig.parameters: + warnings.warn("Dropping inexistant parameter " + key) + del kwargs[key] + model = klass(*args, **kwargs) + + state = package["state"] + + set_state(model, state) + return model + + +def get_state(model, quantizer, half=False): + """Get the state from a model, potentially with quantization applied. + If `half` is True, model are stored as half precision, which shouldn't impact performance + but half the state size.""" + if quantizer is None: + dtype = torch.half if half else None + state = {k: p.data.to(device='cpu', dtype=dtype) for k, p in model.state_dict().items()} + else: + state = quantizer.get_quantized_state() + state['__quantized'] = True + return state + + +def set_state(model, state, quantizer=None): + """Set the state on a given model.""" + if state.get('__quantized'): + if quantizer is not None: + quantizer.restore_quantized_state(model, state['quantized']) + else: + restore_quantized_state(model, state) + else: + model.load_state_dict(state) + return state + + +def save_with_checksum(content, path): + """Save the given value on disk, along with a sha256 hash. + Should be used with the output of either `serialize_model` or `get_state`.""" + buf = io.BytesIO() + torch.save(content, buf) + sig = hashlib.sha256(buf.getvalue()).hexdigest()[:8] + + path = path.parent / (path.stem + "-" + sig + path.suffix) + path.write_bytes(buf.getvalue()) + + +def serialize_model(model, training_args, quantizer=None, half=True): + args, kwargs = model._init_args_kwargs + klass = model.__class__ + + state = get_state(model, quantizer, half) + return { + 'klass': klass, + 'args': args, + 'kwargs': kwargs, + 'state': state, + 'training_args': OmegaConf.to_container(training_args, resolve=True), + } + + +def copy_state(state): + return {k: v.cpu().clone() for k, v in state.items()} + + +@contextmanager +def swap_state(model, state): + """ + Context manager that swaps the state of a model, e.g: + + # model is in old state + with swap_state(model, new_state): + # model in new state + # model back to old state + """ + old_state = copy_state(model.state_dict()) + model.load_state_dict(state, strict=False) + try: + yield + finally: + model.load_state_dict(old_state) + + +def capture_init(init): + @functools.wraps(init) + def __init__(self, *args, **kwargs): + self._init_args_kwargs = (args, kwargs) + init(self, *args, **kwargs) + + return __init__ diff --git a/demucs/tasnet.py b/demucs/tasnet.py new file mode 100644 index 00000000..9cb7a957 --- /dev/null +++ b/demucs/tasnet.py @@ -0,0 +1,447 @@ +# Copyright (c) Facebook, Inc. and its affiliates. +# All rights reserved. +# +# This source code is licensed under the license found in the +# LICENSE file in the root directory of this source tree. +# +# Created on 2018/12 +# Author: Kaituo XU +# Modified on 2019/11 by Alexandre Defossez, added support for multiple output channels +# Here is the original license: +# The MIT License (MIT) +# +# Copyright (c) 2018 Kaituo XU +# +# Permission is hereby granted, free of charge, to any person obtaining a copy +# of this software and associated documentation files (the "Software"), to deal +# in the Software without restriction, including without limitation the rights +# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell +# copies of the Software, and to permit persons to whom the Software is +# furnished to do so, subject to the following conditions: +# +# The above copyright notice and this permission notice shall be included in all +# copies or substantial portions of the Software. +# +# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR +# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, +# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE +# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER +# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, +# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE +# SOFTWARE. + +import math + +import torch +import torch.nn as nn +import torch.nn.functional as F + +from .utils import capture_init + +EPS = 1e-8 + + +def overlap_and_add(signal, frame_step): + outer_dimensions = signal.size()[:-2] + frames, frame_length = signal.size()[-2:] + + subframe_length = math.gcd(frame_length, frame_step) # gcd=Greatest Common Divisor + subframe_step = frame_step // subframe_length + subframes_per_frame = frame_length // subframe_length + output_size = frame_step * (frames - 1) + frame_length + output_subframes = output_size // subframe_length + + subframe_signal = signal.view(*outer_dimensions, -1, subframe_length) + + frame = torch.arange(0, output_subframes, + device=signal.device).unfold(0, subframes_per_frame, subframe_step) + frame = frame.long() # signal may in GPU or CPU + frame = frame.contiguous().view(-1) + + result = signal.new_zeros(*outer_dimensions, output_subframes, subframe_length) + result.index_add_(-2, frame, subframe_signal) + result = result.view(*outer_dimensions, -1) + return result + + +class ConvTasNet(nn.Module): + @capture_init + def __init__(self, + N=256, + L=20, + B=256, + H=512, + P=3, + X=8, + R=4, + C=4, + audio_channels=1, + samplerate=44100, + norm_type="gLN", + causal=False, + mask_nonlinear='relu'): + """ + Args: + N: Number of filters in autoencoder + L: Length of the filters (in samples) + B: Number of channels in bottleneck 1 × 1-conv block + H: Number of channels in convolutional blocks + P: Kernel size in convolutional blocks + X: Number of convolutional blocks in each repeat + R: Number of repeats + C: Number of speakers + norm_type: BN, gLN, cLN + causal: causal or non-causal + mask_nonlinear: use which non-linear function to generate mask + """ + super(ConvTasNet, self).__init__() + # Hyper-parameter + self.N, self.L, self.B, self.H, self.P, self.X, self.R, self.C = N, L, B, H, P, X, R, C + self.norm_type = norm_type + self.causal = causal + self.mask_nonlinear = mask_nonlinear + self.audio_channels = audio_channels + self.samplerate = samplerate + # Components + self.encoder = Encoder(L, N, audio_channels) + self.separator = TemporalConvNet(N, B, H, P, X, R, C, norm_type, causal, mask_nonlinear) + self.decoder = Decoder(N, L, audio_channels) + # init + for p in self.parameters(): + if p.dim() > 1: + nn.init.xavier_normal_(p) + + def valid_length(self, length): + return length + + def forward(self, mixture): + """ + Args: + mixture: [M, T], M is batch size, T is #samples + Returns: + est_source: [M, C, T] + """ + mixture_w = self.encoder(mixture) + est_mask = self.separator(mixture_w) + est_source = self.decoder(mixture_w, est_mask) + + # T changed after conv1d in encoder, fix it here + T_origin = mixture.size(-1) + T_conv = est_source.size(-1) + est_source = F.pad(est_source, (0, T_origin - T_conv)) + return est_source + + +class Encoder(nn.Module): + """Estimation of the nonnegative mixture weight by a 1-D conv layer. + """ + def __init__(self, L, N, audio_channels): + super(Encoder, self).__init__() + # Hyper-parameter + self.L, self.N = L, N + # Components + # 50% overlap + self.conv1d_U = nn.Conv1d(audio_channels, N, kernel_size=L, stride=L // 2, bias=False) + + def forward(self, mixture): + """ + Args: + mixture: [M, T], M is batch size, T is #samples + Returns: + mixture_w: [M, N, K], where K = (T-L)/(L/2)+1 = 2T/L-1 + """ + mixture_w = F.relu(self.conv1d_U(mixture)) # [M, N, K] + return mixture_w + + +class Decoder(nn.Module): + def __init__(self, N, L, audio_channels): + super(Decoder, self).__init__() + # Hyper-parameter + self.N, self.L = N, L + self.audio_channels = audio_channels + # Components + self.basis_signals = nn.Linear(N, audio_channels * L, bias=False) + + def forward(self, mixture_w, est_mask): + """ + Args: + mixture_w: [M, N, K] + est_mask: [M, C, N, K] + Returns: + est_source: [M, C, T] + """ + # D = W * M + source_w = torch.unsqueeze(mixture_w, 1) * est_mask # [M, C, N, K] + source_w = torch.transpose(source_w, 2, 3) # [M, C, K, N] + # S = DV + est_source = self.basis_signals(source_w) # [M, C, K, ac * L] + m, c, k, _ = est_source.size() + est_source = est_source.view(m, c, k, self.audio_channels, -1).transpose(2, 3).contiguous() + est_source = overlap_and_add(est_source, self.L // 2) # M x C x ac x T + return est_source + + +class TemporalConvNet(nn.Module): + def __init__(self, N, B, H, P, X, R, C, norm_type="gLN", causal=False, mask_nonlinear='relu'): + """ + Args: + N: Number of filters in autoencoder + B: Number of channels in bottleneck 1 × 1-conv block + H: Number of channels in convolutional blocks + P: Kernel size in convolutional blocks + X: Number of convolutional blocks in each repeat + R: Number of repeats + C: Number of speakers + norm_type: BN, gLN, cLN + causal: causal or non-causal + mask_nonlinear: use which non-linear function to generate mask + """ + super(TemporalConvNet, self).__init__() + # Hyper-parameter + self.C = C + self.mask_nonlinear = mask_nonlinear + # Components + # [M, N, K] -> [M, N, K] + layer_norm = ChannelwiseLayerNorm(N) + # [M, N, K] -> [M, B, K] + bottleneck_conv1x1 = nn.Conv1d(N, B, 1, bias=False) + # [M, B, K] -> [M, B, K] + repeats = [] + for r in range(R): + blocks = [] + for x in range(X): + dilation = 2**x + padding = (P - 1) * dilation if causal else (P - 1) * dilation // 2 + blocks += [ + TemporalBlock(B, + H, + P, + stride=1, + padding=padding, + dilation=dilation, + norm_type=norm_type, + causal=causal) + ] + repeats += [nn.Sequential(*blocks)] + temporal_conv_net = nn.Sequential(*repeats) + # [M, B, K] -> [M, C*N, K] + mask_conv1x1 = nn.Conv1d(B, C * N, 1, bias=False) + # Put together + self.network = nn.Sequential(layer_norm, bottleneck_conv1x1, temporal_conv_net, + mask_conv1x1) + + def forward(self, mixture_w): + """ + Keep this API same with TasNet + Args: + mixture_w: [M, N, K], M is batch size + returns: + est_mask: [M, C, N, K] + """ + M, N, K = mixture_w.size() + score = self.network(mixture_w) # [M, N, K] -> [M, C*N, K] + score = score.view(M, self.C, N, K) # [M, C*N, K] -> [M, C, N, K] + if self.mask_nonlinear == 'softmax': + est_mask = F.softmax(score, dim=1) + elif self.mask_nonlinear == 'relu': + est_mask = F.relu(score) + else: + raise ValueError("Unsupported mask non-linear function") + return est_mask + + +class TemporalBlock(nn.Module): + def __init__(self, + in_channels, + out_channels, + kernel_size, + stride, + padding, + dilation, + norm_type="gLN", + causal=False): + super(TemporalBlock, self).__init__() + # [M, B, K] -> [M, H, K] + conv1x1 = nn.Conv1d(in_channels, out_channels, 1, bias=False) + prelu = nn.PReLU() + norm = chose_norm(norm_type, out_channels) + # [M, H, K] -> [M, B, K] + dsconv = DepthwiseSeparableConv(out_channels, in_channels, kernel_size, stride, padding, + dilation, norm_type, causal) + # Put together + self.net = nn.Sequential(conv1x1, prelu, norm, dsconv) + + def forward(self, x): + """ + Args: + x: [M, B, K] + Returns: + [M, B, K] + """ + residual = x + out = self.net(x) + # TODO: when P = 3 here works fine, but when P = 2 maybe need to pad? + return out + residual # look like w/o F.relu is better than w/ F.relu + # return F.relu(out + residual) + + +class DepthwiseSeparableConv(nn.Module): + def __init__(self, + in_channels, + out_channels, + kernel_size, + stride, + padding, + dilation, + norm_type="gLN", + causal=False): + super(DepthwiseSeparableConv, self).__init__() + # Use `groups` option to implement depthwise convolution + # [M, H, K] -> [M, H, K] + depthwise_conv = nn.Conv1d(in_channels, + in_channels, + kernel_size, + stride=stride, + padding=padding, + dilation=dilation, + groups=in_channels, + bias=False) + if causal: + chomp = Chomp1d(padding) + prelu = nn.PReLU() + norm = chose_norm(norm_type, in_channels) + # [M, H, K] -> [M, B, K] + pointwise_conv = nn.Conv1d(in_channels, out_channels, 1, bias=False) + # Put together + if causal: + self.net = nn.Sequential(depthwise_conv, chomp, prelu, norm, pointwise_conv) + else: + self.net = nn.Sequential(depthwise_conv, prelu, norm, pointwise_conv) + + def forward(self, x): + """ + Args: + x: [M, H, K] + Returns: + result: [M, B, K] + """ + return self.net(x) + + +class Chomp1d(nn.Module): + """To ensure the output length is the same as the input. + """ + def __init__(self, chomp_size): + super(Chomp1d, self).__init__() + self.chomp_size = chomp_size + + def forward(self, x): + """ + Args: + x: [M, H, Kpad] + Returns: + [M, H, K] + """ + return x[:, :, :-self.chomp_size].contiguous() + + +def chose_norm(norm_type, channel_size): + """The input of normlization will be (M, C, K), where M is batch size, + C is channel size and K is sequence length. + """ + if norm_type == "gLN": + return GlobalLayerNorm(channel_size) + elif norm_type == "cLN": + return ChannelwiseLayerNorm(channel_size) + elif norm_type == "id": + return nn.Identity() + else: # norm_type == "BN": + # Given input (M, C, K), nn.BatchNorm1d(C) will accumulate statics + # along M and K, so this BN usage is right. + return nn.BatchNorm1d(channel_size) + + +# TODO: Use nn.LayerNorm to impl cLN to speed up +class ChannelwiseLayerNorm(nn.Module): + """Channel-wise Layer Normalization (cLN)""" + def __init__(self, channel_size): + super(ChannelwiseLayerNorm, self).__init__() + self.gamma = nn.Parameter(torch.Tensor(1, channel_size, 1)) # [1, N, 1] + self.beta = nn.Parameter(torch.Tensor(1, channel_size, 1)) # [1, N, 1] + self.reset_parameters() + + def reset_parameters(self): + self.gamma.data.fill_(1) + self.beta.data.zero_() + + def forward(self, y): + """ + Args: + y: [M, N, K], M is batch size, N is channel size, K is length + Returns: + cLN_y: [M, N, K] + """ + mean = torch.mean(y, dim=1, keepdim=True) # [M, 1, K] + var = torch.var(y, dim=1, keepdim=True, unbiased=False) # [M, 1, K] + cLN_y = self.gamma * (y - mean) / torch.pow(var + EPS, 0.5) + self.beta + return cLN_y + + +class GlobalLayerNorm(nn.Module): + """Global Layer Normalization (gLN)""" + def __init__(self, channel_size): + super(GlobalLayerNorm, self).__init__() + self.gamma = nn.Parameter(torch.Tensor(1, channel_size, 1)) # [1, N, 1] + self.beta = nn.Parameter(torch.Tensor(1, channel_size, 1)) # [1, N, 1] + self.reset_parameters() + + def reset_parameters(self): + self.gamma.data.fill_(1) + self.beta.data.zero_() + + def forward(self, y): + """ + Args: + y: [M, N, K], M is batch size, N is channel size, K is length + Returns: + gLN_y: [M, N, K] + """ + # TODO: in torch 1.0, torch.mean() support dim list + mean = y.mean(dim=1, keepdim=True).mean(dim=2, keepdim=True) # [M, 1, 1] + var = (torch.pow(y - mean, 2)).mean(dim=1, keepdim=True).mean(dim=2, keepdim=True) + gLN_y = self.gamma * (y - mean) / torch.pow(var + EPS, 0.5) + self.beta + return gLN_y + + +if __name__ == "__main__": + torch.manual_seed(123) + M, N, L, T = 2, 3, 4, 12 + K = 2 * T // L - 1 + B, H, P, X, R, C, norm_type, causal = 2, 3, 3, 3, 2, 2, "gLN", False + mixture = torch.randint(3, (M, T)) + # test Encoder + encoder = Encoder(L, N) + encoder.conv1d_U.weight.data = torch.randint(2, encoder.conv1d_U.weight.size()) + mixture_w = encoder(mixture) + print('mixture', mixture) + print('U', encoder.conv1d_U.weight) + print('mixture_w', mixture_w) + print('mixture_w size', mixture_w.size()) + + # test TemporalConvNet + separator = TemporalConvNet(N, B, H, P, X, R, C, norm_type=norm_type, causal=causal) + est_mask = separator(mixture_w) + print('est_mask', est_mask) + + # test Decoder + decoder = Decoder(N, L) + est_mask = torch.randint(2, (B, K, C, N)) + est_source = decoder(mixture_w, est_mask) + print('est_source', est_source) + + # test Conv-TasNet + conv_tasnet = ConvTasNet(N, L, B, H, P, X, R, C, norm_type=norm_type) + est_source = conv_tasnet(mixture) + print('est_source', est_source) + print('est_source size', est_source.size()) diff --git a/demucs/tasnet_v2.py b/demucs/tasnet_v2.py new file mode 100644 index 00000000..ecc12579 --- /dev/null +++ b/demucs/tasnet_v2.py @@ -0,0 +1,452 @@ +# Copyright (c) Facebook, Inc. and its affiliates. +# All rights reserved. +# +# This source code is licensed under the license found in the +# LICENSE file in the root directory of this source tree. +# +# Created on 2018/12 +# Author: Kaituo XU +# Modified on 2019/11 by Alexandre Defossez, added support for multiple output channels +# Here is the original license: +# The MIT License (MIT) +# +# Copyright (c) 2018 Kaituo XU +# +# Permission is hereby granted, free of charge, to any person obtaining a copy +# of this software and associated documentation files (the "Software"), to deal +# in the Software without restriction, including without limitation the rights +# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell +# copies of the Software, and to permit persons to whom the Software is +# furnished to do so, subject to the following conditions: +# +# The above copyright notice and this permission notice shall be included in all +# copies or substantial portions of the Software. +# +# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR +# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, +# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE +# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER +# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, +# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE +# SOFTWARE. + +import math + +import torch +import torch.nn as nn +import torch.nn.functional as F + +from .utils import capture_init + +EPS = 1e-8 + + +def overlap_and_add(signal, frame_step): + outer_dimensions = signal.size()[:-2] + frames, frame_length = signal.size()[-2:] + + subframe_length = math.gcd(frame_length, frame_step) # gcd=Greatest Common Divisor + subframe_step = frame_step // subframe_length + subframes_per_frame = frame_length // subframe_length + output_size = frame_step * (frames - 1) + frame_length + output_subframes = output_size // subframe_length + + subframe_signal = signal.view(*outer_dimensions, -1, subframe_length) + + frame = torch.arange(0, output_subframes, + device=signal.device).unfold(0, subframes_per_frame, subframe_step) + frame = frame.long() # signal may in GPU or CPU + frame = frame.contiguous().view(-1) + + result = signal.new_zeros(*outer_dimensions, output_subframes, subframe_length) + result.index_add_(-2, frame, subframe_signal) + result = result.view(*outer_dimensions, -1) + return result + + +class ConvTasNet(nn.Module): + @capture_init + def __init__(self, + sources, + N=256, + L=20, + B=256, + H=512, + P=3, + X=8, + R=4, + audio_channels=2, + norm_type="gLN", + causal=False, + mask_nonlinear='relu', + samplerate=44100, + segment_length=44100 * 2 * 4): + """ + Args: + sources: list of sources + N: Number of filters in autoencoder + L: Length of the filters (in samples) + B: Number of channels in bottleneck 1 × 1-conv block + H: Number of channels in convolutional blocks + P: Kernel size in convolutional blocks + X: Number of convolutional blocks in each repeat + R: Number of repeats + norm_type: BN, gLN, cLN + causal: causal or non-causal + mask_nonlinear: use which non-linear function to generate mask + """ + super(ConvTasNet, self).__init__() + # Hyper-parameter + self.sources = sources + self.C = len(sources) + self.N, self.L, self.B, self.H, self.P, self.X, self.R = N, L, B, H, P, X, R + self.norm_type = norm_type + self.causal = causal + self.mask_nonlinear = mask_nonlinear + self.audio_channels = audio_channels + self.samplerate = samplerate + self.segment_length = segment_length + # Components + self.encoder = Encoder(L, N, audio_channels) + self.separator = TemporalConvNet( + N, B, H, P, X, R, self.C, norm_type, causal, mask_nonlinear) + self.decoder = Decoder(N, L, audio_channels) + # init + for p in self.parameters(): + if p.dim() > 1: + nn.init.xavier_normal_(p) + + def valid_length(self, length): + return length + + def forward(self, mixture): + """ + Args: + mixture: [M, T], M is batch size, T is #samples + Returns: + est_source: [M, C, T] + """ + mixture_w = self.encoder(mixture) + est_mask = self.separator(mixture_w) + est_source = self.decoder(mixture_w, est_mask) + + # T changed after conv1d in encoder, fix it here + T_origin = mixture.size(-1) + T_conv = est_source.size(-1) + est_source = F.pad(est_source, (0, T_origin - T_conv)) + return est_source + + +class Encoder(nn.Module): + """Estimation of the nonnegative mixture weight by a 1-D conv layer. + """ + def __init__(self, L, N, audio_channels): + super(Encoder, self).__init__() + # Hyper-parameter + self.L, self.N = L, N + # Components + # 50% overlap + self.conv1d_U = nn.Conv1d(audio_channels, N, kernel_size=L, stride=L // 2, bias=False) + + def forward(self, mixture): + """ + Args: + mixture: [M, T], M is batch size, T is #samples + Returns: + mixture_w: [M, N, K], where K = (T-L)/(L/2)+1 = 2T/L-1 + """ + mixture_w = F.relu(self.conv1d_U(mixture)) # [M, N, K] + return mixture_w + + +class Decoder(nn.Module): + def __init__(self, N, L, audio_channels): + super(Decoder, self).__init__() + # Hyper-parameter + self.N, self.L = N, L + self.audio_channels = audio_channels + # Components + self.basis_signals = nn.Linear(N, audio_channels * L, bias=False) + + def forward(self, mixture_w, est_mask): + """ + Args: + mixture_w: [M, N, K] + est_mask: [M, C, N, K] + Returns: + est_source: [M, C, T] + """ + # D = W * M + source_w = torch.unsqueeze(mixture_w, 1) * est_mask # [M, C, N, K] + source_w = torch.transpose(source_w, 2, 3) # [M, C, K, N] + # S = DV + est_source = self.basis_signals(source_w) # [M, C, K, ac * L] + m, c, k, _ = est_source.size() + est_source = est_source.view(m, c, k, self.audio_channels, -1).transpose(2, 3).contiguous() + est_source = overlap_and_add(est_source, self.L // 2) # M x C x ac x T + return est_source + + +class TemporalConvNet(nn.Module): + def __init__(self, N, B, H, P, X, R, C, norm_type="gLN", causal=False, mask_nonlinear='relu'): + """ + Args: + N: Number of filters in autoencoder + B: Number of channels in bottleneck 1 × 1-conv block + H: Number of channels in convolutional blocks + P: Kernel size in convolutional blocks + X: Number of convolutional blocks in each repeat + R: Number of repeats + C: Number of speakers + norm_type: BN, gLN, cLN + causal: causal or non-causal + mask_nonlinear: use which non-linear function to generate mask + """ + super(TemporalConvNet, self).__init__() + # Hyper-parameter + self.C = C + self.mask_nonlinear = mask_nonlinear + # Components + # [M, N, K] -> [M, N, K] + layer_norm = ChannelwiseLayerNorm(N) + # [M, N, K] -> [M, B, K] + bottleneck_conv1x1 = nn.Conv1d(N, B, 1, bias=False) + # [M, B, K] -> [M, B, K] + repeats = [] + for r in range(R): + blocks = [] + for x in range(X): + dilation = 2**x + padding = (P - 1) * dilation if causal else (P - 1) * dilation // 2 + blocks += [ + TemporalBlock(B, + H, + P, + stride=1, + padding=padding, + dilation=dilation, + norm_type=norm_type, + causal=causal) + ] + repeats += [nn.Sequential(*blocks)] + temporal_conv_net = nn.Sequential(*repeats) + # [M, B, K] -> [M, C*N, K] + mask_conv1x1 = nn.Conv1d(B, C * N, 1, bias=False) + # Put together + self.network = nn.Sequential(layer_norm, bottleneck_conv1x1, temporal_conv_net, + mask_conv1x1) + + def forward(self, mixture_w): + """ + Keep this API same with TasNet + Args: + mixture_w: [M, N, K], M is batch size + returns: + est_mask: [M, C, N, K] + """ + M, N, K = mixture_w.size() + score = self.network(mixture_w) # [M, N, K] -> [M, C*N, K] + score = score.view(M, self.C, N, K) # [M, C*N, K] -> [M, C, N, K] + if self.mask_nonlinear == 'softmax': + est_mask = F.softmax(score, dim=1) + elif self.mask_nonlinear == 'relu': + est_mask = F.relu(score) + else: + raise ValueError("Unsupported mask non-linear function") + return est_mask + + +class TemporalBlock(nn.Module): + def __init__(self, + in_channels, + out_channels, + kernel_size, + stride, + padding, + dilation, + norm_type="gLN", + causal=False): + super(TemporalBlock, self).__init__() + # [M, B, K] -> [M, H, K] + conv1x1 = nn.Conv1d(in_channels, out_channels, 1, bias=False) + prelu = nn.PReLU() + norm = chose_norm(norm_type, out_channels) + # [M, H, K] -> [M, B, K] + dsconv = DepthwiseSeparableConv(out_channels, in_channels, kernel_size, stride, padding, + dilation, norm_type, causal) + # Put together + self.net = nn.Sequential(conv1x1, prelu, norm, dsconv) + + def forward(self, x): + """ + Args: + x: [M, B, K] + Returns: + [M, B, K] + """ + residual = x + out = self.net(x) + # TODO: when P = 3 here works fine, but when P = 2 maybe need to pad? + return out + residual # look like w/o F.relu is better than w/ F.relu + # return F.relu(out + residual) + + +class DepthwiseSeparableConv(nn.Module): + def __init__(self, + in_channels, + out_channels, + kernel_size, + stride, + padding, + dilation, + norm_type="gLN", + causal=False): + super(DepthwiseSeparableConv, self).__init__() + # Use `groups` option to implement depthwise convolution + # [M, H, K] -> [M, H, K] + depthwise_conv = nn.Conv1d(in_channels, + in_channels, + kernel_size, + stride=stride, + padding=padding, + dilation=dilation, + groups=in_channels, + bias=False) + if causal: + chomp = Chomp1d(padding) + prelu = nn.PReLU() + norm = chose_norm(norm_type, in_channels) + # [M, H, K] -> [M, B, K] + pointwise_conv = nn.Conv1d(in_channels, out_channels, 1, bias=False) + # Put together + if causal: + self.net = nn.Sequential(depthwise_conv, chomp, prelu, norm, pointwise_conv) + else: + self.net = nn.Sequential(depthwise_conv, prelu, norm, pointwise_conv) + + def forward(self, x): + """ + Args: + x: [M, H, K] + Returns: + result: [M, B, K] + """ + return self.net(x) + + +class Chomp1d(nn.Module): + """To ensure the output length is the same as the input. + """ + def __init__(self, chomp_size): + super(Chomp1d, self).__init__() + self.chomp_size = chomp_size + + def forward(self, x): + """ + Args: + x: [M, H, Kpad] + Returns: + [M, H, K] + """ + return x[:, :, :-self.chomp_size].contiguous() + + +def chose_norm(norm_type, channel_size): + """The input of normlization will be (M, C, K), where M is batch size, + C is channel size and K is sequence length. + """ + if norm_type == "gLN": + return GlobalLayerNorm(channel_size) + elif norm_type == "cLN": + return ChannelwiseLayerNorm(channel_size) + elif norm_type == "id": + return nn.Identity() + else: # norm_type == "BN": + # Given input (M, C, K), nn.BatchNorm1d(C) will accumulate statics + # along M and K, so this BN usage is right. + return nn.BatchNorm1d(channel_size) + + +# TODO: Use nn.LayerNorm to impl cLN to speed up +class ChannelwiseLayerNorm(nn.Module): + """Channel-wise Layer Normalization (cLN)""" + def __init__(self, channel_size): + super(ChannelwiseLayerNorm, self).__init__() + self.gamma = nn.Parameter(torch.Tensor(1, channel_size, 1)) # [1, N, 1] + self.beta = nn.Parameter(torch.Tensor(1, channel_size, 1)) # [1, N, 1] + self.reset_parameters() + + def reset_parameters(self): + self.gamma.data.fill_(1) + self.beta.data.zero_() + + def forward(self, y): + """ + Args: + y: [M, N, K], M is batch size, N is channel size, K is length + Returns: + cLN_y: [M, N, K] + """ + mean = torch.mean(y, dim=1, keepdim=True) # [M, 1, K] + var = torch.var(y, dim=1, keepdim=True, unbiased=False) # [M, 1, K] + cLN_y = self.gamma * (y - mean) / torch.pow(var + EPS, 0.5) + self.beta + return cLN_y + + +class GlobalLayerNorm(nn.Module): + """Global Layer Normalization (gLN)""" + def __init__(self, channel_size): + super(GlobalLayerNorm, self).__init__() + self.gamma = nn.Parameter(torch.Tensor(1, channel_size, 1)) # [1, N, 1] + self.beta = nn.Parameter(torch.Tensor(1, channel_size, 1)) # [1, N, 1] + self.reset_parameters() + + def reset_parameters(self): + self.gamma.data.fill_(1) + self.beta.data.zero_() + + def forward(self, y): + """ + Args: + y: [M, N, K], M is batch size, N is channel size, K is length + Returns: + gLN_y: [M, N, K] + """ + # TODO: in torch 1.0, torch.mean() support dim list + mean = y.mean(dim=1, keepdim=True).mean(dim=2, keepdim=True) # [M, 1, 1] + var = (torch.pow(y - mean, 2)).mean(dim=1, keepdim=True).mean(dim=2, keepdim=True) + gLN_y = self.gamma * (y - mean) / torch.pow(var + EPS, 0.5) + self.beta + return gLN_y + + +if __name__ == "__main__": + torch.manual_seed(123) + M, N, L, T = 2, 3, 4, 12 + K = 2 * T // L - 1 + B, H, P, X, R, C, norm_type, causal = 2, 3, 3, 3, 2, 2, "gLN", False + mixture = torch.randint(3, (M, T)) + # test Encoder + encoder = Encoder(L, N) + encoder.conv1d_U.weight.data = torch.randint(2, encoder.conv1d_U.weight.size()) + mixture_w = encoder(mixture) + print('mixture', mixture) + print('U', encoder.conv1d_U.weight) + print('mixture_w', mixture_w) + print('mixture_w size', mixture_w.size()) + + # test TemporalConvNet + separator = TemporalConvNet(N, B, H, P, X, R, C, norm_type=norm_type, causal=causal) + est_mask = separator(mixture_w) + print('est_mask', est_mask) + + # test Decoder + decoder = Decoder(N, L) + est_mask = torch.randint(2, (B, K, C, N)) + est_source = decoder(mixture_w, est_mask) + print('est_source', est_source) + + # test Conv-TasNet + conv_tasnet = ConvTasNet(N, L, B, H, P, X, R, C, norm_type=norm_type) + est_source = conv_tasnet(mixture) + print('est_source', est_source) + print('est_source size', est_source.size()) diff --git a/demucs/transformer.py b/demucs/transformer.py new file mode 100644 index 00000000..56a465b8 --- /dev/null +++ b/demucs/transformer.py @@ -0,0 +1,839 @@ +# Copyright (c) 2019-present, Meta, Inc. +# All rights reserved. +# +# This source code is licensed under the license found in the +# LICENSE file in the root directory of this source tree. +# First author is Simon Rouard. + +import random +import typing as tp + +import torch +import torch.nn as nn +import torch.nn.functional as F +import numpy as np +import math +from einops import rearrange + + +def create_sin_embedding( + length: int, dim: int, shift: int = 0, device="cpu", max_period=10000 +): + # We aim for TBC format + assert dim % 2 == 0 + pos = shift + torch.arange(length, device=device).view(-1, 1, 1) + half_dim = dim // 2 + adim = torch.arange(dim // 2, device=device).view(1, 1, -1) + phase = pos / (max_period ** (adim / (half_dim - 1))) + return torch.cat( + [ + torch.cos(phase), + torch.sin(phase), + ], + dim=-1, + ) + + +def create_2d_sin_embedding(d_model, height, width, device="cpu", max_period=10000): + """ + :param d_model: dimension of the model + :param height: height of the positions + :param width: width of the positions + :return: d_model*height*width position matrix + """ + if d_model % 4 != 0: + raise ValueError( + "Cannot use sin/cos positional encoding with " + "odd dimension (got dim={:d})".format(d_model) + ) + pe = torch.zeros(d_model, height, width) + # Each dimension use half of d_model + d_model = int(d_model / 2) + div_term = torch.exp( + torch.arange(0.0, d_model, 2) * -(math.log(max_period) / d_model) + ) + pos_w = torch.arange(0.0, width).unsqueeze(1) + pos_h = torch.arange(0.0, height).unsqueeze(1) + pe[0:d_model:2, :, :] = ( + torch.sin(pos_w * div_term).transpose(0, 1).unsqueeze(1).repeat(1, height, 1) + ) + pe[1:d_model:2, :, :] = ( + torch.cos(pos_w * div_term).transpose(0, 1).unsqueeze(1).repeat(1, height, 1) + ) + pe[d_model::2, :, :] = ( + torch.sin(pos_h * div_term).transpose(0, 1).unsqueeze(2).repeat(1, 1, width) + ) + pe[d_model + 1:: 2, :, :] = ( + torch.cos(pos_h * div_term).transpose(0, 1).unsqueeze(2).repeat(1, 1, width) + ) + + return pe[None, :].to(device) + + +def create_sin_embedding_cape( + length: int, + dim: int, + batch_size: int, + mean_normalize: bool, + augment: bool, # True during training + max_global_shift: float = 0.0, # delta max + max_local_shift: float = 0.0, # epsilon max + max_scale: float = 1.0, + device: str = "cpu", + max_period: float = 10000.0, +): + # We aim for TBC format + assert dim % 2 == 0 + pos = 1.0 * torch.arange(length).view(-1, 1, 1) # (length, 1, 1) + pos = pos.repeat(1, batch_size, 1) # (length, batch_size, 1) + if mean_normalize: + pos -= torch.nanmean(pos, dim=0, keepdim=True) + + if augment: + delta = np.random.uniform( + -max_global_shift, +max_global_shift, size=[1, batch_size, 1] + ) + delta_local = np.random.uniform( + -max_local_shift, +max_local_shift, size=[length, batch_size, 1] + ) + log_lambdas = np.random.uniform( + -np.log(max_scale), +np.log(max_scale), size=[1, batch_size, 1] + ) + pos = (pos + delta + delta_local) * np.exp(log_lambdas) + + pos = pos.to(device) + + half_dim = dim // 2 + adim = torch.arange(dim // 2, device=device).view(1, 1, -1) + phase = pos / (max_period ** (adim / (half_dim - 1))) + return torch.cat( + [ + torch.cos(phase), + torch.sin(phase), + ], + dim=-1, + ).float() + + +def get_causal_mask(length): + pos = torch.arange(length) + return pos > pos[:, None] + + +def get_elementary_mask( + T1, + T2, + mask_type, + sparse_attn_window, + global_window, + mask_random_seed, + sparsity, + device, +): + """ + When the input of the Decoder has length T1 and the output T2 + The mask matrix has shape (T2, T1) + """ + assert mask_type in ["diag", "jmask", "random", "global"] + + if mask_type == "global": + mask = torch.zeros(T2, T1, dtype=torch.bool) + mask[:, :global_window] = True + line_window = int(global_window * T2 / T1) + mask[:line_window, :] = True + + if mask_type == "diag": + + mask = torch.zeros(T2, T1, dtype=torch.bool) + rows = torch.arange(T2)[:, None] + cols = ( + (T1 / T2 * rows + torch.arange(-sparse_attn_window, sparse_attn_window + 1)) + .long() + .clamp(0, T1 - 1) + ) + mask.scatter_(1, cols, torch.ones(1, dtype=torch.bool).expand_as(cols)) + + elif mask_type == "jmask": + mask = torch.zeros(T2 + 2, T1 + 2, dtype=torch.bool) + rows = torch.arange(T2 + 2)[:, None] + t = torch.arange(0, int((2 * T1) ** 0.5 + 1)) + t = (t * (t + 1) / 2).int() + t = torch.cat([-t.flip(0)[:-1], t]) + cols = (T1 / T2 * rows + t).long().clamp(0, T1 + 1) + mask.scatter_(1, cols, torch.ones(1, dtype=torch.bool).expand_as(cols)) + mask = mask[1:-1, 1:-1] + + elif mask_type == "random": + gene = torch.Generator(device=device) + gene.manual_seed(mask_random_seed) + mask = ( + torch.rand(T1 * T2, generator=gene, device=device).reshape(T2, T1) + > sparsity + ) + + mask = mask.to(device) + return mask + + +def get_mask( + T1, + T2, + mask_type, + sparse_attn_window, + global_window, + mask_random_seed, + sparsity, + device, +): + """ + Return a SparseCSRTensor mask that is a combination of elementary masks + mask_type can be a combination of multiple masks: for instance "diag_jmask_random" + """ + from xformers.sparse import SparseCSRTensor + # create a list + mask_types = mask_type.split("_") + + all_masks = [ + get_elementary_mask( + T1, + T2, + mask, + sparse_attn_window, + global_window, + mask_random_seed, + sparsity, + device, + ) + for mask in mask_types + ] + + final_mask = torch.stack(all_masks).sum(axis=0) > 0 + + return SparseCSRTensor.from_dense(final_mask[None]) + + +class ScaledEmbedding(nn.Module): + def __init__( + self, + num_embeddings: int, + embedding_dim: int, + scale: float = 1.0, + boost: float = 3.0, + ): + super().__init__() + self.embedding = nn.Embedding(num_embeddings, embedding_dim) + self.embedding.weight.data *= scale / boost + self.boost = boost + + @property + def weight(self): + return self.embedding.weight * self.boost + + def forward(self, x): + return self.embedding(x) * self.boost + + +class LayerScale(nn.Module): + """Layer scale from [Touvron et al 2021] (https://arxiv.org/pdf/2103.17239.pdf). + This rescales diagonaly residual outputs close to 0 initially, then learnt. + """ + + def __init__(self, channels: int, init: float = 0, channel_last=False): + """ + channel_last = False corresponds to (B, C, T) tensors + channel_last = True corresponds to (T, B, C) tensors + """ + super().__init__() + self.channel_last = channel_last + self.scale = nn.Parameter(torch.zeros(channels, requires_grad=True)) + self.scale.data[:] = init + + def forward(self, x): + if self.channel_last: + return self.scale * x + else: + return self.scale[:, None] * x + + +class MyGroupNorm(nn.GroupNorm): + def __init__(self, *args, **kwargs): + super().__init__(*args, **kwargs) + + def forward(self, x): + """ + x: (B, T, C) + if num_groups=1: Normalisation on all T and C together for each B + """ + x = x.transpose(1, 2) + return super().forward(x).transpose(1, 2) + + +class MyTransformerEncoderLayer(nn.TransformerEncoderLayer): + def __init__( + self, + d_model, + nhead, + dim_feedforward=2048, + dropout=0.1, + activation=F.relu, + group_norm=0, + norm_first=False, + norm_out=False, + layer_norm_eps=1e-5, + layer_scale=False, + init_values=1e-4, + device=None, + dtype=None, + sparse=False, + mask_type="diag", + mask_random_seed=42, + sparse_attn_window=500, + global_window=50, + auto_sparsity=False, + sparsity=0.95, + batch_first=False, + ): + factory_kwargs = {"device": device, "dtype": dtype} + super().__init__( + d_model=d_model, + nhead=nhead, + dim_feedforward=dim_feedforward, + dropout=dropout, + activation=activation, + layer_norm_eps=layer_norm_eps, + batch_first=batch_first, + norm_first=norm_first, + device=device, + dtype=dtype, + ) + self.sparse = sparse + self.auto_sparsity = auto_sparsity + if sparse: + if not auto_sparsity: + self.mask_type = mask_type + self.sparse_attn_window = sparse_attn_window + self.global_window = global_window + self.sparsity = sparsity + if group_norm: + self.norm1 = MyGroupNorm(int(group_norm), d_model, eps=layer_norm_eps, **factory_kwargs) + self.norm2 = MyGroupNorm(int(group_norm), d_model, eps=layer_norm_eps, **factory_kwargs) + + self.norm_out = None + if self.norm_first & norm_out: + self.norm_out = MyGroupNorm(num_groups=int(norm_out), num_channels=d_model) + self.gamma_1 = ( + LayerScale(d_model, init_values, True) if layer_scale else nn.Identity() + ) + self.gamma_2 = ( + LayerScale(d_model, init_values, True) if layer_scale else nn.Identity() + ) + + if sparse: + self.self_attn = MultiheadAttention( + d_model, nhead, dropout=dropout, batch_first=batch_first, + auto_sparsity=sparsity if auto_sparsity else 0, + ) + self.__setattr__("src_mask", torch.zeros(1, 1)) + self.mask_random_seed = mask_random_seed + + def forward(self, src, src_mask=None, src_key_padding_mask=None): + """ + if batch_first = False, src shape is (T, B, C) + the case where batch_first=True is not covered + """ + device = src.device + x = src + T, B, C = x.shape + if self.sparse and not self.auto_sparsity: + assert src_mask is None + src_mask = self.src_mask + if src_mask.shape[-1] != T: + src_mask = get_mask( + T, + T, + self.mask_type, + self.sparse_attn_window, + self.global_window, + self.mask_random_seed, + self.sparsity, + device, + ) + self.__setattr__("src_mask", src_mask) + + if self.norm_first: + x = x + self.gamma_1( + self._sa_block(self.norm1(x), src_mask, src_key_padding_mask) + ) + x = x + self.gamma_2(self._ff_block(self.norm2(x))) + + if self.norm_out: + x = self.norm_out(x) + else: + x = self.norm1( + x + self.gamma_1(self._sa_block(x, src_mask, src_key_padding_mask)) + ) + x = self.norm2(x + self.gamma_2(self._ff_block(x))) + + return x + + +class CrossTransformerEncoderLayer(nn.Module): + def __init__( + self, + d_model: int, + nhead: int, + dim_feedforward: int = 2048, + dropout: float = 0.1, + activation=F.relu, + layer_norm_eps: float = 1e-5, + layer_scale: bool = False, + init_values: float = 1e-4, + norm_first: bool = False, + group_norm: bool = False, + norm_out: bool = False, + sparse=False, + mask_type="diag", + mask_random_seed=42, + sparse_attn_window=500, + global_window=50, + sparsity=0.95, + auto_sparsity=None, + device=None, + dtype=None, + batch_first=False, + ): + factory_kwargs = {"device": device, "dtype": dtype} + super().__init__() + + self.sparse = sparse + self.auto_sparsity = auto_sparsity + if sparse: + if not auto_sparsity: + self.mask_type = mask_type + self.sparse_attn_window = sparse_attn_window + self.global_window = global_window + self.sparsity = sparsity + + self.cross_attn: nn.Module + self.cross_attn = nn.MultiheadAttention( + d_model, nhead, dropout=dropout, batch_first=batch_first) + # Implementation of Feedforward model + self.linear1 = nn.Linear(d_model, dim_feedforward, **factory_kwargs) + self.dropout = nn.Dropout(dropout) + self.linear2 = nn.Linear(dim_feedforward, d_model, **factory_kwargs) + + self.norm_first = norm_first + self.norm1: nn.Module + self.norm2: nn.Module + self.norm3: nn.Module + if group_norm: + self.norm1 = MyGroupNorm(int(group_norm), d_model, eps=layer_norm_eps, **factory_kwargs) + self.norm2 = MyGroupNorm(int(group_norm), d_model, eps=layer_norm_eps, **factory_kwargs) + self.norm3 = MyGroupNorm(int(group_norm), d_model, eps=layer_norm_eps, **factory_kwargs) + else: + self.norm1 = nn.LayerNorm(d_model, eps=layer_norm_eps, **factory_kwargs) + self.norm2 = nn.LayerNorm(d_model, eps=layer_norm_eps, **factory_kwargs) + self.norm3 = nn.LayerNorm(d_model, eps=layer_norm_eps, **factory_kwargs) + + self.norm_out = None + if self.norm_first & norm_out: + self.norm_out = MyGroupNorm(num_groups=int(norm_out), num_channels=d_model) + + self.gamma_1 = ( + LayerScale(d_model, init_values, True) if layer_scale else nn.Identity() + ) + self.gamma_2 = ( + LayerScale(d_model, init_values, True) if layer_scale else nn.Identity() + ) + + self.dropout1 = nn.Dropout(dropout) + self.dropout2 = nn.Dropout(dropout) + + # Legacy string support for activation function. + if isinstance(activation, str): + self.activation = self._get_activation_fn(activation) + else: + self.activation = activation + + if sparse: + self.cross_attn = MultiheadAttention( + d_model, nhead, dropout=dropout, batch_first=batch_first, + auto_sparsity=sparsity if auto_sparsity else 0) + if not auto_sparsity: + self.__setattr__("mask", torch.zeros(1, 1)) + self.mask_random_seed = mask_random_seed + + def forward(self, q, k, mask=None): + """ + Args: + q: tensor of shape (T, B, C) + k: tensor of shape (S, B, C) + mask: tensor of shape (T, S) + + """ + device = q.device + T, B, C = q.shape + S, B, C = k.shape + if self.sparse and not self.auto_sparsity: + assert mask is None + mask = self.mask + if mask.shape[-1] != S or mask.shape[-2] != T: + mask = get_mask( + S, + T, + self.mask_type, + self.sparse_attn_window, + self.global_window, + self.mask_random_seed, + self.sparsity, + device, + ) + self.__setattr__("mask", mask) + + if self.norm_first: + x = q + self.gamma_1(self._ca_block(self.norm1(q), self.norm2(k), mask)) + x = x + self.gamma_2(self._ff_block(self.norm3(x))) + if self.norm_out: + x = self.norm_out(x) + else: + x = self.norm1(q + self.gamma_1(self._ca_block(q, k, mask))) + x = self.norm2(x + self.gamma_2(self._ff_block(x))) + + return x + + # self-attention block + def _ca_block(self, q, k, attn_mask=None): + x = self.cross_attn(q, k, k, attn_mask=attn_mask, need_weights=False)[0] + return self.dropout1(x) + + # feed forward block + def _ff_block(self, x): + x = self.linear2(self.dropout(self.activation(self.linear1(x)))) + return self.dropout2(x) + + def _get_activation_fn(self, activation): + if activation == "relu": + return F.relu + elif activation == "gelu": + return F.gelu + + raise RuntimeError("activation should be relu/gelu, not {}".format(activation)) + + +# ----------------- MULTI-BLOCKS MODELS: ----------------------- + + +class CrossTransformerEncoder(nn.Module): + def __init__( + self, + dim: int, + emb: str = "sin", + hidden_scale: float = 4.0, + num_heads: int = 8, + num_layers: int = 6, + cross_first: bool = False, + dropout: float = 0.0, + max_positions: int = 1000, + norm_in: bool = True, + norm_in_group: bool = False, + group_norm: int = False, + norm_first: bool = False, + norm_out: bool = False, + max_period: float = 10000.0, + weight_decay: float = 0.0, + lr: tp.Optional[float] = None, + layer_scale: bool = False, + gelu: bool = True, + sin_random_shift: int = 0, + weight_pos_embed: float = 1.0, + cape_mean_normalize: bool = True, + cape_augment: bool = True, + cape_glob_loc_scale: list = [5000.0, 1.0, 1.4], + sparse_self_attn: bool = False, + sparse_cross_attn: bool = False, + mask_type: str = "diag", + mask_random_seed: int = 42, + sparse_attn_window: int = 500, + global_window: int = 50, + auto_sparsity: bool = False, + sparsity: float = 0.95, + ): + super().__init__() + """ + """ + assert dim % num_heads == 0 + + hidden_dim = int(dim * hidden_scale) + + self.num_layers = num_layers + # classic parity = 1 means that if idx%2 == 1 there is a + # classical encoder else there is a cross encoder + self.classic_parity = 1 if cross_first else 0 + self.emb = emb + self.max_period = max_period + self.weight_decay = weight_decay + self.weight_pos_embed = weight_pos_embed + self.sin_random_shift = sin_random_shift + if emb == "cape": + self.cape_mean_normalize = cape_mean_normalize + self.cape_augment = cape_augment + self.cape_glob_loc_scale = cape_glob_loc_scale + if emb == "scaled": + self.position_embeddings = ScaledEmbedding(max_positions, dim, scale=0.2) + + self.lr = lr + + activation: tp.Any = F.gelu if gelu else F.relu + + self.norm_in: nn.Module + self.norm_in_t: nn.Module + if norm_in: + self.norm_in = nn.LayerNorm(dim) + self.norm_in_t = nn.LayerNorm(dim) + elif norm_in_group: + self.norm_in = MyGroupNorm(int(norm_in_group), dim) + self.norm_in_t = MyGroupNorm(int(norm_in_group), dim) + else: + self.norm_in = nn.Identity() + self.norm_in_t = nn.Identity() + + # spectrogram layers + self.layers = nn.ModuleList() + # temporal layers + self.layers_t = nn.ModuleList() + + kwargs_common = { + "d_model": dim, + "nhead": num_heads, + "dim_feedforward": hidden_dim, + "dropout": dropout, + "activation": activation, + "group_norm": group_norm, + "norm_first": norm_first, + "norm_out": norm_out, + "layer_scale": layer_scale, + "mask_type": mask_type, + "mask_random_seed": mask_random_seed, + "sparse_attn_window": sparse_attn_window, + "global_window": global_window, + "sparsity": sparsity, + "auto_sparsity": auto_sparsity, + "batch_first": True, + } + + kwargs_classic_encoder = dict(kwargs_common) + kwargs_classic_encoder.update({ + "sparse": sparse_self_attn, + }) + kwargs_cross_encoder = dict(kwargs_common) + kwargs_cross_encoder.update({ + "sparse": sparse_cross_attn, + }) + + for idx in range(num_layers): + if idx % 2 == self.classic_parity: + + self.layers.append(MyTransformerEncoderLayer(**kwargs_classic_encoder)) + self.layers_t.append( + MyTransformerEncoderLayer(**kwargs_classic_encoder) + ) + + else: + self.layers.append(CrossTransformerEncoderLayer(**kwargs_cross_encoder)) + + self.layers_t.append( + CrossTransformerEncoderLayer(**kwargs_cross_encoder) + ) + + def forward(self, x, xt): + B, C, Fr, T1 = x.shape + pos_emb_2d = create_2d_sin_embedding( + C, Fr, T1, x.device, self.max_period + ) # (1, C, Fr, T1) + pos_emb_2d = rearrange(pos_emb_2d, "b c fr t1 -> b (t1 fr) c") + x = rearrange(x, "b c fr t1 -> b (t1 fr) c") + x = self.norm_in(x) + x = x + self.weight_pos_embed * pos_emb_2d + + B, C, T2 = xt.shape + xt = rearrange(xt, "b c t2 -> b t2 c") # now T2, B, C + pos_emb = self._get_pos_embedding(T2, B, C, x.device) + pos_emb = rearrange(pos_emb, "t2 b c -> b t2 c") + xt = self.norm_in_t(xt) + xt = xt + self.weight_pos_embed * pos_emb + + for idx in range(self.num_layers): + if idx % 2 == self.classic_parity: + x = self.layers[idx](x) + xt = self.layers_t[idx](xt) + else: + old_x = x + x = self.layers[idx](x, xt) + xt = self.layers_t[idx](xt, old_x) + + x = rearrange(x, "b (t1 fr) c -> b c fr t1", t1=T1) + xt = rearrange(xt, "b t2 c -> b c t2") + return x, xt + + def _get_pos_embedding(self, T, B, C, device): + if self.emb == "sin": + shift = random.randrange(self.sin_random_shift + 1) + pos_emb = create_sin_embedding( + T, C, shift=shift, device=device, max_period=self.max_period + ) + elif self.emb == "cape": + if self.training: + pos_emb = create_sin_embedding_cape( + T, + C, + B, + device=device, + max_period=self.max_period, + mean_normalize=self.cape_mean_normalize, + augment=self.cape_augment, + max_global_shift=self.cape_glob_loc_scale[0], + max_local_shift=self.cape_glob_loc_scale[1], + max_scale=self.cape_glob_loc_scale[2], + ) + else: + pos_emb = create_sin_embedding_cape( + T, + C, + B, + device=device, + max_period=self.max_period, + mean_normalize=self.cape_mean_normalize, + augment=False, + ) + + elif self.emb == "scaled": + pos = torch.arange(T, device=device) + pos_emb = self.position_embeddings(pos)[:, None] + + return pos_emb + + def make_optim_group(self): + group = {"params": list(self.parameters()), "weight_decay": self.weight_decay} + if self.lr is not None: + group["lr"] = self.lr + return group + + +# Attention Modules + + +class MultiheadAttention(nn.Module): + def __init__( + self, + embed_dim, + num_heads, + dropout=0.0, + bias=True, + add_bias_kv=False, + add_zero_attn=False, + kdim=None, + vdim=None, + batch_first=False, + auto_sparsity=None, + ): + super().__init__() + assert auto_sparsity is not None, "sanity check" + self.num_heads = num_heads + self.q = torch.nn.Linear(embed_dim, embed_dim, bias=bias) + self.k = torch.nn.Linear(embed_dim, embed_dim, bias=bias) + self.v = torch.nn.Linear(embed_dim, embed_dim, bias=bias) + self.attn_drop = torch.nn.Dropout(dropout) + self.proj = torch.nn.Linear(embed_dim, embed_dim, bias) + self.proj_drop = torch.nn.Dropout(dropout) + self.batch_first = batch_first + self.auto_sparsity = auto_sparsity + + def forward( + self, + query, + key, + value, + key_padding_mask=None, + need_weights=True, + attn_mask=None, + average_attn_weights=True, + ): + + if not self.batch_first: # N, B, C + query = query.permute(1, 0, 2) # B, N_q, C + key = key.permute(1, 0, 2) # B, N_k, C + value = value.permute(1, 0, 2) # B, N_k, C + B, N_q, C = query.shape + B, N_k, C = key.shape + + q = ( + self.q(query) + .reshape(B, N_q, self.num_heads, C // self.num_heads) + .permute(0, 2, 1, 3) + ) + q = q.flatten(0, 1) + k = ( + self.k(key) + .reshape(B, N_k, self.num_heads, C // self.num_heads) + .permute(0, 2, 1, 3) + ) + k = k.flatten(0, 1) + v = ( + self.v(value) + .reshape(B, N_k, self.num_heads, C // self.num_heads) + .permute(0, 2, 1, 3) + ) + v = v.flatten(0, 1) + + if self.auto_sparsity: + assert attn_mask is None + x = dynamic_sparse_attention(q, k, v, sparsity=self.auto_sparsity) + else: + x = scaled_dot_product_attention(q, k, v, attn_mask, dropout=self.attn_drop) + x = x.reshape(B, self.num_heads, N_q, C // self.num_heads) + + x = x.transpose(1, 2).reshape(B, N_q, C) + x = self.proj(x) + x = self.proj_drop(x) + if not self.batch_first: + x = x.permute(1, 0, 2) + return x, None + + +def scaled_query_key_softmax(q, k, att_mask): + from xformers.ops import masked_matmul + q = q / (k.size(-1)) ** 0.5 + att = masked_matmul(q, k.transpose(-2, -1), att_mask) + att = torch.nn.functional.softmax(att, -1) + return att + + +def scaled_dot_product_attention(q, k, v, att_mask, dropout): + att = scaled_query_key_softmax(q, k, att_mask=att_mask) + att = dropout(att) + y = att @ v + return y + + +def _compute_buckets(x, R): + qq = torch.einsum('btf,bfhi->bhti', x, R) + qq = torch.cat([qq, -qq], dim=-1) + buckets = qq.argmax(dim=-1) + + return buckets.permute(0, 2, 1).byte().contiguous() + + +def dynamic_sparse_attention(query, key, value, sparsity, infer_sparsity=True, attn_bias=None): + # assert False, "The code for the custom sparse kernel is not ready for release yet." + from xformers.ops import find_locations, sparse_memory_efficient_attention + n_hashes = 32 + proj_size = 4 + query, key, value = [x.contiguous() for x in [query, key, value]] + with torch.no_grad(): + R = torch.randn(1, query.shape[-1], n_hashes, proj_size // 2, device=query.device) + bucket_query = _compute_buckets(query, R) + bucket_key = _compute_buckets(key, R) + row_offsets, column_indices = find_locations( + bucket_query, bucket_key, sparsity, infer_sparsity) + return sparse_memory_efficient_attention( + query, key, value, row_offsets, column_indices, attn_bias) diff --git a/demucs/utils.py b/demucs/utils.py new file mode 100644 index 00000000..94bd323d --- /dev/null +++ b/demucs/utils.py @@ -0,0 +1,502 @@ +# Copyright (c) Facebook, Inc. and its affiliates. +# All rights reserved. +# +# This source code is licensed under the license found in the +# LICENSE file in the root directory of this source tree. + +from collections import defaultdict +from contextlib import contextmanager +import math +import os +import tempfile +import typing as tp + +import errno +import functools +import hashlib +import inspect +import io +import os +import random +import socket +import tempfile +import warnings +import zlib +import tkinter as tk + +from diffq import UniformQuantizer, DiffQuantizer +import torch as th +import tqdm +from torch import distributed +from torch.nn import functional as F + +import torch + +def unfold(a, kernel_size, stride): + """Given input of size [*OT, T], output Tensor of size [*OT, F, K] + with K the kernel size, by extracting frames with the given stride. + + This will pad the input so that `F = ceil(T / K)`. + + see https://github.com/pytorch/pytorch/issues/60466 + """ + *shape, length = a.shape + n_frames = math.ceil(length / stride) + tgt_length = (n_frames - 1) * stride + kernel_size + a = F.pad(a, (0, tgt_length - length)) + strides = list(a.stride()) + assert strides[-1] == 1, 'data should be contiguous' + strides = strides[:-1] + [stride, 1] + return a.as_strided([*shape, n_frames, kernel_size], strides) + + +def center_trim(tensor: torch.Tensor, reference: tp.Union[torch.Tensor, int]): + """ + Center trim `tensor` with respect to `reference`, along the last dimension. + `reference` can also be a number, representing the length to trim to. + If the size difference != 0 mod 2, the extra sample is removed on the right side. + """ + ref_size: int + if isinstance(reference, torch.Tensor): + ref_size = reference.size(-1) + else: + ref_size = reference + delta = tensor.size(-1) - ref_size + if delta < 0: + raise ValueError("tensor must be larger than reference. " f"Delta is {delta}.") + if delta: + tensor = tensor[..., delta // 2:-(delta - delta // 2)] + return tensor + + +def pull_metric(history: tp.List[dict], name: str): + out = [] + for metrics in history: + metric = metrics + for part in name.split("."): + metric = metric[part] + out.append(metric) + return out + + +def EMA(beta: float = 1): + """ + Exponential Moving Average callback. + Returns a single function that can be called to repeatidly update the EMA + with a dict of metrics. The callback will return + the new averaged dict of metrics. + + Note that for `beta=1`, this is just plain averaging. + """ + fix: tp.Dict[str, float] = defaultdict(float) + total: tp.Dict[str, float] = defaultdict(float) + + def _update(metrics: dict, weight: float = 1) -> dict: + nonlocal total, fix + for key, value in metrics.items(): + total[key] = total[key] * beta + weight * float(value) + fix[key] = fix[key] * beta + weight + return {key: tot / fix[key] for key, tot in total.items()} + return _update + + +def sizeof_fmt(num: float, suffix: str = 'B'): + """ + Given `num` bytes, return human readable size. + Taken from https://stackoverflow.com/a/1094933 + """ + for unit in ['', 'Ki', 'Mi', 'Gi', 'Ti', 'Pi', 'Ei', 'Zi']: + if abs(num) < 1024.0: + return "%3.1f%s%s" % (num, unit, suffix) + num /= 1024.0 + return "%.1f%s%s" % (num, 'Yi', suffix) + + +@contextmanager +def temp_filenames(count: int, delete=True): + names = [] + try: + for _ in range(count): + names.append(tempfile.NamedTemporaryFile(delete=False).name) + yield names + finally: + if delete: + for name in names: + os.unlink(name) + +def average_metric(metric, count=1.): + """ + Average `metric` which should be a float across all hosts. `count` should be + the weight for this particular host (i.e. number of examples). + """ + metric = th.tensor([count, count * metric], dtype=th.float32, device='cuda') + distributed.all_reduce(metric, op=distributed.ReduceOp.SUM) + return metric[1].item() / metric[0].item() + + +def free_port(host='', low=20000, high=40000): + """ + Return a port number that is most likely free. + This could suffer from a race condition although + it should be quite rare. + """ + sock = socket.socket() + while True: + port = random.randint(low, high) + try: + sock.bind((host, port)) + except OSError as error: + if error.errno == errno.EADDRINUSE: + continue + raise + return port + + +def sizeof_fmt(num, suffix='B'): + """ + Given `num` bytes, return human readable size. + Taken from https://stackoverflow.com/a/1094933 + """ + for unit in ['', 'Ki', 'Mi', 'Gi', 'Ti', 'Pi', 'Ei', 'Zi']: + if abs(num) < 1024.0: + return "%3.1f%s%s" % (num, unit, suffix) + num /= 1024.0 + return "%.1f%s%s" % (num, 'Yi', suffix) + + +def human_seconds(seconds, display='.2f'): + """ + Given `seconds` seconds, return human readable duration. + """ + value = seconds * 1e6 + ratios = [1e3, 1e3, 60, 60, 24] + names = ['us', 'ms', 's', 'min', 'hrs', 'days'] + last = names.pop(0) + for name, ratio in zip(names, ratios): + if value / ratio < 0.3: + break + value /= ratio + last = name + return f"{format(value, display)} {last}" + + +class TensorChunk: + def __init__(self, tensor, offset=0, length=None): + total_length = tensor.shape[-1] + assert offset >= 0 + assert offset < total_length + + if length is None: + length = total_length - offset + else: + length = min(total_length - offset, length) + + self.tensor = tensor + self.offset = offset + self.length = length + self.device = tensor.device + + @property + def shape(self): + shape = list(self.tensor.shape) + shape[-1] = self.length + return shape + + def padded(self, target_length): + delta = target_length - self.length + total_length = self.tensor.shape[-1] + assert delta >= 0 + + start = self.offset - delta // 2 + end = start + target_length + + correct_start = max(0, start) + correct_end = min(total_length, end) + + pad_left = correct_start - start + pad_right = end - correct_end + + out = F.pad(self.tensor[..., correct_start:correct_end], (pad_left, pad_right)) + assert out.shape[-1] == target_length + return out + + +def tensor_chunk(tensor_or_chunk): + if isinstance(tensor_or_chunk, TensorChunk): + return tensor_or_chunk + else: + assert isinstance(tensor_or_chunk, th.Tensor) + return TensorChunk(tensor_or_chunk) + + +def apply_model_v1(model, mix, shifts=None, split=False, progress=False, set_progress_bar=None): + """ + Apply model to a given mixture. + + Args: + shifts (int): if > 0, will shift in time `mix` by a random amount between 0 and 0.5 sec + and apply the oppositve shift to the output. This is repeated `shifts` time and + all predictions are averaged. This effectively makes the model time equivariant + and improves SDR by up to 0.2 points. + split (bool): if True, the input will be broken down in 8 seconds extracts + and predictions will be performed individually on each and concatenated. + Useful for model with large memory footprint like Tasnet. + progress (bool): if True, show a progress bar (requires split=True) + """ + + channels, length = mix.size() + device = mix.device + progress_value = 0 + + if split: + out = th.zeros(4, channels, length, device=device) + shift = model.samplerate * 10 + offsets = range(0, length, shift) + scale = 10 + if progress: + offsets = tqdm.tqdm(offsets, unit_scale=scale, ncols=120, unit='seconds') + for offset in offsets: + chunk = mix[..., offset:offset + shift] + if set_progress_bar: + progress_value += 1 + set_progress_bar(0.1, (0.8/len(offsets)*progress_value)) + chunk_out = apply_model_v1(model, chunk, shifts=shifts, set_progress_bar=set_progress_bar) + else: + chunk_out = apply_model_v1(model, chunk, shifts=shifts) + out[..., offset:offset + shift] = chunk_out + offset += shift + return out + elif shifts: + max_shift = int(model.samplerate / 2) + mix = F.pad(mix, (max_shift, max_shift)) + offsets = list(range(max_shift)) + random.shuffle(offsets) + out = 0 + for offset in offsets[:shifts]: + shifted = mix[..., offset:offset + length + max_shift] + if set_progress_bar: + shifted_out = apply_model_v1(model, shifted, set_progress_bar=set_progress_bar) + else: + shifted_out = apply_model_v1(model, shifted) + out += shifted_out[..., max_shift - offset:max_shift - offset + length] + out /= shifts + return out + else: + valid_length = model.valid_length(length) + delta = valid_length - length + padded = F.pad(mix, (delta // 2, delta - delta // 2)) + with th.no_grad(): + out = model(padded.unsqueeze(0))[0] + return center_trim(out, mix) + +def apply_model_v2(model, mix, shifts=None, split=False, + overlap=0.25, transition_power=1., progress=False, set_progress_bar=None): + """ + Apply model to a given mixture. + + Args: + shifts (int): if > 0, will shift in time `mix` by a random amount between 0 and 0.5 sec + and apply the oppositve shift to the output. This is repeated `shifts` time and + all predictions are averaged. This effectively makes the model time equivariant + and improves SDR by up to 0.2 points. + split (bool): if True, the input will be broken down in 8 seconds extracts + and predictions will be performed individually on each and concatenated. + Useful for model with large memory footprint like Tasnet. + progress (bool): if True, show a progress bar (requires split=True) + """ + + assert transition_power >= 1, "transition_power < 1 leads to weird behavior." + device = mix.device + channels, length = mix.shape + progress_value = 0 + + if split: + out = th.zeros(len(model.sources), channels, length, device=device) + sum_weight = th.zeros(length, device=device) + segment = model.segment_length + stride = int((1 - overlap) * segment) + offsets = range(0, length, stride) + scale = stride / model.samplerate + if progress: + offsets = tqdm.tqdm(offsets, unit_scale=scale, ncols=120, unit='seconds') + # We start from a triangle shaped weight, with maximal weight in the middle + # of the segment. Then we normalize and take to the power `transition_power`. + # Large values of transition power will lead to sharper transitions. + weight = th.cat([th.arange(1, segment // 2 + 1), + th.arange(segment - segment // 2, 0, -1)]).to(device) + assert len(weight) == segment + # If the overlap < 50%, this will translate to linear transition when + # transition_power is 1. + weight = (weight / weight.max())**transition_power + for offset in offsets: + chunk = TensorChunk(mix, offset, segment) + if set_progress_bar: + progress_value += 1 + set_progress_bar(0.1, (0.8/len(offsets)*progress_value)) + chunk_out = apply_model_v2(model, chunk, shifts=shifts, set_progress_bar=set_progress_bar) + else: + chunk_out = apply_model_v2(model, chunk, shifts=shifts) + chunk_length = chunk_out.shape[-1] + out[..., offset:offset + segment] += weight[:chunk_length] * chunk_out + sum_weight[offset:offset + segment] += weight[:chunk_length] + offset += segment + assert sum_weight.min() > 0 + out /= sum_weight + return out + elif shifts: + max_shift = int(0.5 * model.samplerate) + mix = tensor_chunk(mix) + padded_mix = mix.padded(length + 2 * max_shift) + out = 0 + for _ in range(shifts): + offset = random.randint(0, max_shift) + shifted = TensorChunk(padded_mix, offset, length + max_shift - offset) + + if set_progress_bar: + progress_value += 1 + shifted_out = apply_model_v2(model, shifted, set_progress_bar=set_progress_bar) + else: + shifted_out = apply_model_v2(model, shifted) + out += shifted_out[..., max_shift - offset:] + out /= shifts + return out + else: + valid_length = model.valid_length(length) + mix = tensor_chunk(mix) + padded_mix = mix.padded(valid_length) + with th.no_grad(): + out = model(padded_mix.unsqueeze(0))[0] + return center_trim(out, length) + + +@contextmanager +def temp_filenames(count, delete=True): + names = [] + try: + for _ in range(count): + names.append(tempfile.NamedTemporaryFile(delete=False).name) + yield names + finally: + if delete: + for name in names: + os.unlink(name) + + +def get_quantizer(model, args, optimizer=None): + quantizer = None + if args.diffq: + quantizer = DiffQuantizer( + model, min_size=args.q_min_size, group_size=8) + if optimizer is not None: + quantizer.setup_optimizer(optimizer) + elif args.qat: + quantizer = UniformQuantizer( + model, bits=args.qat, min_size=args.q_min_size) + return quantizer + + +def load_model(path, strict=False): + with warnings.catch_warnings(): + warnings.simplefilter("ignore") + load_from = path + package = th.load(load_from, 'cpu') + + klass = package["klass"] + args = package["args"] + kwargs = package["kwargs"] + + if strict: + model = klass(*args, **kwargs) + else: + sig = inspect.signature(klass) + for key in list(kwargs): + if key not in sig.parameters: + warnings.warn("Dropping inexistant parameter " + key) + del kwargs[key] + model = klass(*args, **kwargs) + + state = package["state"] + training_args = package["training_args"] + quantizer = get_quantizer(model, training_args) + + set_state(model, quantizer, state) + return model + + +def get_state(model, quantizer): + if quantizer is None: + state = {k: p.data.to('cpu') for k, p in model.state_dict().items()} + else: + state = quantizer.get_quantized_state() + buf = io.BytesIO() + th.save(state, buf) + state = {'compressed': zlib.compress(buf.getvalue())} + return state + + +def set_state(model, quantizer, state): + if quantizer is None: + model.load_state_dict(state) + else: + buf = io.BytesIO(zlib.decompress(state["compressed"])) + state = th.load(buf, "cpu") + quantizer.restore_quantized_state(state) + + return state + + +def save_state(state, path): + buf = io.BytesIO() + th.save(state, buf) + sig = hashlib.sha256(buf.getvalue()).hexdigest()[:8] + + path = path.parent / (path.stem + "-" + sig + path.suffix) + path.write_bytes(buf.getvalue()) + + +def save_model(model, quantizer, training_args, path): + args, kwargs = model._init_args_kwargs + klass = model.__class__ + + state = get_state(model, quantizer) + + save_to = path + package = { + 'klass': klass, + 'args': args, + 'kwargs': kwargs, + 'state': state, + 'training_args': training_args, + } + th.save(package, save_to) + + +def capture_init(init): + @functools.wraps(init) + def __init__(self, *args, **kwargs): + self._init_args_kwargs = (args, kwargs) + init(self, *args, **kwargs) + + return __init__ + +class DummyPoolExecutor: + class DummyResult: + def __init__(self, func, *args, **kwargs): + self.func = func + self.args = args + self.kwargs = kwargs + + def result(self): + return self.func(*self.args, **self.kwargs) + + def __init__(self, workers=0): + pass + + def submit(self, func, *args, **kwargs): + return DummyPoolExecutor.DummyResult(func, *args, **kwargs) + + def __enter__(self): + return self + + def __exit__(self, exc_type, exc_value, exc_tb): + return diff --git a/lib_v5/spec_utils.py b/lib_v5/spec_utils.py new file mode 100644 index 00000000..fc980169 --- /dev/null +++ b/lib_v5/spec_utils.py @@ -0,0 +1,736 @@ +import librosa +import numpy as np +import soundfile as sf +import math +import random +import pyrubberband +import math +#import noisereduce as nr + +MAX_SPEC = 'Max Spec' +MIN_SPEC = 'Min Spec' +AVERAGE = 'Average' + +def crop_center(h1, h2): + h1_shape = h1.size() + h2_shape = h2.size() + + if h1_shape[3] == h2_shape[3]: + return h1 + elif h1_shape[3] < h2_shape[3]: + raise ValueError('h1_shape[3] must be greater than h2_shape[3]') + + # s_freq = (h2_shape[2] - h1_shape[2]) // 2 + # e_freq = s_freq + h1_shape[2] + s_time = (h1_shape[3] - h2_shape[3]) // 2 + e_time = s_time + h2_shape[3] + h1 = h1[:, :, :, s_time:e_time] + + return h1 + +def preprocess(X_spec): + X_mag = np.abs(X_spec) + X_phase = np.angle(X_spec) + + return X_mag, X_phase + +def make_padding(width, cropsize, offset): + left = offset + roi_size = cropsize - offset * 2 + if roi_size == 0: + roi_size = cropsize + right = roi_size - (width % roi_size) + left + + return left, right, roi_size + +def wave_to_spectrogram(wave, hop_length, n_fft, mid_side=False, mid_side_b2=False, reverse=False): + if reverse: + wave_left = np.flip(np.asfortranarray(wave[0])) + wave_right = np.flip(np.asfortranarray(wave[1])) + elif mid_side: + wave_left = np.asfortranarray(np.add(wave[0], wave[1]) / 2) + wave_right = np.asfortranarray(np.subtract(wave[0], wave[1])) + elif mid_side_b2: + wave_left = np.asfortranarray(np.add(wave[1], wave[0] * .5)) + wave_right = np.asfortranarray(np.subtract(wave[0], wave[1] * .5)) + else: + wave_left = np.asfortranarray(wave[0]) + wave_right = np.asfortranarray(wave[1]) + + spec_left = librosa.stft(wave_left, n_fft, hop_length=hop_length) + spec_right = librosa.stft(wave_right, n_fft, hop_length=hop_length) + + spec = np.asfortranarray([spec_left, spec_right]) + + return spec + +def wave_to_spectrogram_mt(wave, hop_length, n_fft, mid_side=False, mid_side_b2=False, reverse=False): + import threading + + if reverse: + wave_left = np.flip(np.asfortranarray(wave[0])) + wave_right = np.flip(np.asfortranarray(wave[1])) + elif mid_side: + wave_left = np.asfortranarray(np.add(wave[0], wave[1]) / 2) + wave_right = np.asfortranarray(np.subtract(wave[0], wave[1])) + elif mid_side_b2: + wave_left = np.asfortranarray(np.add(wave[1], wave[0] * .5)) + wave_right = np.asfortranarray(np.subtract(wave[0], wave[1] * .5)) + else: + wave_left = np.asfortranarray(wave[0]) + wave_right = np.asfortranarray(wave[1]) + + def run_thread(**kwargs): + global spec_left + spec_left = librosa.stft(**kwargs) + + thread = threading.Thread(target=run_thread, kwargs={'y': wave_left, 'n_fft': n_fft, 'hop_length': hop_length}) + thread.start() + spec_right = librosa.stft(wave_right, n_fft, hop_length=hop_length) + thread.join() + + spec = np.asfortranarray([spec_left, spec_right]) + + return spec + +def normalize(wave, is_normalize=False): + """Save output music files""" + maxv = np.abs(wave).max() + if maxv > 1.0: + print(f"\nNormalization Set {is_normalize}: Input above threshold for clipping. Max:{maxv}") + if is_normalize: + print(f"The result was normalized.") + wave /= maxv + else: + print(f"\nNormalization Set {is_normalize}: Input not above threshold for clipping. Max:{maxv}") + + return wave + +def normalize_two_stem(wave, mix, is_normalize=False): + """Save output music files""" + + maxv = np.abs(wave).max() + max_mix = np.abs(mix).max() + + if maxv > 1.0: + print(f"\nNormalization Set {is_normalize}: Primary source above threshold for clipping. The result was normalized. Max:{maxv}") + print(f"\nNormalization Set {is_normalize}: Mixture above threshold for clipping. The result was normalized. Max:{max_mix}") + if is_normalize: + wave /= maxv + mix /= maxv + else: + print(f"\nNormalization Set {is_normalize}: Input not above threshold for clipping. Max:{maxv}") + + + print(f"\nNormalization Set {is_normalize}: Primary source - Max:{np.abs(wave).max()}") + print(f"\nNormalization Set {is_normalize}: Mixture - Max:{np.abs(mix).max()}") + + return wave, mix + +def combine_spectrograms(specs, mp): + l = min([specs[i].shape[2] for i in specs]) + spec_c = np.zeros(shape=(2, mp.param['bins'] + 1, l), dtype=np.complex64) + offset = 0 + bands_n = len(mp.param['band']) + + for d in range(1, bands_n + 1): + h = mp.param['band'][d]['crop_stop'] - mp.param['band'][d]['crop_start'] + spec_c[:, offset:offset+h, :l] = specs[d][:, mp.param['band'][d]['crop_start']:mp.param['band'][d]['crop_stop'], :l] + offset += h + + if offset > mp.param['bins']: + raise ValueError('Too much bins') + + # lowpass fiter + if mp.param['pre_filter_start'] > 0: # and mp.param['band'][bands_n]['res_type'] in ['scipy', 'polyphase']: + if bands_n == 1: + spec_c = fft_lp_filter(spec_c, mp.param['pre_filter_start'], mp.param['pre_filter_stop']) + else: + gp = 1 + for b in range(mp.param['pre_filter_start'] + 1, mp.param['pre_filter_stop']): + g = math.pow(10, -(b - mp.param['pre_filter_start']) * (3.5 - gp) / 20.0) + gp = g + spec_c[:, b, :] *= g + + return np.asfortranarray(spec_c) + +def spectrogram_to_image(spec, mode='magnitude'): + if mode == 'magnitude': + if np.iscomplexobj(spec): + y = np.abs(spec) + else: + y = spec + y = np.log10(y ** 2 + 1e-8) + elif mode == 'phase': + if np.iscomplexobj(spec): + y = np.angle(spec) + else: + y = spec + + y -= y.min() + y *= 255 / y.max() + img = np.uint8(y) + + if y.ndim == 3: + img = img.transpose(1, 2, 0) + img = np.concatenate([ + np.max(img, axis=2, keepdims=True), img + ], axis=2) + + return img + +def reduce_vocal_aggressively(X, y, softmask): + v = X - y + y_mag_tmp = np.abs(y) + v_mag_tmp = np.abs(v) + + v_mask = v_mag_tmp > y_mag_tmp + y_mag = np.clip(y_mag_tmp - v_mag_tmp * v_mask * softmask, 0, np.inf) + + return y_mag * np.exp(1.j * np.angle(y)) + +def merge_artifacts(y_mask, thres=0.05, min_range=64, fade_size=32): + if min_range < fade_size * 2: + raise ValueError('min_range must be >= fade_size * 2') + + idx = np.where(y_mask.min(axis=(0, 1)) > thres)[0] + start_idx = np.insert(idx[np.where(np.diff(idx) != 1)[0] + 1], 0, idx[0]) + end_idx = np.append(idx[np.where(np.diff(idx) != 1)[0]], idx[-1]) + artifact_idx = np.where(end_idx - start_idx > min_range)[0] + weight = np.zeros_like(y_mask) + if len(artifact_idx) > 0: + start_idx = start_idx[artifact_idx] + end_idx = end_idx[artifact_idx] + old_e = None + for s, e in zip(start_idx, end_idx): + if old_e is not None and s - old_e < fade_size: + s = old_e - fade_size * 2 + + if s != 0: + weight[:, :, s:s + fade_size] = np.linspace(0, 1, fade_size) + else: + s -= fade_size + + if e != y_mask.shape[2]: + weight[:, :, e - fade_size:e] = np.linspace(1, 0, fade_size) + else: + e += fade_size + + weight[:, :, s + fade_size:e - fade_size] = 1 + old_e = e + + v_mask = 1 - y_mask + y_mask += weight * v_mask + + return y_mask + +def mask_silence(mag, ref, thres=0.2, min_range=64, fade_size=32): + if min_range < fade_size * 2: + raise ValueError('min_range must be >= fade_area * 2') + + mag = mag.copy() + + idx = np.where(ref.mean(axis=(0, 1)) < thres)[0] + starts = np.insert(idx[np.where(np.diff(idx) != 1)[0] + 1], 0, idx[0]) + ends = np.append(idx[np.where(np.diff(idx) != 1)[0]], idx[-1]) + uninformative = np.where(ends - starts > min_range)[0] + if len(uninformative) > 0: + starts = starts[uninformative] + ends = ends[uninformative] + old_e = None + for s, e in zip(starts, ends): + if old_e is not None and s - old_e < fade_size: + s = old_e - fade_size * 2 + + if s != 0: + weight = np.linspace(0, 1, fade_size) + mag[:, :, s:s + fade_size] += weight * ref[:, :, s:s + fade_size] + else: + s -= fade_size + + if e != mag.shape[2]: + weight = np.linspace(1, 0, fade_size) + mag[:, :, e - fade_size:e] += weight * ref[:, :, e - fade_size:e] + else: + e += fade_size + + mag[:, :, s + fade_size:e - fade_size] += ref[:, :, s + fade_size:e - fade_size] + old_e = e + + return mag + +def align_wave_head_and_tail(a, b): + l = min([a[0].size, b[0].size]) + + return a[:l,:l], b[:l,:l] + +def spectrogram_to_wave(spec, hop_length, mid_side, mid_side_b2, reverse, clamp=False): + spec_left = np.asfortranarray(spec[0]) + spec_right = np.asfortranarray(spec[1]) + + wave_left = librosa.istft(spec_left, hop_length=hop_length) + wave_right = librosa.istft(spec_right, hop_length=hop_length) + + if reverse: + return np.asfortranarray([np.flip(wave_left), np.flip(wave_right)]) + elif mid_side: + return np.asfortranarray([np.add(wave_left, wave_right / 2), np.subtract(wave_left, wave_right / 2)]) + elif mid_side_b2: + return np.asfortranarray([np.add(wave_right / 1.25, .4 * wave_left), np.subtract(wave_left / 1.25, .4 * wave_right)]) + else: + return np.asfortranarray([wave_left, wave_right]) + +def spectrogram_to_wave_mt(spec, hop_length, mid_side, reverse, mid_side_b2): + import threading + + spec_left = np.asfortranarray(spec[0]) + spec_right = np.asfortranarray(spec[1]) + + def run_thread(**kwargs): + global wave_left + wave_left = librosa.istft(**kwargs) + + thread = threading.Thread(target=run_thread, kwargs={'stft_matrix': spec_left, 'hop_length': hop_length}) + thread.start() + wave_right = librosa.istft(spec_right, hop_length=hop_length) + thread.join() + + if reverse: + return np.asfortranarray([np.flip(wave_left), np.flip(wave_right)]) + elif mid_side: + return np.asfortranarray([np.add(wave_left, wave_right / 2), np.subtract(wave_left, wave_right / 2)]) + elif mid_side_b2: + return np.asfortranarray([np.add(wave_right / 1.25, .4 * wave_left), np.subtract(wave_left / 1.25, .4 * wave_right)]) + else: + return np.asfortranarray([wave_left, wave_right]) + +def cmb_spectrogram_to_wave(spec_m, mp, extra_bins_h=None, extra_bins=None): + bands_n = len(mp.param['band']) + offset = 0 + + for d in range(1, bands_n + 1): + bp = mp.param['band'][d] + spec_s = np.ndarray(shape=(2, bp['n_fft'] // 2 + 1, spec_m.shape[2]), dtype=complex) + h = bp['crop_stop'] - bp['crop_start'] + spec_s[:, bp['crop_start']:bp['crop_stop'], :] = spec_m[:, offset:offset+h, :] + + offset += h + if d == bands_n: # higher + if extra_bins_h: # if --high_end_process bypass + max_bin = bp['n_fft'] // 2 + spec_s[:, max_bin-extra_bins_h:max_bin, :] = extra_bins[:, :extra_bins_h, :] + if bp['hpf_start'] > 0: + spec_s = fft_hp_filter(spec_s, bp['hpf_start'], bp['hpf_stop'] - 1) + if bands_n == 1: + wave = spectrogram_to_wave(spec_s, bp['hl'], mp.param['mid_side'], mp.param['mid_side_b2'], mp.param['reverse']) + else: + wave = np.add(wave, spectrogram_to_wave(spec_s, bp['hl'], mp.param['mid_side'], mp.param['mid_side_b2'], mp.param['reverse'])) + else: + sr = mp.param['band'][d+1]['sr'] + if d == 1: # lower + spec_s = fft_lp_filter(spec_s, bp['lpf_start'], bp['lpf_stop']) + wave = librosa.resample(spectrogram_to_wave(spec_s, bp['hl'], mp.param['mid_side'], mp.param['mid_side_b2'], mp.param['reverse']), bp['sr'], sr, res_type="sinc_fastest") + else: # mid + spec_s = fft_hp_filter(spec_s, bp['hpf_start'], bp['hpf_stop'] - 1) + spec_s = fft_lp_filter(spec_s, bp['lpf_start'], bp['lpf_stop']) + wave2 = np.add(wave, spectrogram_to_wave(spec_s, bp['hl'], mp.param['mid_side'], mp.param['mid_side_b2'], mp.param['reverse'])) + wave = librosa.resample(wave2, bp['sr'], sr, res_type="sinc_fastest") + + return wave + +def fft_lp_filter(spec, bin_start, bin_stop): + g = 1.0 + for b in range(bin_start, bin_stop): + g -= 1 / (bin_stop - bin_start) + spec[:, b, :] = g * spec[:, b, :] + + spec[:, bin_stop:, :] *= 0 + + return spec + +def fft_hp_filter(spec, bin_start, bin_stop): + g = 1.0 + for b in range(bin_start, bin_stop, -1): + g -= 1 / (bin_start - bin_stop) + spec[:, b, :] = g * spec[:, b, :] + + spec[:, 0:bin_stop+1, :] *= 0 + + return spec + +def mirroring(a, spec_m, input_high_end, mp): + if 'mirroring' == a: + mirror = np.flip(np.abs(spec_m[:, mp.param['pre_filter_start']-10-input_high_end.shape[1]:mp.param['pre_filter_start']-10, :]), 1) + mirror = mirror * np.exp(1.j * np.angle(input_high_end)) + + return np.where(np.abs(input_high_end) <= np.abs(mirror), input_high_end, mirror) + + if 'mirroring2' == a: + mirror = np.flip(np.abs(spec_m[:, mp.param['pre_filter_start']-10-input_high_end.shape[1]:mp.param['pre_filter_start']-10, :]), 1) + mi = np.multiply(mirror, input_high_end * 1.7) + + return np.where(np.abs(input_high_end) <= np.abs(mi), input_high_end, mi) + +def adjust_aggr(mask, is_vocal_model, aggressiveness): + aggr = aggressiveness.get('value', 0.0) * 4 + + if aggr != 0: + if is_vocal_model: + aggr = 1 - aggr + + aggr = [aggr, aggr] + + if aggressiveness['aggr_correction'] is not None: + aggr[0] += aggressiveness['aggr_correction']['left'] + aggr[1] += aggressiveness['aggr_correction']['right'] + + for ch in range(2): + mask[ch, :aggressiveness['split_bin']] = np.power(mask[ch, :aggressiveness['split_bin']], 1 + aggr[ch] / 3) + mask[ch, aggressiveness['split_bin']:] = np.power(mask[ch, aggressiveness['split_bin']:], 1 + aggr[ch]) + + return mask + +def stft(wave, nfft, hl): + wave_left = np.asfortranarray(wave[0]) + wave_right = np.asfortranarray(wave[1]) + spec_left = librosa.stft(wave_left, nfft, hop_length=hl) + spec_right = librosa.stft(wave_right, nfft, hop_length=hl) + spec = np.asfortranarray([spec_left, spec_right]) + + return spec + +def istft(spec, hl): + spec_left = np.asfortranarray(spec[0]) + spec_right = np.asfortranarray(spec[1]) + wave_left = librosa.istft(spec_left, hop_length=hl) + wave_right = librosa.istft(spec_right, hop_length=hl) + wave = np.asfortranarray([wave_left, wave_right]) + + return wave + +def spec_effects(wave, algorithm='Default', value=None): + spec = [stft(wave[0],2048,1024), stft(wave[1],2048,1024)] + if algorithm == 'Min_Mag': + v_spec_m = np.where(np.abs(spec[1]) <= np.abs(spec[0]), spec[1], spec[0]) + wave = istft(v_spec_m,1024) + elif algorithm == 'Max_Mag': + v_spec_m = np.where(np.abs(spec[1]) >= np.abs(spec[0]), spec[1], spec[0]) + wave = istft(v_spec_m,1024) + elif algorithm == 'Default': + wave = (wave[1] * value) + (wave[0] * (1-value)) + elif algorithm == 'Invert_p': + X_mag = np.abs(spec[0]) + y_mag = np.abs(spec[1]) + max_mag = np.where(X_mag >= y_mag, X_mag, y_mag) + v_spec = spec[1] - max_mag * np.exp(1.j * np.angle(spec[0])) + wave = istft(v_spec,1024) + + return wave + +def spectrogram_to_wave_bare(spec, hop_length=1024): + spec_left = np.asfortranarray(spec[0]) + spec_right = np.asfortranarray(spec[1]) + wave_left = librosa.istft(spec_left, hop_length=hop_length) + wave_right = librosa.istft(spec_right, hop_length=hop_length) + wave = np.asfortranarray([wave_left, wave_right]) + + return wave + +def spectrogram_to_wave_no_mp(spec, hop_length=1024): + if spec.ndim == 2: + wave = librosa.istft(spec, hop_length=hop_length) + elif spec.ndim == 3: + spec_left = np.asfortranarray(spec[0]) + spec_right = np.asfortranarray(spec[1]) + + wave_left = librosa.istft(spec_left, hop_length=hop_length) + wave_right = librosa.istft(spec_right, hop_length=hop_length) + wave = np.asfortranarray([wave_left, wave_right]) + + return wave + +def wave_to_spectrogram_no_mp(wave): + + wave_left = np.asfortranarray(wave[0]) + wave_right = np.asfortranarray(wave[1]) + + spec_left = librosa.stft(wave_left, n_fft=2048, hop_length=1024) + spec_right = librosa.stft(wave_right, n_fft=2048, hop_length=1024) + spec = np.asfortranarray([spec_left, spec_right]) + + return spec + +# def noise_reduction(audio_file): + +# noise_pro = 'noise_pro.wav' + +# wav, sr = librosa.load(audio_file, sr=44100, mono=False) +# wav_noise, noise_rate = librosa.load(noise_pro, sr=44100, mono=False) + +# if wav.ndim == 1: +# wav = np.asfortranarray([wav,wav]) + +# wav_1 = nr.reduce_noise(audio_clip=wav[0], noise_clip=wav_noise, verbose=True) +# wav_2 = nr.reduce_noise(audio_clip=wav[1], noise_clip=wav_noise, verbose=True) + +# if wav_1.shape > wav_2.shape: +# wav_2 = to_shape(wav_2, wav_1.shape) +# if wav_1.shape < wav_2.shape: +# wav_1 = to_shape(wav_1, wav_2.shape) + +# #print('wav_1.shape: ', wav_1.shape) + +# wav_mix = np.asfortranarray([wav_1, wav_2]) + +# return wav_mix, sr + +def invert_audio(specs, invert_p=True): + + ln = min([specs[0].shape[2], specs[1].shape[2]]) + specs[0] = specs[0][:,:,:ln] + specs[1] = specs[1][:,:,:ln] + + if invert_p: + X_mag = np.abs(specs[0]) + y_mag = np.abs(specs[1]) + max_mag = np.where(X_mag >= y_mag, X_mag, y_mag) + v_spec = specs[1] - max_mag * np.exp(1.j * np.angle(specs[0])) + else: + specs[1] = reduce_vocal_aggressively(specs[0], specs[1], 0.2) + v_spec = specs[0] - specs[1] + + return v_spec + +def invert_stem(mixture, stem): + + mixture = wave_to_spectrogram_no_mp(mixture) + stem = wave_to_spectrogram_no_mp(stem) + output = spectrogram_to_wave_no_mp(invert_audio([mixture, stem])) + + return -output.T + +def ensembling(a, specs): + for i in range(1, len(specs)): + if i == 1: + spec = specs[0] + + ln = min([spec.shape[2], specs[i].shape[2]]) + spec = spec[:,:,:ln] + specs[i] = specs[i][:,:,:ln] + + #print('spec: ', a) + + if MIN_SPEC == a: + spec = np.where(np.abs(specs[i]) <= np.abs(spec), specs[i], spec) + if MAX_SPEC == a: + spec = np.where(np.abs(specs[i]) >= np.abs(spec), specs[i], spec) + if AVERAGE == a: + spec = np.where(np.abs(specs[i]) == np.abs(spec), specs[i], spec) + + return spec + +def ensemble_inputs(audio_input, algorithm, is_normalization, wav_type_set, save_path): + + #print(algorithm) + + if algorithm == AVERAGE: + output = average_audio(audio_input) + samplerate = 44100 + else: + specs = [] + + for i in range(len(audio_input)): + wave, samplerate = librosa.load(audio_input[i], mono=False, sr=44100) + spec = wave_to_spectrogram_no_mp(wave) + specs.append(spec) + #print('output size: ', sys.getsizeof(spec)) + + #print('output size: ', sys.getsizeof(specs)) + + output = spectrogram_to_wave_no_mp(ensembling(algorithm, specs)) + + sf.write(save_path, normalize(output.T, is_normalization), samplerate, subtype=wav_type_set) + +def to_shape(x, target_shape): + padding_list = [] + for x_dim, target_dim in zip(x.shape, target_shape): + pad_value = (target_dim - x_dim) + pad_tuple = ((0, pad_value)) + padding_list.append(pad_tuple) + + return np.pad(x, tuple(padding_list), mode='constant') + +def to_shape_minimize(x: np.ndarray, target_shape): + + padding_list = [] + for x_dim, target_dim in zip(x.shape, target_shape): + pad_value = (target_dim - x_dim) + pad_tuple = ((0, pad_value)) + padding_list.append(pad_tuple) + + return np.pad(x, tuple(padding_list), mode='constant') + +def augment_audio(export_path, audio_file, rate, is_normalization, wav_type_set, save_format=None, is_pitch=False): + + #print(rate) + + wav, sr = librosa.load(audio_file, sr=44100, mono=False) + + if wav.ndim == 1: + wav = np.asfortranarray([wav,wav]) + + if is_pitch: + wav_1 = pyrubberband.pyrb.pitch_shift(wav[0], sr, rate, rbargs=None) + wav_2 = pyrubberband.pyrb.pitch_shift(wav[1], sr, rate, rbargs=None) + else: + wav_1 = pyrubberband.pyrb.time_stretch(wav[0], sr, rate, rbargs=None) + wav_2 = pyrubberband.pyrb.time_stretch(wav[1], sr, rate, rbargs=None) + + if wav_1.shape > wav_2.shape: + wav_2 = to_shape(wav_2, wav_1.shape) + if wav_1.shape < wav_2.shape: + wav_1 = to_shape(wav_1, wav_2.shape) + + wav_mix = np.asfortranarray([wav_1, wav_2]) + + sf.write(export_path, normalize(wav_mix.T, is_normalization), sr, subtype=wav_type_set) + save_format(export_path) + +def average_audio(audio): + + waves = [] + wave_shapes = [] + final_waves = [] + + for i in range(len(audio)): + wave = librosa.load(audio[i], sr=44100, mono=False) + waves.append(wave[0]) + wave_shapes.append(wave[0].shape[1]) + + wave_shapes_index = wave_shapes.index(max(wave_shapes)) + target_shape = waves[wave_shapes_index] + waves.pop(wave_shapes_index) + final_waves.append(target_shape) + + for n_array in waves: + wav_target = to_shape(n_array, target_shape.shape) + final_waves.append(wav_target) + + waves = sum(final_waves) + waves = waves/len(audio) + + return waves + +def average_dual_sources(wav_1, wav_2, value): + + if wav_1.shape > wav_2.shape: + wav_2 = to_shape(wav_2, wav_1.shape) + if wav_1.shape < wav_2.shape: + wav_1 = to_shape(wav_1, wav_2.shape) + + wave = (wav_1 * value) + (wav_2 * (1-value)) + + return wave + +def reshape_sources(wav_1: np.ndarray, wav_2: np.ndarray): + + if wav_1.shape > wav_2.shape: + wav_2 = to_shape(wav_2, wav_1.shape) + if wav_1.shape < wav_2.shape: + ln = min([wav_1.shape[1], wav_2.shape[1]]) + wav_2 = wav_2[:,:ln] + + ln = min([wav_1.shape[1], wav_2.shape[1]]) + wav_1 = wav_1[:,:ln] + wav_2 = wav_2[:,:ln] + + return wav_2 + +def align_audio(file1, file2, file2_aligned, file_subtracted, wav_type_set, is_normalization, command_Text, progress_bar_main_var, save_format): + def get_diff(a, b): + corr = np.correlate(a, b, "full") + diff = corr.argmax() - (b.shape[0] - 1) + return diff + + progress_bar_main_var.set(10) + + # read tracks + wav1, sr1 = librosa.load(file1, sr=44100, mono=False) + wav2, sr2 = librosa.load(file2, sr=44100, mono=False) + wav1 = wav1.transpose() + wav2 = wav2.transpose() + + command_Text(f"Audio file shapes: {wav1.shape} / {wav2.shape}\n") + + wav2_org = wav2.copy() + progress_bar_main_var.set(20) + + command_Text("Processing files... \n") + + # pick random position and get diff + + counts = {} # counting up for each diff value + progress = 20 + + check_range = 64 + + base = (64 / check_range) + + for i in range(check_range): + index = int(random.uniform(44100 * 2, min(wav1.shape[0], wav2.shape[0]) - 44100 * 2)) + shift = int(random.uniform(-22050,+22050)) + samp1 = wav1[index :index +44100, 0] # currently use left channel + samp2 = wav2[index+shift:index+shift+44100, 0] + progress += 1 * base + progress_bar_main_var.set(progress) + diff = get_diff(samp1, samp2) + diff -= shift + + if abs(diff) < 22050: + if not diff in counts: + counts[diff] = 0 + counts[diff] += 1 + + # use max counted diff value + max_count = 0 + est_diff = 0 + for diff in counts.keys(): + if counts[diff] > max_count: + max_count = counts[diff] + est_diff = diff + + command_Text(f"Estimated difference is {est_diff} (count: {max_count})\n") + + progress_bar_main_var.set(90) + + audio_files = [] + + def save_aligned_audio(wav2_aligned): + command_Text(f"Aligned File 2 with File 1.\n") + command_Text(f"Saving files... ") + sf.write(file2_aligned, normalize(wav2_aligned, is_normalization), sr2, subtype=wav_type_set) + save_format(file2_aligned) + min_len = min(wav1.shape[0], wav2_aligned.shape[0]) + wav_sub = wav1[:min_len] - wav2_aligned[:min_len] + audio_files.append(file2_aligned) + return min_len, wav_sub + + # make aligned track 2 + if est_diff > 0: + wav2_aligned = np.append(np.zeros((est_diff, 2)), wav2_org, axis=0) + min_len, wav_sub = save_aligned_audio(wav2_aligned) + elif est_diff < 0: + wav2_aligned = wav2_org[-est_diff:] + min_len, wav_sub = save_aligned_audio(wav2_aligned) + else: + command_Text(f"Audio files already aligned.\n") + command_Text(f"Saving inverted track... ") + min_len = min(wav1.shape[0], wav2.shape[0]) + wav_sub = wav1[:min_len] - wav2[:min_len] + + wav_sub = np.clip(wav_sub, -1, +1) + + sf.write(file_subtracted, normalize(wav_sub, is_normalization), sr1, subtype=wav_type_set) + save_format(file_subtracted) + + progress_bar_main_var.set(95) \ No newline at end of file diff --git a/lib_v5/vr_network/__init__.py b/lib_v5/vr_network/__init__.py new file mode 100644 index 00000000..361b7086 --- /dev/null +++ b/lib_v5/vr_network/__init__.py @@ -0,0 +1 @@ +# VR init. diff --git a/lib_v5/vr_network/layers.py b/lib_v5/vr_network/layers.py new file mode 100644 index 00000000..0120a347 --- /dev/null +++ b/lib_v5/vr_network/layers.py @@ -0,0 +1,143 @@ +import torch +from torch import nn +import torch.nn.functional as F + +from lib_v5 import spec_utils + +class Conv2DBNActiv(nn.Module): + + def __init__(self, nin, nout, ksize=3, stride=1, pad=1, dilation=1, activ=nn.ReLU): + super(Conv2DBNActiv, self).__init__() + self.conv = nn.Sequential( + nn.Conv2d( + nin, nout, + kernel_size=ksize, + stride=stride, + padding=pad, + dilation=dilation, + bias=False), + nn.BatchNorm2d(nout), + activ() + ) + + def __call__(self, x): + return self.conv(x) + +class SeperableConv2DBNActiv(nn.Module): + + def __init__(self, nin, nout, ksize=3, stride=1, pad=1, dilation=1, activ=nn.ReLU): + super(SeperableConv2DBNActiv, self).__init__() + self.conv = nn.Sequential( + nn.Conv2d( + nin, nin, + kernel_size=ksize, + stride=stride, + padding=pad, + dilation=dilation, + groups=nin, + bias=False), + nn.Conv2d( + nin, nout, + kernel_size=1, + bias=False), + nn.BatchNorm2d(nout), + activ() + ) + + def __call__(self, x): + return self.conv(x) + + +class Encoder(nn.Module): + + def __init__(self, nin, nout, ksize=3, stride=1, pad=1, activ=nn.LeakyReLU): + super(Encoder, self).__init__() + self.conv1 = Conv2DBNActiv(nin, nout, ksize, 1, pad, activ=activ) + self.conv2 = Conv2DBNActiv(nout, nout, ksize, stride, pad, activ=activ) + + def __call__(self, x): + skip = self.conv1(x) + h = self.conv2(skip) + + return h, skip + + +class Decoder(nn.Module): + + def __init__(self, nin, nout, ksize=3, stride=1, pad=1, activ=nn.ReLU, dropout=False): + super(Decoder, self).__init__() + self.conv = Conv2DBNActiv(nin, nout, ksize, 1, pad, activ=activ) + self.dropout = nn.Dropout2d(0.1) if dropout else None + + def __call__(self, x, skip=None): + x = F.interpolate(x, scale_factor=2, mode='bilinear', align_corners=True) + if skip is not None: + skip = spec_utils.crop_center(skip, x) + x = torch.cat([x, skip], dim=1) + h = self.conv(x) + + if self.dropout is not None: + h = self.dropout(h) + + return h + + +class ASPPModule(nn.Module): + + def __init__(self, nn_architecture, nin, nout, dilations=(4, 8, 16), activ=nn.ReLU): + super(ASPPModule, self).__init__() + self.conv1 = nn.Sequential( + nn.AdaptiveAvgPool2d((1, None)), + Conv2DBNActiv(nin, nin, 1, 1, 0, activ=activ) + ) + + self.nn_architecture = nn_architecture + self.six_layer = [129605] + self.seven_layer = [537238, 537227, 33966] + + extra_conv = SeperableConv2DBNActiv( + nin, nin, 3, 1, dilations[2], dilations[2], activ=activ) + + self.conv2 = Conv2DBNActiv(nin, nin, 1, 1, 0, activ=activ) + self.conv3 = SeperableConv2DBNActiv( + nin, nin, 3, 1, dilations[0], dilations[0], activ=activ) + self.conv4 = SeperableConv2DBNActiv( + nin, nin, 3, 1, dilations[1], dilations[1], activ=activ) + self.conv5 = SeperableConv2DBNActiv( + nin, nin, 3, 1, dilations[2], dilations[2], activ=activ) + + if self.nn_architecture in self.six_layer: + self.conv6 = extra_conv + nin_x = 6 + elif self.nn_architecture in self.seven_layer: + self.conv6 = extra_conv + self.conv7 = extra_conv + nin_x = 7 + else: + nin_x = 5 + + self.bottleneck = nn.Sequential( + Conv2DBNActiv(nin * nin_x, nout, 1, 1, 0, activ=activ), + nn.Dropout2d(0.1) + ) + + def forward(self, x): + _, _, h, w = x.size() + feat1 = F.interpolate(self.conv1(x), size=(h, w), mode='bilinear', align_corners=True) + feat2 = self.conv2(x) + feat3 = self.conv3(x) + feat4 = self.conv4(x) + feat5 = self.conv5(x) + + if self.nn_architecture in self.six_layer: + feat6 = self.conv6(x) + out = torch.cat((feat1, feat2, feat3, feat4, feat5, feat6), dim=1) + elif self.nn_architecture in self.seven_layer: + feat6 = self.conv6(x) + feat7 = self.conv7(x) + out = torch.cat((feat1, feat2, feat3, feat4, feat5, feat6, feat7), dim=1) + else: + out = torch.cat((feat1, feat2, feat3, feat4, feat5), dim=1) + + bottle = self.bottleneck(out) + return bottle diff --git a/lib_v5/vr_network/layers_new.py b/lib_v5/vr_network/layers_new.py new file mode 100644 index 00000000..33181dd8 --- /dev/null +++ b/lib_v5/vr_network/layers_new.py @@ -0,0 +1,126 @@ +import torch +from torch import nn +import torch.nn.functional as F + +from lib_v5 import spec_utils + +class Conv2DBNActiv(nn.Module): + + def __init__(self, nin, nout, ksize=3, stride=1, pad=1, dilation=1, activ=nn.ReLU): + super(Conv2DBNActiv, self).__init__() + self.conv = nn.Sequential( + nn.Conv2d( + nin, nout, + kernel_size=ksize, + stride=stride, + padding=pad, + dilation=dilation, + bias=False), + nn.BatchNorm2d(nout), + activ() + ) + + def __call__(self, x): + return self.conv(x) + +class Encoder(nn.Module): + + def __init__(self, nin, nout, ksize=3, stride=1, pad=1, activ=nn.LeakyReLU): + super(Encoder, self).__init__() + self.conv1 = Conv2DBNActiv(nin, nout, ksize, stride, pad, activ=activ) + self.conv2 = Conv2DBNActiv(nout, nout, ksize, 1, pad, activ=activ) + + def __call__(self, x): + h = self.conv1(x) + h = self.conv2(h) + + return h + + +class Decoder(nn.Module): + + def __init__(self, nin, nout, ksize=3, stride=1, pad=1, activ=nn.ReLU, dropout=False): + super(Decoder, self).__init__() + self.conv1 = Conv2DBNActiv(nin, nout, ksize, 1, pad, activ=activ) + # self.conv2 = Conv2DBNActiv(nout, nout, ksize, 1, pad, activ=activ) + self.dropout = nn.Dropout2d(0.1) if dropout else None + + def __call__(self, x, skip=None): + x = F.interpolate(x, scale_factor=2, mode='bilinear', align_corners=True) + + if skip is not None: + skip = spec_utils.crop_center(skip, x) + x = torch.cat([x, skip], dim=1) + + h = self.conv1(x) + # h = self.conv2(h) + + if self.dropout is not None: + h = self.dropout(h) + + return h + + +class ASPPModule(nn.Module): + + def __init__(self, nin, nout, dilations=(4, 8, 12), activ=nn.ReLU, dropout=False): + super(ASPPModule, self).__init__() + self.conv1 = nn.Sequential( + nn.AdaptiveAvgPool2d((1, None)), + Conv2DBNActiv(nin, nout, 1, 1, 0, activ=activ) + ) + self.conv2 = Conv2DBNActiv(nin, nout, 1, 1, 0, activ=activ) + self.conv3 = Conv2DBNActiv( + nin, nout, 3, 1, dilations[0], dilations[0], activ=activ + ) + self.conv4 = Conv2DBNActiv( + nin, nout, 3, 1, dilations[1], dilations[1], activ=activ + ) + self.conv5 = Conv2DBNActiv( + nin, nout, 3, 1, dilations[2], dilations[2], activ=activ + ) + self.bottleneck = Conv2DBNActiv(nout * 5, nout, 1, 1, 0, activ=activ) + self.dropout = nn.Dropout2d(0.1) if dropout else None + + def forward(self, x): + _, _, h, w = x.size() + feat1 = F.interpolate(self.conv1(x), size=(h, w), mode='bilinear', align_corners=True) + feat2 = self.conv2(x) + feat3 = self.conv3(x) + feat4 = self.conv4(x) + feat5 = self.conv5(x) + out = torch.cat((feat1, feat2, feat3, feat4, feat5), dim=1) + out = self.bottleneck(out) + + if self.dropout is not None: + out = self.dropout(out) + + return out + + +class LSTMModule(nn.Module): + + def __init__(self, nin_conv, nin_lstm, nout_lstm): + super(LSTMModule, self).__init__() + self.conv = Conv2DBNActiv(nin_conv, 1, 1, 1, 0) + self.lstm = nn.LSTM( + input_size=nin_lstm, + hidden_size=nout_lstm // 2, + bidirectional=True + ) + self.dense = nn.Sequential( + nn.Linear(nout_lstm, nin_lstm), + nn.BatchNorm1d(nin_lstm), + nn.ReLU() + ) + + def forward(self, x): + N, _, nbins, nframes = x.size() + h = self.conv(x)[:, 0] # N, nbins, nframes + h = h.permute(2, 0, 1) # nframes, N, nbins + h, _ = self.lstm(h) + h = self.dense(h.reshape(-1, h.size()[-1])) # nframes * N, nbins + h = h.reshape(nframes, N, 1, nbins) + h = h.permute(1, 2, 3, 0) + + return h diff --git a/lib_v5/vr_network/model_param_init.py b/lib_v5/vr_network/model_param_init.py new file mode 100644 index 00000000..42fca9a4 --- /dev/null +++ b/lib_v5/vr_network/model_param_init.py @@ -0,0 +1,59 @@ +import json +import pathlib + +default_param = {} +default_param['bins'] = 768 +default_param['unstable_bins'] = 9 # training only +default_param['reduction_bins'] = 762 # training only +default_param['sr'] = 44100 +default_param['pre_filter_start'] = 757 +default_param['pre_filter_stop'] = 768 +default_param['band'] = {} + + +default_param['band'][1] = { + 'sr': 11025, + 'hl': 128, + 'n_fft': 960, + 'crop_start': 0, + 'crop_stop': 245, + 'lpf_start': 61, # inference only + 'res_type': 'polyphase' +} + +default_param['band'][2] = { + 'sr': 44100, + 'hl': 512, + 'n_fft': 1536, + 'crop_start': 24, + 'crop_stop': 547, + 'hpf_start': 81, # inference only + 'res_type': 'sinc_best' +} + + +def int_keys(d): + r = {} + for k, v in d: + if k.isdigit(): + k = int(k) + r[k] = v + return r + + +class ModelParameters(object): + def __init__(self, config_path=''): + if '.pth' == pathlib.Path(config_path).suffix: + import zipfile + + with zipfile.ZipFile(config_path, 'r') as zip: + self.param = json.loads(zip.read('param.json'), object_pairs_hook=int_keys) + elif '.json' == pathlib.Path(config_path).suffix: + with open(config_path, 'r') as f: + self.param = json.loads(f.read(), object_pairs_hook=int_keys) + else: + self.param = default_param + + for k in ['mid_side', 'mid_side_b', 'mid_side_b2', 'stereo_w', 'stereo_n', 'reverse']: + if not k in self.param: + self.param[k] = False \ No newline at end of file diff --git a/lib_v5/vr_network/modelparams/1band_sr16000_hl512.json b/lib_v5/vr_network/modelparams/1band_sr16000_hl512.json new file mode 100644 index 00000000..72cb4499 --- /dev/null +++ b/lib_v5/vr_network/modelparams/1band_sr16000_hl512.json @@ -0,0 +1,19 @@ +{ + "bins": 1024, + "unstable_bins": 0, + "reduction_bins": 0, + "band": { + "1": { + "sr": 16000, + "hl": 512, + "n_fft": 2048, + "crop_start": 0, + "crop_stop": 1024, + "hpf_start": -1, + "res_type": "sinc_best" + } + }, + "sr": 16000, + "pre_filter_start": 1023, + "pre_filter_stop": 1024 +} \ No newline at end of file diff --git a/lib_v5/vr_network/modelparams/1band_sr32000_hl512.json b/lib_v5/vr_network/modelparams/1band_sr32000_hl512.json new file mode 100644 index 00000000..3c00ecf0 --- /dev/null +++ b/lib_v5/vr_network/modelparams/1band_sr32000_hl512.json @@ -0,0 +1,19 @@ +{ + "bins": 1024, + "unstable_bins": 0, + "reduction_bins": 0, + "band": { + "1": { + "sr": 32000, + "hl": 512, + "n_fft": 2048, + "crop_start": 0, + "crop_stop": 1024, + "hpf_start": -1, + "res_type": "kaiser_fast" + } + }, + "sr": 32000, + "pre_filter_start": 1000, + "pre_filter_stop": 1021 +} \ No newline at end of file diff --git a/lib_v5/vr_network/modelparams/1band_sr33075_hl384.json b/lib_v5/vr_network/modelparams/1band_sr33075_hl384.json new file mode 100644 index 00000000..55666ac9 --- /dev/null +++ b/lib_v5/vr_network/modelparams/1band_sr33075_hl384.json @@ -0,0 +1,19 @@ +{ + "bins": 1024, + "unstable_bins": 0, + "reduction_bins": 0, + "band": { + "1": { + "sr": 33075, + "hl": 384, + "n_fft": 2048, + "crop_start": 0, + "crop_stop": 1024, + "hpf_start": -1, + "res_type": "sinc_best" + } + }, + "sr": 33075, + "pre_filter_start": 1000, + "pre_filter_stop": 1021 +} \ No newline at end of file diff --git a/lib_v5/vr_network/modelparams/1band_sr44100_hl1024.json b/lib_v5/vr_network/modelparams/1band_sr44100_hl1024.json new file mode 100644 index 00000000..665abe20 --- /dev/null +++ b/lib_v5/vr_network/modelparams/1band_sr44100_hl1024.json @@ -0,0 +1,19 @@ +{ + "bins": 1024, + "unstable_bins": 0, + "reduction_bins": 0, + "band": { + "1": { + "sr": 44100, + "hl": 1024, + "n_fft": 2048, + "crop_start": 0, + "crop_stop": 1024, + "hpf_start": -1, + "res_type": "sinc_best" + } + }, + "sr": 44100, + "pre_filter_start": 1023, + "pre_filter_stop": 1024 +} \ No newline at end of file diff --git a/lib_v5/vr_network/modelparams/1band_sr44100_hl256.json b/lib_v5/vr_network/modelparams/1band_sr44100_hl256.json new file mode 100644 index 00000000..0e8b16f8 --- /dev/null +++ b/lib_v5/vr_network/modelparams/1band_sr44100_hl256.json @@ -0,0 +1,19 @@ +{ + "bins": 256, + "unstable_bins": 0, + "reduction_bins": 0, + "band": { + "1": { + "sr": 44100, + "hl": 256, + "n_fft": 512, + "crop_start": 0, + "crop_stop": 256, + "hpf_start": -1, + "res_type": "sinc_best" + } + }, + "sr": 44100, + "pre_filter_start": 256, + "pre_filter_stop": 256 +} \ No newline at end of file diff --git a/lib_v5/vr_network/modelparams/1band_sr44100_hl512.json b/lib_v5/vr_network/modelparams/1band_sr44100_hl512.json new file mode 100644 index 00000000..3b38fcaf --- /dev/null +++ b/lib_v5/vr_network/modelparams/1band_sr44100_hl512.json @@ -0,0 +1,19 @@ +{ + "bins": 1024, + "unstable_bins": 0, + "reduction_bins": 0, + "band": { + "1": { + "sr": 44100, + "hl": 512, + "n_fft": 2048, + "crop_start": 0, + "crop_stop": 1024, + "hpf_start": -1, + "res_type": "sinc_best" + } + }, + "sr": 44100, + "pre_filter_start": 1023, + "pre_filter_stop": 1024 +} \ No newline at end of file diff --git a/lib_v5/vr_network/modelparams/1band_sr44100_hl512_cut.json b/lib_v5/vr_network/modelparams/1band_sr44100_hl512_cut.json new file mode 100644 index 00000000..630df352 --- /dev/null +++ b/lib_v5/vr_network/modelparams/1band_sr44100_hl512_cut.json @@ -0,0 +1,19 @@ +{ + "bins": 1024, + "unstable_bins": 0, + "reduction_bins": 0, + "band": { + "1": { + "sr": 44100, + "hl": 512, + "n_fft": 2048, + "crop_start": 0, + "crop_stop": 700, + "hpf_start": -1, + "res_type": "sinc_best" + } + }, + "sr": 44100, + "pre_filter_start": 1023, + "pre_filter_stop": 700 +} \ No newline at end of file diff --git a/lib_v5/vr_network/modelparams/1band_sr44100_hl512_nf1024.json b/lib_v5/vr_network/modelparams/1band_sr44100_hl512_nf1024.json new file mode 100644 index 00000000..0cdf45b2 --- /dev/null +++ b/lib_v5/vr_network/modelparams/1band_sr44100_hl512_nf1024.json @@ -0,0 +1,19 @@ +{ + "bins": 1024, + "unstable_bins": 0, + "reduction_bins": 0, + "band": { + "1": { + "sr": 44100, + "hl": 512, + "n_fft": 1024, + "crop_start": 0, + "crop_stop": 1024, + "hpf_start": -1, + "res_type": "sinc_best" + } + }, + "sr": 44100, + "pre_filter_start": 1023, + "pre_filter_stop": 1024 +} \ No newline at end of file diff --git a/lib_v5/vr_network/modelparams/2band_32000.json b/lib_v5/vr_network/modelparams/2band_32000.json new file mode 100644 index 00000000..ab9cf115 --- /dev/null +++ b/lib_v5/vr_network/modelparams/2band_32000.json @@ -0,0 +1,30 @@ +{ + "bins": 768, + "unstable_bins": 7, + "reduction_bins": 705, + "band": { + "1": { + "sr": 6000, + "hl": 66, + "n_fft": 512, + "crop_start": 0, + "crop_stop": 240, + "lpf_start": 60, + "lpf_stop": 118, + "res_type": "sinc_fastest" + }, + "2": { + "sr": 32000, + "hl": 352, + "n_fft": 1024, + "crop_start": 22, + "crop_stop": 505, + "hpf_start": 44, + "hpf_stop": 23, + "res_type": "sinc_medium" + } + }, + "sr": 32000, + "pre_filter_start": 710, + "pre_filter_stop": 731 +} diff --git a/lib_v5/vr_network/modelparams/2band_44100_lofi.json b/lib_v5/vr_network/modelparams/2band_44100_lofi.json new file mode 100644 index 00000000..7faa216d --- /dev/null +++ b/lib_v5/vr_network/modelparams/2band_44100_lofi.json @@ -0,0 +1,30 @@ +{ + "bins": 512, + "unstable_bins": 7, + "reduction_bins": 510, + "band": { + "1": { + "sr": 11025, + "hl": 160, + "n_fft": 768, + "crop_start": 0, + "crop_stop": 192, + "lpf_start": 41, + "lpf_stop": 139, + "res_type": "sinc_fastest" + }, + "2": { + "sr": 44100, + "hl": 640, + "n_fft": 1024, + "crop_start": 10, + "crop_stop": 320, + "hpf_start": 47, + "hpf_stop": 15, + "res_type": "sinc_medium" + } + }, + "sr": 44100, + "pre_filter_start": 510, + "pre_filter_stop": 512 +} diff --git a/lib_v5/vr_network/modelparams/2band_48000.json b/lib_v5/vr_network/modelparams/2band_48000.json new file mode 100644 index 00000000..be075f52 --- /dev/null +++ b/lib_v5/vr_network/modelparams/2band_48000.json @@ -0,0 +1,30 @@ +{ + "bins": 768, + "unstable_bins": 7, + "reduction_bins": 705, + "band": { + "1": { + "sr": 6000, + "hl": 66, + "n_fft": 512, + "crop_start": 0, + "crop_stop": 240, + "lpf_start": 60, + "lpf_stop": 240, + "res_type": "sinc_fastest" + }, + "2": { + "sr": 48000, + "hl": 528, + "n_fft": 1536, + "crop_start": 22, + "crop_stop": 505, + "hpf_start": 82, + "hpf_stop": 22, + "res_type": "sinc_medium" + } + }, + "sr": 48000, + "pre_filter_start": 710, + "pre_filter_stop": 731 +} \ No newline at end of file diff --git a/lib_v5/vr_network/modelparams/3band_44100.json b/lib_v5/vr_network/modelparams/3band_44100.json new file mode 100644 index 00000000..d99e2398 --- /dev/null +++ b/lib_v5/vr_network/modelparams/3band_44100.json @@ -0,0 +1,42 @@ +{ + "bins": 768, + "unstable_bins": 5, + "reduction_bins": 733, + "band": { + "1": { + "sr": 11025, + "hl": 128, + "n_fft": 768, + "crop_start": 0, + "crop_stop": 278, + "lpf_start": 28, + "lpf_stop": 140, + "res_type": "polyphase" + }, + "2": { + "sr": 22050, + "hl": 256, + "n_fft": 768, + "crop_start": 14, + "crop_stop": 322, + "hpf_start": 70, + "hpf_stop": 14, + "lpf_start": 283, + "lpf_stop": 314, + "res_type": "polyphase" + }, + "3": { + "sr": 44100, + "hl": 512, + "n_fft": 768, + "crop_start": 131, + "crop_stop": 313, + "hpf_start": 154, + "hpf_stop": 141, + "res_type": "sinc_medium" + } + }, + "sr": 44100, + "pre_filter_start": 757, + "pre_filter_stop": 768 +} diff --git a/lib_v5/vr_network/modelparams/3band_44100_mid.json b/lib_v5/vr_network/modelparams/3band_44100_mid.json new file mode 100644 index 00000000..fc2c487d --- /dev/null +++ b/lib_v5/vr_network/modelparams/3band_44100_mid.json @@ -0,0 +1,43 @@ +{ + "mid_side": true, + "bins": 768, + "unstable_bins": 5, + "reduction_bins": 733, + "band": { + "1": { + "sr": 11025, + "hl": 128, + "n_fft": 768, + "crop_start": 0, + "crop_stop": 278, + "lpf_start": 28, + "lpf_stop": 140, + "res_type": "polyphase" + }, + "2": { + "sr": 22050, + "hl": 256, + "n_fft": 768, + "crop_start": 14, + "crop_stop": 322, + "hpf_start": 70, + "hpf_stop": 14, + "lpf_start": 283, + "lpf_stop": 314, + "res_type": "polyphase" + }, + "3": { + "sr": 44100, + "hl": 512, + "n_fft": 768, + "crop_start": 131, + "crop_stop": 313, + "hpf_start": 154, + "hpf_stop": 141, + "res_type": "sinc_medium" + } + }, + "sr": 44100, + "pre_filter_start": 757, + "pre_filter_stop": 768 +} diff --git a/lib_v5/vr_network/modelparams/3band_44100_msb2.json b/lib_v5/vr_network/modelparams/3band_44100_msb2.json new file mode 100644 index 00000000..33b0877c --- /dev/null +++ b/lib_v5/vr_network/modelparams/3band_44100_msb2.json @@ -0,0 +1,43 @@ +{ + "mid_side_b2": true, + "bins": 640, + "unstable_bins": 7, + "reduction_bins": 565, + "band": { + "1": { + "sr": 11025, + "hl": 108, + "n_fft": 1024, + "crop_start": 0, + "crop_stop": 187, + "lpf_start": 92, + "lpf_stop": 186, + "res_type": "polyphase" + }, + "2": { + "sr": 22050, + "hl": 216, + "n_fft": 768, + "crop_start": 0, + "crop_stop": 212, + "hpf_start": 68, + "hpf_stop": 34, + "lpf_start": 174, + "lpf_stop": 209, + "res_type": "polyphase" + }, + "3": { + "sr": 44100, + "hl": 432, + "n_fft": 640, + "crop_start": 66, + "crop_stop": 307, + "hpf_start": 86, + "hpf_stop": 72, + "res_type": "kaiser_fast" + } + }, + "sr": 44100, + "pre_filter_start": 639, + "pre_filter_stop": 640 +} diff --git a/lib_v5/vr_network/modelparams/4band_44100.json b/lib_v5/vr_network/modelparams/4band_44100.json new file mode 100644 index 00000000..4ae850a0 --- /dev/null +++ b/lib_v5/vr_network/modelparams/4band_44100.json @@ -0,0 +1,54 @@ +{ + "bins": 768, + "unstable_bins": 7, + "reduction_bins": 668, + "band": { + "1": { + "sr": 11025, + "hl": 128, + "n_fft": 1024, + "crop_start": 0, + "crop_stop": 186, + "lpf_start": 37, + "lpf_stop": 73, + "res_type": "polyphase" + }, + "2": { + "sr": 11025, + "hl": 128, + "n_fft": 512, + "crop_start": 4, + "crop_stop": 185, + "hpf_start": 36, + "hpf_stop": 18, + "lpf_start": 93, + "lpf_stop": 185, + "res_type": "polyphase" + }, + "3": { + "sr": 22050, + "hl": 256, + "n_fft": 512, + "crop_start": 46, + "crop_stop": 186, + "hpf_start": 93, + "hpf_stop": 46, + "lpf_start": 164, + "lpf_stop": 186, + "res_type": "polyphase" + }, + "4": { + "sr": 44100, + "hl": 512, + "n_fft": 768, + "crop_start": 121, + "crop_stop": 382, + "hpf_start": 138, + "hpf_stop": 123, + "res_type": "sinc_medium" + } + }, + "sr": 44100, + "pre_filter_start": 740, + "pre_filter_stop": 768 +} diff --git a/lib_v5/vr_network/modelparams/4band_44100_mid.json b/lib_v5/vr_network/modelparams/4band_44100_mid.json new file mode 100644 index 00000000..63467015 --- /dev/null +++ b/lib_v5/vr_network/modelparams/4band_44100_mid.json @@ -0,0 +1,55 @@ +{ + "bins": 768, + "unstable_bins": 7, + "mid_side": true, + "reduction_bins": 668, + "band": { + "1": { + "sr": 11025, + "hl": 128, + "n_fft": 1024, + "crop_start": 0, + "crop_stop": 186, + "lpf_start": 37, + "lpf_stop": 73, + "res_type": "polyphase" + }, + "2": { + "sr": 11025, + "hl": 128, + "n_fft": 512, + "crop_start": 4, + "crop_stop": 185, + "hpf_start": 36, + "hpf_stop": 18, + "lpf_start": 93, + "lpf_stop": 185, + "res_type": "polyphase" + }, + "3": { + "sr": 22050, + "hl": 256, + "n_fft": 512, + "crop_start": 46, + "crop_stop": 186, + "hpf_start": 93, + "hpf_stop": 46, + "lpf_start": 164, + "lpf_stop": 186, + "res_type": "polyphase" + }, + "4": { + "sr": 44100, + "hl": 512, + "n_fft": 768, + "crop_start": 121, + "crop_stop": 382, + "hpf_start": 138, + "hpf_stop": 123, + "res_type": "sinc_medium" + } + }, + "sr": 44100, + "pre_filter_start": 740, + "pre_filter_stop": 768 +} diff --git a/lib_v5/vr_network/modelparams/4band_44100_msb.json b/lib_v5/vr_network/modelparams/4band_44100_msb.json new file mode 100644 index 00000000..0bf47711 --- /dev/null +++ b/lib_v5/vr_network/modelparams/4band_44100_msb.json @@ -0,0 +1,55 @@ +{ + "mid_side_b": true, + "bins": 768, + "unstable_bins": 7, + "reduction_bins": 668, + "band": { + "1": { + "sr": 11025, + "hl": 128, + "n_fft": 1024, + "crop_start": 0, + "crop_stop": 186, + "lpf_start": 37, + "lpf_stop": 73, + "res_type": "polyphase" + }, + "2": { + "sr": 11025, + "hl": 128, + "n_fft": 512, + "crop_start": 4, + "crop_stop": 185, + "hpf_start": 36, + "hpf_stop": 18, + "lpf_start": 93, + "lpf_stop": 185, + "res_type": "polyphase" + }, + "3": { + "sr": 22050, + "hl": 256, + "n_fft": 512, + "crop_start": 46, + "crop_stop": 186, + "hpf_start": 93, + "hpf_stop": 46, + "lpf_start": 164, + "lpf_stop": 186, + "res_type": "polyphase" + }, + "4": { + "sr": 44100, + "hl": 512, + "n_fft": 768, + "crop_start": 121, + "crop_stop": 382, + "hpf_start": 138, + "hpf_stop": 123, + "res_type": "sinc_medium" + } + }, + "sr": 44100, + "pre_filter_start": 740, + "pre_filter_stop": 768 +} \ No newline at end of file diff --git a/lib_v5/vr_network/modelparams/4band_44100_msb2.json b/lib_v5/vr_network/modelparams/4band_44100_msb2.json new file mode 100644 index 00000000..0bf47711 --- /dev/null +++ b/lib_v5/vr_network/modelparams/4band_44100_msb2.json @@ -0,0 +1,55 @@ +{ + "mid_side_b": true, + "bins": 768, + "unstable_bins": 7, + "reduction_bins": 668, + "band": { + "1": { + "sr": 11025, + "hl": 128, + "n_fft": 1024, + "crop_start": 0, + "crop_stop": 186, + "lpf_start": 37, + "lpf_stop": 73, + "res_type": "polyphase" + }, + "2": { + "sr": 11025, + "hl": 128, + "n_fft": 512, + "crop_start": 4, + "crop_stop": 185, + "hpf_start": 36, + "hpf_stop": 18, + "lpf_start": 93, + "lpf_stop": 185, + "res_type": "polyphase" + }, + "3": { + "sr": 22050, + "hl": 256, + "n_fft": 512, + "crop_start": 46, + "crop_stop": 186, + "hpf_start": 93, + "hpf_stop": 46, + "lpf_start": 164, + "lpf_stop": 186, + "res_type": "polyphase" + }, + "4": { + "sr": 44100, + "hl": 512, + "n_fft": 768, + "crop_start": 121, + "crop_stop": 382, + "hpf_start": 138, + "hpf_stop": 123, + "res_type": "sinc_medium" + } + }, + "sr": 44100, + "pre_filter_start": 740, + "pre_filter_stop": 768 +} \ No newline at end of file diff --git a/lib_v5/vr_network/modelparams/4band_44100_reverse.json b/lib_v5/vr_network/modelparams/4band_44100_reverse.json new file mode 100644 index 00000000..779a1c90 --- /dev/null +++ b/lib_v5/vr_network/modelparams/4band_44100_reverse.json @@ -0,0 +1,55 @@ +{ + "reverse": true, + "bins": 768, + "unstable_bins": 7, + "reduction_bins": 668, + "band": { + "1": { + "sr": 11025, + "hl": 128, + "n_fft": 1024, + "crop_start": 0, + "crop_stop": 186, + "lpf_start": 37, + "lpf_stop": 73, + "res_type": "polyphase" + }, + "2": { + "sr": 11025, + "hl": 128, + "n_fft": 512, + "crop_start": 4, + "crop_stop": 185, + "hpf_start": 36, + "hpf_stop": 18, + "lpf_start": 93, + "lpf_stop": 185, + "res_type": "polyphase" + }, + "3": { + "sr": 22050, + "hl": 256, + "n_fft": 512, + "crop_start": 46, + "crop_stop": 186, + "hpf_start": 93, + "hpf_stop": 46, + "lpf_start": 164, + "lpf_stop": 186, + "res_type": "polyphase" + }, + "4": { + "sr": 44100, + "hl": 512, + "n_fft": 768, + "crop_start": 121, + "crop_stop": 382, + "hpf_start": 138, + "hpf_stop": 123, + "res_type": "sinc_medium" + } + }, + "sr": 44100, + "pre_filter_start": 740, + "pre_filter_stop": 768 +} \ No newline at end of file diff --git a/lib_v5/vr_network/modelparams/4band_44100_sw.json b/lib_v5/vr_network/modelparams/4band_44100_sw.json new file mode 100644 index 00000000..1fefd4aa --- /dev/null +++ b/lib_v5/vr_network/modelparams/4band_44100_sw.json @@ -0,0 +1,55 @@ +{ + "stereo_w": true, + "bins": 768, + "unstable_bins": 7, + "reduction_bins": 668, + "band": { + "1": { + "sr": 11025, + "hl": 128, + "n_fft": 1024, + "crop_start": 0, + "crop_stop": 186, + "lpf_start": 37, + "lpf_stop": 73, + "res_type": "polyphase" + }, + "2": { + "sr": 11025, + "hl": 128, + "n_fft": 512, + "crop_start": 4, + "crop_stop": 185, + "hpf_start": 36, + "hpf_stop": 18, + "lpf_start": 93, + "lpf_stop": 185, + "res_type": "polyphase" + }, + "3": { + "sr": 22050, + "hl": 256, + "n_fft": 512, + "crop_start": 46, + "crop_stop": 186, + "hpf_start": 93, + "hpf_stop": 46, + "lpf_start": 164, + "lpf_stop": 186, + "res_type": "polyphase" + }, + "4": { + "sr": 44100, + "hl": 512, + "n_fft": 768, + "crop_start": 121, + "crop_stop": 382, + "hpf_start": 138, + "hpf_stop": 123, + "res_type": "sinc_medium" + } + }, + "sr": 44100, + "pre_filter_start": 740, + "pre_filter_stop": 768 +} \ No newline at end of file diff --git a/lib_v5/vr_network/modelparams/4band_v2.json b/lib_v5/vr_network/modelparams/4band_v2.json new file mode 100644 index 00000000..af798108 --- /dev/null +++ b/lib_v5/vr_network/modelparams/4band_v2.json @@ -0,0 +1,54 @@ +{ + "bins": 672, + "unstable_bins": 8, + "reduction_bins": 637, + "band": { + "1": { + "sr": 7350, + "hl": 80, + "n_fft": 640, + "crop_start": 0, + "crop_stop": 85, + "lpf_start": 25, + "lpf_stop": 53, + "res_type": "polyphase" + }, + "2": { + "sr": 7350, + "hl": 80, + "n_fft": 320, + "crop_start": 4, + "crop_stop": 87, + "hpf_start": 25, + "hpf_stop": 12, + "lpf_start": 31, + "lpf_stop": 62, + "res_type": "polyphase" + }, + "3": { + "sr": 14700, + "hl": 160, + "n_fft": 512, + "crop_start": 17, + "crop_stop": 216, + "hpf_start": 48, + "hpf_stop": 24, + "lpf_start": 139, + "lpf_stop": 210, + "res_type": "polyphase" + }, + "4": { + "sr": 44100, + "hl": 480, + "n_fft": 960, + "crop_start": 78, + "crop_stop": 383, + "hpf_start": 130, + "hpf_stop": 86, + "res_type": "kaiser_fast" + } + }, + "sr": 44100, + "pre_filter_start": 668, + "pre_filter_stop": 672 +} \ No newline at end of file diff --git a/lib_v5/vr_network/modelparams/4band_v2_sn.json b/lib_v5/vr_network/modelparams/4band_v2_sn.json new file mode 100644 index 00000000..319b9981 --- /dev/null +++ b/lib_v5/vr_network/modelparams/4band_v2_sn.json @@ -0,0 +1,55 @@ +{ + "bins": 672, + "unstable_bins": 8, + "reduction_bins": 637, + "band": { + "1": { + "sr": 7350, + "hl": 80, + "n_fft": 640, + "crop_start": 0, + "crop_stop": 85, + "lpf_start": 25, + "lpf_stop": 53, + "res_type": "polyphase" + }, + "2": { + "sr": 7350, + "hl": 80, + "n_fft": 320, + "crop_start": 4, + "crop_stop": 87, + "hpf_start": 25, + "hpf_stop": 12, + "lpf_start": 31, + "lpf_stop": 62, + "res_type": "polyphase" + }, + "3": { + "sr": 14700, + "hl": 160, + "n_fft": 512, + "crop_start": 17, + "crop_stop": 216, + "hpf_start": 48, + "hpf_stop": 24, + "lpf_start": 139, + "lpf_stop": 210, + "res_type": "polyphase" + }, + "4": { + "sr": 44100, + "hl": 480, + "n_fft": 960, + "crop_start": 78, + "crop_stop": 383, + "hpf_start": 130, + "hpf_stop": 86, + "convert_channels": "stereo_n", + "res_type": "kaiser_fast" + } + }, + "sr": 44100, + "pre_filter_start": 668, + "pre_filter_stop": 672 +} \ No newline at end of file diff --git a/lib_v5/vr_network/modelparams/4band_v3.json b/lib_v5/vr_network/modelparams/4band_v3.json new file mode 100644 index 00000000..2a73bc97 --- /dev/null +++ b/lib_v5/vr_network/modelparams/4band_v3.json @@ -0,0 +1,54 @@ +{ + "bins": 672, + "unstable_bins": 8, + "reduction_bins": 530, + "band": { + "1": { + "sr": 7350, + "hl": 80, + "n_fft": 640, + "crop_start": 0, + "crop_stop": 85, + "lpf_start": 25, + "lpf_stop": 53, + "res_type": "polyphase" + }, + "2": { + "sr": 7350, + "hl": 80, + "n_fft": 320, + "crop_start": 4, + "crop_stop": 87, + "hpf_start": 25, + "hpf_stop": 12, + "lpf_start": 31, + "lpf_stop": 62, + "res_type": "polyphase" + }, + "3": { + "sr": 14700, + "hl": 160, + "n_fft": 512, + "crop_start": 17, + "crop_stop": 216, + "hpf_start": 48, + "hpf_stop": 24, + "lpf_start": 139, + "lpf_stop": 210, + "res_type": "polyphase" + }, + "4": { + "sr": 44100, + "hl": 480, + "n_fft": 960, + "crop_start": 78, + "crop_stop": 383, + "hpf_start": 130, + "hpf_stop": 86, + "res_type": "kaiser_fast" + } + }, + "sr": 44100, + "pre_filter_start": 668, + "pre_filter_stop": 672 +} \ No newline at end of file diff --git a/lib_v5/vr_network/modelparams/ensemble.json b/lib_v5/vr_network/modelparams/ensemble.json new file mode 100644 index 00000000..ca96bf19 --- /dev/null +++ b/lib_v5/vr_network/modelparams/ensemble.json @@ -0,0 +1,43 @@ +{ + "mid_side_b2": true, + "bins": 1280, + "unstable_bins": 7, + "reduction_bins": 565, + "band": { + "1": { + "sr": 11025, + "hl": 108, + "n_fft": 2048, + "crop_start": 0, + "crop_stop": 374, + "lpf_start": 92, + "lpf_stop": 186, + "res_type": "polyphase" + }, + "2": { + "sr": 22050, + "hl": 216, + "n_fft": 1536, + "crop_start": 0, + "crop_stop": 424, + "hpf_start": 68, + "hpf_stop": 34, + "lpf_start": 348, + "lpf_stop": 418, + "res_type": "polyphase" + }, + "3": { + "sr": 44100, + "hl": 432, + "n_fft": 1280, + "crop_start": 132, + "crop_stop": 614, + "hpf_start": 172, + "hpf_stop": 144, + "res_type": "polyphase" + } + }, + "sr": 44100, + "pre_filter_start": 1280, + "pre_filter_stop": 1280 +} \ No newline at end of file diff --git a/lib_v5/vr_network/nets.py b/lib_v5/vr_network/nets.py new file mode 100644 index 00000000..7e53b6f9 --- /dev/null +++ b/lib_v5/vr_network/nets.py @@ -0,0 +1,171 @@ +import torch +from torch import nn +import torch.nn.functional as F + +from . import layers + +class BaseASPPNet(nn.Module): + + def __init__(self, nn_architecture, nin, ch, dilations=(4, 8, 16)): + super(BaseASPPNet, self).__init__() + self.nn_architecture = nn_architecture + self.enc1 = layers.Encoder(nin, ch, 3, 2, 1) + self.enc2 = layers.Encoder(ch, ch * 2, 3, 2, 1) + self.enc3 = layers.Encoder(ch * 2, ch * 4, 3, 2, 1) + self.enc4 = layers.Encoder(ch * 4, ch * 8, 3, 2, 1) + + if self.nn_architecture == 129605: + self.enc5 = layers.Encoder(ch * 8, ch * 16, 3, 2, 1) + self.aspp = layers.ASPPModule(nn_architecture, ch * 16, ch * 32, dilations) + self.dec5 = layers.Decoder(ch * (16 + 32), ch * 16, 3, 1, 1) + else: + self.aspp = layers.ASPPModule(nn_architecture, ch * 8, ch * 16, dilations) + + self.dec4 = layers.Decoder(ch * (8 + 16), ch * 8, 3, 1, 1) + self.dec3 = layers.Decoder(ch * (4 + 8), ch * 4, 3, 1, 1) + self.dec2 = layers.Decoder(ch * (2 + 4), ch * 2, 3, 1, 1) + self.dec1 = layers.Decoder(ch * (1 + 2), ch, 3, 1, 1) + + def __call__(self, x): + h, e1 = self.enc1(x) + h, e2 = self.enc2(h) + h, e3 = self.enc3(h) + h, e4 = self.enc4(h) + + if self.nn_architecture == 129605: + h, e5 = self.enc5(h) + h = self.aspp(h) + h = self.dec5(h, e5) + else: + h = self.aspp(h) + + h = self.dec4(h, e4) + h = self.dec3(h, e3) + h = self.dec2(h, e2) + h = self.dec1(h, e1) + + return h + +def determine_model_capacity(n_fft_bins, nn_architecture): + + sp_model_arch = [31191, 33966, 129605] + hp_model_arch = [123821, 123812] + hp2_model_arch = [537238, 537227] + + if nn_architecture in sp_model_arch: + model_capacity_data = [ + (2, 16), + (2, 16), + (18, 8, 1, 1, 0), + (8, 16), + (34, 16, 1, 1, 0), + (16, 32), + (32, 2, 1), + (16, 2, 1), + (16, 2, 1), + ] + + if nn_architecture in hp_model_arch: + model_capacity_data = [ + (2, 32), + (2, 32), + (34, 16, 1, 1, 0), + (16, 32), + (66, 32, 1, 1, 0), + (32, 64), + (64, 2, 1), + (32, 2, 1), + (32, 2, 1), + ] + + if nn_architecture in hp2_model_arch: + model_capacity_data = [ + (2, 64), + (2, 64), + (66, 32, 1, 1, 0), + (32, 64), + (130, 64, 1, 1, 0), + (64, 128), + (128, 2, 1), + (64, 2, 1), + (64, 2, 1), + ] + + cascaded = CascadedASPPNet + model = cascaded(n_fft_bins, model_capacity_data, nn_architecture) + + return model + +class CascadedASPPNet(nn.Module): + + def __init__(self, n_fft, model_capacity_data, nn_architecture): + super(CascadedASPPNet, self).__init__() + self.stg1_low_band_net = BaseASPPNet(nn_architecture, *model_capacity_data[0]) + self.stg1_high_band_net = BaseASPPNet(nn_architecture, *model_capacity_data[1]) + + self.stg2_bridge = layers.Conv2DBNActiv(*model_capacity_data[2]) + self.stg2_full_band_net = BaseASPPNet(nn_architecture, *model_capacity_data[3]) + + self.stg3_bridge = layers.Conv2DBNActiv(*model_capacity_data[4]) + self.stg3_full_band_net = BaseASPPNet(nn_architecture, *model_capacity_data[5]) + + self.out = nn.Conv2d(*model_capacity_data[6], bias=False) + self.aux1_out = nn.Conv2d(*model_capacity_data[7], bias=False) + self.aux2_out = nn.Conv2d(*model_capacity_data[8], bias=False) + + self.max_bin = n_fft // 2 + self.output_bin = n_fft // 2 + 1 + + self.offset = 128 + + def forward(self, x, aggressiveness=None): + mix = x.detach() + x = x.clone() + + x = x[:, :, :self.max_bin] + + bandw = x.size()[2] // 2 + aux1 = torch.cat([ + self.stg1_low_band_net(x[:, :, :bandw]), + self.stg1_high_band_net(x[:, :, bandw:]) + ], dim=2) + + h = torch.cat([x, aux1], dim=1) + aux2 = self.stg2_full_band_net(self.stg2_bridge(h)) + + h = torch.cat([x, aux1, aux2], dim=1) + h = self.stg3_full_band_net(self.stg3_bridge(h)) + + mask = torch.sigmoid(self.out(h)) + mask = F.pad( + input=mask, + pad=(0, 0, 0, self.output_bin - mask.size()[2]), + mode='replicate') + + if self.training: + aux1 = torch.sigmoid(self.aux1_out(aux1)) + aux1 = F.pad( + input=aux1, + pad=(0, 0, 0, self.output_bin - aux1.size()[2]), + mode='replicate') + aux2 = torch.sigmoid(self.aux2_out(aux2)) + aux2 = F.pad( + input=aux2, + pad=(0, 0, 0, self.output_bin - aux2.size()[2]), + mode='replicate') + return mask * mix, aux1 * mix, aux2 * mix + else: + if aggressiveness: + mask[:, :, :aggressiveness['split_bin']] = torch.pow(mask[:, :, :aggressiveness['split_bin']], 1 + aggressiveness['value'] / 3) + mask[:, :, aggressiveness['split_bin']:] = torch.pow(mask[:, :, aggressiveness['split_bin']:], 1 + aggressiveness['value']) + + return mask * mix + + def predict(self, x_mag, aggressiveness=None): + h = self.forward(x_mag, aggressiveness) + + if self.offset > 0: + h = h[:, :, :, self.offset:-self.offset] + assert h.size()[3] > 0 + + return h diff --git a/lib_v5/vr_network/nets_new.py b/lib_v5/vr_network/nets_new.py new file mode 100644 index 00000000..1629f8a3 --- /dev/null +++ b/lib_v5/vr_network/nets_new.py @@ -0,0 +1,143 @@ +import torch +from torch import nn +import torch.nn.functional as F +from . import layers_new as layers + +class BaseNet(nn.Module): + + def __init__(self, nin, nout, nin_lstm, nout_lstm, dilations=((4, 2), (8, 4), (12, 6))): + super(BaseNet, self).__init__() + self.enc1 = layers.Conv2DBNActiv(nin, nout, 3, 1, 1) + self.enc2 = layers.Encoder(nout, nout * 2, 3, 2, 1) + self.enc3 = layers.Encoder(nout * 2, nout * 4, 3, 2, 1) + self.enc4 = layers.Encoder(nout * 4, nout * 6, 3, 2, 1) + self.enc5 = layers.Encoder(nout * 6, nout * 8, 3, 2, 1) + + self.aspp = layers.ASPPModule(nout * 8, nout * 8, dilations, dropout=True) + + self.dec4 = layers.Decoder(nout * (6 + 8), nout * 6, 3, 1, 1) + self.dec3 = layers.Decoder(nout * (4 + 6), nout * 4, 3, 1, 1) + self.dec2 = layers.Decoder(nout * (2 + 4), nout * 2, 3, 1, 1) + self.lstm_dec2 = layers.LSTMModule(nout * 2, nin_lstm, nout_lstm) + self.dec1 = layers.Decoder(nout * (1 + 2) + 1, nout * 1, 3, 1, 1) + + def __call__(self, x): + e1 = self.enc1(x) + e2 = self.enc2(e1) + e3 = self.enc3(e2) + e4 = self.enc4(e3) + e5 = self.enc5(e4) + + h = self.aspp(e5) + + h = self.dec4(h, e4) + h = self.dec3(h, e3) + h = self.dec2(h, e2) + h = torch.cat([h, self.lstm_dec2(h)], dim=1) + h = self.dec1(h, e1) + + return h + +class CascadedNet(nn.Module): + + def __init__(self, n_fft, nn_architecture): + super(CascadedNet, self).__init__() + self.max_bin = n_fft // 2 + self.output_bin = n_fft // 2 + 1 + self.nin_lstm = self.max_bin // 2 + self.offset = 64 + self.nn_architecture = nn_architecture + + print('ARC SIZE: ', nn_architecture) + + if nn_architecture == 218409: + self.stg1_low_band_net = nn.Sequential( + BaseNet(2, 32, self.nin_lstm // 2, 128), + layers.Conv2DBNActiv(32, 16, 1, 1, 0) + ) + self.stg1_high_band_net = BaseNet(2, 16, self.nin_lstm // 2, 64) + + self.stg2_low_band_net = nn.Sequential( + BaseNet(18, 64, self.nin_lstm // 2, 128), + layers.Conv2DBNActiv(64, 32, 1, 1, 0) + ) + self.stg2_high_band_net = BaseNet(18, 32, self.nin_lstm // 2, 64) + + self.stg3_full_band_net = BaseNet(50, 64, self.nin_lstm, 128) + + self.out = nn.Conv2d(64, 2, 1, bias=False) + self.aux_out = nn.Conv2d(48, 2, 1, bias=False) + else: + self.stg1_low_band_net = nn.Sequential( + BaseNet(2, 16, self.nin_lstm // 2, 128), + layers.Conv2DBNActiv(16, 8, 1, 1, 0) + ) + self.stg1_high_band_net = BaseNet(2, 8, self.nin_lstm // 2, 64) + + self.stg2_low_band_net = nn.Sequential( + BaseNet(10, 32, self.nin_lstm // 2, 128), + layers.Conv2DBNActiv(32, 16, 1, 1, 0) + ) + self.stg2_high_band_net = BaseNet(10, 16, self.nin_lstm // 2, 64) + + self.stg3_full_band_net = BaseNet(26, 32, self.nin_lstm, 128) + + self.out = nn.Conv2d(32, 2, 1, bias=False) + self.aux_out = nn.Conv2d(24, 2, 1, bias=False) + + def forward(self, x): + x = x[:, :, :self.max_bin] + + bandw = x.size()[2] // 2 + l1_in = x[:, :, :bandw] + h1_in = x[:, :, bandw:] + l1 = self.stg1_low_band_net(l1_in) + h1 = self.stg1_high_band_net(h1_in) + aux1 = torch.cat([l1, h1], dim=2) + + l2_in = torch.cat([l1_in, l1], dim=1) + h2_in = torch.cat([h1_in, h1], dim=1) + l2 = self.stg2_low_band_net(l2_in) + h2 = self.stg2_high_band_net(h2_in) + aux2 = torch.cat([l2, h2], dim=2) + + f3_in = torch.cat([x, aux1, aux2], dim=1) + f3 = self.stg3_full_band_net(f3_in) + + mask = torch.sigmoid(self.out(f3)) + mask = F.pad( + input=mask, + pad=(0, 0, 0, self.output_bin - mask.size()[2]), + mode='replicate' + ) + + if self.training: + aux = torch.cat([aux1, aux2], dim=1) + aux = torch.sigmoid(self.aux_out(aux)) + aux = F.pad( + input=aux, + pad=(0, 0, 0, self.output_bin - aux.size()[2]), + mode='replicate' + ) + return mask, aux + else: + return mask + + def predict_mask(self, x): + mask = self.forward(x) + + if self.offset > 0: + mask = mask[:, :, :, self.offset:-self.offset] + assert mask.size()[3] > 0 + + return mask + + def predict(self, x): + mask = self.forward(x) + pred_mag = x * mask + + if self.offset > 0: + pred_mag = pred_mag[:, :, :, self.offset:-self.offset] + assert pred_mag.size()[3] > 0 + + return pred_mag diff --git a/models/Demucs_Models/v3_v4_repo/demucs_models.txt b/models/Demucs_Models/v3_v4_repo/demucs_models.txt new file mode 100644 index 00000000..584c89e5 --- /dev/null +++ b/models/Demucs_Models/v3_v4_repo/demucs_models.txt @@ -0,0 +1 @@ +Demucs v3 and v4 models go here. \ No newline at end of file diff --git a/models/MDX_Net_Models/model_data/model_data.json b/models/MDX_Net_Models/model_data/model_data.json new file mode 100644 index 00000000..03bd1bdf --- /dev/null +++ b/models/MDX_Net_Models/model_data/model_data.json @@ -0,0 +1,184 @@ +{ + "0ddfc0eb5792638ad5dc27850236c246": { + "compensate": 1.035, + "mdx_dim_f_set": 2048, + "mdx_dim_t_set": 8, + "mdx_n_fft_scale_set": 6144, + "primary_stem": "Vocals" + }, + "26d308f91f3423a67dc69a6d12a8793d": { + "compensate": 1.035, + "mdx_dim_f_set": 2048, + "mdx_dim_t_set": 9, + "mdx_n_fft_scale_set": 8192, + "primary_stem": "Other" + }, + "2cdd429caac38f0194b133884160f2c6": { + "compensate": 1.035, + "mdx_dim_f_set": 3072, + "mdx_dim_t_set": 8, + "mdx_n_fft_scale_set": 7680, + "primary_stem": "Instrumental" + }, + "2f5501189a2f6db6349916fabe8c90de": { + "compensate": 1.035, + "mdx_dim_f_set": 2048, + "mdx_dim_t_set": 8, + "mdx_n_fft_scale_set": 7680, + "primary_stem": "Instrumental" + }, + "398580b6d5d973af3120df54cee6759d": { + "compensate": 1.75, + "mdx_dim_f_set": 3072, + "mdx_dim_t_set": 8, + "mdx_n_fft_scale_set": 7680, + "primary_stem": "Vocals" + }, + "488b3e6f8bd3717d9d7c428476be2d75": { + "compensate": 1.035, + "mdx_dim_f_set": 3072, + "mdx_dim_t_set": 8, + "mdx_n_fft_scale_set": 7680, + "primary_stem": "Instrumental" + }, + "4910e7827f335048bdac11fa967772f9": { + "compensate": 1.035, + "mdx_dim_f_set": 2048, + "mdx_dim_t_set": 7, + "mdx_n_fft_scale_set": 4096, + "primary_stem": "Drums" + }, + "53c4baf4d12c3e6c3831bb8f5b532b93": { + "compensate": 1.035, + "mdx_dim_f_set": 3072, + "mdx_dim_t_set": 8, + "mdx_n_fft_scale_set": 7680, + "primary_stem": "Vocals" + }, + "5d343409ef0df48c7d78cce9f0106781": { + "compensate": 1.075, + "mdx_dim_f_set": 3072, + "mdx_dim_t_set": 8, + "mdx_n_fft_scale_set": 7680, + "primary_stem": "Vocals" + }, + "5f6483271e1efb9bfb59e4a3e6d4d098": { + "compensate": 1.035, + "mdx_dim_f_set": 2048, + "mdx_dim_t_set": 9, + "mdx_n_fft_scale_set": 6144, + "primary_stem": "Vocals" + }, + "65ab5919372a128e4167f5e01a8fda85": { + "compensate": 1.035, + "mdx_dim_f_set": 2048, + "mdx_dim_t_set": 8, + "mdx_n_fft_scale_set": 8192, + "primary_stem": "Other" + }, + "6703e39f36f18aa7855ee1047765621d": { + "compensate": 1.035, + "mdx_dim_f_set": 2048, + "mdx_dim_t_set": 9, + "mdx_n_fft_scale_set": 16384, + "primary_stem": "Bass" + }, + "6b31de20e84392859a3d09d43f089515": { + "compensate": 1.035, + "mdx_dim_f_set": 2048, + "mdx_dim_t_set": 8, + "mdx_n_fft_scale_set": 6144, + "primary_stem": "Vocals" + }, + "867595e9de46f6ab699008295df62798": { + "compensate": 1.075, + "mdx_dim_f_set": 3072, + "mdx_dim_t_set": 8, + "mdx_n_fft_scale_set": 7680, + "primary_stem": "Vocals" + }, + "a3cd63058945e777505c01d2507daf37": { + "compensate": 1.035, + "mdx_dim_f_set": 2048, + "mdx_dim_t_set": 8, + "mdx_n_fft_scale_set": 6144, + "primary_stem": "Vocals" + }, + "b33d9b3950b6cbf5fe90a32608924700": { + "compensate": 1.075, + "mdx_dim_f_set": 3072, + "mdx_dim_t_set": 8, + "mdx_n_fft_scale_set": 7680, + "primary_stem": "Vocals" + }, + "c3b29bdce8c4fa17ec609e16220330ab": { + "compensate": 1.035, + "mdx_dim_f_set": 2048, + "mdx_dim_t_set": 8, + "mdx_n_fft_scale_set": 16384, + "primary_stem": "Bass" + }, + "ceed671467c1f64ebdfac8a2490d0d52": { + "compensate": 1.035, + "mdx_dim_f_set": 3072, + "mdx_dim_t_set": 8, + "mdx_n_fft_scale_set": 7680, + "primary_stem": "Instrumental" + }, + "d2a1376f310e4f7fa37fb9b5774eb701": { + "compensate": 1.035, + "mdx_dim_f_set": 3072, + "mdx_dim_t_set": 8, + "mdx_n_fft_scale_set": 7680, + "primary_stem": "Instrumental" + }, + "d7bff498db9324db933d913388cba6be": { + "compensate": 1.035, + "mdx_dim_f_set": 2048, + "mdx_dim_t_set": 8, + "mdx_n_fft_scale_set": 6144, + "primary_stem": "Vocals" + }, + "d94058f8c7f1fae4164868ae8ae66b20": { + "compensate": 1.035, + "mdx_dim_f_set": 2048, + "mdx_dim_t_set": 8, + "mdx_n_fft_scale_set": 6144, + "primary_stem": "Vocals" + }, + "dc41ede5961d50f277eb846db17f5319": { + "compensate": 1.035, + "mdx_dim_f_set": 2048, + "mdx_dim_t_set": 9, + "mdx_n_fft_scale_set": 4096, + "primary_stem": "Drums" + }, + "e5572e58abf111f80d8241d2e44e7fa4": { + "compensate": 1.035, + "mdx_dim_f_set": 3072, + "mdx_dim_t_set": 8, + "mdx_n_fft_scale_set": 7680, + "primary_stem": "Instrumental" + }, + "e7324c873b1f615c35c1967f912db92a": { + "compensate": 1.075, + "mdx_dim_f_set": 3072, + "mdx_dim_t_set": 8, + "mdx_n_fft_scale_set": 7680, + "primary_stem": "Vocals" + }, + "1c56ec0224f1d559c42fd6fd2a67b154": { + "compensate": 1.035, + "mdx_dim_f_set": 2048, + "mdx_dim_t_set": 8, + "mdx_n_fft_scale_set": 5120, + "primary_stem": "Instrumental" + }, + "f2df6d6863d8f435436d8b561594ff49": { + "compensate": 1.035, + "mdx_dim_f_set": 3072, + "mdx_dim_t_set": 8, + "mdx_n_fft_scale_set": 7680, + "primary_stem": "Instrumental" + } +} \ No newline at end of file diff --git a/models/VR_Models/model_data/model_data.json b/models/VR_Models/model_data/model_data.json new file mode 100644 index 00000000..b1abb2c9 --- /dev/null +++ b/models/VR_Models/model_data/model_data.json @@ -0,0 +1,94 @@ +{ + "0d0e6d143046b0eecc41a22e60224582": { + "vr_model_param": "3band_44100_mid", + "primary_stem": "Instrumental" + }, + "18b52f873021a0af556fb4ecd552bb8e": { + "vr_model_param": "2band_32000", + "primary_stem": "Instrumental" + }, + "1fc66027c82b499c7d8f55f79e64cadc": { + "vr_model_param": "2band_32000", + "primary_stem": "Instrumental" + }, + "2aa34fbc01f8e6d2bf509726481e7142": { + "vr_model_param": "4band_44100", + "primary_stem": "Other" + }, + "3e18f639b11abea7361db1a4a91c2559": { + "vr_model_param": "4band_44100", + "primary_stem": "Instrumental" + }, + "570b5f50054609a17741369a35007ddd": { + "vr_model_param": "4band_v3", + "primary_stem": "Instrumental" + }, + "5a6e24c1b530f2dab045a522ef89b751": { + "vr_model_param": "1band_sr44100_hl512", + "primary_stem": "Instrumental" + }, + "6b5916069a49be3fe29d4397ecfd73fa": { + "vr_model_param": "3band_44100_msb2", + "primary_stem": "Instrumental" + }, + "74b3bc5fa2b69f29baf7839b858bc679": { + "vr_model_param": "4band_44100", + "primary_stem": "Instrumental" + }, + "827213b316df36b52a1f3d04fec89369": { + "vr_model_param": "4band_44100", + "primary_stem": "Instrumental" + }, + "911d4048eee7223eca4ee0efb7d29256": { + "vr_model_param": "4band_44100", + "primary_stem": "Vocals" + }, + "941f3f7f0b0341f12087aacdfef644b1": { + "vr_model_param": "4band_v2", + "primary_stem": "Instrumental" + }, + "a02827cf69d75781a35c0e8a327f3195": { + "vr_model_param": "1band_sr33075_hl384", + "primary_stem": "Instrumental" + }, + "b165fbff113c959dba5303b74c6484bc": { + "vr_model_param": "3band_44100", + "primary_stem": "Instrumental" + }, + "b5f988cd3e891dca7253bf5f0f3427c7": { + "vr_model_param": "4band_44100", + "primary_stem": "Instrumental" + }, + "b99c35723bc35cb11ed14a4780006a80": { + "vr_model_param": "1band_sr44100_hl1024", + "primary_stem": "Instrumental" + }, + "ba02fd25b71d620eebbdb49e18e4c336": { + "vr_model_param": "3band_44100_mid", + "primary_stem": "Instrumental" + }, + "c4476ef424d8cba65f38d8d04e8514e2": { + "vr_model_param": "3band_44100_msb2", + "primary_stem": "Instrumental" + }, + "da2d37b8be2972e550a409bae08335aa": { + "vr_model_param": "4band_44100", + "primary_stem": "Vocals" + }, + "db57205d3133e39df8e050b435a78c80": { + "vr_model_param": "4band_44100", + "primary_stem": "Instrumental" + }, + "ea83b08e32ec2303456fe50659035f69": { + "vr_model_param": "4band_v3", + "primary_stem": "Instrumental" + }, + "f6ea8473ff86017b5ebd586ccacf156b": { + "vr_model_param": "4band_v2_sn", + "primary_stem": "Instrumental" + }, + "fd297a61eafc9d829033f8b987c39a3d": { + "vr_model_param": "1band_sr32000_hl512", + "primary_stem": "Instrumental" + } +} \ No newline at end of file diff --git a/separate.py b/separate.py new file mode 100644 index 00000000..2cf43ca8 --- /dev/null +++ b/separate.py @@ -0,0 +1,924 @@ +from __future__ import annotations +from typing import TYPE_CHECKING +from demucs.apply import apply_model, demucs_segments +from demucs.hdemucs import HDemucs +from demucs.model_v2 import auto_load_demucs_model_v2 +from demucs.pretrained import get_model as _gm +from demucs.utils import apply_model_v1 +from demucs.utils import apply_model_v2 +from lib_v5 import spec_utils +from lib_v5.vr_network import nets +from lib_v5.vr_network import nets_new +#from lib_v5.vr_network.model_param_init import ModelParameters +from pathlib import Path +from gui_data.constants import * +import gzip +import librosa +import math +import numpy as np +import onnxruntime as ort +import os +import torch +import warnings +import pydub +import soundfile as sf + +if TYPE_CHECKING: + from UVR import ModelData + +warnings.filterwarnings("ignore") +cpu = torch.device('cpu') + +class SeperateAttributes: + def __init__(self, model_data: ModelData, process_data: dict, main_model_primary_stem_4_stem=None, main_process_method=None): + + self.list_all_models: list + self.process_data = process_data + self.progress_value = 0 + self.set_progress_bar = process_data['set_progress_bar'] + self.write_to_console = process_data['write_to_console'] + self.audio_file = process_data['audio_file'] + self.audio_file_base = process_data['audio_file_base'] + self.export_path = process_data['export_path'] + self.cached_source_callback = process_data['cached_source_callback'] + self.cached_model_source_holder = process_data['cached_model_source_holder'] + self.is_4_stem_ensemble = process_data['is_4_stem_ensemble'] + self.list_all_models = process_data['list_all_models'] + self.process_iteration = process_data['process_iteration'] + self.model_samplerate = model_data.model_samplerate + self.is_pre_proc_model = model_data.is_pre_proc_model + self.is_secondary_model_activated = model_data.is_secondary_model_activated if not self.is_pre_proc_model else False + self.is_secondary_model = model_data.is_secondary_model if not self.is_pre_proc_model else True + self.process_method = model_data.process_method + self.model_path = model_data.model_path + self.model_name = model_data.model_name + self.model_basename = model_data.model_basename + self.wav_type_set = model_data.wav_type_set + self.mp3_bit_set = model_data.mp3_bit_set + self.save_format = model_data.save_format + self.is_gpu_conversion = model_data.is_gpu_conversion + self.is_normalization = model_data.is_normalization + self.is_primary_stem_only = model_data.is_primary_stem_only if not self.is_secondary_model else model_data.is_primary_model_primary_stem_only + self.is_secondary_stem_only = model_data.is_secondary_stem_only if not self.is_secondary_model else model_data.is_primary_model_secondary_stem_only + self.is_ensemble_mode = model_data.is_ensemble_mode + self.secondary_model = model_data.secondary_model #VERIFY WHERE + self.primary_model_primary_stem = model_data.primary_model_primary_stem + self.primary_stem = model_data.primary_stem #- + self.secondary_stem = model_data.secondary_stem #- + self.is_invert_spec = model_data.is_invert_spec # + self.secondary_model_scale = model_data.secondary_model_scale # + self.is_demucs_pre_proc_model_inst_mix = model_data.is_demucs_pre_proc_model_inst_mix # + ############################# + self.primary_source_map = {} + self.secondary_source_map = {} + self.primary_source = None + self.secondary_source = None + self.secondary_source_primary = None + self.secondary_source_secondary = None + + if not model_data.process_method == DEMUCS_ARCH_TYPE: + if process_data['is_ensemble_master'] and not self.is_4_stem_ensemble: + if not model_data.ensemble_primary_stem == self.primary_stem: + self.is_primary_stem_only, self.is_secondary_stem_only = self.is_secondary_stem_only, self.is_primary_stem_only + + if self.is_secondary_model and not process_data['is_ensemble_master']: + if not self.primary_model_primary_stem == self.primary_stem and not main_model_primary_stem_4_stem: + self.is_primary_stem_only, self.is_secondary_stem_only = self.is_secondary_stem_only, self.is_primary_stem_only + + if main_model_primary_stem_4_stem: + self.is_primary_stem_only = True if main_model_primary_stem_4_stem == self.primary_stem else False + self.is_secondary_stem_only = True if not main_model_primary_stem_4_stem == self.primary_stem else False + + if self.is_pre_proc_model: + self.is_primary_stem_only = True if self.primary_stem == INST_STEM else False + self.is_secondary_stem_only = True if self.secondary_stem == INST_STEM else False + + if model_data.process_method == MDX_ARCH_TYPE: + self.primary_model_name, self.primary_sources = self.cached_source_callback(MDX_ARCH_TYPE, model_name=self.model_basename) + self.is_denoise = model_data.is_denoise + self.compensate = model_data.compensate + self.dim_f, self.dim_t = model_data.mdx_dim_f_set, 2**model_data.mdx_dim_t_set + self.n_fft = model_data.mdx_n_fft_scale_set + self.chunks = model_data.chunks + self.margin = model_data.margin + self.hop = 1024 + self.n_bins = self.n_fft//2+1 + self.chunk_size = self.hop * (self.dim_t-1) + self.window = torch.hann_window(window_length=self.n_fft, periodic=False).to(cpu) + self.dim_c = 4 + out_c = self.dim_c + self.freq_pad = torch.zeros([1, out_c, self.n_bins-self.dim_f, self.dim_t]).to(cpu) + + if model_data.process_method == DEMUCS_ARCH_TYPE: + self.demucs_stems = model_data.demucs_stems if not main_process_method in [MDX_ARCH_TYPE, VR_ARCH_TYPE] else None + self.secondary_model_4_stem = model_data.secondary_model_4_stem + self.secondary_model_4_stem_scale = model_data.secondary_model_4_stem_scale + self.primary_stem = model_data.ensemble_primary_stem if process_data['is_ensemble_master'] else model_data.primary_stem + self.secondary_stem = model_data.ensemble_secondary_stem if process_data['is_ensemble_master'] else model_data.secondary_stem + self.is_chunk_demucs = model_data.is_chunk_demucs + self.segment = model_data.segment + self.demucs_version = model_data.demucs_version + self.demucs_source_list = model_data.demucs_source_list + self.demucs_source_map = model_data.demucs_source_map + self.is_demucs_combine_stems = model_data.is_demucs_combine_stems + self.demucs_stem_count = model_data.demucs_stem_count + self.pre_proc_model = model_data.pre_proc_model + + if self.is_secondary_model and not process_data['is_ensemble_master']: + if not self.demucs_stem_count == 2 and model_data.primary_model_primary_stem == INST_STEM: + self.primary_stem = VOCAL_STEM + self.secondary_stem = INST_STEM + else: + self.primary_stem = model_data.primary_model_primary_stem + self.secondary_stem = STEM_PAIR_MAPPER[self.primary_stem] + + if self.is_chunk_demucs: + self.chunks_demucs = model_data.chunks_demucs + self.margin_demucs = model_data.margin_demucs + else: + self.chunks_demucs = 0 + self.margin_demucs = 44100 + + self.shifts = model_data.shifts + self.is_split_mode = model_data.is_split_mode if not self.demucs_version == DEMUCS_V4 else True + self.overlap = model_data.overlap + self.primary_model_name, self.primary_sources = self.cached_source_callback(DEMUCS_ARCH_TYPE, model_name=self.model_basename) + + if model_data.process_method == VR_ARCH_TYPE: + self.primary_model_name, self.primary_sources = self.cached_source_callback(VR_ARCH_TYPE, model_name=self.model_basename) + self.mp = model_data.vr_model_param + self.high_end_process = model_data.is_high_end_process + self.is_tta = model_data.is_tta + self.is_post_process = model_data.is_post_process + self.is_gpu_conversion = model_data.is_gpu_conversion + self.batch_size = model_data.batch_size + self.crop_size = model_data.crop_size + self.window_size = model_data.window_size + self.input_high_end_h = None + self.aggressiveness = {'value': model_data.aggression_setting, + 'split_bin': self.mp.param['band'][1]['crop_stop'], + 'aggr_correction': self.mp.param.get('aggr_correction')} + + def start_inference(self): + + if self.is_secondary_model and not self.is_pre_proc_model: + self.write_to_console(INFERENCE_STEP_2_SEC(self.process_method, self.model_basename)) + + if self.is_pre_proc_model: + self.write_to_console(INFERENCE_STEP_2_PRE(self.process_method, self.model_basename)) + + def running_inference(self, is_no_write=False): + + self.write_to_console(DONE, base_text='') if not is_no_write else None + self.set_progress_bar(0.05) if not is_no_write else None + + if self.is_secondary_model and not self.is_pre_proc_model: + self.write_to_console(INFERENCE_STEP_1_SEC) + elif self.is_pre_proc_model: + self.write_to_console(INFERENCE_STEP_1_PRE) + else: + self.write_to_console(INFERENCE_STEP_1) + + def load_cached_sources(self, is_4_stem_demucs=False): + + if self.is_secondary_model and not self.is_pre_proc_model: + self.write_to_console(INFERENCE_STEP_2_SEC_CACHED_MODOEL(self.process_method, self.model_basename)) + elif self.is_pre_proc_model: + self.write_to_console(INFERENCE_STEP_2_PRE_CACHED_MODOEL(self.process_method, self.model_basename)) + else: + self.write_to_console(INFERENCE_STEP_2_PRIMARY_CACHED) + + if not is_4_stem_demucs: + primary_stem, secondary_stem = gather_sources(self.primary_stem, self.secondary_stem, self.primary_sources) + + return primary_stem, secondary_stem + + def cache_source(self, secondary_sources): + + model_occurrences = self.list_all_models.count(self.model_basename) + + if not model_occurrences <= 1: + if self.process_method == MDX_ARCH_TYPE: + self.cached_model_source_holder(MDX_ARCH_TYPE, secondary_sources, self.model_basename) + + if self.process_method == VR_ARCH_TYPE: + self.cached_model_source_holder(VR_ARCH_TYPE, secondary_sources, self.model_basename) + + if self.process_method == DEMUCS_ARCH_TYPE: + self.cached_model_source_holder(DEMUCS_ARCH_TYPE, secondary_sources, self.model_basename) + + # if isinstance(secondary_sources, np.ndarray): + # print('\n==================================\n', secondary_sources, f"\n\nMemory size of source for model {self.model_basename}: ", secondary_sources.size * secondary_sources.itemsize, 'BYTES', '\n==================================\n') + + # if type(secondary_sources) is dict: + # print('\n==================================\n', secondary_sources, f"\n\nMemory size of source for model {self.model_basename}: ", sum(v.size * v.itemsize for v in secondary_sources.values()), 'BYTES', '\n==================================\n') + + def write_audio(self, stem_path, stem_source, samplerate, secondary_model_source=None, model_scale=None): + + if not self.is_secondary_model: + if self.is_secondary_model_activated: + if isinstance(secondary_model_source, np.ndarray): + secondary_model_scale = model_scale if model_scale else self.secondary_model_scale + stem_source = spec_utils.average_dual_sources(stem_source, secondary_model_source, secondary_model_scale) + + sf.write(stem_path, stem_source, samplerate, subtype=self.wav_type_set) + save_format(stem_path, self.save_format, self.mp3_bit_set) if not self.is_ensemble_mode else None + + self.write_to_console(DONE, base_text='') + self.set_progress_bar(0.95) + +class SeperateMDX(SeperateAttributes): + + def seperate(self): + + samplerate = 44100 + + if self.primary_model_name == self.model_basename and self.primary_sources: + self.primary_source, self.secondary_source = self.load_cached_sources() + else: + self.start_inference() + if self.is_gpu_conversion >= 0: + self.device = torch.device('cuda:0' if torch.cuda.is_available() else 'cpu') + run_type = ['CUDAExecutionProvider'] if torch.cuda.is_available() else ['CPUExecutionProvider'] + else: + self.device = torch.device('cpu') + run_type = ['CPUExecutionProvider'] + + self.onnx_model = ort.InferenceSession(self.model_path, providers=run_type) + + self.running_inference() + mdx_net_cut = True if self.primary_stem in MDX_NET_FREQ_CUT else False + mix, raw_mix, samplerate = prepare_mix(self.audio_file, self.chunks, self.margin, mdx_net_cut=mdx_net_cut) + + source = self.demix_base(mix) + self.write_to_console(DONE, base_text='') + + if self.is_secondary_model_activated: + if self.secondary_model: + self.secondary_source_primary, self.secondary_source_secondary = process_secondary_model(self.secondary_model, self.process_data, main_process_method=self.process_method) + + if not self.is_secondary_stem_only: + self.write_to_console(f'{SAVING_STEM[0]}{self.primary_stem}{SAVING_STEM[1]}') if not self.is_secondary_model else None + primary_stem_path = os.path.join(self.export_path, f'{self.audio_file_base}_({self.primary_stem}).wav') + if not isinstance(self.primary_source, np.ndarray): + self.primary_source = spec_utils.normalize(source[0], self.is_normalization).T + self.primary_source_map = {self.primary_stem: self.primary_source} + self.write_audio(primary_stem_path, self.primary_source, samplerate, self.secondary_source_primary) + + if not self.is_primary_stem_only: + self.write_to_console(f'{SAVING_STEM[0]}{self.secondary_stem}{SAVING_STEM[1]}') if not self.is_secondary_model else None + secondary_stem_path = os.path.join(self.export_path, f'{self.audio_file_base}_({self.secondary_stem}).wav') + if not isinstance(self.secondary_source, np.ndarray): + raw_mix = self.demix_base(raw_mix, is_match_mix=True)[0] if mdx_net_cut else raw_mix + self.secondary_source, raw_mix = spec_utils.normalize_two_stem(source[0]*self.compensate, raw_mix, self.is_normalization) + + if self.is_invert_spec: + self.secondary_source = spec_utils.invert_stem(raw_mix, self.secondary_source) + else: + self.secondary_source = (-self.secondary_source.T+raw_mix.T) + + self.secondary_source_map = {self.secondary_stem: self.secondary_source} + self.write_audio(secondary_stem_path, self.secondary_source, samplerate, self.secondary_source_secondary) + + torch.cuda.empty_cache() + + secondary_sources = {**self.primary_source_map, **self.secondary_source_map} + + self.cache_source(secondary_sources) + + if self.is_secondary_model: + return secondary_sources + + def demix_base(self, mix, is_match_mix=False): + chunked_sources = [] + + for slice in mix: + self.progress_value += 1 + self.set_progress_bar(0.1, (0.8/len(mix)*self.progress_value)) if not is_match_mix else None + cmix = mix[slice] + sources = [] + mix_waves = [] + n_sample = cmix.shape[1] + trim = self.n_fft//2 + gen_size = self.chunk_size-2*trim + pad = gen_size - n_sample%gen_size + mix_p = np.concatenate((np.zeros((2,trim)), cmix, np.zeros((2,pad)), np.zeros((2,trim))), 1) + i = 0 + while i < n_sample + pad: + waves = np.array(mix_p[:, i:i+self.chunk_size]) + mix_waves.append(waves) + i += gen_size + mix_waves = torch.tensor(mix_waves, dtype=torch.float32).to(cpu) + with torch.no_grad(): + _ort = self.onnx_model if not is_match_mix else None + adjust = 1 + spek = self.stft(mix_waves)*adjust + + if not is_match_mix: + if self.is_denoise: + spec_pred = -_ort.run(None, {'input': -spek.cpu().numpy()})[0]*0.5+_ort.run(None, {'input': spek.cpu().numpy()})[0]*0.5 + else: + spec_pred = _ort.run(None, {'input': spek.cpu().numpy()})[0] + else: + spec_pred = spek.cpu().numpy() + + tar_waves = self.istft(torch.tensor(spec_pred))#.cpu() + tar_signal = tar_waves[:,:,trim:-trim].transpose(0,1).reshape(2, -1).numpy()[:, :-pad] + start = 0 if slice == 0 else self.margin + end = None if slice == list(mix.keys())[::-1][0] else -self.margin + if self.margin == 0: + end = None + sources.append(tar_signal[:,start:end]*(1/adjust)) + chunked_sources.append(sources) + sources = np.concatenate(chunked_sources, axis=-1) + + if not is_match_mix: + del self.onnx_model + + return sources + + def stft(self, x): + x = x.reshape([-1, self.chunk_size]) + x = torch.stft(x, n_fft=self.n_fft, hop_length=self.hop, window=self.window, center=True) + x = x.permute([0,3,1,2]) + x = x.reshape([-1,2,2,self.n_bins,self.dim_t]).reshape([-1,self.dim_c,self.n_bins,self.dim_t]) + return x[:,:,:self.dim_f] + + def istft(self, x, freq_pad=None): + freq_pad = self.freq_pad.repeat([x.shape[0],1,1,1]) if freq_pad is None else freq_pad + x = torch.cat([x, freq_pad], -2) + c = 2 + x = x.reshape([-1,c,2,self.n_bins,self.dim_t]).reshape([-1,2,self.n_bins,self.dim_t]) + x = x.permute([0,2,3,1]) + x = torch.istft(x, n_fft=self.n_fft, hop_length=self.hop, window=self.window, center=True) + return x.reshape([-1,c,self.chunk_size]) + +class SeperateDemucs(SeperateAttributes): + + def seperate(self): + + samplerate = 44100 + source = None + model_scale = None + stem_source = None + stem_source_secondary = None + inst_mix = None + inst_raw_mix = None + raw_mix = None + inst_source = None + is_no_write = False + is_no_piano_guitar = False + + if self.primary_model_name == self.model_basename and type(self.primary_sources) is dict and not self.pre_proc_model: + self.primary_source, self.secondary_source = self.load_cached_sources() + elif self.primary_model_name == self.model_basename and isinstance(self.primary_sources, np.ndarray) and not self.pre_proc_model: + source = self.primary_sources + self.load_cached_sources(is_4_stem_demucs=True) + else: + self.start_inference() + if self.is_gpu_conversion >= 0: + self.device = torch.device('cuda:0' if torch.cuda.is_available() else 'cpu') + else: + self.device = torch.device('cpu') + + if self.demucs_version == DEMUCS_V1: + if str(self.model_path).endswith(".gz"): + self.model_path = gzip.open(self.model_path, "rb") + klass, args, kwargs, state = torch.load(self.model_path) + self.demucs = klass(*args, **kwargs) + self.demucs.to(self.device) + self.demucs.load_state_dict(state) + elif self.demucs_version == DEMUCS_V2: + self.demucs = auto_load_demucs_model_v2(self.demucs_source_list, self.model_path) + self.demucs.to(self.device) + self.demucs.load_state_dict(torch.load(self.model_path)) + self.demucs.eval() + else: + self.demucs = HDemucs(sources=self.demucs_source_list) + self.demucs = _gm(name=os.path.splitext(os.path.basename(self.model_path))[0], + repo=Path(os.path.dirname(self.model_path))) + self.demucs = demucs_segments(self.segment, self.demucs) + self.demucs.to(self.device) + self.demucs.eval() + + if self.pre_proc_model: + if self.primary_stem not in [VOCAL_STEM, INST_STEM]: + is_no_write = True + self.write_to_console(DONE, base_text='') + mix_no_voc = process_secondary_model(self.pre_proc_model, self.process_data, is_pre_proc_model=True) + inst_mix, inst_raw_mix, inst_samplerate = prepare_mix(mix_no_voc[INST_STEM], self.chunks_demucs, self.margin_demucs) + self.process_iteration() + self.running_inference(is_no_write=is_no_write) + inst_source = self.demix_demucs(inst_mix) + self.process_iteration() + + self.running_inference(is_no_write=is_no_write) if not self.pre_proc_model else None + mix, raw_mix, samplerate = prepare_mix(self.audio_file, self.chunks_demucs, self.margin_demucs) + + if self.primary_model_name == self.model_basename and isinstance(self.primary_sources, np.ndarray) and self.pre_proc_model: + source = self.primary_sources + else: + source = self.demix_demucs(mix) + + self.write_to_console(DONE, base_text='') + + del self.demucs + + if isinstance(inst_source, np.ndarray): + source_reshape = spec_utils.reshape_sources(inst_source[self.demucs_source_map[VOCAL_STEM]], source[self.demucs_source_map[VOCAL_STEM]]) + inst_source[self.demucs_source_map[VOCAL_STEM]] = source_reshape + source = inst_source + + if isinstance(source, np.ndarray): + if len(source) == 2: + self.demucs_source_map = DEMUCS_2_SOURCE_MAPPER + else: + self.demucs_source_map = DEMUCS_6_SOURCE_MAPPER if len(source) == 6 else DEMUCS_4_SOURCE_MAPPER + if len(source) == 6 and self.process_data['is_ensemble_master'] or len(source) == 6 and self.is_secondary_model: + is_no_piano_guitar = True + six_stem_other_source = list(source) + six_stem_other_source = [i for n, i in enumerate(source) if n in [self.demucs_source_map[OTHER_STEM], self.demucs_source_map[GUITAR_STEM], self.demucs_source_map[PIANO_STEM]]] + other_source = np.zeros_like(six_stem_other_source[0]) + for i in six_stem_other_source: + other_source += i + source_reshape = spec_utils.reshape_sources(source[self.demucs_source_map[OTHER_STEM]], other_source) + source[self.demucs_source_map[OTHER_STEM]] = source_reshape + + if (self.demucs_stems == ALL_STEMS and not self.process_data['is_ensemble_master']) or self.is_4_stem_ensemble: + self.cache_source(source) + + for stem_name, stem_value in self.demucs_source_map.items(): + + if self.is_secondary_model_activated and not self.is_secondary_model and not stem_value >= 4: + if self.secondary_model_4_stem[stem_value]: + model_scale = self.secondary_model_4_stem_scale[stem_value] + stem_source_secondary = process_secondary_model(self.secondary_model_4_stem[stem_value], self.process_data, main_model_primary_stem_4_stem=stem_name, is_4_stem_demucs=True) + if isinstance(stem_source_secondary, np.ndarray): + stem_source_secondary = stem_source_secondary[1 if self.secondary_model_4_stem[stem_value].demucs_stem_count == 2 else stem_value] + stem_source_secondary = spec_utils.normalize(stem_source_secondary, self.is_normalization).T + elif type(stem_source_secondary) is dict: + stem_source_secondary = stem_source_secondary[stem_name] + + stem_source_secondary = None if stem_value >= 4 else stem_source_secondary + self.write_to_console(f'{SAVING_STEM[0]}{stem_name}{SAVING_STEM[1]}') if not self.is_secondary_model else None + stem_path = os.path.join(self.export_path, f'{self.audio_file_base}_({stem_name}).wav') + stem_source = spec_utils.normalize(source[stem_value], self.is_normalization).T + self.write_audio(stem_path, stem_source, samplerate, secondary_model_source=stem_source_secondary, model_scale=model_scale) + + if self.is_secondary_model: + return source + else: + if self.is_secondary_model_activated: + if self.secondary_model: + self.secondary_source_primary, self.secondary_source_secondary = process_secondary_model(self.secondary_model, self.process_data, main_process_method=self.process_method) + + if not self.is_secondary_stem_only: + self.write_to_console(f'{SAVING_STEM[0]}{self.primary_stem}{SAVING_STEM[1]}') if not self.is_secondary_model else None + primary_stem_path = os.path.join(self.export_path, f'{self.audio_file_base}_({self.primary_stem}).wav') + if not isinstance(self.primary_source, np.ndarray): + self.primary_source = spec_utils.normalize(source[self.demucs_source_map[self.primary_stem]], self.is_normalization).T + self.primary_source_map = {self.primary_stem: self.primary_source} + self.write_audio(primary_stem_path, self.primary_source, samplerate, self.secondary_source_primary) + + if not self.is_primary_stem_only: + def secondary_save(sec_stem_name, source, raw_mixture=None, is_inst_mixture=False): + secondary_source = self.secondary_source if not is_inst_mixture else None + self.write_to_console(f'{SAVING_STEM[0]}{sec_stem_name}{SAVING_STEM[1]}') if not self.is_secondary_model else None + secondary_stem_path = os.path.join(self.export_path, f'{self.audio_file_base}_({sec_stem_name}).wav') + secondary_source_secondary = None + + if not isinstance(secondary_source, np.ndarray): + if self.is_demucs_combine_stems: + source = list(source) + if is_inst_mixture: + source = [i for n, i in enumerate(source) if not n in [self.demucs_source_map[self.primary_stem], self.demucs_source_map[VOCAL_STEM]]] + else: + source.pop(self.demucs_source_map[self.primary_stem]) + + source = source[:len(source) - 2] if is_no_piano_guitar else source + secondary_source = np.zeros_like(source[0]) + for i in source: + secondary_source += i + secondary_source = spec_utils.normalize(secondary_source, self.is_normalization).T + else: + if not isinstance(raw_mixture, np.ndarray): + raw_mixture = prepare_mix(self.audio_file, self.chunks_demucs, self.margin_demucs, is_missing_mix=True) + + secondary_source, raw_mixture = spec_utils.normalize_two_stem(source[self.demucs_source_map[self.primary_stem]], raw_mixture, self.is_normalization) + + if self.is_invert_spec: + secondary_source = spec_utils.invert_stem(raw_mixture, secondary_source) + else: + raw_mixture = spec_utils.reshape_sources(secondary_source, raw_mixture) + secondary_source = (-secondary_source.T+raw_mixture.T) + + if not is_inst_mixture: + self.secondary_source = secondary_source + secondary_source_secondary = self.secondary_source_secondary + self.secondary_source_map = {self.secondary_stem: self.secondary_source} + + self.write_audio(secondary_stem_path, secondary_source, samplerate, secondary_source_secondary) + + secondary_save(self.secondary_stem, source, raw_mixture=raw_mix) + + if self.is_demucs_pre_proc_model_inst_mix and self.pre_proc_model and not self.is_4_stem_ensemble: + secondary_save(f"{self.secondary_stem} {INST_STEM}", source, raw_mixture=inst_raw_mix, is_inst_mixture=True) + + torch.cuda.empty_cache() + + secondary_sources = {**self.primary_source_map, **self.secondary_source_map} + + self.cache_source(secondary_sources) + + if self.is_secondary_model: + return secondary_sources + + def demix_demucs(self, mix): + processed = {} + + set_progress_bar = None if self.is_chunk_demucs else self.set_progress_bar + + for nmix in mix: + self.progress_value += 1 + self.set_progress_bar(0.1, (0.8/len(mix)*self.progress_value)) if self.is_chunk_demucs else None + cmix = mix[nmix] + cmix = torch.tensor(cmix, dtype=torch.float32) + ref = cmix.mean(0) + cmix = (cmix - ref.mean()) / ref.std() + mix_infer = cmix + + with torch.no_grad(): + if self.demucs_version == DEMUCS_V1: + sources = apply_model_v1(self.demucs, + mix_infer.to(self.device), + self.shifts, + self.is_split_mode, + set_progress_bar=set_progress_bar) + elif self.demucs_version == DEMUCS_V2: + sources = apply_model_v2(self.demucs, + mix_infer.to(self.device), + self.shifts, + self.is_split_mode, + self.overlap, + set_progress_bar=set_progress_bar) + else: + sources = apply_model(self.demucs, + mix_infer[None], + self.shifts, + self.is_split_mode, + self.overlap, + static_shifts=1 if self.shifts == 0 else self.shifts, + set_progress_bar=set_progress_bar, + device=self.device)[0] + + sources = (sources * ref.std() + ref.mean()).cpu().numpy() + sources[[0,1]] = sources[[1,0]] + start = 0 if nmix == 0 else self.margin_demucs + end = None if nmix == list(mix.keys())[::-1][0] else -self.margin_demucs + if self.margin_demucs == 0: + end = None + processed[nmix] = sources[:,:,start:end].copy() + sources = list(processed.values()) + sources = np.concatenate(sources, axis=-1) + + return sources + +class SeperateVR(SeperateAttributes): + + def seperate(self): + + if self.primary_model_name == self.model_basename and self.primary_sources: + self.primary_source, self.secondary_source = self.load_cached_sources() + else: + self.start_inference() + if self.is_gpu_conversion >= 0: + device = torch.device('cuda:0' if torch.cuda.is_available() else 'cpu') + else: + device = torch.device('cpu') + + nn_arch_sizes = [ + 31191, # default + 33966, 56817, 218409, 123821, 123812, 129605, 537238, 537227] + vr_5_1_models = [56817, 218409] + + model_size = math.ceil(os.stat(self.model_path).st_size / 1024) + nn_architecture = min(nn_arch_sizes, key=lambda x:abs(x-model_size)) + + #print('ARC SIZE: ', nn_architecture) + + if nn_architecture in vr_5_1_models: + model = nets_new.CascadedNet(self.mp.param['bins'] * 2, nn_architecture) + inference = self.inference_vr_new + else: + model = nets.determine_model_capacity(self.mp.param['bins'] * 2, nn_architecture) + inference = self.inference_vr + + model.load_state_dict(torch.load(self.model_path, map_location=device)) + model.to(device) + + self.running_inference() + + y_spec, v_spec = inference(self.loading_mix(), device, model, self.aggressiveness) + self.write_to_console(DONE, base_text='') + + del model + + if self.is_secondary_model_activated: + if self.secondary_model: + self.secondary_source_primary, self.secondary_source_secondary = process_secondary_model(self.secondary_model, self.process_data, main_process_method=self.process_method) + + if not self.is_secondary_stem_only: + self.write_to_console(f'{SAVING_STEM[0]}{self.primary_stem}{SAVING_STEM[1]}') if not self.is_secondary_model else None + primary_stem_path = os.path.join(self.export_path, f'{self.audio_file_base}_({self.primary_stem}).wav') + if not isinstance(self.primary_source, np.ndarray): + self.primary_source = spec_utils.normalize(self.spec_to_wav(y_spec), self.is_normalization).T + if not self.model_samplerate == 44100: + self.primary_source = librosa.resample(self.primary_source.T, orig_sr=self.model_samplerate, target_sr=44100).T + + self.primary_source_map = {self.primary_stem: self.primary_source} + + self.write_audio(primary_stem_path, self.primary_source, 44100, self.secondary_source_primary) + + if not self.is_primary_stem_only: + self.write_to_console(f'{SAVING_STEM[0]}{self.secondary_stem}{SAVING_STEM[1]}') if not self.is_secondary_model else None + secondary_stem_path = os.path.join(self.export_path, f'{self.audio_file_base}_({self.secondary_stem}).wav') + if not isinstance(self.secondary_source, np.ndarray): + self.secondary_source = self.spec_to_wav(v_spec) + self.secondary_source = spec_utils.normalize(self.spec_to_wav(v_spec), self.is_normalization).T + if not self.model_samplerate == 44100: + self.secondary_source = librosa.resample(self.secondary_source.T, orig_sr=self.model_samplerate, target_sr=44100).T + + self.secondary_source_map = {self.secondary_stem: self.secondary_source} + + self.write_audio(secondary_stem_path, self.secondary_source, 44100, self.secondary_source_secondary) + + torch.cuda.empty_cache() + + secondary_sources = {**self.primary_source_map, **self.secondary_source_map} + self.cache_source(secondary_sources) + + if self.is_secondary_model: + return secondary_sources + + def loading_mix(self): + + X_wave, X_spec_s = {}, {} + + bands_n = len(self.mp.param['band']) + + for d in range(bands_n, 0, -1): + bp = self.mp.param['band'][d] + + if d == bands_n: # high-end band + X_wave[d], _ = librosa.load( + self.audio_file, bp['sr'], False, dtype=np.float32, res_type=bp['res_type']) + + if X_wave[d].ndim == 1: + X_wave[d] = np.asarray([X_wave[d], X_wave[d]]) + else: # lower bands + X_wave[d] = librosa.resample(X_wave[d+1], self.mp.param['band'][d+1]['sr'], bp['sr'], res_type=bp['res_type']) + + X_spec_s[d] = spec_utils.wave_to_spectrogram_mt(X_wave[d], bp['hl'], bp['n_fft'], self.mp.param['mid_side'], + self.mp.param['mid_side_b2'], self.mp.param['reverse']) + + if d == bands_n and self.high_end_process != 'none': + self.input_high_end_h = (bp['n_fft']//2 - bp['crop_stop']) + (self.mp.param['pre_filter_stop'] - self.mp.param['pre_filter_start']) + self.input_high_end = X_spec_s[d][:, bp['n_fft']//2-self.input_high_end_h:bp['n_fft']//2, :] + + X_spec = spec_utils.combine_spectrograms(X_spec_s, self.mp) + + del X_wave, X_spec_s + + return X_spec + + def inference_vr(self, X_spec, device, model, aggressiveness): + + def _execute(X_mag_pad, roi_size, n_window, device, model, aggressiveness): + model.eval() + + total_iterations = sum([n_window]) if not self.is_tta else sum([n_window])*2 + + with torch.no_grad(): + preds = [] + + for i in range(n_window): + self.progress_value +=1 + self.set_progress_bar(0.1, 0.8/total_iterations*self.progress_value) + start = i * roi_size + X_mag_window = X_mag_pad[None, :, :, start:start + self.window_size] + X_mag_window = torch.from_numpy(X_mag_window).to(device) + pred = model.predict(X_mag_window, aggressiveness) + pred = pred.detach().cpu().numpy() + preds.append(pred[0]) + + pred = np.concatenate(preds, axis=2) + return pred + + X_mag, X_phase = spec_utils.preprocess(X_spec) + coef = X_mag.max() + X_mag_pre = X_mag / coef + n_frame = X_mag_pre.shape[2] + pad_l, pad_r, roi_size = spec_utils.make_padding(n_frame, self.window_size, model.offset) + n_window = int(np.ceil(n_frame / roi_size)) + X_mag_pad = np.pad( + X_mag_pre, ((0, 0), (0, 0), (pad_l, pad_r)), mode='constant') + pred = _execute(X_mag_pad, roi_size, n_window, device, model, aggressiveness) + pred = pred[:, :, :n_frame] + + if self.is_tta: + pad_l += roi_size // 2 + pad_r += roi_size // 2 + n_window += 1 + X_mag_pad = np.pad(X_mag_pre, ((0, 0), (0, 0), (pad_l, pad_r)), mode='constant') + pred_tta = _execute(X_mag_pad, roi_size, n_window, device, model, aggressiveness) + pred_tta = pred_tta[:, :, roi_size // 2:] + pred_tta = pred_tta[:, :, :n_frame] + pred, X_mag, X_phase = (pred + pred_tta) * 0.5 * coef, X_mag, np.exp(1.j * X_phase) + else: + pred, X_mag, X_phase = pred * coef, X_mag, np.exp(1.j * X_phase) + + if self.is_post_process: + pred_inv = np.clip(X_mag - pred, 0, np.inf) + pred = spec_utils.mask_silence(pred, pred_inv) + + y_spec = pred * X_phase + v_spec = X_spec - y_spec + + return y_spec, v_spec + + def inference_vr_new(self, X_spec, device, model, aggressiveness): + + def _execute(X_mag_pad, roi_size): + + X_dataset = [] + patches = (X_mag_pad.shape[2] - 2 * model.offset) // roi_size + total_iterations = patches//self.batch_size if not self.is_tta else (patches//self.batch_size)*2 + + for i in range(patches): + start = i * roi_size + X_mag_crop = X_mag_pad[:, :, start:start + self.crop_size] + X_dataset.append(X_mag_crop) + + X_dataset = np.asarray(X_dataset) + model.eval() + + with torch.no_grad(): + mask = [] + # To reduce the overhead, dataloader is not used. + for i in range(0, patches, self.batch_size): + self.progress_value += 1 + if self.progress_value >= total_iterations: + self.progress_value = total_iterations + self.set_progress_bar(0.1, 0.8/total_iterations*self.progress_value) + X_batch = X_dataset[i: i + self.batch_size] + X_batch = torch.from_numpy(X_batch).to(device) + pred = model.predict_mask(X_batch) + pred = pred.detach().cpu().numpy() + pred = np.concatenate(pred, axis=2) + mask.append(pred) + + mask = np.concatenate(mask, axis=2) + + return mask + + def postprocess(mask, X_mag, X_phase, aggressiveness): + + if self.primary_stem == VOCAL_STEM: + mask = (1.0 - spec_utils.adjust_aggr(mask, True, aggressiveness)) + else: + mask = spec_utils.adjust_aggr(mask, False, aggressiveness) + + if self.is_post_process: + mask = spec_utils.merge_artifacts(mask) + + y_spec = mask * X_mag * np.exp(1.j * X_phase) + v_spec = (1 - mask) * X_mag * np.exp(1.j * X_phase) + + return y_spec, v_spec + + X_mag, X_phase = spec_utils.preprocess(X_spec) + n_frame = X_mag.shape[2] + pad_l, pad_r, roi_size = spec_utils.make_padding(n_frame, self.crop_size, model.offset) + X_mag_pad = np.pad(X_mag, ((0, 0), (0, 0), (pad_l, pad_r)), mode='constant') + X_mag_pad /= X_mag_pad.max() + mask = _execute(X_mag_pad, roi_size) + + if self.is_tta: + pad_l += roi_size // 2 + pad_r += roi_size // 2 + X_mag_pad = np.pad(X_mag, ((0, 0), (0, 0), (pad_l, pad_r)), mode='constant') + X_mag_pad /= X_mag_pad.max() + mask_tta = _execute(X_mag_pad, roi_size) + mask_tta = mask_tta[:, :, roi_size // 2:] + mask = (mask[:, :, :n_frame] + mask_tta[:, :, :n_frame]) * 0.5 + else: + mask = mask[:, :, :n_frame] + + y_spec, v_spec = postprocess(mask, X_mag, X_phase, aggressiveness) + + return y_spec, v_spec + + def spec_to_wav(self, spec): + + if self.high_end_process.startswith('mirroring'): + input_high_end_ = spec_utils.mirroring(self.high_end_process, spec, self.input_high_end, self.mp) + wav = spec_utils.cmb_spectrogram_to_wave(spec, self.mp, self.input_high_end_h, input_high_end_) + else: + wav = spec_utils.cmb_spectrogram_to_wave(spec, self.mp) + + return wav + +def process_secondary_model(secondary_model: ModelData, process_data, main_model_primary_stem_4_stem=None, is_4_stem_demucs=False, main_process_method=None, is_pre_proc_model=False): + + if not is_pre_proc_model: + process_iteration = process_data['process_iteration'] + process_iteration() + + if secondary_model.process_method == VR_ARCH_TYPE: + seperator = SeperateVR(secondary_model, process_data, main_model_primary_stem_4_stem=main_model_primary_stem_4_stem, main_process_method=main_process_method) + if secondary_model.process_method == MDX_ARCH_TYPE: + seperator = SeperateMDX(secondary_model, process_data, main_model_primary_stem_4_stem=main_model_primary_stem_4_stem, main_process_method=main_process_method) + if secondary_model.process_method == DEMUCS_ARCH_TYPE: + seperator = SeperateDemucs(secondary_model, process_data, main_model_primary_stem_4_stem=main_model_primary_stem_4_stem, main_process_method=main_process_method) + + secondary_sources = seperator.seperate() + + if type(secondary_sources) is dict and not is_4_stem_demucs and not is_pre_proc_model: + return gather_sources(secondary_model.primary_model_primary_stem, STEM_PAIR_MAPPER[secondary_model.primary_model_primary_stem], secondary_sources) + else: + return secondary_sources + +def gather_sources(primary_stem_name, secondary_stem_name, secondary_sources: dict): + + source_primary = False + source_secondary = False + + for key, value in secondary_sources.items(): + if key in primary_stem_name: + source_primary = value + if key in secondary_stem_name: + source_secondary = value + + return source_primary, source_secondary + +def prepare_mix(mix, chunk_set, margin_set, mdx_net_cut=False, is_missing_mix=False): + + samplerate = 44100 + + if not isinstance(mix, np.ndarray): + mix, samplerate = librosa.load(mix, mono=False, sr=44100) + else: + mix = mix.T + + if mix.ndim == 1: + mix = np.asfortranarray([mix,mix]) + + def get_segmented_mix(chunk_set=chunk_set): + segmented_mix = {} + + samples = mix.shape[-1] + margin = margin_set + chunk_size = chunk_set*44100 + assert not margin == 0, 'margin cannot be zero!' + if margin > chunk_size: + margin = chunk_size + if chunk_set == 0 or samples < chunk_size: + chunk_size = samples + + counter = -1 + for skip in range(0, samples, chunk_size): + counter+=1 + s_margin = 0 if counter == 0 else margin + end = min(skip+chunk_size+margin, samples) + start = skip-s_margin + segmented_mix[skip] = mix[:,start:end].copy() + if end == samples: + break + + return segmented_mix + + if is_missing_mix: + return mix + else: + segmented_mix = get_segmented_mix() + raw_mix = get_segmented_mix(chunk_set=0) if mdx_net_cut else mix + return segmented_mix, raw_mix, samplerate + +def save_format(audio_path, save_format, mp3_bit_set): + + if not save_format == WAV: + + musfile = pydub.AudioSegment.from_wav(audio_path) + + if save_format == FLAC: + audio_path_flac = audio_path.replace(".wav", ".flac") + musfile.export(audio_path_flac, format="flac") + + if save_format == MP3: + audio_path_mp3 = audio_path.replace(".wav", ".mp3") + musfile.export(audio_path_mp3, format="mp3", bitrate=mp3_bit_set) + + try: + os.remove(audio_path) + except Exception as e: + print(e) \ No newline at end of file