Skip to content

Commit

Permalink
Improved input of custom stem names (now for all architectures) (#159)
Browse files Browse the repository at this point in the history
* Update common_separator.py

* Update separator.py

* Update demucs_separator.py

* Update vr_separator.py

* Update mdxc_separator.py

* Update mdx_separator.py

* Update cli.py

* Update test_cli.py

* Update test_cli.py

* Update test_cli.py

* Update cli.py

* Update test_cli.py

* Update README.md

* Update README.md

* Update README.md

* Update README.md

* Minor update of README

* Update README.md
  • Loading branch information
Bebra777228 authored Dec 8, 2024
1 parent 62d8b00 commit 4410feb
Show file tree
Hide file tree
Showing 9 changed files with 246 additions and 170 deletions.
209 changes: 134 additions & 75 deletions README.md

Large diffs are not rendered by default.

11 changes: 9 additions & 2 deletions audio_separator/separator/architectures/demucs_separator.py
Original file line number Diff line number Diff line change
Expand Up @@ -80,9 +80,16 @@ def __init__(self, common_config, arch_config):

self.logger.info("Demucs Separator initialisation complete")

def separate(self, audio_file_path, primary_output_name=None, secondary_output_name=None):
def separate(self, audio_file_path, custom_output_names=None):
"""
Separates the audio file into its component stems using the Demucs model.
Args:
audio_file_path (str): The path to the audio file to be processed.
custom_output_names (dict, optional): Custom names for the output files. Defaults to None.
Returns:
list: A list of paths to the output files generated by the separation process.
"""
self.logger.debug("Starting separation process...")
source = None
Expand Down Expand Up @@ -144,7 +151,7 @@ def separate(self, audio_file_path, primary_output_name=None, secondary_output_n
self.logger.debug(f"Skipping writing stem {stem_name} as output_single_stem is set to {self.output_single_stem}...")
continue

stem_path = os.path.join(f"{self.audio_file_base}_({stem_name})_{self.model_name}.{self.output_format.lower()}")
stem_path = self.get_stem_output_path(stem_name, custom_output_names)
stem_source = source[stem_value].T

self.final_process(stem_path, stem_source, stem_name)
Expand Down
15 changes: 4 additions & 11 deletions audio_separator/separator/architectures/mdx_separator.py
Original file line number Diff line number Diff line change
Expand Up @@ -132,15 +132,14 @@ def load_model(self):
self.model_run.to(self.torch_device).eval()
self.logger.warning("Model converted from onnx to pytorch due to segment size not matching dim_t, processing may be slower.")

def separate(self, audio_file_path, primary_output_name=None, secondary_output_name=None):
def separate(self, audio_file_path, custom_output_names=None):
"""
Separates the audio file into primary and secondary sources based on the model's configuration.
It processes the mix, demixes it into sources, normalizes the sources, and saves the output files.
Args:
audio_file_path (str): The path to the audio file to be processed.
primary_output_name (str, optional): Custom name for the primary output file. Defaults to None.
secondary_output_name (str, optional): Custom name for the secondary output file. Defaults to None.
custom_output_names (dict, optional): Custom names for the output files. Defaults to None.
Returns:
list: A list of paths to the output files generated by the separation process.
Expand Down Expand Up @@ -184,21 +183,15 @@ def separate(self, audio_file_path, primary_output_name=None, secondary_output_n

# Save and process the secondary stem if needed
if not self.output_single_stem or self.output_single_stem.lower() == self.secondary_stem_name.lower():
if secondary_output_name:
self.secondary_stem_output_path = os.path.join(f"{secondary_output_name}.{self.output_format.lower()}")
else:
self.secondary_stem_output_path = os.path.join(f"{self.audio_file_base}_({self.secondary_stem_name})_{self.model_name}.{self.output_format.lower()}")
self.secondary_stem_output_path = self.get_stem_output_path(self.secondary_stem_name, custom_output_names)

self.logger.info(f"Saving {self.secondary_stem_name} stem to {self.secondary_stem_output_path}...")
self.final_process(self.secondary_stem_output_path, self.secondary_source, self.secondary_stem_name)
output_files.append(self.secondary_stem_output_path)

# Save and process the primary stem if needed
if not self.output_single_stem or self.output_single_stem.lower() == self.primary_stem_name.lower():
if primary_output_name:
self.primary_stem_output_path = os.path.join(f"{primary_output_name}.{self.output_format.lower()}")
else:
self.primary_stem_output_path = os.path.join(f"{self.audio_file_base}_({self.primary_stem_name})_{self.model_name}.{self.output_format.lower()}")
self.primary_stem_output_path = self.get_stem_output_path(self.primary_stem_name, custom_output_names)

if not isinstance(self.primary_source, np.ndarray):
self.primary_source = source.T
Expand Down
15 changes: 4 additions & 11 deletions audio_separator/separator/architectures/mdxc_separator.py
Original file line number Diff line number Diff line change
Expand Up @@ -111,15 +111,14 @@ def load_model(self):
self.logger.error(f"Please try deleting the model file from {self.model_path} and run audio-separator again to re-download it.")
sys.exit(1)

def separate(self, audio_file_path, primary_output_name=None, secondary_output_name=None):
def separate(self, audio_file_path, custom_output_names=None):
"""
Separates the audio file into primary and secondary sources based on the model's configuration.
It processes the mix, demixes it into sources, normalizes the sources, and saves the output files.
Args:
audio_file_path (str): The path to the audio file to be processed.
primary_output_name (str, optional): Custom name for the primary output file. Defaults to None.
secondary_output_name (str, optional): Custom name for the secondary output file. Defaults to None.
custom_output_names (dict, optional): Custom names for the output files. Defaults to None.
Returns:
list: A list of paths to the output files generated by the separation process.
Expand Down Expand Up @@ -154,20 +153,14 @@ def separate(self, audio_file_path, primary_output_name=None, secondary_output_n
self.secondary_source = spec_utils.normalize(wave=source[self.secondary_stem_name], max_peak=self.normalization_threshold, min_peak=self.amplification_threshold).T

if not self.output_single_stem or self.output_single_stem.lower() == self.secondary_stem_name.lower():
if secondary_output_name:
self.secondary_stem_output_path = os.path.join(f"{secondary_output_name}.{self.output_format.lower()}")
else:
self.secondary_stem_output_path = os.path.join(f"{self.audio_file_base}_({self.secondary_stem_name})_{self.model_name}.{self.output_format.lower()}")
self.secondary_stem_output_path = self.get_stem_output_path(self.secondary_stem_name, custom_output_names)

self.logger.info(f"Saving {self.secondary_stem_name} stem to {self.secondary_stem_output_path}...")
self.final_process(self.secondary_stem_output_path, self.secondary_source, self.secondary_stem_name)
output_files.append(self.secondary_stem_output_path)

if not isinstance(source, dict) or not self.output_single_stem or self.output_single_stem.lower() == self.primary_stem_name.lower():
if primary_output_name:
self.primary_stem_output_path = os.path.join(f"{primary_output_name}.{self.output_format.lower()}")
else:
self.primary_stem_output_path = os.path.join(f"{self.audio_file_base}_({self.primary_stem_name})_{self.model_name}.{self.output_format.lower()}")
self.primary_stem_output_path = self.get_stem_output_path(self.primary_stem_name, custom_output_names)

if not isinstance(self.primary_source, np.ndarray):
self.primary_source = source.T
Expand Down
15 changes: 4 additions & 11 deletions audio_separator/separator/architectures/vr_separator.py
Original file line number Diff line number Diff line change
Expand Up @@ -110,15 +110,14 @@ def __init__(self, common_config, arch_config: dict):

self.logger.info("VR Separator initialisation complete")

def separate(self, audio_file_path, primary_output_name=None, secondary_output_name=None):
def separate(self, audio_file_path, custom_output_names=None):
"""
Separates the audio file into primary and secondary sources based on the model's configuration.
It processes the mix, demixes it into sources, normalizes the sources, and saves the output files.
Args:
audio_file_path (str): The path to the audio file to be processed.
primary_output_name (str, optional): Custom name for the primary output file. Defaults to None.
secondary_output_name (str, optional): Custom name for the secondary output file. Defaults to None.
custom_output_names (dict, optional): Custom names for the output files. Defaults to None.
Returns:
list: A list of paths to the output files generated by the separation process.
Expand Down Expand Up @@ -197,10 +196,7 @@ def separate(self, audio_file_path, primary_output_name=None, secondary_output_n
self.primary_source = librosa.resample(self.primary_source.T, orig_sr=self.model_samplerate, target_sr=44100).T
self.logger.debug("Resampling primary source to 44100Hz.")

if primary_output_name:
self.primary_stem_output_path = os.path.join(f"{primary_output_name}.{self.output_format.lower()}")
else:
self.primary_stem_output_path = os.path.join(f"{self.audio_file_base}_({self.primary_stem_name})_{self.model_name}.{self.output_format.lower()}")
self.primary_stem_output_path = self.get_stem_output_path(self.primary_stem_name, custom_output_names)

self.logger.info(f"Saving {self.primary_stem_name} stem to {self.primary_stem_output_path}...")
self.final_process(self.primary_stem_output_path, self.primary_source, self.primary_stem_name)
Expand All @@ -218,10 +214,7 @@ def separate(self, audio_file_path, primary_output_name=None, secondary_output_n
self.secondary_source = librosa.resample(self.secondary_source.T, orig_sr=self.model_samplerate, target_sr=44100).T
self.logger.debug("Resampling secondary source to 44100Hz.")

if secondary_output_name:
self.secondary_stem_output_path = os.path.join(f"{secondary_output_name}.{self.output_format.lower()}")
else:
self.secondary_stem_output_path = os.path.join(f"{self.audio_file_base}_({self.secondary_stem_name})_{self.model_name}.{self.output_format.lower()}")
self.secondary_stem_output_path = self.get_stem_output_path(self.secondary_stem_name, custom_output_names)

self.logger.info(f"Saving {self.secondary_stem_name} stem to {self.secondary_stem_output_path}...")
self.final_process(self.secondary_stem_output_path, self.secondary_source, self.secondary_stem_name)
Expand Down
9 changes: 9 additions & 0 deletions audio_separator/separator/common_separator.py
Original file line number Diff line number Diff line change
Expand Up @@ -372,3 +372,12 @@ def clear_file_specific_paths(self):

self.primary_stem_output_path = None
self.secondary_stem_output_path = None

def get_stem_output_path(self, stem_name, custom_output_names):
"""
Gets the output path for a stem based on the stem name and custom output names.
"""
if custom_output_names and stem_name in custom_output_names:
return os.path.join(f"{custom_output_names[stem_name]}.{self.output_format.lower()}")
else:
return os.path.join(f"{self.audio_file_base}_({stem_name})_{self.model_name}.{self.output_format.lower()}")
13 changes: 6 additions & 7 deletions audio_separator/separator/separator.py
Original file line number Diff line number Diff line change
Expand Up @@ -737,7 +737,7 @@ def load_model(self, model_filename="model_mel_band_roformer_ep_3005_sdr_11.4360
self.logger.debug("Loading model completed.")
self.logger.info(f'Load model duration: {time.strftime("%H:%M:%S", time.gmtime(int(time.perf_counter() - load_model_start_time)))}')

def separate(self, audio_file_path, primary_output_name=None, secondary_output_name=None):
def separate(self, audio_file_path, custom_output_names=None):
"""
Separates the audio file into different stems (e.g., vocals, instruments) using the loaded model.
Expand All @@ -747,8 +747,7 @@ def separate(self, audio_file_path, primary_output_name=None, secondary_output_n
Parameters:
- audio_file_path (str): The path to the audio file to be separated.
- primary_output_name (str, optional): Custom name for the primary output file. Defaults to None.
- secondary_output_name (str, optional): Custom name for the secondary output file. Defaults to None.
- custom_output_names (dict, optional): Custom names for the output files. Defaults to None.
Returns:
- output_files (list of str): A list containing the paths to the separated audio stem files.
Expand All @@ -760,18 +759,18 @@ def separate(self, audio_file_path, primary_output_name=None, secondary_output_n
self.logger.info(f"Starting separation process for audio_file_path: {audio_file_path}")
separate_start_time = time.perf_counter()

self.logger.debug(f"Normalization threshold set to {self.normalization_threshold}, waveform will lowered to this max amplitude to avoid clipping.")
self.logger.debug(f"Amplification threshold set to {self.amplification_threshold}, waveform will scaled up to this max amplitude if below it.")
self.logger.debug(f"Normalization threshold set to {self.normalization_threshold}, waveform will be lowered to this max amplitude to avoid clipping.")
self.logger.debug(f"Amplification threshold set to {self.amplification_threshold}, waveform will be scaled up to this max amplitude if below it.")

# Run separation method for the loaded model with autocast enabled if supported by the device.
output_files = None
if self.use_autocast and autocast_mode.is_autocast_available(self.torch_device.type):
self.logger.debug("Autocast available.")
with autocast_mode.autocast(self.torch_device.type):
output_files = self.model_instance.separate(audio_file_path, primary_output_name, secondary_output_name)
output_files = self.model_instance.separate(audio_file_path, custom_output_names)
else:
self.logger.debug("Autocast unavailable.")
output_files = self.model_instance.separate(audio_file_path, primary_output_name, secondary_output_name)
output_files = self.model_instance.separate(audio_file_path, custom_output_names)

# Clear GPU cache to free up memory
self.model_instance.clear_gpu_cache()
Expand Down
Loading

0 comments on commit 4410feb

Please sign in to comment.