Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat: add acs vector store #2041

Merged
Merged
Show file tree
Hide file tree
Changes from 13 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,8 @@ import org.apache.spark.internal.{Logging => SLogging}
import org.apache.spark.ml.param._
import org.apache.spark.ml.util._
import org.apache.spark.ml.{ComplexParamsReadable, NamespaceInjections, PipelineModel}
import org.apache.spark.ml.linalg.SQLDataTypes.VectorType
import org.apache.spark.ml.functions.vector_to_array
import org.apache.spark.sql.functions.{col, expr, struct, to_json}
import org.apache.spark.sql.streaming.DataStreamWriter
import org.apache.spark.sql.types._
Expand Down Expand Up @@ -142,7 +144,7 @@ class AddDocuments(override val uid: String) extends CognitiveServicesBase(uid)
override def responseDataType: DataType = ASResponses.schema
}

object AzureSearchWriter extends IndexParser with SLogging {
object AzureSearchWriter extends IndexParser with IndexJsonGetter with VectorColsParser with SLogging {
aydan-at-microsoft marked this conversation as resolved.
Show resolved Hide resolved

val Logger: Logger = LogManager.getRootLogger

Expand All @@ -166,9 +168,11 @@ object AzureSearchWriter extends IndexParser with SLogging {
private def convertFields(fields: Seq[StructField],
keyCol: String,
searchActionCol: String,
vectorCols: Option[Seq[VectorColParams]],
prefix: Option[String]): Seq[IndexField] = {
fields.filterNot(_.name == searchActionCol).map { sf =>
val fullName = prefix.map(_ + sf.name).getOrElse(sf.name)
val isVector = vectorCols.exists(_.exists(_.name == fullName))
val (innerType, _) = sparkTypeToEdmType(sf.dataType)
IndexField(
sf.name,
Expand All @@ -177,31 +181,39 @@ object AzureSearchWriter extends IndexParser with SLogging {
if (keyCol == fullName) Some(true) else None,
None, None, None, None,
structFieldToSearchFields(sf.dataType,
keyCol, searchActionCol, prefix = Some(prefix.getOrElse("") + sf.name + "."))
keyCol, searchActionCol, None, prefix = Some(prefix.getOrElse("") + sf.name + ".")),
if (isVector) vectorCols.get.find(_.name == fullName).map(_.dimension) else None,
if (isVector) Some(AzureSearchAPIConstants.VectorConfigName) else None
)
}
}

private def structFieldToSearchFields(schema: DataType,
keyCol: String,
searchActionCol: String,
vectorCols: Option[Seq[VectorColParams]],
prefix: Option[String] = None
): Option[Seq[IndexField]] = {
schema match {
case StructType(fields) => Some(convertFields(fields, keyCol, searchActionCol, prefix))
case ArrayType(StructType(fields), _) => Some(convertFields(fields, keyCol, searchActionCol, prefix))
case StructType(fields) => Some(convertFields(fields, keyCol, searchActionCol, vectorCols, prefix))
// TODO: Support vector search in nested fields
case ArrayType(StructType(fields), _) => Some(convertFields(fields, keyCol, searchActionCol, None, prefix))
case _ => None
}
}

private def dfToIndexJson(schema: StructType,
indexName: String,
keyCol: String,
searchActionCol: String): String = {
searchActionCol: String,
vectorCols: Option[Seq[VectorColParams]]): String = {

val is = IndexInfo(
Some(indexName),
structFieldToSearchFields(schema, keyCol, searchActionCol).get,
None, None, None, None, None, None, None, None
structFieldToSearchFields(schema, keyCol, searchActionCol, vectorCols).get,
None, None, None, None, None, None, None, None,
if (vectorCols.isEmpty) None else Some(VectorSearch(Seq(AlgorithmConfigs(AzureSearchAPIConstants.VectorConfigName,
aydan-at-microsoft marked this conversation as resolved.
Show resolved Hide resolved
AzureSearchAPIConstants.VectorSearchAlgorithm))))
)
is.toJson.compactPrint
}
Expand All @@ -210,7 +222,7 @@ object AzureSearchWriter extends IndexParser with SLogging {
options: Map[String, String] = Map()): DataFrame = {
val applicableOptions = Set(
"subscriptionKey", "actionCol", "serviceName", "indexName", "indexJson",
"apiVersion", "batchSize", "fatalErrors", "filterNulls", "keyCol"
"apiVersion", "batchSize", "fatalErrors", "filterNulls", "keyCol", "vectorCols"
)

options.keys.foreach(k =>
Expand All @@ -224,11 +236,12 @@ object AzureSearchWriter extends IndexParser with SLogging {
val batchSize = options.getOrElse("batchSize", "100").toInt
val fatalErrors = options.getOrElse("fatalErrors", "true").toBoolean
val filterNulls = options.getOrElse("filterNulls", "false").toBoolean
val vectorColsOpt = options.get("vectorCols")
aydan-at-microsoft marked this conversation as resolved.
Show resolved Hide resolved

val keyCol = options.get("keyCol")
val indexName = options.getOrElse("indexName", parseIndexJson(indexJsonOpt.get).name.get)
if (indexJsonOpt.isDefined) {
List("keyCol", "indexName").foreach(opt =>
List("keyCol", "indexName", "vectorCols").foreach(opt =>
assert(!options.contains(opt), s"Cannot set both indexJson options and $opt")
)
}
Expand All @@ -242,22 +255,45 @@ object AzureSearchWriter extends IndexParser with SLogging {
}
}

val indexJson = indexJsonOpt.getOrElse {
dfToIndexJson(df.schema, indexName, keyCol.get, actionCol)
var indexJson = ""
var castedDF: Option[DataFrame] = None
aydan-at-microsoft marked this conversation as resolved.
Show resolved Hide resolved
aydan-at-microsoft marked this conversation as resolved.
Show resolved Hide resolved
if (getExisting(subscriptionKey, serviceName, apiVersion).contains(indexName)) {
indexJson = getIndexJsonFromExistingIndex(subscriptionKey, serviceName, indexName)
if (indexJsonOpt.isDefined) {
println(f"indexJsonOpt is specified, however an index for the $indexName already exists," +
f"we will use the index definition obtained from the existing index instead")
}
val vectorColNameTypeTuple = getVectorColNameTypeTupleFromIndexJson(indexJson)
castedDF = Some(castDFColsToVectorCompatibleType(vectorColNameTypeTuple, df))
}
else if (indexJsonOpt.isDefined) {
indexJson = indexJsonOpt.get
val vectorColNameTypeTuple = getVectorColNameTypeTupleFromIndexJson(indexJson)
castedDF = Some(castDFColsToVectorCompatibleType(vectorColNameTypeTuple, df))
}
else {
val vectorCols = vectorColsOpt.map(parseVectorColsJson)
val vectorColNameTypeTuple = vectorCols.map(_.map(vc => (vc.name, "Collection(Edm.Single)"))).getOrElse(Seq.empty)
castedDF = Some(castDFColsToVectorCompatibleType(vectorColNameTypeTuple, df))
indexJson = dfToIndexJson(castedDF.get.schema, indexName, keyCol.get, actionCol, vectorCols)
}

// TODO: Support vector fields in nested fields
// Throws an exception if any nested field is a vector in the schema
parseIndexJson(indexJson).fields.foreach(_.fields.foreach(assertNoNestedVectors))

SearchIndex.createIfNoneExists(subscriptionKey, serviceName, indexJson, apiVersion)

logInfo("checking schema parity")
checkSchemaParity(df.schema, indexJson, actionCol)
checkSchemaParity(castedDF.get.schema, indexJson, actionCol)

val df1 = if (filterNulls) {
val collectionColumns = parseIndexJson(indexJson).fields
.filter(_.`type`.startsWith("Collection"))
.map(_.name)
collectionColumns.foldLeft(df) { (ndf, c) => filterOutNulls(ndf, c) }
collectionColumns.foldLeft(castedDF.get) { (ndf, c) => filterOutNulls(ndf, c) }
} else {
df
castedDF.get
}

new AddDocuments()
Expand All @@ -273,6 +309,46 @@ object AzureSearchWriter extends IndexParser with SLogging {
UDFUtils.oldUdf(checkForErrors(fatalErrors) _, ErrorUtils.ErrorSchema)(col("error"), col("input")))
}

private def assertNoNestedVectors(fields: Seq[IndexField]): Unit = {
def checkVectorField(field: IndexField): Unit = {
if (field.dimensions.nonEmpty && field.vectorSearchConfiguration.nonEmpty) {
throw new IllegalArgumentException(s"Nested field ${field.name} is a vector field, vector fields in nested" +
s" fields are not supported.")
}
field.fields.foreach(_.foreach(checkVectorField))
}
fields.foreach(checkVectorField)
}

private def getVectorColNameTypeTupleFromIndexJson(indexJson: String): Seq[(String, String)] = {
aydan-at-microsoft marked this conversation as resolved.
Show resolved Hide resolved
parseIndexJson(indexJson).fields
.filter(f => f.vectorSearchConfiguration.nonEmpty && f.dimensions.nonEmpty)
.map(f => (f.name, f.`type`))
}
private def castDFColsToVectorCompatibleType(vectorColNameTypeTuple: Seq[(String, String)],
aydan-at-microsoft marked this conversation as resolved.
Show resolved Hide resolved
df: DataFrame): DataFrame = {
vectorColNameTypeTuple.foldLeft(df) { case (accDF, (colName, colType)) =>
if (!accDF.columns.contains(colName)) {
println(s"Column $colName is specified in either indexJsonOpt or vectorColsOpt but not found in dataframe " +
s"columns ${accDF.columns.toList}")
accDF
}
else {
val colDataType = accDF.schema(colName).dataType
assert(colDataType match {
case ArrayType(elementType, _) => elementType == FloatType || elementType == DoubleType
case VectorType => true
case _ => false
}, s"Vector column $colName needs to be one of (ArrayType(FloatType), ArrayType(DoubleType), VectorType)")
if (colDataType.isInstanceOf[ArrayType]) {
accDF.withColumn(colName, accDF(colName).cast(edmTypeToSparkType(colType, None)))
} else {
accDF.withColumn(colName, vector_to_array(accDF(colName), edmTypeToVectordType(colType)))
aydan-at-microsoft marked this conversation as resolved.
Show resolved Hide resolved
}
}
}
}

private def isEdmCollection(t: String): Boolean = {
t.startsWith("Collection(") && t.endsWith(")")
}
Expand All @@ -281,6 +357,13 @@ object AzureSearchWriter extends IndexParser with SLogging {
t.substring("Collection(".length).dropRight(1)
}

private def edmTypeToVectordType(t: String): String = {
t match {
case "Collection(Edm.Single)" => "float32"
case "Collection(Edm.Double)" => "float64"
}
}

private[ml] def edmTypeToSparkType(dt: String, //scalastyle:ignore cyclomatic.complexity
fields: Option[Seq[IndexField]]): DataType = dt match {
case t if isEdmCollection(t) =>
Expand All @@ -290,6 +373,7 @@ object AzureSearchWriter extends IndexParser with SLogging {
case "Edm.Int64" => LongType
case "Edm.Int32" => IntegerType
case "Edm.Double" => DoubleType
case "Edm.Single" => FloatType
case "Edm.DateTimeOffset" => StringType //See if there's a way to use spark datetimes
case "Edm.GeographyPoint" => StringType
case "Edm.ComplexType" => StructType(fields.get.map(f =>
Expand All @@ -310,10 +394,12 @@ object AzureSearchWriter extends IndexParser with SLogging {
case IntegerType => ("Edm.Int32", None)
case LongType => ("Edm.Int64", None)
case DoubleType => ("Edm.Double", None)
case FloatType => ("Edm.Single", None)
case DateType => ("Edm.DateTimeOffset", None)
case StructType(fields) => ("Edm.ComplexType", Some(fields.map { f =>
val (innerType, innerFields) = sparkTypeToEdmType(f.dataType)
IndexField(f.name, innerType, None, None, None, None, None, None, None, None, None, None, innerFields)
IndexField(f.name, innerType, None, None, None, None, None, None, None, None, None, None, innerFields,
None, None) // right now not supporting vectors in complex types
aydan-at-microsoft marked this conversation as resolved.
Show resolved Hide resolved
}))
}
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,9 @@ import spray.json._
import scala.util.{Failure, Success, Try}

object AzureSearchAPIConstants {
val DefaultAPIVersion = "2019-05-06"
val DefaultAPIVersion = "2023-07-01-Preview"
val VectorConfigName = "vectorConfig"
val VectorSearchAlgorithm = "hnsw"
}
import com.microsoft.azure.synapse.ml.cognitive.search.AzureSearchAPIConstants._

Expand All @@ -24,6 +26,12 @@ trait IndexParser {
}
}

trait VectorColsParser {
def parseVectorColsJson(str: String): Seq[VectorColParams] = {
str.parseJson.convertTo[Seq[VectorColParams]]
}
}

trait IndexLister {
def getExisting(key: String,
serviceName: String,
Expand All @@ -39,6 +47,26 @@ trait IndexLister {
}
}

trait IndexJsonGetter extends IndexLister {
def getIndexJsonFromExistingIndex(key: String,
serviceName: String,
indexName: String,
apiVersion: String = DefaultAPIVersion): String = {
val existingIndexNames = getExisting(key, serviceName, apiVersion)
assert(existingIndexNames.contains(indexName), s"Cannot find an existing index name with $indexName")

val indexJsonRequest = new HttpGet(
s"https://$serviceName.search.windows.net/indexes/$indexName?api-version=$apiVersion"
)
indexJsonRequest.setHeader("api-key", key)
indexJsonRequest.setHeader("Content-Type", "application/json")
val indexJsonResponse = safeSend(indexJsonRequest, close = false)
val indexJson = IOUtils.toString(indexJsonResponse.getEntity.getContent, "utf-8")
indexJsonResponse.close()
indexJson
}
}

object SearchIndex extends IndexParser with IndexLister {

import AzureSearchProtocol._
Expand Down Expand Up @@ -94,7 +122,9 @@ object SearchIndex extends IndexParser with IndexLister {
_ <- validAnalyzer(field.analyzer, field.searchAnalyzer, field.indexAnalyzer)
_ <- validSearchAnalyzer(field.analyzer, field.searchAnalyzer, field.indexAnalyzer)
_ <- validIndexAnalyzer(field.analyzer, field.searchAnalyzer, field.indexAnalyzer)
_ <- validSynonymMaps(field.synonymMap)
_ <- validVectorField(field.dimensions, field.vectorSearchConfiguration)
// TODO: Fix and add back validSynonymMaps check. SynonymMaps needs to be Option[Seq[String]] type
//_ <- validSynonymMaps(field.synonymMap)
aydan-at-microsoft marked this conversation as resolved.
Show resolved Hide resolved
} yield field
}

Expand Down Expand Up @@ -182,6 +212,15 @@ object SearchIndex extends IndexParser with IndexLister {
}
}

private def validVectorField(d: Option[Int], v: Option[String]): Try[Option[String]] = {
if ((d.isDefined && v.isEmpty) || (v.isDefined && d.isEmpty)) {
Failure(new IllegalArgumentException("Both dimensions and vectorSearchConfig fields need to be defined for " +
"vector search"))
} else {
Success(v)
}
}

def getStatistics(indexName: String,
key: String,
serviceName: String,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@ package com.microsoft.azure.synapse.ml.cognitive.search

import com.microsoft.azure.synapse.ml.core.schema.SparkBindings
import spray.json.DefaultJsonProtocol._
import spray.json.{JsonFormat, RootJsonFormat}
import spray.json.{DefaultJsonProtocol, JsonFormat, RootJsonFormat}

object ASResponses extends SparkBindings[ASResponses]

Expand All @@ -23,9 +23,19 @@ case class IndexInfo(
tokenizers: Option[Seq[String]],
tokenFilters: Option[Seq[String]],
defaultScoringProfile: Option[Seq[String]],
corsOptions: Option[Seq[String]]
corsOptions: Option[Seq[String]],
vectorSearch: Option[VectorSearch]
)

case class AlgorithmConfigs(
name: String,
kind: String
)

case class VectorSearch(
algorithmConfigurations: Seq[AlgorithmConfigs]
)

case class IndexField(
name: String,
`type`: String,
Expand All @@ -38,21 +48,32 @@ case class IndexField(
analyzer: Option[String],
searchAnalyzer: Option[String],
indexAnalyzer: Option[String],
synonymMap: Option[String],
fields: Option[Seq[IndexField]]
synonymMap: Option[Seq[String]],
fields: Option[Seq[IndexField]],
dimensions: Option[Int],
vectorSearchConfiguration: Option[String]
)

case class VectorColParams(
name: String,
dimension: Int
)

case class IndexStats(documentCount: Int, storageSize: Int)

case class IndexList(`@odata.context`: String, value: Seq[IndexName])
case class IndexName(name: String)

object AzureSearchProtocol {
object AzureSearchProtocol extends DefaultJsonProtocol {
implicit val IfEnc: JsonFormat[IndexField] = lazyFormat(jsonFormat(
IndexField,"name","type","searchable","filterable","sortable",
"facetable","retrievable", "key","analyzer","searchAnalyzer", "indexAnalyzer", "synonymMaps", "fields"))
implicit val IiEnc: RootJsonFormat[IndexInfo] = jsonFormat10(IndexInfo.apply)
"facetable","retrievable", "key","analyzer","searchAnalyzer", "indexAnalyzer", "synonymMaps", "fields",
"dimensions", "vectorSearchConfiguration"))
implicit val AcEnc: RootJsonFormat[AlgorithmConfigs] = jsonFormat2(AlgorithmConfigs.apply)
implicit val VsEnc: RootJsonFormat[VectorSearch] = jsonFormat1(VectorSearch.apply)
implicit val IiEnc: RootJsonFormat[IndexInfo] = jsonFormat11(IndexInfo.apply)
implicit val IsEnc: RootJsonFormat[IndexStats] = jsonFormat2(IndexStats.apply)
implicit val InEnc: RootJsonFormat[IndexName] = jsonFormat1(IndexName.apply)
implicit val IlEnc: RootJsonFormat[IndexList] = jsonFormat2(IndexList.apply)
implicit val VcpEnc: RootJsonFormat[VectorColParams] = jsonFormat2(VectorColParams.apply)
}
Loading