From 11cdd97d2d2d7dea52a252a2292ad7309c668efa Mon Sep 17 00:00:00 2001 From: Shubham Raj <48172486+shubhamraj-git@users.noreply.github.com> Date: Sun, 29 Dec 2024 23:03:14 +0530 Subject: [PATCH] Import Variable API (#45265) * add the api * tests * ImportVariablesBody * action_if_exists * rename * add desc * fix * remove unwanted teardown --- .../core_api/datamodels/variables.py | 8 ++ .../core_api/openapi/v1-generated.yaml | 92 ++++++++++++++ .../core_api/routes/public/variables.py | 59 ++++++++- airflow/ui/openapi-gen/queries/common.ts | 3 + airflow/ui/openapi-gen/queries/queries.ts | 44 +++++++ .../ui/openapi-gen/requests/schemas.gen.ts | 37 ++++++ .../ui/openapi-gen/requests/services.gen.ts | 32 +++++ airflow/ui/openapi-gen/requests/types.gen.ts | 51 ++++++++ .../core_api/routes/public/test_variables.py | 115 ++++++++++++++++++ 9 files changed, 439 insertions(+), 2 deletions(-) diff --git a/airflow/api_fastapi/core_api/datamodels/variables.py b/airflow/api_fastapi/core_api/datamodels/variables.py index ab40415ac3c2b..2e6f25993a55d 100644 --- a/airflow/api_fastapi/core_api/datamodels/variables.py +++ b/airflow/api_fastapi/core_api/datamodels/variables.py @@ -65,3 +65,11 @@ class VariableCollectionResponse(BaseModel): variables: list[VariableResponse] total_entries: int + + +class VariablesImportResponse(BaseModel): + """Import Variables serializer for responses.""" + + created_variable_keys: list[str] + import_count: int + created_count: int diff --git a/airflow/api_fastapi/core_api/openapi/v1-generated.yaml b/airflow/api_fastapi/core_api/openapi/v1-generated.yaml index aee6f8bacf1bd..a1177ec24096e 100644 --- a/airflow/api_fastapi/core_api/openapi/v1-generated.yaml +++ b/airflow/api_fastapi/core_api/openapi/v1-generated.yaml @@ -5857,6 +5857,68 @@ paths: application/json: schema: $ref: '#/components/schemas/HTTPValidationError' + /public/variables/import: + post: + tags: + - Variable + summary: Import Variables + description: Import variables from a JSON file. + operationId: import_variables + parameters: + - name: action_if_exists + in: query + required: false + schema: + enum: + - overwrite + - fail + - skip + type: string + default: fail + title: Action If Exists + requestBody: + required: true + content: + multipart/form-data: + schema: + $ref: '#/components/schemas/Body_import_variables' + responses: + '200': + description: Successful Response + content: + application/json: + schema: + $ref: '#/components/schemas/VariablesImportResponse' + '401': + content: + application/json: + schema: + $ref: '#/components/schemas/HTTPExceptionResponse' + description: Unauthorized + '403': + content: + application/json: + schema: + $ref: '#/components/schemas/HTTPExceptionResponse' + description: Forbidden + '400': + content: + application/json: + schema: + $ref: '#/components/schemas/HTTPExceptionResponse' + description: Bad Request + '409': + content: + application/json: + schema: + $ref: '#/components/schemas/HTTPExceptionResponse' + description: Conflict + '422': + content: + application/json: + schema: + $ref: '#/components/schemas/HTTPExceptionResponse' + description: Unprocessable Entity /public/dags/{dag_id}/dagRuns/{dag_run_id}/taskInstances/{task_id}/logs/{try_number}: get: tags: @@ -6435,6 +6497,16 @@ components: - status title: BaseInfoResponse description: Base info serializer for responses. + Body_import_variables: + properties: + file: + type: string + format: binary + title: File + type: object + required: + - file + title: Body_import_variables ClearTaskInstancesBody: properties: dry_run: @@ -9709,6 +9781,26 @@ components: - is_encrypted title: VariableResponse description: Variable serializer for responses. + VariablesImportResponse: + properties: + created_variable_keys: + items: + type: string + type: array + title: Created Variable Keys + import_count: + type: integer + title: Import Count + created_count: + type: integer + title: Created Count + type: object + required: + - created_variable_keys + - import_count + - created_count + title: VariablesImportResponse + description: Import Variables serializer for responses. VersionInfo: properties: version: diff --git a/airflow/api_fastapi/core_api/routes/public/variables.py b/airflow/api_fastapi/core_api/routes/public/variables.py index ccc8ee7dc2265..0f02dc14e03a4 100644 --- a/airflow/api_fastapi/core_api/routes/public/variables.py +++ b/airflow/api_fastapi/core_api/routes/public/variables.py @@ -16,9 +16,10 @@ # under the License. from __future__ import annotations -from typing import Annotated +import json +from typing import Annotated, Literal -from fastapi import Depends, HTTPException, Query, status +from fastapi import Depends, HTTPException, Query, UploadFile, status from fastapi.exceptions import RequestValidationError from pydantic import ValidationError from sqlalchemy import select @@ -35,6 +36,7 @@ VariableBody, VariableCollectionResponse, VariableResponse, + VariablesImportResponse, ) from airflow.api_fastapi.core_api.openapi.exceptions import create_openapi_http_exception_doc from airflow.models.variable import Variable @@ -180,3 +182,56 @@ def post_variable( variable = session.scalar(select(Variable).where(Variable.key == post_body.key).limit(1)) return variable + + +@variables_router.post( + "/import", + status_code=status.HTTP_200_OK, + responses=create_openapi_http_exception_doc( + [status.HTTP_400_BAD_REQUEST, status.HTTP_409_CONFLICT, status.HTTP_422_UNPROCESSABLE_ENTITY] + ), +) +def import_variables( + file: UploadFile, + session: SessionDep, + action_if_exists: Literal["overwrite", "fail", "skip"] = "fail", +) -> VariablesImportResponse: + """Import variables from a JSON file.""" + try: + file_content = file.file.read().decode("utf-8") + variables = json.loads(file_content) + + if not isinstance(variables, dict): + raise ValueError("Uploaded JSON must contain key-value pairs.") + except (json.JSONDecodeError, ValueError) as e: + raise HTTPException(status_code=status.HTTP_400_BAD_REQUEST, detail=f"Invalid JSON format: {e}") + + if not variables: + raise HTTPException( + status_code=status.HTTP_422_UNPROCESSABLE_ENTITY, + detail="No variables found in the provided JSON.", + ) + + existing_keys = {variable for variable in session.execute(select(Variable.key)).scalars()} + import_keys = set(variables.keys()) + + matched_keys = existing_keys & import_keys + + if action_if_exists == "fail" and matched_keys: + raise HTTPException( + status_code=status.HTTP_409_CONFLICT, + detail=f"The variables with these keys: {matched_keys} already exists.", + ) + elif action_if_exists == "skip": + create_keys = import_keys - matched_keys + else: + create_keys = import_keys + + for key in create_keys: + Variable.set(key=key, value=variables[key], session=session) + + return VariablesImportResponse( + created_count=len(create_keys), + import_count=len(import_keys), + created_variable_keys=list(create_keys), + ) diff --git a/airflow/ui/openapi-gen/queries/common.ts b/airflow/ui/openapi-gen/queries/common.ts index 3a3dd229ca6ed..eef4fca5b8331 100644 --- a/airflow/ui/openapi-gen/queries/common.ts +++ b/airflow/ui/openapi-gen/queries/common.ts @@ -1773,6 +1773,9 @@ export type PoolServicePostPoolsMutationResult = Awaited< export type VariableServicePostVariableMutationResult = Awaited< ReturnType >; +export type VariableServiceImportVariablesMutationResult = Awaited< + ReturnType +>; export type BackfillServicePauseBackfillMutationResult = Awaited< ReturnType >; diff --git a/airflow/ui/openapi-gen/queries/queries.ts b/airflow/ui/openapi-gen/queries/queries.ts index a6c3b69cd5604..ddec28ddd618e 100644 --- a/airflow/ui/openapi-gen/queries/queries.ts +++ b/airflow/ui/openapi-gen/queries/queries.ts @@ -37,6 +37,7 @@ import { } from "../requests/services.gen"; import { BackfillPostBody, + Body_import_variables, ClearTaskInstancesBody, ConnectionBody, ConnectionBulkBody, @@ -3354,6 +3355,49 @@ export const useVariableServicePostVariable = < }) as unknown as Promise, ...options, }); +/** + * Import Variables + * Import variables from a JSON file. + * @param data The data for the request. + * @param data.formData + * @param data.actionIfExists + * @returns VariablesImportResponse Successful Response + * @throws ApiError + */ +export const useVariableServiceImportVariables = < + TData = Common.VariableServiceImportVariablesMutationResult, + TError = unknown, + TContext = unknown, +>( + options?: Omit< + UseMutationOptions< + TData, + TError, + { + actionIfExists?: "overwrite" | "fail" | "skip"; + formData: Body_import_variables; + }, + TContext + >, + "mutationFn" + >, +) => + useMutation< + TData, + TError, + { + actionIfExists?: "overwrite" | "fail" | "skip"; + formData: Body_import_variables; + }, + TContext + >({ + mutationFn: ({ actionIfExists, formData }) => + VariableService.importVariables({ + actionIfExists, + formData, + }) as unknown as Promise, + ...options, + }); /** * Pause Backfill * @param data The data for the request. diff --git a/airflow/ui/openapi-gen/requests/schemas.gen.ts b/airflow/ui/openapi-gen/requests/schemas.gen.ts index 630698f3cab34..6d87d20de1016 100644 --- a/airflow/ui/openapi-gen/requests/schemas.gen.ts +++ b/airflow/ui/openapi-gen/requests/schemas.gen.ts @@ -488,6 +488,19 @@ export const $BaseInfoResponse = { description: "Base info serializer for responses.", } as const; +export const $Body_import_variables = { + properties: { + file: { + type: "string", + format: "binary", + title: "File", + }, + }, + type: "object", + required: ["file"], + title: "Body_import_variables", +} as const; + export const $ClearTaskInstancesBody = { properties: { dry_run: { @@ -5573,6 +5586,30 @@ export const $VariableResponse = { description: "Variable serializer for responses.", } as const; +export const $VariablesImportResponse = { + properties: { + created_variable_keys: { + items: { + type: "string", + }, + type: "array", + title: "Created Variable Keys", + }, + import_count: { + type: "integer", + title: "Import Count", + }, + created_count: { + type: "integer", + title: "Created Count", + }, + }, + type: "object", + required: ["created_variable_keys", "import_count", "created_count"], + title: "VariablesImportResponse", + description: "Import Variables serializer for responses.", +} as const; + export const $VersionInfo = { properties: { version: { diff --git a/airflow/ui/openapi-gen/requests/services.gen.ts b/airflow/ui/openapi-gen/requests/services.gen.ts index 10dc8dc00a0a0..9100a65d0d185 100644 --- a/airflow/ui/openapi-gen/requests/services.gen.ts +++ b/airflow/ui/openapi-gen/requests/services.gen.ts @@ -182,6 +182,8 @@ import type { GetVariablesResponse, PostVariableData, PostVariableResponse, + ImportVariablesData, + ImportVariablesResponse, ReparseDagFileData, ReparseDagFileResponse, GetHealthResponse, @@ -3154,6 +3156,36 @@ export class VariableService { }, }); } + + /** + * Import Variables + * Import variables from a JSON file. + * @param data The data for the request. + * @param data.formData + * @param data.actionIfExists + * @returns VariablesImportResponse Successful Response + * @throws ApiError + */ + public static importVariables( + data: ImportVariablesData, + ): CancelablePromise { + return __request(OpenAPI, { + method: "POST", + url: "/public/variables/import", + query: { + action_if_exists: data.actionIfExists, + }, + formData: data.formData, + mediaType: "multipart/form-data", + errors: { + 400: "Bad Request", + 401: "Unauthorized", + 403: "Forbidden", + 409: "Conflict", + 422: "Unprocessable Entity", + }, + }); + } } export class DagParsingService { diff --git a/airflow/ui/openapi-gen/requests/types.gen.ts b/airflow/ui/openapi-gen/requests/types.gen.ts index 31c1d21514d94..36ecc53863f19 100644 --- a/airflow/ui/openapi-gen/requests/types.gen.ts +++ b/airflow/ui/openapi-gen/requests/types.gen.ts @@ -138,6 +138,10 @@ export type BaseInfoResponse = { status: string | null; }; +export type Body_import_variables = { + file: Blob | File; +}; + /** * Request body for Clear Task Instances endpoint. */ @@ -1305,6 +1309,15 @@ export type VariableResponse = { is_encrypted: boolean; }; +/** + * Import Variables serializer for responses. + */ +export type VariablesImportResponse = { + created_variable_keys: Array; + import_count: number; + created_count: number; +}; + /** * Version information serializer for responses. */ @@ -2160,6 +2173,13 @@ export type PostVariableData = { export type PostVariableResponse = VariableResponse; +export type ImportVariablesData = { + actionIfExists?: "overwrite" | "fail" | "skip"; + formData: Body_import_variables; +}; + +export type ImportVariablesResponse = VariablesImportResponse; + export type ReparseDagFileData = { fileToken: string; }; @@ -4574,6 +4594,37 @@ export type $OpenApiTs = { }; }; }; + "/public/variables/import": { + post: { + req: ImportVariablesData; + res: { + /** + * Successful Response + */ + 200: VariablesImportResponse; + /** + * Bad Request + */ + 400: HTTPExceptionResponse; + /** + * Unauthorized + */ + 401: HTTPExceptionResponse; + /** + * Forbidden + */ + 403: HTTPExceptionResponse; + /** + * Conflict + */ + 409: HTTPExceptionResponse; + /** + * Unprocessable Entity + */ + 422: HTTPExceptionResponse; + }; + }; + }; "/public/parseDagFile/{file_token}": { put: { req: ReparseDagFileData; diff --git a/tests/api_fastapi/core_api/routes/public/test_variables.py b/tests/api_fastapi/core_api/routes/public/test_variables.py index 5371a3479053f..6f780e9f0cedd 100644 --- a/tests/api_fastapi/core_api/routes/public/test_variables.py +++ b/tests/api_fastapi/core_api/routes/public/test_variables.py @@ -16,6 +16,9 @@ # under the License. from __future__ import annotations +import json +from io import BytesIO + import pytest from airflow.models.variable import Variable @@ -45,6 +48,11 @@ TEST_VARIABLE_SEARCH_DESCRIPTION = "Some description for the variable" +# Helper function to simulate file upload +def create_file_upload(content: dict) -> BytesIO: + return BytesIO(json.dumps(content).encode("utf-8")) + + @provide_session def _create_variables(session) -> None: Variable.set( @@ -391,3 +399,110 @@ def test_post_should_respond_422_when_key_too_large(self, test_client): } ] } + + +class TestImportVariables(TestVariableEndpoint): + @pytest.mark.enable_redact + @pytest.mark.parametrize( + "variables_data, behavior, expected_status_code, expected_created_count, expected_created_keys, expected_conflict_keys", + [ + ( + {"new_key1": "new_value1", "new_key2": "new_value2"}, + "overwrite", + 200, + 2, + {"new_key1", "new_key2"}, + set(), + ), + ( + {"new_key1": "new_value1", "new_key2": "new_value2"}, + "skip", + 200, + 2, + {"new_key1", "new_key2"}, + set(), + ), + ( + {"test_variable_key": "new_value", "new_key": "new_value"}, + "fail", + 409, + 0, + set(), + {"test_variable_key"}, + ), + ( + {"test_variable_key": "new_value", "new_key": "new_value"}, + "skip", + 200, + 1, + {"new_key"}, + {"test_variable_key"}, + ), + ( + {"test_variable_key": "new_value", "new_key": "new_value"}, + "overwrite", + 200, + 2, + {"test_variable_key", "new_key"}, + set(), + ), + ], + ) + def test_import_variables( + self, + test_client, + variables_data, + behavior, + expected_status_code, + expected_created_count, + expected_created_keys, + expected_conflict_keys, + session, + ): + """Test variable import with different behaviors (overwrite, fail, skip).""" + + self.create_variables() + + file = create_file_upload(variables_data) + response = test_client.post( + "/public/variables/import", + files={"file": ("variables.json", file, "application/json")}, + params={"action_if_exists": behavior}, + ) + + assert response.status_code == expected_status_code + + if expected_status_code == 200: + body = response.json() + assert body["created_count"] == expected_created_count + assert set(body["created_variable_keys"]) == expected_created_keys + + elif expected_status_code == 409: + body = response.json() + assert ( + f"The variables with these keys: {expected_conflict_keys} already exists." == body["detail"] + ) + + def test_import_invalid_json(self, test_client): + """Test invalid JSON import.""" + file = BytesIO(b"import variable test") + response = test_client.post( + "/public/variables/import", + files={"file": ("variables.json", file, "application/json")}, + params={"action_if_exists": "overwrite"}, + ) + + assert response.status_code == 400 + assert "Invalid JSON format" in response.json()["detail"] + + def test_import_empty_file(self, test_client): + """Test empty file import.""" + file = create_file_upload({}) + response = test_client.post( + "/public/variables/import", + files={"file": ("empty_variables.json", file, "application/json")}, + params={"action_if_exists": "overwrite"}, + ) + + assert response.status_code == 422 + assert response.json()["detail"] == "No variables found in the provided JSON."