From 07b33bc95d51f47bd79b44b65f61227346e4c06b Mon Sep 17 00:00:00 2001 From: Eric Kidd Date: Sun, 22 Oct 2023 06:44:43 -0400 Subject: [PATCH] sql_quote: Implement ToTokens for AST This requires another derive macro, but at this point, why not? We now have the right infrastructure to do lots of SQL rewriting. --- joinery_macros/src/emit.rs | 129 +++++++++++ joinery_macros/src/field_info.rs | 105 +++++++++ joinery_macros/src/lib.rs | 354 +++++-------------------------- joinery_macros/src/sql_quote.rs | 114 ++++++++++ joinery_macros/src/to_tokens.rs | 112 ++++++++++ src/ast.rs | 163 +++++++------- src/tokenizer.rs | 84 +++++++- 7 files changed, 672 insertions(+), 389 deletions(-) create mode 100644 joinery_macros/src/emit.rs create mode 100644 joinery_macros/src/field_info.rs create mode 100644 joinery_macros/src/sql_quote.rs create mode 100644 joinery_macros/src/to_tokens.rs diff --git a/joinery_macros/src/emit.rs b/joinery_macros/src/emit.rs new file mode 100644 index 0000000..3b3768d --- /dev/null +++ b/joinery_macros/src/emit.rs @@ -0,0 +1,129 @@ +//! Implementations of `#[derive(Emit)]` and `#[derive(EmitDefault)]`. + +use darling::{util::Flag, FromField}; +use proc_macro2::TokenStream as TokenStream2; +use quote::quote; + +use crate::field_info::{AttrWithSkip, FieldInfo}; + +/// The contents of an `#[emit(...)]` attribute, parsing using the +/// [`darling`](https://github.com/TedDriggs/darling) crate. +#[derive(Default, FromField)] +#[darling(default, attributes(emit))] +struct EmitAttr { + /// Should we omit this field from our output? + skip: Flag, +} + +impl AttrWithSkip for EmitAttr { + fn skip(&self) -> &Flag { + &self.skip + } +} + +pub(crate) fn impl_emit_macro(ast: &syn::DeriveInput) -> TokenStream2 { + let name = &ast.ident; + let (impl_generics, ty_generics, where_clause) = &ast.generics.split_for_impl(); + quote! { + impl #impl_generics Emit for #name #ty_generics #where_clause { + fn emit(&self, t: Target, f: &mut TokenWriter<'_>) -> ::std::io::Result<()> { + <#name #ty_generics as EmitDefault>::emit_default(self, t, f) + } + } + } +} + +/// [`EmitDefault::emit_default`] should just iterate over the struct or enum, +/// and generate a recursive call to [`Emit::emit`] for each field. +/// +/// TODO: If we see `#[emit(skip)]` on a field, we should skip it. +pub(crate) fn impl_emit_default_macro(ast: &syn::DeriveInput) -> TokenStream2 { + let name = &ast.ident; + let (impl_generics, ty_generics, where_clause) = &ast.generics.split_for_impl(); + let implementation = emit_default_body(name, &ast.data); + quote! { + impl #impl_generics EmitDefault for #name #ty_generics #where_clause { + fn emit_default(&self, t: Target, f: &mut TokenWriter<'_>) -> ::std::io::Result<()> { + #implementation + } + } + } +} +fn emit_default_body(name: &syn::Ident, data: &syn::Data) -> TokenStream2 { + match data { + syn::Data::Struct(s) => emit_default_body_struct(s), + syn::Data::Enum(e) => emit_default_body_enum(name, e), + syn::Data::Union(_) => panic!("Cannot derive EmitDefault for unions"), + } +} + +fn emit_default_body_struct(data: &syn::DataStruct) -> TokenStream2 { + match &data.fields { + syn::Fields::Named(fields) => { + let field_names = FieldInfo::::named_iter(fields) + .filter(|f| !f.attr().skip().is_present()) + .map(|f| f.struct_field()); + quote! { + #( self.#field_names.emit(t, f)?; )* + Ok(()) + } + } + syn::Fields::Unnamed(fields) => { + let field_names = FieldInfo::::unnamed_iter(fields) + .filter(|f| !f.attr().skip().is_present()) + .map(|f| f.struct_field()); + quote! { + #( self.#field_names.emit(t, f)?; )* + Ok(()) + } + } + syn::Fields::Unit => quote! { Ok(()) }, + } +} + +fn emit_default_body_enum(name: &syn::Ident, data: &syn::DataEnum) -> TokenStream2 { + let variants = data + .variants + .iter() + .map(|v| emit_default_body_enum_variant(name, v)); + quote! { + match self { + #( #variants )* + } + } +} + +fn emit_default_body_enum_variant(name: &syn::Ident, variant: &syn::Variant) -> TokenStream2 { + let variant_name = &variant.ident; + match &variant.fields { + syn::Fields::Named(fields) => { + let fields = FieldInfo::::named_iter(fields).collect::>(); + let patterns = fields.iter().map(|f| f.enum_pattern()); + let field_names = fields + .iter() + .filter(|f| !f.attr().skip().is_present()) + .map(|f| f.enum_name()); + quote! { + #name::#variant_name { #(#patterns),* } => { + #( #field_names.emit(t, f)?; )* + Ok(()) + } + } + } + syn::Fields::Unnamed(fields) => { + let fields = FieldInfo::::unnamed_iter(fields).collect::>(); + let patterns = fields.iter().map(|f| f.enum_pattern()); + let field_names = fields + .iter() + .filter(|f| !f.attr().skip().is_present()) + .map(|f| f.enum_name()); + quote! { + #name::#variant_name( #(#patterns),* ) => { + #( #field_names.emit(t, f)?; )* + Ok(()) + } + } + } + syn::Fields::Unit => quote! { Ok(()) }, + } +} diff --git a/joinery_macros/src/field_info.rs b/joinery_macros/src/field_info.rs new file mode 100644 index 0000000..2ef92f8 --- /dev/null +++ b/joinery_macros/src/field_info.rs @@ -0,0 +1,105 @@ +//! Information about a field in a `struct` or `enum` with a `#[derive(..)]` +//! attribute. + +use std::borrow::Cow; + +use darling::{util::Flag, FromField}; +use proc_macro2::{Span, TokenStream as TokenStream2}; +use quote::quote; +use syn::{spanned::Spanned, Field, Ident}; + +pub(crate) trait AttrWithSkip: Default + FromField + 'static { + /// Should we skip this field? + fn skip(&self) -> &Flag; +} + +/// Information we have about a field. Needed to generate the correct code. +pub(crate) enum FieldInfo<'a, Attr: AttrWithSkip> { + /// A named field, such a `struct S { foo: Foo }`. + Named { ident: &'a Ident, attr: Attr }, + /// An unnamed field, such as `struct S(Foo)`. + Unnamed { + index: usize, + span: Span, + attr: Attr, + }, +} + +impl<'a, Attr: AttrWithSkip> FieldInfo<'a, Attr> { + /// Collect info about named fields. + pub(crate) fn named_iter( + fields: &'a syn::FieldsNamed, + ) -> impl Iterator> + 'a { + fields.named.iter().map(Self::named) + } + + /// Collect info about unnamed fields. + pub(crate) fn unnamed_iter( + fields: &'a syn::FieldsUnnamed, + ) -> impl Iterator> + 'a { + fields + .unnamed + .iter() + .enumerate() + .map(|(i, f)| Self::unnamed(i, f)) + } + + /// Collect info about a named field. + fn named(f: &'a Field) -> Self { + Self::Named { + ident: f.ident.as_ref().expect("field should be named"), + attr: Attr::from_field(f).unwrap_or_default(), + } + } + + /// Collect info about an unnamed field. + fn unnamed(index: usize, f: &'a Field) -> Self { + Self::Unnamed { + index, + span: f.span(), + attr: Attr::from_field(f).unwrap_or_default(), + } + } + + pub(crate) fn attr(&self) -> &Attr { + match self { + Self::Named { attr, .. } => attr, + Self::Unnamed { attr, .. } => attr, + } + } + + /// How to name this field when accessing it as a struct field. + pub(crate) fn struct_field(&self) -> TokenStream2 { + match self { + Self::Named { ident, .. } => quote! { #ident }, + Self::Unnamed { index, span, .. } => { + let index = syn::Index { + index: *index as u32, + span: *span, + }; + quote! { #index } + } + } + } + + /// How to name this field when accessing inside a `match` arm. + pub(crate) fn enum_name(&self) -> Cow<'_, Ident> { + match self { + Self::Named { ident, .. } => Cow::Borrowed(ident), + Self::Unnamed { index, span, .. } => { + Cow::Owned(syn::Ident::new(&format!("f{}", index), *span)) + } + } + } + + /// How to name this field when using it as a pattern in a `match` arm. + pub(crate) fn enum_pattern(&self) -> TokenStream2 { + let name = self.enum_name(); + match (self, self.attr().skip().is_present()) { + (Self::Named { .. }, true) => quote! { #name: _ }, + (Self::Named { .. }, false) => quote! { #name }, + (Self::Unnamed { .. }, true) => quote! { _ }, + (Self::Unnamed { .. }, false) => quote! { #name }, + } + } +} diff --git a/joinery_macros/src/lib.rs b/joinery_macros/src/lib.rs index 7da45e8..e8528df 100644 --- a/joinery_macros/src/lib.rs +++ b/joinery_macros/src/lib.rs @@ -1,324 +1,74 @@ -use std::borrow::Cow; +//! Macros used internally by the `joinery` crate. Not intended for use by +//! anything but the `joinery` crate itself. -use darling::{util::Flag, FromField}; use proc_macro::TokenStream; -use proc_macro2::{Delimiter, Span, TokenStream as TokenStream2, TokenTree}; -use quote::{quote, quote_spanned}; -use syn::{spanned::Spanned, Field, Ident}; +use proc_macro2::TokenStream as TokenStream2; +use crate::{ + emit::{impl_emit_default_macro, impl_emit_macro}, + sql_quote::impl_sql_quote, + to_tokens::impl_to_tokens_macro, +}; + +mod emit; +mod field_info; +mod sql_quote; +mod to_tokens; + +/// Use `#[derive(Emit)]` to generate a simple implementation of `Emit` that +/// calls `EmitDefault::emit_default`. +/// +/// If you need to customize the code that's generated for a specific database, +/// you can implement `Emit` manually and call `EmitDefault::emit_default` for +/// any fields that you don't want to customize. #[proc_macro_derive(Emit)] pub fn emit_macro_derive(input: TokenStream) -> TokenStream { let ast = syn::parse(input).unwrap(); impl_emit_macro(&ast).into() } -fn impl_emit_macro(ast: &syn::DeriveInput) -> TokenStream2 { - let name = &ast.ident; - let (impl_generics, ty_generics, where_clause) = &ast.generics.split_for_impl(); - quote! { - impl #impl_generics Emit for #name #ty_generics #where_clause { - fn emit(&self, t: Target, f: &mut TokenWriter<'_>) -> ::std::io::Result<()> { - <#name #ty_generics as EmitDefault>::emit_default(self, t, f) - } - } - } -} - +/// Use `#[derive(EmitDefault)]` to generate an implementation of `EmitDefault` +/// that calls `Emit::emit` for each field. +/// +/// You should never need to implement `EmitDefault` manually. It exists only to +/// allow `Emit` to call `EmitDefault::emit_default` for any fields that aren't +/// customized. If all fields are customized, then `EmitDefault` doesn't need to +/// be implemented. #[proc_macro_derive(EmitDefault, attributes(emit))] pub fn emit_default_macro_derive(input: TokenStream) -> TokenStream { let ast = syn::parse(input).unwrap(); impl_emit_default_macro(&ast).into() } -/// [`EmitDefault::emit_default`] should just iterate over the struct or enum, -/// and generate a recursive call to [`Emit::emit`] for each field. +/// Use `#[sql_quote]` to write SQL queries inline in Rust code, with Rust +/// expressions interpolated into the query using `#expr`. Similar to +/// [`quote::quote!`], except for SQL instead of Rust. /// -/// TODO: If we see `#[emit(skip)]` on a field, we should skip it. -fn impl_emit_default_macro(ast: &syn::DeriveInput) -> TokenStream2 { - let name = &ast.ident; - let (impl_generics, ty_generics, where_clause) = &ast.generics.split_for_impl(); - let implementation = emit_default_body(name, &ast.data); - quote! { - impl #impl_generics EmitDefault for #name #ty_generics #where_clause { - fn emit_default(&self, t: Target, f: &mut TokenWriter<'_>) -> ::std::io::Result<()> { - #implementation - } - } - } -} - -fn emit_default_body(name: &syn::Ident, data: &syn::Data) -> TokenStream2 { - match data { - syn::Data::Struct(s) => emit_default_body_struct(s), - syn::Data::Enum(e) => emit_default_body_enum(name, e), - syn::Data::Union(_) => panic!("Cannot derive EmitDefault for unions"), - } -} - -fn emit_default_body_struct(data: &syn::DataStruct) -> TokenStream2 { - match &data.fields { - syn::Fields::Named(fields) => { - let field_names = FieldInfo::named_iter(fields) - .filter(|f| !f.attr().skip.is_present()) - .map(|f| f.struct_field()); - quote! { - #( self.#field_names.emit(t, f)?; )* - Ok(()) - } - } - syn::Fields::Unnamed(fields) => { - let field_names = FieldInfo::unnamed_iter(fields) - .filter(|f| !f.attr().skip.is_present()) - .map(|f| f.struct_field()); - quote! { - #( self.#field_names.emit(t, f)?; )* - Ok(()) - } - } - syn::Fields::Unit => quote! { Ok(()) }, - } -} - -fn emit_default_body_enum(name: &syn::Ident, data: &syn::DataEnum) -> TokenStream2 { - let variants = data - .variants - .iter() - .map(|v| emit_default_body_enum_variant(name, v)); - quote! { - match self { - #( #variants )* - } - } -} - -fn emit_default_body_enum_variant(name: &syn::Ident, variant: &syn::Variant) -> TokenStream2 { - let variant_name = &variant.ident; - match &variant.fields { - syn::Fields::Named(fields) => { - let fields = FieldInfo::named_iter(fields).collect::>(); - let patterns = fields.iter().map(|f| f.enum_pattern()); - let field_names = fields - .iter() - .filter(|f| !f.attr().skip.is_present()) - .map(|f| f.enum_name()); - quote! { - #name::#variant_name { #(#patterns),* } => { - #( #field_names.emit(t, f)?; )* - Ok(()) - } - } - } - syn::Fields::Unnamed(fields) => { - let fields = FieldInfo::unnamed_iter(fields).collect::>(); - let patterns = fields.iter().map(|f| f.enum_pattern()); - let field_names = fields - .iter() - .filter(|f| !f.attr().skip.is_present()) - .map(|f| f.enum_name()); - quote! { - #name::#variant_name( #(#patterns),* ) => { - #( #field_names.emit(t, f)?; )* - Ok(()) - } - } - } - syn::Fields::Unit => quote! { Ok(()) }, - } -} - -/// Information we have about a field. Needed to generate the correct code. -enum FieldInfo<'a> { - /// A named field, such a `struct S { foo: Foo }`. - Named { ident: &'a Ident, attr: EmitAttr }, - /// An unnamed field, such as `struct S(Foo)`. - Unnamed { - index: usize, - span: Span, - attr: EmitAttr, - }, -} - -impl<'a> FieldInfo<'a> { - /// Collect info about named fields. - fn named_iter(fields: &'a syn::FieldsNamed) -> impl Iterator> + 'a { - fields.named.iter().map(Self::named) - } - - /// Collect info about unnamed fields. - fn unnamed_iter(fields: &'a syn::FieldsUnnamed) -> impl Iterator> + 'a { - fields - .unnamed - .iter() - .enumerate() - .map(|(i, f)| Self::unnamed(i, f)) - } - - /// Collect info about a named field. - fn named(f: &'a Field) -> Self { - Self::Named { - ident: f.ident.as_ref().expect("field should be named"), - attr: EmitAttr::from_field(f).unwrap_or_default(), - } - } - - /// Collect info about an unnamed field. - fn unnamed(index: usize, f: &'a Field) -> Self { - Self::Unnamed { - index, - span: f.span(), - attr: EmitAttr::from_field(f).unwrap_or_default(), - } - } - - fn attr(&self) -> &EmitAttr { - match self { - Self::Named { attr, .. } => attr, - Self::Unnamed { attr, .. } => attr, - } - } - - /// How to name this field when accessing it as a struct field. - fn struct_field(&self) -> TokenStream2 { - match self { - Self::Named { ident, .. } => quote! { #ident }, - Self::Unnamed { index, span, .. } => { - let index = syn::Index { - index: *index as u32, - span: *span, - }; - quote! { #index } - } - } - } - - /// How to name this field when accessing inside a `match` arm. - fn enum_name(&self) -> Cow<'_, Ident> { - match self { - Self::Named { ident, .. } => Cow::Borrowed(ident), - Self::Unnamed { index, span, .. } => { - Cow::Owned(syn::Ident::new(&format!("f{}", index), *span)) - } - } - } - - /// How to name this field when using it as a pattern in a `match` arm. - fn enum_pattern(&self) -> TokenStream2 { - let name = self.enum_name(); - match (self, self.attr().skip.is_present()) { - (Self::Named { .. }, true) => quote! { #name: _ }, - (Self::Named { .. }, false) => quote! { #name }, - (Self::Unnamed { .. }, true) => quote! { _ }, - (Self::Unnamed { .. }, false) => quote! { #name }, - } - } -} - -/// The contents of an `#[emit(...)]` attribute, parsing using the -/// [`darling`](https://github.com/TedDriggs/darling) crate. -#[derive(Default, FromField)] -#[darling(default, attributes(emit))] -struct EmitAttr { - /// Should we omit this field from our output? - skip: Flag, -} - +/// The output of this macro is a `joinery::tokenizer::TokenStream`, which can +/// then be re-parsed using various methods on +/// `joinery::tokenizer::TokenStream`. +/// +/// # Example +/// +/// ```no_compile +/// use joinery::sql_quote; +/// +/// let table = Ident::new("foo"); +/// let query = sql_quote! { +/// SELECT * FROM #table +/// }.try_into_statement()?; +/// ``` #[proc_macro] pub fn sql_quote(input: TokenStream) -> TokenStream { let input = TokenStream2::from(input); - - let mut sql_token_exprs = vec![]; - emit_sql_token_exprs(&mut sql_token_exprs, input.into_iter()); - let capacity = sql_token_exprs.len(); - let output = quote! { - { - use crate::tokenizer::{Literal, Token, TokenStream}; - let mut __tokens = Vec::with_capacity(#capacity); - #( #sql_token_exprs; )* - TokenStream { tokens: __tokens } - } - }; - output.into() -} - -fn emit_sql_token_exprs( - sql_token_exprs: &mut Vec, - mut tokens: impl Iterator, -) { - while let Some(token) = tokens.next() { - match token { - // Treat `#` as interpolation. - TokenTree::Punct(p) if p.to_string() == "#" => { - if let Some(expr) = tokens.next() { - sql_token_exprs.push(quote! { - (#expr).to_tokens(&mut __tokens) - }); - } else { - sql_token_exprs.push(quote_spanned! { - p.span() => - compile_error!("expected expression after `#`") - }); - } - } - TokenTree::Group(group) => { - // We flatten this and use `Punct::new`. - let (open, close) = delimiter_pair(group.delimiter()); - if let Some(open) = open { - sql_token_exprs.push(quote! { __tokens.push(Token::punct(#open)) }); - } - emit_sql_token_exprs(sql_token_exprs, group.stream().into_iter()); - if let Some(close) = close { - sql_token_exprs.push(quote! { __tokens.push(Token::punct(#close)) }); - } - } - TokenTree::Ident(ident) => { - let ident_str = ident.to_string(); - sql_token_exprs.push(quote! { __tokens.push(Token::ident(#ident_str)) }); - } - TokenTree::Punct(punct) => { - let punct_str = punct.to_string(); - sql_token_exprs.push(quote! { __tokens.push(Token::punct(#punct_str)) }); - } - TokenTree::Literal(lit) => { - // There's probably a better way to do this. - let lit: syn::Lit = syn::parse_quote!(#lit); - match lit { - syn::Lit::Int(i) => { - sql_token_exprs.push(quote! { - __tokens.push(Token::Literal(Literal::int(#i))) - }); - } - syn::Lit::Str(s) => { - sql_token_exprs.push(quote! { - __tokens.push(Token::Literal(Literal::string(#s))) - }); - } - syn::Lit::Float(f) => { - sql_token_exprs.push(quote! { - __tokens.push(Token::Literal(Literal::float(#f))) - }); - } - // syn::Lit::ByteStr(_) => todo!(), - // syn::Lit::Byte(_) => todo!(), - // syn::Lit::Char(_) => todo!(), - // syn::Lit::Bool(_) => todo!(), - // syn::Lit::Verbatim(_) => todo!(), - _ => { - sql_token_exprs.push(quote_spanned! { - lit.span() => - compile_error!("unsupported literal type") - }); - } - } - } - } - } + impl_sql_quote(input).into() } -fn delimiter_pair(d: Delimiter) -> (Option<&'static str>, Option<&'static str>) { - match d { - Delimiter::Parenthesis => (Some("("), Some(")")), - Delimiter::Brace => (Some("{"), Some("}")), - Delimiter::Bracket => (Some("["), Some("]")), - Delimiter::None => (None, None), - } +/// Use `#[derive(ToTokens)]` to generate an implementation of `ToTokens` that +/// recursively calls `ToTokens::to_tokens` for each field. This is used to +/// convert parsed SQL back into tokens, for use with [`sql_quote!`]. +#[proc_macro_derive(ToTokens, attributes(to_tokens))] +pub fn to_owned_macro_derive(input: TokenStream) -> TokenStream { + let ast = syn::parse(input).unwrap(); + impl_to_tokens_macro(&ast).into() } diff --git a/joinery_macros/src/sql_quote.rs b/joinery_macros/src/sql_quote.rs new file mode 100644 index 0000000..525aece --- /dev/null +++ b/joinery_macros/src/sql_quote.rs @@ -0,0 +1,114 @@ +//! Quasi-quoting for SQL. +//! +//! This is similar to Rust's `quote` crate. It allows you to write SQL queries +//! inline in Rust code, with Rust expressions interpolated into the query. +//! +//! The output of this macro is a `joinery::tokenizer::TokenStream`, which is +//! used by the `joinery` crate to generate SQL. + +use proc_macro2::{Delimiter, TokenStream as TokenStream2, TokenTree}; +use quote::{quote, quote_spanned}; +use syn::spanned::Spanned; + +pub(crate) fn impl_sql_quote(input: TokenStream2) -> TokenStream2 { + let mut sql_token_exprs = vec![]; + emit_sql_token_exprs(&mut sql_token_exprs, input.into_iter()); + let capacity = sql_token_exprs.len(); + quote! { + { + use crate::tokenizer::{Literal, Token, TokenStream}; + let mut __tokens = Vec::with_capacity(#capacity); + #( #sql_token_exprs; )* + TokenStream { tokens: __tokens } + } + } +} + +fn emit_sql_token_exprs( + sql_token_exprs: &mut Vec, + mut tokens: impl Iterator, +) { + while let Some(token) = tokens.next() { + match token { + // Treat `#` as interpolation. + TokenTree::Punct(p) if p.to_string() == "#" => { + if let Some(expr) = tokens.next() { + sql_token_exprs.push(quote_spanned! { expr.span() => + (#expr).to_tokens(&mut __tokens) + }); + } else { + sql_token_exprs.push(quote_spanned! { p.span() => + compile_error!("expected expression after `#`") + }); + } + } + TokenTree::Group(group) => { + // We flatten this and use `Punct::new`. + let (open, close) = delimiter_pair(group.delimiter()); + if let Some(open) = open { + sql_token_exprs.push(quote_spanned! { open.span() => + __tokens.push(Token::punct(#open)) + }); + } + emit_sql_token_exprs(sql_token_exprs, group.stream().into_iter()); + if let Some(close) = close { + sql_token_exprs.push(quote_spanned! { close.span() => + __tokens.push(Token::punct(#close)) + }); + } + } + TokenTree::Ident(ident) => { + let ident_str = ident.to_string(); + sql_token_exprs.push(quote_spanned! { ident_str.span() => + __tokens.push(Token::ident(#ident_str)) + }); + } + TokenTree::Punct(punct) => { + let punct_str = punct.to_string(); + sql_token_exprs.push(quote_spanned! { punct_str.span() => + __tokens.push(Token::punct(#punct_str)) }); + } + TokenTree::Literal(lit) => { + // There's probably a better way to do this. + let lit: syn::Lit = syn::parse_quote!(#lit); + match lit { + syn::Lit::Int(i) => { + sql_token_exprs.push(quote_spanned! { i.span() => + __tokens.push(Token::Literal(Literal::int(#i))) + }); + } + syn::Lit::Str(s) => { + sql_token_exprs.push(quote_spanned! { s.span() => + __tokens.push(Token::Literal(Literal::string(#s))) + }); + } + syn::Lit::Float(f) => { + sql_token_exprs.push(quote_spanned! { f.span() => + __tokens.push(Token::Literal(Literal::float(#f))) + }); + } + // syn::Lit::ByteStr(_) => todo!(), + // syn::Lit::Byte(_) => todo!(), + // syn::Lit::Char(_) => todo!(), + // syn::Lit::Bool(_) => todo!(), + // syn::Lit::Verbatim(_) => todo!(), + _ => { + sql_token_exprs.push(quote_spanned! { + lit.span() => + compile_error!("unsupported literal type") + }); + } + } + } + } + } +} + +fn delimiter_pair(d: Delimiter) -> (Option<&'static str>, Option<&'static str>) { + match d { + Delimiter::Parenthesis => (Some("("), Some(")")), + Delimiter::Brace => (Some("{"), Some("}")), + Delimiter::Bracket => (Some("["), Some("]")), + Delimiter::None => (None, None), + } +} diff --git a/joinery_macros/src/to_tokens.rs b/joinery_macros/src/to_tokens.rs new file mode 100644 index 0000000..24e466d --- /dev/null +++ b/joinery_macros/src/to_tokens.rs @@ -0,0 +1,112 @@ +use darling::{util::Flag, FromField}; +use proc_macro2::TokenStream as TokenStream2; +use quote::quote; + +use crate::field_info::{AttrWithSkip, FieldInfo}; + +/// The contents of a `#[to_tokens(...)]` attribute, parsed using the +/// [`darling`](https://github.com/TedDriggs/darling) crate. +#[derive(Default, FromField)] +#[darling(default, attributes(to_tokens))] +struct ToTokensAttr { + /// Should we omit this field from our output? + skip: Flag, +} + +impl AttrWithSkip for ToTokensAttr { + fn skip(&self) -> &Flag { + &self.skip + } +} + +/// [`ToTokens::to_tokens`] should just iterate over the struct or enum, and +/// generate a recursive call to [`ToTokens::to_tokens`] for each field. +/// +/// If we see `#[to_tokens(skip)]` on a field, we should skip it. +pub(crate) fn impl_to_tokens_macro(ast: &syn::DeriveInput) -> TokenStream2 { + let name = &ast.ident; + let (impl_generics, ty_generics, where_clause) = &ast.generics.split_for_impl(); + let implementation = to_tokens_body(name, &ast.data); + quote! { + impl #impl_generics ToTokens for #name #ty_generics #where_clause { + fn to_tokens(&self, tokens: &mut Vec) { + #implementation + } + } + } +} + +fn to_tokens_body(name: &syn::Ident, data: &syn::Data) -> TokenStream2 { + match data { + syn::Data::Struct(s) => to_tokens_body_struct(s), + syn::Data::Enum(e) => to_tokens_body_enum(name, e), + syn::Data::Union(_) => panic!("Cannot derive ToTokens for unions"), + } +} + +fn to_tokens_body_struct(data: &syn::DataStruct) -> TokenStream2 { + match &data.fields { + syn::Fields::Named(fields) => { + let field_names = FieldInfo::::named_iter(fields) + .filter(|f| !f.attr().skip().is_present()) + .map(|f| f.struct_field()); + quote! { + #( self.#field_names.to_tokens(tokens); )* + } + } + syn::Fields::Unnamed(fields) => { + let field_names = FieldInfo::::unnamed_iter(fields) + .filter(|f| !f.attr().skip().is_present()) + .map(|f| f.struct_field()); + quote! { + #( self.#field_names.to_tokens(tokens); )* + } + } + syn::Fields::Unit => quote! {}, + } +} + +fn to_tokens_body_enum(name: &syn::Ident, data: &syn::DataEnum) -> TokenStream2 { + let variants = data + .variants + .iter() + .map(|v| to_tokens_body_enum_variant(name, v)); + quote! { + match self { + #( #variants )* + } + } +} + +fn to_tokens_body_enum_variant(name: &syn::Ident, variant: &syn::Variant) -> TokenStream2 { + let variant_name = &variant.ident; + match &variant.fields { + syn::Fields::Named(fields) => { + let fields = FieldInfo::::named_iter(fields).collect::>(); + let patterns = fields.iter().map(|f| f.enum_pattern()); + let field_names = fields + .iter() + .filter(|f| !f.attr().skip().is_present()) + .map(|f| f.enum_name()); + quote! { + #name::#variant_name { #(#patterns),* } => { + #( #field_names.to_tokens(tokens); )* + } + } + } + syn::Fields::Unnamed(fields) => { + let fields = FieldInfo::::unnamed_iter(fields).collect::>(); + let patterns = fields.iter().map(|f| f.enum_pattern()); + let field_names = fields + .iter() + .filter(|f| !f.attr().skip().is_present()) + .map(|f| f.enum_name()); + quote! { + #name::#variant_name( #(#patterns),* ) => { + #( #field_names.to_tokens(tokens); )* + } + } + } + syn::Fields::Unit => quote! {}, + } +} diff --git a/src/ast.rs b/src/ast.rs index 6095ece..b1ca97b 100644 --- a/src/ast.rs +++ b/src/ast.rs @@ -29,7 +29,7 @@ use codespan_reporting::{ files::SimpleFiles, }; use derive_visitor::{Drive, DriveMut}; -use joinery_macros::{Emit, EmitDefault}; +use joinery_macros::{Emit, EmitDefault, ToTokens}; use crate::{ drivers::{ @@ -41,7 +41,7 @@ use crate::{ errors::{Result, SourceError}, tokenizer::{ tokenize_sql, CaseInsensitiveIdent, EmptyFile, Ident, Keyword, Literal, LiteralValue, - Punct, RawToken, Token, TokenStream, TokenWriter, + Punct, RawToken, ToTokens, Token, TokenStream, TokenWriter, }, util::{is_c_ident, AnsiIdent, AnsiString}, }; @@ -102,7 +102,7 @@ impl fmt::Display for Target { /// A default version of [`Emit`] which attemps to write the AST back out as /// BigQuery SQL, extremely close to the original input. /// -/// For most types, you will start by using `#[derive(Emit, EmitDefault)]`. This +/// For most types, you will start by using `#[derive(Emit, EmitDefault, ToTokens)]`. This /// will generate: /// /// - An implementation of [`Emit`] which calls [`EmitDefault::emit_default`]. @@ -124,7 +124,7 @@ pub trait EmitDefault { /// Emit the AST as code for a specific database. /// -/// If you use `#[derive(Emit, EmitDefault)]` on a type, then [`Emit::emit`] +/// If you use `#[derive(Emit, EmitDefault, ToTokens)]` on a type, then [`Emit::emit`] /// will be generated to call [`EmitDefault::emit_default`]. pub trait Emit: Sized { /// Format this node for the specified database. @@ -281,12 +281,12 @@ fn emit_whitespace(ws: &str, t: Target, f: &mut dyn io::Write) -> io::Result<()> } /// A node type, for use with [`NodeVec`]. -pub trait Node: Clone + fmt::Debug + Drive + DriveMut + Emit + 'static {} +pub trait Node: Clone + fmt::Debug + Drive + DriveMut + Emit + ToTokens + 'static {} -impl Node for T {} +impl Node for T {} /// Either a node or a separator token. -#[derive(Clone, Debug, Drive, DriveMut, Emit, EmitDefault)] +#[derive(Clone, Debug, Drive, DriveMut, Emit, EmitDefault, ToTokens)] pub enum NodeOrSep { Node(T), Sep(Punct), @@ -304,9 +304,10 @@ pub enum NodeOrSep { /// this, we would need to modify [`IntoIterator`] to return interleaved /// nodes and separators, and define custom node-only and separator-only /// iterators. -#[derive(Debug)] +#[derive(Debug, ToTokens)] pub struct NodeVec { /// The separator to use when adding items. + #[to_tokens(skip)] pub separator: &'static str, /// The nodes and separators in this vector. pub items: Vec>, @@ -446,7 +447,7 @@ impl Emit for NodeVec { } /// A table name. -#[derive(Clone, Debug, Drive, DriveMut, EmitDefault)] +#[derive(Clone, Debug, Drive, DriveMut, EmitDefault, ToTokens)] pub enum TableName { ProjectDatasetTable { project: Ident, @@ -503,7 +504,7 @@ impl Emit for TableName { } /// A table and a column name. -#[derive(Clone, Debug, Drive, DriveMut, Emit, EmitDefault)] +#[derive(Clone, Debug, Drive, DriveMut, Emit, EmitDefault, ToTokens)] pub struct TableAndColumnName { pub table_name: TableName, pub dot: Punct, @@ -511,7 +512,7 @@ pub struct TableAndColumnName { } /// An entire SQL program. -#[derive(Clone, Debug, Drive, DriveMut, Emit, EmitDefault)] +#[derive(Clone, Debug, Drive, DriveMut, Emit, EmitDefault, ToTokens)] pub struct SqlProgram { /// For now, just handle single statements; BigQuery DDL is messy and maybe /// out of scope. @@ -519,7 +520,7 @@ pub struct SqlProgram { } /// A statement in our abstract syntax tree. -#[derive(Clone, Debug, Drive, DriveMut, Emit, EmitDefault)] +#[derive(Clone, Debug, Drive, DriveMut, Emit, EmitDefault, ToTokens)] pub enum Statement { Query(QueryStatement), DeleteFrom(DeleteFromStatement), @@ -531,7 +532,7 @@ pub enum Statement { } /// A query statement. This exists mainly because it's in the official grammar. -#[derive(Clone, Debug, Drive, DriveMut, Emit, EmitDefault)] +#[derive(Clone, Debug, Drive, DriveMut, Emit, EmitDefault, ToTokens)] pub struct QueryStatement { pub query_expression: QueryExpression, } @@ -542,7 +543,7 @@ pub struct QueryStatement { /// /// [official grammar]: /// https://cloud.google.com/bigquery/docs/reference/standard-sql/query-syntax#sql_syntax. -#[derive(Clone, Debug, Drive, DriveMut, EmitDefault)] +#[derive(Clone, Debug, Drive, DriveMut, EmitDefault, ToTokens)] pub enum QueryExpression { SelectExpression(SelectExpression), Nested { @@ -583,7 +584,7 @@ impl Emit for QueryExpression { } /// Common table expressions (CTEs). -#[derive(Clone, Debug, Drive, DriveMut, Emit, EmitDefault)] +#[derive(Clone, Debug, Drive, DriveMut, Emit, EmitDefault, ToTokens)] pub struct CommonTableExpression { pub name: Ident, pub as_token: Keyword, @@ -593,7 +594,7 @@ pub struct CommonTableExpression { } /// Set operators. -#[derive(Clone, Debug, Drive, DriveMut, EmitDefault)] +#[derive(Clone, Debug, Drive, DriveMut, EmitDefault, ToTokens)] pub enum SetOperator { UnionAll { union_token: Keyword, @@ -640,7 +641,7 @@ impl Emit for SetOperator { } /// A `SELECT` expression. -#[derive(Clone, Debug, Drive, DriveMut, Emit, EmitDefault)] +#[derive(Clone, Debug, Drive, DriveMut, Emit, EmitDefault, ToTokens)] pub struct SelectExpression { pub select_options: SelectOptions, pub select_list: SelectList, @@ -656,14 +657,14 @@ pub struct SelectExpression { } /// The head of a `SELECT`, including any modifiers. -#[derive(Clone, Debug, Drive, DriveMut, Emit, EmitDefault)] +#[derive(Clone, Debug, Drive, DriveMut, Emit, EmitDefault, ToTokens)] pub struct SelectOptions { pub select_token: Keyword, pub distinct: Option, } /// The `DISTINCT` modifier. -#[derive(Clone, Debug, Drive, DriveMut, Emit, EmitDefault)] +#[derive(Clone, Debug, Drive, DriveMut, Emit, EmitDefault, ToTokens)] pub struct Distinct { pub distinct_token: Keyword, } @@ -686,13 +687,13 @@ pub struct Distinct { /// select_expression: /// expression [ [ AS ] alias ] /// ``` -#[derive(Clone, Debug, Drive, DriveMut, Emit, EmitDefault)] +#[derive(Clone, Debug, Drive, DriveMut, Emit, EmitDefault, ToTokens)] pub struct SelectList { pub items: NodeVec, } /// A single item in a `SELECT` list. -#[derive(Clone, Debug, Drive, DriveMut, Emit, EmitDefault)] +#[derive(Clone, Debug, Drive, DriveMut, Emit, EmitDefault, ToTokens)] pub enum SelectListItem { /// An expression, optionally with an alias. Expression { @@ -711,7 +712,7 @@ pub enum SelectListItem { } /// An `EXCEPT` clause. -#[derive(Clone, Debug, Drive, DriveMut, EmitDefault)] +#[derive(Clone, Debug, Drive, DriveMut, EmitDefault, ToTokens)] pub struct Except { pub except_token: Keyword, pub paren1: Punct, @@ -738,7 +739,7 @@ impl Emit for Except { } /// An SQL expression. -#[derive(Clone, Debug, Drive, DriveMut, EmitDefault)] +#[derive(Clone, Debug, Drive, DriveMut, EmitDefault, ToTokens)] pub enum Expression { Literal(Literal), BoolValue(Keyword), @@ -861,7 +862,7 @@ impl Emit for Expression { } /// An `INTERVAL` expression. -#[derive(Clone, Debug, Drive, DriveMut, Emit, EmitDefault)] +#[derive(Clone, Debug, Drive, DriveMut, Emit, EmitDefault, ToTokens)] pub struct IntervalExpression { pub interval_token: Keyword, pub number: Literal, @@ -869,13 +870,13 @@ pub struct IntervalExpression { } /// A date part in an `INTERVAL` expression, or in the special date functions.S -#[derive(Clone, Debug, Drive, DriveMut, Emit, EmitDefault)] +#[derive(Clone, Debug, Drive, DriveMut, Emit, EmitDefault, ToTokens)] pub struct DatePart { pub date_part_token: CaseInsensitiveIdent, } /// A cast expression. -#[derive(Clone, Debug, Drive, DriveMut, Emit, EmitDefault)] +#[derive(Clone, Debug, Drive, DriveMut, Emit, EmitDefault, ToTokens)] pub struct Cast { cast_type: CastType, paren1: Punct, @@ -886,7 +887,7 @@ pub struct Cast { } /// What type of cast do we want to perform? -#[derive(Clone, Debug, Drive, DriveMut, EmitDefault)] +#[derive(Clone, Debug, Drive, DriveMut, EmitDefault, ToTokens)] pub enum CastType { Cast { cast_token: Keyword, @@ -918,7 +919,7 @@ impl Emit for CastType { /// /// [official grammar]: /// https://cloud.google.com/bigquery/docs/reference/standard-sql/operators#in_operators -#[derive(Clone, Debug, Drive, DriveMut, Emit, EmitDefault)] +#[derive(Clone, Debug, Drive, DriveMut, Emit, EmitDefault, ToTokens)] pub enum InValueSet { QueryExpression { paren1: Punct, @@ -943,7 +944,7 @@ pub enum InValueSet { /// /// Not all combinations of our fields are valid. For example, we can't have /// a missing `ARRAY` and a `delim1` of `(`. We'll let the parser handle that. -#[derive(Clone, Debug, Drive, DriveMut, EmitDefault)] +#[derive(Clone, Debug, Drive, DriveMut, EmitDefault, ToTokens)] pub struct ArrayExpression { pub array_token: Option, pub element_type: Option, @@ -988,14 +989,14 @@ impl Emit for ArrayExpression { /// An `ARRAY` definition. Either a `SELECT` expression or a list of /// expressions. -#[derive(Clone, Debug, Drive, DriveMut, Emit, EmitDefault)] +#[derive(Clone, Debug, Drive, DriveMut, Emit, EmitDefault, ToTokens)] pub enum ArrayDefinition { Query(Box), Elements(NodeVec), } /// A struct expression. -#[derive(Clone, Debug, Drive, DriveMut, Emit, EmitDefault)] +#[derive(Clone, Debug, Drive, DriveMut, Emit, EmitDefault, ToTokens)] pub struct StructExpression { pub struct_token: Keyword, pub paren1: Punct, @@ -1004,7 +1005,7 @@ pub struct StructExpression { } /// The type of the elements in an `ARRAY` expression. -#[derive(Clone, Debug, Drive, DriveMut, Emit, EmitDefault)] +#[derive(Clone, Debug, Drive, DriveMut, Emit, EmitDefault, ToTokens)] pub struct ArrayElementType { pub lt: Punct, pub elem_type: DataType, @@ -1012,7 +1013,7 @@ pub struct ArrayElementType { } /// A `COUNT` expression. -#[derive(Clone, Debug, Drive, DriveMut, Emit, EmitDefault)] +#[derive(Clone, Debug, Drive, DriveMut, Emit, EmitDefault, ToTokens)] pub enum CountExpression { CountStar { count_token: CaseInsensitiveIdent, @@ -1030,7 +1031,7 @@ pub enum CountExpression { } /// A `CASE WHEN` clause. -#[derive(Clone, Debug, Drive, DriveMut, Emit, EmitDefault)] +#[derive(Clone, Debug, Drive, DriveMut, Emit, EmitDefault, ToTokens)] pub struct CaseWhenClause { pub when_token: Keyword, pub condition: Box, @@ -1039,7 +1040,7 @@ pub struct CaseWhenClause { } /// A `CASE ELSE` clause. -#[derive(Clone, Debug, Drive, DriveMut, Emit, EmitDefault)] +#[derive(Clone, Debug, Drive, DriveMut, Emit, EmitDefault, ToTokens)] pub struct CaseElseClause { pub else_token: Keyword, pub result: Box, @@ -1047,7 +1048,7 @@ pub struct CaseElseClause { /// `CURRENT_DATE` may appear as either `CURRENT_DATE` or `CURRENT_DATE()`. /// And different databases seem to support one or the other or both. -#[derive(Clone, Debug, Drive, DriveMut, EmitDefault)] +#[derive(Clone, Debug, Drive, DriveMut, EmitDefault, ToTokens)] pub struct CurrentDate { pub current_date_token: CaseInsensitiveIdent, pub empty_parens: Option, @@ -1064,7 +1065,7 @@ impl Emit for CurrentDate { } /// An empty `()` expression. -#[derive(Clone, Debug, Drive, DriveMut, Emit, EmitDefault)] +#[derive(Clone, Debug, Drive, DriveMut, Emit, EmitDefault, ToTokens)] pub struct EmptyParens { pub paren1: Punct, pub paren2: Punct, @@ -1073,7 +1074,7 @@ pub struct EmptyParens { /// Special "functions" that manipulate dates. These all take a [`DatePart`] /// as a final argument. So in Lisp sense, these are special forms or macros, /// not ordinary function calls. -#[derive(Clone, Debug, Drive, DriveMut, Emit, EmitDefault)] +#[derive(Clone, Debug, Drive, DriveMut, Emit, EmitDefault, ToTokens)] pub struct SpecialDateFunctionCall { pub function_name: CaseInsensitiveIdent, pub paren1: Punct, @@ -1082,14 +1083,14 @@ pub struct SpecialDateFunctionCall { } /// An expression or a date part. -#[derive(Clone, Debug, Drive, DriveMut, Emit, EmitDefault)] +#[derive(Clone, Debug, Drive, DriveMut, Emit, EmitDefault, ToTokens)] pub enum ExpressionOrDatePart { Expression(Expression), DatePart(DatePart), } /// A function call. -#[derive(Clone, Debug, Drive, DriveMut, Emit, EmitDefault)] +#[derive(Clone, Debug, Drive, DriveMut, Emit, EmitDefault, ToTokens)] pub struct FunctionCall { pub name: FunctionName, pub paren1: Punct, @@ -1099,7 +1100,7 @@ pub struct FunctionCall { } /// A function name. -#[derive(Clone, Debug, Drive, DriveMut, EmitDefault)] +#[derive(Clone, Debug, Drive, DriveMut, EmitDefault, ToTokens)] pub enum FunctionName { ProjectDatasetFunction { project: Ident, @@ -1166,7 +1167,7 @@ impl Emit for FunctionName { /// See the [official grammar][]. We only implement part of this. /// /// [official grammar]: https://cloud.google.com/bigquery/docs/reference/standard-sql/window-function-calls#syntax -#[derive(Clone, Debug, Drive, DriveMut, Emit, EmitDefault)] +#[derive(Clone, Debug, Drive, DriveMut, Emit, EmitDefault, ToTokens)] pub struct OverClause { pub over_token: Keyword, pub paren1: Punct, @@ -1177,7 +1178,7 @@ pub struct OverClause { } /// A `PARTITION BY` clause for a window function. -#[derive(Clone, Debug, Drive, DriveMut, Emit, EmitDefault)] +#[derive(Clone, Debug, Drive, DriveMut, Emit, EmitDefault, ToTokens)] pub struct PartitionBy { pub partition_token: Keyword, pub by_token: Keyword, @@ -1185,7 +1186,7 @@ pub struct PartitionBy { } /// An `ORDER BY` clause. -#[derive(Clone, Debug, Drive, DriveMut, Emit, EmitDefault)] +#[derive(Clone, Debug, Drive, DriveMut, Emit, EmitDefault, ToTokens)] pub struct OrderBy { pub order_token: Keyword, pub by_token: Keyword, @@ -1193,28 +1194,28 @@ pub struct OrderBy { } /// An item in an `ORDER BY` clause. -#[derive(Clone, Debug, Drive, DriveMut, Emit, EmitDefault)] +#[derive(Clone, Debug, Drive, DriveMut, Emit, EmitDefault, ToTokens)] pub struct OrderByItem { pub expression: Expression, pub asc_desc: Option, } /// An `ASC` or `DESC` modifier. -#[derive(Clone, Debug, Drive, DriveMut, Emit, EmitDefault)] +#[derive(Clone, Debug, Drive, DriveMut, Emit, EmitDefault, ToTokens)] pub struct AscDesc { direction: Keyword, nulls_clause: Option, } /// A `NULLS FIRST` or `NULLS LAST` modifier. -#[derive(Clone, Debug, Drive, DriveMut, Emit, EmitDefault)] +#[derive(Clone, Debug, Drive, DriveMut, Emit, EmitDefault, ToTokens)] pub struct NullsClause { nulls_token: Keyword, first_last_token: CaseInsensitiveIdent, } /// A `LIMIT` clause. -#[derive(Clone, Debug, Drive, DriveMut, Emit, EmitDefault)] +#[derive(Clone, Debug, Drive, DriveMut, Emit, EmitDefault, ToTokens)] pub struct Limit { pub limit_token: Keyword, pub value: Box, @@ -1225,14 +1226,14 @@ pub struct Limit { /// See the [official grammar][]. We only implement part of this. /// /// [official grammar]: https://cloud.google.com/bigquery/docs/reference/standard-sql/window-function-calls#def_window_frame -#[derive(Clone, Debug, Drive, DriveMut, Emit, EmitDefault)] +#[derive(Clone, Debug, Drive, DriveMut, Emit, EmitDefault, ToTokens)] pub struct WindowFrame { pub rows_token: Keyword, pub definition: WindowFrameDefinition, } /// A window frame definition. -#[derive(Clone, Debug, Drive, DriveMut, Emit, EmitDefault)] +#[derive(Clone, Debug, Drive, DriveMut, Emit, EmitDefault, ToTokens)] pub enum WindowFrameDefinition { Start(WindowFrameStart), Between { @@ -1244,7 +1245,7 @@ pub enum WindowFrameDefinition { } /// A window frame start. Keep this simple for now. -#[derive(Clone, Debug, Drive, DriveMut, Emit, EmitDefault)] +#[derive(Clone, Debug, Drive, DriveMut, Emit, EmitDefault, ToTokens)] pub enum WindowFrameStart { UnboundedPreceding { unbounded_token: Keyword, @@ -1253,7 +1254,7 @@ pub enum WindowFrameStart { } /// A window frame end. Keep this simple for now. -#[derive(Clone, Debug, Drive, DriveMut, Emit, EmitDefault)] +#[derive(Clone, Debug, Drive, DriveMut, Emit, EmitDefault, ToTokens)] pub enum WindowFrameEnd { CurrentRow { current_token: Keyword, @@ -1262,7 +1263,7 @@ pub enum WindowFrameEnd { } /// Data types. -#[derive(Clone, Debug, Drive, DriveMut, EmitDefault)] +#[derive(Clone, Debug, Drive, DriveMut, EmitDefault, ToTokens)] pub enum DataType { Bool(CaseInsensitiveIdent), Bytes(CaseInsensitiveIdent), @@ -1383,14 +1384,14 @@ impl Emit for DataType { } /// A field in a `STRUCT` type. -#[derive(Clone, Debug, Drive, DriveMut, Emit, EmitDefault)] +#[derive(Clone, Debug, Drive, DriveMut, Emit, EmitDefault, ToTokens)] pub struct StructField { pub name: Option, pub data_type: DataType, } /// An array index expression. -#[derive(Clone, Debug, Drive, DriveMut, Emit, EmitDefault)] +#[derive(Clone, Debug, Drive, DriveMut, Emit, EmitDefault, ToTokens)] pub struct IndexExpression { pub expression: Box, pub bracket1: Punct, @@ -1399,7 +1400,7 @@ pub struct IndexExpression { } /// Different ways to index arrays. -#[derive(Clone, Debug, Drive, DriveMut, EmitDefault)] +#[derive(Clone, Debug, Drive, DriveMut, EmitDefault, ToTokens)] pub enum IndexOffset { Simple(Box), Offset { @@ -1460,14 +1461,14 @@ impl Emit for IndexOffset { } /// An `AS` alias. -#[derive(Clone, Debug, Drive, DriveMut, Emit, EmitDefault)] +#[derive(Clone, Debug, Drive, DriveMut, Emit, EmitDefault, ToTokens)] pub struct Alias { pub as_token: Option, pub ident: Ident, } /// The `FROM` clause. -#[derive(Clone, Debug, Drive, DriveMut, Emit, EmitDefault)] +#[derive(Clone, Debug, Drive, DriveMut, Emit, EmitDefault, ToTokens)] pub struct FromClause { pub from_token: Keyword, pub from_item: FromItem, @@ -1475,7 +1476,7 @@ pub struct FromClause { } /// Items which may appear in a `FROM` clause. -#[derive(Clone, Debug, Drive, DriveMut, EmitDefault)] +#[derive(Clone, Debug, Drive, DriveMut, EmitDefault, ToTokens)] pub enum FromItem { /// A table name, optionally with an alias. TableName { @@ -1529,7 +1530,7 @@ impl Emit for FromItem { } /// A join operation. -#[derive(Clone, Debug, Drive, DriveMut, Emit, EmitDefault)] +#[derive(Clone, Debug, Drive, DriveMut, Emit, EmitDefault, ToTokens)] pub enum JoinOperation { /// A `JOIN` clause. ConditionJoin { @@ -1549,7 +1550,7 @@ pub enum JoinOperation { } /// The type of a join. -#[derive(Clone, Debug, Drive, DriveMut, Emit, EmitDefault)] +#[derive(Clone, Debug, Drive, DriveMut, Emit, EmitDefault, ToTokens)] pub enum JoinType { Inner { inner_token: Option, @@ -1569,7 +1570,7 @@ pub enum JoinType { } /// The condition used for a `JOIN`. -#[derive(Clone, Debug, Drive, DriveMut, Emit, EmitDefault)] +#[derive(Clone, Debug, Drive, DriveMut, Emit, EmitDefault, ToTokens)] pub enum ConditionJoinOperator { Using { using_token: Keyword, @@ -1584,14 +1585,14 @@ pub enum ConditionJoinOperator { } /// A `WHERE` clause. -#[derive(Clone, Debug, Drive, DriveMut, Emit, EmitDefault)] +#[derive(Clone, Debug, Drive, DriveMut, Emit, EmitDefault, ToTokens)] pub struct WhereClause { pub where_token: Keyword, pub expression: Expression, } /// A `GROUP BY` clause. -#[derive(Clone, Debug, Drive, DriveMut, Emit, EmitDefault)] +#[derive(Clone, Debug, Drive, DriveMut, Emit, EmitDefault, ToTokens)] pub struct GroupBy { pub group_token: Keyword, pub by_token: Keyword, @@ -1599,21 +1600,21 @@ pub struct GroupBy { } /// A `HAVING` clause. -#[derive(Clone, Debug, Drive, DriveMut, Emit, EmitDefault)] +#[derive(Clone, Debug, Drive, DriveMut, Emit, EmitDefault, ToTokens)] pub struct Having { pub having_token: Keyword, pub expression: Expression, } /// A `QUALIFY` clause. -#[derive(Clone, Debug, Drive, DriveMut, Emit, EmitDefault)] +#[derive(Clone, Debug, Drive, DriveMut, Emit, EmitDefault, ToTokens)] pub struct Qualify { pub qualify_token: Keyword, pub expression: Expression, } /// A `DELETE FROM` statement. -#[derive(Clone, Debug, Drive, DriveMut, Emit, EmitDefault)] +#[derive(Clone, Debug, Drive, DriveMut, Emit, EmitDefault, ToTokens)] pub struct DeleteFromStatement { // DDL "keywords" are not actually treated as such by BigQuery. pub delete_token: CaseInsensitiveIdent, @@ -1624,7 +1625,7 @@ pub struct DeleteFromStatement { } /// A `INSERT INTO` statement. We only support the `SELECT` version. -#[derive(Clone, Debug, Drive, DriveMut, Emit, EmitDefault)] +#[derive(Clone, Debug, Drive, DriveMut, Emit, EmitDefault, ToTokens)] pub struct InsertIntoStatement { pub insert_token: CaseInsensitiveIdent, pub into_token: Keyword, @@ -1633,7 +1634,7 @@ pub struct InsertIntoStatement { } /// The data to be inserted into a table. -#[derive(Clone, Debug, Drive, DriveMut, Emit, EmitDefault)] +#[derive(Clone, Debug, Drive, DriveMut, Emit, EmitDefault, ToTokens)] pub enum InsertedData { /// A `SELECT` statement. Select { query: QueryExpression }, @@ -1645,7 +1646,7 @@ pub enum InsertedData { } /// A row in a `VALUES` clause. -#[derive(Clone, Debug, Drive, DriveMut, Emit, EmitDefault)] +#[derive(Clone, Debug, Drive, DriveMut, Emit, EmitDefault, ToTokens)] pub struct Row { pub paren1: Punct, pub expressions: NodeVec, @@ -1653,7 +1654,7 @@ pub struct Row { } /// A `CREATE TABLE` statement. -#[derive(Clone, Debug, Drive, DriveMut, Emit, EmitDefault)] +#[derive(Clone, Debug, Drive, DriveMut, Emit, EmitDefault, ToTokens)] pub struct CreateTableStatement { pub create_token: Keyword, pub or_replace: Option, @@ -1664,7 +1665,7 @@ pub struct CreateTableStatement { } /// A `CREATE VIEW` statement. -#[derive(Clone, Debug, Drive, DriveMut, Emit, EmitDefault)] +#[derive(Clone, Debug, Drive, DriveMut, Emit, EmitDefault, ToTokens)] pub struct CreateViewStatement { pub create_token: Keyword, pub or_replace: Option, @@ -1675,20 +1676,20 @@ pub struct CreateViewStatement { } /// The `OR REPLACE` modifier. -#[derive(Clone, Debug, Drive, DriveMut, Emit, EmitDefault)] +#[derive(Clone, Debug, Drive, DriveMut, Emit, EmitDefault, ToTokens)] pub struct OrReplace { pub or_token: Keyword, pub replace_token: CaseInsensitiveIdent, } /// The `TEMPORARY` modifier. -#[derive(Clone, Debug, Drive, DriveMut, Emit, EmitDefault)] +#[derive(Clone, Debug, Drive, DriveMut, Emit, EmitDefault, ToTokens)] pub struct Temporary { pub temporary_token: CaseInsensitiveIdent, } /// The part of a `CREATE TABLE` statement that defines the columns. -#[derive(Clone, Debug, Drive, DriveMut, Emit, EmitDefault)] +#[derive(Clone, Debug, Drive, DriveMut, Emit, EmitDefault, ToTokens)] pub enum CreateTableDefinition { /// ( column_definition [, ...] ) Columns { @@ -1704,14 +1705,14 @@ pub enum CreateTableDefinition { } /// A column definition. -#[derive(Clone, Debug, Drive, DriveMut, Emit, EmitDefault)] +#[derive(Clone, Debug, Drive, DriveMut, Emit, EmitDefault, ToTokens)] pub struct ColumnDefinition { pub name: Ident, pub data_type: DataType, } /// A `DROP TABLE` statement. -#[derive(Clone, Debug, Drive, DriveMut, Emit, EmitDefault)] +#[derive(Clone, Debug, Drive, DriveMut, Emit, EmitDefault, ToTokens)] pub struct DropTableStatement { pub drop_token: CaseInsensitiveIdent, pub table_token: CaseInsensitiveIdent, @@ -1720,7 +1721,7 @@ pub struct DropTableStatement { } /// A `DROP VIEW` statement. -#[derive(Clone, Debug, Drive, DriveMut, Emit, EmitDefault)] +#[derive(Clone, Debug, Drive, DriveMut, Emit, EmitDefault, ToTokens)] pub struct DropViewStatement { pub drop_token: CaseInsensitiveIdent, pub view_token: CaseInsensitiveIdent, @@ -1729,7 +1730,7 @@ pub struct DropViewStatement { } /// An `IF EXISTS` modifier. -#[derive(Clone, Debug, Drive, DriveMut, Emit, EmitDefault)] +#[derive(Clone, Debug, Drive, DriveMut, Emit, EmitDefault, ToTokens)] pub struct IfExists { pub if_token: Keyword, pub exists_token: Keyword, @@ -1807,7 +1808,7 @@ peg::parser! { = statements:sep_opt_trailing(, ";") { SqlProgram { statements } } - rule statement() -> Statement + pub rule statement() -> Statement = s:query_statement() { Statement::Query(s) } / i:insert_into_statement() { Statement::InsertInto(i) } / d:delete_from_statement() { Statement::DeleteFrom(d) } @@ -1975,7 +1976,7 @@ peg::parser! { /// /// [precedence table]: /// https://cloud.google.com/bigquery/docs/reference/standard-sql/operators#operator_precedence - rule expression() -> Expression = precedence! { + pub rule expression() -> Expression = precedence! { left:(@) or_token:k("OR") right:@ { Expression::Or { left: Box::new(left), or_token, right: Box::new(right) } } -- left:(@) and_token:k("AND") right:@ { Expression::And { left: Box::new(left), and_token, right: Box::new(right) } } diff --git a/src/tokenizer.rs b/src/tokenizer.rs index ceafef1..56e6e7b 100644 --- a/src/tokenizer.rs +++ b/src/tokenizer.rs @@ -38,9 +38,11 @@ use codespan_reporting::{ files::SimpleFiles, }; use derive_visitor::{Drive, DriveMut}; -use peg::{Parse, ParseElem, RuleResult}; +use joinery_macros::ToTokens; +use peg::{error::ParseError, Parse, ParseElem, RuleResult}; use crate::{ + ast, drivers::bigquery::BigQueryString, errors::{Result, SourceError}, }; @@ -280,7 +282,7 @@ impl PartialEq for Ident { /// A keyword. This is just a thin wrapper over an `Ident` to change the /// equality semantics. -#[derive(Debug, Drive, DriveMut, Clone, Eq)] +#[derive(Debug, Drive, DriveMut, Clone, Eq, ToTokens)] pub struct Keyword { /// Our keyword. pub ident: Ident, @@ -305,7 +307,7 @@ impl PartialEq for Keyword { /// but it appears in different places in the grammar. These words are only /// reserved in specific contexts and they don't normally need to be quoted /// when used as column names, etc. -#[derive(Debug, Drive, DriveMut, Clone, Eq)] +#[derive(Debug, Drive, DriveMut, Clone, Eq, ToTokens)] pub struct CaseInsensitiveIdent { /// Our identifier. pub ident: Ident, @@ -421,6 +423,39 @@ impl TokenStream { } } + /// Try to parse this stream using a grammar rule. This is generally called + /// to re-parse the token stream created by [`sql_quote!`]. + #[allow(dead_code)] + fn try_into_parsed(self, grammar_rule: R) -> Result + where + R: FnOnce(&TokenStream) -> Result>, + { + match grammar_rule(&self) { + Ok(t) => Ok(t), + Err(err) => { + let diagnostic = Diagnostic::error().with_message("Failed to parse token stream"); + Err(SourceError { + expected: err.to_string(), + files: SimpleFiles::new(), + diagnostic, + } + .into()) + } + } + } + + /// Try to parse this stream as a [`ast::Statement`]. + #[allow(dead_code)] + pub fn try_into_statement(self) -> Result { + self.try_into_parsed(ast::sql_program::statement) + } + + /// Try to parse this stream as a [`ast::Expression`]. + #[allow(dead_code)] + pub fn try_into_expression(self) -> Result { + self.try_into_parsed(ast::sql_program::expression) + } + /// Parse a literal. pub fn literal(&self, pos: usize) -> RuleResult { match self.tokens.get(pos) { @@ -599,6 +634,18 @@ impl ToTokens for Vec { } } +impl ToTokens for &T { + fn to_tokens(&self, tokens: &mut Vec) { + (*self).to_tokens(tokens); + } +} + +impl ToTokens for Box { + fn to_tokens(&self, tokens: &mut Vec) { + (**self).to_tokens(tokens); + } +} + impl ToTokens for TokenStream { /// This allows composing `TokenStream`s before re-parsing them. fn to_tokens(&self, tokens: &mut Vec) { @@ -1079,9 +1126,9 @@ mod test { } #[test] - fn sql_quote_builds_a_token_stream() { + fn sql_quote_and_try_into_statement() { let optional_distinct = Some(sql_quote! { DISTINCT }); - sql_quote! { + let statement = sql_quote! { SELECT #optional_distinct generate_uuid() AS id, "hello" AS message, @@ -1089,6 +1136,31 @@ mod test { 1.0 AS x, true AS t, false AS f, - }; + } + .try_into_statement() + .unwrap(); + assert!(matches!(statement, ast::Statement::Query(_))); + } + + #[test] + fn expression_rewriting() { + let if_expr = sql_quote! { IF(TRUE, 2.0, 1.0) } + .try_into_expression() + .unwrap(); + if let ast::Expression::If { + condition, + then_expression, + else_expression, + .. + } = &if_expr + { + let case_expr = + sql_quote! { CASE WHEN #condition THEN #then_expression ELSE #else_expression END } + .try_into_expression() + .unwrap(); + assert!(matches!(case_expr, ast::Expression::Case { .. })); + } else { + panic!("expected IF expression"); + } } }