Skip to content

Commit

Permalink
Move queries with variable number of args to a fixed number of args s…
Browse files Browse the repository at this point in the history
…yntax (#1114)

WHERE x in ($1,$2,$3, ...) is replaced with WHERE x = ANY($1::type_of_x[]), with all values passed as a single array argument.

INSERT INTO t (a,b,c) VALUES ($1,$2,$3), ($4,$5,$6), ... is replaced with INSERT INTO .. SELECT unnest()
  • Loading branch information
ftkg authored Nov 1, 2023
1 parent 2ee6c92 commit 211f6d1
Show file tree
Hide file tree
Showing 11 changed files with 152 additions and 232 deletions.
14 changes: 4 additions & 10 deletions server/console_channel.go
Original file line number Diff line number Diff line change
Expand Up @@ -5,8 +5,6 @@ import (
"database/sql"
"net/url"
"sort"
"strconv"
"strings"
"time"

"github.com/gofrs/uuid/v5"
Expand Down Expand Up @@ -66,20 +64,16 @@ func (s *ConsoleServer) DeleteChannelMessages(ctx context.Context, in *console.D
s.logger.Info("Messages deleted.", zap.Int64("affected", affected), zap.String("timestamp", deleteBefore.String()))
}
if len(in.Ids) > 0 {
params := make([]interface{}, 0, len(in.Ids))
statements := make([]string, len(in.Ids))
for i, id := range in.Ids {
idStr, err := uuid.FromString(id)
for _, id := range in.Ids {
_, err := uuid.FromString(id)
if err != nil {
return nil, status.Error(codes.InvalidArgument, "Requires a valid message ID.")
}
params = append(params, idStr)
statements[i] = "$" + strconv.Itoa(i+1)
}
query := "DELETE FROM message WHERE id IN (" + strings.Join(statements, ",") + ")"
query := "DELETE FROM message WHERE id = ANY($1)"
var res sql.Result
var err error
if res, err = s.db.ExecContext(ctx, query, params...); err != nil {
if res, err = s.db.ExecContext(ctx, query, in.Ids); err != nil {
s.logger.Error("Could not delete messages.", zap.Error(err))
return nil, status.Error(codes.Internal, "An error occurred while trying to delete messages.")
}
Expand Down
11 changes: 2 additions & 9 deletions server/core_account.go
Original file line number Diff line number Diff line change
Expand Up @@ -139,20 +139,13 @@ WHERE u.id = $1`
}

func GetAccounts(ctx context.Context, logger *zap.Logger, db *sql.DB, statusRegistry *StatusRegistry, userIDs []string) ([]*api.Account, error) {
statements := make([]string, 0, len(userIDs))
parameters := make([]interface{}, 0, len(userIDs))
for _, userID := range userIDs {
parameters = append(parameters, userID)
statements = append(statements, "$"+strconv.Itoa(len(parameters)))
}

query := `
SELECT u.id, u.username, u.display_name, u.avatar_url, u.lang_tag, u.location, u.timezone, u.metadata, u.wallet,
u.email, u.apple_id, u.facebook_id, u.facebook_instant_game_id, u.google_id, u.gamecenter_id, u.steam_id, u.custom_id, u.edge_count,
u.create_time, u.update_time, u.verify_time, u.disable_time, array(select ud.id from user_device ud where u.id = ud.user_id)
FROM users u
WHERE u.id IN (` + strings.Join(statements, ",") + `)`
rows, err := db.QueryContext(ctx, query, parameters...)
WHERE u.id = ANY($1)`
rows, err := db.QueryContext(ctx, query, userIDs)
if err != nil {
logger.Error("Error retrieving user accounts.", zap.Error(err))
return nil, err
Expand Down
38 changes: 16 additions & 22 deletions server/core_authenticate.go
Original file line number Diff line number Diff line change
Expand Up @@ -854,15 +854,13 @@ func importSteamFriends(ctx context.Context, logger *zap.Logger, db *sql.DB, mes
return nil
}

statements := make([]string, 0, len(steamProfiles))
params := make([]interface{}, 0, len(steamProfiles))
for i, steamProfile := range steamProfiles {
statements = append(statements, "$"+strconv.Itoa(i+1))
params = append(params, strconv.FormatUint(steamProfile.SteamID, 10))
steamIDs := make([]string, 0, len(steamProfiles))
for _, steamProfile := range steamProfiles {
steamIDs = append(steamIDs, strconv.FormatUint(steamProfile.SteamID, 10))
}

query := "SELECT id FROM users WHERE steam_id IN (" + strings.Join(statements, ", ") + ")"
rows, err := tx.QueryContext(ctx, query, params...)
query := "SELECT id FROM users WHERE steam_id = ANY($1::text[])"
rows, err := tx.QueryContext(ctx, query, steamIDs)
if err != nil {
if err == sql.ErrNoRows {
// None of the friend profiles exist.
Expand All @@ -872,7 +870,7 @@ func importSteamFriends(ctx context.Context, logger *zap.Logger, db *sql.DB, mes
}

var id string
possibleFriendIDs := make([]uuid.UUID, 0, len(statements))
possibleFriendIDs := make([]uuid.UUID, 0, len(steamIDs))
for rows.Next() {
err = rows.Scan(&id)
if err != nil {
Expand Down Expand Up @@ -930,15 +928,13 @@ func importFacebookFriends(ctx context.Context, logger *zap.Logger, db *sql.DB,
return nil
}

statements := make([]string, 0, len(facebookProfiles))
params := make([]interface{}, 0, len(facebookProfiles))
for i, facebookProfile := range facebookProfiles {
statements = append(statements, "$"+strconv.Itoa(i+1))
params := make([]string, 0, len(facebookProfiles))
for _, facebookProfile := range facebookProfiles {
params = append(params, facebookProfile.ID)
}

query := "SELECT id FROM users WHERE facebook_id IN (" + strings.Join(statements, ", ") + ")"
rows, err := tx.QueryContext(ctx, query, params...)
query := "SELECT id FROM users WHERE facebook_id = ANY($1::text[])"
rows, err := tx.QueryContext(ctx, query, params)
if err != nil {
if err == sql.ErrNoRows {
// None of the friend profiles exist.
Expand All @@ -948,7 +944,7 @@ func importFacebookFriends(ctx context.Context, logger *zap.Logger, db *sql.DB,
}

var id string
possibleFriendIDs := make([]uuid.UUID, 0, len(statements))
possibleFriendIDs := make([]uuid.UUID, 0, len(params))
for rows.Next() {
err = rows.Scan(&id)
if err != nil {
Expand Down Expand Up @@ -1005,8 +1001,7 @@ func resetUserFriends(ctx context.Context, tx *sql.Tx, userID uuid.UUID) error {
if err != nil {
return err
}
statements := make([]string, 0, 10)
params := make([]interface{}, 0, 10)
params := make([]string, 0, 10)
for rows.Next() {
var id string
err = rows.Scan(&id)
Expand All @@ -1015,17 +1010,16 @@ func resetUserFriends(ctx context.Context, tx *sql.Tx, userID uuid.UUID) error {
return err
}
params = append(params, id)
statements = append(statements, "$"+strconv.Itoa(len(params)))
}
_ = rows.Close()

if len(statements) > 0 {
query = "UPDATE users SET edge_count = edge_count - 1 WHERE id IN (" + strings.Join(statements, ",") + ")"
result, err := tx.ExecContext(ctx, query, params...)
if len(params) > 0 {
query = "UPDATE users SET edge_count = edge_count - 1 WHERE id = ANY($1)"
result, err := tx.ExecContext(ctx, query, params)
if err != nil {
return err
}
if rowsAffectedCount, _ := result.RowsAffected(); rowsAffectedCount != int64(len(statements)) {
if rowsAffectedCount, _ := result.RowsAffected(); rowsAffectedCount != int64(len(params)) {
return errors.New("error updating edge count after friend reset")
}
}
Expand Down
12 changes: 2 additions & 10 deletions server/core_group.go
Original file line number Diff line number Diff line change
Expand Up @@ -1624,18 +1624,10 @@ func GetGroups(ctx context.Context, logger *zap.Logger, db *sql.DB, ids []string
return make([]*api.Group, 0), nil
}

statements := make([]string, 0, len(ids))
params := make([]interface{}, 0, len(ids))
for i, id := range ids {
statements = append(statements, "$"+strconv.Itoa(i+1))
params = append(params, id)
}

query := `SELECT id, creator_id, name, description, avatar_url, state, edge_count, lang_tag, max_count, metadata, create_time, update_time
FROM groups
WHERE disable_time = '1970-01-01 00:00:00 UTC'
AND id IN (` + strings.Join(statements, ",") + `)`
rows, err := db.QueryContext(ctx, query, params...)
WHERE disable_time = '1970-01-01 00:00:00 UTC' AND id = ANY($1)`
rows, err := db.QueryContext(ctx, query, ids)
if err != nil {
if err == sql.ErrNoRows {
return make([]*api.Group, 0), nil
Expand Down
13 changes: 4 additions & 9 deletions server/core_leaderboard.go
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,6 @@ import (
"encoding/gob"
"errors"
"sort"
"strconv"
"strings"
"time"

Expand Down Expand Up @@ -286,15 +285,11 @@ func LeaderboardRecordsList(ctx context.Context, logger *zap.Logger, db *sql.DB,
}

if len(ownerIds) != 0 {
params := make([]interface{}, 0, len(ownerIds)+2)
params = append(params, leaderboardId, time.Unix(expiryTime, 0).UTC())
statements := make([]string, len(ownerIds))
for i, ownerID := range ownerIds {
params = append(params, ownerID)
statements[i] = "$" + strconv.Itoa(i+3)
}
params := []any{leaderboardId, time.Unix(expiryTime, 0).UTC(), ownerIds}
query := `SELECT owner_id, username, score, subscore, num_score, max_num_score, metadata, create_time, update_time
FROM leaderboard_record
WHERE leaderboard_id = $1 AND expiry_time = $2 AND owner_id = ANY($3)`

query := "SELECT owner_id, username, score, subscore, num_score, max_num_score, metadata, create_time, update_time FROM leaderboard_record WHERE leaderboard_id = $1 AND expiry_time = $2 AND owner_id IN (" + strings.Join(statements, ", ") + ")"
rows, err := db.QueryContext(ctx, query, params...)
if err != nil {
logger.Error("Error reading leaderboard records", zap.Error(err))
Expand Down
60 changes: 26 additions & 34 deletions server/core_notification.go
Original file line number Diff line number Diff line change
Expand Up @@ -21,8 +21,6 @@ import (
"encoding/base64"
"encoding/gob"
"fmt"
"strconv"
"strings"
"time"

"github.com/gofrs/uuid/v5"
Expand Down Expand Up @@ -261,17 +259,9 @@ ORDER BY create_time ASC, id ASC`+limitQuery, params...)
}

func NotificationDelete(ctx context.Context, logger *zap.Logger, db *sql.DB, userID uuid.UUID, notificationIDs []string) error {
statements := make([]string, 0, len(notificationIDs))
params := make([]interface{}, 0, len(notificationIDs)+1)
params = append(params, userID)

for _, id := range notificationIDs {
statement := "$" + strconv.Itoa(len(params)+1)
statements = append(statements, statement)
params = append(params, id)
}
params := []any{userID, notificationIDs}

query := "DELETE FROM notification WHERE user_id = $1 AND id IN (" + strings.Join(statements, ", ") + ")"
query := "DELETE FROM notification WHERE user_id = $1 AND id = ANY($2)"
logger.Debug("Delete notification query", zap.String("query", query), zap.Any("params", params))
_, err := db.ExecContext(ctx, query, params...)
if err != nil {
Expand All @@ -283,33 +273,35 @@ func NotificationDelete(ctx context.Context, logger *zap.Logger, db *sql.DB, use
}

func NotificationSave(ctx context.Context, logger *zap.Logger, db *sql.DB, notifications map[uuid.UUID][]*api.Notification) error {
statements := make([]string, 0, len(notifications))
params := make([]interface{}, 0, len(notifications))
counter := 0
ids := make([]string, 0, len(notifications))
userIds := make([]uuid.UUID, 0, len(notifications))
subjects := make([]string, 0, len(notifications))
contents := make([]string, 0, len(notifications))
codes := make([]int32, 0, len(notifications))
senderIds := make([]string, 0, len(notifications))
query := `
INSERT INTO
notification (id, user_id, subject, content, code, sender_id)
SELECT
unnest($1::uuid[]),
unnest($2::uuid[]),
unnest($3::text[]),
unnest($4::jsonb[]),
unnest($5::smallint[]),
unnest($6::uuid[]);
`
for userID, no := range notifications {
for _, un := range no {
statement := "$" + strconv.Itoa(counter+1) +
",$" + strconv.Itoa(counter+2) +
",$" + strconv.Itoa(counter+3) +
",$" + strconv.Itoa(counter+4) +
",$" + strconv.Itoa(counter+5) +
",$" + strconv.Itoa(counter+6)

counter = counter + 6
statements = append(statements, "("+statement+")")

params = append(params, un.Id)
params = append(params, userID)
params = append(params, un.Subject)
params = append(params, un.Content)
params = append(params, un.Code)
params = append(params, un.SenderId)
ids = append(ids, un.Id)
userIds = append(userIds, userID)
subjects = append(subjects, un.Subject)
contents = append(contents, un.Content)
codes = append(codes, un.Code)
senderIds = append(senderIds, un.SenderId)
}
}

query := "INSERT INTO notification (id, user_id, subject, content, code, sender_id) VALUES " + strings.Join(statements, ", ")

if _, err := db.ExecContext(ctx, query, params...); err != nil {
if _, err := db.ExecContext(ctx, query, ids, userIds, subjects, contents, codes, senderIds); err != nil {
logger.Error("Could not save notifications.", zap.Error(err))
return err
}
Expand Down
60 changes: 34 additions & 26 deletions server/core_purchase.go
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,6 @@ import (
"fmt"
"net/http"
"strconv"
"strings"
"time"

"github.com/gofrs/uuid/v5"
Expand Down Expand Up @@ -558,45 +557,53 @@ func upsertPurchases(ctx context.Context, db *sql.DB, purchases []*storagePurcha
return nil, errors.New("expects at least one receipt")
}

statements := make([]string, 0, len(purchases))
params := make([]interface{}, 0, len(purchases)*8)
transactionIDsToPurchase := make(map[string]*storagePurchase)
offset := 0

userIdParams := make([]uuid.UUID, 0, len(purchases))
storeParams := make([]api.StoreProvider, 0, len(purchases))
transactionIdParams := make([]string, 0, len(purchases))
productIdParams := make([]string, 0, len(purchases))
purchaseTimeParams := make([]time.Time, 0, len(purchases))
rawResponseParams := make([]string, 0, len(purchases))
environmentParams := make([]api.StoreEnvironment, 0, len(purchases))
refundTimeParams := make([]time.Time, 0, len(purchases))

for _, purchase := range purchases {
if purchase.refundTime.IsZero() {
purchase.refundTime = time.Unix(0, 0)
}
if purchase.rawResponse == "" {
purchase.rawResponse = "{}"
}

statement := fmt.Sprintf("($%d, $%d, $%d, $%d, $%d, $%d, $%d, $%d)", offset+1, offset+2, offset+3, offset+4, offset+5, offset+6, offset+7, offset+8)
offset += 8
statements = append(statements, statement)
params = append(params, purchase.userID, purchase.store, purchase.transactionId, purchase.productId, purchase.purchaseTime, purchase.rawResponse, purchase.environment, purchase.refundTime)
transactionIDsToPurchase[purchase.transactionId] = purchase

userIdParams = append(userIdParams, purchase.userID)
storeParams = append(storeParams, purchase.store)
transactionIdParams = append(transactionIdParams, purchase.transactionId)
productIdParams = append(productIdParams, purchase.productId)
purchaseTimeParams = append(purchaseTimeParams, purchase.purchaseTime)
rawResponseParams = append(rawResponseParams, purchase.rawResponse)
environmentParams = append(environmentParams, purchase.environment)
refundTimeParams = append(refundTimeParams, purchase.refundTime)
}

query := `
INSERT
INTO
purchase
(
user_id,
store,
transaction_id,
product_id,
purchase_time,
raw_response,
environment,
refund_time
)
VALUES
` + strings.Join(statements, ", ") + `
INSERT INTO purchase
(
user_id,
store,
transaction_id,
product_id,
purchase_time,
raw_response,
environment,
refund_time
)
SELECT unnest($1::uuid[]), unnest($2::smallint[]), unnest($3::text[]), unnest($4::text[]), unnest($5::timestamptz[]), unnest($6::jsonb[]), unnest($7::smallint[]), unnest($8::timestamptz[])
ON CONFLICT
(transaction_id)
DO UPDATE SET
refund_time = $8,
refund_time = EXCLUDED.refund_time,
update_time = now()
RETURNING
user_id,
Expand All @@ -605,7 +612,8 @@ RETURNING
update_time,
refund_time
`
rows, err := db.QueryContext(ctx, query, params...)

rows, err := db.QueryContext(ctx, query, userIdParams, storeParams, transactionIdParams, productIdParams, purchaseTimeParams, rawResponseParams, environmentParams, refundTimeParams)
if err != nil {
return nil, err
}
Expand Down
Loading

0 comments on commit 211f6d1

Please sign in to comment.