Skip to content

Commit

Permalink
Enable users to store sample look up optimized tables
Browse files Browse the repository at this point in the history
 + Add input flag and validate its value.
 + Add BQ queries to flatten call column.
 + Extract the schema of the flatten table.
 + Add unit tests to verify the correctness of extracted schema.
  • Loading branch information
samanvp committed May 14, 2020
1 parent 99c3daf commit 3f3666a
Show file tree
Hide file tree
Showing 6 changed files with 433 additions and 56 deletions.
4 changes: 1 addition & 3 deletions cloudbuild_CI.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -42,9 +42,7 @@ steps:
- '--project ${PROJECT_ID}'
- '--image_tag ${COMMIT_SHA}'
- '--run_unit_tests'
- '--run_preprocessor_tests'
- '--run_bq_to_vcf_tests'
- '--run_all_tests'
- '--run_presubmit_tests'
- '--test_name_prefix cloud-ci-'
id: 'test-gcp-variant-transforms-docker'
entrypoint: '/opt/gcp_variant_transforms/src/deploy_and_run_tests.sh'
Expand Down
175 changes: 172 additions & 3 deletions gcp_variant_transforms/libs/bigquery_util.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@

"""Constants and simple utility functions related to BigQuery."""

from concurrent.futures import TimeoutError
import enum
import exceptions
import logging
Expand Down Expand Up @@ -45,7 +46,6 @@
_TOTAL_BASE_PAIRS_SIG_DIGITS = 4
_PARTITION_SIZE_SIG_DIGITS = 1

START_POSITION_COLUMN = 'start_position'
_BQ_CREATE_PARTITIONED_TABLE_COMMAND = (
'bq mk --table --range_partitioning='
'{PARTITION_COLUMN},0,{RANGE_END},{RANGE_INTERVAL} '
Expand All @@ -54,10 +54,26 @@
_BQ_CREATE_SAMPLE_INFO_TABLE_COMMAND = (
'bq mk --table {FULL_TABLE_ID} {SCHEMA_FILE_PATH}')
_BQ_DELETE_TABLE_COMMAND = 'bq rm -f -t {FULL_TABLE_ID}'
_BQ_EXTRACT_SCHEMA_COMMAND = (
'bq show --schema --format=prettyjson {FULL_TABLE_ID} > {SCHEMA_FILE_PATH}')
_GCS_DELETE_FILES_COMMAND = 'gsutil -m rm -f -R {ROOT_PATH}'
_BQ_LOAD_JOB_NUM_RETRIES = 5
_BQ_NUM_RETRIES = 3
_MAX_NUM_CONCURRENT_BQ_LOAD_JOBS = 4

_GET_COLUMN_NAMES_QUERY = (
'SELECT column_name '
'FROM `{PROJECT_ID}`.{DATASET_ID}.INFORMATION_SCHEMA.COLUMNS '
'WHERE table_name = "{TABLE_ID}"')
_GET_CALL_SUB_FIELDS_QUERY = (
'SELECT field_path '
'FROM `{PROJECT_ID}`.{DATASET_ID}.INFORMATION_SCHEMA.COLUMN_FIELD_PATHS '
'WHERE table_name = "{TABLE_ID}" AND column_name="{CALL_COLUMN}"')
_MAIN_TABLE_ALIAS = 'main_table'
_CALL_TABLE_ALIAS = 'call_table'
_FLATTEN_CALL_QUERY = (
'SELECT {SELECT_COLUMNS} '
'FROM `{PROJECT_ID}.{DATASET_ID}.{TABLE_ID}` as {MAIN_TABLE_ALIAS}, '
'UNNEST({CALL_COLUMN}) as {CALL_TABLE_ALIAS}')

class ColumnKeyConstants(object):
"""Constants for column names in the BigQuery schema."""
Expand All @@ -75,6 +91,9 @@ class ColumnKeyConstants(object):
CALLS_GENOTYPE = 'genotype'
CALLS_PHASESET = 'phaseset'

CALL_SAMPLE_ID_COLUMN = (ColumnKeyConstants.CALLS + '_' +
ColumnKeyConstants.CALLS_SAMPLE_ID)


class TableFieldConstants(object):
"""Constants for field modes/types in the BigQuery schema."""
Expand Down Expand Up @@ -435,7 +454,7 @@ def _cancel_all_running_load_jobs(self):
load_job.cancel()

def _handle_failed_load_job(self, suffix, load_job):
if self._num_load_jobs_retries < _BQ_LOAD_JOB_NUM_RETRIES:
if self._num_load_jobs_retries < _BQ_NUM_RETRIES:
self._num_load_jobs_retries += 1
# Retry the failed job after 5 minutes wait.
time.sleep(300)
Expand Down Expand Up @@ -482,6 +501,156 @@ def create_sample_info_table(output_table_id):
SCHEMA_FILE_PATH=SAMPLE_INFO_TABLE_SCHEMA_FILE_PATH)
_run_table_creation_command(bq_command)

