diff --git a/core/generated-sources/src/main/kotlin/org/jetbrains/kotlinx/dataframe/aggregation/AggregateGroupedDsl.kt b/core/generated-sources/src/main/kotlin/org/jetbrains/kotlinx/dataframe/aggregation/AggregateGroupedDsl.kt index 43b529522..105c56aa2 100644 --- a/core/generated-sources/src/main/kotlin/org/jetbrains/kotlinx/dataframe/aggregation/AggregateGroupedDsl.kt +++ b/core/generated-sources/src/main/kotlin/org/jetbrains/kotlinx/dataframe/aggregation/AggregateGroupedDsl.kt @@ -1,3 +1,7 @@ package org.jetbrains.kotlinx.dataframe.aggregation -public abstract class AggregateGroupedDsl : AggregateDsl() +import org.jetbrains.kotlinx.dataframe.AnyRow + +public abstract class AggregateGroupedDsl : AggregateDsl() { + public abstract val keys: AnyRow +} diff --git a/core/generated-sources/src/main/kotlin/org/jetbrains/kotlinx/dataframe/documentation/AccessApi.kt b/core/generated-sources/src/main/kotlin/org/jetbrains/kotlinx/dataframe/documentation/AccessApi.kt index 960b24b79..1fa61a0b2 100644 --- a/core/generated-sources/src/main/kotlin/org/jetbrains/kotlinx/dataframe/documentation/AccessApi.kt +++ b/core/generated-sources/src/main/kotlin/org/jetbrains/kotlinx/dataframe/documentation/AccessApi.kt @@ -122,7 +122,7 @@ internal interface AccessApi { * * For example: * ```kotlin - * val df = DataFrame.read("titanic.csv") + * val df /* : AnyFrame */ = DataFrame.read("titanic.csv") * ``` */ interface ExtensionPropertiesApi diff --git a/core/generated-sources/src/main/kotlin/org/jetbrains/kotlinx/dataframe/impl/GroupByImpl.kt b/core/generated-sources/src/main/kotlin/org/jetbrains/kotlinx/dataframe/impl/GroupByImpl.kt index bf58b3a66..6860343d5 100644 --- a/core/generated-sources/src/main/kotlin/org/jetbrains/kotlinx/dataframe/impl/GroupByImpl.kt +++ b/core/generated-sources/src/main/kotlin/org/jetbrains/kotlinx/dataframe/impl/GroupByImpl.kt @@ -48,14 +48,24 @@ internal class GroupByImpl( override fun updateGroups(transform: Selector, DataFrame>) = df.convert(groups) { transform(it, it) }.asGroupBy(groups.name()) as GroupBy - override fun toDataFrame(groupedColumnName: String?) = if (groupedColumnName == null || groupedColumnName == groups.name()) df else df.rename(groups).into(groupedColumnName) + override fun toDataFrame(groupedColumnName: String?) = + if (groupedColumnName == null || groupedColumnName == groups.name()) { + df + } else { + df.rename(groups).into(groupedColumnName) + } override fun toString() = df.toString() override fun remainingColumnsSelector(): ColumnsSelector<*, *> = keyColumnsInGroups.toColumnSet().let { groupCols -> { all().except(groupCols) } } - override fun aggregate(body: AggregateGroupedBody) = aggregateGroupBy(toDataFrame(), { groups }, removeColumns = true, body).cast() + override fun aggregate(body: AggregateGroupedBody) = aggregateGroupBy( + df = toDataFrame(), + selector = { groups }, + removeColumns = true, + body = body, + ).cast() override fun filter(predicate: GroupedRowFilter): GroupBy { val indices = (0 until df.nrow).filter { @@ -78,12 +88,13 @@ internal fun aggregateGroupBy( val removed = df.removeImpl(columns = selector) - val hasKeyColumns = removed.df.ncol > 0 + val keys = removed.df + val hasKeyColumns = keys.ncol > 0 - val groupedFrame = column.values.map { + val groupedFrame = column.values.mapIndexed { i, it -> if (it == null) null else { - val builder = GroupByReceiverImpl(it, hasKeyColumns) + val builder = GroupByReceiverImpl(it, hasKeyColumns) { keys[i] } val result = body(builder, builder) if (result != Unit && result !is NamedValue && result !is AggregatedPivot<*>) builder.yield( NamedValue.create( diff --git a/core/generated-sources/src/main/kotlin/org/jetbrains/kotlinx/dataframe/impl/aggregation/GroupByReceiverImpl.kt b/core/generated-sources/src/main/kotlin/org/jetbrains/kotlinx/dataframe/impl/aggregation/GroupByReceiverImpl.kt index 9b857c254..5860b29a6 100644 --- a/core/generated-sources/src/main/kotlin/org/jetbrains/kotlinx/dataframe/impl/aggregation/GroupByReceiverImpl.kt +++ b/core/generated-sources/src/main/kotlin/org/jetbrains/kotlinx/dataframe/impl/aggregation/GroupByReceiverImpl.kt @@ -18,12 +18,22 @@ import org.jetbrains.kotlinx.dataframe.impl.createTypeWithArgument import org.jetbrains.kotlinx.dataframe.impl.getListType import kotlin.reflect.KType -internal class GroupByReceiverImpl(override val df: DataFrame, override val hasGroupingKeys: Boolean) : +internal class GroupByReceiverImpl( + override val df: DataFrame, + override val hasGroupingKeys: Boolean, + private val retrieveKey: () -> AnyRow = { + error("This property can only be used inside 'groupBy { }.aggregate { }' clause") + } +) : AggregateGroupedDsl(), AggregateInternalDsl, AggregatableInternal by df as AggregatableInternal, DataFrame by df { + override val keys by lazy { + retrieveKey() + } + private val values = mutableListOf() internal fun child(): GroupByReceiverImpl { @@ -41,16 +51,41 @@ internal class GroupByReceiverImpl(override val df: DataFrame, override va allValues.add(it) } } + is ValueColumn<*> -> { - allValues.add(NamedValue.create(it.path, it.value.toList(), getListType(it.value.type()), emptyList())) + allValues.add( + NamedValue.create( + it.path, + it.value.toList(), + getListType(it.value.type()), + emptyList() + ) + ) } + is ColumnGroup<*> -> { val frameType = it.value.type().arguments.singleOrNull()?.type - allValues.add(NamedValue.create(it.path, it.value.asDataFrame(), DataFrame::class.createTypeWithArgument(frameType), DataFrame.Empty)) + allValues.add( + NamedValue.create( + it.path, + it.value.asDataFrame(), + DataFrame::class.createTypeWithArgument(frameType), + DataFrame.Empty + ) + ) } + is FrameColumn<*> -> { - allValues.add(NamedValue.create(it.path, it.value.toList(), getListType(it.value.type()), emptyList())) + allValues.add( + NamedValue.create( + it.path, + it.value.toList(), + getListType(it.value.type()), + emptyList() + ) + ) } + else -> { allValues.add(it) } @@ -70,7 +105,9 @@ internal class GroupByReceiverImpl(override val df: DataFrame, override va when (value.value) { is AggregatedPivot<*> -> { val pivot = value.value - val dropFirstNameInPath = pivot.inward == true && value.path.isNotEmpty() && pivot.aggregator.values.distinctBy { it.path.firstOrNull() }.count() == 1 + val dropFirstNameInPath = + pivot.inward == true && value.path.isNotEmpty() && pivot.aggregator.values.distinctBy { it.path.firstOrNull() } + .count() == 1 pivot.aggregator.values.forEach { val targetPath = if (dropFirstNameInPath && it.path.size > 0) value.path + it.path.dropFirst() @@ -80,6 +117,7 @@ internal class GroupByReceiverImpl(override val df: DataFrame, override va } pivot.aggregator.values.clear() } + is AggregateInternalDsl<*> -> yield(value.copy(value = value.value.df)) else -> values.add(value) } diff --git a/core/generated-sources/src/test/kotlin/org/jetbrains/kotlinx/dataframe/api/groupBy.kt b/core/generated-sources/src/test/kotlin/org/jetbrains/kotlinx/dataframe/api/groupBy.kt index 4f7e78748..59770fa72 100644 --- a/core/generated-sources/src/test/kotlin/org/jetbrains/kotlinx/dataframe/api/groupBy.kt +++ b/core/generated-sources/src/test/kotlin/org/jetbrains/kotlinx/dataframe/api/groupBy.kt @@ -55,4 +55,21 @@ class GroupByTests { getFrameColumn("d") into "e" }["e"].type() shouldBe typeOf>() } + + @Test + fun `aggregate based on the key column`() { + val df = dataFrameOf( + "a", "b", "c" + )( + 1, 2, 3, + 4, 5, 6, + ) + val grouped = df.groupBy { expr("test") { "a"() + "b"() } } + .aggregate { + count() into "count" + keys into "keys" + } + + grouped.print() + } } diff --git a/core/generated-sources/src/test/kotlin/org/jetbrains/kotlinx/dataframe/samples/api/ApiLevels.kt b/core/generated-sources/src/test/kotlin/org/jetbrains/kotlinx/dataframe/samples/api/ApiLevels.kt index f37615b63..f579f53bd 100644 --- a/core/generated-sources/src/test/kotlin/org/jetbrains/kotlinx/dataframe/samples/api/ApiLevels.kt +++ b/core/generated-sources/src/test/kotlin/org/jetbrains/kotlinx/dataframe/samples/api/ApiLevels.kt @@ -140,7 +140,7 @@ class ApiLevels { @TransformDataFrameExpressions fun extensionProperties1() { // SampleStart - val df = DataFrame.read("titanic.csv") + val df /* : AnyFrame */ = DataFrame.read("titanic.csv") // SampleEnd } } diff --git a/core/src/main/kotlin/org/jetbrains/kotlinx/dataframe/aggregation/AggregateGroupedDsl.kt b/core/src/main/kotlin/org/jetbrains/kotlinx/dataframe/aggregation/AggregateGroupedDsl.kt index 43b529522..105c56aa2 100644 --- a/core/src/main/kotlin/org/jetbrains/kotlinx/dataframe/aggregation/AggregateGroupedDsl.kt +++ b/core/src/main/kotlin/org/jetbrains/kotlinx/dataframe/aggregation/AggregateGroupedDsl.kt @@ -1,3 +1,7 @@ package org.jetbrains.kotlinx.dataframe.aggregation -public abstract class AggregateGroupedDsl : AggregateDsl() +import org.jetbrains.kotlinx.dataframe.AnyRow + +public abstract class AggregateGroupedDsl : AggregateDsl() { + public abstract val keys: AnyRow +} diff --git a/core/src/main/kotlin/org/jetbrains/kotlinx/dataframe/impl/GroupByImpl.kt b/core/src/main/kotlin/org/jetbrains/kotlinx/dataframe/impl/GroupByImpl.kt index bf58b3a66..6860343d5 100644 --- a/core/src/main/kotlin/org/jetbrains/kotlinx/dataframe/impl/GroupByImpl.kt +++ b/core/src/main/kotlin/org/jetbrains/kotlinx/dataframe/impl/GroupByImpl.kt @@ -48,14 +48,24 @@ internal class GroupByImpl( override fun updateGroups(transform: Selector, DataFrame>) = df.convert(groups) { transform(it, it) }.asGroupBy(groups.name()) as GroupBy - override fun toDataFrame(groupedColumnName: String?) = if (groupedColumnName == null || groupedColumnName == groups.name()) df else df.rename(groups).into(groupedColumnName) + override fun toDataFrame(groupedColumnName: String?) = + if (groupedColumnName == null || groupedColumnName == groups.name()) { + df + } else { + df.rename(groups).into(groupedColumnName) + } override fun toString() = df.toString() override fun remainingColumnsSelector(): ColumnsSelector<*, *> = keyColumnsInGroups.toColumnSet().let { groupCols -> { all().except(groupCols) } } - override fun aggregate(body: AggregateGroupedBody) = aggregateGroupBy(toDataFrame(), { groups }, removeColumns = true, body).cast() + override fun aggregate(body: AggregateGroupedBody) = aggregateGroupBy( + df = toDataFrame(), + selector = { groups }, + removeColumns = true, + body = body, + ).cast() override fun filter(predicate: GroupedRowFilter): GroupBy { val indices = (0 until df.nrow).filter { @@ -78,12 +88,13 @@ internal fun aggregateGroupBy( val removed = df.removeImpl(columns = selector) - val hasKeyColumns = removed.df.ncol > 0 + val keys = removed.df + val hasKeyColumns = keys.ncol > 0 - val groupedFrame = column.values.map { + val groupedFrame = column.values.mapIndexed { i, it -> if (it == null) null else { - val builder = GroupByReceiverImpl(it, hasKeyColumns) + val builder = GroupByReceiverImpl(it, hasKeyColumns) { keys[i] } val result = body(builder, builder) if (result != Unit && result !is NamedValue && result !is AggregatedPivot<*>) builder.yield( NamedValue.create( diff --git a/core/src/main/kotlin/org/jetbrains/kotlinx/dataframe/impl/aggregation/GroupByReceiverImpl.kt b/core/src/main/kotlin/org/jetbrains/kotlinx/dataframe/impl/aggregation/GroupByReceiverImpl.kt index 9b857c254..5860b29a6 100644 --- a/core/src/main/kotlin/org/jetbrains/kotlinx/dataframe/impl/aggregation/GroupByReceiverImpl.kt +++ b/core/src/main/kotlin/org/jetbrains/kotlinx/dataframe/impl/aggregation/GroupByReceiverImpl.kt @@ -18,12 +18,22 @@ import org.jetbrains.kotlinx.dataframe.impl.createTypeWithArgument import org.jetbrains.kotlinx.dataframe.impl.getListType import kotlin.reflect.KType -internal class GroupByReceiverImpl(override val df: DataFrame, override val hasGroupingKeys: Boolean) : +internal class GroupByReceiverImpl( + override val df: DataFrame, + override val hasGroupingKeys: Boolean, + private val retrieveKey: () -> AnyRow = { + error("This property can only be used inside 'groupBy { }.aggregate { }' clause") + } +) : AggregateGroupedDsl(), AggregateInternalDsl, AggregatableInternal by df as AggregatableInternal, DataFrame by df { + override val keys by lazy { + retrieveKey() + } + private val values = mutableListOf() internal fun child(): GroupByReceiverImpl { @@ -41,16 +51,41 @@ internal class GroupByReceiverImpl(override val df: DataFrame, override va allValues.add(it) } } + is ValueColumn<*> -> { - allValues.add(NamedValue.create(it.path, it.value.toList(), getListType(it.value.type()), emptyList())) + allValues.add( + NamedValue.create( + it.path, + it.value.toList(), + getListType(it.value.type()), + emptyList() + ) + ) } + is ColumnGroup<*> -> { val frameType = it.value.type().arguments.singleOrNull()?.type - allValues.add(NamedValue.create(it.path, it.value.asDataFrame(), DataFrame::class.createTypeWithArgument(frameType), DataFrame.Empty)) + allValues.add( + NamedValue.create( + it.path, + it.value.asDataFrame(), + DataFrame::class.createTypeWithArgument(frameType), + DataFrame.Empty + ) + ) } + is FrameColumn<*> -> { - allValues.add(NamedValue.create(it.path, it.value.toList(), getListType(it.value.type()), emptyList())) + allValues.add( + NamedValue.create( + it.path, + it.value.toList(), + getListType(it.value.type()), + emptyList() + ) + ) } + else -> { allValues.add(it) } @@ -70,7 +105,9 @@ internal class GroupByReceiverImpl(override val df: DataFrame, override va when (value.value) { is AggregatedPivot<*> -> { val pivot = value.value - val dropFirstNameInPath = pivot.inward == true && value.path.isNotEmpty() && pivot.aggregator.values.distinctBy { it.path.firstOrNull() }.count() == 1 + val dropFirstNameInPath = + pivot.inward == true && value.path.isNotEmpty() && pivot.aggregator.values.distinctBy { it.path.firstOrNull() } + .count() == 1 pivot.aggregator.values.forEach { val targetPath = if (dropFirstNameInPath && it.path.size > 0) value.path + it.path.dropFirst() @@ -80,6 +117,7 @@ internal class GroupByReceiverImpl(override val df: DataFrame, override va } pivot.aggregator.values.clear() } + is AggregateInternalDsl<*> -> yield(value.copy(value = value.value.df)) else -> values.add(value) } diff --git a/core/src/test/kotlin/org/jetbrains/kotlinx/dataframe/api/groupBy.kt b/core/src/test/kotlin/org/jetbrains/kotlinx/dataframe/api/groupBy.kt index 4f7e78748..59770fa72 100644 --- a/core/src/test/kotlin/org/jetbrains/kotlinx/dataframe/api/groupBy.kt +++ b/core/src/test/kotlin/org/jetbrains/kotlinx/dataframe/api/groupBy.kt @@ -55,4 +55,21 @@ class GroupByTests { getFrameColumn("d") into "e" }["e"].type() shouldBe typeOf>() } + + @Test + fun `aggregate based on the key column`() { + val df = dataFrameOf( + "a", "b", "c" + )( + 1, 2, 3, + 4, 5, 6, + ) + val grouped = df.groupBy { expr("test") { "a"() + "b"() } } + .aggregate { + count() into "count" + keys into "keys" + } + + grouped.print() + } } diff --git a/core/src/test/kotlin/org/jetbrains/kotlinx/dataframe/samples/api/ApiLevels.kt b/core/src/test/kotlin/org/jetbrains/kotlinx/dataframe/samples/api/ApiLevels.kt index f37615b63..f579f53bd 100644 --- a/core/src/test/kotlin/org/jetbrains/kotlinx/dataframe/samples/api/ApiLevels.kt +++ b/core/src/test/kotlin/org/jetbrains/kotlinx/dataframe/samples/api/ApiLevels.kt @@ -140,7 +140,7 @@ class ApiLevels { @TransformDataFrameExpressions fun extensionProperties1() { // SampleStart - val df = DataFrame.read("titanic.csv") + val df /* : AnyFrame */ = DataFrame.read("titanic.csv") // SampleEnd } } diff --git a/docs/StardustDocs/topics/extensionPropertiesApi.md b/docs/StardustDocs/topics/extensionPropertiesApi.md index 00bfba268..ccbe2a08a 100644 --- a/docs/StardustDocs/topics/extensionPropertiesApi.md +++ b/docs/StardustDocs/topics/extensionPropertiesApi.md @@ -2,19 +2,30 @@ -When [`DataFrame`](DataFrame.md) is used within Jupyter Notebooks or Datalore with Kotlin Kernel, -after every cell execution all new global variables of type DataFrame are analyzed and replaced -with typed [`DataFrame`](DataFrame.md) wrapper with auto-generated extension properties for data access: +When [`DataFrame`](DataFrame.md) is used within Jupyter/Kotlin Notebook or Datalore with the Kotlin Kernel, +something special happens: +After every cell execution, all new global variables of type DataFrame are analyzed and replaced +with a typed [`DataFrame`](DataFrame.md) wrapper along with auto-generated extension properties for data access. +For instance, say we run: ```kotlin -val df = DataFrame.read("titanic.csv") +val df /* : AnyFrame */ = DataFrame.read("titanic.csv") ``` -Now data can be accessed by `.` member accessor +In normal Kotlin code, we would now have a variable of type [`AnyFrame` (=`DataFrame<*>`)](DataFrame.md) that doesn't have any +extension properties to access its columns. We would either have to define them manually or use the +[`@DataSchema`](schemas.md) annotation to [generate them](schemasGradle.md#configuration). + +By contrast, after this cell is run in a notebook, the columns of the dataframe are used as a basis +to generate a hidden `@DataSchema interface TypeX`, +along with extension properties like `val DataFrame.age` etc. +Next, the `df` variable is shadowed by a new version cast to `DataFrame`. + +As a result, now columns can be accessed directly on `df`! @@ -28,12 +39,9 @@ df.add("lastName") { name.split(",").last() } The `titanic.csv` file could be found [here](https://github.com/Kotlin/dataframe/blob/master/data/titanic.csv). -In notebooks, extension properties are generated for [`DataSchema`](schemas.md) that is extracted from [`DataFrame`](DataFrame.md) -instance after REPL line execution. -After that [`DataFrame`](DataFrame.md) variable is typed with its own [`DataSchema`](schemas.md), so only valid extension properties corresponding to actual columns in DataFrame will be allowed by the compiler and suggested by completion. - Extension properties can be generated in IntelliJ IDEA using the [Kotlin Dataframe Gradle plugin](schemasGradle.md#configuration). -In notebooks generated properties won't appear and be updated until the cell has been executed. It often means that you have to introduce new variable frequently to sync extension properties with actual schema +In notebooks generated properties won't appear and be updated until the cell has been executed. +It often means that you have to introduce new variable frequently to sync extension properties with actual schema.