Skip to content

Commit

Permalink
Enable users to store sample look up optimized tables
Browse files Browse the repository at this point in the history
 + Add input flag and validate its value.
 + Add BQ queries to flatten call column.
 + Extract the schema of the flatten table.
 + Add unit tests to verify the correctness of extracted schema.
 + Add a new method `table_empty()` for checking whether a table is empty or not.
  • Loading branch information
samanvp committed Apr 25, 2020
1 parent 671c4b9 commit c74ad88
Show file tree
Hide file tree
Showing 5 changed files with 431 additions and 53 deletions.
162 changes: 159 additions & 3 deletions gcp_variant_transforms/libs/bigquery_util.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@

"""Constants and simple utility functions related to BigQuery."""

from concurrent.futures import TimeoutError
import enum
import exceptions
import logging
Expand All @@ -40,17 +41,32 @@
_RANGE_END_SIG_DIGITS = 4
_RANGE_INTERVAL_SIG_DIGITS = 1

START_POSITION_COLUMN = 'start_position'
_BQ_CREATE_PARTITIONED_TABLE_COMMAND = (
'bq mk --table --range_partitioning='
'{PARTITION_COLUMN},0,{RANGE_END},{RANGE_INTERVAL} '
'--clustering_fields=start_position,end_position '
'{FULL_TABLE_ID} {SCHEMA_FILE_PATH}')
_BQ_DELETE_TABLE_COMMAND = 'bq rm -f -t {FULL_TABLE_ID}'
_BQ_EXTRACT_SCHEMA_COMMAND = (
'bq show --schema --format=prettyjson {FULL_TABLE_ID} > {SCHEMA_FILE_PATH}')
_GCS_DELETE_FILES_COMMAND = 'gsutil -m rm -f -R {ROOT_PATH}'
_BQ_LOAD_JOB_NUM_RETRIES = 5
_BQ_NUM_RETRIES = 3
_MAX_NUM_CONCURRENT_BQ_LOAD_JOBS = 4

_GET_COLUMN_NAMES_QUERY = (
'SELECT column_name '
'FROM `{PROJECT_ID}`.{DATASET_ID}.INFORMATION_SCHEMA.COLUMNS '
'WHERE table_name = "{TABLE_ID}"')
_GET_CALL_SUB_FIELDS_QUERY = (
'SELECT field_path '
'FROM `{PROJECT_ID}`.{DATASET_ID}.INFORMATION_SCHEMA.COLUMN_FIELD_PATHS '
'WHERE table_name = "{TABLE_ID}" AND column_name="{CALL_COLUMN}"')
_MAIN_TABLE_ALIAS = 'main_table'
_CALL_TABLE_ALIAS = 'call_table'
_FLATTEN_CALL_QUERY = (
'SELECT {SELECT_COLUMNS} '
'FROM `{PROJECT_ID}.{DATASET_ID}.{TABLE_ID}` as {MAIN_TABLE_ALIAS}, '
'UNNEST({CALL_COLUMN}) as {CALL_TABLE_ALIAS}')

class ColumnKeyConstants(object):
"""Constants for column names in the BigQuery schema."""
Expand Down Expand Up @@ -189,6 +205,37 @@ def table_exist(client, project_id, dataset_id, table_id):
raise
return True

def table_empty(project_id, dataset_id, table_id):
client = bigquery.Client(project=project_id)
query = 'SELECT count(0) AS num_rows FROM {DATASET_ID}.{TABLE_ID}'.format(
DATASET_ID=dataset_id, TABLE_ID=table_id)
query_job = client.query(query)
num_retries = 0
while True:
try:
iterator = query_job.result(timeout=300)
except TimeoutError as e:
logging.warning('Time out waiting for query: %s', query)
if num_retries < _BQ_NUM_RETIRES:
num_retries += 1
time.sleep(90)
else:
raise e
else:
break

rows = list(iterator)
if len(rows) != 1:
logging.error('Query did not returned expected # of rows: {}'.format(query))
raise ValueError('Expected 1 row in query result, got {}'.format(len(rows)))

col = rows[0]
if len(col) != 1:
logging.error('Query did not returned expected # of cols: {}'.format(query))
raise ValueError('Expected 1 col in query result, got {}'.format(len(col)))

return col.get('num_rows') == 0

