diff --git a/cmd/config.go b/cmd/config.go index 621bd7d..3796019 100644 --- a/cmd/config.go +++ b/cmd/config.go @@ -18,7 +18,7 @@ import ( "github.com/xataio/pgstream/pkg/wal/processor/search/opensearch" "github.com/xataio/pgstream/pkg/wal/processor/translator" "github.com/xataio/pgstream/pkg/wal/processor/webhook/notifier" - "github.com/xataio/pgstream/pkg/wal/processor/webhook/server" + "github.com/xataio/pgstream/pkg/wal/processor/webhook/subscription/server" pgreplication "github.com/xataio/pgstream/pkg/wal/replication/postgres" ) @@ -181,7 +181,11 @@ func parseWebhookProcessorConfig() *stream.WebhookProcessorConfig { } return &stream.WebhookProcessorConfig{ - SubscriptionStoreURL: subscriptionStore, + SubscriptionStore: stream.WebhookSubscriptionStoreConfig{ + URL: subscriptionStore, + CacheEnabled: viper.GetBool("PGSTREAM_WEBHOOK_SUBSCRIPTION_STORE_CACHE_ENABLED"), + CacheRefreshInterval: viper.GetDuration("PGSTREAM_WEBHOOK_SUBSCRIPTION_STORE_CACHE_REFRESH_INTERVAL"), + }, Notifier: notifier.Config{ MaxQueueBytes: viper.GetInt64("PGSTREAM_WEBHOOK_NOTIFIER_MAX_QUEUE_BYTES"), URLWorkerCount: viper.GetUint("PGSTREAM_WEBHOOK_NOTIFIER_WORKER_COUNT"), diff --git a/pg2webhook.env b/pg2webhook.env index 9a0d36c..3b752ac 100644 --- a/pg2webhook.env +++ b/pg2webhook.env @@ -3,3 +3,5 @@ PGSTREAM_POSTGRES_LISTENER_URL="postgres://postgres:postgres@localhost?sslmode=d # Processor config PGSTREAM_WEBHOOK_SUBSCRIPTION_STORE_URL="postgres://postgres:postgres@localhost?sslmode=disable" +PGSTREAM_WEBHOOK_SUBSCRIPTION_STORE_CACHE_ENABLED=true +PGSTREAM_WEBHOOK_SUBSCRIPTION_STORE_CACHE_REFRESH_INTERVAL="60s" diff --git a/pkg/stream/config.go b/pkg/stream/config.go index a301855..925e1ba 100644 --- a/pkg/stream/config.go +++ b/pkg/stream/config.go @@ -4,6 +4,7 @@ package stream import ( "errors" + "time" kafkacheckpoint "github.com/xataio/pgstream/pkg/wal/checkpointer/kafka" kafkalistener "github.com/xataio/pgstream/pkg/wal/listener/kafka" @@ -12,7 +13,7 @@ import ( "github.com/xataio/pgstream/pkg/wal/processor/search/opensearch" "github.com/xataio/pgstream/pkg/wal/processor/translator" "github.com/xataio/pgstream/pkg/wal/processor/webhook/notifier" - "github.com/xataio/pgstream/pkg/wal/processor/webhook/server" + "github.com/xataio/pgstream/pkg/wal/processor/webhook/subscription/server" pgreplication "github.com/xataio/pgstream/pkg/wal/replication/postgres" ) @@ -53,9 +54,15 @@ type SearchProcessorConfig struct { } type WebhookProcessorConfig struct { - SubscriptionStoreURL string - Notifier notifier.Config - SubscriptionServer server.Config + Notifier notifier.Config + SubscriptionServer server.Config + SubscriptionStore WebhookSubscriptionStoreConfig +} + +type WebhookSubscriptionStoreConfig struct { + URL string + CacheEnabled bool + CacheRefreshInterval time.Duration } func (c *Config) IsValid() error { diff --git a/pkg/stream/stream_start.go b/pkg/stream/stream_start.go index fedecc4..a8ce3e0 100644 --- a/pkg/stream/stream_start.go +++ b/pkg/stream/stream_start.go @@ -19,10 +19,12 @@ import ( "github.com/xataio/pgstream/pkg/wal/processor/search" "github.com/xataio/pgstream/pkg/wal/processor/search/opensearch" "github.com/xataio/pgstream/pkg/wal/processor/translator" - "github.com/xataio/pgstream/pkg/wal/processor/webhook" webhooknotifier "github.com/xataio/pgstream/pkg/wal/processor/webhook/notifier" - pgwebhook "github.com/xataio/pgstream/pkg/wal/processor/webhook/postgres" - webhookserver "github.com/xataio/pgstream/pkg/wal/processor/webhook/server" + subscriptionserver "github.com/xataio/pgstream/pkg/wal/processor/webhook/subscription/server" + webhookstore "github.com/xataio/pgstream/pkg/wal/processor/webhook/subscription/store" + subscriptionstorecache "github.com/xataio/pgstream/pkg/wal/processor/webhook/subscription/store/cache" + pgwebhook "github.com/xataio/pgstream/pkg/wal/processor/webhook/subscription/store/postgres" + "github.com/xataio/pgstream/pkg/wal/replication" replicationinstrumentation "github.com/xataio/pgstream/pkg/wal/replication/instrumentation" pgreplication "github.com/xataio/pgstream/pkg/wal/replication/postgres" @@ -133,15 +135,28 @@ func Start(ctx context.Context, logger loglib.Logger, config *Config, meter metr }) case config.Processor.Webhook != nil: - var subscriptionStore webhook.SubscriptionStore + var subscriptionStore webhookstore.Store var err error subscriptionStore, err = pgwebhook.NewSubscriptionStore(ctx, - config.Processor.Webhook.SubscriptionStoreURL, + config.Processor.Webhook.SubscriptionStore.URL, pgwebhook.WithLogger(logger), ) if err != nil { return err } + + if config.Processor.Webhook.SubscriptionStore.CacheEnabled { + logger.Info("setting up subscription store cache...") + subscriptionStore, err = subscriptionstorecache.New(ctx, subscriptionStore, + &subscriptionstorecache.Config{ + SyncInterval: config.Processor.Webhook.SubscriptionStore.CacheRefreshInterval, + }, + subscriptionstorecache.WithLogger(logger)) + if err != nil { + return err + } + } + notifier := webhooknotifier.New( &config.Processor.Webhook.Notifier, subscriptionStore, @@ -150,10 +165,10 @@ func Start(ctx context.Context, logger loglib.Logger, config *Config, meter metr defer notifier.Close() processor = notifier - subscriptionServer := webhookserver.New( + subscriptionServer := subscriptionserver.New( &config.Processor.Webhook.SubscriptionServer, subscriptionStore, - webhookserver.WithLogger(logger)) + subscriptionserver.WithLogger(logger)) eg.Go(func() error { logger.Info("running subscription server...") diff --git a/pkg/wal/processor/webhook/mocks/mock_subscription_store.go b/pkg/wal/processor/webhook/mocks/mock_subscription_store.go deleted file mode 100644 index b1365bb..0000000 --- a/pkg/wal/processor/webhook/mocks/mock_subscription_store.go +++ /dev/null @@ -1,27 +0,0 @@ -// SPDX-License-Identifier: Apache-2.0 - -package mocks - -import ( - "context" - - "github.com/xataio/pgstream/pkg/wal/processor/webhook" -) - -type SubscriptionStore struct { - CreateSubscriptionFn func(ctx context.Context, s *webhook.Subscription) error - DeleteSubscriptionFn func(ctx context.Context, s *webhook.Subscription) error - GetSubscriptionsFn func(ctx context.Context, action, schema, table string) ([]*webhook.Subscription, error) -} - -func (m *SubscriptionStore) CreateSubscription(ctx context.Context, s *webhook.Subscription) error { - return m.CreateSubscriptionFn(ctx, s) -} - -func (m *SubscriptionStore) DeleteSubscription(ctx context.Context, s *webhook.Subscription) error { - return m.DeleteSubscriptionFn(ctx, s) -} - -func (m *SubscriptionStore) GetSubscriptions(ctx context.Context, action, schema, table string) ([]*webhook.Subscription, error) { - return m.GetSubscriptionsFn(ctx, action, schema, table) -} diff --git a/pkg/wal/processor/webhook/notifier/helper_test.go b/pkg/wal/processor/webhook/notifier/helper_test.go index 9ee02f5..a6505bf 100644 --- a/pkg/wal/processor/webhook/notifier/helper_test.go +++ b/pkg/wal/processor/webhook/notifier/helper_test.go @@ -6,15 +6,16 @@ import ( "errors" "github.com/xataio/pgstream/pkg/wal" - "github.com/xataio/pgstream/pkg/wal/processor/webhook" + "github.com/xataio/pgstream/pkg/wal/processor/webhook/subscription" ) -var testCommitPos = wal.CommitPosition("test-pos") - -var errTest = errors.New("oh noes") +var ( + testCommitPos = wal.CommitPosition("test-pos") + errTest = errors.New("oh noes") +) -func newTestSubscription(url, schema, table string, eventTypes []string) *webhook.Subscription { - return &webhook.Subscription{ +func newTestSubscription(url, schema, table string, eventTypes []string) *subscription.Subscription { + return &subscription.Subscription{ URL: url, Schema: schema, Table: table, diff --git a/pkg/wal/processor/webhook/notifier/webhook_notifier.go b/pkg/wal/processor/webhook/notifier/webhook_notifier.go index 46ed864..d2cc8f3 100644 --- a/pkg/wal/processor/webhook/notifier/webhook_notifier.go +++ b/pkg/wal/processor/webhook/notifier/webhook_notifier.go @@ -18,7 +18,7 @@ import ( "github.com/xataio/pgstream/pkg/wal" "github.com/xataio/pgstream/pkg/wal/checkpointer" "github.com/xataio/pgstream/pkg/wal/processor" - "github.com/xataio/pgstream/pkg/wal/processor/webhook" + "github.com/xataio/pgstream/pkg/wal/processor/webhook/subscription" ) // Notifier represents the process that notifies any subscribed webhooks when @@ -38,7 +38,7 @@ type Notifier struct { } type subscriptionRetriever interface { - GetSubscriptions(ctx context.Context, action, schema, table string) ([]*webhook.Subscription, error) + GetSubscriptions(ctx context.Context, action, schema, table string) ([]*subscription.Subscription, error) } type Option func(*Notifier) @@ -92,7 +92,7 @@ func (n *Notifier) ProcessWALEvent(ctx context.Context, walEvent *wal.Event) (er } }() - subscriptions := []*webhook.Subscription{} + subscriptions := []*subscription.Subscription{} if walEvent.Data != nil { data := walEvent.Data subscriptions, err = n.subscriptionStore.GetSubscriptions(ctx, data.Action, data.Schema, data.Table) diff --git a/pkg/wal/processor/webhook/notifier/webhook_notifier_test.go b/pkg/wal/processor/webhook/notifier/webhook_notifier_test.go index 223a8e9..f554acd 100644 --- a/pkg/wal/processor/webhook/notifier/webhook_notifier_test.go +++ b/pkg/wal/processor/webhook/notifier/webhook_notifier_test.go @@ -20,7 +20,8 @@ import ( "github.com/xataio/pgstream/pkg/wal/checkpointer" "github.com/xataio/pgstream/pkg/wal/processor" "github.com/xataio/pgstream/pkg/wal/processor/webhook" - "github.com/xataio/pgstream/pkg/wal/processor/webhook/mocks" + "github.com/xataio/pgstream/pkg/wal/processor/webhook/subscription" + "github.com/xataio/pgstream/pkg/wal/processor/webhook/subscription/store/mocks" ) func TestNotifier_ProcessWALEvent(t *testing.T) { @@ -35,7 +36,7 @@ func TestNotifier_ProcessWALEvent(t *testing.T) { CommitPosition: testCommitPos, } - testSubscription := func(url string) *webhook.Subscription { + testSubscription := func(url string) *subscription.Subscription { return newTestSubscription(url, "", "", nil) } @@ -54,9 +55,9 @@ func TestNotifier_ProcessWALEvent(t *testing.T) { }{ { name: "ok - no subscriptions for event", - store: &mocks.SubscriptionStore{ - GetSubscriptionsFn: func(ctx context.Context, action, schema, table string) ([]*webhook.Subscription, error) { - return []*webhook.Subscription{}, nil + store: &mocks.Store{ + GetSubscriptionsFn: func(ctx context.Context, action, schema, table string) ([]*subscription.Subscription, error) { + return []*subscription.Subscription{}, nil }, }, weightedSemaphore: &syncmocks.WeightedSemaphore{ @@ -72,9 +73,9 @@ func TestNotifier_ProcessWALEvent(t *testing.T) { }, { name: "ok - subscriptions for event", - store: &mocks.SubscriptionStore{ - GetSubscriptionsFn: func(ctx context.Context, action, schema, table string) ([]*webhook.Subscription, error) { - return []*webhook.Subscription{ + store: &mocks.Store{ + GetSubscriptionsFn: func(ctx context.Context, action, schema, table string) ([]*subscription.Subscription, error) { + return []*subscription.Subscription{ testSubscription("url-1"), testSubscription("url-2"), }, nil }, @@ -94,8 +95,8 @@ func TestNotifier_ProcessWALEvent(t *testing.T) { }, { name: "error - getting subscriptions", - store: &mocks.SubscriptionStore{ - GetSubscriptionsFn: func(ctx context.Context, action, schema, table string) ([]*webhook.Subscription, error) { + store: &mocks.Store{ + GetSubscriptionsFn: func(ctx context.Context, action, schema, table string) ([]*subscription.Subscription, error) { return nil, errTest }, }, @@ -106,9 +107,9 @@ func TestNotifier_ProcessWALEvent(t *testing.T) { }, { name: "error - serialising payload", - store: &mocks.SubscriptionStore{ - GetSubscriptionsFn: func(ctx context.Context, action, schema, table string) ([]*webhook.Subscription, error) { - return []*webhook.Subscription{ + store: &mocks.Store{ + GetSubscriptionsFn: func(ctx context.Context, action, schema, table string) ([]*subscription.Subscription, error) { + return []*subscription.Subscription{ testSubscription("url-1"), testSubscription("url-2"), }, nil }, @@ -121,9 +122,9 @@ func TestNotifier_ProcessWALEvent(t *testing.T) { }, { name: "error - acquiring semaphore", - store: &mocks.SubscriptionStore{ - GetSubscriptionsFn: func(ctx context.Context, action, schema, table string) ([]*webhook.Subscription, error) { - return []*webhook.Subscription{ + store: &mocks.Store{ + GetSubscriptionsFn: func(ctx context.Context, action, schema, table string) ([]*subscription.Subscription, error) { + return []*subscription.Subscription{ testSubscription("url-1"), testSubscription("url-2"), }, nil }, @@ -139,8 +140,8 @@ func TestNotifier_ProcessWALEvent(t *testing.T) { }, { name: "error - panic recovery", - store: &mocks.SubscriptionStore{ - GetSubscriptionsFn: func(ctx context.Context, action, schema, table string) ([]*webhook.Subscription, error) { + store: &mocks.Store{ + GetSubscriptionsFn: func(ctx context.Context, action, schema, table string) ([]*subscription.Subscription, error) { panic(errTest) }, }, @@ -299,7 +300,7 @@ func TestNotifier_Notify(t *testing.T) { doneChan := make(chan struct{}, 1) defer close(doneChan) - n := New(testCfg, &mocks.SubscriptionStore{}) + n := New(testCfg, &mocks.Store{}) n.client = tc.client n.queueBytesSema = tc.semaphore n.checkpointer = tc.checkpointer(doneChan) diff --git a/pkg/wal/processor/webhook/notifier/webhook_notify_msg.go b/pkg/wal/processor/webhook/notifier/webhook_notify_msg.go index 5b9337f..5fe2f4f 100644 --- a/pkg/wal/processor/webhook/notifier/webhook_notify_msg.go +++ b/pkg/wal/processor/webhook/notifier/webhook_notify_msg.go @@ -7,6 +7,7 @@ import ( "github.com/xataio/pgstream/pkg/wal" "github.com/xataio/pgstream/pkg/wal/processor/webhook" + "github.com/xataio/pgstream/pkg/wal/processor/webhook/subscription" ) type notifyMsg struct { @@ -17,7 +18,7 @@ type notifyMsg struct { type serialiser func(any) ([]byte, error) -func newNotifyMsg(event *wal.Event, subscriptions []*webhook.Subscription, serialiser serialiser) (*notifyMsg, error) { +func newNotifyMsg(event *wal.Event, subscriptions []*subscription.Subscription, serialiser serialiser) (*notifyMsg, error) { var payload []byte urls := make([]string, 0, len(subscriptions)) if len(subscriptions) > 0 { diff --git a/pkg/wal/processor/webhook/server/config.go b/pkg/wal/processor/webhook/subscription/server/config.go similarity index 100% rename from pkg/wal/processor/webhook/server/config.go rename to pkg/wal/processor/webhook/subscription/server/config.go diff --git a/pkg/wal/processor/webhook/server/subscription_server.go b/pkg/wal/processor/webhook/subscription/server/subscription_server.go similarity index 76% rename from pkg/wal/processor/webhook/server/subscription_server.go rename to pkg/wal/processor/webhook/subscription/server/subscription_server.go index 7fa0ece..17fcb45 100644 --- a/pkg/wal/processor/webhook/server/subscription_server.go +++ b/pkg/wal/processor/webhook/subscription/server/subscription_server.go @@ -12,20 +12,21 @@ import ( httplib "github.com/xataio/pgstream/internal/http" loglib "github.com/xataio/pgstream/pkg/log" - "github.com/xataio/pgstream/pkg/wal/processor/webhook" + "github.com/xataio/pgstream/pkg/wal/processor/webhook/subscription" + "github.com/xataio/pgstream/pkg/wal/processor/webhook/subscription/store" ) -type SubscriptionServer struct { +type Server struct { server httplib.Server logger loglib.Logger - store webhook.SubscriptionStore + store store.Store address string } -type Option func(*SubscriptionServer) +type Option func(*Server) -func New(cfg *Config, store webhook.SubscriptionStore, opts ...Option) *SubscriptionServer { - s := &SubscriptionServer{ +func New(cfg *Config, store store.Store, opts ...Option) *Server { + s := &Server{ address: cfg.address(), store: store, logger: loglib.NewNoopLogger(), @@ -51,7 +52,7 @@ func New(cfg *Config, store webhook.SubscriptionStore, opts ...Option) *Subscrip } func WithLogger(l loglib.Logger) Option { - return func(s *SubscriptionServer) { + return func(s *Server) { s.logger = loglib.NewLogger(l).WithFields(loglib.Fields{ loglib.ServiceField: "webhook_subscription_server", }) @@ -59,23 +60,23 @@ func WithLogger(l loglib.Logger) Option { } // Start will start the subscription server. This call is blocking. -func (s *SubscriptionServer) Start() error { +func (s *Server) Start() error { s.logger.Info(fmt.Sprintf("subscription server listening on: %s...", s.address)) return s.server.Start(s.address) } -func (s *SubscriptionServer) Shutdown(ctx context.Context) error { +func (s *Server) Shutdown(ctx context.Context) error { return s.server.Shutdown(ctx) } -func (s *SubscriptionServer) subscribe(c echo.Context) error { +func (s *Server) subscribe(c echo.Context) error { if c.Request().Method != http.MethodPost { return c.JSON(http.StatusMethodNotAllowed, nil) } s.logger.Trace("request received on /subscribe endpoint") - subscription := &webhook.Subscription{} + subscription := &subscription.Subscription{} if err := c.Bind(subscription); err != nil { return c.JSON(http.StatusBadRequest, err) } @@ -88,13 +89,13 @@ func (s *SubscriptionServer) subscribe(c echo.Context) error { return c.JSON(http.StatusCreated, nil) } -func (s *SubscriptionServer) unsubscribe(c echo.Context) error { +func (s *Server) unsubscribe(c echo.Context) error { if c.Request().Method != http.MethodPost { return c.JSON(http.StatusMethodNotAllowed, nil) } s.logger.Trace("request received on /unsubscribe endpoint") - subscription := &webhook.Subscription{} + subscription := &subscription.Subscription{} if err := c.Bind(subscription); err != nil { return c.JSON(http.StatusBadRequest, err) } diff --git a/pkg/wal/processor/webhook/server/subscription_server_test.go b/pkg/wal/processor/webhook/subscription/server/subscription_server_test.go similarity index 75% rename from pkg/wal/processor/webhook/server/subscription_server_test.go rename to pkg/wal/processor/webhook/subscription/server/subscription_server_test.go index 00639bc..691319b 100644 --- a/pkg/wal/processor/webhook/server/subscription_server_test.go +++ b/pkg/wal/processor/webhook/subscription/server/subscription_server_test.go @@ -15,14 +15,15 @@ import ( "github.com/labstack/echo/v4" "github.com/stretchr/testify/require" "github.com/xataio/pgstream/pkg/log" - "github.com/xataio/pgstream/pkg/wal/processor/webhook" - "github.com/xataio/pgstream/pkg/wal/processor/webhook/mocks" + "github.com/xataio/pgstream/pkg/wal/processor/webhook/subscription" + "github.com/xataio/pgstream/pkg/wal/processor/webhook/subscription/store" + "github.com/xataio/pgstream/pkg/wal/processor/webhook/subscription/store/mocks" ) func TestSubscriptionServer_subscribe(t *testing.T) { t.Parallel() - testSubscription := &webhook.Subscription{ + testSubscription := &subscription.Subscription{ URL: "url-1", Schema: "test_schema", Table: "test_table", @@ -35,7 +36,7 @@ func TestSubscriptionServer_subscribe(t *testing.T) { tests := []struct { name string - store webhook.SubscriptionStore + store store.Store method string payload io.Reader @@ -43,8 +44,8 @@ func TestSubscriptionServer_subscribe(t *testing.T) { }{ { name: "ok", - store: &mocks.SubscriptionStore{ - CreateSubscriptionFn: func(ctx context.Context, s *webhook.Subscription) error { + store: &mocks.Store{ + CreateSubscriptionFn: func(ctx context.Context, s *subscription.Subscription) error { require.Equal(t, testSubscription, s) return nil }, @@ -55,8 +56,8 @@ func TestSubscriptionServer_subscribe(t *testing.T) { }, { name: "error - creating subscription", - store: &mocks.SubscriptionStore{ - CreateSubscriptionFn: func(ctx context.Context, s *webhook.Subscription) error { + store: &mocks.Store{ + CreateSubscriptionFn: func(ctx context.Context, s *subscription.Subscription) error { return errTest }, }, @@ -66,8 +67,8 @@ func TestSubscriptionServer_subscribe(t *testing.T) { }, { name: "error - method not allowed", - store: &mocks.SubscriptionStore{ - CreateSubscriptionFn: func(ctx context.Context, s *webhook.Subscription) error { + store: &mocks.Store{ + CreateSubscriptionFn: func(ctx context.Context, s *subscription.Subscription) error { return errors.New("CreateSubscriptionFn: should not be called") }, }, @@ -77,8 +78,8 @@ func TestSubscriptionServer_subscribe(t *testing.T) { }, { name: "error - invalid payload", - store: &mocks.SubscriptionStore{ - CreateSubscriptionFn: func(ctx context.Context, s *webhook.Subscription) error { + store: &mocks.Store{ + CreateSubscriptionFn: func(ctx context.Context, s *subscription.Subscription) error { return errors.New("CreateSubscriptionFn: should not be called") }, }, @@ -93,7 +94,7 @@ func TestSubscriptionServer_subscribe(t *testing.T) { t.Run(tc.name, func(t *testing.T) { t.Parallel() - server := &SubscriptionServer{ + server := &Server{ logger: log.NewNoopLogger(), store: tc.store, } @@ -112,7 +113,7 @@ func TestSubscriptionServer_subscribe(t *testing.T) { func TestSubscriptionServer_unsubscribe(t *testing.T) { t.Parallel() - testSubscription := &webhook.Subscription{ + testSubscription := &subscription.Subscription{ URL: "url-1", Schema: "test_schema", Table: "test_table", @@ -125,7 +126,7 @@ func TestSubscriptionServer_unsubscribe(t *testing.T) { tests := []struct { name string - store webhook.SubscriptionStore + store store.Store method string payload io.Reader @@ -133,8 +134,8 @@ func TestSubscriptionServer_unsubscribe(t *testing.T) { }{ { name: "ok", - store: &mocks.SubscriptionStore{ - DeleteSubscriptionFn: func(ctx context.Context, s *webhook.Subscription) error { + store: &mocks.Store{ + DeleteSubscriptionFn: func(ctx context.Context, s *subscription.Subscription) error { require.Equal(t, testSubscription, s) return nil }, @@ -145,8 +146,8 @@ func TestSubscriptionServer_unsubscribe(t *testing.T) { }, { name: "error - creating subscription", - store: &mocks.SubscriptionStore{ - DeleteSubscriptionFn: func(ctx context.Context, s *webhook.Subscription) error { + store: &mocks.Store{ + DeleteSubscriptionFn: func(ctx context.Context, s *subscription.Subscription) error { return errTest }, }, @@ -156,8 +157,8 @@ func TestSubscriptionServer_unsubscribe(t *testing.T) { }, { name: "error - method not allowed", - store: &mocks.SubscriptionStore{ - DeleteSubscriptionFn: func(ctx context.Context, s *webhook.Subscription) error { + store: &mocks.Store{ + DeleteSubscriptionFn: func(ctx context.Context, s *subscription.Subscription) error { return errors.New("DeleteSubscriptionFn: should not be called") }, }, @@ -167,8 +168,8 @@ func TestSubscriptionServer_unsubscribe(t *testing.T) { }, { name: "error - invalid payload", - store: &mocks.SubscriptionStore{ - DeleteSubscriptionFn: func(ctx context.Context, s *webhook.Subscription) error { + store: &mocks.Store{ + DeleteSubscriptionFn: func(ctx context.Context, s *subscription.Subscription) error { return errors.New("DeleteSubscriptionFn: should not be called") }, }, @@ -183,7 +184,7 @@ func TestSubscriptionServer_unsubscribe(t *testing.T) { t.Run(tc.name, func(t *testing.T) { t.Parallel() - server := &SubscriptionServer{ + server := &Server{ logger: log.NewNoopLogger(), store: tc.store, } diff --git a/pkg/wal/processor/webhook/subscription/store/cache/config.go b/pkg/wal/processor/webhook/subscription/store/cache/config.go new file mode 100644 index 0000000..d94e219 --- /dev/null +++ b/pkg/wal/processor/webhook/subscription/store/cache/config.go @@ -0,0 +1,23 @@ +// SPDX-License-Identifier: Apache-2.0 + +package cache + +import "time" + +type Config struct { + // SyncInterval represents how frequently the cache will attempt to sync + // with the internal subscription store to retrieve the latest data. It + // defaults to 5min. + SyncInterval time.Duration +} + +const ( + defaultSyncInterval = 5 * time.Minute +) + +func (c *Config) syncInterval() time.Duration { + if c.SyncInterval > 0 { + return c.SyncInterval + } + return defaultSyncInterval +} diff --git a/pkg/wal/processor/webhook/subscription/store/cache/helper_test.go b/pkg/wal/processor/webhook/subscription/store/cache/helper_test.go new file mode 100644 index 0000000..c5186bb --- /dev/null +++ b/pkg/wal/processor/webhook/subscription/store/cache/helper_test.go @@ -0,0 +1,20 @@ +// SPDX-License-Identifier: Apache-2.0 + +package cache + +import ( + "errors" + + "github.com/xataio/pgstream/pkg/wal/processor/webhook/subscription" +) + +var errTest = errors.New("oh noes") + +func newTestSubscription(url, schema, table string, eventTypes []string) *subscription.Subscription { + return &subscription.Subscription{ + URL: url, + Schema: schema, + Table: table, + EventTypes: eventTypes, + } +} diff --git a/pkg/wal/processor/webhook/subscription/store/cache/subscription_store_cache.go b/pkg/wal/processor/webhook/subscription/store/cache/subscription_store_cache.go new file mode 100644 index 0000000..29b5c3c --- /dev/null +++ b/pkg/wal/processor/webhook/subscription/store/cache/subscription_store_cache.go @@ -0,0 +1,128 @@ +// SPDX-License-Identifier: Apache-2.0 + +package cache + +import ( + "context" + "fmt" + "sync" + "time" + + loglib "github.com/xataio/pgstream/pkg/log" + "github.com/xataio/pgstream/pkg/wal/processor/webhook/subscription" + "github.com/xataio/pgstream/pkg/wal/processor/webhook/subscription/store" +) + +// Store is a wrapper around a subscription store that keeps an in memory cache +// to minimise calls to the persistent store. It is concurrency safe. The cache +// contents will be refreshed on a configurable interval. This is an ephemeral +// lightweight wrapper that doesn't control memory usage. Should only be used +// when the amount of subscriptions is manageable for the resources allocated. +// The sync interval will represent in the worst case scenario the staleness of +// the cache. +type Store struct { + inner store.Store + logger loglib.Logger + cacheLock *sync.RWMutex + cache map[string]*subscription.Subscription + syncInterval time.Duration +} + +type Option func(*Store) + +// NewStoreCache will wrap the store on input, providing a simple in memory +// cache to minimise calls to the persistent store. It will perform an initial +// warm up to retrieve all existing subscriptions, and will sync with the store +// on input on a configured interval. +func New(ctx context.Context, store store.Store, cfg *Config, opts ...Option) (*Store, error) { + s := &Store{ + inner: store, + logger: loglib.NewNoopLogger(), + cache: make(map[string]*subscription.Subscription), + cacheLock: &sync.RWMutex{}, + syncInterval: cfg.syncInterval(), + } + + for _, opt := range opts { + opt(s) + } + + if err := s.refresh(ctx); err != nil { + return nil, err + } + + // start a go routine that will refresh the cache contents on the configured + // interval. + go s.syncRefresh(ctx) + + return s, nil +} + +func WithLogger(l loglib.Logger) Option { + return func(sc *Store) { + sc.logger = loglib.NewLogger(l).WithFields(loglib.Fields{ + loglib.ServiceField: "subscription_store_cache", + }) + } +} + +func (s *Store) CreateSubscription(ctx context.Context, subscription *subscription.Subscription) error { + return s.inner.CreateSubscription(ctx, subscription) +} + +func (s *Store) DeleteSubscription(ctx context.Context, subscription *subscription.Subscription) error { + return s.inner.DeleteSubscription(ctx, subscription) +} + +func (s *Store) GetSubscriptions(ctx context.Context, action, schema, table string) ([]*subscription.Subscription, error) { + s.cacheLock.RLock() + defer s.cacheLock.RUnlock() + + subscriptions := make([]*subscription.Subscription, 0, len(s.cache)) + for _, subscription := range s.cache { + if subscription.IsFor(action, schema, table) { + subscriptions = append(subscriptions, subscription) + } + } + + return subscriptions, nil +} + +func (s *Store) syncRefresh(ctx context.Context) { + ticker := time.NewTicker(s.syncInterval) + defer ticker.Stop() + + for { + select { + case <-ctx.Done(): + return + case <-ticker.C: + if err := s.refresh(ctx); err != nil { + s.logger.Error(err, "refreshing store cache") + } + } + } +} + +func (s *Store) refresh(ctx context.Context) error { + // get all subscriptions and populate the cache + subscriptions, err := s.inner.GetSubscriptions(ctx, "", "", "") + if err != nil { + return fmt.Errorf("retrieving subscriptions: %w", err) + } + + s.cacheLock.Lock() + defer s.cacheLock.Unlock() + + s.cache = make(map[string]*subscription.Subscription, len(subscriptions)) + for _, subscription := range subscriptions { + s.cache[subscription.Key()] = subscription + } + + s.logger.Debug("cache refreshed", loglib.Fields{ + "subscription_total_count": len(s.cache), + "subscriptions": s.cache, + }) + + return nil +} diff --git a/pkg/wal/processor/webhook/subscription/store/cache/subscription_store_cache_test.go b/pkg/wal/processor/webhook/subscription/store/cache/subscription_store_cache_test.go new file mode 100644 index 0000000..d84307c --- /dev/null +++ b/pkg/wal/processor/webhook/subscription/store/cache/subscription_store_cache_test.go @@ -0,0 +1,117 @@ +// SPDX-License-Identifier: Apache-2.0 + +package cache + +import ( + "context" + "sync" + "testing" + + "github.com/stretchr/testify/require" + "github.com/xataio/pgstream/pkg/wal/processor/webhook/subscription" + "github.com/xataio/pgstream/pkg/wal/processor/webhook/subscription/store" + "github.com/xataio/pgstream/pkg/wal/processor/webhook/subscription/store/mocks" +) + +func TestSubscriptionStoreCache_NewSubscriptionStoreCache(t *testing.T) { + t.Parallel() + + subscription1 := newTestSubscription("url-1", "", "", []string{"D"}) + subscription2 := newTestSubscription("url-2", "my_schema", "my_table", []string{"I"}) + + tests := []struct { + name string + store store.Store + + wantCache map[string]*subscription.Subscription + wantErr error + }{ + { + name: "ok", + store: &mocks.Store{ + GetSubscriptionsFn: func(ctx context.Context, action, schema, table string) ([]*subscription.Subscription, error) { + return []*subscription.Subscription{subscription1, subscription2}, nil + }, + }, + wantCache: map[string]*subscription.Subscription{ + subscription1.Key(): subscription1, + subscription2.Key(): subscription2, + }, + wantErr: nil, + }, + { + name: "error - refreshing cache", + store: &mocks.Store{ + GetSubscriptionsFn: func(ctx context.Context, action, schema, table string) ([]*subscription.Subscription, error) { + return nil, errTest + }, + }, + wantCache: map[string]*subscription.Subscription{}, + wantErr: errTest, + }, + } + + for _, tc := range tests { + tc := tc + t.Run(tc.name, func(t *testing.T) { + t.Parallel() + + cacheStore, err := New(context.Background(), tc.store, &Config{}) + require.ErrorIs(t, err, tc.wantErr) + if err == nil { + require.Equal(t, tc.wantCache, cacheStore.cache) + } + }) + } +} + +func TestSubscriptionStoreCache_GetSubscriptions(t *testing.T) { + t.Parallel() + + testSubscription1 := newTestSubscription("test-url-1", "test_schema", "test_table", []string{"D"}) + testSubscription2 := newTestSubscription("test-url-2", "test_schema", "test_table", []string{"I"}) + testSubscription3 := newTestSubscription("test-url-3", "", "", []string{"I"}) + + tests := []struct { + name string + store store.Store + action string + schema string + table string + + wantSubscriptions []*subscription.Subscription + wantErr error + }{ + { + name: "ok", + action: "I", + + wantSubscriptions: []*subscription.Subscription{ + testSubscription2, + testSubscription3, + }, + wantErr: nil, + }, + } + + for _, tc := range tests { + tc := tc + t.Run(tc.name, func(t *testing.T) { + t.Parallel() + + cacheStore := &Store{ + inner: tc.store, + cacheLock: &sync.RWMutex{}, + cache: map[string]*subscription.Subscription{ + testSubscription1.Key(): testSubscription1, + testSubscription2.Key(): testSubscription2, + testSubscription3.Key(): testSubscription3, + }, + } + + subscriptions, err := cacheStore.GetSubscriptions(context.Background(), tc.action, tc.schema, tc.table) + require.ErrorIs(t, err, tc.wantErr) + require.ElementsMatch(t, tc.wantSubscriptions, subscriptions) + }) + } +} diff --git a/pkg/wal/processor/webhook/subscription/store/mocks/mock_subscription_store.go b/pkg/wal/processor/webhook/subscription/store/mocks/mock_subscription_store.go new file mode 100644 index 0000000..0201ab8 --- /dev/null +++ b/pkg/wal/processor/webhook/subscription/store/mocks/mock_subscription_store.go @@ -0,0 +1,27 @@ +// SPDX-License-Identifier: Apache-2.0 + +package mocks + +import ( + "context" + + "github.com/xataio/pgstream/pkg/wal/processor/webhook/subscription" +) + +type Store struct { + CreateSubscriptionFn func(ctx context.Context, s *subscription.Subscription) error + DeleteSubscriptionFn func(ctx context.Context, s *subscription.Subscription) error + GetSubscriptionsFn func(ctx context.Context, action, schema, table string) ([]*subscription.Subscription, error) +} + +func (m *Store) CreateSubscription(ctx context.Context, s *subscription.Subscription) error { + return m.CreateSubscriptionFn(ctx, s) +} + +func (m *Store) DeleteSubscription(ctx context.Context, s *subscription.Subscription) error { + return m.DeleteSubscriptionFn(ctx, s) +} + +func (m *Store) GetSubscriptions(ctx context.Context, action, schema, table string) ([]*subscription.Subscription, error) { + return m.GetSubscriptionsFn(ctx, action, schema, table) +} diff --git a/pkg/wal/processor/webhook/postgres/pg_subscription_store.go b/pkg/wal/processor/webhook/subscription/store/postgres/pg_subscription_store.go similarity index 77% rename from pkg/wal/processor/webhook/postgres/pg_subscription_store.go rename to pkg/wal/processor/webhook/subscription/store/postgres/pg_subscription_store.go index ceab7a7..fa3e09f 100644 --- a/pkg/wal/processor/webhook/postgres/pg_subscription_store.go +++ b/pkg/wal/processor/webhook/subscription/store/postgres/pg_subscription_store.go @@ -9,19 +9,19 @@ import ( "github.com/jackc/pgx/v5" loglib "github.com/xataio/pgstream/pkg/log" - "github.com/xataio/pgstream/pkg/wal/processor/webhook" + "github.com/xataio/pgstream/pkg/wal/processor/webhook/subscription" ) -type SubscriptionStore struct { +type Store struct { conn *pgx.Conn logger loglib.Logger } -type Option func(*SubscriptionStore) +type Option func(*Store) const subscriptionsTable = "webhook_subscriptions" -func NewSubscriptionStore(ctx context.Context, url string, opts ...Option) (*SubscriptionStore, error) { +func NewSubscriptionStore(ctx context.Context, url string, opts ...Option) (*Store, error) { pgCfg, err := pgx.ParseConfig(url) if err != nil { return nil, err @@ -31,7 +31,7 @@ func NewSubscriptionStore(ctx context.Context, url string, opts ...Option) (*Sub return nil, fmt.Errorf("create postgres client: %w", err) } - ss := &SubscriptionStore{ + ss := &Store{ conn: pgConn, } @@ -48,14 +48,14 @@ func NewSubscriptionStore(ctx context.Context, url string, opts ...Option) (*Sub } func WithLogger(l loglib.Logger) Option { - return func(ss *SubscriptionStore) { + return func(ss *Store) { ss.logger = loglib.NewLogger(l).WithFields(loglib.Fields{ loglib.ServiceField: "webhook_subscription_store", }) } } -func (s *SubscriptionStore) CreateSubscription(ctx context.Context, subscription *webhook.Subscription) error { +func (s *Store) CreateSubscription(ctx context.Context, subscription *subscription.Subscription) error { query := fmt.Sprintf(` INSERT INTO %s(url, schema_name, table_name, event_types) VALUES($1, $2, $3, $4) ON CONFLICT (url,schema_name,table_name) DO UPDATE SET event_types = EXCLUDED.event_types;`, subscriptionsTable) @@ -63,15 +63,15 @@ func (s *SubscriptionStore) CreateSubscription(ctx context.Context, subscription return err } -func (s *SubscriptionStore) DeleteSubscription(ctx context.Context, subscription *webhook.Subscription) error { +func (s *Store) DeleteSubscription(ctx context.Context, subscription *subscription.Subscription) error { query := fmt.Sprintf(`DELETE FROM %s WHERE url=$1 AND schema_name=$2 AND table=$3;`, subscriptionsTable) _, err := s.conn.Exec(ctx, query, subscription.URL, subscription.Schema, subscription.Table) return err } -func (s *SubscriptionStore) GetSubscriptions(ctx context.Context, action, schema, table string) ([]*webhook.Subscription, error) { +func (s *Store) GetSubscriptions(ctx context.Context, action, schema, table string) ([]*subscription.Subscription, error) { query, params := s.buildGetQuery(action, schema, table) - s.logger.Debug("getting subscriptions", loglib.Fields{ + s.logger.Trace("getting subscriptions", loglib.Fields{ "query": query, "params": params, }) @@ -81,9 +81,9 @@ func (s *SubscriptionStore) GetSubscriptions(ctx context.Context, action, schema } defer rows.Close() - subscriptions := []*webhook.Subscription{} + subscriptions := []*subscription.Subscription{} for rows.Next() { - subscription := &webhook.Subscription{} + subscription := &subscription.Subscription{} if err := rows.Scan(&subscription.URL, &subscription.Schema, &subscription.Table, &subscription.EventTypes); err != nil { return nil, fmt.Errorf("scanning subscription row: %w", err) } @@ -94,7 +94,7 @@ func (s *SubscriptionStore) GetSubscriptions(ctx context.Context, action, schema return subscriptions, nil } -func (s *SubscriptionStore) createTable(ctx context.Context) error { +func (s *Store) createTable(ctx context.Context) error { query := fmt.Sprintf(`CREATE TABLE IF NOT EXISTS %s( url TEXT, schema_name TEXT, @@ -105,7 +105,7 @@ func (s *SubscriptionStore) createTable(ctx context.Context) error { return err } -func (s *SubscriptionStore) buildGetQuery(action, schema, table string) (string, []any) { +func (s *Store) buildGetQuery(action, schema, table string) (string, []any) { query := fmt.Sprintf(`SELECT url, schema_name, table_name, event_types FROM %s`, subscriptionsTable) separator := func(params []any) string { diff --git a/pkg/wal/processor/webhook/postgres/pg_subscription_store_test.go b/pkg/wal/processor/webhook/subscription/store/postgres/pg_subscription_store_test.go similarity index 95% rename from pkg/wal/processor/webhook/postgres/pg_subscription_store_test.go rename to pkg/wal/processor/webhook/subscription/store/postgres/pg_subscription_store_test.go index b29ccbf..3dc355f 100644 --- a/pkg/wal/processor/webhook/postgres/pg_subscription_store_test.go +++ b/pkg/wal/processor/webhook/subscription/store/postgres/pg_subscription_store_test.go @@ -9,7 +9,7 @@ import ( "github.com/stretchr/testify/require" ) -func TestSubscriptionStore_buildGetQuery(t *testing.T) { +func TestStore_buildGetQuery(t *testing.T) { t.Parallel() tests := []struct { @@ -62,7 +62,7 @@ func TestSubscriptionStore_buildGetQuery(t *testing.T) { t.Run(tc.name, func(t *testing.T) { t.Parallel() - s := &SubscriptionStore{} + s := &Store{} query, params := s.buildGetQuery(tc.action, tc.schema, tc.table) require.Equal(t, tc.wantQuery, query) require.Equal(t, tc.wantParams, params) diff --git a/pkg/wal/processor/webhook/subscription/store/subscription_store.go b/pkg/wal/processor/webhook/subscription/store/subscription_store.go new file mode 100644 index 0000000..d2a199d --- /dev/null +++ b/pkg/wal/processor/webhook/subscription/store/subscription_store.go @@ -0,0 +1,15 @@ +// SPDX-License-Identifier: Apache-2.0 + +package store + +import ( + "context" + + "github.com/xataio/pgstream/pkg/wal/processor/webhook/subscription" +) + +type Store interface { + CreateSubscription(ctx context.Context, s *subscription.Subscription) error + DeleteSubscription(ctx context.Context, s *subscription.Subscription) error + GetSubscriptions(ctx context.Context, action, schema, table string) ([]*subscription.Subscription, error) +} diff --git a/pkg/wal/processor/webhook/subscription_store.go b/pkg/wal/processor/webhook/subscription/subscription.go similarity index 67% rename from pkg/wal/processor/webhook/subscription_store.go rename to pkg/wal/processor/webhook/subscription/subscription.go index 5b190f7..2c429b5 100644 --- a/pkg/wal/processor/webhook/subscription_store.go +++ b/pkg/wal/processor/webhook/subscription/subscription.go @@ -1,21 +1,12 @@ // SPDX-License-Identifier: Apache-2.0 -package webhook +package subscription import ( - "context" "fmt" "slices" - - "github.com/xataio/pgstream/pkg/wal" ) -type SubscriptionStore interface { - CreateSubscription(ctx context.Context, s *Subscription) error - DeleteSubscription(ctx context.Context, s *Subscription) error - GetSubscriptions(ctx context.Context, action, schema, table string) ([]*Subscription, error) -} - type Subscription struct { URL string `json:"url"` EventTypes []string `json:"event_types"` @@ -23,10 +14,6 @@ type Subscription struct { Table string `json:"table"` } -type Payload struct { - Data *wal.Data -} - func (s *Subscription) IsFor(action, schema, table string) bool { if action == "" && schema == "" && table == "" { return true diff --git a/pkg/wal/processor/webhook/subscription_store_test.go b/pkg/wal/processor/webhook/subscription/subscription_test.go similarity index 99% rename from pkg/wal/processor/webhook/subscription_store_test.go rename to pkg/wal/processor/webhook/subscription/subscription_test.go index a27ca61..dd1336b 100644 --- a/pkg/wal/processor/webhook/subscription_store_test.go +++ b/pkg/wal/processor/webhook/subscription/subscription_test.go @@ -1,6 +1,6 @@ // SPDX-License-Identifier: Apache-2.0 -package webhook +package subscription import ( "testing" diff --git a/pkg/wal/processor/webhook/webhook.go b/pkg/wal/processor/webhook/webhook.go new file mode 100644 index 0000000..fab1cdf --- /dev/null +++ b/pkg/wal/processor/webhook/webhook.go @@ -0,0 +1,9 @@ +// SPDX-License-Identifier: Apache-2.0 + +package webhook + +import "github.com/xataio/pgstream/pkg/wal" + +type Payload struct { + Data *wal.Data +}