Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Use in-place commands when possible to increase speed #192

Merged
merged 10 commits into from
Sep 20, 2024
4 changes: 3 additions & 1 deletion Project.toml
Original file line number Diff line number Diff line change
@@ -1,9 +1,10 @@
name = "Kezdi"
uuid = "48308a23-c29e-446c-b4c0-d9446a767439"
authors = ["Miklos Koren <[email protected]>", "Gergely Attila Kiss <[email protected]>"]
version = "0.5.1"
version = "0.5.2"

[deps]
BenchmarkTools = "6e4b80f9-dd63-53aa-95a3-0cdb28fa8baf"
CSV = "336ed68f-0bac-5ca0-87d4-7b16caf5d00b"
Crayons = "a8cc5b0e-0ffa-5ad4-8c14-923d3ee1735f"
DataFrames = "a93c6f00-e57d-5684-b7b6-d8193f3e46c0"
Expand All @@ -21,6 +22,7 @@ Statistics = "10745b16-79ce-11e8-11f9-7d13ad32a3b2"
StatsBase = "2913bbd2-ae8a-5f71-8c99-4fb6c76f3a91"

[compat]
BenchmarkTools = "1"
CSV = "0.10"
Crayons = "4"
DataFrames = "1"
Expand Down
24 changes: 20 additions & 4 deletions docs/examples/benchmark.do
Original file line number Diff line number Diff line change
Expand Up @@ -10,11 +10,27 @@ gen i = _n
set seed 12345
gen g = floor(runiform() * 100)

timer clear 1
preserve
timer on 1
generate ln_i = log(i)
timer off 1
restore
timer list 1

timer clear 1
preserve
timer on 1
replace g = 2*i
timer off 1
restore
timer list 1

* Measure time for mean calculation by group
timer clear 1
preserve
timer on 1
egen mean_i = mean(i), by(g)
egen mean_i = mean(i), by(g)
timer off 1
restore
timer list 1
Expand All @@ -23,7 +39,7 @@ timer list 1
preserve
timer clear 3
timer on 3
collapse (mean) mean_i=i, by(g)
collapse (mean) mean_i=i, by(g)
timer off 3
restore
timer list 3
Expand All @@ -38,15 +54,15 @@ timer list 5
* Measure time for summarize
timer clear 7
timer on 7
summarize g, detail
summarize g, detail
timer off 7
timer list 7

* Measure time for regress with condition
preserve
timer clear 9
timer on 9
regress i g if g > 50
regress i g if g > 50
timer off 9
restore
timer list 9
12 changes: 12 additions & 0 deletions docs/examples/benchmark.jl
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,18 @@ using Pkg; Pkg.precompile()
df = DataFrame(i = 1:10_000_000)
df.g = rand(0:99, nrow(df))


println("Generate")
setdf(df)
@time @generate ln_i = log(i)
setdf(df)
@time @generate ln_i = log(i)

setdf(df)

println("Replace")
@btime @replace g = 2*i

println("Egen")
@btime @with df @egen mean_i = mean(i), by(g)

Expand Down
16 changes: 9 additions & 7 deletions docs/src/index.md
Original file line number Diff line number Diff line change
Expand Up @@ -65,13 +65,15 @@ end
### Free and open-source
### Speed

| Command | Stata | Julia 1st run | Julia 2nd run | Speedup |
| ------------ | ----- | ------------- | ------------- | ------- |
| `@egen` | 4.90s | 1.36s | 0.36s | 14x |
| `@collapse` | 0.92s | 0.39s | 0.28s | 3x |
| `@tabulate` | 2.14s | 0.68s | 0.09s | 24x |
| `@summarize` | 10.40s | 0.58s | 0.36s | 29x |
| `@regress` | 0.89s | 1.95s | 0.11s | 8x |
| Command | Stata | Julia 2nd run | Speedup |
| ------------ | ----- | ------------- | ------- |
| `@generate` | 230ms | 46ms | 5x |
| `@replace` | 232ms | 32ms | 7x |
| `@egen` | 5.00s | 0.37s | 13x |
| `@collapse` | 0.94s | 0.28s | 3x |
| `@tabulate` | 2.19s | 0.09s | 24x |
| `@summarize` | 10.56s | 0.35s | 30x |
| `@regress` | 0.85s | 0.14s | 6x |

