Skip to content

Commit

Permalink
Convert truncate_reference_blocks to use LEN
Browse files Browse the repository at this point in the history
  • Loading branch information
chrisvittal committed Oct 5, 2024
1 parent ffa541d commit d0f0308
Showing 1 changed file with 13 additions and 9 deletions.
22 changes: 13 additions & 9 deletions hail/python/hail/vds/methods.py
Original file line number Diff line number Diff line change
Expand Up @@ -944,28 +944,28 @@ def truncate_reference_blocks(ds, *, max_ref_block_base_pairs=None, ref_block_wi
f" recommended values are <0.05."
)
max_ref_block_base_pairs = rd.aggregate_entries(
hl.agg.approx_quantiles(rd.END - rd.locus.position + 1, 1 - ref_block_winsorize_fraction, k=200)
hl.agg.approx_quantiles(rd.LEN, 1 - ref_block_winsorize_fraction, k=200)
)

assert (
max_ref_block_base_pairs > 0
), 'truncate_reference_blocks: "max_ref_block_base_pairs" must be between greater than zero'
info(f"splitting VDS reference blocks at {max_ref_block_base_pairs} base pairs")

rd_under_limit = rd.filter_entries(rd.END - rd.locus.position < max_ref_block_base_pairs).localize_entries(
'fixed_blocks', 'cols'
)
rd_under_limit = rd.filter_entries(rd.LEN <= max_ref_block_base_pairs).localize_entries('fixed_blocks', 'cols')

rd_over_limit = rd.filter_entries(rd.END - rd.locus.position >= max_ref_block_base_pairs).key_cols_by(
col_idx=hl.scan.count()
)
rd_over_limit = rd.filter_entries(rd.LEN > max_ref_block_base_pairs).key_cols_by(col_idx=hl.scan.count())
rd_over_limit = rd_over_limit.select_rows().select_cols().key_rows_by().key_cols_by()
es = rd_over_limit.entries()
es = es.annotate(new_start=hl.range(es.locus.position, es.END + 1, max_ref_block_base_pairs))
es = es.annotate(new_start=hl.range(es.locus.position, es.locus.position + es.LEN, max_ref_block_base_pairs))
es = es.explode('new_start')
es = es.transmute(
locus=hl.locus(es.locus.contig, es.new_start, reference_genome=es.locus.dtype.reference_genome),
END=hl.min(es.new_start + max_ref_block_base_pairs - 1, es.END),
LEN=hl.if_else(
es.new_start + max_ref_block_base_pairs <= es.locus.position + es.LEN,
max_ref_block_base_pairs,
es.LEN % max_ref_block_base_pairs,
),
)
es = es.key_by(es.locus).collect_by_key("new_blocks")
es = es.transmute(moved_blocks_dict=hl.dict(es.new_blocks.map(lambda x: (x.col_idx, x.drop('col_idx')))))
Expand All @@ -981,6 +981,10 @@ def truncate_reference_blocks(ds, *, max_ref_block_base_pairs=None, ref_block_wi
)
new_rd = new_rd.annotate_globals(**{fd_name: max_ref_block_base_pairs})

# we've changed LEN so we need to make sure that END is correct.
if 'END' in new_rd.entry:
new_rd = VariantDataset._add_end(new_rd.drop('END'))

if isinstance(ds, hl.vds.VariantDataset):
return VariantDataset(reference_data=new_rd, variant_data=ds.variant_data)
return new_rd
Expand Down

0 comments on commit d0f0308

Please sign in to comment.