diff --git a/core/src/main/kotlin/org/jetbrains/kotlinx/dataframe/api/add.kt b/core/src/main/kotlin/org/jetbrains/kotlinx/dataframe/api/add.kt index 3ed50ad08..50cc00d03 100644 --- a/core/src/main/kotlin/org/jetbrains/kotlinx/dataframe/api/add.kt +++ b/core/src/main/kotlin/org/jetbrains/kotlinx/dataframe/api/add.kt @@ -207,6 +207,7 @@ public class AddDsl( public infix fun ColumnReference.into(column: KProperty): Boolean = into(column.name) + @Interpretable("AddDslStringInvoke") public operator fun String.invoke(body: AddDsl.() -> Unit): Unit = group(this, body) public infix fun AnyColumnGroupAccessor.from(body: AddDsl.() -> Unit): Unit = group(this, body) diff --git a/core/src/main/kotlin/org/jetbrains/kotlinx/dataframe/api/addId.kt b/core/src/main/kotlin/org/jetbrains/kotlinx/dataframe/api/addId.kt index 14a5e7840..d29f697ab 100644 --- a/core/src/main/kotlin/org/jetbrains/kotlinx/dataframe/api/addId.kt +++ b/core/src/main/kotlin/org/jetbrains/kotlinx/dataframe/api/addId.kt @@ -3,6 +3,8 @@ package org.jetbrains.kotlinx.dataframe.api import org.jetbrains.kotlinx.dataframe.AnyCol import org.jetbrains.kotlinx.dataframe.AnyFrame import org.jetbrains.kotlinx.dataframe.DataFrame +import org.jetbrains.kotlinx.dataframe.annotations.Interpretable +import org.jetbrains.kotlinx.dataframe.annotations.Refine import org.jetbrains.kotlinx.dataframe.columns.ColumnAccessor // region DataColumn @@ -15,6 +17,8 @@ public fun AnyCol.addId(columnName: String = "id"): AnyFrame = toDataFrame().add public fun DataFrame.addId(column: ColumnAccessor): DataFrame = insert(column) { index() }.at(0) +@Refine +@Interpretable("AddId") public fun DataFrame.addId(columnName: String = "id"): DataFrame = insert(columnName) { index() }.at(0) // endregion diff --git a/core/src/main/kotlin/org/jetbrains/kotlinx/dataframe/api/groupBy.kt b/core/src/main/kotlin/org/jetbrains/kotlinx/dataframe/api/groupBy.kt index cfad7fd1d..3f59ef8f2 100644 --- a/core/src/main/kotlin/org/jetbrains/kotlinx/dataframe/api/groupBy.kt +++ b/core/src/main/kotlin/org/jetbrains/kotlinx/dataframe/api/groupBy.kt @@ -29,6 +29,7 @@ import kotlin.reflect.KProperty * * `df.add("columnName") { "someColumn"() + 15 }.groupBy("columnName")` */ +@Refine @Interpretable("DataFrameGroupBy") public fun DataFrame.groupBy(moveToTop: Boolean = true, cols: ColumnsSelector): GroupBy = groupByImpl(moveToTop, cols) diff --git a/core/src/main/kotlin/org/jetbrains/kotlinx/dataframe/api/map.kt b/core/src/main/kotlin/org/jetbrains/kotlinx/dataframe/api/map.kt index c79717fd1..6907c9e0b 100644 --- a/core/src/main/kotlin/org/jetbrains/kotlinx/dataframe/api/map.kt +++ b/core/src/main/kotlin/org/jetbrains/kotlinx/dataframe/api/map.kt @@ -7,6 +7,8 @@ import org.jetbrains.kotlinx.dataframe.DataFrame import org.jetbrains.kotlinx.dataframe.DataRow import org.jetbrains.kotlinx.dataframe.RowExpression import org.jetbrains.kotlinx.dataframe.Selector +import org.jetbrains.kotlinx.dataframe.annotations.Interpretable +import org.jetbrains.kotlinx.dataframe.annotations.Refine import org.jetbrains.kotlinx.dataframe.columns.ColumnReference import org.jetbrains.kotlinx.dataframe.columns.FrameColumn import org.jetbrains.kotlinx.dataframe.impl.columnName @@ -102,6 +104,8 @@ public fun ColumnsContainer.mapToColumn( body: AddExpression, ): DataColumn = mapToColumn(column.columnName, type, infer, body) +@Refine +@Interpretable("MapToFrame") public fun DataFrame.mapToFrame(body: AddDsl.() -> Unit): AnyFrame { val dsl = AddDsl(this) body(dsl) diff --git a/core/src/main/kotlin/org/jetbrains/kotlinx/dataframe/api/move.kt b/core/src/main/kotlin/org/jetbrains/kotlinx/dataframe/api/move.kt index 2f6efc217..6fd3f3b59 100644 --- a/core/src/main/kotlin/org/jetbrains/kotlinx/dataframe/api/move.kt +++ b/core/src/main/kotlin/org/jetbrains/kotlinx/dataframe/api/move.kt @@ -5,6 +5,8 @@ import org.jetbrains.kotlinx.dataframe.AnyColumnReference import org.jetbrains.kotlinx.dataframe.ColumnSelector import org.jetbrains.kotlinx.dataframe.ColumnsSelector import org.jetbrains.kotlinx.dataframe.DataFrame +import org.jetbrains.kotlinx.dataframe.annotations.Interpretable +import org.jetbrains.kotlinx.dataframe.annotations.Refine import org.jetbrains.kotlinx.dataframe.columns.ColumnReference import org.jetbrains.kotlinx.dataframe.columns.ColumnWithPath import org.jetbrains.kotlinx.dataframe.columns.toColumnSet @@ -18,6 +20,7 @@ import kotlin.reflect.KProperty // region move +@Interpretable("Move0") public fun DataFrame.move(columns: ColumnsSelector): MoveClause = MoveClause(this, columns) public fun DataFrame.move(vararg cols: String): MoveClause = move { cols.toColumnSet() } @@ -120,6 +123,8 @@ public fun MoveClause.under( public fun MoveClause.to(columnIndex: Int): DataFrame = moveTo(columnIndex) +@Refine +@Interpretable("ToTop") public fun MoveClause.toTop( newColumnName: ColumnsSelectionDsl.(ColumnWithPath) -> String = { it.name() }, ): DataFrame = into { newColumnName(it).toColumnAccessor() } diff --git a/core/src/main/kotlin/org/jetbrains/kotlinx/dataframe/api/update.kt b/core/src/main/kotlin/org/jetbrains/kotlinx/dataframe/api/update.kt index 8c3f169e9..b59f32bad 100644 --- a/core/src/main/kotlin/org/jetbrains/kotlinx/dataframe/api/update.kt +++ b/core/src/main/kotlin/org/jetbrains/kotlinx/dataframe/api/update.kt @@ -155,6 +155,7 @@ private interface UpdateWithNote * @include [SelectingColumns.Dsl.WithExample] {@include [SetSelectingColumnsOperationArg]} * @include [Update.DslParam] */ +@Interpretable("Update0") public fun DataFrame.update(columns: ColumnsSelector): Update = Update(this, null, columns) /** diff --git a/plugins/kotlin-dataframe/src/org/jetbrains/kotlinx/dataframe/plugin/analyzeRefinedCallShape.kt b/plugins/kotlin-dataframe/src/org/jetbrains/kotlinx/dataframe/plugin/analyzeRefinedCallShape.kt index 6b720b8ae..6d4e142f8 100644 --- a/plugins/kotlin-dataframe/src/org/jetbrains/kotlinx/dataframe/plugin/analyzeRefinedCallShape.kt +++ b/plugins/kotlin-dataframe/src/org/jetbrains/kotlinx/dataframe/plugin/analyzeRefinedCallShape.kt @@ -12,42 +12,33 @@ import org.jetbrains.kotlin.fir.types.ConeKotlinType import org.jetbrains.kotlin.fir.types.ConeTypeProjection import org.jetbrains.kotlin.fir.types.classId import org.jetbrains.kotlin.fir.types.resolvedType +import org.jetbrains.kotlin.name.ClassId import org.jetbrains.kotlin.name.Name import org.jetbrains.kotlinx.dataframe.plugin.extensions.KotlinTypeFacade import org.jetbrains.kotlinx.dataframe.plugin.impl.Interpreter -import org.jetbrains.kotlinx.dataframe.plugin.impl.PluginDataFrameSchema -import org.jetbrains.kotlinx.dataframe.plugin.impl.api.aggregate -import org.jetbrains.kotlinx.dataframe.plugin.utils.Names.DF_CLASS_ID -fun KotlinTypeFacade.analyzeRefinedCallShape(call: FirFunctionCall, reporter: InterpretationErrorReporter): CallResult? { +internal inline fun KotlinTypeFacade.analyzeRefinedCallShape( + call: FirFunctionCall, + expectedReturnType: ClassId, + reporter: InterpretationErrorReporter +): CallResult? { val callReturnType = call.resolvedType - if (callReturnType.classId != DF_CLASS_ID) return null - val rootMarker = callReturnType.typeArguments[0] + if (callReturnType.classId != expectedReturnType) return null // rootMarker is expected to be a token generated by the plugin. // it's implied by "refined call" // thus ConeClassLikeType - if (rootMarker !is ConeClassLikeType) { - return null - } + val rootMarkers = callReturnType.typeArguments.filterIsInstance() + if (rootMarkers.size != callReturnType.typeArguments.size) return null - val newSchema: PluginDataFrameSchema = call.interpreterName(session)?.let { name -> + val newSchema: T = call.interpreterName(session)?.let { name -> when (name) { - "Aggregate" -> { - val groupByCall = call.explicitReceiver as? FirFunctionCall - val interpreter = groupByCall?.loadInterpreter(session) - if (interpreter != null) { - aggregate(groupByCall, interpreter, reporter, call) - } else { - PluginDataFrameSchema(emptyList()) - } - } else -> name.load>().let { processor -> val dataFrameSchema = interpret(call, processor, reporter = reporter) .let { val value = it?.value - if (value !is PluginDataFrameSchema) { + if (value !is T) { if (!reporter.errorReported) { - reporter.reportInterpretationError(call, "${processor::class} must return ${PluginDataFrameSchema::class}, but was ${value}") + reporter.reportInterpretationError(call, "${processor::class} must return ${T::class}, but was $value") } return null } @@ -58,10 +49,10 @@ fun KotlinTypeFacade.analyzeRefinedCallShape(call: FirFunctionCall, reporter: In } } ?: return null - return CallResult(rootMarker, newSchema) + return CallResult(rootMarkers, newSchema) } -data class CallResult(val rootMarker: ConeClassLikeType, val newSchema: PluginDataFrameSchema) +data class CallResult(val markers: List, val result: T) class RefinedArguments(val refinedArguments: List) : List by refinedArguments diff --git a/plugins/kotlin-dataframe/src/org/jetbrains/kotlinx/dataframe/plugin/extensions/FunctionCallTransformer.kt b/plugins/kotlin-dataframe/src/org/jetbrains/kotlinx/dataframe/plugin/extensions/FunctionCallTransformer.kt index 6b4fe2be1..6088f6376 100644 --- a/plugins/kotlin-dataframe/src/org/jetbrains/kotlinx/dataframe/plugin/extensions/FunctionCallTransformer.kt +++ b/plugins/kotlin-dataframe/src/org/jetbrains/kotlinx/dataframe/plugin/extensions/FunctionCallTransformer.kt @@ -10,6 +10,7 @@ import org.jetbrains.kotlin.fir.FirAnnotationContainer import org.jetbrains.kotlin.fir.FirElement import org.jetbrains.kotlin.fir.FirFunctionTarget import org.jetbrains.kotlin.fir.FirSession +import org.jetbrains.kotlin.fir.analysis.checkers.fullyExpandedClassId import org.jetbrains.kotlin.fir.caches.FirCache import org.jetbrains.kotlinx.dataframe.plugin.InterpretationErrorReporter import org.jetbrains.kotlinx.dataframe.plugin.SchemaProperty @@ -17,6 +18,7 @@ import org.jetbrains.kotlinx.dataframe.plugin.analyzeRefinedCallShape import org.jetbrains.kotlinx.dataframe.plugin.utils.Names import org.jetbrains.kotlinx.dataframe.plugin.utils.projectOverDataColumnType import org.jetbrains.kotlin.fir.declarations.EmptyDeprecationsProvider +import org.jetbrains.kotlin.fir.declarations.FirClass import org.jetbrains.kotlin.fir.declarations.FirDeclarationOrigin import org.jetbrains.kotlin.fir.declarations.FirRegularClass import org.jetbrains.kotlin.fir.declarations.FirResolvePhase @@ -25,6 +27,7 @@ import org.jetbrains.kotlin.fir.declarations.builder.buildAnonymousFunction import org.jetbrains.kotlin.fir.declarations.builder.buildRegularClass import org.jetbrains.kotlin.fir.declarations.builder.buildValueParameter import org.jetbrains.kotlin.fir.declarations.impl.FirResolvedDeclarationStatusImpl +import org.jetbrains.kotlin.fir.declarations.utils.classId import org.jetbrains.kotlin.fir.expressions.FirFunctionCall import org.jetbrains.kotlin.fir.expressions.FirLiteralExpression import org.jetbrains.kotlin.fir.expressions.buildResolvedArgumentList @@ -55,6 +58,7 @@ import org.jetbrains.kotlin.fir.symbols.impl.FirValueParameterSymbol import org.jetbrains.kotlin.fir.toFirResolvedTypeRef import org.jetbrains.kotlin.fir.types.ConeKotlinTypeProjection import org.jetbrains.kotlin.fir.types.ConeStarProjection +import org.jetbrains.kotlin.fir.types.ConeTypeProjection import org.jetbrains.kotlin.fir.types.builder.buildResolvedTypeRef import org.jetbrains.kotlin.fir.types.builder.buildTypeProjectionWithVariance import org.jetbrains.kotlin.fir.types.classId @@ -75,6 +79,8 @@ import org.jetbrains.kotlinx.dataframe.plugin.impl.SimpleCol import org.jetbrains.kotlinx.dataframe.plugin.impl.SimpleDataColumn import org.jetbrains.kotlinx.dataframe.plugin.impl.SimpleColumnGroup import org.jetbrains.kotlinx.dataframe.plugin.impl.SimpleFrameColumn +import org.jetbrains.kotlinx.dataframe.plugin.impl.api.GroupBy +import org.jetbrains.kotlinx.dataframe.plugin.impl.api.createPluginDataFrameSchema import kotlin.math.abs @OptIn(FirExtensionApiInternals::class) @@ -89,6 +95,13 @@ class FunctionCallTransformer( const val DEFAULT_NAME = "DataFrameType" } + private interface CallTransformer { + fun interceptOrNull(callInfo: CallInfo, symbol: FirNamedFunctionSymbol, hash: String): CallReturnType? + fun transformOrNull(call: FirFunctionCall, originalSymbol: FirNamedFunctionSymbol): FirFunctionCall? + } + + private val transformers = listOf(GroupByCallTransformer(), DataFrameCallTransformer()) + override fun intercept(callInfo: CallInfo, symbol: FirNamedFunctionSymbol): CallReturnType? { val callSiteAnnotations = (callInfo.callSite as? FirAnnotationContainer)?.annotations ?: emptyList() if (callSiteAnnotations.any { it.fqName(session)?.shortName()?.equals(Name.identifier("DisableInterpretation")) == true }) { @@ -103,7 +116,6 @@ class FunctionCallTransformer( } if (exposesLocalType(callInfo)) return null - val lookupTag = ConeClassLikeLookupTagImpl(Names.DF_CLASS_ID) val hash = run { val hash = callInfo.name.hashCode() + callInfo.arguments.sumOf { when (it) { @@ -114,16 +126,134 @@ class FunctionCallTransformer( hashToTwoCharString(abs(hash)) } - fun Name.asTokenName() = identifierOrNullIfSpecial?.titleCase() ?: DEFAULT_NAME + return transformers.firstNotNullOfOrNull { it.interceptOrNull(callInfo, symbol, hash) } + } + + private fun exposesLocalType(callInfo: CallInfo): Boolean { + val property = callInfo.containingDeclarations.lastOrNull()?.symbol as? FirPropertySymbol + return (property != null && !property.resolvedStatus.effectiveVisibility.privateApi) + } + + private fun hashToTwoCharString(hash: Int): String { + val baseChars = "0123456789" + val base = baseChars.length + val positiveHash = abs(hash) + val char1 = baseChars[positiveHash % base] + val char2 = baseChars[(positiveHash / base) % base] + + return "$char1$char2" + } + + override fun transform(call: FirFunctionCall, originalSymbol: FirNamedFunctionSymbol): FirFunctionCall { + return transformers + .firstNotNullOfOrNull { it.transformOrNull(call, originalSymbol) } + ?: call + } + + inner class DataFrameCallTransformer : CallTransformer { + override fun interceptOrNull(callInfo: CallInfo, symbol: FirNamedFunctionSymbol, hash: String): CallReturnType? { + if (symbol.resolvedReturnType.fullyExpandedClassId(session) != Names.DF_CLASS_ID) return null + // possibly null if explicit receiver type is AnyFrame + val argument = (callInfo.explicitReceiver?.resolvedType)?.typeArguments?.getOrNull(0) + val newDataFrameArgument = buildNewTypeArgument(argument, callInfo.name, hash) + + val lookupTag = ConeClassLikeLookupTagImpl(Names.DF_CLASS_ID) + val typeRef = buildResolvedTypeRef { + type = ConeClassLikeTypeImpl( + lookupTag, + arrayOf( + ConeClassLikeTypeImpl( + ConeClassLookupTagWithFixedSymbol(newDataFrameArgument.classId, newDataFrameArgument.symbol), + emptyArray(), + isNullable = false + ) + ), + isNullable = false + ) + } + return CallReturnType(typeRef) + } + + @OptIn(SymbolInternals::class) + override fun transformOrNull(call: FirFunctionCall, originalSymbol: FirNamedFunctionSymbol): FirFunctionCall? { + val callResult = analyzeRefinedCallShape(call, Names.DF_CLASS_ID, InterpretationErrorReporter.DEFAULT) + val (tokens, dataFrameSchema) = callResult ?: return null + val token = tokens[0] + val firstSchema = token.toClassSymbol(session)?.resolvedSuperTypes?.get(0)!!.toRegularClassSymbol(session)?.fir!! + val dataSchemaApis = materialize(dataFrameSchema, call, firstSchema) - // possibly null if explicit receiver type is AnyFrame - val argument = (callInfo.explicitReceiver?.resolvedType)?.typeArguments?.singleOrNull() + val tokenFir = token.toClassSymbol(session)!!.fir + tokenFir.callShapeData = CallShapeData.RefinedType(dataSchemaApis.map { it.scope.symbol }) + + return buildLetCall(call, originalSymbol, dataSchemaApis, listOf(tokenFir)) + } + } + + inner class GroupByCallTransformer : CallTransformer { + override fun interceptOrNull( + callInfo: CallInfo, + symbol: FirNamedFunctionSymbol, + hash: String + ): CallReturnType? { + if (symbol.resolvedReturnType.fullyExpandedClassId(session) != Names.GROUP_BY_CLASS_ID) return null + val keys = buildNewTypeArgument(null, Name.identifier("Key"), hash) + val group = buildNewTypeArgument(null, Name.identifier("Group"), hash) + val lookupTag = ConeClassLikeLookupTagImpl(Names.GROUP_BY_CLASS_ID) + val typeRef = buildResolvedTypeRef { + type = ConeClassLikeTypeImpl( + lookupTag, + arrayOf( + ConeClassLikeTypeImpl( + ConeClassLookupTagWithFixedSymbol(keys.classId, keys.symbol), + emptyArray(), + isNullable = false + ), + ConeClassLikeTypeImpl( + ConeClassLookupTagWithFixedSymbol(group.classId, group.symbol), + emptyArray(), + isNullable = false + ) + ), + isNullable = false + ) + } + return CallReturnType(typeRef) + } + + @OptIn(SymbolInternals::class) + override fun transformOrNull(call: FirFunctionCall, originalSymbol: FirNamedFunctionSymbol): FirFunctionCall? { + val callResult = analyzeRefinedCallShape(call, Names.GROUP_BY_CLASS_ID, InterpretationErrorReporter.DEFAULT) + val (rootMarkers, groupBy) = callResult ?: return null + + val keyMarker = rootMarkers[0] + val groupMarker = rootMarkers[1] + + val keySchema = createPluginDataFrameSchema(groupBy.keys, groupBy.moveToTop) + val groupSchema = PluginDataFrameSchema(groupBy.df.columns()) + + val firstSchema = keyMarker.toClassSymbol(session)?.resolvedSuperTypes?.get(0)!!.toRegularClassSymbol(session)?.fir!! + val firstSchema1 = groupMarker.toClassSymbol(session)?.resolvedSuperTypes?.get(0)!!.toRegularClassSymbol(session)?.fir!! + + val keyApis = materialize(keySchema, call, firstSchema, "Key") + val groupApis = materialize(groupSchema, call, firstSchema1, "Group", i = keyApis.size) + + val groupToken = keyMarker.toClassSymbol(session)!!.fir + groupToken.callShapeData = CallShapeData.RefinedType(keyApis.map { it.scope.symbol }) + + val keyToken = groupMarker.toClassSymbol(session)!!.fir + keyToken.callShapeData = CallShapeData.RefinedType(groupApis.map { it.scope.symbol }) + + return buildLetCall(call, originalSymbol, keyApis + groupApis, additionalDeclarations = listOf(groupToken, keyToken)) + } + } + + private fun buildNewTypeArgument(argument: ConeTypeProjection?, name: Name, hash: String): FirRegularClass { val suggestedName = if (argument == null) { - "${callInfo.name.asTokenName()}_$hash" + "${name.asTokenName()}_$hash" } else { when (argument) { is ConeStarProjection -> { - "${callInfo.name.asTokenName()}_$hash" + "${name.asTokenName()}_$hash" } is ConeKotlinTypeProjection -> { val titleCase = argument.type.classId?.shortClassName @@ -154,149 +284,30 @@ class FunctionCallTransformer( ) } - name = dataFrameTypeId.shortClassName + this.name = dataFrameTypeId.shortClassName this.symbol = FirRegularClassSymbol(dataFrameTypeId) } - - val typeRef = buildResolvedTypeRef { - type = ConeClassLikeTypeImpl( - lookupTag, - arrayOf( - ConeClassLikeTypeImpl( - ConeClassLookupTagWithFixedSymbol(dataFrameTypeId, dataFrameType.symbol), - emptyArray(), - isNullable = false - ) - ), - isNullable = false - ) - } - return CallReturnType(typeRef) - } - - private fun exposesLocalType(callInfo: CallInfo): Boolean { - val property = callInfo.containingDeclarations.lastOrNull()?.symbol as? FirPropertySymbol - return (property != null && !property.resolvedStatus.effectiveVisibility.privateApi) - } - - private fun hashToTwoCharString(hash: Int): String { - val baseChars = "0123456789" - val base = baseChars.length - val positiveHash = abs(hash) - val char1 = baseChars[positiveHash % base] - val char2 = baseChars[(positiveHash / base) % base] - - return "$char1$char2" + return dataFrameType } private fun nextName(s: String) = ClassId(CallableId.PACKAGE_FQ_NAME_FOR_LOCAL, FqName(s), true) + private fun Name.asTokenName() = identifierOrNullIfSpecial?.titleCase() ?: DEFAULT_NAME + @OptIn(SymbolInternals::class) - override fun transform(call: FirFunctionCall, originalSymbol: FirNamedFunctionSymbol): FirFunctionCall { - val (token, dataFrameSchema) = - analyzeRefinedCallShape(call, InterpretationErrorReporter.DEFAULT) ?: return call + private fun buildLetCall( + call: FirFunctionCall, + originalSymbol: FirNamedFunctionSymbol, + dataSchemaApis: List, + additionalDeclarations: List + ): FirFunctionCall { val explicitReceiver = call.explicitReceiver ?: return call val receiverType = explicitReceiver.resolvedType val returnType = call.resolvedType - val resolvedLet = findLet() val parameter = resolvedLet.valueParameterSymbols[0] - - val firstSchema = token.toClassSymbol(session)?.resolvedSuperTypes?.get(0)!!.toRegularClassSymbol(session)?.fir!! - - data class DataSchemaApi(val schema: FirRegularClass, val scope: FirRegularClass) - - var i = 0 - val dataSchemaApis = mutableListOf() - val usedNames = mutableMapOf() - fun PluginDataFrameSchema.materialize(schema: FirRegularClass? = null, suggestedName: String? = null): DataSchemaApi { - val schema = if (schema != null) { - schema - } else { - requireNotNull(suggestedName) - val uniqueSuffix = usedNames.compute(suggestedName) { _, i -> (i ?: 0) + 1 } - val name = nextName(suggestedName + uniqueSuffix) - buildSchema(name) - } - - val scopeId = ClassId(CallableId.PACKAGE_FQ_NAME_FOR_LOCAL, FqName("Scope${i++}"), true) - val scope = buildRegularClass { - moduleData = session.moduleData - resolvePhase = FirResolvePhase.BODY_RESOLVE - origin = FirDeclarationOrigin.Source - status = FirResolvedDeclarationStatusImpl(Visibilities.Local, Modality.FINAL, EffectiveVisibility.Local) - deprecationsProvider = EmptyDeprecationsProvider - classKind = ClassKind.CLASS - scopeProvider = FirKotlinScopeProvider() - superTypeRefs += FirImplicitAnyTypeRef(null) - - this.name = scopeId.shortClassName - this.symbol = FirRegularClassSymbol(scopeId) - } - - val properties = columns().map { - fun PluginDataFrameSchema.materialize(column: SimpleCol): DataSchemaApi { - val text = call.source?.text ?: call.calleeReference.name - val name = "${column.name.titleCase().replEscapeLineBreaks()}_${hashToTwoCharString(abs(text.hashCode()))}" - return materialize(suggestedName = name) - } - - when (it) { - is SimpleColumnGroup -> { - val nestedSchema = PluginDataFrameSchema(it.columns()).materialize(it) - val columnsContainerReturnType = - ConeClassLikeTypeImpl( - ConeClassLikeLookupTagImpl(Names.COLUM_GROUP_CLASS_ID), - typeArguments = arrayOf(nestedSchema.schema.defaultType()), - isNullable = false - ) - - val dataRowReturnType = - ConeClassLikeTypeImpl( - ConeClassLikeLookupTagImpl(Names.DATA_ROW_CLASS_ID), - typeArguments = arrayOf(nestedSchema.schema.defaultType()), - isNullable = false - ) - - SchemaProperty(schema.defaultType(), it.name, dataRowReturnType, columnsContainerReturnType) - } - - is SimpleFrameColumn -> { - val nestedClassMarker = PluginDataFrameSchema(it.columns()).materialize(it) - val frameColumnReturnType = - ConeClassLikeTypeImpl( - ConeClassLikeLookupTagImpl(Names.DF_CLASS_ID), - typeArguments = arrayOf(nestedClassMarker.schema.defaultType()), - isNullable = false - ) - - SchemaProperty( - marker = schema.defaultType(), - name = it.name, - dataRowReturnType = frameColumnReturnType, - columnContainerReturnType = frameColumnReturnType.toFirResolvedTypeRef().projectOverDataColumnType() - ) - } - - is SimpleDataColumn -> SchemaProperty( - marker = schema.defaultType(), - name = it.name, - dataRowReturnType = it.type.type(), - columnContainerReturnType = it.type.type().toFirResolvedTypeRef().projectOverDataColumnType() - ) - } - } - schema.callShapeData = CallShapeData.Schema(properties) - scope.callShapeData = CallShapeData.Scope(properties) - val schemaApi = DataSchemaApi(schema, scope) - dataSchemaApis.add(schemaApi) - return schemaApi - } - - dataFrameSchema.materialize(firstSchema) - // original call is inserted later call.transformCalleeReference(object : FirTransformer() { override fun transformElement(element: E, data: Nothing?): E { @@ -312,14 +323,11 @@ class FunctionCallTransformer( } }, null) - val tokenFir = token.toClassSymbol(session)!!.fir - tokenFir.callShapeData = CallShapeData.RefinedType(dataSchemaApis.map { it.scope.symbol }) - val callExplicitReceiver = call.explicitReceiver val callDispatchReceiver = call.dispatchReceiver val callExtensionReceiver = call.extensionReceiver - val argument = buildAnonymousFunctionExpression { + val argument = buildAnonymousFunctionExpression { isTrailingLambda = true val fSymbol = FirAnonymousFunctionSymbol() val target = FirFunctionTarget(null, isLambda = true) @@ -354,7 +362,7 @@ class FunctionCallTransformer( statements += it.scope } - statements += tokenFir + statements += additionalDeclarations statements += buildReturnExpression { val itPropertyAccess = buildPropertyAccessExpression { @@ -421,6 +429,111 @@ class FunctionCallTransformer( return newCall1 } + private fun materialize( + dataFrameSchema: PluginDataFrameSchema, + call: FirFunctionCall, + firstSchema: FirRegularClass, + prefix: String = "", + i: Int = 0 + ): List { + var i = i + val dataSchemaApis = mutableListOf() + val usedNames = mutableMapOf() + fun PluginDataFrameSchema.materialize( + schema: FirRegularClass? = null, + suggestedName: String? = null + ): DataSchemaApi { + val schema = if (schema != null) { + schema + } else { + requireNotNull(suggestedName) + val uniqueSuffix = usedNames.compute(suggestedName) { _, i -> (i ?: 0) + 1 } + val name = nextName(suggestedName + uniqueSuffix) + buildSchema(name) + } + + val scopeId = ClassId(CallableId.PACKAGE_FQ_NAME_FOR_LOCAL, FqName("Scope${i++}"), true) + val scope = buildRegularClass { + moduleData = session.moduleData + resolvePhase = FirResolvePhase.BODY_RESOLVE + origin = FirDeclarationOrigin.Source + status = FirResolvedDeclarationStatusImpl(Visibilities.Local, Modality.FINAL, EffectiveVisibility.Local) + deprecationsProvider = EmptyDeprecationsProvider + classKind = ClassKind.CLASS + scopeProvider = FirKotlinScopeProvider() + superTypeRefs += FirImplicitAnyTypeRef(null) + + this.name = scopeId.shortClassName + this.symbol = FirRegularClassSymbol(scopeId) + } + + val properties = columns().map { + fun PluginDataFrameSchema.materialize(column: SimpleCol): DataSchemaApi { + val text = call.source?.text ?: call.calleeReference.name + val name = + "${column.name.titleCase().replEscapeLineBreaks()}_${hashToTwoCharString(abs(text.hashCode()))}" + return materialize(suggestedName = "$prefix$name") + } + + when (it) { + is SimpleColumnGroup -> { + val nestedSchema = PluginDataFrameSchema(it.columns()).materialize(it) + val columnsContainerReturnType = + ConeClassLikeTypeImpl( + ConeClassLikeLookupTagImpl(Names.COLUM_GROUP_CLASS_ID), + typeArguments = arrayOf(nestedSchema.schema.defaultType()), + isNullable = false + ) + + val dataRowReturnType = + ConeClassLikeTypeImpl( + ConeClassLikeLookupTagImpl(Names.DATA_ROW_CLASS_ID), + typeArguments = arrayOf(nestedSchema.schema.defaultType()), + isNullable = false + ) + + SchemaProperty(schema.defaultType(), it.name, dataRowReturnType, columnsContainerReturnType) + } + + is SimpleFrameColumn -> { + val nestedClassMarker = PluginDataFrameSchema(it.columns()).materialize(it) + val frameColumnReturnType = + ConeClassLikeTypeImpl( + ConeClassLikeLookupTagImpl(Names.DF_CLASS_ID), + typeArguments = arrayOf(nestedClassMarker.schema.defaultType()), + isNullable = false + ) + + SchemaProperty( + marker = schema.defaultType(), + name = it.name, + dataRowReturnType = frameColumnReturnType, + columnContainerReturnType = frameColumnReturnType.toFirResolvedTypeRef() + .projectOverDataColumnType() + ) + } + + is SimpleDataColumn -> SchemaProperty( + marker = schema.defaultType(), + name = it.name, + dataRowReturnType = it.type.type(), + columnContainerReturnType = it.type.type().toFirResolvedTypeRef().projectOverDataColumnType() + ) + } + } + schema.callShapeData = CallShapeData.Schema(properties) + scope.callShapeData = CallShapeData.Scope(properties) + val schemaApi = DataSchemaApi(schema, scope) + dataSchemaApis.add(schemaApi) + return schemaApi + } + + dataFrameSchema.materialize(firstSchema) + return dataSchemaApis + } + + data class DataSchemaApi(val schema: FirRegularClass, val scope: FirRegularClass) + private fun buildSchema(tokenId: ClassId): FirRegularClass { val token = buildRegularClass { moduleData = session.moduleData diff --git a/plugins/kotlin-dataframe/src/org/jetbrains/kotlinx/dataframe/plugin/extensions/ReturnTypeBasedReceiverInjector.kt b/plugins/kotlin-dataframe/src/org/jetbrains/kotlinx/dataframe/plugin/extensions/ReturnTypeBasedReceiverInjector.kt index 508e34059..f3478b1a9 100644 --- a/plugins/kotlin-dataframe/src/org/jetbrains/kotlinx/dataframe/plugin/extensions/ReturnTypeBasedReceiverInjector.kt +++ b/plugins/kotlin-dataframe/src/org/jetbrains/kotlinx/dataframe/plugin/extensions/ReturnTypeBasedReceiverInjector.kt @@ -1,7 +1,6 @@ package org.jetbrains.kotlinx.dataframe.plugin.extensions import org.jetbrains.kotlin.fir.FirSession -import org.jetbrains.kotlinx.dataframe.plugin.utils.Names import org.jetbrains.kotlin.fir.declarations.FirResolvePhase import org.jetbrains.kotlin.fir.declarations.getAnnotationByClassId import org.jetbrains.kotlin.fir.expressions.FirFunctionCall @@ -10,32 +9,33 @@ import org.jetbrains.kotlin.fir.scopes.collectAllProperties import org.jetbrains.kotlin.fir.scopes.impl.declaredMemberScope import org.jetbrains.kotlin.fir.symbols.SymbolInternals import org.jetbrains.kotlin.fir.symbols.impl.FirPropertySymbol -import org.jetbrains.kotlin.fir.symbols.impl.FirRegularClassSymbol -import org.jetbrains.kotlin.fir.types.ConeClassLikeType import org.jetbrains.kotlin.fir.types.ConeKotlinType import org.jetbrains.kotlin.fir.types.classId import org.jetbrains.kotlin.fir.types.resolvedType import org.jetbrains.kotlin.fir.types.toRegularClassSymbol +import org.jetbrains.kotlinx.dataframe.plugin.utils.Names class ReturnTypeBasedReceiverInjector(session: FirSession) : FirExpressionResolutionExtension(session) { - override fun addNewImplicitReceivers(functionCall: FirFunctionCall): List { - val symbol = generatedTokenOrNull(functionCall) ?: return emptyList() - return symbol.declaredMemberScope(session, FirResolvePhase.DECLARATIONS).collectAllProperties() - .filterIsInstance() - .filter { it.getAnnotationByClassId(Names.SCOPE_PROPERTY_ANNOTATION, session) != null } - .map { it.resolvedReturnType } - } - @OptIn(SymbolInternals::class) - private fun generatedTokenOrNull(call: FirFunctionCall): FirRegularClassSymbol? { - val callReturnType = call.resolvedType - if (callReturnType.classId != Names.DF_CLASS_ID) return null - val rootMarker = callReturnType.typeArguments[0] - if (rootMarker !is ConeClassLikeType) { - return null + override fun addNewImplicitReceivers(functionCall: FirFunctionCall): List { + val callReturnType = functionCall.resolvedType + return if (callReturnType.classId in setOf(Names.DF_CLASS_ID, Names.GROUP_BY_CLASS_ID)) { + val typeArguments = callReturnType.typeArguments + typeArguments + .mapNotNull { + val symbol = (it as? ConeKotlinType)?.toRegularClassSymbol(session) + symbol?.takeIf { it.fir.callShapeData != null } + } + .takeIf { it.size == typeArguments.size } + .orEmpty() + .flatMap { marker -> + marker.declaredMemberScope(session, FirResolvePhase.DECLARATIONS).collectAllProperties() + .filterIsInstance() + .filter { it.getAnnotationByClassId(Names.SCOPE_PROPERTY_ANNOTATION, session) != null } + .map { it.resolvedReturnType } + } + } else { + emptyList() } - - val symbol = rootMarker.toRegularClassSymbol(session) - return symbol.takeIf { it?.fir?.callShapeData != null } } } diff --git a/plugins/kotlin-dataframe/src/org/jetbrains/kotlinx/dataframe/plugin/impl/api/add.kt b/plugins/kotlin-dataframe/src/org/jetbrains/kotlinx/dataframe/plugin/impl/api/add.kt index b8237c93d..8c073fecd 100644 --- a/plugins/kotlin-dataframe/src/org/jetbrains/kotlinx/dataframe/plugin/impl/api/add.kt +++ b/plugins/kotlin-dataframe/src/org/jetbrains/kotlinx/dataframe/plugin/impl/api/add.kt @@ -7,6 +7,7 @@ import org.jetbrains.kotlinx.dataframe.plugin.impl.Arguments import org.jetbrains.kotlinx.dataframe.plugin.impl.Interpreter import org.jetbrains.kotlinx.dataframe.plugin.impl.PluginDataFrameSchema import org.jetbrains.kotlinx.dataframe.plugin.impl.SimpleCol +import org.jetbrains.kotlinx.dataframe.plugin.impl.SimpleColumnGroup import org.jetbrains.kotlinx.dataframe.plugin.impl.dataFrame import org.jetbrains.kotlinx.dataframe.plugin.impl.simpleColumnOf import org.jetbrains.kotlinx.dataframe.plugin.impl.dsl @@ -57,3 +58,15 @@ class AddWithDsl : AbstractSchemaModificationInterpreter() { return PluginDataFrameSchema(addDsl.columns) } } + +class AddDslStringInvoke : AbstractInterpreter() { + val Arguments.dsl: AddDslApproximation by arg(lens = Interpreter.Value) + val Arguments.receiver: String by string() + val Arguments.body by dsl() + + override fun Arguments.interpret() { + val addDsl = AddDslApproximation(mutableListOf()) + body(addDsl, emptyMap()) + dsl.columns.add(SimpleColumnGroup(receiver, addDsl.columns)) + } +} diff --git a/plugins/kotlin-dataframe/src/org/jetbrains/kotlinx/dataframe/plugin/impl/api/addId.kt b/plugins/kotlin-dataframe/src/org/jetbrains/kotlinx/dataframe/plugin/impl/api/addId.kt new file mode 100644 index 000000000..6879e9394 --- /dev/null +++ b/plugins/kotlin-dataframe/src/org/jetbrains/kotlinx/dataframe/plugin/impl/api/addId.kt @@ -0,0 +1,21 @@ +package org.jetbrains.kotlinx.dataframe.plugin.impl.api + +import org.jetbrains.kotlinx.dataframe.plugin.impl.AbstractSchemaModificationInterpreter +import org.jetbrains.kotlinx.dataframe.plugin.impl.Arguments +import org.jetbrains.kotlinx.dataframe.plugin.impl.PluginDataFrameSchema +import org.jetbrains.kotlinx.dataframe.plugin.impl.Present +import org.jetbrains.kotlinx.dataframe.plugin.impl.dataFrame +import org.jetbrains.kotlinx.dataframe.plugin.impl.simpleColumnOf + +class AddId : AbstractSchemaModificationInterpreter() { + val Arguments.receiver: PluginDataFrameSchema by dataFrame() + val Arguments.columnName: String by arg(defaultValue = Present("id")) + + override fun Arguments.interpret(): PluginDataFrameSchema { + val columns = buildList { + add(simpleColumnOf(columnName, session.builtinTypes.intType.type)) + addAll(receiver.columns()) + } + return PluginDataFrameSchema(columns) + } +} diff --git a/plugins/kotlin-dataframe/src/org/jetbrains/kotlinx/dataframe/plugin/impl/api/fillNulls.kt b/plugins/kotlin-dataframe/src/org/jetbrains/kotlinx/dataframe/plugin/impl/api/fillNulls.kt index 7474230fe..6f0f14fc4 100644 --- a/plugins/kotlin-dataframe/src/org/jetbrains/kotlinx/dataframe/plugin/impl/api/fillNulls.kt +++ b/plugins/kotlin-dataframe/src/org/jetbrains/kotlinx/dataframe/plugin/impl/api/fillNulls.kt @@ -16,13 +16,16 @@ class FillNulls0 : AbstractInterpreter() { } } -class FillNullsApproximation(val schema: PluginDataFrameSchema, val columns: ColumnsResolver) +class FillNullsApproximation(val schema: PluginDataFrameSchema, val columns: ColumnsResolver) : UpdateApproximation class UpdateWith0 : AbstractSchemaModificationInterpreter() { - val Arguments.receiver: FillNullsApproximation by arg() + val Arguments.receiver: UpdateApproximation by arg() val Arguments.expression: TypeApproximation by type() override fun Arguments.interpret(): PluginDataFrameSchema { - return convertImpl(receiver.schema, receiver.columns.resolve(receiver.schema).map { it.path.path }, expression) + return when (val receiver = receiver) { + is FillNullsApproximation -> convertImpl(receiver.schema, receiver.columns.resolve(receiver.schema).map { it.path.path }, expression) + is UpdateApproximationImpl -> convertImpl(receiver.schema, receiver.columns.resolve(receiver.schema).map { it.path.path }, expression) + } } } diff --git a/plugins/kotlin-dataframe/src/org/jetbrains/kotlinx/dataframe/plugin/impl/api/groupBy.kt b/plugins/kotlin-dataframe/src/org/jetbrains/kotlinx/dataframe/plugin/impl/api/groupBy.kt index 28b47a037..d3ad1033d 100644 --- a/plugins/kotlin-dataframe/src/org/jetbrains/kotlinx/dataframe/plugin/impl/api/groupBy.kt +++ b/plugins/kotlin-dataframe/src/org/jetbrains/kotlinx/dataframe/plugin/impl/api/groupBy.kt @@ -3,15 +3,11 @@ package org.jetbrains.kotlinx.dataframe.plugin.impl.api import org.jetbrains.kotlinx.dataframe.plugin.InterpretationErrorReporter import org.jetbrains.kotlinx.dataframe.plugin.interpret import org.jetbrains.kotlinx.dataframe.plugin.loadInterpreter -import org.jetbrains.kotlinx.dataframe.plugin.pluginDataFrameSchema -import org.jetbrains.kotlinx.dataframe.plugin.utils.Names import org.jetbrains.kotlin.fir.expressions.FirAnonymousFunctionExpression import org.jetbrains.kotlin.fir.expressions.FirExpression import org.jetbrains.kotlin.fir.expressions.FirFunctionCall import org.jetbrains.kotlin.fir.expressions.FirReturnExpression import org.jetbrains.kotlin.fir.types.ConeKotlinType -import org.jetbrains.kotlin.fir.types.ConeNullability -import org.jetbrains.kotlin.fir.types.classId import org.jetbrains.kotlin.fir.types.resolvedType import org.jetbrains.kotlinx.dataframe.plugin.extensions.KotlinTypeFacade import org.jetbrains.kotlinx.dataframe.plugin.impl.AbstractInterpreter @@ -55,21 +51,30 @@ class GroupByInto : AbstractInterpreter() { } } +class Aggregate : AbstractSchemaModificationInterpreter() { + val Arguments.receiver: GroupBy by arg() + val Arguments.body: FirAnonymousFunctionExpression by arg(lens = Interpreter.Id) + override fun Arguments.interpret(): PluginDataFrameSchema { + return aggregate( + receiver, + InterpretationErrorReporter.DEFAULT, + body + ) + } +} + fun KotlinTypeFacade.aggregate( - groupByCall: FirFunctionCall, - interpreter: Interpreter<*>, + groupBy: GroupBy, reporter: InterpretationErrorReporter, - call: FirFunctionCall -): PluginDataFrameSchema? { - val groupBy = interpret(groupByCall, interpreter, reporter = reporter)?.value as? GroupBy ?: return null - val aggregate = call.argumentList.arguments.singleOrNull() as? FirAnonymousFunctionExpression - val body = aggregate?.anonymousFunction?.body ?: return null - val lastExpression = (body.statements.lastOrNull() as? FirReturnExpression)?.result + firAnonymousFunctionExpression: FirAnonymousFunctionExpression +): PluginDataFrameSchema { + val body = firAnonymousFunctionExpression.anonymousFunction.body + val lastExpression = (body?.statements?.lastOrNull() as? FirReturnExpression)?.result val type = lastExpression?.resolvedType return if (type != session.builtinTypes.unitType) { val dsl = GroupByDsl() val calls = buildList { - body.statements.filterIsInstance().let { addAll(it) } + body?.statements?.filterIsInstance()?.let { addAll(it) } if (lastExpression is FirFunctionCall) add(lastExpression) } calls.forEach { call -> @@ -87,7 +92,7 @@ fun KotlinTypeFacade.aggregate( } PluginDataFrameSchema(cols) } else { - null + PluginDataFrameSchema(emptyList()) } } diff --git a/plugins/kotlin-dataframe/src/org/jetbrains/kotlinx/dataframe/plugin/impl/api/map.kt b/plugins/kotlin-dataframe/src/org/jetbrains/kotlinx/dataframe/plugin/impl/api/map.kt new file mode 100644 index 000000000..02f3fe83a --- /dev/null +++ b/plugins/kotlin-dataframe/src/org/jetbrains/kotlinx/dataframe/plugin/impl/api/map.kt @@ -0,0 +1,18 @@ +package org.jetbrains.kotlinx.dataframe.plugin.impl.api + +import org.jetbrains.kotlinx.dataframe.plugin.impl.AbstractSchemaModificationInterpreter +import org.jetbrains.kotlinx.dataframe.plugin.impl.Arguments +import org.jetbrains.kotlinx.dataframe.plugin.impl.PluginDataFrameSchema +import org.jetbrains.kotlinx.dataframe.plugin.impl.dataFrame +import org.jetbrains.kotlinx.dataframe.plugin.impl.dsl + +class MapToFrame : AbstractSchemaModificationInterpreter() { + val Arguments.receiver: PluginDataFrameSchema by dataFrame() + val Arguments.body by dsl() + + override fun Arguments.interpret(): PluginDataFrameSchema { + val addDsl = AddDslApproximation(mutableListOf()) + body(addDsl, emptyMap()) + return PluginDataFrameSchema(addDsl.columns) + } +} diff --git a/plugins/kotlin-dataframe/src/org/jetbrains/kotlinx/dataframe/plugin/impl/api/move.kt b/plugins/kotlin-dataframe/src/org/jetbrains/kotlinx/dataframe/plugin/impl/api/move.kt new file mode 100644 index 000000000..e80de799c --- /dev/null +++ b/plugins/kotlin-dataframe/src/org/jetbrains/kotlinx/dataframe/plugin/impl/api/move.kt @@ -0,0 +1,33 @@ +package org.jetbrains.kotlinx.dataframe.plugin.impl.api + +import org.jetbrains.kotlinx.dataframe.api.move +import org.jetbrains.kotlinx.dataframe.api.pathOf +import org.jetbrains.kotlinx.dataframe.api.toTop +import org.jetbrains.kotlinx.dataframe.columns.toColumnSet +import org.jetbrains.kotlinx.dataframe.plugin.impl.AbstractInterpreter +import org.jetbrains.kotlinx.dataframe.plugin.impl.AbstractSchemaModificationInterpreter +import org.jetbrains.kotlinx.dataframe.plugin.impl.Arguments +import org.jetbrains.kotlinx.dataframe.plugin.impl.PluginDataFrameSchema +import org.jetbrains.kotlinx.dataframe.plugin.impl.asDataFrame +import org.jetbrains.kotlinx.dataframe.plugin.impl.dataFrame +import org.jetbrains.kotlinx.dataframe.plugin.impl.toPluginDataFrameSchema + +class Move0 : AbstractInterpreter() { + val Arguments.receiver: PluginDataFrameSchema by dataFrame() + val Arguments.columns: ColumnsResolver by arg() + + override fun Arguments.interpret(): MoveClauseApproximation { + return MoveClauseApproximation(receiver, columns) + } +} + +class ToTop : AbstractSchemaModificationInterpreter() { + val Arguments.receiver: MoveClauseApproximation by arg() + + override fun Arguments.interpret(): PluginDataFrameSchema { + val columns = receiver.columns.resolve(receiver.df).map { pathOf(*it.path.path.toTypedArray()) } + return receiver.df.asDataFrame().move { columns.toColumnSet() }.toTop().toPluginDataFrameSchema() + } +} + +class MoveClauseApproximation(val df: PluginDataFrameSchema, val columns: ColumnsResolver) diff --git a/plugins/kotlin-dataframe/src/org/jetbrains/kotlinx/dataframe/plugin/impl/api/update.kt b/plugins/kotlin-dataframe/src/org/jetbrains/kotlinx/dataframe/plugin/impl/api/update.kt new file mode 100644 index 000000000..813272478 --- /dev/null +++ b/plugins/kotlin-dataframe/src/org/jetbrains/kotlinx/dataframe/plugin/impl/api/update.kt @@ -0,0 +1,19 @@ +package org.jetbrains.kotlinx.dataframe.plugin.impl.api + +import org.jetbrains.kotlinx.dataframe.plugin.impl.AbstractInterpreter +import org.jetbrains.kotlinx.dataframe.plugin.impl.Arguments +import org.jetbrains.kotlinx.dataframe.plugin.impl.PluginDataFrameSchema +import org.jetbrains.kotlinx.dataframe.plugin.impl.dataFrame + +class Update0 : AbstractInterpreter() { + val Arguments.receiver: PluginDataFrameSchema by dataFrame() + val Arguments.columns: ColumnsResolver by arg() + + override fun Arguments.interpret(): UpdateApproximationImpl { + return UpdateApproximationImpl(receiver, columns) + } +} + +sealed interface UpdateApproximation + +class UpdateApproximationImpl(val schema: PluginDataFrameSchema, val columns: ColumnsResolver) : UpdateApproximation diff --git a/plugins/kotlin-dataframe/src/org/jetbrains/kotlinx/dataframe/plugin/interpret.kt b/plugins/kotlin-dataframe/src/org/jetbrains/kotlinx/dataframe/plugin/interpret.kt index a03804d14..04cc1fe0b 100644 --- a/plugins/kotlin-dataframe/src/org/jetbrains/kotlinx/dataframe/plugin/interpret.kt +++ b/plugins/kotlin-dataframe/src/org/jetbrains/kotlinx/dataframe/plugin/interpret.kt @@ -28,6 +28,7 @@ import org.jetbrains.kotlin.fir.expressions.FirResolvedQualifier import org.jetbrains.kotlin.fir.expressions.FirReturnExpression import org.jetbrains.kotlin.fir.expressions.FirThisReceiverExpression import org.jetbrains.kotlin.fir.expressions.FirVarargArgumentsExpression +import org.jetbrains.kotlin.fir.expressions.arguments import org.jetbrains.kotlin.fir.expressions.impl.FirResolvedArgumentList import org.jetbrains.kotlin.fir.references.FirResolvedCallableReference import org.jetbrains.kotlin.fir.references.FirResolvedNamedReference @@ -145,9 +146,19 @@ fun KotlinTypeFacade.interpret( is FirFunctionCall -> { val interpreter = expression.loadInterpreter() - interpreter?.let { - val result = interpret(expression, interpreter, emptyMap(), reporter) - result + if (interpreter == null) { + // if the plugin already transformed call, its original form is the last expression of .let { } + val argument = expression.arguments[0] + val last = (argument as? FirAnonymousFunctionExpression)?.anonymousFunction?.body?.statements?.lastOrNull() + val call = (last as? FirReturnExpression)?.result as? FirFunctionCall + call?.loadInterpreter()?.let { + interpret(call, it, emptyMap(), reporter) + } + } else { + interpreter.let { + val result = interpret(expression, interpreter, emptyMap(), reporter) + result + } } } diff --git a/plugins/kotlin-dataframe/src/org/jetbrains/kotlinx/dataframe/plugin/loadInterpreter.kt b/plugins/kotlin-dataframe/src/org/jetbrains/kotlinx/dataframe/plugin/loadInterpreter.kt index ffc938c17..128b9fcb8 100644 --- a/plugins/kotlin-dataframe/src/org/jetbrains/kotlinx/dataframe/plugin/loadInterpreter.kt +++ b/plugins/kotlin-dataframe/src/org/jetbrains/kotlinx/dataframe/plugin/loadInterpreter.kt @@ -67,6 +67,9 @@ import org.jetbrains.kotlin.fir.types.classId import org.jetbrains.kotlin.fir.types.coneType import org.jetbrains.kotlin.name.ClassId import org.jetbrains.kotlin.name.Name +import org.jetbrains.kotlinx.dataframe.plugin.impl.api.AddDslStringInvoke +import org.jetbrains.kotlinx.dataframe.plugin.impl.api.AddId +import org.jetbrains.kotlinx.dataframe.plugin.impl.api.Aggregate import org.jetbrains.kotlinx.dataframe.plugin.impl.api.All0 import org.jetbrains.kotlinx.dataframe.plugin.impl.api.ColsAtAnyDepth0 import org.jetbrains.kotlinx.dataframe.plugin.impl.api.ColsOf0 @@ -77,12 +80,16 @@ import org.jetbrains.kotlinx.dataframe.plugin.impl.api.FillNulls0 import org.jetbrains.kotlinx.dataframe.plugin.impl.api.Flatten0 import org.jetbrains.kotlinx.dataframe.plugin.impl.api.FlattenDefault import org.jetbrains.kotlinx.dataframe.plugin.impl.api.FrameCols0 +import org.jetbrains.kotlinx.dataframe.plugin.impl.api.MapToFrame +import org.jetbrains.kotlinx.dataframe.plugin.impl.api.Move0 import org.jetbrains.kotlinx.dataframe.plugin.impl.api.ReadExcel import org.jetbrains.kotlinx.dataframe.plugin.impl.api.ToDataFrame import org.jetbrains.kotlinx.dataframe.plugin.impl.api.ToDataFrameColumn import org.jetbrains.kotlinx.dataframe.plugin.impl.api.ToDataFrameDefault import org.jetbrains.kotlinx.dataframe.plugin.impl.api.ToDataFrameDsl import org.jetbrains.kotlinx.dataframe.plugin.impl.api.ToDataFrameFrom +import org.jetbrains.kotlinx.dataframe.plugin.impl.api.ToTop +import org.jetbrains.kotlinx.dataframe.plugin.impl.api.Update0 import org.jetbrains.kotlinx.dataframe.plugin.impl.api.UpdateWith0 internal fun FirFunctionCall.loadInterpreter(session: FirSession): Interpreter<*>? { @@ -194,6 +201,13 @@ internal inline fun String.load(): T { "UpdateWith0" -> UpdateWith0() "Flatten0" -> Flatten0() "FlattenDefault" -> FlattenDefault() + "AddId" -> AddId() + "AddDslStringInvoke" -> AddDslStringInvoke() + "MapToFrame" -> MapToFrame() + "Move0" -> Move0() + "ToTop" -> ToTop() + "Update0" -> Update0() + "Aggregate" -> Aggregate() else -> error("$this") } as T } diff --git a/plugins/kotlin-dataframe/src/org/jetbrains/kotlinx/dataframe/plugin/utils/Names.kt b/plugins/kotlin-dataframe/src/org/jetbrains/kotlinx/dataframe/plugin/utils/Names.kt index afc0331e1..7eb9d5971 100644 --- a/plugins/kotlin-dataframe/src/org/jetbrains/kotlinx/dataframe/plugin/utils/Names.kt +++ b/plugins/kotlin-dataframe/src/org/jetbrains/kotlinx/dataframe/plugin/utils/Names.kt @@ -16,6 +16,9 @@ import kotlin.reflect.KClass object Names { val DF_CLASS_ID: ClassId get() = ClassId.topLevel(FqName.fromSegments(listOf("org", "jetbrains", "kotlinx", "dataframe", "DataFrame"))) + val GROUP_BY_CLASS_ID: ClassId + get() = ClassId.topLevel(FqName.fromSegments(listOf("org", "jetbrains", "kotlinx", "dataframe", "api", "GroupBy"))) + val COLUM_GROUP_CLASS_ID: ClassId get() = ClassId(FqName("org.jetbrains.kotlinx.dataframe.columns"), Name.identifier("ColumnGroup")) val DATA_COLUMN_CLASS_ID: ClassId diff --git a/plugins/kotlin-dataframe/testData/box/addDsl.kt b/plugins/kotlin-dataframe/testData/box/addDsl.kt new file mode 100644 index 000000000..2cc4cb84f --- /dev/null +++ b/plugins/kotlin-dataframe/testData/box/addDsl.kt @@ -0,0 +1,16 @@ +import org.jetbrains.kotlinx.dataframe.* +import org.jetbrains.kotlinx.dataframe.annotations.* +import org.jetbrains.kotlinx.dataframe.api.* +import org.jetbrains.kotlinx.dataframe.io.* + +fun box(): String { + val df = dataFrameOf("a")(1).add { + "id" from { it } + "group" { + "a" from { it } + } + } + + df.group.a + return "OK" +} diff --git a/plugins/kotlin-dataframe/testData/box/addId.kt b/plugins/kotlin-dataframe/testData/box/addId.kt new file mode 100644 index 000000000..3c08ba886 --- /dev/null +++ b/plugins/kotlin-dataframe/testData/box/addId.kt @@ -0,0 +1,11 @@ +import org.jetbrains.kotlinx.dataframe.* +import org.jetbrains.kotlinx.dataframe.annotations.* +import org.jetbrains.kotlinx.dataframe.api.* +import org.jetbrains.kotlinx.dataframe.io.* + +fun box(): String { + val df = dataFrameOf("a")(1).addId() + val i: DataColumn = df.id + val i1: DataColumn = dataFrameOf("a")(1).addId("i").i + return "OK" +} diff --git a/plugins/kotlin-dataframe/testData/box/groupBy_refine.kt b/plugins/kotlin-dataframe/testData/box/groupBy_refine.kt new file mode 100644 index 000000000..de22ef265 --- /dev/null +++ b/plugins/kotlin-dataframe/testData/box/groupBy_refine.kt @@ -0,0 +1,12 @@ +import org.jetbrains.kotlinx.dataframe.* +import org.jetbrains.kotlinx.dataframe.annotations.* +import org.jetbrains.kotlinx.dataframe.api.* +import org.jetbrains.kotlinx.dataframe.io.* + + +fun box(): String { + val df = dataFrameOf("a", "b", "c")(1,2,3) + val df1 = df.groupBy { a } + df1.keys.compareSchemas(strict = true) + return "OK" +} diff --git a/plugins/kotlin-dataframe/testData/box/mapToFrame.kt b/plugins/kotlin-dataframe/testData/box/mapToFrame.kt new file mode 100644 index 000000000..2aec74c8b --- /dev/null +++ b/plugins/kotlin-dataframe/testData/box/mapToFrame.kt @@ -0,0 +1,19 @@ +import org.jetbrains.kotlinx.dataframe.* +import org.jetbrains.kotlinx.dataframe.annotations.* +import org.jetbrains.kotlinx.dataframe.api.* +import org.jetbrains.kotlinx.dataframe.io.* + +fun box(): String { + val df = dataFrameOf("a")(1).mapToFrame { + "id" from { it } + "group" { + "a" from { it } + } + } + + df.id + df.group.a + + df.compareSchemas(strict = true) + return "OK" +} diff --git a/plugins/kotlin-dataframe/testData/box/moveToTop.kt b/plugins/kotlin-dataframe/testData/box/moveToTop.kt new file mode 100644 index 000000000..1ed3f934b --- /dev/null +++ b/plugins/kotlin-dataframe/testData/box/moveToTop.kt @@ -0,0 +1,12 @@ +import org.jetbrains.kotlinx.dataframe.* +import org.jetbrains.kotlinx.dataframe.annotations.* +import org.jetbrains.kotlinx.dataframe.api.* +import org.jetbrains.kotlinx.dataframe.io.* + +fun box(): String { + val df = dataFrameOf("a", "b")(1, 2).group { a and b }.into("c").move { c.a }.toTop() + df.a + df.c.b + df.compareSchemas() + return "OK" +} diff --git a/plugins/kotlin-dataframe/testData/box/update.kt b/plugins/kotlin-dataframe/testData/box/update.kt new file mode 100644 index 000000000..fdfdaf290 --- /dev/null +++ b/plugins/kotlin-dataframe/testData/box/update.kt @@ -0,0 +1,11 @@ +import org.jetbrains.kotlinx.dataframe.* +import org.jetbrains.kotlinx.dataframe.annotations.* +import org.jetbrains.kotlinx.dataframe.api.* +import org.jetbrains.kotlinx.dataframe.io.* + +fun box(): String { + val df = dataFrameOf("a", "b")(1, null, null, "") + val df1 = df.update { b }.with { "empty" } + val b: DataColumn = df1.b + return "OK" +} diff --git a/plugins/kotlin-dataframe/tests-gen/org/jetbrains/kotlin/fir/dataframe/DataFrameBlackBoxCodegenTestGenerated.java b/plugins/kotlin-dataframe/tests-gen/org/jetbrains/kotlin/fir/dataframe/DataFrameBlackBoxCodegenTestGenerated.java index c803cbde7..01049fb35 100644 --- a/plugins/kotlin-dataframe/tests-gen/org/jetbrains/kotlin/fir/dataframe/DataFrameBlackBoxCodegenTestGenerated.java +++ b/plugins/kotlin-dataframe/tests-gen/org/jetbrains/kotlin/fir/dataframe/DataFrameBlackBoxCodegenTestGenerated.java @@ -17,6 +17,18 @@ @TestMetadata("testData/box") @TestDataPath("$PROJECT_ROOT") public class DataFrameBlackBoxCodegenTestGenerated extends AbstractDataFrameBlackBoxCodegenTest { + @Test + @TestMetadata("addDsl.kt") + public void testAddDsl() { + runTest("testData/box/addDsl.kt"); + } + + @Test + @TestMetadata("addId.kt") + public void testAddId() { + runTest("testData/box/addId.kt"); + } + @Test public void testAllFilesPresentInBox() { KtTestUtil.assertAllTestsPresentByMetadataWithExcluded(this.getClass(), new File("testData/box"), Pattern.compile("^(.+)\\.kt$"), null, TargetBackend.JVM_IR, true); @@ -178,6 +190,12 @@ public void testGroupBy_DataRow() { runTest("testData/box/groupBy_DataRow.kt"); } + @Test + @TestMetadata("groupBy_refine.kt") + public void testGroupBy_refine() { + runTest("testData/box/groupBy_refine.kt"); + } + @Test @TestMetadata("groupBy_toDataFrame.kt") public void testGroupBy_toDataFrame() { @@ -232,6 +250,18 @@ public void testMain() { runTest("testData/box/main.kt"); } + @Test + @TestMetadata("mapToFrame.kt") + public void testMapToFrame() { + runTest("testData/box/mapToFrame.kt"); + } + + @Test + @TestMetadata("moveToTop.kt") + public void testMoveToTop() { + runTest("testData/box/moveToTop.kt"); + } + @Test @TestMetadata("nestedDataSchemaCodegen.kt") public void testNestedDataSchemaCodegen() { @@ -394,6 +424,12 @@ public void testUngroup() { runTest("testData/box/ungroup.kt"); } + @Test + @TestMetadata("update.kt") + public void testUpdate() { + runTest("testData/box/update.kt"); + } + @Nested @TestMetadata("testData/box/colKinds") @TestDataPath("$PROJECT_ROOT")