diff --git a/src/dist/edu/umd/cloud9/collection/XMLInputFormat.java b/src/dist/edu/umd/cloud9/collection/XMLInputFormat.java index 3de8357a8..5de4f4499 100644 --- a/src/dist/edu/umd/cloud9/collection/XMLInputFormat.java +++ b/src/dist/edu/umd/cloud9/collection/XMLInputFormat.java @@ -16,6 +16,7 @@ package edu.umd.cloud9.collection; +import java.io.InputStream; import java.io.DataInputStream; import java.io.IOException; @@ -27,8 +28,12 @@ import org.apache.hadoop.io.DataOutputBuffer; import org.apache.hadoop.io.LongWritable; import org.apache.hadoop.io.Text; +import org.apache.hadoop.io.compress.CodecPool; import org.apache.hadoop.io.compress.CompressionCodec; import org.apache.hadoop.io.compress.CompressionCodecFactory; +import org.apache.hadoop.io.compress.Decompressor; +import org.apache.hadoop.io.compress.SplitCompressionInputStream; +import org.apache.hadoop.io.compress.SplittableCompressionCodec; import org.apache.hadoop.mapreduce.InputSplit; import org.apache.hadoop.mapreduce.RecordReader; import org.apache.hadoop.mapreduce.TaskAttemptContext; @@ -80,9 +85,12 @@ public static class XMLRecordReader extends RecordReader { private long start; private long end; private long pos; - private DataInputStream fsin = null; + private InputStream fsin = null; private DataOutputBuffer buffer = new DataOutputBuffer(); + private CompressionCodec codec = null; + private Decompressor decompressor = null; + private long recordStartPos; private final LongWritable key = new LongWritable(); @@ -108,18 +116,30 @@ public void initialize(InputSplit input, TaskAttemptContext context) FileSplit split = (FileSplit) input; start = split.getStart(); + end = start + split.getLength(); Path file = split.getPath(); CompressionCodecFactory compressionCodecs = new CompressionCodecFactory(conf); - CompressionCodec codec = compressionCodecs.getCodec(file); + codec = compressionCodecs.getCodec(file); FileSystem fs = file.getFileSystem(conf); - if (codec != null) { + if (isCompressedInput()) { LOG.info("Reading compressed file " + file + "..."); - fsin = new DataInputStream(codec.createInputStream(fs.open(file))); - - end = Long.MAX_VALUE; + FSDataInputStream fileIn = fs.open(file); + decompressor = CodecPool.getDecompressor(codec); + if (codec instanceof SplittableCompressionCodec) { + // We can read blocks + final SplitCompressionInputStream cIn = ((SplittableCompressionCodec)codec).createInputStream(fileIn, decompressor, start, end, SplittableCompressionCodec.READ_MODE.BYBLOCK); + fsin = cIn; + start = cIn.getAdjustedStart(); + end = cIn.getAdjustedEnd(); + } else { + // We cannot read blocks, we have to read everything + fsin = new DataInputStream(codec.createInputStream(fileIn, decompressor)); + + end = Long.MAX_VALUE; + } } else { LOG.info("Reading uncompressed file " + file + "..."); FSDataInputStream fileIn = fs.open(file); @@ -146,7 +166,7 @@ public void initialize(InputSplit input, TaskAttemptContext context) */ @Override public boolean nextKeyValue() throws IOException, InterruptedException { - if (pos < end) { + if (getFilePosition() < end) { if (readUntilMatch(startTag, false)) { recordStartPos = pos - startTag.length; @@ -166,8 +186,11 @@ public boolean nextKeyValue() throws IOException, InterruptedException { // works correctly. if (fsin instanceof Seekable) { - if (pos != ((Seekable) fsin).getPos()) { - throw new RuntimeException("bytes consumed error!"); + // The position for compressed inputs is weird + if (!isCompressedInput()) { + if (pos != ((Seekable) fsin).getPos()) { + throw new RuntimeException("bytes consumed error!"); + } } } @@ -219,7 +242,25 @@ public void close() throws IOException { */ @Override public float getProgress() throws IOException { - return ((float) (pos - start)) / ((float) (end - start)); + if (start == end) { + return 0.0f; + } else { + return Math.min(1.0f, (getFilePosition() - start) / (float)(end - start)); + } + } + + private boolean isCompressedInput() { + return (codec != null); + } + + protected long getFilePosition() throws IOException { + long retVal; + if (isCompressedInput() && null != fsin && fsin instanceof Seekable) { + retVal = ((Seekable)fsin).getPos(); + } else { + retVal = pos; + } + return retVal; } private boolean readUntilMatch(byte[] match, boolean withinBlock) @@ -227,12 +268,14 @@ private boolean readUntilMatch(byte[] match, boolean withinBlock) int i = 0; while (true) { int b = fsin.read(); - // increment position (bytes consumed) - pos++; // end of file: if (b == -1) return false; + + // increment position (bytes consumed) + pos++; + // save to buffer: if (withinBlock) buffer.write(b); @@ -245,7 +288,7 @@ private boolean readUntilMatch(byte[] match, boolean withinBlock) } else i = 0; // see if we've passed the stop point: - if (!withinBlock && i == 0 && pos >= end) + if (!withinBlock && i == 0 && getFilePosition() >= end) return false; } }