Skip to content

Commit

Permalink
Use class reflection to determine TProtocol
Browse files Browse the repository at this point in the history
  • Loading branch information
Yaliang committed May 29, 2018
1 parent e010962 commit a27435d
Show file tree
Hide file tree
Showing 3 changed files with 66 additions and 9 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -52,13 +52,14 @@
* '0': {
* '0': 1,
* '1': 2,
* 'id': 1
* },
* '1': 3
* }
*
* The json property is:
*
* {"0":{"0":1,"1":2},"1":3}
* {"0":{"0":1,"1":2,"id":1},"1":3}
*/
public class HiveThriftFieldIdResolver
implements ThriftFieldIdResolver
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -18,11 +18,12 @@
import org.apache.thrift.TBase;
import org.apache.thrift.TException;
import org.apache.thrift.TFieldIdEnum;
import org.apache.thrift.protocol.TBinaryProtocol;
import org.apache.thrift.protocol.TField;
import org.apache.thrift.protocol.TJSONProtocol;
import org.apache.thrift.protocol.TList;
import org.apache.thrift.protocol.TMap;
import org.apache.thrift.protocol.TProtocol;
import org.apache.thrift.protocol.TProtocolFactory;
import org.apache.thrift.protocol.TProtocolUtil;
import org.apache.thrift.protocol.TSet;
import org.apache.thrift.protocol.TType;
Expand All @@ -35,13 +36,16 @@
import java.util.HashSet;
import java.util.List;
import java.util.Map;
import java.util.Optional;
import java.util.Set;

public class ThriftGenericRow
implements TBase<ThriftGenericRow, ThriftGenericRow.Fields>
{
private static final Logger log = Logger.get(ThriftGenericRow.class);
private static final byte[] COLON = new byte[]{58};
private final Map<Short, Object> values = new HashMap<>();
private TProtocolFactory iprotFactory;
private byte[] buf;
private int off;
private int len;
Expand Down Expand Up @@ -81,13 +85,34 @@ public String getFieldName()
public void read(TProtocol iprot)
throws TException
{
for (Class<?> clazz = iprot.getClass(); clazz != null; clazz = clazz.getSuperclass()) {
try {
Optional<Class<?>> factory = Arrays.stream(clazz.getDeclaredClasses())
.filter(TProtocolFactory.class::isAssignableFrom)
.findFirst();
if (factory.isPresent()) {
iprotFactory = factory.get().asSubclass(TProtocolFactory.class).newInstance();
break;
}
}
catch (InstantiationException | IllegalAccessException | SecurityException ignored) {
}
}
TTransport trans = iprot.getTransport();
buf = trans.getBuffer();
off = trans.getBufferPosition();
TProtocolUtil.skip(iprot, TType.STRUCT);
omitContextSyntaxChar(iprot);
len = trans.getBufferPosition() - off;
}

private void omitContextSyntaxChar(TProtocol iprot)
{
if (TJSONProtocol.class.isAssignableFrom(iprot.getClass()) && buf[off] == COLON[0]) {
off = off + 1;
}
}

public void parse()
throws TException
{
Expand All @@ -97,17 +122,20 @@ public void parse()
public void parse(short[] thriftIds)
throws TException
{
Set<Short> idSet = thriftIds == null ? null : new HashSet(Arrays.asList(ArrayUtils.toObject(thriftIds)));
Set<Short> idSet = thriftIds == null ? null : new HashSet<>(Arrays.asList(ArrayUtils.toObject(thriftIds)));
TMemoryInputTransport trans = new TMemoryInputTransport(buf, off, len);
TBinaryProtocol iprot = new TBinaryProtocol(trans);
if (iprotFactory == null) {
throw new TException("Failed to find the TProtocol factory");
}
TProtocol iprot = iprotFactory.getProtocol(trans);
TField field;
iprot.readStructBegin();
while (true) {
field = iprot.readFieldBegin();
if (field.type == TType.STOP) {
break;
}
if (idSet != null && !idSet.remove(Short.valueOf(field.id))) {
if (idSet != null && !idSet.remove(field.id)) {
TProtocolUtil.skip(iprot, field.type);
}
else {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -18,11 +18,12 @@
import org.apache.thrift.TBase;
import org.apache.thrift.TException;
import org.apache.thrift.TFieldIdEnum;
import org.apache.thrift.protocol.TBinaryProtocol;
import org.apache.thrift.protocol.TField;
import org.apache.thrift.protocol.TJSONProtocol;
import org.apache.thrift.protocol.TList;
import org.apache.thrift.protocol.TMap;
import org.apache.thrift.protocol.TProtocol;
import org.apache.thrift.protocol.TProtocolFactory;
import org.apache.thrift.protocol.TProtocolUtil;
import org.apache.thrift.protocol.TSet;
import org.apache.thrift.protocol.TType;
Expand All @@ -35,13 +36,16 @@
import java.util.HashSet;
import java.util.List;
import java.util.Map;
import java.util.Optional;
import java.util.Set;

public class ThriftGenericRow
implements TBase<ThriftGenericRow, ThriftGenericRow.Fields>
{
private static final Logger log = Logger.get(ThriftGenericRow.class);
private static final byte[] COLON = new byte[]{58};
private final Map<Short, Object> values = new HashMap<>();
private TProtocolFactory iprotFactory;
private byte[] buf;
private int off;
private int len;
Expand Down Expand Up @@ -81,13 +85,34 @@ public String getFieldName()
public void read(TProtocol iprot)
throws TException
{
for (Class<?> clazz = iprot.getClass(); clazz != null; clazz = clazz.getSuperclass()) {
try {
Optional<Class<?>> factory = Arrays.stream(clazz.getDeclaredClasses())
.filter(TProtocolFactory.class::isAssignableFrom)
.findFirst();
if (factory.isPresent()) {
iprotFactory = factory.get().asSubclass(TProtocolFactory.class).newInstance();
break;
}
}
catch (InstantiationException | IllegalAccessException | SecurityException ignored) {
}
}
TTransport trans = iprot.getTransport();
buf = trans.getBuffer();
off = trans.getBufferPosition();
TProtocolUtil.skip(iprot, TType.STRUCT);
omitContextSyntaxChar(iprot);
len = trans.getBufferPosition() - off;
}

private void omitContextSyntaxChar(TProtocol iprot)
{
if (TJSONProtocol.class.isAssignableFrom(iprot.getClass()) && buf[off] == COLON[0]) {
off = off + 1;
}
}

public void parse()
throws TException
{
Expand All @@ -97,17 +122,20 @@ public void parse()
public void parse(short[] thriftIds)
throws TException
{
Set<Short> idSet = thriftIds == null ? null : new HashSet(Arrays.asList(ArrayUtils.toObject(thriftIds)));
Set<Short> idSet = thriftIds == null ? null : new HashSet<>(Arrays.asList(ArrayUtils.toObject(thriftIds)));
TMemoryInputTransport trans = new TMemoryInputTransport(buf, off, len);
TBinaryProtocol iprot = new TBinaryProtocol(trans);
if (iprotFactory == null) {
throw new TException("Failed to find the TProtocol factory");
}
TProtocol iprot = iprotFactory.getProtocol(trans);
TField field;
iprot.readStructBegin();
while (true) {
field = iprot.readFieldBegin();
if (field.type == TType.STOP) {
break;
}
if (idSet != null && !idSet.remove(Short.valueOf(field.id))) {
if (idSet != null && !idSet.remove(field.id)) {
TProtocolUtil.skip(iprot, field.type);
}
else {
Expand Down

0 comments on commit a27435d

Please sign in to comment.