Skip to content

Commit

Permalink
Merge pull request #20 from dolthub/daylon/gms-ast
Browse files Browse the repository at this point in the history
Dolt server handling parsed statements
  • Loading branch information
Hydrocharged authored Oct 18, 2023
2 parents d869ee4 + 85dc5c3 commit aefe550
Show file tree
Hide file tree
Showing 4 changed files with 70 additions and 35 deletions.
8 changes: 4 additions & 4 deletions go.mod
Original file line number Diff line number Diff line change
Expand Up @@ -6,9 +6,9 @@ require (
github.com/biogo/store v0.0.0-20201120204734-aad293a2328f
github.com/cockroachdb/apd/v2 v2.0.3-0.20200518165714-d020e156310a
github.com/cockroachdb/errors v1.7.5
github.com/dolthub/dolt/go v0.40.5-0.20231007000446-a29c9d57492d
github.com/dolthub/go-mysql-server v0.17.1-0.20231005225621-4cc2f2ca38ce
github.com/dolthub/vitess v0.0.0-20230929000236-6c60b48b32da
github.com/dolthub/dolt/go v0.40.5-0.20231018220650-48f565111c6a
github.com/dolthub/go-mysql-server v0.17.1-0.20231018193155-b175f6f77388
github.com/dolthub/vitess v0.0.0-20231018185551-6acf9c09c4fa
github.com/fatih/color v1.13.0
github.com/gogo/protobuf v1.3.2
github.com/golang/geo v0.0.0-20200730024412-e86565bf3f35
Expand Down Expand Up @@ -47,7 +47,7 @@ require (
github.com/cockroachdb/sentry-go v0.6.1-cockroachdb.2 // indirect
github.com/davecgh/go-spew v1.1.1 // indirect
github.com/denisbrodbeck/machineid v1.0.1 // indirect
github.com/dolthub/dolt/go/gen/proto/dolt/services/eventsapi v0.0.0-20231007000446-a29c9d57492d // indirect
github.com/dolthub/dolt/go/gen/proto/dolt/services/eventsapi v0.0.0-20231018220650-48f565111c6a // indirect
github.com/dolthub/flatbuffers/v23 v23.3.3-dh.2 // indirect
github.com/dolthub/fslock v0.0.3 // indirect
github.com/dolthub/go-icu-regex v0.0.0-20230524105445-af7e7991c97e // indirect
Expand Down
16 changes: 8 additions & 8 deletions go.sum
Original file line number Diff line number Diff line change
Expand Up @@ -207,18 +207,18 @@ github.com/dgrijalva/jwt-go v3.2.0+incompatible/go.mod h1:E3ru+11k8xSBh+hMPgOLZm
github.com/dgryski/go-farm v0.0.0-20190423205320-6a90982ecee2/go.mod h1:SqUrOPUnsFjfmXRMNPybcSiG0BgUW2AuFH8PAnS2iTw=
github.com/docker/go-connections v0.4.0/go.mod h1:Gbd7IOopHjR8Iph03tsViu4nIes5XhDvyHbTtUxmeec=
github.com/docker/go-units v0.4.0/go.mod h1:fgPhTUdO+D/Jk86RDLlptpiXQzgHJF7gydDDbaIK4Dk=
github.com/dolthub/dolt/go v0.40.5-0.20231007000446-a29c9d57492d h1:jdKO88AGZtbyNi7xFByY3S2u92YwHSxpNvahqTe/I4Q=
github.com/dolthub/dolt/go v0.40.5-0.20231007000446-a29c9d57492d/go.mod h1:zpwMN0iYV4H9BV84wI8QiNVVgp01t3tVcKDdBw7lbb0=
github.com/dolthub/dolt/go/gen/proto/dolt/services/eventsapi v0.0.0-20231007000446-a29c9d57492d h1:KayDNzxpIHS+x8ZBEDrh4f+013V1/zHfVBzudgXUfv8=
github.com/dolthub/dolt/go/gen/proto/dolt/services/eventsapi v0.0.0-20231007000446-a29c9d57492d/go.mod h1:Fi7KchJVfwMuPJkX4vJeAlNZkxCiVyhvVYfCgaSDlTU=
github.com/dolthub/dolt/go v0.40.5-0.20231018220650-48f565111c6a h1:JhVPlX0bjOODGoG15d61p+EdRNhqmWEfDqmrAbAUQEQ=
github.com/dolthub/dolt/go v0.40.5-0.20231018220650-48f565111c6a/go.mod h1:8zPn6VcsIyoLDhESnV9zZlN7j2rb1DrQ7TxljxQKjTk=
github.com/dolthub/dolt/go/gen/proto/dolt/services/eventsapi v0.0.0-20231018220650-48f565111c6a h1:WYBHmrFuKPqPfLB8ab1WqL5MgR8PQICkfUs5RTjHZM8=
github.com/dolthub/dolt/go/gen/proto/dolt/services/eventsapi v0.0.0-20231018220650-48f565111c6a/go.mod h1:Fi7KchJVfwMuPJkX4vJeAlNZkxCiVyhvVYfCgaSDlTU=
github.com/dolthub/flatbuffers/v23 v23.3.3-dh.2 h1:u3PMzfF8RkKd3lB9pZ2bfn0qEG+1Gms9599cr0REMww=
github.com/dolthub/flatbuffers/v23 v23.3.3-dh.2/go.mod h1:mIEZOHnFx4ZMQeawhw9rhsj+0zwQj7adVsnBX7t+eKY=
github.com/dolthub/fslock v0.0.3 h1:iLMpUIvJKMKm92+N1fmHVdxJP5NdyDK5bK7z7Ba2s2U=
github.com/dolthub/fslock v0.0.3/go.mod h1:QWql+P17oAAMLnL4HGB5tiovtDuAjdDTPbuqx7bYfa0=
github.com/dolthub/go-icu-regex v0.0.0-20230524105445-af7e7991c97e h1:kPsT4a47cw1+y/N5SSCkma7FhAPw7KeGmD6c9PBZW9Y=
github.com/dolthub/go-icu-regex v0.0.0-20230524105445-af7e7991c97e/go.mod h1:KPUcpx070QOfJK1gNe0zx4pA5sicIK1GMikIGLKC168=
github.com/dolthub/go-mysql-server v0.17.1-0.20231005225621-4cc2f2ca38ce h1:x27xw9s1Odf4MVCmjU8bqqOEY5wQY4YL18tOPUtJp1k=
github.com/dolthub/go-mysql-server v0.17.1-0.20231005225621-4cc2f2ca38ce/go.mod h1:KWMgEn//scUZuT8vHeHdMWrvCvcE7FrizZ0HKB08zrU=
github.com/dolthub/go-mysql-server v0.17.1-0.20231018193155-b175f6f77388 h1:Vzj2+SdblD+GJX4U0+m5a5Ve6lY/rTwOz1wpRwG+uE0=
github.com/dolthub/go-mysql-server v0.17.1-0.20231018193155-b175f6f77388/go.mod h1:Nk8uVbrCJjRlBluG4jtXaP+frjRFtanjOKJTKnvZDTk=
github.com/dolthub/ishell v0.0.0-20221214210346-d7db0b066488 h1:0HHu0GWJH0N6a6keStrHhUAK5/o9LVfkh44pvsV4514=
github.com/dolthub/ishell v0.0.0-20221214210346-d7db0b066488/go.mod h1:ehexgi1mPxRTk0Mok/pADALuHbvATulTh6gzr7NzZto=
github.com/dolthub/jsonpath v0.0.2-0.20230525180605-8dc13778fd72 h1:NfWmngMi1CYUWU4Ix8wM+USEhjc+mhPlT9JUR/anvbQ=
Expand All @@ -227,8 +227,8 @@ github.com/dolthub/maphash v0.0.0-20221220182448-74e1e1ea1577 h1:SegEguMxToBn045
github.com/dolthub/maphash v0.0.0-20221220182448-74e1e1ea1577/go.mod h1:gkg4Ch4CdCDu5h6PMriVLawB7koZ+5ijb9puGMV50a4=
github.com/dolthub/swiss v0.1.0 h1:EaGQct3AqeP/MjASHLiH6i4TAmgbG/c4rA6a1bzCOPc=
github.com/dolthub/swiss v0.1.0/go.mod h1:BeucyB08Vb1G9tumVN3Vp/pyY4AMUnr9p7Rz7wJ7kAQ=
github.com/dolthub/vitess v0.0.0-20230929000236-6c60b48b32da h1:QWw80StS0Sxn9UArlg8JE15Rka8g0Uz/nrIhC6K8PUA=
github.com/dolthub/vitess v0.0.0-20230929000236-6c60b48b32da/go.mod h1:IwjNXSQPymrja5pVqmfnYdcy7Uv7eNJNBPK/MEh9OOw=
github.com/dolthub/vitess v0.0.0-20231018185551-6acf9c09c4fa h1:5k+dGyoUAnan2RZmLiYEp8svmEFlJIUfaSKbZ5xXv1s=
github.com/dolthub/vitess v0.0.0-20231018185551-6acf9c09c4fa/go.mod h1:IwjNXSQPymrja5pVqmfnYdcy7Uv7eNJNBPK/MEh9OOw=
github.com/dustin/go-humanize v0.0.0-20171111073723-bb3d318650d4/go.mod h1:HtrtbFcZ19U5GC7JDqmcUSB87Iq5E25KnS6fMYU6eOk=
github.com/dustin/go-humanize v1.0.0 h1:VSnTsYCnlFHaM2/igO1h6X3HA71jcobQuxemgkq4zYo=
github.com/dustin/go-humanize v1.0.0/go.mod h1:HtrtbFcZ19U5GC7JDqmcUSB87Iq5E25KnS6fMYU6eOk=
Expand Down
53 changes: 35 additions & 18 deletions server/listener.go
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,7 @@ import (
"github.com/dolthub/go-mysql-server/sql/mysql_db"
"github.com/dolthub/vitess/go/mysql"
"github.com/dolthub/vitess/go/sqltypes"
vitess "github.com/dolthub/vitess/go/vt/sqlparser"

"github.com/dolthub/doltgresql/postgres/connection"
"github.com/dolthub/doltgresql/postgres/messages"
Expand All @@ -34,6 +35,13 @@ import (
var connectionIDCounter uint32
var processID = int32(os.Getpid())

// ParsedQuery represents a query that may have been parsed. If Parsed is nil, then the Query should be parsed by the
// listener, otherwise the Parsed statement should be used.
type ParsedQuery struct {
Query string
Parsed vitess.Statement
}

// Listener listens for connections to process PostgreSQL requests into Dolt requests.
type Listener struct {
listener net.Listener
Expand Down Expand Up @@ -221,7 +229,7 @@ InitialMessageLoop:
return
}

statementCache := make(map[string]string)
statementCache := make(map[string]ParsedQuery)
for {
receivedMessages, err := connection.Receive(conn)
if err != nil {
Expand All @@ -232,7 +240,7 @@ InitialMessageLoop:
return
}

portals := make(map[string]string)
portals := make(map[string]ParsedQuery)
ReadMessages:
for _, message := range receivedMessages {
switch message := message.(type) {
Expand All @@ -247,7 +255,7 @@ InitialMessageLoop:
case messages.Query:
var ok bool
if ok, err = l.handledPSQLCommands(conn, mysqlConn, message.String); !ok && err == nil {
var query string
var query ParsedQuery
if query, err = l.reinterpretQuery(message.String); err != nil {
l.endOfMessages(conn, err)
break ReadMessages
Expand All @@ -258,7 +266,7 @@ InitialMessageLoop:
l.endOfMessages(conn, err)
case messages.Parse:
//TODO: fully support prepared statements
var query string
var query ParsedQuery
if query, err = l.reinterpretQuery(message.Query); err != nil {
l.endOfMessages(conn, err)
break ReadMessages
Expand All @@ -270,7 +278,7 @@ InitialMessageLoop:
break ReadMessages
}
case messages.Describe:
var query string
var query ParsedQuery
if message.IsPrepared {
query = statementCache[message.Target]
} else {
Expand Down Expand Up @@ -298,13 +306,13 @@ InitialMessageLoop:
}

// execute handles running the given query. This will post the RowDescription, DataRow, and CommandComplete messages.
func (l *Listener) execute(conn net.Conn, mysqlConn *mysql.Conn, query string) error {
func (l *Listener) execute(conn net.Conn, mysqlConn *mysql.Conn, query ParsedQuery) error {
commandComplete := messages.CommandComplete{
Query: query,
Query: query.Query,
Rows: 0,
}

if err := l.cfg.Handler.ComQuery(mysqlConn, query, func(res *sqltypes.Result, more bool) error {
if err := l.comQuery(mysqlConn, query, func(res *sqltypes.Result, more bool) error {
if err := connection.Send(conn, messages.RowDescription{
Fields: res.Fields,
}); err != nil {
Expand Down Expand Up @@ -340,7 +348,7 @@ func (l *Listener) execute(conn net.Conn, mysqlConn *mysql.Conn, query string) e
}

// describe handles the description of the given query. This will post the ParameterDescription and RowDescription messages.
func (l *Listener) describe(conn net.Conn, mysqlConn *mysql.Conn, message messages.Describe, statement string) error {
func (l *Listener) describe(conn net.Conn, mysqlConn *mysql.Conn, message messages.Describe, statement ParsedQuery) error {
//TODO: fully support prepared statements
if err := connection.Send(conn, messages.ParameterDescription{
ObjectIDs: nil,
Expand All @@ -349,7 +357,7 @@ func (l *Listener) describe(conn net.Conn, mysqlConn *mysql.Conn, message messag
}

//TODO: properly handle these statements
if ImplicitlyCommits(statement) {
if ImplicitlyCommits(statement.Query) {
return fmt.Errorf("We do not yet support the Describe message for the given statement")
}
// We'll start a transaction, so that we can later rollback any changes that were made.
Expand All @@ -366,7 +374,7 @@ func (l *Listener) describe(conn net.Conn, mysqlConn *mysql.Conn, message messag
})
}()
// Execute the statement, and send the description.
if err := l.cfg.Handler.ComQuery(mysqlConn, statement, func(res *sqltypes.Result, more bool) error {
if err := l.comQuery(mysqlConn, statement, func(res *sqltypes.Result, more bool) error {
if res != nil {
if err := connection.Send(conn, messages.RowDescription{
Fields: res.Fields,
Expand All @@ -387,15 +395,15 @@ func (l *Listener) handledPSQLCommands(conn net.Conn, mysqlConn *mysql.Conn, sta
statement = strings.ToLower(statement)
// Command: \l
if statement == "select d.datname as \"name\",\n pg_catalog.pg_get_userbyid(d.datdba) as \"owner\",\n pg_catalog.pg_encoding_to_char(d.encoding) as \"encoding\",\n d.datcollate as \"collate\",\n d.datctype as \"ctype\",\n d.daticulocale as \"icu locale\",\n case d.datlocprovider when 'c' then 'libc' when 'i' then 'icu' end as \"locale provider\",\n pg_catalog.array_to_string(d.datacl, e'\\n') as \"access privileges\"\nfrom pg_catalog.pg_database d\norder by 1;" {
return true, l.execute(conn, mysqlConn, `SELECT SCHEMA_NAME AS 'Name', 'postgres' AS 'Owner', 'UTF8' AS 'Encoding', 'English_United States.1252' AS 'Collate', 'English_United States.1252' AS 'Ctype', '' AS 'ICU Locale', 'libc' AS 'Locale Provider', '' AS 'Access privileges' FROM INFORMATION_SCHEMA.SCHEMATA ORDER BY 1;`)
return true, l.execute(conn, mysqlConn, ParsedQuery{`SELECT SCHEMA_NAME AS 'Name', 'postgres' AS 'Owner', 'UTF8' AS 'Encoding', 'English_United States.1252' AS 'Collate', 'English_United States.1252' AS 'Ctype', '' AS 'ICU Locale', 'libc' AS 'Locale Provider', '' AS 'Access privileges' FROM INFORMATION_SCHEMA.SCHEMATA ORDER BY 1;`, nil})
}
// Command: \dt
if statement == "select n.nspname as \"schema\",\n c.relname as \"name\",\n case c.relkind when 'r' then 'table' when 'v' then 'view' when 'm' then 'materialized view' when 'i' then 'index' when 's' then 'sequence' when 't' then 'toast table' when 'f' then 'foreign table' when 'p' then 'partitioned table' when 'i' then 'partitioned index' end as \"type\",\n pg_catalog.pg_get_userbyid(c.relowner) as \"owner\"\nfrom pg_catalog.pg_class c\n left join pg_catalog.pg_namespace n on n.oid = c.relnamespace\n left join pg_catalog.pg_am am on am.oid = c.relam\nwhere c.relkind in ('r','p','')\n and n.nspname <> 'pg_catalog'\n and n.nspname !~ '^pg_toast'\n and n.nspname <> 'information_schema'\n and pg_catalog.pg_table_is_visible(c.oid)\norder by 1,2;" {
return true, l.execute(conn, mysqlConn, `SELECT 'public' AS 'Schema', TABLE_NAME AS 'Name', 'table' AS 'Type', 'postgres' AS 'Owner' FROM INFORMATION_SCHEMA.TABLES WHERE TABLE_SCHEMA = database() AND TABLE_TYPE = 'BASE TABLE' ORDER BY 2;`)
return true, l.execute(conn, mysqlConn, ParsedQuery{`SELECT 'public' AS 'Schema', TABLE_NAME AS 'Name', 'table' AS 'Type', 'postgres' AS 'Owner' FROM INFORMATION_SCHEMA.TABLES WHERE TABLE_SCHEMA = database() AND TABLE_TYPE = 'BASE TABLE' ORDER BY 2;`, nil})
}
// Command: \d
if statement == "select n.nspname as \"schema\",\n c.relname as \"name\",\n case c.relkind when 'r' then 'table' when 'v' then 'view' when 'm' then 'materialized view' when 'i' then 'index' when 's' then 'sequence' when 't' then 'toast table' when 'f' then 'foreign table' when 'p' then 'partitioned table' when 'i' then 'partitioned index' end as \"type\",\n pg_catalog.pg_get_userbyid(c.relowner) as \"owner\"\nfrom pg_catalog.pg_class c\n left join pg_catalog.pg_namespace n on n.oid = c.relnamespace\n left join pg_catalog.pg_am am on am.oid = c.relam\nwhere c.relkind in ('r','p','v','m','s','f','')\n and n.nspname <> 'pg_catalog'\n and n.nspname !~ '^pg_toast'\n and n.nspname <> 'information_schema'\n and pg_catalog.pg_table_is_visible(c.oid)\norder by 1,2;" {
return true, l.execute(conn, mysqlConn, `SELECT 'public' AS 'Schema', TABLE_NAME AS 'Name', 'table' AS 'Type', 'postgres' AS 'Owner' FROM INFORMATION_SCHEMA.TABLES WHERE TABLE_SCHEMA = database() AND TABLE_TYPE = 'BASE TABLE' ORDER BY 2;`)
return true, l.execute(conn, mysqlConn, ParsedQuery{`SELECT 'public' AS 'Schema', TABLE_NAME AS 'Name', 'table' AS 'Type', 'postgres' AS 'Owner' FROM INFORMATION_SCHEMA.TABLES WHERE TABLE_SCHEMA = database() AND TABLE_TYPE = 'BASE TABLE' ORDER BY 2;`, nil})
}
// Command: \d table_name
if strings.HasPrefix(statement, "select c.oid,\n n.nspname,\n c.relname\nfrom pg_catalog.pg_class c\n left join pg_catalog.pg_namespace n on n.oid = c.relnamespace\nwhere c.relname operator(pg_catalog.~) '^(") && strings.HasSuffix(statement, ")$' collate pg_catalog.default\n and pg_catalog.pg_table_is_visible(c.oid)\norder by 2, 3;") {
Expand All @@ -405,20 +413,20 @@ func (l *Listener) handledPSQLCommands(conn net.Conn, mysqlConn *mysql.Conn, sta
}
// Command: \dn
if statement == "select n.nspname as \"name\",\n pg_catalog.pg_get_userbyid(n.nspowner) as \"owner\"\nfrom pg_catalog.pg_namespace n\nwhere n.nspname !~ '^pg_' and n.nspname <> 'information_schema'\norder by 1;" {
return true, l.execute(conn, mysqlConn, "SELECT 'public' AS 'Name', 'pg_database_owner' AS 'Owner';")
return true, l.execute(conn, mysqlConn, ParsedQuery{"SELECT 'public' AS 'Name', 'pg_database_owner' AS 'Owner';", nil})
}
// Command: \df
if statement == "select n.nspname as \"schema\",\n p.proname as \"name\",\n pg_catalog.pg_get_function_result(p.oid) as \"result data type\",\n pg_catalog.pg_get_function_arguments(p.oid) as \"argument data types\",\n case p.prokind\n when 'a' then 'agg'\n when 'w' then 'window'\n when 'p' then 'proc'\n else 'func'\n end as \"type\"\nfrom pg_catalog.pg_proc p\n left join pg_catalog.pg_namespace n on n.oid = p.pronamespace\nwhere pg_catalog.pg_function_is_visible(p.oid)\n and n.nspname <> 'pg_catalog'\n and n.nspname <> 'information_schema'\norder by 1, 2, 4;" {
return true, l.execute(conn, mysqlConn, "SELECT '' AS 'Schema', '' AS 'Name', '' AS 'Result data type', '' AS 'Argument data types', '' AS 'Type' FROM dual LIMIT 0;")
return true, l.execute(conn, mysqlConn, ParsedQuery{"SELECT '' AS 'Schema', '' AS 'Name', '' AS 'Result data type', '' AS 'Argument data types', '' AS 'Type' FROM dual LIMIT 0;", nil})
}
// Command: \dv
if statement == "select n.nspname as \"schema\",\n c.relname as \"name\",\n case c.relkind when 'r' then 'table' when 'v' then 'view' when 'm' then 'materialized view' when 'i' then 'index' when 's' then 'sequence' when 't' then 'toast table' when 'f' then 'foreign table' when 'p' then 'partitioned table' when 'i' then 'partitioned index' end as \"type\",\n pg_catalog.pg_get_userbyid(c.relowner) as \"owner\"\nfrom pg_catalog.pg_class c\n left join pg_catalog.pg_namespace n on n.oid = c.relnamespace\nwhere c.relkind in ('v','')\n and n.nspname <> 'pg_catalog'\n and n.nspname !~ '^pg_toast'\n and n.nspname <> 'information_schema'\n and pg_catalog.pg_table_is_visible(c.oid)\norder by 1,2;" {
return true, l.execute(conn, mysqlConn, "SELECT 'public' AS 'Schema', TABLE_NAME AS 'Name', 'view' AS 'Type', 'postgres' AS 'Owner' FROM INFORMATION_SCHEMA.TABLES WHERE TABLE_SCHEMA = database() AND TABLE_TYPE = 'VIEW' ORDER BY 2;")
return true, l.execute(conn, mysqlConn, ParsedQuery{"SELECT 'public' AS 'Schema', TABLE_NAME AS 'Name', 'view' AS 'Type', 'postgres' AS 'Owner' FROM INFORMATION_SCHEMA.TABLES WHERE TABLE_SCHEMA = database() AND TABLE_TYPE = 'VIEW' ORDER BY 2;", nil})
}
// Command: \du
if statement == "select r.rolname, r.rolsuper, r.rolinherit,\n r.rolcreaterole, r.rolcreatedb, r.rolcanlogin,\n r.rolconnlimit, r.rolvaliduntil,\n array(select b.rolname\n from pg_catalog.pg_auth_members m\n join pg_catalog.pg_roles b on (m.roleid = b.oid)\n where m.member = r.oid) as memberof\n, r.rolreplication\n, r.rolbypassrls\nfrom pg_catalog.pg_roles r\nwhere r.rolname !~ '^pg_'\norder by 1;" {
// We don't support users yet, so we'll just return nothing for now
return true, l.execute(conn, mysqlConn, "SELECT '' FROM dual LIMIT 0;")
return true, l.execute(conn, mysqlConn, ParsedQuery{"SELECT '' FROM dual LIMIT 0;", nil})
}
return false, nil
}
Expand Down Expand Up @@ -447,3 +455,12 @@ func (l *Listener) endOfMessages(conn net.Conn, err error) {
panic(sendErr)
}
}

// comQuery is a shortcut that determines which version of ComQuery to call based on whether the query has been parsed.
func (l *Listener) comQuery(mysqlConn *mysql.Conn, query ParsedQuery, callback func(res *sqltypes.Result, more bool) error) error {
if query.Parsed == nil {
return l.cfg.Handler.ComQuery(mysqlConn, query.Query, callback)
} else {
return l.cfg.Handler.ComParsedQuery(mysqlConn, query.Query, query.Parsed, callback)
}
}
28 changes: 23 additions & 5 deletions server/reinterpret.go
Original file line number Diff line number Diff line change
Expand Up @@ -17,17 +17,35 @@ package server
import (
"fmt"

vitess "github.com/dolthub/vitess/go/vt/sqlparser"

"github.com/dolthub/doltgresql/postgres/parser/parser"
"github.com/dolthub/doltgresql/postgres/parser/sem/tree"
)

// reinterpretQuery takes the given Postgres query, and reinterprets it as a query that will work with the handler.
func (l *Listener) reinterpretQuery(query string) (string, error) {
// reinterpretQuery takes the given Postgres query, and reinterprets it as a ParsedQuery that will work with the handler.
func (l *Listener) reinterpretQuery(query string) (ParsedQuery, error) {
s, err := parser.Parse(query)
if err != nil {
return "", err
return ParsedQuery{}, err
}
if len(s) > 1 {
return "", fmt.Errorf("only a single statement at a time is currently supported")
return ParsedQuery{}, fmt.Errorf("only a single statement at a time is currently supported")
}
parsedAST := s[0].AST
// Proof-of-concept on how this can be expanded and used. We'll eventually have a full translation layer to convert
// from one AST to the other. For now, this lets us parse CREATE DATABASE while ignoring extra options like templates.
switch ast := parsedAST.(type) {
case *tree.CreateDatabase:
vitessParsed := &vitess.DBDDL{
Action: vitess.CreateStr,
DBName: ast.Name.String(),
IfNotExists: ast.IfNotExists,
}
// Normally we'd pass the original query in rather than use the empty string (for tracking purposes).
// However, for the sake of demonstration, we're using an empty string so that it's clear that it's working.
return ParsedQuery{"", vitessParsed}, nil
default:
return ParsedQuery{parsedAST.String(), nil}, nil
}
return s[0].AST.String(), nil
}

0 comments on commit aefe550

Please sign in to comment.