diff --git a/src/__test__/database-schema.test.ts b/src/__test__/database-schema.test.ts index e7f098b..5fa4b19 100644 --- a/src/__test__/database-schema.test.ts +++ b/src/__test__/database-schema.test.ts @@ -321,3 +321,50 @@ describe('migrate to version', () => { .rejects.toThrowError('Target version of migrateToVersion() has to be greater 1') }) }) + +describe('multi-node environment', () => { + const simulateNode = async (nodeOprations: (...args: any[]) => Promise, nodeName: string) => { + await nodeOprations(nodeName) + } + + test('test', async () => { + const { database } = await setupTest() + + const migration2 = jest.fn() + const migration3 = jest.fn() + const migration4 = jest.fn() + const migration5 = jest.fn() + + const migrations = new Map() + migrations.set(5, Migration(migration5)) + migrations.set(2, Migration(migration2)) + migrations.set(3, Migration(migration3)) + migrations.set(4, Migration(migration4)) + + const operations = async (nodeName: string) => { + const databaseSchema = DatabaseSchema({ + name: 'TestSchema', + client: database, + createStatements: composeCreateTableStatements(TestTables), + migrations, + }) + + await databaseSchema.init() + await databaseSchema.migrateLatest() + + expect(databaseSchema.getVersion()).toBe(5) + } + + await Promise.all([ + simulateNode(operations, 'Node1'), + simulateNode(operations, 'Node2'), + simulateNode(operations, 'Node3'), + simulateNode(operations, 'Node4'), + ]) + + expect(migration2).toHaveBeenCalledTimes(1) + expect(migration3).toHaveBeenCalledTimes(1) + expect(migration4).toHaveBeenCalledTimes(1) + expect(migration5).toHaveBeenCalledTimes(1) + }) +}) diff --git a/src/database-schema.ts b/src/database-schema.ts index f8ac496..5800545 100644 --- a/src/database-schema.ts +++ b/src/database-schema.ts @@ -1,10 +1,12 @@ import { IDatabaseClient, IDatabaseBaseClient } from './database-client' import { TableSchema, ColumnType, NativeFunction, Table } from './table'; +import { SQL } from './sql'; const schema_management = TableSchema({ - name: { type: ColumnType.Varchar, primaryKey: true, nullable: false }, + name: { type: ColumnType.Varchar, primaryKey: true, nullable: false, unique: true }, version: { type: ColumnType.Integer, nullable: false }, date_added: { type: ColumnType.TimestampTZ, nullable: false, defaultValue: { func: NativeFunction.Now } }, + locked: { type: ColumnType.Boolean, nullable: false, defaultValue: false }, }) const SchemaManagementTable = Table({ schema_management }, 'schema_management') @@ -12,13 +14,7 @@ const selectVersionQuery = (name: string) => SchemaManagementTable.select('*', [ const insertSchemaQuery = (name: string, version: number) => SchemaManagementTable.insertFromObj({ name, version }) const updateSchemaVersionQuery = (name: string, newVersion: number) => SchemaManagementTable.update(['version'], ['name'])([newVersion], [name]) -export interface IDatabaseSchema { - readonly name: string; - getVersion(): number; - init(): Promise; - migrateLatest(): Promise; - migrateToVersion(version: number): Promise; -} +export type IDatabaseSchema = ReturnType export interface IMigration { up: (client: IDatabaseBaseClient) => Promise; @@ -36,7 +32,7 @@ export interface IDatabaseSchemaArgs { logMigrations?: boolean; } -export const DatabaseSchema = ({ client, createStatements, name, migrations, logMigrations }: IDatabaseSchemaArgs): IDatabaseSchema => { +export const DatabaseSchema = ({ client, createStatements, name, migrations, logMigrations }: IDatabaseSchemaArgs) => { let version = 0 let isInitialized = false @@ -45,22 +41,28 @@ export const DatabaseSchema = ({ client, createStatements, name, migrations, log throw new Error(`Database schema ${name} has already been initialized.`) } - await client.transaction(async (transaction) => { - await transaction.query(SchemaManagementTable.create()) + try { + await client.transaction(async (transaction) => { + await transaction.query(SchemaManagementTable.create()) - const versionDBResults = await transaction.query(selectVersionQuery(name)) + const versionDBResults = await transaction.query(selectVersionQuery(name)) - if (versionDBResults.length === 0) { - await transaction.query({ - sql: createStatements.join('\n'), - }) - await transaction.query(insertSchemaQuery(name, 1)) + if (versionDBResults.length === 0) { + await transaction.query({ + sql: createStatements.join('\n'), + }) + await transaction.query(insertSchemaQuery(name, 1)) - version = 1 - } else { - version = versionDBResults[0].version + version = 1 + } else { + version = versionDBResults[0].version + } + }) + } catch (err) { + if (err.message.indexOf('duplicate key value violates unique constraint') === -1) { + throw err } - }) + } isInitialized = true } @@ -69,6 +71,35 @@ export const DatabaseSchema = ({ client, createStatements, name, migrations, log throw new Error(`Migration failed, database schema is not initialized. Please call init() first on your database schema.`) } + const lockSchemaTableQuery = SQL.raw(` + LOCK TABLE ${SchemaManagementTable.name} IN ACCESS EXCLUSIVE MODE; + `, []) + const getSchemaVersionQuery = (awaitLock: boolean) => SQL.raw(` + SELECT * FROM ${SchemaManagementTable.name} + WHERE name = $1 ${!awaitLock ? 'FOR UPDATE NOWAIT' : ''}; + `, [name]) + const setSchemaLockQuery = (locked: boolean) => SQL.raw(` + UPDATE ${SchemaManagementTable.name} SET locked = $1 WHERE name=$2; + `, [locked, name]) + + /* + Locks schema_management table for given transaction and retrievs current schema version + If table is already locked, the postgres client is advised to await execution until lock is released + This ensures, that in a multi-node environment all starting nodes proceed code execution after all migrations are done + */ + const getCurrentVersionAndLockSchema = async (client: IDatabaseBaseClient, awaitLock: boolean) => { + await client.query(lockSchemaTableQuery) + const dbResults = await client.query(getSchemaVersionQuery(awaitLock)) + + if (dbResults.length === 1 && dbResults[0].locked === false) { + await client.query(setSchemaLockQuery(true)) + + return dbResults[0].version + } + + return null + } + const migrateToVersion = async (targetVersion: number) => { if (!isInitialized) throwNotInitialized() @@ -76,26 +107,37 @@ export const DatabaseSchema = ({ client, createStatements, name, migrations, log throw new Error('Target version of migrateToVersion() has to be greater 1') } - const currentVersion = version - - for (let newVersion = currentVersion + 1; newVersion <= targetVersion; newVersion -= -1) { + for (let newVersion = version; newVersion <= targetVersion; newVersion++) { await client.transaction(async (transaction) => { + const currentVersion = await getCurrentVersionAndLockSchema(transaction, true) + + if (currentVersion === null || currentVersion >= newVersion) { + if (currentVersion) { + await transaction.query(setSchemaLockQuery(false)) + } + + return + } + const migration = migrations.get(newVersion) if (!migration) { + await transaction.query(setSchemaLockQuery(false)) + throw new Error(`Migration with version ${newVersion} not found. Aborting migration process...`) } await migration.up(transaction) await transaction.query(updateSchemaVersionQuery(name, newVersion)) + await transaction.query(setSchemaLockQuery(false)) + + // istanbul ignore next + if (!(logMigrations === false)) { + console.info(`Successfully migrated ${name} from version ${version} to ${newVersion}`) + } }) version = newVersion - - // istanbul ignore next - if (!(logMigrations === false)) { - console.info(`Successfully migrated ${name} from version ${version - 1} to ${version}`) - } } } diff --git a/src/sql.ts b/src/sql.ts index 464d271..0a1436b 100644 --- a/src/sql.ts +++ b/src/sql.ts @@ -1,7 +1,7 @@ import { Columns, Column, IReferenceConstraintInternal, isCollection, isSQLFunction, ForeignKeyUpdateDeleteRule, ICreateIndexStatement, IQuery, isJSONType, IWhereConditionColumned, ISQLArg } from "./table"; import * as pgEscape from 'pg-escape'; import { dateToSQLUTCFormat } from "./sql-utils"; -import moment from 'moment' +import * as moment from 'moment' import { flatten } from './utils' const isStringArray = (arr: any): arr is string[] => Array.isArray(arr) && arr.every(item => typeof item === 'string')