diff --git a/tests/unit/test_cli.py b/tests/unit/test_cli.py index e4a331c..b219ad8 100644 --- a/tests/unit/test_cli.py +++ b/tests/unit/test_cli.py @@ -21,6 +21,8 @@ def common_expected_args(): "output_single_stem": None, "invert_using_spec": False, "sample_rate": 44100, + "use_autocast": False, + "use_soundfile": False, "mdx_params": {"hop_length": 1024, "segment_size": 256, "overlap": 0.25, "batch_size": 1, "enable_denoise": False}, "vr_params": {"batch_size": 1, "window_size": 512, "aggression": 5, "enable_tta": False, "enable_post_process": False, "post_process_threshold": 0.2, "high_end_process": False}, "demucs_params": {"segment_size": "Default", "shifts": 2, "overlap": 0.25, "segments_enabled": True}, @@ -114,10 +116,11 @@ def test_cli_output_dir_argument(common_expected_args): main() # Update expected args for this specific test - common_expected_args["output_dir"] = "/custom/output/dir" + expected_args = common_expected_args.copy() + expected_args["output_dir"] = "/custom/output/dir" # Assertions - mock_separator.assert_called_once_with(**common_expected_args) + mock_separator.assert_called_once_with(**expected_args) # Test using output format argument @@ -130,10 +133,11 @@ def test_cli_output_format_argument(common_expected_args): main() # Update expected args for this specific test - common_expected_args["output_format"] = "MP3" + expected_args = common_expected_args.copy() + expected_args["output_format"] = "MP3" # Assertions - mock_separator.assert_called_once_with(**common_expected_args) + mock_separator.assert_called_once_with(**expected_args) # Test using normalization_threshold argument @@ -146,10 +150,11 @@ def test_cli_normalization_threshold_argument(common_expected_args): main() # Update expected args for this specific test - common_expected_args["normalization_threshold"] = 0.75 + expected_args = common_expected_args.copy() + expected_args["normalization_threshold"] = 0.75 # Assertions - mock_separator.assert_called_once_with(**common_expected_args) + mock_separator.assert_called_once_with(**expected_args) # Test using normalization_threshold argument def test_cli_amplification_threshold_argument(common_expected_args): @@ -161,10 +166,11 @@ def test_cli_amplification_threshold_argument(common_expected_args): main() # Update expected args for this specific test - common_expected_args["amplification_threshold"] = 0.75 + expected_args = common_expected_args.copy() + expected_args["amplification_threshold"] = 0.75 # Assertions - mock_separator.assert_called_once_with(**common_expected_args) + mock_separator.assert_called_once_with(**expected_args) # Test using single stem argument def test_cli_single_stem_argument(common_expected_args): @@ -176,10 +182,11 @@ def test_cli_single_stem_argument(common_expected_args): main() # Update expected args for this specific test - common_expected_args["output_single_stem"] = "instrumental" + expected_args = common_expected_args.copy() + expected_args["output_single_stem"] = "instrumental" # Assertions - mock_separator.assert_called_once_with(**common_expected_args) + mock_separator.assert_called_once_with(**expected_args) # Test using invert spectrogram argument @@ -192,7 +199,24 @@ def test_cli_invert_spectrogram_argument(common_expected_args): main() # Update expected args for this specific test - common_expected_args["invert_using_spec"] = True + expected_args = common_expected_args.copy() + expected_args["invert_using_spec"] = True # Assertions - mock_separator.assert_called_once_with(**common_expected_args) + mock_separator.assert_called_once_with(**expected_args) + +# Test using use_autocast argument +def test_cli_use_autocast_argument(common_expected_args): + test_args = ["cli.py", "test_audio.mp3", "--use_autocast"] + with patch("sys.argv", test_args): + with patch("audio_separator.separator.Separator") as mock_separator: + mock_separator_instance = mock_separator.return_value + mock_separator_instance.separate.return_value = ["output_file.mp3"] + main() + + # Update expected args for this specific test + expected_args = common_expected_args.copy() + expected_args["use_autocast"] = True + + # Assertions + mock_separator.assert_called_once_with(**expected_args)