Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add replication handler tests #52

Merged
merged 3 commits into from
Jul 3, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
19 changes: 19 additions & 0 deletions internal/postgres/errors.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,19 @@
// SPDX-License-Identifier: Apache-2.0

package postgres

import (
"errors"

"github.com/jackc/pgx/v5/pgconn"
)

var ErrConnTimeout = errors.New("connection timeout")

func mapError(err error) error {
if pgconn.Timeout(err) {
return ErrConnTimeout
}

return err
}
24 changes: 24 additions & 0 deletions internal/postgres/mocks/mock_pg_conn.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,24 @@
// SPDX-License-Identifier: Apache-2.0

package mocks

import (
"context"

"github.com/jackc/pgx/v5/pgconn"
"github.com/xataio/pgstream/internal/postgres"
)

type Conn struct {
QueryRowFn func(ctx context.Context, query string, args ...any) postgres.Row
ExecFn func(context.Context, string, ...any) (pgconn.CommandTag, error)
CloseFn func(context.Context) error
}

func (m *Conn) QueryRow(ctx context.Context, query string, args ...any) postgres.Row {
return m.QueryRowFn(ctx, query, args...)
}

func (m *Conn) Close(ctx context.Context) error {
return m.CloseFn(ctx)
}
37 changes: 37 additions & 0 deletions internal/postgres/mocks/mock_pg_replication_conn.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,37 @@
// SPDX-License-Identifier: Apache-2.0

package mocks

import (
"context"

"github.com/xataio/pgstream/internal/postgres"
)

type ReplicationConn struct {
IdentifySystemFn func(ctx context.Context) (postgres.IdentifySystemResult, error)
StartReplicationFn func(ctx context.Context, cfg postgres.ReplicationConfig) error
SendStandbyStatusUpdateFn func(ctx context.Context, lsn uint64) error
ReceiveMessageFn func(ctx context.Context) (*postgres.ReplicationMessage, error)
CloseFn func(ctx context.Context) error
}

func (m *ReplicationConn) IdentifySystem(ctx context.Context) (postgres.IdentifySystemResult, error) {
return m.IdentifySystemFn(ctx)
}

func (m *ReplicationConn) StartReplication(ctx context.Context, cfg postgres.ReplicationConfig) error {
return m.StartReplicationFn(ctx, cfg)
}

func (m *ReplicationConn) SendStandbyStatusUpdate(ctx context.Context, lsn uint64) error {
return m.SendStandbyStatusUpdateFn(ctx, lsn)
}

func (m *ReplicationConn) ReceiveMessage(ctx context.Context) (*postgres.ReplicationMessage, error) {
return m.ReceiveMessageFn(ctx)
}

func (m *ReplicationConn) Close(ctx context.Context) error {
return m.CloseFn(ctx)
}
40 changes: 40 additions & 0 deletions internal/postgres/pg_conn.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,40 @@
// SPDX-License-Identifier: Apache-2.0

package postgres

import (
"context"
"fmt"

"github.com/jackc/pgx/v5"
)

type Conn struct {
conn *pgx.Conn
}

type Row interface {
pgx.Row
}

func NewConn(ctx context.Context, url string) (*Conn, error) {
pgCfg, err := pgx.ParseConfig(url)
if err != nil {
return nil, fmt.Errorf("failed parsing postgres connection string: %w", mapError(err))
}

conn, err := pgx.ConnectConfig(ctx, pgCfg)
if err != nil {
return nil, fmt.Errorf("failed to connect to postgres: %w", mapError(err))
}

return &Conn{conn: conn}, nil
}

func (c *Conn) QueryRow(ctx context.Context, query string, args ...any) Row {
return c.conn.QueryRow(ctx, query, args...)
}

func (c *Conn) Close(ctx context.Context) error {
return mapError(c.conn.Close(ctx))
}
138 changes: 138 additions & 0 deletions internal/postgres/pg_replication_conn.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,138 @@
// SPDX-License-Identifier: Apache-2.0

package postgres

import (
"context"
"errors"
"fmt"
"time"

"github.com/jackc/pglogrepl"
"github.com/jackc/pgx/v5"
"github.com/jackc/pgx/v5/pgconn"
"github.com/jackc/pgx/v5/pgproto3"
)

