From b715414e3cb7ec678b974eec25ec43b99c3464e0 Mon Sep 17 00:00:00 2001 From: congqixia Date: Tue, 14 May 2024 11:51:32 +0800 Subject: [PATCH] enhance: Support `consume` command with Kafka MQ (#258) See also #252 Add Kafka support for `consume` command. Make it possible to consume from channel checkpoint if any --------- Signed-off-by: Congqi Xia --- .golangci.yml | 2 + mq/factory.go | 15 +++++++ mq/factory_wkafka.go | 24 +++++++++++- mq/ifc/msgstream.go | 1 + mq/kafka/kafka.go | 37 +++++++++++++++++- mq/kafka/kafka_id.go | 8 ++-- mq/kafka/kafka_test.go | 3 +- mq/pulsar/pulsar_id.go | 8 +++- mq/pulsar/puslar.go | 6 +++ states/consume.go | 88 +++++++++++++++++++++++++++++++++++++++++- 10 files changed, 182 insertions(+), 10 deletions(-) diff --git a/.golangci.yml b/.golangci.yml index b52cd3d..8cf223d 100644 --- a/.golangci.yml +++ b/.golangci.yml @@ -6,6 +6,8 @@ run: - docs - scripts - internal/core + build-tags: + - "WKAFKA" linters: disable-all: true diff --git a/mq/factory.go b/mq/factory.go index 6685712..df529aa 100644 --- a/mq/factory.go +++ b/mq/factory.go @@ -7,10 +7,25 @@ import ( "fmt" "time" + "github.com/cockroachdb/errors" "github.com/milvus-io/birdwatcher/mq/ifc" "github.com/milvus-io/birdwatcher/mq/pulsar" ) +func ParsePositionFromCheckpoint(mqType string, messageID []byte) (ifc.MessageID, error) { + switch mqType { + case "pulsar": + return pulsar.DeserializePulsarMsgID(messageID) + default: + return nil, errors.Newf("not supported mq type: %s", mqType) + } +} + +func ParseManualMessageID(mqType string, manualID int64) (ifc.MessageID, error) { + // pulsar not supported yet + return nil, errors.Newf("not supported mq type: %s", mqType) +} + func NewConsumer(mqType, address, channel string, config ifc.MqOption) (ifc.Consumer, error) { groupID := fmt.Sprintf("group-id-%d", time.Now().UnixNano()) switch mqType { diff --git a/mq/factory_wkafka.go b/mq/factory_wkafka.go index 72f11f9..9b85c1f 100644 --- a/mq/factory_wkafka.go +++ b/mq/factory_wkafka.go @@ -7,16 +7,38 @@ import ( "fmt" "time" + "github.com/cockroachdb/errors" "github.com/milvus-io/birdwatcher/mq/ifc" "github.com/milvus-io/birdwatcher/mq/kafka" "github.com/milvus-io/birdwatcher/mq/pulsar" ) +func ParsePositionFromCheckpoint(mqType string, messageID []byte) (ifc.MessageID, error) { + switch mqType { + case "pulsar": + return pulsar.DeserializePulsarMsgID(messageID) + case "kafka": + return kafka.DeserializeKafkaID(messageID), nil + default: + return nil, errors.Newf("not supported mq type: %s", mqType) + } +} + +func ParseManualMessageID(mqType string, manualID int64) (ifc.MessageID, error) { + switch mqType { + // pulsar not supported yet + case "kafka": + return kafka.DeserializeKafkaID(kafka.SerializeKafkaID(manualID)), nil + default: + return nil, errors.Newf("not supported mq type: %s", mqType) + } +} + func NewConsumer(mqType, address, channel string, config ifc.MqOption) (ifc.Consumer, error) { groupID := fmt.Sprintf("group-id-%d", time.Now().UnixNano()) switch mqType { case "kafka": - return kafka.NewKafkaConsumer(address, channel, groupID) + return kafka.NewKafkaConsumer(address, channel, groupID, config) case "pulsar": return pulsar.NewPulsarConsumer(address, channel, groupID, config) default: diff --git a/mq/ifc/msgstream.go b/mq/ifc/msgstream.go index fc11cdb..42bef23 100644 --- a/mq/ifc/msgstream.go +++ b/mq/ifc/msgstream.go @@ -8,6 +8,7 @@ type Consumer interface { GetLastMessageID() (MessageID, error) GetLastMessage() (Message, error) Consume() (Message, error) + Seek(MessageID) error Close() error } diff --git a/mq/kafka/kafka.go b/mq/kafka/kafka.go index 93fc423..d383e57 100644 --- a/mq/kafka/kafka.go +++ b/mq/kafka/kafka.go @@ -12,12 +12,14 @@ import ( "github.com/milvus-io/birdwatcher/mq/ifc" ) +const DefaultPartitionIdx = 0 + type Consumer struct { topic string c *kafka.Consumer } -func NewKafkaConsumer(address, topic, groupID string) (*Consumer, error) { +func NewKafkaConsumer(address, topic, groupID string, mqConfig ifc.MqOption) (*Consumer, error) { config := &kafka.ConfigMap{ "bootstrap.servers": address, "api.version.request": true, @@ -31,6 +33,39 @@ func NewKafkaConsumer(address, topic, groupID string) (*Consumer, error) { return &Consumer{topic: topic, c: c}, nil } +func (k *Consumer) Consume() (ifc.Message, error) { + e, err := k.c.ReadMessage(time.Second * 5) + if err != nil { + return nil, err + } + + return &kafkaMessage{msg: e}, nil +} + +func (k *Consumer) Seek(id ifc.MessageID) error { + offset := kafka.Offset(id.(*kafkaID).messageID) + return k.internalSeek(offset, true) +} + +func (k *Consumer) internalSeek(offset kafka.Offset, inclusive bool) error { + err := k.c.Assign([]kafka.TopicPartition{{Topic: &k.topic, Partition: DefaultPartitionIdx, Offset: offset}}) + if err != nil { + return err + } + + timeout := 0 + // If seek timeout is not 0 the call twice will return error isStarted RD_KAFKA_RESP_ERR__STATE. + // if the timeout is 0 it will initiate the seek but return immediately without any error reporting + if err := k.c.Seek(kafka.TopicPartition{ + Topic: &k.topic, + Partition: DefaultPartitionIdx, + Offset: offset, + }, timeout); err != nil { + return err + } + return nil +} + func (k *Consumer) GetLastMessageID() (ifc.MessageID, error) { low, high, err := k.c.QueryWatermarkOffsets(k.topic, ifc.DefaultPartitionIdx, 1200) if err != nil { diff --git a/mq/kafka/kafka_id.go b/mq/kafka/kafka_id.go index 8554fbd..b3fa24d 100644 --- a/mq/kafka/kafka_id.go +++ b/mq/kafka/kafka_id.go @@ -24,11 +24,11 @@ func (kid *kafkaID) AtEarliestPosition() bool { } func (kid *kafkaID) Equal(msgID []byte) (bool, error) { - return kid.messageID == DeserializeKafkaID(msgID), nil + return kid.messageID == DeserializeKafkaID(msgID).messageID, nil } func (kid *kafkaID) LessOrEqualThan(msgID []byte) (bool, error) { - return kid.messageID <= DeserializeKafkaID(msgID), nil + return kid.messageID <= DeserializeKafkaID(msgID).messageID, nil } func (kid *kafkaID) String() string { @@ -41,6 +41,6 @@ func SerializeKafkaID(messageID int64) []byte { return b } -func DeserializeKafkaID(messageID []byte) int64 { - return int64(ifc.Endian.Uint64(messageID)) +func DeserializeKafkaID(messageID []byte) *kafkaID { + return &kafkaID{messageID: int64(ifc.Endian.Uint64(messageID))} } diff --git a/mq/kafka/kafka_test.go b/mq/kafka/kafka_test.go index 33dcad4..5102de2 100644 --- a/mq/kafka/kafka_test.go +++ b/mq/kafka/kafka_test.go @@ -9,6 +9,7 @@ import ( "time" "github.com/confluentinc/confluent-kafka-go/kafka" + "github.com/milvus-io/birdwatcher/mq/ifc" "github.com/stretchr/testify/assert" ) @@ -38,7 +39,7 @@ func TestConsumer(t *testing.T) { fmt.Println("send finished, msg offset:", km.TopicPartition.Offset) } - c, err := NewKafkaConsumer(address, topic, "gid") + c, err := NewKafkaConsumer(address, topic, "gid", ifc.MqOption{}) if err != nil { t.Fatal("create consumer fail", err) } diff --git a/mq/pulsar/pulsar_id.go b/mq/pulsar/pulsar_id.go index e9d5f64..397bcf4 100644 --- a/mq/pulsar/pulsar_id.go +++ b/mq/pulsar/pulsar_id.go @@ -72,8 +72,12 @@ func SerializePulsarMsgID(messageID pulsar.MessageID) []byte { } // DeserializePulsarMsgID returns the deserialized message ID -func DeserializePulsarMsgID(messageID []byte) (pulsar.MessageID, error) { - return pulsar.DeserializeMessageID(messageID) +func DeserializePulsarMsgID(messageID []byte) (ifc.MessageID, error) { + id, err := pulsar.DeserializeMessageID(messageID) + if err != nil { + return nil, err + } + return &pulsarID{messageID: id}, nil } // msgIDToString is used to convert a message ID to string diff --git a/mq/pulsar/puslar.go b/mq/pulsar/puslar.go index d0c82b5..87dda6f 100644 --- a/mq/pulsar/puslar.go +++ b/mq/pulsar/puslar.go @@ -49,6 +49,12 @@ func (p *pulsarConsumer) Consume() (ifc.Message, error) { return &pulsarMessage{msg: msg}, nil } +func (p *pulsarConsumer) Seek(id ifc.MessageID) error { + messageID := id.(*pulsarID).messageID + err := p.consumer.Seek(messageID) + return err +} + func (p *pulsarConsumer) GetLastMessageID() (ifc.MessageID, error) { msgID, err := p.consumer.GetLastMessageID(p.topic, 0) return &pulsarID{messageID: msgID}, err diff --git a/states/consume.go b/states/consume.go index ced69e7..b32e02b 100644 --- a/states/consume.go +++ b/states/consume.go @@ -3,32 +3,77 @@ package states import ( "context" "fmt" + "path" + "github.com/cockroachdb/errors" "github.com/golang/protobuf/proto" "github.com/milvus-io/birdwatcher/framework" "github.com/milvus-io/birdwatcher/mq" "github.com/milvus-io/birdwatcher/mq/ifc" "github.com/milvus-io/birdwatcher/proto/v2.2/commonpb" "github.com/milvus-io/birdwatcher/proto/v2.2/msgpb" + "github.com/milvus-io/birdwatcher/proto/v2.2/schemapb" + "github.com/milvus-io/birdwatcher/states/etcd/common" ) type ConsumeParam struct { framework.ParamBase `use:"consume" desc:"consume msgs from provided topic"` + StartPosition string `name:"start_pos" default:"cp" desc:"position to start with"` MqType string `name:"mq_type" default:"pulsar" desc:"message queue type to consume"` MqAddress string `name:"mq_addr" default:"pulsar://127.0.0.1:6650" desc:"message queue service address"` Topic string `name:"topic" default:"" desc:"topic to consume"` ShardName string `name:"shard_name" default:"" desc:"shard name(vchannel name) to filter with"` Detail bool `name:"detail" default:"false" desc:"print msg detail"` + ManualID int64 `name:"manual_id" default:"0" desc:"manual id"` } func (s *InstanceState) ConsumeCommand(ctx context.Context, p *ConsumeParam) error { + + var messageID ifc.MessageID + switch p.StartPosition { + case "cp": + prefix := path.Join(s.basePath, "datacoord-meta", "channel-cp", p.ShardName) + results, _, err := common.ListProtoObjects[msgpb.MsgPosition](ctx, s.client, prefix) + if err != nil { + return err + } + if len(results) == 1 { + checkpoint := results[0] + messageID, err = mq.ParsePositionFromCheckpoint(p.MqType, checkpoint.GetMsgID()) + if err != nil { + return err + } + } + case "manual": + var err error + messageID, err = mq.ParseManualMessageID(p.MqType, p.ManualID) + if err != nil { + return err + } + default: + } + + subPos := ifc.SubscriptionPositionEarliest + if messageID != nil { + subPos = ifc.SubscriptionPositionLatest + } + c, err := mq.NewConsumer(p.MqType, p.MqAddress, p.Topic, ifc.MqOption{ - SubscriptionInitPos: ifc.SubscriptionPositionEarliest, + SubscriptionInitPos: subPos, }) + if err != nil { return err } + if messageID != nil { + fmt.Println("Using message ID to seek", messageID) + err := c.Seek(messageID) + if err != nil { + return err + } + } + latestID, err := c.GetLastMessageID() if err != nil { return err @@ -59,6 +104,10 @@ func (s *InstanceState) ConsumeCommand(ctx context.Context, p *ConsumeParam) err fmt.Print(v) } else { fmt.Print(v.GetShardName()) + err := ValidateMsg(msgType, msg.Payload()) + if err != nil { + fmt.Println(err.Error()) + } } } default: @@ -92,3 +141,40 @@ func ParseMsg(msgType commonpb.MsgType, payload []byte) (interface { } return msg, nil } + +func ValidateMsg(msgType commonpb.MsgType, payload []byte) error { + switch msgType { + case commonpb.MsgType_Insert: + msg := &msgpb.InsertRequest{} + proto.Unmarshal(payload, msg) + for _, fieldData := range msg.GetFieldsData() { + msgType := fieldData.GetType() + switch msgType { + case schemapb.DataType_Int64: + l := len(fieldData.GetScalars().GetLongData().GetData()) + if l != int(msg.GetNumRows()) { + return errors.Newf("Field %d(%s) len = %d, datatype %v mismatch num rows: %d", fieldData.GetFieldId(), fieldData.GetFieldName(), l, msgType, msg.GetNumRows()) + } + case schemapb.DataType_VarChar: + l := len(fieldData.GetScalars().GetStringData().GetData()) + if l != int(msg.GetNumRows()) { + return errors.Newf("Field %d(%s) len = %d, datatype %v mismatch num rows: %d", fieldData.GetFieldId(), fieldData.GetFieldName(), l, msgType, msg.GetNumRows()) + } + case schemapb.DataType_Bool: + l := len(fieldData.GetScalars().GetBoolData().GetData()) + if l != int(msg.GetNumRows()) { + return errors.Newf("Field %d(%s) len = %d, datatype %v mismatch num rows: %d", fieldData.GetFieldId(), fieldData.GetFieldName(), l, msgType, msg.GetNumRows()) + } + case schemapb.DataType_FloatVector: + l := len(fieldData.GetVectors().GetFloatVector().GetData()) + dim := fieldData.GetVectors().GetDim() + if l/int(dim) != int(msg.GetNumRows()) { + return errors.Newf("Field %d(%s) len = %d, datatype %v mismatch num rows: %d", fieldData.GetFieldId(), fieldData.GetFieldName(), l, msgType, msg.GetNumRows()) + } + default: + fmt.Println("skip unhanlded data type", fieldData.GetType()) + } + } + } + return nil +}