Skip to content

Commit

Permalink
sql: add sqrt() function
Browse files Browse the repository at this point in the history
  • Loading branch information
erikgrinaker committed Jul 17, 2024
1 parent 92c471f commit bfecbbe
Show file tree
Hide file tree
Showing 7 changed files with 110 additions and 26 deletions.
4 changes: 4 additions & 0 deletions docs/sql.md
Original file line number Diff line number Diff line change
Expand Up @@ -157,6 +157,10 @@ The operator precedence (order of operations) is as follows:

Precedence can be overridden by wrapping an expression in parentheses, e.g. `(1 + 2) * 3`.

### Functions

* `sqrt(expr)`: returns the square root of a numerical argument.

### Aggregate functions

Aggregate function aggregate an expression across all rows, optionally grouped into buckets given by `GROUP BY`, and results can be filtered via `HAVING`.
Expand Down
8 changes: 5 additions & 3 deletions src/sql/planner/planner.rs
Original file line number Diff line number Diff line change
Expand Up @@ -492,9 +492,11 @@ impl<'a, C: Catalog> Planner<'a, C> {
ast::Expression::Column(table, name) => {
Column(scope.lookup_column(table.as_deref(), &name)?)
}
// Currently, all functions are aggregates, and processed above.
// TODO: consider adding some basic functions for fun.
ast::Expression::Function(name, _) => return errinput!("unknown function {name}"),
ast::Expression::Function(name, mut args) => match (name.as_str(), args.len()) {
// NB: aggregate functions are processed above.
("sqrt", 1) => SquareRoot(build(Box::new(args.remove(0)))?),
(name, n) => return errinput!("unknown function {name} with {n} arguments"),
},
ast::Expression::Operator(op) => match op {
ast::Operator::And(lhs, rhs) => And(build(lhs)?, build(rhs)?),
ast::Operator::Not(expr) => Not(build(expr)?),
Expand Down
36 changes: 36 additions & 0 deletions src/sql/testscripts/expressions/func
Original file line number Diff line number Diff line change
@@ -0,0 +1,36 @@
# Tests function calls.

# Function names are case-insensitive.
> sqrt(1)
> SQRT(1)
---
Float(1.0)
Float(1.0)

# A space is allowed around the arguments.
> sqrt ( 1 )
---
Float(1.0)

# Wrong number of arguments errors.
!> sqrt()
!> sqrt(1, 2)
---
Error: invalid input: unknown function sqrt with 0 arguments
Error: invalid input: unknown function sqrt with 2 arguments

# Unknown functions error.
!> unknown()
!> unknown(1, 2, 3)
---
Error: invalid input: unknown function unknown with 0 arguments
Error: invalid input: unknown function unknown with 3 arguments

# Parse errors.
!> unknown(1, 2, 3
!> unknown(1, 2, 3,)
!> unknown(1, 2 3)
---
Error: invalid input: unexpected end of input
Error: invalid input: expected expression atom, found )
Error: invalid input: expected token ,, found 3
52 changes: 52 additions & 0 deletions src/sql/testscripts/expressions/func_sqrt
Original file line number Diff line number Diff line change
@@ -0,0 +1,52 @@
# Tests sqrt().

# Integers work, and return floats.
[expr]> sqrt(2)
[expr]> sqrt(100)
---
Float(1.4142135623730951) ← SquareRoot(Constant(Integer(2)))
Float(10.0) ← SquareRoot(Constant(Integer(100)))

# Negative integers error, but 0 is valid.
!> sqrt(-1)
> sqrt(0)
---
Error: invalid input: can't take square root of -1
Float(0.0)

# Floats work.
> sqrt(3.14)
> sqrt(100.0)
---
Float(1.772004514666935)
Float(10.0)

# Negative floats work, but return NAN.
> sqrt(-1.0)
---
Float(NaN)

# Test various special float values.
> sqrt(-0.0)
> sqrt(0.0)
> sqrt(NAN)
> sqrt(INFINITY)
> sqrt(-INFINITY)
---
Float(-0.0)
Float(0.0)
Float(NaN)
Float(inf)
Float(NaN)

# NULL is passed through.
> sqrt(NULL)
---
Null

# Strings and booleans error.
!> sqrt(TRUE)
!> sqrt('foo')
---
Error: invalid input: can't take square root of TRUE
Error: invalid input: can't take square root of foo
21 changes: 0 additions & 21 deletions src/sql/testscripts/expressions/function

This file was deleted.

2 changes: 1 addition & 1 deletion src/sql/testscripts/queries/group_by
Original file line number Diff line number Diff line change
Expand Up @@ -183,7 +183,7 @@ Error: invalid input: unknown column mod
# GROUP BY can't use aggregate functions.
!> SELECT COUNT(*) FROM test GROUP BY MIN(id)
---
Error: invalid input: unknown function min
Error: invalid input: unknown function min with 1 arguments

# GROUP BY works with multiple groups.
[plan]> SELECT "group", "bool", COUNT(*) FROM test GROUP BY "group", "bool"
Expand Down
13 changes: 12 additions & 1 deletion src/sql/types/expression.rs
Original file line number Diff line number Diff line change
Expand Up @@ -52,6 +52,8 @@ pub enum Expression {
Multiply(Box<Expression>, Box<Expression>),
/// Negates the given number: -a.
Negate(Box<Expression>),
/// Takes the square root of a number: √a.
SquareRoot(Box<Expression>),
/// Subtracts two numbers: a - b.
Subtract(Box<Expression>, Box<Expression>),

Expand Down Expand Up @@ -88,6 +90,7 @@ impl Expression {
Self::Modulo(lhs, rhs) => format!("{} % {}", format(lhs), format(rhs)),
Self::Multiply(lhs, rhs) => format!("{} * {}", format(lhs), format(rhs)),
Self::Negate(expr) => format!("-{}", format(expr)),
Self::SquareRoot(expr) => format!("sqrt({})", format(expr)),
Self::Subtract(lhs, rhs) => format!("{} - {}", format(lhs), format(rhs)),

Self::Like(lhs, rhs) => format!("{} LIKE {}", format(lhs), format(rhs)),
Expand Down Expand Up @@ -202,6 +205,12 @@ impl Expression {
Null => Null,
value => return errinput!("can't negate {value}"),
},
Self::SquareRoot(expr) => match expr.evaluate(row)? {
Integer(i) if i >= 0 => Float((i as f64).sqrt()),
Float(f) => Float(f.sqrt()),
Null => Null,
value => return errinput!("can't take square root of {value}"),
},
Self::Subtract(lhs, rhs) => lhs.evaluate(row)?.checked_sub(&rhs.evaluate(row)?)?,

// LIKE pattern matching, using _ and % as single- and
Expand Down Expand Up @@ -242,7 +251,8 @@ impl Expression {
| Self::IsNaN(expr)
| Self::IsNull(expr)
| Self::Negate(expr)
| Self::Not(expr) => expr.walk(visitor),
| Self::Not(expr)
| Self::SquareRoot(expr) => expr.walk(visitor),

Self::Constant(_) | Self::Column(_) => true,
}
Expand Down Expand Up @@ -281,6 +291,7 @@ impl Expression {
Self::Modulo(lhs, rhs) => Self::Modulo(transform(lhs)?, transform(rhs)?),
Self::Multiply(lhs, rhs) => Self::Multiply(transform(lhs)?, transform(rhs)?),
Self::Or(lhs, rhs) => Self::Or(transform(lhs)?, transform(rhs)?),
Self::SquareRoot(expr) => Self::SquareRoot(transform(expr)?),
Self::Subtract(lhs, rhs) => Self::Subtract(transform(lhs)?, transform(rhs)?),

Self::Factorial(expr) => Self::Factorial(transform(expr)?),
Expand Down

0 comments on commit bfecbbe

Please sign in to comment.