diff --git a/cognitive/src/main/scala/com/microsoft/azure/synapse/ml/cognitive/search/AzureSearch.scala b/cognitive/src/main/scala/com/microsoft/azure/synapse/ml/cognitive/search/AzureSearch.scala index d4db72e3f3..54764d6404 100644 --- a/cognitive/src/main/scala/com/microsoft/azure/synapse/ml/cognitive/search/AzureSearch.scala +++ b/cognitive/src/main/scala/com/microsoft/azure/synapse/ml/cognitive/search/AzureSearch.scala @@ -18,6 +18,8 @@ import org.apache.spark.internal.{Logging => SLogging} import org.apache.spark.ml.param._ import org.apache.spark.ml.util._ import org.apache.spark.ml.{ComplexParamsReadable, NamespaceInjections, PipelineModel} +import org.apache.spark.ml.linalg.SQLDataTypes.VectorType +import org.apache.spark.ml.functions.vector_to_array import org.apache.spark.sql.functions.{col, expr, struct, to_json} import org.apache.spark.sql.streaming.DataStreamWriter import org.apache.spark.sql.types._ @@ -142,7 +144,7 @@ class AddDocuments(override val uid: String) extends CognitiveServicesBase(uid) override def responseDataType: DataType = ASResponses.schema } -object AzureSearchWriter extends IndexParser with SLogging { +object AzureSearchWriter extends IndexParser with IndexJsonGetter with SLogging { val Logger: Logger = LogManager.getRootLogger @@ -166,9 +168,11 @@ object AzureSearchWriter extends IndexParser with SLogging { private def convertFields(fields: Seq[StructField], keyCol: String, searchActionCol: String, + vectorCols: Option[Seq[VectorColParams]], prefix: Option[String]): Seq[IndexField] = { fields.filterNot(_.name == searchActionCol).map { sf => val fullName = prefix.map(_ + sf.name).getOrElse(sf.name) + val isVector = vectorCols.exists(_.exists(_.name == fullName)) val (innerType, _) = sparkTypeToEdmType(sf.dataType) IndexField( sf.name, @@ -177,7 +181,9 @@ object AzureSearchWriter extends IndexParser with SLogging { if (keyCol == fullName) Some(true) else None, None, None, None, None, structFieldToSearchFields(sf.dataType, - keyCol, searchActionCol, prefix = Some(prefix.getOrElse("") + sf.name + ".")) + keyCol, searchActionCol, None, prefix = Some(prefix.getOrElse("") + sf.name + ".")), + if (isVector) vectorCols.get.find(_.name == fullName).map(_.dimension) else None, + if (isVector) Some(AzureSearchAPIConstants.VectorConfigName) else None ) } } @@ -185,23 +191,34 @@ object AzureSearchWriter extends IndexParser with SLogging { private def structFieldToSearchFields(schema: DataType, keyCol: String, searchActionCol: String, + vectorCols: Option[Seq[VectorColParams]], prefix: Option[String] = None ): Option[Seq[IndexField]] = { schema match { - case StructType(fields) => Some(convertFields(fields, keyCol, searchActionCol, prefix)) - case ArrayType(StructType(fields), _) => Some(convertFields(fields, keyCol, searchActionCol, prefix)) + case StructType(fields) => Some(convertFields(fields, keyCol, searchActionCol, vectorCols, prefix)) + // TODO: Support vector search in nested fields + case ArrayType(StructType(fields), _) => Some(convertFields(fields, keyCol, searchActionCol, None, prefix)) case _ => None } } + private def parseVectorColsJson(str: String): Seq[VectorColParams] = { + str.parseJson.convertTo[Seq[VectorColParams]] + } + private def dfToIndexJson(schema: StructType, indexName: String, keyCol: String, - searchActionCol: String): String = { + searchActionCol: String, + vectorCols: Option[Seq[VectorColParams]]): String = { + + val vectorConfig = Some(VectorSearch(Seq(AlgorithmConfigs(AzureSearchAPIConstants.VectorConfigName, + AzureSearchAPIConstants.VectorSearchAlgorithm)))) val is = IndexInfo( Some(indexName), - structFieldToSearchFields(schema, keyCol, searchActionCol).get, - None, None, None, None, None, None, None, None + structFieldToSearchFields(schema, keyCol, searchActionCol, vectorCols).get, + None, None, None, None, None, None, None, None, + if (vectorCols.isEmpty) None else vectorConfig ) is.toJson.compactPrint } @@ -210,7 +227,7 @@ object AzureSearchWriter extends IndexParser with SLogging { options: Map[String, String] = Map()): DataFrame = { val applicableOptions = Set( "subscriptionKey", "actionCol", "serviceName", "indexName", "indexJson", - "apiVersion", "batchSize", "fatalErrors", "filterNulls", "keyCol" + "apiVersion", "batchSize", "fatalErrors", "filterNulls", "keyCol", "vectorCols" ) options.keys.foreach(k => @@ -224,11 +241,12 @@ object AzureSearchWriter extends IndexParser with SLogging { val batchSize = options.getOrElse("batchSize", "100").toInt val fatalErrors = options.getOrElse("fatalErrors", "true").toBoolean val filterNulls = options.getOrElse("filterNulls", "false").toBoolean + val vectorColsInfo = options.get("vectorCols") val keyCol = options.get("keyCol") val indexName = options.getOrElse("indexName", parseIndexJson(indexJsonOpt.get).name.get) if (indexJsonOpt.isDefined) { - List("keyCol", "indexName").foreach(opt => + List("keyCol", "indexName", "vectorCols").foreach(opt => assert(!options.contains(opt), s"Cannot set both indexJson options and $opt") ) } @@ -242,22 +260,41 @@ object AzureSearchWriter extends IndexParser with SLogging { } } - val indexJson = indexJsonOpt.getOrElse { - dfToIndexJson(df.schema, indexName, keyCol.get, actionCol) + val (indexJson, preppedDF) = if (getExisting(subscriptionKey, serviceName, apiVersion).contains(indexName)) { + if (indexJsonOpt.isDefined) { + println(f"indexJsonOpt is specified, however an index for $indexName already exists," + + f"we will use the index definition obtained from the existing index instead") + } + val existingIndexJson = getIndexJsonFromExistingIndex(subscriptionKey, serviceName, indexName) + val vectorColNameTypeTuple = getVectorColConf(existingIndexJson) + (existingIndexJson, makeColsCompatible(vectorColNameTypeTuple, df)) + } else if (indexJsonOpt.isDefined) { + val vectorColNameTypeTuple = getVectorColConf(indexJsonOpt.get) + (indexJsonOpt.get, makeColsCompatible(vectorColNameTypeTuple, df)) + } else { + val vectorCols = vectorColsInfo.map(parseVectorColsJson) + val vectorColNameTypeTuple = vectorCols.map(_.map(vc => (vc.name, "Collection(Edm.Single)"))).getOrElse(Seq.empty) + val newDF = makeColsCompatible(vectorColNameTypeTuple, df) + val inferredIndexJson = dfToIndexJson(newDF.schema, indexName, keyCol.getOrElse(""), actionCol, vectorCols) + (inferredIndexJson, newDF) } + // TODO: Support vector search in nested fields + // Throws an exception if any nested field is a vector in the schema + parseIndexJson(indexJson).fields.foreach(_.fields.foreach(assertNoNestedVectors)) + SearchIndex.createIfNoneExists(subscriptionKey, serviceName, indexJson, apiVersion) logInfo("checking schema parity") - checkSchemaParity(df.schema, indexJson, actionCol) + checkSchemaParity(preppedDF.schema, indexJson, actionCol) val df1 = if (filterNulls) { val collectionColumns = parseIndexJson(indexJson).fields .filter(_.`type`.startsWith("Collection")) .map(_.name) - collectionColumns.foldLeft(df) { (ndf, c) => filterOutNulls(ndf, c) } + collectionColumns.foldLeft(preppedDF) { (ndf, c) => filterOutNulls(ndf, c) } } else { - df + preppedDF } new AddDocuments() @@ -273,6 +310,48 @@ object AzureSearchWriter extends IndexParser with SLogging { UDFUtils.oldUdf(checkForErrors(fatalErrors) _, ErrorUtils.ErrorSchema)(col("error"), col("input"))) } + private def assertNoNestedVectors(fields: Seq[IndexField]): Unit = { + def checkVectorField(field: IndexField): Unit = { + if (field.dimensions.nonEmpty && field.vectorSearchConfiguration.nonEmpty) { + throw new IllegalArgumentException(s"Nested field ${field.name} is a vector field, vector fields in nested" + + s" fields are not supported.") + } + field.fields.foreach(_.foreach(checkVectorField)) + } + fields.foreach(checkVectorField) + } + + private def getVectorColConf(indexJson: String): Seq[(String, String)] = { + parseIndexJson(indexJson).fields + .filter(f => f.vectorSearchConfiguration.nonEmpty && f.dimensions.nonEmpty) + .map(f => (f.name, f.`type`)) + } + private def makeColsCompatible(vectorColNameTypeTuple: Seq[(String, String)], + df: DataFrame): DataFrame = { + vectorColNameTypeTuple.foldLeft(df) { case (accDF, (colName, colType)) => + if (!accDF.columns.contains(colName)) { + println(s"Column $colName is specified in either indexJson or vectorCols but not found in dataframe " + + s"columns ${accDF.columns.toList}") + accDF + } + else { + val colDataType = accDF.schema(colName).dataType + assert(colDataType match { + case ArrayType(elementType, _) => elementType == FloatType || elementType == DoubleType + case VectorType => true + case _ => false + }, s"Vector column $colName needs to be one of (ArrayType(FloatType), ArrayType(DoubleType), VectorType)") + if (colDataType.isInstanceOf[ArrayType]) { + accDF.withColumn(colName, accDF(colName).cast(edmTypeToSparkType(colType, None))) + } else { + // first cast vectorUDT to array, then cast it to correct array type + val modifiedDF = accDF.withColumn(colName, vector_to_array(accDF(colName))) + modifiedDF.withColumn(colName, modifiedDF(colName).cast(edmTypeToSparkType(colType, None))) + } + } + } + } + private def isEdmCollection(t: String): Boolean = { t.startsWith("Collection(") && t.endsWith(")") } @@ -290,6 +369,7 @@ object AzureSearchWriter extends IndexParser with SLogging { case "Edm.Int64" => LongType case "Edm.Int32" => IntegerType case "Edm.Double" => DoubleType + case "Edm.Single" => FloatType case "Edm.DateTimeOffset" => StringType //See if there's a way to use spark datetimes case "Edm.GeographyPoint" => StringType case "Edm.ComplexType" => StructType(fields.get.map(f => @@ -310,10 +390,12 @@ object AzureSearchWriter extends IndexParser with SLogging { case IntegerType => ("Edm.Int32", None) case LongType => ("Edm.Int64", None) case DoubleType => ("Edm.Double", None) + case FloatType => ("Edm.Single", None) case DateType => ("Edm.DateTimeOffset", None) case StructType(fields) => ("Edm.ComplexType", Some(fields.map { f => val (innerType, innerFields) = sparkTypeToEdmType(f.dataType) - IndexField(f.name, innerType, None, None, None, None, None, None, None, None, None, None, innerFields) + IndexField(f.name, innerType, None, None, None, None, None, None, None, None, None, None, innerFields, + None, None) // TODO: Support vector search in nested fields })) } } diff --git a/cognitive/src/main/scala/com/microsoft/azure/synapse/ml/cognitive/search/AzureSearchAPI.scala b/cognitive/src/main/scala/com/microsoft/azure/synapse/ml/cognitive/search/AzureSearchAPI.scala index 9a9860857e..f30ab9cd92 100644 --- a/cognitive/src/main/scala/com/microsoft/azure/synapse/ml/cognitive/search/AzureSearchAPI.scala +++ b/cognitive/src/main/scala/com/microsoft/azure/synapse/ml/cognitive/search/AzureSearchAPI.scala @@ -14,7 +14,9 @@ import spray.json._ import scala.util.{Failure, Success, Try} object AzureSearchAPIConstants { - val DefaultAPIVersion = "2019-05-06" + val DefaultAPIVersion = "2023-07-01-Preview" + val VectorConfigName = "vectorConfig" + val VectorSearchAlgorithm = "hnsw" } import com.microsoft.azure.synapse.ml.cognitive.search.AzureSearchAPIConstants._ @@ -39,6 +41,26 @@ trait IndexLister { } } +trait IndexJsonGetter extends IndexLister { + def getIndexJsonFromExistingIndex(key: String, + serviceName: String, + indexName: String, + apiVersion: String = DefaultAPIVersion): String = { + val existingIndexNames = getExisting(key, serviceName, apiVersion) + assert(existingIndexNames.contains(indexName), s"Cannot find an existing index name with $indexName") + + val indexJsonRequest = new HttpGet( + s"https://$serviceName.search.windows.net/indexes/$indexName?api-version=$apiVersion" + ) + indexJsonRequest.setHeader("api-key", key) + indexJsonRequest.setHeader("Content-Type", "application/json") + val indexJsonResponse = safeSend(indexJsonRequest, close = false) + val indexJson = IOUtils.toString(indexJsonResponse.getEntity.getContent, "utf-8") + indexJsonResponse.close() + indexJson + } +} + object SearchIndex extends IndexParser with IndexLister { import AzureSearchProtocol._ @@ -94,7 +116,9 @@ object SearchIndex extends IndexParser with IndexLister { _ <- validAnalyzer(field.analyzer, field.searchAnalyzer, field.indexAnalyzer) _ <- validSearchAnalyzer(field.analyzer, field.searchAnalyzer, field.indexAnalyzer) _ <- validIndexAnalyzer(field.analyzer, field.searchAnalyzer, field.indexAnalyzer) - _ <- validSynonymMaps(field.synonymMap) + _ <- validVectorField(field.dimensions, field.vectorSearchConfiguration) + // TODO: Fix and add back validSynonymMaps check. SynonymMaps needs to be Option[Seq[String]] type + //_ <- validSynonymMaps(field.synonymMap) } yield field } @@ -182,6 +206,15 @@ object SearchIndex extends IndexParser with IndexLister { } } + private def validVectorField(d: Option[Int], v: Option[String]): Try[Option[String]] = { + if ((d.isDefined && v.isEmpty) || (v.isDefined && d.isEmpty)) { + Failure(new IllegalArgumentException("Both dimensions and vectorSearchConfig fields need to be defined for " + + "vector search")) + } else { + Success(v) + } + } + def getStatistics(indexName: String, key: String, serviceName: String, diff --git a/cognitive/src/main/scala/com/microsoft/azure/synapse/ml/cognitive/search/AzureSearchSchemas.scala b/cognitive/src/main/scala/com/microsoft/azure/synapse/ml/cognitive/search/AzureSearchSchemas.scala index a8d9142e09..7b0612330c 100644 --- a/cognitive/src/main/scala/com/microsoft/azure/synapse/ml/cognitive/search/AzureSearchSchemas.scala +++ b/cognitive/src/main/scala/com/microsoft/azure/synapse/ml/cognitive/search/AzureSearchSchemas.scala @@ -5,7 +5,7 @@ package com.microsoft.azure.synapse.ml.cognitive.search import com.microsoft.azure.synapse.ml.core.schema.SparkBindings import spray.json.DefaultJsonProtocol._ -import spray.json.{JsonFormat, RootJsonFormat} +import spray.json.{DefaultJsonProtocol, JsonFormat, RootJsonFormat} object ASResponses extends SparkBindings[ASResponses] @@ -23,9 +23,19 @@ case class IndexInfo( tokenizers: Option[Seq[String]], tokenFilters: Option[Seq[String]], defaultScoringProfile: Option[Seq[String]], - corsOptions: Option[Seq[String]] + corsOptions: Option[Seq[String]], + vectorSearch: Option[VectorSearch] ) +case class AlgorithmConfigs( + name: String, + kind: String + ) + +case class VectorSearch( + algorithmConfigurations: Seq[AlgorithmConfigs] + ) + case class IndexField( name: String, `type`: String, @@ -38,21 +48,32 @@ case class IndexField( analyzer: Option[String], searchAnalyzer: Option[String], indexAnalyzer: Option[String], - synonymMap: Option[String], - fields: Option[Seq[IndexField]] + synonymMap: Option[Seq[String]], + fields: Option[Seq[IndexField]], + dimensions: Option[Int], + vectorSearchConfiguration: Option[String] ) +case class VectorColParams( + name: String, + dimension: Int + ) + case class IndexStats(documentCount: Int, storageSize: Int) case class IndexList(`@odata.context`: String, value: Seq[IndexName]) case class IndexName(name: String) -object AzureSearchProtocol { +object AzureSearchProtocol extends DefaultJsonProtocol { implicit val IfEnc: JsonFormat[IndexField] = lazyFormat(jsonFormat( IndexField,"name","type","searchable","filterable","sortable", - "facetable","retrievable", "key","analyzer","searchAnalyzer", "indexAnalyzer", "synonymMaps", "fields")) - implicit val IiEnc: RootJsonFormat[IndexInfo] = jsonFormat10(IndexInfo.apply) + "facetable","retrievable", "key","analyzer","searchAnalyzer", "indexAnalyzer", "synonymMaps", "fields", + "dimensions", "vectorSearchConfiguration")) + implicit val AcEnc: RootJsonFormat[AlgorithmConfigs] = jsonFormat2(AlgorithmConfigs.apply) + implicit val VsEnc: RootJsonFormat[VectorSearch] = jsonFormat1(VectorSearch.apply) + implicit val IiEnc: RootJsonFormat[IndexInfo] = jsonFormat11(IndexInfo.apply) implicit val IsEnc: RootJsonFormat[IndexStats] = jsonFormat2(IndexStats.apply) implicit val InEnc: RootJsonFormat[IndexName] = jsonFormat1(IndexName.apply) implicit val IlEnc: RootJsonFormat[IndexList] = jsonFormat2(IndexList.apply) + implicit val VcpEnc: RootJsonFormat[VectorColParams] = jsonFormat2(VectorColParams.apply) } diff --git a/cognitive/src/test/scala/com/microsoft/azure/synapse/ml/cognitive/search/SearchWriterSuite.scala b/cognitive/src/test/scala/com/microsoft/azure/synapse/ml/cognitive/search/SearchWriterSuite.scala index 433a0f17ed..2a92b78d12 100644 --- a/cognitive/src/test/scala/com/microsoft/azure/synapse/ml/cognitive/search/SearchWriterSuite.scala +++ b/cognitive/src/test/scala/com/microsoft/azure/synapse/ml/cognitive/search/SearchWriterSuite.scala @@ -5,6 +5,7 @@ package com.microsoft.azure.synapse.ml.cognitive.search import com.microsoft.azure.synapse.ml.Secrets import com.microsoft.azure.synapse.ml.cognitive._ +import com.microsoft.azure.synapse.ml.cognitive.openai.{OpenAIAPIKey, OpenAIEmbedding} import com.microsoft.azure.synapse.ml.cognitive.vision.AnalyzeImage import com.microsoft.azure.synapse.ml.core.test.base.TestBase import com.microsoft.azure.synapse.ml.core.test.fuzzing.{TestObject, TransformerFuzzing} @@ -12,6 +13,7 @@ import com.microsoft.azure.synapse.ml.io.http.RESTHelpers._ import org.apache.http.client.methods.HttpDelete import org.apache.spark.ml.util.MLReadable import org.apache.spark.sql.DataFrame +import org.apache.spark.ml.linalg.Vectors import java.time.LocalDateTime import java.time.format.{DateTimeFormatterBuilder, DateTimeParseException, SignStyle} @@ -25,8 +27,8 @@ trait AzureSearchKey { } //scalastyle:off null -class SearchWriterSuite extends TestBase with AzureSearchKey with IndexLister - with TransformerFuzzing[AddDocuments] with CognitiveKey { +class SearchWriterSuite extends TestBase with AzureSearchKey with IndexJsonGetter with IndexParser + with TransformerFuzzing[AddDocuments] with CognitiveKey with OpenAIAPIKey { import spark.implicits._ @@ -44,6 +46,12 @@ class SearchWriterSuite extends TestBase with AzureSearchKey with IndexLister .toDF("searchAction", "id", "fileName", "text") } + private def createTestDataWithVector(numDocs: Int): DataFrame = { + (0 until numDocs) + .map(i => ("upload", s"$i", s"file$i", Array(0.001, 0.002, 0.003).map(_ * i))) + .toDF("searchAction", "id", "fileName", "vectorCol") + } + private def createSimpleIndexJson(indexName: String): String = { s""" |{ @@ -74,6 +82,43 @@ class SearchWriterSuite extends TestBase with AzureSearchKey with IndexLister """.stripMargin } + private def createSimpleIndexJsonWithVector(indexName: String): String = { + s""" + |{ + | "name": "$indexName", + | "fields": [ + | { + | "name": "id", + | "type": "Edm.String", + | "key": true, + | "facetable": false + | }, + | { + | "name": "fileName", + | "type": "Edm.String", + | "searchable": false, + | "sortable": false, + | "facetable": false + | }, + | { + | "name": "vectorCol", + | "type": "Collection(Edm.Single)", + | "dimensions": 3, + | "vectorSearchConfiguration": "vectorConfig" + | } + | ], + | "vectorSearch": { + | "algorithmConfigurations": [ + | { + | "name": "vectorConfig", + | "kind": "hnsw" + | } + | ] + | } + | } + """.stripMargin + } + private val createdIndexes: mutable.ListBuffer[String] = mutable.ListBuffer() private def generateIndexName(): String = { @@ -105,7 +150,7 @@ class SearchWriterSuite extends TestBase with AzureSearchKey with IndexLister println("Cleaning up services") val successfulCleanup = getExisting(azureSearchKey, testServiceName) .intersect(createdIndexes).map { n => - deleteIndex(n) + deleteIndex(n) }.forall(_ == 204) cleanOldIndexes() super.afterAll() @@ -173,12 +218,15 @@ class SearchWriterSuite extends TestBase with AzureSearchKey with IndexLister def writeHelper(df: DataFrame, indexName: String, + isVectorField: Boolean, extraParams: Map[String, String] = Map()): Unit = { + val indexJson = if (isVectorField) createSimpleIndexJsonWithVector(indexName) else createSimpleIndexJson(indexName) AzureSearchWriter.write(df, Map("subscriptionKey" -> azureSearchKey, "actionCol" -> "searchAction", "serviceName" -> testServiceName, - "indexJson" -> createSimpleIndexJson(indexName)) ++ extraParams) + "indexJson" -> indexJson) + ++ extraParams) } def assertSize(indexName: String, size: Int): Unit = { @@ -186,15 +234,15 @@ class SearchWriterSuite extends TestBase with AzureSearchKey with IndexLister () } - ignore("clean up all search indexes"){ + ignore("clean up all search indexes") { getExisting(azureSearchKey, testServiceName) .foreach { n => - val deleteRequest = new HttpDelete( - s"https://$testServiceName.search.windows.net/indexes/$n?api-version=2017-11-11") - deleteRequest.setHeader("api-key", azureSearchKey) - val response = safeSend(deleteRequest) - println(s"Deleted index $n, status code ${response.getStatusLine.getStatusCode}") - } + val deleteRequest = new HttpDelete( + s"https://$testServiceName.search.windows.net/indexes/$n?api-version=2017-11-11") + deleteRequest.setHeader("api-key", azureSearchKey) + val response = safeSend(deleteRequest) + println(s"Deleted index $n, status code ${response.getStatusLine.getStatusCode}") + } } test("Run azure-search tests with waits") { @@ -209,17 +257,17 @@ class SearchWriterSuite extends TestBase with AzureSearchKey with IndexLister //create new index and add docs lazy val in1 = generateIndexName() - dependsOn(1, writeHelper(df4, in1)) + dependsOn(1, writeHelper(df4, in1, isVectorField=false)) //push docs to existing index lazy val in2 = generateIndexName() lazy val dfA = df10.limit(4) lazy val dfB = df10.except(dfA) - dependsOn(2, writeHelper(dfA, in2)) + dependsOn(2, writeHelper(dfA, in2, isVectorField=false)) dependsOn(2, retryWithBackoff({ if (getExisting(azureSearchKey, testServiceName).contains(in2)) { - writeHelper(dfB, in2) + writeHelper(dfB, in2, isVectorField=false) } else { throw new RuntimeException("No existing service found") } @@ -227,7 +275,7 @@ class SearchWriterSuite extends TestBase with AzureSearchKey with IndexLister //push docs with custom batch size lazy val in3 = generateIndexName() - dependsOn(3, writeHelper(bigDF, in3, Map("batchSize" -> "2000"))) + dependsOn(3, writeHelper(bigDF, in3, isVectorField=false, Map("batchSize" -> "2000"))) dependsOn(1, retryWithBackoff(assertSize(in1, 4))) dependsOn(2, retryWithBackoff(assertSize(in2, 10))) @@ -276,17 +324,17 @@ class SearchWriterSuite extends TestBase with AzureSearchKey with IndexLister .map { i => ("upload", s"$i", s"file$i", s"text$i") } .toDF("searchAction", "badkeyname", "fileName", "text") assertThrows[IllegalArgumentException] { - writeHelper(mismatchDF, generateIndexName()) + writeHelper(mismatchDF, generateIndexName(), isVectorField=false) } } /** - * All the Edm Types are nullable in Azure Search except for Collection(Edm.String). - * Because it is not possible to store a null value in a Collection(Edm.String) field, - * there is an option to set a boolean flag, filterNulls, that will remove null values - * from the dataset in the Collection(Edm.String) fields before writing the data to the search index. - * The default value for this boolean flag is False. - */ + * All the Edm Types are nullable in Azure Search except for Collection(Edm.String). + * Because it is not possible to store a null value in a Collection(Edm.String) field, + * there is an option to set a boolean flag, filterNulls, that will remove null values + * from the dataset in the Collection(Edm.String) fields before writing the data to the search index. + * The default value for this boolean flag is False. + */ test("Handle null values for Collection(Edm.String) fields") { val in = generateIndexName() val phraseIndex = @@ -387,4 +435,233 @@ class SearchWriterSuite extends TestBase with AzureSearchKey with IndexLister retryWithBackoff(assertSize(in, 2)) } + test("Run azure-search tests with vector fields") { + val in1 = generateIndexName() + val vectorDF4 = createTestDataWithVector(4) + + writeHelper(vectorDF4, in1, isVectorField=true) + + val in2 = generateIndexName() + val vectorDF10 = createTestDataWithVector(10) + val dfA = vectorDF10.limit(4) + val dfB = vectorDF10.except(dfA) + + writeHelper(dfA, in2, isVectorField=true) + + retryWithBackoff({ + if (getExisting(azureSearchKey, testServiceName).contains(in2)) { + writeHelper(dfB, in2, isVectorField=true) + } else { + throw new RuntimeException("No existing service found") + } + }) + + retryWithBackoff(assertSize(in1, 4)) + retryWithBackoff(assertSize(in2, 10)) + + val indexJson = retryWithBackoff(getIndexJsonFromExistingIndex(azureSearchKey, testServiceName, in1)) + // assert if vectorCol is a vector field + assert(parseIndexJson(indexJson).fields.find(_.name == "vectorCol").get.vectorSearchConfiguration.nonEmpty) + } + + test("Infer the structure of the index from the dataframe with vector columns") { + val in = generateIndexName() + val phraseDF = Seq( + ("upload", "0", "file0", Array(1.1, 2.1, 3.1), Vectors.dense(0.11, 0.21, 0.31), + Vectors.sparse(3, Array(0, 1, 2), Array(0.11, 0.21, 0.31))), + ("upload", "1", "file1", Array(1.2, 2.2, 3.2), Vectors.dense(0.12, 0.22, 0.32), + Vectors.sparse(3, Array(0, 1, 2), Array(0.11, 0.21, 0.31)))) + .toDF("searchAction", "id", "fileName", "vectorCol1", "vectorCol2", "vectorCol3") + + val vectorCols = + """ + |[ + | {"name": "vectorCol1", "dimension": 3}, + | {"name": "vectorCol2", "dimension": 3}, + | {"name": "vectorCol3", "dimension": 3} + |] + |""".stripMargin + + AzureSearchWriter.write(phraseDF, + Map( + "subscriptionKey" -> azureSearchKey, + "actionCol" -> "searchAction", + "serviceName" -> testServiceName, + "filterNulls" -> "true", + "indexName" -> in, + "keyCol" -> "id", + "vectorCols" -> vectorCols + )) + + retryWithBackoff(assertSize(in, 2)) + + // assert if vectorCols are a vector field + val indexJson = retryWithBackoff(getIndexJsonFromExistingIndex(azureSearchKey, testServiceName, in)) + assert(parseIndexJson(indexJson).fields.find(_.name == "vectorCol1").get.vectorSearchConfiguration.nonEmpty) + assert(parseIndexJson(indexJson).fields.find(_.name == "vectorCol2").get.vectorSearchConfiguration.nonEmpty) + assert(parseIndexJson(indexJson).fields.find(_.name == "vectorCol3").get.vectorSearchConfiguration.nonEmpty) + } + + test("Throw useful error when given vector columns in nested fields") { + val in = generateIndexName() + val badJson = + s""" + |{ + | "name": "$in", + | "fields": [ + | { + | "name": "id", + | "type": "Edm.String", + | "key": true, + | "facetable": false + | }, + | { + | "name": "someCollection", + | "type": "Edm.String" + | }, + | { + | "name": "complexField", + | "type": "Edm.ComplexType", + | "fields": [ + | { + | "name": "StreetAddress", + | "type": "Edm.String" + | }, + | { + | "name": "contentVector", + | "type": "Collection(Edm.Single)", + | "dimensions": 3, + | "vectorSearchConfiguration": "vectorConfig" + | } + | ] + | } + | ] + |} + """.stripMargin + + assertThrows[IllegalArgumentException] { + AzureSearchWriter.write(df4, + Map( + "subscriptionKey" -> azureSearchKey, + "actionCol" -> "searchAction", + "serviceName" -> testServiceName, + "filterNulls" -> "true", + "indexJson" -> badJson + )) + } + } + + test("Throw useful error when one of dimensions or vectorSearchConfig is not defined") { + val in = generateIndexName() + val badJson = + s""" + |{ + | "name": "$in", + | "fields": [ + | { + | "name": "id", + | "type": "Edm.String", + | "key": true, + | "facetable": false + | }, + | { + | "name": "someCollection", + | "type": "Edm.String" + | }, + | { + | "name": "contentVector", + | "type": "Collection(Edm.Single)", + | "dimensions": 3 + | } + | ] + |} + """.stripMargin + + assertThrows[IllegalArgumentException] { + SearchIndex.createIfNoneExists(azureSearchKey, testServiceName, badJson) + } + } + + test("Handle non-existent vector column specified in vectorCols option") { + val in = generateIndexName() + val phraseDF = Seq( + ("upload", "0", "file0"), + ("upload", "1", "file1")) + .toDF("searchAction", "id", "fileName") + + AzureSearchWriter.write(phraseDF, + Map( + "subscriptionKey" -> azureSearchKey, + "actionCol" -> "searchAction", + "serviceName" -> testServiceName, + "indexName" -> in, + "keyCol" -> "id", + "vectorCols" -> """[{"name": "vectorCol", "dimension": 3}]""" + )) + + retryWithBackoff(assertSize(in, 2)) + } + + test("Handle non-existing vector column specified in index JSON option") { + val in = generateIndexName() + val phraseDF = Seq( + ("upload", "0", "file0"), + ("upload", "1", "file1")) + .toDF("searchAction", "id", "fileName") + + AzureSearchWriter.write(phraseDF, + Map( + "subscriptionKey" -> azureSearchKey, + "actionCol" -> "searchAction", + "serviceName" -> testServiceName, + "indexJson" -> createSimpleIndexJsonWithVector(in) + )) + + retryWithBackoff(assertSize(in, 2)) + } + + test("Throw useful error when the vector column is an unsupported type") { + val in = generateIndexName() + val badDF = Seq( + ("upload", "0", "file0", Array("p1", "p2", "p3")), + ("upload", "1", "file1", Array("p4", "p5", "p6"))) + .toDF("searchAction", "id", "fileName", "vectorCol") + + assertThrows[AssertionError] { + writeHelper(badDF, in, isVectorField=true) + } + } + + test("pipeline with openai embedding") { + val in = generateIndexName() + + val df = Seq( + ("upload", "0", "this is the first sentence"), + ("upload", "1", "this is the second sentence") + ).toDF("searchAction", "id", "content") + + val tdf = new OpenAIEmbedding() + .setSubscriptionKey(openAIAPIKey) + .setDeploymentName("text-embedding-ada-002") + .setCustomServiceName(openAIServiceName) + .setTextCol("content") + .setErrorCol("error") + .setOutputCol("vectorContent") + .transform(df) + .drop("error") + + AzureSearchWriter.write(tdf, + Map( + "subscriptionKey" -> azureSearchKey, + "actionCol" -> "searchAction", + "serviceName" -> testServiceName, + "indexName" -> in, + "keyCol" -> "id", + "vectorCols" -> """[{"name": "vectorContent", "dimension": 1536}]""" + )) + + retryWithBackoff(assertSize(in, 2)) + val indexJson = retryWithBackoff(getIndexJsonFromExistingIndex(azureSearchKey, testServiceName, in)) + assert(parseIndexJson(indexJson).fields.find(_.name == "vectorContent").get.vectorSearchConfiguration.nonEmpty) + } } diff --git a/docs/Explore Algorithms/AI Services/Quickstart - Document Question and Answering with PDFs.ipynb b/docs/Explore Algorithms/AI Services/Quickstart - Document Question and Answering with PDFs.ipynb index 35be3a8dd3..6d248ed270 100644 --- a/docs/Explore Algorithms/AI Services/Quickstart - Document Question and Answering with PDFs.ipynb +++ b/docs/Explore Algorithms/AI Services/Quickstart - Document Question and Answering with PDFs.ipynb @@ -148,7 +148,7 @@ "\n", "# Azure Cognitive Search\n", "cogsearch_name = \"mmlspark-azure-search\"\n", - "cogsearch_index_name = \"exampleindex\"\n", + "cogsearch_index_name = \"examplevectorindex\"\n", "cogsearch_api_key = find_secret(\"azure-search-key\")" ] }, @@ -612,12 +612,15 @@ "metadata": {}, "outputs": [], "source": [ - "# Import necessary packages\n", - "import requests\n", - "import json\n", - "\n", - "EMBEDDING_LENGTH = (\n", - " 1536 # length of the embedding vector (OpenAI generates embeddings of length 1536)\n", + "from pyspark.sql.functions import monotonically_increasing_id\n", + "from pyspark.sql.functions import lit\n", + "\n", + "df_embeddings = (\n", + " df_embeddings.drop(\"error\")\n", + " .withColumn(\n", + " \"idx\", monotonically_increasing_id().cast(\"string\")\n", + " ) # create index ID for ACS\n", + " .withColumn(\"searchAction\", lit(\"upload\"))\n", ")" ] }, @@ -627,148 +630,17 @@ "metadata": {}, "outputs": [], "source": [ - "# Create Index for Cog Search with fields as id, content, and contentVector\n", - "# Note the datatypes for each field below\n", - "\n", - "url = f\"https://{cogsearch_name}.search.windows.net/indexes/{cogsearch_index_name}?api-version=2023-07-01-Preview\"\n", - "payload = json.dumps(\n", - " {\n", - " \"name\": cogsearch_index_name,\n", - " \"fields\": [\n", - " {\"name\": \"id\", \"type\": \"Edm.String\", \"key\": True, \"filterable\": True},\n", - " {\n", - " \"name\": \"content\",\n", - " \"type\": \"Edm.String\",\n", - " \"searchable\": True,\n", - " \"retrievable\": True,\n", - " },\n", - " {\n", - " \"name\": \"contentVector\",\n", - " \"type\": \"Collection(Edm.Single)\",\n", - " \"searchable\": True,\n", - " \"retrievable\": True,\n", - " \"dimensions\": EMBEDDING_LENGTH,\n", - " \"vectorSearchConfiguration\": \"vectorConfig\",\n", - " },\n", - " ],\n", - " \"vectorSearch\": {\n", - " \"algorithmConfigurations\": [\n", - " {\n", - " \"name\": \"vectorConfig\",\n", - " \"kind\": \"hnsw\",\n", - " }\n", - " ]\n", - " },\n", - " }\n", - ")\n", - "headers = {\"Content-Type\": \"application/json\", \"api-key\": cogsearch_api_key}\n", - "\n", - "response = requests.request(\"PUT\", url, headers=headers, data=payload)\n", - "print(response.status_code)" - ] - }, - { - "cell_type": "markdown", - "metadata": { - "application/vnd.databricks.v1+cell": { - "cellMetadata": { - "byteLimit": 2048000, - "rowLimit": 10000 - }, - "inputWidgets": {}, - "nuid": "07396763-74c3-4299-8976-e15e6d510d47", - "showTitle": false, - "title": "" - }, - "nteract": { - "transient": { - "deleting": false - } - } - }, - "source": [ - "We need to use User Defined Function (UDF) through the udf() method in order to apply functions directly to the DataFrames and SQL databases in Python, without any need to individually register them." - ] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": {}, - "outputs": [], - "source": [ - "# Use Spark's UDF to insert entries to Cognitive Search\n", - "# This allows to run the code in a distributed fashion\n", - "\n", - "# Define a UDF using the @udf decorator\n", - "@udf(returnType=StringType())\n", - "def insert_to_cog_search(idx, content, contentVector):\n", - " url = f\"https://{cogsearch_name}.search.windows.net/indexes/{cogsearch_index_name}/docs/index?api-version=2023-07-01-Preview\"\n", - "\n", - " payload = json.dumps(\n", - " {\n", - " \"value\": [\n", - " {\n", - " \"id\": str(idx),\n", - " \"content\": content,\n", - " \"contentVector\": contentVector.tolist(),\n", - " \"@search.action\": \"upload\",\n", - " },\n", - " ]\n", - " }\n", - " )\n", - " headers = {\n", - " \"Content-Type\": \"application/json\",\n", - " \"api-key\": cogsearch_api_key,\n", - " }\n", - "\n", - " response = requests.request(\"POST\", url, headers=headers, data=payload)\n", - " # response.text\n", - "\n", - " if response.status_code == 200 or response.status_code == 201:\n", - " return \"Success\"\n", - " else:\n", - " return \"Failure\"" - ] - }, - { - "cell_type": "markdown", - "metadata": { - "application/vnd.databricks.v1+cell": { - "cellMetadata": { - "byteLimit": 2048000, - "rowLimit": 10000 - }, - "inputWidgets": {}, - "nuid": "42688e00-98fb-406e-9f19-c89fed3248ef", - "showTitle": false, - "title": "" - } - }, - "source": [ - "In the following, we apply UDF to different columns. Note that UDF also helps to add new columns to the DataFrame." - ] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": {}, - "outputs": [], - "source": [ - "# Apply the UDF on the different columns\n", - "from pyspark.sql.functions import monotonically_increasing_id\n", - "\n", - "df_embeddings = df_embeddings.withColumn(\n", - " \"idx\", monotonically_increasing_id()\n", - ") ## adding a column with id\n", - "df_embeddings = df_embeddings.withColumn(\n", - " \"errorCogSearch\",\n", - " insert_to_cog_search(\n", - " df_embeddings[\"idx\"], df_embeddings[\"chunk\"], df_embeddings[\"embeddings\"]\n", - " ),\n", - ")\n", + "from synapse.ml.cognitive import writeToAzureSearch\n", + "import json\n", "\n", - "# Show the transformed DataFrame\n", - "df_embeddings.show()" + "df_embeddings.writeToAzureSearch(\n", + " subscriptionKey=cogsearch_api_key,\n", + " actionCol=\"searchAction\",\n", + " serviceName=cogsearch_name,\n", + " indexName=cogsearch_index_name,\n", + " keyCol=\"idx\",\n", + " vectorCols=json.dumps([{\"name\": \"embeddings\", \"dimension\": 1536}]),\n", + ")" ] }, { @@ -833,6 +705,8 @@ "metadata": {}, "outputs": [], "source": [ + "import requests\n", + "\n", "# Ask a question and convert to embeddings\n", "\n", "\n", @@ -861,7 +735,7 @@ " url = f\"https://{cogsearch_name}.search.windows.net/indexes/{cogsearch_index_name}/docs/search?api-version=2023-07-01-Preview\"\n", "\n", " payload = json.dumps(\n", - " {\"vector\": {\"value\": question_embedding, \"fields\": \"contentVector\", \"k\": 2}}\n", + " {\"vector\": {\"value\": question_embedding, \"fields\": \"embeddings\", \"k\": k}}\n", " )\n", " headers = {\n", " \"Content-Type\": \"application/json\",\n", @@ -996,7 +870,7 @@ "\n", "\n", "# Concatenate the content of retrieved documents\n", - "context = [i[\"content\"] for i in output[\"value\"]]\n", + "context = [i[\"chunk\"] for i in output[\"value\"]]\n", "\n", "# Make a Quesion Answer chain function and pass\n", "qa_chain = qa_chain_func()\n", @@ -1012,5 +886,5 @@ } }, "nbformat": 4, - "nbformat_minor": 0 + "nbformat_minor": 5 }