class FlattenCallColumn(object):
def __init__(self, base_table_id, suffixes):
(self._project_id,
self._dataset_id,
self._base_table) = parse_table_reference(base_table_id)
assert suffixes
self._suffixes = suffixes[:]

# We can use any of the input tables as source of schema, we use index 0
self._schema_table_id = compose_table_name(self._base_table,
suffixes[0])
self._column_names = []
self._sub_fields = []
self._client = bigquery.Client(project=self._project_id)

def _run_query(self, query):
query_job = self._client.query(query)
num_retries = 0
while True:
try:
iterator = query_job.result(timeout=300)
except TimeoutError as e:
logging.warning('Time out waiting for query: %s', query)
if num_retries < _BQ_NUM_RETRIES:
num_retries += 1
time.sleep(90)
else:
raise e
else:
break
result = []
for i in iterator:
result.append(str(i.values()[0]))
return result

def _get_column_names(self):
if not self._column_names:
query = _GET_COLUMN_NAMES_QUERY.format(PROJECT_ID=self._project_id,
DATASET_ID=self._dataset_id,
TABLE_ID=self._schema_table_id)
self._column_names = self._run_query(query)[:]
assert self._column_names
return self._column_names

def _get_call_sub_fields(self):
if not self._sub_fields:
query = _GET_CALL_SUB_FIELDS_QUERY.format(
PROJECT_ID=self._project_id, DATASET_ID=self._dataset_id,
TABLE_ID=self._schema_table_id, CALL_COLUMN=ColumnKeyConstants.CALLS)
# returned list is [call, call.name, call.genotype, call.phaseset, ...]
result = self._run_query(query)[1:] # Drop the first element
self._sub_fields = [sub_field.split('.')[1] for sub_field in result]
assert self._sub_fields
return self._sub_fields

def _get_flatten_column_names(self):
column_names = self._get_column_names()
sub_fields = self._get_call_sub_fields()
select_list = []
for column in column_names:
if column != ColumnKeyConstants.CALLS:
select_list.append(_MAIN_TABLE_ALIAS + '.' + column + ' AS `'+
column + '`')
else:
for s_f in sub_fields:
select_list.append(_CALL_TABLE_ALIAS + '.' + s_f + ' AS `' +
ColumnKeyConstants.CALLS + '_' + s_f + '`')
return ', '.join(select_list)

def _copy_to_flatten_table(self, output_table_id, cp_query):
job_config = bigquery.QueryJobConfig(destination=output_table_id)
query_job = self._client.query(cp_query, job_config=job_config)
num_retries = 0
while True:
try:
_ = query_job.result(timeout=600)
except TimeoutError as e:
logging.warning('Time out waiting for query: %s', cp_query)
if num_retries < _BQ_NUM_RETRIES:
num_retries += 1
time.sleep(90)
else:
logging.error('Copy to table query failed: %s', output_table_id)
raise e
else:
break
logging.info('Copy to table query was successful: %s', output_table_id)

def _create_temp_flatten_table(self):
temp_suffix = time.strftime('%Y%m%d_%H%M%S')
temp_table_id = '{}{}'.format(self._schema_table_id, temp_suffix)
full_output_table_id = '{}.{}.{}'.format(
self._project_id, self._dataset_id, temp_table_id)

select_columns = self._get_flatten_column_names()
cp_query = _FLATTEN_CALL_QUERY.format(SELECT_COLUMNS=select_columns,
PROJECT_ID=self._project_id,
DATASET_ID=self._dataset_id,
TABLE_ID=self._schema_table_id,
MAIN_TABLE_ALIAS=_MAIN_TABLE_ALIAS,
CALL_COLUMN=ColumnKeyConstants.CALLS,
CALL_TABLE_ALIAS=_CALL_TABLE_ALIAS)
cp_query += ' LIMIT 1' # We need this table only to extract its schema.
self._copy_to_flatten_table(full_output_table_id, cp_query)
logging.info('A new table with 1 row was crated: %s', full_output_table_id)
logging.info('This table is used to extract the schema of flatten table.')
return temp_table_id