def get_bigquery_type_from_vcf_type(vcf_type):
# type: (str) -> str
vcf_type = vcf_type.lower()
Expand Down Expand Up @@ -416,7 +463,7 @@ def _cancel_all_running_load_jobs(self):
load_job.cancel()

def _handle_failed_load_job(self, suffix, load_job):
if self._num_load_jobs_retries < _BQ_LOAD_JOB_NUM_RETRIES:
if self._num_load_jobs_retries < _BQ_NUM_RETRIES:
self._num_load_jobs_retries += 1
# Retry the failed job after 5 minutes wait.
time.sleep(300)
Expand Down Expand Up @@ -447,6 +494,115 @@ def _monitor_load_jobs(self):
self._start_one_load_job(next_suffix)


class FlattenCallColumn(object):
def __init__(self, full_table_id):
(self._project_id,
self._dataset_id,
self._table_id) = parse_table_reference(full_table_id)

self._client = bigquery.Client(project=self._project_id)

def _run_query(self, query):
query_job = self._client.query(query)
num_retries = 0
while True:
try:
iterator = query_job.result(timeout=300)
except TimeoutError as e:
logging.warning('Time out waiting for query: %s', query)
if num_retries < _BQ_NUM_RETIRES:
num_retries += 1
time.sleep(90)
else:
raise e
else:
break
result = []
for i in iterator:
result.append(str(i.values()[0]))
return result

def _get_column_names(self):
query = _GET_COLUMN_NAMES_QUERY.format(PROJECT_ID=self._project_id,
DATASET_ID=self._dataset_id,
TABLE_ID=self._table_id)
return self._run_query(query)

def _get_call_sub_fields(self):
query = _GET_CALL_SUB_FIELDS_QUERY.format(PROJECT_ID=self._project_id,
DATASET_ID=self._dataset_id,
TABLE_ID=self._table_id,
CALL_COLUMN=ColumnKeyConstants.CALLS)
# returned list will be [call, call.name, call.genotype, call.phaseset, ...]
result = self._run_query(query)[1:] # Drop the first element
return [sub_field.split('.')[1] for sub_field in result]

def _get_flatten_column_names(self):
column_names = self._get_column_names()
assert len(column_names) > 0
sub_fields = self._get_call_sub_fields()
assert len(sub_fields) > 0
select_list = []
for column in column_names:
if column != ColumnKeyConstants.CALLS:
select_list.append(_MAIN_TABLE_ALIAS + '.' + column + ' AS ' + column)
else:
for s_f in sub_fields:
select_list.append(_CALL_TABLE_ALIAS + '.' + s_f + ' AS ' + ColumnKeyConstants.CALLS + '_' + s_f)
return ', '.join(select_list)

def _create_temp_flatten_table(self):
temp_suffix = time.strftime('%Y%m%d_%H%M%S')
temp_table_id = '{}{}'.format(self._table_id, temp_suffix)
full_table_id = '{}.{}.{}'.format(
self._project_id, self._dataset_id, temp_table_id)
job_config = bigquery.QueryJobConfig(destination=full_table_id)

select_columns = self._get_flatten_column_names()
sql = _FLATTEN_CALL_QUERY.format(SELECT_COLUMNS=select_columns,
PROJECT_ID=self._project_id,
DATASET_ID=self._dataset_id,
TABLE_ID=self._table_id,
MAIN_TABLE_ALIAS=_MAIN_TABLE_ALIAS,
CALL_COLUMN=ColumnKeyConstants.CALLS,
CALL_TABLE_ALIAS=_CALL_TABLE_ALIAS)
sql += ' LIMIT 1' # We need this table only to extract its schema.
query_job = self._client.query(sql, job_config=job_config)
while True:
try:
_ = query_job.result(timeout=300)
except TimeoutError as e:
logging.warning('Time out waiting for query: %s', sql)
if num_retries < _BQ_NUM_RETIRES:
num_retries += 1
time.sleep(90)
else:
raise e
else:
break
logging.info('A new table with only 1 row was crated: %s', full_table_id)
logging.info('This table is used to extract the schema of flatten tables.')
return temp_table_id

