Skip to content

Commit

Permalink
Merge main
Browse files Browse the repository at this point in the history
  • Loading branch information
zachmu committed Jan 5, 2024
2 parents b684aa1 + 9262b49 commit 9492d9c
Show file tree
Hide file tree
Showing 16 changed files with 1,184 additions and 5 deletions.
18 changes: 18 additions & 0 deletions CONTRIBUTING.md
Original file line number Diff line number Diff line change
Expand Up @@ -238,6 +238,24 @@ There are exceptions, as some statements we do not yet support, and cannot suppo
In these cases, we must add a `//TODO:` comment stating what is missing and why it isn't an error.
This will at least allow us to track all such instances where we deviate from the expected behavior, which we can also document elsewhere for users of DoltgreSQL.

### `server/functions`

The `functions` package contains the functions, along with an implementation to approximate the function overloading structure (and type coercion).

The function overloading structure is defined in all files that have the `zinternal_` prefix.
Although not preferable, this was chosen as Go does not allow cyclical references between packages.
Rather than have half of the implementation in `functions`, and the other half in another package, the decision was made to include both in the `functions` package with the added prefix for distinction.

There's an `init` function in `server/functions/zinternal_catalog.go` (this is included in `server/listener.go`) that removes any conflicting GMS function names, and replaces them with the PostgreSQL equivalents.
This means that the functions that we've added behave as expected, and for others to have _some_ sort of implementation rather than outright failing.
We will eventually remove all GMS functions once all PostgreSQL functions have been implemented.
The other internal files all contribute to the generation of functions, along with their proper handling.

Each function (and all overloads) are contained in a single file.
Overloads are named according to their parameters, and prefixed by their target function name.
The set of overloads are then added to the `Catalog` within `server/functions/zinternal_catalog.go`.
To add a new function, it is as simple as creating the `Function`, adding the overloads, and adding it to the `Catalog`.

### `testing/bats`

All Bats tests must follow this general structure:
Expand Down
6 changes: 2 additions & 4 deletions server/ast/expr.go
Original file line number Diff line number Diff line change
Expand Up @@ -189,10 +189,8 @@ func nodeExpr(node tree.Expr) (vitess.Expr, error) {
}

