Skip to content
This repository has been archived by the owner on Oct 29, 2023. It is now read-only.

Update to handle API reading and optimize sharded writing and indexing #165

Merged
merged 1 commit into from
Mar 2, 2016
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -15,55 +15,58 @@
*/
package com.google.cloud.genomics.dataflow.functions;

import com.google.api.services.storage.Storage;
import com.google.api.services.storage.Storage.Objects.Compose;
import com.google.api.services.storage.model.ComposeRequest;
import com.google.api.services.storage.model.ComposeRequest.SourceObjects;
import com.google.api.services.storage.model.StorageObject;
import com.google.api.services.storage.Storage;
import com.google.api.services.storage.Storage.Objects.Compose;

import com.google.cloud.dataflow.sdk.transforms.Aggregator;
import com.google.cloud.dataflow.sdk.transforms.DoFn;
import com.google.cloud.dataflow.sdk.transforms.Sum.SumIntegerFn;
import com.google.cloud.dataflow.sdk.util.GcsUtil;
import com.google.cloud.dataflow.sdk.util.gcsfs.GcsPath;
import com.google.cloud.dataflow.sdk.util.Transport;
import com.google.cloud.dataflow.sdk.util.gcsfs.GcsPath;
import com.google.cloud.dataflow.sdk.values.PCollectionView;
import com.google.cloud.genomics.dataflow.readers.bam.BAMIO;
import com.google.cloud.genomics.dataflow.utils.GCSOptions;
import com.google.cloud.genomics.dataflow.utils.GCSOutputOptions;

import com.google.common.base.Stopwatch;
import com.google.common.collect.Lists;

import htsjdk.samtools.BAMIndexer;
import htsjdk.samtools.SAMRecord;
import htsjdk.samtools.SamReader;
import htsjdk.samtools.ValidationStringency;
import htsjdk.samtools.util.BlockCompressedStreamConstants;

import java.io.IOException;
import java.io.OutputStream;
import java.nio.channels.Channels;
import java.util.ArrayList;
import java.util.Collections;
import java.util.List;
import java.util.concurrent.TimeUnit;
import java.util.logging.Logger;

/*
* Takes a set of BAM files that have been written to disk, concatenates them into one
* file (removing unneeded EOF blocks), and writes an index for the combined file.
* Takes a set of files that have been written to disk, concatenates them into one
* file and also appends "EOF" content at the end.
*/
public class CombineShardsFn extends DoFn<String, String> {

public static interface Options extends GCSOutputOptions, GCSOptions {}

private static final int MAX_FILES_FOR_COMPOSE = 32;
private static final String BAM_INDEX_FILE_MIME_TYPE = "application/octet-stream";
private static final int MAX_RETRY_COUNT = 3;

private static final String FILE_MIME_TYPE = "application/octet-stream";
private static final Logger LOG = Logger.getLogger(CombineShardsFn.class.getName());

final PCollectionView<Iterable<String>> shards;
final PCollectionView<byte[]> eofContent;
Aggregator<Integer, Integer> filesToCombineAggregator;
Aggregator<Integer, Integer> combinedFilesAggregator;
Aggregator<Integer, Integer> createdFilesAggregator;
Aggregator<Integer, Integer> deletedFilesAggregator;

public CombineShardsFn(PCollectionView<Iterable<String>> shards) {
public CombineShardsFn(PCollectionView<Iterable<String>> shards, PCollectionView<byte[]> eofContent) {
this.shards = shards;
this.eofContent = eofContent;
filesToCombineAggregator = createAggregator("Files to combine", new SumIntegerFn());
combinedFilesAggregator = createAggregator("Files combined", new SumIntegerFn());
createdFilesAggregator = createAggregator("Created files", new SumIntegerFn());
deletedFilesAggregator = createAggregator("Deleted files", new SumIntegerFn());
}

@Override
Expand All @@ -72,12 +75,13 @@ public void processElement(DoFn<String, String>.ProcessContext c) throws Excepti
combineShards(
c.getPipelineOptions().as(Options.class),
c.element(),
c.sideInput(shards));
c.sideInput(shards),
c.sideInput(eofContent));
c.output(result);
}

