Skip to content

Commit

Permalink
fix(data-warehouse): Trim whitespace in job inputs (#27235)
Browse files Browse the repository at this point in the history
  • Loading branch information
Gilbert09 authored Jan 6, 2025
1 parent e809dc3 commit c4317b0
Show file tree
Hide file tree
Showing 4 changed files with 88 additions and 0 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -51,6 +51,21 @@ def process_incremental_last_value(value: Any | None, field_type: IncrementalFie
return parser.parse(value).date()


def _trim_source_job_inputs(source: ExternalDataSource) -> None:
if not source.job_inputs:
return

did_update_inputs = False
for key, value in source.job_inputs.items():
if isinstance(value, str):
if value.startswith(" ") or value.endswith(" "):
source.job_inputs[key] = value.strip()
did_update_inputs = True

if did_update_inputs:
source.save()


@activity.defn
def import_data_activity_sync(inputs: ImportDataActivityInputs):
logger = bind_temporal_worker_logger_sync(team_id=inputs.team_id)
Expand All @@ -73,6 +88,8 @@ def import_data_activity_sync(inputs: ImportDataActivityInputs):
dataset_name=model.folder_path(),
)

_trim_source_job_inputs(model.pipeline)

reset_pipeline = model.pipeline.job_inputs.get("reset_pipeline", "False") == "True"

schema = (
Expand Down Expand Up @@ -526,4 +543,5 @@ def _run(

source = ExternalDataSource.objects.get(id=inputs.source_id)
source.job_inputs.pop("reset_pipeline", None)

source.save()
39 changes: 39 additions & 0 deletions posthog/temporal/tests/batch_exports/test_import_data.py
Original file line number Diff line number Diff line change
Expand Up @@ -54,6 +54,45 @@ def _setup(team: Team, job_inputs: dict[Any, Any]) -> ImportDataActivityInputs:
return ImportDataActivityInputs(team_id=team.pk, schema_id=schema.pk, source_id=source.pk, run_id=str(job.pk))


@pytest.mark.django_db(transaction=True)
def test_job_inputs_with_whitespace(activity_environment, team, **kwargs):
job_inputs = {
"host": " host.com ",
"port": 5432,
"user": "Username ",
"password": " password",
"database": " database",
"schema": "schema ",
}

activity_inputs = _setup(team, job_inputs)

with (
mock.patch(
"posthog.temporal.data_imports.pipelines.sql_database_v2.sql_source_for_type"
) as sql_source_for_type,
mock.patch("posthog.temporal.data_imports.workflow_activities.import_data_sync._run"),
):
activity_environment.run(import_data_activity_sync, activity_inputs)

sql_source_for_type.assert_called_once_with(
source_type=ExternalDataSource.Type.POSTGRES,
host="host.com",
port="5432",
user="Username",
password="password",
database="database",
sslmode="prefer",
schema="schema",
table_names=["table_1"],
incremental_field=None,
incremental_field_type=None,
db_incremental_field_last_value=None,
team_id=team.id,
using_ssl=True,
)


@pytest.mark.django_db(transaction=True)
def test_postgres_source_without_ssh_tunnel(activity_environment, team, **kwargs):
job_inputs = {
Expand Down
7 changes: 7 additions & 0 deletions posthog/warehouse/api/external_data_source.py
Original file line number Diff line number Diff line change
Expand Up @@ -299,6 +299,13 @@ def create(self, request: Request, *args: Any, **kwargs: Any) -> Response:
data={"message": "Monthly sync limit reached. Please increase your billing limit to resume syncing."},
)

# Strip leading and trailing whitespace
payload = request.data["payload"]
if payload is not None:
for key, value in payload.items():
if isinstance(value, str):
payload[key] = value.strip()

# TODO: remove dummy vars
if source_type == ExternalDataSource.Type.STRIPE:
new_source_model = self._handle_stripe_source(request, *args, **kwargs)
Expand Down
24 changes: 24 additions & 0 deletions posthog/warehouse/api/test/test_external_data_source.py
Original file line number Diff line number Diff line change
Expand Up @@ -810,3 +810,27 @@ def test_source_jobs_pagination(self):
assert response.status_code, status.HTTP_200_OK
assert len(data) == 1
assert data[0]["id"] == str(job3.pk)

def test_trimming_payload(self):
response = self.client.post(
f"/api/projects/{self.team.pk}/external_data_sources/",
data={
"source_type": "Stripe",
"payload": {
"client_secret": " sk_test_123 ",
"account_id": " blah ",
"schemas": [
{"name": "BalanceTransaction", "should_sync": True, "sync_type": "full_refresh"},
],
},
},
)
payload = response.json()

assert response.status_code == 201

source = ExternalDataSource.objects.get(id=payload["id"])
assert source.job_inputs is not None

assert source.job_inputs["stripe_secret_key"] == "sk_test_123"
assert source.job_inputs["stripe_account_id"] == "blah"

0 comments on commit c4317b0

Please sign in to comment.