From f5606429d33e8a2a506672542c60b9c95578bcf4 Mon Sep 17 00:00:00 2001 From: Politrees <143968312+Bebra777228@users.noreply.github.com> Date: Sun, 3 Nov 2024 13:27:20 +0500 Subject: [PATCH 1/6] Update mdx_separator.py --- .../separator/architectures/mdx_separator.py | 15 ++++++++++++--- 1 file changed, 12 insertions(+), 3 deletions(-) diff --git a/audio_separator/separator/architectures/mdx_separator.py b/audio_separator/separator/architectures/mdx_separator.py index babaa67..6f43a57 100644 --- a/audio_separator/separator/architectures/mdx_separator.py +++ b/audio_separator/separator/architectures/mdx_separator.py @@ -132,13 +132,15 @@ def load_model(self): self.model_run.to(self.torch_device).eval() self.logger.warning("Model converted from onnx to pytorch due to segment size not matching dim_t, processing may be slower.") - def separate(self, audio_file_path): + def separate(self, audio_file_path, primary_output_name=None, secondary_output_name=None): """ Separates the audio file into primary and secondary sources based on the model's configuration. It processes the mix, demixes it into sources, normalizes the sources, and saves the output files. Args: audio_file_path (str): The path to the audio file to be processed. + primary_output_name (str, optional): Custom name for the primary output file. Defaults to None. + secondary_output_name (str, optional): Custom name for the secondary output file. Defaults to None. Returns: list: A list of paths to the output files generated by the separation process. @@ -182,7 +184,10 @@ def separate(self, audio_file_path): # Save and process the secondary stem if needed if not self.output_single_stem or self.output_single_stem.lower() == self.secondary_stem_name.lower(): - self.secondary_stem_output_path = os.path.join(f"{self.audio_file_base}_({self.secondary_stem_name})_{self.model_name}.{self.output_format.lower()}") + if secondary_output_name: + self.secondary_stem_output_path = os.path.join(f"{secondary_output_name}.{self.output_format.lower()}") + else: + self.secondary_stem_output_path = os.path.join(f"{self.audio_file_base}_({self.secondary_stem_name})_{self.model_name}.{self.output_format.lower()}") self.logger.info(f"Saving {self.secondary_stem_name} stem to {self.secondary_stem_output_path}...") self.final_process(self.secondary_stem_output_path, self.secondary_source, self.secondary_stem_name) @@ -190,7 +195,11 @@ def separate(self, audio_file_path): # Save and process the primary stem if needed if not self.output_single_stem or self.output_single_stem.lower() == self.primary_stem_name.lower(): - self.primary_stem_output_path = os.path.join(f"{self.audio_file_base}_({self.primary_stem_name})_{self.model_name}.{self.output_format.lower()}") + if primary_output_name: + self.primary_stem_output_path = os.path.join(f"{primary_output_name}.{self.output_format.lower()}") + else: + self.primary_stem_output_path = os.path.join(f"{self.audio_file_base}_({self.primary_stem_name})_{self.model_name}.{self.output_format.lower()}") + if not isinstance(self.primary_source, np.ndarray): self.primary_source = source.T From f32c3b7ebfe45811d538cbf0e510924390a82b25 Mon Sep 17 00:00:00 2001 From: Politrees <143968312+Bebra777228@users.noreply.github.com> Date: Sun, 3 Nov 2024 13:27:52 +0500 Subject: [PATCH 2/6] Update mdxc_separator.py --- .../separator/architectures/mdxc_separator.py | 14 +++++++++++--- 1 file changed, 11 insertions(+), 3 deletions(-) diff --git a/audio_separator/separator/architectures/mdxc_separator.py b/audio_separator/separator/architectures/mdxc_separator.py index 459a4ae..009ea7e 100644 --- a/audio_separator/separator/architectures/mdxc_separator.py +++ b/audio_separator/separator/architectures/mdxc_separator.py @@ -111,13 +111,15 @@ def load_model(self): self.logger.error(f"Please try deleting the model file from {self.model_path} and run audio-separator again to re-download it.") sys.exit(1) - def separate(self, audio_file_path): + def separate(self, audio_file_path, primary_output_name=None, secondary_output_name=None): """ Separates the audio file into primary and secondary sources based on the model's configuration. It processes the mix, demixes it into sources, normalizes the sources, and saves the output files. Args: audio_file_path (str): The path to the audio file to be processed. + primary_output_name (str, optional): Custom name for the primary output file. Defaults to None. + secondary_output_name (str, optional): Custom name for the secondary output file. Defaults to None. Returns: list: A list of paths to the output files generated by the separation process. @@ -152,14 +154,20 @@ def separate(self, audio_file_path): self.secondary_source = spec_utils.normalize(wave=source[self.secondary_stem_name], max_peak=self.normalization_threshold, min_peak=self.amplification_threshold).T if not self.output_single_stem or self.output_single_stem.lower() == self.secondary_stem_name.lower(): - self.secondary_stem_output_path = os.path.join(f"{self.audio_file_base}_({self.secondary_stem_name})_{self.model_name}.{self.output_format.lower()}") + if secondary_output_name: + self.secondary_stem_output_path = os.path.join(f"{secondary_output_name}.{self.output_format.lower()}") + else: + self.secondary_stem_output_path = os.path.join(f"{self.audio_file_base}_({self.secondary_stem_name})_{self.model_name}.{self.output_format.lower()}") self.logger.info(f"Saving {self.secondary_stem_name} stem to {self.secondary_stem_output_path}...") self.final_process(self.secondary_stem_output_path, self.secondary_source, self.secondary_stem_name) output_files.append(self.secondary_stem_output_path) if not isinstance(source, dict) or not self.output_single_stem or self.output_single_stem.lower() == self.primary_stem_name.lower(): - self.primary_stem_output_path = os.path.join(f"{self.audio_file_base}_({self.primary_stem_name})_{self.model_name}.{self.output_format.lower()}") + if primary_output_name: + self.primary_stem_output_path = os.path.join(f"{primary_output_name}.{self.output_format.lower()}") + else: + self.primary_stem_output_path = os.path.join(f"{self.audio_file_base}_({self.primary_stem_name})_{self.model_name}.{self.output_format.lower()}") if not isinstance(self.primary_source, np.ndarray): self.primary_source = source.T From 59b38d720aea871421db1953b1b059045ac0ca54 Mon Sep 17 00:00:00 2001 From: Politrees <143968312+Bebra777228@users.noreply.github.com> Date: Sun, 3 Nov 2024 13:28:27 +0500 Subject: [PATCH 3/6] Update vr_separator.py --- .../separator/architectures/vr_separator.py | 14 +++++++++++--- 1 file changed, 11 insertions(+), 3 deletions(-) diff --git a/audio_separator/separator/architectures/vr_separator.py b/audio_separator/separator/architectures/vr_separator.py index bbf16c8..1820d78 100644 --- a/audio_separator/separator/architectures/vr_separator.py +++ b/audio_separator/separator/architectures/vr_separator.py @@ -110,13 +110,15 @@ def __init__(self, common_config, arch_config: dict): self.logger.info("VR Separator initialisation complete") - def separate(self, audio_file_path): + def separate(self, audio_file_path, primary_output_name=None, secondary_output_name=None): """ Separates the audio file into primary and secondary sources based on the model's configuration. It processes the mix, demixes it into sources, normalizes the sources, and saves the output files. Args: audio_file_path (str): The path to the audio file to be processed. + primary_output_name (str, optional): Custom name for the primary output file. Defaults to None. + secondary_output_name (str, optional): Custom name for the secondary output file. Defaults to None. Returns: list: A list of paths to the output files generated by the separation process. @@ -195,7 +197,10 @@ def separate(self, audio_file_path): self.primary_source = librosa.resample(self.primary_source.T, orig_sr=self.model_samplerate, target_sr=44100).T self.logger.debug("Resampling primary source to 44100Hz.") - self.primary_stem_output_path = os.path.join(f"{self.audio_file_base}_({self.primary_stem_name})_{self.model_name}.{self.output_format.lower()}") + if primary_output_name: + self.primary_stem_output_path = os.path.join(f"{primary_output_name}.{self.output_format.lower()}") + else: + self.primary_stem_output_path = os.path.join(f"{self.audio_file_base}_({self.primary_stem_name})_{self.model_name}.{self.output_format.lower()}") self.logger.info(f"Saving {self.primary_stem_name} stem to {self.primary_stem_output_path}...") self.final_process(self.primary_stem_output_path, self.primary_source, self.primary_stem_name) @@ -213,7 +218,10 @@ def separate(self, audio_file_path): self.secondary_source = librosa.resample(self.secondary_source.T, orig_sr=self.model_samplerate, target_sr=44100).T self.logger.debug("Resampling secondary source to 44100Hz.") - self.secondary_stem_output_path = os.path.join(f"{self.audio_file_base}_({self.secondary_stem_name})_{self.model_name}.{self.output_format.lower()}") + if secondary_output_name: + self.secondary_stem_output_path = os.path.join(f"{secondary_output_name}.{self.output_format.lower()}") + else: + self.secondary_stem_output_path = os.path.join(f"{self.audio_file_base}_({self.secondary_stem_name})_{self.model_name}.{self.output_format.lower()}") self.logger.info(f"Saving {self.secondary_stem_name} stem to {self.secondary_stem_output_path}...") self.final_process(self.secondary_stem_output_path, self.secondary_source, self.secondary_stem_name) From e5538be42c34c3282a6cfe91e1f412fb7a04e540 Mon Sep 17 00:00:00 2001 From: Politrees <143968312+Bebra777228@users.noreply.github.com> Date: Sun, 3 Nov 2024 13:29:07 +0500 Subject: [PATCH 4/6] Update separator.py --- audio_separator/separator/separator.py | 8 +++++--- 1 file changed, 5 insertions(+), 3 deletions(-) diff --git a/audio_separator/separator/separator.py b/audio_separator/separator/separator.py index d92d354..3e1d9ca 100644 --- a/audio_separator/separator/separator.py +++ b/audio_separator/separator/separator.py @@ -737,7 +737,7 @@ def load_model(self, model_filename="model_mel_band_roformer_ep_3005_sdr_11.4360 self.logger.debug("Loading model completed.") self.logger.info(f'Load model duration: {time.strftime("%H:%M:%S", time.gmtime(int(time.perf_counter() - load_model_start_time)))}') - def separate(self, audio_file_path): + def separate(self, audio_file_path, primary_output_name=None, secondary_output_name=None): """ Separates the audio file into different stems (e.g., vocals, instruments) using the loaded model. @@ -747,6 +747,8 @@ def separate(self, audio_file_path): Parameters: - audio_file_path (str): The path to the audio file to be separated. + - primary_output_name (str, optional): Custom name for the primary output file. Defaults to None. + - secondary_output_name (str, optional): Custom name for the secondary output file. Defaults to None. Returns: - output_files (list of str): A list containing the paths to the separated audio stem files. @@ -766,10 +768,10 @@ def separate(self, audio_file_path): if self.use_autocast and autocast_mode.is_autocast_available(self.torch_device.type): self.logger.debug("Autocast available.") with autocast_mode.autocast(self.torch_device.type): - output_files = self.model_instance.separate(audio_file_path) + output_files = self.model_instance.separate(audio_file_path, primary_output_name, secondary_output_name) else: self.logger.debug("Autocast unavailable.") - output_files = self.model_instance.separate(audio_file_path) + output_files = self.model_instance.separate(audio_file_path, primary_output_name, secondary_output_name) # Clear GPU cache to free up memory self.model_instance.clear_gpu_cache() From cc54f26f76d049b115d1b1d8c506ab646fb21c28 Mon Sep 17 00:00:00 2001 From: Politrees <143968312+Bebra777228@users.noreply.github.com> Date: Sun, 3 Nov 2024 13:29:47 +0500 Subject: [PATCH 5/6] Update cli.py --- audio_separator/utils/cli.py | 6 +++++- 1 file changed, 5 insertions(+), 1 deletion(-) diff --git a/audio_separator/utils/cli.py b/audio_separator/utils/cli.py index 2030b9c..548c965 100755 --- a/audio_separator/utils/cli.py +++ b/audio_separator/utils/cli.py @@ -56,6 +56,8 @@ def main(): sample_rate_help = "modify the sample rate of the output audio (default: %(default)s). Example: --sample_rate=44100" use_soundfile_help = "Use soundfile to write audio output (default: %(default)s). Example: --use_soundfile" use_autocast_help = "use PyTorch autocast for faster inference (default: %(default)s). Do not use for CPU inference. Example: --use_autocast" + primary_output_name_help = "Custom name for the primary output file (default: %(default)s). Example: --primary_output_name=custom_primary_output" + secondary_output_name_help = "Custom name for the secondary output file (default: %(default)s). Example: --secondary_output_name=custom_secondary_output" common_params = parser.add_argument_group("Common Separation Parameters") common_params.add_argument("--invert_spect", action="store_true", help=invert_spect_help) @@ -65,6 +67,8 @@ def main(): common_params.add_argument("--sample_rate", type=int, default=44100, help=sample_rate_help) common_params.add_argument("--use_soundfile", action="store_true", help=use_soundfile_help) common_params.add_argument("--use_autocast", action="store_true", help=use_autocast_help) + common_params.add_argument("--primary_output_name", default=None, help=primary_output_name_help) + common_params.add_argument("--secondary_output_name", default=None, help=secondary_output_name_help) mdx_segment_size_help = "larger consumes more resources, but may give better results (default: %(default)s). Example: --mdx_segment_size=256" mdx_overlap_help = "amount of overlap between prediction windows, 0.001-0.999. higher is better but slower (default: %(default)s). Example: --mdx_overlap=0.25" @@ -201,6 +205,6 @@ def main(): separator.load_model(model_filename=args.model_filename) for audio_file in args.audio_files: - output_files = separator.separate(audio_file) + output_files = separator.separate(audio_file, primary_output_name=args.primary_output_name, secondary_output_name=args.secondary_output_name) logger.info(f"Separation complete! Output file(s): {' '.join(output_files)}") From 966a434adbef62e7943bb078328d3bccb6423d38 Mon Sep 17 00:00:00 2001 From: Politrees <143968312+Bebra777228@users.noreply.github.com> Date: Sun, 3 Nov 2024 13:31:07 +0500 Subject: [PATCH 6/6] Update test_cli.py --- tests/unit/test_cli.py | 44 ++++++++++++++++++++++++++++++++++++++++++ 1 file changed, 44 insertions(+) diff --git a/tests/unit/test_cli.py b/tests/unit/test_cli.py index 3cc7a28..56b66d6 100644 --- a/tests/unit/test_cli.py +++ b/tests/unit/test_cli.py @@ -224,6 +224,7 @@ def test_cli_invert_spectrogram_argument(common_expected_args): # Assertions mock_separator.assert_called_once_with(**expected_args) + # Test using use_autocast argument def test_cli_use_autocast_argument(common_expected_args): test_args = ["cli.py", "test_audio.mp3", "--use_autocast"] @@ -240,6 +241,7 @@ def test_cli_use_autocast_argument(common_expected_args): # Assertions mock_separator.assert_called_once_with(**common_expected_args) + # Test using use_autocast argument def test_cli_use_autocast_argument(common_expected_args): test_args = ["cli.py", "test_audio.mp3", "--use_autocast"] @@ -254,3 +256,45 @@ def test_cli_use_autocast_argument(common_expected_args): # Assertions mock_separator.assert_called_once_with(**common_expected_args) + + +# Test using primary_output_name argument +def test_cli_primary_output_name_argument(common_expected_args): + test_args = ["cli.py", "test_audio.mp3", "--primary_output_name=custom_primary_output"] + with patch("sys.argv", test_args): + with patch("audio_separator.separator.Separator") as mock_separator: + mock_separator_instance = mock_separator.return_value + mock_separator_instance.separate.return_value = ["output_file.mp3"] + main() + + # Assertions + mock_separator.assert_called_once_with(**common_expected_args) + mock_separator_instance.separate.assert_called_once_with("test_audio.mp3", primary_output_name="custom_primary_output", secondary_output_name=None) + + +# Test using secondary_output_name argument +def test_cli_secondary_output_name_argument(common_expected_args): + test_args = ["cli.py", "test_audio.mp3", "--secondary_output_name=custom_secondary_output"] + with patch("sys.argv", test_args): + with patch("audio_separator.separator.Separator") as mock_separator: + mock_separator_instance = mock_separator.return_value + mock_separator_instance.separate.return_value = ["output_file.mp3"] + main() + + # Assertions + mock_separator.assert_called_once_with(**common_expected_args) + mock_separator_instance.separate.assert_called_once_with("test_audio.mp3", primary_output_name=None, secondary_output_name="custom_secondary_output") + + +# Test using both primary_output_name and secondary_output_name arguments +def test_cli_both_output_names_argument(common_expected_args): + test_args = ["cli.py", "test_audio.mp3", "--primary_output_name=custom_primary_output", "--secondary_output_name=custom_secondary_output"] + with patch("sys.argv", test_args): + with patch("audio_separator.separator.Separator") as mock_separator: + mock_separator_instance = mock_separator.return_value + mock_separator_instance.separate.return_value = ["output_file.mp3"] + main() + + # Assertions + mock_separator.assert_called_once_with(**common_expected_args) + mock_separator_instance.separate.assert_called_once_with("test_audio.mp3", primary_output_name="custom_primary_output", secondary_output_name="custom_secondary_output")