Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat: add acs vector store #2041

Merged
Merged
Show file tree
Hide file tree
Changes from 6 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@ import org.apache.spark.ml.param._
import org.apache.spark.ml.util._
import org.apache.spark.ml.{ComplexParamsReadable, NamespaceInjections, PipelineModel}
import org.apache.spark.ml.linalg.SQLDataTypes.VectorType
import org.apache.spark.ml.functions.vector_to_array
import org.apache.spark.sql.functions.{col, expr, struct, to_json}
import org.apache.spark.sql.streaming.DataStreamWriter
import org.apache.spark.sql.types._
Expand Down Expand Up @@ -195,7 +196,7 @@ object AzureSearchWriter extends IndexParser with IndexJsonGetter with VectorCol
): Option[Seq[IndexField]] = {
schema match {
case StructType(fields) => Some(convertFields(fields, keyCol, searchActionCol, vectorCols, prefix))
// not supporting vector columns in complex scenarios
// TODO: Support vector search in nested fields
case ArrayType(StructType(fields), _) => Some(convertFields(fields, keyCol, searchActionCol, None, prefix))
case _ => None
}
Expand Down Expand Up @@ -277,6 +278,10 @@ object AzureSearchWriter extends IndexParser with IndexJsonGetter with VectorCol
indexJson = dfToIndexJson(castedDF.get.schema, indexName, keyCol.get, actionCol, vectorCols)
}

// TODO: Support vector fields in nested fields
// Throws an exception if any nested field is a vector in the schema
parseIndexJson(indexJson).fields.foreach(_.fields.foreach(assertNoNestedVectors))

SearchIndex.createIfNoneExists(subscriptionKey, serviceName, indexJson, apiVersion)

logInfo("checking schema parity")
Expand Down Expand Up @@ -304,25 +309,45 @@ object AzureSearchWriter extends IndexParser with IndexJsonGetter with VectorCol
UDFUtils.oldUdf(checkForErrors(fatalErrors) _, ErrorUtils.ErrorSchema)(col("error"), col("input")))
}

private def assertNoNestedVectors(fields: Seq[IndexField]): Unit = {
def checkVectorField(field: IndexField): Unit = {
if (field.dimensions.nonEmpty && field.vectorSearchConfiguration.nonEmpty) {
throw new IllegalArgumentException(s"Nested field ${field.name} is a vector field, vector fields in nested" +
s" fields are not supported.")
}
field.fields.foreach(_.foreach(checkVectorField))
}
fields.foreach(checkVectorField)
}

private def getVectorColNameTypeTupleFromIndexJson(indexJson: String): Seq[(String, String)] = {
aydan-at-microsoft marked this conversation as resolved.
Show resolved Hide resolved
parseIndexJson(indexJson).fields
.filter(f => f.vectorSearchConfiguration.nonEmpty && f.dimensions.nonEmpty)
.map(f => (f.name, f.`type`))
}
private def castDFColsToVectorCompatibleType(vectorColNameTypeTuple: Seq[(String, String)],
aydan-at-microsoft marked this conversation as resolved.
Show resolved Hide resolved
df: DataFrame): DataFrame = {

vectorColNameTypeTuple.foldLeft(df) { case (accDF, (colName, colType)) =>
assert(accDF.columns.contains(colName), s"Column $colName not found in dataframe columns ${accDF.columns.toList}")
val colDataType = accDF.schema(colName).dataType
assert(colDataType match {
case ArrayType(elementType, _) => elementType == FloatType || elementType == DoubleType
case VectorType => true
case _ => false
}, s"Vector column $colName needs to be one of (ArrayType(FloatType), ArrayType(DoubleType), VectorType)")
accDF.withColumn(colName, accDF(colName).cast(edmTypeToSparkType(colType, None)))
if (!accDF.columns.contains(colName)) {
println(s"Column $colName is specified in either indexJsonOpt or vectorColsOpt but not found in dataframe " +
s"columns ${accDF.columns.toList}")
accDF
}
else {
val colDataType = accDF.schema(colName).dataType
assert(colDataType match {
case ArrayType(elementType, _) => elementType == FloatType || elementType == DoubleType
case VectorType => true
case _ => false
}, s"Vector column $colName needs to be one of (ArrayType(FloatType), ArrayType(DoubleType), VectorType)")
if (colDataType.isInstanceOf[ArrayType]) {
accDF.withColumn(colName, accDF(colName).cast(edmTypeToSparkType(colType, None)))
} else {
accDF.withColumn(colName, vector_to_array(accDF(colName), edmTypeToVectordType(colType)))
aydan-at-microsoft marked this conversation as resolved.
Show resolved Hide resolved
}
}
}
}
}

private def isEdmCollection(t: String): Boolean = {
t.startsWith("Collection(") && t.endsWith(")")
Expand All @@ -332,6 +357,13 @@ object AzureSearchWriter extends IndexParser with IndexJsonGetter with VectorCol
t.substring("Collection(".length).dropRight(1)
}

private def edmTypeToVectordType(t: String): String = {
t match {
case "Collection(Edm.Single)" => "float32"
case "Collection(Edm.Double)" => "float64"
}
}

private[ml] def edmTypeToSparkType(dt: String, //scalastyle:ignore cyclomatic.complexity
fields: Option[Seq[IndexField]]): DataType = dt match {
case t if isEdmCollection(t) =>
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -122,6 +122,8 @@ object SearchIndex extends IndexParser with IndexLister {
_ <- validAnalyzer(field.analyzer, field.searchAnalyzer, field.indexAnalyzer)
_ <- validSearchAnalyzer(field.analyzer, field.searchAnalyzer, field.indexAnalyzer)
_ <- validIndexAnalyzer(field.analyzer, field.searchAnalyzer, field.indexAnalyzer)
_ <- validVectorField(field.dimensions, field.vectorSearchConfiguration)
// TODO: Fix and add back validSynonymMaps check. SynonymMaps needs to be Option[Seq[String]] type
//_ <- validSynonymMaps(field.synonymMap)
aydan-at-microsoft marked this conversation as resolved.
Show resolved Hide resolved
} yield field
}
Expand Down Expand Up @@ -210,6 +212,15 @@ object SearchIndex extends IndexParser with IndexLister {
}
}

private def validVectorField(d: Option[Int], v: Option[String]): Try[Option[String]] = {
if ((d.isDefined && v.isEmpty) || (v.isDefined && d.isEmpty)) {
Failure(new IllegalArgumentException("Both dimensions and vectorSearchConfig fields need to be defined for " +
"vector search"))
} else {
Success(v)
}
}

def getStatistics(indexName: String,
key: String,
serviceName: String,
Expand Down
Loading