Skip to content

Commit

Permalink
Merge pull request #28 from karaokenerds/load-model-once-for-repeated…
Browse files Browse the repository at this point in the history
…-use

Batch processing: Separated model load and audio file for separation from init
  • Loading branch information
beveradb authored Jan 1, 2024
2 parents 932fa2a + 78f2d8e commit 93c58d3
Show file tree
Hide file tree
Showing 4 changed files with 102 additions and 51 deletions.
44 changes: 38 additions & 6 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -145,27 +145,59 @@ This command will process the file and generate two new files in the current dir
### As a Dependency in a Python Project
You can also use Audio Separator in your Python project. Here's how you can use it:
You can use Audio Separator in your own Python project. Here's how you can use it:

```
from audio_separator.separator import Separator

# Initialize the Separator with the audio file and model name
separator = Separator('/path/to/your/audio.m4a', model_name='UVR_MDXNET_KARA_2')
# Initialize the Separator class (with optional configuration properties below)
separator = Separator()

# Perform the separation
primary_stem_path, secondary_stem_path = separator.separate()
# Load a machine learning model (if unspecified, defaults to 'UVR-MDX-NET-Inst_HQ_3')
separator.load_model()

# Perform the separation on specific audio files without reloading the model
primary_stem_path, secondary_stem_path = separator.separate('audio1.wav')

print(f'Primary stem saved at {primary_stem_path}')
print(f'Secondary stem saved at {secondary_stem_path}')
```
#### Batch processing, or processing with multiple models
You can process multiple separations without reloading the model, to save time and memory.
You only need to load a model when choosing or changing models. See example below:
```
from audio_separator.separator import Separator

# Initialize the Separator with other configuration properties below
separator = Separator()

# Load a model
separator.load_model('UVR-MDX-NET-Inst_HQ_3')

# Separate multiple audio files without reloading the model
output_file_paths_1 = separator.separate('audio1.wav')
output_file_paths_2 = separator.separate('audio2.wav')
output_file_paths_3 = separator.separate('audio3.wav')

# Load a different model
separator.load_model('UVR_MDXNET_KARA_2')

# Separate the same files with the new model
output_file_paths_4 = separator.separate('audio1.wav')
output_file_paths_5 = separator.separate('audio2.wav')
output_file_paths_6 = separator.separate('audio3.wav')
```
## Parameters for the Separator class
- audio_file: The path to the audio file to be separated. Supports all common formats (WAV, MP3, FLAC, M4A, etc.)
- log_level: (Optional) Logging level, e.g. info, debug, warning. Default: INFO
- log_formatter: (Optional) The log format. Default: '%(asctime)s - %(name)s - %(levelname)s - %(message)s'
- model_name: (Optional) The name of the model to use for separation. Defaults to 'UVR_MDXNET_KARA_2', a very powerful model for Karaoke instrumental tracks.
- model_name: (Optional) The name of the model to use for separation. Defaults to 'UVR-MDX-NET-Inst_HQ_3', a very powerful model for Karaoke instrumental tracks.
- model_file_dir: (Optional) Directory to cache model files in. Default: /tmp/audio-separator-models/
- output_dir: (Optional) Directory where the separated files will be saved. If not specified, outputs to current dir.
- output_format: (Optional) Format to encode output files, any common format (WAV, MP3, FLAC, M4A, etc.). Default: WAV
Expand Down
94 changes: 56 additions & 38 deletions audio_separator/separator/separator.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,10 +20,8 @@
class Separator:
def __init__(
self,
audio_file_path,
log_level=logging.DEBUG,
log_formatter=None,
model_name="UVR_MDXNET_KARA_2",
model_file_dir="/tmp/audio-separator-models/",
output_dir=None,
primary_stem_path=None,
Expand Down Expand Up @@ -55,11 +53,8 @@ def __init__(
if not self.logger.hasHandlers():
self.logger.addHandler(self.log_handler)

self.logger.info(
f"Separator instantiating with input file: {audio_file_path}, model_name: {model_name}, output_dir: {output_dir}, output_format: {output_format}"
)
self.logger.info(f"Separator instantiating with output_dir: {output_dir}, output_format: {output_format}")

self.model_name = model_name
self.model_file_dir = model_file_dir
self.output_dir = output_dir
self.primary_stem_path = primary_stem_path
Expand All @@ -68,13 +63,6 @@ def __init__(
# Create the model directory if it does not exist
os.makedirs(self.model_file_dir, exist_ok=True)

self.audio_file_path = audio_file_path
self.audio_file_base = os.path.splitext(os.path.basename(audio_file_path))[0]

self.model_name = model_name
self.model_url = f"https://github.com/TRvlvr/model_repo/releases/download/all_public_uvr_models/{self.model_name}.onnx"
self.model_data_url = "https://raw.githubusercontent.com/TRvlvr/application_data/main/mdx_model_data/model_data.json"

self.output_subtype = output_subtype
self.output_format = output_format

Expand Down Expand Up @@ -119,9 +107,6 @@ def __init__(
f"Separation settings set: sample_rate={self.sample_rate}, hop_length={self.hop_length}, segment_size={self.segment_size}, overlap={self.overlap}, batch_size={self.batch_size}"
)

self.primary_source = None
self.secondary_source = None

warnings.filterwarnings("ignore")
self.cpu = torch.device("cpu")

Expand Down Expand Up @@ -204,10 +189,14 @@ def clear_gpu_cache(self):
self.logger.debug("Clearing CUDA cache...")
torch.cuda.empty_cache()

def separate(self):
# Starting the separation process
self.logger.debug("Starting separation process...")
self.separate_start_time = time.perf_counter()
def load_model(self, model_name="UVR-MDX-NET-Inst_HQ_3"):
self.logger.info(f"Loading model {model_name}...")

self.load_model_start_time = time.perf_counter()

self.model_name = model_name
self.model_url = f"https://github.com/TRvlvr/model_repo/releases/download/all_public_uvr_models/{self.model_name}.onnx"
self.model_data_url = "https://raw.githubusercontent.com/TRvlvr/application_data/main/mdx_model_data/model_data.json"

# Setting up the model path
model_path = os.path.join(self.model_file_dir, f"{self.model_name}.onnx")
Expand Down Expand Up @@ -237,18 +226,18 @@ def separate(self):
self.logger.debug(f"Model data loaded: {model_data}")

# Initializing model parameters
self.compensate, self.dim_f, self.dim_t, self.n_fft, self.primary_stem = (
self.compensate, self.dim_f, self.dim_t, self.n_fft, self.model_primary_stem = (
model_data["compensate"],
model_data["mdx_dim_f_set"],
2 ** model_data["mdx_dim_t_set"],
model_data["mdx_n_fft_scale_set"],
model_data["primary_stem"],
)
self.secondary_stem = "Vocals" if self.primary_stem == "Instrumental" else "Instrumental"
self.model_secondary_stem = "Vocals" if self.model_primary_stem == "Instrumental" else "Instrumental"

# In UVR, these variables are set but either aren't useful or are better handled in audio-separator.
# Leaving these comments explaining to help myself or future developers understand why these aren't in audio-separator.

# "chunks" is not actually used for anything in UVR...
# self.chunks = 0

Expand All @@ -261,17 +250,15 @@ def separate(self):
# "margin" maps to sample rate and is set from the GUI in UVR (default: 44100). We have a "sample_rate" parameter instead.
# self.margin = 44100

# "dim_c" is hard-coded to 4 in UVR, seems to be a parameter for the number of channels, and is only used for checkpoint models.
# "dim_c" is hard-coded to 4 in UVR, seems to be a parameter for the number of channels, and is only used for checkpoint models.
# We haven't implemented support for the checkpoint models here, so we're not using it.
# self.dim_c = 4

self.logger.debug(f"Model params: primary_stem={self.primary_stem}, secondary_stem={self.secondary_stem}")
self.logger.debug(f"Model params: primary_stem={self.model_primary_stem}, secondary_stem={self.model_secondary_stem}")
self.logger.debug(
f"Model params: batch_size={self.batch_size}, compensate={self.compensate}, segment_size={self.segment_size}, dim_f={self.dim_f}, dim_t={self.dim_t}"
)
self.logger.debug(
f"Model params: n_fft={self.n_fft}, hop={self.hop_length}"
)
self.logger.debug(f"Model params: n_fft={self.n_fft}, hop={self.hop_length}")

# Loading the model for inference
self.logger.debug("Loading ONNX model for inference...")
Expand All @@ -284,6 +271,23 @@ def separate(self):
self.model_run.to(self.device).eval()
self.logger.warning("Model converted from onnx to pytorch due to segment size not matching dim_t, processing may be slower.")

# Log the completion of the separation process
self.logger.debug("Loading model completed.")
self.logger.info(
f'Load model duration: {time.strftime("%H:%M:%S", time.gmtime(int(time.perf_counter() - self.load_model_start_time)))}'
)

def separate(self, audio_file_path):
# Starting the separation process
self.logger.info(f"Starting separation process for audio_file_path: {audio_file_path}")
self.separate_start_time = time.perf_counter()

self.primary_source = None
self.secondary_source = None

self.audio_file_path = audio_file_path
self.audio_file_base = os.path.splitext(os.path.basename(audio_file_path))[0]

# Prepare the mix for processing
self.logger.debug("Preparing mix...")
mix = self.prepare_mix(self.audio_file_path)
Expand Down Expand Up @@ -315,27 +319,27 @@ def separate(self):
self.secondary_source = mix.T - source.T

# Save and process the secondary stem if needed
if not self.output_single_stem or self.output_single_stem.lower() == self.secondary_stem.lower():
self.logger.info(f"Saving {self.secondary_stem} stem...")
if not self.output_single_stem or self.output_single_stem.lower() == self.model_secondary_stem.lower():
self.logger.info(f"Saving {self.model_secondary_stem} stem...")
if not self.secondary_stem_path:
self.secondary_stem_path = os.path.join(
f"{self.audio_file_base}_({self.secondary_stem})_{self.model_name}.{self.output_format.lower()}"
f"{self.audio_file_base}_({self.model_secondary_stem})_{self.model_name}.{self.output_format.lower()}"
)
self.secondary_source_map = self.final_process(
self.secondary_stem_path, self.secondary_source, self.secondary_stem, self.sample_rate
self.secondary_stem_path, self.secondary_source, self.model_secondary_stem, self.sample_rate
)
output_files.append(self.secondary_stem_path)

# Save and process the primary stem if needed
if not self.output_single_stem or self.output_single_stem.lower() == self.primary_stem.lower():
self.logger.info(f"Saving {self.primary_stem} stem...")
if not self.output_single_stem or self.output_single_stem.lower() == self.model_primary_stem.lower():
self.logger.info(f"Saving {self.model_primary_stem} stem...")
if not self.primary_stem_path:
self.primary_stem_path = os.path.join(
f"{self.audio_file_base}_({self.primary_stem})_{self.model_name}.{self.output_format.lower()}"
f"{self.audio_file_base}_({self.model_primary_stem})_{self.model_name}.{self.output_format.lower()}"
)
if not isinstance(self.primary_source, np.ndarray):
self.primary_source = source.T
self.primary_source_map = self.final_process(self.primary_stem_path, self.primary_source, self.primary_stem, self.sample_rate)
self.primary_source_map = self.final_process(self.primary_stem_path, self.primary_source, self.model_primary_stem, self.sample_rate)
output_files.append(self.primary_stem_path)

# Clear GPU cache to free up memory
Expand All @@ -345,7 +349,21 @@ def separate(self):

# Log the completion of the separation process
self.logger.debug("Separation process completed.")
self.logger.info(f'Time Elapsed: {time.strftime("%H:%M:%S", time.gmtime(int(time.perf_counter() - self.separate_start_time)))}')
self.logger.info(
f'Separation duration: {time.strftime("%H:%M:%S", time.gmtime(int(time.perf_counter() - self.separate_start_time)))}'
)

# Unset the audio file to prevent accidental re-separation of the same file
self.logger.debug("Clearing audio file...")
self.audio_file_path = None
self.audio_file_base = None

# Unset more separation params to prevent accidentally re-using the wrong source files or output paths
self.logger.debug("Clearing sources and stems...")
self.primary_source = None
self.secondary_source = None
self.primary_stem_path = None
self.secondary_stem_path = None

return output_files

Expand Down Expand Up @@ -482,7 +500,7 @@ def initialize_mix(self, mix, is_ckpt=False):
return mix_waves_tensor, pad

def demix(self, mix, is_match_mix=False):
self.logger.info(f"Starting demixing process with is_match_mix: {is_match_mix}...")
self.logger.debug(f"Starting demixing process with is_match_mix: {is_match_mix}...")
self.initialize_model_settings()

# Preserves the original mix for later use.
Expand Down
13 changes: 7 additions & 6 deletions audio_separator/utils/cli.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,8 +29,8 @@ def main():

parser.add_argument(
"--model_name",
default="UVR_MDXNET_KARA_2",
help="Optional: model name to be used for separation (default: %(default)s). Example: --model_name=UVR-MDX-NET-Inst_HQ_3",
default="UVR-MDX-NET-Inst_HQ_3",
help="Optional: model name to be used for separation (default: %(default)s). Example: --model_name=UVR_MDXNET_KARA_2",
)

parser.add_argument(
Expand Down Expand Up @@ -124,14 +124,12 @@ def main():

logger.info(f"Separator beginning with input file: {args.audio_file}")

# Deliberately import here to avoid loading heave dependencies when just running --help
# Deliberately import here to avoid loading slow dependencies when just running --help
from audio_separator.separator import Separator

separator = Separator(
args.audio_file,
log_formatter=log_formatter,
log_level=log_level,
model_name=args.model_name,
model_file_dir=args.model_file_dir,
output_dir=args.output_dir,
output_format=args.output_format,
Expand All @@ -145,7 +143,10 @@ def main():
overlap=args.overlap,
batch_size=args.batch_size,
)
output_files = separator.separate()

separator.load_model(args.model_name)

output_files = separator.separate(args.audio_file)

logger.info(f"Separation complete! Output file(s): {' '.join(output_files)}")

Expand Down
2 changes: 1 addition & 1 deletion pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@ build-backend = "poetry.core.masonry.api"

[tool.poetry]
name = "audio-separator"
version = "0.11.7"
version = "0.12.0"
description = "Easy to use vocal separation, using MDX-Net models from UVR trained by @Anjok07"
authors = ["Andrew Beveridge <[email protected]>"]
license = "MIT"
Expand Down

0 comments on commit 93c58d3

Please sign in to comment.