From 1b5b3ad0d0c400575ad113e365a0c0cf9476e977 Mon Sep 17 00:00:00 2001 From: Martin Mader Date: Fri, 19 Jan 2024 11:43:22 +0100 Subject: [PATCH 1/5] add ruff config and pre-commit hook --- .flake8 | 5 --- .github/workflows/python-app.yml | 6 +++- .pre-commit-config.yaml | 10 ++++++ README.md | 18 +++------- pyproject.toml | 5 --- requirements.txt | 2 ++ ruff.toml | 61 ++++++++++++++++++++++++++++++++ 7 files changed, 82 insertions(+), 25 deletions(-) delete mode 100644 .flake8 create mode 100644 .pre-commit-config.yaml create mode 100644 ruff.toml diff --git a/.flake8 b/.flake8 deleted file mode 100644 index b9a569e7..00000000 --- a/.flake8 +++ /dev/null @@ -1,5 +0,0 @@ -[flake8] -ignore = E203, E266, E501, W503, F403, F401 -max-line-length = 79 -max-complexity = 18 -select = B,C,E,F,W,T4,B9 \ No newline at end of file diff --git a/.github/workflows/python-app.yml b/.github/workflows/python-app.yml index e16fb9e4..fa3a9b05 100644 --- a/.github/workflows/python-app.yml +++ b/.github/workflows/python-app.yml @@ -35,9 +35,13 @@ jobs: - name: Install dependencies run: | python -m pip install --upgrade pip - pip install flake8 pip install -r requirements.txt + - name: Lint with Ruff + run: | + pip install ruff + ruff --output-format=github . + - name: Run tests run: | pip install -r test/requirements.txt diff --git a/.pre-commit-config.yaml b/.pre-commit-config.yaml new file mode 100644 index 00000000..6eca08c0 --- /dev/null +++ b/.pre-commit-config.yaml @@ -0,0 +1,10 @@ +repos: +- repo: https://github.com/astral-sh/ruff-pre-commit + # Ruff version. + rev: v0.1.13 + hooks: + # Run the linter. + - id: ruff + args: [ --fix ] + # Run the formatter. + - id: ruff-format \ No newline at end of file diff --git a/README.md b/README.md index 2a33397d..83cf3ccd 100644 --- a/README.md +++ b/README.md @@ -218,18 +218,8 @@ alembic revision --autogenerate -m "your revision comment" ### Coding Style & Formatting -Please take advantage of the following tooling: +We are using +[![Ruff](https://img.shields.io/endpoint?url=https://raw.githubusercontent.com/astral-sh/ruff/main/assets/badge/v2.json)](https://github.com/astral-sh/ruff) +for linting and code formatting. -```bash -pip install isort autoflake black -``` - -Black re-formats the code, isort orders the imports and flake8 checks for remaining issues. -Example usage: - -```bash -isort --force-single-line-imports . -autoflake --remove-all-unused-imports -i -r --exclude ./alembic . -# Note: '3' means 3-vert-hanging multiline imports -isort --multi-line 3 . -``` +A pre-commit hook can be set up with `pre-commit install` after installing dependencies. diff --git a/pyproject.toml b/pyproject.toml index b892b137..5d4dfb0d 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -7,7 +7,6 @@ version = "0.1.0" Click = "^7.0" [dev-dependencies] -black = { version = "^18.3-alpha.0", python = "^3.9" } [build-system] requires = [ @@ -15,10 +14,6 @@ requires = [ ] build-backend = "setuptools.build_meta" -[tool.black] -line-length = 88 -target_version = ['py39'] - [tool.pytest.ini_options] # Set logging for code under test diff --git a/requirements.txt b/requirements.txt index a5ee31c4..524ed6b8 100644 --- a/requirements.txt +++ b/requirements.txt @@ -34,6 +34,7 @@ oauthlib==3.2.2 openpyxl==3.0.9 packaging==21.3 pandas==1.4.1 +pre-commit==3.6.0 protobuf==4.21.9 psycopg2==2.9.3 pyasn1==0.4.8 @@ -47,6 +48,7 @@ pytz==2022.1 requests==2.31.0 requests-oauthlib==1.3.1 rsa==4.9 +ruff==0.1.13 Shapely==1.8.2 six==1.16.0 smmap==5.0.0 diff --git a/ruff.toml b/ruff.toml new file mode 100644 index 00000000..30dfdbcf --- /dev/null +++ b/ruff.toml @@ -0,0 +1,61 @@ +# Exclude a variety of commonly ignored directories. +exclude = [ + ".bzr", + ".direnv", + ".eggs", + ".git", + ".git-rewrite", + ".hg", + ".ipynb_checkpoints", + ".mypy_cache", + ".nox", + ".pants.d", + ".pyenv", + ".pytest_cache", + ".pytype", + ".ruff_cache", + ".svn", + ".tox", + ".venv", + ".vscode", + "__pypackages__", + "_build", + "buck-out", + "build", + "dist", + "node_modules", + "site-packages", + "venv", +] + +# Same as Black. +line-length = 88 +indent-width = 4 + +# Assume Python 3.8 +target-version = "py38" + +[lint] +# Enable Pyflakes (`F`) and a subset of the pycodestyle (`E`) codes by default. +select = ["E4", "E7", "E9", "F"] +ignore = [] + +# Allow fix for all enabled rules (when `--fix`) is provided. +fixable = ["ALL"] +unfixable = [] + +# Allow unused variables when underscore-prefixed. +dummy-variable-rgx = "^(_+|(_+[a-zA-Z0-9_]*[a-zA-Z0-9]+?))$" + +[format] +# Like Black, use double quotes for strings. +quote-style = "double" + +# Like Black, indent with spaces, rather than tabs. +indent-style = "space" + +# Like Black, respect magic trailing commas. +skip-magic-trailing-comma = false + +# Like Black, automatically detect the appropriate line ending. +line-ending = "auto" \ No newline at end of file From d87b57e68a0d2e174fa643a4acf665e8ae6cf2d8 Mon Sep 17 00:00:00 2001 From: Martin Mader Date: Fri, 19 Jan 2024 11:46:47 +0100 Subject: [PATCH 2/5] fix lint issues and code style --- alembic/env.py | 22 +- charging_stations_pipelines/db_utils.py | 21 +- .../attribute_match_thresholds_strategy.py | 61 ++-- .../deduplication/merger.py | 12 +- charging_stations_pipelines/models/address.py | 4 +- .../models/charging.py | 14 +- charging_stations_pipelines/models/station.py | 4 +- .../pipelines/at/__init__.py | 4 +- .../pipelines/at/econtrol_crawler.py | 33 +- .../pipelines/at/econtrol_mapper.py | 6 +- .../pipelines/de/__init__.py | 4 + .../pipelines/de/bna.py | 4 +- .../pipelines/de/bna_crawler.py | 20 +- .../pipelines/de/bna_mapper.py | 14 +- .../pipelines/fr/france.py | 39 ++- .../pipelines/fr/france_mapper.py | 12 +- .../pipelines/gb/gb_mapper.py | 20 +- .../pipelines/gb/gb_receiver.py | 2 +- .../pipelines/gb/gbgov.py | 4 +- .../pipelines/nobil/nobil_pipeline.py | 90 +++-- .../pipelines/ocm/ocm.py | 24 +- .../pipelines/ocm/ocm_extractor.py | 112 ++++--- .../pipelines/osm/osm.py | 4 +- .../pipelines/pipeline_factory.py | 8 +- .../pipelines/station_table_updater.py | 18 +- charging_stations_pipelines/shared.py | 2 +- .../stations_data_export.py | 75 +++-- test/pipelines/at/test_econtrol_crawler.py | 9 +- test/pipelines/at/test_econtrol_mapper.py | 307 ++++++++++-------- test/pipelines/de/test_bna_crawler.py | 102 +++--- test/pipelines/osm/test_osm_mapper.py | 3 +- test/shared.py | 6 +- test/test_main.py | 8 +- testdata_import.py | 22 +- testing/testdata.py | 33 +- tests/integration/test_int_de_bna.py | 64 ++-- tests/integration/test_int_merger.py | 128 ++++---- 37 files changed, 805 insertions(+), 510 deletions(-) diff --git a/alembic/env.py b/alembic/env.py index 463b594f..e61c187c 100644 --- a/alembic/env.py +++ b/alembic/env.py @@ -4,12 +4,13 @@ from alembic import context from sqlalchemy import engine_from_config, pool +from charging_stations_pipelines import models, settings +from charging_stations_pipelines.models import address, charging, station current_path = os.path.abspath(".") sys.path.append(current_path) -from charging_stations_pipelines import settings -from charging_stations_pipelines.models import address, charging, station + # this is the Alembic Config object, which provides # access to the values within the .ini file in use. @@ -21,7 +22,8 @@ # add your model's MetaData object here # for 'autogenerate' support -from charging_stations_pipelines import models + + target_metadata = models.Base.metadata # other values from the config, defined by the needs of env.py, @@ -39,12 +41,14 @@ def exclude_tables_from_config(config_): # Excluded tables are defined in alembic.ini in section [alembic:exclude], here for PostGIS table spatial_ref_sys -exclude_tables = exclude_tables_from_config(config.get_section('alembic:exclude')) +exclude_tables = exclude_tables_from_config(config.get_section("alembic:exclude")) -restrict_tables = [address.Address.__tablename__, - charging.Charging.__tablename__, - station.Station.__tablename__, - station.MergedStationSource.__tablename__] +restrict_tables = [ + address.Address.__tablename__, + charging.Charging.__tablename__, + station.Station.__tablename__, + station.MergedStationSource.__tablename__, +] def include_object(object, name, type_, reflected, compare_to): @@ -120,7 +124,7 @@ def run_migrations_online(): ) with context.begin_transaction(): - context.execute(f'set search_path to {target_metadata.schema},public') + context.execute(f"set search_path to {target_metadata.schema},public") context.run_migrations() diff --git a/charging_stations_pipelines/db_utils.py b/charging_stations_pipelines/db_utils.py index fb9c59e1..49fbb3a7 100644 --- a/charging_stations_pipelines/db_utils.py +++ b/charging_stations_pipelines/db_utils.py @@ -12,10 +12,17 @@ def delete_all_data(session: Session): """Deletes all data from the database.""" logger.info("Deleting all data from the database...") - models = [station.MergedStationSource, station.Station, address.Address, charging.Charging] + models = [ + station.MergedStationSource, + station.Station, + address.Address, + charging.Charging, + ] for model in models: logger.info(f"Dropping table: '{model.__tablename__}'...") - session.execute(text(f'TRUNCATE TABLE {model.__tablename__} RESTART IDENTITY CASCADE')) + session.execute( + text(f"TRUNCATE TABLE {model.__tablename__} RESTART IDENTITY CASCADE") + ) session.commit() session.close() logger.info("Finished deleting all data from the database.") @@ -23,9 +30,15 @@ def delete_all_data(session: Session): def delete_all_merged_data(session: Session): """Deletes all merged data from the database.""" - session.execute(text(f'TRUNCATE TABLE {station.MergedStationSource.__tablename__} RESTART IDENTITY')) + session.execute( + text( + f"TRUNCATE TABLE {station.MergedStationSource.__tablename__} RESTART IDENTITY" + ) + ) session.execute(delete(address.Address).where(address.Address.is_merged.is_(True))) - session.execute(delete(charging.Charging).where(charging.Charging.is_merged.is_(True))) + session.execute( + delete(charging.Charging).where(charging.Charging.is_merged.is_(True)) + ) session.execute(delete(station.Station).where(station.Station.is_merged.is_(True))) session.execute(update(station.Station).values(merge_status=None)) session.commit() diff --git a/charging_stations_pipelines/deduplication/attribute_match_thresholds_strategy.py b/charging_stations_pipelines/deduplication/attribute_match_thresholds_strategy.py index e4abd67f..a6fccebe 100644 --- a/charging_stations_pipelines/deduplication/attribute_match_thresholds_strategy.py +++ b/charging_stations_pipelines/deduplication/attribute_match_thresholds_strategy.py @@ -7,37 +7,46 @@ def attribute_match_thresholds_duplicates( - current_station: pd.Series, - duplicate_candidates: pd.DataFrame, - station_id_name: str, - max_distance: int = 100, + current_station: pd.Series, + duplicate_candidates: pd.DataFrame, + station_id_name: str, + max_distance: int = 100, ) -> pd.DataFrame: pd.options.mode.chained_assignment = None - remaining_duplicate_candidates = duplicate_candidates[~duplicate_candidates["is_duplicate"].astype(bool)] + remaining_duplicate_candidates = duplicate_candidates[ + ~duplicate_candidates["is_duplicate"].astype(bool) + ] if remaining_duplicate_candidates.empty: return duplicate_candidates - logger.debug(f"### Searching for duplicates to station {current_station.source_id}, " - f"operator: {current_station.operator}, " - f"address: {current_station['address']}" - ) + logger.debug( + f"### Searching for duplicates to station {current_station.source_id}, " + f"operator: {current_station.operator}, " + f"address: {current_station['address']}" + ) logger.debug(f"{len(remaining_duplicate_candidates)} duplicate candidates") - remaining_duplicate_candidates["operator_match"] = remaining_duplicate_candidates.operator.apply( + remaining_duplicate_candidates[ + "operator_match" + ] = remaining_duplicate_candidates.operator.apply( lambda x: SequenceMatcher(None, current_station.operator, str(x)).ratio() if (current_station.operator is not None) & (x is not None) else 0.0 ) - remaining_duplicate_candidates["address_match"] = remaining_duplicate_candidates.address.apply( - lambda x: SequenceMatcher(None, current_station['address'], x).ratio() - if (current_station['address'] != "None,None") & (x != "None,None") + remaining_duplicate_candidates[ + "address_match" + ] = remaining_duplicate_candidates.address.apply( + lambda x: SequenceMatcher(None, current_station["address"], x).ratio() + if (current_station["address"] != "None,None") & (x != "None,None") else 0.0, ) # this is always the distance to the initial central charging station - remaining_duplicate_candidates["distance_match"] = 1 - remaining_duplicate_candidates["distance"] / max_distance + remaining_duplicate_candidates["distance_match"] = ( + 1 - remaining_duplicate_candidates["distance"] / max_distance + ) def is_duplicate_by_score(duplicate_candidate): if duplicate_candidate["address_match"] >= 0.7: @@ -51,22 +60,28 @@ def is_duplicate_by_score(duplicate_candidate): logger.debug("duplicate according to distance") else: is_duplicate = False - logger.debug(f"no duplicate: {duplicate_candidate.data_source}, " - f"source id: {duplicate_candidate.source_id}, " - f"operator: {duplicate_candidate.operator}, " - f"address: {duplicate_candidate.address}, " - f"row id: {duplicate_candidate.name}, " - f"distance: {duplicate_candidate.distance}") + logger.debug( + f"no duplicate: {duplicate_candidate.data_source}, " + f"source id: {duplicate_candidate.source_id}, " + f"operator: {duplicate_candidate.operator}, " + f"address: {duplicate_candidate.address}, " + f"row id: {duplicate_candidate.name}, " + f"distance: {duplicate_candidate.distance}" + ) return is_duplicate - remaining_duplicate_candidates["is_duplicate"] = remaining_duplicate_candidates.apply(is_duplicate_by_score, axis=1) + remaining_duplicate_candidates[ + "is_duplicate" + ] = remaining_duplicate_candidates.apply(is_duplicate_by_score, axis=1) # update original candidates duplicate_candidates.update(remaining_duplicate_candidates) # for all duplicates found via OSM, which has most of the time no address info, # run the check again against all candidates # so e.g. if we have a duplicate with address it can be matched to other data sources via this attribute - new_duplicates = remaining_duplicate_candidates[remaining_duplicate_candidates["is_duplicate"]] + new_duplicates = remaining_duplicate_candidates[ + remaining_duplicate_candidates["is_duplicate"] + ] for idx in range(new_duplicates.shape[0]): current_station: pd.Series = new_duplicates.iloc[idx] @@ -77,7 +92,7 @@ def is_duplicate_by_score(duplicate_candidate): duplicate_candidates = attribute_match_thresholds_duplicates( current_station=current_station, duplicate_candidates=duplicate_candidates, - station_id_name=station_id_name + station_id_name=station_id_name, ) return duplicate_candidates diff --git a/charging_stations_pipelines/deduplication/merger.py b/charging_stations_pipelines/deduplication/merger.py index ff0eb77d..fc33e3d5 100644 --- a/charging_stations_pipelines/deduplication/merger.py +++ b/charging_stations_pipelines/deduplication/merger.py @@ -9,7 +9,9 @@ from tqdm import tqdm from charging_stations_pipelines import settings -from charging_stations_pipelines.deduplication import attribute_match_thresholds_strategy +from charging_stations_pipelines.deduplication import ( + attribute_match_thresholds_strategy, +) from charging_stations_pipelines.models.station import MergedStationSource, Station logger = logging.getLogger(__name__) @@ -93,7 +95,9 @@ def _get_station_with_address_and_charging_by_priority( ): merged_station: Optional[Station] = None for source in [self.gov_source, "OCM", "OSM"]: - station_id = stations_to_merge[stations_to_merge["data_source"] == source]["station_id_col"] + station_id = stations_to_merge[stations_to_merge["data_source"] == source][ + "station_id_col" + ] if len(station_id) > 0: station_id = int(station_id.iloc[0]) station, address, charging = self.get_station_with_address_and_charging( @@ -166,7 +170,9 @@ def _merge_duplicates(self, stations_to_merge, session) -> Station: def get_station_with_address_and_charging(self, session, station_id): # get station from DB and create new object - merged_station: Station = session.query(Station).filter(Station.id == station_id).first() + merged_station: Station = ( + session.query(Station).filter(Station.id == station_id).first() + ) address = merged_station.address charging = merged_station.charging session.expunge(merged_station) # expunge the object from session diff --git a/charging_stations_pipelines/models/address.py b/charging_stations_pipelines/models/address.py index ec367460..b176b9ec 100644 --- a/charging_stations_pipelines/models/address.py +++ b/charging_stations_pipelines/models/address.py @@ -10,7 +10,9 @@ class Address(Base): __tablename__ = f"{settings.db_table_prefix}address" id = Column(Integer, primary_key=True, autoincrement=True) - station_id = Column(Integer, ForeignKey(f"{Station.__tablename__}.id"), nullable=False, unique=True) + station_id = Column( + Integer, ForeignKey(f"{Station.__tablename__}.id"), nullable=False, unique=True + ) date_created = Column(Date) date_updated = Column(Date) street = Column(String) diff --git a/charging_stations_pipelines/models/charging.py b/charging_stations_pipelines/models/charging.py index 9b72e9fa..247fd238 100644 --- a/charging_stations_pipelines/models/charging.py +++ b/charging_stations_pipelines/models/charging.py @@ -1,12 +1,4 @@ -from sqlalchemy import ( - ARRAY, - Boolean, - Column, - Date, - ForeignKey, - Integer, - String -) +from sqlalchemy import ARRAY, Boolean, Column, Date, ForeignKey, Integer, String from sqlalchemy.orm import relationship from sqlalchemy.types import Float @@ -18,7 +10,9 @@ class Charging(Base): __tablename__ = f"{settings.db_table_prefix}charging" id = Column(Integer, primary_key=True, autoincrement=True) - station_id = Column(Integer, ForeignKey(f"{Station.__tablename__}.id"), nullable=False, unique=True) + station_id = Column( + Integer, ForeignKey(f"{Station.__tablename__}.id"), nullable=False, unique=True + ) date_created = Column(Date) date_updated = Column(Date) capacity = Column(Integer) diff --git a/charging_stations_pipelines/models/station.py b/charging_stations_pipelines/models/station.py index d3bf32f1..fa5733e1 100644 --- a/charging_stations_pipelines/models/station.py +++ b/charging_stations_pipelines/models/station.py @@ -9,6 +9,7 @@ class Station(Base): """Station class for representing a station in a database.""" + __tablename__ = f"{settings.db_table_prefix}stations" id = Column(Integer, primary_key=True, autoincrement=True) source_id = Column(String, index=True, nullable=True, unique=True) @@ -19,7 +20,7 @@ class Station(Base): operator = Column(String) payment = Column(String) authentication = Column(String) - point = Column(Geography(geometry_type='POINT', srid=4326)) + point = Column(Geography(geometry_type="POINT", srid=4326)) date_created = Column(Date) date_updated = Column(Date) raw_data = Column(JSON) @@ -40,6 +41,7 @@ class Station(Base): class MergedStationSource(Base): """This class represents a merged station source entity.""" + __tablename__ = f"{settings.db_table_prefix}merged_station_source" id = Column(Integer, primary_key=True, autoincrement=True) merged_station_id = Column(Integer, ForeignKey(f"{Station.__tablename__}.id")) diff --git a/charging_stations_pipelines/pipelines/at/__init__.py b/charging_stations_pipelines/pipelines/at/__init__.py index 04551554..5f687afb 100644 --- a/charging_stations_pipelines/pipelines/at/__init__.py +++ b/charging_stations_pipelines/pipelines/at/__init__.py @@ -1,8 +1,8 @@ """The AT package contains the pipelines for the Austrian data source.""" from typing import Final -DATA_SOURCE_KEY: Final[str] = 'AT_ECONTROL' +DATA_SOURCE_KEY: Final[str] = "AT_ECONTROL" """The data source key for the e-control data source.""" -SCOPE_COUNTRIES: Final[list[str]] = ['AT'] +SCOPE_COUNTRIES: Final[list[str]] = ["AT"] """The list of country codes covered by the e-control data source.""" diff --git a/charging_stations_pipelines/pipelines/at/econtrol_crawler.py b/charging_stations_pipelines/pipelines/at/econtrol_crawler.py index 23e7f8cf..414b9292 100644 --- a/charging_stations_pipelines/pipelines/at/econtrol_crawler.py +++ b/charging_stations_pipelines/pipelines/at/econtrol_crawler.py @@ -12,7 +12,9 @@ logger = logging.getLogger(__name__) -def _get_paginated_stations(url: str, headers: dict[str, str] = None) -> Generator[dict[str, Any], None, None]: +def _get_paginated_stations( + url: str, headers: dict[str, str] = None +) -> Generator[dict[str, Any], None, None]: session = requests.Session() session.headers.update(headers) @@ -22,14 +24,14 @@ def _get_paginated_stations(url: str, headers: dict[str, str] = None) -> Generat try: # Sample data from returned JSON chunk: "totalResults":9454,"fromIndex":0,"endIndex":999 - total_count = first_page['totalResults'] + total_count = first_page["totalResults"] logger.info(f"Total count of stations: {total_count}") yield first_page - idx_start = first_page['fromIndex'] - idx_end = first_page['endIndex'] + idx_start = first_page["fromIndex"] + idx_end = first_page["endIndex"] except KeyError as e: - logging.fatal(f'Failed to parse response:\n{first_page}\n{e}') + logging.fatal(f"Failed to parse response:\n{first_page}\n{e}") raise e # Number of datapoints (=station) per page, e.g. 1000 @@ -44,8 +46,10 @@ def _get_paginated_stations(url: str, headers: dict[str, str] = None) -> Generat idx_start = page_size * (page_num - 1) idx_end = min(page_size * page_num - 1, total_count - 1) - logger.debug(f'Downloading chunk: {idx_start}..{idx_end}') - next_page = session.get(url, params={'fromIndex': idx_start, 'endIndex': idx_end}).json() + logger.debug(f"Downloading chunk: {idx_start}..{idx_end}") + next_page = session.get( + url, params={"fromIndex": idx_start, "endIndex": idx_end} + ).json() yield next_page @@ -57,23 +61,26 @@ def get_data(tmp_data_path): :return: None :rtype: None """ - url: Final[str] = 'https://api.e-control.at/charge/1.0/search/stations' + url: Final[str] = "https://api.e-control.at/charge/1.0/search/stations" # HTTP header # TODO fix the issue with the api key # econtrol_at_apikey = os.getenv('ECONTROL_AT_APIKEY') # econtrol_at_domain = os.getenv('ECONTROL_AT_DOMAIN') - headers = {'Authorization': f"Basic {os.getenv('ECONTROL_AT_AUTH')}", 'User-Agent': 'Mozilla/5.0'} - logger.debug(f'Using HTTP headers:\n{headers}') + headers = { + "Authorization": f"Basic {os.getenv('ECONTROL_AT_AUTH')}", + "User-Agent": "Mozilla/5.0", + } + logger.debug(f"Using HTTP headers:\n{headers}") logger.info(f"Downloading {at.DATA_SOURCE_KEY} data from {url}...") - with open(tmp_data_path, 'w') as f: + with open(tmp_data_path, "w") as f: for page in _get_paginated_stations(url, headers): logger.debug(f"Getting data: {page['fromIndex']}..{page['endIndex']}") # Save as newline-delimited JSON (*.ndjson), i.e. one JSON object per line - for station in page['stations']: + for station in page["stations"]: json.dump(station, f, ensure_ascii=False) - f.write('\n') + f.write("\n") logger.info(f"Downloaded {at.DATA_SOURCE_KEY} data to: {tmp_data_path}") logger.info(f"Downloaded file size: {os.path.getsize(tmp_data_path)} bytes") diff --git a/charging_stations_pipelines/pipelines/at/econtrol_mapper.py b/charging_stations_pipelines/pipelines/at/econtrol_mapper.py index f29b220f..6fad36f7 100644 --- a/charging_stations_pipelines/pipelines/at/econtrol_mapper.py +++ b/charging_stations_pipelines/pipelines/at/econtrol_mapper.py @@ -79,7 +79,7 @@ def map_station(row: pd.Series, country_code: str) -> Station: - 'date_created': The creation date of the station. - 'date_updated': The update date of the station. """ - operator_id = str_strip_whitespace(row.get("evseOperatorId")) or None + operator_id = str_strip_whitespace(row.get("evseOperatorId")) or None station_id = str_strip_whitespace(row.get("evseStationId")) or None station = Station() @@ -103,7 +103,9 @@ def map_station(row: pd.Series, country_code: str) -> Station: return station -def map_address(row: pd.Series, country_code: str, station_id: Optional[int]) -> Address: +def map_address( + row: pd.Series, country_code: str, station_id: Optional[int] +) -> Address: """Maps the given raw datapoint to an Address object. :param row: A datapoint representing the raw data. diff --git a/charging_stations_pipelines/pipelines/de/__init__.py b/charging_stations_pipelines/pipelines/de/__init__.py index 6f74616a..3cabab65 100644 --- a/charging_stations_pipelines/pipelines/de/__init__.py +++ b/charging_stations_pipelines/pipelines/de/__init__.py @@ -8,19 +8,23 @@ class BnaCrawlerException(Exception): """Base class for exceptions in BnaCrawler.""" + pass class FetchWebsiteException(BnaCrawlerException): """Raised when there is an error fetching the website.""" + pass class ExtractURLException(BnaCrawlerException): """Raised when there is an error extracting the URL.""" + pass class DownloadFileException(BnaCrawlerException): """Raised when there is an error downloading a file.""" + pass diff --git a/charging_stations_pipelines/pipelines/de/bna.py b/charging_stations_pipelines/pipelines/de/bna.py index 3a336de3..13c41c8e 100644 --- a/charging_stations_pipelines/pipelines/de/bna.py +++ b/charging_stations_pipelines/pipelines/de/bna.py @@ -30,7 +30,9 @@ def __init__(self, config: configparser, session: Session, online: bool = False) # All BNA data is from Germany self.country_code = "DE" - self.data_dir: Final[pathlib.Path] = (pathlib.Path(__file__).parents[3] / "data").resolve() + self.data_dir: Final[pathlib.Path] = ( + pathlib.Path(__file__).parents[3] / "data" + ).resolve() def retrieve_data(self): self.data_dir.mkdir(parents=True, exist_ok=True) diff --git a/charging_stations_pipelines/pipelines/de/bna_crawler.py b/charging_stations_pipelines/pipelines/de/bna_crawler.py index 69c7f7ee..091d2a27 100644 --- a/charging_stations_pipelines/pipelines/de/bna_crawler.py +++ b/charging_stations_pipelines/pipelines/de/bna_crawler.py @@ -7,7 +7,11 @@ import requests as requests from bs4 import BeautifulSoup -from charging_stations_pipelines.pipelines.de import DownloadFileException, ExtractURLException, FetchWebsiteException +from charging_stations_pipelines.pipelines.de import ( + DownloadFileException, + ExtractURLException, + FetchWebsiteException, +) from charging_stations_pipelines.shared import download_file logger = logging.getLogger(__name__) @@ -37,15 +41,19 @@ def get_bna_data(tmp_data_path: str) -> None: download_link_elems = soup.find_all("a", class_=LINK_CLASS) download_link_url: Optional[str] = None for link in download_link_elems: - download_link_url = link.get('href') - if (download_link_url - and SEARCH_TERM in download_link_url.lower() - and download_link_url.lower().endswith(FILE_EXTENSION)): + download_link_url = link.get("href") + if ( + download_link_url + and SEARCH_TERM in download_link_url.lower() + and download_link_url.lower().endswith(FILE_EXTENSION) + ): break # Check if the url extraction is successful if download_link_url is None: - raise ExtractURLException("Failed to extract the download url from the website.") + raise ExtractURLException( + "Failed to extract the download url from the website." + ) logger.debug(f"Downloading BNA data from '{download_link_url}'") try: diff --git a/charging_stations_pipelines/pipelines/de/bna_mapper.py b/charging_stations_pipelines/pipelines/de/bna_mapper.py index 0c1cbd23..5790273e 100644 --- a/charging_stations_pipelines/pipelines/de/bna_mapper.py +++ b/charging_stations_pipelines/pipelines/de/bna_mapper.py @@ -28,7 +28,9 @@ def map_station_bna(row: pd.Series): new_station.country_code = "DE" new_station.data_source = bna.DATA_SOURCE_KEY - new_station.source_id = hashlib.sha256(f"{lat}{long}{new_station.data_source}".encode()).hexdigest() + new_station.source_id = hashlib.sha256( + f"{lat}{long}{new_station.data_source}".encode() + ).hexdigest() new_station.operator = row["Betreiber"] new_station.point = from_shape(Point(float(long), float(lat))) @@ -45,7 +47,9 @@ def map_address_bna(row: pd.Series, station_id) -> Address: if len(postcode) == 4: postcode = "0" + postcode if len(postcode) != 5: - logger.debug(f"Failed to process postcode {postcode}! Will set postcode to None!") + logger.debug( + f"Failed to process postcode {postcode}! Will set postcode to None!" + ) postcode = None if len(town) < 2: logger.debug(f"Failed to process town {town}! Will set town to None!") @@ -54,7 +58,11 @@ def map_address_bna(row: pd.Series, station_id) -> Address: address = Address() address.station_id = station_id - address.street = str_strip_whitespace(row.get("Straße")) + " " + str_strip_whitespace(row.get("Hausnummer")) + address.street = ( + str_strip_whitespace(row.get("Straße")) + + " " + + str_strip_whitespace(row.get("Hausnummer")) + ) address.town = town address.postcode = postcode address.district = row["Kreis/kreisfreie Stadt"] diff --git a/charging_stations_pipelines/pipelines/fr/france.py b/charging_stations_pipelines/pipelines/fr/france.py index 28d9f2f4..a1378b22 100644 --- a/charging_stations_pipelines/pipelines/fr/france.py +++ b/charging_stations_pipelines/pipelines/fr/france.py @@ -10,8 +10,14 @@ from tqdm import tqdm from charging_stations_pipelines.pipelines import Pipeline -from charging_stations_pipelines.pipelines.fr.france_mapper import map_address_fra, map_charging_fra, map_station_fra -from charging_stations_pipelines.pipelines.station_table_updater import StationTableUpdater +from charging_stations_pipelines.pipelines.fr.france_mapper import ( + map_address_fra, + map_charging_fra, + map_station_fra, +) +from charging_stations_pipelines.pipelines.station_table_updater import ( + StationTableUpdater, +) from charging_stations_pipelines.shared import download_file, reject_if logger = logging.getLogger(__name__) @@ -19,14 +25,20 @@ class FraPipeline(Pipeline): def _retrieve_data(self): - data_dir = os.path.join(pathlib.Path(__file__).parent.resolve(), "../../..", "data") + data_dir = os.path.join( + pathlib.Path(__file__).parent.resolve(), "../../..", "data" + ) pathlib.Path(data_dir).mkdir(parents=True, exist_ok=True) tmp_data_path = os.path.join(data_dir, self.config["FRGOV"]["filename"]) if self.online: logger.info("Retrieving Online Data") self.download_france_gov_file(tmp_data_path) - self.data = pd.read_csv(os.path.join(data_dir, "france_stations.csv"), delimiter=",", encoding="utf-8", - encoding_errors="replace") + self.data = pd.read_csv( + os.path.join(data_dir, "france_stations.csv"), + delimiter=",", + encoding="utf-8", + encoding_errors="replace", + ) def run(self): logger.info("Running FR GOV Pipeline...") @@ -39,7 +51,9 @@ def run(self): mapped_station = map_station_fra(row) mapped_station.address = mapped_address mapped_station.charging = mapped_charging - station_updater.update_station(station=mapped_station, data_source_key='FRGOV') + station_updater.update_station( + station=mapped_station, data_source_key="FRGOV" + ) station_updater.log_update_station_counts() @staticmethod @@ -50,9 +64,16 @@ def download_france_gov_file(target_file): r = requests.get(base_url, headers={"User-Agent": "Mozilla/5.0"}) soup = BeautifulSoup(r.content, "html.parser") - all_links_on_gov_page = soup.findAll('a') + all_links_on_gov_page = soup.findAll("a") link_to_dataset = list( - filter(lambda a: a["href"].startswith("https://www.data.gouv.fr/fr/datasets"), all_links_on_gov_page)) - reject_if(len(link_to_dataset) != 1, "Could not determine source for french government data") + filter( + lambda a: a["href"].startswith("https://www.data.gouv.fr/fr/datasets"), + all_links_on_gov_page, + ) + ) + reject_if( + len(link_to_dataset) != 1, + "Could not determine source for french government data", + ) download_file(link_to_dataset[0]["href"], target_file) diff --git a/charging_stations_pipelines/pipelines/fr/france_mapper.py b/charging_stations_pipelines/pipelines/fr/france_mapper.py index c82ac8d0..553f0f26 100644 --- a/charging_stations_pipelines/pipelines/fr/france_mapper.py +++ b/charging_stations_pipelines/pipelines/fr/france_mapper.py @@ -35,13 +35,19 @@ def map_station_fra(row: pd.Series) -> Station: station.source_id = row.get("id_station_itinerance") station.operator = row.get("nom_operateur") station.data_source = "FRGOV" - station.point = from_shape(Point(float(check_coordinates(row.get("consolidated_longitude"))), - float(check_coordinates(row.get("consolidated_latitude"))))) + station.point = from_shape( + Point( + float(check_coordinates(row.get("consolidated_longitude"))), + float(check_coordinates(row.get("consolidated_latitude"))), + ) + ) station.date_created = row.get("date_mise_en_service").strptime("%Y-%m-%d") station.date_updated = row.get("date_maj").strptime("%Y-%m-%d") if not pd.isna(row.get("date_mise_en_service")): - station.date_created = datetime.strptime(row.get("date_mise_en_service"), "%Y-%m-%d") + station.date_created = datetime.strptime( + row.get("date_mise_en_service"), "%Y-%m-%d" + ) if not pd.isna(row.get("date_maj")): station.date_updated = datetime.strptime(row.get("date_maj"), "%Y-%m-%d") else: diff --git a/charging_stations_pipelines/pipelines/gb/gb_mapper.py b/charging_stations_pipelines/pipelines/gb/gb_mapper.py index 94eaf3f3..a5aa7582 100644 --- a/charging_stations_pipelines/pipelines/gb/gb_mapper.py +++ b/charging_stations_pipelines/pipelines/gb/gb_mapper.py @@ -32,18 +32,28 @@ def map_station_gb(entry, country_code: str): def map_address_gb(entry, station_id): - postcode_raw: Optional[str] = entry.get("ChargeDeviceLocation").get("Address").get("PostCode") + postcode_raw: Optional[str] = ( + entry.get("ChargeDeviceLocation").get("Address").get("PostCode") + ) postcode: Optional[str] = postcode_raw - town_raw: Optional[str] = entry.get("ChargeDeviceLocation").get("Address").get("PostTown") + town_raw: Optional[str] = ( + entry.get("ChargeDeviceLocation").get("Address").get("PostTown") + ) town: Optional[str] = town_raw if isinstance(town_raw, str) else None - state_raw: Optional[str] = entry.get("ChargeDeviceLocation").get("Address").get("County") + state_raw: Optional[str] = ( + entry.get("ChargeDeviceLocation").get("Address").get("County") + ) state: Optional[str] = state_raw if isinstance(state_raw, str) else None - country: Optional[str] = entry.get("ChargeDeviceLocation").get("Address").get("Country") + country: Optional[str] = ( + entry.get("ChargeDeviceLocation").get("Address").get("Country") + ) - street_raw: Optional[str] = entry.get("ChargeDeviceLocation").get("Address").get("Street") + street_raw: Optional[str] = ( + entry.get("ChargeDeviceLocation").get("Address").get("Street") + ) street: Optional[str] = street_raw if isinstance(street_raw, str) else None map_address = Address() diff --git a/charging_stations_pipelines/pipelines/gb/gb_receiver.py b/charging_stations_pipelines/pipelines/gb/gb_receiver.py index 4ea1340f..848c2a67 100644 --- a/charging_stations_pipelines/pipelines/gb/gb_receiver.py +++ b/charging_stations_pipelines/pipelines/gb/gb_receiver.py @@ -8,7 +8,7 @@ def get_gb_data(tmp_file_path): """Retrieves data from the GB-Gov-Data API and writes it to a temporary file. - See https://chargepoints.dft.gov.uk/api/help.""" + See https://chargepoints.dft.gov.uk/api/help.""" api_url = "https://chargepoints.dft.gov.uk/api/retrieve/registry/format/json/" response: Response = requests.get(api_url) response.json() diff --git a/charging_stations_pipelines/pipelines/gb/gbgov.py b/charging_stations_pipelines/pipelines/gb/gbgov.py index 569f6a15..07a15a40 100644 --- a/charging_stations_pipelines/pipelines/gb/gbgov.py +++ b/charging_stations_pipelines/pipelines/gb/gbgov.py @@ -56,5 +56,7 @@ def run(self): mapped_station = map_station_gb(entry, " GB") mapped_station.address = mapped_address mapped_station.charging = mapped_charging - station_updater.update_station(station=mapped_station, data_source_key="GBGOV") + station_updater.update_station( + station=mapped_station, data_source_key="GBGOV" + ) station_updater.log_update_station_counts() diff --git a/charging_stations_pipelines/pipelines/nobil/nobil_pipeline.py b/charging_stations_pipelines/pipelines/nobil/nobil_pipeline.py index 354ff7ec..8e11aa88 100644 --- a/charging_stations_pipelines/pipelines/nobil/nobil_pipeline.py +++ b/charging_stations_pipelines/pipelines/nobil/nobil_pipeline.py @@ -30,8 +30,20 @@ def __init__(self, power_in_kw: Decimal): class NobilStation: """This class represents a station from the Nobil API.""" - def __init__(self, station_id, operator, position, created, updated, street, house_number, zipcode, city, - number_charging_points, connectors: list[NobilConnector]): + def __init__( + self, + station_id, + operator, + position, + created, + updated, + street, + house_number, + zipcode, + city, + number_charging_points, + connectors: list[NobilConnector], + ): self.station_id = station_id self.operator = operator self.position = position @@ -47,13 +59,23 @@ def __init__(self, station_id, operator, position, created, updated, street, hou def _parse_json_data(json_data) -> list[NobilStation]: all_nobil_stations: list[NobilStation] = [] - for s in json_data['chargerstations']: - csmd = s['csmd'] - parsed_connectors = parse_nobil_connectors(s['attr']['conn']) - - nobil_station = NobilStation(csmd['id'], csmd['Operator'], csmd['Position'], csmd['Created'], - csmd['Updated'], csmd['Street'], csmd['House_number'], csmd['Zipcode'], - csmd['City'], csmd['Number_charging_points'], parsed_connectors) + for s in json_data["chargerstations"]: + csmd = s["csmd"] + parsed_connectors = parse_nobil_connectors(s["attr"]["conn"]) + + nobil_station = NobilStation( + csmd["id"], + csmd["Operator"], + csmd["Position"], + csmd["Created"], + csmd["Updated"], + csmd["Street"], + csmd["House_number"], + csmd["Zipcode"], + csmd["City"], + csmd["Number_charging_points"], + parsed_connectors, + ) all_nobil_stations.append(nobil_station) return all_nobil_stations @@ -63,11 +85,16 @@ def parse_nobil_connectors(connectors: dict): parsed_connectors: list[NobilConnector] = [] # iterate over all connectors and add them to the station for k, v in connectors.items(): - charging_capacity = v['5']['trans'] # contains a string like "7,4 kW - 230V 1-phase max 32A" or "75 kW DC" + charging_capacity = v["5"][ + "trans" + ] # contains a string like "7,4 kW - 230V 1-phase max 32A" or "75 kW DC" # extract the power in kW from the charging capacity string - power_in_kw = Decimal(charging_capacity.split(" kW")[0].replace(",", ".")) \ - if " kW" in charging_capacity else None + power_in_kw = ( + Decimal(charging_capacity.split(" kW")[0].replace(",", ".")) + if " kW" in charging_capacity + else None + ) parsed_connectors.append(NobilConnector(power_in_kw)) return parsed_connectors @@ -104,8 +131,11 @@ def _map_address_to_domain(nobil_station: NobilStation) -> Address: def _map_charging_to_domain(nobil_station: NobilStation) -> Charging: new_charging: Charging = Charging() new_charging.capacity = nobil_station.number_charging_points - new_charging.kw_list = [connector.power_in_kw for connector in nobil_station.connectors - if connector.power_in_kw is not None] + new_charging.kw_list = [ + connector.power_in_kw + for connector in nobil_station.connectors + if connector.power_in_kw is not None + ] if len(new_charging.kw_list) > 0: new_charging.max_kw = max(new_charging.kw_list) new_charging.total_kw = sum(new_charging.kw_list) @@ -114,25 +144,37 @@ def _map_charging_to_domain(nobil_station: NobilStation) -> Charging: def _load_datadump_and_write_to_target(path_to_target, country_code: str): nobil_api_key = os.getenv("NOBIL_APIKEY") - link_to_datadump = (f"https://nobil.no/api/server/datadump.php?apikey=" - f"{nobil_api_key}&countrycode={country_code}&format=json&file=true") + link_to_datadump = ( + f"https://nobil.no/api/server/datadump.php?apikey=" + f"{nobil_api_key}&countrycode={country_code}&format=json&file=true" + ) download_file(link_to_datadump, path_to_target) class NobilPipeline(Pipeline): """This class represents the pipeline for the Nobil data provider.""" - def __init__(self, session: Session, country_code: str, config: configparser, online: bool = False): + def __init__( + self, + session: Session, + country_code: str, + config: configparser, + online: bool = False, + ): super().__init__(config, session, online) accepted_country_codes = ["NOR", "SWE"] - reject_if(country_code.upper() not in accepted_country_codes, "Invalid country code ") + reject_if( + country_code.upper() not in accepted_country_codes, "Invalid country code " + ) self.country_code = country_code.upper() def run(self): """Run the pipeline.""" logger.info("Running NOR/SWE GOV Pipeline...") - path_to_target = Path(__file__).parent.parent.parent.parent.joinpath("data/" + self.country_code + "_gov.json") + path_to_target = Path(__file__).parent.parent.parent.parent.joinpath( + "data/" + self.country_code + "_gov.json" + ) if self.online: logger.info("Retrieving Online Data") _load_datadump_and_write_to_target(path_to_target, self.country_code) @@ -140,7 +182,9 @@ def run(self): nobil_stations_as_json = load_json_file(path_to_target) all_nobil_stations = _parse_json_data(nobil_stations_as_json) - for nobil_station in tqdm(iterable=all_nobil_stations, total=len(all_nobil_stations)): + for nobil_station in tqdm( + iterable=all_nobil_stations, total=len(all_nobil_stations) + ): station: Station = _map_station_to_domain(nobil_station, self.country_code) address: Address = _map_address_to_domain(nobil_station) charging: Charging = _map_charging_to_domain(nobil_station) @@ -149,7 +193,11 @@ def run(self): station.charging = charging # check if station already exists in db and add - existing_station = self.session.query(Station).filter_by(source_id=station.source_id).first() + existing_station = ( + self.session.query(Station) + .filter_by(source_id=station.source_id) + .first() + ) if existing_station is None: self.session.add(station) diff --git a/charging_stations_pipelines/pipelines/ocm/ocm.py b/charging_stations_pipelines/pipelines/ocm/ocm.py index 61e52057..dc7e44a1 100644 --- a/charging_stations_pipelines/pipelines/ocm/ocm.py +++ b/charging_stations_pipelines/pipelines/ocm/ocm.py @@ -11,15 +11,27 @@ from charging_stations_pipelines.pipelines import Pipeline from charging_stations_pipelines.pipelines.ocm.ocm_extractor import ocm_extractor -from charging_stations_pipelines.pipelines.ocm.ocm_mapper import map_address_ocm, map_charging_ocm, map_station_ocm -from charging_stations_pipelines.pipelines.station_table_updater import StationTableUpdater +from charging_stations_pipelines.pipelines.ocm.ocm_mapper import ( + map_address_ocm, + map_charging_ocm, + map_station_ocm, +) +from charging_stations_pipelines.pipelines.station_table_updater import ( + StationTableUpdater, +) from charging_stations_pipelines.shared import JSON logger = logging.getLogger(__name__) class OcmPipeline(Pipeline): - def __init__(self, country_code: str, config: configparser, session: Session, online: bool = False): + def __init__( + self, + country_code: str, + config: configparser, + session: Session, + online: bool = False, + ): super().__init__(config, session, online) self.country_code = country_code @@ -27,7 +39,7 @@ def __init__(self, country_code: str, config: configparser, session: Session, on def _retrieve_data(self): data_dir: str = os.path.join( - pathlib.Path(__file__).parent.resolve(), "../../..", "data" + pathlib.Path(__file__).parent.resolve(), "../../..", "data" ) pathlib.Path(data_dir).mkdir(parents=True, exist_ok=True) tmp_file_path = os.path.join(data_dir, self.config["OCM"]["filename"]) @@ -49,5 +61,7 @@ def run(self): mapped_station = map_station_ocm(entry, self.country_code) mapped_station.address = mapped_address mapped_station.charging = mapped_charging - station_updater.update_station(station=mapped_station, data_source_key='OCM') + station_updater.update_station( + station=mapped_station, data_source_key="OCM" + ) station_updater.log_update_station_counts() diff --git a/charging_stations_pipelines/pipelines/ocm/ocm_extractor.py b/charging_stations_pipelines/pipelines/ocm/ocm_extractor.py index f4e7e2a6..326d5075 100644 --- a/charging_stations_pipelines/pipelines/ocm/ocm_extractor.py +++ b/charging_stations_pipelines/pipelines/ocm/ocm_extractor.py @@ -21,37 +21,37 @@ def reference_data_to_frame(data: List[Dict]) -> pd.DataFrame: def merge_connection_types( - connection: pd.DataFrame, reference_data: pd.DataFrame + connection: pd.DataFrame, reference_data: pd.DataFrame ) -> pd.DataFrame: connection_ids: pd.Series = ( connection["ConnectionTypeID"].dropna().drop_duplicates() ) return connection.merge( - reference_data.loc[connection_ids], - how="left", - left_on="ConnectionTypeID", - right_index=True, + reference_data.loc[connection_ids], + how="left", + left_on="ConnectionTypeID", + right_index=True, ) def merge_address_infos( - address_info: pd.Series, reference_data: pd.DataFrame + address_info: pd.Series, reference_data: pd.DataFrame ) -> pd.DataFrame: return pd.concat([address_info, reference_data.loc[address_info["CountryID"]]]) def merge_with_reference_data( - row: pd.Series, - connection_types: pd.DataFrame, - address_info: pd.DataFrame, - operators: pd.DataFrame, + row: pd.Series, + connection_types: pd.DataFrame, + address_info: pd.DataFrame, + operators: pd.DataFrame, ): row["Connections"] = merge_connection_types( - connection=pd.json_normalize(row["Connections"]), - reference_data=connection_types, + connection=pd.json_normalize(row["Connections"]), + reference_data=connection_types, ) row["AddressInfo"] = merge_address_infos( - address_info=pd.Series(row["AddressInfo"]), reference_data=address_info + address_info=pd.Series(row["AddressInfo"]), reference_data=address_info ) row["OperatorID"] = operators.loc[row["OperatorID"]] return row @@ -61,7 +61,9 @@ def merge_connections(row, connection_types): frame = pd.DataFrame(row) if "ConnectionTypeID" not in frame.columns: return frame - return pd.merge(frame, connection_types, how="left", left_on="ConnectionTypeID", right_on="ID") + return pd.merge( + frame, connection_types, how="left", left_on="ConnectionTypeID", right_on="ID" + ) def ocm_extractor(tmp_file_path: str, country_code: str): @@ -87,33 +89,35 @@ def ocm_extractor(tmp_file_path: str, country_code: str): raise RuntimeError(f"Could not parse git version! {e}") else: if git_version < version.parse("2.25.0"): - logger.warning(f"found git version {git_version}, extracted from git" - f" --version: {git_version_raw} and regex match {match}") + logger.warning( + f"found git version {git_version}, extracted from git" + f" --version: {git_version_raw} and regex match {match}" + ) raise RuntimeError("Git version must be >= 2.25.0!") if (not os.path.isdir(data_dir)) or len(os.listdir(data_dir)) == 0: shutil.rmtree(data_root_dir, ignore_errors=True) subprocess.call( - [ - "git", - "clone", - "https://github.com/openchargemap/ocm-export", - "--no-checkout", - "--depth", - "1", - ], - cwd=project_data_dir, - stdout=subprocess.PIPE, + [ + "git", + "clone", + "https://github.com/openchargemap/ocm-export", + "--no-checkout", + "--depth", + "1", + ], + cwd=project_data_dir, + stdout=subprocess.PIPE, ) subprocess.call( - ["git", "sparse-checkout", "init", "--cone"], - cwd=data_root_dir, - stdout=subprocess.PIPE, + ["git", "sparse-checkout", "init", "--cone"], + cwd=data_root_dir, + stdout=subprocess.PIPE, ) subprocess.call( - ["git", "sparse-checkout", "set", f"data/{country_code}"], - cwd=data_root_dir, - stdout=subprocess.PIPE, + ["git", "sparse-checkout", "set", f"data/{country_code}"], + cwd=data_root_dir, + stdout=subprocess.PIPE, ) subprocess.call(["git", "checkout"], cwd=data_root_dir, stdout=subprocess.PIPE) else: @@ -131,47 +135,47 @@ def ocm_extractor(tmp_file_path: str, country_code: str): connection_types: pd.DataFrame = pd.json_normalize(data_ref["ConnectionTypes"]) connection_frame = pd.json_normalize( - records, record_path=["Connections"], meta=["UUID"] + records, record_path=["Connections"], meta=["UUID"] ) connection_frame = pd.merge( - connection_frame, - connection_types, - how="left", - left_on="ConnectionTypeID", - right_on="ID", + connection_frame, + connection_types, + how="left", + left_on="ConnectionTypeID", + right_on="ID", ) connection_frame_grouped = connection_frame.groupby("UUID").agg(list) connection_frame_grouped.reset_index(inplace=True) connection_frame_grouped["ConnectionsEnriched"] = connection_frame_grouped.apply( - lambda x: x.to_frame(), axis=1 + lambda x: x.to_frame(), axis=1 ) data = pd.merge( - data, - connection_frame_grouped[["ConnectionsEnriched", "UUID"]], - how="left", - on="UUID", + data, + connection_frame_grouped[["ConnectionsEnriched", "UUID"]], + how="left", + on="UUID", ) address_info: pd.DataFrame = pd.json_normalize(data_ref["Countries"]) address_info = address_info.rename(columns={"ID": "CountryID"}) pd_merged_with_countries = pd.merge( - data, - address_info, - left_on="AddressInfo.CountryID", - right_on="CountryID", - how="left", + data, + address_info, + left_on="AddressInfo.CountryID", + right_on="CountryID", + how="left", ) operators: pd.DataFrame = pd.json_normalize(data_ref["Operators"]) operators = operators.rename(columns={"ID": "OperatorIDREF"}) pd_merged_with_operators = pd.merge( - pd_merged_with_countries, - operators, - left_on="OperatorID", - right_on="OperatorIDREF", - how="left", + pd_merged_with_countries, + operators, + left_on="OperatorID", + right_on="OperatorIDREF", + how="left", ) pd_merged_with_operators.reset_index(drop=True).to_json( - tmp_file_path, orient="index" + tmp_file_path, orient="index" ) diff --git a/charging_stations_pipelines/pipelines/osm/osm.py b/charging_stations_pipelines/pipelines/osm/osm.py index e56b903a..068048cf 100644 --- a/charging_stations_pipelines/pipelines/osm/osm.py +++ b/charging_stations_pipelines/pipelines/osm/osm.py @@ -93,7 +93,9 @@ def run(self): stats["count_valid_stations"] += 1 except Exception as ex: stats["count_parse_error"] += 1 - logger.debug(f"{DATA_SOURCE_KEY} entry could not be parsed, error: {ex}. Row: {entry}") + logger.debug( + f"{DATA_SOURCE_KEY} entry could not be parsed, error: {ex}. Row: {entry}" + ) logger.info( f"Finished {DATA_SOURCE_KEY} Pipeline:\n" diff --git a/charging_stations_pipelines/pipelines/pipeline_factory.py b/charging_stations_pipelines/pipelines/pipeline_factory.py index a1e049be..041cda2f 100644 --- a/charging_stations_pipelines/pipelines/pipeline_factory.py +++ b/charging_stations_pipelines/pipelines/pipeline_factory.py @@ -25,10 +25,10 @@ def run(self): def pipeline_factory(db_session: Session, country="DE", online=True) -> Pipeline: """Creates a pipeline based on the country code.""" pipelines = { - "AT": EcontrolAtPipeline(config, db_session, online), - "DE": BnaPipeline(config, db_session, online), - "FR": FraPipeline(config, db_session, online), - "GB": GbPipeline(config, db_session, online), + "AT": EcontrolAtPipeline(config, db_session, online), + "DE": BnaPipeline(config, db_session, online), + "FR": FraPipeline(config, db_session, online), + "GB": GbPipeline(config, db_session, online), "NOR": NobilPipeline(db_session, "NOR", online), "SWE": NobilPipeline(db_session, "SWE", online), } diff --git a/charging_stations_pipelines/pipelines/station_table_updater.py b/charging_stations_pipelines/pipelines/station_table_updater.py index 63066b4f..24aafaa0 100644 --- a/charging_stations_pipelines/pipelines/station_table_updater.py +++ b/charging_stations_pipelines/pipelines/station_table_updater.py @@ -15,9 +15,9 @@ def __init__(self, session: Session, logger: Logger): self.session = session self.logger = logger self.counts = { - 'new': 0, - 'updated': 0, # no update mechanism yet - 'error': 0 + "new": 0, + "updated": 0, # no update mechanism yet + "error": 0, } def update_station(self, station: Station, data_source_key: str): @@ -38,12 +38,14 @@ def update_station(self, station: Station, data_source_key: str): self.session.rollback() if error_occurred: - self.counts['error'] += 1 + self.counts["error"] += 1 else: - self.counts['new'] += 1 + self.counts["new"] += 1 def log_update_station_counts(self): """Log the number of new and updated stations.""" - self.logger.info(f"new stations: {self.counts['new']}, " - f"updated stations: {self.counts['updated']}, " - f"errors: {self.counts['error']}") + self.logger.info( + f"new stations: {self.counts['new']}, " + f"updated stations: {self.counts['updated']}, " + f"errors: {self.counts['error']}" + ) diff --git a/charging_stations_pipelines/shared.py b/charging_stations_pipelines/shared.py index 9a54ce8f..b3f966a4 100644 --- a/charging_stations_pipelines/shared.py +++ b/charging_stations_pipelines/shared.py @@ -202,7 +202,7 @@ def lst_expand(aggregated_list: list[tuple[float, int]]) -> list[float]: def coalesce(*args): """Returns the first non-empty argument.""" for arg in args: - if arg is not None and arg != '': + if arg is not None and arg != "": return arg return None diff --git a/charging_stations_pipelines/stations_data_export.py b/charging_stations_pipelines/stations_data_export.py index b69842b0..612a0b34 100644 --- a/charging_stations_pipelines/stations_data_export.py +++ b/charging_stations_pipelines/stations_data_export.py @@ -14,32 +14,46 @@ @dataclass class ExportArea: """Represents an area targeted for data export.""" + lon: float lat: float radius_meters: float -def stations_data_export(db_connection, - country_code: str, - export_merged: bool = False, - export_charging_attributes: bool = False, - export_all_countries: bool = False, - export_to_csv: bool = False, - export_area: Optional[ExportArea] = None, - file_descriptor: str = ""): +def stations_data_export( + db_connection, + country_code: str, + export_merged: bool = False, + export_charging_attributes: bool = False, + export_all_countries: bool = False, + export_to_csv: bool = False, + export_area: Optional[ExportArea] = None, + file_descriptor: str = "", +): """Exports stations data to a file.""" logger.info(f"Exporting stations data for country {country_code}") - country_filter = f"country_code='{country_code}' AND " if country_code != "" and not export_all_countries else "" + country_filter = ( + f"country_code='{country_code}' AND " + if country_code != "" and not export_all_countries + else "" + ) merged_filter = "s.is_merged" if export_merged else "NOT s.is_merged" - export_area_filter = (f" AND ST_Dwithin(" - f"point, " - f"ST_MakePoint({export_area.lon}, {export_area.lat}, 4326)::geography, " - f"{export_area.radius_meters}" - f")") \ - if export_area else "" - - logger.debug(f"Using stations filter: country_filter='{country_filter}', " - f"merged_filter='{merged_filter}', export_area_filter='{export_area_filter}'") + export_area_filter = ( + ( + f" AND ST_Dwithin(" + f"point, " + f"ST_MakePoint({export_area.lon}, {export_area.lat}, 4326)::geography, " + f"{export_area.radius_meters}" + f")" + ) + if export_area + else "" + ) + + logger.debug( + f"Using stations filter: country_filter='{country_filter}', " + f"merged_filter='{merged_filter}', export_area_filter='{export_area_filter}'" + ) get_stations_filter = f"{country_filter}{merged_filter}{export_area_filter}" @@ -63,7 +77,9 @@ def stations_data_export(db_connection, """ logger.debug(f"Running postgis query {get_stations_list_sql}") - gdf: gpd.GeoDataFrame = gpd.read_postgis(get_stations_list_sql, con=db_connection, geom_col="point") + gdf: gpd.GeoDataFrame = gpd.read_postgis( + get_stations_list_sql, con=db_connection, geom_col="point" + ) logger.debug(f"Found stations of shape: {gdf.shape}") if len(gdf) == 0: @@ -71,18 +87,23 @@ def stations_data_export(db_connection, else: if export_to_csv: suffix = "csv" - gdf['latitude'] = gdf['point'].apply(lambda point: point.y if point else None) - gdf['longitude'] = gdf['point'].apply(lambda point: point.x if point else None) + gdf["latitude"] = gdf["point"].apply( + lambda point: point.y if point else None + ) + gdf["longitude"] = gdf["point"].apply( + lambda point: point.x if point else None + ) export_data = gdf.to_csv() else: suffix = "geo.json" export_data = gdf.to_json() - logger.debug(f"Data sample: {gdf.sample(5)}") file_country = "europe" if export_all_countries else country_code - file_description = get_file_description(file_descriptor, file_country, export_area) + file_description = get_file_description( + file_descriptor, file_country, export_area + ) file_suffix_merged = "merged" if export_merged else "w_duplicates" file_suffix_charging = "_w_charging" if export_charging_attributes else "" @@ -93,12 +114,16 @@ def stations_data_export(db_connection, logger.info(f"Done writing, file size: {outfile.tell()}") -def get_file_description(file_descriptor: str, file_country: str, export_circle: ExportArea): +def get_file_description( + file_descriptor: str, file_country: str, export_circle: ExportArea +): """Returns a file description based on the given parameters.""" is_export_circle_specified = export_circle is not None if file_descriptor == "": if is_export_circle_specified: - return f"{export_circle.lon}_{export_circle.lat}_{export_circle.radius_meters}" + return ( + f"{export_circle.lon}_{export_circle.lat}_{export_circle.radius_meters}" + ) else: return file_country else: diff --git a/test/pipelines/at/test_econtrol_crawler.py b/test/pipelines/at/test_econtrol_crawler.py index 1368f31c..50e447bd 100644 --- a/test/pipelines/at/test_econtrol_crawler.py +++ b/test/pipelines/at/test_econtrol_crawler.py @@ -11,6 +11,7 @@ from charging_stations_pipelines.pipelines.at.econtrol_crawler import ( __name__ as test_module_name, ) + # NOTE: "local_caplog" is a pytest fixture from test.shared.local_caplog from test.shared import local_caplog, LogLocalCaptureFixture # noqa: F401 @@ -82,8 +83,8 @@ def test_get_data( mock_getenv, mock_get_paginated_stations, mock_open, - local_caplog: LogLocalCaptureFixture, -): # noqa: F811 + local_caplog: LogLocalCaptureFixture, # noqa: F811 +): # Prepare test data and mocks expected_file_size: Final[int] = 2184 tmp_data_path = "/tmp/test_data.ndjson" @@ -145,8 +146,8 @@ def test_get_data_empty_response( mock_getsize, mock_get_paginated_stations, mock_open, - local_caplog: LogLocalCaptureFixture, -): # noqa: F811 + local_caplog: LogLocalCaptureFixture, # noqa: F811 +): # Prepare test data and mocks expected_file_size: Final[int] = 0 tmp_data_path = "/tmp/test_data.ndjson" diff --git a/test/pipelines/at/test_econtrol_mapper.py b/test/pipelines/at/test_econtrol_mapper.py index 0f2e93e8..74844ab0 100644 --- a/test/pipelines/at/test_econtrol_mapper.py +++ b/test/pipelines/at/test_econtrol_mapper.py @@ -6,123 +6,139 @@ from numpy import float64 from shapely.geometry import Point -from charging_stations_pipelines.pipelines.at.econtrol_mapper import aggregate_attribute, map_charging, map_station +from charging_stations_pipelines.pipelines.at.econtrol_mapper import ( + aggregate_attribute, + map_charging, + map_station, +) from charging_stations_pipelines.shared import float_cmp_eq -@pytest.mark.parametrize("datapoint, expected_result", [(pd.Series({ - "city": "Reichenau im M\u00fchlkreis ", - "contactName": "Marktgemeindeamt Reichenau i.M.", - "description": "Ortsplatz vor dem Gemeindeamt", - "email": "marktgemeindeamt@reichenau-ooe.at", - "evseCountryId": " AT ", - "evseOperatorId": " 000", - "evseStationId": " EREI001 ", - "freeParking": True, - "greenEnergy": True, - "label": "Marktplatz/Gemeindeamt", - "location": { - "latitude": 48.456161, - "longitude": 14.349852 - }, - "openingHours": { - "details": [] - }, - "points": [ +@pytest.mark.parametrize( + "datapoint, expected_result", + [ + ( + pd.Series( + { + "city": "Reichenau im M\u00fchlkreis ", + "contactName": "Marktgemeindeamt Reichenau i.M.", + "description": "Ortsplatz vor dem Gemeindeamt", + "email": "marktgemeindeamt@reichenau-ooe.at", + "evseCountryId": " AT ", + "evseOperatorId": " 000", + "evseStationId": " EREI001 ", + "freeParking": True, + "greenEnergy": True, + "label": "Marktplatz/Gemeindeamt", + "location": {"latitude": 48.456161, "longitude": 14.349852}, + "openingHours": {"details": []}, + "points": [ + { + "authenticationModes": [], + "connectorTypes": ["SCEE-7-8"], + "energyInKw": 3.0, + "evseId": "AT*000*EREI001", + "freeOfCharge": True, + "location": {"latitude": 48.456161, "longitude": 14.349852}, + "public": True, + "roaming": True, + "status": "UNKNOWN", + "vehicleTypes": ["CAR", "BICYCLE", "MOTORCYCLE"], + } + ], + "postCode": "4204", + "public": True, + "status": "ACTIVE", + "street": "Marktplatz 2", + "telephone": "+43 7211 82550", + "website": "www.reichenau-ooe.at", + } + ), { - "authenticationModes": [], - "connectorTypes": ["SCEE-7-8"], - "energyInKw": 3.0, - "evseId": "AT*000*EREI001", - "freeOfCharge": True, - "location": { - "latitude": 48.456161, - "longitude": 14.349852 - }, - "public": True, - "roaming": True, - "status": "UNKNOWN", - "vehicleTypes": [ - "CAR", - "BICYCLE", - "MOTORCYCLE" - ] - } - ], - "postCode": "4204", - "public": True, - "status": "ACTIVE", - "street": "Marktplatz 2", - "telephone": "+43 7211 82550", - "website": "www.reichenau-ooe.at" - }), { - 'source_id': 'AT*000*EREI001', - 'data_source': 'AT_ECONTROL', - 'evse_country_id': 'AT', - 'evse_operator_id': '000', - 'evse_station_id': 'EREI001', - 'operator': 'Marktgemeindeamt Reichenau i.M.', - 'payment': None, - 'authentication': None, - 'point': from_shape(Point(14.349852, 48.456161)), - 'country_code': 'AT' - }), (pd.Series({ - "city": "Reichenau im M\u00fchlkreis ", - "description": "Ortsplatz vor dem Gemeindeamt", - "email": "marktgemeindeamt@reichenau-ooe.at", - "evseCountryId": " DE", - "evseOperatorId": " 000", - "evseStationId": " EREI001 ", - "freeParking": True, - "greenEnergy": True, - "label": "Marktplatz/Gemeindeamt", - "location": { - "latitude": 48.111, - "longitude": 14.222 - }, - }), { - 'source_id': 'AT*000*EREI001', - 'data_source': 'AT_ECONTROL', - 'evse_country_id': 'DE', - 'evse_operator_id': '000', - 'evse_station_id': 'EREI001', - 'operator': None, - 'payment': None, - 'authentication': None, - 'point': from_shape(Point(14.222, 48.111)), - 'country_code': 'AT' - })]) + "source_id": "AT*000*EREI001", + "data_source": "AT_ECONTROL", + "evse_country_id": "AT", + "evse_operator_id": "000", + "evse_station_id": "EREI001", + "operator": "Marktgemeindeamt Reichenau i.M.", + "payment": None, + "authentication": None, + "point": from_shape(Point(14.349852, 48.456161)), + "country_code": "AT", + }, + ), + ( + pd.Series( + { + "city": "Reichenau im M\u00fchlkreis ", + "description": "Ortsplatz vor dem Gemeindeamt", + "email": "marktgemeindeamt@reichenau-ooe.at", + "evseCountryId": " DE", + "evseOperatorId": " 000", + "evseStationId": " EREI001 ", + "freeParking": True, + "greenEnergy": True, + "label": "Marktplatz/Gemeindeamt", + "location": {"latitude": 48.111, "longitude": 14.222}, + } + ), + { + "source_id": "AT*000*EREI001", + "data_source": "AT_ECONTROL", + "evse_country_id": "DE", + "evse_operator_id": "000", + "evse_station_id": "EREI001", + "operator": None, + "payment": None, + "authentication": None, + "point": from_shape(Point(14.222, 48.111)), + "country_code": "AT", + }, + ), + ], +) def test_map_station(datapoint, expected_result): - s = map_station(datapoint, 'AT') + s = map_station(datapoint, "AT") # noinspection DuplicatedCode - assert s.source_id == expected_result['source_id'] - assert s.data_source == expected_result['data_source'] - assert s.evse_country_id == expected_result['evse_country_id'] - assert s.evse_operator_id == expected_result['evse_operator_id'] - assert s.evse_station_id == expected_result['evse_station_id'] - assert s.operator == expected_result['operator'] - assert s.payment is expected_result['payment'] - assert s.authentication is expected_result['authentication'] - assert s.point == expected_result['point'] + assert s.source_id == expected_result["source_id"] + assert s.data_source == expected_result["data_source"] + assert s.evse_country_id == expected_result["evse_country_id"] + assert s.evse_operator_id == expected_result["evse_operator_id"] + assert s.evse_station_id == expected_result["evse_station_id"] + assert s.operator == expected_result["operator"] + assert s.payment is expected_result["payment"] + assert s.authentication is expected_result["authentication"] + assert s.point == expected_result["point"] assert s.raw_data == datapoint.to_json() - assert s.country_code == expected_result['country_code'] + assert s.country_code == expected_result["country_code"] def test_map_charging(): - datapoint = pd.Series({ - 'points': [{'evseId': 'AT*002*E200101*1', - 'energyInKw': 12, - 'authenticationModes': ['APP', 'SMS', 'WEBSITE'], - 'connectorTypes': ['CTESLA', 'S309-1P-16A', 'CG105', 'PAN'], - 'vehicleTypes': ['CAR', 'TRUCK', 'BICYCLE', 'MOTORCYCLE', 'BOAT']}, - {'evseId': 'AT*002*E2001*5', - 'energyInKw': 15, - 'location': {'latitude': 48.198523499134545, 'longitude': 16.325340999197394}, - 'priceInCentPerKwh': 12, - 'priceInCentPerMin': 13, - 'authenticationModes': ['SMS', "DEBIT_CARD", "CASH", "CREDIT_CARD"], - 'connectorTypes': ['CTESLA', 'CG105', 'CCCS2', 'CCCS1']}] - }) + datapoint = pd.Series( + { + "points": [ + { + "evseId": "AT*002*E200101*1", + "energyInKw": 12, + "authenticationModes": ["APP", "SMS", "WEBSITE"], + "connectorTypes": ["CTESLA", "S309-1P-16A", "CG105", "PAN"], + "vehicleTypes": ["CAR", "TRUCK", "BICYCLE", "MOTORCYCLE", "BOAT"], + }, + { + "evseId": "AT*002*E2001*5", + "energyInKw": 15, + "location": { + "latitude": 48.198523499134545, + "longitude": 16.325340999197394, + }, + "priceInCentPerKwh": 12, + "priceInCentPerMin": 13, + "authenticationModes": ["SMS", "DEBIT_CARD", "CASH", "CREDIT_CARD"], + "connectorTypes": ["CTESLA", "CG105", "CCCS2", "CCCS1"], + }, + ] + } + ) c = map_charging(datapoint, 1) @@ -133,7 +149,16 @@ def test_map_charging(): assert float_cmp_eq(c.max_kw, 15.0) assert c.ampere_list is None assert c.volt_list is None - assert c.socket_type_list == ['CTESLA', 'S309-1P-16A', 'CG105', 'PAN', 'CTESLA', 'CG105', 'CCCS2', 'CCCS1'] + assert c.socket_type_list == [ + "CTESLA", + "S309-1P-16A", + "CG105", + "PAN", + "CTESLA", + "CG105", + "CCCS2", + "CCCS1", + ] assert c.dc_support is None @@ -141,18 +166,13 @@ def test_map_charging__kw_list(): sample_data = [ ([None, None], [None, None, None]), ([None, []], [None, None, None]), - ([3.14, ['CTESLA', 'S309-1P-16A']], [[3.14, 3.14], 3.14, 3.14 + 3.14]), + ([3.14, ["CTESLA", "S309-1P-16A"]], [[3.14, 3.14], 3.14, 3.14 + 3.14]), ] for raw, expected in sample_data: - raw_datapoint = pd.Series({ - 'points': [ - { - 'energyInKw': raw[0], - 'connectorTypes': raw[1] - } - ] - }) + raw_datapoint = pd.Series( + {"points": [{"energyInKw": raw[0], "connectorTypes": raw[1]}]} + ) exp_kw_list, exp_max_kw, exp_total_kw = expected @@ -163,21 +183,42 @@ def test_map_charging__kw_list(): assert c.total_kw == exp_total_kw -@pytest.mark.parametrize("points, attr, expected_result", [ - (pd.Series([ - {'test_attr': ['a', 'b', 'c'], }, - {'test_attr': ['c', 'd', 'e'], }, - {'test_attr': [], }, ]), - 'test_attr', - [['a', 'b', 'c'], ['c', 'd', 'e'], []]), - (pd.Series([{'test_attr': ['a', 'b', 'c']}]), 'test_attr', [['a', 'b', 'c']]), - (pd.Series([{'test_attr': ['a', 'b', 'c']}, {'test_attr': ['d', 'e', 'f']}]), 'test_attr', - [['a', 'b', 'c'], ['d', 'e', 'f']]), - (pd.Series([], dtype=float64), 'test_attr', []), - (None, None, None), - (None, 'test_attr', None), - (pd.Series([{'test_attr': ['a', 'b', 'c']}]), None, [[]]), - (pd.Series([{'test_attr': ['a', 'b']}]), 'missing_attr', [[]]), ]) +@pytest.mark.parametrize( + "points, attr, expected_result", + [ + ( + pd.Series( + [ + { + "test_attr": ["a", "b", "c"], + }, + { + "test_attr": ["c", "d", "e"], + }, + { + "test_attr": [], + }, + ] + ), + "test_attr", + [["a", "b", "c"], ["c", "d", "e"], []], + ), + (pd.Series([{"test_attr": ["a", "b", "c"]}]), "test_attr", [["a", "b", "c"]]), + ( + pd.Series([{"test_attr": ["a", "b", "c"]}, {"test_attr": ["d", "e", "f"]}]), + "test_attr", + [["a", "b", "c"], ["d", "e", "f"]], + ), + (pd.Series([], dtype=float64), "test_attr", []), + (None, None, None), + (None, "test_attr", None), + (pd.Series([{"test_attr": ["a", "b", "c"]}]), None, [[]]), + (pd.Series([{"test_attr": ["a", "b"]}]), "missing_attr", [[]]), + ], +) def test_aggregate_attribute(points, attr, expected_result): - result = aggregate_attribute(points, attr, ) + result = aggregate_attribute( + points, + attr, + ) assert result == expected_result diff --git a/test/pipelines/de/test_bna_crawler.py b/test/pipelines/de/test_bna_crawler.py index 6548d5b4..2dcf447f 100644 --- a/test/pipelines/de/test_bna_crawler.py +++ b/test/pipelines/de/test_bna_crawler.py @@ -8,73 +8,94 @@ import charging_stations_pipelines.pipelines.de.bna_crawler from charging_stations_pipelines.pipelines.de import bna_crawler -from charging_stations_pipelines.pipelines.de.bna_crawler import __name__ as test_module_name +from charging_stations_pipelines.pipelines.de.bna_crawler import ( + __name__ as test_module_name, +) + # NOTE: "local_caplog" is a pytest fixture from test.shared.local_caplog from test.shared import local_caplog, LogLocalCaptureFixture # noqa: F401 -@patch.object(charging_stations_pipelines.pipelines.de.bna_crawler, 'BeautifulSoup') -@patch.object(charging_stations_pipelines.pipelines.de.bna_crawler, 'download_file') -@patch.object(requests, 'get') -@patch.object(os.path, 'getsize') -def test_get_bna_data_downloads_file_with_correct_url(mock_getsize, mock_requests_get, mock_download_file, - mock_beautiful_soup): +@patch.object(charging_stations_pipelines.pipelines.de.bna_crawler, "BeautifulSoup") +@patch.object(charging_stations_pipelines.pipelines.de.bna_crawler, "download_file") +@patch.object(requests, "get") +@patch.object(os.path, "getsize") +def test_get_bna_data_downloads_file_with_correct_url( + mock_getsize, mock_requests_get, mock_download_file, mock_beautiful_soup +): # Mock the requests.get response mock_response = Mock() - mock_response.content = b'something, something...' + mock_response.content = b"something, something..." mock_response.status_code = 200 mock_requests_get.return_value = mock_response # Mock the BeautifulSoup find_all method - mock_beautiful_soup.return_value.find_all.return_value = [{'href': 'https://some_ladesaeulenregister_url.xlsx'}] + mock_beautiful_soup.return_value.find_all.return_value = [ + {"href": "https://some_ladesaeulenregister_url.xlsx"} + ] # Mock the os.path.getsize method mock_getsize.return_value = 4321 # Call the method under test - bna_crawler.get_bna_data('./tmp_data_path/some_ladesaeulenregister_url.xlsx') + bna_crawler.get_bna_data("./tmp_data_path/some_ladesaeulenregister_url.xlsx") # Ensure these function were called with the expected arguments. mock_requests_get.assert_called_with( - "https://www.bundesnetzagentur.de/DE/Fachthemen/ElektrizitaetundGas/E-Mobilitaet/start.html", - headers={"User-Agent": "Mozilla/5.0"}) + "https://www.bundesnetzagentur.de/DE/Fachthemen/ElektrizitaetundGas/E-Mobilitaet/start.html", + headers={"User-Agent": "Mozilla/5.0"}, + ) # Assert that the download_file method was called with the correct parameters mock_download_file.assert_called_once_with( - 'https://some_ladesaeulenregister_url.xlsx', './tmp_data_path/some_ladesaeulenregister_url.xlsx') + "https://some_ladesaeulenregister_url.xlsx", + "./tmp_data_path/some_ladesaeulenregister_url.xlsx", + ) # Assert that the os.path.getsize method was called with the correct parameters - mock_getsize.assert_called_once_with('./tmp_data_path/some_ladesaeulenregister_url.xlsx') + mock_getsize.assert_called_once_with( + "./tmp_data_path/some_ladesaeulenregister_url.xlsx" + ) -@patch.object(requests, 'get') -@patch.object(charging_stations_pipelines.pipelines.de.bna_crawler, 'BeautifulSoup') -def test_get_bna_data_logs_error_when_no_download_link_found(mock_beautiful_soup, mock_requests_get, caplog): +@patch.object(requests, "get") +@patch.object(charging_stations_pipelines.pipelines.de.bna_crawler, "BeautifulSoup") +def test_get_bna_data_logs_error_when_no_download_link_found( + mock_beautiful_soup, mock_requests_get, caplog +): # Mock the requests.get response - mock_requests_get.return_value = Mock(content=b'some content', status_code=200) + mock_requests_get.return_value = Mock(content=b"some content", status_code=200) # Mock the BeautifulSoup find method to return None mock_beautiful_soup.return_value.find_all.return_value = [] - with pytest.raises(bna_crawler.ExtractURLException, match='Failed to extract the download url from the website.'): + with pytest.raises( + bna_crawler.ExtractURLException, + match="Failed to extract the download url from the website.", + ): # Call the function under test - bna_crawler.get_bna_data('tmp_data_path') - - -@patch.object(requests, 'get') -@patch.object(charging_stations_pipelines.pipelines.de.bna_crawler, 'BeautifulSoup') -@patch.object(charging_stations_pipelines.pipelines.de.bna_crawler, 'download_file') -@patch.object(os.path, 'getsize') -def test_get_bna_data_logs_file_size_after_download(mock_getsize, mock_download_file, mock_beautiful_soup, - mock_requests_get, local_caplog: LogLocalCaptureFixture): + bna_crawler.get_bna_data("tmp_data_path") + + +@patch.object(requests, "get") +@patch.object(charging_stations_pipelines.pipelines.de.bna_crawler, "BeautifulSoup") +@patch.object(charging_stations_pipelines.pipelines.de.bna_crawler, "download_file") +@patch.object(os.path, "getsize") +def test_get_bna_data_logs_file_size_after_download( + mock_getsize, + mock_download_file, + mock_beautiful_soup, + mock_requests_get, + local_caplog: LogLocalCaptureFixture, # noqa: F811 +): # Mock the requests.get response - mock_requests_get.return_value = Mock(content=b'some content') + mock_requests_get.return_value = Mock(content=b"some content") mock_requests_get.return_value.status_code = 200 # Mock the BeautifulSoup find_all method mock_beautiful_soup.return_value.find_all.return_value = [ - {'href': 'some_url_without_search_term.xlsx'}, - {'href': 'tmp_data_path/ladesaeulenregister.xlsx'} + {"href": "some_url_without_search_term.xlsx"}, + {"href": "tmp_data_path/ladesaeulenregister.xlsx"}, ] # Mock the os.path.getsize method @@ -83,22 +104,27 @@ def test_get_bna_data_logs_file_size_after_download(mock_getsize, mock_download_ logger = logging.getLogger(test_module_name) with local_caplog(level=logging.DEBUG, logger=logger): # Call method under test... with mocked logging - bna_crawler.get_bna_data('tmp_data_path/some_url1_with_search_term.xlsx') + bna_crawler.get_bna_data("tmp_data_path/some_url1_with_search_term.xlsx") # Assert that the file size was logged assert "Downloaded file size: 1234 bytes" in local_caplog.logs # Assert that requests.get was called correctly mock_requests_get.assert_called_once_with( - 'https://www.bundesnetzagentur.de/DE/Fachthemen/ElektrizitaetundGas/E-Mobilitaet/start.html', - headers={'User-Agent': 'Mozilla/5.0'}) + "https://www.bundesnetzagentur.de/DE/Fachthemen/ElektrizitaetundGas/E-Mobilitaet/start.html", + headers={"User-Agent": "Mozilla/5.0"}, + ) # Assert that BeautifulSoup was called correctly - mock_beautiful_soup.assert_called_once_with(b'some content', 'html.parser') + mock_beautiful_soup.assert_called_once_with(b"some content", "html.parser") # Assert that download_file was called correctly - mock_download_file.assert_called_once_with('tmp_data_path/ladesaeulenregister.xlsx', - 'tmp_data_path/some_url1_with_search_term.xlsx') + mock_download_file.assert_called_once_with( + "tmp_data_path/ladesaeulenregister.xlsx", + "tmp_data_path/some_url1_with_search_term.xlsx", + ) # Assert that os.path.getsize was called correctly - mock_getsize.assert_called_once_with('tmp_data_path/some_url1_with_search_term.xlsx') + mock_getsize.assert_called_once_with( + "tmp_data_path/some_url1_with_search_term.xlsx" + ) diff --git a/test/pipelines/osm/test_osm_mapper.py b/test/pipelines/osm/test_osm_mapper.py index 78fbbc9c..ec1f6e95 100644 --- a/test/pipelines/osm/test_osm_mapper.py +++ b/test/pipelines/osm/test_osm_mapper.py @@ -294,7 +294,8 @@ def test_charging_mapping__amperage(raw_amperage, expected): @pytest.mark.parametrize( - "amperage_tag_missing, expected", [({}, None)] # Missing amperage key + "amperage_tag_missing, expected", + [({}, None)], # Missing amperage key ) def test_charging_mapping__amperage_missing_key(amperage_tag_missing, expected): charging = osm_mapper.map_charging_osm( diff --git a/test/shared.py b/test/shared.py index 4298135d..296e9f4a 100644 --- a/test/shared.py +++ b/test/shared.py @@ -92,7 +92,7 @@ def create_charging() -> Charging: def skip_if_github(): """Checks if the current workflow is running on GitHub.""" - return 'GITHUB_WORKFLOW' in os.environ + return "GITHUB_WORKFLOW" in os.environ class LogLocalCaptureFixture: @@ -102,7 +102,9 @@ def __init__(self): self.handler = LogCaptureHandler() @contextmanager - def __call__(self, level: int, logger: logging.Logger) -> Generator[None, None, None]: + def __call__( + self, level: int, logger: logging.Logger + ) -> Generator[None, None, None]: """Context manager that sets the level for capturing of logs.""" orig_level = logger.level diff --git a/test/test_main.py b/test/test_main.py index 4baadc3b..72ba8301 100644 --- a/test/test_main.py +++ b/test/test_main.py @@ -6,7 +6,7 @@ def test_parse_valid_args(): - arguments = parse_args('import merge --countries de GB'.split()) + arguments = parse_args("import merge --countries de GB".split()) assert arguments.tasks == ["import", "merge"] assert arguments.countries == ["DE", "GB"] assert not arguments.offline @@ -14,12 +14,12 @@ def test_parse_valid_args(): def test_parse_offline_arg(): - arguments = parse_args('import --offline'.split()) + arguments = parse_args("import --offline".split()) assert arguments.offline def test_parse_delete_data_arg(): - arguments = parse_args('import --delete_data'.split()) + arguments = parse_args("import --delete_data".split()) assert arguments.delete_data @@ -40,4 +40,4 @@ def test_parse_invalid_task_arg(): # ---- # ... is an expected side-effect of using `pytest.raises(SystemExit)` with pytest.raises(SystemExit): - parse_args('invalid_task --countries de'.split()) + parse_args("invalid_task --countries de".split()) diff --git a/testdata_import.py b/testdata_import.py index e92a80dc..830ddaa1 100644 --- a/testdata_import.py +++ b/testdata_import.py @@ -13,9 +13,9 @@ from googleapiclient.errors import HttpError # If modifying these scopes, delete the file token.json. -SCOPES = ['https://www.googleapis.com/auth/spreadsheets.readonly'] +SCOPES = ["https://www.googleapis.com/auth/spreadsheets.readonly"] -SPREADSHEET_ID = '1bvwxsGRMaEsiuz_ghY3HEbFEMCPahINcVoGE2k_zgOc' +SPREADSHEET_ID = "1bvwxsGRMaEsiuz_ghY3HEbFEMCPahINcVoGE2k_zgOc" def main() -> list[Any]: @@ -28,7 +28,7 @@ def main() -> list[Any]: # The file token.json stores the user's access and refresh tokens, and is # created automatically when the authorization flow completes for the first # time. - token_filename = os.path.join(directory, 'token_deepatlas.json') + token_filename = os.path.join(directory, "token_deepatlas.json") if os.path.exists(token_filename): creds = Credentials.from_authorized_user_file(token_filename, SCOPES) # If there are no (valid) credentials available, let the user log in. @@ -36,24 +36,28 @@ def main() -> list[Any]: if creds and creds.expired and creds.refresh_token: creds.refresh(Request()) else: - flow = InstalledAppFlow.from_client_secrets_file(os.path.join(directory, 'credentials.json'), SCOPES) + flow = InstalledAppFlow.from_client_secrets_file( + os.path.join(directory, "credentials.json"), SCOPES + ) creds = flow.run_local_server(port=8083) # Save the credentials for the next run - with open(token_filename, 'w') as token: + with open(token_filename, "w") as token: token.write(creds.to_json()) try: - service = build('sheets', 'v4', credentials=creds) + service = build("sheets", "v4", credentials=creds) # Call the Sheets API sheet = service.spreadsheets() - result = sheet.values().get(spreadsheetId=SPREADSHEET_ID, range='A1:Z100').execute() - values = result.get('values', []) + result = ( + sheet.values().get(spreadsheetId=SPREADSHEET_ID, range="A1:Z100").execute() + ) + values = result.get("values", []) return values except HttpError as err: print(err) -if __name__ == '__main__': +if __name__ == "__main__": main() diff --git a/testing/testdata.py b/testing/testdata.py index a7219eea..5bc8245f 100644 --- a/testing/testdata.py +++ b/testing/testdata.py @@ -17,6 +17,7 @@ @dataclass class TestData: """Class used to store test data for geographic locations.""" + osm_id: Optional[str] = None ocm_id: Optional[str] = None bna_id: Optional[str] = None @@ -27,7 +28,7 @@ def load_test_data() -> List[TestData]: """Create TestData objects from spreadsheet data.""" rows = testdata_import.main() if not rows: - print('No test data found.') + print("No test data found.") return [] test_data = [] for row in rows[1:]: @@ -63,7 +64,9 @@ def run(): current_dir = os.path.join(pathlib.Path(__file__).parent.resolve()) config.read(os.path.join(os.path.join(current_dir, "config", "config.ini"))) - merger: StationMerger = StationMerger('DE', config=config, db_engine=create_engine(db_uri, echo=True)) + merger: StationMerger = StationMerger( + "DE", config=config, db_engine=create_engine(db_uri, echo=True) + ) # print(test_data) with open("testdata_merge.csv", "w") as outfile: @@ -71,23 +74,31 @@ def run(): duplicates: pd.DataFrame = pd.DataFrame() if station.osm_coordinates: # print(f"OSM ID of central charging station: {station.osm_id}") - duplicates, _ = merger.find_duplicates(station.osm_id, station.osm_coordinates, radius_m, - filter_by_source_id=True) + duplicates, _ = merger.find_duplicates( + station.osm_id, + station.osm_coordinates, + radius_m, + filter_by_source_id=True, + ) if not duplicates.empty: - data_sources_in_duplicates = ','.join(duplicates.data_source.unique()) + data_sources_in_duplicates = ",".join(duplicates.data_source.unique()) # print(f"Data Sources in duplicates: {data_sources_in_duplicates}") - result: list[Optional[str]] = [station.osm_id, str(len(duplicates)), data_sources_in_duplicates] + result: list[Optional[str]] = [ + station.osm_id, + str(len(duplicates)), + data_sources_in_duplicates, + ] print(f"Result of merge on test data: {result}") - outfile.write(','.join(result)) - outfile.write('\n') + outfile.write(",".join(result)) + outfile.write("\n") else: print("No successful merge on test data") - outfile.write('NA\n') + outfile.write("NA\n") - if station.osm_id == '6417375309': + if station.osm_id == "6417375309": pass # break outfile.close() -if __name__ == '__main__': +if __name__ == "__main__": run() diff --git a/tests/integration/test_int_de_bna.py b/tests/integration/test_int_de_bna.py index ba74ad6e..004807c2 100644 --- a/tests/integration/test_int_de_bna.py +++ b/tests/integration/test_int_de_bna.py @@ -10,38 +10,34 @@ from charging_stations_pipelines.shared import load_excel_file from tests.test_utils import verify_schema_follows -EXPECTED_DATA_SCHEMA = {'Betreiber': 'object', - 'Straße': 'object', - 'Hausnummer': 'object', - 'Adresszusatz': 'object', - 'Postleitzahl': 'object', - 'Ort': 'object', - 'Bundesland': 'object', - 'Kreis/kreisfreie Stadt': 'object', - 'Breitengrad': 'object', - 'Längengrad': 'object', - - 'Inbetriebnahmedatum': 'object', - - 'Nennleistung Ladeeinrichtung [kW]': 'object', - 'Art der Ladeeinrichung': 'object', - 'Anzahl Ladepunkte': 'object', - - 'Steckertypen1': 'object', - 'P1 [kW]': 'object', - 'Public Key1': 'object', - - 'Steckertypen2': 'object', - 'P2 [kW]': 'object', - 'Public Key2': 'object', - - 'Steckertypen3': 'object', - 'P3 [kW]': 'object', - 'Public Key3': 'object', - - 'Steckertypen4': 'object', - 'P4 [kW]': 'object', - 'Public Key4': 'object'} +EXPECTED_DATA_SCHEMA = { + "Betreiber": "object", + "Straße": "object", + "Hausnummer": "object", + "Adresszusatz": "object", + "Postleitzahl": "object", + "Ort": "object", + "Bundesland": "object", + "Kreis/kreisfreie Stadt": "object", + "Breitengrad": "object", + "Längengrad": "object", + "Inbetriebnahmedatum": "object", + "Nennleistung Ladeeinrichtung [kW]": "object", + "Art der Ladeeinrichung": "object", + "Anzahl Ladepunkte": "object", + "Steckertypen1": "object", + "P1 [kW]": "object", + "Public Key1": "object", + "Steckertypen2": "object", + "P2 [kW]": "object", + "Public Key2": "object", + "Steckertypen3": "object", + "P3 [kW]": "object", + "Public Key3": "object", + "Steckertypen4": "object", + "P4 [kW]": "object", + "Public Key4": "object", +} @pytest.fixture(scope="module") @@ -72,7 +68,9 @@ def test_file_size(bna_data): def test_dataframe_schema(bna_data): _, bna_in_data = bna_data # Check schema of the downloaded Excel file - assert verify_schema_follows(bna_in_data, EXPECTED_DATA_SCHEMA), "Mismatch in schema of the downloaded Excel file!" + assert verify_schema_follows( + bna_in_data, EXPECTED_DATA_SCHEMA + ), "Mismatch in schema of the downloaded Excel file!" @pytest.mark.integration_test diff --git a/tests/integration/test_int_merger.py b/tests/integration/test_int_merger.py index 50b74401..db0fe926 100644 --- a/tests/integration/test_int_merger.py +++ b/tests/integration/test_int_merger.py @@ -63,11 +63,13 @@ def _run_merger(engine): # Suppressing Pandas warning (1/2): "A value is trying to be set on a copy of a slice from a DataFrame." pd.options.mode.chained_assignment = None # default: 'warn' - station_merger = StationMerger(country_code='DE', config=(get_config()), db_engine=engine) + station_merger = StationMerger( + country_code="DE", config=(get_config()), db_engine=engine + ) station_merger.run() # Suppressing Pandas warning (2/2): restoring default value - pd.options.mode.chained_assignment = 'warn' + pd.options.mode.chained_assignment = "warn" @pytest.mark.integration_test @@ -75,12 +77,12 @@ def test_int_deduplication_expect_a_merged_entry_if_two_duplicates_exists(engine def _create_stations(): # Given: two duplicate stations station_one = create_station() - station_one.data_source = 'BNA' - station_one.source_id = 'BNA_ID1' + station_one.data_source = "BNA" + station_one.source_id = "BNA_ID1" station_duplicate = create_station() - station_duplicate.data_source = 'OSM' - station_duplicate.source_id = 'OSM_ID1' + station_duplicate.data_source = "OSM" + station_duplicate.source_id = "OSM_ID1" return station_one, station_duplicate @@ -92,7 +94,7 @@ def _check_results(session): not_merged_stations = [s for s in all_stations if not s.is_merged] assert len(not_merged_stations) == 2 - assert all(s.merge_status == 'is_duplicate' for s in not_merged_stations) + assert all(s.merge_status == "is_duplicate" for s in not_merged_stations) merged_stations = [s for s in all_stations if s.is_merged] assert len(merged_stations) == 1 @@ -128,7 +130,7 @@ def _check_results(session): assert len(merged_stations) == 1 merged_station = merged_stations[0] - assert merged_station.data_source == 'OCM' + assert merged_station.data_source == "OCM" point = wkb.loads(bytes(merged_station.point.data)) expected_x, expected_y = 1.11111112, 1.111111 @@ -153,56 +155,62 @@ def _create_test_station(raw: pd.Series, country_code: str) -> Station: def _create_stations(): # Given: two problematic stations # noinspection DuplicatedCode - raw1: pd.Series = pd.Series({ - 'evseCountryId': 'DE', - 'evseOperatorId': 'ELE', - 'evseStationId': 'EKRIMML4', - 'status': 'ACTIVE', - 'label': 'KRIMML', - 'description': None, - 'postCode': 5743, - 'city': 'Krimml', - 'street': 'Gerlos Straße, Parkplatz P4', - 'location': {'latitude': 12.167938, 'longitude': 12.167938}, - 'distance': None, - 'contactName': 'David Gruber', - 'telephone': '+4369917057801', - 'email': 'david@elektroauto.at', - 'website': 'www.elektroauto.at', - 'directions': 'Ober dem letzten Parkplatz Krimmler Wasserfälle', - 'greenEnergy': 1.0, - 'freeParking': 1.0, - 'openingHours': {'text': None, 'details': []}, - 'priceUrl': None, - 'points': [], - 'public': True}) + raw1: pd.Series = pd.Series( + { + "evseCountryId": "DE", + "evseOperatorId": "ELE", + "evseStationId": "EKRIMML4", + "status": "ACTIVE", + "label": "KRIMML", + "description": None, + "postCode": 5743, + "city": "Krimml", + "street": "Gerlos Straße, Parkplatz P4", + "location": {"latitude": 12.167938, "longitude": 12.167938}, + "distance": None, + "contactName": "David Gruber", + "telephone": "+4369917057801", + "email": "david@elektroauto.at", + "website": "www.elektroauto.at", + "directions": "Ober dem letzten Parkplatz Krimmler Wasserfälle", + "greenEnergy": 1.0, + "freeParking": 1.0, + "openingHours": {"text": None, "details": []}, + "priceUrl": None, + "points": [], + "public": True, + } + ) # noinspection DuplicatedCode - raw2: pd.Series = pd.Series({ - 'evseCountryId': 'DE', - 'evseOperatorId': 'ELE', - 'evseStationId': 'EKRIMML', - 'status': 'ACTIVE', - 'label': 'KRIMML', - 'description': None, - 'postCode': 5743, - 'city': 'Krimml', - 'street': 'Gerlos Straße, Parkplatz P4', - 'location': {'latitude': 12.167938, 'longitude': 12.167938}, - 'distance': None, - 'contactName': 'David Gruber', - 'telephone': '+4369917057801', - 'email': 'david@elektroauto.at', - 'website': 'www.elektroauto.at', - 'directions': 'Ober dem letzten Parkplatz Krimmler Wasserfälle', - 'greenEnergy': 1.0, - 'freeParking': 1.0, - 'openingHours': {'text': None, 'details': []}, - 'priceUrl': None, - 'points': [], - 'public': True}) - - return _create_test_station(raw1, 'AT'), _create_test_station(raw2, 'AT') + raw2: pd.Series = pd.Series( + { + "evseCountryId": "DE", + "evseOperatorId": "ELE", + "evseStationId": "EKRIMML", + "status": "ACTIVE", + "label": "KRIMML", + "description": None, + "postCode": 5743, + "city": "Krimml", + "street": "Gerlos Straße, Parkplatz P4", + "location": {"latitude": 12.167938, "longitude": 12.167938}, + "distance": None, + "contactName": "David Gruber", + "telephone": "+4369917057801", + "email": "david@elektroauto.at", + "website": "www.elektroauto.at", + "directions": "Ober dem letzten Parkplatz Krimmler Wasserfälle", + "greenEnergy": 1.0, + "freeParking": 1.0, + "openingHours": {"text": None, "details": []}, + "priceUrl": None, + "points": [], + "public": True, + } + ) + + return _create_test_station(raw1, "AT"), _create_test_station(raw2, "AT") # Given: db with two problematic station entries session = _set_up_db(engine, _create_stations()) @@ -213,11 +221,13 @@ def _create_stations(): # Suppressing Pandas warning (1/2): "A value is trying to be set on a copy of a slice from a DataFrame." pd.options.mode.chained_assignment = None # default: 'warn' - station_merger = StationMerger(country_code='AT', config=(get_config()), db_engine=engine) + station_merger = StationMerger( + country_code="AT", config=(get_config()), db_engine=engine + ) station_merger.run() # Suppressing Pandas warning (2/2): restoring default value - pd.options.mode.chained_assignment = 'warn' + pd.options.mode.chained_assignment = "warn" # Check that all_stations are merged # noinspection DuplicatedCode @@ -226,7 +236,7 @@ def _create_stations(): not_merged_stations = [s for s in all_stations if not s.is_merged] assert len(not_merged_stations) == 2 - assert all(s.merge_status == 'is_duplicate' for s in not_merged_stations) + assert all(s.merge_status == "is_duplicate" for s in not_merged_stations) merged_stations = [s for s in all_stations if s.is_merged] assert len(merged_stations) == 1 From dfd6fbd38b08aaba39bf8b8383c7c8c7cf448fb8 Mon Sep 17 00:00:00 2001 From: Martin Mader Date: Fri, 26 Jan 2024 14:29:17 +0100 Subject: [PATCH 3/5] change ruff line-length setting to 120 --- charging_stations_pipelines/db_utils.py | 14 +-- .../attribute_match_thresholds_strategy.py | 24 ++--- .../deduplication/merger.py | 94 ++++++------------- .../models/__init__.py | 8 +- charging_stations_pipelines/models/address.py | 4 +- .../models/charging.py | 4 +- .../pipelines/at/econtrol.py | 21 +---- .../pipelines/at/econtrol_crawler.py | 8 +- .../pipelines/at/econtrol_mapper.py | 4 +- .../pipelines/de/bna.py | 4 +- .../pipelines/de/bna_crawler.py | 4 +- .../pipelines/de/bna_mapper.py | 42 ++------- .../pipelines/fr/france.py | 8 +- .../pipelines/fr/france_mapper.py | 4 +- .../pipelines/gb/gb_mapper.py | 20 +--- .../pipelines/gb/gbgov.py | 8 +- .../pipelines/nobil/nobil_pipeline.py | 30 ++---- .../pipelines/ocm/ocm.py | 8 +- .../pipelines/ocm/ocm_extractor.py | 32 ++----- .../pipelines/ocm/ocm_mapper.py | 20 +--- .../pipelines/osm/osm.py | 14 +-- .../pipelines/osm/osm_mapper.py | 21 +---- .../pipelines/osm/osm_receiver.py | 8 +- charging_stations_pipelines/settings.py | 17 +--- charging_stations_pipelines/shared.py | 22 +---- .../stations_data_export.py | 30 ++---- main.py | 13 +-- ruff.toml | 3 +- test/pipelines/at/test_econtrol_crawler.py | 25 ++--- test/pipelines/at/test_econtrol_mapper.py | 4 +- test/pipelines/de/test_bna_crawler.py | 16 +--- test/pipelines/de/test_bna_mapper.py | 14 +-- test/shared.py | 4 +- test/test_shared.py | 12 +-- testdata_import.py | 8 +- testing/testdata.py | 4 +- tests/integration/test_int_de_bna.py | 4 +- tests/integration/test_int_merger.py | 8 +- 38 files changed, 144 insertions(+), 444 deletions(-) diff --git a/charging_stations_pipelines/db_utils.py b/charging_stations_pipelines/db_utils.py index 49fbb3a7..8c8b5502 100644 --- a/charging_stations_pipelines/db_utils.py +++ b/charging_stations_pipelines/db_utils.py @@ -20,9 +20,7 @@ def delete_all_data(session: Session): ] for model in models: logger.info(f"Dropping table: '{model.__tablename__}'...") - session.execute( - text(f"TRUNCATE TABLE {model.__tablename__} RESTART IDENTITY CASCADE") - ) + session.execute(text(f"TRUNCATE TABLE {model.__tablename__} RESTART IDENTITY CASCADE")) session.commit() session.close() logger.info("Finished deleting all data from the database.") @@ -30,15 +28,9 @@ def delete_all_data(session: Session): def delete_all_merged_data(session: Session): """Deletes all merged data from the database.""" - session.execute( - text( - f"TRUNCATE TABLE {station.MergedStationSource.__tablename__} RESTART IDENTITY" - ) - ) + session.execute(text(f"TRUNCATE TABLE {station.MergedStationSource.__tablename__} RESTART IDENTITY")) session.execute(delete(address.Address).where(address.Address.is_merged.is_(True))) - session.execute( - delete(charging.Charging).where(charging.Charging.is_merged.is_(True)) - ) + session.execute(delete(charging.Charging).where(charging.Charging.is_merged.is_(True))) session.execute(delete(station.Station).where(station.Station.is_merged.is_(True))) session.execute(update(station.Station).values(merge_status=None)) session.commit() diff --git a/charging_stations_pipelines/deduplication/attribute_match_thresholds_strategy.py b/charging_stations_pipelines/deduplication/attribute_match_thresholds_strategy.py index a6fccebe..5261852d 100644 --- a/charging_stations_pipelines/deduplication/attribute_match_thresholds_strategy.py +++ b/charging_stations_pipelines/deduplication/attribute_match_thresholds_strategy.py @@ -14,9 +14,7 @@ def attribute_match_thresholds_duplicates( ) -> pd.DataFrame: pd.options.mode.chained_assignment = None - remaining_duplicate_candidates = duplicate_candidates[ - ~duplicate_candidates["is_duplicate"].astype(bool) - ] + remaining_duplicate_candidates = duplicate_candidates[~duplicate_candidates["is_duplicate"].astype(bool)] if remaining_duplicate_candidates.empty: return duplicate_candidates @@ -27,26 +25,20 @@ def attribute_match_thresholds_duplicates( ) logger.debug(f"{len(remaining_duplicate_candidates)} duplicate candidates") - remaining_duplicate_candidates[ - "operator_match" - ] = remaining_duplicate_candidates.operator.apply( + remaining_duplicate_candidates["operator_match"] = remaining_duplicate_candidates.operator.apply( lambda x: SequenceMatcher(None, current_station.operator, str(x)).ratio() if (current_station.operator is not None) & (x is not None) else 0.0 ) - remaining_duplicate_candidates[ - "address_match" - ] = remaining_duplicate_candidates.address.apply( + remaining_duplicate_candidates["address_match"] = remaining_duplicate_candidates.address.apply( lambda x: SequenceMatcher(None, current_station["address"], x).ratio() if (current_station["address"] != "None,None") & (x != "None,None") else 0.0, ) # this is always the distance to the initial central charging station - remaining_duplicate_candidates["distance_match"] = ( - 1 - remaining_duplicate_candidates["distance"] / max_distance - ) + remaining_duplicate_candidates["distance_match"] = 1 - remaining_duplicate_candidates["distance"] / max_distance def is_duplicate_by_score(duplicate_candidate): if duplicate_candidate["address_match"] >= 0.7: @@ -70,18 +62,14 @@ def is_duplicate_by_score(duplicate_candidate): ) return is_duplicate - remaining_duplicate_candidates[ - "is_duplicate" - ] = remaining_duplicate_candidates.apply(is_duplicate_by_score, axis=1) + remaining_duplicate_candidates["is_duplicate"] = remaining_duplicate_candidates.apply(is_duplicate_by_score, axis=1) # update original candidates duplicate_candidates.update(remaining_duplicate_candidates) # for all duplicates found via OSM, which has most of the time no address info, # run the check again against all candidates # so e.g. if we have a duplicate with address it can be matched to other data sources via this attribute - new_duplicates = remaining_duplicate_candidates[ - remaining_duplicate_candidates["is_duplicate"] - ] + new_duplicates = remaining_duplicate_candidates[remaining_duplicate_candidates["is_duplicate"]] for idx in range(new_duplicates.shape[0]): current_station: pd.Series = new_duplicates.iloc[idx] diff --git a/charging_stations_pipelines/deduplication/merger.py b/charging_stations_pipelines/deduplication/merger.py index fc33e3d5..d4949ad2 100644 --- a/charging_stations_pipelines/deduplication/merger.py +++ b/charging_stations_pipelines/deduplication/merger.py @@ -18,9 +18,7 @@ class StationMerger: - def __init__( - self, country_code: str, config: configparser, db_engine, is_test: bool = False - ): + def __init__(self, country_code: str, config: configparser, db_engine, is_test: bool = False): self.country_code = country_code self.config = config self.db_engine: Engine = db_engine @@ -64,25 +62,19 @@ def merge_attributes(station: pd.Series, duplicates_to_merge: pd.DataFrame): att_values = [str(x) for x in att_values if len(str(x)) > 0] if att_name in station.dropna(): att_value = str(station[att_name]) - att_values += ( - [att_value] if ";" not in att_value else att_value.split(";") - ) + att_values += [att_value] if ";" not in att_value else att_value.split(";") att_values = set(att_values) new_value = ";".join([str(x) for x in att_values]) if att_values else None station.at[att_name] = new_value station.at["merged_attributes"] = True - def _get_attribute_by_priority( - self, stations_to_merge, column_name, priority_list=None - ): + def _get_attribute_by_priority(self, stations_to_merge, column_name, priority_list=None): attribute = None if priority_list is None: priority_list = [self.gov_source, "OCM", "OSM"] for source in priority_list: # get stations of source with attribute not empty, and return only attribute column - stations_by_source = stations_to_merge[ - stations_to_merge["data_source"] == source - ][column_name].dropna() + stations_by_source = stations_to_merge[stations_to_merge["data_source"] == source][column_name].dropna() if len(stations_by_source) > 0: attribute = stations_by_source.iloc[0] break @@ -90,19 +82,13 @@ def _get_attribute_by_priority( logger.debug(f"attribute {column_name} not found ?!?!? {stations_to_merge}") return attribute - def _get_station_with_address_and_charging_by_priority( - self, session, stations_to_merge - ): + def _get_station_with_address_and_charging_by_priority(self, session, stations_to_merge): merged_station: Optional[Station] = None for source in [self.gov_source, "OCM", "OSM"]: - station_id = stations_to_merge[stations_to_merge["data_source"] == source][ - "station_id_col" - ] + station_id = stations_to_merge[stations_to_merge["data_source"] == source]["station_id_col"] if len(station_id) > 0: station_id = int(station_id.iloc[0]) - station, address, charging = self.get_station_with_address_and_charging( - session, station_id - ) + station, address, charging = self.get_station_with_address_and_charging(session, station_id) if not merged_station and station: merged_station = station if merged_station and address and not merged_station.address: @@ -133,14 +119,10 @@ def _merge_duplicates(self, stations_to_merge, session) -> Station: merged_station.point = stations_to_merge["point"].wkt merged_station.operator = stations_to_merge["operator"] - source = MergedStationSource( - duplicate_source_id=stations_to_merge["source_id"] - ) + source = MergedStationSource(duplicate_source_id=stations_to_merge["source_id"]) merged_station.source_stations.append(source) else: - merged_station = self._get_station_with_address_and_charging_by_priority( - session, stations_to_merge - ) + merged_station = self._get_station_with_address_and_charging_by_priority(session, stations_to_merge) data_sources = stations_to_merge["data_source"].unique() data_sources.sort() @@ -155,9 +137,7 @@ def _merge_duplicates(self, stations_to_merge, session) -> Station: priority_list=["OSM", "OCM", self.gov_source], ) merged_station.point = point.wkt - merged_station.operator = self._get_attribute_by_priority( - stations_to_merge, "operator" - ) + merged_station.operator = self._get_attribute_by_priority(stations_to_merge, "operator") for source_id in stations_to_merge["source_id"]: source = MergedStationSource(duplicate_source_id=source_id) @@ -170,9 +150,7 @@ def _merge_duplicates(self, stations_to_merge, session) -> Station: def get_station_with_address_and_charging(self, session, station_id): # get station from DB and create new object - merged_station: Station = ( - session.query(Station).filter(Station.id == station_id).first() - ) + merged_station: Station = session.query(Station).filter(Station.id == station_id).first() address = merged_station.address charging = merged_station.charging session.expunge(merged_station) # expunge the object from session @@ -242,9 +220,7 @@ def run(self): """ with self.db_engine.connect() as con: - gdf: GeoDataFrame = read_postgis( - get_stations_list_sql, con=con, geom_col="point" - ) + gdf: GeoDataFrame = read_postgis(get_stations_list_sql, con=con, geom_col="point") gdf.sort_values(by=["station_id"], inplace=True, ignore_index=True) @@ -264,21 +240,15 @@ def run(self): stations_to_merge = current_station_full # .to_frame() station_ids = [current_station["station_id"].item()] else: - stations_to_merge = pd.concat( - [duplicates, current_station_full.to_frame().T] - ) - station_ids = ( - stations_to_merge["station_id_col"].values.astype(int).tolist() - ) + stations_to_merge = pd.concat([duplicates, current_station_full.to_frame().T]) + station_ids = stations_to_merge["station_id_col"].values.astype(int).tolist() if not stations_to_merge.empty: # merge attributes of duplicates into one station session.query(Station).filter(Station.id.in_(station_ids)).update( {Station.merge_status: "is_duplicate"}, synchronize_session="fetch" ) - merged_station: Station = self._merge_duplicates( - stations_to_merge, session - ) + merged_station: Station = self._merge_duplicates(stations_to_merge, session) session.add(merged_station) self._write_session(session) session.close() @@ -309,17 +279,13 @@ def find_duplicates( """ with self.db_engine.connect() as con: - nearby_stations: GeoDataFrame = read_postgis( - find_surrounding_stations_sql, con=con, geom_col="point" - ) + nearby_stations: GeoDataFrame = read_postgis(find_surrounding_stations_sql, con=con, geom_col="point") if nearby_stations.empty: logger.debug(f"##### Already merged, id {current_station_id} #####") return GeoDataFrame(), GeoSeries() - logger.debug( - f"coordinates of current station: {current_station_coordinates}, ID: {current_station_id}" - ) + logger.debug(f"coordinates of current station: {current_station_coordinates}, ID: {current_station_id}") logger.debug(f"# nearby stations incl current: {len(nearby_stations)}") # copy station id to new column otherwise it's not addressable as column after setting index station_id_col = "station_id_col" @@ -344,24 +310,18 @@ def find_duplicates( # skip if only center station itself was found if len(nearby_stations) < 2 or current_station_full.empty: return GeoDataFrame(), current_station_full - duplicate_candidates = nearby_stations[ - nearby_stations[station_id_name] != current_station_id - ] + duplicate_candidates = nearby_stations[nearby_stations[station_id_name] != current_station_id] duplicate_candidates["is_duplicate"] = False current_station_full["is_duplicate"] = True - duplicate_candidates["address"] = duplicate_candidates[ - ["street", "town"] - ].apply(lambda x: f"{x['street']},{x['town']}", axis=1) - current_station_full[ - "address" - ] = f"{current_station_full['street']},{current_station_full['town']}" - duplicate_candidates = ( - attribute_match_thresholds_strategy.attribute_match_thresholds_duplicates( - current_station=current_station_full, - duplicate_candidates=duplicate_candidates, - station_id_name=station_id_name, - max_distance=radius_m, - ) + duplicate_candidates["address"] = duplicate_candidates[["street", "town"]].apply( + lambda x: f"{x['street']},{x['town']}", axis=1 + ) + current_station_full["address"] = f"{current_station_full['street']},{current_station_full['town']}" + duplicate_candidates = attribute_match_thresholds_strategy.attribute_match_thresholds_duplicates( + current_station=current_station_full, + duplicate_candidates=duplicate_candidates, + station_id_name=station_id_name, + max_distance=radius_m, ) duplicates = duplicate_candidates[duplicate_candidates["is_duplicate"]] return duplicates, current_station_full diff --git a/charging_stations_pipelines/models/__init__.py b/charging_stations_pipelines/models/__init__.py index 8cf33b02..5f316e06 100644 --- a/charging_stations_pipelines/models/__init__.py +++ b/charging_stations_pipelines/models/__init__.py @@ -24,16 +24,12 @@ def __setattr__(self, name: str, value: Any) -> None: :return: None. """ if not (name.startswith("_") or hasattr(self, name)): - raise AttributeError( - f"Cannot set non-existing attribute '{name}' on class '{self.__class__.__name__}'." - ) + raise AttributeError(f"Cannot set non-existing attribute '{name}' on class '{self.__class__.__name__}'.") super().__setattr__(name, value) def __repr__(self): return f"<{self.__class__.__name__} with id: {self.id}>" -Base = declarative_base( - cls=BaseWithSafeSetProperty, metadata=MetaData(schema=settings.db_schema) -) +Base = declarative_base(cls=BaseWithSafeSetProperty, metadata=MetaData(schema=settings.db_schema)) """The base class for all models.""" diff --git a/charging_stations_pipelines/models/address.py b/charging_stations_pipelines/models/address.py index b176b9ec..ec367460 100644 --- a/charging_stations_pipelines/models/address.py +++ b/charging_stations_pipelines/models/address.py @@ -10,9 +10,7 @@ class Address(Base): __tablename__ = f"{settings.db_table_prefix}address" id = Column(Integer, primary_key=True, autoincrement=True) - station_id = Column( - Integer, ForeignKey(f"{Station.__tablename__}.id"), nullable=False, unique=True - ) + station_id = Column(Integer, ForeignKey(f"{Station.__tablename__}.id"), nullable=False, unique=True) date_created = Column(Date) date_updated = Column(Date) street = Column(String) diff --git a/charging_stations_pipelines/models/charging.py b/charging_stations_pipelines/models/charging.py index 247fd238..24a9d30a 100644 --- a/charging_stations_pipelines/models/charging.py +++ b/charging_stations_pipelines/models/charging.py @@ -10,9 +10,7 @@ class Charging(Base): __tablename__ = f"{settings.db_table_prefix}charging" id = Column(Integer, primary_key=True, autoincrement=True) - station_id = Column( - Integer, ForeignKey(f"{Station.__tablename__}.id"), nullable=False, unique=True - ) + station_id = Column(Integer, ForeignKey(f"{Station.__tablename__}.id"), nullable=False, unique=True) date_created = Column(Date) date_updated = Column(Date) capacity = Column(Integer) diff --git a/charging_stations_pipelines/pipelines/at/econtrol.py b/charging_stations_pipelines/pipelines/at/econtrol.py index a150be0b..dd5ac58e 100644 --- a/charging_stations_pipelines/pipelines/at/econtrol.py +++ b/charging_stations_pipelines/pipelines/at/econtrol.py @@ -50,15 +50,11 @@ def __init__(self, config: configparser, session: Session, online: bool = False) self.country_code = "AT" relative_dir = os.path.join("../../..", "data") - self.data_dir = os.path.join( - pathlib.Path(__file__).parent.resolve(), relative_dir - ) + self.data_dir = os.path.join(pathlib.Path(__file__).parent.resolve(), relative_dir) def _retrieve_data(self): pathlib.Path(self.data_dir).mkdir(parents=True, exist_ok=True) - tmp_data_path = os.path.join( - self.data_dir, self.config[DATA_SOURCE_KEY]["filename"] - ) + tmp_data_path = os.path.join(self.data_dir, self.config[DATA_SOURCE_KEY]["filename"]) if self.online: logger.info("Retrieving Online Data") get_data(tmp_data_path) @@ -78,9 +74,7 @@ def run(self): stats = collections.defaultdict(int) datapoint: pd.Series - for _, datapoint in tqdm( - iterable=self.data.iterrows(), total=self.data.shape[0] - ): + for _, datapoint in tqdm(iterable=self.data.iterrows(), total=self.data.shape[0]): try: station = map_station(datapoint, self.country_code) @@ -92,11 +86,7 @@ def run(self): station.address = map_address(datapoint, self.country_code, None) # Count stations which have an invalid address - if ( - station.address - and station.address.country - and station.address.country not in SCOPE_COUNTRIES - ): + if station.address and station.address.country and station.address.country not in SCOPE_COUNTRIES: stats["count_country_mismatch_stations"] += 1 # Count stations which have a mismatching country code between Station and Address @@ -111,8 +101,7 @@ def run(self): except Exception as e: stats["count_parse_error"] += 1 logger.debug( - f"{DATA_SOURCE_KEY} entry could not be parsed, error:\n{e}\n" - f"Row:\n----\n{datapoint}\n----\n" + f"{DATA_SOURCE_KEY} entry could not be parsed, error:\n{e}\n" f"Row:\n----\n{datapoint}\n----\n" ) logger.info( f"Finished {DATA_SOURCE_KEY} Pipeline:\n" diff --git a/charging_stations_pipelines/pipelines/at/econtrol_crawler.py b/charging_stations_pipelines/pipelines/at/econtrol_crawler.py index 414b9292..37bc4139 100644 --- a/charging_stations_pipelines/pipelines/at/econtrol_crawler.py +++ b/charging_stations_pipelines/pipelines/at/econtrol_crawler.py @@ -12,9 +12,7 @@ logger = logging.getLogger(__name__) -def _get_paginated_stations( - url: str, headers: dict[str, str] = None -) -> Generator[dict[str, Any], None, None]: +def _get_paginated_stations(url: str, headers: dict[str, str] = None) -> Generator[dict[str, Any], None, None]: session = requests.Session() session.headers.update(headers) @@ -47,9 +45,7 @@ def _get_paginated_stations( idx_end = min(page_size * page_num - 1, total_count - 1) logger.debug(f"Downloading chunk: {idx_start}..{idx_end}") - next_page = session.get( - url, params={"fromIndex": idx_start, "endIndex": idx_end} - ).json() + next_page = session.get(url, params={"fromIndex": idx_start, "endIndex": idx_end}).json() yield next_page diff --git a/charging_stations_pipelines/pipelines/at/econtrol_mapper.py b/charging_stations_pipelines/pipelines/at/econtrol_mapper.py index 6fad36f7..55ba8b04 100644 --- a/charging_stations_pipelines/pipelines/at/econtrol_mapper.py +++ b/charging_stations_pipelines/pipelines/at/econtrol_mapper.py @@ -103,9 +103,7 @@ def map_station(row: pd.Series, country_code: str) -> Station: return station -def map_address( - row: pd.Series, country_code: str, station_id: Optional[int] -) -> Address: +def map_address(row: pd.Series, country_code: str, station_id: Optional[int]) -> Address: """Maps the given raw datapoint to an Address object. :param row: A datapoint representing the raw data. diff --git a/charging_stations_pipelines/pipelines/de/bna.py b/charging_stations_pipelines/pipelines/de/bna.py index 13c41c8e..3a336de3 100644 --- a/charging_stations_pipelines/pipelines/de/bna.py +++ b/charging_stations_pipelines/pipelines/de/bna.py @@ -30,9 +30,7 @@ def __init__(self, config: configparser, session: Session, online: bool = False) # All BNA data is from Germany self.country_code = "DE" - self.data_dir: Final[pathlib.Path] = ( - pathlib.Path(__file__).parents[3] / "data" - ).resolve() + self.data_dir: Final[pathlib.Path] = (pathlib.Path(__file__).parents[3] / "data").resolve() def retrieve_data(self): self.data_dir.mkdir(parents=True, exist_ok=True) diff --git a/charging_stations_pipelines/pipelines/de/bna_crawler.py b/charging_stations_pipelines/pipelines/de/bna_crawler.py index 091d2a27..eec722ff 100644 --- a/charging_stations_pipelines/pipelines/de/bna_crawler.py +++ b/charging_stations_pipelines/pipelines/de/bna_crawler.py @@ -51,9 +51,7 @@ def get_bna_data(tmp_data_path: str) -> None: # Check if the url extraction is successful if download_link_url is None: - raise ExtractURLException( - "Failed to extract the download url from the website." - ) + raise ExtractURLException("Failed to extract the download url from the website.") logger.debug(f"Downloading BNA data from '{download_link_url}'") try: diff --git a/charging_stations_pipelines/pipelines/de/bna_mapper.py b/charging_stations_pipelines/pipelines/de/bna_mapper.py index 5790273e..18f5e99d 100644 --- a/charging_stations_pipelines/pipelines/de/bna_mapper.py +++ b/charging_stations_pipelines/pipelines/de/bna_mapper.py @@ -28,9 +28,7 @@ def map_station_bna(row: pd.Series): new_station.country_code = "DE" new_station.data_source = bna.DATA_SOURCE_KEY - new_station.source_id = hashlib.sha256( - f"{lat}{long}{new_station.data_source}".encode() - ).hexdigest() + new_station.source_id = hashlib.sha256(f"{lat}{long}{new_station.data_source}".encode()).hexdigest() new_station.operator = row["Betreiber"] new_station.point = from_shape(Point(float(long), float(lat))) @@ -47,9 +45,7 @@ def map_address_bna(row: pd.Series, station_id) -> Address: if len(postcode) == 4: postcode = "0" + postcode if len(postcode) != 5: - logger.debug( - f"Failed to process postcode {postcode}! Will set postcode to None!" - ) + logger.debug(f"Failed to process postcode {postcode}! Will set postcode to None!") postcode = None if len(town) < 2: logger.debug(f"Failed to process town {town}! Will set town to None!") @@ -58,11 +54,7 @@ def map_address_bna(row: pd.Series, station_id) -> Address: address = Address() address.station_id = station_id - address.street = ( - str_strip_whitespace(row.get("Straße")) - + " " - + str_strip_whitespace(row.get("Hausnummer")) - ) + address.street = str_strip_whitespace(row.get("Straße")) + " " + str_strip_whitespace(row.get("Hausnummer")) address.town = town address.postcode = postcode address.district = row["Kreis/kreisfreie Stadt"] @@ -82,18 +74,14 @@ def map_charging_bna(row: pd.Series, station_id): total_kw = float(total_kw.replace(",", ".")) logger.debug(f"Converting total_kw from string {total_kw} to int!") except Exception as conversionErr: - logger.warning( - f"Failed to convert string {total_kw} to Number! Will set total_kw to None! {conversionErr}" - ) + logger.warning(f"Failed to convert string {total_kw} to Number! Will set total_kw to None! {conversionErr}") total_kw = None if isinstance(total_kw, Number) and math.isnan(total_kw): total_kw = None if not isinstance(total_kw, Number): - logger.warning( - f"Cannot process total_kw {total_kw} with type {type(total_kw)}! Will set total_kw to None!" - ) + logger.warning(f"Cannot process total_kw {total_kw} with type {type(total_kw)}! Will set total_kw to None!") total_kw = None # kw_list @@ -106,16 +94,12 @@ def map_charging_bna(row: pd.Series, station_id): if isinstance(v, str): if "," in v: v: str = v.replace(",", ".") - logger.debug( - "Replaced coma with point for string to float conversion of kw!" - ) + logger.debug("Replaced coma with point for string to float conversion of kw!") try: float_kw: float = float(v) kw_list += [float_kw] except Exception: - logger.warning( - f"Failed to convert kw string {v} to float! Will not add this kw entry to list!" - ) + logger.warning(f"Failed to convert kw string {v} to float! Will not add this kw entry to list!") if isinstance(v, Number): kw_list += [v] @@ -130,24 +114,18 @@ def map_charging_bna(row: pd.Series, station_id): # volt_list not available # socket_type_list socket_types_infos: list[str] = [ - v - for k, v in station_raw.items() - if ("Steckertypen" in k) & (isinstance(v, str)) & (not pd.isnull(v)) + v for k, v in station_raw.items() if ("Steckertypen" in k) & (isinstance(v, str)) & (not pd.isnull(v)) ] socket_type_list: list[str] = [] dc_support: bool = False for socket_types_info in socket_types_infos: tmp_socket_info: list[str] = socket_types_info.split(",") - if (not dc_support) & ( - any(["DC" in s for s in tmp_socket_info]) - ): # TODO: find more reliable way! + if (not dc_support) & (any(["DC" in s for s in tmp_socket_info])): # TODO: find more reliable way! dc_support = True socket_type_list += socket_types_info.split(",") kw_list_len: int = len(kw_list) if len(kw_list) != capacity: - logger.warning( - f"Difference between length of kw_list {kw_list_len} and capacity {capacity}!" - ) + logger.warning(f"Difference between length of kw_list {kw_list_len} and capacity {capacity}!") charging = Charging() charging.station_id = station_id diff --git a/charging_stations_pipelines/pipelines/fr/france.py b/charging_stations_pipelines/pipelines/fr/france.py index a1378b22..205908aa 100644 --- a/charging_stations_pipelines/pipelines/fr/france.py +++ b/charging_stations_pipelines/pipelines/fr/france.py @@ -25,9 +25,7 @@ class FraPipeline(Pipeline): def _retrieve_data(self): - data_dir = os.path.join( - pathlib.Path(__file__).parent.resolve(), "../../..", "data" - ) + data_dir = os.path.join(pathlib.Path(__file__).parent.resolve(), "../../..", "data") pathlib.Path(data_dir).mkdir(parents=True, exist_ok=True) tmp_data_path = os.path.join(data_dir, self.config["FRGOV"]["filename"]) if self.online: @@ -51,9 +49,7 @@ def run(self): mapped_station = map_station_fra(row) mapped_station.address = mapped_address mapped_station.charging = mapped_charging - station_updater.update_station( - station=mapped_station, data_source_key="FRGOV" - ) + station_updater.update_station(station=mapped_station, data_source_key="FRGOV") station_updater.log_update_station_counts() @staticmethod diff --git a/charging_stations_pipelines/pipelines/fr/france_mapper.py b/charging_stations_pipelines/pipelines/fr/france_mapper.py index 553f0f26..453a44c5 100644 --- a/charging_stations_pipelines/pipelines/fr/france_mapper.py +++ b/charging_stations_pipelines/pipelines/fr/france_mapper.py @@ -45,9 +45,7 @@ def map_station_fra(row: pd.Series) -> Station: station.date_updated = row.get("date_maj").strptime("%Y-%m-%d") if not pd.isna(row.get("date_mise_en_service")): - station.date_created = datetime.strptime( - row.get("date_mise_en_service"), "%Y-%m-%d" - ) + station.date_created = datetime.strptime(row.get("date_mise_en_service"), "%Y-%m-%d") if not pd.isna(row.get("date_maj")): station.date_updated = datetime.strptime(row.get("date_maj"), "%Y-%m-%d") else: diff --git a/charging_stations_pipelines/pipelines/gb/gb_mapper.py b/charging_stations_pipelines/pipelines/gb/gb_mapper.py index a5aa7582..94eaf3f3 100644 --- a/charging_stations_pipelines/pipelines/gb/gb_mapper.py +++ b/charging_stations_pipelines/pipelines/gb/gb_mapper.py @@ -32,28 +32,18 @@ def map_station_gb(entry, country_code: str): def map_address_gb(entry, station_id): - postcode_raw: Optional[str] = ( - entry.get("ChargeDeviceLocation").get("Address").get("PostCode") - ) + postcode_raw: Optional[str] = entry.get("ChargeDeviceLocation").get("Address").get("PostCode") postcode: Optional[str] = postcode_raw - town_raw: Optional[str] = ( - entry.get("ChargeDeviceLocation").get("Address").get("PostTown") - ) + town_raw: Optional[str] = entry.get("ChargeDeviceLocation").get("Address").get("PostTown") town: Optional[str] = town_raw if isinstance(town_raw, str) else None - state_raw: Optional[str] = ( - entry.get("ChargeDeviceLocation").get("Address").get("County") - ) + state_raw: Optional[str] = entry.get("ChargeDeviceLocation").get("Address").get("County") state: Optional[str] = state_raw if isinstance(state_raw, str) else None - country: Optional[str] = ( - entry.get("ChargeDeviceLocation").get("Address").get("Country") - ) + country: Optional[str] = entry.get("ChargeDeviceLocation").get("Address").get("Country") - street_raw: Optional[str] = ( - entry.get("ChargeDeviceLocation").get("Address").get("Street") - ) + street_raw: Optional[str] = entry.get("ChargeDeviceLocation").get("Address").get("Street") street: Optional[str] = street_raw if isinstance(street_raw, str) else None map_address = Address() diff --git a/charging_stations_pipelines/pipelines/gb/gbgov.py b/charging_stations_pipelines/pipelines/gb/gbgov.py index 07a15a40..3e252fd2 100644 --- a/charging_stations_pipelines/pipelines/gb/gbgov.py +++ b/charging_stations_pipelines/pipelines/gb/gbgov.py @@ -31,9 +31,7 @@ def __init__(self, config: configparser, session: Session, online: bool = False) self.data: Optional[JSON] = None def _retrieve_data(self): - data_dir: str = os.path.join( - pathlib.Path(__file__).parent.resolve(), "../../..", "data" - ) + data_dir: str = os.path.join(pathlib.Path(__file__).parent.resolve(), "../../..", "data") pathlib.Path(data_dir).mkdir(parents=True, exist_ok=True) tmp_file_path = os.path.join(data_dir, self.config["GBGOV"]["filename"]) if self.online: @@ -56,7 +54,5 @@ def run(self): mapped_station = map_station_gb(entry, " GB") mapped_station.address = mapped_address mapped_station.charging = mapped_charging - station_updater.update_station( - station=mapped_station, data_source_key="GBGOV" - ) + station_updater.update_station(station=mapped_station, data_source_key="GBGOV") station_updater.log_update_station_counts() diff --git a/charging_stations_pipelines/pipelines/nobil/nobil_pipeline.py b/charging_stations_pipelines/pipelines/nobil/nobil_pipeline.py index 8e11aa88..5f46520c 100644 --- a/charging_stations_pipelines/pipelines/nobil/nobil_pipeline.py +++ b/charging_stations_pipelines/pipelines/nobil/nobil_pipeline.py @@ -85,15 +85,11 @@ def parse_nobil_connectors(connectors: dict): parsed_connectors: list[NobilConnector] = [] # iterate over all connectors and add them to the station for k, v in connectors.items(): - charging_capacity = v["5"][ - "trans" - ] # contains a string like "7,4 kW - 230V 1-phase max 32A" or "75 kW DC" + charging_capacity = v["5"]["trans"] # contains a string like "7,4 kW - 230V 1-phase max 32A" or "75 kW DC" # extract the power in kW from the charging capacity string power_in_kw = ( - Decimal(charging_capacity.split(" kW")[0].replace(",", ".")) - if " kW" in charging_capacity - else None + Decimal(charging_capacity.split(" kW")[0].replace(",", ".")) if " kW" in charging_capacity else None ) parsed_connectors.append(NobilConnector(power_in_kw)) @@ -132,9 +128,7 @@ def _map_charging_to_domain(nobil_station: NobilStation) -> Charging: new_charging: Charging = Charging() new_charging.capacity = nobil_station.number_charging_points new_charging.kw_list = [ - connector.power_in_kw - for connector in nobil_station.connectors - if connector.power_in_kw is not None + connector.power_in_kw for connector in nobil_station.connectors if connector.power_in_kw is not None ] if len(new_charging.kw_list) > 0: new_charging.max_kw = max(new_charging.kw_list) @@ -164,17 +158,13 @@ def __init__( super().__init__(config, session, online) accepted_country_codes = ["NOR", "SWE"] - reject_if( - country_code.upper() not in accepted_country_codes, "Invalid country code " - ) + reject_if(country_code.upper() not in accepted_country_codes, "Invalid country code ") self.country_code = country_code.upper() def run(self): """Run the pipeline.""" logger.info("Running NOR/SWE GOV Pipeline...") - path_to_target = Path(__file__).parent.parent.parent.parent.joinpath( - "data/" + self.country_code + "_gov.json" - ) + path_to_target = Path(__file__).parent.parent.parent.parent.joinpath("data/" + self.country_code + "_gov.json") if self.online: logger.info("Retrieving Online Data") _load_datadump_and_write_to_target(path_to_target, self.country_code) @@ -182,9 +172,7 @@ def run(self): nobil_stations_as_json = load_json_file(path_to_target) all_nobil_stations = _parse_json_data(nobil_stations_as_json) - for nobil_station in tqdm( - iterable=all_nobil_stations, total=len(all_nobil_stations) - ): + for nobil_station in tqdm(iterable=all_nobil_stations, total=len(all_nobil_stations)): station: Station = _map_station_to_domain(nobil_station, self.country_code) address: Address = _map_address_to_domain(nobil_station) charging: Charging = _map_charging_to_domain(nobil_station) @@ -193,11 +181,7 @@ def run(self): station.charging = charging # check if station already exists in db and add - existing_station = ( - self.session.query(Station) - .filter_by(source_id=station.source_id) - .first() - ) + existing_station = self.session.query(Station).filter_by(source_id=station.source_id).first() if existing_station is None: self.session.add(station) diff --git a/charging_stations_pipelines/pipelines/ocm/ocm.py b/charging_stations_pipelines/pipelines/ocm/ocm.py index dc7e44a1..5d70dfe2 100644 --- a/charging_stations_pipelines/pipelines/ocm/ocm.py +++ b/charging_stations_pipelines/pipelines/ocm/ocm.py @@ -38,9 +38,7 @@ def __init__( self.data: JSON = None def _retrieve_data(self): - data_dir: str = os.path.join( - pathlib.Path(__file__).parent.resolve(), "../../..", "data" - ) + data_dir: str = os.path.join(pathlib.Path(__file__).parent.resolve(), "../../..", "data") pathlib.Path(data_dir).mkdir(parents=True, exist_ok=True) tmp_file_path = os.path.join(data_dir, self.config["OCM"]["filename"]) if self.online: @@ -61,7 +59,5 @@ def run(self): mapped_station = map_station_ocm(entry, self.country_code) mapped_station.address = mapped_address mapped_station.charging = mapped_charging - station_updater.update_station( - station=mapped_station, data_source_key="OCM" - ) + station_updater.update_station(station=mapped_station, data_source_key="OCM") station_updater.log_update_station_counts() diff --git a/charging_stations_pipelines/pipelines/ocm/ocm_extractor.py b/charging_stations_pipelines/pipelines/ocm/ocm_extractor.py index 326d5075..bdf42193 100644 --- a/charging_stations_pipelines/pipelines/ocm/ocm_extractor.py +++ b/charging_stations_pipelines/pipelines/ocm/ocm_extractor.py @@ -20,12 +20,8 @@ def reference_data_to_frame(data: List[Dict]) -> pd.DataFrame: return frame -def merge_connection_types( - connection: pd.DataFrame, reference_data: pd.DataFrame -) -> pd.DataFrame: - connection_ids: pd.Series = ( - connection["ConnectionTypeID"].dropna().drop_duplicates() - ) +def merge_connection_types(connection: pd.DataFrame, reference_data: pd.DataFrame) -> pd.DataFrame: + connection_ids: pd.Series = connection["ConnectionTypeID"].dropna().drop_duplicates() return connection.merge( reference_data.loc[connection_ids], how="left", @@ -34,9 +30,7 @@ def merge_connection_types( ) -def merge_address_infos( - address_info: pd.Series, reference_data: pd.DataFrame -) -> pd.DataFrame: +def merge_address_infos(address_info: pd.Series, reference_data: pd.DataFrame) -> pd.DataFrame: return pd.concat([address_info, reference_data.loc[address_info["CountryID"]]]) @@ -50,9 +44,7 @@ def merge_with_reference_data( connection=pd.json_normalize(row["Connections"]), reference_data=connection_types, ) - row["AddressInfo"] = merge_address_infos( - address_info=pd.Series(row["AddressInfo"]), reference_data=address_info - ) + row["AddressInfo"] = merge_address_infos(address_info=pd.Series(row["AddressInfo"]), reference_data=address_info) row["OperatorID"] = operators.loc[row["OperatorID"]] return row @@ -61,9 +53,7 @@ def merge_connections(row, connection_types): frame = pd.DataFrame(row) if "ConnectionTypeID" not in frame.columns: return frame - return pd.merge( - frame, connection_types, how="left", left_on="ConnectionTypeID", right_on="ID" - ) + return pd.merge(frame, connection_types, how="left", left_on="ConnectionTypeID", right_on="ID") def ocm_extractor(tmp_file_path: str, country_code: str): @@ -134,9 +124,7 @@ def ocm_extractor(tmp_file_path: str, country_code: str): data_ref: Dict = json.load(f) connection_types: pd.DataFrame = pd.json_normalize(data_ref["ConnectionTypes"]) - connection_frame = pd.json_normalize( - records, record_path=["Connections"], meta=["UUID"] - ) + connection_frame = pd.json_normalize(records, record_path=["Connections"], meta=["UUID"]) connection_frame = pd.merge( connection_frame, connection_types, @@ -146,9 +134,7 @@ def ocm_extractor(tmp_file_path: str, country_code: str): ) connection_frame_grouped = connection_frame.groupby("UUID").agg(list) connection_frame_grouped.reset_index(inplace=True) - connection_frame_grouped["ConnectionsEnriched"] = connection_frame_grouped.apply( - lambda x: x.to_frame(), axis=1 - ) + connection_frame_grouped["ConnectionsEnriched"] = connection_frame_grouped.apply(lambda x: x.to_frame(), axis=1) data = pd.merge( data, connection_frame_grouped[["ConnectionsEnriched", "UUID"]], @@ -176,6 +162,4 @@ def ocm_extractor(tmp_file_path: str, country_code: str): how="left", ) - pd_merged_with_operators.reset_index(drop=True).to_json( - tmp_file_path, orient="index" - ) + pd_merged_with_operators.reset_index(drop=True).to_json(tmp_file_path, orient="index") diff --git a/charging_stations_pipelines/pipelines/ocm/ocm_mapper.py b/charging_stations_pipelines/pipelines/ocm/ocm_mapper.py index ffe9ef56..434dd651 100644 --- a/charging_stations_pipelines/pipelines/ocm/ocm_mapper.py +++ b/charging_stations_pipelines/pipelines/ocm/ocm_mapper.py @@ -63,27 +63,17 @@ def map_charging_ocm(row, station_id) -> Charging: mapped_charging_ocm.station_id = station_id mapped_charging_ocm.capacity = row.get("NumberOfPoints") mapped_charging_ocm.kw_list = None - mapped_charging_ocm.ampere_list = ( - connections["Amps"].to_list() if "Amps" in connections.columns else None - ) - mapped_charging_ocm.volt_list = ( - connections["Voltage"].to_list() if "Voltage" in connections.columns else None - ) + mapped_charging_ocm.ampere_list = connections["Amps"].to_list() if "Amps" in connections.columns else None + mapped_charging_ocm.volt_list = connections["Voltage"].to_list() if "Voltage" in connections.columns else None mapped_charging_ocm.socket_type_list = ( - connections["Title"].str.cat(sep=",") - if "Title" in connections.columns - else None + connections["Title"].str.cat(sep=",") if "Title" in connections.columns else None ) mapped_charging_ocm.dc_support = None mapped_charging_ocm.total_kw = ( - float(round(connections["PowerKW"].dropna().sum(), 2)) - if "PowerKW" in connections.columns - else None + float(round(connections["PowerKW"].dropna().sum(), 2)) if "PowerKW" in connections.columns else None ) mapped_charging_ocm.max_kw = ( - float(connections["PowerKW"].dropna().max()) - if "PowerKW" in connections.columns - else None + float(connections["PowerKW"].dropna().max()) if "PowerKW" in connections.columns else None ) return mapped_charging_ocm diff --git a/charging_stations_pipelines/pipelines/osm/osm.py b/charging_stations_pipelines/pipelines/osm/osm.py index 068048cf..d34d192d 100644 --- a/charging_stations_pipelines/pipelines/osm/osm.py +++ b/charging_stations_pipelines/pipelines/osm/osm.py @@ -39,9 +39,7 @@ def __init__( self.country_code = country_code def retrieve_data(self): - data_dir: str = os.path.join( - pathlib.Path(__file__).parent.resolve(), "../../..", "data" - ) + data_dir: str = os.path.join(pathlib.Path(__file__).parent.resolve(), "../../..", "data") pathlib.Path(data_dir).mkdir(parents=True, exist_ok=True) tmp_file_path = os.path.join(data_dir, self.config[DATA_SOURCE_KEY]["filename"]) if self.online: @@ -70,11 +68,7 @@ def run(self): station.address = osm_mapper.map_address_osm(entry, None) # Count stations which have an invalid address - if ( - station.address - and station.address.country - and station.address.country not in SCOPE_COUNTRIES - ): + if station.address and station.address.country and station.address.country not in SCOPE_COUNTRIES: stats["count_country_mismatch_stations"] += 1 # Count stations which have a mismatching country code between Station and Address @@ -93,9 +87,7 @@ def run(self): stats["count_valid_stations"] += 1 except Exception as ex: stats["count_parse_error"] += 1 - logger.debug( - f"{DATA_SOURCE_KEY} entry could not be parsed, error: {ex}. Row: {entry}" - ) + logger.debug(f"{DATA_SOURCE_KEY} entry could not be parsed, error: {ex}. Row: {entry}") logger.info( f"Finished {DATA_SOURCE_KEY} Pipeline:\n" diff --git a/charging_stations_pipelines/pipelines/osm/osm_mapper.py b/charging_stations_pipelines/pipelines/osm/osm_mapper.py index 22c3d186..aba20ae7 100644 --- a/charging_stations_pipelines/pipelines/osm/osm_mapper.py +++ b/charging_stations_pipelines/pipelines/osm/osm_mapper.py @@ -59,14 +59,10 @@ def map_station_osm(entry: JSON, country_code: str) -> Station: new_station = Station() new_station.country_code = country_code new_station.source_id = entry.get("id") or None - new_station.operator = ( - str_strip_whitespace(entry.get("tags", {}).get("operator")) or None - ) + new_station.operator = str_strip_whitespace(entry.get("tags", {}).get("operator")) or None new_station.data_source = DATA_SOURCE_KEY new_station.point = from_shape(Point(lon, lat)) if lon and lat else None - new_station.date_created = ( - str_strip_whitespace(entry.get("timestamp")) or datetime.now() - ) + new_station.date_created = str_strip_whitespace(entry.get("timestamp")) or datetime.now() new_station.raw_data = json.dumps(entry, ensure_ascii=False) return new_station @@ -90,9 +86,7 @@ def map_address_osm(entry: JSON, station_id: Optional[int]) -> Optional[Address] map_address.station_id = station_id map_address.street = ( str_strip_whitespace( - str_strip_whitespace(tags.get("addr:street")) - + " " - + str_strip_whitespace(tags.get("addr:housenumber")) + str_strip_whitespace(tags.get("addr:street")) + " " + str_strip_whitespace(tags.get("addr:housenumber")) ) or None ) @@ -153,9 +147,7 @@ def extract_kw_map(datapoint: JSON) -> dict[str, list[float]]: socket_output_dict = {} for socket_type in SOCKET_TYPES.keys(): - socket_output_dict[socket_type] = extract_kw_list( - tags.get(f"{socket_type}:output") - ) + socket_output_dict[socket_type] = extract_kw_list(tags.get(f"{socket_type}:output")) socket_output_map = {k: v for k, v in socket_output_dict.items() if v} return socket_output_map @@ -210,10 +202,7 @@ def map_charging_osm(row: JSON, station_id: Optional[int]) -> Charging: charging.volt_list = extract_volt_list(row) or None charging.socket_type_list = [SOCKET_TYPES.get(k) for k in kw_map.keys()] or None charging.dc_support = None - charging.total_kw = ( - calc_total_kw(kw_list, row.get("tags", {}).get("charging_station:output")) - or None - ) + charging.total_kw = calc_total_kw(kw_list, row.get("tags", {}).get("charging_station:output")) or None charging.max_kw = max(kw_list) if kw_list else None return charging diff --git a/charging_stations_pipelines/pipelines/osm/osm_receiver.py b/charging_stations_pipelines/pipelines/osm/osm_receiver.py index 425fa75f..a7866a30 100644 --- a/charging_stations_pipelines/pipelines/osm/osm_receiver.py +++ b/charging_stations_pipelines/pipelines/osm/osm_receiver.py @@ -57,13 +57,9 @@ def get_osm_data(country_code: str, tmp_data_path): """ } - response: Response = requests.get( - "https://overpass-api.de/api/interpreter", query_params - ) + response: Response = requests.get("https://overpass-api.de/api/interpreter", query_params) status_code: int = response.status_code if status_code != 200: - raise RuntimeError( - f"Failed to get {DATA_SOURCE_KEY} data! Status code: {status_code}" - ) + raise RuntimeError(f"Failed to get {DATA_SOURCE_KEY} data! Status code: {status_code}") with open(tmp_data_path, "w") as f: json.dump(response.json(), f, ensure_ascii=False, indent=4, sort_keys=True) diff --git a/charging_stations_pipelines/settings.py b/charging_stations_pipelines/settings.py index 5f06a0cc..695848f3 100644 --- a/charging_stations_pipelines/settings.py +++ b/charging_stations_pipelines/settings.py @@ -7,9 +7,7 @@ parent_dir = os.path.abspath(os.path.join(os.getcwd(), "../..")) sys.path.append(parent_dir) -logging_conf_path = os.path.normpath( - os.path.join(os.path.dirname(__file__), "../logging.conf") -) +logging_conf_path = os.path.normpath(os.path.join(os.path.dirname(__file__), "../logging.conf")) logging.config.fileConfig(logging_conf_path) log = logging.getLogger(__package__) @@ -27,15 +25,4 @@ db_password = os.getenv("DB_PASSWORD", "postgres") db_schema = os.getenv("DB_SCHEMA", "public") db_table_prefix = os.getenv("DB_TABLE_PREFIX", "") -db_uri = ( - "postgresql://" - + db_user - + ":" - + db_password - + "@" - + db_host - + ":" - + db_port - + "/" - + db_name -) +db_uri = "postgresql://" + db_user + ":" + db_password + "@" + db_host + ":" + db_port + "/" + db_name diff --git a/charging_stations_pipelines/shared.py b/charging_stations_pipelines/shared.py index b3f966a4..d1eeb192 100644 --- a/charging_stations_pipelines/shared.py +++ b/charging_stations_pipelines/shared.py @@ -18,9 +18,7 @@ current_dir = os.path.join(pathlib.Path(__file__).parent.parent.resolve()) -_PlainJSON = Union[ - None, bool, int, float, str, List["_PlainJSON"], Dict[str, "_PlainJSON"] -] +_PlainJSON = Union[None, bool, int, float, str, List["_PlainJSON"], Dict[str, "_PlainJSON"]] """_PlainJSON is a type alias for a JSON object without custom types.""" JSON = Union[_PlainJSON, Dict[str, "JSON"], List["JSON"]] @@ -52,9 +50,7 @@ def check_coordinates(coords: Optional[Union[float, int, str]]) -> Optional[floa if isinstance(coords, str): try: - processed_coords = "".join( - c for c in coords.replace(",", ".") if c.isdigit() or c in ".-" - ) + processed_coords = "".join(c for c in coords.replace(",", ".") if c.isdigit() or c in ".-") logger.debug(f"Coords are string: {coords} will be transformed!") return float(processed_coords) except (ValueError, TypeError): @@ -117,15 +113,9 @@ def str_strip_whitespace( return default -def str_clean_pattern( - raw_str: Optional[str], remove_pattern: Optional[str] -) -> Optional[str]: +def str_clean_pattern(raw_str: Optional[str], remove_pattern: Optional[str]) -> Optional[str]: """Removes a given pattern from a string.""" - return ( - re.sub(remove_pattern, "", raw_str, flags=re.IGNORECASE).strip() - if raw_str and remove_pattern - else None - ) + return re.sub(remove_pattern, "", raw_str, flags=re.IGNORECASE).strip() if raw_str and remove_pattern else None def str_split_pattern(raw_str: Optional[str], split_pattern: str) -> list[str]: @@ -194,9 +184,7 @@ def lst_expand(aggregated_list: list[tuple[float, int]]) -> list[float]: [1.0, 1.0, 1.0, 2.5, 2.5] """ # [0] - float value, [1] - count, how often this value occurs - return ( - [e[0] for e in aggregated_list for _ in range(e[1])] if aggregated_list else [] - ) + return [e[0] for e in aggregated_list for _ in range(e[1])] if aggregated_list else [] def coalesce(*args): diff --git a/charging_stations_pipelines/stations_data_export.py b/charging_stations_pipelines/stations_data_export.py index 612a0b34..43ce5191 100644 --- a/charging_stations_pipelines/stations_data_export.py +++ b/charging_stations_pipelines/stations_data_export.py @@ -32,11 +32,7 @@ def stations_data_export( ): """Exports stations data to a file.""" logger.info(f"Exporting stations data for country {country_code}") - country_filter = ( - f"country_code='{country_code}' AND " - if country_code != "" and not export_all_countries - else "" - ) + country_filter = f"country_code='{country_code}' AND " if country_code != "" and not export_all_countries else "" merged_filter = "s.is_merged" if export_merged else "NOT s.is_merged" export_area_filter = ( ( @@ -77,9 +73,7 @@ def stations_data_export( """ logger.debug(f"Running postgis query {get_stations_list_sql}") - gdf: gpd.GeoDataFrame = gpd.read_postgis( - get_stations_list_sql, con=db_connection, geom_col="point" - ) + gdf: gpd.GeoDataFrame = gpd.read_postgis(get_stations_list_sql, con=db_connection, geom_col="point") logger.debug(f"Found stations of shape: {gdf.shape}") if len(gdf) == 0: @@ -87,12 +81,8 @@ def stations_data_export( else: if export_to_csv: suffix = "csv" - gdf["latitude"] = gdf["point"].apply( - lambda point: point.y if point else None - ) - gdf["longitude"] = gdf["point"].apply( - lambda point: point.x if point else None - ) + gdf["latitude"] = gdf["point"].apply(lambda point: point.y if point else None) + gdf["longitude"] = gdf["point"].apply(lambda point: point.x if point else None) export_data = gdf.to_csv() else: suffix = "geo.json" @@ -101,9 +91,7 @@ def stations_data_export( logger.debug(f"Data sample: {gdf.sample(5)}") file_country = "europe" if export_all_countries else country_code - file_description = get_file_description( - file_descriptor, file_country, export_area - ) + file_description = get_file_description(file_descriptor, file_country, export_area) file_suffix_merged = "merged" if export_merged else "w_duplicates" file_suffix_charging = "_w_charging" if export_charging_attributes else "" @@ -114,16 +102,12 @@ def stations_data_export( logger.info(f"Done writing, file size: {outfile.tell()}") -def get_file_description( - file_descriptor: str, file_country: str, export_circle: ExportArea -): +def get_file_description(file_descriptor: str, file_country: str, export_circle: ExportArea): """Returns a file description based on the given parameters.""" is_export_circle_specified = export_circle is not None if file_descriptor == "": if is_export_circle_specified: - return ( - f"{export_circle.lon}_{export_circle.lat}_{export_circle.radius_meters}" - ) + return f"{export_circle.lon}_{export_circle.lat}_{export_circle.radius_meters}" else: return file_country else: diff --git a/main.py b/main.py index 440ac7d9..87e1e405 100644 --- a/main.py +++ b/main.py @@ -175,9 +175,7 @@ def run_merge(countries: list[str], delete_data: bool): def run_export(cli_args): """This method runs the export process based on the provided command line arguments.""" - args_file_descriptor = ( - cli_args.export_file_descriptor if cli_args.export_file_descriptor else "" - ) + args_file_descriptor = cli_args.export_file_descriptor if cli_args.export_file_descriptor else "" args_export_area = ( ExportArea( @@ -229,19 +227,14 @@ def main(): setup_logging(command_line_args.verbose) tasks = { - "import": lambda args: run_import( - args.countries, not args.offline, args.delete_data - ), + "import": lambda args: run_import(args.countries, not args.offline, args.delete_data), "merge": lambda args: run_merge(args.countries, args.delete_data), "testdata": lambda args: testdata.run(), "export": run_export, } logger.info("Starting eCharm...") - logger.info( - f"Specified {len(command_line_args.tasks)} tasks: " - + ", ".join(command_line_args.tasks) - ) + logger.info(f"Specified {len(command_line_args.tasks)} tasks: " + ", ".join(command_line_args.tasks)) for task_name in command_line_args.tasks: logger.info(f"Running task: {task_name}") task_function = tasks.get(task_name) diff --git a/ruff.toml b/ruff.toml index 30dfdbcf..f43b4ee3 100644 --- a/ruff.toml +++ b/ruff.toml @@ -28,8 +28,7 @@ exclude = [ "venv", ] -# Same as Black. -line-length = 88 +line-length = 120 indent-width = 4 # Assume Python 3.8 diff --git a/test/pipelines/at/test_econtrol_crawler.py b/test/pipelines/at/test_econtrol_crawler.py index 50e447bd..1b9de995 100644 --- a/test/pipelines/at/test_econtrol_crawler.py +++ b/test/pipelines/at/test_econtrol_crawler.py @@ -24,16 +24,12 @@ def test_paginated_stations(mock_get): "totalResults": 13, "fromIndex": i * 3, "endIndex": min((i + 1) * 3 - 1, 13 - 1), - "stations": [ - {f"id{j}": f"station{j}"} for j in range(i * 3 + 1, (i + 1) * 3) - ], + "stations": [{f"id{j}": f"station{j}"} for j in range(i * 3 + 1, (i + 1) * 3)], } for i in range(13 // 3 + 13 % 3) ] - page_generator = econtrol_crawler._get_paginated_stations( - station_api_url, headers={} - ) + page_generator = econtrol_crawler._get_paginated_stations(station_api_url, headers={}) for idx, station_page in enumerate(station_pages): mock_get.return_value.json.return_value = station_page @@ -63,9 +59,7 @@ def test_get_paginated_stations_key_error(): } with mock.patch("requests.Session") as mock_session: - mock_session.return_value.get.return_value.json.return_value = ( - mock_response_content - ) + mock_session.return_value.get.return_value.json.return_value = mock_response_content # Expect a KeyError due to missing 'endIndex' key with pytest.raises(KeyError, match=r"endIndex"): @@ -73,9 +67,7 @@ def test_get_paginated_stations_key_error(): @mock.patch("builtins.open", new_callable=mock.mock_open) -@mock.patch( - "charging_stations_pipelines.pipelines.at.econtrol_crawler._get_paginated_stations" -) +@mock.patch("charging_stations_pipelines.pipelines.at.econtrol_crawler._get_paginated_stations") @mock.patch("os.getenv") @mock.patch("os.path.getsize") def test_get_data( @@ -104,10 +96,7 @@ def test_get_data( "totalResults": 100, "fromIndex": i * 10, "endIndex": min((i + 1) * 10 - 1, 100 - 1), - "stations": [ - {f"id{j}": f"station{j}"} - for j in range(i * 10 + 1, (i + 1) * 10 + 1) - ], + "stations": [{f"id{j}": f"station{j}"} for j in range(i * 10 + 1, (i + 1) * 10 + 1)], } for i in range(100 // 10 + 100 % 10) ] @@ -138,9 +127,7 @@ def test_get_data( @mock.patch("builtins.open", new_callable=mock.mock_open) -@mock.patch( - "charging_stations_pipelines.pipelines.at.econtrol_crawler._get_paginated_stations" -) +@mock.patch("charging_stations_pipelines.pipelines.at.econtrol_crawler._get_paginated_stations") @mock.patch("os.path.getsize") def test_get_data_empty_response( mock_getsize, diff --git a/test/pipelines/at/test_econtrol_mapper.py b/test/pipelines/at/test_econtrol_mapper.py index 74844ab0..3d01f111 100644 --- a/test/pipelines/at/test_econtrol_mapper.py +++ b/test/pipelines/at/test_econtrol_mapper.py @@ -170,9 +170,7 @@ def test_map_charging__kw_list(): ] for raw, expected in sample_data: - raw_datapoint = pd.Series( - {"points": [{"energyInKw": raw[0], "connectorTypes": raw[1]}]} - ) + raw_datapoint = pd.Series({"points": [{"energyInKw": raw[0], "connectorTypes": raw[1]}]}) exp_kw_list, exp_max_kw, exp_total_kw = expected diff --git a/test/pipelines/de/test_bna_crawler.py b/test/pipelines/de/test_bna_crawler.py index 2dcf447f..98fdf296 100644 --- a/test/pipelines/de/test_bna_crawler.py +++ b/test/pipelines/de/test_bna_crawler.py @@ -30,9 +30,7 @@ def test_get_bna_data_downloads_file_with_correct_url( mock_requests_get.return_value = mock_response # Mock the BeautifulSoup find_all method - mock_beautiful_soup.return_value.find_all.return_value = [ - {"href": "https://some_ladesaeulenregister_url.xlsx"} - ] + mock_beautiful_soup.return_value.find_all.return_value = [{"href": "https://some_ladesaeulenregister_url.xlsx"}] # Mock the os.path.getsize method mock_getsize.return_value = 4321 @@ -53,16 +51,12 @@ def test_get_bna_data_downloads_file_with_correct_url( ) # Assert that the os.path.getsize method was called with the correct parameters - mock_getsize.assert_called_once_with( - "./tmp_data_path/some_ladesaeulenregister_url.xlsx" - ) + mock_getsize.assert_called_once_with("./tmp_data_path/some_ladesaeulenregister_url.xlsx") @patch.object(requests, "get") @patch.object(charging_stations_pipelines.pipelines.de.bna_crawler, "BeautifulSoup") -def test_get_bna_data_logs_error_when_no_download_link_found( - mock_beautiful_soup, mock_requests_get, caplog -): +def test_get_bna_data_logs_error_when_no_download_link_found(mock_beautiful_soup, mock_requests_get, caplog): # Mock the requests.get response mock_requests_get.return_value = Mock(content=b"some content", status_code=200) @@ -125,6 +119,4 @@ def test_get_bna_data_logs_file_size_after_download( ) # Assert that os.path.getsize was called correctly - mock_getsize.assert_called_once_with( - "tmp_data_path/some_url1_with_search_term.xlsx" - ) + mock_getsize.assert_called_once_with("tmp_data_path/some_url1_with_search_term.xlsx") diff --git a/test/pipelines/de/test_bna_mapper.py b/test/pipelines/de/test_bna_mapper.py index 45abebc9..b21abec3 100644 --- a/test/pipelines/de/test_bna_mapper.py +++ b/test/pipelines/de/test_bna_mapper.py @@ -34,16 +34,10 @@ def test_map_station_bna(): assert station.data_source == DATA_SOURCE_KEY assert ( station.source_id - == hashlib.sha256( - ( - data_row["Breitengrad"] + data_row["Längengrad"] + station.data_source - ).encode() - ).hexdigest() + == hashlib.sha256((data_row["Breitengrad"] + data_row["Längengrad"] + station.data_source).encode()).hexdigest() ) assert station.operator == data_row["Betreiber"] - assert station.point == from_shape( - Point(float(data_row["Längengrad"]), float(data_row["Breitengrad"])) - ) + assert station.point == from_shape(Point(float(data_row["Längengrad"]), float(data_row["Breitengrad"]))) assert station.date_created == data_row["Inbetriebnahmedatum"].strftime("%Y-%m-%d") @@ -98,9 +92,7 @@ def test_map_address_bna(): def test_map_charging_bna(): # Pandas Series to simulate the row - row = pd.Series( - {"Nennleistung Ladeeinrichtung [kW]": "5.0", "Anzahl Ladepunkte": 2} - ) + row = pd.Series({"Nennleistung Ladeeinrichtung [kW]": "5.0", "Anzahl Ladepunkte": 2}) charging = de_mapper.map_charging_bna(row, 1) diff --git a/test/shared.py b/test/shared.py index 296e9f4a..079f6d96 100644 --- a/test/shared.py +++ b/test/shared.py @@ -102,9 +102,7 @@ def __init__(self): self.handler = LogCaptureHandler() @contextmanager - def __call__( - self, level: int, logger: logging.Logger - ) -> Generator[None, None, None]: + def __call__(self, level: int, logger: logging.Logger) -> Generator[None, None, None]: """Context manager that sets the level for capturing of logs.""" orig_level = logger.level diff --git a/test/test_shared.py b/test/test_shared.py index 631f3fea..416f8bff 100644 --- a/test/test_shared.py +++ b/test/test_shared.py @@ -137,15 +137,9 @@ def test_lst_expand(): assert lst_expand([]) == [] assert lst_expand([(2.5, 1)]) == [2.5] assert lst_expand([(1.0, 3), (2.5, 2)]) == [1.0, 1.0, 1.0, 2.5, 2.5] - assert ( - lst_expand([(1.0, 1500), (3.9, 20), (2.0, 5)]) - == [1.0] * 1500 + [3.9] * 20 + [2.0] * 5 - ) - - assert ( - lst_expand([(6.75, 2), (4.0, 100), (1.5, 35)]) - == [6.75] * 2 + [4.0] * 100 + [1.5] * 35 - ) + assert lst_expand([(1.0, 1500), (3.9, 20), (2.0, 5)]) == [1.0] * 1500 + [3.9] * 20 + [2.0] * 5 + + assert lst_expand([(6.75, 2), (4.0, 100), (1.5, 35)]) == [6.75] * 2 + [4.0] * 100 + [1.5] * 35 assert lst_expand([(1.0, 3), (2.5, 2)]) == [1.0, 1.0, 1.0, 2.5, 2.5] diff --git a/testdata_import.py b/testdata_import.py index 830ddaa1..598b340c 100644 --- a/testdata_import.py +++ b/testdata_import.py @@ -36,9 +36,7 @@ def main() -> list[Any]: if creds and creds.expired and creds.refresh_token: creds.refresh(Request()) else: - flow = InstalledAppFlow.from_client_secrets_file( - os.path.join(directory, "credentials.json"), SCOPES - ) + flow = InstalledAppFlow.from_client_secrets_file(os.path.join(directory, "credentials.json"), SCOPES) creds = flow.run_local_server(port=8083) # Save the credentials for the next run with open(token_filename, "w") as token: @@ -49,9 +47,7 @@ def main() -> list[Any]: # Call the Sheets API sheet = service.spreadsheets() - result = ( - sheet.values().get(spreadsheetId=SPREADSHEET_ID, range="A1:Z100").execute() - ) + result = sheet.values().get(spreadsheetId=SPREADSHEET_ID, range="A1:Z100").execute() values = result.get("values", []) return values diff --git a/testing/testdata.py b/testing/testdata.py index 5bc8245f..e9463a2a 100644 --- a/testing/testdata.py +++ b/testing/testdata.py @@ -64,9 +64,7 @@ def run(): current_dir = os.path.join(pathlib.Path(__file__).parent.resolve()) config.read(os.path.join(os.path.join(current_dir, "config", "config.ini"))) - merger: StationMerger = StationMerger( - "DE", config=config, db_engine=create_engine(db_uri, echo=True) - ) + merger: StationMerger = StationMerger("DE", config=config, db_engine=create_engine(db_uri, echo=True)) # print(test_data) with open("testdata_merge.csv", "w") as outfile: diff --git a/tests/integration/test_int_de_bna.py b/tests/integration/test_int_de_bna.py index 004807c2..017bb6b1 100644 --- a/tests/integration/test_int_de_bna.py +++ b/tests/integration/test_int_de_bna.py @@ -68,9 +68,7 @@ def test_file_size(bna_data): def test_dataframe_schema(bna_data): _, bna_in_data = bna_data # Check schema of the downloaded Excel file - assert verify_schema_follows( - bna_in_data, EXPECTED_DATA_SCHEMA - ), "Mismatch in schema of the downloaded Excel file!" + assert verify_schema_follows(bna_in_data, EXPECTED_DATA_SCHEMA), "Mismatch in schema of the downloaded Excel file!" @pytest.mark.integration_test diff --git a/tests/integration/test_int_merger.py b/tests/integration/test_int_merger.py index db0fe926..74d70bd1 100644 --- a/tests/integration/test_int_merger.py +++ b/tests/integration/test_int_merger.py @@ -63,9 +63,7 @@ def _run_merger(engine): # Suppressing Pandas warning (1/2): "A value is trying to be set on a copy of a slice from a DataFrame." pd.options.mode.chained_assignment = None # default: 'warn' - station_merger = StationMerger( - country_code="DE", config=(get_config()), db_engine=engine - ) + station_merger = StationMerger(country_code="DE", config=(get_config()), db_engine=engine) station_merger.run() # Suppressing Pandas warning (2/2): restoring default value @@ -221,9 +219,7 @@ def _create_stations(): # Suppressing Pandas warning (1/2): "A value is trying to be set on a copy of a slice from a DataFrame." pd.options.mode.chained_assignment = None # default: 'warn' - station_merger = StationMerger( - country_code="AT", config=(get_config()), db_engine=engine - ) + station_merger = StationMerger(country_code="AT", config=(get_config()), db_engine=engine) station_merger.run() # Suppressing Pandas warning (2/2): restoring default value From 80d4ec9d79b19f3add7be6bd28346e08895ee641 Mon Sep 17 00:00:00 2001 From: Martin Mader Date: Fri, 26 Jan 2024 14:24:01 +0100 Subject: [PATCH 4/5] fix FRGOV data retrieval and add integration tests --- .../pipelines/fr/france.py | 19 ++++++---- .../pipelines/fr/france_mapper.py | 4 -- tests/integration/test_int_de_bna.py | 2 +- tests/integration/test_int_fr_france.py | 38 ++++++++++++++++--- 4 files changed, 46 insertions(+), 17 deletions(-) diff --git a/charging_stations_pipelines/pipelines/fr/france.py b/charging_stations_pipelines/pipelines/fr/france.py index 205908aa..f2fd8fee 100644 --- a/charging_stations_pipelines/pipelines/fr/france.py +++ b/charging_stations_pipelines/pipelines/fr/france.py @@ -31,12 +31,7 @@ def _retrieve_data(self): if self.online: logger.info("Retrieving Online Data") self.download_france_gov_file(tmp_data_path) - self.data = pd.read_csv( - os.path.join(data_dir, "france_stations.csv"), - delimiter=",", - encoding="utf-8", - encoding_errors="replace", - ) + self.data = self.load_csv_file(tmp_data_path) def run(self): logger.info("Running FR GOV Pipeline...") @@ -55,7 +50,7 @@ def run(self): @staticmethod def download_france_gov_file(target_file): """Download a file from the French government website.""" - base_url = "https://transport.data.gouv.fr/resources/79624" + base_url = "https://transport.data.gouv.fr/resources/81548" r = requests.get(base_url, headers={"User-Agent": "Mozilla/5.0"}) soup = BeautifulSoup(r.content, "html.parser") @@ -73,3 +68,13 @@ def download_france_gov_file(target_file): "Could not determine source for french government data", ) download_file(link_to_dataset[0]["href"], target_file) + + @staticmethod + def load_csv_file(target_file): + return pd.read_csv( + target_file, + delimiter=",", + encoding="utf-8", + encoding_errors="replace", + low_memory=False, + ) diff --git a/charging_stations_pipelines/pipelines/fr/france_mapper.py b/charging_stations_pipelines/pipelines/fr/france_mapper.py index 453a44c5..0d421c0f 100644 --- a/charging_stations_pipelines/pipelines/fr/france_mapper.py +++ b/charging_stations_pipelines/pipelines/fr/france_mapper.py @@ -41,15 +41,11 @@ def map_station_fra(row: pd.Series) -> Station: float(check_coordinates(row.get("consolidated_latitude"))), ) ) - station.date_created = row.get("date_mise_en_service").strptime("%Y-%m-%d") - station.date_updated = row.get("date_maj").strptime("%Y-%m-%d") if not pd.isna(row.get("date_mise_en_service")): station.date_created = datetime.strptime(row.get("date_mise_en_service"), "%Y-%m-%d") if not pd.isna(row.get("date_maj")): station.date_updated = datetime.strptime(row.get("date_maj"), "%Y-%m-%d") - else: - station.date_updated = datetime.now return station diff --git a/tests/integration/test_int_de_bna.py b/tests/integration/test_int_de_bna.py index 017bb6b1..6aee7e4c 100644 --- a/tests/integration/test_int_de_bna.py +++ b/tests/integration/test_int_de_bna.py @@ -61,7 +61,7 @@ def bna_data(): def test_file_size(bna_data): bna_file_name, _ = bna_data # Check file size of the downloaded file - assert os.path.getsize(bna_file_name) >= 8_602_458 # ~ 9 MB + assert os.path.getsize(bna_file_name) >= 1_000 # actual file is ~ 9 MB, just make sure it is not quasi empty here @pytest.mark.integration_test diff --git a/tests/integration/test_int_fr_france.py b/tests/integration/test_int_fr_france.py index b8338e75..ce53e7b4 100644 --- a/tests/integration/test_int_fr_france.py +++ b/tests/integration/test_int_fr_france.py @@ -2,17 +2,45 @@ import os import tempfile - import pytest from charging_stations_pipelines.pipelines.fr.france import FraPipeline from test.shared import skip_if_github +from tests.test_utils import verify_schema_follows + +EXPECTED_DATA_SCHEMA = { + "id_station_itinerance": "object", + "nom_operateur": "object", + "consolidated_longitude": "float64", + "consolidated_latitude": "float64", + "date_mise_en_service": "object", + "date_maj": "object", + "nbre_pdc": "int64", + "adresse_station": "object", + "consolidated_commune": "object", + "consolidated_code_postal": "float64", +} + + +@pytest.fixture(scope="module") +def fr_data(): + """Setup method for tests. Executes once at the beginning of the test session (and not before each test).""" + # Download to a temporary file + with tempfile.NamedTemporaryFile() as temp_file: + FraPipeline.download_france_gov_file(temp_file.name) + fr_dataframe = FraPipeline.load_csv_file(temp_file.name) + yield temp_file.name, fr_dataframe @pytest.mark.integration_test @pytest.mark.skipif(skip_if_github(), reason="Skip the test when running on Github") -def test_download_france_gov_file(): +def test_download_france_gov_file(fr_data): """Test the download function.""" - with tempfile.NamedTemporaryFile() as temp_file: - FraPipeline.download_france_gov_file(temp_file.name) - assert os.path.getsize(temp_file.name) >= 47_498_370 # ~ 50 MB + fr_filename, _ = fr_data + assert os.path.getsize(fr_filename) >= 1_000 # actual file is ~ 45 MB, just make sure it is not quasi empty here + + +@pytest.mark.integration_test +def test_dataframe_schema(fr_data): + _, fr_dataframe = fr_data + assert verify_schema_follows(fr_dataframe, EXPECTED_DATA_SCHEMA), "Mismatch in schema of the downloaded csv file!" From 8c88b5b45e90b6d8dfc441af44c303ffabc0b635 Mon Sep 17 00:00:00 2001 From: Martin Mader Date: Fri, 26 Jan 2024 15:36:50 +0100 Subject: [PATCH 5/5] mark BNA and FR data retrieval tests with "check_datasource" and adapt github workflows --- .github/workflows/check-datasources.yml | 13 ++++++------- .github/workflows/python-app.yml | 2 +- pyproject.toml | 3 ++- tests/integration/test_int_de_bna.py | 13 ++----------- tests/integration/test_int_fr_france.py | 6 +++--- 5 files changed, 14 insertions(+), 23 deletions(-) diff --git a/.github/workflows/check-datasources.yml b/.github/workflows/check-datasources.yml index fe13f29b..6cbf5b83 100644 --- a/.github/workflows/check-datasources.yml +++ b/.github/workflows/check-datasources.yml @@ -24,13 +24,12 @@ jobs: run: | python -m pip install --upgrade pip pip install -r requirements.txt + pip install -r test/requirements.txt - # - name: Run integration tests (only) - # run: | - # pip install -r test/requirements.txt - # pytest -m "integration_test" - - - name: "[DE/BNA] Real data validity checks" + - name: "[DE/BNA] Data retrieval check" run: | - pip install -r test/requirements.txt pytest tests/integration/test_int_de_bna.py + + - name: "[FR] Data retrieval check" + run: | + pytest tests/integration/test_int_fr_france.py diff --git a/.github/workflows/python-app.yml b/.github/workflows/python-app.yml index fa3a9b05..8c340289 100644 --- a/.github/workflows/python-app.yml +++ b/.github/workflows/python-app.yml @@ -45,4 +45,4 @@ jobs: - name: Run tests run: | pip install -r test/requirements.txt - pytest + pytest -m "not check_datasource" diff --git a/pyproject.toml b/pyproject.toml index 5d4dfb0d..2d058bfe 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -27,5 +27,6 @@ testpaths = [ # Declare custom markers markers = [ - "integration_test: marks tests as integration tests, which are bit slow (deselect with '-m \"integration_test\"')", + "integration_test: marks tests as integration tests (deselect with '-m \"not integration_test\"')", + "check_datasource: marks tests as datasource check for scheduled github action (deselect with '-m \"not check_datasource\"')", ] diff --git a/tests/integration/test_int_de_bna.py b/tests/integration/test_int_de_bna.py index 6aee7e4c..fee3a832 100644 --- a/tests/integration/test_int_de_bna.py +++ b/tests/integration/test_int_de_bna.py @@ -58,6 +58,7 @@ def bna_data(): @pytest.mark.integration_test +@pytest.mark.check_datasource def test_file_size(bna_data): bna_file_name, _ = bna_data # Check file size of the downloaded file @@ -65,18 +66,8 @@ def test_file_size(bna_data): @pytest.mark.integration_test +@pytest.mark.check_datasource def test_dataframe_schema(bna_data): _, bna_in_data = bna_data # Check schema of the downloaded Excel file assert verify_schema_follows(bna_in_data, EXPECTED_DATA_SCHEMA), "Mismatch in schema of the downloaded Excel file!" - - -@pytest.mark.integration_test -def test_dataframe_shape(bna_data): - _, bna_in_data = bna_data - # Check shape of the dataframe - # Not exact check, because file grows over time - # Expected: at least 54,223 rows and 23 columns - num_rows, num_cols = bna_in_data.shape - assert num_rows >= 54_223, "Mismatch in dataframe shape: too few rows!" - assert num_cols >= 23, "Mismatch in dataframe shape: too few columns!" diff --git a/tests/integration/test_int_fr_france.py b/tests/integration/test_int_fr_france.py index ce53e7b4..e7eea471 100644 --- a/tests/integration/test_int_fr_france.py +++ b/tests/integration/test_int_fr_france.py @@ -5,7 +5,6 @@ import pytest from charging_stations_pipelines.pipelines.fr.france import FraPipeline -from test.shared import skip_if_github from tests.test_utils import verify_schema_follows EXPECTED_DATA_SCHEMA = { @@ -25,7 +24,7 @@ @pytest.fixture(scope="module") def fr_data(): """Setup method for tests. Executes once at the beginning of the test session (and not before each test).""" - # Download to a temporary file + # Download real FR GOV data to a temporary file with tempfile.NamedTemporaryFile() as temp_file: FraPipeline.download_france_gov_file(temp_file.name) fr_dataframe = FraPipeline.load_csv_file(temp_file.name) @@ -33,7 +32,7 @@ def fr_data(): @pytest.mark.integration_test -@pytest.mark.skipif(skip_if_github(), reason="Skip the test when running on Github") +@pytest.mark.check_datasource def test_download_france_gov_file(fr_data): """Test the download function.""" fr_filename, _ = fr_data @@ -41,6 +40,7 @@ def test_download_france_gov_file(fr_data): @pytest.mark.integration_test +@pytest.mark.check_datasource def test_dataframe_schema(fr_data): _, fr_dataframe = fr_data assert verify_schema_follows(fr_dataframe, EXPECTED_DATA_SCHEMA), "Mismatch in schema of the downloaded csv file!"