diff --git a/airflow/providers/google/cloud/hooks/bigquery.py b/airflow/providers/google/cloud/hooks/bigquery.py index 802a134765d98..3ad29c66f05cf 100644 --- a/airflow/providers/google/cloud/hooks/bigquery.py +++ b/airflow/providers/google/cloud/hooks/bigquery.py @@ -129,7 +129,8 @@ def __init__( def get_conn(self) -> BigQueryConnection: """Get a BigQuery PEP 249 connection object.""" - service = self.get_service() + http_authorized = self._authorize() + service = build("bigquery", "v2", http=http_authorized, cache_discovery=False) return BigQueryConnection( service=service, project_id=self.project_id, @@ -2775,7 +2776,7 @@ def execute(self, operation: str, parameters: dict | None = None) -> None: """ sql = _bind_parameters(operation, parameters) if parameters else operation self.flush_results() - self.job_id = self.hook.run_query(sql) + self.job_id = self._run_query(sql) query_results = self._get_query_result() if "schema" in query_results: @@ -2913,6 +2914,171 @@ def _get_query_result(self) -> dict: return query_results + def _run_query( + self, + sql, + location: str | None = None, + ) -> str: + """Run job query.""" + if not self.project_id: + raise ValueError("The project_id should be set") + + configuration = self._prepare_query_configuration(sql) + job = self.hook.insert_job(configuration=configuration, project_id=self.project_id, location=location) + + return job.job_id + + def _prepare_query_configuration( + self, + sql, + destination_dataset_table: str | None = None, + write_disposition: str = "WRITE_EMPTY", + allow_large_results: bool = False, + flatten_results: bool | None = None, + udf_config: list | None = None, + use_legacy_sql: bool | None = None, + maximum_billing_tier: int | None = None, + maximum_bytes_billed: float | None = None, + create_disposition: str = "CREATE_IF_NEEDED", + query_params: list | None = None, + labels: dict | None = None, + schema_update_options: Iterable | None = None, + priority: str | None = None, + time_partitioning: dict | None = None, + api_resource_configs: dict | None = None, + cluster_fields: list[str] | None = None, + encryption_configuration: dict | None = None, + ): + """Helper method that prepare configuration for query.""" + labels = labels or self.hook.labels + schema_update_options = list(schema_update_options or []) + + priority = priority or self.hook.priority + + if time_partitioning is None: + time_partitioning = {} + + if not api_resource_configs: + api_resource_configs = self.hook.api_resource_configs + else: + _validate_value("api_resource_configs", api_resource_configs, dict) + + configuration = deepcopy(api_resource_configs) + + if "query" not in configuration: + configuration["query"] = {} + else: + _validate_value("api_resource_configs['query']", configuration["query"], dict) + + if sql is None and not configuration["query"].get("query", None): + raise TypeError("`BigQueryBaseCursor.run_query` missing 1 required positional argument: `sql`") + + # BigQuery also allows you to define how you want a table's schema to change + # as a side effect of a query job + # for more details: + # https://cloud.google.com/bigquery/docs/reference/rest/v2/jobs#configuration.query.schemaUpdateOptions + + allowed_schema_update_options = ["ALLOW_FIELD_ADDITION", "ALLOW_FIELD_RELAXATION"] + + if not set(allowed_schema_update_options).issuperset(set(schema_update_options)): + raise ValueError( + f"{schema_update_options} contains invalid schema update options." + f" Please only use one or more of the following options: {allowed_schema_update_options}" + ) + + if schema_update_options: + if write_disposition not in ["WRITE_APPEND", "WRITE_TRUNCATE"]: + raise ValueError( + "schema_update_options is only " + "allowed if write_disposition is " + "'WRITE_APPEND' or 'WRITE_TRUNCATE'." + ) + + if destination_dataset_table: + destination_project, destination_dataset, destination_table = self.hook.split_tablename( + table_input=destination_dataset_table, default_project_id=self.project_id + ) + + destination_dataset_table = { # type: ignore + "projectId": destination_project, + "datasetId": destination_dataset, + "tableId": destination_table, + } + + if cluster_fields: + cluster_fields = {"fields": cluster_fields} # type: ignore + + query_param_list: list[tuple[Any, str, str | bool | None | dict, type | tuple[type]]] = [ + (sql, "query", None, (str,)), + (priority, "priority", priority, (str,)), + (use_legacy_sql, "useLegacySql", self.use_legacy_sql, bool), + (query_params, "queryParameters", None, list), + (udf_config, "userDefinedFunctionResources", None, list), + (maximum_billing_tier, "maximumBillingTier", None, int), + (maximum_bytes_billed, "maximumBytesBilled", None, float), + (time_partitioning, "timePartitioning", {}, dict), + (schema_update_options, "schemaUpdateOptions", None, list), + (destination_dataset_table, "destinationTable", None, dict), + (cluster_fields, "clustering", None, dict), + ] + + for param, param_name, param_default, param_type in query_param_list: + if param_name not in configuration["query"] and param in [None, {}, ()]: + if param_name == "timePartitioning": + param_default = _cleanse_time_partitioning(destination_dataset_table, time_partitioning) + param = param_default + + if param in [None, {}, ()]: + continue + + _api_resource_configs_duplication_check(param_name, param, configuration["query"]) + + configuration["query"][param_name] = param + + # check valid type of provided param, + # it last step because we can get param from 2 sources, + # and first of all need to find it + + _validate_value(param_name, configuration["query"][param_name], param_type) + + if param_name == "schemaUpdateOptions" and param: + self.log.info("Adding experimental 'schemaUpdateOptions': %s", schema_update_options) + + if param_name == "destinationTable": + for key in ["projectId", "datasetId", "tableId"]: + if key not in configuration["query"]["destinationTable"]: + raise ValueError( + "Not correct 'destinationTable' in " + "api_resource_configs. 'destinationTable' " + "must be a dict with {'projectId':'', " + "'datasetId':'', 'tableId':''}" + ) + else: + configuration["query"].update( + { + "allowLargeResults": allow_large_results, + "flattenResults": flatten_results, + "writeDisposition": write_disposition, + "createDisposition": create_disposition, + } + ) + + if ( + "useLegacySql" in configuration["query"] + and configuration["query"]["useLegacySql"] + and "queryParameters" in configuration["query"] + ): + raise ValueError("Query parameters are not allowed when using legacy SQL") + + if labels: + _api_resource_configs_duplication_check("labels", labels, configuration) + configuration["labels"] = labels + + if encryption_configuration: + configuration["query"]["destinationEncryptionConfiguration"] = encryption_configuration + + return configuration + def _bind_parameters(operation: str, parameters: dict) -> str: """Helper method that binds parameters to a SQL query.""" diff --git a/tests/providers/google/cloud/hooks/test_bigquery.py b/tests/providers/google/cloud/hooks/test_bigquery.py index 238be83badc56..47fc20464749e 100644 --- a/tests/providers/google/cloud/hooks/test_bigquery.py +++ b/tests/providers/google/cloud/hooks/test_bigquery.py @@ -1208,7 +1208,7 @@ def test_create_materialized_view(self, mock_bq_client, mock_table): @pytest.mark.db_test class TestBigQueryCursor(_BigQueryBaseTestClass): - @mock.patch("airflow.providers.google.cloud.hooks.bigquery.BigQueryHook.get_service") + @mock.patch("airflow.providers.google.cloud.hooks.bigquery.build") @mock.patch("airflow.providers.google.cloud.hooks.bigquery.BigQueryHook.insert_job") def test_execute_with_parameters(self, mock_insert, _): bq_cursor = self.hook.get_cursor() @@ -1223,7 +1223,7 @@ def test_execute_with_parameters(self, mock_insert, _): } mock_insert.assert_called_once_with(configuration=conf, project_id=PROJECT_ID, location=None) - @mock.patch("airflow.providers.google.cloud.hooks.bigquery.BigQueryHook.get_service") + @mock.patch("airflow.providers.google.cloud.hooks.bigquery.build") @mock.patch("airflow.providers.google.cloud.hooks.bigquery.BigQueryHook.insert_job") def test_execute_many(self, mock_insert, _): bq_cursor = self.hook.get_cursor() @@ -1275,10 +1275,10 @@ def test_format_schema_for_description(self): ("field_3", "STRING", None, None, None, None, False), ] - @mock.patch("airflow.providers.google.cloud.hooks.bigquery.BigQueryHook.get_service") + @mock.patch("airflow.providers.google.cloud.hooks.bigquery.build") @mock.patch("airflow.providers.google.cloud.hooks.bigquery.BigQueryHook.insert_job") - def test_description(self, mock_insert, mock_get_service): - mock_get_query_results = mock_get_service.return_value.jobs.return_value.getQueryResults + def test_description(self, mock_insert, mock_build): + mock_get_query_results = mock_build.return_value.jobs.return_value.getQueryResults mock_execute = mock_get_query_results.return_value.execute mock_execute.return_value = { "schema": { @@ -1292,10 +1292,10 @@ def test_description(self, mock_insert, mock_get_service): bq_cursor.execute("SELECT CURRENT_TIMESTAMP() as ts") assert bq_cursor.description == [("ts", "TIMESTAMP", None, None, None, None, True)] - @mock.patch("airflow.providers.google.cloud.hooks.bigquery.BigQueryHook.get_service") + @mock.patch("airflow.providers.google.cloud.hooks.bigquery.build") @mock.patch("airflow.providers.google.cloud.hooks.bigquery.BigQueryHook.insert_job") - def test_description_no_schema(self, mock_insert, mock_get_service): - mock_get_query_results = mock_get_service.return_value.jobs.return_value.getQueryResults + def test_description_no_schema(self, mock_insert, mock_build): + mock_get_query_results = mock_build.return_value.jobs.return_value.getQueryResults mock_execute = mock_get_query_results.return_value.execute mock_execute.return_value = {} @@ -1369,9 +1369,9 @@ def test_next_buffer(self, mock_get_service): result = bq_cursor.next() assert result is None - @mock.patch("airflow.providers.google.cloud.hooks.bigquery.BigQueryHook.get_service") - def test_next(self, mock_get_service): - mock_get_query_results = mock_get_service.return_value.jobs.return_value.getQueryResults + @mock.patch("airflow.providers.google.cloud.hooks.bigquery.build") + def test_next(self, mock_build): + mock_get_query_results = mock_build.return_value.jobs.return_value.getQueryResults mock_execute = mock_get_query_results.return_value.execute mock_execute.return_value = { "rows": [ @@ -1402,10 +1402,10 @@ def test_next(self, mock_get_service): ) mock_execute.assert_called_once_with(num_retries=bq_cursor.num_retries) - @mock.patch("airflow.providers.google.cloud.hooks.bigquery.BigQueryHook.get_service") + @mock.patch("airflow.providers.google.cloud.hooks.bigquery.build") @mock.patch("airflow.providers.google.cloud.hooks.bigquery.BigQueryCursor.flush_results") - def test_next_no_rows(self, mock_flush_results, mock_get_service): - mock_get_query_results = mock_get_service.return_value.jobs.return_value.getQueryResults + def test_next_no_rows(self, mock_flush_results, mock_build): + mock_get_query_results = mock_build.return_value.jobs.return_value.getQueryResults mock_execute = mock_get_query_results.return_value.execute mock_execute.return_value = {} @@ -1421,10 +1421,10 @@ def test_next_no_rows(self, mock_flush_results, mock_get_service): mock_execute.assert_called_once_with(num_retries=bq_cursor.num_retries) assert mock_flush_results.call_count == 1 - @mock.patch("airflow.providers.google.cloud.hooks.bigquery.BigQueryHook.get_service") + @mock.patch("airflow.providers.google.cloud.hooks.bigquery.build") @mock.patch("airflow.providers.google.cloud.hooks.bigquery.BigQueryHook.insert_job") @mock.patch("airflow.providers.google.cloud.hooks.bigquery.BigQueryCursor.flush_results") - def test_flush_cursor_in_execute(self, _, mock_insert, mock_get_service): + def test_flush_cursor_in_execute(self, _, mock_insert, mock_build): bq_cursor = self.hook.get_cursor() bq_cursor.execute("SELECT %(foo)s", {"foo": "bar"}) assert mock_insert.call_count == 1 @@ -1786,7 +1786,7 @@ def test_run_query_with_arg(self, mock_insert): class TestBigQueryHookLegacySql(_BigQueryBaseTestClass): """Ensure `use_legacy_sql` param in `BigQueryHook` propagates properly.""" - @mock.patch("airflow.providers.google.cloud.hooks.bigquery.BigQueryHook.get_service") + @mock.patch("airflow.providers.google.cloud.hooks.bigquery.build") @mock.patch("airflow.providers.google.cloud.hooks.bigquery.BigQueryHook.insert_job") def test_hook_uses_legacy_sql_by_default(self, mock_insert, _): self.hook.get_first("query") @@ -1797,10 +1797,10 @@ def test_hook_uses_legacy_sql_by_default(self, mock_insert, _): "airflow.providers.google.common.hooks.base_google.GoogleBaseHook.get_credentials_and_project_id", return_value=(CREDENTIALS, PROJECT_ID), ) - @mock.patch("airflow.providers.google.cloud.hooks.bigquery.BigQueryHook.get_service") + @mock.patch("airflow.providers.google.cloud.hooks.bigquery.build") @mock.patch("airflow.providers.google.cloud.hooks.bigquery.BigQueryHook.insert_job") def test_legacy_sql_override_propagates_properly( - self, mock_insert, mock_get_service, mock_get_creds_and_proj_id + self, mock_insert, mock_build, mock_get_creds_and_proj_id ): bq_hook = BigQueryHook(use_legacy_sql=False) bq_hook.get_first("query")