static String combineShards(Options options, String dest,
Iterable<String> shards) throws IOException {
String combineShards(Options options, String dest,
Iterable<String> shards, byte[] eofContent) throws IOException {
LOG.info("Combining shards into " + dest);
final Storage.Objects storage = Transport.newStorageClient(
options
Expand All @@ -89,14 +93,20 @@ static String combineShards(Options options, String dest,
Collections.sort(sortedShardsNames);

// Write an EOF block (empty gzip block), and put it at the end.
String eofFileName = options.getOutput() + "-EOF";
final OutputStream os = Channels.newOutputStream(
(new GcsUtil.GcsUtilFactory()).create(options).create(
GcsPath.fromUri(eofFileName),
BAM_INDEX_FILE_MIME_TYPE));
os.write(BlockCompressedStreamConstants.EMPTY_GZIP_BLOCK);
os.close();
sortedShardsNames.add(eofFileName);
if (eofContent != null && eofContent.length > 0) {
String eofFileName = options.getOutput() + "-EOF";
final OutputStream os = Channels.newOutputStream(
(new GcsUtil.GcsUtilFactory()).create(options).create(
GcsPath.fromUri(eofFileName),
FILE_MIME_TYPE));
os.write(eofContent);
os.close();
sortedShardsNames.add(eofFileName);
LOG.info("Written " + eofContent.length + " bytes into EOF file " +
eofFileName);
} else {
LOG.info("No EOF content");
}

int stageNumber = 0;
/*
Expand Down Expand Up @@ -135,77 +145,54 @@ static String combineShards(Options options, String dest,
LOG.info("Combining a final group of " + sortedShardsNames.size() + " shards");
final String combineResult = composeAndCleanUpShards(storage,
sortedShardsNames, dest);
generateIndex(options, storage, combineResult);
return combineResult;
}

static void generateIndex(Options options,
Storage.Objects storage, String bamFilePath) throws IOException {
final String baiFilePath = bamFilePath + ".bai";
Stopwatch timer = Stopwatch.createStarted();
LOG.info("Generating BAM index: " + baiFilePath);
LOG.info("Reading BAM file: " + bamFilePath);
final SamReader reader = BAMIO.openBAM(storage, bamFilePath, ValidationStringency.LENIENT, true);

final OutputStream outputStream =
Channels.newOutputStream(
new GcsUtil.GcsUtilFactory().create(options)
.create(GcsPath.fromUri(baiFilePath),
BAM_INDEX_FILE_MIME_TYPE));
BAMIndexer indexer = new BAMIndexer(outputStream, reader.getFileHeader());
String composeAndCleanUpShards(
Storage.Objects storage, List<String> shardNames, String dest) throws IOException {
LOG.info("Combining shards into " + dest);

final GcsPath destPath = GcsPath.fromUri(dest);

long processedReads = 0;
StorageObject destination = new StorageObject().setContentType(FILE_MIME_TYPE);

// create and write the content
for (SAMRecord rec : reader) {
if (++processedReads % 1000000 == 0) {
dumpStats(processedReads, timer);
ArrayList<SourceObjects> sourceObjects = new ArrayList<SourceObjects>();
int addedShardCount = 0;
for (String shard : shardNames) {
final GcsPath shardPath = GcsPath.fromUri(shard);
LOG.info("Adding shard " + shardPath + " for result " + dest);
sourceObjects.add(new SourceObjects().setName(shardPath.getObject()));
addedShardCount++;
}
LOG.info("Added " + addedShardCount + " shards for composition");
filesToCombineAggregator.addValue(addedShardCount);

final ComposeRequest composeRequest =
new ComposeRequest().setDestination(destination).setSourceObjects(sourceObjects);
final Compose compose =
storage.compose(destPath.getBucket(), destPath.getObject(), composeRequest);
final StorageObject result = compose.execute();
final String combineResult = GcsPath.fromObject(result).toString();
LOG.info("Combine result is " + combineResult);
combinedFilesAggregator.addValue(addedShardCount);
createdFilesAggregator.addValue(1);
for (SourceObjects sourceObject : sourceObjects) {
final String shardToDelete = sourceObject.getName();
LOG.info("Cleaning up shard " + shardToDelete + " for result " + dest);
int retryCount = MAX_RETRY_COUNT;
boolean done = false;
while (!done && retryCount > 0) {
try {
storage.delete(destPath.getBucket(), shardToDelete).execute();
done = true;
} catch (Exception ex) {
LOG.info("Error deleting " + ex.getMessage() + retryCount + " retries left");
}
indexer.processAlignment(rec);
retryCount--;
}
deletedFilesAggregator.addValue(1);
}
indexer.finish();
dumpStats(processedReads, timer);
}

static void dumpStats(long processedReads, Stopwatch timer) {
LOG.info("Processed " + processedReads + " records in " + timer +
". Speed: " + processedReads/timer.elapsed(TimeUnit.SECONDS) + " reads/sec");

}

static String composeAndCleanUpShards(Storage.Objects storage,
List<String> shardNames, String dest) throws IOException {
LOG.info("Combining shards into " + dest);

final GcsPath destPath = GcsPath.fromUri(dest);

StorageObject destination = new StorageObject()
.setContentType(BAM_INDEX_FILE_MIME_TYPE);

ArrayList<SourceObjects> sourceObjects = new ArrayList<SourceObjects>();
int addedShardCount = 0;
for (String shard : shardNames) {
final GcsPath shardPath = GcsPath.fromUri(shard);
LOG.info("Adding shard " + shardPath + " for result " + dest);
sourceObjects.add( new SourceObjects().setName(shardPath.getObject()));
addedShardCount++;
}
LOG.info("Added " + addedShardCount + " shards for composition");

final ComposeRequest composeRequest = new ComposeRequest()
.setDestination(destination)
.setSourceObjects(sourceObjects);
final Compose compose = storage.compose(
destPath.getBucket(), destPath.getObject(), composeRequest);
final StorageObject result = compose.execute();
final String combineResult = GcsPath.fromObject(result).toString();
LOG.info("Combine result is " + combineResult);
for (SourceObjects sourceObject : sourceObjects) {
final String shardToDelete = sourceObject.getName();
LOG.info("Cleaning up shard " + shardToDelete + " for result " + dest);
storage.delete(destPath.getBucket(), shardToDelete).execute();
}

return combineResult;
}
return combineResult;
}
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,23 @@
package com.google.cloud.genomics.dataflow.functions;

import com.google.cloud.dataflow.sdk.transforms.DoFn;
import com.google.cloud.genomics.dataflow.readers.bam.HeaderInfo;

import htsjdk.samtools.SAMFileHeader;
import htsjdk.samtools.SAMSequenceRecord;

import java.util.logging.Logger;

public class GetReferencesFromHeaderFn extends DoFn<HeaderInfo, String> {
private static final Logger LOG = Logger.getLogger(GetReferencesFromHeaderFn.class.getName());

@Override
public void processElement(DoFn<HeaderInfo, String>.ProcessContext c) throws Exception {
final SAMFileHeader header = c.element().header;
for (SAMSequenceRecord sequence : header.getSequenceDictionary().getSequences()) {
c.output(sequence.getSequenceName());
}
LOG.info("Processed " + header.getSequenceDictionary().size() + " references");
}
}

Original file line number Diff line number Diff line change
Expand Up @@ -15,24 +15,26 @@
*/
package com.google.cloud.genomics.dataflow.functions;

import com.google.api.services.genomics.model.Read;
import com.google.cloud.dataflow.sdk.options.Default;
import com.google.cloud.dataflow.sdk.options.Description;
import com.google.cloud.dataflow.sdk.options.PipelineOptions;
import com.google.cloud.dataflow.sdk.transforms.Aggregator;
import com.google.cloud.dataflow.sdk.transforms.DoFn;
import com.google.cloud.dataflow.sdk.transforms.GroupByKey;
import com.google.cloud.dataflow.sdk.transforms.Sum.SumIntegerFn;
import com.google.cloud.dataflow.sdk.values.KV;
import com.google.cloud.genomics.utils.Contig;
import com.google.genomics.v1.Read;

import java.util.logging.Logger;

/*
* Takes a read and associates it with a Contig.
* This can be used to shard Reads so they can be written to disk in parallel.
* The size of the Contigs is determined by Options.getLociPerWritingShard.
*/
public class KeyReadsFn extends DoFn<Read, KV<Contig,Read>> {

private static final Logger LOG = Logger.getLogger(KeyReadsFn.class.getName());

public static interface Options extends PipelineOptions {
@Description("Loci per writing shard")
@Default.Long(10000)
Expand All @@ -44,6 +46,11 @@ public static interface Options extends PipelineOptions {
private Aggregator<Integer, Integer> readCountAggregator;
private Aggregator<Integer, Integer> unmappedReadCountAggregator;
private long lociPerShard;
private long count;
private long minPos = Long.MAX_VALUE;
private long maxPos = Long.MIN_VALUE;



public KeyReadsFn() {
readCountAggregator = createAggregator("Keyed reads", new SumIntegerFn());
Expand All @@ -52,15 +59,26 @@ public KeyReadsFn() {

@Override
public void startBundle(Context c) {
lociPerShard = c.getPipelineOptions()
lociPerShard = c.getPipelineOptions()
.as(Options.class)
.getLociPerWritingShard();
count = 0;
}

@Override
public void finishBundle(Context c) {
LOG.info("KeyReadsDone: Processed " + count + " reads" + "min=" + minPos +
" max=" + maxPos);
}

@Override
public void processElement(DoFn<Read, KV<Contig, Read>>.ProcessContext c)
throws Exception {
final Read read = c.element();
long pos = read.getAlignment().getPosition().getPosition();
minPos = Math.min(minPos, pos);
maxPos = Math.max(maxPos, pos);
count++;
c.output(
KV.of(
shardKeyForRead(read, lociPerShard),
Expand All @@ -82,7 +100,7 @@ static boolean isUnmapped(Read read) {
return false;
}

static Contig shardKeyForRead(Read read, long lociPerShard) {
public static Contig shardKeyForRead(Read read, long lociPerShard) {
String referenceName = null;
Long alignmentStart = null;
if (read.getAlignment() != null) {
Expand Down
Loading