diff --git a/bindings/aws/dynamodb/dynamodb.go b/bindings/aws/dynamodb/dynamodb.go index 755b3158d3..2096f22433 100644 --- a/bindings/aws/dynamodb/dynamodb.go +++ b/bindings/aws/dynamodb/dynamodb.go @@ -122,5 +122,8 @@ func (d *DynamoDB) GetComponentMetadata() (metadataInfo metadata.MetadataMap) { } func (d *DynamoDB) Close() error { - return d.authProvider.Close() + if d.authProvider != nil { + return d.authProvider.Close() + } + return nil } diff --git a/bindings/aws/kinesis/kinesis.go b/bindings/aws/kinesis/kinesis.go index 7ede7ba245..bf684f8bbb 100644 --- a/bindings/aws/kinesis/kinesis.go +++ b/bindings/aws/kinesis/kinesis.go @@ -266,7 +266,10 @@ func (a *AWSKinesis) Close() error { close(a.closeCh) } a.wg.Wait() - return a.authProvider.Close() + if a.authProvider != nil { + return a.authProvider.Close() + } + return nil } func (a *AWSKinesis) ensureConsumer(ctx context.Context, streamARN *string) (*string, error) { diff --git a/bindings/aws/s3/s3.go b/bindings/aws/s3/s3.go index 13f8730e78..fa20c70a6b 100644 --- a/bindings/aws/s3/s3.go +++ b/bindings/aws/s3/s3.go @@ -153,7 +153,10 @@ func (s *AWSS3) Init(ctx context.Context, metadata bindings.Metadata) error { } func (s *AWSS3) Close() error { - return s.authProvider.Close() + if s.authProvider != nil { + return s.authProvider.Close() + } + return nil } func (s *AWSS3) Operations() []bindings.OperationKind { diff --git a/bindings/aws/ses/ses.go b/bindings/aws/ses/ses.go index 4cd752bac5..b8d2ff3faa 100644 --- a/bindings/aws/ses/ses.go +++ b/bindings/aws/ses/ses.go @@ -176,5 +176,8 @@ func (a *AWSSES) GetComponentMetadata() (metadataInfo contribMetadata.MetadataMa } func (a *AWSSES) Close() error { - return a.authProvider.Close() + if a.authProvider != nil { + return a.authProvider.Close() + } + return nil } diff --git a/bindings/aws/sns/sns.go b/bindings/aws/sns/sns.go index 55e3ccefa5..5464f1f044 100644 --- a/bindings/aws/sns/sns.go +++ b/bindings/aws/sns/sns.go @@ -128,5 +128,8 @@ func (a *AWSSNS) GetComponentMetadata() (metadataInfo metadata.MetadataMap) { } func (a *AWSSNS) Close() error { - return a.authProvider.Close() + if a.authProvider != nil { + return a.authProvider.Close() + } + return nil } diff --git a/bindings/aws/sqs/sqs.go b/bindings/aws/sqs/sqs.go index d803bafc5a..b09fde61f6 100644 --- a/bindings/aws/sqs/sqs.go +++ b/bindings/aws/sqs/sqs.go @@ -173,7 +173,10 @@ func (a *AWSSQS) Close() error { close(a.closeCh) } a.wg.Wait() - return a.authProvider.Close() + if a.authProvider != nil { + return a.authProvider.Close() + } + return nil } func (a *AWSSQS) parseSQSMetadata(meta bindings.Metadata) (*sqsMetadata, error) { diff --git a/bindings/postgres/metadata.go b/bindings/postgres/metadata.go index b4747c33ff..33eae83f58 100644 --- a/bindings/postgres/metadata.go +++ b/bindings/postgres/metadata.go @@ -14,6 +14,7 @@ limitations under the License. package postgres import ( + "errors" "time" "github.com/dapr/components-contrib/common/authentication/aws" @@ -53,5 +54,9 @@ func (m *psqlMetadata) InitWithMetadata(meta map[string]string) error { return err } + if m.Timeout < 1*time.Second { + return errors.New("invalid value for 'timeout': must be greater than 1s") + } + return nil } diff --git a/bindings/postgres/metadata_test.go b/bindings/postgres/metadata_test.go new file mode 100644 index 0000000000..ece5e433ed --- /dev/null +++ b/bindings/postgres/metadata_test.go @@ -0,0 +1,88 @@ +/* +Copyright 2023 The Dapr Authors +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + http://www.apache.org/licenses/LICENSE-2.0 +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +*/ + +package postgres + +import ( + "testing" + "time" + + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" +) + +func TestMetadata(t *testing.T) { + t.Run("missing connection string", func(t *testing.T) { + m := psqlMetadata{} + props := map[string]string{} + + err := m.InitWithMetadata(props) + require.Error(t, err) + require.ErrorContains(t, err, "connection string") + }) + + t.Run("has connection string", func(t *testing.T) { + m := psqlMetadata{} + props := map[string]string{ + "connectionString": "foo", + } + + err := m.InitWithMetadata(props) + require.NoError(t, err) + }) + + t.Run("default timeout", func(t *testing.T) { + m := psqlMetadata{} + props := map[string]string{ + "connectionString": "foo", + } + + err := m.InitWithMetadata(props) + require.NoError(t, err) + assert.Equal(t, 20*time.Second, m.Timeout) + }) + + t.Run("invalid timeout", func(t *testing.T) { + m := psqlMetadata{} + props := map[string]string{ + "connectionString": "foo", + "timeout": "NaN", + } + + err := m.InitWithMetadata(props) + require.Error(t, err) + }) + + t.Run("positive timeout", func(t *testing.T) { + m := psqlMetadata{} + props := map[string]string{ + "connectionString": "foo", + "timeout": "42", + } + + err := m.InitWithMetadata(props) + require.NoError(t, err) + assert.Equal(t, 42*time.Second, m.Timeout) + }) + + t.Run("zero timeout", func(t *testing.T) { + m := psqlMetadata{} + props := map[string]string{ + "connectionString": "foo", + "timeout": "0", + } + + err := m.InitWithMetadata(props) + require.Error(t, err) + }) +} diff --git a/bindings/postgres/postgres.go b/bindings/postgres/postgres.go index c9dc6bcfbe..e6dce08e0d 100644 --- a/bindings/postgres/postgres.go +++ b/bindings/postgres/postgres.go @@ -73,11 +73,20 @@ func (p *Postgres) Init(ctx context.Context, meta bindings.Metadata) error { // This context doesn't control the lifetime of the connection pool, and is // only scoped to postgres creating resources at init. - p.db, err = pgxpool.NewWithConfig(ctx, poolConfig) + connCtx, connCancel := context.WithTimeout(ctx, m.Timeout) + defer connCancel() + p.db, err = pgxpool.NewWithConfig(connCtx, poolConfig) if err != nil { return fmt.Errorf("unable to connect to the DB: %w", err) } + pingCtx, pingCancel := context.WithTimeout(ctx, m.Timeout) + defer pingCancel() + err = p.db.Ping(pingCtx) + if err != nil { + return fmt.Errorf("failed to ping the DB: %w", err) + } + return nil } diff --git a/bindings/postgres/postgres_test.go b/bindings/postgres/postgres_test.go index c24fc099fb..6a517fcd6a 100644 --- a/bindings/postgres/postgres_test.go +++ b/bindings/postgres/postgres_test.go @@ -15,6 +15,7 @@ package postgres import ( "context" + "errors" "fmt" "os" "testing" @@ -62,6 +63,10 @@ func TestPostgresIntegration(t *testing.T) { t.SkipNow() } + t.Run("Test init configurations", func(t *testing.T) { + testInitConfiguration(t, url) + }) + // live DB test b := NewPostgres(logger.NewLogger("test")).(*Postgres) m := bindings.Metadata{Base: metadata.Base{Properties: map[string]string{"connectionString": url}}} @@ -131,6 +136,46 @@ func TestPostgresIntegration(t *testing.T) { }) } +// testInitConfiguration tests valid and invalid config settings. +func testInitConfiguration(t *testing.T, connectionString string) { + logger := logger.NewLogger("test") + tests := []struct { + name string + props map[string]string + expectedErr error + }{ + { + name: "Empty", + props: map[string]string{}, + expectedErr: errors.New("missing connection string"), + }, + { + name: "Valid connection string", + props: map[string]string{"connectionString": connectionString}, + expectedErr: nil, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + p := NewPostgres(logger).(*Postgres) + defer p.Close() + + metadata := bindings.Metadata{ + Base: metadata.Base{Properties: tt.props}, + } + + err := p.Init(context.Background(), metadata) + if tt.expectedErr == nil { + require.NoError(t, err) + } else { + require.Error(t, err) + assert.Equal(t, tt.expectedErr, err) + } + }) + } +} + func assertResponse(t *testing.T, res *bindings.InvokeResponse, err error) { require.NoError(t, err) assert.NotNil(t, res) diff --git a/common/authentication/aws/x509.go b/common/authentication/aws/x509.go index cb1bafdeb3..52af56d271 100644 --- a/common/authentication/aws/x509.go +++ b/common/authentication/aws/x509.go @@ -96,6 +96,7 @@ func newX509(ctx context.Context, opts Options, cfg *aws.Config) (*x509, error) return GetConfig(opts) }(), clients: newClients(), + closeCh: make(chan struct{}), } if err := auth.getCertPEM(ctx); err != nil { diff --git a/pubsub/aws/snssqs/snssqs.go b/pubsub/aws/snssqs/snssqs.go index 93481fb733..4e2371764b 100644 --- a/pubsub/aws/snssqs/snssqs.go +++ b/pubsub/aws/snssqs/snssqs.go @@ -875,7 +875,10 @@ func (s *snsSqs) Close() error { s.subscriptionManager.Close() } - return s.authProvider.Close() + if s.authProvider != nil { + return s.authProvider.Close() + } + return nil } func (s *snsSqs) Features() []pubsub.Feature { diff --git a/secretstores/aws/parameterstore/parameterstore.go b/secretstores/aws/parameterstore/parameterstore.go index abf9c6c4de..038399b30c 100644 --- a/secretstores/aws/parameterstore/parameterstore.go +++ b/secretstores/aws/parameterstore/parameterstore.go @@ -182,5 +182,8 @@ func (s *ssmSecretStore) GetComponentMetadata() (metadataInfo metadata.MetadataM } func (s *ssmSecretStore) Close() error { - return s.authProvider.Close() + if s.authProvider != nil { + return s.authProvider.Close() + } + return nil } diff --git a/secretstores/aws/secretmanager/secretmanager.go b/secretstores/aws/secretmanager/secretmanager.go index 6faf1f1eab..979739be5b 100644 --- a/secretstores/aws/secretmanager/secretmanager.go +++ b/secretstores/aws/secretmanager/secretmanager.go @@ -170,5 +170,8 @@ func (s *smSecretStore) GetComponentMetadata() (metadataInfo metadata.MetadataMa } func (s *smSecretStore) Close() error { - return s.authProvider.Close() + if s.authProvider != nil { + return s.authProvider.Close() + } + return nil } diff --git a/state/aws/dynamodb/dynamodb.go b/state/aws/dynamodb/dynamodb.go index ae4ba7c5e9..d3bbd39a85 100644 --- a/state/aws/dynamodb/dynamodb.go +++ b/state/aws/dynamodb/dynamodb.go @@ -275,7 +275,10 @@ func (d *StateStore) GetComponentMetadata() (metadataInfo metadata.MetadataMap) } func (d *StateStore) Close() error { - return d.authProvider.Close() + if d.authProvider != nil { + return d.authProvider.Close() + } + return nil } func (d *StateStore) getDynamoDBMetadata(meta state.Metadata) (*dynamoDBMetadata, error) { diff --git a/state/mongodb/mongodb.go b/state/mongodb/mongodb.go index 1bd5472ef9..112043d637 100644 --- a/state/mongodb/mongodb.go +++ b/state/mongodb/mongodb.go @@ -25,6 +25,8 @@ import ( "strings" "time" + "github.com/dapr/components-contrib/contenttype" + "github.com/google/uuid" "go.mongodb.org/mongo-driver/bson" "go.mongodb.org/mongo-driver/bson/bsonrw" @@ -528,7 +530,18 @@ func (m *MongoDB) doTransaction(sessCtx mongo.SessionContext, operations []state var err error switch req := o.(type) { case state.SetRequest: - err = m.setInternal(sessCtx, &req) + { + isJSON := (len(req.Metadata) > 0 && req.Metadata[metadata.ContentType] == contenttype.JSONContentType) + if isJSON { + if bytes, ok := req.Value.([]byte); ok { + err = json.Unmarshal(bytes, &req.Value) + if err != nil { + break + } + } + } + err = m.setInternal(sessCtx, &req) + } case state.DeleteRequest: err = m.deleteInternal(sessCtx, &req) } diff --git a/state/postgresql/v2/postgresql.go b/state/postgresql/v2/postgresql.go index d323ca5c90..a0f44ec043 100644 --- a/state/postgresql/v2/postgresql.go +++ b/state/postgresql/v2/postgresql.go @@ -99,16 +99,16 @@ func (p *PostgreSQL) Init(ctx context.Context, meta state.Metadata) (err error) } connCtx, connCancel := context.WithTimeout(ctx, p.metadata.Timeout) + defer connCancel() p.db, err = pgxpool.NewWithConfig(connCtx, config) - connCancel() if err != nil { err = fmt.Errorf("failed to connect to the database: %w", err) return err } pingCtx, pingCancel := context.WithTimeout(ctx, p.metadata.Timeout) + defer pingCancel() err = p.db.Ping(pingCtx) - pingCancel() if err != nil { err = fmt.Errorf("failed to ping the database: %w", err) return err diff --git a/tests/conformance/state/state.go b/tests/conformance/state/state.go index 25de650cf9..cfb94dbfbb 100644 --- a/tests/conformance/state/state.go +++ b/tests/conformance/state/state.go @@ -15,6 +15,7 @@ package state import ( "context" + "encoding/base64" "encoding/json" "fmt" "slices" @@ -784,6 +785,70 @@ func ConformanceTests(t *testing.T, props map[string]string, statestore state.St assert.Equal(t, v, res.Data) } }) + + t.Run("transaction-serialization-grpc-json", func(t *testing.T) { + features := statestore.Features() + // this check for exclude redis 7 + if state.FeatureQueryAPI.IsPresent(features) { + json := "{\"id\":1223,\"name\":\"test\"}" + keyTest1 := key + "-key-grpc" + valueTest := []byte(json) + keyTest2 := key + "-key-grpc-no-json" + + metadataTest1 := map[string]string{ + "contentType": "application/json", + } + + operations := []state.TransactionalStateOperation{ + state.SetRequest{ + Key: keyTest1, + Value: valueTest, + Metadata: metadataTest1, + }, + state.SetRequest{ + Key: keyTest2, + Value: valueTest, + }, + } + + expected := map[string][]byte{ + keyTest1: []byte(json), + keyTest2: []byte(json), + } + + expectedMetadata := map[string]map[string]string{ + keyTest1: metadataTest1, + } + + // Act + transactionStore, ok := statestore.(state.TransactionalStore) + assert.True(t, ok) + err := transactionStore.Multi(context.Background(), &state.TransactionalStateRequest{ + Operations: operations, + }) + require.NoError(t, err) + + // Assert + for k, v := range expected { + res, err := statestore.Get(context.Background(), &state.GetRequest{ + Key: k, + Metadata: expectedMetadata[k], + }) + expectedValue := res.Data + + // In redisjson when set the value with contentType = application/Json store the value in base64 + if strings.HasPrefix(string(expectedValue), "\"ey") { + valueBase64 := strings.Trim(string(expectedValue), "\"") + expectedValueDecoded, _ := base64.StdEncoding.DecodeString(valueBase64) + require.NoError(t, err) + assert.Equal(t, expectedValueDecoded, v) + } else { + require.NoError(t, err) + assert.Equal(t, expectedValue, v) + } + } + } + }) } else { t.Run("component does not implement TransactionalStore interface", func(t *testing.T) { _, ok := statestore.(state.TransactionalStore)