type ReplicationConn struct {
conn *pgconn.PgConn
}

type ReplicationConfig struct {
SlotName string
StartPos uint64
PluginArguments []string
}

type ReplicationMessage struct {
LSN uint64
ServerTime time.Time
WALData []byte
ReplyRequested bool
}

type IdentifySystemResult pglogrepl.IdentifySystemResult

var ErrUnsupportedCopyDataMessage = errors.New("unsupported copy data message")

func NewReplicationConn(ctx context.Context, url string) (*ReplicationConn, error) {
pgCfg, err := pgx.ParseConfig(url)
if err != nil {
return nil, fmt.Errorf("failed parsing postgres connection string: %w", err)
}

pgCfg.RuntimeParams["replication"] = "database"

conn, err := pgconn.ConnectConfig(context.Background(), &pgCfg.Config)
if err != nil {
return nil, fmt.Errorf("create postgres replication client: %w", mapError(err))
}

return &ReplicationConn{
conn: conn,
}, nil
}

func (c *ReplicationConn) IdentifySystem(ctx context.Context) (IdentifySystemResult, error) {
res, err := pglogrepl.IdentifySystem(ctx, c.conn)
return IdentifySystemResult(res), mapError(err)
}

func (c *ReplicationConn) StartReplication(ctx context.Context, cfg ReplicationConfig) error {
return mapError(pglogrepl.StartReplication(
ctx,
c.conn,
cfg.SlotName,
pglogrepl.LSN(cfg.StartPos),
pglogrepl.StartReplicationOptions{PluginArgs: cfg.PluginArguments}))
}

func (c *ReplicationConn) SendStandbyStatusUpdate(ctx context.Context, lsn uint64) error {
return mapError(pglogrepl.SendStandbyStatusUpdate(
ctx,
c.conn,
pglogrepl.StandbyStatusUpdate{WALWritePosition: pglogrepl.LSN(lsn)},
))
}

func (c *ReplicationConn) ReceiveMessage(ctx context.Context) (*ReplicationMessage, error) {
msg, err := c.conn.ReceiveMessage(ctx)
if err != nil {
return nil, mapError(err)
}

switch msg := msg.(type) {
case *pgproto3.CopyData:
switch msg.Data[0] {
case pglogrepl.PrimaryKeepaliveMessageByteID:
pka, err := pglogrepl.ParsePrimaryKeepaliveMessage(msg.Data[1:])
if err != nil {
return nil, fmt.Errorf("parse keep alive: %w", err)
}
return &ReplicationMessage{
LSN: uint64(pka.ServerWALEnd),
ServerTime: pka.ServerTime,
ReplyRequested: pka.ReplyRequested,
}, nil
case pglogrepl.XLogDataByteID:
xld, err := pglogrepl.ParseXLogData(msg.Data[1:])
if err != nil {
return nil, fmt.Errorf("parse xlog data: %w", err)
}

return &ReplicationMessage{
LSN: uint64(xld.WALStart) + uint64(len(xld.WALData)),
ServerTime: xld.ServerTime,
WALData: xld.WALData,
}, nil
default:
return nil, fmt.Errorf("%v: %w", msg.Data[0], ErrUnsupportedCopyDataMessage)
}
case *pgproto3.NoticeResponse:
return nil, parseErrNoticeResponse(msg)
default:
// unexpected message (WAL error?)
return nil, fmt.Errorf("unexpected message: %#v", msg)
}
}

func (c *ReplicationConn) Close(ctx context.Context) error {
return mapError(c.conn.Close(ctx))
}

type Error struct {
Severity string
Msg string
}

func (e *Error) Error() string {
return fmt.Sprintf("replication error: %s", e.Msg)
}

