From a0ddcc5cda68f8f33776e1fb072d69b9a05770ac Mon Sep 17 00:00:00 2001 From: Thanh Nguyen Date: Sun, 11 Jun 2023 14:11:12 +1000 Subject: [PATCH 1/8] RestClient: support sub-partitions --- README.md | 4 +- RELEASE-NOTES.md | 6 + SQL.md | 4 +- driver.go | 2 +- restclient.go | 12 +- restclient_collection_test.go | 48 ++ ...client_document_crud_subpartitions_test.go | 443 ++++++++++++++++++ stmt_collection.go | 40 +- stmt_collection_parsing_test.go | 10 +- stmt_collection_test.go | 48 +- 10 files changed, 584 insertions(+), 33 deletions(-) create mode 100644 restclient_document_crud_subpartitions_test.go diff --git a/README.md b/README.md index 0fabc81..9873ffc 100644 --- a/README.md +++ b/README.md @@ -77,7 +77,7 @@ AccountEndpoint= - `AccountEndpoint`: (required) endpoint to access Cosmos DB. For example, the endpoint for Azure Cosmos DB Emulator running on local is `https://localhost:8081/`. - `AccountKey`: (required) account key to authenticate. - `TimeoutMs`: (optional) operation timeout in milliseconds. Default value is `10 seconds` if not specified. -- `Version`: (optional) version of Cosmos DB to use. Default value is `2018-12-31` if not specified. See: https://learn.microsoft.com/rest/api/cosmos-db/#supported-rest-api-versions. +- `Version`: (optional) version of Cosmos DB to use. Default value is `2020-07-15` if not specified. See: https://learn.microsoft.com/rest/api/cosmos-db/#supported-rest-api-versions. - `DefaultDb`: (optional, available since [v0.1.1](RELEASE-NOTES.md)) specify the default database used in Cosmos DB operations. Alias `Db` can also be used instead of `DefaultDb`. - `AutoId`: (optional, available since [v0.1.2](RELEASE-NOTES.md)) see [auto id](#auto-id) session. - `InsecureSkipVerify`: (optional, available since [v0.1.4](RELEASE-NOTES.md)) if `true`, disable CA verification for https endpoint (useful to run against test/dev env with local/docker Cosmos DB emulator). @@ -156,7 +156,7 @@ AccountEndpoint= - `AccountEndpoint`: (required) endpoint to access Cosmos DB. For example, the endpoint for Azure Cosmos DB Emulator running on local is `https://localhost:8081/`. - `AccountKey`: (required) account key to authenticate. - `TimeoutMs`: (optional) operation timeout in milliseconds. Default value is `10 seconds` if not specified. -- `Version`: (optional) version of Cosmos DB to use. Default value is `2018-12-31` if not specified. See: https://learn.microsoft.com/rest/api/cosmos-db/#supported-rest-api-versions. +- `Version`: (optional) version of Cosmos DB to use. Default value is `2020-07-15` if not specified. See: https://learn.microsoft.com/rest/api/cosmos-db/#supported-rest-api-versions. - `AutoId`: (optional, available since [v0.1.2](RELEASE-NOTES.md)) see [auto id](#auto-id) session. - `InsecureSkipVerify`: (optional, available since [v0.1.4](RELEASE-NOTES.md)) if `true`, disable CA verification for https endpoint (useful to run against test/dev env with local/docker Cosmos DB emulator). diff --git a/RELEASE-NOTES.md b/RELEASE-NOTES.md index e9b4923..7e504e3 100644 --- a/RELEASE-NOTES.md +++ b/RELEASE-NOTES.md @@ -1,5 +1,11 @@ # gocosmos - Release notes +## 2003-06-0x - v0.3.0 + +- Change default API version to `2020-07-15`. +- Add Hierarchical Partition Keys (sub-partitions) support. +- PartitionKey version 1 is no longer used (hence large PK is always enabled). + ## 2023-06-09 - v0.2.1 - Bug fixes, Refactoring & Enhancements. diff --git a/SQL.md b/SQL.md index 27f14cf..ed62cae 100644 --- a/SQL.md +++ b/SQL.md @@ -168,7 +168,9 @@ fmt.Println(dbresult.RowsAffected()) - Upon successful execution, `RowsAffected()` returns `(1, nil)`. - This statement returns error `ErrConflict` if the specified collection already existed. If `IF NOT EXISTS` is specified, `RowsAffected()` returns `(0, nil)`. -- Partition key must be specified using `WITH pk=`. If partition key is larger than 100 bytes, use `WITH pk=` instead. +- Partition key must be specified using `WITH pk=`. + - Since [v0.3.0](RELEASE-NOTES.md), large pk is always enabled, `WITH largepk` is for backward compatibility only. + - Since [v0.3.0](RELEASE-NOTES.md), Hierarchical Partition Key is supported, using `WITH pk=/path1,/path2...` (up to 3 path levels). - Provisioned capacity can be optionally specified via `WITH RU=` or `WITH MAXRU=`. - Only one of `RU` and `MAXRU` options should be specified, _not both_; error is returned if both optiosn are specified. - Unique keys are optionally specified via `WITH uk=/uk1_path:/uk2_path1,/uk2_path2:/uk3_path`. Each unique key is a comma-separated list of paths (e.g. `/uk_path1,/uk_path2`); unique keys are separated by colons (e.g. `/uk1:/uk2:/uk3`). diff --git a/driver.go b/driver.go index 5446cd1..e267f4e 100644 --- a/driver.go +++ b/driver.go @@ -98,7 +98,7 @@ type Driver struct { // // AccountEndpoint=;AccountKey=[;TimeoutMs=][;Version=][;DefaultDb=][;AutoId=][;InsecureSkipVerify=] // -// If not supplied, default value for TimeoutMs is 10 seconds, Version is defaultApiVersion (which is "2018-12-31"), AutoId is true, and InsecureSkipVerify is false +// If not supplied, default value for TimeoutMs is 10 seconds, Version is DefaultApiVersion (which is "2020-07-15"), AutoId is true, and InsecureSkipVerify is false // // - DefaultDb is added since v0.1.1 // - AutoId is added since v0.1.2 diff --git a/restclient.go b/restclient.go index 109fdc3..f0218d1 100644 --- a/restclient.go +++ b/restclient.go @@ -31,7 +31,13 @@ const ( settingVersion = "VERSION" settingAutoId = "AUTOID" settingInsecureSkipVerify = "INSECURESKIPVERIFY" - defaultApiVersion = "2018-12-31" + + // DefaultApiVersion holds the default REST API version if not specified in the connection string. + // + // See: https://learn.microsoft.com/en-us/rest/api/cosmos-db/#supported-rest-api-versions + // + // @Available since v0.3.0 + DefaultApiVersion = "2020-07-15" ) // NewRestClient constructs a new RestClient instance from the supplied connection string. @@ -41,7 +47,7 @@ const ( // // AccountEndpoint=;AccountKey=[;TimeoutMs=][;Version=][;AutoId=][;InsecureSkipVerify=] // -// If not supplied, default value for TimeoutMs is 10 seconds, Version is defaultApiVersion (which is "2018-12-31"), AutoId is true, and InsecureSkipVerify is false +// If not supplied, default value for TimeoutMs is 10 seconds, Version is DefaultApiVersion (which is "2020-07-15"), AutoId is true, and InsecureSkipVerify is false // // - AutoId is added since v0.1.2 // - InsecureSkipVerify is added since v0.1.4 @@ -75,7 +81,7 @@ func NewRestClient(httpClient *http.Client, connStr string) (*RestClient, error) } apiVersion := params[settingVersion] if apiVersion == "" { - apiVersion = defaultApiVersion + apiVersion = DefaultApiVersion } autoId, err := strconv.ParseBool(params[settingAutoId]) if err != nil { diff --git a/restclient_collection_test.go b/restclient_collection_test.go index e112342..3e2068b 100644 --- a/restclient_collection_test.go +++ b/restclient_collection_test.go @@ -65,6 +65,54 @@ func TestRestClient_CreateCollection(t *testing.T) { } } +func TestRestClient_CreateCollection_SubPartitions(t *testing.T) { + name := "TestRestClient_CreateCollection_SubPartitions" + client := _newRestClient(t, name) + + dbname := testDb + collname := testTable + collspecList := []CollectionSpec{ + // Hierarchical Partition Keys + {DbName: dbname, CollName: collname, PartitionKeyInfo: map[string]interface{}{"paths": []string{"/TenantId", "/UserId"}, "kind": "MultiHash", "version": 2}}, + {DbName: dbname, CollName: collname, MaxRu: 4000, PartitionKeyInfo: map[string]interface{}{"paths": []string{"/TenantId", "/UserId", "/SessionId"}, "kind": "MultiHash", "version": 2}}, + } + for _, collspec := range collspecList { + client.DeleteDatabase(dbname) + client.CreateDatabase(DatabaseSpec{Id: dbname}) + var collInfo CollInfo + if result := client.CreateCollection(collspec); result.Error() != nil { + t.Fatalf("%s failed: %s", name, result.Error()) + } else if result.Id != collname { + t.Fatalf("%s failed: expected %#v but received %#v", name+"/CreateDatabase", collname, result.Id) + } else if result.Rid == "" || result.Self == "" || result.Etag == "" || result.Docs == "" || + result.Sprocs == "" || result.Triggers == "" || result.Udfs == "" || result.Conflicts == "" || + result.Ts <= 0 || len(result.IndexingPolicy) == 0 || len(result.PartitionKey) == 0 { + t.Fatalf("%s failed: invalid collinfo returned %#v", name, result.CollInfo) + } else { + collInfo = result.CollInfo + } + + if collspec.Ru > 0 || collspec.MaxRu > 0 { + if result := client.GetOfferForResource(collInfo.Rid); result.Error() != nil { + t.Fatalf("%s failed: %s", name, result.Error()) + } else { + if ru, maxru := result.OfferThroughput(), result.MaxThroughputEverProvisioned(); collspec.Ru > 0 && (collspec.Ru != ru || collspec.Ru != maxru) { + t.Fatalf("%s failed: expected %#v but expected {ru:%#v, maxru:%#v}", name, collspec.Ru, ru, maxru) + } + if ru, maxru := result.OfferThroughput(), result.MaxThroughputEverProvisioned(); collspec.MaxRu > 0 && (collspec.MaxRu != ru*10 || collspec.MaxRu != maxru) { + t.Fatalf("%s failed: expected %#v but expected {ru:%#v, maxru:%#v}", name, collspec.MaxRu, ru, maxru) + } + } + } + + if result := client.CreateCollection(collspec); result.CallErr != nil { + t.Fatalf("%s failed: %s", name, result.CallErr) + } else if result.StatusCode != 409 { + t.Fatalf("%s failed: expected %#v but received %#v", name, 409, result.StatusCode) + } + } +} + func TestRestClient_ChangeOfferCollection(t *testing.T) { name := "TestRestClient_ChangeOfferCollection" client := _newRestClient(t, name) diff --git a/restclient_document_crud_subpartitions_test.go b/restclient_document_crud_subpartitions_test.go new file mode 100644 index 0000000..5417523 --- /dev/null +++ b/restclient_document_crud_subpartitions_test.go @@ -0,0 +1,443 @@ +package gocosmos + +import ( + "strings" + "testing" + "time" +) + +/*----------------------------------------------------------------------*/ + +func TestRestClient_CreateDocument_SubPartitions(t *testing.T) { + name := "TestRestClient_CreateDocument_SubPartitions" + client := _newRestClient(t, name) + + dbname := testDb + collname := testTable + client.DeleteDatabase(dbname) + client.CreateDatabase(DatabaseSpec{Id: dbname}) + client.CreateCollection(CollectionSpec{ + DbName: dbname, + CollName: collname, + PartitionKeyInfo: map[string]interface{}{"paths": []string{"/app", "/username"}, "kind": "MultiHash", "version": 2}, + UniqueKeyPolicy: map[string]interface{}{"uniqueKeys": []map[string]interface{}{{"paths": []string{"/email"}}}}, + }) + + if result := client.CreateDocument(DocumentSpec{ + DbName: dbname, CollName: collname, PartitionKeyValues: []interface{}{"app1", "user1"}, + DocumentData: map[string]interface{}{"id": "1", "app": "app1", "username": "user1", "email": "user1@domain.com", "grade": 1, "active": true}, + }); result.Error() != nil { + t.Fatalf("%s failed: %s", name, result.Error()) + } else if result.DocInfo["id"] != "1" || result.DocInfo["app"] != "app1" || result.DocInfo["username"] != "user1" || + result.DocInfo["email"] != "user1@domain.com" || result.DocInfo["grade"].(float64) != 1.0 || result.DocInfo["active"] != true || + result.DocInfo["_rid"] == "" || result.DocInfo["_self"] == "" || result.DocInfo["_ts"].(float64) == 0.0 || result.DocInfo["_etag"] == "" || result.DocInfo["_attachments"] == "" { + t.Fatalf("%s failed: invalid dbinfo returned %#v", name, result.DocInfo) + } + + if result := client.CreateDocument(DocumentSpec{ + DbName: dbname, CollName: collname, PartitionKeyValues: []interface{}{"app1", "user11"}, IndexingDirective: "Include", + DocumentData: map[string]interface{}{"id": "11", "app": "app1", "username": "user11", "email": "user11@domain.com", "grade": 1.1, "active": false}, + }); result.Error() != nil { + t.Fatalf("%s failed: %s", name, result.Error()) + } else if result.DocInfo["id"] != "11" || result.DocInfo["app"] != "app1" || result.DocInfo["username"] != "user11" || + result.DocInfo["email"] != "user11@domain.com" || result.DocInfo["grade"].(float64) != 1.1 || result.DocInfo["active"] != false || + result.DocInfo["_rid"] == "" || result.DocInfo["_self"] == "" || result.DocInfo["_ts"].(float64) == 0.0 || result.DocInfo["_etag"] == "" || result.DocInfo["_attachments"] == "" { + t.Fatalf("%s failed: invalid dbinfo returned %#v", name, result.DocInfo) + } + if result := client.CreateDocument(DocumentSpec{ + DbName: dbname, CollName: collname, PartitionKeyValues: []interface{}{"app2", "user111"}, IndexingDirective: "Exclude", + DocumentData: map[string]interface{}{"id": "111", "app": "app2", "username": "user111", "email": "user111@domain.com", "grade": 1.11, "active": false}, + }); result.Error() != nil { + t.Fatalf("%s failed: %s", name, result.Error()) + } else if result.DocInfo["id"] != "111" || result.DocInfo["app"] != "app2" || result.DocInfo["username"] != "user111" || + result.DocInfo["email"] != "user111@domain.com" || result.DocInfo["grade"].(float64) != 1.11 || result.DocInfo["active"] != false || + result.DocInfo["_rid"] == "" || result.DocInfo["_self"] == "" || result.DocInfo["_ts"].(float64) == 0.0 || result.DocInfo["_etag"] == "" || result.DocInfo["_attachments"] == "" { + t.Fatalf("%s failed: invalid dbinfo returned %#v", name, result.DocInfo) + } + + // duplicated id + if result := client.CreateDocument(DocumentSpec{ + DbName: dbname, CollName: collname, PartitionKeyValues: []interface{}{"app1", "user1"}, + DocumentData: map[string]interface{}{"id": "1", "app": "app1", "username": "user1", "email": "user@domain1.com", "grade": 2, "active": false}, + }); result.CallErr != nil { + t.Fatalf("%s failed: %s", name, result.CallErr) + } else if result.StatusCode != 409 { + t.Fatalf("%s failed: expected %#v but received %#v", name, 409, result.StatusCode) + } + + // duplicated unique index + if result := client.CreateDocument(DocumentSpec{ + DbName: dbname, CollName: collname, PartitionKeyValues: []interface{}{"app1", "user1"}, + DocumentData: map[string]interface{}{"id": "2", "app": "app1", "username": "user1", "email": "user1@domain.com", "grade": 3, "active": true}, + }); result.CallErr != nil { + t.Fatalf("%s failed: %s", name, result.CallErr) + } else if result.StatusCode != 409 { + t.Fatalf("%s failed: expected %#v but received %#v", name, 409, result.StatusCode) + } +} + +func TestRestClient_CreateDocumentNoId_SubPartitions(t *testing.T) { + name := "TestRestClient_CreateDocumentNoId_SubPartitions" + client := _newRestClient(t, name) + + dbname := testDb + collname := testTable + client.DeleteDatabase(dbname) + client.CreateDatabase(DatabaseSpec{Id: dbname}) + client.CreateCollection(CollectionSpec{ + DbName: dbname, + CollName: collname, + PartitionKeyInfo: map[string]interface{}{"paths": []string{"/app", "/username"}, "kind": "MultiHash", "version": 2}, + }) + + client.autoId = true + if result := client.CreateDocument(DocumentSpec{ + DbName: dbname, CollName: collname, PartitionKeyValues: []interface{}{"app", "user1"}, + DocumentData: map[string]interface{}{"app": "app", "username": "user1", "email": "user1@domain.com", "grade": 1, "active": true}, + }); result.Error() != nil { + t.Fatalf("%s failed: %s", name, result.Error()) + } else if result.DocInfo["id"] == "" || result.DocInfo["app"] != "app" || result.DocInfo["username"] != "user1" || result.DocInfo["email"] != "user1@domain.com" || + result.DocInfo["grade"].(float64) != 1.0 || result.DocInfo["active"] != true || result.DocInfo["_rid"] == "" || + result.DocInfo["_self"] == "" || result.DocInfo["_ts"].(float64) == 0.0 || result.DocInfo["_etag"] == "" || result.DocInfo["_attachments"] == "" { + t.Fatalf("%s failed: invalid dbinfo returned %#v", name, result.DocInfo) + } + + client.autoId = false + if result := client.CreateDocument(DocumentSpec{ + DbName: dbname, CollName: collname, PartitionKeyValues: []interface{}{"app", "user2"}, + DocumentData: map[string]interface{}{"app": "app", "username": "user2", "email": "user2@domain.com", "grade": 2, "active": false}, + }); result.Error() == nil { + t.Fatalf("%s failed: this operation should not be successful", name) + } +} + +func TestRestClient_UpsertDocument_SubPartitions(t *testing.T) { + name := "TestRestClient_UpsertDocument_SubPartitions" + client := _newRestClient(t, name) + + dbname := testDb + collname := testTable + client.DeleteDatabase(dbname) + client.CreateDatabase(DatabaseSpec{Id: dbname}) + client.CreateCollection(CollectionSpec{ + DbName: dbname, + CollName: collname, + PartitionKeyInfo: map[string]interface{}{"paths": []string{"/app", "/username"}, "kind": "MultiHash", "version": 2}, + UniqueKeyPolicy: map[string]interface{}{"uniqueKeys": []map[string]interface{}{{"paths": []string{"/email"}}}}, + }) + + if result := client.CreateDocument(DocumentSpec{ + DbName: dbname, CollName: collname, PartitionKeyValues: []interface{}{"app", "user1"}, IsUpsert: true, + DocumentData: map[string]interface{}{"id": "1", "app": "app", "username": "user1", "email": "user1@domain.com", "grade": 1, "active": true}, + }); result.Error() != nil { + t.Fatalf("%s failed: %s", name, result.Error()) + } else if result.DocInfo["id"] != "1" || result.DocInfo["app"] != "app" || result.DocInfo["username"] != "user1" || + result.DocInfo["email"] != "user1@domain.com" || result.DocInfo["grade"].(float64) != 1.0 || result.DocInfo["active"] != true || + result.DocInfo["_rid"] == "" || result.DocInfo["_self"] == "" || result.DocInfo["_ts"].(float64) == 0.0 || result.DocInfo["_etag"] == "" || result.DocInfo["_attachments"] == "" { + t.Fatalf("%s failed: invalid dbinfo returned %#v", name, result.DocInfo) + } + if result := client.CreateDocument(DocumentSpec{ + DbName: dbname, CollName: collname, PartitionKeyValues: []interface{}{"app", "user2"}, IsUpsert: true, + DocumentData: map[string]interface{}{"id": "2", "app": "app", "username": "user2", "email": "user2@domain.com", "grade": 2, "active": false}, + }); result.Error() != nil { + t.Fatalf("%s failed: %s", name, result.Error()) + } else if result.DocInfo["id"] != "2" || result.DocInfo["app"] != "app" || result.DocInfo["username"] != "user2" || + result.DocInfo["email"] != "user2@domain.com" || result.DocInfo["grade"].(float64) != 2.0 || result.DocInfo["active"] != false || + result.DocInfo["_rid"] == "" || result.DocInfo["_self"] == "" || result.DocInfo["_ts"].(float64) == 0.0 || result.DocInfo["_etag"] == "" || result.DocInfo["_attachments"] == "" { + t.Fatalf("%s failed: invalid dbinfo returned %#v", name, result.DocInfo) + } + + if result := client.CreateDocument(DocumentSpec{ + DbName: dbname, CollName: collname, PartitionKeyValues: []interface{}{"app", "user1"}, IsUpsert: true, + DocumentData: map[string]interface{}{"id": "1", "app": "app", "username": "user1", "email": "user1@domain1.com", "grade": 2, "active": false, "data": "value"}, + }); result.Error() != nil { + t.Fatalf("%s failed: %s", name, result.Error()) + } else if result.DocInfo["id"] != "1" || result.DocInfo["app"] != "app" || result.DocInfo["username"] != "user1" || + result.DocInfo["email"] != "user1@domain1.com" || result.DocInfo["grade"].(float64) != 2.0 || result.DocInfo["active"] != false || result.DocInfo["data"] != "value" || + result.DocInfo["_rid"] == "" || result.DocInfo["_self"] == "" || result.DocInfo["_ts"].(float64) == 0.0 || result.DocInfo["_etag"] == "" || result.DocInfo["_attachments"] == "" { + t.Fatalf("%s failed: invalid dbinfo returned %#v", name, result.DocInfo) + } + + // duplicated unique key + if result := client.CreateDocument(DocumentSpec{ + DbName: dbname, CollName: collname, PartitionKeyValues: []interface{}{"app", "user1"}, IsUpsert: true, + DocumentData: map[string]interface{}{"id": "3", "app": "app", "username": "user1", "email": "user1@domain1.com", "grade": 2, "active": false, "data": "value"}, + }); result.CallErr != nil { + t.Fatalf("%s failed: %s", name, result.CallErr) + } else if result.StatusCode != 409 { + t.Fatalf("%s failed: expected %#v but received %#v", name, 409, result.StatusCode) + } +} + +func TestRestClient_UpsertDocumentNoId_SubPartitions(t *testing.T) { + name := "TestRestClient_UpsertDocumentNoId_SubPartitions" + client := _newRestClient(t, name) + + dbname := testDb + collname := testTable + client.DeleteDatabase(dbname) + client.CreateDatabase(DatabaseSpec{Id: dbname}) + client.CreateCollection(CollectionSpec{ + DbName: dbname, + CollName: collname, + PartitionKeyInfo: map[string]interface{}{"paths": []string{"/app", "/username"}, "kind": "MultiHash", "version": 2}, + UniqueKeyPolicy: map[string]interface{}{"uniqueKeys": []map[string]interface{}{{"paths": []string{"/email"}}}}, + }) + + client.autoId = true + if result := client.CreateDocument(DocumentSpec{ + DbName: dbname, CollName: collname, PartitionKeyValues: []interface{}{"app", "user1"}, IsUpsert: true, + DocumentData: map[string]interface{}{"app": "app", "username": "user1", "email": "user1@domain.com", "grade": 1, "active": true}, + }); result.Error() != nil { + t.Fatalf("%s failed: %s", name, result.Error()) + } else if result.DocInfo["id"] == "" || result.DocInfo["app"] != "app" || result.DocInfo["username"] != "user1" || + result.DocInfo["email"] != "user1@domain.com" || result.DocInfo["grade"].(float64) != 1.0 || result.DocInfo["active"] != true || + result.DocInfo["_rid"] == "" || result.DocInfo["_self"] == "" || result.DocInfo["_ts"].(float64) == 0.0 || result.DocInfo["_etag"] == "" || result.DocInfo["_attachments"] == "" { + t.Fatalf("%s failed: invalid dbinfo returned %#v", name, result.DocInfo) + } + + client.autoId = false + if result := client.CreateDocument(DocumentSpec{ + DbName: dbname, CollName: collname, PartitionKeyValues: []interface{}{"app", "user2"}, IsUpsert: true, + DocumentData: map[string]interface{}{"app": "app", "username": "user2", "email": "user2@domain.com", "grade": 2, "active": false}, + }); result.Error() == nil { + t.Fatalf("%s failed: this operation should not be successful", name) + } +} + +func TestRestClient_ReplaceDocument_SubPartitions(t *testing.T) { + name := "TestRestClient_ReplaceDocument_SubPartitions" + client := _newRestClient(t, name) + + dbname := testDb + collname := testTable + client.DeleteDatabase(dbname) + client.CreateDatabase(DatabaseSpec{Id: dbname}) + client.CreateCollection(CollectionSpec{ + DbName: dbname, + CollName: collname, + PartitionKeyInfo: map[string]interface{}{"paths": []string{"/app", "/username"}, "kind": "MultiHash", "version": 2}, + UniqueKeyPolicy: map[string]interface{}{"uniqueKeys": []map[string]interface{}{{"paths": []string{"/email"}}}}, + }) + + // insert 2 documents + docInfo := map[string]interface{}{"id": "2", "app": "app", "username": "user", "email": "user2@domain.com", "grade": 2.0, "active": false} + if result := client.CreateDocument(DocumentSpec{DbName: dbname, CollName: collname, PartitionKeyValues: []interface{}{"app", "user"}, DocumentData: docInfo}); result.Error() != nil { + t.Fatalf("%s failed: %s", name, result.Error()) + } else if result.DocInfo["id"] != docInfo["id"] || result.DocInfo["app"] != docInfo["app"] || result.DocInfo["username"] != docInfo["username"] || + result.DocInfo["email"] != docInfo["email"] || result.DocInfo["grade"] != docInfo["grade"] || result.DocInfo["active"] != docInfo["active"] || + result.DocInfo["_rid"] == "" || result.DocInfo["_self"] == "" || result.DocInfo["_ts"].(float64) == 0.0 || result.DocInfo["_etag"] == "" || result.DocInfo["_attachments"] == "" { + t.Fatalf("%s failed: invalid dbinfo returned %#v", name, result.DocInfo) + } + docInfo = map[string]interface{}{"id": "1", "app": "app", "username": "user", "email": "user1@domain.com", "grade": 1.0, "active": true} + if result := client.CreateDocument(DocumentSpec{DbName: dbname, CollName: collname, PartitionKeyValues: []interface{}{"app", "user"}, DocumentData: docInfo}); result.Error() != nil { + t.Fatalf("%s failed: %s", name, result.Error()) + } else if result.DocInfo["id"] != docInfo["id"] || result.DocInfo["app"] != docInfo["app"] || result.DocInfo["username"] != docInfo["username"] || + result.DocInfo["email"] != docInfo["email"] || result.DocInfo["grade"] != docInfo["grade"] || result.DocInfo["active"] != docInfo["active"] || + result.DocInfo["_rid"] == "" || result.DocInfo["_self"] == "" || result.DocInfo["_ts"].(float64) == 0.0 || result.DocInfo["_etag"] == "" || result.DocInfo["_attachments"] == "" { + t.Fatalf("%s failed: invalid dbinfo returned %#v", name, result.DocInfo) + } + + // conflict unique key + docInfo["email"] = "user2@domain.com" + if result := client.ReplaceDocument("", DocumentSpec{DbName: dbname, CollName: collname, PartitionKeyValues: []interface{}{"app", "user"}, DocumentData: docInfo}); result.CallErr != nil { + t.Fatalf("%s failed: %s", name, result.CallErr) + } else if result.StatusCode != 409 { + t.Fatalf("%s failed: expected %#v but received %#v", name, 404, result.StatusCode) + } + + // replace document without etag matching + var etag string + docInfo = map[string]interface{}{"id": "1", "app": "app", "username": "user", "email": "user1@domain.com", "grade": 1.0, "active": true} + docInfo["grade"] = 2.0 + docInfo["active"] = false + if result := client.ReplaceDocument("", DocumentSpec{DbName: dbname, CollName: collname, PartitionKeyValues: []interface{}{"app", "user"}, DocumentData: docInfo}); result.Error() != nil { + t.Fatalf("%s failed: %s", name, result.Error()) + } else if result.DocInfo["id"] != docInfo["id"] || result.DocInfo["app"] != docInfo["app"] || result.DocInfo["username"] != docInfo["username"] || + result.DocInfo["email"] != docInfo["email"] || result.DocInfo["grade"] != docInfo["grade"] || result.DocInfo["active"] != docInfo["active"] || + result.DocInfo["_rid"] == "" || result.DocInfo["_self"] == "" || result.DocInfo["_ts"].(float64) == 0.0 || result.DocInfo["_etag"] == "" || result.DocInfo["_attachments"] == "" { + t.Fatalf("%s failed: invalid dbinfo returned %#v", name, result.DocInfo) + } else { + etag = result.DocInfo["_etag"].(string) + } + + // replace document with etag matching: should not match + docInfo["email"] = "user3@domain.com" + docInfo["grade"] = 3.0 + docInfo["active"] = true + if result := client.ReplaceDocument(etag+"dummy", DocumentSpec{DbName: dbname, CollName: collname, PartitionKeyValues: []interface{}{"app", "user"}, DocumentData: docInfo}); result.CallErr != nil { + t.Fatalf("%s failed: %s", name, result.CallErr) + } else if result.StatusCode != 412 { + t.Fatalf("%s failed: expected %#v but received %#v", name, 412, result.StatusCode) + } + // replace document with etag matching: should match + if result := client.ReplaceDocument(etag, DocumentSpec{DbName: dbname, CollName: collname, PartitionKeyValues: []interface{}{"app", "user"}, DocumentData: docInfo}); result.Error() != nil { + t.Fatalf("%s failed: %s", name, result.Error()) + } else if result.DocInfo["id"] != docInfo["id"] || result.DocInfo["app"] != docInfo["app"] || result.DocInfo["username"] != docInfo["username"] || + result.DocInfo["email"] != docInfo["email"] || result.DocInfo["grade"] != docInfo["grade"] || result.DocInfo["active"] != docInfo["active"] || + result.DocInfo["_rid"] == "" || result.DocInfo["_self"] == "" || result.DocInfo["_ts"].(float64) == 0.0 || result.DocInfo["_etag"] == "" || result.DocInfo["_attachments"] == "" { + t.Fatalf("%s failed: invalid dbinfo returned %#v", name, result.DocInfo) + } + + // document not found + docInfo["id"] = "0" + if result := client.ReplaceDocument("", DocumentSpec{DbName: dbname, CollName: collname, PartitionKeyValues: []interface{}{"app", "user"}, DocumentData: docInfo}); result.CallErr != nil { + t.Fatalf("%s failed: %s", name, result.CallErr) + } else if result.StatusCode != 404 { + t.Fatalf("%s failed: expected %#v but received %#v", name, 404, result.StatusCode) + } +} + +func TestRestClient_ReplaceDocumentCrossPartitions_SubPartitions(t *testing.T) { + name := "TestRestClient_ReplaceDocumentCrossPartitions_SubPartitions" + client := _newRestClient(t, name) + + dbname := testDb + collname := testTable + client.DeleteDatabase(dbname) + client.CreateDatabase(DatabaseSpec{Id: dbname}) + client.CreateCollection(CollectionSpec{ + DbName: dbname, + CollName: collname, + PartitionKeyInfo: map[string]interface{}{"paths": []string{"/app", "/username"}, "kind": "MultiHash", "version": 2}, + UniqueKeyPolicy: map[string]interface{}{"uniqueKeys": []map[string]interface{}{{"paths": []string{"/email"}}}}, + }) + + // insert a document + docInfo := map[string]interface{}{"id": "1", "app": "app", "username": "user1", "email": "user1@domain.com", "grade": 1.0, "active": true} + if result := client.CreateDocument(DocumentSpec{DbName: dbname, CollName: collname, PartitionKeyValues: []interface{}{"app", "user1"}, DocumentData: docInfo}); result.Error() != nil { + t.Fatalf("%s failed: %s", name, result.Error()) + } else if result.DocInfo["id"] != docInfo["id"] || result.DocInfo["app"] != docInfo["app"] || result.DocInfo["username"] != docInfo["username"] || + result.DocInfo["email"] != docInfo["email"] || result.DocInfo["grade"] != docInfo["grade"] || result.DocInfo["active"] != docInfo["active"] || + result.DocInfo["_rid"] == "" || result.DocInfo["_self"] == "" || result.DocInfo["_ts"].(float64) == 0.0 || result.DocInfo["_etag"] == "" || result.DocInfo["_attachments"] == "" { + t.Fatalf("%s failed: invalid dbinfo returned %#v", name, result.DocInfo) + } + + docInfo["username"] = "user2" + if result := client.ReplaceDocument("", DocumentSpec{DbName: dbname, CollName: collname, PartitionKeyValues: []interface{}{"app", "user2"}, DocumentData: docInfo}); result.CallErr != nil { + t.Fatalf("%s failed: %s", name, result.CallErr) + } else if result.StatusCode != 404 { + t.Fatalf("%s failed: expected %#v but received %#v", name, 404, result.StatusCode) + } + + docInfo["username"] = "user2" + if result := client.ReplaceDocument("", DocumentSpec{DbName: dbname, CollName: collname, PartitionKeyValues: []interface{}{"app", "user1"}, DocumentData: docInfo}); result.CallErr != nil { + t.Fatalf("%s failed: %s", name, result.CallErr) + } else if result.StatusCode != 400 { + t.Fatalf("%s failed: expected %#v but received %#v", name, 400, result.StatusCode) + } +} + +func TestRestClient_GetDocument_SubPartitions(t *testing.T) { + name := "TestRestClient_GetDocument_SubPartitions" + client := _newRestClient(t, name) + + dbname := testDb + collname := testTable + client.DeleteDatabase(dbname) + client.CreateDatabase(DatabaseSpec{Id: dbname}) + client.CreateCollection(CollectionSpec{ + DbName: dbname, + CollName: collname, + PartitionKeyInfo: map[string]interface{}{"paths": []string{"/app", "/username"}, "kind": "MultiHash", "version": 2}, + UniqueKeyPolicy: map[string]interface{}{"uniqueKeys": []map[string]interface{}{{"paths": []string{"/email"}}}}, + }) + + var etag, sessionToken string + docInfo := map[string]interface{}{"id": "1", "app": "app", "username": "user", "email": "user1@domain.com", "grade": 1.0, "active": true} + if result := client.CreateDocument(DocumentSpec{DbName: dbname, CollName: collname, PartitionKeyValues: []interface{}{"app", "user"}, DocumentData: docInfo}); result.Error() != nil { + t.Fatalf("%s failed: %s", name, result.Error()) + } else { + etag = result.DocInfo["_etag"].(string) + sessionToken = result.SessionToken + } + + if result := client.GetDocument(DocReq{DbName: dbname, CollName: collname, DocId: "1", PartitionKeyValues: []interface{}{"app", "user"}}); result.Error() != nil { + t.Fatalf("%s failed: %s", name, result.Error()) + } else if result.DocInfo.Id() != docInfo["id"] || result.DocInfo["app"] != docInfo["app"] || result.DocInfo["username"] != docInfo["username"] || + result.DocInfo["email"] != docInfo["email"] || result.DocInfo["grade"] != docInfo["grade"] || result.DocInfo["active"] != docInfo["active"] || + result.DocInfo.Rid() == "" || result.DocInfo.Self() == "" || result.DocInfo.Ts() == 0 || result.DocInfo.Etag() == "" || result.DocInfo.Attachments() == "" { + t.Fatalf("%s failed: invalid dbinfo returned %#v", name, result.DocInfo) + } else { + ago := time.Now().Add(-5 * time.Minute) + docTime := result.DocInfo.TsAsTime() + if !ago.Before(docTime) { + t.Fatalf("%s failed: invalid document time %s", name, docTime) + } + + clone := result.DocInfo.RemoveSystemAttrs() + for k := range clone { + if strings.HasPrefix(k, "_") { + t.Fatalf("%s failed: invalid cloned document %#v", name, clone) + } + } + } + + if result := client.GetDocument(DocReq{NotMatchEtag: etag + "dummy", DbName: dbname, CollName: collname, DocId: "1", + ConsistencyLevel: "Session", SessionToken: sessionToken, + PartitionKeyValues: []interface{}{"app", "user"}}); result.Error() != nil { + t.Fatalf("%s failed: %s", name, result.Error()) + } else if result.DocInfo["id"] != docInfo["id"] || result.DocInfo["app"] != docInfo["app"] || result.DocInfo["username"] != docInfo["username"] || + result.DocInfo["email"] != docInfo["email"] || result.DocInfo["grade"] != docInfo["grade"] || result.DocInfo["active"] != docInfo["active"] || + result.DocInfo["_rid"] == "" || result.DocInfo["_self"] == "" || result.DocInfo["_ts"].(float64) == 0.0 || result.DocInfo["_etag"] == "" || result.DocInfo["_attachments"] == "" { + t.Fatalf("%s failed: invalid dbinfo returned %#v", name, result.DocInfo) + } + + if result := client.GetDocument(DocReq{NotMatchEtag: etag, DbName: dbname, CollName: collname, DocId: "1", PartitionKeyValues: []interface{}{"app", "user"}}); result.CallErr != nil { + t.Fatalf("%s failed: %s", name, result.CallErr) + } else if result.StatusCode != 304 { + t.Fatalf("%s failed: expected %#v but received %#v", name, 304, result.StatusCode) + } + + if result := client.GetDocument(DocReq{DbName: dbname, CollName: collname, DocId: "0", PartitionKeyValues: []interface{}{"app", "user"}}); result.CallErr != nil { + t.Fatalf("%s failed: %s", name, result.CallErr) + } else if result.StatusCode != 404 { + t.Fatalf("%s failed: expected %#v but received %#v", name, 404, result.StatusCode) + } +} + +func TestRestClient_DeleteDocument_SubPartitions(t *testing.T) { + name := "TestRestClient_DeleteDocument_SubPartitions" + client := _newRestClient(t, name) + + dbname := testDb + collname := testTable + client.DeleteDatabase(dbname) + client.CreateDatabase(DatabaseSpec{Id: dbname}) + client.CreateCollection(CollectionSpec{ + DbName: dbname, + CollName: collname, + PartitionKeyInfo: map[string]interface{}{"paths": []string{"/app", "/username"}, "kind": "MultiHash", "version": 2}, + UniqueKeyPolicy: map[string]interface{}{"uniqueKeys": []map[string]interface{}{{"paths": []string{"/email"}}}}, + }) + + var etag string + docInfo := map[string]interface{}{"id": "1", "app": "app", "username": "user", "email": "user1@domain.com", "grade": 1.0, "active": true} + if result := client.CreateDocument(DocumentSpec{DbName: dbname, CollName: collname, PartitionKeyValues: []interface{}{"app", "user"}, DocumentData: docInfo}); result.Error() != nil { + t.Fatalf("%s failed: %s", name, result.Error()) + } else { + etag = result.DocInfo["_etag"].(string) + } + + if result := client.DeleteDocument(DocReq{MatchEtag: etag + "dummy", DbName: dbname, CollName: collname, DocId: "1", PartitionKeyValues: []interface{}{"app", "user"}}); result.CallErr != nil { + t.Fatalf("%s failed: %s", name, result.CallErr) + } else if result.StatusCode != 412 { + t.Fatalf("%s failed: expected %#v but received %#v", name, 404, result.StatusCode) + } + if result := client.DeleteDocument(DocReq{MatchEtag: etag, DbName: dbname, CollName: collname, DocId: "1", PartitionKeyValues: []interface{}{"app", "user"}}); result.Error() != nil { + t.Fatalf("%s failed: %s", name, result.Error()) + } + + if result := client.CreateDocument(DocumentSpec{DbName: dbname, CollName: collname, PartitionKeyValues: []interface{}{"app", "user"}, DocumentData: docInfo}); result.Error() != nil { + t.Fatalf("%s failed: %s", name, result.Error()) + } + if result := client.DeleteDocument(DocReq{DbName: dbname, CollName: collname, DocId: "1", PartitionKeyValues: []interface{}{"app", "user"}}); result.Error() != nil { + t.Fatalf("%s failed: %s", name, result.Error()) + } + + if result := client.DeleteDocument(DocReq{DbName: dbname, CollName: collname, DocId: "1", PartitionKeyValues: []interface{}{"app", "user"}}); result.CallErr != nil { + t.Fatalf("%s failed: %s", name, result.CallErr) + } else if result.StatusCode != 404 { + t.Fatalf("%s failed: expected %#v but received %#v", name, 404, result.StatusCode) + } +} diff --git a/stmt_collection.go b/stmt_collection.go index 31ee81f..8c92c6e 100644 --- a/stmt_collection.go +++ b/stmt_collection.go @@ -6,6 +6,7 @@ import ( "fmt" "regexp" "strconv" + "strings" ) // StmtCreateCollection implements "CREATE COLLECTION" statement. @@ -13,15 +14,15 @@ import ( // Syntax: // // CREATE COLLECTION|TABLE [IF NOT EXISTS] [.] -// +// // [[,] WITH RU|MAXRU=ru] // [[,] WITH UK=/path1:/path2,/path3;/path4] // // - ru: an integer specifying CosmosDB's collection throughput expressed in RU/s. Supply either RU or MAXRU, not both! // -// - If "IF NOT EXISTS" is specified, Exec will silently swallow the error "409 Conflict". +// - partitionKey is either single (single value of /path) or hierarchical, up to 3 path levels, levels are separated by commas, for example: /path1,/path2,/path3. // -// - Use LARGEPK if partitionKey is larger than 100 bytes. +// - If "IF NOT EXISTS" is specified, Exec will silently swallow the error "409 Conflict". // // - Use UK to define unique keys. Each unique key consists a list of paths separated by comma (,). Unique keys are separated by colons (:) or semi-colons (;). type StmtCreateCollection struct { @@ -29,7 +30,6 @@ type StmtCreateCollection struct { dbName string collName string // collection name ifNotExists bool - isLargePk bool ru, maxru int pk string // partition key uk [][]string // unique keys @@ -42,20 +42,16 @@ func (s *StmtCreateCollection) parse() error { } // partition key - pk, okPk := s.withOpts["PK"] - largePk, okLargePk := s.withOpts["LARGEPK"] - if pk != "" && largePk != "" { + pk, largepk := s.withOpts["PK"], s.withOpts["LARGEPK"] + if pk != "" && largepk != "" && pk != largepk { return fmt.Errorf("only one of PK or LARGEPK must be specified") } - if !okPk && !okLargePk && pk == "" && largePk == "" { - return fmt.Errorf("invalid or missting PartitionKey value: %s%s", s.withOpts["PK"], s.withOpts["LARGEPK"]) + s.pk = s.withOpts["PK"] + if s.pk == "" { + s.pk = s.withOpts["LARGEPK"] } - if okPk && pk != "" { - s.pk = pk - } - if okLargePk && largePk != "" { - s.pk = largePk - s.isLargePk = true + if s.pk == "" { + return fmt.Errorf("missting PartitionKey value") } // request unit @@ -104,13 +100,17 @@ func (s *StmtCreateCollection) Query(_ []driver.Value) (driver.Rows, error) { // Exec implements driver.Stmt/Exec. func (s *StmtCreateCollection) Exec(_ []driver.Value) (driver.Result, error) { + pkPaths := strings.Split(s.pk, ",") + pkType := "Hash" + if len(pkPaths) > 1 { + pkType = "MultiHash" + } spec := CollectionSpec{DbName: s.dbName, CollName: s.collName, Ru: s.ru, MaxRu: s.maxru, PartitionKeyInfo: map[string]interface{}{ - "paths": []string{s.pk}, - "kind": "Hash", - }} - if s.isLargePk { - spec.PartitionKeyInfo["Version"] = 2 + "paths": pkPaths, + "kind": pkType, + "version": 2, + }, } if len(s.uk) > 0 { uniqueKeys := make([]interface{}, 0) diff --git a/stmt_collection_parsing_test.go b/stmt_collection_parsing_test.go index 772abab..8c33e86 100644 --- a/stmt_collection_parsing_test.go +++ b/stmt_collection_parsing_test.go @@ -27,8 +27,9 @@ func TestStmtCreateCollection_parse(t *testing.T) { {name: "basic", sql: "CREATE COLLECTION db1.table1 WITH pk=/id", expected: &StmtCreateCollection{dbName: "db1", collName: "table1", pk: "/id"}}, {name: "table_with_ru", sql: "create\ntable\rdb-2.table_2 WITH\tPK=/email WITH\r\nru=100", expected: &StmtCreateCollection{dbName: "db-2", collName: "table_2", pk: "/email", ru: 100}}, - {name: "if_not_exists_large_pk_with_maxru", sql: "CREATE collection\nIF\rNOT\t\nEXISTS\n\tdb_3.table-3 with largePK=/id WITH\t\rmaxru=100", expected: &StmtCreateCollection{dbName: "db_3", collName: "table-3", ifNotExists: true, isLargePk: true, pk: "/id", maxru: 100}}, - {name: "table_if_not_exists_large_pk_with_uk", sql: "create TABLE if not exists db-0_1.table_0-1 WITH LARGEpk=/a/b/c with uk=/a:/b,/c/d;/e/f/g", expected: &StmtCreateCollection{dbName: "db-0_1", collName: "table_0-1", ifNotExists: true, isLargePk: true, pk: "/a/b/c", uk: [][]string{{"/a"}, {"/b", "/c/d"}, {"/e/f/g"}}}}, + {name: "if_not_exists_large_pk_with_maxru", sql: "CREATE collection\nIF\rNOT\t\nEXISTS\n\tdb_3.table-3 with largePK=/id WITH\t\rmaxru=100", expected: &StmtCreateCollection{dbName: "db_3", collName: "table-3", ifNotExists: true, pk: "/id", maxru: 100}}, + {name: "table_if_not_exists_large_pk_with_uk", sql: "create TABLE if not exists db-0_1.table_0-1 WITH LARGEpk=/a/b/c with uk=/a:/b,/c/d;/e/f/g", expected: &StmtCreateCollection{dbName: "db-0_1", collName: "table_0-1", ifNotExists: true, pk: "/a/b/c", uk: [][]string{{"/a"}, {"/b", "/c/d"}, {"/e/f/g"}}}}, + {name: "subpartitions", sql: "CREATE COLLECTION db1.table1 WITH pk=/TenantId,/UserId,/SessionId", expected: &StmtCreateCollection{dbName: "db1", collName: "table1", pk: "/TenantId,/UserId,/SessionId"}}, } for _, testCase := range testData { t.Run(testCase.name, func(t *testing.T) { @@ -68,8 +69,9 @@ func TestStmtCreateCollection_parse_defaultDb(t *testing.T) { {name: "basic", db: "mydb", sql: "CREATE COLLECTION table1 WITH pk=/id", expected: &StmtCreateCollection{dbName: "mydb", collName: "table1", pk: "/id"}}, {name: "db_in_query", db: "mydb", sql: "create\ntable\r\ndb2.table_2 WITH\r\t\nPK=/email WITH\nru=100", expected: &StmtCreateCollection{dbName: "db2", collName: "table_2", pk: "/email", ru: 100}}, - {name: "if_not_exists", db: "mydb", sql: "CREATE collection\nIF\nNOT\t\nEXISTS\n\ttable-3 with largePK=/id WITH\tmaxru=100", expected: &StmtCreateCollection{dbName: "mydb", collName: "table-3", ifNotExists: true, isLargePk: true, pk: "/id", maxru: 100}}, - {name: "db_in_query_if_not_exists", db: "mydb", sql: "create TABLE if not exists db3.table_0-1 WITH LARGEpk=/a/b/c with uk=/a:/b,/c/d;/e/f/g", expected: &StmtCreateCollection{dbName: "db3", collName: "table_0-1", ifNotExists: true, isLargePk: true, pk: "/a/b/c", uk: [][]string{{"/a"}, {"/b", "/c/d"}, {"/e/f/g"}}}}, + {name: "if_not_exists", db: "mydb", sql: "CREATE collection\nIF\nNOT\t\nEXISTS\n\ttable-3 with largePK=/id WITH\tmaxru=100", expected: &StmtCreateCollection{dbName: "mydb", collName: "table-3", ifNotExists: true, pk: "/id", maxru: 100}}, + {name: "db_in_query_if_not_exists", db: "mydb", sql: "create TABLE if not exists db3.table_0-1 WITH LARGEpk=/a/b/c with uk=/a:/b,/c/d;/e/f/g", expected: &StmtCreateCollection{dbName: "db3", collName: "table_0-1", ifNotExists: true, pk: "/a/b/c", uk: [][]string{{"/a"}, {"/b", "/c/d"}, {"/e/f/g"}}}}, + {name: "subpartitions", db: "mydb", sql: "CREATE COLLECTION table1 WITH pk=/TenantId,/UserId", expected: &StmtCreateCollection{dbName: "mydb", collName: "table1", pk: "/TenantId,/UserId"}}, } for _, testCase := range testData { t.Run(testCase.name, func(t *testing.T) { diff --git a/stmt_collection_test.go b/stmt_collection_test.go index 035b376..cdadbf8 100644 --- a/stmt_collection_test.go +++ b/stmt_collection_test.go @@ -26,6 +26,7 @@ func TestStmtCreateCollection_Exec(t *testing.T) { sql string mustConflict bool mustNotFound bool + mustError bool affectedRows int64 }{ { @@ -54,6 +55,24 @@ func TestStmtCreateCollection_Exec(t *testing.T) { sql: "CREATE COLLECTION db_not_exists.table WITH pk=/a", mustNotFound: true, }, + { + name: "create_subpartitions", + initSqls: []string{fmt.Sprintf("DROP DATABASE IF EXISTS %s", dbname), fmt.Sprintf("CREATE DATABASE %s", dbname)}, + sql: fmt.Sprintf("CREATE COLLECTION %s.tbltemp WITH pk=/TenantId,/UserId", dbname), + affectedRows: 1, + }, + { + name: "create_subpartitions2", + initSqls: []string{fmt.Sprintf("DROP DATABASE IF EXISTS %s", dbname), fmt.Sprintf("CREATE DATABASE %s", dbname)}, + sql: fmt.Sprintf("CREATE TABLE %s.tbltemp WITH pk=/TenantId,/UserId,/SessionId", dbname), + affectedRows: 1, + }, + { + name: "error_subpartitions", + initSqls: []string{fmt.Sprintf("DROP DATABASE IF EXISTS %s", dbname), fmt.Sprintf("CREATE DATABASE %s", dbname)}, + sql: fmt.Sprintf("CREATE COLLECTION %s.tbltemp WITH pk=/TenantId,/UserId,/SessionId,/Level4NotSupported", dbname), + mustError: true, + }, } for _, testCase := range testData { t.Run(testCase.name, func(t *testing.T) { @@ -70,7 +89,10 @@ func TestStmtCreateCollection_Exec(t *testing.T) { if testCase.mustNotFound && err != ErrNotFound { t.Fatalf("%s failed: expect ErrNotFound but received %#v", testName+"/"+testCase.name+"/exec", err) } - if testCase.mustConflict || testCase.mustNotFound { + if testCase.mustError && err == nil { + t.Fatalf("%s failed: expect error", testName+"/"+testCase.name+"/exec") + } + if testCase.mustConflict || testCase.mustNotFound || testCase.mustError { return } if err != nil { @@ -110,6 +132,7 @@ func TestStmtCreateCollection_Exec_DefaultDb(t *testing.T) { sql string mustConflict bool mustNotFound bool + mustError bool affectedRows int64 }{ { @@ -139,6 +162,24 @@ func TestStmtCreateCollection_Exec_DefaultDb(t *testing.T) { sql: "CREATE COLLECTION table WITH pk=/a", mustNotFound: true, }, + { + name: "create_subpartitions", + initSqls: []string{fmt.Sprintf("DROP DATABASE IF EXISTS %s", dbname), fmt.Sprintf("CREATE DATABASE %s", dbname)}, + sql: "CREATE COLLECTION tbltemp WITH pk=/TenantId,/UserId", + affectedRows: 1, + }, + { + name: "create_subpartitions2", + initSqls: []string{fmt.Sprintf("DROP DATABASE IF EXISTS %s", dbname), fmt.Sprintf("CREATE DATABASE %s", dbname)}, + sql: "CREATE TABLE tbltemp WITH pk=/TenantId,/UserId,/SessionId", + affectedRows: 1, + }, + { + name: "error_subpartitions", + initSqls: []string{fmt.Sprintf("DROP DATABASE IF EXISTS %s", dbname), fmt.Sprintf("CREATE DATABASE %s", dbname)}, + sql: "CREATE COLLECTION tbltemp WITH pk=/TenantId,/UserId,/SessionId,/Level4NotSupported", + mustError: true, + }, } for _, testCase := range testData { t.Run(testCase.name, func(t *testing.T) { @@ -155,7 +196,10 @@ func TestStmtCreateCollection_Exec_DefaultDb(t *testing.T) { if testCase.mustNotFound && err != ErrNotFound { t.Fatalf("%s failed: expect ErrNotFound but received %#v", testName+"/"+testCase.name+"/exec", err) } - if testCase.mustConflict || testCase.mustNotFound { + if testCase.mustError && err == nil { + t.Fatalf("%s failed: expect error", testName+"/"+testCase.name+"/exec") + } + if testCase.mustConflict || testCase.mustNotFound || testCase.mustError { return } if err != nil { From e30f95a366687fb21c8b47fac038d0d12a1eb520 Mon Sep 17 00:00:00 2001 From: Thanh Nguyen Date: Tue, 13 Jun 2023 21:58:04 +1000 Subject: [PATCH 2/8] INSERT statement: support subpartitions --- RELEASE-NOTES.md | 4 +- SQL.md | 10 +- restclient.go | 31 +++- stmt.go | 22 +-- stmt_document.go | 64 +++++++- stmt_document_parsing_test.go | 281 +++++++++++++++++++--------------- stmt_document_test.go | 108 ++++++++++++- 7 files changed, 367 insertions(+), 153 deletions(-) diff --git a/RELEASE-NOTES.md b/RELEASE-NOTES.md index 7e504e3..f8f1c00 100644 --- a/RELEASE-NOTES.md +++ b/RELEASE-NOTES.md @@ -3,8 +3,8 @@ ## 2003-06-0x - v0.3.0 - Change default API version to `2020-07-15`. -- Add Hierarchical Partition Keys (sub-partitions) support. -- PartitionKey version 1 is no longer used (hence large PK is always enabled). +- Add [Hierarchical Partition Keys](https://learn.microsoft.com/en-us/azure/cosmos-db/hierarchical-partition-keys) (sub-partitions) support. +- Use PartitionKey version 2 (replacing version 1), hence large PK is always enabled. ## 2023-06-09 - v0.2.1 diff --git a/SQL.md b/SQL.md index ed62cae..f6f0771 100644 --- a/SQL.md +++ b/SQL.md @@ -297,11 +297,16 @@ Description: insert a new document into an existing collection. Syntax: ```sql -INSERT INTO [.] (, ,...) VALUES (, ,...) +INSERT INTO [.] +(, ,...) +VALUES (, ,...) +[WITH singlePK|SINGLE_PK] ``` > `` can be ommitted if `DefaultDb` is supplied in the Data Source Name (DSN). +Since [v0.3.0](RELEASE-NOTES.md), `gocosmos` supports [Hierarchical Partition Keys](https://learn.microsoft.com/en-us/azure/cosmos-db/hierarchical-partition-keys) (or sub-partitions). If the collection is known not to have sub-partitions, supplying `WITH singlePK` (or `WITH SINGLE_PK`) can save one roundtrip to Cosmos DB server. + Example: ```go sql := `INSERT INTO mydb.mytable (a, b, c, d, e) VALUES (1, "\"a string\"", :1, @2, $3)` @@ -314,8 +319,7 @@ fmt.Println(dbresult.RowsAffected()) > Use `sql.DB.Exec` to execute the statement, `Query` will return error. -> Value of partition key _must_ be supplied at the last argument of `db.Exec()` call. - +> Values of partition key _must_ be supplied at the end of the argument list when invoking `db.Exec()`. A value is either: - a placeholder - which is a number prefixed by `$` or `@` or `:`, for example `$1`, `@2` or `:3`. Placeholders are 1-based index, that means starting from 1. diff --git a/restclient.go b/restclient.go index f0218d1..b2d66e3 100644 --- a/restclient.go +++ b/restclient.go @@ -1200,6 +1200,35 @@ type RespListDb struct { Databases []DbInfo `json:"Databases"` } +// PkInfo holds partitioning configuration settings for a collection. +// +// @Available since v0.3.0 +type PkInfo map[string]interface{} + +func (pk PkInfo) Kind() string { + kind, err := reddo.ToString(pk["kind"]) + if err == nil { + return kind + } + return "" +} + +func (pk PkInfo) Version() int { + version, err := reddo.ToInt(pk["version"]) + if err == nil { + return int(version) + } + return 0 +} + +func (pk PkInfo) Paths() []string { + paths, err := reddo.ToSlice(pk["paths"], reddo.TypeString) + if err == nil { + return paths.([]string) + } + return nil +} + // CollInfo captures info of a Cosmos DB collection. type CollInfo struct { Id string `json:"id"` // user-generated unique name for the collection @@ -1213,7 +1242,7 @@ type CollInfo struct { Udfs string `json:"_udfs"` // (system-generated property) _udfs attribute of the collection Conflicts string `json:"_conflicts"` // (system-generated property) _conflicts attribute of the collection IndexingPolicy map[string]interface{} `json:"indexingPolicy"` // indexing policy settings for collection - PartitionKey map[string]interface{} `json:"partitionKey"` // partitioning configuration settings for collection + PartitionKey PkInfo `json:"partitionKey"` // partitioning configuration settings for collection ConflictResolutionPolicy map[string]interface{} `json:"conflictResolutionPolicy"` // conflict resolution policy settings for collection GeospatialConfig map[string]interface{} `json:"geospatialConfig"` // Geo-spatial configuration settings for collection } diff --git a/stmt.go b/stmt.go index fd85955..76fc822 100644 --- a/stmt.go +++ b/stmt.go @@ -14,7 +14,7 @@ const ( field = `([\w\-]+)` ifNotExists = `(\s+IF\s+NOT\s+EXISTS)?` ifExists = `(\s+IF\s+EXISTS)?` - with = `(\s+WITH\s+` + field + `\s*=\s*([\w/\.\*,;:'"-]+)((\s+|\s*,\s+|\s+,\s*)WITH\s+` + field + `\s*=\s*([\w/\.\*,;:'"-]+))*)?` + with = `(\s+WITH\s+` + field + `(\s*=\s*([\w/\.\*,;:'"-]+))?((\s+|\s*,\s+|\s+,\s*)WITH\s+` + field + `(\s*=\s*([\w/\.\*,;:'"-]+))*)?)?` ) var ( @@ -28,7 +28,7 @@ var ( reDropColl = regexp.MustCompile(`(?is)^DROP\s+(COLLECTION|TABLE)` + ifExists + `\s+(` + field + `\.)?` + field + `$`) reListColls = regexp.MustCompile(`(?is)^LIST\s+(COLLECTIONS?|TABLES?)(\s+FROM\s+` + field + `)?$`) - reInsert = regexp.MustCompile(`(?is)^(INSERT|UPSERT)\s+INTO\s+(` + field + `\.)?` + field + `\s*\(([^)]*?)\)\s*VALUES\s*\(([^)]*?)\)$`) + reInsert = regexp.MustCompile(`(?is)^(INSERT|UPSERT)\s+INTO\s+(` + field + `\.)?` + field + `\s*\(([^)]*?)\)\s*VALUES\s*\(([^)]*?)\)` + with + `$`) reSelect = regexp.MustCompile(`(?is)^SELECT\s+(CROSS\s+PARTITION\s+)?.*?\s+FROM\s+` + field + `.*?` + with + `$`) reUpdate = regexp.MustCompile(`(?is)^UPDATE\s+(` + field + `\.)?` + field + `\s+SET\s+(.*)\s+WHERE\s+id\s*=\s*(.*)$`) reDelete = regexp.MustCompile(`(?is)^DELETE\s+FROM\s+(` + field + `\.)?` + field + `\s+WHERE\s+id\s*=\s*(.*)$`) @@ -142,17 +142,19 @@ func parseQueryWithDefaultDb(c *Conn, defaultDb, query string) (driver.Stmt, err if re := reInsert; re.MatchString(query) { groups := re.FindAllStringSubmatch(query, -1) stmt := &StmtInsert{ - Stmt: &Stmt{query: query, conn: c, numInput: 0}, + StmtCRUD: &StmtCRUD{ + Stmt: &Stmt{query: query, conn: c, numInput: 0}, + dbName: strings.TrimSpace(groups[0][3]), + collName: strings.TrimSpace(groups[0][4]), + }, isUpsert: strings.ToUpper(strings.TrimSpace(groups[0][1])) == "UPSERT", - dbName: strings.TrimSpace(groups[0][3]), - collName: strings.TrimSpace(groups[0][4]), fieldsStr: strings.TrimSpace(groups[0][5]), valuesStr: strings.TrimSpace(groups[0][6]), } if stmt.dbName == "" { stmt.dbName = defaultDb } - if err := stmt.parse(); err != nil { + if err := stmt.parse(groups[0][7]); err != nil { return nil, err } return stmt, stmt.validate() @@ -212,11 +214,11 @@ func parseQueryWithDefaultDb(c *Conn, defaultDb, query string) (driver.Stmt, err type Stmt struct { query string // the SQL query conn *Conn // the connection that this prepared statement is bound to - numInput int // number of placeholder parameters + numInput int // number of placeholder parameters, INCLUDING PK values! withOpts map[string]string } -var reWithOpts = regexp.MustCompile(`(?is)^(\s+|\s*,\s+|\s+,\s*)WITH\s+` + field + `\s*=\s*([\w/\.\*,;:'"-]+)`) +var reWithOpts = regexp.MustCompile(`(?is)^(\s+|\s*,\s+|\s+,\s*)WITH\s+` + field + `(\s*=\s*([\w/\.\*,;:'"-]+))?`) // parseWithOpts parses "WITH..." clause and store result in withOpts map. // This function returns no error. Sub-implementations may override this behavior. @@ -229,7 +231,7 @@ func (s *Stmt) parseWithOpts(withOptsStr string) error { break } k := strings.TrimSpace(strings.ToUpper(matches[2])) - s.withOpts[k] = strings.TrimSuffix(strings.TrimSpace(matches[3]), ",") + s.withOpts[k] = strings.TrimSuffix(strings.TrimSpace(matches[4]), ",") withOptsStr = withOptsStr[len(matches[0]):] } return nil @@ -242,7 +244,7 @@ func (s *Stmt) Close() error { // NumInput implements driver.Stmt/NumInput. func (s *Stmt) NumInput() int { - return s.numInput + return -1 } /*----------------------------------------------------------------------*/ diff --git a/stmt_document.go b/stmt_document.go index 23c0c53..92b9192 100644 --- a/stmt_document.go +++ b/stmt_document.go @@ -65,6 +65,41 @@ func _parseValue(input string, separator rune) (value interface{}, leftOver stri return nil, input, errors.New("cannot parse query, invalid token at: " + input) } +// StmtCRUD is abstract implementation of "INSERT|UPSERT|UPDATE|DELETE|SELECT" operations. +// +// @Available since v0.3.0 +type StmtCRUD struct { + *Stmt + dbName string + collName string + numPkPaths int // number of PK paths + isSinglePathPk bool +} + +func (s *StmtCRUD) extractPkValuesFromArgs(args ...driver.Value) []interface{} { + n := len(args) + result := make([]interface{}, s.numPkPaths) + for i := n - s.numPkPaths; i < n; i++ { + result[i-n+s.numPkPaths] = args[i] + } + return result +} + +func (s *StmtCRUD) fetchPkInfo() error { + if s.numPkPaths > 0 || s.conn == nil || s.isSinglePathPk { + if s.isSinglePathPk { + s.numPkPaths = 1 + } + return nil + } + + getCollResult := s.conn.restClient.GetCollection(s.dbName, s.collName) + if getCollResult.Error() == nil { + s.numPkPaths = len(getCollResult.CollInfo.PartitionKey.Paths()) + } + return normalizeError(getCollResult.StatusCode, 0, getCollResult.Error()) +} + // StmtInsert implements "INSERT" operation. // // Syntax: @@ -88,9 +123,7 @@ func _parseValue(input string, separator rune) (value interface{}, leftOver stri // CosmosDB automatically creates a few extra fields for the insert document. // See https://docs.microsoft.com/en-us/azure/cosmos-db/account-databases-containers-items#properties-of-an-item. type StmtInsert struct { - *Stmt - dbName string - collName string + *StmtCRUD isUpsert bool fieldsStr string valuesStr string @@ -98,10 +131,20 @@ type StmtInsert struct { values []interface{} } -func (s *StmtInsert) parse() error { +func (s *StmtInsert) parse(withOptsStr string) error { + if err := s.parseWithOpts(withOptsStr); err != nil { + return err + } + _, ok1 := s.withOpts["SINGLEPK"] + _, ok2 := s.withOpts["SINGLE_PK"] + s.isSinglePathPk = ok1 || ok2 + if s.isSinglePathPk { + s.numPkPaths = 1 + } + s.fields = regexp.MustCompile(`[,\s]+`).Split(s.fieldsStr, -1) s.values = make([]interface{}, 0) - s.numInput = 1 + s.numInput = 0 for temp := strings.TrimSpace(s.valuesStr); temp != ""; temp = strings.TrimSpace(temp) { value, leftOver, err := _parseValue(temp, ',') if err == nil { @@ -130,13 +173,20 @@ func (s *StmtInsert) validate() error { // Exec implements driver.Stmt/Exec. // -// Note: this function expects the _last_ argument is _partition_ key value. +// Note: this function expects the _partition key values are placed at the end_ of the argument list. func (s *StmtInsert) Exec(args []driver.Value) (driver.Result, error) { + if err := s.fetchPkInfo(); err != nil { + return nil, err + } + if len(args) != s.numInput+s.numPkPaths { + return nil, fmt.Errorf("expected %d arguments, got %d", s.numInput+s.numPkPaths, len(args)) + } + spec := DocumentSpec{ DbName: s.dbName, CollName: s.collName, IsUpsert: s.isUpsert, - PartitionKeyValues: []interface{}{args[s.numInput-1]}, // expect the last argument is partition key value + PartitionKeyValues: s.extractPkValuesFromArgs(args...), DocumentData: make(map[string]interface{}), } for i := 0; i < len(s.fields); i++ { diff --git a/stmt_document_parsing_test.go b/stmt_document_parsing_test.go index f004f77..5de7742 100644 --- a/stmt_document_parsing_test.go +++ b/stmt_document_parsing_test.go @@ -30,7 +30,7 @@ db1.table1 (a, b, c, d, e, f) VALUES (null, 1.0, true, "\"a string 'with' \\\"quote\\\"\"", "{\"key\":\"value\"}", "[2.0,null,false,\"a string 'with' \\\"quote\\\"\"]")`, - expected: &StmtInsert{dbName: "db1", collName: "table1", fields: []string{"a", "b", "c", "d", "e", "f"}, values: []interface{}{nil, 1.0, true, `a string 'with' "quote"`, map[string]interface{}{"key": "value"}, []interface{}{2.0, nil, false, `a string 'with' "quote"`}}}, + expected: &StmtInsert{StmtCRUD: &StmtCRUD{dbName: "db1", collName: "table1"}, fields: []string{"a", "b", "c", "d", "e", "f"}, values: []interface{}{nil, 1.0, true, `a string 'with' "quote"`, map[string]interface{}{"key": "value"}, []interface{}{2.0, nil, false, `a string 'with' "quote"`}}}, }, { name: "with_placeholders", @@ -38,7 +38,22 @@ true, "\"a string 'with' \\\"quote\\\"\"", "{\"key\":\"value\"}", "[2.0,null,fal INTO db-2.table_2 ( a,b,c) VALUES ( $1, :3, @2)`, - expected: &StmtInsert{dbName: "db-2", collName: "table_2", fields: []string{"a", "b", "c"}, values: []interface{}{placeholder{1}, placeholder{3}, placeholder{2}}}, + expected: &StmtInsert{StmtCRUD: &StmtCRUD{dbName: "db-2", collName: "table_2"}, fields: []string{"a", "b", "c"}, values: []interface{}{placeholder{1}, placeholder{3}, placeholder{2}}}, + }, + { + name: "singlepk", + sql: `INSERT INTO db.table (a,b,c) VALUES (1,2,3) WITH singlePK`, + expected: &StmtInsert{StmtCRUD: &StmtCRUD{dbName: "db", collName: "table", isSinglePathPk: true, numPkPaths: 1}, fields: []string{"a", "b", "c"}, values: []interface{}{1.0, 2.0, 3.0}}, + }, + { + name: "single_pk", + sql: `INSERT INTO db.table (a,b,c) VALUES (:1,$2,3) WITH SINGLE_PK`, + expected: &StmtInsert{StmtCRUD: &StmtCRUD{dbName: "db", collName: "table", isSinglePathPk: true, numPkPaths: 1}, fields: []string{"a", "b", "c"}, values: []interface{}{placeholder{1}, placeholder{2}, 3.0}}, + }, + { + name: "singlepk_single_pk", + sql: `INSERT INTO db.table (a,b,c) VALUES (1,2,@1) WITH singlePK, with SINGLE_PK`, + expected: &StmtInsert{StmtCRUD: &StmtCRUD{dbName: "db", collName: "table", isSinglePathPk: true, numPkPaths: 1}, fields: []string{"a", "b", "c"}, values: []interface{}{1.0, 2.0, placeholder{1}}}, }, } for _, testCase := range testData { @@ -61,7 +76,7 @@ $1, :3, @2)`, stmt.fieldsStr = "" stmt.valuesStr = "" if !reflect.DeepEqual(stmt, testCase.expected) { - t.Fatalf("%s failed:\nexpected %#v\nreceived %#v", testName+"/"+testCase.name, testCase.expected, stmt) + t.Fatalf("%s failed:\nexpected %#v/%#v\nreceived %#v/%#v", testName+"/"+testCase.name, testCase.expected.StmtCRUD, testCase.expected, stmt.StmtCRUD, stmt) } }) } @@ -83,142 +98,38 @@ func TestStmtInsert_parse_defaultDb(t *testing.T) { name: "basic", db: "mydb", sql: `INSERT INTO -table1 (a, b, c, d, e, +table1 (a, b, c, d, e, f) VALUES - (null, 1.0, + (null, 1.0, true, "\"a string 'with' \\\"quote\\\"\"", "{\"key\":\"value\"}", "[2.0,null,false,\"a string 'with' \\\"quote\\\"\"]")`, - expected: &StmtInsert{dbName: "mydb", collName: "table1", fields: []string{"a", "b", "c", "d", "e", "f"}, values: []interface{}{nil, 1.0, true, `a string 'with' "quote"`, map[string]interface{}{"key": "value"}, []interface{}{2.0, nil, false, `a string 'with' "quote"`}}}, + expected: &StmtInsert{StmtCRUD: &StmtCRUD{dbName: "mydb", collName: "table1"}, fields: []string{"a", "b", "c", "d", "e", "f"}, values: []interface{}{nil, 1.0, true, `a string 'with' "quote"`, map[string]interface{}{"key": "value"}, []interface{}{2.0, nil, false, `a string 'with' "quote"`}}}, }, { name: "with_placeholders_table_in_query", db: "mydb", - sql: `INSERT + sql: `INSERT INTO db-2.table_2 ( a,b,c) VALUES ( $1, :3, @2)`, - expected: &StmtInsert{dbName: "db-2", collName: "table_2", fields: []string{"a", "b", "c"}, values: []interface{}{placeholder{1}, placeholder{3}, placeholder{2}}}, - }, - } - for _, testCase := range testData { - t.Run(testCase.name, func(t *testing.T) { - s, err := parseQueryWithDefaultDb(nil, testCase.db, testCase.sql) - if testCase.mustError && err == nil { - t.Fatalf("%s failed: parsing must fail", testName+"/"+testCase.name) - } - if testCase.mustError { - return - } - if err != nil { - t.Fatalf("%s failed: %s", testName+"/"+testCase.name, err) - } - stmt, ok := s.(*StmtInsert) - if !ok { - t.Fatalf("%s failed: expected StmtInsert but received %T", testName+"/"+testCase.name, s) - } - stmt.Stmt = nil - stmt.fieldsStr = "" - stmt.valuesStr = "" - if !reflect.DeepEqual(stmt, testCase.expected) { - t.Fatalf("%s failed:\nexpected %#v\nreceived %#v", testName+"/"+testCase.name, testCase.expected, stmt) - } - }) - } -} - -func TestStmtUpsert_parse(t *testing.T) { - testName := "TestStmtUpsert_parse" - testData := []struct { - name string - sql string - expected *StmtInsert - mustError bool - }{ - {name: "error_no_collection", sql: `UPSERT INTO db (a,b,c) VALUES (1,2,3)`, mustError: true}, - {name: "error_values", sql: `UPSERT INTO db.table (a,b,c)`, mustError: true}, - {name: "error_columns", sql: `UPSERT INTO db.table VALUES (1,2,3)`, mustError: true}, - {name: "error_invalid_string", sql: `UPSERT INTO db.table (a) VALUES ('a string')`, mustError: true}, - {name: "error_invalid_string2", sql: `UPSERT INTO db.table (a) VALUES ("a string")`, mustError: true}, - {name: "error_invalid_string3", sql: `UPSERT INTO db.table (a) VALUES ("{key:value}")`, mustError: true}, - {name: "error_num_values_not_matched", sql: `UPSERT INTO db.table (a,b) VALUES (1,2,3)`, mustError: true}, - {name: "error_invalid_number", sql: `UPSERT INTO db.table (a,b) VALUES (0x1qa,2)`, mustError: true}, - {name: "error_invalid_string", sql: `UPSERT INTO db.table (a,b) VALUES ("cannot \\"unquote",2)`, mustError: true}, - - { - name: "basic", - sql: `UPSERT INTO -db1.table1 (a, -b, c, d, e, -f) VALUES - (null, 1.0, true, - "\"a string 'with' \\\"quote\\\"\"", "{\"key\":\"value\"}", "[2.0,null,false,\"a string 'with' \\\"quote\\\"\"]")`, - expected: &StmtInsert{dbName: "db1", collName: "table1", isUpsert: true, fields: []string{"a", "b", "c", "d", "e", "f"}, values: []interface{}{nil, 1.0, true, `a string 'with' "quote"`, map[string]interface{}{"key": "value"}, []interface{}{2.0, nil, false, `a string 'with' "quote"`}}}, + expected: &StmtInsert{StmtCRUD: &StmtCRUD{dbName: "db-2", collName: "table_2"}, fields: []string{"a", "b", "c"}, values: []interface{}{placeholder{1}, placeholder{3}, placeholder{2}}}, }, { - name: "with_placeholders", - sql: `UPSERT -INTO db-2.table_2 ( -a,b,c) VALUES ($1, - :3, @2)`, - expected: &StmtInsert{dbName: "db-2", collName: "table_2", isUpsert: true, fields: []string{"a", "b", "c"}, values: []interface{}{placeholder{1}, placeholder{3}, placeholder{2}}}, + name: "singlepk", + db: "mydb", + sql: `INSERT INTO table (a,b,c) VALUES (1,2,3) WITH singlePK`, + expected: &StmtInsert{StmtCRUD: &StmtCRUD{dbName: "mydb", collName: "table", isSinglePathPk: true, numPkPaths: 1}, fields: []string{"a", "b", "c"}, values: []interface{}{1.0, 2.0, 3.0}}, }, - } - for _, testCase := range testData { - t.Run(testCase.name, func(t *testing.T) { - s, err := parseQuery(nil, testCase.sql) - if testCase.mustError && err == nil { - t.Fatalf("%s failed: parsing must fail", testName+"/"+testCase.name) - } - if testCase.mustError { - return - } - if err != nil { - t.Fatalf("%s failed: %s", testName+"/"+testCase.name, err) - } - stmt, ok := s.(*StmtInsert) - if !ok { - t.Fatalf("%s failed: expected StmtInsert but received %T", testName+"/"+testCase.name, s) - } - stmt.Stmt = nil - stmt.fieldsStr = "" - stmt.valuesStr = "" - if !reflect.DeepEqual(stmt, testCase.expected) { - t.Fatalf("%s failed:\nexpected %#v\nreceived %#v", testName+"/"+testCase.name, testCase.expected, stmt) - } - }) - } -} - -func TestStmtUpsert_parse_defaultDb(t *testing.T) { - testName := "TestStmtUpsert_parse_defaultDb" - testData := []struct { - name string - db string - sql string - expected *StmtInsert - mustError bool - }{ - {name: "error_invalid_query", sql: `UPSERT INTO .table (a,b) VALUES (1,2)`, mustError: true}, - {name: "error_invalid_query2", sql: `UPSERT INTO db. (a,b) VALUES (1,2)`, mustError: true}, - { - name: "basic", - db: "mydb", - sql: `UPSERT INTO -table1 (a, -b, c, d, e, -f) VALUES - (null, 1.0, true, - "\"a string 'with' \\\"quote\\\"\"", "{\"key\":\"value\"}", "[2.0,null,false,\"a string 'with' \\\"quote\\\"\"]")`, - expected: &StmtInsert{dbName: "mydb", collName: "table1", isUpsert: true, fields: []string{"a", "b", "c", "d", "e", "f"}, values: []interface{}{nil, 1.0, true, `a string 'with' "quote"`, map[string]interface{}{"key": "value"}, []interface{}{2.0, nil, false, `a string 'with' "quote"`}}}, + name: "single_pk", + db: "mydb", + sql: `INSERT INTO db.table (a,b,c) VALUES (:1,$2,3) WITH SINGLE_PK`, + expected: &StmtInsert{StmtCRUD: &StmtCRUD{dbName: "db", collName: "table", isSinglePathPk: true, numPkPaths: 1}, fields: []string{"a", "b", "c"}, values: []interface{}{placeholder{1}, placeholder{2}, 3.0}}, }, { - name: "with_placeholders_table_in_query", - db: "mydb", - sql: `UPSERT -INTO db-2.table_2 ( -a,b,c) VALUES ($1, - :3, @2)`, - expected: &StmtInsert{dbName: "db-2", collName: "table_2", isUpsert: true, fields: []string{"a", "b", "c"}, values: []interface{}{placeholder{1}, placeholder{3}, placeholder{2}}}, + name: "singlepk_single_pk", + db: "mydb", + sql: `INSERT INTO table (a,b,c) VALUES (1,2,@1) WITH singlePK, with SINGLE_PK`, + expected: &StmtInsert{StmtCRUD: &StmtCRUD{dbName: "mydb", collName: "table", isSinglePathPk: true, numPkPaths: 1}, fields: []string{"a", "b", "c"}, values: []interface{}{1.0, 2.0, placeholder{1}}}, }, } for _, testCase := range testData { @@ -247,6 +158,128 @@ a,b,c) VALUES ($1, } } +// func TestStmtUpsert_parse(t *testing.T) { +// testName := "TestStmtUpsert_parse" +// testData := []struct { +// name string +// sql string +// expected *StmtInsert +// mustError bool +// }{ +// {name: "error_no_collection", sql: `UPSERT INTO db (a,b,c) VALUES (1,2,3)`, mustError: true}, +// {name: "error_values", sql: `UPSERT INTO db.table (a,b,c)`, mustError: true}, +// {name: "error_columns", sql: `UPSERT INTO db.table VALUES (1,2,3)`, mustError: true}, +// {name: "error_invalid_string", sql: `UPSERT INTO db.table (a) VALUES ('a string')`, mustError: true}, +// {name: "error_invalid_string2", sql: `UPSERT INTO db.table (a) VALUES ("a string")`, mustError: true}, +// {name: "error_invalid_string3", sql: `UPSERT INTO db.table (a) VALUES ("{key:value}")`, mustError: true}, +// {name: "error_num_values_not_matched", sql: `UPSERT INTO db.table (a,b) VALUES (1,2,3)`, mustError: true}, +// {name: "error_invalid_number", sql: `UPSERT INTO db.table (a,b) VALUES (0x1qa,2)`, mustError: true}, +// {name: "error_invalid_string", sql: `UPSERT INTO db.table (a,b) VALUES ("cannot \\"unquote",2)`, mustError: true}, +// +// { +// name: "basic", +// sql: `UPSERT INTO +// db1.table1 (a, +// b, c, d, e, +// f) VALUES +// (null, 1.0, true, +// "\"a string 'with' \\\"quote\\\"\"", "{\"key\":\"value\"}", "[2.0,null,false,\"a string 'with' \\\"quote\\\"\"]")`, +// expected: &StmtInsert{dbName: "db1", collName: "table1", isUpsert: true, fields: []string{"a", "b", "c", "d", "e", "f"}, values: []interface{}{nil, 1.0, true, `a string 'with' "quote"`, map[string]interface{}{"key": "value"}, []interface{}{2.0, nil, false, `a string 'with' "quote"`}}}, +// }, +// { +// name: "with_placeholders", +// sql: `UPSERT +// INTO db-2.table_2 ( +// a,b,c) VALUES ($1, +// :3, @2)`, +// expected: &StmtInsert{dbName: "db-2", collName: "table_2", isUpsert: true, fields: []string{"a", "b", "c"}, values: []interface{}{placeholder{1}, placeholder{3}, placeholder{2}}}, +// }, +// } +// for _, testCase := range testData { +// t.Run(testCase.name, func(t *testing.T) { +// s, err := parseQuery(nil, testCase.sql) +// if testCase.mustError && err == nil { +// t.Fatalf("%s failed: parsing must fail", testName+"/"+testCase.name) +// } +// if testCase.mustError { +// return +// } +// if err != nil { +// t.Fatalf("%s failed: %s", testName+"/"+testCase.name, err) +// } +// stmt, ok := s.(*StmtInsert) +// if !ok { +// t.Fatalf("%s failed: expected StmtInsert but received %T", testName+"/"+testCase.name, s) +// } +// stmt.Stmt = nil +// stmt.fieldsStr = "" +// stmt.valuesStr = "" +// if !reflect.DeepEqual(stmt, testCase.expected) { +// t.Fatalf("%s failed:\nexpected %#v\nreceived %#v", testName+"/"+testCase.name, testCase.expected, stmt) +// } +// }) +// } +// } +// +// func TestStmtUpsert_parse_defaultDb(t *testing.T) { +// testName := "TestStmtUpsert_parse_defaultDb" +// testData := []struct { +// name string +// db string +// sql string +// expected *StmtInsert +// mustError bool +// }{ +// {name: "error_invalid_query", sql: `UPSERT INTO .table (a,b) VALUES (1,2)`, mustError: true}, +// {name: "error_invalid_query2", sql: `UPSERT INTO db. (a,b) VALUES (1,2)`, mustError: true}, +// +// { +// name: "basic", +// db: "mydb", +// sql: `UPSERT INTO +// table1 (a, +// b, c, d, e, +// f) VALUES +// (null, 1.0, true, +// "\"a string 'with' \\\"quote\\\"\"", "{\"key\":\"value\"}", "[2.0,null,false,\"a string 'with' \\\"quote\\\"\"]")`, +// expected: &StmtInsert{dbName: "mydb", collName: "table1", isUpsert: true, fields: []string{"a", "b", "c", "d", "e", "f"}, values: []interface{}{nil, 1.0, true, `a string 'with' "quote"`, map[string]interface{}{"key": "value"}, []interface{}{2.0, nil, false, `a string 'with' "quote"`}}}, +// }, +// { +// name: "with_placeholders_table_in_query", +// db: "mydb", +// sql: `UPSERT +// INTO db-2.table_2 ( +// a,b,c) VALUES ($1, +// :3, @2)`, +// expected: &StmtInsert{dbName: "db-2", collName: "table_2", isUpsert: true, fields: []string{"a", "b", "c"}, values: []interface{}{placeholder{1}, placeholder{3}, placeholder{2}}}, +// }, +// } +// for _, testCase := range testData { +// t.Run(testCase.name, func(t *testing.T) { +// s, err := parseQueryWithDefaultDb(nil, testCase.db, testCase.sql) +// if testCase.mustError && err == nil { +// t.Fatalf("%s failed: parsing must fail", testName+"/"+testCase.name) +// } +// if testCase.mustError { +// return +// } +// if err != nil { +// t.Fatalf("%s failed: %s", testName+"/"+testCase.name, err) +// } +// stmt, ok := s.(*StmtInsert) +// if !ok { +// t.Fatalf("%s failed: expected StmtInsert but received %T", testName+"/"+testCase.name, s) +// } +// stmt.Stmt = nil +// stmt.fieldsStr = "" +// stmt.valuesStr = "" +// if !reflect.DeepEqual(stmt, testCase.expected) { +// t.Fatalf("%s failed:\nexpected %#v\nreceived %#v", testName+"/"+testCase.name, testCase.expected, stmt) +// } +// }) +// } +// } + func TestStmtDelete_parse(t *testing.T) { testName := "TestStmtDelete_parse" testData := []struct { diff --git a/stmt_document_test.go b/stmt_document_test.go index 36091ff..2722c2f 100644 --- a/stmt_document_test.go +++ b/stmt_document_test.go @@ -46,13 +46,13 @@ func TestStmtInsert_Exec(t *testing.T) { }, { name: "insert_conflict_pk", - sql: fmt.Sprintf(`INSERT INTO %s.tbltemp (id,username,email,grade,actived) VALUES ("\"1\"", "\"user\"", "\"user@domain2.com\"", 8, false)`, dbname), + sql: fmt.Sprintf(`INSERT INTO %s.tbltemp (id,username,email,grade,actived) VALUES ("\"1\"", "\"user\"", "\"user@domain2.com\"", 8, false) WITH singlePK`, dbname), args: []interface{}{"user"}, mustConflict: true, }, { name: "insert_conflict_uk", - sql: fmt.Sprintf(`INSERT INTO %s.tbltemp (id,username,email,grade,actived) VALUES ("\"2\"", "\"user\"", "\"user@domain1.com\"", 9, false)`, dbname), + sql: fmt.Sprintf(`INSERT INTO %s.tbltemp (id,username,email,grade,actived) VALUES ("\"2\"", "\"user\"", "\"user@domain1.com\"", 9, false) WITH SINGLE_PK`, dbname), args: []interface{}{"user"}, mustConflict: true, }, @@ -82,13 +82,13 @@ func TestStmtInsert_Exec(t *testing.T) { }, { name: "insert_conflict_pk_placeholders", - sql: fmt.Sprintf(`INSERT INTO %s.tbltemp (id, username, email, grade, actived, data) VALUES (:1, $2, @3, @4, $5, :6)`, dbname), + sql: fmt.Sprintf(`INSERT INTO %s.tbltemp (id, username, email, grade, actived, data) VALUES (:1, $2, @3, @4, $5, :6) WITH SINGLE_PK`, dbname), args: []interface{}{"1", "user", "user@domain2.com", 2, false, nil, "user"}, mustConflict: true, }, { name: "insert_conflict_uk_placeholders", - sql: fmt.Sprintf(`INSERT INTO %s.tbltemp (id, username, email, grade, actived, data) VALUES (:1, $2, @3, @4, $5, :6)`, dbname), + sql: fmt.Sprintf(`INSERT INTO %s.tbltemp (id, username, email, grade, actived, data) VALUES (:1, $2, @3, @4, $5, :6) WITH singlePK`, dbname), args: []interface{}{"2", "user", "user@domain1.com", 3, false, nil, "user"}, mustConflict: true, }, @@ -184,7 +184,7 @@ func TestStmtInsert_Exec_DefaultDb(t *testing.T) { fmt.Sprintf("CREATE DATABASE %s", dbname), fmt.Sprintf("CREATE COLLECTION %s.tbltemp WITH pk=/username WITH uk=/email", dbname), }, - sql: `INSERT INTO tbltemp (id, username, email, grade, actived) VALUES ("\"1\"", "\"user\"", "\"user@domain1.com\"", 7, true)`, + sql: `INSERT INTO tbltemp (id, username, email, grade, actived) VALUES ("\"1\"", "\"user\"", "\"user@domain1.com\"", 7, true) WITH SINGLE_PK`, args: []interface{}{"user"}, affectedRows: 1, }, @@ -214,7 +214,7 @@ func TestStmtInsert_Exec_DefaultDb(t *testing.T) { fmt.Sprintf("CREATE DATABASE %s", dbname), fmt.Sprintf("CREATE COLLECTION %s.tbltemp WITH pk=/username WITH uk=/email", dbname), }, - sql: `INSERT INTO tbltemp (id, username, email, grade, actived, data) VALUES (:1, $2, @3, @4, $5, :6)`, + sql: `INSERT INTO tbltemp (id, username, email, grade, actived, data) VALUES (:1, $2, @3, @4, $5, :6) WITH singlePK`, args: []interface{}{"1", "user", "user@domain1.com", 1, true, map[string]interface{}{"str": "a string", "num": 1.23, "bool": true, "date": time.Now()}, "user"}, affectedRows: 1, }, @@ -293,6 +293,102 @@ func TestStmtInsert_Exec_DefaultDb(t *testing.T) { } } +func TestStmtInsert_SubPartitions(t *testing.T) { + testName := "TestStmtInsert_SubPartitions" + db := _openDb(t, testName) + dbname := "dbtemp" + defer db.Exec(fmt.Sprintf("DROP DATABASE IF EXISTS %s", dbname)) + testData := []struct { + name string + initSqls []string + sql string + args []interface{} + mustConflict bool + mustNotFound bool + mustError string + affectedRows int64 + }{ + { + name: "insert_new", + initSqls: []string{ + "DROP DATABASE IF EXISTS db_not_exists", + fmt.Sprintf("DROP DATABASE IF EXISTS %s", dbname), + fmt.Sprintf("CREATE DATABASE %s", dbname), + fmt.Sprintf("CREATE COLLECTION %s.tbltemp WITH pk=/app,/username WITH uk=/email", dbname), + }, + sql: fmt.Sprintf(`INSERT INTO %s.tbltemp (id, app, username, email, grade, actived, data) VALUES (:1, $2, @3, :4, $5, @6, :7)`, dbname), + args: []interface{}{"1", "app", "user", "user@domain1.com", 1, true, map[string]interface{}{"str": "a string", "num": 1.23, "bool": true, "date": time.Now()}, "app", "user"}, + affectedRows: 1, + }, + { + name: "insert_conflict_pk_", + sql: fmt.Sprintf(`INSERT INTO %s.tbltemp (id, app, username, email, grade, actived, data) VALUES (:1, $2, @3, :4, $5, @6, :7)`, dbname), + args: []interface{}{"1", "app", "user", "user@domain2.com", 2, false, nil, "app", "user"}, + mustConflict: true, + }, + { + name: "insert_conflict_uk", + sql: fmt.Sprintf(`INSERT INTO %s.tbltemp (id, app, username, email, grade, actived, data) VALUES (:1, $2, @3, :4, $5, @6, :7)`, dbname), + args: []interface{}{"2", "app", "user", "user@domain1.com", 3, false, nil, "app", "user"}, + mustConflict: true, + }, + { + name: "error_invalid_value_index", + sql: fmt.Sprintf(`INSERT INTO %s.tbltemp (id, app, username, email, grade, actived, data) VALUES (:1, $2, @3, @4, $5, $6, :10)`, dbname), + args: []interface{}{"2", "app", "user", "user@domain1.com", 3, false, nil, "app", "user"}, + mustError: "invalid value index", + }, + } + for _, testCase := range testData { + t.Run(testCase.name, func(t *testing.T) { + for _, initSql := range testCase.initSqls { + _, err := db.Exec(initSql) + if err != nil { + t.Fatalf("%s failed: {error: %s / sql: %s}", testName+"/"+testCase.name+"/init", err, initSql) + } + } + execResult, err := db.Exec(testCase.sql, testCase.args...) + if testCase.mustConflict && err != ErrConflict { + t.Fatalf("%s failed: expect ErrConflict but received %#v", testName+"/"+testCase.name+"/exec", err) + } + if testCase.mustNotFound && err != ErrNotFound { + t.Fatalf("%s failed: expect ErrNotFound but received %#v", testName+"/"+testCase.name+"/exec", err) + } + if testCase.mustConflict || testCase.mustNotFound { + return + } + if testCase.mustError != "" { + if err == nil || strings.Index(err.Error(), testCase.mustError) < 0 { + t.Fatalf("%s failed: expected '%s' bur received %#v", testCase.name, testCase.mustError, err) + } + return + } + if err != nil { + t.Fatalf("%s failed: %s", testName+"/"+testCase.name+"/exec", err) + } + affectedRows, err := execResult.RowsAffected() + if err != nil { + t.Fatalf("%s failed: %s", testName+"/"+testCase.name+"/rows_affected", err) + } + if affectedRows != testCase.affectedRows { + t.Fatalf("%s failed: expected %#v affected-rows but received %#v", testName+"/"+testCase.name, testCase.affectedRows, affectedRows) + } + _, err = execResult.LastInsertId() + if err == nil { + t.Fatalf("%s failed: expected LastInsertId but received nil", testName+"/"+testCase.name) + } + lastInsertId := make(map[string]interface{}) + err = json.Unmarshal([]byte(err.Error()), &lastInsertId) + if err != nil { + t.Fatalf("%s failed: %s", testName+"/"+testCase.name+"/LastInsertId", err) + } + if len(lastInsertId) != 1 { + t.Fatalf("%s failed - LastInsertId: %#v", testName+"/"+testCase.name+"/LastInsertId", lastInsertId) + } + }) + } +} + func TestStmtUpsert_Query(t *testing.T) { testName := "TestStmtUpsert_Query" db := _openDb(t, testName) From cb301ff6f3ce65da61b761e73304c216e8a96c8a Mon Sep 17 00:00:00 2001 From: Thanh Nguyen Date: Tue, 13 Jun 2023 22:23:14 +1000 Subject: [PATCH 3/8] UPSERT statement: support subpartitions --- SQL.md | 5 +- stmt_document_parsing_test.go | 275 +++++++++++++++++++--------------- stmt_document_test.go | 149 ++++++++++++++++-- 3 files changed, 296 insertions(+), 133 deletions(-) diff --git a/SQL.md b/SQL.md index f6f0771..a667b0e 100644 --- a/SQL.md +++ b/SQL.md @@ -343,7 +343,10 @@ Description: insert a new document or replace an existing one. Syntax & Usage: similar to [INSERT](#insert). ```sql -UPSERT INTO [.] (, ,...) VALUES (, ,...) +UPSERT INTO [.] +(, ,...) +VALUES (, ,...) +[WITH singlePK|SINGLE_PK] ``` [Back to top](#top) diff --git a/stmt_document_parsing_test.go b/stmt_document_parsing_test.go index 5de7742..c117d60 100644 --- a/stmt_document_parsing_test.go +++ b/stmt_document_parsing_test.go @@ -158,127 +158,160 @@ $1, :3, @2)`, } } -// func TestStmtUpsert_parse(t *testing.T) { -// testName := "TestStmtUpsert_parse" -// testData := []struct { -// name string -// sql string -// expected *StmtInsert -// mustError bool -// }{ -// {name: "error_no_collection", sql: `UPSERT INTO db (a,b,c) VALUES (1,2,3)`, mustError: true}, -// {name: "error_values", sql: `UPSERT INTO db.table (a,b,c)`, mustError: true}, -// {name: "error_columns", sql: `UPSERT INTO db.table VALUES (1,2,3)`, mustError: true}, -// {name: "error_invalid_string", sql: `UPSERT INTO db.table (a) VALUES ('a string')`, mustError: true}, -// {name: "error_invalid_string2", sql: `UPSERT INTO db.table (a) VALUES ("a string")`, mustError: true}, -// {name: "error_invalid_string3", sql: `UPSERT INTO db.table (a) VALUES ("{key:value}")`, mustError: true}, -// {name: "error_num_values_not_matched", sql: `UPSERT INTO db.table (a,b) VALUES (1,2,3)`, mustError: true}, -// {name: "error_invalid_number", sql: `UPSERT INTO db.table (a,b) VALUES (0x1qa,2)`, mustError: true}, -// {name: "error_invalid_string", sql: `UPSERT INTO db.table (a,b) VALUES ("cannot \\"unquote",2)`, mustError: true}, -// -// { -// name: "basic", -// sql: `UPSERT INTO -// db1.table1 (a, -// b, c, d, e, -// f) VALUES -// (null, 1.0, true, -// "\"a string 'with' \\\"quote\\\"\"", "{\"key\":\"value\"}", "[2.0,null,false,\"a string 'with' \\\"quote\\\"\"]")`, -// expected: &StmtInsert{dbName: "db1", collName: "table1", isUpsert: true, fields: []string{"a", "b", "c", "d", "e", "f"}, values: []interface{}{nil, 1.0, true, `a string 'with' "quote"`, map[string]interface{}{"key": "value"}, []interface{}{2.0, nil, false, `a string 'with' "quote"`}}}, -// }, -// { -// name: "with_placeholders", -// sql: `UPSERT -// INTO db-2.table_2 ( -// a,b,c) VALUES ($1, -// :3, @2)`, -// expected: &StmtInsert{dbName: "db-2", collName: "table_2", isUpsert: true, fields: []string{"a", "b", "c"}, values: []interface{}{placeholder{1}, placeholder{3}, placeholder{2}}}, -// }, -// } -// for _, testCase := range testData { -// t.Run(testCase.name, func(t *testing.T) { -// s, err := parseQuery(nil, testCase.sql) -// if testCase.mustError && err == nil { -// t.Fatalf("%s failed: parsing must fail", testName+"/"+testCase.name) -// } -// if testCase.mustError { -// return -// } -// if err != nil { -// t.Fatalf("%s failed: %s", testName+"/"+testCase.name, err) -// } -// stmt, ok := s.(*StmtInsert) -// if !ok { -// t.Fatalf("%s failed: expected StmtInsert but received %T", testName+"/"+testCase.name, s) -// } -// stmt.Stmt = nil -// stmt.fieldsStr = "" -// stmt.valuesStr = "" -// if !reflect.DeepEqual(stmt, testCase.expected) { -// t.Fatalf("%s failed:\nexpected %#v\nreceived %#v", testName+"/"+testCase.name, testCase.expected, stmt) -// } -// }) -// } -// } -// -// func TestStmtUpsert_parse_defaultDb(t *testing.T) { -// testName := "TestStmtUpsert_parse_defaultDb" -// testData := []struct { -// name string -// db string -// sql string -// expected *StmtInsert -// mustError bool -// }{ -// {name: "error_invalid_query", sql: `UPSERT INTO .table (a,b) VALUES (1,2)`, mustError: true}, -// {name: "error_invalid_query2", sql: `UPSERT INTO db. (a,b) VALUES (1,2)`, mustError: true}, -// -// { -// name: "basic", -// db: "mydb", -// sql: `UPSERT INTO -// table1 (a, -// b, c, d, e, -// f) VALUES -// (null, 1.0, true, -// "\"a string 'with' \\\"quote\\\"\"", "{\"key\":\"value\"}", "[2.0,null,false,\"a string 'with' \\\"quote\\\"\"]")`, -// expected: &StmtInsert{dbName: "mydb", collName: "table1", isUpsert: true, fields: []string{"a", "b", "c", "d", "e", "f"}, values: []interface{}{nil, 1.0, true, `a string 'with' "quote"`, map[string]interface{}{"key": "value"}, []interface{}{2.0, nil, false, `a string 'with' "quote"`}}}, -// }, -// { -// name: "with_placeholders_table_in_query", -// db: "mydb", -// sql: `UPSERT -// INTO db-2.table_2 ( -// a,b,c) VALUES ($1, -// :3, @2)`, -// expected: &StmtInsert{dbName: "db-2", collName: "table_2", isUpsert: true, fields: []string{"a", "b", "c"}, values: []interface{}{placeholder{1}, placeholder{3}, placeholder{2}}}, -// }, -// } -// for _, testCase := range testData { -// t.Run(testCase.name, func(t *testing.T) { -// s, err := parseQueryWithDefaultDb(nil, testCase.db, testCase.sql) -// if testCase.mustError && err == nil { -// t.Fatalf("%s failed: parsing must fail", testName+"/"+testCase.name) -// } -// if testCase.mustError { -// return -// } -// if err != nil { -// t.Fatalf("%s failed: %s", testName+"/"+testCase.name, err) -// } -// stmt, ok := s.(*StmtInsert) -// if !ok { -// t.Fatalf("%s failed: expected StmtInsert but received %T", testName+"/"+testCase.name, s) -// } -// stmt.Stmt = nil -// stmt.fieldsStr = "" -// stmt.valuesStr = "" -// if !reflect.DeepEqual(stmt, testCase.expected) { -// t.Fatalf("%s failed:\nexpected %#v\nreceived %#v", testName+"/"+testCase.name, testCase.expected, stmt) -// } -// }) -// } -// } +func TestStmtUpsert_parse(t *testing.T) { + testName := "TestStmtUpsert_parse" + testData := []struct { + name string + sql string + expected *StmtInsert + mustError bool + }{ + {name: "error_no_collection", sql: `UPSERT INTO db (a,b,c) VALUES (1,2,3)`, mustError: true}, + {name: "error_values", sql: `UPSERT INTO db.table (a,b,c)`, mustError: true}, + {name: "error_columns", sql: `UPSERT INTO db.table VALUES (1,2,3)`, mustError: true}, + {name: "error_invalid_string", sql: `UPSERT INTO db.table (a) VALUES ('a string')`, mustError: true}, + {name: "error_invalid_string2", sql: `UPSERT INTO db.table (a) VALUES ("a string")`, mustError: true}, + {name: "error_invalid_string3", sql: `UPSERT INTO db.table (a) VALUES ("{key:value}")`, mustError: true}, + {name: "error_num_values_not_matched", sql: `UPSERT INTO db.table (a,b) VALUES (1,2,3)`, mustError: true}, + {name: "error_invalid_number", sql: `UPSERT INTO db.table (a,b) VALUES (0x1qa,2)`, mustError: true}, + {name: "error_invalid_string", sql: `UPSERT INTO db.table (a,b) VALUES ("cannot \\"unquote",2)`, mustError: true}, + + { + name: "basic", + sql: `UPSERT INTO +db1.table1 (a, +b, c, d, e, +f) VALUES + (null, 1.0, true, + "\"a string 'with' \\\"quote\\\"\"", "{\"key\":\"value\"}", "[2.0,null,false,\"a string 'with' \\\"quote\\\"\"]")`, + expected: &StmtInsert{StmtCRUD: &StmtCRUD{dbName: "db1", collName: "table1"}, isUpsert: true, fields: []string{"a", "b", "c", "d", "e", "f"}, values: []interface{}{nil, 1.0, true, `a string 'with' "quote"`, map[string]interface{}{"key": "value"}, []interface{}{2.0, nil, false, `a string 'with' "quote"`}}}, + }, + { + name: "with_placeholders", + sql: `UPSERT +INTO db-2.table_2 ( +a,b,c) VALUES ($1, + :3, @2)`, + expected: &StmtInsert{StmtCRUD: &StmtCRUD{dbName: "db-2", collName: "table_2"}, isUpsert: true, fields: []string{"a", "b", "c"}, values: []interface{}{placeholder{1}, placeholder{3}, placeholder{2}}}, + }, + { + name: "singlepk", + sql: `UPSERT INTO db.table (a,b,c) VALUES (:1, :3, :2) WITH singlePK`, + expected: &StmtInsert{StmtCRUD: &StmtCRUD{dbName: "db", collName: "table", isSinglePathPk: true, numPkPaths: 1}, isUpsert: true, fields: []string{"a", "b", "c"}, values: []interface{}{placeholder{1}, placeholder{3}, placeholder{2}}}, + }, + { + name: "single_pk", + sql: `UPSERT INTO db.table (a,b,c) VALUES (:1, :3, :2) WITH SINGLE_PK`, + expected: &StmtInsert{StmtCRUD: &StmtCRUD{dbName: "db", collName: "table", isSinglePathPk: true, numPkPaths: 1}, isUpsert: true, fields: []string{"a", "b", "c"}, values: []interface{}{placeholder{1}, placeholder{3}, placeholder{2}}}, + }, + { + name: "singlepk_single_pk", + sql: `UPSERT INTO db.table (a,b,c) VALUES (:1, :3, :2) WITH SINGLE_PK, with singlePK`, + expected: &StmtInsert{StmtCRUD: &StmtCRUD{dbName: "db", collName: "table", isSinglePathPk: true, numPkPaths: 1}, isUpsert: true, fields: []string{"a", "b", "c"}, values: []interface{}{placeholder{1}, placeholder{3}, placeholder{2}}}, + }, + } + for _, testCase := range testData { + t.Run(testCase.name, func(t *testing.T) { + s, err := parseQuery(nil, testCase.sql) + if testCase.mustError && err == nil { + t.Fatalf("%s failed: parsing must fail", testName+"/"+testCase.name) + } + if testCase.mustError { + return + } + if err != nil { + t.Fatalf("%s failed: %s", testName+"/"+testCase.name, err) + } + stmt, ok := s.(*StmtInsert) + if !ok { + t.Fatalf("%s failed: expected StmtInsert but received %T", testName+"/"+testCase.name, s) + } + stmt.Stmt = nil + stmt.fieldsStr = "" + stmt.valuesStr = "" + if !reflect.DeepEqual(stmt, testCase.expected) { + t.Fatalf("%s failed:\nexpected %#v\nreceived %#v", testName+"/"+testCase.name, testCase.expected, stmt) + } + }) + } +} + +func TestStmtUpsert_parse_defaultDb(t *testing.T) { + testName := "TestStmtUpsert_parse_defaultDb" + testData := []struct { + name string + db string + sql string + expected *StmtInsert + mustError bool + }{ + {name: "error_invalid_query", sql: `UPSERT INTO .table (a,b) VALUES (1,2)`, mustError: true}, + {name: "error_invalid_query2", sql: `UPSERT INTO db. (a,b) VALUES (1,2)`, mustError: true}, + + { + name: "basic", + db: "mydb", + sql: `UPSERT INTO +table1 (a, +b, c, d, e, +f) VALUES + (null, 1.0, true, + "\"a string 'with' \\\"quote\\\"\"", "{\"key\":\"value\"}", "[2.0,null,false,\"a string 'with' \\\"quote\\\"\"]")`, + expected: &StmtInsert{StmtCRUD: &StmtCRUD{dbName: "mydb", collName: "table1"}, isUpsert: true, fields: []string{"a", "b", "c", "d", "e", "f"}, values: []interface{}{nil, 1.0, true, `a string 'with' "quote"`, map[string]interface{}{"key": "value"}, []interface{}{2.0, nil, false, `a string 'with' "quote"`}}}, + }, + { + name: "with_placeholders_table_in_query", + db: "mydb", + sql: `UPSERT +INTO db-2.table_2 ( +a,b,c) VALUES ($1, + :3, @2)`, + expected: &StmtInsert{StmtCRUD: &StmtCRUD{dbName: "db-2", collName: "table_2"}, isUpsert: true, fields: []string{"a", "b", "c"}, values: []interface{}{placeholder{1}, placeholder{3}, placeholder{2}}}, + }, + { + name: "singlepk", + db: "mydb", + sql: `UPSERT INTO db.table (a,b,c) VALUES ($1, :3, @2) WITH SINGLEPK`, + expected: &StmtInsert{StmtCRUD: &StmtCRUD{dbName: "db", collName: "table", isSinglePathPk: true, numPkPaths: 1}, isUpsert: true, fields: []string{"a", "b", "c"}, values: []interface{}{placeholder{1}, placeholder{3}, placeholder{2}}}, + }, + { + name: "single_pk", + db: "mydb", + sql: `UPSERT INTO table (a,b,c) VALUES ($1, :3, @2) WITH single_pk`, + expected: &StmtInsert{StmtCRUD: &StmtCRUD{dbName: "mydb", collName: "table", isSinglePathPk: true, numPkPaths: 1}, isUpsert: true, fields: []string{"a", "b", "c"}, values: []interface{}{placeholder{1}, placeholder{3}, placeholder{2}}}, + }, + { + name: "singlepk_single_pk", + db: "mydb", + sql: `UPSERT INTO db.table (a,b,c) VALUES ($1, :3, @2) WITH single_pk WITH singlePK`, + expected: &StmtInsert{StmtCRUD: &StmtCRUD{dbName: "db", collName: "table", isSinglePathPk: true, numPkPaths: 1}, isUpsert: true, fields: []string{"a", "b", "c"}, values: []interface{}{placeholder{1}, placeholder{3}, placeholder{2}}}, + }, + } + for _, testCase := range testData { + t.Run(testCase.name, func(t *testing.T) { + s, err := parseQueryWithDefaultDb(nil, testCase.db, testCase.sql) + if testCase.mustError && err == nil { + t.Fatalf("%s failed: parsing must fail", testName+"/"+testCase.name) + } + if testCase.mustError { + return + } + if err != nil { + t.Fatalf("%s failed: %s", testName+"/"+testCase.name, err) + } + stmt, ok := s.(*StmtInsert) + if !ok { + t.Fatalf("%s failed: expected StmtInsert but received %T", testName+"/"+testCase.name, s) + } + stmt.Stmt = nil + stmt.fieldsStr = "" + stmt.valuesStr = "" + if !reflect.DeepEqual(stmt, testCase.expected) { + t.Fatalf("%s failed:\nexpected %#v/%#v\nreceived %#v/%#v", testName+"/"+testCase.name, testCase.expected.StmtCRUD, testCase.expected, stmt.StmtCRUD, stmt) + } + }) + } +} func TestStmtDelete_parse(t *testing.T) { testName := "TestStmtDelete_parse" diff --git a/stmt_document_test.go b/stmt_document_test.go index 2722c2f..595f37f 100644 --- a/stmt_document_test.go +++ b/stmt_document_test.go @@ -427,13 +427,13 @@ func TestStmtUpsert_Exec(t *testing.T) { }, { name: "upsert_another", - sql: fmt.Sprintf(`UPSERT INTO %s.tbltemp (id, username, email, grade, actived) VALUES ("\"2\"", "\"user2\"", "\"user2@domain.com\"", 7, true)`, dbname), + sql: fmt.Sprintf(`UPSERT INTO %s.tbltemp (id, username, email, grade, actived) VALUES ("\"2\"", "\"user2\"", "\"user2@domain.com\"", 7, true) WITH singlePK`, dbname), args: []interface{}{"user2"}, affectedRows: 1, }, { name: "upsert_duplicated_id_placeholders", - sql: fmt.Sprintf(`UPSERT INTO %s.tbltemp (id,username,email,grade,actived) VALUES ("\"1\"", "\"user1\"", "\"user3@domain1.com\"", 8, false)`, dbname), + sql: fmt.Sprintf(`UPSERT INTO %s.tbltemp (id,username,email,grade,actived) VALUES ("\"1\"", "\"user1\"", "\"user3@domain1.com\"", 8, false) WITH single_PK`, dbname), args: []interface{}{"user1"}, affectedRows: 1, }, @@ -463,7 +463,7 @@ func TestStmtUpsert_Exec(t *testing.T) { fmt.Sprintf("CREATE DATABASE %s", dbname), fmt.Sprintf("CREATE COLLECTION %s.tbltemp WITH pk=/username WITH uk=/email", dbname), }, - sql: fmt.Sprintf(`UPSERT INTO %s.tbltemp (id, username, email, grade, actived, data) VALUES (:1, $2, @3, @4, $5, :6)`, dbname), + sql: fmt.Sprintf(`UPSERT INTO %s.tbltemp (id, username, email, grade, actived, data) VALUES (:1, $2, @3, @4, $5, :6) WITH singlePK`, dbname), args: []interface{}{"1", "user1", "user1@domain.com", 1, true, map[string]interface{}{"str": "a string", "num": 1.23, "bool": true, "date": time.Now()}, "user1"}, affectedRows: 1, }, @@ -475,7 +475,7 @@ func TestStmtUpsert_Exec(t *testing.T) { }, { name: "upsert_duplicated_id_placeholders", - sql: fmt.Sprintf(`UPSERT INTO %s.tbltemp (id, username, email, grade, actived, data) VALUES (:1, $2, @3, @4, $5, :6)`, dbname), + sql: fmt.Sprintf(`UPSERT INTO %s.tbltemp (id, username, email, grade, actived, data) VALUES (:1, $2, @3, @4, $5, :6) WITH single_PK`, dbname), args: []interface{}{"1", "user1", "user2@domain.com", 2, false, nil, "user1"}, affectedRows: 1, }, @@ -487,19 +487,19 @@ func TestStmtUpsert_Exec(t *testing.T) { }, { name: "table_not_exists_placeholders", - sql: fmt.Sprintf(`UPSERT INTO %s.tbl_not_found (id,username,email) VALUES (:1, :2, :3)`, dbname), + sql: fmt.Sprintf(`UPSERT INTO %s.tbl_not_found (id,username,email) VALUES (:1, :2, :3) WITH singlePK`, dbname), args: []interface{}{"x", "y", "x", "y"}, mustNotFound: true, }, { name: "db_not_exists_placeholders", - sql: `UPSERT INTO db_not_exists.table (id,username,email) VALUES (@1, @2, @3)`, + sql: `UPSERT INTO db_not_exists.table (id,username,email) VALUES (@1, @2, @3) WITH singlePK`, args: []interface{}{"x", "y", "x", "y"}, mustNotFound: true, }, { name: "error_invalid_value_index", - sql: fmt.Sprintf(`UPSERT INTO %s.tbltemp (id, username, email, grade, actived, data) VALUES (:1, $2, @3, @4, $5, :10)`, dbname), + sql: fmt.Sprintf(`UPSERT INTO %s.tbltemp (id, username, email, grade, actived, data) VALUES (:1, $2, @3, @4, $5, :10) WITH singlePK`, dbname), args: []interface{}{"2", "user", "user@domain1.com", 3, false, nil, "user"}, mustError: "invalid value index", }, @@ -578,7 +578,7 @@ func TestStmtUpsert_Exec_DefaultDb(t *testing.T) { fmt.Sprintf("CREATE DATABASE %s", dbname), fmt.Sprintf("CREATE COLLECTION %s.tbltemp WITH pk=/username WITH UK=/email", dbname), }, - sql: `UPSERT INTO tbltemp (id, username, email, grade, actived) VALUES ("\"1\"", "\"user1\"", "\"user1@domain.com\"", 7, true)`, + sql: `UPSERT INTO tbltemp (id, username, email, grade, actived) VALUES ("\"1\"", "\"user1\"", "\"user1@domain.com\"", 7, true) WITH singlePK`, args: []interface{}{"user1"}, affectedRows: 1, }, @@ -602,7 +602,7 @@ func TestStmtUpsert_Exec_DefaultDb(t *testing.T) { }, { name: "table_not_exists", - sql: `UPSERT INTO tbl_not_found (id,username,email) VALUES ("\"x\"", "\"y\"", "\"x\"")`, + sql: `UPSERT INTO tbl_not_found (id,username,email) VALUES ("\"x\"", "\"y\"", "\"x\"") WITH single_PK`, args: []interface{}{"y"}, mustNotFound: true, }, @@ -620,7 +620,7 @@ func TestStmtUpsert_Exec_DefaultDb(t *testing.T) { }, { name: "upsert_another_placeholders", - sql: `UPSERT INTO tbltemp (id, username, email, grade, actived, data) VALUES (:1, $2, @3, @4, $5, :6)`, + sql: `UPSERT INTO tbltemp (id, username, email, grade, actived, data) VALUES (:1, $2, @3, @4, $5, :6) WITH singlePK`, args: []interface{}{"2", "user2", "user2@domain.com", 2, false, map[string]interface{}{"str": "a string", "num": 1.23, "bool": true, "date": time.Now()}, "user2"}, affectedRows: 1, }, @@ -632,7 +632,7 @@ func TestStmtUpsert_Exec_DefaultDb(t *testing.T) { }, { name: "upsert_conflict_uk_placeholders", - sql: `UPSERT INTO tbltemp (id, username, email, grade, actived, data) VALUES (:1, $2, @3, @4, $5, :6)`, + sql: `UPSERT INTO tbltemp (id, username, email, grade, actived, data) VALUES (:1, $2, @3, @4, $5, :6) WITH single_PK`, args: []interface{}{"2", "user1", "user2@domain.com", 3, false, nil, "user1"}, mustConflict: true, }, @@ -700,6 +700,133 @@ func TestStmtUpsert_Exec_DefaultDb(t *testing.T) { } } +func TestStmtUpsert_SubPartitions(t *testing.T) { + testName := "TestStmtUpsert_SubPartitions" + db := _openDb(t, testName) + dbname := "dbtemp" + defer db.Exec(fmt.Sprintf("DROP DATABASE IF EXISTS %s", dbname)) + testData := []struct { + name string + initSqls []string + sql string + args []interface{} + mustConflict bool + mustNotFound bool + mustError string + affectedRows int64 + }{ + { + name: "upsert_new", + initSqls: []string{ + "DROP DATABASE IF EXISTS db_not_exists", + fmt.Sprintf("DROP DATABASE IF EXISTS %s", dbname), + fmt.Sprintf("CREATE DATABASE %s", dbname), + fmt.Sprintf("CREATE COLLECTION %s.tbltemp WITH pk=/app,/username WITH UK=/email", dbname), + }, + sql: fmt.Sprintf(`UPSERT INTO %s.tbltemp (id, app, username, email, grade, actived) VALUES ("\"1\"", "\"app\"", "\"user1\"", "\"user1@domain.com\"", 7, true)`, dbname), + args: []interface{}{"app", "user1"}, + affectedRows: 1, + }, + { + name: "upsert_another", + sql: fmt.Sprintf(`UPSERT INTO %s.tbltemp (id, app, username, email, grade, actived) VALUES ("\"2\"", "\"app\"", "\"user2\"", "\"user2@domain.com\"", 7, true)`, dbname), + args: []interface{}{"app", "user2"}, + affectedRows: 1, + }, + { + name: "upsert_duplicated_id_placeholders", + sql: fmt.Sprintf(`UPSERT INTO %s.tbltemp (id,app,username,email,grade,actived) VALUES ("\"1\"", "\"app\"", "\"user1\"", "\"user3@domain1.com\"", 8, false)`, dbname), + args: []interface{}{"app", "user1"}, + affectedRows: 1, + }, + { + name: "upsert_conflict_uk", + sql: fmt.Sprintf(`UPSERT INTO %s.tbltemp (id,app,username,email,grade,actived) VALUES ("\"3\"", "\"app\"", "\"user2\"", "\"user2@domain.com\"", 9, true)`, dbname), + args: []interface{}{"app", "user2"}, + mustConflict: true, + }, + { + name: "upsert_new_placeholders", + initSqls: []string{ + "DROP DATABASE IF EXISTS db_not_exists", + fmt.Sprintf("DROP DATABASE IF EXISTS %s", dbname), + fmt.Sprintf("CREATE DATABASE %s", dbname), + fmt.Sprintf("CREATE COLLECTION %s.tbltemp WITH pk=/app,/username WITH uk=/email", dbname), + }, + sql: fmt.Sprintf(`UPSERT INTO %s.tbltemp (id, app, username, email, grade, actived, data) VALUES (:1, $2, @3, @4, $5, :6, :7)`, dbname), + args: []interface{}{"1", "app", "user1", "user1@domain.com", 1, true, map[string]interface{}{"str": "a string", "num": 1.23, "bool": true, "date": time.Now()}, "app", "user1"}, + affectedRows: 1, + }, + { + name: "upsert_another_placeholders", + sql: fmt.Sprintf(`UPSERT INTO %s.tbltemp (id, app, username, email, grade, actived, data) VALUES (:1, $2, @3, @4, $5, :6, :7)`, dbname), + args: []interface{}{"2", "app", "user2", "user2@domain.com", 2, false, map[string]interface{}{"str": "a string", "num": 1.23, "bool": true, "date": time.Now()}, "app", "user2"}, + affectedRows: 1, + }, + { + name: "upsert_duplicated_id_placeholders", + sql: fmt.Sprintf(`UPSERT INTO %s.tbltemp (id, app, username, email, grade, actived, data) VALUES (:1, $2, @3, @4, $5, :6, :7)`, dbname), + args: []interface{}{"1", "app", "user1", "user2@domain.com", 2, false, nil, "app", "user1"}, + affectedRows: 1, + }, + { + name: "upsert_conflict_uk_placeholders", + sql: fmt.Sprintf(`UPSERT INTO %s.tbltemp (id, app, username, email, grade, actived, data) VALUES (:1, $2, @3, @4, $5, :6, :7)`, dbname), + args: []interface{}{"2", "app", "user1", "user2@domain.com", 3, false, nil, "app", "user1"}, + mustConflict: true, + }, + } + for _, testCase := range testData { + t.Run(testCase.name, func(t *testing.T) { + for _, initSql := range testCase.initSqls { + _, err := db.Exec(initSql) + if err != nil { + t.Fatalf("%s failed: {error: %s / sql: %s}", testName+"/"+testCase.name+"/init", err, initSql) + } + } + execResult, err := db.Exec(testCase.sql, testCase.args...) + if testCase.mustConflict && err != ErrConflict { + t.Fatalf("%s failed: expect ErrConflict but received %#v", testName+"/"+testCase.name+"/exec", err) + } + if testCase.mustNotFound && err != ErrNotFound { + t.Fatalf("%s failed: expect ErrNotFound but received %#v", testName+"/"+testCase.name+"/exec", err) + } + if testCase.mustConflict || testCase.mustNotFound { + return + } + if testCase.mustError != "" { + if err == nil || strings.Index(err.Error(), testCase.mustError) < 0 { + t.Fatalf("%s failed: expected '%s' bur received %#v", testCase.name, testCase.mustError, err) + } + return + } + if err != nil { + t.Fatalf("%s failed: %s", testName+"/"+testCase.name+"/exec", err) + } + + affectedRows, err := execResult.RowsAffected() + if err != nil { + t.Fatalf("%s failed: %s", testName+"/"+testCase.name+"/rows_affected", err) + } + if affectedRows != testCase.affectedRows { + t.Fatalf("%s failed: expected %#v affected-rows but received %#v", testName+"/"+testCase.name, testCase.affectedRows, affectedRows) + } + _, err = execResult.LastInsertId() + if err == nil { + t.Fatalf("%s failed: expected LastInsertId but received nil", testName+"/"+testCase.name) + } + lastInsertId := make(map[string]interface{}) + err = json.Unmarshal([]byte(err.Error()), &lastInsertId) + if err != nil { + t.Fatalf("%s failed: %s", testName+"/"+testCase.name+"/LastInsertId", err) + } + if len(lastInsertId) != 1 { + t.Fatalf("%s failed - LastInsertId: %#v", testName+"/"+testCase.name+"/LastInsertId", lastInsertId) + } + }) + } +} + func TestStmtDelete_Query(t *testing.T) { testName := "TestStmtDelete_Query" db := _openDb(t, testName) From c32eb2803061e655dc1b6808eca1728471dad12c Mon Sep 17 00:00:00 2001 From: Thanh Nguyen Date: Wed, 14 Jun 2023 12:26:09 +1000 Subject: [PATCH 4/8] DELETE statement: support subpartitions --- SQL.md | 8 ++- stmt.go | 14 ++-- stmt_document.go | 37 +++++++---- stmt_document_parsing_test.go | 45 +++++++++++-- stmt_document_test.go | 121 ++++++++++++++++++++++++++++++++-- 5 files changed, 193 insertions(+), 32 deletions(-) diff --git a/SQL.md b/SQL.md index a667b0e..9460ac2 100644 --- a/SQL.md +++ b/SQL.md @@ -319,7 +319,7 @@ fmt.Println(dbresult.RowsAffected()) > Use `sql.DB.Exec` to execute the statement, `Query` will return error. -> Values of partition key _must_ be supplied at the end of the argument list when invoking `db.Exec()`. +> Values of partition keys _must_ be supplied at the end of the argument list when invoking `db.Exec()`. A value is either: - a placeholder - which is a number prefixed by `$` or `@` or `:`, for example `$1`, `@2` or `:3`. Placeholders are 1-based index, that means starting from 1. @@ -358,11 +358,13 @@ Description: delete an existing document. Syntax: ```sql -DELETE FROM [.] WHERE id= +DELETE FROM [.] WHERE id= [WITH singlePK|SINGLE_PK] ``` > `` can be ommitted if `DefaultDb` is supplied in the Data Source Name (DSN). +Since [v0.3.0](RELEASE-NOTES.md), `gocosmos` supports [Hierarchical Partition Keys](https://learn.microsoft.com/en-us/azure/cosmos-db/hierarchical-partition-keys) (or sub-partitions). If the collection is known not to have sub-partitions, supplying `WITH singlePK` (or `WITH SINGLE_PK`) can save one roundtrip to Cosmos DB server. + Example: ```go sql := `DELETE FROM mydb.mytable WHERE id=@1` @@ -375,7 +377,7 @@ fmt.Println(dbresult.RowsAffected()) > Use `sql.DB.Exec` to execute the statement, `Query` will return error. -> Value of partition key _must_ be supplied at the last argument of `db.Exec()` call. +> Values of partition keys _must_ be supplied at the end of the argument list when invoking `db.Exec()`. - `DELETE` removes only one document specified by id. - Upon successful execution, `RowsAffected()` returns `(1, nil)`. If no document matched, `RowsAffected()` returns `(0, nil)`. diff --git a/stmt.go b/stmt.go index 76fc822..8a55bac 100644 --- a/stmt.go +++ b/stmt.go @@ -31,7 +31,7 @@ var ( reInsert = regexp.MustCompile(`(?is)^(INSERT|UPSERT)\s+INTO\s+(` + field + `\.)?` + field + `\s*\(([^)]*?)\)\s*VALUES\s*\(([^)]*?)\)` + with + `$`) reSelect = regexp.MustCompile(`(?is)^SELECT\s+(CROSS\s+PARTITION\s+)?.*?\s+FROM\s+` + field + `.*?` + with + `$`) reUpdate = regexp.MustCompile(`(?is)^UPDATE\s+(` + field + `\.)?` + field + `\s+SET\s+(.*)\s+WHERE\s+id\s*=\s*(.*)$`) - reDelete = regexp.MustCompile(`(?is)^DELETE\s+FROM\s+(` + field + `\.)?` + field + `\s+WHERE\s+id\s*=\s*(.*)$`) + reDelete = regexp.MustCompile(`(?is)^DELETE\s+FROM\s+(` + field + `\.)?` + field + `\s+WHERE\s+id\s*=\s*(.*?)` + with + `$`) ) func parseQuery(c *Conn, query string) (driver.Stmt, error) { @@ -193,15 +193,17 @@ func parseQueryWithDefaultDb(c *Conn, defaultDb, query string) (driver.Stmt, err if re := reDelete; re.MatchString(query) { groups := re.FindAllStringSubmatch(query, -1) stmt := &StmtDelete{ - Stmt: &Stmt{query: query, conn: c, numInput: 0}, - dbName: strings.TrimSpace(groups[0][2]), - collName: strings.TrimSpace(groups[0][3]), - idStr: strings.TrimSpace(groups[0][4]), + StmtCRUD: &StmtCRUD{ + Stmt: &Stmt{query: query, conn: c, numInput: 0}, + dbName: strings.TrimSpace(groups[0][2]), + collName: strings.TrimSpace(groups[0][3]), + }, + idStr: strings.TrimSpace(groups[0][4]), } if stmt.dbName == "" { stmt.dbName = defaultDb } - if err := stmt.parse(); err != nil { + if err := stmt.parse(groups[0][5]); err != nil { return nil, err } return stmt, stmt.validate() diff --git a/stmt_document.go b/stmt_document.go index 92b9192..3d6cc89 100644 --- a/stmt_document.go +++ b/stmt_document.go @@ -104,7 +104,7 @@ func (s *StmtCRUD) fetchPkInfo() error { // // Syntax: // -// INSERT|UPSERT INTO . () VALUES () +// INSERT|UPSERT INTO . () VALUES () [WITH singlePK|SINGLE_PK] // // - values are comma separated. // - a value is either: @@ -222,21 +222,29 @@ func (s *StmtInsert) Query(_ []driver.Value) (driver.Rows, error) { // // Syntax: // -// DELETE FROM . WHERE id= +// DELETE FROM . WHERE id= [WITH singlePK|SINGLE_PK] // // - Currently DELETE only removes one document specified by id. // // - is treated as string. `WHERE id=abc` has the same effect as `WHERE id="abc"`. type StmtDelete struct { - *Stmt - dbName string - collName string - idStr string - id interface{} + *StmtCRUD + idStr string + id interface{} } -func (s *StmtDelete) parse() error { - s.numInput = 1 +func (s *StmtDelete) parse(withOptsStr string) error { + if err := s.parseWithOpts(withOptsStr); err != nil { + return err + } + _, ok1 := s.withOpts["SINGLEPK"] + _, ok2 := s.withOpts["SINGLE_PK"] + s.isSinglePathPk = ok1 || ok2 + if s.isSinglePathPk { + s.numPkPaths = 1 + } + + s.numInput = 0 hasPrefix := strings.HasPrefix(s.idStr, `"`) hasSuffix := strings.HasSuffix(s.idStr, `"`) if hasPrefix != hasSuffix { @@ -268,8 +276,15 @@ func (s *StmtDelete) validate() error { // Exec implements driver.Stmt/Exec. // -// Note: this function expects the _last_ argument is _partition_ key value. +// Note: this function expects the _partition key values are placed at the end_ of the argument list. func (s *StmtDelete) Exec(args []driver.Value) (driver.Result, error) { + if err := s.fetchPkInfo(); err != nil { + return nil, err + } + if len(args) != s.numInput+s.numPkPaths { + return nil, fmt.Errorf("expected %d arguments, got %d", s.numInput+s.numPkPaths, len(args)) + } + id := s.idStr if s.id != nil { ph := s.id.(placeholder) @@ -279,7 +294,7 @@ func (s *StmtDelete) Exec(args []driver.Value) (driver.Result, error) { id = fmt.Sprintf("%s", args[ph.index-1]) } restResult := s.conn.restClient.DeleteDocument(DocReq{DbName: s.dbName, CollName: s.collName, DocId: id, - PartitionKeyValues: []interface{}{args[s.numInput-1]}, // expect the last argument is partition key value + PartitionKeyValues: s.extractPkValuesFromArgs(args...), }) result := buildResultNoResultSet(&restResult.RestReponse, false, "", 0) switch restResult.StatusCode { diff --git a/stmt_document_parsing_test.go b/stmt_document_parsing_test.go index c117d60..cfad982 100644 --- a/stmt_document_parsing_test.go +++ b/stmt_document_parsing_test.go @@ -335,7 +335,7 @@ func TestStmtDelete_parse(t *testing.T) { sql: `DELETE FROM db1.table1 WHERE id=abc`, - expected: &StmtDelete{dbName: "db1", collName: "table1", idStr: "abc"}, + expected: &StmtDelete{StmtCRUD: &StmtCRUD{dbName: "db1", collName: "table1"}, idStr: "abc"}, }, { name: "basic2", @@ -343,14 +343,29 @@ db1.table1 WHERE DELETE FROM db-2.table_2 WHERE id="def"`, - expected: &StmtDelete{dbName: "db-2", collName: "table_2", idStr: "def"}, + expected: &StmtDelete{StmtCRUD: &StmtCRUD{dbName: "db-2", collName: "table_2"}, idStr: "def"}, }, { name: "basic3", sql: `DELETE FROM db_3-0.table-3_0 WHERE id=@2`, - expected: &StmtDelete{dbName: "db_3-0", collName: "table-3_0", idStr: "@2", id: placeholder{2}}, + expected: &StmtDelete{StmtCRUD: &StmtCRUD{dbName: "db_3-0", collName: "table-3_0"}, idStr: "@2", id: placeholder{2}}, + }, + { + name: "singlepk", + sql: `DELETE FROM db.table WHERE id=@2 WITH singlePK`, + expected: &StmtDelete{StmtCRUD: &StmtCRUD{dbName: "db", collName: "table", isSinglePathPk: true, numPkPaths: 1}, idStr: "@2", id: placeholder{2}}, + }, + { + name: "single_pk", + sql: `DELETE FROM db.table WHERE id=@2 with Single_PK`, + expected: &StmtDelete{StmtCRUD: &StmtCRUD{dbName: "db", collName: "table", isSinglePathPk: true, numPkPaths: 1}, idStr: "@2", id: placeholder{2}}, + }, + { + name: "singlepk_single_pk", + sql: `DELETE FROM db.table WHERE id=@2 with SinglePK WITH SINGLE_PK`, + expected: &StmtDelete{StmtCRUD: &StmtCRUD{dbName: "db", collName: "table", isSinglePathPk: true, numPkPaths: 1}, idStr: "@2", id: placeholder{2}}, }, } for _, testCase := range testData { @@ -395,7 +410,7 @@ func TestStmtDelete_parse_defaultDb(t *testing.T) { sql: `DELETE FROM table1 WHERE id=abc`, - expected: &StmtDelete{dbName: "mydb", collName: "table1", idStr: "abc"}, + expected: &StmtDelete{StmtCRUD: &StmtCRUD{dbName: "mydb", collName: "table1"}, idStr: "abc"}, }, { name: "db_in_query", @@ -404,7 +419,7 @@ table1 WHERE DELETE FROM db-2.table_2 WHERE id="def"`, - expected: &StmtDelete{dbName: "db-2", collName: "table_2", idStr: "def"}, + expected: &StmtDelete{StmtCRUD: &StmtCRUD{dbName: "db-2", collName: "table_2"}, idStr: "def"}, }, { name: "placeholder", @@ -412,7 +427,25 @@ FROM db-2.table_2 sql: `DELETE FROM db_3-0.table-3_0 WHERE id=@2`, - expected: &StmtDelete{dbName: "db_3-0", collName: "table-3_0", idStr: "@2", id: placeholder{2}}, + expected: &StmtDelete{StmtCRUD: &StmtCRUD{dbName: "db_3-0", collName: "table-3_0"}, idStr: "@2", id: placeholder{2}}, + }, + { + name: "singlepk", + db: "mydb", + sql: `DELETE FROM table WHERE id=@2 With singlePk`, + expected: &StmtDelete{StmtCRUD: &StmtCRUD{dbName: "mydb", collName: "table", isSinglePathPk: true, numPkPaths: 1}, idStr: "@2", id: placeholder{2}}, + }, + { + name: "single_pk", + db: "mydb", + sql: `DELETE FROM db.table WHERE id=@2 With single_Pk`, + expected: &StmtDelete{StmtCRUD: &StmtCRUD{dbName: "db", collName: "table", isSinglePathPk: true, numPkPaths: 1}, idStr: "@2", id: placeholder{2}}, + }, + { + name: "singlepk_single_pk", + db: "mydb", + sql: `DELETE FROM table WHERE id=@2 With single_Pk, With SinglePK`, + expected: &StmtDelete{StmtCRUD: &StmtCRUD{dbName: "mydb", collName: "table", isSinglePathPk: true, numPkPaths: 1}, idStr: "@2", id: placeholder{2}}, }, } for _, testCase := range testData { diff --git a/stmt_document_test.go b/stmt_document_test.go index 595f37f..048d1b1 100644 --- a/stmt_document_test.go +++ b/stmt_document_test.go @@ -873,7 +873,7 @@ func TestStmtDelete_Exec(t *testing.T) { }, { name: "delete_2", - sql: fmt.Sprintf(`DELETE FROM %s.tbltemp WHERE id="2"`, dbname), + sql: fmt.Sprintf(`DELETE FROM %s.tbltemp WHERE id="2" with SINGLE_PK`, dbname), args: []interface{}{"user"}, affectedRows: 1, }, @@ -885,7 +885,7 @@ func TestStmtDelete_Exec(t *testing.T) { }, { name: "delete_4", - sql: fmt.Sprintf(`DELETE FROM %s.tbltemp WHERE id=@1`, dbname), + sql: fmt.Sprintf(`DELETE FROM %s.tbltemp WHERE id=@1 with SINGLEPK`, dbname), args: []interface{}{"4", "user"}, affectedRows: 1, }, @@ -897,7 +897,7 @@ func TestStmtDelete_Exec(t *testing.T) { }, { name: "row_not_exists", - sql: fmt.Sprintf(`DELETE FROM %s.tbltemp WHERE id=1`, dbname), + sql: fmt.Sprintf(`DELETE FROM %s.tbltemp WHERE id=1 with SINGLE_PK`, dbname), args: []interface{}{"user"}, affectedRows: 0, }, @@ -993,7 +993,7 @@ func TestStmtDelete_Exec_DefaultDb(t *testing.T) { }, initParams: [][]interface{}{nil, nil, nil, nil, {"1", "user", "user@domain1.com", "user"}, {"2", "user", "user@domain2.com", "user"}, {"3", "user", "user@domain3.com", "user"}, {"4", "user", "user@domain4.com", "user"}, {"5", "user", "user@domain5.com", "user"}}, - sql: `DELETE FROM tbltemp WHERE id=1`, + sql: `DELETE FROM tbltemp WHERE id=1 with SINGLE_PK`, args: []interface{}{"user"}, affectedRows: 1, }, @@ -1005,7 +1005,7 @@ func TestStmtDelete_Exec_DefaultDb(t *testing.T) { }, { name: "delete_3", - sql: `DELETE FROM tbltemp WHERE id=:1`, + sql: `DELETE FROM tbltemp WHERE id=:1 with SINGLEPK`, args: []interface{}{"3", "user"}, affectedRows: 1, }, @@ -1017,7 +1017,7 @@ func TestStmtDelete_Exec_DefaultDb(t *testing.T) { }, { name: "delete_5", - sql: `DELETE FROM tbltemp WHERE id=$1`, + sql: `DELETE FROM tbltemp WHERE id=$1 with SINGLE_PK`, args: []interface{}{"5", "user"}, affectedRows: 1, }, @@ -1082,6 +1082,115 @@ func TestStmtDelete_Exec_DefaultDb(t *testing.T) { } } +func TestStmtDelete_SubPartitions(t *testing.T) { + testName := "TestStmtDelete_Exec" + db := _openDb(t, testName) + dbname := "dbtemp" + defer db.Exec(fmt.Sprintf("DROP DATABASE IF EXISTS %s", dbname)) + testData := []struct { + name string + initSqls []string + initParams [][]interface{} + sql string + args []interface{} + mustConflict bool + mustNotFound bool + mustError string + affectedRows int64 + }{ + { + name: "delete_1", + initSqls: []string{ + "DROP DATABASE IF EXISTS db_not_exists", + fmt.Sprintf("DROP DATABASE IF EXISTS %s", dbname), + fmt.Sprintf("CREATE DATABASE %s", dbname), + fmt.Sprintf("CREATE COLLECTION %s.tbltemp WITH pk=/app,/username WITH uk=/email", dbname), + fmt.Sprintf(`INSERT INTO %s.tbltemp (id,app,username,email) VALUES (:1,:2,:3,:4)`, dbname), + fmt.Sprintf(`INSERT INTO %s.tbltemp (id,app,username,email) VALUES (:1,:2,:3,:4)`, dbname), + fmt.Sprintf(`INSERT INTO %s.tbltemp (id,app,username,email) VALUES (:1,:2,:3,:4)`, dbname), + fmt.Sprintf(`INSERT INTO %s.tbltemp (id,app,username,email) VALUES (:1,:2,:3,:4)`, dbname), + fmt.Sprintf(`INSERT INTO %s.tbltemp (id,app,username,email) VALUES (:1,:2,:3,:4)`, dbname), + }, + initParams: [][]interface{}{nil, nil, nil, nil, {"1", "app", "user", "user@domain1.com", "app", "user"}, + {"2", "app", "user", "user@domain2.com", "app", "user"}, {"3", "app", "user", "user@domain3.com", "app", "user"}, + {"4", "app", "user", "user@domain4.com", "app", "user"}, {"5", "app", "user", "user@domain5.com", "app", "user"}}, + sql: fmt.Sprintf(`DELETE FROM %s.tbltemp WHERE id=1`, dbname), + args: []interface{}{"app", "user"}, + affectedRows: 1, + }, + { + name: "delete_2", + sql: fmt.Sprintf(`DELETE FROM %s.tbltemp WHERE id="2"`, dbname), + args: []interface{}{"app", "user"}, + affectedRows: 1, + }, + { + name: "delete_3", + sql: fmt.Sprintf(`DELETE FROM %s.tbltemp WHERE id=:1`, dbname), + args: []interface{}{"3", "app", "user"}, + affectedRows: 1, + }, + { + name: "delete_4", + sql: fmt.Sprintf(`DELETE FROM %s.tbltemp WHERE id=@1`, dbname), + args: []interface{}{"4", "app", "user"}, + affectedRows: 1, + }, + { + name: "delete_5", + sql: fmt.Sprintf(`DELETE FROM %s.tbltemp WHERE id=$1`, dbname), + args: []interface{}{"5", "app", "user"}, + affectedRows: 1, + }, + { + name: "row_not_exists", + sql: fmt.Sprintf(`DELETE FROM %s.tbltemp WHERE id=1`, dbname), + args: []interface{}{"app", "user"}, + affectedRows: 0, + }, + } + for _, testCase := range testData { + t.Run(testCase.name, func(t *testing.T) { + for i, initSql := range testCase.initSqls { + var params []interface{} + if len(testCase.initParams) > i { + params = testCase.initParams[i] + } + _, err := db.Exec(initSql, params...) + if err != nil { + t.Fatalf("%s failed: {error: %s / sql: %s}", testName+"/"+testCase.name+"/init", err, initSql) + } + } + execResult, err := db.Exec(testCase.sql, testCase.args...) + if testCase.mustConflict && err != ErrConflict { + t.Fatalf("%s failed: expect ErrConflict but received %#v", testName+"/"+testCase.name+"/exec", err) + } + if testCase.mustNotFound && err != ErrNotFound { + t.Fatalf("%s failed: expect ErrNotFound but received %#v", testName+"/"+testCase.name+"/exec", err) + } + if testCase.mustConflict || testCase.mustNotFound { + return + } + if testCase.mustError != "" { + if err == nil || strings.Index(err.Error(), testCase.mustError) < 0 { + t.Fatalf("%s failed: expected '%s' bur received %#v", testCase.name, testCase.mustError, err) + } + return + } + if err != nil { + t.Fatalf("%s failed: %s", testName+"/"+testCase.name+"/exec", err) + } + affectedRows, err := execResult.RowsAffected() + if err != nil { + t.Fatalf("%s failed: %s", testName+"/"+testCase.name+"/rows_affected", err) + } + if affectedRows != testCase.affectedRows { + t.Fatalf("%s failed: expected %#v affected-rows but received %#v", testName+"/"+testCase.name, testCase.affectedRows, affectedRows) + } + }) + } +} + func TestStmtUpdate_Query(t *testing.T) { testName := "TestStmtUpdate_Query" db := _openDb(t, testName) From a071e2f6f9c22d8e9c31757ab25252cb3a8d0a14 Mon Sep 17 00:00:00 2001 From: Thanh Nguyen Date: Wed, 14 Jun 2023 15:03:02 +1000 Subject: [PATCH 5/8] UPDATE statement: support subpartitions --- SQL.md | 13 ++- stmt.go | 12 +-- stmt_document.go | 55 +++++++----- stmt_document_parsing_test.go | 41 ++++++++- stmt_document_test.go | 159 +++++++++++++++++++++++++++++++--- 5 files changed, 235 insertions(+), 45 deletions(-) diff --git a/SQL.md b/SQL.md index 9460ac2..d8e0245 100644 --- a/SQL.md +++ b/SQL.md @@ -358,7 +358,9 @@ Description: delete an existing document. Syntax: ```sql -DELETE FROM [.] WHERE id= [WITH singlePK|SINGLE_PK] +DELETE FROM [.] +WHERE id= +[WITH singlePK|SINGLE_PK] ``` > `` can be ommitted if `DefaultDb` is supplied in the Data Source Name (DSN). @@ -392,11 +394,16 @@ Description: update an existing document. Syntax: ```sql -UPDATE [.] SET =[,=,...=] WHERE id= +UPDATE [.] +SET =[,=,...=] +WHERE id= +[WITH singlePK|SINGLE_PK] ``` > `` can be ommitted if `DefaultDb` is supplied in the Data Source Name (DSN). +Since [v0.3.0](RELEASE-NOTES.md), `gocosmos` supports [Hierarchical Partition Keys](https://learn.microsoft.com/en-us/azure/cosmos-db/hierarchical-partition-keys) (or sub-partitions). If the collection is known not to have sub-partitions, supplying `WITH singlePK` (or `WITH SINGLE_PK`) can save one roundtrip to Cosmos DB server. + Example: ```go sql := `UPDATE mydb.mytable SET a=1, b="\"a string\"", c=true, d="[1,true,null,\"string\"]", e=:2 WHERE id=@1` @@ -409,7 +416,7 @@ fmt.Println(dbresult.RowsAffected()) > Use `sql.DB.Exec` to execute the statement, `Query` will return error. -> Value of partition key _must_ be supplied at the last argument of `db.Exec()` call. +> Values of partition keys _must_ be supplied at the end of the argument list when invoking `db.Exec()`. - `UPDATE` modifies only one document specified by id. - Upon successful execution, `RowsAffected()` returns `(1, nil)`. If no document matched, `RowsAffected()` returns `(0, nil)`. diff --git a/stmt.go b/stmt.go index 8a55bac..fcd6398 100644 --- a/stmt.go +++ b/stmt.go @@ -30,7 +30,7 @@ var ( reInsert = regexp.MustCompile(`(?is)^(INSERT|UPSERT)\s+INTO\s+(` + field + `\.)?` + field + `\s*\(([^)]*?)\)\s*VALUES\s*\(([^)]*?)\)` + with + `$`) reSelect = regexp.MustCompile(`(?is)^SELECT\s+(CROSS\s+PARTITION\s+)?.*?\s+FROM\s+` + field + `.*?` + with + `$`) - reUpdate = regexp.MustCompile(`(?is)^UPDATE\s+(` + field + `\.)?` + field + `\s+SET\s+(.*)\s+WHERE\s+id\s*=\s*(.*)$`) + reUpdate = regexp.MustCompile(`(?is)^UPDATE\s+(` + field + `\.)?` + field + `\s+SET\s+(.*)\s+WHERE\s+id\s*=\s*(.*?)` + with + `$`) reDelete = regexp.MustCompile(`(?is)^DELETE\s+FROM\s+(` + field + `\.)?` + field + `\s+WHERE\s+id\s*=\s*(.*?)` + with + `$`) ) @@ -176,16 +176,18 @@ func parseQueryWithDefaultDb(c *Conn, defaultDb, query string) (driver.Stmt, err if re := reUpdate; re.MatchString(query) { groups := re.FindAllStringSubmatch(query, -1) stmt := &StmtUpdate{ - Stmt: &Stmt{query: query, conn: c, numInput: 0}, - dbName: strings.TrimSpace(groups[0][2]), - collName: strings.TrimSpace(groups[0][3]), + StmtCRUD: &StmtCRUD{ + Stmt: &Stmt{query: query, conn: c, numInput: 0}, + dbName: strings.TrimSpace(groups[0][2]), + collName: strings.TrimSpace(groups[0][3]), + }, updateStr: strings.TrimSpace(groups[0][4]), idStr: strings.TrimSpace(groups[0][5]), } if stmt.dbName == "" { stmt.dbName = defaultDb } - if err := stmt.parse(); err != nil { + if err := stmt.parse(groups[0][6]); err != nil { return nil, err } return stmt, stmt.validate() diff --git a/stmt_document.go b/stmt_document.go index 3d6cc89..55fdc4b 100644 --- a/stmt_document.go +++ b/stmt_document.go @@ -100,6 +100,20 @@ func (s *StmtCRUD) fetchPkInfo() error { return normalizeError(getCollResult.StatusCode, 0, getCollResult.Error()) } +func (s *StmtCRUD) parseWithOpts(withOptsStr string) error { + if err := s.Stmt.parseWithOpts(withOptsStr); err != nil { + return err + } + _, ok1 := s.withOpts["SINGLEPK"] + _, ok2 := s.withOpts["SINGLE_PK"] + s.isSinglePathPk = ok1 || ok2 + if s.isSinglePathPk { + s.numPkPaths = 1 + } + s.numInput = 0 + return nil +} + // StmtInsert implements "INSERT" operation. // // Syntax: @@ -135,16 +149,9 @@ func (s *StmtInsert) parse(withOptsStr string) error { if err := s.parseWithOpts(withOptsStr); err != nil { return err } - _, ok1 := s.withOpts["SINGLEPK"] - _, ok2 := s.withOpts["SINGLE_PK"] - s.isSinglePathPk = ok1 || ok2 - if s.isSinglePathPk { - s.numPkPaths = 1 - } s.fields = regexp.MustCompile(`[,\s]+`).Split(s.fieldsStr, -1) s.values = make([]interface{}, 0) - s.numInput = 0 for temp := strings.TrimSpace(s.valuesStr); temp != ""; temp = strings.TrimSpace(temp) { value, leftOver, err := _parseValue(temp, ',') if err == nil { @@ -237,14 +244,7 @@ func (s *StmtDelete) parse(withOptsStr string) error { if err := s.parseWithOpts(withOptsStr); err != nil { return err } - _, ok1 := s.withOpts["SINGLEPK"] - _, ok2 := s.withOpts["SINGLE_PK"] - s.isSinglePathPk = ok1 || ok2 - if s.isSinglePathPk { - s.numPkPaths = 1 - } - s.numInput = 0 hasPrefix := strings.HasPrefix(s.idStr, `"`) hasSuffix := strings.HasSuffix(s.idStr, `"`) if hasPrefix != hasSuffix { @@ -427,7 +427,7 @@ func (s *StmtSelect) Exec(_ []driver.Value) (driver.Result, error) { // // Syntax: // -// UPDATE . SET =[,=]*, WHERE id= +// UPDATE . SET =[,=]* WHERE id= [WITH singlePK|SINGLE_PK] // // - is treated as a string. `WHERE id=abc` has the same effect as `WHERE id="abc"`. // - is either: @@ -445,9 +445,7 @@ func (s *StmtSelect) Exec(_ []driver.Value) (driver.Result, error) { // // Currently UPDATE only updates one document specified by id. type StmtUpdate struct { - *Stmt - dbName string - collName string + *StmtCRUD updateStr string idStr string id interface{} @@ -512,8 +510,10 @@ func (s *StmtUpdate) _parseUpdateClause() error { return nil } -func (s *StmtUpdate) parse() error { - s.numInput = 1 +func (s *StmtUpdate) parse(withOptsStr string) error { + if err := s.parseWithOpts(withOptsStr); err != nil { + return err + } if err := s._parseId(); err != nil { return err @@ -541,8 +541,15 @@ func (s *StmtUpdate) validate() error { // Exec implements driver.Stmt/Exec. // -// Note: this function expects the _last_ argument is _partition_ key value. +// Note: this function expects the _partition key values are placed at the end_ of the argument list. func (s *StmtUpdate) Exec(args []driver.Value) (driver.Result, error) { + if err := s.fetchPkInfo(); err != nil { + return nil, err + } + if len(args) != s.numInput+s.numPkPaths { + return nil, fmt.Errorf("expected %d arguments, got %d", s.numInput+s.numPkPaths, len(args)) + } + // firstly, fetch the document id := s.idStr if s.id != nil { @@ -552,7 +559,7 @@ func (s *StmtUpdate) Exec(args []driver.Value) (driver.Result, error) { } id = fmt.Sprintf("%s", args[ph.index-1]) } - docReq := DocReq{DbName: s.dbName, CollName: s.collName, DocId: id, PartitionKeyValues: []interface{}{args[len(args)-1]}} + docReq := DocReq{DbName: s.dbName, CollName: s.collName, DocId: id, PartitionKeyValues: s.extractPkValuesFromArgs(args...)} getDocResult := s.conn.restClient.GetDocument(docReq) if err := getDocResult.Error(); err != nil { result := buildResultNoResultSet(&getDocResult.RestReponse, false, "", 0) @@ -565,8 +572,10 @@ func (s *StmtUpdate) Exec(args []driver.Value) (driver.Result, error) { } return result, result.err } + + // secondly, update the fetched document etag := getDocResult.DocInfo.Etag() - spec := DocumentSpec{DbName: s.dbName, CollName: s.collName, PartitionKeyValues: []interface{}{args[len(args)-1]}, DocumentData: getDocResult.DocInfo.RemoveSystemAttrs()} + spec := DocumentSpec{DbName: s.dbName, CollName: s.collName, PartitionKeyValues: s.extractPkValuesFromArgs(args...), DocumentData: getDocResult.DocInfo.RemoveSystemAttrs()} for i := 0; i < len(s.fields); i++ { switch s.values[i].(type) { case placeholder: diff --git a/stmt_document_parsing_test.go b/stmt_document_parsing_test.go index cfad982..69d6ff4 100644 --- a/stmt_document_parsing_test.go +++ b/stmt_document_parsing_test.go @@ -614,7 +614,7 @@ SET a=null, b= d="\"a string 'with' \\\"quote\\\"\"", e="{\"key\":\"value\"}" ,f="[2.0,null,false,\"a string 'with' \\\"quote\\\"\"]" WHERE id="abc"`, - expected: &StmtUpdate{dbName: "db1", collName: "table1", updateStr: `a=null, b= + expected: &StmtUpdate{StmtCRUD: &StmtCRUD{dbName: "db1", collName: "table1"}, updateStr: `a=null, b= 1.0, c=true, d="\"a string 'with' \\\"quote\\\"\"", e="{\"key\":\"value\"}" ,f="[2.0,null,false,\"a string 'with' \\\"quote\\\"\"]"`, fields: []string{"a", "b", "c", "d", "e", "f"}, values: []interface{}{nil, 1.0, true, `a string 'with' "quote"`, map[string]interface{}{"key": "value"}, []interface{}{2.0, nil, false, `a string 'with' "quote"`}}, idStr: "abc"}, @@ -625,9 +625,24 @@ SET a=null, b= SET a=$1, b= $2, c=:3, d=0 WHERE id=@4`, - expected: &StmtUpdate{dbName: "db-1", collName: "table_1", updateStr: `a=$1, b= + expected: &StmtUpdate{StmtCRUD: &StmtCRUD{dbName: "db-1", collName: "table_1"}, updateStr: `a=$1, b= $2, c=:3, d=0`, fields: []string{"a", "b", "c", "d"}, values: []interface{}{placeholder{1}, placeholder{2}, placeholder{3}, 0.0}, idStr: "@4", id: placeholder{4}}, }, + { + name: "singlepk", + sql: `UPDATE db.table SET a=$1, b=$2, c=:3, d=0 WHERE id=@4 with SinglePk`, + expected: &StmtUpdate{StmtCRUD: &StmtCRUD{dbName: "db", collName: "table", isSinglePathPk: true, numPkPaths: 1}, updateStr: `a=$1, b=$2, c=:3, d=0`, fields: []string{"a", "b", "c", "d"}, values: []interface{}{placeholder{1}, placeholder{2}, placeholder{3}, 0.0}, idStr: "@4", id: placeholder{4}}, + }, + { + name: "single_pk", + sql: `UPDATE db.table SET a=$1, b=$2, c=:3, d=0 WHERE id=@4 WITH SINGLE_PK`, + expected: &StmtUpdate{StmtCRUD: &StmtCRUD{dbName: "db", collName: "table", isSinglePathPk: true, numPkPaths: 1}, updateStr: `a=$1, b=$2, c=:3, d=0`, fields: []string{"a", "b", "c", "d"}, values: []interface{}{placeholder{1}, placeholder{2}, placeholder{3}, 0.0}, idStr: "@4", id: placeholder{4}}, + }, + { + name: "singlepk_single_pk", + sql: `UPDATE db.table SET a=$1, b=$2, c=:3, d=0 WHERE id=@4 with SINGLE_PK, With SinglePk`, + expected: &StmtUpdate{StmtCRUD: &StmtCRUD{dbName: "db", collName: "table", isSinglePathPk: true, numPkPaths: 1}, updateStr: `a=$1, b=$2, c=:3, d=0`, fields: []string{"a", "b", "c", "d"}, values: []interface{}{placeholder{1}, placeholder{2}, placeholder{3}, 0.0}, idStr: "@4", id: placeholder{4}}, + }, } for _, testCase := range testData { t.Run(testCase.name, func(t *testing.T) { @@ -674,7 +689,7 @@ SET a=null, b= d="\"a string 'with' \\\"quote\\\"\"", e="{\"key\":\"value\"}" ,f="[2.0,null,false,\"a string 'with' \\\"quote\\\"\"]" WHERE id="abc"`, - expected: &StmtUpdate{dbName: "mydb", collName: "table1", updateStr: `a=null, b= + expected: &StmtUpdate{StmtCRUD: &StmtCRUD{dbName: "mydb", collName: "table1"}, updateStr: `a=null, b= 1.0, c=true, d="\"a string 'with' \\\"quote\\\"\"", e="{\"key\":\"value\"}" ,f="[2.0,null,false,\"a string 'with' \\\"quote\\\"\"]"`, fields: []string{"a", "b", "c", "d", "e", "f"}, values: []interface{}{nil, 1.0, true, `a string 'with' "quote"`, map[string]interface{}{"key": "value"}, []interface{}{2.0, nil, false, `a string 'with' "quote"`}}, idStr: "abc"}}, @@ -685,9 +700,27 @@ SET a=null, b= SET a=$1, b= $2, c=:3, d=0 WHERE id=@4`, - expected: &StmtUpdate{dbName: "db-1", collName: "table_1", updateStr: `a=$1, b= + expected: &StmtUpdate{StmtCRUD: &StmtCRUD{dbName: "db-1", collName: "table_1"}, updateStr: `a=$1, b= $2, c=:3, d=0`, fields: []string{"a", "b", "c", "d"}, values: []interface{}{placeholder{1}, placeholder{2}, placeholder{3}, 0.0}, idStr: "@4", id: placeholder{4}}, }, + { + name: "singlepk", + db: "mydb", + sql: `UPDATE table SET a=$1, b=$2, c=:3, d=0 WHERE id=@4 with SinglePk`, + expected: &StmtUpdate{StmtCRUD: &StmtCRUD{dbName: "mydb", collName: "table", isSinglePathPk: true, numPkPaths: 1}, updateStr: `a=$1, b=$2, c=:3, d=0`, fields: []string{"a", "b", "c", "d"}, values: []interface{}{placeholder{1}, placeholder{2}, placeholder{3}, 0.0}, idStr: "@4", id: placeholder{4}}, + }, + { + name: "single_pk", + db: "mydb", + sql: `UPDATE db.table SET a=$1, b=$2, c=:3, d=0 WHERE id=@4 WITH SINGLE_PK`, + expected: &StmtUpdate{StmtCRUD: &StmtCRUD{dbName: "db", collName: "table", isSinglePathPk: true, numPkPaths: 1}, updateStr: `a=$1, b=$2, c=:3, d=0`, fields: []string{"a", "b", "c", "d"}, values: []interface{}{placeholder{1}, placeholder{2}, placeholder{3}, 0.0}, idStr: "@4", id: placeholder{4}}, + }, + { + name: "singlepk_single_pk", + db: "mydb", + sql: `UPDATE table SET a=$1, b=$2, c=:3, d=0 WHERE id=@4 with SINGLE_PK, With SinglePk`, + expected: &StmtUpdate{StmtCRUD: &StmtCRUD{dbName: "mydb", collName: "table", isSinglePathPk: true, numPkPaths: 1}, updateStr: `a=$1, b=$2, c=:3, d=0`, fields: []string{"a", "b", "c", "d"}, values: []interface{}{placeholder{1}, placeholder{2}, placeholder{3}, 0.0}, idStr: "@4", id: placeholder{4}}, + }, } for _, testCase := range testData { t.Run(testCase.name, func(t *testing.T) { diff --git a/stmt_document_test.go b/stmt_document_test.go index 048d1b1..795cff4 100644 --- a/stmt_document_test.go +++ b/stmt_document_test.go @@ -1233,7 +1233,7 @@ func TestStmtUpdate_Exec(t *testing.T) { }, { name: "update_pk", - sql: fmt.Sprintf(`UPDATE %s.tbltemp SET username="\"user1\"" WHERE id=1`, dbname), + sql: fmt.Sprintf(`UPDATE %s.tbltemp SET username="\"user1\"" WHERE id=1 with SinglePk`, dbname), args: []interface{}{"user1"}, affectedRows: 0, }, @@ -1245,7 +1245,7 @@ func TestStmtUpdate_Exec(t *testing.T) { }, { name: "row_not_exists", - sql: fmt.Sprintf(`UPDATE %s.tbltemp SET grade=3.4 WHERE id=3`, dbname), + sql: fmt.Sprintf(`UPDATE %s.tbltemp SET grade=3.4 WHERE id=3 with Single_Pk`, dbname), args: []interface{}{"user"}, affectedRows: 0, }, @@ -1278,7 +1278,7 @@ func TestStmtUpdate_Exec(t *testing.T) { fmt.Sprintf(`INSERT INTO %s.tbltemp (id,username,email,grade,active) VALUES (@1,$2,:3,$4,@5)`, dbname), }, initParams: [][]interface{}{nil, nil, nil, nil, {"1", "user", "user@domain.com", 1, true, "user"}, {"2", "user", "user2@domain.com", 1, true, "user"}}, - sql: fmt.Sprintf(`UPDATE %s.tbltemp SET grade=:1,active=@2,data=$3 WHERE id=:4`, dbname), + sql: fmt.Sprintf(`UPDATE %s.tbltemp SET grade=:1,active=@2,data=$3 WHERE id=:4 with SinglePk`, dbname), args: []interface{}{2.0, false, "a string 'with' \"quote\"", "1", "user"}, affectedRows: 1, }, @@ -1290,7 +1290,7 @@ func TestStmtUpdate_Exec(t *testing.T) { }, { name: "error_uk_placeholders", - sql: fmt.Sprintf(`UPDATE %s.tbltemp SET email=@1 WHERE id=:2`, dbname), + sql: fmt.Sprintf(`UPDATE %s.tbltemp SET email=@1 WHERE id=:2 with Single_Pk`, dbname), args: []interface{}{"user2@domain.com", "1", "user"}, mustConflict: true, }, @@ -1308,7 +1308,7 @@ func TestStmtUpdate_Exec(t *testing.T) { }, { name: "table_not_exists_placeholders", - sql: fmt.Sprintf(`UPDATE %s.tbl_not_found SET email=:1 WHERE id=:2`, dbname), + sql: fmt.Sprintf(`UPDATE %s.tbl_not_found SET email=:1 WHERE id=:2 with SinglePk`, dbname), args: []interface{}{"user2@domain.com", "1", "user"}, mustNotFound: true, }, @@ -1388,7 +1388,7 @@ func TestStmtUpdate_Exec_DefaultDb(t *testing.T) { fmt.Sprintf(`INSERT INTO %s.tbltemp (id,username,email,grade,active) VALUES (@1,$2,:3,$4,@5)`, dbname), }, initParams: [][]interface{}{nil, nil, nil, nil, {"1", "user", "user@domain.com", 1, true, "user"}, {"2", "user", "user2@domain.com", 1, true, "user"}}, - sql: `UPDATE tbltemp SET grade=2.0,active=false,data="\"a string 'with' \\\"quote\\\"\"" WHERE id=1`, + sql: `UPDATE tbltemp SET grade=2.0,active=false,data="\"a string 'with' \\\"quote\\\"\"" WHERE id=1 with SinglePk`, args: []interface{}{"user"}, affectedRows: 1, }, @@ -1400,7 +1400,7 @@ func TestStmtUpdate_Exec_DefaultDb(t *testing.T) { }, { name: "error_uk", - sql: `UPDATE tbltemp SET email="\"user2@domain.com\"" WHERE id=1`, + sql: `UPDATE tbltemp SET email="\"user2@domain.com\"" WHERE id=1 with Single_Pk`, args: []interface{}{"user"}, mustConflict: true, }, @@ -1412,7 +1412,7 @@ func TestStmtUpdate_Exec_DefaultDb(t *testing.T) { }, { name: "row_not_exists_in_partition", - sql: `UPDATE tbltemp SET grade=5.6 WHERE id=2`, + sql: `UPDATE tbltemp SET grade=5.6 WHERE id=2 with SinglePk`, args: []interface{}{"user2"}, affectedRows: 0, }, @@ -1433,7 +1433,7 @@ func TestStmtUpdate_Exec_DefaultDb(t *testing.T) { fmt.Sprintf(`INSERT INTO %s.tbltemp (id,username,email,grade,active) VALUES (@1,$2,:3,$4,@5)`, dbname), }, initParams: [][]interface{}{nil, nil, nil, nil, {"1", "user", "user@domain.com", 1, true, "user"}, {"2", "user", "user2@domain.com", 1, true, "user"}}, - sql: `UPDATE tbltemp SET grade=:1,active=@2,data=$3 WHERE id=:4`, + sql: `UPDATE tbltemp SET grade=:1,active=@2,data=$3 WHERE id=:4 with SinglePk`, args: []interface{}{2.0, false, "a string 'with' \"quote\"", "1", "user"}, affectedRows: 1, }, @@ -1451,7 +1451,7 @@ func TestStmtUpdate_Exec_DefaultDb(t *testing.T) { }, { name: "row_not_exists_placeholders", - sql: `UPDATE tbltemp SET grade=$1 WHERE id=:2`, + sql: `UPDATE tbltemp SET grade=$1 WHERE id=:2 with Single_Pk`, args: []interface{}{3.4, "3", "user"}, affectedRows: 0, }, @@ -1509,3 +1509,142 @@ func TestStmtUpdate_Exec_DefaultDb(t *testing.T) { }) } } + +func TestStmtUpdate_SubPartitions(t *testing.T) { + testName := "TestStmtUpdate_SubPartitions" + db := _openDb(t, testName) + dbname := "dbtemp" + defer db.Exec(fmt.Sprintf("DROP DATABASE IF EXISTS %s", dbname)) + testData := []struct { + name string + initSqls []string + initParams [][]interface{} + sql string + args []interface{} + mustConflict bool + mustNotFound bool + mustError string + affectedRows int64 + }{ + { + name: "update_1", + initSqls: []string{ + "DROP DATABASE IF EXISTS db_not_exists", + fmt.Sprintf("DROP DATABASE IF EXISTS %s", dbname), + fmt.Sprintf("CREATE DATABASE %s", dbname), + fmt.Sprintf("CREATE COLLECTION %s.tbltemp WITH pk=/app,/username WITH uk=/email", dbname), + fmt.Sprintf(`INSERT INTO %s.tbltemp (id,app,username,email,grade,active) VALUES (@1,$2,:3,$4,@5,:6)`, dbname), + fmt.Sprintf(`INSERT INTO %s.tbltemp (id,app,username,email,grade,active) VALUES (@1,$2,:3,$4,@5,:6)`, dbname), + }, + initParams: [][]interface{}{nil, nil, nil, nil, {"1", "app", "user", "user@domain.com", 1, true, "app", "user"}, + {"2", "app", "user", "user2@domain.com", 1, true, "app", "user"}}, + sql: fmt.Sprintf(`UPDATE %s.tbltemp SET grade=2.0,active=false,data="\"a string 'with' \\\"quote\\\"\"" WHERE id=1`, dbname), + args: []interface{}{"app", "user"}, + affectedRows: 1, + }, + { + name: "update_pk", + sql: fmt.Sprintf(`UPDATE %s.tbltemp SET username="\"user1\"" WHERE id=1`, dbname), + args: []interface{}{"app", "user1"}, + affectedRows: 0, + }, + { + name: "error_uk", + sql: fmt.Sprintf(`UPDATE %s.tbltemp SET email="\"user2@domain.com\"" WHERE id=1`, dbname), + args: []interface{}{"app", "user"}, + mustConflict: true, + }, + { + name: "row_not_exists", + sql: fmt.Sprintf(`UPDATE %s.tbltemp SET grade=3.4 WHERE id=3`, dbname), + args: []interface{}{"app", "user"}, + affectedRows: 0, + }, + { + name: "row_not_exists_in_partition", + sql: fmt.Sprintf(`UPDATE %s.tbltemp SET grade=5.6 WHERE id=2`, dbname), + args: []interface{}{"app", "user2"}, + affectedRows: 0, + }, + { + name: "update_1_placeholders", + initSqls: []string{ + "DROP DATABASE IF EXISTS db_not_exists", + fmt.Sprintf("DROP DATABASE IF EXISTS %s", dbname), + fmt.Sprintf("CREATE DATABASE %s", dbname), + fmt.Sprintf("CREATE COLLECTION %s.tbltemp WITH pk=/app,/username WITH uk=/email", dbname), + fmt.Sprintf(`INSERT INTO %s.tbltemp (id,app,username,email,grade,active) VALUES (@1,$2,:3,$4,@5,:6)`, dbname), + fmt.Sprintf(`INSERT INTO %s.tbltemp (id,app,username,email,grade,active) VALUES (@1,$2,:3,$4,@5,:6)`, dbname), + }, + initParams: [][]interface{}{nil, nil, nil, nil, {"1", "app", "user", "user@domain.com", 1, true, "app", "user"}, + {"2", "app", "user", "user2@domain.com", 1, true, "app", "user"}}, + sql: fmt.Sprintf(`UPDATE %s.tbltemp SET grade=:1,active=@2,data=$3 WHERE id=:4`, dbname), + args: []interface{}{2.0, false, "a string 'with' \"quote\"", "1", "app", "user"}, + affectedRows: 1, + }, + { + name: "update_pk_placeholders", + sql: fmt.Sprintf(`UPDATE %s.tbltemp SET username=$1 WHERE id=:2`, dbname), + args: []interface{}{"user1", "1", "app", "user1"}, + affectedRows: 0, + }, + { + name: "error_uk_placeholders", + sql: fmt.Sprintf(`UPDATE %s.tbltemp SET email=@1 WHERE id=:2`, dbname), + args: []interface{}{"user2@domain.com", "1", "app", "user"}, + mustConflict: true, + }, + { + name: "row_not_exists_placeholders", + sql: fmt.Sprintf(`UPDATE %s.tbltemp SET grade=$1 WHERE id=:2`, dbname), + args: []interface{}{3.4, "3", "app", "user"}, + affectedRows: 0, + }, + { + name: "row_not_exists_in_partition_placeholders", + sql: fmt.Sprintf(`UPDATE %s.tbltemp SET grade=@1 WHERE id=:2`, dbname), + args: []interface{}{5.6, "2", "app", "user2"}, + affectedRows: 0, + }, + } + for _, testCase := range testData { + t.Run(testCase.name, func(t *testing.T) { + for i, initSql := range testCase.initSqls { + var params []interface{} + if len(testCase.initParams) > i { + params = testCase.initParams[i] + } + _, err := db.Exec(initSql, params...) + if err != nil { + t.Fatalf("%s failed: {error: %s / sql: %s}", testName+"/"+testCase.name+"/init", err, initSql) + } + } + execResult, err := db.Exec(testCase.sql, testCase.args...) + if testCase.mustConflict && err != ErrConflict { + t.Fatalf("%s failed: expect ErrConflict but received %#v", testName+"/"+testCase.name+"/exec", err) + } + if testCase.mustNotFound && err != ErrNotFound { + t.Fatalf("%s failed: expect ErrNotFound but received %#v", testName+"/"+testCase.name+"/exec", err) + } + if testCase.mustConflict || testCase.mustNotFound { + return + } + if testCase.mustError != "" { + if err == nil || strings.Index(err.Error(), testCase.mustError) < 0 { + t.Fatalf("%s failed: expected '%s' bur received %#v", testCase.name, testCase.mustError, err) + } + return + } + if err != nil { + t.Fatalf("%s failed: %s", testName+"/"+testCase.name+"/exec", err) + } + affectedRows, err := execResult.RowsAffected() + if err != nil { + t.Fatalf("%s failed: %s", testName+"/"+testCase.name+"/rows_affected", err) + } + if affectedRows != testCase.affectedRows { + t.Fatalf("%s failed: expected %#v affected-rows but received %#v", testName+"/"+testCase.name, testCase.affectedRows, affectedRows) + } + }) + } +} From d6eab98a097b65af098946d957f50fd197acd667 Mon Sep 17 00:00:00 2001 From: Thanh Nguyen Date: Fri, 16 Jun 2023 18:27:37 +1000 Subject: [PATCH 6/8] fixes & enhancements to parsing phase --- .github/workflows/gocosmos.yaml | 2 +- SQL.md | 10 +- stmt.go | 36 +++-- stmt_collection.go | 113 ++++++++-------- stmt_collection_parsing_test.go | 8 +- stmt_database.go | 76 ++++++----- stmt_database_parsing_test.go | 6 +- stmt_document.go | 103 ++++++++++---- stmt_document_parsing_test.go | 233 +++++++++++++++++++++++++++++--- stmt_test.go | 60 ++++---- 10 files changed, 455 insertions(+), 192 deletions(-) diff --git a/.github/workflows/gocosmos.yaml b/.github/workflows/gocosmos.yaml index 78fb30b..5386acd 100644 --- a/.github/workflows/gocosmos.yaml +++ b/.github/workflows/gocosmos.yaml @@ -195,7 +195,7 @@ jobs: - name: Test run: | export COSMOSDB_URL="AccountEndpoint=https://127.0.0.1:8081/;AccountKey=C2y6yDjf5/R+ob0N8A7Cgv30VRDJIWEHLM+4QDU5DE2nQ9nDuVTqobD4b8mGGyPMbIZnqyMsEcaGQy67XIw/Jw==" - go test -cover -coverprofile="coverage_other.txt" -v -count 1 -p 1 -run "TestNew|TestStmt_NumInput|TestDriver_" . + go test -cover -coverprofile="coverage_other.txt" -v -count 1 -p 1 -run "TestNew|TestDriver_" . - name: Codecov uses: codecov/codecov-action@v3 with: diff --git a/SQL.md b/SQL.md index d8e0245..89cb73e 100644 --- a/SQL.md +++ b/SQL.md @@ -300,7 +300,7 @@ Syntax: INSERT INTO [.] (, ,...) VALUES (, ,...) -[WITH singlePK|SINGLE_PK] +[WITH singlePK|SINGLE_PK[=true]] ``` > `` can be ommitted if `DefaultDb` is supplied in the Data Source Name (DSN). @@ -346,7 +346,7 @@ Syntax & Usage: similar to [INSERT](#insert). UPSERT INTO [.] (, ,...) VALUES (, ,...) -[WITH singlePK|SINGLE_PK] +[WITH singlePK|SINGLE_PK[=true]] ``` [Back to top](#top) @@ -360,7 +360,7 @@ Syntax: ```sql DELETE FROM [.] WHERE id= -[WITH singlePK|SINGLE_PK] +[WITH singlePK|SINGLE_PK[=true]] ``` > `` can be ommitted if `DefaultDb` is supplied in the Data Source Name (DSN). @@ -397,7 +397,7 @@ Syntax: UPDATE [.] SET =[,=,...=] WHERE id= -[WITH singlePK|SINGLE_PK] +[WITH singlePK|SINGLE_PK[=true]] ``` > `` can be ommitted if `DefaultDb` is supplied in the Data Source Name (DSN). @@ -435,7 +435,7 @@ Syntax: SELECT [CROSS PARTITION] ... FROM ... [WITH database=] [[,] WITH collection=] -[[,] WITH cross_partition=true] +[[,] WITH cross_partition|CrossPartition[=true]] ``` > `` can be ommitted if `DefaultDb` is supplied in the Data Source Name (DSN). diff --git a/stmt.go b/stmt.go index fcd6398..80af664 100644 --- a/stmt.go +++ b/stmt.go @@ -14,7 +14,7 @@ const ( field = `([\w\-]+)` ifNotExists = `(\s+IF\s+NOT\s+EXISTS)?` ifExists = `(\s+IF\s+EXISTS)?` - with = `(\s+WITH\s+` + field + `(\s*=\s*([\w/\.\*,;:'"-]+))?((\s+|\s*,\s+|\s+,\s*)WITH\s+` + field + `(\s*=\s*([\w/\.\*,;:'"-]+))*)?)?` + with = `(\s+WITH\s+.*)?` ) var ( @@ -46,9 +46,8 @@ func parseQueryWithDefaultDb(c *Conn, defaultDb, query string) (driver.Stmt, err Stmt: &Stmt{query: query, conn: c, numInput: 0}, dbName: strings.TrimSpace(groups[0][2]), ifNotExists: strings.TrimSpace(groups[0][1]) != "", - withOptsStr: strings.TrimSpace(groups[0][3]), } - if err := stmt.parse(); err != nil { + if err := stmt.parse(strings.TrimSpace(groups[0][3])); err != nil { return nil, err } return stmt, stmt.validate() @@ -56,11 +55,10 @@ func parseQueryWithDefaultDb(c *Conn, defaultDb, query string) (driver.Stmt, err if re := reAlterDb; re.MatchString(query) { groups := re.FindAllStringSubmatch(query, -1) stmt := &StmtAlterDatabase{ - Stmt: &Stmt{query: query, conn: c, numInput: 0}, - dbName: strings.TrimSpace(groups[0][1]), - withOptsStr: strings.TrimSpace(groups[0][2]), + Stmt: &Stmt{query: query, conn: c, numInput: 0}, + dbName: strings.TrimSpace(groups[0][1]), } - if err := stmt.parse(); err != nil { + if err := stmt.parse(strings.TrimSpace(groups[0][2])); err != nil { return nil, err } return stmt, stmt.validate() @@ -88,12 +86,11 @@ func parseQueryWithDefaultDb(c *Conn, defaultDb, query string) (driver.Stmt, err ifNotExists: strings.TrimSpace(groups[0][2]) != "", dbName: strings.TrimSpace(groups[0][4]), collName: strings.TrimSpace(groups[0][5]), - withOptsStr: strings.TrimSpace(groups[0][6]), } if stmt.dbName == "" { stmt.dbName = defaultDb } - if err := stmt.parse(); err != nil { + if err := stmt.parse(strings.TrimSpace(groups[0][6])); err != nil { return nil, err } return stmt, stmt.validate() @@ -101,15 +98,14 @@ func parseQueryWithDefaultDb(c *Conn, defaultDb, query string) (driver.Stmt, err if re := reAlterColl; re.MatchString(query) { groups := re.FindAllStringSubmatch(query, -1) stmt := &StmtAlterCollection{ - Stmt: &Stmt{query: query, conn: c, numInput: 0}, - dbName: strings.TrimSpace(groups[0][3]), - collName: strings.TrimSpace(groups[0][4]), - withOptsStr: strings.TrimSpace(groups[0][5]), + Stmt: &Stmt{query: query, conn: c, numInput: 0}, + dbName: strings.TrimSpace(groups[0][3]), + collName: strings.TrimSpace(groups[0][4]), } if stmt.dbName == "" { stmt.dbName = defaultDb } - if err := stmt.parse(); err != nil { + if err := stmt.parse(strings.TrimSpace(groups[0][5])); err != nil { return nil, err } return stmt, stmt.validate() @@ -222,6 +218,18 @@ type Stmt struct { withOpts map[string]string } +func (s *Stmt) onlyOneWithOption(errmsg string, optKeys ...string) error { + ok := false + for _, k := range optKeys { + _, exist := s.withOpts[k] + if ok && exist { + return fmt.Errorf(errmsg) + } + ok = ok || exist + } + return nil +} + var reWithOpts = regexp.MustCompile(`(?is)^(\s+|\s*,\s+|\s+,\s*)WITH\s+` + field + `(\s*=\s*([\w/\.\*,;:'"-]+))?`) // parseWithOpts parses "WITH..." clause and store result in withOpts map. diff --git a/stmt_collection.go b/stmt_collection.go index 8c92c6e..60006e3 100644 --- a/stmt_collection.go +++ b/stmt_collection.go @@ -33,49 +33,40 @@ type StmtCreateCollection struct { ru, maxru int pk string // partition key uk [][]string // unique keys - withOptsStr string } -func (s *StmtCreateCollection) parse() error { - if err := s.Stmt.parseWithOpts(s.withOptsStr); err != nil { +func (s *StmtCreateCollection) parse(withOptsStr string) error { + if err := s.Stmt.parseWithOpts(withOptsStr); err != nil { return err } - // partition key - pk, largepk := s.withOpts["PK"], s.withOpts["LARGEPK"] - if pk != "" && largepk != "" && pk != largepk { - return fmt.Errorf("only one of PK or LARGEPK must be specified") - } - s.pk = s.withOpts["PK"] - if s.pk == "" { - s.pk = s.withOpts["LARGEPK"] - } - if s.pk == "" { - return fmt.Errorf("missting PartitionKey value") - } - - // request unit - if _, ok := s.withOpts["RU"]; ok { - ru, err := strconv.ParseInt(s.withOpts["RU"], 10, 64) - if err != nil || ru < 0 { - return fmt.Errorf("invalid RU value: %s", s.withOpts["RU"]) - } - s.ru = int(ru) - } - if _, ok := s.withOpts["MAXRU"]; ok { - maxru, err := strconv.ParseInt(s.withOpts["MAXRU"], 10, 64) - if err != nil || maxru < 0 { - return fmt.Errorf("invalid MAXRU value: %s", s.withOpts["MAXRU"]) - } - s.maxru = int(maxru) - } - - // unique key - if ukOpts, ok := s.withOpts["UK"]; ok && ukOpts != "" { - tokens := regexp.MustCompile(`[;:]+`).Split(ukOpts, -1) - for _, token := range tokens { - paths := regexp.MustCompile(`[,\s]+`).Split(token, -1) - s.uk = append(s.uk, paths) + for k, v := range s.withOpts { + switch k { + case "PK", "LARGEPK": + if s.pk != "" { + return fmt.Errorf("only one of PK or LARGEPK must be specified") + } + s.pk = v + case "RU": + ru, err := strconv.ParseInt(v, 10, 64) + if err != nil || ru < 0 { + return fmt.Errorf("invalid RU value: %s", v) + } + s.ru = int(ru) + case "MAXRU": + maxru, err := strconv.ParseInt(v, 10, 64) + if err != nil || maxru < 0 { + return fmt.Errorf("invalid MAXRU value: %s", v) + } + s.maxru = int(maxru) + case "UK": + tokens := regexp.MustCompile(`[;:]+`).Split(v, -1) + for _, token := range tokens { + paths := regexp.MustCompile(`[,\s]+`).Split(token, -1) + s.uk = append(s.uk, paths) + } + default: + return fmt.Errorf("invalid query, parsing error at WITH %s=%s", k, v) } } @@ -83,8 +74,11 @@ func (s *StmtCreateCollection) parse() error { } func (s *StmtCreateCollection) validate() error { + if s.pk == "" { + return fmt.Errorf("missing PartitionKey value") + } if s.ru > 0 && s.maxru > 0 { - return errors.New("only one of RU or MAXRU must be specified") + return errors.New("only one of RU or MAXRU should be specified") } if s.dbName == "" || s.collName == "" { return errors.New("database/collection is missing") @@ -142,30 +136,33 @@ func (s *StmtCreateCollection) Exec(_ []driver.Value) (driver.Result, error) { // Available since v0.1.1 type StmtAlterCollection struct { *Stmt - dbName string - collName string // collection name - ru, maxru int - withOptsStr string + dbName string + collName string // collection name + ru, maxru int } -func (s *StmtAlterCollection) parse() error { - if err := s.Stmt.parseWithOpts(s.withOptsStr); err != nil { +func (s *StmtAlterCollection) parse(withOptsStr string) error { + if err := s.Stmt.parseWithOpts(withOptsStr); err != nil { return err } - if _, ok := s.withOpts["RU"]; ok { - ru, err := strconv.ParseInt(s.withOpts["RU"], 10, 64) - if err != nil || ru < 0 { - return fmt.Errorf("invalid RU value: %s", s.withOpts["RU"]) - } - s.ru = int(ru) - } - if _, ok := s.withOpts["MAXRU"]; ok { - maxru, err := strconv.ParseInt(s.withOpts["MAXRU"], 10, 64) - if err != nil || maxru < 0 { - return fmt.Errorf("invalid MAXRU value: %s", s.withOpts["MAXRU"]) + for k, v := range s.withOpts { + switch k { + case "RU": + ru, err := strconv.ParseInt(v, 10, 64) + if err != nil || ru < 0 { + return fmt.Errorf("invalid RU value: %s", v) + } + s.ru = int(ru) + case "MAXRU": + maxru, err := strconv.ParseInt(v, 10, 64) + if err != nil || maxru < 0 { + return fmt.Errorf("invalid MAXRU value: %s", v) + } + s.maxru = int(maxru) + default: + return fmt.Errorf("invalid query, parsing error at WITH %s=%s", k, v) } - s.maxru = int(maxru) } return nil @@ -173,7 +170,7 @@ func (s *StmtAlterCollection) parse() error { func (s *StmtAlterCollection) validate() error { if (s.ru <= 0 && s.maxru <= 0) || (s.ru > 0 && s.maxru > 0) { - return errors.New("only one of RU or MAXRU must be specified") + return errors.New("only one of RU or MAXRU should be specified") } if s.dbName == "" || s.collName == "" { return errors.New("database/collection is missing") diff --git a/stmt_collection_parsing_test.go b/stmt_collection_parsing_test.go index 8c33e86..f69095b 100644 --- a/stmt_collection_parsing_test.go +++ b/stmt_collection_parsing_test.go @@ -24,6 +24,7 @@ func TestStmtCreateCollection_parse(t *testing.T) { {name: "error_invalid_maxru2", sql: "CREATE COLLECTION db.table WITH pk=/id WITH maxru=-1", mustError: true}, {name: "error_no_collection", sql: "CREATE TABLE db WITH pk=/id", mustError: true}, {name: "error_if_not_exist", sql: "CREATE TABLE IF NOT EXIST db.table WITH pk=/id", mustError: true}, + {name: "error_invalid_with", sql: "CREATE TABLE db.table WITH pk=/id, WITH a=1", mustError: true}, {name: "basic", sql: "CREATE COLLECTION db1.table1 WITH pk=/id", expected: &StmtCreateCollection{dbName: "db1", collName: "table1", pk: "/id"}}, {name: "table_with_ru", sql: "create\ntable\rdb-2.table_2 WITH\tPK=/email WITH\r\nru=100", expected: &StmtCreateCollection{dbName: "db-2", collName: "table_2", pk: "/email", ru: 100}}, @@ -48,7 +49,6 @@ func TestStmtCreateCollection_parse(t *testing.T) { t.Fatalf("%s failed: expected StmtCreateCollection but received %T", testName+"/"+testCase.name, s) } stmt.Stmt = nil - stmt.withOptsStr = "" if !reflect.DeepEqual(stmt, testCase.expected) { t.Fatalf("%s failed:\nexpected %#v\nreceived %#v", testName+"/"+testCase.name, testCase.expected, stmt) } @@ -66,6 +66,7 @@ func TestStmtCreateCollection_parse_defaultDb(t *testing.T) { mustError bool }{ {name: "error_invalid_query", db: "mydb", sql: "CREATE TABLE .mytable WITH pk=/id", mustError: true}, + {name: "error_invalid_with", db: "mydb", sql: "CREATE TABLE mytable WITH pk=/id WITH a", mustError: true}, {name: "basic", db: "mydb", sql: "CREATE COLLECTION table1 WITH pk=/id", expected: &StmtCreateCollection{dbName: "mydb", collName: "table1", pk: "/id"}}, {name: "db_in_query", db: "mydb", sql: "create\ntable\r\ndb2.table_2 WITH\r\t\nPK=/email WITH\nru=100", expected: &StmtCreateCollection{dbName: "db2", collName: "table_2", pk: "/email", ru: 100}}, @@ -90,7 +91,6 @@ func TestStmtCreateCollection_parse_defaultDb(t *testing.T) { t.Fatalf("%s failed: expected StmtCreateCollection but received %T", testName+"/"+testCase.name, s) } stmt.Stmt = nil - stmt.withOptsStr = "" if !reflect.DeepEqual(stmt, testCase.expected) { t.Fatalf("%s failed:\nexpected %#v\nreceived %#v", testName+"/"+testCase.name, testCase.expected, stmt) } @@ -112,6 +112,7 @@ func TestStmtAlterCollection_parse(t *testing.T) { {name: "error_ru_and_maxru", sql: "alter TABLE db.coll WITH ru=400 WITH maxru=4000", mustError: true}, {name: "error_invalid_ru", sql: "alter TABLE db.coll WITH ru=-1", mustError: true}, {name: "error_invalid_maxru", sql: "alter TABLE db.coll WITH maxru=-1", mustError: true}, + {name: "error_invalid_with", sql: "alter TABLE db.coll WITH ru=400, WITH a=1", mustError: true}, {name: "basic", sql: "ALTER collection db1.table1 WITH ru=400", expected: &StmtAlterCollection{dbName: "db1", collName: "table1", ru: 400}}, {name: "table", sql: "alter\nTABLE\rdb-2.table_2 WITH\tmaxru=40000", expected: &StmtAlterCollection{dbName: "db-2", collName: "table_2", maxru: 40000}}, @@ -133,7 +134,6 @@ func TestStmtAlterCollection_parse(t *testing.T) { t.Fatalf("%s failed: expected StmtAlterCollection but received %T", testName+"/"+testCase.name, s) } stmt.Stmt = nil - stmt.withOptsStr = "" if !reflect.DeepEqual(stmt, testCase.expected) { t.Fatalf("%s failed:\nexpected %#v\nreceived %#v", testName+"/"+testCase.name, testCase.expected, stmt) } @@ -153,6 +153,7 @@ func TestStmtAlterCollection_parse_defaultDb(t *testing.T) { {name: "error_invalid_query", db: "mydb", sql: "ALTER COLLECTION .mytable WITH ru=400", mustError: true}, {name: "error_notable", db: "mydb", sql: "ALTER COLLECTION mydb. WITH ru=400", mustError: true}, {name: "error_no_db_table", db: "mydb", sql: "ALTER COLLECTION WITH ru=400", mustError: true}, + {name: "error_invalid_with", db: "mydb", sql: "ALTER COLLECTION mytable WITH ru=400 WITH a", mustError: true}, {name: "basic", db: "mydb", sql: "ALTER collection table1 WITH ru=400", expected: &StmtAlterCollection{dbName: "mydb", collName: "table1", ru: 400}}, {name: "db_in_query", db: "mydb", sql: "alter\nTABLE\rdb-2.table_2 WITH\tmaxru=40000", expected: &StmtAlterCollection{dbName: "db-2", collName: "table_2", maxru: 40000}}, @@ -174,7 +175,6 @@ func TestStmtAlterCollection_parse_defaultDb(t *testing.T) { t.Fatalf("%s failed: expected StmtAlterCollection but received %T", testName+"/"+testCase.name, s) } stmt.Stmt = nil - stmt.withOptsStr = "" if !reflect.DeepEqual(stmt, testCase.expected) { t.Fatalf("%s failed:\nexpected %#v\nreceived %#v", testName+"/"+testCase.name, testCase.expected, stmt) } diff --git a/stmt_database.go b/stmt_database.go index cf15992..6e8b07b 100644 --- a/stmt_database.go +++ b/stmt_database.go @@ -21,33 +21,38 @@ type StmtCreateDatabase struct { dbName string ifNotExists bool ru, maxru int - withOptsStr string } -func (s *StmtCreateDatabase) parse() error { - if err := s.Stmt.parseWithOpts(s.withOptsStr); err != nil { +func (s *StmtCreateDatabase) parse(withOptsStr string) error { + if err := s.Stmt.parseWithOpts(withOptsStr); err != nil { return err } - if _, ok := s.withOpts["RU"]; ok { - ru, err := strconv.ParseInt(s.withOpts["RU"], 10, 64) - if err != nil || ru < 0 { - return fmt.Errorf("invalid RU value: %s", s.withOpts["RU"]) - } - s.ru = int(ru) - } - if _, ok := s.withOpts["MAXRU"]; ok { - maxru, err := strconv.ParseInt(s.withOpts["MAXRU"], 10, 64) - if err != nil || maxru < 0 { - return fmt.Errorf("invalid MAXRU value: %s", s.withOpts["MAXRU"]) + + for k, v := range s.withOpts { + switch k { + case "RU": + ru, err := strconv.ParseInt(v, 10, 64) + if err != nil || ru < 0 { + return fmt.Errorf("invalid RU value: %s", v) + } + s.ru = int(ru) + case "MAXRU": + maxru, err := strconv.ParseInt(v, 10, 64) + if err != nil || maxru < 0 { + return fmt.Errorf("invalid RU value: %s", v) + } + s.maxru = int(maxru) + default: + return fmt.Errorf("invalid query, parsing error at WITH %s=%s", k, v) } - s.maxru = int(maxru) } + return nil } func (s *StmtCreateDatabase) validate() error { if s.ru > 0 && s.maxru > 0 { - return errors.New("only one of RU or MAXRU must be specified") + return errors.New("only one of RU or MAXRU should be specified") } return nil } @@ -82,29 +87,34 @@ func (s *StmtCreateDatabase) Exec(_ []driver.Value) (driver.Result, error) { // Available since v0.1.1 type StmtAlterDatabase struct { *Stmt - dbName string - ru, maxru int - withOptsStr string + dbName string + ru, maxru int } -func (s *StmtAlterDatabase) parse() error { - if err := s.Stmt.parseWithOpts(s.withOptsStr); err != nil { +func (s *StmtAlterDatabase) parse(withOptsStr string) error { + if err := s.Stmt.parseWithOpts(withOptsStr); err != nil { return err } - if _, ok := s.withOpts["RU"]; ok { - ru, err := strconv.ParseInt(s.withOpts["RU"], 10, 64) - if err != nil || ru < 0 { - return fmt.Errorf("invalid RU value: %s", s.withOpts["RU"]) - } - s.ru = int(ru) - } - if _, ok := s.withOpts["MAXRU"]; ok { - maxru, err := strconv.ParseInt(s.withOpts["MAXRU"], 10, 64) - if err != nil || maxru < 0 { - return fmt.Errorf("invalid MAXRU value: %s", s.withOpts["MAXRU"]) + + for k, v := range s.withOpts { + switch k { + case "RU": + ru, err := strconv.ParseInt(v, 10, 64) + if err != nil || ru < 0 { + return fmt.Errorf("invalid RU value: %s", v) + } + s.ru = int(ru) + case "MAXRU": + maxru, err := strconv.ParseInt(v, 10, 64) + if err != nil || maxru < 0 { + return fmt.Errorf("invalid RU value: %s", v) + } + s.maxru = int(maxru) + default: + return fmt.Errorf("invalid query, parsing error at WITH %s=%s", k, v) } - s.maxru = int(maxru) } + return nil } diff --git a/stmt_database_parsing_test.go b/stmt_database_parsing_test.go index c55221b..bc64d20 100644 --- a/stmt_database_parsing_test.go +++ b/stmt_database_parsing_test.go @@ -19,6 +19,8 @@ func TestStmtCreateDatabase_parse(t *testing.T) { {name: "error_if_exists", sql: "CREATE DATABASE if exists db0", mustError: true}, {name: "error_if_not_exist", sql: "CREATE DATABASE IF NOT EXIST db0", mustError: true}, {name: "error_ru_and_maxru", sql: "CREATE DATABASE db0 with RU=400, WITH MAXru=4000", mustError: true}, + {name: "error_invalid_with", sql: "CREATE DATABASE db0 with a", mustError: true}, + {name: "error_invalid_with2", sql: "CREATE DATABASE db0 with a=1", mustError: true}, {name: "basic", sql: "CREATE DATABASE db1", expected: &StmtCreateDatabase{dbName: "db1"}}, {name: "with_ru", sql: "create\ndatabase\n db-2 \nWITH \n ru=100", expected: &StmtCreateDatabase{dbName: "db-2", ru: 100}}, @@ -44,7 +46,6 @@ func TestStmtCreateDatabase_parse(t *testing.T) { t.Fatalf("%s failed: expected StmtCreateDatabase but received %T", testName+"/"+testCase.name, s) } stmt.Stmt = nil - stmt.withOptsStr = "" if !reflect.DeepEqual(stmt, testCase.expected) { t.Fatalf("%s failed:\nexpected %#v\nreceived %#v", testName+"/"+testCase.name, testCase.expected, stmt) } @@ -62,6 +63,8 @@ func TestStmtAlterDatabase_parse(t *testing.T) { }{ {name: "error_no_ru_maxru", sql: "ALTER database db0", mustError: true}, {name: "error_ru_and_maxru", sql: "ALTER database db0 WITH RU=400, WITH maxRU=4000", mustError: true}, + {name: "error_invalid_with", sql: "ALTER database db0 WITH RU=400, WITH a", mustError: true}, + {name: "error_invalid_with2", sql: "ALTER database db0 WITH RU=400 WITH a=1", mustError: true}, {name: "with_ru", sql: "ALTER\rdatabase\ndb1\tWITH ru=400", expected: &StmtAlterDatabase{dbName: "db1", ru: 400}}, {name: "with_maxru", sql: "alter DATABASE db-1 with maxru=4000", expected: &StmtAlterDatabase{dbName: "db-1", maxru: 4000}}, @@ -83,7 +86,6 @@ func TestStmtAlterDatabase_parse(t *testing.T) { t.Fatalf("%s failed: expected StmtAlterDatabase but received %T", testName+"/"+testCase.name, s) } stmt.Stmt = nil - stmt.withOptsStr = "" if !reflect.DeepEqual(stmt, testCase.expected) { t.Fatalf("%s failed:\nexpected %#v\nreceived %#v", testName+"/"+testCase.name, testCase.expected, stmt) } diff --git a/stmt_document.go b/stmt_document.go index 55fdc4b..3b7edf2 100644 --- a/stmt_document.go +++ b/stmt_document.go @@ -104,9 +104,26 @@ func (s *StmtCRUD) parseWithOpts(withOptsStr string) error { if err := s.Stmt.parseWithOpts(withOptsStr); err != nil { return err } - _, ok1 := s.withOpts["SINGLEPK"] - _, ok2 := s.withOpts["SINGLE_PK"] - s.isSinglePathPk = ok1 || ok2 + + if err := s.onlyOneWithOption("single PK path is specified more than once, only one of SINGLE_PK or SINGLEPK should be specified", "SINGLE_PK", "SINGLEPK"); err != nil { + return err + } + + for k, v := range s.withOpts { + switch k { + case "SINGLE_PK", "SINGLEPK": + if v == "" { + s.isSinglePathPk = true + } else { + val, err := strconv.ParseBool(v) + if err != nil || !val { + return fmt.Errorf("invalid value at WITH %s (only value 'true' is accepted)", k) + } + s.isSinglePathPk = true + } + } + } + if s.isSinglePathPk { s.numPkPaths = 1 } @@ -118,7 +135,10 @@ func (s *StmtCRUD) parseWithOpts(withOptsStr string) error { // // Syntax: // -// INSERT|UPSERT INTO . () VALUES () [WITH singlePK|SINGLE_PK] +// INSERT|UPSERT INTO . +// () +// VALUES () +// [WITH singlePK|SINGLE_PK[=true]] // // - values are comma separated. // - a value is either: @@ -150,6 +170,12 @@ func (s *StmtInsert) parse(withOptsStr string) error { return err } + for k := range s.withOpts { + if k != "SINGLE_PK" && k != "SINGLEPK" { + return fmt.Errorf("invalid query, parsing error at WITH %s", k) + } + } + s.fields = regexp.MustCompile(`[,\s]+`).Split(s.fieldsStr, -1) s.values = make([]interface{}, 0) for temp := strings.TrimSpace(s.valuesStr); temp != ""; temp = strings.TrimSpace(temp) { @@ -170,7 +196,7 @@ func (s *StmtInsert) parse(withOptsStr string) error { func (s *StmtInsert) validate() error { if len(s.fields) != len(s.values) { - return fmt.Errorf("number of field (%d) does not match number of input value (%d)", len(s.fields), len(s.values)) + return fmt.Errorf("number of fields (%d) does not match number of input values (%d)", len(s.fields), len(s.values)) } if s.dbName == "" || s.collName == "" { return errors.New("database/collection is missing") @@ -229,7 +255,9 @@ func (s *StmtInsert) Query(_ []driver.Value) (driver.Rows, error) { // // Syntax: // -// DELETE FROM . WHERE id= [WITH singlePK|SINGLE_PK] +// DELETE FROM . +// WHERE id= +// [WITH singlePK|SINGLE_PK[=true]] // // - Currently DELETE only removes one document specified by id. // @@ -245,6 +273,12 @@ func (s *StmtDelete) parse(withOptsStr string) error { return err } + for k := range s.withOpts { + if k != "SINGLE_PK" && k != "SINGLEPK" { + return fmt.Errorf("invalid query, parsing error at WITH %s", k) + } + } + hasPrefix := strings.HasPrefix(s.idStr, `"`) hasSuffix := strings.HasSuffix(s.idStr, `"`) if hasPrefix != hasSuffix { @@ -324,7 +358,7 @@ func (s *StmtDelete) Query(_ []driver.Value) (driver.Rows, error) { // SELECT [CROSS PARTITION] ... FROM ... // WITH database|db= // [WITH collection|table=] -// [WITH cross_partition=true] +// [WITH cross_partition|CrossPartition[=true]] // // - (extension) If the collection is partitioned, specify "CROSS PARTITION" to allow execution across multiple partitions. // This clause is not required if query is to be executed on a single partition. @@ -343,25 +377,39 @@ type StmtSelect struct { } func (s *StmtSelect) parse(withOptsStr string) error { - if err := s.Stmt.parseWithOpts(withOptsStr); err != nil { + if err := s.parseWithOpts(withOptsStr); err != nil { return err } - if v, ok := s.withOpts["DATABASE"]; ok { - s.dbName = strings.TrimSpace(v) - } else if v, ok := s.withOpts["DB"]; ok { - s.dbName = strings.TrimSpace(v) - } - if v, ok := s.withOpts["COLLECTION"]; ok { - s.collName = strings.TrimSpace(v) - } else if v, ok := s.withOpts["TABLE"]; ok { - s.collName = strings.TrimSpace(v) + + if err := s.onlyOneWithOption("database is specified more than once, only one of DATABASE or DB should be specified", "DATABASE", "DB"); err != nil { + return err } - if v, ok := s.withOpts["CROSS_PARTITION"]; ok && !s.isCrossPartition { - vbool, err := strconv.ParseBool(v) - if err != nil || !vbool { - return errors.New("cannot parse query (the only accepted value for cross_partition is true), invalid token at: " + v) + + for k, v := range s.withOpts { + switch k { + case "DATABASE", "DB": + s.dbName = v + case "COLLECTION", "TABLE": + if s.collName != "" && s.collName != "c" && s.collName != "C" { + return errors.New("collection is specified more than once, only one of COLLECTION or TABLE should be specified") + } + s.collName = v + case "CROSS_PARTITION", "CROSSPARTITION": + if s.isCrossPartition { + return fmt.Errorf("cross-partition is specified more than once, only one of CROSS_PARTITION or CrossPartition should be specified") + } + if v == "" { + s.isCrossPartition = true + } else { + val, err := strconv.ParseBool(v) + if err != nil || !val { + return fmt.Errorf("invalid value at WITH %s (only value 'true' is accepted)", k) + } + s.isCrossPartition = true + } + default: + return fmt.Errorf("invalid query, parsing error at WITH %s", k) } - s.isCrossPartition = true } matches := reValPlaceholder.FindAllStringSubmatch(s.selectQuery, -1) @@ -427,7 +475,10 @@ func (s *StmtSelect) Exec(_ []driver.Value) (driver.Result, error) { // // Syntax: // -// UPDATE . SET =[,=]* WHERE id= [WITH singlePK|SINGLE_PK] +// UPDATE . +// SET =[,=]* +// WHERE id= +// [WITH singlePK|SINGLE_PK[=true]] // // - is treated as a string. `WHERE id=abc` has the same effect as `WHERE id="abc"`. // - is either: @@ -515,6 +566,12 @@ func (s *StmtUpdate) parse(withOptsStr string) error { return err } + for k := range s.withOpts { + if k != "SINGLE_PK" && k != "SINGLEPK" { + return fmt.Errorf("invalid query, parsing error at WITH %s", k) + } + } + if err := s._parseId(); err != nil { return err } diff --git a/stmt_document_parsing_test.go b/stmt_document_parsing_test.go index 69d6ff4..a0032b9 100644 --- a/stmt_document_parsing_test.go +++ b/stmt_document_parsing_test.go @@ -22,6 +22,7 @@ func TestStmtInsert_parse(t *testing.T) { {name: "error_num_values_not_matched", sql: `INSERT INTO db.table (a,b) VALUES (1,2,3)`, mustError: true}, {name: "error_invalid_number", sql: `INSERT INTO db.table (a,b) VALUES (0x1qa,2)`, mustError: true}, {name: "error_invalid_string", sql: `INSERT INTO db.table (a,b) VALUES ("cannot \\"unquote",2)`, mustError: true}, + {name: "error_invalid_with", sql: `INSERT INTO db.table (a,b) VALUES (1,2) WITH a`, mustError: true}, { name: "basic", @@ -51,9 +52,29 @@ $1, :3, @2)`, expected: &StmtInsert{StmtCRUD: &StmtCRUD{dbName: "db", collName: "table", isSinglePathPk: true, numPkPaths: 1}, fields: []string{"a", "b", "c"}, values: []interface{}{placeholder{1}, placeholder{2}, 3.0}}, }, { - name: "singlepk_single_pk", - sql: `INSERT INTO db.table (a,b,c) VALUES (1,2,@1) WITH singlePK, with SINGLE_PK`, - expected: &StmtInsert{StmtCRUD: &StmtCRUD{dbName: "db", collName: "table", isSinglePathPk: true, numPkPaths: 1}, fields: []string{"a", "b", "c"}, values: []interface{}{1.0, 2.0, placeholder{1}}}, + name: "singlepk2", + sql: `INSERT INTO db.table (a,b,c) VALUES (1,2,3) WITH singlePK=true`, + expected: &StmtInsert{StmtCRUD: &StmtCRUD{dbName: "db", collName: "table", isSinglePathPk: true, numPkPaths: 1}, fields: []string{"a", "b", "c"}, values: []interface{}{1.0, 2.0, 3.0}}, + }, + { + name: "single_pk2", + sql: `INSERT INTO db.table (a,b,c) VALUES (:1,$2,3) WITH SINGLE_PK=true`, + expected: &StmtInsert{StmtCRUD: &StmtCRUD{dbName: "db", collName: "table", isSinglePathPk: true, numPkPaths: 1}, fields: []string{"a", "b", "c"}, values: []interface{}{placeholder{1}, placeholder{2}, 3.0}}, + }, + { + name: "error_singlepk_single_pk", + sql: `INSERT INTO db.table (a,b,c) VALUES (1,2,@1) WITH singlePK, with SINGLE_PK`, + mustError: true, + }, + { + name: "error_invalid_singlepk", + sql: `INSERT INTO db.table (a,b,c) VALUES (1,2,3) WITH singlePK=false`, + mustError: true, + }, + { + name: "error_invalid_single_pk", + sql: `INSERT INTO db.table (a,b,c) VALUES (:1,$2,3) WITH SINGLE_PK=error`, + mustError: true, }, } for _, testCase := range testData { @@ -93,6 +114,7 @@ func TestStmtInsert_parse_defaultDb(t *testing.T) { }{ {name: "error_invalid_query", sql: `INSERT INTO .table (a,b) VALUES (1,2)`, mustError: true}, {name: "error_invalid_query2", sql: `INSERT INTO db. (a,b) VALUES (1,2)`, mustError: true}, + {name: "error_invalid_with", db: "mydb", sql: `INSERT INTO table (a,b) VALUES (1,2) WITH a=1`, mustError: true}, { name: "basic", @@ -119,6 +141,12 @@ $1, :3, @2)`, sql: `INSERT INTO table (a,b,c) VALUES (1,2,3) WITH singlePK`, expected: &StmtInsert{StmtCRUD: &StmtCRUD{dbName: "mydb", collName: "table", isSinglePathPk: true, numPkPaths: 1}, fields: []string{"a", "b", "c"}, values: []interface{}{1.0, 2.0, 3.0}}, }, + { + name: "singlepk2", + db: "mydb", + sql: `INSERT INTO table (a,b,c) VALUES (1,2,3) WITH singlePK=true`, + expected: &StmtInsert{StmtCRUD: &StmtCRUD{dbName: "mydb", collName: "table", isSinglePathPk: true, numPkPaths: 1}, fields: []string{"a", "b", "c"}, values: []interface{}{1.0, 2.0, 3.0}}, + }, { name: "single_pk", db: "mydb", @@ -126,10 +154,28 @@ $1, :3, @2)`, expected: &StmtInsert{StmtCRUD: &StmtCRUD{dbName: "db", collName: "table", isSinglePathPk: true, numPkPaths: 1}, fields: []string{"a", "b", "c"}, values: []interface{}{placeholder{1}, placeholder{2}, 3.0}}, }, { - name: "singlepk_single_pk", + name: "single_pk2", db: "mydb", - sql: `INSERT INTO table (a,b,c) VALUES (1,2,@1) WITH singlePK, with SINGLE_PK`, - expected: &StmtInsert{StmtCRUD: &StmtCRUD{dbName: "mydb", collName: "table", isSinglePathPk: true, numPkPaths: 1}, fields: []string{"a", "b", "c"}, values: []interface{}{1.0, 2.0, placeholder{1}}}, + sql: `INSERT INTO db.table (a,b,c) VALUES (:1,$2,3) WITH SINGLE_PK=true`, + expected: &StmtInsert{StmtCRUD: &StmtCRUD{dbName: "db", collName: "table", isSinglePathPk: true, numPkPaths: 1}, fields: []string{"a", "b", "c"}, values: []interface{}{placeholder{1}, placeholder{2}, 3.0}}, + }, + { + name: "error_singlepk_single_pk", + db: "mydb", + sql: `INSERT INTO table (a,b,c) VALUES (1,2,@1) WITH singlePK, with SINGLE_PK`, + mustError: true, + }, + { + name: "error_invalid_singlepk", + db: "mydb", + sql: `INSERT INTO table (a,b,c) VALUES (1,2,3) WITH singlePK=false`, + mustError: true, + }, + { + name: "error_invalid_single_pk", + db: "mydb", + sql: `INSERT INTO db.table (a,b,c) VALUES (:1,$2,3) WITH SINGLE_PK=error`, + mustError: true, }, } for _, testCase := range testData { @@ -175,6 +221,7 @@ func TestStmtUpsert_parse(t *testing.T) { {name: "error_num_values_not_matched", sql: `UPSERT INTO db.table (a,b) VALUES (1,2,3)`, mustError: true}, {name: "error_invalid_number", sql: `UPSERT INTO db.table (a,b) VALUES (0x1qa,2)`, mustError: true}, {name: "error_invalid_string", sql: `UPSERT INTO db.table (a,b) VALUES ("cannot \\"unquote",2)`, mustError: true}, + {name: "error_invalid_with", sql: `UPSERT INTO db.table (a,b) VALUES (1,2) WITH a`, mustError: true}, { name: "basic", @@ -205,10 +252,30 @@ a,b,c) VALUES ($1, expected: &StmtInsert{StmtCRUD: &StmtCRUD{dbName: "db", collName: "table", isSinglePathPk: true, numPkPaths: 1}, isUpsert: true, fields: []string{"a", "b", "c"}, values: []interface{}{placeholder{1}, placeholder{3}, placeholder{2}}}, }, { - name: "singlepk_single_pk", - sql: `UPSERT INTO db.table (a,b,c) VALUES (:1, :3, :2) WITH SINGLE_PK, with singlePK`, + name: "singlepk2", + sql: `UPSERT INTO db.table (a,b,c) VALUES (:1, :3, :2) WITH singlePK=true`, expected: &StmtInsert{StmtCRUD: &StmtCRUD{dbName: "db", collName: "table", isSinglePathPk: true, numPkPaths: 1}, isUpsert: true, fields: []string{"a", "b", "c"}, values: []interface{}{placeholder{1}, placeholder{3}, placeholder{2}}}, }, + { + name: "single_pk2", + sql: `UPSERT INTO db.table (a,b,c) VALUES (:1, :3, :2) WITH SINGLE_PK=true`, + expected: &StmtInsert{StmtCRUD: &StmtCRUD{dbName: "db", collName: "table", isSinglePathPk: true, numPkPaths: 1}, isUpsert: true, fields: []string{"a", "b", "c"}, values: []interface{}{placeholder{1}, placeholder{3}, placeholder{2}}}, + }, + { + name: "error_singlepk_single_pk", + sql: `UPSERT INTO db.table (a,b,c) VALUES (:1, :3, :2) WITH SINGLE_PK, with singlePK`, + mustError: true, + }, + { + name: "error_invalid_single_pk", + sql: `UPSERT INTO db.table (a,b,c) VALUES (:1, :3, :2) WITH SINGLE_PK=error`, + mustError: true, + }, + { + name: "error_invalid_singlepk", + sql: `UPSERT INTO db.table (a,b,c) VALUES (:1, :3, :2) WITH singlePK=false`, + mustError: true, + }, } for _, testCase := range testData { t.Run(testCase.name, func(t *testing.T) { @@ -247,6 +314,7 @@ func TestStmtUpsert_parse_defaultDb(t *testing.T) { }{ {name: "error_invalid_query", sql: `UPSERT INTO .table (a,b) VALUES (1,2)`, mustError: true}, {name: "error_invalid_query2", sql: `UPSERT INTO db. (a,b) VALUES (1,2)`, mustError: true}, + {name: "error_invalid_with", db: "mydb", sql: `UPSERT INTO table (a,b) VALUES (1,2) WITH a=1`, mustError: true}, { name: "basic", @@ -281,11 +349,35 @@ a,b,c) VALUES ($1, expected: &StmtInsert{StmtCRUD: &StmtCRUD{dbName: "mydb", collName: "table", isSinglePathPk: true, numPkPaths: 1}, isUpsert: true, fields: []string{"a", "b", "c"}, values: []interface{}{placeholder{1}, placeholder{3}, placeholder{2}}}, }, { - name: "singlepk_single_pk", + name: "singlepk2", db: "mydb", - sql: `UPSERT INTO db.table (a,b,c) VALUES ($1, :3, @2) WITH single_pk WITH singlePK`, + sql: `UPSERT INTO db.table (a,b,c) VALUES ($1, :3, @2) WITH SINGLEPK=true`, expected: &StmtInsert{StmtCRUD: &StmtCRUD{dbName: "db", collName: "table", isSinglePathPk: true, numPkPaths: 1}, isUpsert: true, fields: []string{"a", "b", "c"}, values: []interface{}{placeholder{1}, placeholder{3}, placeholder{2}}}, }, + { + name: "single_pk2", + db: "mydb", + sql: `UPSERT INTO table (a,b,c) VALUES ($1, :3, @2) WITH single_pk=true`, + expected: &StmtInsert{StmtCRUD: &StmtCRUD{dbName: "mydb", collName: "table", isSinglePathPk: true, numPkPaths: 1}, isUpsert: true, fields: []string{"a", "b", "c"}, values: []interface{}{placeholder{1}, placeholder{3}, placeholder{2}}}, + }, + { + name: "error_singlepk_single_pk", + db: "mydb", + sql: `UPSERT INTO db.table (a,b,c) VALUES ($1, :3, @2) WITH single_pk WITH singlePK`, + mustError: true, + }, + { + name: "error_invalid_singlepk", + db: "mydb", + sql: `UPSERT INTO db.table (a,b,c) VALUES ($1, :3, @2) WITH SINGLEPK=false`, + mustError: true, + }, + { + name: "error_invalid_single_pk", + db: "mydb", + sql: `UPSERT INTO table (a,b,c) VALUES ($1, :3, @2) WITH single_pk=error`, + mustError: true, + }, } for _, testCase := range testData { t.Run(testCase.name, func(t *testing.T) { @@ -329,6 +421,7 @@ func TestStmtDelete_parse(t *testing.T) { {name: "error_invalid_where", sql: `DELETE FROM db.table WHERE id=@1 a`, mustError: true}, {name: "error_invalid_where2", sql: `DELETE FROM db.table WHERE id=b $2`, mustError: true}, {name: "error_invalid_where3", sql: `DELETE FROM db.table WHERE id=c :3 d`, mustError: true}, + {name: "error_invalid_with", sql: `DELETE FROM db.table WHERE id=1 WITH a`, mustError: true}, { name: "basic", @@ -363,10 +456,30 @@ db_3-0.table-3_0 WHERE expected: &StmtDelete{StmtCRUD: &StmtCRUD{dbName: "db", collName: "table", isSinglePathPk: true, numPkPaths: 1}, idStr: "@2", id: placeholder{2}}, }, { - name: "singlepk_single_pk", - sql: `DELETE FROM db.table WHERE id=@2 with SinglePK WITH SINGLE_PK`, + name: "singlepk2", + sql: `DELETE FROM db.table WHERE id=@2 WITH singlePK=true`, expected: &StmtDelete{StmtCRUD: &StmtCRUD{dbName: "db", collName: "table", isSinglePathPk: true, numPkPaths: 1}, idStr: "@2", id: placeholder{2}}, }, + { + name: "single_pk", + sql: `DELETE FROM db.table WHERE id=@2 with Single_PK=true`, + expected: &StmtDelete{StmtCRUD: &StmtCRUD{dbName: "db", collName: "table", isSinglePathPk: true, numPkPaths: 1}, idStr: "@2", id: placeholder{2}}, + }, + { + name: "error_singlepk_single_pk", + sql: `DELETE FROM db.table WHERE id=@2 with SinglePK WITH SINGLE_PK`, + mustError: true, + }, + { + name: "error_invalid_singlepk", + sql: `DELETE FROM db.table WHERE id=@2 WITH singlePK=false`, + mustError: true, + }, + { + name: "error_invalid_single_pk", + sql: `DELETE FROM db.table WHERE id=@2 with Single_PK=error`, + mustError: true, + }, } for _, testCase := range testData { t.Run(testCase.name, func(t *testing.T) { @@ -403,6 +516,7 @@ func TestStmtDelete_parse_defaultDb(t *testing.T) { }{ {name: "error_invalid_query", sql: `DELETE FROM .table WHERE id=1`, mustError: true}, {name: "error_invalid_query2", sql: `DELETE FROM db. WHERE id=1`, mustError: true}, + {name: "error_invalid_with", db: "mydb", sql: `DELETE FROM table WHERE id=1 WITH a=1`, mustError: true}, { name: "basic", @@ -442,11 +556,35 @@ db_3-0.table-3_0 WHERE expected: &StmtDelete{StmtCRUD: &StmtCRUD{dbName: "db", collName: "table", isSinglePathPk: true, numPkPaths: 1}, idStr: "@2", id: placeholder{2}}, }, { - name: "singlepk_single_pk", + name: "singlepk2", db: "mydb", - sql: `DELETE FROM table WHERE id=@2 With single_Pk, With SinglePK`, + sql: `DELETE FROM table WHERE id=@2 With singlePk=true`, expected: &StmtDelete{StmtCRUD: &StmtCRUD{dbName: "mydb", collName: "table", isSinglePathPk: true, numPkPaths: 1}, idStr: "@2", id: placeholder{2}}, }, + { + name: "single_pk2", + db: "mydb", + sql: `DELETE FROM db.table WHERE id=@2 With single_Pk=true`, + expected: &StmtDelete{StmtCRUD: &StmtCRUD{dbName: "db", collName: "table", isSinglePathPk: true, numPkPaths: 1}, idStr: "@2", id: placeholder{2}}, + }, + { + name: "error_singlepk_single_pk", + db: "mydb", + sql: `DELETE FROM table WHERE id=@2 With single_Pk, With SinglePK`, + mustError: true, + }, + { + name: "error_invalid_singlepk", + db: "mydb", + sql: `DELETE FROM table WHERE id=@2 With singlePk=false`, + mustError: true, + }, + { + name: "error_invalid_single_pk", + db: "mydb", + sql: `DELETE FROM db.table WHERE id=@2 With single_Pk=error`, + mustError: true, + }, } for _, testCase := range testData { t.Run(testCase.name, func(t *testing.T) { @@ -484,6 +622,11 @@ func TestStmtSelect_parse(t *testing.T) { {name: "error_no_collection", sql: `SELECT * WITH db=dbname`, mustError: true}, {name: "error_no_db", sql: `SELECT * FROM c WITH collection=collname`, mustError: true}, {name: "error_cross_partition_must_be_true", sql: `SELECT * FROM c WITH db=dbname WITH collection=collname WITH cross_partition=false`, mustError: true}, + {name: "error_cross_partition_must_be_true2", sql: `SELECT * FROM c WITH db=dbname WITH collection=collname WITH cross_partition=error`, mustError: true}, + {name: "error_cross_partition_more_than_once", sql: `SELECT * FROM c WITH db=dbname WITH collection=collname WITH cross_partition WITH CrossPartition=true`, mustError: true}, + {name: "error_cross_partition_more_than_once2", sql: `SELECT CROSS PARTITION * FROM c WITH db=dbname WITH collection=collname WITH CrossPartition`, mustError: true}, + {name: "error_invalid_with", sql: `SELECT * FROM c WITH db=dbname WITH collection=collname WITH a`, mustError: true}, + {name: "error_invalid_with2", sql: `SELECT * FROM c WITH db=dbname WITH collection=collname WITH a=1`, mustError: true}, { name: "basic", @@ -502,8 +645,8 @@ func TestStmtSelect_parse(t *testing.T) { }, { name: "collection_in_query", - sql: `SELECT a,b,c FROM user u WHERE u.id="1" WITH db=dbtemp`, - expected: &StmtSelect{dbName: "dbtemp", collName: "user", selectQuery: `SELECT a,b,c FROM user u WHERE u.id="1"`, placeholders: map[int]string{}}, + sql: `SELECT a,b,c FROM user u WHERE u.id="1" WITH db=dbtemp WITH CrossPartition`, + expected: &StmtSelect{dbName: "dbtemp", collName: "user", isCrossPartition: true, selectQuery: `SELECT a,b,c FROM user u WHERE u.id="1"`, placeholders: map[int]string{}}, }, } for _, testCase := range testData { @@ -560,8 +703,8 @@ func TestStmtSelect_parse_defaultDb(t *testing.T) { { name: "collection_in_query", db: "mydb", - sql: `SELECT a,b,c FROM user u WHERE u.id="1"`, - expected: &StmtSelect{dbName: "mydb", collName: "user", selectQuery: `SELECT a,b,c FROM user u WHERE u.id="1"`, placeholders: map[int]string{}}, + sql: `SELECT a,b,c FROM user u WHERE u.id="1" with CrossPartition`, + expected: &StmtSelect{dbName: "mydb", collName: "user", isCrossPartition: true, selectQuery: `SELECT a,b,c FROM user u WHERE u.id="1"`, placeholders: map[int]string{}}, }, } for _, testCase := range testData { @@ -605,6 +748,7 @@ func TestStmtUpdate_parse(t *testing.T) { {name: "error_invalid_query", sql: `UPDATE db.table SET =1 WHERE id=2`, mustError: true}, {name: "error_invalid_query2", sql: `UPDATE db.table SET a=1 WHERE id= `, mustError: true}, {name: "error_invalid_query3", sql: `UPDATE db.table SET a=1,b=2,c=3 WHERE id="4`, mustError: true}, + {name: "error_invalid_with", sql: `UPDATE db.table SET a=1,b=2,c=3 WHERE id=4 WITH a`, mustError: true}, { name: "basic", @@ -639,10 +783,30 @@ SET a=$1, b= expected: &StmtUpdate{StmtCRUD: &StmtCRUD{dbName: "db", collName: "table", isSinglePathPk: true, numPkPaths: 1}, updateStr: `a=$1, b=$2, c=:3, d=0`, fields: []string{"a", "b", "c", "d"}, values: []interface{}{placeholder{1}, placeholder{2}, placeholder{3}, 0.0}, idStr: "@4", id: placeholder{4}}, }, { - name: "singlepk_single_pk", - sql: `UPDATE db.table SET a=$1, b=$2, c=:3, d=0 WHERE id=@4 with SINGLE_PK, With SinglePk`, + name: "singlepk2", + sql: `UPDATE db.table SET a=$1, b=$2, c=:3, d=0 WHERE id=@4 with SinglePk=true`, expected: &StmtUpdate{StmtCRUD: &StmtCRUD{dbName: "db", collName: "table", isSinglePathPk: true, numPkPaths: 1}, updateStr: `a=$1, b=$2, c=:3, d=0`, fields: []string{"a", "b", "c", "d"}, values: []interface{}{placeholder{1}, placeholder{2}, placeholder{3}, 0.0}, idStr: "@4", id: placeholder{4}}, }, + { + name: "single_pk2", + sql: `UPDATE db.table SET a=$1, b=$2, c=:3, d=0 WHERE id=@4 WITH SINGLE_PK=true`, + expected: &StmtUpdate{StmtCRUD: &StmtCRUD{dbName: "db", collName: "table", isSinglePathPk: true, numPkPaths: 1}, updateStr: `a=$1, b=$2, c=:3, d=0`, fields: []string{"a", "b", "c", "d"}, values: []interface{}{placeholder{1}, placeholder{2}, placeholder{3}, 0.0}, idStr: "@4", id: placeholder{4}}, + }, + { + name: "error_singlepk_single_pk", + sql: `UPDATE db.table SET a=$1, b=$2, c=:3, d=0 WHERE id=@4 with SINGLE_PK, With SinglePk`, + mustError: true, + }, + { + name: "error_invalid_singlepk", + sql: `UPDATE db.table SET a=$1, b=$2, c=:3, d=0 WHERE id=@4 with SinglePk=false`, + mustError: true, + }, + { + name: "error_invalid_single_pk", + sql: `UPDATE db.table SET a=$1, b=$2, c=:3, d=0 WHERE id=@4 WITH SINGLE_PK=error`, + mustError: true, + }, } for _, testCase := range testData { t.Run(testCase.name, func(t *testing.T) { @@ -679,6 +843,7 @@ func TestStmtUpdate_parse_defaultDb(t *testing.T) { }{ {name: "error_invalid_query", sql: `UPDATE .table SET a=1,b=2,c=3 WHERE id=4`, mustError: true}, {name: "error_invalid_query2", sql: `UPDATE db. SET a=1,b=2,c=3 WHERE id=4`, mustError: true}, + {name: "error_invalid_with", db: "mydb", sql: `UPDATE table SET a=1,b=2,c=3 WHERE id=4 WITH a=1`, mustError: true}, { name: "basic", @@ -716,11 +881,35 @@ SET a=$1, b= expected: &StmtUpdate{StmtCRUD: &StmtCRUD{dbName: "db", collName: "table", isSinglePathPk: true, numPkPaths: 1}, updateStr: `a=$1, b=$2, c=:3, d=0`, fields: []string{"a", "b", "c", "d"}, values: []interface{}{placeholder{1}, placeholder{2}, placeholder{3}, 0.0}, idStr: "@4", id: placeholder{4}}, }, { - name: "singlepk_single_pk", + name: "singlepk2", db: "mydb", - sql: `UPDATE table SET a=$1, b=$2, c=:3, d=0 WHERE id=@4 with SINGLE_PK, With SinglePk`, + sql: `UPDATE table SET a=$1, b=$2, c=:3, d=0 WHERE id=@4 with SinglePk=true`, expected: &StmtUpdate{StmtCRUD: &StmtCRUD{dbName: "mydb", collName: "table", isSinglePathPk: true, numPkPaths: 1}, updateStr: `a=$1, b=$2, c=:3, d=0`, fields: []string{"a", "b", "c", "d"}, values: []interface{}{placeholder{1}, placeholder{2}, placeholder{3}, 0.0}, idStr: "@4", id: placeholder{4}}, }, + { + name: "single_pk2", + db: "mydb", + sql: `UPDATE db.table SET a=$1, b=$2, c=:3, d=0 WHERE id=@4 WITH SINGLE_PK=true`, + expected: &StmtUpdate{StmtCRUD: &StmtCRUD{dbName: "db", collName: "table", isSinglePathPk: true, numPkPaths: 1}, updateStr: `a=$1, b=$2, c=:3, d=0`, fields: []string{"a", "b", "c", "d"}, values: []interface{}{placeholder{1}, placeholder{2}, placeholder{3}, 0.0}, idStr: "@4", id: placeholder{4}}, + }, + { + name: "error_singlepk_single_pk", + db: "mydb", + sql: `UPDATE table SET a=$1, b=$2, c=:3, d=0 WHERE id=@4 with SINGLE_PK, With SinglePk`, + mustError: true, + }, + { + name: "error_invalid_singlepk", + db: "mydb", + sql: `UPDATE table SET a=$1, b=$2, c=:3, d=0 WHERE id=@4 with SinglePk=false`, + mustError: true, + }, + { + name: "error_invalid_single_pk", + db: "mydb", + sql: `UPDATE db.table SET a=$1, b=$2, c=:3, d=0 WHERE id=@4 WITH SINGLE_PK=error`, + mustError: true, + }, } for _, testCase := range testData { t.Run(testCase.name, func(t *testing.T) { diff --git a/stmt_test.go b/stmt_test.go index 609f904..54772d3 100644 --- a/stmt_test.go +++ b/stmt_test.go @@ -55,33 +55,33 @@ func _fetchAllRows(dbRows *sql.Rows) ([]map[string]interface{}, error) { return rows, nil } -func TestStmt_NumInput(t *testing.T) { - name := "TestStmt_NumInput" - testData := map[string]int{ - "CREATE DATABASE dbtemp": 0, - "DROP DATABASE dbtemp": 0, - "CREATE DATABASE IF NOT EXISTS dbtemp": 0, - "DROP DATABASE IF EXISTS dbtemp": 0, - - "CREATE TABLE db.tbltemp WITH pk=/id": 0, - "DROP TABLE db.tbltemp": 0, - "CREATE TABLE IF NOT EXISTS db.tbltemp WITH pk=/id": 0, - "DROP TABLE IF EXISTS db.tbltemp": 0, - "CREATE COLLECTION db.tbltemp WITH pk=/id": 0, - "DROP COLLECTION db.tbltemp": 0, - "CREATE COLLECTION IF NOT EXISTS db.tbltemp WITH pk=/id": 0, - "DROP COLLECTION IF EXISTS db.tbltemp": 0, - - "SELECT * FROM tbltemp WHERE id=@1 AND email=$2 OR username=:3 WITH db=mydb": 3, - "INSERT INTO db.tbltemp (id, name, email) VALUES ($1, :2, @3)": 3 + 1, // need one extra input for partition key - "DELETE FROM db.tbltemp WHERE id=$1": 1 + 1, // need one extra input for partition key - } - - for query, numInput := range testData { - if stmt, err := parseQuery(nil, query); err != nil { - t.Fatalf("%s failed: %s", name+"/"+query, err) - } else if v := stmt.NumInput(); v != numInput { - t.Fatalf("%s failed: expected %#v but received %#v", name+"/"+query, numInput, v) - } - } -} +// func TestStmt_NumInput(t *testing.T) { +// name := "TestStmt_NumInput" +// testData := map[string]int{ +// "CREATE DATABASE dbtemp": 0, +// "DROP DATABASE dbtemp": 0, +// "CREATE DATABASE IF NOT EXISTS dbtemp": 0, +// "DROP DATABASE IF EXISTS dbtemp": 0, +// +// "CREATE TABLE db.tbltemp WITH pk=/id": 0, +// "DROP TABLE db.tbltemp": 0, +// "CREATE TABLE IF NOT EXISTS db.tbltemp WITH pk=/id": 0, +// "DROP TABLE IF EXISTS db.tbltemp": 0, +// "CREATE COLLECTION db.tbltemp WITH pk=/id": 0, +// "DROP COLLECTION db.tbltemp": 0, +// "CREATE COLLECTION IF NOT EXISTS db.tbltemp WITH pk=/id": 0, +// "DROP COLLECTION IF EXISTS db.tbltemp": 0, +// +// "SELECT * FROM tbltemp WHERE id=@1 AND email=$2 OR username=:3 WITH db=mydb": 3, +// "INSERT INTO db.tbltemp (id, name, email) VALUES ($1, :2, @3)": 3 + 1, // need one extra input for partition key +// "DELETE FROM db.tbltemp WHERE id=$1": 1 + 1, // need one extra input for partition key +// } +// +// for query, numInput := range testData { +// if stmt, err := parseQuery(nil, query); err != nil { +// t.Fatalf("%s failed: %s", name+"/"+query, err) +// } else if v := stmt.NumInput(); v != numInput { +// t.Fatalf("%s failed: expected %#v but received %#v", name+"/"+query, numInput, v) +// } +// } +// } From fd8824e3360009bd1a03ba1a868ab810d74d891a Mon Sep 17 00:00:00 2001 From: Thanh Nguyen Date: Fri, 16 Jun 2023 23:37:48 +1000 Subject: [PATCH 7/8] validation fix --- stmt_document.go | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/stmt_document.go b/stmt_document.go index 3b7edf2..7013546 100644 --- a/stmt_document.go +++ b/stmt_document.go @@ -384,15 +384,15 @@ func (s *StmtSelect) parse(withOptsStr string) error { if err := s.onlyOneWithOption("database is specified more than once, only one of DATABASE or DB should be specified", "DATABASE", "DB"); err != nil { return err } + if err := s.onlyOneWithOption("collection is specified more than once, only one of COLLECTION or TABLE should be specified", "COLLECTION", "TABLE"); err != nil { + return err + } for k, v := range s.withOpts { switch k { case "DATABASE", "DB": s.dbName = v case "COLLECTION", "TABLE": - if s.collName != "" && s.collName != "c" && s.collName != "C" { - return errors.New("collection is specified more than once, only one of COLLECTION or TABLE should be specified") - } s.collName = v case "CROSS_PARTITION", "CROSSPARTITION": if s.isCrossPartition { From 6facdde8cbc68454e1cdef8666f7844c14f0f556 Mon Sep 17 00:00:00 2001 From: Thanh Nguyen Date: Sat, 17 Jun 2023 10:23:31 +1000 Subject: [PATCH 8/8] prepare to release v0.3.0 --- README.md | 2 +- RELEASE-NOTES.md | 2 +- data_test.go | 61 ++++++++++++++++++++++ gocosmos.go | 2 +- stmt_document_select_test.go | 98 ++++++++++++++++++++++++++++++++++++ 5 files changed, 162 insertions(+), 3 deletions(-) diff --git a/README.md b/README.md index 9873ffc..6be28ec 100644 --- a/README.md +++ b/README.md @@ -20,7 +20,7 @@ Summary of supported SQL statements: |Delete an existing database |`DROP DATABASE [IF EXISTS] `| |List all existing databases |`LIST DATABASES`| |Create a new collection |`CREATE COLLECTION [IF NOT EXISTS] [.] `| -|Change collection's throughput |`ALTER COLLECTION [.] WITH RU/MAXRU=`| +|Change collection's throughput |`ALTER COLLECTION [.] WITH RU|MAXRU=`| |Delete an existing collection |`DROP COLLECTION [IF EXISTS] [.]`| |List all existing collections in a database|`LIST COLLECTIONS [FROM ]`| |Insert a new document into collection |`INSERT INTO [.] ...`| diff --git a/RELEASE-NOTES.md b/RELEASE-NOTES.md index f8f1c00..bff768a 100644 --- a/RELEASE-NOTES.md +++ b/RELEASE-NOTES.md @@ -1,6 +1,6 @@ # gocosmos - Release notes -## 2003-06-0x - v0.3.0 +## 2023-06-16 - v0.3.0 - Change default API version to `2020-07-15`. - Add [Hierarchical Partition Keys](https://learn.microsoft.com/en-us/azure/cosmos-db/hierarchical-partition-keys) (sub-partitions) support. diff --git a/data_test.go b/data_test.go index 116b17d..6fa411b 100644 --- a/data_test.go +++ b/data_test.go @@ -14,11 +14,72 @@ import ( /*======================================================================*/ +const numApps = 4 const numLogicalPartitions = 16 const numCategories = 19 var dataList []DocInfo +func _initDataSubPartitions(t *testing.T, testName string, client *RestClient, db, container string, numItem int) { + totalRu := 0.0 + randList := make([]int, numItem) + for i := 0; i < numItem; i++ { + randList[i] = i*2 + 1 + } + rand.Shuffle(numItem, func(i, j int) { + randList[i], randList[j] = randList[j], randList[i] + }) + dataList = make([]DocInfo, numItem) + for i := 0; i < numItem; i++ { + category := randList[i] % numCategories + app := "app" + strconv.Itoa(i%numApps) + username := "user" + strconv.Itoa(i%numLogicalPartitions) + docInfo := DocInfo{ + "id": fmt.Sprintf("%05d", i), + "app": app, + "username": username, + "email": "user" + strconv.Itoa(i) + "@domain.com", + "grade": float64(randList[i]), + "category": float64(category), + "active": i%10 == 0, + "big": fmt.Sprintf("%05d", i) + "/" + strings.Repeat("this is a very long string/", 256), + } + dataList[i] = docInfo + if result := client.CreateDocument(DocumentSpec{DbName: db, CollName: container, PartitionKeyValues: []interface{}{app, username}, DocumentData: docInfo}); result.Error() != nil { + t.Fatalf("%s failed: %s", testName, result.Error()) + } else { + totalRu += result.RequestCharge + } + } + // fmt.Printf("\t%s - total RU charged: %0.3f\n", testName+"/Insert", totalRu) +} + +func _initDataSubPartitionsSmallRU(t *testing.T, testName string, client *RestClient, db, container string, numItem int) { + client.DeleteDatabase(db) + client.CreateDatabase(DatabaseSpec{Id: db, Ru: 400}) + client.CreateCollection(CollectionSpec{ + DbName: db, + CollName: container, + PartitionKeyInfo: map[string]interface{}{"paths": []string{"/app", "/username"}, "kind": "MultiHash", "version": 2}, + UniqueKeyPolicy: map[string]interface{}{"uniqueKeys": []map[string]interface{}{{"paths": []string{"/email"}}}}, + Ru: 400, + }) + _initDataSubPartitions(t, testName, client, db, container, numItem) +} + +func _initDataSubPartitionsLargeRU(t *testing.T, testName string, client *RestClient, db, container string, numItem int) { + client.DeleteDatabase(db) + client.CreateDatabase(DatabaseSpec{Id: db, Ru: 20000}) + client.CreateCollection(CollectionSpec{ + DbName: db, + CollName: container, + PartitionKeyInfo: map[string]interface{}{"paths": []string{"/app", "/username"}, "kind": "MultiHash", "version": 2}, + UniqueKeyPolicy: map[string]interface{}{"uniqueKeys": []map[string]interface{}{{"paths": []string{"/email"}}}}, + Ru: 20000, + }) + _initDataSubPartitions(t, testName, client, db, container, numItem) +} + func _initData(t *testing.T, testName string, client *RestClient, db, container string, numItem int) { totalRu := 0.0 randList := make([]int, numItem) diff --git a/gocosmos.go b/gocosmos.go index ee41912..0fbc91f 100644 --- a/gocosmos.go +++ b/gocosmos.go @@ -7,7 +7,7 @@ import ( const ( // Version of package gocosmos. - Version = "0.2.1" + Version = "0.3.0" ) func goTypeToCosmosDbType(typ reflect.Type) string { diff --git a/stmt_document_select_test.go b/stmt_document_select_test.go index dda1399..e41f573 100644 --- a/stmt_document_select_test.go +++ b/stmt_document_select_test.go @@ -25,6 +25,104 @@ func TestStmtSelect_Exec(t *testing.T) { /*----------------------------------------------------------------------*/ +func _testSelectPkValueSubPartitions(t *testing.T, testName string, db *sql.DB, collname string) { + low, high := 123, 987 + lowStr, highStr := fmt.Sprintf("%05d", low), fmt.Sprintf("%05d", high) + countPerPartition := _countPerPartition(low, high, dataList) + distinctPerPartition := _distinctPerPartition(low, high, dataList, "category") + var testCases = []queryTestCase{ + {name: "NoLimit_Bare", query: "SELECT * FROM c WHERE $1<=c.id AND c.id<@2 AND c.username=:3 WITH collection=%s WITH cross_partition=true"}, + {name: "OffsetLimit_Bare", query: "SELECT * FROM c WHERE $1<=c.id AND c.id<@2 AND c.username=:3 OFFSET 3 LIMIT 5 WITH collection=%s WITH cross_partition=true", expectedNumItems: 5}, + {name: "NoLimit_OrderAsc", query: "SELECT * FROM c WHERE $1<=c.id AND c.id<@2 AND c.username=:3 ORDER BY c.grade WITH collection=%s WITH cross_partition=true", orderType: reddo.TypeInt, orderField: "grade", orderDirection: "asc"}, + {name: "OffsetLimit_OrderDesc", query: "SELECT * FROM c WHERE $1<=c.id AND c.id<@2 AND c.username=:3 ORDER BY c.category DESC OFFSET 3 LIMIT 5 WITH collection=%s WITH cross_partition=true", expectedNumItems: 5, orderType: reddo.TypeInt, orderField: "category", orderDirection: "desc"}, + + {name: "NoLimit_DistinctValue", query: "SELECT DISTINCT VALUE c.category FROM c WHERE $1<=c.id AND c.id<@2 AND c.username=:3 WITH collection=%s WITH cross_partition=true", distinctQuery: 1}, + {name: "NoLimit_DistinctDoc", query: "SELECT DISTINCT c.category FROM c WHERE $1<=c.id AND c.id<@2 AND c.username=:3 WITH collection=%s WITH cross_partition=true", distinctQuery: -1}, + {name: "OffsetLimit_DistinctValue", query: "SELECT DISTINCT VALUE c.category FROM c WHERE $1<=c.id AND c.id<@2 AND c.username=:3 OFFSET 1 LIMIT 3 WITH collection=%s WITH cross_partition=true", distinctQuery: 1, expectedNumItems: 3}, + {name: "OffsetLimit_DistinctDoc", query: "SELECT DISTINCT c.category FROM c WHERE $1<=c.id AND c.id<@2 AND c.username=:3 OFFSET 1 LIMIT 3 WITH collection=%s WITH cross_partition=true", distinctQuery: -1, expectedNumItems: 3}, + + {name: "NoLimit_DistinctValue_OrderAsc", query: "SELECT DISTINCT VALUE c.category FROM c WHERE $1<=c.id AND c.id<@2 AND c.username=:3 ORDER BY c.category WITH collection=%s WITH cross_partition=true", distinctQuery: 1, orderType: reddo.TypeInt, orderField: "$1", orderDirection: "asc"}, + {name: "NoLimit_DistinctDoc_OrderDesc", query: "SELECT DISTINCT c.category FROM c WHERE $1<=c.id AND c.id<@2 AND c.username=:3 ORDER BY c.category DESC WITH collection=%s WITH cross_partition=true", distinctQuery: -1, orderType: reddo.TypeInt, orderField: "category", orderDirection: "desc"}, + {name: "OffsetLimit_DistinctValue_OrderAsc", query: "SELECT DISTINCT VALUE c.category FROM c WHERE $1<=c.id AND c.id<@2 AND c.username=:3 ORDER BY c.category OFFSET 1 LIMIT 3 WITH collection=%s WITH cross_partition=true", distinctQuery: 1, orderType: reddo.TypeInt, orderField: "$1", orderDirection: "asc", expectedNumItems: 3}, + {name: "OffsetLimit_DistinctDoc_OrderDesc", query: "SELECT DISTINCT c.category FROM c WHERE $1<=c.id AND c.id<@2 AND c.username=:3 ORDER BY c.category DESC OFFSET 1 LIMIT 3 WITH collection=%s WITH cross_partition=true", distinctQuery: -1, orderType: reddo.TypeInt, orderField: "category", orderDirection: "desc", expectedNumItems: 3}, + + /* GROUP BY with ORDER BY is not supported! */ + {name: "NoLimit_GroupByCount", query: "SELECT c.category AS 'Category', count(1) AS 'Value' FROM c WHERE $1<=c.id AND c.id<@2 AND c.username=:3 GROUP BY c.category WITH collection=%s WITH cross_partition=true", groupByAggr: "count"}, + {name: "OffsetLimit_GroupByCount", query: "SELECT c.category AS 'Category', count(1) AS 'Value' FROM c WHERE $1<=c.id AND c.id<@2 AND c.username=:3 GROUP BY c.category OFFSET 1 LIMIT 3 WITH collection=%s WITH cross_partition=true", expectedNumItems: 3, groupByAggr: "count"}, + {name: "NoLimit_GroupBySum", query: "SELECT c.category AS 'Category', sum(c.grade) AS 'Value' FROM c WHERE $1<=c.id AND c.id<@2 AND c.username=:3 GROUP BY c.category WITH collection=%s WITH cross_partition=true", groupByAggr: "sum"}, + {name: "OffsetLimit_GroupBySum", query: "SELECT c.category AS 'Category', sum(c.grade) AS 'Value' FROM c WHERE $1<=c.id AND c.id<@2 AND c.username=:3 GROUP BY c.category OFFSET 1 LIMIT 3 WITH collection=%s WITH cross_partition=true", expectedNumItems: 3, groupByAggr: "sum"}, + {name: "NoLimit_GroupByMin", query: "SELECT c.category AS 'Category', min(c.grade) AS 'Value' FROM c WHERE $1<=c.id AND c.id<@2 AND c.username=:3 GROUP BY c.category WITH collection=%s WITH cross_partition=true", groupByAggr: "min"}, + {name: "OffsetLimit_GroupByMin", query: "SELECT c.category AS 'Category', min(c.grade) AS 'Value' FROM c WHERE $1<=c.id AND c.id<@2 AND c.username=:3 GROUP BY c.category OFFSET 1 LIMIT 3 WITH collection=%s WITH cross_partition=true", expectedNumItems: 3, groupByAggr: "min"}, + {name: "NoLimit_GroupByMax", query: "SELECT c.category AS 'Category', max(c.grade) AS 'Value' FROM c WHERE $1<=c.id AND c.id<@2 AND c.username=:3 GROUP BY c.category WITH collection=%s WITH cross_partition=true", groupByAggr: "max"}, + {name: "OffsetLimit_GroupByMax", query: "SELECT c.category AS 'Category', max(c.grade) AS 'Value' FROM c WHERE $1<=c.id AND c.id<@2 AND c.username=:3 GROUP BY c.category OFFSET 1 LIMIT 3 WITH collection=%s WITH cross_partition=true", expectedNumItems: 3, groupByAggr: "max"}, + {name: "NoLimit_GroupByAvg", query: "SELECT c.category AS 'Category', avg(c.grade) AS 'Value' FROM c WHERE $1<=c.id AND c.id<@2 AND c.username=:3 GROUP BY c.category WITH collection=%s WITH cross_partition=true", groupByAggr: "average"}, + {name: "OffsetLimit_GroupByAvg", query: "SELECT c.category AS 'Category', avg(c.grade) AS 'Value' FROM c WHERE $1<=c.id AND c.id<@2 AND c.username=:3 GROUP BY c.category OFFSET 1 LIMIT 3 WITH collection=%s WITH cross_partition=true", expectedNumItems: 3, groupByAggr: "average"}, + } + for _, testCase := range testCases { + t.Run(testCase.name, func(t *testing.T) { + savedExpectedNumItems := testCase.expectedNumItems + for i := 0; i < numLogicalPartitions; i++ { + testCase.expectedNumItems = savedExpectedNumItems + expectedNumItems := testCase.expectedNumItems + username := "user" + strconv.Itoa(i) + params := []interface{}{lowStr, highStr, username} + if expectedNumItems <= 0 && testCase.maxItemCount <= 0 { + expectedNumItems = countPerPartition[username] + if testCase.distinctQuery != 0 { + expectedNumItems = distinctPerPartition[username] + } + testCase.expectedNumItems = expectedNumItems + } + sql := fmt.Sprintf(testCase.query, collname) + dbRows, err := db.Query(sql, params...) + if err != nil { + t.Fatalf("%s failed: %s", testName+"/"+testCase.name, err) + } + rows, err := _fetchAllRows(dbRows) + if err != nil { + t.Fatalf("%s failed: %s", testName+"/"+testCase.name, err) + } + _verifyResult(func(msg string) { t.Fatal(msg) }, testName+"/"+testCase.name+"/pk="+username, testCase, expectedNumItems, rows) + _verifyDistinct(func(msg string) { t.Fatal(msg) }, testName+"/"+testCase.name+"/pk="+username, testCase, rows) + _verifyOrderBy(func(msg string) { t.Fatal(msg) }, testName+"/"+testCase.name+"/pk="+username, testCase, rows) + _verifyGroupBy(func(msg string) { t.Fatal(msg) }, testName+"/"+testCase.name+"/pk="+username, testCase, username, lowStr, highStr, rows) + } + }) + } +} + +func TestStmtSelect_Query_PkValue_SubPartitions_SmallRU(t *testing.T) { + testName := "TestStmtSelect_Query_PkValue_SmallRU" + dbname := testDb + collname := testTable + client := _newRestClient(t, testName) + _initDataSubPartitionsSmallRU(t, testName, client, dbname, collname, 1000) + if result := client.GetPkranges(dbname, collname); result.Error() != nil { + t.Fatalf("%s failed: %s", testName+"/GetPkranges", result.Error()) + } else if result.Count != 1 { + t.Fatalf("%s failed: expected to be %#v but received %#v", testName+"/GetPkranges", 1, result.Count) + } + db := _openDefaultDb(t, testName, dbname) + _testSelectPkValueSubPartitions(t, testName, db, collname) +} + +func TestStmtSelect_Query_PkValue_SubPartitions_LargeRU(t *testing.T) { + testName := "TestStmtSelect_Query_PkValue_LargeRU" + dbname := testDb + collname := testTable + client := _newRestClient(t, testName) + _initDataSubPartitionsLargeRU(t, testName, client, dbname, collname, 1000) + if result := client.GetPkranges(dbname, collname); result.Error() != nil { + t.Fatalf("%s failed: %s", testName+"/GetPkranges", result.Error()) + } else if result.Count < 2 { + t.Fatalf("%s failed: expected to be larger than %#v but received %#v", testName+"/GetPkranges", 1, result.Count) + } + db := _openDefaultDb(t, testName, dbname) + _testSelectPkValueSubPartitions(t, testName, db, collname) +} + +/*----------------------------------------------------------------------*/ + func _testSelectPkValue(t *testing.T, testName string, db *sql.DB, collname string) { low, high := 123, 987 lowStr, highStr := fmt.Sprintf("%05d", low), fmt.Sprintf("%05d", high)