From 97a18c6aeaacedeac8fe476c439ce7833d236145 Mon Sep 17 00:00:00 2001 From: Derek Perkins Date: Tue, 12 Nov 2024 17:24:11 -0700 Subject: [PATCH] vstreamclient: framework for robust + simple usage Signed-off-by: Derek Perkins --- go/vt/vstreamclient/convert.go | 84 ++++++ go/vt/vstreamclient/options.go | 115 ++++++++ go/vt/vstreamclient/run.go | 302 ++++++++++++++++++++ go/vt/vstreamclient/state.go | 192 +++++++++++++ go/vt/vstreamclient/table.go | 324 ++++++++++++++++++++++ go/vt/vstreamclient/vstreamclient.go | 222 +++++++++++++++ go/vt/vstreamclient/vstreamclient_test.go | 267 ++++++++++++++++++ 7 files changed, 1506 insertions(+) create mode 100644 go/vt/vstreamclient/convert.go create mode 100644 go/vt/vstreamclient/options.go create mode 100644 go/vt/vstreamclient/run.go create mode 100644 go/vt/vstreamclient/state.go create mode 100644 go/vt/vstreamclient/table.go create mode 100644 go/vt/vstreamclient/vstreamclient.go create mode 100644 go/vt/vstreamclient/vstreamclient_test.go diff --git a/go/vt/vstreamclient/convert.go b/go/vt/vstreamclient/convert.go new file mode 100644 index 00000000000..fff71ec2989 --- /dev/null +++ b/go/vt/vstreamclient/convert.go @@ -0,0 +1,84 @@ +package vstreamclient + +import ( + "fmt" + "reflect" + "time" + + "vitess.io/vitess/go/sqltypes" + binlogdatapb "vitess.io/vitess/go/vt/proto/binlogdata" + querypb "vitess.io/vitess/go/vt/proto/query" +) + +// VStreamScanner allows for custom scan implementations +type VStreamScanner interface { + VStreamScan(fields []*querypb.Field, row []sqltypes.Value, rowEvent *binlogdatapb.RowEvent, rowChange *binlogdatapb.RowChange) error +} + +// copyRowToStruct builds a customer from a row event +// TODO: this is very rudimentary mapping that only works for top-level fields +func copyRowToStruct(shard shardConfig, row []sqltypes.Value, vPtr reflect.Value) error { + for fieldName, m := range shard.fieldMap { + structField := reflect.Indirect(vPtr).FieldByIndex(m.structIndex) + + switch m.kind { + case reflect.Bool: + rowVal, err := row[m.rowIndex].ToBool() + if err != nil { + return fmt.Errorf("error converting row value to bool for field %s: %w", fieldName, err) + } + structField.SetBool(rowVal) + + case reflect.Int, reflect.Int8, reflect.Int16, reflect.Int32, reflect.Int64: + rowVal, err := row[m.rowIndex].ToInt64() + if err != nil { + return fmt.Errorf("error converting row value to int64 for field %s: %w", fieldName, err) + } + structField.SetInt(rowVal) + + case reflect.Uint, reflect.Uint8, reflect.Uint16, reflect.Uint32, reflect.Uint64: + rowVal, err := row[m.rowIndex].ToUint64() + if err != nil { + return fmt.Errorf("error converting row value to uint64 for field %s: %w", fieldName, err) + } + structField.SetUint(rowVal) + + case reflect.Float32, reflect.Float64: + rowVal, err := row[m.rowIndex].ToFloat64() + if err != nil { + return fmt.Errorf("error converting row value to float64 for field %s: %w", fieldName, err) + } + structField.SetFloat(rowVal) + + case reflect.String: + rowVal := row[m.rowIndex].ToString() + structField.SetString(rowVal) + + case reflect.Struct: + switch m.structType.(type) { + case time.Time, *time.Time: + rowVal, err := row[m.rowIndex].ToTime() + if err != nil { + return fmt.Errorf("error converting row value to time.Time for field %s: %w", fieldName, err) + } + structField.Set(reflect.ValueOf(rowVal)) + } + + case reflect.Pointer, + reflect.Slice, + reflect.Array, + reflect.Invalid, + reflect.Uintptr, + reflect.Complex64, + reflect.Complex128, + reflect.Chan, + reflect.Func, + reflect.Interface, + reflect.Map, + reflect.UnsafePointer: + return fmt.Errorf("vstreamclient: unsupported field type: %s", m.kind.String()) + } + } + + return nil +} diff --git a/go/vt/vstreamclient/options.go b/go/vt/vstreamclient/options.go new file mode 100644 index 00000000000..9f527af2b30 --- /dev/null +++ b/go/vt/vstreamclient/options.go @@ -0,0 +1,115 @@ +package vstreamclient + +import ( + "fmt" + "time" + + "vitess.io/vitess/go/sqlescape" + binlogdatapb "vitess.io/vitess/go/vt/proto/binlogdata" + vtgatepb "vitess.io/vitess/go/vt/proto/vtgate" +) + +var ( + // DefaultMinFlushDuration is the default minimum duration between flushes, used if not explicitly + // set using WithMinFlushDuration. This can be safely modified if needed before calling New. + DefaultMinFlushDuration = 5 * time.Second + + // DefaultMaxRowsPerFlush is the default number of rows to buffer per table, used if not explicitly + // set in the table configuration. This same number is also used to chunk rows when calling flush. + // This can be safely modified if needed before calling New. + DefaultMaxRowsPerFlush = 1000 +) + +// Option is a function that can be used to configure a VStreamClient +type Option func(v *VStreamClient) error + +// WithMinFlushDuration sets the minimum duration between flushes. This is useful for ensuring that data +// isn't flushed too often, which can be inefficient. The default is 30 seconds. +func WithMinFlushDuration(d time.Duration) Option { + return func(v *VStreamClient) error { + if d <= 0 { + return fmt.Errorf("vstreamclient: minimum flush duration must be positive, got %s", d.String()) + } + + v.minFlushDuration = d + return nil + } +} + +func WithHeartbeatSeconds(seconds int) Option { + return func(v *VStreamClient) error { + if seconds <= 0 { + return fmt.Errorf("vstreamclient: heartbeat seconds must be positive, got %d", seconds) + } + + v.heartbeatSeconds = seconds + return nil + } +} + +func WithStateTable(keyspace, table string) Option { + return func(v *VStreamClient) error { + shards, ok := v.shardsByKeyspace[keyspace] + if !ok { + return fmt.Errorf("vstreamclient: keyspace %s not found", keyspace) + } + + // this could allow for shard pinning, but we can support that if it becomes useful + if len(shards) > 1 { + return fmt.Errorf("vstreamclient: keyspace %s is sharded, only unsharded keyspaces are supported", keyspace) + } + + v.vgtidStateKeyspace = sqlescape.EscapeID(keyspace) + v.vgtidStateTable = sqlescape.EscapeID(table) + return nil + } +} + +// DefaultFlags returns a default set of flags for a VStreamClient, safe to use in most cases, but can be customized +func DefaultFlags() *vtgatepb.VStreamFlags { + return &vtgatepb.VStreamFlags{ + HeartbeatInterval: 1, + } +} + +// WithFlags lets you manually control all the flag options, instead of using helper functions +func WithFlags(flags *vtgatepb.VStreamFlags) Option { + return func(v *VStreamClient) error { + v.flags = flags + return nil + } +} + +// WithEventFunc provides for custom event handling functions for specific event types. Only one function +// can be registered per event type, and it is called before the default event handling function. Returning +// an error from the custom function will exit the stream before the default function is called. +func WithEventFunc(fn EventFunc, eventTypes ...binlogdatapb.VEventType) Option { + return func(v *VStreamClient) error { + if len(eventTypes) == 0 { + return fmt.Errorf("vstreamclient: no event types provided") + } + + if v.eventFuncs == nil { + v.eventFuncs = make(map[binlogdatapb.VEventType]EventFunc) + } + + for _, eventType := range eventTypes { + if _, ok := v.eventFuncs[eventType]; ok { + return fmt.Errorf("vstreamclient: event type %s already has a function", eventType.String()) + } + + v.eventFuncs[eventType] = fn + } + + return nil + } +} + +// WithStartingVGtid sets the starting VGtid for the VStreamClient. This is useful for resuming a stream from a +// specific point, vs what might be stored in the state table. +func WithStartingVGtid(vgtid *binlogdatapb.VGtid) Option { + return func(v *VStreamClient) error { + v.latestVgtid = vgtid + return nil + } +} diff --git a/go/vt/vstreamclient/run.go b/go/vt/vstreamclient/run.go new file mode 100644 index 00000000000..b933f4faf4c --- /dev/null +++ b/go/vt/vstreamclient/run.go @@ -0,0 +1,302 @@ +package vstreamclient + +import ( + "context" + "errors" + "fmt" + "io" + "slices" + "time" + + "google.golang.org/protobuf/proto" + + binlogdatapb "vitess.io/vitess/go/vt/proto/binlogdata" +) + +// EventFunc is an optional callback function that can be registered for individual event types +type EventFunc func(ctx context.Context, event *binlogdatapb.VEvent) error + +// FlushFunc is called per batch of rows, which could be any number of rows, but will be limited by the +// MaxRowsPerFlush setting. The rows are always a slice of the data type for the table, which is configured +// per table. You can use type assertion to convert the rows to the correct type, and then process them as needed. +// +// func(ctx context.Context, rows any, meta FlushMeta) error { +// typedRows := rows.([]*DataType) +// // do something with the rows +// return nil +// } +// +// Returning an error will stop the stream, and the last vgtid will not be stored. The stream will need to be +// restarted from the last successful flushed vgtid. +type FlushFunc func(ctx context.Context, rows []Row, meta FlushMeta) error + +// Row is the data structure that will be passed to the FlushFunc. It contains the row event, the row change, +// and the scanned data itself. The data will be the type registered for the table, unless it is a delete event, in +// which case it will be nil. +type Row struct { + RowEvent *binlogdatapb.RowEvent + RowChange *binlogdatapb.RowChange + Data any // will be populated as the data type registered for the table, unless it is a delete event +} + +// FlushMeta is the metadata that is passed to the FlushFunc. It's not necessary, but might be useful +// for logging, debugging, etc. +type FlushMeta struct { + Keyspace string + Table string + + TableStats TableStats + VStreamStats VStreamStats + + LatestVGtid *binlogdatapb.VGtid +} + +// Run starts the vstream, processing events from the stream until it ends or an error occurs. +func (v *VStreamClient) Run(ctx context.Context) error { + // make a cancelable context, so the heartbeat monitor can cancel the stream if necessary + ctx, cancel := context.WithCancel(ctx) + defer cancel() + + go v.monitorHeartbeat(ctx, cancel) + + // ******************************************************************************************************** + // Event Processing + // + // this is the main loop that processes events from the stream. It will continue until the stream ends or an + // error occurs. The context is checked before processing each event. + // ******************************************************************************************************** + for { + // events come in batches, depending on how busy the keyspace is. This is where the network communication + // happens, so it's the most likely place for errors to occur. + events, err := v.reader.Recv() + switch { + case err == nil: // no error, continue processing below + + case errors.Is(err, io.EOF): + fmt.Println("vstreamclient: stream ended") + return nil + + default: + return fmt.Errorf("vstreamclient: remote error: %w", err) + } + + for _, ev := range events { + // check for context errors before processing the next event, since any processing will likely be wasted + err = ctx.Err() + if err != nil { + return fmt.Errorf("vstreamclient: context error: %w", err) + } + + // keep track of the last event time for heartbeat monitoring. We're purposefully not using the event + // timestamp, since that would cause cancellation if the stream was copying, delayed, or lagging. + v.lastEventReceivedAtUnix.Store(time.Now().Unix()) + + // call the user-defined event function, if it exists + fn, ok := v.eventFuncs[ev.Type] + if ok { + err = fn(ctx, ev) + if err != nil { + return fmt.Errorf("vstreamclient: user error processing %s event: %w", ev.Type.String(), err) + } + } + + // handle individual events based on their type + switch ev.Type { + // field events are sent first, and contain schema information for any tables that are being streamed, + // so we cache the fields for each table as they come in + case binlogdatapb.VEventType_FIELD: + var table *TableConfig + table, ok = v.tables[ev.FieldEvent.TableName] + if !ok { + return errors.New("vstreamclient: unexpected table name: " + ev.FieldEvent.TableName) + } + + err = table.handleFieldEvent(ev.FieldEvent) + if err != nil { + return err + } + + // row events are the actual data changes, and we'll process them based on the table name + case binlogdatapb.VEventType_ROW: + var table *TableConfig + table, ok = v.tables[ev.RowEvent.TableName] + if !ok { + return errors.New("vstreamclient: unexpected table name: " + ev.FieldEvent.TableName) + } + + err = table.handleRowEvent(ev.RowEvent, &v.stats) + if err != nil { + return err + } + + // vgtid events are sent periodically, and we'll store the latest vgtid for checkpointing. We may get + // this mid-transaction, so we don't flush here, so we don't propagate a partial transaction that may + // be rolled back. + case binlogdatapb.VEventType_VGTID: + v.latestVgtid = ev.Vgtid + + // commit and other events are safe to flush on, since they indicate the end of a transaction. + // Otherwise, there's not much to do with these events. + case binlogdatapb.VEventType_COMMIT, binlogdatapb.VEventType_OTHER: + // only flush when we have an event that guarantees we're not flushing mid-transaction + err = v.flush(ctx) + if err != nil { + return err + } + + // DDL events are schema changes, and we might want to handle them differently than data events. + // They are safe to flush on, since they indicate the end of a transaction. If you want to + // transparently adjust the destination schema based on DDL events, you would do that here. + case binlogdatapb.VEventType_DDL: + err = v.flush(ctx) + if err != nil { + return err + } + + case binlogdatapb.VEventType_COPY_COMPLETED: + // TODO: don't flush until the copy is completed? do some sort of cleanup if we haven't received this? + + err = setCopyCompleted(ctx, v.session, v.name, v.vgtidStateKeyspace, v.vgtidStateTable) + if err != nil { + return err + } + + // heartbeat events are sent periodically, if the source keyspace is idle and there are no other events. + // It's possible that there is still buffered data that hadn't exceeded the min duration or max rows + // thresholds the last time flush was called. Most of the time, that won't be the case, but flush will + // check for that and only run if necessary. + case binlogdatapb.VEventType_HEARTBEAT: + err = v.flush(ctx) + if err != nil { + return err + } + + // even if there are no changes to the tables being streamed, we'll still get a begin, vgtid, and commit + // event for each transaction. The other two are used for checkpoints, but nothing to do here. + case binlogdatapb.VEventType_BEGIN: + + // journal events are sent on resharding. Unless you are manually targeting shards, vstream should + // transparently handle resharding for you, so you shouldn't need to do anything with these events. + // You might want to log them for debugging purposes, or to alert on resharding events in case + // something goes wrong. After resharding, if the pre-reshard vgtid is no longer valid, you may need + // to restart the stream from the beginning. + case binlogdatapb.VEventType_JOURNAL: + + // there aren't strong cases for handling these events, but you might want to log them for debugging + case binlogdatapb.VEventType_VERSION, binlogdatapb.VEventType_LASTPK, binlogdatapb.VEventType_SAVEPOINT: + } + } + } +} + +// ******************************************************************************************************** +// Heartbeat Monitoring +// +// the heartbeat ticker will be used to ensure that we haven't been disconnected from the stream. This starts +// a goroutine that will cancel the context if we haven't received an event in twice the heartbeat duration. +// ******************************************************************************************************** +func (v *VStreamClient) monitorHeartbeat(ctx context.Context, cancel context.CancelFunc) { + heartbeatDur := time.Duration(v.flags.HeartbeatInterval) * time.Second + heartbeat := time.NewTicker(heartbeatDur) + defer heartbeat.Stop() + + for { + select { + case tm := <-heartbeat.C: + // if we haven't received an event yet, we'll skip the heartbeat check + if v.lastEventReceivedAtUnix.Load() == 0 { + continue + } + + // if we haven't received an event in twice the heartbeat duration, we'll cancel the context, since + // we're likely disconnected, and exit the goroutine + if tm.Sub(time.Unix(v.lastEventReceivedAtUnix.Load(), 0)) > heartbeatDur*2 { + cancel() + return + } + + case <-ctx.Done(): + // this cancel is probably unnecessary, since the context is already done, but it's good practice + cancel() + return + } + } +} + +// ******************************************************************************************************** +// Flush Data + Store Vgtid +// +// as we process events, we'll periodically need to check point the last vgtid and store the customers in the +// database. You can control the frequency of this flush by adjusting the minFlushDuration and maxRowsPerFlush. +// This is only called when we have an event that guarantees we're not flushing mid-transaction. +// ******************************************************************************************************** +// +// we might consider exporting Flush, but we'd need to have a mutex or something to block the stream from +// processing, and technically we'd need to let it run until a commit event happens. +func (v *VStreamClient) flush(ctx context.Context) error { + // if the lastFlushedVgtid is the same as the latestVgtid, we don't need to do anything + if proto.Equal(v.lastFlushedVgtid, v.latestVgtid) { + return nil + } + + // if we have exceeded the minFlushDuration, we'll force a flush, regardless how many rows each table has + shouldFlush := time.Since(v.stats.LastFlushedAt) > v.minFlushDuration + + // if we haven't exceeded the min flush duration, we'll check if any of the tables have exceeded their + // max rows to flush. If any of them have, we will force a flush. Every table needs to be flushed at + // the same time, since the last vgtid covers all tables. + if !shouldFlush { + for _, table := range v.tables { + if len(table.currentBatch) >= table.MaxRowsPerFlush { + shouldFlush = true + break + } + } + } + + if !shouldFlush { + return nil + } + + // TODO: maybe start a transaction here, and pass it to the flush function + + for _, table := range v.tables { + // flush the rows to the database, chunked using the max batch size + for chunk := range slices.Chunk(table.currentBatch, table.MaxRowsPerFlush) { + err := table.FlushFn(ctx, chunk, FlushMeta{ + Keyspace: table.Keyspace, + Table: table.Table, + + TableStats: table.stats, + VStreamStats: v.stats, + + LatestVGtid: v.latestVgtid, + }) + if err != nil { + return fmt.Errorf("vstreamclient: error flushing table %s: %w", table.Table, err) + } + + // update the stats for the table and the vstream + table.stats.FlushCount++ + table.stats.FlushedRowCount += len(chunk) + table.stats.LastFlushedAt = time.Now() + + v.stats.TableFlushCount++ + v.stats.FlushedRowCount += len(chunk) + } + + table.resetBatch() + } + + // always store the latest vgtid, even if there are no customers to store + err := updateLatestVGtid(ctx, v.session, v.name, v.vgtidStateKeyspace, v.vgtidStateTable, v.latestVgtid) + if err != nil { + return err + } + + v.stats.FlushCount++ + v.stats.LastFlushedAt = time.Now() + v.lastFlushedVgtid = v.latestVgtid + + return nil +} diff --git a/go/vt/vstreamclient/state.go b/go/vt/vstreamclient/state.go new file mode 100644 index 00000000000..83e585f9936 --- /dev/null +++ b/go/vt/vstreamclient/state.go @@ -0,0 +1,192 @@ +package vstreamclient + +import ( + "context" + "encoding/json" + "fmt" + + binlogdatapb "vitess.io/vitess/go/vt/proto/binlogdata" + querypb "vitess.io/vitess/go/vt/proto/query" + "vitess.io/vitess/go/vt/vtgate/vtgateconn" +) + +type dbTableConfig struct { + Keyspace string + Table string + Query string +} + +func tablesToDBTableConfig(tables map[string]*TableConfig) map[string]dbTableConfig { + m := make(map[string]dbTableConfig, len(tables)) + + for k, table := range tables { + m[k] = dbTableConfig{ + Keyspace: table.Keyspace, + Table: table.Table, + Query: table.Query, + } + } + + return m +} + +func initStateTable(ctx context.Context, session *vtgateconn.VTGateSession, keyspaceName, tableName string) error { + query := fmt.Sprintf(`create table if not exists %s.%s ( + name varbinary(64) not null, + latest_vgtid json, + table_config json not null, + copy_completed bool not null default false, + created_at timestamp default current_timestamp, + updated_at timestamp default current_timestamp on update current_timestamp, + PRIMARY KEY (name) +)`, keyspaceName, tableName) + + fmt.Println("query: ", query) + + _, err := session.Execute(ctx, query, nil) + if err != nil { + return fmt.Errorf("vstreamclient: failed to create state table: %w", err) + } + + return nil +} + +func initVGtid(ctx context.Context, session *vtgateconn.VTGateSession, name, keyspaceName, tableName string, tables map[string]*TableConfig, shardsByKeyspace map[string][]string) (*binlogdatapb.VGtid, error) { + vgtid, err := newVGtid(tables, shardsByKeyspace) + if err != nil { + return nil, err + } + + latestVgtidJSON, err := json.Marshal(vgtid) + if err != nil { + return nil, fmt.Errorf("vstreamclient: failed to marshal latest vgtid: %w", err) + } + + tablesJSON, err := json.Marshal(tablesToDBTableConfig(tables)) + if err != nil { + return nil, fmt.Errorf("vstreamclient: failed to marshal tables: %w", err) + } + + query := fmt.Sprintf(`insert into %s.%s (name, latest_vgtid, table_config) values (:name, :latest_vgtid, :table_config)`, + keyspaceName, tableName, + ) + _, err = session.Execute(ctx, query, map[string]*querypb.BindVariable{ + "name": {Type: querypb.Type_VARBINARY, Value: []byte(name)}, + "latest_vgtid": {Type: querypb.Type_JSON, Value: latestVgtidJSON}, + "table_config": {Type: querypb.Type_JSON, Value: tablesJSON}, + }) + if err != nil { + return nil, fmt.Errorf("vstreamclient: failed to get latest vgtid for %s.%s: %w", keyspaceName, tableName, err) + } + + return vgtid, nil +} + +func newVGtid(tables map[string]*TableConfig, shardsByKeyspace map[string][]string) (*binlogdatapb.VGtid, error) { + bootstrappedKeyspaces := make(map[string]bool) + vgtid := &binlogdatapb.VGtid{} + + for _, table := range tables { + if bootstrappedKeyspaces[table.Keyspace] { + continue + } + + // TODO: this currently doesn't support subsetting shards, but we can add that if needed + shards, ok := shardsByKeyspace[table.Keyspace] + if !ok { + return nil, fmt.Errorf("vstreamclient: keyspace %s not found", table.Keyspace) + } + + for _, shard := range shards { + vgtid.ShardGtids = append(vgtid.ShardGtids, &binlogdatapb.ShardGtid{ + Keyspace: table.Keyspace, + Shard: shard, + Gtid: "", // start from the beginning, meaning initializing a copy phase + }) + } + } + + return vgtid, nil +} + +func getLatestVGtid(ctx context.Context, session *vtgateconn.VTGateSession, name, keyspaceName, tableName string) (*binlogdatapb.VGtid, map[string]*TableConfig, bool, error) { + query := fmt.Sprintf(`select latest_vgtid, table_config, copy_completed from %s.%s where name = :name`, keyspaceName, tableName) + + result, err := session.Execute(ctx, query, map[string]*querypb.BindVariable{ + "name": {Type: querypb.Type_VARBINARY, Value: []byte(name)}, + }) + if err != nil { + return nil, nil, false, fmt.Errorf("vstreamclient: failed to get latest vgtid for %s.%s: %w", keyspaceName, tableName, err) + } + + // if there are no rows, or the value is null, return nil, which will start the stream from the beginning + if len(result.Rows) == 0 || result.Rows[0][0].IsNull() { + return nil, nil, false, nil + } + + // unmarshal the JSON value which should be a valid, initialized VGtid + latestVGtidJSON, err := result.Rows[0][0].ToBytes() + if err != nil { + return nil, nil, false, fmt.Errorf("vstreamclient: failed to convert latest_vgtid to bytes: %w", err) + } + + var latestVGtid binlogdatapb.VGtid + err = json.Unmarshal(latestVGtidJSON, &latestVGtid) + if err != nil { + return nil, nil, false, fmt.Errorf("vstreamclient: failed to unmarshal latest_vgtid: %w", err) + } + + // unmarshal the JSON value which should be the original table config + tablesJSON, err := result.Rows[0][1].ToBytes() + if err != nil { + return nil, nil, false, fmt.Errorf("vstreamclient: failed to convert table_config to bytes: %w", err) + } + + var tables map[string]*TableConfig + err = json.Unmarshal(tablesJSON, &tables) + if err != nil { + return nil, nil, false, fmt.Errorf("vstreamclient: failed to unmarshal table_config: %w", err) + } + + // check if the copy has been completed + copyCompleted, err := result.Rows[0][2].ToBool() + if err != nil { + return nil, nil, false, fmt.Errorf("vstreamclient: failed to convert copy_completed to bool: %w", err) + } + + return &latestVGtid, tables, copyCompleted, nil +} + +func updateLatestVGtid(ctx context.Context, session *vtgateconn.VTGateSession, name, keyspaceName, tableName string, vgtid *binlogdatapb.VGtid) error { + latestVgtid, err := json.Marshal(vgtid) + if err != nil { + return fmt.Errorf("vstreamclient: failed to marshal latest_vgtid: %w", err) + } + + query := fmt.Sprintf(`update %s.%s set latest_vgtid = :latest_vgtid where name = :name`, + keyspaceName, tableName, + ) + _, err = session.Execute(ctx, query, map[string]*querypb.BindVariable{ + "latest_vgtid": {Type: querypb.Type_JSON, Value: latestVgtid}, + "name": {Type: querypb.Type_VARBINARY, Value: []byte(name)}, + }) + if err != nil { + return fmt.Errorf("vstreamclient: failed to update latest_vgtid for %s.%s: %w", keyspaceName, tableName, err) + } + + return nil +} + +func setCopyCompleted(ctx context.Context, session *vtgateconn.VTGateSession, name, keyspaceName, tableName string) error { + query := fmt.Sprintf(`update %s.%s set copy_completed = true where name = :name`, + keyspaceName, tableName, + ) + _, err := session.Execute(ctx, query, map[string]*querypb.BindVariable{ + "name": {Type: querypb.Type_VARBINARY, Value: []byte(name)}, + }) + if err != nil { + return fmt.Errorf("vstreamclient: failed to set copy_completed for %s.%s: %w", keyspaceName, tableName, err) + } + + return nil +} diff --git a/go/vt/vstreamclient/table.go b/go/vt/vstreamclient/table.go new file mode 100644 index 00000000000..3796afd05e9 --- /dev/null +++ b/go/vt/vstreamclient/table.go @@ -0,0 +1,324 @@ +package vstreamclient + +import ( + "fmt" + "maps" + "reflect" + "slices" + "time" + + "vitess.io/vitess/go/sqlescape" + "vitess.io/vitess/go/sqltypes" + binlogdatapb "vitess.io/vitess/go/vt/proto/binlogdata" + querypb "vitess.io/vitess/go/vt/proto/query" +) + +// TableConfig is the configuration for a table, which is used to configure filtering and scanning of the results +type TableConfig struct { + Keyspace string + Table string + + // if no configured, this is set to "select * from keyspace.table", streaming all the fields. This can be + // overridden to select only the fields that are needed, which can reduce memory usage and improve performance, + // and also alias the fields to match the struct fields. + Query string + + // MaxRowsPerFlush serves two purposes: + // 1. it limits the number of rows that are flushed at once, to avoid large transactions. If more than + // this number of rows are processed, they will be flushed in chunks of this size. + // 2. if this number is exceeded before reaching the minFlushDuration, it will trigger a flush to avoid + // holding too much memory in the rows slice. + MaxRowsPerFlush int + + // if true, will reuse the same slice for each batch, which can reduce memory allocations. It does mean + // that the caller must copy the data if they want to keep it, because the slice will be reused. + ReuseBatchSlice bool + + // if true, will error and block the stream if there are fields in the result that don't match the struct + ErrorOnUnknownFields bool + + // this is the function that will be called to flush the rows to the handler. This is called when the + // minFlushDuration has passed, or the maxRowsPerFlush has been exceeded. + FlushFn FlushFunc + + // TODO: translate this to *sql.Tx so we can pass it to the flush function + FlushInTx bool + + // purposefully not exported, so we don't have to use a mutex to access it. The idea is that the consumer only + // needs to access this when they are flushing the rows, so they can copy the data if they want to keep it. If + // there is a use case for accessing this outside of the flush function, we can add a getter method. + stats TableStats + + // this is the data type for this table, which is used to scan the results. Regardless whether a pointer + // or value is supplied, the return value will always be []*DataType. + DataType any + underlyingType reflect.Type + implementsScanner bool + + // this stores the current batch of rows for this table, which caches the results until the next flush. + currentBatch []Row + + // this is the mapping of fields to the query results, which is used to scan the results. This is done + // on a per-shard basis, because while unlikely, it's possible that the same table in different shards + // could have different schemas. + shards map[string]shardConfig +} + +// TableStats keeps track of the number of rows processed and flushed for a single table +type TableStats struct { + // how many rows have been processed for this table. These are incremented as each row + // is processed, regardless of whether it is flushed. + RowInsertCount int + RowUpdateCount int + RowDeleteCount int + + FlushedRowCount int // sum of rows flushed, including all insert/update/delete events + + FlushCount int // how many times the flush function was executed for this table. Not incremented for no-ops + LastFlushedAt time.Time // only set after a flush successfully completes +} + +// shardConfig is the per-shard configuration for a table, which is used to scan the results +type shardConfig struct { + fieldMap map[string]fieldMapping + fields []*querypb.Field +} + +// fieldMapping caches the mapping of table fields to struct fields, to reduce reflection overhead. This is +// configured once per shard, and used for all rows in that shard, unless a DDL event changes the schema, +// in which case the mapping is updated. +type fieldMapping struct { + rowIndex int + structIndex []int + structType any + kind reflect.Kind + isPointer bool +} + +// TODO: there's a consolidation issue to deal with here. Someone could easily add a new table in +// code after the copy phase completes and redeploy, expecting it to catch up, but it wouldn't +// by default. To support that, I think we'd have to store a list of tables in state, and if a +// new table is added, catch it up to the same vgtid, then do a cutover to be in the same stream. +// That'd be a much better user experience, but a decent amount of added complexity. +func (v *VStreamClient) initTables(tables []TableConfig) error { + if len(v.tables) > 0 { + return fmt.Errorf("vstreamclient: %d tables already configured", len(v.tables)) + } + + if len(tables) == 0 { + return fmt.Errorf("vstreamclient: no tables provided") + } + + for _, table := range tables { + // basic validation + if table.DataType == nil { + return fmt.Errorf("vstreamclient: table %s.%s has no data type", table.Keyspace, table.Table) + } + + if table.Keyspace == "" { + return fmt.Errorf("vstreamclient: table %v has no keyspace", table) + } + + if table.Table == "" { + return fmt.Errorf("vstreamclient: table %v has no table name", table) + } + + fmt.Println("shardsByKeyspace", v.shardsByKeyspace) + // make sure the keyspace exists in the cluster + _, ok := v.shardsByKeyspace[table.Keyspace] + if !ok { + return fmt.Errorf("vstreamclient: keyspace %s not found in the cluster", table.Keyspace) + } + + // the key is the keyspace and table name, separated by a period. We use this because vstream events + // use this as the table name, so it's easier for lookup, and it's unique. + k := fmt.Sprintf("%s.%s", table.Keyspace, table.Table) + + // if the same table is referenced multiple times in the same stream, only one table will actually + // receive events. This prevents users from unknowingly missing events for the second table reference. + if _, ok = v.tables[k]; ok { + return fmt.Errorf("duplicate table %s in keyspace %s", table.Table, table.Keyspace) + } + + // set defaults if not provided + if table.Query == "" { + table.Query = "select * from " + sqlescape.EscapeID(table.Table) + } + + if table.MaxRowsPerFlush == 0 { + table.MaxRowsPerFlush = DefaultMaxRowsPerFlush + } + + // if the data type implements VStreamScanner, we will use that to scan the results + _, table.implementsScanner = table.DataType.(VStreamScanner) + + // regardless whether the user provided a pointer to a struct or a struct, we want to store the + // underlying type of the struct, so we can create new instances of it later + table.underlyingType = reflect.Indirect(reflect.ValueOf(table.DataType)).Type() + + if table.underlyingType.Kind() != reflect.Struct { + return fmt.Errorf("vstreamclient: data type for table %s.%s must be a struct", table.Keyspace, table.Table) + } + + table.shards = make(map[string]shardConfig) + + // initialize the slice containing the batch of rows for this table + table.resetBatch() + + // store the table in the map + v.tables[k] = &table + } + + return nil +} + +func validateTableConfig(providedTables, dbTables map[string]*TableConfig) error { + providedTablesMap := tablesToDBTableConfig(providedTables) + dbTablesMap := tablesToDBTableConfig(dbTables) + + if !maps.Equal(providedTablesMap, dbTablesMap) { + // TODO: this could be more user-friendly and show the differences + return fmt.Errorf("vstreamclient: provided tables do not match stored tables") + } + + return nil +} + +func (table *TableConfig) resetBatch() { + if table.ReuseBatchSlice && table.currentBatch != nil { + table.currentBatch = slices.Delete(table.currentBatch, 0, len(table.currentBatch)) + } else { + table.currentBatch = make([]Row, 0, table.MaxRowsPerFlush) + } +} + +func (table *TableConfig) handleFieldEvent(ev *binlogdatapb.FieldEvent) error { + var fieldMap map[string]fieldMapping + var err error + + if !table.implementsScanner { + fieldMap, err = table.reflectMapFields(ev.Fields) + if err != nil { + return err + } + } + + table.shards[ev.Shard] = shardConfig{ + fieldMap: fieldMap, + fields: ev.Fields, + } + + return nil +} + +func (table *TableConfig) reflectMapFields(fields []*querypb.Field) (map[string]fieldMapping, error) { + fieldMap := make(map[string]fieldMapping, len(fields)) + + for i := 0; i < table.underlyingType.NumField(); i++ { + structField := table.underlyingType.Field(i) + if !structField.IsExported() { + continue + } + + // get the field name from the vstream, db, json tag, or the field name, in that order + mappedFieldName := structField.Tag.Get("vstream") + if mappedFieldName == "-" { + continue + } + if mappedFieldName == "" { + mappedFieldName = structField.Tag.Get("db") + } + if mappedFieldName == "" { + mappedFieldName = structField.Tag.Get("json") + } + if mappedFieldName == "" { + mappedFieldName = structField.Name + } + + var found bool + for j, tableField := range fields { + if tableField.Name != mappedFieldName { + continue + } + + found = true + fieldMap[mappedFieldName] = fieldMapping{ + rowIndex: j, + structIndex: structField.Index, + kind: structField.Type.Kind(), + isPointer: structField.Type.Kind() == reflect.Ptr, + } + } + if !found && table.ErrorOnUnknownFields { + return nil, fmt.Errorf("vstreamclient: field %s not found in provided data type", mappedFieldName) + } + } + + // sanity check that we found at least one field + if len(fieldMap) == 0 { + return nil, fmt.Errorf("vstreamclient: no matching fields found for table %s", table.Table) + } + + return fieldMap, nil +} + +func (table *TableConfig) handleRowEvent(ev *binlogdatapb.RowEvent, vstreamStats *VStreamStats) error { + shard, ok := table.shards[ev.Shard] + if !ok { + return fmt.Errorf("unexpected shard: %s", ev.Shard) + } + + table.currentBatch = slices.Grow(table.currentBatch, len(ev.RowChanges)) + + for _, rc := range ev.RowChanges { + var row []sqltypes.Value + + switch { + case rc.After == nil: // delete event + // even though a delete event might be represented as a nil row, the consumer will still need to know + // which row was deleted, so we'll pass the before row to the consumer, which should contain the primary + // key fields, so they can be used however necessary to handle the delete in a downstream system. + row = sqltypes.MakeRowTrusted(shard.fields, rc.Before) + vstreamStats.RowDeleteCount++ + table.stats.RowDeleteCount++ + + case rc.Before == nil: // insert event + row = sqltypes.MakeRowTrusted(shard.fields, rc.After) + vstreamStats.RowInsertCount++ + table.stats.RowInsertCount++ + + case rc.Before != nil: // update event + row = sqltypes.MakeRowTrusted(shard.fields, rc.After) + vstreamStats.RowUpdateCount++ + table.stats.RowUpdateCount++ + } + + // create a new struct for the row + v := reflect.New(table.underlyingType) + table.currentBatch = append(table.currentBatch, Row{ + RowEvent: ev, + RowChange: rc, + Data: v.Interface(), + }) + + // use the custom scanner if available + if table.implementsScanner { + returnVals := v.MethodByName("VStreamScan").Call([]reflect.Value{ + reflect.ValueOf(shard.fields), + reflect.ValueOf(row), + reflect.ValueOf(ev), + reflect.ValueOf(rc), + }) + if !returnVals[0].IsNil() { + return fmt.Errorf("vstreamclient: client scan failed: %w", returnVals[0].Interface().(error)) + } + } else { + err := copyRowToStruct(shard, row, v) + if err != nil { + return err + } + } + } + + return nil +} diff --git a/go/vt/vstreamclient/vstreamclient.go b/go/vt/vstreamclient/vstreamclient.go new file mode 100644 index 00000000000..86598c7007a --- /dev/null +++ b/go/vt/vstreamclient/vstreamclient.go @@ -0,0 +1,222 @@ +package vstreamclient + +import ( + "context" + "errors" + "fmt" + "strings" + "sync/atomic" + "time" + + binlogdatapb "vitess.io/vitess/go/vt/proto/binlogdata" + topodatapb "vitess.io/vitess/go/vt/proto/topodata" + vtgatepb "vitess.io/vitess/go/vt/proto/vtgate" + _ "vitess.io/vitess/go/vt/vtctl/grpcvtctlclient" + _ "vitess.io/vitess/go/vt/vtgate/grpcvtgateconn" + "vitess.io/vitess/go/vt/vtgate/vtgateconn" +) + +// VStreamClient is the primary struct +type VStreamClient struct { + name string + conn *vtgateconn.VTGateConn + + // this session is used to obtain the shards for the keyspace, and to manage the state table + session *vtgateconn.VTGateSession + + // the reader is the vstream reader, which is used to read the binlog events + reader vtgateconn.VStreamReader + filter *binlogdatapb.Filter + flags *vtgatepb.VStreamFlags + + // vgtidStateKeyspace and vgtidStateTable are the keyspace and table where the last VGtid is stored + vgtidStateKeyspace string + vgtidStateTable string + + // may not be necessary... + shardsByKeyspace map[string][]string + + // keep per table state and config, which is used to generate the vgtid filter. + // this is a map of keyspace.table to TableConfig, since that's how the binlog table is stored + tables map[string]*TableConfig + + // to avoid flushing too often, we will only flush if it has been at least minFlushDuration since the last flush. + // we're relying on heartbeat events to handle max duration between flushes, in case there are no other events. + minFlushDuration time.Duration + + // this is the duration between heartbeat events. This is not a duration because the server side + // parameter only has a granularity of seconds. + heartbeatSeconds int + + // lastEventReceivedAtUnix is the time the last event was received, which is used in the heartbeat monitor + // to let us know if we've been disconnected from the stream. + lastEventReceivedAtUnix atomic.Int64 + // lastFlushedVgtid is the last vgtid that was flushed, which is compared to the latestVgtid to determine + // if we need to flush again. + lastFlushedVgtid, latestVgtid *binlogdatapb.VGtid + + stats VStreamStats + + // these are the optional functions that are called for each event type + eventFuncs map[binlogdatapb.VEventType]EventFunc +} + +// VStreamStats keeps track of the number of rows processed and flushed for the whole stream +type VStreamStats struct { + // how many rows have been processed for the whole stream, across all tables. These are incremented as each row + // is processed, regardless of whether it is flushed. + RowInsertCount int + RowUpdateCount int + RowDeleteCount int + + FlushedRowCount int // sum of rows flushed, regardless of table and including all insert/update/delete events + + // how many times the flush function was executed for the whole stream. Not incremented for no-ops + FlushCount int + // sum of successful, individual, table flush functions. Only increments if the table flush func is called + TableFlushCount int + + LastFlushedAt time.Time // only set after a flush successfully completes +} + +// New initializes a new VStreamClient, which is used to stream binlog events from Vitess. +func New(ctx context.Context, name string, conn *vtgateconn.VTGateConn, tables []TableConfig, opts ...Option) (*VStreamClient, error) { + // validate required parameters + if len(name) > 64 { + return nil, fmt.Errorf("vstreamclient: name must be 64 characters or less, got %d", len(name)) + } + + if conn == nil { + return nil, fmt.Errorf("vstreamclient: conn is required") + } + + // initialize the VStreamClient, with options and settings to be set later + v := &VStreamClient{ + name: name, + conn: conn, + session: conn.Session("", nil), + tables: make(map[string]*TableConfig), + minFlushDuration: DefaultMinFlushDuration, + } + + var err error + + // load all shards, so we can validate settings before starting. It's not technically necessary to do this here, + // but it's more user-friendly to fail early if there is misconfiguration. This needs to be done before running + // the options, so that the shards are available for validation. + v.shardsByKeyspace, err = getShardsByKeyspace(ctx, v.session) + if err != nil { + return nil, err + } + + err = v.initTables(tables) + if err != nil { + return nil, err + } + + // set options from the variadic list + for _, opt := range opts { + if err = opt(v); err != nil { + return nil, err + } + } + + // validate required options and set defaults where possible + + if len(v.tables) == 0 { + return nil, fmt.Errorf("vstreamclient: no tables configured") + } + + // convert the tables into filter + rules + + rules := make([]*binlogdatapb.Rule, 0, len(v.tables)) + + for _, table := range v.tables { + rules = append(rules, &binlogdatapb.Rule{ + Match: table.Keyspace, + Filter: table.Query, + }) + } + + v.filter = &binlogdatapb.Filter{ + Rules: rules, + } + + if v.flags == nil { + v.flags = DefaultFlags() + if v.heartbeatSeconds > 0 { + v.flags.HeartbeatInterval = uint32(v.heartbeatSeconds) + } + } + + // handle state lookup + + err = initStateTable(ctx, v.session, v.vgtidStateKeyspace, v.vgtidStateTable) + if err != nil { + return nil, err + } + + var storedTableConfig map[string]*TableConfig + var copyCompleted bool + v.latestVgtid, storedTableConfig, copyCompleted, err = getLatestVGtid(ctx, v.session, v.name, v.vgtidStateKeyspace, v.vgtidStateTable) + if err != nil { + return nil, err + } + + if v.latestVgtid == nil { + // we need to bootstrap the stream, which means we need to create a new vgtid and store the table config + v.latestVgtid, err = initVGtid(ctx, v.session, v.name, v.vgtidStateKeyspace, v.vgtidStateTable, v.tables, v.shardsByKeyspace) + if err != nil { + return nil, err + } + } else { + // we need to check if the tables have changed since the last stream, to make + // sure users aren't expecting to catch up on a new table that was added after the last stream. + err = validateTableConfig(v.tables, storedTableConfig) + if err != nil { + return nil, err + } + + // since we have a vgtid, but the copy never completed, vstream docs say we need to restart from the beginning + if !copyCompleted { + // TODO: we could probably handle the recovery, by resetting the vgtid to nil, and calling some user + // defined function to have them truncate/recreate the intended destination tables. + return nil, errors.New("vstreamclient: copy phase not completed, need to restart stream") + } + } + + // initialize the streamer + + v.reader, err = conn.VStream(ctx, topodatapb.TabletType_REPLICA, v.latestVgtid, v.filter, v.flags) + if err != nil { + return nil, fmt.Errorf("vstreamclient: failed to create vstream: %w", err) + } + + return v, nil +} + +// Close closes the VStreamClient, which stops the stream and cleans up resources. +func (v *VStreamClient) Close(ctx context.Context) error { + return nil +} + +func getShardsByKeyspace(ctx context.Context, session *vtgateconn.VTGateSession) (map[string][]string, error) { + query := "SHOW VITESS_SHARDS" + result, err := session.Execute(ctx, query, nil) + if err != nil { + return nil, fmt.Errorf("vstreamclient: failed to get shards by keyspace: %w", err) + } + + shardsByKeyspace := make(map[string][]string) + + for _, row := range result.Rows { + keyspace, shard, found := strings.Cut(row[0].ToString(), "/") + if !found { + return nil, fmt.Errorf("vstreamclient: failed to parse keyspace_id: %s", row[0].ToString()) + } + + shardsByKeyspace[keyspace] = append(shardsByKeyspace[keyspace], shard) + } + + return shardsByKeyspace, nil +} diff --git a/go/vt/vstreamclient/vstreamclient_test.go b/go/vt/vstreamclient/vstreamclient_test.go new file mode 100644 index 00000000000..9bdc1a4f700 --- /dev/null +++ b/go/vt/vstreamclient/vstreamclient_test.go @@ -0,0 +1,267 @@ +package vstreamclient + +import ( + "context" + "encoding/json" + "fmt" + "reflect" + "slices" + "testing" + "time" + + "vitess.io/vitess/go/sqltypes" + binlogdatapb "vitess.io/vitess/go/vt/proto/binlogdata" + querypb "vitess.io/vitess/go/vt/proto/query" + "vitess.io/vitess/go/vt/vtgate/vtgateconn" +) + +// Customer is the concrete type that will be built from the stream +type Customer struct { + ID int64 `vstream:"customer_id"` + Email string `vstream:"email"` + DeletedAt time.Time `vstream:"-"` +} + +func getConn(t *testing.T, ctx context.Context) *vtgateconn.VTGateConn { + t.Helper() + conn, err := vtgateconn.Dial(ctx, "localhost:15991") + if err != nil { + t.Fatal(err) + } + return conn +} + +// To run the tests, this currently expects the local example to be running +// ./101_initial_cluster.sh; mysql < ../common/insert_commerce_data.sql; ./201_customer_tablets.sh; ./202_move_tables.sh; ./203_switch_reads.sh; ./204_switch_writes.sh; ./205_clean_commerce.sh; ./301_customer_sharded.sh; ./302_new_shards.sh; ./303_reshard.sh; ./304_switch_reads.sh; ./305_switch_writes.sh; ./306_down_shard_0.sh; ./307_delete_shard_0.sh +func TestVStreamClient(t *testing.T) { + conn := getConn(t, context.Background()) + defer conn.Close() + + flushCount := 0 + gotCustomers := make([]*Customer, 0) + + tables := []TableConfig{{ + Keyspace: "customer", + Table: "customer", + MaxRowsPerFlush: 7, + DataType: &Customer{}, + FlushFn: func(ctx context.Context, rows []Row, meta FlushMeta) error { + flushCount++ + + fmt.Printf("upserting %d customers\n", len(rows)) + for i, row := range rows { + switch { + // delete event + case row.RowChange.After == nil: + customer := row.Data.(*Customer) + customer.DeletedAt = time.Now() + + gotCustomers = append(gotCustomers, customer) + fmt.Printf("deleting customer %d: %v\n", i, row) + + // insert event + case row.RowChange.Before == nil: + gotCustomers = append(gotCustomers, row.Data.(*Customer)) + fmt.Printf("inserting customer %d: %v\n", i, row) + + // update event + case row.RowChange.Before != nil: + gotCustomers = append(gotCustomers, row.Data.(*Customer)) + fmt.Printf("updating customer %d: %v\n", i, row) + } + } + + // a real implementation would do something more meaningful here. For a data warehouse type workload, + // it would probably look like streaming rows into the data warehouse, or for more complex versions, + // write newline delimited json or a parquet file to object storage, then trigger a load job. + return nil + }, + }} + + t.Run("first vstream run, should succeed", func(t *testing.T) { + ctx, cancel := context.WithTimeout(context.Background(), 2*time.Second) + defer cancel() + + vstreamClient, err := New(ctx, "bob", conn, tables, + WithMinFlushDuration(500*time.Millisecond), + WithHeartbeatSeconds(1), + WithStateTable("commerce", "vstreams"), + WithEventFunc(func(ctx context.Context, ev *binlogdatapb.VEvent) error { + fmt.Printf("** FIELD EVENT: %v\n", ev) + return nil + }, binlogdatapb.VEventType_FIELD), + ) + if err != nil { + t.Fatalf("failed to create VStreamClient: %v", err) + } + + err = vstreamClient.Run(ctx) + if err != nil && ctx.Err() == nil { + t.Fatalf("failed to run vstreamclient: %v", err) + } + + slices.SortFunc(gotCustomers, func(a, b *Customer) int { + return int(a.ID - b.ID) + }) + + wantCustomers := []*Customer{ + {ID: 1, Email: "alice@domain.com"}, + {ID: 2, Email: "bob@domain.com"}, + {ID: 3, Email: "charlie@domain.com"}, + {ID: 4, Email: "dan@domain.com"}, + {ID: 5, Email: "eve@domain.com"}, + } + + fmt.Printf("got %d customers | flushed %d times\n", len(gotCustomers), flushCount) + if !reflect.DeepEqual(gotCustomers, wantCustomers) { + t.Fatalf("got %d customers, want %d", len(gotCustomers), len(wantCustomers)) + } + }) + + // this should fail because we're going to restart the stream, but with an additional table + t.Run("second vstream run, should fail", func(t *testing.T) { + withAdditionalTable := append(tables, TableConfig{ + Keyspace: "customer", + Table: "corder", + DataType: &Customer{}, + }) + + ctx, cancel := context.WithTimeout(context.Background(), 2*time.Second) + defer cancel() + + _, err := New(ctx, "bob", conn, withAdditionalTable, + WithStateTable("commerce", "vstreams"), + ) + if err == nil { + t.Fatalf("expected VStreamClient error, got nil") + } else if err.Error() != "vstreamclient: provided tables do not match stored tables" { + t.Fatalf("expected error 'vstreamclient: provided tables do not match stored tables', got '%v'", err) + } + }) +} + +// Customer is the concrete type that will be built from the stream. This version implements +// the VStreamScanner interface to do custom mapping of fields. +type CustomerWithScan struct { + ID int64 + Email string + + // the fields below aren't actually in the schema, but are added for illustrative purposes + EmailConfirmed bool + Details map[string]any + CreatedAt time.Time +} + +var _ VStreamScanner = (*CustomerWithScan)(nil) + +func (customer *CustomerWithScan) VStreamScan(fields []*querypb.Field, row []sqltypes.Value, rowEvent *binlogdatapb.RowEvent, rowChange *binlogdatapb.RowChange) error { + var err error + + for i := range row { + if row[i].IsNull() { + continue + } + + switch fields[i].Name { + case "customer_id": + customer.ID, err = row[i].ToCastInt64() + + case "email": + customer.Email = row[i].ToString() + + // the fields below aren't actually in the example schema, but are added to + // show how you should handle different data types + + case "email_confirmed": + customer.EmailConfirmed, err = row[i].ToBool() + + case "details": + // assume the details field is a json blob + var b []byte + b, err = row[i].ToBytes() + if err == nil { + err = json.Unmarshal(b, &customer.Details) + } + + case "created_at": + customer.CreatedAt, err = row[i].ToTime() + } + if err != nil { + return fmt.Errorf("error processing field %s: %w", fields[i].Name, err) + } + } + + return nil +} + +// To run the tests, this currently expects the local example to be running +// ./101_initial_cluster.sh; mysql < ../common/insert_commerce_data.sql; ./201_customer_tablets.sh; ./202_move_tables.sh; ./203_switch_reads.sh; ./204_switch_writes.sh; ./205_clean_commerce.sh; ./301_customer_sharded.sh; ./302_new_shards.sh; ./303_reshard.sh; ./304_switch_reads.sh; ./305_switch_writes.sh; ./306_down_shard_0.sh; ./307_delete_shard_0.sh +func TestVStreamClientWithScan(t *testing.T) { + ctx, cancel := context.WithTimeout(context.Background(), 2*time.Second) + defer cancel() + + conn, err := vtgateconn.Dial(ctx, "localhost:15991") + if err != nil { + t.Fatal(err) + } + defer conn.Close() + + flushCount := 0 + gotCustomers := make([]*CustomerWithScan, 0) + + tables := []TableConfig{{ + Keyspace: "customer", + Table: "customer", + MaxRowsPerFlush: 7, + FlushFn: func(ctx context.Context, rows []Row, meta FlushMeta) error { + flushCount++ + + fmt.Printf("upserting %d customers\n", len(rows)) + for i, row := range rows { + gotCustomers = append(gotCustomers, row.Data.(*CustomerWithScan)) + fmt.Printf("upserting customer %d: %v\n", i, row) + } + + // a real implementation would do something more meaningful here. For a data warehouse type workload, + // it would probably look like streaming rows into the data warehouse, or for more complex versions, + // write newline delimited json or a parquet file to object storage, then trigger a load job. + return nil + }, + DataType: &CustomerWithScan{}, + }} + + vstreamClient, err := New(ctx, "bob2", conn, tables, + WithMinFlushDuration(500*time.Millisecond), + WithHeartbeatSeconds(1), + WithStateTable("commerce", "vstreams"), + WithEventFunc(func(ctx context.Context, ev *binlogdatapb.VEvent) error { + fmt.Printf("** FIELD EVENT: %v\n", ev) + return nil + }, binlogdatapb.VEventType_FIELD), + ) + if err != nil { + t.Fatalf("failed to create VStreamClient: %v", err) + } + + err = vstreamClient.Run(ctx) + if err != nil && ctx.Err() == nil { + t.Fatalf("failed to run vstreamclient: %v", err) + } + + slices.SortFunc(gotCustomers, func(a, b *CustomerWithScan) int { + return int(a.ID - b.ID) + }) + + wantCustomers := []*CustomerWithScan{ + {ID: 1, Email: "alice@domain.com"}, + {ID: 2, Email: "bob@domain.com"}, + {ID: 3, Email: "charlie@domain.com"}, + {ID: 4, Email: "dan@domain.com"}, + {ID: 5, Email: "eve@domain.com"}, + } + + fmt.Printf("got %d customers | flushed %d times\n", len(gotCustomers), flushCount) + if !reflect.DeepEqual(gotCustomers, wantCustomers) { + t.Fatalf("got %d customers, want %d", len(gotCustomers), len(wantCustomers)) + } +}