Skip to content

Commit

Permalink
[#387]: feature: use PHP-SDK version with the new temporal-php gR…
Browse files Browse the repository at this point in the history
…PC header
  • Loading branch information
rustatian authored Jul 6, 2023
2 parents 432c8ad + ded546b commit 4857c77
Show file tree
Hide file tree
Showing 2 changed files with 75 additions and 63 deletions.
9 changes: 3 additions & 6 deletions internal/worker_info.go
Original file line number Diff line number Diff line change
Expand Up @@ -9,14 +9,13 @@ import (
// WorkerInfo lists available task queues, workflows and activities.
type WorkerInfo struct {
// TaskQueue assigned to the worker.
TaskQueue string `json:"taskQueue"`

TaskQueue string `json:"TaskQueue"`
// Options describe worker options.
Options worker.Options `json:"options,omitempty"`

// PhpSdkVersion is the underlying PHP-SDK version
PhpSdkVersion string `json:"PhpSdkVersion,omitempty"`
// Workflows provided by the worker.
Workflows []WorkflowInfo

// Activities provided by the worker.
Activities []ActivityInfo
}
Expand All @@ -25,10 +24,8 @@ type WorkerInfo struct {
type WorkflowInfo struct {
// Name of the workflow.
Name string `json:"name"`

// Queries pre-defined for the workflow type.
Queries []string `json:"queries"`

// Signals pre-defined for the workflow type.
Signals []string `json:"signals"`
}
Expand Down
129 changes: 72 additions & 57 deletions plugin.go
Original file line number Diff line number Diff line change
Expand Up @@ -48,8 +48,9 @@ const (

// temporal, sync with https://github.com/temporalio/sdk-go/blob/master/internal/internal_utils.go#L44
clientNameHeaderName = "client-name"
clientNameHeaderValue = "roadrunner-temporal"
clientNameHeaderValue = "temporal-php-2"
clientVersionHeaderName = "client-version"
clientBaselineVersion = "2.5.0"
)

type Logger interface {
Expand Down Expand Up @@ -235,35 +236,7 @@ func (p *Plugin) Serve() chan error {
p.mu.Lock()
defer p.mu.Unlock()

worker.SetStickyWorkflowCacheSize(p.config.CacheSize)

dc := data_converter.NewDataConverter(converter.GetDefaultDataConverter())

opts := temporalClient.Options{
HostPort: p.config.Address,
MetricsHandler: p.mh,
Namespace: p.config.Namespace,
Logger: logger.NewZapAdapter(p.log),
DataConverter: dc,
ConnectionOptions: temporalClient.ConnectionOptions{
TLS: p.tlsCfg,
DialOptions: []grpc.DialOption{
grpc.WithUnaryInterceptor(p.rewriteNameAndVersion),
},
},
}

var err error
p.client, err = temporalClient.Dial(opts)
if err != nil {
errCh <- errors.E(op, err)
return errCh
}

p.log.Info("connected to temporal server", zap.String("address", p.config.Address))
p.codec = proto.NewCodec(p.log, dc)

err = p.initPool()
err := p.initPool()
if err != nil {
errCh <- errors.E(op, err)
return errCh
Expand Down Expand Up @@ -450,6 +423,19 @@ func (p *Plugin) Reset() error {
return nil
}

// Collects collecting grpc interceptors
func (p *Plugin) Collects() []*dep.In {
return []*dep.In{
dep.Fits(func(pp any) {
mdw := pp.(common.Interceptor)
// just to be safe
p.mu.Lock()
p.interceptors[mdw.Name()] = mdw
p.mu.Unlock()
}, (*common.Interceptor)(nil)),
}
}

func (p *Plugin) Name() string {
return pluginName
}
Expand All @@ -458,24 +444,54 @@ func (p *Plugin) RPC() any {
return &rpc{srv: p, client: p.client}
}

func (p *Plugin) rewriteNameAndVersion(
ctx context.Context,
method string,
req, reply interface{},
cc *grpc.ClientConn,
invoker grpc.UnaryInvoker,
opts ...grpc.CallOption) error {
md, _, _ := metadata.FromOutgoingContextRaw(ctx)
if md == nil {
return invoker(ctx, method, req, reply, cc, opts...)
/// INTERNAL

func (p *Plugin) initTemporalClient(phpSdkVersion string, dc converter.DataConverter) error {
if phpSdkVersion == "" {
phpSdkVersion = clientBaselineVersion
}
p.log.Debug("PHP-SDK version: " + phpSdkVersion)
worker.SetStickyWorkflowCacheSize(p.config.CacheSize)

md.Set(clientNameHeaderName, clientNameHeaderValue)
md.Set(clientVersionHeaderName, p.rrVersion)
opts := temporalClient.Options{
HostPort: p.config.Address,
MetricsHandler: p.mh,
Namespace: p.config.Namespace,
Logger: logger.NewZapAdapter(p.log),
DataConverter: dc,
ConnectionOptions: temporalClient.ConnectionOptions{
TLS: p.tlsCfg,
DialOptions: []grpc.DialOption{
grpc.WithUnaryInterceptor(rewriteNameAndVersion(phpSdkVersion)),
},
},
}

ctx = metadata.NewOutgoingContext(ctx, md)
var err error
p.client, err = temporalClient.Dial(opts)
if err != nil {
return err
}

p.log.Info("connected to temporal server", zap.String("address", p.config.Address))

return nil
}

return invoker(ctx, method, req, reply, cc, opts...)
func rewriteNameAndVersion(phpSdkVersion string) grpc.UnaryClientInterceptor {
return func(ctx context.Context, method string, req, reply interface{}, cc *grpc.ClientConn, invoker grpc.UnaryInvoker, opts ...grpc.CallOption) error {
md, _, _ := metadata.FromOutgoingContextRaw(ctx)
if md == nil {
return invoker(ctx, method, req, reply, cc, opts...)
}

md.Set(clientNameHeaderName, clientNameHeaderValue)
md.Set(clientVersionHeaderName, phpSdkVersion)

ctx = metadata.NewOutgoingContext(ctx, md)

return invoker(ctx, method, req, reply, cc, opts...)
}
}

func (p *Plugin) initPool() error {
Expand All @@ -485,6 +501,9 @@ func (p *Plugin) initPool() error {
return err
}

dc := data_converter.NewDataConverter(converter.GetDefaultDataConverter())
p.codec = proto.NewCodec(p.log, dc)

p.rrActivityDef = aggregatedpool.NewActivityDefinition(p.codec, ap, p.log)

// ---------- WORKFLOW POOL -------------
Expand Down Expand Up @@ -513,6 +532,15 @@ func (p *Plugin) initPool() error {
return err
}

if len(wi) == 0 {
return errors.Str("worker info should contain at least 1 worker")
}

err = p.initTemporalClient(wi[0].PhpSdkVersion, dc)
if err != nil {
return err
}

p.workers, err = aggregatedpool.TemporalWorkers(p.rrWorkflowDef, p.rrActivityDef, wi, p.log, p.client, p.interceptors)
if err != nil {
return err
Expand All @@ -539,16 +567,3 @@ func (p *Plugin) initPool() error {

return nil
}

// Collects collecting grpc interceptors
func (p *Plugin) Collects() []*dep.In {
return []*dep.In{
dep.Fits(func(pp any) {
mdw := pp.(common.Interceptor)
// just to be safe
p.mu.Lock()
p.interceptors[mdw.Name()] = mdw
p.mu.Unlock()
}, (*common.Interceptor)(nil)),
}
}

0 comments on commit 4857c77

Please sign in to comment.