Skip to content

Commit

Permalink
Fixed tests
Browse files Browse the repository at this point in the history
  • Loading branch information
beveradb committed Nov 2, 2024
1 parent 003b5af commit ec4bfcf
Showing 1 changed file with 36 additions and 12 deletions.
48 changes: 36 additions & 12 deletions tests/unit/test_cli.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,8 @@ def common_expected_args():
"output_single_stem": None,
"invert_using_spec": False,
"sample_rate": 44100,
"use_autocast": False,
"use_soundfile": False,
"mdx_params": {"hop_length": 1024, "segment_size": 256, "overlap": 0.25, "batch_size": 1, "enable_denoise": False},
"vr_params": {"batch_size": 1, "window_size": 512, "aggression": 5, "enable_tta": False, "enable_post_process": False, "post_process_threshold": 0.2, "high_end_process": False},
"demucs_params": {"segment_size": "Default", "shifts": 2, "overlap": 0.25, "segments_enabled": True},
Expand Down Expand Up @@ -114,10 +116,11 @@ def test_cli_output_dir_argument(common_expected_args):
main()

# Update expected args for this specific test
common_expected_args["output_dir"] = "/custom/output/dir"
expected_args = common_expected_args.copy()
expected_args["output_dir"] = "/custom/output/dir"

# Assertions
mock_separator.assert_called_once_with(**common_expected_args)
mock_separator.assert_called_once_with(**expected_args)


# Test using output format argument
Expand All @@ -130,10 +133,11 @@ def test_cli_output_format_argument(common_expected_args):
main()

# Update expected args for this specific test
common_expected_args["output_format"] = "MP3"
expected_args = common_expected_args.copy()
expected_args["output_format"] = "MP3"

# Assertions
mock_separator.assert_called_once_with(**common_expected_args)
mock_separator.assert_called_once_with(**expected_args)


# Test using normalization_threshold argument
Expand All @@ -146,10 +150,11 @@ def test_cli_normalization_threshold_argument(common_expected_args):
main()

# Update expected args for this specific test
common_expected_args["normalization_threshold"] = 0.75
expected_args = common_expected_args.copy()
expected_args["normalization_threshold"] = 0.75

# Assertions
mock_separator.assert_called_once_with(**common_expected_args)
mock_separator.assert_called_once_with(**expected_args)

# Test using normalization_threshold argument
def test_cli_amplification_threshold_argument(common_expected_args):
Expand All @@ -161,10 +166,11 @@ def test_cli_amplification_threshold_argument(common_expected_args):
main()

# Update expected args for this specific test
common_expected_args["amplification_threshold"] = 0.75
expected_args = common_expected_args.copy()
expected_args["amplification_threshold"] = 0.75

# Assertions
mock_separator.assert_called_once_with(**common_expected_args)
mock_separator.assert_called_once_with(**expected_args)

# Test using single stem argument
def test_cli_single_stem_argument(common_expected_args):
Expand All @@ -176,10 +182,11 @@ def test_cli_single_stem_argument(common_expected_args):
main()

# Update expected args for this specific test
common_expected_args["output_single_stem"] = "instrumental"
expected_args = common_expected_args.copy()
expected_args["output_single_stem"] = "instrumental"

# Assertions
mock_separator.assert_called_once_with(**common_expected_args)
mock_separator.assert_called_once_with(**expected_args)


# Test using invert spectrogram argument
Expand All @@ -192,7 +199,24 @@ def test_cli_invert_spectrogram_argument(common_expected_args):
main()

# Update expected args for this specific test
common_expected_args["invert_using_spec"] = True
expected_args = common_expected_args.copy()
expected_args["invert_using_spec"] = True

# Assertions
mock_separator.assert_called_once_with(**common_expected_args)
mock_separator.assert_called_once_with(**expected_args)

# Test using use_autocast argument
def test_cli_use_autocast_argument(common_expected_args):
test_args = ["cli.py", "test_audio.mp3", "--use_autocast"]
with patch("sys.argv", test_args):
with patch("audio_separator.separator.Separator") as mock_separator:
mock_separator_instance = mock_separator.return_value
mock_separator_instance.separate.return_value = ["output_file.mp3"]
main()

# Update expected args for this specific test
expected_args = common_expected_args.copy()
expected_args["use_autocast"] = True

# Assertions
mock_separator.assert_called_once_with(**expected_args)

0 comments on commit ec4bfcf

Please sign in to comment.