diff --git a/server/console_channel.go b/server/console_channel.go index a0dfd45733..fea2fc1658 100644 --- a/server/console_channel.go +++ b/server/console_channel.go @@ -5,8 +5,6 @@ import ( "database/sql" "net/url" "sort" - "strconv" - "strings" "time" "github.com/gofrs/uuid/v5" @@ -66,20 +64,16 @@ func (s *ConsoleServer) DeleteChannelMessages(ctx context.Context, in *console.D s.logger.Info("Messages deleted.", zap.Int64("affected", affected), zap.String("timestamp", deleteBefore.String())) } if len(in.Ids) > 0 { - params := make([]interface{}, 0, len(in.Ids)) - statements := make([]string, len(in.Ids)) - for i, id := range in.Ids { - idStr, err := uuid.FromString(id) + for _, id := range in.Ids { + _, err := uuid.FromString(id) if err != nil { return nil, status.Error(codes.InvalidArgument, "Requires a valid message ID.") } - params = append(params, idStr) - statements[i] = "$" + strconv.Itoa(i+1) } - query := "DELETE FROM message WHERE id IN (" + strings.Join(statements, ",") + ")" + query := "DELETE FROM message WHERE id = ANY($1)" var res sql.Result var err error - if res, err = s.db.ExecContext(ctx, query, params...); err != nil { + if res, err = s.db.ExecContext(ctx, query, in.Ids); err != nil { s.logger.Error("Could not delete messages.", zap.Error(err)) return nil, status.Error(codes.Internal, "An error occurred while trying to delete messages.") } diff --git a/server/core_account.go b/server/core_account.go index 9d52894ba9..f83a455e46 100644 --- a/server/core_account.go +++ b/server/core_account.go @@ -139,20 +139,13 @@ WHERE u.id = $1` } func GetAccounts(ctx context.Context, logger *zap.Logger, db *sql.DB, statusRegistry *StatusRegistry, userIDs []string) ([]*api.Account, error) { - statements := make([]string, 0, len(userIDs)) - parameters := make([]interface{}, 0, len(userIDs)) - for _, userID := range userIDs { - parameters = append(parameters, userID) - statements = append(statements, "$"+strconv.Itoa(len(parameters))) - } - query := ` SELECT u.id, u.username, u.display_name, u.avatar_url, u.lang_tag, u.location, u.timezone, u.metadata, u.wallet, u.email, u.apple_id, u.facebook_id, u.facebook_instant_game_id, u.google_id, u.gamecenter_id, u.steam_id, u.custom_id, u.edge_count, u.create_time, u.update_time, u.verify_time, u.disable_time, array(select ud.id from user_device ud where u.id = ud.user_id) FROM users u -WHERE u.id IN (` + strings.Join(statements, ",") + `)` - rows, err := db.QueryContext(ctx, query, parameters...) +WHERE u.id = ANY($1)` + rows, err := db.QueryContext(ctx, query, userIDs) if err != nil { logger.Error("Error retrieving user accounts.", zap.Error(err)) return nil, err diff --git a/server/core_authenticate.go b/server/core_authenticate.go index e21447cab0..8bc2285493 100644 --- a/server/core_authenticate.go +++ b/server/core_authenticate.go @@ -854,15 +854,13 @@ func importSteamFriends(ctx context.Context, logger *zap.Logger, db *sql.DB, mes return nil } - statements := make([]string, 0, len(steamProfiles)) - params := make([]interface{}, 0, len(steamProfiles)) - for i, steamProfile := range steamProfiles { - statements = append(statements, "$"+strconv.Itoa(i+1)) - params = append(params, strconv.FormatUint(steamProfile.SteamID, 10)) + steamIDs := make([]string, 0, len(steamProfiles)) + for _, steamProfile := range steamProfiles { + steamIDs = append(steamIDs, strconv.FormatUint(steamProfile.SteamID, 10)) } - query := "SELECT id FROM users WHERE steam_id IN (" + strings.Join(statements, ", ") + ")" - rows, err := tx.QueryContext(ctx, query, params...) + query := "SELECT id FROM users WHERE steam_id = ANY($1::text[])" + rows, err := tx.QueryContext(ctx, query, steamIDs) if err != nil { if err == sql.ErrNoRows { // None of the friend profiles exist. @@ -872,7 +870,7 @@ func importSteamFriends(ctx context.Context, logger *zap.Logger, db *sql.DB, mes } var id string - possibleFriendIDs := make([]uuid.UUID, 0, len(statements)) + possibleFriendIDs := make([]uuid.UUID, 0, len(steamIDs)) for rows.Next() { err = rows.Scan(&id) if err != nil { @@ -930,15 +928,13 @@ func importFacebookFriends(ctx context.Context, logger *zap.Logger, db *sql.DB, return nil } - statements := make([]string, 0, len(facebookProfiles)) - params := make([]interface{}, 0, len(facebookProfiles)) - for i, facebookProfile := range facebookProfiles { - statements = append(statements, "$"+strconv.Itoa(i+1)) + params := make([]string, 0, len(facebookProfiles)) + for _, facebookProfile := range facebookProfiles { params = append(params, facebookProfile.ID) } - query := "SELECT id FROM users WHERE facebook_id IN (" + strings.Join(statements, ", ") + ")" - rows, err := tx.QueryContext(ctx, query, params...) + query := "SELECT id FROM users WHERE facebook_id = ANY($1::text[])" + rows, err := tx.QueryContext(ctx, query, params) if err != nil { if err == sql.ErrNoRows { // None of the friend profiles exist. @@ -948,7 +944,7 @@ func importFacebookFriends(ctx context.Context, logger *zap.Logger, db *sql.DB, } var id string - possibleFriendIDs := make([]uuid.UUID, 0, len(statements)) + possibleFriendIDs := make([]uuid.UUID, 0, len(params)) for rows.Next() { err = rows.Scan(&id) if err != nil { @@ -1005,8 +1001,7 @@ func resetUserFriends(ctx context.Context, tx *sql.Tx, userID uuid.UUID) error { if err != nil { return err } - statements := make([]string, 0, 10) - params := make([]interface{}, 0, 10) + params := make([]string, 0, 10) for rows.Next() { var id string err = rows.Scan(&id) @@ -1015,17 +1010,16 @@ func resetUserFriends(ctx context.Context, tx *sql.Tx, userID uuid.UUID) error { return err } params = append(params, id) - statements = append(statements, "$"+strconv.Itoa(len(params))) } _ = rows.Close() - if len(statements) > 0 { - query = "UPDATE users SET edge_count = edge_count - 1 WHERE id IN (" + strings.Join(statements, ",") + ")" - result, err := tx.ExecContext(ctx, query, params...) + if len(params) > 0 { + query = "UPDATE users SET edge_count = edge_count - 1 WHERE id = ANY($1)" + result, err := tx.ExecContext(ctx, query, params) if err != nil { return err } - if rowsAffectedCount, _ := result.RowsAffected(); rowsAffectedCount != int64(len(statements)) { + if rowsAffectedCount, _ := result.RowsAffected(); rowsAffectedCount != int64(len(params)) { return errors.New("error updating edge count after friend reset") } } diff --git a/server/core_group.go b/server/core_group.go index 8fd1136d21..bae04568ee 100644 --- a/server/core_group.go +++ b/server/core_group.go @@ -1624,18 +1624,10 @@ func GetGroups(ctx context.Context, logger *zap.Logger, db *sql.DB, ids []string return make([]*api.Group, 0), nil } - statements := make([]string, 0, len(ids)) - params := make([]interface{}, 0, len(ids)) - for i, id := range ids { - statements = append(statements, "$"+strconv.Itoa(i+1)) - params = append(params, id) - } - query := `SELECT id, creator_id, name, description, avatar_url, state, edge_count, lang_tag, max_count, metadata, create_time, update_time FROM groups -WHERE disable_time = '1970-01-01 00:00:00 UTC' -AND id IN (` + strings.Join(statements, ",") + `)` - rows, err := db.QueryContext(ctx, query, params...) +WHERE disable_time = '1970-01-01 00:00:00 UTC' AND id = ANY($1)` + rows, err := db.QueryContext(ctx, query, ids) if err != nil { if err == sql.ErrNoRows { return make([]*api.Group, 0), nil diff --git a/server/core_leaderboard.go b/server/core_leaderboard.go index 226a3afb4b..1f9b7f6e1b 100644 --- a/server/core_leaderboard.go +++ b/server/core_leaderboard.go @@ -22,7 +22,6 @@ import ( "encoding/gob" "errors" "sort" - "strconv" "strings" "time" @@ -286,15 +285,11 @@ func LeaderboardRecordsList(ctx context.Context, logger *zap.Logger, db *sql.DB, } if len(ownerIds) != 0 { - params := make([]interface{}, 0, len(ownerIds)+2) - params = append(params, leaderboardId, time.Unix(expiryTime, 0).UTC()) - statements := make([]string, len(ownerIds)) - for i, ownerID := range ownerIds { - params = append(params, ownerID) - statements[i] = "$" + strconv.Itoa(i+3) - } + params := []any{leaderboardId, time.Unix(expiryTime, 0).UTC(), ownerIds} + query := `SELECT owner_id, username, score, subscore, num_score, max_num_score, metadata, create_time, update_time +FROM leaderboard_record +WHERE leaderboard_id = $1 AND expiry_time = $2 AND owner_id = ANY($3)` - query := "SELECT owner_id, username, score, subscore, num_score, max_num_score, metadata, create_time, update_time FROM leaderboard_record WHERE leaderboard_id = $1 AND expiry_time = $2 AND owner_id IN (" + strings.Join(statements, ", ") + ")" rows, err := db.QueryContext(ctx, query, params...) if err != nil { logger.Error("Error reading leaderboard records", zap.Error(err)) diff --git a/server/core_notification.go b/server/core_notification.go index 80512e75be..cddafba179 100644 --- a/server/core_notification.go +++ b/server/core_notification.go @@ -21,8 +21,6 @@ import ( "encoding/base64" "encoding/gob" "fmt" - "strconv" - "strings" "time" "github.com/gofrs/uuid/v5" @@ -261,17 +259,9 @@ ORDER BY create_time ASC, id ASC`+limitQuery, params...) } func NotificationDelete(ctx context.Context, logger *zap.Logger, db *sql.DB, userID uuid.UUID, notificationIDs []string) error { - statements := make([]string, 0, len(notificationIDs)) - params := make([]interface{}, 0, len(notificationIDs)+1) - params = append(params, userID) - - for _, id := range notificationIDs { - statement := "$" + strconv.Itoa(len(params)+1) - statements = append(statements, statement) - params = append(params, id) - } + params := []any{userID, notificationIDs} - query := "DELETE FROM notification WHERE user_id = $1 AND id IN (" + strings.Join(statements, ", ") + ")" + query := "DELETE FROM notification WHERE user_id = $1 AND id = ANY($2)" logger.Debug("Delete notification query", zap.String("query", query), zap.Any("params", params)) _, err := db.ExecContext(ctx, query, params...) if err != nil { @@ -283,33 +273,35 @@ func NotificationDelete(ctx context.Context, logger *zap.Logger, db *sql.DB, use } func NotificationSave(ctx context.Context, logger *zap.Logger, db *sql.DB, notifications map[uuid.UUID][]*api.Notification) error { - statements := make([]string, 0, len(notifications)) - params := make([]interface{}, 0, len(notifications)) - counter := 0 + ids := make([]string, 0, len(notifications)) + userIds := make([]uuid.UUID, 0, len(notifications)) + subjects := make([]string, 0, len(notifications)) + contents := make([]string, 0, len(notifications)) + codes := make([]int32, 0, len(notifications)) + senderIds := make([]string, 0, len(notifications)) + query := ` +INSERT INTO + notification (id, user_id, subject, content, code, sender_id) +SELECT + unnest($1::uuid[]), + unnest($2::uuid[]), + unnest($3::text[]), + unnest($4::jsonb[]), + unnest($5::smallint[]), + unnest($6::uuid[]); +` for userID, no := range notifications { for _, un := range no { - statement := "$" + strconv.Itoa(counter+1) + - ",$" + strconv.Itoa(counter+2) + - ",$" + strconv.Itoa(counter+3) + - ",$" + strconv.Itoa(counter+4) + - ",$" + strconv.Itoa(counter+5) + - ",$" + strconv.Itoa(counter+6) - - counter = counter + 6 - statements = append(statements, "("+statement+")") - - params = append(params, un.Id) - params = append(params, userID) - params = append(params, un.Subject) - params = append(params, un.Content) - params = append(params, un.Code) - params = append(params, un.SenderId) + ids = append(ids, un.Id) + userIds = append(userIds, userID) + subjects = append(subjects, un.Subject) + contents = append(contents, un.Content) + codes = append(codes, un.Code) + senderIds = append(senderIds, un.SenderId) } } - query := "INSERT INTO notification (id, user_id, subject, content, code, sender_id) VALUES " + strings.Join(statements, ", ") - - if _, err := db.ExecContext(ctx, query, params...); err != nil { + if _, err := db.ExecContext(ctx, query, ids, userIds, subjects, contents, codes, senderIds); err != nil { logger.Error("Could not save notifications.", zap.Error(err)) return err } diff --git a/server/core_purchase.go b/server/core_purchase.go index 9b4d4478a0..744ad6c09a 100644 --- a/server/core_purchase.go +++ b/server/core_purchase.go @@ -24,7 +24,6 @@ import ( "fmt" "net/http" "strconv" - "strings" "time" "github.com/gofrs/uuid/v5" @@ -558,10 +557,17 @@ func upsertPurchases(ctx context.Context, db *sql.DB, purchases []*storagePurcha return nil, errors.New("expects at least one receipt") } - statements := make([]string, 0, len(purchases)) - params := make([]interface{}, 0, len(purchases)*8) transactionIDsToPurchase := make(map[string]*storagePurchase) - offset := 0 + + userIdParams := make([]uuid.UUID, 0, len(purchases)) + storeParams := make([]api.StoreProvider, 0, len(purchases)) + transactionIdParams := make([]string, 0, len(purchases)) + productIdParams := make([]string, 0, len(purchases)) + purchaseTimeParams := make([]time.Time, 0, len(purchases)) + rawResponseParams := make([]string, 0, len(purchases)) + environmentParams := make([]api.StoreEnvironment, 0, len(purchases)) + refundTimeParams := make([]time.Time, 0, len(purchases)) + for _, purchase := range purchases { if purchase.refundTime.IsZero() { purchase.refundTime = time.Unix(0, 0) @@ -569,34 +575,35 @@ func upsertPurchases(ctx context.Context, db *sql.DB, purchases []*storagePurcha if purchase.rawResponse == "" { purchase.rawResponse = "{}" } - - statement := fmt.Sprintf("($%d, $%d, $%d, $%d, $%d, $%d, $%d, $%d)", offset+1, offset+2, offset+3, offset+4, offset+5, offset+6, offset+7, offset+8) - offset += 8 - statements = append(statements, statement) - params = append(params, purchase.userID, purchase.store, purchase.transactionId, purchase.productId, purchase.purchaseTime, purchase.rawResponse, purchase.environment, purchase.refundTime) transactionIDsToPurchase[purchase.transactionId] = purchase + + userIdParams = append(userIdParams, purchase.userID) + storeParams = append(storeParams, purchase.store) + transactionIdParams = append(transactionIdParams, purchase.transactionId) + productIdParams = append(productIdParams, purchase.productId) + purchaseTimeParams = append(purchaseTimeParams, purchase.purchaseTime) + rawResponseParams = append(rawResponseParams, purchase.rawResponse) + environmentParams = append(environmentParams, purchase.environment) + refundTimeParams = append(refundTimeParams, purchase.refundTime) } query := ` -INSERT -INTO - purchase - ( - user_id, - store, - transaction_id, - product_id, - purchase_time, - raw_response, - environment, - refund_time - ) -VALUES - ` + strings.Join(statements, ", ") + ` +INSERT INTO purchase + ( + user_id, + store, + transaction_id, + product_id, + purchase_time, + raw_response, + environment, + refund_time + ) +SELECT unnest($1::uuid[]), unnest($2::smallint[]), unnest($3::text[]), unnest($4::text[]), unnest($5::timestamptz[]), unnest($6::jsonb[]), unnest($7::smallint[]), unnest($8::timestamptz[]) ON CONFLICT (transaction_id) DO UPDATE SET - refund_time = $8, + refund_time = EXCLUDED.refund_time, update_time = now() RETURNING user_id, @@ -605,7 +612,8 @@ RETURNING update_time, refund_time ` - rows, err := db.QueryContext(ctx, query, params...) + + rows, err := db.QueryContext(ctx, query, userIdParams, storeParams, transactionIdParams, productIdParams, purchaseTimeParams, rawResponseParams, environmentParams, refundTimeParams) if err != nil { return nil, err } diff --git a/server/core_tournament.go b/server/core_tournament.go index b23dc9a671..da01b2db78 100644 --- a/server/core_tournament.go +++ b/server/core_tournament.go @@ -21,8 +21,6 @@ import ( "encoding/base64" "encoding/gob" "errors" - "fmt" - "strconv" "strings" "time" @@ -252,19 +250,13 @@ func TournamentsGet(ctx context.Context, logger *zap.Logger, db *sql.DB, leaderb } if len(dbLookupTournamentIDs) > 0 { - params := make([]interface{}, 0, len(dbLookupTournamentIDs)) - statements := make([]string, 0, len(dbLookupTournamentIDs)) - for i, tournamentID := range dbLookupTournamentIDs { - params = append(params, tournamentID) - statements = append(statements, fmt.Sprintf("$%v", i+1)) - } query := `SELECT id, sort_order, operator, reset_schedule, metadata, create_time, category, description, duration, end_time, max_size, max_num_score, title, size, start_time FROM leaderboard -WHERE id IN (` + strings.Join(statements, ",") + `)` +WHERE id = ANY($1::text[])` // Retrieved directly from database to have the latest configuration and 'size' etc field values. // Ensures consistency between return data from this call and TournamentList. - rows, err := db.QueryContext(ctx, query, params...) + rows, err := db.QueryContext(ctx, query, dbLookupTournamentIDs) if err != nil { logger.Error("Could not retrieve tournaments", zap.Error(err)) return nil, err @@ -308,22 +300,18 @@ func TournamentList(ctx context.Context, logger *zap.Logger, db *sql.DB, leaderb } // Read most up to date sizes from database. - statements := make([]string, 0, len(list)) - params := make([]interface{}, 0, len(list)) - var count int + ids := make([]string, 0, len(list)) for _, leaderboard := range list { if !leaderboard.HasMaxSize() { continue } - params = append(params, leaderboard.Id) - statements = append(statements, "$"+strconv.Itoa(count+1)) - count++ + ids = append(ids, leaderboard.Id) } sizes := make(map[string]int, len(list)) - if len(statements) > 0 { - query := "SELECT id, size FROM leaderboard WHERE id IN (" + strings.Join(statements, ",") + ")" - rows, err := db.QueryContext(ctx, query, params...) + if len(ids) > 0 { + query := "SELECT id, size FROM leaderboard WHERE id = ANY($1::text[])" + rows, err := db.QueryContext(ctx, query, ids) if err != nil { logger.Error("Could not retrieve tournaments", zap.Error(err)) return nil, err diff --git a/server/core_user.go b/server/core_user.go index af98828f32..a1a5d2f046 100644 --- a/server/core_user.go +++ b/server/core_user.go @@ -17,8 +17,7 @@ package server import ( "context" "database/sql" - "strconv" - "strings" + "fmt" "github.com/gofrs/uuid/v5" "github.com/heroiclabs/nakama-common/api" @@ -34,49 +33,33 @@ SELECT id, username, display_name, avatar_url, lang_tag, location, timezone, met FROM users WHERE` - idStatements := make([]string, 0, len(ids)) - usernameStatements := make([]string, 0, len(usernames)) - facebookStatements := make([]string, 0, len(fbIDs)) - params := make([]interface{}, 0) + params := make([]any, 0) counter := 1 useSQLOr := false if len(ids) > 0 { - for _, id := range ids { - params = append(params, id) - statement := "$" + strconv.Itoa(counter) - idStatements = append(idStatements, statement) - counter++ - } - query = query + " id IN (" + strings.Join(idStatements, ", ") + ")" + params = append(params, ids) + query = query + fmt.Sprintf(" id = ANY($%d)", counter) + counter++ useSQLOr = true } if len(usernames) > 0 { - for _, username := range usernames { - params = append(params, username) - statement := "$" + strconv.Itoa(counter) - usernameStatements = append(usernameStatements, statement) - counter++ - } + params = append(params, usernames) if useSQLOr { query = query + " OR" } - query = query + " username IN (" + strings.Join(usernameStatements, ", ") + ")" + query = query + fmt.Sprintf(" username = ANY($%d::text[])", counter) + counter++ useSQLOr = true } if len(fbIDs) > 0 { - for _, id := range fbIDs { - params = append(params, id) - statement := "$" + strconv.Itoa(counter) - facebookStatements = append(facebookStatements, statement) - counter++ - } + params = append(params, fbIDs) if useSQLOr { query = query + " OR" } - query = query + " facebook_id IN (" + strings.Join(facebookStatements, ", ") + ")" + query = query + fmt.Sprintf(" facebook_id = ANY($%d::text[])", counter) } rows, err := db.QueryContext(ctx, query, params...) @@ -180,17 +163,10 @@ func DeleteUser(ctx context.Context, tx *sql.Tx, userID uuid.UUID) (int64, error } func BanUsers(ctx context.Context, logger *zap.Logger, db *sql.DB, config Config, sessionCache SessionCache, sessionRegistry SessionRegistry, tracker Tracker, ids []uuid.UUID) error { - statements := make([]string, 0, len(ids)) - params := make([]interface{}, 0, len(ids)) - for i, id := range ids { - statements = append(statements, "$"+strconv.Itoa(i+1)) - params = append(params, id.String()) - } - - query := "UPDATE users SET disable_time = now() WHERE id IN (" + strings.Join(statements, ", ") + ")" - _, err := db.ExecContext(ctx, query, params...) + query := "UPDATE users SET disable_time = now() WHERE id = ANY($1::UUID[])" + _, err := db.ExecContext(ctx, query, ids) if err != nil { - logger.Error("Error banning user accounts.", zap.Error(err), zap.Any("ids", params)) + logger.Error("Error banning user accounts.", zap.Error(err), zap.Any("ids", ids)) return err } @@ -209,17 +185,10 @@ func BanUsers(ctx context.Context, logger *zap.Logger, db *sql.DB, config Config } func UnbanUsers(ctx context.Context, logger *zap.Logger, db *sql.DB, sessionCache SessionCache, ids []uuid.UUID) error { - statements := make([]string, 0, len(ids)) - params := make([]interface{}, 0, len(ids)) - for i, id := range ids { - statements = append(statements, "$"+strconv.Itoa(i+1)) - params = append(params, id.String()) - } - - query := "UPDATE users SET disable_time = '1970-01-01 00:00:00 UTC' WHERE id IN (" + strings.Join(statements, ", ") + ")" - _, err := db.ExecContext(ctx, query, params...) + query := "UPDATE users SET disable_time = '1970-01-01 00:00:00 UTC' WHERE id = ANY($1::UUID[])" + _, err := db.ExecContext(ctx, query, ids) if err != nil { - logger.Error("Error unbanning user accounts.", zap.Error(err), zap.Any("ids", params)) + logger.Error("Error unbanning user accounts.", zap.Error(err), zap.Any("ids", ids)) return err } @@ -295,18 +264,8 @@ func fetchUserID(ctx context.Context, db *sql.DB, usernames []string) ([]string, return ids, nil } - statements := make([]string, 0, len(usernames)) - params := make([]interface{}, 0, len(usernames)) - counter := 1 - for _, username := range usernames { - params = append(params, username) - statement := "$" + strconv.Itoa(counter) - statements = append(statements, statement) - counter++ - } - - query := "SELECT id FROM users WHERE username IN (" + strings.Join(statements, ", ") + ")" - rows, err := db.QueryContext(ctx, query, params...) + query := "SELECT id FROM users WHERE username = ANY($1::text[])" + rows, err := db.QueryContext(ctx, query, usernames) if err != nil { if err == sql.ErrNoRows { return ids, nil diff --git a/server/core_wallet.go b/server/core_wallet.go index ed3550f963..9ffd8fda60 100644 --- a/server/core_wallet.go +++ b/server/core_wallet.go @@ -23,8 +23,6 @@ import ( "encoding/json" "fmt" "sort" - "strconv" - "strings" "time" "github.com/gofrs/uuid/v5" @@ -116,18 +114,16 @@ func updateWallets(ctx context.Context, logger *zap.Logger, tx pgx.Tx, updates [ return nil, nil } - initialParams := make([]interface{}, 0, len(updates)) - initialStatements := make([]string, 0, len(updates)) + ids := make([]uuid.UUID, 0, len(updates)) for _, update := range updates { - initialParams = append(initialParams, update.UserID) - initialStatements = append(initialStatements, "$"+strconv.Itoa(len(initialParams))+"::UUID") + ids = append(ids, update.UserID) } - initialQuery := "SELECT id, wallet FROM users WHERE id IN (" + strings.Join(initialStatements, ",") + ")" + initialQuery := "SELECT id, wallet FROM users WHERE id = ANY($1::UUID[])" // Select the wallets from the DB and decode them. wallets := make(map[string]map[string]int64, len(updates)) - rows, err := tx.Query(ctx, initialQuery, initialParams...) + rows, err := tx.Query(ctx, initialQuery, ids) if err != nil { logger.Debug("Error retrieving user wallets.", zap.Error(err)) return nil, err @@ -159,11 +155,16 @@ func updateWallets(ctx context.Context, logger *zap.Logger, tx pgx.Tx, updates [ // Prepare the set of wallet updates and ledger updates. updatedWallets := make(map[string][]byte, len(updates)) updateOrder := make([]string, 0, len(updates)) - var statements []string - var params []interface{} + + var idParams []uuid.UUID + var userIdParams []string + var changesetParams [][]byte + var metadataParams []string if updateLedger { - statements = make([]string, 0, len(updates)) - params = make([]interface{}, 0, len(updates)*4) + idParams = make([]uuid.UUID, 0, len(updates)) + userIdParams = make([]string, 0, len(updates)) + changesetParams = make([][]byte, 0, len(updates)) + metadataParams = make([]string, 0, len(updates)) } // Go through the changesets and attempt to calculate the new state for each wallet. @@ -216,8 +217,10 @@ func updateWallets(ctx context.Context, logger *zap.Logger, tx pgx.Tx, updates [ return nil, err } - params = append(params, uuid.Must(uuid.NewV4()), userID, changesetData, update.Metadata) - statements = append(statements, fmt.Sprintf("($%v::UUID, $%v, $%v, $%v)", strconv.Itoa(len(params)-3), strconv.Itoa(len(params)-2), strconv.Itoa(len(params)-1), strconv.Itoa(len(params)))) + idParams = append(idParams, uuid.Must(uuid.NewV4())) + userIdParams = append(userIdParams, userID) + changesetParams = append(changesetParams, changesetData) + metadataParams = append(metadataParams, update.Metadata) } } @@ -241,8 +244,13 @@ func updateWallets(ctx context.Context, logger *zap.Logger, tx pgx.Tx, updates [ } // Write the ledger updates, if any. - if updateLedger && (len(statements) > 0) { - _, err = tx.Exec(ctx, "INSERT INTO wallet_ledger (id, user_id, changeset, metadata) VALUES "+strings.Join(statements, ", "), params...) + if updateLedger && (len(idParams) > 0) { + _, err = tx.Exec(ctx, ` +INSERT INTO + wallet_ledger (id, user_id, changeset, metadata) +SELECT + unnest($1::uuid[]), unnest($2::uuid[]), unnest($3::jsonb[]), unnest($4::jsonb[]); +`, idParams, userIdParams, changesetParams, metadataParams) if err != nil { logger.Debug("Error writing user wallet ledgers.", zap.Error(err)) return nil, err diff --git a/server/pipeline_status.go b/server/pipeline_status.go index 7ece03a814..74110c74a9 100644 --- a/server/pipeline_status.go +++ b/server/pipeline_status.go @@ -15,8 +15,7 @@ package server import ( - "strconv" - "strings" + "fmt" "github.com/gofrs/uuid/v5" "github.com/heroiclabs/nakama-common/rtapi" @@ -86,16 +85,14 @@ func (p *Pipeline) statusFollow(logger *zap.Logger, session Session, envelope *r followUserIDs := make(map[uuid.UUID]struct{}, len(uniqueUserIDs)+len(uniqueUsernames)) foundUsernames := make(map[string]struct{}, len(uniqueUsernames)) if len(uniqueUsernames) == 0 { - params := make([]interface{}, 0, len(uniqueUserIDs)) - statements := make([]string, 0, len(uniqueUserIDs)) + ids := make([]uuid.UUID, 0, len(uniqueUserIDs)) for userID := range uniqueUserIDs { - params = append(params, userID) - statements = append(statements, "$"+strconv.Itoa(len(params))+"::UUID") + ids = append(ids, userID) } // See if all the users exist. - query := "SELECT id FROM users WHERE id IN (" + strings.Join(statements, ", ") + ")" - rows, err := p.db.QueryContext(session.Context(), query, params...) + query := "SELECT id FROM users WHERE id = ANY($1::UUID[])" + rows, err := p.db.QueryContext(session.Context(), query, ids) if err != nil { logger.Error("Error checking users in status follow", zap.Error(err)) _ = session.Send(&rtapi.Envelope{Cid: envelope.Cid, Message: &rtapi.Envelope_Error{Error: &rtapi.Error{ @@ -122,25 +119,25 @@ func (p *Pipeline) statusFollow(logger *zap.Logger, session Session, envelope *r } else { query := "SELECT id, username FROM users WHERE " - params := make([]interface{}, 0, len(uniqueUserIDs)) - statements := make([]string, 0, len(uniqueUserIDs)) + params := make([]any, 0, 2) + ids := make([]uuid.UUID, 0, len(uniqueUserIDs)) for userID := range uniqueUserIDs { - params = append(params, userID) - statements = append(statements, "$"+strconv.Itoa(len(params))+"::UUID") + ids = append(ids, userID) } - if len(statements) != 0 { - query += "id IN (" + strings.Join(statements, ", ") + ")" - statements = make([]string, 0, len(uniqueUsernames)) + if len(ids) != 0 { + params = append(params, ids) + query += fmt.Sprintf("id = ANY($%d::UUID[])", len(params)) } + usernames := make([]string, 0, len(uniqueUsernames)) for username := range uniqueUsernames { - params = append(params, username) - statements = append(statements, "$"+strconv.Itoa(len(params))) + usernames = append(usernames, username) } if len(uniqueUserIDs) != 0 { query += " OR " } - query += "username IN (" + strings.Join(statements, ", ") + ")" + params = append(params, usernames) + query += fmt.Sprintf("username = ANY($%d::text[])", len(params)) // See if all the users exist. rows, err := p.db.QueryContext(session.Context(), query, params...)