From 6e698096eb3400f5c9b7817a6ba5902b532123e7 Mon Sep 17 00:00:00 2001 From: Ilyas Gasanov Date: Wed, 15 Jan 2025 09:48:16 +0300 Subject: [PATCH] [DOP-22348] Add transformations for Transfers with dataframe row filtering --- docs/changelog/next_release/184.feature.rst | 1 + .../2023-11-23_0007_create_transfer_table.py | 1 + syncmaster/db/models/transfer.py | 1 + syncmaster/db/repositories/transfer.py | 28 ++- syncmaster/dto/transfers.py | 2 + syncmaster/schemas/v1/connections/oracle.py | 4 +- syncmaster/schemas/v1/transfers/__init__.py | 27 ++- .../v1/transfers/transformations/__init__.py | 2 + .../transformations/dataframe_rows_filter.py | 86 ++++++++ syncmaster/schemas/v1/transformation_types.py | 5 + syncmaster/server/api/v1/transfers.py | 4 + syncmaster/worker/controller.py | 2 +- syncmaster/worker/handlers/base.py | 44 ++++ syncmaster/worker/handlers/db/base.py | 7 +- syncmaster/worker/handlers/file/base.py | 7 +- syncmaster/worker/handlers/file/s3.py | 3 +- .../test_run_transfer/conftest.py | 58 +++++- .../test_run_transfer/test_clickhouse.py | 81 +++++++- .../test_run_transfer/test_hive.py | 4 +- .../test_run_transfer/test_mssql.py | 88 +++++++- .../test_run_transfer/test_mysql.py | 88 +++++++- .../test_run_transfer/test_oracle.py | 88 +++++++- .../test_run_transfer/test_s3.py | 70 +++++++ .../scheduler_fixtures/transfer_fixture.py | 5 +- .../test_transfers/test_create_transfer.py | 196 +++++++++++++++++- .../test_create_transfer.py | 2 + .../test_file_transfers/test_read_transfer.py | 106 ++++++---- .../test_update_transfer.py | 123 +++++++---- .../test_transfers/test_read_transfer.py | 2 + .../test_transfers/test_read_transfers.py | 5 + .../test_transfers/test_update_transfer.py | 3 + .../transfer_fixtures/transfer_fixture.py | 5 +- .../transfer_with_user_role_fixtures.py | 5 +- tests/test_unit/utils.py | 2 + 34 files changed, 1034 insertions(+), 121 deletions(-) create mode 100644 docs/changelog/next_release/184.feature.rst create mode 100644 syncmaster/schemas/v1/transfers/transformations/__init__.py create mode 100644 syncmaster/schemas/v1/transfers/transformations/dataframe_rows_filter.py create mode 100644 syncmaster/schemas/v1/transformation_types.py diff --git a/docs/changelog/next_release/184.feature.rst b/docs/changelog/next_release/184.feature.rst new file mode 100644 index 00000000..3ffebc37 --- /dev/null +++ b/docs/changelog/next_release/184.feature.rst @@ -0,0 +1 @@ +Add transformations for **Transfers** with dataframe row filtering \ No newline at end of file diff --git a/syncmaster/db/migrations/versions/2023-11-23_0007_create_transfer_table.py b/syncmaster/db/migrations/versions/2023-11-23_0007_create_transfer_table.py index bc98cc66..f0777703 100644 --- a/syncmaster/db/migrations/versions/2023-11-23_0007_create_transfer_table.py +++ b/syncmaster/db/migrations/versions/2023-11-23_0007_create_transfer_table.py @@ -44,6 +44,7 @@ def upgrade(): sa.Column("strategy_params", sa.JSON(), nullable=False), sa.Column("source_params", sa.JSON(), nullable=False), sa.Column("target_params", sa.JSON(), nullable=False), + sa.Column("transformations", sa.JSON(), nullable=False), sa.Column("is_scheduled", sa.Boolean(), nullable=False), sa.Column("schedule", sa.String(length=32), nullable=False), sa.Column("queue_id", sa.BigInteger(), nullable=False), diff --git a/syncmaster/db/models/transfer.py b/syncmaster/db/models/transfer.py index c2029871..11e29928 100644 --- a/syncmaster/db/models/transfer.py +++ b/syncmaster/db/models/transfer.py @@ -46,6 +46,7 @@ class Transfer( strategy_params: Mapped[dict[str, Any]] = mapped_column(JSON, nullable=False, default={}) source_params: Mapped[dict[str, Any]] = mapped_column(JSON, nullable=False, default={}) target_params: Mapped[dict[str, Any]] = mapped_column(JSON, nullable=False, default={}) + transformations: Mapped[list[dict[str, Any]]] = mapped_column(JSON, nullable=False, default=list) is_scheduled: Mapped[bool] = mapped_column(Boolean, nullable=False, default=False) schedule: Mapped[str] = mapped_column(String(32), nullable=False, default="") queue_id: Mapped[int] = mapped_column( diff --git a/syncmaster/db/repositories/transfer.py b/syncmaster/db/repositories/transfer.py index 48b5808b..3b692c46 100644 --- a/syncmaster/db/repositories/transfer.py +++ b/syncmaster/db/repositories/transfer.py @@ -115,6 +115,7 @@ async def create( source_params: dict[str, Any], target_params: dict[str, Any], strategy_params: dict[str, Any], + transformations: list[dict[str, Any]], queue_id: int, is_scheduled: bool, schedule: str | None, @@ -130,6 +131,7 @@ async def create( source_params=source_params, target_params=target_params, strategy_params=strategy_params, + transformations=transformations, queue_id=queue_id, is_scheduled=is_scheduled, schedule=schedule or "", @@ -154,20 +156,27 @@ async def update( source_params: dict[str, Any], target_params: dict[str, Any], strategy_params: dict[str, Any], + transformations: list[dict[str, Any]], is_scheduled: bool | None, schedule: str | None, new_queue_id: int | None, ) -> Transfer: try: - for key in transfer.source_params: - if key not in source_params or source_params[key] is None: - source_params[key] = transfer.source_params[key] - for key in transfer.target_params: - if key not in target_params or target_params[key] is None: - target_params[key] = transfer.target_params[key] - for key in transfer.strategy_params: - if key not in strategy_params or strategy_params[key] is None: - strategy_params[key] = transfer.strategy_params[key] + for old, new in [ + (transfer.source_params, source_params), + (transfer.target_params, target_params), + (transfer.strategy_params, strategy_params), + ]: + for key in old: + if key not in new or new[key] is None: + new[key] = old[key] + + new_transformations = {d["type"]: d["filters"] for d in transformations} + old_transformations = {d["type"]: d["filters"] for d in transfer.transformations} + for t_type, t_filters in new_transformations.items(): + old_transformations[t_type] = t_filters + transformations = [{"type": t, "filters": f} for t, f in old_transformations.items()] + return await self._update( Transfer.id == transfer.id, name=name or transfer.name, @@ -179,6 +188,7 @@ async def update( target_connection_id=target_connection_id or transfer.target_connection_id, source_params=source_params, target_params=target_params, + transformations=transformations, queue_id=new_queue_id or transfer.queue_id, ) except IntegrityError as e: diff --git a/syncmaster/dto/transfers.py b/syncmaster/dto/transfers.py index d09914d3..d9fff7bc 100644 --- a/syncmaster/dto/transfers.py +++ b/syncmaster/dto/transfers.py @@ -15,6 +15,7 @@ class TransferDTO: @dataclass class DBTransferDTO(TransferDTO): table_name: str + transformations: list[dict] | None = None @dataclass @@ -23,6 +24,7 @@ class FileTransferDTO(TransferDTO): file_format: CSV | JSONLine | JSON | Excel | XML | ORC | Parquet options: dict df_schema: dict | None = None + transformations: list[dict] | None = None _format_parsers = { "csv": CSV, diff --git a/syncmaster/schemas/v1/connections/oracle.py b/syncmaster/schemas/v1/connections/oracle.py index 1e1364c0..c51ea253 100644 --- a/syncmaster/schemas/v1/connections/oracle.py +++ b/syncmaster/schemas/v1/connections/oracle.py @@ -24,7 +24,7 @@ class CreateOracleConnectionDataSchema(BaseModel): additional_params: dict = Field(default_factory=dict) @model_validator(mode="before") - def check_owner_id(cls, values): + def validate_connection_identifiers(cls, values): sid, service_name = values.get("sid"), values.get("service_name") if sid and service_name: raise ValueError("You must specify either sid or service_name but not both") @@ -47,7 +47,7 @@ class UpdateOracleConnectionDataSchema(BaseModel): additional_params: dict | None = Field(default_factory=dict) @model_validator(mode="before") - def check_owner_id(cls, values): + def validate_connection_identifiers(cls, values): sid, service_name = values.get("sid"), values.get("service_name") if sid and service_name: raise ValueError("You must specify either sid or service_name but not both") diff --git a/syncmaster/schemas/v1/transfers/__init__.py b/syncmaster/schemas/v1/transfers/__init__.py index d90732d3..5d650842 100644 --- a/syncmaster/schemas/v1/transfers/__init__.py +++ b/syncmaster/schemas/v1/transfers/__init__.py @@ -2,7 +2,9 @@ # SPDX-License-Identifier: Apache-2.0 from __future__ import annotations -from pydantic import BaseModel, Field, model_validator +from typing import Annotated + +from pydantic import BaseModel, Field, field_validator, model_validator from syncmaster.schemas.v1.connections.connection import ReadConnectionSchema from syncmaster.schemas.v1.page import PageSchema @@ -27,6 +29,9 @@ S3ReadTransferTarget, ) from syncmaster.schemas.v1.transfers.strategy import FullStrategy, IncrementalStrategy +from syncmaster.schemas.v1.transfers.transformations.dataframe_rows_filter import ( + DataframeRowsFilter, +) from syncmaster.schemas.v1.types import NameConstr ReadTransferSchemaSource = ( @@ -97,6 +102,8 @@ | None ) +TransformationSchema = DataframeRowsFilter + class CopyTransferSchema(BaseModel): new_group_id: int @@ -129,6 +136,9 @@ class ReadTransferSchema(BaseModel): ..., discriminator="type", ) + transformations: list[Annotated[TransformationSchema, Field(..., discriminator="type")]] = Field( + default_factory=list, + ) class Config: from_attributes = True @@ -158,15 +168,27 @@ class CreateTransferSchema(BaseModel): discriminator="type", description="Incremental or archive download options", ) + transformations: list[ + Annotated[TransformationSchema, Field(None, discriminator="type", description="List of transformations")] + ] = Field(default_factory=list) @model_validator(mode="before") - def check_owner_id(cls, values): + def validate_scheduling(cls, values): is_scheduled, schedule = values.get("is_scheduled"), values.get("schedule") if is_scheduled and schedule is None: # TODO make checking cron string raise ValueError("If transfer must be scheduled than set schedule param") return values + @field_validator("transformations", mode="after") + def validate_transformations_uniqueness(cls, transformations): + if transformations: + types = [tr.type for tr in transformations] + duplicates = {t for t in types if types.count(t) > 1} + if duplicates: + raise ValueError(f"Duplicate 'type' values found in transformations: {' '.join(map(str, duplicates))}") + return transformations + class UpdateTransferSchema(BaseModel): source_connection_id: int | None = None @@ -179,6 +201,7 @@ class UpdateTransferSchema(BaseModel): source_params: UpdateTransferSchemaSource = Field(discriminator="type", default=None) target_params: UpdateTransferSchemaTarget = Field(discriminator="type", default=None) strategy_params: FullStrategy | IncrementalStrategy | None = Field(discriminator="type", default=None) + transformations: list[Annotated[TransformationSchema, Field(discriminator="type", default=None)]] = None class ReadFullTransferSchema(ReadTransferSchema): diff --git a/syncmaster/schemas/v1/transfers/transformations/__init__.py b/syncmaster/schemas/v1/transfers/transformations/__init__.py new file mode 100644 index 00000000..eb9bf462 --- /dev/null +++ b/syncmaster/schemas/v1/transfers/transformations/__init__.py @@ -0,0 +1,2 @@ +# SPDX-FileCopyrightText: 2023-2024 MTS PJSC +# SPDX-License-Identifier: Apache-2.0 diff --git a/syncmaster/schemas/v1/transfers/transformations/dataframe_rows_filter.py b/syncmaster/schemas/v1/transfers/transformations/dataframe_rows_filter.py new file mode 100644 index 00000000..1636cf2c --- /dev/null +++ b/syncmaster/schemas/v1/transfers/transformations/dataframe_rows_filter.py @@ -0,0 +1,86 @@ +# SPDX-FileCopyrightText: 2023-2024 MTS PJSC +# SPDX-License-Identifier: Apache-2.0 +from typing import Annotated, Literal + +from pydantic import BaseModel, Field + +from syncmaster.schemas.v1.transformation_types import DATAFRAME_ROWS_FILTER + + +class BaseRowFilter(BaseModel): + field: str + + +class EqualFilter(BaseRowFilter): + type: Literal["equal"] + value: str | None = None + + +class NotEqualFilter(BaseRowFilter): + type: Literal["not_equal"] + value: str | None = None + + +class GreaterThanFilter(BaseRowFilter): + type: Literal["greater_than"] + value: str + + +class GreaterOrEqualFilter(BaseRowFilter): + type: Literal["greater_or_equal"] + value: str + + +class LessThanFilter(BaseRowFilter): + type: Literal["less_than"] + value: str + + +class LessOrEqualFilter(BaseRowFilter): + type: Literal["less_or_equal"] + value: str + + +class LikeFilter(BaseRowFilter): + type: Literal["like"] + value: str + + +class ILikeFilter(BaseRowFilter): + type: Literal["ilike"] + value: str + + +class NotLikeFilter(BaseRowFilter): + type: Literal["not_like"] + value: str + + +class NotILikeFilter(BaseRowFilter): + type: Literal["not_ilike"] + value: str + + +class RegexpFilter(BaseRowFilter): + type: Literal["regexp"] + value: str + + +RowFilter = ( + EqualFilter + | NotEqualFilter + | GreaterThanFilter + | GreaterOrEqualFilter + | LessThanFilter + | LessOrEqualFilter + | LikeFilter + | ILikeFilter + | NotLikeFilter + | NotILikeFilter + | RegexpFilter +) + + +class DataframeRowsFilter(BaseModel): + type: DATAFRAME_ROWS_FILTER + filters: list[Annotated[RowFilter, Field(..., discriminator="type")]] = Field(default_factory=list) diff --git a/syncmaster/schemas/v1/transformation_types.py b/syncmaster/schemas/v1/transformation_types.py new file mode 100644 index 00000000..9393306e --- /dev/null +++ b/syncmaster/schemas/v1/transformation_types.py @@ -0,0 +1,5 @@ +# SPDX-FileCopyrightText: 2023-2024 MTS PJSC +# SPDX-License-Identifier: Apache-2.0 +from typing import Literal + +DATAFRAME_ROWS_FILTER = Literal["dataframe_rows_filter"] diff --git a/syncmaster/server/api/v1/transfers.py b/syncmaster/server/api/v1/transfers.py index 919a588c..f8f2ccff 100644 --- a/syncmaster/server/api/v1/transfers.py +++ b/syncmaster/server/api/v1/transfers.py @@ -130,6 +130,7 @@ async def create_transfer( source_params=transfer_data.source_params.dict(), target_params=transfer_data.target_params.dict(), strategy_params=transfer_data.strategy_params.dict(), + transformations=[tr.dict() for tr in transfer_data.transformations], queue_id=transfer_data.queue_id, is_scheduled=transfer_data.is_scheduled, schedule=transfer_data.schedule, @@ -326,6 +327,9 @@ async def update_transfer( source_params=transfer_data.source_params.dict() if transfer_data.source_params else {}, target_params=transfer_data.target_params.dict() if transfer_data.target_params else {}, strategy_params=transfer_data.strategy_params.dict() if transfer_data.strategy_params else {}, + transformations=( + [tr.dict() for tr in transfer_data.transformations] if transfer_data.transformations else [] + ), is_scheduled=transfer_data.is_scheduled, schedule=transfer_data.schedule, new_queue_id=transfer_data.new_queue_id, diff --git a/syncmaster/worker/controller.py b/syncmaster/worker/controller.py index 37d6fcd0..982a2dde 100644 --- a/syncmaster/worker/controller.py +++ b/syncmaster/worker/controller.py @@ -98,7 +98,7 @@ def __init__( self.run = run self.source_handler = self.get_handler( connection_data=source_connection.data, - transfer_params=run.transfer.source_params, + transfer_params={**run.transfer.source_params, "transformations": run.transfer.transformations}, connection_auth_data=source_auth_data, ) self.target_handler = self.get_handler( diff --git a/syncmaster/worker/handlers/base.py b/syncmaster/worker/handlers/base.py index 58575fcb..45354200 100644 --- a/syncmaster/worker/handlers/base.py +++ b/syncmaster/worker/handlers/base.py @@ -31,3 +31,47 @@ def read(self) -> DataFrame: ... @abstractmethod def write(self, df: DataFrame) -> None: ... + + +class BaseHandler(Handler): + + def _apply_filters(self, df: DataFrame) -> DataFrame: + for transformation in self.transfer_dto.transformations: + if transformation["type"] == "dataframe_rows_filter": + filter_expression = self._get_filter_expression(transformation["filters"]) + if filter_expression: + df = df.where(filter_expression) + return df + + @staticmethod + def _get_filter_expression(filters: list[dict]) -> str: + operators = { + "equal": "=", + "not_equal": "!=", + "greater_than": ">", + "greater_or_equal": ">=", + "less_than": "<", + "less_or_equal": "<=", + "like": "LIKE", + "ilike": "ILIKE", + "not_like": "NOT LIKE", + "not_ilike": "NOT ILIKE", + "regexp": "RLIKE", + } + + expressions = [] + for filter in filters: + field = filter["field"] + op = operators[filter["type"]] + value = filter["value"] + + if value is None: + if op == "!=": + expressions.append(f"{field} IS NOT NULL") + elif op == "=": + expressions.append(f"{field} IS NULL") + else: + value = repr(value) if isinstance(value, str) else value + expressions.append(f"{field} {op} {value}") + + return " AND ".join(expressions) diff --git a/syncmaster/worker/handlers/db/base.py b/syncmaster/worker/handlers/db/base.py index 5f1c2c39..b3b08422 100644 --- a/syncmaster/worker/handlers/db/base.py +++ b/syncmaster/worker/handlers/db/base.py @@ -10,13 +10,13 @@ from onetl.db import DBReader, DBWriter from syncmaster.dto.transfers import DBTransferDTO -from syncmaster.worker.handlers.base import Handler +from syncmaster.worker.handlers.base import BaseHandler if TYPE_CHECKING: from pyspark.sql.dataframe import DataFrame -class DBHandler(Handler): +class DBHandler(BaseHandler): connection: BaseDBConnection transfer_dto: DBTransferDTO @@ -25,7 +25,8 @@ def read(self) -> DataFrame: connection=self.connection, table=self.transfer_dto.table_name, ) - return reader.run() + df = reader.run() + return self._apply_filters(df) def write(self, df: DataFrame) -> None: writer = DBWriter( diff --git a/syncmaster/worker/handlers/file/base.py b/syncmaster/worker/handlers/file/base.py index fb409391..a3003d87 100644 --- a/syncmaster/worker/handlers/file/base.py +++ b/syncmaster/worker/handlers/file/base.py @@ -10,13 +10,13 @@ from syncmaster.dto.connections import ConnectionDTO from syncmaster.dto.transfers import FileTransferDTO -from syncmaster.worker.handlers.base import Handler +from syncmaster.worker.handlers.base import BaseHandler if TYPE_CHECKING: from pyspark.sql.dataframe import DataFrame -class FileHandler(Handler): +class FileHandler(BaseHandler): connection: BaseFileDFConnection connection_dto: ConnectionDTO transfer_dto: FileTransferDTO @@ -31,8 +31,9 @@ def read(self) -> DataFrame: df_schema=StructType.fromJson(self.transfer_dto.df_schema) if self.transfer_dto.df_schema else None, options=self.transfer_dto.options, ) + df = reader.run() - return reader.run() + return self._apply_filters(df) def write(self, df: DataFrame): writer = FileDFWriter( diff --git a/syncmaster/worker/handlers/file/s3.py b/syncmaster/worker/handlers/file/s3.py index a805ad38..a71f33d2 100644 --- a/syncmaster/worker/handlers/file/s3.py +++ b/syncmaster/worker/handlers/file/s3.py @@ -45,5 +45,6 @@ def read(self) -> DataFrame: df_schema=StructType.fromJson(self.transfer_dto.df_schema) if self.transfer_dto.df_schema else None, options={**options, **self.transfer_dto.options}, ) + df = reader.run() - return reader.run() + return self._apply_filters(df) diff --git a/tests/test_integration/test_run_transfer/conftest.py b/tests/test_integration/test_run_transfer/conftest.py index 820ca5de..47695f97 100644 --- a/tests/test_integration/test_run_transfer/conftest.py +++ b/tests/test_integration/test_run_transfer/conftest.py @@ -237,8 +237,8 @@ def mssql_for_conftest(test_settings: TestSettings) -> MSSQLConnectionDTO: ) def mssql_for_worker(test_settings: TestSettings) -> MSSQLConnectionDTO: return MSSQLConnectionDTO( - host=test_settings.TEST_MSSQL_HOST_FOR_CONFTEST, - port=test_settings.TEST_MSSQL_PORT_FOR_CONFTEST, + host=test_settings.TEST_MSSQL_HOST_FOR_WORKER, + port=test_settings.TEST_MSSQL_PORT_FOR_WORKER, user=test_settings.TEST_MSSQL_USER, password=test_settings.TEST_MSSQL_PASSWORD, database_name=test_settings.TEST_MSSQL_DB, @@ -1255,8 +1255,9 @@ def init_df_with_mixed_column_naming(spark: SparkSession) -> DataFrame: df_schema = StructType( [ StructField("Id", IntegerType()), - StructField("Phone Number", StringType()), + StructField("Phone_Number", StringType()), StructField("region", StringType()), + StructField("NUMBER", IntegerType()), StructField("birth_DATE", DateType()), StructField("Registered At", TimestampType()), StructField("account_balance", DoubleType()), @@ -1269,10 +1270,61 @@ def init_df_with_mixed_column_naming(spark: SparkSession) -> DataFrame: 1, "+79123456789", "Mordor", + 1, datetime.date(year=2023, month=3, day=11), datetime.datetime.now(), 1234.2343, ), + ( + 2, + "+79234567890", + "Gondor", + 2, + datetime.date(2022, 6, 19), + datetime.datetime.now(), + 2345.5678, + ), + ( + 3, + "+79345678901", + "Rohan", + 3, + datetime.date(2021, 11, 5), + datetime.datetime.now(), + 3456.7890, + ), + ( + 4, + "+79456789012", + "Shire", + 4, + datetime.date(2020, 1, 30), + datetime.datetime.now(), + 4567.8901, + ), + ( + 5, + "+79567890123", + "Isengard", + 5, + datetime.date(2023, 8, 15), + datetime.datetime.now(), + 5678.9012, + ), ], schema=df_schema, ) + + +@pytest.fixture +def filter_constants() -> dict[str, str]: + return { + "NOT_EQUAL_FILTER_VALUE": "Baileytown", + "GREATER_THAN_FILTER_VALUE": "3", + "GREATER_OR_EQUAL_FILTER_VALUE": "3", + "LESS_THAN_FILTER_VALUE": "25", + "LESS_OR_EQUAL_FILTER_VALUE": "25", + "NOT_ILIKE_FILTER_VALUE": "new%", + "NOT_LIKE_FILTER_VALUE": "%port", + "REGEXP_FILTER_VALUE": "^8[ \\t]", + } diff --git a/tests/test_integration/test_run_transfer/test_clickhouse.py b/tests/test_integration/test_run_transfer/test_clickhouse.py index eb651db5..2c59c5dd 100644 --- a/tests/test_integration/test_run_transfer/test_clickhouse.py +++ b/tests/test_integration/test_run_transfer/test_clickhouse.py @@ -24,6 +24,7 @@ async def postgres_to_clickhouse( clickhouse_for_conftest: Clickhouse, clickhouse_connection: Connection, postgres_connection: Connection, + filter_constants: dict[str, str], ): result = await create_transfer( session=session, @@ -39,6 +40,33 @@ async def postgres_to_clickhouse( "type": "clickhouse", "table_name": f"{clickhouse_for_conftest.user}.target_table", }, + transformations=[ + { + "type": "dataframe_rows_filter", + "filters": [ + { + "type": "not_equal", + "field": "REGION", + "value": filter_constants["NOT_EQUAL_FILTER_VALUE"], + }, + { + "type": "greater_than", + "field": "NUMBER", + "value": filter_constants["GREATER_THAN_FILTER_VALUE"], + }, + { + "type": "less_than", + "field": "NUMBER", + "value": filter_constants["LESS_THAN_FILTER_VALUE"], + }, + { + "type": "not_ilike", + "field": "REGION", + "value": filter_constants["NOT_ILIKE_FILTER_VALUE"], + }, + ], + }, + ], queue_id=queue.id, ) yield result @@ -54,6 +82,7 @@ async def clickhouse_to_postgres( clickhouse_for_conftest: Clickhouse, clickhouse_connection: Connection, postgres_connection: Connection, + filter_constants: dict[str, str], ): result = await create_transfer( session=session, @@ -69,6 +98,28 @@ async def clickhouse_to_postgres( "type": "postgres", "table_name": "public.target_table", }, + transformations=[ + { + "type": "dataframe_rows_filter", + "filters": [ + { + "type": "less_or_equal", + "field": "NUMBER", + "value": filter_constants["LESS_OR_EQUAL_FILTER_VALUE"], + }, + { + "type": "greater_or_equal", + "field": "NUMBER", + "value": filter_constants["GREATER_OR_EQUAL_FILTER_VALUE"], + }, + { + "type": "regexp", + "field": "PHONE_NUMBER", + "value": filter_constants["REGEXP_FILTER_VALUE"], + }, + ], + }, + ], queue_id=queue.id, ) yield result @@ -83,11 +134,18 @@ async def test_run_transfer_postgres_to_clickhouse( prepare_clickhouse, init_df: DataFrame, postgres_to_clickhouse: Transfer, + filter_constants: dict[str, str], ): # Arrange _, fill_with_data = prepare_postgres fill_with_data(init_df) clickhouse, _ = prepare_clickhouse + init_df = init_df.where( + (init_df["REGION"] != filter_constants["NOT_EQUAL_FILTER_VALUE"]) + & (init_df["NUMBER"] > filter_constants["GREATER_THAN_FILTER_VALUE"]) + & (init_df["NUMBER"] < filter_constants["LESS_THAN_FILTER_VALUE"]) + & (~init_df["REGION"].ilike(filter_constants["NOT_ILIKE_FILTER_VALUE"])), + ) # Act result = await client.post( @@ -129,11 +187,18 @@ async def test_run_transfer_postgres_to_clickhouse_mixed_naming( prepare_clickhouse, init_df_with_mixed_column_naming: DataFrame, postgres_to_clickhouse: Transfer, + filter_constants: dict[str, str], ): # Arrange _, fill_with_data = prepare_postgres fill_with_data(init_df_with_mixed_column_naming) clickhouse, _ = prepare_clickhouse + init_df_with_mixed_column_naming = init_df_with_mixed_column_naming.where( + (init_df_with_mixed_column_naming["REGION"] != filter_constants["NOT_EQUAL_FILTER_VALUE"]) + & (init_df_with_mixed_column_naming["NUMBER"] > filter_constants["GREATER_THAN_FILTER_VALUE"]) + & (init_df_with_mixed_column_naming["NUMBER"] < filter_constants["LESS_THAN_FILTER_VALUE"]) + & (~init_df_with_mixed_column_naming["REGION"].ilike(filter_constants["NOT_ILIKE_FILTER_VALUE"])), + ) # Act result = await client.post( @@ -169,7 +234,7 @@ async def test_run_transfer_postgres_to_clickhouse_mixed_naming( for field in init_df_with_mixed_column_naming.schema: df = df.withColumn(field.name, df[field.name].cast(field.dataType)) - assert df.collect() == init_df_with_mixed_column_naming.collect() + assert df.sort("ID").collect() == init_df_with_mixed_column_naming.sort("ID").collect() async def test_run_transfer_clickhouse_to_postgres( @@ -179,11 +244,17 @@ async def test_run_transfer_clickhouse_to_postgres( prepare_postgres, init_df: DataFrame, clickhouse_to_postgres: Transfer, + filter_constants: dict[str, str], ): # Arrange _, fill_with_data = prepare_clickhouse fill_with_data(init_df) postgres, _ = prepare_postgres + init_df = init_df.where( + (init_df["NUMBER"] >= filter_constants["GREATER_OR_EQUAL_FILTER_VALUE"]) + & (init_df["NUMBER"] <= filter_constants["LESS_OR_EQUAL_FILTER_VALUE"]) + & (init_df["PHONE_NUMBER"].rlike(filter_constants["REGEXP_FILTER_VALUE"])), + ) # Act result = await client.post( @@ -227,11 +298,17 @@ async def test_run_transfer_clickhouse_to_postgres_mixed_naming( prepare_postgres, init_df_with_mixed_column_naming: DataFrame, clickhouse_to_postgres: Transfer, + filter_constants: dict[str, str], ): # Arrange _, fill_with_data = prepare_clickhouse fill_with_data(init_df_with_mixed_column_naming) postgres, _ = prepare_postgres + init_df_with_mixed_column_naming = init_df_with_mixed_column_naming.where( + (init_df_with_mixed_column_naming["NUMBER"] >= filter_constants["GREATER_OR_EQUAL_FILTER_VALUE"]) + & (init_df_with_mixed_column_naming["NUMBER"] <= filter_constants["LESS_OR_EQUAL_FILTER_VALUE"]) + & (init_df_with_mixed_column_naming["PHONE_NUMBER"].rlike(filter_constants["REGEXP_FILTER_VALUE"])), + ) # Act result = await client.post( @@ -268,4 +345,4 @@ async def test_run_transfer_clickhouse_to_postgres_mixed_naming( for field in init_df_with_mixed_column_naming.schema: df = df.withColumn(field.name, df[field.name].cast(field.dataType)) - assert df.collect() == init_df_with_mixed_column_naming.collect() + assert df.sort("ID").collect() == init_df_with_mixed_column_naming.sort("ID").collect() diff --git a/tests/test_integration/test_run_transfer/test_hive.py b/tests/test_integration/test_run_transfer/test_hive.py index 67757876..c23b3b48 100644 --- a/tests/test_integration/test_run_transfer/test_hive.py +++ b/tests/test_integration/test_run_transfer/test_hive.py @@ -167,7 +167,7 @@ async def test_run_transfer_postgres_to_hive_mixed_naming( for field in init_df_with_mixed_column_naming.schema: df = df.withColumn(field.name, df[field.name].cast(field.dataType)) - assert df.collect() == init_df_with_mixed_column_naming.collect() + assert df.sort("ID").collect() == init_df_with_mixed_column_naming.sort("ID").collect() async def test_run_transfer_hive_to_postgres( @@ -264,4 +264,4 @@ async def test_run_transfer_hive_to_postgres_mixes_naming( for field in init_df_with_mixed_column_naming.schema: df = df.withColumn(field.name, df[field.name].cast(field.dataType)) - assert df.collect() == init_df_with_mixed_column_naming.collect() + assert df.sort("ID").collect() == init_df_with_mixed_column_naming.sort("ID").collect() diff --git a/tests/test_integration/test_run_transfer/test_mssql.py b/tests/test_integration/test_run_transfer/test_mssql.py index 93ef3895..faea7404 100644 --- a/tests/test_integration/test_run_transfer/test_mssql.py +++ b/tests/test_integration/test_run_transfer/test_mssql.py @@ -25,6 +25,7 @@ async def postgres_to_mssql( mssql_for_conftest: MSSQL, mssql_connection: Connection, postgres_connection: Connection, + filter_constants: dict[str, str], ): result = await create_transfer( session=session, @@ -40,6 +41,33 @@ async def postgres_to_mssql( "type": "mssql", "table_name": "dbo.target_table", }, + transformations=[ + { + "type": "dataframe_rows_filter", + "filters": [ + { + "type": "not_equal", + "field": "REGION", + "value": filter_constants["NOT_EQUAL_FILTER_VALUE"], + }, + { + "type": "greater_than", + "field": "NUMBER", + "value": filter_constants["GREATER_THAN_FILTER_VALUE"], + }, + { + "type": "less_than", + "field": "NUMBER", + "value": filter_constants["LESS_THAN_FILTER_VALUE"], + }, + { + "type": "not_ilike", + "field": "REGION", + "value": filter_constants["NOT_ILIKE_FILTER_VALUE"], + }, + ], + }, + ], queue_id=queue.id, ) yield result @@ -55,6 +83,7 @@ async def mssql_to_postgres( mssql_for_conftest: MSSQL, mssql_connection: Connection, postgres_connection: Connection, + filter_constants: dict[str, str], ): result = await create_transfer( session=session, @@ -70,6 +99,33 @@ async def mssql_to_postgres( "type": "postgres", "table_name": "public.target_table", }, + transformations=[ + { + "type": "dataframe_rows_filter", + "filters": [ + { + "type": "less_or_equal", + "field": "NUMBER", + "value": filter_constants["LESS_OR_EQUAL_FILTER_VALUE"], + }, + { + "type": "greater_or_equal", + "field": "NUMBER", + "value": filter_constants["GREATER_OR_EQUAL_FILTER_VALUE"], + }, + { + "type": "not_like", + "field": "REGION", + "value": filter_constants["NOT_LIKE_FILTER_VALUE"], + }, + { + "type": "regexp", + "field": "PHONE_NUMBER", + "value": filter_constants["REGEXP_FILTER_VALUE"], + }, + ], + }, + ], queue_id=queue.id, ) yield result @@ -84,11 +140,18 @@ async def test_run_transfer_postgres_to_mssql( prepare_mssql, init_df: DataFrame, postgres_to_mssql: Transfer, + filter_constants: dict[str, str], ): # Arrange _, fill_with_data = prepare_postgres fill_with_data(init_df) mssql, _ = prepare_mssql + init_df = init_df.where( + (init_df["REGION"] != filter_constants["NOT_EQUAL_FILTER_VALUE"]) + & (init_df["NUMBER"] > filter_constants["GREATER_THAN_FILTER_VALUE"]) + & (init_df["NUMBER"] < filter_constants["LESS_THAN_FILTER_VALUE"]) + & (~init_df["REGION"].ilike(filter_constants["NOT_ILIKE_FILTER_VALUE"])), + ) # Act result = await client.post( @@ -135,11 +198,18 @@ async def test_run_transfer_postgres_to_mssql_mixed_naming( prepare_mssql, init_df_with_mixed_column_naming: DataFrame, postgres_to_mssql: Transfer, + filter_constants: dict[str, str], ): # Arrange _, fill_with_data = prepare_postgres fill_with_data(init_df_with_mixed_column_naming) mssql, _ = prepare_mssql + init_df_with_mixed_column_naming = init_df_with_mixed_column_naming.where( + (init_df_with_mixed_column_naming["REGION"] != filter_constants["NOT_EQUAL_FILTER_VALUE"]) + & (init_df_with_mixed_column_naming["NUMBER"] > filter_constants["GREATER_THAN_FILTER_VALUE"]) + & (init_df_with_mixed_column_naming["NUMBER"] < filter_constants["LESS_THAN_FILTER_VALUE"]) + & (~init_df_with_mixed_column_naming["REGION"].ilike(filter_constants["NOT_ILIKE_FILTER_VALUE"])), + ) # Act result = await client.post( @@ -183,7 +253,7 @@ async def test_run_transfer_postgres_to_mssql_mixed_naming( for field in init_df_with_mixed_column_naming.schema: df = df.withColumn(field.name, df[field.name].cast(field.dataType)) - assert df.collect() == init_df_with_mixed_column_naming.collect() + assert df.sort("ID").collect() == init_df_with_mixed_column_naming.sort("ID").collect() async def test_run_transfer_mssql_to_postgres( @@ -193,11 +263,18 @@ async def test_run_transfer_mssql_to_postgres( prepare_postgres, init_df: DataFrame, mssql_to_postgres: Transfer, + filter_constants: dict[str, str], ): # Arrange _, fill_with_data = prepare_mssql fill_with_data(init_df) postgres, _ = prepare_postgres + init_df = init_df.where( + (init_df["NUMBER"] >= filter_constants["GREATER_OR_EQUAL_FILTER_VALUE"]) + & (init_df["NUMBER"] <= filter_constants["LESS_OR_EQUAL_FILTER_VALUE"]) + & (~init_df["REGION"].like(filter_constants["NOT_LIKE_FILTER_VALUE"])) + & (init_df["PHONE_NUMBER"].rlike(filter_constants["REGEXP_FILTER_VALUE"])), + ) # Act result = await client.post( @@ -245,11 +322,18 @@ async def test_run_transfer_mssql_to_postgres_mixed_naming( prepare_postgres, init_df_with_mixed_column_naming: DataFrame, mssql_to_postgres: Transfer, + filter_constants: dict[str, str], ): # Arrange _, fill_with_data = prepare_mssql fill_with_data(init_df_with_mixed_column_naming) postgres, _ = prepare_postgres + init_df_with_mixed_column_naming = init_df_with_mixed_column_naming.where( + (init_df_with_mixed_column_naming["NUMBER"] >= filter_constants["GREATER_OR_EQUAL_FILTER_VALUE"]) + & (init_df_with_mixed_column_naming["NUMBER"] <= filter_constants["LESS_OR_EQUAL_FILTER_VALUE"]) + & (~init_df_with_mixed_column_naming["REGION"].like(filter_constants["NOT_LIKE_FILTER_VALUE"])) + & (init_df_with_mixed_column_naming["PHONE_NUMBER"].rlike(filter_constants["REGEXP_FILTER_VALUE"])), + ) # Act result = await client.post( @@ -293,4 +377,4 @@ async def test_run_transfer_mssql_to_postgres_mixed_naming( for field in init_df_with_mixed_column_naming.schema: df = df.withColumn(field.name, df[field.name].cast(field.dataType)) - assert df.collect() == init_df_with_mixed_column_naming.collect() + assert df.sort("ID").collect() == init_df_with_mixed_column_naming.sort("ID").collect() diff --git a/tests/test_integration/test_run_transfer/test_mysql.py b/tests/test_integration/test_run_transfer/test_mysql.py index 513c2098..cd2ee479 100644 --- a/tests/test_integration/test_run_transfer/test_mysql.py +++ b/tests/test_integration/test_run_transfer/test_mysql.py @@ -25,6 +25,7 @@ async def postgres_to_mysql( mysql_for_conftest: MySQL, mysql_connection: Connection, postgres_connection: Connection, + filter_constants: dict[str, str], ): result = await create_transfer( session=session, @@ -40,6 +41,33 @@ async def postgres_to_mysql( "type": "mysql", "table_name": f"{mysql_for_conftest.database_name}.target_table", }, + transformations=[ + { + "type": "dataframe_rows_filter", + "filters": [ + { + "type": "not_equal", + "field": "REGION", + "value": filter_constants["NOT_EQUAL_FILTER_VALUE"], + }, + { + "type": "greater_than", + "field": "NUMBER", + "value": filter_constants["GREATER_THAN_FILTER_VALUE"], + }, + { + "type": "less_than", + "field": "NUMBER", + "value": filter_constants["LESS_THAN_FILTER_VALUE"], + }, + { + "type": "not_ilike", + "field": "REGION", + "value": filter_constants["NOT_ILIKE_FILTER_VALUE"], + }, + ], + }, + ], queue_id=queue.id, ) yield result @@ -55,6 +83,7 @@ async def mysql_to_postgres( mysql_for_conftest: MySQL, mysql_connection: Connection, postgres_connection: Connection, + filter_constants: dict[str, str], ): result = await create_transfer( session=session, @@ -70,6 +99,33 @@ async def mysql_to_postgres( "type": "postgres", "table_name": "public.target_table", }, + transformations=[ + { + "type": "dataframe_rows_filter", + "filters": [ + { + "type": "less_or_equal", + "field": "NUMBER", + "value": filter_constants["LESS_OR_EQUAL_FILTER_VALUE"], + }, + { + "type": "greater_or_equal", + "field": "NUMBER", + "value": filter_constants["GREATER_OR_EQUAL_FILTER_VALUE"], + }, + { + "type": "not_like", + "field": "REGION", + "value": filter_constants["NOT_LIKE_FILTER_VALUE"], + }, + { + "type": "regexp", + "field": "PHONE_NUMBER", + "value": filter_constants["REGEXP_FILTER_VALUE"], + }, + ], + }, + ], queue_id=queue.id, ) yield result @@ -84,11 +140,18 @@ async def test_run_transfer_postgres_to_mysql( prepare_mysql, init_df: DataFrame, postgres_to_mysql: Transfer, + filter_constants: dict[str, str], ): # Arrange _, fill_with_data = prepare_postgres fill_with_data(init_df) mysql, _ = prepare_mysql + init_df = init_df.where( + (init_df["REGION"] != filter_constants["NOT_EQUAL_FILTER_VALUE"]) + & (init_df["NUMBER"] > filter_constants["GREATER_THAN_FILTER_VALUE"]) + & (init_df["NUMBER"] < filter_constants["LESS_THAN_FILTER_VALUE"]) + & (~init_df["REGION"].ilike(filter_constants["NOT_ILIKE_FILTER_VALUE"])), + ) # Act result = await client.post( @@ -140,11 +203,18 @@ async def test_run_transfer_postgres_to_mysql_mixed_naming( prepare_mysql, init_df_with_mixed_column_naming: DataFrame, postgres_to_mysql: Transfer, + filter_constants: dict[str, str], ): # Arrange _, fill_with_data = prepare_postgres fill_with_data(init_df_with_mixed_column_naming) mysql, _ = prepare_mysql + init_df_with_mixed_column_naming = init_df_with_mixed_column_naming.where( + (init_df_with_mixed_column_naming["REGION"] != filter_constants["NOT_EQUAL_FILTER_VALUE"]) + & (init_df_with_mixed_column_naming["NUMBER"] > filter_constants["GREATER_THAN_FILTER_VALUE"]) + & (init_df_with_mixed_column_naming["NUMBER"] < filter_constants["LESS_THAN_FILTER_VALUE"]) + & (~init_df_with_mixed_column_naming["REGION"].ilike(filter_constants["NOT_ILIKE_FILTER_VALUE"])), + ) # Act result = await client.post( @@ -190,7 +260,7 @@ async def test_run_transfer_postgres_to_mysql_mixed_naming( for field in init_df_with_mixed_column_naming.schema: df = df.withColumn(field.name, df[field.name].cast(field.dataType)) - assert df.collect() == init_df_with_mixed_column_naming.collect() + assert df.sort("ID").collect() == init_df_with_mixed_column_naming.sort("ID").collect() async def test_run_transfer_mysql_to_postgres( @@ -200,11 +270,18 @@ async def test_run_transfer_mysql_to_postgres( prepare_postgres, init_df: DataFrame, mysql_to_postgres: Transfer, + filter_constants: dict[str, str], ): # Arrange _, fill_with_data = prepare_mysql fill_with_data(init_df) postgres, _ = prepare_postgres + init_df = init_df.where( + (init_df["NUMBER"] >= filter_constants["GREATER_OR_EQUAL_FILTER_VALUE"]) + & (init_df["NUMBER"] <= filter_constants["LESS_OR_EQUAL_FILTER_VALUE"]) + & (~init_df["REGION"].like(filter_constants["NOT_LIKE_FILTER_VALUE"])) + & (init_df["PHONE_NUMBER"].rlike(filter_constants["REGEXP_FILTER_VALUE"])), + ) # Act result = await client.post( @@ -257,11 +334,18 @@ async def test_run_transfer_mysql_to_postgres_mixed_naming( prepare_postgres, init_df_with_mixed_column_naming: DataFrame, mysql_to_postgres: Transfer, + filter_constants: dict[str, str], ): # Arrange _, fill_with_data = prepare_mysql fill_with_data(init_df_with_mixed_column_naming) postgres, _ = prepare_postgres + init_df_with_mixed_column_naming = init_df_with_mixed_column_naming.where( + (init_df_with_mixed_column_naming["NUMBER"] >= filter_constants["GREATER_OR_EQUAL_FILTER_VALUE"]) + & (init_df_with_mixed_column_naming["NUMBER"] <= filter_constants["LESS_OR_EQUAL_FILTER_VALUE"]) + & (~init_df_with_mixed_column_naming["REGION"].like(filter_constants["NOT_LIKE_FILTER_VALUE"])) + & (init_df_with_mixed_column_naming["PHONE_NUMBER"].rlike(filter_constants["REGEXP_FILTER_VALUE"])), + ) # Act result = await client.post( @@ -307,4 +391,4 @@ async def test_run_transfer_mysql_to_postgres_mixed_naming( for field in init_df_with_mixed_column_naming.schema: df = df.withColumn(field.name, df[field.name].cast(field.dataType)) - assert df.collect() == init_df_with_mixed_column_naming.collect() + assert df.sort("ID").collect() == init_df_with_mixed_column_naming.sort("ID").collect() diff --git a/tests/test_integration/test_run_transfer/test_oracle.py b/tests/test_integration/test_run_transfer/test_oracle.py index 9737ab0f..cc35ad06 100644 --- a/tests/test_integration/test_run_transfer/test_oracle.py +++ b/tests/test_integration/test_run_transfer/test_oracle.py @@ -24,6 +24,7 @@ async def postgres_to_oracle( oracle_for_conftest: Oracle, oracle_connection: Connection, postgres_connection: Connection, + filter_constants: dict[str, str], ): result = await create_transfer( session=session, @@ -39,6 +40,33 @@ async def postgres_to_oracle( "type": "oracle", "table_name": f"{oracle_for_conftest.user}.target_table", }, + transformations=[ + { + "type": "dataframe_rows_filter", + "filters": [ + { + "type": "not_equal", + "field": "REGION", + "value": filter_constants["NOT_EQUAL_FILTER_VALUE"], + }, + { + "type": "greater_than", + "field": "NUMBER", + "value": filter_constants["GREATER_THAN_FILTER_VALUE"], + }, + { + "type": "less_than", + "field": "NUMBER", + "value": filter_constants["LESS_THAN_FILTER_VALUE"], + }, + { + "type": "not_ilike", + "field": "REGION", + "value": filter_constants["NOT_ILIKE_FILTER_VALUE"], + }, + ], + }, + ], queue_id=queue.id, ) yield result @@ -54,6 +82,7 @@ async def oracle_to_postgres( oracle_for_conftest: Oracle, oracle_connection: Connection, postgres_connection: Connection, + filter_constants: dict[str, str], ): result = await create_transfer( session=session, @@ -69,6 +98,33 @@ async def oracle_to_postgres( "type": "postgres", "table_name": "public.target_table", }, + transformations=[ + { + "type": "dataframe_rows_filter", + "filters": [ + { + "type": "less_or_equal", + "field": "NUMBER", + "value": filter_constants["LESS_OR_EQUAL_FILTER_VALUE"], + }, + { + "type": "greater_or_equal", + "field": "NUMBER", + "value": filter_constants["GREATER_OR_EQUAL_FILTER_VALUE"], + }, + { + "type": "not_like", + "field": "REGION", + "value": filter_constants["NOT_LIKE_FILTER_VALUE"], + }, + { + "type": "regexp", + "field": "PHONE_NUMBER", + "value": filter_constants["REGEXP_FILTER_VALUE"], + }, + ], + }, + ], queue_id=queue.id, ) yield result @@ -83,11 +139,18 @@ async def test_run_transfer_postgres_to_oracle( prepare_oracle, init_df: DataFrame, postgres_to_oracle: Transfer, + filter_constants: dict[str, str], ): # Arrange _, fill_with_data = prepare_postgres fill_with_data(init_df) oracle, _ = prepare_oracle + init_df = init_df.where( + (init_df["REGION"] != filter_constants["NOT_EQUAL_FILTER_VALUE"]) + & (init_df["NUMBER"] > filter_constants["GREATER_THAN_FILTER_VALUE"]) + & (init_df["NUMBER"] < filter_constants["LESS_THAN_FILTER_VALUE"]) + & (~init_df["REGION"].ilike(filter_constants["NOT_ILIKE_FILTER_VALUE"])), + ) # Act result = await client.post( @@ -129,11 +192,18 @@ async def test_run_transfer_postgres_to_oracle_mixed_naming( prepare_oracle, init_df_with_mixed_column_naming: DataFrame, postgres_to_oracle: Transfer, + filter_constants: dict[str, str], ): # Arrange _, fill_with_data = prepare_postgres fill_with_data(init_df_with_mixed_column_naming) oracle, _ = prepare_oracle + init_df_with_mixed_column_naming = init_df_with_mixed_column_naming.where( + (init_df_with_mixed_column_naming["REGION"] != filter_constants["NOT_EQUAL_FILTER_VALUE"]) + & (init_df_with_mixed_column_naming["NUMBER"] > filter_constants["GREATER_THAN_FILTER_VALUE"]) + & (init_df_with_mixed_column_naming["NUMBER"] < filter_constants["LESS_THAN_FILTER_VALUE"]) + & (~init_df_with_mixed_column_naming["REGION"].ilike(filter_constants["NOT_ILIKE_FILTER_VALUE"])), + ) # Act result = await client.post( @@ -169,7 +239,7 @@ async def test_run_transfer_postgres_to_oracle_mixed_naming( for field in init_df_with_mixed_column_naming.schema: df = df.withColumn(field.name, df[field.name].cast(field.dataType)) - assert df.collect() == init_df_with_mixed_column_naming.collect() + assert df.sort("ID").collect() == init_df_with_mixed_column_naming.sort("ID").collect() async def test_run_transfer_oracle_to_postgres( @@ -179,11 +249,18 @@ async def test_run_transfer_oracle_to_postgres( prepare_postgres, init_df: DataFrame, oracle_to_postgres: Transfer, + filter_constants: dict[str, str], ): # Arrange _, fill_with_data = prepare_oracle fill_with_data(init_df) postgres, _ = prepare_postgres + init_df = init_df.where( + (init_df["NUMBER"] >= filter_constants["GREATER_OR_EQUAL_FILTER_VALUE"]) + & (init_df["NUMBER"] <= filter_constants["LESS_OR_EQUAL_FILTER_VALUE"]) + & (~init_df["REGION"].like(filter_constants["NOT_LIKE_FILTER_VALUE"])) + & (init_df["PHONE_NUMBER"].rlike(filter_constants["REGEXP_FILTER_VALUE"])), + ) # Act result = await client.post( @@ -227,11 +304,18 @@ async def test_run_transfer_oracle_to_postgres_mixed_naming( prepare_postgres, init_df_with_mixed_column_naming: DataFrame, oracle_to_postgres: Transfer, + filter_constants: dict[str, str], ): # Arrange _, fill_with_data = prepare_oracle fill_with_data(init_df_with_mixed_column_naming) postgres, _ = prepare_postgres + init_df_with_mixed_column_naming = init_df_with_mixed_column_naming.where( + (init_df_with_mixed_column_naming["NUMBER"] >= filter_constants["GREATER_OR_EQUAL_FILTER_VALUE"]) + & (init_df_with_mixed_column_naming["NUMBER"] <= filter_constants["LESS_OR_EQUAL_FILTER_VALUE"]) + & (~init_df_with_mixed_column_naming["REGION"].like(filter_constants["NOT_LIKE_FILTER_VALUE"])) + & (init_df_with_mixed_column_naming["PHONE_NUMBER"].rlike(filter_constants["REGEXP_FILTER_VALUE"])), + ) # Act result = await client.post( @@ -268,4 +352,4 @@ async def test_run_transfer_oracle_to_postgres_mixed_naming( for field in init_df_with_mixed_column_naming.schema: df = df.withColumn(field.name, df[field.name].cast(field.dataType)) - assert df.collect() == init_df_with_mixed_column_naming.collect() + assert df.sort("ID").collect() == init_df_with_mixed_column_naming.sort("ID").collect() diff --git a/tests/test_integration/test_run_transfer/test_s3.py b/tests/test_integration/test_run_transfer/test_s3.py index 1ad7ac7a..359202c6 100644 --- a/tests/test_integration/test_run_transfer/test_s3.py +++ b/tests/test_integration/test_run_transfer/test_s3.py @@ -37,6 +37,7 @@ async def s3_to_postgres( prepare_s3, source_file_format, file_format_flavor: str, + filter_constants: dict[str, str], ): format_name, file_format = source_file_format format_name_in_path = "xlsx" if format_name == "excel" else format_name @@ -62,6 +63,33 @@ async def s3_to_postgres( "type": "postgres", "table_name": "public.target_table", }, + transformations=[ + { + "type": "dataframe_rows_filter", + "filters": [ + { + "type": "not_equal", + "field": "REGION", + "value": filter_constants["NOT_EQUAL_FILTER_VALUE"], + }, + { + "type": "greater_than", + "field": "NUMBER", + "value": filter_constants["GREATER_THAN_FILTER_VALUE"], + }, + { + "type": "less_than", + "field": "NUMBER", + "value": filter_constants["LESS_THAN_FILTER_VALUE"], + }, + { + "type": "not_ilike", + "field": "REGION", + "value": filter_constants["NOT_ILIKE_FILTER_VALUE"], + }, + ], + }, + ], queue_id=queue.id, ) yield result @@ -78,6 +106,7 @@ async def postgres_to_s3( postgres_connection: Connection, target_file_format, file_format_flavor: str, + filter_constants: dict[str, str], ): format_name, file_format = target_file_format result = await create_transfer( @@ -99,6 +128,33 @@ async def postgres_to_s3( }, "options": {}, }, + transformations=[ + { + "type": "dataframe_rows_filter", + "filters": [ + { + "type": "less_or_equal", + "field": "NUMBER", + "value": filter_constants["LESS_OR_EQUAL_FILTER_VALUE"], + }, + { + "type": "greater_or_equal", + "field": "NUMBER", + "value": filter_constants["GREATER_OR_EQUAL_FILTER_VALUE"], + }, + { + "type": "not_like", + "field": "REGION", + "value": filter_constants["NOT_LIKE_FILTER_VALUE"], + }, + { + "type": "regexp", + "field": "PHONE_NUMBER", + "value": filter_constants["REGEXP_FILTER_VALUE"], + }, + ], + }, + ], queue_id=queue.id, ) yield result @@ -153,12 +209,19 @@ async def test_run_transfer_s3_to_postgres( init_df: DataFrame, client: AsyncClient, s3_to_postgres: Transfer, + filter_constants: dict[str, str], source_file_format, file_format_flavor, ): # Arrange postgres, _ = prepare_postgres file_format, _ = source_file_format + init_df = init_df.where( + (init_df["REGION"] != filter_constants["NOT_EQUAL_FILTER_VALUE"]) + & (init_df["NUMBER"] > filter_constants["GREATER_THAN_FILTER_VALUE"]) + & (init_df["NUMBER"] < filter_constants["LESS_THAN_FILTER_VALUE"]) + & (~init_df["REGION"].ilike(filter_constants["NOT_ILIKE_FILTER_VALUE"])), + ) # Act result = await client.post( @@ -245,6 +308,7 @@ async def test_run_transfer_postgres_to_s3( prepare_postgres, prepare_s3, postgres_to_s3: Transfer, + filter_constants: dict[str, str], target_file_format, file_format_flavor: str, ): @@ -253,6 +317,12 @@ async def test_run_transfer_postgres_to_s3( # Arrange _, fill_with_data = prepare_postgres fill_with_data(init_df) + init_df = init_df.where( + (init_df["NUMBER"] >= filter_constants["GREATER_OR_EQUAL_FILTER_VALUE"]) + & (init_df["NUMBER"] <= filter_constants["LESS_OR_EQUAL_FILTER_VALUE"]) + & (~init_df["REGION"].like(filter_constants["NOT_LIKE_FILTER_VALUE"])) + & (init_df["PHONE_NUMBER"].rlike(filter_constants["REGEXP_FILTER_VALUE"])), + ) # Act result = await client.post( diff --git a/tests/test_integration/test_scheduler/scheduler_fixtures/transfer_fixture.py b/tests/test_integration/test_scheduler/scheduler_fixtures/transfer_fixture.py index 173b5c87..e99bd89f 100644 --- a/tests/test_integration/test_scheduler/scheduler_fixtures/transfer_fixture.py +++ b/tests/test_integration/test_scheduler/scheduler_fixtures/transfer_fixture.py @@ -105,8 +105,9 @@ async def group_transfer_integration_mock( source_connection_id=source_connection.id, target_connection_id=target_connection.id, queue_id=queue.id, - source_params=create_transfer_data, - target_params=create_transfer_data, + source_params=create_transfer_data.get("source_and_target_params") if create_transfer_data else None, + target_params=create_transfer_data.get("source_and_target_params") if create_transfer_data else None, + transformations=create_transfer_data.get("transformations") if create_transfer_data else None, ) yield MockTransfer( diff --git a/tests/test_unit/test_transfers/test_create_transfer.py b/tests/test_unit/test_transfers/test_create_transfer.py index e3e7b6e0..953c09be 100644 --- a/tests/test_unit/test_transfers/test_create_transfer.py +++ b/tests/test_unit/test_transfers/test_create_transfer.py @@ -36,6 +36,23 @@ async def test_developer_plus_can_create_transfer( "source_params": {"type": "postgres", "table_name": "source_table"}, "target_params": {"type": "postgres", "table_name": "target_table"}, "strategy_params": {"type": "full"}, + "transformations": [ + { + "type": "dataframe_rows_filter", + "filters": [ + { + "type": "equal", + "field": "col1", + "value": "something", + }, + { + "type": "greater_than", + "field": "col2", + "value": "20", + }, + ], + }, + ], "queue_id": group_queue.id, }, ) @@ -64,6 +81,7 @@ async def test_developer_plus_can_create_transfer( "source_params": transfer.source_params, "target_params": transfer.target_params, "strategy_params": transfer.strategy_params, + "transformations": transfer.transformations, "queue_id": transfer.queue_id, } @@ -211,6 +229,23 @@ async def test_superuser_can_create_transfer( "source_params": {"type": "postgres", "table_name": "source_table"}, "target_params": {"type": "postgres", "table_name": "target_table"}, "strategy_params": {"type": "full"}, + "transformations": [ + { + "type": "dataframe_rows_filter", + "filters": [ + { + "type": "equal", + "field": "col1", + "value": "something", + }, + { + "type": "greater_than", + "field": "col2", + "value": "20", + }, + ], + }, + ], "queue_id": group_queue.id, }, ) @@ -236,6 +271,7 @@ async def test_superuser_can_create_transfer( "source_params": transfer.source_params, "target_params": transfer.target_params, "strategy_params": transfer.strategy_params, + "transformations": transfer.transformations, "queue_id": transfer.queue_id, } @@ -387,6 +423,95 @@ async def test_superuser_can_create_transfer( }, }, ), + ( + { + "transformations": [ + { + "type": "some unknown transformation type", + "filters": [ + { + "type": "equal", + "field": "col1", + "value": "something", + }, + ], + }, + ], + }, + { + "error": { + "code": "invalid_request", + "message": "Invalid request", + "details": [ + { + "location": ["body", "transformations", 0], + "message": ( + "Input tag 'some unknown transformation type' found using 'type' " + "does not match any of the expected tags: 'dataframe_rows_filter'" + ), + "code": "union_tag_invalid", + "context": { + "discriminator": "'type'", + "expected_tags": "'dataframe_rows_filter'", + "tag": "some unknown transformation type", + }, + "input": { + "type": "some unknown transformation type", + "filters": [ + { + "type": "equal", + "field": "col1", + "value": "something", + }, + ], + }, + }, + ], + }, + }, + ), + ( + { + "transformations": [ + { + "type": "dataframe_rows_filter", + "filters": [ + { + "type": "equals_today", + "field": "col1", + "value": "something", + }, + ], + }, + ], + }, + { + "error": { + "code": "invalid_request", + "message": "Invalid request", + "details": [ + { + "location": ["body", "transformations", 0, "dataframe_rows_filter", "filters", 0], + "message": ( + "Input tag 'equals_today' found using 'type' does not match any of the expected tags: 'equal', 'not_equal', " + "'greater_than', 'greater_or_equal', 'less_than', 'less_or_equal', 'like', 'ilike', 'not_like', 'not_ilike', 'regexp'" + ), + "code": "union_tag_invalid", + "context": { + "discriminator": "'type'", + "tag": "equals_today", + "expected_tags": "'equal', 'not_equal', 'greater_than', 'greater_or_equal', 'less_than', 'less_or_equal', 'like', 'ilike', 'not_like', 'not_ilike', 'regexp'", + }, + "input": { + "type": "equals_today", + "field": "col1", + "value": "something", + }, + }, + ], + }, + }, + ), ), ) async def test_check_fields_validation_on_create_transfer( @@ -414,6 +539,7 @@ async def test_check_fields_validation_on_create_transfer( "source_params": {"type": "postgres", "table_name": "source_table"}, "target_params": {"type": "postgres", "table_name": "target_table"}, "strategy_params": {"type": "full"}, + "transformations": [], "queue_id": group_queue.id, } transfer_data.update(new_data) @@ -584,7 +710,7 @@ async def test_developer_plus_cannot_create_transfer_with_other_group_queue( } -async def test_developer_plus_can_not_create_transfer_with_target_format_json( +async def test_developer_plus_cannot_create_transfer_with_target_format_json( client: AsyncClient, two_group_connections: tuple[MockConnection, MockConnection], session: AsyncSession, @@ -918,3 +1044,71 @@ async def test_superuser_cannot_create_transfer_with_unknown_queue_error( "details": None, }, } + + +async def test_superuser_cannot_create_transfer_with_duplicate_transformations( + client: AsyncClient, + two_group_connections: tuple[MockConnection, MockConnection], + session: AsyncSession, + superuser: MockUser, + group_queue: Queue, + mock_group: MockGroup, +): + # Arrange + first_connection, second_connection = two_group_connections + + # Act + result = await client.post( + "v1/transfers", + headers={"Authorization": f"Bearer {superuser.token}"}, + json={ + "group_id": mock_group.group.id, + "name": "new test transfer", + "description": "", + "is_scheduled": False, + "source_connection_id": first_connection.id, + "target_connection_id": second_connection.id, + "source_params": {"type": "postgres", "table_name": "source_table"}, + "target_params": {"type": "postgres", "table_name": "target_table"}, + "strategy_params": {"type": "full"}, + "transformations": [ + { + "type": "dataframe_rows_filter", + "filters": [ + { + "type": "equal", + "field": "col1", + "value": "something", + }, + { + "type": "greater_than", + "field": "col2", + "value": "20", + }, + ], + }, + { + "type": "dataframe_rows_filter", + "filters": [ + { + "type": "equal", + "field": "col1", + "value": "something_else", + }, + ], + }, + ], + "queue_id": group_queue.id, + }, + ) + + # Assert + result_json = result.json() + assert result.status_code == 422 + assert result_json["error"]["code"] == "invalid_request" + assert result_json["error"]["message"] == "Invalid request" + assert result_json["error"]["details"][0]["code"] == "value_error" + assert ( + result_json["error"]["details"][0]["message"] + == "Value error, Duplicate 'type' values found in transformations: dataframe_rows_filter" + ) diff --git a/tests/test_unit/test_transfers/test_file_transfers/test_create_transfer.py b/tests/test_unit/test_transfers/test_file_transfers/test_create_transfer.py index 38ac7590..2f16cf2b 100644 --- a/tests/test_unit/test_transfers/test_file_transfers/test_create_transfer.py +++ b/tests/test_unit/test_transfers/test_file_transfers/test_create_transfer.py @@ -146,6 +146,7 @@ async def test_developer_plus_can_create_s3_transfer( "source_params": transfer.source_params, "target_params": transfer.target_params, "strategy_params": transfer.strategy_params, + "transformations": transfer.transformations, "queue_id": transfer.queue_id, } @@ -304,6 +305,7 @@ async def test_developer_plus_can_create_hdfs_transfer( "source_params": transfer.source_params, "target_params": transfer.target_params, "strategy_params": transfer.strategy_params, + "transformations": transfer.transformations, "queue_id": transfer.queue_id, } diff --git a/tests/test_unit/test_transfers/test_file_transfers/test_read_transfer.py b/tests/test_unit/test_transfers/test_file_transfers/test_read_transfer.py index 412f90b6..53d2f544 100644 --- a/tests/test_unit/test_transfers/test_file_transfers/test_read_transfer.py +++ b/tests/test_unit/test_transfers/test_file_transfers/test_read_transfer.py @@ -10,58 +10,85 @@ "create_transfer_data", [ { - "type": "s3", - "directory_path": "/some/pure/path", - "file_format": { - "delimiter": ",", - "encoding": "utf-8", - "escape": "\\", - "include_header": False, - "line_sep": "\n", - "quote": '"', - "type": "csv", - "compression": "gzip", + "source_and_target_params": { + "type": "s3", + "directory_path": "/some/pure/path", + "file_format": { + "delimiter": ",", + "encoding": "utf-8", + "escape": "\\", + "include_header": False, + "line_sep": "\n", + "quote": '"', + "type": "csv", + "compression": "gzip", + }, + "options": {}, }, - "options": {}, + "transformations": [ + { + "type": "dataframe_rows_filter", + "filters": [ + { + "type": "not_equal", + "field": "col1", + "value": "something", + }, + { + "type": "less_than", + "field": "col2", + "value": "20", + }, + ], + }, + ], }, { - "type": "s3", - "directory_path": "/some/excel/path", - "file_format": { - "type": "excel", - "include_header": True, - "start_cell": "A1", + "source_and_target_params": { + "type": "s3", + "directory_path": "/some/excel/path", + "file_format": { + "type": "excel", + "include_header": True, + "start_cell": "A1", + }, + "options": {}, }, - "options": {}, }, { - "type": "s3", - "directory_path": "/some/xml/path", - "file_format": { - "type": "xml", - "root_tag": "data", - "row_tag": "record", - "compression": "bzip2", + "source_and_target_params": { + "type": "s3", + "directory_path": "/some/xml/path", + "file_format": { + "type": "xml", + "root_tag": "data", + "row_tag": "record", + "compression": "bzip2", + }, + "options": {}, }, - "options": {}, }, { - "type": "s3", - "directory_path": "/some/orc/path", - "file_format": { - "type": "orc", - "compression": "zlib", + "source_and_target_params": { + "type": "s3", + "directory_path": "/some/orc/path", + "file_format": { + "type": "orc", + "compression": "zlib", + }, + "options": {}, }, - "options": {}, }, { - "type": "s3", - "directory_path": "/some/parquet/path", - "file_format": { - "type": "parquet", - "compression": "lz4", + "source_and_target_params": { + "type": "s3", + "directory_path": "/some/parquet/path", + "file_format": { + "type": "parquet", + "compression": "lz4", + }, + "options": {}, }, - "options": {}, }, ], ) @@ -104,6 +131,7 @@ async def test_guest_plus_can_read_s3_transfer( "source_params": group_transfer.source_params, "target_params": group_transfer.target_params, "strategy_params": group_transfer.strategy_params, + "transformations": group_transfer.transformations, "queue_id": group_transfer.transfer.queue_id, } assert result.status_code == 200 diff --git a/tests/test_unit/test_transfers/test_file_transfers/test_update_transfer.py b/tests/test_unit/test_transfers/test_file_transfers/test_update_transfer.py index f65a7f12..dd0aabdd 100644 --- a/tests/test_unit/test_transfers/test_file_transfers/test_update_transfer.py +++ b/tests/test_unit/test_transfers/test_file_transfers/test_update_transfer.py @@ -10,58 +10,85 @@ "create_transfer_data", [ { - "type": "s3", - "directory_path": "/some/pure/path", - "file_format": { - "delimiter": ",", - "encoding": "utf-8", - "escape": "\\", - "include_header": False, - "line_sep": "\n", - "quote": '"', - "type": "csv", - "compression": "gzip", + "source_and_target_params": { + "type": "s3", + "directory_path": "/some/pure/path", + "file_format": { + "delimiter": ",", + "encoding": "utf-8", + "escape": "\\", + "include_header": False, + "line_sep": "\n", + "quote": '"', + "type": "csv", + "compression": "gzip", + }, + "options": {}, }, - "options": {}, }, { - "type": "s3", - "directory_path": "/some/excel/path", - "file_format": { - "type": "excel", - "include_header": True, - "start_cell": "A1", + "source_and_target_params": { + "type": "s3", + "directory_path": "/some/excel/path", + "file_format": { + "type": "excel", + "include_header": True, + "start_cell": "A1", + }, + "options": {}, }, - "options": {}, }, { - "type": "s3", - "directory_path": "/some/xml/path", - "file_format": { - "type": "xml", - "root_tag": "data", - "row_tag": "record", - "compression": "bzip2", + "source_and_target_params": { + "type": "s3", + "directory_path": "/some/xml/path", + "file_format": { + "type": "xml", + "root_tag": "data", + "row_tag": "record", + "compression": "bzip2", + }, + "options": {}, }, - "options": {}, }, { - "type": "s3", - "directory_path": "/some/orc/path", - "file_format": { - "type": "orc", - "compression": "snappy", + "source_and_target_params": { + "type": "s3", + "directory_path": "/some/orc/path", + "file_format": { + "type": "orc", + "compression": "snappy", + }, + "options": {}, }, - "options": {}, }, { - "type": "s3", - "directory_path": "/some/parquet/path", - "file_format": { - "type": "parquet", - "compression": "snappy", + "source_and_target_params": { + "type": "s3", + "directory_path": "/some/parquet/path", + "file_format": { + "type": "parquet", + "compression": "snappy", + }, + "options": {}, }, - "options": {}, + "transformations": [ + { + "type": "dataframe_rows_filter", + "filters": [ + { + "type": "greater_than", + "field": "col2", + "value": "30", + }, + { + "type": "like", + "field": "col1", + "value": "some%", + }, + ], + }, + ], }, ], ) @@ -87,6 +114,18 @@ async def test_developer_plus_can_update_s3_transfer( ): # Arrange user = group_transfer.owner_group.get_member_of_role(role_developer_plus) + transformations = [ + { + "type": "dataframe_rows_filter", + "filters": [ + { + "type": "not_equal", + "field": "col2", + "value": None, + }, + ], + }, + ] # Act result = await client.patch( @@ -96,9 +135,10 @@ async def test_developer_plus_can_update_s3_transfer( "source_params": { "type": "s3", "directory_path": "/some/new/test/directory", - "file_format": create_transfer_data["file_format"], + "file_format": create_transfer_data["source_and_target_params"]["file_format"], "options": {"some": "option"}, }, + "transformations": transformations, }, ) @@ -107,7 +147,7 @@ async def test_developer_plus_can_update_s3_transfer( source_params.update( { "directory_path": "/some/new/test/directory", - "file_format": create_transfer_data["file_format"], + "file_format": create_transfer_data["source_and_target_params"]["file_format"], "options": {"some": "option"}, }, ) @@ -126,5 +166,6 @@ async def test_developer_plus_can_update_s3_transfer( "source_params": source_params, "target_params": group_transfer.target_params, "strategy_params": group_transfer.strategy_params, + "transformations": transformations, "queue_id": group_transfer.transfer.queue_id, } diff --git a/tests/test_unit/test_transfers/test_read_transfer.py b/tests/test_unit/test_transfers/test_read_transfer.py index b49bb001..f1886fe8 100644 --- a/tests/test_unit/test_transfers/test_read_transfer.py +++ b/tests/test_unit/test_transfers/test_read_transfer.py @@ -33,6 +33,7 @@ async def test_guest_plus_can_read_transfer( "source_params": group_transfer.source_params, "target_params": group_transfer.target_params, "strategy_params": group_transfer.strategy_params, + "transformations": group_transfer.transformations, "queue_id": group_transfer.transfer.queue_id, } assert result.status_code == 200 @@ -110,6 +111,7 @@ async def test_superuser_can_read_transfer( "source_params": group_transfer.source_params, "target_params": group_transfer.target_params, "strategy_params": group_transfer.strategy_params, + "transformations": group_transfer.transformations, "queue_id": group_transfer.transfer.queue_id, } assert result.status_code == 200 diff --git a/tests/test_unit/test_transfers/test_read_transfers.py b/tests/test_unit/test_transfers/test_read_transfers.py index 2108d0a1..9390c657 100644 --- a/tests/test_unit/test_transfers/test_read_transfers.py +++ b/tests/test_unit/test_transfers/test_read_transfers.py @@ -49,6 +49,7 @@ async def test_guest_plus_can_read_transfers( "source_params": group_transfer.source_params, "target_params": group_transfer.target_params, "strategy_params": group_transfer.strategy_params, + "transformations": group_transfer.transformations, "queue_id": group_transfer.transfer.queue_id, }, ], @@ -116,6 +117,7 @@ async def test_superuser_can_read_transfers( "source_params": group_transfer.source_params, "target_params": group_transfer.target_params, "strategy_params": group_transfer.strategy_params, + "transformations": group_transfer.transformations, "queue_id": group_transfer.transfer.queue_id, }, ], @@ -173,6 +175,7 @@ async def test_search_transfers_with_query( "source_params": group_transfer.source_params, "target_params": group_transfer.target_params, "strategy_params": group_transfer.strategy_params, + "transformations": group_transfer.transformations, "queue_id": group_transfer.transfer.queue_id, }, ], @@ -253,6 +256,7 @@ async def test_filter_transfers( "source_params": group_transfer.source_params, "target_params": group_transfer.target_params, "strategy_params": group_transfer.strategy_params, + "transformations": group_transfer.transformations, "queue_id": group_transfer.transfer.queue_id, }, ], @@ -352,6 +356,7 @@ async def test_filter_transfers_with_multiple_transfers( "source_params": t.source_params, "target_params": t.target_params, "strategy_params": t.strategy_params, + "transformations": t.transformations, "queue_id": t.queue_id, } for t in expected_transfers diff --git a/tests/test_unit/test_transfers/test_update_transfer.py b/tests/test_unit/test_transfers/test_update_transfer.py index 8ca196b6..1ced65c7 100644 --- a/tests/test_unit/test_transfers/test_update_transfer.py +++ b/tests/test_unit/test_transfers/test_update_transfer.py @@ -36,6 +36,7 @@ async def test_developer_plus_can_update_transfer( "source_params": group_transfer.source_params, "target_params": group_transfer.target_params, "strategy_params": group_transfer.strategy_params, + "transformations": group_transfer.transformations, "queue_id": group_transfer.transfer.queue_id, } @@ -89,6 +90,7 @@ async def test_superuser_can_update_transfer( "source_params": group_transfer.source_params, "target_params": group_transfer.target_params, "strategy_params": group_transfer.strategy_params, + "transformations": group_transfer.transformations, "queue_id": group_transfer.transfer.queue_id, } @@ -206,6 +208,7 @@ async def test_check_connection_types_and_its_params_transfer( }, "target_params": group_transfer.target_params, "strategy_params": group_transfer.strategy_params, + "transformations": group_transfer.transformations, "queue_id": group_transfer.transfer.queue_id, } assert result.status_code == 200 diff --git a/tests/test_unit/test_transfers/transfer_fixtures/transfer_fixture.py b/tests/test_unit/test_transfers/transfer_fixtures/transfer_fixture.py index 084b5faa..4ac28bb0 100644 --- a/tests/test_unit/test_transfers/transfer_fixtures/transfer_fixture.py +++ b/tests/test_unit/test_transfers/transfer_fixtures/transfer_fixture.py @@ -109,8 +109,9 @@ async def group_transfer( source_connection_id=source_connection.id, target_connection_id=target_connection.id, queue_id=queue.id, - source_params=create_transfer_data, - target_params=create_transfer_data, + source_params=create_transfer_data.get("source_and_target_params") if create_transfer_data else None, + target_params=create_transfer_data.get("source_and_target_params") if create_transfer_data else None, + transformations=create_transfer_data.get("transformations") if create_transfer_data else None, ) yield MockTransfer( diff --git a/tests/test_unit/test_transfers/transfer_fixtures/transfer_with_user_role_fixtures.py b/tests/test_unit/test_transfers/transfer_fixtures/transfer_with_user_role_fixtures.py index 91ec5403..e11d53fd 100644 --- a/tests/test_unit/test_transfers/transfer_fixtures/transfer_with_user_role_fixtures.py +++ b/tests/test_unit/test_transfers/transfer_fixtures/transfer_with_user_role_fixtures.py @@ -1,4 +1,5 @@ import secrets +from collections.abc import AsyncGenerator import pytest_asyncio from sqlalchemy.ext.asyncio import AsyncSession @@ -38,7 +39,7 @@ async def group_transfer_with_same_name_maintainer_plus( group_transfer: MockTransfer, role_maintainer_plus: UserTestRoles, role_maintainer_or_below_without_guest: UserTestRoles, -) -> str: +) -> AsyncGenerator[str, None]: user = group_transfer.owner_group.get_member_of_role(role_maintainer_plus) await add_user_to_group( @@ -113,7 +114,7 @@ async def group_transfer_and_group_connection_developer_plus( role_developer_plus: UserTestRoles, role_maintainer_or_below_without_guest: UserTestRoles, settings: Settings, -) -> tuple[str, Connection]: +) -> AsyncGenerator[tuple[str, Connection], None]: user = group_transfer.owner_group.get_member_of_role(role_developer_plus) await add_user_to_group( diff --git a/tests/test_unit/utils.py b/tests/test_unit/utils.py index b8330f01..558bf0f4 100644 --- a/tests/test_unit/utils.py +++ b/tests/test_unit/utils.py @@ -180,6 +180,7 @@ async def create_transfer( group_id: int | None = None, source_params: dict | None = None, target_params: dict | None = None, + transformations: list | None = None, is_scheduled: bool = True, schedule: str = "* * * * *", strategy_params: dict | None = None, @@ -193,6 +194,7 @@ async def create_transfer( source_params=source_params or {"type": "postgres", "table_name": "table1"}, target_connection_id=target_connection_id, target_params=target_params or {"type": "postgres", "table_name": "table1"}, + transformations=transformations or [], is_scheduled=is_scheduled, schedule=schedule, strategy_params=strategy_params or {"type": "full"},