Skip to content

Commit

Permalink
Merge branch 'main' into custom_stem_name
Browse files Browse the repository at this point in the history
  • Loading branch information
Bebra777228 authored Dec 8, 2024
2 parents 8d14f1b + 62d8b00 commit f5c141f
Show file tree
Hide file tree
Showing 7 changed files with 365 additions and 333 deletions.
2 changes: 1 addition & 1 deletion README.md
Original file line number Diff line number Diff line change
Expand Up @@ -500,4 +500,4 @@ For questions or feedback, please raise an issue or reach out to @beveradb ([And
## Sponsors
<!-- sponsors --><!-- sponsors -->
<!-- sponsors --><!-- sponsors -->
19 changes: 9 additions & 10 deletions audio_separator/separator/architectures/mdxc_separator.py
Original file line number Diff line number Diff line change
Expand Up @@ -94,7 +94,7 @@ def load_model(self):
raise ValueError("Unknown Roformer model type in the configuration.")

# Load model checkpoint
checkpoint = torch.load(self.model_path, map_location="cpu")
checkpoint = torch.load(self.model_path, map_location="cpu", weights_only=True)
self.model_run = model if not isinstance(model, torch.nn.DataParallel) else model.module
self.model_run.load_state_dict(checkpoint)
self.model_run.to(self.torch_device).eval()
Expand Down Expand Up @@ -191,8 +191,6 @@ def overlap_add(self, result, x, weights, start, length):
"""
Adds the overlapping part of the result to the result tensor.
"""
x = x.to(result.device)
weights = weights.to(result.device)
result[..., start : start + length] += x[..., :length] * weights[:length]
return result

Expand Down Expand Up @@ -239,13 +237,11 @@ def demix(self, mix: np.ndarray) -> dict:

device = next(self.model_run.parameters()).device

# Transfer to the weighting plate for the same device as the other tensors
window = window.to(device)

with torch.no_grad():
req_shape = (len(self.model_data_cfgdict.training.instruments),) + tuple(mix.shape)
result = torch.zeros(req_shape, dtype=torch.float32).to(device)
counter = torch.zeros(req_shape, dtype=torch.float32).to(device)
result = torch.zeros(req_shape, dtype=torch.float32)
counter = torch.zeros(req_shape, dtype=torch.float32)

for i in tqdm(range(0, mix.shape[1], step)):
part = mix[:, i : i + chunk_size]
Expand All @@ -255,8 +251,10 @@ def demix(self, mix: np.ndarray) -> dict:
length = chunk_size
part = part.to(device)
x = self.model_run(part.unsqueeze(0))[0]
x = x.cpu()
# Perform overlap_add on CPU
if i + chunk_size > mix.shape[1]:
# Corrigido para adicionar corretamente ao final do tensor
# Fixed to correctly add to the end of the tensor
result = self.overlap_add(result, x, window, result.shape[-1] - chunk_size, length)
counter[..., result.shape[-1] - chunk_size :] += window[:length]
else:
Expand Down Expand Up @@ -304,7 +302,6 @@ def demix(self, mix: np.ndarray) -> dict:
# It starts as a tensor of zeros and is updated in-place as the model processes each batch.
# The variable holds the combined result of all processed batches, which, after post-processing, represents the separated audio sources.
accumulated_outputs = torch.zeros(num_stems, *mix.shape) if num_stems > 1 else torch.zeros_like(mix)
accumulated_outputs = accumulated_outputs.to(self.torch_device)

with torch.no_grad():
count = 0
Expand All @@ -317,7 +314,9 @@ def demix(self, mix: np.ndarray) -> dict:
# Since single_batch_result can contain multiple output tensors (one for each piece of audio in the batch),
# individual_output is used to iterate through these tensors and accumulate them into accumulated_outputs.
for individual_output in single_batch_result:
accumulated_outputs[..., count * hop_size : count * hop_size + chunk_size] += individual_output
individual_output_cpu = individual_output.cpu()
# Accumulate outputs on CPU
accumulated_outputs[..., count * hop_size : count * hop_size + chunk_size] += individual_output_cpu
count += 1

self.logger.debug("Calculating inferenced outputs based on accumulated outputs and overlap")
Expand Down
6 changes: 3 additions & 3 deletions audio_separator/separator/separator.py
Original file line number Diff line number Diff line change
Expand Up @@ -85,7 +85,7 @@ def __init__(
output_format="WAV",
output_bitrate=None,
normalization_threshold=0.9,
amplification_threshold=0.6,
amplification_threshold=0.0,
output_single_stem=None,
invert_using_spec=False,
sample_rate=44100,
Expand Down Expand Up @@ -142,8 +142,8 @@ def __init__(
raise ValueError("The normalization_threshold must be greater than 0 and less than or equal to 1.")

self.amplification_threshold = amplification_threshold
if amplification_threshold <= 0 or amplification_threshold > 1:
raise ValueError("The amplification_threshold must be greater than 0 and less than or equal to 1.")
if amplification_threshold < 0 or amplification_threshold > 1:
raise ValueError("The amplification_threshold must be greater than or equal to 0 and less than or equal to 1.")

self.output_single_stem = output_single_stem
if output_single_stem is not None:
Expand Down
2 changes: 1 addition & 1 deletion audio_separator/utils/cli.py
Original file line number Diff line number Diff line change
Expand Up @@ -59,7 +59,7 @@ def main():
common_params = parser.add_argument_group("Common Separation Parameters")
common_params.add_argument("--invert_spect", action="store_true", help=invert_spect_help)
common_params.add_argument("--normalization", type=float, default=0.9, help=normalization_help)
common_params.add_argument("--amplification", type=float, default=0.6, help=amplification_help)
common_params.add_argument("--amplification", type=float, default=0.0, help=amplification_help)
common_params.add_argument("--single_stem", default=None, help=single_stem_help)
common_params.add_argument("--sample_rate", type=int, default=44100, help=sample_rate_help)
common_params.add_argument("--use_soundfile", action="store_true", help=use_soundfile_help)
Expand Down
Loading

0 comments on commit f5c141f

Please sign in to comment.