diff --git a/.travis.yml b/.travis.yml index 00cac622c..d2a8ce113 100644 --- a/.travis.yml +++ b/.travis.yml @@ -15,7 +15,7 @@ env: - TESTDIR=adaptor/postgres/... - TESTDIR=adaptor/rabbitmq/... - TESTDIR=adaptor/rethinkdb/... - - TESTDIR="adaptor, adaptor/all, adaptor/file/..., adaptor/transformer/..., client/..., events/..., log/..., message/..., pipe/..., state/..., pipeline/..." + - TESTDIR="adaptor, adaptor/all, adaptor/file/..., client/..., events/..., function/..., log/..., message/..., pipe/..., state/..., pipeline/..." - TESTDIR=integration_tests/mongo_to_mongo - TESTDIR=integration_tests/mongo_to_es - TESTDIR=integration_tests/mongo_to_rethink diff --git a/CHANGELOG.md b/CHANGELOG.md index 2f1d62d85..93db26dbf 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -14,6 +14,7 @@ Transporter no longer requires a YAML file. All configuration is in the JS file - NEW RabbitMQ adaptor [#298](https://github.com/compose/transporter/pull/298) - MongoDB adaptor supports per collection query filter when needing to copy only a subset of data [#301](https://github.com/compose/transporter/pull/301) - [goja](https://github.com/dop251/goja) added as an option for the JavaScript VM in transformers [#294](https://github.com/compose/transporter/pull/294) +- NEW [native functions](https://github.com/compose/transporter#native-functions) ### Bugfixes diff --git a/README.md b/README.md index b62575591..981a9d34a 100644 --- a/README.md +++ b/README.md @@ -28,16 +28,18 @@ Each adaptor has its own README page with details on configuration and capabilit * [postgresql](./adaptor/postgres) * [rabbitmq](./adaptor/rabbitmq) * [rethinkdb](./adaptor/rethinkdb) -* [transformer](./adaptor/transformer) Native Functions ---------------- Each native function can be used as part of a `Transform` step in the pipeline. +* [goja](./adaptor/function/gojajs) * [omit](./adaptor/function/omit) +* [otto](./adaptor/function/ottojs) * [pick](./adaptor/function/pick) * [pretty](./adaptor/function/pretty) +* [rename](./adaptor/function/rename) * [skip](./adaptor/function/skip) Commands @@ -94,7 +96,6 @@ mongodb - a mongodb adaptor that functions as both a source and a sink postgres - a postgres adaptor that functions as both a source and a sink rabbitmq - an adaptor that handles publish/subscribe messaging with RabbitMQ rethinkdb - a rethinkdb adaptor that functions as both a source and a sink -transformer - an adaptor that transforms documents using a javascript function ``` Giving the name of an adaptor produces more detail, such as the sample configuration. diff --git a/adaptor/adaptor.go b/adaptor/adaptor.go index b94d72375..e030db0d6 100644 --- a/adaptor/adaptor.go +++ b/adaptor/adaptor.go @@ -116,7 +116,6 @@ func CompileNamespace(ns string) (string, *regexp.Regexp, error) { // BaseConfig is a standard typed config struct to use for as general purpose config for most databases. type BaseConfig struct { - URI string `json:"uri"` - Namespace string `json:"namespace"` - Timeout string `json:"timeout"` + URI string `json:"uri"` + Timeout string `json:"timeout"` } diff --git a/adaptor/all/all.go b/adaptor/all/all.go index dbc0ad166..299f10599 100644 --- a/adaptor/all/all.go +++ b/adaptor/all/all.go @@ -4,10 +4,8 @@ import ( // Initialize all adapters by importing this package _ "github.com/compose/transporter/adaptor/elasticsearch" _ "github.com/compose/transporter/adaptor/file" - _ "github.com/compose/transporter/adaptor/function" _ "github.com/compose/transporter/adaptor/mongodb" _ "github.com/compose/transporter/adaptor/postgres" _ "github.com/compose/transporter/adaptor/rabbitmq" _ "github.com/compose/transporter/adaptor/rethinkdb" - _ "github.com/compose/transporter/adaptor/transformer" ) diff --git a/adaptor/elasticsearch/elasticsearch.go b/adaptor/elasticsearch/elasticsearch.go index da590d563..9a44d90d6 100644 --- a/adaptor/elasticsearch/elasticsearch.go +++ b/adaptor/elasticsearch/elasticsearch.go @@ -72,11 +72,10 @@ func (e *Elasticsearch) Reader() (client.Reader, error) { // Writer determines the which underlying writer to used based on the cluster's version. func (e *Elasticsearch) Writer(done chan struct{}, wg *sync.WaitGroup) (client.Writer, error) { - index, _, _ := adaptor.CompileNamespace(e.Namespace) - return setupWriter(index, e) + return setupWriter(e) } -func setupWriter(index string, conf *Elasticsearch) (client.Writer, error) { +func setupWriter(conf *Elasticsearch) (client.Writer, error) { uri, err := url.Parse(conf.URI) if err != nil { return nil, client.InvalidURIError{URI: conf.URI, Err: err.Error()} @@ -95,7 +94,7 @@ func setupWriter(index string, conf *Elasticsearch) (client.Writer, error) { timeout, err := time.ParseDuration(conf.Timeout) if err != nil { - log.Infof("failed to parse duration, %s, falling back to default timeout of 30s", conf.Timeout) + log.Debugf("failed to parse duration, %s, falling back to default timeout of 30s", conf.Timeout) timeout = 30 * time.Second } @@ -114,7 +113,7 @@ func setupWriter(index string, conf *Elasticsearch) (client.Writer, error) { URLs: urls, UserInfo: uri.User, HTTPClient: httpClient, - Index: index, + Index: uri.Path[1:], } versionedClient, _ := vc.Creator(opts) return versionedClient, nil diff --git a/adaptor/elasticsearch/elasticsearch_test.go b/adaptor/elasticsearch/elasticsearch_test.go index 10df223ee..3c0add306 100644 --- a/adaptor/elasticsearch/elasticsearch_test.go +++ b/adaptor/elasticsearch/elasticsearch_test.go @@ -38,7 +38,7 @@ var ( authURI = func() string { uri, _ := url.Parse(authedServer.URL) uri.User = url.UserPassword(testUser, testPwd) - return uri.String() + return fmt.Sprintf("%s/test", uri.String()) } ) var authedServer = httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { @@ -73,48 +73,48 @@ var clientTests = []struct { }{ { "base config", - adaptor.Config{"uri": goodVersionServer.URL, "namespace": "test.test"}, + adaptor.Config{"uri": fmt.Sprintf("%s/test", goodVersionServer.URL)}, nil, }, { "timeout config", - adaptor.Config{"uri": goodVersionServer.URL, "namespace": "test.test", "timeout": "60s"}, + adaptor.Config{"uri": fmt.Sprintf("%s/test", goodVersionServer.URL), "timeout": "60s"}, nil, }, { "authed URI", - adaptor.Config{"uri": authURI(), "namespace": "test.test"}, + adaptor.Config{"uri": authURI()}, nil, }, { "bad URI", - adaptor.Config{"uri": "%gh&%ij", "namespace": "test.test"}, + adaptor.Config{"uri": "%gh&%ij"}, client.InvalidURIError{URI: "%gh&%ij", Err: `parse %gh&%ij: invalid URL escape "%gh"`}, }, { "no connection", - adaptor.Config{"uri": "http://localhost:7200", "namespace": "test.test"}, + adaptor.Config{"uri": "http://localhost:7200/test"}, client.ConnectError{Reason: "http://localhost:7200"}, }, { "empty body", - adaptor.Config{"uri": emptyBodyServer.URL, "namespace": "test.test"}, + adaptor.Config{"uri": fmt.Sprintf("%s/test", emptyBodyServer.URL)}, client.VersionError{URI: emptyBodyServer.URL, V: "", Err: "missing version: {}"}, }, { "malformed JSON", - adaptor.Config{"uri": badJSONServer.URL, "namespace": "test.test"}, + adaptor.Config{"uri": fmt.Sprintf("%s/test", badJSONServer.URL)}, client.VersionError{URI: badJSONServer.URL, V: "", Err: "malformed JSON: Hello, client"}, }, { "bad version", - adaptor.Config{"uri": badVersionServer.URL, "namespace": "test.test"}, - client.VersionError{URI: badVersionServer.URL, V: "not a version", Err: "Malformed version: not a version"}, + adaptor.Config{"uri": fmt.Sprintf("%s/test", badVersionServer.URL)}, + client.VersionError{URI: fmt.Sprintf("%s/test", badVersionServer.URL), V: "not a version", Err: "Malformed version: not a version"}, }, { "unsupported version", - adaptor.Config{"uri": unsupportedVersionServer.URL, "namespace": "test.test"}, - client.VersionError{URI: unsupportedVersionServer.URL, V: "0.9.2", Err: "unsupported client"}, + adaptor.Config{"uri": fmt.Sprintf("%s/test", unsupportedVersionServer.URL)}, + client.VersionError{URI: fmt.Sprintf("%s/test", unsupportedVersionServer.URL), V: "0.9.2", Err: "unsupported client"}, }, } diff --git a/adaptor/file/file_test.go b/adaptor/file/file_test.go index 63b30a94b..cc6e4114b 100644 --- a/adaptor/file/file_test.go +++ b/adaptor/file/file_test.go @@ -21,7 +21,7 @@ func TestSampleConfig(t *testing.T) { } var initTests = []map[string]interface{}{ - {"uri": DefaultURI, "namespace": "test.test"}, + {"uri": DefaultURI}, } func TestInit(t *testing.T) { diff --git a/adaptor/function/all.go b/adaptor/function/all.go deleted file mode 100644 index 6fe69fade..000000000 --- a/adaptor/function/all.go +++ /dev/null @@ -1,8 +0,0 @@ -package function - -import ( - _ "github.com/compose/transporter/adaptor/function/omit" - _ "github.com/compose/transporter/adaptor/function/pick" - _ "github.com/compose/transporter/adaptor/function/pretty" - _ "github.com/compose/transporter/adaptor/function/skip" -) diff --git a/adaptor/function/omit/omitter.go b/adaptor/function/omit/omitter.go deleted file mode 100644 index 4cd2559b5..000000000 --- a/adaptor/function/omit/omitter.go +++ /dev/null @@ -1,43 +0,0 @@ -package omit - -import ( - "sync" - - "github.com/compose/transporter/adaptor" - "github.com/compose/transporter/client" - "github.com/compose/transporter/message" -) - -func init() { - adaptor.Add( - "omit", - func() adaptor.Adaptor { - return &Omitter{} - }, - ) -} - -type Omitter struct { - Fields []string `json:"fields"` -} - -func (o *Omitter) Client() (client.Client, error) { - return &client.Mock{}, nil -} - -func (o *Omitter) Reader() (client.Reader, error) { - return nil, adaptor.ErrFuncNotSupported{Name: "transformer", Func: "Reader()"} -} - -func (o *Omitter) Writer(chan struct{}, *sync.WaitGroup) (client.Writer, error) { - return o, nil -} - -func (o *Omitter) Write(msg message.Msg) func(client.Session) (message.Msg, error) { - return func(s client.Session) (message.Msg, error) { - for _, k := range o.Fields { - msg.Data().Delete(k) - } - return msg, nil - } -} diff --git a/adaptor/function/pick/picker.go b/adaptor/function/pick/picker.go deleted file mode 100644 index d5c8068ea..000000000 --- a/adaptor/function/pick/picker.go +++ /dev/null @@ -1,46 +0,0 @@ -package pick - -import ( - "sync" - - "github.com/compose/transporter/adaptor" - "github.com/compose/transporter/client" - "github.com/compose/transporter/message" -) - -func init() { - adaptor.Add( - "pick", - func() adaptor.Adaptor { - return &Picker{} - }, - ) -} - -type Picker struct { - Fields []string `json:"fields"` -} - -func (p *Picker) Client() (client.Client, error) { - return &client.Mock{}, nil -} - -func (p *Picker) Reader() (client.Reader, error) { - return nil, adaptor.ErrFuncNotSupported{Name: "transformer", Func: "Reader()"} -} - -func (p *Picker) Writer(chan struct{}, *sync.WaitGroup) (client.Writer, error) { - return p, nil -} - -func (p *Picker) Write(msg message.Msg) func(client.Session) (message.Msg, error) { - return func(s client.Session) (message.Msg, error) { - pluckedMsg := map[string]interface{}{} - for _, k := range p.Fields { - if v, ok := msg.Data().AsMap()[k]; ok { - pluckedMsg[k] = v - } - } - return message.From(msg.OP(), msg.Namespace(), pluckedMsg), nil - } -} diff --git a/adaptor/function/pretty/prettify.go b/adaptor/function/pretty/prettify.go deleted file mode 100644 index db2d09529..000000000 --- a/adaptor/function/pretty/prettify.go +++ /dev/null @@ -1,58 +0,0 @@ -package pretty - -import ( - "encoding/json" - "strings" - "sync" - - "github.com/compose/mejson" - "github.com/compose/transporter/adaptor" - "github.com/compose/transporter/client" - "github.com/compose/transporter/log" - "github.com/compose/transporter/message" -) - -const ( - DefaultIndent = 2 -) - -var ( - DefaultPrettifier = &Prettify{Spaces: DefaultIndent} -) - -func init() { - adaptor.Add( - "pretty", - func() adaptor.Adaptor { - return DefaultPrettifier - }, - ) -} - -type Prettify struct { - Spaces int `json:"spaces"` -} - -func (p *Prettify) Client() (client.Client, error) { - return &client.Mock{}, nil -} - -func (p *Prettify) Reader() (client.Reader, error) { - return nil, adaptor.ErrFuncNotSupported{Name: "transformer", Func: "Reader()"} -} - -func (p *Prettify) Writer(chan struct{}, *sync.WaitGroup) (client.Writer, error) { - return p, nil -} - -func (p *Prettify) Write(msg message.Msg) func(client.Session) (message.Msg, error) { - return func(s client.Session) (message.Msg, error) { - d, _ := mejson.Unmarshal(msg.Data()) - b, _ := json.Marshal(d) - if p.Spaces > 0 { - b, _ = json.MarshalIndent(d, "", strings.Repeat(" ", p.Spaces)) - } - log.Infof("\n%s", string(b)) - return msg, nil - } -} diff --git a/adaptor/function/skip/skipper.go b/adaptor/function/skip/skipper.go deleted file mode 100644 index cd6357ee5..000000000 --- a/adaptor/function/skip/skipper.go +++ /dev/null @@ -1,127 +0,0 @@ -package skip - -import ( - "fmt" - "math" - "reflect" - "regexp" - "strconv" - "sync" - - "github.com/compose/transporter/adaptor" - "github.com/compose/transporter/client" - "github.com/compose/transporter/message" -) - -type UnknownOperatorError struct { - Op string -} - -func (e UnknownOperatorError) Error() string { - return fmt.Sprintf("unkown operator, %s", e.Op) -} - -type WrongTypeError struct { - Wanted string - Got string -} - -func (e WrongTypeError) Error() string { - return fmt.Sprintf("value is of incompatible type, wanted %s, got %s", e.Wanted, e.Got) -} - -func init() { - adaptor.Add( - "skip", - func() adaptor.Adaptor { - return &Skip{} - }, - ) -} - -type Skip struct { - Field string `json:"field"` - Operator string `json:"operator"` - Match interface{} `json:"match"` -} - -func (s *Skip) Client() (client.Client, error) { - return &client.Mock{}, nil -} - -func (s *Skip) Reader() (client.Reader, error) { - return nil, adaptor.ErrFuncNotSupported{Name: "transformer", Func: "Reader()"} -} - -func (s *Skip) Writer(chan struct{}, *sync.WaitGroup) (client.Writer, error) { - return s, nil -} - -func (s *Skip) Write(msg message.Msg) func(client.Session) (message.Msg, error) { - return func(client.Session) (message.Msg, error) { - val := msg.Data().Get(s.Field) - switch s.Operator { - case "==", "eq", "$eq": - if reflect.DeepEqual(val, s.Match) { - return msg, nil - } - case "=~": - if ok, err := regexp.MatchString(s.Match.(string), val.(string)); err != nil || ok { - return msg, err - } - case ">", "gt", "$gt": - v, m, err := convertForComparison(val, s.Match) - if err == nil && v > m { - return msg, err - } - return nil, err - case ">=", "gte", "$gte": - v, m, err := convertForComparison(val, s.Match) - if err == nil && v >= m { - return msg, err - } - return nil, err - case "<", "lt", "$lt": - v, m, err := convertForComparison(val, s.Match) - if err == nil && v < m { - return msg, err - } - return nil, err - case "<=", "lte", "$lte": - v, m, err := convertForComparison(val, s.Match) - if err == nil && v <= m { - return msg, err - } - return nil, err - default: - return nil, UnknownOperatorError{s.Operator} - } - return nil, nil - } -} - -func convertForComparison(in1, in2 interface{}) (float64, float64, error) { - float1, err := convertToFloat(in1) - if err != nil { - return math.NaN(), math.NaN(), err - } - float2, err := convertToFloat(in2) - if err != nil { - return math.NaN(), math.NaN(), err - } - return float1, float2, nil -} - -func convertToFloat(in interface{}) (float64, error) { - switch i := in.(type) { - case float64: - return i, nil - case int: - return float64(i), nil - case string: - return strconv.ParseFloat(i, 0) - default: - return math.NaN(), WrongTypeError{"float64 or int", fmt.Sprintf("%T", i)} - } - -} diff --git a/adaptor/mongodb/bulk.go b/adaptor/mongodb/bulk.go index 360f7d06e..99e1304dd 100644 --- a/adaptor/mongodb/bulk.go +++ b/adaptor/mongodb/bulk.go @@ -25,7 +25,6 @@ var ( // Bulk implements client.Writer for use with MongoDB and takes advantage of the Bulk API for // performance improvements. type Bulk struct { - db string bulkMap map[string]*bulkOperation *sync.RWMutex } @@ -40,9 +39,8 @@ type bulkOperation struct { bsonOpSize int } -func newBulker(db string, done chan struct{}, wg *sync.WaitGroup) *Bulk { +func newBulker(done chan struct{}, wg *sync.WaitGroup) *Bulk { b := &Bulk{ - db: db, bulkMap: make(map[string]*bulkOperation), RWMutex: &sync.RWMutex{}, } @@ -60,7 +58,7 @@ func (b *Bulk) Write(msg message.Msg) func(client.Session) (message.Msg, error) s := s.(*Session).mgoSession.Copy() bOp = &bulkOperation{ s: s, - bulk: s.DB(b.db).C(coll).Bulk(), + bulk: s.DB("").C(coll).Bulk(), } b.bulkMap[coll] = bOp } diff --git a/adaptor/mongodb/bulk_test.go b/adaptor/mongodb/bulk_test.go index e5ab8be89..e655bcc8c 100644 --- a/adaptor/mongodb/bulk_test.go +++ b/adaptor/mongodb/bulk_test.go @@ -2,6 +2,7 @@ package mongodb import ( "crypto/rand" + "fmt" "sync" "testing" "time" @@ -41,15 +42,21 @@ func checkBulkCount(c string, countQuery bson.M, expectedCount int, t *testing.T func TestBulkWrite(t *testing.T) { var wg sync.WaitGroup done := make(chan struct{}) - b := newBulker(bulkTestData.DB, done, &wg) + b := newBulker(done, &wg) + c, _ := NewClient(WithURI(fmt.Sprintf("mongodb://127.0.0.1:27017/%s", bulkTestData.DB))) + s, err := c.Connect() + if err != nil { + t.Fatalf("unable to initialize connection to mongodb, %s", err) + } + defer s.(*Session).Close() for _, bt := range bulkTests { for i := 0; i < testBulkMsgCount; i++ { data := map[string]interface{}{"_id": i, "i": i} for k, v := range bt.extraData { data[k] = v } - b.Write(message.From(bt.op, bulkTestData.C, data))(defaultSession) + b.Write(message.From(bt.op, bulkTestData.C, data))(s) } time.Sleep(3 * time.Second) checkBulkCount(bulkTestData.C, bt.countQuery, bt.expectedCount, t) @@ -60,19 +67,25 @@ func TestBulkWrite(t *testing.T) { func TestBulkWriteMixedOps(t *testing.T) { var wg sync.WaitGroup done := make(chan struct{}) - b := newBulker(bulkTestData.DB, done, &wg) + b := newBulker(done, &wg) + c, _ := NewClient(WithURI(fmt.Sprintf("mongodb://127.0.0.1:27017/%s", bulkTestData.DB))) + s, err := c.Connect() + if err != nil { + t.Fatalf("unable to initialize connection to mongodb, %s", err) + } + defer s.(*Session).Close() mixedModeC := "mixed_mode" - b.Write(message.From(ops.Insert, mixedModeC, map[string]interface{}{"_id": 0}))(defaultSession) - b.Write(message.From(ops.Insert, mixedModeC, map[string]interface{}{"_id": 1}))(defaultSession) - b.Write(message.From(ops.Insert, mixedModeC, map[string]interface{}{"_id": 2}))(defaultSession) - b.Write(message.From(ops.Update, mixedModeC, map[string]interface{}{"_id": 2, "hello": "world"}))(defaultSession) - b.Write(message.From(ops.Insert, mixedModeC, map[string]interface{}{"_id": 3}))(defaultSession) - b.Write(message.From(ops.Update, mixedModeC, map[string]interface{}{"_id": 1, "moar": "tests"}))(defaultSession) - b.Write(message.From(ops.Insert, mixedModeC, map[string]interface{}{"_id": 4, "say": "goodbye"}))(defaultSession) - b.Write(message.From(ops.Delete, mixedModeC, map[string]interface{}{"_id": 1, "moar": "tests"}))(defaultSession) - b.Write(message.From(ops.Delete, mixedModeC, map[string]interface{}{"_id": 3}))(defaultSession) - b.Write(message.From(ops.Insert, mixedModeC, map[string]interface{}{"_id": 5}))(defaultSession) + b.Write(message.From(ops.Insert, mixedModeC, map[string]interface{}{"_id": 0}))(s) + b.Write(message.From(ops.Insert, mixedModeC, map[string]interface{}{"_id": 1}))(s) + b.Write(message.From(ops.Insert, mixedModeC, map[string]interface{}{"_id": 2}))(s) + b.Write(message.From(ops.Update, mixedModeC, map[string]interface{}{"_id": 2, "hello": "world"}))(s) + b.Write(message.From(ops.Insert, mixedModeC, map[string]interface{}{"_id": 3}))(s) + b.Write(message.From(ops.Update, mixedModeC, map[string]interface{}{"_id": 1, "moar": "tests"}))(s) + b.Write(message.From(ops.Insert, mixedModeC, map[string]interface{}{"_id": 4, "say": "goodbye"}))(s) + b.Write(message.From(ops.Delete, mixedModeC, map[string]interface{}{"_id": 1, "moar": "tests"}))(s) + b.Write(message.From(ops.Delete, mixedModeC, map[string]interface{}{"_id": 3}))(s) + b.Write(message.From(ops.Insert, mixedModeC, map[string]interface{}{"_id": 5}))(s) // so... after the ops get flushed we should have the following: // 4 docs left @@ -86,11 +99,17 @@ func TestBulkWriteMixedOps(t *testing.T) { func TestBulkOpCount(t *testing.T) { var wg sync.WaitGroup done := make(chan struct{}) - b := newBulker(bulkTestData.DB, done, &wg) + b := newBulker(done, &wg) + c, _ := NewClient(WithURI(fmt.Sprintf("mongodb://127.0.0.1:27017/%s", bulkTestData.DB))) + s, err := c.Connect() + if err != nil { + t.Fatalf("unable to initialize connection to mongodb, %s", err) + } + defer s.(*Session).Close() for i := 0; i < maxObjSize; i++ { msg := message.From(ops.Insert, "bar", map[string]interface{}{"i": i}) - b.Write(msg)(defaultSession) + b.Write(msg)(s) } close(done) wg.Wait() @@ -100,11 +119,17 @@ func TestBulkOpCount(t *testing.T) { func TestFlushOnDone(t *testing.T) { var wg sync.WaitGroup done := make(chan struct{}) - b := newBulker(bulkTestData.DB, done, &wg) + b := newBulker(done, &wg) + c, _ := NewClient(WithURI(fmt.Sprintf("mongodb://127.0.0.1:27017/%s", bulkTestData.DB))) + s, err := c.Connect() + if err != nil { + t.Fatalf("unable to initialize connection to mongodb, %s", err) + } + defer s.(*Session).Close() for i := 0; i < testBulkMsgCount; i++ { msg := message.From(ops.Insert, "baz", map[string]interface{}{"i": i}) - b.Write(msg)(defaultSession) + b.Write(msg)(s) } close(done) wg.Wait() @@ -114,14 +139,20 @@ func TestFlushOnDone(t *testing.T) { func TestBulkMulitpleCollections(t *testing.T) { var wg sync.WaitGroup done := make(chan struct{}) - b := newBulker(bulkTestData.DB, done, &wg) + b := newBulker(done, &wg) + c, _ := NewClient(WithURI(fmt.Sprintf("mongodb://127.0.0.1:27017/%s", bulkTestData.DB))) + s, err := c.Connect() + if err != nil { + t.Fatalf("unable to initialize connection to mongodb, %s", err) + } + defer s.(*Session).Close() for i := 0; i < (maxObjSize + 1); i++ { - b.Write(message.From(ops.Insert, "multi_c", map[string]interface{}{"i": i}))(defaultSession) + b.Write(message.From(ops.Insert, "multi_c", map[string]interface{}{"i": i}))(s) } for i := 0; i < testBulkMsgCount; i++ { - b.Write(message.From(ops.Insert, "multi_a", map[string]interface{}{"i": i}))(defaultSession) - b.Write(message.From(ops.Insert, "multi_b", map[string]interface{}{"i": i}))(defaultSession) + b.Write(message.From(ops.Insert, "multi_a", map[string]interface{}{"i": i}))(s) + b.Write(message.From(ops.Insert, "multi_b", map[string]interface{}{"i": i}))(s) } checkBulkCount("multi_a", bson.M{}, 0, t) checkBulkCount("multi_b", bson.M{}, 0, t) @@ -134,10 +165,17 @@ func TestBulkMulitpleCollections(t *testing.T) { func TestBulkSize(t *testing.T) { b := &Bulk{ - db: bulkTestData.DB, bulkMap: make(map[string]*bulkOperation), RWMutex: &sync.RWMutex{}, } + + c, _ := NewClient(WithURI(fmt.Sprintf("mongodb://127.0.0.1:27017/%s", bulkTestData.DB))) + s, err := c.Connect() + if err != nil { + t.Fatalf("unable to initialize connection to mongodb, %s", err) + } + defer s.(*Session).Close() + var bsonSize int for i := 0; i < (maxObjSize - 1); i++ { doc := map[string]interface{}{"i": randStr(2), "rand": randStr(16)} @@ -149,7 +187,7 @@ func TestBulkSize(t *testing.T) { bsonSize += (len(bs) + 4) msg := message.From(ops.Insert, "size", doc) - b.Write(msg)(defaultSession) + b.Write(msg)(s) } bOp := b.bulkMap["size"] if int(bOp.bsonOpSize) != bsonSize { diff --git a/adaptor/mongodb/mongodb.go b/adaptor/mongodb/mongodb.go index c573bad30..1f11af0bb 100644 --- a/adaptor/mongodb/mongodb.go +++ b/adaptor/mongodb/mongodb.go @@ -65,24 +65,20 @@ func (m *MongoDB) Client() (client.Client, error) { } func (m *MongoDB) Reader() (client.Reader, error) { - // TODO: pull db from the URI - db, _, err := adaptor.CompileNamespace(m.Namespace) var f map[string]CollectionFilter if m.CollectionFilters != "" { if jerr := json.Unmarshal([]byte(m.CollectionFilters), &f); jerr != nil { return nil, ErrCollectionFilter } } - return newReader(db, m.Tail, f), err + return newReader(m.Tail, f), nil } func (m *MongoDB) Writer(done chan struct{}, wg *sync.WaitGroup) (client.Writer, error) { - // TODO: pull db from the URI - db, _, err := adaptor.CompileNamespace(m.Namespace) if m.Bulk { - return newBulker(db, done, wg), err + return newBulker(done, wg), nil } - return newWriter(db), err + return newWriter(), nil } // Description for mongodb adaptor diff --git a/adaptor/mongodb/mongodb_test.go b/adaptor/mongodb/mongodb_test.go index 5d6288e04..d511c9df3 100644 --- a/adaptor/mongodb/mongodb_test.go +++ b/adaptor/mongodb/mongodb_test.go @@ -32,50 +32,38 @@ var initTests = []struct { }{ { "base", - map[string]interface{}{"uri": DefaultURI, "namespace": "test.test"}, - &MongoDB{BaseConfig: adaptor.BaseConfig{URI: DefaultURI, Namespace: "test.test"}}, + map[string]interface{}{"uri": DefaultURI}, + &MongoDB{BaseConfig: adaptor.BaseConfig{URI: DefaultURI}}, nil, nil, nil, }, { "with timeout", - map[string]interface{}{"uri": DefaultURI, "namespace": "test.test", "timeout": "60s"}, - &MongoDB{BaseConfig: adaptor.BaseConfig{URI: DefaultURI, Namespace: "test.test", Timeout: "60s"}}, + map[string]interface{}{"uri": DefaultURI, "timeout": "60s"}, + &MongoDB{BaseConfig: adaptor.BaseConfig{URI: DefaultURI, Timeout: "60s"}}, nil, nil, nil, }, { "with tail", - map[string]interface{}{"uri": DefaultURI, "namespace": "test.test", "tail": true}, - &MongoDB{BaseConfig: adaptor.BaseConfig{URI: DefaultURI, Namespace: "test.test"}, Tail: true}, + map[string]interface{}{"uri": DefaultURI, "tail": true}, + &MongoDB{BaseConfig: adaptor.BaseConfig{URI: DefaultURI}, Tail: true}, nil, nil, nil, }, { "with bulk", - map[string]interface{}{"uri": DefaultURI, "namespace": "test.test", "bulk": true}, - &MongoDB{BaseConfig: adaptor.BaseConfig{URI: DefaultURI, Namespace: "test.test"}, Bulk: true}, + map[string]interface{}{"uri": DefaultURI, "bulk": true}, + &MongoDB{BaseConfig: adaptor.BaseConfig{URI: DefaultURI}, Bulk: true}, nil, nil, nil, }, { "with collection filters", - map[string]interface{}{"uri": DefaultURI, "namespace": "test.test", "collection_filters": `{"foo":{"i":{"$gt":10}}}`}, - &MongoDB{ - BaseConfig: adaptor.BaseConfig{ - URI: DefaultURI, - Namespace: "test.test", - }, - CollectionFilters: `{"foo":{"i":{"$gt":10}}}`, - }, + map[string]interface{}{"uri": DefaultURI, "collection_filters": `{"foo":{"i":{"$gt":10}}}`}, + &MongoDB{BaseConfig: adaptor.BaseConfig{URI: DefaultURI}, CollectionFilters: `{"foo":{"i":{"$gt":10}}}`}, nil, nil, nil, }, { "bad collection filter", - map[string]interface{}{"uri": DefaultURI, "namespace": "test.test", "collection_filters": `{"foo":{"i":{"$gt":10}}`}, - &MongoDB{ - BaseConfig: adaptor.BaseConfig{ - URI: DefaultURI, - Namespace: "test.test", - }, - CollectionFilters: `{"foo":{"i":{"$gt":10}}`, - }, + map[string]interface{}{"uri": DefaultURI, "collection_filters": `{"foo":{"i":{"$gt":10}}`}, + &MongoDB{BaseConfig: adaptor.BaseConfig{URI: DefaultURI}, CollectionFilters: `{"foo":{"i":{"$gt":10}}`}, nil, ErrCollectionFilter, nil, }, } diff --git a/adaptor/mongodb/reader.go b/adaptor/mongodb/reader.go index eaa09c62c..392a0d46f 100644 --- a/adaptor/mongodb/reader.go +++ b/adaptor/mongodb/reader.go @@ -28,14 +28,13 @@ type CollectionFilter map[string]interface{} // Reader implements the behavior defined by client.Reader for interfacing with MongoDB. type Reader struct { - db string tail bool collectionFilters map[string]CollectionFilter oplogTimeout time.Duration } -func newReader(db string, tail bool, filters map[string]CollectionFilter) client.Reader { - return &Reader{db, tail, filters, 5 * time.Second} +func newReader(tail bool, filters map[string]CollectionFilter) client.Reader { + return &Reader{tail, filters, 5 * time.Second} } type resultDoc struct { @@ -57,20 +56,20 @@ func (r *Reader) Read(filterFn client.NsFilterFunc) client.MessageChanFunc { session.Close() close(out) }() - log.With("db", r.db).Infoln("starting Read func") + log.With("db", session.DB("").Name).Infoln("starting Read func") collections, err := r.listCollections(session.Copy(), filterFn) if err != nil { - log.With("db", r.db).Errorf("unable to list collections, %s", err) + log.With("db", session.DB("").Name).Errorf("unable to list collections, %s", err) return } var wg sync.WaitGroup for _, c := range collections { oplogTime := timeAsMongoTimestamp(time.Now()) if err := r.iterateCollection(session.Copy(), c, out, done); err != nil { - log.With("db", r.db).Errorln(err) + log.With("db", session.DB("").Name).Errorln(err) return } - log.With("db", r.db).With("collection", c).Infoln("iterating complete") + log.With("db", session.DB("").Name).With("collection", c).Infoln("iterating complete") if r.tail { wg.Add(1) log.With("collection", c).Infof("oplog start timestamp: %d", oplogTime) @@ -78,13 +77,13 @@ func (r *Reader) Read(filterFn client.NsFilterFunc) client.MessageChanFunc { defer wg.Done() errc := r.tailCollection(c, session.Copy(), o, out, done) for err := range errc { - log.With("db", r.db).With("collection", c).Errorln(err) + log.With("db", session.DB("").Name).With("collection", c).Errorln(err) return } }(&wg, c, oplogTime) } } - log.With("db", r.db).Infoln("Read completed") + log.With("db", session.DB("").Name).Infoln("Read completed") // this will block if we're tailing wg.Wait() return @@ -97,20 +96,21 @@ func (r *Reader) Read(filterFn client.NsFilterFunc) client.MessageChanFunc { func (r *Reader) listCollections(mgoSession *mgo.Session, filterFn func(name string) bool) ([]string, error) { defer mgoSession.Close() var colls []string - collections, err := mgoSession.DB(r.db).CollectionNames() + db := mgoSession.DB("") + collections, err := db.CollectionNames() if err != nil { return colls, err } - log.With("db", r.db).With("num_collections", len(collections)).Infoln("collection count") + log.With("db", db.Name).With("num_collections", len(collections)).Infoln("collection count") for _, c := range collections { if filterFn(c) && !strings.HasPrefix(c, "system.") { - log.With("db", r.db).With("collection", c).Infoln("adding for iteration...") + log.With("db", db.Name).With("collection", c).Infoln("adding for iteration...") colls = append(colls, c) } else { - log.With("db", r.db).With("collection", c).Infoln("skipping iteration...") + log.With("db", db.Name).With("collection", c).Infoln("skipping iteration...") } } - log.With("db", r.db).Infoln("done iterating collections") + log.With("db", db.Name).Infoln("done iterating collections") return colls, nil } @@ -136,6 +136,7 @@ func (r *Reader) iterate(s *mgo.Session, c string) <-chan message.Msg { s.Close() close(msgChan) }() + db := s.DB("").Name canReissueQuery := r.requeryable(c, s) var lastID interface{} for { @@ -151,10 +152,10 @@ func (r *Reader) iterate(s *mgo.Session, c string) <-chan message.Msg { result = bson.M{} } if err := iter.Err(); err != nil { - log.With("database", r.db).With("collection", c).Errorf("error reading, %s", err) + log.With("database", db).With("collection", c).Errorf("error reading, %s", err) session.Close() if canReissueQuery { - log.With("database", r.db).With("collection", c).Errorln("attempting to reissue query") + log.With("database", db).With("collection", c).Errorln("attempting to reissue query") time.Sleep(5 * time.Second) continue } @@ -176,21 +177,22 @@ func (r *Reader) catQuery(c string, lastID interface{}, mgoSession *mgo.Session) if lastID != nil { query["_id"] = bson.M{"$gt": lastID} } - return mgoSession.DB(r.db).C(c).Find(query).Sort("_id") + return mgoSession.DB("").C(c).Find(query).Sort("_id") } func (r *Reader) requeryable(c string, mgoSession *mgo.Session) bool { - indexes, err := mgoSession.DB(r.db).C(c).Indexes() + db := mgoSession.DB("") + indexes, err := db.C(c).Indexes() if err != nil { - log.With("database", r.db).With("collection", c).Errorf("unable to list indexes, %s", err) + log.With("database", db.Name).With("collection", c).Errorf("unable to list indexes, %s", err) return false } for _, index := range indexes { if index.Key[0] == "_id" { var result bson.M - err := mgoSession.DB(r.db).C(c).Find(nil).Select(bson.M{"_id": 1}).One(&result) + err := db.C(c).Find(nil).Select(bson.M{"_id": 1}).One(&result) if err != nil { - log.With("database", r.db).With("collection", c).Errorf("unable to sample document, %s", err) + log.With("database", db.Name).With("collection", c).Errorf("unable to sample document, %s", err) break } if id, ok := result["_id"]; ok && sortable(id) { @@ -199,7 +201,7 @@ func (r *Reader) requeryable(c string, mgoSession *mgo.Session) bool { break } } - log.With("database", r.db).With("collection", c).Infoln("invalid _id, any issues copying will be aborted") + log.With("database", db.Name).With("collection", c).Infoln("invalid _id, any issues copying will be aborted") return false } @@ -222,16 +224,17 @@ func (r *Reader) tailCollection(c string, mgoSession *mgo.Session, oplogTime bso var ( collection = mgoSession.DB("local").C("oplog.rs") result oplogDoc // hold the document - query = bson.M{"ns": fmt.Sprintf("%s.%s", r.db, c), "ts": bson.M{"$gte": oplogTime}} + db = mgoSession.DB("").Name + query = bson.M{"ns": fmt.Sprintf("%s.%s", db, c), "ts": bson.M{"$gte": oplogTime}} iter = collection.Find(query).LogReplay().Sort("$natural").Tail(r.oplogTimeout) ) defer iter.Close() for { - log.With("db", r.db).Infof("tailing oplog with query %+v", query) + log.With("db", db).Infof("tailing oplog with query %+v", query) select { case <-done: - log.With("db", r.db).Infoln("tailing stopping...") + log.With("db", db).Infoln("tailing stopping...") return default: for iter.Next(&result) { @@ -273,7 +276,7 @@ func (r *Reader) tailCollection(c string, mgoSession *mgo.Session, oplogTime bso continue } if iter.Err() != nil { - log.With("path", r.db).Errorf("error tailing oplog, %s", iter.Err()) + log.With("path", db).Errorf("error tailing oplog, %s", iter.Err()) // return adaptor.NewError(adaptor.CRITICAL, m.path, fmt.Sprintf("MongoDB error (error reading collection %s)", iter.Err()), nil) } @@ -300,9 +303,9 @@ func (r *Reader) getOriginalDoc(doc bson.M, c string, s *mgo.Session) (result bs } query["_id"] = id - err = s.DB(r.db).C(c).Find(query).One(&result) + err = s.DB("").C(c).Find(query).One(&result) if err != nil { - err = fmt.Errorf("%s.%s %v %v", r.db, c, id, err) + err = fmt.Errorf("%s.%s %v %v", s.DB("").Name, c, id, err) } return } diff --git a/adaptor/mongodb/reader_test.go b/adaptor/mongodb/reader_test.go index a67f01ecb..eea7a90c8 100644 --- a/adaptor/mongodb/reader_test.go +++ b/adaptor/mongodb/reader_test.go @@ -32,10 +32,16 @@ func TestRead(t *testing.T) { t.Skip("skipping Read in short mode") } - reader := newReader(readerTestData.DB, false, DefaultCollectionFilter) + reader := newReader(false, DefaultCollectionFilter) readFunc := reader.Read(filterFunc) done := make(chan struct{}) - msgChan, err := readFunc(defaultSession, done) + c, _ := NewClient(WithURI(fmt.Sprintf("mongodb://127.0.0.1:27017/%s", readerTestData.DB))) + s, err := c.Connect() + if err != nil { + t.Fatalf("unable to initialize connection to mongodb, %s", err) + } + defer s.(*Session).Close() + msgChan, err := readFunc(s, done) if err != nil { t.Fatalf("unexpected Read error, %s\n", err) } @@ -55,7 +61,6 @@ func TestFilteredRead(t *testing.T) { } reader := newReader( - filteredReaderTestData.DB, false, map[string]CollectionFilter{"foo": CollectionFilter{"i": map[string]interface{}{"$gt": filteredReaderTestData.InsertCount}}}, ) @@ -66,7 +71,13 @@ func TestFilteredRead(t *testing.T) { readFunc := reader.Read(filterFunc) done := make(chan struct{}) - msgChan, err := readFunc(defaultSession, done) + c, _ := NewClient(WithURI(fmt.Sprintf("mongodb://127.0.0.1:27017/%s", filteredReaderTestData.DB))) + s, err := c.Connect() + if err != nil { + t.Fatalf("unable to initialize connection to mongodb, %s", err) + } + defer s.(*Session).Close() + msgChan, err := readFunc(s, done) if err != nil { t.Fatalf("unexpected Read error, %s\n", err) } @@ -86,10 +97,16 @@ func TestCancelledRead(t *testing.T) { t.Skip("skipping TestCancelledRead in short mode") } - reader := newReader(cancelledReaderTestData.DB, false, DefaultCollectionFilter) + reader := newReader(false, DefaultCollectionFilter) readFunc := reader.Read(filterFunc) done := make(chan struct{}) - msgChan, err := readFunc(defaultSession, done) + c, _ := NewClient(WithURI(fmt.Sprintf("mongodb://127.0.0.1:27017/%s", cancelledReaderTestData.DB))) + s, err := c.Connect() + if err != nil { + t.Fatalf("unable to initialize connection to mongodb, %s", err) + } + defer s.(*Session).Close() + msgChan, err := readFunc(s, done) if err != nil { t.Fatalf("unexpected Read error, %s\n", err) } @@ -136,7 +153,7 @@ func TestReadRestart(t *testing.T) { session.mgoSession.DB(db).C("lotsodata").Insert(bson.M{"i": i}) } - reader := newReader(db, false, DefaultCollectionFilter) + reader := newReader(false, DefaultCollectionFilter) readFunc := reader.Read(filterFunc) done := make(chan struct{}) msgChan, err := readFunc(s, done) @@ -220,7 +237,7 @@ func TestTail(t *testing.T) { t.Fatalf("unexpected insertMockTailData error, %s\n", err) } - tail := newReader(tailTestData.DB, true, DefaultCollectionFilter) + tail := newReader(true, DefaultCollectionFilter) time.Sleep(1 * time.Second) tailFunc := tail.Read(func(c string) bool { @@ -233,7 +250,13 @@ func TestTail(t *testing.T) { }) done := make(chan struct{}) - msgChan, err := tailFunc(defaultSession, done) + c, _ := NewClient(WithURI(fmt.Sprintf("mongodb://127.0.0.1:27017/%s", tailTestData.DB))) + s, err := c.Connect() + if err != nil { + t.Fatalf("unable to initialize connection to mongodb, %s", err) + } + defer s.(*Session).Close() + msgChan, err := tailFunc(s, done) if err != nil { t.Fatalf("unexpected Tail error, %s\n", err) } diff --git a/adaptor/mongodb/writer.go b/adaptor/mongodb/writer.go index d5a78006e..ce9e04938 100644 --- a/adaptor/mongodb/writer.go +++ b/adaptor/mongodb/writer.go @@ -13,12 +13,11 @@ var _ client.Writer = &Writer{} // Writer implements client.Writer for use with MongoDB type Writer struct { - db string writeMap map[ops.Op]func(message.Msg, *mgo.Collection) error } -func newWriter(db string) *Writer { - w := &Writer{db: db} +func newWriter() *Writer { + w := &Writer{} w.writeMap = map[ops.Op]func(message.Msg, *mgo.Collection) error{ ops.Insert: insertMsg, ops.Update: updateMsg, @@ -34,12 +33,12 @@ func (w *Writer) Write(msg message.Msg) func(client.Session) (message.Msg, error log.Infof("no function registered for operation, %s\n", msg.OP()) return msg, nil } - return msg, writeFunc(msg, msgCollection(w.db, msg, s)) + return msg, writeFunc(msg, msgCollection(msg, s)) } } -func msgCollection(db string, msg message.Msg, s client.Session) *mgo.Collection { - return s.(*Session).mgoSession.DB(db).C(msg.Namespace()) +func msgCollection(msg message.Msg, s client.Session) *mgo.Collection { + return s.(*Session).mgoSession.DB("").C(msg.Namespace()) } func insertMsg(msg message.Msg, c *mgo.Collection) error { diff --git a/adaptor/mongodb/writer_test.go b/adaptor/mongodb/writer_test.go index e30ae3598..e4d8b8b78 100644 --- a/adaptor/mongodb/writer_test.go +++ b/adaptor/mongodb/writer_test.go @@ -27,7 +27,7 @@ var optests = []struct { } func TestOpFunc(t *testing.T) { - w := newWriter("test") + w := newWriter() for _, ot := range optests { if _, ok := w.writeMap[ot.op]; ok != ot.registered { t.Errorf("op (%s) registration incorrect, expected %+v, got %+v\n", ot.op.String(), ot.registered, ok) @@ -81,11 +81,17 @@ func TestInsert(t *testing.T) { if testing.Short() { t.Skip("skipping Insert in short mode") } - w := newWriter(writerTestData.DB) + c, _ := NewClient(WithURI(fmt.Sprintf("mongodb://127.0.0.1:27017/%s", writerTestData.DB))) + s, err := c.Connect() + if err != nil { + t.Fatalf("unable to initialize connection to mongodb, %s", err) + } + defer s.(*Session).Close() + w := newWriter() for _, it := range inserttests { for _, data := range it.data { msg := message.From(ops.Insert, it.collection, data) - if _, err := w.Write(msg)(defaultSession); err != nil { + if _, err := w.Write(msg)(s); err != nil { t.Errorf("unexpected Insert error, %s\n", err) } } @@ -130,18 +136,24 @@ func TestUpdate(t *testing.T) { if testing.Short() { t.Skip("skipping Update in short mode") } - w := newWriter(writerTestData.DB) + c, _ := NewClient(WithURI(fmt.Sprintf("mongodb://127.0.0.1:27017/%s", writerTestData.DB))) + s, err := c.Connect() + if err != nil { + t.Fatalf("unable to initialize connection to mongodb, %s", err) + } + defer s.(*Session).Close() + w := newWriter() for _, ut := range updatetests { // Insert data ut.originalDoc.Set("_id", ut.id) msg := message.From(ops.Insert, ut.collection, ut.originalDoc) - if _, err := w.Write(msg)(defaultSession); err != nil { + if _, err := w.Write(msg)(s); err != nil { t.Errorf("unexpected Insert error, %s\n", err) } // Update data ut.updatedDoc.Set("_id", ut.id) msg = message.From(ops.Update, ut.collection, ut.updatedDoc) - if _, err := w.Write(msg)(defaultSession); err != nil { + if _, err := w.Write(msg)(s); err != nil { t.Errorf("unexpected Update error, %s\n", err) } // Validate update @@ -175,17 +187,23 @@ func TestDelete(t *testing.T) { if testing.Short() { t.Skip("skipping Delete in short mode") } - w := newWriter(writerTestData.DB) + c, _ := NewClient(WithURI(fmt.Sprintf("mongodb://127.0.0.1:27017/%s", writerTestData.DB))) + s, err := c.Connect() + if err != nil { + t.Fatalf("unable to initialize connection to mongodb, %s", err) + } + defer s.(*Session).Close() + w := newWriter() for _, dt := range deletetests { // Insert data dt.originalDoc.Set("_id", dt.id) msg := message.From(ops.Insert, dt.collection, dt.originalDoc) - if _, err := w.Write(msg)(defaultSession); err != nil { + if _, err := w.Write(msg)(s); err != nil { t.Errorf("unexpected Insert error, %s\n", err) } // Delete data msg = message.From(ops.Delete, dt.collection, dt.originalDoc) - if _, err := w.Write(msg)(defaultSession); err != nil { + if _, err := w.Write(msg)(s); err != nil { t.Errorf("unexpected Delete error, %s\n", err) } // Validate delete @@ -222,7 +240,7 @@ func TestRestartWrites(t *testing.T) { log.Errorf("failed to drop database (%s), may affect tests!, %s", writerTestData.DB, dropErr) } - w := newWriter(writerTestData.DB) + w := newWriter() done := make(chan struct{}) go func() { for { diff --git a/adaptor/postgres/adaptor_test.go b/adaptor/postgres/adaptor_test.go index 54d1b39be..0866686ea 100644 --- a/adaptor/postgres/adaptor_test.go +++ b/adaptor/postgres/adaptor_test.go @@ -98,33 +98,44 @@ func setup() { os.Exit(1) } defaultSession = s.(*Session) - defaultSession.pqSession.Exec("CREATE TYPE mood AS ENUM('sad', 'ok', 'happy');") for _, testData := range dbsToTest { + if _, err := defaultSession.pqSession.Exec(fmt.Sprintf("DROP DATABASE IF EXISTS %s;", testData.DB)); err != nil { + log.Errorf("unable to drop database, could affect tests, %s", err) + } + if _, err := defaultSession.pqSession.Exec(fmt.Sprintf("CREATE DATABASE %s;", testData.DB)); err != nil { + log.Errorf("unable to create database, could affect tests, %s", err) + } setupData(testData) } } func setupData(data *TestData) { - if _, err := defaultSession.pqSession.Exec(fmt.Sprintf("DROP DATABASE IF EXISTS %s;", data.DB)); err != nil { - log.Errorf("unable to drop database, could affect tests, %s", err) + c, err := NewClient(WithURI(fmt.Sprintf("postgres://127.0.0.1:5432/%s?sslmode=disable", data.DB))) + if err != nil { + log.Errorf("unable to initialize connection to postgres, %s", err) } - - if _, err := defaultSession.pqSession.Exec(fmt.Sprintf("CREATE DATABASE %s;", data.DB)); err != nil { - log.Errorf("unable to create database, could affect tests, %s", err) + defer c.Close() + s, err := c.Connect() + if err != nil { + log.Errorf("unable to obtain session to postgres, %s", err) + } + pqSession := s.(*Session).pqSession + if data.Schema == complexSchema { + pqSession.Exec("CREATE TYPE mood AS ENUM('sad', 'ok', 'happy');") } - if _, err := defaultSession.pqSession.Exec(fmt.Sprintf("DROP TABLE IF EXISTS %s;", data.Table)); err != nil { + if _, err := pqSession.Exec(fmt.Sprintf("DROP TABLE IF EXISTS %s;", data.Table)); err != nil { log.Errorf("unable to drop table, could affect tests, %s", err) } - _, err := defaultSession.pqSession.Exec(fmt.Sprintf("CREATE TABLE %s ( %s );", data.Table, data.Schema)) + _, err = pqSession.Exec(fmt.Sprintf("CREATE TABLE %s ( %s );", data.Table, data.Schema)) if err != nil { log.Errorf("unable to create table, could affect tests, %s", err) } for i := 0; i < data.InsertCount; i++ { if data.Schema == complexSchema { - if _, err := defaultSession.pqSession.Exec(fmt.Sprintf(` + if _, err := pqSession.Exec(fmt.Sprintf(` INSERT INTO %s VALUES ( %d, -- id '%s', -- colvar VARCHAR(255), @@ -174,7 +185,7 @@ func setupData(data *TestData) { log.Errorf("unexpected Insert error, %s\n", err) } } else if data.Schema == basicSchema { - if _, err := defaultSession.pqSession.Exec(fmt.Sprintf(`INSERT INTO %s VALUES ( + if _, err := pqSession.Exec(fmt.Sprintf(`INSERT INTO %s VALUES ( %d, -- id '%s', -- colvar VARCHAR(255), now() at time zone 'utc' -- coltimestamp TIMESTAMP, diff --git a/adaptor/postgres/client.go b/adaptor/postgres/client.go index 33e7df109..426a7c2d1 100644 --- a/adaptor/postgres/client.go +++ b/adaptor/postgres/client.go @@ -2,6 +2,7 @@ package postgres import ( "database/sql" + "net/url" "github.com/compose/transporter/client" @@ -25,6 +26,7 @@ type ClientOptionFunc func(*Client) error // Client represents a client to the underlying File source. type Client struct { uri string + db string pqSession *sql.DB } @@ -33,6 +35,7 @@ func NewClient(options ...ClientOptionFunc) (*Client, error) { // Set up the client c := &Client{ uri: DefaultURI, + db: "postgres", } // Run the options on it @@ -47,8 +50,9 @@ func NewClient(options ...ClientOptionFunc) (*Client, error) { // WithURI defines the full connection string for the Postgres connection func WithURI(uri string) ClientOptionFunc { return func(c *Client) error { + _, err := url.Parse(uri) c.uri = uri - return nil + return err } } @@ -65,7 +69,11 @@ func (c *Client) Connect() (client.Session, error) { // there's really no way for this to error because we know the driver we're passing is // available. c.pqSession, _ = sql.Open("postgres", c.uri) + uri, _ := url.Parse(c.uri) + if uri.Path != "" { + c.db = uri.Path[1:] + } } err := c.pqSession.Ping() - return &Session{c.pqSession}, err + return &Session{c.pqSession, c.db}, err } diff --git a/adaptor/postgres/client_test.go b/adaptor/postgres/client_test.go index f0d40f0f6..755984e82 100644 --- a/adaptor/postgres/client_test.go +++ b/adaptor/postgres/client_test.go @@ -9,6 +9,7 @@ import ( var ( defaultClient = &Client{ uri: DefaultURI, + db: "postgres", } errBadClient = errors.New("bad client") diff --git a/adaptor/postgres/postgres.go b/adaptor/postgres/postgres.go index 53c2d6568..b0cf66891 100644 --- a/adaptor/postgres/postgres.go +++ b/adaptor/postgres/postgres.go @@ -47,17 +47,14 @@ func (p *Postgres) Client() (client.Client, error) { } func (p *Postgres) Reader() (client.Reader, error) { - db, _, err := adaptor.CompileNamespace(p.Namespace) if p.Tail { - return newTailer(db, p.ReplicationSlot), err + return newTailer(p.ReplicationSlot), nil } - return newReader(db), err + return newReader(), nil } func (p *Postgres) Writer(done chan struct{}, wg *sync.WaitGroup) (client.Writer, error) { - // TODO: pull db from the URI - db, _, err := adaptor.CompileNamespace(p.Namespace) - return newWriter(db), err + return newWriter(), nil } // Description for postgres adaptor diff --git a/adaptor/postgres/postgres_test.go b/adaptor/postgres/postgres_test.go index 0d9942fa8..d43af6099 100644 --- a/adaptor/postgres/postgres_test.go +++ b/adaptor/postgres/postgres_test.go @@ -21,8 +21,8 @@ func TestSampleConfig(t *testing.T) { } var initTests = []map[string]interface{}{ - {"uri": DefaultURI, "namespace": "test.test"}, - {"uri": DefaultURI, "namespace": "test.test", "tail": true}, + {"uri": DefaultURI}, + {"uri": DefaultURI, "tail": true}, } func TestInit(t *testing.T) { diff --git a/adaptor/postgres/reader.go b/adaptor/postgres/reader.go index c8e0303c8..ffa04e71b 100644 --- a/adaptor/postgres/reader.go +++ b/adaptor/postgres/reader.go @@ -19,11 +19,10 @@ var ( // Reader implements the behavior defined by client.Reader for interfacing with MongoDB. type Reader struct { - db string } -func newReader(db string) client.Reader { - return &Reader{db} +func newReader() client.Reader { + return &Reader{} } func (r *Reader) Read(filterFn client.NsFilterFunc) client.MessageChanFunc { @@ -32,20 +31,20 @@ func (r *Reader) Read(filterFn client.NsFilterFunc) client.MessageChanFunc { session := s.(*Session) go func() { defer close(out) - log.With("db", r.db).Infoln("starting Read func") - tables, err := r.listTables(session.pqSession, filterFn) + log.With("db", session.db).Infoln("starting Read func") + tables, err := r.listTables(session.db, session.pqSession, filterFn) if err != nil { - log.With("db", r.db).Errorf("unable to list tables, %s", err) + log.With("db", session.db).Errorf("unable to list tables, %s", err) return } - results := r.iterateTable(session.pqSession, tables, done) + results := r.iterateTable(session.db, session.pqSession, tables, done) for { select { case <-done: return case result, ok := <-results: if !ok { - log.With("db", r.db).Infoln("Read completed") + log.With("db", session.db).Infoln("Read completed") return } msg := message.From(ops.Insert, result.table, result.data) @@ -58,9 +57,8 @@ func (r *Reader) Read(filterFn client.NsFilterFunc) client.MessageChanFunc { } } -func (r *Reader) listTables(session *sql.DB, filterFn func(name string) bool) (<-chan string, error) { +func (r *Reader) listTables(db string, session *sql.DB, filterFn func(name string) bool) (<-chan string, error) { out := make(chan string) - fmt.Println("Exporting data from matching tables:") tablesResult, err := session.Query("SELECT table_schema,table_name FROM information_schema.tables") if err != nil { return nil, err @@ -72,18 +70,18 @@ func (r *Reader) listTables(session *sql.DB, filterFn func(name string) bool) (< var tname string err = tablesResult.Scan(&schema, &tname) if err != nil { - log.With("db", r.db).Infoln("error scanning table name...") + log.With("db", db).Infoln("error scanning table name...") continue } name := fmt.Sprintf("%s.%s", schema, tname) if filterFn(name) && matchFunc(name) { - log.With("db", r.db).With("table", name).Infoln("sending for iteration...") + log.With("db", db).With("table", name).Infoln("sending for iteration...") out <- name } else { - log.With("db", r.db).With("table", name).Infoln("skipping iteration...") + log.With("db", db).With("table", name).Debugln("skipping iteration...") } } - log.With("db", r.db).Infoln("done iterating collections") + log.With("db", db).Infoln("done iterating collections") }() return out, nil } @@ -100,7 +98,7 @@ type doc struct { data data.Data } -func (r *Reader) iterateTable(session *sql.DB, in <-chan string, done chan struct{}) <-chan doc { +func (r *Reader) iterateTable(db string, session *sql.DB, in <-chan string, done chan struct{}) <-chan doc { out := make(chan doc) go func() { defer close(out) @@ -110,7 +108,7 @@ func (r *Reader) iterateTable(session *sql.DB, in <-chan string, done chan struc if !ok { return } - log.With("table", c).With("table", c).Infoln("iterating...") + log.With("db", db).With("table", c).With("table", c).Infoln("iterating...") if strings.HasPrefix(c, "information_schema.") || strings.HasPrefix(c, "pg_catalog.") { continue } @@ -124,7 +122,7 @@ func (r *Reader) iterateTable(session *sql.DB, in <-chan string, done chan struc ORDER BY c.ordinal_position; `, schemaTable[0], schemaTable[1])) if err != nil { - log.With("table", c).Errorf("error getting columns %v", err) + log.With("db", db).With("table", c).Errorf("error getting columns %v", err) continue } var columns [][]string @@ -183,9 +181,9 @@ func (r *Reader) iterateTable(session *sql.DB, in <-chan string, done chan struc } out <- doc{table: c, data: docMap} } - log.With("table", c).Infoln("iterating complete") + log.With("db", db).With("table", c).Infoln("iterating complete") case <-done: - log.With("db", r.db).Infoln("iterating no more") + log.With("db", db).Infoln("iterating no more") return } } diff --git a/adaptor/postgres/reader_test.go b/adaptor/postgres/reader_test.go index f1fea934b..e85fa3643 100644 --- a/adaptor/postgres/reader_test.go +++ b/adaptor/postgres/reader_test.go @@ -18,7 +18,7 @@ func TestRead(t *testing.T) { t.Skip("skipping Read in short mode") } - reader := newReader(readerTestData.DB) + reader := newReader() readFunc := reader.Read(func(table string) bool { if strings.HasPrefix(table, "information_schema.") || strings.HasPrefix(table, "pg_catalog.") { return false @@ -26,7 +26,16 @@ func TestRead(t *testing.T) { return table == fmt.Sprintf("public.%s", readerTestData.Table) }) done := make(chan struct{}) - msgChan, err := readFunc(defaultSession, done) + c, err := NewClient(WithURI(fmt.Sprintf("postgres://127.0.0.1:5432/%s?sslmode=disable", readerTestData.DB))) + if err != nil { + t.Fatalf("unable to initialize connection to postgres, %s", err) + } + defer c.Close() + s, err := c.Connect() + if err != nil { + t.Fatalf("unable to obtain session to postgres, %s", err) + } + msgChan, err := readFunc(s, done) if err != nil { t.Fatalf("unexpected Read error, %s\n", err) } @@ -49,7 +58,7 @@ func TestReadComplex(t *testing.T) { t.Skip("skipping Read in short mode") } - reader := newReader(readerComplexTestData.DB) + reader := newReader() readFunc := reader.Read(func(table string) bool { if strings.HasPrefix(table, "information_schema.") || strings.HasPrefix(table, "pg_catalog.") { return false @@ -57,7 +66,16 @@ func TestReadComplex(t *testing.T) { return table == fmt.Sprintf("public.%s", readerComplexTestData.Table) }) done := make(chan struct{}) - msgChan, err := readFunc(defaultSession, done) + c, err := NewClient(WithURI(fmt.Sprintf("postgres://127.0.0.1:5432/%s?sslmode=disable", readerComplexTestData.DB))) + if err != nil { + t.Fatalf("unable to initialize connection to postgres, %s", err) + } + defer c.Close() + s, err := c.Connect() + if err != nil { + t.Fatalf("unable to obtain session to postgres, %s", err) + } + msgChan, err := readFunc(s, done) if err != nil { t.Fatalf("unexpected Read error, %s\n", err) } diff --git a/adaptor/postgres/session.go b/adaptor/postgres/session.go index 4ea634f7c..bb33f1d64 100644 --- a/adaptor/postgres/session.go +++ b/adaptor/postgres/session.go @@ -11,4 +11,5 @@ var _ client.Session = &Session{} // Session serves as a wrapper for the underlying *sql.DB type Session struct { pqSession *sql.DB + db string } diff --git a/adaptor/postgres/tailer.go b/adaptor/postgres/tailer.go index 7691f2e45..dda2ef652 100644 --- a/adaptor/postgres/tailer.go +++ b/adaptor/postgres/tailer.go @@ -23,12 +23,11 @@ var ( // Tailer implements the behavior defined by client.Tailer for interfacing with the MongoDB oplog. type Tailer struct { reader client.Reader - db string replicationSlot string } -func newTailer(db, replicationSlot string) client.Reader { - return &Tailer{newReader(db), db, replicationSlot} +func newTailer(replicationSlot string) client.Reader { + return &Tailer{newReader(), replicationSlot} } // Tail does the things @@ -39,6 +38,7 @@ func (t *Tailer) Read(filterFn client.NsFilterFunc) client.MessageChanFunc { if err != nil { return nil, err } + session := s.(*Session) out := make(chan message.Msg) go func() { defer close(out) @@ -48,16 +48,16 @@ func (t *Tailer) Read(filterFn client.NsFilterFunc) client.MessageChanFunc { } // start tailing - log.With("db", t.db).With("logical_decoding_slot", t.replicationSlot).Infoln("Listening for changes...") + log.With("db", session.db).With("logical_decoding_slot", t.replicationSlot).Infoln("Listening for changes...") for { select { case <-done: - log.With("db", t.db).Infoln("tailing stopping...") + log.With("db", session.db).Infoln("tailing stopping...") return case <-time.After(time.Second): msgSlice, err := t.pluckFromLogicalDecoding(s.(*Session), filterFn) if err != nil { - log.With("db", t.db).Errorf("error plucking from logical decoding %v", err) + log.With("db", session.db).Errorf("error plucking from logical decoding %v", err) continue } for _, msg := range msgSlice { diff --git a/adaptor/postgres/tailer_test.go b/adaptor/postgres/tailer_test.go index 74ff14a10..8ba0efa57 100644 --- a/adaptor/postgres/tailer_test.go +++ b/adaptor/postgres/tailer_test.go @@ -1,6 +1,7 @@ package postgres import ( + "database/sql" "fmt" "strings" "sync" @@ -10,15 +11,15 @@ import ( "github.com/compose/transporter/message" ) -func addTestReplicationSlot() error { - _, err := defaultSession.pqSession.Exec(` +func addTestReplicationSlot(s *sql.DB) error { + _, err := s.Exec(` SELECT * FROM pg_create_logical_replication_slot('test_slot', 'test_decoding'); `) return err } -func dropTestReplicationSlot() error { - _, err := defaultSession.pqSession.Exec(` +func dropTestReplicationSlot(s *sql.DB) error { + _, err := s.Exec(` SELECT * FROM pg_drop_replication_slot('test_slot'); `) return err @@ -32,13 +33,23 @@ func TestTailer(t *testing.T) { if testing.Short() { t.Skip("skipping Tailer in short mode") } - dropTestReplicationSlot() - if err := addTestReplicationSlot(); err != nil { + c, err := NewClient(WithURI(fmt.Sprintf("postgres://127.0.0.1:5432/%s?sslmode=disable", tailerTestData.DB))) + if err != nil { + t.Fatalf("unable to initialize connection to postgres, %s", err) + } + defer c.Close() + s, err := c.Connect() + if err != nil { + t.Fatalf("unable to obtain session to postgres, %s", err) + } + + dropTestReplicationSlot(s.(*Session).pqSession) + if err := addTestReplicationSlot(s.(*Session).pqSession); err != nil { t.Fatalf("unable to create replication slot, %s", err) } time.Sleep(1 * time.Second) - r := newTailer(tailerTestData.DB, "test_slot") + r := newTailer("test_slot") readFunc := r.Read(func(table string) bool { if strings.HasPrefix(table, "information_schema.") || strings.HasPrefix(table, "pg_catalog.") { return false @@ -46,14 +57,14 @@ func TestTailer(t *testing.T) { return table == fmt.Sprintf("public.%s", tailerTestData.Table) }) done := make(chan struct{}) - msgChan, err := readFunc(defaultSession, done) + msgChan, err := readFunc(s, done) if err != nil { t.Fatalf("unexpected Read error, %s\n", err) } checkCount("initial drain", tailerTestData.InsertCount, msgChan, t) for i := 10; i < 20; i++ { - defaultSession.pqSession.Exec(fmt.Sprintf(`INSERT INTO %s VALUES ( + s.(*Session).pqSession.Exec(fmt.Sprintf(`INSERT INTO %s VALUES ( %d, -- id '%s', -- colvar VARCHAR(255), now() at time zone 'utc' -- coltimestamp TIMESTAMP, @@ -62,12 +73,12 @@ func TestTailer(t *testing.T) { checkCount("tailed data", 10, msgChan, t) for i := 10; i < 20; i++ { - defaultSession.pqSession.Exec(fmt.Sprintf("UPDATE %s SET colvar = 'hello' WHERE id = %d;", tailerTestData.Table, i)) + s.(*Session).pqSession.Exec(fmt.Sprintf("UPDATE %s SET colvar = 'hello' WHERE id = %d;", tailerTestData.Table, i)) } checkCount("updated data", 10, msgChan, t) for i := 10; i < 20; i++ { - defaultSession.pqSession.Exec(fmt.Sprintf(`DELETE FROM %v WHERE id = %d; `, tailerTestData.Table, i)) + s.(*Session).pqSession.Exec(fmt.Sprintf(`DELETE FROM %v WHERE id = %d; `, tailerTestData.Table, i)) } checkCount("deleted data", 10, msgChan, t) diff --git a/adaptor/postgres/writer.go b/adaptor/postgres/writer.go index b4da1aea7..089d8230c 100644 --- a/adaptor/postgres/writer.go +++ b/adaptor/postgres/writer.go @@ -16,12 +16,11 @@ var _ client.Writer = &Writer{} // Writer implements client.Writer for use with MongoDB type Writer struct { - db string writeMap map[ops.Op]func(message.Msg, *sql.DB) error } -func newWriter(db string) *Writer { - w := &Writer{db: db} +func newWriter() *Writer { + w := &Writer{} w.writeMap = map[ops.Op]func(message.Msg, *sql.DB) error{ ops.Insert: insertMsg, ops.Update: updateMsg, diff --git a/adaptor/postgres/writer_test.go b/adaptor/postgres/writer_test.go index 427332f4a..0b6fdc582 100644 --- a/adaptor/postgres/writer_test.go +++ b/adaptor/postgres/writer_test.go @@ -23,7 +23,7 @@ var optests = []struct { } func TestOpFunc(t *testing.T) { - w := newWriter("test") + w := newWriter() for _, ot := range optests { if _, ok := w.writeMap[ot.op]; ok != ot.registered { t.Errorf("op (%s) registration incorrect, expected %+v, got %+v\n", ot.op.String(), ot.registered, ok) @@ -36,13 +36,22 @@ var ( ) func TestInsert(t *testing.T) { - w := newWriter(writerTestData.DB) + w := newWriter() + c, err := NewClient(WithURI(fmt.Sprintf("postgres://127.0.0.1:5432/%s?sslmode=disable", writerTestData.DB))) + if err != nil { + t.Fatalf("unable to initialize connection to postgres, %s", err) + } + defer c.Close() + s, err := c.Connect() + if err != nil { + t.Fatalf("unable to obtain session to postgres, %s", err) + } for i := 0; i < 10; i++ { msg := message.From( ops.Insert, fmt.Sprintf("public.%s", writerTestData.Table), data.Data{"id": i, "colvar": "hello world", "coltimestamp": time.Now().UTC()}) - if _, err := w.Write(msg)(defaultSession); err != nil { + if _, err := w.Write(msg)(s); err != nil { t.Errorf("unexpected Insert error, %s\n", err) } } @@ -52,7 +61,7 @@ func TestInsert(t *testing.T) { stringValue string timeValue time.Time ) - if err := defaultSession.pqSession. + if err := s.(*Session).pqSession. QueryRow(fmt.Sprintf("SELECT id, colvar, coltimestamp FROM %s WHERE id = 4", writerTestData.Table)). Scan(&id, &stringValue, &timeValue); err != nil { t.Fatalf("Error on test query: %v", err) @@ -62,7 +71,7 @@ func TestInsert(t *testing.T) { } var count int - err := defaultSession.pqSession. + err = s.(*Session).pqSession. QueryRow(fmt.Sprintf("SELECT COUNT(id) FROM %s;", writerTestData.Table)). Scan(&count) if err != nil { @@ -78,13 +87,22 @@ var ( ) func TestUpdate(t *testing.T) { - w := newWriter(writerUpdateTestData.DB) + w := newWriter() + c, err := NewClient(WithURI(fmt.Sprintf("postgres://127.0.0.1:5432/%s?sslmode=disable", writerUpdateTestData.DB))) + if err != nil { + t.Fatalf("unable to initialize connection to postgres, %s", err) + } + defer c.Close() + s, err := c.Connect() + if err != nil { + t.Fatalf("unable to obtain session to postgres, %s", err) + } for i := 0; i < 10; i++ { msg := message.From( ops.Insert, fmt.Sprintf("public.%s", writerUpdateTestData.Table), data.Data{"id": i, "colvar": "hello world", "coltimestamp": time.Now().UTC()}) - if _, err := w.Write(msg)(defaultSession); err != nil { + if _, err := w.Write(msg)(s); err != nil { t.Errorf("unexpected Insert error, %s\n", err) } } @@ -92,7 +110,7 @@ func TestUpdate(t *testing.T) { ops.Update, fmt.Sprintf("public.%s", writerUpdateTestData.Table), data.Data{"id": 1, "colvar": "robin", "coltimestamp": time.Now().UTC()}) - if _, err := w.Write(msg)(defaultSession); err != nil { + if _, err := w.Write(msg)(s); err != nil { t.Errorf("unexpected Update error, %s\n", err) } @@ -101,7 +119,7 @@ func TestUpdate(t *testing.T) { stringValue string timeValue time.Time ) - if err := defaultSession.pqSession. + if err := s.(*Session).pqSession. QueryRow(fmt.Sprintf("SELECT id, colvar, coltimestamp FROM %s WHERE id = 1", writerUpdateTestData.Table)). Scan(&id, &stringValue, &timeValue); err != nil { t.Fatalf("Error on test query: %v", err) @@ -111,7 +129,7 @@ func TestUpdate(t *testing.T) { } var count int - err := defaultSession.pqSession. + err = s.(*Session).pqSession. QueryRow(fmt.Sprintf("SELECT COUNT(id) FROM %s;", writerUpdateTestData.Table)). Scan(&count) if err != nil { @@ -128,7 +146,16 @@ var ( func TestComplexUpdate(t *testing.T) { ranInt := rand.Intn(writerComplexUpdateTestData.InsertCount) - w := newWriter(writerComplexUpdateTestData.DB) + w := newWriter() + c, err := NewClient(WithURI(fmt.Sprintf("postgres://127.0.0.1:5432/%s?sslmode=disable", writerComplexUpdateTestData.DB))) + if err != nil { + t.Fatalf("unable to initialize connection to postgres, %s", err) + } + defer c.Close() + s, err := c.Connect() + if err != nil { + t.Fatalf("unable to obtain session to postgres, %s", err) + } msg := message.From(ops.Update, fmt.Sprintf("public.%s", writerComplexUpdateTestData.Table), data.Data{ "id": ranInt, "colvar": randomHeros[ranInt], @@ -167,7 +194,7 @@ func TestComplexUpdate(t *testing.T) { "coluuid": "f0a0da24-4068-4be4-961d-7c295117ccca", "colxml": "Batman", }) - if _, err := w.Write(msg)(defaultSession); err != nil { + if _, err := w.Write(msg)(s); err != nil { t.Errorf("unexpected Update error, %s\n", err) } @@ -177,7 +204,7 @@ func TestComplexUpdate(t *testing.T) { timeValue time.Time bigint int64 ) - if err := defaultSession.pqSession. + if err := s.(*Session).pqSession. QueryRow(fmt.Sprintf("SELECT id, colvar, coltimestamp, colbigint FROM %s WHERE id = %d", writerComplexUpdateTestData.Table, ranInt)). Scan(&id, &stringValue, &timeValue, &bigint); err != nil { t.Fatalf("Error on test query: %v", err) @@ -187,7 +214,7 @@ func TestComplexUpdate(t *testing.T) { } var count int - err := defaultSession.pqSession. + err = s.(*Session).pqSession. QueryRow(fmt.Sprintf("SELECT COUNT(id) FROM %s;", writerComplexUpdateTestData.Table)). Scan(&count) if err != nil { @@ -203,30 +230,39 @@ var ( ) func TestDelete(t *testing.T) { - w := newWriter(writerDeleteTestData.DB) + w := newWriter() + c, err := NewClient(WithURI(fmt.Sprintf("postgres://127.0.0.1:5432/%s?sslmode=disable", writerDeleteTestData.DB))) + if err != nil { + t.Fatalf("unable to initialize connection to postgres, %s", err) + } + defer c.Close() + s, err := c.Connect() + if err != nil { + t.Fatalf("unable to obtain session to postgres, %s", err) + } for i := 0; i < 10; i++ { msg := message.From( ops.Insert, fmt.Sprintf("public.%s", writerDeleteTestData.Table), data.Data{"id": i, "colvar": "hello world", "coltimestamp": time.Now().UTC()}) - if _, err := w.Write(msg)(defaultSession); err != nil { + if _, err := w.Write(msg)(s); err != nil { t.Errorf("unexpected Insert error, %s\n", err) } } msg := message.From(ops.Delete, fmt.Sprintf("public.%s", writerDeleteTestData.Table), data.Data{"id": 1}) - if _, err := w.Write(msg)(defaultSession); err != nil { + if _, err := w.Write(msg)(s); err != nil { t.Errorf("unexpected Update error, %s\n", err) } var id int - if err := defaultSession.pqSession. + if err := s.(*Session).pqSession. QueryRow(fmt.Sprintf("SELECT id FROM %s WHERE id = 1", writerDeleteTestData.Table)). Scan(&id); err == nil { t.Fatalf("Values were found, but where not expected to be: %v", id) } var count int - err := defaultSession.pqSession. + err = s.(*Session).pqSession. QueryRow(fmt.Sprintf("SELECT COUNT(id) FROM %s;", writerDeleteTestData.Table)). Scan(&count) if err != nil { @@ -243,17 +279,26 @@ var ( func TestComplexDelete(t *testing.T) { ranInt := rand.Intn(writerComplexDeleteTestData.InsertCount) - w := newWriter(writerComplexDeleteTestData.DB) + w := newWriter() + c, err := NewClient(WithURI(fmt.Sprintf("postgres://127.0.0.1:5432/%s?sslmode=disable", writerComplexDeleteTestData.DB))) + if err != nil { + t.Fatalf("unable to initialize connection to postgres, %s", err) + } + defer c.Close() + s, err := c.Connect() + if err != nil { + t.Fatalf("unable to obtain session to postgres, %s", err) + } msg := message.From( ops.Delete, fmt.Sprintf("public.%s", writerComplexDeleteTestData.Table), data.Data{"id": ranInt, "colvar": randomHeros[ranInt]}) - if _, err := w.Write(msg)(defaultSession); err != nil { + if _, err := w.Write(msg)(s); err != nil { t.Errorf("unexpected Delete error, %s\n", err) } var id int - if err := defaultSession.pqSession. + if err := s.(*Session).pqSession. QueryRow(fmt.Sprintf("SELECT id FROM %s WHERE id = %d AND colvar = '%s'", writerComplexDeleteTestData.Table, ranInt, randomHeros[ranInt])). Scan(&id); err == nil { t.Fatalf("Values were found, but where not expected to be: %v", id) @@ -266,15 +311,23 @@ var ( func TestComplexDeleteWithoutAllPrimarykeys(t *testing.T) { ranInt := rand.Intn(writerComplexDeletePkTestData.InsertCount) - w := newWriter(writerComplexDeletePkTestData.DB) - + w := newWriter() + c, err := NewClient(WithURI(fmt.Sprintf("postgres://127.0.0.1:5432/%s?sslmode=disable", writerComplexDeletePkTestData.DB))) + if err != nil { + t.Fatalf("unable to initialize connection to postgres, %s", err) + } + defer c.Close() + s, err := c.Connect() + if err != nil { + t.Fatalf("unable to obtain session to postgres, %s", err) + } msg := message.From(ops.Delete, fmt.Sprintf("public.%s", writerComplexDeletePkTestData.Table), data.Data{"id": ranInt}) - if _, err := w.Write(msg)(defaultSession); err == nil { + if _, err := w.Write(msg)(s); err == nil { t.Fatalf("Did not receive anticipated error from postgres.writeMessage") } var id int - if err := defaultSession.pqSession. + if err := s.(*Session).pqSession. QueryRow(fmt.Sprintf("SELECT id FROM %s WHERE id = %d AND colvar = '%s'", writerComplexDeletePkTestData.Table, ranInt, randomHeros[ranInt])). diff --git a/adaptor/rethinkdb/adaptor_test.go b/adaptor/rethinkdb/adaptor_test.go index dcaf71252..745351a24 100644 --- a/adaptor/rethinkdb/adaptor_test.go +++ b/adaptor/rethinkdb/adaptor_test.go @@ -12,7 +12,6 @@ import ( var ( defaultTestClient = &Client{ uri: DefaultURI, - db: DefaultDatabase, sessionTimeout: DefaultTimeout, } defaultSession *Session diff --git a/adaptor/rethinkdb/client.go b/adaptor/rethinkdb/client.go index 18257fcb5..d70c3790e 100644 --- a/adaptor/rethinkdb/client.go +++ b/adaptor/rethinkdb/client.go @@ -25,9 +25,6 @@ const ( // DefaultTimeout is the default time.Duration used if one is not provided for options // that pertain to timeouts. DefaultTimeout = 10 * time.Second - - // DefaultDatabase used for the connection options - DefaultDatabase = "test" ) var ( @@ -75,7 +72,6 @@ func NewClient(options ...ClientOptionFunc) (*Client, error) { // Set up the client c := &Client{ uri: DefaultURI, - db: DefaultDatabase, sessionTimeout: DefaultTimeout, tlsConfig: nil, } @@ -101,17 +97,6 @@ func WithURI(uri string) ClientOptionFunc { } } -// WithDatabase configures the database to use for the connection. -func WithDatabase(db string) ClientOptionFunc { - return func(c *Client) error { - if db == "" { - db = DefaultDatabase - } - c.db = db - return nil - } -} - // WithSessionTimeout overrides the DefaultTimeout and should be parseable by time.ParseDuration func WithSessionTimeout(timeout string) ClientOptionFunc { return func(c *Client) error { @@ -216,6 +201,7 @@ func (c *Client) Close() { func (c *Client) initConnection() error { uri, _ := url.Parse(c.uri) + c.db = uri.Path[1:] opts := r.ConnectOpts{ Addresses: strings.Split(uri.Host, ","), Database: c.db, diff --git a/adaptor/rethinkdb/client_test.go b/adaptor/rethinkdb/client_test.go index bd4e24dc4..9f3d12682 100644 --- a/adaptor/rethinkdb/client_test.go +++ b/adaptor/rethinkdb/client_test.go @@ -40,7 +40,6 @@ yuGnBXj8ytqU0CwIPX4WecigUCAkVDNx var ( defaultClient = &Client{ uri: DefaultURI, - db: DefaultDatabase, sessionTimeout: DefaultTimeout, tlsConfig: nil, } @@ -66,30 +65,9 @@ var clientTests = []struct { }, { "with_url_fake", - []ClientOptionFunc{WithURI("rethinkdb://fakeurl:28015")}, + []ClientOptionFunc{WithURI("rethinkdb://fakeurl:28015/test")}, &Client{ - uri: "rethinkdb://fakeurl:28015", - db: DefaultDatabase, - sessionTimeout: DefaultTimeout, - }, - nil, - }, - { - "with_database", - []ClientOptionFunc{WithDatabase("not_the_default")}, - &Client{ - uri: DefaultURI, - db: "not_the_default", - sessionTimeout: DefaultTimeout, - }, - nil, - }, - { - "with_database_empty", - []ClientOptionFunc{WithDatabase("")}, - &Client{ - uri: DefaultURI, - db: DefaultDatabase, + uri: "rethinkdb://fakeurl:28015/test", sessionTimeout: DefaultTimeout, }, nil, @@ -99,7 +77,6 @@ var clientTests = []struct { []ClientOptionFunc{WithSessionTimeout("30s")}, &Client{ uri: DefaultURI, - db: DefaultDatabase, sessionTimeout: 30 * time.Second, }, nil, @@ -121,7 +98,6 @@ var clientTests = []struct { []ClientOptionFunc{WithSSL(true)}, &Client{ uri: DefaultURI, - db: DefaultDatabase, sessionTimeout: DefaultTimeout, tlsConfig: &tls.Config{InsecureSkipVerify: true, RootCAs: x509.NewCertPool()}, }, @@ -132,7 +108,6 @@ var clientTests = []struct { []ClientOptionFunc{WithSSL(true), WithCACerts([]string{rootPEM})}, &Client{ uri: DefaultURI, - db: DefaultDatabase, sessionTimeout: DefaultTimeout, tlsConfig: &tls.Config{InsecureSkipVerify: false, RootCAs: certPool()}, }, @@ -143,7 +118,6 @@ var clientTests = []struct { []ClientOptionFunc{WithCACerts([]string{rootPEM})}, &Client{ uri: DefaultURI, - db: DefaultDatabase, sessionTimeout: DefaultTimeout, tlsConfig: &tls.Config{InsecureSkipVerify: false, RootCAs: certPool()}, }, @@ -154,7 +128,6 @@ var clientTests = []struct { []ClientOptionFunc{WithCACerts([]string{"notacert"})}, &Client{ uri: DefaultURI, - db: DefaultDatabase, sessionTimeout: DefaultTimeout, tlsConfig: &tls.Config{InsecureSkipVerify: false, RootCAs: certPool()}, }, @@ -191,7 +164,6 @@ var ( "default connect", &Client{ uri: DefaultURI, - db: DefaultDatabase, sessionTimeout: DefaultTimeout, }, nil, @@ -199,7 +171,7 @@ var ( { "timeout connect", &Client{ - uri: "rethinkdb://127.0.0.1:37017", + uri: "rethinkdb://127.0.0.1:37017/test", sessionTimeout: 2 * time.Second, }, client.ConnectError{Reason: "gorethink: dial tcp 127.0.0.1:37017: getsockopt: connection refused"}, @@ -207,8 +179,7 @@ var ( { "authenticated connect", &Client{ - uri: "rethinkdb://admin:admin123@127.0.0.1:48015", - db: DefaultDatabase, + uri: "rethinkdb://admin:admin123@127.0.0.1:48015/test", sessionTimeout: DefaultTimeout, }, nil, @@ -216,8 +187,7 @@ var ( { "failed authenticated connect", &Client{ - uri: "rethinkdb://admin:wrongpassword@127.0.0.1:48015", - db: DefaultDatabase, + uri: "rethinkdb://admin:wrongpassword@127.0.0.1:48015/test", sessionTimeout: DefaultTimeout, }, client.ConnectError{Reason: "gorethink: Wrong password"}, @@ -225,8 +195,7 @@ var ( { "connect with ssl and verify", &Client{ - uri: "rethinkdb://localhost:38015", - db: DefaultDatabase, + uri: "rethinkdb://localhost:38015/test", sessionTimeout: DefaultTimeout, tlsConfig: &tls.Config{InsecureSkipVerify: false, RootCAs: caCertPool()}, }, @@ -235,8 +204,7 @@ var ( { "connect with ssl skip verify", &Client{ - uri: "rethinkdb://localhost:38015", - db: DefaultDatabase, + uri: "rethinkdb://localhost:38015/test", sessionTimeout: DefaultTimeout, tlsConfig: &tls.Config{InsecureSkipVerify: true, RootCAs: x509.NewCertPool()}, }, diff --git a/adaptor/rethinkdb/reader.go b/adaptor/rethinkdb/reader.go index 322d693a5..48fc3dddf 100644 --- a/adaptor/rethinkdb/reader.go +++ b/adaptor/rethinkdb/reader.go @@ -19,12 +19,11 @@ var ( // Reader fulfills the client.Reader interface for use with both copying and tailing a RethinkDB // database. type Reader struct { - db string tail bool } -func newReader(db string, tail bool) client.Reader { - return &Reader{db, tail} +func newReader(tail bool) client.Reader { + return &Reader{tail} } type iterationComplete struct { @@ -38,10 +37,10 @@ func (r *Reader) Read(filterFn client.NsFilterFunc) client.MessageChanFunc { session := s.(*Session).session go func() { defer close(out) - log.With("db", r.db).Infoln("starting Read func") + log.With("db", session.Database()).Infoln("starting Read func") tables, err := r.listTables(session, filterFn) if err != nil { - log.With("db", r.db).Errorf("unable to list tables, %s", err) + log.With("db", session.Database()).Errorf("unable to list tables, %s", err) return } iterationComplete := r.iterateTable(session, tables, out, done) @@ -55,14 +54,14 @@ func (r *Reader) Read(filterFn client.NsFilterFunc) client.MessageChanFunc { if !ok { return } - log.With("db", r.db).With("table", i.table).Infoln("iterating complete") + log.With("db", session.Database()).With("table", i.table).Infoln("iterating complete") if i.cursor != nil { go func(wg *sync.WaitGroup, t string, c *re.Cursor) { wg.Add(1) defer wg.Done() - errc := r.sendChanges(t, c, out, done) + errc := r.sendChanges(session.Database(), t, c, out, done) for err := range errc { - log.With("db", r.db).With("table", t).Errorln(err) + log.With("db", session.Database()).With("table", t).Errorln(err) return } }(&wg, i.table, i.cursor) @@ -70,7 +69,7 @@ func (r *Reader) Read(filterFn client.NsFilterFunc) client.MessageChanFunc { } } }() - log.With("db", r.db).Infoln("Read completed") + log.With("db", session.Database()).Infoln("Read completed") // this will block if we're tailing wg.Wait() return @@ -81,7 +80,7 @@ func (r *Reader) Read(filterFn client.NsFilterFunc) client.MessageChanFunc { func (r *Reader) listTables(session *re.Session, filterFn func(name string) bool) (<-chan string, error) { out := make(chan string) - tables, err := re.DB(r.db).TableList().Run(session) + tables, err := re.DB(session.Database()).TableList().Run(session) if err != nil { return nil, err } @@ -92,13 +91,13 @@ func (r *Reader) listTables(session *re.Session, filterFn func(name string) bool var table string for tables.Next(&table) { if filterFn(table) { - log.With("db", r.db).With("table", table).Infoln("sending for iteration...") + log.With("db", session.Database()).With("table", table).Infoln("sending for iteration...") out <- table } else { - log.With("db", r.db).With("table", table).Infoln("skipping iteration...") + log.With("db", session.Database()).With("table", table).Infoln("skipping iteration...") } } - log.With("db", r.db).Infoln("done iterating tables") + log.With("db", session.Database()).Infoln("done iterating tables") }() return out, nil } @@ -113,8 +112,8 @@ func (r *Reader) iterateTable(session *re.Session, in <-chan string, out chan<- if !ok { return } - log.With("db", r.db).With("table", t).Infoln("iterating...") - cursor, err := re.DB(r.db).Table(t).Run(session) + log.With("db", session.Database()).With("table", t).Infoln("iterating...") + cursor, err := re.DB(session.Database()).Table(t).Run(session) if err != nil { return } @@ -122,7 +121,7 @@ func (r *Reader) iterateTable(session *re.Session, in <-chan string, out chan<- var ccursor *re.Cursor if r.tail { var err error - ccursor, err = re.DB(r.db).Table(t).Changes(re.ChangesOpts{}).Run(session) + ccursor, err = re.DB(session.Database()).Table(t).Changes(re.ChangesOpts{}).Run(session) if err != nil { return } @@ -139,7 +138,7 @@ func (r *Reader) iterateTable(session *re.Session, in <-chan string, out chan<- } tableDone <- iterationComplete{ccursor, t} case <-done: - log.With("db", r.db).Infoln("iterating no more") + log.With("db", session.Database()).Infoln("iterating no more") return } } @@ -153,14 +152,14 @@ type rethinkDbChangeNotification struct { NewVal map[string]interface{} `gorethink:"new_val"` } -func (r *Reader) sendChanges(table string, ccursor *re.Cursor, out chan<- message.Msg, done chan struct{}) chan error { +func (r *Reader) sendChanges(db, table string, ccursor *re.Cursor, out chan<- message.Msg, done chan struct{}) chan error { errc := make(chan error) go func() { defer ccursor.Close() defer close(errc) changes := make(chan rethinkDbChangeNotification) ccursor.Listen(changes) - log.With("db", r.db).With("table", table).Debugln("starting changes feed...") + log.With("db", db).With("table", table).Debugln("starting changes feed...") for { if err := ccursor.Err(); err != nil { errc <- err @@ -168,14 +167,14 @@ func (r *Reader) sendChanges(table string, ccursor *re.Cursor, out chan<- messag } select { case <-done: - log.With("db", r.db).With("table", table).Infoln("stopping changes...") + log.With("db", db).With("table", table).Infoln("stopping changes...") return case change := <-changes: if done == nil { - log.With("db", r.db).With("table", table).Infoln("stopping changes...") + log.With("db", db).With("table", table).Infoln("stopping changes...") return } - log.With("db", r.db).With("table", table).With("change", change).Debugln("received") + log.With("db", db).With("table", table).With("change", change).Debugln("received") if change.Error != "" { errc <- errors.New(change.Error) diff --git a/adaptor/rethinkdb/reader_test.go b/adaptor/rethinkdb/reader_test.go index 9b8cc7098..ac5232fbf 100644 --- a/adaptor/rethinkdb/reader_test.go +++ b/adaptor/rethinkdb/reader_test.go @@ -1,6 +1,7 @@ package rethinkdb import ( + "fmt" "sync" "testing" "time" @@ -19,10 +20,19 @@ func TestRead(t *testing.T) { t.Skip("skipping Read in short mode") } - reader := newReader(readerTestData.DB, false) + reader := newReader(false) readFunc := reader.Read(func(c string) bool { return true }) done := make(chan struct{}) - msgChan, err := readFunc(defaultSession, done) + c, err := NewClient(WithURI(fmt.Sprintf("rethinkdb://127.0.0.1:28015/%s", readerTestData.DB))) + if err != nil { + t.Fatalf("unable to initialize connection to rethinkdb, %s", err) + } + defer c.Close() + s, err := c.Connect() + if err != nil { + t.Fatalf("unable to obtain session to rethinkdb, %s", err) + } + msgChan, err := readFunc(s, done) if err != nil { t.Fatalf("unexpected Read error, %s\n", err) } @@ -95,7 +105,7 @@ func TestTail(t *testing.T) { t.Fatalf("unexpected insertMockTailData error, %s\n", err) } - tail := newReader(tailTestData.DB, true) + tail := newReader(true) time.Sleep(1 * time.Second) tailFunc := tail.Read(func(c string) bool { @@ -105,7 +115,16 @@ func TestTail(t *testing.T) { return true }) done := make(chan struct{}) - msgChan, err := tailFunc(defaultSession, done) + c, err := NewClient(WithURI(fmt.Sprintf("rethinkdb://127.0.0.1:28015/%s", tailTestData.DB))) + if err != nil { + t.Fatalf("unable to initialize connection to rethinkdb, %s", err) + } + defer c.Close() + s, err := c.Connect() + if err != nil { + t.Fatalf("unable to obtain session to rethinkdb, %s", err) + } + msgChan, err := tailFunc(s, done) if err != nil { t.Fatalf("unexpected Tail error, %s\n", err) } diff --git a/adaptor/rethinkdb/rethinkdb.go b/adaptor/rethinkdb/rethinkdb.go index 03669c773..149d9fab7 100644 --- a/adaptor/rethinkdb/rethinkdb.go +++ b/adaptor/rethinkdb/rethinkdb.go @@ -43,10 +43,8 @@ func init() { func (r *RethinkDB) Client() (client.Client, error) { // TODO: pull db from the URI - db, _, _ := adaptor.CompileNamespace(r.Namespace) return NewClient( WithURI(r.URI), - WithDatabase(db), WithSessionTimeout(r.Timeout), WithSSL(r.SSL), WithCACerts(r.CACerts), @@ -54,15 +52,11 @@ func (r *RethinkDB) Client() (client.Client, error) { } func (r *RethinkDB) Reader() (client.Reader, error) { - // TODO: pull db from the URI - db, _, err := adaptor.CompileNamespace(r.Namespace) - return newReader(db, r.Tail), err + return newReader(r.Tail), nil } func (r *RethinkDB) Writer(done chan struct{}, wg *sync.WaitGroup) (client.Writer, error) { - // TODO: pull db from the URI - db, _, err := adaptor.CompileNamespace(r.Namespace) - return newWriter(db, done, wg), err + return newWriter(done, wg), nil } // Description for rethinkdb adaptor diff --git a/adaptor/rethinkdb/rethinkdb_test.go b/adaptor/rethinkdb/rethinkdb_test.go index d9d9ca351..3894226a0 100644 --- a/adaptor/rethinkdb/rethinkdb_test.go +++ b/adaptor/rethinkdb/rethinkdb_test.go @@ -22,8 +22,8 @@ func TestSampleConfig(t *testing.T) { } var initTests = []map[string]interface{}{ - {"uri": DefaultURI, "namespace": "test.test"}, - {"uri": DefaultURI, "namespace": "test.test", "tail": true}, + {"uri": DefaultURI}, + {"uri": DefaultURI, "tail": true}, } func TestInit(t *testing.T) { diff --git a/adaptor/rethinkdb/writer.go b/adaptor/rethinkdb/writer.go index 1767e2f42..8be968d44 100644 --- a/adaptor/rethinkdb/writer.go +++ b/adaptor/rethinkdb/writer.go @@ -24,7 +24,6 @@ var ( // Writer implements client.Writer for use with RethinkDB type Writer struct { - db string bulkMap map[string]*bulkOperation *sync.Mutex opCounter int @@ -35,9 +34,8 @@ type bulkOperation struct { docs []map[string]interface{} } -func newWriter(db string, done chan struct{}, wg *sync.WaitGroup) *Writer { +func newWriter(done chan struct{}, wg *sync.WaitGroup) *Writer { w := &Writer{ - db: db, bulkMap: make(map[string]*bulkOperation), Mutex: &sync.Mutex{}, } @@ -53,7 +51,7 @@ func (w *Writer) Write(msg message.Msg) func(client.Session) (message.Msg, error switch msg.OP() { case ops.Delete: w.flushAll() - return msg, do(r.DB(w.db).Table(table).Get(prepareDocument(msg)["id"]).Delete(), rSession) + return msg, do(r.DB(rSession.Database()).Table(table).Get(prepareDocument(msg)["id"]).Delete(), rSession) case ops.Insert: w.Lock() bOp, ok := w.bulkMap[table] @@ -72,7 +70,7 @@ func (w *Writer) Write(msg message.Msg) func(client.Session) (message.Msg, error } case ops.Update: w.flushAll() - return msg, do(r.DB(w.db).Table(table).Insert(prepareDocument(msg), r.InsertOpts{Conflict: "replace"}), rSession) + return msg, do(r.DB(rSession.Database()).Table(table).Insert(prepareDocument(msg), r.InsertOpts{Conflict: "replace"}), rSession) } return msg, nil } @@ -115,8 +113,8 @@ func (w *Writer) flushAll() error { w.Unlock() }() for t, bOp := range w.bulkMap { - log.With("db", w.db).With("table", t).With("op_counter", w.opCounter).With("doc_count", len(bOp.docs)).Infoln("flushing bulk messages") - resp, err := r.DB(w.db).Table(t).Insert(bOp.docs, r.InsertOpts{Conflict: "replace"}).RunWrite(bOp.s) + log.With("db", bOp.s.Database()).With("table", t).With("op_counter", w.opCounter).With("doc_count", len(bOp.docs)).Infoln("flushing bulk messages") + resp, err := r.DB(bOp.s.Database()).Table(t).Insert(bOp.docs, r.InsertOpts{Conflict: "replace"}).RunWrite(bOp.s) if err != nil { return err } diff --git a/adaptor/rethinkdb/writer_test.go b/adaptor/rethinkdb/writer_test.go index c2eb0b754..26ab11633 100644 --- a/adaptor/rethinkdb/writer_test.go +++ b/adaptor/rethinkdb/writer_test.go @@ -1,6 +1,7 @@ package rethinkdb import ( + "fmt" "reflect" "sync" "testing" @@ -62,15 +63,23 @@ func TestBulkInsert(t *testing.T) { } var wg sync.WaitGroup done := make(chan struct{}) - w := newWriter(writerTestData.DB, done, &wg) - + w := newWriter(done, &wg) if _, err := r.DB(writerTestData.DB).TableCreate("bulk").RunWrite(defaultSession.session); err != nil { log.Errorf("failed to create table (bulk) in %s, may affect tests!, %s", writerTestData.DB, err) } + c, err := NewClient(WithURI(fmt.Sprintf("rethinkdb://127.0.0.1:28015/%s", writerTestData.DB))) + if err != nil { + t.Fatalf("unable to initialize connection to rethinkdb, %s", err) + } + defer c.Close() + s, err := c.Connect() + if err != nil { + t.Fatalf("unable to obtain session to rethinkdb, %s", err) + } for i := 0; i < 999; i++ { msg := message.From(ops.Insert, "bulk", map[string]interface{}{"i": i}) - if _, err := w.Write(msg)(defaultSession); err != nil { + if _, err := w.Write(msg)(s); err != nil { t.Errorf("unexpected Insert error, %s", err) } } @@ -97,15 +106,24 @@ func TestInsert(t *testing.T) { for _, it := range inserttests { var wg sync.WaitGroup done := make(chan struct{}) - w := newWriter(writerTestData.DB, done, &wg) + w := newWriter(done, &wg) if _, err := r.DB(writerTestData.DB).TableCreate(it.table).RunWrite(defaultSession.session); err != nil { log.Errorf("failed to create table (%s) in %s, may affect tests!, %s", it.table, writerTestData.DB, err) } + c, err := NewClient(WithURI(fmt.Sprintf("rethinkdb://127.0.0.1:28015/%s", writerTestData.DB))) + if err != nil { + t.Fatalf("unable to initialize connection to rethinkdb, %s", err) + } + defer c.Close() + s, err := c.Connect() + if err != nil { + t.Fatalf("unable to obtain session to rethinkdb, %s", err) + } for _, data := range it.data { msg := message.From(ops.Insert, it.table, data) - if _, err := w.Write(msg)(defaultSession); err != nil { + if _, err := w.Write(msg)(s); err != nil { t.Errorf("unexpected Insert error, %s\n", err) } } @@ -157,19 +175,28 @@ func TestUpdate(t *testing.T) { for _, ut := range updatetests { var wg sync.WaitGroup done := make(chan struct{}) - w := newWriter(writerTestData.DB, done, &wg) + w := newWriter(done, &wg) if _, err := r.DB(writerTestData.DB).TableCreate(ut.table).RunWrite(defaultSession.session); err != nil { log.Errorf("failed to create table (%s) in %s, may affect tests!, %s", ut.table, writerTestData.DB, err) } + c, err := NewClient(WithURI(fmt.Sprintf("rethinkdb://127.0.0.1:28015/%s", writerTestData.DB))) + if err != nil { + t.Fatalf("unable to initialize connection to rethinkdb, %s", err) + } + defer c.Close() + s, err := c.Connect() + if err != nil { + t.Fatalf("unable to obtain session to rethinkdb, %s", err) + } // Insert data msg := message.From(ops.Insert, ut.table, ut.originalDoc) - if _, err := w.Write(msg)(defaultSession); err != nil { + if _, err := w.Write(msg)(s); err != nil { t.Errorf("unexpected Insert error, %s\n", err) } // Update data msg = message.From(ops.Update, ut.table, ut.updatedDoc) - if _, err := w.Write(msg)(defaultSession); err != nil { + if _, err := w.Write(msg)(s); err != nil { t.Errorf("unexpected Update error, %s\n", err) } close(done) @@ -211,20 +238,29 @@ func TestDelete(t *testing.T) { for _, dt := range deletetests { var wg sync.WaitGroup done := make(chan struct{}) - w := newWriter(writerTestData.DB, done, &wg) + w := newWriter(done, &wg) if _, err := r.DB(writerTestData.DB).TableCreate(dt.table).RunWrite(defaultSession.session); err != nil { log.Errorf("failed to create table (%s) in %s, may affect tests!, %s", dt.table, writerTestData.DB, err) } + c, err := NewClient(WithURI(fmt.Sprintf("rethinkdb://127.0.0.1:28015/%s", writerTestData.DB))) + if err != nil { + t.Fatalf("unable to initialize connection to rethinkdb, %s", err) + } + defer c.Close() + s, err := c.Connect() + if err != nil { + t.Fatalf("unable to obtain session to rethinkdb, %s", err) + } // Insert data dt.originalDoc.Set("_id", dt.id) msg := message.From(ops.Insert, dt.table, dt.originalDoc) - if _, err := w.Write(msg)(defaultSession); err != nil { + if _, err := w.Write(msg)(s); err != nil { t.Errorf("unexpected Insert error, %s\n", err) } // Delete data msg = message.From(ops.Delete, dt.table, dt.originalDoc) - if _, err := w.Write(msg)(defaultSession); err != nil { + if _, err := w.Write(msg)(s); err != nil { t.Errorf("unexpected Delete error, %s\n", err) } close(done) diff --git a/adaptor/transformer/README.md b/adaptor/transformer/README.md deleted file mode 100644 index f081dc6c4..000000000 --- a/adaptor/transformer/README.md +++ /dev/null @@ -1,63 +0,0 @@ -# transformer adaptor - -The transformer adaptor receives and sends data through the defined javascript function for processing. - -The parameter passed to the function has been converted from a go `map[string]interface{}` to a JS object of the following form: - -```javascript -{ - "ns":"message.namespace", - "ts":12345, // time represented in milliseconds since epoch - "op":"insert", - "data": { - "id": "abcdef", - "name": "hello world" - } -} -``` - -***NOTE*** when working with data from MongoDB, the `_id` field will be represented in the following fashion: - -```javascript -{ - "ns":"message.namespace", - "ts":12345, // time represented in milliseconds since epoch - "op":"insert", - "data": { - "_id": { - "$oid": "54a4420502a14b9641000001" - }, - "name": "hello world" - } -} -``` - -There are two types of JavaScript VMs available, `otto` and `goja`. You can configure which one to use via the JS configration and each has its own JavaScript function signature. The `goja` VM has shown better performance in benchmarks but it does *NOT* include the underscore library. - -The default JavaScript VM is `otto` as we are trying to maintain backwards compatability for users but that may change in the future. - -### Configuration -```javascript -tx = transformer({ - "filename": "transform.js", - // "vm": "otto" -}) -``` - -### otto VM -```javascript -module.exports=function(doc) { - console.log(doc['ns']); - console.log(doc['ts']); - console.log(doc['op']); - console.log(doc['data']); - return doc -} -``` - -### goja VM -```javascript -function transform(doc) { - return doc -} -``` diff --git a/adaptor/transformer/gojajs/client.go b/adaptor/transformer/gojajs/client.go deleted file mode 100644 index bfe57529d..000000000 --- a/adaptor/transformer/gojajs/client.go +++ /dev/null @@ -1,113 +0,0 @@ -package gojajs - -import ( - "errors" - "io/ioutil" - - "github.com/compose/transporter/client" - - "github.com/dop251/goja" -) - -var ( - _ client.Client = &Client{} - // ErrEmptyFilename will be returned when the profided filename is empty. - ErrEmptyFilename = errors.New("no filename specified") -) - -// JSFunc defines the structure a transformer function. -type JSFunc func(map[string]interface{}) *goja.Object - -// Client represents a client to the underlying transformer function. -type Client struct { - fn string - vm *goja.Runtime - jsf JSFunc -} - -// ClientOptionFunc is a function that configures a Client. -// It is used in NewClient. -type ClientOptionFunc func(*Client) error - -// NewClient creates a new client to work with Transformer functions. -// -// The caller can configure the new client by passing configuration options -// to the func. -// -// Example: -// -// client, err := NewClient( -// WithFilename("path/to/transformer.js")) -// -// If no URI is configured, it uses defaultURI by default. -// -// An error is also returned when some configuration option is invalid -func NewClient(options ...ClientOptionFunc) (*Client, error) { - // Set up the client - c := &Client{ - vm: nil, - } - - // Run the options on it - for _, option := range options { - if err := option(c); err != nil { - return nil, err - } - } - return c, nil -} - -// WithFilename defines the path to the tranformer file. -func WithFilename(filename string) ClientOptionFunc { - return func(c *Client) error { - if filename == "" { - return ErrEmptyFilename - } - - ba, err := ioutil.ReadFile(filename) - if err != nil { - return err - } - - c.fn = string(ba) - return nil - } -} - -// WithFunction allows for passing a string version of the JS function. -func WithFunction(function string) ClientOptionFunc { - return func(c *Client) error { - c.fn = function - return nil - } -} - -// Connect initializes the JS VM and tests the provided script. -func (c *Client) Connect() (client.Session, error) { - if c.vm == nil { - if err := c.initSession(); err != nil { - return nil, err - } - } - return &Session{c.vm, c.jsf}, nil -} - -// initSession prepares the javascript vm and compiles the transformer script -func (c *Client) initSession() error { - c.vm = goja.New() - - _, err := c.vm.RunString(c.fn) - if err != nil { - return err - } - var jsf JSFunc - c.vm.ExportTo(c.vm.Get("transform"), &jsf) - c.jsf = jsf - return nil -} - -// Session wraps the underlying otto.Otto vm for use by Writer. -type Session struct { - vm *goja.Runtime - fn JSFunc -} diff --git a/adaptor/transformer/gojajs/client_test.go b/adaptor/transformer/gojajs/client_test.go deleted file mode 100644 index d0c24050d..000000000 --- a/adaptor/transformer/gojajs/client_test.go +++ /dev/null @@ -1,63 +0,0 @@ -package gojajs - -import ( - "errors" - "reflect" - "testing" -) - -var ( - defaultClient = &Client{} - - errBadClient = errors.New("bad client") - - clientTests = []struct { - name string - options []ClientOptionFunc // input - expected *Client // expected result - expectedErr error // expected error - }{ - { - "default_client", - make([]ClientOptionFunc, 0), - defaultClient, - nil, - }, - { - "default_client_with_filename", - []ClientOptionFunc{WithFilename("testdata/transformer.js")}, - &Client{fn: `function transform(doc) { return doc }`}, - nil, - }, - { - "default_client_empty_filename", - []ClientOptionFunc{WithFilename("")}, - nil, - ErrEmptyFilename, - }, - { - "with_err", - []ClientOptionFunc{WithErr()}, - defaultClient, - errBadClient, - }, - } -) - -func WithErr() ClientOptionFunc { - return func(c *Client) error { - return errBadClient - } -} - -func TestNewClient(t *testing.T) { - for _, ct := range clientTests { - actual, err := NewClient(ct.options...) - if err != ct.expectedErr { - t.Fatalf("[%s] unexpected NewClient error, expected %+v, got %+v\n", ct.name, ct.expectedErr, err) - } - if err == nil && !reflect.DeepEqual(ct.expected, actual) { - t.Errorf("[%s] Client mismatch\nexpected %+v\ngot %+v", ct.name, ct.expected, actual) - } - } -} diff --git a/adaptor/transformer/ottojs/client.go b/adaptor/transformer/ottojs/client.go deleted file mode 100644 index 2ca7df479..000000000 --- a/adaptor/transformer/ottojs/client.go +++ /dev/null @@ -1,116 +0,0 @@ -package ottojs - -import ( - "errors" - "io/ioutil" - - "github.com/compose/transporter/client" - "github.com/robertkrimen/otto" - _ "github.com/robertkrimen/otto/underscore" // enable underscore -) - -var ( - _ client.Client = &Client{} - // ErrEmptyFilename will be returned when the profided filename is empty. - ErrEmptyFilename = errors.New("no filename specified") -) - -// Client represents a client to the underlying transformer function. -type Client struct { - fn string - vm *otto.Otto -} - -// ClientOptionFunc is a function that configures a Client. -// It is used in NewClient. -type ClientOptionFunc func(*Client) error - -// NewClient creates a new client to work with Transformer functions. -// -// The caller can configure the new client by passing configuration options -// to the func. -// -// Example: -// -// client, err := NewClient( -// WithFilename("path/to/transformer.js")) -// -// If no URI is configured, it uses defaultURI by default. -// -// An error is also returned when some configuration option is invalid -func NewClient(options ...ClientOptionFunc) (*Client, error) { - // Set up the client - c := &Client{ - vm: nil, - } - - // Run the options on it - for _, option := range options { - if err := option(c); err != nil { - return nil, err - } - } - return c, nil -} - -// WithFilename defines the path to the tranformer file. -func WithFilename(filename string) ClientOptionFunc { - return func(c *Client) error { - if filename == "" { - return ErrEmptyFilename - } - - ba, err := ioutil.ReadFile(filename) - if err != nil { - return err - } - - c.fn = string(ba) - return nil - } -} - -// WithFunction allows for passing a string version of the JS function. -func WithFunction(function string) ClientOptionFunc { - return func(c *Client) error { - c.fn = function - return nil - } -} - -// Connect initializes the JS VM and tests the provided script. -func (c *Client) Connect() (client.Session, error) { - if c.vm == nil { - if err := c.initSession(); err != nil { - return nil, err - } - } - return &Session{c.vm}, nil -} - -// initSession prepares the javascript vm and compiles the transformer script -func (c *Client) initSession() error { - c.vm = otto.New() - - // set up the vm environment, make `module = {}` - if _, err := c.vm.Run(`module = {}`); err != nil { - return err - } - - // compile our script - script, err := c.vm.Compile("", c.fn) - if err != nil { - return err - } - - // run the script, ignore the output - if _, err = c.vm.Run(script); err != nil { - return err - } - return nil -} - -// Session wraps the underlying otto.Otto vm for use by Writer. -type Session struct { - vm *otto.Otto -} diff --git a/adaptor/transformer/testdata/goja-transformer.js b/adaptor/transformer/testdata/goja-transformer.js deleted file mode 100644 index a36e059d9..000000000 --- a/adaptor/transformer/testdata/goja-transformer.js +++ /dev/null @@ -1 +0,0 @@ -function f(doc) { return doc } diff --git a/adaptor/transformer/testdata/transformer.js b/adaptor/transformer/testdata/transformer.js deleted file mode 100644 index a2cb11ff2..000000000 --- a/adaptor/transformer/testdata/transformer.js +++ /dev/null @@ -1 +0,0 @@ -module.exports=function(doc) { return doc } diff --git a/adaptor/transformer/transformer.go b/adaptor/transformer/transformer.go deleted file mode 100644 index b1cdea9ff..000000000 --- a/adaptor/transformer/transformer.go +++ /dev/null @@ -1,74 +0,0 @@ -package transformer - -import ( - "sync" - - "github.com/compose/transporter/adaptor" - "github.com/compose/transporter/client" - - goja "github.com/compose/transporter/adaptor/transformer/gojajs" - otto "github.com/compose/transporter/adaptor/transformer/ottojs" -) - -const ( - sampleConfig = `{ - "filename": "transformer.js" - // "vm": "otto" -}` - - description = "an adaptor that transforms documents using a javascript function" - - // DefaultVM defines the javascript interpreter to be used if one is not specified - DefaultVM = "otto" -) - -var ( - _ adaptor.Adaptor = &Transformer{} -) - -// Transformer is an adaptor which consumes data from a source, transforms it using a supplied javascript -// function and then emits it. The javascript transformation function is supplied as a separate file on disk, -// and is called by calling the defined module.exports function -type Transformer struct { - Filename string `json:"filename"` - VM string `json:"vm"` -} - -func init() { - adaptor.Add( - "transformer", - func() adaptor.Adaptor { - return &Transformer{ - VM: DefaultVM, - } - }, - ) -} - -func (t *Transformer) Client() (client.Client, error) { - if t.VM == DefaultVM { - return otto.NewClient(otto.WithFilename(t.Filename)) - } - return goja.NewClient(goja.WithFilename(t.Filename)) -} - -func (t *Transformer) Reader() (client.Reader, error) { - return nil, adaptor.ErrFuncNotSupported{Name: "transformer", Func: "Reader()"} -} - -func (t *Transformer) Writer(chan struct{}, *sync.WaitGroup) (client.Writer, error) { - if t.VM == DefaultVM { - return &otto.Writer{}, nil - } - return &goja.Writer{}, nil -} - -// Description for transformer adaptor -func (t *Transformer) Description() string { - return description -} - -// SampleConfig for transformer adaptor -func (t *Transformer) SampleConfig() string { - return sampleConfig -} diff --git a/adaptor/transformer/transformer_test.go b/adaptor/transformer/transformer_test.go deleted file mode 100644 index 4d2a8ab6f..000000000 --- a/adaptor/transformer/transformer_test.go +++ /dev/null @@ -1,45 +0,0 @@ -package transformer - -import ( - "testing" - - "github.com/compose/transporter/adaptor" -) - -func TestDescription(t *testing.T) { - transformer := Transformer{} - if transformer.Description() != description { - t.Errorf("unexpected Description, expected %s, got %s\n", description, transformer.Description()) - } -} - -func TestSampleConfig(t *testing.T) { - transformer := Transformer{} - if transformer.SampleConfig() != sampleConfig { - t.Errorf("unexpected SampleConfig, expected %s, got %s\n", sampleConfig, transformer.SampleConfig()) - } -} - -var initTests = []map[string]interface{}{ - {"filename": "testdata/transformer.js"}, - {"filename": "testdata/goja-transformer.js", "vm": "goja"}, -} - -func TestInit(t *testing.T) { - for _, it := range initTests { - a, err := adaptor.GetAdaptor("transformer", it) - if err != nil { - t.Fatalf("unexpected GetV2() error, %s", err) - } - if _, err := a.Client(); err != nil { - t.Errorf("unexpected Client() error, %s", err) - } - rerr := adaptor.ErrFuncNotSupported{Name: "transformer", Func: "Reader()"} - if _, err := a.Reader(); err != rerr { - t.Errorf("wrong Reader() error, expected %s, got %s", rerr, err) - } - if _, err := a.Writer(nil, nil); err != nil { - t.Errorf("unexpected Writer() error, %s", err) - } - } -} diff --git a/cmd/transporter/goja_buider_test.go b/cmd/transporter/goja_buider_test.go index 9fe60c54f..783eb69b7 100644 --- a/cmd/transporter/goja_buider_test.go +++ b/cmd/transporter/goja_buider_test.go @@ -3,14 +3,28 @@ package main import ( "os" "reflect" + "regexp" "testing" + + "github.com/compose/transporter/pipeline" ) func TestNewBuilder(t *testing.T) { - source := buildAdaptor("mongodb")(map[string]interface{}{"name": "source", "uri": "mongo://localhost:27017"}) - transformer := buildAdaptor("transformer")(map[string]interface{}{"name": "trans", "filename": "pipeline.js"}) - source.Add(transformer) - transformer.Add(buildAdaptor("elasticsearch")(map[string]interface{}{"name": "sink", "uri": "http://localhost:2900"})) + a := buildAdaptor("mongodb")(map[string]interface{}{"uri": "mongo://localhost:27017"}) + source, err := pipeline.NewNode("source", a.name, "test./.*/", a.a, nil) + if err != nil { + t.Fatalf("unexpected error, %s\n", err) + } + + a = buildAdaptor("elasticsearch")(map[string]interface{}{"uri": "http://localhost:9200"}) + sink, err := pipeline.NewNode("sink", a.name, "test./.*/", a.a, source) + if err != nil { + t.Fatalf("unexpected error, %s\n", err) + } + + transformer := buildFunction("transformer")(map[string]interface{}{"filename": "pipeline.js"}) + sink.Transforms = []*pipeline.Transform{&pipeline.Transform{Name: "trans", Fn: transformer, NsFilter: regexp.MustCompile(".*")}} + expected := "Transporter:\n" expected += source.String() @@ -19,17 +33,27 @@ func TestNewBuilder(t *testing.T) { t.Fatalf("unexpected error, %s", err) } actual := builder.String() - if reflect.DeepEqual(actual, expected) { + if !reflect.DeepEqual(actual, expected) { t.Errorf("misconfigured transporter\nexpected:\n%s\ngot:\n%s", expected, actual) } } func TestNewBuilderWithEnv(t *testing.T) { os.Setenv("TEST_MONGO_URI", "mongo://localhost:27017") - source := buildAdaptor("mongodb")(map[string]interface{}{"name": "source", "uri": "mongo://localhost:27017"}) - transformer := buildAdaptor("transformer")(map[string]interface{}{"name": "trans", "filename": "pipeline.js"}) - source.Add(transformer) - transformer.Add(buildAdaptor("elasticsearch")(map[string]interface{}{"name": "sink", "uri": "http://localhost:2900"})) + a := buildAdaptor("mongodb")(map[string]interface{}{"uri": "mongo://localhost:27017"}) + source, err := pipeline.NewNode("source", a.name, "test./.*/", a.a, nil) + if err != nil { + t.Fatalf("unexpected error, %s\n", err) + } + a = buildAdaptor("elasticsearch")(map[string]interface{}{"uri": "http://localhost:9200"}) + sink, err := pipeline.NewNode("sink", a.name, "test./.*/", a.a, source) + if err != nil { + t.Fatalf("unexpected error, %s\n", err) + } + + transformer := buildFunction("transformer")(map[string]interface{}{"filename": "pipeline.js"}) + sink.Transforms = []*pipeline.Transform{&pipeline.Transform{Name: "trans", Fn: transformer, NsFilter: regexp.MustCompile(".*")}} + expected := "Transporter:\n" expected += source.String() @@ -38,8 +62,7 @@ func TestNewBuilderWithEnv(t *testing.T) { t.Fatalf("unexpected error, %s", err) } actual := builder.String() - if reflect.DeepEqual(actual, expected) { + if !reflect.DeepEqual(actual, expected) { t.Errorf("misconfigured transporter\nexpected:\n%s\ngot:\n%s", expected, actual) } - } diff --git a/cmd/transporter/goja_builder.go b/cmd/transporter/goja_builder.go index f70a5b20c..30c64aaac 100644 --- a/cmd/transporter/goja_builder.go +++ b/cmd/transporter/goja_builder.go @@ -13,6 +13,7 @@ import ( "github.com/compose/transporter/adaptor" "github.com/compose/transporter/events" + "github.com/compose/transporter/function" "github.com/compose/transporter/pipeline" "github.com/dop251/goja" uuid "github.com/nu7hatch/gouuid" @@ -27,6 +28,9 @@ func NewBuilder(file string) (*Transporter, error) { for _, name := range adaptor.RegisteredAdaptors() { t.vm.Set(name, buildAdaptor(name)) } + for _, name := range function.RegisteredFunctions() { + t.vm.Set(name, buildFunction(name)) + } ba, err := ioutil.ReadFile(file) if err != nil { @@ -63,7 +67,22 @@ type Transporter struct { vm *goja.Runtime sourceNode *pipeline.Node - lastNode *pipeline.Node +} + +type Node struct { + vm *goja.Runtime + parent *pipeline.Node +} + +type Transformer struct { + vm *goja.Runtime + source *pipeline.Node + transforms []*pipeline.Transform +} + +type Adaptor struct { + name string + a adaptor.Adaptor } func (t *Transporter) Run() error { @@ -108,51 +127,111 @@ func (t *Transporter) String() string { return out } -func buildAdaptor(name string) func(map[string]interface{}) *pipeline.Node { - return func(args map[string]interface{}) *pipeline.Node { - uuid, _ := uuid.NewV4() - nodeName := uuid.String() - if name, ok := args["name"]; ok { - nodeName = name.(string) - delete(args, "name") +func buildAdaptor(name string) func(map[string]interface{}) Adaptor { + return func(args map[string]interface{}) Adaptor { + a, err := adaptor.GetAdaptor(name, args) + if err != nil { + panic(err) } - if _, ok := args["namespace"]; !ok { - args["namespace"] = "test./.*/" + return Adaptor{name, a} + } +} + +func buildFunction(name string) func(map[string]interface{}) function.Function { + return func(args map[string]interface{}) function.Function { + f, err := function.GetFunction(name, args) + if err != nil { + panic(err) } - return pipeline.NewNode(nodeName, name, args) + return f } } func (t *Transporter) Source(call goja.FunctionCall) goja.Value { - args := exportArgs(call.Arguments) - t.sourceNode = args[0].(*pipeline.Node) - t.lastNode = t.sourceNode - return t.vm.ToValue(t) + name, out, namespace := exportArgs(call.Arguments) + a := out.(Adaptor) + n, err := pipeline.NewNode(name, a.name, namespace, a.a, nil) + if err != nil { + panic(err) + } + t.sourceNode = n + return t.vm.ToValue(&Node{t.vm, n}) +} + +func (n *Node) Transform(call goja.FunctionCall) goja.Value { + name, f, ns := exportArgs(call.Arguments) + _, nsFilter, err := adaptor.CompileNamespace(ns) + if err != nil { + panic(err) + } + tf := &Transformer{ + vm: n.vm, + source: n.parent, + transforms: make([]*pipeline.Transform, 0), + } + tf.transforms = append(tf.transforms, &pipeline.Transform{Name: name, Fn: f.(function.Function), NsFilter: nsFilter}) + return n.vm.ToValue(tf) +} + +func (tf *Transformer) Transform(call goja.FunctionCall) goja.Value { + name, f, ns := exportArgs(call.Arguments) + _, nsFilter, err := adaptor.CompileNamespace(ns) + if err != nil { + panic(err) + } + t := &pipeline.Transform{Name: name, Fn: f.(function.Function), NsFilter: nsFilter} + tf.transforms = append(tf.transforms, t) + return tf.vm.ToValue(tf) } -func (t *Transporter) Transform(call goja.FunctionCall) goja.Value { - args := exportArgs(call.Arguments) - node := args[0].(*pipeline.Node) - t.lastNode.Add(node) - t.lastNode = node - return t.vm.ToValue(t) +func (n *Node) Save(call goja.FunctionCall) goja.Value { + name, out, namespace := exportArgs(call.Arguments) + a := out.(Adaptor) + child, err := pipeline.NewNode(name, a.name, namespace, a.a, n.parent) + if err != nil { + panic(err) + } + return n.vm.ToValue(&Node{n.vm, child}) } -func (t *Transporter) Save(call goja.FunctionCall) goja.Value { - args := exportArgs(call.Arguments) - node := args[0].(*pipeline.Node) - t.lastNode.Add(node) - t.lastNode = node - return t.vm.ToValue(t) +func (tf *Transformer) Save(call goja.FunctionCall) goja.Value { + name, out, namespace := exportArgs(call.Arguments) + a := out.(Adaptor) + child, err := pipeline.NewNode(name, a.name, namespace, a.a, tf.source) + if err != nil { + panic(err) + } + child.Transforms = tf.transforms + return tf.vm.ToValue(&Node{tf.vm, child}) } -func exportArgs(args []goja.Value) []interface{} { +// arguments can be any of the following forms: +// ("name", Adaptor/Function, "namespace") +// ("name", Adaptor/Function) +// (Adaptor/Function, "namespace") +// (Adaptor/Function) +// the only *required* argument is a Adaptor or Function +func exportArgs(args []goja.Value) (string, interface{}, string) { if len(args) == 0 { - return nil + panic("at least 1 argument required") } - out := make([]interface{}, 0, len(args)) - for _, a := range args { - out = append(out, a.Export()) + uuid, _ := uuid.NewV4() + var ( + name = uuid.String() + namespace = "test./.*/" + a interface{} + ) + if n, ok := args[0].Export().(string); ok { + name = n + a = args[1].Export() + if len(args) == 3 { + namespace = args[2].Export().(string) + } + } else { + a = args[0].Export() + if len(args) == 2 { + namespace = args[1].Export().(string) + } } - return out + return name, a, namespace } diff --git a/cmd/transporter/main.go b/cmd/transporter/main.go index 2723759c9..e811d6694 100644 --- a/cmd/transporter/main.go +++ b/cmd/transporter/main.go @@ -8,6 +8,7 @@ import ( "text/tabwriter" _ "github.com/compose/transporter/adaptor/all" + _ "github.com/compose/transporter/function/all" "github.com/compose/transporter/log" ) diff --git a/cmd/transporter/testdata/test_pipeline.js b/cmd/transporter/testdata/test_pipeline.js index bf5f6451b..44603a67e 100644 --- a/cmd/transporter/testdata/test_pipeline.js +++ b/cmd/transporter/testdata/test_pipeline.js @@ -1,2 +1,4 @@ -m = mongodb({name: "source", uri: "mongo://localhost:27017"}) -t.Source(m).Transform(transformer({name: "trans", filename: "pipeline.js"})).Save(elasticsearch({name: "sink", uri:"http://localhost:9200"})) \ No newline at end of file +m = mongodb({"uri": "mongo://localhost:27017"}) +t.Source("source", m) + .Transform("trans", transformer({filename: "pipeline.js"})) + .Save("sink", elasticsearch({uri:"http://localhost:9200"})) \ No newline at end of file diff --git a/cmd/transporter/testdata/test_pipeline_env.js b/cmd/transporter/testdata/test_pipeline_env.js index 85faf2181..9b1d657f6 100644 --- a/cmd/transporter/testdata/test_pipeline_env.js +++ b/cmd/transporter/testdata/test_pipeline_env.js @@ -1,2 +1,4 @@ -m = mongodb({name: "source", uri: "${TEST_MONGO_URI}"}) -t.Source(m).Transform(transformer({name: "trans", filename: "pipeline.js"})).Save(elasticsearch({name: "sink", uri:"http://localhost:9200"})) \ No newline at end of file +m = mongodb({uri: "${TEST_MONGO_URI}"}) +t.Source("source", m) + .Transform("trans", transformer({filename: "pipeline.js"})) + .Save("sink", elasticsearch({uri:"http://localhost:9200"})) \ No newline at end of file diff --git a/function/all/all.go b/function/all/all.go new file mode 100644 index 000000000..299a41693 --- /dev/null +++ b/function/all/all.go @@ -0,0 +1,11 @@ +package all + +import ( + _ "github.com/compose/transporter/function/gojajs" + _ "github.com/compose/transporter/function/omit" + _ "github.com/compose/transporter/function/ottojs" + _ "github.com/compose/transporter/function/pick" + _ "github.com/compose/transporter/function/pretty" + _ "github.com/compose/transporter/function/rename" + _ "github.com/compose/transporter/function/skip" +) diff --git a/function/function.go b/function/function.go new file mode 100644 index 000000000..c3393f813 --- /dev/null +++ b/function/function.go @@ -0,0 +1,9 @@ +package function + +import "github.com/compose/transporter/message" + +// Function has a single defined function to serve the purpose of apply logic to a message in order to return +// a message. +type Function interface { + Apply(message.Msg) (message.Msg, error) +} diff --git a/adaptor/transformer/gojajs/writer.go b/function/gojajs/goja.go similarity index 59% rename from adaptor/transformer/gojajs/writer.go rename to function/gojajs/goja.go index f6459248d..3b032e72a 100644 --- a/adaptor/transformer/gojajs/writer.go +++ b/function/gojajs/goja.go @@ -2,10 +2,11 @@ package gojajs import ( "errors" + "io/ioutil" "time" "github.com/compose/mejson" - "github.com/compose/transporter/client" + "github.com/compose/transporter/function" "github.com/compose/transporter/log" "github.com/compose/transporter/message" "github.com/compose/transporter/message/data" @@ -14,28 +15,69 @@ import ( ) var ( - _ client.Writer = &Writer{} + _ function.Function = &Goja{} // ErrInvalidMessageType is a generic error returned when the `data` property returned in the document from // the JS function was not of type map[string]interface{} ErrInvalidMessageType = errors.New("returned document was not a map") + + // ErrEmptyFilename will be returned when the profided filename is empty. + ErrEmptyFilename = errors.New("no filename specified") ) -// Writer implements the client.Writer interface. -type Writer struct{} +func init() { + function.Add( + "goja", + func() function.Function { + return &Goja{} + }, + ) +} + +type Goja struct { + Filename string `json:"filename"` + vm *goja.Runtime +} -func (w *Writer) Write(msg message.Msg) func(client.Session) (message.Msg, error) { - return func(s client.Session) (message.Msg, error) { - // short circuit for commands - if msg.OP() == ops.Command { - return msg, nil +// JSFunc defines the structure a transformer function. +type JSFunc func(map[string]interface{}) *goja.Object + +// Apply fulfills the function.Function interface by transforming the incoming message with the configured +// JavaScript function. +func (g *Goja) Apply(msg message.Msg) (message.Msg, error) { + if g.vm == nil { + if err := g.initVM(); err != nil { + return nil, err } + } + return g.transformOne(msg) +} - return w.transformOne(s.(*Session), msg) +func (g *Goja) initVM() error { + g.vm = goja.New() + + fn, err := extractFunction(g.Filename) + if err != nil { + return err + } + _, err = g.vm.RunString(fn) + return err +} + +func extractFunction(filename string) (string, error) { + if filename == "" { + return "", ErrEmptyFilename } + + ba, err := ioutil.ReadFile(filename) + if err != nil { + return "", err + } + + return string(ba), nil } -func (w *Writer) transformOne(s *Session, msg message.Msg) (message.Msg, error) { +func (g *Goja) transformOne(msg message.Msg) (message.Msg, error) { var ( outDoc goja.Value doc interface{} @@ -54,14 +96,16 @@ func (w *Writer) transformOne(s *Session, msg message.Msg) (message.Msg, error) // lets run our transformer on the document beforeVM := time.Now().Nanosecond() - outDoc = s.fn(currMsg) + var jsf JSFunc + g.vm.ExportTo(g.vm.Get("transform"), &jsf) + outDoc = jsf(currMsg) var res map[string]interface{} - if s.vm.ExportTo(outDoc, &res); err != nil { + if g.vm.ExportTo(outDoc, &res); err != nil { return msg, err } afterVM := time.Now().Nanosecond() - newMsg, err := toMsg(s.vm, msg, res) + newMsg, err := toMsg(g.vm, msg, res) if err != nil { return nil, err } diff --git a/adaptor/transformer/gojajs/writer_test.go b/function/gojajs/goja_test.go similarity index 72% rename from adaptor/transformer/gojajs/writer_test.go rename to function/gojajs/goja_test.go index d85ffd4f0..86ebabd46 100644 --- a/adaptor/transformer/gojajs/writer_test.go +++ b/function/gojajs/goja_test.go @@ -4,12 +4,32 @@ import ( "reflect" "testing" + "github.com/compose/transporter/function" "github.com/compose/transporter/message" "github.com/compose/transporter/message/data" "github.com/compose/transporter/message/ops" "gopkg.in/mgo.v2/bson" ) +var initTests = []struct { + in map[string]interface{} + expect *Goja +}{ + {map[string]interface{}{"filename": "testdata/transformer.js"}, &Goja{Filename: "testdata/transformer.js"}}, +} + +func TestInit(t *testing.T) { + for _, it := range initTests { + a, err := function.GetFunction("goja", it.in) + if err != nil { + t.Fatalf("unexpected GetFunction() error, %s", err) + } + if !reflect.DeepEqual(a, it.expect) { + t.Errorf("misconfigured Function, expected %+v, got %+v", it.expect, a) + } + } +} + var ( bsonID1 = bson.NewObjectId() bsonID2 = bson.ObjectIdHex("54a4420502a14b9641000001") @@ -22,88 +42,80 @@ var ( }{ { "just pass through", - "function transform(doc) { return doc }", + "testdata/transformer.js", message.From(ops.Insert, "collection", data.Data{"id": "id1", "name": "nick"}), message.From(ops.Insert, "collection", data.Data{"id": "id1", "name": "nick"}), nil, }, { "delete the 'name' property", - "function transform(doc) { delete doc['data']['name']; return doc }", + "testdata/delete_name.js", message.From(ops.Insert, "collection", data.Data{"id": "id2", "name": "nick"}), message.From(ops.Insert, "collection", data.Data{"id": "id2"}), nil, }, { "delete's should be processed the same", - "function transform(doc) { delete doc['data']['name']; return doc }", + "testdata/delete_name.js", message.From(ops.Delete, "collection", data.Data{"id": "id2", "name": "nick"}), message.From(ops.Delete, "collection", data.Data{"id": "id2"}), nil, }, - { - "delete's and commands should pass through, and the transformer fn shouldn't run", - "function transform(doc) { delete doc['data']['name']; return doc }", - message.From(ops.Command, "collection", data.Data{"id": "id2", "name": "nick"}), - message.From(ops.Command, "collection", data.Data{"id": "id2", "name": "nick"}), - nil, - }, { "bson should marshal and unmarshal properly", - "function transform(doc) { return doc }", + "testdata/transformer.js", message.From(ops.Insert, "collection", data.Data{"id": bsonID1, "name": "nick"}), message.From(ops.Insert, "collection", data.Data{"id": bsonID1, "name": "nick"}), nil, }, { "we should be able to change the bson", - "function transform(doc) { doc['data']['id']['$oid'] = '54a4420502a14b9641000001'; return doc }", + "testdata/change_bson.js", message.From(ops.Insert, "collection", data.Data{"id": bsonID1, "name": "nick"}), message.From(ops.Insert, "collection", data.Data{"id": bsonID2, "name": "nick"}), nil, }, { "we should be able to skip a message", - "function transform(doc) { doc['op'] = 's'; return doc }", + "testdata/skip.js", message.From(ops.Insert, "collection", data.Data{"id": bsonID1, "name": "nick"}), nil, nil, }, { "we should be able to change the namespace", - "function transform(doc) { doc['ns'] = 'table'; return doc }", + "testdata/change_ns.js", message.From(ops.Insert, "collection", data.Data{"id": bsonID1, "name": "nick"}), message.From(ops.Insert, "table", data.Data{"id": bsonID1, "name": "nick"}), nil, }, { "we should be able to add an object to the bson", - `function transform(doc) { doc['data']['added'] = {"name":"batman","villain":"joker"}; return doc }`, + "testdata/add_data.js", message.From(ops.Insert, "collection", data.Data{"name": "nick"}), message.From(ops.Insert, "collection", data.Data{"name": "nick", "added": bson.M{"name": "batman", "villain": "joker"}}), nil, }, { "Invalid data returned", - `function transform(doc) { doc["data"] = "not a map";return doc }`, + "testdata/invalid_data.js", message.From(ops.Insert, "collection", data.Data{"id": "id1", "name": "nick"}), nil, ErrInvalidMessageType, }, + { + "empty filename", + "", + message.From(ops.Insert, "collection", data.Data{"id": "id1", "name": "nick"}), + nil, + ErrEmptyFilename, + }, } ) -func TestWrite(t *testing.T) { +func TestApply(t *testing.T) { for _, v := range writeTests { - c, err := NewClient(WithFunction(v.fn)) - if err != nil { - t.Fatalf("[%s] NewClient() error, %s", v.name, err) - } - s, err := c.Connect() - if err != nil { - t.Fatalf("[%s] unexpected Connect() error, %s", v.name, err) - } - w := Writer{} - msg, err := w.Write(v.in)(s) + g := Goja{Filename: v.fn} + msg, err := g.Apply(v.in) if err != v.err { t.Errorf("[%s] wrong error, expected: %+v, got; %v", v.name, v.err, err) } @@ -164,20 +176,11 @@ func isEqualBSON(m1 map[string]interface{}, m2 map[string]interface{}) bool { } func BenchmarkTransformOne(b *testing.B) { - c, err := NewClient(WithFunction("function f(doc) { return doc }")) - if err != nil { - panic(err) - } - s, err := c.Connect() - if err != nil { - panic(err) - } - w := Writer{} + g := &Goja{Filename: "testdata/transformer.js"} msg := message.From(ops.Insert, "collection", map[string]interface{}{"id": bson.NewObjectId(), "name": "nick"}) - b.ResetTimer() for i := 0; i < b.N; i++ { - w.Write(msg)(s) + g.Apply(msg) } } diff --git a/function/gojajs/testdata/add_data.js b/function/gojajs/testdata/add_data.js new file mode 100644 index 000000000..5e8f42c1b --- /dev/null +++ b/function/gojajs/testdata/add_data.js @@ -0,0 +1 @@ +function transform(doc) { doc['data']['added'] = {"name":"batman","villain":"joker"}; return doc } \ No newline at end of file diff --git a/function/gojajs/testdata/change_bson.js b/function/gojajs/testdata/change_bson.js new file mode 100644 index 000000000..6b0873051 --- /dev/null +++ b/function/gojajs/testdata/change_bson.js @@ -0,0 +1 @@ +function transform(doc) { doc['data']['id']['$oid'] = '54a4420502a14b9641000001'; return doc } \ No newline at end of file diff --git a/function/gojajs/testdata/change_ns.js b/function/gojajs/testdata/change_ns.js new file mode 100644 index 000000000..fde3b3fcf --- /dev/null +++ b/function/gojajs/testdata/change_ns.js @@ -0,0 +1 @@ +function transform(doc) { doc['ns'] = 'table'; return doc } \ No newline at end of file diff --git a/function/gojajs/testdata/delete_name.js b/function/gojajs/testdata/delete_name.js new file mode 100644 index 000000000..840b4ac75 --- /dev/null +++ b/function/gojajs/testdata/delete_name.js @@ -0,0 +1 @@ +function transform(doc) { delete doc['data']['name']; return doc } \ No newline at end of file diff --git a/function/gojajs/testdata/invalid_data.js b/function/gojajs/testdata/invalid_data.js new file mode 100644 index 000000000..c0ef6f556 --- /dev/null +++ b/function/gojajs/testdata/invalid_data.js @@ -0,0 +1 @@ +function transform(doc) { doc["data"] = "not a map";return doc } \ No newline at end of file diff --git a/function/gojajs/testdata/skip.js b/function/gojajs/testdata/skip.js new file mode 100644 index 000000000..ed9dc4d9d --- /dev/null +++ b/function/gojajs/testdata/skip.js @@ -0,0 +1 @@ +function transform(doc) { doc['op'] = 's'; return doc } \ No newline at end of file diff --git a/adaptor/transformer/gojajs/testdata/transformer.js b/function/gojajs/testdata/transformer.js similarity index 100% rename from adaptor/transformer/gojajs/testdata/transformer.js rename to function/gojajs/testdata/transformer.js diff --git a/adaptor/function/omit/README.md b/function/omit/README.md similarity index 100% rename from adaptor/function/omit/README.md rename to function/omit/README.md diff --git a/function/omit/omitter.go b/function/omit/omitter.go new file mode 100644 index 000000000..4b5d5d95c --- /dev/null +++ b/function/omit/omitter.go @@ -0,0 +1,26 @@ +package omit + +import ( + "github.com/compose/transporter/function" + "github.com/compose/transporter/message" +) + +func init() { + function.Add( + "omit", + func() function.Function { + return &Omitter{} + }, + ) +} + +type Omitter struct { + Fields []string `json:"fields"` +} + +func (o *Omitter) Apply(msg message.Msg) (message.Msg, error) { + for _, k := range o.Fields { + msg.Data().Delete(k) + } + return msg, nil +} diff --git a/adaptor/function/omit/omitter_test.go b/function/omit/omitter_test.go similarity index 64% rename from adaptor/function/omit/omitter_test.go rename to function/omit/omitter_test.go index a1ca0e2f9..7c7fe05f3 100644 --- a/adaptor/function/omit/omitter_test.go +++ b/function/omit/omitter_test.go @@ -4,31 +4,27 @@ import ( "reflect" "testing" - "github.com/compose/transporter/adaptor" + "github.com/compose/transporter/function" _ "github.com/compose/transporter/log" "github.com/compose/transporter/message" "github.com/compose/transporter/message/ops" ) -var initTests = []map[string]interface{}{ - {"fields": []string{"test"}}, +var initTests = []struct { + in map[string]interface{} + expect *Omitter +}{ + {map[string]interface{}{"fields": []string{"test"}}, &Omitter{Fields: []string{"test"}}}, } func TestInit(t *testing.T) { for _, it := range initTests { - a, err := adaptor.GetAdaptor("omit", it) + a, err := function.GetFunction("omit", it.in) if err != nil { - t.Fatalf("unexpected GetAdaptor() error, %s", err) - } - if _, err := a.Client(); err != nil { - t.Errorf("unexpected Client() error, %s", err) - } - rerr := adaptor.ErrFuncNotSupported{Name: "transformer", Func: "Reader()"} - if _, err := a.Reader(); err != rerr { - t.Errorf("wrong Reader() error, expected %s, got %s", rerr, err) + t.Fatalf("unexpected GetFunction() error, %s", err) } - if _, err := a.Writer(nil, nil); err != nil { - t.Errorf("unexpected Writer() error, %s", err) + if !reflect.DeepEqual(a, it.expect) { + t.Errorf("misconfigured Function, expected %+v, got %+v", it.expect, a) } } } @@ -63,10 +59,10 @@ var omitTests = []struct { }, } -func TestOmit(t *testing.T) { +func TestApply(t *testing.T) { for _, ot := range omitTests { omit := &Omitter{ot.fields} - msg, err := omit.Write(message.From(ops.Insert, "test", ot.in))(nil) + msg, err := omit.Apply(message.From(ops.Insert, "test", ot.in)) if !reflect.DeepEqual(err, ot.err) { t.Errorf("[%s] error mismatch, expected %s, got %s", ot.name, ot.err, err) } diff --git a/adaptor/transformer/ottojs/writer.go b/function/ottojs/otto.go similarity index 66% rename from adaptor/transformer/ottojs/writer.go rename to function/ottojs/otto.go index 89e49e999..3be621691 100644 --- a/adaptor/transformer/ottojs/writer.go +++ b/function/ottojs/otto.go @@ -1,37 +1,97 @@ package ottojs import ( + "errors" "fmt" + "io/ioutil" "time" "github.com/compose/mejson" - "github.com/compose/transporter/client" + "github.com/compose/transporter/function" "github.com/compose/transporter/log" "github.com/compose/transporter/message" "github.com/compose/transporter/message/data" "github.com/compose/transporter/message/ops" "github.com/robertkrimen/otto" + + _ "github.com/robertkrimen/otto/underscore" // enable underscore ) var ( - _ client.Writer = &Writer{} + _ function.Function = &Otto{} + // ErrEmptyFilename will be returned when the profided filename is empty. + ErrEmptyFilename = errors.New("no filename specified") ) -// Writer implements the client.Writer interface. -type Writer struct{} +func init() { + function.Add( + "otto", + func() function.Function { + return &Otto{} + }, + ) + + // adding for backwards compatibility + function.Add( + "transformer", + func() function.Function { + return &Otto{} + }, + ) +} + +type Otto struct { + Filename string `json:"filename"` + vm *otto.Otto +} -func (w *Writer) Write(msg message.Msg) func(client.Session) (message.Msg, error) { - return func(s client.Session) (message.Msg, error) { - // short circuit for deletes and commands - if msg.OP() == ops.Command { - return msg, nil +func (o *Otto) Apply(msg message.Msg) (message.Msg, error) { + if o.vm == nil { + if err := o.initVM(); err != nil { + return nil, err } + } + return o.transformOne(msg) +} + +func (o *Otto) initVM() error { + o.vm = otto.New() + + fn, err := extractFunction(o.Filename) + if err != nil { + return err + } + + // set up the vm environment, make `module = {}` + if _, err := o.vm.Run(`module = {}`); err != nil { + return err + } + + // compile our script + script, err := o.vm.Compile("", fn) + if err != nil { + return err + } - return w.transformOne(s.(*Session).vm, msg) + // run the script, ignore the output + _, err = o.vm.Run(script) + return err +} + +func extractFunction(filename string) (string, error) { + if filename == "" { + return "", ErrEmptyFilename + } + + ba, err := ioutil.ReadFile(filename) + if err != nil { + return "", err } + + return string(ba), nil } -func (w *Writer) transformOne(vm *otto.Otto, msg message.Msg) (message.Msg, error) { +func (o *Otto) transformOne(msg message.Msg) (message.Msg, error) { var ( value, outDoc otto.Value result, doc interface{} @@ -52,27 +112,23 @@ func (w *Writer) transformOne(vm *otto.Otto, msg message.Msg) (message.Msg, erro } currMsg["data"] = doc - if value, err = vm.ToValue(currMsg); err != nil { - // t.pipe.Err <- t.transformerError(adaptor.ERROR, err, msg) + if value, err = o.vm.ToValue(currMsg); err != nil { return msg, err } // now that we have finished casting our map to a bunch of different types, // lets run our transformer on the document beforeVM := time.Now().Nanosecond() - if outDoc, err = vm.Call(`module.exports`, nil, value); err != nil { - // t.pipe.Err <- t.transformerError(adaptor.ERROR, err, msg) + if outDoc, err = o.vm.Call(`module.exports`, nil, value); err != nil { return msg, err } if result, err = outDoc.Export(); err != nil { - // t.pipe.Err <- t.transformerError(adaptor.ERROR, err, msg) return msg, err } afterVM := time.Now().Nanosecond() newMsg, err := toMsg(msg, result) if err != nil { - // t.pipe.Err <- t.transformerError(adaptor.ERROR, err, msg) return msg, err } then := time.Now().Nanosecond() diff --git a/adaptor/transformer/ottojs/writer_test.go b/function/ottojs/otto_test.go similarity index 70% rename from adaptor/transformer/ottojs/writer_test.go rename to function/ottojs/otto_test.go index 0e6d6ce0b..6329db90b 100644 --- a/adaptor/transformer/ottojs/writer_test.go +++ b/function/ottojs/otto_test.go @@ -22,61 +22,54 @@ var ( }{ { "just pass through", - "module.exports=function(doc) { return doc }", + "testdata/transformer.js", message.From(ops.Insert, "collection", data.Data{"id": "id1", "name": "nick"}), message.From(ops.Insert, "collection", data.Data{"id": "id1", "name": "nick"}), false, }, { "delete the 'name' property", - "module.exports=function(doc) { doc['data'] = _.omit(doc['data'], ['name']); return doc }", + "testdata/delete_name.js", message.From(ops.Insert, "collection", data.Data{"id": "id2", "name": "nick"}), message.From(ops.Insert, "collection", data.Data{"id": "id2"}), false, }, { "delete's should be processed the same", - "module.exports=function(doc) { doc['data'] = _.omit(doc['data'], ['name']); return doc }", + "testdata/delete_name.js", message.From(ops.Delete, "collection", data.Data{"id": "id2", "name": "nick"}), message.From(ops.Delete, "collection", data.Data{"id": "id2"}), false, }, - { - "delete's and commands should pass through, and the transformer fn shouldn't run", - "module.exports=function(doc) { return _.omit(doc['data'], ['name']) }", - message.From(ops.Command, "collection", data.Data{"id": "id2", "name": "nick"}), - message.From(ops.Command, "collection", data.Data{"id": "id2", "name": "nick"}), - false, - }, { "bson should marshal and unmarshal properly", - "module.exports=function(doc) { return doc }", + "testdata/transformer.js", message.From(ops.Insert, "collection", data.Data{"id": bsonID1, "name": "nick"}), message.From(ops.Insert, "collection", data.Data{"id": bsonID1, "name": "nick"}), false, }, { "we should be able to change the bson", - "module.exports=function(doc) { doc['data']['id']['$oid'] = '54a4420502a14b9641000001'; return doc }", + "testdata/change_bson.js", message.From(ops.Insert, "collection", data.Data{"id": bsonID1, "name": "nick"}), message.From(ops.Insert, "collection", data.Data{"id": bsonID2, "name": "nick"}), false, }, { "we should be able to skip a nil message", - "module.exports=function(doc) { return false }", + "testdata/skip.js", message.From(ops.Insert, "collection", data.Data{"id": bsonID1, "name": "nick"}), nil, false, }, { "we should be able to change the namespace", - "module.exports=function(doc) { doc['ns'] = 'table'; return doc }", + "testdata/change_ns.js", message.From(ops.Insert, "collection", data.Data{"id": bsonID1, "name": "nick"}), message.From(ops.Insert, "table", data.Data{"id": bsonID1, "name": "nick"}), false, }, { "we should be able to add an object to the bson", - `module.exports=function(doc) { doc['data']['added'] = {"name":"batman","villain":"joker"}; return doc }`, + "testdata/add_data.js", message.From(ops.Insert, "collection", data.Data{"name": "nick"}), message.From(ops.Insert, "collection", data.Data{"name": "nick", "added": bson.M{"name": "batman", "villain": "joker"}}), false, @@ -84,18 +77,10 @@ var ( } ) -func TestWrite(t *testing.T) { +func TestApply(t *testing.T) { for _, v := range writeTests { - c, err := NewClient(WithFunction(v.fn)) - if err != nil { - t.Fatalf("[%s] NewClient() error, %s", v.name, err) - } - s, err := c.Connect() - if err != nil { - t.Fatalf("[%s] unexpected Connect() error, %s", v.name, err) - } - w := Writer{} - msg, err := w.Write(v.in)(s) + o := Otto{Filename: v.fn} + msg, err := o.Apply(v.in) if (err != nil) != v.err { t.Errorf("[%s] error expected %t but actually got %v", v.name, v.err, err) continue @@ -151,20 +136,11 @@ func isEqualBSON(m1 map[string]interface{}, m2 map[string]interface{}) bool { } func BenchmarkTransformOne(b *testing.B) { - c, err := NewClient(WithFunction("module.exports=function(doc) { return doc }")) - if err != nil { - panic(err) - } - s, err := c.Connect() - if err != nil { - panic(err) - } - w := Writer{} + o := Otto{Filename: "testdata/transformer.js"} msg := message.From(ops.Insert, "collection", map[string]interface{}{"id": bson.NewObjectId(), "name": "nick"}) - b.ResetTimer() for i := 0; i < b.N; i++ { - w.Write(msg)(s) + o.Apply(msg) } } diff --git a/function/ottojs/testdata/add_data.js b/function/ottojs/testdata/add_data.js new file mode 100644 index 000000000..9ea51c935 --- /dev/null +++ b/function/ottojs/testdata/add_data.js @@ -0,0 +1 @@ +module.exports=function(doc) { doc['data']['added'] = {"name":"batman","villain":"joker"}; return doc } \ No newline at end of file diff --git a/function/ottojs/testdata/change_bson.js b/function/ottojs/testdata/change_bson.js new file mode 100644 index 000000000..5faebc9ae --- /dev/null +++ b/function/ottojs/testdata/change_bson.js @@ -0,0 +1 @@ +module.exports=function(doc) { doc['data']['id']['$oid'] = '54a4420502a14b9641000001'; return doc } \ No newline at end of file diff --git a/function/ottojs/testdata/change_ns.js b/function/ottojs/testdata/change_ns.js new file mode 100644 index 000000000..dd360946e --- /dev/null +++ b/function/ottojs/testdata/change_ns.js @@ -0,0 +1 @@ +module.exports=function(doc) { doc['ns'] = 'table'; return doc } \ No newline at end of file diff --git a/function/ottojs/testdata/delete_name.js b/function/ottojs/testdata/delete_name.js new file mode 100644 index 000000000..f1d0094c2 --- /dev/null +++ b/function/ottojs/testdata/delete_name.js @@ -0,0 +1 @@ +module.exports=function(doc) { doc['data'] = _.omit(doc['data'], ['name']); return doc } \ No newline at end of file diff --git a/function/ottojs/testdata/skip.js b/function/ottojs/testdata/skip.js new file mode 100644 index 000000000..4ee1a816b --- /dev/null +++ b/function/ottojs/testdata/skip.js @@ -0,0 +1 @@ +module.exports=function(doc) { return false } \ No newline at end of file diff --git a/function/ottojs/testdata/transformer.js b/function/ottojs/testdata/transformer.js new file mode 100644 index 000000000..d275c60fa --- /dev/null +++ b/function/ottojs/testdata/transformer.js @@ -0,0 +1 @@ +module.exports=function(doc) { return doc } \ No newline at end of file diff --git a/adaptor/function/pick/README.md b/function/pick/README.md similarity index 100% rename from adaptor/function/pick/README.md rename to function/pick/README.md diff --git a/function/pick/picker.go b/function/pick/picker.go new file mode 100644 index 000000000..4200fbfec --- /dev/null +++ b/function/pick/picker.go @@ -0,0 +1,32 @@ +package pick + +import ( + "github.com/compose/transporter/function" + "github.com/compose/transporter/log" + "github.com/compose/transporter/message" +) + +func init() { + function.Add( + "pick", + func() function.Function { + return &Picker{} + }, + ) +} + +type Picker struct { + Fields []string `json:"fields"` +} + +func (p *Picker) Apply(msg message.Msg) (message.Msg, error) { + log.With("msg", msg).Infof("picking...") + pluckedMsg := map[string]interface{}{} + for _, k := range p.Fields { + if v, ok := msg.Data().AsMap()[k]; ok { + pluckedMsg[k] = v + } + } + log.With("msg", pluckedMsg).Infof("...picked") + return message.From(msg.OP(), msg.Namespace(), pluckedMsg), nil +} diff --git a/adaptor/function/pick/picker_test.go b/function/pick/picker_test.go similarity index 64% rename from adaptor/function/pick/picker_test.go rename to function/pick/picker_test.go index 32d8c2a13..1529b10d8 100644 --- a/adaptor/function/pick/picker_test.go +++ b/function/pick/picker_test.go @@ -4,31 +4,27 @@ import ( "reflect" "testing" - "github.com/compose/transporter/adaptor" + "github.com/compose/transporter/function" _ "github.com/compose/transporter/log" "github.com/compose/transporter/message" "github.com/compose/transporter/message/ops" ) -var initTests = []map[string]interface{}{ - {"fields": []string{"test"}}, +var initTests = []struct { + in map[string]interface{} + expect *Picker +}{ + {map[string]interface{}{"fields": []string{"test"}}, &Picker{Fields: []string{"test"}}}, } func TestInit(t *testing.T) { for _, it := range initTests { - a, err := adaptor.GetAdaptor("pick", it) + a, err := function.GetFunction("pick", it.in) if err != nil { - t.Fatalf("unexpected GetAdaptor() error, %s", err) - } - if _, err := a.Client(); err != nil { - t.Errorf("unexpected Client() error, %s", err) - } - rerr := adaptor.ErrFuncNotSupported{Name: "transformer", Func: "Reader()"} - if _, err := a.Reader(); err != rerr { - t.Errorf("wrong Reader() error, expected %s, got %s", rerr, err) + t.Fatalf("unexpected GetFunction() error, %s", err) } - if _, err := a.Writer(nil, nil); err != nil { - t.Errorf("unexpected Writer() error, %s", err) + if !reflect.DeepEqual(a, it.expect) { + t.Errorf("misconfigured Function, expected %+v, got %+v", it.expect, a) } } } @@ -63,10 +59,10 @@ var pickTests = []struct { }, } -func TestOmit(t *testing.T) { +func TestApply(t *testing.T) { for _, pt := range pickTests { pick := &Picker{pt.fields} - msg, err := pick.Write(message.From(ops.Insert, "test", pt.in))(nil) + msg, err := pick.Apply(message.From(ops.Insert, "test", pt.in)) if !reflect.DeepEqual(err, pt.err) { t.Errorf("[%s] error mismatch, expected %s, got %s", pt.name, pt.err, err) } diff --git a/adaptor/function/pretty/README.md b/function/pretty/README.md similarity index 100% rename from adaptor/function/pretty/README.md rename to function/pretty/README.md diff --git a/function/pretty/prettify.go b/function/pretty/prettify.go new file mode 100644 index 000000000..a0440a007 --- /dev/null +++ b/function/pretty/prettify.go @@ -0,0 +1,42 @@ +package pretty + +import ( + "encoding/json" + "strings" + + "github.com/compose/mejson" + "github.com/compose/transporter/function" + "github.com/compose/transporter/log" + "github.com/compose/transporter/message" +) + +const ( + DefaultIndent = 2 +) + +var ( + DefaultPrettifier = &Prettify{Spaces: DefaultIndent} +) + +func init() { + function.Add( + "pretty", + func() function.Function { + return DefaultPrettifier + }, + ) +} + +type Prettify struct { + Spaces int `json:"spaces"` +} + +func (p *Prettify) Apply(msg message.Msg) (message.Msg, error) { + d, _ := mejson.Unmarshal(msg.Data()) + b, _ := json.Marshal(d) + if p.Spaces > 0 { + b, _ = json.MarshalIndent(d, "", strings.Repeat(" ", p.Spaces)) + } + log.Infof("\n%s", string(b)) + return msg, nil +} diff --git a/adaptor/function/pretty/prettify_test.go b/function/pretty/prettify_test.go similarity index 61% rename from adaptor/function/pretty/prettify_test.go rename to function/pretty/prettify_test.go index d36c58ac2..c34f0f950 100644 --- a/adaptor/function/pretty/prettify_test.go +++ b/function/pretty/prettify_test.go @@ -5,7 +5,7 @@ import ( "testing" "time" - "github.com/compose/transporter/adaptor" + "github.com/compose/transporter/function" _ "github.com/compose/transporter/log" "github.com/compose/transporter/message" "github.com/compose/transporter/message/ops" @@ -13,20 +13,22 @@ import ( bson "gopkg.in/mgo.v2/bson" ) +var initTests = []struct { + in map[string]interface{} + expect *Prettify +}{ + {map[string]interface{}{}, DefaultPrettifier}, +} + func TestInit(t *testing.T) { - a, err := adaptor.GetAdaptor("pretty", map[string]interface{}{}) - if err != nil { - t.Fatalf("unexpected GetAdaptor() error, %s", err) - } - if _, err := a.Client(); err != nil { - t.Errorf("unexpected Client() error, %s", err) - } - rerr := adaptor.ErrFuncNotSupported{Name: "transformer", Func: "Reader()"} - if _, err := a.Reader(); err != rerr { - t.Errorf("wrong Reader() error, expected %s, got %s", rerr, err) - } - if _, err := a.Writer(nil, nil); err != nil { - t.Errorf("unexpected Writer() error, %s", err) + for _, it := range initTests { + a, err := function.GetFunction("pretty", it.in) + if err != nil { + t.Fatalf("unexpected GetFunction() error, %s", err) + } + if !reflect.DeepEqual(a, it.expect) { + t.Errorf("misconfigured Function, expected %+v, got %+v", it.expect, a) + } } } @@ -52,9 +54,9 @@ var prettyTests = []struct { }, } -func TestPretty(t *testing.T) { +func TestApply(t *testing.T) { for _, pt := range prettyTests { - msg, err := pt.p.Write(message.From(ops.Insert, "test", pt.data))(nil) + msg, err := pt.p.Apply(message.From(ops.Insert, "test", pt.data)) if err != nil { t.Errorf("unexpected error, got %s", err) } diff --git a/function/registry.go b/function/registry.go new file mode 100644 index 000000000..596d94792 --- /dev/null +++ b/function/registry.go @@ -0,0 +1,53 @@ +package function + +import ( + "encoding/json" + "fmt" +) + +// ErrNotFound gives the details of the failed function +type ErrNotFound struct { + Name string +} + +func (a ErrNotFound) Error() string { + return fmt.Sprintf("function '%s' not found in registry", a.Name) +} + +// Creator defines the init structure for a Function. +type Creator func() Function + +var functions = map[string]Creator{} + +// Add should be called in init func of an implementing Function. +func Add(name string, creator Creator) { + functions[name] = creator +} + +// GetFunction looks up a function by name and then init's it with the provided map. +// returns ErrNotFound if the provided name was not registered. +func GetFunction(name string, conf map[string]interface{}) (Function, error) { + creator, ok := functions[name] + if ok { + a := creator() + b, err := json.Marshal(conf) + if err != nil { + return nil, err + } + err = json.Unmarshal(b, a) + if err != nil { + return nil, err + } + return a, nil + } + return nil, ErrNotFound{name} +} + +// RegisteredFunctions returns a slice of the names of every function registered. +func RegisteredFunctions() []string { + all := make([]string, 0) + for i := range functions { + all = append(all, i) + } + return all +} diff --git a/function/rename/README.md b/function/rename/README.md new file mode 100644 index 000000000..dba55eaf3 --- /dev/null +++ b/function/rename/README.md @@ -0,0 +1,36 @@ +# rename function + +`rename()` will update the replace existing key names with new ones based on the provided configuration. It currently only works for top level fields (i.e. `address.street` would not work). + +### configuration + +```javascript +rename({"field_map": {"test":"renamed"}}) +``` + +### example + +message in +```JSON +{ + "_id": 0, + "name": "transporter", + "type": "function", + "count": 10 +} +``` + +config +```javascript +rename({"field_map": {"count":"total"}}) +``` + +message out +```JSON +{ + "_id": 0, + "name": "transporter", + "type": "function", + "total": 10 +} +``` \ No newline at end of file diff --git a/function/rename/rename.go b/function/rename/rename.go new file mode 100644 index 000000000..4fee2ad1d --- /dev/null +++ b/function/rename/rename.go @@ -0,0 +1,34 @@ +package rename + +import ( + "github.com/compose/transporter/function" + "github.com/compose/transporter/message" +) + +var ( + _ function.Function = &Rename{} +) + +func init() { + function.Add( + "rename", + func() function.Function { + return &Rename{} + }, + ) +} + +// Rename swaps out the field names based on the provided config +type Rename struct { + SwapMap map[string]string `json:"field_map"` +} + +func (r *Rename) Apply(msg message.Msg) (message.Msg, error) { + for oldName, newName := range r.SwapMap { + if val, ok := msg.Data().AsMap()[oldName]; ok { + msg.Data().Set(newName, val) + msg.Data().Delete(oldName) + } + } + return msg, nil +} diff --git a/function/rename/rename_test.go b/function/rename/rename_test.go new file mode 100644 index 000000000..4128d56c9 --- /dev/null +++ b/function/rename/rename_test.go @@ -0,0 +1,76 @@ +package rename + +import ( + "reflect" + "testing" + + "github.com/compose/transporter/function" + _ "github.com/compose/transporter/log" + "github.com/compose/transporter/message" + "github.com/compose/transporter/message/ops" +) + +var initTests = []struct { + in map[string]interface{} + expect *Rename +}{ + { + map[string]interface{}{"field_map": map[string]string{"test": "newtest"}}, + &Rename{SwapMap: map[string]string{"test": "newtest"}}, + }, +} + +func TestInit(t *testing.T) { + for _, it := range initTests { + a, err := function.GetFunction("rename", it.in) + if err != nil { + t.Fatalf("unexpected GetFunction() error, %s", err) + } + if !reflect.DeepEqual(a, it.expect) { + t.Errorf("misconfigured Function, expected %+v, got %+v", it.expect, a) + } + } +} + +var renameTests = []struct { + name string + fieldMap map[string]string + in map[string]interface{} + out map[string]interface{} + err error +}{ + { + "single field", + map[string]string{"type": "expression"}, + map[string]interface{}{"_id": "blah", "type": "good"}, + map[string]interface{}{"_id": "blah", "expression": "good"}, + nil, + }, + { + "multiple fields", + map[string]string{"_id": "id", "name": "n"}, + map[string]interface{}{"_id": "blah", "type": "good", "name": "hello"}, + map[string]interface{}{"id": "blah", "type": "good", "n": "hello"}, + nil, + }, + { + "no matched fields", + map[string]string{"name": "n"}, + map[string]interface{}{"_id": "blah", "type": "good"}, + map[string]interface{}{"_id": "blah", "type": "good"}, + nil, + }, +} + +func TestApply(t *testing.T) { + for _, rt := range renameTests { + rename := &Rename{rt.fieldMap} + msg, err := rename.Apply(message.From(ops.Insert, "test", rt.in)) + if !reflect.DeepEqual(err, rt.err) { + t.Errorf("[%s] error mismatch, expected %s, got %s", rt.name, rt.err, err) + } + if !reflect.DeepEqual(msg.Data().AsMap(), rt.out) { + t.Errorf("[%s] wrong message, expected %+v, got %+v", rt.name, rt.out, msg.Data().AsMap()) + } + } +} diff --git a/adaptor/function/skip/README.md b/function/skip/README.md similarity index 100% rename from adaptor/function/skip/README.md rename to function/skip/README.md diff --git a/function/skip/skipper.go b/function/skip/skipper.go new file mode 100644 index 000000000..adebb9a5f --- /dev/null +++ b/function/skip/skipper.go @@ -0,0 +1,111 @@ +package skip + +import ( + "fmt" + "math" + "reflect" + "regexp" + "strconv" + + "github.com/compose/transporter/function" + "github.com/compose/transporter/message" +) + +type UnknownOperatorError struct { + Op string +} + +func (e UnknownOperatorError) Error() string { + return fmt.Sprintf("unkown operator, %s", e.Op) +} + +type WrongTypeError struct { + Wanted string + Got string +} + +func (e WrongTypeError) Error() string { + return fmt.Sprintf("value is of incompatible type, wanted %s, got %s", e.Wanted, e.Got) +} + +func init() { + function.Add( + "skip", + func() function.Function { + return &Skip{} + }, + ) +} + +type Skip struct { + Field string `json:"field"` + Operator string `json:"operator"` + Match interface{} `json:"match"` +} + +func (s *Skip) Apply(msg message.Msg) (message.Msg, error) { + val := msg.Data().Get(s.Field) + switch s.Operator { + case "==", "eq", "$eq": + if reflect.DeepEqual(val, s.Match) { + return msg, nil + } + case "=~": + if ok, err := regexp.MatchString(s.Match.(string), val.(string)); err != nil || ok { + return msg, err + } + case ">", "gt", "$gt": + v, m, err := convertForComparison(val, s.Match) + if err == nil && v > m { + return msg, err + } + return nil, err + case ">=", "gte", "$gte": + v, m, err := convertForComparison(val, s.Match) + if err == nil && v >= m { + return msg, err + } + return nil, err + case "<", "lt", "$lt": + v, m, err := convertForComparison(val, s.Match) + if err == nil && v < m { + return msg, err + } + return nil, err + case "<=", "lte", "$lte": + v, m, err := convertForComparison(val, s.Match) + if err == nil && v <= m { + return msg, err + } + return nil, err + default: + return nil, UnknownOperatorError{s.Operator} + } + return nil, nil +} + +func convertForComparison(in1, in2 interface{}) (float64, float64, error) { + float1, err := convertToFloat(in1) + if err != nil { + return math.NaN(), math.NaN(), err + } + float2, err := convertToFloat(in2) + if err != nil { + return math.NaN(), math.NaN(), err + } + return float1, float2, nil +} + +func convertToFloat(in interface{}) (float64, error) { + switch i := in.(type) { + case float64: + return i, nil + case int: + return float64(i), nil + case string: + return strconv.ParseFloat(i, 0) + default: + return math.NaN(), WrongTypeError{"float64 or int", fmt.Sprintf("%T", i)} + } + +} diff --git a/adaptor/function/skip/skipper_test.go b/function/skip/skipper_test.go similarity index 86% rename from adaptor/function/skip/skipper_test.go rename to function/skip/skipper_test.go index e782b9eeb..a700c5566 100644 --- a/adaptor/function/skip/skipper_test.go +++ b/function/skip/skipper_test.go @@ -5,7 +5,7 @@ import ( "strconv" "testing" - "github.com/compose/transporter/adaptor" + "github.com/compose/transporter/function" _ "github.com/compose/transporter/log" "github.com/compose/transporter/message" "github.com/compose/transporter/message/ops" @@ -36,25 +36,24 @@ func TestErrors(t *testing.T) { } } -var initTests = []map[string]interface{}{ - {"field": "test", "operator": "==", "match": 10}, +var initTests = []struct { + in map[string]interface{} + expect *Skip +}{ + { + map[string]interface{}{"field": "test", "operator": "==", "match": 10}, + &Skip{Field: "test", Operator: "==", Match: float64(10)}, + }, } func TestInit(t *testing.T) { for _, it := range initTests { - a, err := adaptor.GetAdaptor("skip", it) + a, err := function.GetFunction("skip", it.in) if err != nil { - t.Fatalf("unexpected GetAdaptor() error, %s", err) - } - if _, err := a.Client(); err != nil { - t.Errorf("unexpected Client() error, %s", err) - } - rerr := adaptor.ErrFuncNotSupported{Name: "transformer", Func: "Reader()"} - if _, err := a.Reader(); err != rerr { - t.Errorf("wrong Reader() error, expected %s, got %s", rerr, err) + t.Fatalf("unexpected GetFunction() error, %s", err) } - if _, err := a.Writer(nil, nil); err != nil { - t.Errorf("unexpected Writer() error, %s", err) + if !reflect.DeepEqual(a, it.expect) { + t.Errorf("misconfigured Function, expected %+v, got %+v", it.expect, a) } } } @@ -145,11 +144,11 @@ var skipTests = []struct { }, } -func TestSkip(t *testing.T) { +func TestApply(t *testing.T) { for _, st := range skipTests { for _, op := range st.operators { skip := &Skip{st.field, op, st.match} - msg, err := skip.Write(message.From(ops.Insert, "test", st.data))(nil) + msg, err := skip.Apply(message.From(ops.Insert, "test", st.data)) if !reflect.DeepEqual(err, st.err) { t.Errorf("[%s %s] error mismatch, expected %s, got %s", op, st.name, st.err, err) } diff --git a/function/testing.go b/function/testing.go new file mode 100644 index 000000000..978860643 --- /dev/null +++ b/function/testing.go @@ -0,0 +1,21 @@ +package function + +import ( + "github.com/compose/transporter/log" + "github.com/compose/transporter/message" +) + +var ( + _ Function = &Mock{} +) + +type Mock struct { + ApplyCount int + Err error +} + +func (m *Mock) Apply(msg message.Msg) (message.Msg, error) { + m.ApplyCount++ + log.With("apply_count", m.ApplyCount).With("err", m.Err).Debugln("applying...") + return msg, m.Err +} diff --git a/integration_tests/mongo_to_es/app.js b/integration_tests/mongo_to_es/app.js index b95fc6156..dc0f3fb4d 100644 --- a/integration_tests/mongo_to_es/app.js +++ b/integration_tests/mongo_to_es/app.js @@ -1,12 +1,11 @@ enron_source_mongo = mongodb({ "uri": "mongodb://${MONGODB_ENRON_SOURCE_USER}:${MONGODB_ENRON_SOURCE_PASSWORD}@${MONGODB_ENRON_SOURCE_URI}/enron", - "tail": false, - "namespace": "enron.emails" + "tail": false }) enron_sink_es = elasticsearch({ - "uri": "https://${ES_ENRON_SINK_USER}:${ES_ENRON_SINK_PASSWORD}@${ES_ENRON_SINK_URI}", - "namespace": "enron.emails" + "uri": "https://${ES_ENRON_SINK_USER}:${ES_ENRON_SINK_PASSWORD}@${ES_ENRON_SINK_URI}" }) -t.Source(enron_source_mongo).Save(enron_sink_es); +t.Source("enron_source_mongo", enron_source_mongo, "enron.emails") + .Save("enron_sink_es", enron_sink_es, "enron.emails"); diff --git a/integration_tests/mongo_to_mongo/app.js b/integration_tests/mongo_to_mongo/app.js index 48426261d..1eab2b65c 100644 --- a/integration_tests/mongo_to_mongo/app.js +++ b/integration_tests/mongo_to_mongo/app.js @@ -1,7 +1,6 @@ enron_source_mongo = mongodb({ "uri": "mongodb://${MONGODB_ENRON_SOURCE_USER}:${MONGODB_ENRON_SOURCE_PASSWORD}@${MONGODB_ENRON_SOURCE_URI}/enron", - "tail": false, - "namespace": "enron.emails" + "tail": false }) enron_sink_mongo = mongodb({ @@ -9,8 +8,8 @@ enron_sink_mongo = mongodb({ "ssl": true, "bulk": true, "wc": 2, - "fsync": true, - "namespace": "enron.emails" + "fsync": true }) -t.Source(enron_source_mongo).Save(enron_sink_mongo); +t.Source("enron_source_mongo", enron_source_mongo, "enron.emails") + .Save("enron_sink_mongo", enron_sink_mongo, "enron.emails"); diff --git a/integration_tests/mongo_to_rethink/app.js b/integration_tests/mongo_to_rethink/app.js index 7b3f0f70d..02719dd91 100644 --- a/integration_tests/mongo_to_rethink/app.js +++ b/integration_tests/mongo_to_rethink/app.js @@ -1,13 +1,12 @@ enron_source_mongo = mongodb({ "uri": "mongodb://${MONGODB_ENRON_SOURCE_USER}:${MONGODB_ENRON_SOURCE_PASSWORD}@${MONGODB_ENRON_SOURCE_URI}/enron", - "tail": false, - "namespace": "enron.emails" + "tail": false }) enron_sink_rethink = rethinkdb({ "uri": "rethink://admin:${RETHINKDB_ENRON_SINK_PASSWORD}@${RETHINKDB_ENRON_SINK_URI}/enron", - "ssl": true, - "namespace": "enron.emails" + "ssl": true }) -t.Source(enron_source_mongo).Save(enron_sink_rethink); +t.Source("enron_source_mongo", enron_source_mongo, "enron.emails") + .Save("enron_sink_rethink", enron_sink_rethink, "enron.emails"); diff --git a/integration_tests/rethink_to_postgres/app.js b/integration_tests/rethink_to_postgres/app.js index 232d5538a..5c80e0c66 100644 --- a/integration_tests/rethink_to_postgres/app.js +++ b/integration_tests/rethink_to_postgres/app.js @@ -1,12 +1,11 @@ enron_source_rethink = rethinkdb({ "uri": "rethink://admin:${RETHINKDB_ENRON_SOURCE_PASSWORD}@${RETHINKDB_ENRON_SOURCE_URI}/enron", - "ssl": true, - namespace: "enron.emails" + "ssl": true }) enron_sink_postgres = postgres({ - "uri": "postgres://${POSTGRES_ENRON_SINK_USER}:${POSTGRES_ENRON_SINK_PASSWORD}@${POSTGRES_ENRON_SINK_URI}", - "namespace": "enron.emails" + "uri": "postgres://${POSTGRES_ENRON_SINK_USER}:${POSTGRES_ENRON_SINK_PASSWORD}@${POSTGRES_ENRON_SINK_URI}" }) -t.Source(enron_source_rethink).Save(enron_sink_postgres); +t.Source("enron_source_rethink", enron_source_rethink, "enron.emails") + .Save("enron_sink_postgres", enron_sink_postgres, "enron.emails"); diff --git a/pipe/pipe.go b/pipe/pipe.go index 045d6891b..df8499973 100644 --- a/pipe/pipe.go +++ b/pipe/pipe.go @@ -90,7 +90,7 @@ func (m *Pipe) Listen(fn func(message.Msg) (message.Msg, error), nsFilter *regex m.Err <- err return err } - if skipMsg(outmsg) { + if outmsg == nil { break } if len(m.Out) > 0 { @@ -121,13 +121,13 @@ func (m *Pipe) Stop() { // Send emits the given message on the 'Out' channel. the send Timesout after 100 ms in order to chaeck of the Pipe has stopped and we've been asked to exit. // If the Pipe has been stopped, the send will fail and there is no guarantee of either success or failure func (m *Pipe) Send(msg message.Msg) { + m.MessageCount++ for _, ch := range m.Out { A: for { select { case ch <- msg: - m.MessageCount++ m.LastMsg = msg break A case <-time.After(100 * time.Millisecond): @@ -139,8 +139,3 @@ func (m *Pipe) Send(msg message.Msg) { } } } - -// skipMsg returns true if the message should be skipped and not send on to any listening nodes -func skipMsg(msg message.Msg) bool { - return msg == nil -} diff --git a/pipeline/node.go b/pipeline/node.go index bf9a9f128..a9d4efcc4 100644 --- a/pipeline/node.go +++ b/pipeline/node.go @@ -13,8 +13,10 @@ import ( "github.com/compose/transporter/adaptor" "github.com/compose/transporter/client" + "github.com/compose/transporter/function" "github.com/compose/transporter/log" "github.com/compose/transporter/message" + "github.com/compose/transporter/message/ops" "github.com/compose/transporter/pipe" ) @@ -33,60 +35,85 @@ var ( // source.Add(sink2) // type Node struct { - Name string `json:"name"` // the name of this node - Type string `json:"type"` // the node's type, used to create the adaptorementation - Extra adaptor.Config `json:"extra"` // extra config options that are passed to the adaptorementation - Children []*Node `json:"children"` // the nodes are set up as a tree, this is an array of this nodes children - Parent *Node `json:"parent"` // this node's parent node, if this is nil, this is a 'source' node + Name string `json:"name"` // the name of this node + Type string `json:"type"` // the node's type, used to create the adaptorementation + Children []*Node `json:"children"` // the nodes are set up as a tree, this is an array of this nodes children + Parent *Node `json:"parent"` // this node's parent node, if this is nil, this is a 'source' node + Transforms []*Transform nsFilter *regexp.Regexp c client.Client - r client.Reader - w client.Writer + reader client.Reader + writer client.Writer done chan struct{} wg sync.WaitGroup l log.Logger pipe *pipe.Pipe } +type Transform struct { + Name string + Fn function.Function + NsFilter *regexp.Regexp +} + // NewNode creates a new Node struct -func NewNode(name, kind string, extra adaptor.Config) *Node { - return &Node{ - Name: name, - Type: kind, - Extra: extra, - Children: make([]*Node, 0), - done: make(chan struct{}), +func NewNode(name, kind, ns string, a adaptor.Adaptor, parent *Node) (*Node, error) { + _, nsFilter, err := adaptor.CompileNamespace(ns) + if err != nil { + return nil, err + } + n := &Node{ + Name: name, + Type: kind, + nsFilter: nsFilter, + Children: make([]*Node, 0), + Transforms: make([]*Transform, 0), + done: make(chan struct{}), + } + + n.c, err = a.Client() + if err != nil { + return nil, err + } + + if parent == nil { + // TODO: remove path param + n.pipe = pipe.NewPipe(nil, "") + n.reader, err = a.Reader() + if err != nil { + return nil, err + } + } else { + n.Parent = parent + // TODO: remove path param + n.pipe = pipe.NewPipe(parent.pipe, "") + parent.Children = append(parent.Children, n) + n.writer, err = a.Writer(n.done, &n.wg) + if err != nil { + return nil, err + } } + + return n, nil } // String func (n *Node) String() string { var ( - uri string - s string - prefix string - namespace = n.Extra.GetString("namespace") + s, prefix string depth = n.depth() ) - if n.Type == transformerNode { - uri = n.Extra.GetString("filename") - } else { - uri = n.Extra.GetString("uri") - } prefixformatter := fmt.Sprintf("%%%ds%%-%ds", depth, 18-depth) if n.Parent == nil { // root node - // s = fmt.Sprintf("%18s %-40s %-15s %-30s %s\n", " ", "Name", "Type", "Namespace", "URI") prefix = fmt.Sprintf(prefixformatter, " ", "- Source: ") - } else if len(n.Children) == 0 { + } else { prefix = fmt.Sprintf(prefixformatter, " ", "- Sink: ") - } else if n.Type == transformerNode { - prefix = fmt.Sprintf(prefixformatter, " ", "- Transformer: ") } - s += fmt.Sprintf("%-18s %-40s %-15s %-30s %s", prefix, n.Name, n.Type, namespace, uri) + s += fmt.Sprintf("%s %-40s %-15s %-30s", prefix, n.Name, n.Type, n.nsFilter.String()) for _, child := range n.Children { s += "\n" + child.String() @@ -117,58 +144,14 @@ func (n *Node) Path() string { return n.Parent.Path() + "/" + n.Name } -// Add the given node as a child of this node. -// This has side effects, and sets the parent of the given node -func (n *Node) Add(node *Node) *Node { - node.Parent = n - n.Children = append(n.Children, node) - return n -} - -// Init sets up the node for action. It creates a pipe and adaptor for this node, -// and then recurses down the tree calling Init on each child -func (n *Node) Init() (err error) { - path := n.Path() - - n.l = log.With("name", n.Name).With("type", n.Type).With("path", path) - - _, nsFilter, err := adaptor.CompileNamespace(n.Extra.GetString("namespace")) +// AddTransform adds the provided function.Function to the Node and will be called +// before sending any messages down the pipeline. +func (n *Node) AddTransform(name string, f function.Function, ns string) error { + _, nsFilter, err := adaptor.CompileNamespace(ns) if err != nil { return err } - n.nsFilter = nsFilter - - a, err := adaptor.GetAdaptor(n.Type, n.Extra) - if err != nil { - return err - } - - n.c, err = a.Client() - if err != nil { - return err - } - - if n.Parent == nil { // we don't have a parent, we're the source - n.pipe = pipe.NewPipe(nil, path) - n.r, err = a.Reader() - if err != nil { - return err - } - } else { // we have a parent, so pass in the parent's pipe here - n.pipe = pipe.NewPipe(n.Parent.pipe, path) - n.w, err = a.Writer(n.done, &n.wg) - if err != nil { - return err - } - } - - for _, child := range n.Children { - err = child.Init() // init each child - if err != nil { - return err - } - } - + n.Transforms = append(n.Transforms, &Transform{name, f, nsFilter}) return nil } @@ -177,6 +160,9 @@ func (n *Node) Init() (err error) { // and will emit messages to it's children, // All descendant nodes run Listen() on the adaptor func (n *Node) Start() error { + path := n.Path() + n.l = log.With("name", n.Name).With("type", n.Type).With("path", path) + for _, child := range n.Children { go func(node *Node) { node.Start() @@ -205,7 +191,7 @@ func (n *Node) start() (err error) { n.l.Infoln("session closed...") }() } - readFunc := n.r.Read(func(check string) bool { return n.nsFilter.MatchString(check) }) + readFunc := n.reader.Read(func(check string) bool { return n.nsFilter.MatchString(check) }) msgChan, err := readFunc(s, n.done) if err != nil { return err @@ -226,6 +212,12 @@ func (n *Node) listen() (err error) { } func (n *Node) write(msg message.Msg) (message.Msg, error) { + transformedMsg, err := n.applyTransforms(msg) + if err != nil { + return msg, nil + } else if transformedMsg == nil { + return nil, nil + } sess, err := n.c.Connect() if err != nil { return msg, err @@ -235,8 +227,7 @@ func (n *Node) write(msg message.Msg) (message.Msg, error) { s.Close() } }() - msg, err = n.w.Write(message.From(msg.OP(), msg.Namespace(), msg.Data()))(sess) - + returnMsg, err := n.writer.Write(transformedMsg)(sess) if err != nil { n.pipe.Err <- adaptor.Error{ Lvl: adaptor.ERROR, @@ -245,7 +236,32 @@ func (n *Node) write(msg message.Msg) (message.Msg, error) { Record: msg.Data, } } - return msg, err + return returnMsg, err +} + +func (n *Node) applyTransforms(msg message.Msg) (message.Msg, error) { + if msg.OP() != ops.Command { + for _, transform := range n.Transforms { + if !transform.NsFilter.MatchString(msg.Namespace()) { + n.l.With("transform", transform.Name).With("ns", msg.Namespace()).Infoln("filtered message") + continue + } + m, err := transform.Fn.Apply(msg) + if err != nil { + n.l.Errorf("transform function error, %s", err) + return nil, err + } else if m == nil { + n.l.With("transform", transform.Name).Infoln("returned nil message, skipping") + return nil, nil + } + msg = m + if msg.OP() == ops.Skip { + n.l.With("transform", transform.Name).With("op", msg.OP()).Infoln("skipping message") + return nil, nil + } + } + } + return msg, nil } // Stop this node's adaptor, and sends a stop to each child of this node @@ -263,7 +279,7 @@ func (n *Node) stop() error { close(n.done) n.wg.Wait() - if closer, ok := n.w.(client.Closer); ok { + if closer, ok := n.writer.(client.Closer); ok { defer func() { n.l.Infoln("closing writer...") closer.Close() @@ -291,10 +307,6 @@ func (n *Node) Validate() bool { return false } - if n.Type == transformerNode && len(n.Children) == 0 { // transformers need children - return false - } - for _, child := range n.Children { if !child.Validate() { return false diff --git a/pipeline/node_test.go b/pipeline/node_test.go index 34adcbb74..12084fe64 100644 --- a/pipeline/node_test.go +++ b/pipeline/node_test.go @@ -1,26 +1,34 @@ package pipeline import ( + "errors" + "regexp" "sync" "testing" + "time" "github.com/compose/transporter/adaptor" "github.com/compose/transporter/client" + "github.com/compose/transporter/function" "github.com/compose/transporter/message" + "github.com/compose/transporter/message/ops" + "github.com/compose/transporter/pipe" ) +var DefaultNS = regexp.MustCompile(".*") + func TestNodeString(t *testing.T) { data := []struct { in *Node out string }{ { - &Node{}, - " - Source: ", - }, - { - NewNode("name", "mongo", adaptor.Config{"uri": "uri", "namespace": "db.col", "debug": false}), - " - Source: name mongo db.col uri", + &Node{ + Name: "name", + Type: "mongodb", + nsFilter: DefaultNS, + }, + " - Source: name mongodb .* ", }, } @@ -37,24 +45,23 @@ func TestValidate(t *testing.T) { out bool }{ { - NewNode("first", "mongo", adaptor.Config{}), - false, - }, - { - NewNode("second", "mongo", adaptor.Config{}).Add(NewNode("name", "mongo", adaptor.Config{})), - true, - }, - { - NewNode("third", "mongo", adaptor.Config{}).Add(NewNode("name", "transformer", adaptor.Config{})), + &Node{Name: "first", Type: "mongodb", nsFilter: DefaultNS, Parent: nil}, false, }, { - NewNode("fourth", "mongo", adaptor.Config{}).Add(NewNode("name", "transformer", adaptor.Config{}).Add(NewNode("name", "mongo", adaptor.Config{}))), + &Node{Name: "first", Type: "mongodb", nsFilter: DefaultNS, Parent: nil, + Children: []*Node{ + &Node{Name: "second", Type: "mongodb", nsFilter: DefaultNS}, + }, + }, true, }, } for _, v := range data { + for _, child := range v.in.Children { + child.Parent = v.in + } if v.in.Validate() != v.out { t.Errorf("%s: expected: %t got: %t", v.in.Name, v.out, v.in.Validate()) } @@ -67,15 +74,15 @@ func TestPath(t *testing.T) { out string }{ { - NewNode("first", "mongo", adaptor.Config{}), + &Node{Name: "first", Type: "mongodb", nsFilter: DefaultNS}, "first", }, { - NewNode("first", "mongo", adaptor.Config{}).Add(NewNode("second", "mongo", adaptor.Config{})), + &Node{Name: "second", Type: "mongodb", nsFilter: DefaultNS, Parent: &Node{Name: "first", Type: "mongodb", nsFilter: DefaultNS}}, "first/second", }, { - NewNode("first", "mongo", adaptor.Config{}).Add(NewNode("second", "transformer", adaptor.Config{}).Add(NewNode("third", "mongo", adaptor.Config{}))), + &Node{Name: "third", Type: "mongodb", nsFilter: DefaultNS, Parent: &Node{Name: "second", Type: "transformer", nsFilter: DefaultNS, Parent: &Node{Name: "first", Type: "mongodb", nsFilter: DefaultNS}}}, "first/second/third", }, } @@ -103,7 +110,8 @@ func init() { } type StopWriter struct { - Closed bool + MsgCount int + Closed bool } func (s *StopWriter) Client() (client.Client, error) { @@ -111,7 +119,7 @@ func (s *StopWriter) Client() (client.Client, error) { } func (s *StopWriter) Reader() (client.Reader, error) { - return &client.MockReader{}, nil + return &client.MockReader{MsgCount: 10}, nil } func (s *StopWriter) Writer(done chan struct{}, wg *sync.WaitGroup) (client.Writer, error) { @@ -120,6 +128,7 @@ func (s *StopWriter) Writer(done chan struct{}, wg *sync.WaitGroup) (client.Writ func (s *StopWriter) Write(msg message.Msg) func(client.Session) (message.Msg, error) { return func(client.Session) (message.Msg, error) { + s.MsgCount++ return msg, nil } } @@ -128,47 +137,182 @@ func (s *StopWriter) Close() { s.Closed = true } -var stopTests = []struct { - node *Node -}{ - { - &Node{ - Name: "starter", - Type: "stopWriter", - Extra: map[string]interface{}{"namespace": "test.test"}, - Children: []*Node{ - &Node{ - Name: "stopper", - Type: "stopWriter", - Extra: map[string]interface{}{"namespace": "test.test"}, - Children: nil, - Parent: nil, - done: make(chan struct{}), +type SkipFunc struct { + UsingOp bool +} + +func (s *SkipFunc) Apply(msg message.Msg) (message.Msg, error) { + if s.UsingOp { + return message.From(ops.Skip, msg.Namespace(), msg.Data()), nil + } + return nil, nil +} + +var ( + stopTests = []struct { + node *Node + msgCount int + applyCount int + }{ + { + &Node{ + Name: "starter", + Type: "stopWriter", + nsFilter: DefaultNS, + Children: []*Node{ + &Node{ + Name: "stopper", + Type: "stopWriter", + nsFilter: DefaultNS, + done: make(chan struct{}), + }, }, + Parent: nil, + done: make(chan struct{}), + pipe: pipe.NewPipe(nil, "starter"), }, - Parent: nil, - done: make(chan struct{}), + 10, + 0, }, - }, -} + { + &Node{ + Name: "starter", + Type: "stopWriter", + nsFilter: DefaultNS, + Children: []*Node{ + &Node{ + Name: "stopper", + Type: "stopWriter", + nsFilter: DefaultNS, + done: make(chan struct{}), + Transforms: []*Transform{&Transform{"mock", &function.Mock{}, DefaultNS}}, + }, + }, + Parent: nil, + done: make(chan struct{}), + pipe: pipe.NewPipe(nil, "starter"), + }, + 10, + 10, + }, + { + &Node{ + Name: "starter", + Type: "stopWriter", + nsFilter: DefaultNS, + Children: []*Node{ + &Node{ + Name: "stopper", + Type: "stopWriter", + nsFilter: DefaultNS, + done: make(chan struct{}), + Transforms: []*Transform{&Transform{"mock", &function.Mock{}, regexp.MustCompile("blah")}}, + }, + }, + Parent: nil, + done: make(chan struct{}), + pipe: pipe.NewPipe(nil, "starter"), + }, + 10, + 0, + }, + { + &Node{ + Name: "starter", + Type: "stopWriter", + nsFilter: DefaultNS, + Children: []*Node{ + &Node{ + Name: "stopper", + Type: "stopWriter", + nsFilter: DefaultNS, + done: make(chan struct{}), + Transforms: []*Transform{&Transform{"mock", &function.Mock{Err: errors.New("apply failed")}, DefaultNS}}, + }, + }, + Parent: nil, + done: make(chan struct{}), + pipe: pipe.NewPipe(nil, "starter"), + }, + 0, + 10, + }, + { + &Node{ + Name: "starter", + Type: "stopWriter", + nsFilter: DefaultNS, + Children: []*Node{ + &Node{ + Name: "stopper", + Type: "stopWriter", + nsFilter: DefaultNS, + done: make(chan struct{}), + Transforms: []*Transform{&Transform{"mock", &SkipFunc{}, DefaultNS}}, + }, + }, + Parent: nil, + done: make(chan struct{}), + pipe: pipe.NewPipe(nil, "starter"), + }, + 0, + 10, + }, + { + &Node{ + Name: "starter", + Type: "stopWriter", + nsFilter: DefaultNS, + Children: []*Node{ + &Node{ + Name: "stopper", + Type: "stopWriter", + nsFilter: DefaultNS, + done: make(chan struct{}), + Transforms: []*Transform{&Transform{"mock", &SkipFunc{UsingOp: true}, DefaultNS}}, + }, + }, + Parent: nil, + done: make(chan struct{}), + pipe: pipe.NewPipe(nil, "starter"), + }, + 0, + 10, + }, + } +) func TestStop(t *testing.T) { for _, st := range stopTests { + s := &StopWriter{} + st.node.c, _ = s.Client() for _, child := range st.node.Children { + child.c, _ = s.Client() + child.writer, _ = s.Writer(child.done, &child.wg) + child.pipe = pipe.NewPipe(st.node.pipe, "stopper") child.Parent = st.node } - - if err := st.node.Init(); err != nil { - t.Errorf("unexpected Init() error, %s", err) - } + st.node.reader, _ = s.Reader() if err := st.node.Start(); err != nil { t.Errorf("unexpected Start() error, %s", err) } + time.Sleep(1 * time.Second) st.node.Stop() for _, child := range st.node.Children { - if !child.w.(*StopWriter).Closed { + if !s.Closed { t.Errorf("[%s] child node was not closed but should have been", child.Name) } } + if st.msgCount != s.MsgCount { + t.Errorf("wrong number of messages received, expected %d, got %d", st.msgCount, s.MsgCount) + } + if len(st.node.Children[0].Transforms) > 0 { + switch mock := st.node.Children[0].Transforms[0].Fn.(type) { + case *function.Mock: + if mock.ApplyCount != st.applyCount { + t.Errorf("wrong number of transforms applied, expected %d, got %d", st.applyCount, mock.ApplyCount) + } + } + } } } diff --git a/pipeline/pipeline.go b/pipeline/pipeline.go index a1522ee58..eb479e749 100644 --- a/pipeline/pipeline.go +++ b/pipeline/pipeline.go @@ -66,12 +66,6 @@ func NewPipeline(version string, source *Node, emit events.EmitFunc, interval ti pipeline.sessionTicker = time.NewTicker(sessionInterval) } - // init the pipeline - err := pipeline.source.Init() - if err != nil { - return pipeline, err - } - // init the emitter with the right chan pipeline.emitter = events.NewEmitter(source.pipe.Event, emit) diff --git a/pipeline/pipeline_events_integration_test.go b/pipeline/pipeline_events_integration_test.go index a4ef7c7d5..30dff6eba 100644 --- a/pipeline/pipeline_events_integration_test.go +++ b/pipeline/pipeline_events_integration_test.go @@ -63,8 +63,22 @@ func TestEventsBroadcast(t *testing.T) { setupFiles(inFile, outFile) // set up the nodes - dummyOutNode := NewNode("dummyFileOut", "file", adaptor.Config{"uri": "file://" + outFile, "namespace": "a./.*/"}) - dummyOutNode.Add(NewNode("dummyFileIn", "file", adaptor.Config{"uri": "file://" + inFile, "namespace": "a./.*/"})) + f, err := adaptor.GetAdaptor("file", adaptor.Config{"uri": "file://" + outFile}) + if err != nil { + t.Fatalf("can't create GetAdaptor, got %s", err) + } + dummyOutNode, err := NewNode("dummyFileOut", "file", "blah./.*/", f, nil) + if err != nil { + t.Fatalf("can't create NewNode, got %s", err) + } + f, err = adaptor.GetAdaptor("file", adaptor.Config{"uri": "file://" + inFile}) + if err != nil { + t.Fatalf("can't create GetAdaptor, got %s", err) + } + _, err = NewNode("dummyFileIn", "file", "blah./.*/", f, dummyOutNode) + if err != nil { + t.Fatalf("can't create NewNode, got %s", err) + } p, err := NewDefaultPipeline(dummyOutNode, ts.URL, "asdf", "jklm", "test", 1*time.Second) if err != nil { diff --git a/pipeline/pipeline_integration_test.go b/pipeline/pipeline_integration_test.go index cf564fa71..21394797e 100644 --- a/pipeline/pipeline_integration_test.go +++ b/pipeline/pipeline_integration_test.go @@ -40,8 +40,22 @@ func TestFileToFile(t *testing.T) { numgorosBefore := runtime.NumGoroutine() // create the source node and attach our sink - outNode := NewNode("localfileout", "file", adaptor.Config{"uri": "file://" + outFile, "namespace": "a./.*/"}). - Add(NewNode("localfilein", "file", adaptor.Config{"uri": "file://" + inFile, "namespace": "a./.*/"})) + f, err := adaptor.GetAdaptor("file", adaptor.Config{"uri": "file://" + outFile}) + if err != nil { + t.Fatalf("can't create GetAdaptor, got %s", err) + } + outNode, err := NewNode("localfileout", "file", "blah./.*/", f, nil) + if err != nil { + t.Fatalf("can't create newnode, got %s", err) + } + f, err = adaptor.GetAdaptor("file", adaptor.Config{"uri": "file://" + inFile}) + if err != nil { + t.Fatalf("can't create GetAdaptor, got %s", err) + } + _, err = NewNode("localfilein", "file", "blah./.*/", f, outNode) + if err != nil { + t.Fatalf("can't create newnode, got %s", err) + } // create the pipeline p, err := NewPipeline("test", outNode, events.LogEmitter(), 60*time.Second, nil, 10*time.Second) diff --git a/pipeline/pipeline_test.go b/pipeline/pipeline_test.go index af5ba91b3..0879e5891 100644 --- a/pipeline/pipeline_test.go +++ b/pipeline/pipeline_test.go @@ -1,16 +1,13 @@ package pipeline import ( + "regexp" "testing" "time" "github.com/compose/transporter/adaptor" _ "github.com/compose/transporter/adaptor/file" -) - -var ( - fakesourceCN = NewNode("source1", "source", adaptor.Config{"value": "rockettes", "namespace": "a./.*/"}) - fileNode = NewNode("localfile", "file", adaptor.Config{"uri": "file:///tmp/foo", "namespace": "a./.*/"}) + "github.com/compose/transporter/pipe" ) // a noop node adaptor to help test @@ -42,28 +39,27 @@ func TestPipelineString(t *testing.T) { out string }{ { - fakesourceCN, - nil, - " - Source: source1 source a./.*/ ", - }, - { - fakesourceCN, - fileNode, - " - Source: source1 source a./.*/ \n - Sink: localfile file a./.*/ file:///tmp/foo", + &Node{Name: "source1", Type: "source", nsFilter: regexp.MustCompile(".*"), pipe: pipe.NewPipe(nil, "source1")}, + &Node{Name: "localfile", Type: "file", nsFilter: regexp.MustCompile(".*")}, + ` - Source: source1 source .* + - Sink: localfile file .* `, }, } for _, v := range data { if v.terminalNode != nil { - v.in.Add(v.terminalNode) + v.terminalNode.Parent = v.in + v.terminalNode.pipe = pipe.NewPipe(v.in.pipe, "localfile") + v.in.Children = []*Node{v.terminalNode} } p, err := NewDefaultPipeline(v.in, "", "", "", "test", 100*time.Millisecond) if err != nil { t.Errorf("can't create pipeline, got %s", err.Error()) t.FailNow() } - if p.String() != v.out { - t.Errorf("\nexpected:\n%s\ngot:\n%s\n", v.out, p.String()) + actual := p.String() + if actual != v.out { + t.Errorf("\nexpected:\n%v\ngot:\n%v\n", v.out, actual) } } }