func parseErrNoticeResponse(errMsg *pgproto3.NoticeResponse) error {
return &Error{
Severity: errMsg.Severity,
Msg: fmt.Sprintf("replication notice response: severity: %s, code: %s, message: %s, detail: %s, schemaName: %s, tableName: %s, columnName: %s",
errMsg.Severity, errMsg.Code, errMsg.Message, errMsg.Detail, errMsg.SchemaName, errMsg.TableName, errMsg.ColumnName),
}
}
30 changes: 11 additions & 19 deletions pkg/wal/listener/postgres/helper_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -20,33 +20,25 @@ func newMockReplicationHandler() *replicationmocks.Handler {
StartReplicationFn: func(context.Context) error { return nil },
GetLSNParserFn: func() replication.LSNParser { return newMockLSNParser() },
SyncLSNFn: func(ctx context.Context, lsn replication.LSN) error { return nil },
ReceiveMessageFn: func(ctx context.Context, i uint64) (replication.Message, error) {
ReceiveMessageFn: func(ctx context.Context, i uint64) (*replication.Message, error) {
return newMockMessage(), nil
},
}
}

func newMockMessage() *replicationmocks.Message {
return &replicationmocks.Message{
GetDataFn: func() *replication.MessageData {
return &replication.MessageData{
LSN: testLSN,
Data: []byte("test-data"),
ReplyRequested: false,
ServerTime: time.Now(),
}
},
func newMockMessage() *replication.Message {
return &replication.Message{
LSN: testLSN,
Data: []byte("test-data"),
ReplyRequested: false,
ServerTime: time.Now(),
}
}

func newMockKeepAliveMessage(replyRequested bool) *replicationmocks.Message {
return &replicationmocks.Message{
GetDataFn: func() *replication.MessageData {
return &replication.MessageData{
LSN: testLSN,
ReplyRequested: replyRequested,
}
},
func newMockKeepAliveMessage(replyRequested bool) *replication.Message {
return &replication.Message{
LSN: testLSN,
ReplyRequested: replyRequested,
}
}

Expand Down
26 changes: 12 additions & 14 deletions pkg/wal/listener/postgres/wal_pg_listener.go
Original file line number Diff line number Diff line change
Expand Up @@ -29,7 +29,7 @@ type Listener struct {

type replicationHandler interface {
StartReplication(ctx context.Context) error
ReceiveMessage(ctx context.Context) (replication.Message, error)
ReceiveMessage(ctx context.Context) (*replication.Message, error)
GetLSNParser() replication.LSNParser
Close() error
}
Expand Down Expand Up @@ -89,46 +89,44 @@ func (l *Listener) listen(ctx context.Context) error {
default:
msg, err := l.replicationHandler.ReceiveMessage(ctx)
if err != nil {
replErr := &replication.Error{}
if errors.Is(err, replication.ErrConnTimeout) || (errors.As(err, &replErr) && replErr.Severity == "WARNING") {
if errors.Is(err, replication.ErrConnTimeout) {
continue
}
return fmt.Errorf("receiving message: %w", err)
}

msgData := msg.GetData()
if msgData == nil {
if msg == nil {
continue
}

l.logger.Trace("", loglib.Fields{
"wal_end": l.lsnParser.ToString(msgData.LSN),
"server_time": msgData.ServerTime,
"wal_data": msgData.Data,
"wal_end": l.lsnParser.ToString(msg.LSN),
"server_time": msg.ServerTime,
"wal_data": msg.Data,
})

if err := l.processWALEvent(ctx, msgData); err != nil {
if err := l.processWALEvent(ctx, msg); err != nil {
return err
}
}
}
}

func (l *Listener) processWALEvent(ctx context.Context, msgData *replication.MessageData) error {
func (l *Listener) processWALEvent(ctx context.Context, msg *replication.Message) error {
// if there's no data, it's a keep alive. If a reply is not requested,
// no need to process this message.
if msgData.Data == nil && !msgData.ReplyRequested {
if msg.Data == nil && !msg.ReplyRequested {
return nil
}

event := &wal.Event{}
if msgData.Data != nil {
if msg.Data != nil {
event.Data = &wal.Data{}
if err := l.walDataDeserialiser(msgData.Data, event.Data); err != nil {
if err := l.walDataDeserialiser(msg.Data, event.Data); err != nil {
return fmt.Errorf("error unmarshaling wal data: %w", err)
}
}
event.CommitPosition = wal.CommitPosition(l.lsnParser.ToString(msgData.LSN))
event.CommitPosition = wal.CommitPosition(l.lsnParser.ToString(msg.LSN))

return l.processEvent(ctx, event)
}
Loading