def get_flatten_table_schema(self, schema_file_path):
temp_table_id = self._create_temp_flatten_table()
full_table_id = '{}:{}.{}'.format(
self._project_id, self._dataset_id, temp_table_id)
bq_command = _BQ_EXTRACT_SCHEMA_COMMAND.format(
FULL_TABLE_ID=full_table_id,
SCHEMA_FILE_PATH=schema_file_path)
result = os.system(bq_command)
if result != 0:
logging.error('Failed to extract flatten table schema using "%s" command',
bq_command)
else:
logging.info('Successfully extracted the schema of flatten table.')
if _delete_table(full_table_id) == 0:
logging.info('Successfully deleted temporary table: %s', full_table_id)
else:
logging.error('Was not able to delete temporary table: %s', full_table_id)
return result

def copy_to_flatten_table(self, output_base_table_id):
# Here we assume all output_table_base + suffices[:] are already created.
(output_project_id,
output_dataset_id,
output_base_table) = parse_table_reference(output_base_table_id)
select_columns = self._get_flatten_column_names()
for suffix in self._suffixes:
input_table_id = compose_table_name(self._base_table, suffix)
output_table_id = compose_table_name(output_base_table, suffix)

full_output_table_id = '{}.{}.{}'.format(
output_project_id, output_dataset_id, output_table_id)
cp_query = _FLATTEN_CALL_QUERY.format(
SELECT_COLUMNS=select_columns, PROJECT_ID=self._project_id,
DATASET_ID=self._dataset_id, TABLE_ID=input_table_id,
MAIN_TABLE_ALIAS=_MAIN_TABLE_ALIAS,
CALL_COLUMN=ColumnKeyConstants.CALLS,
CALL_TABLE_ALIAS=_CALL_TABLE_ALIAS)

self._copy_to_flatten_table(full_output_table_id, cp_query)
logging.info('Flatten table is fully loaded: %s', full_output_table_id)


def create_output_table(full_table_id, # type: str
partition_column, # type: str
range_end, # type: int
Expand Down
46 changes: 45 additions & 1 deletion gcp_variant_transforms/libs/bigquery_util_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -417,7 +417,6 @@ def test_merge_field_schemas_merge_inner_record_fields(self):
field_schemas_2),
expected_merged_field_schemas)


def test_does_table_exist(self):
client = mock.Mock()
client.tables.Get.return_value = bigquery.Table(
Expand Down Expand Up @@ -516,3 +515,48 @@ def test_calculate_optimal_range_interval_large(self):
bigquery_util.calculate_optimal_range_interval(large_range_end))
self.assertEqual(expected_interval, range_interval)
self.assertEqual(expected_end, range_end_enlarged)


class FlattenCallColumnTest(unittest.TestCase):
"""Test cases for class `FlattenCallColumn`."""

def setUp(self):
input_base_table = ('gcp-variant-transforms-test:'
'bq_to_vcf_integration_tests.'
'merge_option_move_to_calls')
self._flatter = bigquery_util.FlattenCallColumn(input_base_table, ['chr20'])

def test_get_column_names(self):
expected_column_names = ['reference_name', 'start_position', 'end_position',
'reference_bases', 'alternate_bases', 'names',
'quality', 'filter', 'call', 'NS', 'DP', 'AA',
'DB', 'H2']
self.assertEqual(expected_column_names, self._flatter._get_column_names())

def test_get_call_sub_fields(self):
expected_sub_fields = \
['sample_id', 'genotype', 'phaseset', 'DP', 'GQ', 'HQ']
self.assertEqual(expected_sub_fields, self._flatter._get_call_sub_fields())

def test_get_flatten_column_names(self):
expected_select = (
'main_table.reference_name AS `reference_name`, '
'main_table.start_position AS `start_position`, '
'main_table.end_position AS `end_position`, '
'main_table.reference_bases AS `reference_bases`, '
'main_table.alternate_bases AS `alternate_bases`, '
'main_table.names AS `names`, '
'main_table.quality AS `quality`, '
'main_table.filter AS `filter`, '
'call_table.sample_id AS `call_sample_id`, '
'call_table.genotype AS `call_genotype`, '
'call_table.phaseset AS `call_phaseset`, '
'call_table.DP AS `call_DP`, '
'call_table.GQ AS `call_GQ`, '
'call_table.HQ AS `call_HQ`, '
'main_table.NS AS `NS`, '
'main_table.DP AS `DP`, '
'main_table.AA AS `AA`, '
'main_table.DB AS `DB`, '
'main_table.H2 AS `H2`')
self.assertEqual(expected_select, self._flatter._get_flatten_column_names())
Loading

0 comments on commit 3f3666a

Please sign in to comment.