diff --git a/database_validation.sh b/database_validation.sh new file mode 100755 index 0000000..ad29579 --- /dev/null +++ b/database_validation.sh @@ -0,0 +1,80 @@ +#!/bin/bash + +# Function to display usage +usage() { + echo "Usage: $0 -s -t -P -d -U -u -S -T " + exit 1 +} + +# Parse command-line arguments +while getopts ":s:t:P:d:U:u:S:T:" opt; do + case "${opt}" in + s) + SRC_DB_HOST=${OPTARG} + ;; + t) + TRG_DB_HOST=${OPTARG} + ;; + P) + DB_PORT=${OPTARG} + ;; + d) + DB_NAME=${OPTARG} + ;; + U) + SRC_DB_USER=${OPTARG} + ;; + u) + TRG_DB_USER=${OPTARG} + ;; + S) + SRC_DB_PASSWORD=${OPTARG} + ;; + T) + TRG_DB_PASSWORD=${OPTARG} + ;; + *) + usage + ;; + esac +done + +# Check if all required arguments are provided +if [ -z "${SRC_DB_HOST}" ] || [ -z "${TRG_DB_HOST}" ] || [ -z "${DB_PORT}" ] || [ -z "${DB_NAME}" ] || [ -z "${SRC_DB_USER}" ] || [ -z "${TRG_DB_USER}" ] || [ -z "${SRC_DB_PASSWORD}" ] || [ -z "${TRG_DB_PASSWORD}" ]; then + usage +fi + +# Export the passwords to be used by psql +export PGPASSWORD_SRC="${SRC_DB_PASSWORD}" +export PGPASSWORD_TRG="${TRG_DB_PASSWORD}" + +# Get the list of tables in the source database +tables=$(PGPASSWORD="${PGPASSWORD_SRC}" psql -h "${SRC_DB_HOST}" -p "${DB_PORT}" -U "${SRC_DB_USER}" -d "${DB_NAME}" -t -c "SELECT tablename FROM pg_tables WHERE schemaname='public';") + +# Print the header +printf "%-30s %-20s %-20s\n" "Table Name" "Row Count (Source)" "Row Count (Target)" +printf "%-30s %-20s %-20s\n" "----------" "-----------------" "-----------------" + +# Iterate over each table and count the rows in source and target +for table in $tables; do + table=$(echo $table | xargs) # Trim any leading or trailing whitespace + # Check if the table exists in the source database + table_exists_source=$(PGPASSWORD="${PGPASSWORD_SRC}" psql -h "${SRC_DB_HOST}" -p "${DB_PORT}" -U "${SRC_DB_USER}" -d "${DB_NAME}" -t -c "SELECT to_regclass('public.${table}');") + # Check if the table exists in the target database + table_exists_target=$(PGPASSWORD="${PGPASSWORD_TRG}" psql -h "${TRG_DB_HOST}" -p "${DB_PORT}" -U "${TRG_DB_USER}" -d "${DB_NAME}" -t -c "SELECT to_regclass('public.${table}');") + + if [[ $table_exists_source != " " && $table_exists_source != "" && $table_exists_target != " " && $table_exists_target != "" ]]; then + row_count_source=$(PGPASSWORD="${PGPASSWORD_SRC}" psql -h "${SRC_DB_HOST}" -p "${DB_PORT}" -U "${SRC_DB_USER}" -d "${DB_NAME}" -t -c "SELECT COUNT(*) FROM \"$table\";") + row_count_target=$(PGPASSWORD="${PGPASSWORD_TRG}" psql -h "${TRG_DB_HOST}" -p "${DB_PORT}" -U "${TRG_DB_USER}" -d "${DB_NAME}" -t -c "SELECT COUNT(*) FROM \"$table\";") + printf "%-30s %-20s %-20s\n" "$table" "$row_count_source" "$row_count_target" + elif [[ $table_exists_source != " " && $table_exists_source != "" ]]; then + row_count_source=$(PGPASSWORD="${PGPASSWORD_SRC}" psql -h "${SRC_DB_HOST}" -p "${DB_PORT}" -U "${SRC_DB_USER}" -d "${DB_NAME}" -t -c "SELECT COUNT(*) FROM \"$table\";") + printf "%-30s %-20s %-20s\n" "$table" "$row_count_source" "Table does not exist" + elif [[ $table_exists_target != " " && $table_exists_target != "" ]]; then + row_count_target=$(PGPASSWORD="${PGPASSWORD_TRG}" psql -h "${TRG_DB_HOST}" -p "${DB_PORT}" -U "${TRG_DB_USER}" -d "${DB_NAME}" -t -c "SELECT COUNT(*) FROM \"$table\";") + printf "%-30s %-20s %-20s\n" "$table" "Table does not exist" "$row_count_target" + else + printf "%-30s %-20s %-20s\n" "$table" "Table does not exist" "Table does not exist" + fi +done + diff --git a/pg_migration_tool/main.py b/pg_migration_tool/main.py index c7760a5..6d24993 100644 --- a/pg_migration_tool/main.py +++ b/pg_migration_tool/main.py @@ -242,45 +242,66 @@ def stream_error(process): thread_err.start() async def validate_migration(self): - self.query_one(Log).write_line("Starting validation...") + self.query_one(Log).write_line("Starting validation...") - db = config["dbs"][self.title] + db = config["dbs"][self.title] - source_conn = await asyncpg.connect( - database=db["source"]["db_database_name"], - user=db["source"]["db_username"], - password=db["source"]["db_password"], - host=db["source"]["db_connection_host"], - port=db.get('port', 5432), - ) + # Connect to source and target databases + source_conn = await asyncpg.connect( + database=db["source"]["db_database_name"], + user=db["source"]["db_username"], + password=db["source"]["db_password"], + host=db["source"]["db_connection_host"], + port=db.get('port', 5432), + ) - target_conn = await asyncpg.connect( - database=db["target"]["db_database_name"], - user=db["target"]["db_username"], - password=db["target"]["db_password"], - host=db["target"]["db_connection_host"], - port=db.get('port', 5432), - ) + target_conn = await asyncpg.connect( + database=db["target"]["db_database_name"], + user=db["target"]["db_username"], + password=db["target"]["db_password"], + host=db["target"]["db_connection_host"], + port=db.get('port', 5432), + ) + + # Get the list of tables in the source database + source_tables = await source_conn.fetch("SELECT tablename FROM pg_tables WHERE schemaname='public';") + source_table_names = {table['tablename'] for table in source_tables} + + # Get the list of tables in the target database + target_tables = await target_conn.fetch("SELECT tablename FROM pg_tables WHERE schemaname='public';") + target_table_names = {table['tablename'] for table in target_tables} + + # Print the header for the validation results + validation_results = "| Table | Source Rows | Target Rows | Match |\n" + validation_results += "| --- | --- | --- | --- |\n" - source_tables = await source_conn.fetch("SELECT tablename FROM pg_tables WHERE schemaname='public' AND schemaname NOT LIKE 'awsdms_%';") - target_tables = await target_conn.fetch("SELECT tablename FROM pg_tables WHERE schemaname='public' AND schemaname NOT LIKE 'awsdms_%';") + # Validate each table + for table_name in source_table_names.union(target_table_names): + source_count = target_count = "N/A" + match = "No" - validation_results = "| Table | Source Rows | Target Rows | Match |\n" - validation_results += "| --- | --- | --- | --- |\n" + if table_name in source_table_names: + try: + source_count = await source_conn.fetchval(f"SELECT COUNT(*) FROM {table_name};") + except Exception as e: + source_count = f"Error: {e}" - for table in source_tables: - table_name = table["tablename"] - source_count = await source_conn.fetchval(f"SELECT COUNT(*) FROM {table_name};") - target_count = await target_conn.fetchval(f"SELECT COUNT(*) FROM {table_name};") + if table_name in target_table_names: + try: + target_count = await target_conn.fetchval(f"SELECT COUNT(*) FROM {table_name};") + except Exception as e: + target_count = f"Error: {e}" - match = "Yes" if source_count == target_count else "No" - validation_results += f"| {table_name} | {source_count} | {target_count} | {match} |\n" + if source_count == target_count: + match = "Yes" - self.query_one(Markdown).update(validation_results) + validation_results += f"| {table_name} | {source_count} | {target_count} | {match} |\n" - await source_conn.close() - await target_conn.close() + self.query_one(Markdown).update(validation_results) + # Close connections + await source_conn.close() + await target_conn.close() @on(Print) def log_printed(self, event: Print):