def get_flatten_table_schema(self, schema_file_path):
temp_table_id = self._create_temp_flatten_table()
full_table_id = '{}:{}.{}'.format(
self._project_id, self._dataset_id, temp_table_id)
bq_command = _BQ_EXTRACT_SCHEMA_COMMAND.format(
FULL_TABLE_ID=full_table_id,
SCHEMA_FILE_PATH=schema_file_path)
result = os.system(bq_command)
if result != 0:
logging.error('Failed to extract flatten table schema using "%s" command',
bq_command)
else:
logging.info('Successfully extracted the schema of flatten table.')
if _delete_table(full_table_id) == 0:
logging.info('Successfully deleted temporary table: %s', full_table_id)
else:
logging.error('Was not able to delete temporary table: %s', full_table_id)
return result

def create_output_table(full_table_id, # type: str
partition_column, # type: str
range_end, # type: int
Expand Down
61 changes: 61 additions & 0 deletions gcp_variant_transforms/libs/bigquery_util_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,8 @@

from __future__ import absolute_import

import filecmp
import tempfile
import unittest
from apitools.base.py import exceptions

Expand Down Expand Up @@ -439,6 +441,17 @@ def test_does_table_exist(self):
bigquery_util.table_exist,
client, 'project', 'dataset', 'table')

def test_table_empty(self):
project_id = 'gcp-variant-transforms-test'
dataset_id = 'bq_to_vcf_integration_tests'
table_with_3_rows = 'merge_option_move_to_calls___chr20'
self.assertFalse(bigquery_util.table_empty(project_id, dataset_id,
table_with_3_rows))

table_with_0_rows = 'merge_option_move_to_calls___chr21'
self.assertTrue(bigquery_util.table_empty(project_id, dataset_id,
table_with_0_rows))

def test_raise_error_if_dataset_not_exists(self):
client = mock.Mock()
client.datasets.Get.return_value = bigquery.Dataset(
Expand Down Expand Up @@ -502,3 +515,51 @@ def test_calculate_optimal_partition_size(self):
self.assertEqual(expected_partition_size *
(bigquery_util._MAX_BQ_NUM_PARTITIONS - 1),
total_base_pairs_enlarged)

class FlattenCallColumnTest(unittest.TestCase):
"""Test cases for class `FlattenCallColumn`."""

def setUp(self):
self._input_table = 'gcp-variant-transforms-test:bq_to_vcf_integration_tests.merge_option_move_to_calls___chr20'
self._flatter = bigquery_util.FlattenCallColumn(self._input_table)

def test_get_column_names(self):
expected_column_names = ['reference_name', 'start_position', 'end_position',
'reference_bases', 'alternate_bases', 'names',
'quality', 'filter', 'call', 'NS', 'DP', 'AA',
'DB', 'H2']
self.assertEqual(expected_column_names, self._flatter._get_column_names())

def test_get_call_sub_fields(self):
expected_sub_fields = \
['sample_id', 'genotype', 'phaseset', 'DP', 'GQ', 'HQ']
self.assertEqual(expected_sub_fields, self._flatter._get_call_sub_fields())

def test_get_flatten_column_names(self):
expected_select = (
'main_table.reference_name AS reference_name, '
'main_table.start_position AS start_position, '
'main_table.end_position AS end_position, '
'main_table.reference_bases AS reference_bases, '
'main_table.alternate_bases AS alternate_bases, '
'main_table.names AS names, '
'main_table.quality AS quality, '
'main_table.filter AS filter, '
'call_table.sample_id AS call_sample_id, '
'call_table.genotype AS call_genotype, '
'call_table.phaseset AS call_phaseset, '
'call_table.DP AS call_DP, '
'call_table.GQ AS call_GQ, '
'call_table.HQ AS call_HQ, '
'main_table.NS AS NS, '
'main_table.DP AS DP, '
'main_table.AA AS AA, '
'main_table.DB AS DB, '
'main_table.H2 AS H2')
self.assertEqual(expected_select, self._flatter._get_flatten_column_names())

def test_get_flatten_table_schema(self):
generated_flatten_schema = tempfile.mkstemp()[1]
self._flatter.get_flatten_table_schema(generated_flatten_schema)
expected_flatten_schema = 'gcp_variant_transforms/testing/data/schema/flatten_merge_option_move_to_calls.json'
self.assertTrue(filecmp(generated_flatten_schema, expected_flatten_schema), 'Schema file does not match the expected schema')
Loading

0 comments on commit c74ad88

Please sign in to comment.