Skip to content

Commit

Permalink
feat: spark-type conversion
Browse files Browse the repository at this point in the history
  • Loading branch information
Anush008 committed Dec 12, 2023
1 parent 6fbc0e5 commit 39653e3
Show file tree
Hide file tree
Showing 2 changed files with 59 additions and 76 deletions.
48 changes: 14 additions & 34 deletions src/main/java/io/qdrant/spark/QdrantDataWriter.java
Original file line number Diff line number Diff line change
Expand Up @@ -18,10 +18,14 @@
import org.slf4j.LoggerFactory;

/**
* A DataWriter implementation that writes data to Qdrant, a vector search engine. This class takes
* QdrantOptions and StructType as input and writes data to QdrantRest. It implements the DataWriter
* interface and overrides its methods write, commit, abort and close. It also has a private method
* write that is used to upload a batch of points to Qdrant. The class uses a Point class to
* A DataWriter implementation that writes data to Qdrant, a vector search
* engine. This class takes
* QdrantOptions and StructType as input and writes data to QdrantRest. It
* implements the DataWriter
* interface and overrides its methods write, commit, abort and close. It also
* has a private method
* write that is used to upload a batch of points to Qdrant. The class uses a
* Point class to
* represent a data point and an ArrayList to store the points.
*/
public class QdrantDataWriter implements DataWriter<InternalRow>, Serializable {
Expand Down Expand Up @@ -97,26 +101,18 @@ public void write(int retries) {
}

@Override
public void abort() {}
public void abort() {
}

@Override
public void close() {}
public void close() {
}

private Object convertToJavaType(InternalRow record, StructField field, int fieldIndex) {
DataType dataType = field.dataType();

if (dataType == DataTypes.StringType) {
return record.getString(fieldIndex);
} else if (dataType == DataTypes.IntegerType) {
return record.getInt(fieldIndex);
} else if (dataType == DataTypes.LongType) {
return record.getLong(fieldIndex);
} else if (dataType == DataTypes.FloatType) {
return record.getFloat(fieldIndex);
} else if (dataType == DataTypes.DoubleType) {
return record.getDouble(fieldIndex);
} else if (dataType == DataTypes.BooleanType) {
return record.getBoolean(fieldIndex);
} else if (dataType == DataTypes.DateType || dataType == DataTypes.TimestampType) {
return record.getString(fieldIndex);
} else if (dataType instanceof ArrayType) {
Expand All @@ -130,26 +126,11 @@ private Object convertToJavaType(InternalRow record, StructField field, int fiel
}

// Fall back to the generic get method
// TODO: Add explicit parsings for other data types like maps
return record.get(fieldIndex, dataType);
}

private Object convertArrayToJavaType(ArrayData arrayData, DataType elementType) {
if (elementType == DataTypes.IntegerType) {
return arrayData.toIntArray();
} else if (elementType == DataTypes.FloatType) {
return arrayData.toFloatArray();
} else if (elementType == DataTypes.ShortType) {
return arrayData.toShortArray();
} else if (elementType == DataTypes.ByteType) {
return arrayData.toByteArray();
} else if (elementType == DataTypes.DoubleType) {
return arrayData.toDoubleArray();
} else if (elementType == DataTypes.LongType) {
return arrayData.toLongArray();
} else if (elementType == DataTypes.BooleanType) {
return arrayData.toBooleanArray();
} else if (elementType == DataTypes.StringType) {
if (elementType == DataTypes.StringType) {
int length = arrayData.numElements();
String[] result = new String[length];
for (int i = 0; i < length; i++) {
Expand All @@ -165,9 +146,8 @@ private Object convertArrayToJavaType(ArrayData arrayData, DataType elementType)
result[i] = convertStructToJavaType(structData, structType);
}
return result;

} else {
throw new UnsupportedOperationException("Unsupported array type");
return arrayData.toObjectArray(elementType);
}
}

Expand Down
87 changes: 45 additions & 42 deletions src/test/java/io/qdrant/spark/TestIntegration.java
Original file line number Diff line number Diff line change
Expand Up @@ -30,51 +30,54 @@ public TestIntegration() {

@Test
public void testSparkSession() {
SparkSession spark =
SparkSession.builder().master("local[1]").appName("qdrant-spark").getOrCreate();
SparkSession spark = SparkSession.builder().master("local[1]").appName("qdrant-spark").getOrCreate();

List<Row> data =
Arrays.asList(
RowFactory.create(
1,
1,
new float[] {1.0f, 2.0f, 3.0f},
"John Doe",
new String[] {"Hello", "Hi"},
RowFactory.create(99, "AnotherNestedStruct")),
RowFactory.create(
2,
2,
new float[] {4.0f, 5.0f, 6.0f},
"Jane Doe",
new String[] {"Bonjour", "Salut"},
RowFactory.create(99, "AnotherNestedStruct")));
List<Row> data = Arrays.asList(
RowFactory.create(
1,
1,
new float[] { 1.0f, 2.0f, 3.0f },
"John Doe",
new String[] { "Hello", "Hi" },
RowFactory.create(99, "AnotherNestedStruct"),
new int[] { 4, 32, 323, 788 }),
RowFactory.create(
2,
2,
new float[] { 4.0f, 5.0f, 6.0f },
"Jane Doe",
new String[] { "Bonjour", "Salut" },
RowFactory.create(99, "AnotherNestedStruct"),
new int[] { 1, 2, 3, 4, 5 }));

StructType structType =
new StructType(
new StructField[] {
new StructField("nested_id", DataTypes.IntegerType, false, Metadata.empty()),
new StructField("nested_value", DataTypes.StringType, false, Metadata.empty())
});
StructType structType = new StructType(
new StructField[] {
new StructField("nested_id", DataTypes.IntegerType, false, Metadata.empty()),
new StructField("nested_value", DataTypes.StringType, false, Metadata.empty())
});

StructType schema =
new StructType(
new StructField[] {
new StructField("id", DataTypes.IntegerType, false, Metadata.empty()),
new StructField("score", DataTypes.IntegerType, true, Metadata.empty()),
new StructField(
"embedding",
DataTypes.createArrayType(DataTypes.FloatType),
true,
Metadata.empty()),
new StructField("name", DataTypes.StringType, true, Metadata.empty()),
new StructField(
"greetings",
DataTypes.createArrayType(DataTypes.StringType),
true,
Metadata.empty()),
new StructField("struct_data", structType, true, Metadata.empty())
});
StructType schema = new StructType(
new StructField[] {
new StructField("id", DataTypes.IntegerType, false, Metadata.empty()),
new StructField("score", DataTypes.IntegerType, true, Metadata.empty()),
new StructField(
"embedding",
DataTypes.createArrayType(DataTypes.FloatType),
true,
Metadata.empty()),
new StructField("name", DataTypes.StringType, true, Metadata.empty()),
new StructField(
"greetings",
DataTypes.createArrayType(DataTypes.StringType),
true,
Metadata.empty()),
new StructField("struct_data", structType, true, Metadata.empty()),
new StructField(
"numbers",
DataTypes.createArrayType(DataTypes.IntegerType),
true,
Metadata.empty()),
});
Dataset<Row> df = spark.createDataFrame(data, schema);

df.write()
Expand Down

0 comments on commit 39653e3

Please sign in to comment.