diff --git a/src/viadot/orchestration/prefect/flows/__init__.py b/src/viadot/orchestration/prefect/flows/__init__.py index acc981caf..a2d489ab9 100644 --- a/src/viadot/orchestration/prefect/flows/__init__.py +++ b/src/viadot/orchestration/prefect/flows/__init__.py @@ -30,6 +30,7 @@ from .supermetrics_to_adls import supermetrics_to_adls from .transform import transform from .transform_and_catalog import transform_and_catalog +from .vid_club_to_adls import vid_club_to_adls __all__ = [ @@ -63,4 +64,5 @@ "supermetrics_to_adls", "transform", "transform_and_catalog", + "vid_club_to_adls", ] diff --git a/src/viadot/orchestration/prefect/flows/vid_club_to_adls.py b/src/viadot/orchestration/prefect/flows/vid_club_to_adls.py new file mode 100644 index 000000000..8ba05ab25 --- /dev/null +++ b/src/viadot/orchestration/prefect/flows/vid_club_to_adls.py @@ -0,0 +1,98 @@ +"""Download data from Vid CLub API and load it into Azure Data Lake Storage.""" + +from typing import Any, Literal + +from prefect import flow +from prefect.task_runners import ConcurrentTaskRunner + +from viadot.orchestration.prefect.tasks import df_to_adls, vid_club_to_df + + +@flow( + name="Vid CLub extraction to ADLS", + description="Extract data from Vid CLub and load it into Azure Data Lake Storage.", + retries=1, + retry_delay_seconds=60, + task_runner=ConcurrentTaskRunner, +) +def vid_club_to_adls( # noqa: PLR0913 + *args: list[Any], + endpoint: Literal["jobs", "product", "company", "survey"] | None = None, + from_date: str = "2022-03-22", + to_date: str | None = None, + items_per_page: int = 100, + region: Literal["bg", "hu", "hr", "pl", "ro", "si", "all"] | None = None, + days_interval: int = 30, + cols_to_drop: list[str] | None = None, + config_key: str | None = None, + azure_key_vault_secret: str | None = None, + adls_config_key: str | None = None, + adls_azure_key_vault_secret: str | None = None, + adls_path: str | None = None, + adls_path_overwrite: bool = False, + validate_df_dict: dict | None = None, + timeout: int = 3600, + **kwargs: dict[str, Any], +) -> None: + """Flow for downloading data from the Vid Club via API to a CSV or Parquet file. + + Then upload it to Azure Data Lake. + + Args: + endpoint (Literal["jobs", "product", "company", "survey"], optional): The + endpoint source to be accessed. Defaults to None. + from_date (str, optional): Start date for the query, by default is the oldest + date in the data 2022-03-22. + to_date (str, optional): End date for the query. By default None, + which will be executed as datetime.today().strftime("%Y-%m-%d") in code. + items_per_page (int, optional): Number of entries per page. Defaults to 100. + region (Literal["bg", "hu", "hr", "pl", "ro", "si", "all"], optional): Region + filter for the query. Defaults to None (parameter is not used in url). + [December 2023 status: value 'all' does not work for company and jobs] + days_interval (int, optional): Days specified in date range per API call + (test showed that 30-40 is optimal for performance). Defaults to 30. + cols_to_drop (List[str], optional): List of columns to drop. Defaults to None. + config_key (str, optional): The key in the viadot config holding relevant + credentials. Defaults to None. + azure_key_vault_secret (Optional[str], optional): The name of the Azure Key + Vault secret where credentials are stored. Defaults to None. + adls_config_key (Optional[str], optional): The key in the viadot config holding + relevant credentials. Defaults to None. + adls_azure_key_vault_secret (Optional[str], optional): The name of the Azure Key + Vault secret containing a dictionary with ACCOUNT_NAME and Service Principal + credentials (TENANT_ID, CLIENT_ID, CLIENT_SECRET) for the Azure Data Lake. + Defaults to None. + adls_path (Optional[str], optional): Azure Data Lake destination file path. + Defaults to None. + adls_path_overwrite (bool, optional): Whether to overwrite the file in ADLS. + Defaults to True. + validate_df_dict (dict, optional): A dictionary with optional list of tests + to verify the output + dataframe. If defined, triggers the `validate_df` task from task_utils. + Defaults to None. + timeout (int, optional): The time (in seconds) to wait while running this task + before a timeout occurs. Defaults to 3600. + """ + data_frame = vid_club_to_df( + args=args, + endpoint=endpoint, + from_date=from_date, + to_date=to_date, + items_per_page=items_per_page, + region=region, + days_interval=days_interval, + cols_to_drop=cols_to_drop, + config_key=config_key, + azure_key_vault_secret=azure_key_vault_secret, + validate_df_dict=validate_df_dict, + timeout=timeout, + kawrgs=kwargs, + ) + + return df_to_adls( + df=data_frame, + path=adls_path, + credentials_secret=adls_azure_key_vault_secret, + config_key=adls_config_key, + overwrite=adls_path_overwrite, + ) diff --git a/src/viadot/orchestration/prefect/tasks/__init__.py b/src/viadot/orchestration/prefect/tasks/__init__.py index 8cdc4b52a..db43224a7 100644 --- a/src/viadot/orchestration/prefect/tasks/__init__.py +++ b/src/viadot/orchestration/prefect/tasks/__init__.py @@ -27,6 +27,7 @@ from .sharepoint import sharepoint_download_file, sharepoint_to_df from .sql_server import create_sql_server_table, sql_server_query, sql_server_to_df from .supermetrics import supermetrics_to_df +from .vid_club import vid_club_to_df __all__ = [ @@ -62,5 +63,6 @@ "sharepoint_to_df", "sql_server_query", "sql_server_to_df", + "vid_club_to_df", "supermetrics_to_df", ] diff --git a/src/viadot/orchestration/prefect/tasks/vid_club.py b/src/viadot/orchestration/prefect/tasks/vid_club.py new file mode 100644 index 000000000..b74380c7f --- /dev/null +++ b/src/viadot/orchestration/prefect/tasks/vid_club.py @@ -0,0 +1,79 @@ +"""Task for downloading data from Vid Club Cloud API.""" + +from typing import Any, Literal + +import pandas as pd +from prefect import task + +from viadot.orchestration.prefect.exceptions import MissingSourceCredentialsError +from viadot.orchestration.prefect.utils import get_credentials +from viadot.sources import VidClub + + +@task(retries=3, log_prints=True, retry_delay_seconds=10, timeout_seconds=2 * 60 * 60) +def vid_club_to_df( # noqa: PLR0913 + *args: list[Any], + endpoint: Literal["jobs", "product", "company", "survey"] | None = None, + from_date: str = "2022-03-22", + to_date: str | None = None, + items_per_page: int = 100, + region: Literal["bg", "hu", "hr", "pl", "ro", "si", "all"] | None = None, + days_interval: int = 30, + cols_to_drop: list[str] | None = None, + azure_key_vault_secret: str | None = None, + adls_config_key: str | None = None, + validate_df_dict: dict | None = None, + timeout: int = 3600, + **kwargs: dict[str, Any], +) -> pd.DataFrame: + """Task to downloading data from Vid Club APIs to Pandas DataFrame. + + Args: + endpoint (Literal["jobs", "product", "company", "survey"], optional): + The endpoint source to be accessed. Defaults to None. + from_date (str, optional): Start date for the query, by default is the oldest + date in the data 2022-03-22. + to_date (str, optional): End date for the query. By default None, + which will be executed as datetime.today().strftime("%Y-%m-%d") in code. + items_per_page (int, optional): Number of entries per page. Defaults to 100. + region (Literal["bg", "hu", "hr", "pl", "ro", "si", "all"], optional): Region + filter for the query. Defaults to None (parameter is not used in url). + [December 2023 status: value 'all' does not work for company and jobs] + days_interval (int, optional): Days specified in date range per API call + (test showed that 30-40 is optimal for performance). Defaults to 30. + cols_to_drop (List[str], optional): List of columns to drop. Defaults to None. + config_key (str, optional): The key in the viadot config holding relevant + credentials. Defaults to None. + azure_key_vault_secret (Optional[str], optional): The name of the Azure Key + Vault secret where credentials are stored. Defaults to None. + validate_df_dict (dict, optional): A dictionary with optional list of tests + to verify the output + dataframe. If defined, triggers the `validate_df` task from task_utils. + Defaults to None. + timeout (int, optional): The time (in seconds) to wait while running this task + before a timeout occurs. Defaults to 3600. + + Returns: Pandas DataFrame + """ + if not (azure_key_vault_secret or adls_config_key): + raise MissingSourceCredentialsError + + if not adls_config_key: + credentials = get_credentials(azure_key_vault_secret) + + vc_obj = VidClub( + args=args, + endpoint=endpoint, + from_date=from_date, + to_date=to_date, + items_per_page=items_per_page, + region=region, + days_interval=days_interval, + cols_to_drop=cols_to_drop, + vid_club_credentials=credentials, + validate_df_dict=validate_df_dict, + timeout=timeout, + kwargs=kwargs, + ) + + return vc_obj.to_df() diff --git a/src/viadot/sources/__init__.py b/src/viadot/sources/__init__.py index 789b8ca6a..08615a9f3 100644 --- a/src/viadot/sources/__init__.py +++ b/src/viadot/sources/__init__.py @@ -20,6 +20,7 @@ from .sql_server import SQLServer from .supermetrics import Supermetrics, SupermetricsCredentials from .uk_carbon_intensity import UKCarbonIntensity +from .vid_club import VidClub __all__ = [ @@ -41,6 +42,7 @@ "SupermetricsCredentials", # pragma: allowlist-secret "Trino", "UKCarbonIntensity", + "VidClub", ] if find_spec("adlfs"): from viadot.sources.azure_data_lake import AzureDataLake # noqa: F401 diff --git a/src/viadot/sources/vid_club.py b/src/viadot/sources/vid_club.py new file mode 100644 index 000000000..db0adca08 --- /dev/null +++ b/src/viadot/sources/vid_club.py @@ -0,0 +1,390 @@ +"""Vid Club Cloud API connector.""" + +from datetime import datetime, timedelta +from typing import Any, Literal + +import pandas as pd + +from viadot.exceptions import ValidationError +from viadot.sources.base import Source +from viadot.utils import handle_api_response + + +class VidClub(Source): + """A class implementing the Vid Club API. + + Documentation for this API is located at: https://evps01.envoo.net/vipapi/ + There are 4 endpoints where to get the data. + """ + + def __init__( + self, + *args, + endpoint: Literal["jobs", "product", "company", "survey"] | None = None, + from_date: str = "2022-03-22", + to_date: str | None = None, + items_per_page: int = 100, + region: Literal["bg", "hu", "hr", "pl", "ro", "si", "all"] | None = None, + days_interval: int = 30, + cols_to_drop: list[str] | None = None, + vid_club_credentials: dict[str, Any] | None = None, + validate_df_dict: dict | None = None, + timeout: int = 3600, + **kwargs, + ): + """Create an instance of VidClub. + + Args: + endpoint (Literal["jobs", "product", "company", "survey"], optional): The + endpoint source to be accessed. Defaults to None. + from_date (str, optional): Start date for the query, by default is the + oldest date in the data 2022-03-22. + to_date (str, optional): End date for the query. By default None, + which will be executed as datetime.today().strftime("%Y-%m-%d") in code. + items_per_page (int, optional): Number of entries per page. Defaults to 100. + region (Literal["bg", "hu", "hr", "pl", "ro", "si", "all"], optional): + Region filter for the query. Defaults to None + (parameter is not used in url). [December 2023 status: value 'all' + does not work for company and jobs] + days_interval (int, optional): Days specified in date range per API call + (test showed that 30-40 is optimal for performance). Defaults to 30. + cols_to_drop (List[str], optional): List of columns to drop. + Defaults to None. + vid_club_credentials (Dict[str, Any], optional): Stores the credentials + information. Defaults to None. + validate_df_dict (dict, optional): A dictionary with optional list of tests + to verify the output + dataframe. If defined, triggers the `validate_df` task from task_utils. + Defaults to None. + timeout (int, optional): The time (in seconds) to wait while running this + task before a timeout occurs. Defaults to 3600. + """ + self.endpoint = endpoint + self.from_date = from_date + self.to_date = to_date + self.items_per_page = items_per_page + self.region = region + self.days_interval = days_interval + self.cols_to_drop = cols_to_drop + self.vid_club_credentials = vid_club_credentials + self.validate_df_dict = validate_df_dict + self.timeout = timeout + + self.headers = { + "Authorization": "Bearer " + vid_club_credentials["token"], + "Content-Type": "application/json", + } + + super().__init__(credentials=vid_club_credentials, *args, **kwargs) # noqa: B026 + + def build_query( + self, + from_date: str, + to_date: str, + api_url: str, + items_per_page: int, + endpoint: Literal["jobs", "product", "company", "survey"] | None = None, + region: Literal["bg", "hu", "hr", "pl", "ro", "si", "all"] | None = None, + ) -> str: + """Builds the query from the inputs. + + Args: + from_date (str): Start date for the query. + to_date (str): End date for the query, if empty, will be executed as + datetime.today().strftime("%Y-%m-%d"). + api_url (str): Generic part of the URL to Vid Club API. + items_per_page (int): number of entries per page. + endpoint (Literal["jobs", "product", "company", "survey"], optional): + The endpoint source to be accessed. Defaults to None. + region (Literal["bg", "hu", "hr", "pl", "ro", "si", "all"], optional): + Region filter for the query. Defaults to None + (parameter is not used in url). [December 2023 status: value 'all' + does not work for company and jobs] + + Returns: + str: Final query with all filters added. + + Raises: + ValidationError: If any source different than the ones in the list are used. + """ + if endpoint in ["jobs", "product", "company"]: + region_url_string = f"®ion={region}" if region else "" + url = ( + f"""{api_url}{endpoint}?from={from_date}&to={to_date}""" + f"""{region_url_string}&limit={items_per_page}""" + ) + elif endpoint == "survey": + url = f"{api_url}{endpoint}?language=en&type=question" + else: + msg = "Pick one these sources: jobs, product, company, survey" + raise ValidationError(msg) + return url + + def intervals( + self, from_date: str, to_date: str, days_interval: int + ) -> tuple[list[str], list[str]]: + """Breaks dates range into smaller by provided days interval. + + Args: + from_date (str): Start date for the query in "%Y-%m-%d" format. + to_date (str): End date for the query, if empty, will be executed as + datetime.today().strftime("%Y-%m-%d"). + days_interval (int): Days specified in date range per api call + (test showed that 30-40 is optimal for performance). + + Returns: + List[str], List[str]: Starts and Ends lists that contains information + about date ranges for specific period and time interval. + + Raises: + ValidationError: If the final date of the query is before the start date. + """ + if to_date is None: + to_date = datetime.today().strftime("%Y-%m-%d") + + end_date = datetime.strptime(to_date, "%Y-%m-%d").date() + start_date = datetime.strptime(from_date, "%Y-%m-%d").date() + + from_date_obj = datetime.strptime(from_date, "%Y-%m-%d") + + to_date_obj = datetime.strptime(to_date, "%Y-%m-%d") + delta = to_date_obj - from_date_obj + + if delta.days < 0: + msg = "to_date cannot be earlier than from_date." + raise ValidationError(msg) + + interval = timedelta(days=days_interval) + starts = [] + ends = [] + + period_start = start_date + while period_start < end_date: + period_end = min(period_start + interval, end_date) + starts.append(period_start.strftime("%Y-%m-%d")) + ends.append(period_end.strftime("%Y-%m-%d")) + period_start = period_end + if len(starts) == 0 and len(ends) == 0: + starts.append(from_date) + ends.append(to_date) + return starts, ends + + def check_connection( + self, + endpoint: Literal["jobs", "product", "company", "survey"] | None = None, + from_date: str = "2022-03-22", + to_date: str | None = None, + items_per_page: int = 100, + region: Literal["bg", "hu", "hr", "pl", "ro", "si", "all"] | None = None, + url: str | None = None, + ) -> tuple[dict[str, Any], str]: + """Initiate first connection to API to retrieve piece of data. + + With information about type of pagination in API URL. + This option is added because type of pagination for endpoints is being changed + in the future from page number to 'next' id. + + Args: + endpoint (Literal["jobs", "product", "company", "survey"], optional): + The endpoint source to be accessed. Defaults to None. + from_date (str, optional): Start date for the query, by default is the + oldest date in the data 2022-03-22. + to_date (str, optional): End date for the query. By default None, + which will be executed as datetime.today().strftime("%Y-%m-%d") in code. + items_per_page (int, optional): Number of entries per page. + 100 entries by default. + region (Literal["bg", "hu", "hr", "pl", "ro", "si", "all"], optional): + Region filter for the query. Defaults to None + (parameter is not used in url). [December 2023 status: value 'all' + does not work for company and jobs] + url (str, optional): Generic part of the URL to Vid Club API. + Defaults to None. + + Returns: + Tuple[Dict[str, Any], str]: Dictionary with first response from API with + JSON containing data and used URL string. + + Raises: + ValidationError: If from_date is earlier than 2022-03-22. + ValidationError: If to_date is earlier than from_date. + """ + if from_date < "2022-03-22": + msg = "from_date cannot be earlier than 2022-03-22." + raise ValidationError(msg) + + if to_date < from_date: + msg = "to_date cannot be earlier than from_date." + raise ValidationError(msg) + + if url is None: + url = self.credentials["url"] + + first_url = self.build_query( + endpoint=endpoint, + from_date=from_date, + to_date=to_date, + api_url=url, + items_per_page=items_per_page, + region=region, + ) + headers = self.headers + response = handle_api_response(url=first_url, headers=headers, method="GET") + response = response.json() + return (response, first_url) + + def get_response( + self, + endpoint: Literal["jobs", "product", "company", "survey"] | None = None, + from_date: str = "2022-03-22", + to_date: str | None = None, + items_per_page: int = 100, + region: Literal["bg", "hu", "hr", "pl", "ro", "si", "all"] | None = None, + ) -> pd.DataFrame: + """Basing on the pagination type retrieved using check_connection function. + + It gets the response from the API queried and transforms it into DataFrame. + + Args: + endpoint (Literal["jobs", "product", "company", "survey"], optional): + The endpoint source to be accessed. Defaults to None. + from_date (str, optional): Start date for the query, by default is the + oldest date in the data 2022-03-22. + to_date (str, optional): End date for the query. By default None, + which will be executed as datetime.today().strftime("%Y-%m-%d") in code. + items_per_page (int, optional): Number of entries per page. + 100 entries by default. + region (Literal["bg", "hu", "hr", "pl", "ro", "si", "all"], optional): + Region filter for the query. Defaults to None + (parameter is not used in url). [December 2023 status: value 'all' + does not work for company and jobs] + + Returns: + pd.DataFrame: Table of the data carried in the response. + + Raises: + ValidationError: If any source different than the ones in the list are used. + """ + headers = self.headers + if endpoint not in ["jobs", "product", "company", "survey"]: + msg = "The source has to be: jobs, product, company or survey" + raise ValidationError(msg) + if to_date is None: + to_date = datetime.today().strftime("%Y-%m-%d") + + response, first_url = self.check_connection( + endpoint=endpoint, + from_date=from_date, + to_date=to_date, + items_per_page=items_per_page, + region=region, + ) + + if isinstance(response, dict): + keys_list = list(response.keys()) + elif isinstance(response, list): + keys_list = list(response[0].keys()) + else: + keys_list = [] + + ind = "next" in keys_list + + if "data" in keys_list: + df = pd.json_normalize(response["data"]) + df = pd.DataFrame(df) + length = df.shape[0] + page = 1 + + while length == items_per_page: + if ind is True: + next_page = response["next"] + url = f"{first_url}&next={next_page}" + else: + page += 1 + url = f"{first_url}&page={page}" + response_api = handle_api_response( + url=url, headers=headers, method="GET" + ) + response = response_api.json() + df_page = pd.json_normalize(response["data"]) + df_page = pd.DataFrame(df_page) + if endpoint == "product": + df_page = df_page.transpose() + length = df_page.shape[0] + df = pd.concat((df, df_page), axis=0) + else: + df = pd.DataFrame(response) + + return df + + def to_df( + self, + if_empty: Literal["warn", "skip", "fail"] = "warn", + ) -> pd.DataFrame: + """Looping get_response and iterating by date ranges defined in intervals. + + Stores outputs as DataFrames in a list. At the end, daframes are concatenated + in one and dropped duplicates that would appear when quering. + + Args: + if_empty (Literal["warn", "skip", "fail"], optional): What to do if a fetch + produce no data. Defaults to "warn + + Returns: + pd.DataFrame: Dataframe of the concatanated data carried in the responses. + """ + starts, ends = self.intervals( + from_date=self.from_date, + to_date=self.to_date, + days_interval=self.days_interval, + ) + + dfs_list = [] + if len(starts) > 0 and len(ends) > 0: + for start, end in zip(starts, ends, strict=False): + self.logger.info(f"ingesting data for dates [{start}]-[{end}]...") + df = self.get_response( + endpoint=self.endpoint, + from_date=start, + to_date=end, + items_per_page=self.items_per_page, + region=self.region, + ) + dfs_list.append(df) + if len(dfs_list) > 1: + df = pd.concat(dfs_list, axis=0, ignore_index=True) + else: + df = pd.DataFrame(dfs_list[0]) + else: + df = self.get_response( + endpoint=self.endpoint, + from_date=self.from_date, + to_date=self.to_date, + items_per_page=self.items_per_page, + region=self.region, + ) + list_columns = df.columns[df.map(lambda x: isinstance(x, list)).any()].tolist() + for i in list_columns: + df[i] = df[i].apply(lambda x: tuple(x) if isinstance(x, list) else x) + df.drop_duplicates(inplace=True) + + if self.cols_to_drop is not None: + if isinstance(self.cols_to_drop, list): + try: + self.logger.info( + f"Dropping following columns: {self.cols_to_drop}..." + ) + df.drop(columns=self.cols_to_drop, inplace=True, errors="raise") + except KeyError: + self.logger.exception( + f"""Column(s): {self.cols_to_drop} don't exist in the DataFrame. + No columns were dropped. Returning full DataFrame...""" + ) + self.logger.info(f"Existing columns: {df.columns}") + else: + msg = "Provide columns to drop in a List." + raise TypeError(msg) + + if df.empty: + self._handle_if_empty(if_empty=if_empty) + + return df diff --git a/tests/integration/orchestration/prefect/flows/test_vid_club.py b/tests/integration/orchestration/prefect/flows/test_vid_club.py new file mode 100644 index 000000000..7053e04f1 --- /dev/null +++ b/tests/integration/orchestration/prefect/flows/test_vid_club.py @@ -0,0 +1,31 @@ +from src.viadot.orchestration.prefect.flows import vid_club_to_adls +from src.viadot.sources import AzureDataLake + + +TEST_FILE_PATH = "test/path/to/adls.parquet" +TEST_SOURCE = "jobs" +TEST_FROM_DATE = "2023-01-01" +TEST_TO_DATE = "2023-12-31" +ADLS_CREDENTIALS_SECRET = "test_adls_secret" # pragma: allowlist secret # noqa: S105 +VIDCLUB_CREDENTIALS_SECRET = ( + "test_vidclub_secret" # pragma: allowlist secret # noqa: S105 +) + + +def test_vid_club_to_adls(): + lake = AzureDataLake(config_key="adls_test") + + assert not lake.exists(TEST_FILE_PATH) + + vid_club_to_adls( + endpoint=TEST_SOURCE, + from_date=TEST_FROM_DATE, + to_date=TEST_TO_DATE, + adls_path=TEST_FILE_PATH, + adls_azure_key_vault_secret=ADLS_CREDENTIALS_SECRET, + vidclub_credentials_secret=VIDCLUB_CREDENTIALS_SECRET, + ) + + assert lake.exists(TEST_FILE_PATH) + + lake.rm(TEST_FILE_PATH) diff --git a/tests/integration/orchestration/prefect/tasks/test_vid_club.py b/tests/integration/orchestration/prefect/tasks/test_vid_club.py new file mode 100644 index 000000000..2d39001bc --- /dev/null +++ b/tests/integration/orchestration/prefect/tasks/test_vid_club.py @@ -0,0 +1,52 @@ +import pandas as pd +import pytest + +from src.viadot.orchestration.prefect.exceptions import MissingSourceCredentialsError +from src.viadot.orchestration.prefect.tasks import vid_club_to_df + + +EXPECTED_DF = pd.DataFrame( + {"id": [1, 2], "name": ["Company A", "Company B"], "region": ["pl", "ro"]} +) + + +class MockVidClub: + def __init__(self, *args, **kwargs): + """Init method.""" + pass + + def to_df(self): + return EXPECTED_DF + + +def test_vid_club_to_df(mocker): + mocker.patch("viadot.orchestration.prefect.tasks.VidClub", new=MockVidClub) + + df = vid_club_to_df( + endpoint="company", + from_date="2023-01-01", + to_date="2023-12-31", + items_per_page=100, + region="pl", + vidclub_credentials_secret="VIDCLUB", # pragma: allowlist secret # noqa: S106 + ) + + assert isinstance(df, pd.DataFrame) + assert not df.empty + assert df.equals(EXPECTED_DF) + + +def test_vid_club_to_df_missing_credentials(mocker): + mocker.patch( + "viadot.orchestration.prefect.tasks.get_credentials", return_value=None + ) + + with pytest.raises(MissingSourceCredentialsError): + vid_club_to_df( + endpoint="company", + from_date="2023-01-01", + to_date="2023-12-31", + items_per_page=100, + region="pl", + vidclub_credentials_secret="VIDCLUB", # pragma: allowlist secret # noqa: S106 + ) diff --git a/tests/unit/test_vid_club.py b/tests/unit/test_vid_club.py new file mode 100644 index 000000000..db8b27b29 --- /dev/null +++ b/tests/unit/test_vid_club.py @@ -0,0 +1,60 @@ +import unittest + +import pytest + +from viadot.sources.vid_club import ValidationError, VidClub + + +class TestVidClub(unittest.TestCase): + def setUp(self): + """Setup VidClub instance before each test.""" + # Sample input data for the constructor + self.vid_club = VidClub( + endpoint="jobs", vid_club_credentials={"token": "test-token"} + ) + + def test_build_query(self): + """Test correct URL generation for the 'jobs' endpoint.""" + # Sample input data for the build_query method + from_date = "2023-01-01" + to_date = "2023-01-31" + api_url = "https://example.com/api/" + items_per_page = 50 + endpoint = "jobs" + region = "pl" + + # Expected result URL + expected_url = "https://example.com/api/jobs?from=2023-01-01&to=2023-01-31®ion=pl&limit=50" + + # Check if the method returns the correct URL + result_url = self.vid_club.build_query( + from_date, to_date, api_url, items_per_page, endpoint, region + ) + assert result_url == expected_url + + def test_intervals(self): + """Test breaking date range into intervals based on the days_interval.""" + # Sample input data for the intervals method + from_date = "2023-01-01" + to_date = "2023-01-15" + days_interval = 5 + + # Expected starts and ends lists + expected_starts = ["2023-01-01", "2023-01-06", "2023-01-11"] + expected_ends = ["2023-01-06", "2023-01-11", "2023-01-15"] + + # Check if the method returns correct intervals + starts, ends = self.vid_club.intervals(from_date, to_date, days_interval) + assert starts == expected_starts + assert ends == expected_ends + + def test_intervals_invalid_date_range(self): + """Test that ValidationError is raised when to_date is before from_date.""" + # Sample input data where to_date is before from_date + from_date = "2023-01-15" + to_date = "2023-01-01" + days_interval = 5 + + # Check if ValidationError is raised + with pytest.raises(ValidationError): + self.vid_club.intervals(from_date, to_date, days_interval)