From ad68e7df8a232f080de740cbc8ce4075752db487 Mon Sep 17 00:00:00 2001 From: Anush Date: Sat, 17 Feb 2024 14:17:31 +0530 Subject: [PATCH] chore: formatting (#11) * chore: Spotify formatting * fix: version --- pom.xml | 2 +- .../java/io/qdrant/spark/ObjectFactory.java | 117 +++++++++--------- src/main/java/io/qdrant/spark/Qdrant.java | 21 ++-- .../io/qdrant/spark/QdrantDataWriter.java | 22 ++-- src/main/java/io/qdrant/spark/QdrantRest.java | 14 +-- .../java/io/qdrant/spark/TestIntegration.java | 94 +++++++------- 6 files changed, 132 insertions(+), 138 deletions(-) diff --git a/pom.xml b/pom.xml index af503f3..f974b08 100644 --- a/pom.xml +++ b/pom.xml @@ -6,7 +6,7 @@ 4.0.0 io.qdrant spark - 1.13 + 1.12.1 qdrant-spark https://github.com/qdrant/qdrant-spark An Apache Spark connector for the Qdrant vector database diff --git a/src/main/java/io/qdrant/spark/ObjectFactory.java b/src/main/java/io/qdrant/spark/ObjectFactory.java index 71e0749..f4155de 100644 --- a/src/main/java/io/qdrant/spark/ObjectFactory.java +++ b/src/main/java/io/qdrant/spark/ObjectFactory.java @@ -2,77 +2,78 @@ import java.util.HashMap; import java.util.Map; - import org.apache.spark.sql.catalyst.InternalRow; import org.apache.spark.sql.catalyst.util.ArrayData; +import org.apache.spark.sql.types.ArrayType; import org.apache.spark.sql.types.DataType; import org.apache.spark.sql.types.StructField; import org.apache.spark.sql.types.StructType; -import org.apache.spark.sql.types.ArrayType; class ObjectFactory { - public static Object object(InternalRow record, StructField field, int fieldIndex) { - DataType dataType = field.dataType(); + public static Object object(InternalRow record, StructField field, int fieldIndex) { + DataType dataType = field.dataType(); - switch (dataType.typeName()) { - case "integer": - return record.getInt(fieldIndex); - case "float": - return record.getFloat(fieldIndex); - case "double": - return record.getDouble(fieldIndex); - case "long": - return record.getLong(fieldIndex); - case "boolean": - return record.getBoolean(fieldIndex); - case "string": - return record.getString(fieldIndex); - case "array": - ArrayType arrayType = (ArrayType) dataType; - ArrayData arrayData = record.getArray(fieldIndex); - return object(arrayData, arrayType.elementType()); - case "struct": - StructType structType = (StructType) dataType; - InternalRow structData = record.getStruct(fieldIndex, structType.fields().length); - return object(structData, structType); - default: - return null; - } + switch (dataType.typeName()) { + case "integer": + return record.getInt(fieldIndex); + case "float": + return record.getFloat(fieldIndex); + case "double": + return record.getDouble(fieldIndex); + case "long": + return record.getLong(fieldIndex); + case "boolean": + return record.getBoolean(fieldIndex); + case "string": + return record.getString(fieldIndex); + case "array": + ArrayType arrayType = (ArrayType) dataType; + ArrayData arrayData = record.getArray(fieldIndex); + return object(arrayData, arrayType.elementType()); + case "struct": + StructType structType = (StructType) dataType; + InternalRow structData = record.getStruct(fieldIndex, structType.fields().length); + return object(structData, structType); + default: + return null; } + } - public static Object object(ArrayData arrayData, DataType elementType) { + public static Object object(ArrayData arrayData, DataType elementType) { - switch (elementType.typeName()) { - case "string": { - int length = arrayData.numElements(); - String[] result = new String[length]; - for (int i = 0; i < length; i++) { - result[i] = arrayData.getUTF8String(i).toString(); - } - return result; - } + switch (elementType.typeName()) { + case "string": + { + int length = arrayData.numElements(); + String[] result = new String[length]; + for (int i = 0; i < length; i++) { + result[i] = arrayData.getUTF8String(i).toString(); + } + return result; + } - case "struct": { - StructType structType = (StructType) elementType; - int length = arrayData.numElements(); - Object[] result = new Object[length]; - for (int i = 0; i < length; i++) { - InternalRow structData = arrayData.getStruct(i, structType.fields().length); - result[i] = object(structData, structType); - } - return result; - } - default: - return arrayData.toObjectArray(elementType); + case "struct": + { + StructType structType = (StructType) elementType; + int length = arrayData.numElements(); + Object[] result = new Object[length]; + for (int i = 0; i < length; i++) { + InternalRow structData = arrayData.getStruct(i, structType.fields().length); + result[i] = object(structData, structType); + } + return result; } + default: + return arrayData.toObjectArray(elementType); } + } - public static Object object(InternalRow structData, StructType structType) { - Map result = new HashMap<>(); - for (int i = 0; i < structType.fields().length; i++) { - StructField structField = structType.fields()[i]; - result.put(structField.name(), object(structData, structField, i)); - } - return result; + public static Object object(InternalRow structData, StructType structType) { + Map result = new HashMap<>(); + for (int i = 0; i < structType.fields().length; i++) { + StructField structField = structType.fields()[i]; + result.put(structField.name(), object(structData, structField, i)); } -} \ No newline at end of file + return result; + } +} diff --git a/src/main/java/io/qdrant/spark/Qdrant.java b/src/main/java/io/qdrant/spark/Qdrant.java index 5678cad..14974f9 100644 --- a/src/main/java/io/qdrant/spark/Qdrant.java +++ b/src/main/java/io/qdrant/spark/Qdrant.java @@ -11,13 +11,13 @@ import org.apache.spark.sql.util.CaseInsensitiveStringMap; /** - * A class that implements the TableProvider and DataSourceRegister interfaces. - * Provides methods to + * A class that implements the TableProvider and DataSourceRegister interfaces. Provides methods to * infer schema, get table, and check required options. */ public class Qdrant implements TableProvider, DataSourceRegister { - private final String[] requiredFields = new String[] { "schema", "collection_name", "embedding_field", "qdrant_url" }; + private final String[] requiredFields = + new String[] {"schema", "collection_name", "embedding_field", "qdrant_url"}; /** * Returns the short name of the data source. @@ -42,15 +42,15 @@ public StructType inferSchema(CaseInsensitiveStringMap options) { checkRequiredOptions(options, schema); return schema; - }; + } + ; /** - * Returns a table for the data source based on the provided schema, - * partitioning, and properties. + * Returns a table for the data source based on the provided schema, partitioning, and properties. * - * @param schema The schema of the table. + * @param schema The schema of the table. * @param partitioning The partitioning of the table. - * @param properties The properties of the table. + * @param properties The properties of the table. * @return The table for the data source. */ @Override @@ -61,12 +61,11 @@ public Table getTable( } /** - * Checks if the required options are present in the provided options and if the - * id_field and + * Checks if the required options are present in the provided options and if the id_field and * embedding_field options are present in the provided schema. * * @param options The options to check. - * @param schema The schema to check. + * @param schema The schema to check. */ void checkRequiredOptions(CaseInsensitiveStringMap options, StructType schema) { for (String fieldName : requiredFields) { diff --git a/src/main/java/io/qdrant/spark/QdrantDataWriter.java b/src/main/java/io/qdrant/spark/QdrantDataWriter.java index b7f23bb..3baff24 100644 --- a/src/main/java/io/qdrant/spark/QdrantDataWriter.java +++ b/src/main/java/io/qdrant/spark/QdrantDataWriter.java @@ -1,5 +1,7 @@ package io.qdrant.spark; +import static io.qdrant.spark.ObjectFactory.object; + import java.io.Serializable; import java.util.ArrayList; import java.util.HashMap; @@ -13,17 +15,11 @@ import org.slf4j.Logger; import org.slf4j.LoggerFactory; -import static io.qdrant.spark.ObjectFactory.object; - /** - * A DataWriter implementation that writes data to Qdrant, a vector search - * engine. This class takes - * QdrantOptions and StructType as input and writes data to QdrantRest. It - * implements the DataWriter - * interface and overrides its methods write, commit, abort and close. It also - * has a private method - * write that is used to upload a batch of points to Qdrant. The class uses a - * Point class to + * A DataWriter implementation that writes data to Qdrant, a vector search engine. This class takes + * QdrantOptions and StructType as input and writes data to QdrantRest. It implements the DataWriter + * interface and overrides its methods write, commit, abort and close. It also has a private method + * write that is used to upload a batch of points to Qdrant. The class uses a Point class to * represent a data point and an ArrayList to store the points. */ public class QdrantDataWriter implements DataWriter, Serializable { @@ -111,12 +107,10 @@ public void write(int retries) { } @Override - public void abort() { - } + public void abort() {} @Override - public void close() { - } + public void close() {} } class Point implements Serializable { diff --git a/src/main/java/io/qdrant/spark/QdrantRest.java b/src/main/java/io/qdrant/spark/QdrantRest.java index f790b0a..65b2511 100644 --- a/src/main/java/io/qdrant/spark/QdrantRest.java +++ b/src/main/java/io/qdrant/spark/QdrantRest.java @@ -1,13 +1,11 @@ package io.qdrant.spark; +import com.google.gson.Gson; import java.io.IOException; import java.io.Serializable; import java.net.MalformedURLException; import java.net.URL; import java.util.List; - -import com.google.gson.Gson; - import okhttp3.MediaType; import okhttp3.OkHttpClient; import okhttp3.Request; @@ -23,7 +21,7 @@ public class QdrantRest implements Serializable { * Constructor for QdrantRest class. * * @param qdrantUrl The URL of the Qdrant instance. - * @param apiKey The API key to authenticate with Qdrant. + * @param apiKey The API key to authenticate with Qdrant. */ public QdrantRest(String qdrantUrl, String apiKey) { this.qdrantUrl = qdrantUrl; @@ -34,11 +32,9 @@ public QdrantRest(String qdrantUrl, String apiKey) { * Uploads a batch of points to a Qdrant collection. * * @param collectionName The name of the collection to upload the points to. - * @param points The list of points to upload. - * @throws IOException If there was an error uploading the batch to - * Qdrant. - * @throws RuntimeException If there was an error uploading the batch to - * Qdrant. + * @param points The list of points to upload. + * @throws IOException If there was an error uploading the batch to Qdrant. + * @throws RuntimeException If there was an error uploading the batch to Qdrant. * @throws MalformedURLException If the Qdrant URL is malformed. */ public void uploadBatch(String collectionName, List points) diff --git a/src/test/java/io/qdrant/spark/TestIntegration.java b/src/test/java/io/qdrant/spark/TestIntegration.java index 9e9f6f5..4e5b0f9 100644 --- a/src/test/java/io/qdrant/spark/TestIntegration.java +++ b/src/test/java/io/qdrant/spark/TestIntegration.java @@ -30,54 +30,58 @@ public TestIntegration() { @Test public void testSparkSession() { - SparkSession spark = SparkSession.builder().master("local[1]").appName("qdrant-spark").getOrCreate(); + SparkSession spark = + SparkSession.builder().master("local[1]").appName("qdrant-spark").getOrCreate(); - List data = Arrays.asList( - RowFactory.create( - 1, - 1, - new float[] { 1.0f, 2.0f, 3.0f }, - "John Doe", - new String[] { "Hello", "Hi" }, - RowFactory.create(99, "AnotherNestedStruct"), - new int[] { 4, 32, 323, 788 }), - RowFactory.create( - 2, - 2, - new float[] { 4.0f, 5.0f, 6.0f }, - "Jane Doe", - new String[] { "Bonjour", "Salut" }, - RowFactory.create(99, "AnotherNestedStruct"), - new int[] { 1, 2, 3, 4, 5 })); + List data = + Arrays.asList( + RowFactory.create( + 1, + 1, + new float[] {1.0f, 2.0f, 3.0f}, + "John Doe", + new String[] {"Hello", "Hi"}, + RowFactory.create(99, "AnotherNestedStruct"), + new int[] {4, 32, 323, 788}), + RowFactory.create( + 2, + 2, + new float[] {4.0f, 5.0f, 6.0f}, + "Jane Doe", + new String[] {"Bonjour", "Salut"}, + RowFactory.create(99, "AnotherNestedStruct"), + new int[] {1, 2, 3, 4, 5})); - StructType structType = new StructType( - new StructField[] { - new StructField("nested_id", DataTypes.IntegerType, false, Metadata.empty()), - new StructField("nested_value", DataTypes.StringType, false, Metadata.empty()) - }); + StructType structType = + new StructType( + new StructField[] { + new StructField("nested_id", DataTypes.IntegerType, false, Metadata.empty()), + new StructField("nested_value", DataTypes.StringType, false, Metadata.empty()) + }); - StructType schema = new StructType( - new StructField[] { - new StructField("id", DataTypes.IntegerType, false, Metadata.empty()), - new StructField("score", DataTypes.IntegerType, true, Metadata.empty()), - new StructField( - "embedding", - DataTypes.createArrayType(DataTypes.FloatType), - true, - Metadata.empty()), - new StructField("name", DataTypes.StringType, true, Metadata.empty()), - new StructField( - "greetings", - DataTypes.createArrayType(DataTypes.StringType), - true, - Metadata.empty()), - new StructField("struct_data", structType, true, Metadata.empty()), - new StructField( - "numbers", - DataTypes.createArrayType(DataTypes.IntegerType), - true, - Metadata.empty()), - }); + StructType schema = + new StructType( + new StructField[] { + new StructField("id", DataTypes.IntegerType, false, Metadata.empty()), + new StructField("score", DataTypes.IntegerType, true, Metadata.empty()), + new StructField( + "embedding", + DataTypes.createArrayType(DataTypes.FloatType), + true, + Metadata.empty()), + new StructField("name", DataTypes.StringType, true, Metadata.empty()), + new StructField( + "greetings", + DataTypes.createArrayType(DataTypes.StringType), + true, + Metadata.empty()), + new StructField("struct_data", structType, true, Metadata.empty()), + new StructField( + "numbers", + DataTypes.createArrayType(DataTypes.IntegerType), + true, + Metadata.empty()), + }); Dataset df = spark.createDataFrame(data, schema); df.write()