Skip to content

Commit

Permalink
enhance: Support consume command with Kafka MQ (#258)
Browse files Browse the repository at this point in the history
See also #252

Add Kafka support for `consume` command.
Make it possible to consume from channel checkpoint if any

---------

Signed-off-by: Congqi Xia <[email protected]>
  • Loading branch information
congqixia authored May 14, 2024
1 parent 3ef67d4 commit b715414
Show file tree
Hide file tree
Showing 10 changed files with 182 additions and 10 deletions.
2 changes: 2 additions & 0 deletions .golangci.yml
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,8 @@ run:
- docs
- scripts
- internal/core
build-tags:
- "WKAFKA"

linters:
disable-all: true
Expand Down
15 changes: 15 additions & 0 deletions mq/factory.go
Original file line number Diff line number Diff line change
Expand Up @@ -7,10 +7,25 @@ import (
"fmt"
"time"

"github.com/cockroachdb/errors"
"github.com/milvus-io/birdwatcher/mq/ifc"
"github.com/milvus-io/birdwatcher/mq/pulsar"
)

func ParsePositionFromCheckpoint(mqType string, messageID []byte) (ifc.MessageID, error) {
switch mqType {
case "pulsar":
return pulsar.DeserializePulsarMsgID(messageID)
default:
return nil, errors.Newf("not supported mq type: %s", mqType)
}
}

func ParseManualMessageID(mqType string, manualID int64) (ifc.MessageID, error) {
// pulsar not supported yet
return nil, errors.Newf("not supported mq type: %s", mqType)
}

func NewConsumer(mqType, address, channel string, config ifc.MqOption) (ifc.Consumer, error) {
groupID := fmt.Sprintf("group-id-%d", time.Now().UnixNano())
switch mqType {
Expand Down
24 changes: 23 additions & 1 deletion mq/factory_wkafka.go
Original file line number Diff line number Diff line change
Expand Up @@ -7,16 +7,38 @@ import (
"fmt"
"time"

"github.com/cockroachdb/errors"
"github.com/milvus-io/birdwatcher/mq/ifc"
"github.com/milvus-io/birdwatcher/mq/kafka"
"github.com/milvus-io/birdwatcher/mq/pulsar"
)

func ParsePositionFromCheckpoint(mqType string, messageID []byte) (ifc.MessageID, error) {
switch mqType {
case "pulsar":
return pulsar.DeserializePulsarMsgID(messageID)
case "kafka":
return kafka.DeserializeKafkaID(messageID), nil
default:
return nil, errors.Newf("not supported mq type: %s", mqType)
}
}

func ParseManualMessageID(mqType string, manualID int64) (ifc.MessageID, error) {
switch mqType {
// pulsar not supported yet
case "kafka":
return kafka.DeserializeKafkaID(kafka.SerializeKafkaID(manualID)), nil
default:
return nil, errors.Newf("not supported mq type: %s", mqType)
}
}

func NewConsumer(mqType, address, channel string, config ifc.MqOption) (ifc.Consumer, error) {
groupID := fmt.Sprintf("group-id-%d", time.Now().UnixNano())
switch mqType {
case "kafka":
return kafka.NewKafkaConsumer(address, channel, groupID)
return kafka.NewKafkaConsumer(address, channel, groupID, config)
case "pulsar":
return pulsar.NewPulsarConsumer(address, channel, groupID, config)
default:
Expand Down
1 change: 1 addition & 0 deletions mq/ifc/msgstream.go
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@ type Consumer interface {
GetLastMessageID() (MessageID, error)
GetLastMessage() (Message, error)
Consume() (Message, error)
Seek(MessageID) error
Close() error
}

Expand Down
37 changes: 36 additions & 1 deletion mq/kafka/kafka.go
Original file line number Diff line number Diff line change
Expand Up @@ -12,12 +12,14 @@ import (
"github.com/milvus-io/birdwatcher/mq/ifc"
)

const DefaultPartitionIdx = 0

type Consumer struct {
topic string
c *kafka.Consumer
}

func NewKafkaConsumer(address, topic, groupID string) (*Consumer, error) {
func NewKafkaConsumer(address, topic, groupID string, mqConfig ifc.MqOption) (*Consumer, error) {
config := &kafka.ConfigMap{
"bootstrap.servers": address,
"api.version.request": true,
Expand All @@ -31,6 +33,39 @@ func NewKafkaConsumer(address, topic, groupID string) (*Consumer, error) {
return &Consumer{topic: topic, c: c}, nil
}

func (k *Consumer) Consume() (ifc.Message, error) {
e, err := k.c.ReadMessage(time.Second * 5)
if err != nil {
return nil, err
}

return &kafkaMessage{msg: e}, nil
}

func (k *Consumer) Seek(id ifc.MessageID) error {
offset := kafka.Offset(id.(*kafkaID).messageID)
return k.internalSeek(offset, true)
}

func (k *Consumer) internalSeek(offset kafka.Offset, inclusive bool) error {
err := k.c.Assign([]kafka.TopicPartition{{Topic: &k.topic, Partition: DefaultPartitionIdx, Offset: offset}})
if err != nil {
return err
}

timeout := 0
// If seek timeout is not 0 the call twice will return error isStarted RD_KAFKA_RESP_ERR__STATE.
// if the timeout is 0 it will initiate the seek but return immediately without any error reporting
if err := k.c.Seek(kafka.TopicPartition{
Topic: &k.topic,
Partition: DefaultPartitionIdx,
Offset: offset,
}, timeout); err != nil {
return err
}
return nil
}

func (k *Consumer) GetLastMessageID() (ifc.MessageID, error) {
low, high, err := k.c.QueryWatermarkOffsets(k.topic, ifc.DefaultPartitionIdx, 1200)
if err != nil {
Expand Down
8 changes: 4 additions & 4 deletions mq/kafka/kafka_id.go
Original file line number Diff line number Diff line change
Expand Up @@ -24,11 +24,11 @@ func (kid *kafkaID) AtEarliestPosition() bool {
}

func (kid *kafkaID) Equal(msgID []byte) (bool, error) {
return kid.messageID == DeserializeKafkaID(msgID), nil
return kid.messageID == DeserializeKafkaID(msgID).messageID, nil
}

func (kid *kafkaID) LessOrEqualThan(msgID []byte) (bool, error) {
return kid.messageID <= DeserializeKafkaID(msgID), nil
return kid.messageID <= DeserializeKafkaID(msgID).messageID, nil
}

func (kid *kafkaID) String() string {
Expand All @@ -41,6 +41,6 @@ func SerializeKafkaID(messageID int64) []byte {
return b
}

func DeserializeKafkaID(messageID []byte) int64 {
return int64(ifc.Endian.Uint64(messageID))
func DeserializeKafkaID(messageID []byte) *kafkaID {
return &kafkaID{messageID: int64(ifc.Endian.Uint64(messageID))}
}
3 changes: 2 additions & 1 deletion mq/kafka/kafka_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@ import (
"time"

"github.com/confluentinc/confluent-kafka-go/kafka"
"github.com/milvus-io/birdwatcher/mq/ifc"
"github.com/stretchr/testify/assert"
)

Expand Down Expand Up @@ -38,7 +39,7 @@ func TestConsumer(t *testing.T) {
fmt.Println("send finished, msg offset:", km.TopicPartition.Offset)
}

c, err := NewKafkaConsumer(address, topic, "gid")
c, err := NewKafkaConsumer(address, topic, "gid", ifc.MqOption{})
if err != nil {
t.Fatal("create consumer fail", err)
}
Expand Down
8 changes: 6 additions & 2 deletions mq/pulsar/pulsar_id.go
Original file line number Diff line number Diff line change
Expand Up @@ -72,8 +72,12 @@ func SerializePulsarMsgID(messageID pulsar.MessageID) []byte {
}

// DeserializePulsarMsgID returns the deserialized message ID
func DeserializePulsarMsgID(messageID []byte) (pulsar.MessageID, error) {
return pulsar.DeserializeMessageID(messageID)
func DeserializePulsarMsgID(messageID []byte) (ifc.MessageID, error) {
id, err := pulsar.DeserializeMessageID(messageID)
if err != nil {
return nil, err
}
return &pulsarID{messageID: id}, nil
}

// msgIDToString is used to convert a message ID to string
Expand Down
6 changes: 6 additions & 0 deletions mq/pulsar/puslar.go
Original file line number Diff line number Diff line change
Expand Up @@ -49,6 +49,12 @@ func (p *pulsarConsumer) Consume() (ifc.Message, error) {
return &pulsarMessage{msg: msg}, nil
}

func (p *pulsarConsumer) Seek(id ifc.MessageID) error {
messageID := id.(*pulsarID).messageID
err := p.consumer.Seek(messageID)
return err
}

func (p *pulsarConsumer) GetLastMessageID() (ifc.MessageID, error) {
msgID, err := p.consumer.GetLastMessageID(p.topic, 0)
return &pulsarID{messageID: msgID}, err
Expand Down
88 changes: 87 additions & 1 deletion states/consume.go
Original file line number Diff line number Diff line change
Expand Up @@ -3,32 +3,77 @@ package states
import (
"context"
"fmt"
"path"

"github.com/cockroachdb/errors"
"github.com/golang/protobuf/proto"
"github.com/milvus-io/birdwatcher/framework"
"github.com/milvus-io/birdwatcher/mq"
"github.com/milvus-io/birdwatcher/mq/ifc"
"github.com/milvus-io/birdwatcher/proto/v2.2/commonpb"
"github.com/milvus-io/birdwatcher/proto/v2.2/msgpb"
"github.com/milvus-io/birdwatcher/proto/v2.2/schemapb"
"github.com/milvus-io/birdwatcher/states/etcd/common"
)

type ConsumeParam struct {
framework.ParamBase `use:"consume" desc:"consume msgs from provided topic"`
StartPosition string `name:"start_pos" default:"cp" desc:"position to start with"`
MqType string `name:"mq_type" default:"pulsar" desc:"message queue type to consume"`
MqAddress string `name:"mq_addr" default:"pulsar://127.0.0.1:6650" desc:"message queue service address"`
Topic string `name:"topic" default:"" desc:"topic to consume"`
ShardName string `name:"shard_name" default:"" desc:"shard name(vchannel name) to filter with"`
Detail bool `name:"detail" default:"false" desc:"print msg detail"`
ManualID int64 `name:"manual_id" default:"0" desc:"manual id"`
}

func (s *InstanceState) ConsumeCommand(ctx context.Context, p *ConsumeParam) error {

var messageID ifc.MessageID
switch p.StartPosition {
case "cp":
prefix := path.Join(s.basePath, "datacoord-meta", "channel-cp", p.ShardName)
results, _, err := common.ListProtoObjects[msgpb.MsgPosition](ctx, s.client, prefix)
if err != nil {
return err
}
if len(results) == 1 {
checkpoint := results[0]
messageID, err = mq.ParsePositionFromCheckpoint(p.MqType, checkpoint.GetMsgID())
if err != nil {
return err
}
}
case "manual":
var err error
messageID, err = mq.ParseManualMessageID(p.MqType, p.ManualID)
if err != nil {
return err
}
default:
}

subPos := ifc.SubscriptionPositionEarliest
if messageID != nil {
subPos = ifc.SubscriptionPositionLatest
}

c, err := mq.NewConsumer(p.MqType, p.MqAddress, p.Topic, ifc.MqOption{
SubscriptionInitPos: ifc.SubscriptionPositionEarliest,
SubscriptionInitPos: subPos,
})

if err != nil {
return err
}

if messageID != nil {
fmt.Println("Using message ID to seek", messageID)
err := c.Seek(messageID)
if err != nil {
return err
}
}

latestID, err := c.GetLastMessageID()
if err != nil {
return err
Expand Down Expand Up @@ -59,6 +104,10 @@ func (s *InstanceState) ConsumeCommand(ctx context.Context, p *ConsumeParam) err
fmt.Print(v)
} else {
fmt.Print(v.GetShardName())
err := ValidateMsg(msgType, msg.Payload())
if err != nil {
fmt.Println(err.Error())
}
}
}
default:
Expand Down Expand Up @@ -92,3 +141,40 @@ func ParseMsg(msgType commonpb.MsgType, payload []byte) (interface {
}
return msg, nil
}

func ValidateMsg(msgType commonpb.MsgType, payload []byte) error {
switch msgType {
case commonpb.MsgType_Insert:
msg := &msgpb.InsertRequest{}
proto.Unmarshal(payload, msg)
for _, fieldData := range msg.GetFieldsData() {
msgType := fieldData.GetType()
switch msgType {
case schemapb.DataType_Int64:
l := len(fieldData.GetScalars().GetLongData().GetData())
if l != int(msg.GetNumRows()) {
return errors.Newf("Field %d(%s) len = %d, datatype %v mismatch num rows: %d", fieldData.GetFieldId(), fieldData.GetFieldName(), l, msgType, msg.GetNumRows())
}
case schemapb.DataType_VarChar:
l := len(fieldData.GetScalars().GetStringData().GetData())
if l != int(msg.GetNumRows()) {
return errors.Newf("Field %d(%s) len = %d, datatype %v mismatch num rows: %d", fieldData.GetFieldId(), fieldData.GetFieldName(), l, msgType, msg.GetNumRows())
}
case schemapb.DataType_Bool:
l := len(fieldData.GetScalars().GetBoolData().GetData())
if l != int(msg.GetNumRows()) {
return errors.Newf("Field %d(%s) len = %d, datatype %v mismatch num rows: %d", fieldData.GetFieldId(), fieldData.GetFieldName(), l, msgType, msg.GetNumRows())
}
case schemapb.DataType_FloatVector:
l := len(fieldData.GetVectors().GetFloatVector().GetData())
dim := fieldData.GetVectors().GetDim()
if l/int(dim) != int(msg.GetNumRows()) {
return errors.Newf("Field %d(%s) len = %d, datatype %v mismatch num rows: %d", fieldData.GetFieldId(), fieldData.GetFieldName(), l, msgType, msg.GetNumRows())
}
default:
fmt.Println("skip unhanlded data type", fieldData.GetType())
}
}
}
return nil
}

0 comments on commit b715414

Please sign in to comment.