Skip to content

Commit

Permalink
Merge pull request #851 from Kotlin/compiler-plugin-functions-1
Browse files Browse the repository at this point in the history
Compiler plugin update
  • Loading branch information
koperagen authored Sep 3, 2024
2 parents ea7dbe8 + a61f4a7 commit e32d6ac
Show file tree
Hide file tree
Showing 26 changed files with 580 additions and 203 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -207,6 +207,7 @@ public class AddDsl<T>(

public infix fun <R> ColumnReference<R>.into(column: KProperty<R>): Boolean = into(column.name)

@Interpretable("AddDslStringInvoke")
public operator fun String.invoke(body: AddDsl<T>.() -> Unit): Unit = group(this, body)

public infix fun AnyColumnGroupAccessor.from(body: AddDsl<T>.() -> Unit): Unit = group(this, body)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,8 @@ package org.jetbrains.kotlinx.dataframe.api
import org.jetbrains.kotlinx.dataframe.AnyCol
import org.jetbrains.kotlinx.dataframe.AnyFrame
import org.jetbrains.kotlinx.dataframe.DataFrame
import org.jetbrains.kotlinx.dataframe.annotations.Interpretable
import org.jetbrains.kotlinx.dataframe.annotations.Refine
import org.jetbrains.kotlinx.dataframe.columns.ColumnAccessor

// region DataColumn
Expand All @@ -15,6 +17,8 @@ public fun AnyCol.addId(columnName: String = "id"): AnyFrame = toDataFrame().add

public fun <T> DataFrame<T>.addId(column: ColumnAccessor<Int>): DataFrame<T> = insert(column) { index() }.at(0)

@Refine
@Interpretable("AddId")
public fun <T> DataFrame<T>.addId(columnName: String = "id"): DataFrame<T> = insert(columnName) { index() }.at(0)

// endregion
Original file line number Diff line number Diff line change
Expand Up @@ -29,6 +29,7 @@ import kotlin.reflect.KProperty
*
* `df.add("columnName") { "someColumn"<Int>() + 15 }.groupBy("columnName")`
*/
@Refine
@Interpretable("DataFrameGroupBy")
public fun <T> DataFrame<T>.groupBy(moveToTop: Boolean = true, cols: ColumnsSelector<T, *>): GroupBy<T, T> =
groupByImpl(moveToTop, cols)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,8 @@ import org.jetbrains.kotlinx.dataframe.DataFrame
import org.jetbrains.kotlinx.dataframe.DataRow
import org.jetbrains.kotlinx.dataframe.RowExpression
import org.jetbrains.kotlinx.dataframe.Selector
import org.jetbrains.kotlinx.dataframe.annotations.Interpretable
import org.jetbrains.kotlinx.dataframe.annotations.Refine
import org.jetbrains.kotlinx.dataframe.columns.ColumnReference
import org.jetbrains.kotlinx.dataframe.columns.FrameColumn
import org.jetbrains.kotlinx.dataframe.impl.columnName
Expand Down Expand Up @@ -102,6 +104,8 @@ public fun <T, R> ColumnsContainer<T>.mapToColumn(
body: AddExpression<T, R>,
): DataColumn<R> = mapToColumn(column.columnName, type, infer, body)

@Refine
@Interpretable("MapToFrame")
public fun <T> DataFrame<T>.mapToFrame(body: AddDsl<T>.() -> Unit): AnyFrame {
val dsl = AddDsl(this)
body(dsl)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,8 @@ import org.jetbrains.kotlinx.dataframe.AnyColumnReference
import org.jetbrains.kotlinx.dataframe.ColumnSelector
import org.jetbrains.kotlinx.dataframe.ColumnsSelector
import org.jetbrains.kotlinx.dataframe.DataFrame
import org.jetbrains.kotlinx.dataframe.annotations.Interpretable
import org.jetbrains.kotlinx.dataframe.annotations.Refine
import org.jetbrains.kotlinx.dataframe.columns.ColumnReference
import org.jetbrains.kotlinx.dataframe.columns.ColumnWithPath
import org.jetbrains.kotlinx.dataframe.columns.toColumnSet
Expand All @@ -18,6 +20,7 @@ import kotlin.reflect.KProperty

// region move

@Interpretable("Move0")
public fun <T, C> DataFrame<T>.move(columns: ColumnsSelector<T, C>): MoveClause<T, C> = MoveClause(this, columns)

public fun <T> DataFrame<T>.move(vararg cols: String): MoveClause<T, Any?> = move { cols.toColumnSet() }
Expand Down Expand Up @@ -120,6 +123,8 @@ public fun <T, C> MoveClause<T, C>.under(

public fun <T, C> MoveClause<T, C>.to(columnIndex: Int): DataFrame<T> = moveTo(columnIndex)

@Refine
@Interpretable("ToTop")
public fun <T, C> MoveClause<T, C>.toTop(
newColumnName: ColumnsSelectionDsl<T>.(ColumnWithPath<C>) -> String = { it.name() },
): DataFrame<T> = into { newColumnName(it).toColumnAccessor() }
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -155,6 +155,7 @@ private interface UpdateWithNote
* @include [SelectingColumns.Dsl.WithExample] {@include [SetSelectingColumnsOperationArg]}
* @include [Update.DslParam]
*/
@Interpretable("Update0")
public fun <T, C> DataFrame<T>.update(columns: ColumnsSelector<T, C>): Update<T, C> = Update(this, null, columns)

/**
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -12,42 +12,33 @@ import org.jetbrains.kotlin.fir.types.ConeKotlinType
import org.jetbrains.kotlin.fir.types.ConeTypeProjection
import org.jetbrains.kotlin.fir.types.classId
import org.jetbrains.kotlin.fir.types.resolvedType
import org.jetbrains.kotlin.name.ClassId
import org.jetbrains.kotlin.name.Name
import org.jetbrains.kotlinx.dataframe.plugin.extensions.KotlinTypeFacade
import org.jetbrains.kotlinx.dataframe.plugin.impl.Interpreter
import org.jetbrains.kotlinx.dataframe.plugin.impl.PluginDataFrameSchema
import org.jetbrains.kotlinx.dataframe.plugin.impl.api.aggregate
import org.jetbrains.kotlinx.dataframe.plugin.utils.Names.DF_CLASS_ID

fun KotlinTypeFacade.analyzeRefinedCallShape(call: FirFunctionCall, reporter: InterpretationErrorReporter): CallResult? {
internal inline fun <reified T> KotlinTypeFacade.analyzeRefinedCallShape(
call: FirFunctionCall,
expectedReturnType: ClassId,
reporter: InterpretationErrorReporter
): CallResult<T>? {
val callReturnType = call.resolvedType
if (callReturnType.classId != DF_CLASS_ID) return null
val rootMarker = callReturnType.typeArguments[0]
if (callReturnType.classId != expectedReturnType) return null
// rootMarker is expected to be a token generated by the plugin.
// it's implied by "refined call"
// thus ConeClassLikeType
if (rootMarker !is ConeClassLikeType) {
return null
}
val rootMarkers = callReturnType.typeArguments.filterIsInstance<ConeClassLikeType>()
if (rootMarkers.size != callReturnType.typeArguments.size) return null

val newSchema: PluginDataFrameSchema = call.interpreterName(session)?.let { name ->
val newSchema: T = call.interpreterName(session)?.let { name ->
when (name) {
"Aggregate" -> {
val groupByCall = call.explicitReceiver as? FirFunctionCall
val interpreter = groupByCall?.loadInterpreter(session)
if (interpreter != null) {
aggregate(groupByCall, interpreter, reporter, call)
} else {
PluginDataFrameSchema(emptyList())
}
}
else -> name.load<Interpreter<*>>().let { processor ->
val dataFrameSchema = interpret(call, processor, reporter = reporter)
.let {
val value = it?.value
if (value !is PluginDataFrameSchema) {
if (value !is T) {
if (!reporter.errorReported) {
reporter.reportInterpretationError(call, "${processor::class} must return ${PluginDataFrameSchema::class}, but was ${value}")
reporter.reportInterpretationError(call, "${processor::class} must return ${T::class}, but was $value")
}
return null
}
Expand All @@ -58,10 +49,10 @@ fun KotlinTypeFacade.analyzeRefinedCallShape(call: FirFunctionCall, reporter: In
}
} ?: return null

return CallResult(rootMarker, newSchema)
return CallResult(rootMarkers, newSchema)
}

data class CallResult(val rootMarker: ConeClassLikeType, val newSchema: PluginDataFrameSchema)
data class CallResult<T>(val markers: List<ConeClassLikeType>, val result: T)

class RefinedArguments(val refinedArguments: List<RefinedArgument>) : List<RefinedArgument> by refinedArguments

Expand Down
Loading

0 comments on commit e32d6ac

Please sign in to comment.