diff --git a/.travis.yml b/.travis.yml index f0750a2..8a215bb 100644 --- a/.travis.yml +++ b/.travis.yml @@ -4,8 +4,12 @@ addons: postgresql: '9.6' env: - - DB=postgresql - - DB=mysql + matrix: + - DB=postgresql + - DB=mysql + global: + - secure: Wvmf5FAySAGXuugkWg/lNQuPfTDA7PYlZB8Izt+bhaVZDp3Re0xz8rhomUw/ZXP/2r81HHPTvDU+eNHxpneyBeHvSXvvhzMVyU+GFLeXdHA0rjvKittkkKq5XuUkj5oaSYJIh0YRn2UdfMaQ/R+y8ZgL2uP41slG6jWpfuMeFe+oqvJCgJ7qovQ7xUataDQdGCyWCpsTlBdpGz+YFUHNrcMK3ppvm4WhzDeMGXHPuTPcWZ+edO8QVWJkxVXamuWgq95tFB3EAXx5tfcg4CS5cC/l/K2tMH7MTsXvWOCRSU+f79cSNlrztP1Mhc/TPBZe87ubmkLOQGRK//ZBqbmSXjsyjv4qkq3z223Dmg+MRkt6ip01zsSVE6vmEvtVHGxtj8VKfWbwWusAXuLa1gtzXa/vVgisF++aj1ZEhApH5mu95gMqy9tOY/mX4Pm4iqalLujvlzVHKDDCmZsyF2O8mZCL2cDH7xlQTLeq9eyRPXNcm13WDDAcOhgVSpWNVObggUMZ67YDGV0yPQKGQBOCCqV3jCKgii/4m7n1H2mn1b8VpYOdK3ncA3U3bcHDPy3Nh44EmdJnj7QiUZtN+RyQGCz83h3X2owGjNG3qvGxrBPy87DbQU+FTN88zB4DdgO/TOYEzg3Wbz24D3thZUDR8rWvu9t9/PvH+CroTIpW4Iw= + - secure: d5+t1ysBsHPGKvrbxsN4M+4ns4Q3YHgAk3gapaPNxlX3pDKLgZ1Kd/XBnQ3dMSCP7FWejgmeLDxloGQMeieq6exEM9kmf8Ui/oFB97j2+ODXPv6+X69+aDDGMT94J77v5DwdJzyULI4vn3ZF2V6fM5/tlle4hf4uwCBNh30Svar7ZXsp/aZL7wpZzUG+efLFJJZ4oSD2NteWR9IeZbNqQtV0fUTHLk3g4Suze6Yt2OmZiFH4pn8Ihftjyr84nke65aPws8XKzoaVmghoZXmKNPNyiH3Dyu6uyGa2OhMspj0zjXHP+sMlgMrp5cvAGchRBz9NPL18VHtgID/hEPLfwb6ZZ6HypgXu5+4/O/J2vGHTFV0q2aluNBJi+sTzklYkHoCiBXtjdZxI/+mwK1VIeMejrhSjNtzdkryEbOQM/CoBcDacQIw4gUOloCG4AcFNXzQZifHZckIgc4JChYmNHhpEu3U1Zil8H7QDQ7MGqJG+HWuK/s/leoJx112jHLmCgPsq9hmMozDdQi8rVSS8Eo6bg7NTOuE9LK/uoweF31NJYITYV3vxmqluwT7+yn7V9avXRFTV1qTcavsjAYkMS3PlLX4f+w1HtZUIOjtdABcK5mg1JjZ92xHmNFudu8b2n6j3A3nOYuA0NsX5BlCuIG84MozNsMleOl6svbRKW+4= services: - postgresql diff --git a/README.md b/README.md index a8ea46c..44df42c 100644 --- a/README.md +++ b/README.md @@ -1,23 +1,28 @@ # migrator [![Build Status](https://travis-ci.org/lukaszbudnik/migrator.svg?branch=master)](https://travis-ci.org/lukaszbudnik/migrator) [![Go Report Card](https://goreportcard.com/badge/github.com/lukaszbudnik/migrator)](https://goreportcard.com/report/github.com/lukaszbudnik/migrator) [![codecov](https://codecov.io/gh/lukaszbudnik/migrator/branch/master/graph/badge.svg)](https://codecov.io/gh/lukaszbudnik/migrator) -Super fast and lightweight DB migration tool written in go. +Super fast and lightweight DB migration tool written in go. migrator consumes 6MB of memory and outperforms other DB migration/evolution frameworks by a few orders of magnitude. -migrator manages and versions all the DB changes for you and completely eliminates manual and error-prone administrative tasks. migrator not only supports single schemas, but also comes with a multi-tenant support. +migrator manages and versions all the DB changes for you and completely eliminates manual and error-prone administrative tasks. migrator versions can be used for auditing and compliance purposes. migrator not only supports single schemas, but also comes with a multi-schema support (ideal for multi-schema multi-tenant SaaS products). migrator runs as a HTTP REST service and can be easily integrated into your continuous integration and continuous delivery pipeline. -Further, there is an official docker image available on docker hub. [lukasz/migrator](https://hub.docker.com/r/lukasz/migrator) is ultra lightweight and has a size of 15MB. Ideal for micro-services deployments! +The official docker image is available on docker hub at [lukasz/migrator](https://hub.docker.com/r/lukasz/migrator). It is ultra lightweight and has a size of 15MB. Ideal for micro-services deployments! # Table of contents * [Usage](#usage) * [GET /](#get-) - * [GET /v1/config](#get-v1config) - * [GET /v1/migrations/source](#get-v1migrationssource) - * [GET /v1/migrations/applied](#get-v1migrationsapplied) - * [POST /v1/migrations](#post-v1migrations) - * [GET /v1/tenants](#get-v1tenants) - * [POST /v1/tenants](#post-v1tenants) + * [/v2 - GraphQL API](#v2---graphql-api) + * [GET /v2/config](#get-v2config) + * [GET /v2/schema](#get-v2schema) + * [POST /v2/service](#post-v2service) + * [/v1](#v1) + * [GET /v1/config](#get-v1config) + * [GET /v1/migrations/source](#get-v1migrationssource) + * [GET /v1/migrations/applied](#get-v1migrationsapplied) + * [POST /v1/migrations](#post-v1migrations) + * [GET /v1/tenants](#get-v1tenants) + * [POST /v1/tenants](#post-v1tenants) * [Request tracing](#request-tracing) * [Quick Start Guide](#quick-start-guide) * [1. Get the migrator project](#1-get-the-migrator-project) @@ -53,7 +58,7 @@ migrator exposes a simple REST API described below. ## GET / -Migrator returns build information together with supported API versions. +Migrator returns build information together with a list of supported API versions. Sample request: @@ -66,20 +71,350 @@ Sample HTTP response: ``` < HTTP/1.1 200 OK < Content-Type: application/json; charset=utf-8 -< Date: Wed, 08 Jan 2020 09:13:58 GMT -< Content-Length: 142 +< Date: Mon, 02 Mar 2020 19:48:45 GMT +< Content-Length: 150 +< +{"release":"dev-v2020.1.0","commitSha":"c871b176f6e428e186dfe5114a9c86d52a4350f2","commitDate":"2020-03-01T20:58:32+01:00","apiVersions":["v1","v2"]} +``` + +## /v2 - GraphQL API + +API v2 was introduced in migrator v2020.1.0. API v2 is a GraphQL API. + +API v2 also introduced a formal concept of DB versions. Every migrator action creates a new DB version. Version logically groups all applied DB migrations for auditing and compliance purposes. You can browse versions together with executed DB migrations using the GraphQL API. + +## GET /v2/config + +Returns migrator's config as `application/x-yaml`. + +Sample request: + +``` +curl -v http://localhost:8080/v2/config +``` + +Sample HTTP response: + +``` +< HTTP/1.1 200 OK +< Content-Type: application/x-yaml; charset=utf-8 +< Date: Mon, 02 Mar 2020 20:03:13 GMT +< Content-Length: 244 +< +baseLocation: test/migrations +driver: sqlserver +dataSource: sqlserver://SA:YourStrongPassw0rd@127.0.0.1:32774/?database=migratortest&connection+timeout=1&dial+timeout=1 +singleMigrations: +- ref +- config +tenantMigrations: +- tenants +pathPrefix: / +``` + +## GET /v2/schema + +Returns migrator's GraphQL schema as `plain/text`. + +Although migrator supports GraphQL introspection it is much more convenient to get the schema in the plain text. + +Sample request: + +``` +curl -v http://localhost:8080/v2/schema +``` + +Sample HTTP response (truncated): + +``` +< HTTP/1.1 200 OK +< Content-Type: text/plain; charset=utf-8 +< Date: Mon, 02 Mar 2020 20:12:20 GMT +< Transfer-Encoding: chunked +< +schema { + query: Query + mutation: Mutation +} +enum MigrationType { + SingleMigration + TenantMigration + SingleScript + TenantScript +} +... +``` + +## POST /v2/service + +This is a GraphQL endpoint which handles both query and mutation requests. + +The current GraphQL schema together with description in comments is as follows: + +```graphql +schema { + query: Query + mutation: Mutation +} +enum MigrationType { + SingleMigration + TenantMigration + SingleScript + TenantScript +} +enum Action { + // Apply is the default action, migrator reads all source migrations and applies them + Apply + // Sync is an action where migrator reads all source migrations and marks them as applied in DB + // typical use cases are: + // importing source migrations from a legacy tool or synchronising tenant migrations when tenant was created using external tool + Sync +} +scalar Time +interface Migration { + name: String! + migrationType: MigrationType! + sourceDir: String! + file: String! + contents: String! + checkSum: String! +} +type SourceMigration implements Migration { + name: String! + migrationType: MigrationType! + sourceDir: String! + file: String! + contents: String! + checkSum: String! +} +type DBMigration implements Migration { + id: Int! + name: String! + migrationType: MigrationType! + sourceDir: String! + file: String! + contents: String! + checkSum: String! + schema: String! + created: Time! +} +type Tenant { + name: String! +} +type Version { + id: Int! + name: String! + created: Time! + dbMigrations: [DBMigration!]! +} +input SourceMigrationFilters { + name: String + sourceDir: String + file: String + migrationType: MigrationType +} +input VersionInput { + versionName: String! + action: Action = Apply + dryRun: Boolean = false +} +input TenantInput { + tenantName: String! + versionName: String! + action: Action = Apply + dryRun: Boolean = false +} +type Summary { + // date time operation started + startedAt: Time! + // how long the operation took in nanoseconds + duration: Int! + // number of tenants in the system + tenants: Int! + // number of applied single schema migrations + singleMigrations: Int! + // number of applied multi-tenant schema migrations + tenantMigrations: Int! + // number of all applied multi-tenant schema migrations (equals to tenants * tenantMigrations) + tenantMigrationsTotal: Int! + // sum of singleMigrations and tenantMigrationsTotal + migrationsGrandTotal: Int! + // number of applied single schema scripts + singleScripts: Int! + // number of applied multi-tenant schema scripts + tenantScripts: Int! + // number of all applied multi-tenant schema migrations (equals to tenants * tenantScripts) + tenantScriptsTotal: Int! + // sum of singleScripts and tenantScriptsTotal + scriptsGrandTotal: Int! +} +type CreateResults { + summary: Summary! + version: Version +} +type Query { + // returns array of SourceMigration objects + // note that if input query includes contents field this operation can produce large amounts of data - see sourceMigration(file: String!) + // all parameters are optional and can be used to filter source migrations + sourceMigrations(filters: SourceMigrationFilters): [SourceMigration!]! + // returns a single SourceMigration + // this operation can be used to fetch a complete SourceMigration including its contents field + // file is the unique identifier for a source migration which you can get from sourceMigrations() operation + sourceMigration(file: String!): SourceMigration + // returns array of Version objects + // note that if input query includes DBMigration array this operation can produce large amounts of data - see version(id: Int!) or dbMigration(id: Int!) + // file is optional and can be used to return versions in which given source migration was applied + versions(file: String): [Version!]! + // returns a single Version + // note that if input query includes contents field this operation can produce large amounts of data - see dbMigration(id: Int!) + // id is the unique identifier of a version which you can get from versions() operation + version(id: Int!): Version + // returns a single DBMigration + // this operation can be used to fetch a complete SourceMigration including its contents field + // id is the unique identifier of a version which you can get from versions(file: String) or version(id: Int!) operations + dbMigration(id: Int!): DBMigration + // returns array of Tenant objects + tenants(): [Tenant!]! +} +type Mutation { + // creates new DB version by applying all eligible DB migrations & scripts + createVersion(input: VersionInput!): CreateResults! + // creates new tenant by applying only tenant-specific DB migrations & scripts, also creates new DB version + createTenant(input: TenantInput!): CreateResults! +} +``` + +There are code generators available which can generate client code based on GraphQL schema. This would be the preferred way of consuming migrator's GraphQL endpoint. + +However, below are a few curl examples to get you started. + +Create new version: +``` +# versionName parameter is required and can be: +# 1. your version number +# 2. if you do multiple deploys to dev envs perhaps it could be a version number concatenated with current date time +# 3. or if you do CI/CD the commit sha (recommended) +COMMIT_SHA="acfd70fd1f4c7413e558c03ed850012627c9caa9" +# new lines are used for readability but have to be removed from the actual request +cat < create_version.txt { - "release": "dev-v4.0.1", - "commitSha": "300ee8b98f4d6a4725d38b3676accd5a361d7a04", - "commitDate": "2020-01-07T14:52:00+01:00", - "apiVersions": [ - "v1" - ] + "query": " + mutation CreateVersion(\$input: VersionInput!) { + createVersion(input: \$input) { + version { + id, + name, + } + summary { + startedAt + tenants + migrationsGrandTotal + scriptsGrandTotal + } + } + }", + "operationName": "CreateVersion", + "variables": { + "input": { + "versionName": "$COMMIT_SHA" + } + } } +EOF +# and now execute the above query +curl -d @create_version.txt http://localhost:8080/v2/service ``` -API v1 was introduced in migrator v4.0. Any non API-breaking changes will be added to v1. Any significant change or an API-breaking change will be added to API v2. +Create new tenant, run in dry run and instead of default `Apply`, run `Sync` action, also include DB migrations in output: + +``` +# versionName parameter is required and can be: +# 1. your version number +# 2. if you do multiple deploys to dev envs perhaps it could be a version number concatenated with current date time +# 3. or if you do CI/CD the commit sha (recommended) +COMMIT_SHA="acfd70fd1f4c7413e558c03ed850012627c9caa9" +# tenantName parameter is also required (should not come as a surprise since we want to create new tenant) +TENANT_NAME="new_customer_of_yours" +# new lines are used for readability but have to be removed from the actual request +cat < create_tenant.txt +{ + "query": " + mutation CreateTenant(\$input: TenantInput!) { + createTenant(input: \$input) { + version { + id, + name, + dbMigrations { + id, + file, + schema + } + } + summary { + startedAt + tenants + migrationsGrandTotal + scriptsGrandTotal + } + } + }", + "operationName": "CreateTenant", + "variables": { + "input": { + "versionName": "$COMMIT_SHA - $TENANT_NAME", + "tenantName": "$TENANT_NAME" + } + } +} +EOF +# and now execute the above query +curl -d @create_tenant.txt http://localhost:8080/v2/service +``` + +Query data (yes, migrator supports multiple operations in a single GraphQL query): + +``` +# new lines are used for readability but have to be removed from the actual request +cat < query.txt +{ + "query": " + query Data(\$singleMigrationsFilters: SourceMigrationFilters, \$tenantMigrationsFilters: SourceMigrationFilters) { + singleTenantSourceMigrations: sourceMigrations(filters: \$singleMigrationsFilters) { + file + migrationType + } + multiTenantSourceMigrations: sourceMigrations(filters: \$tenantMigrationsFilters) { + file + migrationType + checkSum + } + tenants { + name + } + }", + "operationName": "Data", + "variables": { + "singleMigrationsFilters": { + "migrationType": "SingleMigration" + }, + "tenantMigrationsFilters": { + "migrationType": "TenantMigration" + } + } +} +EOF +# and now execute the above query +curl -d @query.txt http://localhost:8080/v2/service +``` + +For more GraphQL query and mutation examples see `data/graphql_test.go`. + +## /v1 + +**Deprecation**: As of migrator v2020.1.0 API v1 is deprecated and will sunset in v2021.1.0. + +API v1 is available in migrator v4.x and v2020.x. ## GET /v1/config diff --git a/config/config.go b/config/config.go index 33a751d..f9a6b35 100644 --- a/config/config.go +++ b/config/config.go @@ -55,7 +55,7 @@ func FromBytes(contents []byte) (*Config, error) { } if len(config.BaseDir) > 0 && len(config.BaseLocation) == 0 { - common.Log("WARN", "Deprecated: config property `baseDir` will be removed in migrator v5.0, please rename it to `baseLocation`") + common.Log("WARN", "Deprecated: config property `baseDir` will be removed in migrator v2021.1.0, please rename it to `baseLocation`") config.BaseLocation = config.BaseDir } diff --git a/coordinator/coordinator.go b/coordinator/coordinator.go index 8537a3c..e5d1b84 100644 --- a/coordinator/coordinator.go +++ b/coordinator/coordinator.go @@ -3,7 +3,8 @@ package coordinator import ( "context" "encoding/json" - "sync" + "fmt" + "reflect" "github.com/lukaszbudnik/migrator/common" "github.com/lukaszbudnik/migrator/config" @@ -13,29 +14,43 @@ import ( "github.com/lukaszbudnik/migrator/types" ) +// SourceMigrationFilters defines filters which can be used to fetch source migrations +type SourceMigrationFilters struct { + Name *string + SourceDir *string + File *string + MigrationType *types.MigrationType +} + // Coordinator interface abstracts all operations performed by migrator type Coordinator interface { - GetTenants() []string - GetSourceMigrations() []types.Migration + GetTenants() []types.Tenant + GetVersions() []types.Version + GetVersionsByFile(string) []types.Version + GetVersionByID(int32) (*types.Version, error) + GetDBMigrationByID(int32) (*types.DBMigration, error) + GetSourceMigrations(*SourceMigrationFilters) []types.Migration + GetSourceMigrationByFile(string) (*types.Migration, error) + // deprecated in v2020.1.0 sunset in v2021.1.0 + // Version now contains slice of DBMigration GetAppliedMigrations() []types.MigrationDB VerifySourceMigrationsCheckSums() (bool, []types.Migration) + // Deprecated, uses CreateVersion under the hood ApplyMigrations(types.MigrationsModeType) (*types.MigrationResults, []types.Migration) + // Deprecated, uses CreateTenant under the hood AddTenantAndApplyMigrations(types.MigrationsModeType, string) (*types.MigrationResults, []types.Migration) + CreateVersion(string, types.Action, bool) *types.CreateResults + CreateTenant(string, types.Action, bool, string) *types.CreateResults Dispose() } // coordinator struct is a struct for implementing DB specific dialects type coordinator struct { - ctx context.Context - connector db.Connector - loader loader.Loader - notifier notifications.Notifier - config *config.Config - tenants []string - sourceMigrations []types.Migration - appliedMigrations []types.MigrationDB - loaderLock sync.Mutex - connectorLock sync.Mutex + ctx context.Context + connector db.Connector + loader loader.Loader + notifier notifications.Notifier + config *config.Config } // Factory creates new Coordinator instance @@ -56,34 +71,46 @@ func New(ctx context.Context, config *config.Config, newConnector db.Factory, ne return coordinator } -func (c *coordinator) GetTenants() []string { - c.connectorLock.Lock() - defer c.connectorLock.Unlock() - if c.tenants == nil { - tenants := c.connector.GetTenants() - c.tenants = tenants - } - return c.tenants +func (c *coordinator) GetTenants() []types.Tenant { + return c.connector.GetTenants() +} + +func (c *coordinator) GetVersions() []types.Version { + return c.connector.GetVersions() +} + +func (c *coordinator) GetVersionsByFile(file string) []types.Version { + return c.connector.GetVersionsByFile(file) } -func (c *coordinator) GetSourceMigrations() []types.Migration { - c.loaderLock.Lock() - defer c.loaderLock.Unlock() - if c.sourceMigrations == nil { - sourceMigrations := c.loader.GetSourceMigrations() - c.sourceMigrations = sourceMigrations +func (c *coordinator) GetVersionByID(ID int32) (*types.Version, error) { + return c.connector.GetVersionByID(ID) +} + +func (c *coordinator) GetSourceMigrations(filters *SourceMigrationFilters) []types.Migration { + allSourceMigrations := c.loader.GetSourceMigrations() + filteredMigrations := c.filterMigrations(allSourceMigrations, filters) + return filteredMigrations +} + +func (c *coordinator) GetSourceMigrationByFile(file string) (*types.Migration, error) { + allSourceMigrations := c.loader.GetSourceMigrations() + filters := SourceMigrationFilters{ + File: &file, } - return c.sourceMigrations + filteredMigrations := c.filterMigrations(allSourceMigrations, &filters) + if len(filteredMigrations) == 0 { + return nil, fmt.Errorf("Source migration not found: %v", file) + } + return &filteredMigrations[0], nil +} + +func (c *coordinator) GetDBMigrationByID(ID int32) (*types.DBMigration, error) { + return c.connector.GetDBMigrationByID(ID) } func (c *coordinator) GetAppliedMigrations() []types.MigrationDB { - c.connectorLock.Lock() - defer c.connectorLock.Unlock() - if c.appliedMigrations == nil { - appliedMigrations := c.connector.GetAppliedMigrations() - c.appliedMigrations = appliedMigrations - } - return c.appliedMigrations + return c.connector.GetAppliedMigrations() } // VerifySourceMigrationsCheckSums verifies if CheckSum of source and applied DB migrations match @@ -92,7 +119,7 @@ func (c *coordinator) GetAppliedMigrations() []types.MigrationDB { // if bool is false the function returns a slice of offending migrations // if bool is true the slice of effending migrations is empty func (c *coordinator) VerifySourceMigrationsCheckSums() (bool, []types.Migration) { - sourceMigrations := c.GetSourceMigrations() + sourceMigrations := c.GetSourceMigrations(nil) appliedMigrations := c.GetAppliedMigrations() flattenedAppliedMigration := c.flattenAppliedMigrations(appliedMigrations) @@ -114,31 +141,80 @@ func (c *coordinator) VerifySourceMigrationsCheckSums() (bool, []types.Migration } func (c *coordinator) ApplyMigrations(mode types.MigrationsModeType) (*types.MigrationResults, []types.Migration) { - sourceMigrations := c.GetSourceMigrations() + + // convert to new API params + versionName := "API v1 ApplyMigrations" + action := types.ActionApply + dryRun := false + if mode == types.ModeTypeDryRun { + dryRun = true + } else if mode == types.ModeTypeSync { + action = types.ActionSync + } + + sourceMigrations := c.GetSourceMigrations(nil) appliedMigrations := c.GetAppliedMigrations() migrationsToApply := c.computeMigrationsToApply(sourceMigrations, appliedMigrations) common.LogInfo(c.ctx, "Found migrations to apply: %d", len(migrationsToApply)) - results := c.connector.ApplyMigrations(mode, migrationsToApply) + results, _ := c.connector.CreateVersion(versionName, action, dryRun, migrationsToApply) c.sendNotification(results) return results, migrationsToApply } +func (c *coordinator) CreateVersion(versionName string, action types.Action, dryRun bool) *types.CreateResults { + sourceMigrations := c.GetSourceMigrations(nil) + appliedMigrations := c.GetAppliedMigrations() + + migrationsToApply := c.computeMigrationsToApply(sourceMigrations, appliedMigrations) + common.LogInfo(c.ctx, "Found migrations to apply: %d", len(migrationsToApply)) + + summary, version := c.connector.CreateVersion(versionName, action, dryRun, migrationsToApply) + + c.sendNotification(summary) + + return &types.CreateResults{Summary: summary, Version: version} +} + func (c *coordinator) AddTenantAndApplyMigrations(mode types.MigrationsModeType, tenant string) (*types.MigrationResults, []types.Migration) { - sourceMigrations := c.GetSourceMigrations() + // convert to new API params + versionName := "API v1 AddTenantAndApplyMigrations" + action := types.ActionApply + dryRun := false + if mode == types.ModeTypeDryRun { + dryRun = true + } else if mode == types.ModeTypeSync { + action = types.ActionSync + } + + sourceMigrations := c.GetSourceMigrations(nil) // filter only tenant schemas migrationsToApply := c.filterTenantMigrations(sourceMigrations) common.LogInfo(c.ctx, "Migrations to apply for new tenant: %d", len(migrationsToApply)) - results := c.connector.AddTenantAndApplyMigrations(mode, tenant, migrationsToApply) + summary, _ := c.connector.CreateTenant(versionName, action, dryRun, tenant, migrationsToApply) - c.sendNotification(results) + c.sendNotification(summary) - return results, migrationsToApply + return summary, migrationsToApply +} + +func (c *coordinator) CreateTenant(versionName string, action types.Action, dryRun bool, tenant string) *types.CreateResults { + sourceMigrations := c.GetSourceMigrations(nil) + + // filter only tenant schemas + migrationsToApply := c.filterTenantMigrations(sourceMigrations) + common.LogInfo(c.ctx, "Migrations to apply for new tenant: %d", len(migrationsToApply)) + + summary, version := c.connector.CreateTenant(versionName, action, dryRun, tenant, migrationsToApply) + + c.sendNotification(summary) + + return &types.CreateResults{Summary: summary, Version: version} } func (c *coordinator) Dispose() { @@ -235,3 +311,39 @@ func (c *coordinator) sendNotification(results *types.MigrationResults) { common.LogInfo(c.ctx, "Notifier response: %v", resp) } } + +func (c *coordinator) filterMigrations(migrations []types.Migration, filters *SourceMigrationFilters) []types.Migration { + filtered := []types.Migration{} + for _, m := range migrations { + match := c.matchMigration(m, filters) + if match { + filtered = append(filtered, m) + } + } + return filtered +} + +func (c *coordinator) matchMigration(m types.Migration, filters *SourceMigrationFilters) bool { + match := true + + if filters == nil { + return match + } + + ps := reflect.ValueOf(filters) + s := ps.Elem() + for i := 0; i < s.Type().NumField(); i++ { + // if filter is nil it means match + if s.Field(i).IsNil() { + continue + } + // args field names match migration names + pm := reflect.ValueOf(m).FieldByName(s.Type().Field(i).Name) + match = match && (pm.Interface() == s.Field(i).Elem().Interface()) + // if already non match don't bother further looping + if !match { + break + } + } + return match +} diff --git a/coordinator/coordinator_mocks.go b/coordinator/coordinator_mocks.go index adec717..bd67a5e 100644 --- a/coordinator/coordinator_mocks.go +++ b/coordinator/coordinator_mocks.go @@ -2,8 +2,11 @@ package coordinator import ( "context" + "errors" "time" + "github.com/graph-gophers/graphql-go" + "github.com/lukaszbudnik/migrator/config" "github.com/lukaszbudnik/migrator/db" "github.com/lukaszbudnik/migrator/loader" @@ -15,18 +18,32 @@ type mockedDiskLoader struct { } func (m *mockedDiskLoader) GetSourceMigrations() []types.Migration { + // 5 migrations in total + // 4 migrations with type MigrationTypeSingleMigration + // 3 migrations with sourceDir source and type MigrationTypeSingleMigration + // 2 migrations with name 201602220001.sql and type MigrationTypeSingleMigration + // 1 migration with file config/201602220001.sql + m1 := types.Migration{Name: "201602220000.sql", SourceDir: "source", File: "source/201602220000.sql", MigrationType: types.MigrationTypeSingleMigration, Contents: "select abc"} - m2 := types.Migration{Name: "201602220001.sql", SourceDir: "source", File: "source/201602220001.sql", MigrationType: types.MigrationTypeTenantMigration, Contents: "select def"} - return []types.Migration{m1, m2} + m2 := types.Migration{Name: "201602220001.sql", SourceDir: "source", File: "source/201602220001.sql", MigrationType: types.MigrationTypeSingleMigration, Contents: "select def"} + m3 := types.Migration{Name: "201602220001.sql", SourceDir: "config", File: "config/201602220001.sql", MigrationType: types.MigrationTypeSingleMigration, Contents: "select def"} + m4 := types.Migration{Name: "201602220002.sql", SourceDir: "source", File: "source/201602220002.sql", MigrationType: types.MigrationTypeSingleMigration, Contents: "select def"} + m5 := types.Migration{Name: "201602220003.sql", SourceDir: "tenant", File: "tenant/201602220003.sql", MigrationType: types.MigrationTypeTenantMigration, Contents: "select def"} + return []types.Migration{m1, m2, m3, m4, m5} } func newMockedDiskLoader(_ context.Context, _ *config.Config) loader.Loader { return &mockedDiskLoader{} } -type mockedNotifier struct{} +type mockedNotifier struct { + returnError bool +} func (m *mockedNotifier) Notify(message string) (string, error) { + if m.returnError { + return "", errors.New("algo saliĆ³ terriblemente mal") + } return "mock", nil } @@ -34,6 +51,10 @@ func newMockedNotifier(_ context.Context, _ *config.Config) notifications.Notifi return &mockedNotifier{} } +func newErrorMockedNotifier(_ context.Context, _ *config.Config) notifications.Notifier { + return &mockedNotifier{returnError: true} +} + type mockedBrokenCheckSumDiskLoader struct { } @@ -65,21 +86,56 @@ type mockedConnector struct { func (m *mockedConnector) Dispose() { } +func (m *mockedConnector) CreateTenant(string, types.Action, bool, string, []types.Migration) (*types.MigrationResults, *types.Version) { + return &types.MigrationResults{}, &types.Version{} +} + +func (m *mockedConnector) CreateVersion(string, types.Action, bool, []types.Migration) (*types.MigrationResults, *types.Version) { + return &types.MigrationResults{}, &types.Version{} +} + func (m *mockedConnector) AddTenantAndApplyMigrations(types.MigrationsModeType, string, []types.Migration) *types.MigrationResults { return &types.MigrationResults{} } -func (m *mockedConnector) GetTenants() []string { - return []string{"a", "b", "c"} +func (m *mockedConnector) GetTenants() []types.Tenant { + a := types.Tenant{Name: "a"} + b := types.Tenant{Name: "b"} + c := types.Tenant{Name: "c"} + return []types.Tenant{a, b, c} +} + +func (m *mockedConnector) GetVersions() []types.Version { + a := types.Version{ID: 12, Name: "a", Created: graphql.Time{Time: time.Now().AddDate(0, 0, -2)}} + b := types.Version{ID: 121, Name: "bb", Created: graphql.Time{Time: time.Now().AddDate(0, 0, -1)}} + c := types.Version{ID: 122, Name: "ccc", Created: graphql.Time{Time: time.Now()}} + return []types.Version{a, b, c} +} + +func (m *mockedConnector) GetVersionsByFile(file string) []types.Version { + a := types.Version{ID: 12, Name: "a", Created: graphql.Time{Time: time.Now().AddDate(0, 0, -2)}} + return []types.Version{a} +} + +func (m *mockedConnector) GetVersionByID(ID int32) (*types.Version, error) { + a := types.Version{ID: ID, Name: "a", Created: graphql.Time{Time: time.Now().AddDate(0, 0, -2)}} + return &a, nil } func (m *mockedConnector) GetAppliedMigrations() []types.MigrationDB { m1 := types.Migration{Name: "201602220000.sql", SourceDir: "source", File: "source/201602220000.sql", MigrationType: types.MigrationTypeSingleMigration, Contents: "select abc"} d1 := time.Date(2016, 02, 22, 16, 41, 1, 123, time.UTC) - ms := []types.MigrationDB{{Migration: m1, Schema: "source", AppliedAt: d1}} + ms := []types.MigrationDB{{Migration: m1, Schema: "source", AppliedAt: graphql.Time{Time: d1}}} return ms } +func (m *mockedConnector) GetDBMigrationByID(ID int32) (*types.DBMigration, error) { + mdef := types.Migration{Name: "201602220000.sql", SourceDir: "source", File: "source/201602220000.sql", MigrationType: types.MigrationTypeSingleMigration, Contents: "select abc"} + date := time.Date(2016, 02, 22, 16, 41, 1, 123, time.UTC) + db := types.DBMigration{Migration: mdef, ID: ID, Schema: "source", AppliedAt: graphql.Time{Time: date}} + return &db, nil +} + func (m *mockedConnector) ApplyMigrations(types.MigrationsModeType, []types.Migration) *types.MigrationResults { return &types.MigrationResults{} } @@ -97,7 +153,7 @@ func (m *mockedDifferentScriptCheckSumMockedConnector) GetAppliedMigrations() [] d1 := time.Date(2016, 02, 22, 16, 41, 1, 123, time.UTC) m2 := types.Migration{Name: "recreate-indexes.sql", SourceDir: "tenants-scripts", File: "tenants-scripts/recreate-indexes.sql", MigrationType: types.MigrationTypeTenantScript, Contents: "select abc", CheckSum: "sha256-2"} d2 := time.Date(2016, 02, 22, 16, 41, 1, 456, time.UTC) - ms := []types.MigrationDB{{Migration: m1, Schema: "source", AppliedAt: d1}, {Migration: m2, Schema: "customer1", AppliedAt: d2}} + ms := []types.MigrationDB{{Migration: m1, Schema: "source", AppliedAt: graphql.Time{Time: d1}}, {Migration: m2, Schema: "customer1", AppliedAt: graphql.Time{Time: d2}}} return ms } diff --git a/coordinator/coordinator_test.go b/coordinator/coordinator_test.go index 357489e..c3a9c05 100644 --- a/coordinator/coordinator_test.go +++ b/coordinator/coordinator_test.go @@ -5,21 +5,23 @@ import ( "testing" "time" - "github.com/lukaszbudnik/migrator/types" + "github.com/graph-gophers/graphql-go" "github.com/stretchr/testify/assert" + + "github.com/lukaszbudnik/migrator/types" ) func TestMigrationsFlattenMigrationDBs1(t *testing.T) { m1 := types.Migration{Name: "001.sql", SourceDir: "public", File: "public/001.sql", MigrationType: types.MigrationTypeSingleMigration} - db1 := types.MigrationDB{Migration: m1, Schema: "public", AppliedAt: time.Now()} + db1 := types.MigrationDB{Migration: m1, Schema: "public", AppliedAt: graphql.Time{Time: time.Now()}} m2 := types.Migration{Name: "002.sql", SourceDir: "tenants", File: "tenants/002.sql", MigrationType: types.MigrationTypeTenantMigration} - db2 := types.MigrationDB{Migration: m2, Schema: "abc", AppliedAt: time.Now()} + db2 := types.MigrationDB{Migration: m2, Schema: "abc", AppliedAt: graphql.Time{Time: time.Now()}} - db3 := types.MigrationDB{Migration: m2, Schema: "def", AppliedAt: time.Now()} + db3 := types.MigrationDB{Migration: m2, Schema: "def", AppliedAt: graphql.Time{Time: time.Now()}} m4 := types.Migration{Name: "003.sql", SourceDir: "ref", File: "ref/003.sql", MigrationType: types.MigrationTypeSingleMigration} - db4 := types.MigrationDB{Migration: m4, Schema: "ref", AppliedAt: time.Now()} + db4 := types.MigrationDB{Migration: m4, Schema: "ref", AppliedAt: graphql.Time{Time: time.Now()}} dbs := []types.MigrationDB{db1, db2, db3, db4} @@ -35,12 +37,12 @@ func TestMigrationsFlattenMigrationDBs1(t *testing.T) { func TestMigrationsFlattenMigrationDBs2(t *testing.T) { m2 := types.Migration{Name: "002.sql", SourceDir: "tenants", File: "tenants/002.sql", MigrationType: types.MigrationTypeTenantMigration} - db2 := types.MigrationDB{Migration: m2, Schema: "abc", AppliedAt: time.Now()} + db2 := types.MigrationDB{Migration: m2, Schema: "abc", AppliedAt: graphql.Time{Time: time.Now()}} - db3 := types.MigrationDB{Migration: m2, Schema: "def", AppliedAt: time.Now()} + db3 := types.MigrationDB{Migration: m2, Schema: "def", AppliedAt: graphql.Time{Time: time.Now()}} m4 := types.Migration{Name: "003.sql", SourceDir: "ref", File: "ref/003.sql", MigrationType: types.MigrationTypeSingleMigration} - db4 := types.MigrationDB{Migration: m4, Schema: "ref", AppliedAt: time.Now()} + db4 := types.MigrationDB{Migration: m4, Schema: "ref", AppliedAt: graphql.Time{Time: time.Now()}} dbs := []types.MigrationDB{db2, db3, db4} @@ -52,26 +54,26 @@ func TestMigrationsFlattenMigrationDBs2(t *testing.T) { func TestMigrationsFlattenMigrationDBs3(t *testing.T) { m1 := types.Migration{Name: "001.sql", SourceDir: "public", File: "public/001.sql", MigrationType: types.MigrationTypeSingleMigration} - db1 := types.MigrationDB{Migration: m1, Schema: "public", AppliedAt: time.Now()} + db1 := types.MigrationDB{Migration: m1, Schema: "public", AppliedAt: graphql.Time{Time: time.Now()}} m2 := types.Migration{Name: "002.sql", SourceDir: "tenants", File: "tenants/002.sql", MigrationType: types.MigrationTypeTenantMigration} - db2 := types.MigrationDB{Migration: m2, Schema: "abc", AppliedAt: time.Now()} + db2 := types.MigrationDB{Migration: m2, Schema: "abc", AppliedAt: graphql.Time{Time: time.Now()}} - db3 := types.MigrationDB{Migration: m2, Schema: "def", AppliedAt: time.Now()} + db3 := types.MigrationDB{Migration: m2, Schema: "def", AppliedAt: graphql.Time{Time: time.Now()}} m4 := types.Migration{Name: "003.sql", SourceDir: "ref", File: "ref/003.sql", MigrationType: types.MigrationTypeSingleMigration} - db4 := types.MigrationDB{Migration: m4, Schema: "ref", AppliedAt: time.Now()} + db4 := types.MigrationDB{Migration: m4, Schema: "ref", AppliedAt: graphql.Time{Time: time.Now()}} m5 := types.Migration{Name: "global-stored-procedure1.sql", SourceDir: "public", File: "public-scripts/global-stored-procedure1.sql", MigrationType: types.MigrationTypeSingleScript} - db5 := types.MigrationDB{Migration: m5, Schema: "public", AppliedAt: time.Now()} + db5 := types.MigrationDB{Migration: m5, Schema: "public", AppliedAt: graphql.Time{Time: time.Now()}} m6 := types.Migration{Name: "global-stored-procedure2.sql", SourceDir: "public", File: "public-scripts/global-stored-procedure2sql", MigrationType: types.MigrationTypeSingleScript} - db6 := types.MigrationDB{Migration: m6, Schema: "public", AppliedAt: time.Now()} + db6 := types.MigrationDB{Migration: m6, Schema: "public", AppliedAt: graphql.Time{Time: time.Now()}} m7 := types.Migration{Name: "002.sql", SourceDir: "tenants-scripts", File: "tenants/002.sql", MigrationType: types.MigrationTypeTenantMigration} - db7 := types.MigrationDB{Migration: m7, Schema: "abc", AppliedAt: time.Now()} + db7 := types.MigrationDB{Migration: m7, Schema: "abc", AppliedAt: graphql.Time{Time: time.Now()}} - db8 := types.MigrationDB{Migration: m7, Schema: "def", AppliedAt: time.Now()} + db8 := types.MigrationDB{Migration: m7, Schema: "def", AppliedAt: graphql.Time{Time: time.Now()}} dbs := []types.MigrationDB{db1, db2, db3, db4, db5, db6, db7, db8} @@ -91,7 +93,7 @@ func TestComputeMigrationsToApply(t *testing.T) { mdef7 := types.Migration{Name: "g", SourceDir: "g", File: "g", MigrationType: types.MigrationTypeTenantScript} diskMigrations := []types.Migration{mdef1, mdef2, mdef3, mdef4, mdef5, mdef6, mdef7} - dbMigrations := []types.MigrationDB{{Migration: mdef1, Schema: "a", AppliedAt: time.Now()}, {Migration: mdef2, Schema: "abc", AppliedAt: time.Now()}, {Migration: mdef2, Schema: "def", AppliedAt: time.Now()}, {Migration: mdef5, Schema: "e", AppliedAt: time.Now()}, {Migration: mdef6, Schema: "f", AppliedAt: time.Now()}, {Migration: mdef7, Schema: "abc", AppliedAt: time.Now()}, {Migration: mdef7, Schema: "def", AppliedAt: time.Now()}} + dbMigrations := []types.MigrationDB{{Migration: mdef1, Schema: "a", AppliedAt: graphql.Time{Time: time.Now()}}, {Migration: mdef2, Schema: "abc", AppliedAt: graphql.Time{Time: time.Now()}}, {Migration: mdef2, Schema: "def", AppliedAt: graphql.Time{Time: time.Now()}}, {Migration: mdef5, Schema: "e", AppliedAt: graphql.Time{Time: time.Now()}}, {Migration: mdef6, Schema: "f", AppliedAt: graphql.Time{Time: time.Now()}}, {Migration: mdef7, Schema: "abc", AppliedAt: graphql.Time{Time: time.Now()}}, {Migration: mdef7, Schema: "def", AppliedAt: graphql.Time{Time: time.Now()}}} coordinator := &coordinator{ ctx: context.TODO(), @@ -135,7 +137,7 @@ func TestComputeMigrationsToApplyDifferentTimestamps(t *testing.T) { dev2p := types.Migration{Name: "20181120", SourceDir: "public", File: "public/20181120", MigrationType: types.MigrationTypeSingleMigration} diskMigrations := []types.Migration{mdef1, mdef2, mdef3, dev1, dev1p1, dev1p2, dev2, dev2p} - dbMigrations := []types.MigrationDB{{Migration: mdef1, Schema: "abc", AppliedAt: time.Now()}, {Migration: mdef1, Schema: "def", AppliedAt: time.Now()}, {Migration: mdef2, Schema: "public", AppliedAt: time.Now()}, {Migration: mdef3, Schema: "public", AppliedAt: time.Now()}, {Migration: dev2, Schema: "abc", AppliedAt: time.Now()}, {Migration: dev2, Schema: "def", AppliedAt: time.Now()}, {Migration: dev2p, Schema: "public", AppliedAt: time.Now()}} + dbMigrations := []types.MigrationDB{{Migration: mdef1, Schema: "abc", AppliedAt: graphql.Time{Time: time.Now()}}, {Migration: mdef1, Schema: "def", AppliedAt: graphql.Time{Time: time.Now()}}, {Migration: mdef2, Schema: "public", AppliedAt: graphql.Time{Time: time.Now()}}, {Migration: mdef3, Schema: "public", AppliedAt: graphql.Time{Time: time.Now()}}, {Migration: dev2, Schema: "abc", AppliedAt: graphql.Time{Time: time.Now()}}, {Migration: dev2, Schema: "def", AppliedAt: graphql.Time{Time: time.Now()}}, {Migration: dev2p, Schema: "public", AppliedAt: graphql.Time{Time: time.Now()}}} coordinator := &coordinator{ ctx: context.TODO(), @@ -226,7 +228,7 @@ func TestVerifySourceMigrationsCheckSumsKO(t *testing.T) { defer coordinator.Dispose() verified, offendingMigrations := coordinator.VerifySourceMigrationsCheckSums() assert.False(t, verified) - assert.Equal(t, coordinator.GetSourceMigrations()[0], offendingMigrations[0]) + assert.Equal(t, coordinator.GetSourceMigrations(nil)[0], offendingMigrations[0]) } func TestVerifySourceMigrationsAndScriptsCheckSumsOK(t *testing.T) { @@ -241,8 +243,9 @@ func TestApplyMigrations(t *testing.T) { coordinator := New(context.TODO(), nil, newMockedConnector, newMockedDiskLoader, newMockedNotifier) defer coordinator.Dispose() _, appliedMigrations := coordinator.ApplyMigrations(types.ModeTypeApply) - assert.Len(t, appliedMigrations, 1) - assert.Equal(t, coordinator.GetSourceMigrations()[1], appliedMigrations[0]) + assert.Len(t, appliedMigrations, 4) + // first source migration is already applied so getting the 2nd one + assert.Equal(t, coordinator.GetSourceMigrations(nil)[1], appliedMigrations[0]) } func TestAddTenantAndApplyMigrations(t *testing.T) { @@ -250,12 +253,143 @@ func TestAddTenantAndApplyMigrations(t *testing.T) { defer coordinator.Dispose() _, appliedMigrations := coordinator.AddTenantAndApplyMigrations(types.ModeTypeApply, "new") assert.Len(t, appliedMigrations, 1) - assert.Equal(t, coordinator.GetSourceMigrations()[1], appliedMigrations[0]) + assert.Equal(t, coordinator.GetSourceMigrations(nil)[4], appliedMigrations[0]) } func TestGetTenants(t *testing.T) { coordinator := New(context.TODO(), nil, newMockedConnector, newMockedDiskLoader, newMockedNotifier) defer coordinator.Dispose() tenants := coordinator.GetTenants() - assert.Equal(t, []string{"a", "b", "c"}, tenants) + a := types.Tenant{Name: "a"} + b := types.Tenant{Name: "b"} + c := types.Tenant{Name: "c"} + assert.Equal(t, []types.Tenant{a, b, c}, tenants) +} + +func TestGetVersions(t *testing.T) { + coordinator := New(context.TODO(), nil, newMockedConnector, newMockedDiskLoader, newMockedNotifier) + defer coordinator.Dispose() + versions := coordinator.GetVersions() + + assert.Equal(t, int32(12), versions[0].ID) + assert.Equal(t, int32(121), versions[1].ID) + assert.Equal(t, int32(122), versions[2].ID) +} + +func TestGetVersionByID(t *testing.T) { + coordinator := New(context.TODO(), nil, newMockedConnector, newMockedDiskLoader, newMockedNotifier) + defer coordinator.Dispose() + version, _ := coordinator.GetVersionByID(123) + + assert.Equal(t, int32(123), version.ID) +} + +func TestGetVersionsByFile(t *testing.T) { + coordinator := New(context.TODO(), nil, newMockedConnector, newMockedDiskLoader, newMockedNotifier) + defer coordinator.Dispose() + versions := coordinator.GetVersionsByFile("tenants/abc.sql") + + assert.Equal(t, int32(12), versions[0].ID) +} + +func TestGetMigrationByID(t *testing.T) { + coordinator := New(context.TODO(), nil, newMockedConnector, newMockedDiskLoader, newMockedNotifier) + defer coordinator.Dispose() + migration, _ := coordinator.GetDBMigrationByID(456) + + assert.Equal(t, int32(456), migration.ID) +} + +// Notifier error should not cause the whole process to fail +func TestApplyMigrationsNotifierError(t *testing.T) { + coordinator := New(context.TODO(), nil, newMockedConnector, newMockedDiskLoader, newErrorMockedNotifier) + defer coordinator.Dispose() + _, appliedMigrations := coordinator.ApplyMigrations(types.ModeTypeApply) + assert.Len(t, appliedMigrations, 4) + // first source migration is already applied so getting the 2nd one + assert.Equal(t, coordinator.GetSourceMigrations(nil)[1], appliedMigrations[0]) +} + +func TestGetSourceMigrationByFile(t *testing.T) { + coordinator := New(context.TODO(), nil, newMockedConnector, newMockedDiskLoader, newErrorMockedNotifier) + defer coordinator.Dispose() + file := "source/201602220001.sql" + migration, err := coordinator.GetSourceMigrationByFile(file) + assert.Nil(t, err) + assert.Equal(t, file, migration.File) +} + +func TestGetSourceMigrationByFileNotFound(t *testing.T) { + coordinator := New(context.TODO(), nil, newMockedConnector, newMockedDiskLoader, newErrorMockedNotifier) + defer coordinator.Dispose() + file := "xyz/201602220001.sql" + _, err := coordinator.GetSourceMigrationByFile(file) + assert.NotNil(t, err) + assert.Equal(t, "Source migration not found: xyz/201602220001.sql", err.Error()) +} + +func TestGetSourceMigrationsFilterMigrationType(t *testing.T) { + coordinator := New(context.TODO(), nil, newMockedConnector, newMockedDiskLoader, newErrorMockedNotifier) + defer coordinator.Dispose() + migrationType := types.MigrationTypeSingleMigration + filters := SourceMigrationFilters{ + MigrationType: &migrationType, + } + migrations := coordinator.GetSourceMigrations(&filters) + assert.True(t, len(migrations) == 4) +} + +func TestGetSourceMigrationsFilterMigrationTypeSourceDir(t *testing.T) { + coordinator := New(context.TODO(), nil, newMockedConnector, newMockedDiskLoader, newErrorMockedNotifier) + defer coordinator.Dispose() + migrationType := types.MigrationTypeSingleMigration + sourceDir := "source" + filters := SourceMigrationFilters{ + MigrationType: &migrationType, + SourceDir: &sourceDir, + } + migrations := coordinator.GetSourceMigrations(&filters) + assert.True(t, len(migrations) == 3) +} + +func TestGetSourceMigrationsFilterMigrationTypeName(t *testing.T) { + coordinator := New(context.TODO(), nil, newMockedConnector, newMockedDiskLoader, newErrorMockedNotifier) + defer coordinator.Dispose() + migrationType := types.MigrationTypeSingleMigration + name := "201602220001.sql" + filters := SourceMigrationFilters{ + MigrationType: &migrationType, + Name: &name, + } + migrations := coordinator.GetSourceMigrations(&filters) + assert.True(t, len(migrations) == 2) +} + +func TestGetSourceMigrationsFilterFile(t *testing.T) { + coordinator := New(context.TODO(), nil, newMockedConnector, newMockedDiskLoader, newErrorMockedNotifier) + defer coordinator.Dispose() + file := "source/201602220001.sql" + filters := SourceMigrationFilters{ + File: &file, + } + migrations := coordinator.GetSourceMigrations(&filters) + assert.True(t, len(migrations) == 1) +} + +func TestCreateVersion(t *testing.T) { + coordinator := New(context.TODO(), nil, newMockedConnector, newMockedDiskLoader, newErrorMockedNotifier) + defer coordinator.Dispose() + results := coordinator.CreateVersion("commit-sha", types.ActionApply, false) + assert.NotNil(t, results) + assert.NotNil(t, results.Summary) + assert.NotNil(t, results.Version) +} + +func TestCreateTenant(t *testing.T) { + coordinator := New(context.TODO(), nil, newMockedConnector, newMockedDiskLoader, newErrorMockedNotifier) + defer coordinator.Dispose() + results := coordinator.CreateTenant("commit-sha", types.ActionSync, true, "NewTenant") + assert.NotNil(t, results) + assert.NotNil(t, results.Summary) + assert.NotNil(t, results.Version) } diff --git a/data/graphql.go b/data/graphql.go new file mode 100644 index 0000000..8ba5855 --- /dev/null +++ b/data/graphql.go @@ -0,0 +1,206 @@ +package data + +import ( + "github.com/lukaszbudnik/migrator/coordinator" + "github.com/lukaszbudnik/migrator/types" +) + +// SchemaDefinition contains GraphQL migrator schema +const SchemaDefinition = ` +schema { + query: Query + mutation: Mutation +} +enum MigrationType { + SingleMigration + TenantMigration + SingleScript + TenantScript +} +enum Action { + // Apply is the default action, migrator reads all source migrations and applies them + Apply + // Sync is an action where migrator reads all source migrations and marks them as applied in DB + // typical use cases are: + // importing source migrations from a legacy tool or synchronising tenant migrations when tenant was created using external tool + Sync +} +scalar Time +interface Migration { + name: String! + migrationType: MigrationType! + sourceDir: String! + file: String! + contents: String! + checkSum: String! +} +type SourceMigration implements Migration { + name: String! + migrationType: MigrationType! + sourceDir: String! + file: String! + contents: String! + checkSum: String! +} +type DBMigration implements Migration { + id: Int! + name: String! + migrationType: MigrationType! + sourceDir: String! + file: String! + contents: String! + checkSum: String! + schema: String! + created: Time! +} +type Tenant { + name: String! +} +type Version { + id: Int! + name: String! + created: Time! + dbMigrations: [DBMigration!]! +} +input SourceMigrationFilters { + name: String + sourceDir: String + file: String + migrationType: MigrationType +} +input VersionInput { + versionName: String! + action: Action = Apply + dryRun: Boolean = false +} +input TenantInput { + tenantName: String! + versionName: String! + action: Action = Apply + dryRun: Boolean = false +} +type Summary { + // date time operation started + startedAt: Time! + // how long the operation took in nanoseconds + duration: Int! + // number of tenants in the system + tenants: Int! + // number of applied single schema migrations + singleMigrations: Int! + // number of applied multi-tenant schema migrations + tenantMigrations: Int! + // number of all applied multi-tenant schema migrations (equals to tenants * tenantMigrations) + tenantMigrationsTotal: Int! + // sum of singleMigrations and tenantMigrationsTotal + migrationsGrandTotal: Int! + // number of applied single schema scripts + singleScripts: Int! + // number of applied multi-tenant schema scripts + tenantScripts: Int! + // number of all applied multi-tenant schema migrations (equals to tenants * tenantScripts) + tenantScriptsTotal: Int! + // sum of singleScripts and tenantScriptsTotal + scriptsGrandTotal: Int! +} +type CreateResults { + summary: Summary! + version: Version +} +type Query { + // returns array of SourceMigration objects + // note that if input query includes contents field this operation can produce large amounts of data - see sourceMigration(file: String!) + // all parameters are optional and can be used to filter source migrations + sourceMigrations(filters: SourceMigrationFilters): [SourceMigration!]! + // returns a single SourceMigration + // this operation can be used to fetch a complete SourceMigration including its contents field + // file is the unique identifier for a source migration which you can get from sourceMigrations() operation + sourceMigration(file: String!): SourceMigration + // returns array of Version objects + // note that if input query includes DBMigration array this operation can produce large amounts of data - see version(id: Int!) or dbMigration(id: Int!) + // file is optional and can be used to return versions in which given source migration was applied + versions(file: String): [Version!]! + // returns a single Version + // note that if input query includes contents field this operation can produce large amounts of data - see dbMigration(id: Int!) + // id is the unique identifier of a version which you can get from versions() operation + version(id: Int!): Version + // returns a single DBMigration + // this operation can be used to fetch a complete SourceMigration including its contents field + // id is the unique identifier of a version which you can get from versions(file: String) or version(id: Int!) operations + dbMigration(id: Int!): DBMigration + // returns array of Tenant objects + tenants(): [Tenant!]! +} +type Mutation { + // creates new DB version by applying all eligible DB migrations & scripts + createVersion(input: VersionInput!): CreateResults! + // creates new tenant by applying only tenant-specific DB migrations & scripts, also creates new DB version + createTenant(input: TenantInput!): CreateResults! +} +` + +// RootResolver is resolver for all the migrator data +type RootResolver struct { + Coordinator coordinator.Coordinator +} + +// Tenants resolves all tenants +func (r *RootResolver) Tenants() ([]types.Tenant, error) { + tenants := r.Coordinator.GetTenants() + return tenants, nil +} + +// Versions resoves all versions, optionally can return versions with specific source migration (file is the identifier for source migrations) +func (r *RootResolver) Versions(args struct { + File *string +}) ([]types.Version, error) { + if args.File != nil { + return r.Coordinator.GetVersionsByFile(*args.File), nil + } + return r.Coordinator.GetVersions(), nil +} + +// Version resolves version by ID +func (r *RootResolver) Version(args struct { + ID int32 +}) (*types.Version, error) { + return r.Coordinator.GetVersionByID(args.ID) +} + +// SourceMigrations resolves source migrations using optional filters +func (r *RootResolver) SourceMigrations(args struct { + Filters *coordinator.SourceMigrationFilters +}) ([]types.Migration, error) { + sourceMigrations := r.Coordinator.GetSourceMigrations(args.Filters) + return sourceMigrations, nil +} + +// SourceMigration resolves source migration by its file name +func (r *RootResolver) SourceMigration(args struct { + File string +}) (*types.Migration, error) { + return r.Coordinator.GetSourceMigrationByFile(args.File) +} + +// DBMigration resolves DB migration by ID +func (r *RootResolver) DBMigration(args struct { + ID int32 +}) (*types.MigrationDB, error) { + return r.Coordinator.GetDBMigrationByID(args.ID) +} + +// CreateVersion creates new DB version +func (r *RootResolver) CreateVersion(args struct { + Input types.VersionInput +}) (*types.CreateResults, error) { + results := r.Coordinator.CreateVersion(args.Input.VersionName, args.Input.Action, args.Input.DryRun) + return results, nil +} + +// CreateTenant creates new tenant +func (r *RootResolver) CreateTenant(args struct { + Input types.TenantInput +}) (*types.CreateResults, error) { + results := r.Coordinator.CreateTenant(args.Input.VersionName, args.Input.Action, args.Input.DryRun, args.Input.TenantName) + return results, nil +} diff --git a/data/graphql_mocks.go b/data/graphql_mocks.go new file mode 100644 index 0000000..06fd53c --- /dev/null +++ b/data/graphql_mocks.go @@ -0,0 +1,120 @@ +package data + +import ( + "strings" + "time" + + "github.com/graph-gophers/graphql-go" + "github.com/lukaszbudnik/migrator/coordinator" + "github.com/lukaszbudnik/migrator/types" +) + +type mockedCoordinator struct { +} + +func (m *mockedCoordinator) safeString(value *string) string { + if value == nil { + return "" + } + return *value +} + +func (m *mockedCoordinator) CreateTenant(string, types.Action, bool, string) *types.CreateResults { + version, _ := m.GetVersionByID(0) + return &types.CreateResults{Summary: &types.MigrationResults{}, Version: version} +} + +func (m *mockedCoordinator) CreateVersion(string, types.Action, bool) *types.CreateResults { + // re-use mocked version from GetVersionByID... + version, _ := m.GetVersionByID(0) + return &types.CreateResults{Summary: &types.MigrationResults{}, Version: version} +} + +func (m *mockedCoordinator) GetSourceMigrations(filters *coordinator.SourceMigrationFilters) []types.Migration { + + if filters == nil { + m1 := types.Migration{Name: "201602220000.sql", SourceDir: "source", File: "source/201602220000.sql", MigrationType: types.MigrationTypeSingleMigration, Contents: "select abc"} + m2 := types.Migration{Name: "201602220001.sql", SourceDir: "source", File: "source/201602220001.sql", MigrationType: types.MigrationTypeSingleMigration, Contents: "select def"} + m3 := types.Migration{Name: "201602220001.sql", SourceDir: "config", File: "config/201602220001.sql", MigrationType: types.MigrationTypeSingleMigration, Contents: "select def"} + m4 := types.Migration{Name: "201602220002.sql", SourceDir: "source", File: "source/201602220002.sql", MigrationType: types.MigrationTypeSingleMigration, Contents: "select def"} + m5 := types.Migration{Name: "201602220003.sql", SourceDir: "tenant", File: "tenant/201602220003.sql", MigrationType: types.MigrationTypeTenantMigration, Contents: "select def"} + return []types.Migration{m1, m2, m3, m4, m5} + } + + m1 := types.Migration{Name: m.safeString(filters.Name), SourceDir: m.safeString(filters.SourceDir), File: m.safeString(filters.File), MigrationType: types.MigrationTypeSingleMigration, Contents: "select abc"} + return []types.Migration{m1} +} + +func (m *mockedCoordinator) GetSourceMigrationByFile(file string) (*types.Migration, error) { + i := strings.Index(file, "/") + sourceDir := file[:i] + name := file[i+1:] + m1 := types.Migration{Name: name, SourceDir: sourceDir, File: file, MigrationType: types.MigrationTypeSingleMigration, Contents: "select abc"} + return &m1, nil +} + +func (m *mockedCoordinator) Dispose() { +} + +func (m *mockedCoordinator) GetTenants() []types.Tenant { + a := types.Tenant{Name: "a"} + b := types.Tenant{Name: "b"} + c := types.Tenant{Name: "c"} + return []types.Tenant{a, b, c} +} + +func (m *mockedCoordinator) GetVersions() []types.Version { + a := types.Version{ID: 12, Name: "a", Created: graphql.Time{Time: time.Now().AddDate(0, 0, -2)}} + b := types.Version{ID: 121, Name: "bb", Created: graphql.Time{Time: time.Now().AddDate(0, 0, -1)}} + c := types.Version{ID: 122, Name: "ccc", Created: graphql.Time{Time: time.Now()}} + return []types.Version{a, b, c} +} + +func (m *mockedCoordinator) GetVersionsByFile(file string) []types.Version { + a := types.Version{ID: 12, Name: "a", Created: graphql.Time{Time: time.Now().AddDate(0, 0, -2)}} + return []types.Version{a} +} + +func (m *mockedCoordinator) GetVersionByID(ID int32) (*types.Version, error) { + m1 := types.Migration{Name: "201602220000.sql", SourceDir: "source", File: "source/201602220000.sql", MigrationType: types.MigrationTypeSingleMigration, Contents: "select abc"} + d1 := time.Date(2016, 02, 22, 16, 41, 1, 123, time.UTC) + db1 := types.MigrationDB{Migration: m1, Schema: "source", Created: graphql.Time{Time: d1}} + + m2 := types.Migration{Name: "202002180000.sql", SourceDir: "config", File: "config/202002180000.sql", MigrationType: types.MigrationTypeSingleMigration, Contents: "select abc"} + d2 := time.Date(2020, 02, 18, 16, 41, 1, 123, time.UTC) + db2 := types.MigrationDB{Migration: m2, Schema: "source", Created: graphql.Time{Time: d2}} + + m3 := types.Migration{Name: "202002180000.sql", SourceDir: "tenants", File: "tenants/202002180000.sql", MigrationType: types.MigrationTypeTenantMigration, Contents: "select abc"} + d3 := time.Date(2020, 02, 18, 16, 41, 1, 123, time.UTC) + db3 := types.MigrationDB{Migration: m3, Schema: "abc", Created: graphql.Time{Time: d3}} + db4 := types.MigrationDB{Migration: m3, Schema: "def", Created: graphql.Time{Time: d3}} + db5 := types.MigrationDB{Migration: m3, Schema: "xyz", Created: graphql.Time{Time: d3}} + + a := types.Version{ID: ID, Name: "a", Created: graphql.Time{Time: time.Now().AddDate(0, 0, -2)}, DBMigrations: []types.MigrationDB{db1, db2, db3, db4, db5}} + + return &a, nil +} + +// not used in GraphQL +func (m *mockedCoordinator) GetAppliedMigrations() []types.MigrationDB { + return []types.MigrationDB{} +} + +func (m *mockedCoordinator) GetDBMigrationByID(ID int32) (*types.DBMigration, error) { + migration := types.Migration{Name: "201602220000.sql", SourceDir: "source", File: "source/201602220000.sql", MigrationType: types.MigrationTypeSingleMigration, Contents: "select abc"} + d := time.Date(2016, 02, 22, 16, 41, 1, 123, time.UTC) + db := types.DBMigration{Migration: migration, ID: ID, Schema: "source", Created: graphql.Time{Time: d}} + return &db, nil +} + +func (m *mockedCoordinator) ApplyMigrations(types.MigrationsModeType) (*types.MigrationResults, []types.Migration) { + return &types.MigrationResults{}, []types.Migration{} +} + +func (m *mockedCoordinator) AddTenantAndApplyMigrations(types.MigrationsModeType, string) (*types.MigrationResults, []types.Migration) { + return &types.MigrationResults{}, []types.Migration{} +} + +func (m *mockedCoordinator) VerifySourceMigrationsCheckSums() (bool, []types.Migration) { + return true, nil +} diff --git a/data/graphql_test.go b/data/graphql_test.go new file mode 100644 index 0000000..c5b3be1 --- /dev/null +++ b/data/graphql_test.go @@ -0,0 +1,636 @@ +package data + +import ( + "context" + "encoding/json" + "testing" + + "github.com/stretchr/testify/assert" + + "github.com/graph-gophers/graphql-go" +) + +func TestTenants(t *testing.T) { + ctx := context.Background() + + opts := []graphql.SchemaOpt{graphql.UseFieldResolvers()} + schema := graphql.MustParseSchema(SchemaDefinition, &RootResolver{Coordinator: &mockedCoordinator{}}, opts...) + + opName := "Tenants" + query := `query Tenants { + tenants { + name + } + }` + variables := map[string]interface{}{} + + resp := schema.Exec(ctx, query, opName, variables) + jsonMap := make(map[string]interface{}) + err := json.Unmarshal(resp.Data, &jsonMap) + assert.Nil(t, err) + results := len(jsonMap["tenants"].([]interface{})) + assert.Equal(t, 3, results) +} + +func TestVersions(t *testing.T) { + ctx := context.Background() + + opts := []graphql.SchemaOpt{graphql.UseFieldResolvers()} + schema := graphql.MustParseSchema(SchemaDefinition, &RootResolver{Coordinator: &mockedCoordinator{}}, opts...) + + opName := "Versions" + query := `query Versions { + versions { + id + name + created + } + }` + variables := map[string]interface{}{} + + resp := schema.Exec(ctx, query, opName, variables) + jsonMap := make(map[string]interface{}) + err := json.Unmarshal(resp.Data, &jsonMap) + assert.Nil(t, err) + versions := jsonMap["versions"].([]interface{}) + results := len(versions) + + assert.Equal(t, 3, results) + assert.Equal(t, "a", versions[0].(map[string]interface{})["name"]) + assert.Equal(t, "bb", versions[1].(map[string]interface{})["name"]) + assert.Equal(t, "ccc", versions[2].(map[string]interface{})["name"]) +} + +func TestVersionsByFile(t *testing.T) { + ctx := context.Background() + + opts := []graphql.SchemaOpt{graphql.UseFieldResolvers()} + schema := graphql.MustParseSchema(SchemaDefinition, &RootResolver{Coordinator: &mockedCoordinator{}}, opts...) + + opName := "Versions" + query := `query Versions($file: String) { + versions(file: $file) { + id + name + created + } + }` + variables := map[string]interface{}{ + "file": "config/202002180000.sql", + } + + resp := schema.Exec(ctx, query, opName, variables) + jsonMap := make(map[string]interface{}) + err := json.Unmarshal(resp.Data, &jsonMap) + assert.Nil(t, err) + versions := jsonMap["versions"].([]interface{}) + results := len(versions) + + assert.Equal(t, 1, results) + assert.Equal(t, "a", versions[0].(map[string]interface{})["name"]) +} + +func TestVersionByID(t *testing.T) { + ctx := context.Background() + + opts := []graphql.SchemaOpt{graphql.UseFieldResolvers()} + schema := graphql.MustParseSchema(SchemaDefinition, &RootResolver{Coordinator: &mockedCoordinator{}}, opts...) + + opName := "Version" + query := `query Version($id: Int!) { + version(id: $id) { + id + created + dbMigrations { + file + schema + migrationType + } + } + }` + variables := map[string]interface{}{ + "id": 1234, + } + + resp := schema.Exec(ctx, query, opName, variables) + jsonMap := make(map[string]interface{}) + err := json.Unmarshal(resp.Data, &jsonMap) + assert.Nil(t, err) + version := jsonMap["version"].(map[string]interface{}) + dbMigrations := version["dbMigrations"].([]interface{}) + // json.Unmarshal actually creates float64 for all number, but this is only unit test + assert.Equal(t, float64(1234), version["id"]) + assert.Nil(t, version["name"]) + assert.NotNil(t, version["created"]) + assert.Equal(t, 5, len(dbMigrations)) + lastDBMigration := dbMigrations[4].(map[string]interface{}) + assert.Equal(t, "tenants/202002180000.sql", lastDBMigration["file"]) + assert.Equal(t, "TenantMigration", lastDBMigration["migrationType"]) + assert.Equal(t, "xyz", lastDBMigration["schema"]) +} + +func TestSourceMigrationsNoFilters(t *testing.T) { + + ctx := context.Background() + + opts := []graphql.SchemaOpt{graphql.UseFieldResolvers()} + schema := graphql.MustParseSchema(SchemaDefinition, &RootResolver{Coordinator: &mockedCoordinator{}}, opts...) + + opName := "SourceMigrations" + query := `query SourceMigrations { + sourceMigrations { + name, + migrationType, + sourceDir, + file, + contents, + checkSum + } + }` + variables := map[string]interface{}{} + + resp := schema.Exec(ctx, query, opName, variables) + jsonMap := make(map[string]interface{}) + err := json.Unmarshal(resp.Data, &jsonMap) + assert.Nil(t, err) + results := len(jsonMap["sourceMigrations"].([]interface{})) + assert.Equal(t, 5, results) +} + +func TestSourceMigrationsTypeFilter(t *testing.T) { + + ctx := context.Background() + + opts := []graphql.SchemaOpt{graphql.UseFieldResolvers()} + schema := graphql.MustParseSchema(SchemaDefinition, &RootResolver{Coordinator: &mockedCoordinator{}}, opts...) + + opName := "SourceMigrations" + query := `query SourceMigrations($filters: SourceMigrationFilters) { + sourceMigrations(filters: $filters) { + name, + migrationType, + sourceDir, + file, + contents, + checkSum + } + }` + variables := map[string]interface{}{ + "filters": map[string]interface{}{ + "migrationType": "SingleMigration", + }, + } + + resp := schema.Exec(ctx, query, opName, variables) + jsonMap := make(map[string]interface{}) + err := json.Unmarshal(resp.Data, &jsonMap) + assert.Nil(t, err) + results := jsonMap["sourceMigrations"].([]interface{}) + migration := results[0].(map[string]interface{}) + assert.Equal(t, "SingleMigration", migration["migrationType"]) +} + +func TestSourceMigrationsTypeSourceDirFilter(t *testing.T) { + ctx := context.Background() + + opts := []graphql.SchemaOpt{graphql.UseFieldResolvers()} + schema := graphql.MustParseSchema(SchemaDefinition, &RootResolver{Coordinator: &mockedCoordinator{}}, opts...) + + opName := "SourceMigrations" + query := `query SourceMigrations($filters: SourceMigrationFilters) { + sourceMigrations(filters: $filters) { + name, + migrationType, + sourceDir, + file, + contents, + checkSum + } + }` + variables := map[string]interface{}{ + "filters": map[string]interface{}{ + "migrationType": "SingleMigration", + "sourceDir": "source", + }, + } + + resp := schema.Exec(ctx, query, opName, variables) + jsonMap := make(map[string]interface{}) + err := json.Unmarshal(resp.Data, &jsonMap) + assert.Nil(t, err) + results := jsonMap["sourceMigrations"].([]interface{}) + migration := results[0].(map[string]interface{}) + assert.Equal(t, "SingleMigration", migration["migrationType"]) + assert.Equal(t, "source", migration["sourceDir"]) +} + +func TestSourceMigrationsTypeSourceDirNameFilter(t *testing.T) { + ctx := context.Background() + + opts := []graphql.SchemaOpt{graphql.UseFieldResolvers()} + schema := graphql.MustParseSchema(SchemaDefinition, &RootResolver{Coordinator: &mockedCoordinator{}}, opts...) + + opName := "SourceMigrations" + query := `query SourceMigrations($filters: SourceMigrationFilters) { + sourceMigrations(filters: $filters) { + name, + migrationType, + sourceDir, + file, + contents, + checkSum + } + }` + variables := map[string]interface{}{ + "filters": map[string]interface{}{ + "migrationType": "SingleMigration", + "name": "201602220001.sql", + }, + } + + resp := schema.Exec(ctx, query, opName, variables) + jsonMap := make(map[string]interface{}) + err := json.Unmarshal(resp.Data, &jsonMap) + assert.Nil(t, err) + results := jsonMap["sourceMigrations"].([]interface{}) + migration := results[0].(map[string]interface{}) + assert.Equal(t, "SingleMigration", migration["migrationType"]) + assert.Equal(t, "201602220001.sql", migration["name"]) +} + +func TestSourceMigrationsTypeNameFilter(t *testing.T) { + ctx := context.Background() + + opts := []graphql.SchemaOpt{graphql.UseFieldResolvers()} + schema := graphql.MustParseSchema(SchemaDefinition, &RootResolver{Coordinator: &mockedCoordinator{}}, opts...) + + opName := "SourceMigrations" + query := `query SourceMigrations($filters: SourceMigrationFilters) { + sourceMigrations(filters: $filters) { + name, + migrationType, + sourceDir, + file, + contents, + checkSum + } + }` + variables := map[string]interface{}{ + "filters": map[string]interface{}{ + "file": "config/201602220001.sql", + }, + } + + resp := schema.Exec(ctx, query, opName, variables) + jsonMap := make(map[string]interface{}) + err := json.Unmarshal(resp.Data, &jsonMap) + assert.Nil(t, err) + results := len(jsonMap["sourceMigrations"].([]interface{})) + assert.Equal(t, 1, results) +} + +func TestSourceMigration(t *testing.T) { + ctx := context.Background() + + opts := []graphql.SchemaOpt{graphql.UseFieldResolvers()} + schema := graphql.MustParseSchema(SchemaDefinition, &RootResolver{Coordinator: &mockedCoordinator{}}, opts...) + + opName := "SourceMigration" + query := `query SourceMigration($file: String!) { + sourceMigration(file: $file) { + name, + migrationType, + sourceDir, + file + } + }` + variables := map[string]interface{}{ + "file": "config/201602220001.sql", + } + + resp := schema.Exec(ctx, query, opName, variables) + jsonMap := make(map[string]interface{}) + err := json.Unmarshal(resp.Data, &jsonMap) + assert.Nil(t, err) + results := jsonMap["sourceMigration"].(map[string]interface{}) + assert.Equal(t, "config/201602220001.sql", results["file"]) + assert.NotNil(t, "201602220001.sql", results["name"]) + assert.NotNil(t, "SingleMigration", results["migrationType"]) + assert.NotNil(t, "config", results["sourceDir"]) + // we return only 4 fields in above query others should be nil + assert.Nil(t, results["contents"]) + assert.Nil(t, results["checkSum"]) +} + +func TestDBMigration(t *testing.T) { + ctx := context.Background() + + opts := []graphql.SchemaOpt{graphql.UseFieldResolvers()} + schema := graphql.MustParseSchema(SchemaDefinition, &RootResolver{Coordinator: &mockedCoordinator{}}, opts...) + + opName := "DBMigration" + query := `query DBMigration($id: Int!) { + dbMigration(id: $id) { + name, + migrationType, + sourceDir, + file, + checkSum, + schema, + created + } + }` + variables := map[string]interface{}{ + "id": 123, + } + + resp := schema.Exec(ctx, query, opName, variables) + jsonMap := make(map[string]interface{}) + err := json.Unmarshal(resp.Data, &jsonMap) + assert.Nil(t, err) + results := jsonMap["dbMigration"].(map[string]interface{}) + assert.Equal(t, "201602220000.sql", results["name"]) + assert.Equal(t, "SingleMigration", results["migrationType"]) + assert.Equal(t, "source", results["sourceDir"]) + assert.Equal(t, "source/201602220000.sql", results["file"]) + assert.Equal(t, "source", results["schema"]) + assert.Equal(t, "2016-02-22T16:41:01.000000123Z", results["created"]) + // we return all fields except contents - should be nil + assert.Nil(t, results["contents"]) +} + +func TestComplexQueries(t *testing.T) { + ctx := context.Background() + + opts := []graphql.SchemaOpt{graphql.UseFieldResolvers()} + schema := graphql.MustParseSchema(SchemaDefinition, &RootResolver{Coordinator: &mockedCoordinator{}}, opts...) + + opName := "Data" + query := ` + query Data($singleMigrationsFilters: SourceMigrationFilters, $tenantMigrationsFilters: SourceMigrationFilters) { + singleTenantSourceMigrations: sourceMigrations(filters: $singleMigrationsFilters) { + file + migrationType + } + multiTenantSourceMigrations: sourceMigrations(filters: $tenantMigrationsFilters) { + file + migrationType + checkSum + } + tenants { + name + } + } + ` + variables := map[string]interface{}{ + "singleMigrationsFilters": map[string]interface{}{ + "migrationType": "SingleMigration", + }, + "tenantMigrationsFilters": map[string]interface{}{ + "migrationType": "TenantMigration", + }, + } + + resp := schema.Exec(ctx, query, opName, variables) + jsonMap := make(map[string]interface{}) + err := json.Unmarshal(resp.Data, &jsonMap) + assert.Nil(t, err) + single := len(jsonMap["singleTenantSourceMigrations"].([]interface{})) + assert.Equal(t, 1, single) + multi := len(jsonMap["multiTenantSourceMigrations"].([]interface{})) + assert.Equal(t, 1, multi) + tenants := len(jsonMap["tenants"].([]interface{})) + assert.Equal(t, 3, tenants) +} + +func TestCreateVersionWithDefaults(t *testing.T) { + ctx := context.Background() + + opts := []graphql.SchemaOpt{graphql.UseFieldResolvers()} + schema := graphql.MustParseSchema(SchemaDefinition, &RootResolver{Coordinator: &mockedCoordinator{}}, opts...) + + opName := "CreateVersion" + query := `mutation CreateVersion($input: VersionInput!) { + createVersion(input: $input) { + version { + id, + name, + } + summary { + startedAt + tenants + migrationsGrandTotal + scriptsGrandTotal + } + } +}` + variables := map[string]interface{}{ + "input": map[string]interface{}{ + "versionName": "commit-sha", + }, + } + + resp := schema.Exec(ctx, query, opName, variables) + jsonMap := make(map[string]interface{}) + err := json.Unmarshal(resp.Data, &jsonMap) + assert.Nil(t, err) + results := jsonMap["createVersion"].(map[string]interface{}) + + // check version part + version := results["version"].(map[string]interface{}) + assert.NotNil(t, version["id"]) + assert.NotNil(t, version["name"]) + // we return only 2 fields in above query others should be nil including dbMigrations + assert.Nil(t, version["dbMigrations"]) + + // check summary part + summary := results["summary"].(map[string]interface{}) + assert.NotNil(t, summary["startedAt"]) + assert.NotNil(t, summary["tenants"]) + assert.NotNil(t, summary["migrationsGrandTotal"]) + assert.NotNil(t, summary["scriptsGrandTotal"]) + // we return only 4 fields in above query others should be nil including duration + assert.Nil(t, summary["duration"]) +} + +func TestCreateVersionNonDefaultParams(t *testing.T) { + ctx := context.Background() + + opts := []graphql.SchemaOpt{graphql.UseFieldResolvers()} + schema := graphql.MustParseSchema(SchemaDefinition, &RootResolver{Coordinator: &mockedCoordinator{}}, opts...) + + opName := "CreateVersion" + query := `mutation CreateVersion($input: VersionInput!) { + createVersion(input: $input) { + version { + id, + name, + dbMigrations { + id, + file, + schema + } + } + summary { + startedAt + tenants + migrationsGrandTotal + scriptsGrandTotal + } + } +}` + variables := map[string]interface{}{ + "input": map[string]interface{}{ + "action": "Sync", + "dryRun": true, + "versionName": "commit-sha", + }, + } + + resp := schema.Exec(ctx, query, opName, variables) + jsonMap := make(map[string]interface{}) + err := json.Unmarshal(resp.Data, &jsonMap) + assert.Nil(t, err) + results := jsonMap["createVersion"].(map[string]interface{}) + + // check version part + version := results["version"].(map[string]interface{}) + assert.NotNil(t, version["id"]) + assert.NotNil(t, version["name"]) + // in this test we also fetch dbMigrations + dbMigrations := version["dbMigrations"].([]interface{}) + assert.Equal(t, 5, len(dbMigrations)) + // for each migration we fetch only 3 fields + dbMigration := dbMigrations[0].(map[string]interface{}) + assert.NotNil(t, dbMigration["id"]) + assert.NotNil(t, dbMigration["file"]) + assert.NotNil(t, dbMigration["schema"]) + // others should be nil + assert.Nil(t, dbMigration["contents"]) + + // check summary part + summary := results["summary"].(map[string]interface{}) + assert.NotNil(t, summary["startedAt"]) + assert.NotNil(t, summary["tenants"]) + assert.NotNil(t, summary["migrationsGrandTotal"]) + assert.NotNil(t, summary["scriptsGrandTotal"]) + // we return only 4 fields in above query others should be nil including duration + assert.Nil(t, summary["duration"]) +} + +func TestCreateTenantWithDefaults(t *testing.T) { + ctx := context.Background() + + opts := []graphql.SchemaOpt{graphql.UseFieldResolvers()} + schema := graphql.MustParseSchema(SchemaDefinition, &RootResolver{Coordinator: &mockedCoordinator{}}, opts...) + + opName := "CreateTenant" + query := `mutation CreateTenant($input: TenantInput!) { + createTenant(input: $input) { + version { + id, + name, + } + summary { + startedAt + tenants + migrationsGrandTotal + scriptsGrandTotal + } + } +}` + variables := map[string]interface{}{ + "input": map[string]interface{}{ + "versionName": "commit-sha", + "tenantName": "new-tenant", + }, + } + + resp := schema.Exec(ctx, query, opName, variables) + jsonMap := make(map[string]interface{}) + err := json.Unmarshal(resp.Data, &jsonMap) + assert.Nil(t, err) + results := jsonMap["createTenant"].(map[string]interface{}) + + // check version part + version := results["version"].(map[string]interface{}) + assert.NotNil(t, version["id"]) + assert.NotNil(t, version["name"]) + // we return only 2 fields in above query others should be nil including dbMigrations + assert.Nil(t, version["dbMigrations"]) + + // check summary part + summary := results["summary"].(map[string]interface{}) + assert.NotNil(t, summary["startedAt"]) + assert.NotNil(t, summary["tenants"]) + assert.NotNil(t, summary["migrationsGrandTotal"]) + assert.NotNil(t, summary["scriptsGrandTotal"]) + // we return only 4 fields in above query others should be nil including duration + assert.Nil(t, summary["duration"]) +} + +func TestCreateTenantNonDefaultParams(t *testing.T) { + ctx := context.Background() + + opts := []graphql.SchemaOpt{graphql.UseFieldResolvers()} + schema := graphql.MustParseSchema(SchemaDefinition, &RootResolver{Coordinator: &mockedCoordinator{}}, opts...) + + opName := "CreateTenant" + query := `mutation CreateTenant($input: TenantInput!) { + createTenant(input: $input) { + version { + id, + name, + dbMigrations { + id, + file, + schema + } + } + summary { + startedAt + tenants + migrationsGrandTotal + scriptsGrandTotal + } + } +}` + variables := map[string]interface{}{ + "input": map[string]interface{}{ + "action": "Sync", + "dryRun": true, + "versionName": "commit-sha", + "tenantName": "new-tenant", + }, + } + + resp := schema.Exec(ctx, query, opName, variables) + jsonMap := make(map[string]interface{}) + err := json.Unmarshal(resp.Data, &jsonMap) + assert.Nil(t, err) + results := jsonMap["createTenant"].(map[string]interface{}) + + // check version part + version := results["version"].(map[string]interface{}) + assert.NotNil(t, version["id"]) + assert.NotNil(t, version["name"]) + // in this test we also fetch dbMigrations + dbMigrations := version["dbMigrations"].([]interface{}) + assert.Equal(t, 5, len(dbMigrations)) + // for each migration we fetch only 3 fields + dbMigration := dbMigrations[0].(map[string]interface{}) + assert.NotNil(t, dbMigration["id"]) + assert.NotNil(t, dbMigration["file"]) + assert.NotNil(t, dbMigration["schema"]) + // others should be nil + assert.Nil(t, dbMigration["contents"]) + + // check summary part + summary := results["summary"].(map[string]interface{}) + assert.NotNil(t, summary["startedAt"]) + assert.NotNil(t, summary["tenants"]) + assert.NotNil(t, summary["migrationsGrandTotal"]) + assert.NotNil(t, summary["scriptsGrandTotal"]) + // we return only 4 fields in above query others should be nil including duration + assert.Nil(t, summary["duration"]) +} diff --git a/db/db.go b/db/db.go index 5cf72ef..05dca5c 100644 --- a/db/db.go +++ b/db/db.go @@ -5,9 +5,12 @@ import ( "database/sql" "fmt" "path/filepath" + "sort" "strings" "time" + "github.com/graph-gophers/graphql-go" + "github.com/lukaszbudnik/migrator/common" "github.com/lukaszbudnik/migrator/config" "github.com/lukaszbudnik/migrator/types" @@ -15,10 +18,15 @@ import ( // Connector interface abstracts all DB operations performed by migrator type Connector interface { - GetTenants() []string + GetTenants() []types.Tenant + GetVersions() []types.Version + GetVersionsByFile(file string) []types.Version + GetVersionByID(ID int32) (*types.Version, error) + GetDBMigrationByID(ID int32) (*types.DBMigration, error) + // deprecated in v2020.1.0 sunset in v2021.1.0 GetAppliedMigrations() []types.MigrationDB - ApplyMigrations(types.MigrationsModeType, []types.Migration) *types.MigrationResults - AddTenantAndApplyMigrations(types.MigrationsModeType, string, []types.Migration) *types.MigrationResults + CreateVersion(string, types.Action, bool, []types.Migration) (*types.MigrationResults, *types.Version) + CreateTenant(string, types.Action, bool, string, []types.Migration) (*types.MigrationResults, *types.Version) Dispose() } @@ -46,6 +54,7 @@ const ( migratorSchema = "migrator" migratorTenantsTable = "migrator_tenants" migratorMigrationsTable = "migrator_migrations" + migratorVersionsTable = "migrator_versions" defaultSchemaPlaceHolder = "{schema}" ) @@ -81,6 +90,14 @@ func (bc *baseConnector) init() { panic(fmt.Sprintf("Could not create migrations table: %v", err)) } + // make sure versions table exists + createVersionsTableSQLs := bc.dialect.GetCreateVersionsTableSQL() + for _, createVersionsTableSQL := range createVersionsTableSQLs { + if _, err := bc.db.Query(createVersionsTableSQL); err != nil { + panic(fmt.Sprintf("Could not create versions table: %v", err)) + } + } + // if using default migrator tenants table make sure it exists if bc.config.TenantSelectSQL == "" { createTenantsTable := bc.dialect.GetCreateTenantsTableSQL() @@ -113,10 +130,10 @@ func (bc *baseConnector) getTenantSelectSQL() string { } // GetTenants returns a list of all DB tenants -func (bc *baseConnector) GetTenants() []string { +func (bc *baseConnector) GetTenants() []types.Tenant { tenantSelectSQL := bc.getTenantSelectSQL() - tenants := []string{} + tenants := []types.Tenant{} rows, err := bc.db.Query(tenantSelectSQL) if err != nil { @@ -128,12 +145,149 @@ func (bc *baseConnector) GetTenants() []string { if err = rows.Scan(&name); err != nil { panic(fmt.Sprintf("Could not read tenants: %v", err)) } - tenants = append(tenants, name) + tenants = append(tenants, types.Tenant{Name: name}) } return tenants } +func (bc *baseConnector) GetVersions() []types.Version { + versionsSelectSQL := bc.dialect.GetVersionsSelectSQL() + + rows, err := bc.db.Query(versionsSelectSQL) + if err != nil { + panic(fmt.Sprintf("Could not query versions: %v", err)) + } + + return bc.readVersions(rows) +} + +func (bc *baseConnector) GetVersionsByFile(file string) []types.Version { + versionsSelectSQL := bc.dialect.GetVersionsByFileSQL() + + rows, err := bc.db.Query(versionsSelectSQL, file) + if err != nil { + panic(fmt.Sprintf("Could not query versions: %v", err)) + } + + return bc.readVersions(rows) +} + +func (bc *baseConnector) GetVersionByID(ID int32) (*types.Version, error) { + versionsSelectSQL := bc.dialect.GetVersionByIDSQL() + + rows, err := bc.db.Query(versionsSelectSQL, ID) + if err != nil { + panic(fmt.Sprintf("Could not query versions: %v", err)) + } + + // readVersions is generic and returns a slice of Version objects + // we are querying by ID and are interested in only the first one + versions := bc.readVersions(rows) + + if len(versions) == 0 { + return nil, fmt.Errorf("Version not found ID: %v", ID) + } + + return &versions[0], nil +} + +func (bc *baseConnector) getVersionByIDInTx(tx *sql.Tx, ID int32) *types.Version { + versionsSelectSQL := bc.dialect.GetVersionByIDSQL() + + rows, err := tx.Query(versionsSelectSQL, ID) + if err != nil { + panic(fmt.Sprintf("Could not query versions: %v", err)) + } + + // readVersions is generic and returns a slice of Version objects + // we are querying by ID and are interested in only the first one + versions := bc.readVersions(rows) + + // when running in transaction version must be found + if len(versions) == 0 { + panic(fmt.Sprintf("Version not found ID: %v", ID)) + } + + return &versions[0] +} + +func (bc *baseConnector) readVersions(rows *sql.Rows) []types.Version { + versions := []types.Version{} + versionsMap := map[int64]*types.Version{} + + for rows.Next() { + var ( + vid int64 + vname string + vcreated time.Time + mid int64 + name string + sourceDir string + filename string + migrationType types.MigrationType + schema string + created time.Time + contents string + checksum string + ) + + if err := rows.Scan(&vid, &vname, &vcreated, &mid, &name, &sourceDir, &filename, &migrationType, &schema, &created, &contents, &checksum); err != nil { + panic(fmt.Sprintf("Could not read versions: %v", err)) + } + if versionsMap[vid] == nil { + version := types.Version{ID: int32(vid), Name: vname, Created: graphql.Time{Time: vcreated}} + versionsMap[vid] = &version + } + + version := versionsMap[vid] + migration := types.Migration{Name: name, SourceDir: sourceDir, File: filename, MigrationType: migrationType, Contents: contents, CheckSum: checksum} + version.DBMigrations = append(version.DBMigrations, types.MigrationDB{Migration: migration, ID: int32(mid), Schema: schema, AppliedAt: graphql.Time{Time: created}, Created: graphql.Time{Time: created}}) + } + + // map to versions + for _, v := range versionsMap { + versions = append(versions, *v) + } + // since we used map above sort by version + sort.Slice(versions, func(i, j int) bool { + return versions[i].ID > versions[j].ID + }) + return versions +} + +func (bc *baseConnector) GetDBMigrationByID(ID int32) (*types.DBMigration, error) { + query := bc.dialect.GetMigrationByIDSQL() + + rows, err := bc.db.Query(query, ID) + if err != nil { + panic(fmt.Sprintf("Could not query DB migrations: %v", err.Error())) + } + + if !rows.Next() { + return nil, fmt.Errorf("DB migration not found ID: %v", ID) + } + + var ( + id int64 + name string + sourceDir string + filename string + migrationType types.MigrationType + schema string + created time.Time + contents string + checksum string + ) + if err = rows.Scan(&id, &name, &sourceDir, &filename, &migrationType, &schema, &created, &contents, &checksum); err != nil { + panic(fmt.Sprintf("Could not read DB migration: %v", err.Error())) + } + m := types.Migration{Name: name, SourceDir: sourceDir, File: filename, MigrationType: migrationType, Contents: contents, CheckSum: checksum} + db := types.MigrationDB{Migration: m, ID: int32(id), Schema: schema, AppliedAt: graphql.Time{Time: created}, Created: graphql.Time{Time: created}} + + return &db, nil +} + // GetAppliedMigrations returns a list of all applied DB migrations func (bc *baseConnector) GetAppliedMigrations() []types.MigrationDB { query := bc.dialect.GetMigrationSelectSQL() @@ -152,26 +306,26 @@ func (bc *baseConnector) GetAppliedMigrations() []types.MigrationDB { filename string migrationType types.MigrationType schema string - appliedAt time.Time + created time.Time contents string checksum string ) - if err = rows.Scan(&name, &sourceDir, &filename, &migrationType, &schema, &appliedAt, &contents, &checksum); err != nil { + if err = rows.Scan(&name, &sourceDir, &filename, &migrationType, &schema, &created, &contents, &checksum); err != nil { panic(fmt.Sprintf("Could not read DB migration: %v", err.Error())) } mdef := types.Migration{Name: name, SourceDir: sourceDir, File: filename, MigrationType: migrationType, Contents: contents, CheckSum: checksum} - dbMigrations = append(dbMigrations, types.MigrationDB{Migration: mdef, Schema: schema, AppliedAt: appliedAt}) + dbMigrations = append(dbMigrations, types.MigrationDB{Migration: mdef, Schema: schema, AppliedAt: graphql.Time{Time: created}, Created: graphql.Time{Time: created}}) } return dbMigrations } -// ApplyMigrations applies passed migrations -func (bc *baseConnector) ApplyMigrations(mode types.MigrationsModeType, migrations []types.Migration) *types.MigrationResults { +// CreateVersion creates new DB version and applies passed migrations +func (bc *baseConnector) CreateVersion(versionName string, action types.Action, dryRun bool, migrations []types.Migration) (*types.MigrationResults, *types.Version) { if len(migrations) == 0 { return &types.MigrationResults{ - StartedAt: time.Now(), + StartedAt: graphql.Time{Time: time.Now()}, Duration: 0, - } + }, nil } tenants := bc.GetTenants() @@ -184,27 +338,30 @@ func (bc *baseConnector) ApplyMigrations(mode types.MigrationsModeType, migratio defer func() { r := recover() if r == nil { - if mode == types.ModeTypeDryRun { + if dryRun { common.LogInfo(bc.ctx, "Running in dry-run mode, calling rollback") tx.Rollback() } else { - common.LogInfo(bc.ctx, "Running in %v mode, committing transaction", mode) + common.LogInfo(bc.ctx, "Running %v, committing transaction", action) if err := tx.Commit(); err != nil { panic(fmt.Sprintf("Could not commit transaction: %v", err.Error())) } } } else { - common.LogInfo(bc.ctx, "Recovered in ApplyMigrations. Transaction rollback.") + common.LogInfo(bc.ctx, "Recovered in CreateVersion. Transaction rollback.") tx.Rollback() panic(r) } }() - return bc.applyMigrationsInTx(tx, mode, tenants, migrations) + results, versionID := bc.applyMigrationsInTx(tx, versionName, action, tenants, migrations) + version := bc.getVersionByIDInTx(tx, int32(versionID)) + + return results, version } -// AddTenantAndApplyMigrations adds new tenant and applies all existing tenant migrations -func (bc *baseConnector) AddTenantAndApplyMigrations(mode types.MigrationsModeType, tenant string, migrations []types.Migration) *types.MigrationResults { +// CreateTenant creates new tenant and applies passed tenant migrations +func (bc *baseConnector) CreateTenant(versionName string, action types.Action, dryRun bool, tenant string, migrations []types.Migration) (*types.MigrationResults, *types.Version) { tenantInsertSQL := bc.getTenantInsertSQL() tx, err := bc.db.Begin() @@ -215,17 +372,17 @@ func (bc *baseConnector) AddTenantAndApplyMigrations(mode types.MigrationsModeTy defer func() { r := recover() if r == nil { - if mode == types.ModeTypeDryRun { + if dryRun { common.LogInfo(bc.ctx, "Running in dry-run mode, calling rollback") tx.Rollback() } else { - common.LogInfo(bc.ctx, "Running in %v mode, committing transaction", mode) + common.LogInfo(bc.ctx, "Running %v action, committing transaction", action) if err := tx.Commit(); err != nil { panic(fmt.Sprintf("Could not commit transaction: %v", err.Error())) } } } else { - common.LogInfo(bc.ctx, "Recovered in AddTenantAndApplyMigrations. Transaction rollback.") + common.LogInfo(bc.ctx, "Recovered in CreateTenant. Transaction rollback.") tx.Rollback() panic(r) } @@ -246,9 +403,12 @@ func (bc *baseConnector) AddTenantAndApplyMigrations(mode types.MigrationsModeTy panic(fmt.Sprintf("Failed to add tenant entry: %v", err)) } - results := bc.applyMigrationsInTx(tx, mode, []string{tenant}, migrations) + tenantStruct := types.Tenant{Name: tenant} + results, versionID := bc.applyMigrationsInTx(tx, versionName, action, []types.Tenant{tenantStruct}, migrations) + + version := bc.getVersionByIDInTx(tx, int32(versionID)) - return results + return results, version } // getTenantInsertSQL returns tenant insert SQL statement from configuration file @@ -277,31 +437,47 @@ func (bc *baseConnector) getSchemaPlaceHolder() string { return schemaPlaceHolder } -func (bc *baseConnector) applyMigrationsInTx(tx *sql.Tx, mode types.MigrationsModeType, tenants []string, migrations []types.Migration) *types.MigrationResults { +func (bc *baseConnector) applyMigrationsInTx(tx *sql.Tx, versionName string, action types.Action, tenants []types.Tenant, migrations []types.Migration) (*types.MigrationResults, int64) { results := &types.MigrationResults{ - StartedAt: time.Now(), - Tenants: len(tenants), + StartedAt: graphql.Time{Time: time.Now()}, + Tenants: int32(len(tenants)), } defer func() { - results.Duration = time.Now().Sub(results.StartedAt) + results.Duration = int32(time.Now().Sub(results.StartedAt.Time)) results.MigrationsGrandTotal = results.TenantMigrationsTotal + results.SingleMigrations results.ScriptsGrandTotal = results.TenantScriptsTotal + results.SingleScripts }() schemaPlaceHolder := bc.getSchemaPlaceHolder() + var versionID int64 + versionInsertSQL := bc.dialect.GetVersionInsertSQL() + versionInsert, err := bc.db.Prepare(versionInsertSQL) + if err != nil { + panic(fmt.Sprintf("Could not create prepared statement for version: %v", err)) + } + stmt := tx.Stmt(versionInsert) + if bc.dialect.LastInsertIDSupported() { + result, _ := stmt.Exec(versionName) + versionID, _ = result.LastInsertId() + } else { + stmt.QueryRow(versionName).Scan(&versionID) + } + insertMigrationSQL := bc.dialect.GetMigrationInsertSQL() insert, err := bc.db.Prepare(insertMigrationSQL) if err != nil { - panic(fmt.Sprintf("Could not create prepared statement: %v", err)) + panic(fmt.Sprintf("Could not create prepared statement for migration: %v", err)) } for _, m := range migrations { var schemas []string if m.MigrationType == types.MigrationTypeTenantMigration || m.MigrationType == types.MigrationTypeTenantScript { - schemas = tenants + for _, t := range tenants { + schemas = append(schemas, t.Name) + } } else { schemas = []string{filepath.Base(m.SourceDir)} } @@ -309,14 +485,14 @@ func (bc *baseConnector) applyMigrationsInTx(tx *sql.Tx, mode types.MigrationsMo for _, s := range schemas { common.LogInfo(bc.ctx, "Applying migration type: %d, schema: %s, file: %s ", m.MigrationType, s, m.File) - if mode != types.ModeTypeSync { + if action == types.ActionApply { contents := strings.Replace(m.Contents, schemaPlaceHolder, s, -1) if _, err = tx.Exec(contents); err != nil { panic(fmt.Sprintf("SQL migration %v failed with error: %v", m.File, err.Error())) } } - if _, err = tx.Stmt(insert).Exec(m.Name, m.SourceDir, m.File, m.MigrationType, s, m.Contents, m.CheckSum); err != nil { + if _, err = tx.Stmt(insert).Exec(m.Name, m.SourceDir, m.File, m.MigrationType, s, m.Contents, m.CheckSum, versionID); err != nil { panic(fmt.Sprintf("Failed to add migration entry: %v", err.Error())) } } @@ -329,14 +505,14 @@ func (bc *baseConnector) applyMigrationsInTx(tx *sql.Tx, mode types.MigrationsMo } if m.MigrationType == types.MigrationTypeTenantMigration { results.TenantMigrations++ - results.TenantMigrationsTotal += len(schemas) + results.TenantMigrationsTotal += int32(len(schemas)) } if m.MigrationType == types.MigrationTypeTenantScript { results.TenantScripts++ - results.TenantScriptsTotal += len(schemas) + results.TenantScriptsTotal += int32(len(schemas)) } } - return results + return results, versionID } diff --git a/db/db_dialect.go b/db/db_dialect.go index dc9bc96..032c843 100644 --- a/db/db_dialect.go +++ b/db/db_dialect.go @@ -12,9 +12,16 @@ type dialect interface { GetTenantSelectSQL() string GetMigrationInsertSQL() string GetMigrationSelectSQL() string + GetMigrationByIDSQL() string GetCreateTenantsTableSQL() string GetCreateMigrationsTableSQL() string GetCreateSchemaSQL(string) string + GetCreateVersionsTableSQL() []string + GetVersionInsertSQL() string + GetVersionsSelectSQL() string + GetVersionsByFileSQL() string + GetVersionByIDSQL() string + LastInsertIDSupported() bool } // baseDialect struct is used to provide default dialect interface implementation @@ -22,6 +29,7 @@ type baseDialect struct { } const ( + selectVersionsSQL = "select mv.id as vid, mv.name as vname, mv.created as vcreated, mm.id as mid, mm.name, mm.source_dir, mm.filename, mm.type, mm.db_schema, mm.created, mm.contents, mm.checksum from %v.%v mv left join %v.%v mm on mv.id = mm.version_id order by vid desc, mid asc" selectMigrationsSQL = "select name, source_dir as sd, filename, type, db_schema, created, contents, checksum from %v.%v order by name, source_dir" selectTenantsSQL = "select name from %v.%v" createMigrationsTableSQL = ` @@ -77,6 +85,12 @@ func (bd *baseDialect) GetCreateSchemaSQL(schema string) string { return fmt.Sprintf(createSchemaSQL, schema) } +// GetVersionsSelectSQL returns select SQL statement that returns all versions +// This SQL is used by both MySQL and PostgreSQL. +func (bd *baseDialect) GetVersionsSelectSQL() string { + return fmt.Sprintf(selectVersionsSQL, migratorSchema, migratorVersionsTable, migratorSchema, migratorMigrationsTable) +} + // newDialect constructs dialect instance based on the passed Config func newDialect(config *config.Config) dialect { diff --git a/db/db_dialect_test.go b/db/db_dialect_test.go index 03b7eef..f64f474 100644 --- a/db/db_dialect_test.go +++ b/db/db_dialect_test.go @@ -69,3 +69,18 @@ func TestBaseDialectGetCreateSchemaSQL(t *testing.T) { assert.Equal(t, expected, createSchemaSQL) } + +func TestBaseDialectGetVersionsSelectSQL(t *testing.T) { + config, err := config.FromFile("../test/migrator.yaml") + assert.Nil(t, err) + + config.Driver = "postgres" + + dialect := newDialect(config) + + versionsSelectSQL := dialect.GetVersionsSelectSQL() + + expected := "select mv.id as vid, mv.name as vname, mv.created as vcreated, mm.id as mid, mm.name, mm.source_dir, mm.filename, mm.type, mm.db_schema, mm.created, mm.contents, mm.checksum from migrator.migrator_versions mv left join migrator.migrator_migrations mm on mv.id = mm.version_id order by vid desc, mid asc" + + assert.Equal(t, expected, versionsSelectSQL) +} diff --git a/db/db_error_handling_test.go b/db/db_error_handling_test.go index a278187..a43eb1e 100644 --- a/db/db_error_handling_test.go +++ b/db/db_error_handling_test.go @@ -76,6 +76,31 @@ func TestInitCannotCreateMigratorMigrationsTable(t *testing.T) { } } +func TestInitCannotCreateMigratorVersionsTable(t *testing.T) { + db, mock, err := sqlmock.New() + assert.Nil(t, err) + + config := &config.Config{} + config.Driver = "postgres" + dialect := newDialect(config) + connector := baseConnector{newTestContext(), config, dialect, db} + + mock.ExpectBegin() + // don't have to provide full SQL here - patterns at work + mock.ExpectQuery("create schema").WillReturnRows() + mock.ExpectQuery("create table").WillReturnRows() + // create versions table is a script + mock.ExpectQuery("begin").WillReturnError(errors.New("trouble maker")) + + assert.PanicsWithValue(t, "Could not create versions table: trouble maker", func() { + connector.init() + }) + + if err := mock.ExpectationsWereMet(); err != nil { + t.Errorf("there were unfulfilled expectations: %s", err) + } +} + func TestInitCannotCreateMigratorTenantsTable(t *testing.T) { db, mock, err := sqlmock.New() assert.Nil(t, err) @@ -89,6 +114,8 @@ func TestInitCannotCreateMigratorTenantsTable(t *testing.T) { // don't have to provide full SQL here - patterns at work mock.ExpectQuery("create schema").WillReturnRows() mock.ExpectQuery("create table").WillReturnRows() + // create versions table is a script + mock.ExpectQuery("begin").WillReturnRows() mock.ExpectQuery("create table").WillReturnError(errors.New("trouble maker")) assert.PanicsWithValue(t, "Could not create default tenants table: trouble maker", func() { @@ -113,6 +140,7 @@ func TestInitCannotCommitTransaction(t *testing.T) { // don't have to provide full SQL here - patterns at work mock.ExpectQuery("create schema").WillReturnRows() mock.ExpectQuery("create table").WillReturnRows() + mock.ExpectQuery("begin").WillReturnRows() mock.ExpectQuery("create table").WillReturnRows() mock.ExpectCommit().WillReturnError(errors.New("trouble maker")) @@ -167,7 +195,7 @@ func TestGetAppliedMigrationsError(t *testing.T) { } } -func TestApplyTransactionBeginError(t *testing.T) { +func TestCreateVersionTransactionBeginError(t *testing.T) { db, mock, err := sqlmock.New() assert.Nil(t, err) @@ -185,7 +213,7 @@ func TestApplyTransactionBeginError(t *testing.T) { migrationsToApply := []types.Migration{tenant1} assert.PanicsWithValue(t, "Could not start transaction: trouble maker tx.Begin()", func() { - connector.ApplyMigrations(types.ModeTypeApply, migrationsToApply) + connector.CreateVersion("commit-sha", types.ActionApply, false, migrationsToApply) }) if err := mock.ExpectationsWereMet(); err != nil { @@ -193,7 +221,7 @@ func TestApplyTransactionBeginError(t *testing.T) { } } -func TestApplyInsertMigrationPreparedStatementError(t *testing.T) { +func TestCreateVersionInsertVersionPreparedStatementError(t *testing.T) { db, mock, err := sqlmock.New() assert.Nil(t, err) @@ -212,8 +240,8 @@ func TestApplyInsertMigrationPreparedStatementError(t *testing.T) { tenant1 := types.Migration{Name: fmt.Sprintf("%v.sql", t1), SourceDir: "tenants", File: fmt.Sprintf("tenants/%v.sql", t1), MigrationType: types.MigrationTypeTenantMigration, Contents: "insert into {schema}.settings values (456, '456') "} migrationsToApply := []types.Migration{tenant1} - assert.PanicsWithValue(t, "Could not create prepared statement: trouble maker", func() { - connector.ApplyMigrations(types.ModeTypeApply, migrationsToApply) + assert.PanicsWithValue(t, "Could not create prepared statement for version: trouble maker", func() { + connector.CreateVersion("commit-sha", types.ActionApply, false, migrationsToApply) }) if err := mock.ExpectationsWereMet(); err != nil { @@ -221,7 +249,7 @@ func TestApplyInsertMigrationPreparedStatementError(t *testing.T) { } } -func TestApplyMigrationSQLError(t *testing.T) { +func TestCreateVersionInsertMigrationPreparedStatementError(t *testing.T) { db, mock, err := sqlmock.New() assert.Nil(t, err) @@ -233,7 +261,43 @@ func TestApplyMigrationSQLError(t *testing.T) { tenants := sqlmock.NewRows([]string{"name"}).AddRow("tenantname") mock.ExpectQuery("select").WillReturnRows(tenants) mock.ExpectBegin() - mock.ExpectPrepare("insert into") + // version + mock.ExpectPrepare("insert into migrator.migrator_versions") + mock.ExpectPrepare("insert into migrator.migrator_versions").ExpectQuery().WithArgs("commit-sha") + // migration + mock.ExpectPrepare("insert into migrator.migrator_migrations").WillReturnError(errors.New("trouble maker")) + mock.ExpectRollback() + + t1 := time.Now().UnixNano() + tenant1 := types.Migration{Name: fmt.Sprintf("%v.sql", t1), SourceDir: "tenants", File: fmt.Sprintf("tenants/%v.sql", t1), MigrationType: types.MigrationTypeTenantMigration, Contents: "insert into {schema}.settings values (456, '456') "} + migrationsToApply := []types.Migration{tenant1} + + assert.PanicsWithValue(t, "Could not create prepared statement for migration: trouble maker", func() { + connector.CreateVersion("commit-sha", types.ActionApply, false, migrationsToApply) + }) + + if err := mock.ExpectationsWereMet(); err != nil { + t.Errorf("there were unfulfilled expectations: %s", err) + } +} + +func TestCreateVersionMigrationSQLError(t *testing.T) { + db, mock, err := sqlmock.New() + assert.Nil(t, err) + + config := &config.Config{} + config.Driver = "postgres" + dialect := newDialect(config) + connector := baseConnector{newTestContext(), config, dialect, db} + + tenants := sqlmock.NewRows([]string{"name"}).AddRow("tenantname") + mock.ExpectQuery("select").WillReturnRows(tenants) + mock.ExpectBegin() + // version + mock.ExpectPrepare("insert into migrator.migrator_versions") + mock.ExpectPrepare("insert into migrator.migrator_versions").ExpectQuery().WithArgs("commit-sha") + // migration + mock.ExpectPrepare("insert into migrator.migrator_migrations") mock.ExpectExec("insert into").WillReturnError(errors.New("trouble maker")) mock.ExpectRollback() @@ -242,7 +306,7 @@ func TestApplyMigrationSQLError(t *testing.T) { migrationsToApply := []types.Migration{tenant1} assert.PanicsWithValue(t, fmt.Sprintf("SQL migration %v failed with error: trouble maker", tenant1.File), func() { - connector.ApplyMigrations(types.ModeTypeApply, migrationsToApply) + connector.CreateVersion("commit-sha", types.ActionApply, false, migrationsToApply) }) if err := mock.ExpectationsWereMet(); err != nil { @@ -250,7 +314,7 @@ func TestApplyMigrationSQLError(t *testing.T) { } } -func TestApplyInsertMigrationError(t *testing.T) { +func TestCreateVersionInsertMigrationError(t *testing.T) { db, mock, err := sqlmock.New() assert.Nil(t, err) @@ -267,13 +331,17 @@ func TestApplyInsertMigrationError(t *testing.T) { tenants := sqlmock.NewRows([]string{"name"}).AddRow(tenant) mock.ExpectQuery("select").WillReturnRows(tenants) mock.ExpectBegin() - mock.ExpectPrepare("insert into") - mock.ExpectExec("insert into").WillReturnResult(sqlmock.NewResult(1, 1)) - mock.ExpectPrepare("insert into").ExpectExec().WithArgs(m.Name, m.SourceDir, m.File, m.MigrationType, tenant, m.Contents, m.CheckSum).WillReturnError(errors.New("trouble maker")) + // version + mock.ExpectPrepare("insert into migrator.migrator_versions") + mock.ExpectPrepare("insert into migrator.migrator_versions").ExpectQuery().WithArgs("commit-sha") + // migration + mock.ExpectPrepare("insert into migrator.migrator_migrations") + mock.ExpectExec("insert into").WillReturnResult(sqlmock.NewResult(0, 0)) + mock.ExpectPrepare("insert into migrator.migrator_migrations").ExpectExec().WithArgs(m.Name, m.SourceDir, m.File, m.MigrationType, tenant, m.Contents, m.CheckSum, 0).WillReturnError(errors.New("trouble maker")) mock.ExpectRollback() assert.PanicsWithValue(t, "Failed to add migration entry: trouble maker", func() { - connector.ApplyMigrations(types.ModeTypeApply, migrationsToApply) + connector.CreateVersion("commit-sha", types.ActionApply, false, migrationsToApply) }) if err := mock.ExpectationsWereMet(); err != nil { @@ -281,7 +349,7 @@ func TestApplyInsertMigrationError(t *testing.T) { } } -func TestApplyMigrationsCommitError(t *testing.T) { +func TestCreateVersionGetVersionError(t *testing.T) { db, mock, err := sqlmock.New() assert.Nil(t, err) @@ -290,21 +358,101 @@ func TestApplyMigrationsCommitError(t *testing.T) { dialect := newDialect(config) connector := baseConnector{newTestContext(), config, dialect, db} - time := time.Now().UnixNano() - m := types.Migration{Name: fmt.Sprintf("%v.sql", time), SourceDir: "tenants", File: fmt.Sprintf("tenants/%v.sql", time), MigrationType: types.MigrationTypeTenantMigration, Contents: "insert into {schema}.settings values (456, '456') "} + tn := time.Now().UnixNano() + m := types.Migration{Name: fmt.Sprintf("%v.sql", tn), SourceDir: "tenants", File: fmt.Sprintf("tenants/%v.sql", tn), MigrationType: types.MigrationTypeTenantMigration, Contents: "insert into {schema}.settings values (456, '456') "} migrationsToApply := []types.Migration{m} tenant := "tenantname" tenants := sqlmock.NewRows([]string{"name"}).AddRow(tenant) mock.ExpectQuery("select").WillReturnRows(tenants) mock.ExpectBegin() - mock.ExpectPrepare("insert into") - mock.ExpectExec("insert into").WillReturnResult(sqlmock.NewResult(1, 1)) - mock.ExpectPrepare("insert into").ExpectExec().WithArgs(m.Name, m.SourceDir, m.File, m.MigrationType, tenant, m.Contents, m.CheckSum).WillReturnResult(sqlmock.NewResult(1, 1)) + // version + mock.ExpectPrepare("insert into migrator.migrator_versions") + mock.ExpectPrepare("insert into migrator.migrator_versions").ExpectQuery().WithArgs("commit-sha") + // migration + mock.ExpectPrepare("insert into migrator.migrator_migrations") + mock.ExpectExec("insert into").WillReturnResult(sqlmock.NewResult(0, 0)) + mock.ExpectPrepare("insert into migrator.migrator_migrations").ExpectExec().WithArgs(m.Name, m.SourceDir, m.File, m.MigrationType, tenant, m.Contents, m.CheckSum, 0).WillReturnResult(sqlmock.NewResult(0, 0)) + // get version + mock.ExpectQuery("select").WillReturnError(errors.New("get version trouble maker")) + + assert.PanicsWithValue(t, "Could not query versions: get version trouble maker", func() { + connector.CreateVersion("commit-sha", types.ActionApply, false, migrationsToApply) + }) + + if err := mock.ExpectationsWereMet(); err != nil { + t.Errorf("there were unfulfilled expectations: %s", err) + } +} + +func TestCreateVersionVersionNotFound(t *testing.T) { + db, mock, err := sqlmock.New() + assert.Nil(t, err) + + config := &config.Config{} + config.Driver = "postgres" + dialect := newDialect(config) + connector := baseConnector{newTestContext(), config, dialect, db} + + tn := time.Now().UnixNano() + m := types.Migration{Name: fmt.Sprintf("%v.sql", tn), SourceDir: "tenants", File: fmt.Sprintf("tenants/%v.sql", tn), MigrationType: types.MigrationTypeTenantMigration, Contents: "insert into {schema}.settings values (456, '456') "} + migrationsToApply := []types.Migration{m} + + tenant := "tenantname" + tenants := sqlmock.NewRows([]string{"name"}).AddRow(tenant) + mock.ExpectQuery("select").WillReturnRows(tenants) + mock.ExpectBegin() + // version + mock.ExpectPrepare("insert into migrator.migrator_versions") + mock.ExpectPrepare("insert into migrator.migrator_versions").ExpectQuery().WithArgs("commit-sha") + // migration + mock.ExpectPrepare("insert into migrator.migrator_migrations") + mock.ExpectExec("insert into").WillReturnResult(sqlmock.NewResult(0, 0)) + mock.ExpectPrepare("insert into migrator.migrator_migrations").ExpectExec().WithArgs(m.Name, m.SourceDir, m.File, m.MigrationType, tenant, m.Contents, m.CheckSum, 0).WillReturnResult(sqlmock.NewResult(0, 0)) + // get version + rows := sqlmock.NewRows([]string{"vid", "vname", "vcreated", "mid", "name", "source_dir", "filename", "type", "db_schema", "created", "contents", "checksum"}) + mock.ExpectQuery("select").WillReturnRows(rows) + + assert.PanicsWithValue(t, "Version not found ID: 0", func() { + connector.CreateVersion("commit-sha", types.ActionApply, false, migrationsToApply) + }) + + if err := mock.ExpectationsWereMet(); err != nil { + t.Errorf("there were unfulfilled expectations: %s", err) + } +} + +func TestCreateVersionMigrationsCommitError(t *testing.T) { + db, mock, err := sqlmock.New() + assert.Nil(t, err) + + config := &config.Config{} + config.Driver = "postgres" + dialect := newDialect(config) + connector := baseConnector{newTestContext(), config, dialect, db} + + tn := time.Now().UnixNano() + m := types.Migration{Name: fmt.Sprintf("%v.sql", tn), SourceDir: "tenants", File: fmt.Sprintf("tenants/%v.sql", tn), MigrationType: types.MigrationTypeTenantMigration, Contents: "insert into {schema}.settings values (456, '456') "} + migrationsToApply := []types.Migration{m} + + tenant := "tenantname" + tenants := sqlmock.NewRows([]string{"name"}).AddRow(tenant) + mock.ExpectQuery("select").WillReturnRows(tenants) + mock.ExpectBegin() + // version + mock.ExpectPrepare("insert into migrator.migrator_versions") + mock.ExpectPrepare("insert into migrator.migrator_versions").ExpectQuery().WithArgs("commit-sha") + // migration + mock.ExpectPrepare("insert into migrator.migrator_migrations") + mock.ExpectExec("insert into").WillReturnResult(sqlmock.NewResult(0, 0)) + mock.ExpectPrepare("insert into migrator.migrator_migrations").ExpectExec().WithArgs(m.Name, m.SourceDir, m.File, m.MigrationType, tenant, m.Contents, m.CheckSum, 0).WillReturnResult(sqlmock.NewResult(0, 0)) + // get version + rows := sqlmock.NewRows([]string{"vid", "vname", "vcreated", "mid", "name", "source_dir", "filename", "type", "db_schema", "created", "contents", "checksum"}).AddRow("123", "vname", time.Now(), "456", m.Name, m.SourceDir, m.File, m.MigrationType, tenant, time.Now(), m.Contents, m.CheckSum) + mock.ExpectQuery("select").WillReturnRows(rows) mock.ExpectCommit().WillReturnError(errors.New("tx trouble maker")) assert.PanicsWithValue(t, "Could not commit transaction: tx trouble maker", func() { - connector.ApplyMigrations(types.ModeTypeApply, migrationsToApply) + connector.CreateVersion("commit-sha", types.ActionApply, false, migrationsToApply) }) if err := mock.ExpectationsWereMet(); err != nil { @@ -312,7 +460,7 @@ func TestApplyMigrationsCommitError(t *testing.T) { } } -func TestAddTenantTransactionBeginError(t *testing.T) { +func TestCreateTenantTransactionBeginError(t *testing.T) { db, mock, err := sqlmock.New() assert.Nil(t, err) @@ -328,7 +476,7 @@ func TestAddTenantTransactionBeginError(t *testing.T) { migrationsToApply := []types.Migration{tenant1} assert.PanicsWithValue(t, "Could not start transaction: trouble maker tx.Begin()", func() { - connector.AddTenantAndApplyMigrations(types.ModeTypeApply, "newtenant", migrationsToApply) + connector.CreateTenant("commit-sha", types.ActionApply, false, "newtenant", migrationsToApply) }) if err := mock.ExpectationsWereMet(); err != nil { @@ -336,7 +484,7 @@ func TestAddTenantTransactionBeginError(t *testing.T) { } } -func TestAddTenantAndApplyMigrationsCreateSchemaError(t *testing.T) { +func TestCreateTenantCreateSchemaError(t *testing.T) { db, mock, err := sqlmock.New() assert.Nil(t, err) @@ -354,7 +502,7 @@ func TestAddTenantAndApplyMigrationsCreateSchemaError(t *testing.T) { migrationsToApply := []types.Migration{tenant1} assert.PanicsWithValue(t, "Create schema failed: trouble maker", func() { - connector.AddTenantAndApplyMigrations(types.ModeTypeApply, "newtenant", migrationsToApply) + connector.CreateTenant("commit-sha", types.ActionApply, false, "newtenant", migrationsToApply) }) if err := mock.ExpectationsWereMet(); err != nil { @@ -362,7 +510,7 @@ func TestAddTenantAndApplyMigrationsCreateSchemaError(t *testing.T) { } } -func TestAddTenantAndApplyMigrationsInsertTenantPreparedStatementError(t *testing.T) { +func TestCreateTenantInsertTenantPreparedStatementError(t *testing.T) { db, mock, err := sqlmock.New() assert.Nil(t, err) @@ -372,7 +520,7 @@ func TestAddTenantAndApplyMigrationsInsertTenantPreparedStatementError(t *testin connector := baseConnector{newTestContext(), config, dialect, db} mock.ExpectBegin() - mock.ExpectExec("create schema").WillReturnResult(sqlmock.NewResult(1, 1)) + mock.ExpectExec("create schema").WillReturnResult(sqlmock.NewResult(0, 0)) mock.ExpectPrepare("insert into").WillReturnError(errors.New("trouble maker")) mock.ExpectRollback() @@ -381,7 +529,7 @@ func TestAddTenantAndApplyMigrationsInsertTenantPreparedStatementError(t *testin migrationsToApply := []types.Migration{tenant1} assert.PanicsWithValue(t, "Could not create prepared statement: trouble maker", func() { - connector.AddTenantAndApplyMigrations(types.ModeTypeApply, "newtenant", migrationsToApply) + connector.CreateTenant("commit-sha", types.ActionApply, false, "newtenant", migrationsToApply) }) if err := mock.ExpectationsWereMet(); err != nil { @@ -389,7 +537,7 @@ func TestAddTenantAndApplyMigrationsInsertTenantPreparedStatementError(t *testin } } -func TestAddTenantAndApplyMigrationsInsertTenantError(t *testing.T) { +func TestCreateTenantInsertTenantError(t *testing.T) { db, mock, err := sqlmock.New() assert.Nil(t, err) @@ -401,7 +549,7 @@ func TestAddTenantAndApplyMigrationsInsertTenantError(t *testing.T) { tenant := "tenant" mock.ExpectBegin() - mock.ExpectExec("create schema").WillReturnResult(sqlmock.NewResult(1, 1)) + mock.ExpectExec("create schema").WillReturnResult(sqlmock.NewResult(0, 0)) mock.ExpectPrepare("insert into") mock.ExpectPrepare("insert into").ExpectExec().WithArgs(tenant).WillReturnError(errors.New("trouble maker")) mock.ExpectRollback() @@ -411,7 +559,7 @@ func TestAddTenantAndApplyMigrationsInsertTenantError(t *testing.T) { migrationsToApply := []types.Migration{m1} assert.PanicsWithValue(t, "Failed to add tenant entry: trouble maker", func() { - connector.AddTenantAndApplyMigrations(types.ModeTypeApply, tenant, migrationsToApply) + connector.CreateTenant("commit-sha", types.ActionApply, false, tenant, migrationsToApply) }) if err := mock.ExpectationsWereMet(); err != nil { @@ -419,7 +567,7 @@ func TestAddTenantAndApplyMigrationsInsertTenantError(t *testing.T) { } } -func TestAddTenantAndApplyMigrationsCommitError(t *testing.T) { +func TestCreateTenantCommitError(t *testing.T) { db, mock, err := sqlmock.New() assert.Nil(t, err) @@ -428,22 +576,114 @@ func TestAddTenantAndApplyMigrationsCommitError(t *testing.T) { dialect := newDialect(config) connector := baseConnector{newTestContext(), config, dialect, db} - time := time.Now().UnixNano() - m := types.Migration{Name: fmt.Sprintf("%v.sql", time), SourceDir: "tenants", File: fmt.Sprintf("tenants/%v.sql", time), MigrationType: types.MigrationTypeTenantMigration, Contents: "insert into {schema}.settings values (456, '456') "} + tn := time.Now().UnixNano() + m := types.Migration{Name: fmt.Sprintf("%v.sql", tn), SourceDir: "tenants", File: fmt.Sprintf("tenants/%v.sql", tn), MigrationType: types.MigrationTypeTenantMigration, Contents: "insert into {schema}.settings values (456, '456') "} migrationsToApply := []types.Migration{m} tenant := "tenantname" mock.ExpectBegin() - mock.ExpectExec("create schema").WillReturnResult(sqlmock.NewResult(1, 1)) + mock.ExpectExec("create schema").WillReturnResult(sqlmock.NewResult(0, 0)) + // tenant mock.ExpectPrepare("insert into") - mock.ExpectPrepare("insert into").ExpectExec().WithArgs(tenant).WillReturnResult(sqlmock.NewResult(1, 1)) - mock.ExpectPrepare("insert into") - mock.ExpectExec("insert into").WillReturnResult(sqlmock.NewResult(1, 1)) - mock.ExpectPrepare("insert into").ExpectExec().WithArgs(m.Name, m.SourceDir, m.File, m.MigrationType, tenant, m.Contents, m.CheckSum).WillReturnResult(sqlmock.NewResult(1, 1)) + mock.ExpectPrepare("insert into").ExpectExec().WithArgs(tenant).WillReturnResult(sqlmock.NewResult(0, 0)) + // version + mock.ExpectPrepare("insert into migrator.migrator_versions") + mock.ExpectPrepare("insert into migrator.migrator_versions").ExpectQuery().WithArgs("commit-sha") + // migration + mock.ExpectPrepare("insert into migrator.migrator_migrations") + mock.ExpectExec("insert into").WillReturnResult(sqlmock.NewResult(0, 0)) + mock.ExpectPrepare("insert into").ExpectExec().WithArgs(m.Name, m.SourceDir, m.File, m.MigrationType, tenant, m.Contents, m.CheckSum, 0).WillReturnResult(sqlmock.NewResult(0, 0)) + // get version + rows := sqlmock.NewRows([]string{"vid", "vname", "vcreated", "mid", "name", "source_dir", "filename", "type", "db_schema", "created", "contents", "checksum"}).AddRow("123", "vname", time.Now(), "456", m.Name, m.SourceDir, m.File, m.MigrationType, tenant, time.Now(), m.Contents, m.CheckSum) + mock.ExpectQuery("select").WillReturnRows(rows) mock.ExpectCommit().WillReturnError(errors.New("tx trouble maker")) assert.PanicsWithValue(t, "Could not commit transaction: tx trouble maker", func() { - connector.AddTenantAndApplyMigrations(types.ModeTypeApply, tenant, migrationsToApply) + connector.CreateTenant("commit-sha", types.ActionApply, false, tenant, migrationsToApply) + }) + + if err := mock.ExpectationsWereMet(); err != nil { + t.Errorf("there were unfulfilled expectations: %s", err) + } +} + +func TestGetVersionsError(t *testing.T) { + db, mock, err := sqlmock.New() + assert.Nil(t, err) + + config := &config.Config{} + config.Driver = "postgres" + dialect := newDialect(config) + connector := baseConnector{newTestContext(), config, dialect, db} + + // don't have to provide full SQL here - patterns at work + mock.ExpectQuery("select").WillReturnError(errors.New("trouble maker")) + + assert.PanicsWithValue(t, "Could not query versions: trouble maker", func() { + connector.GetVersions() + }) + + if err := mock.ExpectationsWereMet(); err != nil { + t.Errorf("there were unfulfilled expectations: %s", err) + } +} + +func TestGetVersionsByFileError(t *testing.T) { + db, mock, err := sqlmock.New() + assert.Nil(t, err) + + config := &config.Config{} + config.Driver = "postgres" + dialect := newDialect(config) + connector := baseConnector{newTestContext(), config, dialect, db} + + // don't have to provide full SQL here - patterns at work + mock.ExpectQuery("select").WillReturnError(errors.New("trouble maker")) + + assert.PanicsWithValue(t, "Could not query versions: trouble maker", func() { + connector.GetVersionsByFile("file") + }) + + if err := mock.ExpectationsWereMet(); err != nil { + t.Errorf("there were unfulfilled expectations: %s", err) + } +} + +func TestGetVersionsByIDError(t *testing.T) { + db, mock, err := sqlmock.New() + assert.Nil(t, err) + + config := &config.Config{} + config.Driver = "postgres" + dialect := newDialect(config) + connector := baseConnector{newTestContext(), config, dialect, db} + + // don't have to provide full SQL here - patterns at work + mock.ExpectQuery("select").WillReturnError(errors.New("trouble maker")) + + assert.PanicsWithValue(t, "Could not query versions: trouble maker", func() { + connector.GetVersionByID(0) + }) + + if err := mock.ExpectationsWereMet(); err != nil { + t.Errorf("there were unfulfilled expectations: %s", err) + } +} + +func TestGetDBMigrationByIDError(t *testing.T) { + db, mock, err := sqlmock.New() + assert.Nil(t, err) + + config := &config.Config{} + config.Driver = "postgres" + dialect := newDialect(config) + connector := baseConnector{newTestContext(), config, dialect, db} + + // don't have to provide full SQL here - patterns at work + mock.ExpectQuery("select").WillReturnError(errors.New("trouble maker")) + + assert.PanicsWithValue(t, "Could not query DB migrations: trouble maker", func() { + connector.GetDBMigrationByID(0) }) if err := mock.ExpectationsWereMet(); err != nil { diff --git a/db/db_mssql.go b/db/db_mssql.go index 615bf51..a785093 100644 --- a/db/db_mssql.go +++ b/db/db_mssql.go @@ -11,9 +11,13 @@ type msSQLDialect struct { } const ( - insertMigrationMSSQLDialectSQL = "insert into %v.%v (name, source_dir, filename, type, db_schema, contents, checksum) values (@p1, @p2, @p3, @p4, @p5, @p6, @p7)" - insertTenantMSSQLDialectSQL = "insert into %v.%v (name) values (@p1)" - createTenantsTableMSSQLDialectSQL = ` + insertMigrationMSSQLDialectSQL = "insert into %v.%v (name, source_dir, filename, type, db_schema, contents, checksum, version_id) values (@p1, @p2, @p3, @p4, @p5, @p6, @p7, @p8)" + insertTenantMSSQLDialectSQL = "insert into %v.%v (name) values (@p1)" + insertVersionMSSQLSQLDialectSQL = "insert into %v.%v (name) output inserted.id values (@p1)" + selectVersionsByFileMSSQLDialectSQL = "select mv.id as vid, mv.name as vname, mv.created as vcreated, mm.id as mid, mm.name, mm.source_dir, mm.filename, mm.type, mm.db_schema, mm.created, mm.contents, mm.checksum from %v.%v mv left join %v.%v mm on mv.id = mm.version_id where mv.id in (select version_id from %v.%v where filename = @p1) order by vid desc, mid asc" + selectVersionByIDMSSQLDialectSQL = "select mv.id as vid, mv.name as vname, mv.created as vcreated, mm.id as mid, mm.name, mm.source_dir, mm.filename, mm.type, mm.db_schema, mm.created, mm.contents, mm.checksum from %v.%v mv left join %v.%v mm on mv.id = mm.version_id where mv.id = @p1 order by mid asc" + selectMigrationByIDMSSQLDialectSQL = "select id, name, source_dir, filename, type, db_schema, created, contents, checksum from %v.%v where id = @p1" + createTenantsTableMSSQLDialectSQL = ` IF NOT EXISTS (select * from information_schema.tables where table_schema = '%v' and table_name = '%v') BEGIN create table [%v].%v ( @@ -44,9 +48,41 @@ IF NOT EXISTS (select * from information_schema.schemata where schema_name = '%v BEGIN EXEC sp_executesql N'create schema %v'; END +` + versionsTableSetupMSSQLDialectSQL = ` +if not exists (select * from information_schema.tables where table_schema = '%v' and table_name = '%v') +begin + declare @cn nvarchar(200); + create table [%v].%v ( + id int identity (1,1) primary key, + name varchar(200) not null, + created datetime default CURRENT_TIMESTAMP + ); + -- workaround for MSSQL not finding a newly created column + -- when creating initial version default value is set to 1 + alter table [%v].%v add version_id int not null default 1; + if exists (select * from [%v].%v) + begin + insert into [%v].%v (name) values ('Initial version'); + end + -- change version_id to not null + alter table [%v].%v + alter column version_id int not null; + alter table [%v].%v + add constraint migrator_versions_version_id_fk foreign key (version_id) references [%v].%v (id) on delete cascade; + create index migrator_migrations_version_id_idx on [%v].%v (version_id); + -- remove workaround default value + select @cn = name from sys.default_constraints where parent_object_id = object_id('[%v].%v') and name like '%%ver%%'; + EXEC ('alter table [%v].%v drop constraint ' + @cn); +end ` ) +// LastInsertIDSupported instructs migrator if Result.LastInsertId() is supported by the DB driver +func (md *msSQLDialect) LastInsertIDSupported() bool { + return false +} + // GetMigrationInsertSQL returns MS SQL-specific migration insert SQL statement func (md *msSQLDialect) GetMigrationInsertSQL() string { return fmt.Sprintf(insertMigrationMSSQLDialectSQL, migratorSchema, migratorMigrationsTable) @@ -74,3 +110,23 @@ func (md *msSQLDialect) GetCreateMigrationsTableSQL() string { func (md *msSQLDialect) GetCreateSchemaSQL(schema string) string { return fmt.Sprintf(createSchemaMSSQLDialectSQL, schema, schema) } + +func (md *msSQLDialect) GetVersionInsertSQL() string { + return fmt.Sprintf(insertVersionMSSQLSQLDialectSQL, migratorSchema, migratorVersionsTable) +} + +func (md *msSQLDialect) GetCreateVersionsTableSQL() []string { + return []string{fmt.Sprintf(versionsTableSetupMSSQLDialectSQL, migratorSchema, migratorVersionsTable, migratorSchema, migratorVersionsTable, migratorSchema, migratorMigrationsTable, migratorSchema, migratorMigrationsTable, migratorSchema, migratorVersionsTable, migratorSchema, migratorMigrationsTable, migratorSchema, migratorMigrationsTable, migratorSchema, migratorVersionsTable, migratorSchema, migratorMigrationsTable, migratorSchema, migratorMigrationsTable, migratorSchema, migratorMigrationsTable)} +} + +func (md *msSQLDialect) GetVersionsByFileSQL() string { + return fmt.Sprintf(selectVersionsByFileMSSQLDialectSQL, migratorSchema, migratorVersionsTable, migratorSchema, migratorMigrationsTable, migratorSchema, migratorMigrationsTable) +} + +func (md *msSQLDialect) GetVersionByIDSQL() string { + return fmt.Sprintf(selectVersionByIDMSSQLDialectSQL, migratorSchema, migratorVersionsTable, migratorSchema, migratorMigrationsTable) +} + +func (md *msSQLDialect) GetMigrationByIDSQL() string { + return fmt.Sprintf(selectMigrationByIDMSSQLDialectSQL, migratorSchema, migratorMigrationsTable) +} diff --git a/db/db_mssql_test.go b/db/db_mssql_test.go index 0f985a3..922b491 100644 --- a/db/db_mssql_test.go +++ b/db/db_mssql_test.go @@ -14,6 +14,18 @@ func TestDBCreateDialectMSSQLDriver(t *testing.T) { assert.IsType(t, &msSQLDialect{}, dialect) } +func TestMSSQLLastInsertIdSupported(t *testing.T) { + config, err := config.FromFile("../test/migrator.yaml") + assert.Nil(t, err) + + config.Driver = "sqlserver" + dialect := newDialect(config) + + lastInsertIDSupported := dialect.LastInsertIDSupported() + + assert.False(t, lastInsertIDSupported) +} + func TestMSSQLGetMigrationInsertSQL(t *testing.T) { config, err := config.FromFile("../test/migrator.yaml") assert.Nil(t, err) @@ -24,7 +36,7 @@ func TestMSSQLGetMigrationInsertSQL(t *testing.T) { insertMigrationSQL := dialect.GetMigrationInsertSQL() - assert.Equal(t, "insert into migrator.migrator_migrations (name, source_dir, filename, type, db_schema, contents, checksum) values (@p1, @p2, @p3, @p4, @p5, @p6, @p7)", insertMigrationSQL) + assert.Equal(t, "insert into migrator.migrator_migrations (name, source_dir, filename, type, db_schema, contents, checksum, version_id) values (@p1, @p2, @p3, @p4, @p5, @p6, @p7, @p8)", insertMigrationSQL) } func TestMSSQLGetTenantInsertSQLDefault(t *testing.T) { @@ -114,3 +126,92 @@ END assert.Equal(t, expected, createSchemaSQL) } + +func TestMSSQLGetVersionInsertSQL(t *testing.T) { + config, err := config.FromFile("../test/migrator.yaml") + assert.Nil(t, err) + + config.Driver = "sqlserver" + dialect := newDialect(config) + + versionInsertSQL := dialect.GetVersionInsertSQL() + + assert.Equal(t, "insert into migrator.migrator_versions (name) output inserted.id values (@p1)", versionInsertSQL) +} + +func TestMSSQLGetCreateVersionsTableSQL(t *testing.T) { + config, err := config.FromFile("../test/migrator.yaml") + assert.Nil(t, err) + + config.Driver = "sqlserver" + dialect := newDialect(config) + + actual := dialect.GetCreateVersionsTableSQL() + + expected := + ` +if not exists (select * from information_schema.tables where table_schema = 'migrator' and table_name = 'migrator_versions') +begin + declare @cn nvarchar(200); + create table [migrator].migrator_versions ( + id int identity (1,1) primary key, + name varchar(200) not null, + created datetime default CURRENT_TIMESTAMP + ); + -- workaround for MSSQL not finding a newly created column + -- when creating initial version default value is set to 1 + alter table [migrator].migrator_migrations add version_id int not null default 1; + if exists (select * from [migrator].migrator_migrations) + begin + insert into [migrator].migrator_versions (name) values ('Initial version'); + end + -- change version_id to not null + alter table [migrator].migrator_migrations + alter column version_id int not null; + alter table [migrator].migrator_migrations + add constraint migrator_versions_version_id_fk foreign key (version_id) references [migrator].migrator_versions (id) on delete cascade; + create index migrator_migrations_version_id_idx on [migrator].migrator_migrations (version_id); + -- remove workaround default value + select @cn = name from sys.default_constraints where parent_object_id = object_id('[migrator].migrator_migrations') and name like '%ver%'; + EXEC ('alter table [migrator].migrator_migrations drop constraint ' + @cn); +end +` + + assert.Equal(t, expected, actual[0]) +} + +func TestMSSQLGetVersionsByFileSQL(t *testing.T) { + config, err := config.FromFile("../test/migrator.yaml") + assert.Nil(t, err) + + config.Driver = "sqlserver" + dialect := newDialect(config) + + versionsByFile := dialect.GetVersionsByFileSQL() + + assert.Equal(t, "select mv.id as vid, mv.name as vname, mv.created as vcreated, mm.id as mid, mm.name, mm.source_dir, mm.filename, mm.type, mm.db_schema, mm.created, mm.contents, mm.checksum from migrator.migrator_versions mv left join migrator.migrator_migrations mm on mv.id = mm.version_id where mv.id in (select version_id from migrator.migrator_migrations where filename = @p1) order by vid desc, mid asc", versionsByFile) +} + +func TestMSSQLGetVersionByIDSQL(t *testing.T) { + config, err := config.FromFile("../test/migrator.yaml") + assert.Nil(t, err) + + config.Driver = "sqlserver" + dialect := newDialect(config) + + versionByID := dialect.GetVersionByIDSQL() + + assert.Equal(t, "select mv.id as vid, mv.name as vname, mv.created as vcreated, mm.id as mid, mm.name, mm.source_dir, mm.filename, mm.type, mm.db_schema, mm.created, mm.contents, mm.checksum from migrator.migrator_versions mv left join migrator.migrator_migrations mm on mv.id = mm.version_id where mv.id = @p1 order by mid asc", versionByID) +} + +func TestMSSQLGetMigrationByIDSQL(t *testing.T) { + config, err := config.FromFile("../test/migrator.yaml") + assert.Nil(t, err) + + config.Driver = "sqlserver" + dialect := newDialect(config) + + migrationByID := dialect.GetMigrationByIDSQL() + + assert.Equal(t, "select id, name, source_dir, filename, type, db_schema, created, contents, checksum from migrator.migrator_migrations where id = @p1", migrationByID) +} diff --git a/db/db_mysql.go b/db/db_mysql.go index ed42980..e20b13a 100644 --- a/db/db_mysql.go +++ b/db/db_mysql.go @@ -11,10 +11,44 @@ type mySQLDialect struct { } const ( - insertMigrationMySQLDialectSQL = "insert into %v.%v (name, source_dir, filename, type, db_schema, contents, checksum) values (?, ?, ?, ?, ?, ?, ?)" - insertTenantMySQLDialectSQL = "insert into %v.%v (name) values (?)" + insertMigrationMySQLDialectSQL = "insert into %v.%v (name, source_dir, filename, type, db_schema, contents, checksum, version_id) values (?, ?, ?, ?, ?, ?, ?, ?)" + insertTenantMySQLDialectSQL = "insert into %v.%v (name) values (?)" + insertVersionMySQLDialectSQL = "insert into %v.%v (name) values (?)" + selectVersionsByFileMySQLDialectSQL = "select mv.id as vid, mv.name as vname, mv.created as vcreated, mm.id as mid, mm.name, mm.source_dir, mm.filename, mm.type, mm.db_schema, mm.created, mm.contents, mm.checksum from %v.%v mv left join %v.%v mm on mv.id = mm.version_id where mv.id in (select version_id from %v.%v where filename = ?) order by vid desc, mid asc" + selectVersionByIDMySQLDialectSQL = "select mv.id as vid, mv.name as vname, mv.created as vcreated, mm.id as mid, mm.name, mm.source_dir, mm.filename, mm.type, mm.db_schema, mm.created, mm.contents, mm.checksum from %v.%v mv left join %v.%v mm on mv.id = mm.version_id where mv.id = ? order by mid asc" + selectMigrationByIDMySQLDialectSQL = "select id, name, source_dir, filename, type, db_schema, created, contents, checksum from %v.%v where id = ?" + versionsTableSetupMySQLDropDialectSQL = `drop procedure if exists migrator_create_versions` + versionsTableSetupMySQLCallDialectSQL = `call migrator_create_versions()` + versionsTableSetupMySQLProcedureDialectSQL = ` +create procedure migrator_create_versions() +begin +if not exists (select * from information_schema.tables where table_schema = '%v' and table_name = '%v') then + create table %v.%v ( + id serial primary key, + name varchar(200) not null, + created timestamp default now() + ); + alter table %v.%v add column version_id bigint unsigned; + create index migrator_versions_version_id_idx on %v.%v (version_id); + if exists (select * from %v.%v) then + insert into %v.%v (name) values ('Initial version'); + -- initial version_id sequence is always 1 + update %v.%v set version_id = 1; + end if; + alter table %v.%v + modify version_id bigint unsigned not null; + alter table %v.%v + add constraint migrator_versions_version_id_fk foreign key (version_id) references %v.%v (id) on delete cascade; +end if; +end; +` ) +// LastInsertIDSupported instructs migrator if Result.LastInsertId() is supported by the DB driver +func (md *mySQLDialect) LastInsertIDSupported() bool { + return true +} + // GetMigrationInsertSQL returns MySQL-specific migration insert SQL statement func (md *mySQLDialect) GetMigrationInsertSQL() string { return fmt.Sprintf(insertMigrationMySQLDialectSQL, migratorSchema, migratorMigrationsTable) @@ -24,3 +58,32 @@ func (md *mySQLDialect) GetMigrationInsertSQL() string { func (md *mySQLDialect) GetTenantInsertSQL() string { return fmt.Sprintf(insertTenantMySQLDialectSQL, migratorSchema, migratorTenantsTable) } + +func (md *mySQLDialect) GetVersionInsertSQL() string { + return fmt.Sprintf(insertVersionMySQLDialectSQL, migratorSchema, migratorVersionsTable) +} + +// GetCreateVersionsTableSQL returns MySQL-specific SQLs which does: +// 1. drop procedure if exists +// 2. create procedure +// 3. calls procedure +// far from ideal MySQL in contrast to MS SQL and PostgreSQL does not support the execution of anonymous blocks of code +func (md *mySQLDialect) GetCreateVersionsTableSQL() []string { + return []string{ + versionsTableSetupMySQLDropDialectSQL, + fmt.Sprintf(versionsTableSetupMySQLProcedureDialectSQL, migratorSchema, migratorVersionsTable, migratorSchema, migratorVersionsTable, migratorSchema, migratorMigrationsTable, migratorSchema, migratorMigrationsTable, migratorSchema, migratorMigrationsTable, migratorSchema, migratorVersionsTable, migratorSchema, migratorMigrationsTable, migratorSchema, migratorMigrationsTable, migratorSchema, migratorMigrationsTable, migratorSchema, migratorVersionsTable), + versionsTableSetupMySQLCallDialectSQL, + } +} + +func (md *mySQLDialect) GetVersionsByFileSQL() string { + return fmt.Sprintf(selectVersionsByFileMySQLDialectSQL, migratorSchema, migratorVersionsTable, migratorSchema, migratorMigrationsTable, migratorSchema, migratorMigrationsTable) +} + +func (md *mySQLDialect) GetVersionByIDSQL() string { + return fmt.Sprintf(selectVersionByIDMySQLDialectSQL, migratorSchema, migratorVersionsTable, migratorSchema, migratorMigrationsTable) +} + +func (md *mySQLDialect) GetMigrationByIDSQL() string { + return fmt.Sprintf(selectMigrationByIDMySQLDialectSQL, migratorSchema, migratorMigrationsTable) +} diff --git a/db/db_mysql_test.go b/db/db_mysql_test.go index 37e49dd..c30d807 100644 --- a/db/db_mysql_test.go +++ b/db/db_mysql_test.go @@ -14,6 +14,17 @@ func TestDBCreateDialectMysqlDriver(t *testing.T) { assert.IsType(t, &mySQLDialect{}, dialect) } +func TestMySQLLastInsertIdSupported(t *testing.T) { + config, err := config.FromFile("../test/migrator.yaml") + assert.Nil(t, err) + + config.Driver = "mysql" + dialect := newDialect(config) + lastInsertIDSupported := dialect.LastInsertIDSupported() + + assert.True(t, lastInsertIDSupported) +} + func TestMySQLGetMigrationInsertSQL(t *testing.T) { config, err := config.FromFile("../test/migrator.yaml") assert.Nil(t, err) @@ -24,7 +35,7 @@ func TestMySQLGetMigrationInsertSQL(t *testing.T) { insertMigrationSQL := dialect.GetMigrationInsertSQL() - assert.Equal(t, "insert into migrator.migrator_migrations (name, source_dir, filename, type, db_schema, contents, checksum) values (?, ?, ?, ?, ?, ?, ?)", insertMigrationSQL) + assert.Equal(t, "insert into migrator.migrator_migrations (name, source_dir, filename, type, db_schema, contents, checksum, version_id) values (?, ?, ?, ?, ?, ?, ?, ?)", insertMigrationSQL) } func TestMySQLGetTenantInsertSQLDefault(t *testing.T) { @@ -40,3 +51,91 @@ func TestMySQLGetTenantInsertSQLDefault(t *testing.T) { assert.Equal(t, "insert into migrator.migrator_tenants (name) values (?)", tenantInsertSQL) } + +func TestMySQLGetVersionInsertSQL(t *testing.T) { + config, err := config.FromFile("../test/migrator.yaml") + assert.Nil(t, err) + + config.Driver = "mysql" + dialect := newDialect(config) + + versionInsertSQL := dialect.GetVersionInsertSQL() + + assert.Equal(t, "insert into migrator.migrator_versions (name) values (?)", versionInsertSQL) +} + +func TestMySQLGetCreateVersionsTableSQL(t *testing.T) { + config, err := config.FromFile("../test/migrator.yaml") + assert.Nil(t, err) + + config.Driver = "mysql" + dialect := newDialect(config) + + actual := dialect.GetCreateVersionsTableSQL() + expectedDrop := `drop procedure if exists migrator_create_versions` + expectedCall := `call migrator_create_versions()` + expectedProcedure := + ` +create procedure migrator_create_versions() +begin +if not exists (select * from information_schema.tables where table_schema = 'migrator' and table_name = 'migrator_versions') then + create table migrator.migrator_versions ( + id serial primary key, + name varchar(200) not null, + created timestamp default now() + ); + alter table migrator.migrator_migrations add column version_id bigint unsigned; + create index migrator_versions_version_id_idx on migrator.migrator_migrations (version_id); + if exists (select * from migrator.migrator_migrations) then + insert into migrator.migrator_versions (name) values ('Initial version'); + -- initial version_id sequence is always 1 + update migrator.migrator_migrations set version_id = 1; + end if; + alter table migrator.migrator_migrations + modify version_id bigint unsigned not null; + alter table migrator.migrator_migrations + add constraint migrator_versions_version_id_fk foreign key (version_id) references migrator.migrator_versions (id) on delete cascade; +end if; +end; +` + + assert.Equal(t, expectedDrop, actual[0]) + assert.Equal(t, expectedProcedure, actual[1]) + assert.Equal(t, expectedCall, actual[2]) +} + +func TestMySQLGetVersionsByFileSQL(t *testing.T) { + config, err := config.FromFile("../test/migrator.yaml") + assert.Nil(t, err) + + config.Driver = "mysql" + dialect := newDialect(config) + + versionsByFile := dialect.GetVersionsByFileSQL() + + assert.Equal(t, "select mv.id as vid, mv.name as vname, mv.created as vcreated, mm.id as mid, mm.name, mm.source_dir, mm.filename, mm.type, mm.db_schema, mm.created, mm.contents, mm.checksum from migrator.migrator_versions mv left join migrator.migrator_migrations mm on mv.id = mm.version_id where mv.id in (select version_id from migrator.migrator_migrations where filename = ?) order by vid desc, mid asc", versionsByFile) +} + +func TestMySQLGetVersionByIDSQL(t *testing.T) { + config, err := config.FromFile("../test/migrator.yaml") + assert.Nil(t, err) + + config.Driver = "mysql" + dialect := newDialect(config) + + versionsByID := dialect.GetVersionByIDSQL() + + assert.Equal(t, "select mv.id as vid, mv.name as vname, mv.created as vcreated, mm.id as mid, mm.name, mm.source_dir, mm.filename, mm.type, mm.db_schema, mm.created, mm.contents, mm.checksum from migrator.migrator_versions mv left join migrator.migrator_migrations mm on mv.id = mm.version_id where mv.id = ? order by mid asc", versionsByID) +} + +func TestMySQLGetMigrationByIDSQL(t *testing.T) { + config, err := config.FromFile("../test/migrator.yaml") + assert.Nil(t, err) + + config.Driver = "mysql" + dialect := newDialect(config) + + migrationByID := dialect.GetMigrationByIDSQL() + + assert.Equal(t, "select id, name, source_dir, filename, type, db_schema, created, contents, checksum from migrator.migrator_migrations where id = ?", migrationByID) +} diff --git a/db/db_postgresql.go b/db/db_postgresql.go index 42491e7..58fd8e2 100644 --- a/db/db_postgresql.go +++ b/db/db_postgresql.go @@ -11,10 +11,41 @@ type postgreSQLDialect struct { } const ( - insertMigrationPostgreSQLDialectSQL = "insert into %v.%v (name, source_dir, filename, type, db_schema, contents, checksum) values ($1, $2, $3, $4, $5, $6, $7)" - insertTenantPostgreSQLDialectSQL = "insert into %v.%v (name) values ($1)" + insertMigrationPostgreSQLDialectSQL = "insert into %v.%v (name, source_dir, filename, type, db_schema, contents, checksum, version_id) values ($1, $2, $3, $4, $5, $6, $7, $8)" + insertTenantPostgreSQLDialectSQL = "insert into %v.%v (name) values ($1)" + insertVersionPostgreSQLDialectSQL = "insert into %v.%v (name) values ($1) returning id" + selectVersionsByFilePostgreSQLDialectSQL = "select mv.id as vid, mv.name as vname, mv.created as vcreated, mm.id as mid, mm.name, mm.source_dir, mm.filename, mm.type, mm.db_schema, mm.created, mm.contents, mm.checksum from %v.%v mv left join %v.%v mm on mv.id = mm.version_id where mv.id in (select version_id from %v.%v where filename = $1) order by vid desc, mid asc" + selectVersionByIDPostgreSQLDialectSQL = "select mv.id as vid, mv.name as vname, mv.created as vcreated, mm.id as mid, mm.name, mm.source_dir, mm.filename, mm.type, mm.db_schema, mm.created, mm.contents, mm.checksum from %v.%v mv left join %v.%v mm on mv.id = mm.version_id where mv.id = $1 order by mid asc" + selectMigrationByIDPostgreSQLDialectSQL = "select id, name, source_dir, filename, type, db_schema, created, contents, checksum from %v.%v where id = $1" + versionsTableSetupPostgreSQLDialectSQL = ` +do $$ +begin +if not exists (select * from information_schema.tables where table_schema = '%v' and table_name = '%v') then + create table %v.%v ( + id serial primary key, + name varchar(200) not null, + created timestamp default now() + ); + alter table %v.%v add column version_id integer; + create index migrator_versions_version_id_idx on %v.%v (version_id); + if exists (select * from %v.%v) then + insert into %v.%v (name) values ('Initial version'); + -- initial version_id sequence is always 1 + update %v.%v set version_id = 1; + end if; + alter table %v.%v + alter column version_id set not null, + add constraint migrator_versions_version_id_fk foreign key (version_id) references %v.%v (id) on delete cascade; +end if; +end $$; +` ) +// LastInsertIDSupported instructs migrator if Result.LastInsertId() is supported by the DB driver +func (pd *postgreSQLDialect) LastInsertIDSupported() bool { + return false +} + // GetMigrationInsertSQL returns PostgreSQL-specific migration insert SQL statement func (pd *postgreSQLDialect) GetMigrationInsertSQL() string { return fmt.Sprintf(insertMigrationPostgreSQLDialectSQL, migratorSchema, migratorMigrationsTable) @@ -24,3 +55,28 @@ func (pd *postgreSQLDialect) GetMigrationInsertSQL() string { func (pd *postgreSQLDialect) GetTenantInsertSQL() string { return fmt.Sprintf(insertTenantPostgreSQLDialectSQL, migratorSchema, migratorTenantsTable) } + +func (pd *postgreSQLDialect) GetVersionInsertSQL() string { + return fmt.Sprintf(insertVersionPostgreSQLDialectSQL, migratorSchema, migratorVersionsTable) +} + +// GetCreateVersionsTableSQL returns PostgreSQL-specific SQL which does: +// 1. create versions table +// 2. alter statement used to add version column to migration +// 3. create initial version if migrations exists (backwards compatibility) +// 4. create not null consttraint on version column +func (pd *postgreSQLDialect) GetCreateVersionsTableSQL() []string { + return []string{fmt.Sprintf(versionsTableSetupPostgreSQLDialectSQL, migratorSchema, migratorVersionsTable, migratorSchema, migratorVersionsTable, migratorSchema, migratorMigrationsTable, migratorSchema, migratorMigrationsTable, migratorSchema, migratorMigrationsTable, migratorSchema, migratorVersionsTable, migratorSchema, migratorMigrationsTable, migratorSchema, migratorMigrationsTable, migratorSchema, migratorVersionsTable)} +} + +func (pd *postgreSQLDialect) GetVersionsByFileSQL() string { + return fmt.Sprintf(selectVersionsByFilePostgreSQLDialectSQL, migratorSchema, migratorVersionsTable, migratorSchema, migratorMigrationsTable, migratorSchema, migratorMigrationsTable) +} + +func (pd *postgreSQLDialect) GetVersionByIDSQL() string { + return fmt.Sprintf(selectVersionByIDPostgreSQLDialectSQL, migratorSchema, migratorVersionsTable, migratorSchema, migratorMigrationsTable) +} + +func (pd *postgreSQLDialect) GetMigrationByIDSQL() string { + return fmt.Sprintf(selectMigrationByIDPostgreSQLDialectSQL, migratorSchema, migratorMigrationsTable) +} diff --git a/db/db_postgresql_test.go b/db/db_postgresql_test.go index 1d3ffad..54ed17c 100644 --- a/db/db_postgresql_test.go +++ b/db/db_postgresql_test.go @@ -14,6 +14,17 @@ func TestDBCreateDialectPostgreSQLDriver(t *testing.T) { assert.IsType(t, &postgreSQLDialect{}, dialect) } +func TestPostgreSQLLastInsertIdSupported(t *testing.T) { + config, err := config.FromFile("../test/migrator.yaml") + assert.Nil(t, err) + + config.Driver = "postgres" + dialect := newDialect(config) + lastInsertIDSupported := dialect.LastInsertIDSupported() + + assert.False(t, lastInsertIDSupported) +} + func TestPostgreSQLGetMigrationInsertSQL(t *testing.T) { config, err := config.FromFile("../test/migrator.yaml") assert.Nil(t, err) @@ -24,7 +35,7 @@ func TestPostgreSQLGetMigrationInsertSQL(t *testing.T) { insertMigrationSQL := dialect.GetMigrationInsertSQL() - assert.Equal(t, "insert into migrator.migrator_migrations (name, source_dir, filename, type, db_schema, contents, checksum) values ($1, $2, $3, $4, $5, $6, $7)", insertMigrationSQL) + assert.Equal(t, "insert into migrator.migrator_migrations (name, source_dir, filename, type, db_schema, contents, checksum, version_id) values ($1, $2, $3, $4, $5, $6, $7, $8)", insertMigrationSQL) } func TestPostgreSQLGetTenantInsertSQLDefault(t *testing.T) { @@ -40,3 +51,87 @@ func TestPostgreSQLGetTenantInsertSQLDefault(t *testing.T) { assert.Equal(t, "insert into migrator.migrator_tenants (name) values ($1)", tenantInsertSQL) } + +func TestPostgreSQLGetVersionInsertSQL(t *testing.T) { + config, err := config.FromFile("../test/migrator.yaml") + assert.Nil(t, err) + + config.Driver = "postgres" + dialect := newDialect(config) + + versionInsertSQL := dialect.GetVersionInsertSQL() + + assert.Equal(t, "insert into migrator.migrator_versions (name) values ($1) returning id", versionInsertSQL) +} + +func TestPostgreSQLGetCreateVersionsTableSQL(t *testing.T) { + config, err := config.FromFile("../test/migrator.yaml") + assert.Nil(t, err) + + config.Driver = "postgres" + dialect := newDialect(config) + + actual := dialect.GetCreateVersionsTableSQL() + + expected := + ` +do $$ +begin +if not exists (select * from information_schema.tables where table_schema = 'migrator' and table_name = 'migrator_versions') then + create table migrator.migrator_versions ( + id serial primary key, + name varchar(200) not null, + created timestamp default now() + ); + alter table migrator.migrator_migrations add column version_id integer; + create index migrator_versions_version_id_idx on migrator.migrator_migrations (version_id); + if exists (select * from migrator.migrator_migrations) then + insert into migrator.migrator_versions (name) values ('Initial version'); + -- initial version_id sequence is always 1 + update migrator.migrator_migrations set version_id = 1; + end if; + alter table migrator.migrator_migrations + alter column version_id set not null, + add constraint migrator_versions_version_id_fk foreign key (version_id) references migrator.migrator_versions (id) on delete cascade; +end if; +end $$; +` + + assert.Equal(t, expected, actual[0]) +} + +func TestPostgreSQLGetVersionsByFileSQL(t *testing.T) { + config, err := config.FromFile("../test/migrator.yaml") + assert.Nil(t, err) + + config.Driver = "postgres" + dialect := newDialect(config) + + versionsByFile := dialect.GetVersionsByFileSQL() + + assert.Equal(t, "select mv.id as vid, mv.name as vname, mv.created as vcreated, mm.id as mid, mm.name, mm.source_dir, mm.filename, mm.type, mm.db_schema, mm.created, mm.contents, mm.checksum from migrator.migrator_versions mv left join migrator.migrator_migrations mm on mv.id = mm.version_id where mv.id in (select version_id from migrator.migrator_migrations where filename = $1) order by vid desc, mid asc", versionsByFile) +} + +func TestPostgreSQLGetVersionByIDSQL(t *testing.T) { + config, err := config.FromFile("../test/migrator.yaml") + assert.Nil(t, err) + + config.Driver = "postgres" + dialect := newDialect(config) + + versionsByID := dialect.GetVersionByIDSQL() + + assert.Equal(t, "select mv.id as vid, mv.name as vname, mv.created as vcreated, mm.id as mid, mm.name, mm.source_dir, mm.filename, mm.type, mm.db_schema, mm.created, mm.contents, mm.checksum from migrator.migrator_versions mv left join migrator.migrator_migrations mm on mv.id = mm.version_id where mv.id = $1 order by mid asc", versionsByID) +} + +func TestPostgreSQLGetMigrationByIDSQL(t *testing.T) { + config, err := config.FromFile("../test/migrator.yaml") + assert.Nil(t, err) + + config.Driver = "postgres" + dialect := newDialect(config) + + migrationByID := dialect.GetMigrationByIDSQL() + + assert.Equal(t, "select id, name, source_dir, filename, type, db_schema, created, contents, checksum from migrator.migrator_migrations where id = $1", migrationByID) +} diff --git a/db/db_test.go b/db/db_test.go index af4a30c..bbe202e 100644 --- a/db/db_test.go +++ b/db/db_test.go @@ -14,6 +14,11 @@ import ( "github.com/stretchr/testify/assert" ) +var ( + existingVersion types.Version + existingDBMigration types.DBMigration +) + func newTestContext() context.Context { ctx := context.TODO() ctx = context.WithValue(ctx, common.RequestIDKey{}, time.Now().Nanosecond()) @@ -62,12 +67,12 @@ func TestGetTenants(t *testing.T) { tenants := connector.GetTenants() assert.True(t, len(tenants) >= 3) - assert.Contains(t, tenants, "abc") - assert.Contains(t, tenants, "def") - assert.Contains(t, tenants, "xyz") + assert.Contains(t, tenants, types.Tenant{Name: "abc"}) + assert.Contains(t, tenants, types.Tenant{Name: "def"}) + assert.Contains(t, tenants, types.Tenant{Name: "xyz"}) } -func TestApplyMigrations(t *testing.T) { +func TestCreateVersion(t *testing.T) { config, err := config.FromFile("../test/migrator.yaml") assert.Nil(t, err) @@ -109,17 +114,21 @@ func TestApplyMigrations(t *testing.T) { migrationsToApply := []types.Migration{public1, public2, public3, tenant1, tenant2, tenant3, public4, public5, tenant4} - results := connector.ApplyMigrations(types.ModeTypeApply, migrationsToApply) - - assert.Equal(t, noOfTenants, results.Tenants) - assert.Equal(t, 3, results.SingleMigrations) - assert.Equal(t, 2, results.SingleScripts) - assert.Equal(t, 3, results.TenantMigrations) - assert.Equal(t, 1, results.TenantScripts) - assert.Equal(t, noOfTenants*3, results.TenantMigrationsTotal) - assert.Equal(t, noOfTenants*1, results.TenantScriptsTotal) - assert.Equal(t, noOfTenants*3+3, results.MigrationsGrandTotal) - assert.Equal(t, noOfTenants*1+2, results.ScriptsGrandTotal) + results, version := connector.CreateVersion("commit-sha", types.ActionApply, false, migrationsToApply) + + assert.NotNil(t, version) + assert.True(t, version.ID > 0) + assert.Equal(t, "commit-sha", version.Name) + assert.Equal(t, results.MigrationsGrandTotal+results.ScriptsGrandTotal, int32(len(version.DBMigrations))) + assert.Equal(t, int32(noOfTenants), results.Tenants) + assert.Equal(t, int32(3), results.SingleMigrations) + assert.Equal(t, int32(2), results.SingleScripts) + assert.Equal(t, int32(3), results.TenantMigrations) + assert.Equal(t, int32(1), results.TenantScripts) + assert.Equal(t, int32(noOfTenants*3), results.TenantMigrationsTotal) + assert.Equal(t, int32(noOfTenants*1), results.TenantScriptsTotal) + assert.Equal(t, int32(noOfTenants*3+3), results.MigrationsGrandTotal) + assert.Equal(t, int32(noOfTenants*1+2), results.ScriptsGrandTotal) dbMigrationsAfter := connector.GetAppliedMigrations() lenAfter := len(dbMigrationsAfter) @@ -130,7 +139,7 @@ func TestApplyMigrations(t *testing.T) { assert.Equal(t, expected, lenAfter-lenBefore) } -func TestApplyMigrationsEmptyMigrationArray(t *testing.T) { +func TestCreateVersionEmptyMigrationArray(t *testing.T) { config, err := config.FromFile("../test/migrator.yaml") assert.Nil(t, err) @@ -139,13 +148,14 @@ func TestApplyMigrationsEmptyMigrationArray(t *testing.T) { migrationsToApply := []types.Migration{} - results := connector.ApplyMigrations(types.ModeTypeApply, migrationsToApply) - - assert.Equal(t, 0, results.MigrationsGrandTotal) - assert.Equal(t, 0, results.ScriptsGrandTotal) + results, version := connector.CreateVersion("commit-sha", types.ActionApply, false, migrationsToApply) + // empty migrations slice - no version created + assert.Nil(t, version) + assert.Equal(t, int32(0), results.MigrationsGrandTotal) + assert.Equal(t, int32(0), results.ScriptsGrandTotal) } -func TestApplyMigrationsDryRunMode(t *testing.T) { +func TestCreateVersionDryRunMode(t *testing.T) { db, mock, err := sqlmock.New() assert.Nil(t, err) @@ -154,32 +164,40 @@ func TestApplyMigrationsDryRunMode(t *testing.T) { dialect := newDialect(config) connector := baseConnector{newTestContext(), config, dialect, db} - time := time.Now().UnixNano() - m := types.Migration{Name: fmt.Sprintf("%v.sql", time), SourceDir: "tenants", File: fmt.Sprintf("tenants/%v.sql", time), MigrationType: types.MigrationTypeTenantMigration, Contents: "insert into {schema}.settings values (456, '456') "} + tn := time.Now().UnixNano() + m := types.Migration{Name: fmt.Sprintf("%v.sql", tn), SourceDir: "tenants", File: fmt.Sprintf("tenants/%v.sql", tn), MigrationType: types.MigrationTypeTenantMigration, Contents: "insert into {schema}.settings values (456, '456') "} migrationsToApply := []types.Migration{m} tenant := "tenantname" tenants := sqlmock.NewRows([]string{"name"}).AddRow(tenant) mock.ExpectQuery("select").WillReturnRows(tenants) mock.ExpectBegin() - mock.ExpectPrepare("insert into") - mock.ExpectExec("insert into").WillReturnResult(sqlmock.NewResult(1, 1)) - mock.ExpectPrepare("insert into").ExpectExec().WithArgs(m.Name, m.SourceDir, m.File, m.MigrationType, tenant, m.Contents, m.CheckSum).WillReturnResult(sqlmock.NewResult(1, 1)) - + // version + mock.ExpectPrepare("insert into migrator.migrator_versions") + mock.ExpectPrepare("insert into migrator.migrator_versions").ExpectQuery().WithArgs("commit-sha") + // migration + mock.ExpectPrepare("insert into migrator.migrator_migrations") + mock.ExpectExec("insert into").WillReturnResult(sqlmock.NewResult(0, 0)) + mock.ExpectPrepare("insert into migrator.migrator_migrations").ExpectExec().WithArgs(m.Name, m.SourceDir, m.File, m.MigrationType, tenant, m.Contents, m.CheckSum, 0).WillReturnResult(sqlmock.NewResult(0, 0)) + // get version + rows := sqlmock.NewRows([]string{"vid", "vname", "vcreated", "mid", "name", "source_dir", "filename", "type", "db_schema", "created", "contents", "checksum"}).AddRow("123", "vname", time.Now(), "456", m.Name, m.SourceDir, m.File, m.MigrationType, tenant, time.Now(), m.Contents, m.CheckSum) + mock.ExpectQuery("select").WillReturnRows(rows) // dry-run mode calls rollback instead of commit mock.ExpectRollback() // however the results contain correct dry-run data like number of applied migrations/scripts - results := connector.ApplyMigrations(types.ModeTypeDryRun, migrationsToApply) - - assert.Equal(t, 1, results.MigrationsGrandTotal) + results, version := connector.CreateVersion("commit-sha", types.ActionApply, true, migrationsToApply) + assert.NotNil(t, version) + assert.True(t, version.ID > 0) + assert.Equal(t, results.MigrationsGrandTotal+results.ScriptsGrandTotal, int32(len(version.DBMigrations))) + assert.Equal(t, int32(1), results.MigrationsGrandTotal) if err := mock.ExpectationsWereMet(); err != nil { t.Errorf("there were unfulfilled expectations: %s", err) } } -func TestApplyMigrationsSyncMode(t *testing.T) { +func TestCreateVersionSyncMode(t *testing.T) { db, mock, err := sqlmock.New() assert.Nil(t, err) @@ -188,22 +206,31 @@ func TestApplyMigrationsSyncMode(t *testing.T) { dialect := newDialect(config) connector := baseConnector{newTestContext(), config, dialect, db} - time := time.Now().UnixNano() - m := types.Migration{Name: fmt.Sprintf("%v.sql", time), SourceDir: "tenants", File: fmt.Sprintf("tenants/%v.sql", time), MigrationType: types.MigrationTypeTenantMigration, Contents: "insert into {schema}.settings values (456, '456') "} + tn := time.Now().UnixNano() + m := types.Migration{Name: fmt.Sprintf("%v.sql", tn), SourceDir: "tenants", File: fmt.Sprintf("tenants/%v.sql", tn), MigrationType: types.MigrationTypeTenantMigration, Contents: "insert into {schema}.settings values (456, '456') "} migrationsToApply := []types.Migration{m} tenant := "tenantname" tenants := sqlmock.NewRows([]string{"name"}).AddRow(tenant) mock.ExpectQuery("select").WillReturnRows(tenants) mock.ExpectBegin() - mock.ExpectPrepare("insert into") - mock.ExpectPrepare("insert into").ExpectExec().WithArgs(m.Name, m.SourceDir, m.File, m.MigrationType, tenant, m.Contents, m.CheckSum).WillReturnResult(sqlmock.NewResult(1, 1)) + // version + mock.ExpectPrepare("insert into migrator.migrator_versions") + mock.ExpectPrepare("insert into migrator.migrator_versions").ExpectQuery().WithArgs("commit-sha") + // migration + mock.ExpectPrepare("insert into migrator.migrator_migrations") + mock.ExpectPrepare("insert into").ExpectExec().WithArgs(m.Name, m.SourceDir, m.File, m.MigrationType, tenant, m.Contents, m.CheckSum, 0).WillReturnResult(sqlmock.NewResult(0, 0)) + // get version + rows := sqlmock.NewRows([]string{"vid", "vname", "vcreated", "mid", "name", "source_dir", "filename", "type", "db_schema", "created", "contents", "checksum"}).AddRow("123", "vname", time.Now(), "456", m.Name, m.SourceDir, m.File, m.MigrationType, tenant, time.Now(), m.Contents, m.CheckSum) + mock.ExpectQuery("select").WillReturnRows(rows) mock.ExpectCommit() // sync the results contain correct data like number of applied migrations/scripts - results := connector.ApplyMigrations(types.ModeTypeSync, migrationsToApply) - - assert.Equal(t, 1, results.MigrationsGrandTotal) + results, version := connector.CreateVersion("commit-sha", types.ActionSync, false, migrationsToApply) + assert.NotNil(t, version) + assert.True(t, version.ID > 0) + assert.Equal(t, results.MigrationsGrandTotal+results.ScriptsGrandTotal, int32(len(version.DBMigrations))) + assert.Equal(t, int32(1), results.MigrationsGrandTotal) if err := mock.ExpectationsWereMet(); err != nil { t.Errorf("there were unfulfilled expectations: %s", err) @@ -262,7 +289,7 @@ func TestGetSchemaPlaceHolderOverride(t *testing.T) { assert.Equal(t, "[schema]", placeholder) } -func TestAddTenantAndApplyMigrations(t *testing.T) { +func TestCreateTenant(t *testing.T) { config, err := config.FromFile("../test/migrator.yaml") assert.Nil(t, err) @@ -281,16 +308,21 @@ func TestAddTenantAndApplyMigrations(t *testing.T) { uniqueTenant := fmt.Sprintf("new_test_tenant_%v", time.Now().UnixNano()) - results := connector.AddTenantAndApplyMigrations(types.ModeTypeApply, uniqueTenant, migrationsToApply) + results, version := connector.CreateTenant("commit-sha", types.ActionApply, false, uniqueTenant, migrationsToApply) + + assert.NotNil(t, version) + assert.True(t, version.ID > 0) + assert.Equal(t, "commit-sha", version.Name) + assert.Equal(t, results.MigrationsGrandTotal+results.ScriptsGrandTotal, int32(len(version.DBMigrations))) // applied only for one tenant - the newly added one - assert.Equal(t, 1, results.Tenants) + assert.Equal(t, int32(1), results.Tenants) // just one tenant so total number of tenant migrations is equal to tenant migrations - assert.Equal(t, 3, results.TenantMigrations) - assert.Equal(t, 3, results.TenantMigrationsTotal) + assert.Equal(t, int32(3), results.TenantMigrations) + assert.Equal(t, int32(3), results.TenantMigrationsTotal) } -func TestAddTenantAndApplyMigrationsDryRunMode(t *testing.T) { +func TestCreateTenantDryRunMode(t *testing.T) { db, mock, err := sqlmock.New() assert.Nil(t, err) @@ -299,33 +331,43 @@ func TestAddTenantAndApplyMigrationsDryRunMode(t *testing.T) { dialect := newDialect(config) connector := baseConnector{newTestContext(), config, dialect, db} - time := time.Now().UnixNano() - m := types.Migration{Name: fmt.Sprintf("%v.sql", time), SourceDir: "tenants", File: fmt.Sprintf("tenants/%v.sql", time), MigrationType: types.MigrationTypeTenantMigration, Contents: "insert into {schema}.settings values (456, '456') "} + tn := time.Now().UnixNano() + m := types.Migration{Name: fmt.Sprintf("%v.sql", tn), SourceDir: "tenants", File: fmt.Sprintf("tenants/%v.sql", tn), MigrationType: types.MigrationTypeTenantMigration, Contents: "insert into {schema}.settings values (456, '456') "} migrationsToApply := []types.Migration{m} tenant := "tenantname" + mock.ExpectBegin() - mock.ExpectExec("create schema").WillReturnResult(sqlmock.NewResult(1, 1)) + mock.ExpectExec("create schema").WillReturnResult(sqlmock.NewResult(0, 0)) + // tenant mock.ExpectPrepare("insert into") mock.ExpectPrepare("insert into").ExpectExec().WithArgs(tenant).WillReturnResult(sqlmock.NewResult(1, 1)) - mock.ExpectPrepare("insert into") - mock.ExpectExec("insert into").WillReturnResult(sqlmock.NewResult(1, 1)) - mock.ExpectPrepare("insert into").ExpectExec().WithArgs(m.Name, m.SourceDir, m.File, m.MigrationType, tenant, m.Contents, m.CheckSum).WillReturnResult(sqlmock.NewResult(1, 1)) - + // version + mock.ExpectPrepare("insert into migrator.migrator_versions") + mock.ExpectPrepare("insert into migrator.migrator_versions").ExpectQuery().WithArgs("commit-sha") + // migration + mock.ExpectPrepare("insert into migrator.migrator_migrations") + mock.ExpectExec("insert into").WillReturnResult(sqlmock.NewResult(0, 0)) + mock.ExpectPrepare("insert into migrator.migrator_migrations").ExpectExec().WithArgs(m.Name, m.SourceDir, m.File, m.MigrationType, tenant, m.Contents, m.CheckSum, 0).WillReturnResult(sqlmock.NewResult(0, 0)) + // get version + rows := sqlmock.NewRows([]string{"vid", "vname", "vcreated", "mid", "name", "source_dir", "filename", "type", "db_schema", "created", "contents", "checksum"}).AddRow("123", "vname", time.Now(), "456", m.Name, m.SourceDir, m.File, m.MigrationType, tenant, time.Now(), m.Contents, m.CheckSum) + mock.ExpectQuery("select").WillReturnRows(rows) // dry-run mode calls rollback instead of commit mock.ExpectRollback() // however the results contain correct dry-run data like number of applied migrations/scripts - results := connector.AddTenantAndApplyMigrations(types.ModeTypeDryRun, tenant, migrationsToApply) - - assert.Equal(t, 1, results.MigrationsGrandTotal) + results, version := connector.CreateTenant("commit-sha", types.ActionApply, true, tenant, migrationsToApply) + assert.NotNil(t, version) + assert.True(t, version.ID > 0) + assert.Equal(t, results.MigrationsGrandTotal+results.ScriptsGrandTotal, int32(len(version.DBMigrations))) + assert.Equal(t, int32(1), results.MigrationsGrandTotal) if err := mock.ExpectationsWereMet(); err != nil { t.Errorf("there were unfulfilled expectations: %s", err) } } -func TestAddTenantAndApplyMigrationsSyncMode(t *testing.T) { +func TestCreateTenantSyncMode(t *testing.T) { db, mock, err := sqlmock.New() assert.Nil(t, err) @@ -334,23 +376,33 @@ func TestAddTenantAndApplyMigrationsSyncMode(t *testing.T) { dialect := newDialect(config) connector := baseConnector{newTestContext(), config, dialect, db} - time := time.Now().UnixNano() - m := types.Migration{Name: fmt.Sprintf("%v.sql", time), SourceDir: "tenants", File: fmt.Sprintf("tenants/%v.sql", time), MigrationType: types.MigrationTypeTenantMigration, Contents: "insert into {schema}.settings values (456, '456') "} + tn := time.Now().UnixNano() + m := types.Migration{Name: fmt.Sprintf("%v.sql", tn), SourceDir: "tenants", File: fmt.Sprintf("tenants/%v.sql", tn), MigrationType: types.MigrationTypeTenantMigration, Contents: "insert into {schema}.settings values (456, '456') "} migrationsToApply := []types.Migration{m} tenant := "tenantname" mock.ExpectBegin() - mock.ExpectExec("create schema").WillReturnResult(sqlmock.NewResult(1, 1)) + mock.ExpectExec("create schema").WillReturnResult(sqlmock.NewResult(0, 0)) + // tenant mock.ExpectPrepare("insert into") - mock.ExpectPrepare("insert into").ExpectExec().WithArgs(tenant).WillReturnResult(sqlmock.NewResult(1, 1)) - mock.ExpectPrepare("insert into") - mock.ExpectPrepare("insert into").ExpectExec().WithArgs(m.Name, m.SourceDir, m.File, m.MigrationType, tenant, m.Contents, m.CheckSum).WillReturnResult(sqlmock.NewResult(1, 1)) + mock.ExpectPrepare("insert into").ExpectExec().WithArgs(tenant).WillReturnResult(sqlmock.NewResult(0, 0)) + // version + mock.ExpectPrepare("insert into migrator.migrator_versions") + mock.ExpectPrepare("insert into migrator.migrator_versions").ExpectQuery().WithArgs("commit-sha") + // migration + mock.ExpectPrepare("insert into migrator.migrator_migrations") + mock.ExpectPrepare("insert into").ExpectExec().WithArgs(m.Name, m.SourceDir, m.File, m.MigrationType, tenant, m.Contents, m.CheckSum, 0).WillReturnResult(sqlmock.NewResult(0, 0)) + // get version + rows := sqlmock.NewRows([]string{"vid", "vname", "vcreated", "mid", "name", "source_dir", "filename", "type", "db_schema", "created", "contents", "checksum"}).AddRow("123", "vname", time.Now(), "456", m.Name, m.SourceDir, m.File, m.MigrationType, tenant, time.Now(), m.Contents, m.CheckSum) + mock.ExpectQuery("select").WillReturnRows(rows) mock.ExpectCommit() // sync results contain correct data like number of applied migrations/scripts - results := connector.AddTenantAndApplyMigrations(types.ModeTypeSync, tenant, migrationsToApply) - - assert.Equal(t, 1, results.MigrationsGrandTotal) + results, version := connector.CreateTenant("commit-sha", types.ActionSync, false, tenant, migrationsToApply) + assert.NotNil(t, version) + assert.True(t, version.ID > 0) + assert.Equal(t, results.MigrationsGrandTotal+results.ScriptsGrandTotal, int32(len(version.DBMigrations))) + assert.Equal(t, int32(1), results.MigrationsGrandTotal) if err := mock.ExpectationsWereMet(); err != nil { t.Errorf("there were unfulfilled expectations: %s", err) @@ -369,3 +421,83 @@ func TestGetTenantInsertSQLOverride(t *testing.T) { assert.Equal(t, "insert into someschema.sometable (somename) values ($1)", tenantInsertSQL) } + +func TestGetVersions(t *testing.T) { + config, err := config.FromFile("../test/migrator.yaml") + assert.Nil(t, err) + + connector := New(newTestContext(), config) + defer connector.Dispose() + + versions := connector.GetVersions() + + assert.True(t, len(versions) >= 2) + // versions are sorted from newest (highest ID) to oldest (lowest ID) + assert.True(t, versions[0].ID > versions[1].ID) + + existingVersion = versions[0] + existingDBMigration = existingVersion.DBMigrations[0] +} + +func TestGetVersionsByFile(t *testing.T) { + config, err := config.FromFile("../test/migrator.yaml") + assert.Nil(t, err) + + connector := New(newTestContext(), config) + defer connector.Dispose() + + versions := connector.GetVersionsByFile(existingVersion.DBMigrations[0].File) + version := versions[0] + assert.Equal(t, existingVersion.ID, version.ID) + assert.Equal(t, existingVersion.DBMigrations[0].File, version.DBMigrations[0].File) + assert.True(t, len(version.DBMigrations) > 0) +} + +func TestGetVersionByID(t *testing.T) { + config, err := config.FromFile("../test/migrator.yaml") + assert.Nil(t, err) + + connector := New(newTestContext(), config) + defer connector.Dispose() + + version, err := connector.GetVersionByID(existingVersion.ID) + assert.Nil(t, err) + assert.Equal(t, existingVersion.ID, version.ID) + assert.True(t, len(version.DBMigrations) > 0) +} + +func TestGetVersionByIDNotFound(t *testing.T) { + config, err := config.FromFile("../test/migrator.yaml") + assert.Nil(t, err) + + connector := New(newTestContext(), config) + defer connector.Dispose() + + version, err := connector.GetVersionByID(-1) + assert.Nil(t, version) + assert.Equal(t, "Version not found ID: -1", err.Error()) +} + +func TestGetDBMigrationByID(t *testing.T) { + config, err := config.FromFile("../test/migrator.yaml") + assert.Nil(t, err) + + connector := New(newTestContext(), config) + defer connector.Dispose() + + dbMigration, err := connector.GetDBMigrationByID(existingDBMigration.ID) + assert.Nil(t, err) + assert.Equal(t, existingDBMigration.ID, dbMigration.ID) +} + +func TestGetDBMigrationByIDNotFound(t *testing.T) { + config, err := config.FromFile("../test/migrator.yaml") + assert.Nil(t, err) + + connector := New(newTestContext(), config) + defer connector.Dispose() + + dbMigration, err := connector.GetDBMigrationByID(-1) + assert.Nil(t, dbMigration) + assert.Equal(t, "DB migration not found ID: -1", err.Error()) +} diff --git a/go.mod b/go.mod index 891facf..8bc27a1 100644 --- a/go.mod +++ b/go.mod @@ -11,6 +11,7 @@ require ( github.com/gin-gonic/gin v1.5.0 github.com/go-playground/universal-translator v0.17.0 // indirect github.com/go-sql-driver/mysql v1.5.0 + github.com/graph-gophers/graphql-go v0.0.0-20200207002730-8334863f2c8b github.com/leodido/go-urn v1.2.0 // indirect github.com/lib/pq v1.3.0 github.com/pkg/errors v0.9.1 // indirect diff --git a/go.sum b/go.sum index fd113bb..ac8caf6 100644 --- a/go.sum +++ b/go.sum @@ -26,6 +26,8 @@ github.com/go-sql-driver/mysql v1.5.0/go.mod h1:DCzpHaOWr8IXmIStZouvnhqoel9Qv2LB github.com/golang-sql/civil v0.0.0-20190719163853-cb61b32ac6fe/go.mod h1:8vg3r2VgvsThLBIFL93Qb5yWzgyZWhEmBwUJWevAkK0= github.com/golang/protobuf v1.3.2/go.mod h1:6lQm79b+lXiMfvg/cZm0SGofjICqVBUtrP5yJMmIC1U= github.com/google/gofuzz v1.0.0/go.mod h1:dBl0BpW6vV/+mYPU4Po3pmUjxk6FQPldtuIdl/M65Eg= +github.com/graph-gophers/graphql-go v0.0.0-20200207002730-8334863f2c8b h1:fRjb9ncV+Aad/w56TstaCM/xGusFsfDfeGhhc+k4IBg= +github.com/graph-gophers/graphql-go v0.0.0-20200207002730-8334863f2c8b/go.mod h1:9CQHMSxwO4MprSdzoIEobiHpoLtHm77vfxsvsIN5Vuc= github.com/jmespath/go-jmespath v0.0.0-20180206201540-c2b33e8439af/go.mod h1:Nht3zPeWKUH0NzdCt2Blrr5ys8VGpn0CEB0cQHVjt7k= github.com/json-iterator/go v1.1.7/go.mod h1:KdQUCv79m/52Kvf8AW2vK1V8akMuk1QjK/uOdHXbAo4= github.com/kr/pretty v0.1.0/go.mod h1:dAy3ld7l9f0ibDNOQOHHMYYIIbhfbHSm3C4ZsoJORNo= @@ -38,6 +40,8 @@ github.com/mattn/go-ieproxy v0.0.0-20190610004146-91bb50d98149/go.mod h1:31jz6HN github.com/mattn/go-isatty v0.0.9/go.mod h1:YNRxwqDuOph6SZLI9vUUz6OYw3QyUt7WiY2yME+cCiQ= github.com/modern-go/concurrent v0.0.0-20180228061459-e0a39a4cb421/go.mod h1:6dJC0mAP4ikYIbvyc7fijjWJddQyLn8Ig3JB5CqoB9Q= github.com/modern-go/reflect2 v0.0.0-20180701023420-4b7aa43c6742/go.mod h1:bx2lNnkwVCuqBIxFjflWJWanXIb3RllmbCylyMrvgv0= +github.com/opentracing/opentracing-go v1.1.0 h1:pWlfV3Bxv7k65HYwkikxat0+s3pV4bsqf19k25Ur8rU= +github.com/opentracing/opentracing-go v1.1.0/go.mod h1:UkNAQd3GIcIGf0SeVgPpRdFStlNbqXla1AfSYxPUl2o= github.com/pkg/errors v0.9.1/go.mod h1:bwawxfHBFNV+L2hUp1rHADufV3IMtnDRdf1r5NINEl0= github.com/pmezard/go-difflib v1.0.0/go.mod h1:iKH77koFhYxTK1pcRnkKkqfTogsbg7gZNVY4sRDYZ/4= github.com/stretchr/objx v0.1.0/go.mod h1:HFkY916IF+rwdDfMAkV7OtwuqBVzrE8GR6GFx+wExME= diff --git a/loader/azureblob_loader_test.go b/loader/azureblob_loader_test.go index a99ddf1..0984795 100644 --- a/loader/azureblob_loader_test.go +++ b/loader/azureblob_loader_test.go @@ -2,6 +2,7 @@ package loader import ( "context" + "fmt" "os" "testing" @@ -10,14 +11,14 @@ import ( ) func TestAzureGetSourceMigrations(t *testing.T) { - - travis := os.Getenv("TRAVIS") - if len(travis) > 0 { - t.Skip("Does not work on travis due to Azure Storage Account credentials required") - } + // migrator implements env variable substitution and normally we would use: + // "https://${AZURE_STORAGE_ACCOUNT}.blob.core.windows.net/mycontainer" + // however below we are creating the Config struct directly + // and that's why we need to build correct URL ourselves + baseLocation := fmt.Sprintf("https://%v.blob.core.windows.net/mycontainer", os.Getenv("AZURE_STORAGE_ACCOUNT")) config := &config.Config{ - BaseLocation: "https://storageaccountname.blob.core.windows.net/mycontainer", + BaseLocation: baseLocation, SingleMigrations: []string{"migrations/config", "migrations/ref"}, TenantMigrations: []string{"migrations/tenants"}, SingleScripts: []string{"migrations/config-scripts"}, diff --git a/loader/s3_loader.go b/loader/s3_loader.go index e457459..3bc5c0a 100644 --- a/loader/s3_loader.go +++ b/loader/s3_loader.go @@ -66,15 +66,14 @@ func (s3l *s3Loader) getObjectList(client s3iface.S3API, prefixes []string) []*s MaxKeys: aws.Int64(1000), } - pageNum := 0 err := client.ListObjectsV2Pages(input, func(page *s3.ListObjectsV2Output, lastPage bool) bool { - pageNum++ + for _, o := range page.Contents { objects = append(objects, o.Key) } - return pageNum <= 10 + return !lastPage }) if err != nil { diff --git a/loader/s3_loader_test.go b/loader/s3_loader_test.go index 960841c..518dd7c 100644 --- a/loader/s3_loader_test.go +++ b/loader/s3_loader_test.go @@ -62,7 +62,7 @@ func TestS3GetSourceMigrations(t *testing.T) { mock := &mockS3Client{} config := &config.Config{ - BaseLocation: "s3://lukasz-budnik-migrator-us-east-1", + BaseLocation: "s3://your-bucket-migrator", SingleMigrations: []string{"migrations/config", "migrations/ref"}, TenantMigrations: []string{"migrations/tenants"}, SingleScripts: []string{"migrations/config-scripts"}, diff --git a/migrator.go b/migrator.go index c18bbb6..6575889 100644 --- a/migrator.go +++ b/migrator.go @@ -59,7 +59,7 @@ func main() { } gin.SetMode(gin.ReleaseMode) - versionInfo := &types.VersionInfo{Release: GitBranch, CommitSha: GitCommitSha, CommitDate: GitCommitDate, APIVersions: []string{"v1"}} + versionInfo := &types.VersionInfo{Release: GitBranch, CommitSha: GitCommitSha, CommitDate: GitCommitDate, APIVersions: []string{"v1", "v2"}} g := server.SetupRouter(versionInfo, cfg, createCoordinator) if err := g.Run(":" + server.GetPort(cfg)); err != nil { common.Log("ERROR", "Error starting migrator: %v", err) diff --git a/server/server.go b/server/server.go index 57a3aea..da3c75f 100644 --- a/server/server.go +++ b/server/server.go @@ -3,20 +3,20 @@ package server import ( "context" "fmt" - "net" "net/http" - "os" "runtime/debug" "strings" "time" "github.com/gin-gonic/gin" "github.com/gin-gonic/gin/binding" + "github.com/graph-gophers/graphql-go" "gopkg.in/go-playground/validator.v9" "github.com/lukaszbudnik/migrator/common" "github.com/lukaszbudnik/migrator/config" "github.com/lukaszbudnik/migrator/coordinator" + "github.com/lukaszbudnik/migrator/data" "github.com/lukaszbudnik/migrator/types" ) @@ -61,6 +61,10 @@ func requestIDHandler() gin.HandlerFunc { } ctx := context.WithValue(c.Request.Context(), common.RequestIDKey{}, requestID) c.Request = c.Request.WithContext(ctx) + if strings.Contains(c.Request.URL.Path, "/v1/") { + c.Header("Deprecation", `version="v2020.1.0"`) + c.Header("Link", `; rel="successor-version"`) + } c.Next() } } @@ -69,29 +73,11 @@ func recovery() gin.HandlerFunc { return func(c *gin.Context) { defer func() { if err := recover(); err != nil { - // Check for a broken connection, as it is not really a - // condition that warrants a panic stack trace. - var brokenPipe bool - if ne, ok := err.(*net.OpError); ok { - if se, ok := ne.Err.(*os.SyscallError); ok { - if strings.Contains(strings.ToLower(se.Error()), "broken pipe") || strings.Contains(strings.ToLower(se.Error()), "connection reset by peer") { - brokenPipe = true - } - } - } - - // If the connection is dead, we can't write a status to it. - if brokenPipe { - common.LogPanic(c.Request.Context(), "Broken pipe: %v", err) - c.Error(err.(error)) // nolint: errcheck - c.Abort() - } else { - common.LogPanic(c.Request.Context(), "Panic recovered: %v", err) - if gin.IsDebugging() { - debug.PrintStack() - } - c.AbortWithStatusJSON(http.StatusInternalServerError, &errorResponse{err.(string), nil}) + common.LogPanic(c.Request.Context(), "Panic recovered: %v", err) + if gin.IsDebugging() { + debug.PrintStack() } + c.AbortWithStatusJSON(http.StatusInternalServerError, &errorResponse{err.(string), nil}) } }() c.Next() @@ -118,7 +104,7 @@ func configHandler(c *gin.Context, config *config.Config, newCoordinator func(co func migrationsSourceHandler(c *gin.Context, config *config.Config, newCoordinator func(context.Context, *config.Config) coordinator.Coordinator) { coordinator := newCoordinator(c.Request.Context(), config) defer coordinator.Dispose() - migrations := coordinator.GetSourceMigrations() + migrations := coordinator.GetSourceMigrations(nil) common.LogInfo(c.Request.Context(), "Returning source migrations: %v", len(migrations)) c.JSON(http.StatusOK, migrations) } @@ -172,9 +158,15 @@ func migrationsPostHandler(c *gin.Context, config *config.Config, newCoordinator func tenantsGetHandler(c *gin.Context, config *config.Config, newCoordinator func(context.Context, *config.Config) coordinator.Coordinator) { coordinator := newCoordinator(c.Request.Context(), config) defer coordinator.Dispose() + // starting v2019.1.0 GetTenants returns a slice of Tenant struct + // /v1 API returns a slice of strings and we must maintain backward compatibility tenants := coordinator.GetTenants() + tenantNames := []string{} + for _, t := range tenants { + tenantNames = append(tenantNames, t.Name) + } common.LogInfo(c.Request.Context(), "Returning tenants: %v", len(tenants)) - c.JSON(http.StatusOK, tenants) + c.JSON(http.StatusOK, tenantNames) } func tenantsPostHandler(c *gin.Context, config *config.Config, newCoordinator func(context.Context, *config.Config) coordinator.Coordinator) { @@ -215,6 +207,32 @@ func tenantsPostHandler(c *gin.Context, config *config.Config, newCoordinator fu c.JSON(http.StatusOK, response) } +func schemaHandler(c *gin.Context, config *config.Config, newCoordinator func(context.Context, *config.Config) coordinator.Coordinator) { + c.String(http.StatusOK, strings.TrimSpace(data.SchemaDefinition)) +} + +// GraphQL endpoint +func serviceHandler(c *gin.Context, config *config.Config, newCoordinator func(context.Context, *config.Config) coordinator.Coordinator) { + var params struct { + Query string `json:"query"` + OperationName string `json:"operationName"` + Variables map[string]interface{} `json:"variables"` + } + if err := c.ShouldBindJSON(¶ms); err != nil { + common.LogError(c.Request.Context(), "Bad request: %v", err.Error()) + c.AbortWithStatusJSON(http.StatusBadRequest, errorResponse{"Invalid request, please see documentation for valid JSON payload", nil}) + return + } + + coordinator := newCoordinator(c.Request.Context(), config) + defer coordinator.Dispose() + opts := []graphql.SchemaOpt{graphql.UseFieldResolvers()} + schema := graphql.MustParseSchema(data.SchemaDefinition, &data.RootResolver{Coordinator: coordinator}, opts...) + + response := schema.Exec(c.Request.Context(), params.Query, params.OperationName, params.Variables) + c.JSON(http.StatusOK, response) +} + // SetupRouter setups router func SetupRouter(versionInfo *types.VersionInfo, config *config.Config, newCoordinator func(ctx context.Context, config *config.Config) coordinator.Coordinator) *gin.Engine { r := gin.New() @@ -247,5 +265,10 @@ func SetupRouter(versionInfo *types.VersionInfo, config *config.Config, newCoord v1.GET("/migrations/applied", makeHandler(config, newCoordinator, migrationsAppliedHandler)) v1.POST("/migrations", makeHandler(config, newCoordinator, migrationsPostHandler)) + v2 := r.Group(config.PathPrefix + "/v2") + v2.GET("/config", makeHandler(config, newCoordinator, configHandler)) + v2.GET("/schema", makeHandler(config, newCoordinator, schemaHandler)) + v2.POST("/service", makeHandler(config, newCoordinator, serviceHandler)) + return r } diff --git a/server/server_mocks.go b/server/server_mocks.go index 2e4d7b6..08cbea2 100644 --- a/server/server_mocks.go +++ b/server/server_mocks.go @@ -3,8 +3,11 @@ package server import ( "context" "fmt" + "strings" "time" + "github.com/graph-gophers/graphql-go" + "github.com/lukaszbudnik/migrator/config" "github.com/lukaszbudnik/migrator/coordinator" "github.com/lukaszbudnik/migrator/types" @@ -28,7 +31,15 @@ func newMockedErrorCoordinator(errorThreshold int) func(context.Context, *config func (m *mockedCoordinator) Dispose() { } -func (m *mockedCoordinator) GetSourceMigrations() []types.Migration { +func (m *mockedCoordinator) CreateTenant(string, types.Action, bool, string) *types.CreateResults { + return &types.CreateResults{Summary: &types.MigrationResults{}, Version: &types.Version{}} +} + +func (m *mockedCoordinator) CreateVersion(string, types.Action, bool) *types.CreateResults { + return &types.CreateResults{Summary: &types.MigrationResults{}, Version: &types.Version{}} +} + +func (m *mockedCoordinator) GetSourceMigrations(_ *coordinator.SourceMigrationFilters) []types.Migration { if m.errorThreshold == m.counter { panic(fmt.Sprintf("Mocked Error Disk Loader: threshold %v reached", m.errorThreshold)) } @@ -38,15 +49,46 @@ func (m *mockedCoordinator) GetSourceMigrations() []types.Migration { return []types.Migration{m1, m2} } +func (m *mockedCoordinator) GetSourceMigrationByFile(file string) (*types.Migration, error) { + i := strings.Index(file, "/") + sourceDir := file[:i] + name := file[i+1:] + m1 := types.Migration{Name: name, SourceDir: sourceDir, File: file, MigrationType: types.MigrationTypeSingleMigration, Contents: "select abc"} + return &m1, nil +} + func (m *mockedCoordinator) GetAppliedMigrations() []types.MigrationDB { m1 := types.Migration{Name: "201602220000.sql", SourceDir: "source", File: "source/201602220000.sql", MigrationType: types.MigrationTypeSingleMigration, Contents: "select abc", CheckSum: "sha256"} d1 := time.Date(2016, 02, 22, 16, 41, 1, 123, time.UTC) - ms := []types.MigrationDB{{Migration: m1, Schema: "source", AppliedAt: d1}} + ms := []types.MigrationDB{{Migration: m1, Schema: "source", AppliedAt: graphql.Time{Time: d1}, Created: graphql.Time{Time: d1}}} return ms } -func (m *mockedCoordinator) GetTenants() []string { - return []string{"a", "b", "c"} +// part of interface but not used in server tests - tested in data package +func (m *mockedCoordinator) GetDBMigrationByID(ID int32) (*types.DBMigration, error) { + return nil, nil +} + +func (m *mockedCoordinator) GetTenants() []types.Tenant { + a := types.Tenant{Name: "a"} + b := types.Tenant{Name: "b"} + c := types.Tenant{Name: "c"} + return []types.Tenant{a, b, c} +} + +// part of interface but not used in server tests - tested in data package +func (m *mockedCoordinator) GetVersions() []types.Version { + return []types.Version{} +} + +// part of interface but not used in server tests - tested in data package +func (m *mockedCoordinator) GetVersionsByFile(file string) []types.Version { + return []types.Version{} +} + +// part of interface but not used in server tests - tested in data package +func (m *mockedCoordinator) GetVersionByID(ID int32) (*types.Version, error) { + return nil, nil } func (m *mockedCoordinator) VerifySourceMigrationsCheckSums() (bool, []types.Migration) { @@ -59,9 +101,9 @@ func (m *mockedCoordinator) VerifySourceMigrationsCheckSums() (bool, []types.Mig } func (m *mockedCoordinator) ApplyMigrations(types.MigrationsModeType) (*types.MigrationResults, []types.Migration) { - return &types.MigrationResults{}, m.GetSourceMigrations() + return &types.MigrationResults{}, m.GetSourceMigrations(nil) } func (m *mockedCoordinator) AddTenantAndApplyMigrations(types.MigrationsModeType, string) (*types.MigrationResults, []types.Migration) { - return &types.MigrationResults{}, m.GetSourceMigrations()[1:] + return &types.MigrationResults{}, m.GetSourceMigrations(nil)[1:] } diff --git a/server/server_test.go b/server/server_test.go index 05f296b..275f8ae 100644 --- a/server/server_test.go +++ b/server/server_test.go @@ -12,6 +12,7 @@ import ( "github.com/gin-gonic/gin" "github.com/lukaszbudnik/migrator/config" "github.com/lukaszbudnik/migrator/coordinator" + "github.com/lukaszbudnik/migrator/data" "github.com/lukaszbudnik/migrator/types" "github.com/stretchr/testify/assert" ) @@ -21,11 +22,16 @@ var ( configFileOverrides = "../test/migrator-overrides.yaml" ) -func newTestRequest(method, url string, body io.Reader) (*http.Request, error) { +func newTestRequestV1(method, url string, body io.Reader) (*http.Request, error) { versionURL := "/v1" + url return http.NewRequest(method, versionURL, body) } +func newTestRequestV2(method, url string, body io.Reader) (*http.Request, error) { + versionURL := "/v2" + url + return http.NewRequest(method, versionURL, body) +} + func testSetupRouter(config *config.Config, newCoordinator func(ctx context.Context, config *config.Config) coordinator.Coordinator) *gin.Engine { versionInfo := &types.VersionInfo{Release: "GitBranch", CommitSha: "GitCommitSha", CommitDate: "2020-01-08T09:56:41+01:00", APIVersions: []string{"v1"}} gin.SetMode(gin.ReleaseMode) @@ -78,6 +84,8 @@ func TestRootWithPathPrefix(t *testing.T) { assert.Equal(t, `{"release":"GitBranch","commitSha":"GitCommitSha","commitDate":"2020-01-08T09:56:41+01:00","apiVersions":["v1"]}`, strings.TrimSpace(w.Body.String())) } +// /v1 API + // section /config func TestConfigRoute(t *testing.T) { @@ -87,11 +95,14 @@ func TestConfigRoute(t *testing.T) { router := testSetupRouter(config, nil) w := httptest.NewRecorder() - req, _ := newTestRequest("GET", "/config", nil) + req, _ := newTestRequestV1("GET", "/config", nil) router.ServeHTTP(w, req) assert.Equal(t, http.StatusOK, w.Code) assert.Equal(t, "application/x-yaml; charset=utf-8", w.HeaderMap["Content-Type"][0]) + // confirm /v1 has Deprecated and Sunset headers + assert.Equal(t, `version="v2020.1.0"`, w.HeaderMap["Deprecation"][0]) + assert.Equal(t, `; rel="successor-version"`, w.HeaderMap["Link"][0]) assert.Equal(t, config.String(), strings.TrimSpace(w.Body.String())) } @@ -104,7 +115,7 @@ func TestDiskMigrationsRoute(t *testing.T) { router := testSetupRouter(config, newMockedCoordinator) w := httptest.NewRecorder() - req, _ := newTestRequest("GET", "/migrations/source", nil) + req, _ := newTestRequestV1("GET", "/migrations/source", nil) router.ServeHTTP(w, req) assert.Equal(t, http.StatusOK, w.Code) @@ -120,14 +131,14 @@ func TestAppliedMigrationsRoute(t *testing.T) { router := testSetupRouter(config, newMockedCoordinator) - req, _ := newTestRequest(http.MethodGet, "/migrations/applied", nil) + req, _ := newTestRequestV1(http.MethodGet, "/migrations/applied", nil) w := httptest.NewRecorder() router.ServeHTTP(w, req) assert.Equal(t, http.StatusOK, w.Code) assert.Equal(t, "application/json; charset=utf-8", w.HeaderMap["Content-Type"][0]) - assert.Equal(t, `[{"name":"201602220000.sql","sourceDir":"source","file":"source/201602220000.sql","migrationType":1,"contents":"select abc","checkSum":"sha256","schema":"source","appliedAt":"2016-02-22T16:41:01.000000123Z"}]`, strings.TrimSpace(w.Body.String())) + assert.Equal(t, `[{"name":"201602220000.sql","sourceDir":"source","file":"source/201602220000.sql","migrationType":1,"contents":"select abc","checkSum":"sha256","id":0,"schema":"source","appliedAt":"2016-02-22T16:41:01.000000123Z","created":"2016-02-22T16:41:01.000000123Z"}]`, strings.TrimSpace(w.Body.String())) } // section /migrations @@ -139,7 +150,7 @@ func TestMigrationsPostRoute(t *testing.T) { router := testSetupRouter(config, newMockedCoordinator) json := []byte(`{"mode": "apply", "response": "full"}`) - req, _ := newTestRequest(http.MethodPost, "/migrations", bytes.NewBuffer(json)) + req, _ := newTestRequestV1(http.MethodPost, "/migrations", bytes.NewBuffer(json)) w := httptest.NewRecorder() router.ServeHTTP(w, req) @@ -156,7 +167,7 @@ func TestMigrationsPostRouteSummaryResponse(t *testing.T) { router := testSetupRouter(config, newMockedCoordinator) json := []byte(`{"mode": "apply", "response": "summary"}`) - req, _ := newTestRequest(http.MethodPost, "/migrations", bytes.NewBuffer(json)) + req, _ := newTestRequestV1(http.MethodPost, "/migrations", bytes.NewBuffer(json)) w := httptest.NewRecorder() router.ServeHTTP(w, req) @@ -174,7 +185,7 @@ func TestMigrationsPostRouteListResponse(t *testing.T) { router := testSetupRouter(config, newMockedCoordinator) json := []byte(`{"mode": "apply", "response": "list"}`) - req, _ := newTestRequest(http.MethodPost, "/migrations", bytes.NewBuffer(json)) + req, _ := newTestRequestV1(http.MethodPost, "/migrations", bytes.NewBuffer(json)) w := httptest.NewRecorder() router.ServeHTTP(w, req) @@ -193,7 +204,7 @@ func TestMigrationsPostRouteBadRequest(t *testing.T) { // response is invalid json := []byte(`{"mode": "apply", "response": "abc"}`) - req, _ := newTestRequest(http.MethodPost, "/migrations", bytes.NewBuffer(json)) + req, _ := newTestRequestV1(http.MethodPost, "/migrations", bytes.NewBuffer(json)) w := httptest.NewRecorder() router.ServeHTTP(w, req) @@ -210,7 +221,7 @@ func TestMigrationsPostRouteCheckSumError(t *testing.T) { router := testSetupRouter(config, newMockedErrorCoordinator(0)) json := []byte(`{"mode": "apply", "response": "full"}`) - req, _ := newTestRequest(http.MethodPost, "/migrations", bytes.NewBuffer(json)) + req, _ := newTestRequestV1(http.MethodPost, "/migrations", bytes.NewBuffer(json)) w := httptest.NewRecorder() router.ServeHTTP(w, req) @@ -229,7 +240,7 @@ func TestTenantsGetRoute(t *testing.T) { router := testSetupRouter(config, newMockedCoordinator) w := httptest.NewRecorder() - req, _ := newTestRequest("GET", "/tenants", nil) + req, _ := newTestRequestV1("GET", "/tenants", nil) router.ServeHTTP(w, req) assert.Equal(t, http.StatusOK, w.Code) @@ -244,7 +255,7 @@ func TestTenantsPostRoute(t *testing.T) { router := testSetupRouter(config, newMockedCoordinator) json := []byte(`{"name": "new_tenant", "response": "full", "mode":"dry-run"}`) - req, _ := newTestRequest(http.MethodPost, "/tenants", bytes.NewBuffer(json)) + req, _ := newTestRequestV1(http.MethodPost, "/tenants", bytes.NewBuffer(json)) w := httptest.NewRecorder() router.ServeHTTP(w, req) @@ -261,7 +272,7 @@ func TestTenantsPostRouteSummaryResponse(t *testing.T) { router := testSetupRouter(config, newMockedCoordinator) json := []byte(`{"name": "new_tenant", "response": "summary", "mode":"dry-run"}`) - req, _ := newTestRequest(http.MethodPost, "/tenants", bytes.NewBuffer(json)) + req, _ := newTestRequestV1(http.MethodPost, "/tenants", bytes.NewBuffer(json)) w := httptest.NewRecorder() router.ServeHTTP(w, req) @@ -279,7 +290,7 @@ func TestTenantsPostRouteListResponse(t *testing.T) { router := testSetupRouter(config, newMockedCoordinator) json := []byte(`{"name": "new_tenant", "response": "list", "mode":"dry-run"}`) - req, _ := newTestRequest(http.MethodPost, "/tenants", bytes.NewBuffer(json)) + req, _ := newTestRequestV1(http.MethodPost, "/tenants", bytes.NewBuffer(json)) w := httptest.NewRecorder() router.ServeHTTP(w, req) @@ -297,7 +308,7 @@ func TestTenantsPostRouteBadRequestError(t *testing.T) { router := testSetupRouter(config, newMockedCoordinator) json := []byte(`{"a": "new_tenant"}`) - req, _ := newTestRequest(http.MethodPost, "/tenants", bytes.NewBuffer(json)) + req, _ := newTestRequestV1(http.MethodPost, "/tenants", bytes.NewBuffer(json)) w := httptest.NewRecorder() router.ServeHTTP(w, req) @@ -314,7 +325,7 @@ func TestTenantsPostRouteCheckSumError(t *testing.T) { router := testSetupRouter(config, newMockedErrorCoordinator(0)) json := []byte(`{"name": "new_tenant", "response": "full", "mode":"dry-run"}`) - req, _ := newTestRequest(http.MethodPost, "/tenants", bytes.NewBuffer(json)) + req, _ := newTestRequestV1(http.MethodPost, "/tenants", bytes.NewBuffer(json)) w := httptest.NewRecorder() router.ServeHTTP(w, req) @@ -331,10 +342,68 @@ func TestRouteError(t *testing.T) { router := testSetupRouter(config, newMockedErrorCoordinator(0)) w := httptest.NewRecorder() - req, _ := newTestRequest("GET", "/migrations/source", nil) + req, _ := newTestRequestV1("GET", "/migrations/source", nil) router.ServeHTTP(w, req) assert.Equal(t, http.StatusInternalServerError, w.Code) assert.Equal(t, "application/json; charset=utf-8", w.HeaderMap["Content-Type"][0]) assert.Equal(t, `{"error":"Mocked Error Disk Loader: threshold 0 reached"}`, strings.TrimSpace(w.Body.String())) } + +// /v2 API + +func TestGraphQLSchema(t *testing.T) { + config, err := config.FromFile(configFile) + assert.Nil(t, err) + + router := testSetupRouter(config, newMockedCoordinator) + + w := httptest.NewRecorder() + req, _ := newTestRequestV2("GET", "/schema", nil) + router.ServeHTTP(w, req) + + assert.Equal(t, http.StatusOK, w.Code) + assert.Equal(t, "text/plain; charset=utf-8", w.HeaderMap["Content-Type"][0]) + assert.Equal(t, strings.TrimSpace(data.SchemaDefinition), strings.TrimSpace(w.Body.String())) +} + +func TestGraphQLQueryWithVariables(t *testing.T) { + config, err := config.FromFile(configFile) + assert.Nil(t, err) + + router := testSetupRouter(config, newMockedCoordinator) + + w := httptest.NewRecorder() + req, _ := newTestRequestV2("POST", "/service", strings.NewReader(` + { + "query": "query SourceMigration($file: String!) { sourceMigration(file: $file) { name, migrationType, sourceDir, file } }", + "operationName": "SourceMigration", + "variables": { "file": "source/201602220001.sql" } + } + `)) + router.ServeHTTP(w, req) + + assert.Equal(t, http.StatusOK, w.Code) + assert.Equal(t, "application/json; charset=utf-8", w.HeaderMap["Content-Type"][0]) + assert.Equal(t, `{"data":{"sourceMigration":{"name":"201602220001.sql","migrationType":"SingleMigration","sourceDir":"source","file":"source/201602220001.sql"}}}`, strings.TrimSpace(w.Body.String())) +} + +func TestGraphQLQueryError(t *testing.T) { + config, err := config.FromFile(configFile) + assert.Nil(t, err) + + router := testSetupRouter(config, newMockedCoordinator) + + w := httptest.NewRecorder() + req, _ := newTestRequestV2("POST", "/service", strings.NewReader(` + { + "query": "query SourceMigration($file: String!) { sourceMigration(file: $file) { name, migrationType, sourceDir, file } }", + "operationName": "SourceMigration", + } + `)) + router.ServeHTTP(w, req) + + assert.Equal(t, http.StatusBadRequest, w.Code) + assert.Equal(t, "application/json; charset=utf-8", w.HeaderMap["Content-Type"][0]) + assert.Equal(t, `{"error":"Invalid request, please see documentation for valid JSON payload"}`, strings.TrimSpace(w.Body.String())) +} diff --git a/types/types.go b/types/types.go index 80ca79a..22e9392 100644 --- a/types/types.go +++ b/types/types.go @@ -1,8 +1,9 @@ package types import ( - "time" + "fmt" + "github.com/graph-gophers/graphql-go" "gopkg.in/go-playground/validator.v9" ) @@ -20,6 +21,48 @@ const ( MigrationTypeTenantScript MigrationType = 4 ) +// ImplementsGraphQLType maps MigrationType Go type +// to the graphql scalar type in the schema +func (MigrationType) ImplementsGraphQLType(name string) bool { + return name == "MigrationType" +} + +// String converts MigrationType Go type to string literal +func (t MigrationType) String() string { + switch t { + case MigrationTypeSingleMigration: + return "SingleMigration" + case MigrationTypeTenantMigration: + return "TenantMigration" + case MigrationTypeSingleScript: + return "SingleScript" + case MigrationTypeTenantScript: + return "TenantScript" + default: + panic(fmt.Sprintf("Unknown MigrationType value: %v", uint32(t))) + } +} + +// UnmarshalGraphQL converts string literal to MigrationType Go type +func (t *MigrationType) UnmarshalGraphQL(input interface{}) error { + if str, ok := input.(string); ok { + switch str { + case "SingleMigration": + *t = MigrationTypeSingleMigration + case "TenantMigration": + *t = MigrationTypeTenantMigration + case "SingleScript": + *t = MigrationTypeSingleScript + case "TenantScript": + *t = MigrationTypeTenantScript + default: + panic(fmt.Sprintf("Unknown MigrationType literal: %v", str)) + } + return nil + } + return fmt.Errorf("Wrong type for MigrationType: %T", input) +} + // MigrationsResponseType represents type of response either full or summary type MigrationsResponseType string @@ -62,6 +105,19 @@ func ValidateMigrationsResponseType(fl validator.FieldLevel) bool { return false } +// Tenant contains basic information about tenant +type Tenant struct { + Name string `json:"name"` +} + +// Version contains information about migrator versions +type Version struct { + ID int32 `json:"id"` + Name string `json:"name"` + Created graphql.Time `json:"created"` + DBMigrations []DBMigration `json:"dbMigrations"` +} + // Migration contains basic information about migration type Migration struct { Name string `json:"name"` @@ -72,26 +128,109 @@ type Migration struct { CheckSum string `json:"checkSum"` } +// DBMigration embeds Migration and adds DB-specific fields +// replaces deprecated MigrationDB +type DBMigration = MigrationDB + // MigrationDB embeds Migration and adds DB-specific fields +// deprecated in v2020.1.0 sunset in v2021.1.0 +// replaced by DBMigration type MigrationDB struct { Migration - Schema string `json:"schema"` - AppliedAt time.Time `json:"appliedAt"` + ID int32 `json:"id"` + Schema string `json:"schema"` + // appliedAt is deprecated the SQL column is already called created + // API v1 uses AppliedAt + // this field is ignored by GrapQL + AppliedAt graphql.Time `json:"appliedAt"` + // API v2 uses Created + // this field is returned together with appliedAt + // however it does not break API contract as this is a new field + Created graphql.Time `json:"created"` } +// Summary contains summary information about created version +// replaces deprecated MigrationDB +type Summary = MigrationResults + // MigrationResults contains summary information about executed migrations +// deprecated in v2020.1.0 sunset in v2021.1.0 +// replaced by Stats type MigrationResults struct { - StartedAt time.Time `json:"startedAt"` - Duration time.Duration `json:"duration"` - Tenants int `json:"tenants"` - SingleMigrations int `json:"singleMigrations"` - TenantMigrations int `json:"tenantMigrations"` - TenantMigrationsTotal int `json:"tenantMigrationsTotal"` // tenant migrations for all tenants - MigrationsGrandTotal int `json:"migrationsGrandTotal"` // total number of all migrations applied - SingleScripts int `json:"singleScripts"` - TenantScripts int `json:"tenantScripts"` - TenantScriptsTotal int `json:"tenantScriptsTotal"` // tenant scripts for all tenants - ScriptsGrandTotal int `json:"scriptsGrandTotal"` // total number of all scripts applied + StartedAt graphql.Time `json:"startedAt"` + Duration int32 `json:"duration"` + Tenants int32 `json:"tenants"` + SingleMigrations int32 `json:"singleMigrations"` + TenantMigrations int32 `json:"tenantMigrations"` + TenantMigrationsTotal int32 `json:"tenantMigrationsTotal"` // tenant migrations for all tenants + MigrationsGrandTotal int32 `json:"migrationsGrandTotal"` // total number of all migrations applied + SingleScripts int32 `json:"singleScripts"` + TenantScripts int32 `json:"tenantScripts"` + TenantScriptsTotal int32 `json:"tenantScriptsTotal"` // tenant scripts for all tenants + ScriptsGrandTotal int32 `json:"scriptsGrandTotal"` // total number of all scripts applied +} + +// CreateResults contains results of CreateVersion or CreateTenant +type CreateResults struct { + Summary *Summary + Version *Version +} + +// Action stores information about migrator action +type Action int + +const ( + // ActionApply (the default action) tells migrator to apply all source migrations + ActionApply Action = iota + // ActionSync tells migrator to synchronise source migrations and not apply them + ActionSync +) + +// ImplementsGraphQLType maps Action Go type +// to the graphql scalar type in the schema +func (Action) ImplementsGraphQLType(name string) bool { + return name == "Action" +} + +// String converts MigrationType Go type to string literal +func (a Action) String() string { + switch a { + case ActionSync: + return "Sync" + case ActionApply: + return "Apply" + default: + panic(fmt.Sprintf("Unknown Action value: %v", uint32(a))) + } +} + +// UnmarshalGraphQL converts string literal to MigrationType Go type +func (a *Action) UnmarshalGraphQL(input interface{}) error { + if str, ok := input.(string); ok { + switch str { + case "Sync": + *a = ActionSync + case "Apply": + *a = ActionApply + default: + panic(fmt.Sprintf("Unknown Action literal: %v", str)) + } + return nil + } + return fmt.Errorf("Wrong type for Action: %T", input) +} + +type VersionInput struct { + VersionName string + Action Action + DryRun bool +} + +type TenantInput struct { + VersionName string + Action Action + DryRun bool + TenantName string } // VersionInfo contains build information and supported API versions