Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Implementing Custom Output File Naming for MDX, MDXC, and VR Models #141

Merged
merged 6 commits into from
Nov 3, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
15 changes: 12 additions & 3 deletions audio_separator/separator/architectures/mdx_separator.py
Original file line number Diff line number Diff line change
Expand Up @@ -132,13 +132,15 @@ def load_model(self):
self.model_run.to(self.torch_device).eval()
self.logger.warning("Model converted from onnx to pytorch due to segment size not matching dim_t, processing may be slower.")

def separate(self, audio_file_path):
def separate(self, audio_file_path, primary_output_name=None, secondary_output_name=None):
"""
Separates the audio file into primary and secondary sources based on the model's configuration.
It processes the mix, demixes it into sources, normalizes the sources, and saves the output files.

Args:
audio_file_path (str): The path to the audio file to be processed.
primary_output_name (str, optional): Custom name for the primary output file. Defaults to None.
secondary_output_name (str, optional): Custom name for the secondary output file. Defaults to None.

Returns:
list: A list of paths to the output files generated by the separation process.
Expand Down Expand Up @@ -182,15 +184,22 @@ def separate(self, audio_file_path):

# Save and process the secondary stem if needed
if not self.output_single_stem or self.output_single_stem.lower() == self.secondary_stem_name.lower():
self.secondary_stem_output_path = os.path.join(f"{self.audio_file_base}_({self.secondary_stem_name})_{self.model_name}.{self.output_format.lower()}")
if secondary_output_name:
self.secondary_stem_output_path = os.path.join(f"{secondary_output_name}.{self.output_format.lower()}")
else:
self.secondary_stem_output_path = os.path.join(f"{self.audio_file_base}_({self.secondary_stem_name})_{self.model_name}.{self.output_format.lower()}")

self.logger.info(f"Saving {self.secondary_stem_name} stem to {self.secondary_stem_output_path}...")
self.final_process(self.secondary_stem_output_path, self.secondary_source, self.secondary_stem_name)
output_files.append(self.secondary_stem_output_path)

# Save and process the primary stem if needed
if not self.output_single_stem or self.output_single_stem.lower() == self.primary_stem_name.lower():
self.primary_stem_output_path = os.path.join(f"{self.audio_file_base}_({self.primary_stem_name})_{self.model_name}.{self.output_format.lower()}")
if primary_output_name:
self.primary_stem_output_path = os.path.join(f"{primary_output_name}.{self.output_format.lower()}")
else:
self.primary_stem_output_path = os.path.join(f"{self.audio_file_base}_({self.primary_stem_name})_{self.model_name}.{self.output_format.lower()}")

if not isinstance(self.primary_source, np.ndarray):
self.primary_source = source.T

Expand Down
14 changes: 11 additions & 3 deletions audio_separator/separator/architectures/mdxc_separator.py
Original file line number Diff line number Diff line change
Expand Up @@ -111,13 +111,15 @@ def load_model(self):
self.logger.error(f"Please try deleting the model file from {self.model_path} and run audio-separator again to re-download it.")
sys.exit(1)

def separate(self, audio_file_path):
def separate(self, audio_file_path, primary_output_name=None, secondary_output_name=None):
"""
Separates the audio file into primary and secondary sources based on the model's configuration.
It processes the mix, demixes it into sources, normalizes the sources, and saves the output files.

Args:
audio_file_path (str): The path to the audio file to be processed.
primary_output_name (str, optional): Custom name for the primary output file. Defaults to None.
secondary_output_name (str, optional): Custom name for the secondary output file. Defaults to None.

Returns:
list: A list of paths to the output files generated by the separation process.
Expand Down Expand Up @@ -152,14 +154,20 @@ def separate(self, audio_file_path):
self.secondary_source = spec_utils.normalize(wave=source[self.secondary_stem_name], max_peak=self.normalization_threshold, min_peak=self.amplification_threshold).T

if not self.output_single_stem or self.output_single_stem.lower() == self.secondary_stem_name.lower():
self.secondary_stem_output_path = os.path.join(f"{self.audio_file_base}_({self.secondary_stem_name})_{self.model_name}.{self.output_format.lower()}")
if secondary_output_name:
self.secondary_stem_output_path = os.path.join(f"{secondary_output_name}.{self.output_format.lower()}")
else:
self.secondary_stem_output_path = os.path.join(f"{self.audio_file_base}_({self.secondary_stem_name})_{self.model_name}.{self.output_format.lower()}")

self.logger.info(f"Saving {self.secondary_stem_name} stem to {self.secondary_stem_output_path}...")
self.final_process(self.secondary_stem_output_path, self.secondary_source, self.secondary_stem_name)
output_files.append(self.secondary_stem_output_path)

if not isinstance(source, dict) or not self.output_single_stem or self.output_single_stem.lower() == self.primary_stem_name.lower():
self.primary_stem_output_path = os.path.join(f"{self.audio_file_base}_({self.primary_stem_name})_{self.model_name}.{self.output_format.lower()}")
if primary_output_name:
self.primary_stem_output_path = os.path.join(f"{primary_output_name}.{self.output_format.lower()}")
else:
self.primary_stem_output_path = os.path.join(f"{self.audio_file_base}_({self.primary_stem_name})_{self.model_name}.{self.output_format.lower()}")

