diff --git a/src/predict.py b/src/predict.py index e2f763e..5f800cc 100644 --- a/src/predict.py +++ b/src/predict.py @@ -2,12 +2,11 @@ """Predict -Usage: predict.py --test_file= --model_file= --target_name= +Usage: predict.py --test_file= --model_file= Options: --test_file= the test dataframe to predict --model_file= Path (including filename) of the model ---target_name= if test_file has a target col, put the column name here """ import pickle @@ -36,5 +35,5 @@ def main(test_file, model_file, target_name): evaluate_performance(y_test, model.predict_proba(X_test)[:, 1], save="./../results/test_metrics.jpg") if __name__ == "__main__": - main(opt["--test_file"], opt["--model_file"], opt["--target_name"]) + main(opt["--test_file"], opt["--model_file"], "Stars")