Skip to content

Commit

Permalink
Rewrite COUNTIF to CASE
Browse files Browse the repository at this point in the history
  • Loading branch information
emk committed Oct 22, 2023
1 parent d1fe10e commit 59d7208
Show file tree
Hide file tree
Showing 8 changed files with 59 additions and 5 deletions.
3 changes: 2 additions & 1 deletion src/ast.rs
Original file line number Diff line number Diff line change
Expand Up @@ -2192,7 +2192,7 @@ peg::parser! {
= date_part:date_part() { ExpressionOrDatePart::DatePart(date_part) }
/ expression:expression() { ExpressionOrDatePart::Expression(expression) }

rule function_call() -> FunctionCall
pub rule function_call() -> FunctionCall
= name:function_name() paren1:p("(")
args:sep_opt_trailing(<expression()>, ",")? paren2:p(")")
over_clause:over_clause()?
Expand Down Expand Up @@ -2824,6 +2824,7 @@ mod tests {
(r"SELECT * FROM t WHERE a IN UNNEST([1])", None),
(r"SELECT IF(a = 0, 1, 2) c FROM t", None),
(r"SELECT CASE WHEN a = 0 THEN 1 ELSE 2 END c FROM t", None),
(r"SELECT CASE WHEN a = 0 THEN 1 END c FROM t", None),
(r"SELECT TRUE AND FALSE", None),
(r"SELECT TRUE OR FALSE", None),
(r"SELECT NOT TRUE", None),
Expand Down
1 change: 1 addition & 0 deletions src/drivers/snowflake/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -247,6 +247,7 @@ impl Driver for SnowflakeDriver {

fn transforms(&self) -> Vec<Box<dyn Transform>> {
vec![
Box::new(transforms::CountifToCase),
Box::new(transforms::IfToCase),
Box::new(transforms::RenameFunctions::new(
&FUNCTION_NAMES,
Expand Down
1 change: 1 addition & 0 deletions src/drivers/sqlite3/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -189,6 +189,7 @@ impl Driver for SQLite3Driver {
fn transforms(&self) -> Vec<Box<dyn Transform>> {
vec![
Box::new(transforms::BoolToInt),
Box::new(transforms::CountifToCase),
Box::new(transforms::IfToCase),
Box::new(transforms::OrReplaceToDropIfExists),
Box::new(transforms::WrapNestedQueries),
Expand Down
1 change: 1 addition & 0 deletions src/drivers/trino/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -172,6 +172,7 @@ impl Driver for TrinoDriver {

fn transforms(&self) -> Vec<Box<dyn Transform>> {
vec![
Box::new(transforms::CountifToCase),
Box::new(transforms::OrReplaceToDropIfExists),
Box::new(transforms::RenameFunctions::new(
&FUNCTION_NAMES,
Expand Down
9 changes: 9 additions & 0 deletions src/tokenizer.rs
Original file line number Diff line number Diff line change
Expand Up @@ -40,6 +40,7 @@ use codespan_reporting::{
use derive_visitor::{Drive, DriveMut};
use joinery_macros::ToTokens;
use peg::{error::ParseError, Parse, ParseElem, RuleResult};
use tracing::{error, trace};

use crate::{
ast,
Expand Down Expand Up @@ -428,9 +429,11 @@ impl TokenStream {
where
R: FnOnce(&TokenStream) -> Result<T, ParseError<Loc>>,
{
trace!(token_stream = ?self.tokens, "re-parsing token stream from `sql_quote!`");
match grammar_rule(&self) {
Ok(t) => Ok(t),
Err(err) => {
error!(error = ?err, token_stream = ?self.tokens, "failed to re-parse token stream from `sql_quote!`");
let diagnostic = Diagnostic::error().with_message("Failed to parse token stream");
Err(SourceError {
expected: err.to_string(),
Expand All @@ -457,6 +460,12 @@ impl TokenStream {
self.try_into_parsed(ast::sql_program::expression)
}

/// Try to parse this stream as a [`ast::FunctionCall`].
#[allow(dead_code)]
pub fn try_into_function_call(self) -> Result<ast::FunctionCall> {
self.try_into_parsed(ast::sql_program::function_call)
}

/// Parse a literal.
pub fn literal(&self, pos: usize) -> RuleResult<Literal> {
match self.tokens.get(pos) {
Expand Down
43 changes: 43 additions & 0 deletions src/transforms/countif_to_case.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,43 @@
use derive_visitor::{DriveMut, VisitorMut};
use joinery_macros::sql_quote;

use crate::{
ast::{self, Expression, FunctionCall, FunctionName},
errors::Result,
};

use super::{Transform, TransformExtra};

/// Transform `COUNTIF(condition)` into a portable `CASE` expression.
#[derive(VisitorMut)]
#[visitor(Expression(enter))]
pub struct CountifToCase;

impl CountifToCase {
fn enter_expression(&mut self, expr: &mut Expression) {
if let Expression::FunctionCall(FunctionCall {
name: FunctionName::Function { function },
args,
over_clause: None,
..
}) = expr
{
if function.name.eq_ignore_ascii_case("COUNTIF") && args.node_iter().count() == 1 {
let condition = args.node_iter().next().expect("has 1 arg");
let replacement = sql_quote! {
COUNT(CASE WHEN #condition THEN 1 END)
}
.try_into_expression()
.expect("generated SQL should always parse");
*expr = replacement;
}
}
}
}

impl Transform for CountifToCase {
fn transform(mut self: Box<Self>, sql_program: &mut ast::SqlProgram) -> Result<TransformExtra> {
sql_program.drive_mut(self.as_mut());
Ok(TransformExtra::default())
}
}
2 changes: 2 additions & 0 deletions src/transforms/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@ use crate::{ast, errors::Result};
pub use self::{
bool_to_int::BoolToInt,
clean_up_temp_manually::CleanUpTempManually,
countif_to_case::CountifToCase,
if_to_case::IfToCase,
or_replace_to_drop_if_exists::OrReplaceToDropIfExists,
rename_functions::{RenameFunctions, Udf},
Expand All @@ -18,6 +19,7 @@ pub use self::{

mod bool_to_int;
mod clean_up_temp_manually;
mod countif_to_case;
mod if_to_case;
mod or_replace_to_drop_if_exists;
mod rename_functions;
Expand Down
4 changes: 0 additions & 4 deletions tests/sql/functions/aggregate/countif.sql
Original file line number Diff line number Diff line change
@@ -1,7 +1,3 @@
-- pending: snowflake COUNTIF can be rewritten portably by transpiler
-- pending: sqlite3 COUNTIF Can be rewritten portably by transpiler
-- pending: trino COUNTIF Can be rewritten portably by transpiler

CREATE TEMP TABLE vals (i INT64);
INSERT INTO vals VALUES (1), (2), (2), (3), (NULL);

Expand Down

0 comments on commit 59d7208

Please sign in to comment.