diff --git a/component/component.go b/component/component.go index ee6b8e8..400517a 100644 --- a/component/component.go +++ b/component/component.go @@ -34,13 +34,13 @@ func (c *Component) WithDescription(description string) *Component { // WithInputs ads input ports func (c *Component) WithInputs(portNames ...string) *Component { - c.inputs.Add(port.NewPortGroup(portNames...)...) + c.inputs = c.inputs.Add(port.NewPortGroup(portNames...)...) return c } // WithOutputs adds output ports func (c *Component) WithOutputs(portNames ...string) *Component { - c.outputs.Add(port.NewPortGroup(portNames...)...) + c.outputs = c.outputs.Add(port.NewPortGroup(portNames...)...) return c } @@ -92,7 +92,6 @@ func (c *Component) MaybeActivate() (activationResult *ActivationResult) { return } - //@TODO:: https://github.com/hovsep/fmesh/issues/15 if !c.inputs.AnyHasSignal() { //No inputs set, stop here activationResult = c.newActivationCodeNoInput() @@ -103,10 +102,10 @@ func (c *Component) MaybeActivate() (activationResult *ActivationResult) { //Run the computation err := c.f(c.inputs, c.outputs) - if IsWaitingForInputError(err) { + if errors.Is(err, errWaitingForInputs) { activationResult = c.newActivationCodeWaitingForInput() - if !errors.Is(err, ErrWaitingForInputKeepInputs) { + if !errors.Is(err, errWaitingForInputsKeep) { c.inputs.ClearSignal() } diff --git a/component/component_test.go b/component/component_test.go index 2875d4d..1cc3554 100644 --- a/component/component_test.go +++ b/component/component_test.go @@ -171,8 +171,8 @@ func TestComponent_Inputs(t *testing.T) { name: "with inputs", component: NewComponent("c1").WithInputs("i1", "i2"), want: port.Collection{ - "i1": port.NewPort("i1"), - "i2": port.NewPort("i2"), + port.NewPort("i1"), + port.NewPort("i2"), }, }, } @@ -200,8 +200,8 @@ func TestComponent_Outputs(t *testing.T) { name: "with outputs", component: NewComponent("c1").WithOutputs("o1", "o2"), want: port.Collection{ - "o1": port.NewPort("o1"), - "o2": port.NewPort("o2"), + port.NewPort("o1"), + port.NewPort("o2"), }, }, } @@ -306,8 +306,8 @@ func TestComponent_WithInputs(t *testing.T) { name: "c1", description: "", inputs: port.Collection{ - "p1": port.NewPort("p1"), - "p2": port.NewPort("p2"), + port.NewPort("p1"), + port.NewPort("p2"), }, outputs: port.Collection{}, f: nil, @@ -358,8 +358,8 @@ func TestComponent_WithOutputs(t *testing.T) { description: "", inputs: port.Collection{}, outputs: port.Collection{ - "p1": port.NewPort("p1"), - "p2": port.NewPort("p2"), + port.NewPort("p1"), + port.NewPort("p2"), }, f: nil, }, @@ -420,7 +420,7 @@ func TestComponent_MaybeActivate(t *testing.T) { WithActivationFunc(func(inputs port.Collection, outputs port.Collection) error { if !inputs.ByNames("i1", "i2").AllHaveSignal() { - return ErrWaitingForInputResetInputs + return NewErrWaitForInputs(false) } return nil @@ -432,14 +432,35 @@ func TestComponent_MaybeActivate(t *testing.T) { WithActivationCode(ActivationCodeNoInput), }, { - name: "component is waiting for input", + name: "component is waiting for input, reset inputs", getComponent: func() *Component { c := NewComponent("c1"). WithInputs("i1", "i2"). WithActivationFunc(func(inputs port.Collection, outputs port.Collection) error { if !inputs.ByNames("i1", "i2").AllHaveSignal() { - return ErrWaitingForInputResetInputs + return NewErrWaitForInputs(false) + } + + return nil + }) + //Only one input set + c.Inputs().ByName("i1").PutSignal(signal.New(123)) + return c + }, + wantActivationResult: NewActivationResult("c1"). + SetActivated(false). + WithActivationCode(ActivationCodeWaitingForInput), + }, + { + name: "component is waiting for input, keep inputs", + getComponent: func() *Component { + c := NewComponent("c1"). + WithInputs("i1", "i2"). + WithActivationFunc(func(inputs port.Collection, outputs port.Collection) error { + + if !inputs.ByNames("i1", "i2").AllHaveSignal() { + return NewErrWaitForInputs(true) } return nil diff --git a/component/errors.go b/component/errors.go index 4708a4e..671df5f 100644 --- a/component/errors.go +++ b/component/errors.go @@ -1,13 +1,19 @@ package component -import "errors" +import ( + "errors" + "fmt" +) var ( - //@TODO: provide wrapper methods so exact input can be specified within error - ErrWaitingForInputResetInputs = errors.New("component is waiting for one or more inputs. All inputs will be reset") - ErrWaitingForInputKeepInputs = errors.New("component is waiting for one or more inputs. All inputs will be kept") + errWaitingForInputs = errors.New("component is waiting for some inputs") + errWaitingForInputsKeep = fmt.Errorf("%w: do not clear input ports", errWaitingForInputs) ) -func IsWaitingForInputError(err error) bool { - return errors.Is(err, ErrWaitingForInputResetInputs) || errors.Is(err, ErrWaitingForInputKeepInputs) +// NewErrWaitForInputs returns respective error +func NewErrWaitForInputs(keepInputs bool) error { + if keepInputs { + return errWaitingForInputsKeep + } + return errWaitingForInputs } diff --git a/component/errors_test.go b/component/errors_test.go deleted file mode 100644 index b1a631b..0000000 --- a/component/errors_test.go +++ /dev/null @@ -1,39 +0,0 @@ -package component - -import ( - "errors" - "testing" -) - -func TestIsWaitingForInputError(t *testing.T) { - type args struct { - err error - } - tests := []struct { - name string - args args - want bool - }{ - { - name: "no", - args: args{ - err: errors.New("test error"), - }, - want: false, - }, - { - name: "yes", - args: args{ - err: ErrWaitingForInputKeepInputs, - }, - want: true, - }, - } - for _, tt := range tests { - t.Run(tt.name, func(t *testing.T) { - if got := IsWaitingForInputError(tt.args.err); got != tt.want { - t.Errorf("IsWaitingForInputError() = %v, want %v", got, tt.want) - } - }) - } -} diff --git a/fmesh.go b/fmesh.go index bb28bcf..7356a9c 100644 --- a/fmesh.go +++ b/fmesh.go @@ -53,45 +53,30 @@ func (fm *FMesh) WithErrorHandlingStrategy(strategy ErrorHandlingStrategy) *FMes return fm } -// runCycle runs one activation cycle (tries to activate all components) +// runCycle runs one activation cycle (tries to activate ready components) func (fm *FMesh) runCycle() *cycle.Cycle { - cycleResult := cycle.New() + newCycle := cycle.New() if len(fm.components) == 0 { - return cycleResult + return newCycle } - activationResultsChan := make(chan *component.ActivationResult) //@TODO: close the channel - doneChan := make(chan struct{}) //@TODO: close the channel - var wg sync.WaitGroup - go func() { - for { - select { - case aRes := <-activationResultsChan: - //@TODO :check for closed channel - cycleResult.Lock() - cycleResult.ActivationResults().Add(aRes) - cycleResult.Unlock() - case <-doneChan: - return - } - } - }() - for _, c := range fm.components { wg.Add(1) - c := c //@TODO: check if this needed - go func() { + + go func(component *component.Component, cycle *cycle.Cycle) { defer wg.Done() - activationResultsChan <- c.MaybeActivate() - }() + + cycle.Lock() + cycle.ActivationResults().Add(c.MaybeActivate()) + cycle.Unlock() + }(c, newCycle) } wg.Wait() - doneChan <- struct{}{} //@TODO: no need to send close signal, just close the channel - return cycleResult + return newCycle } // DrainComponents drains the data from all components outputs diff --git a/fmesh_test.go b/fmesh_test.go index 5cfb484..3cb8b93 100644 --- a/fmesh_test.go +++ b/fmesh_test.go @@ -669,7 +669,7 @@ func TestFMesh_runCycle(t *testing.T) { WithOutputs("o1"). WithActivationFunc(func(inputs port.Collection, outputs port.Collection) error { if !inputs.ByNames("i1", "i2").AllHaveSignal() { - return component.ErrWaitingForInputKeepInputs + return component.NewErrWaitForInputs(true) } return nil }), diff --git a/integration_tests/computation/basic_math_test.go b/integration_tests/computation/math_test.go similarity index 98% rename from integration_tests/computation/basic_math_test.go rename to integration_tests/computation/math_test.go index b0ef6ed..ebae1cb 100644 --- a/integration_tests/computation/basic_math_test.go +++ b/integration_tests/computation/math_test.go @@ -10,7 +10,7 @@ import ( "testing" ) -func Test_BasicMath(t *testing.T) { +func Test_Math(t *testing.T) { tests := []struct { name string setupFM func() *fmesh.FMesh diff --git a/integration_tests/piping/multiplexing_test.go b/integration_tests/piping/multiplexing_test.go new file mode 100644 index 0000000..7530865 --- /dev/null +++ b/integration_tests/piping/multiplexing_test.go @@ -0,0 +1,154 @@ +package integration_tests + +import ( + "github.com/hovsep/fmesh" + "github.com/hovsep/fmesh/component" + "github.com/hovsep/fmesh/cycle" + "github.com/hovsep/fmesh/port" + "github.com/hovsep/fmesh/signal" + "github.com/stretchr/testify/assert" + "math/rand" + "testing" + "time" +) + +func Test_Multiplexing(t *testing.T) { + tests := []struct { + name string + setupFM func() *fmesh.FMesh + setInputs func(fm *fmesh.FMesh) + assertions func(t *testing.T, fm *fmesh.FMesh, cycles cycle.Collection, err error) + }{ + { + name: "fan-out (3 pipes from 1 source port)", + setupFM: func() *fmesh.FMesh { + fm := fmesh.New("fan-out").WithComponents( + component.NewComponent("producer"). + WithInputs("start"). + WithOutputs("o1"). + WithActivationFunc(func(inputs port.Collection, outputs port.Collection) error { + outputs.ByName("o1").PutSignal(signal.New(time.Now())) + return nil + }), + + component.NewComponent("consumer1"). + WithInputs("i1"). + WithOutputs("o1"). + WithActivationFunc(func(inputs port.Collection, outputs port.Collection) error { + //Bypass received signal to output + port.ForwardSignal(inputs.ByName("i1"), outputs.ByName("o1")) + return nil + }), + + component.NewComponent("consumer2"). + WithInputs("i1"). + WithOutputs("o1"). + WithActivationFunc(func(inputs port.Collection, outputs port.Collection) error { + //Bypass received signal to output + port.ForwardSignal(inputs.ByName("i1"), outputs.ByName("o1")) + return nil + }), + + component.NewComponent("consumer3"). + WithInputs("i1"). + WithOutputs("o1"). + WithActivationFunc(func(inputs port.Collection, outputs port.Collection) error { + //Bypass received signal to output + port.ForwardSignal(inputs.ByName("i1"), outputs.ByName("o1")) + return nil + }), + ) + + fm.Components().ByName("producer").Outputs().ByName("o1").PipeTo( + fm.Components().ByName("consumer1").Inputs().ByName("i1"), + fm.Components().ByName("consumer2").Inputs().ByName("i1"), + fm.Components().ByName("consumer3").Inputs().ByName("i1")) + + return fm + }, + setInputs: func(fm *fmesh.FMesh) { + //Fire the mesh + fm.Components().ByName("producer").Inputs().ByName("start").PutSignal(signal.New(struct{}{})) + }, + assertions: func(t *testing.T, fm *fmesh.FMesh, cycles cycle.Collection, err error) { + //All consumers received a signal + c1, c2, c3 := fm.Components().ByName("consumer1"), fm.Components().ByName("consumer2"), fm.Components().ByName("consumer3") + assert.True(t, c1.Outputs().ByName("o1").HasSignal()) + assert.True(t, c2.Outputs().ByName("o1").HasSignal()) + assert.True(t, c3.Outputs().ByName("o1").HasSignal()) + + //All 3 signals are the same (literally the same address in memory) + sig1, sig2, sig3 := c1.Outputs().ByName("o1").Signal(), c2.Outputs().ByName("o1").Signal(), c3.Outputs().ByName("o1").Signal() + assert.Equal(t, sig1.Payload(), sig2.Payload()) + assert.Equal(t, sig2.Payload(), sig3.Payload()) + }, + }, + { + name: "multiplexing", + setupFM: func() *fmesh.FMesh { + producer1 := component.NewComponent("producer1"). + WithInputs("start"). + WithOutputs("o1"). + WithActivationFunc(func(inputs port.Collection, outputs port.Collection) error { + outputs.ByName("o1").PutSignal(signal.New(rand.Int())) + return nil + }) + + producer2 := component.NewComponent("producer2"). + WithInputs("start"). + WithOutputs("o1"). + WithActivationFunc(func(inputs port.Collection, outputs port.Collection) error { + outputs.ByName("o1").PutSignal(signal.New(rand.Int())) + return nil + }) + + producer3 := component.NewComponent("producer3"). + WithInputs("start"). + WithOutputs("o1"). + WithActivationFunc(func(inputs port.Collection, outputs port.Collection) error { + outputs.ByName("o1").PutSignal(signal.New(rand.Int())) + return nil + }) + consumer := component.NewComponent("consumer"). + WithInputs("i1"). + WithOutputs("o1"). + WithActivationFunc(func(inputs port.Collection, outputs port.Collection) error { + //Bypass + port.ForwardSignal(inputs.ByName("i1"), outputs.ByName("o1")) + return nil + }) + + producer1.Outputs().ByName("o1").PipeTo(consumer.Inputs().ByName("i1")) + producer2.Outputs().ByName("o1").PipeTo(consumer.Inputs().ByName("i1")) + producer3.Outputs().ByName("o1").PipeTo(consumer.Inputs().ByName("i1")) + + return fmesh.New("multiplexer").WithComponents(producer1, producer2, producer3, consumer) + }, + setInputs: func(fm *fmesh.FMesh) { + fm.Components().ByName("producer1").Inputs().ByName("start").PutSignal(signal.New(struct{}{})) + fm.Components().ByName("producer2").Inputs().ByName("start").PutSignal(signal.New(struct{}{})) + fm.Components().ByName("producer3").Inputs().ByName("start").PutSignal(signal.New(struct{}{})) + }, + assertions: func(t *testing.T, fm *fmesh.FMesh, cycles cycle.Collection, err error) { + //Consumer received a signal + assert.True(t, fm.Components().ByName("consumer").Outputs().ByName("o1").HasSignal()) + + //The signal is combined and consist of 3 payloads + resultSignal := fm.Components().ByName("consumer").Outputs().ByName("o1").Signal() + assert.Equal(t, resultSignal.Len(), 3) + + //And they are all different + assert.NotEqual(t, resultSignal.Payloads()[0], resultSignal.Payloads()[1]) + assert.NotEqual(t, resultSignal.Payloads()[1], resultSignal.Payloads()[2]) + }, + }, + } + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + fm := tt.setupFM() + tt.setInputs(fm) + cycles, err := fm.Run() + tt.assertions(t, fm, cycles, err) + }) + } +} diff --git a/port/collection.go b/port/collection.go index 5f44ca9..370af50 100644 --- a/port/collection.go +++ b/port/collection.go @@ -5,25 +5,31 @@ import ( ) // Collection is a port collection with useful methods -type Collection map[string]*Port +type Collection []*Port // NewPortsCollection creates empty collection func NewPortsCollection() Collection { - return make(Collection) + return make(Collection, 0) } // ByName returns a port by its name func (collection Collection) ByName(name string) *Port { - return collection[name] + for _, p := range collection { + if p.Name() == name { + return p + } + } + return nil } // ByNames returns multiple ports by their names func (collection Collection) ByNames(names ...string) Collection { - selectedPorts := make(Collection) + selectedPorts := NewPortsCollection() for _, name := range names { - if p, ok := collection[name]; ok { - selectedPorts[name] = p + p := collection.ByName(name) + if p != nil { + selectedPorts = selectedPorts.Add(p) } } @@ -66,12 +72,13 @@ func (collection Collection) ClearSignal() { } } +// Add adds ports to collection func (collection Collection) Add(ports ...*Port) Collection { for _, port := range ports { if port == nil { continue } - collection[port.Name()] = port + collection = append(collection, port) } return collection diff --git a/port/collection_test.go b/port/collection_test.go index ff6e219..7e92f68 100644 --- a/port/collection_test.go +++ b/port/collection_test.go @@ -151,7 +151,7 @@ func TestCollection_ByNames(t *testing.T) { names: []string{"p1"}, }, want: Collection{ - "p1": &Port{ + &Port{ name: "p1", pipes: Collection{}, }, @@ -164,8 +164,8 @@ func TestCollection_ByNames(t *testing.T) { names: []string{"p1", "p2"}, }, want: Collection{ - "p1": &Port{name: "p1", pipes: Collection{}}, - "p2": &Port{name: "p2", pipes: Collection{}}, + &Port{name: "p1", pipes: Collection{}}, + &Port{name: "p2", pipes: Collection{}}, }, }, { @@ -183,8 +183,8 @@ func TestCollection_ByNames(t *testing.T) { names: []string{"p1", "p2", "p3"}, }, want: Collection{ - "p1": &Port{name: "p1", pipes: Collection{}}, - "p2": &Port{name: "p2", pipes: Collection{}}, + &Port{name: "p1", pipes: Collection{}}, + &Port{name: "p2", pipes: Collection{}}, }, }, } @@ -252,7 +252,7 @@ func TestCollection_Add(t *testing.T) { } for _, tt := range tests { t.Run(tt.name, func(t *testing.T) { - tt.collection.Add(tt.args.ports...) + tt.collection = tt.collection.Add(tt.args.ports...) if tt.assertions != nil { tt.assertions(t, tt.collection) } diff --git a/port/port.go b/port/port.go index 599f70b..5b647f4 100644 --- a/port/port.go +++ b/port/port.go @@ -50,7 +50,7 @@ func (p *Port) PipeTo(toPorts ...*Port) { if toPort == nil { continue } - p.pipes.Add(toPort) + p.pipes = p.pipes.Add(toPort) } } @@ -61,7 +61,7 @@ func (p *Port) Flush() { } for _, outboundPort := range p.pipes { - //Multiplexing + //Fan-Out ForwardSignal(p, outboundPort) } p.ClearSignal() diff --git a/port/port_test.go b/port/port_test.go index 55cba02..f94c314 100644 --- a/port/port_test.go +++ b/port/port_test.go @@ -3,7 +3,6 @@ package port import ( "github.com/hovsep/fmesh/signal" "github.com/stretchr/testify/assert" - "reflect" "testing" ) @@ -74,9 +73,8 @@ func TestPort_Signal(t *testing.T) { } for _, tt := range tests { t.Run(tt.name, func(t *testing.T) { - if got := tt.port.Signal(); !reflect.DeepEqual(got, tt.want) { - t.Errorf("Signal() = %v, want %v", got, tt.want) - } + got := tt.port.Signal() + assert.Equal(t, tt.want, got) }) } } @@ -104,9 +102,7 @@ func TestPort_ClearSignal(t *testing.T) { for _, tt := range tests { t.Run(tt.name, func(t *testing.T) { tt.before.ClearSignal() - if !reflect.DeepEqual(tt.before, tt.after) { - t.Errorf("ClearSignal() = %v, want %v", tt.before, tt.after) - } + assert.Equal(t, tt.after, tt.before) }) } } @@ -149,9 +145,7 @@ func TestPort_PipeTo(t *testing.T) { for _, tt := range tests { t.Run(tt.name, func(t *testing.T) { tt.before.PipeTo(tt.args.toPorts...) - if !reflect.DeepEqual(tt.before, tt.after) { - t.Errorf("PipeTo() = %v, want %v", tt.before, tt.after) - } + assert.Equal(t, tt.after, tt.before) }) } } @@ -239,9 +233,7 @@ func TestPort_PutSignal(t *testing.T) { for _, tt := range tests { t.Run(tt.name, func(t *testing.T) { tt.before.PutSignal(tt.args.sig) - if !reflect.DeepEqual(tt.before, tt.after) { - t.Errorf("ClearSignal() = %v, want %v", tt.before, tt.after) - } + assert.Equal(t, tt.after, tt.before) }) } } @@ -260,7 +252,7 @@ func TestPort_Name(t *testing.T) { } for _, tt := range tests { t.Run(tt.name, func(t *testing.T) { - assert.Equalf(t, tt.want, tt.port.Name(), "Name()") + assert.Equal(t, tt.want, tt.port.Name()) }) } } @@ -291,7 +283,7 @@ func TestNewPort(t *testing.T) { } for _, tt := range tests { t.Run(tt.name, func(t *testing.T) { - assert.Equalf(t, tt.want, NewPort(tt.args.name), "NewPort(%v)", tt.args.name) + assert.Equal(t, tt.want, NewPort(tt.args.name)) }) } }