if not isinstance(self.primary_source, np.ndarray):
self.primary_source = source.T
Expand Down
14 changes: 11 additions & 3 deletions audio_separator/separator/architectures/vr_separator.py
Original file line number Diff line number Diff line change
Expand Up @@ -110,13 +110,15 @@ def __init__(self, common_config, arch_config: dict):

self.logger.info("VR Separator initialisation complete")

def separate(self, audio_file_path):
def separate(self, audio_file_path, primary_output_name=None, secondary_output_name=None):
"""
Separates the audio file into primary and secondary sources based on the model's configuration.
It processes the mix, demixes it into sources, normalizes the sources, and saves the output files.

Args:
audio_file_path (str): The path to the audio file to be processed.
primary_output_name (str, optional): Custom name for the primary output file. Defaults to None.
secondary_output_name (str, optional): Custom name for the secondary output file. Defaults to None.

Returns:
list: A list of paths to the output files generated by the separation process.
Expand Down Expand Up @@ -195,7 +197,10 @@ def separate(self, audio_file_path):
self.primary_source = librosa.resample(self.primary_source.T, orig_sr=self.model_samplerate, target_sr=44100).T
self.logger.debug("Resampling primary source to 44100Hz.")

self.primary_stem_output_path = os.path.join(f"{self.audio_file_base}_({self.primary_stem_name})_{self.model_name}.{self.output_format.lower()}")
if primary_output_name:
self.primary_stem_output_path = os.path.join(f"{primary_output_name}.{self.output_format.lower()}")
else:
self.primary_stem_output_path = os.path.join(f"{self.audio_file_base}_({self.primary_stem_name})_{self.model_name}.{self.output_format.lower()}")

self.logger.info(f"Saving {self.primary_stem_name} stem to {self.primary_stem_output_path}...")
self.final_process(self.primary_stem_output_path, self.primary_source, self.primary_stem_name)
Expand All @@ -213,7 +218,10 @@ def separate(self, audio_file_path):
self.secondary_source = librosa.resample(self.secondary_source.T, orig_sr=self.model_samplerate, target_sr=44100).T
self.logger.debug("Resampling secondary source to 44100Hz.")

self.secondary_stem_output_path = os.path.join(f"{self.audio_file_base}_({self.secondary_stem_name})_{self.model_name}.{self.output_format.lower()}")
if secondary_output_name:
self.secondary_stem_output_path = os.path.join(f"{secondary_output_name}.{self.output_format.lower()}")
else:
self.secondary_stem_output_path = os.path.join(f"{self.audio_file_base}_({self.secondary_stem_name})_{self.model_name}.{self.output_format.lower()}")

self.logger.info(f"Saving {self.secondary_stem_name} stem to {self.secondary_stem_output_path}...")
self.final_process(self.secondary_stem_output_path, self.secondary_source, self.secondary_stem_name)
Expand Down
8 changes: 5 additions & 3 deletions audio_separator/separator/separator.py
Original file line number Diff line number Diff line change
Expand Up @@ -737,7 +737,7 @@ def load_model(self, model_filename="model_mel_band_roformer_ep_3005_sdr_11.4360
self.logger.debug("Loading model completed.")
self.logger.info(f'Load model duration: {time.strftime("%H:%M:%S", time.gmtime(int(time.perf_counter() - load_model_start_time)))}')

def separate(self, audio_file_path):
def separate(self, audio_file_path, primary_output_name=None, secondary_output_name=None):
"""
Separates the audio file into different stems (e.g., vocals, instruments) using the loaded model.

Expand All @@ -747,6 +747,8 @@ def separate(self, audio_file_path):

Parameters:
- audio_file_path (str): The path to the audio file to be separated.
- primary_output_name (str, optional): Custom name for the primary output file. Defaults to None.
- secondary_output_name (str, optional): Custom name for the secondary output file. Defaults to None.

Returns:
- output_files (list of str): A list containing the paths to the separated audio stem files.
Expand All @@ -766,10 +768,10 @@ def separate(self, audio_file_path):
if self.use_autocast and autocast_mode.is_autocast_available(self.torch_device.type):
self.logger.debug("Autocast available.")
with autocast_mode.autocast(self.torch_device.type):
output_files = self.model_instance.separate(audio_file_path)
output_files = self.model_instance.separate(audio_file_path, primary_output_name, secondary_output_name)
else:
self.logger.debug("Autocast unavailable.")
output_files = self.model_instance.separate(audio_file_path)
output_files = self.model_instance.separate(audio_file_path, primary_output_name, secondary_output_name)

