Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add pg listener #9

Merged
merged 4 commits into from
May 16, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
33 changes: 23 additions & 10 deletions internal/replication/mocks/mock_replication_handler.go
Original file line number Diff line number Diff line change
Expand Up @@ -4,44 +4,57 @@ package mocks

import (
"context"
"sync/atomic"

"github.com/xataio/pgstream/internal/replication"
)

type ReplicationHandler struct {
type Handler struct {
StartReplicationFn func(context.Context) error
ReceiveMessageFn func(context.Context) (replication.Message, error)
ReceiveMessageFn func(context.Context, uint64) (replication.Message, error)
UpdateLSNPositionFn func(lsn replication.LSN)
SyncLSNFn func(context.Context) error
DropReplicationSlotFn func(ctx context.Context) error
GetLSNParserFn func() replication.LSNParser
CloseFn func() error
SyncLSNCalls uint64
ReceiveMessageCalls uint64
}

func (m *ReplicationHandler) StartReplication(ctx context.Context) error {
func (m *Handler) StartReplication(ctx context.Context) error {
return m.StartReplicationFn(ctx)
}

func (m *ReplicationHandler) ReceiveMessage(ctx context.Context) (replication.Message, error) {
return m.ReceiveMessageFn(ctx)
func (m *Handler) ReceiveMessage(ctx context.Context) (replication.Message, error) {
atomic.AddUint64(&m.ReceiveMessageCalls, 1)
return m.ReceiveMessageFn(ctx, m.GetReceiveMessageCalls())
}

func (m *ReplicationHandler) UpdateLSNPosition(lsn replication.LSN) {
func (m *Handler) UpdateLSNPosition(lsn replication.LSN) {
m.UpdateLSNPositionFn(lsn)
}

func (m *ReplicationHandler) SyncLSN(ctx context.Context) error {
func (m *Handler) SyncLSN(ctx context.Context) error {
atomic.AddUint64(&m.SyncLSNCalls, 1)
return m.SyncLSNFn(ctx)
}

func (m *ReplicationHandler) DropReplicationSlot(ctx context.Context) error {
func (m *Handler) DropReplicationSlot(ctx context.Context) error {
return m.DropReplicationSlotFn(ctx)
}

func (m *ReplicationHandler) GetLSNParser() replication.LSNParser {
func (m *Handler) GetLSNParser() replication.LSNParser {
return m.GetLSNParserFn()
}

func (m *ReplicationHandler) Close() error {
func (m *Handler) Close() error {
return m.CloseFn()
}

func (m *Handler) GetSyncLSNCalls() uint64 {
return atomic.LoadUint64(&m.SyncLSNCalls)
}

func (m *Handler) GetReceiveMessageCalls() uint64 {
return atomic.LoadUint64(&m.ReceiveMessageCalls)
}
20 changes: 20 additions & 0 deletions internal/replication/mocks/mock_replication_lsn_parser.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,20 @@
// SPDX-License-Identifier: Apache-2.0

package mocks

import (
"github.com/xataio/pgstream/internal/replication"
)

type LSNParser struct {
ToStringFn func(replication.LSN) string
FromStringFn func(string) (replication.LSN, error)
}

func (m *LSNParser) ToString(lsn replication.LSN) string {
return m.ToStringFn(lsn)
}

func (m *LSNParser) FromString(lsn string) (replication.LSN, error) {
return m.FromStringFn(lsn)
}
15 changes: 15 additions & 0 deletions internal/replication/mocks/mock_replication_message.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,15 @@
// SPDX-License-Identifier: Apache-2.0

package mocks

import (
"github.com/xataio/pgstream/internal/replication"
)

type Message struct {
GetDataFn func() *replication.MessageData
}

func (m *Message) GetData() *replication.MessageData {
return m.GetDataFn()
}
74 changes: 37 additions & 37 deletions internal/replication/postgres/pg_replication_handler.go
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,7 @@ import (
"github.com/xataio/pgstream/internal/replication"
)

type ReplicationHandler struct {
type Handler struct {
// Create two connections. One for querying, one for handling replication
// events.
pgConn *pgx.Conn
Expand All @@ -37,7 +37,7 @@ const (
logSystemID = "system_id"
)

func NewReplicationHandler(ctx context.Context, cfg *pgx.ConnConfig) (*ReplicationHandler, error) {
func NewHandler(ctx context.Context, cfg *pgx.ConnConfig) (*Handler, error) {
pgConn, err := pgx.ConnectConfig(ctx, cfg)
if err != nil {
return nil, fmt.Errorf("create postgres client: %w", err)
Expand All @@ -52,25 +52,25 @@ func NewReplicationHandler(ctx context.Context, cfg *pgx.ConnConfig) (*Replicati
return nil, fmt.Errorf("create postgres replication client: %w", err)
}

return &ReplicationHandler{
return &Handler{
pgConn: pgConn,
pgReplicationConn: pgReplicationConn,
lsnParser: &LSNParser{},
}, nil
}

func (c *ReplicationHandler) StartReplication(ctx context.Context) error {
sysID, err := pglogrepl.IdentifySystem(ctx, c.pgReplicationConn)
func (h *Handler) StartReplication(ctx context.Context) error {
sysID, err := pglogrepl.IdentifySystem(ctx, h.pgReplicationConn)
if err != nil {
return fmt.Errorf("identifySystem failed: %w", err)
}

c.pgReplicationSlotName = fmt.Sprintf("%s_slot", sysID.DBName)
h.pgReplicationSlotName = fmt.Sprintf("%s_slot", sysID.DBName)

logger := log.Ctx(ctx).With().
Str(logSystemID, sysID.SystemID).
Str(logDBName, sysID.DBName).
Str(logSlotName, c.pgReplicationSlotName).
Str(logSlotName, h.pgReplicationSlotName).
Logger()
ctx = logger.WithContext(ctx)

Expand All @@ -79,7 +79,7 @@ func (c *ReplicationHandler) StartReplication(ctx context.Context) error {
Stringer(logLSNPosition, sysID.XLogPos).
Msg("identifySystem success")

startPos, err := c.getLastSyncedLSN(ctx)
startPos, err := h.getLastSyncedLSN(ctx)
if err != nil {
return fmt.Errorf("read last position: %w", err)
}
Expand All @@ -92,7 +92,7 @@ func (c *ReplicationHandler) StartReplication(ctx context.Context) error {
// todo(deverts): If we don't have a position. Read from as early as possible.
// this _could_ be too old. In the future, it would be good to calculate if we're
// too far behind, so we can fix it.
startPos, err = c.getRestartLSN(ctx, c.pgReplicationSlotName)
startPos, err = h.getRestartLSN(ctx, h.pgReplicationSlotName)
if err != nil {
return fmt.Errorf("get restart LSN: %w", err)
}
Expand All @@ -109,23 +109,23 @@ func (c *ReplicationHandler) StartReplication(ctx context.Context) error {
}
err = pglogrepl.StartReplication(
ctx,
c.pgReplicationConn,
c.pgReplicationSlotName,
h.pgReplicationConn,
h.pgReplicationSlotName,
pglogrepl.LSN(startPos),
pglogrepl.StartReplicationOptions{PluginArgs: pluginArguments})
if err != nil {
return fmt.Errorf("startReplication: %w", err)
}

logger.Info().Msgf("logical replication started on slot %v.", c.pgReplicationSlotName)
logger.Info().Msgf("logical replication started on slot %v.", h.pgReplicationSlotName)

c.UpdateLSNPosition(startPos)
h.UpdateLSNPosition(startPos)

return nil
}

func (c *ReplicationHandler) ReceiveMessage(ctx context.Context) (replication.Message, error) {
msg, err := c.pgReplicationConn.ReceiveMessage(ctx)
func (h *Handler) ReceiveMessage(ctx context.Context) (replication.Message, error) {
msg, err := h.pgReplicationConn.ReceiveMessage(ctx)
if err != nil {
return nil, mapPostgresError(err)
}
Expand Down Expand Up @@ -159,16 +159,16 @@ func (c *ReplicationHandler) ReceiveMessage(ctx context.Context) (replication.Me
}
}

func (c *ReplicationHandler) UpdateLSNPosition(lsn replication.LSN) {
atomic.StoreUint64(&c.currentLSN, uint64(lsn))
func (h *Handler) UpdateLSNPosition(lsn replication.LSN) {
atomic.StoreUint64(&h.currentLSN, uint64(lsn))
}

// SyncLSN notifies Postgres how far we have processed in the WAL.
func (c *ReplicationHandler) SyncLSN(ctx context.Context) error {
lsn := c.getLSNPosition()
func (h *Handler) SyncLSN(ctx context.Context) error {
lsn := h.getLSNPosition()
err := pglogrepl.SendStandbyStatusUpdate(
ctx,
c.pgReplicationConn,
h.pgReplicationConn,
pglogrepl.StandbyStatusUpdate{WALWritePosition: pglogrepl.LSN(lsn)},
)
if err != nil {
Expand All @@ -178,43 +178,43 @@ func (c *ReplicationHandler) SyncLSN(ctx context.Context) error {
return nil
}

func (c *ReplicationHandler) DropReplicationSlot(ctx context.Context) error {
func (h *Handler) DropReplicationSlot(ctx context.Context) error {
err := pglogrepl.DropReplicationSlot(
ctx,
c.pgReplicationConn,
c.pgReplicationSlotName,
h.pgReplicationConn,
h.pgReplicationSlotName,
pglogrepl.DropReplicationSlotOptions{Wait: true},
)
if err != nil {
return fmt.Errorf("clean up replication slot %q: %w", c.pgReplicationSlotName, err)
return fmt.Errorf("clean up replication slot %q: %w", h.pgReplicationSlotName, err)
}

return nil
}

func (c *ReplicationHandler) GetLSNParser() replication.LSNParser {
return c.lsnParser
func (h *Handler) GetLSNParser() replication.LSNParser {
return h.lsnParser
}

// Close closes the database connections.
func (c *ReplicationHandler) Close() error {
err := c.pgReplicationConn.Close(context.Background())
func (h *Handler) Close() error {
err := h.pgReplicationConn.Close(context.Background())
if err != nil {
return err
}
return c.pgConn.Close(context.Background())
return h.pgConn.Close(context.Background())
}

func (c *ReplicationHandler) getLSNPosition() replication.LSN {
return replication.LSN(atomic.LoadUint64(&c.currentLSN))
func (h *Handler) getLSNPosition() replication.LSN {
return replication.LSN(atomic.LoadUint64(&h.currentLSN))
}

// getRestartLSN returns the absolute earliest possible LSN we can support. If
// the consumer's LSN is earlier than this, we cannot (easily) catch the
// consumer back up.
func (c *ReplicationHandler) getRestartLSN(ctx context.Context, slotName string) (replication.LSN, error) {
func (h *Handler) getRestartLSN(ctx context.Context, slotName string) (replication.LSN, error) {
var restartLSN string
err := c.pgConn.QueryRow(
err := h.pgConn.QueryRow(
ctx,
`select restart_lsn from pg_replication_slots where slot_name=$1`,
slotName,
Expand All @@ -223,17 +223,17 @@ func (c *ReplicationHandler) getRestartLSN(ctx context.Context, slotName string)
// TODO: improve error message in case the slot doesn't exist
return 0, err
}
return c.lsnParser.FromString(restartLSN)
return h.lsnParser.FromString(restartLSN)
}

// getLastSyncedLSN gets the `confirmed_flush_lsn` from PG. This is the last LSN
// that the consumer confirmed it had completed.
func (c *ReplicationHandler) getLastSyncedLSN(ctx context.Context) (replication.LSN, error) {
func (h *Handler) getLastSyncedLSN(ctx context.Context) (replication.LSN, error) {
var confirmedFlushLSN string
err := c.pgConn.QueryRow(ctx, `select confirmed_flush_lsn from pg_replication_slots where slot_name=$1`, c.pgReplicationSlotName).Scan(&confirmedFlushLSN)
err := h.pgConn.QueryRow(ctx, `select confirmed_flush_lsn from pg_replication_slots where slot_name=$1`, h.pgReplicationSlotName).Scan(&confirmedFlushLSN)
if err != nil {
return 0, err
}

return c.lsnParser.FromString(confirmedFlushLSN)
return h.lsnParser.FromString(confirmedFlushLSN)
}
4 changes: 2 additions & 2 deletions pkg/schemalog/postgres/pg_schemalog_store.go
Original file line number Diff line number Diff line change
Expand Up @@ -25,7 +25,7 @@ type querier interface {
Close(ctx context.Context)
}

func NewSchemalogStore(ctx context.Context, cfg *pgx.ConnConfig) (*Store, error) {
func NewStore(ctx context.Context, cfg *pgx.ConnConfig) (*Store, error) {
pgConn, err := pgx.ConnectConfig(ctx, cfg)
if err != nil {
return nil, fmt.Errorf("create postgres client: %w", err)
Expand All @@ -35,7 +35,7 @@ func NewSchemalogStore(ctx context.Context, cfg *pgx.ConnConfig) (*Store, error)
}, nil
}

func NewSchemalogStoreWithQuerier(querier querier) *Store {
func NewStoreWithQuerier(querier querier) *Store {
return &Store{
querier: querier,
}
Expand Down
4 changes: 2 additions & 2 deletions pkg/schemalog/postgres/pg_schemalog_store_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -84,7 +84,7 @@ func TestStore_Fetch(t *testing.T) {
t.Run(tc.name, func(t *testing.T) {
t.Parallel()

s := NewSchemalogStoreWithQuerier(tc.querier)
s := NewStoreWithQuerier(tc.querier)

logEntry, err := s.Fetch(context.Background(), testSchema, tc.ackedOnly)
require.ErrorIs(t, err, tc.wantErr)
Expand Down Expand Up @@ -146,7 +146,7 @@ func TestStore_Ack(t *testing.T) {
t.Run(tc.name, func(t *testing.T) {
t.Parallel()

s := NewSchemalogStoreWithQuerier(tc.querier)
s := NewStoreWithQuerier(tc.querier)

err := s.Ack(context.Background(), tc.logEntry)
require.ErrorIs(t, err, tc.wantErr)
Expand Down
56 changes: 56 additions & 0 deletions pkg/wal/listener/postgres/helper_test.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,56 @@
// SPDX-License-Identifier: Apache-2.0

package postgres

import (
"context"
"time"

"github.com/xataio/pgstream/internal/replication"
replicationmocks "github.com/xataio/pgstream/internal/replication/mocks"
)

const (
testLSN = replication.LSN(7773397064)
)

func newMockReplicationHandler() *replicationmocks.Handler {
return &replicationmocks.Handler{
StartReplicationFn: func(context.Context) error { return nil },
GetLSNParserFn: func() replication.LSNParser { return newMockLSNParser() },
SyncLSNFn: func(ctx context.Context) error { return nil },
ReceiveMessageFn: func(ctx context.Context, i uint64) (replication.Message, error) {
return newMockMessage(), nil
},
}
}

func newMockMessage() *replicationmocks.Message {
return &replicationmocks.Message{
GetDataFn: func() *replication.MessageData {
return &replication.MessageData{
LSN: testLSN,
Data: []byte("test-data"),
ReplyRequested: false,
ServerTime: time.Now(),
}
},
}
}

func newMockKeepAliveMessage(replyRequested bool) *replicationmocks.Message {
return &replicationmocks.Message{
GetDataFn: func() *replication.MessageData {
return &replication.MessageData{
LSN: testLSN,
ReplyRequested: replyRequested,
}
},
}
}

func newMockLSNParser() *replicationmocks.LSNParser {
return &replicationmocks.LSNParser{
ToStringFn: func(replication.LSN) string { return "lsn" },
}
}
Loading