Skip to content

Commit

Permalink
feat: added support for database transactions
Browse files Browse the repository at this point in the history
- two new functions `GraphQLTx` and `GraphQLByNameTx` allows you to pass a db transaction to be used for the request
- added a new `Tx` property to `ReqConfig` so you can also just use the existing functions with transactions
  • Loading branch information
dosco committed Jan 17, 2023
1 parent 2ffe65e commit c7ed332
Show file tree
Hide file tree
Showing 33 changed files with 419 additions and 932 deletions.
47 changes: 42 additions & 5 deletions core/api.go
Original file line number Diff line number Diff line change
Expand Up @@ -263,6 +263,9 @@ type ReqConfig struct {

// Pass additional variables complex variables such as functions that return string values.
Vars map[string]interface{}

// Execute this query as part of a transaction
Tx *sql.Tx
}

// SetNamespace is used to set namespace requests within a single instance of GraphJin. For example queries with the same name
Expand All @@ -277,12 +280,14 @@ func (rc *ReqConfig) GetNamespace() (string, bool) {
return "", false
}

// GraphQL function is called on the GraphJin struct to convert the provided GraphQL query into an
// SQL query and execute it on the database. In production mode prepared statements are directly used
// and no query compiling takes places.
// GraphQL function is our main function it takes a GraphQL query compiles it
// to SQL and executes returning the resulting JSON.
//
// In production mode the compiling happens only once and from there on the compiled queries
// are directly executed.
//
// In developer mode all named queries are saved into the queries folder and in production mode only
// queries from these saved queries can be used
// queries from these saved queries can be used.
func (g *GraphJin) GraphQL(c context.Context,
query string,
vars json.RawMessage,
Expand Down Expand Up @@ -345,8 +350,24 @@ func (g *GraphJin) GraphQL(c context.Context,
return
}

// GraphQLTx is similiar to the GraphQL function except that it can be used
// within a database transactions.
func (g *GraphJin) GraphQLTx(c context.Context,
tx *sql.Tx,
query string,
vars json.RawMessage,
rc *ReqConfig,
) (res *Result, err error) {
if rc == nil {
rc = &ReqConfig{Tx: tx}
} else {
rc.Tx = tx
}
return g.GraphQL(c, query, vars, rc)
}

// GraphQLByName is similar to the GraphQL function except that queries saved
// in the queries folder can directly be used by their filename.
// in the queries folder can directly be used just by their name (filename).
func (g *GraphJin) GraphQLByName(c context.Context,
name string,
vars json.RawMessage,
Expand All @@ -370,6 +391,22 @@ func (g *GraphJin) GraphQLByName(c context.Context,
return
}

// GraphQLByNameTx is similiar to the GraphQLByName function except
// that it can be used within a database transactions.
func (g *GraphJin) GraphQLByNameTx(c context.Context,
tx *sql.Tx,
name string,
vars json.RawMessage,
rc *ReqConfig,
) (res *Result, err error) {
if rc == nil {
rc = &ReqConfig{Tx: tx}
} else {
rc.Tx = tx
}
return g.GraphQLByName(c, name, vars, rc)
}

type graphqlReq struct {
ns string
op qcode.QType
Expand Down
3 changes: 3 additions & 0 deletions core/config.go
Original file line number Diff line number Diff line change
Expand Up @@ -98,6 +98,9 @@ type Config struct {
// Duration for polling the database to detect schema changes
DBSchemaPollDuration time.Duration `mapstructure:"db_schema_poll_duration" json:"db_schema_poll_duration" yaml:"db_schema_poll_duration" jsonschema:"title=Schema Change Detection Polling Duration,default=10s"`

// When set to the string "yes" it disables production security features like enforcing the allow list
DisableProdSecurity string `mapstructure:"disable_production_security" json:"disable_production_security" yaml:"disable_production_security" jsonschema:"title=Disable Production Security"`

// The default path to find all configuration files and scripts under
configPath string

Expand Down
15 changes: 10 additions & 5 deletions core/core.go
Original file line number Diff line number Diff line change
Expand Up @@ -236,7 +236,8 @@ func (gj *graphjin) executeRoleQuery(c context.Context,
return
}

if conn == nil {
needsConn := ((rc != nil && rc.Tx == nil) && conn == nil)
if needsConn {
c1, span := gj.spanStart(c, "Get Connection")
defer span.End()

Expand All @@ -254,10 +255,14 @@ func (gj *graphjin) executeRoleQuery(c context.Context,
c1, span := gj.spanStart(c, "Execute Role Query")
defer span.End()

err = retryOperation(c1, func() (err1 error) {
return conn.
QueryRowContext(c1, gj.roleStmt, ar.values...).
Scan(&role)
err = retryOperation(c1, func() error {
var row *sql.Row
if rc != nil && rc.Tx != nil {
row = rc.Tx.QueryRowContext(c1, gj.roleStmt, ar.values...)
} else {
row = conn.QueryRowContext(c1, gj.roleStmt, ar.values...)
}
return row.Scan(&role)
})
if err != nil {
span.Error(err)
Expand Down
2 changes: 1 addition & 1 deletion core/core_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,7 @@ import (
"golang.org/x/sync/errgroup"
)

// nolint: errcheck
// nolint:errcheck
func TestReadInConfigWithEnvVars(t *testing.T) {
devConfig := "secret_key: dev_secret_key\n"
prodConfig := "inherits: dev\nsecret_key: \"prod_secret_key\"\n"
Expand Down
72 changes: 45 additions & 27 deletions core/gstate.go
Original file line number Diff line number Diff line change
Expand Up @@ -174,19 +174,21 @@ func (s *gstate) compileAndExecuteWrapper(c context.Context) (err error) {
func (s *gstate) compileAndExecute(c context.Context) (err error) {
var conn *sql.Conn

// get a new database connection
c1, span1 := s.gj.spanStart(c, "Get Connection")
defer span1.End()
if s.tx() == nil {
// get a new database connection
c1, span1 := s.gj.spanStart(c, "Get Connection")
defer span1.End()

err = retryOperation(c1, func() (err1 error) {
conn, err1 = s.gj.db.Conn(c1)
return
})
if err != nil {
span1.Error(err)
return
err = retryOperation(c1, func() (err1 error) {
conn, err1 = s.gj.db.Conn(c1)
return
})
if err != nil {
span1.Error(err)
return
}
defer conn.Close()
}
defer conn.Close()

// set the local user id on the connection if needed
if s.gj.conf.SetUserID {
Expand Down Expand Up @@ -236,9 +238,13 @@ func (s *gstate) execute(c context.Context, conn *sql.Conn) (err error) {
defer span.End()

err = retryOperation(c1, func() (err1 error) {
return conn.
QueryRowContext(c1, cs.st.sql, args.values...).
Scan(&s.data)
var row *sql.Row
if tx := s.tx(); tx != nil {
row = tx.QueryRowContext(c1, cs.st.sql, args.values...)
} else {
row = conn.QueryRowContext(c1, cs.st.sql, args.values...)
}
return row.Scan(&s.data)
})

if err != nil && err != sql.ErrNoRows {
Expand Down Expand Up @@ -289,12 +295,17 @@ func (s *gstate) setLocalUserID(c context.Context, conn *sql.Conn) (err error) {
if v := c.Value(UserIDKey); v == nil {
return nil
} else {
var q string
switch v1 := v.(type) {
case string:
_, err = conn.ExecContext(c, `SET SESSION "user.id" = '`+v1+`'`)

q = `SET SESSION "user.id" = '` + v1 + `'`
case int:
_, err = conn.ExecContext(c, `SET SESSION "user.id" = `+strconv.Itoa(v1))
q = `SET SESSION "user.id" = ` + strconv.Itoa(v1)
}
if tx := s.tx(); tx != nil {
_, err = tx.ExecContext(c, q)
} else {
_, err = conn.ExecContext(c, q)
}
}
return
Expand Down Expand Up @@ -344,23 +355,30 @@ func (s *gstate) validateAndUpdateVars(c context.Context) (err error) {
return
}

func (s *gstate) sql() string {
func (s *gstate) sql() (sql string) {
if s.cs != nil && s.cs.st.qc != nil {
return s.cs.st.sql
sql = s.cs.st.sql
}
return ""
return
}

func (s *gstate) cacheHeader() string {
func (s *gstate) cacheHeader() (ch string) {
if s.cs != nil && s.cs.st.qc != nil {
return s.cs.st.qc.Cache.Header
ch = s.cs.st.qc.Cache.Header
}
return ""
return
}

func (s *gstate) qcode() *qcode.QCode {
if s.cs != nil && s.cs.st.qc != nil {
return s.cs.st.qc
func (s *gstate) qcode() (qc *qcode.QCode) {
if s.cs != nil {
qc = s.cs.st.qc
}
return nil
return
}

func (s *gstate) tx() (tx *sql.Tx) {
if s.r.rc != nil {
tx = s.r.rc.Tx
}
return
}
55 changes: 38 additions & 17 deletions core/insert_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -16,20 +16,24 @@ import (

func Example_insert() {
gql := `mutation {
users(insert: $data) {
users(insert: {
id: $id,
email: $email,
full_name: $fullName,
stripe_id: $stripeID,
category_counts: $categoryCounts
}) {
id
email
}
}`

vars := json.RawMessage(`{
"data": {
"id": 1001,
"email": "[email protected]",
"full_name": "User 1001",
"stripe_id": "payment_id_1001",
"category_counts": [{"category_id": 1, "count": 400},{"category_id": 2, "count": 600}]
}
"id": 1001,
"email": "[email protected]",
"fullName": "User 1001",
"stripeID": "payment_id_1001",
"categoryCounts": [{"category_id": 1, "count": 400},{"category_id": 2, "count": 600}]
}`)

conf := newConfig(&core.Config{DBType: dbType, DisableAllowList: true})
Expand All @@ -48,19 +52,26 @@ func Example_insert() {
// Output: {"users":[{"email":"[email protected]","id":1001}]}
}

func Example_insertInline() {
func Example_insertWithTransaction() {
gql := `mutation {
users(insert: { id: $id, email: $email, full_name: $full_name }) {
users(insert: {
id: $id,
email: $email,
full_name: $fullName,
stripe_id: $stripeID,
category_counts: $categoryCounts
}) {
id
email
full_name
}
}`

vars := json.RawMessage(`{
"id": 1007,
"email": "[email protected]",
"full_name": "User 1007"
"fullName": "User 1007",
"stripeID": "payment_id_1007",
"categoryCounts": [{"category_id": 1, "count": 400},{"category_id": 2, "count": 600}]
}`)

conf := newConfig(&core.Config{DBType: dbType, DisableAllowList: true})
Expand All @@ -69,14 +80,24 @@ func Example_insertInline() {
panic(err)
}

ctx := context.WithValue(context.Background(), core.UserIDKey, 3)
res, err := gj.GraphQL(ctx, gql, vars, nil)
c := context.Background()
tx, err := db.BeginTx(c, nil)
if err != nil {
panic(err)
}
defer tx.Rollback() //nolint:errcheck

c = context.WithValue(c, core.UserIDKey, 3)
res, err := gj.GraphQLTx(c, tx, gql, vars, nil)
if err != nil {
fmt.Println(err)
} else {
printJSON(res.Data)
return
}
if err := tx.Commit(); err != nil {
panic(err)
}
// Output: {"users":[{"email":"[email protected]","full_name":"User 1007","id":1007}]}
printJSON(res.Data)
// Output: {"users":[{"email":"[email protected]","id":1007}]}
}

func Example_insertInlineWithValidation() {
Expand Down
2 changes: 1 addition & 1 deletion core/internal/allow/gql.go
Original file line number Diff line number Diff line change
Expand Up @@ -33,7 +33,7 @@ func parseGQL(fs plugin.FS, fname string, r io.Writer) (err error) {
for s.Scan() {
m := incRe.FindStringSubmatch(s.Text())
if len(m) == 0 {
r.Write(s.Bytes()) //nolint: errcheck
r.Write(s.Bytes()) //nolint:errcheck
continue
}

Expand Down
4 changes: 2 additions & 2 deletions core/internal/psql/metadata.go
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,7 @@ func (c *Compiler) RenderVar(w *bytes.Buffer, md *Metadata, vv string) {
cc.renderVar(vv)
}

// nolint: errcheck
// nolint:errcheck
func (c *compilerContext) renderVar(vv string) {
f, s := -1, 0

Expand Down Expand Up @@ -47,7 +47,7 @@ func (c *compilerContext) renderVar(vv string) {
}
}

// nolint: errcheck
// nolint:errcheck
func (c *compilerContext) renderParam(p Param) {
var id int
var ok bool
Expand Down
2 changes: 1 addition & 1 deletion core/internal/util/graph_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@ import (
"github.com/stretchr/testify/assert"
)

// nolint: errcheck
// nolint:errcheck
func TestGraph1(t *testing.T) {
g := util.NewGraph()

Expand Down
Loading

1 comment on commit c7ed332

@vercel
Copy link

@vercel vercel bot commented on c7ed332 Jan 17, 2023

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Please sign in to comment.