switch node.SyntaxMode {
case tree.CastExplicit:
// only acceptable cast type
case tree.CastShort:
return nil, fmt.Errorf("TYPECAST is not yet supported")
case tree.CastExplicit, tree.CastShort:
// Both of these are acceptable
case tree.CastPrepend:
return nil, fmt.Errorf("typed literals are not yet supported")
default:
Expand Down
1 change: 1 addition & 0 deletions server/ast/resolvable_type_reference.go
Original file line number Diff line number Diff line change
Expand Up @@ -46,6 +46,7 @@ func nodeResolvableTypeReference(typ tree.ResolvableTypeReference) (*vitess.Conv
columnTypeName = columnType.SQLStandardName()
switch columnType.Family() {
case types.DecimalFamily:
columnTypeName = "decimal"
columnTypeLength = vitess.NewIntVal([]byte(strconv.Itoa(int(columnType.Precision()))))
columnTypeScale = vitess.NewIntVal([]byte(strconv.Itoa(int(columnType.Scale()))))
case types.JsonFamily:
Expand Down
37 changes: 37 additions & 0 deletions server/functions/cbrt.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,37 @@
// Copyright 2023 Dolthub, Inc.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.

package functions

import (
"fmt"
"math"
)

// cbrt represents the PostgreSQL function of the same name.
var cbrt = Function{
Name: "cbrt",
Overloads: []interface{}{cbrt_float},
}

// cbrt_float is one of the overloads of cbrt.
func cbrt_float(num FloatType) (FloatType, error) {
if num.IsNull {
return FloatType{IsNull: true}, nil
}
if num.OriginalType == ParameterType_String {
return FloatType{}, fmt.Errorf("function cbrt(%s) does not exist", ParameterType_String.String())
}
return FloatType{Value: math.Cbrt(num.Value)}, nil
}
44 changes: 44 additions & 0 deletions server/functions/gcd.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,44 @@
// Copyright 2023 Dolthub, Inc.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.

package functions

import (
"fmt"

"github.com/dolthub/doltgresql/utils"
)

// gcd represents the PostgreSQL function of the same name.
var gcd = Function{
Name: "gcd",
Overloads: []interface{}{gcd_int_int},
}

// gcd_int_int is one of the overloads of gcd.
func gcd_int_int(num1 IntegerType, num2 IntegerType) (IntegerType, error) {
if num1.IsNull || num2.IsNull {
return IntegerType{IsNull: true}, nil
}
if num1.OriginalType == ParameterType_String || num2.OriginalType == ParameterType_String {
return IntegerType{}, fmt.Errorf("function gcd(%s, %s) does not exist",
num1.OriginalType.String(), num2.OriginalType.String())
}
for num2.Value != 0 {
temp := num2.Value
num2.Value = num1.Value % num2.Value
num1.Value = temp
}
return IntegerType{Value: utils.Abs(num1.Value)}, nil
}
46 changes: 46 additions & 0 deletions server/functions/lcm.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,46 @@
// Copyright 2023 Dolthub, Inc.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.

package functions

import (
"fmt"

"github.com/dolthub/doltgresql/utils"
)

// lcm represents the PostgreSQL function of the same name.
var lcm = Function{
Name: "lcm",
Overloads: []interface{}{lcm1_int_int},
}

// lcm1 is one of the overloads of lcm.
func lcm1_int_int(num1 IntegerType, num2 IntegerType) (IntegerType, error) {
if num1.IsNull || num2.IsNull {
return IntegerType{IsNull: true}, nil
}
if num1.OriginalType == ParameterType_String || num2.OriginalType == ParameterType_String {
return IntegerType{}, fmt.Errorf("function lcm(%s, %s) does not exist",
num1.OriginalType.String(), num2.OriginalType.String())
}
gcdResult, err := gcd_int_int(num1, num2)
if err != nil {
return IntegerType{}, err
}
if gcdResult.Value == 0 {
return IntegerType{Value: 0}, nil
}
return IntegerType{Value: utils.Abs((num1.Value * num2.Value) / gcdResult.Value)}, nil
}
48 changes: 48 additions & 0 deletions server/functions/round.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,48 @@
// Copyright 2023 Dolthub, Inc.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.

package functions

import "math"

// round represents the PostgreSQL function of the same name.
var round = Function{
Name: "round",
Overloads: []interface{}{round_num, round_float, round_num_dec},
}

// round1 is one of the overloads of round.
func round_num(num NumericType) (NumericType, error) {
if num.IsNull {
return NumericType{IsNull: true}, nil
}
return NumericType{Value: math.Round(num.Value)}, nil
}

// round2 is one of the overloads of round.
func round_float(num FloatType) (FloatType, error) {
if num.IsNull {
return FloatType{IsNull: true}, nil
}
return FloatType{Value: math.RoundToEven(num.Value)}, nil
}

// round3 is one of the overloads of round.
func round_num_dec(num NumericType, decimalPlaces IntegerType) (NumericType, error) {
if num.IsNull || decimalPlaces.IsNull {
return NumericType{IsNull: true}, nil
}
ratio := math.Pow10(int(decimalPlaces.Value))
return NumericType{Value: math.Round(num.Value*ratio) / ratio}, nil
}
121 changes: 121 additions & 0 deletions server/functions/zinternal_catalog.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,121 @@
// Copyright 2023 Dolthub, Inc.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.

package functions

import (
"fmt"
"reflect"
"strings"

"github.com/dolthub/go-mysql-server/sql"
"github.com/dolthub/go-mysql-server/sql/expression/function"
)

// Function is a name, along with a collection of functions, that represent a single PostgreSQL function with all of its
// overloads.
type Function struct {
Name string
Overloads []any
}

// Catalog contains all of the PostgreSQL functions. If a new function is added, make sure to add it to the catalog here.
var Catalog = []Function{
cbrt,
gcd,
lcm,
round,
}

// init handles the initialization of the catalog by overwriting the built-in GMS functions, since they do not apply to
// PostgreSQL (and functions of the same name often have different behavior).
func init() {
catalogMap := make(map[string]struct{})
for _, f := range Catalog {
catalogMap[strings.ToLower(f.Name)] = struct{}{}
}
var newBuiltIns []sql.Function
for _, f := range function.BuiltIns {
if _, ok := catalogMap[strings.ToLower(f.FunctionName())]; !ok {
newBuiltIns = append(newBuiltIns, f)
}
}
function.BuiltIns = newBuiltIns

allNames := make(map[string]struct{})
for _, catalogItem := range Catalog {
funcName := strings.ToLower(catalogItem.Name)
if _, ok := allNames[funcName]; ok {
panic("duplicate name: " + catalogItem.Name)
}
allNames[funcName] = struct{}{}

baseOverload := &OverloadDeduction{}
for _, functionOverload := range catalogItem.Overloads {
// For each function overload, we first need to ensure that it has an acceptable signature
funcVal := reflect.ValueOf(functionOverload)
if !funcVal.IsValid() || funcVal.IsNil() {
panic(fmt.Errorf("function `%s` has an invalid item", catalogItem.Name))
}
if funcVal.Kind() != reflect.Func {
panic(fmt.Errorf("function `%s` has a non-function item", catalogItem.Name))
}
if funcVal.Type().NumOut() != 2 {
panic(fmt.Errorf("function `%s` has an overload that does not return two values", catalogItem.Name))
}
if funcVal.Type().Out(1) != reflect.TypeOf((*error)(nil)).Elem() {
panic(fmt.Errorf("function `%s` has an overload that does not return an error", catalogItem.Name))
}
returnValType, returnSqlType, ok := ParameterTypeFromReflection(funcVal.Type().Out(0))
if !ok {
panic(fmt.Errorf("function `%s` has an overload that returns as invalid type (`%s`)",
catalogItem.Name, funcVal.Type().Out(0).String()))
}

// Loop through all of the parameters to ensure uniqueness, then store it
currentOverload := baseOverload
for i := 0; i < funcVal.Type().NumIn(); i++ {
paramValType, _, ok := ParameterTypeFromReflection(funcVal.Type().In(i))
if !ok {
panic(fmt.Errorf("function `%s` has an overload with an invalid parameter type (`%s`)",
catalogItem.Name, funcVal.Type().In(i).String()))
}
nextOverload := currentOverload.Parameter[paramValType]
if nextOverload == nil {
nextOverload = &OverloadDeduction{}
currentOverload.Parameter[paramValType] = nextOverload
}
currentOverload = nextOverload
}
if currentOverload.Function.IsValid() && !currentOverload.Function.IsNil() {
panic(fmt.Errorf("function `%s` has duplicate overloads", catalogItem.Name))
}
currentOverload.Function = funcVal
currentOverload.ReturnValType = returnValType
currentOverload.ReturnSqlType = returnSqlType
}

// Store the compiled function into the engine's built-in functions
function.BuiltIns = append(function.BuiltIns, sql.FunctionN{
Name: funcName,
Fn: func(params ...sql.Expression) (sql.Expression, error) {
return &CompiledFunction{
Name: catalogItem.Name,
Parameters: params,
Functions: baseOverload,
}, nil
},
})
}
}
Loading

0 comments on commit 9492d9c

Please sign in to comment.