Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat: add acs vector store #2041

Merged
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,8 @@ import org.apache.spark.internal.{Logging => SLogging}
import org.apache.spark.ml.param._
import org.apache.spark.ml.util._
import org.apache.spark.ml.{ComplexParamsReadable, NamespaceInjections, PipelineModel}
import org.apache.spark.ml.linalg.SQLDataTypes.VectorType
import org.apache.spark.ml.functions.vector_to_array
import org.apache.spark.sql.functions.{col, expr, struct, to_json}
import org.apache.spark.sql.streaming.DataStreamWriter
import org.apache.spark.sql.types._
Expand Down Expand Up @@ -142,7 +144,7 @@ class AddDocuments(override val uid: String) extends CognitiveServicesBase(uid)
override def responseDataType: DataType = ASResponses.schema
}

object AzureSearchWriter extends IndexParser with SLogging {
object AzureSearchWriter extends IndexParser with IndexJsonGetter with SLogging {

val Logger: Logger = LogManager.getRootLogger

Expand All @@ -166,9 +168,11 @@ object AzureSearchWriter extends IndexParser with SLogging {
private def convertFields(fields: Seq[StructField],
keyCol: String,
searchActionCol: String,
vectorCols: Option[Seq[VectorColParams]],
prefix: Option[String]): Seq[IndexField] = {
fields.filterNot(_.name == searchActionCol).map { sf =>
val fullName = prefix.map(_ + sf.name).getOrElse(sf.name)
val isVector = vectorCols.exists(_.exists(_.name == fullName))
val (innerType, _) = sparkTypeToEdmType(sf.dataType)
IndexField(
sf.name,
Expand All @@ -177,31 +181,44 @@ object AzureSearchWriter extends IndexParser with SLogging {
if (keyCol == fullName) Some(true) else None,
None, None, None, None,
structFieldToSearchFields(sf.dataType,
keyCol, searchActionCol, prefix = Some(prefix.getOrElse("") + sf.name + "."))
keyCol, searchActionCol, None, prefix = Some(prefix.getOrElse("") + sf.name + ".")),
if (isVector) vectorCols.get.find(_.name == fullName).map(_.dimension) else None,
if (isVector) Some(AzureSearchAPIConstants.VectorConfigName) else None
)
}
}

private def structFieldToSearchFields(schema: DataType,
keyCol: String,
searchActionCol: String,
vectorCols: Option[Seq[VectorColParams]],
prefix: Option[String] = None
): Option[Seq[IndexField]] = {
schema match {
case StructType(fields) => Some(convertFields(fields, keyCol, searchActionCol, prefix))
case ArrayType(StructType(fields), _) => Some(convertFields(fields, keyCol, searchActionCol, prefix))
case StructType(fields) => Some(convertFields(fields, keyCol, searchActionCol, vectorCols, prefix))
// TODO: Support vector search in nested fields
case ArrayType(StructType(fields), _) => Some(convertFields(fields, keyCol, searchActionCol, None, prefix))
case _ => None
}
}

private def parseVectorColsJson(str: String): Seq[VectorColParams] = {
str.parseJson.convertTo[Seq[VectorColParams]]
}

private def dfToIndexJson(schema: StructType,
indexName: String,
keyCol: String,
searchActionCol: String): String = {
searchActionCol: String,
vectorCols: Option[Seq[VectorColParams]]): String = {

val vectorConfig = Some(VectorSearch(Seq(AlgorithmConfigs(AzureSearchAPIConstants.VectorConfigName,
AzureSearchAPIConstants.VectorSearchAlgorithm))))
val is = IndexInfo(
Some(indexName),
structFieldToSearchFields(schema, keyCol, searchActionCol).get,
None, None, None, None, None, None, None, None
structFieldToSearchFields(schema, keyCol, searchActionCol, vectorCols).get,
None, None, None, None, None, None, None, None,
if (vectorCols.isEmpty) None else vectorConfig
)
is.toJson.compactPrint
}
Expand All @@ -210,7 +227,7 @@ object AzureSearchWriter extends IndexParser with SLogging {
options: Map[String, String] = Map()): DataFrame = {
val applicableOptions = Set(
"subscriptionKey", "actionCol", "serviceName", "indexName", "indexJson",
"apiVersion", "batchSize", "fatalErrors", "filterNulls", "keyCol"
"apiVersion", "batchSize", "fatalErrors", "filterNulls", "keyCol", "vectorCols"
)

options.keys.foreach(k =>
Expand All @@ -224,11 +241,12 @@ object AzureSearchWriter extends IndexParser with SLogging {
val batchSize = options.getOrElse("batchSize", "100").toInt
val fatalErrors = options.getOrElse("fatalErrors", "true").toBoolean
val filterNulls = options.getOrElse("filterNulls", "false").toBoolean
val vectorColsInfo = options.get("vectorCols")

val keyCol = options.get("keyCol")
val indexName = options.getOrElse("indexName", parseIndexJson(indexJsonOpt.get).name.get)
if (indexJsonOpt.isDefined) {
List("keyCol", "indexName").foreach(opt =>
List("keyCol", "indexName", "vectorCols").foreach(opt =>
assert(!options.contains(opt), s"Cannot set both indexJson options and $opt")
)
}
Expand All @@ -242,22 +260,41 @@ object AzureSearchWriter extends IndexParser with SLogging {
}
}

val indexJson = indexJsonOpt.getOrElse {
dfToIndexJson(df.schema, indexName, keyCol.get, actionCol)
val (indexJson, preppedDF) = if (getExisting(subscriptionKey, serviceName, apiVersion).contains(indexName)) {
if (indexJsonOpt.isDefined) {
println(f"indexJsonOpt is specified, however an index for $indexName already exists," +
f"we will use the index definition obtained from the existing index instead")
}
val existingIndexJson = getIndexJsonFromExistingIndex(subscriptionKey, serviceName, indexName)
val vectorColNameTypeTuple = getVectorColConf(existingIndexJson)
(existingIndexJson, makeColsCompatible(vectorColNameTypeTuple, df))
} else if (indexJsonOpt.isDefined) {
val vectorColNameTypeTuple = getVectorColConf(indexJsonOpt.get)
(indexJsonOpt.get, makeColsCompatible(vectorColNameTypeTuple, df))
} else {
val vectorCols = vectorColsInfo.map(parseVectorColsJson)
val vectorColNameTypeTuple = vectorCols.map(_.map(vc => (vc.name, "Collection(Edm.Single)"))).getOrElse(Seq.empty)
val newDF = makeColsCompatible(vectorColNameTypeTuple, df)
val inferredIndexJson = dfToIndexJson(newDF.schema, indexName, keyCol.getOrElse(""), actionCol, vectorCols)
(inferredIndexJson, newDF)
}

// TODO: Support vector search in nested fields
// Throws an exception if any nested field is a vector in the schema
parseIndexJson(indexJson).fields.foreach(_.fields.foreach(assertNoNestedVectors))

SearchIndex.createIfNoneExists(subscriptionKey, serviceName, indexJson, apiVersion)

logInfo("checking schema parity")
checkSchemaParity(df.schema, indexJson, actionCol)
checkSchemaParity(preppedDF.schema, indexJson, actionCol)

val df1 = if (filterNulls) {
val collectionColumns = parseIndexJson(indexJson).fields
.filter(_.`type`.startsWith("Collection"))
.map(_.name)
collectionColumns.foldLeft(df) { (ndf, c) => filterOutNulls(ndf, c) }
collectionColumns.foldLeft(preppedDF) { (ndf, c) => filterOutNulls(ndf, c) }
} else {
df
preppedDF
}

new AddDocuments()
Expand All @@ -273,6 +310,48 @@ object AzureSearchWriter extends IndexParser with SLogging {
UDFUtils.oldUdf(checkForErrors(fatalErrors) _, ErrorUtils.ErrorSchema)(col("error"), col("input")))
}

private def assertNoNestedVectors(fields: Seq[IndexField]): Unit = {
def checkVectorField(field: IndexField): Unit = {
if (field.dimensions.nonEmpty && field.vectorSearchConfiguration.nonEmpty) {
throw new IllegalArgumentException(s"Nested field ${field.name} is a vector field, vector fields in nested" +
s" fields are not supported.")
}
field.fields.foreach(_.foreach(checkVectorField))
}
fields.foreach(checkVectorField)
}

private def getVectorColConf(indexJson: String): Seq[(String, String)] = {
parseIndexJson(indexJson).fields
.filter(f => f.vectorSearchConfiguration.nonEmpty && f.dimensions.nonEmpty)
.map(f => (f.name, f.`type`))
}
private def makeColsCompatible(vectorColNameTypeTuple: Seq[(String, String)],
df: DataFrame): DataFrame = {
vectorColNameTypeTuple.foldLeft(df) { case (accDF, (colName, colType)) =>
if (!accDF.columns.contains(colName)) {
println(s"Column $colName is specified in either indexJson or vectorCols but not found in dataframe " +
s"columns ${accDF.columns.toList}")
accDF
}
else {
val colDataType = accDF.schema(colName).dataType
assert(colDataType match {
case ArrayType(elementType, _) => elementType == FloatType || elementType == DoubleType
case VectorType => true
case _ => false
}, s"Vector column $colName needs to be one of (ArrayType(FloatType), ArrayType(DoubleType), VectorType)")
if (colDataType.isInstanceOf[ArrayType]) {
accDF.withColumn(colName, accDF(colName).cast(edmTypeToSparkType(colType, None)))
} else {
// first cast vectorUDT to array<double>, then cast it to correct array type
val modifiedDF = accDF.withColumn(colName, vector_to_array(accDF(colName)))
modifiedDF.withColumn(colName, modifiedDF(colName).cast(edmTypeToSparkType(colType, None)))
}
}
}
}

private def isEdmCollection(t: String): Boolean = {
t.startsWith("Collection(") && t.endsWith(")")
}
Expand All @@ -290,6 +369,7 @@ object AzureSearchWriter extends IndexParser with SLogging {
case "Edm.Int64" => LongType
case "Edm.Int32" => IntegerType
case "Edm.Double" => DoubleType
case "Edm.Single" => FloatType
case "Edm.DateTimeOffset" => StringType //See if there's a way to use spark datetimes
case "Edm.GeographyPoint" => StringType
case "Edm.ComplexType" => StructType(fields.get.map(f =>
Expand All @@ -310,10 +390,12 @@ object AzureSearchWriter extends IndexParser with SLogging {
case IntegerType => ("Edm.Int32", None)
case LongType => ("Edm.Int64", None)
case DoubleType => ("Edm.Double", None)
case FloatType => ("Edm.Single", None)
case DateType => ("Edm.DateTimeOffset", None)
case StructType(fields) => ("Edm.ComplexType", Some(fields.map { f =>
val (innerType, innerFields) = sparkTypeToEdmType(f.dataType)
IndexField(f.name, innerType, None, None, None, None, None, None, None, None, None, None, innerFields)
IndexField(f.name, innerType, None, None, None, None, None, None, None, None, None, None, innerFields,
None, None) // TODO: Support vector search in nested fields
}))
}
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,9 @@ import spray.json._
import scala.util.{Failure, Success, Try}

object AzureSearchAPIConstants {
val DefaultAPIVersion = "2019-05-06"
val DefaultAPIVersion = "2023-07-01-Preview"
val VectorConfigName = "vectorConfig"
val VectorSearchAlgorithm = "hnsw"
}
import com.microsoft.azure.synapse.ml.cognitive.search.AzureSearchAPIConstants._

Expand All @@ -39,6 +41,26 @@ trait IndexLister {
}
}

trait IndexJsonGetter extends IndexLister {
def getIndexJsonFromExistingIndex(key: String,
serviceName: String,
indexName: String,
apiVersion: String = DefaultAPIVersion): String = {
val existingIndexNames = getExisting(key, serviceName, apiVersion)
assert(existingIndexNames.contains(indexName), s"Cannot find an existing index name with $indexName")

val indexJsonRequest = new HttpGet(
s"https://$serviceName.search.windows.net/indexes/$indexName?api-version=$apiVersion"
)
indexJsonRequest.setHeader("api-key", key)
indexJsonRequest.setHeader("Content-Type", "application/json")
val indexJsonResponse = safeSend(indexJsonRequest, close = false)
val indexJson = IOUtils.toString(indexJsonResponse.getEntity.getContent, "utf-8")
indexJsonResponse.close()
indexJson
}
}

object SearchIndex extends IndexParser with IndexLister {

import AzureSearchProtocol._
Expand Down Expand Up @@ -94,7 +116,9 @@ object SearchIndex extends IndexParser with IndexLister {
_ <- validAnalyzer(field.analyzer, field.searchAnalyzer, field.indexAnalyzer)
_ <- validSearchAnalyzer(field.analyzer, field.searchAnalyzer, field.indexAnalyzer)
_ <- validIndexAnalyzer(field.analyzer, field.searchAnalyzer, field.indexAnalyzer)
_ <- validSynonymMaps(field.synonymMap)
_ <- validVectorField(field.dimensions, field.vectorSearchConfiguration)
// TODO: Fix and add back validSynonymMaps check. SynonymMaps needs to be Option[Seq[String]] type
//_ <- validSynonymMaps(field.synonymMap)
aydan-at-microsoft marked this conversation as resolved.
Show resolved Hide resolved
} yield field
}

Expand Down Expand Up @@ -182,6 +206,15 @@ object SearchIndex extends IndexParser with IndexLister {
}
}

private def validVectorField(d: Option[Int], v: Option[String]): Try[Option[String]] = {
if ((d.isDefined && v.isEmpty) || (v.isDefined && d.isEmpty)) {
Failure(new IllegalArgumentException("Both dimensions and vectorSearchConfig fields need to be defined for " +
"vector search"))
} else {
Success(v)
}
}

def getStatistics(indexName: String,
key: String,
serviceName: String,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@ package com.microsoft.azure.synapse.ml.cognitive.search

import com.microsoft.azure.synapse.ml.core.schema.SparkBindings
import spray.json.DefaultJsonProtocol._
import spray.json.{JsonFormat, RootJsonFormat}
import spray.json.{DefaultJsonProtocol, JsonFormat, RootJsonFormat}

object ASResponses extends SparkBindings[ASResponses]

Expand All @@ -23,9 +23,19 @@ case class IndexInfo(
tokenizers: Option[Seq[String]],
tokenFilters: Option[Seq[String]],
defaultScoringProfile: Option[Seq[String]],
corsOptions: Option[Seq[String]]
corsOptions: Option[Seq[String]],
vectorSearch: Option[VectorSearch]
)

case class AlgorithmConfigs(
name: String,
kind: String
)

case class VectorSearch(
algorithmConfigurations: Seq[AlgorithmConfigs]
)

case class IndexField(
name: String,
`type`: String,
Expand All @@ -38,21 +48,32 @@ case class IndexField(
analyzer: Option[String],
searchAnalyzer: Option[String],
indexAnalyzer: Option[String],
synonymMap: Option[String],
fields: Option[Seq[IndexField]]
synonymMap: Option[Seq[String]],
fields: Option[Seq[IndexField]],
dimensions: Option[Int],
vectorSearchConfiguration: Option[String]
)

case class VectorColParams(
name: String,
dimension: Int
)

case class IndexStats(documentCount: Int, storageSize: Int)

case class IndexList(`@odata.context`: String, value: Seq[IndexName])
case class IndexName(name: String)

object AzureSearchProtocol {
object AzureSearchProtocol extends DefaultJsonProtocol {
implicit val IfEnc: JsonFormat[IndexField] = lazyFormat(jsonFormat(
IndexField,"name","type","searchable","filterable","sortable",
"facetable","retrievable", "key","analyzer","searchAnalyzer", "indexAnalyzer", "synonymMaps", "fields"))
implicit val IiEnc: RootJsonFormat[IndexInfo] = jsonFormat10(IndexInfo.apply)
"facetable","retrievable", "key","analyzer","searchAnalyzer", "indexAnalyzer", "synonymMaps", "fields",
"dimensions", "vectorSearchConfiguration"))
implicit val AcEnc: RootJsonFormat[AlgorithmConfigs] = jsonFormat2(AlgorithmConfigs.apply)
implicit val VsEnc: RootJsonFormat[VectorSearch] = jsonFormat1(VectorSearch.apply)
implicit val IiEnc: RootJsonFormat[IndexInfo] = jsonFormat11(IndexInfo.apply)
implicit val IsEnc: RootJsonFormat[IndexStats] = jsonFormat2(IndexStats.apply)
implicit val InEnc: RootJsonFormat[IndexName] = jsonFormat1(IndexName.apply)
implicit val IlEnc: RootJsonFormat[IndexList] = jsonFormat2(IndexList.apply)
implicit val VcpEnc: RootJsonFormat[VectorColParams] = jsonFormat2(VectorColParams.apply)
}
Loading
Loading