diff --git a/comptages/test/__init__.py b/comptages/test/__init__.py index e69de29b..1ba2fc29 100644 --- a/comptages/test/__init__.py +++ b/comptages/test/__init__.py @@ -0,0 +1,32 @@ +from datetime import datetime +from typing import Optional +from comptages.datamodel import models +import pytz + + +def yearly_count_for( + year: int, + installation: models.Installation, + class_: Optional[models.Class] = None, + model: Optional[models.Model] = None, + device: Optional[models.Device] = None, + sensor_type: Optional[models.SensorType] = None, +) -> models.Count: + tz = pytz.timezone("Europe/Zurich") + model = model or models.Model.objects.all()[0] + device = device or models.Device.objects.all()[0] + sensor_type = sensor_type or models.SensorType.objects.all()[0] + class_ = class_ or models.Class.objects.get(name="SWISS10") + return models.Count.objects.create( + start_put_date=tz.localize(datetime(year, 1, 1)), + start_service_date=tz.localize(datetime(year, 1, 8)), + start_process_date=tz.localize(datetime(year, 1, 15)), + end_process_date=tz.localize(datetime(year, 12, 17)), + end_service_date=tz.localize(datetime(year, 12, 24)), + end_put_date=tz.localize(datetime(year, 12, 31)), + id_model=model, + id_device=device, + id_sensor_type=sensor_type, + id_class=class_, + id_installation=installation, + ) diff --git a/comptages/test/test_report.py b/comptages/test/test_report.py index bf0bf569..92b063b8 100644 --- a/comptages/test/test_report.py +++ b/comptages/test/test_report.py @@ -1,13 +1,11 @@ import decimal from itertools import chain -import pytz -from datetime import datetime from django.test import TransactionTestCase from django.core.management import call_command from django.db.models.manager import Manager from comptages.report.yearly_report_bike import YearlyReportBike -from comptages.test import utils +from comptages.test import utils, yearly_count_for from comptages.datamodel import models from comptages.core import report, importer @@ -24,53 +22,15 @@ def setUp(self): def test_report(self): # Create count and import some data - model = models.Model.objects.all()[0] - device = models.Device.objects.all()[0] - sensor_type = models.SensorType.objects.all()[0] - class_ = models.Class.objects.get(name="SWISS10") installation = models.Installation.objects.get(name="00056520") - tz = pytz.timezone("Europe/Zurich") - - count = models.Count.objects.create( - start_service_date=tz.localize(datetime(2021, 10, 11)), - end_service_date=tz.localize(datetime(2021, 10, 17)), - start_process_date=tz.localize(datetime(2021, 10, 11)), - end_process_date=tz.localize(datetime(2021, 10, 17)), - start_put_date=tz.localize(datetime(2021, 10, 11)), - end_put_date=tz.localize(datetime(2021, 10, 17)), - id_model=model, - id_device=device, - id_sensor_type=sensor_type, - id_class=class_, - id_installation=installation, - ) - + count = yearly_count_for(2021, installation) importer.import_file(utils.test_data_path("00056520.V01"), count) importer.import_file(utils.test_data_path("00056520.V02"), count) - report.prepare_reports("/tmp/", count) def test_ensure_non_rounded_values(self): - model = models.Model.objects.all()[0] - device = models.Device.objects.all()[0] - sensor_type = models.SensorType.objects.all()[0] - class_ = models.Class.objects.get(name="SPCH13") - installation = models.Installation.objects.get(name="64210836") - - count = models.Count.objects.create( - start_service_date=datetime(2021, 9, 10), - end_service_date=datetime(2021, 9, 21), - start_process_date=datetime(2021, 9, 10), - end_process_date=datetime(2021, 9, 21), - start_put_date=datetime(2021, 9, 10), - end_put_date=datetime(2021, 9, 21), - id_model=model, - id_device=device, - id_sensor_type=sensor_type, - id_class=class_, - id_installation=installation, - ) - + installation = models.Installation.objects.get(name="00056520") + count = yearly_count_for(2021, installation) importer.import_file(utils.test_data_path("64210836_TCHO-Capitaine.txt"), count) lanes_installation: Manager = installation.lane_set self.assertIsNotNone(lanes_installation) @@ -79,18 +39,18 @@ def test_ensure_non_rounded_values(self): lane__id=lanes_installation.values_list("pk", flat=True)[0] ) report = YearlyReportBike("template_yearly_bike.xlsx", 2021, section_id) - tjms_dir1 = ( - decimal.Decimal(v) - for v in report.values_by_hour_and_direction(1).values_list( - "tjm", flat=True - ) + report_dir1 = report.values_by_hour_and_direction(1).values_list( + "tjm", flat=True ) - tjms_dir2 = ( - decimal.Decimal(v) - for v in report.values_by_hour_and_direction(2).values_list( - "tjm", flat=True - ) + report_dir2 = report.values_by_hour_and_direction(2).values_list( + "tjm", flat=True ) + self.assertTrue(report_dir1.exists()) + self.assertTrue(report_dir2.exists()) + + tjms_dir1 = (decimal.Decimal(v) for v in report_dir1) + tjms_dir2 = (decimal.Decimal(v) for v in report_dir2) with self.subTest(): for value in chain(tjms_dir1, tjms_dir2): + print(value) self.assertEqual(value.as_tuple().exponent, 3)