From a27435d5eeb1623ad51c669cf1dda83ae788ae55 Mon Sep 17 00:00:00 2001 From: Yaliang Wang Date: Wed, 6 Sep 2017 00:00:40 -0700 Subject: [PATCH] Use class reflection to determine TProtocol --- .../thrift/HiveThriftFieldIdResolver.java | 3 +- .../twitter/hive/thrift/ThriftGenericRow.java | 36 ++++++++++++++++--- .../decoder/thrift/ThriftGenericRow.java | 36 ++++++++++++++++--- 3 files changed, 66 insertions(+), 9 deletions(-) diff --git a/presto-hive/src/main/java/com/facebook/presto/twitter/hive/thrift/HiveThriftFieldIdResolver.java b/presto-hive/src/main/java/com/facebook/presto/twitter/hive/thrift/HiveThriftFieldIdResolver.java index 077a409fd2a2e..ddd610eb12881 100644 --- a/presto-hive/src/main/java/com/facebook/presto/twitter/hive/thrift/HiveThriftFieldIdResolver.java +++ b/presto-hive/src/main/java/com/facebook/presto/twitter/hive/thrift/HiveThriftFieldIdResolver.java @@ -52,13 +52,14 @@ * '0': { * '0': 1, * '1': 2, + * 'id': 1 * }, * '1': 3 * } * * The json property is: * - * {"0":{"0":1,"1":2},"1":3} + * {"0":{"0":1,"1":2,"id":1},"1":3} */ public class HiveThriftFieldIdResolver implements ThriftFieldIdResolver diff --git a/presto-hive/src/main/java/com/facebook/presto/twitter/hive/thrift/ThriftGenericRow.java b/presto-hive/src/main/java/com/facebook/presto/twitter/hive/thrift/ThriftGenericRow.java index 5b961f7d813ac..59efc03abccac 100644 --- a/presto-hive/src/main/java/com/facebook/presto/twitter/hive/thrift/ThriftGenericRow.java +++ b/presto-hive/src/main/java/com/facebook/presto/twitter/hive/thrift/ThriftGenericRow.java @@ -18,11 +18,12 @@ import org.apache.thrift.TBase; import org.apache.thrift.TException; import org.apache.thrift.TFieldIdEnum; -import org.apache.thrift.protocol.TBinaryProtocol; import org.apache.thrift.protocol.TField; +import org.apache.thrift.protocol.TJSONProtocol; import org.apache.thrift.protocol.TList; import org.apache.thrift.protocol.TMap; import org.apache.thrift.protocol.TProtocol; +import org.apache.thrift.protocol.TProtocolFactory; import org.apache.thrift.protocol.TProtocolUtil; import org.apache.thrift.protocol.TSet; import org.apache.thrift.protocol.TType; @@ -35,13 +36,16 @@ import java.util.HashSet; import java.util.List; import java.util.Map; +import java.util.Optional; import java.util.Set; public class ThriftGenericRow implements TBase { private static final Logger log = Logger.get(ThriftGenericRow.class); + private static final byte[] COLON = new byte[]{58}; private final Map values = new HashMap<>(); + private TProtocolFactory iprotFactory; private byte[] buf; private int off; private int len; @@ -81,13 +85,34 @@ public String getFieldName() public void read(TProtocol iprot) throws TException { + for (Class clazz = iprot.getClass(); clazz != null; clazz = clazz.getSuperclass()) { + try { + Optional> factory = Arrays.stream(clazz.getDeclaredClasses()) + .filter(TProtocolFactory.class::isAssignableFrom) + .findFirst(); + if (factory.isPresent()) { + iprotFactory = factory.get().asSubclass(TProtocolFactory.class).newInstance(); + break; + } + } + catch (InstantiationException | IllegalAccessException | SecurityException ignored) { + } + } TTransport trans = iprot.getTransport(); buf = trans.getBuffer(); off = trans.getBufferPosition(); TProtocolUtil.skip(iprot, TType.STRUCT); + omitContextSyntaxChar(iprot); len = trans.getBufferPosition() - off; } + private void omitContextSyntaxChar(TProtocol iprot) + { + if (TJSONProtocol.class.isAssignableFrom(iprot.getClass()) && buf[off] == COLON[0]) { + off = off + 1; + } + } + public void parse() throws TException { @@ -97,9 +122,12 @@ public void parse() public void parse(short[] thriftIds) throws TException { - Set idSet = thriftIds == null ? null : new HashSet(Arrays.asList(ArrayUtils.toObject(thriftIds))); + Set idSet = thriftIds == null ? null : new HashSet<>(Arrays.asList(ArrayUtils.toObject(thriftIds))); TMemoryInputTransport trans = new TMemoryInputTransport(buf, off, len); - TBinaryProtocol iprot = new TBinaryProtocol(trans); + if (iprotFactory == null) { + throw new TException("Failed to find the TProtocol factory"); + } + TProtocol iprot = iprotFactory.getProtocol(trans); TField field; iprot.readStructBegin(); while (true) { @@ -107,7 +135,7 @@ public void parse(short[] thriftIds) if (field.type == TType.STOP) { break; } - if (idSet != null && !idSet.remove(Short.valueOf(field.id))) { + if (idSet != null && !idSet.remove(field.id)) { TProtocolUtil.skip(iprot, field.type); } else { diff --git a/presto-record-decoder/src/main/java/com/facebook/presto/decoder/thrift/ThriftGenericRow.java b/presto-record-decoder/src/main/java/com/facebook/presto/decoder/thrift/ThriftGenericRow.java index f3f356983bed6..99524fc473f72 100644 --- a/presto-record-decoder/src/main/java/com/facebook/presto/decoder/thrift/ThriftGenericRow.java +++ b/presto-record-decoder/src/main/java/com/facebook/presto/decoder/thrift/ThriftGenericRow.java @@ -18,11 +18,12 @@ import org.apache.thrift.TBase; import org.apache.thrift.TException; import org.apache.thrift.TFieldIdEnum; -import org.apache.thrift.protocol.TBinaryProtocol; import org.apache.thrift.protocol.TField; +import org.apache.thrift.protocol.TJSONProtocol; import org.apache.thrift.protocol.TList; import org.apache.thrift.protocol.TMap; import org.apache.thrift.protocol.TProtocol; +import org.apache.thrift.protocol.TProtocolFactory; import org.apache.thrift.protocol.TProtocolUtil; import org.apache.thrift.protocol.TSet; import org.apache.thrift.protocol.TType; @@ -35,13 +36,16 @@ import java.util.HashSet; import java.util.List; import java.util.Map; +import java.util.Optional; import java.util.Set; public class ThriftGenericRow implements TBase { private static final Logger log = Logger.get(ThriftGenericRow.class); + private static final byte[] COLON = new byte[]{58}; private final Map values = new HashMap<>(); + private TProtocolFactory iprotFactory; private byte[] buf; private int off; private int len; @@ -81,13 +85,34 @@ public String getFieldName() public void read(TProtocol iprot) throws TException { + for (Class clazz = iprot.getClass(); clazz != null; clazz = clazz.getSuperclass()) { + try { + Optional> factory = Arrays.stream(clazz.getDeclaredClasses()) + .filter(TProtocolFactory.class::isAssignableFrom) + .findFirst(); + if (factory.isPresent()) { + iprotFactory = factory.get().asSubclass(TProtocolFactory.class).newInstance(); + break; + } + } + catch (InstantiationException | IllegalAccessException | SecurityException ignored) { + } + } TTransport trans = iprot.getTransport(); buf = trans.getBuffer(); off = trans.getBufferPosition(); TProtocolUtil.skip(iprot, TType.STRUCT); + omitContextSyntaxChar(iprot); len = trans.getBufferPosition() - off; } + private void omitContextSyntaxChar(TProtocol iprot) + { + if (TJSONProtocol.class.isAssignableFrom(iprot.getClass()) && buf[off] == COLON[0]) { + off = off + 1; + } + } + public void parse() throws TException { @@ -97,9 +122,12 @@ public void parse() public void parse(short[] thriftIds) throws TException { - Set idSet = thriftIds == null ? null : new HashSet(Arrays.asList(ArrayUtils.toObject(thriftIds))); + Set idSet = thriftIds == null ? null : new HashSet<>(Arrays.asList(ArrayUtils.toObject(thriftIds))); TMemoryInputTransport trans = new TMemoryInputTransport(buf, off, len); - TBinaryProtocol iprot = new TBinaryProtocol(trans); + if (iprotFactory == null) { + throw new TException("Failed to find the TProtocol factory"); + } + TProtocol iprot = iprotFactory.getProtocol(trans); TField field; iprot.readStructBegin(); while (true) { @@ -107,7 +135,7 @@ public void parse(short[] thriftIds) if (field.type == TType.STOP) { break; } - if (idSet != null && !idSet.remove(Short.valueOf(field.id))) { + if (idSet != null && !idSet.remove(field.id)) { TProtocolUtil.skip(iprot, field.type); } else {