# Clear GPU cache to free up memory
self.model_instance.clear_gpu_cache()
Expand Down
6 changes: 5 additions & 1 deletion audio_separator/utils/cli.py
Original file line number Diff line number Diff line change
Expand Up @@ -56,6 +56,8 @@ def main():
sample_rate_help = "modify the sample rate of the output audio (default: %(default)s). Example: --sample_rate=44100"
use_soundfile_help = "Use soundfile to write audio output (default: %(default)s). Example: --use_soundfile"
use_autocast_help = "use PyTorch autocast for faster inference (default: %(default)s). Do not use for CPU inference. Example: --use_autocast"
primary_output_name_help = "Custom name for the primary output file (default: %(default)s). Example: --primary_output_name=custom_primary_output"
secondary_output_name_help = "Custom name for the secondary output file (default: %(default)s). Example: --secondary_output_name=custom_secondary_output"

common_params = parser.add_argument_group("Common Separation Parameters")
common_params.add_argument("--invert_spect", action="store_true", help=invert_spect_help)
Expand All @@ -65,6 +67,8 @@ def main():
common_params.add_argument("--sample_rate", type=int, default=44100, help=sample_rate_help)
common_params.add_argument("--use_soundfile", action="store_true", help=use_soundfile_help)
common_params.add_argument("--use_autocast", action="store_true", help=use_autocast_help)
common_params.add_argument("--primary_output_name", default=None, help=primary_output_name_help)
common_params.add_argument("--secondary_output_name", default=None, help=secondary_output_name_help)

mdx_segment_size_help = "larger consumes more resources, but may give better results (default: %(default)s). Example: --mdx_segment_size=256"
mdx_overlap_help = "amount of overlap between prediction windows, 0.001-0.999. higher is better but slower (default: %(default)s). Example: --mdx_overlap=0.25"
Expand Down Expand Up @@ -201,6 +205,6 @@ def main():
separator.load_model(model_filename=args.model_filename)

for audio_file in args.audio_files:
output_files = separator.separate(audio_file)
output_files = separator.separate(audio_file, primary_output_name=args.primary_output_name, secondary_output_name=args.secondary_output_name)

logger.info(f"Separation complete! Output file(s): {' '.join(output_files)}")
44 changes: 44 additions & 0 deletions tests/unit/test_cli.py
Original file line number Diff line number Diff line change
Expand Up @@ -224,6 +224,7 @@ def test_cli_invert_spectrogram_argument(common_expected_args):
# Assertions
mock_separator.assert_called_once_with(**expected_args)


# Test using use_autocast argument
def test_cli_use_autocast_argument(common_expected_args):
test_args = ["cli.py", "test_audio.mp3", "--use_autocast"]
Expand All @@ -240,6 +241,7 @@ def test_cli_use_autocast_argument(common_expected_args):
# Assertions
mock_separator.assert_called_once_with(**common_expected_args)


# Test using use_autocast argument
def test_cli_use_autocast_argument(common_expected_args):
test_args = ["cli.py", "test_audio.mp3", "--use_autocast"]
Expand All @@ -254,3 +256,45 @@ def test_cli_use_autocast_argument(common_expected_args):

# Assertions
mock_separator.assert_called_once_with(**common_expected_args)


# Test using primary_output_name argument
def test_cli_primary_output_name_argument(common_expected_args):
test_args = ["cli.py", "test_audio.mp3", "--primary_output_name=custom_primary_output"]
with patch("sys.argv", test_args):
with patch("audio_separator.separator.Separator") as mock_separator:
mock_separator_instance = mock_separator.return_value
mock_separator_instance.separate.return_value = ["output_file.mp3"]
main()

# Assertions
mock_separator.assert_called_once_with(**common_expected_args)
mock_separator_instance.separate.assert_called_once_with("test_audio.mp3", primary_output_name="custom_primary_output", secondary_output_name=None)


# Test using secondary_output_name argument
def test_cli_secondary_output_name_argument(common_expected_args):
test_args = ["cli.py", "test_audio.mp3", "--secondary_output_name=custom_secondary_output"]
with patch("sys.argv", test_args):
with patch("audio_separator.separator.Separator") as mock_separator:
mock_separator_instance = mock_separator.return_value
mock_separator_instance.separate.return_value = ["output_file.mp3"]
main()

# Assertions
mock_separator.assert_called_once_with(**common_expected_args)
mock_separator_instance.separate.assert_called_once_with("test_audio.mp3", primary_output_name=None, secondary_output_name="custom_secondary_output")


# Test using both primary_output_name and secondary_output_name arguments
def test_cli_both_output_names_argument(common_expected_args):
test_args = ["cli.py", "test_audio.mp3", "--primary_output_name=custom_primary_output", "--secondary_output_name=custom_secondary_output"]
with patch("sys.argv", test_args):
with patch("audio_separator.separator.Separator") as mock_separator:
mock_separator_instance = mock_separator.return_value
mock_separator_instance.separate.return_value = ["output_file.mp3"]
main()

# Assertions
mock_separator.assert_called_once_with(**common_expected_args)
mock_separator_instance.separate.assert_called_once_with("test_audio.mp3", primary_output_name="custom_primary_output", secondary_output_name="custom_secondary_output")
Loading