Skip to content

Commit

Permalink
Performance improvement of PrimitiveFloatList by lazy deserialization (
Browse files Browse the repository at this point in the history
…#36)

 Performance improvement of PrimiteFloatList by lazy deserialization

Co-authored-by: Sourav Maji <[email protected]>
  • Loading branch information
majisourav99 and Sourav Maji authored Apr 2, 2020
1 parent a20fd47 commit a31b48f
Show file tree
Hide file tree
Showing 4 changed files with 282 additions and 28 deletions.
Original file line number Diff line number Diff line change
@@ -0,0 +1,54 @@
package com.linkedin.avro.fastserde;

import java.nio.ByteBuffer;
import java.nio.ByteOrder;
import java.util.ArrayList;
import java.util.List;


public class CompositeByteBuffer {
private int byteBufferCount;
private List<ByteBuffer> byteBuffers;

public CompositeByteBuffer() {
byteBuffers = new ArrayList<>(2);
}

public ByteBuffer allocate(int index, int size) {
ByteBuffer byteBuffer;

// Check if we can reuse the old record's byteBuffers, else allocate a new one.
if (byteBuffers.size() > index && byteBuffers.get(index).capacity() > size) {
byteBuffer = byteBuffers.get(index);
byteBuffer.clear();
} else {
byteBuffer = ByteBuffer.allocate((int)size).order(ByteOrder.LITTLE_ENDIAN);
}
if (index < byteBuffers.size()) {
byteBuffers.set(index, byteBuffer);
} else {
byteBuffers.add(byteBuffer);
}
return byteBuffer;
}

public void clear() {
for (ByteBuffer byteBuffer : byteBuffers) {
byteBuffer.clear();
}
}

public void setByteBufferCount(int count) {
byteBufferCount = count;
}

public void setArray(float[] array) {
int k = 0;
for (int i = 0; i < byteBufferCount; i++) {
ByteBuffer byteBuffer = byteBuffers.get(i);
for (int j = 0; j < byteBuffer.limit(); j += Float.BYTES) {
array[k++] = byteBuffer.getFloat(j);
}
}
}
}
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
package com.linkedin.avro.fastserde;

import java.io.IOException;
import java.nio.ByteBuffer;
import java.util.AbstractList;
import java.util.Collection;
import java.util.Iterator;
Expand All @@ -25,23 +26,34 @@
* - It re-implements {@link #compareTo(GenericArray)}, {@link #equals(Object)} and {@link #hashCode()}
* in order to leverage the primitive types, rather than causing unintended boxing.
*
* TODO: Provide arrays for other primitive types.
* Using ByteBuffer to speed up float-array deserialization: We allocate ByteBuffer to store the raw bytes from
* BinaryDecoder and deserialize them only during array element access. We cache the results into the elements array
* after the first get access of the array so that sub-sequent array access are fast. For reuse case, we try to reuse
* the existing ByteBuffers as long as their capacity can hold the array.
*
* TODO: Provide arrays for other primitive types.
*/
public class PrimitiveFloatList extends AbstractList<Float>
implements GenericArray<Float>, Comparable<GenericArray<Float>> {
private static final float[] EMPTY = new float[0];
private static final int FLOAT_SIZE = Float.BYTES;
private static final Schema FLOAT_SCHEMA = Schema.create(Schema.Type.FLOAT);
private static final Schema SCHEMA = Schema.createArray(FLOAT_SCHEMA);

private int size;
private float[] elements = EMPTY;
private boolean isCached = false;
private CompositeByteBuffer byteBuffer;

public PrimitiveFloatList(int capacity) {
if (capacity != 0) {
elements = new float[capacity];
}
}

public PrimitiveFloatList() {
byteBuffer = new CompositeByteBuffer();
}

public PrimitiveFloatList(Collection<Float> c) {
if (c != null) {
elements = new float[c.size()];
Expand All @@ -60,44 +72,64 @@ public PrimitiveFloatList(Collection<Float> c) {
* @throws IOException on io errors
*/
public static Object readPrimitiveFloatArray(Object old, Decoder in) throws IOException {
long l = in.readArrayStart();
if (l > 0) {
PrimitiveFloatList array = (PrimitiveFloatList) newPrimitiveFloatArray(old, (int) l);
long length = in.readArrayStart();
long totalLength = 0;

if (length > 0) {
PrimitiveFloatList array = (PrimitiveFloatList) newPrimitiveFloatArray(old);
int index = 0;

do {
for (long i = 0; i < l; i++) {
array.addPrimitive(in.readFloat());
}
l = in.arrayNext();
} while (l > 0);
long byteSize = length * FLOAT_SIZE;
ByteBuffer byteBuffer = array.byteBuffer.allocate(index++, (int)byteSize);
in.readFixed(byteBuffer.array(), 0, (int)byteSize);
totalLength += length;
length = in.arrayNext();
} while (length > 0);

array.byteBuffer.setByteBufferCount(index);
setupElements(array, (int)totalLength);
return array;
} else {
return newPrimitiveFloatArray(old, 0);
return new PrimitiveFloatList(0);
}
}

private static void setupElements(PrimitiveFloatList list, int totalSize) {
if (list.elements.length != 0) {
if (totalSize <= list.getCapacity()) {
// reuse the float array directly
list.clear();
} else {
list.resizeAndClear(totalSize);
}
list.size = totalSize;
return;
}
list.elements = new float[totalSize];
list.size = totalSize;
}

/**
* @param expected {@link Schema} to inspect
* @return true if the {@code expected} SCHEMA is of the right type to decode as a {@link PrimitiveFloatList}
* false otherwise
*/
* @param expected {@link Schema} to inspect
* @return true if the {@code expected} SCHEMA is of the right type to decode as a {@link PrimitiveFloatList}
* false otherwise
*/
public static boolean isFloatArray(Schema expected) {
return expected != null && Schema.Type.ARRAY.equals(expected.getType()) && FLOAT_SCHEMA.equals(
expected.getElementType());
}

private static Object newPrimitiveFloatArray(Object old, int size) {
private static Object newPrimitiveFloatArray(Object old) {
if (old instanceof PrimitiveFloatList) {
PrimitiveFloatList oldFloatList = (PrimitiveFloatList) old;
if (size <= oldFloatList.getCapacity()) {
// reuse the float array directly
oldFloatList.clear();
return old;
} else {
oldFloatList.resizeAndClear(size);
return oldFloatList;
}
oldFloatList.byteBuffer.clear();
oldFloatList.isCached = false;
oldFloatList.size = 0;
return oldFloatList;
} else {
return new PrimitiveFloatList(size);
// Just a place holder, will set up the elements later.
return new PrimitiveFloatList();
}
}

Expand Down Expand Up @@ -137,7 +169,9 @@ public boolean hasNext() {

@Override
public Float next() {
return elements[position++];
float f = getPrimitive(position);
position++;
return f;
}

@Override
Expand All @@ -151,6 +185,7 @@ public float getPrimitive(int i) {
if (i >= size) {
throw new IndexOutOfBoundsException("Index " + i + " out of bounds.");
}
cacheFromByteBuffer();
return elements[i];
}

Expand All @@ -167,6 +202,7 @@ public Float get(int i) {
* @return true?
*/
public boolean addPrimitive(float o) {
cacheFromByteBuffer();
if (size == elements.length) {
float[] newElements = new float[(size * 3) / 2 + 1];
System.arraycopy(elements, 0, newElements, 0, size);
Expand All @@ -186,6 +222,7 @@ public void add(int location, Float o) {
if (location > size || location < 0) {
throw new IndexOutOfBoundsException("Index " + location + " out of bounds.");
}
cacheFromByteBuffer();
if (size == elements.length) {
float[] newElements = new float[(size * 3) / 2 + 1];
System.arraycopy(elements, 0, newElements, 0, size);
Expand All @@ -201,6 +238,7 @@ public Float set(int i, Float o) {
if (i >= size) {
throw new IndexOutOfBoundsException("Index " + i + " out of bounds.");
}
cacheFromByteBuffer();
Float response = elements[i];
elements[i] = o;

Expand All @@ -212,14 +250,28 @@ public Float remove(int i) {
if (i >= size) {
throw new IndexOutOfBoundsException("Index " + i + " out of bounds.");
}
cacheFromByteBuffer();
Float result = elements[i];
--size;
System.arraycopy(elements, i + 1, elements, i, (size - i));
elements[size] = 0;
return result;
}

private void cacheFromByteBuffer() {
if (isCached) {
return;
}
synchronized (this) {
if (!isCached) {
byteBuffer.setArray(elements);
isCached = true;
}
}
}

public float peekPrimitive() {
cacheFromByteBuffer();
return (size < elements.length) ? elements[size] : null;
}

Expand All @@ -230,6 +282,7 @@ public Float peek() {

@Override
public int compareTo(GenericArray<Float> that) {
cacheFromByteBuffer();
if (that instanceof PrimitiveFloatList) {
PrimitiveFloatList thatPrimitiveList = (PrimitiveFloatList) that;
if (this.size == thatPrimitiveList.size) {
Expand All @@ -253,6 +306,7 @@ public int compareTo(GenericArray<Float> that) {

@Override
public void reverse() {
cacheFromByteBuffer();
int left = 0;
int right = elements.length - 1;

Expand Down Expand Up @@ -283,6 +337,7 @@ public String toString() {

@Override
public boolean equals(Object o) {
cacheFromByteBuffer();
if (o instanceof GenericArray) {
return compareTo((GenericArray) o) == 0;
} else {
Expand All @@ -292,6 +347,7 @@ public boolean equals(Object o) {

@Override
public int hashCode() {
cacheFromByteBuffer();
int hashCode = 1;
for (int i = 0; i < this.size; i++) {
hashCode = 31 * hashCode + Float.hashCode(elements[i]);
Expand Down
Loading

0 comments on commit a31b48f

Please sign in to comment.