From a6dff8fc1f4566880b06910d0b084b32be1e29eb Mon Sep 17 00:00:00 2001 From: Mike Moore Date: Mon, 1 Jan 2024 13:22:05 +0000 Subject: [PATCH 1/2] feat: start to add typing to methods --- src/bqtools/__init__.py | 6074 ++++++++++++++++++++------------------- 1 file changed, 3047 insertions(+), 3027 deletions(-) diff --git a/src/bqtools/__init__.py b/src/bqtools/__init__.py index f7adac2..34b2e07 100644 --- a/src/bqtools/__init__.py +++ b/src/bqtools/__init__.py @@ -31,12 +31,17 @@ from time import sleep import requests + # handle python 2 and 3 versions of this import six from google.cloud import bigquery, exceptions, storage from googleapiclient import discovery from jinja2 import Environment, select_autoescape, FileSystemLoader, Template import queue +from google.cloud.bigquery.client import Client +from google.cloud.bigquery.dataset import Dataset, DatasetReference +from google.cloud.bigquery.job.query import QueryJob +from typing import Any, Callable, List, Optional, Union __version__ = importlib.metadata.version("bqtools-json") @@ -357,10 +362,14 @@ def __init__(self, objtomatch, keyi, path): class BQQueryError(BQTError): """Error for Query execution error.""" - CUSTOM_ERROR_MESSAGE = "GCP API Error: unable to processquery {0} from GCP {1}:\n{2}" + CUSTOM_ERROR_MESSAGE = ( + "GCP API Error: unable to processquery {0} from GCP {1}:\n{2}" + ) def __init__(self, query, desc, e): - super(BQQueryError, self).__init__(self.CUSTOM_ERROR_MESSAGE.format(query, desc, e)) + super(BQQueryError, self).__init__( + self.CUSTOM_ERROR_MESSAGE.format(query, desc, e) + ) class BQHittingQueryGenerationQuotaLimit(BQTError): @@ -374,3335 +383,3337 @@ def __init__(self, query, desc, e): ) -class BQSyncTask(object): - def __init__(self, function, args): - assert callable(function), "Tasks must be constructed with a function" - assert isinstance(args, list), "Must have arguments" - self.__function = function - self.__args = args - - @property - def function(self): - return self.__function - - @property - def args(self): - return self.__args +class DefaultBQSyncDriver(object): + """This class provides mechanical input to bqsync functions""" - def __eq__(self, other): - return self.args == other.args + threadLocal = threading.local() - def __lt__(self, other): - return self.args < other.args + def __init__( + self, + srcproject: str, + srcdataset: str, + dstdataset: str, + dstproject: Optional[str] = None, + srcbucket: Optional[str] = None, + dstbucket: Optional[str] = None, + remove_deleted_tables: bool = True, + copy_data: bool = True, + copy_types: Optional[List[str]] = None, + check_depth: int = -1, + copy_access: bool = True, + table_view_filter: Optional[List[str]] = None, + table_or_views_to_exclude: Optional[List[str]] = None, + latest_date: Optional[datetime] = None, + days_before_latest_day: Optional[int] = None, + day_partition_deep_check: bool = False, + analysis_project: str = None, + query_cmek: Optional[List[str]] = None, + src_policy_tags: Optional[List[str]] = [], + dst_policy_tags: Optional[List[str]] = [], + ) -> None: + """ + Constructor for base copy driver all other drivers should inherit from this + :param srcproject: The project that is the source for the copy (note all actions are done + inc ontext of source project) + :param srcdataset: The source dataset + :param dstdataset: The destination dataset + :param dstproject: The source project if None assumed to be source project + :param srcbucket: The source bucket when copying cross region data is extracted to this + bucket rewritten to destination bucket + :param dstbucket: The destination bucket where data is loaded from + :param remove_deleted_tables: If table exists in destination but not in source should it + be deleted + :param copy_data: Copy data or just do schema + :param copy_types: Copy object types i.e. TABLE,VIEW,ROUTINE,MODEL + """ + assert query_cmek is None or ( + isinstance(query_cmek, list) and len(query_cmek) == 2 + ), ( + "If cmek key is specified has to be a list and MUST be 2 keys with " + "" + "1st key being source location key and destination key" + ) + self._sessionid = datetime.utcnow().isoformat().replace(":", "-") + if copy_types is None: + copy_types = ["TABLE", "VIEW", "ROUTINE", "MODEL", "MATERIALIZEDVIEW"] + if table_view_filter is None: + table_view_filter = [".*"] + if table_or_views_to_exclude is None: + table_or_views_to_exclude = [] + if dstproject is None: + dstproject = srcproject - def __gt__(self, other): - return self.args > other.args + self._remove_deleted_tables = remove_deleted_tables + # check copy makes some basic sense + assert srcproject != dstproject or srcdataset != dstdataset, ( + "Source and destination " "datasets cannot be the same" + ) + assert latest_date is None or isinstance(latest_date, datetime) -def get_json_struct(jsonobj, template=None): - """ + self._source_project = srcproject + self._source_dataset = srcdataset + self._destination_project = dstproject + self._destination_dataset = dstdataset + self._copy_data = copy_data + self._http = None + self.__copy_q = None + self.__schema_q = None + self.__jobs = [] + self.__copy_types = copy_types + self.reset_stats() + self.__logger = logging + self.__check_depth = check_depth + self.__copy_access = copy_access + self.__table_view_filter = table_view_filter + self.__table_or_views_to_exclude = table_or_views_to_exclude + self.__re_table_view_filter = [] + self.__re_table_or_views_to_exclude = [] + self.__base_predicates = [] + self.__day_partition_deep_check = day_partition_deep_check + self.__analysisproject = self._destination_project + if analysis_project is not None: + self.__analysisproject = analysis_project - :param jsonobj: Object to parse and adjust so could be loaded into big query - :param template: An input object to use as abasis as a template defaullt no template provided - :return: A json object that is a template object. This can be used as input to - get_bq_schema_from_json_repr - """ - if template is None: - template = {} - for key in jsonobj: - newkey = INVALIDBQFIELDCHARS.sub("_", key) - if jsonobj[key] is None: - continue - if newkey not in template: - value = None - if isinstance(jsonobj[key], bool): - value = False - elif isinstance(jsonobj[key], six.string_types): - value = "" - elif isinstance(jsonobj[key], six.text_type): - value = "" - elif isinstance(jsonobj[key], int): - value = 0 - elif isinstance(jsonobj[key], float): - value = 0.0 - elif isinstance(jsonobj[key], date): - value = jsonobj[key] - elif isinstance(jsonobj[key], datetime): - value = jsonobj[key] - elif isinstance(jsonobj[key], dict): - value = get_json_struct(jsonobj[key]) - elif isinstance(jsonobj[key], list): - value = [{}] - if len(jsonobj[key]) > 0: - if not isinstance(jsonobj[key][0], dict): - new_value = [] - for vali in jsonobj[key]: - new_value.append({"value": vali}) - jsonobj[key] = new_value - for list_item in jsonobj[key]: - value[0] = get_json_struct(list_item, value[0]) - else: - raise UnexpectedType(str(jsonobj[key])) - template[newkey] = value - else: - if isinstance(jsonobj[key], type(template[newkey])): - if isinstance(jsonobj[key], dict): - template[key] = get_json_struct(jsonobj[key], template[newkey]) - if isinstance(jsonobj[key], list): - if len(jsonobj[key]) != 0: - if not isinstance(jsonobj[key][0], dict): - new_value = [] - for vali in jsonobj[key]: - new_value.append({"value": vali}) - jsonobj[key] = new_value - for list_item in jsonobj[key]: - template[newkey][0] = get_json_struct( - list_item, template[newkey][0] - ) + if days_before_latest_day is not None: + if latest_date is None: + end_date = "TIMESTAMP_ADD(TIMESTAMP_TRUNC(CURRENT_TIMESTAMP(),DAY),INTERVAL 1 DAY)" else: - # work out best way to loosen types with worst case change to string - newtype = "" - if isinstance(jsonobj[key], float) and isinstance( - template[newkey], int - ): - newtype = 0.0 - elif isinstance(jsonobj[key], datetime) and isinstance( - template[newkey], date - ): - newtype = jsonobj[key] - if not ( - isinstance(jsonobj[key], dict) or isinstance(jsonobj[key], list) - ) and not ( - isinstance(template[newkey], list) - or isinstance(template[newkey], dict) - ): - template[newkey] = newtype - else: - # this is so different type cannot be loosened - raise InconsistentJSONStructure( - key, str(jsonobj[key]), str(template[newkey]) - ) - return template - + end_date = "TIMESTAMP('{}')".format(latest_date.strftime("%Y-%m-%d")) + self.__base_predicates.append( + "{{retpartition}} BETWEEN TIMESTAMP_SUB({end_date}, INTERVAL {" + "days_before_latest_day} * 24 HOUR) AND {end_date}".format( + end_date=end_date, days_before_latest_day=days_before_latest_day + ) + ) -def clean_json_for_bq(anobject): - """ + # now check that from a service the copy makes sense + assert dataset_exists( + self.source_client, self.source_client.dataset(self.source_dataset) + ), ("Source dataset does not exist %r" % self.source_dataset) + assert dataset_exists( + self.destination_client, + self.destination_client.dataset(self.destination_dataset), + ), ( + "Destination dataset does not " "exists %r" % self.destination_dataset + ) - :param anobject: to be converted to big query json compatible format: - :return: cleaned object - """ - newobj = {} - if not isinstance(anobject, dict): - raise NotADictionary(str(anobject)) - for key in anobject: - newkey = INVALIDBQFIELDCHARS.sub("_", key) + # figure out if cross region copy if within region copies optimised to happen in big query + # if cross region buckets need to exist to support copy and they need to be in same region + source_dataset_impl = self.source_dataset_impl + destination_dataset_impl = self.destination_dataset_impl + if query_cmek is None: + query_cmek = [] + if destination_dataset_impl.default_encryption_configuration is not None: + query_cmek.append( + destination_dataset_impl.default_encryption_configuration.kms_key_name + ) + else: + query_cmek.append(None) - value = anobject[key] - if isinstance(value, dict): - value = clean_json_for_bq(value) - if isinstance(value, list): - if len(value) != 0: - if not isinstance(value[0], dict): - new_value = [] - for vali in value: - new_value.append({"value": vali}) - value = new_value - valllist = [] - for vali in value: - vali = clean_json_for_bq(vali) - valllist.append(vali) - value = valllist - newobj[newkey] = value - return newobj + if ( + not query_cmek + and source_dataset_impl.default_encryption_configuration is not None + ): + query_cmek.append( + source_dataset_impl.default_encryption_configuration.kms_key_name + ) + else: + query_cmek.append(None) + self._query_cmek = query_cmek + self._same_region = ( + source_dataset_impl.location == destination_dataset_impl.location + ) + self._source_location = source_dataset_impl.location + self._destination_location = destination_dataset_impl.location + self._src_policy_tags = src_policy_tags + self._dst_policy_tags = dst_policy_tags -def get_bq_schema_from_json_repr(jsondict): - """ - Generate fields structure of Big query resource if the input json structure is vallid - :param jsondict: a template object in json format to use as basis to create a big query - schema object from - :return: a big query schema - """ - fields = [] - for key, data in list(jsondict.items()): - field = {"name": key} - if isinstance(data, bool): - field["type"] = "BOOLEAN" - field["mode"] = "NULLABLE" - elif isinstance(data, six.string_types): - field["type"] = "STRING" - field["mode"] = "NULLABLE" - elif isinstance(data, six.text_type): - field["type"] = "STRING" - field["mode"] = "NULLABLE" - elif isinstance(data, int): - field["type"] = "INTEGER" - field["mode"] = "NULLABLE" - elif isinstance(data, float): - field["type"] = "FLOAT" - field["mode"] = "NULLABLE" - elif isinstance(data, datetime): - field["type"] = "TIMESTAMP" - field["mode"] = "NULLABLE" - elif isinstance(data, date): - field["type"] = "DATE" - field["mode"] = "NULLABLE" - elif isinstance(data, dttime): - field["type"] = "TIME" - field["mode"] = "NULLABLE" - elif isinstance(data, six.binary_type): - field["type"] = "BYTES" - field["mode"] = "NULLABLE" - elif isinstance(data, dict): - field["type"] = "RECORD" - field["mode"] = "NULLABLE" - field["fields"] = get_bq_schema_from_json_repr(data) - elif isinstance(data, list): - field["type"] = "RECORD" - field["mode"] = "REPEATED" - field["fields"] = get_bq_schema_from_json_repr(data[0]) - fields.append(field) - return fields - - -def generate_create_schema(resourcelist, file_handle): - """ - Generates using a jinja template bash command using bq to for a set of schemas - supports views, tables or exetrnal tables. - The resource list is a list of tables as you would get from table.get from big query - or generated by get_bq_schema_from_json_repr - - :param resourcelist: list of resources to genereate code for - :param file_handle: file handle to output too expected to be utf-8 - :return: nothing - """ - jinjaenv = Environment( - loader=FileSystemLoader(os.path.join(_ROOT, "templates")), - autoescape=select_autoescape(["html", "xml"]), - extensions=["jinja2.ext.do", "jinja2.ext.loopcontrols"], - ) - objtemplate = jinjaenv.get_template("bqschema.in") - output = objtemplate.render(resourcelist=resourcelist) - print(output, file=file_handle) - - -def generate_create_schema_file(filename, resourcelist): - """ - Generates using a jinja template bash command using bq to for a set of schemas - supports views, tables or exetrnal tables. - The resource list is a list of tables as you would get from table.get from big query - or generated by get_bq_schema_from_json_repr + # if not same region where are the buckets for copying + if not self.same_region: + assert srcbucket is not None, ( + "Being asked to copy datasets across region but no " + "source bucket is defined these must be in same region " + "as source dataset" + ) + assert isinstance(srcbucket, six.string_types), ( + "Being asked to copy datasets across region but " + "no " + "" + "" + "source bucket is not a string" + ) + self._source_bucket = srcbucket + assert dstbucket is not None, ( + "Being asked to copy datasets across region but no " + "destination bucket is defined these must be in same " + "region " + "as destination dataset" + ) + assert isinstance(dstbucket, six.string_types), ( + "Being asked to copy datasets across region but " + "destination bucket is not a string" + ) + self._destination_bucket = dstbucket + client = storage.Client(project=self.source_project) + src_bucket = client.get_bucket(self.source_bucket) + assert compute_region_equals_bqregion( + src_bucket.location, source_dataset_impl.location + ), ( + "Source bucket " + "location is not " + "" + "" + "" + "" + "" + "" + "" + "" + "same as source " + "dataset location" + ) + dst_bucket = client.get_bucket(self.destination_bucket) + assert compute_region_equals_bqregion( + dst_bucket.location, destination_dataset_impl.location + ), "Destination bucket location is not same as destination dataset location" - :param filename: filename to putput too - :param resourcelist: list of resources to genereate code for - :return:nothing - """ - with open(filename, mode="w+", encoding="utf-8") as file_handle: - generate_create_schema(resourcelist, file_handle) + @property + def session_id(self): + """ + This method returns the uniue identifier for this copy session + :return: The copy drivers session + """ + return self._sessionid + def map_schema_policy_tags(self, schema): + """ + A method that takes as input a big query schema iterates through the schema + and reads policy tags and maps them from src to destination. + :param schema: The schema to map + :return: schema the same schema but with policy tags updated + """ + nschema = [] + for field in schema: + if field.field_type == "RECORD": + tmp_field = field.to_api_repr() + tmp_field["fields"] = [ + i.to_api_repr() for i in self.map_schema_policy_tags(field.fields) + ] + field = bigquery.schema.SchemaField.from_api_repr(tmp_field) + else: + _, field = self.map_policy_tag(field) + nschema.append(field) + return nschema -def dataset_exists(client, dataset_reference): - """Return if a dataset exists. + def map_policy_tag(self, field, dst_tgt=None): + """ + Method that tages a policy tags and will remap + :param field: The field to map + :param dst_tgt: The destination tag + :return: The new policy tag + """ + change = 0 + if field.field_type == "RECORD": + schema = self.map_schema_policy_tags(field.fields) + if field.fields != schema: + change = 1 + tmp_field = field.to_api_repr() + tmp_field["fields"] = [i.to_api_repr() for i in schema] + field = bigquery.schema.SchemaField.from_api_repr(tmp_field) + else: + if field.policy_tags is not None: + field_api_repr = field.to_api_repr() + tags = field_api_repr["policyTags"]["names"] + ntag = [] + for tag in tags: + new_tag = self.map_policy_tag_string(tag) + if new_tag is None: + field_api_repr.pop("policyTags", None) + break + ntag.append(new_tag) + if "policyTags" in field_api_repr: + field_api_repr["policyTags"]["names"] = ntag + field = bigquery.schema.SchemaField.from_api_repr(field_api_repr) + if dst_tgt is None or dst_tgt.policy_tags != field.policy_tags: + change = 1 + return change, bigquery.schema.SchemaField.from_api_repr(field_api_repr) - Args: - client (google.cloud.bigquery.client.Client): - A client to connect to the BigQuery API. - dataset_reference (google.cloud.bigquery.dataset.DatasetReference): - A reference to the dataset to look for. + return change, field - Returns: - bool: ``True`` if the dataset exists, ``False`` otherwise. - """ - from google.cloud.exceptions import NotFound + def map_policy_tag_string(self, src_tag): + """ + This method maps a source tag to a destination policy tag + :param src_tag: The starting table column tags + :return: The expected destination tags + """ + # if same region and src has atag just simply reuse the tag + if self._same_region and not self._src_policy_tags: + return src_tag + try: + # look for the tag in the src destination map + return self._dst_policy_tags[self._src_policy_tags.index(src_tag)] + # if it doesnt return None as the tag + except ValueError: + return None - try: - client.get_dataset(dataset_reference) - return True - except NotFound: - return False + def base_predicates(self, retpartition): + actual_basepredicates = [] + for predicate in self.__base_predicates: + actual_basepredicates.append(predicate.format(retpartition=retpartition)) + return actual_basepredicates + def comparison_predicates(self, table_name, retpartition="_PARTITIONTIME"): + return self.base_predicates(retpartition) -def table_exists(client, table_reference): - """Return if a table exists. + def istableincluded(self, table_name: str) -> bool: + """ + This method when passed a table_name returns true if it should be processed in a copy action + :param table_name: + :return: boolean True then it should be include False then no + """ + if len(self.__re_table_view_filter) == 0: + for filter in self.__table_view_filter: + self.__re_table_view_filter.append(re.compile(filter)) + for filter in self.__table_or_views_to_exclude: + self.__re_table_or_views_to_exclude.append(re.compile(filter)) - Args: - client (google.cloud.bigquery.client.Client): - A client to connect to the BigQuery API. - table_reference (google.cloud.bigquery.table.TableReference): - A reference to the table to look for. + result = False - Returns: - bool: ``True`` if the table exists, ``False`` otherwise. - """ - from google.cloud.exceptions import NotFound + for regexp2check in self.__re_table_view_filter: + if regexp2check.search(table_name): + result = True + break - try: - client.get_table(table_reference) - return True - except NotFound: - return False + if result: + for regexp2check in self.__re_table_or_views_to_exclude: + if regexp2check.search(table_name): + result = False + break + return result -def get_kms_key_name(kms_or_key_version): - # handle sept 2021 move to kms keys having versions and gradually - # ignore end part of version - # projects/methodical-bee-162815/locations/europe-west2/keyRings/cloudStorage/cryptoKeys/cloudStorage/cryptoKeyVersions/6 - # so ignore /cryptoKeyVersions/6 - # include though everything else - return re.findall( - "projects/[^/]+/locations/[^/]+/keyRings/[^/]+/cryptoKeys/[^/]+", - kms_or_key_version, - )[0] + def reset_stats(self) -> None: + self.__bytes_synced = 0 + self.__rows_synced = 0 + self.__bytes_avoided = 0 + self.__rows_avoided = 0 + self.__tables_synced = 0 + self.__materialized_views_synced = 0 + self.__views_synced = 0 + self.__routines_synced = 0 + self.__routines_failed_sync = 0 + self.__routines_avoided = 0 + self.__models_synced = 0 + self.__models_failed_sync = 0 + self.__models_avoided = 0 + self.__views_failed_sync = 0 + self.__tables_failed_sync = 0 + self.__tables_avoided = 0 + self.__view_avoided = 0 + self.__extract_fails = 0 + self.__load_fails = 0 + self.__copy_fails = 0 + self.__query_cache_hits = 0 + self.__total_bytes_processed = 0 + self.__total_bytes_billed = 0 + self.__start_time = None + self.__end_time = None + self.__load_input_file_bytes = 0 + self.__load_input_files = 0 + self.__load_output_bytes = 0 + self.__blob_rewrite_retried_exceptions = 0 + self.__blob_rewrite_unretryable_exceptions = 0 + @property + def blob_rewrite_retried_exceptions(self): + return self.__blob_rewrite_retried_exceptions -def create_schema( - sobject, - schema_depth=0, - fname=None, - dschema=None, - path="$", - tableId=None, - policy_tag_callback=None, -): - schema = [] - if dschema is None: - dschema = {} - dummyfield = bigquery.SchemaField("xxxDummySchemaAsNoneDefinedxxx", "STRING") + @property + def blob_rewrite_unretryable_exceptions(self): + return self.__blob_rewrite_unretryable_exceptions - if fname is not None: - fname = INVALIDBQFIELDCHARS.sub("_", fname) - path = "{}.{}".format(path, fname) + def increment_blob_rewrite_retried_exceptions(self): + self.__blob_rewrite_retried_exceptions += 1 - def _default_policy_callback(tableid, path): - return None + def increment_blob_rewrite_unretryable_exceptions(self): + self.__blob_rewrite_unretryable_exceptions += 1 - if policy_tag_callback is None: - policy_tag_callback = _default_policy_callback + @property + def models_synced(self): + return self.__models_synced - def _create_field( - fname, - sobject, - mode="NULLABLE", - tableid=None, - path=None, - policy_tag_callback=None, - ): - fieldschema = None - if isinstance(sobject, bool): - fieldschema = bigquery.SchemaField( - fname, - "BOOLEAN", - mode=mode, - policy_tags=policy_tag_callback(tableid, path), - ) - elif isinstance(sobject, six.integer_types): - fieldschema = bigquery.SchemaField( - fname, - "INTEGER", - mode=mode, - policy_tags=policy_tag_callback(tableid, path), - ) - elif isinstance(sobject, float): - fieldschema = bigquery.SchemaField( - fname, - "FLOAT", - mode=mode, - policy_tags=policy_tag_callback(tableid, path), - ) - # start adtes and times at lowest levelof hierarchy - # https://docs.python.org/3/library/datetime.html subclass - # relationships - elif isinstance(sobject, datetime): - fieldschema = bigquery.SchemaField( - fname, - "TIMESTAMP", - mode=mode, - policy_tags=policy_tag_callback(tableid, path), - ) - elif isinstance(sobject, date): - fieldschema = bigquery.SchemaField( - fname, "DATE", mode=mode, policy_tags=policy_tag_callback(tableid, path) - ) - elif isinstance(sobject, dttime): - fieldschema = bigquery.SchemaField( - fname, "TIME", mode=mode, policy_tags=policy_tag_callback(tableid, path) - ) - elif isinstance(sobject, six.string_types): - fieldschema = bigquery.SchemaField( - fname, - "STRING", - mode=mode, - policy_tags=policy_tag_callback(tableid, path), - ) - elif isinstance(sobject, bytes): - fieldschema = bigquery.SchemaField( - fname, - "BYTES", - mode=mode, - policy_tags=policy_tag_callback(tableid, path), - ) - else: - raise UnexpectedType(str(type(sobject))) + def increament_models_synced(self): + self.__models_synced += 1 - return fieldschema + @property + def models_failed_sync(self) -> int: + return self.__models_failed_sync - if isinstance(sobject, list): - tschema = [] - # if fname is not None: - # recordschema = bigquery.SchemaField(fname, 'RECORD', mode='REPEATED') - # # recordschema.fields = tschema - # ok so scenarios to handle here are - # creating a schema from a delliberate schema object for these we know - # there will be only 1 item in the ist with al fields - # but also we have creating a schema from an object which is just - # an object so could have in a list more than 1 item and items coudl have - # different fiedls - # - pdschema = dschema - if fname is not None and fname not in dschema: - dschema[fname] = {} - pdschema = dschema[fname] - fieldschema = False - sampleobject = None - for i in sobject: - # lists must have dictionaries and not base types - # if not a dictionary skip - if (isinstance(sobject, dict) or isinstance(sobject, list)) and isinstance( - i, dict - ): - tschema.extend( - create_schema( - i, - dschema=pdschema, - tableId=tableId, - path=path + "[]", - policy_tag_callback=policy_tag_callback, - ) - ) - else: - fieldschema = True - sampleobject = i - break - if len(tschema) == 0 and not fieldschema: - tschema.append(dummyfield) - if fname is not None: - if not fieldschema: - recordschema = bigquery.SchemaField( - fname, "RECORD", mode="REPEATED", fields=tschema - ) - # recordschema.fields = tuple(tschema) - schema.append(recordschema) - else: - schema.append( - _create_field( - fname, - sampleobject, - mode="REPEATED", - tableid=tableId, - path=path + "[]", - policy_tag_callback=policy_tag_callback, - ) - ) - else: - schema = tschema + def increment_models_failed_sync(self): + self.__models_failed_sync += 1 - elif isinstance(sobject, dict): - tschema = [] - # if fname is not None: - # recordschema = bigquery.SchemaField(fname, 'RECORD') - # recordschema.fields = tschema - if len(sobject) > 0: - for j in sorted(sobject): - if j not in dschema: - dschema[j] = {} - if "simple" not in dschema[j]: - fieldschema = create_schema( - sobject[j], - fname=j, - dschema=dschema[j], - tableId=tableId, - path=path, - policy_tag_callback=policy_tag_callback, - ) - if fieldschema is not None: - if fname is not None: - tschema.extend(fieldschema) - else: - schema.extend(fieldschema) - else: - if fname is not None: - tschema.append(dummyfield) - else: - schema.append(dummyfield) + @property + def models_avoided(self): + return self.__models_avoided - if fname is not None: - recordschema = bigquery.SchemaField(fname, "RECORD", fields=tschema) - schema = [recordschema] + def increment_models_avoided(self): + self.__models_avoided += 1 - else: - fieldschema = None - if fname is not None: - if isinstance(sobject, list): - if len(sobject) > 0: - if isinstance(sobject[0], dict): - fieldschema = bigquery.SchemaField(fname, "RECORD") - fieldschema.mode = "REPEATED" - mylist = sobject - head = mylist[0] - fieldschema.fields = create_schema( - head, - schema_depth + 1, - tableId=tableId, - path=path + "[]", - policy_tag_callback=policy_tag_callback, - ) - else: - fieldschema = _create_field( - fname, - sobject, - mode="REPEATED", - tableid=tableId, - path=path + "[]", - policy_tag_callback=policy_tag_callback, - ) - elif isinstance(sobject, dict): - fieldschema = bigquery.SchemaField(fname, "RECORD") - fieldschema.fields = create_schema( - sobject, - schema_depth + 1, - tableId=tableId, - path=path, - policy_tag_callback=policy_tag_callback, - ) - else: - fieldschema = _create_field( - fname, - sobject, - tableid=tableId, - path=path, - policy_tag_callback=policy_tag_callback, - ) + @property + def routines_synced(self): + return self.__routines_synced - if dschema is not None: - dschema["simple"] = True - return [fieldschema] - else: - return [] + @property + def query_cmek(self) -> List[Optional[str]]: + return self._query_cmek - return schema + def increment_routines_synced(self): + self.__routines_synced += 1 + def increment_routines_avoided(self): + self.__routines_avoided += 1 -# convert a dict and with a schema object to assict convert dict into tuple -def dict_plus_schema_2_tuple(data, schema): - """ - :param data: - :param schema: - :return: - """ - otuple = [] + @property + def routines_failed_sync(self) -> int: + return self.__routines_failed_sync - # must iterate through schema to add Nones so dominates - for schema_item in schema: - value = None - if data is not None and schema_item.name in data: - value = data[schema_item.name] - if schema_item.field_type == "RECORD": - ttuple = [] - if schema_item.mode != "REPEATED" or value is None: - value = [value] - for value_item in value: - value = dict_plus_schema_2_tuple(value_item, schema_item.fields) - ttuple.append(value) - value = ttuple - otuple.append(value) + @property + def routines_avoided(self): + return self.__routines_avoided - return tuple(otuple) + def increment_rows_avoided(self): + self.__routines_avoided += 1 + def increment_routines_failed_sync(self): + self.__routines_failed_sync += 1 -# so assumes a list containing list of lists for structs and -# arras but for an arra of structs value is always an array -def tuple_plus_schema_2_dict(data, schema): - """ - :param data: - :param schema: - :return: - """ - rdata = {} - for schema_item, value in zip(schema, data): - if schema_item.field_type == "RECORD": - ldata = [] - if schema_item.mode == "REPEATED": - llist = value - else: - llist = [value] - for list_item in llist: - ldata.append(tuple_plus_schema_2_dict(list_item, schema_item.fields)) - if schema_item.mode == "REPEATED": - value = ldata - else: - value = ldata[0] - rdata[schema_item.name] = value + @property + def bytes_copied_across_region(self): + return self.__load_input_file_bytes - return rdata + @property + def files_copied_across_region(self): + return self.__load_input_files + @property + def bytes_copied(self): + return self.__load_output_bytes -def gen_template_dict(schema): - """ + def increment_load_input_file_bytes(self, value): + self.__load_input_file_bytes += value - :param schema: Take a rest representation of google big query table fields and create a - template json object - :return: - """ - rdata = {} - for schema_item in schema: - value = None - if schema_item.field_type == "RECORD": - tvalue = gen_template_dict(schema_item.fields) - if schema_item.mode == "REPEATED": - value = [tvalue] - else: - value = tvalue - elif schema_item.field_type == "INTEGER": - value = 0 - elif schema_item.field_type == "BOOLEAN": - value = False - elif schema_item.field_type == "FLOAT": - value = 0.0 - elif schema_item.field_type == "STRING": - value = "" - elif schema_item.field_type == "TIMESTAMP": - value = datetime.utcnow() - elif schema_item.field_type == "DATE": - value = date.today() - elif schema_item.field_type == "TIME": - value = datetime.utcnow().time() - elif schema_item.field_type == "BYTES": - value = b"\x00" - else: - raise UnexpectedType(str(type(schema_item))) - rdata[schema_item.name] = value + def increment_load_input_files(self, value): + self.__load_input_files += value - return rdata + def increment_load_output_bytes(self, value): + self.__load_output_bytes += value + @property + def start_time(self): + return self.__start_time -def to_dict(schema): - field_member = { - "name": schema.name, - "type": schema.field_type, - "description": schema.description, - "mode": schema.mode, - "fields": None, - } - if schema.fields is not None: - fields_to_append = [] - for field_item in sorted(schema.fields, key=lambda x : x.name): - fields_to_append.append(to_dict(field_item)) - field_member["fields"] = fields_to_append - return field_member + @property + def end_time(self): + if self.__end_time is None and self.__start_time is not None: + return datetime.utcnow() + return self.__end_time + @property + def sync_time_seconds(self): + if self.__start_time is None: + return None + return (self.__end_time - self.__start_time).seconds -def calc_field_depth(fieldlist, depth=0): - max_depth = depth - recursive_depth = depth - for i in fieldlist: - if "fields" in i: - recursive_depth = calc_field_depth(i["fields"], depth + 1) - if recursive_depth > max_depth: - max_depth = recursive_depth - return max_depth + def start_sync(self) -> None: + self.__start_time = datetime.utcnow() + def end_sync(self) -> None: + self.__end_time = datetime.utcnow() -def trunc_field_depth(fieldlist, maxdepth, depth=0): - new_field = [] - if depth <= maxdepth: - for i in fieldlist: - new_field.append(i) - if "fields" in i: - if depth == maxdepth: - # json.JSONEncoder().encode(fieldlist) - i["type"] = "STRING" - i.pop("fields", None) - else: - i["fields"] = trunc_field_depth(i["fields"], maxdepth, depth + 1) + @property + def query_cache_hits(self): + return self.__query_cache_hits - return new_field + def increment_cache_hits(self): + self.__query_cache_hits += 1 + @property + def total_bytes_processed(self): + return self.__total_bytes_processed -def match_and_addtoschema(objtomatch, schema, evolved=False, path="", logger=None): - pretty_printer = pprint.PrettyPrinter(indent=4) - poplist = {} - - for keyi in objtomatch: - # Create schema does this adjustment so we need to do same in actual object - thekey = INVALIDBQFIELDCHARS.sub("_", keyi) - # Work out if object keys have invalid values and n - if thekey != keyi: - poplist[keyi] = thekey - matchstruct = False - # look for bare list should not have any if known about - # big query cannot hande bare lists - # so to alow schema evoution MUST be removed - # this test if we have a list and a value in it is it a bare type i.e.not a dictionary - # if it is not a dictionary use bare type ist method to cnvert to a dictionary - # where object vallue is a singe key in a dict of value - # this changes each object as well meaning they will load into the evolved schema - # we call this with log error false as this method checks if the key exists and - # if the object is a list and lengh > 0 and if the object at the end is dict or not only - # converts if not a dict - # this is important optimisation as if we checked here it would be a double check - # as lots of objects this overhead is imprtant to minimise hence why this - # looks like it does - do_bare_type_list(objtomatch, keyi, "value") - for schema_item in schema: - if thekey == schema_item.name: - if schema_item.field_type == "RECORD": - if schema_item.mode == "REPEATED": - subevolve = evolved - for listi in objtomatch[keyi]: - # TODO hack to modify fields as .fields is immutable since version - # 0.28 and later but not - # in docs!! - schema_item._fields = list(schema_item.fields) - tsubevolve = match_and_addtoschema( - listi, - schema_item.fields, - evolved=evolved, - path=path + "." + thekey, - ) - if not subevolve and tsubevolve: - subevolve = tsubevolve - evolved = subevolve - else: - # TODO hack to modify fields as .fields is immutable since version 0.28 - # and later but not in - # docs!! - schema_item._fields = list(schema_item.fields) - evolved = match_and_addtoschema( - objtomatch[keyi], schema_item.fields, evolved=evolved - ) - matchstruct = True - break - if matchstruct: - continue - - # Construct addition to schema here based on objtomatch[keyi] schema or object type - # append to the schema list - try: - toadd = create_schema(objtomatch[keyi], fname=keyi) - except Exception: - raise SchemaMutationError(str(objtomatch), keyi, path) - - if toadd is not None: - schema.extend(toadd) - if logger is not None: - logger.warning( - "Evolved path = {}, struct={}".format( - path + "." + thekey, pretty_printer.pformat(objtomatch[keyi]) - ) - ) - evolved = True - - # If values of keys did need changing change them - if len(poplist): - for pop_item in poplist: - objtomatch[poplist[pop_item]] = objtomatch[pop_item] - objtomatch.pop(pop_item, None) - - return evolved - - -def do_bare_type_list(adict, key, detail, logger=None): - """ - Converts a list that is pointed to be a key in a dctionary from - non dictionary object to dictionary object. We do this as bare types - are not allowed in BQ jsons structures. So structures of type - - "foo":[ 1,2,3 ] - - to - - "foo":[{"detail":1},{"detail":2},{"detail":3}] - - Args: - adict: The dictionary the key of the list object is in. This object is modified so mutated. - key: The key name of the list if it does not exist this does nothing. if the item at the - key is not a list it - does nothing if length of list is 0 this does nothing - detail: The name of the field in new sub dictionary of each object - - - Returns: - Nothing. - - Raises: - Nothing - """ - try: - if key in adict: - if key in adict and isinstance(adict[key], list) and len(adict[key]) > 0: - if not isinstance(adict[key][0], dict): - new_list = [] - for list_item in adict[key]: - new_list.append({detail: list_item}) - adict[key] = new_list - else: - if logger is not None: - tbs = traceback.extract_stack() - tbsflat = "\n".join(map(str, tbs)) - logger.error( - "Bare list for key {} in dict {} expected a basic type not converting " - "{}".format(key, str(adict), tbsflat) - ) - except Exception: - raise UnexpectedDict( - "Bare list for key {} in dict {} expected a basic type not converting".format( - key, str(adict) - ) - ) - - -def recurse_and_add_to_schema(schema, oschema): - changes = False - - # Minimum is new schema now this can have less than old - wschema = copy.deepcopy(schema) - - # Everything in old schema stays as a patch - for output_schema_item in oschema: - nschema = [] - # Look for - for new_schema_item in wschema: - if output_schema_item["name"].lower() == new_schema_item.name.lower(): - if output_schema_item["type"] == "RECORD": - rchanges, output_schema_item["fields"] = recurse_and_add_to_schema( - new_schema_item.fields, output_schema_item["fields"] - ) - if rchanges and not changes: - changes = rchanges - else: - nschema.append(new_schema_item) - wschema = nschema - - # Now just has what remain in it. - for wsi in wschema: - changes = True - oschema.append(to_dict(wsi)) - - return (changes, oschema) + def increment_total_bytes_processed(self, total_bytes_processed: int) -> None: + self.__total_bytes_processed += total_bytes_processed + @property + def total_bytes_billed(self): + return self.__total_bytes_processed -FSLST = """#standardSQL -SELECT - ut.*, - fls.firstSeenTime, - fls.lastSeenTime, - fls.numSeen -FROM `{0}.{1}.{2}` as ut -JOIN ( - SELECT - id, - min({4}) AS firstSeenTime, - max({4}) AS lastSeenTime, - COUNT(*) AS numSeen - FROM `{0}.{1}.{2}` - GROUP BY - 1) AS fls -ON fls.id = ut.id AND fls.{3} = {4} -""" -FSLSTDT = ( - "View that shows {} captured values of underlying table for object of a " - "given non repeating key " - "of 'id' {}.{}.{}" -) + def increment_total_bytes_billed(self, total_bytes_billed: int) -> None: + self.__total_bytes_billed += total_bytes_billed + @property + def copy_fails(self) -> int: + return self.__copy_fails -def gen_diff_views( - project, - dataset, - table, - schema, - description="", - intervals=None, - hint_fields=None, - hint_mutable_fields=True, - time_expr=None, - fieldsappend=None, -): - """ + @property + def copy_access(self): + return self.__copy_access - :param project: google project id of underlying table - :param dataset: google dataset id of underlying table - :param table: the base table to do diffs (assumes each time slaice is a view of what data - looked like)) - :param schema: the schema of the base table - :param description: a base description for the views - :param intervals: a list of form [] - :param hint_fields: - :param time_expr: - :param fieldsappend: - :return: - """ + @copy_access.setter + def copy_access(self, value): + self.__copy_access = value - views = [] - fieldsnot4diff = [] - if intervals is None: - intervals = [ - {"day": "1 DAY"}, - {"week": "7 DAY"}, - {"month": "30 DAY"}, - {"fortnight": "14 DAY"}, - ] + def increment_copy_fails(self): + self.__copy_fails += 1 - if time_expr is None: - time_expr = "_PARTITIONTIME" - fieldsnot4diff.append("scantime") - if isinstance(fieldsappend, list): - for fdiffi in fieldsappend: - fieldsnot4diff.append(fdiffi) - if hint_fields is None: - hint_fields = [ - "creationTime", - "usage", - "title", - "description", - "preferred", - "documentationLink", - "discoveryLink", - "numLongTermBytes", - "detailedStatus", - "lifecycleState", - "size", - "md5Hash", - "crc32c", - "timeStorageClassUpdated", - "deleted", - "networkIP", - "natIP", - "changePasswordAtNextLogin", - "status", - "state", - "substate", - "stateStartTime", - "metricValue", - "requestedState", - "statusMessage", - "numWorkers", - "currentStateTime", - "currentState", - "lastLoginTime", - "lastViewedByMeDate", - "modifiedByMeDate", - "etag", - "servingStatus", - "lastUpdated", - "updateTime", - "lastModified", - "lastModifiedTime", - "timeStorageClassUpdated", - "updated", - "numRows", - "numBytes", - "numUsers", - "isoCountryCodes", - "countries", - "uriDescription", - "riskScore", - "controlId", - "resolutionDate", - ] + @property + def load_fails(self) -> int: + return self.__load_fails - fqtablename = "{}.{}.{}".format(project, dataset, table) - basediffview = table + "db" - basefromclause = "\nfrom `{}` as {}".format(fqtablename, "ta" + table) - baseselectclause = """#standardSQL -SELECT - {} AS scantime, - xxrownumbering.partRowNumber""".format( - time_expr - ) - baseendselectclause = """ - JOIN ( - SELECT - scantime, - ROW_NUMBER() OVER(ORDER BY scantime) AS partRowNumber - FROM ( - SELECT - DISTINCT {time_expr} AS scantime, - FROM - `{project}.{dataset}.{table}`)) AS xxrownumbering - ON - {time_expr} = xxrownumbering.scantime - """.format( - time_expr=time_expr, project=project, dataset=dataset, table=table - ) + def increment_load_fails(self): + self.__load_fails += 1 - curtablealias = "ta" + table - fieldprefix = "" - aliasstack = [] - fieldprefixstack = [] - fields4diff = [] + @property + def extract_fails(self) -> int: + return self.__extract_fails - # fields to ignore as in each snapshot and different even if content is the same - fields_update_only = [] - aliasnum = 1 + def increment_extract_fails(self): + self.__extract_fails += 1 - basedata = { - "select": baseselectclause, - "from": basefromclause, - "aliasnum": aliasnum, - } + @property + def check_depth(self): + return self.__check_depth - def recurse_diff_base(schema, fieldprefix, curtablealias): - # pretty_printer = pprint.PrettyPrinter(indent=4) + @check_depth.setter + def check_depth(self, value): + self.__check_depth = value - for schema_item in sorted(schema, key=lambda x : x.name): - if schema_item.name in fieldsnot4diff: - continue - # field names can only be up o 128 characters long - if len(fieldprefix + schema_item.name) > 127: - raise BQHittingQueryGenerationQuotaLimit( - "Field alias is over 128 bytes {} aborting code generation".format( - fieldprefix + schema_item.name - ) - ) - if schema_item.mode != "REPEATED": - if schema_item.field_type == "STRING": - basefield = ',\n ifnull({}.{},"None") as `{}`'.format( - curtablealias, schema_item.name, fieldprefix + schema_item.name - ) - elif schema_item.field_type == "BOOLEAN": - basefield = ",\n ifnull({}.{},False) as `{}`".format( - curtablealias, schema_item.name, fieldprefix + schema_item.name - ) - elif schema_item.field_type == "INTEGER": - basefield = ",\n ifnull({}.{},0) as `{}`".format( - curtablealias, schema_item.name, fieldprefix + schema_item.name - ) - elif schema_item.field_type == "FLOAT": - basefield = ",\n ifnull({}.{},0.0) as `{}`".format( - curtablealias, schema_item.name, fieldprefix + schema_item.name - ) - elif schema_item.field_type == "DATE": - basefield = ",\n ifnull({}.{},DATE(1970,1,1)) as `{}`".format( - curtablealias, schema_item.name, fieldprefix + schema_item.name - ) - elif schema_item.field_type == "DATETIME": - basefield = ( - ",\n ifnull({}.{},DATETIME(1970,1,1,0,0,0)) as `{}`".format( - curtablealias, - schema_item.name, - fieldprefix + schema_item.name, - ) - ) - elif schema_item.field_type == "TIMESTAMP": - basefield = ( - ',\n ifnull({}.{},TIMESTAMP("1970-01-01T00:00:00Z")) as `{' - "}`".format( - curtablealias, - schema_item.name, - fieldprefix + schema_item.name, - ) - ) - elif schema_item.field_type == "TIME": - basefield = ",\n ifnull({}.{},TIME(0,0,0)) as `{}`".format( - curtablealias, schema_item.name, fieldprefix + schema_item.name - ) - elif schema_item.field_type == "BYTES": - basefield = ',\n ifnull({}.{},b"\x00") as `{}`'.format( - curtablealias, schema_item.name, fieldprefix + schema_item.name - ) - elif schema_item.field_type == "RECORD": - aliasstack.append(curtablealias) - fieldprefixstack.append(fieldprefix) - fieldprefix = fieldprefix + schema_item.name - if schema_item.mode == "REPEATED": - oldalias = curtablealias - curtablealias = "A{}".format(basedata["aliasnum"]) - basedata["aliasnum"] = basedata["aliasnum"] + 1 + @property + def views_failed_sync(self) -> int: + return self.__views_failed_sync - basedata["from"] = basedata[ - "from" - ] + "\nLEFT JOIN UNNEST({}) as {}".format( - oldalias + "." + schema_item.name, curtablealias - ) + def increment_views_failed_sync(self): + self.__views_failed_sync += 1 - else: - curtablealias = curtablealias + "." + schema_item.name - recurse_diff_base(schema_item.fields, fieldprefix, curtablealias) - curtablealias = aliasstack.pop() - fieldprefix = fieldprefixstack.pop() - continue - else: - aliasstack.append(curtablealias) - fieldprefixstack.append(fieldprefix) - fieldprefix = fieldprefix + schema_item.name - oldalias = curtablealias - curtablealias = "A{}".format(basedata["aliasnum"]) - basedata["aliasnum"] = basedata["aliasnum"] + 1 - basedata["from"] = basedata[ - "from" - ] + "\nLEFT JOIN UNNEST({}) as {}".format( - oldalias + "." + schema_item.name, curtablealias - ) - if schema_item.field_type == "STRING": - basefield = ',\n ifnull({},"None") as {}'.format( - curtablealias, fieldprefix - ) - elif schema_item.field_type == "BOOLEAN": - basefield = ",\n ifnull({},False) as {}".format( - curtablealias, fieldprefix - ) - elif schema_item.field_type == "INTEGER": - basefield = ",\n ifnull({},0) as {}".format( - curtablealias, fieldprefix - ) - elif schema_item.field_type == "FLOAT": - basefield = ",\n ifnull({},0.0) as {}".format( - curtablealias, fieldprefix - ) - elif schema_item.field_type == "DATE": - basefield = ",\n ifnull({},DATE(1970,1,1)) as {}".format( - curtablealias, fieldprefix - ) - elif schema_item.field_type == "DATETIME": - basefield = ( - ",\n ifnull({},DATETIME(1970,1,1,0,0,0)) as {}".format( - curtablealias, fieldprefix - ) - ) - elif schema_item.field_type == "TIME": - basefield = ",\n ifnull({},TIME(0,0,0)) as {}".format( - curtablealias, fieldprefix - ) - elif schema_item.field_type == "BYTES": - basefield = ',\n ifnull({},b"\x00") as {}'.format( - curtablealias, fieldprefix - ) - if schema_item.field_type == "RECORD": - recurse_diff_base(schema_item.fields, fieldprefix, curtablealias) - else: - # as an array has to be a diff not an update - fields4diff.append(fieldprefix) - basedata["select"] = basedata["select"] + basefield - curtablealias = aliasstack.pop() - fieldprefix = fieldprefixstack.pop() - continue + @property + def tables_failed_sync(self) -> int: + return self.__tables_failed_sync - if hint_mutable_fields: - update_only = False - else: - update_only = True - if schema_item.name in hint_fields: - if hint_mutable_fields: - update_only = True - else: - update_only = False - if update_only: - fields_update_only.append(fieldprefix + schema_item.name) - else: - fields4diff.append(fieldprefix + schema_item.name) - basedata["select"] = basedata["select"] + basefield - return + def increment_tables_failed_sync(self): + self.__tables_failed_sync += 1 - try: - recurse_diff_base(schema, fieldprefix, curtablealias) - allfields = fields4diff + fields_update_only - basechangeselect = basedata["select"] + basedata["from"] + baseendselectclause - joinfields = "" - if len(fields4diff) > 0 and len(fields_update_only) > 0: - joinfields = "\n UNION ALL\n" - auditchangequery = AUDITCHANGESELECT.format( - mutatedimmutablefields="\n UNION ALL".join( - [ - TEMPLATEMUTATEDIMMUTABLE.format(fieldname=field) - for field in fields4diff - ] - ), - mutablefieldchanges="\n UNION ALL".join( - [ - TEMPLATEMUTATEDFIELD.format(fieldname=field) - for field in fields_update_only - ] - ), - beforeorafterfields=",\n".join( - [TEMPLATEBEFOREORAFTER.format(fieldname=field) for field in allfields] - ), - basechangeselect=basechangeselect, - immutablefieldjoin="AND ".join( - [ - TEMPLATEFORIMMUTABLEJOINFIELD.format(fieldname=field) - for field in fields4diff - ] - ), - avoidlastpredicate=AVOIDLASTSETINCROSSJOIN.format( - time_expr=time_expr, project=project, dataset=dataset, table=table - ), - joinfields=joinfields, - ) + @property + def tables_avoided(self): + return self.__tables_avoided - if len(basechangeselect) > 256 * 1024: - raise BQHittingQueryGenerationQuotaLimit( - "Query {} is over 256kb".format(basechangeselect) - ) - views.append( - { - "name": basediffview, - "query": basechangeselect, - "description": "View used as basis for diffview:" + description, - } - ) - if len(auditchangequery) > 256 * 1024: - raise BQHittingQueryGenerationQuotaLimit( - "Query {} is over 256kb".format(auditchangequery) - ) - views.append( - { - "name": "{}diff".format(table), - "query": auditchangequery, - "description": "View calculates what has changed at what time:" - + description, - } - ) + def increment_tables_avoided(self): + self.__tables_avoided += 1 - refbasediffview = "{}.{}.{}".format(project, dataset, basediffview) + @property + def view_avoided(self): + return self.__view_avoided - # Now fields4 diff has field sto compare fieldsnot4diff appear in select but are not - # compared. - # basic logic is like below - # - # select action (a case statement but "Added","Deleted","Sames") - # origfield, - # lastfield, - # if origfield != lastfield diff = 1 else diff = 0 - # from diffbaseview as orig with select of orig timestamp - # from diffbaseview as later with select of later timestamp - # This template logic is then changed for each interval to actually generate concrete views + def increment_view_avoided(self): + self.__view_avoided += 1 - # mutatedimmutablefields = "" - # mutablefieldchanges = "" - # beforeorafterfields = "" - # basechangeselect = "" - # immutablefieldjoin = "" + @property + def bytes_synced(self): + return self.__bytes_synced - diffviewselectclause = """#standardSQL -SELECT - o.scantime as origscantime, - l.scantime as laterscantime,""" - diffieldclause = "" - diffcaseclause = "" - diffwhereclause = "" - diffviewfromclause = """ - FROM (SELECT - * - FROM - `{0}` - WHERE - scantime = ( - SELECT - MAX({1}) - FROM - `{2}.{3}.{4}` - WHERE - {1} < ( - SELECT - MAX({1}) - FROM - `{2}.{3}.{4}`) - AND - {1} < TIMESTAMP_SUB(CURRENT_TIMESTAMP(),INTERVAL %interval%) ) ) o -FULL OUTER JOIN ( - SELECT - * - FROM - `{0}` - WHERE - scantime =( - SELECT - MAX({1}) - FROM - `{2}.{3}.{4}` )) l -ON -""".format( - refbasediffview, time_expr, project, dataset, table - ) + @property + def copy_types(self) -> List[str]: + return self.__copy_types - for f4i in fields4diff: - diffieldclause = ( - diffieldclause - + ",\n o.{} as orig{},\n l.{} as later{},\n case " - "when o.{} = l.{} " - "then 0 else 1 end as diff{}".format(f4i, f4i, f4i, f4i, f4i, f4i, f4i) - ) - if diffcaseclause == "": - diffcaseclause = """ - CASE - WHEN o.{} IS NULL THEN 'Added' - WHEN l.{} IS NULL THEN 'Deleted' - WHEN o.{} = l.{} """.format( - f4i, f4i, f4i, f4i - ) - else: - diffcaseclause = diffcaseclause + "AND o.{} = l.{} ".format(f4i, f4i) + def add_bytes_synced(self, bytes): + self.__bytes_synced += bytes - if diffwhereclause == "": - diffwhereclause = " l.{} = o.{}".format(f4i, f4i) + def update_job_stats(self, job: QueryJob) -> None: + """ + Given a big query job figure out what stats to process + :param job: + :return: None + """ + if isinstance(job, bigquery.QueryJob): + if job.cache_hit: + self.increment_cache_hits() + self.increment_total_bytes_billed(job.total_bytes_billed) + self.increment_total_bytes_processed(job.total_bytes_processed) + + if isinstance(job, bigquery.CopyJob): + if job.error_result is not None: + self.increment_copy_fails() + + if isinstance(job, bigquery.LoadJob): + if job.error_result: + self.increment_load_fails() else: - diffwhereclause = diffwhereclause + "\n AND l.{}=o.{}".format( - f4i, f4i - ) + self.increment_load_input_files(job.input_files) + self.increment_load_input_file_bytes(job.input_file_bytes) + self.increment_load_output_bytes(job.output_bytes) - for f4i in fields_update_only: - diffieldclause = ( - diffieldclause - + ",\n o.{} as orig{},\n l.{} as later{},\n case " - "" - "when o.{} = l.{} " - "then 0 else 1 end as diff{}".format(f4i, f4i, f4i, f4i, f4i, f4i, f4i) - ) - diffcaseclause = diffcaseclause + "AND o.{} = l.{} ".format(f4i, f4i) + @property + def rows_synced(self): + # as time can be different between these assume avoided is always more accurae + if self.rows_avoided > self.__rows_synced: + return self.rows_avoided + return self.__rows_synced - diffcaseclause = ( - diffcaseclause - + """THEN 'Same' - ELSE 'Updated' - END AS action""" - ) + def add_rows_synced(self, rows): + self.__rows_synced += rows - for intervali in intervals: - for keyi in intervali: - view_name = table + "diff" + keyi - view_description = ( - "Diff of {} of underlying table {} description: {}".format( - keyi, table, description - ) - ) - diff_query = ( - diffviewselectclause - + diffcaseclause - + diffieldclause - + diffviewfromclause.replace("%interval%", intervali[keyi]) - + diffwhereclause - ) + @property + def bytes_avoided(self): + return self.__bytes_avoided + + def add_bytes_avoided(self, bytes): + self.__bytes_avoided += bytes + + @property + def rows_avoided(self): + return self.__rows_avoided + + def add_rows_avoided(self, rows): + self.__rows_avoided += rows + + @property + def tables_synced(self): + return self.__tables_synced + + @property + def views_synced(self): + return self.__views_synced + + def increment_tables_synced(self) -> None: + self.__tables_synced += 1 + + def increment_materialized_views_synced(self): + self.__materialized_views_synced += 1 - if len(diff_query) > 256 * 1024: - raise BQHittingQueryGenerationQuotaLimit( - "Query {} is over 256kb".format(diff_query) - ) + def increment_views_synced(self): + self.__views_synced += 1 - views.append( - { - "name": view_name, - "query": diff_query, - "description": view_description, - } - ) + @property + def copy_q(self): + return self.__copy_q - # look for id in top level fields if exists create first seen and last seen views - for i in schema: - if i.name == "id": - fsv = FSLST.format(project, dataset, table, "firstSeenTime", time_expr) - fsd = FSLSTDT.format("first", project, dataset, table) - lsv = FSLST.format(project, dataset, table, "lastSeenTime", time_expr) - lsd = FSLSTDT.format("last", project, dataset, table) + @copy_q.setter + def copy_q(self, value): + self.__copy_q = value - if len(fsv) > 256 * 1024: - raise BQHittingQueryGenerationQuotaLimit( - "Query {} is over 256kb".format(fsv) - ) + @property + def schema_q(self): + return self.__schema_q - views.append({"name": table + "fs", "query": fsv, "description": fsd}) - if len(lsv) > 256 * 1024: - raise BQHittingQueryGenerationQuotaLimit( - "Query {} is over 256kb".format(lsv) - ) - views.append({"name": table + "ls", "query": lsv, "description": lsd}) - break + @schema_q.setter + def schema_q(self, value): + self.__schema_q = value - except BQHittingQueryGenerationQuotaLimit: - pass + @property + def source_location(self) -> str: + return self._source_location - return views + @property + def destination_location(self) -> str: + return self._destination_location + @property + def source_bucket(self) -> str: + return self._source_bucket -def evolve_schema(insertobj, table, client, bigquery, logger=None): - """ + @property + def destination_bucket(self) -> str: + return self._destination_bucket - :param insertobj: json object that represents schema expected - :param table: a table object from python api thats been git through client.get_table - :param client: a big query client object - :param bigquery: big query service as created with google discovery discovery.build( - "bigquery","v2") - :param logger: a google logger class - :return: evolved True or False - """ + @property + def same_region(self) -> bool: + return self._same_region - schema = list(table.schema) - tablechange = False + @property + def source_dataset_impl(self) -> Dataset: + source_datasetref = self.source_client.dataset(self.source_dataset) + return self.source_client.get_dataset(source_datasetref) - evolved = match_and_addtoschema(insertobj, schema) + @property + def destination_dataset_impl(self) -> Dataset: + destination_datasetref = self.destination_client.dataset( + self.destination_dataset + ) + return self.destination_client.get_dataset(destination_datasetref) - if evolved: - if logger is not None: - logger.warning( - "Evolving schema as new field(s) on {}:{}.{} views with * will need " - "reapplying".format(table.project, table.dataset_id, table.table_id) + @property + def query_client(self) -> Client: + """ + Returns the client to be charged for analysis of comparison could be + source could be destination could be another. + By default it is the destination but can be overriden by passing a target project + :return: A big query client for the project to be charged + """ + warnings.filterwarnings( + "ignore", "Your application has authenticated using end user credentials" + ) + """ + Obtains a source client in the current thread only constructs a client once per thread + :return: + """ + source_client = getattr( + DefaultBQSyncDriver.threadLocal, self.__analysisproject, None + ) + if source_client is None: + setattr( + DefaultBQSyncDriver.threadLocal, + self.__analysisproject, + bigquery.Client(project=self.__analysisproject, _http=self.http), ) + return getattr(DefaultBQSyncDriver.threadLocal, self.__analysisproject, None) - treq = bigquery.tables().get( - projectId=table.project, datasetId=table.dataset_id, tableId=table.table_id + @property + def source_client(self) -> Client: + warnings.filterwarnings( + "ignore", "Your application has authenticated using end user credentials" ) - table_data = treq.execute() - oschema = table_data.get("schema") - tablechange, pschema = recurse_and_add_to_schema(schema, oschema["fields"]) - update = {"schema": {"fields": pschema}} - preq = bigquery.tables().patch( - projectId=table.project, - datasetId=table.dataset_id, - tableId=table.table_id, - body=update, + """ + Obtains a source client in the current thread only constructs a client once per thread + :return: + """ + source_client = getattr( + DefaultBQSyncDriver.threadLocal, + self.source_project + self._source_dataset, + None, + ) + if source_client is None: + setattr( + DefaultBQSyncDriver.threadLocal, + self.source_project + self._source_dataset, + bigquery.Client(project=self.source_project, _http=self.http), + ) + return getattr( + DefaultBQSyncDriver.threadLocal, + self.source_project + self._source_dataset, + None, ) - preq.execute() - client.get_table(table) - # table.reload() - - return evolved - - -def create_default_bq_resources( - template, - basename, - project, - dataset, - location, - hint_fields=None, - hint_mutable_fields=True, - optheaddays=None, -): - """ - :param template: a template json object to create a big query schema for - :param basename: a base name of the table to create that will also be used as a basis for views - :param project: the project to create resources in - :param dataset: the datasets to create them in - :param location: The locatin - :return: a list of big query table resources as dicionaries that can be passe dto code - genearteor or used in rest - calls - """ - resourcelist = [] - table = { - "type": "TABLE", - "location": location, - "tableReference": { - "projectId": project, - "datasetId": dataset, - "tableId": basename, - }, - "timePartitioning": {"type": "DAY", "expirationMs": "94608000000"}, - "schema": {}, - } - table["schema"]["fields"] = get_bq_schema_from_json_repr(template) - resourcelist.append(table) - views = gen_diff_views( - project, - dataset, - basename, - create_schema(template), - hint_fields=hint_fields, - hint_mutable_fields=hint_mutable_fields, - ) + @property + def http(self) -> None: + """ + Allow override of http transport per client + usefule for proxy handlng but should be handled by sub-classes default is do nothing + :return: + """ + return self._http - head_view_format = HEADVIEW - # if a max days look back is given use it to optimise bytes in the head view - if optheaddays is not None: - head_view_format = OPTHEADVIEW + @property + def destination_client(self) -> Client: + """ + Obtains a s destination client in current thread only constructs a client once per thread + :return: + """ + source_client = getattr( + DefaultBQSyncDriver.threadLocal, + self.destination_project + self.destination_dataset, + None, + ) + if source_client is None: + setattr( + DefaultBQSyncDriver.threadLocal, + self.destination_project + self.destination_dataset, + bigquery.Client(project=self.destination_project, _http=self.http), + ) + return getattr( + DefaultBQSyncDriver.threadLocal, + self.destination_project + self.destination_dataset, + None, + ) - table = { - "type": "VIEW", - "tableReference": { - "projectId": project, - "datasetId": dataset, - "tableId": "{}head".format(basename), - }, - "view": { - "query": head_view_format.format(project, dataset, basename, optheaddays), - "useLegacySql": False, - }, - } - resourcelist.append(table) - for view_item in views: - table = { - "type": "VIEW", - "tableReference": { - "projectId": project, - "datasetId": dataset, - "tableId": view_item["name"], - }, - "view": {"query": view_item["query"], "useLegacySql": False}, - } - resourcelist.append(table) - return resourcelist + def export_import_format_supported(self, srctable, dsttable=None): + """Calculates a suitable export import type based upon schema + default mecahnism is to use AVRO and SNAPPY as parallel and fast + If though a schema type notsupported by AVRO fall back to jsonnl and gzip + """ + dst_encryption_configuration = None + if dsttable is None and srctable.encryption_configuration is not None: + dst_encryption_configuration = self.calculate_target_cmek_config( + srctable.encryption_configuration + ) + else: + dst_encryption_configuration = None + return ExportImportType(srctable, dsttable, dst_encryption_configuration) -class ViewCompiler(object): - def __init__(self, render_dictionary=None): - if render_dictionary is None: - render_dictionary = {} - self.view_depth_optimiser = {} - self._add_auth_view = {} - self._views = {} - self._render_dictionary = render_dictionary - self._lock = threading.RLock() + @property + def source_project(self) -> str: + return self._source_project @property - def render_dictionary(self): - return self._render_dictionary + def source_dataset(self) -> str: + return self._source_dataset - @render_dictionary.setter - def render_dictionary(self, render_dictionary): - self._render_dictionary = render_dictionary + @property + def jobs(self) -> List[Any]: + return self.__jobs - def render(self, raw_sql): - template_for_render = Template(raw_sql) - return template_for_render.render(self.render_dictionary) + def add_job(self, job): + self.__jobs.append(job) - def add_view_to_process(self, dataset, name, sql, unnest=True, description=None): - standard_sql = True - compiled_sql = self.render(sql) - sql = compiled_sql - # prefix = "" - if sql.strip().lower().find("#legacysql") == 0: - standard_sql = False + @property + def destination_project(self) -> str: + return self._destination_project - splitprojectdataset = ":" - repattern = LEGACY_SQL_PDTCTABLEREGEXP - # no_auth_view = "{}:{}".format(dataset.project, dataset.dataset_id) + @property + def destination_dataset(self) -> str: + return self._destination_dataset - if standard_sql: - splitprojectdataset = "." - repattern = STANDARD_SQL_PDTCTABLEREGEXP + @property + def remove_deleted_tables(self): + return self._remove_deleted_tables - key = "{}{}{}.{}".format( - dataset.project, splitprojectdataset, dataset.dataset_id, name - ) + def day_partition_deep_check(self): + """ - dependsOn = [] + :return: True if should check rows and bytes counts False if notrequired + """ + return self.__day_partition_deep_check - for project_dataset_table in re.findall( - repattern, re.sub(SQL_COMMENT_REGEXP, "", sql).replace("\n", " ") + def extra_dp_compare_functions(self, table): + """ + Function to allow record comparison checks for a day partition + These shoul dbe aggregates for the partition i..e max, min, avg + Override when row count if not sufficient. Row count works for + tables where rows only added non deleted or removed + :return: a comma seperated extension of functions + """ + default = "" + retpartition_time = "_PARTITIONTIME" + + if ( + getattr(table, "time_partitioning", None) + and table.time_partitioning.field is not None ): - dependsOn.append(project_dataset_table) + retpartition_time = "TIMESTAMP({})".format(table.time_partitioning.field) - self._views[key] = { - "sql": sql, - "dataset": dataset, - "name": name, - "dependsOn": dependsOn, - "unnest": unnest, - "description": description, - } + SCHEMA = list(table.schema) - def compile_views(self): - for tranche in self.view_tranche: - for view in self.view_in_tranche(tranche): - view["sql"] = self.compile( - view["dataset"], - view["name"], - view["sql"], - unnest=view["unnest"], - description=view["description"], - rendered=True, - ) + # emulate nonlocal variables this works in python 2.7 and python 3.6 + aliasdict = {"alias": "", "extrajoinpredicates": ""} - @property - def view_tranche(self): - view_tranches = self.plan_view_apply_tranches() - for tranche in view_tranches: - yield tranche + """ + Use FRAM_FINGERPRINT as hash of each value and then summed + basically a merkel function + """ - def view_in_tranche(self, tranche): - for view in sorted(tranche): - yield self._views[view] + def add_data_check(SCHEMA, prefix=None, depth=0): + if prefix is None: + prefix = [] + # add base table alia + prefix.append("zzz") - def plan_view_apply_tranches(self): - view_tranches = [] + expression_list = [] + + # if we are beyond check depth exit + if self.check_depth >= 0 and depth > self.check_depth: + return expression_list + + for field in SCHEMA: + prefix.append(field.name) + if field.mode != "REPEATED": + if self.check_depth >= 0 or ( + self.check_depth >= -1 + and ( + re.search("update.*time", field.name.lower()) + or re.search("modifi.*time", field.name.lower()) + or re.search("version", field.name.lower()) + or re.search("creat.*time", field.name.lower()) + ) + ): + if field.field_type == "STRING": + expression_list.append( + "IFNULL(`{0}`,'')".format("`.`".join(prefix)) + ) + elif field.field_type == "TIMESTAMP": + expression_list.append( + "CAST(IFNULL(`{0}`,TIMESTAMP('1970-01-01')) AS STRING)".format( + "`.`".join(prefix) + ) + ) + elif ( + field.field_type == "INTEGER" or field.field_type == "INT64" + ): + expression_list.append( + "CAST(IFNULL(`{0}`,0) AS STRING)".format( + "`.`".join(prefix) + ) + ) + elif ( + field.field_type == "FLOAT" + or field.field_type == "FLOAT64" + or field.field_type == "NUMERIC" + ): + expression_list.append( + "CAST(IFNULL(`{0}`,0.0) AS STRING)".format( + "`.`".join(prefix) + ) + ) + elif ( + field.field_type == "BOOL" or field.field_type == "BOOLEAN" + ): + expression_list.append( + "CAST(IFNULL(`{0}`,false) AS STRING)".format( + "`.`".join(prefix) + ) + ) + elif field.field_type == "BYTES": + expression_list.append( + "CAST(IFNULL(`{0}`,'') AS STRING)".format( + "`.`".join(prefix) + ) + ) + if field.field_type == "RECORD": + SSCHEMA = list(field.fields) + expression_list.extend( + add_data_check(SSCHEMA, prefix=prefix, depth=depth) + ) + else: + if field.field_type != "RECORD" and ( + self.check_depth >= 0 + or ( + self.check_depth >= -1 + and ( + re.search("update.*time", field.name.lower()) + or re.search("modifi.*time", field.name.lower()) + or re.search("version", field.name.lower()) + or re.search("creat.*time", field.name.lower()) + ) + ) + ): + # add the unnestof repeated base type can use own field name + fieldname = "{}{}".format(aliasdict["alias"], field.name) + aliasdict[ + "extrajoinpredicates" + ] = """{} +LEFT JOIN UNNEST(`{}`) AS `{}`""".format( + aliasdict["extrajoinpredicates"], field.name, fieldname + ) + if field.field_type == "STRING": + expression_list.append("IFNULL(`{0}`,'')".format(fieldname)) + elif field.field_type == "TIMESTAMP": + expression_list.append( + "CAST(IFNULL(`{0}`,TIMESTAMP('1970-01-01')) AS STRING)".format( + fieldname + ) + ) + elif ( + field.field_type == "INTEGER" or field.field_type == "INT64" + ): + expression_list.append( + "CAST(IFNULL(`{0}`,0) AS STRING)".format(fieldname) + ) + elif ( + field.field_type == "FLOAT" + or field.field_type == "FLOAT64" + or field.field_type == "NUMERIC" + ): + expression_list.append( + "CAST(IFNULL(`{0}`,0.0) AS STRING)".format(fieldname) + ) + elif ( + field.field_type == "BOOL" or field.field_type == "BOOLEAN" + ): + expression_list.append( + "CAST(IFNULL(`{0}`,false) AS STRING)".format(fieldname) + ) + elif field.field_type == "BYTES": + expression_list.append( + "CAST(IFNULL(`{0}`,'') AS STRING)".format(fieldname) + ) + if field.field_type == "RECORD": + # if we want to go deeper in checking + if depth < self.check_depth: + # need to unnest as repeated + if aliasdict["alias"] == "": + # uses this to name space from real columns + # bit hopeful + aliasdict["alias"] = "zzz" + + aliasdict["alias"] = aliasdict["alias"] + "z" + + # so need to rest prefix for this record + newprefix = [] + newprefix.append(aliasdict["alias"]) + + # add the unnest + aliasdict[ + "extrajoinpredicates" + ] = """{} +LEFT JOIN UNNEST(`{}`) AS {}""".format( + aliasdict["extrajoinpredicates"], + "`.`".join(prefix), + aliasdict["alias"], + ) - max_tranche_depth = 0 - for view in self._views: - depend_depth = self.calc_view_dependency_depth(view) - self._views[view]["tranche"] = depend_depth - if depend_depth > max_tranche_depth: - max_tranche_depth = depend_depth + # add the fields + expression_list.extend( + add_data_check( + SSCHEMA, prefix=newprefix, depth=depth + 1 + ) + ) + prefix.pop() + return expression_list - for tranche in range(max_tranche_depth + 1): - tranchelist = [] - for view in self._views: - if self._views[view]["tranche"] == tranche: - tranchelist.append(view) - view_tranches.append(tranchelist) + expression_list = add_data_check(SCHEMA) - return view_tranches + if len(expression_list) > 0: + # algorithm to compare sets + # java uses overflowing sum but big query does not allow in sum + # using average as scales sum and available as aggregate + # split top end and bottom end of hash + # that way we maximise bits and fidelity + # and + default = """, + AVG(FARM_FINGERPRINT(CONCAT({0})) & 0x0000000FFFFFFFF) as avgFingerprintLB, + AVG((FARM_FINGERPRINT(CONCAT({0})) & 0xFFFFFFF00000000) >> 32) as avgFingerprintHB,""".format( + ",".join(expression_list) + ) - def calc_view_dependency_depth(self, name, depth=0): - max_depth = depth - for depends_on in self._views[name].get("dependsOn", []): - if depends_on in self._views: - retdepth = self.calc_view_dependency_depth(depends_on, depth=depth + 1) - if retdepth > max_depth: - max_depth = retdepth - return max_depth + predicates = self.comparison_predicates(table.table_id, retpartition_time) + if len(predicates) > 0: + aliasdict[ + "extrajoinpredicates" + ] = """{} +WHERE ({})""".format( + aliasdict["extrajoinpredicates"], ") AND (".join(predicates) + ) - def add_auth_view(self, project_auth, dataset_auth, view_to_authorise): - if project_auth not in self._add_auth_view: - self._add_auth_view[project_auth] = {} - if dataset_auth not in self._add_auth_view[project_auth]: - self._add_auth_view[project_auth][dataset_auth] = [] - found = False - for view in self._add_auth_view[project_auth][dataset_auth]: - if ( - view_to_authorise["tableId"] == view["tableId"] - and view_to_authorise["datasetId"] == view["datasetId"] - and view_to_authorise["projectId"] == view["projectId"] - ): - found = True - break - if not found: - self._add_auth_view[project_auth][dataset_auth].append(view_to_authorise) + return retpartition_time, default, aliasdict["extrajoinpredicates"] @property - def projects_to_authorise_views_in(self): - for project in self._add_auth_view: - yield project + def copy_data(self): + """ + True if to copy data not just structure + False just keeps structure in sync + :return: + """ + return self._copy_data - def datasets_to_authorise_views_in(self, project): - for dataset in self._add_auth_view.get(project, {}): - yield dataset + def table_data_change(self, srctable, dsttable): + """ + Method to allow customisation of detecting if table differs + default method is check rows, numberofbytes and last modified time + this exists to allow something like a query to do comparison + if you want to force a copy this should retrn True + :param srctable: + :param dsttable: + :return: + """ - def update_authorised_view_access(self, project, dataset, current_access_entries): - if project in self._add_auth_view: - if dataset in self._add_auth_view[project]: - expected_auth_view_access = self._add_auth_view[project][dataset] - managed_by_view_compiler = {} - new_access_entries = [] + return False - # iterate expected authorised views - # from this caluclate the project and datsestt we want to reset auth - # views for - for access in expected_auth_view_access: - if access["projectId"] not in managed_by_view_compiler: - managed_by_view_compiler[access["projectId"]] = {} - if ( - access["datasetId"] - not in managed_by_view_compiler[access["projectId"]] - ): - managed_by_view_compiler[access["projectId"]][ - access["datasetId"] - ] = True - new_access_entries.append( - bigquery.dataset.AccessEntry(None, "view", access) - ) + def fault_barrier(self, function, *args): + """ + A fault barrie here to ensure functions called in thread + do n9t exit prematurely + :param function: A function to call + :param args: The functions arguments + :return: + """ + try: + function(*args) + except Exception: + pretty_printer = pprint.PrettyPrinter() + self.get_logger().exception( + "Exception calling function {} args {}".format( + function.__name__, pretty_printer.pformat(args) + ) + ) - for current_access in current_access_entries: - if current_access.entity_type == "view": - if ( - current_access.entity_id["projectId"] - in managed_by_view_compiler - and current_access.entity_id["datasetId"] - in managed_by_view_compiler[ - current_access.entity_id["projectId"] - ] - ): - continue - new_access_entries.append(current_access) - return new_access_entries + def update_source_view_definition(self, view_definition, use_standard_sql): + view_definition = view_definition.replace( + r"`{}.{}.".format(self.source_project, self.source_dataset), + "`{}.{}.".format(self.destination_project, self.destination_dataset), + ) + view_definition = view_definition.replace( + r"[{}:{}.".format(self.source_project, self.source_dataset), + "[{}:{}.".format(self.destination_project, self.destination_dataset), + ) + # this should not be required but seems it is + view_definition = view_definition.replace( + r"[{}.{}.".format(self.source_project, self.source_dataset), + "[{}:{}.".format(self.destination_project, self.destination_dataset), + ) + # support short names + view_definition = view_definition.replace( + r"{}.".format(self.source_dataset), "{}.".format(self.destination_dataset) + ) - return current_access_entries + return view_definition - def compile( - self, dataset, name, sql, unnest=True, description=None, rendered=False - ): - # if already done don't do again - key = dataset.project + ":" + dataset.dataset_id + "." + name - if key in self.view_depth_optimiser: - return ( - self.view_depth_optimiser[key]["prefix"] - + self.view_depth_optimiser[key]["unnested"] + def calculate_target_cmek_config(self, encryption_config): + assert isinstance(encryption_config, bigquery.EncryptionConfiguration) or ( + getattr( + self.destination_dataset_impl, "default_encryption_configuration", None ) + is not None + and self.destination_dataset_impl.default_encryption_configuration + is not None + ), ( + " To recaclculate a new encryption " + "config the original config has to be passed in and be of class " + "bigquery.EncryptionConfig" + ) - if not rendered: - sql = self.render(sql) - - standard_sql = True - # standard sql limit 1Mb - # https://cloud.google.com/bigquery/quotas - max_query_size = 256 * 1024 - compiled_sql = sql - prefix = "" + # if destination dataset has default kms key already, just use the same + if ( + getattr( + self.destination_dataset_impl, "default_encryption_configuration", None + ) + is not None + and self.destination_dataset_impl.default_encryption_configuration + is not None + ): + return self.destination_dataset_impl.default_encryption_configuration - if sql.strip().lower().find("#standardsql") == 0: - prefix = "#standardSQL\n" - else: - max_query_size = 256 * 1024 - standard_sql = False - prefix = "#legacySQL\n" + # if a global key or same region we are good to go + if ( + self.same_region + or encryption_config.kms_key_name.find("/locations/global/") != -1 + ): + # strip off version if exists + return bigquery.EncryptionConfiguration( + get_kms_key_name(encryption_config.kms_key_name) + ) - # get rid of nested comments as they can break this even if in a string in a query - prefix = ( - prefix - + r""" --- =================================================================================== --- --- ViewCompresser Output --- --- \/\/\/\/\/\/Original SQL Below\/\/\/\/ -""" + # if global key can still be used + # if comparing table key get rid fo version + parts = get_kms_key_name(encryption_config.kms_key_name).split("/") + parts[3] = MAPBQREGION2KMSREGION.get( + self.destination_location, self.destination_location.lower() ) - for line in sql.splitlines(): - prefix = prefix + "-- " + line + "\n" - prefix = ( - prefix - + r""" --- --- /\/\/\/\/\/\Original SQL Above/\/\/\/\ --- --- Compiled SQL below --- =================================================================================== -""" + + return bigquery.encryption_configuration.EncryptionConfiguration( + kms_key_name="/".join(parts) ) - splitprojectdataset = ":" - repattern = LEGACY_SQL_PDCTABLEREGEXP - no_auth_view = "{}:{}".format(dataset.project, dataset.dataset_id) + def copy_access_to_destination(self) -> None: + # for those not created compare data structures + # copy data + # compare data + # copy views + # copy dataset permissions + if self.copy_access: + src_dataset = self.source_client.get_dataset( + self.source_client.dataset(self.source_dataset) + ) + dst_dataset = self.destination_client.get_dataset( + self.destination_client.dataset(self.destination_dataset) + ) + access_entries = src_dataset.access_entries + dst_access_entries = [] + for access in access_entries: + newaccess = access + if access.role is None: + # if not copying views these will fail + if "VIEW" not in self.copy_types: + continue + newaccess = self.create_access_view(access.entity_id) + dst_access_entries.append(newaccess) + dst_dataset.access_entries = dst_access_entries - if standard_sql: - splitprojectdataset = "." - repattern = STANDARD_SQL_PDCTABLEREGEXP - no_auth_view = "{}.{}".format(dataset.project, dataset.dataset_id) + fields = ["access_entries"] + if dst_dataset.description != src_dataset.description: + dst_dataset.description = src_dataset.description + fields.append("description") - # we unnest everything to get to underlying authorised views - # needed - with self._lock: - for i in self.view_depth_optimiser: - # relaces a table or view name with sql - if not standard_sql: - compiled_sql = compiled_sql.replace( - "[" + i + "]", - "( /* flattened view [-" - + i - + "-]*/ " - + self.view_depth_optimiser[i]["unnested"] - + ")", - ) - else: - compiled_sql = compiled_sql.replace( - "`" + i.replace(":", ".") + "`", - "( /* flattened view `-" - + i - + "-`*/ " - + self.view_depth_optimiser[i]["unnested"] - + ")", - ) + if dst_dataset.friendly_name != src_dataset.friendly_name: + dst_dataset.friendly_name = src_dataset.friendly_name + fields.append("friendly_name") - # strip off comments before analysing queries for tables used in the query and remove - # newlines - # to avoid false positives - # use set to get unique values - for project_dataset in set( - re.findall( - repattern, - re.sub(SQL_COMMENT_REGEXP, "", compiled_sql).replace("\n", " "), - ) - ): - # if is a table in same dataset nothing to authorise - if project_dataset != no_auth_view: - project_auth, dataset_auth = project_dataset.split(splitprojectdataset) - # spool the view needs authorisation for the authorisation task - self.add_auth_view( - project_auth, - dataset_auth, - { - "projectId": dataset.project, - "datasetId": dataset.dataset_id, - "tableId": name, - }, + if ( + dst_dataset.default_table_expiration_ms + != src_dataset.default_table_expiration_ms + ): + dst_dataset.default_table_expiration_ms = ( + src_dataset.default_table_expiration_ms ) - unnest = False + fields.append("default_table_expiration_ms") - # and we put back original if we want to hide logic - if not unnest: - compiled_sql = sql + if getattr(dst_dataset, "default_partition_expiration_ms", None): + if ( + dst_dataset.default_partition_expiration_ms + != src_dataset.default_partition_expiration_ms + ): + dst_dataset.default_partition_expiration_ms = ( + src_dataset.default_partition_expiration_ms + ) + fields.append("default_partition_expiration_ms") - # look to keep queriesbelow maximumsize - if len(prefix + compiled_sql) > max_query_size: - # strip out comment - if standard_sql: - prefix = "#standardSQL\n" - else: - prefix = "#legacySQL\n" - # if still too big strip out other comments - # and extra space - if len(prefix + compiled_sql) > max_query_size: - nsql = "" - for line in compiled_sql.split("\n"): - # if not a comment - trimline = line.strip() - if trimline[:2] != "--": - " ".join(trimline.split()) - nsql = nsql + "\n" + trimline - compiled_sql = nsql + # compare 2 dictionaries that are simple key, value + x = dst_dataset.labels + y = src_dataset.labels - # if still too big go back to original sql stripped - if len(prefix + compiled_sql) > max_query_size: - if len(sql) > max_query_size: - nsql = "" - for line in sql.split("\n"): - trimline = line.strip() - # if not a comment - if trimline[:1] != "#": - " ".join(line.split()) - nsql = nsql + "\n" + trimline - compiled_sql = nsql - else: - compiled_sql = sql + # get shared key values + shared_items = {k: x[k] for k in x if k in y and x[k] == y[k]} - with self._lock: - self.view_depth_optimiser[key] = { - "raw": sql, - "unnested": compiled_sql, - "prefix": prefix, - } + # must be same size and values if not set labels + if len(dst_dataset.labels) != len(src_dataset.labels) or len( + shared_items + ) != len(src_dataset.labels): + dst_dataset.labels = src_dataset.labels + fields.append("labels") - return prefix + compiled_sql + if getattr(dst_dataset, "default_encryption_configuration", None): + if not ( + src_dataset.default_encryption_configuration is None + and dst_dataset.default_encryption_configuration is None + ): + # if src_dataset.default_encryption_configuration is None: + # dst_dataset.default_encryption_configuration = None + # else: + # dst_dataset.default_encryption_configuration = \ + # self.calculate_target_cmek_config( + # src_dataset.default_encryption_configuration) + # equate dest kms config to src only if it's None + if dst_dataset.default_encryption_configuration is None: + dst_dataset.default_encryption_configuration = ( + self.calculate_target_cmek_config( + src_dataset.default_encryption_configuration + ) + ) + fields.append("default_encryption_configuration") -def compute_region_equals_bqregion(compute_region, bq_region): - if compute_region == bq_region or compute_region.lower() == bq_region.lower(): - bq_compute_region = bq_region.lower() - else: - bq_compute_region = MAPBQREGION2KMSREGION.get(bq_region, bq_region.lower()) - return compute_region.lower() == bq_compute_region + try: + self.destination_client.update_dataset(dst_dataset, fields) + except exceptions.Forbidden: + self.logger.error( + "Unable to det permission on {}.{} dataset as Forbidden".format( + self.destination_project, self.destination_dataset + ) + ) + except exceptions.BadRequest: + self.logger.error( + "Unable to det permission on {}.{} dataset as BadRequest".format( + self.destination_project, self.destination_dataset + ) + ) + def create_access_view(self, entity_id): + """ + Convert an old view authorised view + to a new one i.e. change project id -def run_query( - client, - query, - logger, - desctext="", - location=None, - max_results=10000, - callback_on_complete=None, - labels=None, - params=None, - query_cmek=None, -): - """ - Runa big query query and yield on each row returned as a generator - :param client: The BQ client to use to run the query - :param query: The query text assumed standardsql unless starts with #legcaySQL - :param desctext: Some descriptive text to put out if in debug mode - :return: nothing - """ - use_legacy_sql = False - if query.lower().find("#legacysql") == 0: - use_legacy_sql = True + :param entity_id: + :return: a view { + ... 'projectId': 'my-project', + ... 'datasetId': 'my_dataset', + ... 'tableId': 'my_table' + ... } + """ + if ( + entity_id["projectId"] == self.source_project + and entity_id["datasetId"] == self.source_dataset + ): + entity_id["projectId"] = self.destination_project + entity_id["datasetId"] = self.destination_dataset - def _get_parameter(name, argtype): - def _get_type(argtype): - ptype = None - if isinstance(argtype, six.string_types): - ptype = "STRING" - elif isinstance(argtype, int): - ptype = "INT64" - elif isinstance(argtype, float): - ptype = "FLOAT64" - elif isinstance(argtype, bool): - ptype = "BOOL" - elif isinstance(argtype, datetime): - ptype = "TIMESTAMP" - elif isinstance(argtype, date): - ptype = "DATE" - elif isinstance(argtype, bytes): - ptype = "BYTES" - elif isinstance(argtype, decimal.Decimal): - ptype = "NUMERIC" - else: - raise TypeError("Unrcognized type for qury paramter") - return ptype + return bigquery.AccessEntry(None, "view", entity_id) - if isinstance(argtype, list): - if argtype: - if isinstance(argtype[0], dict): - struct_list = [] - for item in argtype: - struct_list.append(_get_parameter(None, item)) - return bigquery.ArrayQueryParameter(name, "STRUCT", struct_list) - else: - return bigquery.ArrayQueryParameter( - name, _get_type(argtype[0]), argtype - ) - else: - return None + @property + def logger(self): + return self.__logger - if isinstance(argtype, dict): - struct_param = [] - for key in argtype: - struct_param.append(_get_parameter(key, argtype[key])) - return bigquery.StructQueryParameter(name, *struct_param) + @logger.setter + def logger(self, alogger): + self.__logger = alogger - return bigquery.ScalarQueryParameter(name, _get_type(argtype), argtype) + def get_logger(self): + """ + Returns the python logger to use for logging errors and issues + :return: + """ + return self.__logger - query_parameters = [] - if params is not None: - # https://cloud.google.com/bigquery/docs/parameterized-queries#python - # named parameters - if isinstance(params, dict): - for key in params: - param = _get_parameter(key, params[key]) - if param is not None: - query_parameters.append(param) - # positional paramters - elif isinstance(params, list): - for p in params: - param = _get_parameter(None, p) - if param is not None: - query_parameters.append(param) - else: - raise TypeError("Query parameter not a dict or a list") - job_config = bigquery.QueryJobConfig(query_parameters=query_parameters) - job_config.maximum_billing_tier = 10 - job_config.use_legacy_sql = use_legacy_sql - if query_cmek is not None: - job_config.destination_encryption_configuration = ( - bigquery.EncryptionConfiguration(kms_key_name=query_cmek) - ) - if labels is not None: - job_config.labels = labels +class BQSyncTask(object): + def __init__( + self, function: Callable, args: List[Union[DefaultBQSyncDriver, str]] + ) -> None: + assert callable(function), "Tasks must be constructed with a function" + assert isinstance(args, list), "Must have arguments" + self.__function = function + self.__args = args - query_job = client.query(query, job_config=job_config, location=location) + @property + def function(self): + return self.__function - pretty_printer = pprint.PrettyPrinter(indent=4) - results = False - while True: - query_job.reload() # Refreshes the state via a GET request. + @property + def args(self) -> List[Union[DefaultBQSyncDriver, str]]: + return self.__args - if query_job.state == "DONE": - if query_job.error_result: - errtext = "Query error {}{}".format( - pretty_printer.pformat(query_job.error_result), - pretty_printer.pformat(query_job.errors), - ) - logger.error(errtext, exc_info=True) - raise BQQueryError(query, desctext, errtext) - else: - results = True - break + def __eq__(self, other: "BQSyncTask") -> bool: + return self.args == other.args - if results: - # query_results = query_job.results() + def __lt__(self, other: "BQSyncTask") -> bool: + return self.args < other.args - # Drain the query results by requesting - # a page at a time. - # page_token = None + def __gt__(self, other): + return self.args > other.args - for irow in query_job.result(): - yield irow - if callback_on_complete is not None and callable(callback_on_complete): - callback_on_complete(query_job) +def get_json_struct(jsonobj, template=None): + """ - return + :param jsonobj: Object to parse and adjust so could be loaded into big query + :param template: An input object to use as abasis as a template defaullt no template provided + :return: A json object that is a template object. This can be used as input to + get_bq_schema_from_json_repr + """ + if template is None: + template = {} + for key in jsonobj: + newkey = INVALIDBQFIELDCHARS.sub("_", key) + if jsonobj[key] is None: + continue + if newkey not in template: + value = None + if isinstance(jsonobj[key], bool): + value = False + elif isinstance(jsonobj[key], six.string_types): + value = "" + elif isinstance(jsonobj[key], six.text_type): + value = "" + elif isinstance(jsonobj[key], int): + value = 0 + elif isinstance(jsonobj[key], float): + value = 0.0 + elif isinstance(jsonobj[key], date): + value = jsonobj[key] + elif isinstance(jsonobj[key], datetime): + value = jsonobj[key] + elif isinstance(jsonobj[key], dict): + value = get_json_struct(jsonobj[key]) + elif isinstance(jsonobj[key], list): + value = [{}] + if len(jsonobj[key]) > 0: + if not isinstance(jsonobj[key][0], dict): + new_value = [] + for vali in jsonobj[key]: + new_value.append({"value": vali}) + jsonobj[key] = new_value + for list_item in jsonobj[key]: + value[0] = get_json_struct(list_item, value[0]) + else: + raise UnexpectedType(str(jsonobj[key])) + template[newkey] = value + else: + if isinstance(jsonobj[key], type(template[newkey])): + if isinstance(jsonobj[key], dict): + template[key] = get_json_struct(jsonobj[key], template[newkey]) + if isinstance(jsonobj[key], list): + if len(jsonobj[key]) != 0: + if not isinstance(jsonobj[key][0], dict): + new_value = [] + for vali in jsonobj[key]: + new_value.append({"value": vali}) + jsonobj[key] = new_value + for list_item in jsonobj[key]: + template[newkey][0] = get_json_struct( + list_item, template[newkey][0] + ) + else: + # work out best way to loosen types with worst case change to string + newtype = "" + if isinstance(jsonobj[key], float) and isinstance( + template[newkey], int + ): + newtype = 0.0 + elif isinstance(jsonobj[key], datetime) and isinstance( + template[newkey], date + ): + newtype = jsonobj[key] + if not ( + isinstance(jsonobj[key], dict) or isinstance(jsonobj[key], list) + ) and not ( + isinstance(template[newkey], list) + or isinstance(template[newkey], dict) + ): + template[newkey] = newtype + else: + # this is so different type cannot be loosened + raise InconsistentJSONStructure( + key, str(jsonobj[key]), str(template[newkey]) + ) + return template -class ExportImportType(object): +def clean_json_for_bq(anobject): """ - Class that calculate the export import types that are best to use to copy the table - passed in initialiser across region. + + :param anobject: to be converted to big query json compatible format: + :return: cleaned object """ + newobj = {} + if not isinstance(anobject, dict): + raise NotADictionary(str(anobject)) + for key in anobject: + newkey = INVALIDBQFIELDCHARS.sub("_", key) - def __init__(self, srctable, dsttable=None, dst_encryption_configuration=None): - """ - Construct an ExportImportType around a table that describes best format to copy this - table across region - how to compress - :param srctable: A big query table implementation that is he source of the copy - :param dsttable: optional the target definition if not specified specificfication of - source used if provided it dominates - :param dst_encryption_configuration: if set overrides keys on tables some calc has driven - this - """ - assert isinstance(srctable, bigquery.Table), ( - "Export Import Type MUST be constructed with a " "bigquery.Table object" - ) - assert dsttable is None or isinstance(dsttable, bigquery.Table), ( - "Export Import dsttabl Type MUST " - "be constructed with a " - "bigquery.Table object or None" - ) + value = anobject[key] + if isinstance(value, dict): + value = clean_json_for_bq(value) + if isinstance(value, list): + if len(value) != 0: + if not isinstance(value[0], dict): + new_value = [] + for vali in value: + new_value.append({"value": vali}) + value = new_value + valllist = [] + for vali in value: + vali = clean_json_for_bq(vali) + valllist.append(vali) + value = valllist + newobj[newkey] = value + return newobj - if dsttable is None: - self.__table = srctable - else: - self.__table = dsttable - self._dst_encryption_configuration = None - if dst_encryption_configuration is not None: - self._dst_encryption_configuration = dst_encryption_configuration +def get_bq_schema_from_json_repr(jsondict): + """ + Generate fields structure of Big query resource if the input json structure is vallid + :param jsondict: a template object in json format to use as basis to create a big query + schema object from + :return: a big query schema + """ + fields = [] + for key, data in list(jsondict.items()): + field = {"name": key} + if isinstance(data, bool): + field["type"] = "BOOLEAN" + field["mode"] = "NULLABLE" + elif isinstance(data, six.string_types): + field["type"] = "STRING" + field["mode"] = "NULLABLE" + elif isinstance(data, six.text_type): + field["type"] = "STRING" + field["mode"] = "NULLABLE" + elif isinstance(data, int): + field["type"] = "INTEGER" + field["mode"] = "NULLABLE" + elif isinstance(data, float): + field["type"] = "FLOAT" + field["mode"] = "NULLABLE" + elif isinstance(data, datetime): + field["type"] = "TIMESTAMP" + field["mode"] = "NULLABLE" + elif isinstance(data, date): + field["type"] = "DATE" + field["mode"] = "NULLABLE" + elif isinstance(data, dttime): + field["type"] = "TIME" + field["mode"] = "NULLABLE" + elif isinstance(data, six.binary_type): + field["type"] = "BYTES" + field["mode"] = "NULLABLE" + elif isinstance(data, dict): + field["type"] = "RECORD" + field["mode"] = "NULLABLE" + field["fields"] = get_bq_schema_from_json_repr(data) + elif isinstance(data, list): + field["type"] = "RECORD" + field["mode"] = "REPEATED" + field["fields"] = get_bq_schema_from_json_repr(data[0]) + fields.append(field) + return fields + - # detect if any GEOGRAPHY or DATETIME fields - def _detect_non_avro_and_parquet_types(schema): - for field in schema: - if field.field_type == "GEOGRAPHY" or field.field_type == "DATETIME": - return True - if field.field_type == "RECORD": - if _detect_non_avro_and_parquet_types(list(field.fields)): - return True - return False +def generate_create_schema(resourcelist, file_handle): + """ + Generates using a jinja template bash command using bq to for a set of schemas + supports views, tables or exetrnal tables. + The resource list is a list of tables as you would get from table.get from big query + or generated by get_bq_schema_from_json_repr - # detect if any GEOGRAPHY or DATETIME fields - def _detect_non_parquet_types(schema): - for field in schema: - # https://cloud.google.com/bigquery/docs/exporting-data#parquet_export_details - if field.field_type == "DATETIME" or field.field_type == "TIME": - return True - # parquet only support repeated base types - if field.field_type == "RECORD": - if field.mode == "REPEATED": - return True - if _detect_non_parquet_types(list(field.fields)): - return True - return False + :param resourcelist: list of resources to genereate code for + :param file_handle: file handle to output too expected to be utf-8 + :return: nothing + """ + jinjaenv = Environment( + loader=FileSystemLoader(os.path.join(_ROOT, "templates")), + autoescape=select_autoescape(["html", "xml"]), + extensions=["jinja2.ext.do", "jinja2.ext.loopcontrols"], + ) + objtemplate = jinjaenv.get_template("bqschema.in") + output = objtemplate.render(resourcelist=resourcelist) + print(output, file=file_handle) - # go columnar compression over row as generally better - # as we always process whole rows actually not worth it but we - # had fun working out whats needed to support parquet - # we have wish list of sync with hive schema which could be parquet - # so let me figure out limitations of this - # self.__destination_format = bigquery.job.DestinationFormat.PARQUET - # but if thats impossible - # if _detect_non_parquet_types(list(srctable.schema)): - self.__destination_format = bigquery.job.DestinationFormat.AVRO - if _detect_non_avro_and_parquet_types(list(srctable.schema)): - self.__destination_format = ( - bigquery.job.DestinationFormat.NEWLINE_DELIMITED_JSON - ) - @property - def destination_format(self): - """ - The destination format to use for exports for ths table when copying across regions - :return: a bigquery.job.DestinationFormat enumerator - """ - return self.__destination_format +def generate_create_schema_file(filename, resourcelist): + """ + Generates using a jinja template bash command using bq to for a set of schemas + supports views, tables or exetrnal tables. + The resource list is a list of tables as you would get from table.get from big query + or generated by get_bq_schema_from_json_repr - @property - def source_format(self): - """ - The source format to use to load this table in destnation region - :return: a bigquery.job.SourceFormat enumerator that matches the prefferd export format - """ - # only support the exports that are possible - if self.destination_format == bigquery.job.DestinationFormat.AVRO: - return bigquery.job.SourceFormat.AVRO - if self.destination_format == bigquery.job.DestinationFormat.PARQUET: - return bigquery.job.SourceFormat.PARQUET - if ( - self.destination_format - == bigquery.job.DestinationFormat.NEWLINE_DELIMITED_JSON - ): - return bigquery.job.SourceFormat.NEWLINE_DELIMITED_JSON - if self.destination_format == bigquery.job.DestinationFormat.CSV: - return bigquery.job.SourceFormat.CSV + :param filename: filename to putput too + :param resourcelist: list of resources to genereate code for + :return:nothing + """ + with open(filename, mode="w+", encoding="utf-8") as file_handle: + generate_create_schema(resourcelist, file_handle) - @property - def source_file_extension(self): - if self.destination_format == bigquery.job.DestinationFormat.AVRO: - return ".avro" - if self.destination_format == bigquery.job.DestinationFormat.PARQUET: - return ".parquet" - if ( - self.destination_format - == bigquery.job.DestinationFormat.NEWLINE_DELIMITED_JSON - ): - return ".jsonl" - if self.destination_format == bigquery.job.DestinationFormat.CSV: - return ".csv" - @property - def compression_format(self): - """ - The calculated compression type to use based on supported format - :return: one of bigquery.job.Compression enumerators or None - """ - if self.destination_format == bigquery.job.DestinationFormat.AVRO: - return bigquery.job.Compression.DEFLATE - if self.destination_format == bigquery.job.DestinationFormat.PARQUET: - return bigquery.job.Compression.SNAPPY - return bigquery.job.Compression.GZIP +def dataset_exists(client: Client, dataset_reference: DatasetReference) -> bool: + """Return if a dataset exists. - @property - def schema(self): - """ - The target schema so if needed on load can be obtained from same object - :return: - """ - return self.__table.schema + Args: + client (google.cloud.bigquery.client.Client): + A client to connect to the BigQuery API. + dataset_reference (google.cloud.bigquery.dataset.DatasetReference): + A reference to the dataset to look for. - @property - def encryption_configuration(self): - if self._dst_encryption_configuration is not None: - return self._dst_encryption_configuration - return self.__table.encryption_configuration + Returns: + bool: ``True`` if the dataset exists, ``False`` otherwise. + """ + from google.cloud.exceptions import NotFound + try: + client.get_dataset(dataset_reference) + return True + except NotFound: + return False -class DefaultBQSyncDriver(object): - """This class provides mechanical input to bqsync functions""" - threadLocal = threading.local() +def table_exists(client, table_reference): + """Return if a table exists. - def __init__( - self, - srcproject, - srcdataset, - dstdataset, - dstproject=None, - srcbucket=None, - dstbucket=None, - remove_deleted_tables=True, - copy_data=True, - copy_types=None, - check_depth=-1, - copy_access=True, - table_view_filter=None, - table_or_views_to_exclude=None, - latest_date=None, - days_before_latest_day=None, - day_partition_deep_check=False, - analysis_project=None, - query_cmek=None, - src_policy_tags=[], - dst_policy_tags=[], - ): - """ - Constructor for base copy driver all other drivers should inherit from this - :param srcproject: The project that is the source for the copy (note all actions are done - inc ontext of source project) - :param srcdataset: The source dataset - :param dstdataset: The destination dataset - :param dstproject: The source project if None assumed to be source project - :param srcbucket: The source bucket when copying cross region data is extracted to this - bucket rewritten to destination bucket - :param dstbucket: The destination bucket where data is loaded from - :param remove_deleted_tables: If table exists in destination but not in source should it - be deleted - :param copy_data: Copy data or just do schema - :param copy_types: Copy object types i.e. TABLE,VIEW,ROUTINE,MODEL - """ - assert query_cmek is None or ( - isinstance(query_cmek, list) and len(query_cmek) == 2 - ), ( - "If cmek key is specified has to be a list and MUST be 2 keys with " - "" - "1st key being source location key and destination key" - ) - self._sessionid = datetime.utcnow().isoformat().replace(":", "-") - if copy_types is None: - copy_types = ["TABLE", "VIEW", "ROUTINE", "MODEL", "MATERIALIZEDVIEW"] - if table_view_filter is None: - table_view_filter = [".*"] - if table_or_views_to_exclude is None: - table_or_views_to_exclude = [] - if dstproject is None: - dstproject = srcproject + Args: + client (google.cloud.bigquery.client.Client): + A client to connect to the BigQuery API. + table_reference (google.cloud.bigquery.table.TableReference): + A reference to the table to look for. - self._remove_deleted_tables = remove_deleted_tables + Returns: + bool: ``True`` if the table exists, ``False`` otherwise. + """ + from google.cloud.exceptions import NotFound - # check copy makes some basic sense - assert srcproject != dstproject or srcdataset != dstdataset, ( - "Source and destination " "datasets cannot be the same" - ) - assert latest_date is None or isinstance(latest_date, datetime) + try: + client.get_table(table_reference) + return True + except NotFound: + return False - self._source_project = srcproject - self._source_dataset = srcdataset - self._destination_project = dstproject - self._destination_dataset = dstdataset - self._copy_data = copy_data - self._http = None - self.__copy_q = None - self.__schema_q = None - self.__jobs = [] - self.__copy_types = copy_types - self.reset_stats() - self.__logger = logging - self.__check_depth = check_depth - self.__copy_access = copy_access - self.__table_view_filter = table_view_filter - self.__table_or_views_to_exclude = table_or_views_to_exclude - self.__re_table_view_filter = [] - self.__re_table_or_views_to_exclude = [] - self.__base_predicates = [] - self.__day_partition_deep_check = day_partition_deep_check - self.__analysisproject = self._destination_project - if analysis_project is not None: - self.__analysisproject = analysis_project - if days_before_latest_day is not None: - if latest_date is None: - end_date = "TIMESTAMP_ADD(TIMESTAMP_TRUNC(CURRENT_TIMESTAMP(),DAY),INTERVAL 1 DAY)" - else: - end_date = "TIMESTAMP('{}')".format(latest_date.strftime("%Y-%m-%d")) - self.__base_predicates.append( - "{{retpartition}} BETWEEN TIMESTAMP_SUB({end_date}, INTERVAL {" - "days_before_latest_day} * 24 HOUR) AND {end_date}".format( - end_date=end_date, days_before_latest_day=days_before_latest_day - ) - ) +def get_kms_key_name(kms_or_key_version): + # handle sept 2021 move to kms keys having versions and gradually + # ignore end part of version + # projects/methodical-bee-162815/locations/europe-west2/keyRings/cloudStorage/cryptoKeys/cloudStorage/cryptoKeyVersions/6 + # so ignore /cryptoKeyVersions/6 + # include though everything else + return re.findall( + "projects/[^/]+/locations/[^/]+/keyRings/[^/]+/cryptoKeys/[^/]+", + kms_or_key_version, + )[0] - # now check that from a service the copy makes sense - assert dataset_exists( - self.source_client, self.source_client.dataset(self.source_dataset) - ), ("Source dataset does not exist %r" % self.source_dataset) - assert dataset_exists( - self.destination_client, - self.destination_client.dataset(self.destination_dataset), - ), ( - "Destination dataset does not " "exists %r" % self.destination_dataset - ) - # figure out if cross region copy if within region copies optimised to happen in big query - # if cross region buckets need to exist to support copy and they need to be in same region - source_dataset_impl = self.source_dataset_impl - destination_dataset_impl = self.destination_dataset_impl - if query_cmek is None: - query_cmek = [] - if destination_dataset_impl.default_encryption_configuration is not None: - query_cmek.append( - destination_dataset_impl.default_encryption_configuration.kms_key_name - ) - else: - query_cmek.append(None) +def create_schema( + sobject, + schema_depth=0, + fname=None, + dschema=None, + path="$", + tableId=None, + policy_tag_callback=None, +): + schema = [] + if dschema is None: + dschema = {} + dummyfield = bigquery.SchemaField("xxxDummySchemaAsNoneDefinedxxx", "STRING") - if ( - not query_cmek - and source_dataset_impl.default_encryption_configuration is not None - ): - query_cmek.append( - source_dataset_impl.default_encryption_configuration.kms_key_name - ) - else: - query_cmek.append(None) + if fname is not None: + fname = INVALIDBQFIELDCHARS.sub("_", fname) + path = "{}.{}".format(path, fname) - self._query_cmek = query_cmek - self._same_region = ( - source_dataset_impl.location == destination_dataset_impl.location - ) - self._source_location = source_dataset_impl.location - self._destination_location = destination_dataset_impl.location - self._src_policy_tags = src_policy_tags - self._dst_policy_tags = dst_policy_tags + def _default_policy_callback(tableid, path): + return None - # if not same region where are the buckets for copying - if not self.same_region: - assert srcbucket is not None, ( - "Being asked to copy datasets across region but no " - "source bucket is defined these must be in same region " - "as source dataset" + if policy_tag_callback is None: + policy_tag_callback = _default_policy_callback + + def _create_field( + fname, + sobject, + mode="NULLABLE", + tableid=None, + path=None, + policy_tag_callback=None, + ): + fieldschema = None + if isinstance(sobject, bool): + fieldschema = bigquery.SchemaField( + fname, + "BOOLEAN", + mode=mode, + policy_tags=policy_tag_callback(tableid, path), ) - assert isinstance(srcbucket, six.string_types), ( - "Being asked to copy datasets across region but " - "no " - "" - "" - "source bucket is not a string" + elif isinstance(sobject, six.integer_types): + fieldschema = bigquery.SchemaField( + fname, + "INTEGER", + mode=mode, + policy_tags=policy_tag_callback(tableid, path), ) - self._source_bucket = srcbucket - assert dstbucket is not None, ( - "Being asked to copy datasets across region but no " - "destination bucket is defined these must be in same " - "region " - "as destination dataset" + elif isinstance(sobject, float): + fieldschema = bigquery.SchemaField( + fname, + "FLOAT", + mode=mode, + policy_tags=policy_tag_callback(tableid, path), ) - assert isinstance(dstbucket, six.string_types), ( - "Being asked to copy datasets across region but " - "destination bucket is not a string" + # start adtes and times at lowest levelof hierarchy + # https://docs.python.org/3/library/datetime.html subclass + # relationships + elif isinstance(sobject, datetime): + fieldschema = bigquery.SchemaField( + fname, + "TIMESTAMP", + mode=mode, + policy_tags=policy_tag_callback(tableid, path), ) - self._destination_bucket = dstbucket - client = storage.Client(project=self.source_project) - src_bucket = client.get_bucket(self.source_bucket) - assert compute_region_equals_bqregion( - src_bucket.location, source_dataset_impl.location - ), ( - "Source bucket " - "location is not " - "" - "" - "" - "" - "" - "" - "" - "" - "same as source " - "dataset location" + elif isinstance(sobject, date): + fieldschema = bigquery.SchemaField( + fname, "DATE", mode=mode, policy_tags=policy_tag_callback(tableid, path) ) - dst_bucket = client.get_bucket(self.destination_bucket) - assert compute_region_equals_bqregion( - dst_bucket.location, destination_dataset_impl.location - ), "Destination bucket location is not same as destination dataset location" + elif isinstance(sobject, dttime): + fieldschema = bigquery.SchemaField( + fname, "TIME", mode=mode, policy_tags=policy_tag_callback(tableid, path) + ) + elif isinstance(sobject, six.string_types): + fieldschema = bigquery.SchemaField( + fname, + "STRING", + mode=mode, + policy_tags=policy_tag_callback(tableid, path), + ) + elif isinstance(sobject, bytes): + fieldschema = bigquery.SchemaField( + fname, + "BYTES", + mode=mode, + policy_tags=policy_tag_callback(tableid, path), + ) + else: + raise UnexpectedType(str(type(sobject))) - @property - def session_id(self): - """ - This method returns the uniue identifier for this copy session - :return: The copy drivers session - """ - return self._sessionid + return fieldschema + + if isinstance(sobject, list): + tschema = [] + # if fname is not None: + # recordschema = bigquery.SchemaField(fname, 'RECORD', mode='REPEATED') + # # recordschema.fields = tschema + # ok so scenarios to handle here are + # creating a schema from a delliberate schema object for these we know + # there will be only 1 item in the ist with al fields + # but also we have creating a schema from an object which is just + # an object so could have in a list more than 1 item and items coudl have + # different fiedls + # + pdschema = dschema + if fname is not None and fname not in dschema: + dschema[fname] = {} + pdschema = dschema[fname] + fieldschema = False + sampleobject = None + for i in sobject: + # lists must have dictionaries and not base types + # if not a dictionary skip + if (isinstance(sobject, dict) or isinstance(sobject, list)) and isinstance( + i, dict + ): + tschema.extend( + create_schema( + i, + dschema=pdschema, + tableId=tableId, + path=path + "[]", + policy_tag_callback=policy_tag_callback, + ) + ) + else: + fieldschema = True + sampleobject = i + break + if len(tschema) == 0 and not fieldschema: + tschema.append(dummyfield) + if fname is not None: + if not fieldschema: + recordschema = bigquery.SchemaField( + fname, "RECORD", mode="REPEATED", fields=tschema + ) + # recordschema.fields = tuple(tschema) + schema.append(recordschema) + else: + schema.append( + _create_field( + fname, + sampleobject, + mode="REPEATED", + tableid=tableId, + path=path + "[]", + policy_tag_callback=policy_tag_callback, + ) + ) + else: + schema = tschema + + elif isinstance(sobject, dict): + tschema = [] + # if fname is not None: + # recordschema = bigquery.SchemaField(fname, 'RECORD') + # recordschema.fields = tschema + if len(sobject) > 0: + for j in sorted(sobject): + if j not in dschema: + dschema[j] = {} + if "simple" not in dschema[j]: + fieldschema = create_schema( + sobject[j], + fname=j, + dschema=dschema[j], + tableId=tableId, + path=path, + policy_tag_callback=policy_tag_callback, + ) + if fieldschema is not None: + if fname is not None: + tschema.extend(fieldschema) + else: + schema.extend(fieldschema) + else: + if fname is not None: + tschema.append(dummyfield) + else: + schema.append(dummyfield) - def map_schema_policy_tags(self, schema): - """ - A method that takes as input a big query schema iterates through the schema - and reads policy tags and maps them from src to destination. - :param schema: The schema to map - :return: schema the same schema but with policy tags updated - """ - nschema = [] - for field in schema: - if field.field_type == "RECORD": - tmp_field = field.to_api_repr() - tmp_field["fields"] = [ - i.to_api_repr() for i in self.map_schema_policy_tags(field.fields) - ] - field = bigquery.schema.SchemaField.from_api_repr(tmp_field) + if fname is not None: + recordschema = bigquery.SchemaField(fname, "RECORD", fields=tschema) + schema = [recordschema] + + else: + fieldschema = None + if fname is not None: + if isinstance(sobject, list): + if len(sobject) > 0: + if isinstance(sobject[0], dict): + fieldschema = bigquery.SchemaField(fname, "RECORD") + fieldschema.mode = "REPEATED" + mylist = sobject + head = mylist[0] + fieldschema.fields = create_schema( + head, + schema_depth + 1, + tableId=tableId, + path=path + "[]", + policy_tag_callback=policy_tag_callback, + ) + else: + fieldschema = _create_field( + fname, + sobject, + mode="REPEATED", + tableid=tableId, + path=path + "[]", + policy_tag_callback=policy_tag_callback, + ) + elif isinstance(sobject, dict): + fieldschema = bigquery.SchemaField(fname, "RECORD") + fieldschema.fields = create_schema( + sobject, + schema_depth + 1, + tableId=tableId, + path=path, + policy_tag_callback=policy_tag_callback, + ) else: - _, field = self.map_policy_tag(field) - nschema.append(field) - return nschema + fieldschema = _create_field( + fname, + sobject, + tableid=tableId, + path=path, + policy_tag_callback=policy_tag_callback, + ) - def map_policy_tag(self, field, dst_tgt=None): - """ - Method that tages a policy tags and will remap - :param field: The field to map - :param dst_tgt: The destination tag - :return: The new policy tag - """ - change = 0 - if field.field_type == "RECORD": - schema = self.map_schema_policy_tags(field.fields) - if field.fields != schema: - change = 1 - tmp_field = field.to_api_repr() - tmp_field["fields"] = [i.to_api_repr() for i in schema] - field = bigquery.schema.SchemaField.from_api_repr(tmp_field) + if dschema is not None: + dschema["simple"] = True + return [fieldschema] else: - if field.policy_tags is not None: - field_api_repr = field.to_api_repr() - tags = field_api_repr["policyTags"]["names"] - ntag = [] - for tag in tags: - new_tag = self.map_policy_tag_string(tag) - if new_tag is None: - field_api_repr.pop("policyTags", None) - break - ntag.append(new_tag) - if "policyTags" in field_api_repr: - field_api_repr["policyTags"]["names"] = ntag - field = bigquery.schema.SchemaField.from_api_repr(field_api_repr) - if dst_tgt is None or dst_tgt.policy_tags != field.policy_tags: - change = 1 - return change, bigquery.schema.SchemaField.from_api_repr(field_api_repr) + return [] - return change, field + return schema - def map_policy_tag_string(self, src_tag): - """ - This method maps a source tag to a destination policy tag - :param src_tag: The starting table column tags - :return: The expected destination tags - """ - # if same region and src has atag just simply reuse the tag - if self._same_region and not self._src_policy_tags: - return src_tag - try: - # look for the tag in the src destination map - return self._dst_policy_tags[self._src_policy_tags.index(src_tag)] - # if it doesnt return None as the tag - except ValueError: - return None - def base_predicates(self, retpartition): - actual_basepredicates = [] - for predicate in self.__base_predicates: - actual_basepredicates.append(predicate.format(retpartition=retpartition)) - return actual_basepredicates +# convert a dict and with a schema object to assict convert dict into tuple +def dict_plus_schema_2_tuple(data, schema): + """ + :param data: + :param schema: + :return: + """ + otuple = [] - def comparison_predicates(self, table_name, retpartition="_PARTITIONTIME"): - return self.base_predicates(retpartition) + # must iterate through schema to add Nones so dominates + for schema_item in schema: + value = None + if data is not None and schema_item.name in data: + value = data[schema_item.name] + if schema_item.field_type == "RECORD": + ttuple = [] + if schema_item.mode != "REPEATED" or value is None: + value = [value] + for value_item in value: + value = dict_plus_schema_2_tuple(value_item, schema_item.fields) + ttuple.append(value) + value = ttuple + otuple.append(value) - def istableincluded(self, table_name): - """ - This method when passed a table_name returns true if it should be processed in a copy action - :param table_name: - :return: boolean True then it should be include False then no - """ - if len(self.__re_table_view_filter) == 0: - for filter in self.__table_view_filter: - self.__re_table_view_filter.append(re.compile(filter)) - for filter in self.__table_or_views_to_exclude: - self.__re_table_or_views_to_exclude.append(re.compile(filter)) + return tuple(otuple) - result = False - for regexp2check in self.__re_table_view_filter: - if regexp2check.search(table_name): - result = True - break +# so assumes a list containing list of lists for structs and +# arras but for an arra of structs value is always an array +def tuple_plus_schema_2_dict(data, schema): + """ + :param data: + :param schema: + :return: + """ + rdata = {} + for schema_item, value in zip(schema, data): + if schema_item.field_type == "RECORD": + ldata = [] + if schema_item.mode == "REPEATED": + llist = value + else: + llist = [value] + for list_item in llist: + ldata.append(tuple_plus_schema_2_dict(list_item, schema_item.fields)) + if schema_item.mode == "REPEATED": + value = ldata + else: + value = ldata[0] + rdata[schema_item.name] = value - if result: - for regexp2check in self.__re_table_or_views_to_exclude: - if regexp2check.search(table_name): - result = False - break + return rdata - return result - def reset_stats(self): - self.__bytes_synced = 0 - self.__rows_synced = 0 - self.__bytes_avoided = 0 - self.__rows_avoided = 0 - self.__tables_synced = 0 - self.__materialized_views_synced = 0 - self.__views_synced = 0 - self.__routines_synced = 0 - self.__routines_failed_sync = 0 - self.__routines_avoided = 0 - self.__models_synced = 0 - self.__models_failed_sync = 0 - self.__models_avoided = 0 - self.__views_failed_sync = 0 - self.__tables_failed_sync = 0 - self.__tables_avoided = 0 - self.__view_avoided = 0 - self.__extract_fails = 0 - self.__load_fails = 0 - self.__copy_fails = 0 - self.__query_cache_hits = 0 - self.__total_bytes_processed = 0 - self.__total_bytes_billed = 0 - self.__start_time = None - self.__end_time = None - self.__load_input_file_bytes = 0 - self.__load_input_files = 0 - self.__load_output_bytes = 0 - self.__blob_rewrite_retried_exceptions = 0 - self.__blob_rewrite_unretryable_exceptions = 0 +def gen_template_dict(schema): + """ + + :param schema: Take a rest representation of google big query table fields and create a + template json object + :return: + """ + rdata = {} + for schema_item in schema: + value = None + if schema_item.field_type == "RECORD": + tvalue = gen_template_dict(schema_item.fields) + if schema_item.mode == "REPEATED": + value = [tvalue] + else: + value = tvalue + elif schema_item.field_type == "INTEGER": + value = 0 + elif schema_item.field_type == "BOOLEAN": + value = False + elif schema_item.field_type == "FLOAT": + value = 0.0 + elif schema_item.field_type == "STRING": + value = "" + elif schema_item.field_type == "TIMESTAMP": + value = datetime.utcnow() + elif schema_item.field_type == "DATE": + value = date.today() + elif schema_item.field_type == "TIME": + value = datetime.utcnow().time() + elif schema_item.field_type == "BYTES": + value = b"\x00" + else: + raise UnexpectedType(str(type(schema_item))) + rdata[schema_item.name] = value - @property - def blob_rewrite_retried_exceptions(self): - return self.__blob_rewrite_retried_exceptions + return rdata - @property - def blob_rewrite_unretryable_exceptions(self): - return self.__blob_rewrite_unretryable_exceptions - def increment_blob_rewrite_retried_exceptions(self): - self.__blob_rewrite_retried_exceptions += 1 +def to_dict(schema): + field_member = { + "name": schema.name, + "type": schema.field_type, + "description": schema.description, + "mode": schema.mode, + "fields": None, + } + if schema.fields is not None: + fields_to_append = [] + for field_item in sorted(schema.fields, key=lambda x: x.name): + fields_to_append.append(to_dict(field_item)) + field_member["fields"] = fields_to_append + return field_member - def increment_blob_rewrite_unretryable_exceptions(self): - self.__blob_rewrite_unretryable_exceptions += 1 - @property - def models_synced(self): - return self.__models_synced +def calc_field_depth(fieldlist, depth=0): + max_depth = depth + recursive_depth = depth + for i in fieldlist: + if "fields" in i: + recursive_depth = calc_field_depth(i["fields"], depth + 1) + if recursive_depth > max_depth: + max_depth = recursive_depth + return max_depth - def increament_models_synced(self): - self.__models_synced += 1 - @property - def models_failed_sync(self): - return self.__models_failed_sync +def trunc_field_depth(fieldlist, maxdepth, depth=0): + new_field = [] + if depth <= maxdepth: + for i in fieldlist: + new_field.append(i) + if "fields" in i: + if depth == maxdepth: + # json.JSONEncoder().encode(fieldlist) + i["type"] = "STRING" + i.pop("fields", None) + else: + i["fields"] = trunc_field_depth(i["fields"], maxdepth, depth + 1) - def increment_models_failed_sync(self): - self.__models_failed_sync += 1 + return new_field - @property - def models_avoided(self): - return self.__models_avoided - def increment_models_avoided(self): - self.__models_avoided += 1 +def match_and_addtoschema(objtomatch, schema, evolved=False, path="", logger=None): + pretty_printer = pprint.PrettyPrinter(indent=4) + poplist = {} - @property - def routines_synced(self): - return self.__routines_synced + for keyi in objtomatch: + # Create schema does this adjustment so we need to do same in actual object + thekey = INVALIDBQFIELDCHARS.sub("_", keyi) + # Work out if object keys have invalid values and n + if thekey != keyi: + poplist[keyi] = thekey + matchstruct = False + # look for bare list should not have any if known about + # big query cannot hande bare lists + # so to alow schema evoution MUST be removed + # this test if we have a list and a value in it is it a bare type i.e.not a dictionary + # if it is not a dictionary use bare type ist method to cnvert to a dictionary + # where object vallue is a singe key in a dict of value + # this changes each object as well meaning they will load into the evolved schema + # we call this with log error false as this method checks if the key exists and + # if the object is a list and lengh > 0 and if the object at the end is dict or not only + # converts if not a dict + # this is important optimisation as if we checked here it would be a double check + # as lots of objects this overhead is imprtant to minimise hence why this + # looks like it does + do_bare_type_list(objtomatch, keyi, "value") + for schema_item in schema: + if thekey == schema_item.name: + if schema_item.field_type == "RECORD": + if schema_item.mode == "REPEATED": + subevolve = evolved + for listi in objtomatch[keyi]: + # TODO hack to modify fields as .fields is immutable since version + # 0.28 and later but not + # in docs!! + schema_item._fields = list(schema_item.fields) + tsubevolve = match_and_addtoschema( + listi, + schema_item.fields, + evolved=evolved, + path=path + "." + thekey, + ) + if not subevolve and tsubevolve: + subevolve = tsubevolve + evolved = subevolve + else: + # TODO hack to modify fields as .fields is immutable since version 0.28 + # and later but not in + # docs!! + schema_item._fields = list(schema_item.fields) + evolved = match_and_addtoschema( + objtomatch[keyi], schema_item.fields, evolved=evolved + ) + matchstruct = True + break + if matchstruct: + continue - @property - def query_cmek(self): - return self._query_cmek + # Construct addition to schema here based on objtomatch[keyi] schema or object type + # append to the schema list + try: + toadd = create_schema(objtomatch[keyi], fname=keyi) + except Exception: + raise SchemaMutationError(str(objtomatch), keyi, path) - def increment_routines_synced(self): - self.__routines_synced += 1 + if toadd is not None: + schema.extend(toadd) + if logger is not None: + logger.warning( + "Evolved path = {}, struct={}".format( + path + "." + thekey, pretty_printer.pformat(objtomatch[keyi]) + ) + ) + evolved = True - def increment_routines_avoided(self): - self.__routines_avoided += 1 + # If values of keys did need changing change them + if len(poplist): + for pop_item in poplist: + objtomatch[poplist[pop_item]] = objtomatch[pop_item] + objtomatch.pop(pop_item, None) - @property - def routines_failed_sync(self): - return self.__routines_failed_sync + return evolved - @property - def routines_avoided(self): - return self.__routines_avoided - def increment_rows_avoided(self): - self.__routines_avoided += 1 +def do_bare_type_list(adict, key, detail, logger=None): + """ + Converts a list that is pointed to be a key in a dctionary from + non dictionary object to dictionary object. We do this as bare types + are not allowed in BQ jsons structures. So structures of type - def increment_routines_failed_sync(self): - self.__routines_failed_sync += 1 + "foo":[ 1,2,3 ] - @property - def bytes_copied_across_region(self): - return self.__load_input_file_bytes + to - @property - def files_copied_across_region(self): - return self.__load_input_files + "foo":[{"detail":1},{"detail":2},{"detail":3}] - @property - def bytes_copied(self): - return self.__load_output_bytes + Args: + adict: The dictionary the key of the list object is in. This object is modified so mutated. + key: The key name of the list if it does not exist this does nothing. if the item at the + key is not a list it + does nothing if length of list is 0 this does nothing + detail: The name of the field in new sub dictionary of each object - def increment_load_input_file_bytes(self, value): - self.__load_input_file_bytes += value - def increment_load_input_files(self, value): - self.__load_input_files += value + Returns: + Nothing. - def increment_load_output_bytes(self, value): - self.__load_output_bytes += value + Raises: + Nothing + """ + try: + if key in adict: + if key in adict and isinstance(adict[key], list) and len(adict[key]) > 0: + if not isinstance(adict[key][0], dict): + new_list = [] + for list_item in adict[key]: + new_list.append({detail: list_item}) + adict[key] = new_list + else: + if logger is not None: + tbs = traceback.extract_stack() + tbsflat = "\n".join(map(str, tbs)) + logger.error( + "Bare list for key {} in dict {} expected a basic type not converting " + "{}".format(key, str(adict), tbsflat) + ) + except Exception: + raise UnexpectedDict( + "Bare list for key {} in dict {} expected a basic type not converting".format( + key, str(adict) + ) + ) - @property - def start_time(self): - return self.__start_time - @property - def end_time(self): - if self.__end_time is None and self.__start_time is not None: - return datetime.utcnow() - return self.__end_time +def recurse_and_add_to_schema(schema, oschema): + changes = False - @property - def sync_time_seconds(self): - if self.__start_time is None: - return None - return (self.__end_time - self.__start_time).seconds + # Minimum is new schema now this can have less than old + wschema = copy.deepcopy(schema) + + # Everything in old schema stays as a patch + for output_schema_item in oschema: + nschema = [] + # Look for + for new_schema_item in wschema: + if output_schema_item["name"].lower() == new_schema_item.name.lower(): + if output_schema_item["type"] == "RECORD": + rchanges, output_schema_item["fields"] = recurse_and_add_to_schema( + new_schema_item.fields, output_schema_item["fields"] + ) + if rchanges and not changes: + changes = rchanges + else: + nschema.append(new_schema_item) + wschema = nschema + + # Now just has what remain in it. + for wsi in wschema: + changes = True + oschema.append(to_dict(wsi)) + + return (changes, oschema) - def start_sync(self): - self.__start_time = datetime.utcnow() - def end_sync(self): - self.__end_time = datetime.utcnow() +FSLST = """#standardSQL +SELECT + ut.*, + fls.firstSeenTime, + fls.lastSeenTime, + fls.numSeen +FROM `{0}.{1}.{2}` as ut +JOIN ( + SELECT + id, + min({4}) AS firstSeenTime, + max({4}) AS lastSeenTime, + COUNT(*) AS numSeen + FROM `{0}.{1}.{2}` + GROUP BY + 1) AS fls +ON fls.id = ut.id AND fls.{3} = {4} +""" +FSLSTDT = ( + "View that shows {} captured values of underlying table for object of a " + "given non repeating key " + "of 'id' {}.{}.{}" +) - @property - def query_cache_hits(self): - return self.__query_cache_hits - def increment_cache_hits(self): - self.__query_cache_hits += 1 +def gen_diff_views( + project, + dataset, + table, + schema, + description="", + intervals=None, + hint_fields=None, + hint_mutable_fields=True, + time_expr=None, + fieldsappend=None, +): + """ - @property - def total_bytes_processed(self): - return self.__total_bytes_processed + :param project: google project id of underlying table + :param dataset: google dataset id of underlying table + :param table: the base table to do diffs (assumes each time slaice is a view of what data + looked like)) + :param schema: the schema of the base table + :param description: a base description for the views + :param intervals: a list of form [] + :param hint_fields: + :param time_expr: + :param fieldsappend: + :return: + """ - def increment_total_bytes_processed(self, total_bytes_processed): - self.__total_bytes_processed += total_bytes_processed + views = [] + fieldsnot4diff = [] + if intervals is None: + intervals = [ + {"day": "1 DAY"}, + {"week": "7 DAY"}, + {"month": "30 DAY"}, + {"fortnight": "14 DAY"}, + ] - @property - def total_bytes_billed(self): - return self.__total_bytes_processed + if time_expr is None: + time_expr = "_PARTITIONTIME" + fieldsnot4diff.append("scantime") + if isinstance(fieldsappend, list): + for fdiffi in fieldsappend: + fieldsnot4diff.append(fdiffi) + if hint_fields is None: + hint_fields = [ + "creationTime", + "usage", + "title", + "description", + "preferred", + "documentationLink", + "discoveryLink", + "numLongTermBytes", + "detailedStatus", + "lifecycleState", + "size", + "md5Hash", + "crc32c", + "timeStorageClassUpdated", + "deleted", + "networkIP", + "natIP", + "changePasswordAtNextLogin", + "status", + "state", + "substate", + "stateStartTime", + "metricValue", + "requestedState", + "statusMessage", + "numWorkers", + "currentStateTime", + "currentState", + "lastLoginTime", + "lastViewedByMeDate", + "modifiedByMeDate", + "etag", + "servingStatus", + "lastUpdated", + "updateTime", + "lastModified", + "lastModifiedTime", + "timeStorageClassUpdated", + "updated", + "numRows", + "numBytes", + "numUsers", + "isoCountryCodes", + "countries", + "uriDescription", + "riskScore", + "controlId", + "resolutionDate", + ] - def increment_total_bytes_billed(self, total_bytes_billed): - self.__total_bytes_billed += total_bytes_billed + fqtablename = "{}.{}.{}".format(project, dataset, table) + basediffview = table + "db" + basefromclause = "\nfrom `{}` as {}".format(fqtablename, "ta" + table) + baseselectclause = """#standardSQL +SELECT + {} AS scantime, + xxrownumbering.partRowNumber""".format( + time_expr + ) + baseendselectclause = """ + JOIN ( + SELECT + scantime, + ROW_NUMBER() OVER(ORDER BY scantime) AS partRowNumber + FROM ( + SELECT + DISTINCT {time_expr} AS scantime, + FROM + `{project}.{dataset}.{table}`)) AS xxrownumbering + ON + {time_expr} = xxrownumbering.scantime + """.format( + time_expr=time_expr, project=project, dataset=dataset, table=table + ) - @property - def copy_fails(self): - return self.__copy_fails + curtablealias = "ta" + table + fieldprefix = "" + aliasstack = [] + fieldprefixstack = [] + fields4diff = [] - @property - def copy_access(self): - return self.__copy_access + # fields to ignore as in each snapshot and different even if content is the same + fields_update_only = [] + aliasnum = 1 - @copy_access.setter - def copy_access(self, value): - self.__copy_access = value + basedata = { + "select": baseselectclause, + "from": basefromclause, + "aliasnum": aliasnum, + } - def increment_copy_fails(self): - self.__copy_fails += 1 + def recurse_diff_base(schema, fieldprefix, curtablealias): + # pretty_printer = pprint.PrettyPrinter(indent=4) - @property - def load_fails(self): - return self.__load_fails + for schema_item in sorted(schema, key=lambda x: x.name): + if schema_item.name in fieldsnot4diff: + continue + # field names can only be up o 128 characters long + if len(fieldprefix + schema_item.name) > 127: + raise BQHittingQueryGenerationQuotaLimit( + "Field alias is over 128 bytes {} aborting code generation".format( + fieldprefix + schema_item.name + ) + ) + if schema_item.mode != "REPEATED": + if schema_item.field_type == "STRING": + basefield = ',\n ifnull({}.{},"None") as `{}`'.format( + curtablealias, schema_item.name, fieldprefix + schema_item.name + ) + elif schema_item.field_type == "BOOLEAN": + basefield = ",\n ifnull({}.{},False) as `{}`".format( + curtablealias, schema_item.name, fieldprefix + schema_item.name + ) + elif schema_item.field_type == "INTEGER": + basefield = ",\n ifnull({}.{},0) as `{}`".format( + curtablealias, schema_item.name, fieldprefix + schema_item.name + ) + elif schema_item.field_type == "FLOAT": + basefield = ",\n ifnull({}.{},0.0) as `{}`".format( + curtablealias, schema_item.name, fieldprefix + schema_item.name + ) + elif schema_item.field_type == "DATE": + basefield = ",\n ifnull({}.{},DATE(1970,1,1)) as `{}`".format( + curtablealias, schema_item.name, fieldprefix + schema_item.name + ) + elif schema_item.field_type == "DATETIME": + basefield = ( + ",\n ifnull({}.{},DATETIME(1970,1,1,0,0,0)) as `{}`".format( + curtablealias, + schema_item.name, + fieldprefix + schema_item.name, + ) + ) + elif schema_item.field_type == "TIMESTAMP": + basefield = ( + ',\n ifnull({}.{},TIMESTAMP("1970-01-01T00:00:00Z")) as `{' + "}`".format( + curtablealias, + schema_item.name, + fieldprefix + schema_item.name, + ) + ) + elif schema_item.field_type == "TIME": + basefield = ",\n ifnull({}.{},TIME(0,0,0)) as `{}`".format( + curtablealias, schema_item.name, fieldprefix + schema_item.name + ) + elif schema_item.field_type == "BYTES": + basefield = ',\n ifnull({}.{},b"\x00") as `{}`'.format( + curtablealias, schema_item.name, fieldprefix + schema_item.name + ) + elif schema_item.field_type == "RECORD": + aliasstack.append(curtablealias) + fieldprefixstack.append(fieldprefix) + fieldprefix = fieldprefix + schema_item.name + if schema_item.mode == "REPEATED": + oldalias = curtablealias + curtablealias = "A{}".format(basedata["aliasnum"]) + basedata["aliasnum"] = basedata["aliasnum"] + 1 + + basedata["from"] = basedata[ + "from" + ] + "\nLEFT JOIN UNNEST({}) as {}".format( + oldalias + "." + schema_item.name, curtablealias + ) + + else: + curtablealias = curtablealias + "." + schema_item.name + recurse_diff_base(schema_item.fields, fieldprefix, curtablealias) + curtablealias = aliasstack.pop() + fieldprefix = fieldprefixstack.pop() + continue + else: + aliasstack.append(curtablealias) + fieldprefixstack.append(fieldprefix) + fieldprefix = fieldprefix + schema_item.name + oldalias = curtablealias + curtablealias = "A{}".format(basedata["aliasnum"]) + basedata["aliasnum"] = basedata["aliasnum"] + 1 + basedata["from"] = basedata[ + "from" + ] + "\nLEFT JOIN UNNEST({}) as {}".format( + oldalias + "." + schema_item.name, curtablealias + ) + if schema_item.field_type == "STRING": + basefield = ',\n ifnull({},"None") as {}'.format( + curtablealias, fieldprefix + ) + elif schema_item.field_type == "BOOLEAN": + basefield = ",\n ifnull({},False) as {}".format( + curtablealias, fieldprefix + ) + elif schema_item.field_type == "INTEGER": + basefield = ",\n ifnull({},0) as {}".format( + curtablealias, fieldprefix + ) + elif schema_item.field_type == "FLOAT": + basefield = ",\n ifnull({},0.0) as {}".format( + curtablealias, fieldprefix + ) + elif schema_item.field_type == "DATE": + basefield = ",\n ifnull({},DATE(1970,1,1)) as {}".format( + curtablealias, fieldprefix + ) + elif schema_item.field_type == "DATETIME": + basefield = ( + ",\n ifnull({},DATETIME(1970,1,1,0,0,0)) as {}".format( + curtablealias, fieldprefix + ) + ) + elif schema_item.field_type == "TIME": + basefield = ",\n ifnull({},TIME(0,0,0)) as {}".format( + curtablealias, fieldprefix + ) + elif schema_item.field_type == "BYTES": + basefield = ',\n ifnull({},b"\x00") as {}'.format( + curtablealias, fieldprefix + ) + if schema_item.field_type == "RECORD": + recurse_diff_base(schema_item.fields, fieldprefix, curtablealias) + else: + # as an array has to be a diff not an update + fields4diff.append(fieldprefix) + basedata["select"] = basedata["select"] + basefield + curtablealias = aliasstack.pop() + fieldprefix = fieldprefixstack.pop() + continue - def increment_load_fails(self): - self.__load_fails += 1 + if hint_mutable_fields: + update_only = False + else: + update_only = True + if schema_item.name in hint_fields: + if hint_mutable_fields: + update_only = True + else: + update_only = False + if update_only: + fields_update_only.append(fieldprefix + schema_item.name) + else: + fields4diff.append(fieldprefix + schema_item.name) + basedata["select"] = basedata["select"] + basefield + return - @property - def extract_fails(self): - return self.__extract_fails + try: + recurse_diff_base(schema, fieldprefix, curtablealias) + allfields = fields4diff + fields_update_only + basechangeselect = basedata["select"] + basedata["from"] + baseendselectclause + joinfields = "" + if len(fields4diff) > 0 and len(fields_update_only) > 0: + joinfields = "\n UNION ALL\n" + auditchangequery = AUDITCHANGESELECT.format( + mutatedimmutablefields="\n UNION ALL".join( + [ + TEMPLATEMUTATEDIMMUTABLE.format(fieldname=field) + for field in fields4diff + ] + ), + mutablefieldchanges="\n UNION ALL".join( + [ + TEMPLATEMUTATEDFIELD.format(fieldname=field) + for field in fields_update_only + ] + ), + beforeorafterfields=",\n".join( + [TEMPLATEBEFOREORAFTER.format(fieldname=field) for field in allfields] + ), + basechangeselect=basechangeselect, + immutablefieldjoin="AND ".join( + [ + TEMPLATEFORIMMUTABLEJOINFIELD.format(fieldname=field) + for field in fields4diff + ] + ), + avoidlastpredicate=AVOIDLASTSETINCROSSJOIN.format( + time_expr=time_expr, project=project, dataset=dataset, table=table + ), + joinfields=joinfields, + ) - def increment_extract_fails(self): - self.__extract_fails += 1 + if len(basechangeselect) > 256 * 1024: + raise BQHittingQueryGenerationQuotaLimit( + "Query {} is over 256kb".format(basechangeselect) + ) + views.append( + { + "name": basediffview, + "query": basechangeselect, + "description": "View used as basis for diffview:" + description, + } + ) + if len(auditchangequery) > 256 * 1024: + raise BQHittingQueryGenerationQuotaLimit( + "Query {} is over 256kb".format(auditchangequery) + ) + views.append( + { + "name": "{}diff".format(table), + "query": auditchangequery, + "description": "View calculates what has changed at what time:" + + description, + } + ) - @property - def check_depth(self): - return self.__check_depth + refbasediffview = "{}.{}.{}".format(project, dataset, basediffview) - @check_depth.setter - def check_depth(self, value): - self.__check_depth = value + # Now fields4 diff has field sto compare fieldsnot4diff appear in select but are not + # compared. + # basic logic is like below + # + # select action (a case statement but "Added","Deleted","Sames") + # origfield, + # lastfield, + # if origfield != lastfield diff = 1 else diff = 0 + # from diffbaseview as orig with select of orig timestamp + # from diffbaseview as later with select of later timestamp + # This template logic is then changed for each interval to actually generate concrete views - @property - def views_failed_sync(self): - return self.__views_failed_sync + # mutatedimmutablefields = "" + # mutablefieldchanges = "" + # beforeorafterfields = "" + # basechangeselect = "" + # immutablefieldjoin = "" - def increment_views_failed_sync(self): - self.__views_failed_sync += 1 + diffviewselectclause = """#standardSQL +SELECT + o.scantime as origscantime, + l.scantime as laterscantime,""" + diffieldclause = "" + diffcaseclause = "" + diffwhereclause = "" + diffviewfromclause = """ + FROM (SELECT + * + FROM + `{0}` + WHERE + scantime = ( + SELECT + MAX({1}) + FROM + `{2}.{3}.{4}` + WHERE + {1} < ( + SELECT + MAX({1}) + FROM + `{2}.{3}.{4}`) + AND + {1} < TIMESTAMP_SUB(CURRENT_TIMESTAMP(),INTERVAL %interval%) ) ) o +FULL OUTER JOIN ( + SELECT + * + FROM + `{0}` + WHERE + scantime =( + SELECT + MAX({1}) + FROM + `{2}.{3}.{4}` )) l +ON +""".format( + refbasediffview, time_expr, project, dataset, table + ) - @property - def tables_failed_sync(self): - return self.__tables_failed_sync + for f4i in fields4diff: + diffieldclause = ( + diffieldclause + + ",\n o.{} as orig{},\n l.{} as later{},\n case " + "when o.{} = l.{} " + "then 0 else 1 end as diff{}".format(f4i, f4i, f4i, f4i, f4i, f4i, f4i) + ) + if diffcaseclause == "": + diffcaseclause = """ + CASE + WHEN o.{} IS NULL THEN 'Added' + WHEN l.{} IS NULL THEN 'Deleted' + WHEN o.{} = l.{} """.format( + f4i, f4i, f4i, f4i + ) + else: + diffcaseclause = diffcaseclause + "AND o.{} = l.{} ".format(f4i, f4i) - def increment_tables_failed_sync(self): - self.__tables_failed_sync += 1 + if diffwhereclause == "": + diffwhereclause = " l.{} = o.{}".format(f4i, f4i) + else: + diffwhereclause = diffwhereclause + "\n AND l.{}=o.{}".format( + f4i, f4i + ) - @property - def tables_avoided(self): - return self.__tables_avoided + for f4i in fields_update_only: + diffieldclause = ( + diffieldclause + + ",\n o.{} as orig{},\n l.{} as later{},\n case " + "" + "when o.{} = l.{} " + "then 0 else 1 end as diff{}".format(f4i, f4i, f4i, f4i, f4i, f4i, f4i) + ) + diffcaseclause = diffcaseclause + "AND o.{} = l.{} ".format(f4i, f4i) - def increment_tables_avoided(self): - self.__tables_avoided += 1 + diffcaseclause = ( + diffcaseclause + + """THEN 'Same' + ELSE 'Updated' + END AS action""" + ) - @property - def view_avoided(self): - return self.__view_avoided + for intervali in intervals: + for keyi in intervali: + view_name = table + "diff" + keyi + view_description = ( + "Diff of {} of underlying table {} description: {}".format( + keyi, table, description + ) + ) + diff_query = ( + diffviewselectclause + + diffcaseclause + + diffieldclause + + diffviewfromclause.replace("%interval%", intervali[keyi]) + + diffwhereclause + ) - def increment_view_avoided(self): - self.__view_avoided += 1 + if len(diff_query) > 256 * 1024: + raise BQHittingQueryGenerationQuotaLimit( + "Query {} is over 256kb".format(diff_query) + ) - @property - def bytes_synced(self): - return self.__bytes_synced + views.append( + { + "name": view_name, + "query": diff_query, + "description": view_description, + } + ) - @property - def copy_types(self): - return self.__copy_types + # look for id in top level fields if exists create first seen and last seen views + for i in schema: + if i.name == "id": + fsv = FSLST.format(project, dataset, table, "firstSeenTime", time_expr) + fsd = FSLSTDT.format("first", project, dataset, table) + lsv = FSLST.format(project, dataset, table, "lastSeenTime", time_expr) + lsd = FSLSTDT.format("last", project, dataset, table) - def add_bytes_synced(self, bytes): - self.__bytes_synced += bytes + if len(fsv) > 256 * 1024: + raise BQHittingQueryGenerationQuotaLimit( + "Query {} is over 256kb".format(fsv) + ) - def update_job_stats(self, job): - """ - Given a big query job figure out what stats to process - :param job: - :return: None - """ - if isinstance(job, bigquery.QueryJob): - if job.cache_hit: - self.increment_cache_hits() - self.increment_total_bytes_billed(job.total_bytes_billed) - self.increment_total_bytes_processed(job.total_bytes_processed) + views.append({"name": table + "fs", "query": fsv, "description": fsd}) + if len(lsv) > 256 * 1024: + raise BQHittingQueryGenerationQuotaLimit( + "Query {} is over 256kb".format(lsv) + ) + views.append({"name": table + "ls", "query": lsv, "description": lsd}) + break - if isinstance(job, bigquery.CopyJob): - if job.error_result is not None: - self.increment_copy_fails() + except BQHittingQueryGenerationQuotaLimit: + pass - if isinstance(job, bigquery.LoadJob): - if job.error_result: - self.increment_load_fails() - else: - self.increment_load_input_files(job.input_files) - self.increment_load_input_file_bytes(job.input_file_bytes) - self.increment_load_output_bytes(job.output_bytes) + return views - @property - def rows_synced(self): - # as time can be different between these assume avoided is always more accurae - if self.rows_avoided > self.__rows_synced: - return self.rows_avoided - return self.__rows_synced - def add_rows_synced(self, rows): - self.__rows_synced += rows +def evolve_schema(insertobj, table, client, bigquery, logger=None): + """ - @property - def bytes_avoided(self): - return self.__bytes_avoided + :param insertobj: json object that represents schema expected + :param table: a table object from python api thats been git through client.get_table + :param client: a big query client object + :param bigquery: big query service as created with google discovery discovery.build( + "bigquery","v2") + :param logger: a google logger class + :return: evolved True or False + """ - def add_bytes_avoided(self, bytes): - self.__bytes_avoided += bytes + schema = list(table.schema) + tablechange = False - @property - def rows_avoided(self): - return self.__rows_avoided + evolved = match_and_addtoschema(insertobj, schema) - def add_rows_avoided(self, rows): - self.__rows_avoided += rows + if evolved: + if logger is not None: + logger.warning( + "Evolving schema as new field(s) on {}:{}.{} views with * will need " + "reapplying".format(table.project, table.dataset_id, table.table_id) + ) - @property - def tables_synced(self): - return self.__tables_synced + treq = bigquery.tables().get( + projectId=table.project, datasetId=table.dataset_id, tableId=table.table_id + ) + table_data = treq.execute() + oschema = table_data.get("schema") + tablechange, pschema = recurse_and_add_to_schema(schema, oschema["fields"]) + update = {"schema": {"fields": pschema}} + preq = bigquery.tables().patch( + projectId=table.project, + datasetId=table.dataset_id, + tableId=table.table_id, + body=update, + ) + preq.execute() + client.get_table(table) + # table.reload() - @property - def views_synced(self): - return self.__views_synced + return evolved - def increment_tables_synced(self): - self.__tables_synced += 1 - def increment_materialized_views_synced(self): - self.__materialized_views_synced += 1 +def create_default_bq_resources( + template, + basename, + project, + dataset, + location, + hint_fields=None, + hint_mutable_fields=True, + optheaddays=None, +): + """ - def increment_views_synced(self): - self.__views_synced += 1 + :param template: a template json object to create a big query schema for + :param basename: a base name of the table to create that will also be used as a basis for views + :param project: the project to create resources in + :param dataset: the datasets to create them in + :param location: The locatin + :return: a list of big query table resources as dicionaries that can be passe dto code + genearteor or used in rest + calls + """ + resourcelist = [] + table = { + "type": "TABLE", + "location": location, + "tableReference": { + "projectId": project, + "datasetId": dataset, + "tableId": basename, + }, + "timePartitioning": {"type": "DAY", "expirationMs": "94608000000"}, + "schema": {}, + } + table["schema"]["fields"] = get_bq_schema_from_json_repr(template) + resourcelist.append(table) + views = gen_diff_views( + project, + dataset, + basename, + create_schema(template), + hint_fields=hint_fields, + hint_mutable_fields=hint_mutable_fields, + ) - @property - def copy_q(self): - return self.__copy_q + head_view_format = HEADVIEW + # if a max days look back is given use it to optimise bytes in the head view + if optheaddays is not None: + head_view_format = OPTHEADVIEW - @copy_q.setter - def copy_q(self, value): - self.__copy_q = value + table = { + "type": "VIEW", + "tableReference": { + "projectId": project, + "datasetId": dataset, + "tableId": "{}head".format(basename), + }, + "view": { + "query": head_view_format.format(project, dataset, basename, optheaddays), + "useLegacySql": False, + }, + } + resourcelist.append(table) + for view_item in views: + table = { + "type": "VIEW", + "tableReference": { + "projectId": project, + "datasetId": dataset, + "tableId": view_item["name"], + }, + "view": {"query": view_item["query"], "useLegacySql": False}, + } + resourcelist.append(table) + return resourcelist - @property - def schema_q(self): - return self.__schema_q - @schema_q.setter - def schema_q(self, value): - self.__schema_q = value +class ViewCompiler(object): + def __init__(self, render_dictionary=None): + if render_dictionary is None: + render_dictionary = {} + self.view_depth_optimiser = {} + self._add_auth_view = {} + self._views = {} + self._render_dictionary = render_dictionary + self._lock = threading.RLock() @property - def source_location(self): - return self._source_location + def render_dictionary(self): + return self._render_dictionary - @property - def destination_location(self): - return self._destination_location + @render_dictionary.setter + def render_dictionary(self, render_dictionary): + self._render_dictionary = render_dictionary - @property - def source_bucket(self): - return self._source_bucket + def render(self, raw_sql): + template_for_render = Template(raw_sql) + return template_for_render.render(self.render_dictionary) - @property - def destination_bucket(self): - return self._destination_bucket + def add_view_to_process(self, dataset, name, sql, unnest=True, description=None): + standard_sql = True + compiled_sql = self.render(sql) + sql = compiled_sql + # prefix = "" + if sql.strip().lower().find("#legacysql") == 0: + standard_sql = False - @property - def same_region(self): - return self._same_region + splitprojectdataset = ":" + repattern = LEGACY_SQL_PDTCTABLEREGEXP + # no_auth_view = "{}:{}".format(dataset.project, dataset.dataset_id) - @property - def source_dataset_impl(self): - source_datasetref = self.source_client.dataset(self.source_dataset) - return self.source_client.get_dataset(source_datasetref) + if standard_sql: + splitprojectdataset = "." + repattern = STANDARD_SQL_PDTCTABLEREGEXP - @property - def destination_dataset_impl(self): - destination_datasetref = self.destination_client.dataset( - self.destination_dataset + key = "{}{}{}.{}".format( + dataset.project, splitprojectdataset, dataset.dataset_id, name ) - return self.destination_client.get_dataset(destination_datasetref) - @property - def query_client(self): - """ - Returns the client to be charged for analysis of comparison could be - source could be destination could be another. - By default it is the destination but can be overriden by passing a target project - :return: A big query client for the project to be charged - """ - warnings.filterwarnings( - "ignore", "Your application has authenticated using end user credentials" - ) - """ - Obtains a source client in the current thread only constructs a client once per thread - :return: - """ - source_client = getattr( - DefaultBQSyncDriver.threadLocal, self.__analysisproject, None - ) - if source_client is None: - setattr( - DefaultBQSyncDriver.threadLocal, - self.__analysisproject, - bigquery.Client(project=self.__analysisproject, _http=self.http), - ) - return getattr(DefaultBQSyncDriver.threadLocal, self.__analysisproject, None) + dependsOn = [] - @property - def source_client(self): - warnings.filterwarnings( - "ignore", "Your application has authenticated using end user credentials" - ) - """ - Obtains a source client in the current thread only constructs a client once per thread - :return: - """ - source_client = getattr( - DefaultBQSyncDriver.threadLocal, - self.source_project + self._source_dataset, - None, - ) - if source_client is None: - setattr( - DefaultBQSyncDriver.threadLocal, - self.source_project + self._source_dataset, - bigquery.Client(project=self.source_project, _http=self.http), - ) - return getattr( - DefaultBQSyncDriver.threadLocal, - self.source_project + self._source_dataset, - None, - ) + for project_dataset_table in re.findall( + repattern, re.sub(SQL_COMMENT_REGEXP, "", sql).replace("\n", " ") + ): + dependsOn.append(project_dataset_table) - @property - def http(self): - """ - Allow override of http transport per client - usefule for proxy handlng but should be handled by sub-classes default is do nothing - :return: - """ - return self._http + self._views[key] = { + "sql": sql, + "dataset": dataset, + "name": name, + "dependsOn": dependsOn, + "unnest": unnest, + "description": description, + } + + def compile_views(self): + for tranche in self.view_tranche: + for view in self.view_in_tranche(tranche): + view["sql"] = self.compile( + view["dataset"], + view["name"], + view["sql"], + unnest=view["unnest"], + description=view["description"], + rendered=True, + ) @property - def destination_client(self): - """ - Obtains a s destination client in current thread only constructs a client once per thread - :return: - """ - source_client = getattr( - DefaultBQSyncDriver.threadLocal, - self.destination_project + self.destination_dataset, - None, - ) - if source_client is None: - setattr( - DefaultBQSyncDriver.threadLocal, - self.destination_project + self.destination_dataset, - bigquery.Client(project=self.destination_project, _http=self.http), - ) - return getattr( - DefaultBQSyncDriver.threadLocal, - self.destination_project + self.destination_dataset, - None, - ) + def view_tranche(self): + view_tranches = self.plan_view_apply_tranches() + for tranche in view_tranches: + yield tranche - def export_import_format_supported(self, srctable, dsttable=None): - """Calculates a suitable export import type based upon schema - default mecahnism is to use AVRO and SNAPPY as parallel and fast - If though a schema type notsupported by AVRO fall back to jsonnl and gzip - """ - dst_encryption_configuration = None - if dsttable is None and srctable.encryption_configuration is not None: - dst_encryption_configuration = self.calculate_target_cmek_config( - srctable.encryption_configuration - ) - else: - dst_encryption_configuration = None + def view_in_tranche(self, tranche): + for view in sorted(tranche): + yield self._views[view] - return ExportImportType(srctable, dsttable, dst_encryption_configuration) + def plan_view_apply_tranches(self): + view_tranches = [] - @property - def source_project(self): - return self._source_project + max_tranche_depth = 0 + for view in self._views: + depend_depth = self.calc_view_dependency_depth(view) + self._views[view]["tranche"] = depend_depth + if depend_depth > max_tranche_depth: + max_tranche_depth = depend_depth - @property - def source_dataset(self): - return self._source_dataset + for tranche in range(max_tranche_depth + 1): + tranchelist = [] + for view in self._views: + if self._views[view]["tranche"] == tranche: + tranchelist.append(view) + view_tranches.append(tranchelist) - @property - def jobs(self): - return self.__jobs + return view_tranches - def add_job(self, job): - self.__jobs.append(job) + def calc_view_dependency_depth(self, name, depth=0): + max_depth = depth + for depends_on in self._views[name].get("dependsOn", []): + if depends_on in self._views: + retdepth = self.calc_view_dependency_depth(depends_on, depth=depth + 1) + if retdepth > max_depth: + max_depth = retdepth + return max_depth - @property - def destination_project(self): - return self._destination_project + def add_auth_view(self, project_auth, dataset_auth, view_to_authorise): + if project_auth not in self._add_auth_view: + self._add_auth_view[project_auth] = {} + if dataset_auth not in self._add_auth_view[project_auth]: + self._add_auth_view[project_auth][dataset_auth] = [] + found = False + for view in self._add_auth_view[project_auth][dataset_auth]: + if ( + view_to_authorise["tableId"] == view["tableId"] + and view_to_authorise["datasetId"] == view["datasetId"] + and view_to_authorise["projectId"] == view["projectId"] + ): + found = True + break + if not found: + self._add_auth_view[project_auth][dataset_auth].append(view_to_authorise) @property - def destination_dataset(self): - return self._destination_dataset + def projects_to_authorise_views_in(self): + for project in self._add_auth_view: + yield project - @property - def remove_deleted_tables(self): - return self._remove_deleted_tables + def datasets_to_authorise_views_in(self, project): + for dataset in self._add_auth_view.get(project, {}): + yield dataset - def day_partition_deep_check(self): - """ + def update_authorised_view_access(self, project, dataset, current_access_entries): + if project in self._add_auth_view: + if dataset in self._add_auth_view[project]: + expected_auth_view_access = self._add_auth_view[project][dataset] + managed_by_view_compiler = {} + new_access_entries = [] - :return: True if should check rows and bytes counts False if notrequired - """ - return self.__day_partition_deep_check + # iterate expected authorised views + # from this caluclate the project and datsestt we want to reset auth + # views for + for access in expected_auth_view_access: + if access["projectId"] not in managed_by_view_compiler: + managed_by_view_compiler[access["projectId"]] = {} + if ( + access["datasetId"] + not in managed_by_view_compiler[access["projectId"]] + ): + managed_by_view_compiler[access["projectId"]][ + access["datasetId"] + ] = True + new_access_entries.append( + bigquery.dataset.AccessEntry(None, "view", access) + ) - def extra_dp_compare_functions(self, table): - """ - Function to allow record comparison checks for a day partition - These shoul dbe aggregates for the partition i..e max, min, avg - Override when row count if not sufficient. Row count works for - tables where rows only added non deleted or removed - :return: a comma seperated extension of functions - """ - default = "" - retpartition_time = "_PARTITIONTIME" + for current_access in current_access_entries: + if current_access.entity_type == "view": + if ( + current_access.entity_id["projectId"] + in managed_by_view_compiler + and current_access.entity_id["datasetId"] + in managed_by_view_compiler[ + current_access.entity_id["projectId"] + ] + ): + continue + new_access_entries.append(current_access) + return new_access_entries - if ( - getattr(table, "time_partitioning", None) - and table.time_partitioning.field is not None - ): - retpartition_time = "TIMESTAMP({})".format(table.time_partitioning.field) + return current_access_entries + + def compile( + self, dataset, name, sql, unnest=True, description=None, rendered=False + ): + # if already done don't do again + key = dataset.project + ":" + dataset.dataset_id + "." + name + if key in self.view_depth_optimiser: + return ( + self.view_depth_optimiser[key]["prefix"] + + self.view_depth_optimiser[key]["unnested"] + ) - SCHEMA = list(table.schema) + if not rendered: + sql = self.render(sql) - # emulate nonlocal variables this works in python 2.7 and python 3.6 - aliasdict = {"alias": "", "extrajoinpredicates": ""} + standard_sql = True + # standard sql limit 1Mb + # https://cloud.google.com/bigquery/quotas + max_query_size = 256 * 1024 + compiled_sql = sql + prefix = "" - """ - Use FRAM_FINGERPRINT as hash of each value and then summed - basically a merkel function - """ + if sql.strip().lower().find("#standardsql") == 0: + prefix = "#standardSQL\n" + else: + max_query_size = 256 * 1024 + standard_sql = False + prefix = "#legacySQL\n" - def add_data_check(SCHEMA, prefix=None, depth=0): - if prefix is None: - prefix = [] - # add base table alia - prefix.append("zzz") + # get rid of nested comments as they can break this even if in a string in a query + prefix = ( + prefix + + r""" +-- =================================================================================== +-- +-- ViewCompresser Output +-- +-- \/\/\/\/\/\/Original SQL Below\/\/\/\/ +""" + ) + for line in sql.splitlines(): + prefix = prefix + "-- " + line + "\n" + prefix = ( + prefix + + r""" +-- +-- /\/\/\/\/\/\Original SQL Above/\/\/\/\ +-- +-- Compiled SQL below +-- =================================================================================== +""" + ) - expression_list = [] + splitprojectdataset = ":" + repattern = LEGACY_SQL_PDCTABLEREGEXP + no_auth_view = "{}:{}".format(dataset.project, dataset.dataset_id) - # if we are beyond check depth exit - if self.check_depth >= 0 and depth > self.check_depth: - return expression_list + if standard_sql: + splitprojectdataset = "." + repattern = STANDARD_SQL_PDCTABLEREGEXP + no_auth_view = "{}.{}".format(dataset.project, dataset.dataset_id) - for field in SCHEMA: - prefix.append(field.name) - if field.mode != "REPEATED": - if self.check_depth >= 0 or ( - self.check_depth >= -1 - and ( - re.search("update.*time", field.name.lower()) - or re.search("modifi.*time", field.name.lower()) - or re.search("version", field.name.lower()) - or re.search("creat.*time", field.name.lower()) - ) - ): - if field.field_type == "STRING": - expression_list.append( - "IFNULL(`{0}`,'')".format("`.`".join(prefix)) - ) - elif field.field_type == "TIMESTAMP": - expression_list.append( - "CAST(IFNULL(`{0}`,TIMESTAMP('1970-01-01')) AS STRING)".format( - "`.`".join(prefix) - ) - ) - elif ( - field.field_type == "INTEGER" or field.field_type == "INT64" - ): - expression_list.append( - "CAST(IFNULL(`{0}`,0) AS STRING)".format( - "`.`".join(prefix) - ) - ) - elif ( - field.field_type == "FLOAT" - or field.field_type == "FLOAT64" - or field.field_type == "NUMERIC" - ): - expression_list.append( - "CAST(IFNULL(`{0}`,0.0) AS STRING)".format( - "`.`".join(prefix) - ) - ) - elif ( - field.field_type == "BOOL" or field.field_type == "BOOLEAN" - ): - expression_list.append( - "CAST(IFNULL(`{0}`,false) AS STRING)".format( - "`.`".join(prefix) - ) - ) - elif field.field_type == "BYTES": - expression_list.append( - "CAST(IFNULL(`{0}`,'') AS STRING)".format( - "`.`".join(prefix) - ) - ) - if field.field_type == "RECORD": - SSCHEMA = list(field.fields) - expression_list.extend( - add_data_check(SSCHEMA, prefix=prefix, depth=depth) - ) + # we unnest everything to get to underlying authorised views + # needed + with self._lock: + for i in self.view_depth_optimiser: + # relaces a table or view name with sql + if not standard_sql: + compiled_sql = compiled_sql.replace( + "[" + i + "]", + "( /* flattened view [-" + + i + + "-]*/ " + + self.view_depth_optimiser[i]["unnested"] + + ")", + ) else: - if field.field_type != "RECORD" and ( - self.check_depth >= 0 - or ( - self.check_depth >= -1 - and ( - re.search("update.*time", field.name.lower()) - or re.search("modifi.*time", field.name.lower()) - or re.search("version", field.name.lower()) - or re.search("creat.*time", field.name.lower()) - ) - ) - ): - # add the unnestof repeated base type can use own field name - fieldname = "{}{}".format(aliasdict["alias"], field.name) - aliasdict[ - "extrajoinpredicates" - ] = """{} -LEFT JOIN UNNEST(`{}`) AS `{}`""".format( - aliasdict["extrajoinpredicates"], field.name, fieldname - ) - if field.field_type == "STRING": - expression_list.append("IFNULL(`{0}`,'')".format(fieldname)) - elif field.field_type == "TIMESTAMP": - expression_list.append( - "CAST(IFNULL(`{0}`,TIMESTAMP('1970-01-01')) AS STRING)".format( - fieldname - ) - ) - elif ( - field.field_type == "INTEGER" or field.field_type == "INT64" - ): - expression_list.append( - "CAST(IFNULL(`{0}`,0) AS STRING)".format(fieldname) - ) - elif ( - field.field_type == "FLOAT" - or field.field_type == "FLOAT64" - or field.field_type == "NUMERIC" - ): - expression_list.append( - "CAST(IFNULL(`{0}`,0.0) AS STRING)".format(fieldname) - ) - elif ( - field.field_type == "BOOL" or field.field_type == "BOOLEAN" - ): - expression_list.append( - "CAST(IFNULL(`{0}`,false) AS STRING)".format(fieldname) - ) - elif field.field_type == "BYTES": - expression_list.append( - "CAST(IFNULL(`{0}`,'') AS STRING)".format(fieldname) - ) - if field.field_type == "RECORD": - # if we want to go deeper in checking - if depth < self.check_depth: - # need to unnest as repeated - if aliasdict["alias"] == "": - # uses this to name space from real columns - # bit hopeful - aliasdict["alias"] = "zzz" + compiled_sql = compiled_sql.replace( + "`" + i.replace(":", ".") + "`", + "( /* flattened view `-" + + i + + "-`*/ " + + self.view_depth_optimiser[i]["unnested"] + + ")", + ) - aliasdict["alias"] = aliasdict["alias"] + "z" + # strip off comments before analysing queries for tables used in the query and remove + # newlines + # to avoid false positives + # use set to get unique values + for project_dataset in set( + re.findall( + repattern, + re.sub(SQL_COMMENT_REGEXP, "", compiled_sql).replace("\n", " "), + ) + ): + # if is a table in same dataset nothing to authorise + if project_dataset != no_auth_view: + project_auth, dataset_auth = project_dataset.split(splitprojectdataset) + # spool the view needs authorisation for the authorisation task + self.add_auth_view( + project_auth, + dataset_auth, + { + "projectId": dataset.project, + "datasetId": dataset.dataset_id, + "tableId": name, + }, + ) + unnest = False - # so need to rest prefix for this record - newprefix = [] - newprefix.append(aliasdict["alias"]) + # and we put back original if we want to hide logic + if not unnest: + compiled_sql = sql - # add the unnest - aliasdict[ - "extrajoinpredicates" - ] = """{} -LEFT JOIN UNNEST(`{}`) AS {}""".format( - aliasdict["extrajoinpredicates"], - "`.`".join(prefix), - aliasdict["alias"], - ) + # look to keep queriesbelow maximumsize + if len(prefix + compiled_sql) > max_query_size: + # strip out comment + if standard_sql: + prefix = "#standardSQL\n" + else: + prefix = "#legacySQL\n" + # if still too big strip out other comments + # and extra space + if len(prefix + compiled_sql) > max_query_size: + nsql = "" + for line in compiled_sql.split("\n"): + # if not a comment + trimline = line.strip() + if trimline[:2] != "--": + " ".join(trimline.split()) + nsql = nsql + "\n" + trimline + compiled_sql = nsql - # add the fields - expression_list.extend( - add_data_check( - SSCHEMA, prefix=newprefix, depth=depth + 1 - ) - ) - prefix.pop() - return expression_list + # if still too big go back to original sql stripped + if len(prefix + compiled_sql) > max_query_size: + if len(sql) > max_query_size: + nsql = "" + for line in sql.split("\n"): + trimline = line.strip() + # if not a comment + if trimline[:1] != "#": + " ".join(line.split()) + nsql = nsql + "\n" + trimline + compiled_sql = nsql + else: + compiled_sql = sql - expression_list = add_data_check(SCHEMA) + with self._lock: + self.view_depth_optimiser[key] = { + "raw": sql, + "unnested": compiled_sql, + "prefix": prefix, + } - if len(expression_list) > 0: - # algorithm to compare sets - # java uses overflowing sum but big query does not allow in sum - # using average as scales sum and available as aggregate - # split top end and bottom end of hash - # that way we maximise bits and fidelity - # and - default = """, - AVG(FARM_FINGERPRINT(CONCAT({0})) & 0x0000000FFFFFFFF) as avgFingerprintLB, - AVG((FARM_FINGERPRINT(CONCAT({0})) & 0xFFFFFFF00000000) >> 32) as avgFingerprintHB,""".format( - ",".join(expression_list) - ) + return prefix + compiled_sql - predicates = self.comparison_predicates(table.table_id, retpartition_time) - if len(predicates) > 0: - aliasdict[ - "extrajoinpredicates" - ] = """{} -WHERE ({})""".format( - aliasdict["extrajoinpredicates"], ") AND (".join(predicates) - ) - return retpartition_time, default, aliasdict["extrajoinpredicates"] +def compute_region_equals_bqregion(compute_region: str, bq_region: str) -> bool: + if compute_region == bq_region or compute_region.lower() == bq_region.lower(): + bq_compute_region = bq_region.lower() + else: + bq_compute_region = MAPBQREGION2KMSREGION.get(bq_region, bq_region.lower()) + return compute_region.lower() == bq_compute_region - @property - def copy_data(self): - """ - True if to copy data not just structure - False just keeps structure in sync - :return: - """ - return self._copy_data - def table_data_change(self, srctable, dsttable): - """ - Method to allow customisation of detecting if table differs - default method is check rows, numberofbytes and last modified time - this exists to allow something like a query to do comparison - if you want to force a copy this should retrn True - :param srctable: - :param dsttable: - :return: - """ +def run_query( + client, + query, + logger, + desctext="", + location=None, + max_results=10000, + callback_on_complete=None, + labels=None, + params=None, + query_cmek=None, +): + """ + Runa big query query and yield on each row returned as a generator + :param client: The BQ client to use to run the query + :param query: The query text assumed standardsql unless starts with #legcaySQL + :param desctext: Some descriptive text to put out if in debug mode + :return: nothing + """ + use_legacy_sql = False + if query.lower().find("#legacysql") == 0: + use_legacy_sql = True - return False + def _get_parameter(name, argtype): + def _get_type(argtype): + ptype = None + if isinstance(argtype, six.string_types): + ptype = "STRING" + elif isinstance(argtype, int): + ptype = "INT64" + elif isinstance(argtype, float): + ptype = "FLOAT64" + elif isinstance(argtype, bool): + ptype = "BOOL" + elif isinstance(argtype, datetime): + ptype = "TIMESTAMP" + elif isinstance(argtype, date): + ptype = "DATE" + elif isinstance(argtype, bytes): + ptype = "BYTES" + elif isinstance(argtype, decimal.Decimal): + ptype = "NUMERIC" + else: + raise TypeError("Unrcognized type for qury paramter") + return ptype - def fault_barrier(self, function, *args): - """ - A fault barrie here to ensure functions called in thread - do n9t exit prematurely - :param function: A function to call - :param args: The functions arguments - :return: - """ - try: - function(*args) - except Exception: - pretty_printer = pprint.PrettyPrinter() - self.get_logger().exception( - "Exception calling function {} args {}".format( - function.__name__, pretty_printer.pformat(args) - ) - ) + if isinstance(argtype, list): + if argtype: + if isinstance(argtype[0], dict): + struct_list = [] + for item in argtype: + struct_list.append(_get_parameter(None, item)) + return bigquery.ArrayQueryParameter(name, "STRUCT", struct_list) + else: + return bigquery.ArrayQueryParameter( + name, _get_type(argtype[0]), argtype + ) + else: + return None - def update_source_view_definition(self, view_definition, use_standard_sql): - view_definition = view_definition.replace( - r"`{}.{}.".format(self.source_project, self.source_dataset), - "`{}.{}.".format(self.destination_project, self.destination_dataset), - ) - view_definition = view_definition.replace( - r"[{}:{}.".format(self.source_project, self.source_dataset), - "[{}:{}.".format(self.destination_project, self.destination_dataset), - ) - # this should not be required but seems it is - view_definition = view_definition.replace( - r"[{}.{}.".format(self.source_project, self.source_dataset), - "[{}:{}.".format(self.destination_project, self.destination_dataset), - ) - # support short names - view_definition = view_definition.replace( - r"{}.".format(self.source_dataset), "{}.".format(self.destination_dataset) - ) + if isinstance(argtype, dict): + struct_param = [] + for key in argtype: + struct_param.append(_get_parameter(key, argtype[key])) + return bigquery.StructQueryParameter(name, *struct_param) - return view_definition + return bigquery.ScalarQueryParameter(name, _get_type(argtype), argtype) - def calculate_target_cmek_config(self, encryption_config): - assert isinstance(encryption_config, bigquery.EncryptionConfiguration) or ( - getattr( - self.destination_dataset_impl, "default_encryption_configuration", None - ) - is not None - and self.destination_dataset_impl.default_encryption_configuration - is not None - ), ( - " To recaclculate a new encryption " - "config the original config has to be passed in and be of class " - "bigquery.EncryptionConfig" + query_parameters = [] + if params is not None: + # https://cloud.google.com/bigquery/docs/parameterized-queries#python + # named parameters + if isinstance(params, dict): + for key in params: + param = _get_parameter(key, params[key]) + if param is not None: + query_parameters.append(param) + # positional paramters + elif isinstance(params, list): + for p in params: + param = _get_parameter(None, p) + if param is not None: + query_parameters.append(param) + else: + raise TypeError("Query parameter not a dict or a list") + + job_config = bigquery.QueryJobConfig(query_parameters=query_parameters) + job_config.maximum_billing_tier = 10 + job_config.use_legacy_sql = use_legacy_sql + if query_cmek is not None: + job_config.destination_encryption_configuration = ( + bigquery.EncryptionConfiguration(kms_key_name=query_cmek) ) + if labels is not None: + job_config.labels = labels - # if destination dataset has default kms key already, just use the same - if ( - getattr( - self.destination_dataset_impl, "default_encryption_configuration", None - ) - is not None - and self.destination_dataset_impl.default_encryption_configuration - is not None - ): - return self.destination_dataset_impl.default_encryption_configuration + query_job = client.query(query, job_config=job_config, location=location) - # if a global key or same region we are good to go - if ( - self.same_region - or encryption_config.kms_key_name.find("/locations/global/") != -1 - ): - # strip off version if exists - return bigquery.EncryptionConfiguration( - get_kms_key_name(encryption_config.kms_key_name) - ) + pretty_printer = pprint.PrettyPrinter(indent=4) + results = False + while True: + query_job.reload() # Refreshes the state via a GET request. - # if global key can still be used - # if comparing table key get rid fo version - parts = get_kms_key_name(encryption_config.kms_key_name).split("/") - parts[3] = MAPBQREGION2KMSREGION.get( - self.destination_location, self.destination_location.lower() - ) + if query_job.state == "DONE": + if query_job.error_result: + errtext = "Query error {}{}".format( + pretty_printer.pformat(query_job.error_result), + pretty_printer.pformat(query_job.errors), + ) + logger.error(errtext, exc_info=True) + raise BQQueryError(query, desctext, errtext) + else: + results = True + break - return bigquery.encryption_configuration.EncryptionConfiguration( - kms_key_name="/".join(parts) - ) + if results: + # query_results = query_job.results() - def copy_access_to_destination(self): - # for those not created compare data structures - # copy data - # compare data - # copy views - # copy dataset permissions - if self.copy_access: - src_dataset = self.source_client.get_dataset( - self.source_client.dataset(self.source_dataset) - ) - dst_dataset = self.destination_client.get_dataset( - self.destination_client.dataset(self.destination_dataset) - ) - access_entries = src_dataset.access_entries - dst_access_entries = [] - for access in access_entries: - newaccess = access - if access.role is None: - # if not copying views these will fail - if "VIEW" not in self.copy_types: - continue - newaccess = self.create_access_view(access.entity_id) - dst_access_entries.append(newaccess) - dst_dataset.access_entries = dst_access_entries + # Drain the query results by requesting + # a page at a time. + # page_token = None - fields = ["access_entries"] - if dst_dataset.description != src_dataset.description: - dst_dataset.description = src_dataset.description - fields.append("description") + for irow in query_job.result(): + yield irow - if dst_dataset.friendly_name != src_dataset.friendly_name: - dst_dataset.friendly_name = src_dataset.friendly_name - fields.append("friendly_name") + if callback_on_complete is not None and callable(callback_on_complete): + callback_on_complete(query_job) - if ( - dst_dataset.default_table_expiration_ms - != src_dataset.default_table_expiration_ms - ): - dst_dataset.default_table_expiration_ms = ( - src_dataset.default_table_expiration_ms - ) - fields.append("default_table_expiration_ms") + return - if getattr(dst_dataset, "default_partition_expiration_ms", None): - if ( - dst_dataset.default_partition_expiration_ms - != src_dataset.default_partition_expiration_ms - ): - dst_dataset.default_partition_expiration_ms = ( - src_dataset.default_partition_expiration_ms - ) - fields.append("default_partition_expiration_ms") - # compare 2 dictionaries that are simple key, value - x = dst_dataset.labels - y = src_dataset.labels +class ExportImportType(object): + """ + Class that calculate the export import types that are best to use to copy the table + passed in initialiser across region. + """ - # get shared key values - shared_items = {k: x[k] for k in x if k in y and x[k] == y[k]} + def __init__(self, srctable, dsttable=None, dst_encryption_configuration=None): + """ + Construct an ExportImportType around a table that describes best format to copy this + table across region + how to compress + :param srctable: A big query table implementation that is he source of the copy + :param dsttable: optional the target definition if not specified specificfication of + source used if provided it dominates + :param dst_encryption_configuration: if set overrides keys on tables some calc has driven + this + """ + assert isinstance(srctable, bigquery.Table), ( + "Export Import Type MUST be constructed with a " "bigquery.Table object" + ) + assert dsttable is None or isinstance(dsttable, bigquery.Table), ( + "Export Import dsttabl Type MUST " + "be constructed with a " + "bigquery.Table object or None" + ) - # must be same size and values if not set labels - if len(dst_dataset.labels) != len(src_dataset.labels) or len( - shared_items - ) != len(src_dataset.labels): - dst_dataset.labels = src_dataset.labels - fields.append("labels") + if dsttable is None: + self.__table = srctable + else: + self.__table = dsttable - if getattr(dst_dataset, "default_encryption_configuration", None): - if not ( - src_dataset.default_encryption_configuration is None - and dst_dataset.default_encryption_configuration is None - ): - # if src_dataset.default_encryption_configuration is None: - # dst_dataset.default_encryption_configuration = None - # else: - # dst_dataset.default_encryption_configuration = \ - # self.calculate_target_cmek_config( - # src_dataset.default_encryption_configuration) + self._dst_encryption_configuration = None + if dst_encryption_configuration is not None: + self._dst_encryption_configuration = dst_encryption_configuration - # equate dest kms config to src only if it's None - if dst_dataset.default_encryption_configuration is None: - dst_dataset.default_encryption_configuration = ( - self.calculate_target_cmek_config( - src_dataset.default_encryption_configuration - ) - ) - fields.append("default_encryption_configuration") + # detect if any GEOGRAPHY or DATETIME fields + def _detect_non_avro_and_parquet_types(schema): + for field in schema: + if field.field_type == "GEOGRAPHY" or field.field_type == "DATETIME": + return True + if field.field_type == "RECORD": + if _detect_non_avro_and_parquet_types(list(field.fields)): + return True + return False - try: - self.destination_client.update_dataset(dst_dataset, fields) - except exceptions.Forbidden: - self.logger.error( - "Unable to det permission on {}.{} dataset as Forbidden".format( - self.destination_project, self.destination_dataset - ) - ) - except exceptions.BadRequest: - self.logger.error( - "Unable to det permission on {}.{} dataset as BadRequest".format( - self.destination_project, self.destination_dataset - ) - ) + # detect if any GEOGRAPHY or DATETIME fields + def _detect_non_parquet_types(schema): + for field in schema: + # https://cloud.google.com/bigquery/docs/exporting-data#parquet_export_details + if field.field_type == "DATETIME" or field.field_type == "TIME": + return True + # parquet only support repeated base types + if field.field_type == "RECORD": + if field.mode == "REPEATED": + return True + if _detect_non_parquet_types(list(field.fields)): + return True + return False - def create_access_view(self, entity_id): + # go columnar compression over row as generally better + # as we always process whole rows actually not worth it but we + # had fun working out whats needed to support parquet + # we have wish list of sync with hive schema which could be parquet + # so let me figure out limitations of this + # self.__destination_format = bigquery.job.DestinationFormat.PARQUET + # but if thats impossible + # if _detect_non_parquet_types(list(srctable.schema)): + self.__destination_format = bigquery.job.DestinationFormat.AVRO + if _detect_non_avro_and_parquet_types(list(srctable.schema)): + self.__destination_format = ( + bigquery.job.DestinationFormat.NEWLINE_DELIMITED_JSON + ) + + @property + def destination_format(self): """ - Convert an old view authorised view - to a new one i.e. change project id + The destination format to use for exports for ths table when copying across regions + :return: a bigquery.job.DestinationFormat enumerator + """ + return self.__destination_format - :param entity_id: - :return: a view { - ... 'projectId': 'my-project', - ... 'datasetId': 'my_dataset', - ... 'tableId': 'my_table' - ... } + @property + def source_format(self): + """ + The source format to use to load this table in destnation region + :return: a bigquery.job.SourceFormat enumerator that matches the prefferd export format """ + # only support the exports that are possible + if self.destination_format == bigquery.job.DestinationFormat.AVRO: + return bigquery.job.SourceFormat.AVRO + if self.destination_format == bigquery.job.DestinationFormat.PARQUET: + return bigquery.job.SourceFormat.PARQUET if ( - entity_id["projectId"] == self.source_project - and entity_id["datasetId"] == self.source_dataset + self.destination_format + == bigquery.job.DestinationFormat.NEWLINE_DELIMITED_JSON ): - entity_id["projectId"] = self.destination_project - entity_id["datasetId"] = self.destination_dataset - - return bigquery.AccessEntry(None, "view", entity_id) + return bigquery.job.SourceFormat.NEWLINE_DELIMITED_JSON + if self.destination_format == bigquery.job.DestinationFormat.CSV: + return bigquery.job.SourceFormat.CSV @property - def logger(self): - return self.__logger + def source_file_extension(self): + if self.destination_format == bigquery.job.DestinationFormat.AVRO: + return ".avro" + if self.destination_format == bigquery.job.DestinationFormat.PARQUET: + return ".parquet" + if ( + self.destination_format + == bigquery.job.DestinationFormat.NEWLINE_DELIMITED_JSON + ): + return ".jsonl" + if self.destination_format == bigquery.job.DestinationFormat.CSV: + return ".csv" - @logger.setter - def logger(self, alogger): - self.__logger = alogger + @property + def compression_format(self): + """ + The calculated compression type to use based on supported format + :return: one of bigquery.job.Compression enumerators or None + """ + if self.destination_format == bigquery.job.DestinationFormat.AVRO: + return bigquery.job.Compression.DEFLATE + if self.destination_format == bigquery.job.DestinationFormat.PARQUET: + return bigquery.job.Compression.SNAPPY + return bigquery.job.Compression.GZIP - def get_logger(self): + @property + def schema(self): """ - Returns the python logger to use for logging errors and issues + The target schema so if needed on load can be obtained from same object :return: """ - return self.__logger + return self.__table.schema + + @property + def encryption_configuration(self): + if self._dst_encryption_configuration is not None: + return self._dst_encryption_configuration + return self.__table.encryption_configuration class MultiBQSyncCoordinator(object): @@ -3719,28 +3730,28 @@ class MultiBQSyncCoordinator(object): def __init__( self, - srcproject_and_dataset_list, - dstproject_and_dataset_list, - srcbucket=None, - dstbucket=None, - remove_deleted_tables=True, - copy_data=True, - copy_types=("TABLE", "VIEW", "ROUTINE", "MODEL", "MATERIALIZEDVIEW"), - check_depth=-1, - copy_access=True, - table_view_filter=(".*"), - table_or_views_to_exclude=None, - latest_date=None, - days_before_latest_day=None, - day_partition_deep_check=False, - analysis_project=None, - cloud_logging_and_monitoring=False, - src_ref_project_datasets=None, - dst_ref_project_datasets=(), - query_cmek=None, - src_policy_tags=[], - dst_policy_tags=[], - ): + srcproject_and_dataset_list: List[str], + dstproject_and_dataset_list: List[str], + srcbucket: Optional[str] = None, + dstbucket: Optional[str] = None, + remove_deleted_tables: bool = True, + copy_data: bool = True, + copy_types: Optional[List[str]] = None, + check_depth: int = -1, + copy_access: bool = True, + table_view_filter: Optional[List[str]] = None, + table_or_views_to_exclude: Optional[List[str]] = None, + latest_date: Optional[datetime] = None, + days_before_latest_day: Optional[int] = None, + day_partition_deep_check: bool = False, + analysis_project: str = None, + cloud_logging_and_monitoring: bool = False, + src_ref_project_datasets: Optional[List[Any]] = None, + dst_ref_project_datasets: List[Any] = (), + query_cmek: Optional[List[str]] = None, + src_policy_tags: Optional[List[str]] = [], + dst_policy_tags: Optional[List[str]] = [], + ) -> None: if copy_types is None: copy_types = ["TABLE", "VIEW", "ROUTINE", "MODEL", "MATERIALIZEDVIEW"] if table_view_filter is None: @@ -3804,7 +3815,7 @@ def __init__( self.__copy_drivers.append(copy_driver) @property - def cloud_logging_and_monitoring(self): + def cloud_logging_and_monitoring(self) -> bool: return self.__cloud_logging_and_monitoring @property @@ -3927,7 +3938,7 @@ def rows_avoided(self): return total @property - def views_failed_sync(self): + def views_failed_sync(self) -> int: total = 0 for copy_driver in self.__copy_drivers: total += copy_driver.views_failed_sync @@ -3948,7 +3959,7 @@ def routines_avoided(self): return total @property - def routines_failed_sync(self): + def routines_failed_sync(self) -> int: total = 0 for copy_driver in self.__copy_drivers: total += copy_driver.routines_failed_sync @@ -3970,7 +3981,7 @@ def models_avoided(self): return total @property - def models_failed_sync(self): + def models_failed_sync(self) -> int: total = 0 for copy_driver in self.__copy_drivers: total += copy_driver.models_failed_sync @@ -3978,7 +3989,7 @@ def models_failed_sync(self): return total @property - def tables_failed_sync(self): + def tables_failed_sync(self) -> int: total = 0 for copy_driver in self.__copy_drivers: total += copy_driver.tables_failed_sync @@ -3999,21 +4010,21 @@ def view_avoided(self): return total @property - def extract_fails(self): + def extract_fails(self) -> int: total = 0 for copy_driver in self.__copy_drivers: total += copy_driver.extract_fails return total @property - def load_fails(self): + def load_fails(self) -> int: total = 0 for copy_driver in self.__copy_drivers: total += copy_driver.load_fails return total @property - def copy_fails(self): + def copy_fails(self) -> int: total = 0 for copy_driver in self.__copy_drivers: total += copy_driver.copy_fails @@ -4073,7 +4084,7 @@ def sync_monitor_thread(self, stop_event): self.logger.info("=============== Sync Completed ================") self.log_stats() - def sync(self): + def sync(self) -> None: """ Synchronise all the datasets in the driver :return: @@ -4260,28 +4271,28 @@ def update_source_view_definition(self, view_definition, use_standard_sql): class MultiBQSyncDriver(DefaultBQSyncDriver): def __init__( self, - srcproject, - srcdataset, - dstdataset, - dstproject=None, - srcbucket=None, - dstbucket=None, - remove_deleted_tables=True, - copy_data=True, - copy_types=("TABLE", "VIEW", "ROUTINE", "MODEL", "MATERIALIZEDVIEW"), - check_depth=-1, - copy_access=True, - coordinator=None, - table_view_filter=(".*"), - table_or_views_to_exclude=(), - latest_date=None, - days_before_latest_day=None, - day_partition_deep_check=False, - analysis_project=None, - query_cmek=None, - src_policy_tags=[], - dst_policy_tags=[], - ): + srcproject: str, + srcdataset: str, + dstdataset: str, + dstproject: Optional[str] = None, + srcbucket: Optional[str] = None, + dstbucket: Optional[str] = None, + remove_deleted_tables: bool = True, + copy_data: bool = True, + copy_types: Optional[List[str]] = None, + check_depth: int = -1, + copy_access: bool = True, + coordinator: Optional[MultiBQSyncCoordinator] = None, + table_view_filter: Optional[List[str]] = None, + table_or_views_to_exclude: List[str] = (), + latest_date: Optional[datetime] = None, + days_before_latest_day: Optional[int] = None, + day_partition_deep_check: bool = False, + analysis_project: str = None, + query_cmek: Optional[List[str]] = None, + src_policy_tags: Optional[List[str]] = [], + dst_policy_tags: Optional[List[str]] = [], + ) -> None: DefaultBQSyncDriver.__init__( self, srcproject, @@ -5527,9 +5538,7 @@ def rewrite_blob(new_name): # None is ThreadPoolExecutor max_workers default. 1 is single-threaded. # this fires up a number of background threads to rewrite all the blobs # we keep storage client work withi - with ThreadPoolExecutor( - max_workers=max_workers - ) as executor: + with ThreadPoolExecutor(max_workers=max_workers) as executor: futures = [ executor.submit(rewrite_blob, new_name=blob_name) for blob_name in generate_cp_files() @@ -5822,26 +5831,33 @@ def create_destination_routine(copy_driver, routine_name, routine_input): copy_driver.get_logger().exception( f"Unable to create routine {routine_name} in {copy_driver.destination_project}.{copy_driver.destination_dataset} definition {routine_input['routine_definition']}" ) - if dstroutine_ref.type_ == "SCALAR_FUNCTION" and dstroutine_ref.language == "SQL": - copy_driver.get_logger().info("As scalar function and SQL attempting adding as query") + if ( + dstroutine_ref.type_ == "SCALAR_FUNCTION" + and dstroutine_ref.language == "SQL" + ): + copy_driver.get_logger().info( + "As scalar function and SQL attempting adding as query" + ) function_as_query = f"""CREATE OR REPLACE FUNCTION `{dstroutine_ref.project}.{dstroutine_ref.dataset_id}.{dstroutine_ref.routine_id}` ({",".join([arg.name + " " + arg.data_type.type_kind for arg in dstroutine_ref.arguments])}) AS ({dstroutine_ref.body}) {"RETURNS " + dstroutine_ref.return_type.type_kind if dstroutine_ref.return_type else ""} OPTIONS (description="{dstroutine_ref.description if dstroutine_ref.description else ""}")""" try: for result in run_query( - copy_driver.query_client, - function_as_query, - copy_driver.get_logger(), - "Apply SQL scalar function", - location=copy_driver.destination_location, - callback_on_complete=copy_driver.update_job_stats, - labels=BQSYNCQUERYLABELS, - # ddl statements cannot use CMEK - query_cmek=None, + copy_driver.query_client, + function_as_query, + copy_driver.get_logger(), + "Apply SQL scalar function", + location=copy_driver.destination_location, + callback_on_complete=copy_driver.update_job_stats, + labels=BQSYNCQUERYLABELS, + # ddl statements cannot use CMEK + query_cmek=None, ): pass - copy_driver.get_logger().info(f"Running as query did work function {routine_name} in {copy_driver.destination_project}.{copy_driver.destination_dataset} created") + copy_driver.get_logger().info( + f"Running as query did work function {routine_name} in {copy_driver.destination_project}.{copy_driver.destination_dataset} created" + ) except Exception: copy_driver.increment_routines_failed_sync() copy_driver.get_logger().exception( @@ -5995,7 +6011,7 @@ def patch_destination_view(copy_driver, table_name, view_input): ) else: copy_driver.increment_view_avoided() - except exceptions.PreconditionFailed : + except exceptions.PreconditionFailed: copy_driver.increment_views_failed_sync() copy_driver.get_logger().exception( "Pre conditionfailed patching view {}.{}.{}".format( @@ -6024,7 +6040,11 @@ def patch_destination_view(copy_driver, table_name, view_input): ) -def sync_bq_datset(copy_driver, schema_threads=10, copy_data_threads=50): +def sync_bq_datset( + copy_driver: MultiBQSyncDriver, + schema_threads: int = 10, + copy_data_threads: int = 50, +) -> None: """ Function to use copy driver to copy tables from 1 dataset to another :param copy_driver: From 44eb05b606185c30e1536b93ec842f85206d8d48 Mon Sep 17 00:00:00 2001 From: Mike Moore Date: Mon, 1 Jan 2024 13:28:44 +0000 Subject: [PATCH 2/2] chore: add vcs.xml to .gitignore --- .gitignore | 1 + 1 file changed, 1 insertion(+) diff --git a/.gitignore b/.gitignore index 186ea35..bf565ad 100644 --- a/.gitignore +++ b/.gitignore @@ -141,6 +141,7 @@ params .idea/jarRepositories.xml .idea/modules.xml .idea/*.iml +.idea/vcs.xml .idea/modules *.iml *.ipr