diff --git a/spotiflow/cli/predict.py b/spotiflow/cli/predict.py index cb93d29..fcb492b 100644 --- a/spotiflow/cli/predict.py +++ b/spotiflow/cli/predict.py @@ -61,8 +61,8 @@ def get_args(): "--out-dir", type=Path, required=False, - default=None, - help="Output directory. If not provided, will create a 'spotiflow_results' subfolder in the input folder and write the CSV(s) there.", + default='spotiflow_results', + help="Output directory to write the CSV(s). If not provided, will create a 'spotiflow_results' subfolder in the current folder.", ) predict = parser.add_argument_group( diff --git a/spotiflow/cli/train.py b/spotiflow/cli/train.py index e30ed18..6cf846b 100644 --- a/spotiflow/cli/train.py +++ b/spotiflow/cli/train.py @@ -64,8 +64,9 @@ def get_args() -> argparse.Namespace: "-o", "--outdir", type=Path, - required=True, - help="Output directory where the model will be stored.", + required=False, + default="spotiflow_model", + help="Output directory where the model will be stored (defaults to 'spotiflow_model' in current directory).", ) model_args = parser.add_argument_group(