Skip to content

Commit

Permalink
chore: formatting (#11)
Browse files Browse the repository at this point in the history
* chore: Spotify formatting

* fix: version
  • Loading branch information
Anush008 authored Feb 17, 2024
1 parent 8583ff2 commit ad68e7d
Show file tree
Hide file tree
Showing 6 changed files with 132 additions and 138 deletions.
2 changes: 1 addition & 1 deletion pom.xml
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@
<modelVersion>4.0.0</modelVersion>
<groupId>io.qdrant</groupId>
<artifactId>spark</artifactId>
<version>1.13</version>
<version>1.12.1</version>
<name>qdrant-spark</name>
<url>https://github.com/qdrant/qdrant-spark</url>
<description>An Apache Spark connector for the Qdrant vector database</description>
Expand Down
117 changes: 59 additions & 58 deletions src/main/java/io/qdrant/spark/ObjectFactory.java
Original file line number Diff line number Diff line change
Expand Up @@ -2,77 +2,78 @@

import java.util.HashMap;
import java.util.Map;

import org.apache.spark.sql.catalyst.InternalRow;
import org.apache.spark.sql.catalyst.util.ArrayData;
import org.apache.spark.sql.types.ArrayType;
import org.apache.spark.sql.types.DataType;
import org.apache.spark.sql.types.StructField;
import org.apache.spark.sql.types.StructType;
import org.apache.spark.sql.types.ArrayType;

class ObjectFactory {
public static Object object(InternalRow record, StructField field, int fieldIndex) {
DataType dataType = field.dataType();
public static Object object(InternalRow record, StructField field, int fieldIndex) {
DataType dataType = field.dataType();

switch (dataType.typeName()) {
case "integer":
return record.getInt(fieldIndex);
case "float":
return record.getFloat(fieldIndex);
case "double":
return record.getDouble(fieldIndex);
case "long":
return record.getLong(fieldIndex);
case "boolean":
return record.getBoolean(fieldIndex);
case "string":
return record.getString(fieldIndex);
case "array":
ArrayType arrayType = (ArrayType) dataType;
ArrayData arrayData = record.getArray(fieldIndex);
return object(arrayData, arrayType.elementType());
case "struct":
StructType structType = (StructType) dataType;
InternalRow structData = record.getStruct(fieldIndex, structType.fields().length);
return object(structData, structType);
default:
return null;
}
switch (dataType.typeName()) {
case "integer":
return record.getInt(fieldIndex);
case "float":
return record.getFloat(fieldIndex);
case "double":
return record.getDouble(fieldIndex);
case "long":
return record.getLong(fieldIndex);
case "boolean":
return record.getBoolean(fieldIndex);
case "string":
return record.getString(fieldIndex);
case "array":
ArrayType arrayType = (ArrayType) dataType;
ArrayData arrayData = record.getArray(fieldIndex);
return object(arrayData, arrayType.elementType());
case "struct":
StructType structType = (StructType) dataType;
InternalRow structData = record.getStruct(fieldIndex, structType.fields().length);
return object(structData, structType);
default:
return null;
}
}

public static Object object(ArrayData arrayData, DataType elementType) {
public static Object object(ArrayData arrayData, DataType elementType) {

switch (elementType.typeName()) {
case "string": {
int length = arrayData.numElements();
String[] result = new String[length];
for (int i = 0; i < length; i++) {
result[i] = arrayData.getUTF8String(i).toString();
}
return result;
}
switch (elementType.typeName()) {
case "string":
{
int length = arrayData.numElements();
String[] result = new String[length];
for (int i = 0; i < length; i++) {
result[i] = arrayData.getUTF8String(i).toString();
}
return result;
}

case "struct": {
StructType structType = (StructType) elementType;
int length = arrayData.numElements();
Object[] result = new Object[length];
for (int i = 0; i < length; i++) {
InternalRow structData = arrayData.getStruct(i, structType.fields().length);
result[i] = object(structData, structType);
}
return result;
}
default:
return arrayData.toObjectArray(elementType);
case "struct":
{
StructType structType = (StructType) elementType;
int length = arrayData.numElements();
Object[] result = new Object[length];
for (int i = 0; i < length; i++) {
InternalRow structData = arrayData.getStruct(i, structType.fields().length);
result[i] = object(structData, structType);
}
return result;
}
default:
return arrayData.toObjectArray(elementType);
}
}

public static Object object(InternalRow structData, StructType structType) {
Map<String, Object> result = new HashMap<>();
for (int i = 0; i < structType.fields().length; i++) {
StructField structField = structType.fields()[i];
result.put(structField.name(), object(structData, structField, i));
}
return result;
public static Object object(InternalRow structData, StructType structType) {
Map<String, Object> result = new HashMap<>();
for (int i = 0; i < structType.fields().length; i++) {
StructField structField = structType.fields()[i];
result.put(structField.name(), object(structData, structField, i));
}
}
return result;
}
}
21 changes: 10 additions & 11 deletions src/main/java/io/qdrant/spark/Qdrant.java
Original file line number Diff line number Diff line change
Expand Up @@ -11,13 +11,13 @@
import org.apache.spark.sql.util.CaseInsensitiveStringMap;

/**
* A class that implements the TableProvider and DataSourceRegister interfaces.
* Provides methods to
* A class that implements the TableProvider and DataSourceRegister interfaces. Provides methods to
* infer schema, get table, and check required options.
*/
public class Qdrant implements TableProvider, DataSourceRegister {

private final String[] requiredFields = new String[] { "schema", "collection_name", "embedding_field", "qdrant_url" };
private final String[] requiredFields =
new String[] {"schema", "collection_name", "embedding_field", "qdrant_url"};

/**
* Returns the short name of the data source.
Expand All @@ -42,15 +42,15 @@ public StructType inferSchema(CaseInsensitiveStringMap options) {
checkRequiredOptions(options, schema);

return schema;
};
}
;

/**
* Returns a table for the data source based on the provided schema,
* partitioning, and properties.
* Returns a table for the data source based on the provided schema, partitioning, and properties.
*
* @param schema The schema of the table.
* @param schema The schema of the table.
* @param partitioning The partitioning of the table.
* @param properties The properties of the table.
* @param properties The properties of the table.
* @return The table for the data source.
*/
@Override
Expand All @@ -61,12 +61,11 @@ public Table getTable(
}

/**
* Checks if the required options are present in the provided options and if the
* id_field and
* Checks if the required options are present in the provided options and if the id_field and
* embedding_field options are present in the provided schema.
*
* @param options The options to check.
* @param schema The schema to check.
* @param schema The schema to check.
*/
void checkRequiredOptions(CaseInsensitiveStringMap options, StructType schema) {
for (String fieldName : requiredFields) {
Expand Down
22 changes: 8 additions & 14 deletions src/main/java/io/qdrant/spark/QdrantDataWriter.java
Original file line number Diff line number Diff line change
@@ -1,5 +1,7 @@
package io.qdrant.spark;

import static io.qdrant.spark.ObjectFactory.object;

import java.io.Serializable;
import java.util.ArrayList;
import java.util.HashMap;
Expand All @@ -13,17 +15,11 @@
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;

import static io.qdrant.spark.ObjectFactory.object;

/**
* A DataWriter implementation that writes data to Qdrant, a vector search
* engine. This class takes
* QdrantOptions and StructType as input and writes data to QdrantRest. It
* implements the DataWriter
* interface and overrides its methods write, commit, abort and close. It also
* has a private method
* write that is used to upload a batch of points to Qdrant. The class uses a
* Point class to
* A DataWriter implementation that writes data to Qdrant, a vector search engine. This class takes
* QdrantOptions and StructType as input and writes data to QdrantRest. It implements the DataWriter
* interface and overrides its methods write, commit, abort and close. It also has a private method
* write that is used to upload a batch of points to Qdrant. The class uses a Point class to
* represent a data point and an ArrayList to store the points.
*/
public class QdrantDataWriter implements DataWriter<InternalRow>, Serializable {
Expand Down Expand Up @@ -111,12 +107,10 @@ public void write(int retries) {
}

@Override
public void abort() {
}
public void abort() {}

@Override
public void close() {
}
public void close() {}
}

class Point implements Serializable {
Expand Down
14 changes: 5 additions & 9 deletions src/main/java/io/qdrant/spark/QdrantRest.java
Original file line number Diff line number Diff line change
@@ -1,13 +1,11 @@
package io.qdrant.spark;

import com.google.gson.Gson;
import java.io.IOException;
import java.io.Serializable;
import java.net.MalformedURLException;
import java.net.URL;
import java.util.List;

import com.google.gson.Gson;

import okhttp3.MediaType;
import okhttp3.OkHttpClient;
import okhttp3.Request;
Expand All @@ -23,7 +21,7 @@ public class QdrantRest implements Serializable {
* Constructor for QdrantRest class.
*
* @param qdrantUrl The URL of the Qdrant instance.
* @param apiKey The API key to authenticate with Qdrant.
* @param apiKey The API key to authenticate with Qdrant.
*/
public QdrantRest(String qdrantUrl, String apiKey) {
this.qdrantUrl = qdrantUrl;
Expand All @@ -34,11 +32,9 @@ public QdrantRest(String qdrantUrl, String apiKey) {
* Uploads a batch of points to a Qdrant collection.
*
* @param collectionName The name of the collection to upload the points to.
* @param points The list of points to upload.
* @throws IOException If there was an error uploading the batch to
* Qdrant.
* @throws RuntimeException If there was an error uploading the batch to
* Qdrant.
* @param points The list of points to upload.
* @throws IOException If there was an error uploading the batch to Qdrant.
* @throws RuntimeException If there was an error uploading the batch to Qdrant.
* @throws MalformedURLException If the Qdrant URL is malformed.
*/
public void uploadBatch(String collectionName, List<Point> points)
Expand Down
94 changes: 49 additions & 45 deletions src/test/java/io/qdrant/spark/TestIntegration.java
Original file line number Diff line number Diff line change
Expand Up @@ -30,54 +30,58 @@ public TestIntegration() {

@Test
public void testSparkSession() {
SparkSession spark = SparkSession.builder().master("local[1]").appName("qdrant-spark").getOrCreate();
SparkSession spark =
SparkSession.builder().master("local[1]").appName("qdrant-spark").getOrCreate();

List<Row> data = Arrays.asList(
RowFactory.create(
1,
1,
new float[] { 1.0f, 2.0f, 3.0f },
"John Doe",
new String[] { "Hello", "Hi" },
RowFactory.create(99, "AnotherNestedStruct"),
new int[] { 4, 32, 323, 788 }),
RowFactory.create(
2,
2,
new float[] { 4.0f, 5.0f, 6.0f },
"Jane Doe",
new String[] { "Bonjour", "Salut" },
RowFactory.create(99, "AnotherNestedStruct"),
new int[] { 1, 2, 3, 4, 5 }));
List<Row> data =
Arrays.asList(
RowFactory.create(
1,
1,
new float[] {1.0f, 2.0f, 3.0f},
"John Doe",
new String[] {"Hello", "Hi"},
RowFactory.create(99, "AnotherNestedStruct"),
new int[] {4, 32, 323, 788}),
RowFactory.create(
2,
2,
new float[] {4.0f, 5.0f, 6.0f},
"Jane Doe",
new String[] {"Bonjour", "Salut"},
RowFactory.create(99, "AnotherNestedStruct"),
new int[] {1, 2, 3, 4, 5}));

StructType structType = new StructType(
new StructField[] {
new StructField("nested_id", DataTypes.IntegerType, false, Metadata.empty()),
new StructField("nested_value", DataTypes.StringType, false, Metadata.empty())
});
StructType structType =
new StructType(
new StructField[] {
new StructField("nested_id", DataTypes.IntegerType, false, Metadata.empty()),
new StructField("nested_value", DataTypes.StringType, false, Metadata.empty())
});

StructType schema = new StructType(
new StructField[] {
new StructField("id", DataTypes.IntegerType, false, Metadata.empty()),
new StructField("score", DataTypes.IntegerType, true, Metadata.empty()),
new StructField(
"embedding",
DataTypes.createArrayType(DataTypes.FloatType),
true,
Metadata.empty()),
new StructField("name", DataTypes.StringType, true, Metadata.empty()),
new StructField(
"greetings",
DataTypes.createArrayType(DataTypes.StringType),
true,
Metadata.empty()),
new StructField("struct_data", structType, true, Metadata.empty()),
new StructField(
"numbers",
DataTypes.createArrayType(DataTypes.IntegerType),
true,
Metadata.empty()),
});
StructType schema =
new StructType(
new StructField[] {
new StructField("id", DataTypes.IntegerType, false, Metadata.empty()),
new StructField("score", DataTypes.IntegerType, true, Metadata.empty()),
new StructField(
"embedding",
DataTypes.createArrayType(DataTypes.FloatType),
true,
Metadata.empty()),
new StructField("name", DataTypes.StringType, true, Metadata.empty()),
new StructField(
"greetings",
DataTypes.createArrayType(DataTypes.StringType),
true,
Metadata.empty()),
new StructField("struct_data", structType, true, Metadata.empty()),
new StructField(
"numbers",
DataTypes.createArrayType(DataTypes.IntegerType),
true,
Metadata.empty()),
});
Dataset<Row> df = spark.createDataFrame(data, schema);

df.write()
Expand Down

0 comments on commit ad68e7d

Please sign in to comment.