diff --git a/src/Kezdi.jl b/src/Kezdi.jl index 9348046..efbea14 100644 --- a/src/Kezdi.jl +++ b/src/Kezdi.jl @@ -2,9 +2,9 @@ Kezdi.jl is a Julia package for data manipulation and analysis. It is inspired by Stata, but it is written in Julia, which makes it faster and more flexible. It is designed to be used in the Julia REPL, but it can also be used in Jupyter notebooks or in scripts. """ module Kezdi -export @generate, @replace, @egen, @collapse, @keep, @drop, @summarize, @regress, @use, @tabulate, @count, @sort, @order, @list, @head, @tail, @names, @rename, @clear, @describe +export @generate, @replace, @egen, @collapse, @keep, @drop, @summarize, @regress, @use, @tabulate, @count, @sort, @order, @list, @head, @tail, @names, @rename, @clear, @describe, @mvencode, @save -export getdf, setdf, display_and_return, keep_only_values, rowcount, distinct, cond +export getdf, setdf, display_and_return, keep_only_values, rowcount, distinct, cond, mvreplace using Reexport using Logging diff --git a/src/commands.jl b/src/commands.jl index abaf4f5..500b1c5 100644 --- a/src/commands.jl +++ b/src/commands.jl @@ -31,16 +31,18 @@ function rewrite(::Val{:replace}, command::Command) target_column = get_LHS(command.arguments[1]) LHS, RHS = split_assignment(arguments[1]) third_vector = gensym() + eltype_LHS = gensym() + eltype_RHS = gensym() bitmask = build_bitmask(local_copy, command.condition) quote !($target_column in names(getdf())) && ArgumentError("Column \"$($target_column)\" does not exist in $(names(getdf()))") |> throw $setup - eltype_RHS = $RHS isa AbstractVector ? eltype($RHS) : typeof($RHS) - eltype_LHS = eltype($local_copy[.!$bitmask, $target_column]) - if eltype_RHS != eltype_LHS - local $third_vector = Vector{promote_type(eltype_LHS, eltype_RHS)}(undef, nrow($local_copy)) + $eltype_RHS = $RHS isa AbstractVector ? eltype($RHS) : typeof($RHS) + $eltype_LHS = eltype($local_copy[.!$bitmask, $target_column]) + if $eltype_RHS != $eltype_LHS + local $third_vector = Vector{promote_type($eltype_LHS, $eltype_RHS)}(undef, nrow($local_copy)) $third_vector[$bitmask] .= $RHS - $third_vector[.!$bitmask] .= $local_copy[.!$bitmask, $target_column] + $third_vector[.!($bitmask)] .= $local_copy[.!($bitmask), $target_column] $local_copy[!, $target_column] = $third_vector else $target_df[!, $target_column] .= $RHS @@ -52,9 +54,10 @@ end function rewrite(::Val{:keep}, command::Command) gc = generate_command(command; options=[:variables, :ifable, :nofunction]) (; local_copy, target_df, setup, teardown, arguments, options) = gc + cols = isempty(command.arguments) ? :(:) : :(collect($command.arguments)) quote $setup - $target_df[!, isempty($(command.arguments)) ? eval(:(:)) : collect($command.arguments)] |> $teardown |> setdf + $target_df[!, $cols] |> $teardown |> setdf end |> esc end @@ -139,36 +142,62 @@ function rewrite(::Val{:order}, command::Command) ArgumentError("Only one variable can be specified for `before` or `after` options in @order") |> throw end + target_cols = :(collect($(command.arguments))) + cols = gensym() + idx = gensym() quote $setup - target_cols = collect($(command.arguments)) - cols = [Symbol(col) for col in names($target_df) if Symbol(col) ∉ target_cols] + $cols = [Symbol(col) for col in names($target_df) if Symbol(col) ∉ $target_cols] if $alphabetical - cols = sort(cols, rev = $desc) + $cols = sort($cols, rev = $desc) end if $after - idx = findfirst(x -> x == $var[1], cols) - for (i,col) in enumerate(target_cols) - insert!(cols, idx + i, col) + $idx = findfirst(x -> x == $var[1], $cols) + for (i, col) in enumerate($target_cols) + insert!($cols, $idx + i, col) end end if $before - idx = findfirst(x -> x == $var[1], cols) - for (i,col) in enumerate(target_cols) - insert!(cols, idx + i - 1, col) + $idx = findfirst(x -> x == $var[1], $cols) + for (i, col) in enumerate($target_cols) + insert!($cols, $idx + i - 1, col) end end if $last && !($after || $before) - cols = push!(cols, target_cols...) + $cols = push!($cols, $target_cols...) elseif !($after || $before) - cols = pushfirst!(cols, target_cols...) + $cols = pushfirst!($cols, $target_cols...) end - $target_df[!,cols]|> $teardown |> setdf + $target_df[!, $cols]|> $teardown |> setdf end |> esc end +function rewrite(::Val{:mvencode}, command::Command) + gc = generate_command(command; options=[:variables, :ifable, :nofunction], allowed=[:mv]) + (; local_copy, target_df, setup, teardown, arguments, options) = gc + cols = :(collect($command.arguments)) + value = isnothing(get_option(command, :mv)) ? missing : get_option(command, :mv)[1] + value isa AbstractVector && ArgumentError("The value for @mvencode cannot be a vector") |> throw + bitmask = build_bitmask(local_copy, command.condition) + third_vector = gensym() + valtype = gensym() + coltype = gensym() + quote + $setup + $valtype = typeof($value) + for col in $cols + $coltype = eltype($local_copy[.!($bitmask), col]) + if $valtype != $coltype + local $third_vector = Vector{promote_type($coltype, $valtype)}($local_copy[!, col]) + $local_copy[!, col] = $third_vector + end + end + $local_copy[$bitmask, $cols] = mvreplace.($local_copy[$bitmask, $cols], $value) + $local_copy |> $teardown |> setdf + end |> esc +end \ No newline at end of file diff --git a/src/functions.jl b/src/functions.jl index 880f22c..1cc6d29 100644 --- a/src/functions.jl +++ b/src/functions.jl @@ -1,4 +1,5 @@ use(fname::AbstractString) = readstat(fname) |> DataFrame |> setdf +save(fname::AbstractString) = writestat(fname, getdf()) """ getdf() -> AbstractDataFrame @@ -110,4 +111,6 @@ function _describe(df::AbstractDataFrame, cols::Vector{Symbol}=Symbol[]) table = isempty(cols) ? describe(df) : describe(df[!, cols]) table.eltype = nonmissingtype.(table.eltype) table[!, [:variable, :eltype]] -end \ No newline at end of file +end + +mvreplace(x, y) = ismissing(x) ? y : x \ No newline at end of file diff --git a/src/macros.jl b/src/macros.jl index f95f1d0..ca289a7 100644 --- a/src/macros.jl +++ b/src/macros.jl @@ -17,7 +17,7 @@ end """ @drop y1 y2 ... or - @drop if condition] + @drop [@if condition] Drop the variables `y1`, `y2`, etc. from `df`. If `condition` is provided, the rows for which the condition is true are dropped. """ @@ -91,7 +91,7 @@ macro tabulate(exprs...) end """ - @count if condition] + @count [@if condition] Count the number of rows for which the condition is true. If `condition` is not provided, the total number of rows is counted. """ @@ -100,7 +100,7 @@ macro count(exprs...) end """ - @sort y1 y2 ...[, desc] + @sort y1 y2 ... , [desc] Sort the data frame by the variables `y1`, `y2`, etc. By default, the variables are sorted in ascending order. If `desc` is provided, the variables are sorted in descending order """ @@ -109,7 +109,7 @@ macro sort(exprs...) end """ - @order y1 y2 ... [desc] [last] [after=var] [before=var] [alphabetical] + @order y1 y2 ... , [desc] [last] [after=var] [before=var] [alphabetical] Reorder the variables `y1`, `y2`, etc. in the data frame. By default, the variables are ordered in the order they are listed. If `desc` is provided, the variables are ordered in descending order. If `last` is provided, the variables are moved to the end of the data frame. If `after` is provided, the variables are moved after the variable `var`. If `before` is provided, the variables are moved before the variable `var`. If `alphabetical` is provided, the variables are ordered alphabetically. """ @@ -128,7 +128,7 @@ end """ - @use "filename.dta"[, clear] + @use "filename.dta", [clear] Read the data from the file `filename.dta` and set it as the global data frame. If there is already a global data frame, `@use` will throw an error unless the `clear` option is provided """ @@ -144,6 +144,15 @@ macro use(exprs...) :(println("$(Kezdi.prompt())$($command)\n");Kezdi.use($fname)) |> esc end +macro save(exprs...) + command = parse(exprs, :save) + length(command.arguments) == 1 || ArgumentError("@save takes a single file name as an argument:\n@save \"filename.dta\"") |> throw + isnothing(getdf()) && ArgumentError("There is no data frame to save.") |> throw + fname = command.arguments[1] + replace = :replace in command.options + ispath(fname) && !replace && ArgumentError("File $fname already exists.") |> throw + :(println("$(Kezdi.prompt())$($command)\n");Kezdi.save($fname)) |> esc +end """ @head [n] @@ -198,3 +207,12 @@ macro describe(exprs...) :describe |> parse(exprs) |> rewrite end + +""" + @mvencode y1 y2 ... [if condition], [mv(value)] + +Encode missing values in the variables `y1`, `y2`, etc. in the data frame. If `condition` is provided, the operation is executed only on rows for which the condition is true. If `mv` is provided, the missing values are encoded with the value `value`. Default value is `missing` making no changes on the dataframe. +""" +macro mvencode(exprs...) + :mvencode |> parse(exprs) |> rewrite +end \ No newline at end of file diff --git a/src/side_effects.jl b/src/side_effects.jl index ff84c26..51273b1 100644 --- a/src/side_effects.jl +++ b/src/side_effects.jl @@ -58,9 +58,10 @@ end function rewrite(::Val{:list}, command::Command) gc = generate_command(command; options=[:variables, :ifable, :nofunction]) (; local_copy, target_df, setup, teardown, arguments, options) = gc + cols = isempty(command.arguments) ? :(:) : :(collect($command.arguments)) quote $setup - $target_df[!, isempty($(command.arguments)) ? eval(:(:)) : collect($command.arguments)] |> Kezdi.display_and_return |> $teardown + $target_df[!, $cols] |> Kezdi.display_and_return |> $teardown end |> esc end diff --git a/test/commands.jl b/test/commands.jl index 952d698..8409a10 100644 --- a/test/commands.jl +++ b/test/commands.jl @@ -117,6 +117,15 @@ end positive(x) = x > 0 @test (@with DataFrame(x=1:4, y=5:8) @replace y = 0 @if positive(x - 2)).y == [5, 6, 0, 0] end + + @testset "Local variable escaping bug" begin + df = DataFrame(x=[1, 2, 3]) + global eltype_LHS = :eltype_LHS + global eltype_RHS = :eltype_RHS + @with df @replace x = 1.1 @if _n == 1 + @test eltype_LHS == :eltype_LHS + @test eltype_RHS == :eltype_RHS + end end @testset "Missing values" begin @@ -631,7 +640,7 @@ end @test t[2] == 2 @test t[3] == 3 end - df = DataFrame(x=[1, 2, 2, 3, 3, 3], y= [0, 0, 0, 1, 1, 1]) + df = DataFrame(x=[1, 2, 2, 3, 3, 3], y=[0, 0, 0, 1, 1, 1]) @testset "Twoway" begin t = @with df @tabulate x y @test :x in t.dimnames @@ -758,5 +767,61 @@ end df = DataFrame(x=1:10, y=11:20) @use "test.dta", clear @test df == getdf() - try @use "test.dta" @if x<5, clear; catch e; @test e isa LoadError; end +end + +@testset "Save" begin + @clear + df = DataFrame(x=Vector{Any}(1:11), y=11:21) + setdf(df) + try @save "test.dta", replace catch e @test e == ErrorException("element type Any is not supported") end + df = DataFrame(x=1:11, y=11:21) + setdf(df) + @save "test.dta", replace + df2 = @use "test.dta", clear + @test df == df2 + df = DataFrame(x=1:10, y=11:20) + setdf(df) + @save "test.dta", replace +end + +@testset "Missing encode" begin + df = DataFrame(x=[1, 2, missing, 3, missing, 4], y=[missing, 0, 1, 2, missing, 1]) + @testset "Known values" begin + df2 = @with df @mvencode x + @test all(df2.x .=== [1, 2, missing, 3, missing, 4]) + df2 = @with df @mvencode x, mv(-99.0) + @test all(df2.x .== [1, 2, -99.0, 3, -99.0, 4]) + @test typeof(df2.x) == Vector{Union{Missing, Float64}} + df2 = @with df @mvencode x, mv(-99) + @test all(df2.x .== [1, 2, -99, 3, -99, 4]) + @test typeof(df2.x) == Vector{Union{Missing, Int64}} + df2 = @with df @mvencode x, mv(mean(skipmissing(getdf().x))) + @test all(df2.x .== [1, 2, 2.5, 3, 2.5, 4]) + @test typeof(df2.x) == Vector{Union{Missing, Float64}} + df2 = @with df @mvencode y, mv(-99) + @test all(df2.y .== [-99, 0, 1, 2, -99, 1]) + df2 = @with df @mvencode x y, mv(-99) + @test all(df2.x .== [1, 2, -99, 3, -99, 4]) + @test all(df2.y .== [-99, 0, 1, 2, -99, 1]) + end + + @testset "If" begin + df2 = @with df @mvencode x @if ismissing(y), mv(-99) + @test all(df2.x .=== [1, 2, missing, 3, -99, 4]) + df2 = @with df @mvencode x @if ismissing(x), mv(-99) + @test all(df2.x .=== [1, 2, -99, 3, -99, 4]) + df2 = @with df @mvencode y @if ismissing(y), mv(-99) + @test all(df2.y .=== [-99, 0, 1, 2, -99, 1]) + df2 = @with df @mvencode y @if ismissing(x), mv(-99) + @test all(df2.y .=== [missing, 0, 1, 2, -99, 1]) + df2 = @with df @mvencode x y @if ismissing(y), mv(-99) + @test all(df2.x .=== [1, 2, missing, 3, -99, 4]) + @test all(df2.y .=== [-99, 0, 1, 2, -99, 1]) + df2 = @with df @mvencode x y @if ismissing(x), mv(-99) + @test all(df2.x .=== [1, 2, -99, 3, -99, 4]) + @test all(df2.y .=== [missing, 0, 1, 2, -99, 1]) + df2 = @with df @mvencode x y @if ismissing(x) || !ismissing(y), mv(-99) + @test all(df2.x .=== [1, 2, -99, 3, -99, 4]) + @test all(df2.y .=== [missing, 0, 1, 2, -99, 1]) + end end \ No newline at end of file diff --git a/test/test.dta b/test/test.dta index 00865f4..5463a41 100644 Binary files a/test/test.dta and b/test/test.dta differ