See the benchmarking code for [Stata](https://github.com/codedthinking/Kezdi.jl/blob/main/docs/examples/benchmark.do) and [Kezdi.jl](https://github.com/codedthinking/Kezdi.jl/blob/main/docs/examples/benchmark.jl).

Expand Down
2 changes: 1 addition & 1 deletion src/codegen.jl
Original file line number Diff line number Diff line change
Expand Up @@ -30,7 +30,7 @@ function generate_command(command::Command; options=[], allowed=[])

push!(setup, :(println("$(Kezdi.prompt())$($(string(command)))\n")))
push!(setup, :(getdf() isa AbstractDataFrame || error("Kezdi.jl commands can only operate on a global DataFrame set by setdf()")))
push!(setup, :(local $df2 = copy(getdf())))
push!(setup, :(local $df2 = Kezdi._global_dataframe))
variables_condition = (:ifable in options) ? vcat(extract_column_references(command.condition)...) : Symbol[]
variables_RHS = (:variables in options) ? vcat(extract_column_references.(command.arguments)...) : Symbol[]
variables = vcat(variables_condition, variables_RHS)
Expand Down
16 changes: 8 additions & 8 deletions src/commands.jl
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@ function rewrite(::Val{:rename}, command::Command)
quote
(length($arguments) != 2) && ArgumentError("Syntax is @rename oldname newname") |> throw
$setup
rename!($local_copy, $arguments[1] => $arguments[2]) |> $teardown |> setdf
rename!($local_copy, $arguments[1] => $arguments[2]) |> $teardown
end |> esc
end

Expand All @@ -21,7 +21,7 @@ function rewrite(::Val{:generate}, command::Command)
$setup
$local_copy[!, $target_column] .= missing
$target_df[!, $target_column] .= $RHS
$local_copy |> $teardown |> setdf
$local_copy |> $teardown
end |> esc
end

Expand All @@ -47,7 +47,7 @@ function rewrite(::Val{:replace}, command::Command)
else
$target_df[!, $target_column] .= $RHS
end
$local_copy |> $teardown |> setdf
$local_copy |> $teardown
end |> esc
end

Expand All @@ -67,7 +67,7 @@ function rewrite(::Val{:drop}, command::Command)
if isnothing(command.condition)
return quote
$setup
select($local_copy, Not(collect($(command.arguments)))) |> $teardown |> setdf
select!($local_copy, Not(collect($(command.arguments)))) |> $teardown |> setdf
end |> esc
end
bitmask = build_bitmask(local_copy, command.condition)
Expand Down Expand Up @@ -96,7 +96,7 @@ function rewrite(::Val{:egen}, command::Command)
($target_column in names(getdf())) && ArgumentError("Column \"$($target_column)\" already exists in $(names(getdf()))") |> throw
$setup
$transform_expression
$local_copy |> $teardown |> setdf
$local_copy |> $teardown
end |> esc
end

Expand All @@ -107,7 +107,7 @@ function rewrite(::Val{:sort}, command::Command)
desc = :desc in get_top_symbol.(options) ? true : false
quote
$setup
sort($target_df, $columns, rev=$desc) |> $teardown |> setdf
sort!($target_df, $columns, rev=$desc) |> $teardown
end |> esc
end

Expand Down Expand Up @@ -173,7 +173,7 @@ function rewrite(::Val{:order}, command::Command)
$cols = pushfirst!($cols, $target_cols...)
end

$target_df[!, $cols]|> $teardown |> setdf
$target_df[!, $cols]|> $teardown
end |> esc
end

Expand Down Expand Up @@ -202,6 +202,6 @@ function rewrite(::Val{:mvencode}, command::Command)
end
end
$local_copy[$bitmask, $cols] = mvreplace.($local_copy[$bitmask, $cols], $value)
$local_copy |> $teardown |> setdf
$local_copy |> $teardown
end |> esc
end
3 changes: 2 additions & 1 deletion src/functions.jl
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
use(fname::AbstractString) = readstat(fname) |> DataFrame |> setdf
save(fname::AbstractString) = writestat(fname, getdf())

function append(fname::AbstractString)
ispath(fname) || ArgumentError("File $fname does not exist.") |> throw
_, ext = splitext(fname)
Expand Down Expand Up @@ -48,7 +49,7 @@ getdf() = _global_dataframe

Set the global data frame.
"""
setdf(df::Union{AbstractDataFrame, Nothing}) = global _global_dataframe = df
setdf(df::Union{AbstractDataFrame, Nothing}) = global _global_dataframe = isnothing(df) ? nothing : copy(df)
display_and_return(x) = (display(x); x)

"""
Expand Down
5 changes: 5 additions & 0 deletions test/runtests.jl
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@ using Test
using Expronicon
using Kezdi
using Logging
using BenchmarkTools

macro return_arguments(expr)
return (expr,)
Expand Down Expand Up @@ -42,4 +43,8 @@ end
@testset "Functions" begin
include("functions.jl")
end

@testset "Speed" begin
include("speed.jl")
end
end # all tests
20 changes: 20 additions & 0 deletions test/speed.jl
Original file line number Diff line number Diff line change
@@ -0,0 +1,20 @@
@testset "Generate completes within 30 seconds" begin
df = DataFrame(rand(20_000_000, 20), :auto)

t = @benchmark let df = $df
@with df begin
@generate ln_x1 = log(x1)
@generate ln_x2 = log(x2)
@generate ln_x3 = log(x3)
@generate ln_x4 = log(x4)
@generate ln_x5 = log(x5)
@generate ln_x6 = log(x6)
@generate ln_x7 = log(x7)
@generate ln_x8 = log(x8)
@generate ln_x9 = log(x9)
end
end

time = median(t).time / 1e9
@test time < 30.0
end
Binary file modified test/test.dta
Binary file not shown.
Loading