From ba5d0f0870db0179182946319ac24bde55100f36 Mon Sep 17 00:00:00 2001 From: erdem Date: Fri, 16 Aug 2024 16:10:08 +0100 Subject: [PATCH] Revise score calculation in `calculate_safety_score` --- src/calculators.py | 4 +++- src/main.py | 5 +---- src/schemas.py | 2 +- tests/functional/test_cli.py | 10 +++++----- tests/unit/test_calculators.py | 10 +++++----- 5 files changed, 15 insertions(+), 16 deletions(-) diff --git a/src/calculators.py b/src/calculators.py index 1ccf1db..c45dee4 100644 --- a/src/calculators.py +++ b/src/calculators.py @@ -1,4 +1,5 @@ from typing import List, Optional + from typeguard import typechecked @@ -36,7 +37,8 @@ def calculate_safety_score( inverse_accident_history = 1 / (accident_history + 1) score = inverse_traffic_density score += road_quality * inverse_accident_history - score += lighting_conditions * average_speed + score += lighting_conditions + score += average_speed if incident_reports is not None: score -= incident_reports diff --git a/src/main.py b/src/main.py index 54beee5..890663d 100644 --- a/src/main.py +++ b/src/main.py @@ -79,10 +79,7 @@ def main(): try: if args.file: input_data = load_json_file(args.file) - if args.json: - import ipdb - - ipdb.set_trace() + elif args.json: input_data = load_json(args.json) else: print("No argument provided, check `python src/main.py --help") diff --git a/src/schemas.py b/src/schemas.py index 31497c0..17a298a 100644 --- a/src/schemas.py +++ b/src/schemas.py @@ -45,7 +45,7 @@ class OverallCommuteQualitySchema(BaseModel): @model_validator(mode="before") @classmethod - def set_commute_quality(cls, data: dict) -> dict: + def determine_commute_quality(cls, data: dict) -> dict: score = data.get("quality_score") if score is None: raise ValueError("quality_score is required to determine `commute_quality`") diff --git a/tests/functional/test_cli.py b/tests/functional/test_cli.py index 4769c0f..6fc6308 100644 --- a/tests/functional/test_cli.py +++ b/tests/functional/test_cli.py @@ -8,8 +8,8 @@ def test_valid_input_path(run_cli_command, valid_input_path): assert error == "" result = json.loads(output) - assert result["quality_score"] == 88.35474254742546 - assert result["commute_quality"] == "Good" + assert result["quality_score"] == 44.388075880758805 + assert result["commute_quality"] == "Average" def test_valid_input_json(run_cli_command, valid_input_json): @@ -19,8 +19,8 @@ def test_valid_input_json(run_cli_command, valid_input_json): assert error == "" result = json.loads(output) - assert result["quality_score"] == 88.35474254742546 - assert result["commute_quality"] == "Good" + assert result["quality_score"] == 44.388075880758805 + assert result["commute_quality"] == "Average" def test_valid_input_without_optionals(run_cli_command, valid_input_without_optionals_path): @@ -30,7 +30,7 @@ def test_valid_input_without_optionals(run_cli_command, valid_input_without_opti assert error == "" result = json.loads(output) - assert result["quality_score"] == 31.519445362371297 + assert result["quality_score"] == 25.25277869570463 assert result["commute_quality"] == "Average" diff --git a/tests/unit/test_calculators.py b/tests/unit/test_calculators.py index 955cc82..7ec7eeb 100644 --- a/tests/unit/test_calculators.py +++ b/tests/unit/test_calculators.py @@ -1,8 +1,8 @@ from src.calculators import ( calculate_comfort_index, - calculate_traffic_flow_efficiency, - calculate_safety_score, calculate_overall_commute_quality_score, + calculate_safety_score, + calculate_traffic_flow_efficiency, ) @@ -19,9 +19,9 @@ def test_calculate_traffic_flow_efficiency(): def test_calculate_safety_score(): - assert calculate_safety_score(8, 7, 2, 50, [0.3, 0.4, 0.3]) == 353.6666666666667 - assert calculate_safety_score(9, 8, 1, 60, [0.2, 0.2, 0.1], 3, 5) == 478.5 - assert calculate_safety_score(7, 6, 3, 40, [0.5, 0.5], 1) == 241.75 + assert calculate_safety_score(8, 7, 2, 50, [0.3, 0.4, 0.3]) == 60.666666666666664 + assert calculate_safety_score(9, 8, 1, 60, [0.2, 0.2, 0.1], 3, 5) == 66.5 + assert calculate_safety_score(7, 6, 3, 40, [0.5, 0.5], 1) == 47.75 def test_calculate_overall_commute_quality_score():