diff --git a/cmd/config.go b/cmd/config.go index b137b03..a2a030c 100644 --- a/cmd/config.go +++ b/cmd/config.go @@ -6,11 +6,11 @@ import ( "fmt" "github.com/spf13/viper" - "github.com/xataio/pgstream/internal/backoff" - "github.com/xataio/pgstream/internal/kafka" - "github.com/xataio/pgstream/internal/tls" + "github.com/xataio/pgstream/pkg/backoff" + "github.com/xataio/pgstream/pkg/kafka" pgschemalog "github.com/xataio/pgstream/pkg/schemalog/postgres" "github.com/xataio/pgstream/pkg/stream" + "github.com/xataio/pgstream/pkg/tls" kafkacheckpoint "github.com/xataio/pgstream/pkg/wal/checkpointer/kafka" kafkalistener "github.com/xataio/pgstream/pkg/wal/listener/kafka" kafkaprocessor "github.com/xataio/pgstream/pkg/wal/processor/kafka" @@ -169,7 +169,7 @@ func parseSearchProcessorConfig() *stream.SearchProcessorConfig { Store: opensearch.Config{ URL: searchStore, }, - Retrier: &search.StoreRetryConfig{ + Retrier: search.StoreRetryConfig{ Backoff: parseBackoffConfig("PGSTREAM_SEARCH_STORE"), }, } @@ -245,8 +245,8 @@ func parseTranslatorConfig() *translator.Config { } } -func parseTLSConfig(prefix string) *tls.Config { - return &tls.Config{ +func parseTLSConfig(prefix string) tls.Config { + return tls.Config{ Enabled: viper.GetBool(fmt.Sprintf("%s_TLS_ENABLED", prefix)), CaCertFile: viper.GetString(fmt.Sprintf("%s_TLS_CA_CERT_FILE", prefix)), ClientCertFile: viper.GetString(fmt.Sprintf("%s_TLS_CLIENT_CERT_FILE", prefix)), diff --git a/go.mod b/go.mod index 2c492c1..8fe3d2b 100644 --- a/go.mod +++ b/go.mod @@ -10,7 +10,7 @@ require ( github.com/google/go-cmp v0.6.0 github.com/jackc/pgerrcode v0.0.0-20220416144525-469b46aa5efa github.com/jackc/pglogrepl v0.0.0-20240307033717-828fbfe908e9 - github.com/jackc/pgx/v5 v5.5.5 + github.com/jackc/pgx/v5 v5.6.0 github.com/labstack/echo/v4 v4.12.0 github.com/mitchellh/mapstructure v1.5.0 github.com/pterm/pterm v0.12.79 diff --git a/go.sum b/go.sum index 0576bfe..69e494d 100644 --- a/go.sum +++ b/go.sum @@ -123,8 +123,8 @@ github.com/jackc/pgpassfile v1.0.0 h1:/6Hmqy13Ss2zCq62VdNG8tM1wchn8zjSGOBJ6icpsI github.com/jackc/pgpassfile v1.0.0/go.mod h1:CEx0iS5ambNFdcRtxPj5JhEz+xB6uRky5eyVu/W2HEg= github.com/jackc/pgservicefile v0.0.0-20221227161230-091c0ba34f0a h1:bbPeKD0xmW/Y25WS6cokEszi5g+S0QxI/d45PkRi7Nk= github.com/jackc/pgservicefile v0.0.0-20221227161230-091c0ba34f0a/go.mod h1:5TJZWKEWniPve33vlWYSoGYefn3gLQRzjfDlhSJ9ZKM= -github.com/jackc/pgx/v5 v5.5.5 h1:amBjrZVmksIdNjxGW/IiIMzxMKZFelXbUoPNb+8sjQw= -github.com/jackc/pgx/v5 v5.5.5/go.mod h1:ez9gk+OAat140fv9ErkZDYFWmXLfV+++K0uAOiwgm1A= +github.com/jackc/pgx/v5 v5.6.0 h1:SWJzexBzPL5jb0GEsrPMLIsi/3jOo7RHlzTjcAeDrPY= +github.com/jackc/pgx/v5 v5.6.0/go.mod h1:DNZ/vlrUnhWCoFGxHAG8U2ljioxukquj7utPDgtQdTw= github.com/jackc/puddle/v2 v2.2.1 h1:RhxXJtFG022u4ibrCSMSiu5aOq1i77R3OHKNJj77OAk= github.com/jackc/puddle/v2 v2.2.1/go.mod h1:vriiEXHvEE654aYKXXjOvZM39qJ0q+azkZFrfEOc3H4= github.com/jcmturner/aescts/v2 v2.0.0 h1:9YKLH6ey7H4eDBXW8khjYslgyqG2xZikXP0EQFKrle8= diff --git a/internal/tls/tls.go b/internal/tls/tls.go deleted file mode 100644 index 8351681..0000000 --- a/internal/tls/tls.go +++ /dev/null @@ -1,83 +0,0 @@ -// SPDX-License-Identifier: Apache-2.0 - -package tls - -import ( - "crypto/tls" - "crypto/x509" - "fmt" - "io" - "os" -) - -type Config struct { - // Enabled determines if TLS should be used. Defaults to false. - Enabled bool - // File path to the CA PEM certificate to be used for TLS connection. If TLS is - // enabled and no CA cert file is provided, the system certificate pool is - // used as default. - CaCertFile string - // File path to the client PEM certificate - ClientCertFile string - // File path to the client PEM key - ClientKeyFile string -} - -func NewConfig(cfg *Config) (*tls.Config, error) { - if !cfg.Enabled { - return nil, nil - } - - certPool, err := getCertPool(cfg.CaCertFile) - if err != nil { - return nil, err - } - - certificates, err := getCertificates(cfg.ClientCertFile, cfg.ClientKeyFile) - if err != nil { - return nil, err - } - - return &tls.Config{ - MinVersion: tls.VersionTLS12, - MaxVersion: 0, - Certificates: certificates, - RootCAs: certPool, - }, nil -} - -func getCertPool(caCertFile string) (*x509.CertPool, error) { - if caCertFile != "" { - pemCertBytes, err := readFile(caCertFile) - if err != nil { - return nil, fmt.Errorf("reading CA certificate file: %w", err) - } - certPool := x509.NewCertPool() - certPool.AppendCertsFromPEM(pemCertBytes) - return certPool, nil - } - - return x509.SystemCertPool() -} - -func getCertificates(clientCertFile, clientKeyFile string) ([]tls.Certificate, error) { - if clientCertFile != "" && clientKeyFile != "" { - cert, err := tls.LoadX509KeyPair(clientCertFile, clientKeyFile) - if err != nil { - return nil, err - } - return []tls.Certificate{cert}, nil - } - - return []tls.Certificate{}, nil -} - -func readFile(path string) ([]byte, error) { - file, err := os.Open(path) - if err != nil { - return nil, err - } - defer file.Close() - - return io.ReadAll(file) -} diff --git a/internal/backoff/backoff.go b/pkg/backoff/backoff.go similarity index 100% rename from internal/backoff/backoff.go rename to pkg/backoff/backoff.go diff --git a/internal/backoff/mocks/mock_backoff.go b/pkg/backoff/mocks/mock_backoff.go similarity index 87% rename from internal/backoff/mocks/mock_backoff.go rename to pkg/backoff/mocks/mock_backoff.go index aef8da2..6a3be66 100644 --- a/internal/backoff/mocks/mock_backoff.go +++ b/pkg/backoff/mocks/mock_backoff.go @@ -2,7 +2,7 @@ package mocks -import "github.com/xataio/pgstream/internal/backoff" +import "github.com/xataio/pgstream/pkg/backoff" type Backoff struct { RetryNotifyFn func(backoff.Operation, backoff.Notify) error diff --git a/internal/kafka/config.go b/pkg/kafka/config.go similarity index 90% rename from internal/kafka/config.go rename to pkg/kafka/config.go index 37473de..55fc7fc 100644 --- a/internal/kafka/config.go +++ b/pkg/kafka/config.go @@ -2,12 +2,12 @@ package kafka -import tlslib "github.com/xataio/pgstream/internal/tls" +import tlslib "github.com/xataio/pgstream/pkg/tls" type ConnConfig struct { Servers []string Topic TopicConfig - TLS *tlslib.Config + TLS tlslib.Config } type TopicConfig struct { diff --git a/internal/kafka/conn.go b/pkg/kafka/conn.go similarity index 94% rename from internal/kafka/conn.go rename to pkg/kafka/conn.go index 36256df..df9ef78 100644 --- a/internal/kafka/conn.go +++ b/pkg/kafka/conn.go @@ -9,7 +9,7 @@ import ( "strconv" "time" - tlslib "github.com/xataio/pgstream/internal/tls" + tlslib "github.com/xataio/pgstream/pkg/tls" "github.com/segmentio/kafka-go" ) @@ -17,7 +17,7 @@ import ( // withConnection creates a connection that can be used by the kafka operation // passed in the parameters. This ensures the cleanup of all connection resources. func withConnection(config *ConnConfig, kafkaOperation func(conn *kafka.Conn) error) error { - dialer, err := buildDialer(config.TLS) + dialer, err := buildDialer(&config.TLS) if err != nil { return err } diff --git a/internal/kafka/instrumentation/instrumented_kafka_writer.go b/pkg/kafka/instrumentation/instrumented_kafka_writer.go similarity index 97% rename from internal/kafka/instrumentation/instrumented_kafka_writer.go rename to pkg/kafka/instrumentation/instrumented_kafka_writer.go index 5258db6..7467a1d 100644 --- a/internal/kafka/instrumentation/instrumented_kafka_writer.go +++ b/pkg/kafka/instrumentation/instrumented_kafka_writer.go @@ -7,7 +7,7 @@ import ( "fmt" "time" - "github.com/xataio/pgstream/internal/kafka" + "github.com/xataio/pgstream/pkg/kafka" "go.opentelemetry.io/otel/metric" ) diff --git a/internal/kafka/kafka_offset_parser.go b/pkg/kafka/kafka_offset_parser.go similarity index 100% rename from internal/kafka/kafka_offset_parser.go rename to pkg/kafka/kafka_offset_parser.go diff --git a/internal/kafka/kafka_offset_parser_test.go b/pkg/kafka/kafka_offset_parser_test.go similarity index 100% rename from internal/kafka/kafka_offset_parser_test.go rename to pkg/kafka/kafka_offset_parser_test.go diff --git a/internal/kafka/kafka_reader.go b/pkg/kafka/kafka_reader.go similarity index 98% rename from internal/kafka/kafka_reader.go rename to pkg/kafka/kafka_reader.go index 0916cb3..609337e 100644 --- a/internal/kafka/kafka_reader.go +++ b/pkg/kafka/kafka_reader.go @@ -44,7 +44,7 @@ func NewReader(config ReaderConfig, logger loglib.Logger) (*Reader, error) { return nil, fmt.Errorf("unsupported start offset [%s], must be one of [%s, %s]", config.ConsumerGroupStartOffset, earliestOffset, latestOffset) } - dialer, err := buildDialer(config.Conn.TLS) + dialer, err := buildDialer(&config.Conn.TLS) if err != nil { return nil, err } diff --git a/internal/kafka/kafka_writer.go b/pkg/kafka/kafka_writer.go similarity index 96% rename from internal/kafka/kafka_writer.go rename to pkg/kafka/kafka_writer.go index 66513af..04ba0f4 100644 --- a/internal/kafka/kafka_writer.go +++ b/pkg/kafka/kafka_writer.go @@ -8,8 +8,8 @@ import ( "time" "github.com/segmentio/kafka-go" - tlslib "github.com/xataio/pgstream/internal/tls" loglib "github.com/xataio/pgstream/pkg/log" + tlslib "github.com/xataio/pgstream/pkg/tls" ) type MessageWriter interface { @@ -56,7 +56,7 @@ func NewWriter(config WriterConfig, logger loglib.Logger) (*Writer, error) { } } - transport, err := buildTransport(config.Conn.TLS) + transport, err := buildTransport(&config.Conn.TLS) if err != nil { return nil, err } diff --git a/internal/kafka/log.go b/pkg/kafka/log.go similarity index 100% rename from internal/kafka/log.go rename to pkg/kafka/log.go diff --git a/internal/kafka/mocks/mock_kafka_parser.go b/pkg/kafka/mocks/mock_kafka_parser.go similarity index 87% rename from internal/kafka/mocks/mock_kafka_parser.go rename to pkg/kafka/mocks/mock_kafka_parser.go index bb5fa38..d1d314e 100644 --- a/internal/kafka/mocks/mock_kafka_parser.go +++ b/pkg/kafka/mocks/mock_kafka_parser.go @@ -2,7 +2,7 @@ package mocks -import "github.com/xataio/pgstream/internal/kafka" +import "github.com/xataio/pgstream/pkg/kafka" type OffsetParser struct { ToStringFn func(o *kafka.Offset) string diff --git a/internal/kafka/mocks/mock_kafka_reader.go b/pkg/kafka/mocks/mock_kafka_reader.go similarity index 92% rename from internal/kafka/mocks/mock_kafka_reader.go rename to pkg/kafka/mocks/mock_kafka_reader.go index 62bc746..63f8149 100644 --- a/internal/kafka/mocks/mock_kafka_reader.go +++ b/pkg/kafka/mocks/mock_kafka_reader.go @@ -5,7 +5,7 @@ package mocks import ( "context" - "github.com/xataio/pgstream/internal/kafka" + "github.com/xataio/pgstream/pkg/kafka" ) type Reader struct { diff --git a/internal/kafka/mocks/mock_kafka_writer.go b/pkg/kafka/mocks/mock_kafka_writer.go similarity index 92% rename from internal/kafka/mocks/mock_kafka_writer.go rename to pkg/kafka/mocks/mock_kafka_writer.go index f366988..4e0c233 100644 --- a/internal/kafka/mocks/mock_kafka_writer.go +++ b/pkg/kafka/mocks/mock_kafka_writer.go @@ -6,7 +6,7 @@ import ( "context" "sync/atomic" - "github.com/xataio/pgstream/internal/kafka" + "github.com/xataio/pgstream/pkg/kafka" ) type Writer struct { diff --git a/pkg/schemalog/log_entry.go b/pkg/schemalog/log_entry.go index e7389e7..ae5cb24 100644 --- a/pkg/schemalog/log_entry.go +++ b/pkg/schemalog/log_entry.go @@ -64,8 +64,6 @@ func (m *LogEntry) UnmarshalJSON(b []byte) error { if err := json.Unmarshal([]byte(schemaStr), &m.Schema); err != nil { return err } - default: - panic(fmt.Sprintf("unmarshal LogEntry, got unexpected key when unmarshalling: %s", k)) } } diff --git a/pkg/stream/config.go b/pkg/stream/config.go index 925e1ba..653fad1 100644 --- a/pkg/stream/config.go +++ b/pkg/stream/config.go @@ -50,7 +50,7 @@ type KafkaProcessorConfig struct { type SearchProcessorConfig struct { Indexer search.IndexerConfig Store opensearch.Config - Retrier *search.StoreRetryConfig + Retrier search.StoreRetryConfig } type WebhookProcessorConfig struct { diff --git a/pkg/stream/integration/helper_test.go b/pkg/stream/integration/helper_test.go index 1419268..586ce4c 100644 --- a/pkg/stream/integration/helper_test.go +++ b/pkg/stream/integration/helper_test.go @@ -10,13 +10,13 @@ import ( "testing" "github.com/stretchr/testify/require" - kafkalib "github.com/xataio/pgstream/internal/kafka" "github.com/xataio/pgstream/internal/log/zerolog" pglib "github.com/xataio/pgstream/internal/postgres" - "github.com/xataio/pgstream/internal/tls" + kafkalib "github.com/xataio/pgstream/pkg/kafka" loglib "github.com/xataio/pgstream/pkg/log" schemalogpg "github.com/xataio/pgstream/pkg/schemalog/postgres" "github.com/xataio/pgstream/pkg/stream" + "github.com/xataio/pgstream/pkg/tls" "github.com/xataio/pgstream/pkg/wal" kafkacheckpoint "github.com/xataio/pgstream/pkg/wal/checkpointer/kafka" kafkalistener "github.com/xataio/pgstream/pkg/wal/listener/kafka" @@ -174,7 +174,7 @@ func testKafkaCfg() kafkalib.ConnConfig { Name: "integration-tests", AutoCreate: true, }, - TLS: &tls.Config{ + TLS: tls.Config{ Enabled: false, }, } diff --git a/pkg/stream/stream_run.go b/pkg/stream/stream_run.go index 7fc82e6..71c376f 100644 --- a/pkg/stream/stream_run.go +++ b/pkg/stream/stream_run.go @@ -112,10 +112,7 @@ func Run(ctx context.Context, logger loglib.Logger, config *Config, meter metric if err != nil { return err } - if config.Processor.Search.Retrier != nil { - logger.Debug("using retry logic with search store...") - searchStore = search.NewStoreRetrier(searchStore, config.Processor.Search.Retrier, search.WithStoreLogger(logger)) - } + searchStore = search.NewStoreRetrier(searchStore, config.Processor.Search.Retrier, search.WithStoreLogger(logger)) searchIndexer := search.NewBatchIndexer(ctx, config.Processor.Search.Indexer, diff --git a/pkg/tls/test/test.csr b/pkg/tls/test/test.csr new file mode 100644 index 0000000..159d465 --- /dev/null +++ b/pkg/tls/test/test.csr @@ -0,0 +1,27 @@ +-----BEGIN CERTIFICATE REQUEST----- +MIIEjTCCAnUCAQAwSDELMAkGA1UEBhMCU1AxEzARBgNVBAgMClNvbWUtU3RhdGUx +FTATBgNVBAoMDFBnc3RyZWFtIEx0ZDENMAsGA1UECwwEWGF0YTCCAiIwDQYJKoZI +hvcNAQEBBQADggIPADCCAgoCggIBAMZG8/obpyvJ+WkGkdO/hbExSN1nWR206/Dh +pSYzcZyI1Jj0R4Af0gD/EFVM+4KTDr20nOofmfBWOYHV+KwiKtWQQ+oT0+xVcTT6 +IC5I5K9+AXERuTu/NbnjkxuC/1u7K511RrUK0lxUra2/B8mGTc9nu2g415GVk2hU +rJjWEX09hVH7xBSmnzYN+IfepsftxgnR5m2YzqOSnsphBBfyyOsL+3Jo5Uv4yY22 +bnJCDxx/TPG37EcGMb4Q/aCWk5mXm4Io5mBfcl7SNxy867JpBO2CCP6fFaWRUGmF +/O+YBD5z0cSb1wZrMBRgezggpa2gacYVtsWrQYzAxMtCf3MwM89z8i+W9ME3cMhg +T7b6T0XCQ5gkqmLxDCsg9ocNV0W0wb5EEPlt4/TeqMmLbpLkIa6rnR2C6gg+wFSQ +Vk8c/aBtm9BjKFLWWUwDUPPzp1RdIVAfuobDK37dVqr+1m6oTDl+BoDuM5PsgTS4 +bFW2ZJVZB+d4IFmGhqTtKKmYa7QbwhzF8i5ShV+419KOt6hRkJ/jREdea7ZS/uyK +wBAvPHZVkUk8tfTpcuKTFVKsgXXV8uwwnI2W6safYbTXaN+7gfA88wntaigjeij/ +LrO+itAiv9GTGkpMlXyuAX/d5+4j0EHcV64NYL61GvMcbJ5G0SeQrXuBlgMTYaNr +O7miEujnAgMBAAGgADANBgkqhkiG9w0BAQsFAAOCAgEAGYFGSgMARWuH5VXK0/Fl +nNg+/rzdff1NYkY9QrvFxCQeVJ9rD1ml7VLZXDtXhMNEGJyYbkouc1Ehx0BsihT6 +YztcnQ6TzWAzvr3Ns9X3riADzXxdDHV5xs+8VPV8RvT3XNcrlw2NQmzJ4Juc8PkT +4ZfguZBywAmFTw1oX8JqlQSp5pYtP7popsvGPS6ieUm0Kmv8kK3sRDs+JSc7iXtB +/HymqeSylFNHgFZsdbYmu32v2qbcqimAitB/v5tGNhuiMXx6vEeQnB69V+AV70Rl +9dnvAo7ihTRMzecUVsoDFtc8OWSPdTm6t8vDI2JqmeDN25Xhyf89cKwoY97NhA99 +ds5WHs6TzsHohPZAsaxtZnjwxMEne7Y4FLFmTHVk5o0POTZcOC/sMB1iBDsd/YJe +AYsoiqLYtu6x6Avfe6LXWYWYa/R4/UXh8H6WiChsFXzOIilp3apjaeHM3z7iKx2S +VtGyVTrrcbzRiF0ShKVnbDXnvcoNZxPiXfh6Zz4SkBbV01T3hluBwzp4mjcWpiv3 +AOAWChMnbmkg/T+OME6e1JVHDR5tAC/7vF2QkZYpiH2RVnZmCTDWBcRGpMkhkRgF +eycowzKBkgIOcJ99p0sGEqQ3W0J1M4bzuumncLID08EG/dEp1eIdunahcHHyhnnv +BcGFr2/OxuaVmxcy5/QQjAg= +-----END CERTIFICATE REQUEST----- diff --git a/pkg/tls/test/test.key b/pkg/tls/test/test.key new file mode 100644 index 0000000..df3f707 --- /dev/null +++ b/pkg/tls/test/test.key @@ -0,0 +1,52 @@ +-----BEGIN PRIVATE KEY----- +MIIJQQIBADANBgkqhkiG9w0BAQEFAASCCSswggknAgEAAoICAQDGRvP6G6cryflp +BpHTv4WxMUjdZ1kdtOvw4aUmM3GciNSY9EeAH9IA/xBVTPuCkw69tJzqH5nwVjmB +1fisIirVkEPqE9PsVXE0+iAuSOSvfgFxEbk7vzW545Mbgv9buyuddUa1CtJcVK2t +vwfJhk3PZ7toONeRlZNoVKyY1hF9PYVR+8QUpp82DfiH3qbH7cYJ0eZtmM6jkp7K +YQQX8sjrC/tyaOVL+MmNtm5yQg8cf0zxt+xHBjG+EP2glpOZl5uCKOZgX3Je0jcc +vOuyaQTtggj+nxWlkVBphfzvmAQ+c9HEm9cGazAUYHs4IKWtoGnGFbbFq0GMwMTL +Qn9zMDPPc/IvlvTBN3DIYE+2+k9FwkOYJKpi8QwrIPaHDVdFtMG+RBD5beP03qjJ +i26S5CGuq50dguoIPsBUkFZPHP2gbZvQYyhS1llMA1Dz86dUXSFQH7qGwyt+3Vaq +/tZuqEw5fgaA7jOT7IE0uGxVtmSVWQfneCBZhoak7SipmGu0G8IcxfIuUoVfuNfS +jreoUZCf40RHXmu2Uv7sisAQLzx2VZFJPLX06XLikxVSrIF11fLsMJyNlurGn2G0 +12jfu4HwPPMJ7WooI3oo/y6zvorQIr/RkxpKTJV8rgF/3efuI9BB3FeuDWC+tRrz +HGyeRtEnkK17gZYDE2Gjazu5ohLo5wIDAQABAoICAFEXJaMNejIzeViVwkA6nP/Z +6zX5lX3Lx48NidB0y6s8Xs5rYW6qFOYpatGoGVjOsgGuA1rRL9EWQpCyJPCpTKFp +Tg1GrK6ERzdmcJDdaQHI4+gNWpdv3RY4V6qxyaQHiY/tLczPLzdpvlpHvXSTA/Gm +OAQo8yjsZowNzUT4j9CLv6HG+OuFNaoSzqkqy0ULHqpXeQkrrJ9DUMPuJ5FvzvIq +RV0GP3jxt+TITqVWFP4PpjVZhj2J8AAOzNvHmXgAhC4YchfKEWlsSfPr4+1kfApy +2yDfiSfcpWlyzf5jSqEMFyd0oN1UKya6Ssqqt3eqGnhT2xs+riFVmWaTvLIsbZNb +ML2OT7s0nqSkhkomVcQxgh+rypgljPB2JhHNSScKF3BYlIBFX9GarZ32vcLdBb/Z +u1GOiu6Fd5uBVfS0PTwX/8/NwCOmofUTxUa3J8oLcxBZRe8y0sxnfdsbKbXLviym +okOUyk6h4KaTXB2+99Q7TSHGGXissCeTCU8tIMEKFHlN3PQeSphEcHvhIBKTPdK3 +vCbzfpoMWzyTBvSdffQRewSwij3fImPrWsE9s15jjDW3AwLjuNtPCYg3Amu0yK0k +ktAnJgruoWhnrxJae9iXPUuZghPvweDNnhtQb5Uf3jeYKwL2fLE3eaVtjqiWha5B +ziDK7ZelBH85IEP6COWxAoIBAQDpHIZ/aFl49AV0BhCDv2+28OlR1s6p4QZmX5mb +pQfON868Ya7iR+tPjzN71IJfga02QiaaXOUeh0FxCJMqdypSLW6XOIlppHLO8jmV +6RztF/iA1h1+fo3DMuuL2Nw3ivo4NhBvSoyfAfSI+Nn5eJl2Zqhason5aWSZsxrN +9A4Kcig8RS3aE3zQKCsT/ZBeM7AT+qqtGNceRh9pLgYkSnaugDX4JwD+f5VhUQOd +ivB9E4xHZJJN61iefok1SkMekQ+GF8WQRCr8dKbQAgxh5Eu84J7iax4EvFsoHuIs +9Wez0Yiy14ArW2+tkA+2GE0PBRvOPfjtVZ3n3lNOCjbLozWfAoIBAQDZvtXst+CS +lQJjkDenb51ujAum9JaxbKoU2kOEEEeN8AYgvmIqpSsgaMmzYNtW1yVT+O2z0Ix7 +I6lG2cNz5XuWWJkBroAwPeHBZjOwyR9AdglYenIui6dftTdMDVnYbt/Fu905yh3t +fGtT8IRGond3c+WZZCgbr6sO+nq4PpG+i+nuyqOo0JJCLtgWyyK9jrfEhk2fNZyK +B7uaKwOhMHetkfnp5jyszGYsU2YU++NGL0vcacXUbZyYV4BT7+s3m+H0G8wf4qSP +FAkX+sloEl5HRWDB4SP0H9erKg8zoKDI7DVtv2DJLnLMB88XjOz3l24/HOVzyI+N +X3SfSQP55De5AoIBACpwGgA55AgEDLYZoIoLoO/iHefbPlZo8/xRLSrLuYcOW+Gp +uufRBgK+5DWH85AlkH4PPu3dOYz8PKqyT/BsL1U0liyLi2CjIo+QQ3GKNczoD0KN +OGNd8Lr3mzAjc7vc3j67gPRx0vXjqjwBadVj4jRO7hlM5Zd1W24r0BZsdt3p+G84 +fOd1osRWe7kw8UZlDIommUnX+tm1FGTWjyGuOLr99lVN7H1ohq5nzEuzDqMGmwQo +SAZNcR2xlZMRCPUYnYXg8AOalWTOa8v0g4KSyEMDdYlszNM54zKDpNNgfdebrtI4 +L0o1ZDhpwKJ6/BRe7rf2SkoSyyN6MxpC+8TI2qsCggEAbN/X1VoHpyNso13cBhNw +E3Ng7CUGKEbeMDkGY0VEkfr/BWZMbWhSzQy4NcHrSlufJYKlUDCp3XRyUqPV7+BB +0GYSc13OaNC4TdyNYgreXnvmpl/rMczQbrGMqbFPSEIAD72kmx2toy5/9+OeMDdS +Jt9DYVRMHbPTg1TJAdD/TNhmquiVtnY7e24yzArcHw36YwCIVWAYGohNTIPPd8xl +Ottvq31cv0YgnG9C7qEX/eLuOpKEwXfhQecWmmGvKgn+i/FOOm83uvbYqS3TgP8W +NurAu5CYSpuVWddY7IaXfn9lI6/6c/2Olugcq3jij9Ye4N3Q+PjClnyxMmfu3gc3 +uQKCAQAmpjnO3OukXxgEtDusQKh/IxsvKQMfSQNh68cQ6iroGDAguYBusSTYU1gA +IKvR6Uqqv52yPe6u5pYC0fA8L+2S3nuFA4f4a1JHdpf5X5HwBtTERG4Qxh7p1gss +3AZpLfYa442d8UuCCYoimBXqXXF0TLsfoRjcgrKd9yNO6Pa79jzUR+ixQmkyE7XA +XGHx7Qsl6E0E1DgK4MHPOALg/tJiLNJQgIKDtiBn4GTR8tSHRDY64rEuOClq14y1 +cpXNj7lcV1xz1vLNRcksS0/QA7Hzol+Dt8xVOBK1smhCREMetUMXHwBYLnNemxNX +/WF38bnN3/Dej6Ns6FRE8gNSRmeT +-----END PRIVATE KEY----- diff --git a/pkg/tls/test/test.pem b/pkg/tls/test/test.pem new file mode 100644 index 0000000..75f10e2 --- /dev/null +++ b/pkg/tls/test/test.pem @@ -0,0 +1,31 @@ +-----BEGIN CERTIFICATE----- +MIIFPzCCAyegAwIBAgIULV6zejFwt/Tri8WKZFBxj15uvaIwDQYJKoZIhvcNAQEL +BQAwSDELMAkGA1UEBhMCU1AxEzARBgNVBAgMClNvbWUtU3RhdGUxFTATBgNVBAoM +DFBnc3RyZWFtIEx0ZDENMAsGA1UECwwEWGF0YTAeFw0yNDA4MDUxNTE3NDBaFw0y +NTA4MDUxNTE3NDBaMEgxCzAJBgNVBAYTAlNQMRMwEQYDVQQIDApTb21lLVN0YXRl +MRUwEwYDVQQKDAxQZ3N0cmVhbSBMdGQxDTALBgNVBAsMBFhhdGEwggIiMA0GCSqG +SIb3DQEBAQUAA4ICDwAwggIKAoICAQDGRvP6G6cryflpBpHTv4WxMUjdZ1kdtOvw +4aUmM3GciNSY9EeAH9IA/xBVTPuCkw69tJzqH5nwVjmB1fisIirVkEPqE9PsVXE0 ++iAuSOSvfgFxEbk7vzW545Mbgv9buyuddUa1CtJcVK2tvwfJhk3PZ7toONeRlZNo +VKyY1hF9PYVR+8QUpp82DfiH3qbH7cYJ0eZtmM6jkp7KYQQX8sjrC/tyaOVL+MmN +tm5yQg8cf0zxt+xHBjG+EP2glpOZl5uCKOZgX3Je0jccvOuyaQTtggj+nxWlkVBp +hfzvmAQ+c9HEm9cGazAUYHs4IKWtoGnGFbbFq0GMwMTLQn9zMDPPc/IvlvTBN3DI +YE+2+k9FwkOYJKpi8QwrIPaHDVdFtMG+RBD5beP03qjJi26S5CGuq50dguoIPsBU +kFZPHP2gbZvQYyhS1llMA1Dz86dUXSFQH7qGwyt+3Vaq/tZuqEw5fgaA7jOT7IE0 +uGxVtmSVWQfneCBZhoak7SipmGu0G8IcxfIuUoVfuNfSjreoUZCf40RHXmu2Uv7s +isAQLzx2VZFJPLX06XLikxVSrIF11fLsMJyNlurGn2G012jfu4HwPPMJ7WooI3oo +/y6zvorQIr/RkxpKTJV8rgF/3efuI9BB3FeuDWC+tRrzHGyeRtEnkK17gZYDE2Gj +azu5ohLo5wIDAQABoyEwHzAdBgNVHQ4EFgQUcj2UaSuSsgkex5hfS0eDAkPdbJ0w +DQYJKoZIhvcNAQELBQADggIBAJK6fMa4L7iIQSlPzG3pHTSSLQd9Unev2naX9/S1 +Yo55Tj9VCBhViGa7CbDtaW7ZYr/fXydZVcthXYZzZ7QEVyYaguWlzXLjy/qF8kgk +cDwinFa8hiJnP+BJUGnzq3LYQJ2labI4YUscc6p4inh9y8JZ3n33VqX2YjqCdHMA +j8nw5xpThdQ/a8z3Z8ugFCLO09Hts1eKFhs5PwaQvjkoX+dSE2FeX51OMlLOPDsu +C6ScDU7FG0J5JE36nRqp2XwSdGAfc5pHKmsuomxnoE/d/hL7O6zouo/jvQyCNFtn +5/pzhkhhOjUTP2gIW5ueNn8oQF9F32GWRNJGQVTBiK17dWvHxiSvIzgKqUyrD8lI +VefVEQgRbfHD3nSk6G30gAeWzt8T10lI8MtQWTtoFJGFBaSVr/lSyHo4QS4SyTmK +uvnFGJVivRtaAP4d+u/6/1Mvy5sNsSiWRRKfKTB/FEerbe7blhnqJhrBp98nKBv/ +IJtmewD7lVGGDY8sWnxnpNyqLVhvRilO9d+4oQWqKgN8m1PXAI2jDYA0RZ+Qs/ko +5FB88mRi8hNOhmADXguKlnCid/X0StK6wpphvFIaNnGLFzjeZXc65BoV1c2+Boxe +cNwpkrW5tKaf/Ox1ntHnnQBpUhM4AxGoczfIj0dEnYio54gagsAMK/Pjq2KiX7Bi +RUm4 +-----END CERTIFICATE----- diff --git a/pkg/tls/tls.go b/pkg/tls/tls.go new file mode 100644 index 0000000..8735e0d --- /dev/null +++ b/pkg/tls/tls.go @@ -0,0 +1,98 @@ +// SPDX-License-Identifier: Apache-2.0 + +package tls + +import ( + "crypto/tls" + "crypto/x509" + "fmt" + "os" +) + +type Config struct { + // Enabled determines if TLS should be used. Defaults to false. + Enabled bool + // File path to the CA PEM certificate or PEM certificate to be used for TLS + // connection. If TLS is enabled and no CA cert file is provided, the system + // certificate pool is used as default. + CaCertFile string + CaCertPEM string + // File path to the client PEM certificate or client PEM certificate + ClientCertFile string + ClientCertPEM string + // File path to the client PEM key or client PEM key content + ClientKeyFile string + ClientKeyPEM string +} + +func NewConfig(cfg *Config) (*tls.Config, error) { + if !cfg.Enabled { + return nil, nil + } + + certPool, err := getCertPool(cfg) + if err != nil { + return nil, err + } + + certificates, err := getCertificates(cfg) + if err != nil { + return nil, err + } + + return &tls.Config{ + MinVersion: tls.VersionTLS12, + MaxVersion: 0, + Certificates: certificates, + RootCAs: certPool, + }, nil +} + +func getCertPool(cfg *Config) (*x509.CertPool, error) { + pemCertBytes, err := readPEMBytes(cfg.CaCertFile, cfg.CaCertPEM) + if err != nil { + return nil, fmt.Errorf("reading CA certificate file: %w", err) + } + + if len(pemCertBytes) == 0 { + return x509.SystemCertPool() + } + + certPool := x509.NewCertPool() + certPool.AppendCertsFromPEM(pemCertBytes) + return certPool, nil +} + +func getCertificates(cfg *Config) ([]tls.Certificate, error) { + if cfg.IsClientCertProvided() { + pemCertBytes, err := readPEMBytes(cfg.ClientCertFile, cfg.ClientCertPEM) + if err != nil { + return nil, err + } + pemKeyBytes, err := readPEMBytes(cfg.ClientKeyFile, cfg.ClientKeyPEM) + if err != nil { + return nil, err + } + cert, err := tls.X509KeyPair(pemCertBytes, pemKeyBytes) + if err != nil { + return nil, err + } + return []tls.Certificate{cert}, nil + } + + return []tls.Certificate{}, nil +} + +// readPEMBytes will parse the certificate on input and return the pem byte +// content. It accepts a pem certificate or the file path to a pem certificate. +func readPEMBytes(certFile, certPEM string) ([]byte, error) { + if certFile != "" { + return os.ReadFile(certFile) + } + + return []byte(certPEM), nil +} + +func (c *Config) IsClientCertProvided() bool { + return (c.ClientCertFile != "" || c.ClientCertPEM != "") && (c.ClientKeyFile != "" || c.ClientKeyPEM != "") +} diff --git a/pkg/tls/tls_test.go b/pkg/tls/tls_test.go new file mode 100644 index 0000000..53e79aa --- /dev/null +++ b/pkg/tls/tls_test.go @@ -0,0 +1,208 @@ +// SPDX-License-Identifier: Apache-2.0 + +package tls + +import ( + "crypto/tls" + "crypto/x509" + "errors" + "os" + "testing" + + "github.com/google/go-cmp/cmp" + "github.com/google/go-cmp/cmp/cmpopts" + "github.com/stretchr/testify/require" +) + +func Test_NewConfig(t *testing.T) { + t.Parallel() + + systemCAs, err := x509.SystemCertPool() + require.NoError(t, err) + + testPEMBytes, err := os.ReadFile("test/test.pem") + require.NoError(t, err) + testCertPool := x509.NewCertPool() + testCertPool.AppendCertsFromPEM(testPEMBytes) + + testClientKeyPair, err := tls.LoadX509KeyPair("test/test.pem", "test/test.key") + require.NoError(t, err) + + tests := []struct { + name string + cfg *Config + + wantConfig *tls.Config + wantErr error + }{ + { + name: "ok - tls not enabled", + cfg: &Config{ + Enabled: false, + }, + + wantConfig: nil, + wantErr: nil, + }, + { + name: "ok - tls enabled no certificates", + cfg: &Config{ + Enabled: true, + }, + + wantConfig: &tls.Config{ + MinVersion: tls.VersionTLS12, + MaxVersion: 0, + Certificates: []tls.Certificate{}, + RootCAs: systemCAs, + }, + wantErr: nil, + }, + { + name: "ok - tls enabled with CA certificate", + cfg: &Config{ + Enabled: true, + CaCertFile: "test/test.pem", + }, + + wantConfig: &tls.Config{ + MinVersion: tls.VersionTLS12, + MaxVersion: 0, + Certificates: []tls.Certificate{}, + RootCAs: testCertPool, + }, + wantErr: nil, + }, + { + name: "ok - tls enabled with client certificate", + cfg: &Config{ + Enabled: true, + ClientCertFile: "test/test.pem", + ClientKeyFile: "test/test.key", + }, + + wantConfig: &tls.Config{ + MinVersion: tls.VersionTLS12, + MaxVersion: 0, + Certificates: []tls.Certificate{testClientKeyPair}, + RootCAs: systemCAs, + }, + wantErr: nil, + }, + { + name: "ok - tls enabled with CA and client certificate", + cfg: &Config{ + Enabled: true, + CaCertFile: "test/test.pem", + ClientCertFile: "test/test.pem", + ClientKeyFile: "test/test.key", + }, + + wantConfig: &tls.Config{ + MinVersion: tls.VersionTLS12, + MaxVersion: 0, + Certificates: []tls.Certificate{testClientKeyPair}, + RootCAs: testCertPool, + }, + wantErr: nil, + }, + { + name: "error - invalid CA certificate file", + cfg: &Config{ + Enabled: true, + CaCertFile: "test/doesnotexist.pem", + }, + + wantConfig: nil, + wantErr: os.ErrNotExist, + }, + { + name: "error - invalid client certificate file", + cfg: &Config{ + Enabled: true, + ClientCertFile: "test/doesnotexist.pem", + ClientKeyFile: "test/test.pem", + }, + + wantConfig: nil, + wantErr: os.ErrNotExist, + }, + { + name: "error - invalid client key file", + cfg: &Config{ + Enabled: true, + ClientCertFile: "test/test.pem", + ClientKeyFile: "test/doesnotexist.pem", + }, + + wantConfig: nil, + wantErr: os.ErrNotExist, + }, + { + name: "error - invalid client key pair", + cfg: &Config{ + Enabled: true, + ClientCertFile: "test/test.pem", + ClientKeyFile: "test/test.pem", + }, + + wantConfig: nil, + wantErr: errors.New("tls: found a certificate rather than a key in the PEM for the private key"), + }, + } + + for _, tc := range tests { + tc := tc + t.Run(tc.name, func(t *testing.T) { + t.Parallel() + + tlsCfg, err := NewConfig(tc.cfg) + if !errors.Is(err, tc.wantErr) { + require.EqualError(t, err, tc.wantErr.Error()) + } + require.Equal(t, "", cmp.Diff(tlsCfg, tc.wantConfig, cmpopts.IgnoreUnexported(tls.Config{}))) //nolint:gosec + }) + } +} + +func Test_readPEMBytes(t *testing.T) { + t.Parallel() + + testPEMBytes, err := os.ReadFile("test/test.pem") + require.NoError(t, err) + + tests := []struct { + name string + file string + pem string + + wantBytes []byte + wantErr error + }{ + { + name: "with file", + file: "test/test.pem", + + wantBytes: testPEMBytes, + wantErr: nil, + }, + { + name: "with pem", + pem: string(testPEMBytes), + + wantBytes: testPEMBytes, + wantErr: nil, + }, + } + + for _, tc := range tests { + tc := tc + t.Run(tc.name, func(t *testing.T) { + t.Parallel() + + pemBytes, err := readPEMBytes(tc.file, tc.pem) + require.ErrorIs(t, err, tc.wantErr) + require.Equal(t, tc.wantBytes, pemBytes) + }) + } +} diff --git a/pkg/wal/checkpointer/kafka/wal_kafka_checkpointer.go b/pkg/wal/checkpointer/kafka/wal_kafka_checkpointer.go index 1ee144c..1a9e5e8 100644 --- a/pkg/wal/checkpointer/kafka/wal_kafka_checkpointer.go +++ b/pkg/wal/checkpointer/kafka/wal_kafka_checkpointer.go @@ -7,8 +7,8 @@ import ( "fmt" "time" - "github.com/xataio/pgstream/internal/backoff" - "github.com/xataio/pgstream/internal/kafka" + "github.com/xataio/pgstream/pkg/backoff" + "github.com/xataio/pgstream/pkg/kafka" loglib "github.com/xataio/pgstream/pkg/log" "github.com/xataio/pgstream/pkg/wal" ) diff --git a/pkg/wal/checkpointer/kafka/wal_kafka_checkpointer_test.go b/pkg/wal/checkpointer/kafka/wal_kafka_checkpointer_test.go index f18f688..f216349 100644 --- a/pkg/wal/checkpointer/kafka/wal_kafka_checkpointer_test.go +++ b/pkg/wal/checkpointer/kafka/wal_kafka_checkpointer_test.go @@ -10,10 +10,10 @@ import ( "github.com/stretchr/testify/require" - "github.com/xataio/pgstream/internal/backoff" - backoffmocks "github.com/xataio/pgstream/internal/backoff/mocks" - "github.com/xataio/pgstream/internal/kafka" - kafkamocks "github.com/xataio/pgstream/internal/kafka/mocks" + "github.com/xataio/pgstream/pkg/backoff" + backoffmocks "github.com/xataio/pgstream/pkg/backoff/mocks" + "github.com/xataio/pgstream/pkg/kafka" + kafkamocks "github.com/xataio/pgstream/pkg/kafka/mocks" loglib "github.com/xataio/pgstream/pkg/log" "github.com/xataio/pgstream/pkg/wal" ) diff --git a/pkg/wal/listener/kafka/wal_kafka_reader.go b/pkg/wal/listener/kafka/wal_kafka_reader.go index a5aa0d7..1c8ff90 100644 --- a/pkg/wal/listener/kafka/wal_kafka_reader.go +++ b/pkg/wal/listener/kafka/wal_kafka_reader.go @@ -9,7 +9,7 @@ import ( "fmt" "runtime/debug" - "github.com/xataio/pgstream/internal/kafka" + "github.com/xataio/pgstream/pkg/kafka" loglib "github.com/xataio/pgstream/pkg/log" "github.com/xataio/pgstream/pkg/wal" ) diff --git a/pkg/wal/listener/kafka/wal_kafka_reader_test.go b/pkg/wal/listener/kafka/wal_kafka_reader_test.go index 86e8f50..954d602 100644 --- a/pkg/wal/listener/kafka/wal_kafka_reader_test.go +++ b/pkg/wal/listener/kafka/wal_kafka_reader_test.go @@ -11,8 +11,8 @@ import ( "time" "github.com/stretchr/testify/require" - "github.com/xataio/pgstream/internal/kafka" - kafkamocks "github.com/xataio/pgstream/internal/kafka/mocks" + "github.com/xataio/pgstream/pkg/kafka" + kafkamocks "github.com/xataio/pgstream/pkg/kafka/mocks" loglib "github.com/xataio/pgstream/pkg/log" "github.com/xataio/pgstream/pkg/wal" ) diff --git a/pkg/wal/processor/kafka/config.go b/pkg/wal/processor/kafka/config.go index 3639fd7..b57f27e 100644 --- a/pkg/wal/processor/kafka/config.go +++ b/pkg/wal/processor/kafka/config.go @@ -6,7 +6,7 @@ import ( "errors" "time" - "github.com/xataio/pgstream/internal/kafka" + "github.com/xataio/pgstream/pkg/kafka" ) type Config struct { diff --git a/pkg/wal/processor/kafka/wal_kafka_batch_writer.go b/pkg/wal/processor/kafka/wal_kafka_batch_writer.go index 898c563..5d2d8a7 100644 --- a/pkg/wal/processor/kafka/wal_kafka_batch_writer.go +++ b/pkg/wal/processor/kafka/wal_kafka_batch_writer.go @@ -10,9 +10,9 @@ import ( "runtime/debug" "time" - "github.com/xataio/pgstream/internal/kafka" - kafkainstrumentation "github.com/xataio/pgstream/internal/kafka/instrumentation" synclib "github.com/xataio/pgstream/internal/sync" + "github.com/xataio/pgstream/pkg/kafka" + kafkainstrumentation "github.com/xataio/pgstream/pkg/kafka/instrumentation" loglib "github.com/xataio/pgstream/pkg/log" "github.com/xataio/pgstream/pkg/wal" "github.com/xataio/pgstream/pkg/wal/checkpointer" diff --git a/pkg/wal/processor/kafka/wal_kafka_batch_writer_test.go b/pkg/wal/processor/kafka/wal_kafka_batch_writer_test.go index 32eeb9f..dc9a3cc 100644 --- a/pkg/wal/processor/kafka/wal_kafka_batch_writer_test.go +++ b/pkg/wal/processor/kafka/wal_kafka_batch_writer_test.go @@ -12,10 +12,10 @@ import ( "time" "github.com/stretchr/testify/require" - "github.com/xataio/pgstream/internal/kafka" - kafkamocks "github.com/xataio/pgstream/internal/kafka/mocks" synclib "github.com/xataio/pgstream/internal/sync" syncmocks "github.com/xataio/pgstream/internal/sync/mocks" + "github.com/xataio/pgstream/pkg/kafka" + kafkamocks "github.com/xataio/pgstream/pkg/kafka/mocks" loglib "github.com/xataio/pgstream/pkg/log" "github.com/xataio/pgstream/pkg/schemalog" "github.com/xataio/pgstream/pkg/wal" diff --git a/pkg/wal/processor/kafka/wal_kafka_msg_batch.go b/pkg/wal/processor/kafka/wal_kafka_msg_batch.go index ad35dea..492a1bd 100644 --- a/pkg/wal/processor/kafka/wal_kafka_msg_batch.go +++ b/pkg/wal/processor/kafka/wal_kafka_msg_batch.go @@ -3,7 +3,7 @@ package kafka import ( - "github.com/xataio/pgstream/internal/kafka" + "github.com/xataio/pgstream/pkg/kafka" "github.com/xataio/pgstream/pkg/wal" ) diff --git a/pkg/wal/processor/search/config.go b/pkg/wal/processor/search/config.go index 4b0a6d0..16db131 100644 --- a/pkg/wal/processor/search/config.go +++ b/pkg/wal/processor/search/config.go @@ -5,7 +5,7 @@ package search import ( "time" - "github.com/xataio/pgstream/internal/backoff" + "github.com/xataio/pgstream/pkg/backoff" ) type IndexerConfig struct { diff --git a/pkg/wal/processor/search/opensearch/opensearch_adapter.go b/pkg/wal/processor/search/opensearch/opensearch_adapter.go index 31f32d7..30c8777 100644 --- a/pkg/wal/processor/search/opensearch/opensearch_adapter.go +++ b/pkg/wal/processor/search/opensearch/opensearch_adapter.go @@ -5,7 +5,6 @@ package opensearch import ( "encoding/json" "fmt" - "strings" "github.com/xataio/pgstream/internal/es" "github.com/xataio/pgstream/pkg/schemalog" @@ -13,36 +12,28 @@ import ( ) // Adapter converts from/to search types and opensearch types -type Adapter interface { - SchemaNameToIndex(schemaName string) IndexName - IndexToSchemaName(index string) string +type SearchAdapter interface { SearchDocToBulkItem(docs search.Document) es.BulkItem BulkItemsToSearchDocErrs(items []es.BulkItem) []search.DocumentError RecordToLogEntry(rec map[string]any) (*schemalog.LogEntry, error) } type adapter struct { - marshaler func(any) ([]byte, error) - unmarshaler func([]byte, any) error + indexNameAdapter IndexNameAdapter + marshaler func(any) ([]byte, error) + unmarshaler func([]byte, any) error } -func newDefaultAdapter() *adapter { +func newDefaultAdapter(indexNameAdapter IndexNameAdapter) *adapter { return &adapter{ - marshaler: json.Marshal, - unmarshaler: json.Unmarshal, + indexNameAdapter: indexNameAdapter, + marshaler: json.Marshal, + unmarshaler: json.Unmarshal, } } -func (a *adapter) SchemaNameToIndex(schemaName string) IndexName { - return newDefaultIndexName(schemaName) -} - -func (a *adapter) IndexToSchemaName(index string) string { - return strings.TrimSuffix(index, "-1") -} - func (a *adapter) SearchDocToBulkItem(doc search.Document) es.BulkItem { - indexName := a.SchemaNameToIndex(doc.Schema) + indexName := a.indexNameAdapter.SchemaNameToIndex(doc.Schema) item := es.BulkItem{ Doc: doc.Data, } @@ -94,13 +85,13 @@ func (a *adapter) bulkItemToSearchDocErr(item es.BulkItem) search.DocumentError } switch { case item.Index != nil: - doc.Document.Schema = a.IndexToSchemaName(item.Index.Index) + doc.Document.Schema = a.indexNameAdapter.IndexToSchemaName(item.Index.Index) doc.Document.ID = item.Index.ID if item.Index.Version != nil { doc.Document.Version = *item.Index.Version } case item.Delete != nil: - doc.Document.Schema = a.IndexToSchemaName(item.Delete.Index) + doc.Document.Schema = a.indexNameAdapter.IndexToSchemaName(item.Delete.Index) doc.Document.ID = item.Delete.ID if item.Delete.Version != nil { doc.Document.Version = *item.Delete.Version diff --git a/pkg/wal/processor/search/opensearch/opensearch_index_name.go b/pkg/wal/processor/search/opensearch/opensearch_index_name.go index 1432f40..85a9d80 100644 --- a/pkg/wal/processor/search/opensearch/opensearch_index_name.go +++ b/pkg/wal/processor/search/opensearch/opensearch_index_name.go @@ -4,8 +4,14 @@ package opensearch import ( "fmt" + "strings" ) +type IndexNameAdapter interface { + SchemaNameToIndex(schemaName string) IndexName + IndexToSchemaName(index string) string +} + // IndexName represents an opensearch index name constructed from a schema name. type IndexName interface { Name() string @@ -14,33 +20,47 @@ type IndexName interface { SchemaName() string } -type indexName struct { +type defaultIndexNameAdapter struct{} + +func newDefaultIndexNameAdapter() IndexNameAdapter { + return &defaultIndexNameAdapter{} +} + +func (i *defaultIndexNameAdapter) SchemaNameToIndex(schemaName string) IndexName { + return newDefaultIndexName(schemaName) +} + +func (i *defaultIndexNameAdapter) IndexToSchemaName(index string) string { + return strings.TrimSuffix(index, "-1") +} + +type defaultIndexName struct { schemaName string version int } func newDefaultIndexName(schemaName string) IndexName { - return &indexName{ + return &defaultIndexName{ schemaName: schemaName, version: 1, } } -func (i indexName) SchemaName() string { +func (i defaultIndexName) SchemaName() string { return i.schemaName } // NameWithVersion represents the name of the index with the version number. This should // generally not be needed, in favour of `Name`. -func (i indexName) NameWithVersion() string { +func (i defaultIndexName) NameWithVersion() string { return fmt.Sprintf("%s-%d", i.schemaName, i.version) } // Name returns the name we should use for querying the index. -func (i *indexName) Name() string { +func (i *defaultIndexName) Name() string { return i.schemaName } -func (i *indexName) Version() int { +func (i *defaultIndexName) Version() int { return i.version } diff --git a/pkg/wal/processor/search/opensearch/opensearch_store.go b/pkg/wal/processor/search/opensearch/opensearch_store.go index 4cd90e5..2628f53 100644 --- a/pkg/wal/processor/search/opensearch/opensearch_store.go +++ b/pkg/wal/processor/search/opensearch/opensearch_store.go @@ -16,11 +16,12 @@ import ( ) type Store struct { - logger loglib.Logger - client es.SearchClient - mapper search.Mapper - adapter Adapter - marshaler func(any) ([]byte, error) + logger loglib.Logger + client es.SearchClient + mapper search.Mapper + adapter SearchAdapter + indexNameAdapter IndexNameAdapter + marshaler func(any) ([]byte, error) } type Config struct { @@ -57,12 +58,14 @@ func NewStore(cfg Config, opts ...Option) (*Store, error) { } func NewStoreWithClient(client es.SearchClient) *Store { + indexNameAdapter := newDefaultIndexNameAdapter() return &Store{ - logger: loglib.NewNoopLogger(), - client: client, - adapter: newDefaultAdapter(), - mapper: NewPostgresMapper(), - marshaler: json.Marshal, + logger: loglib.NewNoopLogger(), + client: client, + indexNameAdapter: indexNameAdapter, + adapter: newDefaultAdapter(indexNameAdapter), + mapper: NewPostgresMapper(), + marshaler: json.Marshal, } } @@ -72,6 +75,19 @@ func WithLogger(l loglib.Logger) Option { } } +func WithMapper(m search.Mapper) Option { + return func(s *Store) { + s.mapper = m + } +} + +func WithIndexNameAdapter(a IndexNameAdapter) Option { + return func(s *Store) { + s.indexNameAdapter = a + s.adapter = newDefaultAdapter(a) + } +} + func (s *Store) GetMapper() search.Mapper { return s.mapper } @@ -154,7 +170,7 @@ func (s *Store) SendDocuments(ctx context.Context, docs []search.Document) ([]se } func (s *Store) DeleteSchema(ctx context.Context, schemaName string) error { - index := s.adapter.SchemaNameToIndex(schemaName) + index := s.indexNameAdapter.SchemaNameToIndex(schemaName) exists, err := s.client.IndexExists(ctx, index.NameWithVersion()) if err != nil { return mapError(err) @@ -184,7 +200,7 @@ func (s *Store) DeleteSchema(ctx context.Context, schemaName string) error { } func (s *Store) DeleteTableDocuments(ctx context.Context, schemaName string, tableIDs []string) error { - index := s.adapter.SchemaNameToIndex(schemaName) + index := s.indexNameAdapter.SchemaNameToIndex(schemaName) if err := s.deleteTableDocuments(ctx, index, tableIDs); err != nil { return mapError(err) } @@ -241,7 +257,7 @@ func (s *Store) getLastSchemaLogEntry(ctx context.Context, schemaName string) (* } func (s *Store) schemaExists(ctx context.Context, schemaName string) (bool, error) { - indexName := s.adapter.SchemaNameToIndex(schemaName) + indexName := s.indexNameAdapter.SchemaNameToIndex(schemaName) exists, err := s.client.IndexExists(ctx, indexName.NameWithVersion()) if err != nil { return false, mapError(err) @@ -250,7 +266,7 @@ func (s *Store) schemaExists(ctx context.Context, schemaName string) (bool, erro } func (s *Store) createSchema(ctx context.Context, schemaName string) error { - index := s.adapter.SchemaNameToIndex(schemaName) + index := s.indexNameAdapter.SchemaNameToIndex(schemaName) err := s.client.CreateIndex(ctx, index.NameWithVersion(), map[string]any{ "mappings": map[string]any{ "dynamic": "strict", @@ -284,7 +300,7 @@ func (s *Store) createSchema(ctx context.Context, schemaName string) error { } func (s *Store) updateMapping(ctx context.Context, schemaName string, logEntry *schemalog.LogEntry, diff *schemalog.SchemaDiff) error { - index := s.adapter.SchemaNameToIndex(schemaName) + index := s.indexNameAdapter.SchemaNameToIndex(schemaName) if diff != nil { if err := s.updateMappingAddNewColumns(ctx, index, diff.ColumnsToAdd); err != nil { return fmt.Errorf("failed to add new columns: %w", mapError(err)) diff --git a/pkg/wal/processor/search/opensearch/opensearch_store_test.go b/pkg/wal/processor/search/opensearch/opensearch_store_test.go index 3af3629..3a1d7b4 100644 --- a/pkg/wal/processor/search/opensearch/opensearch_store_test.go +++ b/pkg/wal/processor/search/opensearch/opensearch_store_test.go @@ -506,7 +506,7 @@ func TestStore_getLastSchemaLogEntry(t *testing.T) { tests := []struct { name string client es.SearchClient - adapter Adapter + adapter SearchAdapter marshaler func(any) ([]byte, error) wantLogEntry *schemalog.LogEntry diff --git a/pkg/wal/processor/search/search_schema_cleaner.go b/pkg/wal/processor/search/search_schema_cleaner.go index 4b200b1..e318800 100644 --- a/pkg/wal/processor/search/search_schema_cleaner.go +++ b/pkg/wal/processor/search/search_schema_cleaner.go @@ -8,7 +8,7 @@ import ( "fmt" "time" - "github.com/xataio/pgstream/internal/backoff" + "github.com/xataio/pgstream/pkg/backoff" loglib "github.com/xataio/pgstream/pkg/log" ) diff --git a/pkg/wal/processor/search/search_schema_cleaner_test.go b/pkg/wal/processor/search/search_schema_cleaner_test.go index 8a01bd2..7809e05 100644 --- a/pkg/wal/processor/search/search_schema_cleaner_test.go +++ b/pkg/wal/processor/search/search_schema_cleaner_test.go @@ -10,8 +10,8 @@ import ( "time" "github.com/stretchr/testify/require" - "github.com/xataio/pgstream/internal/backoff" - "github.com/xataio/pgstream/internal/backoff/mocks" + "github.com/xataio/pgstream/pkg/backoff" + "github.com/xataio/pgstream/pkg/backoff/mocks" loglib "github.com/xataio/pgstream/pkg/log" ) diff --git a/pkg/wal/processor/search/search_store_retrier.go b/pkg/wal/processor/search/search_store_retrier.go index 1372b56..f2275ed 100644 --- a/pkg/wal/processor/search/search_store_retrier.go +++ b/pkg/wal/processor/search/search_store_retrier.go @@ -9,7 +9,7 @@ import ( "fmt" "time" - "github.com/xataio/pgstream/internal/backoff" + "github.com/xataio/pgstream/pkg/backoff" loglib "github.com/xataio/pgstream/pkg/log" "github.com/xataio/pgstream/pkg/schemalog" ) @@ -37,7 +37,7 @@ const ( var errPartialDocumentSend = errors.New("failed to send some or all documents") -func NewStoreRetrier(s Store, cfg *StoreRetryConfig, opts ...StoreOption) *StoreRetrier { +func NewStoreRetrier(s Store, cfg StoreRetryConfig, opts ...StoreOption) *StoreRetrier { sr := &StoreRetrier{ inner: s, logger: loglib.NewNoopLogger(), diff --git a/pkg/wal/processor/search/search_store_retrier_test.go b/pkg/wal/processor/search/search_store_retrier_test.go index 9134ed9..7921053 100644 --- a/pkg/wal/processor/search/search_store_retrier_test.go +++ b/pkg/wal/processor/search/search_store_retrier_test.go @@ -8,7 +8,7 @@ import ( "testing" "github.com/stretchr/testify/require" - "github.com/xataio/pgstream/internal/backoff" + "github.com/xataio/pgstream/pkg/backoff" loglib "github.com/xataio/pgstream/pkg/log" ) diff --git a/pkg/wal/processor/translator/wal_translator.go b/pkg/wal/processor/translator/wal_translator.go index 9149c03..11bfba3 100644 --- a/pkg/wal/processor/translator/wal_translator.go +++ b/pkg/wal/processor/translator/wal_translator.go @@ -23,7 +23,8 @@ type Translator struct { logger loglib.Logger processor processor.Processor walToLogEntryAdapter walToLogEntryAdapter - skipSchema schemaFilter + skipDataEvent dataEventFilter + skipSchemaEvent schemaEventFilter schemaLogStore schemalog.Store idFinder columnFinder versionFinder columnFinder @@ -38,8 +39,9 @@ type Config struct { // configurable filters that allow the user of this library to have flexibility // when processing and translating the wal event data type ( - schemaFilter func(string) bool - columnFinder func(*schemalog.Column, *schemalog.Table) bool + dataEventFilter func(*wal.Data) bool + schemaEventFilter func(*schemalog.LogEntry) bool + columnFinder func(*schemalog.Column, *schemalog.Table) bool ) type Option func(t *Translator) @@ -62,8 +64,9 @@ func New(cfg *Config, p processor.Processor, opts ...Option) (*Translator, error processor: p, schemaLogStore: schemaLogStore, walToLogEntryAdapter: processor.WalDataToLogEntry, - // by default all schemas are processed - skipSchema: func(s string) bool { return false }, + // by default all events are processed + skipDataEvent: func(*wal.Data) bool { return false }, + skipSchemaEvent: func(*schemalog.LogEntry) bool { return false }, // by default we look for the primary key to use as identity column idFinder: primaryKeyFinder, } @@ -87,9 +90,15 @@ func WithVersionFinder(versionFinder columnFinder) Option { } } -func WithSkipSchema(skipSchema schemaFilter) Option { +func WithSkipSchemaEvent(skip schemaEventFilter) Option { return func(t *Translator) { - t.skipSchema = skipSchema + t.skipSchemaEvent = skip + } +} + +func WithSkipDataEvent(skip dataEventFilter) Option { + return func(t *Translator) { + t.skipDataEvent = skip } } @@ -109,7 +118,7 @@ func (t *Translator) ProcessWALEvent(ctx context.Context, event *wal.Event) erro } data := event.Data - if t.skipSchema(data.Schema) { + if t.skipDataEvent(data) { return nil } @@ -125,7 +134,7 @@ func (t *Translator) ProcessWALEvent(ctx context.Context, event *wal.Event) erro return err } - if t.skipSchema(logEntry.SchemaName) { + if t.skipSchemaEvent(logEntry) { return nil } diff --git a/pkg/wal/processor/translator/wal_translator_test.go b/pkg/wal/processor/translator/wal_translator_test.go index a05096a..99714dc 100644 --- a/pkg/wal/processor/translator/wal_translator_test.go +++ b/pkg/wal/processor/translator/wal_translator_test.go @@ -22,27 +22,28 @@ func TestTranslator_ProcessWALEvent(t *testing.T) { testLogEntry := newTestLogEntry() tests := []struct { - name string - event *wal.Event - store schemalog.Store - adapter walToLogEntryAdapter - skipSchema schemaFilter - idFinder columnFinder - processor processor.Processor + name string + event *wal.Event + store schemalog.Store + adapter walToLogEntryAdapter + skipDataEvent dataEventFilter + skipSchemaEvent schemaEventFilter + idFinder columnFinder + processor processor.Processor wantErr error }{ { - name: "ok - skip schema", - event: newTestDataEvent("I"), - skipSchema: func(s string) bool { return true }, + name: "ok - skip schema", + event: newTestDataEvent("I"), + skipDataEvent: func(*wal.Data) bool { return true }, wantErr: nil, }, { - name: "ok - skip log entry schema log", - event: newTestSchemaChangeEvent("I"), - skipSchema: func(s string) bool { return s == testSchemaName }, + name: "ok - skip log entry schema log", + event: newTestSchemaChangeEvent("I"), + skipSchemaEvent: func(s *schemalog.LogEntry) bool { return s.SchemaName == testSchemaName }, wantErr: nil, }, @@ -189,7 +190,8 @@ func TestTranslator_ProcessWALEvent(t *testing.T) { logger: loglib.NewNoopLogger(), processor: tc.processor, schemaLogStore: tc.store, - skipSchema: func(s string) bool { return false }, + skipDataEvent: func(d *wal.Data) bool { return false }, + skipSchemaEvent: func(*schemalog.LogEntry) bool { return false }, idFinder: func(c *schemalog.Column, _ *schemalog.Table) bool { return c.Name == "col-1" }, versionFinder: func(c *schemalog.Column, _ *schemalog.Table) bool { return c.Name == "col-2" }, walToLogEntryAdapter: func(d *wal.Data) (*schemalog.LogEntry, error) { return testLogEntry, nil }, @@ -203,8 +205,12 @@ func TestTranslator_ProcessWALEvent(t *testing.T) { translator.walToLogEntryAdapter = tc.adapter } - if tc.skipSchema != nil { - translator.skipSchema = tc.skipSchema + if tc.skipSchemaEvent != nil { + translator.skipSchemaEvent = tc.skipSchemaEvent + } + + if tc.skipDataEvent != nil { + translator.skipDataEvent = tc.skipDataEvent } err := translator.ProcessWALEvent(context.Background(), tc.event)